File size: 6,931 Bytes
b7de90d ff418f7 b7de90d 6a04235 b7de90d deb0770 b7de90d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 |
import streamlit as st
import torch
from transformers import RobertaTokenizer, T5ForConditionalGeneration
import pickle
import os
import time
from torch.serialization import safe_globals, add_safe_globals
add_safe_globals([
"transformers.models.t5.modeling_t5.T5ForConditionalGeneration"
])
# Set page configuration
st.set_page_config(
page_title="CodeT5 Query Generator",
page_icon="🤖",
layout="wide",
)
# CSS styling
st.markdown("""
<style>
.main-header {
font-size: 2.5rem;
color: #4527A0;
text-align: center;
margin-bottom: 1rem;
}
.sub-header {
font-size: 1.5rem;
color: #5E35B1;
margin-bottom: 1rem;
}
.response-container {
background-color: #f0f2f6;
border-radius: 10px;
padding: 20px;
margin-top: 20px;
}
.stButton>button {
background-color: #673AB7;
color: white;
}
.stButton>button:hover {
background-color: #5E35B1;
color: white;
}
.footer {
text-align: center;
margin-top: 3rem;
color: #9575CD;
}
</style>
""", unsafe_allow_html=True)
# App header
st.markdown("<h1 class='main-header'>Network Query Generator</h1>", unsafe_allow_html=True)
st.markdown("<h2 class='sub-header'>Ask questions and get specialized network related queries</h2>", unsafe_allow_html=True)
# Sidebar for model information and settings
with st.sidebar:
st.title("About")
st.info("This app uses a fine-tuned CodeT5 model to generate specialized queries from natural language questions.")
st.title("Model Settings")
max_length = st.slider("Maximum output length", 32, 256, 128)
num_beams = st.slider("Number of beams", 1, 10, 4)
temperature = st.slider("Temperature", 0.0, 1.0, 0.7)
st.title("Model Info")
st.markdown("**Base model:** Salesforce/codet5-small")
st.markdown("**Fine-tuned on:** Custom dataset")
MODEL_PATH = "finetuned_codet5_small_01.pkl"
# Function to load the model
@st.cache_resource
def load_model(file_path):
"""Load the tokenizer and model using a safe approach"""
model_name = "Salesforce/codet5-small"
tokenizer = RobertaTokenizer.from_pretrained(model_name)
try:
with safe_globals(["transformers.models.t5.modeling_t5.T5ForConditionalGeneration"]):
model = torch.load(file_path, map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu"), weights_only=False)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.eval()
return tokenizer, model, device, None
except Exception as e:
try:
# Initialize base model
base_model = T5ForConditionalGeneration.from_pretrained(model_name)
# Load state dict
with safe_globals(["transformers.models.t5.modeling_t5.T5ForConditionalGeneration"]):
state_dict = torch.load(file_path, map_location="cpu")
# If the loaded object is already a model, extract just the state dict
if hasattr(state_dict, 'state_dict'):
state_dict = state_dict.state_dict()
# Load the state dict into the base model
base_model.load_state_dict(state_dict)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
base_model = base_model.to(device)
base_model.eval()
return tokenizer, base_model, device, None
except Exception as e2:
return None, None, None, f"Error loading model: {e2}"
# Function to generate query
def generate_query(question, tokenizer, model, device, max_length=128, num_beams=4, temperature=0.7):
"""Generate a query based on the user's question"""
inputs = tokenizer(question, return_tensors="pt", padding=True, truncation=True, max_length=128)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
generated_ids = model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_length=max_length,
num_beams=num_beams,
temperature=temperature,
early_stopping=True
)
generated_query = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
return generated_query
# Load the model at startup
with st.spinner("Loading model... (this may take a moment)"):
tokenizer, model, device, error_message = load_model(MODEL_PATH)
if model is not None:
st.sidebar.success(f"Model loaded successfully!")
else:
st.sidebar.error(f"Failed to load model: {error_message}")
# Main app area
question = st.text_area("Enter your question here:", height=100, placeholder="Example: Show the total data transferred, grouped by user department and VPN type, excluding 'Guest' users, for VPN sessions that lasted longer than 2 hours")
# a button to generate the response
col1, col2, col3 = st.columns([1, 1, 1])
with col2:
generate_button = st.button("Generate Query", use_container_width=True)
# Display generation result
if generate_button and question:
if model is not None and tokenizer is not None:
with st.spinner("Generating response..."):
# Add a slight delay for user experience
time.sleep(0.5)
response = generate_query(
question,
tokenizer,
model,
device,
max_length=max_length,
num_beams=num_beams,
temperature=temperature
)
st.markdown("<div class='response-container'>", unsafe_allow_html=True)
st.markdown("### Generated Query:")
st.code(response, language="sql")
st.markdown("</div>", unsafe_allow_html=True)
else:
st.error("Model could not be loaded. Please check if the model file exists at the correct path.")
# Example questions
with st.expander("Example Questions"):
example_questions = [
"Show me the current configuration of the router.",
"Get the total number of ['firewall', 'router', 'switch', 'server', 'access_point'] with high CPU usage.",
"Get the total bandwidth usage of access_point FW1.",
"Get the total uptime of all ['firewall', 'router', 'switch', 'server', 'access_point'].",
"Get the total number of ['firewall', 'router', 'switch', 'server', 'access_point'].",
"Find the top 5 devices with the highest data usage, grouped by region and filtering for data usage greater than 10GB."
]
for i in range(len(example_questions)):
st.write(example_questions[i])
# Footer
st.markdown("<p class='footer'>Powered by CodeT5 - Fine-tuned for specialized queries</p>", unsafe_allow_html=True) |