diff --git a/mindspore/nn/layer/normalization.py b/mindspore/nn/layer/normalization.py index 9595018859..8a11a4162b 100644 --- a/mindspore/nn/layer/normalization.py +++ b/mindspore/nn/layer/normalization.py @@ -32,6 +32,7 @@ from mindspore.communication.management import get_group_size, get_rank from mindspore.communication import management from mindspore.ops import _selected_ops from mindspore.common import dtype as mstype +from mindspore.parallel._utils import _is_in_auto_parallel_mode from ..cell import Cell __all__ = ['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d', 'LayerNorm', 'GroupNorm', @@ -146,9 +147,12 @@ class _BatchNorm(Cell): device_num=self.group_device_num) self.bn_infer = P.BatchNorm(is_training=False, epsilon=self.eps, data_format=self.format) - - data_parallel_strategy = ((1,), (1,)) - data_parallel_strategy_one = ((1,), ()) + if _is_in_auto_parallel_mode(): + data_parallel_strategy = ((1,), (1,)) + data_parallel_strategy_one = ((1,), ()) + else: + data_parallel_strategy = None + data_parallel_strategy_one = None self.sub_mean = P.Sub().shard(data_parallel_strategy) self.sub_var = P.Sub().shard(data_parallel_strategy) self.mul_mean = P.Mul().shard(data_parallel_strategy_one) diff --git a/mindspore/nn/layer/quant.py b/mindspore/nn/layer/quant.py index 6b20e25f3e..2d7ba786fb 100644 --- a/mindspore/nn/layer/quant.py +++ b/mindspore/nn/layer/quant.py @@ -548,14 +548,12 @@ class Conv2dBnFoldQuantOneConv(Cell): momentum=self.momentum, data_format=self.format) self.bn_infer = P.BatchNorm(is_training=False, epsilon=self.eps, data_format=self.format) - data_parallel_strategy = ((1,), (1,)) - data_parallel_strategy_one = ((1,), ()) - self.sub_mean = P.Sub().shard(data_parallel_strategy) - self.sub_var = P.Sub().shard(data_parallel_strategy) - self.mul_mean = P.Mul().shard(data_parallel_strategy_one) - self.mul_var = P.Mul().shard(data_parallel_strategy_one) - self.assign_sub_mean = P.AssignSub().shard(data_parallel_strategy) - self.assign_sub_var = P.AssignSub().shard(data_parallel_strategy) + self.sub_mean = P.Sub() + self.sub_var = P.Sub() + self.mul_mean = P.Mul() + self.mul_var = P.Mul() + self.assign_sub_mean = P.AssignSub() + self.assign_sub_var = P.AssignSub() self.reshape = P.Reshape() def extend_repr(self): diff --git a/mindspore/ops/primitive.py b/mindspore/ops/primitive.py index 53a7009a92..a8e30528c9 100644 --- a/mindspore/ops/primitive.py +++ b/mindspore/ops/primitive.py @@ -18,6 +18,7 @@ import inspect import copy from mindspore.common.api import _wrap_func from mindspore import context, log as logger +from mindspore.parallel._utils import _is_in_auto_parallel_mode from .._c_expression import Primitive_, real_run_op, prim_type from .._checkparam import Validator from . import signature as sig @@ -143,7 +144,7 @@ class Primitive(Primitive_): strategy (tuple): Strategy describes the distributed parallel mode of the current primitive. """ mode = context.get_auto_parallel_context("parallel_mode") - if mode not in [context.ParallelMode.AUTO_PARALLEL, context.ParallelMode.SEMI_AUTO_PARALLEL]: + if not _is_in_auto_parallel_mode() and strategy: logger.warning(f"The shard strategy {strategy} of {self.name} is not valid in {mode}. " f"Please use semi auto or auto parallel mode.") self.add_prim_attr("strategy", strategy) diff --git a/mindspore/parallel/_utils.py b/mindspore/parallel/_utils.py index 9b45e5e2b0..088ed6fd38 100644 --- a/mindspore/parallel/_utils.py +++ b/mindspore/parallel/_utils.py @@ -30,6 +30,8 @@ def _get_parallel_mode(): """Get parallel mode.""" return auto_parallel_context().get_parallel_mode() +def _is_in_auto_parallel_mode(): + return _get_parallel_mode() in [ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL] def _get_full_batch(): """Get whether to use full_batch."""