File size: 5,478 Bytes
e3641b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import copy
import os
import os.path as osp
import time
import warnings
import click
import yaml

from glob import glob

import torch
import torch.distributed as dist

from vit_utils.util import init_random_seed, set_random_seed
from vit_utils.dist_util import get_dist_info, init_dist
from vit_utils.logging import get_root_logger

import configs.ViTPose_small_coco_256x192 as s_cfg
import configs.ViTPose_base_coco_256x192 as b_cfg
import configs.ViTPose_large_coco_256x192 as l_cfg
import configs.ViTPose_huge_coco_256x192 as h_cfg

from vit_models.model import ViTPose
from datasets.COCO import COCODataset
from vit_utils.train_valid_fn import train_model

CUR_PATH = osp.dirname(__file__)

@click.command()
@click.option('--config-path', type=click.Path(exists=True), default='config.yaml', required=True, help='train config file path')
@click.option('--model-name', type=str, default='b', required=True, help='[b: ViT-B, l: ViT-L, h: ViT-H]')
def main(config_path, model_name):
        
    cfg = {'b':b_cfg,
           's':s_cfg,
           'l':l_cfg,
           'h':h_cfg}.get(model_name.lower())
    # Load config.yaml
    with open(config_path, 'r') as f:
        cfg_yaml = yaml.load(f, Loader=yaml.SafeLoader)
        
    for k, v in cfg_yaml.items():
        if hasattr(cfg, k):
            raise ValueError(f"Already exists {k} in config")
        else:
            cfg.__setattr__(k, v)

    # set cudnn_benchmark
    if cfg.cudnn_benchmark:
        torch.backends.cudnn.benchmark = True
    
    # Set work directory (session-level)
    if not hasattr(cfg, 'work_dir'):
        cfg.__setattr__('work_dir', f"{CUR_PATH}/runs/train")
        
    if not osp.exists(cfg.work_dir):
        os.makedirs(cfg.work_dir)
    session_list = sorted(glob(f"{cfg.work_dir}/*"))
    if len(session_list) == 0:
        session = 1
    else:
        session = int(os.path.basename(session_list[-1])) + 1
    session_dir = osp.join(cfg.work_dir, str(session).zfill(3))
    os.makedirs(session_dir)
    cfg.__setattr__('work_dir', session_dir)
        

    if cfg.autoscale_lr:
        # apply the linear scaling rule (https://arxiv.org/abs/1706.02677)
        cfg.optimizer['lr'] = cfg.optimizer['lr'] * len(cfg.gpu_ids) / 8

    # init distributed env first, since logger depends on the dist info.
    if cfg.launcher == 'none':
        distributed = False
        if len(cfg.gpu_ids) > 1:
            warnings.warn(
                f"We treat {cfg['gpu_ids']} as gpu-ids, and reset to "
                f"{cfg['gpu_ids'][0:1]} as gpu-ids to avoid potential error in "
                "non-distribute training time.")
            cfg.gpu_ids = cfg.gpu_ids[0:1]
    else:
        distributed = True
        init_dist(cfg.launcher, **cfg.dist_params)
        # re-set gpu_ids with distributed training mode
        _, world_size = get_dist_info()
        cfg.gpu_ids = range(world_size)

    # init the logger before other steps
    timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
    log_file = osp.join(session_dir, f'{timestamp}.log')
    logger = get_root_logger(log_file=log_file)

    # init the meta dict to record some important information such as
    # environment info and seed, which will be logged
    meta = dict()

    # log some basic info
    logger.info(f'Distributed training: {distributed}')

    # set random seeds
    seed = init_random_seed(cfg.seed)
    logger.info(f"Set random seed to {seed}, "
                f"deterministic: {cfg.deterministic}")
    set_random_seed(seed, deterministic=cfg.deterministic)
    meta['seed'] = seed

    # Set model
    model = ViTPose(cfg.model)
    if cfg.resume_from:
        # Load ckpt partially
        ckpt_state = torch.load(cfg.resume_from)['state_dict']
        ckpt_state.pop('keypoint_head.final_layer.bias')
        ckpt_state.pop('keypoint_head.final_layer.weight')
        model.load_state_dict(ckpt_state, strict=False)

        # freeze the backbone, leave the head to be finetuned
        model.backbone.frozen_stages = model.backbone.depth - 1
        model.backbone.freeze_ffn = True
        model.backbone.freeze_attn = True
        model.backbone._freeze_stages()
    
    # Set dataset
    datasets_train = COCODataset(
        root_path=cfg.data_root, 
        data_version="feet_train",
        is_train=True, 
        use_gt_bboxes=True,
        image_width=192, 
        image_height=256,
        scale=True, 
        scale_factor=0.35, 
        flip_prob=0.5, 
        rotate_prob=0.5, 
        rotation_factor=45., 
        half_body_prob=0.3,
        use_different_joints_weight=True, 
        heatmap_sigma=3, 
        soft_nms=False
        )
    
    datasets_valid = COCODataset(
        root_path=cfg.data_root, 
        data_version="feet_val",
        is_train=False, 
        use_gt_bboxes=True,
        image_width=192, 
        image_height=256,
        scale=False, 
        scale_factor=0.35, 
        flip_prob=0.5, 
        rotate_prob=0.5, 
        rotation_factor=45., 
        half_body_prob=0.3,
        use_different_joints_weight=True, 
        heatmap_sigma=3, 
        soft_nms=False
        )

    train_model(
        model=model,
        datasets_train=datasets_train,
        datasets_valid=datasets_valid,
        cfg=cfg,
        distributed=distributed,
        validate=cfg.validate,
        timestamp=timestamp,
        meta=meta
        )


if __name__ == '__main__':
    main()