File size: 3,978 Bytes
c604980
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
from py2neo import Graph, Node, Relationship
from langchain_community.document_loaders import TextLoader,UnstructuredCSVLoader, UnstructuredPDFLoader,UnstructuredWordDocumentLoader,UnstructuredExcelLoader,UnstructuredMarkdownLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter



class KnowledgeGraph:
    def __init__(self, uri, user, password):
        self.graph = Graph(uri, auth=(user, password))

    def parse_data(self,file):
        if "txt" in file.lower() or "csv" in file.lower():
            try:
                loaders = UnstructuredCSVLoader(file)
                data = loaders.load()
            except:
                loaders = TextLoader(file,encoding="utf-8")
                data = loaders.load()
        if ".doc" in file.lower() or ".docx" in file.lower():
            loaders = UnstructuredWordDocumentLoader(file)
            data = loaders.load()
        if "pdf" in file.lower():
            loaders = UnstructuredPDFLoader(file)
            data = loaders.load()
        if ".xlsx" in file.lower():
            loaders = UnstructuredExcelLoader(file)
            data = loaders.load()
        if ".md" in file.lower():
            loaders = UnstructuredMarkdownLoader(file)
            data = loaders.load()
        return data

    # 切分 数据
    def split_files(self, files,chunk_size=500, chunk_overlap=100):
        text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
        print("开始创建数据库 ....")
        tmps = []
        for file in files:
            data = self.parse_data(file)
            tmps.extend(data)

        splits = text_splitter.split_documents(tmps)

        return splits

    def create_node(self, label, properties):
        matcher = self.graph.nodes.match(label, **properties)
        if matcher.first():
            return matcher.first()
        else:
            node = Node(label, **properties)
            self.graph.create(node)
            return node

    def create_relationship(self, label1, properties1, label2, properties2, relationship_type,
                            relationship_properties={}):
        node1 = self.create_node(label1, properties1)
        node2 = self.create_node(label2, properties2)

        matcher = self.graph.match((node1, node2), r_type=relationship_type)
        for rel in matcher:
            if all(rel[key] == value for key, value in relationship_properties.items()):
                return rel

        relationship = Relationship(node1, relationship_type, node2, **relationship_properties)
        self.graph.create(relationship)
        return relationship

    def delete_node(self, label, properties):
        matcher = self.graph.nodes.match(label, **properties)
        node = matcher.first()
        if node:
            self.graph.delete(node)
            return True
        return False

    def update_node(self, label, identifier, updates):
        matcher = self.graph.nodes.match(label, **identifier)
        node = matcher.first()
        if node:
            for key, value in updates.items():
                node[key] = value
            self.graph.push(node)
            return node
        return None

    def find_node(self, label, properties):
        matcher = self.graph.nodes.match(label, **properties)
        return list(matcher)

    def create_nodes(self, label, properties_list):
        nodes = []
        for properties in properties_list:
            node = self.create_node(label, properties)
            nodes.append(node)
        return nodes

    def create_relationships(self, relationships):
        created_relationships = []
        for rel in relationships:
            label1, properties1, label2, properties2, relationship_type = rel
            relationship = self.create_relationship(label1, properties1, label2, properties2, relationship_type)
            created_relationships.append(relationship)
        return created_relationships