OneKE / src /modules /knowledge_base /case_repository.py
ShawnRu's picture
update
e6e7506
import json
import os
import torch
import numpy as np
from utils import *
from sentence_transformers import SentenceTransformer
from rapidfuzz import process
from models import *
import copy
import warnings
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
docker_model_path = "/app/model/all-MiniLM-L6-v2"
warnings.filterwarnings("ignore", category=FutureWarning, message=r".*clean_up_tokenization_spaces*")
class CaseRepository:
def __init__(self):
try:
self.embedder = SentenceTransformer(docker_model_path)
except:
self.embedder = SentenceTransformer(config['model']['embedding_model'])
self.embedder.to(device)
self.corpus = self.load_corpus()
self.embedded_corpus = self.embed_corpus()
def load_corpus(self):
with open(os.path.join(os.path.dirname(__file__), "case_repository.json")) as file:
corpus = json.load(file)
return corpus
def update_corpus(self):
try:
with open(os.path.join(os.path.dirname(__file__), "case_repository.json"), "w") as file:
json.dump(self.corpus, file, indent=2)
except Exception as e:
print(f"Error when updating corpus: {e}")
def embed_corpus(self):
embedded_corpus = {}
for key, content in self.corpus.items():
good_index = [item['index']['embed_index'] for item in content['good']]
encoded_good_index = self.embedder.encode(good_index, convert_to_tensor=True).to(device)
bad_index = [item['index']['embed_index'] for item in content['bad']]
encoded_bad_index = self.embedder.encode(bad_index, convert_to_tensor=True).to(device)
embedded_corpus[key] = {"good": encoded_good_index, "bad": encoded_bad_index}
return embedded_corpus
def get_similarity_scores(self, task: TaskType, embed_index="", str_index="", case_type="", top_k=2):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Embedding similarity match
encoded_embed_query = self.embedder.encode(embed_index, convert_to_tensor=True).to(device)
embedding_similarity_matrix = self.embedder.similarity(encoded_embed_query, self.embedded_corpus[task][case_type])
embedding_similarity_scores = embedding_similarity_matrix[0].to(device)
# String similarity match
str_match_corpus = [item['index']['str_index'] for item in self.corpus[task][case_type]]
str_similarity_results = process.extract(str_index, str_match_corpus, limit=len(str_match_corpus))
scores_dict = {match[0]: match[1] for match in str_similarity_results}
scores_in_order = [scores_dict[candidate] for candidate in str_match_corpus]
str_similarity_scores = torch.tensor(scores_in_order, dtype=torch.float32).to(device)
# Normalize scores
embedding_score_range = embedding_similarity_scores.max() - embedding_similarity_scores.min()
str_score_range = str_similarity_scores.max() - str_similarity_scores.min()
if embedding_score_range > 0:
embed_norm_scores = (embedding_similarity_scores - embedding_similarity_scores.min()) / embedding_score_range
else:
embed_norm_scores = embedding_similarity_scores
if str_score_range > 0:
str_norm_scores = (str_similarity_scores - str_similarity_scores.min()) / str_score_range
else:
str_norm_scores = str_similarity_scores / 100
# Combine the scores with weights
combined_scores = 0.5 * embed_norm_scores + 0.5 * str_norm_scores
original_combined_scores = 0.5 * embedding_similarity_scores + 0.5 * str_similarity_scores / 100
scores, indices = torch.topk(combined_scores, k=min(top_k, combined_scores.size(0)))
original_scores, original_indices = torch.topk(original_combined_scores, k=min(top_k, original_combined_scores.size(0)))
return scores, indices, original_scores, original_indices
def query_case(self, task: TaskType, embed_index="", str_index="", case_type="", top_k=2) -> list:
_, indices, _, _ = self.get_similarity_scores(task, embed_index, str_index, case_type, top_k)
top_matches = [self.corpus[task][case_type][idx]["content"] for idx in indices]
return top_matches
def update_case(self, task: TaskType, embed_index="", str_index="", content="" ,case_type=""):
self.corpus[task][case_type].append({"index": {"embed_index": embed_index, "str_index": str_index}, "content": content})
self.embedded_corpus[task][case_type] = torch.cat([self.embedded_corpus[task][case_type], self.embedder.encode([embed_index], convert_to_tensor=True).to(device)], dim=0)
print(f"A {case_type} case updated for {task} task.")
class CaseRepositoryHandler:
def __init__(self, llm: BaseEngine):
self.repository = CaseRepository()
self.llm = llm
def __get_good_case_analysis(self, instruction="", text="", result="", additional_info=""):
prompt = good_case_analysis_instruction.format(
instruction=instruction, text=text, result=result, additional_info=additional_info
)
for _ in range(3):
response = self.llm.get_chat_response(prompt)
response = extract_json_dict(response)
if not isinstance(response, dict):
return response
return None
def __get_bad_case_reflection(self, instruction="", text="", original_answer="", correct_answer="", additional_info=""):
prompt = bad_case_reflection_instruction.format(
instruction=instruction, text=text, original_answer=original_answer, correct_answer=correct_answer, additional_info=additional_info
)
for _ in range(3):
response = self.llm.get_chat_response(prompt)
response = extract_json_dict(response)
if not isinstance(response, dict):
return response
return None
def __get_index(self, data: DataPoint, case_type: str):
# set embed_index
embed_index = f"**Text**: {data.distilled_text}\n{data.chunk_text_list[0]}"
# set str_index
if data.task == "Base":
str_index = f"**Task**: {data.instruction}"
else:
str_index = f"{data.constraint}"
if case_type == "bad":
str_index += f"\n\n**Original Result**: {json.dumps(data.pred)}"
return embed_index, str_index
def query_good_case(self, data: DataPoint):
embed_index, str_index = self.__get_index(data, "good")
return self.repository.query_case(task=data.task, embed_index=embed_index, str_index=str_index, case_type="good")
def query_bad_case(self, data: DataPoint):
embed_index, str_index = self.__get_index(data, "bad")
return self.repository.query_case(task=data.task, embed_index=embed_index, str_index=str_index, case_type="bad")
def update_good_case(self, data: DataPoint):
if data.truth == "" :
print("No truth value provided.")
return
embed_index, str_index = self.__get_index(data, "good")
_, _, original_scores, _ = self.repository.get_similarity_scores(data.task, embed_index, str_index, "good", 1)
original_scores = original_scores.tolist()
if original_scores[0] >= 0.9:
print("The similar good case is already in the corpus. Similarity Score: ", original_scores[0])
return
good_case_alaysis = self.__get_good_case_analysis(instruction=data.instruction, text=data.distilled_text, result=data.truth, additional_info=data.constraint)
wrapped_good_case_analysis = f"**Analysis**: {good_case_alaysis}"
wrapped_instruction = f"**Task**: {data.instruction}"
wrapped_text = f"**Text**: {data.distilled_text}\n{data.chunk_text_list[0]}"
wrapped_answer = f"**Correct Answer**: {json.dumps(data.truth)}"
if data.task == "Base":
content = f"{wrapped_instruction}\n\n{wrapped_text}\n\n{wrapped_good_case_analysis}\n\n{wrapped_answer}"
else:
content = f"{wrapped_text}\n\n{data.constraint}\n\n{wrapped_good_case_analysis}\n\n{wrapped_answer}"
self.repository.update_case(data.task, embed_index, str_index, content, "good")
def update_bad_case(self, data: DataPoint):
if data.truth == "" :
print("No truth value provided.")
return
if normalize_obj(data.pred) == normalize_obj(data.truth):
return
embed_index, str_index = self.__get_index(data, "bad")
_, _, original_scores, _ = self.repository.get_similarity_scores(data.task, embed_index, str_index, "bad", 1)
original_scores = original_scores.tolist()
if original_scores[0] >= 0.9:
print("The similar bad case is already in the corpus. Similarity Score: ", original_scores[0])
return
bad_case_reflection = self.__get_bad_case_reflection(instruction=data.instruction, text=data.distilled_text, original_answer=data.pred, correct_answer=data.truth, additional_info=data.constraint)
wrapped_bad_case_reflection = f"**Reflection**: {bad_case_reflection}"
wrapper_original_answer = f"**Original Answer**: {json.dumps(data.pred)}"
wrapper_correct_answer = f"**Correct Answer**: {json.dumps(data.truth)}"
wrapped_instruction = f"**Task**: {data.instruction}"
wrapped_text = f"**Text**: {data.distilled_text}\n{data.chunk_text_list[0]}"
if data.task == "Base":
content = f"{wrapped_instruction}\n\n{wrapped_text}\n\n{wrapper_original_answer}\n\n{wrapped_bad_case_reflection}\n\n{wrapper_correct_answer}"
else:
content = f"{wrapped_text}\n\n{data.constraint}\n\n{wrapper_original_answer}\n\n{wrapped_bad_case_reflection}\n\n{wrapper_correct_answer}"
self.repository.update_case(data.task, embed_index, str_index, content, "bad")
def update_case(self, data: DataPoint):
self.update_good_case(data)
self.update_bad_case(data)
self.repository.update_corpus()