Spaces:
Running
Running
import torch | |
from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
from langchain.retrievers import ContextualCompressionRetriever | |
from langchain.retrievers.document_compressors import CrossEncoderReranker | |
class DocumentRetrieverWithReranker: | |
def __init__(self, retriever, reranker_model_name="BAAI/bge-reranker-base", top_n=3): | |
self.retriever = retriever | |
self.reranker_model_name = reranker_model_name | |
self.top_n = top_n | |
self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
self.tokenizer = AutoTokenizer.from_pretrained(self.reranker_model_name) | |
self.model = AutoModelForSequenceClassification.from_pretrained(self.reranker_model_name) | |
self.model = self.model.to(self.device) | |
self.compressor = CrossEncoderReranker(model=self, top_n=self.top_n) | |
self.compression_retriever = ContextualCompressionRetriever( | |
base_compressor=self.compressor, base_retriever=self.retriever | |
) | |
def __call__(self, pairs): | |
with torch.inference_mode(): | |
inputs = self.tokenizer( | |
pairs, | |
padding=True, | |
truncation=True, | |
return_tensors="pt", | |
max_length=512, | |
) | |
inputs = inputs.to(self.device) | |
scores = self.model(**inputs, return_dict=True).logits.view(-1).float() | |
return scores.detach().cpu().tolist() | |
def retrieve_and_rerank(self, query): | |
return self.compression_retriever.invoke(query) | |
def pretty_print_docs(docs): | |
print( | |
f"\n{'-' * 100}\n".join( | |
[f"Document {i+1}:\n\n" + d.page_content for i, d in enumerate(docs)] | |
) | |
) |