You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

initializer.py 10 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Initializer for cell parameters."""
  16. import numbers
  17. import math
  18. from functools import reduce
  19. import numpy as np
  20. from scipy.stats import truncnorm
  21. from mindspore import log as logger
  22. from . import dtype as mstype
  23. from .tensor import Tensor
  24. _INITIALIZER_ALIAS = dict()
  25. class Initializer:
  26. """
  27. The base class of the initializer.
  28. Args:
  29. kwargs (dict): Keyword arguments for Initializer.
  30. Returns:
  31. Array, assigned array.
  32. """
  33. def __init__(self, **kwargs):
  34. self._kwargs = kwargs
  35. self.shape = None
  36. self.dtype = None
  37. def _initialize(self, *kwargs):
  38. raise NotImplementedError('Must be overridden!')
  39. def __call__(self, arr):
  40. return self._initialize(arr)
  41. @property
  42. def shape(self):
  43. return self._shape
  44. @shape.setter
  45. def shape(self, shape):
  46. self._shape = shape
  47. @property
  48. def dtype(self):
  49. return self._dtype
  50. @dtype.setter
  51. def dtype(self, dtype):
  52. self._dtype = dtype
  53. def to_tensor(self, slice_index=None, shape=None):
  54. """
  55. Get the tensor format data of this Initializer.
  56. Args:
  57. slice_index (int): Slice index of a parameter's slices.
  58. Used when initialize a slice of a parameter, it guarantee that
  59. devices use the same slice can generate the same tensor.
  60. shape (list[int]): Shape of the slice, used when initialize a slice of the parameter.
  61. """
  62. arr = None
  63. if shape is None:
  64. shape = self.shape
  65. try:
  66. arr = np.ndarray(shape)
  67. except ValueError:
  68. msg = "Error shape={}".format(shape)
  69. logger.error(msg)
  70. raise ValueError(msg)
  71. if slice_index is not None:
  72. np.random.seed(slice_index)
  73. self.__call__(arr)
  74. return Tensor(arr, dtype=self.dtype)
  75. def _register(*aliases):
  76. """Return the alias register."""
  77. def alias_reg(cls):
  78. name = cls.__name__
  79. name = name.lower()
  80. if name not in _INITIALIZER_ALIAS:
  81. _INITIALIZER_ALIAS[name] = cls
  82. for alias in aliases:
  83. if alias not in _INITIALIZER_ALIAS:
  84. _INITIALIZER_ALIAS[alias] = cls
  85. return cls
  86. return alias_reg
  87. def _assignment(arr, num):
  88. """Assign the value of `num` to `arr`."""
  89. if arr.shape == ():
  90. arr = arr.reshape((1))
  91. arr[:] = num
  92. arr = arr.reshape(())
  93. else:
  94. if isinstance(num, np.ndarray):
  95. arr[:] = num[:]
  96. else:
  97. arr[:] = num
  98. return arr
  99. @_register('zeros')
  100. class Zero(Initializer):
  101. """
  102. Initialize the array to zero.
  103. Args:
  104. arr (Array): The array to be assigned.
  105. Returns:
  106. Array, assigned array.
  107. """
  108. def _initialize(self, arr):
  109. _assignment(arr, 0)
  110. @_register('ones')
  111. class One(Initializer):
  112. """
  113. Initialize the array to one.
  114. Args:
  115. arr (Array): The array to be assigned.
  116. Returns:
  117. Array, assigned array.
  118. """
  119. def _initialize(self, arr):
  120. _assignment(arr, 1)
  121. def _calculate_in_and_out(arr):
  122. """
  123. Calculate n_in and n_out.
  124. Args:
  125. arr (Array): Input array.
  126. Returns:
  127. Tuple, a tuple with two elements, the first element is `n_in` and the second element is `n_out`.
  128. """
  129. dim = len(arr.shape)
  130. if dim < 2:
  131. raise ValueError("If initialize data with xavier uniform, the dimension of data must greater than 1.")
  132. n_in = arr.shape[1]
  133. n_out = arr.shape[0]
  134. if dim > 2:
  135. counter = reduce(lambda x, y: x * y, arr.shape[2:])
  136. n_in *= counter
  137. n_out *= counter
  138. return n_in, n_out
  139. @_register('xavier_uniform')
  140. class XavierUniform(Initializer):
  141. r"""
  142. Initialize the array with xavier uniform algorithm, and from a uniform distribution collect samples within
  143. U[-boundary, boundary] where :math:`boundary = gain * \sqrt{\frac{6}{n_{in} + n_{out}}}`.
  144. Args:
  145. gain (Array): The array to be assigned. Default: 1.
  146. Returns:
  147. Array, assigned array.
  148. """
  149. def __init__(self, gain=1):
  150. super(XavierUniform, self).__init__(gain=gain)
  151. self.gain = gain
  152. def _initialize(self, arr):
  153. n_in, n_out = _calculate_in_and_out(arr)
  154. boundary = self.gain * math.sqrt(6.0 / (n_in + n_out))
  155. data = np.random.uniform(-boundary, boundary, arr.shape)
  156. _assignment(arr, data)
  157. @_register('he_uniform')
  158. class HeUniform(Initializer):
  159. r"""
  160. Initialize the array with He kaiming uniform algorithm, and from a uniform distribution collect samples within
  161. U[-boundary, boundary] where :math:`boundary = \sqrt{\frac{6}{n_{in}}}` where :math:`n_{in}` is the number of
  162. input units in the weight tensor.
  163. Args:
  164. arr (Array): The array to be assigned.
  165. Returns:
  166. Array, assigned array.
  167. """
  168. def _initialize(self, arr):
  169. n_in, _ = _calculate_in_and_out(arr)
  170. boundary = math.sqrt(6.0 / n_in)
  171. data = np.random.uniform(-boundary, boundary, arr.shape)
  172. _assignment(arr, data)
  173. class Constant(Initializer):
  174. """
  175. Initialize a constant.
  176. Args:
  177. value (Union[int, numpy.ndarray]): The value to initialize.
  178. Returns:
  179. Array, initialize array.
  180. """
  181. def __init__(self, value):
  182. super(Constant, self).__init__(value=value)
  183. self.value = value
  184. def _initialize(self, arr):
  185. _assignment(arr, self.value)
  186. @_register()
  187. class Uniform(Initializer):
  188. """
  189. Initialize a uniform array, and obtain values U(-scale, scale) from the uniform distribution
  190. to fill the input tensor.
  191. Args:
  192. scale (float): The scale of the array. Default: 0.07.
  193. Returns:
  194. Array, uniform array.
  195. """
  196. def __init__(self, scale=0.07):
  197. super(Uniform, self).__init__(scale=scale)
  198. self.scale = scale
  199. def _initialize(self, arr):
  200. tmp = np.random.uniform(-self.scale, self.scale, arr.shape)
  201. _assignment(arr, tmp)
  202. @_register()
  203. class Normal(Initializer):
  204. """
  205. Initialize a normal array, and obtain values N(0, sigma) from the uniform distribution
  206. to fill the input tensor.
  207. Args:
  208. sigma (float): The sigma of the array. Default: 0.01.
  209. Returns:
  210. Array, normal array.
  211. """
  212. def __init__(self, sigma=0.01):
  213. super(Normal, self).__init__(sigma=sigma)
  214. self.sigma = sigma
  215. def _initialize(self, arr):
  216. tmp = np.random.normal(0, self.sigma, arr.shape)
  217. _assignment(arr, tmp)
  218. @_register()
  219. class TruncatedNormal(Initializer):
  220. """
  221. Initialize a truncated normal distribution which is a bounded normal distribution within N(low, high).
  222. Args:
  223. sigma (float): The sigma of the array. Default: 0.01.
  224. Returns:
  225. Array, truncated normal array.
  226. """
  227. def __init__(self, sigma=0.01):
  228. super(TruncatedNormal, self).__init__(sigma=sigma)
  229. self.sigma = sigma
  230. def _initialize(self, arr):
  231. tmp = truncnorm.rvs(-2, 2, loc=0, scale=self.sigma, size=arr.shape, random_state=None)
  232. _assignment(arr, tmp)
  233. def initializer(init, shape=None, dtype=mstype.float32):
  234. """
  235. Create and initialize a tensor.
  236. Args:
  237. init (Union[Tensor, str, Initializer, numbers.Number]): Initialize value.
  238. - `str`: The `init` should be the alias of the class inheriting from `Initializer` and the corresponding
  239. class will be called.
  240. - `Initializer`: The `init` should be the class inheriting from `Initializer` to initialize tensor.
  241. - `numbers.Number`: The `Constant` will be called to initialize tensor.
  242. shape (Union[tuple, list, int]): A list of integers, a tuple of integers or an integer as the shape of
  243. output. Default: None.
  244. dtype (:class:`mindspore.dtype`): The type of data in initialized tensor. Default: mindspore.float32.
  245. Returns:
  246. Union[Tensor, Initializer], When `init` is Tensor, the return is Tensor object,
  247. otherwise the return is Initialize object.
  248. Examples:
  249. >>> tensor = initializer('ones', [1, 2, 3], mindspore.float32)
  250. """
  251. if not isinstance(init, (Tensor, numbers.Number, str, Initializer)):
  252. raise TypeError("Unsupported init type '{}'.".format(type(init)))
  253. if isinstance(init, Tensor):
  254. init_shape = init.shape
  255. shape = shape if isinstance(shape, (tuple, list)) else [shape]
  256. if shape is not None and init_shape != tuple(shape):
  257. raise ValueError("The shape of init should be same as variable shape, but got the shape of init {} and "
  258. "the variable shape {}.".format(list(init.shape), shape))
  259. return init
  260. if isinstance(shape, list):
  261. shape = tuple(shape)
  262. elif isinstance(shape, numbers.Number):
  263. shape = (shape,)
  264. if isinstance(init, Initializer):
  265. init.shape = init.shape if init.shape is not None else shape
  266. init.dtype = init.dtype if init.dtype is not None else dtype
  267. return init
  268. if isinstance(init, str):
  269. init_obj = _INITIALIZER_ALIAS[init.lower()]()
  270. if init_obj is None:
  271. raise ValueError("The class corresponding to '{}' was not found.".format(init))
  272. init = init_obj
  273. init.shape = shape
  274. init.dtype = dtype
  275. return init
  276. if isinstance(init, numbers.Number):
  277. init_obj = Constant(init)
  278. init_obj.shape = shape
  279. init_obj.dtype = dtype
  280. return init_obj
  281. raise TypeError("Unsupported init type '{}'.".format(type(init)))
  282. __all__ = [
  283. 'Initializer',
  284. 'initializer',
  285. 'TruncatedNormal',
  286. 'Normal',
  287. 'Uniform',
  288. 'HeUniform',
  289. 'XavierUniform',
  290. 'One',
  291. 'Zero',
  292. 'Constant']