import json import numpy as np from sentence_transformers import util from typing import Tuple, List, Dict, Any, Optional from app.services.model_service import get_model, get_embeddings def search_answer( user_input: str, model, question_embeddings: np.ndarray, answer_embeddings: np.ndarray, threshold_q: float, threshold_a: float, answers: List[str], ) -> Tuple[str, str]: """ Search for an answer using cosine similarity. """ # Encode with batch_size and show_progress_bar=False to speed up user_embedding = model.encode( [user_input], convert_to_numpy=True, batch_size=1, show_progress_bar=False, normalize_embeddings=True, # Pre-normalize to speed up cosine similarity ) # Calculate cosine similarity with questions cos_scores_q = util.cos_sim(user_embedding, question_embeddings)[0] best_q_idx = np.argmax(cos_scores_q) score_q = cos_scores_q[best_q_idx] if score_q >= threshold_q: return ( answers[best_q_idx].replace("\n", " \n"), f"{score_q:.2f}", ) # Calculate cosine similarity with answers cos_scores_a = util.cos_sim(user_embedding, answer_embeddings)[0] best_a_idx = np.argmax(cos_scores_a) score_a = cos_scores_a[best_a_idx] if score_a >= threshold_a: return ( answers[best_a_idx].replace("\n", " \n"), f"{score_a:.2f}", ) return ( "申し訳ありませんが、ご質問の答えを見つけることができませんでした。もう少し詳しく説明していただけますか?", "一致なし", ) def predict_answer( user_input: str, threshold_q: float = 0.5, threshold_a: float = 0.5 ) -> Dict[str, Any]: """ Predict an answer based on user input. """ try: # Get the global model and embeddings model = get_model() question_embeddings, answer_embeddings, qa_data = get_embeddings() if question_embeddings is None or answer_embeddings is None or qa_data is None: return { "status": "error", "message": "Embeddings not found. Please create embeddings first.", } answers = [item["answer"] for item in qa_data] answer, score = search_answer( user_input, model, question_embeddings, answer_embeddings, threshold_q, threshold_a, answers, ) return {"status": "success", "answer": answer, "score": score} except Exception as e: return {"status": "error", "message": str(e)}