import re import gradio as gr from datasets import load_dataset import torch from torch.utils.data import random_split from collections import Counter import torch.nn as nn class LSTMClassifier(nn.Module): def __init__(self, vocab_size, embedding_dim=200, hidden_dim=256): super(LSTMClassifier, self).__init__() self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0) self.lstm = nn.LSTM( embedding_dim, hidden_dim, num_layers=2, batch_first=True, bidirectional=True, dropout=0.3, ) # Dropout layer self.dropout = nn.Dropout(0.4) # Additional dense layers self.fc1 = nn.Linear(hidden_dim * 2, hidden_dim) self.fc2 = nn.Linear(hidden_dim, 2) def forward(self, x): embedded = self.embedding(x) lstm_out, (hidden, cell) = self.lstm(embedded) # Concatenate forward and backward hidden states hidden = torch.cat((hidden[-2, :, :], hidden[-1, :, :]), dim=1) hidden = self.dropout(hidden) # Additional layer with ReLU activation hidden = torch.relu(self.fc1(hidden)) hidden = self.dropout(hidden) # Final classification layer out = self.fc2(hidden) return out def create_vocabulary(ds, max_words=10000): word2idx = { "": 0, "": 1, } words = [] for example in ds: text = example["sms"] text = text.lower() text = re.sub(r"[^\w\s]", "", text) words.extend(text.split()) word_counts = Counter(words) common_words = word_counts.most_common(max_words - 2) for word, _ in common_words: word2idx[word] = len(word2idx) return word2idx def create_splits(ds): # 80/20 split full_dataset = ds['train'] train_size = int(0.8 * len(full_dataset)) test_size = len(full_dataset) - train_size train_dataset, test_dataset = random_split( full_dataset, [train_size, test_size], generator=torch.Generator().manual_seed(42), ) return train_dataset, test_dataset ds = load_dataset("ucirvine/sms_spam") train_dataset, test_dataset = create_splits(ds) vocab = create_vocabulary(train_dataset) # First recreate the model architecture model = LSTMClassifier(len(vocab), 100) # Load the saved state dict model.load_state_dict(torch.load('best_model.pth')) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = model.to(device) def predict_text(model, text, word2idx, device, max_length=50): # Set model to evaluation mode model.eval() # Preprocess the text (same as training) text = text.lower() words = text.split() # Convert words to indices indices = [word2idx.get(word, word2idx['']) for word in words] # Pad or truncate if len(indices) < max_length: indices += [word2idx['']] * (max_length - len(indices)) else: indices = indices[:max_length] # Convert to tensor with torch.no_grad(): input_tensor = torch.tensor(indices).unsqueeze( 0).to(device) # Add batch dimension outputs = model(input_tensor) probabilities = torch.softmax(outputs, dim=1) prediction = torch.argmax(outputs, dim=1) return { 'prediction': 'spam' if prediction.item() == 1 else 'ham', 'confidence': probabilities[0][prediction].item() } interface = gr.Interface( fn=lambda text: predict_text(model, text, vocab, device), inputs=gr.Textbox(lines=2, placeholder="Enter your text here..."), outputs=gr.Textbox(), title="SMS Spam Classifier", description="Enter a text message to predict if it's spam or ham.", ) interface.launch(share=True)