| @@ -40,12 +40,47 @@ __global__ void ApplyAdamKernel(const size_t size, const T *gradient, const T *b | |||||
| } | } | ||||
| } | } | ||||
| template <typename T> | |||||
| __global__ void AdamWeightDecayKernel(const size_t size, const T *gradient, const float *learning_rate, | |||||
| const float *beta1, const float *beta2, const float *epsilon, const float *decay, | |||||
| T *variable, T *m, T *v) { | |||||
| for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { | |||||
| T next_m = beta1[0] * m[i] + (1 - beta1[0]) * gradient[i]; | |||||
| T next_v = beta2[0] * v[i] + (1 - beta2[0]) * gradient[i] * gradient[i]; | |||||
| T update = next_m / (sqrt(next_v) + epsilon[0]); | |||||
| update += decay[0] * variable[i]; | |||||
| variable[i] -= learning_rate[0] * update; | |||||
| m[i] = next_m; | |||||
| v[i] = next_v; | |||||
| } | |||||
| } | |||||
| template <> | |||||
| __global__ void AdamWeightDecayKernel(const size_t size, const half *gradient, const float *learning_rate, | |||||
| const float *beta1, const float *beta2, const float *epsilon, const float *decay, | |||||
| half *variable, half *m, half *v) { | |||||
| for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { | |||||
| half next_m = __float2half(beta1[0]) * m[i] + __float2half(1 - beta1[0]) * gradient[i]; | |||||
| half next_v = __float2half(beta2[0]) * v[i] + __float2half(1 - beta2[0]) * gradient[i] * gradient[i]; | |||||
| half update = next_m / (hsqrt(next_v) + __float2half(epsilon[0])); | |||||
| update += __float2half(decay[0]) * variable[i]; | |||||
| variable[i] -= __float2half(learning_rate[0]) * update; | |||||
| m[i] = next_m; | |||||
| v[i] = next_v; | |||||
| } | |||||
| } | |||||
| template <typename T> | template <typename T> | ||||
| void ApplyAdam(const size_t size, const T *gradient, const T *beta1_power, const T *beta2_power, const T *learning_rate, | void ApplyAdam(const size_t size, const T *gradient, const T *beta1_power, const T *beta2_power, const T *learning_rate, | ||||
| const T *beta1, const T *beta2, const T *epsilon, T *variable, T *m, T *v, cudaStream_t cuda_stream) { | const T *beta1, const T *beta2, const T *epsilon, T *variable, T *m, T *v, cudaStream_t cuda_stream) { | ||||
| ApplyAdamKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>( | ApplyAdamKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>( | ||||
| size, gradient, beta1_power, beta2_power, learning_rate, beta1, beta2, epsilon, variable, m, v); | size, gradient, beta1_power, beta2_power, learning_rate, beta1, beta2, epsilon, variable, m, v); | ||||
| } | } | ||||
| template <typename T> | |||||
| void AdamWeightDecayOp(const size_t size, const T *gradient, const float *learning_rate, const float *beta1, | |||||
| const float *beta2, const float *epsilon, const float *decay, T *variable, T *m, T *v, | |||||
| cudaStream_t cuda_stream) { | |||||
| AdamWeightDecayKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, gradient, learning_rate, beta1, beta2, | |||||
| epsilon, decay, variable, m, v); | |||||
| } | |||||
| template void ApplyAdam<float>(const size_t size, const float *gradient, const float *beta1_power, | template void ApplyAdam<float>(const size_t size, const float *gradient, const float *beta1_power, | ||||
| const float *beta2_power, const float *learning_rate, const float *beta1, | const float *beta2_power, const float *learning_rate, const float *beta1, | ||||
| @@ -54,3 +89,9 @@ template void ApplyAdam<float>(const size_t size, const float *gradient, const f | |||||
| template void ApplyAdam<half>(const size_t size, const half *gradient, const half *beta1_power, const half *beta2_power, | template void ApplyAdam<half>(const size_t size, const half *gradient, const half *beta1_power, const half *beta2_power, | ||||
| const half *learning_rate, const half *beta1, const half *beta2, const half *epsilon, | const half *learning_rate, const half *beta1, const half *beta2, const half *epsilon, | ||||
| half *variable, half *m, half *v, cudaStream_t cuda_stream); | half *variable, half *m, half *v, cudaStream_t cuda_stream); | ||||
| template void AdamWeightDecayOp<float>(const size_t size, const float *gradient, const float *learning_rate, | |||||
| const float *beta1, const float *beta2, const float *epsilon, const float *decay, | |||||
| float *variable, float *m, float *v, cudaStream_t cuda_stream); | |||||
| template void AdamWeightDecayOp<half>(const size_t size, const half *gradient, const float *learning_rate, | |||||
| const float *beta1, const float *beta2, const float *epsilon, const float *decay, | |||||
| half *variable, half *m, half *v, cudaStream_t cuda_stream); | |||||
| @@ -21,5 +21,9 @@ | |||||
| template <typename T> | template <typename T> | ||||
| void ApplyAdam(const size_t size, const T *gradient, const T *beta1_power, const T *beta2_power, const T *learning_rate, | void ApplyAdam(const size_t size, const T *gradient, const T *beta1_power, const T *beta2_power, const T *learning_rate, | ||||
| const T *beta1, const T *beta2, const T *epsilon, T *variable, T *m, T *v, cudaStream_t cuda_stream); | const T *beta1, const T *beta2, const T *epsilon, T *variable, T *m, T *v, cudaStream_t cuda_stream); | ||||
| template <typename T> | |||||
| void AdamWeightDecayOp(const size_t size, const T *gradient, const float *learning_rate, const float *beta1, | |||||
| const float *beta2, const float *epsilon, const float *decay, T *variable, T *m, T *v, | |||||
| cudaStream_t cuda_stream); | |||||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ADAM_IMPL_H_ | #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ADAM_IMPL_H_ | ||||
| @@ -0,0 +1,52 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "backend/kernel_compiler/gpu/nn/adam_weight_decay_gpu_kernel.h" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| MS_REG_GPU_KERNEL_ONE(AdamWeightDecay, | |||||
| KernelAttr() | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddOutputAttr(kNumberTypeFloat32) | |||||
| .AddOutputAttr(kNumberTypeFloat32) | |||||
| .AddOutputAttr(kNumberTypeFloat32), | |||||
| AdamWeightDecayGpuKernel, float) | |||||
| MS_REG_GPU_KERNEL_ONE(AdamWeightDecay, | |||||
| KernelAttr() | |||||
| .AddInputAttr(kNumberTypeFloat16) | |||||
| .AddInputAttr(kNumberTypeFloat16) | |||||
| .AddInputAttr(kNumberTypeFloat16) | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddInputAttr(kNumberTypeFloat16) | |||||
| .AddOutputAttr(kNumberTypeFloat16) | |||||
| .AddOutputAttr(kNumberTypeFloat16) | |||||
| .AddOutputAttr(kNumberTypeFloat16), | |||||
| AdamWeightDecayGpuKernel, half) | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,137 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_ADAM_WEIGHT_DECAY_GPU_KERNEL_H_ | |||||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_ADAM_WEIGHT_DECAY_GPU_KERNEL_H_ | |||||
| #include <vector> | |||||
| #include "backend/kernel_compiler/gpu/gpu_kernel.h" | |||||
| #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" | |||||
| #include "backend/kernel_compiler/gpu/cuda_impl/adam_impl.cuh" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| template <typename T> | |||||
| class AdamWeightDecayGpuKernel : public GpuKernel { | |||||
| public: | |||||
| AdamWeightDecayGpuKernel() | |||||
| : variable_size_(0), | |||||
| m_size_(0), | |||||
| v_size_(0), | |||||
| learning_rate_size_(0), | |||||
| beta1_size_(0), | |||||
| beta2_size_(0), | |||||
| epsilon_size_(0), | |||||
| decay_size_(0), | |||||
| gradient_size_(0) {} | |||||
| ~AdamWeightDecayGpuKernel() override = default; | |||||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | |||||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | |||||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, const std::vector<AddressPtr> &, | |||||
| void *stream_ptr) override { | |||||
| T *variable = GetDeviceAddress<T>(inputs, 0); | |||||
| T *m = GetDeviceAddress<T>(inputs, 1); | |||||
| T *v = GetDeviceAddress<T>(inputs, 2); | |||||
| float *lr = GetDeviceAddress<float>(inputs, 3); | |||||
| float *beta1 = GetDeviceAddress<float>(inputs, 4); | |||||
| float *beta2 = GetDeviceAddress<float>(inputs, 5); | |||||
| float *epsilon = GetDeviceAddress<float>(inputs, 6); | |||||
| float *decay = GetDeviceAddress<float>(inputs, 7); | |||||
| T *gradient = GetDeviceAddress<T>(inputs, 8); | |||||
| AdamWeightDecayOp(inputs[0]->size / sizeof(T), gradient, lr, beta1, beta2, epsilon, decay, variable, m, v, | |||||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||||
| return true; | |||||
| } | |||||
| bool Init(const CNodePtr &kernel_node) override { | |||||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||||
| if (input_num != 9) { | |||||
| MS_LOG(ERROR) << "Input number is " << input_num << ", but adam needs 9 inputs."; | |||||
| return false; | |||||
| } | |||||
| variable_size_ = sizeof(T); | |||||
| m_size_ = sizeof(T); | |||||
| v_size_ = sizeof(T); | |||||
| learning_rate_size_ = sizeof(float); | |||||
| beta1_size_ = sizeof(float); | |||||
| beta2_size_ = sizeof(float); | |||||
| epsilon_size_ = sizeof(float); | |||||
| decay_size_ = sizeof(float); | |||||
| gradient_size_ = sizeof(T); | |||||
| auto variable_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||||
| for (size_t i = 0; i < variable_shape.size(); i++) { | |||||
| variable_size_ *= variable_shape[i]; | |||||
| } | |||||
| auto m_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); | |||||
| for (size_t i = 0; i < m_shape.size(); i++) { | |||||
| m_size_ *= m_shape[i]; | |||||
| } | |||||
| auto v_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2); | |||||
| for (size_t i = 0; i < v_shape.size(); i++) { | |||||
| v_size_ *= v_shape[i]; | |||||
| } | |||||
| auto gradient_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 8); | |||||
| for (size_t i = 0; i < gradient_shape.size(); i++) { | |||||
| gradient_size_ *= gradient_shape[i]; | |||||
| } | |||||
| InitSizeLists(); | |||||
| return true; | |||||
| } | |||||
| protected: | |||||
| void InitSizeLists() override { | |||||
| input_size_list_.push_back(variable_size_); | |||||
| input_size_list_.push_back(m_size_); | |||||
| input_size_list_.push_back(v_size_); | |||||
| input_size_list_.push_back(learning_rate_size_); | |||||
| input_size_list_.push_back(beta1_size_); | |||||
| input_size_list_.push_back(beta2_size_); | |||||
| input_size_list_.push_back(epsilon_size_); | |||||
| input_size_list_.push_back(decay_size_); | |||||
| input_size_list_.push_back(gradient_size_); | |||||
| output_size_list_.push_back(0); | |||||
| output_size_list_.push_back(0); | |||||
| output_size_list_.push_back(0); | |||||
| } | |||||
| private: | |||||
| size_t variable_size_; | |||||
| size_t m_size_; | |||||
| size_t v_size_; | |||||
| size_t learning_rate_size_; | |||||
| size_t beta1_size_; | |||||
| size_t beta2_size_; | |||||
| size_t epsilon_size_; | |||||
| size_t decay_size_; | |||||
| size_t gradient_size_; | |||||
| std::vector<size_t> input_size_list_; | |||||
| std::vector<size_t> output_size_list_; | |||||
| std::vector<size_t> workspace_size_list_; | |||||
| }; | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_ADAM_WEIGHT_DECAY_GPU_KERNEL_H_ | |||||
| @@ -42,7 +42,7 @@ from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSumm | |||||
| TensorSummary, HistogramSummary, Print, Assert) | TensorSummary, HistogramSummary, Print, Assert) | ||||
| from .control_ops import ControlDepend, GeSwitch, Merge | from .control_ops import ControlDepend, GeSwitch, Merge | ||||
| from .inner_ops import (ScalarCast, Randperm, NoRepeatNGram, LambApplyOptimizerAssign, LambApplyWeightAssign, MakeRefKey, | from .inner_ops import (ScalarCast, Randperm, NoRepeatNGram, LambApplyOptimizerAssign, LambApplyWeightAssign, MakeRefKey, | ||||
| FusedWeightScaleApplyMomentum) | |||||
| FusedWeightScaleApplyMomentum, AdamWeightDecay) | |||||
| from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, AssignSub, Atan2, BatchMatMul, | from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, AssignSub, Atan2, BatchMatMul, | ||||
| BitwiseAnd, BitwiseOr, | BitwiseAnd, BitwiseOr, | ||||
| @@ -149,6 +149,7 @@ __all__ = [ | |||||
| 'TopK', | 'TopK', | ||||
| 'LinSpace', | 'LinSpace', | ||||
| 'Adam', | 'Adam', | ||||
| 'AdamWeightDecay', | |||||
| 'FusedSparseAdam', | 'FusedSparseAdam', | ||||
| 'FusedSparseLazyAdam', | 'FusedSparseLazyAdam', | ||||
| 'AdamNoUpdateParam', | 'AdamNoUpdateParam', | ||||
| @@ -441,3 +441,93 @@ class FusedWeightScaleApplyMomentum(PrimitiveWithInfer): | |||||
| validator.check_scalar_or_tensor_types_same({"d_dtype": d_dtype}, valid_dtypes, self.name) | validator.check_scalar_or_tensor_types_same({"d_dtype": d_dtype}, valid_dtypes, self.name) | ||||
| validator.check_scalar_or_tensor_types_same({"s_dtype": s_dtype}, valid_dtypes, self.name) | validator.check_scalar_or_tensor_types_same({"s_dtype": s_dtype}, valid_dtypes, self.name) | ||||
| return v_dtype | return v_dtype | ||||
| class AdamWeightDecay(PrimitiveWithInfer): | |||||
| r""" | |||||
| Updates gradients by the Adaptive Moment Estimation (AdamWeightDecay) algorithm with weight decay. | |||||
| The Adam algorithm is proposed in `Adam: A Method for Stochastic Optimization <https://arxiv.org/abs/1412.6980>`_. | |||||
| The updating formulas are as follows, | |||||
| .. math:: | |||||
| \begin{array}{ll} \\ | |||||
| m = \beta_1 * m + (1 - \beta_1) * g \\ | |||||
| v = \beta_2 * v + (1 - \beta_2) * g * g \\ | |||||
| w = w - lr * \frac{m}{\sqrt{v} + \epsilon} | |||||
| \end{array} | |||||
| :math:`m` represents the 1st moment vector, :math:`v` represents the 2nd moment vector, :math:`g` represents | |||||
| `gradient`, :math:`\beta_1, \beta_2` represent `beta1` and `beta2`, | |||||
| :math:`\lr` represents `learning_rate`, :math:`w` represents `var`, :math:`\epsilon` represents | |||||
| `epsilon`. | |||||
| Args: | |||||
| use_locking (bool): Whether to enable a lock to protect variable tensors from being updated. | |||||
| If true, updates of the var, m, and v tensors will be protected by a lock. | |||||
| If false, the result is unpredictable. Default: False. | |||||
| Inputs: | |||||
| - **var** (Tensor) - Weights to be updated. | |||||
| - **m** (Tensor) - The 1st moment vector in the updating formula, has the same type as `var`. | |||||
| - **v** (Tensor) - the 2nd moment vector in the updating formula. | |||||
| Mean square gradients with the same type as `var`. | |||||
| - **lr** (float) - :math:`l` in the updating formula. | |||||
| - **beta1** (float) - The exponential decay rate for the 1st moment estimations. | |||||
| - **beta2** (float) - The exponential decay rate for the 2nd moment estimations. | |||||
| - **epsilon** (float) - Term added to the denominator to improve numerical stability. | |||||
| - **gradient** (Tensor) - Gradient, has the same type as `var`. | |||||
| Outputs: | |||||
| Tuple of 3 Tensor, the updated parameters. | |||||
| - **var** (Tensor) - The same shape and data type as `var`. | |||||
| - **m** (Tensor) - The same shape and data type as `m`. | |||||
| - **v** (Tensor) - The same shape and data type as `v`. | |||||
| Supported Platforms: | |||||
| ``GPU`` | |||||
| Examples: | |||||
| >>> import numpy as np | |||||
| >>> import mindspore.nn as nn | |||||
| >>> from mindspore import Tensor, Parameter | |||||
| >>> from mindspore.ops import operations as ops | |||||
| >>> class Net(nn.Cell): | |||||
| ... def __init__(self): | |||||
| ... super(Net, self).__init__() | |||||
| ... self.adam_weight_decay = ops.AdamWeightDecay() | |||||
| ... self.var = Parameter(Tensor(np.ones([2, 2]).astype(np.float32)), name="var") | |||||
| ... self.m = Parameter(Tensor(np.ones([2, 2]).astype(np.float32)), name="m") | |||||
| ... self.v = Parameter(Tensor(np.ones([2, 2]).astype(np.float32)), name="v") | |||||
| ... def construct(self, lr, beta1, beta2, epsilon, decay, grad): | |||||
| ... out = self.adam_weight_decay(self.var, self.m, self.v, lr, beta1, beta2, | |||||
| ... epsilon, decay, grad) | |||||
| ... return out | |||||
| >>> np.random.seed(0) | |||||
| >>> net = Net() | |||||
| >>> gradient = Tensor(np.random.rand(2, 2).astype(np.float32)) | |||||
| >>> output = net(0.9, 0.9, 0.999, 1e-8, 1e-5, gradient) | |||||
| """ | |||||
| @prim_attr_register | |||||
| def __init__(self, use_locking=False): | |||||
| validator.check_value_type("use_locking", use_locking, [bool], self.name) | |||||
| def infer_shape(self, var_shape, m_shape, v_shape, lr_shape, beta1_shape, beta2_shape, | |||||
| epsilon_shape, decay_shape, grad_shape): | |||||
| validator.check("var_shape", var_shape, "m_shape", m_shape, Rel.EQ, self.name) | |||||
| validator.check("var_shape", var_shape, "v_shape", v_shape, Rel.EQ, self.name) | |||||
| validator.check("var_shape", var_shape, "grad_shape", grad_shape, Rel.EQ, self.name) | |||||
| return var_shape, m_shape, v_shape | |||||
| def infer_dtype(self, var_dtype, m_dtype, v_dtype, lr_dtype, beta1_dtype, beta2_dtype, | |||||
| epsilon_dtype, decay, grad_dtype): | |||||
| args = {"var": var_dtype, "m": m_dtype, "v": v_dtype, "grad": grad_dtype} | |||||
| validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name) | |||||
| args = {"lr": lr_dtype, "beta1": beta1_dtype, "beta2": beta2_dtype, "epsilon": epsilon_dtype, | |||||
| "decay": decay} | |||||
| validator.check_scalar_or_tensor_types_same(args, [mstype.float32], self.name, True) | |||||
| return var_dtype, m_dtype, v_dtype | |||||
| @@ -36,7 +36,7 @@ from src import BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithL | |||||
| BertTrainAccumulationAllReduceEachWithLossScaleCell, \ | BertTrainAccumulationAllReduceEachWithLossScaleCell, \ | ||||
| BertTrainAccumulationAllReducePostWithLossScaleCell, \ | BertTrainAccumulationAllReducePostWithLossScaleCell, \ | ||||
| BertTrainOneStepWithLossScaleCellForAdam, \ | BertTrainOneStepWithLossScaleCellForAdam, \ | ||||
| AdamWeightDecayForBert | |||||
| AdamWeightDecayForBert, AdamWeightDecayOp | |||||
| from src.dataset import create_bert_dataset | from src.dataset import create_bert_dataset | ||||
| from src.config import cfg, bert_net_cfg | from src.config import cfg, bert_net_cfg | ||||
| from src.utils import LossCallBack, BertLearningRate | from src.utils import LossCallBack, BertLearningRate | ||||
| @@ -96,6 +96,8 @@ def _get_optimizer(args_opt, network): | |||||
| {'order_params': params}] | {'order_params': params}] | ||||
| if args_opt.enable_lossscale == "true" and args_opt.device_target == 'GPU': | if args_opt.enable_lossscale == "true" and args_opt.device_target == 'GPU': | ||||
| optimizer = AdamWeightDecayForBert(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps) | optimizer = AdamWeightDecayForBert(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps) | ||||
| elif context.get_context("mode") == context.PYNATIVE_MODE and args_opt.device_target == 'GPU': | |||||
| optimizer = AdamWeightDecayOp(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps) | |||||
| else: | else: | ||||
| optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps) | optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps) | ||||
| elif cfg.optimizer == "Thor": | elif cfg.optimizer == "Thor": | ||||
| @@ -23,7 +23,7 @@ from .bert_model import BertAttention, BertConfig, BertEncoderCell, BertModel, \ | |||||
| BertOutput, BertSelfAttention, BertTransformer, EmbeddingLookup, \ | BertOutput, BertSelfAttention, BertTransformer, EmbeddingLookup, \ | ||||
| EmbeddingPostprocessor, RelaPosEmbeddingsGenerator, RelaPosMatrixGenerator, \ | EmbeddingPostprocessor, RelaPosEmbeddingsGenerator, RelaPosMatrixGenerator, \ | ||||
| SaturateCast, CreateAttentionMaskFromInputMask | SaturateCast, CreateAttentionMaskFromInputMask | ||||
| from .adam import AdamWeightDecayForBert | |||||
| from .adam import AdamWeightDecayForBert, AdamWeightDecayOp | |||||
| __all__ = [ | __all__ = [ | ||||
| "BertNetworkWithLoss", "BertPreTraining", "BertPretrainingLoss", | "BertNetworkWithLoss", "BertPreTraining", "BertPretrainingLoss", | ||||
| "GetMaskedLMOutput", "GetNextSentenceOutput", "BertTrainOneStepCell", | "GetMaskedLMOutput", "GetNextSentenceOutput", "BertTrainOneStepCell", | ||||
| @@ -33,5 +33,5 @@ __all__ = [ | |||||
| "BertSelfAttention", "BertTransformer", "EmbeddingLookup", | "BertSelfAttention", "BertTransformer", "EmbeddingLookup", | ||||
| "EmbeddingPostprocessor", "RelaPosEmbeddingsGenerator", "AdamWeightDecayForBert", | "EmbeddingPostprocessor", "RelaPosEmbeddingsGenerator", "AdamWeightDecayForBert", | ||||
| "RelaPosMatrixGenerator", "SaturateCast", "CreateAttentionMaskFromInputMask", | "RelaPosMatrixGenerator", "SaturateCast", "CreateAttentionMaskFromInputMask", | ||||
| "BertTrainOneStepWithLossScaleCellForAdam" | |||||
| "BertTrainOneStepWithLossScaleCellForAdam", "AdamWeightDecayOp" | |||||
| ] | ] | ||||
| @@ -28,6 +28,20 @@ _adam_opt = C.MultitypeFuncGraph("adam_opt") | |||||
| _scaler_one = Tensor(1, mstype.int32) | _scaler_one = Tensor(1, mstype.int32) | ||||
| _scaler_ten = Tensor(10, mstype.float32) | _scaler_ten = Tensor(10, mstype.float32) | ||||
| @_adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor", | |||||
| "Tensor", "Bool", "Bool") | |||||
| def _update_run_kernel(beta1, beta2, eps, lr, weight_decay, param, m, v, gradient, decay_flags, optim_filter): | |||||
| """ | |||||
| Update parameters by AdamWeightDecay op. | |||||
| """ | |||||
| if optim_filter: | |||||
| adam = P.AdamWeightDecay() | |||||
| if decay_flags: | |||||
| next_param = adam(param, m, v, lr, beta1, beta2, eps, Tensor(weight_decay, mstype.float32), gradient) | |||||
| else: | |||||
| next_param = adam(param, m, v, lr, beta1, beta2, eps, Tensor(0.0, mstype.float32), gradient) | |||||
| return next_param | |||||
| return gradient | |||||
| @_adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor", | @_adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor", | ||||
| "Tensor", "Bool", "Bool") | "Tensor", "Bool", "Bool") | ||||
| @@ -252,7 +266,7 @@ class AdamWeightDecayForBert(Optimizer): | |||||
| Examples: | Examples: | ||||
| >>> net = Net() | >>> net = Net() | ||||
| >>> #1) All parameters use the same learning rate and weight decay | >>> #1) All parameters use the same learning rate and weight decay | ||||
| >>> optim = nn.AdamWeightDecay(params=net.trainable_params()) | |||||
| >>> optim = AdamWeightDecay(params=net.trainable_params()) | |||||
| >>> | >>> | ||||
| >>> #2) Use parameter groups and set different values | >>> #2) Use parameter groups and set different values | ||||
| >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) | >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) | ||||
| @@ -260,7 +274,7 @@ class AdamWeightDecayForBert(Optimizer): | |||||
| >>> group_params = [{'params': conv_params, 'weight_decay': 0.01}, | >>> group_params = [{'params': conv_params, 'weight_decay': 0.01}, | ||||
| ... {'params': no_conv_params, 'lr': 0.01}, | ... {'params': no_conv_params, 'lr': 0.01}, | ||||
| ... {'order_params': net.trainable_params()}] | ... {'order_params': net.trainable_params()}] | ||||
| >>> optim = nn.AdamWeightDecay(group_params, learning_rate=0.1, weight_decay=0.0) | |||||
| >>> optim = AdamWeightDecay(group_params, learning_rate=0.1, weight_decay=0.0) | |||||
| >>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01. | >>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01. | ||||
| >>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0. | >>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0. | ||||
| >>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'. | >>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'. | ||||
| @@ -305,3 +319,105 @@ class AdamWeightDecayForBert(Optimizer): | |||||
| if self.use_parallel: | if self.use_parallel: | ||||
| self.broadcast_params(optim_result) | self.broadcast_params(optim_result) | ||||
| return optim_result | return optim_result | ||||
| class AdamWeightDecayOp(Optimizer): | |||||
| """ | |||||
| Implements the Adam algorithm to fix the weight decay. It is a complete operator, not a combination of other ops. | |||||
| Note: | |||||
| When separating parameter groups, the weight decay in each group will be applied on the parameters if the | |||||
| weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied | |||||
| on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive. | |||||
| To improve parameter groups performance, the customized order of parameters can be supported. | |||||
| Args: | |||||
| params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated, | |||||
| the element in `params` must be class `Parameter`. When the `params` is a list of `dict`, the "params", | |||||
| "lr", "weight_decay" and "order_params" are the keys can be parsed. | |||||
| - params: Required. The value must be a list of `Parameter`. | |||||
| - lr: Optional. If "lr" is in the keys, the value of the corresponding learning rate will be used. | |||||
| If not, the `learning_rate` in the API will be used. | |||||
| - weight_decay: Optional. If "weight_decay" is in the keys, the value of the corresponding weight decay | |||||
| will be used. If not, the `weight_decay` in the API will be used. | |||||
| - order_params: Optional. If "order_params" is in the keys, the value must be the order of parameters and | |||||
| the order will be followed in the optimizer. There are no other keys in the `dict` and the parameters | |||||
| which in the 'order_params' must be in one of group parameters. | |||||
| learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or a graph for the learning rate. | |||||
| When the learning_rate is an Iterable or a Tensor in a 1D dimension, use the dynamic learning rate, then | |||||
| the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule, | |||||
| use dynamic learning rate, the i-th learning rate will be calculated during the process of training | |||||
| according to the formula of LearningRateSchedule. When the learning_rate is a float or a Tensor in a zero | |||||
| dimension, use fixed learning rate. Other cases are not supported. The float learning rate must be | |||||
| equal to or greater than 0. If the type of `learning_rate` is int, it will be converted to float. | |||||
| Default: 1e-3. | |||||
| beta1 (float): The exponential decay rate for the 1st moment estimations. Default: 0.9. | |||||
| Should be in range (0.0, 1.0). | |||||
| beta2 (float): The exponential decay rate for the 2nd moment estimations. Default: 0.999. | |||||
| Should be in range (0.0, 1.0). | |||||
| eps (float): Term added to the denominator to improve numerical stability. Default: 1e-6. | |||||
| Should be greater than 0. | |||||
| weight_decay (float): Weight decay (L2 penalty). It must be equal to or greater than 0. Default: 0.0. | |||||
| Inputs: | |||||
| - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`. | |||||
| Outputs: | |||||
| tuple[bool], all elements are True. | |||||
| Supported Platforms: | |||||
| ``GPU`` | |||||
| Examples: | |||||
| >>> net = Net() | |||||
| >>> #1) All parameters use the same learning rate and weight decay | |||||
| >>> optim = AdamWeightDecayOp(params=net.trainable_params()) | |||||
| >>> | |||||
| >>> #2) Use parameter groups and set different values | |||||
| >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) | |||||
| >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) | |||||
| >>> group_params = [{'params': conv_params, 'weight_decay': 0.01}, | |||||
| ... {'params': no_conv_params, 'lr': 0.01}, | |||||
| ... {'order_params': net.trainable_params()}] | |||||
| >>> optim = AdamWeightDecayOp(group_params, learning_rate=0.1, weight_decay=0.0) | |||||
| >>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01. | |||||
| >>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0. | |||||
| >>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'. | |||||
| >>> | |||||
| >>> loss = nn.SoftmaxCrossEntropyWithLogits() | |||||
| >>> model = Model(net, loss_fn=loss, optimizer=optim) | |||||
| """ | |||||
| def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0): | |||||
| super(AdamWeightDecayOp, self).__init__(learning_rate, params, weight_decay) | |||||
| _check_param_value(beta1, beta2, eps, self.cls_name) | |||||
| self.beta1 = Tensor(np.array([beta1]).astype(np.float32)) | |||||
| self.beta2 = Tensor(np.array([beta2]).astype(np.float32)) | |||||
| self.eps = Tensor(np.array([eps]).astype(np.float32)) | |||||
| self.moments1 = self.parameters.clone(prefix="adam_m", init='zeros') | |||||
| self.moments2 = self.parameters.clone(prefix="adam_v", init='zeros') | |||||
| self.hyper_map = C.HyperMap() | |||||
| def construct(self, gradients): | |||||
| """AdamWeightDecayOp""" | |||||
| lr = self.get_lr() | |||||
| if self.is_group: | |||||
| if self.is_group_lr: | |||||
| optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps), | |||||
| lr, self.weight_decay, self.parameters, self.moments1, self.moments2, | |||||
| gradients, self.decay_flags, self.optim_filter) | |||||
| else: | |||||
| optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr), | |||||
| self.weight_decay, self.parameters, self.moments1, self.moments2, | |||||
| gradients, self.decay_flags, self.optim_filter) | |||||
| else: | |||||
| optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr, self.weight_decay), | |||||
| self.parameters, self.moments1, self.moments2, | |||||
| gradients, self.decay_flags, self.optim_filter) | |||||
| if self.use_parallel: | |||||
| self.broadcast_params(optim_result) | |||||
| return optim_result | |||||