File size: 12,718 Bytes
af7c0ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# A script to run multinode training with submitit.
# --------------------------------------------------------

import argparse
import os.path as osp
import submitit
import itertools

from omegaconf import OmegaConf
from paintmind.engine.util import instantiate_from_config
from paintmind.utils.device_utils import configure_compute_backend


def parse_args():
    parser = argparse.ArgumentParser("Submitit for accelerator training")
    # Slurm configuration
    parser.add_argument("--ngpus", default=8, type=int, help="Number of gpus to request on each node")
    parser.add_argument("--nodes", default=1, type=int, help="Number of nodes to request")
    parser.add_argument("--timeout", default=7000, type=int, help="Duration of the job, default 5 days")
    parser.add_argument("--qos", default="normal", type=str, help="QOS to request")
    parser.add_argument("--job_dir", default="", type=str, help="Job dir. Leave empty for automatic.")
    parser.add_argument("--partition", default="h100-camera-train", type=str, help="Partition where to submit")
    parser.add_argument("--exclude", default="", type=str, help="Exclude nodes from the partition")
    parser.add_argument("--nodelist", default="", type=str, help="Nodelist to request")
    parser.add_argument('--comment', default="", type=str, help="Comment to pass to scheduler")
    
    # Model and testing configuration
    parser.add_argument('--model', type=str, nargs='+', default=[None], help="Path to model(s)")
    parser.add_argument('--step', type=int, nargs='+', default=[250000], help="Step number(s)")
    parser.add_argument('--cfg', type=str, default=None, help="Path to config file")
    parser.add_argument('--dataset', type=str, default='imagenet', help="Dataset to use")
    
    # Legacy parameter (preserved for backward compatibility)
    parser.add_argument('--cfg_value', type=float, nargs='+', default=[None],
                       help='Legacy parameter for GPT classifier-free guidance scale')
    
    # CFG-related parameters - all with nargs='+' to support multiple values
    parser.add_argument('--ae_cfg', type=float, nargs='+', default=[None], 
                       help="Autoencoder classifier-free guidance scale")
    parser.add_argument('--diff_cfg', type=float, nargs='+', default=[None], 
                       help="Diffusion classifier-free guidance scale")
    parser.add_argument('--cfg_schedule', type=str, nargs='+', default=[None], 
                       help="CFG schedule type (e.g., constant, linear)")
    parser.add_argument('--diff_cfg_schedule', type=str, nargs='+', default=[None], 
                       help="Diffusion CFG schedule type (e.g., constant, inv_linear)")
    parser.add_argument('--test_num_slots', type=int, nargs='+', default=[None], 
                       help="Number of slots to use for inference")
    parser.add_argument('--temperature', type=float, nargs='+', default=[None], 
                       help="Temperature for sampling")
    
    return parser.parse_args()


def load_config(model_path, cfg_path=None):
    """Load configuration from file or model directory."""
    if cfg_path is not None and osp.exists(cfg_path):
        config_path = cfg_path
    elif model_path and osp.exists(osp.join(model_path, 'config.yaml')):
        config_path = osp.join(model_path, 'config.yaml')
    else:
        raise ValueError(f"No config file found at {model_path} or {cfg_path}")
    
    return OmegaConf.load(config_path)


def setup_checkpoint_path(model_path, step, config):
    """Set up the checkpoint path based on model and step."""
    if model_path:
        ckpt_path = osp.join(model_path, 'models', f'step{step}')
        if not osp.exists(ckpt_path):
            print(f"Skipping non-existent checkpoint: {ckpt_path}")
            return None
        if hasattr(config.trainer.params, 'model'):
            config.trainer.params.model.params.ckpt_path = ckpt_path
        else:
            config.trainer.params.gpt_model.params.ckpt_path = ckpt_path
    else:
        result_folder = config.trainer.params.result_folder
        ckpt_path = osp.join(result_folder, 'models', f'step{step}')
        if hasattr(config.trainer.params, 'model'):
            config.trainer.params.model.params.ckpt_path = ckpt_path
        else:
            config.trainer.params.gpt_model.params.ckpt_path = ckpt_path
    
    return ckpt_path


