lalitwale100's picture
Updated app.py
deb0770 verified
raw
history blame
6.83 kB
import streamlit as st
import torch
from transformers import RobertaTokenizer, T5ForConditionalGeneration
import pickle
import os
import time
from torch.serialization import safe_globals, add_safe_globals
add_safe_globals([
"transformers.models.t5.modeling_t5.T5ForConditionalGeneration"
])
# Set page configuration
st.set_page_config(
page_title="CodeT5 Query Generator",
page_icon="🤖",
layout="wide",
)
# CSS styling
st.markdown("""
<style>
.main-header {
font-size: 2.5rem;
color: #4527A0;
text-align: center;
margin-bottom: 1rem;
}
.sub-header {
font-size: 1.5rem;
color: #5E35B1;
margin-bottom: 1rem;
}
.response-container {
background-color: #f0f2f6;
border-radius: 10px;
padding: 20px;
margin-top: 20px;
}
.stButton>button {
background-color: #673AB7;
color: white;
}
.stButton>button:hover {
background-color: #5E35B1;
color: white;
}
.footer {
text-align: center;
margin-top: 3rem;
color: #9575CD;
}
</style>
""", unsafe_allow_html=True)
# App header
st.markdown("<h1 class='main-header'>Network Query Generator</h1>", unsafe_allow_html=True)
st.markdown("<h2 class='sub-header'>Ask questions and get specialized network related queries</h2>", unsafe_allow_html=True)
# Sidebar for model information and settings
with st.sidebar:
st.title("About")
st.info("This app uses a fine-tuned CodeT5 model to generate specialized queries from natural language questions.")
st.title("Model Settings")
max_length = st.slider("Maximum output length", 32, 256, 128)
num_beams = st.slider("Number of beams", 1, 10, 4)
temperature = st.slider("Temperature", 0.0, 1.0, 0.7)
st.title("Model Info")
st.markdown("**Base model:** Salesforce/codet5-small")
st.markdown("**Fine-tuned on:** Custom dataset")
MODEL_PATH = "finetuned_codet5_small_01.pkl"
# Function to load the model
@st.cache_resource
def load_model(file_path):
"""Load the tokenizer and model using a safe approach"""
model_name = "Salesforce/codet5-small"
tokenizer = RobertaTokenizer.from_pretrained(model_name)
try:
with safe_globals(["transformers.models.t5.modeling_t5.T5ForConditionalGeneration"]):
model = torch.load(file_path, map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu"), weights_only=False)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.eval()
return tokenizer, model, device, None
except Exception as e:
try:
# Initialize base model
base_model = T5ForConditionalGeneration.from_pretrained(model_name)
# Load state dict
with safe_globals(["transformers.models.t5.modeling_t5.T5ForConditionalGeneration"]):
state_dict = torch.load(file_path, map_location="cpu")
# If the loaded object is already a model, extract just the state dict
if hasattr(state_dict, 'state_dict'):
state_dict = state_dict.state_dict()
# Load the state dict into the base model
base_model.load_state_dict(state_dict)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
base_model = base_model.to(device)
base_model.eval()
return tokenizer, base_model, device, None
except Exception as e2:
return None, None, None, f"Error loading model: {e2}"
# Function to generate query
def generate_query(question, tokenizer, model, device, max_length=128, num_beams=4, temperature=0.7):
"""Generate a query based on the user's question"""
inputs = tokenizer(question, return_tensors="pt", padding=True, truncation=True, max_length=128)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
generated_ids = model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_length=max_length,
num_beams=num_beams,
temperature=temperature,
early_stopping=True
)
generated_query = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
return generated_query
# Load the model at startup
with st.spinner("Loading model... (this may take a moment)"):
tokenizer, model, device, error_message = load_model(MODEL_PATH)
if model is not None:
st.sidebar.success(f"Model loaded successfully!")
else:
st.sidebar.error(f"Failed to load model: {error_message}")
# Main app area
question = st.text_area("Enter your question here:", height=100, placeholder="Example: How can I secure my network against DDoS attacks?")
# a button to generate the response
col1, col2, col3 = st.columns([1, 1, 1])
with col2:
generate_button = st.button("Generate Query", use_container_width=True)
# Display generation result
if generate_button and question:
if model is not None and tokenizer is not None:
with st.spinner("Generating response..."):
# Add a slight delay for user experience
time.sleep(0.5)
response = generate_query(
question,
tokenizer,
model,
device,
max_length=max_length,
num_beams=num_beams,
temperature=temperature
)
st.markdown("<div class='response-container'>", unsafe_allow_html=True)
st.markdown("### Generated Query:")
st.code(response, language="sql")
st.markdown("</div>", unsafe_allow_html=True)
else:
st.error("Model could not be loaded. Please check if the model file exists at the correct path.")
# Example questions
with st.expander("Example Questions"):
example_questions = [
"Show me the current configuration of the router.",
"Get the total number of ['firewall', 'router', 'switch', 'server', 'access_point'] with high CPU usage.",
"Get the total bandwidth usage of access_point FW1.",
"Get the total uptime of all ['firewall', 'router', 'switch', 'server', 'access_point'].",
"Get the total number of ['firewall', 'router', 'switch', 'server', 'access_point'].",
"Find the top 5 devices with the highest data usage, grouped by region and filtering for data usage greater than 10GB."
]
for i in range(len(example_questions)):
st.write(example_questions[i])
# Footer
st.markdown("<p class='footer'>Powered by CodeT5 - Fine-tuned for specialized queries</p>", unsafe_allow_html=True)