|
|
@@ -245,7 +245,7 @@ class _AutoParallelContext: |
|
|
self.check_context_handle() |
|
|
self.check_context_handle() |
|
|
return self._context_handle.get_parameter_broadcast_is_set() |
|
|
return self._context_handle.get_parameter_broadcast_is_set() |
|
|
|
|
|
|
|
|
def set_all_reduce_fusion_split_indices(self, indices, group=""): |
|
|
|
|
|
|
|
|
def set_all_reduce_fusion_split_indices(self, indices, group="hccl_world_groupsum1"): |
|
|
""" |
|
|
""" |
|
|
Set allreduce fusion strategy by parameters indices. |
|
|
Set allreduce fusion strategy by parameters indices. |
|
|
|
|
|
|
|
|
@@ -279,7 +279,7 @@ class _AutoParallelContext: |
|
|
else: |
|
|
else: |
|
|
_set_fusion_strategy_by_idx(indices, group) |
|
|
_set_fusion_strategy_by_idx(indices, group) |
|
|
|
|
|
|
|
|
def get_all_reduce_fusion_split_indices(self, group=""): |
|
|
|
|
|
|
|
|
def get_all_reduce_fusion_split_indices(self, group="hccl_world_groupsum1"): |
|
|
""" |
|
|
""" |
|
|
Get allreduce fusion split indices. |
|
|
Get allreduce fusion split indices. |
|
|
|
|
|
|
|
|
@@ -301,7 +301,7 @@ class _AutoParallelContext: |
|
|
raise TypeError('Group must be a python str') |
|
|
raise TypeError('Group must be a python str') |
|
|
return self._context_handle.get_all_reduce_fusion_split_indices(group) |
|
|
return self._context_handle.get_all_reduce_fusion_split_indices(group) |
|
|
|
|
|
|
|
|
def set_all_reduce_fusion_split_sizes(self, sizes, group=""): |
|
|
|
|
|
|
|
|
def set_all_reduce_fusion_split_sizes(self, sizes, group="hccl_world_groupsum1"): |
|
|
""" |
|
|
""" |
|
|
Set allreduce fusion strategy by parameters data sizes. |
|
|
Set allreduce fusion strategy by parameters data sizes. |
|
|
|
|
|
|
|
|
@@ -335,7 +335,7 @@ class _AutoParallelContext: |
|
|
else: |
|
|
else: |
|
|
_set_fusion_strategy_by_size(sizes, group) |
|
|
_set_fusion_strategy_by_size(sizes, group) |
|
|
|
|
|
|
|
|
def get_all_reduce_fusion_split_sizes(self, group=""): |
|
|
|
|
|
|
|
|
def get_all_reduce_fusion_split_sizes(self, group="hccl_world_groupsum1"): |
|
|
""" |
|
|
""" |
|
|
Get allreduce fusion split sizes. |
|
|
Get allreduce fusion split sizes. |
|
|
|
|
|
|
|
|
|