LISA-demo / retrievers.py
Kadi-IAM's picture
Clean code and add readme
1a20a59
"""
Retrievers for text chunks.
"""
import os
from langchain.text_splitter import (
RecursiveCharacterTextSplitter,
SpacyTextSplitter,
)
from rerank import BgeRerank
from langchain.retrievers import ContextualCompressionRetriever
def get_parent_doc_retriever(
documents,
vectorstore,
add_documents=True,
docstore="in_memory",
save_path_root="./",
docstore_file="store_location",
save_vectorstore=False,
save_docstore=False,
k=10,
):
"""Parent document (small-to-big) retriever."""
# TODO need better design
# Ref: explain how it works: https://clusteredbytes.pages.dev/posts/2023/langchain-parent-document-retriever/
from langchain.storage.file_system import LocalFileStore
from langchain.storage import InMemoryStore
from langchain.storage._lc_store import create_kv_docstore
from langchain.retrievers import ParentDocumentRetriever
# Document store for parent, different from (child) docs in vectorestore
if docstore == "in_memory":
docstore = InMemoryStore()
elif docstore == "local_storage":
# Ref: https://stackoverflow.com/questions/77385587/persist-parentdocumentretriever-of-langchain
fs = LocalFileStore(docstore_file)
docstore = create_kv_docstore(fs)
elif docstore == "sql":
from langchain_rag.storage import SQLStore
# Instantiate the SQLStore with the root path
docstore = SQLStore(
namespace="test", db_url="sqlite:///parent_retrieval_db.db"
) # TODO: WIP
else:
docstore = docstore # TODO: add check
# raise # TODO implement other docstores
# TODO: how to better set these values?
# parent_splitter = RecursiveCharacterTextSplitter(separators=["\n\n", "\n"], chunk_size=1024, chunk_overlap=256)
# child_splitter = RecursiveCharacterTextSplitter(separators=["\n\n", "\n"], chunk_size=256, chunk_overlap=64)
parent_splitter = SpacyTextSplitter.from_tiktoken_encoder(
chunk_size=512,
chunk_overlap=128,
)
child_splitter = SpacyTextSplitter.from_tiktoken_encoder(
chunk_size=256,
chunk_overlap=64,
)
retriever = ParentDocumentRetriever(
vectorstore=vectorstore,
docstore=docstore,
child_splitter=child_splitter,
parent_splitter=parent_splitter,
search_kwargs={"k": k},
)
if add_documents:
retriever.add_documents(documents)
if save_vectorstore:
vectorstore.save_local(os.path.join(save_path_root, "faiss_index"))
if save_docstore:
import pickle
def save_to_pickle(obj, filename):
with open(filename, "wb") as file:
pickle.dump(obj, file, pickle.HIGHEST_PROTOCOL)
save_to_pickle(docstore, os.path.join(save_path_root, "docstore.pkl"))
return retriever
def get_rerank_retriever(base_retriever, reranker=None):
"""Return rerank retriever."""
# Use default BgeRerank or user defined reranker
if reranker is None:
compressor = BgeRerank()
else: # TODO : add check
compressor = reranker
compression_retriever = ContextualCompressionRetriever(
base_compressor=compressor, base_retriever=base_retriever
)
return compression_retriever