File size: 7,731 Bytes
b410583 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
#!/usr/bin/env python
import os
import argparse
def get_cmd(task, sub_task, model_tag, gpu, data_num, bs, lr, source_length, target_length, patience, epoch, warmup,
model_dir, summary_dir, res_fn, max_steps=None, save_steps=None, log_steps=None):
if max_steps is None:
cmd_str = 'bash exp_with_args.sh %s %s %s %d %d %d %d %d %d %d %d %d %s %s %s' % \
(task, sub_task, model_tag, gpu, data_num, bs, lr, source_length, target_length, patience, epoch,
warmup, model_dir, summary_dir, res_fn)
else:
cmd_str = 'bash exp_with_args.sh %s %s %s %d %d %d %d %d %d %d %d %d %s %s %s %d %d %d' % \
(task, sub_task, model_tag, gpu, data_num, bs, lr, source_length, target_length, patience, epoch,
warmup, model_dir, summary_dir, res_fn, max_steps, save_steps, log_steps)
return cmd_str
def get_args_by_task_model(task, sub_task, model_tag):
if task == 'translate':
# java-cs: Read 10300 examples, avg src len: 13, avg trg len: 15, max src len: 136, max trg len: 118
# [TOKENIZE] avg src len: 45, avg trg len: 56, max src len: 391, max trg len: 404
src_len = 320
trg_len = 256
epoch = 100
patience = 5
elif task == 'summarize':
# ruby: Read 24927 examples, avg src len: 66, avg trg len: 12, max src len: 501, max trg len: 146
# [TOKENIZE] avg src len: 100, avg trg len: 13, max src len: 1250, max trg len: 161
# Python: Read 251820 examples, avg src len: 100, avg trg len: 11, max src len: 512, max trg len: 222
# [TOKENIZE] avg src len: 142, avg trg len: 12, max src len: 2016, max trg len: 245
# Javascript: Read 58025 examples, avg src len: 114, avg trg len: 11, max src len: 512, max trg len: 165
# [TOKENIZE] avg src len: 136, avg trg len: 12, max src len: 3016, max trg len: 177
src_len = 256
trg_len = 128
epoch = 15
patience = 2
elif task == 'refine':
# small: Read 46680 examples, avg src len: 31, avg trg len: 28, max src len: 50, max trg len: 50
# [TOKENIZE] avg src len: 50, avg trg len: 45, max src len: 129, max trg len: 121
# medium: Read 52364 examples, avg src len: 74, avg trg len: 73, max src len: 100, max trg len: 100
# [TOKENIZE] avg src len: 117, avg trg len: 114, max src len: 238, max trg len: 238
if sub_task == 'small':
src_len = 130
trg_len = 120
elif sub_task == 'medium':
src_len = 240
trg_len = 240
epoch = 50
patience = 5
elif task == 'concode':
# Read 100000 examples, avg src len: 71, avg trg len: 26, max src len: 567, max trg len: 140
# [TOKENIZE] avg src len: 213, avg trg len: 33, max src len: 2246, max trg len: 264
src_len = 320
trg_len = 150
epoch = 30
patience = 3
elif task == 'defect':
# Read 21854 examples, avg src len: 187, avg trg len: 1, max src len: 12195, max trg len: 1
# [TOKENIZE] avg src len: 597, avg trg len: 1, max src len: 41447, max trg len: 1
src_len = 512
trg_len = 3
epoch = 10
patience = 2
elif task == 'clone':
# Read 901028 examples, avg src len: 120, avg trg len: 123, max src len: 5270, max trg len: 5270
# [TOKENIZE] avg src len: 318, avg trg len: 323, max src len: 15111, max trg len: 15111
src_len = 400
trg_len = 400
epoch = 1
patience = 2
if 'codet5_small' in model_tag:
bs = 32
if task == 'summarize' or task == 'translate' or (task == 'refine' and sub_task == 'small'):
bs = 64
elif task == 'clone':
bs = 25
elif 'codet5_large' in model_tag:
bs = 8
else:
bs = 32
if task == 'translate':
bs = 25
elif task == 'summarize':
bs = 48
elif task == 'clone':
if model_tag in ['codebert', 'roberta']:
bs = 16
else:
bs = 10
lr = 5
if task == 'concode':
lr = 10
elif task == 'defect':
lr = 2
return bs, lr, src_len, trg_len, patience, epoch
def run_one_exp(args):
bs, lr, src_len, trg_len, patience, epoch = get_args_by_task_model(args.task, args.sub_task, args.model_tag)
print('============================Start Running==========================')
cmd_str = get_cmd(task=args.task, sub_task=args.sub_task, model_tag=args.model_tag, gpu=args.gpu,
data_num=args.data_num, bs=bs, lr=lr, source_length=src_len, target_length=trg_len,
patience=patience, epoch=epoch, warmup=1000,
model_dir=args.model_dir, summary_dir=args.summary_dir,
res_fn='{}/{}_{}.txt'.format(args.res_dir, args.task, args.model_tag))
print('%s\n' % cmd_str)
os.system(cmd_str)
def run_multi_task_exp(args):
# Total train data num = 1149722 (for all five tasks)
if 'codet5_small' in args.model_tag:
bs, lr, max_steps, save_steps, log_steps = 60, 5, 600000, 20000, 100
else:
bs, lr, max_steps, save_steps, log_steps = 25, 5, 800000, 20000, 100
if args.data_num != -1:
max_steps, save_steps, log_steps = 1000, 200, 50
print('============================Start Running==========================')
cmd_str = get_cmd(task='multi_task', sub_task='none', model_tag=args.model_tag, gpu=args.gpu,
data_num=args.data_num, bs=bs, lr=lr, source_length=-1, target_length=-1,
patience=-1, epoch=-1, warmup=1000,
model_dir=args.model_dir, summary_dir=args.summary_dir,
res_fn='{}/multi_task_{}.txt'.format(args.res_dir, args.model_tag),
max_steps=max_steps, save_steps=save_steps, log_steps=log_steps)
print('%s\n' % cmd_str)
os.system(cmd_str)
def get_sub_tasks(task):
if task == 'summarize':
sub_tasks = ['ruby', 'javascript', 'go', 'python', 'java', 'php']
elif task == 'translate':
sub_tasks = ['java-cs', 'cs-java']
elif task == 'refine':
sub_tasks = ['small', 'medium']
elif task in ['concode', 'defect', 'clone', 'multi_task']:
sub_tasks = ['none']
return sub_tasks
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--model_tag", type=str, default='codet5_base',
choices=['roberta', 'codebert', 'bart_base', 'codet5_small', 'codet5_base', 'codet5_large'])
parser.add_argument("--task", type=str, default='summarize', choices=['summarize', 'concode', 'translate',
'refine', 'defect', 'clone', 'multi_task'])
parser.add_argument("--sub_task", type=str, default='ruby')
parser.add_argument("--res_dir", type=str, default='results', help='directory to save fine-tuning results')
parser.add_argument("--model_dir", type=str, default='saved_models', help='directory to save fine-tuned models')
parser.add_argument("--summary_dir", type=str, default='tensorboard', help='directory to save tensorboard summary')
parser.add_argument("--data_num", type=int, default=-1, help='number of data instances to use, -1 for full data')
parser.add_argument("--gpu", type=int, default=0, help='index of the gpu to use in a cluster')
args = parser.parse_args()
if not os.path.exists(args.res_dir):
os.makedirs(args.res_dir)
assert args.sub_task in get_sub_tasks(args.task)
if args.task != 'multi_task':
run_one_exp(args)
else:
run_multi_task_exp(args)
|