Spaces:
Sleeping
Sleeping
File size: 2,405 Bytes
1398c16 76269ea 1398c16 76269ea |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
# app.py
import gradio as gr
from model_utils import load_model_info, get_model_stats
from visualize import (
visualize_attention,
visualize_token_embeddings,
plot_tokenization,
compare_model_sizes
)
MODEL_CHOICES = {
"BERT (base)": "bert-base-uncased",
"DistilBERT": "distilbert-base-uncased",
"RoBERTa": "roberta-base",
"GPT-2": "gpt2",
"Electra": "google/electra-base-discriminator",
"ALBERT": "albert-base-v2",
"XLNet": "xlnet-base-cased"
}
def run_visualizer(model_name, text, layer, head):
model_info = load_model_info(model_name)
attention_plot = visualize_attention(model_info, text, layer, head)
token_heatmap = visualize_token_embeddings(model_info, text)
token_plot = plot_tokenization(model_info, text)
model_stats = get_model_stats(model_info)
return attention_plot, token_heatmap, token_plot, model_stats
def run_comparison_chart():
return compare_model_sizes(MODEL_CHOICES.values())
with gr.Blocks() as demo:
gr.Markdown("""
# π€ Transformer Model Visualizer
Explore attention heads, token embeddings, and tokenizer behavior across popular transformer models.
""")
with gr.Row():
model_selector = gr.Dropdown(label="Choose Model", choices=list(MODEL_CHOICES.keys()), value="BERT (base)")
input_text = gr.Textbox(label="Input Text", placeholder="Enter text to analyze")
with gr.Row():
layer_slider = gr.Slider(minimum=0, maximum=11, step=1, value=0, label="Layer")
head_slider = gr.Slider(minimum=0, maximum=11, step=1, value=0, label="Attention Head")
run_btn = gr.Button("Run Analysis")
with gr.Row():
attention_output = gr.Plot(label="Self-Attention Visualization")
embedding_output = gr.Plot(label="Token Embedding Heatmap")
with gr.Row():
token_output = gr.Plot(label="Tokenization Overview")
model_output = gr.JSON(label="Model Details")
run_btn.click(
fn=run_visualizer,
inputs=[model_selector, input_text, layer_slider, head_slider],
outputs=[attention_output, embedding_output, token_output, model_output]
)
with gr.Accordion("π Compare Model Sizes", open=False):
compare_btn = gr.Button("Generate Comparison Chart")
comparison_output = gr.Plot()
compare_btn.click(fn=run_comparison_chart, outputs=comparison_output)
demo.launch()
|