Spaces:
Sleeping
Sleeping
import tempfile | |
import itertools | |
import gradio as gr | |
from __init__ import * | |
from llama_cpp import Llama | |
from chromadb.config import Settings | |
from typing import List, Optional, Union | |
from langchain.vectorstores import Chroma | |
from langchain.docstore.document import Document | |
from huggingface_hub.file_download import http_get | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
class LocalChatGPT: | |
def __init__(self): | |
self.llama_model: Optional[Llama] = None | |
self.embeddings: HuggingFaceEmbeddings = self.initialize_app() | |
def initialize_app(self) -> HuggingFaceEmbeddings: | |
""" | |
Load all models from the list | |
:return: | |
""" | |
os.makedirs(MODELS_DIR, exist_ok=True) | |
model_url, model_name = list(DICT_REPO_AND_MODELS.items())[0] | |
final_model_path = os.path.join(MODELS_DIR, model_name) | |
os.makedirs("/".join(final_model_path.split("/")[:-1]), exist_ok=True) | |
if not os.path.exists(final_model_path): | |
with open(final_model_path, "wb") as f: | |
http_get(model_url, f) | |
self.llama_model = Llama( | |
model_path=final_model_path, | |
n_ctx=2000, | |
n_parts=1, | |
) | |
return HuggingFaceEmbeddings(model_name=EMBEDDER_NAME, cache_folder=MODELS_DIR) | |
def load_model(self, model_name): | |
""" | |
:param model_name: | |
:return: | |
""" | |
final_model_path = os.path.join(MODELS_DIR, model_name) | |
os.makedirs("/".join(final_model_path.split("/")[:-1]), exist_ok=True) | |
if not os.path.exists(final_model_path): | |
with open(final_model_path, "wb") as f: | |
if model_url := [i for i in DICT_REPO_AND_MODELS if DICT_REPO_AND_MODELS[i] == model_name]: | |
http_get(model_url[0], f) | |
self.llama_model = Llama( | |
model_path=final_model_path, | |
n_ctx=2000, | |
n_parts=1, | |
) | |
return model_name | |
def load_single_document(file_path: str) -> Document: | |
""" | |
Upload one document. | |
:param file_path: | |
:return: | |
""" | |
ext: str = "." + file_path.rsplit(".", 1)[-1] | |
assert ext in LOADER_MAPPING | |
loader_class, loader_args = LOADER_MAPPING[ext] | |
loader = loader_class(file_path, **loader_args) | |
return loader.load()[0] | |
def get_message_tokens(model: Llama, role: str, content: str) -> list: | |
""" | |
:param model: | |
:param role: | |
:param content: | |
:return: | |
""" | |
message_tokens: list = model.tokenize(content.encode("utf-8")) | |
message_tokens.insert(1, ROLE_TOKENS[role]) | |
message_tokens.insert(2, LINEBREAK_TOKEN) | |
message_tokens.append(model.token_eos()) | |
return message_tokens | |
def get_system_tokens(self, model: Llama) -> list: | |
""" | |
:param model: | |
:return: | |
""" | |
system_message: dict = {"role": "system", "content": SYSTEM_PROMPT} | |
return self.get_message_tokens(model, **system_message) | |
def upload_files(files: List[tempfile.TemporaryFile]) -> List[str]: | |
""" | |
:param files: | |
:return: | |
""" | |
return [f.name for f in files] | |
def process_text(text: str) -> Optional[str]: | |
""" | |
:param text: | |
:return: | |
""" | |
lines: list = text.split("\n") | |
lines = [line for line in lines if len(line.strip()) > 2] | |
text = "\n".join(lines).strip() | |
return None if len(text) < 10 else text | |
def update_text_db( | |
db: Optional[Chroma], | |
fixed_documents: List[Document], | |
ids: List[str] | |
) -> Union[Optional[Chroma], str]: | |
if db: | |
data: dict = db.get() | |
files_db = {dict_data['source'].split('/')[-1] for dict_data in data["metadatas"]} | |
files_load = {dict_data.metadata["source"].split('/')[-1] for dict_data in fixed_documents} | |
if files_load == files_db: | |
# db.delete([item for item in data['ids'] if item not in ids]) | |
# db.update_documents(ids, fixed_documents) | |
db.delete(data['ids']) | |
db.add_texts( | |
texts=[doc.page_content for doc in fixed_documents], | |
metadatas=[doc.metadata for doc in fixed_documents], | |
ids=ids | |
) | |
file_warning = f"Uploaded {len(fixed_documents)} fragments! You can ask questions" | |
return db, file_warning | |
def build_index( | |
self, | |
file_paths: List[str], | |
db: Optional[Chroma], | |
chunk_size: int, | |
chunk_overlap: int | |
): | |
""" | |
:param file_paths: | |
:param db: | |
:param chunk_size: | |
:param chunk_overlap: | |
:return: | |
""" | |
documents: List[Document] = [self.load_single_document(path) for path in file_paths] | |
text_splitter: RecursiveCharacterTextSplitter = RecursiveCharacterTextSplitter( | |
chunk_size=chunk_size, chunk_overlap=chunk_overlap | |
) | |
documents = text_splitter.split_documents(documents) | |
fixed_documents: List[Document] = [] | |
for doc in documents: | |
doc.page_content = self.process_text(doc.page_content) | |
if not doc.page_content: | |
continue | |
fixed_documents.append(doc) | |
ids: List[str] = [ | |
f"{path.split('/')[-1].replace('.txt', '')}{i}" | |
for path, i in itertools.product(file_paths, range(1, len(fixed_documents) + 1)) | |
] | |
self.update_text_db(db, fixed_documents, ids) | |
db = Chroma.from_documents( | |
documents=fixed_documents, | |
embedding=self.embeddings, | |
ids=ids, | |
client_settings=Settings( | |
anonymized_telemetry=False, | |
persist_directory="db" | |
) | |
) | |
file_warning = f"Uploaded {len(fixed_documents)} fragments! You can ask questions." | |
return db, file_warning | |
def user(message, history): | |
new_history = history + [[message, None]] | |
return "", new_history | |
def regenerate_response(history): | |
""" | |
:param history: | |
:return: | |
""" | |
return "", history | |
def retrieve(history, db: Optional[Chroma], retrieved_docs): | |
""" | |
:param history: | |
:param db: | |
:param retrieved_docs: | |
:return: | |
""" | |
if db: | |
last_user_message = history[-1][0] | |
try: | |
docs = db.similarity_search(last_user_message, k=4) | |
# retriever = db.as_retriever(search_kwargs={"k": k_documents}) | |
# docs = retriever.get_relevant_documents(last_user_message) | |
except RuntimeError: | |
docs = db.similarity_search(last_user_message, k=1) | |
# retriever = db.as_retriever(search_kwargs={"k": 1}) | |
# docs = retriever.get_relevant_documents(last_user_message) | |
source_docs = set() | |
for doc in docs: | |
for content in doc.metadata.values(): | |
source_docs.add(content.split("/")[-1]) | |
retrieved_docs = "\n\n".join([doc.page_content for doc in docs]) | |
retrieved_docs = f"A document- {''.join(list(source_docs))}.\n\n{retrieved_docs}" | |
return retrieved_docs | |
def bot(self, history, retrieved_docs): | |
""" | |
:param history: | |
:param retrieved_docs: | |
:return: | |
""" | |
if not history: | |
return | |
tokens = self.get_system_tokens(self.llama_model)[:] | |
tokens.append(LINEBREAK_TOKEN) | |
for user_message, bot_message in history[:-1]: | |
message_tokens = self.get_message_tokens(model=self.llama_model, role="user", content=user_message) | |
tokens.extend(message_tokens) | |
last_user_message = history[-1][0] | |
if retrieved_docs: | |
last_user_message = f"Context: {retrieved_docs}\n\nUsing context, answer the question:" \ | |
f"{last_user_message}" | |
message_tokens = self.get_message_tokens(model=self.llama_model, role="user", content=last_user_message) | |
tokens.extend(message_tokens) | |
role_tokens = [self.llama_model.token_bos(), BOT_TOKEN, LINEBREAK_TOKEN] | |
tokens.extend(role_tokens) | |
generator = self.llama_model.generate( | |
tokens, | |
top_k=30, | |
top_p=0.9, | |
temp=0.1 | |
) | |
partial_text = "" | |
for i, token in enumerate(generator): | |
if token == self.llama_model.token_eos() or (MAX_NEW_TOKENS is not None and i >= MAX_NEW_TOKENS): | |
break | |
partial_text += self.llama_model.detokenize([token]).decode("utf-8", "ignore") | |
history[-1][1] = partial_text | |
yield history | |
def run(self): | |
""" | |
:return: | |
""" | |
with gr.Blocks(theme=gr.themes.Soft(), css=BLOCK_CSS) as demo: | |
db: Optional[Chroma] = gr.State(None) | |
favicon = f'<img src="{FAVICON_PATH}" width="48px" style="display: inline">' | |
gr.Markdown( | |
f"""<h1><center>{favicon} GPT-based text assistant</center></h1>""" | |
) | |
with gr.Row(elem_id="model_selector_row"): | |
models: list = list(DICT_REPO_AND_MODELS.values()) | |
model_selector = gr.Dropdown( | |
choices=models, | |
value=models[0] if models else "", | |
interactive=True, | |
show_label=False, | |
container=False, | |
) | |
with gr.Row(): | |
with gr.Column(scale=5): | |
chatbot = gr.Chatbot(label="Dialogue", height=400) | |
with gr.Column(min_width=200, scale=4): | |
retrieved_docs = gr.Textbox( | |
label="Extracted fragments", | |
placeholder="Will appear after asking questions", | |
interactive=False | |
) | |
with gr.Row(): | |
with gr.Column(scale=20): | |
msg = gr.Textbox( | |
label="send a message", | |
show_label=False, | |
placeholder="send a message", | |
container=False | |
) | |
with gr.Column(scale=3, min_width=100): | |
submit = gr.Button("📤 Send", variant="primary") | |
with gr.Row(): | |
# gr.Button(value="👍 Понравилось") | |
# gr.Button(value="👎 Не понравилось") | |
stop = gr.Button(value="⛔ Stop") | |
regenerate = gr.Button(value="🔄 Repeat") | |
clear = gr.Button(value="🗑️ Clear") | |
# # Upload files | |
# file_output.upload( | |
# fn=self.upload_files, | |
# inputs=[file_output], | |
# outputs=[file_paths], | |
# queue=True, | |
# ).success( | |
# fn=self.build_index, | |
# inputs=[file_paths, db, chunk_size, chunk_overlap], | |
# outputs=[db, file_warning], | |
# queue=True | |
# ) | |
model_selector.change( | |
fn=self.load_model, | |
inputs=[model_selector], | |
outputs=[model_selector] | |
) | |
# Pressing Enter | |
submit_event = msg.submit( | |
fn=self.user, | |
inputs=[msg, chatbot], | |
outputs=[msg, chatbot], | |
queue=False, | |
).success( | |
fn=self.retrieve, | |
inputs=[chatbot, db, retrieved_docs], | |
outputs=[retrieved_docs], | |
queue=True, | |
).success( | |
fn=self.bot, | |
inputs=[chatbot, retrieved_docs], | |
outputs=chatbot, | |
queue=True, | |
) | |
# Pressing the button | |
submit_click_event = submit.click( | |
fn=self.user, | |
inputs=[msg, chatbot], | |
outputs=[msg, chatbot], | |
queue=False, | |
).success( | |
fn=self.retrieve, | |
inputs=[chatbot, db, retrieved_docs], | |
outputs=[retrieved_docs], | |
queue=True, | |
).success( | |
fn=self.bot, | |
inputs=[chatbot, retrieved_docs], | |
outputs=chatbot, | |
queue=True, | |
) | |
# Stop generation | |
stop.click( | |
fn=None, | |
inputs=None, | |
outputs=None, | |
cancels=[submit_event, submit_click_event], | |
queue=False, | |
) | |
# Regenerate | |
regenerate.click( | |
fn=self.regenerate_response, | |
inputs=[chatbot], | |
outputs=[msg, chatbot], | |
queue=False, | |
).success( | |
fn=self.retrieve, | |
inputs=[chatbot, db, retrieved_docs], | |
outputs=[retrieved_docs], | |
queue=True, | |
).success( | |
fn=self.bot, | |
inputs=[chatbot, retrieved_docs], | |
outputs=chatbot, | |
queue=True, | |
) | |
# Clear history | |
clear.click(lambda: None, None, chatbot, queue=False) | |
demo.queue(max_size=128, default_concurrency_limit=10, api_open=False) | |
demo.launch(server_name="0.0.0.0", max_threads=200) | |
if __name__ == "__main__": | |
local_chat_gpt = LocalChatGPT() | |
local_chat_gpt.run() |