File size: 7,767 Bytes
009d93e e6e7506 009d93e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
from models import *
from utils import *
from .knowledge_base import schema_repository
from langchain_core.output_parsers import JsonOutputParser
class SchemaAnalyzer:
def __init__(self, llm: BaseEngine):
self.llm = llm
def serialize_schema(self, schema) -> str:
if isinstance(schema, (str, list, dict, set, tuple)):
return schema
try:
parser = JsonOutputParser(pydantic_object = schema)
schema_description = parser.get_format_instructions()
schema_content = re.findall(r'```(.*?)```', schema_description, re.DOTALL)
explanation = "For example, for the schema {\"properties\": {\"foo\": {\"title\": \"Foo\", \"description\": \"a list of strings\", \"type\": \"array\", \"items\": {\"type\": \"string\"}}}}, the object {\"foo\": [\"bar\", \"baz\"]} is a well-formatted instance."
schema = f"{schema_content}\n\n{explanation}"
except:
return schema
return schema
def redefine_text(self, text_analysis):
try:
field = text_analysis['field']
genre = text_analysis['genre']
except:
return text_analysis
prompt = f"This text is from the field of {field} and represents the genre of {genre}."
return prompt
def get_text_analysis(self, text: str):
output_schema = self.serialize_schema(schema_repository.TextDescription)
prompt = text_analysis_instruction.format(examples="", text=text, schema=output_schema)
response = self.llm.get_chat_response(prompt)
response = extract_json_dict(response)
response = self.redefine_text(response)
return response
def get_deduced_schema_json(self, instruction: str, text: str, distilled_text: str):
prompt = deduced_schema_json_instruction.format(examples=example_wrapper(json_schema_examples), instruction=instruction, distilled_text=distilled_text, text=text)
response = self.llm.get_chat_response(prompt)
response = extract_json_dict(response)
code = response
print(f"Deduced Schema in Json: \n{response}\n\n")
return code, response
def get_deduced_schema_code(self, instruction: str, text: str, distilled_text: str):
prompt = deduced_schema_code_instruction.format(examples=example_wrapper(code_schema_examples), instruction=instruction, distilled_text=distilled_text, text=text)
response = self.llm.get_chat_response(prompt)
code_blocks = re.findall(r'```[^\n]*\n(.*?)\n```', response, re.DOTALL)
if code_blocks:
try:
code_block = code_blocks[-1]
namespace = {}
exec(code_block, namespace)
schema = namespace.get('ExtractionTarget')
if schema is not None:
index = code_block.find("class")
code = code_block[index:]
print(f"Deduced Schema in Code: \n{code}\n\n")
schema = self.serialize_schema(schema)
return code, schema
except Exception as e:
print(e)
return self.get_deduced_schema_json(instruction, text, distilled_text)
return self.get_deduced_schema_json(instruction, text, distilled_text)
class SchemaAgent:
def __init__(self, llm: BaseEngine):
self.llm = llm
self.module = SchemaAnalyzer(llm = llm)
self.schema_repo = schema_repository
self.methods = ["get_default_schema", "get_retrieved_schema", "get_deduced_schema"]
def __preprocess_text(self, data: DataPoint):
if data.use_file:
data.chunk_text_list = chunk_file(data.file_path)
else:
data.chunk_text_list = chunk_str(data.text)
if data.task == "NER":
data.print_schema = """
class Entity(BaseModel):
name : str = Field(description="The specific name of the entity. ")
type : str = Field(description="The type or category that the entity belongs to.")
class EntityList(BaseModel):
entity_list : List[Entity] = Field(description="Named entities appearing in the text.")
"""
elif data.task == "RE":
data.print_schema = """
class Relation(BaseModel):
head : str = Field(description="The starting entity in the relationship.")
tail : str = Field(description="The ending entity in the relationship.")
relation : str = Field(description="The predicate that defines the relationship between the two entities.")
class RelationList(BaseModel):
relation_list : List[Relation] = Field(description="The collection of relationships between various entities.")
"""
elif data.task == "EE":
data.print_schema = """
class Event(BaseModel):
event_type : str = Field(description="The type of the event.")
event_trigger : str = Field(description="A specific word or phrase that indicates the occurrence of the event.")
event_argument : dict = Field(description="The arguments or participants involved in the event.")
class EventList(BaseModel):
event_list : List[Event] = Field(description="The events presented in the text.")
"""
elif data.task == "Triple":
data.print_schema = """
class Triple(BaseModel):
head: str = Field(description="The subject or head of the triple.")
head_type: str = Field(description="The type of the subject entity.")
relation: str = Field(description="The predicate or relation between the entities.")
relation_type: str = Field(description="The type of the relation.")
tail: str = Field(description="The object or tail of the triple.")
tail_type: str = Field(description="The type of the object entity.")
class TripleList(BaseModel):
triple_list: List[Triple] = Field(description="The collection of triples and their types presented in the text.")
"""
return data
def get_default_schema(self, data: DataPoint):
data = self.__preprocess_text(data)
default_schema = config['agent']['default_schema']
data.set_schema(default_schema)
function_name = current_function_name()
data.update_trajectory(function_name, default_schema)
return data
def get_retrieved_schema(self, data: DataPoint):
self.__preprocess_text(data)
schema_name = data.output_schema
schema_class = getattr(self.schema_repo, schema_name, None)
if schema_class is not None:
schema = self.module.serialize_schema(schema_class)
default_schema = config['agent']['default_schema']
data.set_schema(f"{default_schema}\n{schema}")
function_name = current_function_name()
data.update_trajectory(function_name, schema)
else:
return self.get_default_schema(data)
return data
def get_deduced_schema(self, data: DataPoint):
self.__preprocess_text(data)
target_text = data.chunk_text_list[0]
analysed_text = self.module.get_text_analysis(target_text)
if len(data.chunk_text_list) > 1:
prefix = "Below is a portion of the text to be extracted. "
analysed_text = f"{prefix}\n{target_text}"
distilled_text = self.module.redefine_text(analysed_text)
code, deduced_schema = self.module.get_deduced_schema_code(data.instruction, target_text, distilled_text)
data.print_schema = code
data.set_distilled_text(distilled_text)
default_schema = config['agent']['default_schema']
data.set_schema(f"{default_schema}\n{deduced_schema}")
function_name = current_function_name()
data.update_trajectory(function_name, deduced_schema)
return data
|