You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

batchnorm.py 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308
  1. # -*- coding: utf-8 -*-
  2. from typing import Optional
  3. import numpy as np
  4. from ..distributed.group import WORLD, Group
  5. from ..functional.nn import batch_norm, sync_batch_norm
  6. from ..tensor import Parameter, Tensor
  7. from . import init
  8. from .module import Module
  9. class _BatchNorm(Module):
  10. def __init__(
  11. self,
  12. num_features,
  13. eps=1e-5,
  14. momentum=0.9,
  15. affine=True,
  16. track_running_stats=True,
  17. freeze=False,
  18. compute_mode="default",
  19. param_dim="dim_1c11",
  20. **kwargs
  21. ):
  22. super(_BatchNorm, self).__init__(**kwargs)
  23. self.num_features = num_features
  24. self.eps = eps
  25. self.momentum = momentum
  26. self.affine = affine
  27. self.track_running_stats = track_running_stats
  28. self._track_running_stats_saved = track_running_stats
  29. self.freeze = freeze
  30. self.compute_mode = compute_mode
  31. self.param_dim = param_dim
  32. if self.freeze:
  33. assert (
  34. self._track_running_stats_saved
  35. ), "track_running_stats must be initilized to True if freeze is True"
  36. tshape = (1, self.num_features, 1, 1)
  37. if self.affine:
  38. self.weight = Parameter(np.ones(tshape, dtype=np.float32))
  39. self.bias = Parameter(np.zeros(tshape, dtype=np.float32))
  40. else:
  41. self.weight = None
  42. self.bias = None
  43. if self.track_running_stats:
  44. self.running_mean = Tensor(np.zeros(tshape, dtype=np.float32))
  45. self.running_var = Tensor(np.ones(tshape, dtype=np.float32))
  46. else:
  47. self.running_mean = None
  48. self.running_var = None
  49. def reset_running_stats(self) -> None:
  50. if self.track_running_stats:
  51. init.zeros_(self.running_mean)
  52. init.ones_(self.running_var)
  53. def reset_parameters(self) -> None:
  54. self.reset_running_stats()
  55. if self.affine:
  56. init.ones_(self.weight)
  57. init.zeros_(self.bias)
  58. def _check_input_ndim(self, inp):
  59. raise NotImplementedError
  60. def forward(self, inp):
  61. self._check_input_ndim(inp)
  62. if self._track_running_stats_saved == False:
  63. assert (
  64. self.track_running_stats == False
  65. ), "track_running_stats can not be initilized to False and changed to True later"
  66. _weight = self.weight
  67. _bias = self.bias
  68. if self.freeze:
  69. if _weight is not None:
  70. _weight = _weight.detach()
  71. if _bias is not None:
  72. _bias = _bias.detach()
  73. # fastpath excution for freeze
  74. scale = (self.running_var + self.eps) ** (-0.5)
  75. if _weight is not None:
  76. scale *= _weight
  77. bias = -self.running_mean * scale
  78. if _bias is not None:
  79. bias += _bias
  80. return inp * scale + bias
  81. if self.training and self.track_running_stats:
  82. exponential_average_factor = self.momentum
  83. else:
  84. exponential_average_factor = 0.0 # useless
  85. output = batch_norm(
  86. inp,
  87. self.running_mean if self.track_running_stats else None,
  88. self.running_var if self.track_running_stats else None,
  89. _weight,
  90. _bias,
  91. training=self.training
  92. or ((self.running_mean is None) and (self.running_var is None)),
  93. momentum=exponential_average_factor,
  94. eps=self.eps,
  95. compute_mode=self.compute_mode,
  96. param_dim=self.param_dim,
  97. )
  98. return output
  99. def _module_info_string(self) -> str:
  100. s = (
  101. "{num_features}, eps={eps}, momentum={momentum}, affine={affine}, "
  102. "track_running_stats={track_running_stats}"
  103. )
  104. return s.format(**self.__dict__)
  105. class SyncBatchNorm(_BatchNorm):
  106. r"""Applies Synchronized Batch Normalization for distributed training.
  107. Args:
  108. num_features: usually :math:`C` from an input of shape
  109. :math:`(N, C, H, W)` or the highest ranked dimension of an input
  110. less than 4D.
  111. eps: a value added to the denominator for numerical stability.
  112. Default: 1e-5
  113. momentum: the value used for the ``running_mean`` and ``running_var`` computation.
  114. Default: 0.9
  115. affine: a boolean value that when set to True, this module has
  116. learnable affine parameters. Default: True
  117. track_running_stats: when set to True, this module tracks the
  118. running mean and variance. When set to False, this module does not
  119. track such statistics and always uses batch statistics in both training
  120. and eval modes. Default: True
  121. freeze: when set to True, this module does not update the
  122. running mean and variance, and uses the running mean and variance instead of
  123. the batch mean and batch variance to normalize the input. The parameter takes effect
  124. only when the module is initilized with track_running_stats as True.
  125. Default: False
  126. group: communication group, caculate mean and variance between this group.
  127. Default: :obj:`~.distributed.WORLD`
  128. """
  129. def __init__(
  130. self,
  131. num_features,
  132. eps=1e-5,
  133. momentum=0.9,
  134. affine=True,
  135. track_running_stats=True,
  136. freeze=False,
  137. group: Optional[Group] = WORLD,
  138. **kwargs
  139. ) -> None:
  140. super().__init__(
  141. num_features, eps, momentum, affine, track_running_stats, freeze, **kwargs
  142. )
  143. self.group = group
  144. def _check_input_ndim(self, inp):
  145. if len(inp.shape) not in {2, 3, 4}:
  146. raise ValueError(
  147. "expected 2D, 3D or 4D input (got {}D input)".format(len(inp.shape))
  148. )
  149. def forward(self, inp):
  150. self._check_input_ndim(inp)
  151. inp_shape = inp.shape
  152. _ndims = len(inp_shape)
  153. if _ndims != 4:
  154. new_shape = Tensor([1, 1, 1, 1], device=inp.device)
  155. origin_shape = inp_shape
  156. if _ndims == 2:
  157. new_shape[:2] = origin_shape[:2]
  158. elif _ndims == 3:
  159. new_shape[:3] = origin_shape[:3]
  160. else:
  161. raise ValueError(
  162. "expected 2D, 3D or 4D input (got {}D input)".format(len(inp_shape))
  163. )
  164. inp = inp.reshape(new_shape)
  165. if self.training and self.track_running_stats:
  166. exponential_average_factor = self.momentum
  167. else:
  168. exponential_average_factor = 0.0 # useless
  169. _weight = self.weight
  170. _bias = self.bias
  171. if self.freeze:
  172. if _weight is not None:
  173. _weight = _weight.detach()
  174. if _bias is not None:
  175. _bias = _bias.detach()
  176. output = sync_batch_norm(
  177. inp,
  178. self.running_mean,
  179. self.running_var,
  180. _weight,
  181. _bias,
  182. training=(self.training and not self.freeze)
  183. or ((self.running_mean is None) and (self.running_var is None)),
  184. momentum=exponential_average_factor,
  185. eps=self.eps,
  186. group=self.group,
  187. )
  188. if _ndims != 4:
  189. output = output.reshape(origin_shape)
  190. return output
  191. class BatchNorm1d(_BatchNorm):
  192. r"""Applies Batch Normalization over a 2D/3D tensor.
  193. Refer to :class:`~.BatchNorm2d` for more information.
  194. """
  195. def _check_input_ndim(self, inp):
  196. if len(inp.shape) not in {2, 3}:
  197. raise ValueError(
  198. "expected 2D or 3D input (got {}D input)".format(len(inp.shape))
  199. )
  200. class BatchNorm2d(_BatchNorm):
  201. r"""Applies Batch Normalization over a 4D tensor.
  202. .. math::
  203. y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
  204. The mean and standard-deviation are calculated per-dimension over
  205. the mini-batches and :math:`\gamma` and :math:`\beta` are learnable
  206. parameter vectors.
  207. By default, during training this layer keeps running estimates of its
  208. computed mean and variance, which are then used for normalization during
  209. evaluation. The running estimates are kept with a default :attr:`momentum`
  210. of 0.9.
  211. If :attr:`track_running_stats` is set to ``False``, this layer will not
  212. keep running estimates, batch statistics is used during
  213. evaluation time instead.
  214. Because the Batch Normalization is done over the `C` dimension, computing
  215. statistics on `(N, H, W)` slices, it's common terminology to call this
  216. Spatial Batch Normalization.
  217. .. note::
  218. The update formula for ``running_mean`` and ``running_var`` (taking ``running_mean`` as an example) is
  219. .. math::
  220. \textrm{running_mean} = \textrm{momentum} \times \textrm{running_mean} + (1 - \textrm{momentum}) \times \textrm{batch_mean}
  221. which could be defined differently in other frameworks. Most notably, ``momentum`` of 0.1 in PyTorch
  222. is equivalent to ``mementum`` of 0.9 here.
  223. Args:
  224. num_features: usually :math:`C` from an input of shape
  225. :math:`(N, C, H, W)` or the highest ranked dimension of an input
  226. less than 4D.
  227. eps: a value added to the denominator for numerical stability.
  228. Default: 1e-5
  229. momentum: the value used for the ``running_mean`` and ``running_var`` computation.
  230. Default: 0.9
  231. affine: a boolean value that when set to True, this module has
  232. learnable affine parameters. Default: True
  233. track_running_stats: when set to True, this module tracks the
  234. running mean and variance. When set to False, this module does not
  235. track such statistics and always uses batch statistics in both training
  236. and eval modes. Default: True
  237. freeze: when set to True, this module does not update the
  238. running mean and variance, and uses the running mean and variance instead of
  239. the batch mean and batch variance to normalize the input. The parameter takes effect
  240. only when the module is initilized with track_running_stats as True.
  241. Default: False
  242. Examples:
  243. >>> import numpy as np
  244. >>> # With Learnable Parameters
  245. >>> m = M.BatchNorm2d(4)
  246. >>> inp = mge.tensor(np.random.rand(1, 4, 3, 3).astype("float32"))
  247. >>> oup = m(inp)
  248. >>> print(m.weight.numpy().flatten(), m.bias.numpy().flatten())
  249. [1. 1. 1. 1.] [0. 0. 0. 0.]
  250. >>> # Without L`e`arnable Parameters
  251. >>> m = M.BatchNorm2d(4, affine=False)
  252. >>> oup = m(inp)
  253. >>> print(m.weight, m.bias)
  254. None None
  255. """
  256. def _check_input_ndim(self, inp):
  257. if len(inp.shape) != 4:
  258. raise ValueError("expected 4D input (got {}D input)".format(len(inp.shape)))