Jechen00 commited on
Commit
1078e59
·
1 Parent(s): 740510f

initial commit with Panel app

Browse files
Dockerfile CHANGED
@@ -8,7 +8,7 @@ RUN python3 -m pip install --no-cache-dir --upgrade -r /code/requirements.txt
8
 
9
  COPY . .
10
 
11
- CMD ["panel", "serve", "/code/app.py", "--address", "0.0.0.0", "--port", "7860", "--allow-websocket-origin", "*"]
12
 
13
  RUN mkdir /.cache
14
  RUN chmod 777 /.cache
 
8
 
9
  COPY . .
10
 
11
+ CMD ["panel", "serve", "/code/app.py", "--address", "0.0.0.0", "--port", "7860", "--allow-websocket-origin", "*", "--num-procs", "2", "--num_threads", "4"]
12
 
13
  RUN mkdir /.cache
14
  RUN chmod 777 /.cache
app.py CHANGED
@@ -1,147 +1,230 @@
1
- import io
2
- import random
3
- from typing import List, Tuple
4
-
5
- import aiohttp
6
  import panel as pn
7
- from PIL import Image
8
- from transformers import CLIPModel, CLIPProcessor
9
-
10
- pn.extension(design="bootstrap", sizing_mode="stretch_width")
11
-
12
- ICON_URLS = {
13
- "brand-github": "https://github.com/holoviz/panel",
14
- "brand-twitter": "https://twitter.com/Panel_Org",
15
- "brand-linkedin": "https://www.linkedin.com/company/panel-org",
16
- "message-circle": "https://discourse.holoviz.org/",
17
- "brand-discord": "https://discord.gg/AXRHnJU6sP",
18
- }
19
-
20
-
21
- async def random_url(_):
22
- pet = random.choice(["cat", "dog"])
23
- api_url = f"https://api.the{pet}api.com/v1/images/search"
24
- async with aiohttp.ClientSession() as session:
25
- async with session.get(api_url) as resp:
26
- return (await resp.json())[0]["url"]
27
-
28
-
29
- @pn.cache
30
- def load_processor_model(
31
- processor_name: str, model_name: str
32
- ) -> Tuple[CLIPProcessor, CLIPModel]:
33
- processor = CLIPProcessor.from_pretrained(processor_name)
34
- model = CLIPModel.from_pretrained(model_name)
35
- return processor, model
36
-
37
-
38
- async def open_image_url(image_url: str) -> Image:
39
- async with aiohttp.ClientSession() as session:
40
- async with session.get(image_url) as resp:
41
- return Image.open(io.BytesIO(await resp.read()))
42
-
43
-
44
- def get_similarity_scores(class_items: List[str], image: Image) -> List[float]:
45
- processor, model = load_processor_model(
46
- "openai/clip-vit-base-patch32", "openai/clip-vit-base-patch32"
47
- )
48
- inputs = processor(
49
- text=class_items,
50
- images=[image],
51
- return_tensors="pt", # pytorch tensors
52
- )
53
- outputs = model(**inputs)
54
- logits_per_image = outputs.logits_per_image
55
- class_likelihoods = logits_per_image.softmax(dim=1).detach().numpy()
56
- return class_likelihoods[0]
57
-
58
-
59
- async def process_inputs(class_names: List[str], image_url: str):
60
- """
61
- High level function that takes in the user inputs and returns the
62
- classification results as panel objects.
63
- """
64
- try:
65
- main.disabled = True
66
- if not image_url:
67
- yield "##### ⚠️ Provide an image URL"
68
- return
69
-
70
- yield "##### ⚙ Fetching image and running model..."
71
- try:
72
- pil_img = await open_image_url(image_url)
73
- img = pn.pane.Image(pil_img, height=400, align="center")
74
- except Exception as e:
75
- yield f"##### 😔 Something went wrong, please try a different URL!"
76
- return
77
-
78
- class_items = class_names.split(",")
79
- class_likelihoods = get_similarity_scores(class_items, pil_img)
80
-
81
- # build the results column
82
- results = pn.Column("##### 🎉 Here are the results!", img)
83
-
84
- for class_item, class_likelihood in zip(class_items, class_likelihoods):
85
- row_label = pn.widgets.StaticText(
86
- name=class_item.strip(), value=f"{class_likelihood:.2%}", align="center"
87
- )
88
- row_bar = pn.indicators.Progress(
89
- value=int(class_likelihood * 100),
90
- sizing_mode="stretch_width",
91
- bar_color="secondary",
92
- margin=(0, 10),
93
- design=pn.theme.Material,
94
- )
95
- results.append(pn.Column(row_label, row_bar))
96
- yield results
97
- finally:
98
- main.disabled = False
99
-
100
-
101
- # create widgets
102
- randomize_url = pn.widgets.Button(name="Randomize URL", align="end")
103
-
104
- image_url = pn.widgets.TextInput(
105
- name="Image URL to classify",
106
- value=pn.bind(random_url, randomize_url),
107
- )
108
- class_names = pn.widgets.TextInput(
109
- name="Comma separated class names",
110
- placeholder="Enter possible class names, e.g. cat, dog",
111
- value="cat, dog, parrot",
112
- )
113
-
114
- input_widgets = pn.Column(
115
- "##### 😊 Click randomize or paste a URL to start classifying!",
116
- pn.Row(image_url, randomize_url),
117
- class_names,
118
- )
119
-
120
- # add interactivity
121
- interactive_result = pn.panel(
122
- pn.bind(process_inputs, image_url=image_url, class_names=class_names),
123
- height=600,
124
- )
125
-
126
- # add footer
127
- footer_row = pn.Row(pn.Spacer(), align="center")
128
- for icon, url in ICON_URLS.items():
129
- href_button = pn.widgets.Button(icon=icon, width=35, height=35)
130
- href_button.js_on_click(code=f"window.open('{url}')")
131
- footer_row.append(href_button)
132
- footer_row.append(pn.Spacer())
133
-
134
- # create dashboard
135
- main = pn.WidgetBox(
136
- input_widgets,
137
- interactive_result,
138
- footer_row,
139
- )
140
-
141
- title = "Panel Demo - Image Classification"
142
- pn.template.BootstrapTemplate(
143
- title=title,
144
- main=main,
145
- main_max_width="min(50%, 698px)",
146
- header_background="#F08080",
147
- ).servable(title=title)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #####################################
2
+ # Packages & Dependencies
3
+ #####################################
 
 
4
  import panel as pn
