Jechen00 commited on
Commit
e260219
·
1 Parent(s): 2565d4d

Removed optional plotting from model_training/data_setup.py. Matplotlib no longer required.

Browse files
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ __pycache__/
2
+ .DS_Store
3
+ .ipynb_checkpoints
4
+ *.ipynb
5
+ mnist_data/
model_training/__pycache__/data_setup.cpython-313.pyc DELETED
Binary file (6.32 kB)
 
model_training/data_setup.py CHANGED
@@ -13,7 +13,6 @@ import io
13
  import base64
14
  from PIL import Image
15
  import numpy as np
16
- import matplotlib.pyplot as plt
17
 
18
  # Transformations applied to each image
19
  BASE_TRANSFORMS = v2.Compose([
@@ -80,7 +79,7 @@ def get_dataloaders(root: str,
80
 
81
  return train_dl, test_dl
82
 
83
- def mnist_preprocess(uri: str, plot: bool = False):
84
  '''
85
  Preprocesses a data URI representing a handwritten digit image according to the pipeline used in the MNIST dataset.
86
  The pipeline includes:
@@ -93,8 +92,7 @@ def mnist_preprocess(uri: str, plot: bool = False):
93
 
94
  Args:
95
  uri (str): A string representing the full data URI.
96
- plot (bool, optional): If True, the resized 20x20 image is plotted alongside the final 28x28 image (pre-normalization).
97
- The red lines on these plots intersect at the COM position. Default is False.
98
  Returns:
99
  Tensor: A tensor of shape (1, 28, 28) representing the preprocessed image, normalized using MNIST statistics.
100
  '''
@@ -134,22 +132,6 @@ def mnist_preprocess(uri: str, plot: bool = False):
134
 
135
  # Paste cropped image into 28x28 field such that the old COM (com_y, com_x), is at the center (14, 14)
136
  new_img[new_slice_y, new_slice_x] = img[old_slice_y, old_slice_x]
137
-
138
- if plot:
139
- fig, axes = plt.subplots(nrows = 1, ncols = 2, figsize = (12, 6))
140
-
141
- axes[0].imshow(img, cmap = 'grey')
142
- axes[0].axhline(com_y, c = 'red')
143
- axes[0].axvline(com_x, c = 'red')
144
-
145
- axes[1].imshow(new_img, cmap = 'grey')
146
- axes[1].axhline(new_com_y, c = 'red')
147
- axes[1].axvline(new_com_x, c = 'red')
148
-
149
- axes[0].set_title(f'Original Resized {img.shape[0]}x{img.shape[1]} Image')
150
- axes[1].set_title('New Centered 28x28 Image')
151
-
152
- plt.tight_layout()
153
 
154
  # Return transformed tensor of new image. This includes normalizing to MNIST stats
155
- return BASE_TRANSFORMS(new_img)
 
13
  import base64
14
  from PIL import Image
15
  import numpy as np
 
16
 
17
  # Transformations applied to each image
18
  BASE_TRANSFORMS = v2.Compose([
 
79
 
80
  return train_dl, test_dl
81
 
82
+ def mnist_preprocess(uri: str):
83
  '''
84
  Preprocesses a data URI representing a handwritten digit image according to the pipeline used in the MNIST dataset.
85
  The pipeline includes:
 
92
 
93
  Args:
94
  uri (str): A string representing the full data URI.
95
+
 
96
  Returns:
97
  Tensor: A tensor of shape (1, 28, 28) representing the preprocessed image, normalized using MNIST statistics.
98
  '''
 
132
 
133
  # Paste cropped image into 28x28 field such that the old COM (com_y, com_x), is at the center (14, 14)
134
  new_img[new_slice_y, new_slice_x] = img[old_slice_y, old_slice_x]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
  # Return transformed tensor of new image. This includes normalizing to MNIST stats
137
+ return BASE_TRANSFORMS(new_img)
requirements.txt CHANGED
@@ -1,5 +1,4 @@
1
  numpy==2.2.4
2
- matplotlib==3.10.1
3
  panel==1.4.5
4
  param==2.1.1
5
  plotly==6.0.1
 
1
  numpy==2.2.4
 
2
  panel==1.4.5
3
  param==2.1.1
4
  plotly==6.0.1