|
from models import * |
|
from utils import * |
|
from .extraction_agent import ExtractionAgent |
|
from .knowledge_base.case_repository import CaseRepositoryHandler |
|
class ReflectionGenerator: |
|
def __init__(self, llm: BaseEngine): |
|
self.llm = llm |
|
|
|
def get_reflection(self, instruction="", examples="", text="",schema="", result=""): |
|
result = json.dumps(result) |
|
examples = bad_case_wrapper(examples) |
|
prompt = reflect_instruction.format(instruction=instruction, examples=examples, text=text, schema=schema, result=result) |
|
response = self.llm.get_chat_response(prompt) |
|
response = extract_json_dict(response) |
|
return response |
|
|
|
class ReflectionAgent: |
|
def __init__(self, llm: BaseEngine, case_repo: CaseRepositoryHandler): |
|
self.llm = llm |
|
self.module = ReflectionGenerator(llm = llm) |
|
self.extractor = ExtractionAgent(llm = llm, case_repo = case_repo) |
|
self.case_repo = case_repo |
|
self.methods = ["reflect_with_case"] |
|
|
|
def __select_result(self, result_list): |
|
dict_objects = [obj for obj in result_list if isinstance(obj, dict)] |
|
if dict_objects: |
|
selected_obj = max(dict_objects, key=lambda d: len(json.dumps(d))) |
|
else: |
|
selected_obj = max(result_list, key=lambda o: len(json.dumps(o))) |
|
return selected_obj |
|
|
|
def __self_consistance_check(self, data: DataPoint): |
|
extract_func = list(data.result_trajectory.keys())[-1] |
|
if hasattr(self.extractor, extract_func): |
|
result_trails = [] |
|
result_trails.append(data.result_list) |
|
extract_func = getattr(self.extractor, extract_func) |
|
temperature = [0.5, 1] |
|
for index in range(2): |
|
self.module.llm.set_hyperparameter(temperature=temperature[index]) |
|
data = extract_func(data) |
|
result_trails.append(data.result_list) |
|
self.module.llm.set_hyperparameter() |
|
consistant_result = [] |
|
reflect_index = [] |
|
for index, elements in enumerate(zip(*result_trails)): |
|
normalized_elements = [normalize_obj(e) for e in elements] |
|
element_counts = Counter(normalized_elements) |
|
selected_element = next((elements[i] for i, element in enumerate(normalized_elements) |
|
if element_counts[element] >= 2), None) |
|
if selected_element is None: |
|
selected_element = self.__select_result(elements) |
|
reflect_index.append(index) |
|
consistant_result.append(selected_element) |
|
data.set_result_list(consistant_result) |
|
return reflect_index |
|
|
|
def reflect_with_case(self, data: DataPoint): |
|
if data.result_list == []: |
|
return data |
|
reflect_index = self.__self_consistance_check(data) |
|
reflected_result_list = data.result_list |
|
for idx in reflect_index: |
|
text = data.chunk_text_list[idx] |
|
result = data.result_list[idx] |
|
examples = json.dumps(self.case_repo.query_bad_case(data)) |
|
reflected_res = self.module.get_reflection(instruction=data.instruction, examples=examples, text=text, schema=data.output_schema, result=result) |
|
reflected_result_list[idx] = reflected_res |
|
data.set_result_list(reflected_result_list) |
|
function_name = current_function_name() |
|
data.update_trajectory(function_name, data.result_list) |
|
return data |
|
|