5
+ import os, yaml
6
+ from panel.viewable import Viewer
7
+
8
+ from app_components import canvas, plots
9
+ from app_utils import styles
10
+
11
+ pn.extension('plotly')
12
+ FILE_PATH = os.path.dirname(__file__)
13
+
14
+
15
+ ################################################
16
+ # Digit Classifier Layout
17
+ ################################################
18
+ class DigitClassifier(Viewer):
19
+ '''
20
+ Builds and displays the UI for the classifier application.
21
+
22
+ Args:
23
+ mod_path (str): The absolute path to the saved TinyVGG model
24
+ mod_kwargs (dict): A dictionary containing the keyword-arguments for the TinyVGG model.
25
+ This should have the keys: num_blks, num_convs, in_channels, hidden_channels, and num_classes
26
+ '''
27
+
28
+ def __init__(self, mod_path: str, mod_kwargs: dict, **params):
29
+ self.canvas = canvas.Canvas(sizing_mode = 'stretch_both',
30
+ styles = {'border':'black solid 0.15rem'})
31
+
32
+ self.clear_btn = pn.widgets.Button(name = 'Clear',
33
+ sizing_mode = 'stretch_width',
34
+ stylesheets = [styles.BTN_STYLESHEET])
35
+
36
+ self.plot_panels = plots.PlotPanels(canvas_info = self.canvas, mod_path = mod_path, mod_kwargs = mod_kwargs)
37
+
38
+ super().__init__(**params)
39
+ self.github_logo = pn.pane.PNG(
40
+ object = FILE_PATH + '/assets/github-mark-white.png',
41
+ alt_text = 'GitHub Repo',
42
+ link_url = 'https://github.com/Jechen00/digit-classifier-app',
43
+ height = 70,
44
+ styles = {'margin':'0'}
45
+ )
46
+ self.controls_col = pn.FlexBox(
47
+ self.github_logo,
48
+ self.clear_btn,
49
+ self.plot_panels.pred_txt,
50
+ gap = '60px',
51
+ flex_direction = 'column',
52
+ justify_content = 'center',
53
+ align_items = 'center',
54
+ flex_wrap = 'nowrap',
55
+ styles = {'width':'40%', 'height':'100%'}
56
+ )
57
+
58
+ self.mod_input_txt = pn.pane.HTML(
59
+ object = '''
60
+ <div>
61
+ <b>MODEL INPUT</b>
62
+ </div>
63
+ ''',
64
+ styles = {'margin':'0rem', 'padding-left':'0.15rem', 'color':'white',
65
+ 'font-size':styles.FONTSIZES['mod_input_txt'],
66
+ 'font-family':styles.FONTFAMILY,
67
+ 'position':'absolute', 'z-index':'100'}
68
+ )
69
+
70
+ self.img_row = pn.FlexBox(
71
+ self.canvas,
72
+ self.controls_col,
73
+ pn.FlexBox(self.mod_input_txt,
74
+ self.plot_panels.img_pane,
75
+ sizing_mode = 'stretch_both',
76
+ styles = {'border':'solid 0.15rem white'}),
77
+ gap = '1%',
78
+ flex_wrap = 'nowrap',
79
+ flex_direction = 'row',
80
+ justify_content = 'center',
81
+ sizing_mode = 'stretch_width',
82
+ styles = {'height':'60%'}
83
+ )
84
+
85
+ self.prob_row = pn.FlexBox(self.plot_panels.prob_pane,
86
+ sizing_mode = 'stretch_width',
87
+ styles = {'height':'40%',
88
+ 'border':'solid 0.15rem black'})
89
+
90
+ self.page_info = pn.pane.HTML(
91
+ object = f'''
92
+ <style>
93
+ .link {{
94
+ color: rgb(29, 161, 242);
95
+ text-decoration: none;
96
+ transition: text-decoration 0.2s ease;
97
+ }}
98
+
99
+ .link:hover {{
100
+ text-decoration: underline;
101
+ }}
102
+ </style>
103
+
104
+ <div style="text-align:center; font-size:{styles.FONTSIZES['sidebar_title']};margin-top:0.2rem">
105
+ <b>Digit Classifier</b>
106
+ </div>
107
+
108
+ <div style="padding:0 2.5% 0 2.5%; text-align:left; font-size:{styles.FONTSIZES['sidebar_txt']}; width: 100%;">
109
+ <hr style="height:2px; background-color:rgb(200, 200, 200); border:none; margin-top:0">
110
+
111
+ <p style="margin:0">
112
+ This is a handwritten digit classifier that uses a <i>convolutional neural network (CNN)</i>
113
+ to make predictions. The architecture of the model is a scaled-down version of
114
+ the <i>Visual Geometry Group (VGG)</i> architecture from the paper:
115
+ <a href="https://arxiv.org/pdf/1409.1556"
116
+ class="link"
117
+ target="_blank"
118
+ rel="noopener noreferrer">
119
+ Very Deep Convolutional Networks for Large-Scale Image Recognition</a>.
120
+ </p>
121
+ </br>
122
+ <p style="margin:0">
123
+ <b>How To Use:</b> Draw a digit (0-9) on the canvas
124
+ and the model will produce a prediction for it in real time.
125
+ Prediction probabilities (or confidences) for each digit are displayed in the bar chart,
126
+ reflecting the model's softmax output distribution.
127
+ To the right of the canvas, you'll also find the transformed input image, i.e. the canvas drawing after undergoing
128
+ <a href="https://paperswithcode.com/dataset/mnist"
129
+ class="link"
130
+ target="_blank"
131
+ rel="noopener noreferrer">
132
+ MNIST preprocessing</a>.
133
+ This input image represents what the model receives prior to feature extraction and classification.
134
+ </p>
135
+ </div>
136
+ <div style="margin-left: 5px; margin-top: 72px">
137
+ <a href="https://github.com/Jechen00"
138
+ class="link"
139
+ target="blank"
140
+ rel="noopener noreferrer"
141
+ style="font-size: {styles.FONTSIZES['made_by_txt']}; color: {styles.CLRS['made_by_txt']};">
142
+ Made by Jeff Chen
143
+ </a>
144
+ </div>
145
+ ''',
146
+ styles = {'margin':' 0rem', 'color': styles.CLRS['sidebar_txt'],
147
+ 'width': '19.7%', 'height': '100%',
148
+ 'font-family': styles.FONTFAMILY,
149
+ 'background-color': styles.CLRS['sidebar'],
150
+ 'overflow-y':'scroll',
151
+ 'border': 'solid 0.15rem black'}
152
+ )
153
+
154
+ self.classifier_content = pn.FlexBox(
155
+ self.img_row,
156
+ self.prob_row,
157
+ gap = '0.5%',
158
+ flex_direction = 'column',
159
+ flex_wrap = 'nowrap',
160
+ sizing_mode = 'stretch_height',
161
+ styles = {'width': '80%'}
162
+ )
163
+
164
+ self.page_content = pn.FlexBox(
165
+ self.page_info,
166
+ self.classifier_content,
167
+ gap = '0.3%',
168
+ flex_direction = 'row',
169
+ justify_content = 'space-around',
170
+ align_items = 'center',
171
+ flex_wrap = 'nowrap',
172
+ styles = {
173
+ 'height':'100%',
174
+ 'width':'100vw',
175
+ 'padding': '1%',
176
+ 'min-width': '1200px',
177
+ 'min-height': '600px',
178
+ 'max-width': '3600px',
179
+ 'max-height': '1800px',
180
+ 'background-color': styles.CLRS['page_bg']
181
+ },
182
+ )
183
+
184
+ # This is mainly used to ensure there is always have a grey background
185
+ self.page_layout = pn.FlexBox(
186
+ self.page_content,
187
+ justify_content = 'center',
188
+ flex_wrap = 'nowrap',
189
+ sizing_mode = 'stretch_both',
190
+ styles = {
191
+ 'min-width': 'max-content',
192
+ 'background-color': styles.CLRS['page_bg'],
193
+ }
194
+ )
195
+ # Set up on-click event with clear button and the canvas
196
+ self.clear_btn.on_click(self.canvas.toggle_clear)
197
+
198
+ def __panel__(self):
199
+ '''
200
+ Returns the main layout of the application to be rendered by Panel.
201
+ '''
202
+ return self.page_layout
203
+
204
+
205
+ def create_app():
206
+ '''
207
+ Creates the application, ensuring that each user gets a different instance of digit_classifier.
208
+ Mostly used to keep things away from a global scope.
209
+ '''
210
+ # Used to serve with panel serve in command line
211
+ save_dir = FILE_PATH + '/saved_models'
212
+ base_name = 'tiny_vgg_less_compute'
213
+
214
+ mod_path = f'{save_dir}/{base_name}_model.pth' # Path to the saved model state dict
215
+ settings_path = f'{save_dir}/{base_name}_settings.yaml' # Path to the saved model kwargs
216
+
217
+ # Load in model kwargs
218
+ with open( settings_path, 'r') as f:
219
+ loaded_settings = yaml.load(f, Loader = yaml.FullLoader)
220
+
221
+ mod_kwargs = loaded_settings['mod_kwargs']
222
+
223
+ digit_classifier = DigitClassifier(mod_path = mod_path, mod_kwargs = mod_kwargs)
224
+ return digit_classifier
225
+
226
+ ################################################
227
+ # Serve App
228
+ ################################################
229
+ # Used to serve with panel serve in command line
230
+ create_app().servable(title = 'CNN Digit Classifier')
app_components/__pycache__/canvas.cpython-313.pyc ADDED
Binary file (2.96 kB). View file
 
