File size: 7,513 Bytes
7bf90b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b9c4ce9
7bf90b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
import os
import json
import re
import streamlit as st
from transformers import AutoTokenizer
import pandas as pd

# Importing Hugging Face models and libraries
from sentence_transformers import SentenceTransformer, CrossEncoder
import hnswlib
import numpy as np
from typing import Iterator

from easyllm.clients import huggingface

# Set Hugging Face API key
huggingface.prompt_builder = "llama2"
huggingface.api_key = os.environ["HUGGINGFACE_TOKEN"]

# Constants
MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = 4000
EMBED_DIM = 1024
K = 10
EF = 100
SEARCH_INDEX = "search_index.bin"
EMBEDDINGS_FILE = "embeddings.npy"
DOCUMENT_DATASET = "chunked_data.parquet"
COSINE_THRESHOLD = 0.7

torch_device = "cuda" if torch.cuda.is_available() else "cpu"
print("Running on device:", torch_device)
print("CPU threads:", torch.get_num_threads())

model_id = "meta-llama/Llama-2-70b-chat-hf"
biencoder = SentenceTransformer("intfloat/e5-large-v2", device=torch_device)
cross_encoder = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-12-v2", max_length=512, device=torch_device)

tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=os.environ["HUGGINGFACE_TOKEN"])

# Initialize Streamlit app
st.title("PEFT Docs QA Chatbot")

# Function to create QA prompt
def create_qa_prompt(query, relevant_chunks):
    stuffed_context = " ".join(relevant_chunks)
    return f"""\
Use the following pieces of context given in to answer the question at the end. \
If you don't know the answer, just say that you don't know, don't try to make up an answer. \
Keep the answer short and succinct.
        
Context: {stuffed_context}
Question: {query}
Helpful Answer: \
"""

# Function to generate a Streamlit app response
def generate_response(message, history_with_input, system_prompt, max_new_tokens, temperature, top_p, top_k):
    if max_new_tokens > MAX_MAX_NEW_TOKENS:
        raise ValueError
    history = history_with_input[:-1]
    if len(history) > 0:
        condensed_query = generate_condensed_query(message, history)
        print(f"{condensed_query=}")
    else:
        condensed_query = message
    query_embedding = create_query_embedding(condensed_query)
    relevant_chunks = find_nearest_neighbors(query_embedding)
    reranked_relevant_chunks = rerank_chunks_with_cross_encoder(condensed_query, relevant_chunks)
    qa_prompt = create_qa_prompt(condensed_query, reranked_relevant_chunks)
    print(f"{qa_prompt=}")
    generator = get_completion(
        qa_prompt,
        system_prompt=system_prompt,
        stream=True,
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        top_k=top_k,
        top_p=top_p,
    )

    output = ""
    for idx, response in generator:
        token = response["choices"][0]["delta"].get("content", "") or ""
        output += token
        if idx == 0:
            history.append((message, output))
        else:
            history[-1] = (message, output)

        history = [
            (wrap_html_code(history[i][0].strip()), wrap_html_code(history[i][1].strip()))
            for i in range(0, len(history))
        ]
        return history

# Function to get input token length
def get_input_token_length(message, chat_history, system_prompt):
    prompt = get_prompt(message, chat_history, system_prompt)
    input_ids = tokenizer([prompt], return_tensors="np", add_special_tokens=False)["input_ids"]
    return input_ids.shape[-1]

# Function to create a condensed query
def generate_condensed_query(query, history):
    chat_history = ""
    for turn in history:
        chat_history += f"Human: {turn[0]}\n"
        chat_history += f"Assistant: {turn[1]}\n"

    condense_question_prompt = create_condense_question_prompt(query, chat_history)
    condensed_question = json.loads(get_completion(condense_question_prompt, max_new_tokens=64, temperature=0))
    return condensed_question["question"]

