ONNX 模型库
返回模型

说明文档

potion-retrieval-32M 模型卡片

<div align="center"> <img width="35%" alt="Model2Vec logo" src="https://raw.githubusercontent.com/MinishLab/model2vec/main/assets/images/logo_v2.png"> </div>

这是一个针对检索任务优化的 Model2Vec 模型。它是基于 potion-base-32M 微调得到的。微调使用的是 这篇博客 中描述的训练方法的改进版本。它使用静态嵌入,使得文本嵌入可以在 GPU 和 CPU 上以数量级的速度更快地计算。它专为计算资源有限或需要实时性能的应用场景而设计。

安装

使用 pip 安装 model2vec:

pip install model2vec

使用方法

使用 from_pretrained 方法加载模型:

from model2vec import StaticModel
# 加载预训练的 Model2Vec 模型
model = StaticModel.from_pretrained("minishlab/potion-retrieval-32M")
# 计算文本嵌入
embeddings = model.encode(["示例句子"])

工作原理

Model2Vec 创建了一个小巧的静态模型,在 MTEB 的所有任务上,它以较大优势超越其他静态嵌入模型。该模型使用 Tokenlearn 进行预训练。创建过程如下:

  • 蒸馏:首先,使用 Model2Vec 从句子转换器模型中蒸馏出一个模型。
  • 训练数据创建:使用句子转换器模型在大型语料库上创建平均输出嵌入,从而生成训练数据。
  • 训练:使用 Tokenlearn 在训练数据上训练蒸馏后的模型。
  • 训练后重新正则化:训练后,根据词频对 token 进行加权,应用 PCA,最后应用 SIF 加权,对模型进行重新正则化。

该模型的结果可以在 Model2Vec 结果页面 查看。

结果

该模型的结果如下表所示。所有模型的完整 Model2Vec 结果可以在 Model2Vec 结果页面 查看。

平均 (全部)                                                 49.73
平均 (MTEB)                                                49.76
分类                                                        59.56
聚类                                                        30.55
成对分类                                                    76.38
重排序                                                     50.05
检索                                                        36.35
语义文本相似度                                             73.22
摘要                                                        28.85
PEARL                                                      49.31
词语相似度                                                 50.02

额外资源

库作者

Model2Vec 由 Minish Lab 团队开发,团队成员包括 Stephan TulkensThomas van Dongen

引用

如果您在工作中使用此模型,请引用 Model2Vec 仓库

