Merge pull request !5612 from chenweifeng/BatchNormAddReluGradtags/v1.0.0
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -26,8 +26,7 @@ __global__ void MomentumUpdateVariableKernel(const size_t size, T *variable, T * | |||
| } | |||
| template <> | |||
| __global__ void MomentumUpdateVariableKernel(const size_t size, half *variable, half *accumulation, | |||
| const float *learning_rate, const half *gradient, | |||
| const float *momentum) { | |||
| const float *learning_rate, const half *gradient, const float *momentum) { | |||
| for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (size); i += blockDim.x * gridDim.x) { | |||
| accumulation[i] = __float2half(momentum[0]) * accumulation[i] + gradient[i]; | |||
| variable[i] -= __float2half(learning_rate[0]) * accumulation[i]; | |||
| @@ -36,8 +35,7 @@ __global__ void MomentumUpdateVariableKernel(const size_t size, half *variable, | |||
| } | |||
| template <> | |||
| __global__ void MomentumUpdateVariableKernel(const size_t size, float *variable, float *accumulation, | |||
| const float *learning_rate, const half *gradient, | |||
| const float *momentum) { | |||
| const float *learning_rate, const half *gradient, const float *momentum) { | |||
| for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (size); i += blockDim.x * gridDim.x) { | |||
| accumulation[i] = momentum[0] * accumulation[i] + __half2float(gradient[i]); | |||
| variable[i] -= learning_rate[0] * accumulation[i]; | |||
| @@ -51,15 +49,68 @@ void MomentumUpdateVariable(const size_t size, T *variable, T *accumulation, con | |||
| learning_rate, gradient, momentum); | |||
| return; | |||
| } | |||
| template <typename T, typename S> | |||
| __global__ void FusedMomentumWeightDecayScaleMomentum(const size_t element_num, T *weight_decay, T *scale, T *variable, | |||
| T *accumulation, const T *learning_rate, const S *gradient, | |||
| const T *momentum) { | |||
| for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (element_num); i += blockDim.x * gridDim.x) { | |||
| T grad = (variable[i] * weight_decay[0] + static_cast<T>(gradient[i])) * scale[0]; | |||
| accumulation[i] = momentum[0] * accumulation[i] + grad; | |||
| variable[i] -= learning_rate[0] * accumulation[i]; | |||
| } | |||
| } | |||
| template <typename T, typename S> | |||
| void FusedWeightDecayScaleMomentum(const size_t element_num, T *weight_decay, T *scale, T *variable, T *accumulation, | |||
| const T *learning_rate, const S *gradient, const T *momentum, | |||
| cudaStream_t cuda_stream) { | |||
| size_t thread_per_block = 256; | |||
| size_t block_per_grid = (element_num + thread_per_block - 1) / thread_per_block; | |||
| FusedMomentumWeightDecayScaleMomentum<<<block_per_grid, thread_per_block, 0, cuda_stream>>>( | |||
| element_num, weight_decay, scale, variable, accumulation, learning_rate, gradient, momentum); | |||
| } | |||
| template <typename T, typename S> | |||
| __global__ void FusedMomentumScaleMomentum(const size_t element_num, T *scale, T *variable, T *accumulation, | |||
| const T *learning_rate, const S *gradient, const T *momentum) { | |||
| for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (element_num); i += blockDim.x * gridDim.x) { | |||
| accumulation[i] = momentum[0] * accumulation[i] + static_cast<T>(gradient[i]); | |||
| variable[i] -= learning_rate[0] * accumulation[i]; | |||
| } | |||
| } | |||
| template <typename T, typename S> | |||
| void FusedScaleMomentum(const size_t element_num, T *scale, T *variable, T *accumulation, const T *learning_rate, | |||
| const S *gradient, const T *momentum, cudaStream_t cuda_stream) { | |||
| size_t thread_per_block = 256; | |||
| size_t block_per_grid = (element_num + thread_per_block - 1) / thread_per_block; | |||
| FusedMomentumScaleMomentum<<<block_per_grid, thread_per_block, 0, cuda_stream>>>( | |||
| element_num, scale, variable, accumulation, learning_rate, gradient, momentum); | |||
| } | |||
| template void MomentumUpdateVariable<float, float, float>(const size_t size, float *variable, float *accumulation, | |||
| const float *learning_rate, const float *gradient, | |||
| const float *momentum, cudaStream_t cuda_stream); | |||
| const float *learning_rate, const float *gradient, | |||
| const float *momentum, cudaStream_t cuda_stream); | |||
| template void MomentumUpdateVariable<half, half, half>(const size_t size, half *variable, half *accumulation, | |||
| const half *learning_rate, const half *gradient, | |||
| const half *momentum, cudaStream_t cuda_stream); | |||
| const half *learning_rate, const half *gradient, | |||
| const half *momentum, cudaStream_t cuda_stream); | |||
| template void MomentumUpdateVariable<half, float, half>(const size_t size, half *variable, half *accumulation, | |||
| const float *learning_rate, const half *gradient, | |||
| const float *momentum, cudaStream_t cuda_stream); | |||
| const float *learning_rate, const half *gradient, | |||
| const float *momentum, cudaStream_t cuda_stream); | |||
| template void MomentumUpdateVariable<float, float, half>(const size_t size, float *variable, float *accumulation, | |||
| const float *learning_rate, const half *gradient, | |||
| const float *momentum, cudaStream_t cuda_stream); | |||
| const float *learning_rate, const half *gradient, | |||
| const float *momentum, cudaStream_t cuda_stream); | |||
| template void FusedWeightDecayScaleMomentum(const size_t element_num, float *weight_decay, float *scale, | |||
| float *variable, float *accumulation, const float *learning_rate, | |||
| const float *gradient, const float *momentum, cudaStream_t cuda_stream); | |||
| template void FusedWeightDecayScaleMomentum(const size_t element_num, float *weight_decay, float *scale, | |||
| float *variable, float *accumulation, const float *learning_rate, | |||
| const half *gradient, const float *momentum, cudaStream_t cuda_stream); | |||
| template void FusedScaleMomentum(const size_t element_num, float *scale, float *variable, float *accumulation, | |||
| const float *learning_rate, const float *gradient, const float *momentum, | |||
| cudaStream_t cuda_stream); | |||
| template void FusedScaleMomentum(const size_t element_num, float *scale, float *variable, float *accumulation, | |||
| const float *learning_rate, const half *gradient, const float *momentum, | |||
| cudaStream_t cuda_stream); | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -21,5 +21,12 @@ | |||
| template <typename T, typename S, typename G> | |||
| void MomentumUpdateVariable(const size_t size, T *variable, T *accumulation, const S *learning_rate, const G *gradient, | |||
| const S *momentum, cudaStream_t cuda_stream); | |||
| template <typename T, typename S> | |||
| void FusedWeightDecayScaleMomentum(const size_t element_num, T *weight_decay, T *scale, T *variable, T *accumulation, | |||
| const T *learning_rate, const S *gradient, const T *momentum, | |||
| cudaStream_t cuda_stream); | |||
| template <typename T, typename S> | |||
| void FusedScaleMomentum(const size_t element_num, T *scale, T *variable, T *accumulation, const T *learning_rate, | |||
| const S *gradient, const T *momentum, cudaStream_t cuda_stream); | |||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_MOMENTUMIMPL_H_ | |||
| @@ -0,0 +1,44 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/kernel_compiler/gpu/nn/fused_scale_momentum_gpu_kernel.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| MS_REG_GPU_KERNEL_TWO(FusedScaleApplyMomentum, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat32) // scale | |||
| .AddInputAttr(kNumberTypeFloat32) // variable | |||
| .AddInputAttr(kNumberTypeFloat32) // accumulation | |||
| .AddInputAttr(kNumberTypeFloat32) // learning_rate | |||
| .AddInputAttr(kNumberTypeFloat32) // gradient | |||
| .AddInputAttr(kNumberTypeFloat32) // momentum | |||
| .AddOutputAttr(kNumberTypeFloat32), | |||
| FusedScaleMomentumGpuKernel, float, float) | |||
| MS_REG_GPU_KERNEL_TWO(FusedScaleApplyMomentum, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat32) // scale | |||
| .AddInputAttr(kNumberTypeFloat32) // variable | |||
| .AddInputAttr(kNumberTypeFloat32) // accumulation | |||
| .AddInputAttr(kNumberTypeFloat32) // variable | |||
| .AddInputAttr(kNumberTypeFloat32) // accumulation | |||
| .AddInputAttr(kNumberTypeFloat32) // learning_rate | |||
| .AddInputAttr(kNumberTypeFloat16) // gradient | |||
| .AddInputAttr(kNumberTypeFloat32) // momentum | |||
| .AddOutputAttr(kNumberTypeFloat32), | |||
| FusedScaleMomentumGpuKernel, float, half) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,85 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_SCALE_MOMENTUM_GPU_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_SCALE_MOMENTUM_GPU_KERNEL_H_ | |||
| #include <vector> | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel.h" | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/momentum_impl.cuh" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| template <typename T, typename S> | |||
| class FusedScaleMomentumGpuKernel : public GpuKernel { | |||
| public: | |||
| FusedScaleMomentumGpuKernel() : element_num_(1) {} | |||
| ~FusedScaleMomentumGpuKernel() override = default; | |||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | |||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, const std::vector<AddressPtr> &, | |||
| void *stream_ptr) override { | |||
| T *scale = GetDeviceAddress<T>(inputs, 0); | |||
| T *variable = GetDeviceAddress<T>(inputs, 1); | |||
| T *accumulation = GetDeviceAddress<T>(inputs, 2); | |||
| T *learning_rate = GetDeviceAddress<T>(inputs, 3); | |||
| S *gradient = GetDeviceAddress<S>(inputs, 4); | |||
| T *momentum = GetDeviceAddress<T>(inputs, 5); | |||
| FusedScaleMomentum(element_num_, scale, variable, accumulation, learning_rate, gradient, momentum, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| return true; | |||
| } | |||
| bool Init(const CNodePtr &kernel_node) override { | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| if (input_num != 6) { | |||
| MS_LOG(ERROR) << "Input number is " << input_num << ", but FusedMomentum needs 6 inputs."; | |||
| return false; | |||
| } | |||
| auto variable_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| for (size_t i = 0; i < variable_shape.size(); i++) { | |||
| element_num_ *= variable_shape[i]; | |||
| } | |||
| InitSizeLists(); | |||
| return true; | |||
| } | |||
| protected: | |||
| void InitSizeLists() override { | |||
| input_size_list_.push_back(sizeof(T)); | |||
| input_size_list_.push_back(element_num_ * sizeof(T)); | |||
| input_size_list_.push_back(element_num_ * sizeof(T)); | |||
| input_size_list_.push_back(sizeof(T)); | |||
| input_size_list_.push_back(element_num_ * sizeof(S)); | |||
| input_size_list_.push_back(sizeof(T)); | |||
| output_size_list_.push_back(element_num_ * sizeof(T)); | |||
| } | |||
| private: | |||
| size_t element_num_; | |||
| std::vector<size_t> input_size_list_; | |||
| std::vector<size_t> output_size_list_; | |||
| std::vector<size_t> workspace_size_list_; | |||
| }; | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_SCALE_MOMENTUM_GPU_KERNEL_H_ | |||
| @@ -0,0 +1,44 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/kernel_compiler/gpu/nn/fused_weightdecay_scale_momentum_gpu_kernel.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| MS_REG_GPU_KERNEL_TWO(FusedWeightScaleApplyMomentum, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat32) // weight decay | |||
| .AddInputAttr(kNumberTypeFloat32) // scale | |||
| .AddInputAttr(kNumberTypeFloat32) // variable | |||
| .AddInputAttr(kNumberTypeFloat32) // accumulation | |||
| .AddInputAttr(kNumberTypeFloat32) // learning_rate | |||
| .AddInputAttr(kNumberTypeFloat32) // gradient | |||
| .AddInputAttr(kNumberTypeFloat32) // momentum | |||
| .AddOutputAttr(kNumberTypeFloat32), | |||
| FusedWeightDecayScaleMomentumGpuKernel, float, float) | |||
| MS_REG_GPU_KERNEL_TWO(FusedWeightScaleApplyMomentum, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat32) // variable | |||
| .AddInputAttr(kNumberTypeFloat32) // accumulation | |||
| .AddInputAttr(kNumberTypeFloat32) // variable | |||
| .AddInputAttr(kNumberTypeFloat32) // accumulation | |||
| .AddInputAttr(kNumberTypeFloat32) // learning_rate | |||
| .AddInputAttr(kNumberTypeFloat16) // gradient | |||
| .AddInputAttr(kNumberTypeFloat32) // momentum | |||
| .AddOutputAttr(kNumberTypeFloat32), | |||
| FusedWeightDecayScaleMomentumGpuKernel, float, half) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,87 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_WEIGHTDECAY_SCALE_MOMENTUM_GPU_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_WEIGHTDECAY_SCALE_MOMENTUM_GPU_KERNEL_H_ | |||
| #include <vector> | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel.h" | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/momentum_impl.cuh" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| template <typename T, typename S> | |||
| class FusedWeightDecayScaleMomentumGpuKernel : public GpuKernel { | |||
| public: | |||
| FusedWeightDecayScaleMomentumGpuKernel() : element_num_(1) {} | |||
| ~FusedWeightDecayScaleMomentumGpuKernel() override = default; | |||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | |||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, const std::vector<AddressPtr> &, | |||
| void *stream_ptr) override { | |||
| T *weight_decay = GetDeviceAddress<T>(inputs, 0); | |||
| T *scale = GetDeviceAddress<T>(inputs, 1); | |||
| T *variable = GetDeviceAddress<T>(inputs, 2); | |||
| T *accumulation = GetDeviceAddress<T>(inputs, 3); | |||
| T *learning_rate = GetDeviceAddress<T>(inputs, 4); | |||
| S *gradient = GetDeviceAddress<S>(inputs, 5); | |||
| T *momentum = GetDeviceAddress<T>(inputs, 6); | |||
| FusedWeightDecayScaleMomentum(element_num_, weight_decay, scale, variable, accumulation, learning_rate, gradient, | |||
| momentum, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| return true; | |||
| } | |||
| bool Init(const CNodePtr &kernel_node) override { | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| if (input_num != 7) { | |||
| MS_LOG(ERROR) << "Input number is " << input_num << ", but FusedMomentum needs 7 inputs."; | |||
| return false; | |||
| } | |||
| auto variable_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| for (size_t i = 0; i < variable_shape.size(); i++) { | |||
| element_num_ *= variable_shape[i]; | |||
| } | |||
| InitSizeLists(); | |||
| return true; | |||
| } | |||
| protected: | |||
| void InitSizeLists() override { | |||
| input_size_list_.push_back(sizeof(T)); | |||
| input_size_list_.push_back(sizeof(T)); | |||
| input_size_list_.push_back(element_num_ * sizeof(T)); | |||
| input_size_list_.push_back(element_num_ * sizeof(T)); | |||
| input_size_list_.push_back(sizeof(T)); | |||
| input_size_list_.push_back(element_num_ * sizeof(S)); | |||
| input_size_list_.push_back(sizeof(T)); | |||
| output_size_list_.push_back(element_num_ * sizeof(T)); | |||
| } | |||
| private: | |||
| size_t element_num_; | |||
| std::vector<size_t> input_size_list_; | |||
| std::vector<size_t> output_size_list_; | |||
| std::vector<size_t> workspace_size_list_; | |||
| }; | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_WEIGHTDECAY_SCALE_MOMENTUM_GPU_KERNEL_H_ | |||
| @@ -0,0 +1,67 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/optimizer/gpu/apply_momentum_scale_fusion.h" | |||
| #include <memory> | |||
| #include <vector> | |||
| #include <string> | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| #include "ir/primitive.h" | |||
| #include "utils/utils.h" | |||
| #include "backend/optimizer/common/helper.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| const BaseRef ApplyMomentumScaleFusion::DefinePattern() const { | |||
| VectorRef scale = VectorRef({prim::kPrimMul, gradient_, scale_}); | |||
| VectorRef apply_momentum = | |||
| VectorRef({prim::kPrimApplyMomentum, variable_, accumulation_, learning_rate_, scale, momentum_}); | |||
| return apply_momentum; | |||
| } | |||
| const AnfNodePtr ApplyMomentumScaleFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, | |||
| const EquivPtr &equiv) const { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| MS_EXCEPTION_IF_NULL(equiv); | |||
| auto scale = utils::cast<AnfNodePtr>((*equiv)[scale_]); | |||
| auto variable = utils::cast<AnfNodePtr>((*equiv)[variable_]); | |||
| auto accumulation = utils::cast<AnfNodePtr>((*equiv)[accumulation_]); | |||
| auto learning_rate = utils::cast<AnfNodePtr>((*equiv)[learning_rate_]); | |||
| auto gradient = utils::cast<AnfNodePtr>((*equiv)[gradient_]); | |||
| auto momentum = utils::cast<AnfNodePtr>((*equiv)[momentum_]); | |||
| MS_EXCEPTION_IF_NULL(scale); | |||
| MS_EXCEPTION_IF_NULL(variable); | |||
| MS_EXCEPTION_IF_NULL(accumulation); | |||
| MS_EXCEPTION_IF_NULL(learning_rate); | |||
| MS_EXCEPTION_IF_NULL(gradient); | |||
| MS_EXCEPTION_IF_NULL(momentum); | |||
| auto prim = std::make_shared<Primitive>(kFusedScaleApplyMomentum); | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| std::vector<AnfNodePtr> inputs = {NewValueNode(prim), scale, variable, accumulation, | |||
| learning_rate, gradient, momentum}; | |||
| auto replace_node = graph->NewCNode(inputs); | |||
| MS_EXCEPTION_IF_NULL(replace_node); | |||
| auto types = {AnfAlgo::GetOutputInferDataType(node, 0)}; | |||
| auto shapes = {AnfAlgo::GetOutputInferShape(node, 0)}; | |||
| AnfAlgo::SetOutputInferTypeAndShape(types, shapes, replace_node.get()); | |||
| replace_node->set_scope(node->scope()); | |||
| return replace_node; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,48 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_APPLY_MOMENTUM_SCALE_FUSION_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_APPLY_MOMENTUM_SCALE_FUSION_H_ | |||
| #include <memory> | |||
| #include "backend/optimizer/common/optimizer.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class ApplyMomentumScaleFusion : public PatternProcessPass { | |||
| public: | |||
| explicit ApplyMomentumScaleFusion(bool multigraph = true) : PatternProcessPass("momentum_scale_fusion", multigraph) { | |||
| scale_ = std::make_shared<Var>(); | |||
| variable_ = std::make_shared<Var>(); | |||
| accumulation_ = std::make_shared<Var>(); | |||
| learning_rate_ = std::make_shared<Var>(); | |||
| gradient_ = std::make_shared<Var>(); | |||
| momentum_ = std::make_shared<Var>(); | |||
| } | |||
| ~ApplyMomentumScaleFusion() override = default; | |||
| const BaseRef DefinePattern() const override; | |||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||
| private: | |||
| VarPtr scale_; | |||
| VarPtr variable_; | |||
| VarPtr accumulation_; | |||
| VarPtr learning_rate_; | |||
| VarPtr gradient_; | |||
| VarPtr momentum_; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_APPLY_MOMENTUM_SCALE_FUSION_H_ | |||
| @@ -0,0 +1,71 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/optimizer/gpu/apply_momentum_weight_scale_fusion.h" | |||
| #include <memory> | |||
| #include <vector> | |||
| #include <string> | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| #include "ir/primitive.h" | |||
| #include "utils/utils.h" | |||
| #include "backend/optimizer/common/helper.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| const BaseRef ApplyMomentumWeightDecayScaleFusion::DefinePattern() const { | |||
| VectorRef weight = VectorRef( | |||
| {prim::kPrimAddN, VectorRef({prim::kPrimMul, variable_, weight_decay_}), VectorRef({prim::kPrimCast, gradient_})}); | |||
| VectorRef scale = VectorRef({prim::kPrimMul, weight, scale_}); | |||
| VectorRef apply_momentum = | |||
| VectorRef({prim::kPrimApplyMomentum, variable_, accumulation_, learning_rate_, scale, momentum_}); | |||
| return apply_momentum; | |||
| } | |||
| const AnfNodePtr ApplyMomentumWeightDecayScaleFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, | |||
| const EquivPtr &equiv) const { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| MS_EXCEPTION_IF_NULL(equiv); | |||
| auto weight_decay = utils::cast<AnfNodePtr>((*equiv)[weight_decay_]); | |||
| auto scale = utils::cast<AnfNodePtr>((*equiv)[scale_]); | |||
| auto variable = utils::cast<AnfNodePtr>((*equiv)[variable_]); | |||
| auto accumulation = utils::cast<AnfNodePtr>((*equiv)[accumulation_]); | |||
| auto learning_rate = utils::cast<AnfNodePtr>((*equiv)[learning_rate_]); | |||
| auto gradient = utils::cast<AnfNodePtr>((*equiv)[gradient_]); | |||
| auto momentum = utils::cast<AnfNodePtr>((*equiv)[momentum_]); | |||
| MS_EXCEPTION_IF_NULL(weight_decay); | |||
| MS_EXCEPTION_IF_NULL(scale); | |||
| MS_EXCEPTION_IF_NULL(variable); | |||
| MS_EXCEPTION_IF_NULL(accumulation); | |||
| MS_EXCEPTION_IF_NULL(learning_rate); | |||
| MS_EXCEPTION_IF_NULL(gradient); | |||
| MS_EXCEPTION_IF_NULL(momentum); | |||
| auto prim = std::make_shared<Primitive>(kFusedWeightScaleApplyMomentum); | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| std::vector<AnfNodePtr> inputs = {NewValueNode(prim), weight_decay, scale, variable, | |||
| accumulation, learning_rate, gradient, momentum}; | |||
| auto replace_node = graph->NewCNode(inputs); | |||
| MS_EXCEPTION_IF_NULL(replace_node); | |||
| auto types = {AnfAlgo::GetOutputInferDataType(node, 0)}; | |||
| auto shapes = {AnfAlgo::GetOutputInferShape(node, 0)}; | |||
| AnfAlgo::SetOutputInferTypeAndShape(types, shapes, replace_node.get()); | |||
| replace_node->set_scope(node->scope()); | |||
| return replace_node; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,52 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_APPLY_MOMENTUM_WEIGHT_DECAY_SCALE_FUSION_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_APPLY_MOMENTUM_WEIGHT_DECAY_SCALE_FUSION_H_ | |||
| #include <memory> | |||
| #include "backend/optimizer/common/optimizer.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class ApplyMomentumWeightDecayScaleFusion : public PatternProcessPass { | |||
| public: | |||
| explicit ApplyMomentumWeightDecayScaleFusion(bool multigraph = true) | |||
| : PatternProcessPass("momentum_weightdecay_scale_fusion", multigraph) { | |||
| weight_decay_ = std::make_shared<Var>(); | |||
| scale_ = std::make_shared<Var>(); | |||
| variable_ = std::make_shared<Var>(); | |||
| accumulation_ = std::make_shared<Var>(); | |||
| learning_rate_ = std::make_shared<Var>(); | |||
| gradient_ = std::make_shared<Var>(); | |||
| momentum_ = std::make_shared<Var>(); | |||
| } | |||
| ~ApplyMomentumWeightDecayScaleFusion() override = default; | |||
| const BaseRef DefinePattern() const override; | |||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||
| private: | |||
| VarPtr weight_decay_; | |||
| VarPtr scale_; | |||
| VarPtr variable_; | |||
| VarPtr accumulation_; | |||
| VarPtr learning_rate_; | |||
| VarPtr gradient_; | |||
| VarPtr momentum_; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_APPLY_MOMENTUM_WEIGHT_DECAY_SCALE_FUSION_H_ | |||
| @@ -0,0 +1,175 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/optimizer/gpu/batch_norm_add_relu_grad_fusion.h" | |||
| #include <algorithm> | |||
| #include <memory> | |||
| #include <vector> | |||
| #include <string> | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| #include "ir/primitive.h" | |||
| #include "utils/utils.h" | |||
| #include "backend/optimizer/common/helper.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace { | |||
| const std::vector<int> kOutputIndex{0, 1, 2}; | |||
| constexpr size_t kBNGradOutputNum = 3; | |||
| constexpr size_t kBNAddReluGradOutputNum = 4; | |||
| bool GetBatchNormOutputs(const FuncGraphPtr &func_graph, const AnfNodePtr &bn, std::vector<AnfNodePtr> *bn_outputs) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(bn_outputs); | |||
| auto manager = func_graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| if (manager->node_users().find(bn) == manager->node_users().end()) { | |||
| return false; | |||
| } | |||
| size_t output_num = 0; | |||
| for (const auto &node_index : manager->node_users()[bn]) { | |||
| const AnfNodePtr &output = node_index.first; | |||
| MS_EXCEPTION_IF_NULL(output); | |||
| if (!IsPrimitiveCNode(output, prim::kPrimTupleGetItem)) { | |||
| continue; | |||
| } | |||
| auto tuple_getiterm_cnode = output->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(tuple_getiterm_cnode); | |||
| auto index_node = tuple_getiterm_cnode->input(kInputNodeOutputIndexInTupleGetItem); | |||
| MS_EXCEPTION_IF_NULL(index_node); | |||
| auto value_node = index_node->cast<ValueNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(value_node); | |||
| int index = GetValue<int>(value_node->value()); | |||
| if (std::find(kOutputIndex.begin(), kOutputIndex.end(), index) == kOutputIndex.end()) { | |||
| return false; | |||
| } | |||
| bn_outputs->push_back(output); | |||
| output_num++; | |||
| } | |||
| return output_num == kBNGradOutputNum; | |||
| } | |||
| void SetShapeAndType(const CNodePtr &bn_add_relu_grad, const AnfNodePtr &bn_grad, const AnfNodePtr &relu_grad) { | |||
| // set output shape and dtype | |||
| std::vector<TypeId> outputs_type; | |||
| std::vector<std::vector<size_t>> outputs_shape; | |||
| auto output_num = AnfAlgo::GetOutputTensorNum(bn_grad); | |||
| for (size_t i = 0; i < output_num; ++i) { | |||
| outputs_type.push_back(AnfAlgo::GetOutputInferDataType(bn_grad, i)); | |||
| outputs_shape.push_back(AnfAlgo::GetOutputInferShape(bn_grad, i)); | |||
| } | |||
| outputs_type.push_back(AnfAlgo::GetOutputInferDataType(relu_grad, 0)); | |||
| outputs_shape.push_back(AnfAlgo::GetOutputInferShape(relu_grad, 0)); | |||
| AnfAlgo::SetOutputInferTypeAndShape(outputs_type, outputs_shape, bn_add_relu_grad.get()); | |||
| } | |||
| void ReplaceOutput(const FuncGraphPtr &graph, const AnfNodePtr &bn_grad, const AnfNodePtr &relu_grad, | |||
| const CNodePtr &bn_add_relu_grad) { | |||
| // Create outputs | |||
| std::vector<AnfNodePtr> bn_add_relu_grad_output; | |||
| CreateMultipleOutputsOfAnfNode(graph, bn_add_relu_grad, kBNAddReluGradOutputNum, &bn_add_relu_grad_output); | |||
| if (bn_add_relu_grad_output.size() != kBNAddReluGradOutputNum) { | |||
| MS_LOG(EXCEPTION) << "The output size of node " << kFusedBatchNormGradExWithAddAndActivation << " must be " | |||
| << kBNAddReluGradOutputNum << ", but it is " << bn_add_relu_grad_output.size(); | |||
| } | |||
| // Get bn outputs | |||
| std::vector<AnfNodePtr> bn_outputs; | |||
| if (!GetBatchNormOutputs(graph, bn_grad, &bn_outputs)) { | |||
| MS_LOG(INFO) << "The " << prim::kPrimFusedBatchNormGradEx | |||
| << " node should only have output 0, 1 and 2. The node should not be changed"; | |||
| return; | |||
| } | |||
| // Replace orignal output | |||
| auto manager = graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| sort(bn_outputs.begin(), bn_outputs.end(), CompareTupleGetitem); | |||
| size_t output_index = 0; | |||
| for (const auto &output : bn_outputs) { | |||
| (void)manager->Replace(output, bn_add_relu_grad_output[output_index]); | |||
| output_index++; | |||
| } | |||
| manager->Replace(relu_grad, bn_add_relu_grad_output[kBNAddReluGradOutputNum - 1]); | |||
| return; | |||
| } | |||
| } // namespace | |||
| const BaseRef BatchNormAddReluGradFusion::DefinePattern() const { | |||
| VectorRef relu_grad = VectorRef({prim::kPrimReluGrad, dy_, y_}); | |||
| VectorRef batch_norm_grad = | |||
| VectorRef({prim::kPrimFusedBatchNormGradEx, relu_grad, x_, scale_, save_mean_, save_var_, reserve_}); | |||
| return batch_norm_grad; | |||
| } | |||
| const AnfNodePtr BatchNormAddReluGradFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, | |||
| const EquivPtr &) const { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (AnfAlgo::GetOutputInferDataType(node, 0) != kNumberTypeFloat16) { | |||
| return nullptr; | |||
| } | |||
| auto relu_grad = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 0); | |||
| MS_EXCEPTION_IF_NULL(relu_grad); | |||
| auto relu_users = GetRealNodeUsedList(graph, relu_grad); | |||
| if (relu_users->size() != 2) { | |||
| return nullptr; | |||
| } | |||
| // process pattern as Relu(TensorAdd(BN#0, BN#1)) | |||
| auto tuple_getitem = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 5); | |||
| MS_EXCEPTION_IF_NULL(tuple_getitem); | |||
| auto forward_node = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tuple_getitem), 0); | |||
| if (AnfAlgo::GetCNodeName(forward_node) != kFusedBatchNormExWithAddAndActivation) { | |||
| return nullptr; | |||
| } | |||
| auto dy = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(relu_grad), 0); | |||
| MS_EXCEPTION_IF_NULL(dy); | |||
| auto y = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(relu_grad), 1); | |||
| MS_EXCEPTION_IF_NULL(y); | |||
| auto x = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 1); | |||
| MS_EXCEPTION_IF_NULL(x); | |||
| auto scale = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 2); | |||
| MS_EXCEPTION_IF_NULL(scale); | |||
| auto save_mean = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 3); | |||
| MS_EXCEPTION_IF_NULL(save_mean); | |||
| auto save_var = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 4); | |||
| MS_EXCEPTION_IF_NULL(save_var); | |||
| auto reserve = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 5); | |||
| MS_EXCEPTION_IF_NULL(reserve); | |||
| auto batch_norm = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(save_mean), 0); | |||
| MS_EXCEPTION_IF_NULL(batch_norm); | |||
| auto bias = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), 2); | |||
| MS_EXCEPTION_IF_NULL(bias); | |||
| auto prim = std::make_shared<Primitive>(kFusedBatchNormGradExWithAddAndActivation); | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| std::vector<AnfNodePtr> inputs = {NewValueNode(prim), dy, x, scale, save_mean, save_var, reserve, bias, y}; | |||
| auto fused_batch_norm_add_relu_grad = graph->NewCNode(inputs); | |||
| MS_EXCEPTION_IF_NULL(fused_batch_norm_add_relu_grad); | |||
| AnfAlgo::CopyNodeAttrs(node, fused_batch_norm_add_relu_grad); | |||
| SetShapeAndType(fused_batch_norm_add_relu_grad, node, relu_grad); | |||
| ReplaceOutput(graph, node, relu_grad, fused_batch_norm_add_relu_grad); | |||
| return nullptr; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,57 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_BATCH_NORM_ADD_RELU_GRAD_FUSION_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_BATCH_NORM_ADD_RELU_GRAD_FUSION_H_ | |||
| #include <memory> | |||
| #include "backend/optimizer/common/optimizer.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class BatchNormAddReluGradFusion : public PatternProcessPass { | |||
| public: | |||
| explicit BatchNormAddReluGradFusion(bool multigraph = true) | |||
| : PatternProcessPass("batch_norm_add_relu_grad_fusion", multigraph) { | |||
| dy_ = std::make_shared<Var>(); | |||
| y_ = std::make_shared<Var>(); | |||
| x_ = std::make_shared<Var>(); | |||
| scale_ = std::make_shared<Var>(); | |||
| bias_ = std::make_shared<Var>(); | |||
| mean_ = std::make_shared<Var>(); | |||
| var_ = std::make_shared<Var>(); | |||
| save_mean_ = std::make_shared<Var>(); | |||
| save_var_ = std::make_shared<Var>(); | |||
| reserve_ = std::make_shared<Var>(); | |||
| } | |||
| ~BatchNormAddReluGradFusion() override = default; | |||
| const BaseRef DefinePattern() const override; | |||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||
| private: | |||
| VarPtr dy_; | |||
| VarPtr y_; | |||
| VarPtr x_; | |||
| VarPtr scale_; | |||
| VarPtr bias_; | |||
| VarPtr mean_; | |||
| VarPtr var_; | |||
| VarPtr save_mean_; | |||
| VarPtr save_var_; | |||
| VarPtr reserve_; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_BATCH_NORM_RELU_GRAD_FUSION_H_ | |||
| @@ -26,11 +26,14 @@ | |||
| #include "backend/optimizer/pass/getitem_tuple.h" | |||
| #include "backend/optimizer/gpu/adam_weight_decay_fusion.h" | |||
| #include "backend/optimizer/gpu/adam_fusion.h" | |||
| #include "backend/optimizer/gpu/apply_momentum_weight_scale_fusion.h" | |||
| #include "backend/optimizer/gpu/apply_momentum_scale_fusion.h" | |||
| #include "backend/optimizer/gpu/replace_bn_cast_fusion.h" | |||
| #include "backend/optimizer/gpu/replace_bn_grad_cast_fusion.h" | |||
| #include "backend/optimizer/gpu/batch_norm_relu_fusion.h" | |||
| #include "backend/optimizer/gpu/batch_norm_relu_grad_fusion.h" | |||
| #include "backend/optimizer/gpu/batch_norm_add_relu_fusion.h" | |||
| #include "backend/optimizer/gpu/batch_norm_add_relu_grad_fusion.h" | |||
| #include "backend/optimizer/gpu/replace_momentum_cast_fusion.h" | |||
| #include "backend/optimizer/gpu/replace_addn_fusion.h" | |||
| #include "backend/optimizer/gpu/insert_format_transform_op.h" | |||
| @@ -73,6 +76,8 @@ void GPUSession::Optimize(const std::shared_ptr<KernelGraph> &kernel_graph) { | |||
| auto pm = std::make_shared<opt::PassManager>(); | |||
| pm->AddPass(std::make_shared<opt::AdamWeightDecayFusion>()); | |||
| pm->AddPass(std::make_shared<opt::AdamFusion>()); | |||
| // pm->AddPass(std::make_shared<opt::ApplyMomentumWeightDecayScaleFusion>()); | |||
| // pm->AddPass(std::make_shared<opt::ApplyMomentumScaleFusion>()); | |||
| pm->AddPass(std::make_shared<opt::ReplaceBNCastFusion>()); | |||
| pm->AddPass(std::make_shared<opt::ReplaceBNGradCastFusion>()); | |||
| pm->AddPass(std::make_shared<opt::ReplaceMomentumCastFusion>()); | |||
| @@ -81,6 +86,7 @@ void GPUSession::Optimize(const std::shared_ptr<KernelGraph> &kernel_graph) { | |||
| pm->AddPass(std::make_shared<opt::BatchNormReluFusion>()); | |||
| pm->AddPass(std::make_shared<opt::BatchNormReluGradFusion>()); | |||
| pm->AddPass(std::make_shared<opt::BatchNormAddReluFusion>()); | |||
| // pm->AddPass(std::make_shared<opt::BatchNormAddReluGradFusion>()); | |||
| } | |||
| optimizer->AddPassManager(pm); | |||
| (void)optimizer->Optimize(kernel_graph); | |||
| @@ -193,6 +193,8 @@ constexpr auto kPaddingOpName = "Padding"; | |||
| constexpr auto kAvgPoolOpName = "AvgPool"; | |||
| constexpr auto kAvgPoolGradGpuOpName = "AvgPoolGradGpu"; | |||
| constexpr auto kTensorAddOpName = "TensorAdd"; | |||
| constexpr auto kFusedWeightScaleApplyMomentum = "FusedWeightScaleApplyMomentum"; | |||
| constexpr auto kFusedScaleApplyMomentum = "FusedScaleApplyMomentum"; | |||
| // attr key name | |||
| constexpr auto kAttrInputNames = "input_names"; | |||