| @@ -75,9 +75,9 @@ void MomentumUpdateVariable(const size_t size, T *variable, T *accumulation, con | |||||
| } | } | ||||
| template <typename T, typename S> | template <typename T, typename S> | ||||
| __global__ void FusedMomentumWeightDecayScaleMomentum(const size_t element_num, T *weight_decay, T *scale, T *variable, | |||||
| T *accumulation, const T *learning_rate, const S *gradient, | |||||
| const T *momentum) { | |||||
| __global__ void FusedMomentumWeightDecayScaleKernel(const size_t element_num, T *weight_decay, T *scale, T *variable, | |||||
| T *accumulation, const T *learning_rate, const S *gradient, | |||||
| const T *momentum) { | |||||
| for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (element_num); i += blockDim.x * gridDim.x) { | for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (element_num); i += blockDim.x * gridDim.x) { | ||||
| T grad = (variable[i] * weight_decay[0] + static_cast<T>(gradient[i])) * scale[0]; | T grad = (variable[i] * weight_decay[0] + static_cast<T>(gradient[i])) * scale[0]; | ||||
| accumulation[i] = momentum[0] * accumulation[i] + grad; | accumulation[i] = momentum[0] * accumulation[i] + grad; | ||||
| @@ -91,13 +91,13 @@ void FusedWeightDecayScaleMomentum(const size_t element_num, T *weight_decay, T | |||||
| cudaStream_t cuda_stream) { | cudaStream_t cuda_stream) { | ||||
| size_t thread_per_block = 256; | size_t thread_per_block = 256; | ||||
| size_t block_per_grid = (element_num + thread_per_block - 1) / thread_per_block; | size_t block_per_grid = (element_num + thread_per_block - 1) / thread_per_block; | ||||
| FusedMomentumWeightDecayScaleMomentum<<<block_per_grid, thread_per_block, 0, cuda_stream>>>( | |||||
| FusedMomentumWeightDecayScaleKernel<<<block_per_grid, thread_per_block, 0, cuda_stream>>>( | |||||
| element_num, weight_decay, scale, variable, accumulation, learning_rate, gradient, momentum); | element_num, weight_decay, scale, variable, accumulation, learning_rate, gradient, momentum); | ||||
| } | } | ||||
| template <typename T, typename S> | template <typename T, typename S> | ||||
| __global__ void FusedMomentumScaleMomentum(const size_t element_num, T *scale, T *variable, T *accumulation, | |||||
| const T *learning_rate, const S *gradient, const T *momentum) { | |||||
| __global__ void FusedMomentumScaleKernel(const size_t element_num, T *scale, T *variable, T *accumulation, | |||||
| const T *learning_rate, const S *gradient, const T *momentum) { | |||||
| for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (element_num); i += blockDim.x * gridDim.x) { | for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (element_num); i += blockDim.x * gridDim.x) { | ||||
| accumulation[i] = momentum[0] * accumulation[i] + static_cast<T>(gradient[i]) * scale[0]; | accumulation[i] = momentum[0] * accumulation[i] + static_cast<T>(gradient[i]) * scale[0]; | ||||
| variable[i] -= learning_rate[0] * accumulation[i]; | variable[i] -= learning_rate[0] * accumulation[i]; | ||||
| @@ -109,15 +109,33 @@ void FusedScaleMomentum(const size_t element_num, T *scale, T *variable, T *accu | |||||
| const S *gradient, const T *momentum, cudaStream_t cuda_stream) { | const S *gradient, const T *momentum, cudaStream_t cuda_stream) { | ||||
| size_t thread_per_block = 256; | size_t thread_per_block = 256; | ||||
| size_t block_per_grid = (element_num + thread_per_block - 1) / thread_per_block; | size_t block_per_grid = (element_num + thread_per_block - 1) / thread_per_block; | ||||
| FusedMomentumScaleMomentum<<<block_per_grid, thread_per_block, 0, cuda_stream>>>( | |||||
| FusedMomentumScaleKernel<<<block_per_grid, thread_per_block, 0, cuda_stream>>>( | |||||
| element_num, scale, variable, accumulation, learning_rate, gradient, momentum); | element_num, scale, variable, accumulation, learning_rate, gradient, momentum); | ||||
| } | } | ||||
| template <typename T, typename S> | |||||
| __global__ void FusedWeightDecayMomentumKernel(const size_t element_num, T *weight_decay, T *variable, T *accumulation, | |||||
| const T *learning_rate, const S *gradient, const T *momentum) { | |||||
| for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (element_num); i += blockDim.x * gridDim.x) { | |||||
| T grad = variable[i] * weight_decay[0] + static_cast<T>(gradient[i]); | |||||
| accumulation[i] = momentum[0] * accumulation[i] + grad; | |||||
| variable[i] -= learning_rate[0] * accumulation[i]; | |||||
| } | |||||
| } | |||||
| template <typename T, typename S> | |||||
| void FusedWeightDecayMomentum(const size_t element_num, T *weight_decay, T *variable, T *accumulation, | |||||
| const T *learning_rate, const S *gradient, const T *momentum, cudaStream_t cuda_stream) { | |||||
| size_t thread_per_block = 256; | |||||
| size_t block_per_grid = (element_num + thread_per_block - 1) / thread_per_block; | |||||
| FusedWeightDecayMomentumKernel<<<block_per_grid, thread_per_block, 0, cuda_stream>>>( | |||||
| element_num, weight_decay, variable, accumulation, learning_rate, gradient, momentum); | |||||
| } | |||||
| // CombineFusedScaleMomentum | // CombineFusedScaleMomentum | ||||
| template <typename T, typename S> | template <typename T, typename S> | ||||
| __global__ void CombineFusedMomentumScaleMomentum(const size_t num, const size_t *element_num, | |||||
| T **scale, T **variable, T **accumulation, | |||||
| T **learning_rate, S **gradient, T **momentum) { | |||||
| __global__ void CombineFusedMomentumScaleKernel(const size_t num, const size_t *element_num, T **scale, T **variable, | |||||
| T **accumulation, T **learning_rate, S **gradient, T **momentum) { | |||||
| for (size_t idx = 0; idx < num; idx++) { | for (size_t idx = 0; idx < num; idx++) { | ||||
| for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (element_num[idx]); i += blockDim.x * gridDim.x) { | for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (element_num[idx]); i += blockDim.x * gridDim.x) { | ||||
| accumulation[idx][i] = momentum[idx][0] * accumulation[idx][i] + static_cast<T>(gradient[idx][i]) * scale[idx][0]; | accumulation[idx][i] = momentum[idx][0] * accumulation[idx][i] + static_cast<T>(gradient[idx][i]) * scale[idx][0]; | ||||
| @@ -127,22 +145,21 @@ __global__ void CombineFusedMomentumScaleMomentum(const size_t num, const size_t | |||||
| } | } | ||||
| template <typename T, typename S> | template <typename T, typename S> | ||||
| void CombineFusedScaleMomentum(const size_t max, const size_t num, const size_t *elements, T **scale, | |||||
| T **variable, T **accumulation, T **learning_rate, S **gradient, | |||||
| T **momentum, cudaStream_t cuda_stream) { | |||||
| void CombineFusedScaleMomentum(const size_t max, const size_t num, const size_t *elements, T **scale, T **variable, | |||||
| T **accumulation, T **learning_rate, S **gradient, T **momentum, | |||||
| cudaStream_t cuda_stream) { | |||||
| size_t thread_per_block = 256; | size_t thread_per_block = 256; | ||||
| size_t block_per_grid = (max + thread_per_block - 1) / thread_per_block; | size_t block_per_grid = (max + thread_per_block - 1) / thread_per_block; | ||||
| CombineFusedMomentumScaleMomentum<<<block_per_grid, thread_per_block, 0, cuda_stream>>>( | |||||
| CombineFusedMomentumScaleKernel<<<block_per_grid, thread_per_block, 0, cuda_stream>>>( | |||||
| num, elements, scale, variable, accumulation, learning_rate, gradient, momentum); | num, elements, scale, variable, accumulation, learning_rate, gradient, momentum); | ||||
| } | } | ||||
| // end CombineFusedScaleMomentum | // end CombineFusedScaleMomentum | ||||
| // CombineFusedWeightDecayScaleMomentum | // CombineFusedWeightDecayScaleMomentum | ||||
| template <typename T, typename S> | template <typename T, typename S> | ||||
| __global__ void CombineFusedMomentumWeightDecayScaleMomentum(const size_t num, const size_t *element_num, | |||||
| T **weight_decay, T **scale, T **variable, | |||||
| T **accumulation, T **learning_rate, S **gradient, | |||||
| T **momentum) { | |||||
| __global__ void CombineFusedMomentumWeightDecayScaleKernel(const size_t num, const size_t *element_num, | |||||
| T **weight_decay, T **scale, T **variable, T **accumulation, | |||||
| T **learning_rate, S **gradient, T **momentum) { | |||||
| for (size_t idx = 0; idx < num; idx++) { | for (size_t idx = 0; idx < num; idx++) { | ||||
| for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (element_num[idx]); i += blockDim.x * gridDim.x) { | for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (element_num[idx]); i += blockDim.x * gridDim.x) { | ||||
| T grad = (variable[idx][i] * weight_decay[idx][0] + static_cast<T>(gradient[idx][i])) * scale[idx][0]; | T grad = (variable[idx][i] * weight_decay[idx][0] + static_cast<T>(gradient[idx][i])) * scale[idx][0]; | ||||
| @@ -155,11 +172,10 @@ __global__ void CombineFusedMomentumWeightDecayScaleMomentum(const size_t num, c | |||||
| template <typename T, typename S> | template <typename T, typename S> | ||||
| void CombineFusedWeightDecayScaleMomentum(const size_t max, const size_t num, const size_t *element_num, | void CombineFusedWeightDecayScaleMomentum(const size_t max, const size_t num, const size_t *element_num, | ||||
| T **weight_decay, T **scale, T **variable, T **accumulation, | T **weight_decay, T **scale, T **variable, T **accumulation, | ||||
| T **learning_rate, S **gradient, T **momentum, | |||||
| cudaStream_t cuda_stream) { | |||||
| T **learning_rate, S **gradient, T **momentum, cudaStream_t cuda_stream) { | |||||
| size_t thread_per_block = 256; | size_t thread_per_block = 256; | ||||
| size_t block_per_grid = (max + thread_per_block - 1) / thread_per_block; | size_t block_per_grid = (max + thread_per_block - 1) / thread_per_block; | ||||
| CombineFusedMomentumWeightDecayScaleMomentum<<<block_per_grid, thread_per_block, 0, cuda_stream>>>( | |||||
| CombineFusedMomentumWeightDecayScaleKernel<<<block_per_grid, thread_per_block, 0, cuda_stream>>>( | |||||
| num, element_num, weight_decay, scale, variable, accumulation, learning_rate, gradient, momentum); | num, element_num, weight_decay, scale, variable, accumulation, learning_rate, gradient, momentum); | ||||
| } | } | ||||
| // end CombineFusedWeightDecayScaleMomentum | // end CombineFusedWeightDecayScaleMomentum | ||||
| @@ -186,6 +202,12 @@ template void FusedWeightDecayScaleMomentum(const size_t element_num, float *wei | |||||
| template void FusedWeightDecayScaleMomentum(const size_t element_num, float *weight_decay, float *scale, | template void FusedWeightDecayScaleMomentum(const size_t element_num, float *weight_decay, float *scale, | ||||
| float *variable, float *accumulation, const float *learning_rate, | float *variable, float *accumulation, const float *learning_rate, | ||||
| const half *gradient, const float *momentum, cudaStream_t cuda_stream); | const half *gradient, const float *momentum, cudaStream_t cuda_stream); | ||||
| template void FusedWeightDecayMomentum(const size_t element_num, float *weight_decay, float *variable, | |||||
| float *accumulation, const float *learning_rate, const float *gradient, | |||||
| const float *momentum, cudaStream_t cuda_stream); | |||||
| template void FusedWeightDecayMomentum(const size_t element_num, float *weight_decay, float *variable, | |||||
| float *accumulation, const float *learning_rate, const half *gradient, | |||||
| const float *momentum, cudaStream_t cuda_stream); | |||||
| template void FusedScaleMomentum(const size_t element_num, float *scale, float *variable, float *accumulation, | template void FusedScaleMomentum(const size_t element_num, float *scale, float *variable, float *accumulation, | ||||
| const float *learning_rate, const float *gradient, const float *momentum, | const float *learning_rate, const float *gradient, const float *momentum, | ||||
| cudaStream_t cuda_stream); | cudaStream_t cuda_stream); | ||||
| @@ -193,16 +215,16 @@ template void FusedScaleMomentum(const size_t element_num, float *scale, float * | |||||
| const float *learning_rate, const half *gradient, const float *momentum, | const float *learning_rate, const half *gradient, const float *momentum, | ||||
| cudaStream_t cuda_stream); | cudaStream_t cuda_stream); | ||||
| template void CombineFusedWeightDecayScaleMomentum(const size_t max, const size_t num, const size_t *elements, | template void CombineFusedWeightDecayScaleMomentum(const size_t max, const size_t num, const size_t *elements, | ||||
| float **weight_decay, float **scale, float **variable, | |||||
| float **accumulation, float **learning_rate, float **gradient, | |||||
| float **momentum, cudaStream_t cuda_stream); | |||||
| float **weight_decay, float **scale, float **variable, | |||||
| float **accumulation, float **learning_rate, float **gradient, | |||||
| float **momentum, cudaStream_t cuda_stream); | |||||
| template void CombineFusedWeightDecayScaleMomentum(const size_t max, const size_t num, const size_t *elements, | template void CombineFusedWeightDecayScaleMomentum(const size_t max, const size_t num, const size_t *elements, | ||||
| float **weight_decay, float **scale, float **variable, | |||||
| float **accumulation, float **learning_rate, half **gradient, | |||||
| float **momentum, cudaStream_t cuda_stream); | |||||
| float **weight_decay, float **scale, float **variable, | |||||
| float **accumulation, float **learning_rate, half **gradient, | |||||
| float **momentum, cudaStream_t cuda_stream); | |||||
| template void CombineFusedScaleMomentum(const size_t max, const size_t num, const size_t *elements, float **scale, | template void CombineFusedScaleMomentum(const size_t max, const size_t num, const size_t *elements, float **scale, | ||||
| float **variable, float **accumulation, float **learning_rate, | |||||
| float **gradient, float **momentum, cudaStream_t cuda_stream); | |||||
| float **variable, float **accumulation, float **learning_rate, float **gradient, | |||||
| float **momentum, cudaStream_t cuda_stream); | |||||
| template void CombineFusedScaleMomentum(const size_t max, const size_t num, const size_t *elements, float **scale, | template void CombineFusedScaleMomentum(const size_t max, const size_t num, const size_t *elements, float **scale, | ||||
| float **variable, float **accumulation, float **learning_rate, | |||||
| half **gradient, float **momentum, cudaStream_t cuda_stream); | |||||
| float **variable, float **accumulation, float **learning_rate, half **gradient, | |||||
| float **momentum, cudaStream_t cuda_stream); | |||||
| @@ -26,6 +26,9 @@ void FusedWeightDecayScaleMomentum(const size_t element_num, T *weight_decay, T | |||||
| const T *learning_rate, const S *gradient, const T *momentum, | const T *learning_rate, const S *gradient, const T *momentum, | ||||
| cudaStream_t cuda_stream); | cudaStream_t cuda_stream); | ||||
| template <typename T, typename S> | template <typename T, typename S> | ||||
| void FusedWeightDecayMomentum(const size_t element_num, T *weight_decay, T *variable, T *accumulation, | |||||
| const T *learning_rate, const S *gradient, const T *momentum, cudaStream_t cuda_stream); | |||||
| template <typename T, typename S> | |||||
| void FusedScaleMomentum(const size_t element_num, T *scale, T *variable, T *accumulation, const T *learning_rate, | void FusedScaleMomentum(const size_t element_num, T *scale, T *variable, T *accumulation, const T *learning_rate, | ||||
| const S *gradient, const T *momentum, cudaStream_t cuda_stream); | const S *gradient, const T *momentum, cudaStream_t cuda_stream); | ||||
| template <typename T, typename S> | template <typename T, typename S> | ||||
| @@ -0,0 +1,43 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "backend/kernel_compiler/gpu/nn/fused_weightdecay_momentum_gpu_kernel.h" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| MS_REG_GPU_KERNEL_TWO(FusedWeightApplyMomentum, | |||||
| KernelAttr() | |||||
| .AddInputAttr(kNumberTypeFloat32) // weight decay | |||||
| .AddInputAttr(kNumberTypeFloat32) // variable | |||||
| .AddInputAttr(kNumberTypeFloat32) // accumulation | |||||
| .AddInputAttr(kNumberTypeFloat32) // learning_rate | |||||
| .AddInputAttr(kNumberTypeFloat32) // gradient | |||||
| .AddInputAttr(kNumberTypeFloat32) // momentum | |||||
| .AddOutputAttr(kNumberTypeFloat32), | |||||
| FusedWeightDecayMomentumGpuKernel, float, float) | |||||
| MS_REG_GPU_KERNEL_TWO(FusedWeightApplyMomentum, | |||||
| KernelAttr() | |||||
| .AddInputAttr(kNumberTypeFloat32) // variable | |||||
| .AddInputAttr(kNumberTypeFloat32) // accumulation | |||||
| .AddInputAttr(kNumberTypeFloat32) // variable | |||||
| .AddInputAttr(kNumberTypeFloat32) // accumulation | |||||
| .AddInputAttr(kNumberTypeFloat32) // learning_rate | |||||
| .AddInputAttr(kNumberTypeFloat16) // gradient | |||||
| .AddInputAttr(kNumberTypeFloat32) // momentum | |||||
| .AddOutputAttr(kNumberTypeFloat32), | |||||
| FusedWeightDecayMomentumGpuKernel, float, half) | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,85 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_WEIGHTDECAY_MOMENTUM_GPU_KERNEL_H_ | |||||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_WEIGHTDECAY_MOMENTUM_GPU_KERNEL_H_ | |||||
| #include <vector> | |||||
| #include "backend/kernel_compiler/gpu/gpu_kernel.h" | |||||
| #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" | |||||
| #include "backend/kernel_compiler/gpu/cuda_impl/momentum_impl.cuh" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| template <typename T, typename S> | |||||
| class FusedWeightDecayMomentumGpuKernel : public GpuKernel { | |||||
| public: | |||||
| FusedWeightDecayMomentumGpuKernel() : element_num_(1) {} | |||||
| ~FusedWeightDecayMomentumGpuKernel() override = default; | |||||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | |||||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | |||||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, const std::vector<AddressPtr> &, | |||||
| void *stream_ptr) override { | |||||
| T *weight_decay = GetDeviceAddress<T>(inputs, 0); | |||||
| T *variable = GetDeviceAddress<T>(inputs, 1); | |||||
| T *accumulation = GetDeviceAddress<T>(inputs, 2); | |||||
| T *learning_rate = GetDeviceAddress<T>(inputs, 3); | |||||
| S *gradient = GetDeviceAddress<S>(inputs, 4); | |||||
| T *momentum = GetDeviceAddress<T>(inputs, 5); | |||||
| FusedWeightDecayMomentum(element_num_, weight_decay, variable, accumulation, learning_rate, gradient, momentum, | |||||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||||
| return true; | |||||
| } | |||||
| bool Init(const CNodePtr &kernel_node) override { | |||||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||||
| if (input_num != 6) { | |||||
| MS_LOG(ERROR) << "Input number is " << input_num << ", but FusedMomentum needs 6 inputs."; | |||||
| return false; | |||||
| } | |||||
| auto variable_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); | |||||
| for (size_t i = 0; i < variable_shape.size(); i++) { | |||||
| element_num_ *= variable_shape[i]; | |||||
| } | |||||
| InitSizeLists(); | |||||
| return true; | |||||
| } | |||||
| protected: | |||||
| void InitSizeLists() override { | |||||
| input_size_list_.push_back(sizeof(T)); | |||||
| input_size_list_.push_back(element_num_ * sizeof(T)); | |||||
| input_size_list_.push_back(element_num_ * sizeof(T)); | |||||
| input_size_list_.push_back(sizeof(T)); | |||||
| input_size_list_.push_back(element_num_ * sizeof(S)); | |||||
| input_size_list_.push_back(sizeof(T)); | |||||
| output_size_list_.push_back(element_num_ * sizeof(T)); | |||||
| } | |||||
| private: | |||||
| size_t element_num_; | |||||
| std::vector<size_t> input_size_list_; | |||||
| std::vector<size_t> output_size_list_; | |||||
| std::vector<size_t> workspace_size_list_; | |||||
| }; | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_WEIGHTDECAY_MOMENTUM_GPU_KERNEL_H_ | |||||
| @@ -0,0 +1,90 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "backend/optimizer/gpu/apply_momentum_weight_fusion.h" | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include <string> | |||||
| #include "backend/session/anf_runtime_algorithm.h" | |||||
| #include "ir/primitive.h" | |||||
| #include "utils/utils.h" | |||||
| #include "backend/optimizer/common/helper.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| bool ApplyMomentumWeightDecayFusion::IsScalar(const BaseRef &n) { | |||||
| if (utils::isa<AnfNodePtr>(n)) { | |||||
| AnfNodePtr in = utils::cast<AnfNodePtr>(n); | |||||
| MS_EXCEPTION_IF_NULL(in); | |||||
| auto shape = in->Shape()->cast<abstract::ShapePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(shape); | |||||
| if (shape->shape().size() != 0) { | |||||
| return false; | |||||
| } | |||||
| auto dtype = in->Type(); | |||||
| if (dtype->type_id() != kObjectTypeTensorType) { | |||||
| return false; | |||||
| } | |||||
| auto element_type = dyn_cast<TensorType>(dtype)->element()->type_id(); | |||||
| if (element_type != kNumberTypeFloat32) { | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| const BaseRef ApplyMomentumWeightDecayFusion::DefinePattern() const { | |||||
| VectorRef weight_decay = | |||||
| VectorRef({prim::kPrimAddN, VectorRef({prim::kPrimMul, variable_, weight_decay_}), gradient_}); | |||||
| VectorRef apply_momentum = | |||||
| VectorRef({prim::kPrimApplyMomentum, variable_, accumulation_, learning_rate_, weight_decay, momentum_}); | |||||
| return apply_momentum; | |||||
| } | |||||
| const AnfNodePtr ApplyMomentumWeightDecayFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, | |||||
| const EquivPtr &equiv) const { | |||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| MS_EXCEPTION_IF_NULL(equiv); | |||||
| auto weight_decay = utils::cast<AnfNodePtr>((*equiv)[weight_decay_]); | |||||
| auto variable = utils::cast<AnfNodePtr>((*equiv)[variable_]); | |||||
| auto accumulation = utils::cast<AnfNodePtr>((*equiv)[accumulation_]); | |||||
| auto learning_rate = utils::cast<AnfNodePtr>((*equiv)[learning_rate_]); | |||||
| auto gradient = utils::cast<AnfNodePtr>((*equiv)[gradient_]); | |||||
| auto momentum = utils::cast<AnfNodePtr>((*equiv)[momentum_]); | |||||
| MS_EXCEPTION_IF_NULL(weight_decay); | |||||
| MS_EXCEPTION_IF_NULL(variable); | |||||
| MS_EXCEPTION_IF_NULL(accumulation); | |||||
| MS_EXCEPTION_IF_NULL(learning_rate); | |||||
| MS_EXCEPTION_IF_NULL(gradient); | |||||
| MS_EXCEPTION_IF_NULL(momentum); | |||||
| auto prim = std::make_shared<Primitive>(kFusedWeightApplyMomentum); | |||||
| MS_EXCEPTION_IF_NULL(prim); | |||||
| std::vector<AnfNodePtr> inputs = {NewValueNode(prim), weight_decay, variable, accumulation, | |||||
| learning_rate, gradient, momentum}; | |||||
| auto replace_node = graph->NewCNode(inputs); | |||||
| MS_EXCEPTION_IF_NULL(replace_node); | |||||
| auto types = {AnfAlgo::GetOutputInferDataType(node, 0)}; | |||||
| auto shapes = {AnfAlgo::GetOutputInferShape(node, 0)}; | |||||
| AnfAlgo::SetOutputInferTypeAndShape(types, shapes, replace_node.get()); | |||||
| replace_node->set_scope(node->scope()); | |||||
| return replace_node; | |||||
| } | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,51 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_APPLY_MOMENTUM_WEIGHT_DECAY_FUSION_H_ | |||||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_APPLY_MOMENTUM_WEIGHT_DECAY_FUSION_H_ | |||||
| #include <memory> | |||||
| #include "backend/optimizer/common/optimizer.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| class ApplyMomentumWeightDecayFusion : public PatternProcessPass { | |||||
| public: | |||||
| explicit ApplyMomentumWeightDecayFusion(bool multigraph = true) | |||||
| : PatternProcessPass("momentum_weightdecay_fusion", multigraph) { | |||||
| weight_decay_ = std::make_shared<Var>(); | |||||
| variable_ = std::make_shared<Var>(); | |||||
| accumulation_ = std::make_shared<Var>(); | |||||
| learning_rate_ = std::make_shared<Var>(); | |||||
| gradient_ = std::make_shared<Var>(); | |||||
| momentum_ = std::make_shared<Var>(); | |||||
| } | |||||
| ~ApplyMomentumWeightDecayFusion() override = default; | |||||
| const BaseRef DefinePattern() const override; | |||||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||||
| private: | |||||
| static bool IsScalar(const BaseRef &n); | |||||
| VarPtr weight_decay_; | |||||
| VarPtr variable_; | |||||
| VarPtr accumulation_; | |||||
| VarPtr learning_rate_; | |||||
| VarPtr gradient_; | |||||
| VarPtr momentum_; | |||||
| }; | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_APPLY_MOMENTUM_WEIGHT_DECAY_FUSION_H_ | |||||
| @@ -22,6 +22,7 @@ | |||||
| #include "backend/optimizer/gpu/adam_fusion.h" | #include "backend/optimizer/gpu/adam_fusion.h" | ||||
| #include "backend/optimizer/gpu/apply_momentum_weight_scale_fusion.h" | #include "backend/optimizer/gpu/apply_momentum_weight_scale_fusion.h" | ||||
| #include "backend/optimizer/gpu/apply_momentum_scale_fusion.h" | #include "backend/optimizer/gpu/apply_momentum_scale_fusion.h" | ||||
| #include "backend/optimizer/gpu/apply_momentum_weight_fusion.h" | |||||
| #include "backend/optimizer/gpu/batch_norm_relu_fusion.h" | #include "backend/optimizer/gpu/batch_norm_relu_fusion.h" | ||||
| #include "backend/optimizer/gpu/batch_norm_relu_grad_fusion.h" | #include "backend/optimizer/gpu/batch_norm_relu_grad_fusion.h" | ||||
| #include "backend/optimizer/gpu/batch_norm_add_relu_fusion.h" | #include "backend/optimizer/gpu/batch_norm_add_relu_fusion.h" | ||||
| @@ -125,6 +126,7 @@ void GPUSession::Optimize(const std::shared_ptr<KernelGraph> &kernel_graph) { | |||||
| pm->AddPass(std::make_shared<opt::AdamFusion>()); | pm->AddPass(std::make_shared<opt::AdamFusion>()); | ||||
| pm->AddPass(std::make_shared<opt::ApplyMomentumWeightDecayScaleFusion>()); | pm->AddPass(std::make_shared<opt::ApplyMomentumWeightDecayScaleFusion>()); | ||||
| pm->AddPass(std::make_shared<opt::ApplyMomentumScaleFusion>()); | pm->AddPass(std::make_shared<opt::ApplyMomentumScaleFusion>()); | ||||
| pm->AddPass(std::make_shared<opt::ApplyMomentumWeightDecayFusion>()); | |||||
| if (!(context_ptr->get_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL))) { | if (!(context_ptr->get_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL))) { | ||||
| pm->AddPass(std::make_shared<opt::CastAllFusion>("cast_all")); | pm->AddPass(std::make_shared<opt::CastAllFusion>("cast_all")); | ||||
| } | } | ||||
| @@ -234,6 +234,7 @@ constexpr auto kReduceMeanOpName = "ReduceMean"; | |||||
| constexpr auto kReduceAnyOpName = "ReduceAny"; | constexpr auto kReduceAnyOpName = "ReduceAny"; | ||||
| constexpr auto kReduceAllOpName = "ReduceAll"; | constexpr auto kReduceAllOpName = "ReduceAll"; | ||||
| constexpr auto kFusedWeightScaleApplyMomentum = "FusedWeightScaleApplyMomentum"; | constexpr auto kFusedWeightScaleApplyMomentum = "FusedWeightScaleApplyMomentum"; | ||||
| constexpr auto kFusedWeightApplyMomentum = "FusedWeightApplyMomentum"; | |||||
| constexpr auto kFusedScaleApplyMomentum = "FusedScaleApplyMomentum"; | constexpr auto kFusedScaleApplyMomentum = "FusedScaleApplyMomentum"; | ||||
| constexpr auto kBasicLSTMCellWeightGradOpName = "BasicLSTMCellWeightGrad"; | constexpr auto kBasicLSTMCellWeightGradOpName = "BasicLSTMCellWeightGrad"; | ||||
| constexpr auto kBasicLSTMCellInputGradOpName = "BasicLSTMCellInputGrad"; | constexpr auto kBasicLSTMCellInputGradOpName = "BasicLSTMCellInputGrad"; | ||||
| @@ -0,0 +1,61 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import numpy as np | |||||
| import pytest | |||||
| import mindspore.context as context | |||||
| import mindspore.nn as nn | |||||
| from mindspore.common.parameter import Parameter | |||||
| from mindspore import Tensor | |||||
| from mindspore.ops import operations as P | |||||
| class MomentumFusionNet(nn.Cell): | |||||
| def __init__(self, var, accum): | |||||
| super(MomentumFusionNet, self).__init__() | |||||
| self.op = P.ApplyMomentum() | |||||
| self.add = P.AddN() | |||||
| self.mul = P.Mul() | |||||
| self.var = Parameter(var, name="variable") | |||||
| self.accum = Parameter(accum, name="accumulate") | |||||
| self.lr = 0.1 | |||||
| self.weight_decay = 0.002 | |||||
| self.moment = 0.98 | |||||
| def construct(self, grad): | |||||
| wd = self.mul(self.var, self.weight_decay) | |||||
| g = self.add((wd, grad)) | |||||
| return self.op(self.var, self.accum, self.lr, g, self.moment) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_momentum_fusion(): | |||||
| np.random.seed(42) | |||||
| var = Tensor(np.random.randn(10, 20).astype(np.float32)) | |||||
| accum = Tensor(np.random.randn(10, 20).astype(np.float32)) | |||||
| grad = Tensor(np.random.randn(10, 20).astype(np.float32)) | |||||
| context.set_context(device_target='GPU', mode=context.GRAPH_MODE) | |||||
| net1 = MomentumFusionNet(var, accum) | |||||
| _ = net1(grad) | |||||
| context.set_context(device_target='GPU', mode=context.PYNATIVE_MODE) | |||||
| net2 = MomentumFusionNet(var, accum) | |||||
| _ = net2(grad) | |||||
| assert np.allclose(net1.var.data.asnumpy(), net2.var.data.asnumpy(), atol=1e-5) | |||||
| assert np.allclose(net1.accum.data.asnumpy(), net2.accum.data.asnumpy(), atol=1e-5) | |||||