ShawnRu commited on
Commit
e6e7506
Β·
1 Parent(s): c376793
.gitattributes CHANGED
@@ -32,5 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.xz filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
36
- data/Harry_Potter_Chapter1.pdf filter=lfs diff=lfs merge=lfs -text
 
32
  *.xz filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
 
.gitignore CHANGED
@@ -1,3 +1,3 @@
1
- local
2
  **/__pycache__
3
- *.pyc
 
 
 
1
  **/__pycache__
2
+ *.pyc
3
+ dev
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 ZJUNLP
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
data/Artificial_Intelligence_Wikipedia.txt DELETED
@@ -1,46 +0,0 @@
1
- In the 22nd century, rising sea levels from global warming
2
- have wiped out coastal cities and altered the world's climate.
3
- With the human population in decline, nations have created
4
- humanoid robots called mechas to fulfill various roles.
5
-
6
- In Madison, New Jersey, David, an 11-year-old prototype mecha
7
- child capable of love, is given to Henry Swinton and his wife
8
- Monica, whose son Martin is in suspended animation. Monica
9
- initially feels uncomfortable but warms to David after he is
10
- activated and imprinted. David befriends Teddy, Martin's robotic
11
- teddy bear.
12
-
13
- After Martin is cured and brought home, he goads David into
14
- cutting off a piece of Monica's hair. Later, David accidentally
15
- pokes Monica's eye with scissors. During a pool party, David
16
- reacts to being poked with a knife and both he and Martin fall
17
- into the pool. Martin is saved, but David is blamed.
18
-
19
- Henry convinces Monica to return David to his creators for
20
- destruction, but instead, she abandons him in the woods with
21
- Teddy. David, believing that becoming human will regain Monica's
22
- love, decides to find the Blue Fairy.
23
-
24
- David and Teddy are captured by the "Flesh Fair", where obsolete
25
- mechas are destroyed. David pleads for his life, and the audience
26
- allows him to escape with Gigolo Joe, a mecha framed for murder.
27
- They travel to Rouge City, where "Dr. Know", a holographic answer
28
- engine, directs them to the ruins of New York City and suggests
29
- that the Blue Fairy may help.
30
-
31
- David meets Professor Hobby, who shows him copies of himself,
32
- including female variants. Disheartened, David attempts suicide,
33
- but Joe rescues him. They find the Blue Fairy, which turns out
34
- to be a statue. David repeatedly asks the statue to turn him
35
- into a real boy until his power source is depleted.
36
-
37
- Two thousand years later, humanity is extinct and Manhattan is
38
- buried under ice. Mechas have evolved, and a group called the
39
- Specialists resurrect David and Teddy. They reconstruct the Swinton
40
- home from David's memories and explain that he cannot become human.
41
- However, they recreate Monica using genetic material from the
42
- strand of hair Teddy kept. Monica can live for only one day.
43
-
44
- David spends his happiest day with Monica, and as she falls asleep,
45
- she tells him she has always loved him. David lies down next to her
46
- and closes his eyes.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
data/Harry_Potter_Chapter1.pdf DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:b9eb5104658f4d6ef8ff9b457f28f188b6aa1b201443719c501e462072eacf57
3
- size 163709
 
 
 
 
data/Tulsi_Gabbard_News.html DELETED
The diff for this file is too large to render. See raw diff
 
