diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/momentum_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/momentum_impl.cu index 21c7e4b786..b4afb20823 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/momentum_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/momentum_impl.cu @@ -75,9 +75,9 @@ void MomentumUpdateVariable(const size_t size, T *variable, T *accumulation, con } template -__global__ void FusedMomentumWeightDecayScaleMomentum(const size_t element_num, T *weight_decay, T *scale, T *variable, - T *accumulation, const T *learning_rate, const S *gradient, - const T *momentum) { +__global__ void FusedMomentumWeightDecayScaleKernel(const size_t element_num, T *weight_decay, T *scale, T *variable, + T *accumulation, const T *learning_rate, const S *gradient, + const T *momentum) { for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (element_num); i += blockDim.x * gridDim.x) { T grad = (variable[i] * weight_decay[0] + static_cast(gradient[i])) * scale[0]; accumulation[i] = momentum[0] * accumulation[i] + grad; @@ -91,13 +91,13 @@ void FusedWeightDecayScaleMomentum(const size_t element_num, T *weight_decay, T cudaStream_t cuda_stream) { size_t thread_per_block = 256; size_t block_per_grid = (element_num + thread_per_block - 1) / thread_per_block; - FusedMomentumWeightDecayScaleMomentum<<>>( + FusedMomentumWeightDecayScaleKernel<<>>( element_num, weight_decay, scale, variable, accumulation, learning_rate, gradient, momentum); } template -__global__ void FusedMomentumScaleMomentum(const size_t element_num, T *scale, T *variable, T *accumulation, - const T *learning_rate, const S *gradient, const T *momentum) { +__global__ void FusedMomentumScaleKernel(const size_t element_num, T *scale, T *variable, T *accumulation, + const T *learning_rate, const S *gradient, const T *momentum) { for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (element_num); i += blockDim.x * gridDim.x) { accumulation[i] = momentum[0] * accumulation[i] + static_cast(gradient[i]) * scale[0]; variable[i] -= learning_rate[0] * accumulation[i]; @@ -109,15 +109,33 @@ void FusedScaleMomentum(const size_t element_num, T *scale, T *variable, T *accu const S *gradient, const T *momentum, cudaStream_t cuda_stream) { size_t thread_per_block = 256; size_t block_per_grid = (element_num + thread_per_block - 1) / thread_per_block; - FusedMomentumScaleMomentum<<>>( + FusedMomentumScaleKernel<<>>( element_num, scale, variable, accumulation, learning_rate, gradient, momentum); } +template +__global__ void FusedWeightDecayMomentumKernel(const size_t element_num, T *weight_decay, T *variable, T *accumulation, + const T *learning_rate, const S *gradient, const T *momentum) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (element_num); i += blockDim.x * gridDim.x) { + T grad = variable[i] * weight_decay[0] + static_cast(gradient[i]); + accumulation[i] = momentum[0] * accumulation[i] + grad; + variable[i] -= learning_rate[0] * accumulation[i]; + } +} + +template +void FusedWeightDecayMomentum(const size_t element_num, T *weight_decay, T *variable, T *accumulation, + const T *learning_rate, const S *gradient, const T *momentum, cudaStream_t cuda_stream) { + size_t thread_per_block = 256; + size_t block_per_grid = (element_num + thread_per_block - 1) / thread_per_block; + FusedWeightDecayMomentumKernel<<>>( + element_num, weight_decay, variable, accumulation, learning_rate, gradient, momentum); +} + // CombineFusedScaleMomentum template -__global__ void CombineFusedMomentumScaleMomentum(const size_t num, const size_t *element_num, - T **scale, T **variable, T **accumulation, - T **learning_rate, S **gradient, T **momentum) { +__global__ void CombineFusedMomentumScaleKernel(const size_t num, const size_t *element_num, T **scale, T **variable, + T **accumulation, T **learning_rate, S **gradient, T **momentum) { for (size_t idx = 0; idx < num; idx++) { for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (element_num[idx]); i += blockDim.x * gridDim.x) { accumulation[idx][i] = momentum[idx][0] * accumulation[idx][i] + static_cast(gradient[idx][i]) * scale[idx][0]; @@ -127,22 +145,21 @@ __global__ void CombineFusedMomentumScaleMomentum(const size_t num, const size_t } template -void CombineFusedScaleMomentum(const size_t max, const size_t num, const size_t *elements, T **scale, - T **variable, T **accumulation, T **learning_rate, S **gradient, - T **momentum, cudaStream_t cuda_stream) { +void CombineFusedScaleMomentum(const size_t max, const size_t num, const size_t *elements, T **scale, T **variable, + T **accumulation, T **learning_rate, S **gradient, T **momentum, + cudaStream_t cuda_stream) { size_t thread_per_block = 256; size_t block_per_grid = (max + thread_per_block - 1) / thread_per_block; - CombineFusedMomentumScaleMomentum<<>>( + CombineFusedMomentumScaleKernel<<>>( num, elements, scale, variable, accumulation, learning_rate, gradient, momentum); } // end CombineFusedScaleMomentum // CombineFusedWeightDecayScaleMomentum template -__global__ void CombineFusedMomentumWeightDecayScaleMomentum(const size_t num, const size_t *element_num, - T **weight_decay, T **scale, T **variable, - T **accumulation, T **learning_rate, S **gradient, - T **momentum) { +__global__ void CombineFusedMomentumWeightDecayScaleKernel(const size_t num, const size_t *element_num, + T **weight_decay, T **scale, T **variable, T **accumulation, + T **learning_rate, S **gradient, T **momentum) { for (size_t idx = 0; idx < num; idx++) { for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (element_num[idx]); i += blockDim.x * gridDim.x) { T grad = (variable[idx][i] * weight_decay[idx][0] + static_cast(gradient[idx][i])) * scale[idx][0]; @@ -155,11 +172,10 @@ __global__ void CombineFusedMomentumWeightDecayScaleMomentum(const size_t num, c template void CombineFusedWeightDecayScaleMomentum(const size_t max, const size_t num, const size_t *element_num, T **weight_decay, T **scale, T **variable, T **accumulation, - T **learning_rate, S **gradient, T **momentum, - cudaStream_t cuda_stream) { + T **learning_rate, S **gradient, T **momentum, cudaStream_t cuda_stream) { size_t thread_per_block = 256; size_t block_per_grid = (max + thread_per_block - 1) / thread_per_block; - CombineFusedMomentumWeightDecayScaleMomentum<<>>( + CombineFusedMomentumWeightDecayScaleKernel<<>>( num, element_num, weight_decay, scale, variable, accumulation, learning_rate, gradient, momentum); } // end CombineFusedWeightDecayScaleMomentum @@ -186,6 +202,12 @@ template void FusedWeightDecayScaleMomentum(const size_t element_num, float *wei template void FusedWeightDecayScaleMomentum(const size_t element_num, float *weight_decay, float *scale, float *variable, float *accumulation, const float *learning_rate, const half *gradient, const float *momentum, cudaStream_t cuda_stream); +template void FusedWeightDecayMomentum(const size_t element_num, float *weight_decay, float *variable, + float *accumulation, const float *learning_rate, const float *gradient, + const float *momentum, cudaStream_t cuda_stream); +template void FusedWeightDecayMomentum(const size_t element_num, float *weight_decay, float *variable, + float *accumulation, const float *learning_rate, const half *gradient, + const float *momentum, cudaStream_t cuda_stream); template void FusedScaleMomentum(const size_t element_num, float *scale, float *variable, float *accumulation, const float *learning_rate, const float *gradient, const float *momentum, cudaStream_t cuda_stream); @@ -193,16 +215,16 @@ template void FusedScaleMomentum(const size_t element_num, float *scale, float * const float *learning_rate, const half *gradient, const float *momentum, cudaStream_t cuda_stream); template void CombineFusedWeightDecayScaleMomentum(const size_t max, const size_t num, const size_t *elements, - float **weight_decay, float **scale, float **variable, - float **accumulation, float **learning_rate, float **gradient, - float **momentum, cudaStream_t cuda_stream); + float **weight_decay, float **scale, float **variable, + float **accumulation, float **learning_rate, float **gradient, + float **momentum, cudaStream_t cuda_stream); template void CombineFusedWeightDecayScaleMomentum(const size_t max, const size_t num, const size_t *elements, - float **weight_decay, float **scale, float **variable, - float **accumulation, float **learning_rate, half **gradient, - float **momentum, cudaStream_t cuda_stream); + float **weight_decay, float **scale, float **variable, + float **accumulation, float **learning_rate, half **gradient, + float **momentum, cudaStream_t cuda_stream); template void CombineFusedScaleMomentum(const size_t max, const size_t num, const size_t *elements, float **scale, - float **variable, float **accumulation, float **learning_rate, - float **gradient, float **momentum, cudaStream_t cuda_stream); + float **variable, float **accumulation, float **learning_rate, float **gradient, + float **momentum, cudaStream_t cuda_stream); template void CombineFusedScaleMomentum(const size_t max, const size_t num, const size_t *elements, float **scale, - float **variable, float **accumulation, float **learning_rate, - half **gradient, float **momentum, cudaStream_t cuda_stream); + float **variable, float **accumulation, float **learning_rate, half **gradient, + float **momentum, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/momentum_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/momentum_impl.cuh index 7c9ac40bbd..57b271e9d1 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/momentum_impl.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/momentum_impl.cuh @@ -26,6 +26,9 @@ void FusedWeightDecayScaleMomentum(const size_t element_num, T *weight_decay, T const T *learning_rate, const S *gradient, const T *momentum, cudaStream_t cuda_stream); template +void FusedWeightDecayMomentum(const size_t element_num, T *weight_decay, T *variable, T *accumulation, + const T *learning_rate, const S *gradient, const T *momentum, cudaStream_t cuda_stream); +template void FusedScaleMomentum(const size_t element_num, T *scale, T *variable, T *accumulation, const T *learning_rate, const S *gradient, const T *momentum, cudaStream_t cuda_stream); template diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_weightdecay_momentum_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_weightdecay_momentum_gpu_kernel.cc new file mode 100644 index 0000000000..37e4452198 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_weightdecay_momentum_gpu_kernel.cc @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/fused_weightdecay_momentum_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_TWO(FusedWeightApplyMomentum, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) // weight decay + .AddInputAttr(kNumberTypeFloat32) // variable + .AddInputAttr(kNumberTypeFloat32) // accumulation + .AddInputAttr(kNumberTypeFloat32) // learning_rate + .AddInputAttr(kNumberTypeFloat32) // gradient + .AddInputAttr(kNumberTypeFloat32) // momentum + .AddOutputAttr(kNumberTypeFloat32), + FusedWeightDecayMomentumGpuKernel, float, float) +MS_REG_GPU_KERNEL_TWO(FusedWeightApplyMomentum, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) // variable + .AddInputAttr(kNumberTypeFloat32) // accumulation + .AddInputAttr(kNumberTypeFloat32) // variable + .AddInputAttr(kNumberTypeFloat32) // accumulation + .AddInputAttr(kNumberTypeFloat32) // learning_rate + .AddInputAttr(kNumberTypeFloat16) // gradient + .AddInputAttr(kNumberTypeFloat32) // momentum + .AddOutputAttr(kNumberTypeFloat32), + FusedWeightDecayMomentumGpuKernel, float, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_weightdecay_momentum_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_weightdecay_momentum_gpu_kernel.h new file mode 100644 index 0000000000..1a252079bd --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_weightdecay_momentum_gpu_kernel.h @@ -0,0 +1,85 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_WEIGHTDECAY_MOMENTUM_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_WEIGHTDECAY_MOMENTUM_GPU_KERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/momentum_impl.cuh" +namespace mindspore { +namespace kernel { +template +class FusedWeightDecayMomentumGpuKernel : public GpuKernel { + public: + FusedWeightDecayMomentumGpuKernel() : element_num_(1) {} + ~FusedWeightDecayMomentumGpuKernel() override = default; + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, const std::vector &, + void *stream_ptr) override { + T *weight_decay = GetDeviceAddress(inputs, 0); + T *variable = GetDeviceAddress(inputs, 1); + T *accumulation = GetDeviceAddress(inputs, 2); + T *learning_rate = GetDeviceAddress(inputs, 3); + S *gradient = GetDeviceAddress(inputs, 4); + T *momentum = GetDeviceAddress(inputs, 5); + + FusedWeightDecayMomentum(element_num_, weight_decay, variable, accumulation, learning_rate, gradient, momentum, + reinterpret_cast(stream_ptr)); + return true; + } + bool Init(const CNodePtr &kernel_node) override { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 6) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but FusedMomentum needs 6 inputs."; + return false; + } + + auto variable_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + for (size_t i = 0; i < variable_shape.size(); i++) { + element_num_ *= variable_shape[i]; + } + + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(sizeof(T)); + input_size_list_.push_back(element_num_ * sizeof(T)); + input_size_list_.push_back(element_num_ * sizeof(T)); + input_size_list_.push_back(sizeof(T)); + input_size_list_.push_back(element_num_ * sizeof(S)); + input_size_list_.push_back(sizeof(T)); + output_size_list_.push_back(element_num_ * sizeof(T)); + } + + private: + size_t element_num_; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_WEIGHTDECAY_MOMENTUM_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/optimizer/gpu/apply_momentum_weight_fusion.cc b/mindspore/ccsrc/backend/optimizer/gpu/apply_momentum_weight_fusion.cc new file mode 100644 index 0000000000..3c851bcde4 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/gpu/apply_momentum_weight_fusion.cc @@ -0,0 +1,90 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/gpu/apply_momentum_weight_fusion.h" + +#include +#include +#include + +#include "backend/session/anf_runtime_algorithm.h" +#include "ir/primitive.h" +#include "utils/utils.h" +#include "backend/optimizer/common/helper.h" + +namespace mindspore { +namespace opt { +bool ApplyMomentumWeightDecayFusion::IsScalar(const BaseRef &n) { + if (utils::isa(n)) { + AnfNodePtr in = utils::cast(n); + MS_EXCEPTION_IF_NULL(in); + auto shape = in->Shape()->cast(); + MS_EXCEPTION_IF_NULL(shape); + if (shape->shape().size() != 0) { + return false; + } + auto dtype = in->Type(); + if (dtype->type_id() != kObjectTypeTensorType) { + return false; + } + auto element_type = dyn_cast(dtype)->element()->type_id(); + if (element_type != kNumberTypeFloat32) { + return false; + } + return true; + } + return false; +} + +const BaseRef ApplyMomentumWeightDecayFusion::DefinePattern() const { + VectorRef weight_decay = + VectorRef({prim::kPrimAddN, VectorRef({prim::kPrimMul, variable_, weight_decay_}), gradient_}); + VectorRef apply_momentum = + VectorRef({prim::kPrimApplyMomentum, variable_, accumulation_, learning_rate_, weight_decay, momentum_}); + return apply_momentum; +} + +const AnfNodePtr ApplyMomentumWeightDecayFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, + const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(equiv); + auto weight_decay = utils::cast((*equiv)[weight_decay_]); + auto variable = utils::cast((*equiv)[variable_]); + auto accumulation = utils::cast((*equiv)[accumulation_]); + auto learning_rate = utils::cast((*equiv)[learning_rate_]); + auto gradient = utils::cast((*equiv)[gradient_]); + auto momentum = utils::cast((*equiv)[momentum_]); + MS_EXCEPTION_IF_NULL(weight_decay); + MS_EXCEPTION_IF_NULL(variable); + MS_EXCEPTION_IF_NULL(accumulation); + MS_EXCEPTION_IF_NULL(learning_rate); + MS_EXCEPTION_IF_NULL(gradient); + MS_EXCEPTION_IF_NULL(momentum); + + auto prim = std::make_shared(kFusedWeightApplyMomentum); + MS_EXCEPTION_IF_NULL(prim); + std::vector inputs = {NewValueNode(prim), weight_decay, variable, accumulation, + learning_rate, gradient, momentum}; + auto replace_node = graph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(replace_node); + auto types = {AnfAlgo::GetOutputInferDataType(node, 0)}; + auto shapes = {AnfAlgo::GetOutputInferShape(node, 0)}; + AnfAlgo::SetOutputInferTypeAndShape(types, shapes, replace_node.get()); + replace_node->set_scope(node->scope()); + return replace_node; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/gpu/apply_momentum_weight_fusion.h b/mindspore/ccsrc/backend/optimizer/gpu/apply_momentum_weight_fusion.h new file mode 100644 index 0000000000..1b49394e43 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/gpu/apply_momentum_weight_fusion.h @@ -0,0 +1,51 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_APPLY_MOMENTUM_WEIGHT_DECAY_FUSION_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_APPLY_MOMENTUM_WEIGHT_DECAY_FUSION_H_ + +#include +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class ApplyMomentumWeightDecayFusion : public PatternProcessPass { + public: + explicit ApplyMomentumWeightDecayFusion(bool multigraph = true) + : PatternProcessPass("momentum_weightdecay_fusion", multigraph) { + weight_decay_ = std::make_shared(); + variable_ = std::make_shared(); + accumulation_ = std::make_shared(); + learning_rate_ = std::make_shared(); + gradient_ = std::make_shared(); + momentum_ = std::make_shared(); + } + ~ApplyMomentumWeightDecayFusion() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + static bool IsScalar(const BaseRef &n); + + VarPtr weight_decay_; + VarPtr variable_; + VarPtr accumulation_; + VarPtr learning_rate_; + VarPtr gradient_; + VarPtr momentum_; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_APPLY_MOMENTUM_WEIGHT_DECAY_FUSION_H_ diff --git a/mindspore/ccsrc/backend/session/gpu_session.cc b/mindspore/ccsrc/backend/session/gpu_session.cc index a7dcc161ff..58143233d5 100644 --- a/mindspore/ccsrc/backend/session/gpu_session.cc +++ b/mindspore/ccsrc/backend/session/gpu_session.cc @@ -22,6 +22,7 @@ #include "backend/optimizer/gpu/adam_fusion.h" #include "backend/optimizer/gpu/apply_momentum_weight_scale_fusion.h" #include "backend/optimizer/gpu/apply_momentum_scale_fusion.h" +#include "backend/optimizer/gpu/apply_momentum_weight_fusion.h" #include "backend/optimizer/gpu/batch_norm_relu_fusion.h" #include "backend/optimizer/gpu/batch_norm_relu_grad_fusion.h" #include "backend/optimizer/gpu/batch_norm_add_relu_fusion.h" @@ -125,6 +126,7 @@ void GPUSession::Optimize(const std::shared_ptr &kernel_graph) { pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); + pm->AddPass(std::make_shared()); if (!(context_ptr->get_param(MS_CTX_ENABLE_GRAPH_KERNEL))) { pm->AddPass(std::make_shared("cast_all")); } diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index 8f39918609..1438543186 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -234,6 +234,7 @@ constexpr auto kReduceMeanOpName = "ReduceMean"; constexpr auto kReduceAnyOpName = "ReduceAny"; constexpr auto kReduceAllOpName = "ReduceAll"; constexpr auto kFusedWeightScaleApplyMomentum = "FusedWeightScaleApplyMomentum"; +constexpr auto kFusedWeightApplyMomentum = "FusedWeightApplyMomentum"; constexpr auto kFusedScaleApplyMomentum = "FusedScaleApplyMomentum"; constexpr auto kBasicLSTMCellWeightGradOpName = "BasicLSTMCellWeightGrad"; constexpr auto kBasicLSTMCellInputGradOpName = "BasicLSTMCellInputGrad"; diff --git a/tests/st/ops/gpu/test_momentum_fusion.py b/tests/st/ops/gpu/test_momentum_fusion.py new file mode 100644 index 0000000000..61d4c0f408 --- /dev/null +++ b/tests/st/ops/gpu/test_momentum_fusion.py @@ -0,0 +1,61 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore.common.parameter import Parameter +from mindspore import Tensor +from mindspore.ops import operations as P + +class MomentumFusionNet(nn.Cell): + def __init__(self, var, accum): + super(MomentumFusionNet, self).__init__() + self.op = P.ApplyMomentum() + self.add = P.AddN() + self.mul = P.Mul() + self.var = Parameter(var, name="variable") + self.accum = Parameter(accum, name="accumulate") + self.lr = 0.1 + self.weight_decay = 0.002 + self.moment = 0.98 + + def construct(self, grad): + wd = self.mul(self.var, self.weight_decay) + g = self.add((wd, grad)) + return self.op(self.var, self.accum, self.lr, g, self.moment) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_momentum_fusion(): + np.random.seed(42) + var = Tensor(np.random.randn(10, 20).astype(np.float32)) + accum = Tensor(np.random.randn(10, 20).astype(np.float32)) + grad = Tensor(np.random.randn(10, 20).astype(np.float32)) + + context.set_context(device_target='GPU', mode=context.GRAPH_MODE) + net1 = MomentumFusionNet(var, accum) + _ = net1(grad) + + context.set_context(device_target='GPU', mode=context.PYNATIVE_MODE) + net2 = MomentumFusionNet(var, accum) + _ = net2(grad) + + assert np.allclose(net1.var.data.asnumpy(), net2.var.data.asnumpy(), atol=1e-5) + assert np.allclose(net1.accum.data.asnumpy(), net2.accum.data.asnumpy(), atol=1e-5)