Spaces:
Sleeping
Sleeping
File size: 5,478 Bytes
e3641b1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 |
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import copy
import os
import os.path as osp
import time
import warnings
import click
import yaml
from glob import glob
import torch
import torch.distributed as dist
from vit_utils.util import init_random_seed, set_random_seed
from vit_utils.dist_util import get_dist_info, init_dist
from vit_utils.logging import get_root_logger
import configs.ViTPose_small_coco_256x192 as s_cfg
import configs.ViTPose_base_coco_256x192 as b_cfg
import configs.ViTPose_large_coco_256x192 as l_cfg
import configs.ViTPose_huge_coco_256x192 as h_cfg
from vit_models.model import ViTPose
from datasets.COCO import COCODataset
from vit_utils.train_valid_fn import train_model
CUR_PATH = osp.dirname(__file__)
@click.command()
@click.option('--config-path', type=click.Path(exists=True), default='config.yaml', required=True, help='train config file path')
@click.option('--model-name', type=str, default='b', required=True, help='[b: ViT-B, l: ViT-L, h: ViT-H]')
def main(config_path, model_name):
cfg = {'b':b_cfg,
's':s_cfg,
'l':l_cfg,
'h':h_cfg}.get(model_name.lower())
# Load config.yaml
with open(config_path, 'r') as f:
cfg_yaml = yaml.load(f, Loader=yaml.SafeLoader)
for k, v in cfg_yaml.items():
if hasattr(cfg, k):
raise ValueError(f"Already exists {k} in config")
else:
cfg.__setattr__(k, v)
# set cudnn_benchmark
if cfg.cudnn_benchmark:
torch.backends.cudnn.benchmark = True
# Set work directory (session-level)
if not hasattr(cfg, 'work_dir'):
cfg.__setattr__('work_dir', f"{CUR_PATH}/runs/train")
if not osp.exists(cfg.work_dir):
os.makedirs(cfg.work_dir)
session_list = sorted(glob(f"{cfg.work_dir}/*"))
if len(session_list) == 0:
session = 1
else:
session = int(os.path.basename(session_list[-1])) + 1
session_dir = osp.join(cfg.work_dir, str(session).zfill(3))
os.makedirs(session_dir)
cfg.__setattr__('work_dir', session_dir)
if cfg.autoscale_lr:
# apply the linear scaling rule (https://arxiv.org/abs/1706.02677)
cfg.optimizer['lr'] = cfg.optimizer['lr'] * len(cfg.gpu_ids) / 8
# init distributed env first, since logger depends on the dist info.
if cfg.launcher == 'none':
distributed = False
if len(cfg.gpu_ids) > 1:
warnings.warn(
f"We treat {cfg['gpu_ids']} as gpu-ids, and reset to "
f"{cfg['gpu_ids'][0:1]} as gpu-ids to avoid potential error in "
"non-distribute training time.")
cfg.gpu_ids = cfg.gpu_ids[0:1]
else:
distributed = True
init_dist(cfg.launcher, **cfg.dist_params)
# re-set gpu_ids with distributed training mode
_, world_size = get_dist_info()
cfg.gpu_ids = range(world_size)
# init the logger before other steps
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
log_file = osp.join(session_dir, f'{timestamp}.log')
logger = get_root_logger(log_file=log_file)
# init the meta dict to record some important information such as
# environment info and seed, which will be logged
meta = dict()
# log some basic info
logger.info(f'Distributed training: {distributed}')
# set random seeds
seed = init_random_seed(cfg.seed)
logger.info(f"Set random seed to {seed}, "
f"deterministic: {cfg.deterministic}")
set_random_seed(seed, deterministic=cfg.deterministic)
meta['seed'] = seed
# Set model
model = ViTPose(cfg.model)
if cfg.resume_from:
# Load ckpt partially
ckpt_state = torch.load(cfg.resume_from)['state_dict']
ckpt_state.pop('keypoint_head.final_layer.bias')
ckpt_state.pop('keypoint_head.final_layer.weight')
model.load_state_dict(ckpt_state, strict=False)
# freeze the backbone, leave the head to be finetuned
model.backbone.frozen_stages = model.backbone.depth - 1
model.backbone.freeze_ffn = True
model.backbone.freeze_attn = True
model.backbone._freeze_stages()
# Set dataset
datasets_train = COCODataset(
root_path=cfg.data_root,
data_version="feet_train",
is_train=True,
use_gt_bboxes=True,
image_width=192,
image_height=256,
scale=True,
scale_factor=0.35,
flip_prob=0.5,
rotate_prob=0.5,
rotation_factor=45.,
half_body_prob=0.3,
use_different_joints_weight=True,
heatmap_sigma=3,
soft_nms=False
)
datasets_valid = COCODataset(
root_path=cfg.data_root,
data_version="feet_val",
is_train=False,
use_gt_bboxes=True,
image_width=192,
image_height=256,
scale=False,
scale_factor=0.35,
flip_prob=0.5,
rotate_prob=0.5,
rotation_factor=45.,
half_body_prob=0.3,
use_different_joints_weight=True,
heatmap_sigma=3,
soft_nms=False
)
train_model(
model=model,
datasets_train=datasets_train,
datasets_valid=datasets_valid,
cfg=cfg,
distributed=distributed,
validate=cfg.validate,
timestamp=timestamp,
meta=meta
)
if __name__ == '__main__':
main()
|