You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

optimizer.py 25 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """optimizer"""
  16. from typing import Iterable
  17. import numpy as np
  18. import mindspore
  19. from mindspore.ops import functional as F, composite as C, operations as P
  20. from mindspore.nn.cell import Cell
  21. from mindspore.nn.layer.container import CellList
  22. from mindspore.common.parameter import Parameter, ParameterTuple
  23. from mindspore.common.initializer import initializer
  24. from mindspore.common.tensor import Tensor, IndexedSlices
  25. import mindspore.common.dtype as mstype
  26. from mindspore._checkparam import Validator as validator
  27. from mindspore._checkparam import Rel
  28. from mindspore import log as logger
  29. from mindspore.parallel._utils import _get_global_rank, _get_device_num, _get_parallel_mode
  30. from mindspore.train.parallel_utils import ParallelMode
  31. from mindspore import context
  32. from mindspore.nn.learning_rate_schedule import LearningRateSchedule
  33. __all__ = ['Optimizer']
  34. class Optimizer(Cell):
  35. """
  36. Base class for all optimizers.
  37. This class defines the API to add Ops to train a model.
  38. Note:
  39. This class defines the API to add Ops to train a model. Never use
  40. this class directly, but instead instantiate one of its subclasses.
  41. Different parameter groups can set different `learning_rate` and `weight_decay`.
  42. When separating parameter groups, the weight decay in each group will be applied on the parameters if the
  43. weight_decay is positive. For most optimizer, when not separating parameters, the `weight_decay` in the API will
  44. be applied on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive.
  45. To improve parameter groups performance, the customized order of parameters can be supported.
  46. Args:
  47. learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or graph for the learning
  48. rate. When the learning_rate is a Iterable or a Tensor with dimension of 1, use dynamic learning rate, then
  49. the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule,
  50. use dynamic learning rate, the i-th learning rate will be calculated during the process of training
  51. according to the formula of LearningRateSchedule. When the learning_rate is a float or a Tensor with
  52. dimension of 0, use fixed learning rate. Other cases are not supported. The float learning rate should be
  53. equal to or greater than 0. If the type of `learning_rate` is int, it will be converted to float.
  54. parameters (Union[list[Parameter], list[dict]]): When the `parameters` is a list of `Parameter` which will be
  55. updated, the element in `parameters` should be class `Parameter`. When the `parameters` is a list of `dict`,
  56. the "params", "lr", "weight_decay" and "order_params" are the keys can be parsed.
  57. - params: Required. The value should be a list of `Parameter`.
  58. - lr: Optional. If "lr" in the keys, the value of corresponding learning rate will be used.
  59. If not, the `learning_rate` in the API will be used.
  60. - weight_decay: Optional. If "weight_decay" in the keys, the value of corresponding weight decay
  61. will be used. If not, the `weight_decay` in the API will be used.
  62. - order_params: Optional. If "order_params" in the keys, the value should be the order of parameters and
  63. the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which
  64. in the value of 'order_params' should be in one of group parameters.
  65. weight_decay (float): A floating point value for the weight decay. It should be in range [0.0, 1.0].
  66. If the type of `weight_decay` input is int, it will be converted to float. Default: 0.0.
  67. loss_scale (float): A floating point value for the loss scale. It should be not less than 1.0. If the
  68. type of `loss_scale` input is int, it will be converted to float. Default: 1.0.
  69. Raises:
  70. ValueError: If the learning_rate is a Tensor, but the dims of tensor is greater than 1.
  71. TypeError: If the learning_rate is not any of the three types: float, Tensor, Iterable.
  72. """
  73. def __init__(self, learning_rate, parameters, weight_decay=0.0, loss_scale=1.0):
  74. super(Optimizer, self).__init__(auto_prefix=False)
  75. if parameters and not isinstance(parameters, list):
  76. parameters = list(parameters)
  77. if not parameters:
  78. raise ValueError("Optimizer got an empty parameter list.")
  79. if not isinstance(parameters[0], (dict, Parameter)):
  80. raise TypeError("Only a list of Parameter or dict can be supported.")
  81. if isinstance(loss_scale, int):
  82. loss_scale = float(loss_scale)
  83. validator.check_value_type("loss_scale", loss_scale, [float], self.cls_name)
  84. validator.check_number_range("loss_scale", loss_scale, 1.0, float("inf"), Rel.INC_LEFT, self.cls_name)
  85. self.loss_scale = loss_scale
  86. weight_decay = self._preprocess_weight_decay(weight_decay)
  87. self.dynamic_lr = False
  88. self.assignadd = None
  89. self.global_step = None
  90. self.is_group = False
  91. self.is_group_lr = False
  92. self.is_group_params_ordered = False
  93. learning_rate = self._preprocess_single_lr(learning_rate)
  94. if isinstance(parameters[0], dict):
  95. self.is_group = True
  96. self.group_params = []
  97. self.group_lr = []
  98. self.group_weight_decay = []
  99. self._init_group_params(parameters, learning_rate, weight_decay)
  100. # The final value of dynamic_lr can be determined after the process of parse_single_lr and init_group_params
  101. if self.dynamic_lr:
  102. self.assignadd = P.AssignAdd()
  103. self.global_step = Parameter(initializer(0, [1], mindspore.int32), name='global_step')
  104. if self.is_group_lr:
  105. if self.dynamic_lr:
  106. self.learning_rate = CellList(self.group_lr)
  107. else:
  108. self.learning_rate = ParameterTuple(self.group_lr)
  109. else:
  110. self.learning_rate = self._build_single_lr(learning_rate, 'learning_rate')
  111. if self.is_group:
  112. self.parameters = ParameterTuple(self.group_params)
  113. self.weight_decay = tuple(self.group_weight_decay)
  114. decay_filter = lambda x: x > 0
  115. self.decay_flags = tuple(decay_filter(x) for x in self.weight_decay)
  116. self.exec_weight_decay = any(self.decay_flags)
  117. else:
  118. self.parameters = ParameterTuple(parameters)
  119. self.weight_decay = weight_decay * loss_scale
  120. decay_filter = lambda x: 'beta' not in x.name and 'gamma' not in x.name
  121. self.decay_flags = tuple(decay_filter(x) for x in self.parameters)
  122. self.exec_weight_decay = self.weight_decay > 0
  123. ps_filter = lambda x: x.is_param_ps
  124. self.ps_parameters = tuple(ps_filter(x) for x in self.parameters)
  125. self.reciprocal_scale = 1.0 / loss_scale
  126. self.param_length = len(self.parameters)
  127. self.map_ = C.Map()
  128. use_parallel = context.get_auto_parallel_context("enable_parallel_optimizer")
  129. self.use_parallel = use_parallel
  130. if use_parallel:
  131. if self.cls_name not in ["Lamb", "AdamWeightDecay"]:
  132. raise RuntimeError("Optimizer segmentation does not support optimizer {}".format(self.cls_name))
  133. if _get_parallel_mode() != ParallelMode.DATA_PARALLEL:
  134. raise RuntimeError("Optimizer segmentation does not support parallel mode {}".format
  135. (_get_parallel_mode()))
  136. self.dev_num = _get_device_num()
  137. if self.dev_num > self.param_length:
  138. raise RuntimeError("Optimizer segmentation can not be applied when the number of parameters {} is"
  139. " less than the number of devices {}".format(self.param_length, self.dev_num))
  140. self.param_rank = self._get_parameter_group_id()
  141. self.optim_filter = tuple(map(lambda x: x == _get_global_rank(), self.param_rank))
  142. self.param_names = []
  143. for param in self.parameters:
  144. self.param_names.append(param.name)
  145. else:
  146. self.optim_filter = (True,) * self.param_length
  147. def decay_weight(self, gradients):
  148. """
  149. Weight decay.
  150. An approach to reduce the overfitting of a deep learning neural network model.
  151. Args:
  152. gradients (tuple[Tensor]): The gradients of `self.parameters`, and have the same shape with
  153. `self.parameters`.
  154. Returns:
  155. tuple[Tensor], The gradients after weight decay.
  156. """
  157. if self.exec_weight_decay:
  158. params = self.parameters
  159. if self.is_group:
  160. gradients = self.map_(F.partial(_apply_decay), self.weight_decay, self.decay_flags,
  161. params, gradients)
  162. else:
  163. gradients = self.map_(F.partial(_apply_decay, self.weight_decay), self.decay_flags,
  164. params, gradients)
  165. return gradients
  166. def scale_grad(self, gradients):
  167. """
  168. Loss scale for mixed precision.
  169. An approach of mixed precision training to improve the speed and energy efficiency of training deep neural
  170. network.
  171. Args:
  172. gradients (tuple[Tensor]): The gradients of `self.parameters`, and have the same shape with
  173. `self.parameters`.
  174. Returns:
  175. tuple[Tensor], The gradients after loss scale.
  176. """
  177. if self.reciprocal_scale != 1.0:
  178. gradients = self.map_(F.partial(_grad_scale, self.reciprocal_scale), gradients)
  179. return gradients
  180. def _preprocess_weight_decay(self, weight_decay):
  181. """Check weight decay, and convert int to float."""
  182. if isinstance(weight_decay, (float, int)):
  183. weight_decay = float(weight_decay)
  184. validator.check_number_range("weight_decay", weight_decay, 0.0, 1.0, Rel.INC_BOTH, self.cls_name)
  185. return weight_decay
  186. raise TypeError("Weight decay should be int or float.")
  187. def _preprocess_single_lr(self, learning_rate):
  188. """Check lr value, and convert lr to a float, a Tensor or a LearningRateSchedule."""
  189. if isinstance(learning_rate, (float, int)):
  190. learning_rate = float(learning_rate)
  191. validator.check_number_range("learning rate", learning_rate, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name)
  192. return learning_rate
  193. if isinstance(learning_rate, Tensor) and learning_rate.dim() == 0:
  194. return learning_rate
  195. self.dynamic_lr = True
  196. if isinstance(learning_rate, Iterable):
  197. return Tensor(np.array(list(learning_rate)).astype(np.float32))
  198. if isinstance(learning_rate, Tensor):
  199. if learning_rate.dim() > 1:
  200. raise ValueError("The dim of `Tensor` type Learning rate should be a 0 or 1,"
  201. f"but got {learning_rate.dim()}.")
  202. if learning_rate.dim() == 1 and learning_rate.size() < 2:
  203. logger.warning("If use `Tensor` type dynamic learning rate, please make sure that the number"
  204. "of elements in the tensor passed is greater than 1.")
  205. return learning_rate
  206. if isinstance(learning_rate, LearningRateSchedule):
  207. return learning_rate
  208. raise TypeError("Learning rate should be int, float, Tensor, Iterable or LearningRateSchedule.")
  209. def _build_single_lr(self, learning_rate, name):
  210. """Build learning rate value, convert learning rate to a Parameter or a LearningRateSchedule."""
  211. if isinstance(learning_rate, float):
  212. learning_rate = Parameter(Tensor(learning_rate, mstype.float32), name)
  213. if self.is_group_lr and self.dynamic_lr:
  214. learning_rate = _ConvertToCell(learning_rate)
  215. return learning_rate
  216. if isinstance(learning_rate, Tensor) and learning_rate.dim() == 0:
  217. learning_rate = Parameter(learning_rate, name)
  218. if self.is_group_lr and self.dynamic_lr:
  219. learning_rate = _ConvertToCell(learning_rate)
  220. return learning_rate
  221. if isinstance(learning_rate, Tensor) and learning_rate.dim() == 1:
  222. return _IteratorLearningRate(learning_rate, name)
  223. return learning_rate
  224. def _check_group_params(self, parameters):
  225. """Check group params."""
  226. parse_keys = ['params', 'lr', 'weight_decay', 'order_params']
  227. for group_param in parameters:
  228. invalid_key = list(filter(lambda x: x not in parse_keys, group_param.keys()))
  229. if invalid_key:
  230. raise KeyError(f'The key "{invalid_key}" cannot be recognized in group params.')
  231. if 'order_params' in group_param.keys():
  232. if len(group_param.keys()) > 1:
  233. raise ValueError("The order params dict in group parameters should "
  234. "only include the 'order_params' key.")
  235. if not isinstance(group_param['order_params'], Iterable):
  236. raise TypeError("The value of 'order_params' should be an Iterable type.")
  237. continue
  238. if not group_param['params']:
  239. raise ValueError("Optimizer got an empty group parameter list.")
  240. for param in group_param['params']:
  241. if not isinstance(param, Parameter):
  242. raise TypeError("The group param should be an iterator of Parameter type.")
  243. def _parse_group_params(self, parameters, learning_rate):
  244. """Parse group params."""
  245. self._check_group_params(parameters)
  246. if isinstance(learning_rate, Tensor) and learning_rate.dim() == 1:
  247. tensor_lr_length = learning_rate.size()
  248. else:
  249. tensor_lr_length = 0
  250. for group_param in parameters:
  251. if 'order_params' in group_param.keys():
  252. if len(group_param.keys()) > 1:
  253. raise ValueError("The order params dict in group parameters should "
  254. "only include the 'order_params' key.")
  255. if not isinstance(group_param['order_params'], Iterable):
  256. raise TypeError("The value of 'order_params' should be an Iterable type.")
  257. self.is_group_params_ordered = True
  258. continue
  259. if 'lr' in group_param.keys():
  260. self.is_group_lr = True
  261. group_lr = self._preprocess_single_lr(group_param['lr'])
  262. if isinstance(group_lr, Tensor) and group_lr.dim() == 1:
  263. group_lr_length = group_lr.size()
  264. if tensor_lr_length == 0:
  265. tensor_lr_length = group_lr_length
  266. elif group_lr_length != tensor_lr_length:
  267. raise ValueError("The Tensor type dynamic learning rate in group should be the same size.")
  268. def _init_group_params(self, parameters, learning_rate, weight_decay):
  269. """Init learning rate or weight decay in group params."""
  270. self._parse_group_params(parameters, learning_rate)
  271. default_lr = self._build_single_lr(learning_rate, 'learning_rate')
  272. params_store = []
  273. for group_num, group_param in enumerate(parameters):
  274. if 'order_params' in group_param.keys():
  275. ordered_parameters = group_param['order_params']
  276. continue
  277. self.group_params += group_param['params']
  278. if 'lr' in group_param.keys():
  279. lr_param_name = 'learning_rate_group_' + str(group_num)
  280. lr = self._preprocess_single_lr(group_param['lr'])
  281. lr = self._build_single_lr(lr, lr_param_name)
  282. else:
  283. lr = default_lr
  284. if 'weight_decay' in group_param.keys():
  285. cur_weight_decay = self._preprocess_weight_decay(group_param['weight_decay'])
  286. weight_decay_ = cur_weight_decay * self.loss_scale
  287. else:
  288. weight_decay_ = weight_decay * self.loss_scale
  289. for key in group_param.keys():
  290. if key not in ('params', 'lr', 'weight_decay'):
  291. logger.warning(f"The optimizer cannot parse '{key}' when setting parameter groups.")
  292. for param in group_param['params']:
  293. validator.check_value_type("parameter", param, [Parameter], self.cls_name)
  294. if param.name in params_store:
  295. raise RuntimeError(f"The {param.name} parameter has appeared in parameter groups.")
  296. params_store.append(param.name)
  297. self.group_lr.append(lr)
  298. self.group_weight_decay.append(weight_decay_)
  299. if self.is_group_params_ordered:
  300. self._order_and_adjust_group_params(ordered_parameters)
  301. def _order_and_adjust_group_params(self, ordered_parameters):
  302. """
  303. Order group parameter, learning rate and weight decay in group params.
  304. """
  305. params_length = len(self.group_params)
  306. if len(ordered_parameters) != len(self.group_params):
  307. raise ValueError(f"The value of 'order_params' should be same with all group parameters.")
  308. ordered_params = [None] * params_length
  309. ordered_learning_rate = [None] * params_length
  310. ordered_weight_decay = [None] * params_length
  311. params_name = [param.name for param in ordered_parameters]
  312. for param, lr, wd in zip(self.group_params, self.group_lr, self.group_weight_decay):
  313. index = params_name.index(param.name)
  314. ordered_params[index] = param
  315. ordered_learning_rate[index] = lr
  316. ordered_weight_decay[index] = wd
  317. self.group_params = ordered_params
  318. self.group_lr = ordered_learning_rate
  319. self.group_weight_decay = ordered_weight_decay
  320. def get_lr(self):
  321. """
  322. Get the learning rate of current step.
  323. Returns:
  324. float, the learning rate of current step.
  325. """
  326. lr = self.learning_rate
  327. if self.dynamic_lr:
  328. if self.is_group_lr:
  329. lr = ()
  330. for learning_rate in self.learning_rate:
  331. current_dynamic_lr = learning_rate(self.global_step)
  332. lr += (current_dynamic_lr,)
  333. else:
  334. lr = self.learning_rate(self.global_step)
  335. F.control_depend(lr, self.assignadd(self.global_step, 1))
  336. return lr
  337. def get_lr_parameter(self, param):
  338. """
  339. Get the learning rate of parameter.
  340. Args:
  341. param (Union[Parameter, list[Parameter]]): The `Parameter` or list of `Parameter`.
  342. Returns:
  343. Parameter, single `Parameter` or `list[Parameter]` according to the input type.
  344. """
  345. def get_lr_value(learning_rate):
  346. if isinstance(learning_rate, (_ConvertToCell, _IteratorLearningRate)):
  347. return learning_rate.learning_rate
  348. return learning_rate
  349. if isinstance(param, Parameter):
  350. param_list = [param]
  351. elif isinstance(param, list):
  352. param_list = param
  353. else:
  354. raise TypeError(f"The parameter only support 'Parameter' or 'list' type.")
  355. lr = []
  356. for p in param_list:
  357. validator.check_value_type("parameter", p, [Parameter], self.cls_name)
  358. if p not in self.parameters:
  359. raise ValueError(f"The parameter {p.name} is not in optimizer.")
  360. if self.is_group_lr:
  361. index = self.parameters.index(p)
  362. lr.append(get_lr_value(self.learning_rate[index]))
  363. else:
  364. lr.append(get_lr_value(self.learning_rate))
  365. return lr if isinstance(param, list) else lr[0]
  366. def _get_parameter_group_id(self):
  367. """
  368. Get the parameter partition group id, which is less than the number of devices.
  369. Returns:
  370. tuple, the group id tuple of parameters.
  371. """
  372. rank_list = ()
  373. count = 0
  374. for _ in range(self.param_length):
  375. rank_list = rank_list + (count,)
  376. count = count + 1
  377. if count == self.dev_num:
  378. count = 0
  379. return rank_list
  380. def broadcast_params(self, optim_result):
  381. """
  382. Apply Broadcast operations in the sequential order of parameter groups.
  383. Returns:
  384. bool, the status flag.
  385. """
  386. param_group = []
  387. key_group = []
  388. for _ in range(self.dev_num):
  389. param_group.append(F.make_tuple())
  390. key_group.append(F.make_tuple())
  391. for i in range(self.param_length):
  392. param_group[self.param_rank[i]] = param_group[self.param_rank[i]] + (optim_result[i],)
  393. key = P.MakeRefKey(self.param_names[i])()
  394. key_group[self.param_rank[i]] = key_group[self.param_rank[i]] + (key,)
  395. new_param_group = []
  396. for root in range(self.dev_num):
  397. ops = P.Broadcast(root)
  398. next_params = ops(param_group[root])
  399. new_param_group.append(next_params)
  400. for i in range(F.tuple_len(next_params)):
  401. F.assign(key_group[root][i], next_params[i])
  402. status = True
  403. for i in range(self.dev_num - 1):
  404. status = F.control_depend(new_param_group[i][0], new_param_group[i+1])
  405. return status
  406. def construct(self, *hyper_params):
  407. raise NotImplementedError
  408. op_add = P.AddN()
  409. op_gather = P.GatherV2()
  410. _apply_decay = C.MultitypeFuncGraph("apply_decay")
  411. @_apply_decay.register("Number", "Bool", "Tensor", "IndexedSlices")
  412. def _tensor_apply_decay_with_sparse(weight_decay, if_apply, weight, gradient):
  413. """Get grad with weight_decay."""
  414. if if_apply:
  415. indices = gradient.indices()
  416. values = op_add((op_gather(weight, indices, 0) * weight_decay, gradient.values()))
  417. shape = gradient.dense_shape()
  418. return IndexedSlices(indices, values, shape)
  419. return gradient
  420. @_apply_decay.register("Number", "Bool", "Tensor", "Tensor")
  421. def _tensor_apply_decay(weight_decay, if_apply, weight, gradient):
  422. """Get grad with weight_decay."""
  423. if if_apply:
  424. return op_add((weight * weight_decay, gradient))
  425. return gradient
  426. _grad_scale = C.MultitypeFuncGraph("grad_scale")
  427. @_grad_scale.register("Number", "Tensor")
  428. def tensor_grad_scale(scale, grad):
  429. """Get grad with scale."""
  430. if scale == 1.0:
  431. return grad
  432. return grad * scale
  433. @_grad_scale.register("Number", "IndexedSlices")
  434. def tensor_grad_scale_with_sparse(scale, grad):
  435. """Get grad with scale."""
  436. if scale == 1.0:
  437. return grad
  438. return IndexedSlices(grad.indices(), grad.values() * scale, grad.dense_shape())
  439. class _ConvertToCell(LearningRateSchedule):
  440. """Inner api, convert learning rate of scalar to LearningRateSchedule."""
  441. def __init__(self, learning_rate):
  442. super(_ConvertToCell, self).__init__()
  443. if not isinstance(learning_rate, Parameter):
  444. raise TypeError('Learning rate must be Parameter.')
  445. self.learning_rate = learning_rate
  446. def construct(self, global_step):
  447. return self.learning_rate + 1.0 - 1.0
  448. class _IteratorLearningRate(LearningRateSchedule):
  449. """Inner api, convert learning rate of Tensor(list) to LearningRateSchedule."""
  450. def __init__(self, learning_rate, name):
  451. super(_IteratorLearningRate, self).__init__()
  452. if isinstance(learning_rate, Tensor):
  453. if learning_rate.dim() != 1:
  454. raise ValueError("The dim of `Tensor` type dynamic learning rate should be a 1,"
  455. f"but got {learning_rate.dim()}.")
  456. else:
  457. raise TypeError("Learning rate should be Tensor.")
  458. self.learning_rate = Parameter(learning_rate, name)
  459. self.gather = P.GatherV2()
  460. def construct(self, global_step):
  461. return self.gather(self.learning_rate, global_step, 0)