You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

optimizer.py 7.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """optimizer"""
  16. from typing import Iterable
  17. import numpy as np
  18. import mindspore
  19. from mindspore.ops import functional as F, composite as C, operations as P
  20. from mindspore.nn.cell import Cell
  21. from mindspore.common.parameter import Parameter, ParameterTuple
  22. from mindspore.common.initializer import initializer
  23. import mindspore.common.dtype as mstype
  24. from mindspore._checkparam import Validator as validator
  25. from mindspore._checkparam import Rel
  26. from mindspore.common.tensor import Tensor
  27. from mindspore import log as logger
  28. __all__ = ['Optimizer']
  29. class Optimizer(Cell):
  30. """
  31. Base class for all optimizers.
  32. This class defines the API to add Ops to train a model.
  33. Note:
  34. This class defines the API to add Ops to train a model. Never use
  35. this class directly, but instead instantiate one of its subclasses.
  36. Args:
  37. learning_rate (float): A floating point value for the learning rate. Should be greater than 0.
  38. parameters (list): A list of parameter, which will be updated. The element in `parameters`
  39. should be class mindspore.Parameter.
  40. weight_decay (float): A floating point value for the weight decay. If the type of `weight_decay`
  41. input is int, it will be convertd to float. Default: 0.0.
  42. loss_scale (float): A floating point value for the loss scale. It should be greater than 0. If the
  43. type of `loss_scale` input is int, it will be convertd to float. Default: 1.0.
  44. decay_filter (Function): A function to determine whether to apply weight decay on parameters. Default: lambda
  45. x: 'beta' not in x.name and 'gamma' not in x.name.
  46. Raises:
  47. ValueError: If the learning_rate is a Tensor, but the dims of tensor is greater than 1.
  48. TypeError: If the learning_rate is not any of the three types: float, Tensor, Iterable.
  49. """
  50. def __init__(self, learning_rate, parameters, weight_decay=0.0, loss_scale=1.0,
  51. decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name):
  52. super(Optimizer, self).__init__(auto_prefix=False)
  53. if isinstance(learning_rate, float):
  54. self.dynamic_lr = False
  55. self.gather = None
  56. self.assignadd = None
  57. self.global_step = None
  58. validator.check_number_range("learning rate", learning_rate, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name)
  59. learning_rate = Tensor(learning_rate, mstype.float32)
  60. else:
  61. self.dynamic_lr = True
  62. self.gather = P.GatherV2()
  63. self.assignadd = P.AssignAdd()
  64. self.global_step = Parameter(initializer(0, [1], mindspore.int32), name='global_step')
  65. if isinstance(learning_rate, Iterable):
  66. learning_rate = Tensor(np.array(list(learning_rate)).astype(np.float32))
  67. elif isinstance(learning_rate, Tensor):
  68. if learning_rate.dim() > 1:
  69. raise ValueError("Learning rate should be a 0 or 1 dim `Tensor`,"
  70. f"but got {learning_rate.dim()}.")
  71. if learning_rate.dim() == 1 and learning_rate.size() < 2:
  72. logger.warning("If want to use the dynamic learning rate, please make sure that the number "
  73. "of elements in the list, tuple or tensor passed is greater than 1.")
  74. else:
  75. raise TypeError("Learning rate should be float, Tensor or Iterable.")
  76. if isinstance(weight_decay, int):
  77. weight_decay = float(weight_decay)
  78. validator.check_float_legal_value('weight_decay', weight_decay, None)
  79. if isinstance(loss_scale, int):
  80. loss_scale = float(loss_scale)
  81. validator.check_float_legal_value('loss_scale', loss_scale, None)
  82. if loss_scale <= 0.0:
  83. raise ValueError("Loss scale should be greater than 0, but got {}".format(loss_scale))
  84. self.loss_scale = loss_scale
  85. if weight_decay < 0.0:
  86. raise ValueError("Weight decay should be equal or greater than 0, but got {}".format(weight_decay))
  87. self.learning_rate = Parameter(learning_rate, name="learning_rate")
  88. self.parameters = ParameterTuple(parameters)
  89. self.reciprocal_scale = 1.0 / loss_scale
  90. self.weight_decay = weight_decay * loss_scale
  91. self.decay_flags = tuple(decay_filter(x) for x in self.parameters)
  92. if not self.parameters:
  93. raise ValueError("optimizer got an empty parameter list.")
  94. def decay_weight(self, gradients):
  95. """
  96. Weight decay.
  97. An approach to reduce the overfitting of a deep learning neural network model.
  98. Args:
  99. gradients (tuple[Tensor]): The gradients of `self.parameters`, and have the same shape with
  100. `self.parameters`.
  101. Returns:
  102. tuple[Tensor], The gradients after weight decay.
  103. """
  104. if self.weight_decay > 0:
  105. params = self.parameters
  106. gradients = self.hyper_map(F.partial(apply_decay, self.weight_decay), self.decay_flags, params, gradients)
  107. return gradients
  108. def scale_grad(self, gradients):
  109. """
  110. Loss scale for mixed precision.
  111. An approach of mixed precision training to improve the speed and energy efficiency of training deep neural
  112. network.
  113. Args:
  114. gradients (tuple[Tensor]): The gradients of `self.parameters`, and have the same shape with
  115. `self.parameters`.
  116. Returns:
  117. tuple[Tensor], The gradients after loss scale.
  118. """
  119. if self.reciprocal_scale != 1.0:
  120. gradients = self.hyper_map(F.partial(grad_scale, self.reciprocal_scale), gradients)
  121. return gradients
  122. def get_lr(self):
  123. """
  124. Get the learning rate of current step.
  125. Returns:
  126. float, the learning rate of current step.
  127. """
  128. lr = self.learning_rate
  129. if self.dynamic_lr:
  130. lr = self.gather(self.learning_rate, self.global_step, 0)
  131. F.control_depend(lr, self.assignadd(self.global_step, 1))
  132. return lr
  133. def construct(self, *hyper_params):
  134. raise NotImplementedError
  135. op_add = P.AddN()
  136. apply_decay = C.MultitypeFuncGraph("apply_decay")
  137. @apply_decay.register("Number", "Bool", "Tensor", "Tensor")
  138. def _tensor_apply_decay(weight_decay, if_apply, weight, gradient):
  139. """Get grad with weight_decay."""
  140. if if_apply:
  141. return op_add((weight * weight_decay, gradient))
  142. return gradient
  143. grad_scale = C.MultitypeFuncGraph("grad_scale")
  144. @grad_scale.register("Number", "Tensor")
  145. def tensor_grad_scale(scale, grad):
  146. """Get grad with scale."""
  147. if scale == 1.0:
  148. return grad
  149. cast_op = P.Cast()
  150. type_op = P.DType()
  151. return grad * cast_op(F.scalar_to_array(scale), type_op(grad))