|
import json |
|
import os |
|
import torch |
|
import numpy as np |
|
from utils import * |
|
from sentence_transformers import SentenceTransformer |
|
from rapidfuzz import process |
|
from models import * |
|
import copy |
|
|
|
import warnings |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
docker_model_path = "/app/model/all-MiniLM-L6-v2" |
|
warnings.filterwarnings("ignore", category=FutureWarning, message=r".*clean_up_tokenization_spaces*") |
|
|
|
class CaseRepository: |
|
def __init__(self): |
|
try: |
|
self.embedder = SentenceTransformer(docker_model_path) |
|
except: |
|
self.embedder = SentenceTransformer(config['model']['embedding_model']) |
|
self.embedder.to(device) |
|
self.corpus = self.load_corpus() |
|
self.embedded_corpus = self.embed_corpus() |
|
|
|
def load_corpus(self): |
|
with open(os.path.join(os.path.dirname(__file__), "case_repository.json")) as file: |
|
corpus = json.load(file) |
|
return corpus |
|
|
|
def update_corpus(self): |
|
try: |
|
with open(os.path.join(os.path.dirname(__file__), "case_repository.json"), "w") as file: |
|
json.dump(self.corpus, file, indent=2) |
|
except Exception as e: |
|
print(f"Error when updating corpus: {e}") |
|
|
|
def embed_corpus(self): |
|
embedded_corpus = {} |
|
for key, content in self.corpus.items(): |
|
good_index = [item['index']['embed_index'] for item in content['good']] |
|
encoded_good_index = self.embedder.encode(good_index, convert_to_tensor=True).to(device) |
|
bad_index = [item['index']['embed_index'] for item in content['bad']] |
|
encoded_bad_index = self.embedder.encode(bad_index, convert_to_tensor=True).to(device) |
|
embedded_corpus[key] = {"good": encoded_good_index, "bad": encoded_bad_index} |
|
return embedded_corpus |
|
|
|
def get_similarity_scores(self, task: TaskType, embed_index="", str_index="", case_type="", top_k=2): |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
encoded_embed_query = self.embedder.encode(embed_index, convert_to_tensor=True).to(device) |
|
embedding_similarity_matrix = self.embedder.similarity(encoded_embed_query, self.embedded_corpus[task][case_type]) |
|
embedding_similarity_scores = embedding_similarity_matrix[0].to(device) |
|
|
|
|
|
str_match_corpus = [item['index']['str_index'] for item in self.corpus[task][case_type]] |
|
str_similarity_results = process.extract(str_index, str_match_corpus, limit=len(str_match_corpus)) |
|
scores_dict = {match[0]: match[1] for match in str_similarity_results} |
|
scores_in_order = [scores_dict[candidate] for candidate in str_match_corpus] |
|
str_similarity_scores = torch.tensor(scores_in_order, dtype=torch.float32).to(device) |
|
|
|
|
|
embedding_score_range = embedding_similarity_scores.max() - embedding_similarity_scores.min() |
|
str_score_range = str_similarity_scores.max() - str_similarity_scores.min() |
|
if embedding_score_range > 0: |
|
embed_norm_scores = (embedding_similarity_scores - embedding_similarity_scores.min()) / embedding_score_range |
|
else: |
|
embed_norm_scores = embedding_similarity_scores |
|
if str_score_range > 0: |
|
str_norm_scores = (str_similarity_scores - str_similarity_scores.min()) / str_score_range |
|
else: |
|
str_norm_scores = str_similarity_scores / 100 |
|
|
|
|
|
combined_scores = 0.5 * embed_norm_scores + 0.5 * str_norm_scores |
|
original_combined_scores = 0.5 * embedding_similarity_scores + 0.5 * str_similarity_scores / 100 |
|
|
|
scores, indices = torch.topk(combined_scores, k=min(top_k, combined_scores.size(0))) |
|
original_scores, original_indices = torch.topk(original_combined_scores, k=min(top_k, original_combined_scores.size(0))) |
|
return scores, indices, original_scores, original_indices |
|
|
|
def query_case(self, task: TaskType, embed_index="", str_index="", case_type="", top_k=2) -> list: |
|
_, indices, _, _ = self.get_similarity_scores(task, embed_index, str_index, case_type, top_k) |
|
top_matches = [self.corpus[task][case_type][idx]["content"] for idx in indices] |
|
return top_matches |
|
|
|
def update_case(self, task: TaskType, embed_index="", str_index="", content="" ,case_type=""): |
|
self.corpus[task][case_type].append({"index": {"embed_index": embed_index, "str_index": str_index}, "content": content}) |
|
self.embedded_corpus[task][case_type] = torch.cat([self.embedded_corpus[task][case_type], self.embedder.encode([embed_index], convert_to_tensor=True).to(device)], dim=0) |
|
print(f"A {case_type} case updated for {task} task.") |
|
|
|
class CaseRepositoryHandler: |
|
def __init__(self, llm: BaseEngine): |
|
self.repository = CaseRepository() |
|
self.llm = llm |
|
|
|
def __get_good_case_analysis(self, instruction="", text="", result="", additional_info=""): |
|
prompt = good_case_analysis_instruction.format( |
|
instruction=instruction, text=text, result=result, additional_info=additional_info |
|
) |
|
for _ in range(3): |
|
response = self.llm.get_chat_response(prompt) |
|
response = extract_json_dict(response) |
|
if not isinstance(response, dict): |
|
return response |
|
return None |
|
|
|
def __get_bad_case_reflection(self, instruction="", text="", original_answer="", correct_answer="", additional_info=""): |
|
prompt = bad_case_reflection_instruction.format( |
|
instruction=instruction, text=text, original_answer=original_answer, correct_answer=correct_answer, additional_info=additional_info |
|
) |
|
for _ in range(3): |
|
response = self.llm.get_chat_response(prompt) |
|
response = extract_json_dict(response) |
|
if not isinstance(response, dict): |
|
return response |
|
return None |
|
|
|
def __get_index(self, data: DataPoint, case_type: str): |
|
|
|
embed_index = f"**Text**: {data.distilled_text}\n{data.chunk_text_list[0]}" |
|
|
|
|
|
if data.task == "Base": |
|
str_index = f"**Task**: {data.instruction}" |
|
else: |
|
str_index = f"{data.constraint}" |
|
|
|
if case_type == "bad": |
|
str_index += f"\n\n**Original Result**: {json.dumps(data.pred)}" |
|
|
|
return embed_index, str_index |
|
|
|
def query_good_case(self, data: DataPoint): |
|
embed_index, str_index = self.__get_index(data, "good") |
|
return self.repository.query_case(task=data.task, embed_index=embed_index, str_index=str_index, case_type="good") |
|
|
|
def query_bad_case(self, data: DataPoint): |
|
embed_index, str_index = self.__get_index(data, "bad") |
|
return self.repository.query_case(task=data.task, embed_index=embed_index, str_index=str_index, case_type="bad") |
|
|
|
def update_good_case(self, data: DataPoint): |
|
if data.truth == "" : |
|
print("No truth value provided.") |
|
return |
|
embed_index, str_index = self.__get_index(data, "good") |
|
_, _, original_scores, _ = self.repository.get_similarity_scores(data.task, embed_index, str_index, "good", 1) |
|
original_scores = original_scores.tolist() |
|
if original_scores[0] >= 0.9: |
|
print("The similar good case is already in the corpus. Similarity Score: ", original_scores[0]) |
|
return |
|
good_case_alaysis = self.__get_good_case_analysis(instruction=data.instruction, text=data.distilled_text, result=data.truth, additional_info=data.constraint) |
|
wrapped_good_case_analysis = f"**Analysis**: {good_case_alaysis}" |
|
wrapped_instruction = f"**Task**: {data.instruction}" |
|
wrapped_text = f"**Text**: {data.distilled_text}\n{data.chunk_text_list[0]}" |
|
wrapped_answer = f"**Correct Answer**: {json.dumps(data.truth)}" |
|
if data.task == "Base": |
|
content = f"{wrapped_instruction}\n\n{wrapped_text}\n\n{wrapped_good_case_analysis}\n\n{wrapped_answer}" |
|
else: |
|
content = f"{wrapped_text}\n\n{data.constraint}\n\n{wrapped_good_case_analysis}\n\n{wrapped_answer}" |
|
self.repository.update_case(data.task, embed_index, str_index, content, "good") |
|
|
|
def update_bad_case(self, data: DataPoint): |
|
if data.truth == "" : |
|
print("No truth value provided.") |
|
return |
|
if normalize_obj(data.pred) == normalize_obj(data.truth): |
|
return |
|
embed_index, str_index = self.__get_index(data, "bad") |
|
_, _, original_scores, _ = self.repository.get_similarity_scores(data.task, embed_index, str_index, "bad", 1) |
|
original_scores = original_scores.tolist() |
|
if original_scores[0] >= 0.9: |
|
print("The similar bad case is already in the corpus. Similarity Score: ", original_scores[0]) |
|
return |
|
bad_case_reflection = self.__get_bad_case_reflection(instruction=data.instruction, text=data.distilled_text, original_answer=data.pred, correct_answer=data.truth, additional_info=data.constraint) |
|
wrapped_bad_case_reflection = f"**Reflection**: {bad_case_reflection}" |
|
wrapper_original_answer = f"**Original Answer**: {json.dumps(data.pred)}" |
|
wrapper_correct_answer = f"**Correct Answer**: {json.dumps(data.truth)}" |
|
wrapped_instruction = f"**Task**: {data.instruction}" |
|
wrapped_text = f"**Text**: {data.distilled_text}\n{data.chunk_text_list[0]}" |
|
if data.task == "Base": |
|
content = f"{wrapped_instruction}\n\n{wrapped_text}\n\n{wrapper_original_answer}\n\n{wrapped_bad_case_reflection}\n\n{wrapper_correct_answer}" |
|
else: |
|
content = f"{wrapped_text}\n\n{data.constraint}\n\n{wrapper_original_answer}\n\n{wrapped_bad_case_reflection}\n\n{wrapper_correct_answer}" |
|
self.repository.update_case(data.task, embed_index, str_index, content, "bad") |
|
|
|
def update_case(self, data: DataPoint): |
|
self.update_good_case(data) |
|
self.update_bad_case(data) |
|
self.repository.update_corpus() |