Browse Source

!2886 only save ckpt in rank0 for Transformer

Merge pull request !2886 from yuchaojie/transformer
tags/v0.6.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
62a75eddea
1 changed files with 5 additions and 4 deletions
  1. +5
    -4
      model_zoo/Transformer/train.py

+ 5
- 4
model_zoo/Transformer/train.py View File

@@ -147,10 +147,11 @@ def run_transformer_train():

callbacks = [TimeMonitor(dataset.get_dataset_size()), LossCallBack()]
if args.enable_save_ckpt == "true":
ckpt_config = CheckpointConfig(save_checkpoint_steps=args.save_checkpoint_steps,
keep_checkpoint_max=args.save_checkpoint_num)
ckpoint_cb = ModelCheckpoint(prefix='transformer', directory=args.save_checkpoint_path, config=ckpt_config)
callbacks.append(ckpoint_cb)
if device_num == 1 or (device_num > 1 and rank_id == 0):
ckpt_config = CheckpointConfig(save_checkpoint_steps=args.save_checkpoint_steps,
keep_checkpoint_max=args.save_checkpoint_num)
ckpoint_cb = ModelCheckpoint(prefix='transformer', directory=args.save_checkpoint_path, config=ckpt_config)
callbacks.append(ckpoint_cb)

if args.enable_lossscale == "true":
scale_manager = DynamicLossScaleManager(init_loss_scale=cfg.init_loss_scale_value,


Loading…
Cancel
Save