Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ====================================
|
2 |
+
# 1) Install Dependencies (One-Time)
|
3 |
+
# ====================================
|
4 |
+
# ====================================
|
5 |
+
# 2) Imports
|
6 |
+
# ====================================
|
7 |
+
import torch
|
8 |
+
import gradio as gr
|
9 |
+
import matplotlib.pyplot as plt
|
10 |
+
import numpy as np
|
11 |
+
import base64
|
12 |
+
from io import BytesIO
|
13 |
+
|
14 |
+
# We'll use a "tiny" BERT model to reduce loading/inference time:
|
15 |
+
from transformers import AutoTokenizer, AutoModel
|
16 |
+
|
17 |
+
# Check if GPU is available
|
18 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
19 |
+
print("Using device:", device)
|
20 |
+
|
21 |
+
# ====================================
|
22 |
+
# 3) Load Tiny BERT Model + Tokenizer
|
23 |
+
# ====================================
|
24 |
+
# This "prajjwal1/bert-tiny" model is just 2 Transformer layers
|
25 |
+
# (and ~4 million parameters), so it loads and runs faster.
|
26 |
+
tokenizer = AutoTokenizer.from_pretrained("prajjwal1/bert-tiny")
|
27 |
+
model = AutoModel.from_pretrained(
|
28 |
+
"prajjwal1/bert-tiny",
|
29 |
+
output_attentions=True # so we can visualize attention
|
30 |
+
).to(device)
|
31 |
+
model.eval()
|
32 |
+
|
33 |
+
# ====================================
|
34 |
+
# 4) Helper Functions
|
35 |
+
# ====================================
|
36 |
+
|
37 |
+
def plot_heatmap(matrix, tokens, title="Attention", cmap="Blues"):
|
38 |
+
"""
|
39 |
+
Creates a heatmap from a 2D matrix: [seq_len, seq_len]
|
40 |
+
with tokens on both axes. Returns a base64-encoded PNG.
|
41 |
+
"""
|
42 |
+
fig, ax = plt.subplots(figsize=(6, 5))
|
43 |
+
cax = ax.imshow(matrix, interpolation='nearest', cmap=cmap)
|
44 |
+
ax.set_title(title)
|
45 |
+
|
46 |
+
# Show tokens on x and y axis
|
47 |
+
ax.set_xticks(range(len(tokens)))
|
48 |
+
ax.set_xticklabels(tokens, rotation=90)
|
49 |
+
ax.set_yticks(range(len(tokens)))
|
50 |
+
ax.set_yticklabels(tokens)
|
51 |
+
|
52 |
+
fig.colorbar(cax, ax=ax)
|
53 |
+
plt.tight_layout()
|
54 |
+
|
55 |
+
# Convert plot to base64-encoded PNG
|
56 |
+
buf = BytesIO()
|
57 |
+
plt.savefig(buf, format='png', bbox_inches="tight")
|
58 |
+
plt.close(fig)
|
59 |
+
buf.seek(0)
|
60 |
+
return "data:image/png;base64," + base64.b64encode(buf.read()).decode("utf-8")
|
61 |
+
|
62 |
+
def simulate_rnn_hidden_states(tokens):
|
63 |
+
"""
|
64 |
+
Simulate how an RNN processes tokens one-by-one.
|
65 |
+
We'll just create random hidden states for illustration.
|
66 |
+
Returns a base64-encoded PNG heatmap of shape [seq_len, hidden_dim].
|
67 |
+
"""
|
68 |
+
seq_len = len(tokens)
|
69 |
+
hidden_dim = 8 # small dimension for the "hidden state"
|
70 |
+
|
71 |
+
# Create random hidden states: shape [seq_len, hidden_dim]
|
72 |
+
random_states = np.random.rand(seq_len, hidden_dim)
|
73 |
+
|
74 |
+
fig, ax = plt.subplots(figsize=(6, 3))
|
75 |
+
cax = ax.imshow(random_states, interpolation='nearest', aspect='auto', cmap="viridis")
|
76 |
+
ax.set_title("Simulated RNN Hidden States")
|
77 |
+
ax.set_xlabel("Hidden Dim")
|
78 |
+
ax.set_ylabel("Token Index")
|
79 |
+
|
80 |
+
fig.colorbar(cax, ax=ax)
|
81 |
+
plt.tight_layout()
|
82 |
+
|
83 |
+
buf = BytesIO()
|
84 |
+
plt.savefig(buf, format='png', bbox_inches="tight")
|
85 |
+
plt.close(fig)
|
86 |
+
buf.seek(0)
|
87 |
+
return "data:image/png;base64," + base64.b64encode(buf.read()).decode("utf-8")
|
88 |
+
|
89 |
+
# ====================================
|
90 |
+
# 5) Gradio Inference Function
|
91 |
+
# ====================================
|
92 |
+
|
93 |
+
def compare_rnn_transformer(input_text):
|
94 |
+
"""
|
95 |
+
- Tokenize input_text
|
96 |
+
- Simulate an RNN's hidden states
|
97 |
+
- Show Tiny BERT attention (averaged over heads from last layer)
|
98 |
+
- Return two images: RNN hidden states, Transformer attention map
|
99 |
+
"""
|
100 |
+
|
101 |
+
# 1) Tokenize input text
|
102 |
+
inputs = tokenizer.encode_plus(
|
103 |
+
input_text,
|
104 |
+
return_tensors="pt",
|
105 |
+
truncation=True,
|
106 |
+
max_length=50
|
107 |
+
)
|
108 |
+
# Move to GPU if available
|
109 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
110 |
+
|
111 |
+
# Convert IDs to tokens (just for axis labels)
|
112 |
+
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0].tolist())
|
113 |
+
|
114 |
+
# 2) Simulate RNN hidden states
|
115 |
+
rnn_heatmap = simulate_rnn_hidden_states(tokens)
|
116 |
+
|
117 |
+
# 3) Forward pass through Tiny BERT
|
118 |
+
with torch.no_grad():
|
119 |
+
outputs = model(**inputs)
|
120 |
+
# outputs.attentions: [n_layers, batch_size, n_heads, seq_len, seq_len]
|
121 |
+
attentions = outputs.attentions
|
122 |
+
# Take the last layer's attention
|
123 |
+
last_layer_attention = attentions[-1].squeeze(0) # shape: [n_heads, seq_len, seq_len]
|
124 |
+
# Average across heads -> [seq_len, seq_len]
|
125 |
+
avg_attention = last_layer_attention.mean(dim=0).cpu().numpy()
|
126 |
+
|
127 |
+
# 4) Create a heatmap for attention
|
128 |
+
transformer_heatmap = plot_heatmap(avg_attention, tokens, title="Transformer Attention")
|
129 |
+
|
130 |
+
return (rnn_heatmap, transformer_heatmap)
|
131 |
+
|
132 |
+
# ====================================
|
133 |
+
# 6) Create and Launch Gradio Interface
|
134 |
+
# ====================================
|
135 |
+
|
136 |
+
interface = gr.Interface(
|
137 |
+
fn=compare_rnn_transformer,
|
138 |
+
inputs=gr.Textbox(
|
139 |
+
lines=3,
|
140 |
+
label="Enter a sentence to see RNN vs. Transformer visualization"
|
141 |
+
),
|
142 |
+
outputs=[
|
143 |
+
gr.Image(label="RNN Hidden States"),
|
144 |
+
gr.Image(label="Transformer Attention Map")
|
145 |
+
],
|
146 |
+
title="RNN vs. Tiny BERT Demo",
|
147 |
+
description=(
|
148 |
+
"Type in a sentence and see how a simulated RNN processes tokens step-by-step "
|
149 |
+
"vs. how a real (tiny) Transformer computes attention across all tokens in parallel.\n\n"
|
150 |
+
"For best performance, enable GPU under Runtime > Change runtime type > GPU."
|
151 |
+
)
|
152 |
+
)
|
153 |
+
|
154 |
+
interface.launch()
|