|
""" |
|
Retrievers for text chunks. |
|
""" |
|
|
|
import os |
|
|
|
from langchain.text_splitter import ( |
|
RecursiveCharacterTextSplitter, |
|
SpacyTextSplitter, |
|
) |
|
|
|
from rerank import BgeRerank |
|
from langchain.retrievers import ContextualCompressionRetriever |
|
|
|
|
|
def get_parent_doc_retriever( |
|
documents, |
|
vectorstore, |
|
add_documents=True, |
|
docstore="in_memory", |
|
save_path_root="./", |
|
docstore_file="store_location", |
|
save_vectorstore=False, |
|
save_docstore=False, |
|
k=10, |
|
): |
|
"""Parent document (small-to-big) retriever.""" |
|
|
|
|
|
|
|
from langchain.storage.file_system import LocalFileStore |
|
from langchain.storage import InMemoryStore |
|
from langchain.storage._lc_store import create_kv_docstore |
|
from langchain.retrievers import ParentDocumentRetriever |
|
|
|
|
|
if docstore == "in_memory": |
|
docstore = InMemoryStore() |
|
elif docstore == "local_storage": |
|
|
|
fs = LocalFileStore(docstore_file) |
|
docstore = create_kv_docstore(fs) |
|
elif docstore == "sql": |
|
from langchain_rag.storage import SQLStore |
|
|
|
|
|
docstore = SQLStore( |
|
namespace="test", db_url="sqlite:///parent_retrieval_db.db" |
|
) |
|
else: |
|
docstore = docstore |
|
|
|
|
|
|
|
|
|
|
|
parent_splitter = SpacyTextSplitter.from_tiktoken_encoder( |
|
chunk_size=512, |
|
chunk_overlap=128, |
|
) |
|
child_splitter = SpacyTextSplitter.from_tiktoken_encoder( |
|
chunk_size=256, |
|
chunk_overlap=64, |
|
) |
|
|
|
retriever = ParentDocumentRetriever( |
|
vectorstore=vectorstore, |
|
docstore=docstore, |
|
child_splitter=child_splitter, |
|
parent_splitter=parent_splitter, |
|
search_kwargs={"k": k}, |
|
) |
|
|
|
if add_documents: |
|
retriever.add_documents(documents) |
|
|
|
if save_vectorstore: |
|
vectorstore.save_local(os.path.join(save_path_root, "faiss_index")) |
|
|
|
if save_docstore: |
|
import pickle |
|
|
|
def save_to_pickle(obj, filename): |
|
with open(filename, "wb") as file: |
|
pickle.dump(obj, file, pickle.HIGHEST_PROTOCOL) |
|
|
|
save_to_pickle(docstore, os.path.join(save_path_root, "docstore.pkl")) |
|
|
|
return retriever |
|
|
|
|
|
def get_rerank_retriever(base_retriever, reranker=None): |
|
"""Return rerank retriever.""" |
|
|
|
|
|
if reranker is None: |
|
compressor = BgeRerank() |
|
else: |
|
compressor = reranker |
|
compression_retriever = ContextualCompressionRetriever( |
|
base_compressor=compressor, base_retriever=base_retriever |
|
) |
|
|
|
return compression_retriever |
|
|