update
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- src/__pycache__/pipeline.cpython-311.pyc +0 -0
- src/__pycache__/pipeline.cpython-39.pyc +0 -0
- src/config.yaml +19 -0
- src/generate_memory.py +181 -0
- src/main.py +233 -0
- src/models/__init__.py +3 -0
- src/models/__pycache__/__init__.cpython-311.pyc +0 -0
- src/models/__pycache__/__init__.cpython-37.pyc +0 -0
- src/models/__pycache__/__init__.cpython-39.pyc +0 -0
- src/models/__pycache__/llm_def.cpython-311.pyc +0 -0
- src/models/__pycache__/llm_def.cpython-37.pyc +0 -0
- src/models/__pycache__/llm_def.cpython-39.pyc +0 -0
- src/models/__pycache__/prompt_example.cpython-311.pyc +0 -0
- src/models/__pycache__/prompt_example.cpython-39.pyc +0 -0
- src/models/__pycache__/prompt_template.cpython-311.pyc +0 -0
- src/models/__pycache__/prompt_template.cpython-39.pyc +0 -0
- src/models/llm_def.py +212 -0
- src/models/prompt_example.py +137 -0
- src/models/prompt_template.py +174 -0
- src/modules/__init__.py +4 -0
- src/modules/__pycache__/__init__.cpython-311.pyc +0 -0
- src/modules/__pycache__/__init__.cpython-39.pyc +0 -0
- src/modules/__pycache__/extraction_agent.cpython-311.pyc +0 -0
- src/modules/__pycache__/extraction_agent.cpython-39.pyc +0 -0
- src/modules/__pycache__/reflection_agent.cpython-311.pyc +0 -0
- src/modules/__pycache__/reflection_agent.cpython-39.pyc +0 -0
- src/modules/__pycache__/schema_agent.cpython-311.pyc +0 -0
- src/modules/__pycache__/schema_agent.cpython-39.pyc +0 -0
- src/modules/extraction_agent.py +85 -0
- src/modules/knowledge_base/__pycache__/case_repository.cpython-311.pyc +0 -0
- src/modules/knowledge_base/__pycache__/case_repository.cpython-39.pyc +0 -0
- src/modules/knowledge_base/__pycache__/schema_repository.cpython-311.pyc +0 -0
- src/modules/knowledge_base/__pycache__/schema_repository.cpython-39.pyc +0 -0
- src/modules/knowledge_base/case_repository.json +0 -0
- src/modules/knowledge_base/case_repository.py +391 -0
- src/modules/knowledge_base/schema_repository.py +91 -0
- src/modules/reflection_agent.py +74 -0
- src/modules/schema_agent.py +151 -0
- src/pipeline.py +98 -0
- src/run.py +88 -0
- src/utils/__init__.py +3 -0
- src/utils/__pycache__/__init__.cpython-311.pyc +0 -0
- src/utils/__pycache__/__init__.cpython-39.pyc +0 -0
- src/utils/__pycache__/data_def.cpython-311.pyc +0 -0
- src/utils/__pycache__/data_def.cpython-39.pyc +0 -0
- src/utils/__pycache__/process.cpython-311.pyc +0 -0
- src/utils/__pycache__/process.cpython-39.pyc +0 -0
- src/utils/data_def.py +59 -0
- src/utils/process.py +183 -0
- src/webui/__init__.py +1 -0
src/__pycache__/pipeline.cpython-311.pyc
ADDED
Binary file (5.34 kB). View file
|
|
src/__pycache__/pipeline.cpython-39.pyc
ADDED
Binary file (3.56 kB). View file
|
|
src/config.yaml
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
agent:
|
2 |
+
default_schema: The final extraction result should be formatted as a JSON object.
|
3 |
+
default_ner: Extract the Named Entities in the given text.
|
4 |
+
default_re: Extract Relationships between Named Entities in the given text.
|
5 |
+
default_ee: Extract the Events in the given text.
|
6 |
+
chunk_token_limit: 1024
|
7 |
+
mode:
|
8 |
+
quick:
|
9 |
+
schema_agent: get_deduced_schema
|
10 |
+
extraction_agent: extract_information_direct
|
11 |
+
standard:
|
12 |
+
schema_agent: get_deduced_schema
|
13 |
+
extraction_agent: extract_information_with_case
|
14 |
+
reflection_agent: reflect_with_case
|
15 |
+
customized:
|
16 |
+
schema_agent: get_retrieved_schema
|
17 |
+
extraction_agent: extract_information_direct
|
18 |
+
|
19 |
+
|
src/generate_memory.py
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Literal
|
2 |
+
from models import *
|
3 |
+
from utils import *
|
4 |
+
from modules import *
|
5 |
+
|
6 |
+
|
7 |
+
class Pipeline:
|
8 |
+
def __init__(self, llm: BaseEngine):
|
9 |
+
self.llm = llm
|
10 |
+
self.case_repo = CaseRepositoryHandler(llm = llm)
|
11 |
+
self.schema_agent = SchemaAgent(llm = llm)
|
12 |
+
self.extraction_agent = ExtractionAgent(llm = llm, case_repo = self.case_repo)
|
13 |
+
self.reflection_agent = ReflectionAgent(llm = llm, case_repo = self.case_repo)
|
14 |
+
|
15 |
+
def __init_method(self, data: DataPoint, process_method):
|
16 |
+
default_order = ["schema_agent", "extraction_agent", "reflection_agent"]
|
17 |
+
if "schema_agent" not in process_method:
|
18 |
+
process_method["schema_agent"] = "get_default_schema"
|
19 |
+
if data.task != "Base":
|
20 |
+
process_method["schema_agent"] = "get_retrieved_schema"
|
21 |
+
if "extraction_agent" not in process_method:
|
22 |
+
process_method["extraction_agent"] = "extract_information_direct"
|
23 |
+
sorted_process_method = {key: process_method[key] for key in default_order if key in process_method}
|
24 |
+
return sorted_process_method
|
25 |
+
|
26 |
+
def __init_data(self, data: DataPoint):
|
27 |
+
if data.task == "NER":
|
28 |
+
data.instruction = config['agent']['default_ner']
|
29 |
+
data.output_schema = "EntityList"
|
30 |
+
elif data.task == "RE":
|
31 |
+
data.instruction = config['agent']['default_re']
|
32 |
+
data.output_schema = "RelationList"
|
33 |
+
elif data.task == "EE":
|
34 |
+
data.instruction = config['agent']['default_ee']
|
35 |
+
data.output_schema = "EventList"
|
36 |
+
return data
|
37 |
+
|
38 |
+
# main entry
|
39 |
+
def get_extract_result(self,
|
40 |
+
task: TaskType,
|
41 |
+
instruction: str = "",
|
42 |
+
text: str = "",
|
43 |
+
output_schema: str = "",
|
44 |
+
constraint: str = "",
|
45 |
+
use_file: bool = False,
|
46 |
+
truth: str = "",
|
47 |
+
mode: str = "quick",
|
48 |
+
update_case: bool = False
|
49 |
+
):
|
50 |
+
|
51 |
+
data = DataPoint(task=task, instruction=instruction, text=text, output_schema=output_schema, constraint=constraint, use_file=use_file, truth=truth)
|
52 |
+
data = self.__init_data(data)
|
53 |
+
data.instruction = "In the tranquil seaside town, the summer evening cast a golden glow over everything. The townsfolk gathered at the café by the pier, enjoying the sea breeze while eagerly anticipating the annual Ocean Festival's opening ceremony. \nFirst to arrive was Mayor William, dressed in a deep blue suit, holding a roll of his speech. He smiled and greeted the residents, who held deep respect for their community-minded mayor. Beside him trotted Max, his loyal golden retriever, wagging his tail excitedly at every familiar face he saw. \nFollowing closely was Emily, the town’s high school teacher, accompanied by a group of students ready to perform a musical piece they'd rehearsed. One of the girls carried Polly, a vibrant green parrot, on her shoulder. Polly occasionally chimed in with cheerful squawks, adding to the lively atmosphere. \nNot far away, Captain Jack, with his trusty pipe in hand, chatted with old friends about this year's catch. His fleet was the town’s economic backbone, and his seasoned face and towering presence were complemented by the presence of Whiskers, his orange tabby cat, who loved lounging on the dock, attentively watching the gentle waves. \nInside the café, Kate was bustling about, serving guests. As the owner, with her fiery red curls and vivacious spirit, she was the heart of the place. Her friend Susan, an artist living in a tiny cottage nearby, was helping her prepare refreshing beverages. Slinky, Susan's mischievous ferret, darted playfully between the tables, much to the delight of the children present. \nLeaning on the café's railing, a young boy named Tommy watched the sea with wide, gleaming eyes, filled with dreams of the future. By his side sat Daisy, a spirited little dachshund, barking excitedly at the seagulls flying overhead. Tommy's mother, Lucy, stood beside him, smiling softly as she held a seashell he had just found on the beach. \nAmong the crowd, a group of unnamed tourists snapped photos, capturing memories of the charming festival. Street vendors called out, selling their wares—handmade jewelry and sweet confections—as the scent of grilled seafood wafted through the air. \nSuddenly, a burst of laughter erupted—it was James and his band making their grand entrance. Accompanying them was Benny, a friendly border collie who \"performed\" with the band, delighting the crowd with his antics. Set to play a big concert after the opening ceremony, James, the town's star musician, had won the hearts of locals with his soulful tunes. \nAs dusk settled, lights were strung across the streets, casting a magical glow over the town. Mayor William took the stage to deliver his speech, with Max sitting proudly by his side. The festival atmosphere reached its vibrant peak, and in this small town, each person—and animal—carried their own dreams and stories, yet at this moment, they were united by the shared celebration."
|
54 |
+
data.chunk_text_list.append("In the tranquil seaside town, the summer evening cast a golden glow over everything. The townsfolk gathered at the café by the pier, enjoying the sea breeze while eagerly anticipating the annual Ocean Festival's opening ceremony. \nFirst to arrive was Mayor William, dressed in a deep blue suit, holding a roll of his speech. He smiled and greeted the residents, who held deep respect for their community-minded mayor. Beside him trotted Max, his loyal golden retriever, wagging his tail excitedly at every familiar face he saw. \nFollowing closely was Emily, the town’s high school teacher, accompanied by a group of students ready to perform a musical piece they'd rehearsed. One of the girls carried Polly, a vibrant green parrot, on her shoulder. Polly occasionally chimed in with cheerful squawks, adding to the lively atmosphere. \nNot far away, Captain Jack, with his trusty pipe in hand, chatted with old friends about this year's catch. His fleet was the town’s economic backbone, and his seasoned face and towering presence were complemented by the presence of Whiskers, his orange tabby cat, who loved lounging on the dock, attentively watching the gentle waves. \nInside the café, Kate was bustling about, serving guests. As the owner, with her fiery red curls and vivacious spirit, she was the heart of the place. Her friend Susan, an artist living in a tiny cottage nearby, was helping her prepare refreshing beverages. Slinky, Susan's mischievous ferret, darted playfully between the tables, much to the delight of the children present. \nLeaning on the café's railing, a young boy named Tommy watched the sea with wide, gleaming eyes, filled with dreams of the future. By his side sat Daisy, a spirited little dachshund, barking excitedly at the seagulls flying overhead. Tommy's mother, Lucy, stood beside him, smiling softly as she held a seashell he had just found on the beach. \nAmong the crowd, a group of unnamed tourists snapped photos, capturing memories of the charming festival. Street vendors called out, selling their wares—handmade jewelry and sweet confections—as the scent of grilled seafood wafted through the air. \nSuddenly, a burst of laughter erupted—it was James and his band making their grand entrance. Accompanying them was Benny, a friendly border collie who \"performed\" with the band, delighting the crowd with his antics. Set to play a big concert after the opening ceremony, James, the town's star musician, had won the hearts of locals with his soulful tunes. \nAs dusk settled, lights were strung across the streets, casting a magical glow over the town. Mayor William took the stage to deliver his speech, with Max sitting proudly by his side. The festival atmosphere reached its vibrant peak, and in this small town, each person—and animal—carried their own dreams and stories, yet at this moment, they were united by the shared celebration.")
|
55 |
+
data.distilled_text = "This text is from the field of Slice of Life and represents the genre of Novel."
|
56 |
+
data.pred = {
|
57 |
+
"characters": [
|
58 |
+
{
|
59 |
+
"name": "Mayor William",
|
60 |
+
"role": "Mayor"
|
61 |
+
},
|
62 |
+
{
|
63 |
+
"name": "Max",
|
64 |
+
"role": "Golden Retriever, Mayor William's dog"
|
65 |
+
},
|
66 |
+
{
|
67 |
+
"name": "Emily",
|
68 |
+
"role": "High school teacher"
|
69 |
+
},
|
70 |
+
{
|
71 |
+
"name": "Polly",
|
72 |
+
"role": "Parrot, accompanying a student"
|
73 |
+
},
|
74 |
+
{
|
75 |
+
"name": "Captain Jack",
|
76 |
+
"role": "Captain"
|
77 |
+
},
|
78 |
+
{
|
79 |
+
"name": "Whiskers",
|
80 |
+
"role": "Orange tabby cat, Captain Jack's pet"
|
81 |
+
},
|
82 |
+
{
|
83 |
+
"name": "Kate",
|
84 |
+
"role": "Café owner"
|
85 |
+
},
|
86 |
+
{
|
87 |
+
"name": "Susan",
|
88 |
+
"role": "Artist, Kate's friend"
|
89 |
+
},
|
90 |
+
{
|
91 |
+
"name": "Slinky",
|
92 |
+
"role": "Ferret, Susan's pet"
|
93 |
+
},
|
94 |
+
{
|
95 |
+
"name": "Tommy",
|
96 |
+
"role": "Young boy"
|
97 |
+
},
|
98 |
+
{
|
99 |
+
"name": "Daisy",
|
100 |
+
"role": "Dachshund, Tommy's pet"
|
101 |
+
},
|
102 |
+
{
|
103 |
+
"name": "Lucy",
|
104 |
+
"role": "Tommy's mother"
|
105 |
+
},
|
106 |
+
{
|
107 |
+
"name": "James",
|
108 |
+
"role": "Musician, band leader"
|
109 |
+
},
|
110 |
+
{
|
111 |
+
"name": "Benny",
|
112 |
+
"role": "Border Collie, accompanying James and his band"
|
113 |
+
},
|
114 |
+
{
|
115 |
+
"name": "Unnamed Tourists",
|
116 |
+
"role": "Visitors at the festival"
|
117 |
+
},
|
118 |
+
{
|
119 |
+
"name": "Street Vendors",
|
120 |
+
"role": "Sellers at the festival"
|
121 |
+
}
|
122 |
+
]
|
123 |
+
}
|
124 |
+
|
125 |
+
data.truth = {
|
126 |
+
"characters": [
|
127 |
+
{
|
128 |
+
"name": "Mayor William",
|
129 |
+
"role": "The friendly and respected mayor of the seaside town."
|
130 |
+
},
|
131 |
+
{
|
132 |
+
"name": "Emily",
|
133 |
+
"role": "A high school teacher guiding students in a festival performance."
|
134 |
+
},
|
135 |
+
{
|
136 |
+
"name": "Captain Jack",
|
137 |
+
"role": "A seasoned sailor whose fleet supports the town."
|
138 |
+
},
|
139 |
+
{
|
140 |
+
"name": "Kate",
|
141 |
+
"role": "The welcoming owner of the local café."
|
142 |
+
},
|
143 |
+
{
|
144 |
+
"name": "Susan",
|
145 |
+
"role": "An artist known for her ocean-themed paintings."
|
146 |
+
},
|
147 |
+
{
|
148 |
+
"name": "Tommy",
|
149 |
+
"role": "A young boy with dreams of the sea."
|
150 |
+
},
|
151 |
+
{
|
152 |
+
"name": "Lucy",
|
153 |
+
"role": "Tommy's caring and supportive mother."
|
154 |
+
},
|
155 |
+
{
|
156 |
+
"name": "James",
|
157 |
+
"role": "A charismatic musician and band leader."
|
158 |
+
}
|
159 |
+
]
|
160 |
+
}
|
161 |
+
|
162 |
+
|
163 |
+
# Case Update
|
164 |
+
if update_case:
|
165 |
+
if (data.truth == ""):
|
166 |
+
truth = input("Please enter the correct answer you prefer, or press Enter to accept the current answer: ")
|
167 |
+
if truth.strip() == "":
|
168 |
+
data.truth = data.pred
|
169 |
+
else:
|
170 |
+
data.truth = extract_json_dict(truth)
|
171 |
+
self.case_repo.update_case(data)
|
172 |
+
|
173 |
+
# return result
|
174 |
+
result = data.pred
|
175 |
+
trajectory = data.get_result_trajectory()
|
176 |
+
|
177 |
+
return result, trajectory, "a", "b"
|
178 |
+
|
179 |
+
model = DeepSeek(model_name_or_path="deepseek-chat", api_key="")
|
180 |
+
pipeline = Pipeline(model)
|
181 |
+
result, trajectory, *_ = pipeline.get_extract_result(update_case=True, task="Base")
|
src/main.py
ADDED
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import json
|
3 |
+
import gradio as gr
|
4 |
+
|
5 |
+
from pipeline import Pipeline
|
6 |
+
from models import *
|
7 |
+
|
8 |
+
|
9 |
+
examples = [
|
10 |
+
{
|
11 |
+
"task": "NER",
|
12 |
+
"use_file": False,
|
13 |
+
"text": "Finally, every other year , ELRA organizes a major conference LREC , the International Language Resources and Evaluation Conference .",
|
14 |
+
"instruction": "",
|
15 |
+
"constraint": """["nationality", "country capital", "place of death", "children", "location contains", "place of birth", "place lived", "administrative division of country", "country of administrative divisions", "company", "neighborhood of", "company founders"]""",
|
16 |
+
"file_path": None,
|
17 |
+
},
|
18 |
+
{
|
19 |
+
"task": "RE",
|
20 |
+
"use_file": False,
|
21 |
+
"text": "The aid group Doctors Without Borders said that since Saturday , more than 275 wounded people had been admitted and treated at Donka Hospital in the capital of Guinea , Conakry .",
|
22 |
+
"instruction": "",
|
23 |
+
"constraint": """["nationality", "country capital", "place of death", "children", "location contains", "place of birth", "place lived", "administrative division of country", "country of administrative divisions", "company", "neighborhood of", "company founders"]""",
|
24 |
+
"file_path": None,
|
25 |
+
},
|
26 |
+
{
|
27 |
+
"task": "EE",
|
28 |
+
"use_file": False,
|
29 |
+
"text": "The file suggested to the user contains no software related to video streaming and simply carries the malicious payload that later compromises victim \u2019s account and sends out the deceptive messages to all victim \u2019s contacts .",
|
30 |
+
"instruction": "",
|
31 |
+
"constraint": """{"phishing": ["damage amount", "attack pattern", "tool", "victim", "place", "attacker", "purpose", "trusted entity", "time"], "data breach": ["damage amount", "attack pattern", "number of data", "number of victim", "tool", "compromised data", "victim", "place", "attacker", "purpose", "time"], "ransom": ["damage amount", "attack pattern", "payment method", "tool", "victim", "place", "attacker", "price", "time"], "discover vulnerability": ["vulnerable system", "vulnerability", "vulnerable system owner", "vulnerable system version", "supported platform", "common vulnerabilities and exposures", "capabilities", "time", "discoverer"], "patch vulnerability": ["vulnerable system", "vulnerability", "issues addressed", "vulnerable system version", "releaser", "supported platform", "common vulnerabilities and exposures", "patch number", "time", "patch"]}""",
|
32 |
+
"file_path": None,
|
33 |
+
},
|
34 |
+
# {
|
35 |
+
# "task": "Base",
|
36 |
+
# "use_file": True,
|
37 |
+
# "file_path": "data/Harry_Potter_Chapter_1.pdf",
|
38 |
+
# "instruction": "Extract main characters and the background setting from this chapter.",
|
39 |
+
# "constraint": "",
|
40 |
+
# "text": "",
|
41 |
+
# },
|
42 |
+
# {
|
43 |
+
# "task": "Base",
|
44 |
+
# "use_file": True,
|
45 |
+
# "file_path": "data/Tulsi_Gabbard_News.html",
|
46 |
+
# "instruction": "Extract key information from the given text.",
|
47 |
+
# "constraint": "",
|
48 |
+
# "text": "",
|
49 |
+
# },
|
50 |
+
]
|
51 |
+
|
52 |
+
|
53 |
+
def create_interface():
|
54 |
+
with gr.Blocks(title="OneKE Demo") as demo:
|
55 |
+
gr.HTML("""
|
56 |
+
<div style="text-align:center;">
|
57 |
+
<p align="center">
|
58 |
+
<a href="https://github.com/zjunlp/DeepKE/blob/main/example/llm/assets/oneke_logo.png">
|
59 |
+
<img src="https://raw.githubusercontent.com/zjunlp/DeepKE/refs/heads/main/example/llm/assets/oneke_logo.png" width="240"/>
|
60 |
+
</a>
|
61 |
+
</p>
|
62 |
+
<h1>OneKE: A Dockerized Schema-Guided LLM Agent-based Knowledge Extraction System</h1>
|
63 |
+
<p>
|
64 |
+
🌐[<a href="https://oneke.openkg.cn/" target="_blank">Web</a>]
|
65 |
+
⌨️[<a href="https://github.com/zjunlp/OneKE" target="_blank">Code</a>]
|
66 |
+
📹[<a href="http://oneke.openkg.cn/demo.mp4" target="_blank">Video</a>]
|
67 |
+
</p>
|
68 |
+
</div>
|
69 |
+
""")
|
70 |
+
|
71 |
+
example_button_gr = gr.Button("🎲 Quick Start with an Example 🎲")
|
72 |
+
|
73 |
+
|
74 |
+
with gr.Row():
|
75 |
+
with gr.Column():
|
76 |
+
model_gr = gr.Dropdown(choices=["gpt-3.5-turbo", "gpt-4o", "gpt-4o-mini"], label="🤖 Select your Model")
|
77 |
+
api_key_gr = gr.Textbox(label="🔑 Enter your API-Key")
|
78 |
+
with gr.Column():
|
79 |
+
task_gr = gr.Dropdown(choices=["Base", "NER", "RE", "EE"], label="🎯 Select your Task")
|
80 |
+
use_file_gr = gr.Checkbox(label="📂 Use File", value=True)
|
81 |
+
|
82 |
+
file_path_gr = gr.File(label="📖 Upload a File", visible=True)
|
83 |
+
text_gr = gr.Textbox(label="📖 Text", placeholder="Enter your Text", visible=False)
|
84 |
+
instruction_gr = gr.Textbox(label="🕹️ Instruction", visible=True)
|
85 |
+
constraint_gr = gr.Textbox(label="🕹️ Constraint", visible=False)
|
86 |
+
|
87 |
+
def update_fields(task):
|
88 |
+
if task == "Base":
|
89 |
+
return gr.update(visible=True, label="🕹️ Instruction", placeholder="Enter your Instruction"), gr.update(visible=False)
|
90 |
+
elif task == "NER":
|
91 |
+
return gr.update(visible=False), gr.update(visible=True, label="🕹️ Constraint", placeholder="Enter your NER Constraint")
|
92 |
+
elif task == "RE":
|
93 |
+
return gr.update(visible=False), gr.update(visible=True, label="🕹️ Constraint", placeholder="Enter your RE Constraint")
|
94 |
+
elif task == "EE":
|
95 |
+
return gr.update(visible=False), gr.update(visible=True, label="🕹️ Constraint", placeholder="Enter your EE Constraint")
|
96 |
+
|
97 |
+
def update_input_fields(use_file):
|
98 |
+
if use_file:
|
99 |
+
return gr.update(visible=False), gr.update(visible=True)
|
100 |
+
else:
|
101 |
+
return gr.update(visible=True), gr.update(visible=False)
|
102 |
+
|
103 |
+
def start_with_example():
|
104 |
+
example_index = random.randint(0, len(examples) - 1)
|
105 |
+
example = examples[example_index]
|
106 |
+
return (
|
107 |
+
gr.update(value=example["task"]),
|
108 |
+
gr.update(value=example["use_file"]),
|
109 |
+
gr.update(value=example["file_path"], visible=example["use_file"]),
|
110 |
+
gr.update(value=example["text"], visible=not example["use_file"]),
|
111 |
+
gr.update(value=example["instruction"], visible=example["task"] == "Base"),
|
112 |
+
gr.update(value=example["constraint"], visible=example["task"] in ["NER", "RE", "EE"]),
|
113 |
+
)
|
114 |
+
|
115 |
+
def submit(model, api_key, task, instruction, constraint, text, use_file, file_path):
|
116 |
+
try:
|
117 |
+
# 创建 Pipeline 实例
|
118 |
+
pipeline = Pipeline(ChatGPT(model_name_or_path=model, api_key=api_key))
|
119 |
+
if task == "Base":
|
120 |
+
instruction = instruction
|
121 |
+
constraint = ""
|
122 |
+
else:
|
123 |
+
instruction = ""
|
124 |
+
constraint = constraint
|
125 |
+
if use_file:
|
126 |
+
text = ""
|
127 |
+
file_path = file_path
|
128 |
+
else:
|
129 |
+
text = text
|
130 |
+
file_path = None
|
131 |
+
|
132 |
+
# 调用 Pipeline
|
133 |
+
_, _, ger_frontend_schema, ger_frontend_res = pipeline.get_extract_result(
|
134 |
+
task=task,
|
135 |
+
instruction=instruction,
|
136 |
+
constraint=constraint,
|
137 |
+
use_file=use_file,
|
138 |
+
file_path=file_path,
|
139 |
+
text=text,
|
140 |
+
)
|
141 |
+
|
142 |
+
ger_frontend_schema = str(ger_frontend_schema)
|
143 |
+
ger_frontend_res = json.dumps(ger_frontend_res, ensure_ascii=False, indent=4) if isinstance(ger_frontend_res, dict) else str(ger_frontend_res)
|
144 |
+
return ger_frontend_schema, ger_frontend_res, gr.update(value="", visible=False)
|
145 |
+
|
146 |
+
except Exception as e:
|
147 |
+
error_message = f"⚠️ Error:\n {str(e)}"
|
148 |
+
return "", "", gr.update(value=error_message, visible=True)
|
149 |
+
|
150 |
+
def clear_all():
|
151 |
+
return (
|
152 |
+
gr.update(value=""), # model
|
153 |
+
gr.update(value=""), # API Key
|
154 |
+
gr.update(value=""), # task
|
155 |
+
gr.update(value="", visible=False), # instruction
|
156 |
+
gr.update(value="", visible=False), # constraint
|
157 |
+
gr.update(value=True), # use_file
|
158 |
+
gr.update(value="", visible=False), # text
|
159 |
+
gr.update(value=None, visible=True), # file_path
|
160 |
+
gr.update(value=""),
|
161 |
+
gr.update(value=""),
|
162 |
+
gr.update(value="", visible=False), # error_output
|
163 |
+
)
|
164 |
+
|
165 |
+
with gr.Row():
|
166 |
+
submit_button_gr = gr.Button("Submit", variant="primary", scale=8)
|
167 |
+
clear_button = gr.Button("Clear", scale=5)
|
168 |
+
gr.HTML("""
|
169 |
+
<div style="width: 100%; text-align: center; font-size: 16px; font-weight: bold; position: relative; margin: 20px 0;">
|
170 |
+
<span style="position: absolute; left: 0; top: 50%; transform: translateY(-50%); width: 45%; border-top: 1px solid #ccc;"></span>
|
171 |
+
<span style="position: relative; z-index: 1; background-color: white; padding: 0 10px;">Output:</span>
|
172 |
+
<span style="position: absolute; right: 0; top: 50%; transform: translateY(-50%); width: 45%; border-top: 1px solid #ccc;"></span>
|
173 |
+
</div>
|
174 |
+
""")
|
175 |
+
error_output_gr = gr.Textbox(label="😵💫 Ops, an Error Occurred", visible=False)
|
176 |
+
with gr.Row():
|
177 |
+
with gr.Column(scale=1):
|
178 |
+
py_output_gr = gr.Code(label="🤔 Generated Schema", language="python", lines=10, interactive=False)
|
179 |
+
with gr.Column(scale=1):
|
180 |
+
json_output_gr = gr.Code(label="😉 Final Answer", language="json", lines=10, interactive=False)
|
181 |
+
|
182 |
+
task_gr.change(fn=update_fields, inputs=task_gr, outputs=[instruction_gr, constraint_gr])
|
183 |
+
use_file_gr.change(fn=update_input_fields, inputs=use_file_gr, outputs=[text_gr, file_path_gr])
|
184 |
+
|
185 |
+
example_button_gr.click(
|
186 |
+
fn=start_with_example,
|
187 |
+
inputs=[],
|
188 |
+
outputs=[
|
189 |
+
task_gr,
|
190 |
+
use_file_gr,
|
191 |
+
file_path_gr,
|
192 |
+
text_gr,
|
193 |
+
instruction_gr,
|
194 |
+
constraint_gr,
|
195 |
+
],
|
196 |
+
)
|
197 |
+
submit_button_gr.click(
|
198 |
+
fn=submit,
|
199 |
+
inputs=[
|
200 |
+
model_gr,
|
201 |
+
api_key_gr,
|
202 |
+
task_gr,
|
203 |
+
instruction_gr,
|
204 |
+
constraint_gr,
|
205 |
+
text_gr,
|
206 |
+
use_file_gr,
|
207 |
+
file_path_gr,
|
208 |
+
],
|
209 |
+
outputs=[py_output_gr, json_output_gr, error_output_gr],
|
210 |
+
show_progress=True,
|
211 |
+
)
|
212 |
+
clear_button.click(
|
213 |
+
fn=clear_all,
|
214 |
+
outputs=[
|
215 |
+
model_gr,
|
216 |
+
api_key_gr,
|
217 |
+
task_gr,
|
218 |
+
instruction_gr,
|
219 |
+
constraint_gr,
|
220 |
+
use_file_gr,
|
221 |
+
text_gr,
|
222 |
+
file_path_gr,
|
223 |
+
py_output_gr,
|
224 |
+
json_output_gr,
|
225 |
+
error_output_gr,
|
226 |
+
],
|
227 |
+
)
|
228 |
+
|
229 |
+
return demo
|
230 |
+
|
231 |
+
|
232 |
+
interface = create_interface()
|
233 |
+
interface.launch()
|
src/models/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .llm_def import BaseEngine, LLaMA, Qwen, MiniCPM, ChatGLM, ChatGPT, DeepSeek
|
2 |
+
from .prompt_example import *
|
3 |
+
from .prompt_template import *
|
src/models/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (434 Bytes). View file
|
|
src/models/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (315 Bytes). View file
|
|
src/models/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (359 Bytes). View file
|
|
src/models/__pycache__/llm_def.cpython-311.pyc
ADDED
Binary file (11.8 kB). View file
|
|
src/models/__pycache__/llm_def.cpython-37.pyc
ADDED
Binary file (7.14 kB). View file
|
|
src/models/__pycache__/llm_def.cpython-39.pyc
ADDED
Binary file (6.8 kB). View file
|
|
src/models/__pycache__/prompt_example.cpython-311.pyc
ADDED
Binary file (5.67 kB). View file
|
|
src/models/__pycache__/prompt_example.cpython-39.pyc
ADDED
Binary file (5.66 kB). View file
|
|
src/models/__pycache__/prompt_template.cpython-311.pyc
ADDED
Binary file (5.42 kB). View file
|
|
src/models/__pycache__/prompt_template.cpython-39.pyc
ADDED
Binary file (4.95 kB). View file
|
|
src/models/llm_def.py
ADDED
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Surpported Models.
|
3 |
+
Supports:
|
4 |
+
- Open Source:LLaMA3, Qwen2.5, MiniCPM3, ChatGLM4
|
5 |
+
- Closed Source: ChatGPT, DeepSeek
|
6 |
+
"""
|
7 |
+
|
8 |
+
from transformers import pipeline
|
9 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoTokenizer
|
10 |
+
import torch
|
11 |
+
import openai
|
12 |
+
import os
|
13 |
+
from openai import OpenAI
|
14 |
+
|
15 |
+
# The inferencing code is taken from the official documentation
|
16 |
+
|
17 |
+
class BaseEngine:
|
18 |
+
def __init__(self, model_name_or_path: str):
|
19 |
+
self.name = None
|
20 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)
|
21 |
+
self.temperature = 0.2
|
22 |
+
self.top_p = 0.9
|
23 |
+
self.max_tokens = 1024
|
24 |
+
|
25 |
+
def get_chat_response(self, prompt):
|
26 |
+
raise NotImplementedError
|
27 |
+
|
28 |
+
def set_hyperparameter(self, temperature: float = 0.2, top_p: float = 0.9, max_tokens: int = 1024):
|
29 |
+
self.temperature = temperature
|
30 |
+
self.top_p = top_p
|
31 |
+
self.max_tokens = max_tokens
|
32 |
+
|
33 |
+
class LLaMA(BaseEngine):
|
34 |
+
def __init__(self, model_name_or_path: str):
|
35 |
+
super().__init__(model_name_or_path)
|
36 |
+
self.name = "LLaMA"
|
37 |
+
self.model_id = model_name_or_path
|
38 |
+
self.pipeline = pipeline(
|
39 |
+
"text-generation",
|
40 |
+
model=self.model_id,
|
41 |
+
model_kwargs={"torch_dtype": torch.bfloat16},
|
42 |
+
device_map="auto",
|
43 |
+
)
|
44 |
+
self.terminators = [
|
45 |
+
self.pipeline.tokenizer.eos_token_id,
|
46 |
+
self.pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")
|
47 |
+
]
|
48 |
+
|
49 |
+
def get_chat_response(self, prompt):
|
50 |
+
messages = [
|
51 |
+
{"role": "system", "content": "You are a helpful assistant."},
|
52 |
+
{"role": "user", "content": prompt},
|
53 |
+
]
|
54 |
+
outputs = self.pipeline(
|
55 |
+
messages,
|
56 |
+
max_new_tokens=self.max_tokens,
|
57 |
+
eos_token_id=self.terminators,
|
58 |
+
do_sample=True,
|
59 |
+
temperature=self.temperature,
|
60 |
+
top_p=self.top_p,
|
61 |
+
)
|
62 |
+
return outputs[0]["generated_text"][-1]['content'].strip()
|
63 |
+
|
64 |
+
class Qwen(BaseEngine):
|
65 |
+
def __init__(self, model_name_or_path: str):
|
66 |
+
super().__init__(model_name_or_path)
|
67 |
+
self.name = "Qwen"
|
68 |
+
self.model_id = model_name_or_path
|
69 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
70 |
+
self.model_id,
|
71 |
+
torch_dtype="auto",
|
72 |
+
device_map="auto"
|
73 |
+
)
|
74 |
+
|
75 |
+
def get_chat_response(self, prompt):
|
76 |
+
messages = [
|
77 |
+
{"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."},
|
78 |
+
{"role": "user", "content": prompt}
|
79 |
+
]
|
80 |
+
text = self.tokenizer.apply_chat_template(
|
81 |
+
messages,
|
82 |
+
tokenize=False,
|
83 |
+
add_generation_prompt=True
|
84 |
+
)
|
85 |
+
model_inputs = self.tokenizer([text], return_tensors="pt").to(self.model.device)
|
86 |
+
generated_ids = self.model.generate(
|
87 |
+
**model_inputs,
|
88 |
+
temperature=self.temperature,
|
89 |
+
top_p=self.top_p,
|
90 |
+
max_new_tokens=self.max_tokens
|
91 |
+
)
|
92 |
+
generated_ids = [
|
93 |
+
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
|
94 |
+
]
|
95 |
+
response = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
|
96 |
+
|
97 |
+
return response
|
98 |
+
|
99 |
+
class MiniCPM(BaseEngine):
|
100 |
+
def __init__(self, model_name_or_path: str):
|
101 |
+
super().__init__(model_name_or_path)
|
102 |
+
self.name = "MiniCPM"
|
103 |
+
self.model_id = model_name_or_path
|
104 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
105 |
+
self.model_id,
|
106 |
+
torch_dtype=torch.bfloat16,
|
107 |
+
device_map="auto",
|
108 |
+
trust_remote_code=True
|
109 |
+
)
|
110 |
+
|
111 |
+
def get_chat_response(self, prompt):
|
112 |
+
messages = [
|
113 |
+
{"role": "system", "content": "You are a helpful assistant."},
|
114 |
+
{"role": "user", "content": prompt}
|
115 |
+
]
|
116 |
+
model_inputs = self.tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True).to(self.model.device)
|
117 |
+
model_outputs = self.model.generate(
|
118 |
+
model_inputs,
|
119 |
+
temperature=self.temperature,
|
120 |
+
top_p=self.top_p,
|
121 |
+
max_new_tokens=self.max_tokens
|
122 |
+
)
|
123 |
+
output_token_ids = [
|
124 |
+
model_outputs[i][len(model_inputs[i]):] for i in range(len(model_inputs))
|
125 |
+
]
|
126 |
+
response = self.tokenizer.batch_decode(output_token_ids, skip_special_tokens=True)[0].strip()
|
127 |
+
|
128 |
+
return response
|
129 |
+
|
130 |
+
class ChatGLM(BaseEngine):
|
131 |
+
def __init__(self, model_name_or_path: str):
|
132 |
+
super().__init__(model_name_or_path)
|
133 |
+
self.name = "ChatGLM"
|
134 |
+
self.model_id = model_name_or_path
|
135 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
136 |
+
self.model_id,
|
137 |
+
torch_dtype=torch.bfloat16,
|
138 |
+
device_map="auto",
|
139 |
+
low_cpu_mem_usage=True,
|
140 |
+
trust_remote_code=True
|
141 |
+
)
|
142 |
+
|
143 |
+
def get_chat_response(self, prompt):
|
144 |
+
messages = [
|
145 |
+
{"role": "system", "content": "You are a helpful assistant."},
|
146 |
+
{"role": "user", "content": prompt}
|
147 |
+
]
|
148 |
+
model_inputs = self.tokenizer.apply_chat_template(messages, return_tensors="pt", return_dict=True, add_generation_prompt=True, tokenize=True).to(self.model.device)
|
149 |
+
model_outputs = self.model.generate(
|
150 |
+
**model_inputs,
|
151 |
+
temperature=self.temperature,
|
152 |
+
top_p=self.top_p,
|
153 |
+
max_new_tokens=self.max_tokens
|
154 |
+
)
|
155 |
+
model_outputs = model_outputs[:, model_inputs['input_ids'].shape[1]:]
|
156 |
+
response = self.tokenizer.batch_decode(model_outputs, skip_special_tokens=True)[0].strip()
|
157 |
+
|
158 |
+
return response
|
159 |
+
|
160 |
+
class ChatGPT(BaseEngine):
|
161 |
+
def __init__(self, model_name_or_path: str, api_key: str, base_url=openai.base_url):
|
162 |
+
self.name = "ChatGPT"
|
163 |
+
self.model = model_name_or_path
|
164 |
+
self.base_url = base_url
|
165 |
+
self.temperature = 0.2
|
166 |
+
self.top_p = 0.9
|
167 |
+
self.max_tokens = 1024
|
168 |
+
if api_key != "":
|
169 |
+
self.api_key = api_key
|
170 |
+
else:
|
171 |
+
self.api_key = os.environ["OPENAI_API_KEY"]
|
172 |
+
self.client = OpenAI(api_key=self.api_key, base_url=self.base_url)
|
173 |
+
|
174 |
+
def get_chat_response(self, input):
|
175 |
+
response = self.client.chat.completions.create(
|
176 |
+
model=self.model,
|
177 |
+
messages=[
|
178 |
+
{"role": "user", "content": input},
|
179 |
+
],
|
180 |
+
stream=False,
|
181 |
+
temperature=self.temperature,
|
182 |
+
max_tokens=self.max_tokens,
|
183 |
+
stop=None
|
184 |
+
)
|
185 |
+
return response.choices[0].message.content
|
186 |
+
|
187 |
+
class DeepSeek(BaseEngine):
|
188 |
+
def __init__(self, model_name_or_path: str, api_key: str, base_url="https://api.deepseek.com"):
|
189 |
+
self.name = "DeepSeek"
|
190 |
+
self.model = model_name_or_path
|
191 |
+
self.base_url = base_url
|
192 |
+
self.temperature = 0.2
|
193 |
+
self.top_p = 0.9
|
194 |
+
self.max_tokens = 1024
|
195 |
+
if api_key != "":
|
196 |
+
self.api_key = api_key
|
197 |
+
else:
|
198 |
+
self.api_key = os.environ["DEEPSEEK_API_KEY"]
|
199 |
+
self.client = OpenAI(api_key=self.api_key, base_url=self.base_url)
|
200 |
+
|
201 |
+
def get_chat_response(self, input):
|
202 |
+
response = self.client.chat.completions.create(
|
203 |
+
model=self.model,
|
204 |
+
messages=[
|
205 |
+
{"role": "user", "content": input},
|
206 |
+
],
|
207 |
+
stream=False,
|
208 |
+
temperature=self.temperature,
|
209 |
+
max_tokens=self.max_tokens,
|
210 |
+
stop=None
|
211 |
+
)
|
212 |
+
return response.choices[0].message.content
|
src/models/prompt_example.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
json_schema_examples = """
|
2 |
+
**Task**: Please extract all economic policies affecting the stock market between 2015 and 2023 and the exact dates of their implementation.
|
3 |
+
**Text**: This text is from the field of Economics and represents the genre of Article.
|
4 |
+
...(example text)...
|
5 |
+
**Output Schema**:
|
6 |
+
{
|
7 |
+
"economic_policies": [
|
8 |
+
{
|
9 |
+
"name": null,
|
10 |
+
"implementation_date": null
|
11 |
+
}
|
12 |
+
]
|
13 |
+
}
|
14 |
+
|
15 |
+
Example2:
|
16 |
+
**Task**: Tell me the main content of papers related to NLP between 2022 and 2023.
|
17 |
+
**Text**: This text is from the field of AI and represents the genre of Research Paper.
|
18 |
+
...(example text)...
|
19 |
+
**Output Schema**:
|
20 |
+
{
|
21 |
+
"papers": [
|
22 |
+
{
|
23 |
+
"title": null,
|
24 |
+
"content": null
|
25 |
+
}
|
26 |
+
]
|
27 |
+
}
|
28 |
+
|
29 |
+
Example3:
|
30 |
+
**Task**: Extract all the information in the given text.
|
31 |
+
**Text**: This text is from the field of Political and represents the genre of News Report.
|
32 |
+
...(example text)...
|
33 |
+
**Output Schema**:
|
34 |
+
Answer:
|
35 |
+
{
|
36 |
+
"news_report":
|
37 |
+
{
|
38 |
+
"title": null,
|
39 |
+
"summary": null,
|
40 |
+
"publication_date": null,
|
41 |
+
"keywords": [],
|
42 |
+
"events": [
|
43 |
+
{
|
44 |
+
"name": null,
|
45 |
+
"time": null,
|
46 |
+
"people_involved": [],
|
47 |
+
"cause": null,
|
48 |
+
"process": null,
|
49 |
+
"result": null
|
50 |
+
}
|
51 |
+
],
|
52 |
+
quotes: [],
|
53 |
+
viewpoints: []
|
54 |
+
}
|
55 |
+
}
|
56 |
+
"""
|
57 |
+
|
58 |
+
code_schema_examples = """
|
59 |
+
Example1:
|
60 |
+
**Task**: Extract all the entities in the given text.
|
61 |
+
**Text**:
|
62 |
+
...(example text)...
|
63 |
+
**Output Schema**:
|
64 |
+
```python
|
65 |
+
from typing import List, Optional
|
66 |
+
from pydantic import BaseModel, Field
|
67 |
+
|
68 |
+
class Entity(BaseModel):
|
69 |
+
label : str = Field(description="The type or category of the entity, such as 'Process', 'Technique', 'Data Structure', 'Methodology', 'Person', etc. ")
|
70 |
+
name : str = Field(description="The specific name of the entity. It should represent a single, distinct concept and must not be an empty string. For example, if the entity is a 'Technique', the name could be 'Neural Networks'.")
|
71 |
+
|
72 |
+
class ExtractionTarget(BaseModel):
|
73 |
+
entity_list : List[Entity] = Field(description="All the entities presented in the context. The entities should encode ONE concept.")
|
74 |
+
```
|
75 |
+
|
76 |
+
Example2:
|
77 |
+
**Task**: Extract all the information in the given text.
|
78 |
+
**Text**: This text is from the field of Political and represents the genre of News Article.
|
79 |
+
...(example text)...
|
80 |
+
**Output Schema**:
|
81 |
+
```python
|
82 |
+
from typing import List, Optional
|
83 |
+
from pydantic import BaseModel, Field
|
84 |
+
|
85 |
+
class Person(BaseModel):
|
86 |
+
name: str = Field(description="The name of the person")
|
87 |
+
identity: Optional[str] = Field(description="The occupation, status or characteristics of the person.")
|
88 |
+
role: Optional[str] = Field(description="The role or function the person plays in an event.")
|
89 |
+
|
90 |
+
class Event(BaseModel):
|
91 |
+
name: str = Field(description="Name of the event")
|
92 |
+
time: Optional[str] = Field(description="Time when the event took place")
|
93 |
+
people_involved: Optional[List[Person]] = Field(description="People involved in the event")
|
94 |
+
cause: Optional[str] = Field(default=None, description="Reason for the event, if applicable")
|
95 |
+
process: Optional[str] = Field(description="Details of the event process")
|
96 |
+
result: Optional[str] = Field(default=None, description="Result or outcome of the event")
|
97 |
+
|
98 |
+
class ExtractionTarget(BaseModel):
|
99 |
+
title: str = Field(description="The title or headline of the news article")
|
100 |
+
summary: str = Field(description="A brief summary of the news article")
|
101 |
+
publication_date: Optional[str] = Field(description="The publication date of the article")
|
102 |
+
keywords: Optional[List[str]] = Field(description="List of keywords or topics covered in the article")
|
103 |
+
events: List[Event] = Field(description="Events covered in the article")
|
104 |
+
quotes: Optional[List[str]] = Field(default=None, description="Quotes related to the news, if any")
|
105 |
+
viewpoints: Optional[List[str]] = Field(default=None, description="Different viewpoints regarding the news")
|
106 |
+
```
|
107 |
+
|
108 |
+
Example3:
|
109 |
+
**Task**: Extract the key information in the given text.
|
110 |
+
**Text**: This text is from the field of AI and represents the genre of Research Paper.
|
111 |
+
...(example text)...
|
112 |
+
```python
|
113 |
+
from typing import List, Optional
|
114 |
+
from pydantic import BaseModel, Field
|
115 |
+
|
116 |
+
class MetaData(BaseModel):
|
117 |
+
title : str = Field(description="The title of the article")
|
118 |
+
authors : List[str] = Field(description="The list of the article's authors")
|
119 |
+
abstract: str = Field(description="The article's abstract")
|
120 |
+
key_words: List[str] = Field(description="The key words associated with the article")
|
121 |
+
|
122 |
+
class Baseline(BaseModel):
|
123 |
+
method_name : str = Field(description="The name of the baseline method")
|
124 |
+
proposed_solution : str = Field(description="the proposed solution in details")
|
125 |
+
performance_metrics : str = Field(description="The performance metrics of the method and comparative analysis")
|
126 |
+
|
127 |
+
class ExtractionTarget(BaseModel):
|
128 |
+
|
129 |
+
key_contributions: List[str] = Field(description="The key contributions of the article")
|
130 |
+
limitation_of_sota : str=Field(description="the summary limitation of the existing work")
|
131 |
+
proposed_solution : str = Field(description="the proposed solution in details")
|
132 |
+
baselines : List[Baseline] = Field(description="The list of baseline methods and their details")
|
133 |
+
performance_metrics : str = Field(description="The performance metrics of the method and comparative analysis")
|
134 |
+
paper_limitations : str=Field(description="The limitations of the proposed solution of the paper")
|
135 |
+
```
|
136 |
+
|
137 |
+
"""
|
src/models/prompt_template.py
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain.prompts import PromptTemplate
|
2 |
+
from .prompt_example import *
|
3 |
+
|
4 |
+
# ==================================================================== #
|
5 |
+
# SCHEMA AGENT #
|
6 |
+
# ==================================================================== #
|
7 |
+
|
8 |
+
# Get Text Analysis
|
9 |
+
TEXT_ANALYSIS_INSTRUCTION = """
|
10 |
+
**Instruction**: Please analyze and categorize the given text.
|
11 |
+
{examples}
|
12 |
+
**Text**: {text}
|
13 |
+
|
14 |
+
**Output Shema**: {schema}
|
15 |
+
"""
|
16 |
+
|
17 |
+
text_analysis_instruction = PromptTemplate(
|
18 |
+
input_variables=["examples", "text", "schema"],
|
19 |
+
template=TEXT_ANALYSIS_INSTRUCTION,
|
20 |
+
)
|
21 |
+
|
22 |
+
# Get Deduced Schema Json
|
23 |
+
DEDUCE_SCHEMA_JSON_INSTRUCTION = """
|
24 |
+
**Instruction**: Generate an output format that meets the requirements as described in the task. Pay attention to the following requirements:
|
25 |
+
- Format: Return your responses in dictionary format as a JSON object.
|
26 |
+
- Content: Do not include any actual data; all attributes values should be set to None.
|
27 |
+
- Note: Attributes not mentioned in the task description should be ignored.
|
28 |
+
{examples}
|
29 |
+
**Task**: {instruction}
|
30 |
+
|
31 |
+
**Text**: {distilled_text}
|
32 |
+
{text}
|
33 |
+
|
34 |
+
Now please deduce the output schema in json format. All attributes values should be set to None.
|
35 |
+
**Output Schema**:
|
36 |
+
"""
|
37 |
+
|
38 |
+
deduced_schema_json_instruction = PromptTemplate(
|
39 |
+
input_variables=["examples", "instruction", "distilled_text", "text", "schema"],
|
40 |
+
template=DEDUCE_SCHEMA_JSON_INSTRUCTION,
|
41 |
+
)
|
42 |
+
|
43 |
+
# Get Deduced Schema Code
|
44 |
+
DEDUCE_SCHEMA_CODE_INSTRUCTION = """
|
45 |
+
**Instruction**: Based on the provided text and task description, Define the output schema in Python using Pydantic. Name the final extraction target class as 'ExtractionTarget'.
|
46 |
+
{examples}
|
47 |
+
**Task**: {instruction}
|
48 |
+
|
49 |
+
**Text**: {distilled_text}
|
50 |
+
{text}
|
51 |
+
|
52 |
+
Now please deduce the output schema. Ensure that the output code snippet is wrapped in '```',and can be directly parsed by the Python interpreter.
|
53 |
+
**Output Schema**: """
|
54 |
+
deduced_schema_code_instruction = PromptTemplate(
|
55 |
+
input_variables=["examples", "instruction", "distilled_text", "text"],
|
56 |
+
template=DEDUCE_SCHEMA_CODE_INSTRUCTION,
|
57 |
+
)
|
58 |
+
|
59 |
+
|
60 |
+
# ==================================================================== #
|
61 |
+
# EXTRACTION AGENT #
|
62 |
+
# ==================================================================== #
|
63 |
+
|
64 |
+
EXTRACT_INSTRUCTION = """
|
65 |
+
**Instruction**: You are an agent skilled in information extarction. {instruction}
|
66 |
+
{examples}
|
67 |
+
**Text**: {text}
|
68 |
+
{additional_info}
|
69 |
+
**Output Schema**: {schema}
|
70 |
+
|
71 |
+
Now please extract the corresponding information from the text. Ensure that the information you extract has a clear reference in the given text. Set any property not explicitly mentioned in the text to null.
|
72 |
+
"""
|
73 |
+
|
74 |
+
extract_instruction = PromptTemplate(
|
75 |
+
input_variables=["instruction", "examples", "text", "schema", "additional_info"],
|
76 |
+
template=EXTRACT_INSTRUCTION,
|
77 |
+
)
|
78 |
+
|
79 |
+
SUMMARIZE_INSTRUCTION = """
|
80 |
+
**Instruction**: Below is a list of results obtained after segmenting and extracting information from a long article. Please consolidate all the answers to generate a final response.
|
81 |
+
{examples}
|
82 |
+
**Task**: {instruction}
|
83 |
+
|
84 |
+
**Result List**: {answer_list}
|
85 |
+
|
86 |
+
**Output Schema**: {schema}
|
87 |
+
Now summarize all the information from the Result List.
|
88 |
+
"""
|
89 |
+
summarize_instruction = PromptTemplate(
|
90 |
+
input_variables=["instruction", "examples", "answer_list", "schema"],
|
91 |
+
template=SUMMARIZE_INSTRUCTION,
|
92 |
+
)
|
93 |
+
|
94 |
+
|
95 |
+
# ==================================================================== #
|
96 |
+
# REFLECION AGENT #
|
97 |
+
# ==================================================================== #
|
98 |
+
REFLECT_INSTRUCTION = """**Instruction**: You are an agent skilled in reflection and optimization based on the original result. Refer to **Reflection Reference** to identify potential issues in the current extraction results.
|
99 |
+
|
100 |
+
**Reflection Reference**: {examples}
|
101 |
+
|
102 |
+
Now please review each element in the extraction result. Identify and improve any potential issues in the result based on the reflection. NOTE: If the original result is correct, no modifications are needed!
|
103 |
+
|
104 |
+
**Task**: {instruction}
|
105 |
+
|
106 |
+
**Text**: {text}
|
107 |
+
|
108 |
+
**Output Schema**: {schema}
|
109 |
+
|
110 |
+
**Original Result**: {result}
|
111 |
+
|
112 |
+
"""
|
113 |
+
reflect_instruction = PromptTemplate(
|
114 |
+
input_variables=["instruction", "examples", "text", "schema", "result"],
|
115 |
+
template=REFLECT_INSTRUCTION,
|
116 |
+
)
|
117 |
+
|
118 |
+
SUMMARIZE_INSTRUCTION = """
|
119 |
+
**Instruction**: Below is a list of results obtained after segmenting and extracting information from a long article. Please consolidate all the answers to generate a final response.
|
120 |
+
|
121 |
+
**Task**: {instruction}
|
122 |
+
|
123 |
+
**Result List**: {answer_list}
|
124 |
+
{additional_info}
|
125 |
+
**Output Schema**: {schema}
|
126 |
+
Now summarize the information from the Result List.
|
127 |
+
"""
|
128 |
+
summarize_instruction = PromptTemplate(
|
129 |
+
input_variables=["instruction", "answer_list", "additional_info", "schema"],
|
130 |
+
template=SUMMARIZE_INSTRUCTION,
|
131 |
+
)
|
132 |
+
|
133 |
+
|
134 |
+
|
135 |
+
# ==================================================================== #
|
136 |
+
# CASE REPOSITORY #
|
137 |
+
# ==================================================================== #
|
138 |
+
|
139 |
+
GOOD_CASE_ANALYSIS_INSTRUCTION = """
|
140 |
+
**Instruction**: Below is an information extraction task and its corresponding correct answer. Provide the reasoning steps that led to the correct answer, along with brief explanation of the answer. Your response should be brief and organized.
|
141 |
+
|
142 |
+
**Task**: {instruction}
|
143 |
+
|
144 |
+
**Text**: {text}
|
145 |
+
{additional_info}
|
146 |
+
**Correct Answer**: {result}
|
147 |
+
|
148 |
+
Now please generate the reasoning steps and breif analysis of the **Correct Answer** given above. DO NOT generate your own extraction result.
|
149 |
+
**Analysis**:
|
150 |
+
"""
|
151 |
+
good_case_analysis_instruction = PromptTemplate(
|
152 |
+
input_variables=["instruction", "text", "result", "additional_info"],
|
153 |
+
template=GOOD_CASE_ANALYSIS_INSTRUCTION,
|
154 |
+
)
|
155 |
+
|
156 |
+
BAD_CASE_REFLECTION_INSTRUCTION = """
|
157 |
+
**Instruction**: Based on the task description, compare the original answer with the correct one. Your output should be a brief reflection or concise summarized rules.
|
158 |
+
|
159 |
+
**Task**: {instruction}
|
160 |
+
|
161 |
+
**Text**: {text}
|
162 |
+
{additional_info}
|
163 |
+
**Original Answer**: {original_answer}
|
164 |
+
|
165 |
+
**Correct Answer**: {correct_answer}
|
166 |
+
|
167 |
+
Now please generate a brief and organized reflection. DO NOT generate your own extraction result.
|
168 |
+
**Reflection**:
|
169 |
+
"""
|
170 |
+
|
171 |
+
bad_case_reflection_instruction = PromptTemplate(
|
172 |
+
input_variables=["instruction", "text", "original_answer", "correct_answer", "additional_info"],
|
173 |
+
template=BAD_CASE_REFLECTION_INSTRUCTION,
|
174 |
+
)
|
src/modules/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .schema_agent import SchemaAgent
|
2 |
+
from .extraction_agent import ExtractionAgent
|
3 |
+
from .reflection_agent import ReflectionAgent
|
4 |
+
from .knowledge_base.case_repository import CaseRepositoryHandler
|
src/modules/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (459 Bytes). View file
|
|
src/modules/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (392 Bytes). View file
|
|
src/modules/__pycache__/extraction_agent.cpython-311.pyc
ADDED
Binary file (6.66 kB). View file
|
|
src/modules/__pycache__/extraction_agent.cpython-39.pyc
ADDED
Binary file (4.06 kB). View file
|
|
src/modules/__pycache__/reflection_agent.cpython-311.pyc
ADDED
Binary file (6.98 kB). View file
|
|
src/modules/__pycache__/reflection_agent.cpython-39.pyc
ADDED
Binary file (4.01 kB). View file
|
|
src/modules/__pycache__/schema_agent.cpython-311.pyc
ADDED
Binary file (10.7 kB). View file
|
|
src/modules/__pycache__/schema_agent.cpython-39.pyc
ADDED
Binary file (6.64 kB). View file
|
|
src/modules/extraction_agent.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from models import *
|
2 |
+
from utils import *
|
3 |
+
from .knowledge_base.case_repository import CaseRepositoryHandler
|
4 |
+
|
5 |
+
class InformationExtractor:
|
6 |
+
def __init__(self, llm: BaseEngine):
|
7 |
+
self.llm = llm
|
8 |
+
|
9 |
+
def extract_information(self, instruction="", text="", examples="", schema="", additional_info=""):
|
10 |
+
examples = good_case_wrapper(examples)
|
11 |
+
prompt = extract_instruction.format(instruction=instruction, examples=examples, text=text, additional_info=additional_info, schema=schema)
|
12 |
+
response = self.llm.get_chat_response(prompt)
|
13 |
+
response = extract_json_dict(response)
|
14 |
+
print(f"prompt: {prompt}")
|
15 |
+
print("========================================")
|
16 |
+
print(f"response: {response}")
|
17 |
+
return response
|
18 |
+
|
19 |
+
def summarize_answer(self, instruction="", answer_list="", schema="", additional_info=""):
|
20 |
+
prompt = summarize_instruction.format(instruction=instruction, answer_list=answer_list, schema=schema, additional_info=additional_info)
|
21 |
+
response = self.llm.get_chat_response(prompt)
|
22 |
+
response = extract_json_dict(response)
|
23 |
+
return response
|
24 |
+
|
25 |
+
class ExtractionAgent:
|
26 |
+
def __init__(self, llm: BaseEngine, case_repo: CaseRepositoryHandler):
|
27 |
+
self.llm = llm
|
28 |
+
self.module = InformationExtractor(llm = llm)
|
29 |
+
self.case_repo = case_repo
|
30 |
+
self.methods = ["extract_information_direct", "extract_information_with_case"]
|
31 |
+
|
32 |
+
def __get_constraint(self, data: DataPoint):
|
33 |
+
if data.constraint == "":
|
34 |
+
return data
|
35 |
+
if data.task == "NER":
|
36 |
+
constraint = json.dumps(data.constraint)
|
37 |
+
if "**Entity Type Constraint**" in constraint:
|
38 |
+
return data
|
39 |
+
data.constraint = f"\n**Entity Type Constraint**: The type of entities must be chosen from the following list.\n{constraint}\n"
|
40 |
+
elif data.task == "RE":
|
41 |
+
constraint = json.dumps(data.constraint)
|
42 |
+
if "**Relation Type Constraint**" in constraint:
|
43 |
+
return data
|
44 |
+
data.constraint = f"\n**Relation Type Constraint**: The type of relations must be chosen from the following list.\n{constraint}\n"
|
45 |
+
elif data.task == "EE":
|
46 |
+
constraint = json.dumps(data.constraint)
|
47 |
+
if "**Event Extraction Constraint**" in constraint:
|
48 |
+
return data
|
49 |
+
data.constraint = f"\n**Event Extraction Constraint**: The event type must be selected from the following dictionary keys, and its event arguments should be chosen from its corresponding dictionary values. \n{constraint}\n"
|
50 |
+
return data
|
51 |
+
|
52 |
+
def extract_information_direct(self, data: DataPoint):
|
53 |
+
data = self.__get_constraint(data)
|
54 |
+
result_list = []
|
55 |
+
for chunk_text in data.chunk_text_list:
|
56 |
+
extract_direct_result = self.module.extract_information(instruction=data.instruction, text=chunk_text, schema=data.output_schema, examples="", additional_info=data.constraint)
|
57 |
+
result_list.append(extract_direct_result)
|
58 |
+
function_name = current_function_name()
|
59 |
+
data.set_result_list(result_list)
|
60 |
+
data.update_trajectory(function_name, result_list)
|
61 |
+
return data
|
62 |
+
|
63 |
+
def extract_information_with_case(self, data: DataPoint):
|
64 |
+
data = self.__get_constraint(data)
|
65 |
+
result_list = []
|
66 |
+
for chunk_text in data.chunk_text_list:
|
67 |
+
examples = self.case_repo.query_good_case(data)
|
68 |
+
extract_case_result = self.module.extract_information(instruction=data.instruction, text=chunk_text, schema=data.output_schema, examples=examples, additional_info=data.constraint)
|
69 |
+
result_list.append(extract_case_result)
|
70 |
+
function_name = current_function_name()
|
71 |
+
data.set_result_list(result_list)
|
72 |
+
data.update_trajectory(function_name, result_list)
|
73 |
+
return data
|
74 |
+
|
75 |
+
def summarize_answer(self, data: DataPoint):
|
76 |
+
if len(data.result_list) == 0:
|
77 |
+
return data
|
78 |
+
if len(data.result_list) == 1:
|
79 |
+
data.set_pred(data.result_list[0])
|
80 |
+
return data
|
81 |
+
summarized_result = self.module.summarize_answer(instruction=data.instruction, answer_list=data.result_list, schema=data.output_schema, additional_info=data.constraint)
|
82 |
+
funtion_name = current_function_name()
|
83 |
+
data.set_pred(summarized_result)
|
84 |
+
data.update_trajectory(funtion_name, summarized_result)
|
85 |
+
return data
|
src/modules/knowledge_base/__pycache__/case_repository.cpython-311.pyc
ADDED
Binary file (4.64 kB). View file
|
|
src/modules/knowledge_base/__pycache__/case_repository.cpython-39.pyc
ADDED
Binary file (3.8 kB). View file
|
|
src/modules/knowledge_base/__pycache__/schema_repository.cpython-311.pyc
ADDED
Binary file (9.25 kB). View file
|
|
src/modules/knowledge_base/__pycache__/schema_repository.cpython-39.pyc
ADDED
Binary file (5.94 kB). View file
|
|
src/modules/knowledge_base/case_repository.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
src/modules/knowledge_base/case_repository.py
ADDED
@@ -0,0 +1,391 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# import json
|
2 |
+
# import os
|
3 |
+
# import torch
|
4 |
+
# import numpy as np
|
5 |
+
# from utils import *
|
6 |
+
# from sentence_transformers import SentenceTransformer
|
7 |
+
# from rapidfuzz import process
|
8 |
+
# from models import *
|
9 |
+
# import copy
|
10 |
+
|
11 |
+
# import warnings
|
12 |
+
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
13 |
+
# warnings.filterwarnings("ignore", category=FutureWarning, message=r".*clean_up_tokenization_spaces*")
|
14 |
+
|
15 |
+
# class CaseRepository:
|
16 |
+
# def __init__(self):
|
17 |
+
# self.embedder = SentenceTransformer("all-MiniLM-L6-v2")
|
18 |
+
# self.embedder.to(device)
|
19 |
+
# self.corpus = self.load_corpus()
|
20 |
+
# self.embedded_corpus = self.embed_corpus()
|
21 |
+
|
22 |
+
# def load_corpus(self):
|
23 |
+
# with open(os.path.join(os.path.dirname(__file__), "case_repository.json")) as file:
|
24 |
+
# corpus = json.load(file)
|
25 |
+
# return corpus
|
26 |
+
|
27 |
+
# def update_corpus(self):
|
28 |
+
# try:
|
29 |
+
# with open(os.path.join(os.path.dirname(__file__), "case_repository.json"), "w") as file:
|
30 |
+
# json.dump(self.corpus, file, indent=2)
|
31 |
+
# except Exception as e:
|
32 |
+
# print(f"Error when updating corpus: {e}")
|
33 |
+
|
34 |
+
# def embed_corpus(self):
|
35 |
+
# embedded_corpus = {}
|
36 |
+
# for key, content in self.corpus.items():
|
37 |
+
# good_index = [item['index']['embed_index'] for item in content['good']]
|
38 |
+
# encoded_good_index = self.embedder.encode(good_index, convert_to_tensor=True).to(device)
|
39 |
+
# bad_index = [item['index']['embed_index'] for item in content['bad']]
|
40 |
+
# encoded_bad_index = self.embedder.encode(bad_index, convert_to_tensor=True).to(device)
|
41 |
+
# embedded_corpus[key] = {"good": encoded_good_index, "bad": encoded_bad_index}
|
42 |
+
# return embedded_corpus
|
43 |
+
|
44 |
+
# def get_similarity_scores(self, task: TaskType, embed_index="", str_index="", case_type="", top_k=2):
|
45 |
+
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
46 |
+
# # Embedding similarity match
|
47 |
+
# encoded_embed_query = self.embedder.encode(embed_index, convert_to_tensor=True).to(device)
|
48 |
+
# embedding_similarity_matrix = self.embedder.similarity(encoded_embed_query, self.embedded_corpus[task][case_type])
|
49 |
+
# embedding_similarity_scores = embedding_similarity_matrix[0].to(device)
|
50 |
+
|
51 |
+
# # String similarity match
|
52 |
+
# str_match_corpus = [item['index']['str_index'] for item in self.corpus[task][case_type]]
|
53 |
+
# str_similarity_results = process.extract(str_index, str_match_corpus, limit=len(str_match_corpus))
|
54 |
+
# scores_dict = {match[0]: match[1] for match in str_similarity_results}
|
55 |
+
# scores_in_order = [scores_dict[candidate] for candidate in str_match_corpus]
|
56 |
+
# str_similarity_scores = torch.tensor(scores_in_order, dtype=torch.float32).to(device)
|
57 |
+
|
58 |
+
# # Normalize scores
|
59 |
+
# embedding_score_range = embedding_similarity_scores.max() - embedding_similarity_scores.min()
|
60 |
+
# str_score_range = str_similarity_scores.max() - str_similarity_scores.min()
|
61 |
+
# if embedding_score_range > 0:
|
62 |
+
# embed_norm_scores = (embedding_similarity_scores - embedding_similarity_scores.min()) / embedding_score_range
|
63 |
+
# else:
|
64 |
+
# embed_norm_scores = embedding_similarity_scores
|
65 |
+
# if str_score_range > 0:
|
66 |
+
# str_norm_scores = (str_similarity_scores - str_similarity_scores.min()) / str_score_range
|
67 |
+
# else:
|
68 |
+
# str_norm_scores = str_similarity_scores / 100
|
69 |
+
|
70 |
+
# # Combine the scores with weights
|
71 |
+
# combined_scores = 0.5 * embed_norm_scores + 0.5 * str_norm_scores
|
72 |
+
# original_combined_scores = 0.5 * embedding_similarity_scores + 0.5 * str_similarity_scores / 100
|
73 |
+
|
74 |
+
# scores, indices = torch.topk(combined_scores, k=min(top_k, combined_scores.size(0)))
|
75 |
+
# original_scores, original_indices = torch.topk(original_combined_scores, k=min(top_k, original_combined_scores.size(0)))
|
76 |
+
# return scores, indices, original_scores, original_indices
|
77 |
+
|
78 |
+
# def query_case(self, task: TaskType, embed_index="", str_index="", case_type="", top_k=2) -> list:
|
79 |
+
# _, indices, _, _ = self.get_similarity_scores(task, embed_index, str_index, case_type, top_k)
|
80 |
+
# top_matches = [self.corpus[task][case_type][idx]["content"] for idx in indices]
|
81 |
+
# return top_matches
|
82 |
+
|
83 |
+
# def update_case(self, task: TaskType, embed_index="", str_index="", content="" ,case_type=""):
|
84 |
+
# self.corpus[task][case_type].append({"index": {"embed_index": embed_index, "str_index": str_index}, "content": content})
|
85 |
+
# self.embedded_corpus[task][case_type] = torch.cat([self.embedded_corpus[task][case_type], self.embedder.encode([embed_index], convert_to_tensor=True).to(device)], dim=0)
|
86 |
+
# print(f"Case updated for {task} task.")
|
87 |
+
|
88 |
+
# class CaseRepositoryHandler:
|
89 |
+
# def __init__(self, llm: BaseEngine):
|
90 |
+
# self.repository = CaseRepository()
|
91 |
+
# self.llm = llm
|
92 |
+
|
93 |
+
# def __get_good_case_analysis(self, instruction="", text="", result="", additional_info=""):
|
94 |
+
# prompt = good_case_analysis_instruction.format(
|
95 |
+
# instruction=instruction, text=text, result=result, additional_info=additional_info
|
96 |
+
# )
|
97 |
+
# for _ in range(3):
|
98 |
+
# response = self.llm.get_chat_response(prompt)
|
99 |
+
# response = extract_json_dict(response)
|
100 |
+
# if not isinstance(response, dict):
|
101 |
+
# return response
|
102 |
+
# return None
|
103 |
+
|
104 |
+
# def __get_bad_case_reflection(self, instruction="", text="", original_answer="", correct_answer="", additional_info=""):
|
105 |
+
# prompt = bad_case_reflection_instruction.format(
|
106 |
+
# instruction=instruction, text=text, original_answer=original_answer, correct_answer=correct_answer, additional_info=additional_info
|
107 |
+
# )
|
108 |
+
# for _ in range(3):
|
109 |
+
# response = self.llm.get_chat_response(prompt)
|
110 |
+
# response = extract_json_dict(response)
|
111 |
+
# if not isinstance(response, dict):
|
112 |
+
# return response
|
113 |
+
# return None
|
114 |
+
|
115 |
+
# def __get_index(self, data: DataPoint, case_type: str):
|
116 |
+
# # set embed_index
|
117 |
+
# embed_index = f"**Text**: {data.distilled_text}\n{data.chunk_text_list[0]}"
|
118 |
+
|
119 |
+
# # set str_index
|
120 |
+
# if data.task == "Base":
|
121 |
+
# str_index = f"**Task**: {data.instruction}"
|
122 |
+
# else:
|
123 |
+
# str_index = f"{data.constraint}"
|
124 |
+
|
125 |
+
# if case_type == "bad":
|
126 |
+
# str_index += f"\n\n**Original Result**: {json.dumps(data.pred)}"
|
127 |
+
|
128 |
+
# return embed_index, str_index
|
129 |
+
|
130 |
+
# def query_good_case(self, data: DataPoint):
|
131 |
+
# embed_index, str_index = self.__get_index(data, "good")
|
132 |
+
# return self.repository.query_case(task=data.task, embed_index=embed_index, str_index=str_index, case_type="good")
|
133 |
+
|
134 |
+
# def query_bad_case(self, data: DataPoint):
|
135 |
+
# embed_index, str_index = self.__get_index(data, "bad")
|
136 |
+
# return self.repository.query_case(task=data.task, embed_index=embed_index, str_index=str_index, case_type="bad")
|
137 |
+
|
138 |
+
# def update_good_case(self, data: DataPoint):
|
139 |
+
# if data.truth == "" :
|
140 |
+
# print("No truth value provided.")
|
141 |
+
# return
|
142 |
+
# embed_index, str_index = self.__get_index(data, "good")
|
143 |
+
# _, _, original_scores, _ = self.repository.get_similarity_scores(data.task, embed_index, str_index, "good", 1)
|
144 |
+
# original_scores = original_scores.tolist()
|
145 |
+
# if original_scores[0] >= 0.9:
|
146 |
+
# print("The similar good case is already in the corpus. Similarity Score: ", original_scores[0])
|
147 |
+
# return
|
148 |
+
# good_case_alaysis = self.__get_good_case_analysis(instruction=data.instruction, text=data.distilled_text, result=data.truth, additional_info=data.constraint)
|
149 |
+
# wrapped_good_case_analysis = f"**Analysis**: {good_case_alaysis}"
|
150 |
+
# wrapped_instruction = f"**Task**: {data.instruction}"
|
151 |
+
# wrapped_text = f"**Text**: {data.distilled_text}\n{data.chunk_text_list[0]}"
|
152 |
+
# wrapped_answer = f"**Correct Answer**: {json.dumps(data.truth)}"
|
153 |
+
# if data.task == "Base":
|
154 |
+
# content = f"{wrapped_instruction}\n\n{wrapped_text}\n\n{wrapped_good_case_analysis}\n\n{wrapped_answer}"
|
155 |
+
# else:
|
156 |
+
# content = f"{wrapped_text}\n\n{data.constraint}\n\n{wrapped_good_case_analysis}\n\n{wrapped_answer}"
|
157 |
+
# self.repository.update_case(data.task, embed_index, str_index, content, "good")
|
158 |
+
|
159 |
+
# def update_bad_case(self, data: DataPoint):
|
160 |
+
# if data.truth == "" :
|
161 |
+
# print("No truth value provided.")
|
162 |
+
# return
|
163 |
+
# if normalize_obj(data.pred) == normalize_obj(data.truth):
|
164 |
+
# return
|
165 |
+
# embed_index, str_index = self.__get_index(data, "bad")
|
166 |
+
# _, _, original_scores, _ = self.repository.get_similarity_scores(data.task, embed_index, str_index, "bad", 1)
|
167 |
+
# original_scores = original_scores.tolist()
|
168 |
+
# if original_scores[0] >= 0.9:
|
169 |
+
# print("The similar bad case is already in the corpus. Similarity Score: ", original_scores[0])
|
170 |
+
# return
|
171 |
+
# bad_case_reflection = self.__get_bad_case_reflection(instruction=data.instruction, text=data.distilled_text, original_answer=data.pred, correct_answer=data.truth, additional_info=data.constraint)
|
172 |
+
# wrapped_bad_case_reflection = f"**Reflection**: {bad_case_reflection}"
|
173 |
+
# wrapper_original_answer = f"**Original Answer**: {json.dumps(data.pred)}"
|
174 |
+
# wrapper_correct_answer = f"**Correct Answer**: {json.dumps(data.truth)}"
|
175 |
+
# wrapped_instruction = f"**Task**: {data.instruction}"
|
176 |
+
# wrapped_text = f"**Text**: {data.distilled_text}\n{data.chunk_text_list[0]}"
|
177 |
+
# if data.task == "Base":
|
178 |
+
# content = f"{wrapped_instruction}\n\n{wrapped_text}\n\n{wrapper_original_answer}\n\n{wrapped_bad_case_reflection}\n\n{wrapper_correct_answer}"
|
179 |
+
# else:
|
180 |
+
# content = f"{wrapped_text}\n\n{data.constraint}\n\n{wrapper_original_answer}\n\n{wrapped_bad_case_reflection}\n\n{wrapper_correct_answer}"
|
181 |
+
# self.repository.update_case(data.task, embed_index, str_index, content, "bad")
|
182 |
+
|
183 |
+
# def update_case(self, data: DataPoint):
|
184 |
+
# self.update_good_case(data)
|
185 |
+
# self.update_bad_case(data)
|
186 |
+
# self.repository.update_corpus()
|
187 |
+
|
188 |
+
|
189 |
+
|
190 |
+
import json
|
191 |
+
import os
|
192 |
+
import torch
|
193 |
+
import numpy as np
|
194 |
+
from utils import *
|
195 |
+
from sentence_transformers import SentenceTransformer
|
196 |
+
from rapidfuzz import process
|
197 |
+
from models import *
|
198 |
+
import copy
|
199 |
+
|
200 |
+
import warnings
|
201 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
202 |
+
warnings.filterwarnings("ignore", category=FutureWarning, message=r".*clean_up_tokenization_spaces*")
|
203 |
+
|
204 |
+
class CaseRepository:
|
205 |
+
def __init__(self):
|
206 |
+
# self.embedder = SentenceTransformer("all-MiniLM-L6-v2")
|
207 |
+
# self.embedder.to(device)
|
208 |
+
# self.corpus = self.load_corpus()
|
209 |
+
# self.embedded_corpus = self.embed_corpus()
|
210 |
+
pass
|
211 |
+
|
212 |
+
def load_corpus(self):
|
213 |
+
# with open(os.path.join(os.path.dirname(__file__), "case_repository.json")) as file:
|
214 |
+
# corpus = json.load(file)
|
215 |
+
# return corpus
|
216 |
+
pass
|
217 |
+
|
218 |
+
def update_corpus(self):
|
219 |
+
# try:
|
220 |
+
# with open(os.path.join(os.path.dirname(__file__), "case_repository.json"), "w") as file:
|
221 |
+
# json.dump(self.corpus, file, indent=2)
|
222 |
+
# except Exception as e:
|
223 |
+
# print(f"Error when updating corpus: {e}")
|
224 |
+
pass
|
225 |
+
|
226 |
+
def embed_corpus(self):
|
227 |
+
# embedded_corpus = {}
|
228 |
+
# for key, content in self.corpus.items():
|
229 |
+
# good_index = [item['index']['embed_index'] for item in content['good']]
|
230 |
+
# encoded_good_index = self.embedder.encode(good_index, convert_to_tensor=True).to(device)
|
231 |
+
# bad_index = [item['index']['embed_index'] for item in content['bad']]
|
232 |
+
# encoded_bad_index = self.embedder.encode(bad_index, convert_to_tensor=True).to(device)
|
233 |
+
# embedded_corpus[key] = {"good": encoded_good_index, "bad": encoded_bad_index}
|
234 |
+
# return embedded_corpus
|
235 |
+
pass
|
236 |
+
|
237 |
+
def get_similarity_scores(self, task: TaskType, embed_index="", str_index="", case_type="", top_k=2):
|
238 |
+
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
239 |
+
# # Embedding similarity match
|
240 |
+
# encoded_embed_query = self.embedder.encode(embed_index, convert_to_tensor=True).to(device)
|
241 |
+
# embedding_similarity_matrix = self.embedder.similarity(encoded_embed_query, self.embedded_corpus[task][case_type])
|
242 |
+
# embedding_similarity_scores = embedding_similarity_matrix[0].to(device)
|
243 |
+
|
244 |
+
# # String similarity match
|
245 |
+
# str_match_corpus = [item['index']['str_index'] for item in self.corpus[task][case_type]]
|
246 |
+
# str_similarity_results = process.extract(str_index, str_match_corpus, limit=len(str_match_corpus))
|
247 |
+
# scores_dict = {match[0]: match[1] for match in str_similarity_results}
|
248 |
+
# scores_in_order = [scores_dict[candidate] for candidate in str_match_corpus]
|
249 |
+
# str_similarity_scores = torch.tensor(scores_in_order, dtype=torch.float32).to(device)
|
250 |
+
|
251 |
+
# # Normalize scores
|
252 |
+
# embedding_score_range = embedding_similarity_scores.max() - embedding_similarity_scores.min()
|
253 |
+
# str_score_range = str_similarity_scores.max() - str_similarity_scores.min()
|
254 |
+
# if embedding_score_range > 0:
|
255 |
+
# embed_norm_scores = (embedding_similarity_scores - embedding_similarity_scores.min()) / embedding_score_range
|
256 |
+
# else:
|
257 |
+
# embed_norm_scores = embedding_similarity_scores
|
258 |
+
# if str_score_range > 0:
|
259 |
+
# str_norm_scores = (str_similarity_scores - str_similarity_scores.min()) / str_score_range
|
260 |
+
# else:
|
261 |
+
# str_norm_scores = str_similarity_scores / 100
|
262 |
+
|
263 |
+
# # Combine the scores with weights
|
264 |
+
# combined_scores = 0.5 * embed_norm_scores + 0.5 * str_norm_scores
|
265 |
+
# original_combined_scores = 0.5 * embedding_similarity_scores + 0.5 * str_similarity_scores / 100
|
266 |
+
|
267 |
+
# scores, indices = torch.topk(combined_scores, k=min(top_k, combined_scores.size(0)))
|
268 |
+
# original_scores, original_indices = torch.topk(original_combined_scores, k=min(top_k, original_combined_scores.size(0)))
|
269 |
+
# return scores, indices, original_scores, original_indices
|
270 |
+
pass
|
271 |
+
|
272 |
+
def query_case(self, task: TaskType, embed_index="", str_index="", case_type="", top_k=2) -> list:
|
273 |
+
# _, indices, _, _ = self.get_similarity_scores(task, embed_index, str_index, case_type, top_k)
|
274 |
+
# top_matches = [self.corpus[task][case_type][idx]["content"] for idx in indices]
|
275 |
+
# return top_matches
|
276 |
+
pass
|
277 |
+
|
278 |
+
def update_case(self, task: TaskType, embed_index="", str_index="", content="" ,case_type=""):
|
279 |
+
# self.corpus[task][case_type].append({"index": {"embed_index": embed_index, "str_index": str_index}, "content": content})
|
280 |
+
# self.embedded_corpus[task][case_type] = torch.cat([self.embedded_corpus[task][case_type], self.embedder.encode([embed_index], convert_to_tensor=True).to(device)], dim=0)
|
281 |
+
# print(f"Case updated for {task} task.")
|
282 |
+
pass
|
283 |
+
|
284 |
+
class CaseRepositoryHandler:
|
285 |
+
def __init__(self, llm: BaseEngine):
|
286 |
+
self.repository = CaseRepository()
|
287 |
+
self.llm = llm
|
288 |
+
|
289 |
+
def __get_good_case_analysis(self, instruction="", text="", result="", additional_info=""):
|
290 |
+
# prompt = good_case_analysis_instruction.format(
|
291 |
+
# instruction=instruction, text=text, result=result, additional_info=additional_info
|
292 |
+
# )
|
293 |
+
# for _ in range(3):
|
294 |
+
# response = self.llm.get_chat_response(prompt)
|
295 |
+
# response = extract_json_dict(response)
|
296 |
+
# if not isinstance(response, dict):
|
297 |
+
# return response
|
298 |
+
# return None
|
299 |
+
pass
|
300 |
+
|
301 |
+
def __get_bad_case_reflection(self, instruction="", text="", original_answer="", correct_answer="", additional_info=""):
|
302 |
+
# prompt = bad_case_reflection_instruction.format(
|
303 |
+
# instruction=instruction, text=text, original_answer=original_answer, correct_answer=correct_answer, additional_info=additional_info
|
304 |
+
# )
|
305 |
+
# for _ in range(3):
|
306 |
+
# response = self.llm.get_chat_response(prompt)
|
307 |
+
# response = extract_json_dict(response)
|
308 |
+
# if not isinstance(response, dict):
|
309 |
+
# return response
|
310 |
+
# return None
|
311 |
+
pass
|
312 |
+
|
313 |
+
def __get_index(self, data: DataPoint, case_type: str):
|
314 |
+
# set embed_index
|
315 |
+
# embed_index = f"**Text**: {data.distilled_text}\n{data.chunk_text_list[0]}"
|
316 |
+
|
317 |
+
# # set str_index
|
318 |
+
# if data.task == "Base":
|
319 |
+
# str_index = f"**Task**: {data.instruction}"
|
320 |
+
# else:
|
321 |
+
# str_index = f"{data.constraint}"
|
322 |
+
|
323 |
+
# if case_type == "bad":
|
324 |
+
# str_index += f"\n\n**Original Result**: {json.dumps(data.pred)}"
|
325 |
+
|
326 |
+
# return embed_index, str_index
|
327 |
+
pass
|
328 |
+
|
329 |
+
def query_good_case(self, data: DataPoint):
|
330 |
+
# embed_index, str_index = self.__get_index(data, "good")
|
331 |
+
# return self.repository.query_case(task=data.task, embed_index=embed_index, str_index=str_index, case_type="good")
|
332 |
+
pass
|
333 |
+
|
334 |
+
def query_bad_case(self, data: DataPoint):
|
335 |
+
# embed_index, str_index = self.__get_index(data, "bad")
|
336 |
+
# return self.repository.query_case(task=data.task, embed_index=embed_index, str_index=str_index, case_type="bad")
|
337 |
+
pass
|
338 |
+
|
339 |
+
def update_good_case(self, data: DataPoint):
|
340 |
+
# if data.truth == "" :
|
341 |
+
# print("No truth value provided.")
|
342 |
+
# return
|
343 |
+
# embed_index, str_index = self.__get_index(data, "good")
|
344 |
+
# _, _, original_scores, _ = self.repository.get_similarity_scores(data.task, embed_index, str_index, "good", 1)
|
345 |
+
# original_scores = original_scores.tolist()
|
346 |
+
# if original_scores[0] >= 0.9:
|
347 |
+
# print("The similar good case is already in the corpus. Similarity Score: ", original_scores[0])
|
348 |
+
# return
|
349 |
+
# good_case_alaysis = self.__get_good_case_analysis(instruction=data.instruction, text=data.distilled_text, result=data.truth, additional_info=data.constraint)
|
350 |
+
# wrapped_good_case_analysis = f"**Analysis**: {good_case_alaysis}"
|
351 |
+
# wrapped_instruction = f"**Task**: {data.instruction}"
|
352 |
+
# wrapped_text = f"**Text**: {data.distilled_text}\n{data.chunk_text_list[0]}"
|
353 |
+
# wrapped_answer = f"**Correct Answer**: {json.dumps(data.truth)}"
|
354 |
+
# if data.task == "Base":
|
355 |
+
# content = f"{wrapped_instruction}\n\n{wrapped_text}\n\n{wrapped_good_case_analysis}\n\n{wrapped_answer}"
|
356 |
+
# else:
|
357 |
+
# content = f"{wrapped_text}\n\n{data.constraint}\n\n{wrapped_good_case_analysis}\n\n{wrapped_answer}"
|
358 |
+
# self.repository.update_case(data.task, embed_index, str_index, content, "good")
|
359 |
+
pass
|
360 |
+
|
361 |
+
def update_bad_case(self, data: DataPoint):
|
362 |
+
# if data.truth == "" :
|
363 |
+
# print("No truth value provided.")
|
364 |
+
# return
|
365 |
+
# if normalize_obj(data.pred) == normalize_obj(data.truth):
|
366 |
+
# return
|
367 |
+
# embed_index, str_index = self.__get_index(data, "bad")
|
368 |
+
# _, _, original_scores, _ = self.repository.get_similarity_scores(data.task, embed_index, str_index, "bad", 1)
|
369 |
+
# original_scores = original_scores.tolist()
|
370 |
+
# if original_scores[0] >= 0.9:
|
371 |
+
# print("The similar bad case is already in the corpus. Similarity Score: ", original_scores[0])
|
372 |
+
# return
|
373 |
+
# bad_case_reflection = self.__get_bad_case_reflection(instruction=data.instruction, text=data.distilled_text, original_answer=data.pred, correct_answer=data.truth, additional_info=data.constraint)
|
374 |
+
# wrapped_bad_case_reflection = f"**Reflection**: {bad_case_reflection}"
|
375 |
+
# wrapper_original_answer = f"**Original Answer**: {json.dumps(data.pred)}"
|
376 |
+
# wrapper_correct_answer = f"**Correct Answer**: {json.dumps(data.truth)}"
|
377 |
+
# wrapped_instruction = f"**Task**: {data.instruction}"
|
378 |
+
# wrapped_text = f"**Text**: {data.distilled_text}\n{data.chunk_text_list[0]}"
|
379 |
+
# if data.task == "Base":
|
380 |
+
# content = f"{wrapped_instruction}\n\n{wrapped_text}\n\n{wrapper_original_answer}\n\n{wrapped_bad_case_reflection}\n\n{wrapper_correct_answer}"
|
381 |
+
# else:
|
382 |
+
# content = f"{wrapped_text}\n\n{data.constraint}\n\n{wrapper_original_answer}\n\n{wrapped_bad_case_reflection}\n\n{wrapper_correct_answer}"
|
383 |
+
# self.repository.update_case(data.task, embed_index, str_index, content, "bad")
|
384 |
+
pass
|
385 |
+
|
386 |
+
def update_case(self, data: DataPoint):
|
387 |
+
# self.update_good_case(data)
|
388 |
+
# self.update_bad_case(data)
|
389 |
+
# self.repository.update_corpus()
|
390 |
+
pass
|
391 |
+
|
src/modules/knowledge_base/schema_repository.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Optional
|
2 |
+
from pydantic import BaseModel, Field
|
3 |
+
from langchain_core.output_parsers import JsonOutputParser
|
4 |
+
|
5 |
+
# ==================================================================== #
|
6 |
+
# NER TASK #
|
7 |
+
# ==================================================================== #
|
8 |
+
class Entity(BaseModel):
|
9 |
+
name : str = Field(description="The specific name of the entity. ")
|
10 |
+
type : str = Field(description="The type or category that the entity belongs to.")
|
11 |
+
class EntityList(BaseModel):
|
12 |
+
entity_list : List[Entity] = Field(description="Named entities appearing in the text.")
|
13 |
+
|
14 |
+
# ==================================================================== #
|
15 |
+
# RE TASK #
|
16 |
+
# ==================================================================== #
|
17 |
+
class Relation(BaseModel):
|
18 |
+
head : str = Field(description="The starting entity in the relationship.")
|
19 |
+
tail : str = Field(description="The ending entity in the relationship.")
|
20 |
+
relation : str = Field(description="The predicate that defines the relationship between the two entities.")
|
21 |
+
|
22 |
+
class RelationList(BaseModel):
|
23 |
+
relation_list : List[Relation] = Field(description="The collection of relationships between various entities.")
|
24 |
+
|
25 |
+
# ==================================================================== #
|
26 |
+
# EE TASK #
|
27 |
+
# ==================================================================== #
|
28 |
+
class Event(BaseModel):
|
29 |
+
event_type : str = Field(description="The type of the event.")
|
30 |
+
event_trigger : str = Field(description="A specific word or phrase that indicates the occurrence of the event.")
|
31 |
+
event_argument : dict = Field(description="The arguments or participants involved in the event.")
|
32 |
+
|
33 |
+
class EventList(BaseModel):
|
34 |
+
event_list : List[Event] = Field(description="The events presented in the text.")
|
35 |
+
|
36 |
+
# ==================================================================== #
|
37 |
+
# TEXT DESCRIPTION #
|
38 |
+
# ==================================================================== #
|
39 |
+
class TextDescription(BaseModel):
|
40 |
+
field: str = Field(description="The field of the given text, such as 'Science', 'Literature', 'Business', 'Medicine', 'Entertainment', etc.")
|
41 |
+
genre: str = Field(description="The genre of the given text, such as 'Article', 'Novel', 'Dialog', 'Blog', 'Manual','Expository', 'News Report', 'Research Paper', etc.")
|
42 |
+
|
43 |
+
# ==================================================================== #
|
44 |
+
# USER DEFINED SCHEMA #
|
45 |
+
# ==================================================================== #
|
46 |
+
|
47 |
+
# --------------------------- Research Paper ----------------------- #
|
48 |
+
class MetaData(BaseModel):
|
49 |
+
title : str = Field(description="The title of the article")
|
50 |
+
authors : List[str] = Field(description="The list of the article's authors")
|
51 |
+
abstract: str = Field(description="The article's abstract")
|
52 |
+
key_words: List[str] = Field(description="The key words associated with the article")
|
53 |
+
|
54 |
+
class Baseline(BaseModel):
|
55 |
+
method_name : str = Field(description="The name of the baseline method")
|
56 |
+
proposed_solution : str = Field(description="the proposed solution in details")
|
57 |
+
performance_metrics : str = Field(description="The performance metrics of the method and comparative analysis")
|
58 |
+
|
59 |
+
class ExtractionTarget(BaseModel):
|
60 |
+
|
61 |
+
key_contributions: List[str] = Field(description="The key contributions of the article")
|
62 |
+
limitation_of_sota : str=Field(description="the summary limitation of the existing work")
|
63 |
+
proposed_solution : str = Field(description="the proposed solution in details")
|
64 |
+
baselines : List[Baseline] = Field(description="The list of baseline methods and their details")
|
65 |
+
performance_metrics : str = Field(description="The performance metrics of the method and comparative analysis")
|
66 |
+
paper_limitations : str=Field(description="The limitations of the proposed solution of the paper")
|
67 |
+
|
68 |
+
# --------------------------- News ----------------------- #
|
69 |
+
class Person(BaseModel):
|
70 |
+
name: str = Field(description="The name of the person")
|
71 |
+
identity: Optional[str] = Field(description="The occupation, status or characteristics of the person.")
|
72 |
+
role: Optional[str] = Field(description="The role or function the person plays in an event.")
|
73 |
+
|
74 |
+
class Event(BaseModel):
|
75 |
+
name: str = Field(description="Name of the event")
|
76 |
+
time: Optional[str] = Field(description="Time when the event took place")
|
77 |
+
people_involved: Optional[List[Person]] = Field(description="People involved in the event")
|
78 |
+
cause: Optional[str] = Field(default=None, description="Reason for the event, if applicable")
|
79 |
+
process: Optional[str] = Field(description="Details of the event process")
|
80 |
+
result: Optional[str] = Field(default=None, description="Result or outcome of the event")
|
81 |
+
|
82 |
+
class NewsReport(BaseModel):
|
83 |
+
title: str = Field(description="The title or headline of the news report")
|
84 |
+
summary: str = Field(description="A brief summary of the news report")
|
85 |
+
publication_date: Optional[str] = Field(description="The publication date of the report")
|
86 |
+
keywords: Optional[List[str]] = Field(description="List of keywords or topics covered in the news report")
|
87 |
+
events: List[Event] = Field(description="Events covered in the news report")
|
88 |
+
quotes: Optional[List[str]] = Field(default=None, description="Quotes related to the news, if any")
|
89 |
+
viewpoints: Optional[List[str]] = Field(default=None, description="Different viewpoints regarding the news")
|
90 |
+
|
91 |
+
# --------- You can customize new extraction schemas below -------- #
|
src/modules/reflection_agent.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from models import *
|
2 |
+
from utils import *
|
3 |
+
from .extraction_agent import ExtractionAgent
|
4 |
+
from .knowledge_base.case_repository import CaseRepositoryHandler
|
5 |
+
class ReflectionGenerator:
|
6 |
+
def __init__(self, llm: BaseEngine):
|
7 |
+
self.llm = llm
|
8 |
+
|
9 |
+
def get_reflection(self, instruction="", examples="", text="",schema="", result=""):
|
10 |
+
result = json.dumps(result)
|
11 |
+
examples = bad_case_wrapper(examples)
|
12 |
+
prompt = reflect_instruction.format(instruction=instruction, examples=examples, text=text, schema=schema, result=result)
|
13 |
+
response = self.llm.get_chat_response(prompt)
|
14 |
+
response = extract_json_dict(response)
|
15 |
+
return response
|
16 |
+
|
17 |
+
class ReflectionAgent:
|
18 |
+
def __init__(self, llm: BaseEngine, case_repo: CaseRepositoryHandler):
|
19 |
+
self.llm = llm
|
20 |
+
self.module = ReflectionGenerator(llm = llm)
|
21 |
+
self.extractor = ExtractionAgent(llm = llm, case_repo = case_repo)
|
22 |
+
self.case_repo = case_repo
|
23 |
+
self.methods = ["reflect_with_case"]
|
24 |
+
|
25 |
+
def __select_result(self, result_list):
|
26 |
+
dict_objects = [obj for obj in result_list if isinstance(obj, dict)]
|
27 |
+
if dict_objects:
|
28 |
+
selected_obj = max(dict_objects, key=lambda d: len(json.dumps(d)))
|
29 |
+
else:
|
30 |
+
selected_obj = max(result_list, key=lambda o: len(json.dumps(o)))
|
31 |
+
return selected_obj
|
32 |
+
|
33 |
+
def __self_consistance_check(self, data: DataPoint):
|
34 |
+
extract_func = list(data.result_trajectory.keys())[-1]
|
35 |
+
if hasattr(self.extractor, extract_func):
|
36 |
+
result_trails = []
|
37 |
+
result_trails.append(data.result_list)
|
38 |
+
extract_func = getattr(self.extractor, extract_func)
|
39 |
+
temperature = [0.5, 1]
|
40 |
+
for index in range(2):
|
41 |
+
self.module.llm.set_hyperparameter(temperature=temperature[index])
|
42 |
+
data = extract_func(data)
|
43 |
+
result_trails.append(data.result_list)
|
44 |
+
self.module.llm.set_hyperparameter()
|
45 |
+
consistant_result = []
|
46 |
+
reflect_index = []
|
47 |
+
for index, elements in enumerate(zip(*result_trails)):
|
48 |
+
normalized_elements = [normalize_obj(e) for e in elements]
|
49 |
+
element_counts = Counter(normalized_elements)
|
50 |
+
selected_element = next((elements[i] for i, element in enumerate(normalized_elements)
|
51 |
+
if element_counts[element] >= 2), None)
|
52 |
+
if selected_element is None:
|
53 |
+
selected_element = self.__select_result(elements)
|
54 |
+
reflect_index.append(index)
|
55 |
+
consistant_result.append(selected_element)
|
56 |
+
data.set_result_list(consistant_result)
|
57 |
+
return reflect_index
|
58 |
+
|
59 |
+
def reflect_with_case(self, data: DataPoint):
|
60 |
+
if data.result_list == []:
|
61 |
+
return data
|
62 |
+
reflect_index = self.__self_consistance_check(data)
|
63 |
+
reflected_result_list = data.result_list
|
64 |
+
for idx in reflect_index:
|
65 |
+
text = data.chunk_text_list[idx]
|
66 |
+
result = data.result_list[idx]
|
67 |
+
examples = json.dumps(self.case_repo.query_bad_case(data))
|
68 |
+
reflected_res = self.module.get_reflection(instruction=data.instruction, examples=examples, text=text, schema=data.output_schema, result=result)
|
69 |
+
reflected_result_list[idx] = reflected_res
|
70 |
+
data.set_result_list(reflected_result_list)
|
71 |
+
function_name = current_function_name()
|
72 |
+
data.update_trajectory(function_name, data.result_list)
|
73 |
+
return data
|
74 |
+
|
src/modules/schema_agent.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from models import *
|
2 |
+
from utils import *
|
3 |
+
from .knowledge_base import schema_repository
|
4 |
+
from langchain_core.output_parsers import JsonOutputParser
|
5 |
+
|
6 |
+
class SchemaAnalyzer:
|
7 |
+
def __init__(self, llm: BaseEngine):
|
8 |
+
self.llm = llm
|
9 |
+
|
10 |
+
def serialize_schema(self, schema) -> str:
|
11 |
+
if isinstance(schema, (str, list, dict, set, tuple)):
|
12 |
+
return schema
|
13 |
+
try:
|
14 |
+
parser = JsonOutputParser(pydantic_object = schema)
|
15 |
+
schema_description = parser.get_format_instructions()
|
16 |
+
schema_content = re.findall(r'```(.*?)```', schema_description, re.DOTALL)
|
17 |
+
explanation = "For example, for the schema {\"properties\": {\"foo\": {\"title\": \"Foo\", \"description\": \"a list of strings\", \"type\": \"array\", \"items\": {\"type\": \"string\"}}}}, the object {\"foo\": [\"bar\", \"baz\"]} is a well-formatted instance."
|
18 |
+
schema = f"{schema_content}\n\n{explanation}"
|
19 |
+
except:
|
20 |
+
return schema
|
21 |
+
return schema
|
22 |
+
|
23 |
+
def redefine_text(self, text_analysis):
|
24 |
+
try:
|
25 |
+
field = text_analysis['field']
|
26 |
+
genre = text_analysis['genre']
|
27 |
+
except:
|
28 |
+
return text_analysis
|
29 |
+
prompt = f"This text is from the field of {field} and represents the genre of {genre}."
|
30 |
+
return prompt
|
31 |
+
|
32 |
+
def get_text_analysis(self, text: str):
|
33 |
+
output_schema = self.serialize_schema(schema_repository.TextDescription)
|
34 |
+
prompt = text_analysis_instruction.format(examples="", text=text, schema=output_schema)
|
35 |
+
response = self.llm.get_chat_response(prompt)
|
36 |
+
response = extract_json_dict(response)
|
37 |
+
response = self.redefine_text(response)
|
38 |
+
return response
|
39 |
+
|
40 |
+
def get_deduced_schema_json(self, instruction: str, text: str, distilled_text: str):
|
41 |
+
prompt = deduced_schema_json_instruction.format(examples=example_wrapper(json_schema_examples), instruction=instruction, distilled_text=distilled_text, text=text)
|
42 |
+
response = self.llm.get_chat_response(prompt)
|
43 |
+
response = extract_json_dict(response)
|
44 |
+
code = response
|
45 |
+
print(f"Deduced Schema in Json: \n{response}\n\n")
|
46 |
+
return code, response
|
47 |
+
|
48 |
+
def get_deduced_schema_code(self, instruction: str, text: str, distilled_text: str):
|
49 |
+
prompt = deduced_schema_code_instruction.format(examples=example_wrapper(code_schema_examples), instruction=instruction, distilled_text=distilled_text, text=text)
|
50 |
+
response = self.llm.get_chat_response(prompt)
|
51 |
+
print(f"schema prompt: {prompt}")
|
52 |
+
print("========================================")
|
53 |
+
print(f"schema response: {response}")
|
54 |
+
code_blocks = re.findall(r'```[^\n]*\n(.*?)\n```', response, re.DOTALL)
|
55 |
+
if code_blocks:
|
56 |
+
try:
|
57 |
+
code_block = code_blocks[-1]
|
58 |
+
namespace = {}
|
59 |
+
exec(code_block, namespace)
|
60 |
+
schema = namespace.get('ExtractionTarget')
|
61 |
+
if schema is not None:
|
62 |
+
index = code_block.find("class")
|
63 |
+
code = code_block[index:]
|
64 |
+
print(f"Deduced Schema in Code: \n{code}\n\n")
|
65 |
+
schema = self.serialize_schema(schema)
|
66 |
+
return code, schema
|
67 |
+
except Exception as e:
|
68 |
+
print(e)
|
69 |
+
return self.get_deduced_schema_json(instruction, text, distilled_text)
|
70 |
+
return self.get_deduced_schema_json(instruction, text, distilled_text)
|
71 |
+
|
72 |
+
class SchemaAgent:
|
73 |
+
def __init__(self, llm: BaseEngine):
|
74 |
+
self.llm = llm
|
75 |
+
self.module = SchemaAnalyzer(llm = llm)
|
76 |
+
self.schema_repo = schema_repository
|
77 |
+
self.methods = ["get_default_schema", "get_retrieved_schema", "get_deduced_schema"]
|
78 |
+
|
79 |
+
def __preprocess_text(self, data: DataPoint):
|
80 |
+
if data.use_file:
|
81 |
+
data.chunk_text_list = chunk_file(data.file_path)
|
82 |
+
else:
|
83 |
+
data.chunk_text_list = chunk_str(data.text)
|
84 |
+
if data.task == "NER":
|
85 |
+
data.print_schema = """
|
86 |
+
class Entity(BaseModel):
|
87 |
+
name : str = Field(description="The specific name of the entity. ")
|
88 |
+
type : str = Field(description="The type or category that the entity belongs to.")
|
89 |
+
class EntityList(BaseModel):
|
90 |
+
entity_list : List[Entity] = Field(description="Named entities appearing in the text.")
|
91 |
+
"""
|
92 |
+
elif data.task == "RE":
|
93 |
+
data.print_schema = """
|
94 |
+
class Relation(BaseModel):
|
95 |
+
head : str = Field(description="The starting entity in the relationship.")
|
96 |
+
tail : str = Field(description="The ending entity in the relationship.")
|
97 |
+
relation : str = Field(description="The predicate that defines the relationship between the two entities.")
|
98 |
+
|
99 |
+
class RelationList(BaseModel):
|
100 |
+
relation_list : List[Relation] = Field(description="The collection of relationships between various entities.")
|
101 |
+
"""
|
102 |
+
elif data.task == "EE":
|
103 |
+
data.print_schema = """
|
104 |
+
class Event(BaseModel):
|
105 |
+
event_type : str = Field(description="The type of the event.")
|
106 |
+
event_trigger : str = Field(description="A specific word or phrase that indicates the occurrence of the event.")
|
107 |
+
event_argument : dict = Field(description="The arguments or participants involved in the event.")
|
108 |
+
|
109 |
+
class EventList(BaseModel):
|
110 |
+
event_list : List[Event] = Field(description="The events presented in the text.")
|
111 |
+
"""
|
112 |
+
return data
|
113 |
+
|
114 |
+
def get_default_schema(self, data: DataPoint):
|
115 |
+
data = self.__preprocess_text(data)
|
116 |
+
default_schema = config['agent']['default_schema']
|
117 |
+
data.set_schema(default_schema)
|
118 |
+
function_name = current_function_name()
|
119 |
+
data.update_trajectory(function_name, default_schema)
|
120 |
+
return data
|
121 |
+
|
122 |
+
def get_retrieved_schema(self, data: DataPoint):
|
123 |
+
self.__preprocess_text(data)
|
124 |
+
schema_name = data.output_schema
|
125 |
+
schema_class = getattr(self.schema_repo, schema_name, None)
|
126 |
+
if schema_class is not None:
|
127 |
+
schema = self.module.serialize_schema(schema_class)
|
128 |
+
default_schema = config['agent']['default_schema']
|
129 |
+
data.set_schema(f"{default_schema}\n{schema}")
|
130 |
+
function_name = current_function_name()
|
131 |
+
data.update_trajectory(function_name, schema)
|
132 |
+
else:
|
133 |
+
return self.get_default_schema(data)
|
134 |
+
return data
|
135 |
+
|
136 |
+
def get_deduced_schema(self, data: DataPoint):
|
137 |
+
self.__preprocess_text(data)
|
138 |
+
target_text = data.chunk_text_list[0]
|
139 |
+
analysed_text = self.module.get_text_analysis(target_text)
|
140 |
+
if len(data.chunk_text_list) > 1:
|
141 |
+
prefix = "Below is a portion of the text to be extracted. "
|
142 |
+
analysed_text = f"{prefix}\n{target_text}"
|
143 |
+
distilled_text = self.module.redefine_text(analysed_text)
|
144 |
+
code, deduced_schema = self.module.get_deduced_schema_code(data.instruction, target_text, distilled_text)
|
145 |
+
data.print_schema = code
|
146 |
+
data.set_distilled_text(distilled_text)
|
147 |
+
default_schema = config['agent']['default_schema']
|
148 |
+
data.set_schema(f"{default_schema}\n{deduced_schema}")
|
149 |
+
function_name = current_function_name()
|
150 |
+
data.update_trajectory(function_name, deduced_schema)
|
151 |
+
return data
|
src/pipeline.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Literal
|
2 |
+
from models import *
|
3 |
+
from utils import *
|
4 |
+
from modules import *
|
5 |
+
|
6 |
+
class Pipeline:
|
7 |
+
def __init__(self, llm: BaseEngine):
|
8 |
+
self.llm = llm
|
9 |
+
self.case_repo = CaseRepositoryHandler(llm = llm)
|
10 |
+
self.schema_agent = SchemaAgent(llm = llm)
|
11 |
+
self.extraction_agent = ExtractionAgent(llm = llm, case_repo = self.case_repo)
|
12 |
+
self.reflection_agent = ReflectionAgent(llm = llm, case_repo = self.case_repo)
|
13 |
+
|
14 |
+
def __init_method(self, data: DataPoint, process_method):
|
15 |
+
default_order = ["schema_agent", "extraction_agent", "reflection_agent"]
|
16 |
+
if "schema_agent" not in process_method:
|
17 |
+
process_method["schema_agent"] = "get_default_schema"
|
18 |
+
if data.task == "Base":
|
19 |
+
process_method["schema_agent"] = "get_deduced_schema"
|
20 |
+
if data.task != "Base":
|
21 |
+
process_method["schema_agent"] = "get_retrieved_schema"
|
22 |
+
if "extraction_agent" not in process_method:
|
23 |
+
process_method["extraction_agent"] = "extract_information_direct"
|
24 |
+
sorted_process_method = {key: process_method[key] for key in default_order if key in process_method}
|
25 |
+
return sorted_process_method
|
26 |
+
|
27 |
+
def __init_data(self, data: DataPoint):
|
28 |
+
if data.task == "NER":
|
29 |
+
data.instruction = config['agent']['default_ner']
|
30 |
+
data.output_schema = "EntityList"
|
31 |
+
elif data.task == "RE":
|
32 |
+
data.instruction = config['agent']['default_re']
|
33 |
+
data.output_schema = "RelationList"
|
34 |
+
elif data.task == "EE":
|
35 |
+
data.instruction = config['agent']['default_ee']
|
36 |
+
data.output_schema = "EventList"
|
37 |
+
return data
|
38 |
+
|
39 |
+
|
40 |
+
|
41 |
+
# main entry
|
42 |
+
def get_extract_result(self,
|
43 |
+
task: TaskType,
|
44 |
+
instruction: str = "",
|
45 |
+
text: str = "",
|
46 |
+
output_schema: str = "",
|
47 |
+
constraint: str = "",
|
48 |
+
use_file: bool = False,
|
49 |
+
file_path: str = "",
|
50 |
+
truth: str = "",
|
51 |
+
mode: str = "quick",
|
52 |
+
update_case: bool = False
|
53 |
+
):
|
54 |
+
print(f" task: {task},\n instruction: {instruction},\n text: {text},\n output_schema: {output_schema},\n constraint: {constraint},\n use_file: {use_file},\n file_path: {file_path},\n truth: {truth},\n mode: {mode},\n update_case: {update_case}")
|
55 |
+
data = DataPoint(task=task, instruction=instruction, text=text, output_schema=output_schema, constraint=constraint, use_file=use_file, file_path=file_path, truth=truth)
|
56 |
+
data = self.__init_data(data)
|
57 |
+
if mode in config['agent']['mode'].keys():
|
58 |
+
process_method = config['agent']['mode'][mode]
|
59 |
+
else:
|
60 |
+
process_method = mode
|
61 |
+
print(f"data=================: {data.task}")
|
62 |
+
print(f"process_method=================: {process_method}")
|
63 |
+
sorted_process_method = self.__init_method(data, process_method)
|
64 |
+
print_schema = False
|
65 |
+
frontend_schema = ""
|
66 |
+
frontend_res = ""
|
67 |
+
# Information Extract
|
68 |
+
print(f"sorted_process_method=================: {sorted_process_method}")
|
69 |
+
for agent_name, method_name in sorted_process_method.items():
|
70 |
+
agent = getattr(self, agent_name, None)
|
71 |
+
if not agent:
|
72 |
+
raise AttributeError(f"{agent_name} does not exist.")
|
73 |
+
method = getattr(agent, method_name, None)
|
74 |
+
if not method:
|
75 |
+
raise AttributeError(f"Method '{method_name}' not found in {agent_name}.")
|
76 |
+
data = method(data)
|
77 |
+
if not print_schema and data.print_schema:
|
78 |
+
print("Schema: \n", data.print_schema)
|
79 |
+
frontend_schema = data.print_schema
|
80 |
+
print_schema = True
|
81 |
+
data = self.extraction_agent.summarize_answer(data)
|
82 |
+
print("Extraction Result: \n", json.dumps(data.pred, indent=2))
|
83 |
+
frontend_res = data.pred
|
84 |
+
# Case Update
|
85 |
+
if update_case:
|
86 |
+
if (data.truth == ""):
|
87 |
+
truth = input("Please enter the correct answer you prefer, or press Enter to accept the current answer: ")
|
88 |
+
if truth.strip() == "":
|
89 |
+
data.truth = data.pred
|
90 |
+
else:
|
91 |
+
data.truth = extract_json_dict(truth)
|
92 |
+
self.case_repo.update_case(data)
|
93 |
+
|
94 |
+
# return result
|
95 |
+
result = data.pred
|
96 |
+
trajectory = data.get_result_trajectory()
|
97 |
+
|
98 |
+
return result, trajectory, frontend_schema, frontend_res
|
src/run.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import yaml
|
4 |
+
from pipeline import Pipeline
|
5 |
+
from typing import Literal
|
6 |
+
import models
|
7 |
+
from models import *
|
8 |
+
from utils import *
|
9 |
+
from modules import *
|
10 |
+
|
11 |
+
def load_extraction_config(yaml_path):
|
12 |
+
# 从文件路径读取 YAML 内容
|
13 |
+
if not os.path.exists(yaml_path):
|
14 |
+
print(f"Error: The config file '{yaml_path}' does not exist.")
|
15 |
+
return {}
|
16 |
+
|
17 |
+
with open(yaml_path, 'r') as file:
|
18 |
+
config = yaml.safe_load(file)
|
19 |
+
|
20 |
+
# 提取'extraction'配置的字典
|
21 |
+
model_config = config.get('model', {})
|
22 |
+
extraction_config = config.get('extraction', {})
|
23 |
+
# model config
|
24 |
+
model_name_or_path = model_config.get('model_name_or_path', "")
|
25 |
+
model_category = model_config.get('category', "")
|
26 |
+
api_key = model_config.get('api_key', "")
|
27 |
+
base_url = model_config.get('base_url', "")
|
28 |
+
|
29 |
+
# extraction config
|
30 |
+
task = extraction_config.get('task', "")
|
31 |
+
instruction = extraction_config.get('instruction', "")
|
32 |
+
text = extraction_config.get('text', "")
|
33 |
+
output_schema = extraction_config.get('output_schema', "")
|
34 |
+
constraint = extraction_config.get('constraint', "")
|
35 |
+
truth = extraction_config.get('truth', "")
|
36 |
+
use_file = extraction_config.get('use_file', False)
|
37 |
+
mode = extraction_config.get('mode', "quick")
|
38 |
+
update_case = extraction_config.get('update_case', False)
|
39 |
+
|
40 |
+
# 返回一个包含这些变量的字典
|
41 |
+
return {
|
42 |
+
"model": {
|
43 |
+
"model_name_or_path": model_name_or_path,
|
44 |
+
"category": model_category,
|
45 |
+
"api_key": api_key,
|
46 |
+
"base_url": base_url
|
47 |
+
},
|
48 |
+
"extraction": {
|
49 |
+
"task": task,
|
50 |
+
"instruction": instruction,
|
51 |
+
"text": text,
|
52 |
+
"output_schema": output_schema,
|
53 |
+
"constraint": constraint,
|
54 |
+
"truth": truth,
|
55 |
+
"use_file": use_file,
|
56 |
+
"mode": mode,
|
57 |
+
"update_case": update_case
|
58 |
+
}
|
59 |
+
}
|
60 |
+
|
61 |
+
|
62 |
+
def main():
|
63 |
+
# 创建命令行参数解析器
|
64 |
+
parser = argparse.ArgumentParser(description='Run the extraction model.')
|
65 |
+
parser.add_argument('--config', type=str, required=True,
|
66 |
+
help='Path to the YAML configuration file.')
|
67 |
+
|
68 |
+
# 解析命令行参数
|
69 |
+
args = parser.parse_args()
|
70 |
+
|
71 |
+
# 加载配置
|
72 |
+
config = load_extraction_config(args.config)
|
73 |
+
model_config = config['model']
|
74 |
+
extraction_config = config['extraction']
|
75 |
+
clazz = getattr(models, model_config['category'], None)
|
76 |
+
if clazz is None:
|
77 |
+
print(f"Error: The model category '{model_config['category']}' is not supported.")
|
78 |
+
return
|
79 |
+
if model_config['api_key'] == "":
|
80 |
+
model = clazz(model_config['model_name_or_path'])
|
81 |
+
else:
|
82 |
+
model = clazz(model_config['model_name_or_path'], model_config['api_key'], model_config['base_url'])
|
83 |
+
pipeline = Pipeline(model)
|
84 |
+
result, trajectory, *_ = pipeline.get_extract_result(task=extraction_config['task'], instruction=extraction_config['instruction'], text=extraction_config['text'], output_schema=extraction_config['output_schema'], constraint=extraction_config['constraint'], use_file=extraction_config['use_file'], truth=extraction_config['truth'], mode=extraction_config['mode'], update_case=extraction_config['update_case'])
|
85 |
+
return
|
86 |
+
|
87 |
+
if __name__ == "__main__":
|
88 |
+
main()
|
src/utils/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .process import *
|
2 |
+
from .data_def import DataPoint, TaskType
|
3 |
+
|
src/utils/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (274 Bytes). View file
|
|
src/utils/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (239 Bytes). View file
|
|
src/utils/__pycache__/data_def.cpython-311.pyc
ADDED
Binary file (3.07 kB). View file
|
|
src/utils/__pycache__/data_def.cpython-39.pyc
ADDED
Binary file (2.3 kB). View file
|
|
src/utils/__pycache__/process.cpython-311.pyc
ADDED
Binary file (10.7 kB). View file
|
|
src/utils/__pycache__/process.cpython-39.pyc
ADDED
Binary file (5.98 kB). View file
|
|
src/utils/data_def.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Literal
|
2 |
+
from models import *
|
3 |
+
from .process import *
|
4 |
+
# predefined processing logic for routine extraction tasks
|
5 |
+
TaskType = Literal["NER", "RE", "EE", "Base"]
|
6 |
+
ModelType = Literal["gpt-3.5-turbo", "gpt-4o"]
|
7 |
+
|
8 |
+
class DataPoint:
|
9 |
+
def __init__(self,
|
10 |
+
task: TaskType = "Base",
|
11 |
+
instruction: str = "",
|
12 |
+
text: str = "",
|
13 |
+
output_schema: str = "",
|
14 |
+
constraint: str = "",
|
15 |
+
use_file: bool = False,
|
16 |
+
file_path: str = "",
|
17 |
+
truth: str = ""):
|
18 |
+
"""
|
19 |
+
Initialize a DataPoint instance.
|
20 |
+
"""
|
21 |
+
# task information
|
22 |
+
self.task = task
|
23 |
+
self.instruction = instruction
|
24 |
+
self.text = text
|
25 |
+
self.output_schema = output_schema
|
26 |
+
self.constraint = constraint
|
27 |
+
self.use_file = use_file
|
28 |
+
self.file_path = file_path
|
29 |
+
self.truth = extract_json_dict(truth)
|
30 |
+
# temp storage
|
31 |
+
self.print_schema = ""
|
32 |
+
self.distilled_text = ""
|
33 |
+
self.chunk_text_list = []
|
34 |
+
# result feedback
|
35 |
+
self.result_list = []
|
36 |
+
self.result_trajectory = {}
|
37 |
+
self.pred = ""
|
38 |
+
|
39 |
+
def set_constraint(self, constraint):
|
40 |
+
self.constraint = constraint
|
41 |
+
|
42 |
+
def set_schema(self, output_schema):
|
43 |
+
self.output_schema = output_schema
|
44 |
+
|
45 |
+
def set_pred(self, pred):
|
46 |
+
self.pred = pred
|
47 |
+
|
48 |
+
def set_result_list(self, result_list):
|
49 |
+
self.result_list = result_list
|
50 |
+
|
51 |
+
def set_distilled_text(self, distilled_text):
|
52 |
+
self.distilled_text = distilled_text
|
53 |
+
|
54 |
+
def update_trajectory(self, function, result):
|
55 |
+
if function not in self.result_trajectory:
|
56 |
+
self.result_trajectory.update({function: result})
|
57 |
+
|
58 |
+
def get_result_trajectory(self):
|
59 |
+
return {"instruction": self.instruction, "text": self.text, "constraint": self.constraint, "trajectory": self.result_trajectory, "pred": self.pred}
|
src/utils/process.py
ADDED
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Data Processing Functions.
|
3 |
+
Supports:
|
4 |
+
- Segmentation of long text
|
5 |
+
- Segmentation of file content
|
6 |
+
"""
|
7 |
+
from langchain_community.document_loaders import TextLoader, PyPDFLoader, Docx2txtLoader, BSHTMLLoader, JSONLoader
|
8 |
+
from nltk.tokenize import sent_tokenize
|
9 |
+
from collections import Counter
|
10 |
+
import re
|
11 |
+
import json
|
12 |
+
import yaml
|
13 |
+
import os
|
14 |
+
import yaml
|
15 |
+
import os
|
16 |
+
import inspect
|
17 |
+
import ast
|
18 |
+
with open(os.path.join(os.path.dirname(__file__), "..", "config.yaml")) as file:
|
19 |
+
config = yaml.safe_load(file)
|
20 |
+
|
21 |
+
# Split the string text into chunks
|
22 |
+
def chunk_str(text):
|
23 |
+
sentences = sent_tokenize(text)
|
24 |
+
chunks = []
|
25 |
+
current_chunk = []
|
26 |
+
current_length = 0
|
27 |
+
|
28 |
+
for sentence in sentences:
|
29 |
+
token_count = len(sentence.split())
|
30 |
+
if current_length + token_count <= config['agent']['chunk_token_limit']:
|
31 |
+
current_chunk.append(sentence)
|
32 |
+
current_length += token_count
|
33 |
+
else:
|
34 |
+
if current_chunk:
|
35 |
+
chunks.append(' '.join(current_chunk))
|
36 |
+
current_chunk = [sentence]
|
37 |
+
current_length = token_count
|
38 |
+
if current_chunk:
|
39 |
+
chunks.append(' '.join(current_chunk))
|
40 |
+
return chunks
|
41 |
+
|
42 |
+
# Load and split the content of a file
|
43 |
+
def chunk_file(file_path):
|
44 |
+
pages = []
|
45 |
+
|
46 |
+
if file_path.endswith(".pdf"):
|
47 |
+
loader = PyPDFLoader(file_path)
|
48 |
+
elif file_path.endswith(".txt"):
|
49 |
+
loader = TextLoader(file_path)
|
50 |
+
elif file_path.endswith(".docx"):
|
51 |
+
loader = Docx2txtLoader(file_path)
|
52 |
+
elif file_path.endswith(".html"):
|
53 |
+
loader = BSHTMLLoader(file_path)
|
54 |
+
elif file_path.endswith(".json"):
|
55 |
+
loader = JSONLoader(file_path)
|
56 |
+
else:
|
57 |
+
raise ValueError("Unsupported file format") # Inform that the format is unsupported
|
58 |
+
|
59 |
+
pages = loader.load_and_split()
|
60 |
+
docs = ""
|
61 |
+
for item in pages:
|
62 |
+
docs += item.page_content
|
63 |
+
pages = chunk_str(docs)
|
64 |
+
|
65 |
+
return pages
|
66 |
+
|
67 |
+
def process_single_quotes(text):
|
68 |
+
result = re.sub(r"(?<!\w)'|'(?!\w)", '"', text)
|
69 |
+
return result
|
70 |
+
|
71 |
+
def remove_empty_values(data):
|
72 |
+
def is_empty(value):
|
73 |
+
return value is None or value == [] or value == "" or value == {}
|
74 |
+
if isinstance(data, dict):
|
75 |
+
return {
|
76 |
+
k: remove_empty_values(v)
|
77 |
+
for k, v in data.items()
|
78 |
+
if not is_empty(v)
|
79 |
+
}
|
80 |
+
elif isinstance(data, list):
|
81 |
+
return [
|
82 |
+
remove_empty_values(item)
|
83 |
+
for item in data
|
84 |
+
if not is_empty(item)
|
85 |
+
]
|
86 |
+
else:
|
87 |
+
return data
|
88 |
+
|
89 |
+
def extract_json_dict(text):
|
90 |
+
if isinstance(text, dict):
|
91 |
+
return text
|
92 |
+
pattern = r'\{(?:[^{}]|(?:\{(?:[^{}]|(?:\{[^{}]*\})*)*\})*)*\}'
|
93 |
+
matches = re.findall(pattern, text)
|
94 |
+
if matches:
|
95 |
+
json_string = matches[-1]
|
96 |
+
json_string = process_single_quotes(json_string)
|
97 |
+
try:
|
98 |
+
json_dict = json.loads(json_string)
|
99 |
+
json_dict = remove_empty_values(json_dict)
|
100 |
+
if json_dict is None:
|
101 |
+
return "No valid information found."
|
102 |
+
return json_dict
|
103 |
+
except json.JSONDecodeError:
|
104 |
+
return json_string
|
105 |
+
else:
|
106 |
+
return text
|
107 |
+
|
108 |
+
def good_case_wrapper(example: str):
|
109 |
+
if example is None or example == "":
|
110 |
+
return ""
|
111 |
+
example = f"\nHere are some examples:\n{example}\n(END OF EXAMPLES)\nRefer to the reasoning steps and analysis in the examples to help complete the extraction task below.\n\n"
|
112 |
+
return example
|
113 |
+
|
114 |
+
def bad_case_wrapper(example: str):
|
115 |
+
if example is None or example == "":
|
116 |
+
return ""
|
117 |
+
example = f"\nHere are some examples of bad cases:\n{example}\n(END OF EXAMPLES)\nRefer to the reflection rules and reflection steps in the examples to help optimize the original result below.\n\n"
|
118 |
+
return example
|
119 |
+
|
120 |
+
def example_wrapper(example: str):
|
121 |
+
if example is None or example == "":
|
122 |
+
return ""
|
123 |
+
example = f"\nHere are some examples:\n{example}\n(END OF EXAMPLES)\n\n"
|
124 |
+
return example
|
125 |
+
|
126 |
+
def remove_redundant_space(s):
|
127 |
+
s = ' '.join(s.split())
|
128 |
+
s = re.sub(r"\s*(,|:|\(|\)|\.|_|;|'|-)\s*", r'\1', s)
|
129 |
+
return s
|
130 |
+
|
131 |
+
def format_string(s):
|
132 |
+
s = remove_redundant_space(s)
|
133 |
+
s = s.lower()
|
134 |
+
s = s.replace('{','').replace('}','')
|
135 |
+
s = re.sub(',+', ',', s)
|
136 |
+
s = re.sub('\.+', '.', s)
|
137 |
+
s = re.sub(';+', ';', s)
|
138 |
+
s = s.replace('’', "'")
|
139 |
+
return s
|
140 |
+
|
141 |
+
def calculate_metrics(y_truth: set, y_pred: set):
|
142 |
+
TP = len(y_truth & y_pred)
|
143 |
+
FN = len(y_truth - y_pred)
|
144 |
+
FP = len(y_pred - y_truth)
|
145 |
+
precision = TP / (TP + FP) if (TP + FP) > 0 else 0
|
146 |
+
recall = TP / (TP + FN) if (TP + FN) > 0 else 0
|
147 |
+
f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
|
148 |
+
return precision, recall, f1_score
|
149 |
+
|
150 |
+
def current_function_name():
|
151 |
+
try:
|
152 |
+
stack = inspect.stack()
|
153 |
+
if len(stack) > 1:
|
154 |
+
outer_func_name = stack[1].function
|
155 |
+
return outer_func_name
|
156 |
+
else:
|
157 |
+
print("No caller function found")
|
158 |
+
return None
|
159 |
+
|
160 |
+
except Exception as e:
|
161 |
+
print(f"An error occurred: {e}")
|
162 |
+
pass
|
163 |
+
|
164 |
+
def normalize_obj(value):
|
165 |
+
if isinstance(value, dict):
|
166 |
+
return frozenset((k, normalize_obj(v)) for k, v in value.items())
|
167 |
+
elif isinstance(value, (list, set, tuple)):
|
168 |
+
# 将 Counter 转换为元组以便于被哈希
|
169 |
+
return tuple(Counter(map(normalize_obj, value)).items())
|
170 |
+
elif isinstance(value, str):
|
171 |
+
return format_string(value)
|
172 |
+
return value
|
173 |
+
|
174 |
+
def dict_list_to_set(data_list):
|
175 |
+
result_set = set()
|
176 |
+
try:
|
177 |
+
for dictionary in data_list:
|
178 |
+
value_tuple = tuple(format_string(value) for value in dictionary.values())
|
179 |
+
result_set.add(value_tuple)
|
180 |
+
return result_set
|
181 |
+
except Exception as e:
|
182 |
+
print (f"Failed to convert dictionary list to set: {data_list}")
|
183 |
+
return result_set
|
src/webui/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .interface import InterFace
|