File size: 2,555 Bytes
6cf26d6
 
c443e62
 
 
 
6cf26d6
c443e62
 
 
 
 
 
 
 
 
6cf26d6
 
c443e62
6cf26d6
c443e62
6cf26d6
c443e62
 
 
 
 
 
 
6cf26d6
c443e62
6906b73
6cf26d6
6906b73
c443e62
6cf26d6
c443e62
 
 
 
 
 
 
 
6cf26d6
c443e62
6cf26d6
c443e62
6cf26d6
 
c443e62
6cf26d6
 
c443e62
6cf26d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c443e62
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
[file name] updated_code.py
[file content]
import streamlit as st
import matplotlib.pyplot as plt
import pandas as pd
import torch
from transformers import AutoConfig, AutoTokenizer  # Added AutoTokenizer

# Page configuration
st.set_page_config(
    page_title="Transformer Visualizer",
    page_icon="🧠",
    layout="wide",
    initial_sidebar_state="expanded"
)

# Custom CSS styling (unchanged)
# ... [same CSS styles as original] ...

# Model database (unchanged)
MODELS = {
    # ... [same model database as original] ...
}

def get_model_config(model_name):
    config = AutoConfig.from_pretrained(MODELS[model_name]["model_name"])
    return config

def plot_model_comparison(selected_model):
    # ... [same comparison function as original] ...

def visualize_architecture(model_info):
    # ... [same architecture function as original] ...

def visualize_attention_patterns():
    # ... [same attention patterns function as original] ...

def main():
    st.title("🧠 Transformer Model Visualizer")
    
    selected_model = st.sidebar.selectbox("Select Model", list(MODELS.keys()))
    model_info = MODELS[selected_model]
    config = get_model_config(selected_model)
    
    # Metrics columns (unchanged)
    col1, col2, col3, col4 = st.columns(4)
    # ... [same metrics code as original] ...
    
    # Added 4th tab
    tab1, tab2, tab3, tab4 = st.tabs(["Model Structure", "Comparison", "Model Attention", "Model Tokenization"])
    
    # Existing tabs (unchanged)
    # ... [same tab1, tab2, tab3 code as original] ...
    
    # New Tokenization Tab
    with tab4:
        st.subheader("Text Tokenization")
        user_input = st.text_input("Enter Text:", value="My name is Sadia!", key="tokenizer_input")
        
        if st.button("Tokenize", key="tokenize_button"):
            try:
                tokenizer = AutoTokenizer.from_pretrained(MODELS[selected_model]["model_name"])
                tokens = tokenizer.tokenize(user_input)
                
                # Format output similar to reference image
                tokenized_output = "- [  \n"
                for idx, token in enumerate(tokens):
                    tokenized_output += f"  {idx} : \"{token}\"  \n"
                tokenized_output += "]"
                
                st.markdown("**Tokenized Output:**")
                st.markdown(f"```\n{tokenized_output}\n```", unsafe_allow_html=True)
                
            except Exception as e:
                st.error(f"Error in tokenization: {str(e)}")

if __name__ == "__main__":
    main()