Spaces:
Running
Running
<html lang="en"> | |
<head> | |
<meta charset="UTF-8"> | |
<meta name="viewport" content="width=device-width, initial-scale=1.0"> | |
<title>ModernBERT Features Visualization</title> | |
<script src="https://cdn.tailwindcss.com"></script> | |
<style> | |
/* Custom styles for better visualization */ | |
body { | |
font-family: 'Inter', sans-serif; /* Using Inter font as standard */ | |
} | |
canvas { | |
display: block; | |
background-color: #f3f4f6; /* Light gray background */ | |
border-radius: 0.5rem; /* Rounded corners */ | |
border: 1px solid #d1d5db; /* Gray border */ | |
margin: 1rem auto; /* Center canvas with margin */ | |
} | |
.token { | |
display: inline-block; /* Align tokens nicely */ | |
padding: 0.3rem 0.6rem; | |
margin: 0.2rem; | |
border-radius: 0.375rem; /* Rounded corners for tokens */ | |
font-size: 0.875rem; /* Smaller font size */ | |
font-weight: 500; | |
min-width: 30px; /* Ensure minimum width */ | |
text-align: center; | |
} | |
.token-real { background-color: #60a5fa; color: white; } /* Blue for real tokens */ | |
.token-pad { background-color: #e5e7eb; color: #6b7280; } /* Gray for padding */ | |
.token-attend { border: 2px solid #f87171; } /* Red border for attending token */ | |
.token-attended { background-color: #fbbf24; color: white; } /* Amber for attended tokens */ | |
.token-local { background-color: #a78bfa; color: white; } /* Violet for local attention window */ | |
/* Ensure canvas is responsive */ | |
#animationCanvas { | |
width: 100%; | |
max-width: 800px; /* Limit max width for larger screens */ | |
height: auto; /* Adjust height automatically */ | |
aspect-ratio: 16 / 9; /* Maintain aspect ratio, adjust as needed */ | |
} | |
/* Style buttons */ | |
button { | |
transition: all 0.2s ease-in-out; /* Smooth transitions */ | |
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1); /* Subtle shadow */ | |
} | |
button:hover { | |
transform: translateY(-1px); /* Slight lift on hover */ | |
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.15); /* Enhanced shadow on hover */ | |
} | |
button:active { | |
transform: translateY(0px); /* Press effect */ | |
box-shadow: 0 1px 2px rgba(0, 0, 0, 0.1); | |
} | |
</style> | |
<link href="https://fonts.googleapis.com/css2?family=Inter:wght@400;500;700&display=swap" rel="stylesheet"> | |
</head> | |
<body class="bg-gray-100 p-4 md:p-8"> | |
<div class="max-w-4xl mx-auto bg-white p-6 rounded-lg shadow-md"> | |
<h1 class="text-2xl md:text-3xl font-bold text-center text-gray-800 mb-6">Visualizing ModernBERT Efficiency Features</h1> | |
<canvas id="animationCanvas"></canvas> | |
<div class="flex flex-wrap justify-center gap-3 mt-6 mb-4"> | |
<button id="btnPadding" class="bg-blue-500 hover:bg-blue-600 text-white font-semibold py-2 px-4 rounded-lg shadow">Padding</button> | |
<button id="btnUnpadding" class="bg-green-500 hover:bg-green-600 text-white font-semibold py-2 px-4 rounded-lg shadow">Unpadding</button> | |
<button id="btnPacking" class="bg-purple-500 hover:bg-purple-600 text-white font-semibold py-2 px-4 rounded-lg shadow">Sequence Packing</button> | |
<button id="btnLocalAttn" class="bg-indigo-500 hover:bg-indigo-600 text-white font-semibold py-2 px-4 rounded-lg shadow">Local Attention</button> | |
<button id="btnGlobalAttn" class="bg-red-500 hover:bg-red-600 text-white font-semibold py-2 px-4 rounded-lg shadow">Global Attention</button> | |
<button id="btnReset" class="bg-gray-500 hover:bg-gray-600 text-white font-semibold py-2 px-4 rounded-lg shadow">Reset</button> | |
</div> | |
<div id="explanation" class="mt-4 p-4 bg-gray-50 border border-gray-200 rounded-lg text-gray-700 min-h-[100px]"> | |
Click a button above to see the animation and explanation. | |
</div> | |
<div class="mt-6 p-4 bg-gray-50 border border-gray-200 rounded-lg"> | |
<h3 class="font-semibold mb-2 text-gray-800">Legend:</h3> | |
<div class="flex flex-wrap gap-2 items-center"> | |
<span class="token token-real">Token</span> | |
<span class="token token-pad">Pad</span> | |
<span class="token token-attend">Attending</span> | |
<span class="token token-attended">Attended</span> | |
<span class="token token-local">Local Window</span> | |
</div> | |
</div> | |
</div> | |
<script> | |
const canvas = document.getElementById('animationCanvas'); | |
const ctx = canvas.getContext('2d'); | |
const explanationDiv = document.getElementById('explanation'); | |
// --- Configuration --- | |
const tokenWidth = 45; | |
const tokenHeight = 30; | |
const padding = 10; // Padding around tokens and between rows | |
const animationSpeed = 2; // Higher is faster | |
let sequences = [ | |
['T1', 'T2', 'T3', 'T4', 'T5'], | |
['T1', 'T2', 'T3'], | |
['T1', 'T2', 'T3', 'T4', 'T5', 'T6', 'T7'], | |
['T1', 'T2'] | |
]; | |
const maxSeqLen = 8; // Max length for padding example | |
const packMaxLen = 10; // Max length for packing example | |
const localWindowSize = 3; // e.g., attend to self +/- 1 | |
let animationFrameId = null; // To control animation loop | |
// --- Drawing Functions --- | |
function drawToken(x, y, text, type = 'real', highlight = 'none') { | |
ctx.font = '12px Inter'; | |
ctx.textAlign = 'center'; | |
ctx.textBaseline = 'middle'; | |
let bgColor = '#60a5fa'; // token-real (blue) | |
let textColor = 'white'; | |
let borderColor = null; | |
if (type === 'pad') { | |
bgColor = '#e5e7eb'; // token-pad (gray) | |
textColor = '#6b7280'; | |
} else if (type === 'attended') { | |
bgColor = '#fbbf24'; // token-attended (amber) | |
} else if (type === 'local') { | |
bgColor = '#a78bfa'; // token-local (violet) | |
} | |
if (highlight === 'attending') { | |
borderColor = '#f87171'; // token-attend (red) | |
} | |
// Draw background rectangle | |
ctx.fillStyle = bgColor; | |
ctx.beginPath(); | |
ctx.roundRect(x, y, tokenWidth, tokenHeight, 5); // Use roundRect for rounded corners | |
ctx.fill(); | |
// Draw border if needed | |
if (borderColor) { | |
ctx.strokeStyle = borderColor; | |
ctx.lineWidth = 2; | |
ctx.stroke(); | |
} | |
// Draw text | |
ctx.fillStyle = textColor; | |
ctx.fillText(text, x + tokenWidth / 2, y + tokenHeight / 2); | |
} | |
function drawSequence(x, y, sequence, highlightMap = {}) { | |
sequence.forEach((token, index) => { | |
const tokenX = x + index * (tokenWidth + padding); | |
const tokenType = (typeof token === 'string' && token.startsWith('T')) ? 'real' : 'pad'; | |
const highlight = highlightMap[index] || 'none'; | |
drawToken(tokenX, y, token, tokenType, highlight); | |
}); | |
} | |
function clearCanvas() { | |
ctx.clearRect(0, 0, canvas.width, canvas.height); | |
} | |
// --- Animation Logic --- | |
// 1. Padding Animation | |
function animatePadding() { | |
cancelAnimationFrame(animationFrameId); // Stop previous animation | |
explanationDiv.innerHTML = ` | |
<h3 class="font-semibold mb-1">Padding</h3> | |
Traditional processing requires all sequences in a batch to have the same length. Shorter sequences are "padded" with special (PAD) tokens to match the longest sequence. This wastes computation on meaningless tokens. | |
`; | |
let currentLengths = sequences.map(() => 0); | |
const targetMaxLength = Math.max(...sequences.map(s => s.length)); | |
const paddedSequences = sequences.map(seq => { | |
const padsNeeded = targetMaxLength - seq.length; | |
return [...seq, ...Array(padsNeeded).fill('Pad')]; | |
}); | |
let progress = 0; // Represents how many tokens are shown | |
const totalSteps = targetMaxLength; | |
function step() { | |
clearCanvas(); | |
const startY = padding * 3; // Start drawing lower | |
paddedSequences.forEach((seq, i) => { | |
const y = startY + i * (tokenHeight + padding * 2); // Increased vertical spacing | |
const displayLength = Math.min(seq.length, Math.ceil(progress)); | |
drawSequence(padding * 2, y, seq.slice(0, displayLength)); | |
}); | |
if (progress < totalSteps) { | |
progress += animationSpeed / 10; // Adjust speed | |
animationFrameId = requestAnimationFrame(step); | |
} | |
} | |
step(); | |
} | |
// 2. Unpadding Animation | |
function animateUnpadding() { | |
cancelAnimationFrame(animationFrameId); | |
explanationDiv.innerHTML = ` | |
<h3 class="font-semibold mb-1">Unpadding</h3> | |
Unpadding removes the PAD tokens before processing. Sequences are treated individually (conceptually), avoiding wasted computation. ModernBERT concatenates these unpadded sequences. | |
`; | |
const targetMaxLength = Math.max(...sequences.map(s => s.length)); | |
const paddedSequences = sequences.map(seq => { | |
const padsNeeded = targetMaxLength - seq.length; | |
return [...seq, ...Array(padsNeeded).fill('Pad')]; | |
}); | |
let fadeAmount = 1; // 1 = fully visible, 0 = invisible | |
function step() { | |
clearCanvas(); | |
const startY = padding * 3; | |
ctx.globalAlpha = fadeAmount; // Apply fade effect | |
paddedSequences.forEach((seq, i) => { | |
const y = startY + i * (tokenHeight + padding * 2); | |
seq.forEach((token, index) => { | |
const tokenX = padding * 2 + index * (tokenWidth + padding); | |
const isPad = token === 'Pad'; | |
// Only fade out padding tokens | |
ctx.globalAlpha = isPad ? fadeAmount : 1; | |
drawToken(tokenX, y, token, isPad ? 'pad' : 'real'); | |
}); | |
}); | |
ctx.globalAlpha = 1; // Reset alpha | |
if (fadeAmount > 0) { | |
fadeAmount -= 0.02 * animationSpeed; // Fade out speed | |
animationFrameId = requestAnimationFrame(step); | |
} else { | |
// After fading, show only original tokens | |
clearCanvas(); | |
sequences.forEach((seq, i) => { | |
const y = startY + i * (tokenHeight + padding * 2); | |
drawSequence(padding * 2, y, seq); | |
}); | |
} | |
} | |
step(); | |
} | |
// 3. Sequence Packing Animation | |
function animateSequencePacking() { | |
cancelAnimationFrame(animationFrameId); | |
explanationDiv.innerHTML = ` | |
<h3 class="font-semibold mb-1">Sequence Packing</h3> | |
After unpadding, sequences are concatenated (packed) together into longer sequences, up to the model's maximum length (e.g., ${packMaxLen} here). This maximizes GPU utilization by processing more real tokens per batch. Careful masking ensures tokens only attend within their original sequence. | |
`; | |
let packedSequences = []; | |
let currentPack = []; | |
let currentLen = 0; | |
sequences.forEach(seq => { | |
if (currentLen + seq.length <= packMaxLen) { | |
currentPack.push(...seq); | |
currentLen += seq.length; | |
} else { | |
packedSequences.push([...currentPack]); | |
currentPack = [...seq]; | |
currentLen = seq.length; | |
} | |
}); | |
if (currentPack.length > 0) { | |
packedSequences.push(currentPack); | |
} | |
let progress = 0; // How many sequences are shown packed | |
const totalSequences = packedSequences.length; | |
function step() { | |
clearCanvas(); | |
const startY = padding * 3; | |
let currentY = startY; | |
// Draw original sequences first (fading out) | |
const fade = 1 - (progress / totalSequences); | |
ctx.globalAlpha = Math.max(0, fade); | |
sequences.forEach((seq, i) => { | |
const y = startY + i * (tokenHeight + padding * 2); | |
drawSequence(padding * 2, y, seq); | |
}); | |
ctx.globalAlpha = 1.0; | |
// Draw packed sequences (fading in) | |
const fadeIn = progress / totalSequences; | |
ctx.globalAlpha = Math.min(1, fadeIn * 2); // Faster fade in | |
packedSequences.slice(0, Math.ceil(progress)).forEach((pack, i) => { | |
const y = startY + i * (tokenHeight + padding * 2); // Draw packed below originals initially | |
drawSequence(padding * 2, y, pack); | |
currentY = y + tokenHeight + padding * 2; | |
}); | |
ctx.globalAlpha = 1.0; | |
if (progress < totalSequences) { | |
progress += animationSpeed / 20; // Adjust speed | |
animationFrameId = requestAnimationFrame(step); | |
} else { | |
// Final state: only packed sequences | |
clearCanvas(); | |
packedSequences.forEach((pack, i) => { | |
const y = startY + i * (tokenHeight + padding * 2); | |
drawSequence(padding * 2, y, pack); | |
}); | |
} | |
} | |
step(); | |
} | |
// 4. Attention Animations (Local/Global) | |
function animateAttention(isGlobal) { | |
cancelAnimationFrame(animationFrameId); | |
const seq = sequences[2]; // Use the longest sequence for demo | |
const midIndex = Math.floor(seq.length / 2); // Token that will 'attend' | |
explanationDiv.innerHTML = ` | |
<h3 class="font-semibold mb-1">${isGlobal ? 'Global' : 'Local'} Attention</h3> | |
Attention allows tokens to "look" at other tokens to understand context. | |
<b>${isGlobal ? 'Global Attention:' : 'Local Attention:'}</b> | |
${isGlobal | |
? 'Every token attends to every other token in the sequence. Powerful but computationally expensive, especially for long sequences.' | |
: `Each token attends only to a fixed-size window of nearby tokens (e.g., +/- ${Math.floor(localWindowSize / 2)} here). Much faster for long sequences.` | |
} ModernBERT alternates between these layers. | |
`; | |
let highlightProgress = 0; // 0 to 1 | |
function step() { | |
clearCanvas(); | |
const startY = padding * 3; | |
const startX = padding * 2; | |
let highlightMap = {}; | |
highlightMap[midIndex] = 'attending'; // The token doing the attending | |
const currentHighlight = Math.min(1, highlightProgress); | |
if (isGlobal) { | |
// Highlight all tokens based on progress | |
for (let i = 0; i < seq.length; i++) { | |
if (i !== midIndex) { | |
// Simple linear fade-in for attended tokens | |
if (Math.random() < currentHighlight) { // Randomly highlight based on progress for effect | |
highlightMap[i] = 'attended'; | |
} | |
} | |
} | |
} else { // Local Attention | |
const windowStart = Math.max(0, midIndex - Math.floor(localWindowSize / 2)); | |
const windowEnd = Math.min(seq.length - 1, midIndex + Math.floor(localWindowSize / 2)); | |
for (let i = windowStart; i <= windowEnd; i++) { | |
// Highlight based on progress within the window | |
if (i !== midIndex) { | |
const dist = Math.abs(i - midIndex); | |
const requiredProgress = dist / (localWindowSize / 2); // Closer tokens highlight sooner | |
if (currentHighlight >= requiredProgress) { | |
highlightMap[i] = 'local'; // Use 'local' type for window visualization | |
} | |
} | |
} | |
} | |
drawSequence(startX, startY, seq, highlightMap); | |
if (highlightProgress < 1) { | |
highlightProgress += 0.01 * animationSpeed; | |
animationFrameId = requestAnimationFrame(step); | |
} else { | |
// Ensure final state is fully highlighted | |
clearCanvas(); | |
highlightMap = {}; | |
highlightMap[midIndex] = 'attending'; | |
if (isGlobal) { | |
for (let i = 0; i < seq.length; i++) if (i !== midIndex) highlightMap[i] = 'attended'; | |
} else { | |
const windowStart = Math.max(0, midIndex - Math.floor(localWindowSize / 2)); | |
const windowEnd = Math.min(seq.length - 1, midIndex + Math.floor(localWindowSize / 2)); | |
for (let i = windowStart; i <= windowEnd; i++) if (i !== midIndex) highlightMap[i] = 'local'; | |
} | |
drawSequence(startX, startY, seq, highlightMap); | |
} | |
} | |
step(); | |
} | |
// --- Reset Function --- | |
function resetVisualization() { | |
cancelAnimationFrame(animationFrameId); | |
clearCanvas(); | |
explanationDiv.innerHTML = 'Click a button above to see the animation and explanation.'; | |
// Optionally redraw initial state if needed | |
// drawInitialState(); // Implement if you want a default view | |
} | |
// --- Event Listeners --- | |
document.getElementById('btnPadding').addEventListener('click', animatePadding); | |
document.getElementById('btnUnpadding').addEventListener('click', animateUnpadding); | |
document.getElementById('btnPacking').addEventListener('click', animateSequencePacking); | |
document.getElementById('btnLocalAttn').addEventListener('click', () => animateAttention(false)); | |
document.getElementById('btnGlobalAttn').addEventListener('click', () => animateAttention(true)); | |
document.getElementById('btnReset').addEventListener('click', resetVisualization); | |
// --- Initial Setup & Resize --- | |
function resizeCanvas() { | |
// Make canvas resolution match its display size | |
const displayWidth = canvas.clientWidth; | |
const displayHeight = canvas.clientHeight; // Use clientHeight for aspect ratio consistency | |
if (canvas.width !== displayWidth || canvas.height !== displayHeight) { | |
canvas.width = displayWidth; | |
canvas.height = displayHeight; | |
// Redraw current state if an animation was running? Or just reset? | |
// For simplicity, we'll reset on resize. | |
resetVisualization(); | |
} | |
} | |
// Initial resize and setup listener | |
window.addEventListener('resize', resizeCanvas); | |
// Ensure initial sizing is correct after elements are laid out | |
window.addEventListener('load', () => { | |
resizeCanvas(); | |
resetVisualization(); // Start clean | |
}); | |
</script> | |
</body> | |
</html> | |