Browse Source

Add Group Normalization

tags/v0.2.0-alpha
zhaojichen 5 years ago
parent
commit
04c522d0c6
1 changed files with 2 additions and 0 deletions
  1. +2
    -0
      tests/ut/python/nn/test_batchnorm.py

+ 2
- 0
tests/ut/python/nn/test_batchnorm.py View File

@@ -57,6 +57,7 @@ def test_compile():
input_data = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]).astype(np.float32))
_executor.compile(net, input_data)


class GroupNet(nn.Cell):
def __init__(self):
super(GroupNet, self).__init__()
@@ -64,6 +65,7 @@ class GroupNet(nn.Cell):
def construct(self, x):
return self.group_bn(x)


def test_compile_groupnorm():
net = nn.GroupNorm(16, 64)
input_data = Tensor(np.random.rand(1,64,256,256).astype(np.float32))


Loading…
Cancel
Save