File size: 3,803 Bytes
87476a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import re
import gradio as gr
from datasets import load_dataset
import torch
from torch.utils.data import random_split
from collections import Counter
import torch.nn as nn


class LSTMClassifier(nn.Module):
    def __init__(self, vocab_size, embedding_dim=200, hidden_dim=256):
        super(LSTMClassifier, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)

        self.lstm = nn.LSTM(
            embedding_dim,
            hidden_dim,
            num_layers=2,
            batch_first=True,
            bidirectional=True,
            dropout=0.3,
        )

        # Dropout layer
        self.dropout = nn.Dropout(0.4)

        # Additional dense layers
        self.fc1 = nn.Linear(hidden_dim * 2, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, 2)

    def forward(self, x):
        embedded = self.embedding(x)

        lstm_out, (hidden, cell) = self.lstm(embedded)

        # Concatenate forward and backward hidden states
        hidden = torch.cat((hidden[-2, :, :], hidden[-1, :, :]), dim=1)
        hidden = self.dropout(hidden)

        # Additional layer with ReLU activation
        hidden = torch.relu(self.fc1(hidden))
        hidden = self.dropout(hidden)

        # Final classification layer
        out = self.fc2(hidden)
        return out


def create_vocabulary(ds, max_words=10000):
    word2idx = {
        "<PAD>": 0,
        "<UNK>": 1,
    }
    words = []
    for example in ds:
        text = example["sms"]
        text = text.lower()
        text = re.sub(r"[^\w\s]", "", text)
        words.extend(text.split())

    word_counts = Counter(words)
    common_words = word_counts.most_common(max_words - 2)
    for word, _ in common_words:
        word2idx[word] = len(word2idx)

    return word2idx


def create_splits(ds):
    # 80/20 split
    full_dataset = ds['train']
    train_size = int(0.8 * len(full_dataset))
    test_size = len(full_dataset) - train_size

    train_dataset, test_dataset = random_split(
        full_dataset,
        [train_size, test_size],
        generator=torch.Generator().manual_seed(42),
    )
    return train_dataset, test_dataset


ds = load_dataset("ucirvine/sms_spam")
train_dataset, test_dataset = create_splits(ds)
vocab = create_vocabulary(train_dataset)

# First recreate the model architecture
model = LSTMClassifier(len(vocab), 100)
# Load the saved state dict
model.load_state_dict(torch.load('best_model.pth'))
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = model.to(device)


def predict_text(model, text, word2idx, device, max_length=50):
    # Set model to evaluation mode
    model.eval()

    # Preprocess the text (same as training)
    text = text.lower()
    words = text.split()

    # Convert words to indices
    indices = [word2idx.get(word, word2idx['<UNK>']) for word in words]

    # Pad or truncate
    if len(indices) < max_length:
        indices += [word2idx['<PAD>']] * (max_length - len(indices))
    else:
        indices = indices[:max_length]

    # Convert to tensor
    with torch.no_grad():
        input_tensor = torch.tensor(indices).unsqueeze(
            0).to(device)  # Add batch dimension
        outputs = model(input_tensor)
        probabilities = torch.softmax(outputs, dim=1)
        prediction = torch.argmax(outputs, dim=1)

    return {
        'prediction': 'spam' if prediction.item() == 1 else 'ham',
        'confidence': probabilities[0][prediction].item()
    }


interface = gr.Interface(
    fn=lambda text: predict_text(model, text, vocab, device),
    inputs=gr.Textbox(lines=2, placeholder="Enter your text here..."),
    outputs=gr.Textbox(),
    title="SMS Spam Classifier",
    description="Enter a text message to predict if it's spam or ham.",
)

interface.launch(share=True)