mohitkumarrajbadi's picture
New Framework Change
2bdd84f
import streamlit as st
from utils import (
load_model,
load_finetuned_model,
generate_response,
get_hf_token
)
import os
import json
from datetime import datetime
st.set_page_config(page_title="Gemma Chat", layout="wide")
# -------------------------------
# πŸ’‘ Theme Toggle
# -------------------------------
dark_mode = st.sidebar.toggle("πŸŒ™ Dark Mode", value=False)
if dark_mode:
st.markdown(
"""
<style>
body { background-color: #1e1e1e; color: #ffffff; }
.stTextInput, .stTextArea, .stSelectbox, .stSlider { color: #ffffff !important; }
</style>
""", unsafe_allow_html=True
)
st.title("πŸ’¬ Chat with Gemma Model")
# -------------------------------
# πŸ“Œ Model Source Selection
# -------------------------------
model_source = st.sidebar.radio("πŸ“Œ Select Model Source", ["Local (.pt)", "Hugging Face"])
# -------------------------------
# πŸ”₯ Dynamic Model List
# -------------------------------
if model_source == "Local (.pt)":
model_dir = "models"
if not os.path.exists(model_dir):
os.makedirs(model_dir)
local_models = [f for f in os.listdir(model_dir) if f.endswith(".pt")]
if local_models:
selected_model = st.sidebar.selectbox("πŸ› οΈ Select Local Model", local_models)
model_path = os.path.join(model_dir, selected_model)
else:
st.warning("⚠️ No fine-tuned models found. Fine-tune a model first.")
st.stop()
else:
hf_models = [
"google/gemma-3-1b-it",
"google/gemma-3-4b-pt",
"google/gemma-3-4b-it",
"google/gemma-3-12b-pt",
"google/gemma-3-12b-it",
"google/gemma-3-27b-pt",
"google/gemma-3-27b-it"
]
selected_model = st.sidebar.selectbox("πŸ› οΈ Select Hugging Face Model", hf_models)
model_path = None
# -------------------------------
# πŸ”₯ Model Loading
# -------------------------------
hf_token = get_hf_token()
if model_source == "Local (.pt)":
tokenizer, model = load_model("google/gemma-3-1b-it", hf_token) # Base model first
model = load_finetuned_model(model, model_path)
if model:
st.success(f"βœ… Local fine-tuned model loaded: `{selected_model}`")
else:
st.error("❌ Failed to load local model.")
st.stop()
else:
tokenizer, model = load_model(selected_model, hf_token)
if model:
st.success(f"βœ… Hugging Face model loaded: `{selected_model}`")
else:
st.error("❌ Failed to load Hugging Face model.")
st.stop()
# -------------------------------
# βš™οΈ Model Configuration Panel
# -------------------------------
st.sidebar.header("βš™οΈ Model Configuration")
temperature = st.sidebar.slider("πŸ”₯ Temperature", 0.1, 1.5, 0.7, 0.1)
top_p = st.sidebar.slider("🎯 Top-p", 0.1, 1.0, 0.9, 0.1)
repetition_penalty = st.sidebar.slider("πŸ” Repetition Penalty", 0.5, 2.0, 1.0, 0.1)
# -------------------------------
# πŸ’¬ Chat Interface
# -------------------------------
if "conversation" not in st.session_state:
st.session_state.conversation = []
prompt = st.text_area("πŸ’¬ Enter your message:", "Hello, how are you?", key="prompt", height=100)
max_length = st.slider("πŸ“ Max Response Length", min_value=50, max_value=1000, value=300, step=50)
# -------------------------------
# πŸš€ Streaming Response Function
# -------------------------------
def stream_response():
"""
Streams the response token by token.
"""
response = generate_response(prompt, model, tokenizer, max_length)
if response:
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
st.session_state.conversation.append({"sender": "πŸ‘€ You", "message": prompt, "timestamp": timestamp})
st.session_state.conversation.append({"sender": "πŸ€– AI", "message": response, "timestamp": timestamp})
return response
else:
st.error("❌ Failed to generate response.")
return None
# -------------------------------
# 🎯 Conversation Controls
# -------------------------------
col1, col2, col3 = st.columns([1, 1, 1])
if col1.button("πŸš€ Generate (CTRL+Enter)", help="Use CTRL + Enter to generate"):
stream_response()
if col2.button("πŸ—‘οΈ Clear Conversation"):
st.session_state.conversation = []
# Export & Import
if col3.download_button("πŸ“₯ Export Chat", json.dumps(st.session_state.conversation, indent=4), "chat_history.json"):
st.success("βœ… Chat exported successfully!")
uploaded_file = st.file_uploader("πŸ“€ Import Conversation", type=["json"])
if uploaded_file is not None:
st.session_state.conversation = json.load(uploaded_file)
st.success("βœ… Conversation imported successfully!")
# -------------------------------
# πŸ› οΈ Display Conversation
# -------------------------------
st.subheader("πŸ“œ Conversation History")
for msg in st.session_state.conversation:
with st.container():
st.markdown(f"**{msg['sender']}** \nπŸ•’ {msg['timestamp']}")
st.write(msg['message'])