|
|
|
@@ -274,10 +274,7 @@ class _AutoParallelContext: |
|
|
|
|
|
|
|
self._context_handle.set_all_reduce_fusion_split_indices(indices, group) |
|
|
|
if context.get_context("device_target") == "Ascend": |
|
|
|
if group == "": |
|
|
|
_set_fusion_strategy_by_idx(indices) |
|
|
|
else: |
|
|
|
_set_fusion_strategy_by_idx(indices, group) |
|
|
|
_set_fusion_strategy_by_idx(indices) |
|
|
|
|
|
|
|
def get_all_reduce_fusion_split_indices(self, group="hccl_world_groupsum1"): |
|
|
|
""" |
|
|
|
@@ -330,10 +327,7 @@ class _AutoParallelContext: |
|
|
|
|
|
|
|
self._context_handle.set_all_reduce_fusion_split_sizes(sizes, group) |
|
|
|
if context.get_context("device_target") == "Ascend": |
|
|
|
if group == "": |
|
|
|
_set_fusion_strategy_by_size(sizes) |
|
|
|
else: |
|
|
|
_set_fusion_strategy_by_size(sizes, group) |
|
|
|
_set_fusion_strategy_by_size(sizes) |
|
|
|
|
|
|
|
def get_all_reduce_fusion_split_sizes(self, group="hccl_world_groupsum1"): |
|
|
|
""" |
|
|
|
|