File size: 2,405 Bytes
1398c16
 
76269ea
 
 
 
 
 
1398c16
 
76269ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
# app.py
import gradio as gr
from model_utils import load_model_info, get_model_stats
from visualize import (
    visualize_attention,
    visualize_token_embeddings,
    plot_tokenization,
    compare_model_sizes
)

MODEL_CHOICES = {
    "BERT (base)": "bert-base-uncased",
    "DistilBERT": "distilbert-base-uncased",
    "RoBERTa": "roberta-base",
    "GPT-2": "gpt2",
    "Electra": "google/electra-base-discriminator",
    "ALBERT": "albert-base-v2",
    "XLNet": "xlnet-base-cased"
}

def run_visualizer(model_name, text, layer, head):
    model_info = load_model_info(model_name)
    attention_plot = visualize_attention(model_info, text, layer, head)
    token_heatmap = visualize_token_embeddings(model_info, text)
    token_plot = plot_tokenization(model_info, text)
    model_stats = get_model_stats(model_info)

    return attention_plot, token_heatmap, token_plot, model_stats

def run_comparison_chart():
    return compare_model_sizes(MODEL_CHOICES.values())

with gr.Blocks() as demo:
    gr.Markdown("""
    # πŸ€– Transformer Model Visualizer
    Explore attention heads, token embeddings, and tokenizer behavior across popular transformer models.
    """)

    with gr.Row():
        model_selector = gr.Dropdown(label="Choose Model", choices=list(MODEL_CHOICES.keys()), value="BERT (base)")
        input_text = gr.Textbox(label="Input Text", placeholder="Enter text to analyze")

    with gr.Row():
        layer_slider = gr.Slider(minimum=0, maximum=11, step=1, value=0, label="Layer")
        head_slider = gr.Slider(minimum=0, maximum=11, step=1, value=0, label="Attention Head")

    run_btn = gr.Button("Run Analysis")

    with gr.Row():
        attention_output = gr.Plot(label="Self-Attention Visualization")
        embedding_output = gr.Plot(label="Token Embedding Heatmap")

    with gr.Row():
        token_output = gr.Plot(label="Tokenization Overview")
        model_output = gr.JSON(label="Model Details")

    run_btn.click(
        fn=run_visualizer,
        inputs=[model_selector, input_text, layer_slider, head_slider],
        outputs=[attention_output, embedding_output, token_output, model_output]
    )

    with gr.Accordion("πŸ“Š Compare Model Sizes", open=False):
        compare_btn = gr.Button("Generate Comparison Chart")
        comparison_output = gr.Plot()
        compare_btn.click(fn=run_comparison_chart, outputs=comparison_output)

demo.launch()