update
Browse files- .gitattributes +1 -2
- .gitignore +2 -2
- LICENSE +21 -0
- data/Artificial_Intelligence_Wikipedia.txt +0 -46
- data/Harry_Potter_Chapter1.pdf +0 -3
- data/Tulsi_Gabbard_News.html +0 -0
- examples/config/BookExtraction.yaml +15 -0
- examples/config/EE.yaml +14 -0
- examples/config/NER.yaml +13 -0
- examples/config/NewsExtraction.yaml +15 -0
- examples/config/RE.yaml +15 -0
- examples/config/Triple2KG.yaml +21 -0
- examples/example.py +17 -0
- examples/results/BookExtraction.json +48 -0
- examples/results/EE.json +13 -0
- examples/results/NER.json +16 -0
- examples/results/NewsExtraction.json +51 -0
- examples/results/RE.json +9 -0
- examples/results/TripleExtraction.json +156 -0
- experiments/dataset_def.py +181 -0
- experiments/run_ner.py +15 -0
- experiments/run_re.py +10 -0
- figs/logo.png +0 -0
- figs/main.png +0 -0
- requirements.txt +2 -1
- src/config.yaml +2 -1
- src/construct/__init__.py +1 -0
- src/construct/convert.py +201 -0
- src/models/llm_def.py +15 -16
- src/models/prompt_example.py +12 -12
- src/models/prompt_template.py +14 -14
- src/models/vllm_serve.py +2 -3
- src/modules/extraction_agent.py +28 -0
- src/modules/knowledge_base/case_repository.py +22 -22
- src/modules/knowledge_base/schema_repository.py +13 -0
- src/modules/reflection_agent.py +4 -5
- src/modules/schema_agent.py +12 -0
- src/pipeline.py +18 -2
- src/run.py +11 -2
- src/utils/__init__.py +0 -1
- src/utils/process.py +66 -29
- src/webui.py +33 -12
.gitattributes
CHANGED
@@ -32,5 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
32 |
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
-
data/Harry_Potter_Chapter1.pdf filter=lfs diff=lfs merge=lfs -text
|
|
|
32 |
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
.gitignore
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
-
local
|
2 |
**/__pycache__
|
3 |
-
*.pyc
|
|
|
|
|
|
1 |
**/__pycache__
|
2 |
+
*.pyc
|
3 |
+
dev
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2025 ZJUNLP
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
data/Artificial_Intelligence_Wikipedia.txt
DELETED
@@ -1,46 +0,0 @@
|
|
1 |
-
In the 22nd century, rising sea levels from global warming
|
2 |
-
have wiped out coastal cities and altered the world's climate.
|
3 |
-
With the human population in decline, nations have created
|
4 |
-
humanoid robots called mechas to fulfill various roles.
|
5 |
-
|
6 |
-
In Madison, New Jersey, David, an 11-year-old prototype mecha
|
7 |
-
child capable of love, is given to Henry Swinton and his wife
|
8 |
-
Monica, whose son Martin is in suspended animation. Monica
|
9 |
-
initially feels uncomfortable but warms to David after he is
|
10 |
-
activated and imprinted. David befriends Teddy, Martin's robotic
|
11 |
-
teddy bear.
|
12 |
-
|
13 |
-
After Martin is cured and brought home, he goads David into
|
14 |
-
cutting off a piece of Monica's hair. Later, David accidentally
|
15 |
-
pokes Monica's eye with scissors. During a pool party, David
|
16 |
-
reacts to being poked with a knife and both he and Martin fall
|
17 |
-
into the pool. Martin is saved, but David is blamed.
|
18 |
-
|
19 |
-
Henry convinces Monica to return David to his creators for
|
20 |
-
destruction, but instead, she abandons him in the woods with
|
21 |
-
Teddy. David, believing that becoming human will regain Monica's
|
22 |
-
love, decides to find the Blue Fairy.
|
23 |
-
|
24 |
-
David and Teddy are captured by the "Flesh Fair", where obsolete
|
25 |
-
mechas are destroyed. David pleads for his life, and the audience
|
26 |
-
allows him to escape with Gigolo Joe, a mecha framed for murder.
|
27 |
-
They travel to Rouge City, where "Dr. Know", a holographic answer
|
28 |
-
engine, directs them to the ruins of New York City and suggests
|
29 |
-
that the Blue Fairy may help.
|
30 |
-
|
31 |
-
David meets Professor Hobby, who shows him copies of himself,
|
32 |
-
including female variants. Disheartened, David attempts suicide,
|
33 |
-
but Joe rescues him. They find the Blue Fairy, which turns out
|
34 |
-
to be a statue. David repeatedly asks the statue to turn him
|
35 |
-
into a real boy until his power source is depleted.
|
36 |
-
|
37 |
-
Two thousand years later, humanity is extinct and Manhattan is
|
38 |
-
buried under ice. Mechas have evolved, and a group called the
|
39 |
-
Specialists resurrect David and Teddy. They reconstruct the Swinton
|
40 |
-
home from David's memories and explain that he cannot become human.
|
41 |
-
However, they recreate Monica using genetic material from the
|
42 |
-
strand of hair Teddy kept. Monica can live for only one day.
|
43 |
-
|
44 |
-
David spends his happiest day with Monica, and as she falls asleep,
|
45 |
-
she tells him she has always loved him. David lies down next to her
|
46 |
-
and closes his eyes.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data/Harry_Potter_Chapter1.pdf
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:b9eb5104658f4d6ef8ff9b457f28f188b6aa1b201443719c501e462072eacf57
|
3 |
-
size 163709
|
|
|
|
|
|
|
|
data/Tulsi_Gabbard_News.html
DELETED
The diff for this file is too large to render.
See raw diff
|
|
examples/config/BookExtraction.yaml
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
# Recommend using ChatGPT or DeepSeek APIs for complex IE task.
|
3 |
+
category: ChatGPT # model category, chosen from ChatGPT, DeepSeek, LLaMA, Qwen, ChatGLM, MiniCPM, OneKE.
|
4 |
+
model_name_or_path: gpt-4o-mini # # model name, chosen from the model list of the selected category.
|
5 |
+
api_key: your_api_key # your API key for the model with API service. No need for open-source models.
|
6 |
+
base_url: https://api.openai.com/v1 # # base URL for the API service. No need for open-source models.
|
7 |
+
|
8 |
+
extraction:
|
9 |
+
task: Base # task type, chosen from Base, NER, RE, EE.
|
10 |
+
instruction: Extract main characters and background setting from this chapter. # description for the task. No need for NER, RE, EE task.
|
11 |
+
use_file: true # whether to use a file for the input text. Default set to false.
|
12 |
+
file_path: ./data/input_files/Harry_Potter_Chapter1.pdf # # path to the input file. No need if use_file is set to false.
|
13 |
+
mode: quick # extraction mode, chosen from quick, detailed, customized. Default set to quick. See src/config.yaml for more details.
|
14 |
+
update_case: false # whether to update the case repository. Default set to false.
|
15 |
+
show_trajectory: false # whether to display the extracted intermediate steps
|
examples/config/EE.yaml
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
category: DeepSeek # model category, chosen from ChatGPT, DeepSeek, LLaMA, Qwen, ChatGLM, MiniCPM, OneKE.
|
3 |
+
model_name_or_path: deepseek-chat # model name, chosen from the model list of the selected category.
|
4 |
+
api_key: your_api_key # your API key for the model with API service. No need for open-source models.
|
5 |
+
base_url: https://api.deepseek.com # base URL for the API service. No need for open-source models.
|
6 |
+
|
7 |
+
extraction:
|
8 |
+
task: EE # task type, chosen from Base, NER, RE, EE.
|
9 |
+
text: UConn Health , an academic medical center , says in a media statement that it identified approximately 326,000 potentially impacted individuals whose personal information was contained in the compromised email accounts. # input text for the extraction task. No need if use_file is set to true.
|
10 |
+
constraint: {"phishing": ["damage amount", "attack pattern", "tool", "victim", "place", "attacker", "purpose", "trusted entity", "time"], "data breach": ["damage amount", "attack pattern", "number of data", "number of victim", "tool", "compromised data", "victim", "place", "attacker", "purpose", "time"], "ransom": ["damage amount", "attack pattern", "payment method", "tool", "victim", "place", "attacker", "price", "time"], "discover vulnerability": ["vulnerable system", "vulnerability", "vulnerable system owner", "vulnerable system version", "supported platform", "common vulnerabilities and exposures", "capabilities", "time", "discoverer"], "patch vulnerability": ["vulnerable system", "vulnerability", "issues addressed", "vulnerable system version", "releaser", "supported platform", "common vulnerabilities and exposures", "patch number", "time", "patch"]} # Specified event type and the corresponding arguments for the event extraction task. Structured as a dictionary with the event type as the key and the list of arguments as the value. Default set to empty.
|
11 |
+
use_file: false # whether to use a file for the input text.
|
12 |
+
mode: standard # extraction mode, chosen from quick, detailed, customized. Default set to quick. See src/config.yaml for more details.
|
13 |
+
update_case: false # whether to update the case repository. Default set to false.
|
14 |
+
show_trajectory: false # whether to display the extracted intermediate steps
|
examples/config/NER.yaml
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
category: LLaMA # model category, chosen from ChatGPT, DeepSeek, LLaMA, Qwen, ChatGLM, MiniCPM, OneKE.
|
3 |
+
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct # model name to download from huggingface or use the local model path.
|
4 |
+
vllm_serve: false # whether to use the vllm. Default set to false.
|
5 |
+
|
6 |
+
extraction:
|
7 |
+
task: NER # task type, chosen from Base, NER, RE, EE.
|
8 |
+
text: Finally , every other year , ELRA organizes a major conference LREC , the International Language Resources and Evaluation Conference . # input text for the extraction task. No need if use_file is set to true.
|
9 |
+
constraint: ["algorithm", "conference", "else", "product", "task", "field", "metrics", "organization", "researcher", "program language", "country", "location", "person", "university"] # Specified entity types for the named entity recognition task. Default set to empty.
|
10 |
+
use_file: false # whether to use a file for the input text.
|
11 |
+
mode: quick # extraction mode, chosen from quick, detailed, customized. Default set to quick. See src/config.yaml for more details.
|
12 |
+
update_case: false # whether to update the case repository. Default set to false.
|
13 |
+
show_trajectory: false # whether to display the extracted intermediate steps
|
examples/config/NewsExtraction.yaml
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
category: DeepSeek # model category, chosen from ChatGPT, DeepSeek, LLaMA, Qwen, ChatGLM, MiniCPM, OneKE.
|
3 |
+
model_name_or_path: deepseek-chat # model name, chosen from the model list of the selected category.
|
4 |
+
api_key: your_api_key # your API key for the model with API service. No need for open-source models.
|
5 |
+
base_url: https://api.deepseek.com # base URL for the API service. No need for open-source models.
|
6 |
+
|
7 |
+
extraction:
|
8 |
+
task: Base # task type, chosen from Base, NER, RE, EE.
|
9 |
+
instruction: Extract key information from the given text. # description for the task. No need for NER, RE, EE task.
|
10 |
+
use_file: true # whether to use a file for the input text. Default set to false.
|
11 |
+
file_path: ./data/input_files/Tulsi_Gabbard_News.html # path to the input file. No need if use_file is set to false.
|
12 |
+
output_schema: NewsReport # output schema for the extraction task. Selected the from schema repository.
|
13 |
+
mode: customized # extraction mode, chosen from quick, detailed, customized. Default set to quick. See src/config.yaml for more details.
|
14 |
+
update_case: false # whether to update the case repository. Default set to false.
|
15 |
+
show_trajectory: false # whether to display the extracted intermediate steps
|
examples/config/RE.yaml
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
category: ChatGPT # model category, chosen from ChatGPT, DeepSeek, LLaMA, Qwen, ChatGLM, MiniCPM, OneKE.
|
3 |
+
model_name_or_path: gpt-4o-mini # model name, chosen from the model list of the selected category.
|
4 |
+
api_key: your_api_key # your API key for the model with API service. No need for open-source models.
|
5 |
+
base_url: https://api.openai.com/v1 # base URL for the API service. No need for open-source models.
|
6 |
+
|
7 |
+
extraction:
|
8 |
+
task: RE # task type, chosen from Base, NER, RE, EE.
|
9 |
+
text: The aid group Doctors Without Borders said that since Saturday , more than 275 wounded people had been admitted and treated at Donka Hospital in the capital of Guinea , Conakry . # input text for the extraction task. No need if use_file is set to true.
|
10 |
+
constraint: ["nationality", "country capital", "place of death", "children", "location contains", "place of birth", "place lived", "administrative division of country", "country of administrative divisions", "company", "neighborhood of", "company founders"] # Specified entity types for the named entity recognition task. Default set to empty.
|
11 |
+
truth: {"relation_list": [{"head": "Guinea", "tail": "Conakry", "relation": "country capital"}]} # Truth data for the relation extraction task. Structured as a dictionary with the list of relation tuples as the value. Required if set update_case to true.
|
12 |
+
use_file: false # whether to use a file for the input text.
|
13 |
+
mode: quick # extraction mode, chosen from quick, detailed, customized. Default set to quick. See src/config.yaml for more details.
|
14 |
+
update_case: true # whether to update the case repository. Default set to false.
|
15 |
+
show_trajectory: false # whether to display the extracted intermediate steps
|
examples/config/Triple2KG.yaml
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
# Recommend using ChatGPT or DeepSeek APIs for complex Triple task.
|
3 |
+
category: ChatGPT # model category, chosen from ChatGPT, DeepSeek, LLaMA, Qwen, ChatGLM, MiniCPM, OneKE.
|
4 |
+
model_name_or_path: gpt-4o-mini # # model name, chosen from the model list of the selected category.
|
5 |
+
api_key: your_api_key # your API key for the model with API service. No need for open-source models.
|
6 |
+
base_url: https://api.openai.com/v1 # # base URL for the API service. No need for open-source models.
|
7 |
+
|
8 |
+
extraction:
|
9 |
+
mode: quick # extraction mode, chosen from quick, detailed, customized. Default set to quick. See src/config.yaml for more details.
|
10 |
+
task: Triple # task type, chosen from Base, NER, RE, EE. Now newly added task 'Triple'.
|
11 |
+
use_file: true # whether to use a file for the input text. Default set to false.
|
12 |
+
file_path: ./data/input_files/Artificial_Intelligence_Wikipedia.txt # # path to the input file. No need if use_file is set to false.
|
13 |
+
constraint: [["Person", "Place", "Event", "Property"], ["Interpersonal", "Located", "Ownership", "Action"]] # Specified entity or relation types for Triple Extraction task. You can write 3 lists for subject, relation and object types. Or you can write 2 lists for entity and relation types. Or you can write 1 list for entity type only.
|
14 |
+
update_case: false # whether to update the case repository. Default set to false.
|
15 |
+
show_trajectory: false # whether to display the extracted intermediate steps
|
16 |
+
|
17 |
+
# construct: # (Optional) If you want to construct a Knowledge Graph, you need to set the construct field, or you must delete this field.
|
18 |
+
# database: Neo4j # database type, now only support Neo4j.
|
19 |
+
# url: neo4j://localhost:7687 # your database URLοΌNeo4j's default port is 7687.
|
20 |
+
# username: your_username # your database username.
|
21 |
+
# password: "your_password" # your database password.
|
examples/example.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
sys.path.append("./src")
|
3 |
+
from models import *
|
4 |
+
from pipeline import *
|
5 |
+
import json
|
6 |
+
|
7 |
+
# model configuration
|
8 |
+
model = ChatGPT(model_name_or_path="your_model_name_or_path", api_key="your_api_key")
|
9 |
+
pipeline = Pipeline(model)
|
10 |
+
|
11 |
+
# extraction configuration
|
12 |
+
Task = "NER"
|
13 |
+
Text = "Finally , every other year , ELRA organizes a major conference LREC , the International Language Resources and Evaluation Conference."
|
14 |
+
Constraint = ["nationality", "country capital", "place of death", "children", "location contains", "place of birth", "place lived", "administrative division of country", "country of administrative divisions", "company", "neighborhood of", "company founders"]
|
15 |
+
|
16 |
+
# get extraction result
|
17 |
+
result, trajectory, frontend_schema, frontend_res = pipeline.get_extract_result(task=Task, text=Text, constraint=Constraint, show_trajectory=True)
|
examples/results/BookExtraction.json
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"main_characters": [
|
3 |
+
{
|
4 |
+
"name": "Mr. Dursley",
|
5 |
+
"description": "The director of a firm called Grunnings, a big, beefy man with hardly any neck and a large mustache."
|
6 |
+
},
|
7 |
+
{
|
8 |
+
"name": "Mrs. Dursley",
|
9 |
+
"description": "Thin and blonde, with nearly twice the usual amount of neck, spends time spying on neighbors."
|
10 |
+
},
|
11 |
+
{
|
12 |
+
"name": "Dudley Dursley",
|
13 |
+
"description": "The small son of Mr. and Mrs. Dursley, considered by them to be the finest boy anywhere."
|
14 |
+
},
|
15 |
+
{
|
16 |
+
"name": "Albus Dumbledore",
|
17 |
+
"description": "A tall, thin, and very old man with long silver hair and a purple cloak, who arrives mysteriously."
|
18 |
+
},
|
19 |
+
{
|
20 |
+
"name": "Professor McGonagall",
|
21 |
+
"description": "A severe-looking woman who can transform into a cat, wearing an emerald cloak."
|
22 |
+
},
|
23 |
+
{
|
24 |
+
"name": "Voldemort",
|
25 |
+
"description": "The dark wizard who has caused fear and chaos, but has mysteriously disappeared."
|
26 |
+
},
|
27 |
+
{
|
28 |
+
"name": "Harry Potter",
|
29 |
+
"description": "The young boy who survived Voldemort's attack, becoming a significant figure in the wizarding world."
|
30 |
+
},
|
31 |
+
{
|
32 |
+
"name": "Lily Potter",
|
33 |
+
"description": "Harry's mother, who is mentioned as having been killed by Voldemort."
|
34 |
+
},
|
35 |
+
{
|
36 |
+
"name": "James Potter",
|
37 |
+
"description": "Harry's father, who is mentioned as having been killed by Voldemort."
|
38 |
+
},
|
39 |
+
{
|
40 |
+
"name": "Hagrid",
|
41 |
+
"description": "A giant man who is caring and emotional about Harry's situation."
|
42 |
+
}
|
43 |
+
],
|
44 |
+
"background_setting": {
|
45 |
+
"location": "Number four, Privet Drive, Suburban",
|
46 |
+
"time_period": "A dull, gray Tuesday morning, Late 20th Century"
|
47 |
+
}
|
48 |
+
}
|
examples/results/EE.json
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"event_list": [
|
3 |
+
{
|
4 |
+
"event_type": "data breach",
|
5 |
+
"event_trigger": "compromised",
|
6 |
+
"event_argument": {
|
7 |
+
"number of victim": 326000,
|
8 |
+
"compromised data": "personal information contained in email accounts",
|
9 |
+
"victim": "individuals whose personal information was compromised"
|
10 |
+
}
|
11 |
+
}
|
12 |
+
]
|
13 |
+
}
|
examples/results/NER.json
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"entity_list": [
|
3 |
+
{
|
4 |
+
"name": "ELRA",
|
5 |
+
"type": "organization"
|
6 |
+
},
|
7 |
+
{
|
8 |
+
"name": "LREC",
|
9 |
+
"type": "conference"
|
10 |
+
},
|
11 |
+
{
|
12 |
+
"name": "International Language Resources and Evaluation Conference",
|
13 |
+
"type": "conference"
|
14 |
+
}
|
15 |
+
]
|
16 |
+
}
|
examples/results/NewsExtraction.json
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"title": "Who is Tulsi Gabbard? Meet Trump's pick for director of national intelligence",
|
3 |
+
"summary": "Tulsi Gabbard, President-elect Donald Trump\u2019s choice for director of national intelligence, could face a challenging Senate confirmation battle due to her lack of intelligence experience and controversial views.",
|
4 |
+
"publication_date": "December 4, 2024",
|
5 |
+
"keywords": [
|
6 |
+
"Tulsi Gabbard",
|
7 |
+
"Donald Trump",
|
8 |
+
"director of national intelligence",
|
9 |
+
"confirmation battle",
|
10 |
+
"intelligence agencies",
|
11 |
+
"Russia",
|
12 |
+
"Syria",
|
13 |
+
"Bashar al-Assad"
|
14 |
+
],
|
15 |
+
"events": [
|
16 |
+
{
|
17 |
+
"name": "Tulsi Gabbard's nomination for director of national intelligence",
|
18 |
+
"people_involved": [
|
19 |
+
{
|
20 |
+
"name": "Tulsi Gabbard",
|
21 |
+
"identity": "Former U.S. Representative",
|
22 |
+
"role": "Nominee for director of national intelligence"
|
23 |
+
},
|
24 |
+
{
|
25 |
+
"name": "Donald Trump",
|
26 |
+
"identity": "President-elect",
|
27 |
+
"role": "Nominator"
|
28 |
+
},
|
29 |
+
{
|
30 |
+
"name": "Tammy Duckworth",
|
31 |
+
"identity": "Democratic Senator",
|
32 |
+
"role": "Critic of Gabbard's nomination"
|
33 |
+
},
|
34 |
+
{
|
35 |
+
"name": "Olivia Troye",
|
36 |
+
"identity": "Former national security official",
|
37 |
+
"role": "Commentator on Gabbard's potential impact"
|
38 |
+
}
|
39 |
+
],
|
40 |
+
"process": "Gabbard's nomination is expected to lead to a Senate confirmation battle."
|
41 |
+
}
|
42 |
+
],
|
43 |
+
"quotes": {
|
44 |
+
"Tammy Duckworth": "The U.S. intelligence community has identified her as having troubling relationships with America\u2019s foes, and so my worry is that she couldn\u2019t pass a background check.",
|
45 |
+
"Olivia Troye": "If Gabbard is confirmed, America\u2019s allies may not share as much information with the U.S."
|
46 |
+
},
|
47 |
+
"viewpoints": [
|
48 |
+
"Gabbard's lack of intelligence experience raises concerns about her ability to oversee 18 intelligence agencies.",
|
49 |
+
"Her past comments and meetings with foreign adversaries have led to accusations of being a national security risk."
|
50 |
+
]
|
51 |
+
}
|
examples/results/RE.json
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"relation_list": [
|
3 |
+
{
|
4 |
+
"head": "Guinea",
|
5 |
+
"tail": "Conakry",
|
6 |
+
"relation": "country capital"
|
7 |
+
}
|
8 |
+
]
|
9 |
+
}
|
examples/results/TripleExtraction.json
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"triple_list": [
|
3 |
+
{
|
4 |
+
"head": "sea levels",
|
5 |
+
"head_type": "Property",
|
6 |
+
"relation": "wiped out",
|
7 |
+
"relation_type": "Action",
|
8 |
+
"tail": "coastal cities",
|
9 |
+
"tail_type": "Place"
|
10 |
+
},
|
11 |
+
{
|
12 |
+
"head": "nations",
|
13 |
+
"head_type": "Person",
|
14 |
+
"relation": "created",
|
15 |
+
"relation_type": "Action",
|
16 |
+
"tail": "mechas",
|
17 |
+
"tail_type": "Property"
|
18 |
+
},
|
19 |
+
{
|
20 |
+
"head": "David",
|
21 |
+
"head_type": "Person",
|
22 |
+
"relation": "given to",
|
23 |
+
"relation_type": "Ownership",
|
24 |
+
"tail": "Henry and Monica",
|
25 |
+
"tail_type": "Person"
|
26 |
+
},
|
27 |
+
{
|
28 |
+
"head": "Monica",
|
29 |
+
"head_type": "Person",
|
30 |
+
"relation": "feels uncomfortable",
|
31 |
+
"relation_type": "Interpersonal",
|
32 |
+
"tail": "David",
|
33 |
+
"tail_type": "Person"
|
34 |
+
},
|
35 |
+
{
|
36 |
+
"head": "David",
|
37 |
+
"head_type": "Person",
|
38 |
+
"relation": "befriends",
|
39 |
+
"relation_type": "Interpersonal",
|
40 |
+
"tail": "Teddy",
|
41 |
+
"tail_type": "Person"
|
42 |
+
},
|
43 |
+
{
|
44 |
+
"head": "Martin",
|
45 |
+
"head_type": "Person",
|
46 |
+
"relation": "goads",
|
47 |
+
"relation_type": "Action",
|
48 |
+
"tail": "David",
|
49 |
+
"tail_type": "Person"
|
50 |
+
},
|
51 |
+
{
|
52 |
+
"head": "David",
|
53 |
+
"head_type": "Person",
|
54 |
+
"relation": "blamed for",
|
55 |
+
"relation_type": "Action",
|
56 |
+
"tail": "incident",
|
57 |
+
"tail_type": "Event"
|
58 |
+
},
|
59 |
+
{
|
60 |
+
"head": "Monica",
|
61 |
+
"head_type": "Person",
|
62 |
+
"relation": "returns David to",
|
63 |
+
"relation_type": "Ownership",
|
64 |
+
"tail": "creators",
|
65 |
+
"tail_type": "Person"
|
66 |
+
},
|
67 |
+
{
|
68 |
+
"head": "David",
|
69 |
+
"head_type": "Person",
|
70 |
+
"relation": "decides to find",
|
71 |
+
"relation_type": "Action",
|
72 |
+
"tail": "Blue Fairy",
|
73 |
+
"tail_type": "Property"
|
74 |
+
},
|
75 |
+
{
|
76 |
+
"head": "David",
|
77 |
+
"head_type": "Person",
|
78 |
+
"relation": "pleads for",
|
79 |
+
"relation_type": "Action",
|
80 |
+
"tail": "his life",
|
81 |
+
"tail_type": "Event"
|
82 |
+
},
|
83 |
+
{
|
84 |
+
"head": "David",
|
85 |
+
"head_type": "Person",
|
86 |
+
"relation": "meets",
|
87 |
+
"relation_type": "Interpersonal",
|
88 |
+
"tail": "Professor Hobby",
|
89 |
+
"tail_type": "Person"
|
90 |
+
},
|
91 |
+
{
|
92 |
+
"head": "David",
|
93 |
+
"head_type": "Person",
|
94 |
+
"relation": "attempts",
|
95 |
+
"relation_type": "Action",
|
96 |
+
"tail": "suicide",
|
97 |
+
"tail_type": "Event"
|
98 |
+
},
|
99 |
+
{
|
100 |
+
"head": "Joe",
|
101 |
+
"head_type": "Person",
|
102 |
+
"relation": "rescues",
|
103 |
+
"relation_type": "Action",
|
104 |
+
"tail": "David",
|
105 |
+
"tail_type": "Person"
|
106 |
+
},
|
107 |
+
{
|
108 |
+
"head": "David",
|
109 |
+
"head_type": "Person",
|
110 |
+
"relation": "asks statue to turn him into",
|
111 |
+
"relation_type": "Action",
|
112 |
+
"tail": "real boy",
|
113 |
+
"tail_type": "Property"
|
114 |
+
},
|
115 |
+
{
|
116 |
+
"head": "humanity",
|
117 |
+
"head_type": "Person",
|
118 |
+
"relation": "is extinct",
|
119 |
+
"relation_type": "Action",
|
120 |
+
"tail": "future",
|
121 |
+
"tail_type": "Event"
|
122 |
+
},
|
123 |
+
{
|
124 |
+
"head": "Specialists",
|
125 |
+
"head_type": "Person",
|
126 |
+
"relation": "resurrect",
|
127 |
+
"relation_type": "Action",
|
128 |
+
"tail": "David and Teddy",
|
129 |
+
"tail_type": "Person"
|
130 |
+
},
|
131 |
+
{
|
132 |
+
"head": "Monica",
|
133 |
+
"head_type": "Person",
|
134 |
+
"relation": "can live for",
|
135 |
+
"relation_type": "Property",
|
136 |
+
"tail": "one day",
|
137 |
+
"tail_type": "Property"
|
138 |
+
},
|
139 |
+
{
|
140 |
+
"head": "David",
|
141 |
+
"head_type": "Person",
|
142 |
+
"relation": "spends",
|
143 |
+
"relation_type": "Action",
|
144 |
+
"tail": "happiest day with Monica",
|
145 |
+
"tail_type": "Event"
|
146 |
+
},
|
147 |
+
{
|
148 |
+
"head": "Monica",
|
149 |
+
"head_type": "Person",
|
150 |
+
"relation": "tells",
|
151 |
+
"relation_type": "Interpersonal",
|
152 |
+
"tail": "David",
|
153 |
+
"tail_type": "Person"
|
154 |
+
}
|
155 |
+
]
|
156 |
+
}
|
experiments/dataset_def.py
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import random
|
4 |
+
from utils import *
|
5 |
+
from pipeline import *
|
6 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
7 |
+
DATA_DIR = os.path.join(current_dir, "../data/datasets")
|
8 |
+
OUTPUT_DIR = os.path.join(current_dir, "results")
|
9 |
+
|
10 |
+
class BaseDataset:
|
11 |
+
def __init__(self):
|
12 |
+
pass
|
13 |
+
|
14 |
+
def __getitem__(self, idx):
|
15 |
+
return None
|
16 |
+
|
17 |
+
def __len__(self):
|
18 |
+
return None
|
19 |
+
|
20 |
+
def evaluate(self, idx, answer):
|
21 |
+
return None
|
22 |
+
|
23 |
+
class NERDataset(BaseDataset):
|
24 |
+
def __init__(self, name=None, task="NER", data_dir = f"{DATA_DIR}/CrossNER", output_dir = f"{OUTPUT_DIR}", train=False):
|
25 |
+
self.name = name
|
26 |
+
self.task = task
|
27 |
+
self.data_dir = data_dir
|
28 |
+
self.output_dir = output_dir
|
29 |
+
self.test_file = json.load(open(f"{data_dir}/train.json")) if train else json.load(open(f"{data_dir}/test.json"))
|
30 |
+
self.schema = str(json.load(open(f"{data_dir}/class.json")))
|
31 |
+
self.retry = 2
|
32 |
+
|
33 |
+
def evaluate(self, llm: BaseEngine, mode="", sample=None, random_sample=False, update_case=False):
|
34 |
+
# initialize
|
35 |
+
sample = len(self.test_file) if sample is None else sample
|
36 |
+
if random_sample:
|
37 |
+
test_file = random.sample(self.test_file, sample)
|
38 |
+
else:
|
39 |
+
test_file = self.test_file[:sample]
|
40 |
+
total_precision, total_recall, total_f1 = 0, 0, 0
|
41 |
+
num_items = 0
|
42 |
+
output_path = f"{self.output_dir}/{self.name}_{self.task}_{mode}_{llm.name}_sample{sample}.jsonl"
|
43 |
+
print("Results will be saved to: ", output_path)
|
44 |
+
|
45 |
+
# predict and evaluate
|
46 |
+
pipeline = Pipeline(llm=llm)
|
47 |
+
for item in test_file:
|
48 |
+
try:
|
49 |
+
# get prediction
|
50 |
+
num_items += 1
|
51 |
+
truth = list(item.items())[1]
|
52 |
+
truth = {truth[0]: truth[1]}
|
53 |
+
pred_set = set()
|
54 |
+
for attempt in range(self.retry):
|
55 |
+
pred_result, pred_detailed, _, _ = pipeline.get_extract_result(task=self.task, text=item['sentence'], constraint=self.schema, mode=mode, truth=truth, update_case=update_case)
|
56 |
+
try:
|
57 |
+
pred_result = pred_result['entity_list']
|
58 |
+
pred_set = dict_list_to_set(pred_result)
|
59 |
+
break
|
60 |
+
except Exception as e:
|
61 |
+
print(f"Failed to parse result: {pred_result}, retrying... Exception: {e}")
|
62 |
+
|
63 |
+
# evaluate
|
64 |
+
truth_result = item["entity_list"]
|
65 |
+
truth_set = dict_list_to_set(truth_result)
|
66 |
+
print(truth_set)
|
67 |
+
print(pred_set)
|
68 |
+
|
69 |
+
precision, recall, f1_score = calculate_metrics(truth_set, pred_set)
|
70 |
+
total_precision += precision
|
71 |
+
total_recall += recall
|
72 |
+
total_f1 += f1_score
|
73 |
+
|
74 |
+
pred_detailed["pred"] = pred_result
|
75 |
+
pred_detailed["truth"] = truth_result
|
76 |
+
pred_detailed["metrics"] = {"precision": precision, "recall": recall, "f1_score": f1_score}
|
77 |
+
res_detailed = {"id": num_items}
|
78 |
+
res_detailed.update(pred_detailed)
|
79 |
+
with open(output_path, 'a') as file:
|
80 |
+
file.write(json.dumps(res_detailed) + '\n')
|
81 |
+
except Exception as e:
|
82 |
+
print(f"Exception occured: {e}")
|
83 |
+
print(f"idx: {num_items}")
|
84 |
+
pass
|
85 |
+
|
86 |
+
# calculate overall metrics
|
87 |
+
if num_items > 0:
|
88 |
+
avg_precision = total_precision / num_items
|
89 |
+
avg_recall = total_recall / num_items
|
90 |
+
avg_f1 = total_f1 / num_items
|
91 |
+
overall_metrics = {
|
92 |
+
"total_items": num_items,
|
93 |
+
"average_precision": avg_precision,
|
94 |
+
"average_recall": avg_recall,
|
95 |
+
"average_f1_score": avg_f1
|
96 |
+
}
|
97 |
+
with open(output_path, 'a') as file:
|
98 |
+
file.write(json.dumps(overall_metrics) + '\n\n')
|
99 |
+
print(f"Overall Metrics:\nTotal Items: {num_items}\nAverage Precision: {avg_precision:.4f}\nAverage Recall: {avg_recall:.4f}\nAverage F1 Score: {avg_f1:.4f}")
|
100 |
+
else:
|
101 |
+
print("No items processed.")
|
102 |
+
|
103 |
+
class REDataset(BaseDataset):
|
104 |
+
def __init__(self, name=None, task="RE", data_dir = f"{DATA_DIR}/NYT11", output_dir = f"{OUTPUT_DIR}", train=False):
|
105 |
+
self.name = name
|
106 |
+
self.task = task
|
107 |
+
self.data_dir = data_dir
|
108 |
+
self.output_dir = output_dir
|
109 |
+
self.test_file = json.load(open(f"{data_dir}/train.json")) if train else json.load(open(f"{data_dir}/test.json"))
|
110 |
+
self.schema = str(json.load(open(f"{data_dir}/class.json")))
|
111 |
+
self.retry = 2
|
112 |
+
|
113 |
+
def evaluate(self, llm: BaseEngine, mode="", sample=None, random_sample=False, update_case=False):
|
114 |
+
# initialize
|
115 |
+
sample = len(self.test_file) if sample is None else sample
|
116 |
+
if random_sample:
|
117 |
+
test_file = random.sample(self.test_file, sample)
|
118 |
+
else:
|
119 |
+
test_file = self.test_file[:sample]
|
120 |
+
total_precision, total_recall, total_f1 = 0, 0, 0
|
121 |
+
num_items = 0
|
122 |
+
output_path = f"{self.output_dir}/{self.name}_{self.task}_{mode}_{llm.name}_sample{sample}.jsonl"
|
123 |
+
print("Results will be saved to: ", output_path)
|
124 |
+
|
125 |
+
# predict and evaluate
|
126 |
+
pipeline = Pipeline(llm=llm)
|
127 |
+
for item in test_file:
|
128 |
+
try:
|
129 |
+
# get prediction
|
130 |
+
num_items += 1
|
131 |
+
truth = list(item.items())[1]
|
132 |
+
truth = {truth[0]: truth[1]}
|
133 |
+
pred_set = set()
|
134 |
+
for attempt in range(self.retry):
|
135 |
+
pred_result, pred_detailed, _, _ = pipeline.get_extract_result(task=self.task, text=item['sentence'], constraint=self.schema, mode=mode, truth=truth, update_case=update_case)
|
136 |
+
try:
|
137 |
+
pred_result = pred_result['relation_list']
|
138 |
+
pred_set = dict_list_to_set(pred_result)
|
139 |
+
break
|
140 |
+
except Exception as e:
|
141 |
+
print(f"Failed to parse result: {pred_result}, retrying... Exception: {e}")
|
142 |
+
|
143 |
+
# evaluate
|
144 |
+
truth_result = item["relation_list"]
|
145 |
+
truth_set = dict_list_to_set(truth_result)
|
146 |
+
print(truth_set)
|
147 |
+
print(pred_set)
|
148 |
+
|
149 |
+
precision, recall, f1_score = calculate_metrics(truth_set, pred_set)
|
150 |
+
total_precision += precision
|
151 |
+
total_recall += recall
|
152 |
+
total_f1 += f1_score
|
153 |
+
|
154 |
+
pred_detailed["pred"] = pred_result
|
155 |
+
pred_detailed["truth"] = truth_result
|
156 |
+
pred_detailed["metrics"] = {"precision": precision, "recall": recall, "f1_score": f1_score}
|
157 |
+
res_detailed = {"id": num_items}
|
158 |
+
res_detailed.update(pred_detailed)
|
159 |
+
with open(output_path, 'a') as file:
|
160 |
+
file.write(json.dumps(res_detailed) + '\n')
|
161 |
+
except Exception as e:
|
162 |
+
print(f"Exception occured: {e}")
|
163 |
+
print(f"idx: {num_items}")
|
164 |
+
pass
|
165 |
+
|
166 |
+
# calculate overall metrics
|
167 |
+
if num_items > 0:
|
168 |
+
avg_precision = total_precision / num_items
|
169 |
+
avg_recall = total_recall / num_items
|
170 |
+
avg_f1 = total_f1 / num_items
|
171 |
+
overall_metrics = {
|
172 |
+
"total_items": num_items,
|
173 |
+
"average_precision": avg_precision,
|
174 |
+
"average_recall": avg_recall,
|
175 |
+
"average_f1_score": avg_f1
|
176 |
+
}
|
177 |
+
with open(output_path, 'a') as file:
|
178 |
+
file.write(json.dumps(overall_metrics) + '\n\n')
|
179 |
+
print(f"Overall Metrics:\nTotal Items: {num_items}\nAverage Precision: {avg_precision:.4f}\nAverage Recall: {avg_recall:.4f}\nAverage F1 Score: {avg_f1:.4f}")
|
180 |
+
else:
|
181 |
+
print("No items processed.")
|
experiments/run_ner.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
sys.path.append("./src")
|
3 |
+
from models import *
|
4 |
+
from dataset_def import *
|
5 |
+
name = "crossner-"
|
6 |
+
data_dir = "./data/datasets/CrossNER/"
|
7 |
+
model = ChatGPT(model_name_or_path="gpt-4o-mini", api_key="your_api_key", base_url=" https://api.openai.com/v1")
|
8 |
+
tasklist = ["ai", "literature", "music", "politics", "science"]
|
9 |
+
for task in tasklist:
|
10 |
+
task_name = name + task
|
11 |
+
task_data_dir = data_dir + task
|
12 |
+
dataset = NERDataset(name=task_name, data_dir=task_data_dir)
|
13 |
+
mode = "quick"
|
14 |
+
f1_score = dataset.evaluate(llm=model, mode=mode)
|
15 |
+
print(f"Task: {task_name}, f1_score: {f1_score}")
|
experiments/run_re.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
sys.path.append("./src")
|
3 |
+
from models import *
|
4 |
+
from dataset_def import *
|
5 |
+
data_dir = "./data/datasets/NYT11/"
|
6 |
+
model = LLaMA("meta-llama/Meta-Llama-3-8B-Instruct")
|
7 |
+
dataset = REDataset(name="NYT11", data_dir=data_dir)
|
8 |
+
f1_score = dataset.evaluate(llm=model, mode="quick")
|
9 |
+
print("f1_score: ", f1_score)
|
10 |
+
|
figs/logo.png
ADDED
![]() |
figs/main.png
ADDED
![]() |
requirements.txt
CHANGED
@@ -15,4 +15,5 @@ sentencepiece==0.2.0
|
|
15 |
protobuf==5.29.3
|
16 |
bitsandbytes==0.45.0
|
17 |
vllm==0.6.0
|
18 |
-
gradio==4.44.0
|
|
|
|
15 |
protobuf==5.29.3
|
16 |
bitsandbytes==0.45.0
|
17 |
vllm==0.6.0
|
18 |
+
gradio==4.44.0
|
19 |
+
neo4j==5.28.1
|
src/config.yaml
CHANGED
@@ -6,6 +6,7 @@ agent:
|
|
6 |
default_ner: Extract the Named Entities in the given text.
|
7 |
default_re: Extract Relationships between Named Entities in the given text.
|
8 |
default_ee: Extract the Events in the given text.
|
|
|
9 |
chunk_token_limit: 1024
|
10 |
mode:
|
11 |
quick:
|
@@ -16,5 +17,5 @@ agent:
|
|
16 |
extraction_agent: extract_information_with_case
|
17 |
reflection_agent: reflect_with_case
|
18 |
customized:
|
19 |
-
schema_agent:
|
20 |
extraction_agent: extract_information_direct
|
|
|
6 |
default_ner: Extract the Named Entities in the given text.
|
7 |
default_re: Extract Relationships between Named Entities in the given text.
|
8 |
default_ee: Extract the Events in the given text.
|
9 |
+
default_triple: Extract the Triples (subject, relation, object) from the given text, hope that all the relationships for each entity can be extracted.
|
10 |
chunk_token_limit: 1024
|
11 |
mode:
|
12 |
quick:
|
|
|
17 |
extraction_agent: extract_information_with_case
|
18 |
reflection_agent: reflect_with_case
|
19 |
customized:
|
20 |
+
schema_agent: get_retrieved_schema
|
21 |
extraction_agent: extract_information_direct
|
src/construct/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .convert import *
|
src/construct/convert.py
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import re
|
3 |
+
from neo4j import GraphDatabase
|
4 |
+
|
5 |
+
|
6 |
+
def sanitize_string(input_str, max_length=255):
|
7 |
+
"""
|
8 |
+
Process the input string to ensure it meets the database requirements.
|
9 |
+
"""
|
10 |
+
# step1: Replace invalid characters
|
11 |
+
input_str = re.sub(r'[^a-zA-Z0-9_]', '_', input_str)
|
12 |
+
|
13 |
+
# step2: Add prefix if it starts with a digit
|
14 |
+
if input_str[0].isdigit():
|
15 |
+
input_str = 'num' + input_str
|
16 |
+
|
17 |
+
# step3: Limit length
|
18 |
+
if len(input_str) > max_length:
|
19 |
+
input_str = input_str[:max_length]
|
20 |
+
|
21 |
+
return input_str
|
22 |
+
|
23 |
+
|
24 |
+
def generate_cypher_statements(data):
|
25 |
+
"""
|
26 |
+
Generates Cypher query statements based on the provided JSON data.
|
27 |
+
"""
|
28 |
+
cypher_statements = []
|
29 |
+
parsed_data = json.loads(data)
|
30 |
+
|
31 |
+
def create_statement(triple):
|
32 |
+
head = triple.get("head")
|
33 |
+
head_type = triple.get("head_type")
|
34 |
+
relation = triple.get("relation")
|
35 |
+
relation_type = triple.get("relation_type")
|
36 |
+
tail = triple.get("tail")
|
37 |
+
tail_type = triple.get("tail_type")
|
38 |
+
|
39 |
+
# head_safe = sanitize_string(head) if head else None
|
40 |
+
head_type_safe = sanitize_string(head_type) if head_type else None
|
41 |
+
# relation_safe = sanitize_string(relation) if relation else None
|
42 |
+
relation_type_safe = sanitize_string(relation_type) if relation_type else None
|
43 |
+
# tail_safe = sanitize_string(tail) if tail else None
|
44 |
+
tail_type_safe = sanitize_string(tail_type) if tail_type else None
|
45 |
+
|
46 |
+
statement = ""
|
47 |
+
if head:
|
48 |
+
if head_type_safe:
|
49 |
+
statement += f'MERGE (a:{head_type_safe} {{name: "{head}"}}) '
|
50 |
+
else:
|
51 |
+
statement += f'MERGE (a:UNTYPED {{name: "{head}"}}) '
|
52 |
+
if tail:
|
53 |
+
if tail_type_safe:
|
54 |
+
statement += f'MERGE (b:{tail_type_safe} {{name: "{tail}"}}) '
|
55 |
+
else:
|
56 |
+
statement += f'MERGE (b:UNTYPED {{name: "{tail}"}}) '
|
57 |
+
if relation:
|
58 |
+
if head and tail: # Only create relation if head and tail exist.
|
59 |
+
if relation_type_safe:
|
60 |
+
statement += f'MERGE (a)-[:{relation_type_safe} {{name: "{relation}"}}]->(b);'
|
61 |
+
else:
|
62 |
+
statement += f'MERGE (a)-[:UNTYPED {{name: "{relation}"}}]->(b);'
|
63 |
+
else:
|
64 |
+
statement += ';' if statement != "" else ''
|
65 |
+
else:
|
66 |
+
if relation_type_safe: # if relation is not provided, create relation by `relation_type`.
|
67 |
+
statement += f'MERGE (a)-[:{relation_type_safe} {{name: "{relation_type_safe}"}}]->(b);'
|
68 |
+
else:
|
69 |
+
statement += ';' if statement != "" else ''
|
70 |
+
return statement
|
71 |
+
|
72 |
+
if "triple_list" in parsed_data:
|
73 |
+
for triple in parsed_data["triple_list"]:
|
74 |
+
cypher_statements.append(create_statement(triple))
|
75 |
+
else:
|
76 |
+
cypher_statements.append(create_statement(parsed_data))
|
77 |
+
|
78 |
+
return cypher_statements
|
79 |
+
|
80 |
+
|
81 |
+
def execute_cypher_statements(uri, user, password, cypher_statements):
|
82 |
+
"""
|
83 |
+
Executes the generated Cypher query statements.
|
84 |
+
"""
|
85 |
+
driver = GraphDatabase.driver(uri, auth=(user, password))
|
86 |
+
|
87 |
+
with driver.session() as session:
|
88 |
+
for statement in cypher_statements:
|
89 |
+
session.run(statement)
|
90 |
+
print(f"Executed: {statement}")
|
91 |
+
|
92 |
+
# Write excuted cypher statements to a text file if you want.
|
93 |
+
# with open("executed_statements.txt", 'a') as f:
|
94 |
+
# for statement in cypher_statements:
|
95 |
+
# f.write(statement + '\n')
|
96 |
+
# f.write('\n')
|
97 |
+
|
98 |
+
driver.close()
|
99 |
+
|
100 |
+
|
101 |
+
# Here is a test of your database connection:
|
102 |
+
if __name__ == "__main__":
|
103 |
+
# test_data 1: Contains a list of triples
|
104 |
+
test_data = '''
|
105 |
+
{
|
106 |
+
"triple_list": [
|
107 |
+
{
|
108 |
+
"head": "J.K. Rowling",
|
109 |
+
"head_type": "Person",
|
110 |
+
"relation": "wrote",
|
111 |
+
"relation_type": "Actions",
|
112 |
+
"tail": "Fantastic Beasts and Where to Find Them",
|
113 |
+
"tail_type": "Book"
|
114 |
+
},
|
115 |
+
{
|
116 |
+
"head": "Fantastic Beasts and Where to Find Them",
|
117 |
+
"head_type": "Book",
|
118 |
+
"relation": "extra section of",
|
119 |
+
"relation_type": "Affiliation",
|
120 |
+
"tail": "Harry Potter Series",
|
121 |
+
"tail_type": "Book"
|
122 |
+
},
|
123 |
+
{
|
124 |
+
"head": "J.K. Rowling",
|
125 |
+
"head_type": "Person",
|
126 |
+
"relation": "wrote",
|
127 |
+
"relation_type": "Actions",
|
128 |
+
"tail": "Harry Potter Series",
|
129 |
+
"tail_type": "Book"
|
130 |
+
},
|
131 |
+
{
|
132 |
+
"head": "Harry Potter Series",
|
133 |
+
"head_type": "Book",
|
134 |
+
"relation": "create",
|
135 |
+
"relation_type": "Actions",
|
136 |
+
"tail": "Dumbledore",
|
137 |
+
"tail_type": "Person"
|
138 |
+
},
|
139 |
+
{
|
140 |
+
"head": "Fantastic Beasts and Where to Find Them",
|
141 |
+
"head_type": "Book",
|
142 |
+
"relation": "mention",
|
143 |
+
"relation_type": "Actions",
|
144 |
+
"tail": "Dumbledore",
|
145 |
+
"tail_type": "Person"
|
146 |
+
},
|
147 |
+
{
|
148 |
+
"head": "Voldemort",
|
149 |
+
"head_type": "Person",
|
150 |
+
"relation": "afrid",
|
151 |
+
"relation_type": "Emotion",
|
152 |
+
"tail": "Dumbledore",
|
153 |
+
"tail_type": "Person"
|
154 |
+
},
|
155 |
+
{
|
156 |
+
"head": "Voldemort",
|
157 |
+
"head_type": "Person",
|
158 |
+
"relation": "robs",
|
159 |
+
"relation_type": "Actions",
|
160 |
+
"tail": "the Elder Wand",
|
161 |
+
"tail_type": "Weapon"
|
162 |
+
},
|
163 |
+
{
|
164 |
+
"head": "the Elder Wand",
|
165 |
+
"head_type": "Weapon",
|
166 |
+
"relation": "belong to",
|
167 |
+
"relation_type": "Affiliation",
|
168 |
+
"tail": "Dumbledore",
|
169 |
+
"tail_type": "Person"
|
170 |
+
}
|
171 |
+
]
|
172 |
+
}
|
173 |
+
'''
|
174 |
+
|
175 |
+
# test_data 2: Contains a single triple
|
176 |
+
# test_data = '''
|
177 |
+
# {
|
178 |
+
# "head": "Christopher Nolan",
|
179 |
+
# "head_type": "Person",
|
180 |
+
# "relation": "directed",
|
181 |
+
# "relation_type": "Action",
|
182 |
+
# "tail": "Inception",
|
183 |
+
# "tail_type": "Movie"
|
184 |
+
# }
|
185 |
+
# '''
|
186 |
+
|
187 |
+
# Generate Cypher query statements
|
188 |
+
cypher_statements = generate_cypher_statements(test_data)
|
189 |
+
|
190 |
+
# Print the generated Cypher query statements
|
191 |
+
for statement in cypher_statements:
|
192 |
+
print(statement)
|
193 |
+
print("\n")
|
194 |
+
|
195 |
+
# Execute the generated Cypher query statements
|
196 |
+
execute_cypher_statements(
|
197 |
+
uri="neo4j://localhost:7687", # your URI
|
198 |
+
user="your_username", # your username
|
199 |
+
password="your_password", # your password
|
200 |
+
cypher_statements=cypher_statements,
|
201 |
+
)
|
src/models/llm_def.py
CHANGED
@@ -22,7 +22,7 @@ class BaseEngine:
|
|
22 |
self.top_p = 0.9
|
23 |
self.max_tokens = 1024
|
24 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
25 |
-
|
26 |
def get_chat_response(self, prompt):
|
27 |
raise NotImplementedError
|
28 |
|
@@ -30,7 +30,7 @@ class BaseEngine:
|
|
30 |
self.temperature = temperature
|
31 |
self.top_p = top_p
|
32 |
self.max_tokens = max_tokens
|
33 |
-
|
34 |
class LLaMA(BaseEngine):
|
35 |
def __init__(self, model_name_or_path: str):
|
36 |
super().__init__(model_name_or_path)
|
@@ -61,7 +61,7 @@ class LLaMA(BaseEngine):
|
|
61 |
top_p=self.top_p,
|
62 |
)
|
63 |
return outputs[0]["generated_text"][-1]['content'].strip()
|
64 |
-
|
65 |
class Qwen(BaseEngine):
|
66 |
def __init__(self, model_name_or_path: str):
|
67 |
super().__init__(model_name_or_path)
|
@@ -72,7 +72,7 @@ class Qwen(BaseEngine):
|
|
72 |
torch_dtype="auto",
|
73 |
device_map="auto"
|
74 |
)
|
75 |
-
|
76 |
def get_chat_response(self, prompt):
|
77 |
messages = [
|
78 |
{"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."},
|
@@ -94,7 +94,7 @@ class Qwen(BaseEngine):
|
|
94 |
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
|
95 |
]
|
96 |
response = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
|
97 |
-
|
98 |
return response
|
99 |
|
100 |
class MiniCPM(BaseEngine):
|
@@ -125,7 +125,7 @@ class MiniCPM(BaseEngine):
|
|
125 |
model_outputs[i][len(model_inputs[i]):] for i in range(len(model_inputs))
|
126 |
]
|
127 |
response = self.tokenizer.batch_decode(output_token_ids, skip_special_tokens=True)[0].strip()
|
128 |
-
|
129 |
return response
|
130 |
|
131 |
class ChatGLM(BaseEngine):
|
@@ -155,7 +155,7 @@ class ChatGLM(BaseEngine):
|
|
155 |
)
|
156 |
model_outputs = model_outputs[:, model_inputs['input_ids'].shape[1]:]
|
157 |
response = self.tokenizer.batch_decode(model_outputs, skip_special_tokens=True)[0].strip()
|
158 |
-
|
159 |
return response
|
160 |
|
161 |
class OneKE(BaseEngine):
|
@@ -164,7 +164,7 @@ class OneKE(BaseEngine):
|
|
164 |
self.name = "OneKE"
|
165 |
self.model_id = model_name_or_path
|
166 |
config = AutoConfig.from_pretrained(self.model_id, trust_remote_code=True)
|
167 |
-
quantization_config=BitsAndBytesConfig(
|
168 |
load_in_4bit=True,
|
169 |
llm_int8_threshold=6.0,
|
170 |
llm_int8_has_fp16_weight=False,
|
@@ -175,12 +175,12 @@ class OneKE(BaseEngine):
|
|
175 |
self.model = AutoModelForCausalLM.from_pretrained(
|
176 |
self.model_id,
|
177 |
config=config,
|
178 |
-
device_map="auto",
|
179 |
quantization_config=quantization_config,
|
180 |
torch_dtype=torch.bfloat16,
|
181 |
trust_remote_code=True,
|
182 |
)
|
183 |
-
|
184 |
def get_chat_response(self, prompt):
|
185 |
system_prompt = '<<SYS>>\nYou are a helpful assistant. δ½ ζ―δΈδΈͺδΉδΊε©δΊΊηε©ζγ\n<</SYS>>\n\n'
|
186 |
sintruct = '[INST] ' + system_prompt + prompt + '[/INST]'
|
@@ -191,9 +191,9 @@ class OneKE(BaseEngine):
|
|
191 |
generation_output = generation_output.sequences[0]
|
192 |
generation_output = generation_output[input_length:]
|
193 |
response = self.tokenizer.decode(generation_output, skip_special_tokens=True)
|
194 |
-
|
195 |
return response
|
196 |
-
|
197 |
class ChatGPT(BaseEngine):
|
198 |
def __init__(self, model_name_or_path: str, api_key: str, base_url=openai.base_url):
|
199 |
self.name = "ChatGPT"
|
@@ -207,7 +207,7 @@ class ChatGPT(BaseEngine):
|
|
207 |
else:
|
208 |
self.api_key = os.environ["OPENAI_API_KEY"]
|
209 |
self.client = OpenAI(api_key=self.api_key, base_url=self.base_url)
|
210 |
-
|
211 |
def get_chat_response(self, input):
|
212 |
response = self.client.chat.completions.create(
|
213 |
model=self.model,
|
@@ -234,7 +234,7 @@ class DeepSeek(BaseEngine):
|
|
234 |
else:
|
235 |
self.api_key = os.environ["DEEPSEEK_API_KEY"]
|
236 |
self.client = OpenAI(api_key=self.api_key, base_url=self.base_url)
|
237 |
-
|
238 |
def get_chat_response(self, input):
|
239 |
response = self.client.chat.completions.create(
|
240 |
model=self.model,
|
@@ -258,7 +258,7 @@ class LocalServer(BaseEngine):
|
|
258 |
self.max_tokens = 1024
|
259 |
self.api_key = "EMPTY_API_KEY"
|
260 |
self.client = OpenAI(api_key=self.api_key, base_url=self.base_url)
|
261 |
-
|
262 |
def get_chat_response(self, input):
|
263 |
try:
|
264 |
response = self.client.chat.completions.create(
|
@@ -276,4 +276,3 @@ class LocalServer(BaseEngine):
|
|
276 |
print("Error: Unable to connect to the server. Please check if the vllm service is running and the port is 8080.")
|
277 |
except Exception as e:
|
278 |
print(f"Error: {e}")
|
279 |
-
|
|
|
22 |
self.top_p = 0.9
|
23 |
self.max_tokens = 1024
|
24 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
25 |
+
|
26 |
def get_chat_response(self, prompt):
|
27 |
raise NotImplementedError
|
28 |
|
|
|
30 |
self.temperature = temperature
|
31 |
self.top_p = top_p
|
32 |
self.max_tokens = max_tokens
|
33 |
+
|
34 |
class LLaMA(BaseEngine):
|
35 |
def __init__(self, model_name_or_path: str):
|
36 |
super().__init__(model_name_or_path)
|
|
|
61 |
top_p=self.top_p,
|
62 |
)
|
63 |
return outputs[0]["generated_text"][-1]['content'].strip()
|
64 |
+
|
65 |
class Qwen(BaseEngine):
|
66 |
def __init__(self, model_name_or_path: str):
|
67 |
super().__init__(model_name_or_path)
|
|
|
72 |
torch_dtype="auto",
|
73 |
device_map="auto"
|
74 |
)
|
75 |
+
|
76 |
def get_chat_response(self, prompt):
|
77 |
messages = [
|
78 |
{"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."},
|
|
|
94 |
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
|
95 |
]
|
96 |
response = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
|
97 |
+
|
98 |
return response
|
99 |
|
100 |
class MiniCPM(BaseEngine):
|
|
|
125 |
model_outputs[i][len(model_inputs[i]):] for i in range(len(model_inputs))
|
126 |
]
|
127 |
response = self.tokenizer.batch_decode(output_token_ids, skip_special_tokens=True)[0].strip()
|
128 |
+
|
129 |
return response
|
130 |
|
131 |
class ChatGLM(BaseEngine):
|
|
|
155 |
)
|
156 |
model_outputs = model_outputs[:, model_inputs['input_ids'].shape[1]:]
|
157 |
response = self.tokenizer.batch_decode(model_outputs, skip_special_tokens=True)[0].strip()
|
158 |
+
|
159 |
return response
|
160 |
|
161 |
class OneKE(BaseEngine):
|
|
|
164 |
self.name = "OneKE"
|
165 |
self.model_id = model_name_or_path
|
166 |
config = AutoConfig.from_pretrained(self.model_id, trust_remote_code=True)
|
167 |
+
quantization_config=BitsAndBytesConfig(
|
168 |
load_in_4bit=True,
|
169 |
llm_int8_threshold=6.0,
|
170 |
llm_int8_has_fp16_weight=False,
|
|
|
175 |
self.model = AutoModelForCausalLM.from_pretrained(
|
176 |
self.model_id,
|
177 |
config=config,
|
178 |
+
device_map="auto",
|
179 |
quantization_config=quantization_config,
|
180 |
torch_dtype=torch.bfloat16,
|
181 |
trust_remote_code=True,
|
182 |
)
|
183 |
+
|
184 |
def get_chat_response(self, prompt):
|
185 |
system_prompt = '<<SYS>>\nYou are a helpful assistant. δ½ ζ―δΈδΈͺδΉδΊε©δΊΊηε©ζγ\n<</SYS>>\n\n'
|
186 |
sintruct = '[INST] ' + system_prompt + prompt + '[/INST]'
|
|
|
191 |
generation_output = generation_output.sequences[0]
|
192 |
generation_output = generation_output[input_length:]
|
193 |
response = self.tokenizer.decode(generation_output, skip_special_tokens=True)
|
194 |
+
|
195 |
return response
|
196 |
+
|
197 |
class ChatGPT(BaseEngine):
|
198 |
def __init__(self, model_name_or_path: str, api_key: str, base_url=openai.base_url):
|
199 |
self.name = "ChatGPT"
|
|
|
207 |
else:
|
208 |
self.api_key = os.environ["OPENAI_API_KEY"]
|
209 |
self.client = OpenAI(api_key=self.api_key, base_url=self.base_url)
|
210 |
+
|
211 |
def get_chat_response(self, input):
|
212 |
response = self.client.chat.completions.create(
|
213 |
model=self.model,
|
|
|
234 |
else:
|
235 |
self.api_key = os.environ["DEEPSEEK_API_KEY"]
|
236 |
self.client = OpenAI(api_key=self.api_key, base_url=self.base_url)
|
237 |
+
|
238 |
def get_chat_response(self, input):
|
239 |
response = self.client.chat.completions.create(
|
240 |
model=self.model,
|
|
|
258 |
self.max_tokens = 1024
|
259 |
self.api_key = "EMPTY_API_KEY"
|
260 |
self.client = OpenAI(api_key=self.api_key, base_url=self.base_url)
|
261 |
+
|
262 |
def get_chat_response(self, input):
|
263 |
try:
|
264 |
response = self.client.chat.completions.create(
|
|
|
276 |
print("Error: Unable to connect to the server. Please check if the vllm service is running and the port is 8080.")
|
277 |
except Exception as e:
|
278 |
print(f"Error: {e}")
|
|
src/models/prompt_example.py
CHANGED
@@ -2,7 +2,7 @@ json_schema_examples = """
|
|
2 |
**Task**: Please extract all economic policies affecting the stock market between 2015 and 2023 and the exact dates of their implementation.
|
3 |
**Text**: This text is from the field of Economics and represents the genre of Article.
|
4 |
...(example text)...
|
5 |
-
**Output Schema**:
|
6 |
{
|
7 |
"economic_policies": [
|
8 |
{
|
@@ -31,9 +31,9 @@ Example3:
|
|
31 |
**Text**: This text is from the field of Political and represents the genre of News Report.
|
32 |
...(example text)...
|
33 |
**Output Schema**:
|
34 |
-
Answer:
|
35 |
{
|
36 |
-
"news_report":
|
37 |
{
|
38 |
"title": null,
|
39 |
"summary": null,
|
@@ -56,9 +56,9 @@ Answer:
|
|
56 |
"""
|
57 |
|
58 |
code_schema_examples = """
|
59 |
-
Example1:
|
60 |
**Task**: Extract all the entities in the given text.
|
61 |
-
**Text**:
|
62 |
...(example text)...
|
63 |
**Output Schema**:
|
64 |
```python
|
@@ -68,12 +68,12 @@ from pydantic import BaseModel, Field
|
|
68 |
class Entity(BaseModel):
|
69 |
label : str = Field(description="The type or category of the entity, such as 'Process', 'Technique', 'Data Structure', 'Methodology', 'Person', etc. ")
|
70 |
name : str = Field(description="The specific name of the entity. It should represent a single, distinct concept and must not be an empty string. For example, if the entity is a 'Technique', the name could be 'Neural Networks'.")
|
71 |
-
|
72 |
class ExtractionTarget(BaseModel):
|
73 |
entity_list : List[Entity] = Field(description="All the entities presented in the context. The entities should encode ONE concept.")
|
74 |
```
|
75 |
|
76 |
-
Example2:
|
77 |
**Task**: Extract all the information in the given text.
|
78 |
**Text**: This text is from the field of Political and represents the genre of News Article.
|
79 |
...(example text)...
|
@@ -95,7 +95,7 @@ class Event(BaseModel):
|
|
95 |
process: Optional[str] = Field(description="Details of the event process")
|
96 |
result: Optional[str] = Field(default=None, description="Result or outcome of the event")
|
97 |
|
98 |
-
class NewsReport(BaseModel):
|
99 |
title: str = Field(description="The title or headline of the news report")
|
100 |
summary: str = Field(description="A brief summary of the news report")
|
101 |
publication_date: Optional[str] = Field(description="The publication date of the report")
|
@@ -116,16 +116,16 @@ from pydantic import BaseModel, Field
|
|
116 |
class MetaData(BaseModel):
|
117 |
title : str = Field(description="The title of the article")
|
118 |
authors : List[str] = Field(description="The list of the article's authors")
|
119 |
-
abstract: str = Field(description="The article's abstract")
|
120 |
key_words: List[str] = Field(description="The key words associated with the article")
|
121 |
-
|
122 |
class Baseline(BaseModel):
|
123 |
method_name : str = Field(description="The name of the baseline method")
|
124 |
proposed_solution : str = Field(description="the proposed solution in details")
|
125 |
performance_metrics : str = Field(description="The performance metrics of the method and comparative analysis")
|
126 |
-
|
127 |
class ExtractionTarget(BaseModel):
|
128 |
-
|
129 |
key_contributions: List[str] = Field(description="The key contributions of the article")
|
130 |
limitation_of_sota : str=Field(description="the summary limitation of the existing work")
|
131 |
proposed_solution : str = Field(description="the proposed solution in details")
|
|
|
2 |
**Task**: Please extract all economic policies affecting the stock market between 2015 and 2023 and the exact dates of their implementation.
|
3 |
**Text**: This text is from the field of Economics and represents the genre of Article.
|
4 |
...(example text)...
|
5 |
+
**Output Schema**:
|
6 |
{
|
7 |
"economic_policies": [
|
8 |
{
|
|
|
31 |
**Text**: This text is from the field of Political and represents the genre of News Report.
|
32 |
...(example text)...
|
33 |
**Output Schema**:
|
34 |
+
Answer:
|
35 |
{
|
36 |
+
"news_report":
|
37 |
{
|
38 |
"title": null,
|
39 |
"summary": null,
|
|
|
56 |
"""
|
57 |
|
58 |
code_schema_examples = """
|
59 |
+
Example1:
|
60 |
**Task**: Extract all the entities in the given text.
|
61 |
+
**Text**:
|
62 |
...(example text)...
|
63 |
**Output Schema**:
|
64 |
```python
|
|
|
68 |
class Entity(BaseModel):
|
69 |
label : str = Field(description="The type or category of the entity, such as 'Process', 'Technique', 'Data Structure', 'Methodology', 'Person', etc. ")
|
70 |
name : str = Field(description="The specific name of the entity. It should represent a single, distinct concept and must not be an empty string. For example, if the entity is a 'Technique', the name could be 'Neural Networks'.")
|
71 |
+
|
72 |
class ExtractionTarget(BaseModel):
|
73 |
entity_list : List[Entity] = Field(description="All the entities presented in the context. The entities should encode ONE concept.")
|
74 |
```
|
75 |
|
76 |
+
Example2:
|
77 |
**Task**: Extract all the information in the given text.
|
78 |
**Text**: This text is from the field of Political and represents the genre of News Article.
|
79 |
...(example text)...
|
|
|
95 |
process: Optional[str] = Field(description="Details of the event process")
|
96 |
result: Optional[str] = Field(default=None, description="Result or outcome of the event")
|
97 |
|
98 |
+
class NewsReport(BaseModel):
|
99 |
title: str = Field(description="The title or headline of the news report")
|
100 |
summary: str = Field(description="A brief summary of the news report")
|
101 |
publication_date: Optional[str] = Field(description="The publication date of the report")
|
|
|
116 |
class MetaData(BaseModel):
|
117 |
title : str = Field(description="The title of the article")
|
118 |
authors : List[str] = Field(description="The list of the article's authors")
|
119 |
+
abstract: str = Field(description="The article's abstract")
|
120 |
key_words: List[str] = Field(description="The key words associated with the article")
|
121 |
+
|
122 |
class Baseline(BaseModel):
|
123 |
method_name : str = Field(description="The name of the baseline method")
|
124 |
proposed_solution : str = Field(description="the proposed solution in details")
|
125 |
performance_metrics : str = Field(description="The performance metrics of the method and comparative analysis")
|
126 |
+
|
127 |
class ExtractionTarget(BaseModel):
|
128 |
+
|
129 |
key_contributions: List[str] = Field(description="The key contributions of the article")
|
130 |
limitation_of_sota : str=Field(description="the summary limitation of the existing work")
|
131 |
proposed_solution : str = Field(description="the proposed solution in details")
|
src/models/prompt_template.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
from langchain.prompts import PromptTemplate
|
2 |
from .prompt_example import *
|
3 |
|
4 |
-
# ==================================================================== #
|
5 |
-
# SCHEMA AGENT #
|
6 |
-
# ==================================================================== #
|
7 |
|
8 |
# Get Text Analysis
|
9 |
TEXT_ANALYSIS_INSTRUCTION = """
|
@@ -22,9 +22,9 @@ text_analysis_instruction = PromptTemplate(
|
|
22 |
# Get Deduced Schema Json
|
23 |
DEDUCE_SCHEMA_JSON_INSTRUCTION = """
|
24 |
**Instruction**: Generate an output format that meets the requirements as described in the task. Pay attention to the following requirements:
|
25 |
-
- Format: Return your responses in dictionary format as a JSON object.
|
26 |
- Content: Do not include any actual data; all attributes values should be set to None.
|
27 |
-
- Note: Attributes not mentioned in the task description should be ignored.
|
28 |
{examples}
|
29 |
**Task**: {instruction}
|
30 |
|
@@ -57,9 +57,9 @@ deduced_schema_code_instruction = PromptTemplate(
|
|
57 |
)
|
58 |
|
59 |
|
60 |
-
# ==================================================================== #
|
61 |
-
# EXTRACTION AGENT #
|
62 |
-
# ==================================================================== #
|
63 |
|
64 |
EXTRACT_INSTRUCTION = """
|
65 |
**Instruction**: You are an agent skilled in information extarction. {instruction}
|
@@ -113,9 +113,9 @@ summarize_instruction = PromptTemplate(
|
|
113 |
|
114 |
|
115 |
|
116 |
-
# ==================================================================== #
|
117 |
-
# REFLECION AGENT #
|
118 |
-
# ==================================================================== #
|
119 |
REFLECT_INSTRUCTION = """**Instruction**: You are an agent skilled in reflection and optimization based on the original result. Refer to **Reflection Reference** to identify potential issues in the current extraction results.
|
120 |
|
121 |
**Reflection Reference**: {examples}
|
@@ -153,9 +153,9 @@ summarize_instruction = PromptTemplate(
|
|
153 |
|
154 |
|
155 |
|
156 |
-
# ==================================================================== #
|
157 |
-
# CASE REPOSITORY #
|
158 |
-
# ==================================================================== #
|
159 |
|
160 |
GOOD_CASE_ANALYSIS_INSTRUCTION = """
|
161 |
**Instruction**: Below is an information extraction task and its corresponding correct answer. Provide the reasoning steps that led to the correct answer, along with brief explanation of the answer. Your response should be brief and organized.
|
|
|
1 |
from langchain.prompts import PromptTemplate
|
2 |
from .prompt_example import *
|
3 |
|
4 |
+
# ==================================================================== #
|
5 |
+
# SCHEMA AGENT #
|
6 |
+
# ==================================================================== #
|
7 |
|
8 |
# Get Text Analysis
|
9 |
TEXT_ANALYSIS_INSTRUCTION = """
|
|
|
22 |
# Get Deduced Schema Json
|
23 |
DEDUCE_SCHEMA_JSON_INSTRUCTION = """
|
24 |
**Instruction**: Generate an output format that meets the requirements as described in the task. Pay attention to the following requirements:
|
25 |
+
- Format: Return your responses in dictionary format as a JSON object.
|
26 |
- Content: Do not include any actual data; all attributes values should be set to None.
|
27 |
+
- Note: Attributes not mentioned in the task description should be ignored.
|
28 |
{examples}
|
29 |
**Task**: {instruction}
|
30 |
|
|
|
57 |
)
|
58 |
|
59 |
|
60 |
+
# ==================================================================== #
|
61 |
+
# EXTRACTION AGENT #
|
62 |
+
# ==================================================================== #
|
63 |
|
64 |
EXTRACT_INSTRUCTION = """
|
65 |
**Instruction**: You are an agent skilled in information extarction. {instruction}
|
|
|
113 |
|
114 |
|
115 |
|
116 |
+
# ==================================================================== #
|
117 |
+
# REFLECION AGENT #
|
118 |
+
# ==================================================================== #
|
119 |
REFLECT_INSTRUCTION = """**Instruction**: You are an agent skilled in reflection and optimization based on the original result. Refer to **Reflection Reference** to identify potential issues in the current extraction results.
|
120 |
|
121 |
**Reflection Reference**: {examples}
|
|
|
153 |
|
154 |
|
155 |
|
156 |
+
# ==================================================================== #
|
157 |
+
# CASE REPOSITORY #
|
158 |
+
# ==================================================================== #
|
159 |
|
160 |
GOOD_CASE_ANALYSIS_INSTRUCTION = """
|
161 |
**Instruction**: Below is an information extraction task and its corresponding correct answer. Provide the reasoning steps that led to the correct answer, along with brief explanation of the answer. Your response should be brief and organized.
|
src/models/vllm_serve.py
CHANGED
@@ -9,13 +9,13 @@ from utils import *
|
|
9 |
def main():
|
10 |
# Create command-line argument parser
|
11 |
parser = argparse.ArgumentParser(description='Run the extraction model.')
|
12 |
-
parser.add_argument('--config', type=str, required=True,
|
13 |
help='Path to the YAML configuration file.')
|
14 |
parser.add_argument('--tensor-parallel-size', type=int, default=2,
|
15 |
help='Tensor parallel size for the VLLM server.')
|
16 |
parser.add_argument('--max-model-len', type=int, default=32768,
|
17 |
help='Maximum model length for the VLLM server.')
|
18 |
-
|
19 |
# Parse command-line arguments
|
20 |
args = parser.parse_args()
|
21 |
|
@@ -31,4 +31,3 @@ def main():
|
|
31 |
|
32 |
if __name__ == "__main__":
|
33 |
main()
|
34 |
-
|
|
|
9 |
def main():
|
10 |
# Create command-line argument parser
|
11 |
parser = argparse.ArgumentParser(description='Run the extraction model.')
|
12 |
+
parser.add_argument('--config', type=str, required=True,
|
13 |
help='Path to the YAML configuration file.')
|
14 |
parser.add_argument('--tensor-parallel-size', type=int, default=2,
|
15 |
help='Tensor parallel size for the VLLM server.')
|
16 |
parser.add_argument('--max-model-len', type=int, default=32768,
|
17 |
help='Maximum model length for the VLLM server.')
|
18 |
+
|
19 |
# Parse command-line arguments
|
20 |
args = parser.parse_args()
|
21 |
|
|
|
31 |
|
32 |
if __name__ == "__main__":
|
33 |
main()
|
|
src/modules/extraction_agent.py
CHANGED
@@ -65,6 +65,34 @@ class ExtractionAgent:
|
|
65 |
data.constraint = json.dumps(result)
|
66 |
except:
|
67 |
print("Invalid Constraint: Event Extraction constraint must be a dictionary with event types as keys and lists of arguments as values.", data.constraint)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
return data
|
69 |
|
70 |
def extract_information_direct(self, data: DataPoint):
|
|
|
65 |
data.constraint = json.dumps(result)
|
66 |
except:
|
67 |
print("Invalid Constraint: Event Extraction constraint must be a dictionary with event types as keys and lists of arguments as values.", data.constraint)
|
68 |
+
elif data.task == "Triple":
|
69 |
+
constraint = json.dumps(data.constraint)
|
70 |
+
if "**Triple Extraction Constraint**" in constraint:
|
71 |
+
return data
|
72 |
+
if self.llm.name != "OneKE":
|
73 |
+
if len(data.constraint) == 1: # 1 list means entity
|
74 |
+
data.constraint = f"\n**Triple Extraction Constraint**: Entities type must chosen from following list:\n{constraint}\n"
|
75 |
+
elif len(data.constraint) == 2: # 2 list means entity and relation
|
76 |
+
if data.constraint[0] == []:
|
77 |
+
data.constraint = f"\n**Triple Extraction Constraint**: Relation type must chosen from following list:\n{data.constraint[1]}\n"
|
78 |
+
elif data.constraint[1] == []:
|
79 |
+
data.constraint = f"\n**Triple Extraction Constraint**: Entities type must chosen from following list:\n{data.constraint[0]}\n"
|
80 |
+
else:
|
81 |
+
data.constraint = f"\n**Triple Extraction Constraint**: Entities type must chosen from following list:\n{data.constraint[0]}\nRelation type must chosen from following list:\n{data.constraint[1]}\n"
|
82 |
+
elif len(data.constraint) == 3: # 3 list means entity, relation and object
|
83 |
+
if data.constraint[0] == []:
|
84 |
+
data.constraint = f"\n**Triple Extraction Constraint**: Relation type must chosen from following list:\n{data.constraint[1]}\nObject Entities must chosen from following list:\n{data.constraint[2]}\n"
|
85 |
+
elif data.constraint[1] == []:
|
86 |
+
data.constraint = f"\n**Triple Extraction Constraint**: Subject Entities must chosen from following list:\n{data.constraint[0]}\nObject Entities must chosen from following list:\n{data.constraint[2]}\n"
|
87 |
+
elif data.constraint[2] == []:
|
88 |
+
data.constraint = f"\n**Triple Extraction Constraint**: Subject Entities must chosen from following list:\n{data.constraint[0]}\nRelation type must chosen from following list:\n{data.constraint[1]}\n"
|
89 |
+
else:
|
90 |
+
data.constraint = f"\n**Triple Extraction Constraint**: Subject Entities must chosen from following list:\n{data.constraint[0]}\nRelation type must chosen from following list:\n{data.constraint[1]}\nObject Entities must chosen from following list:\n{data.constraint[2]}\n"
|
91 |
+
else:
|
92 |
+
data.constraint = f"\n**Triple Extraction Constraint**: The type of entities must be chosen from the following list:\n{constraint}\n"
|
93 |
+
else:
|
94 |
+
print("OneKE does not support Triple Extraction task now, please wait for the next version.")
|
95 |
+
# print("data.constraint", data.constraint)
|
96 |
return data
|
97 |
|
98 |
def extract_information_direct(self, data: DataPoint):
|
src/modules/knowledge_base/case_repository.py
CHANGED
@@ -19,7 +19,7 @@ class CaseRepository:
|
|
19 |
self.embedder = SentenceTransformer(docker_model_path)
|
20 |
except:
|
21 |
self.embedder = SentenceTransformer(config['model']['embedding_model'])
|
22 |
-
self.embedder.to(device)
|
23 |
self.corpus = self.load_corpus()
|
24 |
self.embedded_corpus = self.embed_corpus()
|
25 |
|
@@ -27,14 +27,14 @@ class CaseRepository:
|
|
27 |
with open(os.path.join(os.path.dirname(__file__), "case_repository.json")) as file:
|
28 |
corpus = json.load(file)
|
29 |
return corpus
|
30 |
-
|
31 |
def update_corpus(self):
|
32 |
try:
|
33 |
with open(os.path.join(os.path.dirname(__file__), "case_repository.json"), "w") as file:
|
34 |
json.dump(self.corpus, file, indent=2)
|
35 |
except Exception as e:
|
36 |
print(f"Error when updating corpus: {e}")
|
37 |
-
|
38 |
def embed_corpus(self):
|
39 |
embedded_corpus = {}
|
40 |
for key, content in self.corpus.items():
|
@@ -43,8 +43,8 @@ class CaseRepository:
|
|
43 |
bad_index = [item['index']['embed_index'] for item in content['bad']]
|
44 |
encoded_bad_index = self.embedder.encode(bad_index, convert_to_tensor=True).to(device)
|
45 |
embedded_corpus[key] = {"good": encoded_good_index, "bad": encoded_bad_index}
|
46 |
-
return embedded_corpus
|
47 |
-
|
48 |
def get_similarity_scores(self, task: TaskType, embed_index="", str_index="", case_type="", top_k=2):
|
49 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
50 |
# Embedding similarity match
|
@@ -58,7 +58,7 @@ class CaseRepository:
|
|
58 |
scores_dict = {match[0]: match[1] for match in str_similarity_results}
|
59 |
scores_in_order = [scores_dict[candidate] for candidate in str_match_corpus]
|
60 |
str_similarity_scores = torch.tensor(scores_in_order, dtype=torch.float32).to(device)
|
61 |
-
|
62 |
# Normalize scores
|
63 |
embedding_score_range = embedding_similarity_scores.max() - embedding_similarity_scores.min()
|
64 |
str_score_range = str_similarity_scores.max() - str_similarity_scores.min()
|
@@ -74,16 +74,16 @@ class CaseRepository:
|
|
74 |
# Combine the scores with weights
|
75 |
combined_scores = 0.5 * embed_norm_scores + 0.5 * str_norm_scores
|
76 |
original_combined_scores = 0.5 * embedding_similarity_scores + 0.5 * str_similarity_scores / 100
|
77 |
-
|
78 |
scores, indices = torch.topk(combined_scores, k=min(top_k, combined_scores.size(0)))
|
79 |
original_scores, original_indices = torch.topk(original_combined_scores, k=min(top_k, original_combined_scores.size(0)))
|
80 |
return scores, indices, original_scores, original_indices
|
81 |
-
|
82 |
def query_case(self, task: TaskType, embed_index="", str_index="", case_type="", top_k=2) -> list:
|
83 |
_, indices, _, _ = self.get_similarity_scores(task, embed_index, str_index, case_type, top_k)
|
84 |
top_matches = [self.corpus[task][case_type][idx]["content"] for idx in indices]
|
85 |
return top_matches
|
86 |
-
|
87 |
def update_case(self, task: TaskType, embed_index="", str_index="", content="" ,case_type=""):
|
88 |
self.corpus[task][case_type].append({"index": {"embed_index": embed_index, "str_index": str_index}, "content": content})
|
89 |
self.embedded_corpus[task][case_type] = torch.cat([self.embedded_corpus[task][case_type], self.embedder.encode([embed_index], convert_to_tensor=True).to(device)], dim=0)
|
@@ -102,9 +102,9 @@ class CaseRepositoryHandler:
|
|
102 |
response = self.llm.get_chat_response(prompt)
|
103 |
response = extract_json_dict(response)
|
104 |
if not isinstance(response, dict):
|
105 |
-
return response
|
106 |
return None
|
107 |
-
|
108 |
def __get_bad_case_reflection(self, instruction="", text="", original_answer="", correct_answer="", additional_info=""):
|
109 |
prompt = bad_case_reflection_instruction.format(
|
110 |
instruction=instruction, text=text, original_answer=original_answer, correct_answer=correct_answer, additional_info=additional_info
|
@@ -115,34 +115,34 @@ class CaseRepositoryHandler:
|
|
115 |
if not isinstance(response, dict):
|
116 |
return response
|
117 |
return None
|
118 |
-
|
119 |
def __get_index(self, data: DataPoint, case_type: str):
|
120 |
# set embed_index
|
121 |
embed_index = f"**Text**: {data.distilled_text}\n{data.chunk_text_list[0]}"
|
122 |
-
|
123 |
# set str_index
|
124 |
if data.task == "Base":
|
125 |
str_index = f"**Task**: {data.instruction}"
|
126 |
else:
|
127 |
str_index = f"{data.constraint}"
|
128 |
-
|
129 |
if case_type == "bad":
|
130 |
str_index += f"\n\n**Original Result**: {json.dumps(data.pred)}"
|
131 |
-
|
132 |
return embed_index, str_index
|
133 |
-
|
134 |
def query_good_case(self, data: DataPoint):
|
135 |
embed_index, str_index = self.__get_index(data, "good")
|
136 |
return self.repository.query_case(task=data.task, embed_index=embed_index, str_index=str_index, case_type="good")
|
137 |
-
|
138 |
def query_bad_case(self, data: DataPoint):
|
139 |
embed_index, str_index = self.__get_index(data, "bad")
|
140 |
return self.repository.query_case(task=data.task, embed_index=embed_index, str_index=str_index, case_type="bad")
|
141 |
-
|
142 |
def update_good_case(self, data: DataPoint):
|
143 |
if data.truth == "" :
|
144 |
print("No truth value provided.")
|
145 |
-
return
|
146 |
embed_index, str_index = self.__get_index(data, "good")
|
147 |
_, _, original_scores, _ = self.repository.get_similarity_scores(data.task, embed_index, str_index, "good", 1)
|
148 |
original_scores = original_scores.tolist()
|
@@ -159,11 +159,11 @@ class CaseRepositoryHandler:
|
|
159 |
else:
|
160 |
content = f"{wrapped_text}\n\n{data.constraint}\n\n{wrapped_good_case_analysis}\n\n{wrapped_answer}"
|
161 |
self.repository.update_case(data.task, embed_index, str_index, content, "good")
|
162 |
-
|
163 |
def update_bad_case(self, data: DataPoint):
|
164 |
if data.truth == "" :
|
165 |
print("No truth value provided.")
|
166 |
-
return
|
167 |
if normalize_obj(data.pred) == normalize_obj(data.truth):
|
168 |
return
|
169 |
embed_index, str_index = self.__get_index(data, "bad")
|
@@ -183,7 +183,7 @@ class CaseRepositoryHandler:
|
|
183 |
else:
|
184 |
content = f"{wrapped_text}\n\n{data.constraint}\n\n{wrapper_original_answer}\n\n{wrapped_bad_case_reflection}\n\n{wrapper_correct_answer}"
|
185 |
self.repository.update_case(data.task, embed_index, str_index, content, "bad")
|
186 |
-
|
187 |
def update_case(self, data: DataPoint):
|
188 |
self.update_good_case(data)
|
189 |
self.update_bad_case(data)
|
|
|
19 |
self.embedder = SentenceTransformer(docker_model_path)
|
20 |
except:
|
21 |
self.embedder = SentenceTransformer(config['model']['embedding_model'])
|
22 |
+
self.embedder.to(device)
|
23 |
self.corpus = self.load_corpus()
|
24 |
self.embedded_corpus = self.embed_corpus()
|
25 |
|
|
|
27 |
with open(os.path.join(os.path.dirname(__file__), "case_repository.json")) as file:
|
28 |
corpus = json.load(file)
|
29 |
return corpus
|
30 |
+
|
31 |
def update_corpus(self):
|
32 |
try:
|
33 |
with open(os.path.join(os.path.dirname(__file__), "case_repository.json"), "w") as file:
|
34 |
json.dump(self.corpus, file, indent=2)
|
35 |
except Exception as e:
|
36 |
print(f"Error when updating corpus: {e}")
|
37 |
+
|
38 |
def embed_corpus(self):
|
39 |
embedded_corpus = {}
|
40 |
for key, content in self.corpus.items():
|
|
|
43 |
bad_index = [item['index']['embed_index'] for item in content['bad']]
|
44 |
encoded_bad_index = self.embedder.encode(bad_index, convert_to_tensor=True).to(device)
|
45 |
embedded_corpus[key] = {"good": encoded_good_index, "bad": encoded_bad_index}
|
46 |
+
return embedded_corpus
|
47 |
+
|
48 |
def get_similarity_scores(self, task: TaskType, embed_index="", str_index="", case_type="", top_k=2):
|
49 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
50 |
# Embedding similarity match
|
|
|
58 |
scores_dict = {match[0]: match[1] for match in str_similarity_results}
|
59 |
scores_in_order = [scores_dict[candidate] for candidate in str_match_corpus]
|
60 |
str_similarity_scores = torch.tensor(scores_in_order, dtype=torch.float32).to(device)
|
61 |
+
|
62 |
# Normalize scores
|
63 |
embedding_score_range = embedding_similarity_scores.max() - embedding_similarity_scores.min()
|
64 |
str_score_range = str_similarity_scores.max() - str_similarity_scores.min()
|
|
|
74 |
# Combine the scores with weights
|
75 |
combined_scores = 0.5 * embed_norm_scores + 0.5 * str_norm_scores
|
76 |
original_combined_scores = 0.5 * embedding_similarity_scores + 0.5 * str_similarity_scores / 100
|
77 |
+
|
78 |
scores, indices = torch.topk(combined_scores, k=min(top_k, combined_scores.size(0)))
|
79 |
original_scores, original_indices = torch.topk(original_combined_scores, k=min(top_k, original_combined_scores.size(0)))
|
80 |
return scores, indices, original_scores, original_indices
|
81 |
+
|
82 |
def query_case(self, task: TaskType, embed_index="", str_index="", case_type="", top_k=2) -> list:
|
83 |
_, indices, _, _ = self.get_similarity_scores(task, embed_index, str_index, case_type, top_k)
|
84 |
top_matches = [self.corpus[task][case_type][idx]["content"] for idx in indices]
|
85 |
return top_matches
|
86 |
+
|
87 |
def update_case(self, task: TaskType, embed_index="", str_index="", content="" ,case_type=""):
|
88 |
self.corpus[task][case_type].append({"index": {"embed_index": embed_index, "str_index": str_index}, "content": content})
|
89 |
self.embedded_corpus[task][case_type] = torch.cat([self.embedded_corpus[task][case_type], self.embedder.encode([embed_index], convert_to_tensor=True).to(device)], dim=0)
|
|
|
102 |
response = self.llm.get_chat_response(prompt)
|
103 |
response = extract_json_dict(response)
|
104 |
if not isinstance(response, dict):
|
105 |
+
return response
|
106 |
return None
|
107 |
+
|
108 |
def __get_bad_case_reflection(self, instruction="", text="", original_answer="", correct_answer="", additional_info=""):
|
109 |
prompt = bad_case_reflection_instruction.format(
|
110 |
instruction=instruction, text=text, original_answer=original_answer, correct_answer=correct_answer, additional_info=additional_info
|
|
|
115 |
if not isinstance(response, dict):
|
116 |
return response
|
117 |
return None
|
118 |
+
|
119 |
def __get_index(self, data: DataPoint, case_type: str):
|
120 |
# set embed_index
|
121 |
embed_index = f"**Text**: {data.distilled_text}\n{data.chunk_text_list[0]}"
|
122 |
+
|
123 |
# set str_index
|
124 |
if data.task == "Base":
|
125 |
str_index = f"**Task**: {data.instruction}"
|
126 |
else:
|
127 |
str_index = f"{data.constraint}"
|
128 |
+
|
129 |
if case_type == "bad":
|
130 |
str_index += f"\n\n**Original Result**: {json.dumps(data.pred)}"
|
131 |
+
|
132 |
return embed_index, str_index
|
133 |
+
|
134 |
def query_good_case(self, data: DataPoint):
|
135 |
embed_index, str_index = self.__get_index(data, "good")
|
136 |
return self.repository.query_case(task=data.task, embed_index=embed_index, str_index=str_index, case_type="good")
|
137 |
+
|
138 |
def query_bad_case(self, data: DataPoint):
|
139 |
embed_index, str_index = self.__get_index(data, "bad")
|
140 |
return self.repository.query_case(task=data.task, embed_index=embed_index, str_index=str_index, case_type="bad")
|
141 |
+
|
142 |
def update_good_case(self, data: DataPoint):
|
143 |
if data.truth == "" :
|
144 |
print("No truth value provided.")
|
145 |
+
return
|
146 |
embed_index, str_index = self.__get_index(data, "good")
|
147 |
_, _, original_scores, _ = self.repository.get_similarity_scores(data.task, embed_index, str_index, "good", 1)
|
148 |
original_scores = original_scores.tolist()
|
|
|
159 |
else:
|
160 |
content = f"{wrapped_text}\n\n{data.constraint}\n\n{wrapped_good_case_analysis}\n\n{wrapped_answer}"
|
161 |
self.repository.update_case(data.task, embed_index, str_index, content, "good")
|
162 |
+
|
163 |
def update_bad_case(self, data: DataPoint):
|
164 |
if data.truth == "" :
|
165 |
print("No truth value provided.")
|
166 |
+
return
|
167 |
if normalize_obj(data.pred) == normalize_obj(data.truth):
|
168 |
return
|
169 |
embed_index, str_index = self.__get_index(data, "bad")
|
|
|
183 |
else:
|
184 |
content = f"{wrapped_text}\n\n{data.constraint}\n\n{wrapper_original_answer}\n\n{wrapped_bad_case_reflection}\n\n{wrapper_correct_answer}"
|
185 |
self.repository.update_case(data.task, embed_index, str_index, content, "bad")
|
186 |
+
|
187 |
def update_case(self, data: DataPoint):
|
188 |
self.update_good_case(data)
|
189 |
self.update_bad_case(data)
|
src/modules/knowledge_base/schema_repository.py
CHANGED
@@ -33,6 +33,19 @@ class Event(BaseModel):
|
|
33 |
class EventList(BaseModel):
|
34 |
event_list : List[Event] = Field(description="The events presented in the text.")
|
35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
# ==================================================================== #
|
37 |
# TEXT DESCRIPTION #
|
38 |
# ==================================================================== #
|
|
|
33 |
class EventList(BaseModel):
|
34 |
event_list : List[Event] = Field(description="The events presented in the text.")
|
35 |
|
36 |
+
# ==================================================================== #
|
37 |
+
# Triple TASK #
|
38 |
+
# ==================================================================== #
|
39 |
+
class Triple(BaseModel):
|
40 |
+
head: str = Field(description="The subject or head of the triple.")
|
41 |
+
head_type: str = Field(description="The type of the subject entity.")
|
42 |
+
relation: str = Field(description="The predicate or relation between the entities.")
|
43 |
+
relation_type: str = Field(description="The type of the relation.")
|
44 |
+
tail: str = Field(description="The object or tail of the triple.")
|
45 |
+
tail_type: str = Field(description="The type of the object entity.")
|
46 |
+
class TripleList(BaseModel):
|
47 |
+
triple_list: List[Triple] = Field(description="The collection of triples and their types presented in the text.")
|
48 |
+
|
49 |
# ==================================================================== #
|
50 |
# TEXT DESCRIPTION #
|
51 |
# ==================================================================== #
|
src/modules/reflection_agent.py
CHANGED
@@ -5,7 +5,7 @@ from .knowledge_base.case_repository import CaseRepositoryHandler
|
|
5 |
class ReflectionGenerator:
|
6 |
def __init__(self, llm: BaseEngine):
|
7 |
self.llm = llm
|
8 |
-
|
9 |
def get_reflection(self, instruction="", examples="", text="",schema="", result=""):
|
10 |
result = json.dumps(result)
|
11 |
examples = bad_case_wrapper(examples)
|
@@ -13,7 +13,7 @@ class ReflectionGenerator:
|
|
13 |
response = self.llm.get_chat_response(prompt)
|
14 |
response = extract_json_dict(response)
|
15 |
return response
|
16 |
-
|
17 |
class ReflectionAgent:
|
18 |
def __init__(self, llm: BaseEngine, case_repo: CaseRepositoryHandler):
|
19 |
self.llm = llm
|
@@ -29,7 +29,7 @@ class ReflectionAgent:
|
|
29 |
else:
|
30 |
selected_obj = max(result_list, key=lambda o: len(json.dumps(o)))
|
31 |
return selected_obj
|
32 |
-
|
33 |
def __self_consistance_check(self, data: DataPoint):
|
34 |
extract_func = list(data.result_trajectory.keys())[-1]
|
35 |
if hasattr(self.extractor, extract_func):
|
@@ -55,7 +55,7 @@ class ReflectionAgent:
|
|
55 |
consistant_result.append(selected_element)
|
56 |
data.set_result_list(consistant_result)
|
57 |
return reflect_index
|
58 |
-
|
59 |
def reflect_with_case(self, data: DataPoint):
|
60 |
if data.result_list == []:
|
61 |
return data
|
@@ -71,4 +71,3 @@ class ReflectionAgent:
|
|
71 |
function_name = current_function_name()
|
72 |
data.update_trajectory(function_name, data.result_list)
|
73 |
return data
|
74 |
-
|
|
|
5 |
class ReflectionGenerator:
|
6 |
def __init__(self, llm: BaseEngine):
|
7 |
self.llm = llm
|
8 |
+
|
9 |
def get_reflection(self, instruction="", examples="", text="",schema="", result=""):
|
10 |
result = json.dumps(result)
|
11 |
examples = bad_case_wrapper(examples)
|
|
|
13 |
response = self.llm.get_chat_response(prompt)
|
14 |
response = extract_json_dict(response)
|
15 |
return response
|
16 |
+
|
17 |
class ReflectionAgent:
|
18 |
def __init__(self, llm: BaseEngine, case_repo: CaseRepositoryHandler):
|
19 |
self.llm = llm
|
|
|
29 |
else:
|
30 |
selected_obj = max(result_list, key=lambda o: len(json.dumps(o)))
|
31 |
return selected_obj
|
32 |
+
|
33 |
def __self_consistance_check(self, data: DataPoint):
|
34 |
extract_func = list(data.result_trajectory.keys())[-1]
|
35 |
if hasattr(self.extractor, extract_func):
|
|
|
55 |
consistant_result.append(selected_element)
|
56 |
data.set_result_list(consistant_result)
|
57 |
return reflect_index
|
58 |
+
|
59 |
def reflect_with_case(self, data: DataPoint):
|
60 |
if data.result_list == []:
|
61 |
return data
|
|
|
71 |
function_name = current_function_name()
|
72 |
data.update_trajectory(function_name, data.result_list)
|
73 |
return data
|
|
src/modules/schema_agent.py
CHANGED
@@ -106,6 +106,18 @@ class Event(BaseModel):
|
|
106 |
class EventList(BaseModel):
|
107 |
event_list : List[Event] = Field(description="The events presented in the text.")
|
108 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
return data
|
110 |
|
111 |
def get_default_schema(self, data: DataPoint):
|
|
|
106 |
class EventList(BaseModel):
|
107 |
event_list : List[Event] = Field(description="The events presented in the text.")
|
108 |
"""
|
109 |
+
elif data.task == "Triple":
|
110 |
+
data.print_schema = """
|
111 |
+
class Triple(BaseModel):
|
112 |
+
head: str = Field(description="The subject or head of the triple.")
|
113 |
+
head_type: str = Field(description="The type of the subject entity.")
|
114 |
+
relation: str = Field(description="The predicate or relation between the entities.")
|
115 |
+
relation_type: str = Field(description="The type of the relation.")
|
116 |
+
tail: str = Field(description="The object or tail of the triple.")
|
117 |
+
tail_type: str = Field(description="The type of the object entity.")
|
118 |
+
class TripleList(BaseModel):
|
119 |
+
triple_list: List[Triple] = Field(description="The collection of triples and their types presented in the text.")
|
120 |
+
"""
|
121 |
return data
|
122 |
|
123 |
def get_default_schema(self, data: DataPoint):
|
src/pipeline.py
CHANGED
@@ -2,6 +2,7 @@ from typing import Literal
|
|
2 |
from models import *
|
3 |
from utils import *
|
4 |
from modules import *
|
|
|
5 |
|
6 |
|
7 |
class Pipeline:
|
@@ -14,7 +15,7 @@ class Pipeline:
|
|
14 |
|
15 |
def __check_consistancy(self, llm, task, mode, update_case):
|
16 |
if llm.name == "OneKE":
|
17 |
-
if task == "Base":
|
18 |
raise ValueError("The finetuned OneKE only supports quick extraction mode for NER, RE and EE Task.")
|
19 |
else:
|
20 |
mode = "quick"
|
@@ -44,12 +45,16 @@ class Pipeline:
|
|
44 |
elif data.task == "EE":
|
45 |
data.instruction = config['agent']['default_ee']
|
46 |
data.output_schema = "EventList"
|
|
|
|
|
|
|
47 |
return data
|
48 |
|
49 |
# main entry
|
50 |
def get_extract_result(self,
|
51 |
task: TaskType,
|
52 |
three_agents = {},
|
|
|
53 |
instruction: str = "",
|
54 |
text: str = "",
|
55 |
output_schema: str = "",
|
@@ -61,6 +66,7 @@ class Pipeline:
|
|
61 |
update_case: bool = False,
|
62 |
show_trajectory: bool = False,
|
63 |
isgui: bool = False,
|
|
|
64 |
):
|
65 |
# for key, value in locals().items():
|
66 |
# print(f"{key}: {value}")
|
@@ -105,7 +111,17 @@ class Pipeline:
|
|
105 |
# show result
|
106 |
if show_trajectory:
|
107 |
print("Extraction Trajectory: \n", json.dumps(data.get_result_trajectory(), indent=2))
|
108 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
|
110 |
frontend_res = data.pred #
|
111 |
|
|
|
2 |
from models import *
|
3 |
from utils import *
|
4 |
from modules import *
|
5 |
+
from construct import *
|
6 |
|
7 |
|
8 |
class Pipeline:
|
|
|
15 |
|
16 |
def __check_consistancy(self, llm, task, mode, update_case):
|
17 |
if llm.name == "OneKE":
|
18 |
+
if task == "Base" or task == "Triple":
|
19 |
raise ValueError("The finetuned OneKE only supports quick extraction mode for NER, RE and EE Task.")
|
20 |
else:
|
21 |
mode = "quick"
|
|
|
45 |
elif data.task == "EE":
|
46 |
data.instruction = config['agent']['default_ee']
|
47 |
data.output_schema = "EventList"
|
48 |
+
elif data.task == "Triple":
|
49 |
+
data.instruction = config['agent']['default_triple']
|
50 |
+
data.output_schema = "TripleList"
|
51 |
return data
|
52 |
|
53 |
# main entry
|
54 |
def get_extract_result(self,
|
55 |
task: TaskType,
|
56 |
three_agents = {},
|
57 |
+
construct = {},
|
58 |
instruction: str = "",
|
59 |
text: str = "",
|
60 |
output_schema: str = "",
|
|
|
66 |
update_case: bool = False,
|
67 |
show_trajectory: bool = False,
|
68 |
isgui: bool = False,
|
69 |
+
iskg: bool = False,
|
70 |
):
|
71 |
# for key, value in locals().items():
|
72 |
# print(f"{key}: {value}")
|
|
|
111 |
# show result
|
112 |
if show_trajectory:
|
113 |
print("Extraction Trajectory: \n", json.dumps(data.get_result_trajectory(), indent=2))
|
114 |
+
extraction_result = json.dumps(data.pred, indent=2)
|
115 |
+
print("Extraction Result: \n", extraction_result)
|
116 |
+
|
117 |
+
# construct KG
|
118 |
+
if iskg:
|
119 |
+
myurl = construct['url']
|
120 |
+
myusername = construct['username']
|
121 |
+
mypassword = construct['password']
|
122 |
+
print(f"Construct KG in your {construct['database']} now...")
|
123 |
+
cypher_statements = generate_cypher_statements(extraction_result)
|
124 |
+
execute_cypher_statements(uri=myurl, user=myusername, password=mypassword, cypher_statements=cypher_statements)
|
125 |
|
126 |
frontend_res = data.pred #
|
127 |
|
src/run.py
CHANGED
@@ -11,9 +11,9 @@ from modules import *
|
|
11 |
def main():
|
12 |
# Create command-line argument parser
|
13 |
parser = argparse.ArgumentParser(description='Run the extraction framefork.')
|
14 |
-
parser.add_argument('--config', type=str, required=True,
|
15 |
help='Path to the YAML configuration file.')
|
16 |
-
|
17 |
# Parse command-line arguments
|
18 |
args = parser.parse_args()
|
19 |
|
@@ -35,6 +35,15 @@ def main():
|
|
35 |
pipeline = Pipeline(model)
|
36 |
# Extraction config
|
37 |
extraction_config = config['extraction']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
result, trajectory, _, _ = pipeline.get_extract_result(task=extraction_config['task'], instruction=extraction_config['instruction'], text=extraction_config['text'], output_schema=extraction_config['output_schema'], constraint=extraction_config['constraint'], use_file=extraction_config['use_file'], file_path=extraction_config['file_path'], truth=extraction_config['truth'], mode=extraction_config['mode'], update_case=extraction_config['update_case'], show_trajectory=extraction_config['show_trajectory'])
|
39 |
return
|
40 |
|
|
|
11 |
def main():
|
12 |
# Create command-line argument parser
|
13 |
parser = argparse.ArgumentParser(description='Run the extraction framefork.')
|
14 |
+
parser.add_argument('--config', type=str, required=True,
|
15 |
help='Path to the YAML configuration file.')
|
16 |
+
|
17 |
# Parse command-line arguments
|
18 |
args = parser.parse_args()
|
19 |
|
|
|
35 |
pipeline = Pipeline(model)
|
36 |
# Extraction config
|
37 |
extraction_config = config['extraction']
|
38 |
+
# constuct config
|
39 |
+
if 'construct' in config:
|
40 |
+
construct_config = config['construct']
|
41 |
+
result, trajectory, _, _ = pipeline.get_extract_result(task=extraction_config['task'], instruction=extraction_config['instruction'], text=extraction_config['text'], output_schema=extraction_config['output_schema'], constraint=extraction_config['constraint'], use_file=extraction_config['use_file'], file_path=extraction_config['file_path'], truth=extraction_config['truth'], mode=extraction_config['mode'], update_case=extraction_config['update_case'], show_trajectory=extraction_config['show_trajectory'],
|
42 |
+
construct=construct_config, iskg=True) # When 'construct' is provided, 'iskg' should be True to construct the knowledge graph.
|
43 |
+
return
|
44 |
+
else:
|
45 |
+
print("please provide construct config in the yaml file.")
|
46 |
+
|
47 |
result, trajectory, _, _ = pipeline.get_extract_result(task=extraction_config['task'], instruction=extraction_config['instruction'], text=extraction_config['text'], output_schema=extraction_config['output_schema'], constraint=extraction_config['constraint'], use_file=extraction_config['use_file'], file_path=extraction_config['file_path'], truth=extraction_config['truth'], mode=extraction_config['mode'], update_case=extraction_config['update_case'], show_trajectory=extraction_config['show_trajectory'])
|
48 |
return
|
49 |
|
src/utils/__init__.py
CHANGED
@@ -1,3 +1,2 @@
|
|
1 |
from .process import *
|
2 |
from .data_def import DataPoint, TaskType
|
3 |
-
|
|
|
1 |
from .process import *
|
2 |
from .data_def import DataPoint, TaskType
|
|
src/utils/process.py
CHANGED
@@ -17,28 +17,28 @@ import inspect
|
|
17 |
import ast
|
18 |
with open(os.path.join(os.path.dirname(__file__), "..", "config.yaml")) as file:
|
19 |
config = yaml.safe_load(file)
|
20 |
-
|
21 |
-
# Load configuration
|
22 |
def load_extraction_config(yaml_path):
|
23 |
# Read YAML content from the file path
|
24 |
if not os.path.exists(yaml_path):
|
25 |
print(f"Error: The config file '{yaml_path}' does not exist.")
|
26 |
return {}
|
27 |
-
|
28 |
with open(yaml_path, 'r') as file:
|
29 |
config = yaml.safe_load(file)
|
30 |
|
31 |
# Extract the 'extraction' configuration dictionary
|
32 |
model_config = config.get('model', {})
|
33 |
extraction_config = config.get('extraction', {})
|
34 |
-
|
35 |
# Model config
|
36 |
model_name_or_path = model_config.get('model_name_or_path', "")
|
37 |
model_category = model_config.get('category', "")
|
38 |
api_key = model_config.get('api_key', "")
|
39 |
base_url = model_config.get('base_url', "")
|
40 |
vllm_serve = model_config.get('vllm_serve', False)
|
41 |
-
|
42 |
# Extraction config
|
43 |
task = extraction_config.get('task', "")
|
44 |
instruction = extraction_config.get('instruction', "")
|
@@ -52,6 +52,43 @@ def load_extraction_config(yaml_path):
|
|
52 |
update_case = extraction_config.get('update_case', False)
|
53 |
show_trajectory = extraction_config.get('show_trajectory', False)
|
54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
# Return a dictionary containing these variables
|
56 |
return {
|
57 |
"model": {
|
@@ -75,7 +112,7 @@ def load_extraction_config(yaml_path):
|
|
75 |
"show_trajectory": show_trajectory
|
76 |
}
|
77 |
}
|
78 |
-
|
79 |
# Split the string text into chunks
|
80 |
def chunk_str(text):
|
81 |
sentences = sent_tokenize(text)
|
@@ -102,24 +139,24 @@ def chunk_file(file_path):
|
|
102 |
pages = []
|
103 |
|
104 |
if file_path.endswith(".pdf"):
|
105 |
-
loader = PyPDFLoader(file_path)
|
106 |
elif file_path.endswith(".txt"):
|
107 |
-
loader = TextLoader(file_path)
|
108 |
elif file_path.endswith(".docx"):
|
109 |
-
loader = Docx2txtLoader(file_path)
|
110 |
elif file_path.endswith(".html"):
|
111 |
-
loader = BSHTMLLoader(file_path)
|
112 |
elif file_path.endswith(".json"):
|
113 |
-
loader = JSONLoader(file_path)
|
114 |
else:
|
115 |
raise ValueError("Unsupported file format") # Inform that the format is unsupported
|
116 |
-
|
117 |
-
pages = loader.load_and_split()
|
118 |
docs = ""
|
119 |
for item in pages:
|
120 |
docs += item.page_content
|
121 |
pages = chunk_str(docs)
|
122 |
-
|
123 |
return pages
|
124 |
|
125 |
def process_single_quotes(text):
|
@@ -147,11 +184,11 @@ def remove_empty_values(data):
|
|
147 |
def extract_json_dict(text):
|
148 |
if isinstance(text, dict):
|
149 |
return text
|
150 |
-
pattern = r'\{(?:[^{}]|(?:\{(?:[^{}]|(?:\{[^{}]*\})*)*\})*)*\}'
|
151 |
-
matches = re.findall(pattern, text)
|
152 |
if matches:
|
153 |
-
json_string = matches[-1]
|
154 |
-
json_string = process_single_quotes(json_string)
|
155 |
try:
|
156 |
json_dict = json.loads(json_string)
|
157 |
json_dict = remove_empty_values(json_dict)
|
@@ -159,9 +196,9 @@ def extract_json_dict(text):
|
|
159 |
return "No valid information found."
|
160 |
return json_dict
|
161 |
except json.JSONDecodeError:
|
162 |
-
return json_string
|
163 |
else:
|
164 |
-
return text
|
165 |
|
166 |
def good_case_wrapper(example: str):
|
167 |
if example is None or example == "":
|
@@ -182,10 +219,10 @@ def example_wrapper(example: str):
|
|
182 |
return example
|
183 |
|
184 |
def remove_redundant_space(s):
|
185 |
-
s = ' '.join(s.split())
|
186 |
-
s = re.sub(r"\s*(,|:|\(|\)|\.|_|;|'|-)\s*", r'\1', s)
|
187 |
return s
|
188 |
-
|
189 |
def format_string(s):
|
190 |
s = remove_redundant_space(s)
|
191 |
s = s.lower()
|
@@ -197,9 +234,9 @@ def format_string(s):
|
|
197 |
return s
|
198 |
|
199 |
def calculate_metrics(y_truth: set, y_pred: set):
|
200 |
-
TP = len(y_truth & y_pred)
|
201 |
-
FN = len(y_truth - y_pred)
|
202 |
-
FP = len(y_pred - y_truth)
|
203 |
precision = TP / (TP + FP) if (TP + FP) > 0 else 0
|
204 |
recall = TP / (TP + FN) if (TP + FN) > 0 else 0
|
205 |
f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
|
@@ -214,11 +251,11 @@ def current_function_name():
|
|
214 |
else:
|
215 |
print("No caller function found")
|
216 |
return None
|
217 |
-
|
218 |
except Exception as e:
|
219 |
print(f"An error occurred: {e}")
|
220 |
-
pass
|
221 |
-
|
222 |
def normalize_obj(value):
|
223 |
if isinstance(value, dict):
|
224 |
return frozenset((k, normalize_obj(v)) for k, v in value.items())
|
|
|
17 |
import ast
|
18 |
with open(os.path.join(os.path.dirname(__file__), "..", "config.yaml")) as file:
|
19 |
config = yaml.safe_load(file)
|
20 |
+
|
21 |
+
# Load configuration
|
22 |
def load_extraction_config(yaml_path):
|
23 |
# Read YAML content from the file path
|
24 |
if not os.path.exists(yaml_path):
|
25 |
print(f"Error: The config file '{yaml_path}' does not exist.")
|
26 |
return {}
|
27 |
+
|
28 |
with open(yaml_path, 'r') as file:
|
29 |
config = yaml.safe_load(file)
|
30 |
|
31 |
# Extract the 'extraction' configuration dictionary
|
32 |
model_config = config.get('model', {})
|
33 |
extraction_config = config.get('extraction', {})
|
34 |
+
|
35 |
# Model config
|
36 |
model_name_or_path = model_config.get('model_name_or_path', "")
|
37 |
model_category = model_config.get('category', "")
|
38 |
api_key = model_config.get('api_key', "")
|
39 |
base_url = model_config.get('base_url', "")
|
40 |
vllm_serve = model_config.get('vllm_serve', False)
|
41 |
+
|
42 |
# Extraction config
|
43 |
task = extraction_config.get('task', "")
|
44 |
instruction = extraction_config.get('instruction', "")
|
|
|
52 |
update_case = extraction_config.get('update_case', False)
|
53 |
show_trajectory = extraction_config.get('show_trajectory', False)
|
54 |
|
55 |
+
# Construct config (optional: for constructing your knowledge graph)
|
56 |
+
if 'construct' in config:
|
57 |
+
construct_config = config.get('construct', {})
|
58 |
+
database = construct_config.get('database', "")
|
59 |
+
url = construct_config.get('url', "")
|
60 |
+
username = construct_config.get('username', "")
|
61 |
+
password = construct_config.get('password', "")
|
62 |
+
# Return a dictionary containing these variables
|
63 |
+
return {
|
64 |
+
"model": {
|
65 |
+
"model_name_or_path": model_name_or_path,
|
66 |
+
"category": model_category,
|
67 |
+
"api_key": api_key,
|
68 |
+
"base_url": base_url,
|
69 |
+
"vllm_serve": vllm_serve
|
70 |
+
},
|
71 |
+
"extraction": {
|
72 |
+
"task": task,
|
73 |
+
"instruction": instruction,
|
74 |
+
"text": text,
|
75 |
+
"output_schema": output_schema,
|
76 |
+
"constraint": constraint,
|
77 |
+
"truth": truth,
|
78 |
+
"use_file": use_file,
|
79 |
+
"file_path": file_path,
|
80 |
+
"mode": mode,
|
81 |
+
"update_case": update_case,
|
82 |
+
"show_trajectory": show_trajectory
|
83 |
+
},
|
84 |
+
"construct": {
|
85 |
+
"database": database,
|
86 |
+
"url": url,
|
87 |
+
"username": username,
|
88 |
+
"password": password
|
89 |
+
}
|
90 |
+
}
|
91 |
+
|
92 |
# Return a dictionary containing these variables
|
93 |
return {
|
94 |
"model": {
|
|
|
112 |
"show_trajectory": show_trajectory
|
113 |
}
|
114 |
}
|
115 |
+
|
116 |
# Split the string text into chunks
|
117 |
def chunk_str(text):
|
118 |
sentences = sent_tokenize(text)
|
|
|
139 |
pages = []
|
140 |
|
141 |
if file_path.endswith(".pdf"):
|
142 |
+
loader = PyPDFLoader(file_path)
|
143 |
elif file_path.endswith(".txt"):
|
144 |
+
loader = TextLoader(file_path)
|
145 |
elif file_path.endswith(".docx"):
|
146 |
+
loader = Docx2txtLoader(file_path)
|
147 |
elif file_path.endswith(".html"):
|
148 |
+
loader = BSHTMLLoader(file_path)
|
149 |
elif file_path.endswith(".json"):
|
150 |
+
loader = JSONLoader(file_path)
|
151 |
else:
|
152 |
raise ValueError("Unsupported file format") # Inform that the format is unsupported
|
153 |
+
|
154 |
+
pages = loader.load_and_split()
|
155 |
docs = ""
|
156 |
for item in pages:
|
157 |
docs += item.page_content
|
158 |
pages = chunk_str(docs)
|
159 |
+
|
160 |
return pages
|
161 |
|
162 |
def process_single_quotes(text):
|
|
|
184 |
def extract_json_dict(text):
|
185 |
if isinstance(text, dict):
|
186 |
return text
|
187 |
+
pattern = r'\{(?:[^{}]|(?:\{(?:[^{}]|(?:\{[^{}]*\})*)*\})*)*\}'
|
188 |
+
matches = re.findall(pattern, text)
|
189 |
if matches:
|
190 |
+
json_string = matches[-1]
|
191 |
+
json_string = process_single_quotes(json_string)
|
192 |
try:
|
193 |
json_dict = json.loads(json_string)
|
194 |
json_dict = remove_empty_values(json_dict)
|
|
|
196 |
return "No valid information found."
|
197 |
return json_dict
|
198 |
except json.JSONDecodeError:
|
199 |
+
return json_string
|
200 |
else:
|
201 |
+
return text
|
202 |
|
203 |
def good_case_wrapper(example: str):
|
204 |
if example is None or example == "":
|
|
|
219 |
return example
|
220 |
|
221 |
def remove_redundant_space(s):
|
222 |
+
s = ' '.join(s.split())
|
223 |
+
s = re.sub(r"\s*(,|:|\(|\)|\.|_|;|'|-)\s*", r'\1', s)
|
224 |
return s
|
225 |
+
|
226 |
def format_string(s):
|
227 |
s = remove_redundant_space(s)
|
228 |
s = s.lower()
|
|
|
234 |
return s
|
235 |
|
236 |
def calculate_metrics(y_truth: set, y_pred: set):
|
237 |
+
TP = len(y_truth & y_pred)
|
238 |
+
FN = len(y_truth - y_pred)
|
239 |
+
FP = len(y_pred - y_truth)
|
240 |
precision = TP / (TP + FP) if (TP + FP) > 0 else 0
|
241 |
recall = TP / (TP + FN) if (TP + FN) > 0 else 0
|
242 |
f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
|
|
|
251 |
else:
|
252 |
print("No caller function found")
|
253 |
return None
|
254 |
+
|
255 |
except Exception as e:
|
256 |
print(f"An error occurred: {e}")
|
257 |
+
pass
|
258 |
+
|
259 |
def normalize_obj(value):
|
260 |
if isinstance(value, dict):
|
261 |
return frozenset((k, normalize_obj(v)) for k, v in value.items())
|
src/webui.py
CHANGED
@@ -1,9 +1,15 @@
|
|
1 |
-
|
2 |
-
|
|
|
|
|
|
|
3 |
import gradio as gr
|
|
|
|
|
4 |
|
5 |
-
from pipeline import Pipeline
|
6 |
from models import *
|
|
|
|
|
7 |
|
8 |
examples = [
|
9 |
{
|
@@ -43,7 +49,7 @@ examples = [
|
|
43 |
"task": "Base",
|
44 |
"mode": "quick",
|
45 |
"use_file": True,
|
46 |
-
"file_path": "data/Harry_Potter_Chapter1.pdf",
|
47 |
"instruction": "Extract main characters and the background setting from this chapter.",
|
48 |
"constraint": "",
|
49 |
"text": "",
|
@@ -54,13 +60,24 @@ examples = [
|
|
54 |
"task": "Base",
|
55 |
"mode": "quick",
|
56 |
"use_file": True,
|
57 |
-
"file_path": "data/Tulsi_Gabbard_News.html",
|
58 |
"instruction": "Extract key information from the given text.",
|
59 |
"constraint": "",
|
60 |
"text": "",
|
61 |
"update_case": False,
|
62 |
"truth": "",
|
63 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
]
|
65 |
|
66 |
|
@@ -75,16 +92,16 @@ def create_interface():
|
|
75 |
</p>
|
76 |
<h1>OneKE: A Dockerized Schema-Guided LLM Agent-based Knowledge Extraction System</h1>
|
77 |
<p>
|
78 |
-
π[<a href="https://oneke.openkg.cn/" target="_blank">
|
79 |
-
β¨οΈ[<a href="https://github.com/zjunlp/OneKE" target="_blank">Code</a>]
|
80 |
πΉ[<a href="http://oneke.openkg.cn/demo.mp4" target="_blank">Video</a>]
|
|
|
|
|
81 |
</p>
|
82 |
</div>
|
83 |
""")
|
84 |
|
85 |
example_button_gr = gr.Button("π² Quick Start with an Example π²")
|
86 |
|
87 |
-
|
88 |
with gr.Row():
|
89 |
with gr.Column():
|
90 |
model_gr = gr.Dropdown(
|
@@ -103,7 +120,7 @@ def create_interface():
|
|
103 |
with gr.Column():
|
104 |
task_gr = gr.Dropdown(
|
105 |
label="π― Select your Task",
|
106 |
-
choices=["Base", "NER", "RE", "EE"],
|
107 |
value="Base",
|
108 |
)
|
109 |
mode_gr = gr.Dropdown(
|
@@ -139,6 +156,8 @@ def create_interface():
|
|
139 |
return gr.update(visible=False), gr.update(visible=True, label="πΉοΈ Constraint", placeholder="Enter your RE Constraint")
|
140 |
elif task == "EE":
|
141 |
return gr.update(visible=False), gr.update(visible=True, label="πΉοΈ Constraint", placeholder="Enter your EE Constraint")
|
|
|
|
|
142 |
|
143 |
def update_input_fields(use_file):
|
144 |
if use_file:
|
@@ -162,7 +181,7 @@ def create_interface():
|
|
162 |
gr.update(value=example["file_path"], visible=example["use_file"]),
|
163 |
gr.update(value=example["text"], visible=not example["use_file"]),
|
164 |
gr.update(value=example["instruction"], visible=example["task"] == "Base"),
|
165 |
-
gr.update(value=example["constraint"], visible=example["task"] in ["NER", "RE", "EE"]),
|
166 |
gr.update(value=example["update_case"]),
|
167 |
gr.update(value=example["truth"]),
|
168 |
gr.update(value="NOT REQUIRED", visible=False),
|
@@ -207,7 +226,7 @@ def create_interface():
|
|
207 |
if reflection_agent not in ["", "NOT REQUIRED"]:
|
208 |
agent3["reflection_agent"] = reflection_agent
|
209 |
|
210 |
-
#
|
211 |
_, _, ger_frontend_schema, ger_frontend_res = pipeline.get_extract_result(
|
212 |
task=task,
|
213 |
text=text,
|
@@ -336,6 +355,8 @@ def create_interface():
|
|
336 |
|
337 |
return demo
|
338 |
|
|
|
|
|
339 |
if __name__ == "__main__":
|
340 |
interface = create_interface()
|
341 |
-
interface.launch()
|
|
|
1 |
+
"""
|
2 |
+
....../OneKE$ python src/webui.py
|
3 |
+
"""
|
4 |
+
|
5 |
+
|
6 |
import gradio as gr
|
7 |
+
import json
|
8 |
+
import random
|
9 |
|
|
|
10 |
from models import *
|
11 |
+
from pipeline import Pipeline
|
12 |
+
|
13 |
|
14 |
examples = [
|
15 |
{
|
|
|
49 |
"task": "Base",
|
50 |
"mode": "quick",
|
51 |
"use_file": True,
|
52 |
+
"file_path": "data/input_files/Harry_Potter_Chapter1.pdf",
|
53 |
"instruction": "Extract main characters and the background setting from this chapter.",
|
54 |
"constraint": "",
|
55 |
"text": "",
|
|
|
60 |
"task": "Base",
|
61 |
"mode": "quick",
|
62 |
"use_file": True,
|
63 |
+
"file_path": "data/input_files/Tulsi_Gabbard_News.html",
|
64 |
"instruction": "Extract key information from the given text.",
|
65 |
"constraint": "",
|
66 |
"text": "",
|
67 |
"update_case": False,
|
68 |
"truth": "",
|
69 |
},
|
70 |
+
{
|
71 |
+
"task": "Triple",
|
72 |
+
"mode": "quick",
|
73 |
+
"use_file": True,
|
74 |
+
"file_path": "data/input_files/Artificial_Intelligence_Wikipedia.txt",
|
75 |
+
"instruction": "",
|
76 |
+
"constraint": """[["Person", "Place", "Event", "property"], ["Interpersonal", "Located", "Ownership", "Action"]]""",
|
77 |
+
"text": "",
|
78 |
+
"update_case": False,
|
79 |
+
"truth": "",
|
80 |
+
}
|
81 |
]
|
82 |
|
83 |
|
|
|
92 |
</p>
|
93 |
<h1>OneKE: A Dockerized Schema-Guided LLM Agent-based Knowledge Extraction System</h1>
|
94 |
<p>
|
95 |
+
π[<a href="https://oneke.openkg.cn/" target="_blank">Home</a>]
|
|
|
96 |
πΉ[<a href="http://oneke.openkg.cn/demo.mp4" target="_blank">Video</a>]
|
97 |
+
π[<a href="https://arxiv.org/abs/2209.10707" target="_blank">Paper</a>]
|
98 |
+
π»[<a href="https://github.com/zjunlp/OneKE" target="_blank">Code</a>]
|
99 |
</p>
|
100 |
</div>
|
101 |
""")
|
102 |
|
103 |
example_button_gr = gr.Button("π² Quick Start with an Example π²")
|
104 |
|
|
|
105 |
with gr.Row():
|
106 |
with gr.Column():
|
107 |
model_gr = gr.Dropdown(
|
|
|
120 |
with gr.Column():
|
121 |
task_gr = gr.Dropdown(
|
122 |
label="π― Select your Task",
|
123 |
+
choices=["Base", "NER", "RE", "EE", "Triple"],
|
124 |
value="Base",
|
125 |
)
|
126 |
mode_gr = gr.Dropdown(
|
|
|
156 |
return gr.update(visible=False), gr.update(visible=True, label="πΉοΈ Constraint", placeholder="Enter your RE Constraint")
|
157 |
elif task == "EE":
|
158 |
return gr.update(visible=False), gr.update(visible=True, label="πΉοΈ Constraint", placeholder="Enter your EE Constraint")
|
159 |
+
elif task == "Triple":
|
160 |
+
return gr.update(visible=False), gr.update(visible=True, label="πΉοΈ Constraint", placeholder="Enter your Triple Constraint")
|
161 |
|
162 |
def update_input_fields(use_file):
|
163 |
if use_file:
|
|
|
181 |
gr.update(value=example["file_path"], visible=example["use_file"]),
|
182 |
gr.update(value=example["text"], visible=not example["use_file"]),
|
183 |
gr.update(value=example["instruction"], visible=example["task"] == "Base"),
|
184 |
+
gr.update(value=example["constraint"], visible=example["task"] in ["NER", "RE", "EE", "Triple"]),
|
185 |
gr.update(value=example["update_case"]),
|
186 |
gr.update(value=example["truth"]),
|
187 |
gr.update(value="NOT REQUIRED", visible=False),
|
|
|
226 |
if reflection_agent not in ["", "NOT REQUIRED"]:
|
227 |
agent3["reflection_agent"] = reflection_agent
|
228 |
|
229 |
+
# use 'Pipeline'
|
230 |
_, _, ger_frontend_schema, ger_frontend_res = pipeline.get_extract_result(
|
231 |
task=task,
|
232 |
text=text,
|
|
|
355 |
|
356 |
return demo
|
357 |
|
358 |
+
|
359 |
+
# Launch the front-end interface
|
360 |
if __name__ == "__main__":
|
361 |
interface = create_interface()
|
362 |
+
interface.launch() # the Gradio defalut URL usually is: 127.0.0.1:7860
|