You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

initializer.py 14 kB

6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Initializer for cell parameters."""
  16. import numbers
  17. import math
  18. from functools import reduce
  19. import numpy as np
  20. from scipy.stats import truncnorm
  21. from mindspore import log as logger
  22. from . import dtype as mstype
  23. from .tensor import Tensor
  24. from .._c_expression import random_normal
  25. _INITIALIZER_ALIAS = dict()
  26. class Initializer:
  27. """
  28. The base class of the initializer.
  29. Args:
  30. kwargs (dict): Keyword arguments for Initializer.
  31. Returns:
  32. Array, assigned array.
  33. """
  34. def __init__(self, **kwargs):
  35. self._kwargs = kwargs
  36. self.shape = None
  37. self.dtype = None
  38. def _initialize(self, *kwargs):
  39. raise NotImplementedError('Must be overridden!')
  40. def __call__(self, arr):
  41. return self._initialize(arr)
  42. @property
  43. def shape(self):
  44. return self._shape
  45. @shape.setter
  46. def shape(self, shape):
  47. self._shape = shape
  48. @property
  49. def dtype(self):
  50. return self._dtype
  51. @dtype.setter
  52. def dtype(self, dtype):
  53. self._dtype = dtype
  54. def to_tensor(self, slice_index=None, shape=None):
  55. """
  56. Get the tensor format data of this Initializer.
  57. Args:
  58. slice_index (int): Slice index of a parameter's slices.
  59. Used when initialize a slice of a parameter, it guarantee that
  60. devices use the same slice can generate the same tensor.
  61. shape (list[int]): Shape of the slice, used when initialize a slice of the parameter.
  62. """
  63. arr = None
  64. if shape is None:
  65. shape = self.shape
  66. try:
  67. arr = np.ndarray(shape)
  68. except ValueError:
  69. msg = "Error shape={}".format(shape)
  70. logger.error(msg)
  71. raise ValueError(msg)
  72. if slice_index is not None:
  73. np.random.seed(slice_index)
  74. self.__call__(arr)
  75. return Tensor(arr, dtype=self.dtype)
  76. def _register(*aliases):
  77. """Return the alias register."""
  78. def alias_reg(cls):
  79. name = cls.__name__
  80. name = name.lower()
  81. if name not in _INITIALIZER_ALIAS:
  82. _INITIALIZER_ALIAS[name] = cls
  83. for alias in aliases:
  84. if alias not in _INITIALIZER_ALIAS:
  85. _INITIALIZER_ALIAS[alias] = cls
  86. return cls
  87. return alias_reg
  88. def _assignment(arr, num):
  89. """Assign the value of `num` to `arr`."""
  90. if arr.shape == ():
  91. arr = arr.reshape((1))
  92. arr[:] = num
  93. arr = arr.reshape(())
  94. else:
  95. if isinstance(num, np.ndarray):
  96. arr[:] = num[:]
  97. else:
  98. arr[:] = num
  99. return arr
  100. @_register('zeros')
  101. class Zero(Initializer):
  102. """
  103. Initialize the array to zero.
  104. Args:
  105. arr (Array): The array to be assigned.
  106. Returns:
  107. Array, assigned array.
  108. """
  109. def _initialize(self, arr):
  110. _assignment(arr, 0)
  111. @_register('ones')
  112. class One(Initializer):
  113. """
  114. Initialize the array to one.
  115. Args:
  116. arr (Array): The array to be assigned.
  117. Returns:
  118. Array, assigned array.
  119. """
  120. def _initialize(self, arr):
  121. _assignment(arr, 1)
  122. def _calculate_fan_in_and_fan_out(shape):
  123. """
  124. calculate fan_in and fan_out
  125. Args:
  126. shape (tuple): input shape.
  127. Returns:
  128. Tuple, a tuple with two elements, the first element is `n_in` and the second element is `n_out`.
  129. """
  130. dimensions = len(shape)
  131. if dimensions < 2:
  132. raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions")
  133. if dimensions == 2: # Linear
  134. fan_in = shape[1]
  135. fan_out = shape[0]
  136. else:
  137. num_input_fmaps = shape[1]
  138. num_output_fmaps = shape[0]
  139. receptive_field_size = 1
  140. if dimensions > 2:
  141. receptive_field_size = shape[2] * shape[3]
  142. fan_in = num_input_fmaps * receptive_field_size
  143. fan_out = num_output_fmaps * receptive_field_size
  144. return fan_in, fan_out
  145. def _calculate_correct_fan(shape, mode):
  146. """
  147. Calculate fan.
  148. Args:
  149. shape (tuple): input shape.
  150. mode (str): only support fan_in and fan_out.
  151. Returns:
  152. fan_in or fan_out.
  153. """
  154. mode = mode.lower()
  155. valid_modes = ['fan_in', 'fan_out']
  156. if mode not in valid_modes:
  157. raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes))
  158. fan_in, fan_out = _calculate_fan_in_and_fan_out(shape)
  159. return fan_in if mode == 'fan_in' else fan_out
  160. def _calculate_gain(nonlinearity, param=None):
  161. """
  162. Calculate gain.
  163. Args:
  164. nonlinearity (str): nonlinearity function.
  165. param (str): used to calculate negative_slope.
  166. Returns:
  167. number.
  168. """
  169. linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d']
  170. if nonlinearity in linear_fns or nonlinearity == 'sigmoid':
  171. res = 1
  172. elif nonlinearity == 'tanh':
  173. res = 5.0 / 3
  174. elif nonlinearity == 'relu':
  175. res = math.sqrt(2.0)
  176. elif nonlinearity == 'leaky_relu':
  177. if param is None:
  178. negative_slope = 0.01
  179. elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float):
  180. # True/False are instances of int, hence check above
  181. negative_slope = param
  182. else:
  183. raise ValueError("negative_slope {} not a valid number".format(param))
  184. res = math.sqrt(2.0 / (1 + negative_slope ** 2))
  185. else:
  186. raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
  187. return res
  188. def _calculate_in_and_out(arr):
  189. """
  190. Calculate n_in and n_out.
  191. Args:
  192. arr (Array): Input array.
  193. Returns:
  194. Tuple, a tuple with two elements, the first element is `n_in` and the second element is `n_out`.
  195. """
  196. dim = len(arr.shape)
  197. if dim < 2:
  198. raise ValueError("If initialize data with xavier uniform, the dimension of data must be greater than 1.")
  199. n_in = arr.shape[1]
  200. n_out = arr.shape[0]
  201. if dim > 2:
  202. counter = reduce(lambda x, y: x * y, arr.shape[2:])
  203. n_in *= counter
  204. n_out *= counter
  205. return n_in, n_out
  206. @_register('xavier_uniform')
  207. class XavierUniform(Initializer):
  208. r"""
  209. Initialize the array with xavier uniform algorithm, and from a uniform distribution collect samples within
  210. U[-boundary, boundary] where :math:`boundary = gain * \sqrt{\frac{6}{n_{in} + n_{out}}}`.
  211. Args:
  212. gain (Array): The array to be assigned. Default: 1.
  213. Returns:
  214. Array, assigned array.
  215. """
  216. def __init__(self, gain=1):
  217. super(XavierUniform, self).__init__(gain=gain)
  218. self.gain = gain
  219. def _initialize(self, arr):
  220. n_in, n_out = _calculate_in_and_out(arr)
  221. boundary = self.gain * math.sqrt(6.0 / (n_in + n_out))
  222. data = np.random.uniform(-boundary, boundary, arr.shape)
  223. _assignment(arr, data)
  224. @_register('he_uniform')
  225. class HeUniform(Initializer):
  226. r"""
  227. Initialize the array with He kaiming uniform algorithm, and from a uniform distribution collect samples within
  228. U[-boundary, boundary] where :math:`boundary = \sqrt{\frac{6}{n_{in}}}` where :math:`n_{in}` is the number of
  229. input units in the weight tensor.
  230. Args:
  231. arr (Array): The array to be assigned.
  232. Returns:
  233. Array, assigned array.
  234. """
  235. def _initialize(self, arr):
  236. n_in, _ = _calculate_in_and_out(arr)
  237. boundary = math.sqrt(6.0 / n_in)
  238. data = np.random.uniform(-boundary, boundary, arr.shape)
  239. _assignment(arr, data)
  240. @_register('he_normal')
  241. class HeNormal(Initializer):
  242. r"""
  243. Initialize the array with He kaiming Normal algorithm, and from a normal distribution collect samples within
  244. N(0, sigma).
  245. Args:
  246. negative_slope (int, float, bool): Default: 0, used when nonlinearity is 'leaky_relu'.
  247. mode (str): Default: fan_in.
  248. nonlinearity (str): Default: leaky_relu.
  249. Returns:
  250. Array, assigned array.
  251. """
  252. def __init__(self, negative_slope=0, mode='fan_in', nonlinearity='leaky_relu'):
  253. super(HeNormal, self).__init__(negative_slope=negative_slope, mode=mode, nonlinearity=nonlinearity)
  254. self.negative_slope = negative_slope
  255. self.mode = mode
  256. self.nonlinearity = nonlinearity
  257. def _initialize(self, arr):
  258. fan = _calculate_correct_fan(arr.shape, self.mode)
  259. gain = _calculate_gain(self.nonlinearity, self.negative_slope)
  260. std = gain / math.sqrt(fan)
  261. data = np.random.normal(0, std, arr.shape)
  262. _assignment(arr, data)
  263. class Constant(Initializer):
  264. """
  265. Initialize a constant.
  266. Args:
  267. value (Union[int, numpy.ndarray]): The value to initialize.
  268. Returns:
  269. Array, initialize array.
  270. """
  271. def __init__(self, value):
  272. super(Constant, self).__init__(value=value)
  273. self.value = value
  274. def _initialize(self, arr):
  275. _assignment(arr, self.value)
  276. @_register()
  277. class Uniform(Initializer):
  278. """
  279. Initialize a uniform array, and obtain values U(-scale, scale) from the uniform distribution
  280. to fill the input tensor.
  281. Args:
  282. scale (float): The scale of the array. Default: 0.07.
  283. Returns:
  284. Array, uniform array.
  285. """
  286. def __init__(self, scale=0.07):
  287. super(Uniform, self).__init__(scale=scale)
  288. self.scale = scale
  289. def _initialize(self, arr):
  290. tmp = np.random.uniform(-self.scale, self.scale, arr.shape)
  291. _assignment(arr, tmp)
  292. @_register()
  293. class Normal(Initializer):
  294. """
  295. Initialize a normal array, and obtain values N(0, sigma) from the uniform distribution
  296. to fill the input tensor.
  297. Args:
  298. sigma (float): The sigma of the array. Default: 0.01.
  299. Returns:
  300. Array, normal array.
  301. """
  302. def __init__(self, sigma=0.01):
  303. super(Normal, self).__init__(sigma=sigma)
  304. self.sigma = sigma
  305. def _initialize(self, arr):
  306. seed = np.random.get_state()[1][0]
  307. output_tensor = Tensor(np.zeros(arr.shape, dtype=np.float32))
  308. random_normal(0, self.sigma, arr.shape, seed, output_tensor)
  309. output_data = output_tensor.asnumpy()
  310. output_data *= self.sigma
  311. _assignment(arr, output_data)
  312. @_register()
  313. class TruncatedNormal(Initializer):
  314. """
  315. Initialize a truncated normal distribution which is a bounded normal distribution within N(low, high).
  316. Args:
  317. sigma (float): The sigma of the array. Default: 0.01.
  318. Returns:
  319. Array, truncated normal array.
  320. """
  321. def __init__(self, sigma=0.01):
  322. super(TruncatedNormal, self).__init__(sigma=sigma)
  323. self.sigma = sigma
  324. def _initialize(self, arr):
  325. tmp = truncnorm.rvs(-2, 2, loc=0, scale=self.sigma, size=arr.shape, random_state=None)
  326. _assignment(arr, tmp)
  327. def initializer(init, shape=None, dtype=mstype.float32):
  328. """
  329. Create and initialize a tensor.
  330. Args:
  331. init (Union[Tensor, str, Initializer, numbers.Number]): Initialize value.
  332. - `str`: The `init` should be the alias of the class inheriting from `Initializer` and the corresponding
  333. class will be called.
  334. - `Initializer`: The `init` should be the class inheriting from `Initializer` to initialize tensor.
  335. - `numbers.Number`: The `Constant` will be called to initialize tensor.
  336. shape (Union[tuple, list, int]): A list of integers, a tuple of integers or an integer as the shape of
  337. output. Default: None.
  338. dtype (:class:`mindspore.dtype`): The type of data in initialized tensor. Default: mindspore.float32.
  339. Returns:
  340. Union[Tensor, Initializer], When `init` is Tensor, the return is Tensor object,
  341. otherwise the return is Initialize object.
  342. Examples:
  343. >>> tensor = initializer('ones', [1, 2, 3], mindspore.float32)
  344. >>> tensor = initializer(One(), [1, 2, 3], mindspore.float32)
  345. >>> tensor = initializer(0, [1, 2, 3], mindspore.float32)
  346. """
  347. if not isinstance(init, (Tensor, numbers.Number, str, Initializer)):
  348. raise TypeError("Unsupported init type '{}'.".format(type(init)))
  349. if isinstance(init, Tensor):
  350. init_shape = init.shape
  351. shape = shape if isinstance(shape, (tuple, list)) else [shape]
  352. if shape is not None and init_shape != tuple(shape):
  353. raise ValueError("The shape of init should be same as variable shape, but got the shape of init {} and "
  354. "the variable shape {}.".format(list(init.shape), shape))
  355. return init
  356. if isinstance(shape, list):
  357. shape = tuple(shape)
  358. elif isinstance(shape, numbers.Number):
  359. shape = (shape,)
  360. if isinstance(init, Initializer):
  361. init.shape = init.shape if init.shape is not None else shape
  362. init.dtype = init.dtype if init.dtype is not None else dtype
  363. return init
  364. if isinstance(init, str):
  365. init_obj = _INITIALIZER_ALIAS[init.lower()]()
  366. if init_obj is None:
  367. raise ValueError("The class corresponding to '{}' was not found.".format(init))
  368. init = init_obj
  369. init.shape = shape
  370. init.dtype = dtype
  371. return init
  372. if isinstance(init, numbers.Number):
  373. init_obj = Constant(init)
  374. init_obj.shape = shape
  375. init_obj.dtype = dtype
  376. return init_obj
  377. raise TypeError("Unsupported init type '{}'.".format(type(init)))
  378. __all__ = [
  379. 'Initializer',
  380. 'initializer',
  381. 'TruncatedNormal',
  382. 'Normal',
  383. 'Uniform',
  384. 'HeUniform',
  385. 'HeNormal',
  386. 'XavierUniform',
  387. 'One',
  388. 'Zero',
  389. 'Constant']