| @@ -20,6 +20,8 @@ from mindspore._c_expression import AutoParallelContext | |||||
| from mindspore._checkparam import args_type_check | from mindspore._checkparam import args_type_check | ||||
| _MAX_GROUP_NAME_LEN = 127 | _MAX_GROUP_NAME_LEN = 127 | ||||
| _DEFAULT_HCCL_FUSION_GROUP_NAME = "hccl_world_groupsum1" | |||||
| _DEFAULT_NCCL_FUSION_GROUP_NAME = "nccl_world_groupsum1" | |||||
| class _AutoParallelContext: | class _AutoParallelContext: | ||||
| @@ -267,7 +269,7 @@ class _AutoParallelContext: | |||||
| self.check_context_handle() | self.check_context_handle() | ||||
| return self._context_handle.get_parameter_broadcast_is_set() | return self._context_handle.get_parameter_broadcast_is_set() | ||||
| def set_all_reduce_fusion_split_indices(self, indices, group="hccl_world_groupsum1"): | |||||
| def set_all_reduce_fusion_split_indices(self, indices, group=""): | |||||
| """ | """ | ||||
| Set allreduce fusion strategy by parameters indices. | Set allreduce fusion strategy by parameters indices. | ||||
| @@ -294,11 +296,17 @@ class _AutoParallelContext: | |||||
| else: | else: | ||||
| raise TypeError('Group must be a python str') | raise TypeError('Group must be a python str') | ||||
| if group == "": | |||||
| if context.get_context("device_target") == "Ascend": | |||||
| group = _DEFAULT_HCCL_FUSION_GROUP_NAME | |||||
| else: | |||||
| group = _DEFAULT_NCCL_FUSION_GROUP_NAME | |||||
| self._context_handle.set_all_reduce_fusion_split_indices(indices, group) | self._context_handle.set_all_reduce_fusion_split_indices(indices, group) | ||||
| if context.get_context("device_target") == "Ascend": | if context.get_context("device_target") == "Ascend": | ||||
| _set_fusion_strategy_by_idx(indices) | _set_fusion_strategy_by_idx(indices) | ||||
| def get_all_reduce_fusion_split_indices(self, group="hccl_world_groupsum1"): | |||||
| def get_all_reduce_fusion_split_indices(self, group=""): | |||||
| """ | """ | ||||
| Get allreduce fusion split indices. | Get allreduce fusion split indices. | ||||
| @@ -318,9 +326,15 @@ class _AutoParallelContext: | |||||
| raise ValueError('Group name len is out of range {_MAX_GROUP_NAME_LEN}') | raise ValueError('Group name len is out of range {_MAX_GROUP_NAME_LEN}') | ||||
| else: | else: | ||||
| raise TypeError('Group must be a python str') | raise TypeError('Group must be a python str') | ||||
| if group == "": | |||||
| if context.get_context("device_target") == "Ascend": | |||||
| group = _DEFAULT_HCCL_FUSION_GROUP_NAME | |||||
| else: | |||||
| group = _DEFAULT_NCCL_FUSION_GROUP_NAME | |||||
| return self._context_handle.get_all_reduce_fusion_split_indices(group) | return self._context_handle.get_all_reduce_fusion_split_indices(group) | ||||
| def set_all_reduce_fusion_split_sizes(self, sizes, group="hccl_world_groupsum1"): | |||||
| def set_all_reduce_fusion_split_sizes(self, sizes, group=""): | |||||
| """ | """ | ||||
| Set allreduce fusion strategy by parameters data sizes. | Set allreduce fusion strategy by parameters data sizes. | ||||
| @@ -347,11 +361,17 @@ class _AutoParallelContext: | |||||
| else: | else: | ||||
| raise TypeError('Group must be a python str') | raise TypeError('Group must be a python str') | ||||
| if group == "": | |||||
| if context.get_context("device_target") == "Ascend": | |||||
| group = _DEFAULT_HCCL_FUSION_GROUP_NAME | |||||
| else: | |||||
| group = _DEFAULT_NCCL_FUSION_GROUP_NAME | |||||
| self._context_handle.set_all_reduce_fusion_split_sizes(sizes, group) | self._context_handle.set_all_reduce_fusion_split_sizes(sizes, group) | ||||
| if context.get_context("device_target") == "Ascend": | if context.get_context("device_target") == "Ascend": | ||||
| _set_fusion_strategy_by_size(sizes) | _set_fusion_strategy_by_size(sizes) | ||||
| def get_all_reduce_fusion_split_sizes(self, group="hccl_world_groupsum1"): | |||||
| def get_all_reduce_fusion_split_sizes(self, group=""): | |||||
| """ | """ | ||||
| Get allreduce fusion split sizes. | Get allreduce fusion split sizes. | ||||
| @@ -371,6 +391,12 @@ class _AutoParallelContext: | |||||
| raise ValueError('Group name len is out of range {_MAX_GROUP_NAME_LEN}') | raise ValueError('Group name len is out of range {_MAX_GROUP_NAME_LEN}') | ||||
| else: | else: | ||||
| raise TypeError('Group must be a python str') | raise TypeError('Group must be a python str') | ||||
| if group == "": | |||||
| if context.get_context("device_target") == "Ascend": | |||||
| group = _DEFAULT_HCCL_FUSION_GROUP_NAME | |||||
| else: | |||||
| group = _DEFAULT_NCCL_FUSION_GROUP_NAME | |||||
| return self._context_handle.get_all_reduce_fusion_split_sizes(group) | return self._context_handle.get_all_reduce_fusion_split_sizes(group) | ||||
| def set_enable_all_reduce_fusion(self, enable_all_reduce_fusion): | def set_enable_all_reduce_fusion(self, enable_all_reduce_fusion): | ||||