|
|
|
@@ -20,6 +20,8 @@ from mindspore._c_expression import AutoParallelContext |
|
|
|
from mindspore._checkparam import args_type_check |
|
|
|
|
|
|
|
_MAX_GROUP_NAME_LEN = 127 |
|
|
|
_DEFAULT_HCCL_FUSION_GROUP_NAME = "hccl_world_groupsum1" |
|
|
|
_DEFAULT_NCCL_FUSION_GROUP_NAME = "nccl_world_groupsum1" |
|
|
|
|
|
|
|
|
|
|
|
class _AutoParallelContext: |
|
|
|
@@ -267,7 +269,7 @@ class _AutoParallelContext: |
|
|
|
self.check_context_handle() |
|
|
|
return self._context_handle.get_parameter_broadcast_is_set() |
|
|
|
|
|
|
|
def set_all_reduce_fusion_split_indices(self, indices, group="hccl_world_groupsum1"): |
|
|
|
def set_all_reduce_fusion_split_indices(self, indices, group=""): |
|
|
|
""" |
|
|
|
Set allreduce fusion strategy by parameters indices. |
|
|
|
|
|
|
|
@@ -294,11 +296,17 @@ class _AutoParallelContext: |
|
|
|
else: |
|
|
|
raise TypeError('Group must be a python str') |
|
|
|
|
|
|
|
if group == "": |
|
|
|
if context.get_context("device_target") == "Ascend": |
|
|
|
group = _DEFAULT_HCCL_FUSION_GROUP_NAME |
|
|
|
else: |
|
|
|
group = _DEFAULT_NCCL_FUSION_GROUP_NAME |
|
|
|
|
|
|
|
self._context_handle.set_all_reduce_fusion_split_indices(indices, group) |
|
|
|
if context.get_context("device_target") == "Ascend": |
|
|
|
_set_fusion_strategy_by_idx(indices) |
|
|
|
|
|
|
|
def get_all_reduce_fusion_split_indices(self, group="hccl_world_groupsum1"): |
|
|
|
def get_all_reduce_fusion_split_indices(self, group=""): |
|
|
|
""" |
|
|
|
Get allreduce fusion split indices. |
|
|
|
|
|
|
|
@@ -318,9 +326,15 @@ class _AutoParallelContext: |
|
|
|
raise ValueError('Group name len is out of range {_MAX_GROUP_NAME_LEN}') |
|
|
|
else: |
|
|
|
raise TypeError('Group must be a python str') |
|
|
|
|
|
|
|
if group == "": |
|
|
|
if context.get_context("device_target") == "Ascend": |
|
|
|
group = _DEFAULT_HCCL_FUSION_GROUP_NAME |
|
|
|
else: |
|
|
|
group = _DEFAULT_NCCL_FUSION_GROUP_NAME |
|
|
|
return self._context_handle.get_all_reduce_fusion_split_indices(group) |
|
|
|
|
|
|
|
def set_all_reduce_fusion_split_sizes(self, sizes, group="hccl_world_groupsum1"): |
|
|
|
def set_all_reduce_fusion_split_sizes(self, sizes, group=""): |
|
|
|
""" |
|
|
|
Set allreduce fusion strategy by parameters data sizes. |
|
|
|
|
|
|
|
@@ -347,11 +361,17 @@ class _AutoParallelContext: |
|
|
|
else: |
|
|
|
raise TypeError('Group must be a python str') |
|
|
|
|
|
|
|
if group == "": |
|
|
|
if context.get_context("device_target") == "Ascend": |
|
|
|
group = _DEFAULT_HCCL_FUSION_GROUP_NAME |
|
|
|
else: |
|
|
|
group = _DEFAULT_NCCL_FUSION_GROUP_NAME |
|
|
|
|
|
|
|
self._context_handle.set_all_reduce_fusion_split_sizes(sizes, group) |
|
|
|
if context.get_context("device_target") == "Ascend": |
|
|
|
_set_fusion_strategy_by_size(sizes) |
|
|
|
|
|
|
|
def get_all_reduce_fusion_split_sizes(self, group="hccl_world_groupsum1"): |
|
|
|
def get_all_reduce_fusion_split_sizes(self, group=""): |
|
|
|
""" |
|
|
|
Get allreduce fusion split sizes. |
|
|
|
|
|
|
|
@@ -371,6 +391,12 @@ class _AutoParallelContext: |
|
|
|
raise ValueError('Group name len is out of range {_MAX_GROUP_NAME_LEN}') |
|
|
|
else: |
|
|
|
raise TypeError('Group must be a python str') |
|
|
|
|
|
|
|
if group == "": |
|
|
|
if context.get_context("device_target") == "Ascend": |
|
|
|
group = _DEFAULT_HCCL_FUSION_GROUP_NAME |
|
|
|
else: |
|
|
|
group = _DEFAULT_NCCL_FUSION_GROUP_NAME |
|
|
|
return self._context_handle.get_all_reduce_fusion_split_sizes(group) |
|
|
|
|
|
|
|
def set_enable_all_reduce_fusion(self, enable_all_reduce_fusion): |
|
|
|
|