Browse Source

!6153 [AutoParallel]Check device target in init

Merge pull request !6153 from lichen/add_checks_for_init
tags/v1.0.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
2b05208a6f
3 changed files with 6 additions and 2 deletions
  1. +3
    -1
      mindspore/communication/management.py
  2. +2
    -1
      mindspore/context.py
  3. +1
    -0
      tests/ut/python/communication/test_comm.py

+ 3
- 1
mindspore/communication/management.py View File

@@ -62,8 +62,8 @@ def init(backend_name=None):
""" """
if _is_role_pserver() or _is_role_sched(): if _is_role_pserver() or _is_role_sched():
return return
device_target = context.get_context("device_target")
if backend_name is None: if backend_name is None:
device_target = context.get_context("device_target")
if device_target == "Ascend": if device_target == "Ascend":
backend_name = "hccl" backend_name = "hccl"
elif device_target == "GPU": elif device_target == "GPU":
@@ -74,6 +74,8 @@ def init(backend_name=None):
raise TypeError("Backend name must be a string, but got {}".format(type(backend_name))) raise TypeError("Backend name must be a string, but got {}".format(type(backend_name)))


if backend_name == "hccl": if backend_name == "hccl":
if device_target != "Ascend":
raise RuntimeError("Device target should be 'Ascend' to init hccl, but got {}".format(device_target))
init_hccl() init_hccl()
GlobalComm.BACKEND = Backend("hccl") GlobalComm.BACKEND = Backend("hccl")
GlobalComm.WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP GlobalComm.WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP


+ 2
- 1
mindspore/context.py View File

@@ -380,7 +380,8 @@ def set_auto_parallel_context(**kwargs):
full_batch (bool): Whether to load the whole batch on each device. Default: False. full_batch (bool): Whether to load the whole batch on each device. Default: False.
enable_parallel_optimizer (bool): This is a developing feature, which shards the weight update computation in enable_parallel_optimizer (bool): This is a developing feature, which shards the weight update computation in
data parallel training in the benefit of time and memory saving. data parallel training in the benefit of time and memory saving.
all_reduce_fusion_config (list): Set allreduce fusion strategy by parameters indices.
all_reduce_fusion_config (list): Set allreduce fusion strategy by parameters indices. Only support ReduceOp.SUM
and HCCL_WORLD_GROUP/NCCL_WORLD_GROUP.


Raises: Raises:
ValueError: If input key is not attribute in auto parallel context. ValueError: If input key is not attribute in auto parallel context.


+ 1
- 0
tests/ut/python/communication/test_comm.py View File

@@ -33,6 +33,7 @@ from mindspore.ops.operations.comm_ops import Broadcast


tag = 0 tag = 0


context.set_context(device_target="Ascend")
init("hccl") init("hccl")






Loading…
Cancel
Save