Spaces:
Sleeping
Sleeping
import re | |
import gradio as gr | |
from datasets import load_dataset | |
import torch | |
from torch.utils.data import random_split | |
from collections import Counter | |
import torch.nn as nn | |
class LSTMClassifier(nn.Module): | |
def __init__(self, vocab_size, embedding_dim=200, hidden_dim=256): | |
super(LSTMClassifier, self).__init__() | |
self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0) | |
self.lstm = nn.LSTM( | |
embedding_dim, | |
hidden_dim, | |
num_layers=2, | |
batch_first=True, | |
bidirectional=True, | |
dropout=0.3, | |
) | |
# Dropout layer | |
self.dropout = nn.Dropout(0.4) | |
# Additional dense layers | |
self.fc1 = nn.Linear(hidden_dim * 2, hidden_dim) | |
self.fc2 = nn.Linear(hidden_dim, 2) | |
def forward(self, x): | |
embedded = self.embedding(x) | |
lstm_out, (hidden, cell) = self.lstm(embedded) | |
# Concatenate forward and backward hidden states | |
hidden = torch.cat((hidden[-2, :, :], hidden[-1, :, :]), dim=1) | |
hidden = self.dropout(hidden) | |
# Additional layer with ReLU activation | |
hidden = torch.relu(self.fc1(hidden)) | |
hidden = self.dropout(hidden) | |
# Final classification layer | |
out = self.fc2(hidden) | |
return out | |
def create_vocabulary(ds, max_words=10000): | |
word2idx = { | |
"<PAD>": 0, | |
"<UNK>": 1, | |
} | |
words = [] | |
for example in ds: | |
text = example["sms"] | |
text = text.lower() | |
text = re.sub(r"[^\w\s]", "", text) | |
words.extend(text.split()) | |
word_counts = Counter(words) | |
common_words = word_counts.most_common(max_words - 2) | |
for word, _ in common_words: | |
word2idx[word] = len(word2idx) | |
return word2idx | |
def create_splits(ds): | |
# 80/20 split | |
full_dataset = ds['train'] | |
train_size = int(0.8 * len(full_dataset)) | |
test_size = len(full_dataset) - train_size | |
train_dataset, test_dataset = random_split( | |
full_dataset, | |
[train_size, test_size], | |
generator=torch.Generator().manual_seed(42), | |
) | |
return train_dataset, test_dataset | |
ds = load_dataset("ucirvine/sms_spam") | |
train_dataset, test_dataset = create_splits(ds) | |
vocab = create_vocabulary(train_dataset) | |
# First recreate the model architecture | |
model = LSTMClassifier(len(vocab), 100) | |
# Load the saved state dict | |
model.load_state_dict(torch.load('best_model.pth')) | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
model = model.to(device) | |
def predict_text(model, text, word2idx, device, max_length=50): | |
# Set model to evaluation mode | |
model.eval() | |
# Preprocess the text (same as training) | |
text = text.lower() | |
words = text.split() | |
# Convert words to indices | |
indices = [word2idx.get(word, word2idx['<UNK>']) for word in words] | |
# Pad or truncate | |
if len(indices) < max_length: | |
indices += [word2idx['<PAD>']] * (max_length - len(indices)) | |
else: | |
indices = indices[:max_length] | |
# Convert to tensor | |
with torch.no_grad(): | |
input_tensor = torch.tensor(indices).unsqueeze( | |
0).to(device) # Add batch dimension | |
outputs = model(input_tensor) | |
probabilities = torch.softmax(outputs, dim=1) | |
prediction = torch.argmax(outputs, dim=1) | |
return { | |
'prediction': 'spam' if prediction.item() == 1 else 'ham', | |
'confidence': probabilities[0][prediction].item() | |
} | |
interface = gr.Interface( | |
fn=lambda text: predict_text(model, text, vocab, device), | |
inputs=gr.Textbox(lines=2, placeholder="Enter your text here..."), | |
outputs=gr.Textbox(), | |
title="SMS Spam Classifier", | |
description="Enter a text message to predict if it's spam or ham.", | |
) | |
interface.launch(share=True) | |