Jechen00 commited on
Commit
b1814f3
·
1 Parent(s): 70fea38

changed model used by app to have less compute

Browse files
app.py CHANGED
@@ -204,7 +204,7 @@ def create_app():
204
  '''
205
  # Used to serve with panel serve in command line
206
  save_dir = FILE_PATH + '/saved_models'
207
- base_name = 'tiny_vgg'
208
 
209
  mod_path = f'{save_dir}/{base_name}_model.pth' # Path to the saved model state dict
210
  settings_path = f'{save_dir}/{base_name}_settings.yaml' # Path to the saved model kwargs
 
204
  '''
205
  # Used to serve with panel serve in command line
206
  save_dir = FILE_PATH + '/saved_models'
207
+ base_name = 'tiny_vgg_less_compute'
208
 
209
  mod_path = f'{save_dir}/{base_name}_model.pth' # Path to the saved model state dict
210
  settings_path = f'{save_dir}/{base_name}_settings.yaml' # Path to the saved model kwargs
model_training/__pycache__/utils.cpython-313.pyc DELETED
Binary file (2.55 kB)
 
model_training/args.txt CHANGED
@@ -1,12 +1,12 @@
1
  --num-workers
2
  0
3
  --num-epochs
4
- 100
5
  --batch-size
6
  100
7
  --learning-rate
8
  0.001
9
  --patience
10
- 20
11
  --min-delta
12
- 0.0005
 
1
  --num-workers
2
  0
3
  --num-epochs
4
+ 300
5
  --batch-size
6
  100
7
  --learning-rate
8
  0.001
9
  --patience
10
+ 50
11
  --min-delta
12
+ 0.0001
model_training/data_setup.py CHANGED
@@ -47,17 +47,19 @@ def get_dataloaders(root: str,
47
  num_workers (int): Number of workers to use for multiprocessing. Default is 0.
48
  '''
49
 
50
- # Get training and testing MNIST data
51
  mnist_train = datasets.MNIST(root, download = True, train = True,
52
- transform = TRAIN_TRANSFORMS)
53
  mnist_test = datasets.MNIST(root, download = True, train = False,
54
  transform = BASE_TRANSFORMS)
55
 
56
  # Create dataloaders
57
  if num_workers > 0:
58
  mp_context = utils.MP_CONTEXT
 
59
  else:
60
  mp_context = None
 
61
 
62
  train_dl = DataLoader(
63
  dataset = mnist_train,
@@ -65,7 +67,8 @@ def get_dataloaders(root: str,
65
  shuffle = True,
66
  num_workers = num_workers,
67
  multiprocessing_context = mp_context,
68
- pin_memory = True
 
69
  )
70
 
71
  test_dl = DataLoader(
@@ -74,7 +77,8 @@ def get_dataloaders(root: str,
74
  shuffle = False,
75
  num_workers = num_workers,
76
  multiprocessing_context = mp_context,
77
- pin_memory = True
 
78
  )
79
 
80
  return train_dl, test_dl
 
47
  num_workers (int): Number of workers to use for multiprocessing. Default is 0.
48
  '''
49
 
50
+ # Get training and testing MNIST data
51
  mnist_train = datasets.MNIST(root, download = True, train = True,
52
+ transform = TRAIN_TRANSFORMS)
53
  mnist_test = datasets.MNIST(root, download = True, train = False,
54
  transform = BASE_TRANSFORMS)
55
 
56
  # Create dataloaders
57
  if num_workers > 0:
58
  mp_context = utils.MP_CONTEXT
59
+ persistent_workers = True
60
  else:
61
  mp_context = None
62
+ persistent_workers = False
63
 
64
  train_dl = DataLoader(
65
  dataset = mnist_train,
 
67
  shuffle = True,
68
  num_workers = num_workers,
69
  multiprocessing_context = mp_context,
70
+ pin_memory = utils.PIN_MEM,
71
+ persistent_workers = persistent_workers
72
  )
73
 
74
  test_dl = DataLoader(
 
77
  shuffle = False,
78
  num_workers = num_workers,
79
  multiprocessing_context = mp_context,
80
+ pin_memory = utils.PIN_MEM,
81
+ persistent_workers = persistent_workers
82
  )
83
 
84
  return train_dl, test_dl
model_training/run_training.py CHANGED
@@ -53,7 +53,7 @@ if __name__ == '__main__':
53
  # Set up saving directory and file name
54
  save_dir = '../saved_models'
55
 
56
- base_name = 'tiny_vgg'
57
  mod_name = f'{base_name}_model.pth'
58
 
59
  # Get TinyVGG model
@@ -61,12 +61,13 @@ if __name__ == '__main__':
61
  'num_blks': 2,
62
  'num_convs': 2,
63
  'in_channels': 1,
64
- 'hidden_channels': 10,
65
- 'fc_hidden_dim': 64,
66
  'num_classes': len(train_dl.dataset.classes)
67
  }
