import streamlit as st import torch from transformers import RobertaTokenizer, T5ForConditionalGeneration import pickle import os import time from torch.serialization import safe_globals, add_safe_globals add_safe_globals([ "transformers.models.t5.modeling_t5.T5ForConditionalGeneration" ]) # Set page configuration st.set_page_config( page_title="CodeT5 Query Generator", page_icon="🤖", layout="wide", ) # CSS styling st.markdown(""" """, unsafe_allow_html=True) # App header st.markdown("

Network Query Generator

", unsafe_allow_html=True) st.markdown("

Ask questions and get specialized network related queries

", unsafe_allow_html=True) # Sidebar for model information and settings with st.sidebar: st.title("About") st.info("This app uses a fine-tuned CodeT5 model to generate specialized queries from natural language questions.") st.title("Model Settings") max_length = st.slider("Maximum output length", 32, 256, 128) num_beams = st.slider("Number of beams", 1, 10, 4) temperature = st.slider("Temperature", 0.0, 1.0, 0.7) st.title("Model Info") st.markdown("**Base model:** Salesforce/codet5-small") st.markdown("**Fine-tuned on:** Custom dataset") MODEL_PATH = "finetuned_codet5_small_01.pkl" # Function to load the model @st.cache_resource def load_model(file_path): """Load the tokenizer and model using a safe approach""" model_name = "Salesforce/codet5-small" tokenizer = RobertaTokenizer.from_pretrained(model_name) try: with safe_globals(["transformers.models.t5.modeling_t5.T5ForConditionalGeneration"]): model = torch.load(file_path, map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu"), weights_only=False) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device) model.eval() return tokenizer, model, device, None except Exception as e: try: # Initialize base model base_model = T5ForConditionalGeneration.from_pretrained(model_name) # Load state dict with safe_globals(["transformers.models.t5.modeling_t5.T5ForConditionalGeneration"]): state_dict = torch.load(file_path, map_location="cpu") # If the loaded object is already a model, extract just the state dict if hasattr(state_dict, 'state_dict'): state_dict = state_dict.state_dict() # Load the state dict into the base model base_model.load_state_dict(state_dict) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") base_model = base_model.to(device) base_model.eval() return tokenizer, base_model, device, None except Exception as e2: return None, None, None, f"Error loading model: {e2}" # Function to generate query def generate_query(question, tokenizer, model, device, max_length=128, num_beams=4, temperature=0.7): """Generate a query based on the user's question""" inputs = tokenizer(question, return_tensors="pt", padding=True, truncation=True, max_length=128) inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): generated_ids = model.generate( input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_length=max_length, num_beams=num_beams, temperature=temperature, early_stopping=True ) generated_query = tokenizer.decode(generated_ids[0], skip_special_tokens=True) return generated_query # Load the model at startup with st.spinner("Loading model... (this may take a moment)"): tokenizer, model, device, error_message = load_model(MODEL_PATH) if model is not None: st.sidebar.success(f"Model loaded successfully!") else: st.sidebar.error(f"Failed to load model: {error_message}") # Main app area question = st.text_area("Enter your question here:", height=100, placeholder="Example: Show the total data transferred, grouped by user department and VPN type, excluding 'Guest' users, for VPN sessions that lasted longer than 2 hours") # a button to generate the response col1, col2, col3 = st.columns([1, 1, 1]) with col2: generate_button = st.button("Generate Query", use_container_width=True) # Display generation result if generate_button and question: if model is not None and tokenizer is not None: with st.spinner("Generating response..."): # Add a slight delay for user experience time.sleep(0.5) response = generate_query( question, tokenizer, model, device, max_length=max_length, num_beams=num_beams, temperature=temperature ) st.markdown("
", unsafe_allow_html=True) st.markdown("### Generated Query:") st.code(response, language="sql") st.markdown("
", unsafe_allow_html=True) else: st.error("Model could not be loaded. Please check if the model file exists at the correct path.") # Example questions with st.expander("Example Questions"): example_questions = [ "Show me the current configuration of the router.", "Get the total number of ['firewall', 'router', 'switch', 'server', 'access_point'] with high CPU usage.", "Get the total bandwidth usage of access_point FW1.", "Get the total uptime of all ['firewall', 'router', 'switch', 'server', 'access_point'].", "Get the total number of ['firewall', 'router', 'switch', 'server', 'access_point'].", "Find the top 5 devices with the highest data usage, grouped by region and filtering for data usage greater than 10GB." ] for i in range(len(example_questions)): st.write(example_questions[i]) # Footer st.markdown("", unsafe_allow_html=True)