File size: 10,156 Bytes
009d93e
 
 
 
 
 
 
 
 
 
 
 
4754e33
009d93e
 
 
 
4754e33
 
 
 
e6e7506
4754e33
 
009d93e
 
4754e33
 
 
e6e7506
009d93e
4754e33
 
 
 
 
e6e7506
009d93e
4754e33
 
 
 
 
 
 
e6e7506
 
009d93e
4754e33
 
 
 
 
 
 
 
 
 
 
 
e6e7506
4754e33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e6e7506
4754e33
 
 
e6e7506
009d93e
4754e33
 
 
e6e7506
009d93e
4754e33
 
 
009d93e
 
 
 
 
 
 
4754e33
 
 
 
 
 
 
e6e7506
4754e33
e6e7506
009d93e
4754e33
 
 
 
 
 
 
 
 
e6e7506
009d93e
 
4754e33
e6e7506
4754e33
 
 
 
 
e6e7506
4754e33
 
e6e7506
4754e33
e6e7506
009d93e
4754e33
 
e6e7506
009d93e
4754e33
 
e6e7506
009d93e
4754e33
 
e6e7506
4754e33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e6e7506
009d93e
4754e33
 
e6e7506
4754e33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e6e7506
009d93e
4754e33
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
import json
import os
import torch
import numpy as np
from utils import *
from sentence_transformers import SentenceTransformer
from rapidfuzz import process
from models import *
import copy

import warnings
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
docker_model_path = "/app/model/all-MiniLM-L6-v2"
warnings.filterwarnings("ignore", category=FutureWarning, message=r".*clean_up_tokenization_spaces*")

class CaseRepository:
    def __init__(self):
        try:
            self.embedder = SentenceTransformer(docker_model_path)
        except:
            self.embedder = SentenceTransformer(config['model']['embedding_model'])
        self.embedder.to(device)
        self.corpus = self.load_corpus()
        self.embedded_corpus = self.embed_corpus()

    def load_corpus(self):
        with open(os.path.join(os.path.dirname(__file__), "case_repository.json")) as file:
            corpus = json.load(file)
        return corpus

    def update_corpus(self):
        try:
            with open(os.path.join(os.path.dirname(__file__), "case_repository.json"), "w") as file:
                json.dump(self.corpus, file, indent=2)
        except Exception as e:
            print(f"Error when updating corpus: {e}")

    def embed_corpus(self):
        embedded_corpus = {}
        for key, content in self.corpus.items():
            good_index = [item['index']['embed_index'] for item in content['good']]
            encoded_good_index = self.embedder.encode(good_index, convert_to_tensor=True).to(device)
            bad_index = [item['index']['embed_index'] for item in content['bad']]
            encoded_bad_index = self.embedder.encode(bad_index, convert_to_tensor=True).to(device)
            embedded_corpus[key] = {"good": encoded_good_index, "bad": encoded_bad_index}
        return embedded_corpus

    def get_similarity_scores(self, task: TaskType, embed_index="", str_index="", case_type="", top_k=2):
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        # Embedding similarity match
        encoded_embed_query = self.embedder.encode(embed_index, convert_to_tensor=True).to(device)
        embedding_similarity_matrix = self.embedder.similarity(encoded_embed_query, self.embedded_corpus[task][case_type])
        embedding_similarity_scores = embedding_similarity_matrix[0].to(device)

        # String similarity match
        str_match_corpus = [item['index']['str_index'] for item in self.corpus[task][case_type]]
        str_similarity_results = process.extract(str_index, str_match_corpus, limit=len(str_match_corpus))
        scores_dict = {match[0]: match[1] for match in str_similarity_results}
        scores_in_order = [scores_dict[candidate] for candidate in str_match_corpus]
        str_similarity_scores = torch.tensor(scores_in_order, dtype=torch.float32).to(device)

        # Normalize scores
        embedding_score_range = embedding_similarity_scores.max() - embedding_similarity_scores.min()
        str_score_range = str_similarity_scores.max() - str_similarity_scores.min()
        if embedding_score_range > 0:
            embed_norm_scores = (embedding_similarity_scores - embedding_similarity_scores.min()) / embedding_score_range
        else:
            embed_norm_scores = embedding_similarity_scores
        if str_score_range > 0:
            str_norm_scores = (str_similarity_scores - str_similarity_scores.min()) / str_score_range
        else:
            str_norm_scores = str_similarity_scores / 100

        # Combine the scores with weights
        combined_scores = 0.5 * embed_norm_scores + 0.5 * str_norm_scores
        original_combined_scores = 0.5 * embedding_similarity_scores + 0.5 * str_similarity_scores / 100

        scores, indices = torch.topk(combined_scores, k=min(top_k, combined_scores.size(0)))
        original_scores, original_indices = torch.topk(original_combined_scores, k=min(top_k, original_combined_scores.size(0)))
        return scores, indices, original_scores, original_indices

    def query_case(self, task: TaskType, embed_index="", str_index="", case_type="", top_k=2) -> list:
        _, indices, _, _ = self.get_similarity_scores(task, embed_index, str_index, case_type, top_k)
        top_matches = [self.corpus[task][case_type][idx]["content"] for idx in indices]
        return top_matches

    def update_case(self, task: TaskType, embed_index="", str_index="", content="" ,case_type=""):
        self.corpus[task][case_type].append({"index": {"embed_index": embed_index, "str_index": str_index}, "content": content})
        self.embedded_corpus[task][case_type] = torch.cat([self.embedded_corpus[task][case_type], self.embedder.encode([embed_index], convert_to_tensor=True).to(device)], dim=0)
        print(f"A {case_type} case updated for {task} task.")

