Spaces:
Runtime error
Runtime error
import os | |
import gradio as gr | |
from langchain.embeddings import HuggingFaceInstructEmbeddings | |
from langchain.text_splitter import CharacterTextSplitter | |
from langchain.vectorstores import Chroma | |
from langchain.document_loaders import TextLoader | |
from langchain.memory import ConversationBufferMemory | |
from langchain.llms import HuggingFaceHub | |
from langchain.chains import ConversationalRetrievalChain | |
def load_embeddings(): | |
print("Loading embeddings...") | |
model_name = os.environ['HUGGINGFACEHUB_EMBEDDINGS_MODEL_NAME'] | |
return HuggingFaceInstructEmbeddings(model_name=model_name) | |
def split_file(file, chunk_size, chunk_overlap): | |
print('spliting file', file.name) | |
loader = TextLoader(file.name) | |
documents = loader.load() | |
text_splitter = CharacterTextSplitter( | |
chunk_size=chunk_size, chunk_overlap=chunk_overlap) | |
return text_splitter.split_documents(documents) | |
def get_persist_directory(file_name): | |
return os.path.join(os.environ['CHROMADB_PERSIST_DIRECTORY'], file_name) | |
def process_file(file, chunk_size, chunk_overlap): | |
docs = split_file(file, chunk_size, chunk_overlap) | |
embeddings = load_embeddings() | |
file_name, _ = os.path.splitext(os.path.basename(file.name)) | |
persist_directory = get_persist_directory(file_name) | |
print("persist directory", persist_directory) | |
vectordb = Chroma.from_documents(documents=docs, embedding=embeddings, | |
collection_name=file_name, persist_directory=persist_directory) | |
print(vectordb._client.list_collections()) | |
vectordb.persist() | |
return 'Done!' | |
def is_dir(root, name): | |
path = os.path.join(root, name) | |
return os.path.isdir(path) | |
def get_vector_dbs(): | |
root = os.environ['CHROMADB_PERSIST_DIRECTORY'] | |
if not os.path.exists(root): | |
return [] | |
files = os.listdir(root) | |
dirs = filter(lambda x: is_dir(root, x), files) | |
print(dirs) | |
return dirs | |
def load_vectordb(file_name): | |
embeddings = load_embeddings() | |
persist_directory = get_persist_directory(file_name) | |
print(persist_directory) | |
vectordb = Chroma(collection_name=file_name, | |
embedding_function=embeddings, persist_directory=persist_directory) | |
print(vectordb._client.list_collections()) | |
return vectordb | |
def create_qa_chain(collection_name, temperature, max_length): | |
print('creating qa chain...') | |
memory = ConversationBufferMemory( | |
memory_key="chat_history", return_messages=True) | |
llm = HuggingFaceHub( | |
repo_id=os.environ["HUGGINGFACEHUB_LLM_REPO_ID"], | |
model_kwargs={"temperature": temperature, "max_length": max_length} | |
) | |
vectordb = load_vectordb(collection_name) | |
return ConversationalRetrievalChain.from_llm(llm=llm, retriever=vectordb.as_retriever(), memory=memory) | |
def submit_message(bot_history, text): | |
bot_history = bot_history + [(text, None)] | |
return bot_history, "" | |
def bot(bot_history, collection_name, temperature, max_length): | |
qa = create_qa_chain(collection_name, temperature, max_length) | |
print(qa, bot_history[-1][1]) | |
qa.run(bot_history[-1][0]) | |
bot_history[-1][1] = 'so cool!' | |
return bot_history | |
def clear_bot(): | |
return None | |
title = "QnA Chatbot" | |
with gr.Blocks() as demo: | |
gr.Markdown(f"# {title}") | |
with gr.Tab("File"): | |
upload = gr.File(file_types=["text"], label="Upload File") | |
chunk_size = gr.Slider( | |
500, 5000, value=1000, step=100, label="Chunk Size") | |
chunk_overlap = gr.Slider(0, 30, value=20, label="Chunk Overlap") | |
process = gr.Button("Process") | |
result = gr.Label() | |
with gr.Tab("Bot"): | |
with gr.Row(): | |
with gr.Column(scale=0.5): | |
collection = gr.Dropdown( | |
choices=get_vector_dbs(), label="Document") | |
temperature = gr.Slider( | |
0.0, 1.0, value=0.5, step=0.05, label="Temperature") | |
max_length = gr.Slider(20, 1000, value=64, label="Max Length") | |
with gr.Column(scale=0.5): | |
chatbot = gr.Chatbot([], elem_id="chatbot").style(height=550) | |
message = gr.Textbox( | |
show_label=False, placeholder="Ask me anything!") | |
clear = gr.Button("Clear") | |
process.click(process_file, [upload, chunk_size, chunk_overlap], result) | |
message.submit(submit_message, [chatbot, message], [chatbot, message]).then( | |
bot, [chatbot, collection, temperature, max_length], chatbot | |
) | |
clear.click(clear_bot, None, chatbot) | |
demo.title = title | |
demo.launch() | |