|
|
|
@@ -147,10 +147,11 @@ def run_transformer_train(): |
|
|
|
|
|
|
|
callbacks = [TimeMonitor(dataset.get_dataset_size()), LossCallBack()] |
|
|
|
if args.enable_save_ckpt == "true": |
|
|
|
ckpt_config = CheckpointConfig(save_checkpoint_steps=args.save_checkpoint_steps, |
|
|
|
keep_checkpoint_max=args.save_checkpoint_num) |
|
|
|
ckpoint_cb = ModelCheckpoint(prefix='transformer', directory=args.save_checkpoint_path, config=ckpt_config) |
|
|
|
callbacks.append(ckpoint_cb) |
|
|
|
if device_num == 1 or (device_num > 1 and rank_id == 0): |
|
|
|
ckpt_config = CheckpointConfig(save_checkpoint_steps=args.save_checkpoint_steps, |
|
|
|
keep_checkpoint_max=args.save_checkpoint_num) |
|
|
|
ckpoint_cb = ModelCheckpoint(prefix='transformer', directory=args.save_checkpoint_path, config=ckpt_config) |
|
|
|
callbacks.append(ckpoint_cb) |
|
|
|
|
|
|
|
if args.enable_lossscale == "true": |
|
|
|
scale_manager = DynamicLossScaleManager(init_loss_scale=cfg.init_loss_scale_value, |
|
|
|
|