examples/config/BookExtraction.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ # Recommend using ChatGPT or DeepSeek APIs for complex IE task.
3
+ category: ChatGPT # model category, chosen from ChatGPT, DeepSeek, LLaMA, Qwen, ChatGLM, MiniCPM, OneKE.
4
+ model_name_or_path: gpt-4o-mini # # model name, chosen from the model list of the selected category.
5
+ api_key: your_api_key # your API key for the model with API service. No need for open-source models.
6
+ base_url: https://api.openai.com/v1 # # base URL for the API service. No need for open-source models.
7
+
8
+ extraction:
9
+ task: Base # task type, chosen from Base, NER, RE, EE.
10
+ instruction: Extract main characters and background setting from this chapter. # description for the task. No need for NER, RE, EE task.
11
+ use_file: true # whether to use a file for the input text. Default set to false.
12
+ file_path: ./data/input_files/Harry_Potter_Chapter1.pdf # # path to the input file. No need if use_file is set to false.
13
+ mode: quick # extraction mode, chosen from quick, detailed, customized. Default set to quick. See src/config.yaml for more details.
14
+ update_case: false # whether to update the case repository. Default set to false.
15
+ show_trajectory: false # whether to display the extracted intermediate steps
examples/config/EE.yaml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ category: DeepSeek # model category, chosen from ChatGPT, DeepSeek, LLaMA, Qwen, ChatGLM, MiniCPM, OneKE.
3
+ model_name_or_path: deepseek-chat # model name, chosen from the model list of the selected category.
4
+ api_key: your_api_key # your API key for the model with API service. No need for open-source models.
5
+ base_url: https://api.deepseek.com # base URL for the API service. No need for open-source models.
6
+
7
+ extraction:
8
+ task: EE # task type, chosen from Base, NER, RE, EE.
9
+ text: UConn Health , an academic medical center , says in a media statement that it identified approximately 326,000 potentially impacted individuals whose personal information was contained in the compromised email accounts. # input text for the extraction task. No need if use_file is set to true.
10
+ constraint: {"phishing": ["damage amount", "attack pattern", "tool", "victim", "place", "attacker", "purpose", "trusted entity", "time"], "data breach": ["damage amount", "attack pattern", "number of data", "number of victim", "tool", "compromised data", "victim", "place", "attacker", "purpose", "time"], "ransom": ["damage amount", "attack pattern", "payment method", "tool", "victim", "place", "attacker", "price", "time"], "discover vulnerability": ["vulnerable system", "vulnerability", "vulnerable system owner", "vulnerable system version", "supported platform", "common vulnerabilities and exposures", "capabilities", "time", "discoverer"], "patch vulnerability": ["vulnerable system", "vulnerability", "issues addressed", "vulnerable system version", "releaser", "supported platform", "common vulnerabilities and exposures", "patch number", "time", "patch"]} # Specified event type and the corresponding arguments for the event extraction task. Structured as a dictionary with the event type as the key and the list of arguments as the value. Default set to empty.
11
+ use_file: false # whether to use a file for the input text.
12
+ mode: standard # extraction mode, chosen from quick, detailed, customized. Default set to quick. See src/config.yaml for more details.
13
+ update_case: false # whether to update the case repository. Default set to false.
14
+ show_trajectory: false # whether to display the extracted intermediate steps
examples/config/NER.yaml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ category: LLaMA # model category, chosen from ChatGPT, DeepSeek, LLaMA, Qwen, ChatGLM, MiniCPM, OneKE.
3
+ model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct # model name to download from huggingface or use the local model path.
4
+ vllm_serve: false # whether to use the vllm. Default set to false.
5
+
6
+ extraction:
7
+ task: NER # task type, chosen from Base, NER, RE, EE.
8
+ text: Finally , every other year , ELRA organizes a major conference LREC , the International Language Resources and Evaluation Conference . # input text for the extraction task. No need if use_file is set to true.
9
+ constraint: ["algorithm", "conference", "else", "product", "task", "field", "metrics", "organization", "researcher", "program language", "country", "location", "person", "university"] # Specified entity types for the named entity recognition task. Default set to empty.
10
+ use_file: false # whether to use a file for the input text.
11
+ mode: quick # extraction mode, chosen from quick, detailed, customized. Default set to quick. See src/config.yaml for more details.
12
+ update_case: false # whether to update the case repository. Default set to false.
13
+ show_trajectory: false # whether to display the extracted intermediate steps
examples/config/NewsExtraction.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ category: DeepSeek # model category, chosen from ChatGPT, DeepSeek, LLaMA, Qwen, ChatGLM, MiniCPM, OneKE.
3
+ model_name_or_path: deepseek-chat # model name, chosen from the model list of the selected category.
4
+ api_key: your_api_key # your API key for the model with API service. No need for open-source models.
5
+ base_url: https://api.deepseek.com # base URL for the API service. No need for open-source models.
6
+
7
+ extraction:
8
+ task: Base # task type, chosen from Base, NER, RE, EE.
9
+ instruction: Extract key information from the given text. # description for the task. No need for NER, RE, EE task.
10
+ use_file: true # whether to use a file for the input text. Default set to false.
11
+ file_path: ./data/input_files/Tulsi_Gabbard_News.html # path to the input file. No need if use_file is set to false.
12
+ output_schema: NewsReport # output schema for the extraction task. Selected the from schema repository.
13
+ mode: customized # extraction mode, chosen from quick, detailed, customized. Default set to quick. See src/config.yaml for more details.
14
+ update_case: false # whether to update the case repository. Default set to false.
15
+ show_trajectory: false # whether to display the extracted intermediate steps
examples/config/RE.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ category: ChatGPT # model category, chosen from ChatGPT, DeepSeek, LLaMA, Qwen, ChatGLM, MiniCPM, OneKE.
3
+ model_name_or_path: gpt-4o-mini # model name, chosen from the model list of the selected category.
4
+ api_key: your_api_key # your API key for the model with API service. No need for open-source models.
5
+ base_url: https://api.openai.com/v1 # base URL for the API service. No need for open-source models.
6
+
7
+ extraction:
8
+ task: RE # task type, chosen from Base, NER, RE, EE.
9
+ text: The aid group Doctors Without Borders said that since Saturday , more than 275 wounded people had been admitted and treated at Donka Hospital in the capital of Guinea , Conakry . # input text for the extraction task. No need if use_file is set to true.
10
+ constraint: ["nationality", "country capital", "place of death", "children", "location contains", "place of birth", "place lived", "administrative division of country", "country of administrative divisions", "company", "neighborhood of", "company founders"] # Specified entity types for the named entity recognition task. Default set to empty.
11
+ truth: {"relation_list": [{"head": "Guinea", "tail": "Conakry", "relation": "country capital"}]} # Truth data for the relation extraction task. Structured as a dictionary with the list of relation tuples as the value. Required if set update_case to true.
12
+ use_file: false # whether to use a file for the input text.
13
+ mode: quick # extraction mode, chosen from quick, detailed, customized. Default set to quick. See src/config.yaml for more details.
14
+ update_case: true # whether to update the case repository. Default set to false.
15
+ show_trajectory: false # whether to display the extracted intermediate steps
examples/config/Triple2KG.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ # Recommend using ChatGPT or DeepSeek APIs for complex Triple task.
3
+ category: ChatGPT # model category, chosen from ChatGPT, DeepSeek, LLaMA, Qwen, ChatGLM, MiniCPM, OneKE.
4
+ model_name_or_path: gpt-4o-mini # # model name, chosen from the model list of the selected category.
5
+ api_key: your_api_key # your API key for the model with API service. No need for open-source models.
6
+ base_url: https://api.openai.com/v1 # # base URL for the API service. No need for open-source models.
7
+
8
+ extraction:
9
+ mode: quick # extraction mode, chosen from quick, detailed, customized. Default set to quick. See src/config.yaml for more details.
10
+ task: Triple # task type, chosen from Base, NER, RE, EE. Now newly added task 'Triple'.
11
+ use_file: true # whether to use a file for the input text. Default set to false.
12
+ file_path: ./data/input_files/Artificial_Intelligence_Wikipedia.txt # # path to the input file. No need if use_file is set to false.
13
+ constraint: [["Person", "Place", "Event", "Property"], ["Interpersonal", "Located", "Ownership", "Action"]] # Specified entity or relation types for Triple Extraction task. You can write 3 lists for subject, relation and object types. Or you can write 2 lists for entity and relation types. Or you can write 1 list for entity type only.
14
+ update_case: false # whether to update the case repository. Default set to false.
15
+ show_trajectory: false # whether to display the extracted intermediate steps
16
+
17
+ # construct: # (Optional) If you want to construct a Knowledge Graph, you need to set the construct field, or you must delete this field.
18
+ # database: Neo4j # database type, now only support Neo4j.
19
+ # url: neo4j://localhost:7687 # your database URL,Neo4j's default port is 7687.
20
+ # username: your_username # your database username.
21
+ # password: "your_password" # your database password.
examples/example.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append("./src")
3
+ from models import *
4
+ from pipeline import *
5
+ import json
6
+
7
+ # model configuration
8
+ model = ChatGPT(model_name_or_path="your_model_name_or_path", api_key="your_api_key")
9
+ pipeline = Pipeline(model)
10
+
11
+ # extraction configuration
12
+ Task = "NER"
13
+ Text = "Finally , every other year , ELRA organizes a major conference LREC , the International Language Resources and Evaluation Conference."
14
+ Constraint = ["nationality", "country capital", "place of death", "children", "location contains", "place of birth", "place lived", "administrative division of country", "country of administrative divisions", "company", "neighborhood of", "company founders"]
15
+
16
+ # get extraction result
17
+ result, trajectory, frontend_schema, frontend_res = pipeline.get_extract_result(task=Task, text=Text, constraint=Constraint, show_trajectory=True)
examples/results/BookExtraction.json ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "main_characters": [
3
+ {
4
+ "name": "Mr. Dursley",
5
+ "description": "The director of a firm called Grunnings, a big, beefy man with hardly any neck and a large mustache."
6
+ },
7
+ {
8
+ "name": "Mrs. Dursley",
9
+ "description": "Thin and blonde, with nearly twice the usual amount of neck, spends time spying on neighbors."
10
+ },
11
+ {
12
+ "name": "Dudley Dursley",
13
+ "description": "The small son of Mr. and Mrs. Dursley, considered by them to be the finest boy anywhere."
14
+ },
15
+ {
16
+ "name": "Albus Dumbledore",
17
+ "description": "A tall, thin, and very old man with long silver hair and a purple cloak, who arrives mysteriously."
18
+ },
19
+ {
20
+ "name": "Professor McGonagall",
21
+ "description": "A severe-looking woman who can transform into a cat, wearing an emerald cloak."
22
+ },
23
+ {
24
+ "name": "Voldemort",
25
+ "description": "The dark wizard who has caused fear and chaos, but has mysteriously disappeared."
26
+ },
27
+ {
28
+ "name": "Harry Potter",
29
+ "description": "The young boy who survived Voldemort's attack, becoming a significant figure in the wizarding world."
30
+ },
31
+ {
32
+ "name": "Lily Potter",
33
+ "description": "Harry's mother, who is mentioned as having been killed by Voldemort."
34
+ },
35
+ {
36
+ "name": "James Potter",
37
+ "description": "Harry's father, who is mentioned as having been killed by Voldemort."
38
+ },
39
+ {
40
+ "name": "Hagrid",
41
+ "description": "A giant man who is caring and emotional about Harry's situation."
42
+ }
43
+ ],
44
+ "background_setting": {
45
+ "location": "Number four, Privet Drive, Suburban",
46
+ "time_period": "A dull, gray Tuesday morning, Late 20th Century"
47
+ }
48
+ }
examples/results/EE.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "event_list": [
3
+ {
4
+ "event_type": "data breach",
5
+ "event_trigger": "compromised",
6
+ "event_argument": {
7
+ "number of victim": 326000,
8
+ "compromised data": "personal information contained in email accounts",
9
+ "victim": "individuals whose personal information was compromised"
10
+ }
11
+ }
12
+ ]
13
+ }
examples/results/NER.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "entity_list": [
3
+ {
4
+ "name": "ELRA",
5
+ "type": "organization"
6
+ },
7
+ {
8
+ "name": "LREC",
9
+ "type": "conference"
10
+ },
11
+ {
12
+ "name": "International Language Resources and Evaluation Conference",
13
+ "type": "conference"
14
+ }
15
+ ]
16
+ }
examples/results/NewsExtraction.json ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "title": "Who is Tulsi Gabbard? Meet Trump's pick for director of national intelligence",
3
+ "summary": "Tulsi Gabbard, President-elect Donald Trump\u2019s choice for director of national intelligence, could face a challenging Senate confirmation battle due to her lack of intelligence experience and controversial views.",
4
+ "publication_date": "December 4, 2024",
5
+ "keywords": [
6
+ "Tulsi Gabbard",
7
+ "Donald Trump",
8
+ "director of national intelligence",
9
+ "confirmation battle",
10
+ "intelligence agencies",
11
+ "Russia",
12
+ "Syria",
13
+ "Bashar al-Assad"
14
+ ],
15
+ "events": [
16
+ {
17
+ "name": "Tulsi Gabbard's nomination for director of national intelligence",
18
+ "people_involved": [
19
+ {
20
+ "name": "Tulsi Gabbard",
21
+ "identity": "Former U.S. Representative",
22
+ "role": "Nominee for director of national intelligence"
23
+ },
24
+ {
25
+ "name": "Donald Trump",
26
+ "identity": "President-elect",
27
+ "role": "Nominator"
28
+ },
29
+ {
30
+ "name": "Tammy Duckworth",
31
+ "identity": "Democratic Senator",
32
+ "role": "Critic of Gabbard's nomination"
33
+ },
34
+ {
35
+ "name": "Olivia Troye",
36
+ "identity": "Former national security official",
37
+ "role": "Commentator on Gabbard's potential impact"
38
+ }
39
+ ],
40
+ "process": "Gabbard's nomination is expected to lead to a Senate confirmation battle."
41
+ }
42
+ ],
43
+ "quotes": {
44
+ "Tammy Duckworth": "The U.S. intelligence community has identified her as having troubling relationships with America\u2019s foes, and so my worry is that she couldn\u2019t pass a background check.",
45
+ "Olivia Troye": "If Gabbard is confirmed, America\u2019s allies may not share as much information with the U.S."
46
+ },
47
+ "viewpoints": [
48
+ "Gabbard's lack of intelligence experience raises concerns about her ability to oversee 18 intelligence agencies.",
49
+ "Her past comments and meetings with foreign adversaries have led to accusations of being a national security risk."
50
+ ]
51
+ }
examples/results/RE.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "relation_list": [
3
+ {
4
+ "head": "Guinea",
5
+ "tail": "Conakry",
6
+ "relation": "country capital"
7
+ }
8
+ ]
9
+ }
examples/results/TripleExtraction.json ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "triple_list": [
3
+ {
4
+ "head": "sea levels",
5
+ "head_type": "Property",
6
+ "relation": "wiped out",
7
+ "relation_type": "Action",
8
+ "tail": "coastal cities",
9
+ "tail_type": "Place"
10
+ },
11
+ {
12
+ "head": "nations",
13
+ "head_type": "Person",
14
+ "relation": "created",
15
+ "relation_type": "Action",
16
+ "tail": "mechas",
17
+ "tail_type": "Property"
18
+ },
19
+ {
20
+ "head": "David",
21
+ "head_type": "Person",
22
+ "relation": "given to",
23
+ "relation_type": "Ownership",
24
+ "tail": "Henry and Monica",
25
+ "tail_type": "Person"
26
+ },
27
+ {
28
+ "head": "Monica",
29
+ "head_type": "Person",
30
+ "relation": "feels uncomfortable",
31
+ "relation_type": "Interpersonal",
32
+ "tail": "David",
33
+ "tail_type": "Person"
34
+ },
35
+ {
36
+ "head": "David",
37
+ "head_type": "Person",
38
+ "relation": "befriends",
39
+ "relation_type": "Interpersonal",
40
+ "tail": "Teddy",
41
+ "tail_type": "Person"
42
+ },
43
+ {
44
+ "head": "Martin",
45
+ "head_type": "Person",
46
+ "relation": "goads",
47
+ "relation_type": "Action",
48
+ "tail": "David",
49
+ "tail_type": "Person"
50
+ },
51
+ {
52
+ "head": "David",
53
+ "head_type": "Person",
54
+ "relation": "blamed for",
55
+ "relation_type": "Action",
56
+ "tail": "incident",
57
+ "tail_type": "Event"
58
+ },
59
+ {
60
+ "head": "Monica",
61
+ "head_type": "Person",
62
+ "relation": "returns David to",
63
+ "relation_type": "Ownership",
64
+ "tail": "creators",
65
+ "tail_type": "Person"
66
+ },
67
+ {
68
+ "head": "David",
69
+ "head_type": "Person",
70
+ "relation": "decides to find",
71
+ "relation_type": "Action",
72
+ "tail": "Blue Fairy",
73
+ "tail_type": "Property"
74
+ },
75
+ {
76
+ "head": "David",
77
+ "head_type": "Person",
78
+ "relation": "pleads for",
79
+ "relation_type": "Action",
80
+ "tail": "his life",
81
+ "tail_type": "Event"
82
+ },
83
+ {
84
+ "head": "David",
85
+ "head_type": "Person",
86
+ "relation": "meets",
87
+ "relation_type": "Interpersonal",
88
+ "tail": "Professor Hobby",
89
+ "tail_type": "Person"
90
+ },
91
+ {
92
+ "head": "David",
93
+ "head_type": "Person",
94
+ "relation": "attempts",
95
+ "relation_type": "Action",
96
+ "tail": "suicide",
97
+ "tail_type": "Event"
98
+ },
99
+ {
100
+ "head": "Joe",
101
+ "head_type": "Person",
102
+ "relation": "rescues",
103
+ "relation_type": "Action",
104
+ "tail": "David",
105
+ "tail_type": "Person"
106
+ },
107
+ {
108
+ "head": "David",
109
+ "head_type": "Person",
110
+ "relation": "asks statue to turn him into",
111
+ "relation_type": "Action",
112
+ "tail": "real boy",
113
+ "tail_type": "Property"
114
+ },
115
+ {
116
+ "head": "humanity",
117
+ "head_type": "Person",
118
+ "relation": "is extinct",
119
+ "relation_type": "Action",
120
+ "tail": "future",
121
+ "tail_type": "Event"
122
+ },
123
+ {
124
+ "head": "Specialists",
125
+ "head_type": "Person",
126
+ "relation": "resurrect",
127
+ "relation_type": "Action",
128
+ "tail": "David and Teddy",
129
+ "tail_type": "Person"
130
+ },
131
+ {
132
+ "head": "Monica",
133
+ "head_type": "Person",
134
+ "relation": "can live for",
135
+ "relation_type": "Property",
136
+ "tail": "one day",
137
+ "tail_type": "Property"
138
+ },
139
+ {
140
+ "head": "David",
141
+ "head_type": "Person",
142
+ "relation": "spends",
143
+ "relation_type": "Action",
144
+ "tail": "happiest day with Monica",
145
+ "tail_type": "Event"
146
+ },
147
+ {
148
+ "head": "Monica",
149
+ "head_type": "Person",
150
+ "relation": "tells",
151
+ "relation_type": "Interpersonal",
152
+ "tail": "David",
153
+ "tail_type": "Person"
154
+ }
155
+ ]
156
+ }
experiments/dataset_def.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import random
4
+ from utils import *
5
+ from pipeline import *
6
+ current_dir = os.path.dirname(os.path.abspath(__file__))
7
+ DATA_DIR = os.path.join(current_dir, "../data/datasets")
8
+ OUTPUT_DIR = os.path.join(current_dir, "results")
9
+
10
+ class BaseDataset:
11
+ def __init__(self):
12
+ pass
13
+
14
+ def __getitem__(self, idx):
15
+ return None
16
+
17
+ def __len__(self):
18
+ return None
19
+
20
+ def evaluate(self, idx, answer):
21
+ return None
22
+
23
+ class NERDataset(BaseDataset):
24
+ def __init__(self, name=None, task="NER", data_dir = f"{DATA_DIR}/CrossNER", output_dir = f"{OUTPUT_DIR}", train=False):
25
+ self.name = name
26
+ self.task = task
27
+ self.data_dir = data_dir
28
+ self.output_dir = output_dir
29
+ self.test_file = json.load(open(f"{data_dir}/train.json")) if train else json.load(open(f"{data_dir}/test.json"))
30
+ self.schema = str(json.load(open(f"{data_dir}/class.json")))
31
+ self.retry = 2
32
+
33
+ def evaluate(self, llm: BaseEngine, mode="", sample=None, random_sample=False, update_case=False):
34
+ # initialize
35
+ sample = len(self.test_file) if sample is None else sample
36
+ if random_sample:
37
+ test_file = random.sample(self.test_file, sample)
38
+ else:
39
+ test_file = self.test_file[:sample]
40
+ total_precision, total_recall, total_f1 = 0, 0, 0
41
+ num_items = 0
42
+ output_path = f"{self.output_dir}/{self.name}_{self.task}_{mode}_{llm.name}_sample{sample}.jsonl"
43
+ print("Results will be saved to: ", output_path)
44
+
45
+ # predict and evaluate
46
+ pipeline = Pipeline(llm=llm)
47
+ for item in test_file:
48
+ try:
49
+ # get prediction
50
+ num_items += 1
51
+ truth = list(item.items())[1]
52
+ truth = {truth[0]: truth[1]}
53
+ pred_set = set()
54
+ for attempt in range(self.retry):
55
+ pred_result, pred_detailed, _, _ = pipeline.get_extract_result(task=self.task, text=item['sentence'], constraint=self.schema, mode=mode, truth=truth, update_case=update_case)
56
+ try:
57
+ pred_result = pred_result['entity_list']
58
+ pred_set = dict_list_to_set(pred_result)
59
+ break
60
+ except Exception as e:
61
+ print(f"Failed to parse result: {pred_result}, retrying... Exception: {e}")
62
+
63
+ # evaluate
64
+ truth_result = item["entity_list"]
65
+ truth_set = dict_list_to_set(truth_result)
66
+ print(truth_set)
67
+ print(pred_set)
68
+
69
+ precision, recall, f1_score = calculate_metrics(truth_set, pred_set)
70
+ total_precision += precision
71
+ total_recall += recall
72
+ total_f1 += f1_score
73
+
74
+ pred_detailed["pred"] = pred_result
75
+ pred_detailed["truth"] = truth_result
76
+ pred_detailed["metrics"] = {"precision": precision, "recall": recall, "f1_score": f1_score}
77
+ res_detailed = {"id": num_items}
78
+ res_detailed.update(pred_detailed)
79
+ with open(output_path, 'a') as file:
80
+ file.write(json.dumps(res_detailed) + '\n')
81
+ except Exception as e:
82
+ print(f"Exception occured: {e}")
83
+ print(f"idx: {num_items}")
84
+ pass
85
+
86
+ # calculate overall metrics
87
+ if num_items > 0:
88
+ avg_precision = total_precision / num_items
89
+ avg_recall = total_recall / num_items
90
+ avg_f1 = total_f1 / num_items
91
+ overall_metrics = {
92
+ "total_items": num_items,
93
+ "average_precision": avg_precision,
94
+ "average_recall": avg_recall,
95
+ "average_f1_score": avg_f1
96
+ }
97
+ with open(output_path, 'a') as file:
98
+ file.write(json.dumps(overall_metrics) + '\n\n')
99
+ print(f"Overall Metrics:\nTotal Items: {num_items}\nAverage Precision: {avg_precision:.4f}\nAverage Recall: {avg_recall:.4f}\nAverage F1 Score: {avg_f1:.4f}")
100
+ else:
101
+ print("No items processed.")
102
+
103
+ class REDataset(BaseDataset):
104
+ def __init__(self, name=None, task="RE", data_dir = f"{DATA_DIR}/NYT11", output_dir = f"{OUTPUT_DIR}", train=False):
105
+ self.name = name
106
+ self.task = task
107
+ self.data_dir = data_dir
108
+ self.output_dir = output_dir
109
+ self.test_file = json.load(open(f"{data_dir}/train.json")) if train else json.load(open(f"{data_dir}/test.json"))
110
+ self.schema = str(json.load(open(f"{data_dir}/class.json")))
111
+ self.retry = 2
112
+
113
+ def evaluate(self, llm: BaseEngine, mode="", sample=None, random_sample=False, update_case=False):
114
+ # initialize
115
+ sample = len(self.test_file) if sample is None else sample
116
+ if random_sample:
117
+ test_file = random.sample(self.test_file, sample)
118
+ else:
119
+ test_file = self.test_file[:sample]
120
+ total_precision, total_recall, total_f1 = 0, 0, 0
121
+ num_items = 0
122
+ output_path = f"{self.output_dir}/{self.name}_{self.task}_{mode}_{llm.name}_sample{sample}.jsonl"
123
+ print("Results will be saved to: ", output_path)
124
+
125
+ # predict and evaluate
126
+ pipeline = Pipeline(llm=llm)
127
+ for item in test_file:
128
+ try:
129
+ # get prediction
130
+ num_items += 1
131
+ truth = list(item.items())[1]
132
+ truth = {truth[0]: truth[1]}
133
+ pred_set = set()
134
+ for attempt in range(self.retry):
135
+ pred_result, pred_detailed, _, _ = pipeline.get_extract_result(task=self.task, text=item['sentence'], constraint=self.schema, mode=mode, truth=truth, update_case=update_case)
136
+ try:
137
+ pred_result = pred_result['relation_list']
138
+ pred_set = dict_list_to_set(pred_result)
139
+ break
140
+ except Exception as e:
141
+ print(f"Failed to parse result: {pred_result}, retrying... Exception: {e}")
142
+
143
+ # evaluate
144
+ truth_result = item["relation_list"]
145
+ truth_set = dict_list_to_set(truth_result)
146
+ print(truth_set)
147
+ print(pred_set)
148
+
149
+ precision, recall, f1_score = calculate_metrics(truth_set, pred_set)
150
+ total_precision += precision
151
+ total_recall += recall
152
+ total_f1 += f1_score
153
+
154
+ pred_detailed["pred"] = pred_result
155
+ pred_detailed["truth"] = truth_result
156
+ pred_detailed["metrics"] = {"precision": precision, "recall": recall, "f1_score": f1_score}
157
+ res_detailed = {"id": num_items}
158
+ res_detailed.update(pred_detailed)
159
+ with open(output_path, 'a') as file:
160
+ file.write(json.dumps(res_detailed) + '\n')
161
+ except Exception as e:
162
+ print(f"Exception occured: {e}")
163
+ print(f"idx: {num_items}")
164
+ pass
165
+
166
+ # calculate overall metrics
167
+ if num_items > 0:
168
+ avg_precision = total_precision / num_items
169
+ avg_recall = total_recall / num_items
170
+ avg_f1 = total_f1 / num_items
171
+ overall_metrics = {
172
+ "total_items": num_items,
173
+ "average_precision": avg_precision,
174
+ "average_recall": avg_recall,
175
+ "average_f1_score": avg_f1
176
+ }
177
+ with open(output_path, 'a') as file:
178
+ file.write(json.dumps(overall_metrics) + '\n\n')
179
+ print(f"Overall Metrics:\nTotal Items: {num_items}\nAverage Precision: {avg_precision:.4f}\nAverage Recall: {avg_recall:.4f}\nAverage F1 Score: {avg_f1:.4f}")
180
+ else:
181
+ print("No items processed.")
experiments/run_ner.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append("./src")
3
+ from models import *
4
+ from dataset_def import *
5
+ name = "crossner-"
6
+ data_dir = "./data/datasets/CrossNER/"
7
+ model = ChatGPT(model_name_or_path="gpt-4o-mini", api_key="your_api_key", base_url=" https://api.openai.com/v1")
8
+ tasklist = ["ai", "literature", "music", "politics", "science"]
9
+ for task in tasklist:
10
+ task_name = name + task
11
+ task_data_dir = data_dir + task
12
+ dataset = NERDataset(name=task_name, data_dir=task_data_dir)
13
+ mode = "quick"
14
+ f1_score = dataset.evaluate(llm=model, mode=mode)
15
+ print(f"Task: {task_name}, f1_score: {f1_score}")
experiments/run_re.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append("./src")
3
+ from models import *
4
+ from dataset_def import *
5
+ data_dir = "./data/datasets/NYT11/"
6
+ model = LLaMA("meta-llama/Meta-Llama-3-8B-Instruct")
7
+ dataset = REDataset(name="NYT11", data_dir=data_dir)
8
+ f1_score = dataset.evaluate(llm=model, mode="quick")
9
+ print("f1_score: ", f1_score)
10
+
figs/logo.png ADDED
figs/main.png ADDED
requirements.txt CHANGED
@@ -15,4 +15,5 @@ sentencepiece==0.2.0
15
  protobuf==5.29.3
