tomaarsen HF Staff commited on
Commit
91b1760
·
verified ·
1 Parent(s): 1c4db27

Create train_nq.py

Browse files
Files changed (1) hide show
  1. train_nq.py +111 -0
train_nq.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+
5
+ from datasets import load_dataset
6
+
7
+ from sentence_transformers.evaluation import SequentialEvaluator, SimilarityFunction
8
+ from sentence_transformers.models import Pooling, Transformer
9
+ from sentence_transformers.sparse_encoder import SparseEncoder
10
+ from sentence_transformers.sparse_encoder.evaluation import (
11
+ SparseEmbeddingSimilarityEvaluator,
12
+ )
13
+ from sentence_transformers.sparse_encoder.evaluation.SparseNanoBEIREvaluator import SparseNanoBEIREvaluator
14
+ from sentence_transformers.sparse_encoder.losses import CSRLoss
15
+ from sentence_transformers.sparse_encoder.models import CSRSparsity
16
+ from sentence_transformers.sparse_encoder.trainer import SparseEncoderTrainer
17
+ from sentence_transformers.sparse_encoder.training_args import (
18
+ SparseEncoderTrainingArguments,
19
+ )
20
+ from sentence_transformers.training_args import BatchSamplers
21
+
22
+ # Set up logging
23
+ logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO)
24
+
25
+
26
+ def main():
27
+ # Initialize model components
28
+ model_name = "microsoft/mpnet-base"
29
+ transformer = Transformer(model_name)
30
+ # transformer.requires_grad_(False) # Freeze the transformer model
31
+ pooling = Pooling(transformer.get_word_embedding_dimension(), pooling_mode="mean")
32
+ csr_sparsity = CSRSparsity(
33
+ input_dim=transformer.get_word_embedding_dimension(),
34
+ hidden_dim=4 * transformer.get_word_embedding_dimension(),
35
+ k=256, # Number of top values to keep
36
+ k_aux=512, # Number of top values for auxiliary loss
37
+ )
38
+
39
+ # Create the SparseEncoder model
40
+ model = SparseEncoder(modules=[transformer, pooling, csr_sparsity])
41
+
42
+ # 2a. Load the NQ dataset: https://huggingface.co/datasets/sentence-transformers/natural-questions
43
+ logging.info("Read the Natural Questions training dataset")
44
+ full_dataset = load_dataset("sentence-transformers/natural-questions", split="train").select(range(100_000))
45
+ dataset_dict = full_dataset.train_test_split(test_size=1_000, seed=12)
46
+ train_dataset = dataset_dict["train"]
47
+ eval_dataset = dataset_dict["test"]
48
+ logging.info(train_dataset)
49
+ logging.info(eval_dataset)
50
+
51
+ # 3. Initialize the loss
52
+ loss = CSRLoss(
53
+ model=model,
54
+ beta=0.1, # Weight for auxiliary loss
55
+ gamma=1, # Weight for ranking loss
56
+ scale=20.0, # Scale for similarity computation
57
+ )
58
+
59
+ # 4. Define an evaluator for use during training. This is useful to keep track of alongside the evaluation loss.
60
+ evaluators = []
61
+ for k_dim in [16, 32, 64, 128, 256]:
62
+ evaluators.append(SparseNanoBEIREvaluator(["msmarco", "nfcorpus", "nq"], truncate_dim=k_dim))
63
+ dev_evaluator = SequentialEvaluator(evaluators, main_score_function=lambda scores: scores[-1])
64
+ dev_evaluator(model)
65
+
66
+ # Set up training arguments
67
+ run_name = "sparse-mpnet-base-nq-fresh"
68
+ training_args = SparseEncoderTrainingArguments(
69
+ output_dir=f"models/{run_name}",
70
+ num_train_epochs=1,
71
+ per_device_train_batch_size=32,
72
+ per_device_eval_batch_size=32,
73
+ warmup_ratio=0.1,
74
+ fp16=False, # Set to False if you get an error that your GPU can't run on FP16
75
+ bf16=True, # Set to True if you have a GPU that supports BF16
76
+ batch_sampler=BatchSamplers.NO_DUPLICATES, # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch
77
+ logging_steps=200,
78
+ eval_strategy="steps",
79
+ eval_steps=400,
80
+ save_strategy="steps",
81
+ save_steps=400,
82
+ learning_rate=4e-5,
83
+ optim="adamw_torch",
84
+ weight_decay=1e-4,
85
+ adam_epsilon=6.25e-10,
86
+ run_name=run_name,
87
+ )
88
+
89
+ # Initialize trainer
90
+ trainer = SparseEncoderTrainer(
91
+ model=model,
92
+ args=training_args,
93
+ train_dataset=train_dataset,
94
+ eval_dataset=eval_dataset,
95
+ loss=loss,
96
+ evaluator=dev_evaluator,
97
+ )
98
+
99
+ # Train model
100
+ trainer.train()
101
+
102
+ # 7. Evaluate the model performance again after training
103
+ dev_evaluator(model)
104
+
105
+ # 8. Save the trained & evaluated model locally
106
+ model.save_pretrained(f"models/{run_name}/final")
107
+
108
+ model.push_to_hub(run_name)
109
+
110
+ if __name__ == "__main__":
111
+ main()