| @@ -20,6 +20,8 @@ from mindspore._c_expression import AutoParallelContext | |||
| from mindspore._checkparam import args_type_check | |||
| _MAX_GROUP_NAME_LEN = 127 | |||
| _DEFAULT_HCCL_FUSION_GROUP_NAME = "hccl_world_groupsum1" | |||
| _DEFAULT_NCCL_FUSION_GROUP_NAME = "nccl_world_groupsum1" | |||
| class _AutoParallelContext: | |||
| @@ -267,7 +269,7 @@ class _AutoParallelContext: | |||
| self.check_context_handle() | |||
| return self._context_handle.get_parameter_broadcast_is_set() | |||
| def set_all_reduce_fusion_split_indices(self, indices, group="hccl_world_groupsum1"): | |||
| def set_all_reduce_fusion_split_indices(self, indices, group=""): | |||
| """ | |||
| Set allreduce fusion strategy by parameters indices. | |||
| @@ -294,11 +296,17 @@ class _AutoParallelContext: | |||
| else: | |||
| raise TypeError('Group must be a python str') | |||
| if group == "": | |||
| if context.get_context("device_target") == "Ascend": | |||
| group = _DEFAULT_HCCL_FUSION_GROUP_NAME | |||
| else: | |||
| group = _DEFAULT_NCCL_FUSION_GROUP_NAME | |||
| self._context_handle.set_all_reduce_fusion_split_indices(indices, group) | |||
| if context.get_context("device_target") == "Ascend": | |||
| _set_fusion_strategy_by_idx(indices) | |||
| def get_all_reduce_fusion_split_indices(self, group="hccl_world_groupsum1"): | |||
| def get_all_reduce_fusion_split_indices(self, group=""): | |||
| """ | |||
| Get allreduce fusion split indices. | |||
| @@ -318,9 +326,15 @@ class _AutoParallelContext: | |||
| raise ValueError('Group name len is out of range {_MAX_GROUP_NAME_LEN}') | |||
| else: | |||
| raise TypeError('Group must be a python str') | |||
| if group == "": | |||
| if context.get_context("device_target") == "Ascend": | |||
| group = _DEFAULT_HCCL_FUSION_GROUP_NAME | |||
| else: | |||
| group = _DEFAULT_NCCL_FUSION_GROUP_NAME | |||
| return self._context_handle.get_all_reduce_fusion_split_indices(group) | |||
| def set_all_reduce_fusion_split_sizes(self, sizes, group="hccl_world_groupsum1"): | |||
| def set_all_reduce_fusion_split_sizes(self, sizes, group=""): | |||
| """ | |||
| Set allreduce fusion strategy by parameters data sizes. | |||
| @@ -347,11 +361,17 @@ class _AutoParallelContext: | |||
| else: | |||
| raise TypeError('Group must be a python str') | |||
| if group == "": | |||
| if context.get_context("device_target") == "Ascend": | |||
| group = _DEFAULT_HCCL_FUSION_GROUP_NAME | |||
| else: | |||
| group = _DEFAULT_NCCL_FUSION_GROUP_NAME | |||
| self._context_handle.set_all_reduce_fusion_split_sizes(sizes, group) | |||
| if context.get_context("device_target") == "Ascend": | |||
| _set_fusion_strategy_by_size(sizes) | |||
| def get_all_reduce_fusion_split_sizes(self, group="hccl_world_groupsum1"): | |||
| def get_all_reduce_fusion_split_sizes(self, group=""): | |||
| """ | |||
| Get allreduce fusion split sizes. | |||
| @@ -371,6 +391,12 @@ class _AutoParallelContext: | |||
| raise ValueError('Group name len is out of range {_MAX_GROUP_NAME_LEN}') | |||
| else: | |||
| raise TypeError('Group must be a python str') | |||
| if group == "": | |||
| if context.get_context("device_target") == "Ascend": | |||
| group = _DEFAULT_HCCL_FUSION_GROUP_NAME | |||
| else: | |||
| group = _DEFAULT_NCCL_FUSION_GROUP_NAME | |||
| return self._context_handle.get_all_reduce_fusion_split_sizes(group) | |||
| def set_enable_all_reduce_fusion(self, enable_all_reduce_fusion): | |||