import os import gradio as gr from langchain.embeddings import HuggingFaceInstructEmbeddings from langchain.text_splitter import CharacterTextSplitter from langchain.vectorstores import Chroma from langchain.document_loaders import TextLoader from langchain.memory import ConversationBufferMemory from langchain.llms import HuggingFaceHub from langchain.chains import ConversationalRetrievalChain embeddings = None qa_chain = None def load_embeddings(): global embeddings if not embeddings: print("loading embeddings...") model_name = os.environ['HUGGINGFACEHUB_EMBEDDINGS_MODEL_NAME'] embeddings = HuggingFaceInstructEmbeddings(model_name=model_name) return embeddings def split_file(file, chunk_size, chunk_overlap): print('spliting file...', file.name, chunk_size, chunk_overlap) loader = TextLoader(file.name) documents = loader.load() text_splitter = CharacterTextSplitter( chunk_size=chunk_size, chunk_overlap=chunk_overlap) return text_splitter.split_documents(documents) def get_persist_directory(file_name): return os.path.join(os.environ['CHROMADB_PERSIST_DIRECTORY'], file_name) def process_file(file, chunk_size, chunk_overlap): docs = split_file(file, chunk_size, chunk_overlap) embeddings = load_embeddings() file_name, _ = os.path.splitext(os.path.basename(file.name)) persist_directory = get_persist_directory(file_name) print("initializing vector store...", persist_directory) vectordb = Chroma.from_documents(documents=docs, embedding=embeddings, collection_name=file_name, persist_directory=persist_directory) print("persisting...", vectordb._client.list_collections()) vectordb.persist() return 'Done!', gr.Dropdown.update(choices=get_vector_dbs(), value=file_name) def is_dir(root, name): path = os.path.join(root, name) return os.path.isdir(path) def get_vector_dbs(): root = os.environ['CHROMADB_PERSIST_DIRECTORY'] if not os.path.exists(root): return [] print('get vector dbs...', root) files = os.listdir(root) dirs = list(filter(lambda x: is_dir(root, x), files)) print(dirs) return dirs def load_vectordb(file_name): embeddings = load_embeddings() persist_directory = get_persist_directory(file_name) print(persist_directory) vectordb = Chroma(collection_name=file_name, embedding_function=embeddings, persist_directory=persist_directory) print(vectordb._client.list_collections()) return vectordb def create_qa_chain(collection_name, temperature, max_length): print('creating qa chain...', collection_name, temperature, max_length) if not collection_name: return global qa_chain memory = ConversationBufferMemory( memory_key="chat_history", return_messages=True) llm = HuggingFaceHub( repo_id=os.environ["HUGGINGFACEHUB_LLM_REPO_ID"], model_kwargs={"temperature": temperature, "max_length": max_length} ) vectordb = load_vectordb(collection_name) qa_chain = ConversationalRetrievalChain.from_llm( llm=llm, retriever=vectordb.as_retriever(), memory=memory) def refresh_collection(): choices = get_vector_dbs() return gr.Dropdown.update(choices=choices, value=choices[0] if choices else None) def submit_message(bot_history, text): bot_history = bot_history + [(text, None)] return bot_history, "" def bot(bot_history): global qa_chain print(qa_chain, bot_history[-1][1]) result = qa_chain.run(bot_history[-1][0]) print(result) bot_history[-1][1] = result return bot_history def clear_bot(): return None title = "QnA Chatbot" with gr.Blocks() as demo: gr.Markdown(f"# {title}") with gr.Tab("File"): upload = gr.File(file_types=["text"], label="Upload File") chunk_size = gr.Slider( 500, 5000, value=1000, step=100, label="Chunk Size") chunk_overlap = gr.Slider(0, 30, value=20, label="Chunk Overlap") process = gr.Button("Process") result = gr.Label() with gr.Tab("Bot"): with gr.Row(): with gr.Column(): with gr.Row(): with gr.Column(scale=3): choices = get_vector_dbs() collection = gr.Dropdown( choices, value=choices[0] if choices else None, label="Document", allow_custom_value=True) with gr.Column(): refresh = gr.Button("Refresh") temperature = gr.Slider( 0.0, 1.0, value=0.5, step=0.05, label="Temperature") max_length = gr.Slider( 20, 1000, value=100, step=10, label="Max Length") with gr.Column(): chatbot = gr.Chatbot([], elem_id="chatbot").style(height=550) message = gr.Textbox( show_label=False, placeholder="Ask me anything!") clear = gr.Button("Clear") process.click( process_file, [upload, chunk_size, chunk_overlap], [result, collection] ) create_qa_chain(collection.value, temperature.value, max_length.value) collection.change(create_qa_chain, [collection, temperature, max_length]) temperature.change(create_qa_chain, [collection, temperature, max_length]) max_length.change(create_qa_chain, [collection, temperature, max_length]) refresh.click(refresh_collection, None, collection) message.submit(submit_message, [chatbot, message], [chatbot, message]).then( bot, chatbot, chatbot ) clear.click(clear_bot, None, chatbot) demo.title = title demo.launch()