from __future__ import annotations import logging from datasets import load_dataset from sentence_transformers.evaluation import SequentialEvaluator, SimilarityFunction from sentence_transformers.models import Pooling, Transformer from sentence_transformers.sparse_encoder import SparseEncoder from sentence_transformers.sparse_encoder.evaluation import ( SparseEmbeddingSimilarityEvaluator, ) from sentence_transformers.sparse_encoder.evaluation.SparseNanoBEIREvaluator import SparseNanoBEIREvaluator from sentence_transformers.sparse_encoder.losses import CSRLoss from sentence_transformers.sparse_encoder.models import CSRSparsity from sentence_transformers.sparse_encoder.trainer import SparseEncoderTrainer from sentence_transformers.sparse_encoder.training_args import ( SparseEncoderTrainingArguments, ) from sentence_transformers.training_args import BatchSamplers # Set up logging logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO) def main(): # Initialize model components model_name = "microsoft/mpnet-base" transformer = Transformer(model_name) # transformer.requires_grad_(False) # Freeze the transformer model pooling = Pooling(transformer.get_word_embedding_dimension(), pooling_mode="mean") csr_sparsity = CSRSparsity( input_dim=transformer.get_word_embedding_dimension(), hidden_dim=4 * transformer.get_word_embedding_dimension(), k=256, # Number of top values to keep k_aux=512, # Number of top values for auxiliary loss ) # Create the SparseEncoder model model = SparseEncoder(modules=[transformer, pooling, csr_sparsity]) # 2a. Load the NQ dataset: https://huggingface.co./datasets/sentence-transformers/natural-questions logging.info("Read the Natural Questions training dataset") full_dataset = load_dataset("sentence-transformers/natural-questions", split="train").select(range(100_000)) dataset_dict = full_dataset.train_test_split(test_size=1_000, seed=12) train_dataset = dataset_dict["train"] eval_dataset = dataset_dict["test"] logging.info(train_dataset) logging.info(eval_dataset) # 3. Initialize the loss loss = CSRLoss( model=model, beta=0.1, # Weight for auxiliary loss gamma=1, # Weight for ranking loss scale=20.0, # Scale for similarity computation ) # 4. Define an evaluator for use during training. This is useful to keep track of alongside the evaluation loss. evaluators = [] for k_dim in [16, 32, 64, 128, 256]: evaluators.append(SparseNanoBEIREvaluator(["msmarco", "nfcorpus", "nq"], truncate_dim=k_dim)) dev_evaluator = SequentialEvaluator(evaluators, main_score_function=lambda scores: scores[-1]) dev_evaluator(model) # Set up training arguments run_name = "sparse-mpnet-base-nq-fresh" training_args = SparseEncoderTrainingArguments( output_dir=f"models/{run_name}", num_train_epochs=1, per_device_train_batch_size=32, per_device_eval_batch_size=32, warmup_ratio=0.1, fp16=False, # Set to False if you get an error that your GPU can't run on FP16 bf16=True, # Set to True if you have a GPU that supports BF16 batch_sampler=BatchSamplers.NO_DUPLICATES, # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch logging_steps=200, eval_strategy="steps", eval_steps=400, save_strategy="steps", save_steps=400, learning_rate=4e-5, optim="adamw_torch", weight_decay=1e-4, adam_epsilon=6.25e-10, run_name=run_name, ) # Initialize trainer trainer = SparseEncoderTrainer( model=model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset, loss=loss, evaluator=dev_evaluator, ) # Train model trainer.train() # 7. Evaluate the model performance again after training dev_evaluator(model) # 8. Save the trained & evaluated model locally model.save_pretrained(f"models/{run_name}/final") model.push_to_hub(run_name) if __name__ == "__main__": main()