from models import * from utils import * from .knowledge_base import schema_repository from langchain_core.output_parsers import JsonOutputParser class SchemaAnalyzer: def __init__(self, llm: BaseEngine): self.llm = llm def serialize_schema(self, schema) -> str: if isinstance(schema, (str, list, dict, set, tuple)): return schema try: parser = JsonOutputParser(pydantic_object = schema) schema_description = parser.get_format_instructions() schema_content = re.findall(r'```(.*?)```', schema_description, re.DOTALL) explanation = "For example, for the schema {\"properties\": {\"foo\": {\"title\": \"Foo\", \"description\": \"a list of strings\", \"type\": \"array\", \"items\": {\"type\": \"string\"}}}}, the object {\"foo\": [\"bar\", \"baz\"]} is a well-formatted instance." schema = f"{schema_content}\n\n{explanation}" except: return schema return schema def redefine_text(self, text_analysis): try: field = text_analysis['field'] genre = text_analysis['genre'] except: return text_analysis prompt = f"This text is from the field of {field} and represents the genre of {genre}." return prompt def get_text_analysis(self, text: str): output_schema = self.serialize_schema(schema_repository.TextDescription) prompt = text_analysis_instruction.format(examples="", text=text, schema=output_schema) response = self.llm.get_chat_response(prompt) response = extract_json_dict(response) response = self.redefine_text(response) return response def get_deduced_schema_json(self, instruction: str, text: str, distilled_text: str): prompt = deduced_schema_json_instruction.format(examples=example_wrapper(json_schema_examples), instruction=instruction, distilled_text=distilled_text, text=text) response = self.llm.get_chat_response(prompt) response = extract_json_dict(response) code = response print(f"Deduced Schema in Json: \n{response}\n\n") return code, response def get_deduced_schema_code(self, instruction: str, text: str, distilled_text: str): prompt = deduced_schema_code_instruction.format(examples=example_wrapper(code_schema_examples), instruction=instruction, distilled_text=distilled_text, text=text) response = self.llm.get_chat_response(prompt) code_blocks = re.findall(r'```[^\n]*\n(.*?)\n```', response, re.DOTALL) if code_blocks: try: code_block = code_blocks[-1] namespace = {} exec(code_block, namespace) schema = namespace.get('ExtractionTarget') if schema is not None: index = code_block.find("class") code = code_block[index:] print(f"Deduced Schema in Code: \n{code}\n\n") schema = self.serialize_schema(schema) return code, schema except Exception as e: print(e) return self.get_deduced_schema_json(instruction, text, distilled_text) return self.get_deduced_schema_json(instruction, text, distilled_text) class SchemaAgent: def __init__(self, llm: BaseEngine): self.llm = llm self.module = SchemaAnalyzer(llm = llm) self.schema_repo = schema_repository self.methods = ["get_default_schema", "get_retrieved_schema", "get_deduced_schema"] def __preprocess_text(self, data: DataPoint): if data.use_file: data.chunk_text_list = chunk_file(data.file_path) else: data.chunk_text_list = chunk_str(data.text) if data.task == "NER": data.print_schema = """ class Entity(BaseModel): name : str = Field(description="The specific name of the entity. ") type : str = Field(description="The type or category that the entity belongs to.") class EntityList(BaseModel): entity_list : List[Entity] = Field(description="Named entities appearing in the text.") """ elif data.task == "RE": data.print_schema = """ class Relation(BaseModel): head : str = Field(description="The starting entity in the relationship.") tail : str = Field(description="The ending entity in the relationship.") relation : str = Field(description="The predicate that defines the relationship between the two entities.") class RelationList(BaseModel): relation_list : List[Relation] = Field(description="The collection of relationships between various entities.") """ elif data.task == "EE": data.print_schema = """ class Event(BaseModel): event_type : str = Field(description="The type of the event.") event_trigger : str = Field(description="A specific word or phrase that indicates the occurrence of the event.") event_argument : dict = Field(description="The arguments or participants involved in the event.") class EventList(BaseModel): event_list : List[Event] = Field(description="The events presented in the text.") """ elif data.task == "Triple": data.print_schema = """ class Triple(BaseModel): head: str = Field(description="The subject or head of the triple.") head_type: str = Field(description="The type of the subject entity.") relation: str = Field(description="The predicate or relation between the entities.") relation_type: str = Field(description="The type of the relation.") tail: str = Field(description="The object or tail of the triple.") tail_type: str = Field(description="The type of the object entity.") class TripleList(BaseModel): triple_list: List[Triple] = Field(description="The collection of triples and their types presented in the text.") """ return data def get_default_schema(self, data: DataPoint): data = self.__preprocess_text(data) default_schema = config['agent']['default_schema'] data.set_schema(default_schema) function_name = current_function_name() data.update_trajectory(function_name, default_schema) return data def get_retrieved_schema(self, data: DataPoint): self.__preprocess_text(data) schema_name = data.output_schema schema_class = getattr(self.schema_repo, schema_name, None) if schema_class is not None: schema = self.module.serialize_schema(schema_class) default_schema = config['agent']['default_schema'] data.set_schema(f"{default_schema}\n{schema}") function_name = current_function_name() data.update_trajectory(function_name, schema) else: return self.get_default_schema(data) return data def get_deduced_schema(self, data: DataPoint): self.__preprocess_text(data) target_text = data.chunk_text_list[0] analysed_text = self.module.get_text_analysis(target_text) if len(data.chunk_text_list) > 1: prefix = "Below is a portion of the text to be extracted. " analysed_text = f"{prefix}\n{target_text}" distilled_text = self.module.redefine_text(analysed_text) code, deduced_schema = self.module.get_deduced_schema_code(data.instruction, target_text, distilled_text) data.print_schema = code data.set_distilled_text(distilled_text) default_schema = config['agent']['default_schema'] data.set_schema(f"{default_schema}\n{deduced_schema}") function_name = current_function_name() data.update_trajectory(function_name, deduced_schema) return data