|
|
|
@@ -32,6 +32,7 @@ from mindspore.communication.management import get_group_size, get_rank |
|
|
|
from mindspore.communication import management |
|
|
|
from mindspore.ops import _selected_ops |
|
|
|
from mindspore.common import dtype as mstype |
|
|
|
from mindspore.parallel._utils import _is_in_auto_parallel_mode |
|
|
|
from ..cell import Cell |
|
|
|
|
|
|
|
__all__ = ['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d', 'LayerNorm', 'GroupNorm', |
|
|
|
@@ -146,9 +147,12 @@ class _BatchNorm(Cell): |
|
|
|
device_num=self.group_device_num) |
|
|
|
|
|
|
|
self.bn_infer = P.BatchNorm(is_training=False, epsilon=self.eps, data_format=self.format) |
|
|
|
|
|
|
|
data_parallel_strategy = ((1,), (1,)) |
|
|
|
data_parallel_strategy_one = ((1,), ()) |
|
|
|
if _is_in_auto_parallel_mode(): |
|
|
|
data_parallel_strategy = ((1,), (1,)) |
|
|
|
data_parallel_strategy_one = ((1,), ()) |
|
|
|
else: |
|
|
|
data_parallel_strategy = None |
|
|
|
data_parallel_strategy_one = None |
|
|
|
self.sub_mean = P.Sub().shard(data_parallel_strategy) |
|
|
|
self.sub_var = P.Sub().shard(data_parallel_strategy) |
|
|
|
self.mul_mean = P.Mul().shard(data_parallel_strategy_one) |
|
|
|
|