File size: 3,288 Bytes
1a20a59
 
 
 
2fafc94
 
 
 
 
 
 
 
 
 
1a20a59
2fafc94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1a20a59
 
 
2fafc94
 
1a20a59
2fafc94
1a20a59
2fafc94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1a20a59
2fafc94
 
 
1a20a59
2fafc94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
"""
Retrievers for text chunks.
"""

import os

from langchain.text_splitter import (
    RecursiveCharacterTextSplitter,
    SpacyTextSplitter,
)

from rerank import BgeRerank
from langchain.retrievers import ContextualCompressionRetriever


def get_parent_doc_retriever(
    documents,
    vectorstore,
    add_documents=True,
    docstore="in_memory",
    save_path_root="./",
    docstore_file="store_location",
    save_vectorstore=False,
    save_docstore=False,
    k=10,
):
    """Parent document (small-to-big) retriever."""

    # TODO need better design
    # Ref: explain how it works: https://clusteredbytes.pages.dev/posts/2023/langchain-parent-document-retriever/
    from langchain.storage.file_system import LocalFileStore
    from langchain.storage import InMemoryStore
    from langchain.storage._lc_store import create_kv_docstore
    from langchain.retrievers import ParentDocumentRetriever

    # Document store for parent, different from (child) docs in vectorestore
    if docstore == "in_memory":
        docstore = InMemoryStore()
    elif docstore == "local_storage":
        # Ref: https://stackoverflow.com/questions/77385587/persist-parentdocumentretriever-of-langchain
        fs = LocalFileStore(docstore_file)
        docstore = create_kv_docstore(fs)
    elif docstore == "sql":
        from langchain_rag.storage import SQLStore

        # Instantiate the SQLStore with the root path
        docstore = SQLStore(
            namespace="test", db_url="sqlite:///parent_retrieval_db.db"
        )  # TODO: WIP
    else:
        docstore = docstore  # TODO: add check
        # raise  # TODO implement other docstores

    # TODO: how to better set these values?
    # parent_splitter = RecursiveCharacterTextSplitter(separators=["\n\n", "\n"], chunk_size=1024, chunk_overlap=256)
    # child_splitter = RecursiveCharacterTextSplitter(separators=["\n\n", "\n"], chunk_size=256, chunk_overlap=64)
    parent_splitter = SpacyTextSplitter.from_tiktoken_encoder(
        chunk_size=512,
        chunk_overlap=128,
    )
    child_splitter = SpacyTextSplitter.from_tiktoken_encoder(
        chunk_size=256,
        chunk_overlap=64,
    )

    retriever = ParentDocumentRetriever(
        vectorstore=vectorstore,
        docstore=docstore,
        child_splitter=child_splitter,
        parent_splitter=parent_splitter,
        search_kwargs={"k": k},
    )

    if add_documents:
        retriever.add_documents(documents)

    if save_vectorstore:
        vectorstore.save_local(os.path.join(save_path_root, "faiss_index"))

    if save_docstore:
        import pickle

        def save_to_pickle(obj, filename):
            with open(filename, "wb") as file:
                pickle.dump(obj, file, pickle.HIGHEST_PROTOCOL)

        save_to_pickle(docstore, os.path.join(save_path_root, "docstore.pkl"))

    return retriever


def get_rerank_retriever(base_retriever, reranker=None):
    """Return rerank retriever."""

    # Use default BgeRerank or user defined reranker
    if reranker is None:
        compressor = BgeRerank()
    else:  # TODO : add check
        compressor = reranker
    compression_retriever = ContextualCompressionRetriever(
        base_compressor=compressor, base_retriever=base_retriever
    )

    return compression_retriever