@@ -92,8 +92,10 @@ def run_predistill():
dataset_size = dataset.get_dataset_size()
dataset_size = dataset.get_dataset_size()
if args_opt.enable_data_sink == 'true':
if args_opt.enable_data_sink == 'true':
repeat_count = args_opt.td_phase1_epoch_size * dataset.get_dataset_size() // args_opt.data_sink_steps
repeat_count = args_opt.td_phase1_epoch_size * dataset.get_dataset_size() // args_opt.data_sink_steps
time_monitor_steps = args_opt.data_sink_steps
else:
else:
repeat_count = args_opt.td_phase1_epoch_size
repeat_count = args_opt.td_phase1_epoch_size
time_monitor_steps = dataset_size
optimizer_cfg = cfg.optimizer_cfg
optimizer_cfg = cfg.optimizer_cfg
@@ -110,10 +112,10 @@ def run_predistill():
{'order_params': params}]
{'order_params': params}]
optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=optimizer_cfg.AdamWeightDecay.eps)
optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=optimizer_cfg.AdamWeightDecay.eps)
callback = [TimeMonitor(dataset_size ), LossCallBack(), ModelSaveCkpt(netwithloss.bert,
args_opt.save_ckpt_step,
args_opt.max_ckpt_num,
td_phase1_save_ckpt_dir)]
callback = [TimeMonitor(time_monitor_steps ), LossCallBack(), ModelSaveCkpt(netwithloss.bert,
args_opt.save_ckpt_step,
args_opt.max_ckpt_num,
td_phase1_save_ckpt_dir)]
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=cfg.loss_scale_value,
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=cfg.loss_scale_value,
scale_factor=cfg.scale_factor,
scale_factor=cfg.scale_factor,
scale_window=cfg.scale_window)
scale_window=cfg.scale_window)
@@ -147,8 +149,10 @@ def run_task_distill(ckpt_file):
dataset_size = train_dataset.get_dataset_size()
dataset_size = train_dataset.get_dataset_size()
if args_opt.enable_data_sink == 'true':
if args_opt.enable_data_sink == 'true':
repeat_count = args_opt.td_phase2_epoch_size * train_dataset.get_dataset_size() // args_opt.data_sink_steps
repeat_count = args_opt.td_phase2_epoch_size * train_dataset.get_dataset_size() // args_opt.data_sink_steps
time_monitor_steps = args_opt.data_sink_steps
else:
else:
repeat_count = args_opt.td_phase2_epoch_size
repeat_count = args_opt.td_phase2_epoch_size
time_monitor_steps = dataset_size
optimizer_cfg = cfg.optimizer_cfg
optimizer_cfg = cfg.optimizer_cfg
@@ -170,14 +174,14 @@ def run_task_distill(ckpt_file):
device_num, rank, args_opt.do_shuffle,
device_num, rank, args_opt.do_shuffle,
args_opt.eval_data_dir, args_opt.schema_dir)
args_opt.eval_data_dir, args_opt.schema_dir)
if args_opt.do_eval.lower() == "true":
if args_opt.do_eval.lower() == "true":
callback = [TimeMonitor(dataset_size ), LossCallBack(),
callback = [TimeMonitor(time_monitor_steps ), LossCallBack(),
ModelSaveCkpt(netwithloss.bert,
ModelSaveCkpt(netwithloss.bert,
args_opt.save_ckpt_step,
args_opt.save_ckpt_step,
args_opt.max_ckpt_num,
args_opt.max_ckpt_num,
td_phase2_save_ckpt_dir),
td_phase2_save_ckpt_dir),
EvalCallBack(netwithloss.bert, eval_dataset)]
EvalCallBack(netwithloss.bert, eval_dataset)]
else:
else:
callback = [TimeMonitor(dataset_size ), LossCallBack(),
callback = [TimeMonitor(time_monitor_steps ), LossCallBack(),
ModelSaveCkpt(netwithloss.bert,
ModelSaveCkpt(netwithloss.bert,
args_opt.save_ckpt_step,
args_opt.save_ckpt_step,
args_opt.max_ckpt_num,
args_opt.max_ckpt_num,