You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ftrl.py 6.9 kB

6 years ago
6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """FTRL"""
  16. from mindspore.ops import functional as F, composite as C, operations as P
  17. from mindspore.common import Tensor
  18. import mindspore.common.dtype as mstype
  19. from mindspore.ops.operations import _inner_ops as inner
  20. from mindspore._checkparam import Validator as validator
  21. from mindspore._checkparam import Rel
  22. from .optimizer import Optimizer, _apply_decay, _grad_scale
  23. _ftrl_opt = C.MultitypeFuncGraph("ftrl_opt")
  24. @_ftrl_opt.register("Function", "Function", "Tensor", "Number", "Number", "Number", "Tensor", "Tuple", "Tensor",
  25. "Tensor")
  26. def _tensor_run_opt_with_sparse(opt, spars_opt, learning_rate, l1, l2, lr_power, linear, gradient, weight, moment):
  27. """Apply sparse ftrl optimizer to the weight parameter when the gradient is sparse."""
  28. success = True
  29. success = F.depend(success, spars_opt(weight, moment, linear, gradient[1], gradient[0]))
  30. return success
  31. @_ftrl_opt.register("Function", "Function", "Tensor", "Number", "Number", "Number", "Tensor", "Tensor", "Tensor",
  32. "Tensor")
  33. def _tensor_run_opt(opt, spars_opt, learning_rate, l1, l2, lr_power, linear, gradient, weight, moment):
  34. """Apply ftrl optimizer to the weight parameter."""
  35. success = True
  36. success = F.depend(success, opt(weight, moment, linear, gradient, learning_rate, l1, l2, lr_power))
  37. return success
  38. def _check_param(initial_accum, lr_power, l1, l2, use_locking, weight_decay=0.0, prim_name=None):
  39. """Check param."""
  40. validator.check_value_type("initial_accum", initial_accum, [float], prim_name)
  41. validator.check_number("initial_accum", initial_accum, 0.0, Rel.GE, prim_name)
  42. validator.check_value_type("lr_power", lr_power, [float], prim_name)
  43. validator.check_number("lr_power", lr_power, 0.0, Rel.LE, prim_name)
  44. validator.check_value_type("l1", l1, [float], prim_name)
  45. validator.check_number("l1", l1, 0.0, Rel.GE, prim_name)
  46. validator.check_value_type("l2", l2, [float], prim_name)
  47. validator.check_number("l2", l2, 0.0, Rel.GE, prim_name)
  48. validator.check_value_type("use_locking", use_locking, [bool], prim_name)
  49. validator.check_value_type("weight_decay", weight_decay, [float], prim_name)
  50. validator.check_number("weight_decay", weight_decay, 0.0, Rel.GE, prim_name)
  51. class FTRL(Optimizer):
  52. """
  53. Implement the FTRL algorithm with ApplyFtrl Operator.
  54. FTRL is an online convex optimization algorithm that adaptively chooses its regularization function
  55. based on the loss functions. Refer to paper `Adaptive Bound Optimization for Online Convex Optimization
  56. <https://arxiv.org/abs/1002.4908>`_. Refer to paper `Ad Click Prediction: a View from the Trenches
  57. <https://www.eecs.tufts.edu/~dsculley/papers/ad-click-prediction.pdf>`_ for engineering document.
  58. Note:
  59. The sparse strategy is applied while the SparseGatherV2 operator being used for forward network and the
  60. `sparse_grad` of `Parameter` being set. The sparse feature is under continuous development. The sparse
  61. behavior is currently performed on the CPU.
  62. Args:
  63. params (list[Parameter]): A list of parameter, which will be updated. The element in `params`
  64. should be Parameter.
  65. initial_accum (float): The starting value for accumulators, must be zero or positive values. Default: 0.1.
  66. learning_rate (float): The learning rate value, should be positive. Default: 0.001.
  67. lr_power (float): Learning rate power controls how the learning rate decreases during training, must be less
  68. than or equal to zero. Use fixed learning rate if lr_power is zero. Default: -0.5.
  69. l1 (float): l1 regularization strength, must be greater than or equal to zero. Default: 0.0.
  70. l2 (float): l2 regularization strength, must be greater than or equal to zero. Default: 0.0.
  71. use_locking (bool): If True use locks for update operation. Default: False.
  72. loss_scale (float): Value for the loss scale. It should be equal to or greater than 1.0. Default: 1.0.
  73. wegith_decay (float): Weight decay value to multiply weight, must be zero or positive value. Default: 0.0.
  74. Inputs:
  75. - **grads** (tuple[Tensor]) - The gradients of `params` in optimizer, the shape is as same as the `params`
  76. in optimizer.
  77. Outputs:
  78. tuple[Parameter], the updated parameters, the shape is the same as `params`.
  79. Examples:
  80. >>> net = Net()
  81. >>> loss = nn.SoftmaxCrossEntropyWithLogits()
  82. >>> opt = nn.FTRL(net.trainable_params())
  83. >>> model = Model(net, loss_fn=loss, optimizer=opt, metrics=None)
  84. """
  85. def __init__(self, params, initial_accum=0.1, learning_rate=0.001, lr_power=-0.5, l1=0.0, l2=0.0,
  86. use_locking=False, loss_scale=1.0, weight_decay=0.0):
  87. super(FTRL, self).__init__(learning_rate, params, loss_scale=loss_scale)
  88. if self.is_group:
  89. raise RuntimeError(f"The {self.cls_name} optimizer cannot support group setting.")
  90. _check_param(initial_accum, lr_power, l1, l2, use_locking, weight_decay, self.cls_name)
  91. self.moments = self.parameters.clone(prefix="moments", init=initial_accum)
  92. self.linear = self.parameters.clone(prefix="linear", init='zeros')
  93. self.l1 = l1
  94. self.l2 = l2
  95. self.lr_power = lr_power
  96. self.weight_decay = weight_decay
  97. self.decay_tf = tuple((lambda: True)() for x in self.parameters)
  98. self.hyper_map = C.HyperMap()
  99. self.opt = P.ApplyFtrl(use_locking=use_locking)
  100. self.sparse_opt = inner.SparseApplyFtrlNoReturn(learning_rate, l1, l2, lr_power, use_locking=use_locking)
  101. def construct(self, grads):
  102. params = self.parameters
  103. moments = self.moments
  104. linear = self.linear
  105. lr = self.learning_rate
  106. if self.weight_decay > 0.0:
  107. grads = self.map_(F.partial(_apply_decay, self.weight_decay), self.decay_tf, params, grads)
  108. grads = self.scale_grad(grads)
  109. success = self.map_(F.partial(_ftrl_opt, self.opt, self.sparse_opt, lr, self.l1, self.l2, self.lr_power),
  110. linear, grads, params, moments)
  111. return success