File size: 2,685 Bytes
44a025a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import json
import numpy as np
from sentence_transformers import util
from typing import Tuple, List, Dict, Any, Optional

from app.services.model_service import get_model, get_embeddings


def search_answer(
    user_input: str,
    model,
    question_embeddings: np.ndarray,
    answer_embeddings: np.ndarray,
    threshold_q: float,
    threshold_a: float,
    answers: List[str],
) -> Tuple[str, str]:
    """
    Search for an answer using cosine similarity.
    """
    # Encode with batch_size and show_progress_bar=False to speed up
    user_embedding = model.encode(
        [user_input],
        convert_to_numpy=True,
        batch_size=1,
        show_progress_bar=False,
        normalize_embeddings=True,  # Pre-normalize to speed up cosine similarity
    )

    # Calculate cosine similarity with questions
    cos_scores_q = util.cos_sim(user_embedding, question_embeddings)[0]
    best_q_idx = np.argmax(cos_scores_q)
    score_q = cos_scores_q[best_q_idx]

    if score_q >= threshold_q:
        return (
            answers[best_q_idx].replace("\n", "  \n"),
            f"{score_q:.2f}",
        )

    # Calculate cosine similarity with answers
    cos_scores_a = util.cos_sim(user_embedding, answer_embeddings)[0]
    best_a_idx = np.argmax(cos_scores_a)
    score_a = cos_scores_a[best_a_idx]

    if score_a >= threshold_a:
        return (
            answers[best_a_idx].replace("\n", "  \n"),
            f"{score_a:.2f}",
        )

    return (
        "็”ณใ—่จณใ‚ใ‚Šใพใ›ใ‚“ใŒใ€ใ”่ณชๅ•ใฎ็ญ”ใˆใ‚’่ฆ‹ใคใ‘ใ‚‹ใ“ใจใŒใงใใพใ›ใ‚“ใงใ—ใŸใ€‚ใ‚‚ใ†ๅฐ‘ใ—่ฉณใ—ใ่ชฌๆ˜Žใ—ใฆใ„ใŸใ ใ‘ใพใ™ใ‹๏ผŸ",
        "ไธ€่‡ดใชใ—",
    )


def predict_answer(
    user_input: str, threshold_q: float = 0.5, threshold_a: float = 0.5
) -> Dict[str, Any]:
    """
    Predict an answer based on user input.
    """
    try:
        # Get the global model and embeddings
        model = get_model()
        question_embeddings, answer_embeddings, qa_data = get_embeddings()

        if question_embeddings is None or answer_embeddings is None or qa_data is None:
            return {
                "status": "error",
                "message": "Embeddings not found. Please create embeddings first.",
            }

        answers = [item["answer"] for item in qa_data]

        answer, score = search_answer(
            user_input,
            model,
            question_embeddings,
            answer_embeddings,
            threshold_q,
            threshold_a,
            answers,
        )

        return {"status": "success", "answer": answer, "score": score}
    except Exception as e:
        return {"status": "error", "message": str(e)}