File size: 3,486 Bytes
44a025a
 
 
 
 
 
 
 
 
11aa943
 
 
44a025a
11aa943
44a025a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11aa943
 
44a025a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11aa943
 
 
 
 
 
44a025a
 
11aa943
44a025a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import pandas as pd
import json
import re
import numpy as np
import os
from typing import List, Dict, Tuple, Any

from app.services.model_service import get_model, reload_embeddings

# Define data directory path
DATA_DIR = "/app/data"

# Ensure data directory exists
os.makedirs(DATA_DIR, exist_ok=True)


def remove_prefix(text: str, prefix_pattern: str) -> str:
    """
    Removes the prefix matching the given pattern from the text.
    """
    return re.sub(prefix_pattern, "", text).strip()


def process_file(file_path: str, file_type: str) -> List[Dict[str, str]]:
    """
    Process Excel or CSV file and extract question-answer pairs.
    """
    if file_type == "excel":
        df = pd.read_excel(file_path)
    elif file_type == "csv":
        df = pd.read_csv(file_path)
    else:
        raise ValueError("Unsupported file type. Use 'excel' or 'csv'.")

    # Check if the necessary columns exist
    if "θ³ͺ問" not in df.columns or "ε›žη­”" not in df.columns:
        raise ValueError("The file must contain 'θ³ͺ問' and 'ε›žη­”' columns.")

    # Initialize the list to store processed data
    qa_list = []
    df.dropna(inplace=True)
    # Iterate over each row in the DataFrame
    for index, row in df.iterrows():
        raw_question = str(row["θ³ͺ問"])
        raw_answer = str(row["ε›žη­”"])

        # Remove prefixes using regex patterns
        question = remove_prefix(raw_question, r"^Q\d+\.\s*")
        answer = remove_prefix(raw_answer, r"^A\.\s*")

        qa_list.append({"question": question, "answer": answer})
        # print(qa_list)

    return qa_list


def save_raw_data(qa_list: List[Dict[str, str]]) -> None:
    """
    Save the raw question-answer pairs to a JSON file.
    """
    raw_path = os.path.join(DATA_DIR, "raw.json")
    with open(raw_path, "w", encoding="utf-8") as json_file:
        json.dump(qa_list, json_file, ensure_ascii=False, indent=2)


def create_and_save_embeddings(qa_list: List[Dict[str, str]]) -> None:
    """
    Create embeddings for questions and answers and save them.
    """
    questions = [item["question"] for item in qa_list]
    answers = [item["answer"] for item in qa_list]

    # Use the global model
    model = get_model()

    # Create embeddings for questions and answers
    question_embeddings = model.encode(questions, convert_to_numpy=True)
    answer_embeddings = model.encode(answers, convert_to_numpy=True)

    # Save embeddings as numpy arrays
    q_emb_path = os.path.join(DATA_DIR, "question_embeddings.npy")
    a_emb_path = os.path.join(DATA_DIR, "answer_embeddings.npy")
    qa_data_path = os.path.join(DATA_DIR, "qa_data.json")

    np.save(q_emb_path, question_embeddings)
    np.save(a_emb_path, answer_embeddings)

    # Save the original data
    with open(qa_data_path, "w", encoding="utf-8") as f:
        json.dump(qa_list, f, ensure_ascii=False, indent=2)


def process_and_create_embeddings(file_path: str, file_type: str) -> Dict[str, Any]:
    """
    Process the input file and create embeddings.
    """
    try:
        qa_list = process_file(file_path, file_type)
        save_raw_data(qa_list)
        create_and_save_embeddings(qa_list)

        # Reload embeddings into memory
        reload_embeddings()

        return {
            "status": "success",
            "message": "Embeddings created successfully",
            "data": {"total_qa_pairs": len(qa_list)},
        }
    except Exception as e:
        return {"status": "error", "message": str(e)}