|
|
@@ -14,6 +14,8 @@ |
|
|
# ============================================================================ |
|
|
# ============================================================================ |
|
|
"""Context of auto parallel""" |
|
|
"""Context of auto parallel""" |
|
|
import threading |
|
|
import threading |
|
|
|
|
|
import mindspore.context as context |
|
|
|
|
|
from mindspore.parallel._dp_allreduce_fusion import _set_fusion_strategy_by_idx, _set_fusion_strategy_by_size |
|
|
from mindspore._c_expression import AutoParallelContext |
|
|
from mindspore._c_expression import AutoParallelContext |
|
|
from mindspore._extends.pynative_helper import args_type_check |
|
|
from mindspore._extends.pynative_helper import args_type_check |
|
|
|
|
|
|
|
|
@@ -219,13 +221,15 @@ class _AutoParallelContext: |
|
|
indices (list): Indices list. |
|
|
indices (list): Indices list. |
|
|
|
|
|
|
|
|
Raises: |
|
|
Raises: |
|
|
ValueError: If type of indices item is not int. |
|
|
|
|
|
|
|
|
TypeError: If type of indices item is not int. |
|
|
""" |
|
|
""" |
|
|
self.check_context_handle() |
|
|
self.check_context_handle() |
|
|
for index in indices: |
|
|
for index in indices: |
|
|
if not isinstance(index, int): |
|
|
if not isinstance(index, int): |
|
|
raise TypeError('indices has invalid value') |
|
|
raise TypeError('indices has invalid value') |
|
|
return self._context_handle.set_all_reduce_fusion_split_indices(indices) |
|
|
|
|
|
|
|
|
self._context_handle.set_all_reduce_fusion_split_indices(indices) |
|
|
|
|
|
if context.get_context("device_target") == "Ascend": |
|
|
|
|
|
_set_fusion_strategy_by_idx(indices) |
|
|
|
|
|
|
|
|
def get_all_reduce_fusion_split_indices(self): |
|
|
def get_all_reduce_fusion_split_indices(self): |
|
|
"""Get allreduce fusion split indices.""" |
|
|
"""Get allreduce fusion split indices.""" |
|
|
@@ -240,13 +244,15 @@ class _AutoParallelContext: |
|
|
sizes (list): Sizes list. |
|
|
sizes (list): Sizes list. |
|
|
|
|
|
|
|
|
Raises: |
|
|
Raises: |
|
|
ValueError: If type of sizes item is not int. |
|
|
|
|
|
|
|
|
TypeError: If type of sizes item is not int. |
|
|
""" |
|
|
""" |
|
|
self.check_context_handle() |
|
|
self.check_context_handle() |
|
|
for size in sizes: |
|
|
for size in sizes: |
|
|
if not isinstance(size, int): |
|
|
if not isinstance(size, int): |
|
|
raise TypeError('sizes has invalid value') |
|
|
raise TypeError('sizes has invalid value') |
|
|
return self._context_handle.set_all_reduce_fusion_split_sizes(sizes) |
|
|
|
|
|
|
|
|
self._context_handle.set_all_reduce_fusion_split_sizes(sizes) |
|
|
|
|
|
if context.get_context("device_target") == "Ascend": |
|
|
|
|
|
_set_fusion_strategy_by_size(sizes) |
|
|
|
|
|
|
|
|
def get_all_reduce_fusion_split_sizes(self): |
|
|
def get_all_reduce_fusion_split_sizes(self): |
|
|
"""Get allreduce fusion split sizes.""" |
|
|
"""Get allreduce fusion split sizes.""" |
|
|
|