File size: 7,767 Bytes
009d93e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e6e7506
 
 
 
 
 
 
 
 
 
 
 
009d93e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
from models import *
from utils import *
from .knowledge_base import schema_repository
from langchain_core.output_parsers import JsonOutputParser

class SchemaAnalyzer:
    def __init__(self, llm: BaseEngine):
        self.llm = llm

    def serialize_schema(self, schema) -> str:
        if isinstance(schema, (str, list, dict, set, tuple)):
            return schema
        try:
            parser = JsonOutputParser(pydantic_object = schema)
            schema_description = parser.get_format_instructions()
            schema_content = re.findall(r'```(.*?)```', schema_description, re.DOTALL)
            explanation = "For example, for the schema {\"properties\": {\"foo\": {\"title\": \"Foo\", \"description\": \"a list of strings\", \"type\": \"array\", \"items\": {\"type\": \"string\"}}}}, the object {\"foo\": [\"bar\", \"baz\"]} is a well-formatted instance."
            schema = f"{schema_content}\n\n{explanation}"
        except:
            return schema
        return schema

    def redefine_text(self, text_analysis):
        try:
            field = text_analysis['field']
            genre = text_analysis['genre']
        except:
            return text_analysis
        prompt = f"This text is from the field of {field} and represents the genre of {genre}."
        return prompt

    def get_text_analysis(self, text: str):
        output_schema = self.serialize_schema(schema_repository.TextDescription)
        prompt = text_analysis_instruction.format(examples="", text=text, schema=output_schema)
        response = self.llm.get_chat_response(prompt)
        response = extract_json_dict(response)
        response = self.redefine_text(response)
        return response

    def get_deduced_schema_json(self, instruction: str, text: str, distilled_text: str):
        prompt = deduced_schema_json_instruction.format(examples=example_wrapper(json_schema_examples), instruction=instruction, distilled_text=distilled_text, text=text)
        response = self.llm.get_chat_response(prompt)
        response = extract_json_dict(response)
        code = response
        print(f"Deduced Schema in Json: \n{response}\n\n")
        return code, response

    def get_deduced_schema_code(self, instruction: str, text: str, distilled_text: str):
        prompt = deduced_schema_code_instruction.format(examples=example_wrapper(code_schema_examples), instruction=instruction, distilled_text=distilled_text, text=text)
        response = self.llm.get_chat_response(prompt)
        code_blocks = re.findall(r'```[^\n]*\n(.*?)\n```', response, re.DOTALL)
        if code_blocks:
            try:
                code_block = code_blocks[-1]
                namespace = {}
                exec(code_block, namespace)
                schema = namespace.get('ExtractionTarget')
                if schema is not None:
                    index = code_block.find("class")
                    code = code_block[index:]
                    print(f"Deduced Schema in Code: \n{code}\n\n")
                    schema = self.serialize_schema(schema)
                    return code, schema
            except Exception as e:
                print(e)
                return self.get_deduced_schema_json(instruction, text, distilled_text)
        return self.get_deduced_schema_json(instruction, text, distilled_text)

class SchemaAgent:
    def __init__(self, llm: BaseEngine):
        self.llm = llm
        self.module = SchemaAnalyzer(llm = llm)
        self.schema_repo = schema_repository
        self.methods = ["get_default_schema", "get_retrieved_schema", "get_deduced_schema"]

    def __preprocess_text(self, data: DataPoint):
        if data.use_file:
            data.chunk_text_list = chunk_file(data.file_path)
        else:
            data.chunk_text_list = chunk_str(data.text)
        if data.task == "NER":
            data.print_schema = """
class Entity(BaseModel):
    name : str = Field(description="The specific name of the entity. ")
    type : str = Field(description="The type or category that the entity belongs to.")
class EntityList(BaseModel):
    entity_list : List[Entity] = Field(description="Named entities appearing in the text.")
            """
        elif data.task == "RE":
            data.print_schema = """
class Relation(BaseModel):
    head : str = Field(description="The starting entity in the relationship.")
    tail : str = Field(description="The ending entity in the relationship.")
    relation : str = Field(description="The predicate that defines the relationship between the two entities.")

class RelationList(BaseModel):
    relation_list : List[Relation] = Field(description="The collection of relationships between various entities.")
            """
        elif data.task == "EE":
            data.print_schema = """
class Event(BaseModel):
    event_type : str = Field(description="The type of the event.")
    event_trigger : str = Field(description="A specific word or phrase that indicates the occurrence of the event.")
    event_argument : dict = Field(description="The arguments or participants involved in the event.")

class EventList(BaseModel):
    event_list : List[Event] = Field(description="The events presented in the text.")
            """
        elif data.task == "Triple":
            data.print_schema = """
class Triple(BaseModel):
    head: str = Field(description="The subject or head of the triple.")
    head_type: str = Field(description="The type of the subject entity.")
    relation: str = Field(description="The predicate or relation between the entities.")
    relation_type: str = Field(description="The type of the relation.")
    tail: str = Field(description="The object or tail of the triple.")
    tail_type: str = Field(description="The type of the object entity.")
class TripleList(BaseModel):
    triple_list: List[Triple] = Field(description="The collection of triples and their types presented in the text.")
"""
        return data

    def get_default_schema(self, data: DataPoint):
        data = self.__preprocess_text(data)
        default_schema = config['agent']['default_schema']
        data.set_schema(default_schema)
        function_name = current_function_name()
        data.update_trajectory(function_name, default_schema)
        return data

    def get_retrieved_schema(self, data: DataPoint):
        self.__preprocess_text(data)
        schema_name = data.output_schema
        schema_class = getattr(self.schema_repo, schema_name, None)
        if schema_class is not None:
            schema = self.module.serialize_schema(schema_class)
            default_schema = config['agent']['default_schema']
            data.set_schema(f"{default_schema}\n{schema}")
            function_name = current_function_name()
            data.update_trajectory(function_name, schema)
        else:
            return self.get_default_schema(data)
        return data

    def get_deduced_schema(self, data: DataPoint):
        self.__preprocess_text(data)
        target_text = data.chunk_text_list[0]
        analysed_text = self.module.get_text_analysis(target_text)
        if len(data.chunk_text_list) > 1:
            prefix = "Below is a portion of the text to be extracted. "
            analysed_text = f"{prefix}\n{target_text}"
        distilled_text = self.module.redefine_text(analysed_text)
        code, deduced_schema = self.module.get_deduced_schema_code(data.instruction, target_text, distilled_text)
        data.print_schema = code
        data.set_distilled_text(distilled_text)
        default_schema = config['agent']['default_schema']
        data.set_schema(f"{default_schema}\n{deduced_schema}")
        function_name = current_function_name()
        data.update_trajectory(function_name, deduced_schema)
        return data