File size: 3,775 Bytes
8cf08be
 
 
 
44cbb2e
8cf08be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9de1f87
 
 
44cbb2e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8cf08be
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import os
import requests
from chromadb import Client, Settings, PersistentClient
from chromadb.utils.embedding_functions.sentence_transformer_embedding_function import SentenceTransformerEmbeddingFunction
from rank_bm25 import BM25Okapi
import logging

logger = logging.getLogger("db")
logging.basicConfig(
    format="%(asctime)s %(levelname)-8s %(message)s",
    level=logging.INFO,
    datefmt="%Y-%m-%d %H:%M:%S",
)

JINA_KEY = os.getenv('JINA_API_KEY')

jina_headers = {
"Authorization": f"Bearer {JINA_KEY}"
}

def get_data_url(url):
    logger.info(f"Scraping {url}")
    jina_response = requests.get(f"https://r.jina.ai/{url}", headers=jina_headers, verify=False)

    return jina_response.text

class HacatonDB:
    def __init__(self):
        self.client = PersistentClient(settings=Settings(anonymized_telemetry=False))
        self.embed = SentenceTransformerEmbeddingFunction(
            model_name="BAAI/bge-m3"
        )
        self.collection = self.client.create_collection('test_hakaton', embedding_function=self.embed, metadata={"hnsw:space": "cosine"}, get_or_create=True)

    def add(self, urls):
        logger.info(f"Add info to collection")
        texts = []
        meta = []
        new_urls = []
        for url in urls:
            if len(self.collection.get(ids=[url])["ids"]) > 0:
                logger.info(f"URL {url} already exist")
                continue

            new_urls.append(url)
            texts.append(get_data_url(url))
            meta.append({"file_name": f"file_{url.split('/')[-2]}"})
            logger.info(f"URL {url} added")

        if len(new_urls) > 0:
            self.collection.add(documents=texts, ids=new_urls, metadatas=meta)
            logger.info(f"Addition {len(new_urls)} sources completed")
        else:
            logger.info(f"No new sources")

    def update(self, urls):
        pass

    def get_ids(self):
        return self.collection.get()["ids"]

    def query(self, query, top_k, alpha=0.5):
        results = self.collection.query(
            query_texts=query,
            n_results=top_k*2
        )
        chroma_scores = dict(zip(results['ids'][0], results['distances'][0]))

        bm25_scores = self.get_bm25_scores(query)

        combined = {}
        

        bm25_max = max(bm25_scores.values())
        for doc_id in set(chroma_scores.keys()).union(bm25_scores.keys()):
            chroma_score = chroma_scores.get(doc_id, 1)
            bm25_score = bm25_scores.get(doc_id, 0)
            
            if bm25_max > 0:
                bm25_score = bm25_score / bm25_max

            combined[doc_id] = alpha * (1 - chroma_score) + (1 - alpha) * bm25_score  # 1 - chroma_norm т.к. меньшее расстояние лучше

        sorted_docs = sorted(combined.items(), key=lambda x: x[1], reverse=True)
        top_ids = [doc[0] for doc in sorted_docs[:top_k]]

        return self.collection.get(ids=top_ids)

    def get_bm25_scores(self, query):
        retrieved = self.collection.get(include=['documents'])
        all_docs = retrieved['documents']
        all_ids = retrieved['ids']

        num2id = {}
        for num, doc_id in enumerate(all_ids):
            num2id[num] = doc_id
        tokenized_corpus = [doc.lower().split() for doc in all_docs]
        bm25 = BM25Okapi(tokenized_corpus)


        tokenized_query = query.lower().split()
        bm25_scores = bm25.get_scores(tokenized_query)

        bm25_scores = {num2id[i]: float(score) for i, score in enumerate(bm25_scores)}
        return bm25_scores
    
    def get_all_documents(self):
        return self.collection.get(include=['documents'])['documents']


db = HacatonDB()