marianvd-01's picture
Update app.py
76269ea verified
raw
history blame
2.41 kB
# app.py
import gradio as gr
from model_utils import load_model_info, get_model_stats
from visualize import (
visualize_attention,
visualize_token_embeddings,
plot_tokenization,
compare_model_sizes
)
MODEL_CHOICES = {
"BERT (base)": "bert-base-uncased",
"DistilBERT": "distilbert-base-uncased",
"RoBERTa": "roberta-base",
"GPT-2": "gpt2",
"Electra": "google/electra-base-discriminator",
"ALBERT": "albert-base-v2",
"XLNet": "xlnet-base-cased"
}
def run_visualizer(model_name, text, layer, head):
model_info = load_model_info(model_name)
attention_plot = visualize_attention(model_info, text, layer, head)
token_heatmap = visualize_token_embeddings(model_info, text)
token_plot = plot_tokenization(model_info, text)
model_stats = get_model_stats(model_info)
return attention_plot, token_heatmap, token_plot, model_stats
def run_comparison_chart():
return compare_model_sizes(MODEL_CHOICES.values())
with gr.Blocks() as demo:
gr.Markdown("""
# πŸ€– Transformer Model Visualizer
Explore attention heads, token embeddings, and tokenizer behavior across popular transformer models.
""")
with gr.Row():
model_selector = gr.Dropdown(label="Choose Model", choices=list(MODEL_CHOICES.keys()), value="BERT (base)")
input_text = gr.Textbox(label="Input Text", placeholder="Enter text to analyze")
with gr.Row():
layer_slider = gr.Slider(minimum=0, maximum=11, step=1, value=0, label="Layer")
head_slider = gr.Slider(minimum=0, maximum=11, step=1, value=0, label="Attention Head")
run_btn = gr.Button("Run Analysis")
with gr.Row():
attention_output = gr.Plot(label="Self-Attention Visualization")
embedding_output = gr.Plot(label="Token Embedding Heatmap")
with gr.Row():
token_output = gr.Plot(label="Tokenization Overview")
model_output = gr.JSON(label="Model Details")
run_btn.click(
fn=run_visualizer,
inputs=[model_selector, input_text, layer_slider, head_slider],
outputs=[attention_output, embedding_output, token_output, model_output]
)
with gr.Accordion("πŸ“Š Compare Model Sizes", open=False):
compare_btn = gr.Button("Generate Comparison Chart")
comparison_output = gr.Plot()
compare_btn.click(fn=run_comparison_chart, outputs=comparison_output)
demo.launch()