LLM / app.py
kevin1911's picture
Create app.py
6bd894d verified
# ====================================
# 1) Install Dependencies (One-Time)
# ====================================
# ====================================
# 2) Imports
# ====================================
import torch
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
import base64
from io import BytesIO
# We'll use a "tiny" BERT model to reduce loading/inference time:
from transformers import AutoTokenizer, AutoModel
# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
# ====================================
# 3) Load Tiny BERT Model + Tokenizer
# ====================================
# This "prajjwal1/bert-tiny" model is just 2 Transformer layers
# (and ~4 million parameters), so it loads and runs faster.
tokenizer = AutoTokenizer.from_pretrained("prajjwal1/bert-tiny")
model = AutoModel.from_pretrained(
"prajjwal1/bert-tiny",
output_attentions=True # so we can visualize attention
).to(device)
model.eval()
# ====================================
# 4) Helper Functions
# ====================================
def plot_heatmap(matrix, tokens, title="Attention", cmap="Blues"):
"""
Creates a heatmap from a 2D matrix: [seq_len, seq_len]
with tokens on both axes. Returns a base64-encoded PNG.
"""
fig, ax = plt.subplots(figsize=(6, 5))
cax = ax.imshow(matrix, interpolation='nearest', cmap=cmap)
ax.set_title(title)
# Show tokens on x and y axis
ax.set_xticks(range(len(tokens)))
ax.set_xticklabels(tokens, rotation=90)
ax.set_yticks(range(len(tokens)))
ax.set_yticklabels(tokens)
fig.colorbar(cax, ax=ax)
plt.tight_layout()
# Convert plot to base64-encoded PNG
buf = BytesIO()
plt.savefig(buf, format='png', bbox_inches="tight")
plt.close(fig)
buf.seek(0)
return "data:image/png;base64," + base64.b64encode(buf.read()).decode("utf-8")
def simulate_rnn_hidden_states(tokens):
"""
Simulate how an RNN processes tokens one-by-one.
We'll just create random hidden states for illustration.
Returns a base64-encoded PNG heatmap of shape [seq_len, hidden_dim].
"""
seq_len = len(tokens)
hidden_dim = 8 # small dimension for the "hidden state"
# Create random hidden states: shape [seq_len, hidden_dim]
random_states = np.random.rand(seq_len, hidden_dim)
fig, ax = plt.subplots(figsize=(6, 3))
cax = ax.imshow(random_states, interpolation='nearest', aspect='auto', cmap="viridis")
ax.set_title("Simulated RNN Hidden States")
ax.set_xlabel("Hidden Dim")
ax.set_ylabel("Token Index")
fig.colorbar(cax, ax=ax)
plt.tight_layout()
buf = BytesIO()
plt.savefig(buf, format='png', bbox_inches="tight")
plt.close(fig)
buf.seek(0)
return "data:image/png;base64," + base64.b64encode(buf.read()).decode("utf-8")
# ====================================
# 5) Gradio Inference Function
# ====================================
def compare_rnn_transformer(input_text):
"""
- Tokenize input_text
- Simulate an RNN's hidden states
- Show Tiny BERT attention (averaged over heads from last layer)
- Return two images: RNN hidden states, Transformer attention map
"""
# 1) Tokenize input text
inputs = tokenizer.encode_plus(
input_text,
return_tensors="pt",
truncation=True,
max_length=50
)
# Move to GPU if available
inputs = {k: v.to(device) for k, v in inputs.items()}
# Convert IDs to tokens (just for axis labels)
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0].tolist())
# 2) Simulate RNN hidden states
rnn_heatmap = simulate_rnn_hidden_states(tokens)
# 3) Forward pass through Tiny BERT
with torch.no_grad():
outputs = model(**inputs)
# outputs.attentions: [n_layers, batch_size, n_heads, seq_len, seq_len]
attentions = outputs.attentions
# Take the last layer's attention
last_layer_attention = attentions[-1].squeeze(0) # shape: [n_heads, seq_len, seq_len]
# Average across heads -> [seq_len, seq_len]
avg_attention = last_layer_attention.mean(dim=0).cpu().numpy()
# 4) Create a heatmap for attention
transformer_heatmap = plot_heatmap(avg_attention, tokens, title="Transformer Attention")
return (rnn_heatmap, transformer_heatmap)
# ====================================
# 6) Create and Launch Gradio Interface
# ====================================
interface = gr.Interface(
fn=compare_rnn_transformer,
inputs=gr.Textbox(
lines=3,
label="Enter a sentence to see RNN vs. Transformer visualization"
),
outputs=[
gr.Image(label="RNN Hidden States"),
gr.Image(label="Transformer Attention Map")
],
title="RNN vs. Tiny BERT Demo",
description=(
"Type in a sentence and see how a simulated RNN processes tokens step-by-step "
"vs. how a real (tiny) Transformer computes attention across all tokens in parallel.\n\n"
"For best performance, enable GPU under Runtime > Change runtime type > GPU."
)
)
interface.launch()