|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import gradio as gr |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
import base64 |
|
from io import BytesIO |
|
|
|
|
|
from transformers import AutoTokenizer, AutoModel |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
print("Using device:", device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("prajjwal1/bert-tiny") |
|
model = AutoModel.from_pretrained( |
|
"prajjwal1/bert-tiny", |
|
output_attentions=True |
|
).to(device) |
|
model.eval() |
|
|
|
|
|
|
|
|
|
|
|
def plot_heatmap(matrix, tokens, title="Attention", cmap="Blues"): |
|
""" |
|
Creates a heatmap from a 2D matrix: [seq_len, seq_len] |
|
with tokens on both axes. Returns a base64-encoded PNG. |
|
""" |
|
fig, ax = plt.subplots(figsize=(6, 5)) |
|
cax = ax.imshow(matrix, interpolation='nearest', cmap=cmap) |
|
ax.set_title(title) |
|
|
|
|
|
ax.set_xticks(range(len(tokens))) |
|
ax.set_xticklabels(tokens, rotation=90) |
|
ax.set_yticks(range(len(tokens))) |
|
ax.set_yticklabels(tokens) |
|
|
|
fig.colorbar(cax, ax=ax) |
|
plt.tight_layout() |
|
|
|
|
|
buf = BytesIO() |
|
plt.savefig(buf, format='png', bbox_inches="tight") |
|
plt.close(fig) |
|
buf.seek(0) |
|
return "data:image/png;base64," + base64.b64encode(buf.read()).decode("utf-8") |
|
|
|
def simulate_rnn_hidden_states(tokens): |
|
""" |
|
Simulate how an RNN processes tokens one-by-one. |
|
We'll just create random hidden states for illustration. |
|
Returns a base64-encoded PNG heatmap of shape [seq_len, hidden_dim]. |
|
""" |
|
seq_len = len(tokens) |
|
hidden_dim = 8 |
|
|
|
|
|
random_states = np.random.rand(seq_len, hidden_dim) |
|
|
|
fig, ax = plt.subplots(figsize=(6, 3)) |
|
cax = ax.imshow(random_states, interpolation='nearest', aspect='auto', cmap="viridis") |
|
ax.set_title("Simulated RNN Hidden States") |
|
ax.set_xlabel("Hidden Dim") |
|
ax.set_ylabel("Token Index") |
|
|
|
fig.colorbar(cax, ax=ax) |
|
plt.tight_layout() |
|
|
|
buf = BytesIO() |
|
plt.savefig(buf, format='png', bbox_inches="tight") |
|
plt.close(fig) |
|
buf.seek(0) |
|
return "data:image/png;base64," + base64.b64encode(buf.read()).decode("utf-8") |
|
|
|
|
|
|
|
|
|
|
|
def compare_rnn_transformer(input_text): |
|
""" |
|
- Tokenize input_text |
|
- Simulate an RNN's hidden states |
|
- Show Tiny BERT attention (averaged over heads from last layer) |
|
- Return two images: RNN hidden states, Transformer attention map |
|
""" |
|
|
|
|
|
inputs = tokenizer.encode_plus( |
|
input_text, |
|
return_tensors="pt", |
|
truncation=True, |
|
max_length=50 |
|
) |
|
|
|
inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
|
|
|
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0].tolist()) |
|
|
|
|
|
rnn_heatmap = simulate_rnn_hidden_states(tokens) |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
|
|
attentions = outputs.attentions |
|
|
|
last_layer_attention = attentions[-1].squeeze(0) |
|
|
|
avg_attention = last_layer_attention.mean(dim=0).cpu().numpy() |
|
|
|
|
|
transformer_heatmap = plot_heatmap(avg_attention, tokens, title="Transformer Attention") |
|
|
|
return (rnn_heatmap, transformer_heatmap) |
|
|
|
|
|
|
|
|
|
|
|
interface = gr.Interface( |
|
fn=compare_rnn_transformer, |
|
inputs=gr.Textbox( |
|
lines=3, |
|
label="Enter a sentence to see RNN vs. Transformer visualization" |
|
), |
|
outputs=[ |
|
gr.Image(label="RNN Hidden States"), |
|
gr.Image(label="Transformer Attention Map") |
|
], |
|
title="RNN vs. Tiny BERT Demo", |
|
description=( |
|
"Type in a sentence and see how a simulated RNN processes tokens step-by-step " |
|
"vs. how a real (tiny) Transformer computes attention across all tokens in parallel.\n\n" |
|
"For best performance, enable GPU under Runtime > Change runtime type > GPU." |
|
) |
|
) |
|
|
|
interface.launch() |
|
|