app_components/__pycache__/plots.cpython-313.pyc ADDED
Binary file (11.2 kB). View file
 
app_components/canvas.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #####################################
2
+ # Packages & Dependencies
3
+ #####################################
4
+ import param
5
+ from panel.reactive import ReactiveHTML
6
+
7
+
8
+ #####################################
9
+ # Canvas
10
+ #####################################
11
+ class Canvas(ReactiveHTML):
12
+ '''
13
+ The HTML canvas panel used for drawing digits (0-9) in the application.
14
+ Reference: https://panel.holoviz.org/how_to/custom_components/examples/canvas_draw.html
15
+ '''
16
+ uri = param.String()
17
+ clear = param.Boolean(default = False)
18
+
19
+ _template = '''
20
+ <canvas
21
+ id="canvas"
22
+ style="width: 100%; height: 100%"
23
+ height=400px
24
+ width=400px
25
+ onmousedown="${script('start')}"
26
+ onmousemove="${script('draw')}"
27
+ onmouseup="${script('end')}"
28
+ onmouseleave="${script('end')}">
29
+ </canvas>
30
+ '''
31
+
32
+ _scripts = {
33
+ 'render': '''
34
+ state.ctx = canvas.getContext('2d');
35
+ state.ctx.fillStyle = '#FFFFFF';
36
+ state.ctx.fillRect(0, 0, canvas.width, canvas.height);
37
+ state.ctx.lineWidth = 30;
38
+ state.ctx.strokeStyle = '#000000';
39
+ state.ctx.lineJoin = 'round';
40
+ state.ctx.lineCap = 'round';
41
+
42
+ // Helper to normalize mouse coordinates
43
+ state.getCoords = function(e) {
44
+ const rect = canvas.getBoundingClientRect();
45
+ return {
46
+ x: (e.clientX - rect.left) * (canvas.width / rect.width),
47
+ y: (e.clientY - rect.top) * (canvas.height / rect.height)
48
+ };
49
+ };
50
+ ''',
51
+
52
+ 'start': '''
53
+ if (state.isDrawing) return;
54
+ state.isDrawing = true;
55
+ const pos = state.getCoords(event);
56
+ state.ctx.beginPath();
57
+ state.ctx.moveTo(pos.x, pos.y);
58
+ ''',
59
+
60
+ 'draw': '''
61
+ if (!state.isDrawing) return;
62
+ const pos = state.getCoords(event);
63
+ state.ctx.lineTo(pos.x, pos.y);
64
+ state.ctx.stroke();
65
+ data.uri = canvas.toDataURL('image/png');
66
+ ''',
67
+
68
+ 'end': '''
69
+ if (!state.isDrawing) return; // Early return if already not drawing
70
+ state.isDrawing = false;
71
+ ''',
72
+
73
+ 'clear': '''
74
+ state.ctx.fillStyle = '#FFFFFF';
75
+ state.ctx.fillRect(0, 0, canvas.width, canvas.height);
76
+ data.uri = '';
77
+ '''
78
+ }
79
+
80
+ def toggle_clear(self, *event):
81
+ '''
82
+ Toggles the value of self.clear to trigger the JS 'clear' function.
83
+ '''
84
+ self.clear = not self.clear
app_components/plots.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #####################################
2
+ # Packages & Dependencies
3
+ #####################################
4
+ import param
5
+ import panel as pn
6
+
7
+ import torch
8
+ import numpy as np
9
+ import plotly.graph_objects as go
10
+
11
+ from . import canvas
12
+ from app_utils import styles
13
+
14
+ import sys, os
15
+ APP_PATH = os.path.dirname(os.path.dirname(__file__)) # Path to the digit-classifier-app directory
16
+ sys.path.append(APP_PATH + '/model_training')
17
+
18
+ # Imports from model_training
19
+ import data_setup, model
20
+
21
+
22
+ #####################################
23
+ # Plotly Panels
24
+ #####################################
25
+ PLOTLY_CONFIGS = {
26
+ 'displayModeBar': True, 'displaylogo': False,
27
+ 'modeBarButtonsToRemove': ['autoScale', 'lasso', 'select',
28
+ 'toImage', 'pan', 'zoom', 'zoomIn', 'zoomOut']
29
+ }
30
+
31
+ class PlotPanels(param.Parameterized):
32
+ '''
33
+ Contains all Plotly pane objects for the application.
34
+ This includes the probability bar chart and the MNIST preprocessed image heat map.
35
+
36
+ Args:
37
+ canvas_info (param.ClassSelector): A Canvas class object to get the data URI of the drawn image.
38
+ mod_path (str): The absolute path to the saved TinyVGG model.
39
+ mod_kwargs (dict): A dictionary containing the keyword-arguments for the TinyVGG model.
40
+ This should have the keys: num_blks, num_convs, in_channels, hidden_channels, and num_classes
41
+ '''
42
+
43
+ canvas_info = param.ClassSelector(class_ = canvas.Canvas) # Canvas object to get the data URI
44
+
45
+ def __init__(self, mod_path: str, mod_kwargs: dict, **params):
46
+ super().__init__(**params)
47
+ self.class_labels = np.arange(0, 10)
48
+ self.cnn_mod = model.TinyVGG(**mod_kwargs)
49
+ self.cnn_mod.load_state_dict(torch.load(mod_path, map_location = 'cpu'))
50
+
51
+ self.img_pane = pn.pane.Plotly(
52
+ name = 'image_plot',
53
+ config = PLOTLY_CONFIGS,
54
+ sizing_mode = 'stretch_both',
55
+ margin = 0,
56
+ )
57
+
58
+ self.prob_pane = pn.pane.Plotly(
59
+ name = 'prob_plot',
60
+ config = PLOTLY_CONFIGS,
61
+ sizing_mode = 'stretch_both',
62
+ margin = 0
63
+ )
64
+
65
+ self.pred_txt = pn.pane.HTML(
66
+ styles = {'margin':'0rem', 'color':styles.CLRS['pred_txt'],
67
+ 'font-size':styles.FONTSIZES['pred_txt'],
68
+ 'font-family':styles.FONTFAMILY}
69
+ )
70
+
71
+ # Initialize plotly figures
72
+ self._update_prediction()
73
+
74
+ # Set up watchers thta update based on data URI changes
75
+ self.canvas_info.param.watch(self._update_prediction, 'uri')
76
+
77
+ def _update_prediction(self, *event):
78
+ '''
79
+ Performs all prediction-related updates for the application.
80
+ This function is connected to the URI parameter of canvas_info through a watcher.
81
+ Any times the URI changes, a class prediction is immediately.
82
+ Following this, the probability bar chart and model input heatmap are updated as well.
83
+ '''
84
+ self._update_preprocessed_tensor()
85
+ self._update_pred_txt()
86
+ self._update_img_plot()
87
+ self._update_prob_plot()
88
+
89
+ def _update_preprocessed_tensor(self):
90
+ '''
91
+ Transforms the data URI (string) from canvas_info into a preprocessed tensor.
92
+ This is done by having it undergo the MNISt preprocessing pipeline (see mnist_preprocess in data_setup for details).
93
+ Additionally, a prediction is made for the preprocessed tensor to get its class label.
94
+ The correpsonding set of prediction probabilities are stored.
95
+ '''
96
+ # Check if uri is non-empty
97
+ if self.canvas_info.uri:
98
+ self.input_img = data_setup.mnist_preprocess(self.canvas_info.uri)
99
+
100
+ self.cnn_mod.eval() # Set CNN to eval & inference mode
101
+ with torch.inference_mode():
102
+ pred_logits = self.cnn_mod(self.input_img.unsqueeze(0))
103
+ self.pred_probs = torch.softmax(pred_logits, dim = 1)[0].numpy()
104
+ self.pred_label = np.argmax(self.pred_probs)
105
+ else:
106
+ self.input_img = torch.zeros((28, 28))
107
+ self.pred_probs = np.zeros(10)
108
+ self.pred_label = None
109
+
110
+ def _update_pred_txt(self):
111
+ '''
112
+ Updates the prediction and probability HTML text to reflect the current data URI.
113
+ '''
114
+ if self.canvas_info.uri:
115
+ pred, prob = self.pred_label, f'{self.pred_probs[self.pred_label]:.3f}'
116
+ else:
117
+ pred, prob = 'N/A', 'N/A'
118
+
119
+ self.pred_txt.object = f'''
120
+ <div style="text-align: left;">
121
+ <b>Prediction:</b> {pred}
122
+ </br>
123
+ <b>Probability:</b> {prob}
124
+ </div>
125
+ '''
126
+
127
+ def _update_prob_plot(self):
128
+ '''
129
+ Updates the probability bar chart to showcase the softmax output probability distribution
130
+ obtained from the prediction in _update_preprocessed_tensor.
131
+ '''
132
+ # Marker fill and outline color for bar plot
133
+ mkr_clrs = [styles.CLRS['base_bar']] * len(self.class_labels)
134
+ mkr_line_clrs = [styles.CLRS['base_bar_line']] * len(self.class_labels)
135
+ if self.pred_label is not None:
136
+ mkr_clrs[self.pred_label] = styles.CLRS['pred_bar']
137
+ mkr_line_clrs[self.pred_label] = styles.CLRS['pred_bar_line']
138
+
139
+ fig = go.Figure()
140
+ # Bar plot
141
+ fig.add_trace(
142
+ go.Bar(x = self.class_labels, y = self.pred_probs,
143
+ marker_color = mkr_clrs, marker_line_color = mkr_line_clrs,
144
+ marker_line_width = 1.5, showlegend = False,
145
+ text = self.pred_probs, textposition = 'outside',
146
+ textfont = dict(color = styles.CLRS['plot_txt'],
147
+ size = styles.FONTSIZES['plot_bar_txt'], family = styles.FONTFAMILY),
148
+ texttemplate = '%{text:.3f}',
149
+ customdata = self.pred_probs * 100,
150
+ hoverlabel_font = dict(family = styles.FONTFAMILY),
151
+ hovertemplate = '<b>Class Label:</b> %{x}' +
152
+ '<br><b>Probability:</b> %{customdata:.2f} %' +
153
+ '<extra></extra>'
154
+ )
155
+ )
156
+ # Used to fix axis limits
157
+ fig.add_trace(
158
+ go.Scatter(
159
+ x = [0.5, 0.5], y = [0.1, 1],
160
+ marker = dict(color = 'rgba(0, 0, 0, 0)', size = 10),
161
+ mode = 'markers',
162
+ hoverinfo = 'skip',
163
+ showlegend = False
164
+ )
165
+ )
166
+ fig.update_yaxes(
167
+ title = dict(text = 'Prediction Probability', standoff = 0,
168
+ font = dict(color = styles.CLRS['plot_txt'],
169
+ size = styles.FONTSIZES['plot_labels'],
170
+ family = styles.FONTFAMILY)),
171
+ tickfont = dict(size = styles.FONTSIZES['plot_ticks'],
172
+ family = styles.FONTFAMILY),
173
+ dtick = 0.1, ticks = 'outside', ticklen = 0,
174
+ gridcolor = styles.CLRS['prob_plot_grid']
175
+ )
176
+ fig.update_xaxes(
177
+ title = dict(text = 'Class Label', standoff = 6,
178
+ font = dict(color = styles.CLRS['plot_txt'],
179
+ size = styles.FONTSIZES['plot_labels'],
180
+ family = styles.FONTFAMILY)),
181
+ dtick = 1, tickfont = dict(size = styles.FONTSIZES['plot_ticks'],
182
+ family = styles.FONTFAMILY),
183
+ )
184
+ fig.update_layout(
185
+ paper_bgcolor = styles.CLRS['prob_plot_bg'],
186
+ plot_bgcolor = styles.CLRS['prob_plot_bg'],
187
+ margin = dict(l = 60, r = 0, t = 5, b = 45),
188
+ )
189
+
190
+ self.prob_pane.object = fig
191
+
192
+ def _update_img_plot(self):
193
+ '''
194
+ Updates the heat map to showcase the current model input, i.e. the preprocessed canvas drawing.
195
+ '''
196
+ img_np = self.input_img.squeeze().numpy()
197
+
198
+ if self.pred_label is not None:
199
+ zmin, zmax = np.min(img_np), np.max(img_np)
200
+ else:
201
+ zmin, zmax = 0, 1
202
+
203
+ fig = go.Figure(
204
+ data = go.Heatmap(
205
+ z = img_np,
206
+ colorscale = 'gray',
207
+ showscale = False,
208
+ zmin = zmin,
209
+ zmax = zmax,
210
+ hoverlabel_font = dict(family = styles.FONTFAMILY),
211
+ hovertemplate = '<b>Pixel Position:</b> (%{x}, %{y})' +
212
+ '<br><b>Pixel Value:</b> %{z:.3f}' +
213
+ '<extra></extra>'
214
+ )
215
+ )
216
+
217
+ fig.update_yaxes(autorange = 'reversed')
218
+ fig.update_layout(
219
+ plot_bgcolor = styles.CLRS['img_plot_bg'],
220
+ margin = dict(l = 0, r = 0, t = 0, b = 0),
221
+ xaxis = dict(showticklabels = False),
222
+ yaxis = dict(showticklabels = False),
223
+ )
224
+
225
+ self.img_pane.object = fig
app_utils/__pycache__/styles.cpython-313.pyc ADDED
Binary file (1.31 kB). View file
 
