File size: 2,625 Bytes
44a025a
 
11aa943
44a025a
 
 
11aa943
 
 
44a025a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11aa943
 
 
 
 
 
44a025a
11aa943
44a025a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import json
import numpy as np
import os
from sentence_transformers import SentenceTransformer
from typing import List, Dict, Tuple, Any, Optional

# Define data directory path
DATA_DIR = "/app/data"

# Global variables to store model and data
_model = None
_question_embeddings = None
_answer_embeddings = None
_qa_data = None


def initialize_model() -> None:
    """
    Initialize the model once and store it in a global variable.
    """
    global _model
    if _model is None:
        _model = SentenceTransformer("pkshatech/GLuCoSE-base-ja")
    return _model


def get_model() -> SentenceTransformer:
    """
    Get the loaded model or initialize it if not loaded.
    """
    global _model
    if _model is None:
        _model = initialize_model()
    return _model


def load_embeddings() -> Tuple[np.ndarray, np.ndarray, List[Dict[str, str]]]:
    """
    Load embeddings and QA data from files.
    """
    global _question_embeddings, _answer_embeddings, _qa_data

    try:
        q_emb_path = os.path.join(DATA_DIR, "question_embeddings.npy")
        a_emb_path = os.path.join(DATA_DIR, "answer_embeddings.npy")
        qa_data_path = os.path.join(DATA_DIR, "qa_data.json")

        _question_embeddings = np.load(q_emb_path)
        _answer_embeddings = np.load(a_emb_path)

        with open(qa_data_path, "r", encoding="utf-8") as f:
            _qa_data = json.load(f)

        return _question_embeddings, _answer_embeddings, _qa_data
    except FileNotFoundError as e:
        print(f"Warning: Embeddings not found. {str(e)}")
        return None, None, None
    except Exception as e:
        print(f"Error loading embeddings: {str(e)}")
        return None, None, None


def get_embeddings() -> (
    Tuple[Optional[np.ndarray], Optional[np.ndarray], Optional[List[Dict[str, str]]]]
):
    """
    Get the loaded embeddings or load them if not loaded.
    """
    global _question_embeddings, _answer_embeddings, _qa_data

    if _question_embeddings is None or _answer_embeddings is None or _qa_data is None:
        _question_embeddings, _answer_embeddings, _qa_data = load_embeddings()

    return _question_embeddings, _answer_embeddings, _qa_data


def reload_embeddings() -> bool:
    """
    Reload embeddings from files.
    """
    global _question_embeddings, _answer_embeddings, _qa_data

    try:
        _question_embeddings, _answer_embeddings, _qa_data = load_embeddings()
        print(f"Embeddings reloaded successfully. {len(_qa_data)} QA pairs available.")
        return True
    except Exception as e:
        print(f"Error reloading embeddings: {str(e)}")
        return False