16
  bitsandbytes==0.45.0
17
  vllm==0.6.0
18
- gradio==4.44.0
 
 
15
  protobuf==5.29.3
16
  bitsandbytes==0.45.0
17
  vllm==0.6.0
18
+ gradio==4.44.0
19
+ neo4j==5.28.1
src/config.yaml CHANGED
@@ -6,6 +6,7 @@ agent:
6
  default_ner: Extract the Named Entities in the given text.
7
  default_re: Extract Relationships between Named Entities in the given text.
8
  default_ee: Extract the Events in the given text.
 
9
  chunk_token_limit: 1024
10
  mode:
11
  quick:
@@ -16,5 +17,5 @@ agent:
16
  extraction_agent: extract_information_with_case
17
  reflection_agent: reflect_with_case
18
  customized:
19
- schema_agent: get_retrieved_schema
20
  extraction_agent: extract_information_direct
 
6
  default_ner: Extract the Named Entities in the given text.
7
  default_re: Extract Relationships between Named Entities in the given text.
8
  default_ee: Extract the Events in the given text.
9
+ default_triple: Extract the Triples (subject, relation, object) from the given text, hope that all the relationships for each entity can be extracted.
10
  chunk_token_limit: 1024
11
  mode:
12
  quick:
 
17
  extraction_agent: extract_information_with_case
18
  reflection_agent: reflect_with_case
19
  customized:
20
+ schema_agent: get_retrieved_schema
21
  extraction_agent: extract_information_direct
src/construct/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .convert import *
src/construct/convert.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import re
3
+ from neo4j import GraphDatabase
4
+
5
+
6
+ def sanitize_string(input_str, max_length=255):
7
+ """
8
+ Process the input string to ensure it meets the database requirements.
9
+ """
10
+ # step1: Replace invalid characters
11
+ input_str = re.sub(r'[^a-zA-Z0-9_]', '_', input_str)
12
+
13
+ # step2: Add prefix if it starts with a digit
14
+ if input_str[0].isdigit():
15
+ input_str = 'num' + input_str
16
+
17
+ # step3: Limit length
18
+ if len(input_str) > max_length:
19
+ input_str = input_str[:max_length]
20
+
21
+ return input_str
22
+
23
+
24
+ def generate_cypher_statements(data):
25
+ """
26
+ Generates Cypher query statements based on the provided JSON data.
27
+ """
28
+ cypher_statements = []
29
+ parsed_data = json.loads(data)
30
+
31
+ def create_statement(triple):
32
+ head = triple.get("head")
33
+ head_type = triple.get("head_type")
34
+ relation = triple.get("relation")
35
+ relation_type = triple.get("relation_type")
36
+ tail = triple.get("tail")
37
+ tail_type = triple.get("tail_type")
38
+
39
+ # head_safe = sanitize_string(head) if head else None
40
+ head_type_safe = sanitize_string(head_type) if head_type else None
41
+ # relation_safe = sanitize_string(relation) if relation else None
42
+ relation_type_safe = sanitize_string(relation_type) if relation_type else None
43
+ # tail_safe = sanitize_string(tail) if tail else None
44
+ tail_type_safe = sanitize_string(tail_type) if tail_type else None
45
+
46
+ statement = ""
47
+ if head:
48
+ if head_type_safe:
49
+ statement += f'MERGE (a:{head_type_safe} {{name: "{head}"}}) '
50
+ else:
51
+ statement += f'MERGE (a:UNTYPED {{name: "{head}"}}) '
52
+ if tail:
53
+ if tail_type_safe:
54
+ statement += f'MERGE (b:{tail_type_safe} {{name: "{tail}"}}) '
55
+ else:
56
+ statement += f'MERGE (b:UNTYPED {{name: "{tail}"}}) '
57
+ if relation:
58
+ if head and tail: # Only create relation if head and tail exist.
59
+ if relation_type_safe:
60
+ statement += f'MERGE (a)-[:{relation_type_safe} {{name: "{relation}"}}]->(b);'
61
+ else:
62
+ statement += f'MERGE (a)-[:UNTYPED {{name: "{relation}"}}]->(b);'
63
+ else:
64
+ statement += ';' if statement != "" else ''
65
+ else:
66
+ if relation_type_safe: # if relation is not provided, create relation by `relation_type`.
67
+ statement += f'MERGE (a)-[:{relation_type_safe} {{name: "{relation_type_safe}"}}]->(b);'
68
+ else:
69
+ statement += ';' if statement != "" else ''
70
+ return statement
71
+
72
+ if "triple_list" in parsed_data:
73
+ for triple in parsed_data["triple_list"]:
74
+ cypher_statements.append(create_statement(triple))
75
+ else:
76
+ cypher_statements.append(create_statement(parsed_data))
77
+
78
+ return cypher_statements
79
+
80
+
81
+ def execute_cypher_statements(uri, user, password, cypher_statements):
82
+ """
83
+ Executes the generated Cypher query statements.
84
+ """
85
+ driver = GraphDatabase.driver(uri, auth=(user, password))
86
+
87
+ with driver.session() as session:
88
+ for statement in cypher_statements:
89
+ session.run(statement)
90
+ print(f"Executed: {statement}")
91
+
92
+ # Write excuted cypher statements to a text file if you want.
93
+ # with open("executed_statements.txt", 'a') as f:
94
+ # for statement in cypher_statements:
95
+ # f.write(statement + '\n')
96
+ # f.write('\n')
97
+
98
+ driver.close()
99
+
100
+
101
+ # Here is a test of your database connection:
102
+ if __name__ == "__main__":
103
+ # test_data 1: Contains a list of triples
104
+ test_data = '''
105
+ {
106
+ "triple_list": [
107
+ {
108
+ "head": "J.K. Rowling",
109
+ "head_type": "Person",
110
+ "relation": "wrote",
111
+ "relation_type": "Actions",
112
+ "tail": "Fantastic Beasts and Where to Find Them",
113
+ "tail_type": "Book"
114
+ },
115
+ {
116
+ "head": "Fantastic Beasts and Where to Find Them",
117
+ "head_type": "Book",
118
+ "relation": "extra section of",
119
+ "relation_type": "Affiliation",
120
+ "tail": "Harry Potter Series",
121
+ "tail_type": "Book"
122
+ },
123
+ {
124
+ "head": "J.K. Rowling",
125
+ "head_type": "Person",
126
+ "relation": "wrote",
127
+ "relation_type": "Actions",
128
+ "tail": "Harry Potter Series",
129
+ "tail_type": "Book"
130
+ },
131
+ {
132
+ "head": "Harry Potter Series",
133
+ "head_type": "Book",
134
+ "relation": "create",
135
+ "relation_type": "Actions",
136
+ "tail": "Dumbledore",
137
+ "tail_type": "Person"
138
+ },
139
+ {
140
+ "head": "Fantastic Beasts and Where to Find Them",
141
+ "head_type": "Book",
142
+ "relation": "mention",
143
+ "relation_type": "Actions",
144
+ "tail": "Dumbledore",
145
+ "tail_type": "Person"
146
+ },
147
+ {
148
+ "head": "Voldemort",
149
+ "head_type": "Person",
150
+ "relation": "afrid",
151
+ "relation_type": "Emotion",
152
+ "tail": "Dumbledore",
153
+ "tail_type": "Person"
154
+ },
155
+ {
156
+ "head": "Voldemort",
157
+ "head_type": "Person",
158
+ "relation": "robs",
159
+ "relation_type": "Actions",
160
+ "tail": "the Elder Wand",
161
+ "tail_type": "Weapon"
162
+ },
163
+ {
164
+ "head": "the Elder Wand",
165
+ "head_type": "Weapon",
166
+ "relation": "belong to",
167
+ "relation_type": "Affiliation",
168
+ "tail": "Dumbledore",
169
+ "tail_type": "Person"
170
+ }
171
+ ]
172
+ }
173
+ '''
174
+
175
+ # test_data 2: Contains a single triple
176
+ # test_data = '''
177
+ # {
178
+ # "head": "Christopher Nolan",
179
+ # "head_type": "Person",
180
+ # "relation": "directed",
181
+ # "relation_type": "Action",
182
+ # "tail": "Inception",
183
+ # "tail_type": "Movie"
184
+ # }
185
+ # '''
186
+
187
+ # Generate Cypher query statements
188
+ cypher_statements = generate_cypher_statements(test_data)
189
+
190
+ # Print the generated Cypher query statements
191
+ for statement in cypher_statements:
192
+ print(statement)
193
+ print("\n")
194
+
195
+ # Execute the generated Cypher query statements
196
+ execute_cypher_statements(
197
+ uri="neo4j://localhost:7687", # your URI
198
+ user="your_username", # your username
199
+ password="your_password", # your password
200
+ cypher_statements=cypher_statements,
201
+ )
src/models/llm_def.py CHANGED
@@ -22,7 +22,7 @@ class BaseEngine:
22
  self.top_p = 0.9
