Spaces:
Sleeping
Sleeping
import json | |
import numpy as np | |
from sentence_transformers import util | |
from typing import Tuple, List, Dict, Any, Optional | |
from app.services.model_service import get_model, get_embeddings | |
def search_answer( | |
user_input: str, | |
model, | |
question_embeddings: np.ndarray, | |
answer_embeddings: np.ndarray, | |
threshold_q: float, | |
threshold_a: float, | |
answers: List[str], | |
) -> Tuple[str, str]: | |
""" | |
Search for an answer using cosine similarity. | |
""" | |
# Encode with batch_size and show_progress_bar=False to speed up | |
user_embedding = model.encode( | |
[user_input], | |
convert_to_numpy=True, | |
batch_size=1, | |
show_progress_bar=False, | |
normalize_embeddings=True, # Pre-normalize to speed up cosine similarity | |
) | |
# Calculate cosine similarity with questions | |
cos_scores_q = util.cos_sim(user_embedding, question_embeddings)[0] | |
best_q_idx = np.argmax(cos_scores_q) | |
score_q = cos_scores_q[best_q_idx] | |
if score_q >= threshold_q: | |
return ( | |
answers[best_q_idx].replace("\n", " \n"), | |
f"{score_q:.2f}", | |
) | |
# Calculate cosine similarity with answers | |
cos_scores_a = util.cos_sim(user_embedding, answer_embeddings)[0] | |
best_a_idx = np.argmax(cos_scores_a) | |
score_a = cos_scores_a[best_a_idx] | |
if score_a >= threshold_a: | |
return ( | |
answers[best_a_idx].replace("\n", " \n"), | |
f"{score_a:.2f}", | |
) | |
return ( | |
"็ณใ่จณใใใพใใใใใ่ณชๅใฎ็ญใใ่ฆใคใใใใจใใงใใพใใใงใใใใใๅฐใ่ฉณใใ่ชฌๆใใฆใใใ ใใพใใ๏ผ", | |
"ไธ่ดใชใ", | |
) | |
def predict_answer( | |
user_input: str, threshold_q: float = 0.5, threshold_a: float = 0.5 | |
) -> Dict[str, Any]: | |
""" | |
Predict an answer based on user input. | |
""" | |
try: | |
# Get the global model and embeddings | |
model = get_model() | |
question_embeddings, answer_embeddings, qa_data = get_embeddings() | |
if question_embeddings is None or answer_embeddings is None or qa_data is None: | |
return { | |
"status": "error", | |
"message": "Embeddings not found. Please create embeddings first.", | |
} | |
answers = [item["answer"] for item in qa_data] | |
answer, score = search_answer( | |
user_input, | |
model, | |
question_embeddings, | |
answer_embeddings, | |
threshold_q, | |
threshold_a, | |
answers, | |
) | |
return {"status": "success", "answer": answer, "score": score} | |
except Exception as e: | |
return {"status": "error", "message": str(e)} | |