| @@ -245,7 +245,7 @@ class _AutoParallelContext: | |||||
| self.check_context_handle() | self.check_context_handle() | ||||
| return self._context_handle.get_parameter_broadcast_is_set() | return self._context_handle.get_parameter_broadcast_is_set() | ||||
| def set_all_reduce_fusion_split_indices(self, indices, group=""): | |||||
| def set_all_reduce_fusion_split_indices(self, indices, group="hccl_world_groupsum1"): | |||||
| """ | """ | ||||
| Set allreduce fusion strategy by parameters indices. | Set allreduce fusion strategy by parameters indices. | ||||
| @@ -279,7 +279,7 @@ class _AutoParallelContext: | |||||
| else: | else: | ||||
| _set_fusion_strategy_by_idx(indices, group) | _set_fusion_strategy_by_idx(indices, group) | ||||
| def get_all_reduce_fusion_split_indices(self, group=""): | |||||
| def get_all_reduce_fusion_split_indices(self, group="hccl_world_groupsum1"): | |||||
| """ | """ | ||||
| Get allreduce fusion split indices. | Get allreduce fusion split indices. | ||||
| @@ -301,7 +301,7 @@ class _AutoParallelContext: | |||||
| raise TypeError('Group must be a python str') | raise TypeError('Group must be a python str') | ||||
| return self._context_handle.get_all_reduce_fusion_split_indices(group) | return self._context_handle.get_all_reduce_fusion_split_indices(group) | ||||
| def set_all_reduce_fusion_split_sizes(self, sizes, group=""): | |||||
| def set_all_reduce_fusion_split_sizes(self, sizes, group="hccl_world_groupsum1"): | |||||
| """ | """ | ||||
| Set allreduce fusion strategy by parameters data sizes. | Set allreduce fusion strategy by parameters data sizes. | ||||
| @@ -335,7 +335,7 @@ class _AutoParallelContext: | |||||
| else: | else: | ||||
| _set_fusion_strategy_by_size(sizes, group) | _set_fusion_strategy_by_size(sizes, group) | ||||
| def get_all_reduce_fusion_split_sizes(self, group=""): | |||||
| def get_all_reduce_fusion_split_sizes(self, group="hccl_world_groupsum1"): | |||||
| """ | """ | ||||
| Get allreduce fusion split sizes. | Get allreduce fusion split sizes. | ||||