Transformer / app.py
rahideer's picture
Update app.py
b1bec5c verified
import streamlit as st
from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification, pipeline
import torch
import plotly.express as px
import numpy as np
from sklearn.decomposition import PCA
from utils import visualize_attention, list_supported_models, plot_token_embeddings
st.set_page_config(page_title="Transformer Visualizer", layout="wide")
st.title("🧠 Transformer Visualizer")
st.markdown("Explore how Transformer models process and understand language.")
task = st.sidebar.selectbox("Select Task", ["Text Classification", "Text Generation", "Question Answering"])
model_name = st.sidebar.selectbox("Select Model", list_supported_models(task))
text_input = st.text_area("Enter input text", "The quick brown fox jumps over the lazy dog.")
if st.button("Run"):
st.info(f"Loading model: `{model_name}`...")
tokenizer = AutoTokenizer.from_pretrained(model_name)
if task == "Text Classification":
model = AutoModelForSequenceClassification.from_pretrained(model_name, output_attentions=True)
else:
model = AutoModel.from_pretrained(model_name, output_attentions=True)
inputs = tokenizer(text_input, return_tensors="pt", return_token_type_ids=False)
outputs = model(**inputs)
attentions = outputs.attentions
st.success("Model inference complete!")
# Tokenization Visualization
st.subheader("πŸ”  Tokenization")
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
token_ids = inputs["input_ids"][0].tolist()
st.write(list(zip(tokens, token_ids)))
# Token Embeddings Visualization
st.subheader("🌐 Token Embedding Space (PCA)")
with torch.no_grad():
hidden_states = model.base_model.embeddings.word_embeddings(inputs["input_ids"]).squeeze(0)
fig_embed = plot_token_embeddings(hidden_states, tokens)
st.plotly_chart(fig_embed, use_container_width=True)
# Attention Visualization
if attentions:
st.subheader("πŸ‘οΈ Attention Visualization")
fig = visualize_attention(attentions, tokenizer, inputs)
st.plotly_chart(fig, use_container_width=True)
else:
st.warning("This model does not return attention weights.")
if task == "Text Classification":
st.subheader("βœ… Prediction")
pipe = pipeline("text-classification", model=model, tokenizer=tokenizer)
prediction = pipe(text_input)
st.write(prediction)
st.sidebar.markdown("---")
st.sidebar.write("App by Rahiya Esar πŸ’–")