| @@ -32,6 +32,7 @@ from mindspore.communication.management import get_group_size, get_rank | |||||
| from mindspore.communication import management | from mindspore.communication import management | ||||
| from mindspore.ops import _selected_ops | from mindspore.ops import _selected_ops | ||||
| from mindspore.common import dtype as mstype | from mindspore.common import dtype as mstype | ||||
| from mindspore.parallel._utils import _is_in_auto_parallel_mode | |||||
| from ..cell import Cell | from ..cell import Cell | ||||
| __all__ = ['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d', 'LayerNorm', 'GroupNorm', | __all__ = ['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d', 'LayerNorm', 'GroupNorm', | ||||
| @@ -146,9 +147,12 @@ class _BatchNorm(Cell): | |||||
| device_num=self.group_device_num) | device_num=self.group_device_num) | ||||
| self.bn_infer = P.BatchNorm(is_training=False, epsilon=self.eps, data_format=self.format) | self.bn_infer = P.BatchNorm(is_training=False, epsilon=self.eps, data_format=self.format) | ||||
| data_parallel_strategy = ((1,), (1,)) | |||||
| data_parallel_strategy_one = ((1,), ()) | |||||
| if _is_in_auto_parallel_mode(): | |||||
| data_parallel_strategy = ((1,), (1,)) | |||||
| data_parallel_strategy_one = ((1,), ()) | |||||
| else: | |||||
| data_parallel_strategy = None | |||||
| data_parallel_strategy_one = None | |||||
| self.sub_mean = P.Sub().shard(data_parallel_strategy) | self.sub_mean = P.Sub().shard(data_parallel_strategy) | ||||
| self.sub_var = P.Sub().shard(data_parallel_strategy) | self.sub_var = P.Sub().shard(data_parallel_strategy) | ||||
| self.mul_mean = P.Mul().shard(data_parallel_strategy_one) | self.mul_mean = P.Mul().shard(data_parallel_strategy_one) | ||||
| @@ -548,14 +548,12 @@ class Conv2dBnFoldQuantOneConv(Cell): | |||||
| momentum=self.momentum, data_format=self.format) | momentum=self.momentum, data_format=self.format) | ||||
| self.bn_infer = P.BatchNorm(is_training=False, epsilon=self.eps, data_format=self.format) | self.bn_infer = P.BatchNorm(is_training=False, epsilon=self.eps, data_format=self.format) | ||||
| data_parallel_strategy = ((1,), (1,)) | |||||
| data_parallel_strategy_one = ((1,), ()) | |||||
| self.sub_mean = P.Sub().shard(data_parallel_strategy) | |||||
| self.sub_var = P.Sub().shard(data_parallel_strategy) | |||||
| self.mul_mean = P.Mul().shard(data_parallel_strategy_one) | |||||
| self.mul_var = P.Mul().shard(data_parallel_strategy_one) | |||||
| self.assign_sub_mean = P.AssignSub().shard(data_parallel_strategy) | |||||
| self.assign_sub_var = P.AssignSub().shard(data_parallel_strategy) | |||||
| self.sub_mean = P.Sub() | |||||
| self.sub_var = P.Sub() | |||||
| self.mul_mean = P.Mul() | |||||
| self.mul_var = P.Mul() | |||||
| self.assign_sub_mean = P.AssignSub() | |||||
| self.assign_sub_var = P.AssignSub() | |||||
| self.reshape = P.Reshape() | self.reshape = P.Reshape() | ||||
| def extend_repr(self): | def extend_repr(self): | ||||
| @@ -18,6 +18,7 @@ import inspect | |||||
| import copy | import copy | ||||
| from mindspore.common.api import _wrap_func | from mindspore.common.api import _wrap_func | ||||
| from mindspore import context, log as logger | from mindspore import context, log as logger | ||||
| from mindspore.parallel._utils import _is_in_auto_parallel_mode | |||||
| from .._c_expression import Primitive_, real_run_op, prim_type | from .._c_expression import Primitive_, real_run_op, prim_type | ||||
| from .._checkparam import Validator | from .._checkparam import Validator | ||||
| from . import signature as sig | from . import signature as sig | ||||
| @@ -143,7 +144,7 @@ class Primitive(Primitive_): | |||||
| strategy (tuple): Strategy describes the distributed parallel mode of the current primitive. | strategy (tuple): Strategy describes the distributed parallel mode of the current primitive. | ||||
| """ | """ | ||||
| mode = context.get_auto_parallel_context("parallel_mode") | mode = context.get_auto_parallel_context("parallel_mode") | ||||
| if mode not in [context.ParallelMode.AUTO_PARALLEL, context.ParallelMode.SEMI_AUTO_PARALLEL]: | |||||
| if not _is_in_auto_parallel_mode() and strategy: | |||||
| logger.warning(f"The shard strategy {strategy} of {self.name} is not valid in {mode}. " | logger.warning(f"The shard strategy {strategy} of {self.name} is not valid in {mode}. " | ||||
| f"Please use semi auto or auto parallel mode.") | f"Please use semi auto or auto parallel mode.") | ||||
| self.add_prim_attr("strategy", strategy) | self.add_prim_attr("strategy", strategy) | ||||
| @@ -30,6 +30,8 @@ def _get_parallel_mode(): | |||||
| """Get parallel mode.""" | """Get parallel mode.""" | ||||
| return auto_parallel_context().get_parallel_mode() | return auto_parallel_context().get_parallel_mode() | ||||
| def _is_in_auto_parallel_mode(): | |||||
| return _get_parallel_mode() in [ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL] | |||||
| def _get_full_batch(): | def _get_full_batch(): | ||||
| """Get whether to use full_batch.""" | """Get whether to use full_batch.""" | ||||