|
from models import * |
|
from utils import * |
|
from .knowledge_base import schema_repository |
|
from langchain_core.output_parsers import JsonOutputParser |
|
|
|
class SchemaAnalyzer: |
|
def __init__(self, llm: BaseEngine): |
|
self.llm = llm |
|
|
|
def serialize_schema(self, schema) -> str: |
|
if isinstance(schema, (str, list, dict, set, tuple)): |
|
return schema |
|
try: |
|
parser = JsonOutputParser(pydantic_object = schema) |
|
schema_description = parser.get_format_instructions() |
|
schema_content = re.findall(r'```(.*?)```', schema_description, re.DOTALL) |
|
explanation = "For example, for the schema {\"properties\": {\"foo\": {\"title\": \"Foo\", \"description\": \"a list of strings\", \"type\": \"array\", \"items\": {\"type\": \"string\"}}}}, the object {\"foo\": [\"bar\", \"baz\"]} is a well-formatted instance." |
|
schema = f"{schema_content}\n\n{explanation}" |
|
except: |
|
return schema |
|
return schema |
|
|
|
def redefine_text(self, text_analysis): |
|
try: |
|
field = text_analysis['field'] |
|
genre = text_analysis['genre'] |
|
except: |
|
return text_analysis |
|
prompt = f"This text is from the field of {field} and represents the genre of {genre}." |
|
return prompt |
|
|
|
def get_text_analysis(self, text: str): |
|
output_schema = self.serialize_schema(schema_repository.TextDescription) |
|
prompt = text_analysis_instruction.format(examples="", text=text, schema=output_schema) |
|
response = self.llm.get_chat_response(prompt) |
|
response = extract_json_dict(response) |
|
response = self.redefine_text(response) |
|
return response |
|
|
|
def get_deduced_schema_json(self, instruction: str, text: str, distilled_text: str): |
|
prompt = deduced_schema_json_instruction.format(examples=example_wrapper(json_schema_examples), instruction=instruction, distilled_text=distilled_text, text=text) |
|
response = self.llm.get_chat_response(prompt) |
|
response = extract_json_dict(response) |
|
code = response |
|
print(f"Deduced Schema in Json: \n{response}\n\n") |
|
return code, response |
|
|
|
def get_deduced_schema_code(self, instruction: str, text: str, distilled_text: str): |
|
prompt = deduced_schema_code_instruction.format(examples=example_wrapper(code_schema_examples), instruction=instruction, distilled_text=distilled_text, text=text) |
|
response = self.llm.get_chat_response(prompt) |
|
code_blocks = re.findall(r'```[^\n]*\n(.*?)\n```', response, re.DOTALL) |
|
if code_blocks: |
|
try: |
|
code_block = code_blocks[-1] |
|
namespace = {} |
|
exec(code_block, namespace) |
|
schema = namespace.get('ExtractionTarget') |
|
if schema is not None: |
|
index = code_block.find("class") |
|
code = code_block[index:] |
|
print(f"Deduced Schema in Code: \n{code}\n\n") |
|
schema = self.serialize_schema(schema) |
|
return code, schema |
|
except Exception as e: |
|
print(e) |
|
return self.get_deduced_schema_json(instruction, text, distilled_text) |
|
return self.get_deduced_schema_json(instruction, text, distilled_text) |
|
|
|
class SchemaAgent: |
|
def __init__(self, llm: BaseEngine): |
|
self.llm = llm |
|
self.module = SchemaAnalyzer(llm = llm) |
|
self.schema_repo = schema_repository |
|
self.methods = ["get_default_schema", "get_retrieved_schema", "get_deduced_schema"] |
|
|
|
def __preprocess_text(self, data: DataPoint): |
|
if data.use_file: |
|
data.chunk_text_list = chunk_file(data.file_path) |
|
else: |
|
data.chunk_text_list = chunk_str(data.text) |
|
if data.task == "NER": |
|
data.print_schema = """ |
|
class Entity(BaseModel): |
|
name : str = Field(description="The specific name of the entity. ") |
|
type : str = Field(description="The type or category that the entity belongs to.") |
|
class EntityList(BaseModel): |
|
entity_list : List[Entity] = Field(description="Named entities appearing in the text.") |
|
""" |
|
elif data.task == "RE": |
|
data.print_schema = """ |
|
class Relation(BaseModel): |
|
head : str = Field(description="The starting entity in the relationship.") |
|
tail : str = Field(description="The ending entity in the relationship.") |
|
relation : str = Field(description="The predicate that defines the relationship between the two entities.") |
|
|
|
class RelationList(BaseModel): |
|
relation_list : List[Relation] = Field(description="The collection of relationships between various entities.") |
|
""" |
|
elif data.task == "EE": |
|
data.print_schema = """ |
|
class Event(BaseModel): |
|
event_type : str = Field(description="The type of the event.") |
|
event_trigger : str = Field(description="A specific word or phrase that indicates the occurrence of the event.") |
|
event_argument : dict = Field(description="The arguments or participants involved in the event.") |
|
|
|
class EventList(BaseModel): |
|
event_list : List[Event] = Field(description="The events presented in the text.") |
|
""" |
|
elif data.task == "Triple": |
|
data.print_schema = """ |
|
class Triple(BaseModel): |
|
head: str = Field(description="The subject or head of the triple.") |
|
head_type: str = Field(description="The type of the subject entity.") |
|
relation: str = Field(description="The predicate or relation between the entities.") |
|
relation_type: str = Field(description="The type of the relation.") |
|
tail: str = Field(description="The object or tail of the triple.") |
|
tail_type: str = Field(description="The type of the object entity.") |
|
class TripleList(BaseModel): |
|
triple_list: List[Triple] = Field(description="The collection of triples and their types presented in the text.") |
|
""" |
|
return data |
|
|
|
def get_default_schema(self, data: DataPoint): |
|
data = self.__preprocess_text(data) |
|
default_schema = config['agent']['default_schema'] |
|
data.set_schema(default_schema) |
|
function_name = current_function_name() |
|
data.update_trajectory(function_name, default_schema) |
|
return data |
|
|
|
def get_retrieved_schema(self, data: DataPoint): |
|
self.__preprocess_text(data) |
|
schema_name = data.output_schema |
|
schema_class = getattr(self.schema_repo, schema_name, None) |
|
if schema_class is not None: |
|
schema = self.module.serialize_schema(schema_class) |
|
default_schema = config['agent']['default_schema'] |
|
data.set_schema(f"{default_schema}\n{schema}") |
|
function_name = current_function_name() |
|
data.update_trajectory(function_name, schema) |
|
else: |
|
return self.get_default_schema(data) |
|
return data |
|
|
|
def get_deduced_schema(self, data: DataPoint): |
|
self.__preprocess_text(data) |
|
target_text = data.chunk_text_list[0] |
|
analysed_text = self.module.get_text_analysis(target_text) |
|
if len(data.chunk_text_list) > 1: |
|
prefix = "Below is a portion of the text to be extracted. " |
|
analysed_text = f"{prefix}\n{target_text}" |
|
distilled_text = self.module.redefine_text(analysed_text) |
|
code, deduced_schema = self.module.get_deduced_schema_code(data.instruction, target_text, distilled_text) |
|
data.print_schema = code |
|
data.set_distilled_text(distilled_text) |
|
default_schema = config['agent']['default_schema'] |
|
data.set_schema(f"{default_schema}\n{deduced_schema}") |
|
function_name = current_function_name() |
|
data.update_trajectory(function_name, deduced_schema) |
|
return data |
|
|