You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

sgd.py 5.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """sgd"""
  16. from mindspore.ops import functional as F, composite as C, operations as P
  17. from mindspore.common.parameter import Parameter
  18. from mindspore.common.tensor import Tensor
  19. import mindspore.common.dtype as mstype
  20. from mindspore._checkparam import Validator as validator
  21. from .optimizer import Optimizer
  22. sgd_opt = C.MultitypeFuncGraph("sgd_opt")
  23. @sgd_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor")
  24. def _tensor_run_opt_ext(opt, learning_rate, momentum, gradient, weight, accum, stat):
  25. """Apply sgd optimizer to the weight parameter using Tensor."""
  26. success = True
  27. success = F.depend(success, opt(weight, gradient, learning_rate, accum, momentum, stat))
  28. return success
  29. class SGD(Optimizer):
  30. """
  31. Implements stochastic gradient descent (optionally with momentum).
  32. Introduction to SGD can be found at https://en.wikipedia.org/wiki/Stochastic_gradient_descent.
  33. Nesterov momentum is based on the formula from paper `On the importance of initialization and
  34. momentum in deep learning <http://proceedings.mlr.press/v28/sutskever13.html>`_.
  35. Args:
  36. params (list[Parameter]): A list of parameter, which will be updated. The element in `params`
  37. should be class mindspore.Parameter.
  38. learning_rate (Union[float, Tensor, Iterable]): A value for the learning rate. When the learning_rate is
  39. Iterable or a Tensor and the dims of the Tensor is 1,
  40. use dynamic learning rate, then the i-th step will
  41. take the i-th value as the learning rate.
  42. When the learning_rate is float or learning_rate is a Tensor
  43. but the dims of the Tensor is 0, use fixed learning rate.
  44. Other cases are not supported. Default: 0.1.
  45. momentum (float): A floating point value the momentum. Default: 0.
  46. dampening (float): A floating point value of dampening for momentum. Default: 0.
  47. weight_decay (float): Weight decay (L2 penalty). Default: 0.
  48. nesterov (bool): Enables the Nesterov momentum. Default: False.
  49. loss_scale (float): A floating point value for the loss scale, which should be larger
  50. than 0.0. Default: 1.0.
  51. Inputs:
  52. - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
  53. Outputs:
  54. Tensor[bool], the value is True.
  55. Raises:
  56. ValueError: If the momentum, dampening or weight_decay value is less than 0.0.
  57. Examples:
  58. >>> net = Net()
  59. >>> loss = nn.SoftmaxCrossEntropyWithLogits()
  60. >>> optim = nn.SGD(params=net.trainable_params())
  61. >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None)
  62. """
  63. def __init__(self, params, learning_rate=0.1, momentum=0.0, dampening=0.0, weight_decay=0.0, nesterov=False,
  64. loss_scale=1.0):
  65. super(SGD, self).__init__(learning_rate, params, weight_decay, loss_scale)
  66. if not isinstance(momentum, float):
  67. raise TypeError("momentum should be float number!")
  68. if isinstance(momentum, float) and momentum < 0.0:
  69. raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum))
  70. if not isinstance(dampening, float):
  71. raise TypeError("dampening should be float number")
  72. if isinstance(dampening, int):
  73. dampening = float(dampening)
  74. if dampening < 0.0:
  75. raise ValueError("dampening should be at least 0.0, but got dampening {}".format(dampening))
  76. self.dampening = dampening
  77. validator.check_value_type("nesterov", nesterov, [bool], self.cls_name)
  78. self.nesterov = nesterov
  79. self.opt = P.SGD(dampening, weight_decay, nesterov)
  80. self.momentum = Parameter(Tensor(momentum, mstype.float32), name="momentum")
  81. self.accum = self.parameters.clone(prefix="accum", init='zeros')
  82. self.stat = self.parameters.clone(prefix="stat", init='ones')
  83. self.hyper_map = C.HyperMap()
  84. def construct(self, gradients):
  85. params = self.parameters
  86. accum = self.accum
  87. stat = self.stat
  88. gradients = self.decay_weight(gradients)
  89. gradients = self.scale_grad(gradients)
  90. lr = self.get_lr()
  91. success = self.hyper_map(F.partial(sgd_opt, self.opt, lr, self.momentum), gradients, params, accum, stat)
  92. return success