Spaces:
Sleeping
Sleeping
import json | |
import numpy as np | |
import os | |
from sentence_transformers import SentenceTransformer | |
from typing import List, Dict, Tuple, Any, Optional | |
# Define data directory path | |
DATA_DIR = "/app/data" | |
# Global variables to store model and data | |
_model = None | |
_question_embeddings = None | |
_answer_embeddings = None | |
_qa_data = None | |
def initialize_model() -> None: | |
""" | |
Initialize the model once and store it in a global variable. | |
""" | |
global _model | |
if _model is None: | |
_model = SentenceTransformer("pkshatech/GLuCoSE-base-ja") | |
return _model | |
def get_model() -> SentenceTransformer: | |
""" | |
Get the loaded model or initialize it if not loaded. | |
""" | |
global _model | |
if _model is None: | |
_model = initialize_model() | |
return _model | |
def load_embeddings() -> Tuple[np.ndarray, np.ndarray, List[Dict[str, str]]]: | |
""" | |
Load embeddings and QA data from files. | |
""" | |
global _question_embeddings, _answer_embeddings, _qa_data | |
try: | |
q_emb_path = os.path.join(DATA_DIR, "question_embeddings.npy") | |
a_emb_path = os.path.join(DATA_DIR, "answer_embeddings.npy") | |
qa_data_path = os.path.join(DATA_DIR, "qa_data.json") | |
_question_embeddings = np.load(q_emb_path) | |
_answer_embeddings = np.load(a_emb_path) | |
with open(qa_data_path, "r", encoding="utf-8") as f: | |
_qa_data = json.load(f) | |
return _question_embeddings, _answer_embeddings, _qa_data | |
except FileNotFoundError as e: | |
print(f"Warning: Embeddings not found. {str(e)}") | |
return None, None, None | |
except Exception as e: | |
print(f"Error loading embeddings: {str(e)}") | |
return None, None, None | |
def get_embeddings() -> ( | |
Tuple[Optional[np.ndarray], Optional[np.ndarray], Optional[List[Dict[str, str]]]] | |
): | |
""" | |
Get the loaded embeddings or load them if not loaded. | |
""" | |
global _question_embeddings, _answer_embeddings, _qa_data | |
if _question_embeddings is None or _answer_embeddings is None or _qa_data is None: | |
_question_embeddings, _answer_embeddings, _qa_data = load_embeddings() | |
return _question_embeddings, _answer_embeddings, _qa_data | |
def reload_embeddings() -> bool: | |
""" | |
Reload embeddings from files. | |
""" | |
global _question_embeddings, _answer_embeddings, _qa_data | |
try: | |
_question_embeddings, _answer_embeddings, _qa_data = load_embeddings() | |
print(f"Embeddings reloaded successfully. {len(_qa_data)} QA pairs available.") | |
return True | |
except Exception as e: | |
print(f"Error reloading embeddings: {str(e)}") | |
return False | |