transformer-visualizer / visualize.py
marianvd-01's picture
Update visualize.py
38d431d verified
# visualize.py - Contains functions to draw:
#Attention matrix
#Tokenization preview
#Embedding heatmaps
#Model comparison chart
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import torch
from sklearn.decomposition import PCA
def plot_attention(tokens, attn_matrix):
fig, ax = plt.subplots(figsize=(8, 6))
cax = ax.matshow(attn_matrix, cmap="viridis")
fig.colorbar(cax)
ax.set_xticks(range(len(tokens)))
ax.set_yticks(range(len(tokens)))
ax.set_xticklabels(tokens, rotation=90)
ax.set_yticklabels(tokens)
ax.set_title("Attention Map")
plt.tight_layout()
return fig
def visualize_attention(tokenizer, model, text, layer_index, head_index):
inputs = tokenizer(text, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
attn = outputs.attentions[layer_index][0, head_index].detach().numpy()
return plot_attention(tokens, attn)
def show_tokenization(tokenizer, text):
tokens = tokenizer.tokenize(text)
fig, ax = plt.subplots(figsize=(8, 1))
ax.imshow([[0] * len(tokens)], cmap="Pastel2", aspect="auto")
ax.set_xticks(range(len(tokens)))
ax.set_xticklabels(tokens, rotation=90)
ax.set_yticks([])
ax.set_title("Tokenization")
return fig
def show_embeddings(tokenizer, model, text):
inputs = tokenizer(text, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
embeddings = outputs.last_hidden_state[0].detach().numpy()
pca = PCA(n_components=2)
reduced = pca.fit_transform(embeddings)
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
fig, ax = plt.subplots()
ax.scatter(reduced[:, 0], reduced[:, 1])
for i, token in enumerate(tokens):
ax.annotate(token, (reduced[i, 0], reduced[i, 1]))
ax.set_title("Token Embeddings (PCA)")
return fig
def get_token_list(tokenizer, text):
return tokenizer.tokenize(text)
def compare_model_sizes():
from model_utils import MODEL_OPTIONS
from transformers import AutoModel
model_names = list(MODEL_OPTIONS.values())
sizes = []
for name in model_names:
try:
model = AutoModel.from_pretrained(name)
size = sum(p.numel() for p in model.parameters()) / 1e6 # in millions
sizes.append(size)
except:
sizes.append(None)
fig, ax = plt.subplots()
ax.bar(list(MODEL_OPTIONS.keys()), sizes, color="skyblue")
ax.set_ylabel("Parameters (Millions)")
ax.set_title("Model Size Comparison")
ax.tick_params(axis='x', rotation=45)
plt.tight_layout()
return fig