Spaces:
Running
Running
import os | |
import requests | |
from chromadb import Client, Settings, PersistentClient | |
from chromadb.utils.embedding_functions.sentence_transformer_embedding_function import SentenceTransformerEmbeddingFunction | |
from rank_bm25 import BM25Okapi | |
import logging | |
logger = logging.getLogger("db") | |
logging.basicConfig( | |
format="%(asctime)s %(levelname)-8s %(message)s", | |
level=logging.INFO, | |
datefmt="%Y-%m-%d %H:%M:%S", | |
) | |
JINA_KEY = os.getenv('JINA_API_KEY') | |
jina_headers = { | |
"Authorization": f"Bearer {JINA_KEY}" | |
} | |
def get_data_url(url): | |
logger.info(f"Scraping {url}") | |
jina_response = requests.get(f"https://r.jina.ai/{url}", headers=jina_headers, verify=False) | |
return jina_response.text | |
class HacatonDB: | |
def __init__(self): | |
self.client = PersistentClient(settings=Settings(anonymized_telemetry=False)) | |
self.embed = SentenceTransformerEmbeddingFunction( | |
model_name="BAAI/bge-m3" | |
) | |
self.collection = self.client.create_collection('test_hakaton', embedding_function=self.embed, metadata={"hnsw:space": "cosine"}, get_or_create=True) | |
def add(self, urls): | |
logger.info(f"Add info to collection") | |
texts = [] | |
meta = [] | |
new_urls = [] | |
for url in urls: | |
if len(self.collection.get(ids=[url])["ids"]) > 0: | |
logger.info(f"URL {url} already exist") | |
continue | |
new_urls.append(url) | |
texts.append(get_data_url(url)) | |
meta.append({"file_name": f"file_{url.split('/')[-2]}"}) | |
logger.info(f"URL {url} added") | |
if len(new_urls) > 0: | |
self.collection.add(documents=texts, ids=new_urls, metadatas=meta) | |
logger.info(f"Addition {len(new_urls)} sources completed") | |
else: | |
logger.info(f"No new sources") | |
def update(self, urls): | |
pass | |
def get_ids(self): | |
return self.collection.get()["ids"] | |
def query(self, query, top_k, alpha=0.5): | |
results = self.collection.query( | |
query_texts=query, | |
n_results=top_k*2 | |
) | |
chroma_scores = dict(zip(results['ids'][0], results['distances'][0])) | |
bm25_scores = self.get_bm25_scores(query) | |
combined = {} | |
bm25_max = max(bm25_scores.values()) | |
for doc_id in set(chroma_scores.keys()).union(bm25_scores.keys()): | |
chroma_score = chroma_scores.get(doc_id, 1) | |
bm25_score = bm25_scores.get(doc_id, 0) | |
if bm25_max > 0: | |
bm25_score = bm25_score / bm25_max | |
combined[doc_id] = alpha * (1 - chroma_score) + (1 - alpha) * bm25_score # 1 - chroma_norm т.к. меньшее расстояние лучше | |
sorted_docs = sorted(combined.items(), key=lambda x: x[1], reverse=True) | |
top_ids = [doc[0] for doc in sorted_docs[:top_k]] | |
return self.collection.get(ids=top_ids) | |
def get_bm25_scores(self, query): | |
retrieved = self.collection.get(include=['documents']) | |
all_docs = retrieved['documents'] | |
all_ids = retrieved['ids'] | |
num2id = {} | |
for num, doc_id in enumerate(all_ids): | |
num2id[num] = doc_id | |
tokenized_corpus = [doc.lower().split() for doc in all_docs] | |
bm25 = BM25Okapi(tokenized_corpus) | |
tokenized_query = query.lower().split() | |
bm25_scores = bm25.get_scores(tokenized_query) | |
bm25_scores = {num2id[i]: float(score) for i, score in enumerate(bm25_scores)} | |
return bm25_scores | |
def get_all_documents(self): | |
return self.collection.get(include=['documents'])['documents'] | |
db = HacatonDB() |