File size: 2,669 Bytes
ffea4fb 3e9eb6b ffea4fb 0cdb8e0 ffea4fb 0cdb8e0 ffea4fb 0cdb8e0 ffea4fb 0cdb8e0 3272087 ffea4fb 3e9eb6b ffea4fb 3e9eb6b ffea4fb 3e9eb6b ffea4fb 3e9eb6b ffea4fb 0cdb8e0 ffea4fb 3e9eb6b 46d2acf 3e9eb6b ffea4fb 0cdb8e0 3e9eb6b 0cdb8e0 3e9eb6b 0cdb8e0 ffea4fb 0cdb8e0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 |
import streamlit as st
import torch
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd
import plotly.express as px
from sklearn.decomposition import PCA
from transformers import BeitFeatureExtractor, BeitModel
# App Title
st.title("π Vision Transformer Explorer")
st.markdown("""
Explore Vision Transformers, their architectures, and tokenization mechanisms.
""")
# Model Selection
model_name = st.selectbox(
"Choose a Vision Transformer Model:",
["microsoft/beit-base-patch16-224",
"microsoft/swin-base-patch4-window7",
"google/vit-base-patch16"]
)
# Load Feature Extractor & Model
st.write(f"Loading model: `{model_name}`...")
feature_extractor = BeitFeatureExtractor.from_pretrained(model_name)
model = BeitModel.from_pretrained(model_name)
# Display Model Details
st.subheader("π Model Details")
st.write(f"Model Type: BeIT")
st.write(f"Number of Layers: {model.config.num_hidden_layers}")
st.write(f"Number of Attention Heads: {model.config.num_attention_heads}")
st.write(f"Total Parameters: {sum(p.numel() for p in model.parameters())/1e6:.2f}M")
# Model Size Comparison
st.subheader("π Model Size Comparison")
model_sizes = {
"microsoft/beit-base-patch16-224": 86,
"microsoft/swin-base-patch4-window7": 87,
"google/vit-base-patch16": 86
}
df_size = pd.DataFrame(model_sizes.items(), columns=["Model", "Size (Million Parameters)"])
fig = px.bar(df_size, x="Model", y="Size (Million Parameters)", title="Model Size Comparison")
st.plotly_chart(fig)
# Image Processing Section
st.subheader("πΈ Image Processing Visualization")
uploaded_image = st.file_uploader("Upload an image:", type=['png', 'jpg', 'jpeg'])
if uploaded_image is not None:
image = Image.open(uploaded_image)
inputs = feature_extractor(images=image, return_tensors="pt")
outputs = model(**inputs)
# Visualize patch embeddings
embeddings = outputs.last_hidden_state.squeeze(0).numpy()
pca = PCA(n_components=2)
reduced_embeddings = pca.fit_transform(embeddings)
df_embeddings = pd.DataFrame(reduced_embeddings, columns=["PCA1", "PCA2"])
fig = px.scatter(df_embeddings, x="PCA1", y="PCA2",
title="Patch Embeddings (PCA Projection)")
st.plotly_chart(fig)
# Attention Visualization
st.subheader("π Attention Map")
if uploaded_image is not None:
outputs = model(**inputs, output_attentions=True)
attention = outputs.attentions[-1].squeeze().detach().numpy()
fig, ax = plt.subplots(figsize=(10, 5))
sns.heatmap(attention[0], cmap="viridis", ax=ax)
st.pyplot(fig)
st.markdown("π‘ *Explore Vision Transformers!*\n") |