import streamlit as st import torch import matplotlib.pyplot as plt import seaborn as sns import numpy as np import pandas as pd import plotly.express as px from sklearn.decomposition import PCA from transformers import BeitFeatureExtractor, BeitModel # App Title st.title("🚀 Vision Transformer Explorer") st.markdown(""" Explore Vision Transformers, their architectures, and tokenization mechanisms. """) # Model Selection model_name = st.selectbox( "Choose a Vision Transformer Model:", ["microsoft/beit-base-patch16-224", "microsoft/swin-base-patch4-window7", "google/vit-base-patch16"] ) # Load Feature Extractor & Model st.write(f"Loading model: `{model_name}`...") feature_extractor = BeitFeatureExtractor.from_pretrained(model_name) model = BeitModel.from_pretrained(model_name) # Display Model Details st.subheader("🛠 Model Details") st.write(f"Model Type: BeIT") st.write(f"Number of Layers: {model.config.num_hidden_layers}") st.write(f"Number of Attention Heads: {model.config.num_attention_heads}") st.write(f"Total Parameters: {sum(p.numel() for p in model.parameters())/1e6:.2f}M") # Model Size Comparison st.subheader("📊 Model Size Comparison") model_sizes = { "microsoft/beit-base-patch16-224": 86, "microsoft/swin-base-patch4-window7": 87, "google/vit-base-patch16": 86 } df_size = pd.DataFrame(model_sizes.items(), columns=["Model", "Size (Million Parameters)"]) fig = px.bar(df_size, x="Model", y="Size (Million Parameters)", title="Model Size Comparison") st.plotly_chart(fig) # Image Processing Section st.subheader("📸 Image Processing Visualization") uploaded_image = st.file_uploader("Upload an image:", type=['png', 'jpg', 'jpeg']) if uploaded_image is not None: image = Image.open(uploaded_image) inputs = feature_extractor(images=image, return_tensors="pt") outputs = model(**inputs) # Visualize patch embeddings embeddings = outputs.last_hidden_state.squeeze(0).numpy() pca = PCA(n_components=2) reduced_embeddings = pca.fit_transform(embeddings) df_embeddings = pd.DataFrame(reduced_embeddings, columns=["PCA1", "PCA2"]) fig = px.scatter(df_embeddings, x="PCA1", y="PCA2", title="Patch Embeddings (PCA Projection)") st.plotly_chart(fig) # Attention Visualization st.subheader("🔍 Attention Map") if uploaded_image is not None: outputs = model(**inputs, output_attentions=True) attention = outputs.attentions[-1].squeeze().detach().numpy() fig, ax = plt.subplots(figsize=(10, 5)) sns.heatmap(attention[0], cmap="viridis", ax=ax) st.pyplot(fig) st.markdown("💡 *Explore Vision Transformers!*\n")