spam-LSTM-model / app.py
k-code's picture
init
87476a9
import re
import gradio as gr
from datasets import load_dataset
import torch
from torch.utils.data import random_split
from collections import Counter
import torch.nn as nn
class LSTMClassifier(nn.Module):
def __init__(self, vocab_size, embedding_dim=200, hidden_dim=256):
super(LSTMClassifier, self).__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
self.lstm = nn.LSTM(
embedding_dim,
hidden_dim,
num_layers=2,
batch_first=True,
bidirectional=True,
dropout=0.3,
)
# Dropout layer
self.dropout = nn.Dropout(0.4)
# Additional dense layers
self.fc1 = nn.Linear(hidden_dim * 2, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, 2)
def forward(self, x):
embedded = self.embedding(x)
lstm_out, (hidden, cell) = self.lstm(embedded)
# Concatenate forward and backward hidden states
hidden = torch.cat((hidden[-2, :, :], hidden[-1, :, :]), dim=1)
hidden = self.dropout(hidden)
# Additional layer with ReLU activation
hidden = torch.relu(self.fc1(hidden))
hidden = self.dropout(hidden)
# Final classification layer
out = self.fc2(hidden)
return out
def create_vocabulary(ds, max_words=10000):
word2idx = {
"<PAD>": 0,
"<UNK>": 1,
}
words = []
for example in ds:
text = example["sms"]
text = text.lower()
text = re.sub(r"[^\w\s]", "", text)
words.extend(text.split())
word_counts = Counter(words)
common_words = word_counts.most_common(max_words - 2)
for word, _ in common_words:
word2idx[word] = len(word2idx)
return word2idx
def create_splits(ds):
# 80/20 split
full_dataset = ds['train']
train_size = int(0.8 * len(full_dataset))
test_size = len(full_dataset) - train_size
train_dataset, test_dataset = random_split(
full_dataset,
[train_size, test_size],
generator=torch.Generator().manual_seed(42),
)
return train_dataset, test_dataset
ds = load_dataset("ucirvine/sms_spam")
train_dataset, test_dataset = create_splits(ds)
vocab = create_vocabulary(train_dataset)
# First recreate the model architecture
model = LSTMClassifier(len(vocab), 100)
# Load the saved state dict
model.load_state_dict(torch.load('best_model.pth'))
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
def predict_text(model, text, word2idx, device, max_length=50):
# Set model to evaluation mode
model.eval()
# Preprocess the text (same as training)
text = text.lower()
words = text.split()
# Convert words to indices
indices = [word2idx.get(word, word2idx['<UNK>']) for word in words]
# Pad or truncate
if len(indices) < max_length:
indices += [word2idx['<PAD>']] * (max_length - len(indices))
else:
indices = indices[:max_length]
# Convert to tensor
with torch.no_grad():
input_tensor = torch.tensor(indices).unsqueeze(
0).to(device) # Add batch dimension
outputs = model(input_tensor)
probabilities = torch.softmax(outputs, dim=1)
prediction = torch.argmax(outputs, dim=1)
return {
'prediction': 'spam' if prediction.item() == 1 else 'ham',
'confidence': probabilities[0][prediction].item()
}
interface = gr.Interface(
fn=lambda text: predict_text(model, text, vocab, device),
inputs=gr.Textbox(lines=2, placeholder="Enter your text here..."),
outputs=gr.Textbox(),
title="SMS Spam Classifier",
description="Enter a text message to predict if it's spam or ham.",
)
interface.launch(share=True)