modernbert-viz / index.html
ucalyptus's picture
Update index.html
85f61a7 verified
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>ModernBERT Features Visualization</title>
<script src="https://cdn.tailwindcss.com"></script>
<style>
/* Custom styles for better visualization */
body {
font-family: 'Inter', sans-serif; /* Using Inter font as standard */
}
canvas {
display: block;
background-color: #f3f4f6; /* Light gray background */
border-radius: 0.5rem; /* Rounded corners */
border: 1px solid #d1d5db; /* Gray border */
margin: 1rem auto; /* Center canvas with margin */
}
.token {
display: inline-block; /* Align tokens nicely */
padding: 0.3rem 0.6rem;
margin: 0.2rem;
border-radius: 0.375rem; /* Rounded corners for tokens */
font-size: 0.875rem; /* Smaller font size */
font-weight: 500;
min-width: 30px; /* Ensure minimum width */
text-align: center;
}
.token-real { background-color: #60a5fa; color: white; } /* Blue for real tokens */
.token-pad { background-color: #e5e7eb; color: #6b7280; } /* Gray for padding */
.token-attend { border: 2px solid #f87171; } /* Red border for attending token */
.token-attended { background-color: #fbbf24; color: white; } /* Amber for attended tokens */
.token-local { background-color: #a78bfa; color: white; } /* Violet for local attention window */
/* Ensure canvas is responsive */
#animationCanvas {
width: 100%;
max-width: 800px; /* Limit max width for larger screens */
height: auto; /* Adjust height automatically */
aspect-ratio: 16 / 9; /* Maintain aspect ratio, adjust as needed */
}
/* Style buttons */
button {
transition: all 0.2s ease-in-out; /* Smooth transitions */
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1); /* Subtle shadow */
}
button:hover {
transform: translateY(-1px); /* Slight lift on hover */
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.15); /* Enhanced shadow on hover */
}
button:active {
transform: translateY(0px); /* Press effect */
box-shadow: 0 1px 2px rgba(0, 0, 0, 0.1);
}
</style>
<link href="https://fonts.googleapis.com/css2?family=Inter:wght@400;500;700&display=swap" rel="stylesheet">
</head>
<body class="bg-gray-100 p-4 md:p-8">
<div class="max-w-4xl mx-auto bg-white p-6 rounded-lg shadow-md">
<h1 class="text-2xl md:text-3xl font-bold text-center text-gray-800 mb-6">Visualizing ModernBERT Efficiency Features</h1>
<canvas id="animationCanvas"></canvas>
<div class="flex flex-wrap justify-center gap-3 mt-6 mb-4">
<button id="btnPadding" class="bg-blue-500 hover:bg-blue-600 text-white font-semibold py-2 px-4 rounded-lg shadow">Padding</button>
<button id="btnUnpadding" class="bg-green-500 hover:bg-green-600 text-white font-semibold py-2 px-4 rounded-lg shadow">Unpadding</button>
<button id="btnPacking" class="bg-purple-500 hover:bg-purple-600 text-white font-semibold py-2 px-4 rounded-lg shadow">Sequence Packing</button>
<button id="btnLocalAttn" class="bg-indigo-500 hover:bg-indigo-600 text-white font-semibold py-2 px-4 rounded-lg shadow">Local Attention</button>
<button id="btnGlobalAttn" class="bg-red-500 hover:bg-red-600 text-white font-semibold py-2 px-4 rounded-lg shadow">Global Attention</button>
<button id="btnReset" class="bg-gray-500 hover:bg-gray-600 text-white font-semibold py-2 px-4 rounded-lg shadow">Reset</button>
</div>
<div id="explanation" class="mt-4 p-4 bg-gray-50 border border-gray-200 rounded-lg text-gray-700 min-h-[100px]">
Click a button above to see the animation and explanation.
</div>
<div class="mt-6 p-4 bg-gray-50 border border-gray-200 rounded-lg">
<h3 class="font-semibold mb-2 text-gray-800">Legend:</h3>
<div class="flex flex-wrap gap-2 items-center">
<span class="token token-real">Token</span>
<span class="token token-pad">Pad</span>
<span class="token token-attend">Attending</span>
<span class="token token-attended">Attended</span>
<span class="token token-local">Local Window</span>
</div>
</div>
</div>
<script>
const canvas = document.getElementById('animationCanvas');
const ctx = canvas.getContext('2d');
const explanationDiv = document.getElementById('explanation');
// --- Configuration ---
const tokenWidth = 45;
const tokenHeight = 30;
const padding = 10; // Padding around tokens and between rows
const animationSpeed = 2; // Higher is faster
let sequences = [
['T1', 'T2', 'T3', 'T4', 'T5'],
['T1', 'T2', 'T3'],
['T1', 'T2', 'T3', 'T4', 'T5', 'T6', 'T7'],
['T1', 'T2']
];
const maxSeqLen = 8; // Max length for padding example
const packMaxLen = 10; // Max length for packing example
const localWindowSize = 3; // e.g., attend to self +/- 1
let animationFrameId = null; // To control animation loop
// --- Drawing Functions ---
function drawToken(x, y, text, type = 'real', highlight = 'none') {
ctx.font = '12px Inter';
ctx.textAlign = 'center';
ctx.textBaseline = 'middle';
let bgColor = '#60a5fa'; // token-real (blue)
let textColor = 'white';
let borderColor = null;
if (type === 'pad') {
bgColor = '#e5e7eb'; // token-pad (gray)
textColor = '#6b7280';
} else if (type === 'attended') {
bgColor = '#fbbf24'; // token-attended (amber)
} else if (type === 'local') {
bgColor = '#a78bfa'; // token-local (violet)
}
if (highlight === 'attending') {
borderColor = '#f87171'; // token-attend (red)
}
// Draw background rectangle
ctx.fillStyle = bgColor;
ctx.beginPath();
ctx.roundRect(x, y, tokenWidth, tokenHeight, 5); // Use roundRect for rounded corners
ctx.fill();
// Draw border if needed
if (borderColor) {
ctx.strokeStyle = borderColor;
ctx.lineWidth = 2;
ctx.stroke();
}
// Draw text
ctx.fillStyle = textColor;
ctx.fillText(text, x + tokenWidth / 2, y + tokenHeight / 2);
}
function drawSequence(x, y, sequence, highlightMap = {}) {
sequence.forEach((token, index) => {
const tokenX = x + index * (tokenWidth + padding);
const tokenType = (typeof token === 'string' && token.startsWith('T')) ? 'real' : 'pad';
const highlight = highlightMap[index] || 'none';
drawToken(tokenX, y, token, tokenType, highlight);
});
}
function clearCanvas() {
ctx.clearRect(0, 0, canvas.width, canvas.height);
}
// --- Animation Logic ---
// 1. Padding Animation
function animatePadding() {
cancelAnimationFrame(animationFrameId); // Stop previous animation
explanationDiv.innerHTML = `
<h3 class="font-semibold mb-1">Padding</h3>
Traditional processing requires all sequences in a batch to have the same length. Shorter sequences are "padded" with special (PAD) tokens to match the longest sequence. This wastes computation on meaningless tokens.
`;
let currentLengths = sequences.map(() => 0);
const targetMaxLength = Math.max(...sequences.map(s => s.length));
const paddedSequences = sequences.map(seq => {
const padsNeeded = targetMaxLength - seq.length;
return [...seq, ...Array(padsNeeded).fill('Pad')];
});
let progress = 0; // Represents how many tokens are shown
const totalSteps = targetMaxLength;
function step() {
clearCanvas();
const startY = padding * 3; // Start drawing lower
paddedSequences.forEach((seq, i) => {
const y = startY + i * (tokenHeight + padding * 2); // Increased vertical spacing
const displayLength = Math.min(seq.length, Math.ceil(progress));
drawSequence(padding * 2, y, seq.slice(0, displayLength));
});
if (progress < totalSteps) {
progress += animationSpeed / 10; // Adjust speed
animationFrameId = requestAnimationFrame(step);
}
}
step();
}
// 2. Unpadding Animation
function animateUnpadding() {
cancelAnimationFrame(animationFrameId);
explanationDiv.innerHTML = `
<h3 class="font-semibold mb-1">Unpadding</h3>
Unpadding removes the PAD tokens before processing. Sequences are treated individually (conceptually), avoiding wasted computation. ModernBERT concatenates these unpadded sequences.
`;
const targetMaxLength = Math.max(...sequences.map(s => s.length));
const paddedSequences = sequences.map(seq => {
const padsNeeded = targetMaxLength - seq.length;
return [...seq, ...Array(padsNeeded).fill('Pad')];
});
let fadeAmount = 1; // 1 = fully visible, 0 = invisible
function step() {
clearCanvas();
const startY = padding * 3;
ctx.globalAlpha = fadeAmount; // Apply fade effect
paddedSequences.forEach((seq, i) => {
const y = startY + i * (tokenHeight + padding * 2);
seq.forEach((token, index) => {
const tokenX = padding * 2 + index * (tokenWidth + padding);
const isPad = token === 'Pad';
// Only fade out padding tokens
ctx.globalAlpha = isPad ? fadeAmount : 1;
drawToken(tokenX, y, token, isPad ? 'pad' : 'real');
});
});
ctx.globalAlpha = 1; // Reset alpha
if (fadeAmount > 0) {
fadeAmount -= 0.02 * animationSpeed; // Fade out speed
animationFrameId = requestAnimationFrame(step);
} else {
// After fading, show only original tokens
clearCanvas();
sequences.forEach((seq, i) => {
const y = startY + i * (tokenHeight + padding * 2);
drawSequence(padding * 2, y, seq);
});
}
}
step();
}
// 3. Sequence Packing Animation
function animateSequencePacking() {
cancelAnimationFrame(animationFrameId);
explanationDiv.innerHTML = `
<h3 class="font-semibold mb-1">Sequence Packing</h3>
After unpadding, sequences are concatenated (packed) together into longer sequences, up to the model's maximum length (e.g., ${packMaxLen} here). This maximizes GPU utilization by processing more real tokens per batch. Careful masking ensures tokens only attend within their original sequence.
`;
let packedSequences = [];
let currentPack = [];
let currentLen = 0;
sequences.forEach(seq => {
if (currentLen + seq.length <= packMaxLen) {
currentPack.push(...seq);
currentLen += seq.length;
} else {
packedSequences.push([...currentPack]);
currentPack = [...seq];
currentLen = seq.length;
}
});
if (currentPack.length > 0) {
packedSequences.push(currentPack);
}
let progress = 0; // How many sequences are shown packed
const totalSequences = packedSequences.length;
function step() {
clearCanvas();
const startY = padding * 3;
let currentY = startY;
// Draw original sequences first (fading out)
const fade = 1 - (progress / totalSequences);
ctx.globalAlpha = Math.max(0, fade);
sequences.forEach((seq, i) => {
const y = startY + i * (tokenHeight + padding * 2);
drawSequence(padding * 2, y, seq);
});
ctx.globalAlpha = 1.0;
// Draw packed sequences (fading in)
const fadeIn = progress / totalSequences;
ctx.globalAlpha = Math.min(1, fadeIn * 2); // Faster fade in
packedSequences.slice(0, Math.ceil(progress)).forEach((pack, i) => {
const y = startY + i * (tokenHeight + padding * 2); // Draw packed below originals initially
drawSequence(padding * 2, y, pack);
currentY = y + tokenHeight + padding * 2;
});
ctx.globalAlpha = 1.0;
if (progress < totalSequences) {
progress += animationSpeed / 20; // Adjust speed
animationFrameId = requestAnimationFrame(step);
} else {
// Final state: only packed sequences
clearCanvas();
packedSequences.forEach((pack, i) => {
const y = startY + i * (tokenHeight + padding * 2);
drawSequence(padding * 2, y, pack);
});
}
}
step();
}
// 4. Attention Animations (Local/Global)
function animateAttention(isGlobal) {
cancelAnimationFrame(animationFrameId);
const seq = sequences[2]; // Use the longest sequence for demo
const midIndex = Math.floor(seq.length / 2); // Token that will 'attend'
explanationDiv.innerHTML = `
<h3 class="font-semibold mb-1">${isGlobal ? 'Global' : 'Local'} Attention</h3>
Attention allows tokens to "look" at other tokens to understand context.
<b>${isGlobal ? 'Global Attention:' : 'Local Attention:'}</b>
${isGlobal
? 'Every token attends to every other token in the sequence. Powerful but computationally expensive, especially for long sequences.'
: `Each token attends only to a fixed-size window of nearby tokens (e.g., +/- ${Math.floor(localWindowSize / 2)} here). Much faster for long sequences.`
} ModernBERT alternates between these layers.
`;
let highlightProgress = 0; // 0 to 1
function step() {
clearCanvas();
const startY = padding * 3;
const startX = padding * 2;
let highlightMap = {};
highlightMap[midIndex] = 'attending'; // The token doing the attending
const currentHighlight = Math.min(1, highlightProgress);
if (isGlobal) {
// Highlight all tokens based on progress
for (let i = 0; i < seq.length; i++) {
if (i !== midIndex) {
// Simple linear fade-in for attended tokens
if (Math.random() < currentHighlight) { // Randomly highlight based on progress for effect
highlightMap[i] = 'attended';
}
}
}
} else { // Local Attention
const windowStart = Math.max(0, midIndex - Math.floor(localWindowSize / 2));
const windowEnd = Math.min(seq.length - 1, midIndex + Math.floor(localWindowSize / 2));
for (let i = windowStart; i <= windowEnd; i++) {
// Highlight based on progress within the window
if (i !== midIndex) {
const dist = Math.abs(i - midIndex);
const requiredProgress = dist / (localWindowSize / 2); // Closer tokens highlight sooner
if (currentHighlight >= requiredProgress) {
highlightMap[i] = 'local'; // Use 'local' type for window visualization
}
}
}
}
drawSequence(startX, startY, seq, highlightMap);
if (highlightProgress < 1) {
highlightProgress += 0.01 * animationSpeed;
animationFrameId = requestAnimationFrame(step);
} else {
// Ensure final state is fully highlighted
clearCanvas();
highlightMap = {};
highlightMap[midIndex] = 'attending';
if (isGlobal) {
for (let i = 0; i < seq.length; i++) if (i !== midIndex) highlightMap[i] = 'attended';
} else {
const windowStart = Math.max(0, midIndex - Math.floor(localWindowSize / 2));
const windowEnd = Math.min(seq.length - 1, midIndex + Math.floor(localWindowSize / 2));
for (let i = windowStart; i <= windowEnd; i++) if (i !== midIndex) highlightMap[i] = 'local';
}
drawSequence(startX, startY, seq, highlightMap);
}
}
step();
}
// --- Reset Function ---
function resetVisualization() {
cancelAnimationFrame(animationFrameId);
clearCanvas();
explanationDiv.innerHTML = 'Click a button above to see the animation and explanation.';
// Optionally redraw initial state if needed
// drawInitialState(); // Implement if you want a default view
}
// --- Event Listeners ---
document.getElementById('btnPadding').addEventListener('click', animatePadding);
document.getElementById('btnUnpadding').addEventListener('click', animateUnpadding);
document.getElementById('btnPacking').addEventListener('click', animateSequencePacking);
document.getElementById('btnLocalAttn').addEventListener('click', () => animateAttention(false));
document.getElementById('btnGlobalAttn').addEventListener('click', () => animateAttention(true));
document.getElementById('btnReset').addEventListener('click', resetVisualization);
// --- Initial Setup & Resize ---
function resizeCanvas() {
// Make canvas resolution match its display size
const displayWidth = canvas.clientWidth;
const displayHeight = canvas.clientHeight; // Use clientHeight for aspect ratio consistency
if (canvas.width !== displayWidth || canvas.height !== displayHeight) {
canvas.width = displayWidth;
canvas.height = displayHeight;
// Redraw current state if an animation was running? Or just reset?
// For simplicity, we'll reset on resize.
resetVisualization();
}
}
// Initial resize and setup listener
window.addEventListener('resize', resizeCanvas);
// Ensure initial sizing is correct after elements are laid out
window.addEventListener('load', () => {
resizeCanvas();
resetVisualization(); // Start clean
});
</script>
</body>
</html>