borutokarma123's picture
Update app.py
3e9eb6b verified
import streamlit as st
import torch
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd
import plotly.express as px
from sklearn.decomposition import PCA
from transformers import BeitFeatureExtractor, BeitModel
# App Title
st.title("πŸš€ Vision Transformer Explorer")
st.markdown("""
Explore Vision Transformers, their architectures, and tokenization mechanisms.
""")
# Model Selection
model_name = st.selectbox(
"Choose a Vision Transformer Model:",
["microsoft/beit-base-patch16-224",
"microsoft/swin-base-patch4-window7",
"google/vit-base-patch16"]
)
# Load Feature Extractor & Model
st.write(f"Loading model: `{model_name}`...")
feature_extractor = BeitFeatureExtractor.from_pretrained(model_name)
model = BeitModel.from_pretrained(model_name)
# Display Model Details
st.subheader("πŸ›  Model Details")
st.write(f"Model Type: BeIT")
st.write(f"Number of Layers: {model.config.num_hidden_layers}")
st.write(f"Number of Attention Heads: {model.config.num_attention_heads}")
st.write(f"Total Parameters: {sum(p.numel() for p in model.parameters())/1e6:.2f}M")
# Model Size Comparison
st.subheader("πŸ“Š Model Size Comparison")
model_sizes = {
"microsoft/beit-base-patch16-224": 86,
"microsoft/swin-base-patch4-window7": 87,
"google/vit-base-patch16": 86
}
df_size = pd.DataFrame(model_sizes.items(), columns=["Model", "Size (Million Parameters)"])
fig = px.bar(df_size, x="Model", y="Size (Million Parameters)", title="Model Size Comparison")
st.plotly_chart(fig)
# Image Processing Section
st.subheader("πŸ“Έ Image Processing Visualization")
uploaded_image = st.file_uploader("Upload an image:", type=['png', 'jpg', 'jpeg'])
if uploaded_image is not None:
image = Image.open(uploaded_image)
inputs = feature_extractor(images=image, return_tensors="pt")
outputs = model(**inputs)
# Visualize patch embeddings
embeddings = outputs.last_hidden_state.squeeze(0).numpy()
pca = PCA(n_components=2)
reduced_embeddings = pca.fit_transform(embeddings)
df_embeddings = pd.DataFrame(reduced_embeddings, columns=["PCA1", "PCA2"])
fig = px.scatter(df_embeddings, x="PCA1", y="PCA2",
title="Patch Embeddings (PCA Projection)")
st.plotly_chart(fig)
# Attention Visualization
st.subheader("πŸ” Attention Map")
if uploaded_image is not None:
outputs = model(**inputs, output_attentions=True)
attention = outputs.attentions[-1].squeeze().detach().numpy()
fig, ax = plt.subplots(figsize=(10, 5))
sns.heatmap(attention[0], cmap="viridis", ax=ax)
st.pyplot(fig)
st.markdown("πŸ’‘ *Explore Vision Transformers!*\n")