iisadia commited on
Commit
cf92986
Β·
verified Β·
1 Parent(s): 6cf26d6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +201 -35
app.py CHANGED
@@ -1,10 +1,8 @@
1
- [file name] updated_code.py
2
- [file content]
3
  import streamlit as st
4
  import matplotlib.pyplot as plt
5
  import pandas as pd
6
  import torch
7
- from transformers import AutoConfig, AutoTokenizer # Added AutoTokenizer
8
 
9
  # Page configuration
10
  st.set_page_config(
@@ -14,12 +12,52 @@ st.set_page_config(
14
  initial_sidebar_state="expanded"
15
  )
16
 
17
- # Custom CSS styling (unchanged)
18
- # ... [same CSS styles as original] ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
- # Model database (unchanged)
21
  MODELS = {
22
- # ... [same model database as original] ...
 
 
 
 
 
 
 
 
 
23
  }
24
 
25
  def get_model_config(model_name):
@@ -27,13 +65,95 @@ def get_model_config(model_name):
27
  return config
28
 
29
  def plot_model_comparison(selected_model):
30
- # ... [same comparison function as original] ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  def visualize_architecture(model_info):
33
- # ... [same architecture function as original] ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  def visualize_attention_patterns():
36
- # ... [same attention patterns function as original] ...
 
 
 
 
 
 
37
 
38
  def main():
39
  st.title("🧠 Transformer Model Visualizer")
@@ -41,38 +161,84 @@ def main():
41
  selected_model = st.sidebar.selectbox("Select Model", list(MODELS.keys()))
42
  model_info = MODELS[selected_model]
43
  config = get_model_config(selected_model)
 
44
 
45
- # Metrics columns (unchanged)
46
  col1, col2, col3, col4 = st.columns(4)
47
- # ... [same metrics code as original] ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
- # Added 4th tab
50
- tab1, tab2, tab3, tab4 = st.tabs(["Model Structure", "Comparison", "Model Attention", "Model Tokenization"])
 
51
 
52
- # Existing tabs (unchanged)
53
- # ... [same tab1, tab2, tab3 code as original] ...
 
 
 
 
 
54
 
55
- # New Tokenization Tab
56
  with tab4:
57
- st.subheader("Text Tokenization")
58
- user_input = st.text_input("Enter Text:", value="My name is Sadia!", key="tokenizer_input")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
- if st.button("Tokenize", key="tokenize_button"):
61
- try:
62
- tokenizer = AutoTokenizer.from_pretrained(MODELS[selected_model]["model_name"])
63
- tokens = tokenizer.tokenize(user_input)
64
-
65
- # Format output similar to reference image
66
- tokenized_output = "- [ \n"
67
- for idx, token in enumerate(tokens):
68
- tokenized_output += f" {idx} : \"{token}\" \n"
69
- tokenized_output += "]"
70
-
71
- st.markdown("**Tokenized Output:**")
72
- st.markdown(f"```\n{tokenized_output}\n```", unsafe_allow_html=True)
73
-
74
- except Exception as e:
75
- st.error(f"Error in tokenization: {str(e)}")
76
 
77
  if __name__ == "__main__":
78
  main()
 
 
 
1
  import streamlit as st
2
  import matplotlib.pyplot as plt
3
  import pandas as pd
4
  import torch
5
+ from transformers import AutoConfig, AutoTokenizer
6
 
7
  # Page configuration
8
  st.set_page_config(
 
12
  initial_sidebar_state="expanded"
13
  )
14
 
15
+ # Custom CSS styling
16
+ st.markdown("""
17
+ <style>
18
+ .reportview-container {
19
+ background: linear-gradient(45deg, #1a1a1a, #4a4a4a);
20
+ }
21
+ .sidebar .sidebar-content {
22
+ background: #2c2c2c !important;
23
+ }
24
+ h1, h2, h3, h4, h5, h6 {
25
+ color: #00ff00 !important;
26
+ }
27
+ .stMetric {
28
+ background-color: #333333;
29
+ border-radius: 10px;
30
+ padding: 15px;
31
+ }
32
+ .architecture {
33
+ font-family: monospace;
34
+ color: #00ff00;
35
+ white-space: pre-wrap;
36
+ background-color: #1a1a1a;
37
+ padding: 20px;
38
+ border-radius: 10px;
39
+ border: 1px solid #00ff00;
40
+ }
41
+ .token-table {
42
+ margin-top: 20px;
43
+ border: 1px solid #00ff00;
44
+ border-radius: 5px;
45
+ }
46
+ </style>
47
+ """, unsafe_allow_html=True)
48
 
49
+ # Model database
50
  MODELS = {
51
+ "BERT": {"model_name": "bert-base-uncased", "type": "Encoder", "layers": 12, "heads": 12, "params": 109.48},
52
+ "GPT-2": {"model_name": "gpt2", "type": "Decoder", "layers": 12, "heads": 12, "params": 117},
53
+ "T5-Small": {"model_name": "t5-small", "type": "Seq2Seq", "layers": 6, "heads": 8, "params": 60},
54
+ "RoBERTa": {"model_name": "roberta-base", "type": "Encoder", "layers": 12, "heads": 12, "params": 125},
55
+ "DistilBERT": {"model_name": "distilbert-base-uncased", "type": "Encoder", "layers": 6, "heads": 12, "params": 66},
56
+ "ALBERT": {"model_name": "albert-base-v2", "type": "Encoder", "layers": 12, "heads": 12, "params": 11.8},
57
+ "ELECTRA": {"model_name": "google/electra-small-discriminator", "type": "Encoder", "layers": 12, "heads": 12, "params": 13.5},
58
+ "XLNet": {"model_name": "xlnet-base-cased", "type": "AutoRegressive", "layers": 12, "heads": 12, "params": 110},
59
+ "BART": {"model_name": "facebook/bart-base", "type": "Seq2Seq", "layers": 6, "heads": 16, "params": 139},
60
+ "DeBERTa": {"model_name": "microsoft/deberta-base", "type": "Encoder", "layers": 12, "heads": 12, "params": 139}
61
  }
62
 
63
  def get_model_config(model_name):
 
65
  return config
66
 
67
  def plot_model_comparison(selected_model):
68
+ model_names = list(MODELS.keys())
69
+ params = [m["params"] for m in MODELS.values()]
70
+
71
+ fig, ax = plt.subplots(figsize=(10, 6))
72
+ bars = ax.bar(model_names, params)
73
+
74
+ index = list(MODELS.keys()).index(selected_model)
75
+ bars[index].set_color('#00ff00')
76
+
77
+ ax.set_ylabel('Parameters (Millions)', color='white')
78
+ ax.set_title('Model Size Comparison', color='white')
79
+ ax.tick_params(axis='x', rotation=45, colors='white')
80
+ ax.tick_params(axis='y', colors='white')
81
+ ax.set_facecolor('#2c2c2c')
82
+ fig.patch.set_facecolor('#2c2c2c')
83
+
84
+ st.pyplot(fig)
85
 
86
  def visualize_architecture(model_info):
87
+ architecture = []
88
+ model_type = model_info["type"]
89
+ layers = model_info["layers"]
90
+ heads = model_info["heads"]
91
+
92
+ architecture.append("Input")
93
+ architecture.append("β”‚")
94
+ architecture.append("β–Ό")
95
+
96
+ if model_type == "Encoder":
97
+ architecture.append("[Embedding Layer]")
98
+ for i in range(layers):
99
+ architecture.extend([
100
+ f"Encoder Layer {i+1}",
101
+ "β”œβ”€ Multi-Head Attention",
102
+ f"β”‚ └─ {heads} Heads",
103
+ "β”œβ”€ Layer Normalization",
104
+ "└─ Feed Forward Network",
105
+ "β”‚",
106
+ "β–Ό"
107
+ ])
108
+ architecture.append("[Output]")
109
+
110
+ elif model_type == "Decoder":
111
+ architecture.append("[Embedding Layer]")
112
+ for i in range(layers):
113
+ architecture.extend([
114
+ f"Decoder Layer {i+1}",
115
+ "β”œβ”€ Masked Multi-Head Attention",
116
+ f"β”‚ └─ {heads} Heads",
117
+ "β”œβ”€ Layer Normalization",
118
+ "└─ Feed Forward Network",
119
+ "β”‚",
120
+ "β–Ό"
121
+ ])
122
+ architecture.append("[Output]")
123
+
124
+ elif model_type == "Seq2Seq":
125
+ architecture.append("Encoder Stack")
126
+ for i in range(layers):
127
+ architecture.extend([
128
+ f"Encoder Layer {i+1}",
129
+ "β”œβ”€ Self-Attention",
130
+ "└─ Feed Forward Network",
131
+ "β”‚",
132
+ "β–Ό"
133
+ ])
134
+ architecture.append("β†’β†’β†’ [Context] β†’β†’β†’")
135
+ architecture.append("Decoder Stack")
136
+ for i in range(layers):
137
+ architecture.extend([
138
+ f"Decoder Layer {i+1}",
139
+ "β”œβ”€ Masked Self-Attention",
140
+ "β”œβ”€ Encoder-Decoder Attention",
141
+ "└─ Feed Forward Network",
142
+ "β”‚",
143
+ "β–Ό"
144
+ ])
145
+ architecture.append("[Output]")
146
+
147
+ return "\n".join(architecture)
148
 
149
  def visualize_attention_patterns():
150
+ fig, ax = plt.subplots(figsize=(8, 6))
151
+ data = torch.randn(5, 5)
152
+ ax.imshow(data, cmap='viridis')
153
+ ax.set_title('Attention Patterns Example', color='white')
154
+ ax.set_facecolor('#2c2c2c')
155
+ fig.patch.set_facecolor('#2c2c2c')
156
+ st.pyplot(fig)
157
 
158
  def main():
159
  st.title("🧠 Transformer Model Visualizer")
 
161
  selected_model = st.sidebar.selectbox("Select Model", list(MODELS.keys()))
162
  model_info = MODELS[selected_model]
163
  config = get_model_config(selected_model)
164
+ tokenizer = AutoTokenizer.from_pretrained(model_info["model_name"])
165
 
 
166
  col1, col2, col3, col4 = st.columns(4)
167
+ with col1:
168
+ st.metric("Model Type", model_info["type"])
169
+ with col2:
170
+ st.metric("Layers", model_info["layers"])
171
+ with col3:
172
+ st.metric("Attention Heads", model_info["heads"])
173
+ with col4:
174
+ st.metric("Parameters", f"{model_info['params']}M")
175
+
176
+ tab1, tab2, tab3, tab4 = st.tabs(["Model Structure", "Comparison", "Model Attention", "Tokenization"])
177
+
178
+ with tab1:
179
+ st.subheader("Architecture Diagram")
180
+ architecture = visualize_architecture(model_info)
181
+ st.markdown(f"<div class='architecture'>{architecture}</div>", unsafe_allow_html=True)
182
+
183
+ st.markdown("""
184
+ **Legend:**
185
+ - **Multi-Head Attention**: Self-attention mechanism with multiple parallel heads
186
+ - **Layer Normalization**: Normalization operation between layers
187
+ - **Feed Forward Network**: Position-wise fully connected network
188
+ - **Masked Attention**: Attention with future token masking
189
+ """)
190
 
191
+ with tab2:
192
+ st.subheader("Model Size Comparison")
193
+ plot_model_comparison(selected_model)
194
 
195
+ with tab3:
196
+ st.subheader("Model-specific Visualizations")
197
+ visualize_attention_patterns()
198
+ if selected_model == "BERT":
199
+ st.write("BERT-specific visualization example")
200
+ elif selected_model == "GPT-2":
201
+ st.write("GPT-2 attention mask visualization")
202
 
 
203
  with tab4:
204
+ st.subheader("πŸ“ Tokenization Visualization")
205
+
206
+ input_text = st.text_input("Enter Text:", "Hello, how are you?")
207
+
208
+ col1, col2 = st.columns(2)
209
+
210
+ with col1:
211
+ st.markdown("**Tokenized Output**")
212
+ tokens = tokenizer.tokenize(input_text)
213
+ st.write(tokens)
214
+
215
+ with col2:
216
+ st.markdown("**Token IDs**")
217
+ encoded_ids = tokenizer.encode(input_text)
218
+ st.write(encoded_ids)
219
+
220
+ st.markdown("**Token-ID Mapping**")
221
+ token_data = pd.DataFrame({
222
+ "Token": tokens,
223
+ "ID": encoded_ids[1:-1] if tokenizer.cls_token else encoded_ids
224
+ })
225
+ st.dataframe(
226
+ token_data,
227
+ height=150,
228
+ use_container_width=True,
229
+ column_config={
230
+ "Token": "Token",
231
+ "ID": {"header": "ID", "help": "Numerical representation of the token"}
232
+ }
233
+ )
234
 
235
+ st.markdown(f"""
236
+ **Tokenizer Info:**
237
+ - Vocabulary size: `{tokenizer.vocab_size}`
238
+ - Special tokens: `{tokenizer.all_special_tokens}`
239
+ - Padding token: `{tokenizer.pad_token}`
240
+ - Max length: `{tokenizer.model_max_length}`
241
+ """)
 
 
 
 
 
 
 
 
 
242
 
243
  if __name__ == "__main__":
244
  main()