File size: 5,050 Bytes
e8f2571 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
_base_ = [
'./_base_/datasets/hsi_detection.py', './_base_/default_runtime.py'
]
# fp16 = dict(loss_scale=512.)
norm = 'LN' #'IN1d' 'LN''BN1d'
num_levels = 2
in_channels = 30
embed_dims = 256 # embed_dims256
query_initial = 'one'
model = dict(
type='SpecDetr',
num_queries = 900, # num_matching_queries 900
num_query_per_cat= 5,
num_fix_query = 0,
with_box_refine=True,
as_two_stage=True,
num_feature_levels=num_levels,
candidate_bboxes_size = 0.01, # initial candidate_bboxes after encode 0.01
scale_gt_bboxes_size = 0, # [0,0.5) 0.25,
training_dn = True, # use dn when training
# dn_only_pos = False,
dn_type = 'CDN', # DN CDNV1 CDN
query_initial = query_initial,
remove_last_candidate = True, # when the last feacture size of backbone is 1
data_preprocessor=dict(
type='HSIDetDataPreprocessor'),
backbone=dict(
type='No_backbone_ST',
in_channels=in_channels,
embed_dims=embed_dims,
# Please only add indices that would be used
# in FPN, otherwise some parameter will not be used
num_levels=num_levels,
norm_cfg=dict(type=norm),
),
encoder=dict(
num_layers=6,
layer_cfg=dict(
self_attn_cfg=dict(embed_dims=embed_dims, num_levels=num_levels, num_points=4,
dropout=0.0), # 0.1 for DeformDETR
ffn_cfg=dict(
embed_dims=embed_dims,
feedforward_channels=embed_dims*8, # 1024 for DeformDETR
ffn_drop=0.0),
norm_cfg=dict(type=norm),)), # 0.1 for DeformDETR
decoder=dict(
num_layers=6,
return_intermediate=True,
layer_cfg=dict(
self_attn_cfg=dict(embed_dims=embed_dims, num_heads=8,
dropout=0.0), # 0.1 for DeformDETR
cross_attn_cfg=dict(embed_dims=embed_dims, num_levels=num_levels, num_points=4,
dropout=0.0), # 0.1 for DeformDETR
ffn_cfg=dict(
embed_dims=embed_dims,
feedforward_channels=embed_dims*8, # 1024 for DeformDETR 2048 for dino
ffn_drop=0.0),
norm_cfg=dict(type=norm),), # 0.1 for DeformDETR norm_cfg=dict(type='LN')
post_norm_cfg=None),
positional_encoding=dict(
num_feats=embed_dims//2,
normalize=True,
offset=0.0, # -0.5 for DeformDETR
temperature=20), # 10000 for DeformDETR
bbox_head=dict(
type='SpecDetrHead',
num_classes=8,
sync_cls_avg_factor=True,
pre_bboxes_round = False,
use_nms = True,
iou_threshold = 0.01,
embed_dims = embed_dims,
# neg_cls = True,
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0), # 2.0 in DeformDETR
loss_bbox=dict(type='L1Loss', loss_weight=5.0),
loss_iou=dict(type='GIoULoss', loss_weight=2.0)),
dn_cfg=dict( # TODO: Move to model.train_cfg ?
label_noise_scale=0.5, # centor 0.1 -0.5
box_noise_scale =1.5, # wh noise 1---
# group_cfg=dict(dynamic=False, num_groups=30,
# num_dn_queries=200),
group_cfg=dict(dynamic=True, num_groups=None,
num_dn_queries=200),
# group_cfg=dict(dynamic=False, num_groups=10,
# num_dn_queries=None),
), # TODO: half num_dn_queries
# training and testing settings
train_cfg=dict(
assigner=dict(
type='DynamicIOUHungarianAssigner',
match_costs=[
dict(type='FocalLossCost', weight=2.0),
dict(type='BBoxL1Cost', weight=5.0, box_format='xywh'),
dict(type='IoUCost', iou_mode='giou', weight=2.0),
dict(type='IoULossCost', iou_mode='iou', weight=1.0)
],
match_num=10, # 1 5
base_match_num=1,
iou_loss_th=0.05,
dynamic_match=True)),
test_cfg=dict(max_per_img=300)) # 100 for DeformDETR
# optimizer
optim_wrapper = dict(
type='OptimWrapper',
optimizer=dict(
type='AdamW',
lr=0.0001, # 0.0002 for DeformDETR
weight_decay=0.0001),
clip_grad=dict(max_norm=0.1, norm_type=2),
paramwise_cfg=dict(custom_keys={'backbone': dict(lr_mult=0.1)})
) # custom_keys contains sampling_offsets and reference_points in DeformDETR # noqa
# learning policy
max_epochs = 100
train_cfg = dict(
type='EpochBasedTrainLoop', max_epochs=max_epochs, val_interval=20,)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')
param_scheduler = [
dict(
type='MultiStepLR',
begin=0,
end=max_epochs,
by_epoch=True,
milestones=[90],
gamma=0.1)
]
# # NOTE: `auto_scale_lr` is for automatically scaling LR,
# # USER SHOULD NOT CHANGE ITS VALUES.
auto_scale_lr = dict(base_batch_size=4) |