RAG / backend /reranking.py
thenativefox
Added split files and tables
939262b
raw
history blame
1.79 kB
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import CrossEncoderReranker
class DocumentRetrieverWithReranker:
def __init__(self, retriever, reranker_model_name="BAAI/bge-reranker-base", top_n=3):
self.retriever = retriever
self.reranker_model_name = reranker_model_name
self.top_n = top_n
self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
self.tokenizer = AutoTokenizer.from_pretrained(self.reranker_model_name)
self.model = AutoModelForSequenceClassification.from_pretrained(self.reranker_model_name)
self.model = self.model.to(self.device)
self.compressor = CrossEncoderReranker(model=self, top_n=self.top_n)
self.compression_retriever = ContextualCompressionRetriever(
base_compressor=self.compressor, base_retriever=self.retriever
)
def __call__(self, pairs):
with torch.inference_mode():
inputs = self.tokenizer(
pairs,
padding=True,
truncation=True,
return_tensors="pt",
max_length=512,
)
inputs = inputs.to(self.device)
scores = self.model(**inputs, return_dict=True).logits.view(-1).float()
return scores.detach().cpu().tolist()
def retrieve_and_rerank(self, query):
return self.compression_retriever.invoke(query)
@staticmethod
def pretty_print_docs(docs):
print(
f"\n{'-' * 100}\n".join(
[f"Document {i+1}:\n\n" + d.page_content for i, d in enumerate(docs)]
)
)