You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

normalization.py 55 kB

4 years ago
5 years ago
5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120
  1. # Copyright 2020-2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """normalization"""
  16. import itertools
  17. import numbers
  18. from mindspore.ops import operations as P
  19. from mindspore.ops import functional as F
  20. from mindspore.ops.operations import _inner_ops as inner
  21. from mindspore.common.parameter import Parameter
  22. from mindspore.common.initializer import initializer, Initializer
  23. from mindspore.common.tensor import Tensor
  24. from mindspore.common._decorator import deprecated
  25. from mindspore.ops.primitive import constexpr
  26. import mindspore.context as context
  27. from mindspore._checkparam import Rel
  28. from mindspore._checkparam import Validator as validator
  29. from mindspore._extends import cell_attr_register
  30. from mindspore.communication.management import get_group_size, get_rank
  31. from mindspore.communication import management
  32. from mindspore.common import dtype as mstype
  33. from mindspore.parallel._utils import _is_in_auto_parallel_mode
  34. from ..cell import Cell
  35. __all__ = ['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d', 'LayerNorm', 'GroupNorm',
  36. 'GlobalBatchNorm', 'SyncBatchNorm', 'InstanceNorm2d']
  37. SYNC_BN_GROUP_NAME = ""
  38. class _BatchNorm(Cell):
  39. """Batch Normalization base class."""
  40. @cell_attr_register
  41. def __init__(self,
  42. num_features,
  43. eps=1e-5,
  44. momentum=0.9,
  45. affine=True,
  46. gamma_init='ones',
  47. beta_init='zeros',
  48. moving_mean_init='zeros',
  49. moving_var_init='ones',
  50. use_batch_statistics=None,
  51. device_num_each_group=1,
  52. process_groups=0,
  53. input_dims='2d',
  54. data_format='NCHW'):
  55. super(_BatchNorm, self).__init__()
  56. validator.check_value_type('num_features', num_features, [int], self.cls_name)
  57. if num_features < 1:
  58. raise ValueError("num_features must be at least 1")
  59. if momentum < 0 or momentum > 1:
  60. raise ValueError("momentum should be a number in range [0, 1], but got {}".format(momentum))
  61. self.input_dims = input_dims
  62. self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.cls_name)
  63. if context.get_context("device_target") != "GPU" and self.format == "NHWC":
  64. raise ValueError("NHWC format only support in GPU target.")
  65. self.use_batch_statistics = use_batch_statistics
  66. if self.use_batch_statistics is not None and not isinstance(self.use_batch_statistics, bool):
  67. raise ValueError("use_batch_statistics should be a boolean value or None.")
  68. self.num_features = num_features
  69. self.eps = eps
  70. self.moving_mean = Parameter(initializer(
  71. moving_mean_init, num_features), name="mean", requires_grad=False)
  72. self.moving_variance = Parameter(initializer(
  73. moving_var_init, num_features), name="variance", requires_grad=False)
  74. self.gamma = Parameter(initializer(
  75. gamma_init, num_features), name="gamma", requires_grad=affine)
  76. self.beta = Parameter(initializer(
  77. beta_init, num_features), name="beta", requires_grad=affine)
  78. self.group_device_num = validator.check_positive_int(device_num_each_group)
  79. self.process_groups = process_groups
  80. self.is_global = False
  81. self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
  82. global SYNC_BN_GROUP_NAME
  83. # for GlobalBatchNorm
  84. if self.group_device_num != 1:
  85. self.rank_id = get_rank()
  86. self.rank_size = get_group_size()
  87. self.device_list = [i for i in range(0, self.rank_size)]
  88. self.rank_list = self.list_group(self.device_list, self.group_device_num)
  89. self.rank_list_idx = len(self.rank_list)
  90. for i in range(self.rank_list_idx):
  91. if self.rank_id in self.rank_list[i]:
  92. self.is_global = True
  93. if SYNC_BN_GROUP_NAME == "":
  94. SYNC_BN_GROUP_NAME = "sync_bn_group" + str(i)
  95. management.create_group(SYNC_BN_GROUP_NAME, self.rank_list[i])
  96. # for SyncBatchNorm
  97. if self.process_groups != 0:
  98. self.rank_id = get_rank()
  99. self.rank_size = get_group_size()
  100. if self.process_groups is not None:
  101. validator.check_isinstance("process_groups", self.process_groups, list)
  102. self._check_rank_ids(self.process_groups, self.rank_size)
  103. for i in range(len(self.process_groups)):
  104. validator.check_isinstance("process_groups[" + str(i) + "]", self.process_groups[i], list)
  105. self.group_device_num = len(self.process_groups[i])
  106. if self.rank_id in self.process_groups[i] and self.group_device_num > 1:
  107. self.is_global = True
  108. if SYNC_BN_GROUP_NAME == "":
  109. SYNC_BN_GROUP_NAME = "sync_bn_group" + str(i)
  110. management.create_group(SYNC_BN_GROUP_NAME, self.process_groups[i])
  111. elif self.rank_size > 1:
  112. self.is_global = True
  113. self.group_device_num = self.rank_size
  114. self.device_list = [i for i in range(0, self.rank_size)]
  115. if SYNC_BN_GROUP_NAME == "":
  116. SYNC_BN_GROUP_NAME = "sync_bn_group0"
  117. management.create_group(SYNC_BN_GROUP_NAME, self.device_list)
  118. self.shape = P.Shape()
  119. self.reduce_mean = P.ReduceMean(keep_dims=True)
  120. self.square = P.Square()
  121. self.sqrt = P.Sqrt()
  122. self.cast = P.Cast()
  123. self.dtype = P.DType()
  124. self.reshape = P.Reshape()
  125. self._target = context.get_context("device_target")
  126. self.is_graph_mode = context.get_context("mode") == context.GRAPH_MODE
  127. self.momentum = 1.0 - momentum
  128. if context.get_context("enable_ge"):
  129. self.is_ge_backend = True
  130. else:
  131. self.is_ge_backend = False
  132. self.bn_train = P.BatchNorm(is_training=True,
  133. epsilon=self.eps,
  134. momentum=self.momentum,
  135. data_format=self.format)
  136. if self.is_global:
  137. self.bn_train = inner.SyncBatchNorm(epsilon=self.eps,
  138. momentum=self.momentum,
  139. group=SYNC_BN_GROUP_NAME,
  140. device_num=self.group_device_num)
  141. self.bn_infer = P.BatchNorm(is_training=False, epsilon=self.eps, data_format=self.format)
  142. if _is_in_auto_parallel_mode():
  143. data_parallel_strategy = ((1,), (1,))
  144. data_parallel_strategy_one = ((1,), ())
  145. else:
  146. data_parallel_strategy = None
  147. data_parallel_strategy_one = None
  148. self.sub_mean = P.Sub().shard(data_parallel_strategy)
  149. self.sub_var = P.Sub().shard(data_parallel_strategy)
  150. self.mul_mean = P.Mul().shard(data_parallel_strategy_one)
  151. self.mul_var = P.Mul().shard(data_parallel_strategy_one)
  152. self.assign_sub_mean = P.AssignSub().shard(data_parallel_strategy)
  153. self.assign_sub_var = P.AssignSub().shard(data_parallel_strategy)
  154. def _check_data_dim(self, x):
  155. raise NotImplementedError
  156. def list_group(self, world_rank, group_size):
  157. if group_size > get_group_size():
  158. raise ValueError("group size can not be greater than local rank size, group size is {}, "
  159. "local_rank_size is {}".format(group_size, get_group_size()))
  160. if len(world_rank) % group_size != 0:
  161. raise ValueError("please make your group size correct.")
  162. world_rank_list = zip(*(iter(world_rank),) * group_size)
  163. group_list = [list(i) for i in world_rank_list]
  164. return group_list
  165. def _check_rank_ids(self, process_groups, rank_size):
  166. seen = set()
  167. for rid in itertools.chain(*process_groups):
  168. validator.check_int_range(rid, 0, rank_size, Rel.INC_LEFT, "rank id in process_groups")
  169. if rid in seen:
  170. raise ValueError("rank id in process_groups should not be duplicated.")
  171. seen.add(rid)
  172. def construct(self, x):
  173. _shape_check_bn(self.shape(x), self.input_dims)
  174. if self.use_batch_statistics is None:
  175. if self.training:
  176. return self.bn_train(x,
  177. self.gamma,
  178. self.beta,
  179. self.moving_mean,
  180. self.moving_variance)[0]
  181. if not self.training:
  182. return self.bn_infer(x,
  183. self.gamma,
  184. self.beta,
  185. self.moving_mean,
  186. self.moving_variance)[0]
  187. if self.use_batch_statistics is True:
  188. return self.bn_train(x,
  189. self.gamma,
  190. self.beta,
  191. self.moving_mean,
  192. self.moving_variance)[0]
  193. return self.bn_infer(x,
  194. self.gamma,
  195. self.beta,
  196. self.moving_mean,
  197. self.moving_variance)[0]
  198. def extend_repr(self):
  199. return 'num_features={}, eps={}, momentum={}, gamma={}, beta={}, moving_mean={}, moving_variance={}'.format(
  200. self.num_features, self.eps, self.momentum, self.gamma, self.beta, self.moving_mean, self.moving_variance)
  201. @constexpr
  202. def _channel_check(channel, num_channel):
  203. if channel != num_channel:
  204. raise ValueError("the input channel is not equal with num_channel")
  205. @constexpr
  206. def _shape_check(in_shape):
  207. if len(in_shape) != 4:
  208. raise ValueError("The input must has 4 dims.")
  209. @constexpr
  210. def _shape_check_bn(in_shape, in_dims):
  211. """check input dims of batch norm."""
  212. dim = len(in_shape)
  213. if in_dims == '1d' and dim != 2:
  214. raise ValueError("The input must has 2 dims.")
  215. if in_dims == '2d' and dim != 4:
  216. raise ValueError("The input must has 4 dims.")
  217. if in_dims == '3d' and dim != 5:
  218. raise ValueError("The input must has 5 dims.")
  219. if in_dims == 'both' and dim != 2 and dim != 4:
  220. raise ValueError("The input must has 2 dims or 4 dims.")
  221. @constexpr
  222. def _shape_infer(x_shape, num_feature):
  223. """global Batch Normalization shape and axes infer"""
  224. if len(x_shape) == 4:
  225. axes = (0, 2, 3)
  226. re_shape = (1, num_feature, 1, 1)
  227. else:
  228. axes = (0,)
  229. re_shape = (1, num_feature)
  230. return axes, re_shape
  231. class BatchNorm1d(_BatchNorm):
  232. r"""
  233. Batch Normalization layer over a 2D input.
  234. Batch Normalization is widely used in convolutional networks. This layer
  235. applies Batch Normalization over a 2D input (a mini-batch of 1D inputs) to
  236. reduce internal covariate shift as described in the paper
  237. `Batch Normalization: Accelerating Deep Network Training by
  238. Reducing Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`_. It
  239. rescales and recenters the feature using a mini-batch of data and
  240. the learned parameters which can be described in the following formula.
  241. .. math::
  242. y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
  243. Note:
  244. The implementation of BatchNorm is different in graph mode and pynative mode, therefore the mode is not
  245. recommended to be changed after net was initialized.
  246. Args:
  247. num_features (int): `C` from an expected input of size (N, C).
  248. eps (float): A value added to the denominator for numerical stability. Default: 1e-5.
  249. momentum (float): A floating hyperparameter of the momentum for the
  250. running_mean and running_var computation. Default: 0.9.
  251. affine (bool): A bool value. When set to True, gamma and beta can be learned. Default: True.
  252. gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
  253. The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'ones'.
  254. beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
  255. The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'zeros'.
  256. moving_mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving mean.
  257. The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'zeros'.
  258. moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance.
  259. The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'ones'.
  260. use_batch_statistics (bool): If true, use the mean value and variance value of current batch data. If false,
  261. use the mean value and variance value of specified value. If None, the training process will use the mean
  262. and variance of current batch data and track the running mean and variance, the evaluation process will use
  263. the running mean and variance. Default: None.
  264. Inputs:
  265. - **input** (Tensor) - Tensor of shape :math:`(N, C_{in})`.
  266. Outputs:
  267. Tensor, the normalized, scaled, offset tensor, of shape :math:`(N, C_{out})`.
  268. Supported Platforms:
  269. ``Ascend`` ``GPU``
  270. Raises:
  271. TypeError: If `num_features` is not an int.
  272. TypeError: If `eps` is not a float.
  273. ValueError: If `num_features` is less than 1.
  274. ValueError: If `momentum` is not in range [0, 1].
  275. Examples:
  276. >>> net = nn.BatchNorm1d(num_features=4)
  277. >>> np.random.seed(0)
  278. >>> input = Tensor(np.random.randint(0, 255, [2, 4]), mindspore.float32)
  279. >>> output = net(input)
  280. >>> print(output)
  281. [[171.99915 46.999763 116.99941 191.99904 ]
  282. [ 66.999664 250.99875 194.99902 102.99948 ]]
  283. """
  284. def __init__(self,
  285. num_features,
  286. eps=1e-5,
  287. momentum=0.9,
  288. affine=True,
  289. gamma_init='ones',
  290. beta_init='zeros',
  291. moving_mean_init='zeros',
  292. moving_var_init='ones',
  293. use_batch_statistics=None):
  294. super(BatchNorm1d, self).__init__(num_features,
  295. eps,
  296. momentum,
  297. affine,
  298. gamma_init,
  299. beta_init,
  300. moving_mean_init,
  301. moving_var_init,
  302. use_batch_statistics,
  303. input_dims='1d')
  304. def _check_data_dim(self, x):
  305. if x.ndim != 2:
  306. pass
  307. class BatchNorm2d(_BatchNorm):
  308. r"""
  309. Batch Normalization layer over a 4D input.
  310. Batch Normalization is widely used in convolutional networks. This layer
  311. applies Batch Normalization over a 4D input (a mini-batch of 2D inputs with
  312. additional channel dimension) to avoid internal covariate shift as described
  313. in the paper `Batch Normalization: Accelerating Deep Network Training by
  314. Reducing Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`_. It
  315. rescales and recenters the feature using a mini-batch of data and
  316. the learned parameters which can be described in the following formula.
  317. .. math::
  318. y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
  319. Note:
  320. The implementation of BatchNorm is different in graph mode and pynative mode, therefore that mode can not be
  321. changed after net was initialized.
  322. Note that the formula for updating the running_mean and running_var is
  323. :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times x_t + \text{momentum} \times \hat{x}`,
  324. where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the new observed value.
  325. Args:
  326. num_features (int): `C` from an expected input of size (N, C, H, W).
  327. eps (float): A value added to the denominator for numerical stability. Default: 1e-5.
  328. momentum (float): A floating hyperparameter of the momentum for the
  329. running_mean and running_var computation. Default: 0.9.
  330. affine (bool): A bool value. When set to True, gamma and beta can be learned. Default: True.
  331. gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
  332. The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'ones'.
  333. beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
  334. The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'zeros'.
  335. moving_mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving mean.
  336. The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'zeros'.
  337. moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance.
  338. The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'ones'.
  339. use_batch_statistics (bool):
  340. - If true, use the mean value and variance value of current batch data and track running mean
  341. and running varance.
  342. - If false, use the mean value and variance value of specified value, and not track statistical value.
  343. - If None, The use_batch_statistics is automatically assigned process according to
  344. the training and eval mode. During training, batchnorm2d process will be the same
  345. with use_batch_statistics=True. Contrarily, in eval, batchnorm2d process will be the same
  346. with use_batch_statistics=False. Default: None.
  347. data_format (str): The optional value for data format, is 'NHWC' or 'NCHW'.
  348. Default: 'NCHW'.
  349. Inputs:
  350. - **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
  351. Outputs:
  352. Tensor, the normalized, scaled, offset tensor, of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
  353. Raises:
  354. TypeError: If `num_features` is not an int.
  355. TypeError: If `eps` is not a float.
  356. ValueError: If `num_features` is less than 1.
  357. ValueError: If `momentum` is not in range [0, 1].
  358. ValueError: If `data_format` is neither 'NHWC' not 'NCHW'.
  359. Supported Platforms:
  360. ``Ascend`` ``GPU`` ``CPU``
  361. Examples:
  362. >>> net = nn.BatchNorm2d(num_features=3)
  363. >>> np.random.seed(0)
  364. >>> input = Tensor(np.random.randint(0, 255, [1, 3, 2, 2]), mindspore.float32)
  365. >>> output = net(input)
  366. >>> print(output)
  367. [[[[171.99915 46.999763 ]
  368. [116.99941 191.99904 ]]
  369. [[ 66.999664 250.99875 ]
  370. [194.99902 102.99948 ]]
  371. [[ 8.999955 210.99895 ]
  372. [ 20.999895 241.9988 ]]]]
  373. """
  374. def __init__(self,
  375. num_features,
  376. eps=1e-5,
  377. momentum=0.9,
  378. affine=True,
  379. gamma_init='ones',
  380. beta_init='zeros',
  381. moving_mean_init='zeros',
  382. moving_var_init='ones',
  383. use_batch_statistics=None,
  384. data_format='NCHW'):
  385. super(BatchNorm2d, self).__init__(num_features,
  386. eps,
  387. momentum,
  388. affine,
  389. gamma_init,
  390. beta_init,
  391. moving_mean_init,
  392. moving_var_init,
  393. use_batch_statistics,
  394. input_dims='2d',
  395. data_format=data_format)
  396. def _check_data_dim(self, x):
  397. if x.ndim != 4:
  398. pass
  399. @constexpr
  400. def _check_3d_shape(input_shape):
  401. if len(input_shape) != 5:
  402. raise ValueError("For BatchNorm3d, input data must be 5-dimensional.")
  403. class BatchNorm3d(Cell):
  404. r"""
  405. Batch Normalization layer over a 5D input.
  406. Batch Normalization is widely used in convolutional networks. This layer
  407. applies Batch Normalization over a 5D input (a mini-batch of 3D inputs with
  408. additional channel dimension) to avoid internal covariate shift.
  409. .. math::
  410. y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
  411. Note:
  412. The implementation of BatchNorm is different in graph mode and pynative mode, therefore that mode can not be
  413. changed after net was initialized.
  414. Note that the formula for updating the running_mean and running_var is
  415. :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times x_t + \text{momentum} \times \hat{x}`,
  416. where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the new observed value.
  417. Args:
  418. num_features (int): `C` from an expected input of size (N, C, D, H, W).
  419. eps (float): A value added to the denominator for numerical stability. Default: 1e-5.
  420. momentum (float): A floating hyperparameter of the momentum for the
  421. running_mean and running_var computation. Default: 0.9.
  422. affine (bool): A bool value. When set to True, gamma and beta can be learned. Default: True.
  423. gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
  424. The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'ones'.
  425. beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
  426. The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'zeros'.
  427. moving_mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving mean.
  428. The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'zeros'.
  429. moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance.
  430. The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'ones'.
  431. use_batch_statistics (bool): If true, use the mean value and variance value of current batch data. If false,
  432. use the mean value and variance value of specified value. If None, the training process will use the mean
  433. and variance of current batch data and track the running mean and variance, the evaluation process will use
  434. the running mean and variance. Default: None.
  435. data_format (str): The optional value for data format is 'NCDHW'. Default: 'NCDHW'.
  436. Inputs:
  437. - **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})`.
  438. Outputs:
  439. Tensor, the normalized, scaled, offset tensor, of shape :math:`(N, C_{out}, D_{out},H_{out}, W_{out})`.
  440. Raises:
  441. TypeError: If `num_features` is not an int.
  442. TypeError: If `eps` is not a float.
  443. ValueError: If `num_features` is less than 1.
  444. ValueError: If `momentum` is not in range [0, 1].
  445. ValueError: If `data_format` is not 'NCDHW'.
  446. Supported Platforms:
  447. ``Ascend`` ``GPU`` ``CPU``
  448. Examples:
  449. >>> net = nn.BatchNorm3d(num_features=3)
  450. >>> np.random.seed(0)
  451. >>> input = Tensor(np.random.randint(0, 255, [16, 3, 10, 32, 32]), mindspore.float32)
  452. >>> output = net(input)
  453. >>> print(output.shape)
  454. (16, 3, 10, 32, 32)
  455. """
  456. def __init__(self,
  457. num_features,
  458. eps=1e-5,
  459. momentum=0.9,
  460. affine=True,
  461. gamma_init='ones',
  462. beta_init='zeros',
  463. moving_mean_init='zeros',
  464. moving_var_init='ones',
  465. use_batch_statistics=None,
  466. data_format='NCDHW'):
  467. super(BatchNorm3d, self).__init__()
  468. self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.cls_name)
  469. self.reshape = P.Reshape()
  470. self.bn2d = BatchNorm2d(num_features=num_features,
  471. eps=eps,
  472. momentum=momentum,
  473. affine=affine,
  474. gamma_init=gamma_init,
  475. beta_init=beta_init,
  476. moving_mean_init=moving_mean_init,
  477. moving_var_init=moving_var_init,
  478. use_batch_statistics=use_batch_statistics,
  479. data_format="NCHW")
  480. def construct(self, input_x):
  481. x_shape = F.shape(input_x)
  482. _check_3d_shape(x_shape)
  483. input_x = self.reshape(input_x, (x_shape[0], x_shape[1], x_shape[2] * x_shape[3], x_shape[4]))
  484. bn2d_out = self.bn2d(input_x)
  485. bn3d_out = self.reshape(bn2d_out, x_shape)
  486. return bn3d_out
  487. class GlobalBatchNorm(_BatchNorm):
  488. r"""
  489. Global Batch Normalization layer over a N-dimension input.
  490. Global Batch Normalization is cross device synchronized Batch Normalization. The implementation of
  491. Batch Normalization only normalizes the data within each device. Global Normalization will normalize
  492. the input within the group.It has been described in the paper `Batch Normalization: Accelerating Deep Network
  493. Training by Reducing Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`_. It rescales and recenters the
  494. feature using a mini-batch of data and the learned parameters which can be described in the following formula.
  495. .. math::
  496. y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
  497. Note:
  498. Currently, GlobalBatchNorm only supports 2D and 4D inputs.
  499. Args:
  500. num_features (int): `C` from an expected input of size (N, C, H, W).
  501. device_num_each_group (int): The number of devices in each group. Default: 2.
  502. eps (float): A value added to the denominator for numerical stability. Default: 1e-5.
  503. momentum (float): A floating hyperparameter of the momentum for the
  504. running_mean and running_var computation. Default: 0.9.
  505. affine (bool): A bool value. When set to True, gamma and beta can be learned. Default: True.
  506. gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
  507. The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
  508. 'he_uniform', etc. Default: 'ones'.
  509. beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
  510. The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
  511. 'he_uniform', etc. Default: 'zeros'.
  512. moving_mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving mean.
  513. The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
  514. 'he_uniform', etc. Default: 'zeros'.
  515. moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance.
  516. The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
  517. 'he_uniform', etc. Default: 'ones'.
  518. use_batch_statistics (bool): If true, use the mean value and variance value of current batch data. If false,
  519. use the mean value and variance value of specified value. If None, training process will use the mean and
  520. variance of current batch data and track the running mean and variance, eval process will use the running
  521. mean and variance. Default: None.
  522. Inputs:
  523. - **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
  524. Outputs:
  525. Tensor, the normalized, scaled, offset tensor, of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
  526. Raises:
  527. TypeError: If `num_features` or `device_num_each_group` is not an int.
  528. TypeError: If `eps` is not a float.
  529. ValueError: If `num_features` is less than 1.
  530. ValueError: If `momentum` is not in range [0, 1].
  531. ValueError: If `device_num_each_group` is less than 2.
  532. Supported Platforms:
  533. ``Ascend``
  534. Examples:
  535. >>> # This example should be run with multiple processes.
  536. >>> # Please refer to the tutorial > Distributed Training on mindspore.cn.
  537. >>> import numpy as np
  538. >>> from mindspore.communication import init
  539. >>> from mindspore import context
  540. >>> from mindspore.context import ParallelMode
  541. >>> from mindspore import nn, Tensor
  542. >>> from mindspore.common import dtype as mstype
  543. >>>
  544. >>> context.set_context(mode=context.GRAPH_MODE)
  545. >>> init()
  546. >>> context.reset_auto_parallel_context()
  547. >>> context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL)
  548. >>> np.random.seed(0)
  549. >>> global_bn_op = nn.GlobalBatchNorm(num_features=3, device_num_each_group=2)
  550. >>> input = Tensor(np.random.randint(0, 255, [1, 3, 2, 2]), mstype.float32)
  551. >>> output = global_bn_op(input)
  552. >>> print(output)
  553. [[[[171.99915 46.999763]
  554. [116.99941 191.99904 ]]
  555. [[ 66.999664 250.99875 ]
  556. [194.99902 102.99948 ]]
  557. [[ 8.999955 210.99895 ]
  558. [ 20.9999895 241.9988 ]]]]
  559. """
  560. @deprecated("1.2", "SyncBatchNorm", True)
  561. def __init__(self,
  562. num_features,
  563. eps=1e-5,
  564. momentum=0.9,
  565. affine=True,
  566. gamma_init='ones',
  567. beta_init='zeros',
  568. moving_mean_init='zeros',
  569. moving_var_init='ones',
  570. use_batch_statistics=None,
  571. device_num_each_group=2):
  572. super(GlobalBatchNorm, self).__init__(num_features,
  573. eps,
  574. momentum,
  575. affine,
  576. gamma_init,
  577. beta_init,
  578. moving_mean_init,
  579. moving_var_init,
  580. use_batch_statistics,
  581. device_num_each_group,
  582. input_dims='both')
  583. self.group_device_num = validator.check_positive_int(device_num_each_group)
  584. if self.group_device_num <= 1:
  585. raise ValueError("the number of group must be greater than 1.")
  586. def _check_data_dim(self, x):
  587. if x.dim == 0:
  588. pass
  589. class SyncBatchNorm(_BatchNorm):
  590. r"""
  591. Sync Batch Normalization layer over a N-dimension input.
  592. Sync Batch Normalization is cross device synchronized Batch Normalization. The implementation of Batch
  593. Normalization only normalizes the data within each device. Sync Batch Normalization will normalize the input
  594. within the group. It has been described in the paper `Batch Normalization: Accelerating Deep Network Training by
  595. Reducing Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`_. It rescales and recenters the
  596. feature using a mini-batch of data and the learned parameters which can be described in the following formula.
  597. .. math::
  598. y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
  599. Note:
  600. Currently, SyncBatchNorm only supports 2D and 4D inputs.
  601. Args:
  602. num_features (int): `C` from an expected input of size (N, C, H, W).
  603. eps (float): A value added to the denominator for numerical stability. Default: 1e-5.
  604. momentum (float): A floating hyperparameter of the momentum for the
  605. running_mean and running_var computation. Default: 0.9.
  606. affine (bool): A bool value. When set to True, gamma and beta can be learned. Default: True.
  607. gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
  608. The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
  609. 'he_uniform', etc. Default: 'ones'.
  610. beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
  611. The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
  612. 'he_uniform', etc. Default: 'zeros'.
  613. moving_mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving mean.
  614. The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
  615. 'he_uniform', etc. Default: 'zeros'.
  616. moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance.
  617. The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
  618. 'he_uniform', etc. Default: 'ones'.
  619. use_batch_statistics (bool): If true, use the mean value and variance value of current batch data. If false,
  620. use the mean value and variance value of specified value. If None, training process will use the mean and
  621. variance of current batch data and track the running mean and variance, eval process will use the running
  622. mean and variance. Default: None.
  623. process_groups (list): A list to divide devices into different sync groups, containing N subtraction lists.
  624. Each subtraction list contains int numbers identifying rank ids which need to be synchronized in the same
  625. group. All int values must be in [0, rank_size) and different from each other. Default: None, indicating
  626. synchronization across all devices.
  627. Inputs:
  628. - **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
  629. Outputs:
  630. Tensor, the normalized, scaled, offset tensor, of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
  631. Raises:
  632. TypeError: If `num_features` is not an int.
  633. TypeError: If `eps` is not a float.
  634. TypeError: If `process_groups` is not a list.
  635. ValueError: If `num_features` is less than 1.
  636. ValueError: If `momentum` is not in range [0, 1].
  637. ValueError: If rank_id in `process_groups` is not in range [0, rank_size).
  638. Supported Platforms:
  639. ``Ascend``
  640. Examples:
  641. >>> # This example should be run with multiple processes.
  642. >>> # Please refer to the tutorial > Distributed Training on mindspore.cn.
  643. >>> import numpy as np
  644. >>> from mindspore.communication import init
  645. >>> from mindspore import context
  646. >>> from mindspore.context import ParallelMode
  647. >>> from mindspore import nn, Tensor
  648. >>> from mindspore.common import dtype as mstype
  649. >>>
  650. >>> context.set_context(mode=context.GRAPH_MODE)
  651. >>> init()
  652. >>> context.reset_auto_parallel_context()
  653. >>> context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL)
  654. >>> np.random.seed(0)
  655. >>> sync_bn_op = nn.SyncBatchNorm(num_features=3, process_groups=[[0, 1], [2, 3]])
  656. >>> input = Tensor(np.random.randint(0, 255, [1, 3, 2, 2]), mstype.float32)
  657. >>> output = sync_bn_op(input)
  658. >>> print(output)
  659. [[[[171.99915 46.999763]
  660. [116.99941 191.99904 ]]
  661. [[ 66.999664 250.99875 ]
  662. [194.99902 102.99948 ]]
  663. [[ 8.999955 210.99895 ]
  664. [ 20.9999895 241.9988 ]]]]
  665. """
  666. def __init__(self,
  667. num_features,
  668. eps=1e-5,
  669. momentum=0.9,
  670. affine=True,
  671. gamma_init='ones',
  672. beta_init='zeros',
  673. moving_mean_init='zeros',
  674. moving_var_init='ones',
  675. use_batch_statistics=None,
  676. process_groups=None):
  677. super(SyncBatchNorm, self).__init__(num_features,
  678. eps,
  679. momentum,
  680. affine,
  681. gamma_init,
  682. beta_init,
  683. moving_mean_init,
  684. moving_var_init,
  685. use_batch_statistics,
  686. process_groups=process_groups,
  687. input_dims='both')
  688. def _check_data_dim(self, x):
  689. if x.dim == 0:
  690. pass
  691. class LayerNorm(Cell):
  692. r"""
  693. Applies Layer Normalization over a mini-batch of inputs.
  694. Layer Normalization is widely used in recurrent neural networks. It applies
  695. normalization on a mini-batch of inputs for each single training case as described
  696. in the paper `Layer Normalization <https://arxiv.org/pdf/1607.06450.pdf>`_. Unlike Batch
  697. Normalization, Layer Normalization performs exactly the same computation at training and
  698. testing time. It can be described using the following formula. It is applied across all channels
  699. and pixel but only one batch size.
  700. .. math::
  701. y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
  702. Args:
  703. normalized_shape (Union(tuple[int], list[int]): The normalization is performed over axis
  704. `begin_norm_axis ... R - 1`.
  705. begin_norm_axis (int): The first normalization dimension: normalization will be performed along dimensions
  706. `begin_norm_axis: rank(inputs)`, the value should be in [-1, rank(input)). Default: -1.
  707. begin_params_axis (int): The first parameter(beta, gamma)dimension: scale and centering parameters
  708. will have dimensions `begin_params_axis: rank(inputs)` and will be broadcast with
  709. the normalized inputs accordingly, the value should be in [-1, rank(input)). Default: -1.
  710. gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
  711. The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
  712. 'he_uniform', etc. Default: 'ones'.
  713. beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
  714. The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
  715. 'he_uniform', etc. Default: 'zeros'.
  716. epsilon (float): A value added to the denominator for numerical stability. Default: 1e-7.
  717. Inputs:
  718. - **input_x** (Tensor) - The shape of 'input_x' is :math:`(x_1, x_2, ..., x_R)`,
  719. and `input_shape[begin_norm_axis:]` is equal to `normalized_shape`.
  720. Outputs:
  721. Tensor, the normalized and scaled offset tensor, has the same shape and data type as the `input_x`.
  722. Raises:
  723. TypeError: If `normalized_shape` is neither a list nor tuple.
  724. TypeError: If `begin_norm_axis` or `begin_params_axis` is not an int.
  725. TypeError: If `epsilon` is not a float.
  726. Supported Platforms:
  727. ``Ascend`` ``GPU`` ``CPU``
  728. Examples:
  729. >>> x = Tensor(np.ones([20, 5, 10, 10]), mindspore.float32)
  730. >>> shape1 = x.shape[1:]
  731. >>> m = nn.LayerNorm(shape1, begin_norm_axis=1, begin_params_axis=1)
  732. >>> output = m(x).shape
  733. >>> print(output)
  734. (20, 5, 10, 10)
  735. """
  736. def __init__(self,
  737. normalized_shape,
  738. begin_norm_axis=-1,
  739. begin_params_axis=-1,
  740. gamma_init='ones',
  741. beta_init='zeros',
  742. epsilon=1e-7
  743. ):
  744. super(LayerNorm, self).__init__()
  745. if not isinstance(normalized_shape, (tuple, list)):
  746. raise TypeError("The type of 'normalized_shape' should be tuple[int] or list[int], but '{}' type is {}."
  747. .format(normalized_shape, type(normalized_shape)))
  748. self.normalized_shape = normalized_shape
  749. self.begin_norm_axis = begin_norm_axis
  750. self.begin_params_axis = begin_params_axis
  751. self.epsilon = epsilon
  752. self.gamma = Parameter(initializer(
  753. gamma_init, normalized_shape), name="gamma")
  754. self.beta = Parameter(initializer(
  755. beta_init, normalized_shape), name="beta")
  756. self.layer_norm = P.LayerNorm(begin_norm_axis=self.begin_norm_axis,
  757. begin_params_axis=self.begin_params_axis,
  758. epsilon=self.epsilon)
  759. def construct(self, input_x):
  760. y, _, _ = self.layer_norm(input_x, self.gamma, self.beta)
  761. return y
  762. def extend_repr(self):
  763. """Display instance object as string."""
  764. return 'normalized_shape={}, begin_norm_axis={}, begin_params_axis={}, gamma{}, beta={}'.format(
  765. self.normalized_shape, self.begin_norm_axis, self.begin_params_axis, self.gamma, self.beta)
  766. class InstanceNorm2d(Cell):
  767. r"""
  768. Instance Normalization layer over a 4D input.
  769. This layer applies Instance Normalization over a 4D input (a mini-batch of 2D inputs with
  770. additional channel dimension) as described in the paper `Instance Normalization: The Missing Ingredient for
  771. Fast Stylization <https://arxiv.org/abs/1607.08022>`_. It rescales and recenters the feature using a mini-batch
  772. of data and the learned parameters which can be described in the following formula.
  773. .. math::
  774. y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
  775. \gamma and \beta are learnable parameter vectors of size num_features if affine is True. The standard-deviation
  776. is calculated via the biased estimator.
  777. By default, this layer uses instance statistics computed from input data in both training and evaluation modes.
  778. If use_batch_statistics is set to True, it means training phases, and this layer keeps running estimates of its
  779. computed mean and variance, which are then used for normalization during evaluation. The running estimates are
  780. kept with a default momentum of 0.1.
  781. InstanceNorm2d and BatchNorm2d are very similar, but have some differences. InstanceNorm2d is applied on each
  782. channel of channeled data like RGB images, but BatchNorm2d is usually applied on each batch of batched data.
  783. Note:
  784. Note that the formula for updating the running_mean and running_var is
  785. :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times x_t + \text{momentum} \times \hat{x}`,
  786. where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the new observed value.
  787. Args:
  788. num_features (int): `C` from an expected input of size (N, C, H, W).
  789. eps (float): A value added to the denominator for numerical stability. Default: 1e-5.
  790. momentum (float): A floating hyperparameter of the momentum for the
  791. running_mean and running_var computation. Default: 0.1.
  792. affine (bool): A bool value. When set to True, gamma and beta can be learned. Default: True.
  793. gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
  794. The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'ones'.
  795. beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
  796. The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'zeros'.
  797. moving_mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving mean.
  798. The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'zeros'.
  799. moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance.
  800. The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'ones'.
  801. use_batch_statistics (bool): If true, use the mean value and variance value of current batch data. If false,
  802. use the mean value and variance value of specified value. Default: True.
  803. Inputs:
  804. - **input** (Tensor) - Tensor of shape :math:`(N, C, H, W)`. Data type: float16 or float32.
  805. Outputs:
  806. Tensor, the normalized, scaled, offset tensor, of shape :math:`(N, C, H, W)`. Same type and
  807. shape as the `input_x`.
  808. Supported Platforms:
  809. ``GPU``
  810. Raises:
  811. TypeError: If `num_features` is not an int.
  812. TypeError: If `eps` is not a float.
  813. TypeError: If `momentum` is not a float.
  814. TypeError: If `affine` is not a bool.
  815. TypeError: If the type of `gamma_init`/`beta_init`/`moving_mean_init`/`moving_var_init` is not same, or if
  816. the initialized element type is not float32.
  817. ValueError: If `num_features` is less than 1.
  818. ValueError: If `momentum` is not in range [0, 1].
  819. KeyError: If any of `gamma_init`/`beta_init`/`moving_mean_init`/`moving_var_init` is str and the homonymous
  820. class inheriting from `Initializer` not exists.
  821. Examples:
  822. >>> net = nn.InstanceNorm2d(3)
  823. >>> np.random.seed(0)
  824. >>> input = Tensor(np.random.randint(0, 255, [2, 3, 2, 2]), mindspore.float32)
  825. >>> output = net(input)
  826. >>> print(output.shape)
  827. (2, 3, 2, 2)
  828. """
  829. @cell_attr_register
  830. def __init__(self,
  831. num_features,
  832. eps=1e-5,
  833. momentum=0.1,
  834. affine=True,
  835. gamma_init='ones',
  836. beta_init='zeros',
  837. moving_mean_init='zeros',
  838. moving_var_init='ones',
  839. use_batch_statistics=True):
  840. super(InstanceNorm2d, self).__init__()
  841. validator.check_value_type('num_features', num_features, [int], self.cls_name)
  842. validator.check_value_type('eps', eps, [float], self.cls_name)
  843. validator.check_value_type('momentum', momentum, [float], self.cls_name)
  844. validator.check_value_type('affine', affine, [bool], self.cls_name)
  845. args_input = {"gamma_init": gamma_init, "beta_init": beta_init,
  846. "moving_mean_init": moving_mean_init, "moving_var_init": moving_var_init}
  847. self.check_types_valid(args_input, 'InstanceNorm2d')
  848. if num_features < 1:
  849. raise ValueError("num_features must be at least 1")
  850. if momentum < 0 or momentum > 1:
  851. raise ValueError("momentum should be a number in range [0, 1], but got {}".format(momentum))
  852. self.use_batch_statistics = use_batch_statistics
  853. self.num_features = num_features
  854. self.eps = eps
  855. self.input_dims = '2d'
  856. self.moving_mean = Parameter(initializer(
  857. moving_mean_init, num_features), name="mean", requires_grad=False)
  858. self.moving_variance = Parameter(initializer(
  859. moving_var_init, num_features), name="variance", requires_grad=False)
  860. self.gamma = Parameter(initializer(
  861. gamma_init, num_features), name="gamma", requires_grad=affine)
  862. self.beta = Parameter(initializer(
  863. beta_init, num_features), name="beta", requires_grad=affine)
  864. self.shape = P.Shape()
  865. self.momentum = momentum
  866. self.instance_bn = P.InstanceNorm(is_training=self.use_batch_statistics,
  867. epsilon=self.eps,
  868. momentum=self.momentum)
  869. def _check_data_dim(self, x):
  870. raise NotImplementedError
  871. def construct(self, x):
  872. _shape_check_bn(self.shape(x), self.input_dims)
  873. return self.instance_bn(x,
  874. self.gamma,
  875. self.beta,
  876. self.moving_mean,
  877. self.moving_variance)[0]
  878. def extend_repr(self):
  879. return 'num_features={}, eps={}, momentum={}, gamma={}, beta={}, moving_mean={}, moving_variance={}'.format(
  880. self.num_features, self.eps, self.momentum, self.gamma, self.beta, self.moving_mean, self.moving_variance)
  881. def check_types_valid(self, args_dict, name):
  882. for key, _ in args_dict.items():
  883. val = args_dict[key]
  884. if not isinstance(val, (Tensor, numbers.Number, str, Initializer)):
  885. raise TypeError(f"[{name}]Supported type for arg {key} is [Tensor, numbers.Number, str, Initializer],"
  886. f"but got {type(val)}")
  887. if isinstance(val, Tensor) and val.dtype != mstype.float32:
  888. raise TypeError(f"[{name}]The type of arg {key} should be float32, but got {val.dtype}")
  889. class GroupNorm(Cell):
  890. r"""
  891. Group Normalization over a mini-batch of inputs.
  892. Group Normalization is widely used in recurrent neural networks. It applies
  893. normalization on a mini-batch of inputs for each single training case as described
  894. in the paper `Group Normalization <https://arxiv.org/pdf/1803.08494.pdf>`_. Group Normalization
  895. divides the channels into groups and computes within each group the mean and variance for normalization,
  896. and it performs very stable over a wide range of batch size. It can be described using the following formula.
  897. .. math::
  898. y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
  899. Args:
  900. num_groups (int): The number of groups to be divided along the channel dimension.
  901. num_channels (int): The number of channels per group.
  902. eps (float): A value added to the denominator for numerical stability. Default: 1e-5.
  903. affine (bool): A bool value, this layer will have learnable affine parameters when set to true. Default: True.
  904. gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
  905. The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
  906. 'he_uniform', etc. Default: 'ones'. If gamma_init is a Tensor, the shape must be [num_channels].
  907. beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
  908. The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
  909. 'he_uniform', etc. Default: 'zeros'. If beta_init is a Tensor, the shape must be [num_channels].
  910. Inputs:
  911. - **input_x** (Tensor) - The input feature with shape [N, C, H, W].
  912. Outputs:
  913. Tensor, the normalized and scaled offset tensor, has the same shape and data type as the `input_x`.
  914. Raises:
  915. TypeError: If `num_groups` or `num_channels` is not an int.
  916. TypeError: If `eps` is not a float.
  917. TypeError: If `affine` is not a bool.
  918. ValueError: If `num_groups` or `num_channels` is less than 1.
  919. ValueError: If `num_channels` is not divided by `num_groups`.
  920. Supported Platforms:
  921. ``Ascend`` ``GPU`` ``CPU``
  922. Examples:
  923. >>> goup_norm_op = nn.GroupNorm(2, 2)
  924. >>> x = Tensor(np.ones([1, 2, 4, 4], np.float32))
  925. >>> output = goup_norm_op(x)
  926. >>> print(output)
  927. [[[[0. 0. 0. 0.]
  928. [0. 0. 0. 0.]
  929. [0. 0. 0. 0.]
  930. [0. 0. 0. 0.]]
  931. [[0. 0. 0. 0.]
  932. [0. 0. 0. 0.]
  933. [0. 0. 0. 0.]
  934. [0. 0. 0. 0.]]]]
  935. """
  936. def __init__(self, num_groups, num_channels, eps=1e-05, affine=True, gamma_init='ones', beta_init='zeros'):
  937. super(GroupNorm, self).__init__()
  938. self.num_groups = validator.check_positive_int(num_groups)
  939. self.num_channels = validator.check_positive_int(num_channels)
  940. if num_channels % num_groups != 0:
  941. raise ValueError("num_channels should be divided by num_groups")
  942. self.eps = validator.check_value_type('eps', eps, (float,), type(self).__name__)
  943. self.affine = validator.check_bool(affine)
  944. gamma = initializer(gamma_init, num_channels)
  945. beta = initializer(beta_init, num_channels)
  946. if self.affine:
  947. self.gamma = Parameter(gamma, name='gamma')
  948. self.beta = Parameter(beta, name='beta')
  949. else:
  950. self.gamma = gamma
  951. self.beta = beta
  952. self.shape = F.shape
  953. self.reshape = F.reshape
  954. self.reduce_mean = P.ReduceMean(keep_dims=True)
  955. self.square = F.square
  956. self.reduce_sum = P.ReduceSum(keep_dims=True)
  957. self.sqrt = P.Sqrt()
  958. def _cal_output(self, x):
  959. """calculate groupnorm output"""
  960. batch, channel, height, width = self.shape(x)
  961. _channel_check(channel, self.num_channels)
  962. x = self.reshape(x, (batch, self.num_groups, -1))
  963. mean = self.reduce_mean(x, 2)
  964. var = self.reduce_sum(self.square(x - mean), 2) / (channel * height * width / self.num_groups)
  965. std = self.sqrt(var + self.eps)
  966. x = (x - mean) / std
  967. x = self.reshape(x, (batch, channel, height, width))
  968. output = x * self.reshape(self.gamma, (-1, 1, 1)) + self.reshape(self.beta, (-1, 1, 1))
  969. return output
  970. def construct(self, x):
  971. _shape_check(self.shape(x))
  972. output = self._cal_output(x)
  973. return output
  974. def extend_repr(self):
  975. """Display instance object as string."""
  976. return 'num_groups={}, num_channels={}'.format(self.num_groups, self.num_channels)