|
import streamlit as st |
|
import torch |
|
from transformers import RobertaTokenizer, T5ForConditionalGeneration |
|
import pickle |
|
import os |
|
import time |
|
from torch.serialization import safe_globals, add_safe_globals |
|
|
|
add_safe_globals([ |
|
"transformers.models.t5.modeling_t5.T5ForConditionalGeneration" |
|
]) |
|
|
|
|
|
st.set_page_config( |
|
page_title="CodeT5 Query Generator", |
|
page_icon="🤖", |
|
layout="wide", |
|
) |
|
|
|
|
|
st.markdown(""" |
|
<style> |
|
.main-header { |
|
font-size: 2.5rem; |
|
color: #4527A0; |
|
text-align: center; |
|
margin-bottom: 1rem; |
|
} |
|
.sub-header { |
|
font-size: 1.5rem; |
|
color: #5E35B1; |
|
margin-bottom: 1rem; |
|
} |
|
.response-container { |
|
background-color: #f0f2f6; |
|
border-radius: 10px; |
|
padding: 20px; |
|
margin-top: 20px; |
|
} |
|
.stButton>button { |
|
background-color: #673AB7; |
|
color: white; |
|
} |
|
.stButton>button:hover { |
|
background-color: #5E35B1; |
|
color: white; |
|
} |
|
.footer { |
|
text-align: center; |
|
margin-top: 3rem; |
|
color: #9575CD; |
|
} |
|
</style> |
|
""", unsafe_allow_html=True) |
|
|
|
|
|
st.markdown("<h1 class='main-header'>Network Query Generator</h1>", unsafe_allow_html=True) |
|
st.markdown("<h2 class='sub-header'>Ask questions and get specialized network related queries</h2>", unsafe_allow_html=True) |
|
|
|
|
|
with st.sidebar: |
|
st.title("About") |
|
st.info("This app uses a fine-tuned CodeT5 model to generate specialized queries from natural language questions.") |
|
|
|
st.title("Model Settings") |
|
max_length = st.slider("Maximum output length", 32, 256, 128) |
|
num_beams = st.slider("Number of beams", 1, 10, 4) |
|
temperature = st.slider("Temperature", 0.0, 1.0, 0.7) |
|
|
|
st.title("Model Info") |
|
st.markdown("**Base model:** Salesforce/codet5-small") |
|
st.markdown("**Fine-tuned on:** Custom dataset") |
|
|
|
MODEL_PATH = "finetuned_codet5_small_01.pkl" |
|
|
|
|
|
@st.cache_resource |
|
def load_model(file_path): |
|
"""Load the tokenizer and model using a safe approach""" |
|
model_name = "Salesforce/codet5-small" |
|
tokenizer = RobertaTokenizer.from_pretrained(model_name) |
|
|
|
try: |
|
with safe_globals(["transformers.models.t5.modeling_t5.T5ForConditionalGeneration"]): |
|
model = torch.load(file_path, map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu"), weights_only=False) |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model = model.to(device) |
|
model.eval() |
|
return tokenizer, model, device, None |
|
except Exception as e: |
|
try: |
|
|
|
base_model = T5ForConditionalGeneration.from_pretrained(model_name) |
|
|
|
|
|
with safe_globals(["transformers.models.t5.modeling_t5.T5ForConditionalGeneration"]): |
|
state_dict = torch.load(file_path, map_location="cpu") |
|
|
|
|
|
if hasattr(state_dict, 'state_dict'): |
|
state_dict = state_dict.state_dict() |
|
|
|
|
|
base_model.load_state_dict(state_dict) |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
base_model = base_model.to(device) |
|
base_model.eval() |
|
return tokenizer, base_model, device, None |
|
except Exception as e2: |
|
return None, None, None, f"Error loading model: {e2}" |
|
|
|
|
|
def generate_query(question, tokenizer, model, device, max_length=128, num_beams=4, temperature=0.7): |
|
"""Generate a query based on the user's question""" |
|
inputs = tokenizer(question, return_tensors="pt", padding=True, truncation=True, max_length=128) |
|
inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
|
with torch.no_grad(): |
|
generated_ids = model.generate( |
|
input_ids=inputs["input_ids"], |
|
attention_mask=inputs["attention_mask"], |
|
max_length=max_length, |
|
num_beams=num_beams, |
|
temperature=temperature, |
|
early_stopping=True |
|
) |
|
|
|
generated_query = tokenizer.decode(generated_ids[0], skip_special_tokens=True) |
|
return generated_query |
|
|
|
|
|
with st.spinner("Loading model... (this may take a moment)"): |
|
tokenizer, model, device, error_message = load_model(MODEL_PATH) |
|
|
|
if model is not None: |
|
st.sidebar.success(f"Model loaded successfully!") |
|
else: |
|
st.sidebar.error(f"Failed to load model: {error_message}") |
|
|
|
|
|
question = st.text_area("Enter your question here:", height=100, placeholder="Example: Show the total data transferred, grouped by user department and VPN type, excluding 'Guest' users, for VPN sessions that lasted longer than 2 hours") |
|
|
|
|
|
col1, col2, col3 = st.columns([1, 1, 1]) |
|
with col2: |
|
generate_button = st.button("Generate Query", use_container_width=True) |
|
|
|
|
|
if generate_button and question: |
|
if model is not None and tokenizer is not None: |
|
with st.spinner("Generating response..."): |
|
|
|
time.sleep(0.5) |
|
response = generate_query( |
|
question, |
|
tokenizer, |
|
model, |
|
device, |
|
max_length=max_length, |
|
num_beams=num_beams, |
|
temperature=temperature |
|
) |
|
|
|
st.markdown("<div class='response-container'>", unsafe_allow_html=True) |
|
st.markdown("### Generated Query:") |
|
st.code(response, language="sql") |
|
st.markdown("</div>", unsafe_allow_html=True) |
|
else: |
|
st.error("Model could not be loaded. Please check if the model file exists at the correct path.") |
|
|
|
|
|
with st.expander("Example Questions"): |
|
example_questions = [ |
|
"Show me the current configuration of the router.", |
|
"Get the total number of ['firewall', 'router', 'switch', 'server', 'access_point'] with high CPU usage.", |
|
"Get the total bandwidth usage of access_point FW1.", |
|
"Get the total uptime of all ['firewall', 'router', 'switch', 'server', 'access_point'].", |
|
"Get the total number of ['firewall', 'router', 'switch', 'server', 'access_point'].", |
|
"Find the top 5 devices with the highest data usage, grouped by region and filtering for data usage greater than 10GB." |
|
] |
|
|
|
for i in range(len(example_questions)): |
|
st.write(example_questions[i]) |
|
|
|
|
|
st.markdown("<p class='footer'>Powered by CodeT5 - Fine-tuned for specialized queries</p>", unsafe_allow_html=True) |