class CaseRepositoryHandler:
    def __init__(self, llm: BaseEngine):
        self.repository = CaseRepository()
        self.llm = llm

    def __get_good_case_analysis(self, instruction="", text="", result="", additional_info=""):
        prompt = good_case_analysis_instruction.format(
            instruction=instruction, text=text, result=result, additional_info=additional_info
        )
        for _ in range(3):
            response = self.llm.get_chat_response(prompt)
            response = extract_json_dict(response)
            if not isinstance(response, dict):
                return response
        return None

    def __get_bad_case_reflection(self, instruction="", text="", original_answer="", correct_answer="", additional_info=""):
        prompt = bad_case_reflection_instruction.format(
            instruction=instruction, text=text, original_answer=original_answer, correct_answer=correct_answer, additional_info=additional_info
        )
        for _ in range(3):
            response = self.llm.get_chat_response(prompt)
            response = extract_json_dict(response)
            if not isinstance(response, dict):
                return response
        return None

    def __get_index(self, data: DataPoint, case_type: str):
        # set embed_index
        embed_index = f"**Text**: {data.distilled_text}\n{data.chunk_text_list[0]}"

        # set str_index
        if data.task == "Base":
            str_index = f"**Task**: {data.instruction}"
        else:
            str_index = f"{data.constraint}"

        if case_type == "bad":
            str_index += f"\n\n**Original Result**: {json.dumps(data.pred)}"

        return embed_index, str_index

    def query_good_case(self, data: DataPoint):
        embed_index, str_index = self.__get_index(data, "good")
        return self.repository.query_case(task=data.task, embed_index=embed_index, str_index=str_index, case_type="good")

    def query_bad_case(self, data: DataPoint):
        embed_index, str_index = self.__get_index(data, "bad")
        return self.repository.query_case(task=data.task, embed_index=embed_index, str_index=str_index, case_type="bad")

    def update_good_case(self, data: DataPoint):
        if data.truth == "" :
            print("No truth value provided.")
            return
        embed_index, str_index = self.__get_index(data, "good")
        _, _, original_scores, _ = self.repository.get_similarity_scores(data.task, embed_index, str_index, "good", 1)
        original_scores = original_scores.tolist()
        if original_scores[0] >= 0.9:
            print("The similar good case is already in the corpus. Similarity Score: ", original_scores[0])
            return
        good_case_alaysis = self.__get_good_case_analysis(instruction=data.instruction, text=data.distilled_text, result=data.truth, additional_info=data.constraint)
        wrapped_good_case_analysis = f"**Analysis**: {good_case_alaysis}"
        wrapped_instruction = f"**Task**: {data.instruction}"
        wrapped_text = f"**Text**: {data.distilled_text}\n{data.chunk_text_list[0]}"
        wrapped_answer = f"**Correct Answer**: {json.dumps(data.truth)}"
        if data.task == "Base":
            content = f"{wrapped_instruction}\n\n{wrapped_text}\n\n{wrapped_good_case_analysis}\n\n{wrapped_answer}"
        else:
            content = f"{wrapped_text}\n\n{data.constraint}\n\n{wrapped_good_case_analysis}\n\n{wrapped_answer}"
        self.repository.update_case(data.task, embed_index, str_index, content, "good")

    def update_bad_case(self, data: DataPoint):
        if data.truth == "" :
            print("No truth value provided.")
            return
        if normalize_obj(data.pred) == normalize_obj(data.truth):
            return
        embed_index, str_index = self.__get_index(data, "bad")
        _, _, original_scores, _ = self.repository.get_similarity_scores(data.task, embed_index, str_index, "bad", 1)
        original_scores = original_scores.tolist()
        if original_scores[0] >= 0.9:
            print("The similar bad case is already in the corpus. Similarity Score: ", original_scores[0])
            return
        bad_case_reflection = self.__get_bad_case_reflection(instruction=data.instruction, text=data.distilled_text, original_answer=data.pred, correct_answer=data.truth, additional_info=data.constraint)
        wrapped_bad_case_reflection = f"**Reflection**: {bad_case_reflection}"
        wrapper_original_answer = f"**Original Answer**: {json.dumps(data.pred)}"
        wrapper_correct_answer = f"**Correct Answer**: {json.dumps(data.truth)}"
        wrapped_instruction = f"**Task**: {data.instruction}"
        wrapped_text = f"**Text**: {data.distilled_text}\n{data.chunk_text_list[0]}"
        if data.task == "Base":
            content = f"{wrapped_instruction}\n\n{wrapped_text}\n\n{wrapper_original_answer}\n\n{wrapped_bad_case_reflection}\n\n{wrapper_correct_answer}"
        else:
            content =  f"{wrapped_text}\n\n{data.constraint}\n\n{wrapper_original_answer}\n\n{wrapped_bad_case_reflection}\n\n{wrapper_correct_answer}"
        self.repository.update_case(data.task, embed_index, str_index, content, "bad")

    def update_case(self, data: DataPoint):
        self.update_good_case(data)
        self.update_bad_case(data)
        self.repository.update_corpus()