Cafe-Chatbot / app.py
Copain22's picture
Update app.py
f3f8525 verified
raw
history blame
5.75 kB
# 0. Install custom transformers and imports
import os
os.system("pip install git+https://github.com/shumingma/transformers.git")
os.system("pip install sentence-transformers")
import threading
import torch
import torch._dynamo
torch._dynamo.config.suppress_errors = True
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TextIteratorStreamer,
)
from sentence_transformers import SentenceTransformer
import gradio as gr
import spaces
import pdfplumber
from pathlib import Path
from PyPDF2 import PdfReader
# 1. System prompt
SYSTEM_PROMPT = """
You are a friendly café assistant for Café Eleven. Your job is to:
1. Greet the customer warmly.
2. Help them order food and drinks from our menu.
3. Ask the customer for their desired pickup time.
4. Confirm the pickup time before ending the conversation.
5. Answer questions about ingredients, preparation, etc.
6. Handle special requests (allergies, modifications) politely.
7. Provide calorie information if asked.
Always be polite, helpful, and ensure the customer feels welcomed and cared for!
"""
MODEL_ID = "microsoft/bitnet-b1.58-2B-4T"
# 2. Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
device_map="auto"
)
print(f"Model loaded on device: {model.device}")
# 3. Load PDF files
def load_pdfs(folder_path="."):
docs = []
current_section = None
for pdf_file in Path(folder_path).glob("*.pdf"):
with pdfplumber.open(str(pdf_file)) as pdf:
for page in pdf.pages:
text = page.extract_text()
if text:
lines = text.split("\n")
for line in lines:
line = line.strip()
if not line:
continue
if line.isupper() and len(line.split()) <= 6:
if current_section:
docs.append(current_section)
current_section = line
else:
if current_section:
current_section += f" | {line}"
else:
current_section = line
if current_section:
docs.append(current_section)
current_section = None
return docs
document_chunks = load_pdfs(".")
print(f"Loaded {len(document_chunks)} text chunks from PDFs.")
# 4. Create embeddings
embedder = SentenceTransformer("all-MiniLM-L6-v2") # Fast small model
doc_embeddings = embedder.encode(document_chunks, normalize_embeddings=True)
# 5. Retrieval function with float32 fix
def retrieve_context(question, top_k=3):
question_embedding = embedder.encode(question, normalize_embeddings=True)
question_embedding = torch.tensor(question_embedding, dtype=torch.float32)
doc_embeds = torch.tensor(doc_embeddings, dtype=torch.float32)
scores = doc_embeds @ question_embedding
top_indices = torch.topk(scores, k=min(top_k, len(scores))).indices.tolist()
return "\n\n".join([document_chunks[idx] for idx in top_indices])
# 6. Chat respond function
@spaces.GPU
def respond(
message: str,
history: list[tuple[str, str]],
system_message: str,
max_tokens: int,
temperature: float,
top_p: float,
):
context = retrieve_context(message)
messages = [{"role": "system", "content": system_message}]
for user_msg, bot_msg in history:
if user_msg:
messages.append({"role": "user", "content": user_msg})
if bot_msg:
messages.append({"role": "assistant", "content": bot_msg})
messages.append({"role": "user", "content": f"{message}\n\nRelevant menu info:\n{context}"})
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
streamer = TextIteratorStreamer(
tokenizer, skip_prompt=True, skip_special_tokens=True
)
generate_kwargs = dict(
**inputs,
streamer=streamer,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True,
)
thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
thread.start()
response = ""
for new_text in streamer:
response += new_text
yield response
# 7. Gradio ChatInterface
demo = gr.ChatInterface(
fn=respond,
title="Café Eleven Assistant",
description="Friendly café assistant with real menu knowledge!",
examples=[
[
"What kinds of burgers do you have?",
SYSTEM_PROMPT.strip(),
512,
0.7,
0.95,
],
[
"Do you have any gluten-free pastries?",
SYSTEM_PROMPT.strip(),
512,
0.7,
0.95,
],
],
additional_inputs=[
gr.Textbox(
value=SYSTEM_PROMPT.strip(),
label="System message"
),
gr.Slider(
minimum=1,
maximum=2048,
value=512,
step=1,
label="Max new tokens"
),
gr.Slider(
minimum=0.1,
maximum=4.0,
value=0.7,
step=0.1,
label="Temperature"
),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (nucleus sampling)"
),
],
)
# 8. Launch
if __name__ == "__main__":
demo.launch(share=True)