app_utils/styles.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #####################################
2
+ # Fonts & Colors
3
+ #####################################
4
+ FONTFAMILY = 'Helvetica'
5
+
6
+ FONTSIZES = {
7
+ 'pred_txt': '1.2rem',
8
+ 'mod_input_txt': '0.8rem',
9
+ 'plot_ticks': 14,
10
+ 'plot_labels': 16,
11
+ 'plot_bar_txt': 14,
12
+ 'btn': '1rem',
13
+ 'sidebar_txt': '0.95rem',
14
+ 'sidebar_title': '1.8rem',
15
+ 'made_by_txt': '0.75rem'
16
+ }
17
+
18
+ CLRS = {
19
+ 'pred_txt': 'white',
20
+ 'sidebar': 'white',
21
+ 'sidebar_txt': 'black',
22
+ 'base_bar': 'rgb(158, 202, 225)',
23
+ 'base_bar_line': 'rgb(8, 48, 107)',
24
+ 'pred_bar': 'rgb(240, 140, 140)',
25
+ 'pred_bar_line': 'rgb(180, 0, 0)',
26
+ 'plot_txt': 'black',
27
+ 'prob_plot_bg': 'white',
28
+ 'prob_plot_grid': 'rgb(225, 225, 225)',
29
+ 'img_plot_bg': 'black',
30
+ 'btn_base': 'white',
31
+ 'btn_hover': 'rgb(200, 200, 200)',
32
+ 'page_bg': 'rgb(150, 150, 150)',
33
+ 'made_by_txt': 'rgb(180, 180, 180)'
34
+ }
35
+
36
+
37
+ #####################################
38
+ # Stylesheets
39
+ #####################################
40
+ BTN_STYLESHEET = f'''
41
+ :host(.solid) .bk-btn {{
42
+ background-color: {CLRS['btn_base']};
43
+ border: black solid 0.1rem;
44
+ border-radius: 0.8rem;
45
+ font-size: {FONTSIZES['btn']};
46
+ padding-top: 0.3rem;
47
+ padding-bottom: 0.3rem;
48
+ }}
49
+
50
+ :host(.solid) .bk-btn:hover {{
51
+ background-color: {CLRS['btn_hover']};
52
+ }}
53
+ '''
assets/github-mark-white.png ADDED
model_training/__pycache__/data_setup.cpython-313.pyc ADDED
Binary file (6.32 kB). View file
 
