You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

initializer.py 15 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Initializer for cell parameters."""
  16. import numbers
  17. import math
  18. import copy
  19. from functools import reduce
  20. import numpy as np
  21. from scipy.stats import truncnorm
  22. from mindspore import log as logger
  23. from . import dtype as mstype
  24. from .tensor import Tensor
  25. from .seed import get_global_seed
  26. from .._c_expression import random_normal
  27. _INITIALIZER_ALIAS = dict()
  28. class Initializer:
  29. """
  30. The base class of the initializer.
  31. Initialization of tensor basic attributes and model weight values.
  32. Args:
  33. kwargs (dict): Keyword arguments for Initializer.
  34. Returns:
  35. Array, an array after being initialized.
  36. """
  37. def __init__(self, **kwargs):
  38. self._kwargs = kwargs
  39. self.shape = None
  40. self.dtype = None
  41. def _initialize(self, *kwargs):
  42. raise NotImplementedError('Must be overridden!')
  43. def __call__(self, arr):
  44. return self._initialize(arr)
  45. @property
  46. def shape(self):
  47. return self._shape
  48. @shape.setter
  49. def shape(self, shape):
  50. self._shape = shape
  51. @property
  52. def dtype(self):
  53. return self._dtype
  54. @dtype.setter
  55. def dtype(self, dtype):
  56. self._dtype = dtype
  57. def to_tensor(self, slice_index=None, shape=None):
  58. """
  59. Get the tensor format data of this Initializer.
  60. Args:
  61. slice_index (int): Slice index of a parameter's slices.
  62. It is used when initialize a slice of a parameter, it guarantees that devices
  63. using the same slice can generate the same tensor.
  64. shape (list[int]): Shape of the slice, it is used when initialize a slice of the parameter.
  65. """
  66. arr = None
  67. if shape is None:
  68. shape = self.shape
  69. try:
  70. arr = np.ndarray(shape, dtype=mstype.dtype_to_nptype(self.dtype))
  71. except ValueError:
  72. msg = "Error shape={}".format(shape)
  73. logger.error(msg)
  74. raise ValueError(msg)
  75. global_seed = get_global_seed()
  76. need_set_seed = ((slice_index is not None) and (global_seed is None))
  77. seed_saved = np.random.get_state()[1][0]
  78. if need_set_seed:
  79. np.random.seed(slice_index)
  80. self.__call__(arr)
  81. if need_set_seed:
  82. np.random.seed(seed_saved)
  83. return Tensor(arr, dtype=self.dtype)
  84. def _register(*aliases):
  85. """Return the alias register."""
  86. def alias_reg(cls):
  87. name = cls.__name__
  88. name = name.lower()
  89. if name not in _INITIALIZER_ALIAS:
  90. _INITIALIZER_ALIAS[name] = cls
  91. for alias in aliases:
  92. if alias not in _INITIALIZER_ALIAS:
  93. _INITIALIZER_ALIAS[alias] = cls
  94. return cls
  95. return alias_reg
  96. def _assignment(arr, num):
  97. """Assign the value of `num` to `arr`."""
  98. if arr.shape == ():
  99. arr = arr.reshape((1))
  100. arr[:] = num
  101. arr = arr.reshape(())
  102. else:
  103. if isinstance(num, np.ndarray):
  104. arr[:] = num[:]
  105. else:
  106. arr[:] = num
  107. return arr
  108. @_register('zeros')
  109. class Zero(Initializer):
  110. """
  111. Initialize the array to zero.
  112. Args:
  113. arr (Array): The array to be assigned.
  114. Returns:
  115. Array, an array after being assigned.
  116. """
  117. def _initialize(self, arr):
  118. _assignment(arr, 0)
  119. @_register('ones')
  120. class One(Initializer):
  121. """
  122. Initialize the array to one.
  123. Args:
  124. arr (Array): The array to be assigned.
  125. Returns:
  126. Array, assigned array.
  127. """
  128. def _initialize(self, arr):
  129. _assignment(arr, 1)
  130. def _calculate_fan_in_and_fan_out(shape):
  131. """
  132. calculate fan_in and fan_out
  133. Args:
  134. shape (tuple): input shape.
  135. Returns:
  136. Tuple, a tuple with two elements, the first element is `n_in` and the second element is `n_out`.
  137. """
  138. dimensions = len(shape)
  139. if dimensions < 2:
  140. raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions")
  141. if dimensions == 2: # Linear
  142. fan_in = shape[1]
  143. fan_out = shape[0]
  144. else:
  145. num_input_fmaps = shape[1]
  146. num_output_fmaps = shape[0]
  147. receptive_field_size = 1
  148. if dimensions > 2:
  149. receptive_field_size = shape[2] * shape[3]
  150. fan_in = num_input_fmaps * receptive_field_size
  151. fan_out = num_output_fmaps * receptive_field_size
  152. return fan_in, fan_out
  153. def _calculate_correct_fan(shape, mode):
  154. """
  155. Calculate fan.
  156. Args:
  157. shape (tuple): input shape.
  158. mode (str): only support fan_in and fan_out.
  159. Returns:
  160. fan_in or fan_out.
  161. """
  162. mode = mode.lower()
  163. valid_modes = ['fan_in', 'fan_out']
  164. if mode not in valid_modes:
  165. raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes))
  166. fan_in, fan_out = _calculate_fan_in_and_fan_out(shape)
  167. return fan_in if mode == 'fan_in' else fan_out
  168. def _calculate_gain(nonlinearity, param=None):
  169. """
  170. Calculate gain.
  171. Args:
  172. nonlinearity (str): nonlinearity function.
  173. param (str): used to calculate negative_slope.
  174. Returns:
  175. number.
  176. """
  177. linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d']
  178. if nonlinearity in linear_fns or nonlinearity == 'sigmoid':
  179. res = 1
  180. elif nonlinearity == 'tanh':
  181. res = 5.0 / 3
  182. elif nonlinearity == 'relu':
  183. res = math.sqrt(2.0)
  184. elif nonlinearity == 'leaky_relu':
  185. if param is None:
  186. negative_slope = 0.01
  187. elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float):
  188. # True/False are instances of int, hence check above
  189. negative_slope = param
  190. else:
  191. raise ValueError("negative_slope {} not a valid number".format(param))
  192. res = math.sqrt(2.0 / (1 + negative_slope ** 2))
  193. else:
  194. raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
  195. return res
  196. def _calculate_in_and_out(arr):
  197. """
  198. Calculate n_in and n_out.
  199. Args:
  200. arr (Array): Input array.
  201. Returns:
  202. Tuple, a tuple with two elements, the first element is `n_in` and the second element is `n_out`.
  203. """
  204. dim = len(arr.shape)
  205. if dim < 2:
  206. raise ValueError("If initialize data with xavier uniform, the dimension of data must be greater than 1.")
  207. n_in = arr.shape[1]
  208. n_out = arr.shape[0]
  209. if dim > 2:
  210. counter = reduce(lambda x, y: x * y, arr.shape[2:])
  211. n_in *= counter
  212. n_out *= counter
  213. return n_in, n_out
  214. @_register('xavier_uniform')
  215. class XavierUniform(Initializer):
  216. r"""
  217. Initialize the array with xavier uniform algorithm, and from a uniform distribution collect samples within
  218. U[-boundary, boundary] The boundary is defined as :
  219. where :math:`boundary = gain * \sqrt{\frac{6}{n_{in} + n_{out}}}`.
  220. where :math:`n_{in}` is the number of input units in the weight tensor.
  221. where :math:`n_{out}` is the number of output units in the weight tensor.
  222. Args:
  223. gain (Array): The array to be assigned. Default: 1.
  224. Returns:
  225. Array, assigned array.
  226. """
  227. def __init__(self, gain=1):
  228. super(XavierUniform, self).__init__(gain=gain)
  229. self.gain = gain
  230. def _initialize(self, arr):
  231. n_in, n_out = _calculate_in_and_out(arr)
  232. boundary = self.gain * math.sqrt(6.0 / (n_in + n_out))
  233. data = np.random.uniform(-boundary, boundary, arr.shape)
  234. _assignment(arr, data)
  235. @_register('he_uniform')
  236. class HeUniform(Initializer):
  237. r"""
  238. Initialize the array with He kaiming uniform algorithm, and from a uniform distribution collect samples within
  239. U[-boundary, boundary] The boundary is defined as :
  240. where :math:`boundary = \sqrt{\frac{6}{n_{in}}}`.
  241. where :math:`n_{in}` is the number of input units in the weight tensor.
  242. Args:
  243. arr (Array): The array to be assigned.
  244. Returns:
  245. Array, assigned array.
  246. """
  247. def _initialize(self, arr):
  248. n_in, _ = _calculate_in_and_out(arr)
  249. boundary = math.sqrt(6.0 / n_in)
  250. data = np.random.uniform(-boundary, boundary, arr.shape)
  251. _assignment(arr, data)
  252. @_register('he_normal')
  253. class HeNormal(Initializer):
  254. r"""
  255. Initialize the array with He kaiming Normal algorithm, and from a normal distribution collect samples within
  256. N(0, sigma).
  257. Args:
  258. negative_slope (int, float, bool): Default: 0, used when nonlinearity is 'leaky_relu'.
  259. mode (str): Default: fan_in.
  260. nonlinearity (str): Default: leaky_relu.
  261. Returns:
  262. Array, assigned array.
  263. """
  264. def __init__(self, negative_slope=0, mode='fan_in', nonlinearity='leaky_relu'):
  265. super(HeNormal, self).__init__(negative_slope=negative_slope, mode=mode, nonlinearity=nonlinearity)
  266. self.negative_slope = negative_slope
  267. self.mode = mode
  268. self.nonlinearity = nonlinearity
  269. def _initialize(self, arr):
  270. fan = _calculate_correct_fan(arr.shape, self.mode)
  271. gain = _calculate_gain(self.nonlinearity, self.negative_slope)
  272. std = gain / math.sqrt(fan)
  273. data = np.random.normal(0, std, arr.shape)
  274. _assignment(arr, data)
  275. class Constant(Initializer):
  276. """
  277. Initialize a constant.
  278. Args:
  279. value (Union[int, numpy.ndarray]): The value to initialize.
  280. Returns:
  281. Array, an array after being assigned.
  282. """
  283. def __init__(self, value):
  284. super(Constant, self).__init__(value=value)
  285. self.value = value
  286. def _initialize(self, arr):
  287. _assignment(arr, self.value)
  288. @_register()
  289. class Uniform(Initializer):
  290. """
  291. Initialize a uniform array, and obtain values U(-scale, scale) from the uniform distribution
  292. to fill the input tensor.
  293. Args:
  294. scale (float): The scale of the array. Default: 0.07.
  295. Returns:
  296. Array, uniform array.
  297. """
  298. def __init__(self, scale=0.07):
  299. super(Uniform, self).__init__(scale=scale)
  300. self.scale = scale
  301. def _initialize(self, arr):
  302. tmp = np.random.uniform(-self.scale, self.scale, arr.shape)
  303. _assignment(arr, tmp)
  304. @_register()
  305. class Normal(Initializer):
  306. """
  307. Initialize a normal array, and obtain values N(0, sigma) from the uniform distribution
  308. to fill the input tensor.
  309. Args:
  310. sigma (float): The sigma of the array. Default: 0.01.
  311. Returns:
  312. Array, normal array.
  313. """
  314. def __init__(self, sigma=0.01):
  315. super(Normal, self).__init__(sigma=sigma)
  316. self.sigma = sigma
  317. def _initialize(self, arr):
  318. seed = np.random.get_state()[1][0]
  319. output_tensor = Tensor(np.zeros(arr.shape, dtype=np.float32))
  320. random_normal(0, self.sigma, arr.shape, seed, output_tensor)
  321. output_data = output_tensor.asnumpy()
  322. output_data *= self.sigma
  323. _assignment(arr, output_data)
  324. @_register()
  325. class TruncatedNormal(Initializer):
  326. """
  327. Initialize a truncated normal distribution which is a bounded normal distribution within N(low, high).
  328. Args:
  329. sigma (float): The sigma of the array. Default: 0.01.
  330. Returns:
  331. Array, truncated normal array.
  332. """
  333. def __init__(self, sigma=0.01):
  334. super(TruncatedNormal, self).__init__(sigma=sigma)
  335. self.sigma = sigma
  336. def _initialize(self, arr):
  337. tmp = truncnorm.rvs(-2, 2, loc=0, scale=self.sigma, size=arr.shape, random_state=None)
  338. _assignment(arr, tmp)
  339. def initializer(init, shape=None, dtype=mstype.float32):
  340. """
  341. Create and initialize a tensor.
  342. Args:
  343. init (Union[Tensor, str, Initializer, numbers.Number]): Initialize value.
  344. - `str`: The `init` should be the alias of the class inheriting from `Initializer` and the corresponding
  345. class will be called.
  346. - `Initializer`: The `init` should be the class inheriting from `Initializer` to initialize tensor.
  347. - `numbers.Number`: The `Constant` will be called to initialize tensor.
  348. shape (Union[tuple, list, int]): A list of integers, a tuple of integers or an integer as the shape of
  349. output. Default: None.
  350. dtype (:class:`mindspore.dtype`): The type of data in initialized tensor. Default: mindspore.float32.
  351. Returns:
  352. Union[Tensor, Initializer], When `init` is Tensor, the return is Tensor object,
  353. otherwise the return is Initialize object.
  354. Examples:
  355. >>> tensor = initializer('ones', [1, 2, 3], mindspore.float32)
  356. >>> tensor = initializer(One(), [1, 2, 3], mindspore.float32)
  357. >>> tensor = initializer(0, [1, 2, 3], mindspore.float32)
  358. """
  359. if not isinstance(init, (Tensor, numbers.Number, str, Initializer)):
  360. raise TypeError("Unsupported init type '{}'.".format(type(init)))
  361. if isinstance(init, Tensor):
  362. init_shape = init.shape
  363. shape = shape if isinstance(shape, (tuple, list)) else [shape]
  364. if shape is not None and init_shape != tuple(shape):
  365. raise ValueError("The shape of init should be same as variable shape, but got the shape of init {} and "
  366. "the variable shape {}.".format(list(init.shape), shape))
  367. return init
  368. if isinstance(shape, list):
  369. shape = tuple(shape)
  370. elif isinstance(shape, numbers.Number):
  371. shape = (shape,)
  372. for value in shape if shape is not None else ():
  373. if not isinstance(value, int) or value <= 0:
  374. raise ValueError(f"shape is invalid, shape value must be positive integer, shape:{shape}")
  375. if isinstance(init, Initializer):
  376. init_copy = copy.deepcopy(init)
  377. init_copy.shape = shape if shape is not None else init.shape
  378. init_copy.dtype = init.dtype if init.dtype is not None else dtype
  379. return init_copy
  380. if isinstance(init, str):
  381. init_obj = _INITIALIZER_ALIAS[init.lower()]()
  382. if init_obj is None:
  383. raise ValueError("The class corresponding to '{}' was not found.".format(init))
  384. init = init_obj
  385. init.shape = shape
  386. init.dtype = dtype
  387. return init
  388. if isinstance(init, numbers.Number):
  389. init_obj = Constant(init)
  390. init_obj.shape = shape
  391. init_obj.dtype = dtype
  392. return init_obj
  393. raise TypeError("Unsupported init type '{}'.".format(type(init)))
  394. __all__ = [
  395. 'Initializer',
  396. 'initializer',
  397. 'TruncatedNormal',
  398. 'Normal',
  399. 'Uniform',
  400. 'HeUniform',
  401. 'HeNormal',
  402. 'XavierUniform',
  403. 'One',
  404. 'Zero',
  405. 'Constant']