diff --git a/tests/ut/python/nn/test_batchnorm.py b/tests/ut/python/nn/test_batchnorm.py index 4bd8c996d6..e73b7ebbf0 100644 --- a/tests/ut/python/nn/test_batchnorm.py +++ b/tests/ut/python/nn/test_batchnorm.py @@ -57,6 +57,7 @@ def test_compile(): input_data = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]).astype(np.float32)) _executor.compile(net, input_data) + class GroupNet(nn.Cell): def __init__(self): super(GroupNet, self).__init__() @@ -64,6 +65,7 @@ class GroupNet(nn.Cell): def construct(self, x): return self.group_bn(x) + def test_compile_groupnorm(): net = nn.GroupNorm(16, 64) input_data = Tensor(np.random.rand(1,64,256,256).astype(np.float32))