You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

initializer.py 18 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Initializer for cell parameters."""
  16. import numbers
  17. import math
  18. from functools import reduce
  19. import numpy as np
  20. from scipy.stats import truncnorm
  21. from .seed import get_seed, _get_graph_seed
  22. from . import dtype as mstype
  23. from .tensor import Tensor
  24. from .._c_expression import random_normal
  25. _INITIALIZER_ALIAS = dict()
  26. class Initializer:
  27. """
  28. The base class of the initializer.
  29. Initialization of tensor basic attributes and model weight values.
  30. Args:
  31. kwargs (dict): Keyword arguments for Initializer.
  32. Returns:
  33. Array, an array after being initialized.
  34. """
  35. def __init__(self, **kwargs):
  36. self._kwargs = kwargs
  37. self._seed = None
  38. @property
  39. def seed(self):
  40. if self._seed is None:
  41. seed, seed2 = _get_graph_seed(get_seed(), "init")
  42. else:
  43. seed, seed2 = self._seed + 1, 0
  44. return seed, seed2
  45. @seed.setter
  46. def seed(self, value):
  47. self._seed = value
  48. def _initialize(self, *kwargs):
  49. raise NotImplementedError('Must be overridden!')
  50. def __call__(self, arr):
  51. return self._initialize(arr)
  52. def _register(*aliases):
  53. """Return the alias register."""
  54. def alias_reg(cls):
  55. name = cls.__name__
  56. name = name.lower()
  57. if name not in _INITIALIZER_ALIAS:
  58. _INITIALIZER_ALIAS[name] = cls
  59. for alias in aliases:
  60. if alias not in _INITIALIZER_ALIAS:
  61. _INITIALIZER_ALIAS[alias] = cls
  62. return cls
  63. return alias_reg
  64. def _assignment(arr, num):
  65. """Assign the value of `num` to `arr`."""
  66. if arr.shape == ():
  67. arr = arr.reshape((1))
  68. arr[:] = num
  69. arr = arr.reshape(())
  70. else:
  71. if isinstance(num, np.ndarray):
  72. arr[:] = num[:]
  73. else:
  74. arr[:] = num
  75. return arr
  76. @_register('zeros')
  77. class Zero(Initializer):
  78. """
  79. Initialize the array to zero.
  80. Args:
  81. arr (Array): The array to be assigned.
  82. Returns:
  83. Array, an array after being assigned.
  84. Examples:
  85. >>> import mindspore
  86. >>> from mindspore.common.initializer import initializer, Zero
  87. >>> tensor1 = initializer(Zero(), [1, 2, 3], mindspore.float32)
  88. >>> tensor2 = initializer('zeros', [1, 2, 3], mindspore.float32)
  89. """
  90. def _initialize(self, arr):
  91. _assignment(arr, 0)
  92. @_register('ones')
  93. class One(Initializer):
  94. """
  95. Initialize the array to one.
  96. Args:
  97. arr (Array): The array to be assigned.
  98. Returns:
  99. Array, assigned array.
  100. Examples:
  101. >>> import mindspore
  102. >>> from mindspore.common.initializer import initializer, One
  103. >>> tensor1 = initializer(One(), [1, 2, 3], mindspore.float32)
  104. >>> tensor2 = initializer('ones', [1, 2, 3], mindspore.float32)
  105. """
  106. def _initialize(self, arr):
  107. _assignment(arr, 1)
  108. def _calculate_fan_in_and_fan_out(shape):
  109. """
  110. calculate fan_in and fan_out
  111. Args:
  112. shape (tuple): input shape.
  113. Returns:
  114. Tuple, a tuple with two elements, the first element is `n_in` and the second element is `n_out`.
  115. """
  116. dimensions = len(shape)
  117. if dimensions < 2:
  118. raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions")
  119. if dimensions == 2: # Linear
  120. fan_in = shape[1]
  121. fan_out = shape[0]
  122. else:
  123. num_input_fmaps = shape[1]
  124. num_output_fmaps = shape[0]
  125. receptive_field_size = 1
  126. if dimensions > 2:
  127. for i in range(2, dimensions):
  128. receptive_field_size *= shape[i]
  129. fan_in = num_input_fmaps * receptive_field_size
  130. fan_out = num_output_fmaps * receptive_field_size
  131. return fan_in, fan_out
  132. def _calculate_correct_fan(shape, mode):
  133. """
  134. Calculate fan.
  135. Args:
  136. shape (tuple): input shape.
  137. mode (str): only support fan_in and fan_out.
  138. Returns:
  139. fan_in or fan_out.
  140. """
  141. mode = mode.lower()
  142. valid_modes = ['fan_in', 'fan_out']
  143. if mode not in valid_modes:
  144. raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes))
  145. fan_in, fan_out = _calculate_fan_in_and_fan_out(shape)
  146. return fan_in if mode == 'fan_in' else fan_out
  147. def _calculate_gain(nonlinearity, param=None):
  148. """
  149. Calculate gain.
  150. Args:
  151. nonlinearity (str): nonlinearity function.
  152. param (str): used to calculate negative_slope.
  153. Returns:
  154. number.
  155. """
  156. linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d']
  157. if nonlinearity in linear_fns or nonlinearity == 'sigmoid':
  158. res = 1
  159. elif nonlinearity == 'tanh':
  160. res = 5.0 / 3
  161. elif nonlinearity == 'relu':
  162. res = math.sqrt(2.0)
  163. elif nonlinearity == 'leaky_relu':
  164. if param is None:
  165. negative_slope = 0.01
  166. elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float):
  167. # True/False are instances of int, hence check above
  168. negative_slope = param
  169. else:
  170. raise ValueError("negative_slope {} not a valid number".format(param))
  171. res = math.sqrt(2.0 / (1 + negative_slope ** 2))
  172. else:
  173. raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
  174. return res
  175. def _calculate_in_and_out(arr):
  176. """
  177. Calculate n_in and n_out.
  178. Args:
  179. arr (Array): Input array.
  180. Returns:
  181. Tuple, a tuple with two elements, the first element is `n_in` and the second element is `n_out`.
  182. """
  183. dim = len(arr.shape)
  184. if dim < 2:
  185. raise ValueError("If initialize data with xavier uniform, the dimension of data must be greater than 1.")
  186. n_in = arr.shape[1]
  187. n_out = arr.shape[0]
  188. if dim > 2:
  189. counter = reduce(lambda x, y: x * y, arr.shape[2:])
  190. n_in *= counter
  191. n_out *= counter
  192. return n_in, n_out
  193. @_register('xavier_uniform')
  194. class XavierUniform(Initializer):
  195. r"""
  196. Initialize the array with xavier uniform algorithm, and from a uniform distribution collect samples within
  197. U[-boundary, boundary] The boundary is defined as:
  198. .. math::
  199. boundary = gain * \sqrt{\frac{6}{n_{in} + n_{out}}}
  200. - where :math:`n_{in}` is the number of input units in the weight tensor.
  201. - where :math:`n_{out}` is the number of output units in the weight tensor.
  202. For details of XavierUniform algorithm, please check
  203. `<http://proceedings.mlr.press/v9/glorot10a.html>`_.
  204. Args:
  205. gain (float): An optional scaling factor. Default: 1.
  206. Returns:
  207. Array, assigned array.
  208. Examples:
  209. >>> import mindspore
  210. >>> from mindspore.common.initializer import initializer, XavierUniform
  211. >>> tensor1 = initializer(XavierUniform(), [1, 2, 3], mindspore.float32)
  212. >>> tensor2 = initializer('xavier_uniform', [1, 2, 3], mindspore.float32)
  213. """
  214. def __init__(self, gain=1):
  215. super(XavierUniform, self).__init__(gain=gain)
  216. self.gain = gain
  217. def _initialize(self, arr):
  218. n_in, n_out = _calculate_fan_in_and_fan_out(arr.shape)
  219. boundary = self.gain * math.sqrt(6.0 / (n_in + n_out))
  220. data = np.random.uniform(-boundary, boundary, arr.shape)
  221. _assignment(arr, data)
  222. @_register('he_uniform')
  223. class HeUniform(Initializer):
  224. r"""
  225. Initialize the array with He kaiming uniform algorithm, and from a uniform distribution collect samples within
  226. U[-boundary, boundary] The boundary is defined as:
  227. .. math::
  228. boundary = \sqrt{\frac{6}{(1 + a^2) \times \text{fan_in}}}
  229. Args:
  230. negative_slope (int, float, bool): The negative slope of the rectifier used after this layer
  231. (only used when `nonlinearity` is 'leaky_relu'). Default: 0.
  232. mode (str): Either 'fan_in' or 'fan_out'. Choosing 'fan_in' preserves the magnitude of the
  233. variance of the weights in the forward pass. Choosing 'fan_out' preserves the magnitudes
  234. in the backwards pass. Default: fan_in.
  235. nonlinearity (str): The non-linear function, recommended to use only with 'relu' or 'leaky_relu'.
  236. Default: leaky_relu.
  237. Returns:
  238. Array, assigned array.
  239. Examples:
  240. >>> import mindspore
  241. >>> from mindspore.common.initializer import initializer, HeUniform
  242. >>> tensor1 = initializer(HeUniform(), [1, 2, 3], mindspore.float32)
  243. >>> tensor2 = initializer('he_uniform', [1, 2, 3], mindspore.float32)
  244. """
  245. def __init__(self, negative_slope=0, mode='fan_in', nonlinearity='leaky_relu'):
  246. super(HeUniform, self).__init__(negative_slope=negative_slope, mode=mode, nonlinearity=nonlinearity)
  247. self.negative_slope = negative_slope
  248. self.mode = mode
  249. self.nonlinearity = nonlinearity
  250. def _initialize(self, arr):
  251. fan = _calculate_correct_fan(arr.shape, self.mode)
  252. gain = _calculate_gain(self.nonlinearity, self.negative_slope)
  253. std = gain / math.sqrt(fan)
  254. boundary = math.sqrt(3.0) * std
  255. data = np.random.uniform(-boundary, boundary, arr.shape)
  256. _assignment(arr, data)
  257. @_register('he_normal')
  258. class HeNormal(Initializer):
  259. r"""
  260. Initialize the array with He kaiming Normal algorithm, and from a normal distribution collect samples within
  261. N(0, sigma).
  262. Args:
  263. negative_slope (int, float, bool): The negative slope of the rectifier used after this layer
  264. (only used when `nonlinearity` is 'leaky_relu'). Default: 0.
  265. mode (str): Either 'fan_in' or 'fan_out'. Choosing 'fan_in' preserves the magnitude of the
  266. variance of the weights in the forward pass. Choosing 'fan_out' preserves the magnitudes
  267. in the backwards pass. Default: fan_in.
  268. nonlinearity (str): The non-linear function, recommended to use only with 'relu' or 'leaky_relu'.
  269. Default: leaky_relu.
  270. Returns:
  271. Array, assigned array.
  272. Examples:
  273. >>> import mindspore
  274. >>> from mindspore.common.initializer import initializer, HeNormal
  275. >>> tensor1 = initializer(HeNormal(), [1, 2, 3], mindspore.float32)
  276. >>> tensor2 = initializer('he_normal', [1, 2, 3], mindspore.float32)
  277. """
  278. def __init__(self, negative_slope=0, mode='fan_in', nonlinearity='leaky_relu'):
  279. super(HeNormal, self).__init__(negative_slope=negative_slope, mode=mode, nonlinearity=nonlinearity)
  280. self.negative_slope = negative_slope
  281. self.mode = mode
  282. self.nonlinearity = nonlinearity
  283. def _initialize(self, arr):
  284. fan = _calculate_correct_fan(arr.shape, self.mode)
  285. gain = _calculate_gain(self.nonlinearity, self.negative_slope)
  286. std = gain / math.sqrt(fan)
  287. data = np.random.normal(0, std, arr.shape)
  288. _assignment(arr, data)
  289. class Constant(Initializer):
  290. """
  291. Initialize a constant.
  292. Args:
  293. value (Union[int, numpy.ndarray]): The value to initialize.
  294. Returns:
  295. Array, an array after being assigned.
  296. Examples:
  297. >>> import mindspore
  298. >>> from mindspore.common.initializer import initializer
  299. >>> tensor1 = initializer(0, [1, 2, 3], mindspore.float32)
  300. >>> tensor2 = initializer(5, [1, 2, 3], mindspore.float32)
  301. """
  302. def __init__(self, value):
  303. super(Constant, self).__init__(value=value)
  304. self.value = value
  305. def _initialize(self, arr):
  306. _assignment(arr, self.value)
  307. @_register()
  308. class Uniform(Initializer):
  309. """
  310. Initialize a uniform array, and obtain values U(-scale, scale) from the uniform distribution
  311. to fill the input tensor.
  312. Args:
  313. scale (float): The scale of the array. Default: 0.07.
  314. Returns:
  315. Array, uniform array.
  316. Examples:
  317. >>> import mindspore
  318. >>> from mindspore.common.initializer import initializer, Uniform
  319. >>> tensor1 = initializer(Uniform(), [1, 2, 3], mindspore.float32)
  320. >>> tensor2 = initializer('uniform', [1, 2, 3], mindspore.float32)
  321. """
  322. def __init__(self, scale=0.07):
  323. super(Uniform, self).__init__(scale=scale)
  324. self.scale = scale
  325. def _initialize(self, arr):
  326. tmp = np.random.uniform(-self.scale, self.scale, arr.shape)
  327. _assignment(arr, tmp)
  328. @_register()
  329. class Normal(Initializer):
  330. """
  331. Initialize a normal array, and obtain values N(sigma, mean) from the normal distribution
  332. to fill the input tensor.
  333. Args:
  334. sigma (float): The sigma of the array. Default: 0.01.
  335. mean (float): The mean of the array. Default: 0.0.
  336. Returns:
  337. Array, normal array.
  338. Examples:
  339. >>> import mindspore
  340. >>> from mindspore.common.initializer import initializer, Normal
  341. >>> tensor1 = initializer(Normal(), [1, 2, 3], mindspore.float32)
  342. >>> tensor2 = initializer('normal', [1, 2, 3], mindspore.float32)
  343. """
  344. def __init__(self, sigma=0.01, mean=0.0):
  345. super(Normal, self).__init__(sigma=sigma, mean=mean)
  346. self.sigma = sigma
  347. self.mean = mean
  348. def _initialize(self, arr):
  349. seed, seed2 = self.seed
  350. output_tensor = Tensor(np.zeros(arr.shape, dtype=np.float32))
  351. random_normal(arr.shape, seed, seed2, output_tensor)
  352. output_data = output_tensor.asnumpy()
  353. output_data = output_data * self.sigma + self.mean
  354. _assignment(arr, output_data)
  355. @_register()
  356. class TruncatedNormal(Initializer):
  357. """
  358. Initialize a truncated normal distribution which is a bounded normal distribution within N(low, high).
  359. Args:
  360. sigma (float): The sigma of the array. Default: 0.01.
  361. Returns:
  362. Array, truncated normal array.
  363. Examples:
  364. >>> import mindspore
  365. >>> from mindspore.common.initializer import initializer, TruncatedNormal
  366. >>> tensor1 = initializer(TruncatedNormal(), [1, 2, 3], mindspore.float32)
  367. >>> tensor2 = initializer('truncatedNormal', [1, 2, 3], mindspore.float32)
  368. """
  369. def __init__(self, sigma=0.01):
  370. super(TruncatedNormal, self).__init__(sigma=sigma)
  371. self.sigma = sigma
  372. def _initialize(self, arr):
  373. tmp = truncnorm.rvs(-2, 2, loc=0, scale=self.sigma, size=arr.shape, random_state=None)
  374. _assignment(arr, tmp)
  375. def initializer(init, shape=None, dtype=mstype.float32):
  376. """
  377. Create and initialize a tensor.
  378. Args:
  379. init (Union[Tensor, str, Initializer, numbers.Number]): Initialize value.
  380. - `str`: The `init` should be the alias of the class inheriting from `Initializer` and the corresponding
  381. class will be called.
  382. - `Initializer`: The `init` should be the class inheriting from `Initializer` to initialize tensor.
  383. - `numbers.Number`: The `Constant` will be called to initialize tensor.
  384. shape (Union[tuple, list, int]): A list of integers, a tuple of integers or an integer as the shape of
  385. output. Default: None.
  386. dtype (:class:`mindspore.dtype`): The type of data in initialized tensor. Default: mindspore.float32.
  387. Returns:
  388. Union[Tensor], return is Tensor object.
  389. Examples:
  390. >>> import mindspore
  391. >>> from mindspore.common.initializer import initializer, One
  392. >>> tensor1 = initializer('ones', [1, 2, 3], mindspore.float32)
  393. >>> tensor2 = initializer(One(), [1, 2, 3], mindspore.float32)
  394. >>> tensor3 = initializer(0, [1, 2, 3], mindspore.float32)
  395. """
  396. if not isinstance(init, (Tensor, numbers.Number, str, Initializer)):
  397. raise TypeError("Unsupported init type '{}'.".format(type(init)))
  398. if isinstance(init, Tensor):
  399. init_shape = init.shape
  400. shape = shape if isinstance(shape, (tuple, list)) else [shape]
  401. if shape is not None and init_shape != tuple(shape):
  402. raise ValueError("The shape of init should be same as variable shape, but got the shape of init {} and "
  403. "the variable shape {}.".format(list(init.shape), shape))
  404. return init
  405. if isinstance(shape, list):
  406. shape = tuple(shape)
  407. elif isinstance(shape, numbers.Number):
  408. shape = (shape,)
  409. for value in shape if shape is not None else ():
  410. if not isinstance(value, int) or value <= 0:
  411. raise ValueError(f"shape is invalid, shape value must be positive integer, shape:{shape}")
  412. if isinstance(init, str):
  413. init = _INITIALIZER_ALIAS[init.lower()]()
  414. if init is None:
  415. raise ValueError("The class corresponding to '{}' was not found.".format(init))
  416. elif isinstance(init, numbers.Number):
  417. init = Constant(init)
  418. shape = shape if shape is not None else init.shape
  419. init_obj = Tensor(dtype=dtype, shape=shape, init=init)
  420. return init_obj
  421. __all__ = [
  422. 'Initializer',
  423. 'initializer',
  424. 'TruncatedNormal',
  425. 'Normal',
  426. 'Uniform',
  427. 'HeUniform',
  428. 'HeNormal',
  429. 'XavierUniform',
  430. 'One',
  431. 'Zero',
  432. 'Constant']