File size: 3,103 Bytes
9d0184c
 
 
 
 
 
 
 
 
1398c16
9d0184c
76269ea
 
9d0184c
 
 
1398c16
 
9d0184c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76269ea
 
9d0184c
 
 
 
 
76269ea
 
9d0184c
76269ea
9d0184c
 
76269ea
 
9d0184c
 
 
 
 
 
76269ea
 
9d0184c
 
 
76269ea
9d0184c
76269ea
9d0184c
 
76269ea
9d0184c
 
76269ea
9d0184c
 
 
 
 
 
 
 
 
 
 
 
76269ea
 
 
9d0184c
 
 
76269ea
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
# 🧠 Summary of Features:
# πŸ”½ Dropdown to choose from 7 prebuilt models
# πŸ”„ Updates attention layer & head slider limits based on the model
# πŸ“Š Visualizes attention maps for selected head/layer
# 🧩 Tokenized words preview
# πŸ“‰ Embeddings reduced to 2D using PCA
# πŸ“¦ Model size bar chart across models
# πŸ“‹ Full model config info in JSON viewer

import gradio as gr
from model_utils import MODEL_OPTIONS, load_model, get_model_info
from visualize import (
    visualize_attention,
    show_tokenization,
    show_embeddings,
    compare_model_sizes,
)

# Initial load
DEFAULT_MODEL_NAME = list(MODEL_OPTIONS.values())[0]
tokenizer, model = load_model(DEFAULT_MODEL_NAME)

# Shared state
current_tokenizer = tokenizer
current_model = model


def update_model(selected_model_name):
    global current_tokenizer, current_model
    model_id = MODEL_OPTIONS[selected_model_name]
    current_tokenizer, current_model = load_model(model_id)
    info = get_model_info(current_model)

    # Update layer/head sliders based on model
    num_layers = info.get("Number of Layers", 1)
    num_heads = info.get("Number of Attention Heads", 1)

    return (
        info,
        gr.update(maximum=num_layers - 1, value=0),
        gr.update(maximum=num_heads - 1, value=0),
    )


def run_all_visualizations(text, layer, head):
    attention_fig = visualize_attention(current_tokenizer, current_model, text, layer, head)
    token_fig = show_tokenization(current_tokenizer, text)
    embedding_fig = show_embeddings(current_tokenizer, current_model, text)
    return attention_fig, token_fig, embedding_fig


# UI
with gr.Blocks() as demo:
    gr.Markdown("## πŸ” Transformer Explorer")
    gr.Markdown("Explore attention, tokenization, and embedding visualizations for various transformer models.")

    with gr.Row():
        model_dropdown = gr.Dropdown(
            label="Choose a model",
            choices=list(MODEL_OPTIONS.keys()),
            value=list(MODEL_OPTIONS.keys())[0],
        )
        model_info = gr.JSON(label="Model Info")

    with gr.Row():
        text_input = gr.Textbox(label="Enter text", value="The quick brown fox jumps over the lazy dog.")
        layer_slider = gr.Slider(label="Layer", minimum=0, maximum=11, step=1, value=0)
        head_slider = gr.Slider(label="Head", minimum=0, maximum=11, step=1, value=0)

    run_button = gr.Button("Run Visualizations")

    with gr.Tab("πŸ“Š Attention"):
        attention_plot = gr.Plot()

    with gr.Tab("🧩 Tokenization"):
        token_plot = gr.Plot()

    with gr.Tab("πŸ“‰ Embeddings"):
        embedding_plot = gr.Plot()

    with gr.Tab("πŸ“¦ Model Size Comparison"):
        model_compare_plot = gr.Plot(value=compare_model_sizes())

    # Event binding
    model_dropdown.change(fn=update_model, inputs=[model_dropdown], outputs=[model_info, layer_slider, head_slider])
    run_button.click(
        fn=run_all_visualizations,
        inputs=[text_input, layer_slider, head_slider],
        outputs=[attention_plot, token_plot, embedding_plot],
    )


if __name__ == "__main__":
    demo.launch()