23
  self.max_tokens = 1024
24
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
-
26
  def get_chat_response(self, prompt):
27
  raise NotImplementedError
28
 
@@ -30,7 +30,7 @@ class BaseEngine:
30
  self.temperature = temperature
31
  self.top_p = top_p
32
  self.max_tokens = max_tokens
33
-
34
  class LLaMA(BaseEngine):
35
  def __init__(self, model_name_or_path: str):
36
  super().__init__(model_name_or_path)
@@ -61,7 +61,7 @@ class LLaMA(BaseEngine):
61
  top_p=self.top_p,
62
  )
63
  return outputs[0]["generated_text"][-1]['content'].strip()
64
-
65
  class Qwen(BaseEngine):
66
  def __init__(self, model_name_or_path: str):
67
  super().__init__(model_name_or_path)
@@ -72,7 +72,7 @@ class Qwen(BaseEngine):
72
  torch_dtype="auto",
73
  device_map="auto"
74
  )
75
-
76
  def get_chat_response(self, prompt):
77
  messages = [
78
  {"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."},
@@ -94,7 +94,7 @@ class Qwen(BaseEngine):
94
  output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
95
  ]
96
  response = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
97
-
98
  return response
99
 
100
  class MiniCPM(BaseEngine):
@@ -125,7 +125,7 @@ class MiniCPM(BaseEngine):
125
  model_outputs[i][len(model_inputs[i]):] for i in range(len(model_inputs))
126
  ]
127
  response = self.tokenizer.batch_decode(output_token_ids, skip_special_tokens=True)[0].strip()
128
-
129
  return response
130
 
131
  class ChatGLM(BaseEngine):
@@ -155,7 +155,7 @@ class ChatGLM(BaseEngine):
155
  )
156
  model_outputs = model_outputs[:, model_inputs['input_ids'].shape[1]:]
157
  response = self.tokenizer.batch_decode(model_outputs, skip_special_tokens=True)[0].strip()
158
-
159
  return response
160
 
161
  class OneKE(BaseEngine):
@@ -164,7 +164,7 @@ class OneKE(BaseEngine):
164
  self.name = "OneKE"
165
  self.model_id = model_name_or_path
166
  config = AutoConfig.from_pretrained(self.model_id, trust_remote_code=True)
167
- quantization_config=BitsAndBytesConfig(
168
  load_in_4bit=True,
169
  llm_int8_threshold=6.0,
170
  llm_int8_has_fp16_weight=False,
@@ -175,12 +175,12 @@ class OneKE(BaseEngine):
175
  self.model = AutoModelForCausalLM.from_pretrained(
176
  self.model_id,
177
  config=config,
178
- device_map="auto",
179
  quantization_config=quantization_config,
180
  torch_dtype=torch.bfloat16,
181
  trust_remote_code=True,
182
  )
183
-
184
  def get_chat_response(self, prompt):
185
  system_prompt = '<<SYS>>\nYou are a helpful assistant. δ½ ζ˜―δΈ€δΈͺδΉδΊŽεŠ©δΊΊηš„εŠ©ζ‰‹γ€‚\n<</SYS>>\n\n'
186
  sintruct = '[INST] ' + system_prompt + prompt + '[/INST]'
@@ -191,9 +191,9 @@ class OneKE(BaseEngine):
191
  generation_output = generation_output.sequences[0]
192
  generation_output = generation_output[input_length:]
193
  response = self.tokenizer.decode(generation_output, skip_special_tokens=True)
194
-
195
  return response
196
-
197
  class ChatGPT(BaseEngine):
198
  def __init__(self, model_name_or_path: str, api_key: str, base_url=openai.base_url):
199
  self.name = "ChatGPT"
@@ -207,7 +207,7 @@ class ChatGPT(BaseEngine):
207
  else:
208
  self.api_key = os.environ["OPENAI_API_KEY"]
209
  self.client = OpenAI(api_key=self.api_key, base_url=self.base_url)
210
-
211
  def get_chat_response(self, input):
212
  response = self.client.chat.completions.create(
213
  model=self.model,
@@ -234,7 +234,7 @@ class DeepSeek(BaseEngine):
234
  else:
235
  self.api_key = os.environ["DEEPSEEK_API_KEY"]
236
  self.client = OpenAI(api_key=self.api_key, base_url=self.base_url)
237
-
238
  def get_chat_response(self, input):
239
  response = self.client.chat.completions.create(
240
  model=self.model,
@@ -258,7 +258,7 @@ class LocalServer(BaseEngine):
258
  self.max_tokens = 1024
259
  self.api_key = "EMPTY_API_KEY"
260
  self.client = OpenAI(api_key=self.api_key, base_url=self.base_url)
261
-
262
  def get_chat_response(self, input):
263
  try:
264
  response = self.client.chat.completions.create(
@@ -276,4 +276,3 @@ class LocalServer(BaseEngine):
276
  print("Error: Unable to connect to the server. Please check if the vllm service is running and the port is 8080.")
277
  except Exception as e:
278
  print(f"Error: {e}")
279
-
 
22
  self.top_p = 0.9
23
  self.max_tokens = 1024
24
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
+
26
  def get_chat_response(self, prompt):
27
  raise NotImplementedError
28
 
 
30
  self.temperature = temperature
31
  self.top_p = top_p
32
  self.max_tokens = max_tokens
33
+
34
  class LLaMA(BaseEngine):
35
  def __init__(self, model_name_or_path: str):
36
  super().__init__(model_name_or_path)
 
61
  top_p=self.top_p,
62
  )
63
  return outputs[0]["generated_text"][-1]['content'].strip()
64
+
65
  class Qwen(BaseEngine):
66
  def __init__(self, model_name_or_path: str):
67
  super().__init__(model_name_or_path)
 
72
  torch_dtype="auto",
73
  device_map="auto"
74
  )
75
+
76
  def get_chat_response(self, prompt):
77
  messages = [
78
  {"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."},
 
94
  output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
95
  ]
96
  response = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
97
+
98
  return response
99
 
100
  class MiniCPM(BaseEngine):
 
125
  model_outputs[i][len(model_inputs[i]):] for i in range(len(model_inputs))
126
  ]
127
  response = self.tokenizer.batch_decode(output_token_ids, skip_special_tokens=True)[0].strip()
128
+
129
  return response
130
 
131
  class ChatGLM(BaseEngine):
 
155
  )
156
  model_outputs = model_outputs[:, model_inputs['input_ids'].shape[1]:]
157
  response = self.tokenizer.batch_decode(model_outputs, skip_special_tokens=True)[0].strip()
158
+
159
  return response
160
 
161
  class OneKE(BaseEngine):
 
164
  self.name = "OneKE"
165
  self.model_id = model_name_or_path
166
  config = AutoConfig.from_pretrained(self.model_id, trust_remote_code=True)
