ShawnRu commited on
Commit
009d93e
·
1 Parent(s): 36208ce
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. src/__pycache__/pipeline.cpython-311.pyc +0 -0
  2. src/__pycache__/pipeline.cpython-39.pyc +0 -0
  3. src/config.yaml +19 -0
  4. src/generate_memory.py +181 -0
  5. src/main.py +233 -0
  6. src/models/__init__.py +3 -0
  7. src/models/__pycache__/__init__.cpython-311.pyc +0 -0
  8. src/models/__pycache__/__init__.cpython-37.pyc +0 -0
  9. src/models/__pycache__/__init__.cpython-39.pyc +0 -0
  10. src/models/__pycache__/llm_def.cpython-311.pyc +0 -0
  11. src/models/__pycache__/llm_def.cpython-37.pyc +0 -0
  12. src/models/__pycache__/llm_def.cpython-39.pyc +0 -0
  13. src/models/__pycache__/prompt_example.cpython-311.pyc +0 -0
  14. src/models/__pycache__/prompt_example.cpython-39.pyc +0 -0
  15. src/models/__pycache__/prompt_template.cpython-311.pyc +0 -0
  16. src/models/__pycache__/prompt_template.cpython-39.pyc +0 -0
  17. src/models/llm_def.py +212 -0
  18. src/models/prompt_example.py +137 -0
  19. src/models/prompt_template.py +174 -0
  20. src/modules/__init__.py +4 -0
  21. src/modules/__pycache__/__init__.cpython-311.pyc +0 -0
  22. src/modules/__pycache__/__init__.cpython-39.pyc +0 -0
  23. src/modules/__pycache__/extraction_agent.cpython-311.pyc +0 -0
  24. src/modules/__pycache__/extraction_agent.cpython-39.pyc +0 -0
  25. src/modules/__pycache__/reflection_agent.cpython-311.pyc +0 -0
  26. src/modules/__pycache__/reflection_agent.cpython-39.pyc +0 -0
  27. src/modules/__pycache__/schema_agent.cpython-311.pyc +0 -0
  28. src/modules/__pycache__/schema_agent.cpython-39.pyc +0 -0
  29. src/modules/extraction_agent.py +85 -0
  30. src/modules/knowledge_base/__pycache__/case_repository.cpython-311.pyc +0 -0
  31. src/modules/knowledge_base/__pycache__/case_repository.cpython-39.pyc +0 -0
  32. src/modules/knowledge_base/__pycache__/schema_repository.cpython-311.pyc +0 -0
  33. src/modules/knowledge_base/__pycache__/schema_repository.cpython-39.pyc +0 -0
  34. src/modules/knowledge_base/case_repository.json +0 -0
  35. src/modules/knowledge_base/case_repository.py +391 -0
  36. src/modules/knowledge_base/schema_repository.py +91 -0
  37. src/modules/reflection_agent.py +74 -0
  38. src/modules/schema_agent.py +151 -0
  39. src/pipeline.py +98 -0
  40. src/run.py +88 -0
  41. src/utils/__init__.py +3 -0
  42. src/utils/__pycache__/__init__.cpython-311.pyc +0 -0
  43. src/utils/__pycache__/__init__.cpython-39.pyc +0 -0
  44. src/utils/__pycache__/data_def.cpython-311.pyc +0 -0
  45. src/utils/__pycache__/data_def.cpython-39.pyc +0 -0
  46. src/utils/__pycache__/process.cpython-311.pyc +0 -0
  47. src/utils/__pycache__/process.cpython-39.pyc +0 -0
  48. src/utils/data_def.py +59 -0
  49. src/utils/process.py +183 -0
  50. src/webui/__init__.py +1 -0
src/__pycache__/pipeline.cpython-311.pyc ADDED
Binary file (5.34 kB). View file
 
src/__pycache__/pipeline.cpython-39.pyc ADDED
Binary file (3.56 kB). View file
 
