File size: 4,234 Bytes
91b1760
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
from __future__ import annotations

import logging

from datasets import load_dataset

from sentence_transformers.evaluation import SequentialEvaluator, SimilarityFunction
from sentence_transformers.models import Pooling, Transformer
from sentence_transformers.sparse_encoder import SparseEncoder
from sentence_transformers.sparse_encoder.evaluation import (
    SparseEmbeddingSimilarityEvaluator,
)
from sentence_transformers.sparse_encoder.evaluation.SparseNanoBEIREvaluator import SparseNanoBEIREvaluator
from sentence_transformers.sparse_encoder.losses import CSRLoss
from sentence_transformers.sparse_encoder.models import CSRSparsity
from sentence_transformers.sparse_encoder.trainer import SparseEncoderTrainer
from sentence_transformers.sparse_encoder.training_args import (
    SparseEncoderTrainingArguments,
)
from sentence_transformers.training_args import BatchSamplers

# Set up logging
logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO)


def main():
    # Initialize model components
    model_name = "microsoft/mpnet-base"
    transformer = Transformer(model_name)
    # transformer.requires_grad_(False)  # Freeze the transformer model
    pooling = Pooling(transformer.get_word_embedding_dimension(), pooling_mode="mean")
    csr_sparsity = CSRSparsity(
        input_dim=transformer.get_word_embedding_dimension(),
        hidden_dim=4 * transformer.get_word_embedding_dimension(),
        k=256,  # Number of top values to keep
        k_aux=512,  # Number of top values for auxiliary loss
    )

    # Create the SparseEncoder model
    model = SparseEncoder(modules=[transformer, pooling, csr_sparsity])

    # 2a. Load the NQ dataset: https://huggingface.co./datasets/sentence-transformers/natural-questions
    logging.info("Read the Natural Questions training dataset")
    full_dataset = load_dataset("sentence-transformers/natural-questions", split="train").select(range(100_000))
    dataset_dict = full_dataset.train_test_split(test_size=1_000, seed=12)
    train_dataset = dataset_dict["train"]
    eval_dataset = dataset_dict["test"]
    logging.info(train_dataset)
    logging.info(eval_dataset)

    # 3. Initialize the loss
    loss = CSRLoss(
        model=model,
        beta=0.1,  # Weight for auxiliary loss
        gamma=1,  # Weight for ranking loss
        scale=20.0,  # Scale for similarity computation
    )

    # 4. Define an evaluator for use during training. This is useful to keep track of alongside the evaluation loss.
    evaluators = []
    for k_dim in [16, 32, 64, 128, 256]:
        evaluators.append(SparseNanoBEIREvaluator(["msmarco", "nfcorpus", "nq"], truncate_dim=k_dim))
    dev_evaluator = SequentialEvaluator(evaluators, main_score_function=lambda scores: scores[-1])
    dev_evaluator(model)

    # Set up training arguments
    run_name = "sparse-mpnet-base-nq-fresh"
    training_args = SparseEncoderTrainingArguments(
        output_dir=f"models/{run_name}",
        num_train_epochs=1,
        per_device_train_batch_size=32,
        per_device_eval_batch_size=32,
        warmup_ratio=0.1,
        fp16=False,  # Set to False if you get an error that your GPU can't run on FP16
        bf16=True,  # Set to True if you have a GPU that supports BF16
        batch_sampler=BatchSamplers.NO_DUPLICATES,  # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch
        logging_steps=200,
        eval_strategy="steps",
        eval_steps=400,
        save_strategy="steps",
        save_steps=400,
        learning_rate=4e-5,
        optim="adamw_torch",
        weight_decay=1e-4,
        adam_epsilon=6.25e-10,
        run_name=run_name,
    )

    # Initialize trainer
    trainer = SparseEncoderTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        loss=loss,
        evaluator=dev_evaluator,
    )

    # Train model
    trainer.train()

    # 7. Evaluate the model performance again after training
    dev_evaluator(model)

    # 8. Save the trained & evaluated model locally
    model.save_pretrained(f"models/{run_name}/final")

    model.push_to_hub(run_name)

if __name__ == "__main__":
    main()