kevin1911 commited on
Commit
6bd894d
·
verified ·
1 Parent(s): 270f8d5

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +154 -0
app.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ====================================
2
+ # 1) Install Dependencies (One-Time)
3
+ # ====================================
4
+ # ====================================
5
+ # 2) Imports
6
+ # ====================================
7
+ import torch
8
+ import gradio as gr
9
+ import matplotlib.pyplot as plt
10
+ import numpy as np
11
+ import base64
12
+ from io import BytesIO
13
+
14
+ # We'll use a "tiny" BERT model to reduce loading/inference time:
15
+ from transformers import AutoTokenizer, AutoModel
16
+
17
+ # Check if GPU is available
18
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
+ print("Using device:", device)
20
+
21
+ # ====================================
22
+ # 3) Load Tiny BERT Model + Tokenizer
23
+ # ====================================
24
+ # This "prajjwal1/bert-tiny" model is just 2 Transformer layers
25
+ # (and ~4 million parameters), so it loads and runs faster.
26
+ tokenizer = AutoTokenizer.from_pretrained("prajjwal1/bert-tiny")
27
+ model = AutoModel.from_pretrained(
28
+ "prajjwal1/bert-tiny",
29
+ output_attentions=True # so we can visualize attention
30
+ ).to(device)
31
+ model.eval()
32
+
33
+ # ====================================
34
+ # 4) Helper Functions
35
+ # ====================================
36
+
37
+ def plot_heatmap(matrix, tokens, title="Attention", cmap="Blues"):
38
+ """
39
+ Creates a heatmap from a 2D matrix: [seq_len, seq_len]
40
+ with tokens on both axes. Returns a base64-encoded PNG.
41
+ """
42
+ fig, ax = plt.subplots(figsize=(6, 5))
43
+ cax = ax.imshow(matrix, interpolation='nearest', cmap=cmap)
44
+ ax.set_title(title)
45
+
46
+ # Show tokens on x and y axis
47
+ ax.set_xticks(range(len(tokens)))
48
+ ax.set_xticklabels(tokens, rotation=90)
49
+ ax.set_yticks(range(len(tokens)))
50
+ ax.set_yticklabels(tokens)
51
+
52
+ fig.colorbar(cax, ax=ax)
53
+ plt.tight_layout()
54
+
55
+ # Convert plot to base64-encoded PNG
56
+ buf = BytesIO()
57
+ plt.savefig(buf, format='png', bbox_inches="tight")
58
+ plt.close(fig)
59
+ buf.seek(0)
60
+ return "data:image/png;base64," + base64.b64encode(buf.read()).decode("utf-8")
61
+
62
+ def simulate_rnn_hidden_states(tokens):
63
+ """
64
+ Simulate how an RNN processes tokens one-by-one.
65
+ We'll just create random hidden states for illustration.
66
+ Returns a base64-encoded PNG heatmap of shape [seq_len, hidden_dim].
67
+ """
68
+ seq_len = len(tokens)
69
+ hidden_dim = 8 # small dimension for the "hidden state"
70
+
71
+ # Create random hidden states: shape [seq_len, hidden_dim]
72
+ random_states = np.random.rand(seq_len, hidden_dim)
73
+
74
+ fig, ax = plt.subplots(figsize=(6, 3))
75
+ cax = ax.imshow(random_states, interpolation='nearest', aspect='auto', cmap="viridis")
76
+ ax.set_title("Simulated RNN Hidden States")
77
+ ax.set_xlabel("Hidden Dim")
78
+ ax.set_ylabel("Token Index")
79
+
80
+ fig.colorbar(cax, ax=ax)
81
+ plt.tight_layout()
82
+
83
+ buf = BytesIO()
84
+ plt.savefig(buf, format='png', bbox_inches="tight")
85
+ plt.close(fig)
86
+ buf.seek(0)
87
+ return "data:image/png;base64," + base64.b64encode(buf.read()).decode("utf-8")
88
+
89
+ # ====================================
90
+ # 5) Gradio Inference Function
91
+ # ====================================
92
+
93
+ def compare_rnn_transformer(input_text):
94
+ """
95
+ - Tokenize input_text
96
+ - Simulate an RNN's hidden states
97
+ - Show Tiny BERT attention (averaged over heads from last layer)
98
+ - Return two images: RNN hidden states, Transformer attention map
99
+ """
100
+
101
+ # 1) Tokenize input text
102
+ inputs = tokenizer.encode_plus(
103
+ input_text,
104
+ return_tensors="pt",
105
+ truncation=True,
106
+ max_length=50
107
+ )
108
+ # Move to GPU if available
109
+ inputs = {k: v.to(device) for k, v in inputs.items()}
110
+
111
+ # Convert IDs to tokens (just for axis labels)
112
+ tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0].tolist())
113
+
114
+ # 2) Simulate RNN hidden states
115
+ rnn_heatmap = simulate_rnn_hidden_states(tokens)
116
+
117
+ # 3) Forward pass through Tiny BERT
118
+ with torch.no_grad():
119
+ outputs = model(**inputs)
120
+ # outputs.attentions: [n_layers, batch_size, n_heads, seq_len, seq_len]
121
+ attentions = outputs.attentions
122
+ # Take the last layer's attention
123
+ last_layer_attention = attentions[-1].squeeze(0) # shape: [n_heads, seq_len, seq_len]
124
+ # Average across heads -> [seq_len, seq_len]
125
+ avg_attention = last_layer_attention.mean(dim=0).cpu().numpy()
126
+
127
+ # 4) Create a heatmap for attention
128
+ transformer_heatmap = plot_heatmap(avg_attention, tokens, title="Transformer Attention")
129
+
130
+ return (rnn_heatmap, transformer_heatmap)
131
+
132
+ # ====================================
133
+ # 6) Create and Launch Gradio Interface
134
+ # ====================================
135
+
136
+ interface = gr.Interface(
137
+ fn=compare_rnn_transformer,
138
+ inputs=gr.Textbox(
139
+ lines=3,
140
+ label="Enter a sentence to see RNN vs. Transformer visualization"
141
+ ),
142
+ outputs=[
143
+ gr.Image(label="RNN Hidden States"),
144
+ gr.Image(label="Transformer Attention Map")
145
+ ],
146
+ title="RNN vs. Tiny BERT Demo",
147
+ description=(
148
+ "Type in a sentence and see how a simulated RNN processes tokens step-by-step "
149
+ "vs. how a real (tiny) Transformer computes attention across all tokens in parallel.\n\n"
150
+ "For best performance, enable GPU under Runtime > Change runtime type > GPU."
151
+ )
152
+ )
153
+
154
+ interface.launch()