File size: 2,836 Bytes
b5deaf1
e83b975
b5deaf1
 
9a73c5d
b5deaf1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2bf1ef6
b5deaf1
2bf1ef6
b5deaf1
7097576
 
 
b5deaf1
1494f1e
b5deaf1
 
7097576
1494f1e
b5deaf1
 
 
 
e83b975
 
1031c5b
7097576
e83b975
e3a12d5
e83b975
1031c5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e83b975
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
"""Module for retrievers that fetch documents from various sources."""
from venv import logger
from langchain_core.retrievers import BaseRetriever
from langchain_core.vectorstores import VectorStoreRetriever
from models.db import vectorstore

class DocRetriever(BaseRetriever):
    """
    DocRetriever is a class that retrieves documents using a VectorStoreRetriever.
    Attributes:
        retriever (VectorStoreRetriever): An instance used to retrieve documents.
        k (int): The number of documents to retrieve. Default is 10.
    Methods:
        __init__(k: int = 10) -> None:
            Initializes the DocRetriever with a specified number of documents to retrieve.
        _get_relevant_documents(query: str, *, run_manager) -> list:
            Retrieves relevant documents based on the given query.
            Args:
                query (str): The query string to search for relevant documents.
                run_manager: An object to manage the run (not used in the method).
            Returns:
                list: A list of Document objects with relevant metadata.
    """
    retriever: VectorStoreRetriever = None
    k: int = 6

    def __init__(self, req, k: int = 6) -> None:
        super().__init__()
        _filter={}
        _filter.update({"user_id": req.user_id})
        print(_filter)
        self.retriever = vectorstore.as_retriever(
            search_type='similarity',
            search_kwargs={
                "k": k,
                "filter": _filter,
                # "score_threshold": .3
            }
        )

    def _get_relevant_documents(self, query: str, *, run_manager) -> list:
        try:
            retrieved_docs = self.retriever.invoke(query)
            # doc_lst = []
            print(retrieved_docs)
            for doc in retrieved_docs:
                doc.metadata['id'] = doc.id
                # date = str(doc.metadata['publishDate'])
                doc.metadata['content'] = doc.page_content
                # doc_lst.append(Document(
                #     page_content = doc.page_content,
                #     metadata = doc.metadata
                #     # metadata = {
                #     #     "content": doc.page_content,
                #     #     # "id": doc.metadata['id'],
                #     #     "title": doc.metadata['subject'],
                #     #     # "site": doc.metadata['site'],
                #     #     # "link": doc.metadata['link'],
                #     #     # "publishDate": doc.metadata['publishDate'].strftime('%Y-%m-%d'),
                #     #     # 'web': False,
                #     #     # "source": "Finfast"
                #     # }
                # ))
            return retrieved_docs
        except RuntimeError as e:
            logger.error("Error retrieving documents: %s", e)
            return []