src/config.yaml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ agent:
2
+ default_schema: The final extraction result should be formatted as a JSON object.
3
+ default_ner: Extract the Named Entities in the given text.
4
+ default_re: Extract Relationships between Named Entities in the given text.
5
+ default_ee: Extract the Events in the given text.
6
+ chunk_token_limit: 1024
7
+ mode:
8
+ quick:
9
+ schema_agent: get_deduced_schema
10
+ extraction_agent: extract_information_direct
11
+ standard:
12
+ schema_agent: get_deduced_schema
13
+ extraction_agent: extract_information_with_case
14
+ reflection_agent: reflect_with_case
15
+ customized:
16
+ schema_agent: get_retrieved_schema
17
+ extraction_agent: extract_information_direct
18
+
19
+
src/generate_memory.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Literal
2
+ from models import *
3
+ from utils import *
4
+ from modules import *
5
+
6
+
7
+ class Pipeline:
8
+ def __init__(self, llm: BaseEngine):
9
+ self.llm = llm
10
+ self.case_repo = CaseRepositoryHandler(llm = llm)
11
+ self.schema_agent = SchemaAgent(llm = llm)
12
+ self.extraction_agent = ExtractionAgent(llm = llm, case_repo = self.case_repo)
13
+ self.reflection_agent = ReflectionAgent(llm = llm, case_repo = self.case_repo)
14
+
15
+ def __init_method(self, data: DataPoint, process_method):
16
+ default_order = ["schema_agent", "extraction_agent", "reflection_agent"]
17
+ if "schema_agent" not in process_method:
18
+ process_method["schema_agent"] = "get_default_schema"
19
+ if data.task != "Base":
20
+ process_method["schema_agent"] = "get_retrieved_schema"
21
+ if "extraction_agent" not in process_method:
22
+ process_method["extraction_agent"] = "extract_information_direct"
23
+ sorted_process_method = {key: process_method[key] for key in default_order if key in process_method}
24
+ return sorted_process_method
25
+
26
+ def __init_data(self, data: DataPoint):
27
+ if data.task == "NER":
28
+ data.instruction = config['agent']['default_ner']
29
+ data.output_schema = "EntityList"
30
+ elif data.task == "RE":
31
+ data.instruction = config['agent']['default_re']
32
+ data.output_schema = "RelationList"
33
+ elif data.task == "EE":
34
+ data.instruction = config['agent']['default_ee']
35
+ data.output_schema = "EventList"
36
+ return data
37
+
38
+ # main entry
39
+ def get_extract_result(self,
40
+ task: TaskType,
41
+ instruction: str = "",
42
+ text: str = "",
43
+ output_schema: str = "",
44
+ constraint: str = "",
45
+ use_file: bool = False,
46
+ truth: str = "",
47
+ mode: str = "quick",
48
+ update_case: bool = False
49
+ ):
50
+
51
+ data = DataPoint(task=task, instruction=instruction, text=text, output_schema=output_schema, constraint=constraint, use_file=use_file, truth=truth)
52
+ data = self.__init_data(data)
53
+ data.instruction = "In the tranquil seaside town, the summer evening cast a golden glow over everything. The townsfolk gathered at the café by the pier, enjoying the sea breeze while eagerly anticipating the annual Ocean Festival's opening ceremony. \nFirst to arrive was Mayor William, dressed in a deep blue suit, holding a roll of his speech. He smiled and greeted the residents, who held deep respect for their community-minded mayor. Beside him trotted Max, his loyal golden retriever, wagging his tail excitedly at every familiar face he saw. \nFollowing closely was Emily, the town’s high school teacher, accompanied by a group of students ready to perform a musical piece they'd rehearsed. One of the girls carried Polly, a vibrant green parrot, on her shoulder. Polly occasionally chimed in with cheerful squawks, adding to the lively atmosphere. \nNot far away, Captain Jack, with his trusty pipe in hand, chatted with old friends about this year's catch. His fleet was the town’s economic backbone, and his seasoned face and towering presence were complemented by the presence of Whiskers, his orange tabby cat, who loved lounging on the dock, attentively watching the gentle waves. \nInside the café, Kate was bustling about, serving guests. As the owner, with her fiery red curls and vivacious spirit, she was the heart of the place. Her friend Susan, an artist living in a tiny cottage nearby, was helping her prepare refreshing beverages. Slinky, Susan's mischievous ferret, darted playfully between the tables, much to the delight of the children present. \nLeaning on the café's railing, a young boy named Tommy watched the sea with wide, gleaming eyes, filled with dreams of the future. By his side sat Daisy, a spirited little dachshund, barking excitedly at the seagulls flying overhead. Tommy's mother, Lucy, stood beside him, smiling softly as she held a seashell he had just found on the beach. \nAmong the crowd, a group of unnamed tourists snapped photos, capturing memories of the charming festival. Street vendors called out, selling their wares—handmade jewelry and sweet confections—as the scent of grilled seafood wafted through the air. \nSuddenly, a burst of laughter erupted—it was James and his band making their grand entrance. Accompanying them was Benny, a friendly border collie who \"performed\" with the band, delighting the crowd with his antics. Set to play a big concert after the opening ceremony, James, the town's star musician, had won the hearts of locals with his soulful tunes. \nAs dusk settled, lights were strung across the streets, casting a magical glow over the town. Mayor William took the stage to deliver his speech, with Max sitting proudly by his side. The festival atmosphere reached its vibrant peak, and in this small town, each person—and animal—carried their own dreams and stories, yet at this moment, they were united by the shared celebration."
54
+ data.chunk_text_list.append("In the tranquil seaside town, the summer evening cast a golden glow over everything. The townsfolk gathered at the café by the pier, enjoying the sea breeze while eagerly anticipating the annual Ocean Festival's opening ceremony. \nFirst to arrive was Mayor William, dressed in a deep blue suit, holding a roll of his speech. He smiled and greeted the residents, who held deep respect for their community-minded mayor. Beside him trotted Max, his loyal golden retriever, wagging his tail excitedly at every familiar face he saw. \nFollowing closely was Emily, the town’s high school teacher, accompanied by a group of students ready to perform a musical piece they'd rehearsed. One of the girls carried Polly, a vibrant green parrot, on her shoulder. Polly occasionally chimed in with cheerful squawks, adding to the lively atmosphere. \nNot far away, Captain Jack, with his trusty pipe in hand, chatted with old friends about this year's catch. His fleet was the town’s economic backbone, and his seasoned face and towering presence were complemented by the presence of Whiskers, his orange tabby cat, who loved lounging on the dock, attentively watching the gentle waves. \nInside the café, Kate was bustling about, serving guests. As the owner, with her fiery red curls and vivacious spirit, she was the heart of the place. Her friend Susan, an artist living in a tiny cottage nearby, was helping her prepare refreshing beverages. Slinky, Susan's mischievous ferret, darted playfully between the tables, much to the delight of the children present. \nLeaning on the café's railing, a young boy named Tommy watched the sea with wide, gleaming eyes, filled with dreams of the future. By his side sat Daisy, a spirited little dachshund, barking excitedly at the seagulls flying overhead. Tommy's mother, Lucy, stood beside him, smiling softly as she held a seashell he had just found on the beach. \nAmong the crowd, a group of unnamed tourists snapped photos, capturing memories of the charming festival. Street vendors called out, selling their wares—handmade jewelry and sweet confections—as the scent of grilled seafood wafted through the air. \nSuddenly, a burst of laughter erupted—it was James and his band making their grand entrance. Accompanying them was Benny, a friendly border collie who \"performed\" with the band, delighting the crowd with his antics. Set to play a big concert after the opening ceremony, James, the town's star musician, had won the hearts of locals with his soulful tunes. \nAs dusk settled, lights were strung across the streets, casting a magical glow over the town. Mayor William took the stage to deliver his speech, with Max sitting proudly by his side. The festival atmosphere reached its vibrant peak, and in this small town, each person—and animal—carried their own dreams and stories, yet at this moment, they were united by the shared celebration.")
55
+ data.distilled_text = "This text is from the field of Slice of Life and represents the genre of Novel."
56
+ data.pred = {
57
+ "characters": [
58
+ {
59
+ "name": "Mayor William",
60
+ "role": "Mayor"
61
+ },
62
+ {
63
+ "name": "Max",
64
+ "role": "Golden Retriever, Mayor William's dog"
65
+ },
66
+ {
67
+ "name": "Emily",
68
+ "role": "High school teacher"
69
+ },
70
+ {
71
+ "name": "Polly",
72
+ "role": "Parrot, accompanying a student"
73
+ },
74
+ {
75
+ "name": "Captain Jack",
76
+ "role": "Captain"
77
+ },
78
+ {
79
+ "name": "Whiskers",
80
+ "role": "Orange tabby cat, Captain Jack's pet"
81
+ },
82
+ {
83
+ "name": "Kate",
84
+ "role": "Café owner"
85
+ },
86
+ {
87
+ "name": "Susan",
88
+ "role": "Artist, Kate's friend"
89
+ },
90
+ {
91
+ "name": "Slinky",
92
+ "role": "Ferret, Susan's pet"
93
+ },
94
+ {
95
+ "name": "Tommy",
96
+ "role": "Young boy"
97
+ },
98
+ {
99
+ "name": "Daisy",
100
+ "role": "Dachshund, Tommy's pet"
101
+ },
102
+ {
103
+ "name": "Lucy",
104
+ "role": "Tommy's mother"
105
+ },
106
+ {
107
+ "name": "James",
108
+ "role": "Musician, band leader"
109
+ },
110
+ {
111
+ "name": "Benny",
112
+ "role": "Border Collie, accompanying James and his band"
113
+ },
114
+ {
115
+ "name": "Unnamed Tourists",
116
+ "role": "Visitors at the festival"
117
+ },
118
+ {
119
+ "name": "Street Vendors",
120
+ "role": "Sellers at the festival"
121
+ }
122
+ ]
123
+ }
124
+
125
+ data.truth = {
126
+ "characters": [
127
+ {
128
+ "name": "Mayor William",
129
+ "role": "The friendly and respected mayor of the seaside town."
130
+ },
131
+ {
132
+ "name": "Emily",
133
+ "role": "A high school teacher guiding students in a festival performance."
134
+ },
135
+ {
136
+ "name": "Captain Jack",
137
+ "role": "A seasoned sailor whose fleet supports the town."
138
+ },
139
+ {
140
+ "name": "Kate",
141
+ "role": "The welcoming owner of the local café."
142
+ },
143
+ {
144
+ "name": "Susan",
145
+ "role": "An artist known for her ocean-themed paintings."
146
+ },
147
+ {
148
+ "name": "Tommy",
149
+ "role": "A young boy with dreams of the sea."
150
+ },
151
+ {
152
+ "name": "Lucy",
153
+ "role": "Tommy's caring and supportive mother."
154
+ },
155
+ {
156
+ "name": "James",
157
+ "role": "A charismatic musician and band leader."
158
+ }
159
+ ]
160
+ }
161
+
162
+
163
+ # Case Update
164
+ if update_case:
165
+ if (data.truth == ""):
166
+ truth = input("Please enter the correct answer you prefer, or press Enter to accept the current answer: ")
167
+ if truth.strip() == "":
168
+ data.truth = data.pred
169
+ else:
170
+ data.truth = extract_json_dict(truth)
171
+ self.case_repo.update_case(data)
172
+
173
+ # return result
174
+ result = data.pred
175
+ trajectory = data.get_result_trajectory()
176
+
177
+ return result, trajectory, "a", "b"
178
+
179
+ model = DeepSeek(model_name_or_path="deepseek-chat", api_key="")
180
+ pipeline = Pipeline(model)
181
+ result, trajectory, *_ = pipeline.get_extract_result(update_case=True, task="Base")
src/main.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import json
3
+ import gradio as gr
4
+
5
+ from pipeline import Pipeline
6
+ from models import *
7
+
8
+
9
+ examples = [
10
+ {
11
+ "task": "NER",
12
+ "use_file": False,
13
+ "text": "Finally, every other year , ELRA organizes a major conference LREC , the International Language Resources and Evaluation Conference .",
14
+ "instruction": "",
15
+ "constraint": """["nationality", "country capital", "place of death", "children", "location contains", "place of birth", "place lived", "administrative division of country", "country of administrative divisions", "company", "neighborhood of", "company founders"]""",
16
+ "file_path": None,
17
+ },
18
+ {
19
+ "task": "RE",
20
+ "use_file": False,
21
+ "text": "The aid group Doctors Without Borders said that since Saturday , more than 275 wounded people had been admitted and treated at Donka Hospital in the capital of Guinea , Conakry .",
22
+ "instruction": "",
23
+ "constraint": """["nationality", "country capital", "place of death", "children", "location contains", "place of birth", "place lived", "administrative division of country", "country of administrative divisions", "company", "neighborhood of", "company founders"]""",
24
+ "file_path": None,
25
+ },
26
+ {
27
+ "task": "EE",
28
+ "use_file": False,
29
+ "text": "The file suggested to the user contains no software related to video streaming and simply carries the malicious payload that later compromises victim \u2019s account and sends out the deceptive messages to all victim \u2019s contacts .",
30
+ "instruction": "",
31
+ "constraint": """{"phishing": ["damage amount", "attack pattern", "tool", "victim", "place", "attacker", "purpose", "trusted entity", "time"], "data breach": ["damage amount", "attack pattern", "number of data", "number of victim", "tool", "compromised data", "victim", "place", "attacker", "purpose", "time"], "ransom": ["damage amount", "attack pattern", "payment method", "tool", "victim", "place", "attacker", "price", "time"], "discover vulnerability": ["vulnerable system", "vulnerability", "vulnerable system owner", "vulnerable system version", "supported platform", "common vulnerabilities and exposures", "capabilities", "time", "discoverer"], "patch vulnerability": ["vulnerable system", "vulnerability", "issues addressed", "vulnerable system version", "releaser", "supported platform", "common vulnerabilities and exposures", "patch number", "time", "patch"]}""",
32
+ "file_path": None,
33
+ },
34
+ # {
35
+ # "task": "Base",
36
+ # "use_file": True,
37
+ # "file_path": "data/Harry_Potter_Chapter_1.pdf",
38
+ # "instruction": "Extract main characters and the background setting from this chapter.",
39
+ # "constraint": "",
40
+ # "text": "",
41
+ # },
42
+ # {
43
+ # "task": "Base",
44
+ # "use_file": True,
45
+ # "file_path": "data/Tulsi_Gabbard_News.html",
46
+ # "instruction": "Extract key information from the given text.",
47
+ # "constraint": "",
48
+ # "text": "",
49
+ # },
50
+ ]
51
+
52
+
53
+ def create_interface():
54
+ with gr.Blocks(title="OneKE Demo") as demo:
55
+ gr.HTML("""
56
+ <div style="text-align:center;">
57
+ <p align="center">
58
+ <a href="https://github.com/zjunlp/DeepKE/blob/main/example/llm/assets/oneke_logo.png">
59
+ <img src="https://raw.githubusercontent.com/zjunlp/DeepKE/refs/heads/main/example/llm/assets/oneke_logo.png" width="240"/>
60
+ </a>
61
+ </p>
62
+ <h1>OneKE: A Dockerized Schema-Guided LLM Agent-based Knowledge Extraction System</h1>
63
+ <p>
64
+ 🌐[<a href="https://oneke.openkg.cn/" target="_blank">Web</a>]
65
+ ⌨️[<a href="https://github.com/zjunlp/OneKE" target="_blank">Code</a>]
66
+ 📹[<a href="http://oneke.openkg.cn/demo.mp4" target="_blank">Video</a>]
67
+ </p>
68
+ </div>
69
+ """)
70
+
71
+ example_button_gr = gr.Button("🎲 Quick Start with an Example 🎲")
72
+
73
+
74
+ with gr.Row():
75
+ with gr.Column():
76
+ model_gr = gr.Dropdown(choices=["gpt-3.5-turbo", "gpt-4o", "gpt-4o-mini"], label="🤖 Select your Model")
77
+ api_key_gr = gr.Textbox(label="🔑 Enter your API-Key")
78
+ with gr.Column():
79
+ task_gr = gr.Dropdown(choices=["Base", "NER", "RE", "EE"], label="🎯 Select your Task")
80
+ use_file_gr = gr.Checkbox(label="📂 Use File", value=True)
81
+
82
+ file_path_gr = gr.File(label="📖 Upload a File", visible=True)
83
+ text_gr = gr.Textbox(label="📖 Text", placeholder="Enter your Text", visible=False)
84
+ instruction_gr = gr.Textbox(label="🕹️ Instruction", visible=True)
85
+ constraint_gr = gr.Textbox(label="🕹️ Constraint", visible=False)
86
+
87
+ def update_fields(task):
88
+ if task == "Base":
89
+ return gr.update(visible=True, label="🕹️ Instruction", placeholder="Enter your Instruction"), gr.update(visible=False)
90
+ elif task == "NER":
91
+ return gr.update(visible=False), gr.update(visible=True, label="🕹️ Constraint", placeholder="Enter your NER Constraint")
92
+ elif task == "RE":
93
+ return gr.update(visible=False), gr.update(visible=True, label="🕹️ Constraint", placeholder="Enter your RE Constraint")
94
+ elif task == "EE":
95
+ return gr.update(visible=False), gr.update(visible=True, label="🕹️ Constraint", placeholder="Enter your EE Constraint")
96
+
97
+ def update_input_fields(use_file):
98
+ if use_file:
99
+ return gr.update(visible=False), gr.update(visible=True)
100
+ else:
101
+ return gr.update(visible=True), gr.update(visible=False)
102
+
103
+ def start_with_example():
104
+ example_index = random.randint(0, len(examples) - 1)
105
+ example = examples[example_index]
106
+ return (
107
+ gr.update(value=example["task"]),
108
+ gr.update(value=example["use_file"]),
109
+ gr.update(value=example["file_path"], visible=example["use_file"]),
110
+ gr.update(value=example["text"], visible=not example["use_file"]),
111
+ gr.update(value=example["instruction"], visible=example["task"] == "Base"),
112
+ gr.update(value=example["constraint"], visible=example["task"] in ["NER", "RE", "EE"]),
113
+ )
114
+
115
+ def submit(model, api_key, task, instruction, constraint, text, use_file, file_path):
116
+ try:
117
+ # 创建 Pipeline 实例
118
+ pipeline = Pipeline(ChatGPT(model_name_or_path=model, api_key=api_key))
119
+ if task == "Base":
120
+ instruction = instruction
121
+ constraint = ""
122
+ else:
123
+ instruction = ""
124
+ constraint = constraint
125
+ if use_file:
126
+ text = ""
127
+ file_path = file_path
128
+ else:
129
+ text = text
130
+ file_path = None
131
+
132
+ # 调用 Pipeline
133
+ _, _, ger_frontend_schema, ger_frontend_res = pipeline.get_extract_result(
134
+ task=task,
135
+ instruction=instruction,
136
+ constraint=constraint,
137
+ use_file=use_file,
138
+ file_path=file_path,
139
+ text=text,
140
+ )
141
+
142
+ ger_frontend_schema = str(ger_frontend_schema)
143
+ ger_frontend_res = json.dumps(ger_frontend_res, ensure_ascii=False, indent=4) if isinstance(ger_frontend_res, dict) else str(ger_frontend_res)
144
+ return ger_frontend_schema, ger_frontend_res, gr.update(value="", visible=False)
145
+
146
+ except Exception as e:
147
+ error_message = f"⚠️ Error:\n {str(e)}"
148
+ return "", "", gr.update(value=error_message, visible=True)
149
+
150
+ def clear_all():
151
+ return (
152
+ gr.update(value=""), # model
153
+ gr.update(value=""), # API Key
154
+ gr.update(value=""), # task
155
+ gr.update(value="", visible=False), # instruction
156
+ gr.update(value="", visible=False), # constraint
157
+ gr.update(value=True), # use_file
158
+ gr.update(value="", visible=False), # text
159
+ gr.update(value=None, visible=True), # file_path
160
+ gr.update(value=""),
161
+ gr.update(value=""),
162
+ gr.update(value="", visible=False), # error_output
163
+ )
164
+
165
+ with gr.Row():
166
+ submit_button_gr = gr.Button("Submit", variant="primary", scale=8)
167
+ clear_button = gr.Button("Clear", scale=5)
168
+ gr.HTML("""
169
+ <div style="width: 100%; text-align: center; font-size: 16px; font-weight: bold; position: relative; margin: 20px 0;">
170
+ <span style="position: absolute; left: 0; top: 50%; transform: translateY(-50%); width: 45%; border-top: 1px solid #ccc;"></span>
171
+ <span style="position: relative; z-index: 1; background-color: white; padding: 0 10px;">Output:</span>
172
+ <span style="position: absolute; right: 0; top: 50%; transform: translateY(-50%); width: 45%; border-top: 1px solid #ccc;"></span>
173
+ </div>
174
+ """)
175
+ error_output_gr = gr.Textbox(label="😵‍💫 Ops, an Error Occurred", visible=False)
176
+ with gr.Row():
177
+ with gr.Column(scale=1):
178
+ py_output_gr = gr.Code(label="🤔 Generated Schema", language="python", lines=10, interactive=False)
179
+ with gr.Column(scale=1):
180
+ json_output_gr = gr.Code(label="😉 Final Answer", language="json", lines=10, interactive=False)
181
+
182
+ task_gr.change(fn=update_fields, inputs=task_gr, outputs=[instruction_gr, constraint_gr])
183
+ use_file_gr.change(fn=update_input_fields, inputs=use_file_gr, outputs=[text_gr, file_path_gr])
184
+
185
+ example_button_gr.click(
186
+ fn=start_with_example,
187
+ inputs=[],
188
+ outputs=[
189
+ task_gr,
190
+ use_file_gr,
191
+ file_path_gr,
192
+ text_gr,
193
+ instruction_gr,
194
+ constraint_gr,
195
+ ],
196
+ )
197
+ submit_button_gr.click(
198
+ fn=submit,
199
+ inputs=[
200
+ model_gr,
201
+ api_key_gr,
202
+ task_gr,
203
+ instruction_gr,
204
+ constraint_gr,
205
+ text_gr,
206
+ use_file_gr,
207
+ file_path_gr,
208
+ ],
209
+ outputs=[py_output_gr, json_output_gr, error_output_gr],
210
+ show_progress=True,
211
+ )
212
+ clear_button.click(
213
+ fn=clear_all,
214
+ outputs=[
215
+ model_gr,
216
+ api_key_gr,
217
+ task_gr,
218
+ instruction_gr,
219
+ constraint_gr,
220
+ use_file_gr,
221
+ text_gr,
222
+ file_path_gr,
223
+ py_output_gr,
224
+ json_output_gr,
225
+ error_output_gr,
226
+ ],
227
+ )
228
+
229
+ return demo
230
+
231
+
232
+ interface = create_interface()
233
+ interface.launch()
src/models/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .llm_def import BaseEngine, LLaMA, Qwen, MiniCPM, ChatGLM, ChatGPT, DeepSeek
2
+ from .prompt_example import *
3
+ from .prompt_template import *
src/models/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (434 Bytes). View file
 
src/models/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (315 Bytes). View file
 
src/models/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (359 Bytes). View file
 
src/models/__pycache__/llm_def.cpython-311.pyc ADDED
Binary file (11.8 kB). View file
 
src/models/__pycache__/llm_def.cpython-37.pyc ADDED
Binary file (7.14 kB). View file
 
src/models/__pycache__/llm_def.cpython-39.pyc ADDED
Binary file (6.8 kB). View file
 