167
+ quantization_config=BitsAndBytesConfig(
168
  load_in_4bit=True,
169
  llm_int8_threshold=6.0,
170
  llm_int8_has_fp16_weight=False,
 
175
  self.model = AutoModelForCausalLM.from_pretrained(
176
  self.model_id,
177
  config=config,
178
+ device_map="auto",
179
  quantization_config=quantization_config,
180
  torch_dtype=torch.bfloat16,
181
  trust_remote_code=True,
182
  )
183
+
184
  def get_chat_response(self, prompt):
185
  system_prompt = '<<SYS>>\nYou are a helpful assistant. δ½ ζ˜―δΈ€δΈͺδΉδΊŽεŠ©δΊΊηš„εŠ©ζ‰‹γ€‚\n<</SYS>>\n\n'
186
  sintruct = '[INST] ' + system_prompt + prompt + '[/INST]'
 
191
  generation_output = generation_output.sequences[0]
192
  generation_output = generation_output[input_length:]
193
  response = self.tokenizer.decode(generation_output, skip_special_tokens=True)
194
+
195
  return response
196
+
197
  class ChatGPT(BaseEngine):
198
  def __init__(self, model_name_or_path: str, api_key: str, base_url=openai.base_url):
199
  self.name = "ChatGPT"
 
207
  else:
208
  self.api_key = os.environ["OPENAI_API_KEY"]
209
  self.client = OpenAI(api_key=self.api_key, base_url=self.base_url)
210
+
211
  def get_chat_response(self, input):
212
  response = self.client.chat.completions.create(
213
  model=self.model,
 
234
  else:
235
  self.api_key = os.environ["DEEPSEEK_API_KEY"]
236
  self.client = OpenAI(api_key=self.api_key, base_url=self.base_url)
237
+
238
  def get_chat_response(self, input):
239
  response = self.client.chat.completions.create(
240
  model=self.model,
 
258
  self.max_tokens = 1024
259
  self.api_key = "EMPTY_API_KEY"
260
  self.client = OpenAI(api_key=self.api_key, base_url=self.base_url)
261
+
262
  def get_chat_response(self, input):
263
  try:
264
  response = self.client.chat.completions.create(
 
276
  print("Error: Unable to connect to the server. Please check if the vllm service is running and the port is 8080.")
277
  except Exception as e:
278
  print(f"Error: {e}")
 
src/models/prompt_example.py CHANGED
@@ -2,7 +2,7 @@ json_schema_examples = """
2
  **Task**: Please extract all economic policies affecting the stock market between 2015 and 2023 and the exact dates of their implementation.
3
  **Text**: This text is from the field of Economics and represents the genre of Article.
4
  ...(example text)...
5
- **Output Schema**:
6
  {
7
  "economic_policies": [
8
  {
@@ -31,9 +31,9 @@ Example3:
31
  **Text**: This text is from the field of Political and represents the genre of News Report.
32
  ...(example text)...
33
  **Output Schema**:
34
- Answer:
35
  {
36
- "news_report":
37
  {
38
  "title": null,
39
  "summary": null,
@@ -56,9 +56,9 @@ Answer:
56
  """
57
 
58
  code_schema_examples = """
59
- Example1:
60
  **Task**: Extract all the entities in the given text.
61
- **Text**:
62
  ...(example text)...
63
  **Output Schema**:
64
  ```python
@@ -68,12 +68,12 @@ from pydantic import BaseModel, Field
68
  class Entity(BaseModel):
69
  label : str = Field(description="The type or category of the entity, such as 'Process', 'Technique', 'Data Structure', 'Methodology', 'Person', etc. ")
70
  name : str = Field(description="The specific name of the entity. It should represent a single, distinct concept and must not be an empty string. For example, if the entity is a 'Technique', the name could be 'Neural Networks'.")
71
-
72
  class ExtractionTarget(BaseModel):
73
  entity_list : List[Entity] = Field(description="All the entities presented in the context. The entities should encode ONE concept.")
74
  ```
75
 
76
- Example2:
77
  **Task**: Extract all the information in the given text.
78
  **Text**: This text is from the field of Political and represents the genre of News Article.
79
  ...(example text)...
@@ -95,7 +95,7 @@ class Event(BaseModel):
95
  process: Optional[str] = Field(description="Details of the event process")
96
  result: Optional[str] = Field(default=None, description="Result or outcome of the event")
97
 
98
- class NewsReport(BaseModel):
99
  title: str = Field(description="The title or headline of the news report")
100
  summary: str = Field(description="A brief summary of the news report")
101
  publication_date: Optional[str] = Field(description="The publication date of the report")
@@ -116,16 +116,16 @@ from pydantic import BaseModel, Field
116
  class MetaData(BaseModel):
117
  title : str = Field(description="The title of the article")
118
  authors : List[str] = Field(description="The list of the article's authors")
119
- abstract: str = Field(description="The article's abstract")
120
  key_words: List[str] = Field(description="The key words associated with the article")
121
-
122
  class Baseline(BaseModel):
123
  method_name : str = Field(description="The name of the baseline method")
124
  proposed_solution : str = Field(description="the proposed solution in details")
125
  performance_metrics : str = Field(description="The performance metrics of the method and comparative analysis")
126
-
127
  class ExtractionTarget(BaseModel):
128
-
129
  key_contributions: List[str] = Field(description="The key contributions of the article")
130
  limitation_of_sota : str=Field(description="the summary limitation of the existing work")
131
  proposed_solution : str = Field(description="the proposed solution in details")
 
2
  **Task**: Please extract all economic policies affecting the stock market between 2015 and 2023 and the exact dates of their implementation.
3
  **Text**: This text is from the field of Economics and represents the genre of Article.
4
  ...(example text)...
5
+ **Output Schema**:
6
  {
7
  "economic_policies": [
8
  {
 
31
  **Text**: This text is from the field of Political and represents the genre of News Report.
32
  ...(example text)...
33
  **Output Schema**:
34
+ Answer:
35
  {
36
+ "news_report":
37
  {
38
  "title": null,
39
  "summary": null,
 
56
  """
57
 
58
  code_schema_examples = """
59
+ Example1:
60
  **Task**: Extract all the entities in the given text.
61
+ **Text**:
62
  ...(example text)...
63
  **Output Schema**:
64
  ```python
 
68
  class Entity(BaseModel):
69
  label : str = Field(description="The type or category of the entity, such as 'Process', 'Technique', 'Data Structure', 'Methodology', 'Person', etc. ")
70
  name : str = Field(description="The specific name of the entity. It should represent a single, distinct concept and must not be an empty string. For example, if the entity is a 'Technique', the name could be 'Neural Networks'.")
71
+
72
  class ExtractionTarget(BaseModel):
73
  entity_list : List[Entity] = Field(description="All the entities presented in the context. The entities should encode ONE concept.")
74
  ```
75
 
76
+ Example2:
77
  **Task**: Extract all the information in the given text.
78
  **Text**: This text is from the field of Political and represents the genre of News Article.
79
  ...(example text)...
 
95
  process: Optional[str] = Field(description="Details of the event process")
96
  result: Optional[str] = Field(default=None, description="Result or outcome of the event")
97
 
98
+ class NewsReport(BaseModel):
99
  title: str = Field(description="The title or headline of the news report")
100
  summary: str = Field(description="A brief summary of the news report")
101
  publication_date: Optional[str] = Field(description="The publication date of the report")
 
116
  class MetaData(BaseModel):
117
  title : str = Field(description="The title of the article")
118
  authors : List[str] = Field(description="The list of the article's authors")
119
+ abstract: str = Field(description="The article's abstract")
120
  key_words: List[str] = Field(description="The key words associated with the article")
121
+
122
  class Baseline(BaseModel):
123
  method_name : str = Field(description="The name of the baseline method")
124
  proposed_solution : str = Field(description="the proposed solution in details")
125
  performance_metrics : str = Field(description="The performance metrics of the method and comparative analysis")
126
+
127
  class ExtractionTarget(BaseModel):
128
+
129
  key_contributions: List[str] = Field(description="The key contributions of the article")
130
  limitation_of_sota : str=Field(description="the summary limitation of the existing work")
131
  proposed_solution : str = Field(description="the proposed solution in details")
src/models/prompt_template.py CHANGED
@@ -1,9 +1,9 @@
1
  from langchain.prompts import PromptTemplate
2
  from .prompt_example import *
3
 
4
- # ==================================================================== #
5
- # SCHEMA AGENT #
6
- # ==================================================================== #
7
 
8
  # Get Text Analysis
9
  TEXT_ANALYSIS_INSTRUCTION = """
@@ -22,9 +22,9 @@ text_analysis_instruction = PromptTemplate(
22
  # Get Deduced Schema Json
23
  DEDUCE_SCHEMA_JSON_INSTRUCTION = """
24
  **Instruction**: Generate an output format that meets the requirements as described in the task. Pay attention to the following requirements:
25
- - Format: Return your responses in dictionary format as a JSON object.
26
  - Content: Do not include any actual data; all attributes values should be set to None.
27
- - Note: Attributes not mentioned in the task description should be ignored.
28
  {examples}
29
  **Task**: {instruction}
30
 
@@ -57,9 +57,9 @@ deduced_schema_code_instruction = PromptTemplate(
57
  )
58
 
59
 
60
- # ==================================================================== #
61
- # EXTRACTION AGENT #
62
- # ==================================================================== #
63
 
64
  EXTRACT_INSTRUCTION = """
65
  **Instruction**: You are an agent skilled in information extarction. {instruction}
@@ -113,9 +113,9 @@ summarize_instruction = PromptTemplate(
113
 
114
 
115
 
116
- # ==================================================================== #
117
- # REFLECION AGENT #
118
- # ==================================================================== #
119
  REFLECT_INSTRUCTION = """**Instruction**: You are an agent skilled in reflection and optimization based on the original result. Refer to **Reflection Reference** to identify potential issues in the current extraction results.
120
 
121
  **Reflection Reference**: {examples}
@@ -153,9 +153,9 @@ summarize_instruction = PromptTemplate(
153
 
154
 
155
 
156
- # ==================================================================== #
157
- # CASE REPOSITORY #
158
- # ==================================================================== #
159
 
160
  GOOD_CASE_ANALYSIS_INSTRUCTION = """
161
  **Instruction**: Below is an information extraction task and its corresponding correct answer. Provide the reasoning steps that led to the correct answer, along with brief explanation of the answer. Your response should be brief and organized.
 
1
  from langchain.prompts import PromptTemplate
2
  from .prompt_example import *
3
 
4
+ # ==================================================================== #
5
+ # SCHEMA AGENT #
6
+ # ==================================================================== #
7
 
8
  # Get Text Analysis
9
  TEXT_ANALYSIS_INSTRUCTION = """
 
22
  # Get Deduced Schema Json
23
  DEDUCE_SCHEMA_JSON_INSTRUCTION = """
24
  **Instruction**: Generate an output format that meets the requirements as described in the task. Pay attention to the following requirements:
25
+ - Format: Return your responses in dictionary format as a JSON object.
26
  - Content: Do not include any actual data; all attributes values should be set to None.
27
+ - Note: Attributes not mentioned in the task description should be ignored.
28
  {examples}
29
  **Task**: {instruction}
30
 
 
57
  )
58
 
59
 
60
+ # ==================================================================== #
61
+ # EXTRACTION AGENT #
62
+ # ==================================================================== #
63
 
64
  EXTRACT_INSTRUCTION = """
65
  **Instruction**: You are an agent skilled in information extarction. {instruction}
 
113
 
114
 
115
 
116
+ # ==================================================================== #
117
+ # REFLECION AGENT #
118
+ # ==================================================================== #
119
  REFLECT_INSTRUCTION = """**Instruction**: You are an agent skilled in reflection and optimization based on the original result. Refer to **Reflection Reference** to identify potential issues in the current extraction results.
120
 
121
  **Reflection Reference**: {examples}
 
153
 
154
 
155
 
156
+ # ==================================================================== #
157
+ # CASE REPOSITORY #
158
+ # ==================================================================== #
159
 
160
  GOOD_CASE_ANALYSIS_INSTRUCTION = """
161
  **Instruction**: Below is an information extraction task and its corresponding correct answer. Provide the reasoning steps that led to the correct answer, along with brief explanation of the answer. Your response should be brief and organized.
src/models/vllm_serve.py CHANGED
@@ -9,13 +9,13 @@ from utils import *
9
  def main():
10
  # Create command-line argument parser
11
  parser = argparse.ArgumentParser(description='Run the extraction model.')
12
- parser.add_argument('--config', type=str, required=True,
13
  help='Path to the YAML configuration file.')
14
  parser.add_argument('--tensor-parallel-size', type=int, default=2,
15
  help='Tensor parallel size for the VLLM server.')
16
  parser.add_argument('--max-model-len', type=int, default=32768,
17
  help='Maximum model length for the VLLM server.')
18
-
19
  # Parse command-line arguments
20
  args = parser.parse_args()
21
 
@@ -31,4 +31,3 @@ def main():
31
 
32
  if __name__ == "__main__":
33
  main()
34
-
 
9
  def main():
10
  # Create command-line argument parser
11
  parser = argparse.ArgumentParser(description='Run the extraction model.')
12
+ parser.add_argument('--config', type=str, required=True,
13
  help='Path to the YAML configuration file.')
14
  parser.add_argument('--tensor-parallel-size', type=int, default=2,
15
  help='Tensor parallel size for the VLLM server.')
16
  parser.add_argument('--max-model-len', type=int, default=32768,
17
  help='Maximum model length for the VLLM server.')
18
+
19
  # Parse command-line arguments
20
  args = parser.parse_args()
21
 
 
31
 
32
  if __name__ == "__main__":
33
  main()
 
src/modules/extraction_agent.py CHANGED
@@ -65,6 +65,34 @@ class ExtractionAgent:
65
  data.constraint = json.dumps(result)
66
  except:
67
  print("Invalid Constraint: Event Extraction constraint must be a dictionary with event types as keys and lists of arguments as values.", data.constraint)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  return data
69
 
70
  def extract_information_direct(self, data: DataPoint):
 
65
  data.constraint = json.dumps(result)
66
  except:
67
  print("Invalid Constraint: Event Extraction constraint must be a dictionary with event types as keys and lists of arguments as values.", data.constraint)
68
+ elif data.task == "Triple":
69
+ constraint = json.dumps(data.constraint)
70
+ if "**Triple Extraction Constraint**" in constraint:
71
+ return data
72
+ if self.llm.name != "OneKE":
73
+ if len(data.constraint) == 1: # 1 list means entity
74
+ data.constraint = f"\n**Triple Extraction Constraint**: Entities type must chosen from following list:\n{constraint}\n"
75
+ elif len(data.constraint) == 2: # 2 list means entity and relation
76
+ if data.constraint[0] == []:
77
+ data.constraint = f"\n**Triple Extraction Constraint**: Relation type must chosen from following list:\n{data.constraint[1]}\n"
78
+ elif data.constraint[1] == []:
79
+ data.constraint = f"\n**Triple Extraction Constraint**: Entities type must chosen from following list:\n{data.constraint[0]}\n"
80
+ else:
81
+ data.constraint = f"\n**Triple Extraction Constraint**: Entities type must chosen from following list:\n{data.constraint[0]}\nRelation type must chosen from following list:\n{data.constraint[1]}\n"
82
+ elif len(data.constraint) == 3: # 3 list means entity, relation and object
83
+ if data.constraint[0] == []:
84
+ data.constraint = f"\n**Triple Extraction Constraint**: Relation type must chosen from following list:\n{data.constraint[1]}\nObject Entities must chosen from following list:\n{data.constraint[2]}\n"
85
+ elif data.constraint[1] == []:
86
+ data.constraint = f"\n**Triple Extraction Constraint**: Subject Entities must chosen from following list:\n{data.constraint[0]}\nObject Entities must chosen from following list:\n{data.constraint[2]}\n"
87
+ elif data.constraint[2] == []:
88
+ data.constraint = f"\n**Triple Extraction Constraint**: Subject Entities must chosen from following list:\n{data.constraint[0]}\nRelation type must chosen from following list:\n{data.constraint[1]}\n"
89
+ else:
90
+ data.constraint = f"\n**Triple Extraction Constraint**: Subject Entities must chosen from following list:\n{data.constraint[0]}\nRelation type must chosen from following list:\n{data.constraint[1]}\nObject Entities must chosen from following list:\n{data.constraint[2]}\n"
91
+ else:
92
+ data.constraint = f"\n**Triple Extraction Constraint**: The type of entities must be chosen from the following list:\n{constraint}\n"
93
+ else:
94
+ print("OneKE does not support Triple Extraction task now, please wait for the next version.")
95
+ # print("data.constraint", data.constraint)
96
  return data
97
 
98
  def extract_information_direct(self, data: DataPoint):
src/modules/knowledge_base/case_repository.py CHANGED
@@ -19,7 +19,7 @@ class CaseRepository:
19
  self.embedder = SentenceTransformer(docker_model_path)
20
  except:
21
  self.embedder = SentenceTransformer(config['model']['embedding_model'])
22
- self.embedder.to(device)
23
  self.corpus = self.load_corpus()
24
  self.embedded_corpus = self.embed_corpus()
25
 
@@ -27,14 +27,14 @@ class CaseRepository:
27
  with open(os.path.join(os.path.dirname(__file__), "case_repository.json")) as file:
28
  corpus = json.load(file)
29
  return corpus
30
-
31
  def update_corpus(self):
32
  try:
33
  with open(os.path.join(os.path.dirname(__file__), "case_repository.json"), "w") as file:
34
  json.dump(self.corpus, file, indent=2)
35
  except Exception as e:
36
  print(f"Error when updating corpus: {e}")
37
-
38
  def embed_corpus(self):
39
  embedded_corpus = {}
40
  for key, content in self.corpus.items():
@@ -43,8 +43,8 @@ class CaseRepository:
43
  bad_index = [item['index']['embed_index'] for item in content['bad']]
44
  encoded_bad_index = self.embedder.encode(bad_index, convert_to_tensor=True).to(device)
45
  embedded_corpus[key] = {"good": encoded_good_index, "bad": encoded_bad_index}
46
- return embedded_corpus
47
-
48
  def get_similarity_scores(self, task: TaskType, embed_index="", str_index="", case_type="", top_k=2):
49
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
50
  # Embedding similarity match
@@ -58,7 +58,7 @@ class CaseRepository:
58
  scores_dict = {match[0]: match[1] for match in str_similarity_results}
59
  scores_in_order = [scores_dict[candidate] for candidate in str_match_corpus]
60
  str_similarity_scores = torch.tensor(scores_in_order, dtype=torch.float32).to(device)
61
-
62
  # Normalize scores
63
  embedding_score_range = embedding_similarity_scores.max() - embedding_similarity_scores.min()
64
  str_score_range = str_similarity_scores.max() - str_similarity_scores.min()
@@ -74,16 +74,16 @@ class CaseRepository:
74
  # Combine the scores with weights
75
  combined_scores = 0.5 * embed_norm_scores + 0.5 * str_norm_scores
76
  original_combined_scores = 0.5 * embedding_similarity_scores + 0.5 * str_similarity_scores / 100
77
-
78
  scores, indices = torch.topk(combined_scores, k=min(top_k, combined_scores.size(0)))
79
  original_scores, original_indices = torch.topk(original_combined_scores, k=min(top_k, original_combined_scores.size(0)))
80
  return scores, indices, original_scores, original_indices
81
-
82
  def query_case(self, task: TaskType, embed_index="", str_index="", case_type="", top_k=2) -> list:
83
  _, indices, _, _ = self.get_similarity_scores(task, embed_index, str_index, case_type, top_k)
84
  top_matches = [self.corpus[task][case_type][idx]["content"] for idx in indices]
85
  return top_matches
86
-
87
  def update_case(self, task: TaskType, embed_index="", str_index="", content="" ,case_type=""):
88
  self.corpus[task][case_type].append({"index": {"embed_index": embed_index, "str_index": str_index}, "content": content})
89
  self.embedded_corpus[task][case_type] = torch.cat([self.embedded_corpus[task][case_type], self.embedder.encode([embed_index], convert_to_tensor=True).to(device)], dim=0)
@@ -102,9 +102,9 @@ class CaseRepositoryHandler:
102
  response = self.llm.get_chat_response(prompt)
103
  response = extract_json_dict(response)
104
  if not isinstance(response, dict):
105
- return response
106
  return None
107
-
108
  def __get_bad_case_reflection(self, instruction="", text="", original_answer="", correct_answer="", additional_info=""):
109
  prompt = bad_case_reflection_instruction.format(
110
  instruction=instruction, text=text, original_answer=original_answer, correct_answer=correct_answer, additional_info=additional_info
@@ -115,34 +115,34 @@ class CaseRepositoryHandler:
115
  if not isinstance(response, dict):
116
  return response
117
  return None
118
-
119
  def __get_index(self, data: DataPoint, case_type: str):
120
  # set embed_index
121
  embed_index = f"**Text**: {data.distilled_text}\n{data.chunk_text_list[0]}"
122
-
123
  # set str_index
124
  if data.task == "Base":
125
  str_index = f"**Task**: {data.instruction}"
126
  else:
127
  str_index = f"{data.constraint}"
128
-
129
  if case_type == "bad":
130
  str_index += f"\n\n**Original Result**: {json.dumps(data.pred)}"
131
-
132
  return embed_index, str_index
133
-
134
  def query_good_case(self, data: DataPoint):
135
  embed_index, str_index = self.__get_index(data, "good")
136
  return self.repository.query_case(task=data.task, embed_index=embed_index, str_index=str_index, case_type="good")
137
-
138
  def query_bad_case(self, data: DataPoint):
139
  embed_index, str_index = self.__get_index(data, "bad")
140
  return self.repository.query_case(task=data.task, embed_index=embed_index, str_index=str_index, case_type="bad")
141
-
142
  def update_good_case(self, data: DataPoint):
143
  if data.truth == "" :
144
  print("No truth value provided.")
145
- return
146
  embed_index, str_index = self.__get_index(data, "good")
147
  _, _, original_scores, _ = self.repository.get_similarity_scores(data.task, embed_index, str_index, "good", 1)
148
  original_scores = original_scores.tolist()
@@ -159,11 +159,11 @@ class CaseRepositoryHandler:
159
  else:
160
  content = f"{wrapped_text}\n\n{data.constraint}\n\n{wrapped_good_case_analysis}\n\n{wrapped_answer}"
161
  self.repository.update_case(data.task, embed_index, str_index, content, "good")
162
-
163
  def update_bad_case(self, data: DataPoint):
164
  if data.truth == "" :
165
  print("No truth value provided.")
166
- return
167
  if normalize_obj(data.pred) == normalize_obj(data.truth):
168
  return
169
  embed_index, str_index = self.__get_index(data, "bad")
@@ -183,7 +183,7 @@ class CaseRepositoryHandler:
183
  else:
184
  content = f"{wrapped_text}\n\n{data.constraint}\n\n{wrapper_original_answer}\n\n{wrapped_bad_case_reflection}\n\n{wrapper_correct_answer}"
185
  self.repository.update_case(data.task, embed_index, str_index, content, "bad")
186
-
187
  def update_case(self, data: DataPoint):
188
  self.update_good_case(data)
189
  self.update_bad_case(data)
 
19
  self.embedder = SentenceTransformer(docker_model_path)
20
  except:
21
  self.embedder = SentenceTransformer(config['model']['embedding_model'])
22
+ self.embedder.to(device)
23
  self.corpus = self.load_corpus()
24
  self.embedded_corpus = self.embed_corpus()
25
 
 
27
  with open(os.path.join(os.path.dirname(__file__), "case_repository.json")) as file:
28
  corpus = json.load(file)
29
  return corpus
30
+
31
  def update_corpus(self):
32
  try:
33
  with open(os.path.join(os.path.dirname(__file__), "case_repository.json"), "w") as file:
34
  json.dump(self.corpus, file, indent=2)
35
  except Exception as e:
36
  print(f"Error when updating corpus: {e}")
37
+
38
  def embed_corpus(self):
39
  embedded_corpus = {}
40
  for key, content in self.corpus.items():
 
43
  bad_index = [item['index']['embed_index'] for item in content['bad']]
44
  encoded_bad_index = self.embedder.encode(bad_index, convert_to_tensor=True).to(device)
45
  embedded_corpus[key] = {"good": encoded_good_index, "bad": encoded_bad_index}
46
+ return embedded_corpus
47
+
48
  def get_similarity_scores(self, task: TaskType, embed_index="", str_index="", case_type="", top_k=2):
49
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
50
  # Embedding similarity match
 
58
  scores_dict = {match[0]: match[1] for match in str_similarity_results}
59
  scores_in_order = [scores_dict[candidate] for candidate in str_match_corpus]
60
  str_similarity_scores = torch.tensor(scores_in_order, dtype=torch.float32).to(device)
61
+
62
  # Normalize scores
63
  embedding_score_range = embedding_similarity_scores.max() - embedding_similarity_scores.min()
64
  str_score_range = str_similarity_scores.max() - str_similarity_scores.min()
 
74
  # Combine the scores with weights
75
  combined_scores = 0.5 * embed_norm_scores + 0.5 * str_norm_scores
76
  original_combined_scores = 0.5 * embedding_similarity_scores + 0.5 * str_similarity_scores / 100
77
+
78
  scores, indices = torch.topk(combined_scores, k=min(top_k, combined_scores.size(0)))
79
  original_scores, original_indices = torch.topk(original_combined_scores, k=min(top_k, original_combined_scores.size(0)))
80
  return scores, indices, original_scores, original_indices
81
+
82
  def query_case(self, task: TaskType, embed_index="", str_index="", case_type="", top_k=2) -> list:
83
  _, indices, _, _ = self.get_similarity_scores(task, embed_index, str_index, case_type, top_k)
84
  top_matches = [self.corpus[task][case_type][idx]["content"] for idx in indices]
85
  return top_matches
86
+
87
  def update_case(self, task: TaskType, embed_index="", str_index="", content="" ,case_type=""):
88
  self.corpus[task][case_type].append({"index": {"embed_index": embed_index, "str_index": str_index}, "content": content})
89
  self.embedded_corpus[task][case_type] = torch.cat([self.embedded_corpus[task][case_type], self.embedder.encode([embed_index], convert_to_tensor=True).to(device)], dim=0)
 
102
  response = self.llm.get_chat_response(prompt)
103
  response = extract_json_dict(response)
104
  if not isinstance(response, dict):
105
+ return response
106
  return None
107
+
108
  def __get_bad_case_reflection(self, instruction="", text="", original_answer="", correct_answer="", additional_info=""):
109
  prompt = bad_case_reflection_instruction.format(
110
  instruction=instruction, text=text, original_answer=original_answer, correct_answer=correct_answer, additional_info=additional_info
 
115
  if not isinstance(response, dict):
116
  return response
117
  return None
118
+
119
  def __get_index(self, data: DataPoint, case_type: str):
120
  # set embed_index
121
  embed_index = f"**Text**: {data.distilled_text}\n{data.chunk_text_list[0]}"
122
+
123
  # set str_index
124
  if data.task == "Base":
125
  str_index = f"**Task**: {data.instruction}"
126
  else:
127
  str_index = f"{data.constraint}"
128
+
129
  if case_type == "bad":
130
  str_index += f"\n\n**Original Result**: {json.dumps(data.pred)}"
131
+
132
  return embed_index, str_index
133
+
134
  def query_good_case(self, data: DataPoint):
135
  embed_index, str_index = self.__get_index(data, "good")
136
  return self.repository.query_case(task=data.task, embed_index=embed_index, str_index=str_index, case_type="good")
137
+
138
  def query_bad_case(self, data: DataPoint):
139
  embed_index, str_index = self.__get_index(data, "bad")
140
  return self.repository.query_case(task=data.task, embed_index=embed_index, str_index=str_index, case_type="bad")
141
+
142
  def update_good_case(self, data: DataPoint):
143
  if data.truth == "" :
144
  print("No truth value provided.")
145
+ return
146
  embed_index, str_index = self.__get_index(data, "good")
147
  _, _, original_scores, _ = self.repository.get_similarity_scores(data.task, embed_index, str_index, "good", 1)
148
  original_scores = original_scores.tolist()
 
159
  else:
160
  content = f"{wrapped_text}\n\n{data.constraint}\n\n{wrapped_good_case_analysis}\n\n{wrapped_answer}"
161
  self.repository.update_case(data.task, embed_index, str_index, content, "good")
162
+
163
  def update_bad_case(self, data: DataPoint):
164
  if data.truth == "" :
165
  print("No truth value provided.")
166
+ return
167
  if normalize_obj(data.pred) == normalize_obj(data.truth):
168
  return
169
  embed_index, str_index = self.__get_index(data, "bad")
 
183
  else:
184
  content = f"{wrapped_text}\n\n{data.constraint}\n\n{wrapper_original_answer}\n\n{wrapped_bad_case_reflection}\n\n{wrapper_correct_answer}"
185
  self.repository.update_case(data.task, embed_index, str_index, content, "bad")
186
+
187
  def update_case(self, data: DataPoint):
188
  self.update_good_case(data)
189
  self.update_bad_case(data)
src/modules/knowledge_base/schema_repository.py CHANGED
@@ -33,6 +33,19 @@ class Event(BaseModel):
33
  class EventList(BaseModel):
34
  event_list : List[Event] = Field(description="The events presented in the text.")
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  # ==================================================================== #
37
  # TEXT DESCRIPTION #
38
  # ==================================================================== #
 
33
  class EventList(BaseModel):
34
  event_list : List[Event] = Field(description="The events presented in the text.")
35
 
36
+ # ==================================================================== #
37
+ # Triple TASK #
38
+ # ==================================================================== #
39
+ class Triple(BaseModel):
40
+ head: str = Field(description="The subject or head of the triple.")
41
+ head_type: str = Field(description="The type of the subject entity.")
42
+ relation: str = Field(description="The predicate or relation between the entities.")
43
+ relation_type: str = Field(description="The type of the relation.")
44
+ tail: str = Field(description="The object or tail of the triple.")
45
+ tail_type: str = Field(description="The type of the object entity.")
46
+ class TripleList(BaseModel):
47
+ triple_list: List[Triple] = Field(description="The collection of triples and their types presented in the text.")
48
+
49
  # ==================================================================== #
50
  # TEXT DESCRIPTION #
51
  # ==================================================================== #
src/modules/reflection_agent.py CHANGED
@@ -5,7 +5,7 @@ from .knowledge_base.case_repository import CaseRepositoryHandler
5
  class ReflectionGenerator:
6
  def __init__(self, llm: BaseEngine):
7
  self.llm = llm
8
-
9
  def get_reflection(self, instruction="", examples="", text="",schema="", result=""):
10
  result = json.dumps(result)
11
  examples = bad_case_wrapper(examples)
@@ -13,7 +13,7 @@ class ReflectionGenerator:
13
  response = self.llm.get_chat_response(prompt)
14
  response = extract_json_dict(response)
15
  return response
16
-
17
  class ReflectionAgent:
18
  def __init__(self, llm: BaseEngine, case_repo: CaseRepositoryHandler):
19
  self.llm = llm
@@ -29,7 +29,7 @@ class ReflectionAgent:
29
  else:
30
  selected_obj = max(result_list, key=lambda o: len(json.dumps(o)))
31
  return selected_obj
32
-
33
  def __self_consistance_check(self, data: DataPoint):
34
  extract_func = list(data.result_trajectory.keys())[-1]
35
  if hasattr(self.extractor, extract_func):
@@ -55,7 +55,7 @@ class ReflectionAgent:
55
  consistant_result.append(selected_element)
56
  data.set_result_list(consistant_result)
57
  return reflect_index
58
-
59
  def reflect_with_case(self, data: DataPoint):
60
  if data.result_list == []:
61
  return data
@@ -71,4 +71,3 @@ class ReflectionAgent:
71
  function_name = current_function_name()
72
  data.update_trajectory(function_name, data.result_list)
73
  return data
74
-
 
5
  class ReflectionGenerator:
6
  def __init__(self, llm: BaseEngine):
7
  self.llm = llm
8
+
9
  def get_reflection(self, instruction="", examples="", text="",schema="", result=""):
10
  result = json.dumps(result)
11
  examples = bad_case_wrapper(examples)
 
13
  response = self.llm.get_chat_response(prompt)
14
  response = extract_json_dict(response)
15
  return response
16
+
17
  class ReflectionAgent:
18
  def __init__(self, llm: BaseEngine, case_repo: CaseRepositoryHandler):
19
  self.llm = llm
 
29
  else:
30
  selected_obj = max(result_list, key=lambda o: len(json.dumps(o)))
31
  return selected_obj
32
+
33
  def __self_consistance_check(self, data: DataPoint):
34
  extract_func = list(data.result_trajectory.keys())[-1]
35
  if hasattr(self.extractor, extract_func):
 
55
  consistant_result.append(selected_element)
56
  data.set_result_list(consistant_result)
57
  return reflect_index
58
+
59
  def reflect_with_case(self, data: DataPoint):
60
  if data.result_list == []:
61
  return data
 
71
  function_name = current_function_name()
72
  data.update_trajectory(function_name, data.result_list)
73
  return data
 
src/modules/schema_agent.py CHANGED
@@ -106,6 +106,18 @@ class Event(BaseModel):
106
  class EventList(BaseModel):
107
  event_list : List[Event] = Field(description="The events presented in the text.")
108
  """
 
 
 
 
 
 
 
 
 
 
 
 
109
  return data
110
 
111
  def get_default_schema(self, data: DataPoint):
 
106
  class EventList(BaseModel):
107
  event_list : List[Event] = Field(description="The events presented in the text.")
108
  """
109
+ elif data.task == "Triple":
110
+ data.print_schema = """
111
+ class Triple(BaseModel):
112
+ head: str = Field(description="The subject or head of the triple.")
113
+ head_type: str = Field(description="The type of the subject entity.")
114
+ relation: str = Field(description="The predicate or relation between the entities.")
115
+ relation_type: str = Field(description="The type of the relation.")
116
+ tail: str = Field(description="The object or tail of the triple.")
117
+ tail_type: str = Field(description="The type of the object entity.")
118
+ class TripleList(BaseModel):
119
+ triple_list: List[Triple] = Field(description="The collection of triples and their types presented in the text.")
120
+ """
121
  return data
122
 
123
  def get_default_schema(self, data: DataPoint):
src/pipeline.py CHANGED
@@ -2,6 +2,7 @@ from typing import Literal
2
  from models import *
3
  from utils import *
4
  from modules import *
 
5
 
6
 
7
  class Pipeline:
@@ -14,7 +15,7 @@ class Pipeline:
14
 
15
  def __check_consistancy(self, llm, task, mode, update_case):
16
  if llm.name == "OneKE":
17
- if task == "Base":
18
  raise ValueError("The finetuned OneKE only supports quick extraction mode for NER, RE and EE Task.")
19
  else:
20
  mode = "quick"
@@ -44,12 +45,16 @@ class Pipeline:
44
  elif data.task == "EE":
45
  data.instruction = config['agent']['default_ee']
46
  data.output_schema = "EventList"
 
 
 
47
  return data
48
 
49
  # main entry
50
  def get_extract_result(self,
51
  task: TaskType,
52
  three_agents = {},
 
53
  instruction: str = "",
54
  text: str = "",
55
  output_schema: str = "",
@@ -61,6 +66,7 @@ class Pipeline:
61
  update_case: bool = False,
62
  show_trajectory: bool = False,
63
  isgui: bool = False,
 
64
  ):
65
  # for key, value in locals().items():
66
  # print(f"{key}: {value}")
@@ -105,7 +111,17 @@ class Pipeline:
105
  # show result
106
  if show_trajectory:
107
  print("Extraction Trajectory: \n", json.dumps(data.get_result_trajectory(), indent=2))
108
- print("Extraction Result: \n", json.dumps(data.pred, indent=2))
 
 
 
 
 
 
 
 
 
 
109
 
110
  frontend_res = data.pred #
111
 
 
2
  from models import *
3
  from utils import *
4
  from modules import *
5
+ from construct import *
6
 
7
 
8
  class Pipeline:
 
15
 
16
  def __check_consistancy(self, llm, task, mode, update_case):
17
  if llm.name == "OneKE":
18
+ if task == "Base" or task == "Triple":
19
  raise ValueError("The finetuned OneKE only supports quick extraction mode for NER, RE and EE Task.")
20
  else:
21
  mode = "quick"
 
45
  elif data.task == "EE":
46
  data.instruction = config['agent']['default_ee']
47
  data.output_schema = "EventList"
48
+ elif data.task == "Triple":
49
+ data.instruction = config['agent']['default_triple']
50
+ data.output_schema = "TripleList"
51
  return data
52
 
53
  # main entry
54
  def get_extract_result(self,
55
  task: TaskType,
56
  three_agents = {},
57
+ construct = {},
58
  instruction: str = "",
59
  text: str = "",
60
  output_schema: str = "",
 
66
  update_case: bool = False,
67
  show_trajectory: bool = False,
68
  isgui: bool = False,
69
+ iskg: bool = False,
70
  ):
71
  # for key, value in locals().items():
72
  # print(f"{key}: {value}")
 
111
  # show result
112
  if show_trajectory:
113
  print("Extraction Trajectory: \n", json.dumps(data.get_result_trajectory(), indent=2))
114
+ extraction_result = json.dumps(data.pred, indent=2)
115
+ print("Extraction Result: \n", extraction_result)
116
+
117
+ # construct KG
118
+ if iskg:
119
+ myurl = construct['url']
120
+ myusername = construct['username']
121
+ mypassword = construct['password']
122
+ print(f"Construct KG in your {construct['database']} now...")
123
+ cypher_statements = generate_cypher_statements(extraction_result)
124
+ execute_cypher_statements(uri=myurl, user=myusername, password=mypassword, cypher_statements=cypher_statements)
125
 
126
  frontend_res = data.pred #
127
 
src/run.py CHANGED
@@ -11,9 +11,9 @@ from modules import *
11
  def main():
12
  # Create command-line argument parser
13
  parser = argparse.ArgumentParser(description='Run the extraction framefork.')
14
- parser.add_argument('--config', type=str, required=True,
15
  help='Path to the YAML configuration file.')
16
-
17
  # Parse command-line arguments
18
  args = parser.parse_args()
19
 
@@ -35,6 +35,15 @@ def main():
35
  pipeline = Pipeline(model)
36
  # Extraction config
37
  extraction_config = config['extraction']
 
 
 
 
 
 
 
 
 
38
  result, trajectory, _, _ = pipeline.get_extract_result(task=extraction_config['task'], instruction=extraction_config['instruction'], text=extraction_config['text'], output_schema=extraction_config['output_schema'], constraint=extraction_config['constraint'], use_file=extraction_config['use_file'], file_path=extraction_config['file_path'], truth=extraction_config['truth'], mode=extraction_config['mode'], update_case=extraction_config['update_case'], show_trajectory=extraction_config['show_trajectory'])
39
  return
40
 
 
11
  def main():
12
  # Create command-line argument parser
13
  parser = argparse.ArgumentParser(description='Run the extraction framefork.')
14
+ parser.add_argument('--config', type=str, required=True,
15
  help='Path to the YAML configuration file.')
16
+
17
  # Parse command-line arguments
18
  args = parser.parse_args()
19
 
 
35
  pipeline = Pipeline(model)
36
  # Extraction config
37
  extraction_config = config['extraction']
38
+ # constuct config
39
+ if 'construct' in config:
40
+ construct_config = config['construct']
41
+ result, trajectory, _, _ = pipeline.get_extract_result(task=extraction_config['task'], instruction=extraction_config['instruction'], text=extraction_config['text'], output_schema=extraction_config['output_schema'], constraint=extraction_config['constraint'], use_file=extraction_config['use_file'], file_path=extraction_config['file_path'], truth=extraction_config['truth'], mode=extraction_config['mode'], update_case=extraction_config['update_case'], show_trajectory=extraction_config['show_trajectory'],
42
+ construct=construct_config, iskg=True) # When 'construct' is provided, 'iskg' should be True to construct the knowledge graph.
43
+ return
44
+ else:
45
+ print("please provide construct config in the yaml file.")
46
+
47
  result, trajectory, _, _ = pipeline.get_extract_result(task=extraction_config['task'], instruction=extraction_config['instruction'], text=extraction_config['text'], output_schema=extraction_config['output_schema'], constraint=extraction_config['constraint'], use_file=extraction_config['use_file'], file_path=extraction_config['file_path'], truth=extraction_config['truth'], mode=extraction_config['mode'], update_case=extraction_config['update_case'], show_trajectory=extraction_config['show_trajectory'])
48
  return
49
 
src/utils/__init__.py CHANGED
@@ -1,3 +1,2 @@
1
  from .process import *
2
  from .data_def import DataPoint, TaskType
3
-
 
1
  from .process import *
2
  from .data_def import DataPoint, TaskType
 
src/utils/process.py CHANGED
@@ -17,28 +17,28 @@ import inspect
17
  import ast
18
  with open(os.path.join(os.path.dirname(__file__), "..", "config.yaml")) as file:
19
  config = yaml.safe_load(file)
20
-
21
- # Load configuration
22
  def load_extraction_config(yaml_path):
23
  # Read YAML content from the file path
24
  if not os.path.exists(yaml_path):
25
  print(f"Error: The config file '{yaml_path}' does not exist.")
26
  return {}
27
-
28
  with open(yaml_path, 'r') as file:
29
  config = yaml.safe_load(file)
30
 
31
  # Extract the 'extraction' configuration dictionary
32
  model_config = config.get('model', {})
33
  extraction_config = config.get('extraction', {})
34
-
35
  # Model config
36
  model_name_or_path = model_config.get('model_name_or_path', "")
37
  model_category = model_config.get('category', "")
38
  api_key = model_config.get('api_key', "")
39
  base_url = model_config.get('base_url', "")
40
  vllm_serve = model_config.get('vllm_serve', False)
41
-
42
  # Extraction config
43
  task = extraction_config.get('task', "")
44
  instruction = extraction_config.get('instruction', "")
@@ -52,6 +52,43 @@ def load_extraction_config(yaml_path):
52
  update_case = extraction_config.get('update_case', False)
53
  show_trajectory = extraction_config.get('show_trajectory', False)
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  # Return a dictionary containing these variables
56
  return {
57
  "model": {
@@ -75,7 +112,7 @@ def load_extraction_config(yaml_path):
75
  "show_trajectory": show_trajectory
76
  }
77
  }
78
-
79
  # Split the string text into chunks
80
  def chunk_str(text):
81
  sentences = sent_tokenize(text)
@@ -102,24 +139,24 @@ def chunk_file(file_path):
102
  pages = []
103
 
104
  if file_path.endswith(".pdf"):
105
- loader = PyPDFLoader(file_path)
106
  elif file_path.endswith(".txt"):
107
- loader = TextLoader(file_path)
108
  elif file_path.endswith(".docx"):
109
- loader = Docx2txtLoader(file_path)
110
  elif file_path.endswith(".html"):
111
- loader = BSHTMLLoader(file_path)
112
  elif file_path.endswith(".json"):
113
- loader = JSONLoader(file_path)
114
  else:
115
  raise ValueError("Unsupported file format") # Inform that the format is unsupported
116
-
117
- pages = loader.load_and_split()
118
  docs = ""
119
  for item in pages:
120
  docs += item.page_content
121
  pages = chunk_str(docs)
122
-
123
  return pages
124
 
125
  def process_single_quotes(text):
@@ -147,11 +184,11 @@ def remove_empty_values(data):
147
  def extract_json_dict(text):
148
  if isinstance(text, dict):
149
  return text
150
- pattern = r'\{(?:[^{}]|(?:\{(?:[^{}]|(?:\{[^{}]*\})*)*\})*)*\}'
151
- matches = re.findall(pattern, text)
152
  if matches:
153
- json_string = matches[-1]
154
- json_string = process_single_quotes(json_string)
155
  try:
156
  json_dict = json.loads(json_string)
157
  json_dict = remove_empty_values(json_dict)
@@ -159,9 +196,9 @@ def extract_json_dict(text):
159
  return "No valid information found."
160
  return json_dict
161
  except json.JSONDecodeError:
162
- return json_string
163
  else:
164
- return text
165
 
166
  def good_case_wrapper(example: str):
167
  if example is None or example == "":
@@ -182,10 +219,10 @@ def example_wrapper(example: str):
182
  return example
183
 
184
  def remove_redundant_space(s):
185
- s = ' '.join(s.split())
186
- s = re.sub(r"\s*(,|:|\(|\)|\.|_|;|'|-)\s*", r'\1', s)
187
  return s
188
-
189
  def format_string(s):
190
  s = remove_redundant_space(s)
191
  s = s.lower()
@@ -197,9 +234,9 @@ def format_string(s):
197
  return s
198
 
199
  def calculate_metrics(y_truth: set, y_pred: set):
200
- TP = len(y_truth & y_pred)
201
- FN = len(y_truth - y_pred)
202
- FP = len(y_pred - y_truth)
203
  precision = TP / (TP + FP) if (TP + FP) > 0 else 0
204
  recall = TP / (TP + FN) if (TP + FN) > 0 else 0
205
  f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
@@ -214,11 +251,11 @@ def current_function_name():
214
  else:
215
  print("No caller function found")
216
  return None
217
-
218
  except Exception as e:
219
  print(f"An error occurred: {e}")
220
- pass
221
-
222
  def normalize_obj(value):
223
  if isinstance(value, dict):
224
  return frozenset((k, normalize_obj(v)) for k, v in value.items())
 
17
  import ast
18
  with open(os.path.join(os.path.dirname(__file__), "..", "config.yaml")) as file:
19
  config = yaml.safe_load(file)
20
+
21
+ # Load configuration
22
  def load_extraction_config(yaml_path):
23
  # Read YAML content from the file path
24
  if not os.path.exists(yaml_path):
25
  print(f"Error: The config file '{yaml_path}' does not exist.")
26
  return {}
27
+
28
  with open(yaml_path, 'r') as file:
29
  config = yaml.safe_load(file)
30
 
31
  # Extract the 'extraction' configuration dictionary
32
  model_config = config.get('model', {})
33
  extraction_config = config.get('extraction', {})
34
+
35
  # Model config
36
  model_name_or_path = model_config.get('model_name_or_path', "")
37
  model_category = model_config.get('category', "")
38
  api_key = model_config.get('api_key', "")
39
  base_url = model_config.get('base_url', "")
40
  vllm_serve = model_config.get('vllm_serve', False)
41
+
42
  # Extraction config
43
  task = extraction_config.get('task', "")
44
  instruction = extraction_config.get('instruction', "")
 
52
  update_case = extraction_config.get('update_case', False)
53
  show_trajectory = extraction_config.get('show_trajectory', False)
54
 
55
+ # Construct config (optional: for constructing your knowledge graph)
56
+ if 'construct' in config:
57
+ construct_config = config.get('construct', {})
58
+ database = construct_config.get('database', "")
59
+ url = construct_config.get('url', "")
60
+ username = construct_config.get('username', "")
61
+ password = construct_config.get('password', "")
62
+ # Return a dictionary containing these variables
63
+ return {
64
+ "model": {
65
+ "model_name_or_path": model_name_or_path,
66
+ "category": model_category,
67
+ "api_key": api_key,
68
+ "base_url": base_url,
69
+ "vllm_serve": vllm_serve
70
+ },
71
+ "extraction": {
72
+ "task": task,
73
+ "instruction": instruction,
74
+ "text": text,
75
+ "output_schema": output_schema,
76
+ "constraint": constraint,
77
+ "truth": truth,
78
+ "use_file": use_file,
79
+ "file_path": file_path,
80
+ "mode": mode,
81
+ "update_case": update_case,
82
+ "show_trajectory": show_trajectory
83
+ },
84
+ "construct": {
85
+ "database": database,
86
+ "url": url,
87
+ "username": username,
88
+ "password": password
89
+ }
90
+ }
91
+
92
  # Return a dictionary containing these variables
93
  return {
94
  "model": {
 
112
  "show_trajectory": show_trajectory
113
  }
114
  }
115
+
116
  # Split the string text into chunks
117
  def chunk_str(text):
118
  sentences = sent_tokenize(text)
 
139
  pages = []
140
 
141
  if file_path.endswith(".pdf"):
142
+ loader = PyPDFLoader(file_path)
143
  elif file_path.endswith(".txt"):
144
+ loader = TextLoader(file_path)
145
  elif file_path.endswith(".docx"):
146
+ loader = Docx2txtLoader(file_path)
147
  elif file_path.endswith(".html"):
148
+ loader = BSHTMLLoader(file_path)
149
  elif file_path.endswith(".json"):
150
+ loader = JSONLoader(file_path)
151
  else:
152
  raise ValueError("Unsupported file format") # Inform that the format is unsupported
153
+
154
+ pages = loader.load_and_split()
155
  docs = ""
156
  for item in pages:
157
  docs += item.page_content
158
  pages = chunk_str(docs)
159
+
160
  return pages
161
 
162
  def process_single_quotes(text):
 
184
  def extract_json_dict(text):
185
  if isinstance(text, dict):
186
  return text
187
+ pattern = r'\{(?:[^{}]|(?:\{(?:[^{}]|(?:\{[^{}]*\})*)*\})*)*\}'
188
+ matches = re.findall(pattern, text)
189
  if matches:
190
+ json_string = matches[-1]
191
+ json_string = process_single_quotes(json_string)
192
  try:
193
  json_dict = json.loads(json_string)
194
  json_dict = remove_empty_values(json_dict)
 
196
  return "No valid information found."
197
  return json_dict
198
  except json.JSONDecodeError:
199
+ return json_string
200
  else:
201
+ return text
202
 
203
  def good_case_wrapper(example: str):
204
  if example is None or example == "":
 
219
  return example
220
 
221
  def remove_redundant_space(s):
222
+ s = ' '.join(s.split())
223
+ s = re.sub(r"\s*(,|:|\(|\)|\.|_|;|'|-)\s*", r'\1', s)
224
  return s
225
+
226
  def format_string(s):
227
  s = remove_redundant_space(s)
228
  s = s.lower()
 
234
  return s
235
 
236
  def calculate_metrics(y_truth: set, y_pred: set):
237
+ TP = len(y_truth & y_pred)
238
+ FN = len(y_truth - y_pred)
239
+ FP = len(y_pred - y_truth)
240
  precision = TP / (TP + FP) if (TP + FP) > 0 else 0
241
  recall = TP / (TP + FN) if (TP + FN) > 0 else 0
242
  f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
 
251
  else:
252
  print("No caller function found")
253
  return None
254
+
255
  except Exception as e:
256
  print(f"An error occurred: {e}")
257
+ pass
258
+
259
  def normalize_obj(value):
260
  if isinstance(value, dict):
261
  return frozenset((k, normalize_obj(v)) for k, v in value.items())
src/webui.py CHANGED
@@ -1,9 +1,15 @@
1
- import random
2
- import json
 
 
 
3
  import gradio as gr
 
 
4
 
5
- from pipeline import Pipeline
6
  from models import *
 
 
7
 
8
  examples = [
9
  {
@@ -43,7 +49,7 @@ examples = [
43
  "task": "Base",
44
  "mode": "quick",
45
  "use_file": True,
46
- "file_path": "data/Harry_Potter_Chapter1.pdf",
47
  "instruction": "Extract main characters and the background setting from this chapter.",
48
  "constraint": "",
49
  "text": "",
@@ -54,13 +60,24 @@ examples = [
54
  "task": "Base",
55
  "mode": "quick",
56
  "use_file": True,
57
- "file_path": "data/Tulsi_Gabbard_News.html",
58
  "instruction": "Extract key information from the given text.",
59
  "constraint": "",
60
  "text": "",
61
  "update_case": False,
62
  "truth": "",
63
  },
 
 
 
 
 
 
 
 
 
 
 
64
  ]
65
 
66
 
@@ -75,16 +92,16 @@ def create_interface():
75
  </p>
76
  <h1>OneKE: A Dockerized Schema-Guided LLM Agent-based Knowledge Extraction System</h1>
77
  <p>
78
- 🌐[<a href="https://oneke.openkg.cn/" target="_blank">Web</a>]
79
- ⌨️[<a href="https://github.com/zjunlp/OneKE" target="_blank">Code</a>]
80
  πŸ“Ή[<a href="http://oneke.openkg.cn/demo.mp4" target="_blank">Video</a>]
 
 
81
  </p>
82
  </div>
83
  """)
84
 
85
  example_button_gr = gr.Button("🎲 Quick Start with an Example 🎲")
86
 
87
-
88
  with gr.Row():
89
  with gr.Column():
90
  model_gr = gr.Dropdown(
@@ -103,7 +120,7 @@ def create_interface():
103
  with gr.Column():
104
  task_gr = gr.Dropdown(
105
  label="🎯 Select your Task",
106
- choices=["Base", "NER", "RE", "EE"],
107
  value="Base",
108
  )
109
  mode_gr = gr.Dropdown(
@@ -139,6 +156,8 @@ def create_interface():
139
  return gr.update(visible=False), gr.update(visible=True, label="πŸ•ΉοΈ Constraint", placeholder="Enter your RE Constraint")
140
  elif task == "EE":
141
  return gr.update(visible=False), gr.update(visible=True, label="πŸ•ΉοΈ Constraint", placeholder="Enter your EE Constraint")
 
 
142
 
143
  def update_input_fields(use_file):
144
  if use_file:
@@ -162,7 +181,7 @@ def create_interface():
162
  gr.update(value=example["file_path"], visible=example["use_file"]),
163
  gr.update(value=example["text"], visible=not example["use_file"]),
164
  gr.update(value=example["instruction"], visible=example["task"] == "Base"),
165
- gr.update(value=example["constraint"], visible=example["task"] in ["NER", "RE", "EE"]),
166
  gr.update(value=example["update_case"]),
167
  gr.update(value=example["truth"]),
168
  gr.update(value="NOT REQUIRED", visible=False),
@@ -207,7 +226,7 @@ def create_interface():
207
  if reflection_agent not in ["", "NOT REQUIRED"]:
208
  agent3["reflection_agent"] = reflection_agent
209
 
210
- # 调用 Pipeline
211
  _, _, ger_frontend_schema, ger_frontend_res = pipeline.get_extract_result(
212
  task=task,
213
  text=text,
@@ -336,6 +355,8 @@ def create_interface():
336
 
337
  return demo
338
 
 
 
339
  if __name__ == "__main__":
340
  interface = create_interface()
341
- interface.launch()
 
1
+ """
2
+ ....../OneKE$ python src/webui.py
3
+ """
4
+
5
+
6
  import gradio as gr
7
+ import json
8
+ import random
9
 
 
10
  from models import *
11
+ from pipeline import Pipeline
12
+
13
 
14
  examples = [
15
  {
 
49
  "task": "Base",
50
  "mode": "quick",
51
  "use_file": True,
52
+ "file_path": "data/input_files/Harry_Potter_Chapter1.pdf",
53
  "instruction": "Extract main characters and the background setting from this chapter.",
54
  "constraint": "",
55
  "text": "",
 
60
  "task": "Base",
61
  "mode": "quick",
62
  "use_file": True,
63
+ "file_path": "data/input_files/Tulsi_Gabbard_News.html",
64
  "instruction": "Extract key information from the given text.",
65
  "constraint": "",
66
  "text": "",
67
  "update_case": False,
68
  "truth": "",
69
  },
70
+ {
71
+ "task": "Triple",
72
+ "mode": "quick",
73
+ "use_file": True,
74
+ "file_path": "data/input_files/Artificial_Intelligence_Wikipedia.txt",
75
+ "instruction": "",
76
+ "constraint": """[["Person", "Place", "Event", "property"], ["Interpersonal", "Located", "Ownership", "Action"]]""",
77
+ "text": "",
78
+ "update_case": False,
79
+ "truth": "",
80
+ }
81
  ]
82
 
83
 
 
92
  </p>
93
  <h1>OneKE: A Dockerized Schema-Guided LLM Agent-based Knowledge Extraction System</h1>
94
  <p>
95
+ 🌐[<a href="https://oneke.openkg.cn/" target="_blank">Home</a>]
 
96
  πŸ“Ή[<a href="http://oneke.openkg.cn/demo.mp4" target="_blank">Video</a>]
97
+ πŸ“[<a href="https://arxiv.org/abs/2209.10707" target="_blank">Paper</a>]
98
+ πŸ’»[<a href="https://github.com/zjunlp/OneKE" target="_blank">Code</a>]
99
  </p>
100
  </div>
101
  """)
102
 
103
  example_button_gr = gr.Button("🎲 Quick Start with an Example 🎲")
104
 
 
105
  with gr.Row():
106
  with gr.Column():
107
  model_gr = gr.Dropdown(
 
120
  with gr.Column():
121
  task_gr = gr.Dropdown(
122
  label="🎯 Select your Task",
123
+ choices=["Base", "NER", "RE", "EE", "Triple"],
124
  value="Base",
125
  )
126
  mode_gr = gr.Dropdown(
 
156
  return gr.update(visible=False), gr.update(visible=True, label="πŸ•ΉοΈ Constraint", placeholder="Enter your RE Constraint")
157
  elif task == "EE":
158
  return gr.update(visible=False), gr.update(visible=True, label="πŸ•ΉοΈ Constraint", placeholder="Enter your EE Constraint")
159
+ elif task == "Triple":
160
+ return gr.update(visible=False), gr.update(visible=True, label="πŸ•ΉοΈ Constraint", placeholder="Enter your Triple Constraint")
161
 
162
  def update_input_fields(use_file):
163
  if use_file:
 
181
  gr.update(value=example["file_path"], visible=example["use_file"]),
182
  gr.update(value=example["text"], visible=not example["use_file"]),
183
  gr.update(value=example["instruction"], visible=example["task"] == "Base"),
184
+ gr.update(value=example["constraint"], visible=example["task"] in ["NER", "RE", "EE", "Triple"]),
185
  gr.update(value=example["update_case"]),
186
  gr.update(value=example["truth"]),
187
  gr.update(value="NOT REQUIRED", visible=False),
 
226
  if reflection_agent not in ["", "NOT REQUIRED"]:
227
  agent3["reflection_agent"] = reflection_agent
228
 
229
+ # use 'Pipeline'
230
  _, _, ger_frontend_schema, ger_frontend_res = pipeline.get_extract_result(
231
  task=task,
232
  text=text,
 
355
 
356
  return demo
357
 
358
+
359
+ # Launch the front-end interface
360
  if __name__ == "__main__":
361
  interface = create_interface()
362
+ interface.launch() # the Gradio defalut URL usually is: 127.0.0.1:7860