iisadia commited on
Commit
c443e62
·
verified ·
1 Parent(s): 172b89c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +125 -0
app.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import matplotlib.pyplot as plt
3
+ import pandas as pd
4
+ import torch
5
+ from transformers import AutoConfig
6
+
7
+ # Page configuration
8
+ st.set_page_config(
9
+ page_title="Transformer Visualizer",
10
+ page_icon="🧠",
11
+ layout="wide",
12
+ initial_sidebar_state="expanded"
13
+ )
14
+
15
+ # Custom CSS styling
16
+ st.markdown("""
17
+ <style>
18
+ .reportview-container {
19
+ background: linear-gradient(45deg, #1a1a1a, #4a4a4a);
20
+ }
21
+ .sidebar .sidebar-content {
22
+ background: #2c2c2c !important;
23
+ }
24
+ h1, h2, h3, h4, h5, h6 {
25
+ color: #00ff00 !important;
26
+ }
27
+ .stMetric {
28
+ background-color: #333333;
29
+ border-radius: 10px;
30
+ padding: 15px;
31
+ }
32
+ </style>
33
+ """, unsafe_allow_html=True)
34
+
35
+ # Model database
36
+ MODELS = {
37
+ "BERT": {"model_name": "bert-base-uncased", "type": "Encoder", "layers": 12, "heads": 12, "params": 109.48},
38
+ "GPT-2": {"model_name": "gpt2", "type": "Decoder", "layers": 12, "heads": 12, "params": 117},
39
+ "T5-Small": {"model_name": "t5-small", "type": "Seq2Seq", "layers": 6, "heads": 8, "params": 60},
40
+ "RoBERTa": {"model_name": "roberta-base", "type": "Encoder", "layers": 12, "heads": 12, "params": 125},
41
+ "DistilBERT": {"model_name": "distilbert-base-uncased", "type": "Encoder", "layers": 6, "heads": 12, "params": 66},
42
+ "ALBERT": {"model_name": "albert-base-v2", "type": "Encoder", "layers": 12, "heads": 12, "params": 11.8},
43
+ "ELECTRA": {"model_name": "google/electra-small-discriminator", "type": "Encoder", "layers": 12, "heads": 12, "params": 13.5},
44
+ "XLNet": {"model_name": "xlnet-base-cased", "type": "AutoRegressive", "layers": 12, "heads": 12, "params": 110},
45
+ "BART": {"model_name": "facebook/bart-base", "type": "Seq2Seq", "layers": 6, "heads": 16, "params": 139},
46
+ "DeBERTa": {"model_name": "microsoft/deberta-base", "type": "Encoder", "layers": 12, "heads": 12, "params": 139}
47
+ }
48
+
49
+ def get_model_config(model_name):
50
+ config = AutoConfig.from_pretrained(MODELS[model_name]["model_name"])
51
+ return config
52
+
53
+ def plot_model_comparison(selected_model):
54
+ model_names = list(MODELS.keys())
55
+ params = [m["params"] for m in MODELS.values()]
56
+
57
+ fig, ax = plt.subplots(figsize=(10, 6))
58
+ bars = ax.bar(model_names, params)
59
+
60
+ # Highlight selected model
61
+ index = list(MODELS.keys()).index(selected_model)
62
+ bars[index].set_color('#00ff00')
63
+
64
+ ax.set_ylabel('Parameters (Millions)', color='white')
65
+ ax.set_title('Model Size Comparison', color='white')
66
+ ax.tick_params(axis='x', rotation=45, colors='white')
67
+ ax.tick_params(axis='y', colors='white')
68
+ ax.set_facecolor('#2c2c2c')
69
+ fig.patch.set_facecolor('#2c2c2c')
70
+
71
+ st.pyplot(fig)
72
+
73
+ def visualize_attention_patterns():
74
+ # Simplified attention patterns visualization
75
+ fig, ax = plt.subplots(figsize=(8, 6))
76
+ data = torch.randn(5, 5)
77
+ ax.imshow(data, cmap='viridis')
78
+ ax.set_title('Attention Patterns Example', color='white')
79
+ ax.set_facecolor('#2c2c2c')
80
+ fig.patch.set_facecolor('#2c2c2c')
81
+ st.pyplot(fig)
82
+
83
+ def main():
84
+ st.title("🧠 Transformer Model Visualizer")
85
+
86
+ # Model selection
87
+ selected_model = st.sidebar.selectbox("Select Model", list(MODELS.keys()))
88
+
89
+ # Model details
90
+ model_info = MODELS[selected_model]
91
+ config = get_model_config(selected_model)
92
+
93
+ # Display metrics
94
+ col1, col2, col3, col4 = st.columns(4)
95
+ with col1:
96
+ st.metric("Model Type", model_info["type"])
97
+ with col2:
98
+ st.metric("Layers", model_info["layers"])
99
+ with col3:
100
+ st.metric("Attention Heads", model_info["heads"])
101
+ with col4:
102
+ st.metric("Parameters", f"{model_info['params']}M")
103
+
104
+ # Visualization tabs
105
+ tab1, tab2, tab3 = st.tabs(["Model Structure", "Comparison", "Model Specific"])
106
+
107
+ with tab1:
108
+ st.subheader("Architecture Diagram")
109
+ st.image("https://upload.wikimedia.org/wikipedia/commons/thumb/8/8a/Transformer_model.svg/1200px-Transformer_model.svg.png",
110
+ use_container_width=True) # Changed parameter here
111
+
112
+ with tab2:
113
+ st.subheader("Model Size Comparison")
114
+ plot_model_comparison(selected_model)
115
+
116
+ with tab3:
117
+ st.subheader("Model-specific Visualizations")
118
+ visualize_attention_patterns()
119
+ if selected_model == "BERT":
120
+ st.write("BERT-specific visualization example")
121
+ elif selected_model == "GPT-2":
122
+ st.write("GPT-2 attention mask visualization")
123
+
124
+ if __name__ == "__main__":
125
+ main()