Browse Source

fix doc problems

tags/v0.3.0-alpha
zhaojichen 5 years ago
parent
commit
4a79dde736
1 changed files with 5 additions and 1 deletions
  1. +5
    -1
      mindspore/nn/layer/normalization.py

+ 5
- 1
mindspore/nn/layer/normalization.py View File

@@ -17,6 +17,7 @@ from mindspore.ops import operations as P
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore.common.parameter import Parameter from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer from mindspore.common.initializer import initializer
from mindspore.ops.primitive import constexpr
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
import mindspore.context as context import mindspore.context as context
@@ -165,7 +166,9 @@ class _BatchNorm(Cell):
def extend_repr(self): def extend_repr(self):
return 'num_features={}, eps={}, momentum={}, gamma={}, beta={}, moving_mean={}, moving_variance={}'.format( return 'num_features={}, eps={}, momentum={}, gamma={}, beta={}, moving_mean={}, moving_variance={}'.format(
self.num_features, self.eps, self.momentum, self.gamma, self.beta, self.moving_mean, self.moving_variance) self.num_features, self.eps, self.momentum, self.gamma, self.beta, self.moving_mean, self.moving_variance)

def _channel_check(channel, num_channel):
if channel != num_channel:
raise ValueError("the input channel is not equal with num_channels")


class BatchNorm1d(_BatchNorm): class BatchNorm1d(_BatchNorm):
r""" r"""
@@ -508,6 +511,7 @@ class GroupNorm(Cell):


def construct(self, x): def construct(self, x):
batch, channel, height, width = self.shape(x) batch, channel, height, width = self.shape(x)
_channel_check(channel, self.num_channels)
x = self.reshape(x, (batch, self.num_groups, channel*height*width/self.num_groups)) x = self.reshape(x, (batch, self.num_groups, channel*height*width/self.num_groups))
mean = self.reduce_mean(x, 2) mean = self.reduce_mean(x, 2)
var = self.reduce_sum(self.square(x - mean), 2) / (channel * height * width / self.num_groups - 1) var = self.reduce_sum(self.square(x - mean), 2) / (channel * height * width / self.num_groups - 1)


Loading…
Cancel
Save