iisadia commited on
Commit
6906b73
Β·
verified Β·
1 Parent(s): 9410bc8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -10
app.py CHANGED
@@ -29,6 +29,15 @@ st.markdown("""
29
  border-radius: 10px;
30
  padding: 15px;
31
  }
 
 
 
 
 
 
 
 
 
32
  </style>
33
  """, unsafe_allow_html=True)
34
 
@@ -57,7 +66,6 @@ def plot_model_comparison(selected_model):
57
  fig, ax = plt.subplots(figsize=(10, 6))
58
  bars = ax.bar(model_names, params)
59
 
60
- # Highlight selected model
61
  index = list(MODELS.keys()).index(selected_model)
62
  bars[index].set_color('#00ff00')
63
 
@@ -70,8 +78,70 @@ def plot_model_comparison(selected_model):
70
 
71
  st.pyplot(fig)
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  def visualize_attention_patterns():
74
- # Simplified attention patterns visualization
75
  fig, ax = plt.subplots(figsize=(8, 6))
76
  data = torch.randn(5, 5)
77
  ax.imshow(data, cmap='viridis')
@@ -83,14 +153,10 @@ def visualize_attention_patterns():
83
  def main():
84
  st.title("🧠 Transformer Model Visualizer")
85
 
86
- # Model selection
87
  selected_model = st.sidebar.selectbox("Select Model", list(MODELS.keys()))
88
-
89
- # Model details
90
  model_info = MODELS[selected_model]
91
  config = get_model_config(selected_model)
92
 
93
- # Display metrics
94
  col1, col2, col3, col4 = st.columns(4)
95
  with col1:
96
  st.metric("Model Type", model_info["type"])
@@ -101,13 +167,20 @@ def main():
101
  with col4:
102
  st.metric("Parameters", f"{model_info['params']}M")
103
 
104
- # Visualization tabs
105
- tab1, tab2, tab3 = st.tabs(["Model Structure", "Comparison", "Model Specific"])
106
 
107
  with tab1:
108
  st.subheader("Architecture Diagram")
109
- st.image("https://jalammar.github.io/images/t/transformer.png",
110
- use_container_width=True)
 
 
 
 
 
 
 
 
111
 
112
  with tab2:
113
  st.subheader("Model Size Comparison")
 
29
  border-radius: 10px;
30
  padding: 15px;
31
  }
32
+ .architecture {
33
+ font-family: monospace;
34
+ color: #00ff00;
35
+ white-space: pre-wrap;
36
+ background-color: #1a1a1a;
37
+ padding: 20px;
38
+ border-radius: 10px;
39
+ border: 1px solid #00ff00;
40
+ }
41
  </style>
42
  """, unsafe_allow_html=True)
43
 
 
66
  fig, ax = plt.subplots(figsize=(10, 6))
67
  bars = ax.bar(model_names, params)
68
 
 
69
  index = list(MODELS.keys()).index(selected_model)
70
  bars[index].set_color('#00ff00')
71
 
 
78
 
79
  st.pyplot(fig)
80
 
81
+ def visualize_architecture(model_info):
82
+ architecture = []
83
+ model_type = model_info["type"]
84
+ layers = model_info["layers"]
85
+ heads = model_info["heads"]
86
+
87
+ architecture.append("Input")
88
+ architecture.append("β”‚")
89
+ architecture.append("β–Ό")
90
+
91
+ if model_type == "Encoder":
92
+ architecture.append("[Embedding Layer]")
93
+ for i in range(layers):
94
+ architecture.extend([
95
+ f"Encoder Layer {i+1}",
96
+ "β”œβ”€ Multi-Head Attention",
97
+ f"β”‚ └─ {heads} Heads",
98
+ "β”œβ”€ Layer Normalization",
99
+ "└─ Feed Forward Network",
100
+ "β”‚",
101
+ "β–Ό"
102
+ ])
103
+ architecture.append("[Output]")
104
+
105
+ elif model_type == "Decoder":
106
+ architecture.append("[Embedding Layer]")
107
+ for i in range(layers):
108
+ architecture.extend([
109
+ f"Decoder Layer {i+1}",
110
+ "β”œβ”€ Masked Multi-Head Attention",
111
+ f"β”‚ └─ {heads} Heads",
112
+ "β”œβ”€ Layer Normalization",
113
+ "└─ Feed Forward Network",
114
+ "β”‚",
115
+ "β–Ό"
116
+ ])
117
+ architecture.append("[Output]")
118
+
119
+ elif model_type == "Seq2Seq":
120
+ architecture.append("Encoder Stack")
121
+ for i in range(layers):
122
+ architecture.extend([
123
+ f"Encoder Layer {i+1}",
124
+ "β”œβ”€ Self-Attention",
125
+ "└─ Feed Forward Network",
126
+ "β”‚",
127
+ "β–Ό"
128
+ ])
129
+ architecture.append("β†’β†’β†’ [Context] β†’β†’β†’")
130
+ architecture.append("Decoder Stack")
131
+ for i in range(layers):
132
+ architecture.extend([
133
+ f"Decoder Layer {i+1}",
134
+ "β”œβ”€ Masked Self-Attention",
135
+ "β”œβ”€ Encoder-Decoder Attention",
136
+ "└─ Feed Forward Network",
137
+ "β”‚",
138
+ "β–Ό"
139
+ ])
140
+ architecture.append("[Output]")
141
+
142
+ return "\n".join(architecture)
143
+
144
  def visualize_attention_patterns():
 
145
  fig, ax = plt.subplots(figsize=(8, 6))
146
  data = torch.randn(5, 5)
147
  ax.imshow(data, cmap='viridis')
 
153
  def main():
154
  st.title("🧠 Transformer Model Visualizer")
155
 
 
156
  selected_model = st.sidebar.selectbox("Select Model", list(MODELS.keys()))
 
 
157
  model_info = MODELS[selected_model]
158
  config = get_model_config(selected_model)
159
 
 
160
  col1, col2, col3, col4 = st.columns(4)
161
  with col1:
162
  st.metric("Model Type", model_info["type"])
 
167
  with col4:
168
  st.metric("Parameters", f"{model_info['params']}M")
169
 
170
+ tab1, tab2, tab3 = st.tabs(["Model Structure", "Comparison", "Model Attention"])
 
171
 
172
  with tab1:
173
  st.subheader("Architecture Diagram")
174
+ architecture = visualize_architecture(model_info)
175
+ st.markdown(f"<div class='architecture'>{architecture}</div>", unsafe_allow_html=True)
176
+
177
+ st.markdown("""
178
+ **Legend:**
179
+ - **Multi-Head Attention**: Self-attention mechanism with multiple parallel heads
180
+ - **Layer Normalization**: Normalization operation between layers
181
+ - **Feed Forward Network**: Position-wise fully connected network
182
+ - **Masked Attention**: Attention with future token masking
183
+ """)
184
 
185
  with tab2:
186
  st.subheader("Model Size Comparison")