marianvd-01 commited on
Commit
9d0184c
Β·
verified Β·
1 Parent(s): 251e282

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +74 -47
app.py CHANGED
@@ -1,69 +1,96 @@
1
- # app.py
 
 
 
 
 
 
 
 
2
  import gradio as gr
3
- from model_utils import load_model_info, get_model_stats
4
  from visualize import (
5
  visualize_attention,
6
- visualize_token_embeddings,
7
- plot_tokenization,
8
- compare_model_sizes
9
  )
10
 
11
- MODEL_CHOICES = {
12
- "BERT (base)": "bert-base-uncased",
13
- "DistilBERT": "distilbert-base-uncased",
14
- "RoBERTa": "roberta-base",
15
- "GPT-2": "gpt2",
16
- "Electra": "google/electra-base-discriminator",
17
- "ALBERT": "albert-base-v2",
18
- "XLNet": "xlnet-base-cased"
19
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
- def run_visualizer(model_name, text, layer, head):
22
- model_info = load_model_info(model_name)
23
- attention_plot = visualize_attention(model_info, text, layer, head)
24
- token_heatmap = visualize_token_embeddings(model_info, text)
25
- token_plot = plot_tokenization(model_info, text)
26
- model_stats = get_model_stats(model_info)
27
 
28
- return attention_plot, token_heatmap, token_plot, model_stats
 
 
 
 
29
 
30
- def run_comparison_chart():
31
- return compare_model_sizes(MODEL_CHOICES.values())
32
 
 
33
  with gr.Blocks() as demo:
34
- gr.Markdown("""
35
- # πŸ€– Transformer Model Visualizer
36
- Explore attention heads, token embeddings, and tokenizer behavior across popular transformer models.
37
- """)
38
 
39
  with gr.Row():
40
- model_selector = gr.Dropdown(label="Choose Model", choices=list(MODEL_CHOICES.keys()), value="BERT (base)")
41
- input_text = gr.Textbox(label="Input Text", placeholder="Enter text to analyze")
 
 
 
 
42
 
43
  with gr.Row():
44
- layer_slider = gr.Slider(minimum=0, maximum=11, step=1, value=0, label="Layer")
45
- head_slider = gr.Slider(minimum=0, maximum=11, step=1, value=0, label="Attention Head")
 
46
 
47
- run_btn = gr.Button("Run Analysis")
48
 
49
- with gr.Row():
50
- attention_output = gr.Plot(label="Self-Attention Visualization")
51
- embedding_output = gr.Plot(label="Token Embedding Heatmap")
52
 
53
- with gr.Row():
54
- token_output = gr.Plot(label="Tokenization Overview")
55
- model_output = gr.JSON(label="Model Details")
56
 
57
- run_btn.click(
58
- fn=run_visualizer,
59
- inputs=[model_selector, input_text, layer_slider, head_slider],
60
- outputs=[attention_output, embedding_output, token_output, model_output]
 
 
 
 
 
 
 
 
61
  )
62
 
63
- with gr.Accordion("πŸ“Š Compare Model Sizes", open=False):
64
- compare_btn = gr.Button("Generate Comparison Chart")
65
- comparison_output = gr.Plot()
66
- compare_btn.click(fn=run_comparison_chart, outputs=comparison_output)
67
 
68
- demo.launch()
 
 
69
 
 
1
+ # 🧠 Summary of Features:
2
+ # πŸ”½ Dropdown to choose from 7 prebuilt models
3
+ # πŸ”„ Updates attention layer & head slider limits based on the model
4
+ # πŸ“Š Visualizes attention maps for selected head/layer
5
+ # 🧩 Tokenized words preview
6
+ # πŸ“‰ Embeddings reduced to 2D using PCA
7
+ # πŸ“¦ Model size bar chart across models
8
+ # πŸ“‹ Full model config info in JSON viewer
9
+
10
  import gradio as gr
11
+ from model_utils import MODEL_OPTIONS, load_model, get_model_info
12
  from visualize import (
13
  visualize_attention,
14
+ show_tokenization,
15
+ show_embeddings,
16
+ compare_model_sizes,
17
  )
18
 
19
+ # Initial load
20
+ DEFAULT_MODEL_NAME = list(MODEL_OPTIONS.values())[0]
21
+ tokenizer, model = load_model(DEFAULT_MODEL_NAME)
22
+
23
+ # Shared state
24
+ current_tokenizer = tokenizer
25
+ current_model = model
26
+
27
+
28
+ def update_model(selected_model_name):
29
+ global current_tokenizer, current_model
30
+ model_id = MODEL_OPTIONS[selected_model_name]
31
+ current_tokenizer, current_model = load_model(model_id)
32
+ info = get_model_info(current_model)
33
+
34
+ # Update layer/head sliders based on model
35
+ num_layers = info.get("Number of Layers", 1)
36
+ num_heads = info.get("Number of Attention Heads", 1)
37
+
38
+ return (
39
+ info,
40
+ gr.update(maximum=num_layers - 1, value=0),
41
+ gr.update(maximum=num_heads - 1, value=0),
42
+ )
43
 
 
 
 
 
 
 
44
 
45
+ def run_all_visualizations(text, layer, head):
46
+ attention_fig = visualize_attention(current_tokenizer, current_model, text, layer, head)
47
+ token_fig = show_tokenization(current_tokenizer, text)
48
+ embedding_fig = show_embeddings(current_tokenizer, current_model, text)
49
+ return attention_fig, token_fig, embedding_fig
50
 
 
 
51
 
52
+ # UI
53
  with gr.Blocks() as demo:
54
+ gr.Markdown("## πŸ” Transformer Explorer")
55
+ gr.Markdown("Explore attention, tokenization, and embedding visualizations for various transformer models.")
 
 
56
 
57
  with gr.Row():
58
+ model_dropdown = gr.Dropdown(
59
+ label="Choose a model",
60
+ choices=list(MODEL_OPTIONS.keys()),
61
+ value=list(MODEL_OPTIONS.keys())[0],
62
+ )
63
+ model_info = gr.JSON(label="Model Info")
64
 
65
  with gr.Row():
66
+ text_input = gr.Textbox(label="Enter text", value="The quick brown fox jumps over the lazy dog.")
67
+ layer_slider = gr.Slider(label="Layer", minimum=0, maximum=11, step=1, value=0)
68
+ head_slider = gr.Slider(label="Head", minimum=0, maximum=11, step=1, value=0)
69
 
70
+ run_button = gr.Button("Run Visualizations")
71
 
72
+ with gr.Tab("πŸ“Š Attention"):
73
+ attention_plot = gr.Plot()
 
74
 
75
+ with gr.Tab("🧩 Tokenization"):
76
+ token_plot = gr.Plot()
 
77
 
78
+ with gr.Tab("πŸ“‰ Embeddings"):
79
+ embedding_plot = gr.Plot()
80
+
81
+ with gr.Tab("πŸ“¦ Model Size Comparison"):
82
+ model_compare_plot = gr.Plot(value=compare_model_sizes())
83
+
84
+ # Event binding
85
+ model_dropdown.change(fn=update_model, inputs=[model_dropdown], outputs=[model_info, layer_slider, head_slider])
86
+ run_button.click(
87
+ fn=run_all_visualizations,
88
+ inputs=[text_input, layer_slider, head_slider],
89
+ outputs=[attention_plot, token_plot, embedding_plot],
90
  )
91
 
 
 
 
 
92
 
93
+ if __name__ == "__main__":
94
+ demo.launch()
95
+
96