import streamlit as st import pandas as pd import json import os from datetime import datetime from utils import ( load_model, get_hf_token, simulate_training, plot_training_metrics, load_finetuned_model, save_model ) st.title("🔥 Fine-tune the Gemma Model") # ------------------------------- # Finetuning Option Selection # ------------------------------- finetune_option = st.radio("Select Finetuning Option", ["Fine-tune from scratch", "Refinetune existing model"]) # ------------------------------- # Model Selection Logic # ------------------------------- selected_model = None saved_model_path = None if finetune_option == "Fine-tune from scratch": # Display Hugging Face model list model_list = [ "google/gemma-3-1b-pt", "google/gemma-3-1b-it", "google/gemma-3-4b-pt", "google/gemma-3-4b-it", "google/gemma-3-12b-pt", "google/gemma-3-12b-it", "google/gemma-3-27b-pt", "google/gemma-3-27b-it" ] selected_model = st.selectbox("🛠️ Select Gemma Model to Fine-tune", model_list) elif finetune_option == "Refinetune existing model": # Dynamically list all saved models from the /models folder model_dir = "models" if os.path.exists(model_dir): saved_models = [f for f in os.listdir(model_dir) if f.endswith(".pt")] else: saved_models = [] if saved_models: saved_model_path = st.selectbox("Select a saved model to re-finetune", saved_models) saved_model_path = os.path.join(model_dir, saved_model_path) st.success(f"✅ Selected model for refinement: `{saved_model_path}`") else: st.warning("⚠️ No saved models found! Switching to fine-tuning from scratch.") finetune_option = "Fine-tune from scratch" # ------------------------------- # Dataset Selection # ------------------------------- st.subheader("📚 Dataset Selection") dataset_option = st.radio("Choose dataset:", ["Upload New Dataset", "Use Existing Dataset (`train_data.csv`)"]) dataset_path = "datasets/train_data.csv" if dataset_option == "Upload New Dataset": uploaded_file = st.file_uploader("📤 Upload Dataset (CSV or JSON)", type=["csv", "json"]) if uploaded_file is not None: if uploaded_file.name.endswith(".csv"): new_data = pd.read_csv(uploaded_file) elif uploaded_file.name.endswith(".json"): json_data = json.load(uploaded_file) new_data = pd.json_normalize(json_data) else: st.error("❌ Unsupported file format. Please upload CSV or JSON.") st.stop() if os.path.exists(dataset_path): new_data.to_csv(dataset_path, mode='a', index=False, header=False) st.success(f"✅ Data appended to `{dataset_path}`!") else: new_data.to_csv(dataset_path, index=False) st.success(f"✅ Dataset saved as `{dataset_path}`!") elif dataset_option == "Use Existing Dataset (`train_data.csv`)": if os.path.exists(dataset_path): st.success("✅ Using existing `train_data.csv` for fine-tuning.") else: st.error("❌ `train_data.csv` not found! Please upload a new dataset.") st.stop() # ------------------------------- # Hyperparameters Configuration # ------------------------------- st.subheader("🔧 Hyperparameter Configuration") learning_rate = st.number_input("📊 Learning Rate", value=1e-4, format="%.5f") batch_size = st.number_input("🛠️ Batch Size", value=16, step=1) epochs = st.number_input("⏱️ Epochs", value=3, step=1) # ------------------------------- # Fine-tuning Execution with Real-Time Visualization # ------------------------------- if st.button("🚀 Start Fine-tuning"): st.info("Fine-tuning process initiated...") hf_token = get_hf_token() # Model loading logic if finetune_option == "Refinetune existing model" and saved_model_path: tokenizer, model = load_model("google/gemma-3-1b-it", hf_token) model = load_finetuned_model(model, saved_model_path) if model: st.success(f"✅ Loaded saved model: `{saved_model_path}` for refinement!") else: st.error("❌ Failed to load the saved model. Aborting.") st.stop() else: if not selected_model: st.error("❌ Please select a model to fine-tune.") st.stop() tokenizer, model = load_model(selected_model, hf_token) if model: st.success(f"✅ Base model loaded: `{selected_model}`") else: st.error("❌ Failed to load the base model. Aborting.") st.stop() # Create placeholders for training progress loss_chart = st.line_chart() # Loss curve acc_chart = st.line_chart() # Accuracy curve progress_text = st.empty() # Simulate training loop with real-time visualization losses_over_epochs = [] accuracies_over_epochs = [] for epoch, losses, accs in simulate_training(epochs, learning_rate, batch_size): # Update training text progress_text.text(f"Epoch {epoch}/{epochs} in progress...") # Assume simulate_training returns overall average loss and accuracy per epoch losses_over_epochs.append(losses) # e.g., average loss of the epoch accuracies_over_epochs.append(accs) # e.g., average accuracy of the epoch # Update real-time charts loss_chart.add_rows(pd.DataFrame({"Loss": [losses]})) acc_chart.add_rows(pd.DataFrame({"Accuracy": [accs]})) progress_text.text("Fine-tuning completed!") # Save fine-tuned model with timestamp timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") model_identifier = selected_model if selected_model else os.path.basename(saved_model_path) new_model_name = f"models/fine_tuned_model_{model_identifier.replace('/', '_')}_{timestamp}.pt" saved_model_path = save_model(model, new_model_name) if saved_model_path: st.success(f"✅ Fine-tuning completed! Model saved as `{saved_model_path}`") model = load_finetuned_model(model, saved_model_path) if model: st.success("🛠️ Fine-tuned model loaded and ready for inference!") else: st.error("❌ Failed to load the fine-tuned model for inference.") else: st.error("❌ Failed to save the fine-tuned model.")