68
 
69
  vgg_mod = model.TinyVGG(**mod_kwargs).to(utils.DEVICE)
 
70
 
71
  # Save model kwargs and train settings
72
  with open(f'{save_dir}/{base_name}_settings.yaml', 'w') as f:
 
53
  # Set up saving directory and file name
54
  save_dir = '../saved_models'
55
 
56
+ base_name = 'tiny_vgg_less_compute'
57
  mod_name = f'{base_name}_model.pth'
58
 
59
  # Get TinyVGG model
 
61
  'num_blks': 2,
62
  'num_convs': 2,
63
  'in_channels': 1,
64
+ 'hidden_channels': 5,
65
+ 'fc_hidden_dim': 128,
66
  'num_classes': len(train_dl.dataset.classes)
67
  }
68
 
69
  vgg_mod = model.TinyVGG(**mod_kwargs).to(utils.DEVICE)
70
+ torch.compile(vgg_mod)
71
 
72
  # Save model kwargs and train settings
73
  with open(f'{save_dir}/{base_name}_settings.yaml', 'w') as f:
model_training/utils.py CHANGED
@@ -10,12 +10,15 @@ import os
10
  if torch.cuda.is_available():
11
  DEVICE = torch.device('cuda')
12
  MP_CONTEXT = None
 
13
  elif torch.backends.mps.is_available():
14
  DEVICE = torch.device('mps')
15
  MP_CONTEXT = 'forkserver'
 
16
  else:
17
  DEVICE = torch.device('cpu')
18
  MP_CONTEXT = None
 
19
 
20
 
21
  #####################################
@@ -37,6 +40,7 @@ def set_seed(seed: int = 0):
37
  torch.cuda.manual_seed_all(seed)
38
 
39
  torch.use_deterministic_algorithms(True)
 
40
 
41
  def save_model(model: torch.nn.Module,
42
  save_dir: str,
 
10
  if torch.cuda.is_available():
11
  DEVICE = torch.device('cuda')
12
  MP_CONTEXT = None
13
+ PIN_MEM = True
14
  elif torch.backends.mps.is_available():
15
  DEVICE = torch.device('mps')
16
  MP_CONTEXT = 'forkserver'
17
+ PIN_MEM = False
18
  else:
19
  DEVICE = torch.device('cpu')
20
  MP_CONTEXT = None
21
+ PIN_MEM = False
22
 
23
 
24
  #####################################
 
40
  torch.cuda.manual_seed_all(seed)
41
 
42
  torch.use_deterministic_algorithms(True)
43
+ os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
44
 
45
  def save_model(model: torch.nn.Module,
46
  save_dir: str,
saved_models/tiny_vgg_less_compute_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:72c2bf04c913dd30f77bb7dde8e4f9bd253533dbd7349475dbc2d0775b875a5e
3
+ size 145606
saved_models/tiny_vgg_less_compute_settings.yaml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ mod_kwargs:
2
+ fc_hidden_dim: 128
3
+ hidden_channels: 5
4
+ in_channels: 1
5
+ num_blks: 2
6
+ num_classes: 10
7
+ num_convs: 2
8
+ train_kwargs:
9
+ batch_size: 100
10
+ learning_rate: 0.001
11
+ min_delta: 0.0001
12
+ num_epochs: 300
13
+ num_workers: 0
14
+ patience: 50