@software{minishlab2024model2vec,
  authors = {Stephan Tulkens and Thomas van Dongen},
  title = {Model2Vec: The Fastest State-of-the-Art Static Embeddings in the World},
  year = {2024},
  url = {https://github.com/MinishLab/model2vec}
}

可复现性

以下脚本可用于复现此模型。微调方法和代码的所有功劳归于 Tom Aarsen 以及他在博客中介绍的内容。我们对原始代码做了少量修改:

  • 我们从一个预训练的 Model2Vec 模型(potion-base-32M)开始。
  • 我们将数据集大小减小到原来的十分之一。在实验过程中,我们发现模型不需要完整的数据集就能收敛。
  • 我们降低了学习率并训练 3 个 epoch 而不是 1 个。使用过高的学习率会抹去使用预训练模型的效果。
import random
import logging
from datasets import load_dataset, Dataset, DatasetDict
from sentence_transformers import (
    SentenceTransformer,
    SentenceTransformerTrainer,
    SentenceTransformerTrainingArguments,
    SentenceTransformerModelCardData,
)
from sentence_transformers.losses import MatryoshkaLoss, MultipleNegativesRankingLoss
from sentence_transformers.training_args import BatchSamplers, MultiDatasetBatchSamplers
from sentence_transformers.evaluation import NanoBEIREvaluator
from sentence_transformers.models.StaticEmbedding import StaticEmbedding
import wandb

logging.basicConfig(
    format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO
)
random.seed(12)


def load_train_eval_datasets(factor: int = 1):
    """
    Loads train and eval datasets from disk if available. Otherwise, downloads 
    them from Hugging Face, preprocesses, and saves them to disk. If `factor` is 
    greater than 1, returns a fraction (1/factor) of each dataset subset.

    :param factor: The factor by which the data is reduced. If factor=1, no reduction is performed.
    :return: (train_dataset: DatasetDict, eval_dataset: DatasetDict)
    """
    try:
        # Try loading from disk
        train_dataset = DatasetDict.load_from_disk("datasets/train_dataset")
        eval_dataset = DatasetDict.load_from_disk("datasets/eval_dataset")
    except FileNotFoundError:
        print("Prebuilt datasets not found on disk. Building from scratch...")

        print("Loading gooaq dataset...")
        gooaq_dataset = load_dataset("sentence-transformers/gooaq", split="train")
        gooaq_dataset_dict = gooaq_dataset.train_test_split(test_size=10_000, seed=12)
        gooaq_train_dataset: Dataset = gooaq_dataset_dict["train"]
        gooaq_eval_dataset: Dataset = gooaq_dataset_dict["test"]
        print("Loaded gooaq dataset.")

        print("Loading msmarco dataset...")
        msmarco_dataset = load_dataset(
            "sentence-transformers/msmarco-co-condenser-margin-mse-sym-mnrl-mean-v1",
            "triplet",
            split="train"
        )
        msmarco_dataset_dict = msmarco_dataset.train_test_split(test_size=10_000, seed=12)
        msmarco_train_dataset: Dataset = msmarco_dataset_dict["train"]
        msmarco_eval_dataset: Dataset = msmarco_dataset_dict["test"]
        print("Loaded msmarco dataset.")

        print("Loading squad dataset...")
        squad_dataset = load_dataset("sentence-transformers/squad", split="train")
        squad_dataset_dict = squad_dataset.train_test_split(test_size=10_000, seed=12)
        squad_train_dataset: Dataset = squad_dataset_dict["train"]
        squad_eval_dataset: Dataset = squad_dataset_dict["test"]
        print("Loaded squad dataset.")

        print("Loading s2orc dataset...")
        s2orc_dataset = load_dataset(
            "sentence-transformers/s2orc", 
            "title-abstract-pair", 
            split="train[:100000]"  # limit to 100k
        )
        s2orc_dataset_dict = s2orc_dataset.train_test_split(test_size=10_000, seed=12)
        s2orc_train_dataset: Dataset = s2orc_dataset_dict["train"]
        s2orc_eval_dataset: Dataset = s2orc_dataset_dict["test"]
        print("Loaded s2orc dataset.")

        print("Loading allnli dataset...")
        allnli_train_dataset = load_dataset(
            "sentence-transformers/all-nili", 
            "triplet", 
            split="train"
        )
        allnli_eval_dataset = load_dataset(
            "sentence-transformers/all-nili", 
            "triplet", 
            split="dev"
        )
        print("Loaded allnli dataset.")

        print("Loading paq dataset...")
        paq_dataset = load_dataset("sentence-transformers/paq", split="train")
        paq_dataset_dict = paq_dataset.train_test_split(test_size=10_000, seed=12)
        paq_train_dataset: Dataset = paq_dataset_dict["train"]
        paq_eval_dataset: Dataset = paq_dataset_dict["test"]
        print("Loaded paq dataset.")

        print("Loading trivia_qa dataset...")
        trivia_qa = load_dataset("sentence-transformers/trivia-qa", split="train")
        trivia_qa_dataset_dict = trivia_qa.train_test_split(test_size=5_000, seed=12)
        trivia_qa_train_dataset: Dataset = trivia_qa_dataset_dict["train"]
        trivia_qa_eval_dataset: Dataset = trivia_qa_dataset_dict["test"]
        print("Loaded trivia_qa dataset.")

        print("Loading msmarco_10m dataset...")
        msmarco_10m_dataset = load_dataset("bclavie/msmarco-10m-triplets", split="train")
        msmarco_10m_dataset_dict = msmarco_10m_dataset.train_test_split(
            test_size=10_000, seed=12
        )
        msmarco_10m_train_dataset: Dataset = msmarco_10m_dataset_dict["train"]
        msmarco_10m_eval_dataset: Dataset = msmarco_10m_dataset_dict["test"]
        print("Loaded msmarco_10m dataset.")

        print("Loading swim_ir dataset...")
        swim_ir_dataset = load_dataset(
            "nthakur/swim-ir-monolingual", 
            "en", 
            split="train"
        ).select_columns(["query", "text"])
        swim_ir_dataset_dict = swim_ir_dataset.train_test_split(
            test_size=10_000, seed=12
        )
        swim_ir_train_dataset: Dataset = swim_ir_dataset_dict["train"]
        swim_ir_eval_dataset: Dataset = swim_ir_dataset_dict["test"]
        print("Loaded swim_ir dataset.")

        # NOTE: 20 negatives
        print("Loading pubmedqa dataset...")
        pubmedqa_dataset = load_dataset(
            "sentence-transformers/pubmedqa", 
            "triplet-20", 
            split="train"
        )
        pubmedqa_dataset_dict = pubmedqa_dataset.train_test_split(test_size=100, seed=12)
        pubmedqa_train_dataset: Dataset = pubmedqa_dataset_dict["train"]
        pubmedqa_eval_dataset: Dataset = pubmedqa_dataset_dict["test"]
        print("Loaded pubmedqa dataset.")

        # NOTE: A lot of overlap with anchor/positives
        print("Loading miracl dataset...")
        miracl_dataset = load_dataset(
            "sentence-transformers/miracl", 
            "en-triplet-all", 
            split="train"
        )
        miracl_dataset_dict = miracl_dataset.train_test_split(test_size=10_000, seed=12)
        miracl_train_dataset: Dataset = miracl_dataset_dict["train"]
        miracl_eval_dataset: Dataset = miracl_dataset_dict["test"]
        print("Loaded miracl dataset.")

        # NOTE: A lot of overlap with anchor/positives
        print("Loading mldr dataset...")
        mldr_dataset = load_dataset(
            "sentence-transformers/mldr", 
            "en-triplet-all", 
            split="train"
        )
        mldr_dataset_dict = mldr_dataset.train_test_split(test_size=10_000, seed=12)
        mldr_train_dataset: Dataset = mldr_dataset_dict["train"]
        mldr_eval_dataset: Dataset = mldr_dataset_dict["test"]
        print("Loaded mldr dataset.")

        # NOTE: A lot of overlap with anchor/positives
        print("Loading mr_tydi dataset...")
        mr_tydi_dataset = load_dataset(
            "sentence-transformers/mr-tydi", 
            "en-triplet-all", 
            split="train"
        )
        mr_tydi_dataset_dict = mr_tydi_dataset.train_test_split(test_size=10_000, seed=12)
        mr_tydi_train_dataset: Dataset = mr_tydi_dataset_dict["train"]
        mr_tydi_eval_dataset: Dataset = mr_tydi_dataset_dict["test"]
        print("Loaded mr_tydi dataset.")

        train_dataset = DatasetDict({
            "gooaq": gooaq_train_dataset,
            "msmarco": msmarco_train_dataset,
            "squad": squad_train_dataset,
            "s2orc": s2orc_train_dataset,
            "allnli": allnli_train_dataset,
            "paq": paq_train_dataset,
            "trivia_qa": trivia_qa_train_dataset,
            "msmarco_10m": msmarco_10m_train_dataset,
            "swim_ir": swim_ir_train_dataset,
            "pubmedqa": pubmedqa_train_dataset,
            "miracl": miracl_train_dataset,
            "mldr": mldr_train_dataset,
            "mr_tydi": mr_tydi_train_dataset,
        })
        eval_dataset = DatasetDict({
            "gooaq": gooaq_eval_dataset,
            "msmarco": msmarco_eval_dataset,
            "squad": squad_eval_dataset,
            "s2orc": s2orc_eval_dataset,
            "allnili": allnli_eval_dataset,
            "paq": paq_eval_dataset,
            "trivia_qa": trivia_qa_eval_dataset,
            "msmarco_10m": msmarco_10m_eval_dataset,
            "swim_ir": swim_ir_eval_dataset,
            "pubmedqa": pubmedqa_eval_dataset,
            "miracl": miracl_eval_dataset,
            "mldr": mldr_eval_dataset,
            "mr_tydi": mr_tydi_eval_dataset,
        })

        # Save to disk for next time
        train_dataset.save_to_disk("datasets/train_dataset")
        eval_dataset.save_to_disk("datasets/eval_dataset")

        # Quit to avoid memory overhead on large datasets
        quit()

    # Reduce the dataset if factor > 1
    if factor > 1:
        for subset_name in train_dataset:
            ds = train_dataset[subset_name].shuffle(seed=42)
            new_len = len(ds) // factor
            train_dataset[subset_name] = ds.select(range(new_len))

        for subset_name in eval_dataset:
            ds = eval_dataset[subset_name].shuffle(seed=42)
            new_len = len(ds) // factor
            eval_dataset[subset_name] = ds.select(range(new_len))

    return train_dataset, eval_dataset


def main():
    wandb.init(entity="minishlab", project="minishlab")
    # 1. Load a model to finetune
    static_embedding = StaticEmbedding.from_model2vec("minishlab/potion-base-32M")

    # 2. Initialize the SentenceTransformer model
    model_name = "potion-retrieval-32M"
    model = SentenceTransformer(
        modules=[static_embedding],
        model_card_data=SentenceTransformerModelCardData(
            language="en",
            license="MIT",
            model_name=model_name,
        ),
    )

    # 3. Load training & evaluation datasets
    # NOTE: we reduce the total dataset size by a factor of 10 
    train_dataset, eval_dataset = load_train_eval_datasets(factor=10)
    print(train_dataset)

    # 4. Define a loss function
    loss = MultipleNegativesRankingLoss(model)
    loss = MatryoshkaLoss(model, loss, matryoshka_dims=[32, 64, 128, 256, 512])

    # 5. Specify training arguments
    run_name = model_name
    epochs = 3
    lr = 0.05
    args = SentenceTransformerTrainingArguments(
        output_dir=f"models/{run_name}",
        num_train_epochs=epochs,
        per_device_train_batch_size=2048,
        per_device_eval_batch_size=2048,
        learning_rate=lr,
        warmup_ratio=0.1,
        fp16=False,
        bf16=True,
        batch_sampler=BatchSamplers.NO_DUPLICATES,
        multi_dataset_batch_sampler=MultiDatasetBatchSamplers.PROPORTIONAL,
        eval_strategy="steps",
        eval_steps=250,
        save_strategy="steps",
        save_steps=250,
        save_total_limit=2,
        logging_steps=250,
        logging_first_step=True,
        run_name=run_name,
        report_to=["wandb"],
        load_best_model_at_end=True,
        metric_for_best_model="eval_NanoBEIR_mean_cosine_ndcg@10",
        greater_is_better=True,
    )

    # 6. Create an evaluator & evaluate the base model
    evaluator = NanoBEIREvaluator()
    evaluator(model)

    # 7. Create a trainer & train
    trainer = SentenceTransformerTrainer(
        model=model,
        args=args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        loss=loss,
        evaluator=evaluator,
    )
    trainer.train()

    # 8. Evaluate the trained model and save
    evaluator(model)
    model.save_pretrained(f"models/{run_name}/final")


if __name__ == "__main__":
    main()

minishlab/potion-retrieval-32M

作者 minishlab

model2vec
↓ 96.2K ♥ 27

创建时间: 2025-01-23 15:05:16+00:00

更新时间: 2025-01-29 11:00:09+00:00

在 Hugging Face 上查看

文件 (10)

.gitattributes
README.md
config.json
model.safetensors
modules.json
onnx/model.onnx ONNX
special_tokens_map.json
tokenizer.json
tokenizer_config.json
vocab.txt