You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

normalization.py 26 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """normalization"""
  16. from mindspore.ops import operations as P
  17. from mindspore.ops import functional as F
  18. from mindspore.common.parameter import Parameter
  19. from mindspore.common.initializer import initializer
  20. from mindspore.common.tensor import Tensor
  21. import mindspore.common.dtype as mstype
  22. import mindspore.context as context
  23. from mindspore._checkparam import check_bool, check_typename
  24. from mindspore._extends import cell_attr_register
  25. from mindspore.communication.management import get_group_size, get_rank
  26. from mindspore.communication import management
  27. from mindspore._checkparam import check_int_positive
  28. from ..cell import Cell
  29. class _BatchNorm(Cell):
  30. """Batch Normalization base class."""
  31. @cell_attr_register
  32. def __init__(self,
  33. num_features,
  34. eps=1e-5,
  35. momentum=0.9,
  36. affine=True,
  37. gamma_init='ones',
  38. beta_init='zeros',
  39. moving_mean_init='zeros',
  40. moving_var_init='ones',
  41. use_batch_statistics=True,
  42. group=1):
  43. super(_BatchNorm, self).__init__()
  44. if num_features < 1:
  45. raise ValueError("num_features must be at least 1")
  46. if momentum < 0 or momentum > 1:
  47. raise ValueError("momentum should be a number in range [0, 1], but got {}".format(momentum))
  48. self.use_batch_statistics = use_batch_statistics
  49. self.num_features = num_features
  50. self.eps = eps
  51. self.moving_mean = Parameter(initializer(
  52. moving_mean_init, num_features), name="mean", requires_grad=False)
  53. self.moving_variance = Parameter(initializer(
  54. moving_var_init, num_features), name="variance", requires_grad=False)
  55. self.gamma = Parameter(initializer(
  56. gamma_init, num_features), name="gamma", requires_grad=affine)
  57. self.beta = Parameter(initializer(
  58. beta_init, num_features), name="beta", requires_grad=affine)
  59. self.group = check_int_positive(group)
  60. if self.group != 1:
  61. self.rank_id = get_rank()
  62. self.rank_size = get_group_size()
  63. self.device_list = [i for i in range(0, self.rank_size)]
  64. self.rank_list = self.list_group(self.device_list, self.group)
  65. self.rank_list_idx = len(self.rank_list)
  66. for i in range(self.rank_list_idx):
  67. if self.rank_id in self.rank_list[i] and self.group != 1:
  68. self.is_global = True
  69. management.create_group('group' + str(i), self.rank_list[i])
  70. self.all_reduce = P.AllReduce(P.ReduceOp.SUM, 'group' + str(i)).add_prim_attr('fusion', 1)
  71. self.shape = P.Shape()
  72. self.reduce_mean = P.ReduceMean()
  73. self.square = P.Square()
  74. if context.get_context("enable_ge"):
  75. self.is_ge_backend = True
  76. self.momentum = Tensor(1.0 - momentum, mstype.float32)
  77. self.bn_train = P.BatchNorm(is_training=True,
  78. epsilon=self.eps)
  79. else:
  80. self.is_ge_backend = False
  81. self.momentum = 1.0 - momentum
  82. self.bn_train = P.FusedBatchNorm(mode=1,
  83. epsilon=self.eps,
  84. momentum=self.momentum)
  85. self.bn_infer = P.BatchNorm(is_training=False, epsilon=self.eps)
  86. data_parallel_strategy = ((1,), (1,))
  87. data_parallel_strategy_one = ((1,), ())
  88. self.sub_mean = P.Sub().set_strategy(data_parallel_strategy)
  89. self.sub_var = P.Sub().set_strategy(data_parallel_strategy)
  90. self.mul_mean = P.Mul().set_strategy(data_parallel_strategy_one)
  91. self.mul_var = P.Mul().set_strategy(data_parallel_strategy_one)
  92. self.assign_sub_mean = P.AssignSub().set_strategy(data_parallel_strategy)
  93. self.assign_sub_var = P.AssignSub().set_strategy(data_parallel_strategy)
  94. def _check_data_dim(self, x):
  95. raise NotImplementedError
  96. def list_group(self, world_rank, group_size):
  97. if group_size > get_group_size():
  98. raise ValueError("group size can not be greater than local rank size, group size is {}, "
  99. "local_rank_size is {}".format(group_size, get_group_size()))
  100. if len(world_rank) % group_size != 0:
  101. raise ValueError("please make your group size correct.")
  102. world_rank_list = zip(*(iter(world_rank),) *group_size)
  103. group_list = [list(i) for i in world_rank_list]
  104. return group_list
  105. def construct(self, x):
  106. if self.training and self.use_batch_statistics:
  107. if self.is_ge_backend:
  108. if self.is_global:
  109. x_mean = self.reduce_mean(x)
  110. x_mean_square = self.reduce_mean(self.square(x))
  111. global_batch_mean = self.all_reduce(x_mean) / self.group
  112. global_batch_mean_square = self.all_reduce(x_mean_square) / self.group
  113. global_mean = global_batch_mean
  114. global_var = global_batch_mean_square - self.square(global_batch_mean)
  115. y, batch_mean, batch_var, _, _ = \
  116. self.bn_train(x,
  117. self.gamma,
  118. self.beta,
  119. None,
  120. None)
  121. mean_sub = self.sub_mean(self.moving_mean, global_mean)
  122. temp_mean = self.mul_mean(mean_sub, self.momentum)
  123. mean_sub2 = self.sub_var(self.moving_variance, global_var)
  124. temp_variance = self.mul_var(mean_sub2, self.momentum)
  125. y = F.depend(y, self.assign_sub_mean(self.moving_mean, temp_mean))
  126. y = F.depend(y, self.assign_sub_var(self.moving_variance, temp_variance))
  127. else:
  128. y, batch_mean, batch_var, _, _ = \
  129. self.bn_train(x,
  130. self.gamma,
  131. self.beta,
  132. None,
  133. None)
  134. mean_sub = self.sub_mean(self.moving_mean, batch_mean)
  135. temp_mean = self.mul_mean(mean_sub, self.momentum)
  136. mean_sub2 = self.sub_var(self.moving_variance, batch_var)
  137. temp_variance = self.mul_var(mean_sub2, self.momentum)
  138. y = F.depend(y, self.assign_sub_mean(self.moving_mean, temp_mean))
  139. y = F.depend(y, self.assign_sub_var(self.moving_variance, temp_variance))
  140. else:
  141. y = self.bn_train(x,
  142. self.gamma,
  143. self.beta,
  144. self.moving_mean,
  145. self.moving_variance)[0]
  146. else:
  147. y = self.bn_infer(x,
  148. self.gamma,
  149. self.beta,
  150. self.moving_mean,
  151. self.moving_variance)[0]
  152. return y
  153. def extend_repr(self):
  154. return 'num_features={}, eps={}, momentum={}, gamma={}, beta={}, moving_mean={}, moving_variance={}'.format(
  155. self.num_features, self.eps, self.momentum, self.gamma, self.beta, self.moving_mean, self.moving_variance)
  156. class BatchNorm1d(_BatchNorm):
  157. r"""
  158. Batch normalization layer over a 2D input.
  159. Batch Normalization is widely used in convolutional networks. This layer
  160. applies Batch Normalization over a 2D input (a mini-batch of 1D inputs) to
  161. reduce internal covariate shift as described in the paper
  162. `Batch Normalization: Accelerating Deep Network Training by
  163. Reducing Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`_. It
  164. rescales and recenters the feature using a mini-batch of data and
  165. the learned parameters which can be described in the following formula.
  166. .. math::
  167. y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
  168. Args:
  169. num_features (int): `C` from an expected input of size (N, C).
  170. eps (float): A value added to the denominator for numerical stability. Default: 1e-5.
  171. momentum (float): A floating hyperparameter of the momentum for the
  172. running_mean and running_var computation. Default: 0.9.
  173. affine (bool): A bool value when set to True, gamma and beta can be learnable. Default: True.
  174. gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
  175. The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
  176. 'he_uniform', etc. Default: 'ones'.
  177. beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
  178. The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
  179. 'he_uniform', etc. Default: 'zeros'.
  180. moving_mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving mean.
  181. The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
  182. 'he_uniform', etc. Default: 'zeros'.
  183. moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance.
  184. The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
  185. 'he_uniform', etc. Default: 'ones'.
  186. use_batch_statistics (bool): If true, use the mean value and variance value of current batch data, else use
  187. the mean value and variance value of specified value. Default: True.
  188. Inputs:
  189. - **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
  190. Outputs:
  191. Tensor, the normalized, scaled, offset tensor, of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
  192. Examples:
  193. >>> net = nn.BatchNorm1d(num_features=16)
  194. >>> input = Tensor(np.random.randint(0, 255, [3, 16]), mindspore.float32)
  195. >>> net(input)
  196. """
  197. def __init__(self,
  198. num_features,
  199. eps=1e-5,
  200. momentum=0.9,
  201. affine=True,
  202. gamma_init='ones',
  203. beta_init='zeros',
  204. moving_mean_init='zeros',
  205. moving_var_init='ones',
  206. use_batch_statistics=True):
  207. super(BatchNorm1d, self).__init__(num_features,
  208. eps,
  209. momentum,
  210. affine,
  211. gamma_init,
  212. beta_init,
  213. moving_mean_init,
  214. moving_var_init,
  215. use_batch_statistics)
  216. def _check_data_dim(self, x):
  217. if x.dim() != 2:
  218. pass
  219. class BatchNorm2d(_BatchNorm):
  220. r"""
  221. Batch normalization layer over a 4D input.
  222. Batch Normalization is widely used in convolutional networks. This layer
  223. applies Batch Normalization over a 4D input (a mini-batch of 2D inputs with
  224. additional channel dimension) to avoid internal covariate shift as described
  225. in the paper `Batch Normalization: Accelerating Deep Network Training by
  226. Reducing Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`_. It
  227. rescales and recenters the feature using a mini-batch of data and
  228. the learned parameters which can be described in the following formula.
  229. .. math::
  230. y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
  231. Args:
  232. num_features (int): `C` from an expected input of size (N, C, H, W).
  233. eps (float): A value added to the denominator for numerical stability. Default: 1e-5.
  234. momentum (float): A floating hyperparameter of the momentum for the
  235. running_mean and running_var computation. Default: 0.9.
  236. affine (bool): A bool value when set to True, gamma and beta can be learnable. Default: True.
  237. gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
  238. The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
  239. 'he_uniform', etc. Default: 'ones'.
  240. beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
  241. The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
  242. 'he_uniform', etc. Default: 'zeros'.
  243. moving_mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving mean.
  244. The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
  245. 'he_uniform', etc. Default: 'zeros'.
  246. moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance.
  247. The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
  248. 'he_uniform', etc. Default: 'ones'.
  249. use_batch_statistics (bool): If true, use the mean value and variance value of current batch data, else use
  250. the mean value and variance value of specified value. Default: True.
  251. Inputs:
  252. - **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
  253. Outputs:
  254. Tensor, the normalized, scaled, offset tensor, of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
  255. Examples:
  256. >>> net = nn.BatchNorm2d(num_features=3)
  257. >>> input = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), mindspore.float32)
  258. >>> net(input)
  259. """
  260. def __init__(self,
  261. num_features,
  262. eps=1e-5,
  263. momentum=0.9,
  264. affine=True,
  265. gamma_init='ones',
  266. beta_init='zeros',
  267. moving_mean_init='zeros',
  268. moving_var_init='ones',
  269. use_batch_statistics=True):
  270. super(BatchNorm2d, self).__init__(num_features,
  271. eps,
  272. momentum,
  273. affine,
  274. gamma_init,
  275. beta_init,
  276. moving_mean_init,
  277. moving_var_init,
  278. use_batch_statistics)
  279. def _check_data_dim(self, x):
  280. if x.dim() != 4:
  281. pass
  282. class GlobalBatchNorm(_BatchNorm):
  283. r"""
  284. Global normalization layer over a N-dimension input.
  285. Global Normalization is cross device synchronized batch normalization. Batch Normalization implementation
  286. only normalize the data within each device. Global normalization will normalize the input within the group.
  287. It has been described in the paper `Batch Normalization: Accelerating Deep Network Training by
  288. Reducing Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`_. It rescales and recenters the
  289. feature using a mini-batch of data and the learned parameters which can be described in the following formula.
  290. .. math::
  291. y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
  292. Args:
  293. num_features (int): `C` from an expected input of size (N, C, H, W).
  294. group (int): The number of device in each group.
  295. eps (float): A value added to the denominator for numerical stability. Default: 1e-5.
  296. momentum (float): A floating hyperparameter of the momentum for the
  297. running_mean and running_var computation. Default: 0.9.
  298. gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
  299. The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
  300. 'he_uniform', etc. Default: 'ones'.
  301. beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
  302. The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
  303. 'he_uniform', etc. Default: 'zeros'.
  304. moving_mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving mean.
  305. The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
  306. 'he_uniform', etc. Default: 'zeros'.
  307. moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance.
  308. The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
  309. 'he_uniform', etc. Default: 'ones'.
  310. use_batch_statistics (bool): If true, use the mean value and variance value of current batch data, else use
  311. the mean value and variance value of specified value. Default: True.
  312. Inputs:
  313. - **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
  314. Outputs:
  315. Tensor, the normalized, scaled, offset tensor, of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
  316. Examples:
  317. >>> global_bn_op = nn.GlobalBatchNorm(num_features=3, group=4)
  318. >>> input = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), mindspore.float32)
  319. >>> global_bn_op(input)
  320. """
  321. def __init__(self,
  322. num_features,
  323. eps=1e-5,
  324. momentum=0.9,
  325. affine=True,
  326. gamma_init='ones',
  327. beta_init='zeros',
  328. moving_mean_init='zeros',
  329. moving_var_init='ones',
  330. use_batch_statistics=True,
  331. group=1):
  332. super(GlobalBatchNorm, self).__init__(num_features,
  333. eps,
  334. momentum,
  335. affine,
  336. gamma_init,
  337. beta_init,
  338. moving_mean_init,
  339. moving_var_init,
  340. use_batch_statistics,
  341. group)
  342. self.group = check_int_positive(group)
  343. if self.group <= 1:
  344. raise ValueError("the number of group must be greater than 1.")
  345. def _check_data_dim(self, x):
  346. if x.dim == 0:
  347. pass
  348. class LayerNorm(Cell):
  349. r"""
  350. Applies Layer Normalization over a mini-batch of inputs.
  351. Layer normalization is widely used in recurrent neural networks. It applies
  352. normalization over a mini-batch of inputs for each single training case as described
  353. in the paper `Layer Normalization <https://arxiv.org/pdf/1607.06450.pdf>`_. Unlike batch
  354. normalization, layer normalization performs exactly the same computation at training and
  355. testing times. It can be described using the following formula. It is applied across all channels
  356. and pixel but only one batch size.
  357. .. math::
  358. y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
  359. Args:
  360. normalized_shape (Union(tuple[int], list[int]): The normalization is performed over axes
  361. `begin_norm_axis ... R - 1` and centering and scaling parameters are calculated over
  362. `begin_params_axis ... R - 1`.
  363. begin_norm_axis (int): It first normalization dimension: normalization will be performed along dimensions
  364. `begin_norm_axis: rank(inputs)`, the value should be in [-1, rank(input)). Default: -1.
  365. begin_params_axis (int): The first parameter(beta, gamma)dimension: scale and centering parameters
  366. will have dimensions `begin_params_axis: rank(inputs)` and will be broadcast with
  367. the normalized inputs accordingly, the value should be in [-1, rank(input)). Default: -1.
  368. gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
  369. The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
  370. 'he_uniform', etc. Default: 'ones'.
  371. beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
  372. The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
  373. 'he_uniform', etc. Default: 'zeros'.
  374. Inputs:
  375. - **input_x** (Tensor) - The shape of 'input_x' is :math:`(x_1, x_2, ..., x_R)`,
  376. and `input_shape[begin_norm_axis:]` is equal to `normalized_shape`.
  377. Outputs:
  378. Tensor, the normalized and scaled offset tensor, has the same shape and data type as the `input_x`.
  379. Examples:
  380. >>> x = Tensor(np.ones([20, 5, 10, 10]), mindspore.float32)
  381. >>> shape1 = x.shape()[1:]
  382. >>> m = nn.LayerNorm(shape1, begin_norm_axis=1, begin_params_axis=1)
  383. >>> m(x)
  384. """
  385. def __init__(self,
  386. normalized_shape,
  387. begin_norm_axis=-1,
  388. begin_params_axis=-1,
  389. gamma_init='ones',
  390. beta_init='zeros',
  391. ):
  392. super(LayerNorm, self).__init__()
  393. self.normalized_shape = normalized_shape
  394. self.begin_norm_axis = begin_norm_axis
  395. self.begin_params_axis = begin_params_axis
  396. self.gamma = Parameter(initializer(
  397. gamma_init, normalized_shape), name="gamma")
  398. self.beta = Parameter(initializer(
  399. beta_init, normalized_shape), name="beta")
  400. self.layer_norm = P.LayerNorm(begin_norm_axis=self.begin_norm_axis, begin_params_axis=self.begin_params_axis)
  401. def construct(self, input_x):
  402. y, _, _ = self.layer_norm(input_x, self.gamma, self.beta)
  403. return y
  404. def extend_repr(self):
  405. """Display instance object as string."""
  406. s = 'normalized_shape={}, begin_norm_axis={}, begin_params_axis={}, gamma{}, beta={}'.format(
  407. self.normalized_shape, self.begin_norm_axis, self.begin_params_axis, self.gamma, self.beta)
  408. return s
  409. class GroupNorm(Cell):
  410. r"""
  411. Group Normalization over a mini-batch of inputs.
  412. Group normalization is widely used in recurrent neural networks. It applies
  413. normalization over a mini-batch of inputs for each single training case as described
  414. in the paper `Group Normalization <https://arxiv.org/pdf/1803.08494.pdf>`_. Group normalization
  415. divides the channels into groups and computes within each group the mean and variance for normalization,
  416. and it performs very stable over a wide range of batch size. It can be described using the following formula.
  417. .. math::
  418. y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
  419. Args:
  420. num_groups (int): The number of groups to be divided along the channel dimension.
  421. num_channels (int): The number of channels per group.
  422. eps (float): A value added to the denominator for numerical stability. Default: 1e-5.
  423. affine (bool): A bool value, this layer will has learnable affine parameters when set to true. Default: True.
  424. Inputs:
  425. - **input_x** (Tensor) - The input feature with shape [N, C, H, W].
  426. Outputs:
  427. Tensor, the normalized and scaled offset tensor, has the same shape and data type as the `input_x`.
  428. Examples:
  429. >>> goup_norm_op = nn.GroupNorm(16, 64)
  430. >>> x = Tensor(np.ones([1, 64, 256, 256], np.float32))
  431. >>> goup_norm_op(x)
  432. """
  433. def __init__(self, num_groups, num_channels, eps=1e-05, affine=True):
  434. super(GroupNorm, self).__init__()
  435. self.num_groups = check_int_positive(num_groups)
  436. self.num_channels = check_int_positive(num_channels)
  437. if num_channels % num_groups != 0:
  438. raise ValueError("num_channels should be divided by num_groups")
  439. self.eps = Tensor(check_typename('eps', eps, (float,)), mstype.float32)
  440. self.affine = check_bool(affine)
  441. gamma = initializer('ones', [num_channels, 1, 1], mstype.float32)
  442. beta = initializer('zeros', [num_channels, 1, 1], mstype.float32)
  443. if self.affine:
  444. self.gamma = Parameter(gamma, name='gamma')
  445. self.beta = Parameter(beta, name='beta')
  446. else:
  447. self.gamma = gamma
  448. self.beta = beta
  449. self.shape = F.shape
  450. self.reshape = F.reshape
  451. self.reduce_mean = P.ReduceMean(keep_dims=True)
  452. self.square = F.square
  453. self.reduce_sum = P.ReduceSum(keep_dims=True)
  454. self.sqrt = P.Sqrt()
  455. def construct(self, x):
  456. batch, channel, height, width = self.shape(x)
  457. x = self.reshape(x, (batch, self.num_groups, channel*height*width/self.num_groups))
  458. mean = self.reduce_mean(x, 2)
  459. var = self.reduce_sum(self.square(x - mean), 2) / (channel * height * width / self.num_groups - 1)
  460. std = self.sqrt(var + self.eps)
  461. x = (x - mean) / std
  462. x = self.reshape(x, (batch, channel, height, width))
  463. output = x * self.gamma + self.beta
  464. return output
  465. def extend_repr(self):
  466. """Display instance object as string."""
  467. s = 'num_groups={}, num_channels={}'.format(self.num_groups, self.num_channels)
  468. return s