File size: 6,931 Bytes
b7de90d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ff418f7
b7de90d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6a04235
b7de90d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
deb0770
 
b7de90d
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
import streamlit as st
import torch
from transformers import RobertaTokenizer, T5ForConditionalGeneration
import pickle
import os
import time
from torch.serialization import safe_globals, add_safe_globals

add_safe_globals([
    "transformers.models.t5.modeling_t5.T5ForConditionalGeneration"
])

# Set page configuration
st.set_page_config(
    page_title="CodeT5 Query Generator",
    page_icon="🤖",
    layout="wide",
)

# CSS styling
st.markdown("""
<style>
    .main-header {
        font-size: 2.5rem;
        color: #4527A0;
        text-align: center;
        margin-bottom: 1rem;
    }
    .sub-header {
        font-size: 1.5rem;
        color: #5E35B1;
        margin-bottom: 1rem;
    }
    .response-container {
        background-color: #f0f2f6;
        border-radius: 10px;
        padding: 20px;
        margin-top: 20px;
    }
    .stButton>button {
        background-color: #673AB7;
        color: white;
    }
    .stButton>button:hover {
        background-color: #5E35B1;
        color: white;
    }
    .footer {
        text-align: center;
        margin-top: 3rem;
        color: #9575CD;
    }
</style>
""", unsafe_allow_html=True)

# App header
st.markdown("<h1 class='main-header'>Network Query Generator</h1>", unsafe_allow_html=True)
st.markdown("<h2 class='sub-header'>Ask questions and get specialized network related queries</h2>", unsafe_allow_html=True)

# Sidebar for model information and settings
with st.sidebar:
    st.title("About")
    st.info("This app uses a fine-tuned CodeT5 model to generate specialized queries from natural language questions.")
    
    st.title("Model Settings")
    max_length = st.slider("Maximum output length", 32, 256, 128)
    num_beams = st.slider("Number of beams", 1, 10, 4)
    temperature = st.slider("Temperature", 0.0, 1.0, 0.7)
    
    st.title("Model Info")
    st.markdown("**Base model:** Salesforce/codet5-small")
    st.markdown("**Fine-tuned on:** Custom dataset")

MODEL_PATH = "finetuned_codet5_small_01.pkl"

# Function to load the model
@st.cache_resource
def load_model(file_path):
    """Load the tokenizer and model using a safe approach"""
    model_name = "Salesforce/codet5-small"
    tokenizer = RobertaTokenizer.from_pretrained(model_name)
    
    try:
        with safe_globals(["transformers.models.t5.modeling_t5.T5ForConditionalGeneration"]):
            model = torch.load(file_path, map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu"), weights_only=False)
        
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        model = model.to(device)
        model.eval()
        return tokenizer, model, device, None
    except Exception as e:
        try:
            # Initialize base model
            base_model = T5ForConditionalGeneration.from_pretrained(model_name)
            
            # Load state dict
            with safe_globals(["transformers.models.t5.modeling_t5.T5ForConditionalGeneration"]):
                state_dict = torch.load(file_path, map_location="cpu")
            
            # If the loaded object is already a model, extract just the state dict
            if hasattr(state_dict, 'state_dict'):
                state_dict = state_dict.state_dict()
                
            # Load the state dict into the base model
            base_model.load_state_dict(state_dict)
            
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            base_model = base_model.to(device)
            base_model.eval()
            return tokenizer, base_model, device, None
        except Exception as e2:
            return None, None, None, f"Error loading model: {e2}"

# Function to generate query
def generate_query(question, tokenizer, model, device, max_length=128, num_beams=4, temperature=0.7):
    """Generate a query based on the user's question"""
    inputs = tokenizer(question, return_tensors="pt", padding=True, truncation=True, max_length=128)
    inputs = {k: v.to(device) for k, v in inputs.items()}

    with torch.no_grad():
        generated_ids = model.generate(
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            max_length=max_length,
            num_beams=num_beams,
            temperature=temperature,
            early_stopping=True
        )

    generated_query = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
    return generated_query

# Load the model at startup
with st.spinner("Loading model... (this may take a moment)"):
    tokenizer, model, device, error_message = load_model(MODEL_PATH)
    
    if model is not None:
        st.sidebar.success(f"Model loaded successfully!")
    else:
        st.sidebar.error(f"Failed to load model: {error_message}")

# Main app area
question = st.text_area("Enter your question here:", height=100, placeholder="Example: Show the total data transferred, grouped by user department and VPN type, excluding 'Guest' users, for VPN sessions that lasted longer than 2 hours")

# a button to generate the response
col1, col2, col3 = st.columns([1, 1, 1])
with col2:
    generate_button = st.button("Generate Query", use_container_width=True)

# Display generation result
if generate_button and question:
    if model is not None and tokenizer is not None:
        with st.spinner("Generating response..."):
            # Add a slight delay for user experience
            time.sleep(0.5)
            response = generate_query(
                question, 
                tokenizer, 
                model, 
                device, 
                max_length=max_length, 
                num_beams=num_beams,
                temperature=temperature
            )
        
        st.markdown("<div class='response-container'>", unsafe_allow_html=True)
        st.markdown("### Generated Query:")
        st.code(response, language="sql")
        st.markdown("</div>", unsafe_allow_html=True)
    else:
        st.error("Model could not be loaded. Please check if the model file exists at the correct path.")

# Example questions
with st.expander("Example Questions"):
    example_questions = [
        "Show me the current configuration of the router.",
        "Get the total number of ['firewall', 'router', 'switch', 'server', 'access_point'] with high CPU usage.",
        "Get the total bandwidth usage of access_point FW1.",
        "Get the total uptime of all ['firewall', 'router', 'switch', 'server', 'access_point'].",
        "Get the total number of ['firewall', 'router', 'switch', 'server', 'access_point'].",
        "Find the top 5 devices with the highest data usage, grouped by region and filtering for data usage greater than 10GB."
    ]
    
    for i in range(len(example_questions)):
        st.write(example_questions[i])

# Footer
st.markdown("<p class='footer'>Powered by CodeT5 - Fine-tuned for specialized queries</p>", unsafe_allow_html=True)