def setup_test_config(config, use_coco=False):
    """Set up common test configuration parameters."""
    config.trainer.params.test_dataset = config.trainer.params.dataset
    if not use_coco:
        config.trainer.params.test_dataset.params.split = 'val'
    else:
        config.trainer.params.test_dataset.target = 'paintmind.utils.datasets.COCO'
        config.trainer.params.test_dataset.params.root = './dataset/coco'
        config.trainer.params.test_dataset.params.split = 'val2017'
    config.trainer.params.test_only = True
    config.trainer.params.compile = False
    config.trainer.params.eval_fid = True
    config.trainer.params.fid_stats = 'fid_stats/adm_in256_stats.npz'
    if hasattr(config.trainer.params, 'model'):
        config.trainer.params.model.params.num_sampling_steps = '250'
    else:
        config.trainer.params.ae_model.params.num_sampling_steps = '250'

def apply_cfg_params(config, param_dict):
    """Apply CFG-related parameters to the config."""
    # Apply each parameter if it's not None
    if param_dict.get('cfg_value') is not None:
        config.trainer.params.cfg = param_dict['cfg_value']
        print(f"Setting cfg to {param_dict['cfg_value']}")
    
    if param_dict.get('ae_cfg') is not None:
        config.trainer.params.ae_cfg = param_dict['ae_cfg']
        print(f"Setting ae_cfg to {param_dict['ae_cfg']}")
        
    if param_dict.get('diff_cfg') is not None:
        config.trainer.params.diff_cfg = param_dict['diff_cfg']
        print(f"Setting diff_cfg to {param_dict['diff_cfg']}")
        
    if param_dict.get('cfg_schedule') is not None:
        config.trainer.params.cfg_schedule = param_dict['cfg_schedule']
        print(f"Setting cfg_schedule to {param_dict['cfg_schedule']}")
        
    if param_dict.get('diff_cfg_schedule') is not None:
        config.trainer.params.diff_cfg_schedule = param_dict['diff_cfg_schedule']
        print(f"Setting diff_cfg_schedule to {param_dict['diff_cfg_schedule']}")
        
    if param_dict.get('test_num_slots') is not None:
        config.trainer.params.test_num_slots = param_dict['test_num_slots']
        print(f"Setting test_num_slots to {param_dict['test_num_slots']}")

    if param_dict.get('temperature') is not None:
        config.trainer.params.temperature = param_dict['temperature']
        print(f"Setting temperature to {param_dict['temperature']}")


def run_test(config):
    """Instantiate trainer and run test."""
    trainer = instantiate_from_config(config.trainer)
    trainer.train()


def generate_param_combinations(args):
    """Generate all combinations of parameters from the provided arguments."""
    # Create parameter grid for all combinations
    param_grid = {
        'cfg_value': [None] if args.cfg_value == [None] else args.cfg_value,
        'ae_cfg': [None] if args.ae_cfg == [None] else args.ae_cfg,
        'diff_cfg': [None] if args.diff_cfg == [None] else args.diff_cfg,
        'cfg_schedule': [None] if args.cfg_schedule == [None] else args.cfg_schedule,
        'diff_cfg_schedule': [None] if args.diff_cfg_schedule == [None] else args.diff_cfg_schedule,
        'test_num_slots': [None] if args.test_num_slots == [None] else args.test_num_slots,
        'temperature': [None] if args.temperature == [None] else args.temperature
    }
    
    # Get all parameter names that have non-None values
    active_params = [k for k, v in param_grid.items() if v != [None]]
    
    if not active_params:
        # If no parameters are specified, yield a dict with all None values
        yield {k: None for k in param_grid.keys()}
        return
    
    # Generate all combinations of active parameters
    active_values = [param_grid[k] for k in active_params]
    for combination in itertools.product(*active_values):
        param_dict = {k: None for k in param_grid.keys()}  # Start with all None
        for i, param_name in enumerate(active_params):
            param_dict[param_name] = combination[i]
        yield param_dict


