You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

sgd.py 3.3 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. # -*- coding: utf-8 -*-
  2. import os
  3. from typing import Iterable, Union
  4. from ..functional.inplace import _inplace_add_
  5. from ..tensor import Parameter, tensor
  6. from .optimizer import Optimizer
  7. class SGD(Optimizer):
  8. r"""Implements stochastic gradient descent.
  9. Nesterov momentum is based on the formula from
  10. `"On the importance of initialization and momentum in deep learning" <http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf>`_ .
  11. Args:
  12. params: iterable of parameters to optimize or dicts defining
  13. parameter groups.
  14. lr: learning rate.
  15. momentum: momentum factor. Default: 0.0
  16. nesterov: enables Nesterov momentum. Default: False
  17. weight_decay: weight decay (L2 penalty). Default: 0.0
  18. """
  19. def __init__(
  20. self,
  21. params: Union[Iterable[Parameter], dict],
  22. lr: float,
  23. momentum: float = 0.0,
  24. nesterov: bool = False,
  25. weight_decay: float = 0.0,
  26. ):
  27. assert lr >= 0.0, "Invalid learning rate: {}".format(lr)
  28. assert momentum >= 0.0, "Invalid momentum value: {}".format(momentum)
  29. assert weight_decay >= 0.0, "Invalid weight_decay value: {}".format(
  30. weight_decay
  31. )
  32. assert not nesterov or momentum > 0.0, "Nesterov momentum requires a momentum"
  33. defaults = dict(lr=lr, momentum=momentum, weight_decay=weight_decay)
  34. super().__init__(params, defaults)
  35. self.nesterov = nesterov
  36. self._disable_type_convert = True
  37. def _create_state(self, param_group):
  38. if param_group["momentum"] != 0.0:
  39. for param in param_group["params"]:
  40. self._add_state(param, "momentum_buffer")
  41. def _updates(self, param_group):
  42. lr = param_group["lr"]
  43. weight_decay = param_group["weight_decay"]
  44. momentum = param_group["momentum"]
  45. # since `conver_inputs` is disabled for param updates,
  46. # scalar should be explicitly tansforred to tensor
  47. _lr = tensor(lr, dtype="float32")
  48. _weight_decay = tensor(weight_decay, dtype="float32")
  49. _momentum = tensor(momentum, dtype="float32")
  50. inplace_mode = int(os.getenv("MEGENGINE_INPLACE_UPDATE", "0"))
  51. if inplace_mode:
  52. _neg_lr = tensor(-lr, dtype="float32")
  53. c1 = tensor(1.0)
  54. for param in param_group["params"]:
  55. if param.grad is None:
  56. continue
  57. grad = param.grad
  58. if weight_decay != 0.0:
  59. grad = grad + param * _weight_decay
  60. if inplace_mode:
  61. if momentum != 0.0:
  62. v = self._state[param]["momentum_buffer"]
  63. _inplace_add_(v, grad, alpha=_momentum, beta=c1)
  64. if self.nesterov:
  65. grad = grad + v * _momentum
  66. else:
  67. grad = v
  68. _inplace_add_(param, grad, alpha=c1, beta=_neg_lr)
  69. continue
  70. if momentum != 0.0:
  71. v = self._state[param]["momentum_buffer"]
  72. v *= _momentum
  73. v += grad
  74. if self.nesterov:
  75. grad = grad + v * _momentum
  76. else:
  77. grad = v
  78. param -= _lr * grad