Spaces:
Sleeping
Sleeping
File size: 2,625 Bytes
44a025a 11aa943 44a025a 11aa943 44a025a 11aa943 44a025a 11aa943 44a025a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
import json
import numpy as np
import os
from sentence_transformers import SentenceTransformer
from typing import List, Dict, Tuple, Any, Optional
# Define data directory path
DATA_DIR = "/app/data"
# Global variables to store model and data
_model = None
_question_embeddings = None
_answer_embeddings = None
_qa_data = None
def initialize_model() -> None:
"""
Initialize the model once and store it in a global variable.
"""
global _model
if _model is None:
_model = SentenceTransformer("pkshatech/GLuCoSE-base-ja")
return _model
def get_model() -> SentenceTransformer:
"""
Get the loaded model or initialize it if not loaded.
"""
global _model
if _model is None:
_model = initialize_model()
return _model
def load_embeddings() -> Tuple[np.ndarray, np.ndarray, List[Dict[str, str]]]:
"""
Load embeddings and QA data from files.
"""
global _question_embeddings, _answer_embeddings, _qa_data
try:
q_emb_path = os.path.join(DATA_DIR, "question_embeddings.npy")
a_emb_path = os.path.join(DATA_DIR, "answer_embeddings.npy")
qa_data_path = os.path.join(DATA_DIR, "qa_data.json")
_question_embeddings = np.load(q_emb_path)
_answer_embeddings = np.load(a_emb_path)
with open(qa_data_path, "r", encoding="utf-8") as f:
_qa_data = json.load(f)
return _question_embeddings, _answer_embeddings, _qa_data
except FileNotFoundError as e:
print(f"Warning: Embeddings not found. {str(e)}")
return None, None, None
except Exception as e:
print(f"Error loading embeddings: {str(e)}")
return None, None, None
def get_embeddings() -> (
Tuple[Optional[np.ndarray], Optional[np.ndarray], Optional[List[Dict[str, str]]]]
):
"""
Get the loaded embeddings or load them if not loaded.
"""
global _question_embeddings, _answer_embeddings, _qa_data
if _question_embeddings is None or _answer_embeddings is None or _qa_data is None:
_question_embeddings, _answer_embeddings, _qa_data = load_embeddings()
return _question_embeddings, _answer_embeddings, _qa_data
def reload_embeddings() -> bool:
"""
Reload embeddings from files.
"""
global _question_embeddings, _answer_embeddings, _qa_data
try:
_question_embeddings, _answer_embeddings, _qa_data = load_embeddings()
print(f"Embeddings reloaded successfully. {len(_qa_data)} QA pairs available.")
return True
except Exception as e:
print(f"Error reloading embeddings: {str(e)}")
return False
|