Taiken_chatbot_API / app /services /prediction_service.py
vumichien's picture
Add application file
44a025a
import json
import numpy as np
from sentence_transformers import util
from typing import Tuple, List, Dict, Any, Optional
from app.services.model_service import get_model, get_embeddings
def search_answer(
user_input: str,
model,
question_embeddings: np.ndarray,
answer_embeddings: np.ndarray,
threshold_q: float,
threshold_a: float,
answers: List[str],
) -> Tuple[str, str]:
"""
Search for an answer using cosine similarity.
"""
# Encode with batch_size and show_progress_bar=False to speed up
user_embedding = model.encode(
[user_input],
convert_to_numpy=True,
batch_size=1,
show_progress_bar=False,
normalize_embeddings=True, # Pre-normalize to speed up cosine similarity
)
# Calculate cosine similarity with questions
cos_scores_q = util.cos_sim(user_embedding, question_embeddings)[0]
best_q_idx = np.argmax(cos_scores_q)
score_q = cos_scores_q[best_q_idx]
if score_q >= threshold_q:
return (
answers[best_q_idx].replace("\n", " \n"),
f"{score_q:.2f}",
)
# Calculate cosine similarity with answers
cos_scores_a = util.cos_sim(user_embedding, answer_embeddings)[0]
best_a_idx = np.argmax(cos_scores_a)
score_a = cos_scores_a[best_a_idx]
if score_a >= threshold_a:
return (
answers[best_a_idx].replace("\n", " \n"),
f"{score_a:.2f}",
)
return (
"็”ณใ—่จณใ‚ใ‚Šใพใ›ใ‚“ใŒใ€ใ”่ณชๅ•ใฎ็ญ”ใˆใ‚’่ฆ‹ใคใ‘ใ‚‹ใ“ใจใŒใงใใพใ›ใ‚“ใงใ—ใŸใ€‚ใ‚‚ใ†ๅฐ‘ใ—่ฉณใ—ใ่ชฌๆ˜Žใ—ใฆใ„ใŸใ ใ‘ใพใ™ใ‹๏ผŸ",
"ไธ€่‡ดใชใ—",
)
def predict_answer(
user_input: str, threshold_q: float = 0.5, threshold_a: float = 0.5
) -> Dict[str, Any]:
"""
Predict an answer based on user input.
"""
try:
# Get the global model and embeddings
model = get_model()
question_embeddings, answer_embeddings, qa_data = get_embeddings()
if question_embeddings is None or answer_embeddings is None or qa_data is None:
return {
"status": "error",
"message": "Embeddings not found. Please create embeddings first.",
}
answers = [item["answer"] for item in qa_data]
answer, score = search_answer(
user_input,
model,
question_embeddings,
answer_embeddings,
threshold_q,
threshold_a,
answers,
)
return {"status": "success", "answer": answer, "score": score}
except Exception as e:
return {"status": "error", "message": str(e)}