Merge pull request !6153 from lichen/add_checks_for_inittags/v1.0.0
| @@ -62,8 +62,8 @@ def init(backend_name=None): | |||||
| """ | """ | ||||
| if _is_role_pserver() or _is_role_sched(): | if _is_role_pserver() or _is_role_sched(): | ||||
| return | return | ||||
| device_target = context.get_context("device_target") | |||||
| if backend_name is None: | if backend_name is None: | ||||
| device_target = context.get_context("device_target") | |||||
| if device_target == "Ascend": | if device_target == "Ascend": | ||||
| backend_name = "hccl" | backend_name = "hccl" | ||||
| elif device_target == "GPU": | elif device_target == "GPU": | ||||
| @@ -74,6 +74,8 @@ def init(backend_name=None): | |||||
| raise TypeError("Backend name must be a string, but got {}".format(type(backend_name))) | raise TypeError("Backend name must be a string, but got {}".format(type(backend_name))) | ||||
| if backend_name == "hccl": | if backend_name == "hccl": | ||||
| if device_target != "Ascend": | |||||
| raise RuntimeError("Device target should be 'Ascend' to init hccl, but got {}".format(device_target)) | |||||
| init_hccl() | init_hccl() | ||||
| GlobalComm.BACKEND = Backend("hccl") | GlobalComm.BACKEND = Backend("hccl") | ||||
| GlobalComm.WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP | GlobalComm.WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP | ||||
| @@ -380,7 +380,8 @@ def set_auto_parallel_context(**kwargs): | |||||
| full_batch (bool): Whether to load the whole batch on each device. Default: False. | full_batch (bool): Whether to load the whole batch on each device. Default: False. | ||||
| enable_parallel_optimizer (bool): This is a developing feature, which shards the weight update computation in | enable_parallel_optimizer (bool): This is a developing feature, which shards the weight update computation in | ||||
| data parallel training in the benefit of time and memory saving. | data parallel training in the benefit of time and memory saving. | ||||
| all_reduce_fusion_config (list): Set allreduce fusion strategy by parameters indices. | |||||
| all_reduce_fusion_config (list): Set allreduce fusion strategy by parameters indices. Only support ReduceOp.SUM | |||||
| and HCCL_WORLD_GROUP/NCCL_WORLD_GROUP. | |||||
| Raises: | Raises: | ||||
| ValueError: If input key is not attribute in auto parallel context. | ValueError: If input key is not attribute in auto parallel context. | ||||
| @@ -33,6 +33,7 @@ from mindspore.ops.operations.comm_ops import Broadcast | |||||
| tag = 0 | tag = 0 | ||||
| context.set_context(device_target="Ascend") | |||||
| init("hccl") | init("hccl") | ||||