import streamlit as st from utils import ( load_model, load_finetuned_model, generate_response, get_hf_token ) import os import json from datetime import datetime st.set_page_config(page_title="Gemma Chat", layout="wide") # ------------------------------- # šŸ’” Theme Toggle # ------------------------------- dark_mode = st.sidebar.toggle("šŸŒ™ Dark Mode", value=False) if dark_mode: st.markdown( """ """, unsafe_allow_html=True ) st.title("šŸ’¬ Chat with Gemma Model") # ------------------------------- # šŸ“Œ Model Source Selection # ------------------------------- model_source = st.sidebar.radio("šŸ“Œ Select Model Source", ["Local (.pt)", "Hugging Face"]) # ------------------------------- # šŸ”„ Dynamic Model List # ------------------------------- if model_source == "Local (.pt)": model_dir = "models" if not os.path.exists(model_dir): os.makedirs(model_dir) local_models = [f for f in os.listdir(model_dir) if f.endswith(".pt")] if local_models: selected_model = st.sidebar.selectbox("šŸ› ļø Select Local Model", local_models) model_path = os.path.join(model_dir, selected_model) else: st.warning("āš ļø No fine-tuned models found. Fine-tune a model first.") st.stop() else: hf_models = [ "google/gemma-3-1b-it", "google/gemma-3-4b-pt", "google/gemma-3-4b-it", "google/gemma-3-12b-pt", "google/gemma-3-12b-it", "google/gemma-3-27b-pt", "google/gemma-3-27b-it" ] selected_model = st.sidebar.selectbox("šŸ› ļø Select Hugging Face Model", hf_models) model_path = None # ------------------------------- # šŸ”„ Model Loading # ------------------------------- hf_token = get_hf_token() if model_source == "Local (.pt)": tokenizer, model = load_model("google/gemma-3-1b-it", hf_token) # Base model first model = load_finetuned_model(model, model_path) if model: st.success(f"āœ… Local fine-tuned model loaded: `{selected_model}`") else: st.error("āŒ Failed to load local model.") st.stop() else: tokenizer, model = load_model(selected_model, hf_token) if model: st.success(f"āœ… Hugging Face model loaded: `{selected_model}`") else: st.error("āŒ Failed to load Hugging Face model.") st.stop() # ------------------------------- # āš™ļø Model Configuration Panel # ------------------------------- st.sidebar.header("āš™ļø Model Configuration") temperature = st.sidebar.slider("šŸ”„ Temperature", 0.1, 1.5, 0.7, 0.1) top_p = st.sidebar.slider("šŸŽÆ Top-p", 0.1, 1.0, 0.9, 0.1) repetition_penalty = st.sidebar.slider("šŸ” Repetition Penalty", 0.5, 2.0, 1.0, 0.1) # ------------------------------- # šŸ’¬ Chat Interface # ------------------------------- if "conversation" not in st.session_state: st.session_state.conversation = [] prompt = st.text_area("šŸ’¬ Enter your message:", "Hello, how are you?", key="prompt", height=100) max_length = st.slider("šŸ“ Max Response Length", min_value=50, max_value=1000, value=300, step=50) # ------------------------------- # šŸš€ Streaming Response Function # ------------------------------- def stream_response(): """ Streams the response token by token. """ response = generate_response(prompt, model, tokenizer, max_length) if response: timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") st.session_state.conversation.append({"sender": "šŸ‘¤ You", "message": prompt, "timestamp": timestamp}) st.session_state.conversation.append({"sender": "šŸ¤– AI", "message": response, "timestamp": timestamp}) return response else: st.error("āŒ Failed to generate response.") return None # ------------------------------- # šŸŽÆ Conversation Controls # ------------------------------- col1, col2, col3 = st.columns([1, 1, 1]) if col1.button("šŸš€ Generate (CTRL+Enter)", help="Use CTRL + Enter to generate"): stream_response() if col2.button("šŸ—‘ļø Clear Conversation"): st.session_state.conversation = [] # Export & Import if col3.download_button("šŸ“„ Export Chat", json.dumps(st.session_state.conversation, indent=4), "chat_history.json"): st.success("āœ… Chat exported successfully!") uploaded_file = st.file_uploader("šŸ“¤ Import Conversation", type=["json"]) if uploaded_file is not None: st.session_state.conversation = json.load(uploaded_file) st.success("āœ… Conversation imported successfully!") # ------------------------------- # šŸ› ļø Display Conversation # ------------------------------- st.subheader("šŸ“œ Conversation History") for msg in st.session_state.conversation: with st.container(): st.markdown(f"**{msg['sender']}** \nšŸ•’ {msg['timestamp']}") st.write(msg['message'])