File size: 5,167 Bytes
6bd894d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
# ====================================
# 1) Install Dependencies (One-Time)
# ====================================
# ====================================
# 2) Imports
# ====================================
import torch
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
import base64
from io import BytesIO
# We'll use a "tiny" BERT model to reduce loading/inference time:
from transformers import AutoTokenizer, AutoModel
# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
# ====================================
# 3) Load Tiny BERT Model + Tokenizer
# ====================================
# This "prajjwal1/bert-tiny" model is just 2 Transformer layers
# (and ~4 million parameters), so it loads and runs faster.
tokenizer = AutoTokenizer.from_pretrained("prajjwal1/bert-tiny")
model = AutoModel.from_pretrained(
"prajjwal1/bert-tiny",
output_attentions=True # so we can visualize attention
).to(device)
model.eval()
# ====================================
# 4) Helper Functions
# ====================================
def plot_heatmap(matrix, tokens, title="Attention", cmap="Blues"):
"""
Creates a heatmap from a 2D matrix: [seq_len, seq_len]
with tokens on both axes. Returns a base64-encoded PNG.
"""
fig, ax = plt.subplots(figsize=(6, 5))
cax = ax.imshow(matrix, interpolation='nearest', cmap=cmap)
ax.set_title(title)
# Show tokens on x and y axis
ax.set_xticks(range(len(tokens)))
ax.set_xticklabels(tokens, rotation=90)
ax.set_yticks(range(len(tokens)))
ax.set_yticklabels(tokens)
fig.colorbar(cax, ax=ax)
plt.tight_layout()
# Convert plot to base64-encoded PNG
buf = BytesIO()
plt.savefig(buf, format='png', bbox_inches="tight")
plt.close(fig)
buf.seek(0)
return "data:image/png;base64," + base64.b64encode(buf.read()).decode("utf-8")
def simulate_rnn_hidden_states(tokens):
"""
Simulate how an RNN processes tokens one-by-one.
We'll just create random hidden states for illustration.
Returns a base64-encoded PNG heatmap of shape [seq_len, hidden_dim].
"""
seq_len = len(tokens)
hidden_dim = 8 # small dimension for the "hidden state"
# Create random hidden states: shape [seq_len, hidden_dim]
random_states = np.random.rand(seq_len, hidden_dim)
fig, ax = plt.subplots(figsize=(6, 3))
cax = ax.imshow(random_states, interpolation='nearest', aspect='auto', cmap="viridis")
ax.set_title("Simulated RNN Hidden States")
ax.set_xlabel("Hidden Dim")
ax.set_ylabel("Token Index")
fig.colorbar(cax, ax=ax)
plt.tight_layout()
buf = BytesIO()
plt.savefig(buf, format='png', bbox_inches="tight")
plt.close(fig)
buf.seek(0)
return "data:image/png;base64," + base64.b64encode(buf.read()).decode("utf-8")
# ====================================
# 5) Gradio Inference Function
# ====================================
def compare_rnn_transformer(input_text):
"""
- Tokenize input_text
- Simulate an RNN's hidden states
- Show Tiny BERT attention (averaged over heads from last layer)
- Return two images: RNN hidden states, Transformer attention map
"""
# 1) Tokenize input text
inputs = tokenizer.encode_plus(
input_text,
return_tensors="pt",
truncation=True,
max_length=50
)
# Move to GPU if available
inputs = {k: v.to(device) for k, v in inputs.items()}
# Convert IDs to tokens (just for axis labels)
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0].tolist())
# 2) Simulate RNN hidden states
rnn_heatmap = simulate_rnn_hidden_states(tokens)
# 3) Forward pass through Tiny BERT
with torch.no_grad():
outputs = model(**inputs)
# outputs.attentions: [n_layers, batch_size, n_heads, seq_len, seq_len]
attentions = outputs.attentions
# Take the last layer's attention
last_layer_attention = attentions[-1].squeeze(0) # shape: [n_heads, seq_len, seq_len]
# Average across heads -> [seq_len, seq_len]
avg_attention = last_layer_attention.mean(dim=0).cpu().numpy()
# 4) Create a heatmap for attention
transformer_heatmap = plot_heatmap(avg_attention, tokens, title="Transformer Attention")
return (rnn_heatmap, transformer_heatmap)
# ====================================
# 6) Create and Launch Gradio Interface
# ====================================
interface = gr.Interface(
fn=compare_rnn_transformer,
inputs=gr.Textbox(
lines=3,
label="Enter a sentence to see RNN vs. Transformer visualization"
),
outputs=[
gr.Image(label="RNN Hidden States"),
gr.Image(label="Transformer Attention Map")
],
title="RNN vs. Tiny BERT Demo",
description=(
"Type in a sentence and see how a simulated RNN processes tokens step-by-step "
"vs. how a real (tiny) Transformer computes attention across all tokens in parallel.\n\n"
"For best performance, enable GPU under Runtime > Change runtime type > GPU."
)
)
interface.launch()
|