Removed optional plotting from model_training/data_setup.py. Matplotlib no longer required.
Browse files- .gitignore +5 -0
- model_training/__pycache__/data_setup.cpython-313.pyc +0 -0
- model_training/data_setup.py +3 -21
- requirements.txt +0 -1
.gitignore
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__/
|
2 |
+
.DS_Store
|
3 |
+
.ipynb_checkpoints
|
4 |
+
*.ipynb
|
5 |
+
mnist_data/
|
model_training/__pycache__/data_setup.cpython-313.pyc
DELETED
Binary file (6.32 kB)
|
|
model_training/data_setup.py
CHANGED
@@ -13,7 +13,6 @@ import io
|
|
13 |
import base64
|
14 |
from PIL import Image
|
15 |
import numpy as np
|
16 |
-
import matplotlib.pyplot as plt
|
17 |
|
18 |
# Transformations applied to each image
|
19 |
BASE_TRANSFORMS = v2.Compose([
|
@@ -80,7 +79,7 @@ def get_dataloaders(root: str,
|
|
80 |
|
81 |
return train_dl, test_dl
|
82 |
|
83 |
-
def mnist_preprocess(uri: str
|
84 |
'''
|
85 |
Preprocesses a data URI representing a handwritten digit image according to the pipeline used in the MNIST dataset.
|
86 |
The pipeline includes:
|
@@ -93,8 +92,7 @@ def mnist_preprocess(uri: str, plot: bool = False):
|
|
93 |
|
94 |
Args:
|
95 |
uri (str): A string representing the full data URI.
|
96 |
-
|
97 |
-
The red lines on these plots intersect at the COM position. Default is False.
|
98 |
Returns:
|
99 |
Tensor: A tensor of shape (1, 28, 28) representing the preprocessed image, normalized using MNIST statistics.
|
100 |
'''
|
@@ -134,22 +132,6 @@ def mnist_preprocess(uri: str, plot: bool = False):
|
|
134 |
|
135 |
# Paste cropped image into 28x28 field such that the old COM (com_y, com_x), is at the center (14, 14)
|
136 |
new_img[new_slice_y, new_slice_x] = img[old_slice_y, old_slice_x]
|
137 |
-
|
138 |
-
if plot:
|
139 |
-
fig, axes = plt.subplots(nrows = 1, ncols = 2, figsize = (12, 6))
|
140 |
-
|
141 |
-
axes[0].imshow(img, cmap = 'grey')
|
142 |
-
axes[0].axhline(com_y, c = 'red')
|
143 |
-
axes[0].axvline(com_x, c = 'red')
|
144 |
-
|
145 |
-
axes[1].imshow(new_img, cmap = 'grey')
|
146 |
-
axes[1].axhline(new_com_y, c = 'red')
|
147 |
-
axes[1].axvline(new_com_x, c = 'red')
|
148 |
-
|
149 |
-
axes[0].set_title(f'Original Resized {img.shape[0]}x{img.shape[1]} Image')
|
150 |
-
axes[1].set_title('New Centered 28x28 Image')
|
151 |
-
|
152 |
-
plt.tight_layout()
|
153 |
|
154 |
# Return transformed tensor of new image. This includes normalizing to MNIST stats
|
155 |
-
return BASE_TRANSFORMS(new_img)
|
|
|
13 |
import base64
|
14 |
from PIL import Image
|
15 |
import numpy as np
|
|
|
16 |
|
17 |
# Transformations applied to each image
|
18 |
BASE_TRANSFORMS = v2.Compose([
|
|
|
79 |
|
80 |
return train_dl, test_dl
|
81 |
|
82 |
+
def mnist_preprocess(uri: str):
|
83 |
'''
|
84 |
Preprocesses a data URI representing a handwritten digit image according to the pipeline used in the MNIST dataset.
|
85 |
The pipeline includes:
|
|
|
92 |
|
93 |
Args:
|
94 |
uri (str): A string representing the full data URI.
|
95 |
+
|
|
|
96 |
Returns:
|
97 |
Tensor: A tensor of shape (1, 28, 28) representing the preprocessed image, normalized using MNIST statistics.
|
98 |
'''
|
|
|
132 |
|
133 |
# Paste cropped image into 28x28 field such that the old COM (com_y, com_x), is at the center (14, 14)
|
134 |
new_img[new_slice_y, new_slice_x] = img[old_slice_y, old_slice_x]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
135 |
|
136 |
# Return transformed tensor of new image. This includes normalizing to MNIST stats
|
137 |
+
return BASE_TRANSFORMS(new_img)
|
requirements.txt
CHANGED
@@ -1,5 +1,4 @@
|
|
1 |
numpy==2.2.4
|
2 |
-
matplotlib==3.10.1
|
3 |
panel==1.4.5
|
4 |
param==2.1.1
|
5 |
plotly==6.0.1
|
|
|
1 |
numpy==2.2.4
|
|
|
2 |
panel==1.4.5
|
3 |
param==2.1.1
|
4 |
plotly==6.0.1
|