From c1a619ccfe17b79995c52e69f493b7122aadb988 Mon Sep 17 00:00:00 2001 From: VectorSL Date: Wed, 3 Mar 2021 15:28:44 +0800 Subject: [PATCH] add AdamWeightDecayOp --- .../gpu/cuda_impl/adam_impl.cu | 41 ++++++ .../gpu/cuda_impl/adam_impl.cuh | 4 + .../gpu/nn/adam_weight_decay_gpu_kernel.cc | 52 +++++++ .../gpu/nn/adam_weight_decay_gpu_kernel.h | 137 ++++++++++++++++++ mindspore/ops/operations/__init__.py | 3 +- mindspore/ops/operations/inner_ops.py | 90 ++++++++++++ model_zoo/official/nlp/bert/run_pretrain.py | 4 +- model_zoo/official/nlp/bert/src/__init__.py | 4 +- model_zoo/official/nlp/bert/src/adam.py | 120 ++++++++++++++- 9 files changed, 449 insertions(+), 6 deletions(-) create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/nn/adam_weight_decay_gpu_kernel.cc create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/nn/adam_weight_decay_gpu_kernel.h diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/adam_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/adam_impl.cu index 615b94723d..ad73438807 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/adam_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/adam_impl.cu @@ -40,12 +40,47 @@ __global__ void ApplyAdamKernel(const size_t size, const T *gradient, const T *b } } +template +__global__ void AdamWeightDecayKernel(const size_t size, const T *gradient, const float *learning_rate, + const float *beta1, const float *beta2, const float *epsilon, const float *decay, + T *variable, T *m, T *v) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { + T next_m = beta1[0] * m[i] + (1 - beta1[0]) * gradient[i]; + T next_v = beta2[0] * v[i] + (1 - beta2[0]) * gradient[i] * gradient[i]; + T update = next_m / (sqrt(next_v) + epsilon[0]); + update += decay[0] * variable[i]; + variable[i] -= learning_rate[0] * update; + m[i] = next_m; + v[i] = next_v; + } +} +template <> +__global__ void AdamWeightDecayKernel(const size_t size, const half *gradient, const float *learning_rate, + const float *beta1, const float *beta2, const float *epsilon, const float *decay, + half *variable, half *m, half *v) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { + half next_m = __float2half(beta1[0]) * m[i] + __float2half(1 - beta1[0]) * gradient[i]; + half next_v = __float2half(beta2[0]) * v[i] + __float2half(1 - beta2[0]) * gradient[i] * gradient[i]; + half update = next_m / (hsqrt(next_v) + __float2half(epsilon[0])); + update += __float2half(decay[0]) * variable[i]; + variable[i] -= __float2half(learning_rate[0]) * update; + m[i] = next_m; + v[i] = next_v; + } +} template void ApplyAdam(const size_t size, const T *gradient, const T *beta1_power, const T *beta2_power, const T *learning_rate, const T *beta1, const T *beta2, const T *epsilon, T *variable, T *m, T *v, cudaStream_t cuda_stream) { ApplyAdamKernel<<>>( size, gradient, beta1_power, beta2_power, learning_rate, beta1, beta2, epsilon, variable, m, v); } +template +void AdamWeightDecayOp(const size_t size, const T *gradient, const float *learning_rate, const float *beta1, + const float *beta2, const float *epsilon, const float *decay, T *variable, T *m, T *v, + cudaStream_t cuda_stream) { + AdamWeightDecayKernel<<>>(size, gradient, learning_rate, beta1, beta2, + epsilon, decay, variable, m, v); +} template void ApplyAdam(const size_t size, const float *gradient, const float *beta1_power, const float *beta2_power, const float *learning_rate, const float *beta1, @@ -54,3 +89,9 @@ template void ApplyAdam(const size_t size, const float *gradient, const f template void ApplyAdam(const size_t size, const half *gradient, const half *beta1_power, const half *beta2_power, const half *learning_rate, const half *beta1, const half *beta2, const half *epsilon, half *variable, half *m, half *v, cudaStream_t cuda_stream); +template void AdamWeightDecayOp(const size_t size, const float *gradient, const float *learning_rate, + const float *beta1, const float *beta2, const float *epsilon, const float *decay, + float *variable, float *m, float *v, cudaStream_t cuda_stream); +template void AdamWeightDecayOp(const size_t size, const half *gradient, const float *learning_rate, + const float *beta1, const float *beta2, const float *epsilon, const float *decay, + half *variable, half *m, half *v, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/adam_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/adam_impl.cuh index 7fc4a3e949..a88d54e517 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/adam_impl.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/adam_impl.cuh @@ -21,5 +21,9 @@ template void ApplyAdam(const size_t size, const T *gradient, const T *beta1_power, const T *beta2_power, const T *learning_rate, const T *beta1, const T *beta2, const T *epsilon, T *variable, T *m, T *v, cudaStream_t cuda_stream); +template +void AdamWeightDecayOp(const size_t size, const T *gradient, const float *learning_rate, const float *beta1, + const float *beta2, const float *epsilon, const float *decay, T *variable, T *m, T *v, + cudaStream_t cuda_stream); #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ADAM_IMPL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/adam_weight_decay_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/adam_weight_decay_gpu_kernel.cc new file mode 100644 index 0000000000..ff5cb3ed4b --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/adam_weight_decay_gpu_kernel.cc @@ -0,0 +1,52 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/adam_weight_decay_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(AdamWeightDecay, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + AdamWeightDecayGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(AdamWeightDecay, + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16), + AdamWeightDecayGpuKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/adam_weight_decay_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/adam_weight_decay_gpu_kernel.h new file mode 100644 index 0000000000..2bd85c58c0 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/adam_weight_decay_gpu_kernel.h @@ -0,0 +1,137 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_ADAM_WEIGHT_DECAY_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_ADAM_WEIGHT_DECAY_GPU_KERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/adam_impl.cuh" +namespace mindspore { +namespace kernel { +template +class AdamWeightDecayGpuKernel : public GpuKernel { + public: + AdamWeightDecayGpuKernel() + : variable_size_(0), + m_size_(0), + v_size_(0), + learning_rate_size_(0), + beta1_size_(0), + beta2_size_(0), + epsilon_size_(0), + decay_size_(0), + gradient_size_(0) {} + + ~AdamWeightDecayGpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, const std::vector &, + void *stream_ptr) override { + T *variable = GetDeviceAddress(inputs, 0); + T *m = GetDeviceAddress(inputs, 1); + T *v = GetDeviceAddress(inputs, 2); + float *lr = GetDeviceAddress(inputs, 3); + float *beta1 = GetDeviceAddress(inputs, 4); + float *beta2 = GetDeviceAddress(inputs, 5); + float *epsilon = GetDeviceAddress(inputs, 6); + float *decay = GetDeviceAddress(inputs, 7); + T *gradient = GetDeviceAddress(inputs, 8); + AdamWeightDecayOp(inputs[0]->size / sizeof(T), gradient, lr, beta1, beta2, epsilon, decay, variable, m, v, + reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 9) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but adam needs 9 inputs."; + return false; + } + + variable_size_ = sizeof(T); + m_size_ = sizeof(T); + v_size_ = sizeof(T); + learning_rate_size_ = sizeof(float); + beta1_size_ = sizeof(float); + beta2_size_ = sizeof(float); + epsilon_size_ = sizeof(float); + decay_size_ = sizeof(float); + gradient_size_ = sizeof(T); + + auto variable_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + for (size_t i = 0; i < variable_shape.size(); i++) { + variable_size_ *= variable_shape[i]; + } + + auto m_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + for (size_t i = 0; i < m_shape.size(); i++) { + m_size_ *= m_shape[i]; + } + + auto v_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2); + for (size_t i = 0; i < v_shape.size(); i++) { + v_size_ *= v_shape[i]; + } + + auto gradient_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 8); + for (size_t i = 0; i < gradient_shape.size(); i++) { + gradient_size_ *= gradient_shape[i]; + } + + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(variable_size_); + input_size_list_.push_back(m_size_); + input_size_list_.push_back(v_size_); + input_size_list_.push_back(learning_rate_size_); + input_size_list_.push_back(beta1_size_); + input_size_list_.push_back(beta2_size_); + input_size_list_.push_back(epsilon_size_); + input_size_list_.push_back(decay_size_); + input_size_list_.push_back(gradient_size_); + output_size_list_.push_back(0); + output_size_list_.push_back(0); + output_size_list_.push_back(0); + } + + private: + size_t variable_size_; + size_t m_size_; + size_t v_size_; + size_t learning_rate_size_; + size_t beta1_size_; + size_t beta2_size_; + size_t epsilon_size_; + size_t decay_size_; + size_t gradient_size_; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_ADAM_WEIGHT_DECAY_GPU_KERNEL_H_ diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index 9d1dafbf78..d624e0c774 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -42,7 +42,7 @@ from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSumm TensorSummary, HistogramSummary, Print, Assert) from .control_ops import ControlDepend, GeSwitch, Merge from .inner_ops import (ScalarCast, Randperm, NoRepeatNGram, LambApplyOptimizerAssign, LambApplyWeightAssign, MakeRefKey, - FusedWeightScaleApplyMomentum) + FusedWeightScaleApplyMomentum, AdamWeightDecay) from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, AssignSub, Atan2, BatchMatMul, BitwiseAnd, BitwiseOr, @@ -149,6 +149,7 @@ __all__ = [ 'TopK', 'LinSpace', 'Adam', + 'AdamWeightDecay', 'FusedSparseAdam', 'FusedSparseLazyAdam', 'AdamNoUpdateParam', diff --git a/mindspore/ops/operations/inner_ops.py b/mindspore/ops/operations/inner_ops.py index 640b951eae..d8bf114578 100644 --- a/mindspore/ops/operations/inner_ops.py +++ b/mindspore/ops/operations/inner_ops.py @@ -441,3 +441,93 @@ class FusedWeightScaleApplyMomentum(PrimitiveWithInfer): validator.check_scalar_or_tensor_types_same({"d_dtype": d_dtype}, valid_dtypes, self.name) validator.check_scalar_or_tensor_types_same({"s_dtype": s_dtype}, valid_dtypes, self.name) return v_dtype + + +class AdamWeightDecay(PrimitiveWithInfer): + r""" + Updates gradients by the Adaptive Moment Estimation (AdamWeightDecay) algorithm with weight decay. + + The Adam algorithm is proposed in `Adam: A Method for Stochastic Optimization `_. + + The updating formulas are as follows, + + .. math:: + \begin{array}{ll} \\ + m = \beta_1 * m + (1 - \beta_1) * g \\ + v = \beta_2 * v + (1 - \beta_2) * g * g \\ + w = w - lr * \frac{m}{\sqrt{v} + \epsilon} + \end{array} + + :math:`m` represents the 1st moment vector, :math:`v` represents the 2nd moment vector, :math:`g` represents + `gradient`, :math:`\beta_1, \beta_2` represent `beta1` and `beta2`, + :math:`\lr` represents `learning_rate`, :math:`w` represents `var`, :math:`\epsilon` represents + `epsilon`. + + Args: + use_locking (bool): Whether to enable a lock to protect variable tensors from being updated. + If true, updates of the var, m, and v tensors will be protected by a lock. + If false, the result is unpredictable. Default: False. + + Inputs: + - **var** (Tensor) - Weights to be updated. + - **m** (Tensor) - The 1st moment vector in the updating formula, has the same type as `var`. + - **v** (Tensor) - the 2nd moment vector in the updating formula. + Mean square gradients with the same type as `var`. + - **lr** (float) - :math:`l` in the updating formula. + - **beta1** (float) - The exponential decay rate for the 1st moment estimations. + - **beta2** (float) - The exponential decay rate for the 2nd moment estimations. + - **epsilon** (float) - Term added to the denominator to improve numerical stability. + - **gradient** (Tensor) - Gradient, has the same type as `var`. + + Outputs: + Tuple of 3 Tensor, the updated parameters. + + - **var** (Tensor) - The same shape and data type as `var`. + - **m** (Tensor) - The same shape and data type as `m`. + - **v** (Tensor) - The same shape and data type as `v`. + + Supported Platforms: + ``GPU`` + + Examples: + >>> import numpy as np + >>> import mindspore.nn as nn + >>> from mindspore import Tensor, Parameter + >>> from mindspore.ops import operations as ops + >>> class Net(nn.Cell): + ... def __init__(self): + ... super(Net, self).__init__() + ... self.adam_weight_decay = ops.AdamWeightDecay() + ... self.var = Parameter(Tensor(np.ones([2, 2]).astype(np.float32)), name="var") + ... self.m = Parameter(Tensor(np.ones([2, 2]).astype(np.float32)), name="m") + ... self.v = Parameter(Tensor(np.ones([2, 2]).astype(np.float32)), name="v") + ... def construct(self, lr, beta1, beta2, epsilon, decay, grad): + ... out = self.adam_weight_decay(self.var, self.m, self.v, lr, beta1, beta2, + ... epsilon, decay, grad) + ... return out + >>> np.random.seed(0) + >>> net = Net() + >>> gradient = Tensor(np.random.rand(2, 2).astype(np.float32)) + >>> output = net(0.9, 0.9, 0.999, 1e-8, 1e-5, gradient) + """ + + @prim_attr_register + def __init__(self, use_locking=False): + validator.check_value_type("use_locking", use_locking, [bool], self.name) + + def infer_shape(self, var_shape, m_shape, v_shape, lr_shape, beta1_shape, beta2_shape, + epsilon_shape, decay_shape, grad_shape): + validator.check("var_shape", var_shape, "m_shape", m_shape, Rel.EQ, self.name) + validator.check("var_shape", var_shape, "v_shape", v_shape, Rel.EQ, self.name) + validator.check("var_shape", var_shape, "grad_shape", grad_shape, Rel.EQ, self.name) + return var_shape, m_shape, v_shape + + def infer_dtype(self, var_dtype, m_dtype, v_dtype, lr_dtype, beta1_dtype, beta2_dtype, + epsilon_dtype, decay, grad_dtype): + args = {"var": var_dtype, "m": m_dtype, "v": v_dtype, "grad": grad_dtype} + validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name) + + args = {"lr": lr_dtype, "beta1": beta1_dtype, "beta2": beta2_dtype, "epsilon": epsilon_dtype, + "decay": decay} + validator.check_scalar_or_tensor_types_same(args, [mstype.float32], self.name, True) + return var_dtype, m_dtype, v_dtype diff --git a/model_zoo/official/nlp/bert/run_pretrain.py b/model_zoo/official/nlp/bert/run_pretrain.py index 0338fe196a..a4d7d17ff6 100644 --- a/model_zoo/official/nlp/bert/run_pretrain.py +++ b/model_zoo/official/nlp/bert/run_pretrain.py @@ -36,7 +36,7 @@ from src import BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithL BertTrainAccumulationAllReduceEachWithLossScaleCell, \ BertTrainAccumulationAllReducePostWithLossScaleCell, \ BertTrainOneStepWithLossScaleCellForAdam, \ - AdamWeightDecayForBert + AdamWeightDecayForBert, AdamWeightDecayOp from src.dataset import create_bert_dataset from src.config import cfg, bert_net_cfg from src.utils import LossCallBack, BertLearningRate @@ -96,6 +96,8 @@ def _get_optimizer(args_opt, network): {'order_params': params}] if args_opt.enable_lossscale == "true" and args_opt.device_target == 'GPU': optimizer = AdamWeightDecayForBert(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps) + elif context.get_context("mode") == context.PYNATIVE_MODE and args_opt.device_target == 'GPU': + optimizer = AdamWeightDecayOp(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps) else: optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps) elif cfg.optimizer == "Thor": diff --git a/model_zoo/official/nlp/bert/src/__init__.py b/model_zoo/official/nlp/bert/src/__init__.py index 72046b3461..0e3f1ab8d0 100644 --- a/model_zoo/official/nlp/bert/src/__init__.py +++ b/model_zoo/official/nlp/bert/src/__init__.py @@ -23,7 +23,7 @@ from .bert_model import BertAttention, BertConfig, BertEncoderCell, BertModel, \ BertOutput, BertSelfAttention, BertTransformer, EmbeddingLookup, \ EmbeddingPostprocessor, RelaPosEmbeddingsGenerator, RelaPosMatrixGenerator, \ SaturateCast, CreateAttentionMaskFromInputMask -from .adam import AdamWeightDecayForBert +from .adam import AdamWeightDecayForBert, AdamWeightDecayOp __all__ = [ "BertNetworkWithLoss", "BertPreTraining", "BertPretrainingLoss", "GetMaskedLMOutput", "GetNextSentenceOutput", "BertTrainOneStepCell", @@ -33,5 +33,5 @@ __all__ = [ "BertSelfAttention", "BertTransformer", "EmbeddingLookup", "EmbeddingPostprocessor", "RelaPosEmbeddingsGenerator", "AdamWeightDecayForBert", "RelaPosMatrixGenerator", "SaturateCast", "CreateAttentionMaskFromInputMask", - "BertTrainOneStepWithLossScaleCellForAdam" + "BertTrainOneStepWithLossScaleCellForAdam", "AdamWeightDecayOp" ] diff --git a/model_zoo/official/nlp/bert/src/adam.py b/model_zoo/official/nlp/bert/src/adam.py index c7a952e2bb..7ca2c5b852 100644 --- a/model_zoo/official/nlp/bert/src/adam.py +++ b/model_zoo/official/nlp/bert/src/adam.py @@ -28,6 +28,20 @@ _adam_opt = C.MultitypeFuncGraph("adam_opt") _scaler_one = Tensor(1, mstype.int32) _scaler_ten = Tensor(10, mstype.float32) +@_adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor", + "Tensor", "Bool", "Bool") +def _update_run_kernel(beta1, beta2, eps, lr, weight_decay, param, m, v, gradient, decay_flags, optim_filter): + """ + Update parameters by AdamWeightDecay op. + """ + if optim_filter: + adam = P.AdamWeightDecay() + if decay_flags: + next_param = adam(param, m, v, lr, beta1, beta2, eps, Tensor(weight_decay, mstype.float32), gradient) + else: + next_param = adam(param, m, v, lr, beta1, beta2, eps, Tensor(0.0, mstype.float32), gradient) + return next_param + return gradient @_adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor", "Tensor", "Bool", "Bool") @@ -252,7 +266,7 @@ class AdamWeightDecayForBert(Optimizer): Examples: >>> net = Net() >>> #1) All parameters use the same learning rate and weight decay - >>> optim = nn.AdamWeightDecay(params=net.trainable_params()) + >>> optim = AdamWeightDecay(params=net.trainable_params()) >>> >>> #2) Use parameter groups and set different values >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) @@ -260,7 +274,7 @@ class AdamWeightDecayForBert(Optimizer): >>> group_params = [{'params': conv_params, 'weight_decay': 0.01}, ... {'params': no_conv_params, 'lr': 0.01}, ... {'order_params': net.trainable_params()}] - >>> optim = nn.AdamWeightDecay(group_params, learning_rate=0.1, weight_decay=0.0) + >>> optim = AdamWeightDecay(group_params, learning_rate=0.1, weight_decay=0.0) >>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01. >>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0. >>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'. @@ -305,3 +319,105 @@ class AdamWeightDecayForBert(Optimizer): if self.use_parallel: self.broadcast_params(optim_result) return optim_result + +class AdamWeightDecayOp(Optimizer): + """ + Implements the Adam algorithm to fix the weight decay. It is a complete operator, not a combination of other ops. + + Note: + When separating parameter groups, the weight decay in each group will be applied on the parameters if the + weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied + on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive. + + To improve parameter groups performance, the customized order of parameters can be supported. + + Args: + params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated, + the element in `params` must be class `Parameter`. When the `params` is a list of `dict`, the "params", + "lr", "weight_decay" and "order_params" are the keys can be parsed. + + - params: Required. The value must be a list of `Parameter`. + + - lr: Optional. If "lr" is in the keys, the value of the corresponding learning rate will be used. + If not, the `learning_rate` in the API will be used. + + - weight_decay: Optional. If "weight_decay" is in the keys, the value of the corresponding weight decay + will be used. If not, the `weight_decay` in the API will be used. + + - order_params: Optional. If "order_params" is in the keys, the value must be the order of parameters and + the order will be followed in the optimizer. There are no other keys in the `dict` and the parameters + which in the 'order_params' must be in one of group parameters. + + learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or a graph for the learning rate. + When the learning_rate is an Iterable or a Tensor in a 1D dimension, use the dynamic learning rate, then + the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule, + use dynamic learning rate, the i-th learning rate will be calculated during the process of training + according to the formula of LearningRateSchedule. When the learning_rate is a float or a Tensor in a zero + dimension, use fixed learning rate. Other cases are not supported. The float learning rate must be + equal to or greater than 0. If the type of `learning_rate` is int, it will be converted to float. + Default: 1e-3. + beta1 (float): The exponential decay rate for the 1st moment estimations. Default: 0.9. + Should be in range (0.0, 1.0). + beta2 (float): The exponential decay rate for the 2nd moment estimations. Default: 0.999. + Should be in range (0.0, 1.0). + eps (float): Term added to the denominator to improve numerical stability. Default: 1e-6. + Should be greater than 0. + weight_decay (float): Weight decay (L2 penalty). It must be equal to or greater than 0. Default: 0.0. + + Inputs: + - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`. + + Outputs: + tuple[bool], all elements are True. + + Supported Platforms: + ``GPU`` + + Examples: + >>> net = Net() + >>> #1) All parameters use the same learning rate and weight decay + >>> optim = AdamWeightDecayOp(params=net.trainable_params()) + >>> + >>> #2) Use parameter groups and set different values + >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) + >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) + >>> group_params = [{'params': conv_params, 'weight_decay': 0.01}, + ... {'params': no_conv_params, 'lr': 0.01}, + ... {'order_params': net.trainable_params()}] + >>> optim = AdamWeightDecayOp(group_params, learning_rate=0.1, weight_decay=0.0) + >>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01. + >>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0. + >>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'. + >>> + >>> loss = nn.SoftmaxCrossEntropyWithLogits() + >>> model = Model(net, loss_fn=loss, optimizer=optim) + """ + def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0): + super(AdamWeightDecayOp, self).__init__(learning_rate, params, weight_decay) + _check_param_value(beta1, beta2, eps, self.cls_name) + self.beta1 = Tensor(np.array([beta1]).astype(np.float32)) + self.beta2 = Tensor(np.array([beta2]).astype(np.float32)) + self.eps = Tensor(np.array([eps]).astype(np.float32)) + self.moments1 = self.parameters.clone(prefix="adam_m", init='zeros') + self.moments2 = self.parameters.clone(prefix="adam_v", init='zeros') + self.hyper_map = C.HyperMap() + + def construct(self, gradients): + """AdamWeightDecayOp""" + lr = self.get_lr() + if self.is_group: + if self.is_group_lr: + optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps), + lr, self.weight_decay, self.parameters, self.moments1, self.moments2, + gradients, self.decay_flags, self.optim_filter) + else: + optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr), + self.weight_decay, self.parameters, self.moments1, self.moments2, + gradients, self.decay_flags, self.optim_filter) + else: + optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr, self.weight_decay), + self.parameters, self.moments1, self.moments2, + gradients, self.decay_flags, self.optim_filter) + if self.use_parallel: + self.broadcast_params(optim_result) + return optim_result