Spaces:
Build error
Build error
File size: 5,028 Bytes
2bdd84f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
import streamlit as st
from utils import (
load_model,
load_finetuned_model,
generate_response,
get_hf_token
)
import os
import json
from datetime import datetime
st.set_page_config(page_title="Gemma Chat", layout="wide")
# -------------------------------
# π‘ Theme Toggle
# -------------------------------
dark_mode = st.sidebar.toggle("π Dark Mode", value=False)
if dark_mode:
st.markdown(
"""
<style>
body { background-color: #1e1e1e; color: #ffffff; }
.stTextInput, .stTextArea, .stSelectbox, .stSlider { color: #ffffff !important; }
</style>
""", unsafe_allow_html=True
)
st.title("π¬ Chat with Gemma Model")
# -------------------------------
# π Model Source Selection
# -------------------------------
model_source = st.sidebar.radio("π Select Model Source", ["Local (.pt)", "Hugging Face"])
# -------------------------------
# π₯ Dynamic Model List
# -------------------------------
if model_source == "Local (.pt)":
model_dir = "models"
if not os.path.exists(model_dir):
os.makedirs(model_dir)
local_models = [f for f in os.listdir(model_dir) if f.endswith(".pt")]
if local_models:
selected_model = st.sidebar.selectbox("π οΈ Select Local Model", local_models)
model_path = os.path.join(model_dir, selected_model)
else:
st.warning("β οΈ No fine-tuned models found. Fine-tune a model first.")
st.stop()
else:
hf_models = [
"google/gemma-3-1b-it",
"google/gemma-3-4b-pt",
"google/gemma-3-4b-it",
"google/gemma-3-12b-pt",
"google/gemma-3-12b-it",
"google/gemma-3-27b-pt",
"google/gemma-3-27b-it"
]
selected_model = st.sidebar.selectbox("π οΈ Select Hugging Face Model", hf_models)
model_path = None
# -------------------------------
# π₯ Model Loading
# -------------------------------
hf_token = get_hf_token()
if model_source == "Local (.pt)":
tokenizer, model = load_model("google/gemma-3-1b-it", hf_token) # Base model first
model = load_finetuned_model(model, model_path)
if model:
st.success(f"β
Local fine-tuned model loaded: `{selected_model}`")
else:
st.error("β Failed to load local model.")
st.stop()
else:
tokenizer, model = load_model(selected_model, hf_token)
if model:
st.success(f"β
Hugging Face model loaded: `{selected_model}`")
else:
st.error("β Failed to load Hugging Face model.")
st.stop()
# -------------------------------
# βοΈ Model Configuration Panel
# -------------------------------
st.sidebar.header("βοΈ Model Configuration")
temperature = st.sidebar.slider("π₯ Temperature", 0.1, 1.5, 0.7, 0.1)
top_p = st.sidebar.slider("π― Top-p", 0.1, 1.0, 0.9, 0.1)
repetition_penalty = st.sidebar.slider("π Repetition Penalty", 0.5, 2.0, 1.0, 0.1)
# -------------------------------
# π¬ Chat Interface
# -------------------------------
if "conversation" not in st.session_state:
st.session_state.conversation = []
prompt = st.text_area("π¬ Enter your message:", "Hello, how are you?", key="prompt", height=100)
max_length = st.slider("π Max Response Length", min_value=50, max_value=1000, value=300, step=50)
# -------------------------------
# π Streaming Response Function
# -------------------------------
def stream_response():
"""
Streams the response token by token.
"""
response = generate_response(prompt, model, tokenizer, max_length)
if response:
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
st.session_state.conversation.append({"sender": "π€ You", "message": prompt, "timestamp": timestamp})
st.session_state.conversation.append({"sender": "π€ AI", "message": response, "timestamp": timestamp})
return response
else:
st.error("β Failed to generate response.")
return None
# -------------------------------
# π― Conversation Controls
# -------------------------------
col1, col2, col3 = st.columns([1, 1, 1])
if col1.button("π Generate (CTRL+Enter)", help="Use CTRL + Enter to generate"):
stream_response()
if col2.button("ποΈ Clear Conversation"):
st.session_state.conversation = []
# Export & Import
if col3.download_button("π₯ Export Chat", json.dumps(st.session_state.conversation, indent=4), "chat_history.json"):
st.success("β
Chat exported successfully!")
uploaded_file = st.file_uploader("π€ Import Conversation", type=["json"])
if uploaded_file is not None:
st.session_state.conversation = json.load(uploaded_file)
st.success("β
Conversation imported successfully!")
# -------------------------------
# π οΈ Display Conversation
# -------------------------------
st.subheader("π Conversation History")
for msg in st.session_state.conversation:
with st.container():
st.markdown(f"**{msg['sender']}** \nπ {msg['timestamp']}")
st.write(msg['message'])
|