src/models/__pycache__/prompt_example.cpython-311.pyc ADDED
Binary file (5.67 kB). View file
 
src/models/__pycache__/prompt_example.cpython-39.pyc ADDED
Binary file (5.66 kB). View file
 
src/models/__pycache__/prompt_template.cpython-311.pyc ADDED
Binary file (5.42 kB). View file
 
src/models/__pycache__/prompt_template.cpython-39.pyc ADDED
Binary file (4.95 kB). View file
 
src/models/llm_def.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Surpported Models.
3
+ Supports:
4
+ - Open Source:LLaMA3, Qwen2.5, MiniCPM3, ChatGLM4
5
+ - Closed Source: ChatGPT, DeepSeek
6
+ """
7
+
8
+ from transformers import pipeline
9
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoTokenizer
10
+ import torch
11
+ import openai
12
+ import os
13
+ from openai import OpenAI
14
+
15
+ # The inferencing code is taken from the official documentation
16
+
17
+ class BaseEngine:
18
+ def __init__(self, model_name_or_path: str):
19
+ self.name = None
20
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)
21
+ self.temperature = 0.2
22
+ self.top_p = 0.9
23
+ self.max_tokens = 1024
24
+
25
+ def get_chat_response(self, prompt):
26
+ raise NotImplementedError
27
+
28
+ def set_hyperparameter(self, temperature: float = 0.2, top_p: float = 0.9, max_tokens: int = 1024):
29
+ self.temperature = temperature
30
+ self.top_p = top_p
31
+ self.max_tokens = max_tokens
32
+
33
+ class LLaMA(BaseEngine):
34
+ def __init__(self, model_name_or_path: str):
35
+ super().__init__(model_name_or_path)
36
+ self.name = "LLaMA"
37
+ self.model_id = model_name_or_path
38
+ self.pipeline = pipeline(
39
+ "text-generation",
40
+ model=self.model_id,
41
+ model_kwargs={"torch_dtype": torch.bfloat16},
42
+ device_map="auto",
43
+ )
44
+ self.terminators = [
45
+ self.pipeline.tokenizer.eos_token_id,
46
+ self.pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")
47
+ ]
48
+
49
+ def get_chat_response(self, prompt):
50
+ messages = [
51
+ {"role": "system", "content": "You are a helpful assistant."},
52
+ {"role": "user", "content": prompt},
53
+ ]
54
+ outputs = self.pipeline(
55
+ messages,
56
+ max_new_tokens=self.max_tokens,
57
+ eos_token_id=self.terminators,
58
+ do_sample=True,
59
+ temperature=self.temperature,
60
+ top_p=self.top_p,
61
+ )
62
+ return outputs[0]["generated_text"][-1]['content'].strip()
63
+
64
+ class Qwen(BaseEngine):
65
+ def __init__(self, model_name_or_path: str):
66
+ super().__init__(model_name_or_path)
67
+ self.name = "Qwen"
68
+ self.model_id = model_name_or_path
69
+ self.model = AutoModelForCausalLM.from_pretrained(
70
+ self.model_id,
71
+ torch_dtype="auto",
72
+ device_map="auto"
73
+ )
74
+
75
+ def get_chat_response(self, prompt):
76
+ messages = [
77
+ {"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."},
78
+ {"role": "user", "content": prompt}
79
+ ]
80
+ text = self.tokenizer.apply_chat_template(
81
+ messages,
82
+ tokenize=False,
83
+ add_generation_prompt=True
84
+ )
85
+ model_inputs = self.tokenizer([text], return_tensors="pt").to(self.model.device)
86
+ generated_ids = self.model.generate(
87
+ **model_inputs,
88
+ temperature=self.temperature,
89
+ top_p=self.top_p,
90
+ max_new_tokens=self.max_tokens
91
+ )
92
+ generated_ids = [
93
+ output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
94
+ ]
95
+ response = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
96
+
97
+ return response
98
+
99
+ class MiniCPM(BaseEngine):
100
+ def __init__(self, model_name_or_path: str):
101
+ super().__init__(model_name_or_path)
102
+ self.name = "MiniCPM"
103
+ self.model_id = model_name_or_path
104
+ self.model = AutoModelForCausalLM.from_pretrained(
105
+ self.model_id,
106
+ torch_dtype=torch.bfloat16,
107
+ device_map="auto",
108
+ trust_remote_code=True
109
+ )
110
+
111
+ def get_chat_response(self, prompt):
112
+ messages = [
113
+ {"role": "system", "content": "You are a helpful assistant."},
114
+ {"role": "user", "content": prompt}
115
+ ]
116
+ model_inputs = self.tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True).to(self.model.device)
117
+ model_outputs = self.model.generate(
118
+ model_inputs,
119
+ temperature=self.temperature,
120
+ top_p=self.top_p,
121
+ max_new_tokens=self.max_tokens
122
+ )
123
+ output_token_ids = [
124
+ model_outputs[i][len(model_inputs[i]):] for i in range(len(model_inputs))
125
+ ]
126
+ response = self.tokenizer.batch_decode(output_token_ids, skip_special_tokens=True)[0].strip()
127
+
128
+ return response
129
+
130
+ class ChatGLM(BaseEngine):
131
+ def __init__(self, model_name_or_path: str):
132
+ super().__init__(model_name_or_path)
133
+ self.name = "ChatGLM"
134
+ self.model_id = model_name_or_path
135
+ self.model = AutoModelForCausalLM.from_pretrained(
136
+ self.model_id,
137
+ torch_dtype=torch.bfloat16,
138
+ device_map="auto",
139
+ low_cpu_mem_usage=True,
140
+ trust_remote_code=True
141
+ )
142
+
143
+ def get_chat_response(self, prompt):
144
+ messages = [
145
+ {"role": "system", "content": "You are a helpful assistant."},
146
+ {"role": "user", "content": prompt}
147
+ ]
148
+ model_inputs = self.tokenizer.apply_chat_template(messages, return_tensors="pt", return_dict=True, add_generation_prompt=True, tokenize=True).to(self.model.device)
149
+ model_outputs = self.model.generate(
150
+ **model_inputs,
151
+ temperature=self.temperature,
152
+ top_p=self.top_p,
153
+ max_new_tokens=self.max_tokens
154
+ )
155
+ model_outputs = model_outputs[:, model_inputs['input_ids'].shape[1]:]
156
+ response = self.tokenizer.batch_decode(model_outputs, skip_special_tokens=True)[0].strip()
157
+
158
+ return response
159
+
160
+ class ChatGPT(BaseEngine):
161
+ def __init__(self, model_name_or_path: str, api_key: str, base_url=openai.base_url):
162
+ self.name = "ChatGPT"
163
+ self.model = model_name_or_path
164
+ self.base_url = base_url
165
+ self.temperature = 0.2
166
+ self.top_p = 0.9
167
+ self.max_tokens = 1024
168
+ if api_key != "":
169
+ self.api_key = api_key
170
+ else:
171
+ self.api_key = os.environ["OPENAI_API_KEY"]
172
+ self.client = OpenAI(api_key=self.api_key, base_url=self.base_url)
173
+
174
+ def get_chat_response(self, input):
175
+ response = self.client.chat.completions.create(
176
+ model=self.model,
177
+ messages=[
178
+ {"role": "user", "content": input},
179
+ ],
180
+ stream=False,
181
+ temperature=self.temperature,
182
+ max_tokens=self.max_tokens,
183
+ stop=None
184
+ )
185
+ return response.choices[0].message.content
186
+
187
+ class DeepSeek(BaseEngine):
188
+ def __init__(self, model_name_or_path: str, api_key: str, base_url="https://api.deepseek.com"):
189
+ self.name = "DeepSeek"
190
+ self.model = model_name_or_path
191
+ self.base_url = base_url
192
+ self.temperature = 0.2
193
+ self.top_p = 0.9
194
+ self.max_tokens = 1024
195
+ if api_key != "":
196
+ self.api_key = api_key
197
+ else:
198
+ self.api_key = os.environ["DEEPSEEK_API_KEY"]
199
+ self.client = OpenAI(api_key=self.api_key, base_url=self.base_url)
200
+
201
+ def get_chat_response(self, input):
202
+ response = self.client.chat.completions.create(
203
+ model=self.model,
204
+ messages=[
205
+ {"role": "user", "content": input},
206
+ ],
207
+ stream=False,
208
+ temperature=self.temperature,
209
+ max_tokens=self.max_tokens,
210
+ stop=None
211
+ )
212
+ return response.choices[0].message.content
src/models/prompt_example.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ json_schema_examples = """
2
+ **Task**: Please extract all economic policies affecting the stock market between 2015 and 2023 and the exact dates of their implementation.
3
+ **Text**: This text is from the field of Economics and represents the genre of Article.
4
+ ...(example text)...
5
+ **Output Schema**:
6
+ {
7
+ "economic_policies": [
8
+ {
9
+ "name": null,
10
+ "implementation_date": null
11
+ }
12
+ ]
13
+ }
14
+
15
+ Example2:
16
+ **Task**: Tell me the main content of papers related to NLP between 2022 and 2023.
17
+ **Text**: This text is from the field of AI and represents the genre of Research Paper.
18
+ ...(example text)...
19
+ **Output Schema**:
20
+ {
21
+ "papers": [
22
+ {
23
+ "title": null,
24
+ "content": null
25
+ }
26
+ ]
27
+ }
28
+
29
+ Example3:
30
+ **Task**: Extract all the information in the given text.
31
+ **Text**: This text is from the field of Political and represents the genre of News Report.
32
+ ...(example text)...
33
+ **Output Schema**:
34
+ Answer:
35
+ {
36
+ "news_report":
37
+ {
38
+ "title": null,
39
+ "summary": null,
40
+ "publication_date": null,
41
+ "keywords": [],
42
+ "events": [
43
+ {
44
+ "name": null,
45
+ "time": null,
46
+ "people_involved": [],
47
+ "cause": null,
48
+ "process": null,
49
+ "result": null
50
+ }
51
+ ],
52
+ quotes: [],
53
+ viewpoints: []
54
+ }
55
+ }
56
+ """
57
+
58
+ code_schema_examples = """
59
+ Example1:
60
+ **Task**: Extract all the entities in the given text.
61
+ **Text**:
62
+ ...(example text)...
63
+ **Output Schema**:
64
+ ```python
65
+ from typing import List, Optional
66
+ from pydantic import BaseModel, Field
67
+
68
+ class Entity(BaseModel):
69
+ label : str = Field(description="The type or category of the entity, such as 'Process', 'Technique', 'Data Structure', 'Methodology', 'Person', etc. ")
70
+ name : str = Field(description="The specific name of the entity. It should represent a single, distinct concept and must not be an empty string. For example, if the entity is a 'Technique', the name could be 'Neural Networks'.")
71
+
72
+ class ExtractionTarget(BaseModel):
73
+ entity_list : List[Entity] = Field(description="All the entities presented in the context. The entities should encode ONE concept.")
74
+ ```
75
+
76
+ Example2:
77
+ **Task**: Extract all the information in the given text.
78
+ **Text**: This text is from the field of Political and represents the genre of News Article.
79
+ ...(example text)...
80
+ **Output Schema**:
81
+ ```python
82
+ from typing import List, Optional
83
+ from pydantic import BaseModel, Field
84
+
85
+ class Person(BaseModel):
86
+ name: str = Field(description="The name of the person")
87
+ identity: Optional[str] = Field(description="The occupation, status or characteristics of the person.")
88
+ role: Optional[str] = Field(description="The role or function the person plays in an event.")
89
+
90
+ class Event(BaseModel):
91
+ name: str = Field(description="Name of the event")
92
+ time: Optional[str] = Field(description="Time when the event took place")
93
+ people_involved: Optional[List[Person]] = Field(description="People involved in the event")
94
+ cause: Optional[str] = Field(default=None, description="Reason for the event, if applicable")
95
+ process: Optional[str] = Field(description="Details of the event process")
96
+ result: Optional[str] = Field(default=None, description="Result or outcome of the event")
97
+
98
+ class ExtractionTarget(BaseModel):
99
+ title: str = Field(description="The title or headline of the news article")
100
+ summary: str = Field(description="A brief summary of the news article")
101
+ publication_date: Optional[str] = Field(description="The publication date of the article")
102
+ keywords: Optional[List[str]] = Field(description="List of keywords or topics covered in the article")
103
+ events: List[Event] = Field(description="Events covered in the article")
104
+ quotes: Optional[List[str]] = Field(default=None, description="Quotes related to the news, if any")
105
+ viewpoints: Optional[List[str]] = Field(default=None, description="Different viewpoints regarding the news")
106
+ ```
107
+
108
+ Example3:
109
+ **Task**: Extract the key information in the given text.
110
+ **Text**: This text is from the field of AI and represents the genre of Research Paper.
111
+ ...(example text)...
112
+ ```python
113
+ from typing import List, Optional
114
+ from pydantic import BaseModel, Field
115
+
116
+ class MetaData(BaseModel):
117
+ title : str = Field(description="The title of the article")
118
+ authors : List[str] = Field(description="The list of the article's authors")
119
+ abstract: str = Field(description="The article's abstract")
120
+ key_words: List[str] = Field(description="The key words associated with the article")
121
+
122
+ class Baseline(BaseModel):
123
+ method_name : str = Field(description="The name of the baseline method")
124
+ proposed_solution : str = Field(description="the proposed solution in details")
125
+ performance_metrics : str = Field(description="The performance metrics of the method and comparative analysis")
126
+
127
+ class ExtractionTarget(BaseModel):
128
+
129
+ key_contributions: List[str] = Field(description="The key contributions of the article")
130
+ limitation_of_sota : str=Field(description="the summary limitation of the existing work")
131
+ proposed_solution : str = Field(description="the proposed solution in details")
132
+ baselines : List[Baseline] = Field(description="The list of baseline methods and their details")
133
+ performance_metrics : str = Field(description="The performance metrics of the method and comparative analysis")
134
+ paper_limitations : str=Field(description="The limitations of the proposed solution of the paper")
135
+ ```
136
+
137
+ """
src/models/prompt_template.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.prompts import PromptTemplate
2
+ from .prompt_example import *
3
+
4
+ # ==================================================================== #
5
+ # SCHEMA AGENT #
6
+ # ==================================================================== #
7
+
8
+ # Get Text Analysis
9
+ TEXT_ANALYSIS_INSTRUCTION = """
10
+ **Instruction**: Please analyze and categorize the given text.
11
+ {examples}
12
+ **Text**: {text}
13
+
14
+ **Output Shema**: {schema}
15
+ """
16
+
17
+ text_analysis_instruction = PromptTemplate(
18
+ input_variables=["examples", "text", "schema"],
19
+ template=TEXT_ANALYSIS_INSTRUCTION,
20
+ )
21
+
22
+ # Get Deduced Schema Json
23
+ DEDUCE_SCHEMA_JSON_INSTRUCTION = """
24
+ **Instruction**: Generate an output format that meets the requirements as described in the task. Pay attention to the following requirements:
25
+ - Format: Return your responses in dictionary format as a JSON object.
26
+ - Content: Do not include any actual data; all attributes values should be set to None.
27
+ - Note: Attributes not mentioned in the task description should be ignored.
28
+ {examples}
29
+ **Task**: {instruction}
30
+
31
+ **Text**: {distilled_text}
32
+ {text}
33
+
34
+ Now please deduce the output schema in json format. All attributes values should be set to None.
35
+ **Output Schema**:
36
+ """
37
+
38
+ deduced_schema_json_instruction = PromptTemplate(
39
+ input_variables=["examples", "instruction", "distilled_text", "text", "schema"],
40
+ template=DEDUCE_SCHEMA_JSON_INSTRUCTION,
41
+ )
42
+
43
+ # Get Deduced Schema Code
44
+ DEDUCE_SCHEMA_CODE_INSTRUCTION = """
45
+ **Instruction**: Based on the provided text and task description, Define the output schema in Python using Pydantic. Name the final extraction target class as 'ExtractionTarget'.
46
+ {examples}
47
+ **Task**: {instruction}
48
+
49
+ **Text**: {distilled_text}
50
+ {text}
51
+
52
+ Now please deduce the output schema. Ensure that the output code snippet is wrapped in '```',and can be directly parsed by the Python interpreter.
53
+ **Output Schema**: """
54
+ deduced_schema_code_instruction = PromptTemplate(
55
+ input_variables=["examples", "instruction", "distilled_text", "text"],
56
+ template=DEDUCE_SCHEMA_CODE_INSTRUCTION,
57
+ )
58
+
59
+
60
+ # ==================================================================== #
61
+ # EXTRACTION AGENT #
62
+ # ==================================================================== #
63
+
64
+ EXTRACT_INSTRUCTION = """
65
+ **Instruction**: You are an agent skilled in information extarction. {instruction}
66
+ {examples}
67
+ **Text**: {text}
68
+ {additional_info}
69
+ **Output Schema**: {schema}
70
+
71
+ Now please extract the corresponding information from the text. Ensure that the information you extract has a clear reference in the given text. Set any property not explicitly mentioned in the text to null.
72
+ """
73
+
74
+ extract_instruction = PromptTemplate(
75
+ input_variables=["instruction", "examples", "text", "schema", "additional_info"],
76
+ template=EXTRACT_INSTRUCTION,
77
+ )
78
+
79
+ SUMMARIZE_INSTRUCTION = """
80
+ **Instruction**: Below is a list of results obtained after segmenting and extracting information from a long article. Please consolidate all the answers to generate a final response.
81
+ {examples}
82
+ **Task**: {instruction}
83
+
84
+ **Result List**: {answer_list}
85
+
86
+ **Output Schema**: {schema}
87
+ Now summarize all the information from the Result List.
88
+ """
89
+ summarize_instruction = PromptTemplate(
90
+ input_variables=["instruction", "examples", "answer_list", "schema"],
91
+ template=SUMMARIZE_INSTRUCTION,
92
+ )
93
+
94
+
95
+ # ==================================================================== #
96
+ # REFLECION AGENT #
97
+ # ==================================================================== #
98
+ REFLECT_INSTRUCTION = """**Instruction**: You are an agent skilled in reflection and optimization based on the original result. Refer to **Reflection Reference** to identify potential issues in the current extraction results.
99
+
100
+ **Reflection Reference**: {examples}
101
+
102
+ Now please review each element in the extraction result. Identify and improve any potential issues in the result based on the reflection. NOTE: If the original result is correct, no modifications are needed!
103
+
104
+ **Task**: {instruction}
105
+
106
+ **Text**: {text}
107
+
108
+ **Output Schema**: {schema}
109
+
110
+ **Original Result**: {result}
111
+
112
+ """
113
+ reflect_instruction = PromptTemplate(
114
+ input_variables=["instruction", "examples", "text", "schema", "result"],
115
+ template=REFLECT_INSTRUCTION,
116
+ )
117
+
118
+ SUMMARIZE_INSTRUCTION = """
119
+ **Instruction**: Below is a list of results obtained after segmenting and extracting information from a long article. Please consolidate all the answers to generate a final response.
120
+
121
+ **Task**: {instruction}
122
+
123
+ **Result List**: {answer_list}
124
+ {additional_info}
125
+ **Output Schema**: {schema}
126
+ Now summarize the information from the Result List.
127
+ """
128
+ summarize_instruction = PromptTemplate(
129
+ input_variables=["instruction", "answer_list", "additional_info", "schema"],
130
+ template=SUMMARIZE_INSTRUCTION,
131
+ )
132
+
133
+
134
+
135
+ # ==================================================================== #
136
+ # CASE REPOSITORY #
137
+ # ==================================================================== #
138
+
139
+ GOOD_CASE_ANALYSIS_INSTRUCTION = """
140
+ **Instruction**: Below is an information extraction task and its corresponding correct answer. Provide the reasoning steps that led to the correct answer, along with brief explanation of the answer. Your response should be brief and organized.
141
+
142
+ **Task**: {instruction}
143
+
144
+ **Text**: {text}
145
+ {additional_info}
146
+ **Correct Answer**: {result}
147
+
148
+ Now please generate the reasoning steps and breif analysis of the **Correct Answer** given above. DO NOT generate your own extraction result.
149
+ **Analysis**:
150
+ """
151
+ good_case_analysis_instruction = PromptTemplate(
152
+ input_variables=["instruction", "text", "result", "additional_info"],
153
+ template=GOOD_CASE_ANALYSIS_INSTRUCTION,
154
+ )
155
+
156
+ BAD_CASE_REFLECTION_INSTRUCTION = """
157
+ **Instruction**: Based on the task description, compare the original answer with the correct one. Your output should be a brief reflection or concise summarized rules.
158
+
159
+ **Task**: {instruction}
160
+
161
+ **Text**: {text}
162
+ {additional_info}
163
+ **Original Answer**: {original_answer}
164
+
165
+ **Correct Answer**: {correct_answer}
166
+
167
+ Now please generate a brief and organized reflection. DO NOT generate your own extraction result.
168
+ **Reflection**:
169
+ """
170
+
171
+ bad_case_reflection_instruction = PromptTemplate(
172
+ input_variables=["instruction", "text", "original_answer", "correct_answer", "additional_info"],
173
+ template=BAD_CASE_REFLECTION_INSTRUCTION,
174
+ )
src/modules/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .schema_agent import SchemaAgent
2
+ from .extraction_agent import ExtractionAgent
3
+ from .reflection_agent import ReflectionAgent
4
+ from .knowledge_base.case_repository import CaseRepositoryHandler
src/modules/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (459 Bytes). View file
 
src/modules/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (392 Bytes). View file
 
src/modules/__pycache__/extraction_agent.cpython-311.pyc ADDED
Binary file (6.66 kB). View file
 
src/modules/__pycache__/extraction_agent.cpython-39.pyc ADDED
Binary file (4.06 kB). View file
 
src/modules/__pycache__/reflection_agent.cpython-311.pyc ADDED
Binary file (6.98 kB). View file
 
src/modules/__pycache__/reflection_agent.cpython-39.pyc ADDED
Binary file (4.01 kB). View file
 
src/modules/__pycache__/schema_agent.cpython-311.pyc ADDED
Binary file (10.7 kB). View file
 
src/modules/__pycache__/schema_agent.cpython-39.pyc ADDED
Binary file (6.64 kB). View file
 
src/modules/extraction_agent.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from models import *
2
+ from utils import *
3
+ from .knowledge_base.case_repository import CaseRepositoryHandler
4
+
5
+ class InformationExtractor:
6
+ def __init__(self, llm: BaseEngine):
7
+ self.llm = llm
8
+
9
+ def extract_information(self, instruction="", text="", examples="", schema="", additional_info=""):
10
+ examples = good_case_wrapper(examples)
11
+ prompt = extract_instruction.format(instruction=instruction, examples=examples, text=text, additional_info=additional_info, schema=schema)
12
+ response = self.llm.get_chat_response(prompt)
13
+ response = extract_json_dict(response)
14
+ print(f"prompt: {prompt}")
15
+ print("========================================")
16
+ print(f"response: {response}")
17
+ return response
18
+
19
+ def summarize_answer(self, instruction="", answer_list="", schema="", additional_info=""):
20
+ prompt = summarize_instruction.format(instruction=instruction, answer_list=answer_list, schema=schema, additional_info=additional_info)
21
+ response = self.llm.get_chat_response(prompt)
22
+ response = extract_json_dict(response)
23
+ return response
24
+
25
+ class ExtractionAgent:
26
+ def __init__(self, llm: BaseEngine, case_repo: CaseRepositoryHandler):
27
+ self.llm = llm
28
+ self.module = InformationExtractor(llm = llm)
29
+ self.case_repo = case_repo
30
+ self.methods = ["extract_information_direct", "extract_information_with_case"]
31
+
32
+ def __get_constraint(self, data: DataPoint):
33
+ if data.constraint == "":
34
+ return data
35
+ if data.task == "NER":
36
+ constraint = json.dumps(data.constraint)
37
+ if "**Entity Type Constraint**" in constraint:
38
+ return data
39
+ data.constraint = f"\n**Entity Type Constraint**: The type of entities must be chosen from the following list.\n{constraint}\n"
40
+ elif data.task == "RE":
41
+ constraint = json.dumps(data.constraint)
42
+ if "**Relation Type Constraint**" in constraint:
43
+ return data
44
+ data.constraint = f"\n**Relation Type Constraint**: The type of relations must be chosen from the following list.\n{constraint}\n"
45
+ elif data.task == "EE":
46
+ constraint = json.dumps(data.constraint)
47
+ if "**Event Extraction Constraint**" in constraint:
48
+ return data
49
+ data.constraint = f"\n**Event Extraction Constraint**: The event type must be selected from the following dictionary keys, and its event arguments should be chosen from its corresponding dictionary values. \n{constraint}\n"
50
+ return data
51
+
52
+ def extract_information_direct(self, data: DataPoint):
53
+ data = self.__get_constraint(data)
54
+ result_list = []
55
+ for chunk_text in data.chunk_text_list:
56
+ extract_direct_result = self.module.extract_information(instruction=data.instruction, text=chunk_text, schema=data.output_schema, examples="", additional_info=data.constraint)
57
+ result_list.append(extract_direct_result)
58
+ function_name = current_function_name()
59
+ data.set_result_list(result_list)
60
+ data.update_trajectory(function_name, result_list)
61
+ return data
62
+
63
+ def extract_information_with_case(self, data: DataPoint):
64
+ data = self.__get_constraint(data)
65
+ result_list = []
66
+ for chunk_text in data.chunk_text_list:
67
+ examples = self.case_repo.query_good_case(data)
68
+ extract_case_result = self.module.extract_information(instruction=data.instruction, text=chunk_text, schema=data.output_schema, examples=examples, additional_info=data.constraint)
69
+ result_list.append(extract_case_result)
70
+ function_name = current_function_name()
71
+ data.set_result_list(result_list)
72
+ data.update_trajectory(function_name, result_list)
73
+ return data
74
+
75
+ def summarize_answer(self, data: DataPoint):
76
+ if len(data.result_list) == 0:
77
+ return data
78
+ if len(data.result_list) == 1:
79
+ data.set_pred(data.result_list[0])
80
+ return data
81
+ summarized_result = self.module.summarize_answer(instruction=data.instruction, answer_list=data.result_list, schema=data.output_schema, additional_info=data.constraint)
82
+ funtion_name = current_function_name()
83
+ data.set_pred(summarized_result)
84
+ data.update_trajectory(funtion_name, summarized_result)
85
+ return data
src/modules/knowledge_base/__pycache__/case_repository.cpython-311.pyc ADDED
Binary file (4.64 kB). View file
 
src/modules/knowledge_base/__pycache__/case_repository.cpython-39.pyc ADDED
Binary file (3.8 kB). View file
 
src/modules/knowledge_base/__pycache__/schema_repository.cpython-311.pyc ADDED
Binary file (9.25 kB). View file
 
src/modules/knowledge_base/__pycache__/schema_repository.cpython-39.pyc ADDED
Binary file (5.94 kB). View file
 
src/modules/knowledge_base/case_repository.json ADDED
The diff for this file is too large to render. See raw diff
 
src/modules/knowledge_base/case_repository.py ADDED
@@ -0,0 +1,391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import json
2
+ # import os
3
+ # import torch
4
+ # import numpy as np
5
+ # from utils import *
6
+ # from sentence_transformers import SentenceTransformer
7
+ # from rapidfuzz import process
8
+ # from models import *
9
+ # import copy
10
+
11
+ # import warnings
12
+ # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+ # warnings.filterwarnings("ignore", category=FutureWarning, message=r".*clean_up_tokenization_spaces*")
14
+
15
+ # class CaseRepository:
16
+ # def __init__(self):
17
+ # self.embedder = SentenceTransformer("all-MiniLM-L6-v2")
18
+ # self.embedder.to(device)
19
+ # self.corpus = self.load_corpus()
20
+ # self.embedded_corpus = self.embed_corpus()
21
+
22
+ # def load_corpus(self):
23
+ # with open(os.path.join(os.path.dirname(__file__), "case_repository.json")) as file:
24
+ # corpus = json.load(file)
25
+ # return corpus
26
+
27
+ # def update_corpus(self):
28
+ # try:
29
+ # with open(os.path.join(os.path.dirname(__file__), "case_repository.json"), "w") as file:
30
+ # json.dump(self.corpus, file, indent=2)
31
+ # except Exception as e:
32
+ # print(f"Error when updating corpus: {e}")
33
+
34
+ # def embed_corpus(self):
35
+ # embedded_corpus = {}
36
+ # for key, content in self.corpus.items():
37
+ # good_index = [item['index']['embed_index'] for item in content['good']]
38
+ # encoded_good_index = self.embedder.encode(good_index, convert_to_tensor=True).to(device)
39
+ # bad_index = [item['index']['embed_index'] for item in content['bad']]
40
+ # encoded_bad_index = self.embedder.encode(bad_index, convert_to_tensor=True).to(device)
41
+ # embedded_corpus[key] = {"good": encoded_good_index, "bad": encoded_bad_index}
42
+ # return embedded_corpus
43
+
44
+ # def get_similarity_scores(self, task: TaskType, embed_index="", str_index="", case_type="", top_k=2):
45
+ # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
46
+ # # Embedding similarity match
47
+ # encoded_embed_query = self.embedder.encode(embed_index, convert_to_tensor=True).to(device)
48
+ # embedding_similarity_matrix = self.embedder.similarity(encoded_embed_query, self.embedded_corpus[task][case_type])
49
+ # embedding_similarity_scores = embedding_similarity_matrix[0].to(device)
50
+
51
+ # # String similarity match
52
+ # str_match_corpus = [item['index']['str_index'] for item in self.corpus[task][case_type]]
53
+ # str_similarity_results = process.extract(str_index, str_match_corpus, limit=len(str_match_corpus))
54
+ # scores_dict = {match[0]: match[1] for match in str_similarity_results}
55
+ # scores_in_order = [scores_dict[candidate] for candidate in str_match_corpus]
56
+ # str_similarity_scores = torch.tensor(scores_in_order, dtype=torch.float32).to(device)
57
+
58
+ # # Normalize scores
59
+ # embedding_score_range = embedding_similarity_scores.max() - embedding_similarity_scores.min()
60
+ # str_score_range = str_similarity_scores.max() - str_similarity_scores.min()
61
+ # if embedding_score_range > 0:
62
+ # embed_norm_scores = (embedding_similarity_scores - embedding_similarity_scores.min()) / embedding_score_range
63
+ # else:
64
+ # embed_norm_scores = embedding_similarity_scores
65
+ # if str_score_range > 0:
66
+ # str_norm_scores = (str_similarity_scores - str_similarity_scores.min()) / str_score_range
67
+ # else:
68
+ # str_norm_scores = str_similarity_scores / 100
69
+
70
+ # # Combine the scores with weights
71
+ # combined_scores = 0.5 * embed_norm_scores + 0.5 * str_norm_scores
72
+ # original_combined_scores = 0.5 * embedding_similarity_scores + 0.5 * str_similarity_scores / 100
73
+
74
+ # scores, indices = torch.topk(combined_scores, k=min(top_k, combined_scores.size(0)))
75
+ # original_scores, original_indices = torch.topk(original_combined_scores, k=min(top_k, original_combined_scores.size(0)))
76
+ # return scores, indices, original_scores, original_indices
77
+
78
+ # def query_case(self, task: TaskType, embed_index="", str_index="", case_type="", top_k=2) -> list:
79
+ # _, indices, _, _ = self.get_similarity_scores(task, embed_index, str_index, case_type, top_k)
80
+ # top_matches = [self.corpus[task][case_type][idx]["content"] for idx in indices]
81
+ # return top_matches
82
+
83
+ # def update_case(self, task: TaskType, embed_index="", str_index="", content="" ,case_type=""):
84
+ # self.corpus[task][case_type].append({"index": {"embed_index": embed_index, "str_index": str_index}, "content": content})
85
+ # self.embedded_corpus[task][case_type] = torch.cat([self.embedded_corpus[task][case_type], self.embedder.encode([embed_index], convert_to_tensor=True).to(device)], dim=0)
86
+ # print(f"Case updated for {task} task.")
87
+
88
+ # class CaseRepositoryHandler:
89
+ # def __init__(self, llm: BaseEngine):
90
+ # self.repository = CaseRepository()
91
+ # self.llm = llm
92
+
93
+ # def __get_good_case_analysis(self, instruction="", text="", result="", additional_info=""):
94
+ # prompt = good_case_analysis_instruction.format(
95
+ # instruction=instruction, text=text, result=result, additional_info=additional_info
96
+ # )
97
+ # for _ in range(3):
98
+ # response = self.llm.get_chat_response(prompt)
99
+ # response = extract_json_dict(response)
100
+ # if not isinstance(response, dict):
101
+ # return response
102
+ # return None
103
+
104
+ # def __get_bad_case_reflection(self, instruction="", text="", original_answer="", correct_answer="", additional_info=""):
105
+ # prompt = bad_case_reflection_instruction.format(
106
+ # instruction=instruction, text=text, original_answer=original_answer, correct_answer=correct_answer, additional_info=additional_info
107
+ # )
108
+ # for _ in range(3):
109
+ # response = self.llm.get_chat_response(prompt)
110
+ # response = extract_json_dict(response)
111
+ # if not isinstance(response, dict):
112
+ # return response
113
+ # return None
114
+
115
+ # def __get_index(self, data: DataPoint, case_type: str):
116
+ # # set embed_index
117
+ # embed_index = f"**Text**: {data.distilled_text}\n{data.chunk_text_list[0]}"
118
+
119
+ # # set str_index
120
+ # if data.task == "Base":
121
+ # str_index = f"**Task**: {data.instruction}"
122
+ # else:
123
+ # str_index = f"{data.constraint}"
124
+
125
+ # if case_type == "bad":
126
+ # str_index += f"\n\n**Original Result**: {json.dumps(data.pred)}"
127
+
128
+ # return embed_index, str_index
129
+
130
+ # def query_good_case(self, data: DataPoint):
131
+ # embed_index, str_index = self.__get_index(data, "good")
132
+ # return self.repository.query_case(task=data.task, embed_index=embed_index, str_index=str_index, case_type="good")
133
+
134
+ # def query_bad_case(self, data: DataPoint):
135
+ # embed_index, str_index = self.__get_index(data, "bad")
136
+ # return self.repository.query_case(task=data.task, embed_index=embed_index, str_index=str_index, case_type="bad")
137
+
138
+ # def update_good_case(self, data: DataPoint):
139
+ # if data.truth == "" :
140
+ # print("No truth value provided.")
141
+ # return
142
+ # embed_index, str_index = self.__get_index(data, "good")
143
+ # _, _, original_scores, _ = self.repository.get_similarity_scores(data.task, embed_index, str_index, "good", 1)
144
+ # original_scores = original_scores.tolist()
145
+ # if original_scores[0] >= 0.9:
146
+ # print("The similar good case is already in the corpus. Similarity Score: ", original_scores[0])
147
+ # return
148
+ # good_case_alaysis = self.__get_good_case_analysis(instruction=data.instruction, text=data.distilled_text, result=data.truth, additional_info=data.constraint)
149
+ # wrapped_good_case_analysis = f"**Analysis**: {good_case_alaysis}"
150
+ # wrapped_instruction = f"**Task**: {data.instruction}"
151
+ # wrapped_text = f"**Text**: {data.distilled_text}\n{data.chunk_text_list[0]}"
152
+ # wrapped_answer = f"**Correct Answer**: {json.dumps(data.truth)}"
153
+ # if data.task == "Base":
154
+ # content = f"{wrapped_instruction}\n\n{wrapped_text}\n\n{wrapped_good_case_analysis}\n\n{wrapped_answer}"
155
+ # else:
156
+ # content = f"{wrapped_text}\n\n{data.constraint}\n\n{wrapped_good_case_analysis}\n\n{wrapped_answer}"
157
+ # self.repository.update_case(data.task, embed_index, str_index, content, "good")
158
+
159
+ # def update_bad_case(self, data: DataPoint):
160
+ # if data.truth == "" :
161
+ # print("No truth value provided.")
162
+ # return
163
+ # if normalize_obj(data.pred) == normalize_obj(data.truth):
164
+ # return
165
+ # embed_index, str_index = self.__get_index(data, "bad")
166
+ # _, _, original_scores, _ = self.repository.get_similarity_scores(data.task, embed_index, str_index, "bad", 1)
167
+ # original_scores = original_scores.tolist()
168
+ # if original_scores[0] >= 0.9:
169
+ # print("The similar bad case is already in the corpus. Similarity Score: ", original_scores[0])
170
+ # return
171
+ # bad_case_reflection = self.__get_bad_case_reflection(instruction=data.instruction, text=data.distilled_text, original_answer=data.pred, correct_answer=data.truth, additional_info=data.constraint)
172
+ # wrapped_bad_case_reflection = f"**Reflection**: {bad_case_reflection}"
173
+ # wrapper_original_answer = f"**Original Answer**: {json.dumps(data.pred)}"
174
+ # wrapper_correct_answer = f"**Correct Answer**: {json.dumps(data.truth)}"
175
+ # wrapped_instruction = f"**Task**: {data.instruction}"
176
+ # wrapped_text = f"**Text**: {data.distilled_text}\n{data.chunk_text_list[0]}"
177
+ # if data.task == "Base":
178
+ # content = f"{wrapped_instruction}\n\n{wrapped_text}\n\n{wrapper_original_answer}\n\n{wrapped_bad_case_reflection}\n\n{wrapper_correct_answer}"
179
+ # else:
180
+ # content = f"{wrapped_text}\n\n{data.constraint}\n\n{wrapper_original_answer}\n\n{wrapped_bad_case_reflection}\n\n{wrapper_correct_answer}"
181
+ # self.repository.update_case(data.task, embed_index, str_index, content, "bad")
182
+
183
+ # def update_case(self, data: DataPoint):
184
+ # self.update_good_case(data)
185
+ # self.update_bad_case(data)
186
+ # self.repository.update_corpus()
187
+
188
+
189
+
190
+ import json
191
+ import os
192
+ import torch
193
+ import numpy as np
194
+ from utils import *
195
+ from sentence_transformers import SentenceTransformer
196
+ from rapidfuzz import process
197
+ from models import *
198
+ import copy
199
+
200
+ import warnings
201
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
202
+ warnings.filterwarnings("ignore", category=FutureWarning, message=r".*clean_up_tokenization_spaces*")
203
+
204
+ class CaseRepository:
205
+ def __init__(self):
206
+ # self.embedder = SentenceTransformer("all-MiniLM-L6-v2")
207
+ # self.embedder.to(device)
208
+ # self.corpus = self.load_corpus()
209
+ # self.embedded_corpus = self.embed_corpus()
210
+ pass
211
+
212
+ def load_corpus(self):
213
+ # with open(os.path.join(os.path.dirname(__file__), "case_repository.json")) as file:
214
+ # corpus = json.load(file)
215
+ # return corpus
216
+ pass
217
+
218
+ def update_corpus(self):
219
+ # try:
220
+ # with open(os.path.join(os.path.dirname(__file__), "case_repository.json"), "w") as file:
221
+ # json.dump(self.corpus, file, indent=2)
222
+ # except Exception as e:
223
+ # print(f"Error when updating corpus: {e}")
224
+ pass
225
+
226
+ def embed_corpus(self):
227
+ # embedded_corpus = {}
228
+ # for key, content in self.corpus.items():
229
+ # good_index = [item['index']['embed_index'] for item in content['good']]
230
+ # encoded_good_index = self.embedder.encode(good_index, convert_to_tensor=True).to(device)
231
+ # bad_index = [item['index']['embed_index'] for item in content['bad']]
232
+ # encoded_bad_index = self.embedder.encode(bad_index, convert_to_tensor=True).to(device)
233
+ # embedded_corpus[key] = {"good": encoded_good_index, "bad": encoded_bad_index}
234
+ # return embedded_corpus
235
+ pass
236
+
237
+ def get_similarity_scores(self, task: TaskType, embed_index="", str_index="", case_type="", top_k=2):
238
+ # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
239
+ # # Embedding similarity match
240
+ # encoded_embed_query = self.embedder.encode(embed_index, convert_to_tensor=True).to(device)
241
+ # embedding_similarity_matrix = self.embedder.similarity(encoded_embed_query, self.embedded_corpus[task][case_type])
242
+ # embedding_similarity_scores = embedding_similarity_matrix[0].to(device)
243
+
244
+ # # String similarity match
245
+ # str_match_corpus = [item['index']['str_index'] for item in self.corpus[task][case_type]]
246
+ # str_similarity_results = process.extract(str_index, str_match_corpus, limit=len(str_match_corpus))
247
+ # scores_dict = {match[0]: match[1] for match in str_similarity_results}
248
+ # scores_in_order = [scores_dict[candidate] for candidate in str_match_corpus]
249
+ # str_similarity_scores = torch.tensor(scores_in_order, dtype=torch.float32).to(device)
250
+
251
+ # # Normalize scores
252
+ # embedding_score_range = embedding_similarity_scores.max() - embedding_similarity_scores.min()
253
+ # str_score_range = str_similarity_scores.max() - str_similarity_scores.min()
254
+ # if embedding_score_range > 0:
255
+ # embed_norm_scores = (embedding_similarity_scores - embedding_similarity_scores.min()) / embedding_score_range
256
+ # else:
257
+ # embed_norm_scores = embedding_similarity_scores
258
+ # if str_score_range > 0:
259
+ # str_norm_scores = (str_similarity_scores - str_similarity_scores.min()) / str_score_range
260
+ # else:
261
+ # str_norm_scores = str_similarity_scores / 100
262
+
263
+ # # Combine the scores with weights
264
+ # combined_scores = 0.5 * embed_norm_scores + 0.5 * str_norm_scores
265
+ # original_combined_scores = 0.5 * embedding_similarity_scores + 0.5 * str_similarity_scores / 100
266
+
267
+ # scores, indices = torch.topk(combined_scores, k=min(top_k, combined_scores.size(0)))
268
+ # original_scores, original_indices = torch.topk(original_combined_scores, k=min(top_k, original_combined_scores.size(0)))
269
+ # return scores, indices, original_scores, original_indices
270
+ pass
271
+
272
+ def query_case(self, task: TaskType, embed_index="", str_index="", case_type="", top_k=2) -> list:
273
+ # _, indices, _, _ = self.get_similarity_scores(task, embed_index, str_index, case_type, top_k)
274
+ # top_matches = [self.corpus[task][case_type][idx]["content"] for idx in indices]
275
+ # return top_matches
276
+ pass
277
+
278
+ def update_case(self, task: TaskType, embed_index="", str_index="", content="" ,case_type=""):
279
+ # self.corpus[task][case_type].append({"index": {"embed_index": embed_index, "str_index": str_index}, "content": content})
280
+ # self.embedded_corpus[task][case_type] = torch.cat([self.embedded_corpus[task][case_type], self.embedder.encode([embed_index], convert_to_tensor=True).to(device)], dim=0)
281
+ # print(f"Case updated for {task} task.")
282
+ pass
283
+
284
+ class CaseRepositoryHandler:
285
+ def __init__(self, llm: BaseEngine):
286
+ self.repository = CaseRepository()
287
+ self.llm = llm
288
+
289
+ def __get_good_case_analysis(self, instruction="", text="", result="", additional_info=""):
290
+ # prompt = good_case_analysis_instruction.format(
291
+ # instruction=instruction, text=text, result=result, additional_info=additional_info
292
+ # )
293
+ # for _ in range(3):
294
+ # response = self.llm.get_chat_response(prompt)
295
+ # response = extract_json_dict(response)
296
+ # if not isinstance(response, dict):
297
+ # return response
298
+ # return None
299
+ pass
300
+
301
+ def __get_bad_case_reflection(self, instruction="", text="", original_answer="", correct_answer="", additional_info=""):
302
+ # prompt = bad_case_reflection_instruction.format(
303
+ # instruction=instruction, text=text, original_answer=original_answer, correct_answer=correct_answer, additional_info=additional_info
304
+ # )
305
+ # for _ in range(3):
306
+ # response = self.llm.get_chat_response(prompt)
307
+ # response = extract_json_dict(response)
308
+ # if not isinstance(response, dict):
309
+ # return response
310
+ # return None
311
+ pass
312
+
313
+ def __get_index(self, data: DataPoint, case_type: str):
314
+ # set embed_index
315
+ # embed_index = f"**Text**: {data.distilled_text}\n{data.chunk_text_list[0]}"
316
+
317
+ # # set str_index
318
+ # if data.task == "Base":
319
+ # str_index = f"**Task**: {data.instruction}"
320
+ # else:
321
+ # str_index = f"{data.constraint}"
322
+
323
+ # if case_type == "bad":
324
+ # str_index += f"\n\n**Original Result**: {json.dumps(data.pred)}"
325
+
326
+ # return embed_index, str_index
327
+ pass
328
+
329
+ def query_good_case(self, data: DataPoint):
330
+ # embed_index, str_index = self.__get_index(data, "good")
331
+ # return self.repository.query_case(task=data.task, embed_index=embed_index, str_index=str_index, case_type="good")
332
+ pass
333
+
334
+ def query_bad_case(self, data: DataPoint):
335
+ # embed_index, str_index = self.__get_index(data, "bad")
336
+ # return self.repository.query_case(task=data.task, embed_index=embed_index, str_index=str_index, case_type="bad")
337
+ pass
338
+
339
+ def update_good_case(self, data: DataPoint):
340
+ # if data.truth == "" :
341
+ # print("No truth value provided.")
342
+ # return
343
+ # embed_index, str_index = self.__get_index(data, "good")
344
+ # _, _, original_scores, _ = self.repository.get_similarity_scores(data.task, embed_index, str_index, "good", 1)
345
+ # original_scores = original_scores.tolist()
346
+ # if original_scores[0] >= 0.9:
347
+ # print("The similar good case is already in the corpus. Similarity Score: ", original_scores[0])
348
+ # return
349
+ # good_case_alaysis = self.__get_good_case_analysis(instruction=data.instruction, text=data.distilled_text, result=data.truth, additional_info=data.constraint)
350
+ # wrapped_good_case_analysis = f"**Analysis**: {good_case_alaysis}"
351
+ # wrapped_instruction = f"**Task**: {data.instruction}"
352
+ # wrapped_text = f"**Text**: {data.distilled_text}\n{data.chunk_text_list[0]}"
353
+ # wrapped_answer = f"**Correct Answer**: {json.dumps(data.truth)}"
354
+ # if data.task == "Base":
355
+ # content = f"{wrapped_instruction}\n\n{wrapped_text}\n\n{wrapped_good_case_analysis}\n\n{wrapped_answer}"
356
+ # else:
357
+ # content = f"{wrapped_text}\n\n{data.constraint}\n\n{wrapped_good_case_analysis}\n\n{wrapped_answer}"
358
+ # self.repository.update_case(data.task, embed_index, str_index, content, "good")
359
+ pass
360
+
361
+ def update_bad_case(self, data: DataPoint):
362
+ # if data.truth == "" :
363
+ # print("No truth value provided.")
364
+ # return
365
+ # if normalize_obj(data.pred) == normalize_obj(data.truth):
366
+ # return
367
+ # embed_index, str_index = self.__get_index(data, "bad")
368
+ # _, _, original_scores, _ = self.repository.get_similarity_scores(data.task, embed_index, str_index, "bad", 1)
369
+ # original_scores = original_scores.tolist()
370
+ # if original_scores[0] >= 0.9:
371
+ # print("The similar bad case is already in the corpus. Similarity Score: ", original_scores[0])
372
+ # return
373
+ # bad_case_reflection = self.__get_bad_case_reflection(instruction=data.instruction, text=data.distilled_text, original_answer=data.pred, correct_answer=data.truth, additional_info=data.constraint)
374
+ # wrapped_bad_case_reflection = f"**Reflection**: {bad_case_reflection}"
375
+ # wrapper_original_answer = f"**Original Answer**: {json.dumps(data.pred)}"
376
+ # wrapper_correct_answer = f"**Correct Answer**: {json.dumps(data.truth)}"
377
+ # wrapped_instruction = f"**Task**: {data.instruction}"
378
+ # wrapped_text = f"**Text**: {data.distilled_text}\n{data.chunk_text_list[0]}"
379
+ # if data.task == "Base":
380
+ # content = f"{wrapped_instruction}\n\n{wrapped_text}\n\n{wrapper_original_answer}\n\n{wrapped_bad_case_reflection}\n\n{wrapper_correct_answer}"
381
+ # else:
382
+ # content = f"{wrapped_text}\n\n{data.constraint}\n\n{wrapper_original_answer}\n\n{wrapped_bad_case_reflection}\n\n{wrapper_correct_answer}"
383
+ # self.repository.update_case(data.task, embed_index, str_index, content, "bad")
384
+ pass
385
+
386
+ def update_case(self, data: DataPoint):
387
+ # self.update_good_case(data)
388
+ # self.update_bad_case(data)
389
+ # self.repository.update_corpus()
390
+ pass
391
+
src/modules/knowledge_base/schema_repository.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional
2
+ from pydantic import BaseModel, Field
3
+ from langchain_core.output_parsers import JsonOutputParser
4
+
5
+ # ==================================================================== #
6
+ # NER TASK #
7
+ # ==================================================================== #
8
+ class Entity(BaseModel):
9
+ name : str = Field(description="The specific name of the entity. ")
10
+ type : str = Field(description="The type or category that the entity belongs to.")
11
+ class EntityList(BaseModel):
12
+ entity_list : List[Entity] = Field(description="Named entities appearing in the text.")
13
+
14
+ # ==================================================================== #
15
+ # RE TASK #
16
+ # ==================================================================== #
17
+ class Relation(BaseModel):
18
+ head : str = Field(description="The starting entity in the relationship.")
19
+ tail : str = Field(description="The ending entity in the relationship.")
20
+ relation : str = Field(description="The predicate that defines the relationship between the two entities.")
21
+
22
+ class RelationList(BaseModel):
23
+ relation_list : List[Relation] = Field(description="The collection of relationships between various entities.")
24
+
25
+ # ==================================================================== #
26
+ # EE TASK #
27
+ # ==================================================================== #
28
+ class Event(BaseModel):
29
+ event_type : str = Field(description="The type of the event.")
30
+ event_trigger : str = Field(description="A specific word or phrase that indicates the occurrence of the event.")
31
+ event_argument : dict = Field(description="The arguments or participants involved in the event.")
32
+
33
+ class EventList(BaseModel):
34
+ event_list : List[Event] = Field(description="The events presented in the text.")
35
+
36
+ # ==================================================================== #
37
+ # TEXT DESCRIPTION #
38
+ # ==================================================================== #
39
+ class TextDescription(BaseModel):
40
+ field: str = Field(description="The field of the given text, such as 'Science', 'Literature', 'Business', 'Medicine', 'Entertainment', etc.")
41
+ genre: str = Field(description="The genre of the given text, such as 'Article', 'Novel', 'Dialog', 'Blog', 'Manual','Expository', 'News Report', 'Research Paper', etc.")
42
+
43
+ # ==================================================================== #
44
+ # USER DEFINED SCHEMA #
45
+ # ==================================================================== #
46
+
47
+ # --------------------------- Research Paper ----------------------- #
48
+ class MetaData(BaseModel):
49
+ title : str = Field(description="The title of the article")
50
+ authors : List[str] = Field(description="The list of the article's authors")
51
+ abstract: str = Field(description="The article's abstract")
52
+ key_words: List[str] = Field(description="The key words associated with the article")
53
+
54
+ class Baseline(BaseModel):
55
+ method_name : str = Field(description="The name of the baseline method")
56
+ proposed_solution : str = Field(description="the proposed solution in details")
57
+ performance_metrics : str = Field(description="The performance metrics of the method and comparative analysis")
58
+
59
+ class ExtractionTarget(BaseModel):
60
+
61
+ key_contributions: List[str] = Field(description="The key contributions of the article")
62
+ limitation_of_sota : str=Field(description="the summary limitation of the existing work")
63
+ proposed_solution : str = Field(description="the proposed solution in details")
64
+ baselines : List[Baseline] = Field(description="The list of baseline methods and their details")
65
+ performance_metrics : str = Field(description="The performance metrics of the method and comparative analysis")
66
+ paper_limitations : str=Field(description="The limitations of the proposed solution of the paper")
67
+
68
+ # --------------------------- News ----------------------- #
69
+ class Person(BaseModel):
70
+ name: str = Field(description="The name of the person")
71
+ identity: Optional[str] = Field(description="The occupation, status or characteristics of the person.")
72
+ role: Optional[str] = Field(description="The role or function the person plays in an event.")
73
+
74
+ class Event(BaseModel):
75
+ name: str = Field(description="Name of the event")
76
+ time: Optional[str] = Field(description="Time when the event took place")
77
+ people_involved: Optional[List[Person]] = Field(description="People involved in the event")
78
+ cause: Optional[str] = Field(default=None, description="Reason for the event, if applicable")
79
+ process: Optional[str] = Field(description="Details of the event process")
80
+ result: Optional[str] = Field(default=None, description="Result or outcome of the event")
81
+
82
+ class NewsReport(BaseModel):
83
+ title: str = Field(description="The title or headline of the news report")
84
+ summary: str = Field(description="A brief summary of the news report")
85
+ publication_date: Optional[str] = Field(description="The publication date of the report")
86
+ keywords: Optional[List[str]] = Field(description="List of keywords or topics covered in the news report")
87
+ events: List[Event] = Field(description="Events covered in the news report")
88
+ quotes: Optional[List[str]] = Field(default=None, description="Quotes related to the news, if any")
89
+ viewpoints: Optional[List[str]] = Field(default=None, description="Different viewpoints regarding the news")
90
+
91
+ # --------- You can customize new extraction schemas below -------- #
src/modules/reflection_agent.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from models import *
2
+ from utils import *
3
+ from .extraction_agent import ExtractionAgent
4
+ from .knowledge_base.case_repository import CaseRepositoryHandler
5
+ class ReflectionGenerator:
6
+ def __init__(self, llm: BaseEngine):
7
+ self.llm = llm
8
+
9
+ def get_reflection(self, instruction="", examples="", text="",schema="", result=""):
10
+ result = json.dumps(result)
11
+ examples = bad_case_wrapper(examples)
12
+ prompt = reflect_instruction.format(instruction=instruction, examples=examples, text=text, schema=schema, result=result)
13
+ response = self.llm.get_chat_response(prompt)
14
+ response = extract_json_dict(response)
15
+ return response
16
+
17
+ class ReflectionAgent:
18
+ def __init__(self, llm: BaseEngine, case_repo: CaseRepositoryHandler):
19
+ self.llm = llm
20
+ self.module = ReflectionGenerator(llm = llm)
21
+ self.extractor = ExtractionAgent(llm = llm, case_repo = case_repo)
22
+ self.case_repo = case_repo
23
+ self.methods = ["reflect_with_case"]
24
+
25
+ def __select_result(self, result_list):
26
+ dict_objects = [obj for obj in result_list if isinstance(obj, dict)]
27
+ if dict_objects:
28
+ selected_obj = max(dict_objects, key=lambda d: len(json.dumps(d)))
29
+ else:
30
+ selected_obj = max(result_list, key=lambda o: len(json.dumps(o)))
31
+ return selected_obj
32
+
33
+ def __self_consistance_check(self, data: DataPoint):
34
+ extract_func = list(data.result_trajectory.keys())[-1]
35
+ if hasattr(self.extractor, extract_func):
36
+ result_trails = []
37
+ result_trails.append(data.result_list)
38
+ extract_func = getattr(self.extractor, extract_func)
39
+ temperature = [0.5, 1]
40
+ for index in range(2):
41
+ self.module.llm.set_hyperparameter(temperature=temperature[index])
42
+ data = extract_func(data)
43
+ result_trails.append(data.result_list)
44
+ self.module.llm.set_hyperparameter()
45
+ consistant_result = []
46
+ reflect_index = []
47
+ for index, elements in enumerate(zip(*result_trails)):
48
+ normalized_elements = [normalize_obj(e) for e in elements]
49
+ element_counts = Counter(normalized_elements)
50
+ selected_element = next((elements[i] for i, element in enumerate(normalized_elements)
51
+ if element_counts[element] >= 2), None)
52
+ if selected_element is None:
53
+ selected_element = self.__select_result(elements)
54
+ reflect_index.append(index)
55
+ consistant_result.append(selected_element)
56
+ data.set_result_list(consistant_result)
57
+ return reflect_index
58
+
59
+ def reflect_with_case(self, data: DataPoint):
60
+ if data.result_list == []:
61
+ return data
62
+ reflect_index = self.__self_consistance_check(data)
63
+ reflected_result_list = data.result_list
64
+ for idx in reflect_index:
65
+ text = data.chunk_text_list[idx]
66
+ result = data.result_list[idx]
67
+ examples = json.dumps(self.case_repo.query_bad_case(data))
68
+ reflected_res = self.module.get_reflection(instruction=data.instruction, examples=examples, text=text, schema=data.output_schema, result=result)
69
+ reflected_result_list[idx] = reflected_res
70
+ data.set_result_list(reflected_result_list)
71
+ function_name = current_function_name()
72
+ data.update_trajectory(function_name, data.result_list)
73
+ return data
74
+
src/modules/schema_agent.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from models import *
2
+ from utils import *
3
+ from .knowledge_base import schema_repository
4
+ from langchain_core.output_parsers import JsonOutputParser
5
+
6
+ class SchemaAnalyzer:
7
+ def __init__(self, llm: BaseEngine):
8
+ self.llm = llm
9
+
10
+ def serialize_schema(self, schema) -> str:
11
+ if isinstance(schema, (str, list, dict, set, tuple)):
12
+ return schema
13
+ try:
14
+ parser = JsonOutputParser(pydantic_object = schema)
15
+ schema_description = parser.get_format_instructions()
16
+ schema_content = re.findall(r'```(.*?)```', schema_description, re.DOTALL)
17
+ explanation = "For example, for the schema {\"properties\": {\"foo\": {\"title\": \"Foo\", \"description\": \"a list of strings\", \"type\": \"array\", \"items\": {\"type\": \"string\"}}}}, the object {\"foo\": [\"bar\", \"baz\"]} is a well-formatted instance."
18
+ schema = f"{schema_content}\n\n{explanation}"
19
+ except:
20
+ return schema
21
+ return schema
22
+
23
+ def redefine_text(self, text_analysis):
24
+ try:
25
+ field = text_analysis['field']
26
+ genre = text_analysis['genre']
27
+ except:
28
+ return text_analysis
29
+ prompt = f"This text is from the field of {field} and represents the genre of {genre}."
30
+ return prompt
31
+
32
+ def get_text_analysis(self, text: str):
33
+ output_schema = self.serialize_schema(schema_repository.TextDescription)
34
+ prompt = text_analysis_instruction.format(examples="", text=text, schema=output_schema)
35
+ response = self.llm.get_chat_response(prompt)
36
+ response = extract_json_dict(response)
37
+ response = self.redefine_text(response)
38
+ return response
39
+
40
+ def get_deduced_schema_json(self, instruction: str, text: str, distilled_text: str):
41
+ prompt = deduced_schema_json_instruction.format(examples=example_wrapper(json_schema_examples), instruction=instruction, distilled_text=distilled_text, text=text)
42
+ response = self.llm.get_chat_response(prompt)
43
+ response = extract_json_dict(response)
44
+ code = response
45
+ print(f"Deduced Schema in Json: \n{response}\n\n")
46
+ return code, response
47
+
48
+ def get_deduced_schema_code(self, instruction: str, text: str, distilled_text: str):
49
+ prompt = deduced_schema_code_instruction.format(examples=example_wrapper(code_schema_examples), instruction=instruction, distilled_text=distilled_text, text=text)
50
+ response = self.llm.get_chat_response(prompt)
51
+ print(f"schema prompt: {prompt}")
52
+ print("========================================")
53
+ print(f"schema response: {response}")
54
+ code_blocks = re.findall(r'```[^\n]*\n(.*?)\n```', response, re.DOTALL)
55
+ if code_blocks:
56
+ try:
57
+ code_block = code_blocks[-1]
58
+ namespace = {}
59
+ exec(code_block, namespace)
60
+ schema = namespace.get('ExtractionTarget')
61
+ if schema is not None:
62
+ index = code_block.find("class")
63
+ code = code_block[index:]
64
+ print(f"Deduced Schema in Code: \n{code}\n\n")
65
+ schema = self.serialize_schema(schema)
66
+ return code, schema
67
+ except Exception as e:
68
+ print(e)
69
+ return self.get_deduced_schema_json(instruction, text, distilled_text)
70
+ return self.get_deduced_schema_json(instruction, text, distilled_text)
71
+
72
+ class SchemaAgent:
73
+ def __init__(self, llm: BaseEngine):
74
+ self.llm = llm
75
+ self.module = SchemaAnalyzer(llm = llm)
76
+ self.schema_repo = schema_repository
77
+ self.methods = ["get_default_schema", "get_retrieved_schema", "get_deduced_schema"]
78
+
79
+ def __preprocess_text(self, data: DataPoint):
80
+ if data.use_file:
81
+ data.chunk_text_list = chunk_file(data.file_path)
82
+ else:
83
+ data.chunk_text_list = chunk_str(data.text)
84
+ if data.task == "NER":
85
+ data.print_schema = """
86
+ class Entity(BaseModel):
87
+ name : str = Field(description="The specific name of the entity. ")
88
+ type : str = Field(description="The type or category that the entity belongs to.")
89
+ class EntityList(BaseModel):
90
+ entity_list : List[Entity] = Field(description="Named entities appearing in the text.")
91
+ """
92
+ elif data.task == "RE":
93
+ data.print_schema = """
94
+ class Relation(BaseModel):
95
+ head : str = Field(description="The starting entity in the relationship.")
96
+ tail : str = Field(description="The ending entity in the relationship.")
97
+ relation : str = Field(description="The predicate that defines the relationship between the two entities.")
98
+
99
+ class RelationList(BaseModel):
100
+ relation_list : List[Relation] = Field(description="The collection of relationships between various entities.")
101
+ """
102
+ elif data.task == "EE":
103
+ data.print_schema = """
104
+ class Event(BaseModel):
105
+ event_type : str = Field(description="The type of the event.")
106
+ event_trigger : str = Field(description="A specific word or phrase that indicates the occurrence of the event.")
107
+ event_argument : dict = Field(description="The arguments or participants involved in the event.")
108
+
109
+ class EventList(BaseModel):
110
+ event_list : List[Event] = Field(description="The events presented in the text.")
111
+ """
112
+ return data
113
+
114
+ def get_default_schema(self, data: DataPoint):
115
+ data = self.__preprocess_text(data)
116
+ default_schema = config['agent']['default_schema']
117
+ data.set_schema(default_schema)
118
+ function_name = current_function_name()
119
+ data.update_trajectory(function_name, default_schema)
120
+ return data
121
+
122
+ def get_retrieved_schema(self, data: DataPoint):
123
+ self.__preprocess_text(data)
124
+ schema_name = data.output_schema
125
+ schema_class = getattr(self.schema_repo, schema_name, None)
126
+ if schema_class is not None:
127
+ schema = self.module.serialize_schema(schema_class)
128
+ default_schema = config['agent']['default_schema']
129
+ data.set_schema(f"{default_schema}\n{schema}")
130
+ function_name = current_function_name()
131
+ data.update_trajectory(function_name, schema)
132
+ else:
133
+ return self.get_default_schema(data)
134
+ return data
135
+
136
+ def get_deduced_schema(self, data: DataPoint):
137
+ self.__preprocess_text(data)
138
+ target_text = data.chunk_text_list[0]
139
+ analysed_text = self.module.get_text_analysis(target_text)
140
+ if len(data.chunk_text_list) > 1:
141
+ prefix = "Below is a portion of the text to be extracted. "
142
+ analysed_text = f"{prefix}\n{target_text}"
143
+ distilled_text = self.module.redefine_text(analysed_text)
144
+ code, deduced_schema = self.module.get_deduced_schema_code(data.instruction, target_text, distilled_text)
145
+ data.print_schema = code
146
+ data.set_distilled_text(distilled_text)
147
+ default_schema = config['agent']['default_schema']
148
+ data.set_schema(f"{default_schema}\n{deduced_schema}")
149
+ function_name = current_function_name()
150
+ data.update_trajectory(function_name, deduced_schema)
151
+ return data
src/pipeline.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Literal
2
+ from models import *
3
+ from utils import *
4
+ from modules import *
5
+
6
+ class Pipeline:
7
+ def __init__(self, llm: BaseEngine):
8
+ self.llm = llm
9
+ self.case_repo = CaseRepositoryHandler(llm = llm)
10
+ self.schema_agent = SchemaAgent(llm = llm)
11
+ self.extraction_agent = ExtractionAgent(llm = llm, case_repo = self.case_repo)
12
+ self.reflection_agent = ReflectionAgent(llm = llm, case_repo = self.case_repo)
13
+
14
+ def __init_method(self, data: DataPoint, process_method):
15
+ default_order = ["schema_agent", "extraction_agent", "reflection_agent"]
16
+ if "schema_agent" not in process_method:
17
+ process_method["schema_agent"] = "get_default_schema"
18
+ if data.task == "Base":
19
+ process_method["schema_agent"] = "get_deduced_schema"
20
+ if data.task != "Base":
21
+ process_method["schema_agent"] = "get_retrieved_schema"
22
+ if "extraction_agent" not in process_method:
23
+ process_method["extraction_agent"] = "extract_information_direct"
24
+ sorted_process_method = {key: process_method[key] for key in default_order if key in process_method}
25
+ return sorted_process_method
26
+
27
+ def __init_data(self, data: DataPoint):
28
+ if data.task == "NER":
29
+ data.instruction = config['agent']['default_ner']
30
+ data.output_schema = "EntityList"
31
+ elif data.task == "RE":
32
+ data.instruction = config['agent']['default_re']
33
+ data.output_schema = "RelationList"
34
+ elif data.task == "EE":
35
+ data.instruction = config['agent']['default_ee']
36
+ data.output_schema = "EventList"
37
+ return data
38
+
39
+
40
+
41
+ # main entry
42
+ def get_extract_result(self,
43
+ task: TaskType,
44
+ instruction: str = "",
45
+ text: str = "",
46
+ output_schema: str = "",
47
+ constraint: str = "",
48
+ use_file: bool = False,
49
+ file_path: str = "",
50
+ truth: str = "",
51
+ mode: str = "quick",
52
+ update_case: bool = False
53
+ ):
54
+ print(f" task: {task},\n instruction: {instruction},\n text: {text},\n output_schema: {output_schema},\n constraint: {constraint},\n use_file: {use_file},\n file_path: {file_path},\n truth: {truth},\n mode: {mode},\n update_case: {update_case}")
55
+ data = DataPoint(task=task, instruction=instruction, text=text, output_schema=output_schema, constraint=constraint, use_file=use_file, file_path=file_path, truth=truth)
56
+ data = self.__init_data(data)
57
+ if mode in config['agent']['mode'].keys():
58
+ process_method = config['agent']['mode'][mode]
59
+ else:
60
+ process_method = mode
61
+ print(f"data=================: {data.task}")
62
+ print(f"process_method=================: {process_method}")
63
+ sorted_process_method = self.__init_method(data, process_method)
64
+ print_schema = False
65
+ frontend_schema = ""
66
+ frontend_res = ""
67
+ # Information Extract
68
+ print(f"sorted_process_method=================: {sorted_process_method}")
69
+ for agent_name, method_name in sorted_process_method.items():
70
+ agent = getattr(self, agent_name, None)
71
+ if not agent:
72
+ raise AttributeError(f"{agent_name} does not exist.")
73
+ method = getattr(agent, method_name, None)
74
+ if not method:
75
+ raise AttributeError(f"Method '{method_name}' not found in {agent_name}.")
76
+ data = method(data)
77
+ if not print_schema and data.print_schema:
78
+ print("Schema: \n", data.print_schema)
79
+ frontend_schema = data.print_schema
80
+ print_schema = True
81
+ data = self.extraction_agent.summarize_answer(data)
82
+ print("Extraction Result: \n", json.dumps(data.pred, indent=2))
83
+ frontend_res = data.pred
84
+ # Case Update
85
+ if update_case:
86
+ if (data.truth == ""):
87
+ truth = input("Please enter the correct answer you prefer, or press Enter to accept the current answer: ")
88
+ if truth.strip() == "":
89
+ data.truth = data.pred
90
+ else:
91
+ data.truth = extract_json_dict(truth)
92
+ self.case_repo.update_case(data)
93
+
94
+ # return result
95
+ result = data.pred
96
+ trajectory = data.get_result_trajectory()
97
+
98
+ return result, trajectory, frontend_schema, frontend_res
src/run.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import yaml
4
+ from pipeline import Pipeline
5
+ from typing import Literal
6
+ import models
7
+ from models import *
8
+ from utils import *
9
+ from modules import *
10
+
11
+ def load_extraction_config(yaml_path):
12
+ # 从文件路径读取 YAML 内容
13
+ if not os.path.exists(yaml_path):
14
+ print(f"Error: The config file '{yaml_path}' does not exist.")
15
+ return {}
16
+
17
+ with open(yaml_path, 'r') as file:
18
+ config = yaml.safe_load(file)
19
+
20
+ # 提取'extraction'配置的字典
21
+ model_config = config.get('model', {})
22
+ extraction_config = config.get('extraction', {})
23
+ # model config
24
+ model_name_or_path = model_config.get('model_name_or_path', "")
25
+ model_category = model_config.get('category', "")
26
+ api_key = model_config.get('api_key', "")
27
+ base_url = model_config.get('base_url', "")
28
+
29
+ # extraction config
30
+ task = extraction_config.get('task', "")
31
+ instruction = extraction_config.get('instruction', "")
32
+ text = extraction_config.get('text', "")
33
+ output_schema = extraction_config.get('output_schema', "")
34
+ constraint = extraction_config.get('constraint', "")
35
+ truth = extraction_config.get('truth', "")
36
+ use_file = extraction_config.get('use_file', False)
37
+ mode = extraction_config.get('mode', "quick")
38
+ update_case = extraction_config.get('update_case', False)
39
+
40
+ # 返回一个包含这些变量的字典
41
+ return {
42
+ "model": {
43
+ "model_name_or_path": model_name_or_path,
44
+ "category": model_category,
45
+ "api_key": api_key,
46
+ "base_url": base_url
47
+ },
48
+ "extraction": {
49
+ "task": task,
50
+ "instruction": instruction,
51
+ "text": text,
52
+ "output_schema": output_schema,
53
+ "constraint": constraint,
54
+ "truth": truth,
55
+ "use_file": use_file,
56
+ "mode": mode,
57
+ "update_case": update_case
58
+ }
59
+ }
60
+
61
+
62
+ def main():
63
+ # 创建命令行参数解析器
64
+ parser = argparse.ArgumentParser(description='Run the extraction model.')
65
+ parser.add_argument('--config', type=str, required=True,
66
+ help='Path to the YAML configuration file.')
67
+
68
+ # 解析命令行参数
69
+ args = parser.parse_args()
70
+
71
+ # 加载配置
72
+ config = load_extraction_config(args.config)
73
+ model_config = config['model']
74
+ extraction_config = config['extraction']
75
+ clazz = getattr(models, model_config['category'], None)
76
+ if clazz is None:
77
+ print(f"Error: The model category '{model_config['category']}' is not supported.")
78
+ return
79
+ if model_config['api_key'] == "":
80
+ model = clazz(model_config['model_name_or_path'])
81
+ else:
82
+ model = clazz(model_config['model_name_or_path'], model_config['api_key'], model_config['base_url'])
83
+ pipeline = Pipeline(model)
84
+ result, trajectory, *_ = pipeline.get_extract_result(task=extraction_config['task'], instruction=extraction_config['instruction'], text=extraction_config['text'], output_schema=extraction_config['output_schema'], constraint=extraction_config['constraint'], use_file=extraction_config['use_file'], truth=extraction_config['truth'], mode=extraction_config['mode'], update_case=extraction_config['update_case'])
85
+ return
86
+
87
+ if __name__ == "__main__":
88
+ main()
src/utils/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .process import *
2
+ from .data_def import DataPoint, TaskType
3
+
src/utils/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (274 Bytes). View file
 
src/utils/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (239 Bytes). View file
 
src/utils/__pycache__/data_def.cpython-311.pyc ADDED
Binary file (3.07 kB). View file
 
src/utils/__pycache__/data_def.cpython-39.pyc ADDED
Binary file (2.3 kB). View file
 
src/utils/__pycache__/process.cpython-311.pyc ADDED
Binary file (10.7 kB). View file
 
src/utils/__pycache__/process.cpython-39.pyc ADDED
Binary file (5.98 kB). View file
 
src/utils/data_def.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Literal
2
+ from models import *
3
+ from .process import *
4
+ # predefined processing logic for routine extraction tasks
5
+ TaskType = Literal["NER", "RE", "EE", "Base"]
6
+ ModelType = Literal["gpt-3.5-turbo", "gpt-4o"]
7
+
8
+ class DataPoint:
9
+ def __init__(self,
10
+ task: TaskType = "Base",
11
+ instruction: str = "",
12
+ text: str = "",
13
+ output_schema: str = "",
14
+ constraint: str = "",
15
+ use_file: bool = False,
16
+ file_path: str = "",
17
+ truth: str = ""):
18
+ """
19
+ Initialize a DataPoint instance.
20
+ """
21
+ # task information
22
+ self.task = task
23
+ self.instruction = instruction
24
+ self.text = text
25
+ self.output_schema = output_schema
26
+ self.constraint = constraint
27
+ self.use_file = use_file
28
+ self.file_path = file_path
29
+ self.truth = extract_json_dict(truth)
30
+ # temp storage
31
+ self.print_schema = ""
32
+ self.distilled_text = ""
33
+ self.chunk_text_list = []
34
+ # result feedback
35
+ self.result_list = []
36
+ self.result_trajectory = {}
37
+ self.pred = ""
38
+
39
+ def set_constraint(self, constraint):
40
+ self.constraint = constraint
41
+
42
+ def set_schema(self, output_schema):
43
+ self.output_schema = output_schema
44
+
45
+ def set_pred(self, pred):
46
+ self.pred = pred
47
+
48
+ def set_result_list(self, result_list):
49
+ self.result_list = result_list
50
+
51
+ def set_distilled_text(self, distilled_text):
52
+ self.distilled_text = distilled_text
53
+
54
+ def update_trajectory(self, function, result):
55
+ if function not in self.result_trajectory:
56
+ self.result_trajectory.update({function: result})
57
+
58
+ def get_result_trajectory(self):
59
+ return {"instruction": self.instruction, "text": self.text, "constraint": self.constraint, "trajectory": self.result_trajectory, "pred": self.pred}
src/utils/process.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Data Processing Functions.
3
+ Supports:
4
+ - Segmentation of long text
5
+ - Segmentation of file content
6
+ """
7
+ from langchain_community.document_loaders import TextLoader, PyPDFLoader, Docx2txtLoader, BSHTMLLoader, JSONLoader
8
+ from nltk.tokenize import sent_tokenize
9
+ from collections import Counter
10
+ import re
11
+ import json
12
+ import yaml
13
+ import os
14
+ import yaml
15
+ import os
16
+ import inspect
17
+ import ast
18
+ with open(os.path.join(os.path.dirname(__file__), "..", "config.yaml")) as file:
19
+ config = yaml.safe_load(file)
20
+
21
+ # Split the string text into chunks
22
+ def chunk_str(text):
23
+ sentences = sent_tokenize(text)
24
+ chunks = []
25
+ current_chunk = []
26
+ current_length = 0
27
+
28
+ for sentence in sentences:
29
+ token_count = len(sentence.split())
30
+ if current_length + token_count <= config['agent']['chunk_token_limit']:
31
+ current_chunk.append(sentence)
32
+ current_length += token_count
33
+ else:
34
+ if current_chunk:
35
+ chunks.append(' '.join(current_chunk))
36
+ current_chunk = [sentence]
37
+ current_length = token_count
38
+ if current_chunk:
39
+ chunks.append(' '.join(current_chunk))
40
+ return chunks
41
+
42
+ # Load and split the content of a file
43
+ def chunk_file(file_path):
44
+ pages = []
45
+
46
+ if file_path.endswith(".pdf"):
47
+ loader = PyPDFLoader(file_path)
48
+ elif file_path.endswith(".txt"):
49
+ loader = TextLoader(file_path)
50
+ elif file_path.endswith(".docx"):
51
+ loader = Docx2txtLoader(file_path)
52
+ elif file_path.endswith(".html"):
53
+ loader = BSHTMLLoader(file_path)
54
+ elif file_path.endswith(".json"):
55
+ loader = JSONLoader(file_path)
56
+ else:
57
+ raise ValueError("Unsupported file format") # Inform that the format is unsupported
58
+
59
+ pages = loader.load_and_split()
60
+ docs = ""
61
+ for item in pages:
62
+ docs += item.page_content
63
+ pages = chunk_str(docs)
64
+
65
+ return pages
66
+
67
+ def process_single_quotes(text):
68
+ result = re.sub(r"(?<!\w)'|'(?!\w)", '"', text)
69
+ return result
70
+
71
+ def remove_empty_values(data):
72
+ def is_empty(value):
73
+ return value is None or value == [] or value == "" or value == {}
74
+ if isinstance(data, dict):
75
+ return {
76
+ k: remove_empty_values(v)
77
+ for k, v in data.items()
78
+ if not is_empty(v)
79
+ }
80
+ elif isinstance(data, list):
81
+ return [
82
+ remove_empty_values(item)
83
+ for item in data
84
+ if not is_empty(item)
85
+ ]
86
+ else:
87
+ return data
88
+
89
+ def extract_json_dict(text):
90
+ if isinstance(text, dict):
91
+ return text
92
+ pattern = r'\{(?:[^{}]|(?:\{(?:[^{}]|(?:\{[^{}]*\})*)*\})*)*\}'
93
+ matches = re.findall(pattern, text)
94
+ if matches:
95
+ json_string = matches[-1]
96
+ json_string = process_single_quotes(json_string)
97
+ try:
98
+ json_dict = json.loads(json_string)
99
+ json_dict = remove_empty_values(json_dict)
100
+ if json_dict is None:
101
+ return "No valid information found."
102
+ return json_dict
103
+ except json.JSONDecodeError:
104
+ return json_string
105
+ else:
106
+ return text
107
+
108
+ def good_case_wrapper(example: str):
109
+ if example is None or example == "":
110
+ return ""
111
+ example = f"\nHere are some examples:\n{example}\n(END OF EXAMPLES)\nRefer to the reasoning steps and analysis in the examples to help complete the extraction task below.\n\n"
112
+ return example
113
+
114
+ def bad_case_wrapper(example: str):
115
+ if example is None or example == "":
116
+ return ""
117
+ example = f"\nHere are some examples of bad cases:\n{example}\n(END OF EXAMPLES)\nRefer to the reflection rules and reflection steps in the examples to help optimize the original result below.\n\n"
118
+ return example
119
+
120
+ def example_wrapper(example: str):
121
+ if example is None or example == "":
122
+ return ""
123
+ example = f"\nHere are some examples:\n{example}\n(END OF EXAMPLES)\n\n"
124
+ return example
125
+
126
+ def remove_redundant_space(s):
127
+ s = ' '.join(s.split())
128
+ s = re.sub(r"\s*(,|:|\(|\)|\.|_|;|'|-)\s*", r'\1', s)
129
+ return s
130
+
131
+ def format_string(s):
132
+ s = remove_redundant_space(s)
133
+ s = s.lower()
134
+ s = s.replace('{','').replace('}','')
135
+ s = re.sub(',+', ',', s)
136
+ s = re.sub('\.+', '.', s)
137
+ s = re.sub(';+', ';', s)
138
+ s = s.replace('’', "'")
139
+ return s
140
+
141
+ def calculate_metrics(y_truth: set, y_pred: set):
142
+ TP = len(y_truth & y_pred)
143
+ FN = len(y_truth - y_pred)
144
+ FP = len(y_pred - y_truth)
145
+ precision = TP / (TP + FP) if (TP + FP) > 0 else 0
146
+ recall = TP / (TP + FN) if (TP + FN) > 0 else 0
147
+ f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
148
+ return precision, recall, f1_score
149
+
150
+ def current_function_name():
151
+ try:
152
+ stack = inspect.stack()
153
+ if len(stack) > 1:
154
+ outer_func_name = stack[1].function
155
+ return outer_func_name
156
+ else:
157
+ print("No caller function found")
158
+ return None
159
+
160
+ except Exception as e:
161
+ print(f"An error occurred: {e}")
162
+ pass
163
+
164
+ def normalize_obj(value):
165
+ if isinstance(value, dict):
166
+ return frozenset((k, normalize_obj(v)) for k, v in value.items())
167
+ elif isinstance(value, (list, set, tuple)):
168
+ # 将 Counter 转换为元组以便于被哈希
169
+ return tuple(Counter(map(normalize_obj, value)).items())
170
+ elif isinstance(value, str):
171
+ return format_string(value)
172
+ return value
173
+
174
+ def dict_list_to_set(data_list):
175
+ result_set = set()
176
+ try:
177
+ for dictionary in data_list:
178
+ value_tuple = tuple(format_string(value) for value in dictionary.values())
179
+ result_set.add(value_tuple)
180
+ return result_set
181
+ except Exception as e:
182
+ print (f"Failed to convert dictionary list to set: {data_list}")
183
+ return result_set
src/webui/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .interface import InterFace