Browse Source

add global batch normalization

tags/v0.2.0-alpha
zhaojichen 6 years ago
parent
commit
27c3076849
2 changed files with 25 additions and 13 deletions
  1. +16
    -11
      mindspore/nn/layer/normalization.py
  2. +9
    -2
      tests/ut/python/nn/test_batchnorm.py

+ 16
- 11
mindspore/nn/layer/normalization.py View File

@@ -69,16 +69,17 @@ class _BatchNorm(Cell):
self.beta = Parameter(initializer(
beta_init, num_features), name="beta", requires_grad=affine)
self.group = check_int_positive(group)
self.rank_id = get_rank()
self.rank_size = get_local_rank_size()
self.device_list = [i for i in range(0, self.rank_size)]
self.rank_list = self.list_group(self.device_list, self.group)
self.rank_list_idx = len(self.rank_list)
for i in range(self.rank_list_idx):
if self.rank_id in self.rank_list[i] and self.group != 1:
self.is_global = True
management.create_group('group' + str(i), self.rank_list[i])
self.all_reduce = _GlobalBNHelper('group' + str(i))
if self.group != 1:
self.rank_id = get_rank()
self.rank_size = get_local_rank_size()
self.device_list = [i for i in range(0, self.rank_size)]
self.rank_list = self.list_group(self.device_list, self.group)
self.rank_list_idx = len(self.rank_list)
for i in range(self.rank_list_idx):
if self.rank_id in self.rank_list[i] and self.group != 1:
self.is_global = True
management.create_group('group' + str(i), self.rank_list[i])
self.all_reduce = _GlobalBNHelper('group' + str(i))
self.shape = P.Shape()
self.reduce_mean = P.ReduceMean()
self.square = P.Square()
@@ -110,7 +111,8 @@ class _BatchNorm(Cell):

def list_group(self, world_rank, group_size):
if group_size > get_local_rank_size():
raise ValueError("group size can not be greater than local rank size, group size is {}, local_rank_size is {}".format(group_size, get_local_rank_size()))
raise ValueError("group size can not be greater than local rank size, group size is {}, local_rank_size is {}".format(
group_size, get_local_rank_size()))
if len(world_rank) % group_size != 0:
raise ValueError("please make your group size correct.")
world_rank_list = zip(*(iter(world_rank),) *group_size)
@@ -322,6 +324,9 @@ class GlobalBatchNorm(_BatchNorm):
>>> input = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), mindspore.float32)
>>> global_bn_op(input)
"""
def _check_data_dim(self, x):
if x.dim == 0:
pass

class LayerNorm(Cell):
r"""


+ 9
- 2
tests/ut/python/nn/test_batchnorm.py View File

@@ -20,6 +20,8 @@ import mindspore.nn as nn
from mindspore.common.api import _executor
from mindspore import Tensor, Parameter
from mindspore.communication.management import init
from mindspore import context
from mindspore import ParallelMode


def test_bn_pars_valid1():
@@ -75,12 +77,17 @@ def test_compile_groupnorm():
class GlobalBNNet(nn.Cell):
def __init__(self):
super(GlobalBNNet, self).__init__()
self.bn = nn.GlobalBatchNorm(num_features = 2, group = 4)
self.bn = nn.GlobalBatchNorm(num_features = 2, group = 2)
def construct(self, x):
return self.bn(x)

def test_gloabl_bn():
def test_global_bn():
init("hccl")
size = 4
context.set_context(mode=context.GRAPH_MODE)
context.reset_auto_parallel_context()
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL,
device_num=size, parameter_broadcast=True)
net = GlobalBNNet()
input_data = Tensor(np.array([[2.4, 2.1], [3.2, 5.4]], dtype=np.float32))
net.set_train()


Loading…
Cancel
Save