Spaces:
Build error
Build error
File size: 5,323 Bytes
c1786ee 930b56e c4b0ccb c1786ee 850643b fe9ba40 c1786ee 850643b c1786ee 850643b 930b56e 9e85637 930b56e c4b0ccb 850643b c1786ee 850643b c1786ee 930b56e c1786ee 850643b c1786ee 850643b c1786ee 850643b c1786ee 850643b c1786ee 850643b c1786ee 850643b c1786ee 850643b c1786ee 850643b c1786ee 850643b c1786ee 850643b c1786ee 850643b c1786ee 850643b c1786ee 850643b c1786ee 850643b c1786ee 850643b c1786ee 850643b c1786ee 850643b 930b56e 850643b 9e25b25 c1786ee 850643b 9e25b25 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
import streamlit as st
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import os
from threading import Thread
import requests
import time
# Define model path for caching (Avoids reloading every app restart)
MODEL_PATH = "/mnt/data/Phi-4-Hindi"
TOKEN = os.environ.get("HF_TOKEN")
MODEL_NAME = "DrishtiSharma/Phi-4-Hindi-quantized"
# Load Model & Tokenizer Once
@st.cache_resource()
def load_model():
with st.spinner("Loading model... Please wait β³"):
try:
if not os.path.exists(MODEL_PATH):
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME, token=TOKEN, trust_remote_code=True, torch_dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=TOKEN)
model.save_pretrained(MODEL_PATH)
tokenizer.save_pretrained(MODEL_PATH)
else:
model = AutoModelForCausalLM.from_pretrained(MODEL_PATH)
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
except requests.exceptions.ConnectionError:
st.error("β οΈ Connection error! Unable to download the model. Please check your internet connection and try again.")
return None, None
except requests.exceptions.ReadTimeout:
st.error("β οΈ Read Timeout! The request took too long. Please try again later.")
return None, None
return model, tokenizer
# Load and move model to appropriate device
model, tok = load_model()
if model is None or tok is None:
st.stop()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
try:
model = model.to(device)
except torch.cuda.OutOfMemoryError:
st.error("β οΈ CUDA Out of Memory! Running on CPU instead.")
device = torch.device("cpu")
model = model.to(device)
terminators = [tok.eos_token_id]
# Initialize session state if not set
if "chat_history" not in st.session_state:
st.session_state.chat_history = []
# Chat function
def chat(message, temperature, do_sample, max_tokens):
"""Processes chat input and generates a response using the model."""
# Append new message to history
st.session_state.chat_history.append({"role": "user", "content": message})
# Convert chat history into model-friendly format
messages = tok.apply_chat_template(st.session_state.chat_history, tokenize=False, add_generation_prompt=True)
model_inputs = tok([messages], return_tensors="pt").to(device)
# Initialize streamer for token-wise response
streamer = TextIteratorStreamer(tok, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
# Define generation parameters
generate_kwargs = {
"inputs": model_inputs["input_ids"],
"streamer": streamer,
"max_new_tokens": max_tokens,
"do_sample": do_sample,
"temperature": temperature,
"eos_token_id": terminators,
}
if temperature == 0:
generate_kwargs["do_sample"] = False
# Generate response asynchronously
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
# Collect response as it streams
response_text = ""
for new_text in streamer:
response_text += new_text
yield response_text
# Save the assistant's response to session history
st.session_state.chat_history.append({"role": "assistant", "content": response_text})
# UI Setup
st.title("π¬ Chat With Phi-4-Hindi")
st.success("β
Model is READY to chat!")
st.markdown("Chat with [large-traversaal/Phi-4-Hindi](https://huggingface.co./large-traversaal/Phi-4-Hindi)")
# Sidebar Chat Settings
temperature = st.sidebar.slider("Temperature", 0.0, 1.0, 0.3, 0.1)
do_sample = st.sidebar.checkbox("Use Sampling", value=True)
max_tokens = st.sidebar.slider("Max Tokens", 128, 4096, 512, 1)
text_color = st.sidebar.selectbox("Text Color", ["Red", "Black", "Blue", "Green", "Purple"], index=0)
dark_mode = st.sidebar.checkbox("π Dark Mode", value=False)
# Function to format chat messages
def get_html_text(text, color):
return f'<p style="color: {color.lower()}; font-size: 16px;">{text}</p>'
# Display chat history
for msg in st.session_state.chat_history:
role = "π€" if msg["role"] == "user" else "π€"
st.markdown(get_html_text(f"**{role}:** {msg['content']}", text_color if role == "π€" else "black"), unsafe_allow_html=True)
# User Input Handling
user_input = st.text_input("Type your message:", "")
if st.button("Send"):
if user_input.strip():
st.session_state.chat_history.append({"role": "user", "content": user_input})
# Display chatbot response
with st.spinner("Generating response... π€π"):
response_generator = chat(user_input, temperature, do_sample, max_tokens)
final_response = ""
for output in response_generator:
final_response = output # Store latest output
st.success("β
Response generated!")
# Add generated response to session state
st.rerun()
if st.button("π§Ή Clear Chat"):
with st.spinner("Clearing chat history..."):
st.session_state.chat_history = []
st.success("β
Chat history cleared!")
st.rerun()
|