|
import os |
|
import json |
|
import re |
|
import streamlit as st |
|
from transformers import AutoTokenizer |
|
import pandas as pd |
|
|
|
|
|
from sentence_transformers import SentenceTransformer, CrossEncoder |
|
import hnswlib |
|
import numpy as np |
|
from typing import Iterator |
|
|
|
from easyllm.clients import huggingface |
|
|
|
|
|
huggingface.prompt_builder = "llama2" |
|
huggingface.api_key = os.environ["HUGGINGFACE_TOKEN"] |
|
|
|
|
|
MAX_MAX_NEW_TOKENS = 2048 |
|
DEFAULT_MAX_NEW_TOKENS = 1024 |
|
MAX_INPUT_TOKEN_LENGTH = 4000 |
|
EMBED_DIM = 1024 |
|
K = 10 |
|
EF = 100 |
|
SEARCH_INDEX = "search_index.bin" |
|
EMBEDDINGS_FILE = "embeddings.npy" |
|
DOCUMENT_DATASET = "chunked_data.parquet" |
|
COSINE_THRESHOLD = 0.7 |
|
|
|
torch_device = "cuda" if torch.cuda.is_available() else "cpu" |
|
print("Running on device:", torch_device) |
|
print("CPU threads:", torch.get_num_threads()) |
|
|
|
model_id = "meta-llama/Llama-2-70b-chat-hf" |
|
biencoder = SentenceTransformer("intfloat/e5-large-v2", device=torch_device) |
|
cross_encoder = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-12-v2", max_length=512, device=torch_device) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=os.environ["HUGGINGFACE_TOKEN"]) |
|
|
|
|
|
st.title("PEFT Docs QA Chatbot") |
|
|
|
|
|
def create_qa_prompt(query, relevant_chunks): |
|
stuffed_context = " ".join(relevant_chunks) |
|
return f"""\ |
|
Use the following pieces of context given in to answer the question at the end. \ |
|
If you don't know the answer, just say that you don't know, don't try to make up an answer. \ |
|
Keep the answer short and succinct. |
|
|
|
Context: {stuffed_context} |
|
Question: {query} |
|
Helpful Answer: \ |
|
""" |
|
|
|
|
|
def generate_response(message, history_with_input, system_prompt, max_new_tokens, temperature, top_p, top_k): |
|
if max_new_tokens > MAX_MAX_NEW_TOKENS: |
|
raise ValueError |
|
history = history_with_input[:-1] |
|
if len(history) > 0: |
|
condensed_query = generate_condensed_query(message, history) |
|
print(f"{condensed_query=}") |
|
else: |
|
condensed_query = message |
|
query_embedding = create_query_embedding(condensed_query) |
|
relevant_chunks = find_nearest_neighbors(query_embedding) |
|
reranked_relevant_chunks = rerank_chunks_with_cross_encoder(condensed_query, relevant_chunks) |
|
qa_prompt = create_qa_prompt(condensed_query, reranked_relevant_chunks) |
|
print(f"{qa_prompt=}") |
|
generator = get_completion( |
|
qa_prompt, |
|
system_prompt=system_prompt, |
|
stream=True, |
|
max_new_tokens=max_new_tokens, |
|
temperature=temperature, |
|
top_k=top_k, |
|
top_p=top_p, |
|
) |
|
|
|
output = "" |
|
for idx, response in generator: |
|
token = response["choices"][0]["delta"].get("content", "") or "" |
|
output += token |
|
if idx == 0: |
|
history.append((message, output)) |
|
else: |
|
history[-1] = (message, output) |
|
|
|
history = [ |
|
(wrap_html_code(history[i][0].strip()), wrap_html_code(history[i][1].strip())) |
|
for i in range(0, len(history)) |
|
] |
|
return history |
|
|
|
|
|
def get_input_token_length(message, chat_history, system_prompt): |
|
prompt = get_prompt(message, chat_history, system_prompt) |
|
input_ids = tokenizer([prompt], return_tensors="np", add_special_tokens=False)["input_ids"] |
|
return input_ids.shape[-1] |
|
|
|
|
|
def generate_condensed_query(query, history): |
|
chat_history = "" |
|
for turn in history: |
|
chat_history += f"Human: {turn[0]}\n" |
|
chat_history += f"Assistant: {turn[1]}\n" |
|
|
|
condense_question_prompt = create_condense_question_prompt(query, chat_history) |
|
condensed_question = json.loads(get_completion(condense_question_prompt, max_new_tokens=64, temperature=0)) |
|
return condensed_question["question"] |
|
|
|
|
|
def load_hnsw_index(index_file): |
|
index = hnswlib.Index(space="ip", dim=EMBED_DIM) |
|
index.load_index(index_file) |
|
return index |
|
|
|
|
|
def create_hnsw_index(embeddings_file, M=16, efC=100): |
|
embeddings = np.load(embeddings_file) |
|
num_dim = embeddings.shape[1] |
|
ids = np.arange(embeddings.shape[0]) |
|
index = hnswlib.Index(space="ip", dim=num_dim) |
|
index.init_index(max_elements=embeddings.shape[0], ef_construction=efC, M=M) |
|
index.add_items(embeddings, ids) |
|
return index |
|
|
|
|
|
def create_query_embedding(query): |
|
embedding = biencoder.encode([query], normalize_embeddings=True)[0] |
|
return embedding |
|
|
|
|
|
def find_nearest_neighbors(query_embedding): |
|
search_index.set_ef(EF) |
|
labels, distances = search_index.knn_query(query_embedding, k=K) |
|
labels = [label for label, distance in zip(labels[0], distances[0]) if (1 - distance) >= COSINE_THRESHOLD] |
|
relevant_chunks = data_df.iloc[labels]["chunk_content"].tolist() |
|
return relevant_chunks |
|
|
|
|
|
def rerank_chunks_with_cross_encoder(query, chunks): |
|
pairs = [(query, chunk) for chunk in chunks] |
|
scores = cross_encoder.predict(pairs) |
|
sorted_chunks = [chunk for _, chunk in sorted(zip(scores, chunks), reverse=True)] |
|
return sorted_chunks |
|
|
|
|
|
def wrap_html_code(text): |
|
pattern = r"<.*?>" |
|
matches = re.findall(pattern, text) |
|
if len(matches) > 0: |
|
return f"```{text}```" |
|
else: |
|
return text |
|
|
|
|
|
search_index = create_hnsw_index(EMBEDDINGS_FILE) |
|
data_df = pd.read_parquet(DOCUMENT_DATASET).reset_index() |
|
|
|
|
|
st.markdown("Welcome to the PEFT Docs QA Chatbot.") |
|
message = st.text_input("You:", "") |
|
history_with_input = [] |
|
system_prompt = st.text_area("System prompt", DEFAULT_SYSTEM_PROMPT) |
|
max_new_tokens = st.slider("Max new tokens", 1, MAX_MAX_NEW_TOKENS, DEFAULT_MAX_NEW_TOKENS) |
|
temperature = st.slider("Temperature", 0.1, 4.0, 0.2, 0.1) |
|
top_p = st.slider("Top-p (nucleus sampling)", 0.05 , 1.0, 0.05) |
|
top_k = st.slider("Top-k", 1, 1000, 50) |
|
|
|
if st.button("Submit"): |
|
if message: |
|
try: |
|
history_with_input, response = generate_response( |
|
message, history_with_input, system_prompt, max_new_tokens, temperature, top_p, top_k |
|
) |
|
st.write("Chatbot:", response[-1][1]) |
|
except Exception as e: |
|
st.error(f"An error occurred: {e}") |
|
else: |
|
st.warning("Please enter a message.") |
|
|
|
if st.button("Retry"): |
|
if history_with_input: |
|
history_with_input, _ = generate_response( |
|
message, history_with_input, system_prompt, max_new_tokens, temperature, top_p, top_k |
|
) |
|
st.write("Chatbot:", history_with_input[-1][1]) |
|
else: |
|
st.warning("No previous message to retry.") |
|
|
|
if st.button("Undo"): |
|
if history_with_input: |
|
_, last_message = history_with_input.pop() |
|
st.text_area("You:", last_message, height=50) |
|
else: |
|
st.warning("No previous message to undo.") |
|
|
|
if st.button("Clear"): |
|
message = "" |
|
history_with_input = [] |
|
system_prompt = DEFAULT_SYSTEM_PROMPT |
|
max_new_tokens = DEFAULT_MAX_NEW_TOKENS |
|
temperature = 0.2 |
|
top_p = 0.95 |
|
top_k = 50 |
|
|
|
st.sidebar.markdown( |
|
"This is a Streamlit app for the PEFT Docs QA Chatbot. Enter your message, configure advanced options, and interact with the chatbot." |
|
) |
|
|
|
|