Spaces:
Sleeping
Sleeping
[file name] updated_code.py | |
[file content] | |
import streamlit as st | |
import matplotlib.pyplot as plt | |
import pandas as pd | |
import torch | |
from transformers import AutoConfig, AutoTokenizer # Added AutoTokenizer | |
# Page configuration | |
st.set_page_config( | |
page_title="Transformer Visualizer", | |
page_icon="🧠", | |
layout="wide", | |
initial_sidebar_state="expanded" | |
) | |
# Custom CSS styling (unchanged) | |
# ... [same CSS styles as original] ... | |
# Model database (unchanged) | |
MODELS = { | |
# ... [same model database as original] ... | |
} | |
def get_model_config(model_name): | |
config = AutoConfig.from_pretrained(MODELS[model_name]["model_name"]) | |
return config | |
def plot_model_comparison(selected_model): | |
# ... [same comparison function as original] ... | |
def visualize_architecture(model_info): | |
# ... [same architecture function as original] ... | |
def visualize_attention_patterns(): | |
# ... [same attention patterns function as original] ... | |
def main(): | |
st.title("🧠 Transformer Model Visualizer") | |
selected_model = st.sidebar.selectbox("Select Model", list(MODELS.keys())) | |
model_info = MODELS[selected_model] | |
config = get_model_config(selected_model) | |
# Metrics columns (unchanged) | |
col1, col2, col3, col4 = st.columns(4) | |
# ... [same metrics code as original] ... | |
# Added 4th tab | |
tab1, tab2, tab3, tab4 = st.tabs(["Model Structure", "Comparison", "Model Attention", "Model Tokenization"]) | |
# Existing tabs (unchanged) | |
# ... [same tab1, tab2, tab3 code as original] ... | |
# New Tokenization Tab | |
with tab4: | |
st.subheader("Text Tokenization") | |
user_input = st.text_input("Enter Text:", value="My name is Sadia!", key="tokenizer_input") | |
if st.button("Tokenize", key="tokenize_button"): | |
try: | |
tokenizer = AutoTokenizer.from_pretrained(MODELS[selected_model]["model_name"]) | |
tokens = tokenizer.tokenize(user_input) | |
# Format output similar to reference image | |
tokenized_output = "- [ \n" | |
for idx, token in enumerate(tokens): | |
tokenized_output += f" {idx} : \"{token}\" \n" | |
tokenized_output += "]" | |
st.markdown("**Tokenized Output:**") | |
st.markdown(f"```\n{tokenized_output}\n```", unsafe_allow_html=True) | |
except Exception as e: | |
st.error(f"Error in tokenization: {str(e)}") | |
if __name__ == "__main__": | |
main() |