import os import json import re import streamlit as st from transformers import AutoTokenizer import pandas as pd # Importing Hugging Face models and libraries from sentence_transformers import SentenceTransformer, CrossEncoder import hnswlib import numpy as np from typing import Iterator from easyllm.clients import huggingface # Set Hugging Face API key huggingface.prompt_builder = "llama2" huggingface.api_key = os.environ["HUGGINGFACE_TOKEN"] # Constants MAX_MAX_NEW_TOKENS = 2048 DEFAULT_MAX_NEW_TOKENS = 1024 MAX_INPUT_TOKEN_LENGTH = 4000 EMBED_DIM = 1024 K = 10 EF = 100 SEARCH_INDEX = "search_index.bin" EMBEDDINGS_FILE = "embeddings.npy" DOCUMENT_DATASET = "chunked_data.parquet" COSINE_THRESHOLD = 0.7 torch_device = "cuda" if torch.cuda.is_available() else "cpu" print("Running on device:", torch_device) print("CPU threads:", torch.get_num_threads()) model_id = "meta-llama/Llama-2-70b-chat-hf" biencoder = SentenceTransformer("intfloat/e5-large-v2", device=torch_device) cross_encoder = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-12-v2", max_length=512, device=torch_device) tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=os.environ["HUGGINGFACE_TOKEN"]) # Initialize Streamlit app st.title("PEFT Docs QA Chatbot") # Function to create QA prompt def create_qa_prompt(query, relevant_chunks): stuffed_context = " ".join(relevant_chunks) return f"""\ Use the following pieces of context given in to answer the question at the end. \ If you don't know the answer, just say that you don't know, don't try to make up an answer. \ Keep the answer short and succinct. Context: {stuffed_context} Question: {query} Helpful Answer: \ """ # Function to generate a Streamlit app response def generate_response(message, history_with_input, system_prompt, max_new_tokens, temperature, top_p, top_k): if max_new_tokens > MAX_MAX_NEW_TOKENS: raise ValueError history = history_with_input[:-1] if len(history) > 0: condensed_query = generate_condensed_query(message, history) print(f"{condensed_query=}") else: condensed_query = message query_embedding = create_query_embedding(condensed_query) relevant_chunks = find_nearest_neighbors(query_embedding) reranked_relevant_chunks = rerank_chunks_with_cross_encoder(condensed_query, relevant_chunks) qa_prompt = create_qa_prompt(condensed_query, reranked_relevant_chunks) print(f"{qa_prompt=}") generator = get_completion( qa_prompt, system_prompt=system_prompt, stream=True, max_new_tokens=max_new_tokens, temperature=temperature, top_k=top_k, top_p=top_p, ) output = "" for idx, response in generator: token = response["choices"][0]["delta"].get("content", "") or "" output += token if idx == 0: history.append((message, output)) else: history[-1] = (message, output) history = [ (wrap_html_code(history[i][0].strip()), wrap_html_code(history[i][1].strip())) for i in range(0, len(history)) ] return history # Function to get input token length def get_input_token_length(message, chat_history, system_prompt): prompt = get_prompt(message, chat_history, system_prompt) input_ids = tokenizer([prompt], return_tensors="np", add_special_tokens=False)["input_ids"] return input_ids.shape[-1] # Function to create a condensed query def generate_condensed_query(query, history): chat_history = "" for turn in history: chat_history += f"Human: {turn[0]}\n" chat_history += f"Assistant: {turn[1]}\n" condense_question_prompt = create_condense_question_prompt(query, chat_history) condensed_question = json.loads(get_completion(condense_question_prompt, max_new_tokens=64, temperature=0)) return condensed_question["question"] # Function to load the HNSW index def load_hnsw_index(index_file): index = hnswlib.Index(space="ip", dim=EMBED_DIM) index.load_index(index_file) return index # Function to create the HNSW index def create_hnsw_index(embeddings_file, M=16, efC=100): embeddings = np.load(embeddings_file) num_dim = embeddings.shape[1] ids = np.arange(embeddings.shape[0]) index = hnswlib.Index(space="ip", dim=num_dim) index.init_index(max_elements=embeddings.shape[0], ef_construction=efC, M=M) index.add_items(embeddings, ids) return index # Function to create a query embedding def create_query_embedding(query): embedding = biencoder.encode([query], normalize_embeddings=True)[0] return embedding # Function to find nearest neighbors def find_nearest_neighbors(query_embedding): search_index.set_ef(EF) labels, distances = search_index.knn_query(query_embedding, k=K) labels = [label for label, distance in zip(labels[0], distances[0]) if (1 - distance) >= COSINE_THRESHOLD] relevant_chunks = data_df.iloc[labels]["chunk_content"].tolist() return relevant_chunks # Function to rerank chunks with the cross encoder def rerank_chunks_with_cross_encoder(query, chunks): pairs = [(query, chunk) for chunk in chunks] scores = cross_encoder.predict(pairs) sorted_chunks = [chunk for _, chunk in sorted(zip(scores, chunks), reverse=True)] return sorted_chunks # Function to wrap HTML code def wrap_html_code(text): pattern = r"<.*?>" matches = re.findall(pattern, text) if len(matches) > 0: return f"```{text}```" else: return text # Load the HNSW index for the PEFT docs search_index = create_hnsw_index(EMBEDDINGS_FILE) # load_hnsw_index(SEARCH_INDEX) data_df = pd.read_parquet(DOCUMENT_DATASET).reset_index() # Streamlit UI st.markdown("Welcome to the PEFT Docs QA Chatbot.") message = st.text_input("You:", "") history_with_input = [] system_prompt = st.text_area("System prompt", DEFAULT_SYSTEM_PROMPT) max_new_tokens = st.slider("Max new tokens", 1, MAX_MAX_NEW_TOKENS, DEFAULT_MAX_NEW_TOKENS) temperature = st.slider("Temperature", 0.1, 4.0, 0.2, 0.1) top_p = st.slider("Top-p (nucleus sampling)", 0.05 , 1.0, 0.05) top_k = st.slider("Top-k", 1, 1000, 50) if st.button("Submit"): if message: try: history_with_input, response = generate_response( message, history_with_input, system_prompt, max_new_tokens, temperature, top_p, top_k ) st.write("Chatbot:", response[-1][1]) except Exception as e: st.error(f"An error occurred: {e}") else: st.warning("Please enter a message.") if st.button("Retry"): if history_with_input: history_with_input, _ = generate_response( message, history_with_input, system_prompt, max_new_tokens, temperature, top_p, top_k ) st.write("Chatbot:", history_with_input[-1][1]) else: st.warning("No previous message to retry.") if st.button("Undo"): if history_with_input: _, last_message = history_with_input.pop() st.text_area("You:", last_message, height=50) else: st.warning("No previous message to undo.") if st.button("Clear"): message = "" history_with_input = [] system_prompt = DEFAULT_SYSTEM_PROMPT max_new_tokens = DEFAULT_MAX_NEW_TOKENS temperature = 0.2 top_p = 0.95 top_k = 50 st.sidebar.markdown( "This is a Streamlit app for the PEFT Docs QA Chatbot. Enter your message, configure advanced options, and interact with the chatbot." )