Spaces:
Sleeping
Sleeping
from llama_index.core import StorageContext, load_index_from_storage | |
from llama_index.embeddings.huggingface import HuggingFaceEmbedding | |
from llama_index.core import Settings | |
from llama_index.llms.groq import Groq | |
from llama_index.llms.openai import OpenAI | |
from core.config import settings | |
from llama_index.vector_stores.faiss import FaissVectorStore | |
from prompt.prompt import qa_prompt_tmpl, refine_prompt_tmpl | |
from IPython.display import Markdown, display | |
import re | |
from llama_index.core.retrievers import RecursiveRetriever | |
import string | |
from llama_index.postprocessor.cohere_rerank import CohereRerank | |
from llama_index.core.query_engine import RetrieverQueryEngine | |
import pickle | |
from loader import get_all_nodes, document_prepare | |
from llama_index.embeddings.openai import OpenAIEmbedding | |
from llama_index.core.query_engine import MultiStepQueryEngine | |
from llama_index.core.indices.query.query_transform.base import ( | |
StepDecomposeQueryTransform, | |
) | |
from llama_index.core.indices.postprocessor import MetadataReplacementPostProcessor | |
#Settings | |
# Settings.embed_model = HuggingFaceEmbedding( | |
# model_name= settings.EMBEDDING_MODEL | |
# ) | |
Settings.embed_model = OpenAIEmbedding( | |
model_name= settings.OPENAI_EMBEDDING_MODEL | |
) | |
Settings.llm = OpenAI(model = settings.OPENAI_MODEL, | |
api_key = settings.OPENAI_API_KEY, temperature=0) | |
step_decompose_transform = StepDecomposeQueryTransform( | |
llm=Settings.llm, verbose=True | |
) | |
# Settings.llm = Groq(model=settings.MODEL_ID, api_key= settings.MODEL_API_KEY) | |
# print(Settings.llm.max_tokens) | |
# all_nodes = get_all_nodes(document_prepare(settings.RAW_DATA_DIR)) | |
# define prompt viewing function | |
def display_prompt_dict(prompts_dict): | |
for k, p in prompts_dict.items(): | |
text_md = f"**Prompt Key**: {k}<br>" f"**Text:** <br>" | |
display(Markdown(text_md)) | |
print(p.get_template()) | |
display(Markdown("<br><br>")) | |
def preprocessing_text(query: str) -> str: | |
text = query | |
abbreviations = { | |
'tnhh': 'Trách nhiệm hữu hạn', # Công ty Trách nhiệm Hữu hạn | |
'Tnhh': 'Trách nhiệm hữu hạn', # Công ty Trách nhiệm Hữu hạn | |
'TNHH': 'Trách nhiệm hữu hạn', # Công ty Trách nhiệm Hữu hạn | |
'cp': 'Cổ phần', # Công ty Cổ phần | |
'CP': 'Cổ phần', | |
'mtv': 'Một thành viên', # Công ty Trách nhiệm Hữu hạn Một Thành Viên | |
'MTV': 'Một thành viên', | |
'công ty hd': 'công ty Hợp danh', # Công ty Hợp danh | |
'công ty HD': 'công ty Hợp danh', | |
'dn': 'doanh nghiệp', # Doanh nghiệp | |
'DN': 'Doanh nghiệp', | |
'DNTN': 'Doanh nghiệp tư nhân', | |
'dntn': 'Doanh nghiệp tư nhân', | |
'Dntn': 'Doanh nghiệp tư nhân', | |
'vốn đl': 'Vốn điều lệ', # Vốn Điều lệ | |
'gpkd': 'Giấy phép kinh doanh', # Giấy Phép Kinh Doanh | |
'GPKD': 'Giấy phép kinh doanh', | |
'dkdn': 'Đăng ký doanh nghiệp', # Đăng Ký Doanh Nghiệp | |
'tldn': 'Thành lập doanh nghiệp', # Thành lập Doanh nghiệp | |
'hdqt': 'Hội đồng quản trị', # Hội Đồng Quản Trị | |
'vốn góp': 'Vốn góp', # Vốn Góp | |
'tct': 'Tổng công ty', # Tổng Công ty | |
'kv': 'Khu vực', # Khu Vực | |
'htx': 'Hợp tác xã', # Hợp Tác Xã | |
'lds': 'Liên doanh', # Liên Doanh | |
'sở hđt': 'Sở hữu đầu tư', # Sở Hữu Đầu Tư | |
'nlđ': 'Người lao động', # Người Lao Động | |
'đt': 'Đầu tư', # Đầu Tư | |
'kt': 'Kinh tế', # Kinh Tế | |
'kte': 'Kinh tế', | |
'hđ': 'hợp đồng', | |
'hdong': 'hợp đồng', | |
'gd': 'Giám đốc', | |
'đtdnnn': 'Đầu tư doanh nghiệp nước ngoài' # Đầu Tư Doanh Nghiệp Nước Ngoài | |
} | |
for k,v in abbreviations.items(): | |
text = text.replace(k,v) | |
text = re.sub(r'(.)\1{2,}', r'\1', text) #Removes trailing | |
text = re.sub(r"(\w)\s*([{}])\s*(\w)".format(re.escape(string.punctuation)), r"\1 \3", text) # Removes punctuation after word characters | |
text = re.sub(r"(\w)([" + string.punctuation + "])", r"\1", text) # Removes punctuation after word characters | |
text = re.sub(f"([{string.punctuation}])([{string.punctuation}])+", r"\1", text) # Remove repeated consecutive punctuation marks | |
text = text.strip() # Remove leading and trailing whitespaces | |
# While loops to remove leading and trailing punctuation and whitespace characters. | |
while text.endswith(tuple(string.punctuation + string.whitespace)): | |
text = text[:-1] | |
while text.startswith(tuple(string.punctuation + string.whitespace)): | |
text = text[1:] | |
text = re.sub(r"\s+", " ", text) # Replace multiple consecutive whitespaces with a single space | |
return text | |
def response_faiss(query:str, history: str) -> str: | |
message = preprocessing_text(query) | |
vector_store = FaissVectorStore.from_persist_dir("./storage") | |
storage_context = StorageContext.from_defaults( | |
vector_store=vector_store, persist_dir="./storage" | |
) | |
index = load_index_from_storage(storage_context=storage_context) | |
vector_retriever = index.as_retriever(similarity_top_k=2) | |
query_engine = index.as_query_engine() | |
query_engine.update_prompts( | |
{"response_synthesizer:text_qa_template": qa_prompt_tmpl, | |
"response_synthesizer:refine_template": refine_prompt_tmpl,} | |
) | |
# display_prompt_dict(query_engine.get_prompts()) | |
response = str(query_engine.query(f"{message}")) | |
retrieved_nodes = vector_retriever.retrieve(message) | |
print(retrieved_nodes[0].metadata) | |
print(response) | |
return response | |
def sub_chunk_query(query: str, history: str) -> str: | |
query = preprocessing_text(query) | |
all_nodes_dict = {n.node_id: n for n in all_nodes} | |
vector_store = FaissVectorStore.from_persist_dir("./storage") | |
storage_context = StorageContext.from_defaults( | |
vector_store=vector_store, persist_dir="./storage" | |
) | |
index = load_index_from_storage(storage_context=storage_context) | |
vector_retriever_chunk = index.as_retriever(similarity_top_k=2) | |
retriever_chunk = RecursiveRetriever( | |
"vector", | |
retriever_dict={"vector": vector_retriever_chunk}, | |
node_dict=all_nodes_dict, | |
verbose=True, | |
) | |
nodes = retriever_chunk.retrieve(query) | |
print(nodes[0].text[:500]) | |
query_engine = MultiStepQueryEngine( | |
retriever_chunk, | |
storage_context = storage_context, | |
similarity_top_k=5, | |
query_transform=step_decompose_transform, | |
node_postprocessors=[ | |
CohereRerank(api_key=settings.COHERE_API_KEY, top_n=3) | |
], | |
) | |
query_engine.update_prompts( | |
{"response_synthesizer:text_qa_template": qa_prompt_tmpl, | |
"response_synthesizer:refine_template": refine_prompt_tmpl,} | |
) | |
response = str(query_engine.query(f"{query}")) | |
print(query) | |
print(response) | |
return response | |
def window_query(query: str, history: str): | |
query = preprocessing_text(query) | |
vector_store = FaissVectorStore.from_persist_dir("./storage") | |
storage_context = StorageContext.from_defaults( | |
vector_store=vector_store, persist_dir="./storage" | |
) | |
sentence_index = load_index_from_storage(storage_context=storage_context) | |
query_engine = sentence_index.as_query_engine( | |
similarity_top_k=3, | |
# the target key defaults to `window` to match the node_parser's default | |
node_postprocessors=[ | |
MetadataReplacementPostProcessor(target_metadata_key="window"), | |
CohereRerank(api_key=settings.COHERE_API_KEY, top_n=2), | |
], | |
verbose=True, | |
) | |
query_engine.update_prompts( | |
{"response_synthesizer:text_qa_template": qa_prompt_tmpl, | |
"response_synthesizer:refine_template": refine_prompt_tmpl,} | |
) | |
response = query_engine.query(f"{query}") | |
window = response.source_nodes[0].node.metadata["window"][:500] | |
sentence = response.source_nodes[0].node.metadata["original_text"][:500] | |
print(f"Window: {window}") | |
print("------------------") | |
print(f"Original Sentence: {sentence}") | |
return str(response) | |
examples=[ | |
'Chào bán cổ phần cho cổ đông hiện hữu của công ty cổ phần không phải là công ty đại chúng được thực hiện ra sao ?', | |
'Quyền của doanh nghiệp là những quyền nào?', | |
'Các trường hợp nào được coi là tên gây nhầm lẫn ?', | |
'Các quy định về chào bán trái phiếu riêng lẻ', | |
'Doanh nghiệp có quyền và nghĩa vụ như thế nào?', | |
'Thành lập công ty TNHH thì quy trình như thế nào?' | |
] | |
# query = examples[1] | |
# print(query) | |
# print(sub_chunk_query(query, "")) |