Browse Source

!594 Fix wizard template module to fit new 'set_auto_parallel_context' API

Merge pull request !594 from moran/wizard_dev
tags/v1.0.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
81250b1650
3 changed files with 5 additions and 5 deletions
  1. +2
    -2
      mindinsight/wizard/conf/templates/network/alexnet/train.py-tpl
  2. +1
    -1
      mindinsight/wizard/conf/templates/network/lenet/train.py-tpl
  3. +2
    -2
      mindinsight/wizard/conf/templates/network/resnet50/train.py-tpl

+ 2
- 2
mindinsight/wizard/conf/templates/network/alexnet/train.py-tpl View File

@@ -60,14 +60,14 @@ if __name__ == "__main__":
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(device_id=device_id, enable_auto_mixed_precision=True)
context.set_auto_parallel_context(device_num=args.device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
mirror_mean=True)
gradients_mean=True)

init()
# GPU target
else:
init("nccl")
context.set_auto_parallel_context(device_num=get_group_size(), parallel_mode=ParallelMode.DATA_PARALLEL,
mirror_mean=True)
gradients_mean=True)
ckpt_save_dir = cfg.save_checkpoint_path + "ckpt_" + str(get_rank()) + "/"




+ 1
- 1
mindinsight/wizard/conf/templates/network/lenet/train.py-tpl View File

@@ -60,7 +60,7 @@ if __name__ == "__main__":
raise ValueError('Distribute running is no supported on %s' % args.device_target)
context.reset_auto_parallel_context()
context.set_auto_parallel_context(device_num=args.device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
mirror_mean=True)
gradients_mean=True)

data_path = args.dataset_path
do_train = True


+ 2
- 2
mindinsight/wizard/conf/templates/network/resnet50/train.py-tpl View File

@@ -65,7 +65,7 @@ if __name__ == '__main__':
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(device_id=device_id, enable_auto_mixed_precision=True)
context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
mirror_mean=True)
gradients_mean=True)
auto_parallel_context().set_all_reduce_fusion_split_indices([107, 160])

init()
@@ -73,7 +73,7 @@ if __name__ == '__main__':
else:
init("nccl")
context.set_auto_parallel_context(device_num=get_group_size(), parallel_mode=ParallelMode.DATA_PARALLEL,
mirror_mean=True)
gradients_mean=True)
ckpt_save_dir = cfg.save_checkpoint_path + "ckpt_" + str(get_rank()) + "/"

# create dataset


Loading…
Cancel
Save