iisadia's picture
Update app.py
6cf26d6 verified
raw
history blame
2.56 kB
[file name] updated_code.py
[file content]
import streamlit as st
import matplotlib.pyplot as plt
import pandas as pd
import torch
from transformers import AutoConfig, AutoTokenizer # Added AutoTokenizer
# Page configuration
st.set_page_config(
page_title="Transformer Visualizer",
page_icon="🧠",
layout="wide",
initial_sidebar_state="expanded"
)
# Custom CSS styling (unchanged)
# ... [same CSS styles as original] ...
# Model database (unchanged)
MODELS = {
# ... [same model database as original] ...
}
def get_model_config(model_name):
config = AutoConfig.from_pretrained(MODELS[model_name]["model_name"])
return config
def plot_model_comparison(selected_model):
# ... [same comparison function as original] ...
def visualize_architecture(model_info):
# ... [same architecture function as original] ...
def visualize_attention_patterns():
# ... [same attention patterns function as original] ...
def main():
st.title("🧠 Transformer Model Visualizer")
selected_model = st.sidebar.selectbox("Select Model", list(MODELS.keys()))
model_info = MODELS[selected_model]
config = get_model_config(selected_model)
# Metrics columns (unchanged)
col1, col2, col3, col4 = st.columns(4)
# ... [same metrics code as original] ...
# Added 4th tab
tab1, tab2, tab3, tab4 = st.tabs(["Model Structure", "Comparison", "Model Attention", "Model Tokenization"])
# Existing tabs (unchanged)
# ... [same tab1, tab2, tab3 code as original] ...
# New Tokenization Tab
with tab4:
st.subheader("Text Tokenization")
user_input = st.text_input("Enter Text:", value="My name is Sadia!", key="tokenizer_input")
if st.button("Tokenize", key="tokenize_button"):
try:
tokenizer = AutoTokenizer.from_pretrained(MODELS[selected_model]["model_name"])
tokens = tokenizer.tokenize(user_input)
# Format output similar to reference image
tokenized_output = "- [ \n"
for idx, token in enumerate(tokens):
tokenized_output += f" {idx} : \"{token}\" \n"
tokenized_output += "]"
st.markdown("**Tokenized Output:**")
st.markdown(f"```\n{tokenized_output}\n```", unsafe_allow_html=True)
except Exception as e:
st.error(f"Error in tokenization: {str(e)}")
if __name__ == "__main__":
main()