You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

initializer.py 8.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Initializer for cell parameters."""
  16. import numbers
  17. import math
  18. from functools import reduce
  19. import numpy as np
  20. from scipy.stats import truncnorm
  21. from mindspore import log as logger
  22. from . import dtype as mstype
  23. from .tensor import Tensor
  24. _INITIALIZER_ALIAS = dict()
  25. class Initializer:
  26. """
  27. The base class of the initializer.
  28. Args:
  29. kwargs (dict): Keyword arguments for Initializer.
  30. Returns:
  31. Array, assigned array.
  32. """
  33. def __init__(self, **kwargs):
  34. self._kwargs = kwargs
  35. def _initialize(self, *kwargs):
  36. raise NotImplementedError('Must be overridden!')
  37. def __call__(self, arr):
  38. return self._initialize(arr)
  39. def _register(*aliases):
  40. """Return the alias register."""
  41. def alias_reg(cls):
  42. name = cls.__name__
  43. name = name.lower()
  44. if name not in _INITIALIZER_ALIAS:
  45. _INITIALIZER_ALIAS[name] = cls
  46. for alias in aliases:
  47. if alias not in _INITIALIZER_ALIAS:
  48. _INITIALIZER_ALIAS[alias] = cls
  49. return cls
  50. return alias_reg
  51. def _assignment(arr, num):
  52. """Assign the value of `num` to `arr`."""
  53. if arr.shape == ():
  54. arr = arr.reshape((1))
  55. arr[:] = num
  56. arr = arr.reshape(())
  57. else:
  58. if isinstance(num, np.ndarray):
  59. arr[:] = num[:]
  60. else:
  61. arr[:] = num
  62. return arr
  63. @_register('zeros')
  64. class Zero(Initializer):
  65. """
  66. Initialize the array to zero.
  67. Args:
  68. arr (Array): The array to be assigned.
  69. Returns:
  70. Array, assigned array.
  71. """
  72. def _initialize(self, arr):
  73. _assignment(arr, 0)
  74. @_register('ones')
  75. class One(Initializer):
  76. """
  77. Initialize the array to one.
  78. Args:
  79. arr (Array): The array to be assigned.
  80. Returns:
  81. Array, assigned array.
  82. """
  83. def _initialize(self, arr):
  84. _assignment(arr, 1)
  85. def _calculate_in_and_out(arr):
  86. """
  87. Calculate n_in and n_out.
  88. Args:
  89. arr (Array): Input array.
  90. Returns:
  91. Tuple, a tuple with two elements, the first element is `n_in` and the second element is `n_out`.
  92. """
  93. dim = len(arr.shape)
  94. if dim < 2:
  95. raise ValueError("If initialize data with xavier uniform, the dimension of data must greater than 1.")
  96. n_in = arr.shape[1]
  97. n_out = arr.shape[0]
  98. if dim > 2:
  99. counter = reduce(lambda x, y: x * y, arr.shape[2:])
  100. n_in *= counter
  101. n_out *= counter
  102. return n_in, n_out
  103. @_register('xavier_uniform')
  104. class XavierUniform(Initializer):
  105. r"""
  106. Initialize the array with xavier uniform algorithm, and from a uniform distribution collect samples within
  107. U[-boundary, boundary] where :math:`boundary = gain * \sqrt{\frac{6}{n_{in} + n_{out}}}`.
  108. Args:
  109. gain (Array): The array to be assigned. Default: 1.
  110. Returns:
  111. Array, assigned array.
  112. """
  113. def __init__(self, gain=1):
  114. super(XavierUniform, self).__init__(gain=gain)
  115. self.gain = gain
  116. def _initialize(self, arr):
  117. n_in, n_out = _calculate_in_and_out(arr)
  118. boundary = self.gain * math.sqrt(6.0 / (n_in + n_out))
  119. data = np.random.uniform(-boundary, boundary, arr.shape)
  120. _assignment(arr, data)
  121. @_register('he_uniform')
  122. class HeUniform(Initializer):
  123. r"""
  124. Initialize the array with He kaiming uniform algorithm, and from a uniform distribution collect samples within
  125. U[-boundary, boundary] where :math:`boundary = \sqrt{\frac{6}{n_{in}}}` where :math:`n_{in}` is the number of
  126. input units in the weight tensor.
  127. Args:
  128. arr (Array): The array to be assigned.
  129. Returns:
  130. Array, assigned array.
  131. """
  132. def _initialize(self, arr):
  133. n_in, _ = _calculate_in_and_out(arr)
  134. boundary = math.sqrt(6.0 / n_in)
  135. data = np.random.uniform(-boundary, boundary, arr.shape)
  136. _assignment(arr, data)
  137. class Constant(Initializer):
  138. """
  139. Initialize a constant.
  140. Args:
  141. value (Union[int, numpy.ndarray]): The value to initialize.
  142. Returns:
  143. Array, initialize array.
  144. """
  145. def __init__(self, value):
  146. super(Constant, self).__init__(value=value)
  147. self.value = value
  148. def _initialize(self, arr):
  149. _assignment(arr, self.value)
  150. @_register()
  151. class Uniform(Initializer):
  152. """
  153. Initialize a uniform array, and obtain values U(-scale, scale) from the uniform distribution
  154. to fill the input tensor.
  155. Args:
  156. scale (float): The scale of the array. Default: 0.07.
  157. Returns:
  158. Array, uniform array.
  159. """
  160. def __init__(self, scale=0.07):
  161. super(Uniform, self).__init__(scale=scale)
  162. self.scale = scale
  163. def _initialize(self, arr):
  164. tmp = np.random.uniform(-self.scale, self.scale, arr.shape)
  165. _assignment(arr, tmp)
  166. @_register()
  167. class Normal(Initializer):
  168. """
  169. Initialize a normal array, and obtain values N(0, sigma) from the uniform distribution
  170. to fill the input tensor.
  171. Args:
  172. sigma (float): The sigma of the array. Default: 0.01.
  173. Returns:
  174. Array, normal array.
  175. """
  176. def __init__(self, sigma=0.01):
  177. super(Normal, self).__init__(sigma=sigma)
  178. self.sigma = sigma
  179. def _initialize(self, arr):
  180. tmp = np.random.normal(0, self.sigma, arr.shape)
  181. _assignment(arr, tmp)
  182. @_register()
  183. class TruncatedNormal(Initializer):
  184. """
  185. Initialize a truncated normal distribution which is a bounded normal distribution within N(low, high).
  186. Args:
  187. sigma (float): The sigma of the array. Default: 0.01.
  188. Returns:
  189. Array, truncated normal array.
  190. """
  191. def __init__(self, sigma=0.01):
  192. super(TruncatedNormal, self).__init__(sigma=sigma)
  193. self.sigma = sigma
  194. def _initialize(self, arr):
  195. tmp = truncnorm.rvs(-2, 2, loc=0, scale=self.sigma, size=arr.shape, random_state=None)
  196. _assignment(arr, tmp)
  197. def initializer(init, shape=None, dtype=mstype.float32):
  198. """
  199. Create and initialize a tensor.
  200. Args:
  201. init (Union[Tensor, str, Initializer, numbers.Number]): Initialize value.
  202. - `str`: The `init` should be the alias of the class inheriting from `Initializer` and the corresponding
  203. class will be called.
  204. - `Initializer`: The `init` should be the class inheriting from `Initializer` to initialize tensor.
  205. - `numbers.Number`: The `Constant` will be called to initialize tensor.
  206. shape (Union[tuple, list, int]): A list of integers, a tuple of integers or an integer as the shape of
  207. output. Default: None.
  208. dtype (:class:`mindspore.dtype`): The type of data in initialized tensor. Default: mindspore.float32.
  209. Returns:
  210. Tensor, initialized tensor.
  211. Examples:
  212. >>> tensor = initializer('ones', [1, 2, 3], mindspore.float32)
  213. """
  214. if not isinstance(init, (Tensor, numbers.Number, str, Initializer)):
  215. raise TypeError('Unsupported init type.')
  216. if isinstance(init, Tensor):
  217. init_shape = init.shape()
  218. shape = shape if isinstance(shape, (tuple, list)) else [shape]
  219. if shape is not None and init_shape != tuple(shape):
  220. raise ValueError("The shape of init should be same as variable shape, but got the shape of init {} and "
  221. "the variable shape {}.".format(list(init.shape()), shape))
  222. return init
  223. try:
  224. arr = np.ndarray(shape)
  225. except ValueError:
  226. msg = "Error shape={}".format(shape)
  227. logger.error(msg)
  228. raise ValueError(msg)
  229. if isinstance(init, numbers.Number):
  230. init_obj = Constant(init)
  231. elif isinstance(init, str):
  232. init_obj = _INITIALIZER_ALIAS[init.lower()]()
  233. else:
  234. init_obj = init
  235. init_obj(arr)
  236. return Tensor(arr, dtype=dtype)
  237. __all__ = [
  238. 'Initializer',
  239. 'initializer',
  240. 'TruncatedNormal',
  241. 'Normal',
  242. 'Uniform',
  243. 'HeUniform',
  244. 'XavierUniform',
  245. 'One',
  246. 'Zero',
  247. 'Constant']