You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

optimizer.py 34 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """optimizer"""
  16. import inspect
  17. from typing import Iterable
  18. import numpy as np
  19. import mindspore
  20. from mindspore.ops import functional as F, composite as C, operations as P
  21. from mindspore.ops.operations import _inner_ops as inner
  22. from mindspore.nn.cell import Cell
  23. from mindspore.nn.layer.container import CellList
  24. from mindspore.common.parameter import Parameter, ParameterTuple
  25. from mindspore.common.initializer import initializer
  26. from mindspore.common.tensor import Tensor, RowTensor
  27. import mindspore.common.dtype as mstype
  28. from mindspore._checkparam import Validator as validator
  29. from mindspore import log as logger
  30. from mindspore.parallel._utils import _get_global_rank, _get_device_num, _get_parallel_mode
  31. from mindspore.context import ParallelMode
  32. from mindspore import context
  33. from mindspore.nn.learning_rate_schedule import LearningRateSchedule
  34. __all__ = ['Optimizer', 'opt_init_args_register']
  35. def opt_init_args_register(fn):
  36. def deco(self, *args, **kwargs):
  37. bound_args = inspect.signature(fn).bind(self, *args, **kwargs)
  38. bound_args.apply_defaults()
  39. arguments = bound_args.arguments
  40. arguments.pop('self')
  41. arguments.pop('params')
  42. setattr(self, 'init_args', arguments)
  43. fn(self, *args, **kwargs)
  44. return deco
  45. class Optimizer(Cell):
  46. """
  47. Base class for all optimizers.
  48. Note:
  49. This class defines the API to add Ops to train a model. Never use
  50. this class directly, but instead instantiate one of its subclasses.
  51. Different parameter groups can set different `learning_rate`, `weight_decay` and `grad_centralization`.
  52. When separating parameter groups, the weight decay in each group will be applied on the parameters if the
  53. weight_decay is positive. For most optimizer, when not separating parameters, the `weight_decay` in the API will
  54. be applied on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive.
  55. When separating parameter groups, if you want to centralize the gradient, set grad_centralization to True,
  56. but the gradient centralization can only be applied to the parameters of the convolution layer.
  57. If the parameters of the non convolution layer are set to True, an error will be reported.
  58. To improve parameter groups performance, the customized order of parameters can be supported.
  59. Args:
  60. learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or a graph for the learning
  61. rate. When the learning_rate is an Iterable or a Tensor in a 1D dimension, use dynamic learning rate, then
  62. the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule,
  63. use dynamic learning rate, the i-th learning rate will be calculated during the process of training
  64. according to the formula of LearningRateSchedule. When the learning_rate is a float or a Tensor in a zero
  65. dimension, use fixed learning rate. Other cases are not supported. The float learning rate must be
  66. equal to or greater than 0. If the type of `learning_rate` is int, it will be converted to float.
  67. parameters (Union[list[Parameter], list[dict]]): When the `parameters` is a list of `Parameter` which will be
  68. updated, the element in `parameters` must be class `Parameter`. When the `parameters` is a list of `dict`,
  69. the "params", "lr", "weight_decay" and "order_params" are the keys can be parsed.
  70. - params: Required. The value must be a list of `Parameter`.
  71. - lr: Optional. If "lr" in the keys, the value of corresponding learning rate will be used.
  72. If not, the `learning_rate` in the API will be used.
  73. - weight_decay: Optional. If "weight_decay" in the keys, the value of corresponding weight decay
  74. will be used. If not, the `weight_decay` in the API will be used.
  75. - order_params: Optional. If "order_params" in the keys, the value must be the order of parameters and
  76. the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which
  77. in the value of 'order_params' must be in one of group parameters.
  78. - grad_centralization: Optional. The data type of "grad_centralization" is Bool. If "grad_centralization"
  79. is in the keys, the set value will be used. If not, the `grad_centralization` is False by default.
  80. This parameter only works on the convolution layer.
  81. weight_decay (Union[float, int]): An int or a floating point value for the weight decay.
  82. It must be equal to or greater than 0.
  83. If the type of `weight_decay` input is int, it will be converted to float. Default: 0.0.
  84. loss_scale (float): A floating point value for the loss scale. It must be greater than 0. If the
  85. type of `loss_scale` input is int, it will be converted to float. In general, use the default value. Only
  86. when `FixedLossScaleManager` is used for training and the `drop_overflow_update` in
  87. `FixedLossScaleManager` is set to False, then this value needs to be the same as the `loss_scale` in
  88. `FixedLossScaleManager`. Refer to class :class:`mindspore.FixedLossScaleManager` for more details.
  89. Default: 1.0.
  90. Raises:
  91. TypeError: If `learning_rate` is not one of int, float, Tensor, Iterable, LearningRateSchedule.
  92. TypeError: If element of `parameters` is neither Parameter nor dict.
  93. TypeError: If `loss_scale` is not a float.
  94. TypeError: If `weight_decay` is neither float nor int.
  95. ValueError: If `loss_scale` is less than or equal to 0.
  96. ValueError: If `weight_decay` is less than 0.
  97. ValueError: If `learning_rate` is a Tensor, but the dimension of tensor is greater than 1.
  98. Supported Platforms:
  99. ``Ascend`` ``GPU``
  100. """
  101. def __init__(self, learning_rate, parameters, weight_decay=0.0, loss_scale=1.0):
  102. super(Optimizer, self).__init__(auto_prefix=False)
  103. if parameters is not None and not isinstance(parameters, list):
  104. parameters = list(parameters)
  105. if not parameters:
  106. raise ValueError("Optimizer got an empty parameter list.")
  107. if not isinstance(parameters[0], (dict, Parameter)):
  108. raise TypeError("Only a list of Parameter or dict can be supported.")
  109. if isinstance(loss_scale, int):
  110. loss_scale = float(loss_scale)
  111. validator.check_value_type("loss_scale", loss_scale, [float], self.cls_name)
  112. validator.check_positive_float(loss_scale, "loss_scale", self.cls_name)
  113. self.loss_scale = loss_scale
  114. weight_decay = self._preprocess_weight_decay(weight_decay)
  115. self.grad_centralization = False
  116. self._unique = True
  117. self._target = context.get_context("device_target")
  118. self.dynamic_lr = False
  119. self.assignadd = None
  120. self.global_step = None
  121. self.is_group = False
  122. self.is_group_lr = False
  123. self.is_group_params_ordered = False
  124. learning_rate = self._preprocess_single_lr(learning_rate)
  125. if isinstance(parameters[0], dict):
  126. self.is_group = True
  127. self.group_params = []
  128. self.group_lr = []
  129. self.group_weight_decay = []
  130. self.group_grad_centralization = []
  131. self._init_group_params(parameters, learning_rate, weight_decay, self.grad_centralization)
  132. # The final value of dynamic_lr can be determined after the process of parse_single_lr and init_group_params
  133. if self.dynamic_lr:
  134. self.assignadd = P.AssignAdd()
  135. self.global_step = Parameter(initializer(0, [1], mindspore.int32), name='global_step')
  136. if self.is_group_lr:
  137. self.learning_rate = CellList(self.group_lr, auto_prefix=False) if self.dynamic_lr \
  138. else ParameterTuple(self.group_lr)
  139. else:
  140. self.learning_rate = self._build_single_lr(learning_rate, 'learning_rate')
  141. if self.is_group:
  142. self.parameters = ParameterTuple(self.group_params)
  143. self.weight_decay = tuple(self.group_weight_decay)
  144. self.weight_decay_tensor_tuple = tuple(Tensor(x, mstype.float32) for x in self.group_weight_decay)
  145. decay_filter = lambda x: x > 0
  146. self.decay_flags = tuple(decay_filter(x) for x in self.weight_decay)
  147. self.exec_weight_decay = any(self.decay_flags)
  148. self.grad_centralization_flags = tuple(self.group_grad_centralization)
  149. else:
  150. self.parameters = ParameterTuple(parameters)
  151. self.weight_decay = weight_decay * loss_scale
  152. self.weight_decay_tensor = Tensor(self.weight_decay, mstype.float32)
  153. decay_filter = lambda x: 'beta' not in x.name and 'gamma' not in x.name
  154. self.decay_flags = tuple(decay_filter(x) for x in self.parameters)
  155. self.exec_weight_decay = self.weight_decay > 0
  156. # when a parameter has been unique, there is no need do another unique in optimizer.
  157. for param in self.parameters:
  158. if param.unique:
  159. self._unique = False
  160. break
  161. ps_filter = lambda x: x.is_param_ps
  162. self.ps_parameters = tuple(ps_filter(x) for x in self.parameters)
  163. cache_filter = lambda x: x.cache_enable
  164. self.cache_enable = tuple(cache_filter(x) for x in self.parameters)
  165. self.reciprocal_scale = Tensor(1.0 / loss_scale, mstype.float32)
  166. self.need_scale = loss_scale != 1.0
  167. self.global_step_increase_tensor = Tensor(1, mstype.int32)
  168. self.param_length = len(self.parameters)
  169. self.map_ = C.Map()
  170. self._use_parallel_optimizer()
  171. def _use_parallel_optimizer(self):
  172. """Indicates whether to use automatic parallelism."""
  173. if context.get_auto_parallel_context("enable_parallel_optimizer"):
  174. if _get_parallel_mode() == ParallelMode.DATA_PARALLEL and context.get_context("device_target") == "Ascend":
  175. self.use_parallel = True
  176. elif _get_parallel_mode() == ParallelMode.DATA_PARALLEL \
  177. and context.get_context("device_target") != "Ascend":
  178. raise RuntimeError("Parallel optimizer only supports Ascend in data parallel mode.")
  179. elif _get_parallel_mode() in (ParallelMode.STAND_ALONE, ParallelMode.HYBRID_PARALLEL):
  180. raise RuntimeError("Parallel optimizer is not supported in {}.".format(_get_parallel_mode()))
  181. else:
  182. self.use_parallel = False
  183. else:
  184. self.use_parallel = False
  185. if self.use_parallel:
  186. if self.cls_name not in ["Lamb", "AdamWeightDecay"]:
  187. raise RuntimeError("Parallel optimizer does not support optimizer {}".format(self.cls_name))
  188. self.dev_num = _get_device_num()
  189. if self.dev_num > self.param_length:
  190. raise RuntimeError("Parallel optimizer can not be applied when the number of parameters {} is"
  191. " less than the number of devices {}".format(self.param_length, self.dev_num))
  192. self.param_rank = self._get_parameter_group_id()
  193. self.optim_filter = tuple(map(lambda x: x == _get_global_rank(), self.param_rank))
  194. self.param_names = []
  195. for param in self.parameters:
  196. self.param_names.append(param.name)
  197. else:
  198. self.optim_filter = (True,) * self.param_length
  199. @property
  200. def unique(self):
  201. """The method is to see whether to make unique. The input type is bool. The method is read-only."""
  202. return self._unique
  203. @unique.setter
  204. def unique(self, value):
  205. """Set whether the input value is unique."""
  206. if not isinstance(value, bool):
  207. raise TypeError("The value type must be bool, but got value type is {}".format(type(value)))
  208. self._unique = value
  209. @property
  210. def target(self):
  211. """The method is used to determine whether the parameter is updated on host or device. The input type is str
  212. and can only be 'CPU', 'Ascend' or 'GPU'."""
  213. return self._target
  214. @target.setter
  215. def target(self, value):
  216. """If the input value is set to "CPU", the parameters will be updated on the host using the Fused
  217. optimizer operation."""
  218. raise NotImplementedError
  219. def decay_weight(self, gradients):
  220. """
  221. Weight decay.
  222. An approach to reduce the overfitting of a deep learning neural network model.
  223. Args:
  224. gradients (tuple[Tensor]): The gradients of `self.parameters`, and have the same shape as
  225. `self.parameters`.
  226. Returns:
  227. tuple[Tensor], The gradients after weight decay.
  228. """
  229. if self.exec_weight_decay:
  230. params = self.parameters
  231. if self.is_group:
  232. gradients = self.map_(F.partial(_apply_decay), self.weight_decay_tensor_tuple, self.decay_flags,
  233. params, gradients)
  234. else:
  235. gradients = self.map_(F.partial(_apply_decay, self.weight_decay_tensor), self.decay_flags,
  236. params, gradients)
  237. return gradients
  238. def gradients_centralization(self, gradients):
  239. """
  240. Gradients centralization.
  241. A method for optimizing convolutional layer parameters to impore the training speed of a deep learning neural
  242. network model.
  243. Args:
  244. gradients (tuple[Tensor]): The gradients of `self.parameters`, and have the same shape as
  245. `self.parameters`.
  246. Returns:
  247. tuple[Tensor], The gradients after gradients centralization.
  248. """
  249. if self.is_group:
  250. gradients = self.map_(F.partial(_apply_grad_centralization), self.grad_centralization_flags, gradients)
  251. return gradients
  252. def scale_grad(self, gradients):
  253. """
  254. Loss scale for mixed precision.
  255. An approach of mixed precision training to improve the speed and energy efficiency of training deep neural
  256. network.
  257. Args:
  258. gradients (tuple[Tensor]): The gradients of `self.parameters`, and have the same shape as
  259. `self.parameters`.
  260. Returns:
  261. tuple[Tensor], The gradients after loss scale.
  262. """
  263. if self.need_scale:
  264. gradients = self.map_(F.partial(_grad_scale, self.reciprocal_scale), gradients)
  265. return gradients
  266. def _grad_sparse_indices_deduplicate(self, gradients):
  267. """ In the case of using big operators, deduplicate the 'indexes' in gradients."""
  268. if self._target != 'CPU' and self._unique:
  269. gradients = self.map_(F.partial(_indices_deduplicate), gradients)
  270. return gradients
  271. def _preprocess_weight_decay(self, weight_decay):
  272. """Check weight decay, and convert int to float."""
  273. if isinstance(weight_decay, (float, int)):
  274. weight_decay = float(weight_decay)
  275. validator.check_non_negative_float(weight_decay, "weight_decay", self.cls_name)
  276. return weight_decay
  277. raise TypeError("Weight decay should be int or float.")
  278. def _preprocess_grad_centralization(self, grad_centralization):
  279. if not isinstance(grad_centralization, bool):
  280. raise TypeError("The gradients centralization should be bool")
  281. return grad_centralization
  282. def _preprocess_single_lr(self, learning_rate):
  283. """Check lr value, and convert lr to a float, a Tensor or a LearningRateSchedule."""
  284. if isinstance(learning_rate, (float, int)):
  285. learning_rate = float(learning_rate)
  286. validator.check_non_negative_float(learning_rate, "learning rate", self.cls_name)
  287. return learning_rate
  288. if isinstance(learning_rate, Tensor) and learning_rate.ndim == 0:
  289. return learning_rate
  290. self.dynamic_lr = True
  291. if isinstance(learning_rate, Iterable):
  292. return Tensor(np.array(list(learning_rate)).astype(np.float32))
  293. if isinstance(learning_rate, Tensor):
  294. if learning_rate.ndim > 1:
  295. raise ValueError("The dim of `Tensor` type Learning rate should be a 0 or 1,"
  296. f"but got {learning_rate.ndim}.")
  297. if learning_rate.ndim == 1 and learning_rate.size < 2:
  298. logger.warning("If use `Tensor` type dynamic learning rate, please make sure that the number"
  299. "of elements in the tensor passed is greater than 1.")
  300. return learning_rate
  301. if isinstance(learning_rate, LearningRateSchedule):
  302. return learning_rate
  303. raise TypeError("Learning rate should be int, float, Tensor, Iterable or LearningRateSchedule.")
  304. def _build_single_lr(self, learning_rate, name):
  305. """Build learning rate value, convert learning rate to a Parameter or a LearningRateSchedule."""
  306. if isinstance(learning_rate, float):
  307. learning_rate = Parameter(Tensor(learning_rate, mstype.float32), name)
  308. if self.is_group_lr and self.dynamic_lr:
  309. learning_rate = _ConvertToCell(learning_rate)
  310. return learning_rate
  311. if isinstance(learning_rate, Tensor) and learning_rate.ndim == 0:
  312. learning_rate = Parameter(learning_rate, name)
  313. if self.is_group_lr and self.dynamic_lr:
  314. learning_rate = _ConvertToCell(learning_rate)
  315. return learning_rate
  316. if isinstance(learning_rate, Tensor) and learning_rate.ndim == 1:
  317. return _IteratorLearningRate(learning_rate, name)
  318. return learning_rate
  319. def _check_group_params(self, parameters):
  320. """Check group params."""
  321. parse_keys = ['params', 'lr', 'weight_decay', 'order_params', 'grad_centralization']
  322. for group_param in parameters:
  323. invalid_key = list(filter(lambda x: x not in parse_keys, group_param.keys()))
  324. if invalid_key:
  325. raise KeyError(f'The key "{invalid_key}" cannot be recognized in group params.')
  326. if 'order_params' in group_param.keys():
  327. if len(group_param.keys()) > 1:
  328. raise ValueError("The order params dict in group parameters should "
  329. "only include the 'order_params' key.")
  330. if not isinstance(group_param['order_params'], Iterable):
  331. raise TypeError("The value of 'order_params' should be an Iterable type.")
  332. continue
  333. if not group_param['params']:
  334. raise ValueError("Optimizer got an empty group parameter list.")
  335. for param in group_param['params']:
  336. if not isinstance(param, Parameter):
  337. raise TypeError("The group param should be an iterator of Parameter type.")
  338. def _parse_group_params(self, parameters, learning_rate):
  339. """Parse group params."""
  340. self._check_group_params(parameters)
  341. if isinstance(learning_rate, Tensor) and learning_rate.ndim == 1:
  342. tensor_lr_length = learning_rate.size
  343. else:
  344. tensor_lr_length = 0
  345. for group_param in parameters:
  346. if 'order_params' in group_param.keys():
  347. if len(group_param.keys()) > 1:
  348. raise ValueError("The order params dict in group parameters should "
  349. "only include the 'order_params' key.")
  350. if not isinstance(group_param['order_params'], Iterable):
  351. raise TypeError("The value of 'order_params' should be an Iterable type.")
  352. self.is_group_params_ordered = True
  353. continue
  354. if 'lr' in group_param.keys():
  355. self.is_group_lr = True
  356. group_lr = self._preprocess_single_lr(group_param['lr'])
  357. if isinstance(group_lr, Tensor) and group_lr.ndim == 1:
  358. group_lr_length = group_lr.size
  359. if tensor_lr_length == 0:
  360. tensor_lr_length = group_lr_length
  361. elif group_lr_length != tensor_lr_length:
  362. raise ValueError("The Tensor type dynamic learning rate in group should be the same size.")
  363. def _init_group_params(self, parameters, learning_rate, weight_decay, grad_centralization):
  364. """Initialize learning rate, weight decay or grad centralization in group params."""
  365. self._parse_group_params(parameters, learning_rate)
  366. default_lr = self._build_single_lr(learning_rate, 'learning_rate')
  367. params_store = []
  368. for group_num, group_param in enumerate(parameters):
  369. if 'order_params' in group_param.keys():
  370. ordered_parameters = group_param['order_params']
  371. continue
  372. self.group_params += group_param['params']
  373. if 'lr' in group_param.keys():
  374. lr_param_name = 'learning_rate_group_' + str(group_num)
  375. lr = self._preprocess_single_lr(group_param['lr'])
  376. lr = self._build_single_lr(lr, lr_param_name)
  377. else:
  378. lr = default_lr
  379. if 'weight_decay' in group_param.keys():
  380. cur_weight_decay = self._preprocess_weight_decay(group_param['weight_decay'])
  381. weight_decay_ = cur_weight_decay * self.loss_scale
  382. else:
  383. weight_decay_ = weight_decay * self.loss_scale
  384. if 'grad_centralization' in group_param.keys():
  385. self.grad_centralization = self._preprocess_grad_centralization(group_param['grad_centralization'])
  386. for param in group_param['params']:
  387. validator.check_value_type("parameter", param, [Parameter], self.cls_name)
  388. grad_centralization_ = self.grad_centralization
  389. else:
  390. grad_centralization_ = grad_centralization
  391. for key in group_param.keys():
  392. if key not in ('params', 'lr', 'weight_decay', 'grad_centralization'):
  393. logger.warning(f"The optimizer cannot parse '{key}' when setting parameter groups.")
  394. for param in group_param['params']:
  395. validator.check_value_type("parameter", param, [Parameter], self.cls_name)
  396. if param.name in params_store:
  397. raise RuntimeError(f"The {param.name} parameter has appeared in parameter groups.")
  398. params_store.append(param.name)
  399. self.group_lr.append(lr)
  400. self.group_weight_decay.append(weight_decay_)
  401. self.group_grad_centralization.append(grad_centralization_)
  402. if self.is_group_params_ordered:
  403. self._order_and_adjust_group_params(ordered_parameters)
  404. def _order_and_adjust_group_params(self, ordered_parameters):
  405. """
  406. Order group parameter, learning rate, weight decay and grad centralization in group params.
  407. """
  408. params_length = len(self.group_params)
  409. if len(ordered_parameters) != len(self.group_params):
  410. raise ValueError(f"The value of 'order_params' should be same with all group parameters.")
  411. ordered_params = [None] * params_length
  412. ordered_learning_rate = [None] * params_length
  413. ordered_weight_decay = [None] * params_length
  414. ordered_grad_centralization = [None] * params_length
  415. params_name = [param.name for param in ordered_parameters]
  416. for param, lr, wd, gc in zip(self.group_params, self.group_lr, self.group_weight_decay,
  417. self.group_grad_centralization):
  418. index = params_name.index(param.name)
  419. ordered_params[index] = param
  420. ordered_learning_rate[index] = lr
  421. ordered_weight_decay[index] = wd
  422. ordered_grad_centralization[index] = gc
  423. self.group_params = ordered_params
  424. self.group_lr = ordered_learning_rate
  425. self.group_weight_decay = ordered_weight_decay
  426. self.group_grad_centralization = ordered_grad_centralization
  427. def get_lr(self):
  428. """
  429. Get the learning rate of current step.
  430. Returns:
  431. float, the learning rate of current step.
  432. """
  433. lr = self.learning_rate
  434. if self.dynamic_lr:
  435. if self.is_group_lr:
  436. lr = ()
  437. for learning_rate in self.learning_rate:
  438. current_dynamic_lr = learning_rate(self.global_step)
  439. lr += (current_dynamic_lr,)
  440. else:
  441. lr = self.learning_rate(self.global_step)
  442. self.assignadd(self.global_step, self.global_step_increase_tensor)
  443. return lr
  444. def get_lr_parameter(self, param):
  445. """
  446. Get the learning rate of parameter.
  447. Args:
  448. param (Union[Parameter, list[Parameter]]): The `Parameter` or list of `Parameter`.
  449. Returns:
  450. Parameter, single `Parameter` or `list[Parameter]` according to the input type.
  451. """
  452. def get_lr_value(learning_rate):
  453. if isinstance(learning_rate, (_ConvertToCell, _IteratorLearningRate)):
  454. return learning_rate.learning_rate
  455. return learning_rate
  456. if isinstance(param, Parameter):
  457. param_list = [param]
  458. elif isinstance(param, list):
  459. param_list = param
  460. else:
  461. raise TypeError(f"The parameter only support 'Parameter' or 'list' type.")
  462. lr = []
  463. ids = [id(p) for p in self.parameters]
  464. for p in param_list:
  465. validator.check_value_type("parameter", p, [Parameter], self.cls_name)
  466. if id(p) not in ids:
  467. raise ValueError(f"The parameter {p.name} is not in optimizer.")
  468. if self.is_group_lr:
  469. index = ids.index(id(p))
  470. lr.append(get_lr_value(self.learning_rate[index]))
  471. else:
  472. lr.append(get_lr_value(self.learning_rate))
  473. return lr if isinstance(param, list) else lr[0]
  474. def _get_parameter_group_id(self):
  475. """
  476. Get the parameter partition group id, which is less than the number of devices.
  477. Returns:
  478. tuple, the group id tuple of parameters.
  479. """
  480. rank_list = ()
  481. count = 0
  482. for _ in range(self.param_length):
  483. rank_list = rank_list + (count,)
  484. count = count + 1
  485. if count == self.dev_num:
  486. count = 0
  487. return rank_list
  488. def broadcast_params(self, optim_result):
  489. """
  490. Apply Broadcast operations in the sequential order of parameter groups.
  491. Returns:
  492. bool, the status flag.
  493. """
  494. param_group = []
  495. key_group = []
  496. for _ in range(self.dev_num):
  497. param_group.append(F.make_tuple())
  498. key_group.append(F.make_tuple())
  499. for i in range(self.param_length):
  500. param_group[self.param_rank[i]] = param_group[self.param_rank[i]] + (self.parameters[i],)
  501. key = P.MakeRefKey(self.param_names[i])()
  502. key_group[self.param_rank[i]] = key_group[self.param_rank[i]] + (key,)
  503. new_param_group = []
  504. for root in range(self.dev_num):
  505. ops = P.Broadcast(root)
  506. if root > 0:
  507. param_group[root] = F.depend(param_group[root], new_param_group[root-1])
  508. else:
  509. param_group[root] = F.depend(param_group[root], optim_result)
  510. next_params = ops(param_group[root])
  511. new_param_group.append(next_params)
  512. for i in range(F.tuple_len(next_params)):
  513. F.assign(key_group[root][i], next_params[i])
  514. return new_param_group
  515. def construct(self, *hyper_params):
  516. raise NotImplementedError
  517. op_add = P.AddN()
  518. op_gather = P.Gather()
  519. op_mul = P.Mul()
  520. op_gc = inner.Centralization()
  521. _apply_decay = C.MultitypeFuncGraph("apply_decay")
  522. _apply_grad_centralization = C.MultitypeFuncGraph("apply_grad_centralization")
  523. @_apply_decay.register("Tensor", "Bool", "Tensor", "RowTensor")
  524. def _tensor_apply_decay_with_sparse(weight_decay, if_apply, weight, gradient):
  525. """Get grad with weight_decay."""
  526. if if_apply:
  527. indices = gradient.indices
  528. values = op_add((op_gather(weight, indices, 0) * F.cast(weight_decay, F.dtype(weight)), gradient.values))
  529. shape = gradient.dense_shape
  530. return RowTensor(indices, values, shape)
  531. return gradient
  532. @_apply_decay.register("Tensor", "Bool", "Tensor", "Tensor")
  533. def _tensor_apply_decay(weight_decay, if_apply, weight, gradient):
  534. """Get grad with weight_decay."""
  535. if if_apply:
  536. return op_add((op_mul(weight, F.cast(weight_decay, F.dtype(weight))), gradient))
  537. return gradient
  538. @_apply_grad_centralization.register("Bool", "RowTensor")
  539. def _tensor_apply_grad_centralization_with_sparse(if_apply, gradient):
  540. """Get grad with grad_centralization."""
  541. if if_apply:
  542. indices = gradient.indices
  543. shape = gradient.dense_shape
  544. grad_shape = F.shape(gradient)
  545. axis = []
  546. for i in range(1, len(grad_shape)):
  547. axis.append(i)
  548. if len(axis) >= 1:
  549. if grad_shape[1] % 16 != 0:
  550. return gradient
  551. values = op_gc(gradient.values, axis)
  552. return RowTensor(indices, values, shape)
  553. return gradient
  554. @_apply_grad_centralization.register("Bool", "Tensor")
  555. def _tensor_apply_grad_centralization(if_apply, gradient):
  556. """Get grad with grad_centralization."""
  557. if if_apply:
  558. axis = []
  559. grad_shape = F.shape(gradient)
  560. for i in range(1, len(grad_shape)):
  561. axis.append(i)
  562. if len(axis) >= 1:
  563. if grad_shape[1] % 16 != 0:
  564. return gradient
  565. return op_gc(gradient, axis)
  566. return gradient
  567. _grad_scale = C.MultitypeFuncGraph("grad_scale")
  568. _indices_deduplicate = C.MultitypeFuncGraph("indices_deduplicate")
  569. @_grad_scale.register("Number", "Tensor")
  570. def tensor_grad_scale(scale, grad):
  571. """Get grad with scale."""
  572. if scale == 1.0:
  573. return grad
  574. return op_mul(grad, F.cast(scale, F.dtype(grad)))
  575. @_grad_scale.register("Tensor", "Tensor")
  576. def tensor_grad_scale_with_tensor(scale, grad):
  577. """Get grad with scale."""
  578. return op_mul(grad, F.cast(scale, F.dtype(grad)))
  579. @_grad_scale.register("Tensor", "RowTensor")
  580. def tensor_grad_scale_with_sparse(scale, grad):
  581. """Get grad with scale."""
  582. return RowTensor(grad.indices, grad.values * F.cast(scale, F.dtype(grad.values)), grad.dense_shape)
  583. @_indices_deduplicate.register("RowTensor")
  584. def rowtensor_deduplicate_indices_slices(grad):
  585. """Unique the indices and sums the 'values' corresponding to the duplicate indices."""
  586. indices = grad.indices
  587. values = grad.values
  588. unique_indices, index_position = P.Unique()(indices)
  589. summed_values = P.UnsortedSegmentSum()(values, index_position, P.DynamicShape()(unique_indices)[0])
  590. return RowTensor(unique_indices, summed_values, grad.dense_shape)
  591. @_indices_deduplicate.register("Tensor")
  592. def tensor_deduplicate_indice_slices(grad):
  593. """Return the input gradient directly in the dense sences."""
  594. return grad
  595. class _ConvertToCell(LearningRateSchedule):
  596. """Inner api, convert learning rate of scalar to LearningRateSchedule."""
  597. def __init__(self, learning_rate):
  598. super(_ConvertToCell, self).__init__()
  599. if not isinstance(learning_rate, Parameter):
  600. raise TypeError('Learning rate must be Parameter.')
  601. self.learning_rate = learning_rate
  602. def construct(self, global_step):
  603. return self.learning_rate + 1.0 - 1.0
  604. class _IteratorLearningRate(LearningRateSchedule):
  605. """Inner api, convert learning rate of Tensor(list) to LearningRateSchedule."""
  606. def __init__(self, learning_rate, name):
  607. super(_IteratorLearningRate, self).__init__()
  608. if isinstance(learning_rate, Tensor):
  609. if learning_rate.ndim != 1:
  610. raise ValueError("The dim of `Tensor` type dynamic learning rate should be a 1,"
  611. f"but got {learning_rate.ndim}.")
  612. else:
  613. raise TypeError("Learning rate should be Tensor.")
  614. self.learning_rate = Parameter(learning_rate, name)
  615. self.gather = P.Gather()
  616. def construct(self, global_step):
  617. return self.gather(self.learning_rate, global_step, 0)