You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

optimizer.py 28 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """optimizer"""
  16. from typing import Iterable
  17. import numpy as np
  18. import mindspore
  19. from mindspore.ops import functional as F, composite as C, operations as P
  20. from mindspore.nn.cell import Cell
  21. from mindspore.nn.layer.container import CellList
  22. from mindspore.common.parameter import Parameter, ParameterTuple
  23. from mindspore.common.initializer import initializer
  24. from mindspore.common.tensor import Tensor, RowTensor
  25. import mindspore.common.dtype as mstype
  26. from mindspore._checkparam import Validator as validator
  27. from mindspore import log as logger
  28. from mindspore.parallel._utils import _get_global_rank, _get_device_num, _get_parallel_mode
  29. from mindspore.context import ParallelMode
  30. from mindspore import context
  31. from mindspore.nn.learning_rate_schedule import LearningRateSchedule
  32. __all__ = ['Optimizer']
  33. class Optimizer(Cell):
  34. """
  35. Base class for all optimizers.
  36. Note:
  37. This class defines the API to add Ops to train a model. Never use
  38. this class directly, but instead instantiate one of its subclasses.
  39. Different parameter groups can set different `learning_rate` and `weight_decay`.
  40. When separating parameter groups, the weight decay in each group will be applied on the parameters if the
  41. weight_decay is positive. For most optimizer, when not separating parameters, the `weight_decay` in the API will
  42. be applied on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive.
  43. To improve parameter groups performance, the customized order of parameters can be supported.
  44. Args:
  45. learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or a graph for the learning
  46. rate. When the learning_rate is an Iterable or a Tensor in a 1D dimension, use dynamic learning rate, then
  47. the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule,
  48. use dynamic learning rate, the i-th learning rate will be calculated during the process of training
  49. according to the formula of LearningRateSchedule. When the learning_rate is a float or a Tensor in a zero
  50. dimension, use fixed learning rate. Other cases are not supported. The float learning rate must be
  51. equal to or greater than 0. If the type of `learning_rate` is int, it will be converted to float.
  52. parameters (Union[list[Parameter], list[dict]]): When the `parameters` is a list of `Parameter` which will be
  53. updated, the element in `parameters` must be class `Parameter`. When the `parameters` is a list of `dict`,
  54. the "params", "lr", "weight_decay" and "order_params" are the keys can be parsed.
  55. - params: Required. The value must be a list of `Parameter`.
  56. - lr: Optional. If "lr" in the keys, the value of corresponding learning rate will be used.
  57. If not, the `learning_rate` in the API will be used.
  58. - weight_decay: Optional. If "weight_decay" in the keys, the value of corresponding weight decay
  59. will be used. If not, the `weight_decay` in the API will be used.
  60. - order_params: Optional. If "order_params" in the keys, the value must be the order of parameters and
  61. the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which
  62. in the value of 'order_params' must be in one of group parameters.
  63. weight_decay (float): A floating point value for the weight decay. It must be equal to or greater than 0.
  64. If the type of `weight_decay` input is int, it will be converted to float. Default: 0.0.
  65. loss_scale (float): A floating point value for the loss scale. It must be greater than 0. If the
  66. type of `loss_scale` input is int, it will be converted to float. Default: 1.0.
  67. Raises:
  68. ValueError: If the learning_rate is a Tensor, but the dimension of tensor is greater than 1.
  69. TypeError: If the learning_rate is not any of the three types: float, Tensor, nor Iterable.
  70. Supported Platforms:
  71. ``Ascend`` ``GPU``
  72. """
  73. def __init__(self, learning_rate, parameters, weight_decay=0.0, loss_scale=1.0):
  74. super(Optimizer, self).__init__(auto_prefix=False)
  75. if parameters is not None and not isinstance(parameters, list):
  76. parameters = list(parameters)
  77. if not parameters:
  78. raise ValueError("Optimizer got an empty parameter list.")
  79. if not isinstance(parameters[0], (dict, Parameter)):
  80. raise TypeError("Only a list of Parameter or dict can be supported.")
  81. if isinstance(loss_scale, int):
  82. loss_scale = float(loss_scale)
  83. validator.check_value_type("loss_scale", loss_scale, [float], self.cls_name)
  84. validator.check_positive_float(loss_scale, "loss_scale", self.cls_name)
  85. self.loss_scale = loss_scale
  86. weight_decay = self._preprocess_weight_decay(weight_decay)
  87. self._unique = True
  88. self._target = context.get_context("device_target")
  89. self.dynamic_lr = False
  90. self.assignadd = None
  91. self.global_step = None
  92. self.is_group = False
  93. self.is_group_lr = False
  94. self.is_group_params_ordered = False
  95. learning_rate = self._preprocess_single_lr(learning_rate)
  96. if isinstance(parameters[0], dict):
  97. self.is_group = True
  98. self.group_params = []
  99. self.group_lr = []
  100. self.group_weight_decay = []
  101. self._init_group_params(parameters, learning_rate, weight_decay)
  102. # The final value of dynamic_lr can be determined after the process of parse_single_lr and init_group_params
  103. if self.dynamic_lr:
  104. self.assignadd = P.AssignAdd()
  105. self.global_step = Parameter(initializer(0, [1], mindspore.int32), name='global_step')
  106. if self.is_group_lr:
  107. if self.dynamic_lr:
  108. self.learning_rate = CellList(self.group_lr)
  109. else:
  110. self.learning_rate = ParameterTuple(self.group_lr)
  111. else:
  112. self.learning_rate = self._build_single_lr(learning_rate, 'learning_rate')
  113. if self.is_group:
  114. self.parameters = ParameterTuple(self.group_params)
  115. self.weight_decay = tuple(self.group_weight_decay)
  116. self.weight_decay_tensor_tuple = tuple(Tensor(x, mstype.float32) for x in self.group_weight_decay)
  117. decay_filter = lambda x: x > 0
  118. self.decay_flags = tuple(decay_filter(x) for x in self.weight_decay)
  119. self.exec_weight_decay = any(self.decay_flags)
  120. else:
  121. self.parameters = ParameterTuple(parameters)
  122. self.weight_decay = weight_decay * loss_scale
  123. self.weight_decay_tensor = Tensor(self.weight_decay, mstype.float32)
  124. decay_filter = lambda x: 'beta' not in x.name and 'gamma' not in x.name
  125. self.decay_flags = tuple(decay_filter(x) for x in self.parameters)
  126. self.exec_weight_decay = self.weight_decay > 0
  127. # when a parameter has been unique, there is no need do another unique in optimizer.
  128. for param in self.parameters:
  129. if param.unique:
  130. self._unique = False
  131. break
  132. ps_filter = lambda x: x.is_param_ps
  133. self.ps_parameters = tuple(ps_filter(x) for x in self.parameters)
  134. cache_filter = lambda x: x.cache_enable
  135. self.cache_enable = tuple(cache_filter(x) for x in self.parameters)
  136. self.reciprocal_scale = Tensor(1.0 / loss_scale, mstype.float32)
  137. self.need_scale = loss_scale != 1.0
  138. self.global_step_increase_tensor = Tensor(1, mstype.int32)
  139. self.param_length = len(self.parameters)
  140. self.map_ = C.Map()
  141. if context.get_auto_parallel_context("enable_parallel_optimizer"):
  142. if _get_parallel_mode() == ParallelMode.DATA_PARALLEL and context.get_context("device_target") == "Ascend":
  143. self.use_parallel = True
  144. elif _get_parallel_mode() == ParallelMode.DATA_PARALLEL \
  145. and context.get_context("device_target") != "Ascend":
  146. raise RuntimeError("Parallel optimizer only supports Ascend in data parallel mode.")
  147. elif _get_parallel_mode() in (ParallelMode.STAND_ALONE, ParallelMode.HYBRID_PARALLEL):
  148. raise RuntimeError("Parallel optimizer is not supported in {}.".format(_get_parallel_mode()))
  149. else:
  150. self.use_parallel = False
  151. else:
  152. self.use_parallel = False
  153. if self.use_parallel:
  154. if self.cls_name not in ["Lamb", "AdamWeightDecay"]:
  155. raise RuntimeError("Parallel optimizer does not support optimizer {}".format(self.cls_name))
  156. self.dev_num = _get_device_num()
  157. if self.dev_num > self.param_length:
  158. raise RuntimeError("Parallel optimizer can not be applied when the number of parameters {} is"
  159. " less than the number of devices {}".format(self.param_length, self.dev_num))
  160. self.param_rank = self._get_parameter_group_id()
  161. self.optim_filter = tuple(map(lambda x: x == _get_global_rank(), self.param_rank))
  162. self.param_names = []
  163. for param in self.parameters:
  164. self.param_names.append(param.name)
  165. else:
  166. self.optim_filter = (True,) * self.param_length
  167. @property
  168. def unique(self):
  169. """The method is to see whether to make unique. The input type is bool. The method is read-only."""
  170. return self._unique
  171. @unique.setter
  172. def unique(self, value):
  173. """Set whether the input value is unique."""
  174. if not isinstance(value, bool):
  175. raise TypeError("The value type must be bool, but got value type is {}".format(type(value)))
  176. self._unique = value
  177. @property
  178. def target(self):
  179. """The method is used to determine whether the parameter is updated on host or device. The input type is str
  180. and can only be 'CPU', 'Ascend' or 'GPU'."""
  181. return self._target
  182. @target.setter
  183. def target(self, value):
  184. """If the input value is set to "CPU", the parameters will be updated on the host using the Fused
  185. optimizer operation."""
  186. raise NotImplementedError
  187. def decay_weight(self, gradients):
  188. """
  189. Weight decay.
  190. An approach to reduce the overfitting of a deep learning neural network model.
  191. Args:
  192. gradients (tuple[Tensor]): The gradients of `self.parameters`, and have the same shape as
  193. `self.parameters`.
  194. Returns:
  195. tuple[Tensor], The gradients after weight decay.
  196. """
  197. if self.exec_weight_decay:
  198. params = self.parameters
  199. if self.is_group:
  200. gradients = self.map_(F.partial(_apply_decay), self.weight_decay_tensor_tuple, self.decay_flags,
  201. params, gradients)
  202. else:
  203. gradients = self.map_(F.partial(_apply_decay, self.weight_decay_tensor), self.decay_flags,
  204. params, gradients)
  205. return gradients
  206. def scale_grad(self, gradients):
  207. """
  208. Loss scale for mixed precision.
  209. An approach of mixed precision training to improve the speed and energy efficiency of training deep neural
  210. network.
  211. Args:
  212. gradients (tuple[Tensor]): The gradients of `self.parameters`, and have the same shape as
  213. `self.parameters`.
  214. Returns:
  215. tuple[Tensor], The gradients after loss scale.
  216. """
  217. if self.need_scale:
  218. gradients = self.map_(F.partial(_grad_scale, self.reciprocal_scale), gradients)
  219. return gradients
  220. def _grad_sparse_indices_deduplicate(self, gradients):
  221. """ In the case of using big operators, deduplicate the 'indexes' in gradients."""
  222. if self._target != 'CPU' and self._unique:
  223. gradients = self.map_(F.partial(_indices_deduplicate), gradients)
  224. return gradients
  225. def _preprocess_weight_decay(self, weight_decay):
  226. """Check weight decay, and convert int to float."""
  227. if isinstance(weight_decay, (float, int)):
  228. weight_decay = float(weight_decay)
  229. validator.check_non_negative_float(weight_decay, "weight_decay", self.cls_name)
  230. return weight_decay
  231. raise TypeError("Weight decay should be int or float.")
  232. def _preprocess_single_lr(self, learning_rate):
  233. """Check lr value, and convert lr to a float, a Tensor or a LearningRateSchedule."""
  234. if isinstance(learning_rate, (float, int)):
  235. learning_rate = float(learning_rate)
  236. validator.check_non_negative_float(learning_rate, "learning rate", self.cls_name)
  237. return learning_rate
  238. if isinstance(learning_rate, Tensor) and learning_rate.ndim == 0:
  239. return learning_rate
  240. self.dynamic_lr = True
  241. if isinstance(learning_rate, Iterable):
  242. return Tensor(np.array(list(learning_rate)).astype(np.float32))
  243. if isinstance(learning_rate, Tensor):
  244. if learning_rate.ndim > 1:
  245. raise ValueError("The dim of `Tensor` type Learning rate should be a 0 or 1,"
  246. f"but got {learning_rate.ndim}.")
  247. if learning_rate.ndim == 1 and learning_rate.size < 2:
  248. logger.warning("If use `Tensor` type dynamic learning rate, please make sure that the number"
  249. "of elements in the tensor passed is greater than 1.")
  250. return learning_rate
  251. if isinstance(learning_rate, LearningRateSchedule):
  252. return learning_rate
  253. raise TypeError("Learning rate should be int, float, Tensor, Iterable or LearningRateSchedule.")
  254. def _build_single_lr(self, learning_rate, name):
  255. """Build learning rate value, convert learning rate to a Parameter or a LearningRateSchedule."""
  256. if isinstance(learning_rate, float):
  257. learning_rate = Parameter(Tensor(learning_rate, mstype.float32), name)
  258. if self.is_group_lr and self.dynamic_lr:
  259. learning_rate = _ConvertToCell(learning_rate)
  260. return learning_rate
  261. if isinstance(learning_rate, Tensor) and learning_rate.ndim == 0:
  262. learning_rate = Parameter(learning_rate, name)
  263. if self.is_group_lr and self.dynamic_lr:
  264. learning_rate = _ConvertToCell(learning_rate)
  265. return learning_rate
  266. if isinstance(learning_rate, Tensor) and learning_rate.ndim == 1:
  267. return _IteratorLearningRate(learning_rate, name)
  268. return learning_rate
  269. def _check_group_params(self, parameters):
  270. """Check group params."""
  271. parse_keys = ['params', 'lr', 'weight_decay', 'order_params']
  272. for group_param in parameters:
  273. invalid_key = list(filter(lambda x: x not in parse_keys, group_param.keys()))
  274. if invalid_key:
  275. raise KeyError(f'The key "{invalid_key}" cannot be recognized in group params.')
  276. if 'order_params' in group_param.keys():
  277. if len(group_param.keys()) > 1:
  278. raise ValueError("The order params dict in group parameters should "
  279. "only include the 'order_params' key.")
  280. if not isinstance(group_param['order_params'], Iterable):
  281. raise TypeError("The value of 'order_params' should be an Iterable type.")
  282. continue
  283. if not group_param['params']:
  284. raise ValueError("Optimizer got an empty group parameter list.")
  285. for param in group_param['params']:
  286. if not isinstance(param, Parameter):
  287. raise TypeError("The group param should be an iterator of Parameter type.")
  288. def _parse_group_params(self, parameters, learning_rate):
  289. """Parse group params."""
  290. self._check_group_params(parameters)
  291. if isinstance(learning_rate, Tensor) and learning_rate.ndim == 1:
  292. tensor_lr_length = learning_rate.size
  293. else:
  294. tensor_lr_length = 0
  295. for group_param in parameters:
  296. if 'order_params' in group_param.keys():
  297. if len(group_param.keys()) > 1:
  298. raise ValueError("The order params dict in group parameters should "
  299. "only include the 'order_params' key.")
  300. if not isinstance(group_param['order_params'], Iterable):
  301. raise TypeError("The value of 'order_params' should be an Iterable type.")
  302. self.is_group_params_ordered = True
  303. continue
  304. if 'lr' in group_param.keys():
  305. self.is_group_lr = True
  306. group_lr = self._preprocess_single_lr(group_param['lr'])
  307. if isinstance(group_lr, Tensor) and group_lr.ndim == 1:
  308. group_lr_length = group_lr.size
  309. if tensor_lr_length == 0:
  310. tensor_lr_length = group_lr_length
  311. elif group_lr_length != tensor_lr_length:
  312. raise ValueError("The Tensor type dynamic learning rate in group should be the same size.")
  313. def _init_group_params(self, parameters, learning_rate, weight_decay):
  314. """Initialize learning rate or weight decay in group params."""
  315. self._parse_group_params(parameters, learning_rate)
  316. default_lr = self._build_single_lr(learning_rate, 'learning_rate')
  317. params_store = []
  318. for group_num, group_param in enumerate(parameters):
  319. if 'order_params' in group_param.keys():
  320. ordered_parameters = group_param['order_params']
  321. continue
  322. self.group_params += group_param['params']
  323. if 'lr' in group_param.keys():
  324. lr_param_name = 'learning_rate_group_' + str(group_num)
  325. lr = self._preprocess_single_lr(group_param['lr'])
  326. lr = self._build_single_lr(lr, lr_param_name)
  327. else:
  328. lr = default_lr
  329. if 'weight_decay' in group_param.keys():
  330. cur_weight_decay = self._preprocess_weight_decay(group_param['weight_decay'])
  331. weight_decay_ = cur_weight_decay * self.loss_scale
  332. else:
  333. weight_decay_ = weight_decay * self.loss_scale
  334. for key in group_param.keys():
  335. if key not in ('params', 'lr', 'weight_decay'):
  336. logger.warning(f"The optimizer cannot parse '{key}' when setting parameter groups.")
  337. for param in group_param['params']:
  338. validator.check_value_type("parameter", param, [Parameter], self.cls_name)
  339. if param.name in params_store:
  340. raise RuntimeError(f"The {param.name} parameter has appeared in parameter groups.")
  341. params_store.append(param.name)
  342. self.group_lr.append(lr)
  343. self.group_weight_decay.append(weight_decay_)
  344. if self.is_group_params_ordered:
  345. self._order_and_adjust_group_params(ordered_parameters)
  346. def _order_and_adjust_group_params(self, ordered_parameters):
  347. """
  348. Order group parameter, learning rate and weight decay in group params.
  349. """
  350. params_length = len(self.group_params)
  351. if len(ordered_parameters) != len(self.group_params):
  352. raise ValueError(f"The value of 'order_params' should be same with all group parameters.")
  353. ordered_params = [None] * params_length
  354. ordered_learning_rate = [None] * params_length
  355. ordered_weight_decay = [None] * params_length
  356. params_name = [param.name for param in ordered_parameters]
  357. for param, lr, wd in zip(self.group_params, self.group_lr, self.group_weight_decay):
  358. index = params_name.index(param.name)
  359. ordered_params[index] = param
  360. ordered_learning_rate[index] = lr
  361. ordered_weight_decay[index] = wd
  362. self.group_params = ordered_params
  363. self.group_lr = ordered_learning_rate
  364. self.group_weight_decay = ordered_weight_decay
  365. def get_lr(self):
  366. """
  367. Get the learning rate of current step.
  368. Returns:
  369. float, the learning rate of current step.
  370. """
  371. lr = self.learning_rate
  372. if self.dynamic_lr:
  373. if self.is_group_lr:
  374. lr = ()
  375. for learning_rate in self.learning_rate:
  376. current_dynamic_lr = learning_rate(self.global_step)
  377. lr += (current_dynamic_lr,)
  378. else:
  379. lr = self.learning_rate(self.global_step)
  380. self.assignadd(self.global_step, self.global_step_increase_tensor)
  381. return lr
  382. def get_lr_parameter(self, param):
  383. """
  384. Get the learning rate of parameter.
  385. Args:
  386. param (Union[Parameter, list[Parameter]]): The `Parameter` or list of `Parameter`.
  387. Returns:
  388. Parameter, single `Parameter` or `list[Parameter]` according to the input type.
  389. """
  390. def get_lr_value(learning_rate):
  391. if isinstance(learning_rate, (_ConvertToCell, _IteratorLearningRate)):
  392. return learning_rate.learning_rate
  393. return learning_rate
  394. if isinstance(param, Parameter):
  395. param_list = [param]
  396. elif isinstance(param, list):
  397. param_list = param
  398. else:
  399. raise TypeError(f"The parameter only support 'Parameter' or 'list' type.")
  400. lr = []
  401. ids = [id(p) for p in self.parameters]
  402. for p in param_list:
  403. validator.check_value_type("parameter", p, [Parameter], self.cls_name)
  404. if id(p) not in ids:
  405. raise ValueError(f"The parameter {p.name} is not in optimizer.")
  406. if self.is_group_lr:
  407. index = ids.index(id(p))
  408. lr.append(get_lr_value(self.learning_rate[index]))
  409. else:
  410. lr.append(get_lr_value(self.learning_rate))
  411. return lr if isinstance(param, list) else lr[0]
  412. def _get_parameter_group_id(self):
  413. """
  414. Get the parameter partition group id, which is less than the number of devices.
  415. Returns:
  416. tuple, the group id tuple of parameters.
  417. """
  418. rank_list = ()
  419. count = 0
  420. for _ in range(self.param_length):
  421. rank_list = rank_list + (count,)
  422. count = count + 1
  423. if count == self.dev_num:
  424. count = 0
  425. return rank_list
  426. def broadcast_params(self, optim_result):
  427. """
  428. Apply Broadcast operations in the sequential order of parameter groups.
  429. Returns:
  430. bool, the status flag.
  431. """
  432. param_group = []
  433. key_group = []
  434. for _ in range(self.dev_num):
  435. param_group.append(F.make_tuple())
  436. key_group.append(F.make_tuple())
  437. for i in range(self.param_length):
  438. param_group[self.param_rank[i]] = param_group[self.param_rank[i]] + (self.parameters[i],)
  439. key = P.MakeRefKey(self.param_names[i])()
  440. key_group[self.param_rank[i]] = key_group[self.param_rank[i]] + (key,)
  441. new_param_group = []
  442. for root in range(self.dev_num):
  443. ops = P.Broadcast(root)
  444. next_params = ops(param_group[root])
  445. new_param_group.append(next_params)
  446. for i in range(F.tuple_len(next_params)):
  447. F.assign(key_group[root][i], next_params[i])
  448. return new_param_group
  449. def construct(self, *hyper_params):
  450. raise NotImplementedError
  451. op_add = P.AddN()
  452. op_gather = P.Gather()
  453. op_mul = P.Mul()
  454. _apply_decay = C.MultitypeFuncGraph("apply_decay")
  455. @_apply_decay.register("Tensor", "Bool", "Tensor", "RowTensor")
  456. def _tensor_apply_decay_with_sparse(weight_decay, if_apply, weight, gradient):
  457. """Get grad with weight_decay."""
  458. if if_apply:
  459. indices = gradient.indices
  460. values = op_add((op_gather(weight, indices, 0) * F.cast(weight_decay, F.dtype(weight)), gradient.values))
  461. shape = gradient.dense_shape
  462. return RowTensor(indices, values, shape)
  463. return gradient
  464. @_apply_decay.register("Tensor", "Bool", "Tensor", "Tensor")
  465. def _tensor_apply_decay(weight_decay, if_apply, weight, gradient):
  466. """Get grad with weight_decay."""
  467. if if_apply:
  468. return op_add((op_mul(weight, F.cast(weight_decay, F.dtype(weight))), gradient))
  469. return gradient
  470. _grad_scale = C.MultitypeFuncGraph("grad_scale")
  471. _indices_deduplicate = C.MultitypeFuncGraph("indices_deduplicate")
  472. @_grad_scale.register("Number", "Tensor")
  473. def tensor_grad_scale(scale, grad):
  474. """Get grad with scale."""
  475. if scale == 1.0:
  476. return grad
  477. return op_mul(grad, F.cast(scale, F.dtype(grad)))
  478. @_grad_scale.register("Tensor", "Tensor")
  479. def tensor_grad_scale_with_tensor(scale, grad):
  480. """Get grad with scale."""
  481. return op_mul(grad, F.cast(scale, F.dtype(grad)))
  482. @_grad_scale.register("Tensor", "RowTensor")
  483. def tensor_grad_scale_with_sparse(scale, grad):
  484. """Get grad with scale."""
  485. return RowTensor(grad.indices, grad.values * F.cast(scale, F.dtype(grad.values)), grad.dense_shape)
  486. @_indices_deduplicate.register("RowTensor")
  487. def rowtensor_deduplicate_indices_slices(grad):
  488. """Unique the indices and sums the 'values' corresponding to the duplicate indices."""
  489. indices = grad.indices
  490. values = grad.values
  491. unique_indices, index_position = P.Unique()(indices)
  492. summed_values = P.UnsortedSegmentSum()(values, index_position, P.DynamicShape()(unique_indices)[0])
  493. return RowTensor(unique_indices, summed_values, grad.dense_shape)
  494. @_indices_deduplicate.register("Tensor")
  495. def tensor_deduplicate_indice_slices(grad):
  496. """Return the input gradient directly in the dense sences."""
  497. return grad
  498. class _ConvertToCell(LearningRateSchedule):
  499. """Inner api, convert learning rate of scalar to LearningRateSchedule."""
  500. def __init__(self, learning_rate):
  501. super(_ConvertToCell, self).__init__()
  502. if not isinstance(learning_rate, Parameter):
  503. raise TypeError('Learning rate must be Parameter.')
  504. self.learning_rate = learning_rate
  505. def construct(self, global_step):
  506. return self.learning_rate + 1.0 - 1.0
  507. class _IteratorLearningRate(LearningRateSchedule):
  508. """Inner api, convert learning rate of Tensor(list) to LearningRateSchedule."""
  509. def __init__(self, learning_rate, name):
  510. super(_IteratorLearningRate, self).__init__()
  511. if isinstance(learning_rate, Tensor):
  512. if learning_rate.ndim != 1:
  513. raise ValueError("The dim of `Tensor` type dynamic learning rate should be a 1,"
  514. f"but got {learning_rate.ndim}.")
  515. else:
  516. raise TypeError("Learning rate should be Tensor.")
  517. self.learning_rate = Parameter(learning_rate, name)
  518. self.gather = P.Gather()
  519. def construct(self, global_step):
  520. return self.gather(self.learning_rate, global_step, 0)