You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

normalization.py 27 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """normalization"""
  16. from mindspore.ops import operations as P
  17. from mindspore.ops import functional as F
  18. from mindspore.common.parameter import Parameter
  19. from mindspore.common.initializer import initializer
  20. from mindspore.ops.primitive import constexpr
  21. from mindspore.common.tensor import Tensor
  22. import mindspore.common.dtype as mstype
  23. import mindspore.context as context
  24. from mindspore._checkparam import check_bool, check_typename
  25. from mindspore._extends import cell_attr_register
  26. from mindspore.communication.management import get_group_size, get_rank
  27. from mindspore.communication import management
  28. from mindspore._checkparam import check_int_positive
  29. from ..cell import Cell
  30. class _BatchNorm(Cell):
  31. """Batch Normalization base class."""
  32. @cell_attr_register
  33. def __init__(self,
  34. num_features,
  35. eps=1e-5,
  36. momentum=0.9,
  37. affine=True,
  38. gamma_init='ones',
  39. beta_init='zeros',
  40. moving_mean_init='zeros',
  41. moving_var_init='ones',
  42. use_batch_statistics=True,
  43. device_num_each_group=1):
  44. super(_BatchNorm, self).__init__()
  45. if num_features < 1:
  46. raise ValueError("num_features must be at least 1")
  47. if momentum < 0 or momentum > 1:
  48. raise ValueError("momentum should be a number in range [0, 1], but got {}".format(momentum))
  49. self.use_batch_statistics = use_batch_statistics
  50. self.num_features = num_features
  51. self.eps = eps
  52. self.moving_mean = Parameter(initializer(
  53. moving_mean_init, num_features), name="mean", requires_grad=False)
  54. self.moving_variance = Parameter(initializer(
  55. moving_var_init, num_features), name="variance", requires_grad=False)
  56. self.gamma = Parameter(initializer(
  57. gamma_init, num_features), name="gamma", requires_grad=affine)
  58. self.beta = Parameter(initializer(
  59. beta_init, num_features), name="beta", requires_grad=affine)
  60. self.group = check_int_positive(device_num_each_group)
  61. if self.group != 1:
  62. self.rank_id = get_rank()
  63. self.rank_size = get_group_size()
  64. self.device_list = [i for i in range(0, self.rank_size)]
  65. self.rank_list = self.list_group(self.device_list, self.group)
  66. self.rank_list_idx = len(self.rank_list)
  67. for i in range(self.rank_list_idx):
  68. if self.rank_id in self.rank_list[i] and self.group != 1:
  69. self.is_global = True
  70. management.create_group('group' + str(i), self.rank_list[i])
  71. self.all_reduce = P.AllReduce(P.ReduceOp.SUM, 'group' + str(i)).add_prim_attr('fusion', 1)
  72. self.shape = P.Shape()
  73. self.reduce_mean = P.ReduceMean()
  74. self.square = P.Square()
  75. if context.get_context("enable_ge"):
  76. self.is_ge_backend = True
  77. self.momentum = Tensor(1.0 - momentum, mstype.float32)
  78. self.bn_train = P.BatchNorm(is_training=True,
  79. epsilon=self.eps)
  80. else:
  81. self.is_ge_backend = False
  82. self.momentum = 1.0 - momentum
  83. self.bn_train = P.FusedBatchNorm(mode=1,
  84. epsilon=self.eps,
  85. momentum=self.momentum)
  86. self.bn_infer = P.BatchNorm(is_training=False, epsilon=self.eps)
  87. data_parallel_strategy = ((1,), (1,))
  88. data_parallel_strategy_one = ((1,), ())
  89. self.sub_mean = P.Sub().set_strategy(data_parallel_strategy)
  90. self.sub_var = P.Sub().set_strategy(data_parallel_strategy)
  91. self.mul_mean = P.Mul().set_strategy(data_parallel_strategy_one)
  92. self.mul_var = P.Mul().set_strategy(data_parallel_strategy_one)
  93. self.assign_sub_mean = P.AssignSub().set_strategy(data_parallel_strategy)
  94. self.assign_sub_var = P.AssignSub().set_strategy(data_parallel_strategy)
  95. def _check_data_dim(self, x):
  96. raise NotImplementedError
  97. def list_group(self, world_rank, group_size):
  98. if group_size > get_group_size():
  99. raise ValueError("group size can not be greater than local rank size, group size is {}, "
  100. "local_rank_size is {}".format(group_size, get_group_size()))
  101. if len(world_rank) % group_size != 0:
  102. raise ValueError("please make your group size correct.")
  103. world_rank_list = zip(*(iter(world_rank),) *group_size)
  104. group_list = [list(i) for i in world_rank_list]
  105. return group_list
  106. def construct(self, x):
  107. if self.training and self.use_batch_statistics:
  108. if self.is_ge_backend:
  109. if self.is_global:
  110. x_mean = self.reduce_mean(x)
  111. x_mean_square = self.reduce_mean(self.square(x))
  112. global_batch_mean = self.all_reduce(x_mean) / self.group
  113. global_batch_mean_square = self.all_reduce(x_mean_square) / self.group
  114. global_mean = global_batch_mean
  115. global_var = global_batch_mean_square - self.square(global_batch_mean)
  116. y, batch_mean, batch_var, _, _ = \
  117. self.bn_train(x,
  118. self.gamma,
  119. self.beta,
  120. None,
  121. None)
  122. mean_sub = self.sub_mean(self.moving_mean, global_mean)
  123. temp_mean = self.mul_mean(mean_sub, self.momentum)
  124. mean_sub2 = self.sub_var(self.moving_variance, global_var)
  125. temp_variance = self.mul_var(mean_sub2, self.momentum)
  126. y = F.depend(y, self.assign_sub_mean(self.moving_mean, temp_mean))
  127. y = F.depend(y, self.assign_sub_var(self.moving_variance, temp_variance))
  128. else:
  129. y, batch_mean, batch_var, _, _ = \
  130. self.bn_train(x,
  131. self.gamma,
  132. self.beta,
  133. None,
  134. None)
  135. mean_sub = self.sub_mean(self.moving_mean, batch_mean)
  136. temp_mean = self.mul_mean(mean_sub, self.momentum)
  137. mean_sub2 = self.sub_var(self.moving_variance, batch_var)
  138. temp_variance = self.mul_var(mean_sub2, self.momentum)
  139. y = F.depend(y, self.assign_sub_mean(self.moving_mean, temp_mean))
  140. y = F.depend(y, self.assign_sub_var(self.moving_variance, temp_variance))
  141. else:
  142. y = self.bn_train(x,
  143. self.gamma,
  144. self.beta,
  145. self.moving_mean,
  146. self.moving_variance)[0]
  147. else:
  148. y = self.bn_infer(x,
  149. self.gamma,
  150. self.beta,
  151. self.moving_mean,
  152. self.moving_variance)[0]
  153. return y
  154. def extend_repr(self):
  155. return 'num_features={}, eps={}, momentum={}, gamma={}, beta={}, moving_mean={}, moving_variance={}'.format(
  156. self.num_features, self.eps, self.momentum, self.gamma, self.beta, self.moving_mean, self.moving_variance)
  157. @constexpr
  158. def _channel_check(channel, num_channel):
  159. if channel != num_channel:
  160. raise ValueError("the input channel is not equal with num_channel")
  161. class BatchNorm1d(_BatchNorm):
  162. r"""
  163. Batch normalization layer over a 2D input.
  164. Batch Normalization is widely used in convolutional networks. This layer
  165. applies Batch Normalization over a 2D input (a mini-batch of 1D inputs) to
  166. reduce internal covariate shift as described in the paper
  167. `Batch Normalization: Accelerating Deep Network Training by
  168. Reducing Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`_. It
  169. rescales and recenters the feature using a mini-batch of data and
  170. the learned parameters which can be described in the following formula.
  171. .. math::
  172. y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
  173. Args:
  174. num_features (int): `C` from an expected input of size (N, C).
  175. eps (float): A value added to the denominator for numerical stability. Default: 1e-5.
  176. momentum (float): A floating hyperparameter of the momentum for the
  177. running_mean and running_var computation. Default: 0.9.
  178. affine (bool): A bool value when set to True, gamma and beta can be learnable. Default: True.
  179. gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
  180. The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
  181. 'he_uniform', etc. Default: 'ones'.
  182. beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
  183. The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
  184. 'he_uniform', etc. Default: 'zeros'.
  185. moving_mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving mean.
  186. The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
  187. 'he_uniform', etc. Default: 'zeros'.
  188. moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance.
  189. The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
  190. 'he_uniform', etc. Default: 'ones'.
  191. use_batch_statistics (bool): If true, use the mean value and variance value of current batch data, else use
  192. the mean value and variance value of specified value. Default: True.
  193. Inputs:
  194. - **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
  195. Outputs:
  196. Tensor, the normalized, scaled, offset tensor, of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
  197. Examples:
  198. >>> net = nn.BatchNorm1d(num_features=16)
  199. >>> input = Tensor(np.random.randint(0, 255, [3, 16]), mindspore.float32)
  200. >>> net(input)
  201. """
  202. def __init__(self,
  203. num_features,
  204. eps=1e-5,
  205. momentum=0.9,
  206. affine=True,
  207. gamma_init='ones',
  208. beta_init='zeros',
  209. moving_mean_init='zeros',
  210. moving_var_init='ones',
  211. use_batch_statistics=True):
  212. super(BatchNorm1d, self).__init__(num_features,
  213. eps,
  214. momentum,
  215. affine,
  216. gamma_init,
  217. beta_init,
  218. moving_mean_init,
  219. moving_var_init,
  220. use_batch_statistics)
  221. def _check_data_dim(self, x):
  222. if x.dim() != 2:
  223. pass
  224. class BatchNorm2d(_BatchNorm):
  225. r"""
  226. Batch normalization layer over a 4D input.
  227. Batch Normalization is widely used in convolutional networks. This layer
  228. applies Batch Normalization over a 4D input (a mini-batch of 2D inputs with
  229. additional channel dimension) to avoid internal covariate shift as described
  230. in the paper `Batch Normalization: Accelerating Deep Network Training by
  231. Reducing Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`_. It
  232. rescales and recenters the feature using a mini-batch of data and
  233. the learned parameters which can be described in the following formula.
  234. .. math::
  235. y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
  236. Args:
  237. num_features (int): `C` from an expected input of size (N, C, H, W).
  238. eps (float): A value added to the denominator for numerical stability. Default: 1e-5.
  239. momentum (float): A floating hyperparameter of the momentum for the
  240. running_mean and running_var computation. Default: 0.9.
  241. affine (bool): A bool value when set to True, gamma and beta can be learnable. Default: True.
  242. gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
  243. The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
  244. 'he_uniform', etc. Default: 'ones'.
  245. beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
  246. The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
  247. 'he_uniform', etc. Default: 'zeros'.
  248. moving_mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving mean.
  249. The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
  250. 'he_uniform', etc. Default: 'zeros'.
  251. moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance.
  252. The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
  253. 'he_uniform', etc. Default: 'ones'.
  254. use_batch_statistics (bool): If true, use the mean value and variance value of current batch data, else use
  255. the mean value and variance value of specified value. Default: True.
  256. Inputs:
  257. - **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
  258. Outputs:
  259. Tensor, the normalized, scaled, offset tensor, of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
  260. Examples:
  261. >>> net = nn.BatchNorm2d(num_features=3)
  262. >>> input = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), mindspore.float32)
  263. >>> net(input)
  264. """
  265. def __init__(self,
  266. num_features,
  267. eps=1e-5,
  268. momentum=0.9,
  269. affine=True,
  270. gamma_init='ones',
  271. beta_init='zeros',
  272. moving_mean_init='zeros',
  273. moving_var_init='ones',
  274. use_batch_statistics=True):
  275. super(BatchNorm2d, self).__init__(num_features,
  276. eps,
  277. momentum,
  278. affine,
  279. gamma_init,
  280. beta_init,
  281. moving_mean_init,
  282. moving_var_init,
  283. use_batch_statistics)
  284. def _check_data_dim(self, x):
  285. if x.dim() != 4:
  286. pass
  287. class GlobalBatchNorm(_BatchNorm):
  288. r"""
  289. Global normalization layer over a N-dimension input.
  290. Global Normalization is cross device synchronized batch normalization. Batch Normalization implementation
  291. only normalize the data within each device. Global normalization will normalize the input within the group.
  292. It has been described in the paper `Batch Normalization: Accelerating Deep Network Training by
  293. Reducing Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`_. It rescales and recenters the
  294. feature using a mini-batch of data and the learned parameters which can be described in the following formula.
  295. .. math::
  296. y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
  297. Args:
  298. num_features (int): `C` from an expected input of size (N, C, H, W).
  299. device_num_each_group (int): The number of devices in each group.
  300. eps (float): A value added to the denominator for numerical stability. Default: 1e-5.
  301. momentum (float): A floating hyperparameter of the momentum for the
  302. running_mean and running_var computation. Default: 0.9.
  303. gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
  304. The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
  305. 'he_uniform', etc. Default: 'ones'.
  306. beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
  307. The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
  308. 'he_uniform', etc. Default: 'zeros'.
  309. moving_mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving mean.
  310. The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
  311. 'he_uniform', etc. Default: 'zeros'.
  312. moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance.
  313. The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
  314. 'he_uniform', etc. Default: 'ones'.
  315. use_batch_statistics (bool): If true, use the mean value and variance value of current batch data, else use
  316. the mean value and variance value of specified value. Default: True.
  317. Inputs:
  318. - **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
  319. Outputs:
  320. Tensor, the normalized, scaled, offset tensor, of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
  321. Examples:
  322. >>> global_bn_op = nn.GlobalBatchNorm(num_features=3, device_num_each_group=4)
  323. >>> input = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), mindspore.float32)
  324. >>> global_bn_op(input)
  325. """
  326. def __init__(self,
  327. num_features,
  328. eps=1e-5,
  329. momentum=0.9,
  330. affine=True,
  331. gamma_init='ones',
  332. beta_init='zeros',
  333. moving_mean_init='zeros',
  334. moving_var_init='ones',
  335. use_batch_statistics=True,
  336. device_num_each_group=1):
  337. super(GlobalBatchNorm, self).__init__(num_features,
  338. eps,
  339. momentum,
  340. affine,
  341. gamma_init,
  342. beta_init,
  343. moving_mean_init,
  344. moving_var_init,
  345. use_batch_statistics,
  346. device_num_each_group)
  347. self.group = check_int_positive(device_num_each_group)
  348. if self.group <= 1:
  349. raise ValueError("the number of group must be greater than 1.")
  350. def _check_data_dim(self, x):
  351. if x.dim == 0:
  352. pass
  353. class LayerNorm(Cell):
  354. r"""
  355. Applies Layer Normalization over a mini-batch of inputs.
  356. Layer normalization is widely used in recurrent neural networks. It applies
  357. normalization over a mini-batch of inputs for each single training case as described
  358. in the paper `Layer Normalization <https://arxiv.org/pdf/1607.06450.pdf>`_. Unlike batch
  359. normalization, layer normalization performs exactly the same computation at training and
  360. testing times. It can be described using the following formula. It is applied across all channels
  361. and pixel but only one batch size.
  362. .. math::
  363. y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
  364. Args:
  365. normalized_shape (Union(tuple[int], list[int]): The normalization is performed over axis
  366. `begin_norm_axis ... R - 1`.
  367. begin_norm_axis (int): It first normalization dimension: normalization will be performed along dimensions
  368. `begin_norm_axis: rank(inputs)`, the value should be in [-1, rank(input)). Default: -1.
  369. begin_params_axis (int): The first parameter(beta, gamma)dimension: scale and centering parameters
  370. will have dimensions `begin_params_axis: rank(inputs)` and will be broadcast with
  371. the normalized inputs accordingly, the value should be in [-1, rank(input)). Default: -1.
  372. gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
  373. The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
  374. 'he_uniform', etc. Default: 'ones'.
  375. beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
  376. The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
  377. 'he_uniform', etc. Default: 'zeros'.
  378. Inputs:
  379. - **input_x** (Tensor) - The shape of 'input_x' is :math:`(x_1, x_2, ..., x_R)`,
  380. and `input_shape[begin_norm_axis:]` is equal to `normalized_shape`.
  381. Outputs:
  382. Tensor, the normalized and scaled offset tensor, has the same shape and data type as the `input_x`.
  383. Examples:
  384. >>> x = Tensor(np.ones([20, 5, 10, 10]), mindspore.float32)
  385. >>> shape1 = x.shape()[1:]
  386. >>> m = nn.LayerNorm(shape1, begin_norm_axis=1, begin_params_axis=1)
  387. >>> m(x)
  388. """
  389. def __init__(self,
  390. normalized_shape,
  391. begin_norm_axis=-1,
  392. begin_params_axis=-1,
  393. gamma_init='ones',
  394. beta_init='zeros',
  395. ):
  396. super(LayerNorm, self).__init__()
  397. self.normalized_shape = normalized_shape
  398. self.begin_norm_axis = begin_norm_axis
  399. self.begin_params_axis = begin_params_axis
  400. self.gamma = Parameter(initializer(
  401. gamma_init, normalized_shape), name="gamma")
  402. self.beta = Parameter(initializer(
  403. beta_init, normalized_shape), name="beta")
  404. self.layer_norm = P.LayerNorm(begin_norm_axis=self.begin_norm_axis, begin_params_axis=self.begin_params_axis)
  405. def construct(self, input_x):
  406. y, _, _ = self.layer_norm(input_x, self.gamma, self.beta)
  407. return y
  408. def extend_repr(self):
  409. """Display instance object as string."""
  410. s = 'normalized_shape={}, begin_norm_axis={}, begin_params_axis={}, gamma{}, beta={}'.format(
  411. self.normalized_shape, self.begin_norm_axis, self.begin_params_axis, self.gamma, self.beta)
  412. return s
  413. class GroupNorm(Cell):
  414. r"""
  415. Group Normalization over a mini-batch of inputs.
  416. Group normalization is widely used in recurrent neural networks. It applies
  417. normalization over a mini-batch of inputs for each single training case as described
  418. in the paper `Group Normalization <https://arxiv.org/pdf/1803.08494.pdf>`_. Group normalization
  419. divides the channels into groups and computes within each group the mean and variance for normalization,
  420. and it performs very stable over a wide range of batch size. It can be described using the following formula.
  421. .. math::
  422. y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
  423. Args:
  424. num_groups (int): The number of groups to be divided along the channel dimension.
  425. num_channels (int): The number of channels per group.
  426. eps (float): A value added to the denominator for numerical stability. Default: 1e-5.
  427. affine (bool): A bool value, this layer will has learnable affine parameters when set to true. Default: True.
  428. Inputs:
  429. - **input_x** (Tensor) - The input feature with shape [N, C, H, W].
  430. Outputs:
  431. Tensor, the normalized and scaled offset tensor, has the same shape and data type as the `input_x`.
  432. Examples:
  433. >>> goup_norm_op = nn.GroupNorm(16, 64)
  434. >>> x = Tensor(np.ones([1, 64, 256, 256], np.float32))
  435. >>> goup_norm_op(x)
  436. """
  437. def __init__(self, num_groups, num_channels, eps=1e-05, affine=True, gamma_init='ones', beta_init='zeros'):
  438. super(GroupNorm, self).__init__()
  439. self.num_groups = check_int_positive(num_groups)
  440. self.num_channels = check_int_positive(num_channels)
  441. if num_channels % num_groups != 0:
  442. raise ValueError("num_channels should be divided by num_groups")
  443. self.eps = check_typename('eps', eps, (float,))
  444. self.affine = check_bool(affine)
  445. gamma = initializer(gamma_init, [num_channels, 1, 1])
  446. beta = initializer(beta_init, [num_channels, 1, 1])
  447. if self.affine:
  448. self.gamma = Parameter(gamma, name='gamma')
  449. self.beta = Parameter(beta, name='beta')
  450. else:
  451. self.gamma = gamma
  452. self.beta = beta
  453. self.shape = F.shape
  454. self.reshape = F.reshape
  455. self.reduce_mean = P.ReduceMean(keep_dims=True)
  456. self.square = F.square
  457. self.reduce_sum = P.ReduceSum(keep_dims=True)
  458. self.sqrt = P.Sqrt()
  459. def construct(self, x):
  460. batch, channel, height, width = self.shape(x)
  461. _channel_check(channel, self.num_channels)
  462. x = self.reshape(x, (batch, self.num_groups, channel*height*width/self.num_groups))
  463. mean = self.reduce_mean(x, 2)
  464. var = self.reduce_sum(self.square(x - mean), 2) / (channel * height * width / self.num_groups - 1)
  465. std = self.sqrt(var + self.eps)
  466. x = (x - mean) / std
  467. x = self.reshape(x, (batch, channel, height, width))
  468. output = x * self.gamma + self.beta
  469. return output
  470. def extend_repr(self):
  471. """Display instance object as string."""
  472. s = 'num_groups={}, num_channels={}'.format(self.num_groups, self.num_channels)
  473. return s