From: @zhaosida_hw Reviewed-by: @kisnwang,@zhoufeng54 Signed-off-by: @zhoufeng54pull/15709/MERGE
| @@ -119,6 +119,7 @@ checkopts() | |||
| ANDROID_STL="c++_shared" | |||
| ENABLE_MAKE_CLEAN="off" | |||
| X86_64_SIMD="off" | |||
| ARM_SIMD="off" | |||
| DEVICE_VERSION="" | |||
| DEVICE="" | |||
| ENABLE_NPU="off" | |||
| @@ -331,6 +332,9 @@ checkopts() | |||
| if [[ "$OPTARG" == "sse" || "$OPTARG" == "avx" ]]; then | |||
| X86_64_SIMD="$OPTARG" | |||
| fi | |||
| if [[ "$OPTARG" == "neon" ]]; then | |||
| ARM_SIMD="$OPTARG" | |||
| fi | |||
| ;; | |||
| H) | |||
| check_on_off $OPTARG H | |||
| @@ -474,7 +478,7 @@ build_mindspore() | |||
| CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_GPU=ON -DUSE_CUDA=ON -DCUDA_PATH=$CUDA_PATH -DMS_REQUIRE_CUDA_VERSION=${CUDA_VERSION}" | |||
| fi | |||
| if [[ "X$ENABLE_CPU" = "Xon" ]]; then | |||
| CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_CPU=ON -DX86_64_SIMD=${X86_64_SIMD}" | |||
| CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_CPU=ON -DX86_64_SIMD=${X86_64_SIMD} -DARM_SIMD=${ARM_SIMD}" | |||
| fi | |||
| if [[ "X$COMPILE_MINDDATA" = "Xon" ]]; then | |||
| CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_MINDDATA=ON" | |||
| @@ -61,6 +61,18 @@ if(ENABLE_CPU) | |||
| message("not compiled quantum kernel_compiler") | |||
| set(QUANTUM_SRC_LIST "") | |||
| endif() | |||
| if("${ARM_SIMD}" STREQUAL "neon") | |||
| set(CPU_SIMD_SRC "${CMAKE_CURRENT_SOURCE_DIR}/cpu/adam_weight_decay_cpu_kernel.cc") | |||
| add_compile_definitions(ENABLE_NEON) | |||
| set_property(SOURCE ${CPU_SIMD_SRC} PROPERTY COMPILE_OPTIONS -O3 -ffast-math) | |||
| endif() | |||
| if("${X86_64_SIMD}" STREQUAL "avx") | |||
| set(CPU_SIMD_SRC "${CMAKE_CURRENT_SOURCE_DIR}/cpu/adam_weight_decay_cpu_kernel.cc") | |||
| add_compile_definitions(ENABLE_AVX512) | |||
| set_property(SOURCE ${CPU_SIMD_SRC} PROPERTY COMPILE_OPTIONS -O3 -fopenmp -mavx512f -ffast-math) | |||
| endif() | |||
| endif() | |||
| if(NOT (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU))) | |||
| @@ -0,0 +1,146 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/kernel_compiler/cpu/adam_weight_decay_cpu_kernel.h" | |||
| #include <cmath> | |||
| #include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h" | |||
| #include "runtime/device/cpu/cpu_device_address.h" | |||
| #include "utils/ms_utils.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| template <typename T> | |||
| void AdamWeightDecayCPUKernel::LaunchAdamWeightDecay(T *var, T *m, T *v, float lr, float beta1, float beta2, | |||
| float epsilon, T *decay, const T *gradient, size_t size) { | |||
| float beta1_minus = 1 - beta1; | |||
| float beta2_minus = 1 - beta2; | |||
| #if defined(ENABLE_AVX512) | |||
| MS_FLOAT32X16 beta1_16 = MS_MOV512_F32(beta1); | |||
| MS_FLOAT32X16 beta2_16 = MS_MOV512_F32(beta2); | |||
| MS_FLOAT32X16 beta1_minus_16 = MS_MOV512_F32(beta1_minus); | |||
| MS_FLOAT32X16 beta2_minus_16 = MS_MOV512_F32(beta2_minus); | |||
| MS_FLOAT32X16 lr_neg_16 = MS_MOV512_F32(-lr); | |||
| MS_FLOAT32X16 epsilon_16 = MS_MOV512_F32(epsilon); | |||
| MS_FLOAT32X16 decay_16 = MS_MOV512_F32(*decay); | |||
| #endif | |||
| #if defined(ENABLE_NEON) | |||
| MS_FLOAT32X4 epsilon_4 = MS_MOVQ_F32(epsilon); | |||
| float lr_neg = -lr; | |||
| #endif | |||
| auto task = [&](size_t start, size_t end) { | |||
| size_t i = start; | |||
| #if defined(ENABLE_AVX512) | |||
| if (end >= MS_AVX512_WIDTH) { | |||
| for (; i <= end - MS_AVX512_WIDTH; i += MS_AVX512_WIDTH) { | |||
| MS_FLOAT32X16 var_16 = MS_LD512_F32(var + i); | |||
| MS_FLOAT32X16 m_16 = MS_LD512_F32(m + i); | |||
| MS_FLOAT32X16 v_16 = MS_LD512_F32(v + i); | |||
| MS_FLOAT32X16 g_16 = MS_LD512_F32(gradient + i); | |||
| m_16 = MS_MUL512_F32(m_16, beta1_16); | |||
| m_16 = MS_FMA512_F32(g_16, beta1_minus_16, m_16); | |||
| v_16 = MS_MUL512_F32(v_16, beta2_16); | |||
| v_16 = MS_MUL512_F32(g_16, g_16); | |||
| v_16 = MS_FMA512_F32(g_16, beta2_minus_16, v_16); | |||
| g_16 = MS_SQRT512_F32(v_16); | |||
| g_16 = MS_DIV512_F32(m_16, MS_ADD512_F32(g_16, epsilon_16)); | |||
| g_16 = MS_FMA512_F32(var_16, decay_16, g_16); | |||
| var_16 = MS_FMA512_F32(g_16, lr_neg_16, var_16); | |||
| MS_ST512_F32(var + i, var_16); | |||
| MS_ST512_F32(m + i, m_16); | |||
| MS_ST512_F32(v + i, v_16); | |||
| } | |||
| } | |||
| #endif | |||
| #if defined(ENABLE_NEON) | |||
| if (end >= MS_NEON_WIDTH) { | |||
| for (; i <= end - MS_NEON_WIDTH; i += MS_NEON_WIDTH) { | |||
| MS_FLOAT32X4 var_4 = MS_LDQ_F32(var + i); | |||
| MS_FLOAT32X4 m_4 = MS_LDQ_F32(m + i); | |||
| MS_FLOAT32X4 v_4 = MS_LDQ_F32(v + i); | |||
| MS_FLOAT32X4 g_4 = MS_LDQ_F32(gradient + i); | |||
| m_4 = MS_MULQ_N_F32(m_4, beta1); | |||
| m_4 = MS_MLAQ_N_F32(m_4, g_4, beta1_minus); | |||
| v_4 = MS_MULQ_N_F32(v_4, beta2); | |||
| g_4 = MS_MULQ_F32(g_4, g_4); | |||
| v_4 = MS_MLAQ_N_F32(v_4, g_4, beta2_minus); | |||
| g_4 = MS_SQRT_F32(v_4); | |||
| g_4 = MS_DIVQ_F32(m_4, MS_ADDQ_F32(g_4, epsilon_4)); | |||
| g_4 = MS_MLAQ_N_F32(g_4, var_4, *decay); | |||
| var_4 = MS_MLAQ_N_F32(var_4, g_4, lr_neg); | |||
| MS_STQ_F32(var + i, var_4); | |||
| MS_STQ_F32(m + i, m_4); | |||
| MS_STQ_F32(v + i, v_4); | |||
| } | |||
| } | |||
| #endif | |||
| for (; i < end; i++) { | |||
| m[i] += (gradient[i] - m[i]) * beta1_minus; | |||
| v[i] += (gradient[i] * gradient[i] - v[i]) * beta2_minus; | |||
| T update = m[i] / (std::sqrt(v[i]) + epsilon); | |||
| update += decay[0] * var[i]; | |||
| var[i] -= lr * update; | |||
| } | |||
| }; | |||
| CPUKernelUtils::ParallelFor(task, size); | |||
| } | |||
| void AdamWeightDecayCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| if (input_num != 9) { | |||
| MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but AdamWeightDecay needs 9 inputs."; | |||
| } | |||
| size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); | |||
| if (output_num != 3) { | |||
| MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but AdamWeightDecay needs 3 outputs."; | |||
| } | |||
| } | |||
| bool AdamWeightDecayCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | |||
| const std::vector<kernel::AddressPtr> & /*workspace*/, | |||
| const std::vector<kernel::AddressPtr> &outputs) { | |||
| if (inputs.size() != 9) { | |||
| MS_LOG(EXCEPTION) << "Input number is " << inputs.size() << ", but AdamWeightDecay needs 9 inputs."; | |||
| } | |||
| if (outputs.size() != 3) { | |||
| MS_LOG(EXCEPTION) << "Output number is " << outputs.size() << ", but AdamWeightDecay needs 3 outputs."; | |||
| } | |||
| if (inputs[0]->size != inputs[1]->size || inputs[0]->size != inputs[2]->size || inputs[0]->size != inputs[8]->size) { | |||
| MS_LOG(EXCEPTION) << "Error input data size!"; | |||
| } | |||
| size_t f_size = sizeof(float); | |||
| if (inputs[3]->size != f_size || inputs[4]->size != f_size || inputs[5]->size != f_size || | |||
| inputs[6]->size != f_size || inputs[7]->size != f_size) { | |||
| MS_LOG(EXCEPTION) << "The attribute beta, lr and epsilon must be float!"; | |||
| } | |||
| auto var = reinterpret_cast<float *>(inputs[0]->addr); | |||
| auto m = reinterpret_cast<float *>(inputs[1]->addr); | |||
| auto v = reinterpret_cast<float *>(inputs[2]->addr); | |||
| float lr = reinterpret_cast<float *>(inputs[3]->addr)[0]; | |||
| float beta1 = reinterpret_cast<float *>(inputs[4]->addr)[0]; | |||
| float beta2 = reinterpret_cast<float *>(inputs[5]->addr)[0]; | |||
| float epsilon = reinterpret_cast<float *>(inputs[6]->addr)[0]; | |||
| auto decay = reinterpret_cast<float *>(inputs[7]->addr); | |||
| auto gradient = reinterpret_cast<float *>(inputs[8]->addr); | |||
| // multithreading | |||
| size_t lens = inputs[0]->size > 0 ? static_cast<size_t>(inputs[0]->size / sizeof(float)) : 1; | |||
| LaunchAdamWeightDecay<float>(var, m, v, lr, beta1, beta2, epsilon, decay, gradient, lens); | |||
| return true; | |||
| } | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,97 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_ADAM_WEIGHT_DECAY_CPU_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_ADAM_WEIGHT_DECAY_CPU_KERNEL_H_ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "backend/kernel_compiler/cpu/cpu_kernel.h" | |||
| #include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" | |||
| #ifdef ENABLE_NEON | |||
| #include <arm_neon.h> | |||
| #endif | |||
| #if defined(ENABLE_AVX512) | |||
| #include <x86intrin.h> | |||
| #endif | |||
| #ifdef ENABLE_NEON | |||
| #define MS_FLOAT32X4 float32x4_t | |||
| #define MS_LDQ_F32 vld1q_f32 | |||
| #define MS_MOVQ_F32 vmovq_n_f32 | |||
| #define MS_STQ_F32 vst1q_f32 | |||
| #define MS_ADDQ_F32(src1, src2) vaddq_f32(src1, src2) | |||
| #define MS_MULQ_F32(src1, src2) vmulq_f32(src1, src2) | |||
| #define MS_MULQ_N_F32(src1, src2) vmulq_n_f32(src1, src2) | |||
| #define MS_DIVQ_F32(src1, src2) vdivq_f32(src1, src2) | |||
| #define MS_MLAQ_F32(src1, src2, src3) vmlaq_f32(src1, src2, src3) | |||
| #define MS_MLAQ_N_F32(src1, src2, src3) vmlaq_n_f32(src1, src2, src3) | |||
| #define MS_SQRT_F32(src) vsqrtq_f32(src) | |||
| #define MS_CAST_F32_F16(src) vreinterpretq_f32_f16(src) | |||
| #define MS_NEON_WIDTH 4 | |||
| #endif | |||
| #if defined(ENABLE_AVX512) | |||
| #define MS_FLOAT32X16 __m512 | |||
| #define MS_LD512_F32 _mm512_loadu_ps | |||
| #define MS_ST512_F32 _mm512_storeu_ps | |||
| #define MS_MOV512_F32 _mm512_set1_ps | |||
| #define MS_ADD512_F32(src1, src2) _mm512_add_ps(src1, src2) | |||
| #define MS_MUL512_F32(src1, src2) _mm512_mul_ps(src1, src2) | |||
| #define MS_DIV512_F32(src1, src2) _mm512_div_ps(src1, src2) | |||
| #define MS_FMA512_F32(src1, src2, src3) _mm512_fmadd_ps(src1, src2, src3) | |||
| #define MS_SQRT512_F32(src) _mm512_sqrt_ps(src) | |||
| #define MS_CAST512_F32_S32(src) _mm512_castsi512_ps(src) | |||
| #define MS_AVX512_WIDTH 16 | |||
| #endif | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| class AdamWeightDecayCPUKernel : public CPUKernel { | |||
| public: | |||
| AdamWeightDecayCPUKernel() = default; | |||
| ~AdamWeightDecayCPUKernel() override = default; | |||
| template <typename T> | |||
| void LaunchAdamWeightDecay(T *var, T *m, T *v, float lr, float beta1, float beta2, float epsilon, T *decay, | |||
| const T *gradient, size_t size); | |||
| void InitKernel(const CNodePtr &kernel_node) override; | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs) override; | |||
| private: | |||
| }; | |||
| MS_REG_CPU_KERNEL(AdamWeightDecay, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32), | |||
| AdamWeightDecayCPUKernel) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_ADAM_WEIGHT_DECAY_CPU_KERNEL_H_ | |||
| @@ -495,7 +495,7 @@ class AdamWeightDecay(PrimitiveWithInfer): | |||
| - **v** (Tensor) - The same shape and data type as `v`. | |||
| Supported Platforms: | |||
| ``GPU`` | |||
| ``GPU`` ``CPU`` | |||
| Examples: | |||
| >>> import numpy as np | |||
| @@ -0,0 +1,162 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """AdamWeightDecay, a customized Adam for pangu1. Input: gradient.""" | |||
| import numpy as np | |||
| from mindspore.common import dtype as mstype | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops import composite as C | |||
| from mindspore.ops import functional as F | |||
| from mindspore.common.tensor import Tensor | |||
| from mindspore._checkparam import Validator as validator | |||
| from mindspore._checkparam import Rel | |||
| from mindspore.nn.optim.optimizer import Optimizer | |||
| _adam_opt = C.MultitypeFuncGraph("adam_opt") | |||
| _scaler_one = Tensor(1, mstype.int32) | |||
| _scaler_ten = Tensor(10, mstype.float32) | |||
| @_adam_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor", | |||
| "Tensor", "Bool", "Bool") | |||
| def _update_run_kernel(opt, beta1, beta2, eps, lr, weight_decay, param, m, v, gradient, decay_flags, optim_filter): | |||
| """ | |||
| Update parameters by AdamWeightDecay op. | |||
| """ | |||
| if optim_filter: | |||
| op_cast = P.Cast() | |||
| gradient_fp32 = op_cast(gradient, mstype.float32) | |||
| if decay_flags: | |||
| next_param = opt(param, m, v, lr, beta1, beta2, eps, F.cast(weight_decay, mstype.float32), gradient_fp32) | |||
| else: | |||
| next_param = opt(param, m, v, lr, beta1, beta2, eps, F.cast(0.0, mstype.float32), gradient_fp32) | |||
| return next_param | |||
| return gradient | |||
| def _check_param_value(beta1, beta2, eps, prim_name): | |||
| """Check the type of inputs.""" | |||
| validator.check_value_type("beta1", beta1, [float], prim_name) | |||
| validator.check_value_type("beta2", beta2, [float], prim_name) | |||
| validator.check_value_type("eps", eps, [float], prim_name) | |||
| validator.check_float_range(beta1, 0.0, 1.0, Rel.INC_NEITHER, "beta1", prim_name) | |||
| validator.check_float_range(beta2, 0.0, 1.0, Rel.INC_NEITHER, "beta2", prim_name) | |||
| validator.check_positive_float(eps, "eps", prim_name) | |||
| class AdamWeightDecayOp(Optimizer): | |||
| """ | |||
| Implements the Adam algorithm to fix the weight decay. It is a complete operator, not a combination of other ops. | |||
| Note: | |||
| When separating parameter groups, the weight decay in each group will be applied on the parameters if the | |||
| weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied | |||
| on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive. | |||
| To improve parameter groups performance, the customized order of parameters can be supported. | |||
| Args: | |||
| params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated, | |||
| the element in `params` must be class `Parameter`. When the `params` is a list of `dict`, the "params", | |||
| "lr", "weight_decay" and "order_params" are the keys can be parsed. | |||
| - params: Required. The value must be a list of `Parameter`. | |||
| - lr: Optional. If "lr" is in the keys, the value of the corresponding learning rate will be used. | |||
| If not, the `learning_rate` in the API will be used. | |||
| - weight_decay: Optional. If "weight_decay" is in the keys, the value of the corresponding weight decay | |||
| will be used. If not, the `weight_decay` in the API will be used. | |||
| - order_params: Optional. If "order_params" is in the keys, the value must be the order of parameters and | |||
| the order will be followed in the optimizer. There are no other keys in the `dict` and the parameters | |||
| which in the 'order_params' must be in one of group parameters. | |||
| learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or a graph for the learning rate. | |||
| When the learning_rate is an Iterable or a Tensor in a 1D dimension, use the dynamic learning rate, then | |||
| the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule, | |||
| use dynamic learning rate, the i-th learning rate will be calculated during the process of training | |||
| according to the formula of LearningRateSchedule. When the learning_rate is a float or a Tensor in a zero | |||
| dimension, use fixed learning rate. Other cases are not supported. The float learning rate must be | |||
| equal to or greater than 0. If the type of `learning_rate` is int, it will be converted to float. | |||
| Default: 1e-3. | |||
| beta1 (float): The exponential decay rate for the 1st moment estimations. Default: 0.9. | |||
| Should be in range (0.0, 1.0). | |||
| beta2 (float): The exponential decay rate for the 2nd moment estimations. Default: 0.999. | |||
| Should be in range (0.0, 1.0). | |||
| eps (float): Term added to the denominator to improve numerical stability. Default: 1e-6. | |||
| Should be greater than 0. | |||
| weight_decay (float): Weight decay (L2 penalty). It must be equal to or greater than 0. Default: 0.0. | |||
| Inputs: | |||
| - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`. | |||
| Outputs: | |||
| tuple[bool], all elements are True. | |||
| Supported Platforms: | |||
| ``CPU`` | |||
| Examples: | |||
| >>> net = Net() | |||
| >>> #1) All parameters use the same learning rate and weight decay | |||
| >>> optim = AdamWeightDecayOp(params=net.trainable_params()) | |||
| >>> | |||
| >>> #2) Use parameter groups and set different values | |||
| >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) | |||
| >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) | |||
| >>> group_params = [{'params': conv_params, 'weight_decay': 0.01}, | |||
| ... {'params': no_conv_params, 'lr': 0.01}, | |||
| ... {'order_params': net.trainable_params()}] | |||
| >>> optim = AdamWeightDecayOp(group_params, learning_rate=0.1, weight_decay=0.0) | |||
| >>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01. | |||
| >>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0. | |||
| >>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'. | |||
| >>> | |||
| >>> loss = nn.SoftmaxCrossEntropyWithLogits() | |||
| >>> model = Model(net, loss_fn=loss, optimizer=optim) | |||
| """ | |||
| def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0): | |||
| super(AdamWeightDecayOp, self).__init__(learning_rate, params, weight_decay) | |||
| _check_param_value(beta1, beta2, eps, self.cls_name) | |||
| self.beta1 = Tensor(np.array([beta1]).astype(np.float32)) | |||
| self.beta2 = Tensor(np.array([beta2]).astype(np.float32)) | |||
| self.eps = Tensor(np.array([eps]).astype(np.float32)) | |||
| self.moments1 = self.parameters.clone(prefix="adam_m", init='zeros') | |||
| self.moments2 = self.parameters.clone(prefix="adam_v", init='zeros') | |||
| self.hyper_map = C.HyperMap() | |||
| self.opt = P.AdamWeightDecay() | |||
| self.opt.add_prim_attr("primitive_target", "CPU") | |||
| def construct(self, gradients): | |||
| """AdamWeightDecayOp""" | |||
| lr = self.get_lr() | |||
| if self.is_group: | |||
| if self.is_group_lr: | |||
| optim_result = self.map_(F.partial(_adam_opt, self.opt, self.beta1, self.beta2, self.eps), | |||
| lr, self.weight_decay, self.parameters, self.moments1, self.moments2, | |||
| gradients, self.decay_flags, self.optim_filter) | |||
| else: | |||
| optim_result = self.map_(F.partial(_adam_opt, self.opt, self.beta1, self.beta2, self.eps, lr), | |||
| self.weight_decay, self.parameters, self.moments1, self.moments2, | |||
| gradients, self.decay_flags, self.optim_filter) | |||
| else: | |||
| optim_result = self.map_(F.partial(_adam_opt, self.opt, self.beta1, self.beta2, self.eps, lr, | |||
| self.weight_decay), | |||
| self.parameters, self.moments1, self.moments2, gradients, self.decay_flags, | |||
| self.optim_filter) | |||
| if self.use_parallel: | |||
| self.broadcast_params(optim_result) | |||
| return optim_result | |||
| @@ -0,0 +1,66 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore.nn import Dense | |||
| from mindspore.nn import TrainOneStepCell, WithLossCell | |||
| from mindspore.ops import operations as P | |||
| from model_zoo.official.nlp.gpt.src.adam import AdamWeightDecayOp | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||
| class NetAdamWeightDecay(nn.Cell): | |||
| def __init__(self): | |||
| super(NetAdamWeightDecay, self).__init__() | |||
| self.batch_size = 1 | |||
| self.reshape = P.Reshape() | |||
| weight = Tensor(np.ones([10, 16]).astype(np.float32) * 0.01) | |||
| self.fc1 = Dense(16, 10, weight_init=weight) | |||
| def construct(self, input_x): | |||
| output = self.reshape(input_x, (self.batch_size, -1)) | |||
| output = self.fc1(output) | |||
| return output | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_adam_weight_decay(): | |||
| epoch = 3 | |||
| net = NetAdamWeightDecay() | |||
| optimizer = AdamWeightDecayOp(filter(lambda x: x.requires_grad, | |||
| net.get_parameters()), learning_rate=0.01) | |||
| criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') | |||
| net_with_criterion = WithLossCell(net, criterion) | |||
| train_network = TrainOneStepCell( | |||
| net_with_criterion, optimizer) | |||
| train_network.set_train() | |||
| losses1 = [] | |||
| for _ in range(epoch): | |||
| data = Tensor(np.arange(0, 16).reshape( | |||
| 1, 1, 4, 4).astype(np.float32) * 0.01) | |||
| label = Tensor(np.array([0]).astype(np.int32)) | |||
| loss = train_network(data, label) | |||
| losses1.append(loss.asnumpy()) | |||
| assert losses1[0] > losses1[1] | |||
| assert losses1[1] > losses1[2] | |||