File size: 5,293 Bytes
83f7ed7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
from elasticsearch import Elasticsearch
from langchain_elasticsearch.vectorstores import ElasticsearchStore
from langchain_community.embeddings import OllamaEmbeddings
from langchain_community.document_loaders import TextLoader, UnstructuredCSVLoader, UnstructuredPDFLoader, \
    UnstructuredWordDocumentLoader, UnstructuredExcelLoader, UnstructuredMarkdownLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from .asr_utils import get_spk_txt
import requests


class ElsStore():
    def __init__(self, embedding="mofanke/acge_text_embedding:latest", es_url="http://localhost:9200",
                 index_name='test_index'):
        self.embedding = OllamaEmbeddings(model=embedding)
        self.es_url = es_url
        self.elastic_vector_search = ElasticsearchStore(
            es_url=self.es_url,
            index_name=index_name,
            embedding=self.embedding
        )

    def parse_data(self, file):
        if "txt" in file.lower() or "csv" in file.lower():
            try:
                loaders = UnstructuredCSVLoader(file)
                data = loaders.load()
            except:
                loaders = TextLoader(file, encoding="utf-8")
                data = loaders.load()
        if ".doc" in file.lower() or ".docx" in file.lower():
            loaders = UnstructuredWordDocumentLoader(file)
            data = loaders.load()
        if "pdf" in file.lower():
            loaders = UnstructuredPDFLoader(file)
            data = loaders.load()
        if ".xlsx" in file.lower():
            loaders = UnstructuredExcelLoader(file)
            data = loaders.load()
        if ".md" in file.lower():
            loaders = UnstructuredMarkdownLoader(file)
            data = loaders.load()
        if "mp3" in file.lower() or "mp4" in file.lower() or "wav" in file.lower():
            # 语音解析成文字
            fw = get_spk_txt(file)
            loaders = UnstructuredCSVLoader(fw)
            data = loaders.load()
            tmp = []
            for i in data:
                i.metadata["source"] = file
                tmp.append(i)
            data = tmp
        return data

    def get_count(self, c_name):
        # 获取index-anme中的数据块数

        # 初始化 Elasticsearch 客户端
        es = Elasticsearch([{
            'host': self.es_url.split(":")[1][2:],
            'port': int(self.es_url.split(":")[2]),
            'scheme': 'http'  # 指定使用的协议
        }])

        # 指定索引名称
        index_name = c_name

        # 获取文档总数
        response = es.count(index=index_name)

        # 输出文档总数
        return response['count']

    # 创建 新的index_name 并且初始化
    def create_collection(self, files, c_name, chunk_size=200, chunk_overlap=50):
        self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
        print("开始创建数据库 ....")
        tmps = []
        for file in files:
            data = self.parse_data(file)
            tmps.extend(data)

        splits = self.text_splitter.split_documents(tmps)

        self.elastic_vector_search = ElasticsearchStore.from_documents(
            documents=splits,
            embedding=self.embedding,
            es_url=self.es_url,
            index_name=c_name,
        )

        self.elastic_vector_search.client.indices.refresh(index=c_name)

        print("数据块总量:", self.get_count(c_name))

        return self.elastic_vector_search

    # 添加 数据到已有数据库
    def add_chroma(self, files, c_name, chunk_size=200, chunk_overlap=50):
        self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
        print("开始添加文件...")
        tmps = []
        for file in files:
            data = self.parse_data(file)
            tmps.extend(data)

        splits = self.text_splitter.split_documents(tmps)

        self.elastic_vector_search = ElasticsearchStore(
            es_url=self.es_url,
            index_name=c_name,
            embedding=self.embedding
        )
        self.elastic_vector_search.add_documents(splits)
        self.elastic_vector_search.client.indices.refresh(index=c_name)
        print("数据块总量:", self.get_count(c_name))

        return self.elastic_vector_search

    # 删除某个 知识库 collection
    def delete_collection(self, c_name):
        url = self.es_url + "/" + c_name
        # 发送 DELETE 请求
        response = requests.delete(url)

        # 检查响应状态码
        if response.status_code == 200:
            return f"索引 'test-basic1' 已成功删除。"
        elif response.status_code == 404:
            return f"索引 'test-basic1' 不存在。"
        else:
            return f"删除索引时出错: {response.status_code}, {response.text}"

    # 获取目前所有 index_names
    def get_all_collections_name(self):
        indices = self.elastic_vector_search.client.indices.get_alias()
        index_names = list(indices.keys())

        return index_names

    def get_collcetion_content_files(self,c_name):
        return []

    # 删除 某个collection中的 某个文件
    def del_files(self, del_files_name, c_name):
        return None