Browse Source

add allreduce group for resnet gpu version

tags/v0.7.0-beta
yuchaojie 5 years ago
parent
commit
64a1560f1a
3 changed files with 8 additions and 6 deletions
  1. +4
    -4
      mindspore/parallel/_auto_parallel_context.py
  2. +1
    -1
      model_zoo/official/cv/resnet/README.md
  3. +3
    -1
      model_zoo/official/cv/resnet/train.py

+ 4
- 4
mindspore/parallel/_auto_parallel_context.py View File

@@ -275,7 +275,7 @@ class _AutoParallelContext:

Args:
indices (list): Indices list.
group (str): The hccl communication group.
group (str): The communication group of hccl/nccl.

Raises:
TypeError: If type of indices item is not int.
@@ -311,7 +311,7 @@ class _AutoParallelContext:
Get allreduce fusion split indices.

Args:
group (str): The hccl communication group.
group (str): The communication group of hccl/nccl.

Returns:
Return split sizes list according to the group.
@@ -340,7 +340,7 @@ class _AutoParallelContext:

Args:
sizes (list): Sizes list.
group (str): The hccl communication group.
group (str): The communication group of hccl/nccl.

Raises:
TypeError: If type of sizes item is not int.
@@ -376,7 +376,7 @@ class _AutoParallelContext:
Get allreduce fusion split sizes.

Args:
group (str): The hccl communication group.
group (str): The communication group of hccl/nccl.

Returns:
Return split sizes list according to the group.


+ 1
- 1
model_zoo/official/cv/resnet/README.md View File

@@ -44,7 +44,7 @@ ImageNet2012
├── run_distribute_train.sh # launch distributed training(8 pcs)
├── run_parameter_server_train.sh # launch Ascend parameter server training(8 pcs)
├── run_eval.sh # launch evaluation
── run_standalone_train.sh # launch standalone training(1 pcs)
── run_standalone_train.sh # launch standalone training(1 pcs)
├── run_distribute_train_gpu.sh # launch gpu distributed training(8 pcs)
├── run_parameter_server_train_gpu.sh # launch gpu parameter server training(8 pcs)
├── run_eval_gpu.sh # launch gpu evaluation


+ 3
- 1
model_zoo/official/cv/resnet/train.py View File

@@ -81,9 +81,11 @@ if __name__ == '__main__':
init()
# GPU target
else:
init("nccl")
context.set_auto_parallel_context(device_num=get_group_size(), parallel_mode=ParallelMode.DATA_PARALLEL,
mirror_mean=True)
if args_opt.net == "resnet50":
auto_parallel_context().set_all_reduce_fusion_split_indices([85, 160])
init("nccl")
ckpt_save_dir = config.save_checkpoint_path + "ckpt_" + str(get_rank()) + "/"

# create dataset


Loading…
Cancel
Save