|
""" |
|
Rerank with cross encoder. |
|
Ref: |
|
https://medium.aiplanet.com/advanced-rag-cohere-re-ranker-99acc941601c |
|
https://github.com/langchain-ai/langchain/issues/13076 |
|
""" |
|
|
|
from __future__ import annotations |
|
from typing import Optional, Sequence |
|
from langchain.schema import Document |
|
from langchain.pydantic_v1 import Extra |
|
|
|
from langchain.callbacks.manager import Callbacks |
|
from langchain.retrievers.document_compressors.base import BaseDocumentCompressor |
|
|
|
from sentence_transformers import CrossEncoder |
|
|
|
|
|
class BgeRerank(BaseDocumentCompressor): |
|
""" |
|
Re-rank with CrossEncoder. |
|
|
|
Ref: |
|
https://medium.aiplanet.com/advanced-rag-cohere-re-ranker-99acc941601c |
|
https://github.com/langchain-ai/langchain/issues/13076 |
|
good to read: |
|
https://zhuanlan.zhihu.com/p/676008717 or its source https://teemukanstren.com/2023/12/25/llmrag-based-question-answering/ |
|
""" |
|
|
|
|
|
|
|
model_name: str = "jinaai/jina-reranker-v1-turbo-en" |
|
"""Model name to use for reranking.""" |
|
top_n: int = 6 |
|
"""Number of documents to return.""" |
|
model: CrossEncoder = CrossEncoder(model_name, trust_remote_code=True) |
|
"""CrossEncoder instance to use for reranking.""" |
|
|
|
def bge_rerank(self, query, docs): |
|
model_inputs = [[query, doc] for doc in docs] |
|
scores = self.model.predict(model_inputs) |
|
results = sorted(enumerate(scores), key=lambda x: x[1], reverse=True) |
|
return results[: self.top_n] |
|
|
|
class Config: |
|
"""Configuration for this pydantic object.""" |
|
|
|
extra = Extra.forbid |
|
arbitrary_types_allowed = True |
|
|
|
def compress_documents( |
|
self, |
|
documents: Sequence[Document], |
|
query: str, |
|
callbacks: Optional[Callbacks] = None, |
|
) -> Sequence[Document]: |
|
""" |
|
Compress documents using BAAI/bge-reranker models. |
|
|
|
Args: |
|
documents: A sequence of documents to compress. |
|
query: The query to use for compressing the documents. |
|
callbacks: Callbacks to run during the compression process. |
|
|
|
Returns: |
|
A sequence of compressed documents. |
|
""" |
|
if len(documents) == 0: |
|
return [] |
|
doc_list = list(documents) |
|
_docs = [d.page_content for d in doc_list] |
|
results = self.bge_rerank(query, _docs) |
|
final_results = [] |
|
for r in results: |
|
doc = doc_list[r[0]] |
|
doc.metadata["relevance_score"] = r[1] |
|
final_results.append(doc) |
|
|
|
return final_results |
|
|