initial commit with Panel app
Browse files- Dockerfile +1 -1
- app.py +229 -146
- app_components/__pycache__/canvas.cpython-313.pyc +0 -0
- app_components/__pycache__/plots.cpython-313.pyc +0 -0
- app_components/canvas.py +84 -0
- app_components/plots.py +225 -0
- app_utils/__pycache__/styles.cpython-313.pyc +0 -0
- app_utils/styles.py +53 -0
- assets/github-mark-white.png +0 -0
- model_training/__pycache__/data_setup.cpython-313.pyc +0 -0
- model_training/__pycache__/model.cpython-313.pyc +0 -0
- model_training/__pycache__/utils.cpython-313.pyc +0 -0
- model_training/args.txt +12 -0
- model_training/data_setup.py +151 -0
- model_training/engine.py +195 -0
- model_training/model.py +122 -0
- model_training/run_training.py +90 -0
- model_training/utils.py +64 -0
- requirements.txt +7 -6
- saved_models/tiny_vgg_less_compute_model.pth +3 -0
- saved_models/tiny_vgg_less_compute_settings.yaml +13 -0
- saved_models/tiny_vgg_model.pth +3 -0
- saved_models/tiny_vgg_settings.yaml +13 -0
Dockerfile
CHANGED
@@ -8,7 +8,7 @@ RUN python3 -m pip install --no-cache-dir --upgrade -r /code/requirements.txt
|
|
8 |
|
9 |
COPY . .
|
10 |
|
11 |
-
CMD ["panel", "serve", "/code/app.py", "--address", "0.0.0.0", "--port", "7860", "--allow-websocket-origin", "*"]
|
12 |
|
13 |
RUN mkdir /.cache
|
14 |
RUN chmod 777 /.cache
|
|
|
8 |
|
9 |
COPY . .
|
10 |
|
11 |
+
CMD ["panel", "serve", "/code/app.py", "--address", "0.0.0.0", "--port", "7860", "--allow-websocket-origin", "*", "--num-procs", "2", "--num_threads", "4"]
|
12 |
|
13 |
RUN mkdir /.cache
|
14 |
RUN chmod 777 /.cache
|
app.py
CHANGED
@@ -1,147 +1,230 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
import aiohttp
|
6 |
import panel as pn
|
7 |
-
|
8 |
-
from
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
def
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#####################################
|
2 |
+
# Packages & Dependencies
|
3 |
+
#####################################
|
|
|
|
|
4 |
import panel as pn
|
5 |
+
import os, yaml
|
6 |
+
from panel.viewable import Viewer
|
7 |
+
|
8 |
+
from app_components import canvas, plots
|
9 |
+
from app_utils import styles
|
10 |
+
|
11 |
+
pn.extension('plotly')
|
12 |
+
FILE_PATH = os.path.dirname(__file__)
|
13 |
+
|
14 |
+
|
15 |
+
################################################
|
16 |
+
# Digit Classifier Layout
|
17 |
+
################################################
|
18 |
+
class DigitClassifier(Viewer):
|
19 |
+
'''
|
20 |
+
Builds and displays the UI for the classifier application.
|
21 |
+
|
22 |
+
Args:
|
23 |
+
mod_path (str): The absolute path to the saved TinyVGG model
|
24 |
+
mod_kwargs (dict): A dictionary containing the keyword-arguments for the TinyVGG model.
|
25 |
+
This should have the keys: num_blks, num_convs, in_channels, hidden_channels, and num_classes
|
26 |
+
'''
|
27 |
+
|
28 |
+
def __init__(self, mod_path: str, mod_kwargs: dict, **params):
|
29 |
+
self.canvas = canvas.Canvas(sizing_mode = 'stretch_both',
|
30 |
+
styles = {'border':'black solid 0.15rem'})
|
31 |
+
|
32 |
+
self.clear_btn = pn.widgets.Button(name = 'Clear',
|
33 |
+
sizing_mode = 'stretch_width',
|
34 |
+
stylesheets = [styles.BTN_STYLESHEET])
|
35 |
+
|
36 |
+
self.plot_panels = plots.PlotPanels(canvas_info = self.canvas, mod_path = mod_path, mod_kwargs = mod_kwargs)
|
37 |
+
|
38 |
+
super().__init__(**params)
|
39 |
+
self.github_logo = pn.pane.PNG(
|
40 |
+
object = FILE_PATH + '/assets/github-mark-white.png',
|
41 |
+
alt_text = 'GitHub Repo',
|
42 |
+
link_url = 'https://github.com/Jechen00/digit-classifier-app',
|
43 |
+
height = 70,
|
44 |
+
styles = {'margin':'0'}
|
45 |
+
)
|
46 |
+
self.controls_col = pn.FlexBox(
|
47 |
+
self.github_logo,
|
48 |
+
self.clear_btn,
|
49 |
+
self.plot_panels.pred_txt,
|
50 |
+
gap = '60px',
|
51 |
+
flex_direction = 'column',
|
52 |
+
justify_content = 'center',
|
53 |
+
align_items = 'center',
|
54 |
+
flex_wrap = 'nowrap',
|
55 |
+
styles = {'width':'40%', 'height':'100%'}
|
56 |
+
)
|
57 |
+
|
58 |
+
self.mod_input_txt = pn.pane.HTML(
|
59 |
+
object = '''
|
60 |
+
<div>
|
61 |
+
<b>MODEL INPUT</b>
|
62 |
+
</div>
|
63 |
+
''',
|
64 |
+
styles = {'margin':'0rem', 'padding-left':'0.15rem', 'color':'white',
|
65 |
+
'font-size':styles.FONTSIZES['mod_input_txt'],
|
66 |
+
'font-family':styles.FONTFAMILY,
|
67 |
+
'position':'absolute', 'z-index':'100'}
|
68 |
+
)
|
69 |
+
|
70 |
+
self.img_row = pn.FlexBox(
|
71 |
+
self.canvas,
|
72 |
+
self.controls_col,
|
73 |
+
pn.FlexBox(self.mod_input_txt,
|
74 |
+
self.plot_panels.img_pane,
|
75 |
+
sizing_mode = 'stretch_both',
|
76 |
+
styles = {'border':'solid 0.15rem white'}),
|
77 |
+
gap = '1%',
|
78 |
+
flex_wrap = 'nowrap',
|
79 |
+
flex_direction = 'row',
|
80 |
+
justify_content = 'center',
|
81 |
+
sizing_mode = 'stretch_width',
|
82 |
+
styles = {'height':'60%'}
|
83 |
+
)
|
84 |
+
|
85 |
+
self.prob_row = pn.FlexBox(self.plot_panels.prob_pane,
|
86 |
+
sizing_mode = 'stretch_width',
|
87 |
+
styles = {'height':'40%',
|
88 |
+
'border':'solid 0.15rem black'})
|
89 |
+
|
90 |
+
self.page_info = pn.pane.HTML(
|
91 |
+
object = f'''
|
92 |
+
<style>
|
93 |
+
.link {{
|
94 |
+
color: rgb(29, 161, 242);
|
95 |
+
text-decoration: none;
|
96 |
+
transition: text-decoration 0.2s ease;
|
97 |
+
}}
|
98 |
+
|
99 |
+
.link:hover {{
|
100 |
+
text-decoration: underline;
|
101 |
+
}}
|
102 |
+
</style>
|
103 |
+
|
104 |
+
<div style="text-align:center; font-size:{styles.FONTSIZES['sidebar_title']};margin-top:0.2rem">
|
105 |
+
<b>Digit Classifier</b>
|
106 |
+
</div>
|
107 |
+
|
108 |
+
<div style="padding:0 2.5% 0 2.5%; text-align:left; font-size:{styles.FONTSIZES['sidebar_txt']}; width: 100%;">
|
109 |
+
<hr style="height:2px; background-color:rgb(200, 200, 200); border:none; margin-top:0">
|
110 |
+
|
111 |
+
<p style="margin:0">
|
112 |
+
This is a handwritten digit classifier that uses a <i>convolutional neural network (CNN)</i>
|
113 |
+
to make predictions. The architecture of the model is a scaled-down version of
|
114 |
+
the <i>Visual Geometry Group (VGG)</i> architecture from the paper:
|
115 |
+
<a href="https://arxiv.org/pdf/1409.1556"
|
116 |
+
class="link"
|
117 |
+
target="_blank"
|
118 |
+
rel="noopener noreferrer">
|
119 |
+
Very Deep Convolutional Networks for Large-Scale Image Recognition</a>.
|
120 |
+
</p>
|
121 |
+
</br>
|
122 |
+
<p style="margin:0">
|
123 |
+
<b>How To Use:</b> Draw a digit (0-9) on the canvas
|
124 |
+
and the model will produce a prediction for it in real time.
|
125 |
+
Prediction probabilities (or confidences) for each digit are displayed in the bar chart,
|
126 |
+
reflecting the model's softmax output distribution.
|
127 |
+
To the right of the canvas, you'll also find the transformed input image, i.e. the canvas drawing after undergoing
|
128 |
+
<a href="https://paperswithcode.com/dataset/mnist"
|
129 |
+
class="link"
|
130 |
+
target="_blank"
|
131 |
+
rel="noopener noreferrer">
|
132 |
+
MNIST preprocessing</a>.
|
133 |
+
This input image represents what the model receives prior to feature extraction and classification.
|
134 |
+
</p>
|
135 |
+
</div>
|
136 |
+
<div style="margin-left: 5px; margin-top: 72px">
|
137 |
+
<a href="https://github.com/Jechen00"
|
138 |
+
class="link"
|
139 |
+
target="blank"
|
140 |
+
rel="noopener noreferrer"
|
141 |
+
style="font-size: {styles.FONTSIZES['made_by_txt']}; color: {styles.CLRS['made_by_txt']};">
|
142 |
+
Made by Jeff Chen
|
143 |
+
</a>
|
144 |
+
</div>
|
145 |
+
''',
|
146 |
+
styles = {'margin':' 0rem', 'color': styles.CLRS['sidebar_txt'],
|
147 |
+
'width': '19.7%', 'height': '100%',
|
148 |
+
'font-family': styles.FONTFAMILY,
|
149 |
+
'background-color': styles.CLRS['sidebar'],
|
150 |
+
'overflow-y':'scroll',
|
151 |
+
'border': 'solid 0.15rem black'}
|
152 |
+
)
|
153 |
+
|
154 |
+
self.classifier_content = pn.FlexBox(
|
155 |
+
self.img_row,
|
156 |
+
self.prob_row,
|
157 |
+
gap = '0.5%',
|
158 |
+
flex_direction = 'column',
|
159 |
+
flex_wrap = 'nowrap',
|
160 |
+
sizing_mode = 'stretch_height',
|
161 |
+
styles = {'width': '80%'}
|
162 |
+
)
|
163 |
+
|
164 |
+
self.page_content = pn.FlexBox(
|
165 |
+
self.page_info,
|
166 |
+
self.classifier_content,
|
167 |
+
gap = '0.3%',
|
168 |
+
flex_direction = 'row',
|
169 |
+
justify_content = 'space-around',
|
170 |
+
align_items = 'center',
|
171 |
+
flex_wrap = 'nowrap',
|
172 |
+
styles = {
|
173 |
+
'height':'100%',
|
174 |
+
'width':'100vw',
|
175 |
+
'padding': '1%',
|
176 |
+
'min-width': '1200px',
|
177 |
+
'min-height': '600px',
|
178 |
+
'max-width': '3600px',
|
179 |
+
'max-height': '1800px',
|
180 |
+
'background-color': styles.CLRS['page_bg']
|
181 |
+
},
|
182 |
+
)
|
183 |
+
|
184 |
+
# This is mainly used to ensure there is always have a grey background
|
185 |
+
self.page_layout = pn.FlexBox(
|
186 |
+
self.page_content,
|
187 |
+
justify_content = 'center',
|
188 |
+
flex_wrap = 'nowrap',
|
189 |
+
sizing_mode = 'stretch_both',
|
190 |
+
styles = {
|
191 |
+
'min-width': 'max-content',
|
192 |
+
'background-color': styles.CLRS['page_bg'],
|
193 |
+
}
|
194 |
+
)
|
195 |
+
# Set up on-click event with clear button and the canvas
|
196 |
+
self.clear_btn.on_click(self.canvas.toggle_clear)
|
197 |
+
|
198 |
+
def __panel__(self):
|
199 |
+
'''
|
200 |
+
Returns the main layout of the application to be rendered by Panel.
|
201 |
+
'''
|
202 |
+
return self.page_layout
|
203 |
+
|
204 |
+
|
205 |
+
def create_app():
|
206 |
+
'''
|
207 |
+
Creates the application, ensuring that each user gets a different instance of digit_classifier.
|
208 |
+
Mostly used to keep things away from a global scope.
|
209 |
+
'''
|
210 |
+
# Used to serve with panel serve in command line
|
211 |
+
save_dir = FILE_PATH + '/saved_models'
|
212 |
+
base_name = 'tiny_vgg_less_compute'
|
213 |
+
|
214 |
+
mod_path = f'{save_dir}/{base_name}_model.pth' # Path to the saved model state dict
|
215 |
+
settings_path = f'{save_dir}/{base_name}_settings.yaml' # Path to the saved model kwargs
|
216 |
+
|
217 |
+
# Load in model kwargs
|
218 |
+
with open( settings_path, 'r') as f:
|
219 |
+
loaded_settings = yaml.load(f, Loader = yaml.FullLoader)
|
220 |
+
|
221 |
+
mod_kwargs = loaded_settings['mod_kwargs']
|
222 |
+
|
223 |
+
digit_classifier = DigitClassifier(mod_path = mod_path, mod_kwargs = mod_kwargs)
|
224 |
+
return digit_classifier
|
225 |
+
|
226 |
+
################################################
|
227 |
+
# Serve App
|
228 |
+
################################################
|
229 |
+
# Used to serve with panel serve in command line
|
230 |
+
create_app().servable(title = 'CNN Digit Classifier')
|
app_components/__pycache__/canvas.cpython-313.pyc
ADDED
Binary file (2.96 kB). View file
|
|
app_components/__pycache__/plots.cpython-313.pyc
ADDED
Binary file (11.2 kB). View file
|
|
app_components/canvas.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#####################################
|
2 |
+
# Packages & Dependencies
|
3 |
+
#####################################
|
4 |
+
import param
|
5 |
+
from panel.reactive import ReactiveHTML
|
6 |
+
|
7 |
+
|
8 |
+
#####################################
|
9 |
+
# Canvas
|
10 |
+
#####################################
|
11 |
+
class Canvas(ReactiveHTML):
|
12 |
+
'''
|
13 |
+
The HTML canvas panel used for drawing digits (0-9) in the application.
|
14 |
+
Reference: https://panel.holoviz.org/how_to/custom_components/examples/canvas_draw.html
|
15 |
+
'''
|
16 |
+
uri = param.String()
|
17 |
+
clear = param.Boolean(default = False)
|
18 |
+
|
19 |
+
_template = '''
|
20 |
+
<canvas
|
21 |
+
id="canvas"
|
22 |
+
style="width: 100%; height: 100%"
|
23 |
+
height=400px
|
24 |
+
width=400px
|
25 |
+
onmousedown="${script('start')}"
|
26 |
+
onmousemove="${script('draw')}"
|
27 |
+
onmouseup="${script('end')}"
|
28 |
+
onmouseleave="${script('end')}">
|
29 |
+
</canvas>
|
30 |
+
'''
|
31 |
+
|
32 |
+
_scripts = {
|
33 |
+
'render': '''
|
34 |
+
state.ctx = canvas.getContext('2d');
|
35 |
+
state.ctx.fillStyle = '#FFFFFF';
|
36 |
+
state.ctx.fillRect(0, 0, canvas.width, canvas.height);
|
37 |
+
state.ctx.lineWidth = 30;
|
38 |
+
state.ctx.strokeStyle = '#000000';
|
39 |
+
state.ctx.lineJoin = 'round';
|
40 |
+
state.ctx.lineCap = 'round';
|
41 |
+
|
42 |
+
// Helper to normalize mouse coordinates
|
43 |
+
state.getCoords = function(e) {
|
44 |
+
const rect = canvas.getBoundingClientRect();
|
45 |
+
return {
|
46 |
+
x: (e.clientX - rect.left) * (canvas.width / rect.width),
|
47 |
+
y: (e.clientY - rect.top) * (canvas.height / rect.height)
|
48 |
+
};
|
49 |
+
};
|
50 |
+
''',
|
51 |
+
|
52 |
+
'start': '''
|
53 |
+
if (state.isDrawing) return;
|
54 |
+
state.isDrawing = true;
|
55 |
+
const pos = state.getCoords(event);
|
56 |
+
state.ctx.beginPath();
|
57 |
+
state.ctx.moveTo(pos.x, pos.y);
|
58 |
+
''',
|
59 |
+
|
60 |
+
'draw': '''
|
61 |
+
if (!state.isDrawing) return;
|
62 |
+
const pos = state.getCoords(event);
|
63 |
+
state.ctx.lineTo(pos.x, pos.y);
|
64 |
+
state.ctx.stroke();
|
65 |
+
data.uri = canvas.toDataURL('image/png');
|
66 |
+
''',
|
67 |
+
|
68 |
+
'end': '''
|
69 |
+
if (!state.isDrawing) return; // Early return if already not drawing
|
70 |
+
state.isDrawing = false;
|
71 |
+
''',
|
72 |
+
|
73 |
+
'clear': '''
|
74 |
+
state.ctx.fillStyle = '#FFFFFF';
|
75 |
+
state.ctx.fillRect(0, 0, canvas.width, canvas.height);
|
76 |
+
data.uri = '';
|
77 |
+
'''
|
78 |
+
}
|
79 |
+
|
80 |
+
def toggle_clear(self, *event):
|
81 |
+
'''
|
82 |
+
Toggles the value of self.clear to trigger the JS 'clear' function.
|
83 |
+
'''
|
84 |
+
self.clear = not self.clear
|
app_components/plots.py
ADDED
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#####################################
|
2 |
+
# Packages & Dependencies
|
3 |
+
#####################################
|
4 |
+
import param
|
5 |
+
import panel as pn
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import numpy as np
|
9 |
+
import plotly.graph_objects as go
|
10 |
+
|
11 |
+
from . import canvas
|
12 |
+
from app_utils import styles
|
13 |
+
|
14 |
+
import sys, os
|
15 |
+
APP_PATH = os.path.dirname(os.path.dirname(__file__)) # Path to the digit-classifier-app directory
|
16 |
+
sys.path.append(APP_PATH + '/model_training')
|
17 |
+
|
18 |
+
# Imports from model_training
|
19 |
+
import data_setup, model
|
20 |
+
|
21 |
+
|
22 |
+
#####################################
|
23 |
+
# Plotly Panels
|
24 |
+
#####################################
|
25 |
+
PLOTLY_CONFIGS = {
|
26 |
+
'displayModeBar': True, 'displaylogo': False,
|
27 |
+
'modeBarButtonsToRemove': ['autoScale', 'lasso', 'select',
|
28 |
+
'toImage', 'pan', 'zoom', 'zoomIn', 'zoomOut']
|
29 |
+
}
|
30 |
+
|
31 |
+
class PlotPanels(param.Parameterized):
|
32 |
+
'''
|
33 |
+
Contains all Plotly pane objects for the application.
|
34 |
+
This includes the probability bar chart and the MNIST preprocessed image heat map.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
canvas_info (param.ClassSelector): A Canvas class object to get the data URI of the drawn image.
|
38 |
+
mod_path (str): The absolute path to the saved TinyVGG model.
|
39 |
+
mod_kwargs (dict): A dictionary containing the keyword-arguments for the TinyVGG model.
|
40 |
+
This should have the keys: num_blks, num_convs, in_channels, hidden_channels, and num_classes
|
41 |
+
'''
|
42 |
+
|
43 |
+
canvas_info = param.ClassSelector(class_ = canvas.Canvas) # Canvas object to get the data URI
|
44 |
+
|
45 |
+
def __init__(self, mod_path: str, mod_kwargs: dict, **params):
|
46 |
+
super().__init__(**params)
|
47 |
+
self.class_labels = np.arange(0, 10)
|
48 |
+
self.cnn_mod = model.TinyVGG(**mod_kwargs)
|
49 |
+
self.cnn_mod.load_state_dict(torch.load(mod_path, map_location = 'cpu'))
|
50 |
+
|
51 |
+
self.img_pane = pn.pane.Plotly(
|
52 |
+
name = 'image_plot',
|
53 |
+
config = PLOTLY_CONFIGS,
|
54 |
+
sizing_mode = 'stretch_both',
|
55 |
+
margin = 0,
|
56 |
+
)
|
57 |
+
|
58 |
+
self.prob_pane = pn.pane.Plotly(
|
59 |
+
name = 'prob_plot',
|
60 |
+
config = PLOTLY_CONFIGS,
|
61 |
+
sizing_mode = 'stretch_both',
|
62 |
+
margin = 0
|
63 |
+
)
|
64 |
+
|
65 |
+
self.pred_txt = pn.pane.HTML(
|
66 |
+
styles = {'margin':'0rem', 'color':styles.CLRS['pred_txt'],
|
67 |
+
'font-size':styles.FONTSIZES['pred_txt'],
|
68 |
+
'font-family':styles.FONTFAMILY}
|
69 |
+
)
|
70 |
+
|
71 |
+
# Initialize plotly figures
|
72 |
+
self._update_prediction()
|
73 |
+
|
74 |
+
# Set up watchers thta update based on data URI changes
|
75 |
+
self.canvas_info.param.watch(self._update_prediction, 'uri')
|
76 |
+
|
77 |
+
def _update_prediction(self, *event):
|
78 |
+
'''
|
79 |
+
Performs all prediction-related updates for the application.
|
80 |
+
This function is connected to the URI parameter of canvas_info through a watcher.
|
81 |
+
Any times the URI changes, a class prediction is immediately.
|
82 |
+
Following this, the probability bar chart and model input heatmap are updated as well.
|
83 |
+
'''
|
84 |
+
self._update_preprocessed_tensor()
|
85 |
+
self._update_pred_txt()
|
86 |
+
self._update_img_plot()
|
87 |
+
self._update_prob_plot()
|
88 |
+
|
89 |
+
def _update_preprocessed_tensor(self):
|
90 |
+
'''
|
91 |
+
Transforms the data URI (string) from canvas_info into a preprocessed tensor.
|
92 |
+
This is done by having it undergo the MNISt preprocessing pipeline (see mnist_preprocess in data_setup for details).
|
93 |
+
Additionally, a prediction is made for the preprocessed tensor to get its class label.
|
94 |
+
The correpsonding set of prediction probabilities are stored.
|
95 |
+
'''
|
96 |
+
# Check if uri is non-empty
|
97 |
+
if self.canvas_info.uri:
|
98 |
+
self.input_img = data_setup.mnist_preprocess(self.canvas_info.uri)
|
99 |
+
|
100 |
+
self.cnn_mod.eval() # Set CNN to eval & inference mode
|
101 |
+
with torch.inference_mode():
|
102 |
+
pred_logits = self.cnn_mod(self.input_img.unsqueeze(0))
|
103 |
+
self.pred_probs = torch.softmax(pred_logits, dim = 1)[0].numpy()
|
104 |
+
self.pred_label = np.argmax(self.pred_probs)
|
105 |
+
else:
|
106 |
+
self.input_img = torch.zeros((28, 28))
|
107 |
+
self.pred_probs = np.zeros(10)
|
108 |
+
self.pred_label = None
|
109 |
+
|
110 |
+
def _update_pred_txt(self):
|
111 |
+
'''
|
112 |
+
Updates the prediction and probability HTML text to reflect the current data URI.
|
113 |
+
'''
|
114 |
+
if self.canvas_info.uri:
|
115 |
+
pred, prob = self.pred_label, f'{self.pred_probs[self.pred_label]:.3f}'
|
116 |
+
else:
|
117 |
+
pred, prob = 'N/A', 'N/A'
|
118 |
+
|
119 |
+
self.pred_txt.object = f'''
|
120 |
+
<div style="text-align: left;">
|
121 |
+
<b>Prediction:</b> {pred}
|
122 |
+
</br>
|
123 |
+
<b>Probability:</b> {prob}
|
124 |
+
</div>
|
125 |
+
'''
|
126 |
+
|
127 |
+
def _update_prob_plot(self):
|
128 |
+
'''
|
129 |
+
Updates the probability bar chart to showcase the softmax output probability distribution
|
130 |
+
obtained from the prediction in _update_preprocessed_tensor.
|
131 |
+
'''
|
132 |
+
# Marker fill and outline color for bar plot
|
133 |
+
mkr_clrs = [styles.CLRS['base_bar']] * len(self.class_labels)
|
134 |
+
mkr_line_clrs = [styles.CLRS['base_bar_line']] * len(self.class_labels)
|
135 |
+
if self.pred_label is not None:
|
136 |
+
mkr_clrs[self.pred_label] = styles.CLRS['pred_bar']
|
137 |
+
mkr_line_clrs[self.pred_label] = styles.CLRS['pred_bar_line']
|
138 |
+
|
139 |
+
fig = go.Figure()
|
140 |
+
# Bar plot
|
141 |
+
fig.add_trace(
|
142 |
+
go.Bar(x = self.class_labels, y = self.pred_probs,
|
143 |
+
marker_color = mkr_clrs, marker_line_color = mkr_line_clrs,
|
144 |
+
marker_line_width = 1.5, showlegend = False,
|
145 |
+
text = self.pred_probs, textposition = 'outside',
|
146 |
+
textfont = dict(color = styles.CLRS['plot_txt'],
|
147 |
+
size = styles.FONTSIZES['plot_bar_txt'], family = styles.FONTFAMILY),
|
148 |
+
texttemplate = '%{text:.3f}',
|
149 |
+
customdata = self.pred_probs * 100,
|
150 |
+
hoverlabel_font = dict(family = styles.FONTFAMILY),
|
151 |
+
hovertemplate = '<b>Class Label:</b> %{x}' +
|
152 |
+
'<br><b>Probability:</b> %{customdata:.2f} %' +
|
153 |
+
'<extra></extra>'
|
154 |
+
)
|
155 |
+
)
|
156 |
+
# Used to fix axis limits
|
157 |
+
fig.add_trace(
|
158 |
+
go.Scatter(
|
159 |
+
x = [0.5, 0.5], y = [0.1, 1],
|
160 |
+
marker = dict(color = 'rgba(0, 0, 0, 0)', size = 10),
|
161 |
+
mode = 'markers',
|
162 |
+
hoverinfo = 'skip',
|
163 |
+
showlegend = False
|
164 |
+
)
|
165 |
+
)
|
166 |
+
fig.update_yaxes(
|
167 |
+
title = dict(text = 'Prediction Probability', standoff = 0,
|
168 |
+
font = dict(color = styles.CLRS['plot_txt'],
|
169 |
+
size = styles.FONTSIZES['plot_labels'],
|
170 |
+
family = styles.FONTFAMILY)),
|
171 |
+
tickfont = dict(size = styles.FONTSIZES['plot_ticks'],
|
172 |
+
family = styles.FONTFAMILY),
|
173 |
+
dtick = 0.1, ticks = 'outside', ticklen = 0,
|
174 |
+
gridcolor = styles.CLRS['prob_plot_grid']
|
175 |
+
)
|
176 |
+
fig.update_xaxes(
|
177 |
+
title = dict(text = 'Class Label', standoff = 6,
|
178 |
+
font = dict(color = styles.CLRS['plot_txt'],
|
179 |
+
size = styles.FONTSIZES['plot_labels'],
|
180 |
+
family = styles.FONTFAMILY)),
|
181 |
+
dtick = 1, tickfont = dict(size = styles.FONTSIZES['plot_ticks'],
|
182 |
+
family = styles.FONTFAMILY),
|
183 |
+
)
|
184 |
+
fig.update_layout(
|
185 |
+
paper_bgcolor = styles.CLRS['prob_plot_bg'],
|
186 |
+
plot_bgcolor = styles.CLRS['prob_plot_bg'],
|
187 |
+
margin = dict(l = 60, r = 0, t = 5, b = 45),
|
188 |
+
)
|
189 |
+
|
190 |
+
self.prob_pane.object = fig
|
191 |
+
|
192 |
+
def _update_img_plot(self):
|
193 |
+
'''
|
194 |
+
Updates the heat map to showcase the current model input, i.e. the preprocessed canvas drawing.
|
195 |
+
'''
|
196 |
+
img_np = self.input_img.squeeze().numpy()
|
197 |
+
|
198 |
+
if self.pred_label is not None:
|
199 |
+
zmin, zmax = np.min(img_np), np.max(img_np)
|
200 |
+
else:
|
201 |
+
zmin, zmax = 0, 1
|
202 |
+
|
203 |
+
fig = go.Figure(
|
204 |
+
data = go.Heatmap(
|
205 |
+
z = img_np,
|
206 |
+
colorscale = 'gray',
|
207 |
+
showscale = False,
|
208 |
+
zmin = zmin,
|
209 |
+
zmax = zmax,
|
210 |
+
hoverlabel_font = dict(family = styles.FONTFAMILY),
|
211 |
+
hovertemplate = '<b>Pixel Position:</b> (%{x}, %{y})' +
|
212 |
+
'<br><b>Pixel Value:</b> %{z:.3f}' +
|
213 |
+
'<extra></extra>'
|
214 |
+
)
|
215 |
+
)
|
216 |
+
|
217 |
+
fig.update_yaxes(autorange = 'reversed')
|
218 |
+
fig.update_layout(
|
219 |
+
plot_bgcolor = styles.CLRS['img_plot_bg'],
|
220 |
+
margin = dict(l = 0, r = 0, t = 0, b = 0),
|
221 |
+
xaxis = dict(showticklabels = False),
|
222 |
+
yaxis = dict(showticklabels = False),
|
223 |
+
)
|
224 |
+
|
225 |
+
self.img_pane.object = fig
|
app_utils/__pycache__/styles.cpython-313.pyc
ADDED
Binary file (1.31 kB). View file
|
|
app_utils/styles.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#####################################
|
2 |
+
# Fonts & Colors
|
3 |
+
#####################################
|
4 |
+
FONTFAMILY = 'Helvetica'
|
5 |
+
|
6 |
+
FONTSIZES = {
|
7 |
+
'pred_txt': '1.2rem',
|
8 |
+
'mod_input_txt': '0.8rem',
|
9 |
+
'plot_ticks': 14,
|
10 |
+
'plot_labels': 16,
|
11 |
+
'plot_bar_txt': 14,
|
12 |
+
'btn': '1rem',
|
13 |
+
'sidebar_txt': '0.95rem',
|
14 |
+
'sidebar_title': '1.8rem',
|
15 |
+
'made_by_txt': '0.75rem'
|
16 |
+
}
|
17 |
+
|
18 |
+
CLRS = {
|
19 |
+
'pred_txt': 'white',
|
20 |
+
'sidebar': 'white',
|
21 |
+
'sidebar_txt': 'black',
|
22 |
+
'base_bar': 'rgb(158, 202, 225)',
|
23 |
+
'base_bar_line': 'rgb(8, 48, 107)',
|
24 |
+
'pred_bar': 'rgb(240, 140, 140)',
|
25 |
+
'pred_bar_line': 'rgb(180, 0, 0)',
|
26 |
+
'plot_txt': 'black',
|
27 |
+
'prob_plot_bg': 'white',
|
28 |
+
'prob_plot_grid': 'rgb(225, 225, 225)',
|
29 |
+
'img_plot_bg': 'black',
|
30 |
+
'btn_base': 'white',
|
31 |
+
'btn_hover': 'rgb(200, 200, 200)',
|
32 |
+
'page_bg': 'rgb(150, 150, 150)',
|
33 |
+
'made_by_txt': 'rgb(180, 180, 180)'
|
34 |
+
}
|
35 |
+
|
36 |
+
|
37 |
+
#####################################
|
38 |
+
# Stylesheets
|
39 |
+
#####################################
|
40 |
+
BTN_STYLESHEET = f'''
|
41 |
+
:host(.solid) .bk-btn {{
|
42 |
+
background-color: {CLRS['btn_base']};
|
43 |
+
border: black solid 0.1rem;
|
44 |
+
border-radius: 0.8rem;
|
45 |
+
font-size: {FONTSIZES['btn']};
|
46 |
+
padding-top: 0.3rem;
|
47 |
+
padding-bottom: 0.3rem;
|
48 |
+
}}
|
49 |
+
|
50 |
+
:host(.solid) .bk-btn:hover {{
|
51 |
+
background-color: {CLRS['btn_hover']};
|
52 |
+
}}
|
53 |
+
'''
|
assets/github-mark-white.png
ADDED
![]() |
model_training/__pycache__/data_setup.cpython-313.pyc
ADDED
Binary file (6.32 kB). View file
|
|
model_training/__pycache__/model.cpython-313.pyc
ADDED
Binary file (6.18 kB). View file
|
|
model_training/__pycache__/utils.cpython-313.pyc
ADDED
Binary file (2.55 kB). View file
|
|
model_training/args.txt
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
--num-workers
|
2 |
+
0
|
3 |
+
--num-epochs
|
4 |
+
25
|
5 |
+
--batch-size
|
6 |
+
100
|
7 |
+
--learning-rate
|
8 |
+
0.001
|
9 |
+
--patience
|
10 |
+
10
|
11 |
+
--min-delta
|
12 |
+
0.001
|
model_training/data_setup.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#####################################
|
2 |
+
# Packages & Dependencies
|
3 |
+
#####################################
|
4 |
+
from torchvision import transforms, datasets
|
5 |
+
from torch.utils.data import DataLoader
|
6 |
+
|
7 |
+
import utils
|
8 |
+
from typing import Tuple
|
9 |
+
|
10 |
+
import io
|
11 |
+
import base64
|
12 |
+
from PIL import Image
|
13 |
+
import numpy as np
|
14 |
+
import matplotlib.pyplot as plt
|
15 |
+
|
16 |
+
# Transformations applied to each image
|
17 |
+
BASE_TRANSFORMS = transforms.Compose([
|
18 |
+
transforms.ToTensor(), # Convert to tensor and rescale pixel values to within [0, 1]
|
19 |
+
transforms.Normalize(mean = [0.1307], std = [0.3081]) # Normalize with MNIST stats
|
20 |
+
])
|
21 |
+
|
22 |
+
TRAIN_TRANSFORMS = transforms.Compose([
|
23 |
+
transforms.RandomAffine(degrees = 15, # Rotate up to -/+ 15 degrees
|
24 |
+
scale = (0.8, 1.2), # Scale between 80 and 120 percent
|
25 |
+
translate = (0.08, 0.08), # Translate up to -/+ 8 percent in both x and y
|
26 |
+
shear = 10), # Shear up to -/+ 10 degrees
|
27 |
+
transforms.ToTensor(), # Convert to tensor and rescale pixel values to within [0, 1]
|
28 |
+
transforms.Normalize(mean = [0.1307], std = [0.3081]), # Normalize with MNIST stats
|
29 |
+
])
|
30 |
+
|
31 |
+
|
32 |
+
#####################################
|
33 |
+
# Functions
|
34 |
+
#####################################
|
35 |
+
def get_dataloaders(root: str,
|
36 |
+
batch_size: int,
|
37 |
+
num_workers: int = 0) -> Tuple[DataLoader, DataLoader]:
|
38 |
+
'''
|
39 |
+
Creates training and testing dataloaders for the MNIST dataset
|
40 |
+
|
41 |
+
Args:
|
42 |
+
root (str): Path to download MNIST data.
|
43 |
+
batch_size (int): Size used to split training and testing datasets into batches.
|
44 |
+
num_workers (int): Number of workers to use for multiprocessing. Default is 0.
|
45 |
+
'''
|
46 |
+
|
47 |
+
# Get training and testing MNIST data
|
48 |
+
mnist_train = datasets.MNIST(root, download = True, train = True,
|
49 |
+
transform = TRAIN_TRANSFORMS)
|
50 |
+
mnist_test = datasets.MNIST(root, download = True, train = False,
|
51 |
+
transform = BASE_TRANSFORMS)
|
52 |
+
|
53 |
+
# Create dataloaders
|
54 |
+
if num_workers > 0:
|
55 |
+
mp_context = utils.MP_CONTEXT
|
56 |
+
else:
|
57 |
+
mp_context = None
|
58 |
+
|
59 |
+
train_dl = DataLoader(
|
60 |
+
dataset = mnist_train,
|
61 |
+
batch_size = batch_size,
|
62 |
+
shuffle = True,
|
63 |
+
num_workers = num_workers,
|
64 |
+
multiprocessing_context = mp_context,
|
65 |
+
pin_memory = True
|
66 |
+
)
|
67 |
+
|
68 |
+
test_dl = DataLoader(
|
69 |
+
dataset = mnist_test,
|
70 |
+
batch_size = batch_size,
|
71 |
+
shuffle = False,
|
72 |
+
num_workers = num_workers,
|
73 |
+
multiprocessing_context = mp_context,
|
74 |
+
pin_memory = True
|
75 |
+
)
|
76 |
+
|
77 |
+
return train_dl, test_dl
|
78 |
+
|
79 |
+
def mnist_preprocess(uri: str, plot: bool = False):
|
80 |
+
'''
|
81 |
+
Preprocesses a data URI representing a handwritten digit image according to the pipeline used in the MNIST dataset.
|
82 |
+
The pipeline includes:
|
83 |
+
1. Converting the image to grayscale.
|
84 |
+
2. Resizing the image to 20x20, preserving the aspect ratio, and using anti-aliasing.
|
85 |
+
3. Centering the resized image in a 28x28 image based on the center of mass (COM).
|
86 |
+
4. Converting the image to a tensor (pixel values between 0 and 1) and normalizing it using MNIST statistics.
|
87 |
+
|
88 |
+
Reference: https://paperswithcode.com/dataset/mnist
|
89 |
+
|
90 |
+
Args:
|
91 |
+
uri (str): A string representing the full data URI.
|
92 |
+
plot (bool, optional): If True, the resized 20x20 image is plotted alongside the final 28x28 image (pre-normalization).
|
93 |
+
The red lines on these plots intersect at the COM position. Default is False.
|
94 |
+
Returns:
|
95 |
+
Tensor: A tensor of shape (1, 28, 28) representing the preprocessed image, normalized using MNIST statistics.
|
96 |
+
'''
|
97 |
+
encoded_img = uri.split(',', 1)[1]
|
98 |
+
image_bytes = io.BytesIO(base64.b64decode(encoded_img))
|
99 |
+
pil_img = Image.open(image_bytes).convert('L') # Gray scale
|
100 |
+
|
101 |
+
# Resize to 20x20, preserving aspect ratio, and using anti-aliasing
|
102 |
+
pil_img.thumbnail((20, 20), Image.Resampling.LANCZOS)
|
103 |
+
|
104 |
+
# Convert to numpy and invert image
|
105 |
+
img = 255 - np.array(pil_img)
|
106 |
+
|
107 |
+
# Get image indices for y-axis (rows) and x-axis (columns)
|
108 |
+
img_idxs = np.indices(img.shape)
|
109 |
+
tot_mass = img.sum()
|
110 |
+
|
111 |
+
# This represents the indices of the center of masses (COMs)
|
112 |
+
com_x = np.round((img_idxs[1] * img).sum() / tot_mass).astype(int)
|
113 |
+
com_y = np.round((img_idxs[0] * img).sum() / tot_mass).astype(int)
|
114 |
+
|
115 |
+
dist_com_end_x = img.shape[1] - com_x # number of column indices from com_x to last index
|
116 |
+
dist_com_end_y = img.shape[0] - com_y # number of row indices from com_y to last index
|
117 |
+
|
118 |
+
new_img = np.zeros((28, 28), dtype = np.uint8)
|
119 |
+
new_com_x, new_com_y = 14, 14 # Indices of the COMs for the new 28x28 image
|
120 |
+
|
121 |
+
valid_start_x = min(new_com_x, com_x)
|
122 |
+
valid_end_x = min(14, dist_com_end_x) # 14 is index distance from new COM to 28-th index
|
123 |
+
valid_start_y = min(new_com_y, com_y)
|
124 |
+
valid_end_y = min(14, dist_com_end_y) # 14 is index distance from new COM to 28-th index
|
125 |
+
|
126 |
+
old_slice_x = slice(com_x - valid_start_x, com_x + valid_end_x)
|
127 |
+
old_slice_y = slice(com_y - valid_start_y, com_y + valid_end_y)
|
128 |
+
new_slice_x = slice(new_com_x - valid_start_x, new_com_x + valid_end_x)
|
129 |
+
new_slice_y = slice(new_com_y - valid_start_y, new_com_y + valid_end_y)
|
130 |
+
|
131 |
+
# Paste cropped image into 28x28 field such that the old COM (com_y, com_x), is at the center (14, 14)
|
132 |
+
new_img[new_slice_y, new_slice_x] = img[old_slice_y, old_slice_x]
|
133 |
+
|
134 |
+
if plot:
|
135 |
+
fig, axes = plt.subplots(nrows = 1, ncols = 2, figsize = (12, 6))
|
136 |
+
|
137 |
+
axes[0].imshow(img, cmap = 'grey')
|
138 |
+
axes[0].axhline(com_y, c = 'red')
|
139 |
+
axes[0].axvline(com_x, c = 'red')
|
140 |
+
|
141 |
+
axes[1].imshow(new_img, cmap = 'grey')
|
142 |
+
axes[1].axhline(new_com_y, c = 'red')
|
143 |
+
axes[1].axvline(new_com_x, c = 'red')
|
144 |
+
|
145 |
+
axes[0].set_title(f'Original Resized {img.shape[0]}x{img.shape[1]} Image')
|
146 |
+
axes[1].set_title('New Centered 28x28 Image')
|
147 |
+
|
148 |
+
plt.tight_layout()
|
149 |
+
|
150 |
+
# Return transformed tensor of new image. This includes normalizing to MNIST stats
|
151 |
+
return BASE_TRANSFORMS(new_img)
|
model_training/engine.py
ADDED
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#####################################
|
2 |
+
# Packages
|
3 |
+
#####################################
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from typing import Tuple, Dict, List
|
7 |
+
import utils
|
8 |
+
|
9 |
+
#####################################
|
10 |
+
# Functions
|
11 |
+
#####################################
|
12 |
+
def train_step(model: torch.nn.Module,
|
13 |
+
dataloader: torch.utils.data.DataLoader,
|
14 |
+
loss_fn: torch.nn.Module,
|
15 |
+
optimizer: torch.optim.Optimizer,
|
16 |
+
device: torch.device) -> Tuple[float, float]:
|
17 |
+
|
18 |
+
'''
|
19 |
+
Performs a training step for a PyTorch model.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
model (torch.nn.Module): PyTorch model that will be trained
|
23 |
+
dataloader (torch.utils.data.DataLoader): Dataloader containing data to train on
|
24 |
+
loss_fn (torch.nn.Module): Loss function used as the error metric
|
25 |
+
optimizer (torch.optim.Optimizer): Optimization method used to update model parameters per batch
|
26 |
+
device (torch.device): Device to train on
|
27 |
+
|
28 |
+
Returns:
|
29 |
+
train_loss (float): The average loss calculated over the training set.
|
30 |
+
train_acc (float): The accuracy calculated over the training set.
|
31 |
+
'''
|
32 |
+
|
33 |
+
model.train()
|
34 |
+
train_loss = torch.tensor(0.0, device = device)
|
35 |
+
train_acc = torch.tensor(0.0, device = device)
|
36 |
+
num_samps = len(dataloader.dataset)
|
37 |
+
|
38 |
+
# Loop through all batches in the dataloader
|
39 |
+
for X, y in dataloader:
|
40 |
+
|
41 |
+
optimizer.zero_grad() # Clear old accumulated gradients
|
42 |
+
|
43 |
+
X, y = X.to(device), y.to(device)
|
44 |
+
|
45 |
+
y_logits = model(X) # Get logits
|
46 |
+
|
47 |
+
loss = loss_fn(y_logits, y)
|
48 |
+
train_loss += loss.detach() * X.shape[0] # Calculate total loss for batch
|
49 |
+
|
50 |
+
loss.backward() # Perform backpropagation
|
51 |
+
optimizer.step() # Update parameters
|
52 |
+
|
53 |
+
y_pred = y_logits.argmax(dim = 1) # No softmax needed for argmax (b/c preserves order)
|
54 |
+
|
55 |
+
train_acc += (y_pred == y).sum() # Calculate total accuracy for batch
|
56 |
+
|
57 |
+
# Get average loss and accuracy per sample
|
58 |
+
train_loss = train_loss.item() / num_samps
|
59 |
+
train_acc = train_acc.item() / num_samps
|
60 |
+
|
61 |
+
return train_loss, train_acc
|
62 |
+
|
63 |
+
|
64 |
+
def test_step(model: torch.nn.Module,
|
65 |
+
dataloader: torch.utils.data.DataLoader,
|
66 |
+
loss_fn: torch.nn.Module,
|
67 |
+
device: torch.device) -> Tuple[float, float]:
|
68 |
+
|
69 |
+
'''
|
70 |
+
Performs a testing step for a PyTorch model.
|
71 |
+
|
72 |
+
Args:
|
73 |
+
model (torch.nn.Module): PyTorch model that will be tested.
|
74 |
+
dataloader (torch.utils.data.DataLoader): Dataloader containing data to test on.
|
75 |
+
loss_fn (torch.nn.Module): Loss function used as the error metric.
|
76 |
+
device (torch.device): Device to compute on.
|
77 |
+
|
78 |
+
Returns:
|
79 |
+
test_loss (float): The average loss calculated over batches.
|
80 |
+
test_acc (float): The average accuracy calculated over batches.
|
81 |
+
'''
|
82 |
+
|
83 |
+
model.eval()
|
84 |
+
test_loss = torch.tensor(0.0, device = device)
|
85 |
+
test_acc = torch.tensor(0.0, device = device)
|
86 |
+
num_samps = len(dataloader.dataset)
|
87 |
+
|
88 |
+
with torch.inference_mode():
|
89 |
+
# Loop through all batches in the dataloader
|
90 |
+
for X, y in dataloader:
|
91 |
+
X, y = X.to(device), y.to(device)
|
92 |
+
|
93 |
+
y_logits = model(X) # Get logits
|
94 |
+
|
95 |
+
test_loss += loss_fn(y_logits, y) * X.shape[0] # Calculate total loss for batch
|
96 |
+
|
97 |
+
y_pred = y_logits.argmax(dim = 1) # No softmax needed for argmax (b/c preserves order)
|
98 |
+
|
99 |
+
test_acc += (y_pred == y).sum() # Calculate total accuracy for batch
|
100 |
+
|
101 |
+
# Get average loss and accuracy
|
102 |
+
test_loss = test_loss.item() / num_samps
|
103 |
+
test_acc = test_acc.item() / num_samps
|
104 |
+
|
105 |
+
return test_loss, test_acc
|
106 |
+
|
107 |
+
|
108 |
+
def train(model: torch.nn.Module,
|
109 |
+
train_dl: torch.utils.data.DataLoader,
|
110 |
+
test_dl: torch.utils.data.DataLoader,
|
111 |
+
loss_fn: torch.nn.Module,
|
112 |
+
optimizer: torch.optim.Optimizer,
|
113 |
+
num_epochs: int,
|
114 |
+
patience: int,
|
115 |
+
min_delta: float,
|
116 |
+
device: torch.device,
|
117 |
+
save_mod: bool = True,
|
118 |
+
save_dir: str = '',
|
119 |
+
mod_name: str = '') -> Dict[str, List[float]]:
|
120 |
+
'''
|
121 |
+
Performs the training and testing steps for a PyTorch model,
|
122 |
+
with early stopping applied for test loss.
|
123 |
+
|
124 |
+
Args:
|
125 |
+
model (torch.nn.Module): PyTorch model to train.
|
126 |
+
train_dl (torch.utils.data.DataLoader): DataLoader for training.
|
127 |
+
test_dl (torch.utils.data.DataLoader): DataLoader for testing.
|
128 |
+
loss_fn (torch.nn.Module): Loss function used as the error metric.
|
129 |
+
optimizer (torch.optim.Optimizer): Optimizer used to update model parameters per batch.
|
130 |
+
|
131 |
+
num_epochs (int): Max number of epochs to train.
|
132 |
+
patience (int): Number of epochs to wait before early stopping.
|
133 |
+
min_delta (float): Minimum decrease in loss to reset counter.
|
134 |
+
|
135 |
+
device (torch.device): Device to train on.
|
136 |
+
save_mod (bool, optional): If True, saves the model after each epoch. Default is True.
|
137 |
+
save_dir (str, optional): Directory to save the model to. Must be nonempty if save_mod is True.
|
138 |
+
mod_name (str, optional): Filename for the saved model. Must be nonempty if save_mod is True.
|
139 |
+
|
140 |
+
returns:
|
141 |
+
res (dict): A results dictionary containing lists of train and test metrics for each epoch.
|
142 |
+
'''
|
143 |
+
|
144 |
+
bold_start, bold_end = '\033[1m', '\033[0m'
|
145 |
+
|
146 |
+
if save_mod:
|
147 |
+
assert save_dir, 'save_dir cannot be None or empty.'
|
148 |
+
assert mod_name, 'mod_name cannot be None or empty.'
|
149 |
+
|
150 |
+
# Initialize results dictionary
|
151 |
+
res = {'train_loss': [],
|
152 |
+
'train_acc': [],
|
153 |
+
'test_loss': [],
|
154 |
+
'test_acc': []
|
155 |
+
}
|
156 |
+
|
157 |
+
# Initialize best_loss and counter for early stopping
|
158 |
+
best_loss, counter = None, 0
|
159 |
+
|
160 |
+
for epoch in range(num_epochs):
|
161 |
+
# Perform training and testing step
|
162 |
+
train_loss, train_acc = train_step(model, train_dl, loss_fn, optimizer, device)
|
163 |
+
test_loss, test_acc = test_step(model, test_dl, loss_fn, device)
|
164 |
+
|
165 |
+
# Store loss and accuracy values
|
166 |
+
res['train_loss'].append(train_loss)
|
167 |
+
res['train_acc'].append(train_acc)
|
168 |
+
res['test_loss'].append(test_loss)
|
169 |
+
res['test_acc'].append(test_acc)
|
170 |
+
|
171 |
+
print(f'Epoch: {epoch + 1} | ' +
|
172 |
+
f'train_loss = {train_loss:.4f} | train_acc = {train_acc:.4f} | ' +
|
173 |
+
f'test_loss = {test_loss:.4f} | test_acc = {test_acc:.4f}')
|
174 |
+
|
175 |
+
# Check for improvement
|
176 |
+
if best_loss == None:
|
177 |
+
best_loss = test_loss
|
178 |
+
if save_mod:
|
179 |
+
utils.save_model(model, save_dir, mod_name)
|
180 |
+
|
181 |
+
elif test_loss < best_loss - min_delta:
|
182 |
+
best_loss = test_loss
|
183 |
+
counter = 0
|
184 |
+
|
185 |
+
if save_mod:
|
186 |
+
utils.save_model(model, save_dir, mod_name)
|
187 |
+
print(f'{bold_start}[SAVED]{bold_end} Adequate improvement in test loss; model saved.')
|
188 |
+
|
189 |
+
else:
|
190 |
+
counter += 1
|
191 |
+
if counter > patience:
|
192 |
+
print(f'{bold_start}[ALERT]{bold_end} No improvement in test loss after {counter} epochs; early stopping triggered.')
|
193 |
+
break
|
194 |
+
|
195 |
+
return res
|
model_training/model.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#####################################
|
2 |
+
# Packages & Dependencies
|
3 |
+
#####################################
|
4 |
+
import torch
|
5 |
+
from torch import nn
|
6 |
+
|
7 |
+
|
8 |
+
#####################################
|
9 |
+
# VGG Model Class
|
10 |
+
#####################################
|
11 |
+
class VGGBlock(nn.Module):
|
12 |
+
'''
|
13 |
+
Defines a modified block in the VGG architecture,
|
14 |
+
which includes batch normalization between convolutional layers and ReLU activations.
|
15 |
+
|
16 |
+
Reference: https://poloclub.github.io/cnn-explainer/
|
17 |
+
Reference: https://d2l.ai/chapter_convolutional-modern/vgg.html
|
18 |
+
|
19 |
+
Args:
|
20 |
+
num_convs (int): Number of consecutive convolutional layers + ReLU activations.
|
21 |
+
in_channels (int): Number of channels in the input.
|
22 |
+
hidden_channels (int): Number of hidden channels between convolutional layers.
|
23 |
+
out_channels (int): Number of channels in the output.
|
24 |
+
'''
|
25 |
+
def __init__(self,
|
26 |
+
num_convs: int,
|
27 |
+
in_channels: int,
|
28 |
+
hidden_channels: int,
|
29 |
+
out_channels: int):
|
30 |
+
super().__init__()
|
31 |
+
|
32 |
+
self.layers = []
|
33 |
+
|
34 |
+
for i in range(num_convs):
|
35 |
+
conv_in = in_channels if i == 0 else hidden_channels
|
36 |
+
conv_out = out_channels if i == num_convs-1 else hidden_channels
|
37 |
+
|
38 |
+
self.layers += [
|
39 |
+
nn.Conv2d(conv_in, conv_out, kernel_size = 3, stride = 1, padding = 1),
|
40 |
+
nn.BatchNorm2d(conv_out),
|
41 |
+
nn.ReLU()
|
42 |
+
]
|
43 |
+
|
44 |
+
self.layers.append(nn.MaxPool2d(kernel_size = 2, stride = 2))
|
45 |
+
|
46 |
+
self.vgg_blk = nn.Sequential(*self.layers)
|
47 |
+
|
48 |
+
def forward(self, X: torch.Tensor) -> torch.Tensor:
|
49 |
+
'''
|
50 |
+
Forward pass of VGG block.
|
51 |
+
|
52 |
+
Args:
|
53 |
+
X (torch.Tensor): Input tensor of shape (batch_size, in_channels, height, width)
|
54 |
+
Returns:
|
55 |
+
torch.Tensor: Output tensor of shape (batch_size, out_channels, new_height, new_width)
|
56 |
+
'''
|
57 |
+
|
58 |
+
return self.vgg_blk(X)
|
59 |
+
|
60 |
+
class TinyVGG(nn.Module):
|
61 |
+
'''
|
62 |
+
Creates a simplified version of a VGG model, adapted from
|
63 |
+
https://github.com/poloclub/cnn-explainer/blob/master/tiny-vgg/tiny-vgg.py.
|
64 |
+
The main difference is that the hidden dimensions and number of convolutional layers
|
65 |
+
remain the same across VGG blocks and the classifier's linear layers has output fewer features.
|
66 |
+
|
67 |
+
Args:
|
68 |
+
num_blks (int): Number of VGG blocks to put in the model
|
69 |
+
num_convs (int): Number of consecutive convolutional layers + ReLU activations in each VGG block.
|
70 |
+
in_channels (int): Number of channels in the input.
|
71 |
+
hidden_channels (int): Number of hidden channels between convolutional layers.
|
72 |
+
num_classes (int): Number of class labels.
|
73 |
+
|
74 |
+
'''
|
75 |
+
def __init__(self,
|
76 |
+
num_blks: int,
|
77 |
+
num_convs: int,
|
78 |
+
in_channels: int,
|
79 |
+
hidden_channels: int,
|
80 |
+
num_classes: int):
|
81 |
+
super().__init__()
|
82 |
+
|
83 |
+
self.all_blks = []
|
84 |
+
for i in range(num_blks):
|
85 |
+
conv_in = in_channels if i == 0 else hidden_channels
|
86 |
+
self.all_blks.append(
|
87 |
+
VGGBlock(num_convs, conv_in, hidden_channels, hidden_channels)
|
88 |
+
)
|
89 |
+
|
90 |
+
self.vgg_body = nn.Sequential(*self.all_blks)
|
91 |
+
self.classifier = nn.Sequential(
|
92 |
+
nn.Flatten(),
|
93 |
+
nn.LazyLinear(4096), nn.ReLU(), nn.Dropout(0.5),
|
94 |
+
nn.LazyLinear(2048), nn.ReLU(), nn.Dropout(0.5),
|
95 |
+
nn.LazyLinear(num_classes)
|
96 |
+
)
|
97 |
+
|
98 |
+
self.vgg_body.apply(self._custom_init)
|
99 |
+
self.classifier.apply(self._custom_init)
|
100 |
+
|
101 |
+
def _custom_init(self, module):
|
102 |
+
'''
|
103 |
+
Initializes convolutional layer weights with Xavier initialization method.
|
104 |
+
Initializes convolutional layer biases to zero.
|
105 |
+
'''
|
106 |
+
if isinstance(module, (nn.Conv2d)):
|
107 |
+
nn.init.xavier_uniform_(module.weight)
|
108 |
+
nn.init.zeros_(module.bias)
|
109 |
+
|
110 |
+
def forward(self, X: torch.Tensor) -> torch.Tensor:
|
111 |
+
'''
|
112 |
+
Forward pass of the TinyVGG model.
|
113 |
+
|
114 |
+
Args:
|
115 |
+
X (torch.Tensor): Input tensor of shape (batch_size, in_channels, height, width).
|
116 |
+
|
117 |
+
Returns:
|
118 |
+
torch.Tensor: Logits of shape (batch_size, num_classes).
|
119 |
+
'''
|
120 |
+
|
121 |
+
X = self.vgg_body(X)
|
122 |
+
return self.classifier(X)
|
model_training/run_training.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#####################################
|
2 |
+
# Packages & Dependencies
|
3 |
+
#####################################
|
4 |
+
import argparse
|
5 |
+
import torch
|
6 |
+
from torch import nn
|
7 |
+
|
8 |
+
import utils, data_setup, model, engine
|
9 |
+
import yaml
|
10 |
+
|
11 |
+
# Setup random seeds
|
12 |
+
utils.set_seed(0)
|
13 |
+
|
14 |
+
# Setup hyperparameters
|
15 |
+
parser = argparse.ArgumentParser(fromfile_prefix_chars = '@')
|
16 |
+
|
17 |
+
parser.add_argument('-nw', '--num-workers', help = 'Number of workers for dataloaders.',
|
18 |
+
type = int, default = 0)
|
19 |
+
parser.add_argument('-ne', '--num-epochs', help = 'Number of epochs to train model for.',
|
20 |
+
type = int, default = 15)
|
21 |
+
parser.add_argument('-bs', '--batch-size', help = 'Size of batches to split training set.',
|
22 |
+
type = int, default = 100)
|
23 |
+
parser.add_argument('-lr', '--learning-rate', help = 'Learning rate for the optimizer.',
|
24 |
+
type = float, default = 0.001)
|
25 |
+
parser.add_argument('-p', '--patience', help = 'Number of epochs to wait before early stopping.',
|
26 |
+
type = int, default = 5)
|
27 |
+
parser.add_argument('-md', '--min-delta', help = 'Minimum decrease in loss to reset patience.',
|
28 |
+
type = float, default = 0.001)
|
29 |
+
|
30 |
+
args = parser.parse_args()
|
31 |
+
|
32 |
+
|
33 |
+
#####################################
|
34 |
+
# Training Code
|
35 |
+
#####################################
|
36 |
+
if __name__ == '__main__':
|
37 |
+
|
38 |
+
print(f'{'#' * 50}\n'
|
39 |
+
f'\033[1mTraining hyperparameters:\033[0m \n'
|
40 |
+
f' - num-workers: {args.num_workers} \n'
|
41 |
+
f' - num-epochs: {args.num_epochs} \n'
|
42 |
+
f' - batch-size: {args.batch_size} \n'
|
43 |
+
f' - learning-rate: {args.learning_rate} \n'
|
44 |
+
f' - patience: {args.patience} \n'
|
45 |
+
f' - min-delta: {args.min_delta} \n'
|
46 |
+
f'{'#' * 50}')
|
47 |
+
|
48 |
+
# Get dataloaders
|
49 |
+
train_dl, test_dl = data_setup.get_dataloaders(root = './mnist_data',
|
50 |
+
batch_size = args.batch_size,
|
51 |
+
num_workers = args.num_workers)
|
52 |
+
|
53 |
+
# Set up saving directory and file name
|
54 |
+
save_dir = '../saved_models'
|
55 |
+
|
56 |
+
base_name = 'tiny_vgg'
|
57 |
+
mod_name = f'{base_name}_model.pth'
|
58 |
+
|
59 |
+
# Get TinyVGG model
|
60 |
+
mod_kwargs = {
|
61 |
+
'num_blks': 2,
|
62 |
+
'num_convs': 2,
|
63 |
+
'in_channels': 1,
|
64 |
+
'hidden_channels': 10,
|
65 |
+
'num_classes': len(train_dl.dataset.classes)
|
66 |
+
}
|
67 |
+
|
68 |
+
vgg_mod = model.TinyVGG(**mod_kwargs).to(utils.DEVICE)
|
69 |
+
|
70 |
+
# Save model kwargs and train settings
|
71 |
+
with open(f'{save_dir}/{base_name}_settings.yaml', 'w') as f:
|
72 |
+
yaml.dump({'train_kwargs': vars(args), 'mod_kwargs': mod_kwargs}, f)
|
73 |
+
|
74 |
+
# Get loss function and optimizer
|
75 |
+
loss_fn = nn.CrossEntropyLoss()
|
76 |
+
optimizer = torch.optim.Adam(params = vgg_mod.parameters(), lr = args.learning_rate)
|
77 |
+
|
78 |
+
# Train model
|
79 |
+
mod_res = engine.train(model = vgg_mod,
|
80 |
+
train_dl = train_dl,
|
81 |
+
test_dl = test_dl,
|
82 |
+
loss_fn = loss_fn,
|
83 |
+
optimizer = optimizer,
|
84 |
+
num_epochs = args.num_epochs,
|
85 |
+
patience = args.patience,
|
86 |
+
min_delta = args.min_delta,
|
87 |
+
device = utils.DEVICE,
|
88 |
+
save_mod = True,
|
89 |
+
save_dir = save_dir,
|
90 |
+
mod_name = mod_name)
|
model_training/utils.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#####################################
|
2 |
+
# Packages & Dependencies
|
3 |
+
#####################################
|
4 |
+
import torch
|
5 |
+
import random
|
6 |
+
import numpy as np
|
7 |
+
import os
|
8 |
+
|
9 |
+
# Setup device and multiprocessing context
|
10 |
+
if torch.cuda.is_available():
|
11 |
+
DEVICE = torch.device('cuda')
|
12 |
+
MP_CONTEXT = None
|
13 |
+
elif torch.backends.mps.is_available():
|
14 |
+
DEVICE = torch.device('mps')
|
15 |
+
MP_CONTEXT = 'forkserver'
|
16 |
+
else:
|
17 |
+
DEVICE = torch.device('cpu')
|
18 |
+
MP_CONTEXT = None
|
19 |
+
|
20 |
+
|
21 |
+
#####################################
|
22 |
+
# Functions
|
23 |
+
#####################################
|
24 |
+
def set_seed(seed: int = 0):
|
25 |
+
'''
|
26 |
+
Sets random seed and deterministic settings for reproducibility across:
|
27 |
+
- PyTorch
|
28 |
+
- NumPy
|
29 |
+
- Python's random module
|
30 |
+
|
31 |
+
Args:
|
32 |
+
seed (int): The seed value to set.
|
33 |
+
'''
|
34 |
+
torch.manual_seed(seed)
|
35 |
+
np.random.seed(seed)
|
36 |
+
random.seed(seed)
|
37 |
+
torch.cuda.manual_seed_all(seed)
|
38 |
+
|
39 |
+
torch.use_deterministic_algorithms(True)
|
40 |
+
|
41 |
+
def save_model(model: torch.nn.Module,
|
42 |
+
save_dir: str,
|
43 |
+
mod_name: str):
|
44 |
+
'''
|
45 |
+
Saves the `state_dict()` of a model to the directory 'save_dir.'
|
46 |
+
|
47 |
+
Args:
|
48 |
+
model (torch.nn.Module): The PyTorch model whose state dict and keyword arguments will be saved.
|
49 |
+
save_dir (str): Directory to save the model to.
|
50 |
+
mod_name (str): Filename for the saved model. If this doesn't end with '.pth' or '.pt,' it will be added on for the state_dict.
|
51 |
+
|
52 |
+
'''
|
53 |
+
# Create directory if it doesn't exist
|
54 |
+
os.makedirs(save_dir, exist_ok = True)
|
55 |
+
|
56 |
+
# Add .pth if it is not in mod_name
|
57 |
+
if not mod_name.endswith('.pth') and not mod_name.endswith('.pt'):
|
58 |
+
mod_name += '.pth'
|
59 |
+
|
60 |
+
# Create save path
|
61 |
+
save_path = os.path.join(save_dir, mod_name)
|
62 |
+
|
63 |
+
# Save model's state dict
|
64 |
+
torch.save(obj = model.state_dict(), f = save_path)
|
requirements.txt
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
|
|
|
1 |
+
numpy==2.2.4
|
2 |
+
matplotlib==3.10.1
|
3 |
+
panel==1.4.5
|
4 |
+
param==2.1.1
|
5 |
+
plotly==6.0.1
|
6 |
+
torch==2.6.0
|
7 |
+
torchvision==0.21.0
|
saved_models/tiny_vgg_less_compute_model.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:94a16b55d2a65b58c30bcad6dcee77d7e45e15221577795aa4de97a508fddced
|
3 |
+
size 38494248
|
saved_models/tiny_vgg_less_compute_settings.yaml
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
mod_kwargs:
|
2 |
+
hidden_channels: 6
|
3 |
+
in_channels: 1
|
4 |
+
num_blks: 2
|
5 |
+
num_classes: 10
|
6 |
+
num_convs: 2
|
7 |
+
train_kwargs:
|
8 |
+
batch_size: 100
|
9 |
+
learning_rate: 0.001
|
10 |
+
min_delta: 0.0005
|
11 |
+
num_epochs: 50
|
12 |
+
num_workers: 0
|
13 |
+
patience: 10
|
saved_models/tiny_vgg_model.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:10b1913e0c2c44d5a76624196371ae83381a13a939afe9e9e9146354206997e9
|
3 |
+
size 41711994
|
saved_models/tiny_vgg_settings.yaml
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
mod_kwargs:
|
2 |
+
hidden_channels: 10
|
3 |
+
in_channels: 1
|
4 |
+
num_blks: 2
|
5 |
+
num_classes: 10
|
6 |
+
num_convs: 2
|
7 |
+
train_kwargs:
|
8 |
+
batch_size: 100
|
9 |
+
learning_rate: 0.001
|
10 |
+
min_delta: 0.001
|
11 |
+
num_epochs: 25
|
12 |
+
num_workers: 0
|
13 |
+
patience: 5
|