hakaton / db.py
aleksandrrnt's picture
Upload 5 files
44cbb2e verified
import os
import requests
from chromadb import Client, Settings, PersistentClient
from chromadb.utils.embedding_functions.sentence_transformer_embedding_function import SentenceTransformerEmbeddingFunction
from rank_bm25 import BM25Okapi
import logging
logger = logging.getLogger("db")
logging.basicConfig(
format="%(asctime)s %(levelname)-8s %(message)s",
level=logging.INFO,
datefmt="%Y-%m-%d %H:%M:%S",
)
JINA_KEY = os.getenv('JINA_API_KEY')
jina_headers = {
"Authorization": f"Bearer {JINA_KEY}"
}
def get_data_url(url):
logger.info(f"Scraping {url}")
jina_response = requests.get(f"https://r.jina.ai/{url}", headers=jina_headers, verify=False)
return jina_response.text
class HacatonDB:
def __init__(self):
self.client = PersistentClient(settings=Settings(anonymized_telemetry=False))
self.embed = SentenceTransformerEmbeddingFunction(
model_name="BAAI/bge-m3"
)
self.collection = self.client.create_collection('test_hakaton', embedding_function=self.embed, metadata={"hnsw:space": "cosine"}, get_or_create=True)
def add(self, urls):
logger.info(f"Add info to collection")
texts = []
meta = []
new_urls = []
for url in urls:
if len(self.collection.get(ids=[url])["ids"]) > 0:
logger.info(f"URL {url} already exist")
continue
new_urls.append(url)
texts.append(get_data_url(url))
meta.append({"file_name": f"file_{url.split('/')[-2]}"})
logger.info(f"URL {url} added")
if len(new_urls) > 0:
self.collection.add(documents=texts, ids=new_urls, metadatas=meta)
logger.info(f"Addition {len(new_urls)} sources completed")
else:
logger.info(f"No new sources")
def update(self, urls):
pass
def get_ids(self):
return self.collection.get()["ids"]
def query(self, query, top_k, alpha=0.5):
results = self.collection.query(
query_texts=query,
n_results=top_k*2
)
chroma_scores = dict(zip(results['ids'][0], results['distances'][0]))
bm25_scores = self.get_bm25_scores(query)
combined = {}
bm25_max = max(bm25_scores.values())
for doc_id in set(chroma_scores.keys()).union(bm25_scores.keys()):
chroma_score = chroma_scores.get(doc_id, 1)
bm25_score = bm25_scores.get(doc_id, 0)
if bm25_max > 0:
bm25_score = bm25_score / bm25_max
combined[doc_id] = alpha * (1 - chroma_score) + (1 - alpha) * bm25_score # 1 - chroma_norm т.к. меньшее расстояние лучше
sorted_docs = sorted(combined.items(), key=lambda x: x[1], reverse=True)
top_ids = [doc[0] for doc in sorted_docs[:top_k]]
return self.collection.get(ids=top_ids)
def get_bm25_scores(self, query):
retrieved = self.collection.get(include=['documents'])
all_docs = retrieved['documents']
all_ids = retrieved['ids']
num2id = {}
for num, doc_id in enumerate(all_ids):
num2id[num] = doc_id
tokenized_corpus = [doc.lower().split() for doc in all_docs]
bm25 = BM25Okapi(tokenized_corpus)
tokenized_query = query.lower().split()
bm25_scores = bm25.get_scores(tokenized_query)
bm25_scores = {num2id[i]: float(score) for i, score in enumerate(bm25_scores)}
return bm25_scores
def get_all_documents(self):
return self.collection.get(include=['documents'])['documents']
db = HacatonDB()