class Trainer(object):
    def __init__(self, args):
        self.args = args

    def __call__(self):
        """Main entry point for the submitit job."""
        self._setup_gpu_args()
        configure_compute_backend()
        self._run_tests()

    def _run_tests(self):
        """Run tests for all specified models and steps."""
        for step in self.args.step:
            for model in self.args.model:
                print(f"Testing model: {model} at step: {step}")
                
                # Load configuration
                config = load_config(model, self.args.cfg)
                
                # Setup checkpoint path
                ckpt_path = setup_checkpoint_path(model, step, config)
                if ckpt_path is None:
                    continue
                
                use_coco = self.args.dataset == 'coco' or self.args.dataset == 'COCO'
                # Setup test configuration
                setup_test_config(config, use_coco)
                
                # Generate and apply all parameter combinations
                for param_dict in generate_param_combinations(self.args):
                    # Create a copy of the config for each parameter combination
                    current_config = OmegaConf.create(OmegaConf.to_container(config, resolve=True))
                    
                    # Print parameter combination
                    param_str = ", ".join([f"{k}={v}" for k, v in param_dict.items() if v is not None])
                    print(f"Testing with parameters: {param_str}")
                    
                    # Apply parameters and run test
                    apply_cfg_params(current_config, param_dict)
                    run_test(current_config)

    def _setup_gpu_args(self):
        """Set up GPU and distributed environment variables."""
        import submitit

        print("Exporting PyTorch distributed environment variables")
        dist_env = submitit.helpers.TorchDistributedEnvironment().export(set_cuda_visible_devices=False)
        print(f"Master: {dist_env.master_addr}:{dist_env.master_port}")
        print(f"Rank: {dist_env.rank}")
        print(f"World size: {dist_env.world_size}")
        print(f"Local rank: {dist_env.local_rank}")
        print(f"Local world size: {dist_env.local_world_size}")

        job_env = submitit.JobEnvironment()
        self.args.output_dir = str(self.args.output_dir).replace("%j", str(job_env.job_id))
        self.args.log_dir = self.args.output_dir
        print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}")


def main():
    """Main function to set up and submit the job."""
    args = parse_args()

    # Determine job directory
    if args.cfg is not None and osp.exists(args.cfg):
        config = OmegaConf.load(args.cfg)
    elif osp.exists(osp.join(args.model[0], 'config.yaml')):
        config = OmegaConf.load(osp.join(args.model[0], 'config.yaml'))
    else:
        raise ValueError(f"No config file found at {args.model[0]} or {args.cfg}")
    
    args.job_dir = config.trainer.params.result_folder

    # Set up the executor
    executor = submitit.AutoExecutor(folder=args.job_dir, slurm_max_num_timeout=30)

    # Configure slurm parameters
    slurm_kwargs = {
        'slurm_signal_delay_s': 120,
        'slurm_qos': args.qos
    }
    
    if args.comment:
        slurm_kwargs['slurm_comment'] = args.comment
    if args.exclude:
        slurm_kwargs['slurm_exclude'] = args.exclude
    if args.nodelist:
        slurm_kwargs['slurm_nodelist'] = args.nodelist

    # Update executor parameters
    executor.update_parameters(
        gpus_per_node=args.ngpus,
        tasks_per_node=args.ngpus,  # one task per GPU
        nodes=args.nodes,
        timeout_min=args.timeout,
        slurm_partition=args.partition,
        name="fid",
        **slurm_kwargs
    )

    args.output_dir = args.job_dir

    # Submit the job
    trainer = Trainer(args)
    job = executor.submit(trainer)

    print("Submitted job_id:", job.job_id)


if __name__ == "__main__":
    main()