|
|
@@ -294,6 +294,12 @@ class _AutoParallelContext: |
|
|
else: |
|
|
else: |
|
|
raise TypeError('indices must be a python list') |
|
|
raise TypeError('indices must be a python list') |
|
|
|
|
|
|
|
|
|
|
|
if len(set(indices)) != len(indices): |
|
|
|
|
|
raise ValueError('indices has duplicate elements') |
|
|
|
|
|
|
|
|
|
|
|
if sorted(indices) != indices: |
|
|
|
|
|
raise ValueError('elements in indices must be sorted in ascending order') |
|
|
|
|
|
|
|
|
if isinstance(group, (str)): |
|
|
if isinstance(group, (str)): |
|
|
group_len = len(group) |
|
|
group_len = len(group) |
|
|
if group_len > _MAX_GROUP_NAME_LEN: |
|
|
if group_len > _MAX_GROUP_NAME_LEN: |
|
|
@@ -308,7 +314,7 @@ class _AutoParallelContext: |
|
|
group = _DEFAULT_NCCL_FUSION_GROUP_NAME |
|
|
group = _DEFAULT_NCCL_FUSION_GROUP_NAME |
|
|
|
|
|
|
|
|
self._context_handle.set_all_reduce_fusion_split_indices(indices, group) |
|
|
self._context_handle.set_all_reduce_fusion_split_indices(indices, group) |
|
|
if context.get_context("device_target") == "Ascend": |
|
|
|
|
|
|
|
|
if context.get_context("device_target") == "Ascend" and context.get_context("enable_ge"): |
|
|
_set_fusion_strategy_by_idx(indices) |
|
|
_set_fusion_strategy_by_idx(indices) |
|
|
|
|
|
|
|
|
def get_all_reduce_fusion_split_indices(self, group=""): |
|
|
def get_all_reduce_fusion_split_indices(self, group=""): |
|
|
|