Spaces:
Sleeping
Sleeping
from llama_index.core import load_index_from_storage, StorageContext, SimpleDirectoryReader, VectorStoreIndex, QueryBundle | |
from llama_index.embeddings.huggingface import HuggingFaceEmbedding | |
from llama_index.core import Settings | |
from llama_index.llms.groq import Groq | |
from llama_index.llms.ollama import Ollama | |
from llama_index.readers.file import DocxReader | |
from llama_index.core.node_parser import SimpleFileNodeParser, SentenceSplitter, SimpleNodeParser | |
from llama_index.core.storage.docstore import SimpleDocumentStore | |
from llama_index.vector_stores.faiss import FaissVectorStore | |
from llama_index.core.retrievers import RecursiveRetriever | |
from llama_index.core.schema import IndexNode | |
from llama_index.llms.openai import OpenAI | |
from llama_index.embeddings.openai import OpenAIEmbedding | |
from llama_index.core.response.notebook_utils import display_source_node | |
from llama_index.core.query_engine import RetrieverQueryEngine | |
import faiss | |
import re | |
from core.config import settings | |
from llama_index.core.schema import MetadataMode | |
import pickle | |
from llama_index.core.node_parser import SentenceWindowNodeParser | |
from llama_index.core.indices.postprocessor import MetadataReplacementPostProcessor | |
from llama_index.postprocessor.cohere_rerank import CohereRerank | |
from prompt.prompt import qa_prompt_tmpl, refine_prompt_tmpl | |
# #Settings | |
# Settings.embed_model = HuggingFaceEmbedding( | |
# model_name= settings.EMBEDDING_MODEL | |
# ) | |
# Settings.llm = Groq(model=settings.MODEL_ID, api_key= settings.MODEL_API_KEY) | |
Settings.embed_model = OpenAIEmbedding( | |
model_name= settings.OPENAI_EMBEDDING_MODEL | |
) | |
Settings.llm = OpenAI(model = settings.OPENAI_MODEL, | |
api_key = settings.OPENAI_API_KEY, max_tokens = 512) | |
def windows_parser(documents: str): | |
# create the sentence window node parser w/ default settings | |
# d = settings.EMBEDDING_MODEL_DIMENSIONS | |
d = settings.OPENAI_EMBEDDING_MODEL_DIMS | |
faiss_index = faiss.IndexFlatL2(d) | |
# assign faiss as the vector_store to the context | |
vector_store = FaissVectorStore(faiss_index=faiss_index) | |
storage_context = StorageContext.from_defaults(vector_store=vector_store) | |
node_parser = SentenceWindowNodeParser.from_defaults( | |
window_size=50, | |
window_metadata_key="window", | |
original_text_metadata_key="original_text", | |
) | |
sentence_nodes = node_parser.get_nodes_from_documents(documents) | |
sentence_index = VectorStoreIndex(sentence_nodes, | |
storage_context=storage_context, | |
show_progress=True,) | |
sentence_index.storage_context.persist() | |
def window_query(query: str): | |
vector_store = FaissVectorStore.from_persist_dir("./storage") | |
storage_context = StorageContext.from_defaults( | |
vector_store=vector_store, persist_dir="./storage" | |
) | |
sentence_index = load_index_from_storage(storage_context=storage_context) | |
query_engine = sentence_index.as_query_engine( | |
similarity_top_k=3, | |
# the target key defaults to `window` to match the node_parser's default | |
node_postprocessors=[ | |
MetadataReplacementPostProcessor(target_metadata_key="window"), | |
CohereRerank(api_key=settings.COHERE_API_KEY, top_n=2), | |
], | |
verbose=True, | |
) | |
query_engine.update_prompts( | |
{"response_synthesizer:text_qa_template": qa_prompt_tmpl, | |
"response_synthesizer:refine_template": refine_prompt_tmpl,} | |
) | |
response = query_engine.query(f"{query}") | |
window = response.source_nodes[0].node.metadata["window"][:500] | |
sentence = response.source_nodes[0].node.metadata["original_text"][:500] | |
print(f"Window: {window}") | |
print("------------------") | |
print(f"Original Sentence: {sentence}") | |
return str(response) | |
def document_prepare(path: str): | |
#load documents | |
documents = SimpleDirectoryReader(path, file_extractor = {'.docx': DocxReader()}).load_data() | |
print(len(documents)) | |
#extract metadata if needed | |
# extract_metadata(documents) | |
# documents[0].excluded_llm_metadata_keys = ["law_number", "file_name", "file_type", "file_size","creation_date", "last_modified_date"] | |
# documents[0].excluded_embed_metadata_keys = ["law_number", "law_name","file_name", "file_type", "file_size","creation_date", "last_modified_date"] | |
# # print("LLM: ",documents[0].get_content(metadata_mode=MetadataMode.LLM)[:500]) | |
# print("Embed: ", documents[0].get_content(metadata_mode=MetadataMode.EMBED)[:500]) | |
return documents | |
def extract_metadata(docs: list) -> None: | |
for doc in docs: | |
text = doc.text | |
# The regular expression pattern | |
pattern_laws_number = r"(?i)số[:\s]+([^\s.,]+)" | |
pattern_laws_name = r"(NGHỊ ĐỊNH|LUẬT)\s+(.*?)\s+Căn cứ" | |
# Find the match | |
match_laws_number = re.search(pattern_laws_number, text) | |
match_laws_name = re.search(pattern_laws_name, text) | |
# Extract and print the result if a match is found | |
# print("before:", doc.metadata) | |
if match_laws_number: | |
# print("Found:", match_laws_number.group(1)) # Output: 59/2020/QH14 | |
(doc.metadata) = {**doc.metadata, "law_number" : f"{match_laws_number.group(1)}"} | |
if match_laws_name: | |
# print("Found:", f"{match_laws_name.group(1)} {match_laws_name.group(2)}") # Output: Luật doanh nghiệp | |
(doc.metadata) = {**doc.metadata, "law_name" : f"{match_laws_name.group(1)} {match_laws_name.group(2)}"} | |
# print("after:", doc.metadata, "\n") | |
def faiss_setup(documents: list) -> None : | |
d = settings.OPENAI_EMBEDDING_MODEL_DIMS | |
faiss_index = faiss.IndexFlatL2(d) | |
# assign faiss as the vector_store to the context | |
vector_store = FaissVectorStore(faiss_index=faiss_index) | |
storage_context = StorageContext.from_defaults(vector_store=vector_store) | |
index = VectorStoreIndex.from_documents( | |
documents, | |
storage_context = storage_context) | |
def faiss_load(query: str) -> str: | |
vector_store = FaissVectorStore.from_persist_dir("./storage") | |
storage_context = StorageContext.from_defaults( | |
vector_store=vector_store, persist_dir="./storage" | |
) | |
index = load_index_from_storage(storage_context=storage_context) | |
query_engine = index.as_query_engine() | |
vector_retriever = index.as_retriever(similarity_top_k=2) | |
response = query_engine.query(query) | |
retrieved_nodes = vector_retriever.retrieve(query) | |
print(retrieved_nodes[0]) | |
return response | |
def get_all_nodes(documents: list): | |
# Save all_nodes to a file | |
node_parser = SimpleNodeParser.from_defaults(chunk_size=settings.MAX_NEW_TOKENS, chunk_overlap= settings.MAX_OVERLAPS) | |
base_nodes = node_parser.get_nodes_from_documents(documents) | |
# set node ids to be a constant | |
for idx, node in enumerate(base_nodes): | |
node.id_ = f"node-{idx}" | |
#original: 1024. Divided into 8 128, 4 256, 2 512 | |
sub_chunk_sizes = [(settings.MAX_NEW_TOKENS/8), (settings.MAX_NEW_TOKENS/4), (settings.MAX_NEW_TOKENS/2)] | |
sub_overlap_sizes = [(settings.MAX_OVERLAPS/8), (settings.MAX_OVERLAPS/4), (settings.MAX_OVERLAPS/2)] | |
sub_node_parsers = [ | |
SimpleNodeParser.from_defaults(chunk_size=c, chunk_overlap=o) for c, o in zip(sub_chunk_sizes, sub_overlap_sizes) | |
] | |
all_nodes = [] | |
for base_node in base_nodes: | |
for n in sub_node_parsers: | |
sub_nodes = n.get_nodes_from_documents([base_node]) | |
sub_inodes = [ | |
IndexNode.from_text_node(sn, base_node.node_id) for sn in sub_nodes | |
] | |
all_nodes.extend(sub_inodes) | |
# also add original node to node | |
original_node = IndexNode.from_text_node(base_node, base_node.node_id) | |
all_nodes.append(original_node) | |
# print('done nodes') | |
return all_nodes | |
def sub_chunk_setup(all_nodes:list ) -> None: | |
# Load all_nodes from a file | |
# d = settings.OPENAI_EMBEDDING_MODEL_DIMS | |
d = settings.EMBEDDING_MODEL_DIMENSIONS | |
faiss_index = faiss.IndexFlatL2(d) | |
# assign faiss as the vector_store to the context | |
vector_store = FaissVectorStore(faiss_index=faiss_index) | |
storage_context = StorageContext.from_defaults(vector_store=vector_store) | |
index = VectorStoreIndex( | |
all_nodes, | |
storage_context = storage_context, | |
show_progress= True | |
) | |
print('done setup') | |
index.storage_context.persist() | |
def sub_chunk_query(all_nodes:list, query: str) -> str: | |
# Load all_nodes from a file | |
all_nodes_dict = {n.node_id: n for n in all_nodes} | |
vector_store = FaissVectorStore.from_persist_dir("./storage") | |
storage_context = StorageContext.from_defaults( | |
vector_store=vector_store, persist_dir="./storage" | |
) | |
index = load_index_from_storage(storage_context=storage_context) | |
vector_retriever_chunk = index.as_retriever(similarity_top_k=3) | |
retriever_chunk = RecursiveRetriever( | |
"vector", | |
retriever_dict={"vector": vector_retriever_chunk}, | |
node_dict=all_nodes_dict, | |
verbose=True, | |
) | |
nodes = retriever_chunk.retrieve(QueryBundle(query)) | |
for node in nodes: | |
display_source_node(node, source_length=2000) | |
# print(settings.MAX_NEW_TOKENS) | |
query_engine = RetrieverQueryEngine.from_args( | |
retriever_chunk, storage_context = storage_context | |
) | |
response = str(query_engine.query(f"{query}")) | |
# print(response) | |
return response | |
if __name__ == "__main__": | |
documents = document_prepare(settings.RAW_DATA_DIR) | |
# all_nodes = get_all_nodes(documents) | |
# faiss_setup(documents) | |
# sub_chunk_setup(all_nodes) | |
# windows_parser(documents) | |
# examples=[ | |
# 'Chào bán cổ phần cho cổ đông hiện hữu của công ty cổ phần không phải là công ty đại chúng được thực hiện ra sao ?', | |
# 'Quyền của doanh nghiệp là những quyền nào?', | |
# 'Các trường hợp nào được coi là tên gây nhầm lẫn ?', | |
# 'Các quy định về chào bán trái phiếu riêng lẻ', | |
# 'Doanh nghiệp có quyền và nghĩa vụ như thế nào?', | |
# 'Thành lập công ty TNHH thì quy trình như thế nào?' | |
# ] | |
examples = [ | |
"Công ty cổ phần là gì?", | |
"Định nghĩa về “góp vốn” trong Luật Doanh nghiệp là gì?", | |
"Khái niệm “cổ đông” được hiểu như thế nào?", | |
"Thế nào là “vốn điều lệ” trong doanh nghiệp?", | |
"“Doanh nghiệp có vốn đầu tư nước ngoài” là gì?" | |
] | |
for example in examples: | |
# query = examples[3] | |
query = example | |
print("///////////////////////////////") | |
print(query) | |
# print(faiss_load(query)) | |
# print(sub_chunk_query(all_nodes, query)) | |
print("Answer:", window_query(query)) | |
print("\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\") | |