File size: 3,530 Bytes
009d93e e6e7506 009d93e e6e7506 009d93e e6e7506 009d93e e6e7506 009d93e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
from models import *
from utils import *
from .extraction_agent import ExtractionAgent
from .knowledge_base.case_repository import CaseRepositoryHandler
class ReflectionGenerator:
def __init__(self, llm: BaseEngine):
self.llm = llm
def get_reflection(self, instruction="", examples="", text="",schema="", result=""):
result = json.dumps(result)
examples = bad_case_wrapper(examples)
prompt = reflect_instruction.format(instruction=instruction, examples=examples, text=text, schema=schema, result=result)
response = self.llm.get_chat_response(prompt)
response = extract_json_dict(response)
return response
class ReflectionAgent:
def __init__(self, llm: BaseEngine, case_repo: CaseRepositoryHandler):
self.llm = llm
self.module = ReflectionGenerator(llm = llm)
self.extractor = ExtractionAgent(llm = llm, case_repo = case_repo)
self.case_repo = case_repo
self.methods = ["reflect_with_case"]
def __select_result(self, result_list):
dict_objects = [obj for obj in result_list if isinstance(obj, dict)]
if dict_objects:
selected_obj = max(dict_objects, key=lambda d: len(json.dumps(d)))
else:
selected_obj = max(result_list, key=lambda o: len(json.dumps(o)))
return selected_obj
def __self_consistance_check(self, data: DataPoint):
extract_func = list(data.result_trajectory.keys())[-1]
if hasattr(self.extractor, extract_func):
result_trails = []
result_trails.append(data.result_list)
extract_func = getattr(self.extractor, extract_func)
temperature = [0.5, 1]
for index in range(2):
self.module.llm.set_hyperparameter(temperature=temperature[index])
data = extract_func(data)
result_trails.append(data.result_list)
self.module.llm.set_hyperparameter()
consistant_result = []
reflect_index = []
for index, elements in enumerate(zip(*result_trails)):
normalized_elements = [normalize_obj(e) for e in elements]
element_counts = Counter(normalized_elements)
selected_element = next((elements[i] for i, element in enumerate(normalized_elements)
if element_counts[element] >= 2), None)
if selected_element is None:
selected_element = self.__select_result(elements)
reflect_index.append(index)
consistant_result.append(selected_element)
data.set_result_list(consistant_result)
return reflect_index
def reflect_with_case(self, data: DataPoint):
if data.result_list == []:
return data
reflect_index = self.__self_consistance_check(data)
reflected_result_list = data.result_list
for idx in reflect_index:
text = data.chunk_text_list[idx]
result = data.result_list[idx]
examples = json.dumps(self.case_repo.query_bad_case(data))
reflected_res = self.module.get_reflection(instruction=data.instruction, examples=examples, text=text, schema=data.output_schema, result=result)
reflected_result_list[idx] = reflected_res
data.set_result_list(reflected_result_list)
function_name = current_function_name()
data.update_trajectory(function_name, data.result_list)
return data
|