Spaces:
Sleeping
Sleeping
# app.py | |
import gradio as gr | |
from model_utils import load_model_info, get_model_stats | |
from visualize import ( | |
visualize_attention, | |
visualize_token_embeddings, | |
plot_tokenization, | |
compare_model_sizes | |
) | |
MODEL_CHOICES = { | |
"BERT (base)": "bert-base-uncased", | |
"DistilBERT": "distilbert-base-uncased", | |
"RoBERTa": "roberta-base", | |
"GPT-2": "gpt2", | |
"Electra": "google/electra-base-discriminator", | |
"ALBERT": "albert-base-v2", | |
"XLNet": "xlnet-base-cased" | |
} | |
def run_visualizer(model_name, text, layer, head): | |
model_info = load_model_info(model_name) | |
attention_plot = visualize_attention(model_info, text, layer, head) | |
token_heatmap = visualize_token_embeddings(model_info, text) | |
token_plot = plot_tokenization(model_info, text) | |
model_stats = get_model_stats(model_info) | |
return attention_plot, token_heatmap, token_plot, model_stats | |
def run_comparison_chart(): | |
return compare_model_sizes(MODEL_CHOICES.values()) | |
with gr.Blocks() as demo: | |
gr.Markdown(""" | |
# π€ Transformer Model Visualizer | |
Explore attention heads, token embeddings, and tokenizer behavior across popular transformer models. | |
""") | |
with gr.Row(): | |
model_selector = gr.Dropdown(label="Choose Model", choices=list(MODEL_CHOICES.keys()), value="BERT (base)") | |
input_text = gr.Textbox(label="Input Text", placeholder="Enter text to analyze") | |
with gr.Row(): | |
layer_slider = gr.Slider(minimum=0, maximum=11, step=1, value=0, label="Layer") | |
head_slider = gr.Slider(minimum=0, maximum=11, step=1, value=0, label="Attention Head") | |
run_btn = gr.Button("Run Analysis") | |
with gr.Row(): | |
attention_output = gr.Plot(label="Self-Attention Visualization") | |
embedding_output = gr.Plot(label="Token Embedding Heatmap") | |
with gr.Row(): | |
token_output = gr.Plot(label="Tokenization Overview") | |
model_output = gr.JSON(label="Model Details") | |
run_btn.click( | |
fn=run_visualizer, | |
inputs=[model_selector, input_text, layer_slider, head_slider], | |
outputs=[attention_output, embedding_output, token_output, model_output] | |
) | |
with gr.Accordion("π Compare Model Sizes", open=False): | |
compare_btn = gr.Button("Generate Comparison Chart") | |
comparison_output = gr.Plot() | |
compare_btn.click(fn=run_comparison_chart, outputs=comparison_output) | |
demo.launch() | |