Browse Source

fix shard strategy for batchnorm

pull/15337/head
Ziyan 4 years ago
parent
commit
b3eebea4de
4 changed files with 17 additions and 12 deletions
  1. +7
    -3
      mindspore/nn/layer/normalization.py
  2. +6
    -8
      mindspore/nn/layer/quant.py
  3. +2
    -1
      mindspore/ops/primitive.py
  4. +2
    -0
      mindspore/parallel/_utils.py

+ 7
- 3
mindspore/nn/layer/normalization.py View File

@@ -32,6 +32,7 @@ from mindspore.communication.management import get_group_size, get_rank
from mindspore.communication import management
from mindspore.ops import _selected_ops
from mindspore.common import dtype as mstype
from mindspore.parallel._utils import _is_in_auto_parallel_mode
from ..cell import Cell

__all__ = ['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d', 'LayerNorm', 'GroupNorm',
@@ -146,9 +147,12 @@ class _BatchNorm(Cell):
device_num=self.group_device_num)

self.bn_infer = P.BatchNorm(is_training=False, epsilon=self.eps, data_format=self.format)

data_parallel_strategy = ((1,), (1,))
data_parallel_strategy_one = ((1,), ())
if _is_in_auto_parallel_mode():
data_parallel_strategy = ((1,), (1,))
data_parallel_strategy_one = ((1,), ())
else:
data_parallel_strategy = None
data_parallel_strategy_one = None
self.sub_mean = P.Sub().shard(data_parallel_strategy)
self.sub_var = P.Sub().shard(data_parallel_strategy)
self.mul_mean = P.Mul().shard(data_parallel_strategy_one)


+ 6
- 8
mindspore/nn/layer/quant.py View File

@@ -548,14 +548,12 @@ class Conv2dBnFoldQuantOneConv(Cell):
momentum=self.momentum, data_format=self.format)

self.bn_infer = P.BatchNorm(is_training=False, epsilon=self.eps, data_format=self.format)
data_parallel_strategy = ((1,), (1,))
data_parallel_strategy_one = ((1,), ())
self.sub_mean = P.Sub().shard(data_parallel_strategy)
self.sub_var = P.Sub().shard(data_parallel_strategy)
self.mul_mean = P.Mul().shard(data_parallel_strategy_one)
self.mul_var = P.Mul().shard(data_parallel_strategy_one)
self.assign_sub_mean = P.AssignSub().shard(data_parallel_strategy)
self.assign_sub_var = P.AssignSub().shard(data_parallel_strategy)
self.sub_mean = P.Sub()
self.sub_var = P.Sub()
self.mul_mean = P.Mul()
self.mul_var = P.Mul()
self.assign_sub_mean = P.AssignSub()
self.assign_sub_var = P.AssignSub()
self.reshape = P.Reshape()

def extend_repr(self):


+ 2
- 1
mindspore/ops/primitive.py View File

@@ -18,6 +18,7 @@ import inspect
import copy
from mindspore.common.api import _wrap_func
from mindspore import context, log as logger
from mindspore.parallel._utils import _is_in_auto_parallel_mode
from .._c_expression import Primitive_, real_run_op, prim_type
from .._checkparam import Validator
from . import signature as sig
@@ -143,7 +144,7 @@ class Primitive(Primitive_):
strategy (tuple): Strategy describes the distributed parallel mode of the current primitive.
"""
mode = context.get_auto_parallel_context("parallel_mode")
if mode not in [context.ParallelMode.AUTO_PARALLEL, context.ParallelMode.SEMI_AUTO_PARALLEL]:
if not _is_in_auto_parallel_mode() and strategy:
logger.warning(f"The shard strategy {strategy} of {self.name} is not valid in {mode}. "
f"Please use semi auto or auto parallel mode.")
self.add_prim_attr("strategy", strategy)


+ 2
- 0
mindspore/parallel/_utils.py View File

@@ -30,6 +30,8 @@ def _get_parallel_mode():
"""Get parallel mode."""
return auto_parallel_context().get_parallel_mode()

def _is_in_auto_parallel_mode():
return _get_parallel_mode() in [ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL]

def _get_full_batch():
"""Get whether to use full_batch."""


Loading…
Cancel
Save