iisadia commited on
Commit
6cf26d6
·
verified ·
1 Parent(s): 6906b73

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -157
app.py CHANGED
@@ -1,8 +1,10 @@
 
 
1
  import streamlit as st
2
  import matplotlib.pyplot as plt
3
  import pandas as pd
4
  import torch
5
- from transformers import AutoConfig
6
 
7
  # Page configuration
8
  st.set_page_config(
@@ -12,47 +14,12 @@ st.set_page_config(
12
  initial_sidebar_state="expanded"
13
  )
14
 
15
- # Custom CSS styling
16
- st.markdown("""
17
- <style>
18
- .reportview-container {
19
- background: linear-gradient(45deg, #1a1a1a, #4a4a4a);
20
- }
21
- .sidebar .sidebar-content {
22
- background: #2c2c2c !important;
23
- }
24
- h1, h2, h3, h4, h5, h6 {
25
- color: #00ff00 !important;
26
- }
27
- .stMetric {
28
- background-color: #333333;
29
- border-radius: 10px;
30
- padding: 15px;
31
- }
32
- .architecture {
33
- font-family: monospace;
34
- color: #00ff00;
35
- white-space: pre-wrap;
36
- background-color: #1a1a1a;
37
- padding: 20px;
38
- border-radius: 10px;
39
- border: 1px solid #00ff00;
40
- }
41
- </style>
42
- """, unsafe_allow_html=True)
43
 
44
- # Model database
45
  MODELS = {
46
- "BERT": {"model_name": "bert-base-uncased", "type": "Encoder", "layers": 12, "heads": 12, "params": 109.48},
47
- "GPT-2": {"model_name": "gpt2", "type": "Decoder", "layers": 12, "heads": 12, "params": 117},
48
- "T5-Small": {"model_name": "t5-small", "type": "Seq2Seq", "layers": 6, "heads": 8, "params": 60},
49
- "RoBERTa": {"model_name": "roberta-base", "type": "Encoder", "layers": 12, "heads": 12, "params": 125},
50
- "DistilBERT": {"model_name": "distilbert-base-uncased", "type": "Encoder", "layers": 6, "heads": 12, "params": 66},
51
- "ALBERT": {"model_name": "albert-base-v2", "type": "Encoder", "layers": 12, "heads": 12, "params": 11.8},
52
- "ELECTRA": {"model_name": "google/electra-small-discriminator", "type": "Encoder", "layers": 12, "heads": 12, "params": 13.5},
53
- "XLNet": {"model_name": "xlnet-base-cased", "type": "AutoRegressive", "layers": 12, "heads": 12, "params": 110},
54
- "BART": {"model_name": "facebook/bart-base", "type": "Seq2Seq", "layers": 6, "heads": 16, "params": 139},
55
- "DeBERTa": {"model_name": "microsoft/deberta-base", "type": "Encoder", "layers": 12, "heads": 12, "params": 139}
56
  }
57
 
58
  def get_model_config(model_name):
@@ -60,95 +27,13 @@ def get_model_config(model_name):
60
  return config
61
 
62
  def plot_model_comparison(selected_model):
63
- model_names = list(MODELS.keys())
64
- params = [m["params"] for m in MODELS.values()]
65
-
66
- fig, ax = plt.subplots(figsize=(10, 6))
67
- bars = ax.bar(model_names, params)
68
-
69
- index = list(MODELS.keys()).index(selected_model)
70
- bars[index].set_color('#00ff00')
71
-
72
- ax.set_ylabel('Parameters (Millions)', color='white')
73
- ax.set_title('Model Size Comparison', color='white')
74
- ax.tick_params(axis='x', rotation=45, colors='white')
75
- ax.tick_params(axis='y', colors='white')
76
- ax.set_facecolor('#2c2c2c')
77
- fig.patch.set_facecolor('#2c2c2c')
78
-
79
- st.pyplot(fig)
80
 
81
  def visualize_architecture(model_info):
82
- architecture = []
83
- model_type = model_info["type"]
84
- layers = model_info["layers"]
85
- heads = model_info["heads"]
86
-
87
- architecture.append("Input")
88
- architecture.append("│")
89
- architecture.append("▼")
90
-
91
- if model_type == "Encoder":
92
- architecture.append("[Embedding Layer]")
93
- for i in range(layers):
94
- architecture.extend([
95
- f"Encoder Layer {i+1}",
96
- "├─ Multi-Head Attention",
97
- f"│ └─ {heads} Heads",
98
- "├─ Layer Normalization",
99
- "└─ Feed Forward Network",
100
- "│",
101
- "▼"
102
- ])
103
- architecture.append("[Output]")
104
-
105
- elif model_type == "Decoder":
106
- architecture.append("[Embedding Layer]")
107
- for i in range(layers):
108
- architecture.extend([
109
- f"Decoder Layer {i+1}",
110
- "├─ Masked Multi-Head Attention",
111
- f"│ └─ {heads} Heads",
112
- "├─ Layer Normalization",
113
- "└─ Feed Forward Network",
114
- "│",
115
- "▼"
116
- ])
117
- architecture.append("[Output]")
118
-
119
- elif model_type == "Seq2Seq":
120
- architecture.append("Encoder Stack")
121
- for i in range(layers):
122
- architecture.extend([
123
- f"Encoder Layer {i+1}",
124
- "├─ Self-Attention",
125
- "└─ Feed Forward Network",
126
- "│",
127
- "▼"
128
- ])
129
- architecture.append("→→→ [Context] →→→")
130
- architecture.append("Decoder Stack")
131
- for i in range(layers):
132
- architecture.extend([
133
- f"Decoder Layer {i+1}",
134
- "├─ Masked Self-Attention",
135
- "├─ Encoder-Decoder Attention",
136
- "└─ Feed Forward Network",
137
- "│",
138
- "▼"
139
- ])
140
- architecture.append("[Output]")
141
-
142
- return "\n".join(architecture)
143
 
144
  def visualize_attention_patterns():
145
- fig, ax = plt.subplots(figsize=(8, 6))
146
- data = torch.randn(5, 5)
147
- ax.imshow(data, cmap='viridis')
148
- ax.set_title('Attention Patterns Example', color='white')
149
- ax.set_facecolor('#2c2c2c')
150
- fig.patch.set_facecolor('#2c2c2c')
151
- st.pyplot(fig)
152
 
153
  def main():
154
  st.title("🧠 Transformer Model Visualizer")
@@ -157,42 +42,37 @@ def main():
157
  model_info = MODELS[selected_model]
158
  config = get_model_config(selected_model)
159
 
 
160
  col1, col2, col3, col4 = st.columns(4)
161
- with col1:
162
- st.metric("Model Type", model_info["type"])
163
- with col2:
164
- st.metric("Layers", model_info["layers"])
165
- with col3:
166
- st.metric("Attention Heads", model_info["heads"])
167
- with col4:
168
- st.metric("Parameters", f"{model_info['params']}M")
169
 
170
- tab1, tab2, tab3 = st.tabs(["Model Structure", "Comparison", "Model Attention"])
 
171
 
172
- with tab1:
173
- st.subheader("Architecture Diagram")
174
- architecture = visualize_architecture(model_info)
175
- st.markdown(f"<div class='architecture'>{architecture}</div>", unsafe_allow_html=True)
176
-
177
- st.markdown("""
178
- **Legend:**
179
- - **Multi-Head Attention**: Self-attention mechanism with multiple parallel heads
180
- - **Layer Normalization**: Normalization operation between layers
181
- - **Feed Forward Network**: Position-wise fully connected network
182
- - **Masked Attention**: Attention with future token masking
183
- """)
184
-
185
- with tab2:
186
- st.subheader("Model Size Comparison")
187
- plot_model_comparison(selected_model)
188
 
189
- with tab3:
190
- st.subheader("Model-specific Visualizations")
191
- visualize_attention_patterns()
192
- if selected_model == "BERT":
193
- st.write("BERT-specific visualization example")
194
- elif selected_model == "GPT-2":
195
- st.write("GPT-2 attention mask visualization")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
 
197
  if __name__ == "__main__":
198
  main()
 
1
+ [file name] updated_code.py
2
+ [file content]
3
  import streamlit as st
4
  import matplotlib.pyplot as plt
5
  import pandas as pd
6
  import torch
7
+ from transformers import AutoConfig, AutoTokenizer # Added AutoTokenizer
8
 
9
  # Page configuration
10
  st.set_page_config(
 
14
  initial_sidebar_state="expanded"
15
  )
16
 
17
+ # Custom CSS styling (unchanged)
18
+ # ... [same CSS styles as original] ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
+ # Model database (unchanged)
21
  MODELS = {
22
+ # ... [same model database as original] ...
 
 
 
 
 
 
 
 
 
23
  }
24
 
25
  def get_model_config(model_name):
 
27
  return config
28
 
29
  def plot_model_comparison(selected_model):
30
+ # ... [same comparison function as original] ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  def visualize_architecture(model_info):
33
+ # ... [same architecture function as original] ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  def visualize_attention_patterns():
36
+ # ... [same attention patterns function as original] ...
 
 
 
 
 
 
37
 
38
  def main():
39
  st.title("🧠 Transformer Model Visualizer")
 
42
  model_info = MODELS[selected_model]
43
  config = get_model_config(selected_model)
44
 
45
+ # Metrics columns (unchanged)
46
  col1, col2, col3, col4 = st.columns(4)
47
+ # ... [same metrics code as original] ...
 
 
 
 
 
 
 
48
 
49
+ # Added 4th tab
50
+ tab1, tab2, tab3, tab4 = st.tabs(["Model Structure", "Comparison", "Model Attention", "Model Tokenization"])
51
 
52
+ # Existing tabs (unchanged)
53
+ # ... [same tab1, tab2, tab3 code as original] ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
+ # New Tokenization Tab
56
+ with tab4:
57
+ st.subheader("Text Tokenization")
58
+ user_input = st.text_input("Enter Text:", value="My name is Sadia!", key="tokenizer_input")
59
+
60
+ if st.button("Tokenize", key="tokenize_button"):
61
+ try:
62
+ tokenizer = AutoTokenizer.from_pretrained(MODELS[selected_model]["model_name"])
63
+ tokens = tokenizer.tokenize(user_input)
64
+
65
+ # Format output similar to reference image
66
+ tokenized_output = "- [ \n"
67
+ for idx, token in enumerate(tokens):
68
+ tokenized_output += f" {idx} : \"{token}\" \n"
69
+ tokenized_output += "]"
70
+
71
+ st.markdown("**Tokenized Output:**")
72
+ st.markdown(f"```\n{tokenized_output}\n```", unsafe_allow_html=True)
73
+
74
+ except Exception as e:
75
+ st.error(f"Error in tokenization: {str(e)}")
76
 
77
  if __name__ == "__main__":
78
  main()