# Function to load the HNSW index
def load_hnsw_index(index_file):
    index = hnswlib.Index(space="ip", dim=EMBED_DIM)
    index.load_index(index_file)
    return index

# Function to create the HNSW index
def create_hnsw_index(embeddings_file, M=16, efC=100):
    embeddings = np.load(embeddings_file)
    num_dim = embeddings.shape[1]
    ids = np.arange(embeddings.shape[0])
    index = hnswlib.Index(space="ip", dim=num_dim)
    index.init_index(max_elements=embeddings.shape[0], ef_construction=efC, M=M)
    index.add_items(embeddings, ids)
    return index

# Function to create a query embedding
def create_query_embedding(query):
    embedding = biencoder.encode([query], normalize_embeddings=True)[0]
    return embedding

# Function to find nearest neighbors
def find_nearest_neighbors(query_embedding):
    search_index.set_ef(EF)
    labels, distances = search_index.knn_query(query_embedding, k=K)
    labels = [label for label, distance in zip(labels[0], distances[0]) if (1 - distance) >= COSINE_THRESHOLD]
    relevant_chunks = data_df.iloc[labels]["chunk_content"].tolist()
    return relevant_chunks

# Function to rerank chunks with the cross encoder
def rerank_chunks_with_cross_encoder(query, chunks):
    pairs = [(query, chunk) for chunk in chunks]
    scores = cross_encoder.predict(pairs)
    sorted_chunks = [chunk for _, chunk in sorted(zip(scores, chunks), reverse=True)]
    return sorted_chunks

# Function to wrap HTML code
def wrap_html_code(text):
    pattern = r"<.*?>"
    matches = re.findall(pattern, text)
    if len(matches) > 0:
        return f"```{text}```"
    else:
        return text

# Load the HNSW index for the PEFT docs
search_index = create_hnsw_index(EMBEDDINGS_FILE)  # load_hnsw_index(SEARCH_INDEX)
data_df = pd.read_parquet(DOCUMENT_DATASET).reset_index()

# Streamlit UI
st.markdown("Welcome to the PEFT Docs QA Chatbot.")
message = st.text_input("You:", "")
history_with_input = []
system_prompt = st.text_area("System prompt", DEFAULT_SYSTEM_PROMPT)
max_new_tokens = st.slider("Max new tokens", 1, MAX_MAX_NEW_TOKENS, DEFAULT_MAX_NEW_TOKENS)
temperature = st.slider("Temperature", 0.1, 4.0, 0.2, 0.1)
top_p = st.slider("Top-p (nucleus sampling)", 0.05 , 1.0, 0.05)
top_k = st.slider("Top-k", 1, 1000, 50)

if st.button("Submit"):
    if message:
        try:
            history_with_input, response = generate_response(
                message, history_with_input, system_prompt, max_new_tokens, temperature, top_p, top_k
            )
            st.write("Chatbot:", response[-1][1])
        except Exception as e:
            st.error(f"An error occurred: {e}")
    else:
        st.warning("Please enter a message.")

if st.button("Retry"):
    if history_with_input:
        history_with_input, _ = generate_response(
            message, history_with_input, system_prompt, max_new_tokens, temperature, top_p, top_k
        )
        st.write("Chatbot:", history_with_input[-1][1])
    else:
        st.warning("No previous message to retry.")

if st.button("Undo"):
    if history_with_input:
        _, last_message = history_with_input.pop()
        st.text_area("You:", last_message, height=50)
    else:
        st.warning("No previous message to undo.")

if st.button("Clear"):
    message = ""
    history_with_input = []
    system_prompt = DEFAULT_SYSTEM_PROMPT
    max_new_tokens = DEFAULT_MAX_NEW_TOKENS
    temperature = 0.2
    top_p = 0.95
    top_k = 50

st.sidebar.markdown(
    "This is a Streamlit app for the PEFT Docs QA Chatbot. Enter your message, configure advanced options, and interact with the chatbot."
)