You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

optimizer.py 38 kB

4 years ago
5 years ago
4 years ago
4 years ago
5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824
  1. # Copyright 2020-2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """optimizer"""
  16. import inspect
  17. from typing import Iterable
  18. import numpy as np
  19. import mindspore
  20. from mindspore.ops import functional as F, composite as C, operations as P
  21. from mindspore.ops.operations import _inner_ops as inner
  22. from mindspore.nn.cell import Cell
  23. from mindspore.nn.layer.container import CellList
  24. from mindspore.common.parameter import Parameter, ParameterTuple
  25. from mindspore.common.initializer import initializer
  26. from mindspore.common.tensor import Tensor, RowTensor
  27. import mindspore.common.dtype as mstype
  28. from mindspore._checkparam import Validator as validator
  29. from mindspore import log as logger
  30. from mindspore.parallel._utils import _get_global_rank, _get_device_num, _get_parallel_mode
  31. from mindspore.context import ParallelMode
  32. from mindspore import context
  33. from mindspore.nn.learning_rate_schedule import LearningRateSchedule
  34. __all__ = ['Optimizer', 'opt_init_args_register']
  35. def opt_init_args_register(fn):
  36. """Register optimizer init args."""
  37. def deco(self, *args, **kwargs):
  38. bound_args = inspect.signature(fn).bind(self, *args, **kwargs)
  39. bound_args.apply_defaults()
  40. arguments = bound_args.arguments
  41. arguments.pop('self')
  42. if 'params' in arguments.keys():
  43. setattr(self, 'init_params', dict({"params": arguments['params']}))
  44. arguments.pop('params')
  45. if 'optimizer' in arguments.keys():
  46. setattr(self, 'init_params', dict({"params": arguments['optimizer'].init_params["params"]}))
  47. arguments.pop('optimizer')
  48. setattr(self, 'init_args', arguments)
  49. fn(self, *args, **kwargs)
  50. return deco
  51. class Optimizer(Cell):
  52. """
  53. Base class for updating parameters. Never use this class directly, but instantiate one of its subclasses instead.
  54. Grouping parameters is supported. If parameters are grouped, different strategy of `learning_rate`, `weight_decay`
  55. and `grad_centralization` can be applied to each group.
  56. Note:
  57. If parameters are not grouped, the `weight_decay` in optimizer will be applied on the network parameters without
  58. 'beta' or 'gamma' in their names. Users can group parameters to change the strategy of decaying weight. When
  59. parameters are grouped, each group can set `weight_decay`, if not, the `weight_decay` in optimizer will be
  60. applied.
  61. Args:
  62. learning_rate (Union[float, int, Tensor, Iterable, LearningRateSchedule]):
  63. - float: The fixed learning rate value. Must be equal to or greater than 0.
  64. - int: The fixed learning rate value. Must be equal to or greater than 0. It will be converted to float.
  65. - Tensor: Its value should be a scalar or a 1-D vector. For scalar, fixed learning rate will be applied.
  66. For vector, learning rate is dynamic, then the i-th step will take the i-th value as the learning rate.
  67. - Iterable: Learning rate is dynamic. The i-th step will take the i-th value as the learning rate.
  68. - LearningRateSchedule: Learning rate is dynamic. During training, the optimizer calls the instance of
  69. LearningRateSchedule with step as the input to get the learning rate of current step.
  70. parameters (Union[list[Parameter], list[dict]]): Must be list of `Parameter` or list of `dict`. When the
  71. `parameters` is a list of `dict`, the string "params", "lr", "weight_decay", "grad_centralization" and
  72. "order_params" are the keys can be parsed.
  73. - params: Required. Parameters in current group. The value must be a list of `Parameter`.
  74. - lr: Optional. If "lr" in the keys, the value of corresponding learning rate will be used.
  75. If not, the `learning_rate` in optimizer will be used. Fixed and dynamic learning rate are supported.
  76. - weight_decay: Optional. If "weight_decay" in the keys, the value of corresponding weight decay
  77. will be used. If not, the `weight_decay` in the optimizer will be used.
  78. - grad_centralization: Optional. Must be Boolean. If "grad_centralization" is in the keys, the set value
  79. will be used. If not, the `grad_centralization` is False by default. This configuration only works on the
  80. convolution layer.
  81. - order_params: Optional. When parameters is grouped, this usually is used to maintain the order of
  82. parameters that appeared in the network to improve performance. The value should be parameters whose
  83. order will be followed in optimizer.
  84. If `order_params` in the keys, other keys will be ignored and the element of 'order_params' must be in
  85. one group of `params`.
  86. weight_decay (Union[float, int]): An int or a floating point value for the weight decay.
  87. It must be equal to or greater than 0.
  88. If the type of `weight_decay` input is int, it will be converted to float. Default: 0.0.
  89. loss_scale (float): A floating point value for the loss scale. It must be greater than 0. If the
  90. type of `loss_scale` input is int, it will be converted to float. In general, use the default value. Only
  91. when `FixedLossScaleManager` is used for training and the `drop_overflow_update` in
  92. `FixedLossScaleManager` is set to False, this value needs to be the same as the `loss_scale` in
  93. `FixedLossScaleManager`. Refer to class :class:`mindspore.FixedLossScaleManager` for more details.
  94. Default: 1.0.
  95. Raises:
  96. TypeError: If `learning_rate` is not one of int, float, Tensor, Iterable, LearningRateSchedule.
  97. TypeError: If element of `parameters` is neither Parameter nor dict.
  98. TypeError: If `loss_scale` is not a float.
  99. TypeError: If `weight_decay` is neither float nor int.
  100. ValueError: If `loss_scale` is less than or equal to 0.
  101. ValueError: If `weight_decay` is less than 0.
  102. ValueError: If `learning_rate` is a Tensor, but the dimension of tensor is greater than 1.
  103. Supported Platforms:
  104. ``Ascend`` ``GPU`` ``CPU``
  105. """
  106. def __init__(self, learning_rate, parameters, weight_decay=0.0, loss_scale=1.0):
  107. super(Optimizer, self).__init__(auto_prefix=False)
  108. parameters = self._parameters_base_check(parameters, "parameters")
  109. if not all(isinstance(x, Parameter) for x in parameters) and not all(isinstance(x, dict) for x in parameters):
  110. raise TypeError("For 'Optimizer', all elements of the argument 'parameters' must be "
  111. "'Parameter' or 'dict'.")
  112. if isinstance(loss_scale, int):
  113. loss_scale = float(loss_scale)
  114. validator.check_value_type("loss_scale", loss_scale, [float], self.cls_name)
  115. validator.check_positive_float(loss_scale, "loss_scale", self.cls_name)
  116. self.loss_scale = loss_scale
  117. weight_decay = self._preprocess_weight_decay(weight_decay)
  118. self.grad_centralization = False
  119. self._unique = True
  120. self._target = context.get_context("device_target")
  121. self.dynamic_lr = False
  122. self.assignadd = None
  123. self.global_step = None
  124. self.is_group = False
  125. self.is_group_lr = False
  126. self.is_group_params_ordered = False
  127. learning_rate = self._preprocess_single_lr(learning_rate)
  128. if isinstance(parameters[0], dict):
  129. self.is_group = True
  130. self.group_params = []
  131. self.group_lr = []
  132. self.group_weight_decay = []
  133. self.group_grad_centralization = []
  134. self._init_group_params(parameters, learning_rate, weight_decay, self.grad_centralization)
  135. # The final value of dynamic_lr can be determined after the process of parse_single_lr and init_group_params
  136. if self.dynamic_lr:
  137. self.assignadd = P.AssignAdd()
  138. self.global_step = Parameter(initializer(0, [1], mindspore.int32), name='global_step')
  139. if self.is_group_lr:
  140. self.learning_rate = CellList(self.group_lr, auto_prefix=False) if self.dynamic_lr \
  141. else ParameterTuple(self.group_lr)
  142. else:
  143. self.learning_rate = self._build_single_lr(learning_rate, 'learning_rate')
  144. if self.is_group:
  145. self.parameters = ParameterTuple(self.group_params)
  146. self.weight_decay = tuple(self.group_weight_decay)
  147. self.weight_decay_tensor_tuple = tuple(Tensor(x, mstype.float32) for x in self.group_weight_decay)
  148. decay_filter = lambda x: x > 0
  149. self.decay_flags = tuple(decay_filter(x) for x in self.weight_decay)
  150. self.exec_weight_decay = any(self.decay_flags)
  151. self.grad_centralization_flags = tuple(self.group_grad_centralization)
  152. else:
  153. self.parameters = ParameterTuple(parameters)
  154. self.weight_decay = weight_decay * loss_scale
  155. self.weight_decay_tensor = Tensor(self.weight_decay, mstype.float32)
  156. decay_filter = lambda x: 'beta' not in x.name and 'gamma' not in x.name
  157. self.decay_flags = tuple(decay_filter(x) for x in self.parameters)
  158. self.exec_weight_decay = self.weight_decay > 0
  159. # when a parameter has been unique, there is no need do another unique in optimizer.
  160. for param in self.parameters:
  161. if param.unique:
  162. self._unique = False
  163. break
  164. ps_filter = lambda x: x.is_param_ps
  165. self.ps_parameters = tuple(ps_filter(x) for x in self.parameters)
  166. cache_filter = lambda x: x.cache_enable
  167. self.cache_enable = tuple(cache_filter(x) for x in self.parameters)
  168. self.reciprocal_scale = Tensor(1.0 / loss_scale, mstype.float32)
  169. self.need_scale = loss_scale != 1.0
  170. self.global_step_increase_tensor = Tensor(1, mstype.int32)
  171. self.param_length = len(self.parameters)
  172. self.map_ = C.Map()
  173. self.map_reverse = C.Map(None, True)
  174. self.hyper_map = C.HyperMap()
  175. self.hyper_map_reverse = C.HyperMap(None, True)
  176. self._use_parallel_optimizer()
  177. self.enable_tuple_broaden = True
  178. def _use_parallel_optimizer(self):
  179. """Indicates whether to use automatic parallelism."""
  180. if context.get_auto_parallel_context("enable_parallel_optimizer"):
  181. if _get_parallel_mode() == ParallelMode.DATA_PARALLEL and context.get_context("device_target") == "Ascend":
  182. self.use_parallel = True
  183. elif _get_parallel_mode() == ParallelMode.DATA_PARALLEL \
  184. and context.get_context("device_target") != "Ascend":
  185. raise RuntimeError("Parallel optimizer only supports 'Ascend' in data parallel mode, "
  186. "but got {}.".format(context.get_context("device_target")))
  187. elif _get_parallel_mode() in (ParallelMode.STAND_ALONE, ParallelMode.HYBRID_PARALLEL):
  188. raise RuntimeError("Parallel optimizer is not supported in {}.".format(_get_parallel_mode()))
  189. else:
  190. self.use_parallel = False
  191. else:
  192. self.use_parallel = False
  193. if self.use_parallel:
  194. if self.cls_name not in ["Lamb", "AdamWeightDecay", "AdaFactor"]:
  195. raise RuntimeError("Parallel optimizer only support optimizer 'Lamb', 'AdamWeightDecay' or "
  196. "'AdaFactor', but got {}.".format(self.cls_name))
  197. self.dev_num = _get_device_num()
  198. if self.dev_num > self.param_length:
  199. raise RuntimeError("Parallel optimizer can not be applied when the number of parameters {} is"
  200. " less than the number of devices {}".format(self.param_length, self.dev_num))
  201. self.param_rank = self._get_parameter_group_id()
  202. self.optim_filter = tuple(map(lambda x: x == _get_global_rank(), self.param_rank))
  203. self.param_names = []
  204. for param in self.parameters:
  205. self.param_names.append(param.name)
  206. else:
  207. self.optim_filter = (True,) * self.param_length
  208. @property
  209. def unique(self):
  210. """
  211. Whether to make the gradients unique in optimizer. Generally, it is used in sparse networks. Set to True if the
  212. the gradients of the optimizer are sparse. Set to False if the forward network has made the parameters unique,
  213. that is, the gradients of the optimizer is no longer sparse.
  214. The default value is True when it is not set.
  215. """
  216. return self._unique
  217. @unique.setter
  218. def unique(self, value):
  219. """Set the `unique` attribute."""
  220. if not isinstance(value, bool):
  221. raise TypeError("For 'Optimizer', the property 'unique' must be bool, "
  222. "but got {}".format(type(value)))
  223. self._unique = value
  224. @property
  225. def target(self):
  226. """
  227. The property is used to determine whether the parameter is updated on host or device. The input type is str
  228. and can only be 'CPU', 'Ascend' or 'GPU'.
  229. """
  230. return self._target
  231. @target.setter
  232. def target(self, value):
  233. """
  234. If the input value is set to "CPU", the parameters will be updated on the host using the Fused
  235. optimizer operation.
  236. """
  237. raise NotImplementedError
  238. def _set_base_target(self, value):
  239. """
  240. If the input value is set to "CPU", the parameters will be updated on the host using the Fused
  241. optimizer operation.
  242. """
  243. if not isinstance(value, str):
  244. raise TypeError("For 'Optimizer', the property 'target' must be string, but got {}".format(type(value)))
  245. if value not in ('CPU', 'Ascend', 'GPU'):
  246. raise ValueError("For 'Optimizer', the property 'target' must be one of ['CPU', 'Ascend' ,'GPU'], "
  247. "but got {}".format(value))
  248. if self._target == "CPU" and value in ('Ascend', 'GPU'):
  249. raise ValueError("For 'Optimizer', the property 'target' cannot be set to 'GPU' or 'Ascend' "
  250. "in the 'CPU' environment.")
  251. if self._target == "Ascend" and value == 'GPU':
  252. raise ValueError("For 'Optimizer', the property 'target' cannot be set to 'GPU' "
  253. "in the 'Ascend' environment.")
  254. if self._target == "GPU" and value == 'Ascend':
  255. raise ValueError("For 'Optimizer', the property 'target' cannot be set to 'Ascend' "
  256. "in the 'GPU' environment.")
  257. self._is_device = (value != 'CPU')
  258. self._target = value
  259. def decay_weight(self, gradients):
  260. """
  261. Weight decay.
  262. An approach to reduce the overfitting of a deep learning neural network model. User-defined optimizers based
  263. on :class:`mindspore.nn.Optimizer` can also call this interface to apply weight decay.
  264. Args:
  265. gradients (tuple[Tensor]):The gradients of network parameters, and have the same shape as the parameters.
  266. Returns:
  267. tuple[Tensor], The gradients after weight decay.
  268. """
  269. if self.exec_weight_decay:
  270. params = self.parameters
  271. if self.is_group:
  272. gradients = self.map_(F.partial(_apply_decay), self.weight_decay_tensor_tuple, self.decay_flags,
  273. params, gradients)
  274. else:
  275. gradients = self.map_(F.partial(_apply_decay, self.weight_decay_tensor), self.decay_flags,
  276. params, gradients)
  277. return gradients
  278. def gradients_centralization(self, gradients):
  279. """
  280. Gradients centralization.
  281. A method for optimizing convolutional layer parameters to improve the training speed of a deep learning neural
  282. network model. User-defined optimizers based on :class:`mindspore.nn.Optimizer` can also call this interface to
  283. centralize gradients.
  284. Args:
  285. gradients (tuple[Tensor]): The gradients of network parameters, and have the same shape as the parameters.
  286. Returns:
  287. tuple[Tensor], The gradients after gradients centralization.
  288. """
  289. if self.is_group:
  290. gradients = self.map_(F.partial(_apply_grad_centralization), self.grad_centralization_flags, gradients)
  291. return gradients
  292. def scale_grad(self, gradients):
  293. """
  294. Restore gradients for mixed precision.
  295. User-defined optimizers based on :class:`mindspore.nn.Optimizer` can also call this interface to restore
  296. gradients.
  297. Args:
  298. gradients (tuple[Tensor]): The gradients of network parameters, and have the same shape as the parameters.
  299. Returns:
  300. tuple[Tensor], The gradients after loss scale.
  301. """
  302. if self.need_scale:
  303. gradients = self.map_(F.partial(_grad_scale, self.reciprocal_scale), gradients)
  304. return gradients
  305. def _grad_sparse_indices_deduplicate(self, gradients):
  306. """ In the case of using big operators, deduplicate the 'indexes' in gradients."""
  307. if self._target != 'CPU' and self._unique:
  308. gradients = self.map_(F.partial(_indices_deduplicate), gradients)
  309. return gradients
  310. def _preprocess_weight_decay(self, weight_decay):
  311. """Check weight decay, and convert int to float."""
  312. if isinstance(weight_decay, (float, int)):
  313. weight_decay = float(weight_decay)
  314. validator.check_non_negative_float(weight_decay, "weight_decay", self.cls_name)
  315. return weight_decay
  316. raise TypeError("Weight decay should be int or float.")
  317. def _preprocess_grad_centralization(self, grad_centralization):
  318. if not isinstance(grad_centralization, bool):
  319. raise TypeError("For 'Optimizer', the 'gradients_centralization' should be bool type, "
  320. "but got {}.".format(type(grad_centralization)))
  321. return grad_centralization
  322. def _preprocess_single_lr(self, learning_rate):
  323. """Check lr value, and convert lr to a float, a Tensor or a LearningRateSchedule."""
  324. if isinstance(learning_rate, (float, int)):
  325. learning_rate = float(learning_rate)
  326. validator.check_non_negative_float(learning_rate, "learning rate", self.cls_name)
  327. return learning_rate
  328. if isinstance(learning_rate, Tensor) and learning_rate.ndim == 0:
  329. return learning_rate
  330. self.dynamic_lr = True
  331. if isinstance(learning_rate, Iterable):
  332. return Tensor(np.array(list(learning_rate)).astype(np.float32))
  333. if isinstance(learning_rate, Tensor):
  334. if learning_rate.ndim > 1:
  335. raise ValueError(f"For 'Optimizer', the dim of Tensor type 'learning_rate' should be a 0 or 1, "
  336. f"but got {learning_rate.ndim}.")
  337. if learning_rate.ndim == 1 and learning_rate.size < 2:
  338. logger.warning("If use `Tensor` type dynamic learning rate, please make sure that the number "
  339. "of elements in the tensor is greater than 1.")
  340. return learning_rate
  341. if isinstance(learning_rate, LearningRateSchedule):
  342. return learning_rate
  343. raise TypeError("For 'Optimizer', 'learning_rate' should be int, float, Tensor, Iterable or "
  344. "LearningRateSchedule, but got {}.".format(type(learning_rate)))
  345. def _build_single_lr(self, learning_rate, name):
  346. """Build learning rate value, convert learning rate to a Parameter or a LearningRateSchedule."""
  347. if isinstance(learning_rate, float):
  348. learning_rate = Parameter(Tensor(learning_rate, mstype.float32), name)
  349. if self.is_group_lr and self.dynamic_lr:
  350. learning_rate = _ConvertToCell(learning_rate)
  351. return learning_rate
  352. if isinstance(learning_rate, Tensor) and learning_rate.ndim == 0:
  353. learning_rate = Parameter(learning_rate, name)
  354. if self.is_group_lr and self.dynamic_lr:
  355. learning_rate = _ConvertToCell(learning_rate)
  356. return learning_rate
  357. if isinstance(learning_rate, Tensor) and learning_rate.ndim == 1:
  358. return _IteratorLearningRate(learning_rate, name)
  359. return learning_rate
  360. def _parameters_base_check(self, parameters, param_info):
  361. if parameters is None:
  362. raise ValueError(f"Optimizer {param_info} can not be None.")
  363. if not isinstance(parameters, Iterable):
  364. raise TypeError(f"Optimizer {param_info} must be Iterable.")
  365. parameters = list(parameters)
  366. if not parameters:
  367. raise ValueError(f"Optimizer got an empty {param_info} list.")
  368. return parameters
  369. def _check_group_params(self, parameters):
  370. """Check group params."""
  371. parse_keys = ['params', 'lr', 'weight_decay', 'order_params', 'grad_centralization']
  372. for group_param in parameters:
  373. invalid_key = list(filter(lambda x: x not in parse_keys, group_param.keys()))
  374. if invalid_key:
  375. raise KeyError(f'The key "{invalid_key}" cannot be recognized in group params.')
  376. if 'order_params' in group_param.keys():
  377. if len(group_param.keys()) > 1:
  378. raise ValueError("The order params dict in group parameters should "
  379. "only include the 'order_params' key.")
  380. if not isinstance(group_param['order_params'], Iterable):
  381. raise TypeError("The value of 'order_params' should be an Iterable type.")
  382. continue
  383. parameters = self._parameters_base_check(group_param['params'], "group `params`")
  384. if not all(isinstance(x, Parameter) for x in parameters):
  385. raise TypeError("The group `params` should be an iterator of Parameter type.")
  386. def _parse_group_params(self, parameters, learning_rate):
  387. """Parse group params."""
  388. self._check_group_params(parameters)
  389. if isinstance(learning_rate, Tensor) and learning_rate.ndim == 1:
  390. tensor_lr_length = learning_rate.size
  391. else:
  392. tensor_lr_length = 0
  393. for group_param in parameters:
  394. if 'order_params' in group_param.keys():
  395. if len(group_param.keys()) > 1:
  396. raise ValueError("The order params dict in group parameters should "
  397. "only include the 'order_params' key.")
  398. if not isinstance(group_param['order_params'], Iterable):
  399. raise TypeError("The value of 'order_params' should be an Iterable type.")
  400. self.is_group_params_ordered = True
  401. continue
  402. if 'lr' in group_param.keys():
  403. self.is_group_lr = True
  404. group_lr = self._preprocess_single_lr(group_param['lr'])
  405. if isinstance(group_lr, Tensor) and group_lr.ndim == 1:
  406. group_lr_length = group_lr.size
  407. if tensor_lr_length == 0:
  408. tensor_lr_length = group_lr_length
  409. elif group_lr_length != tensor_lr_length:
  410. raise ValueError("The Tensor type dynamic learning rate in group should be the same size.")
  411. def _init_group_params(self, parameters, learning_rate, weight_decay, grad_centralization):
  412. """Initialize learning rate, weight decay or grad centralization in group params."""
  413. self._parse_group_params(parameters, learning_rate)
  414. default_lr = self._build_single_lr(learning_rate, 'learning_rate')
  415. params_store = []
  416. for group_num, group_param in enumerate(parameters):
  417. if 'order_params' in group_param.keys():
  418. ordered_parameters = group_param['order_params']
  419. continue
  420. self.group_params += group_param['params']
  421. if 'lr' in group_param.keys():
  422. lr_param_name = 'learning_rate_group_' + str(group_num)
  423. lr = self._preprocess_single_lr(group_param['lr'])
  424. lr = self._build_single_lr(lr, lr_param_name)
  425. else:
  426. lr = default_lr
  427. if 'weight_decay' in group_param.keys():
  428. cur_weight_decay = self._preprocess_weight_decay(group_param['weight_decay'])
  429. weight_decay_ = cur_weight_decay * self.loss_scale
  430. else:
  431. weight_decay_ = weight_decay * self.loss_scale
  432. if 'grad_centralization' in group_param.keys():
  433. self.grad_centralization = self._preprocess_grad_centralization(group_param['grad_centralization'])
  434. for param in group_param['params']:
  435. validator.check_value_type("parameter", param, [Parameter], self.cls_name)
  436. grad_centralization_ = self.grad_centralization
  437. else:
  438. grad_centralization_ = grad_centralization
  439. for key in group_param.keys():
  440. if key not in ('params', 'lr', 'weight_decay', 'grad_centralization'):
  441. logger.warning(f"The optimizer cannot parse '{key}' when setting parameter groups.")
  442. for param in group_param['params']:
  443. validator.check_value_type("parameter", param, [Parameter], self.cls_name)
  444. if param.name in params_store:
  445. raise RuntimeError(f"The {param.name} parameter already exists, it does not support "
  446. f"repeated setting. Please check whether the optimizer parameter "
  447. f"has been set multiple times.")
  448. params_store.append(param.name)
  449. self.group_lr.append(lr)
  450. self.group_weight_decay.append(weight_decay_)
  451. self.group_grad_centralization.append(grad_centralization_)
  452. if self.is_group_params_ordered:
  453. self._order_and_adjust_group_params(ordered_parameters)
  454. def _order_and_adjust_group_params(self, ordered_parameters):
  455. """
  456. Order group parameter, learning rate, weight decay and grad centralization in group params.
  457. """
  458. params_length = len(self.group_params)
  459. if len(ordered_parameters) != len(self.group_params):
  460. raise ValueError(f"The value of 'order_params' should be same with all group parameters.")
  461. ordered_params = [None] * params_length
  462. ordered_learning_rate = [None] * params_length
  463. ordered_weight_decay = [None] * params_length
  464. ordered_grad_centralization = [None] * params_length
  465. params_name = [param.name for param in ordered_parameters]
  466. for param, lr, wd, gc in zip(self.group_params, self.group_lr, self.group_weight_decay,
  467. self.group_grad_centralization):
  468. index = params_name.index(param.name)
  469. ordered_params[index] = param
  470. ordered_learning_rate[index] = lr
  471. ordered_weight_decay[index] = wd
  472. ordered_grad_centralization[index] = gc
  473. self.group_params = ordered_params
  474. self.group_lr = ordered_learning_rate
  475. self.group_weight_decay = ordered_weight_decay
  476. self.group_grad_centralization = ordered_grad_centralization
  477. def get_lr(self):
  478. """
  479. The optimizer calls this interface to get the learning rate for the current step. User-defined optimizers based
  480. on :class:`mindspore.nn.Optimizer` can also call this interface before updating the parameters.
  481. Returns:
  482. float, the learning rate of current step.
  483. """
  484. lr = self.learning_rate
  485. if self.dynamic_lr:
  486. if self.is_group_lr:
  487. lr = ()
  488. for learning_rate in self.learning_rate:
  489. current_dynamic_lr = learning_rate(self.global_step)
  490. lr += (current_dynamic_lr,)
  491. else:
  492. lr = self.learning_rate(self.global_step)
  493. self.assignadd(self.global_step, self.global_step_increase_tensor)
  494. return lr
  495. def get_lr_parameter(self, param):
  496. """
  497. When parameters is grouped and learning rate is different for each group. Get the learning rate of the specified
  498. `param`.
  499. Args:
  500. param (Union[Parameter, list[Parameter]]): The `Parameter` or list of `Parameter`.
  501. Returns:
  502. Parameter, single `Parameter` or `list[Parameter]` according to the input type. If learning rate is dynamic,
  503. `LearningRateSchedule` or `list[LearningRateSchedule]` that used to calculate the learning rate will be
  504. returned.
  505. Examples:
  506. >>> from mindspore import nn
  507. >>> net = Net()
  508. >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
  509. >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
  510. >>> group_params = [{'params': conv_params, 'lr': 0.05},
  511. ... {'params': no_conv_params, 'lr': 0.01}]
  512. >>> optim = nn.Momentum(group_params, learning_rate=0.1, momentum=0.9, weight_decay=0.0)
  513. >>> conv_lr = optim.get_lr_parameter(conv_params)
  514. >>> print(conv_lr[0].asnumpy())
  515. 0.05
  516. """
  517. def get_lr_value(learning_rate):
  518. if isinstance(learning_rate, (_ConvertToCell, _IteratorLearningRate)):
  519. return learning_rate.learning_rate
  520. return learning_rate
  521. if isinstance(param, Parameter):
  522. param_list = [param]
  523. elif isinstance(param, list):
  524. param_list = param
  525. else:
  526. raise TypeError(f"For 'get_lr_parameter', the 'param' must be 'Parameter' or 'list' type, "
  527. f"but got {type(param)}.")
  528. lr = []
  529. ids = [id(p) for p in self.parameters]
  530. for p in param_list:
  531. validator.check_value_type("parameter", p, [Parameter], self.cls_name)
  532. if id(p) not in ids:
  533. raise ValueError(f"The parameter {p.name} is not in optimizer, please check whether the "
  534. f"argument 'param' is correct.")
  535. if self.is_group_lr:
  536. index = ids.index(id(p))
  537. lr.append(get_lr_value(self.learning_rate[index]))
  538. else:
  539. lr.append(get_lr_value(self.learning_rate))
  540. return lr if isinstance(param, list) else lr[0]
  541. def _get_parameter_group_id(self):
  542. """
  543. Get the parameter partition group id, which is less than the number of devices.
  544. Returns:
  545. tuple, the group id tuple of parameters.
  546. """
  547. rank_list = ()
  548. count = 0
  549. for _ in range(self.param_length):
  550. rank_list = rank_list + (count,)
  551. count = count + 1
  552. if count == self.dev_num:
  553. count = 0
  554. return rank_list
  555. def broadcast_params(self, optim_result):
  556. """
  557. Apply Broadcast operations in the sequential order of parameter groups.
  558. Args:
  559. optim_result(bool): The results of updating parameters. This input is used to ensure that the parameters are
  560. updated before they are broadcast.
  561. Returns:
  562. bool, the status flag.
  563. """
  564. param_group = []
  565. key_group = []
  566. for _ in range(self.dev_num):
  567. param_group.append(F.make_tuple())
  568. key_group.append(F.make_tuple())
  569. for i in range(self.param_length):
  570. param_group[self.param_rank[i]] = param_group[self.param_rank[i]] + (self.parameters[i],)
  571. key = P.MakeRefKey(self.param_names[i])()
  572. key_group[self.param_rank[i]] = key_group[self.param_rank[i]] + (key,)
  573. new_param_group = []
  574. for root in range(self.dev_num):
  575. ops = P.Broadcast(root)
  576. if root > 0:
  577. param_group[root] = F.depend(param_group[root], new_param_group[root-1])
  578. else:
  579. param_group[root] = F.depend(param_group[root], optim_result)
  580. next_params = ops(param_group[root])
  581. new_param_group.append(next_params)
  582. for i in range(F.tuple_len(next_params)):
  583. F.assign(key_group[root][i], next_params[i])
  584. return new_param_group
  585. def construct(self, *hyper_params):
  586. raise NotImplementedError
  587. op_add = P.AddN()
  588. op_gather = P.Gather()
  589. op_mul = P.Mul()
  590. op_gc = inner.Centralization()
  591. _apply_decay = C.MultitypeFuncGraph("apply_decay")
  592. _apply_grad_centralization = C.MultitypeFuncGraph("apply_grad_centralization")
  593. @_apply_decay.register("Tensor", "Bool", "Tensor", "RowTensor")
  594. def _tensor_apply_decay_with_sparse(weight_decay, if_apply, weight, gradient):
  595. """Get grad with weight_decay."""
  596. if if_apply:
  597. indices = gradient.indices
  598. values = op_add((op_gather(weight, indices, 0) * F.cast(weight_decay, F.dtype(weight)), gradient.values))
  599. shape = gradient.dense_shape
  600. return RowTensor(indices, values, shape)
  601. return gradient
  602. @_apply_decay.register("Tensor", "Bool", "Tensor", "Tensor")
  603. def _tensor_apply_decay(weight_decay, if_apply, weight, gradient):
  604. """Get grad with weight_decay."""
  605. if if_apply:
  606. return op_add((op_mul(weight, F.cast(weight_decay, F.dtype(weight))), gradient))
  607. return gradient
  608. @_apply_grad_centralization.register("Bool", "RowTensor")
  609. def _tensor_apply_grad_centralization_with_sparse(if_apply, gradient):
  610. """Get grad with grad_centralization."""
  611. if if_apply:
  612. indices = gradient.indices
  613. shape = gradient.dense_shape
  614. grad_shape = F.shape(gradient)
  615. axis = []
  616. for i in range(1, len(grad_shape)):
  617. axis.append(i)
  618. if len(axis) >= 1:
  619. if grad_shape[1] % 16 != 0:
  620. return gradient
  621. values = op_gc(gradient.values, axis)
  622. return RowTensor(indices, values, shape)
  623. return gradient
  624. @_apply_grad_centralization.register("Bool", "Tensor")
  625. def _tensor_apply_grad_centralization(if_apply, gradient):
  626. """Get grad with grad_centralization."""
  627. if if_apply:
  628. axis = []
  629. grad_shape = F.shape(gradient)
  630. for i in range(1, len(grad_shape)):
  631. axis.append(i)
  632. if len(axis) >= 1:
  633. if grad_shape[1] % 16 != 0:
  634. return gradient
  635. return op_gc(gradient, axis)
  636. return gradient
  637. _grad_scale = C.MultitypeFuncGraph("grad_scale")
  638. _indices_deduplicate = C.MultitypeFuncGraph("indices_deduplicate")
  639. @_grad_scale.register("Number", "Tensor")
  640. def tensor_grad_scale(scale, grad):
  641. """Get grad with scale."""
  642. if scale == 1.0:
  643. return grad
  644. return op_mul(grad, F.cast(scale, F.dtype(grad)))
  645. @_grad_scale.register("Tensor", "Tensor")
  646. def tensor_grad_scale_with_tensor(scale, grad):
  647. """Get grad with scale."""
  648. return op_mul(grad, F.cast(scale, F.dtype(grad)))
  649. @_grad_scale.register("Tensor", "RowTensor")
  650. def tensor_grad_scale_with_sparse(scale, grad):
  651. """Get grad with scale."""
  652. return RowTensor(grad.indices, grad.values * F.cast(scale, F.dtype(grad.values)), grad.dense_shape)
  653. @_indices_deduplicate.register("RowTensor")
  654. def rowtensor_deduplicate_indices_slices(grad):
  655. """Unique the indices and sums the 'values' corresponding to the duplicate indices."""
  656. indices = grad.indices
  657. values = grad.values
  658. unique_indices, index_position = P.Unique()(indices)
  659. summed_values = P.UnsortedSegmentSum()(values, index_position, P.DynamicShape()(unique_indices)[0])
  660. return RowTensor(unique_indices, summed_values, grad.dense_shape)
  661. @_indices_deduplicate.register("Tensor")
  662. def tensor_deduplicate_indice_slices(grad):
  663. """Return the input gradient directly in the dense sences."""
  664. return grad
  665. class _ConvertToCell(LearningRateSchedule):
  666. """Inner api, convert learning rate of scalar to LearningRateSchedule."""
  667. def __init__(self, learning_rate):
  668. super(_ConvertToCell, self).__init__()
  669. if not isinstance(learning_rate, Parameter):
  670. raise TypeError("For '_ConvertToCell', the argument 'learning_rate' must be Parameter, "
  671. "but got {}.".format(type(learning_rate)))
  672. self.learning_rate = learning_rate
  673. def construct(self, global_step):
  674. return self.learning_rate + 1.0 - 1.0
  675. class _IteratorLearningRate(LearningRateSchedule):
  676. """Inner api, convert learning rate of Tensor(list) to LearningRateSchedule."""
  677. def __init__(self, learning_rate, name):
  678. super(_IteratorLearningRate, self).__init__()
  679. if isinstance(learning_rate, Tensor):
  680. if learning_rate.ndim != 1:
  681. raise ValueError(f"For '_IteratorLearningRate', the dimension of the argument 'learning_rate' should "
  682. f"be 1, but got {learning_rate.ndim}.")
  683. else:
  684. raise TypeError("For '_IteratorLearningRate', the argument 'learning_rate' should be Tensor, "
  685. "but got {}.".format(type(learning_rate)))
  686. self.learning_rate = Parameter(learning_rate, name)
  687. self.gather = P.Gather()
  688. def construct(self, global_step):
  689. return self.gather(self.learning_rate, global_step, 0)