File size: 5,167 Bytes
6bd894d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
# ====================================
# 1) Install Dependencies (One-Time)
# ====================================
# ====================================
# 2) Imports
# ====================================
import torch
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
import base64
from io import BytesIO

# We'll use a "tiny" BERT model to reduce loading/inference time:
from transformers import AutoTokenizer, AutoModel

# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# ====================================
# 3) Load Tiny BERT Model + Tokenizer
# ====================================
# This "prajjwal1/bert-tiny" model is just 2 Transformer layers
# (and ~4 million parameters), so it loads and runs faster.
tokenizer = AutoTokenizer.from_pretrained("prajjwal1/bert-tiny")
model = AutoModel.from_pretrained(
    "prajjwal1/bert-tiny",
    output_attentions=True  # so we can visualize attention
).to(device)
model.eval()

# ====================================
# 4) Helper Functions
# ====================================

def plot_heatmap(matrix, tokens, title="Attention", cmap="Blues"):
    """
    Creates a heatmap from a 2D matrix: [seq_len, seq_len]
    with tokens on both axes. Returns a base64-encoded PNG.
    """
    fig, ax = plt.subplots(figsize=(6, 5))
    cax = ax.imshow(matrix, interpolation='nearest', cmap=cmap)
    ax.set_title(title)

    # Show tokens on x and y axis
    ax.set_xticks(range(len(tokens)))
    ax.set_xticklabels(tokens, rotation=90)
    ax.set_yticks(range(len(tokens)))
    ax.set_yticklabels(tokens)

    fig.colorbar(cax, ax=ax)
    plt.tight_layout()

    # Convert plot to base64-encoded PNG
    buf = BytesIO()
    plt.savefig(buf, format='png', bbox_inches="tight")
    plt.close(fig)
    buf.seek(0)
    return "data:image/png;base64," + base64.b64encode(buf.read()).decode("utf-8")

def simulate_rnn_hidden_states(tokens):
    """
    Simulate how an RNN processes tokens one-by-one.
    We'll just create random hidden states for illustration.
    Returns a base64-encoded PNG heatmap of shape [seq_len, hidden_dim].
    """
    seq_len = len(tokens)
    hidden_dim = 8  # small dimension for the "hidden state"

    # Create random hidden states: shape [seq_len, hidden_dim]
    random_states = np.random.rand(seq_len, hidden_dim)

    fig, ax = plt.subplots(figsize=(6, 3))
    cax = ax.imshow(random_states, interpolation='nearest', aspect='auto', cmap="viridis")
    ax.set_title("Simulated RNN Hidden States")
    ax.set_xlabel("Hidden Dim")
    ax.set_ylabel("Token Index")

    fig.colorbar(cax, ax=ax)
    plt.tight_layout()

    buf = BytesIO()
    plt.savefig(buf, format='png', bbox_inches="tight")
    plt.close(fig)
    buf.seek(0)
    return "data:image/png;base64," + base64.b64encode(buf.read()).decode("utf-8")

# ====================================
# 5) Gradio Inference Function
# ====================================

def compare_rnn_transformer(input_text):
    """
    - Tokenize input_text
    - Simulate an RNN's hidden states
    - Show Tiny BERT attention (averaged over heads from last layer)
    - Return two images: RNN hidden states, Transformer attention map
    """

    # 1) Tokenize input text
    inputs = tokenizer.encode_plus(
        input_text,
        return_tensors="pt",
        truncation=True,
        max_length=50
    )
    # Move to GPU if available
    inputs = {k: v.to(device) for k, v in inputs.items()}

    # Convert IDs to tokens (just for axis labels)
    tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0].tolist())

    # 2) Simulate RNN hidden states
    rnn_heatmap = simulate_rnn_hidden_states(tokens)

    # 3) Forward pass through Tiny BERT
    with torch.no_grad():
        outputs = model(**inputs)
        # outputs.attentions: [n_layers, batch_size, n_heads, seq_len, seq_len]
        attentions = outputs.attentions
        # Take the last layer's attention
        last_layer_attention = attentions[-1].squeeze(0)  # shape: [n_heads, seq_len, seq_len]
        # Average across heads -> [seq_len, seq_len]
        avg_attention = last_layer_attention.mean(dim=0).cpu().numpy()

    # 4) Create a heatmap for attention
    transformer_heatmap = plot_heatmap(avg_attention, tokens, title="Transformer Attention")

    return (rnn_heatmap, transformer_heatmap)

# ====================================
# 6) Create and Launch Gradio Interface
# ====================================

interface = gr.Interface(
    fn=compare_rnn_transformer,
    inputs=gr.Textbox(
        lines=3,
        label="Enter a sentence to see RNN vs. Transformer visualization"
    ),
    outputs=[
        gr.Image(label="RNN Hidden States"),
        gr.Image(label="Transformer Attention Map")
    ],
    title="RNN vs. Tiny BERT Demo",
    description=(
        "Type in a sentence and see how a simulated RNN processes tokens step-by-step "
        "vs. how a real (tiny) Transformer computes attention across all tokens in parallel.\n\n"
        "For best performance, enable GPU under Runtime > Change runtime type > GPU."
    )
)

interface.launch()