From f57bd919e08432e7fe4a6177371936cd4680c34c Mon Sep 17 00:00:00 2001 From: yao_yf Date: Thu, 9 Apr 2020 15:02:33 +0800 Subject: [PATCH] Integrate two allreduce fusion set interfaces into one --- mindspore/parallel/__init__.py | 4 +--- mindspore/parallel/_auto_parallel_context.py | 14 ++++++++++---- ...allreduce_fusion.py => _dp_allreduce_fusion.py} | 4 ++-- 3 files changed, 13 insertions(+), 9 deletions(-) rename mindspore/parallel/{dp_allreduce_fusion.py => _dp_allreduce_fusion.py} (94%) diff --git a/mindspore/parallel/__init__.py b/mindspore/parallel/__init__.py index c79704f110..79d8e67a8d 100644 --- a/mindspore/parallel/__init__.py +++ b/mindspore/parallel/__init__.py @@ -15,9 +15,7 @@ """ This interface is ONLY used in Auto-parallel procedure. """ -from .dp_allreduce_fusion import set_fusion_strategy_by_idx, set_fusion_strategy_by_size from .algo_parameter_config import get_algo_parameters, reset_algo_parameters, \ set_algo_parameters -__all__ = ["set_fusion_strategy_by_idx", "set_fusion_strategy_by_size", "get_algo_parameters", - "reset_algo_parameters", "set_algo_parameters"] +__all__ = ["get_algo_parameters", "reset_algo_parameters", "set_algo_parameters"] diff --git a/mindspore/parallel/_auto_parallel_context.py b/mindspore/parallel/_auto_parallel_context.py index 3564ad4395..c99ac4a3c7 100644 --- a/mindspore/parallel/_auto_parallel_context.py +++ b/mindspore/parallel/_auto_parallel_context.py @@ -14,6 +14,8 @@ # ============================================================================ """Context of auto parallel""" import threading +import mindspore.context as context +from mindspore.parallel._dp_allreduce_fusion import _set_fusion_strategy_by_idx, _set_fusion_strategy_by_size from mindspore._c_expression import AutoParallelContext from mindspore._extends.pynative_helper import args_type_check @@ -219,13 +221,15 @@ class _AutoParallelContext: indices (list): Indices list. Raises: - ValueError: If type of indices item is not int. + TypeError: If type of indices item is not int. """ self.check_context_handle() for index in indices: if not isinstance(index, int): raise TypeError('indices has invalid value') - return self._context_handle.set_all_reduce_fusion_split_indices(indices) + self._context_handle.set_all_reduce_fusion_split_indices(indices) + if context.get_context("device_target") == "Ascend": + _set_fusion_strategy_by_idx(indices) def get_all_reduce_fusion_split_indices(self): """Get allreduce fusion split indices.""" @@ -240,13 +244,15 @@ class _AutoParallelContext: sizes (list): Sizes list. Raises: - ValueError: If type of sizes item is not int. + TypeError: If type of sizes item is not int. """ self.check_context_handle() for size in sizes: if not isinstance(size, int): raise TypeError('sizes has invalid value') - return self._context_handle.set_all_reduce_fusion_split_sizes(sizes) + self._context_handle.set_all_reduce_fusion_split_sizes(sizes) + if context.get_context("device_target") == "Ascend": + _set_fusion_strategy_by_size(sizes) def get_all_reduce_fusion_split_sizes(self): """Get allreduce fusion split sizes.""" diff --git a/mindspore/parallel/dp_allreduce_fusion.py b/mindspore/parallel/_dp_allreduce_fusion.py similarity index 94% rename from mindspore/parallel/dp_allreduce_fusion.py rename to mindspore/parallel/_dp_allreduce_fusion.py index 979823bd80..3c7039dbd6 100644 --- a/mindspore/parallel/dp_allreduce_fusion.py +++ b/mindspore/parallel/_dp_allreduce_fusion.py @@ -43,7 +43,7 @@ def _c_array(ctype, values): return (ctype * len(values))(*values) -def set_fusion_strategy_by_idx(idxList, group="hccl_world_group"): +def _set_fusion_strategy_by_idx(idxList, group="hccl_world_group"): """ A function set gradient segment strategy according to the index list. @@ -100,7 +100,7 @@ def set_fusion_strategy_by_idx(idxList, group="hccl_world_group"): raise RuntimeError('Allreduce split error') -def set_fusion_strategy_by_size(dataSizeList, group="hccl_world_group"): +def _set_fusion_strategy_by_size(dataSizeList, group="hccl_world_group"): """ A function set gradient segment strategy according to the data size percentage list.