File size: 5,028 Bytes
2bdd84f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import streamlit as st
from utils import (
    load_model, 
    load_finetuned_model, 
    generate_response, 
    get_hf_token
)
import os
import json
from datetime import datetime

st.set_page_config(page_title="Gemma Chat", layout="wide")

# -------------------------------
# πŸ’‘ Theme Toggle
# -------------------------------
dark_mode = st.sidebar.toggle("πŸŒ™ Dark Mode", value=False)

if dark_mode:
    st.markdown(
        """
        <style>
        body { background-color: #1e1e1e; color: #ffffff; }
        .stTextInput, .stTextArea, .stSelectbox, .stSlider { color: #ffffff !important; }
        </style>
        """, unsafe_allow_html=True
    )

st.title("πŸ’¬ Chat with Gemma Model")

# -------------------------------
# πŸ“Œ Model Source Selection
# -------------------------------
model_source = st.sidebar.radio("πŸ“Œ Select Model Source", ["Local (.pt)", "Hugging Face"])

# -------------------------------
# πŸ”₯ Dynamic Model List
# -------------------------------
if model_source == "Local (.pt)":
    model_dir = "models"
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)

    local_models = [f for f in os.listdir(model_dir) if f.endswith(".pt")]

    if local_models:
        selected_model = st.sidebar.selectbox("πŸ› οΈ Select Local Model", local_models)
        model_path = os.path.join(model_dir, selected_model)
    else:
        st.warning("⚠️ No fine-tuned models found. Fine-tune a model first.")
        st.stop()

else:
    hf_models = [
        "google/gemma-3-1b-it", 
        "google/gemma-3-4b-pt", 
        "google/gemma-3-4b-it",
        "google/gemma-3-12b-pt", 
        "google/gemma-3-12b-it", 
        "google/gemma-3-27b-pt", 
        "google/gemma-3-27b-it"
    ]
    selected_model = st.sidebar.selectbox("πŸ› οΈ Select Hugging Face Model", hf_models)
    model_path = None

# -------------------------------
# πŸ”₯ Model Loading
# -------------------------------
hf_token = get_hf_token()

if model_source == "Local (.pt)":
    tokenizer, model = load_model("google/gemma-3-1b-it", hf_token)  # Base model first
    model = load_finetuned_model(model, model_path)
    if model:
        st.success(f"βœ… Local fine-tuned model loaded: `{selected_model}`")
    else:
        st.error("❌ Failed to load local model.")
        st.stop()

else:
    tokenizer, model = load_model(selected_model, hf_token)
    if model:
        st.success(f"βœ… Hugging Face model loaded: `{selected_model}`")
    else:
        st.error("❌ Failed to load Hugging Face model.")
        st.stop()

# -------------------------------
# βš™οΈ Model Configuration Panel
# -------------------------------
st.sidebar.header("βš™οΈ Model Configuration")
temperature = st.sidebar.slider("πŸ”₯ Temperature", 0.1, 1.5, 0.7, 0.1)
top_p = st.sidebar.slider("🎯 Top-p", 0.1, 1.0, 0.9, 0.1)
repetition_penalty = st.sidebar.slider("πŸ” Repetition Penalty", 0.5, 2.0, 1.0, 0.1)

# -------------------------------
# πŸ’¬ Chat Interface
# -------------------------------
if "conversation" not in st.session_state:
    st.session_state.conversation = []

prompt = st.text_area("πŸ’¬ Enter your message:", "Hello, how are you?", key="prompt", height=100)
max_length = st.slider("πŸ“ Max Response Length", min_value=50, max_value=1000, value=300, step=50)

# -------------------------------
# πŸš€ Streaming Response Function
# -------------------------------
def stream_response():
    """
    Streams the response token by token.
    """
    response = generate_response(prompt, model, tokenizer, max_length)

    if response:
        timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        st.session_state.conversation.append({"sender": "πŸ‘€ You", "message": prompt, "timestamp": timestamp})
        st.session_state.conversation.append({"sender": "πŸ€– AI", "message": response, "timestamp": timestamp})
        return response
    else:
        st.error("❌ Failed to generate response.")
        return None

# -------------------------------
# 🎯 Conversation Controls
# -------------------------------
col1, col2, col3 = st.columns([1, 1, 1])

if col1.button("πŸš€ Generate (CTRL+Enter)", help="Use CTRL + Enter to generate"):
    stream_response()

if col2.button("πŸ—‘οΈ Clear Conversation"):
    st.session_state.conversation = []

# Export & Import
if col3.download_button("πŸ“₯ Export Chat", json.dumps(st.session_state.conversation, indent=4), "chat_history.json"):
    st.success("βœ… Chat exported successfully!")

uploaded_file = st.file_uploader("πŸ“€ Import Conversation", type=["json"])

if uploaded_file is not None:
    st.session_state.conversation = json.load(uploaded_file)
    st.success("βœ… Conversation imported successfully!")

# -------------------------------
# πŸ› οΈ Display Conversation
# -------------------------------
st.subheader("πŸ“œ Conversation History")

for msg in st.session_state.conversation:
    with st.container():
        st.markdown(f"**{msg['sender']}**  \nπŸ•’ {msg['timestamp']}")
        st.write(msg['message'])