Spaces:
Sleeping
Sleeping
# π§ Summary of Features: | |
# π½ Dropdown to choose from 7 prebuilt models | |
# π Updates attention layer & head slider limits based on the model | |
# π Visualizes attention maps for selected head/layer | |
# π§© Tokenized words preview | |
# π Embeddings reduced to 2D using PCA | |
# π¦ Model size bar chart across models | |
# π Full model config info in JSON viewer | |
import gradio as gr | |
from model_utils import MODEL_OPTIONS, load_model, get_model_info | |
from visualize import ( | |
visualize_attention, | |
show_tokenization, | |
show_embeddings, | |
compare_model_sizes, | |
) | |
# Initial load | |
DEFAULT_MODEL_NAME = list(MODEL_OPTIONS.values())[0] | |
tokenizer, model = load_model(DEFAULT_MODEL_NAME) | |
# Shared state | |
current_tokenizer = tokenizer | |
current_model = model | |
def update_model(selected_model_name): | |
global current_tokenizer, current_model | |
model_id = MODEL_OPTIONS[selected_model_name] | |
current_tokenizer, current_model = load_model(model_id) | |
info = get_model_info(current_model) | |
# Update layer/head sliders based on model | |
num_layers = info.get("Number of Layers", 1) | |
num_heads = info.get("Number of Attention Heads", 1) | |
return ( | |
info, | |
gr.update(maximum=num_layers - 1, value=0), | |
gr.update(maximum=num_heads - 1, value=0), | |
) | |
def run_all_visualizations(text, layer, head): | |
attention_fig = visualize_attention(current_tokenizer, current_model, text, layer, head) | |
token_fig = show_tokenization(current_tokenizer, text) | |
embedding_fig = show_embeddings(current_tokenizer, current_model, text) | |
return attention_fig, token_fig, embedding_fig | |
# UI | |
with gr.Blocks() as demo: | |
gr.Markdown("## π Transformer Explorer") | |
gr.Markdown("Explore attention, tokenization, and embedding visualizations for various transformer models.") | |
with gr.Row(): | |
model_dropdown = gr.Dropdown( | |
label="Choose a model", | |
choices=list(MODEL_OPTIONS.keys()), | |
value=list(MODEL_OPTIONS.keys())[0], | |
) | |
model_info = gr.JSON(label="Model Info") | |
with gr.Row(): | |
text_input = gr.Textbox(label="Enter text", value="The quick brown fox jumps over the lazy dog.") | |
layer_slider = gr.Slider(label="Layer", minimum=0, maximum=11, step=1, value=0) | |
head_slider = gr.Slider(label="Head", minimum=0, maximum=11, step=1, value=0) | |
run_button = gr.Button("Run Visualizations") | |
with gr.Tab("π Attention"): | |
attention_plot = gr.Plot() | |
with gr.Tab("π§© Tokenization"): | |
token_plot = gr.Plot() | |
with gr.Tab("π Embeddings"): | |
embedding_plot = gr.Plot() | |
with gr.Tab("π¦ Model Size Comparison"): | |
model_compare_plot = gr.Plot(value=compare_model_sizes()) | |
# Event binding | |
model_dropdown.change(fn=update_model, inputs=[model_dropdown], outputs=[model_info, layer_slider, head_slider]) | |
run_button.click( | |
fn=run_all_visualizations, | |
inputs=[text_input, layer_slider, head_slider], | |
outputs=[attention_plot, token_plot, embedding_plot], | |
) | |
if __name__ == "__main__": | |
demo.launch() | |