diff --git a/mindspore/parallel/_auto_parallel_context.py b/mindspore/parallel/_auto_parallel_context.py index 3f6ce21cb9..7fd366c2f0 100644 --- a/mindspore/parallel/_auto_parallel_context.py +++ b/mindspore/parallel/_auto_parallel_context.py @@ -20,6 +20,8 @@ from mindspore._c_expression import AutoParallelContext from mindspore._checkparam import args_type_check _MAX_GROUP_NAME_LEN = 127 +_DEFAULT_HCCL_FUSION_GROUP_NAME = "hccl_world_groupsum1" +_DEFAULT_NCCL_FUSION_GROUP_NAME = "nccl_world_groupsum1" class _AutoParallelContext: @@ -267,7 +269,7 @@ class _AutoParallelContext: self.check_context_handle() return self._context_handle.get_parameter_broadcast_is_set() - def set_all_reduce_fusion_split_indices(self, indices, group="hccl_world_groupsum1"): + def set_all_reduce_fusion_split_indices(self, indices, group=""): """ Set allreduce fusion strategy by parameters indices. @@ -294,11 +296,17 @@ class _AutoParallelContext: else: raise TypeError('Group must be a python str') + if group == "": + if context.get_context("device_target") == "Ascend": + group = _DEFAULT_HCCL_FUSION_GROUP_NAME + else: + group = _DEFAULT_NCCL_FUSION_GROUP_NAME + self._context_handle.set_all_reduce_fusion_split_indices(indices, group) if context.get_context("device_target") == "Ascend": _set_fusion_strategy_by_idx(indices) - def get_all_reduce_fusion_split_indices(self, group="hccl_world_groupsum1"): + def get_all_reduce_fusion_split_indices(self, group=""): """ Get allreduce fusion split indices. @@ -318,9 +326,15 @@ class _AutoParallelContext: raise ValueError('Group name len is out of range {_MAX_GROUP_NAME_LEN}') else: raise TypeError('Group must be a python str') + + if group == "": + if context.get_context("device_target") == "Ascend": + group = _DEFAULT_HCCL_FUSION_GROUP_NAME + else: + group = _DEFAULT_NCCL_FUSION_GROUP_NAME return self._context_handle.get_all_reduce_fusion_split_indices(group) - def set_all_reduce_fusion_split_sizes(self, sizes, group="hccl_world_groupsum1"): + def set_all_reduce_fusion_split_sizes(self, sizes, group=""): """ Set allreduce fusion strategy by parameters data sizes. @@ -347,11 +361,17 @@ class _AutoParallelContext: else: raise TypeError('Group must be a python str') + if group == "": + if context.get_context("device_target") == "Ascend": + group = _DEFAULT_HCCL_FUSION_GROUP_NAME + else: + group = _DEFAULT_NCCL_FUSION_GROUP_NAME + self._context_handle.set_all_reduce_fusion_split_sizes(sizes, group) if context.get_context("device_target") == "Ascend": _set_fusion_strategy_by_size(sizes) - def get_all_reduce_fusion_split_sizes(self, group="hccl_world_groupsum1"): + def get_all_reduce_fusion_split_sizes(self, group=""): """ Get allreduce fusion split sizes. @@ -371,6 +391,12 @@ class _AutoParallelContext: raise ValueError('Group name len is out of range {_MAX_GROUP_NAME_LEN}') else: raise TypeError('Group must be a python str') + + if group == "": + if context.get_context("device_target") == "Ascend": + group = _DEFAULT_HCCL_FUSION_GROUP_NAME + else: + group = _DEFAULT_NCCL_FUSION_GROUP_NAME return self._context_handle.get_all_reduce_fusion_split_sizes(group) def set_enable_all_reduce_fusion(self, enable_all_reduce_fusion):