model_training/__pycache__/model.cpython-313.pyc ADDED
Binary file (6.18 kB). View file
 
model_training/__pycache__/utils.cpython-313.pyc ADDED
Binary file (2.55 kB). View file
 
model_training/args.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --num-workers
2
+ 0
3
+ --num-epochs
4
+ 25
5
+ --batch-size
6
+ 100
7
+ --learning-rate
8
+ 0.001
9
+ --patience
10
+ 10
11
+ --min-delta
12
+ 0.001
model_training/data_setup.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #####################################
2
+ # Packages & Dependencies
3
+ #####################################
4
+ from torchvision import transforms, datasets
5
+ from torch.utils.data import DataLoader
6
+
7
+ import utils
8
+ from typing import Tuple
9
+
10
+ import io
11
+ import base64
12
+ from PIL import Image
13
+ import numpy as np
14
+ import matplotlib.pyplot as plt
15
+
16
+ # Transformations applied to each image
17
+ BASE_TRANSFORMS = transforms.Compose([
18
+ transforms.ToTensor(), # Convert to tensor and rescale pixel values to within [0, 1]
19
+ transforms.Normalize(mean = [0.1307], std = [0.3081]) # Normalize with MNIST stats
20
+ ])
21
+
22
+ TRAIN_TRANSFORMS = transforms.Compose([
23
+ transforms.RandomAffine(degrees = 15, # Rotate up to -/+ 15 degrees
24
+ scale = (0.8, 1.2), # Scale between 80 and 120 percent
25
+ translate = (0.08, 0.08), # Translate up to -/+ 8 percent in both x and y
26
+ shear = 10), # Shear up to -/+ 10 degrees
27
+ transforms.ToTensor(), # Convert to tensor and rescale pixel values to within [0, 1]
28
+ transforms.Normalize(mean = [0.1307], std = [0.3081]), # Normalize with MNIST stats
29
+ ])
30
+
31
+
32
+ #####################################
33
+ # Functions
34
+ #####################################
35
+ def get_dataloaders(root: str,
36
+ batch_size: int,
37
+ num_workers: int = 0) -> Tuple[DataLoader, DataLoader]:
38
+ '''
39
+ Creates training and testing dataloaders for the MNIST dataset
40
+
41
+ Args:
42
+ root (str): Path to download MNIST data.
43
+ batch_size (int): Size used to split training and testing datasets into batches.
44
+ num_workers (int): Number of workers to use for multiprocessing. Default is 0.
45
+ '''
46
+
47
+ # Get training and testing MNIST data
48
+ mnist_train = datasets.MNIST(root, download = True, train = True,
49
+ transform = TRAIN_TRANSFORMS)
50
+ mnist_test = datasets.MNIST(root, download = True, train = False,
51
+ transform = BASE_TRANSFORMS)
52
+
53
+ # Create dataloaders
54
+ if num_workers > 0:
55
+ mp_context = utils.MP_CONTEXT
56
+ else:
57
+ mp_context = None
58
+
59
+ train_dl = DataLoader(
60
+ dataset = mnist_train,
61
+ batch_size = batch_size,
62
+ shuffle = True,
63
+ num_workers = num_workers,
64
+ multiprocessing_context = mp_context,
65
+ pin_memory = True
66
+ )
67
+
68
+ test_dl = DataLoader(
69
+ dataset = mnist_test,
70
+ batch_size = batch_size,
71
+ shuffle = False,
72
+ num_workers = num_workers,
73
+ multiprocessing_context = mp_context,
74
+ pin_memory = True
75
+ )
76
+
77
+ return train_dl, test_dl
78
+
79
+ def mnist_preprocess(uri: str, plot: bool = False):
80
+ '''
81
+ Preprocesses a data URI representing a handwritten digit image according to the pipeline used in the MNIST dataset.
82
+ The pipeline includes:
83
+ 1. Converting the image to grayscale.
84
+ 2. Resizing the image to 20x20, preserving the aspect ratio, and using anti-aliasing.
85
+ 3. Centering the resized image in a 28x28 image based on the center of mass (COM).
86
+ 4. Converting the image to a tensor (pixel values between 0 and 1) and normalizing it using MNIST statistics.
87
+
88
+ Reference: https://paperswithcode.com/dataset/mnist
89
+
90
+ Args:
91
+ uri (str): A string representing the full data URI.
92
+ plot (bool, optional): If True, the resized 20x20 image is plotted alongside the final 28x28 image (pre-normalization).
93
+ The red lines on these plots intersect at the COM position. Default is False.
94
+ Returns:
95
+ Tensor: A tensor of shape (1, 28, 28) representing the preprocessed image, normalized using MNIST statistics.
96
+ '''
97
+ encoded_img = uri.split(',', 1)[1]
98
+ image_bytes = io.BytesIO(base64.b64decode(encoded_img))
99
+ pil_img = Image.open(image_bytes).convert('L') # Gray scale
100
+
101
+ # Resize to 20x20, preserving aspect ratio, and using anti-aliasing
102
+ pil_img.thumbnail((20, 20), Image.Resampling.LANCZOS)
103
+
104
+ # Convert to numpy and invert image
105
+ img = 255 - np.array(pil_img)
106
+
107
+ # Get image indices for y-axis (rows) and x-axis (columns)
108
+ img_idxs = np.indices(img.shape)
109
+ tot_mass = img.sum()
110
+
111
+ # This represents the indices of the center of masses (COMs)
112
+ com_x = np.round((img_idxs[1] * img).sum() / tot_mass).astype(int)
113
+ com_y = np.round((img_idxs[0] * img).sum() / tot_mass).astype(int)
114
+
115
+ dist_com_end_x = img.shape[1] - com_x # number of column indices from com_x to last index
116
+ dist_com_end_y = img.shape[0] - com_y # number of row indices from com_y to last index
117
+
118
+ new_img = np.zeros((28, 28), dtype = np.uint8)
119
+ new_com_x, new_com_y = 14, 14 # Indices of the COMs for the new 28x28 image
120
+
121
+ valid_start_x = min(new_com_x, com_x)
122
+ valid_end_x = min(14, dist_com_end_x) # 14 is index distance from new COM to 28-th index
123
+ valid_start_y = min(new_com_y, com_y)
124
+ valid_end_y = min(14, dist_com_end_y) # 14 is index distance from new COM to 28-th index
125
+
126
+ old_slice_x = slice(com_x - valid_start_x, com_x + valid_end_x)
127
+ old_slice_y = slice(com_y - valid_start_y, com_y + valid_end_y)
128
+ new_slice_x = slice(new_com_x - valid_start_x, new_com_x + valid_end_x)
129
+ new_slice_y = slice(new_com_y - valid_start_y, new_com_y + valid_end_y)
130
+
131
+ # Paste cropped image into 28x28 field such that the old COM (com_y, com_x), is at the center (14, 14)
132
+ new_img[new_slice_y, new_slice_x] = img[old_slice_y, old_slice_x]
133
+
134
+ if plot:
135
+ fig, axes = plt.subplots(nrows = 1, ncols = 2, figsize = (12, 6))
136
+
137
+ axes[0].imshow(img, cmap = 'grey')
138
+ axes[0].axhline(com_y, c = 'red')
139
+ axes[0].axvline(com_x, c = 'red')
140
+
141
+ axes[1].imshow(new_img, cmap = 'grey')
142
+ axes[1].axhline(new_com_y, c = 'red')
143
+ axes[1].axvline(new_com_x, c = 'red')
144
+
145
+ axes[0].set_title(f'Original Resized {img.shape[0]}x{img.shape[1]} Image')
146
+ axes[1].set_title('New Centered 28x28 Image')
147
+
148
+ plt.tight_layout()
149
+
150
+ # Return transformed tensor of new image. This includes normalizing to MNIST stats
151
+ return BASE_TRANSFORMS(new_img)
model_training/engine.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #####################################
2
+ # Packages
3
+ #####################################
4
+ import torch
5
+
6
+ from typing import Tuple, Dict, List
7
+ import utils
8
+
9
+ #####################################
10
+ # Functions
11
+ #####################################
12
+ def train_step(model: torch.nn.Module,
13
+ dataloader: torch.utils.data.DataLoader,
14
+ loss_fn: torch.nn.Module,
15
+ optimizer: torch.optim.Optimizer,
16
+ device: torch.device) -> Tuple[float, float]:
17
+
18
+ '''
19
+ Performs a training step for a PyTorch model.
20
+
21
+ Args:
22
+ model (torch.nn.Module): PyTorch model that will be trained
23
+ dataloader (torch.utils.data.DataLoader): Dataloader containing data to train on
24
+ loss_fn (torch.nn.Module): Loss function used as the error metric
25
+ optimizer (torch.optim.Optimizer): Optimization method used to update model parameters per batch
26
+ device (torch.device): Device to train on
27
+
28
+ Returns:
29
+ train_loss (float): The average loss calculated over the training set.
30
+ train_acc (float): The accuracy calculated over the training set.
31
+ '''
32
+
33
+ model.train()
34
+ train_loss = torch.tensor(0.0, device = device)
35
+ train_acc = torch.tensor(0.0, device = device)
36
+ num_samps = len(dataloader.dataset)
37
+
38
+ # Loop through all batches in the dataloader
39
+ for X, y in dataloader:
40
+
41
+ optimizer.zero_grad() # Clear old accumulated gradients
42
+
43
+ X, y = X.to(device), y.to(device)
44
+
45
+ y_logits = model(X) # Get logits
46
+
47
+ loss = loss_fn(y_logits, y)
48
+ train_loss += loss.detach() * X.shape[0] # Calculate total loss for batch
49
+
50
+ loss.backward() # Perform backpropagation
51
+ optimizer.step() # Update parameters
52
+
53
+ y_pred = y_logits.argmax(dim = 1) # No softmax needed for argmax (b/c preserves order)
54
+
55
+ train_acc += (y_pred == y).sum() # Calculate total accuracy for batch
56
+
57
+ # Get average loss and accuracy per sample
58
+ train_loss = train_loss.item() / num_samps
59
+ train_acc = train_acc.item() / num_samps
60
+
61
+ return train_loss, train_acc
62
+
63
+
64
+ def test_step(model: torch.nn.Module,
65
+ dataloader: torch.utils.data.DataLoader,
66
+ loss_fn: torch.nn.Module,
67
+ device: torch.device) -> Tuple[float, float]:
68
+
69
+ '''
70
+ Performs a testing step for a PyTorch model.
71
+
72
+ Args:
73
+ model (torch.nn.Module): PyTorch model that will be tested.
74
+ dataloader (torch.utils.data.DataLoader): Dataloader containing data to test on.
75
+ loss_fn (torch.nn.Module): Loss function used as the error metric.
76
+ device (torch.device): Device to compute on.
77
+
78
+ Returns:
79
+ test_loss (float): The average loss calculated over batches.
80
+ test_acc (float): The average accuracy calculated over batches.
81
+ '''
82
+
83
+ model.eval()
84
+ test_loss = torch.tensor(0.0, device = device)
85
+ test_acc = torch.tensor(0.0, device = device)
86
+ num_samps = len(dataloader.dataset)
87
+
88
+ with torch.inference_mode():
89
+ # Loop through all batches in the dataloader
90
+ for X, y in dataloader:
91
+ X, y = X.to(device), y.to(device)
92
+
93
+ y_logits = model(X) # Get logits
94
+
95
+ test_loss += loss_fn(y_logits, y) * X.shape[0] # Calculate total loss for batch
96
+
97
+ y_pred = y_logits.argmax(dim = 1) # No softmax needed for argmax (b/c preserves order)
98
+
99
+ test_acc += (y_pred == y).sum() # Calculate total accuracy for batch
100
+
101
+ # Get average loss and accuracy
102
+ test_loss = test_loss.item() / num_samps
103
+ test_acc = test_acc.item() / num_samps
104
+
105
+ return test_loss, test_acc
106
+
107
+
108
+ def train(model: torch.nn.Module,
109
+ train_dl: torch.utils.data.DataLoader,
110
+ test_dl: torch.utils.data.DataLoader,
111
+ loss_fn: torch.nn.Module,
112
+ optimizer: torch.optim.Optimizer,
113
+ num_epochs: int,
114
+ patience: int,
115
+ min_delta: float,
116
+ device: torch.device,
117
+ save_mod: bool = True,
118
+ save_dir: str = '',
119
+ mod_name: str = '') -> Dict[str, List[float]]:
120
+ '''
121
+ Performs the training and testing steps for a PyTorch model,
122
+ with early stopping applied for test loss.
123
+
124
+ Args:
125
+ model (torch.nn.Module): PyTorch model to train.
126
+ train_dl (torch.utils.data.DataLoader): DataLoader for training.
127
+ test_dl (torch.utils.data.DataLoader): DataLoader for testing.
128
+ loss_fn (torch.nn.Module): Loss function used as the error metric.
129
+ optimizer (torch.optim.Optimizer): Optimizer used to update model parameters per batch.
130
+
131
+ num_epochs (int): Max number of epochs to train.
132
+ patience (int): Number of epochs to wait before early stopping.
133
+ min_delta (float): Minimum decrease in loss to reset counter.
134
+
135
+ device (torch.device): Device to train on.
136
+ save_mod (bool, optional): If True, saves the model after each epoch. Default is True.
137
+ save_dir (str, optional): Directory to save the model to. Must be nonempty if save_mod is True.
138
+ mod_name (str, optional): Filename for the saved model. Must be nonempty if save_mod is True.
139
+
140
+ returns:
141
+ res (dict): A results dictionary containing lists of train and test metrics for each epoch.
142
+ '''
143
+
144
+ bold_start, bold_end = '\033[1m', '\033[0m'
145
+
146
+ if save_mod:
147
+ assert save_dir, 'save_dir cannot be None or empty.'
148
+ assert mod_name, 'mod_name cannot be None or empty.'
149
+
150
+ # Initialize results dictionary
151
+ res = {'train_loss': [],
152
+ 'train_acc': [],
153
+ 'test_loss': [],
154
+ 'test_acc': []
155
+ }
156
+
157
+ # Initialize best_loss and counter for early stopping
158
+ best_loss, counter = None, 0
159
+
160
+ for epoch in range(num_epochs):
161
+ # Perform training and testing step
162
+ train_loss, train_acc = train_step(model, train_dl, loss_fn, optimizer, device)
163
+ test_loss, test_acc = test_step(model, test_dl, loss_fn, device)
164
+
165
+ # Store loss and accuracy values
166
+ res['train_loss'].append(train_loss)
167
+ res['train_acc'].append(train_acc)
168
+ res['test_loss'].append(test_loss)
169
+ res['test_acc'].append(test_acc)
170
+
171
+ print(f'Epoch: {epoch + 1} | ' +
172
+ f'train_loss = {train_loss:.4f} | train_acc = {train_acc:.4f} | ' +
173
+ f'test_loss = {test_loss:.4f} | test_acc = {test_acc:.4f}')
174
+
175
+ # Check for improvement
176
+ if best_loss == None:
177
+ best_loss = test_loss
178
+ if save_mod:
179
+ utils.save_model(model, save_dir, mod_name)
180
+
181
+ elif test_loss < best_loss - min_delta:
182
+ best_loss = test_loss
183
+ counter = 0
184
+
185
+ if save_mod:
186
+ utils.save_model(model, save_dir, mod_name)
187
+ print(f'{bold_start}[SAVED]{bold_end} Adequate improvement in test loss; model saved.')
188
+
189
+ else:
190
+ counter += 1
191
+ if counter > patience:
192
+ print(f'{bold_start}[ALERT]{bold_end} No improvement in test loss after {counter} epochs; early stopping triggered.')
193
+ break
194
+
195
+ return res
model_training/model.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #####################################
2
+ # Packages & Dependencies
3
+ #####################################
4
+ import torch
5
+ from torch import nn
6
+
7
+
8
+ #####################################
9
+ # VGG Model Class
10
+ #####################################
11
+ class VGGBlock(nn.Module):
12
+ '''
13
+ Defines a modified block in the VGG architecture,
14
+ which includes batch normalization between convolutional layers and ReLU activations.
15
+
16
+ Reference: https://poloclub.github.io/cnn-explainer/
17
+ Reference: https://d2l.ai/chapter_convolutional-modern/vgg.html
18
+
19
+ Args:
20
+ num_convs (int): Number of consecutive convolutional layers + ReLU activations.
21
+ in_channels (int): Number of channels in the input.
22
+ hidden_channels (int): Number of hidden channels between convolutional layers.
23
+ out_channels (int): Number of channels in the output.
24
+ '''
25
+ def __init__(self,
26
+ num_convs: int,
27
+ in_channels: int,
28
+ hidden_channels: int,
29
+ out_channels: int):
30
+ super().__init__()
31
+
32
+ self.layers = []
33
+
34
+ for i in range(num_convs):
35
+ conv_in = in_channels if i == 0 else hidden_channels
36
+ conv_out = out_channels if i == num_convs-1 else hidden_channels
37
+
38
+ self.layers += [
39
+ nn.Conv2d(conv_in, conv_out, kernel_size = 3, stride = 1, padding = 1),
40
+ nn.BatchNorm2d(conv_out),
41
+ nn.ReLU()
42
+ ]
43
+
44
+ self.layers.append(nn.MaxPool2d(kernel_size = 2, stride = 2))
45
+
46
+ self.vgg_blk = nn.Sequential(*self.layers)
47
+
48
+ def forward(self, X: torch.Tensor) -> torch.Tensor:
49
+ '''
50
+ Forward pass of VGG block.
51
+
52
+ Args:
53
+ X (torch.Tensor): Input tensor of shape (batch_size, in_channels, height, width)
54
+ Returns:
55
+ torch.Tensor: Output tensor of shape (batch_size, out_channels, new_height, new_width)
56
+ '''
57
+
58
+ return self.vgg_blk(X)
59
+
60
+ class TinyVGG(nn.Module):
61
+ '''
62
+ Creates a simplified version of a VGG model, adapted from
63
+ https://github.com/poloclub/cnn-explainer/blob/master/tiny-vgg/tiny-vgg.py.
64
+ The main difference is that the hidden dimensions and number of convolutional layers
65
+ remain the same across VGG blocks and the classifier's linear layers has output fewer features.
66
+
67
+ Args:
68
+ num_blks (int): Number of VGG blocks to put in the model
69
+ num_convs (int): Number of consecutive convolutional layers + ReLU activations in each VGG block.
70
+ in_channels (int): Number of channels in the input.
71
+ hidden_channels (int): Number of hidden channels between convolutional layers.
72
+ num_classes (int): Number of class labels.
73
+
74
+ '''
75
+ def __init__(self,
76
+ num_blks: int,
77
+ num_convs: int,
78
+ in_channels: int,
79
+ hidden_channels: int,
80
+ num_classes: int):
81
+ super().__init__()
82
+
83
+ self.all_blks = []
84
+ for i in range(num_blks):
85
+ conv_in = in_channels if i == 0 else hidden_channels
86
+ self.all_blks.append(
87
+ VGGBlock(num_convs, conv_in, hidden_channels, hidden_channels)
88
+ )
89
+
90
+ self.vgg_body = nn.Sequential(*self.all_blks)
91
+ self.classifier = nn.Sequential(
92
+ nn.Flatten(),
93
+ nn.LazyLinear(4096), nn.ReLU(), nn.Dropout(0.5),
94
+ nn.LazyLinear(2048), nn.ReLU(), nn.Dropout(0.5),
95
+ nn.LazyLinear(num_classes)
96
+ )
97
+
98
+ self.vgg_body.apply(self._custom_init)
99
+ self.classifier.apply(self._custom_init)
100
+
101
+ def _custom_init(self, module):
102
+ '''
103
+ Initializes convolutional layer weights with Xavier initialization method.
104
+ Initializes convolutional layer biases to zero.
105
+ '''
106
+ if isinstance(module, (nn.Conv2d)):
107
+ nn.init.xavier_uniform_(module.weight)
108
+ nn.init.zeros_(module.bias)
109
+
110
+ def forward(self, X: torch.Tensor) -> torch.Tensor:
111
+ '''
112
+ Forward pass of the TinyVGG model.
113
+
114
+ Args:
115
+ X (torch.Tensor): Input tensor of shape (batch_size, in_channels, height, width).
116
+
117
+ Returns:
118
+ torch.Tensor: Logits of shape (batch_size, num_classes).
119
+ '''
120
+
121
+ X = self.vgg_body(X)
122
+ return self.classifier(X)
model_training/run_training.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #####################################
2
+ # Packages & Dependencies
3
+ #####################################
4
+ import argparse
5
+ import torch
6
+ from torch import nn
7
+
8
+ import utils, data_setup, model, engine
9
+ import yaml
10
+
11
+ # Setup random seeds
12
+ utils.set_seed(0)
13
+
14
+ # Setup hyperparameters
15
+ parser = argparse.ArgumentParser(fromfile_prefix_chars = '@')
16
+
17
+ parser.add_argument('-nw', '--num-workers', help = 'Number of workers for dataloaders.',
18
+ type = int, default = 0)
19
+ parser.add_argument('-ne', '--num-epochs', help = 'Number of epochs to train model for.',
20
+ type = int, default = 15)
21
+ parser.add_argument('-bs', '--batch-size', help = 'Size of batches to split training set.',
22
+ type = int, default = 100)
23
+ parser.add_argument('-lr', '--learning-rate', help = 'Learning rate for the optimizer.',
24
+ type = float, default = 0.001)
25
+ parser.add_argument('-p', '--patience', help = 'Number of epochs to wait before early stopping.',
26
+ type = int, default = 5)
27
+ parser.add_argument('-md', '--min-delta', help = 'Minimum decrease in loss to reset patience.',
28
+ type = float, default = 0.001)
29
+
30
+ args = parser.parse_args()
31
+
32
+
33
+ #####################################
34
+ # Training Code
35
+ #####################################
36
+ if __name__ == '__main__':
37
+
38
+ print(f'{'#' * 50}\n'
39
+ f'\033[1mTraining hyperparameters:\033[0m \n'
40
+ f' - num-workers: {args.num_workers} \n'
41
+ f' - num-epochs: {args.num_epochs} \n'
42
+ f' - batch-size: {args.batch_size} \n'
43
+ f' - learning-rate: {args.learning_rate} \n'
44
+ f' - patience: {args.patience} \n'
45
+ f' - min-delta: {args.min_delta} \n'
46
+ f'{'#' * 50}')
47
+
48
+ # Get dataloaders
49
+ train_dl, test_dl = data_setup.get_dataloaders(root = './mnist_data',
50
+ batch_size = args.batch_size,
51
+ num_workers = args.num_workers)
52
+
53
+ # Set up saving directory and file name
54
+ save_dir = '../saved_models'
55
+
56
+ base_name = 'tiny_vgg'
57
+ mod_name = f'{base_name}_model.pth'
58
+
59
+ # Get TinyVGG model
60
+ mod_kwargs = {
61
+ 'num_blks': 2,
62
+ 'num_convs': 2,
63
+ 'in_channels': 1,
64
+ 'hidden_channels': 10,
65
+ 'num_classes': len(train_dl.dataset.classes)
66
+ }
67
+
68
+ vgg_mod = model.TinyVGG(**mod_kwargs).to(utils.DEVICE)
69
+
70
+ # Save model kwargs and train settings
71
+ with open(f'{save_dir}/{base_name}_settings.yaml', 'w') as f:
72
+ yaml.dump({'train_kwargs': vars(args), 'mod_kwargs': mod_kwargs}, f)
73
+
74
+ # Get loss function and optimizer
75
+ loss_fn = nn.CrossEntropyLoss()
76
+ optimizer = torch.optim.Adam(params = vgg_mod.parameters(), lr = args.learning_rate)
77
+
78
+ # Train model
79
+ mod_res = engine.train(model = vgg_mod,
80
+ train_dl = train_dl,
81
+ test_dl = test_dl,
82
+ loss_fn = loss_fn,
83
+ optimizer = optimizer,
84
+ num_epochs = args.num_epochs,
85
+ patience = args.patience,
86
+ min_delta = args.min_delta,
87
+ device = utils.DEVICE,
88
+ save_mod = True,
89
+ save_dir = save_dir,
90
+ mod_name = mod_name)
model_training/utils.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #####################################
2
+ # Packages & Dependencies
3
+ #####################################
4
+ import torch
5
+ import random
6
+ import numpy as np
7
+ import os
8
+
9
+ # Setup device and multiprocessing context
10
+ if torch.cuda.is_available():
11
+ DEVICE = torch.device('cuda')
12
+ MP_CONTEXT = None
13
+ elif torch.backends.mps.is_available():
14
+ DEVICE = torch.device('mps')
15
+ MP_CONTEXT = 'forkserver'
16
+ else:
17
+ DEVICE = torch.device('cpu')
18
+ MP_CONTEXT = None
19
+
20
+
21
+ #####################################
22
+ # Functions
23
+ #####################################
24
+ def set_seed(seed: int = 0):
25
+ '''
26
+ Sets random seed and deterministic settings for reproducibility across:
27
+ - PyTorch
28
+ - NumPy
29
+ - Python's random module
30
+
31
+ Args:
32
+ seed (int): The seed value to set.
33
+ '''
34
+ torch.manual_seed(seed)
35
+ np.random.seed(seed)
36
+ random.seed(seed)
37
+ torch.cuda.manual_seed_all(seed)
38
+
39
+ torch.use_deterministic_algorithms(True)
40
+
41
+ def save_model(model: torch.nn.Module,
42
+ save_dir: str,
43
+ mod_name: str):
44
+ '''
45
+ Saves the `state_dict()` of a model to the directory 'save_dir.'
46
+
47
+ Args:
48
+ model (torch.nn.Module): The PyTorch model whose state dict and keyword arguments will be saved.
49
+ save_dir (str): Directory to save the model to.
50
+ mod_name (str): Filename for the saved model. If this doesn't end with '.pth' or '.pt,' it will be added on for the state_dict.
51
+
52
+ '''
53
+ # Create directory if it doesn't exist
54
+ os.makedirs(save_dir, exist_ok = True)
55
+
56
+ # Add .pth if it is not in mod_name
57
+ if not mod_name.endswith('.pth') and not mod_name.endswith('.pt'):
58
+ mod_name += '.pth'
59
+
60
+ # Create save path
61
+ save_path = os.path.join(save_dir, mod_name)
62
+
63
+ # Save model's state dict
64
+ torch.save(obj = model.state_dict(), f = save_path)
requirements.txt CHANGED
@@ -1,6 +1,7 @@
1
- panel
2
- jupyter
3
- transformers
4
- numpy
5
- torch
6
- aiohttp
 
 
1
+ numpy==2.2.4
2
+ matplotlib==3.10.1
3
+ panel==1.4.5
4
+ param==2.1.1
5
+ plotly==6.0.1
6
+ torch==2.6.0
7
+ torchvision==0.21.0
saved_models/tiny_vgg_less_compute_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:94a16b55d2a65b58c30bcad6dcee77d7e45e15221577795aa4de97a508fddced
3
+ size 38494248
saved_models/tiny_vgg_less_compute_settings.yaml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ mod_kwargs:
2
+ hidden_channels: 6
3
+ in_channels: 1
4
+ num_blks: 2
5
+ num_classes: 10
6
+ num_convs: 2
7
+ train_kwargs:
8
+ batch_size: 100
9
+ learning_rate: 0.001
10
+ min_delta: 0.0005
11
+ num_epochs: 50
12
+ num_workers: 0
13
+ patience: 10
saved_models/tiny_vgg_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:10b1913e0c2c44d5a76624196371ae83381a13a939afe9e9e9146354206997e9
3
+ size 41711994
saved_models/tiny_vgg_settings.yaml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ mod_kwargs:
2
+ hidden_channels: 10
3
+ in_channels: 1
4
+ num_blks: 2
5
+ num_classes: 10
6
+ num_convs: 2
7
+ train_kwargs:
8
+ batch_size: 100
9
+ learning_rate: 0.001
10
+ min_delta: 0.001
11
+ num_epochs: 25
12
+ num_workers: 0
13
+ patience: 5