Merge pull request !3092 from VectorSL/momentumtags/v0.6.0-beta
| @@ -15,9 +15,9 @@ | |||||
| */ | */ | ||||
| #include "momentum_impl.cuh" | #include "momentum_impl.cuh" | ||||
| template <typename T, typename S> | |||||
| template <typename T, typename S, typename G> | |||||
| __global__ void MomentumUpdateVariableKernel(const size_t size, T *variable, T *accumulation, const S *learning_rate, | __global__ void MomentumUpdateVariableKernel(const size_t size, T *variable, T *accumulation, const S *learning_rate, | ||||
| const T *gradient, const S *momentum) { | |||||
| const G *gradient, const S *momentum) { | |||||
| for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (size); i += blockDim.x * gridDim.x) { | for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (size); i += blockDim.x * gridDim.x) { | ||||
| accumulation[i] = momentum[0] * accumulation[i] + gradient[i]; | accumulation[i] = momentum[0] * accumulation[i] + gradient[i]; | ||||
| variable[i] -= learning_rate[0] * accumulation[i]; | variable[i] -= learning_rate[0] * accumulation[i]; | ||||
| @@ -34,19 +34,32 @@ __global__ void MomentumUpdateVariableKernel(const size_t size, half *variable, | |||||
| } | } | ||||
| return; | return; | ||||
| } | } | ||||
| template <typename T, typename S> | |||||
| void MomentumUpdateVariable(const size_t size, T *variable, T *accumulation, const S *learning_rate, const T *gradient, | |||||
| template <> | |||||
| __global__ void MomentumUpdateVariableKernel(const size_t size, float *variable, float *accumulation, | |||||
| const float *learning_rate, const half *gradient, | |||||
| const float *momentum) { | |||||
| for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (size); i += blockDim.x * gridDim.x) { | |||||
| accumulation[i] = momentum[0] * accumulation[i] + __half2float(gradient[i]); | |||||
| variable[i] -= learning_rate[0] * accumulation[i]; | |||||
| } | |||||
| return; | |||||
| } | |||||
| template <typename T, typename S, typename G> | |||||
| void MomentumUpdateVariable(const size_t size, T *variable, T *accumulation, const S *learning_rate, const G *gradient, | |||||
| const S *momentum, cudaStream_t cuda_stream) { | const S *momentum, cudaStream_t cuda_stream) { | ||||
| MomentumUpdateVariableKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, variable, accumulation, | MomentumUpdateVariableKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, variable, accumulation, | ||||
| learning_rate, gradient, momentum); | learning_rate, gradient, momentum); | ||||
| return; | return; | ||||
| } | } | ||||
| template void MomentumUpdateVariable<float, float>(const size_t size, float *variable, float *accumulation, | |||||
| template void MomentumUpdateVariable<float, float, float>(const size_t size, float *variable, float *accumulation, | |||||
| const float *learning_rate, const float *gradient, | const float *learning_rate, const float *gradient, | ||||
| const float *momentum, cudaStream_t cuda_stream); | const float *momentum, cudaStream_t cuda_stream); | ||||
| template void MomentumUpdateVariable<half, half>(const size_t size, half *variable, half *accumulation, | |||||
| template void MomentumUpdateVariable<half, half, half>(const size_t size, half *variable, half *accumulation, | |||||
| const half *learning_rate, const half *gradient, | const half *learning_rate, const half *gradient, | ||||
| const half *momentum, cudaStream_t cuda_stream); | const half *momentum, cudaStream_t cuda_stream); | ||||
| template void MomentumUpdateVariable<half, float>(const size_t size, half *variable, half *accumulation, | |||||
| template void MomentumUpdateVariable<half, float, half>(const size_t size, half *variable, half *accumulation, | |||||
| const float *learning_rate, const half *gradient, | |||||
| const float *momentum, cudaStream_t cuda_stream); | |||||
| template void MomentumUpdateVariable<float, float, half>(const size_t size, float *variable, float *accumulation, | |||||
| const float *learning_rate, const half *gradient, | const float *learning_rate, const half *gradient, | ||||
| const float *momentum, cudaStream_t cuda_stream); | const float *momentum, cudaStream_t cuda_stream); | ||||
| @@ -18,8 +18,8 @@ | |||||
| #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_MOMENTUMIMPL_H_ | #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_MOMENTUMIMPL_H_ | ||||
| #include "runtime/device/gpu/cuda_common.h" | #include "runtime/device/gpu/cuda_common.h" | ||||
| template <typename T, typename S> | |||||
| void MomentumUpdateVariable(const size_t size, T *variable, T *accumulation, const S *learning_rate, const T *gradient, | |||||
| template <typename T, typename S, typename G> | |||||
| void MomentumUpdateVariable(const size_t size, T *variable, T *accumulation, const S *learning_rate, const G *gradient, | |||||
| const S *momentum, cudaStream_t cuda_stream); | const S *momentum, cudaStream_t cuda_stream); | ||||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_MOMENTUMIMPL_H_ | #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_MOMENTUMIMPL_H_ | ||||
| @@ -88,6 +88,12 @@ class GpuKernelRegister { | |||||
| static_assert(std::is_base_of<GpuKernel, OPCLASS<T, S>>::value, " must be base of GpuKernel"); \ | static_assert(std::is_base_of<GpuKernel, OPCLASS<T, S>>::value, " must be base of GpuKernel"); \ | ||||
| static const GpuKernelRegister g_##OPNAME##_##T##_##S##_gpu_kernel_reg(#OPNAME, ATTR, \ | static const GpuKernelRegister g_##OPNAME##_##T##_##S##_gpu_kernel_reg(#OPNAME, ATTR, \ | ||||
| []() { return new OPCLASS<T, S>(); }); | []() { return new OPCLASS<T, S>(); }); | ||||
| // register of mixed accuracy kernels which use template and maintain three typename | |||||
| #define MS_REG_GPU_KERNEL_THREE(OPNAME, ATTR, OPCLASS, T, S, G) \ | |||||
| static_assert(std::is_base_of<GpuKernel, OPCLASS<T, S, G>>::value, " must be base of GpuKernel"); \ | |||||
| static const GpuKernelRegister g_##OPNAME##_##T##_##S##_##G##_gpu_kernel_reg( \ | |||||
| #OPNAME, ATTR, []() { return new OPCLASS<T, S, G>(); }); | |||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_GPUKERNELFACTORY_H_ | #endif // MINDSPORE_CCSRC_KERNEL_GPU_GPUKERNELFACTORY_H_ | ||||
| @@ -34,15 +34,15 @@ MS_REG_GPU_KERNEL_ONE(FusedBatchNorm, | |||||
| MS_REG_GPU_KERNEL_ONE(FusedBatchNorm, | MS_REG_GPU_KERNEL_ONE(FusedBatchNorm, | ||||
| KernelAttr() | KernelAttr() | ||||
| .AddInputAttr(kNumberTypeFloat16) | .AddInputAttr(kNumberTypeFloat16) | ||||
| .AddInputAttr(kNumberTypeFloat16) | |||||
| .AddInputAttr(kNumberTypeFloat16) | |||||
| .AddInputAttr(kNumberTypeFloat16) | |||||
| .AddInputAttr(kNumberTypeFloat16) | |||||
| .AddOutputAttr(kNumberTypeFloat16) | |||||
| .AddOutputAttr(kNumberTypeFloat16) | |||||
| .AddOutputAttr(kNumberTypeFloat16) | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddOutputAttr(kNumberTypeFloat16) | .AddOutputAttr(kNumberTypeFloat16) | ||||
| .AddOutputAttr(kNumberTypeFloat16), | |||||
| .AddOutputAttr(kNumberTypeFloat32) | |||||
| .AddOutputAttr(kNumberTypeFloat32) | |||||
| .AddOutputAttr(kNumberTypeFloat32) | |||||
| .AddOutputAttr(kNumberTypeFloat32), | |||||
| FusedBatchNormGpuKernel, half) | FusedBatchNormGpuKernel, half) | ||||
| MS_REG_GPU_KERNEL_ONE(BatchNorm, | MS_REG_GPU_KERNEL_ONE(BatchNorm, | ||||
| KernelAttr() | KernelAttr() | ||||
| @@ -60,15 +60,15 @@ MS_REG_GPU_KERNEL_ONE(BatchNorm, | |||||
| MS_REG_GPU_KERNEL_ONE(BatchNorm, | MS_REG_GPU_KERNEL_ONE(BatchNorm, | ||||
| KernelAttr() | KernelAttr() | ||||
| .AddInputAttr(kNumberTypeFloat16) | .AddInputAttr(kNumberTypeFloat16) | ||||
| .AddInputAttr(kNumberTypeFloat16) | |||||
| .AddInputAttr(kNumberTypeFloat16) | |||||
| .AddInputAttr(kNumberTypeFloat16) | |||||
| .AddInputAttr(kNumberTypeFloat16) | |||||
| .AddOutputAttr(kNumberTypeFloat16) | |||||
| .AddOutputAttr(kNumberTypeFloat16) | |||||
| .AddOutputAttr(kNumberTypeFloat16) | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddOutputAttr(kNumberTypeFloat16) | .AddOutputAttr(kNumberTypeFloat16) | ||||
| .AddOutputAttr(kNumberTypeFloat16), | |||||
| .AddOutputAttr(kNumberTypeFloat32) | |||||
| .AddOutputAttr(kNumberTypeFloat32) | |||||
| .AddOutputAttr(kNumberTypeFloat32) | |||||
| .AddOutputAttr(kNumberTypeFloat32), | |||||
| FusedBatchNormGpuKernel, half) | FusedBatchNormGpuKernel, half) | ||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -56,17 +56,17 @@ class FusedBatchNormGpuKernel : public GpuKernel { | |||||
| return true; | return true; | ||||
| } | } | ||||
| auto x = GetDeviceAddress<T>(inputs, 0); | auto x = GetDeviceAddress<T>(inputs, 0); | ||||
| auto scale = GetDeviceAddress<T>(inputs, 1); | |||||
| auto bias = GetDeviceAddress<T>(inputs, 2); | |||||
| auto runing_mean = GetDeviceAddress<T>(inputs, 3); | |||||
| auto runnig_variance = GetDeviceAddress<T>(inputs, 4); | |||||
| auto scale = GetDeviceAddress<float>(inputs, 1); | |||||
| auto bias = GetDeviceAddress<float>(inputs, 2); | |||||
| auto runing_mean = GetDeviceAddress<float>(inputs, 3); | |||||
| auto runnig_variance = GetDeviceAddress<float>(inputs, 4); | |||||
| auto y = GetDeviceAddress<T>(outputs, 0); | auto y = GetDeviceAddress<T>(outputs, 0); | ||||
| const float alpha = 1; | const float alpha = 1; | ||||
| const float beta = 0; | const float beta = 0; | ||||
| if (is_train_) { | if (is_train_) { | ||||
| auto save_mean = GetDeviceAddress<T>(outputs, 3); | |||||
| auto save_variance = GetDeviceAddress<T>(outputs, 4); | |||||
| auto save_mean = GetDeviceAddress<float>(outputs, 3); | |||||
| auto save_variance = GetDeviceAddress<float>(outputs, 4); | |||||
| CHECK_CUDNN_RET_WITH_EXCEPT( | CHECK_CUDNN_RET_WITH_EXCEPT( | ||||
| cudnnBatchNormalizationForwardTraining(handle_, mode_, &alpha, &beta, x_desc_, x, y_desc_, y, | cudnnBatchNormalizationForwardTraining(handle_, mode_, &alpha, &beta, x_desc_, x, y_desc_, y, | ||||
| scale_bias_mean_var_desc_, scale, bias, exp_avg_factor_, runing_mean, | scale_bias_mean_var_desc_, scale, bias, exp_avg_factor_, runing_mean, | ||||
| @@ -33,12 +33,12 @@ MS_REG_GPU_KERNEL_ONE(FusedBatchNormGrad, | |||||
| KernelAttr() | KernelAttr() | ||||
| .AddInputAttr(kNumberTypeFloat16) | .AddInputAttr(kNumberTypeFloat16) | ||||
| .AddInputAttr(kNumberTypeFloat16) | .AddInputAttr(kNumberTypeFloat16) | ||||
| .AddInputAttr(kNumberTypeFloat16) | |||||
| .AddInputAttr(kNumberTypeFloat16) | |||||
| .AddInputAttr(kNumberTypeFloat16) | |||||
| .AddOutputAttr(kNumberTypeFloat16) | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddOutputAttr(kNumberTypeFloat16) | .AddOutputAttr(kNumberTypeFloat16) | ||||
| .AddOutputAttr(kNumberTypeFloat16), | |||||
| .AddOutputAttr(kNumberTypeFloat32) | |||||
| .AddOutputAttr(kNumberTypeFloat32), | |||||
| FusedBatchNormGradGpuKernel, half) | FusedBatchNormGradGpuKernel, half) | ||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -55,12 +55,12 @@ class FusedBatchNormGradGpuKernel : public GpuKernel { | |||||
| } | } | ||||
| auto dy = GetDeviceAddress<T>(inputs, 0); | auto dy = GetDeviceAddress<T>(inputs, 0); | ||||
| auto x = GetDeviceAddress<T>(inputs, 1); | auto x = GetDeviceAddress<T>(inputs, 1); | ||||
| auto scale = GetDeviceAddress<T>(inputs, 2); | |||||
| auto save_mean = GetDeviceAddress<T>(inputs, 3); | |||||
| auto save_variance = GetDeviceAddress<T>(inputs, 4); | |||||
| auto scale = GetDeviceAddress<float>(inputs, 2); | |||||
| auto save_mean = GetDeviceAddress<float>(inputs, 3); | |||||
| auto save_variance = GetDeviceAddress<float>(inputs, 4); | |||||
| auto dx = GetDeviceAddress<T>(outputs, 0); | auto dx = GetDeviceAddress<T>(outputs, 0); | ||||
| auto bn_scale = GetDeviceAddress<T>(outputs, 1); | |||||
| auto bn_bias = GetDeviceAddress<T>(outputs, 2); | |||||
| auto bn_scale = GetDeviceAddress<float>(outputs, 1); | |||||
| auto bn_bias = GetDeviceAddress<float>(outputs, 2); | |||||
| const float alpha_data_diff = 1; | const float alpha_data_diff = 1; | ||||
| const float beta_data_diff = 0; | const float beta_data_diff = 0; | ||||
| @@ -18,32 +18,41 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace kernel { | namespace kernel { | ||||
| MS_REG_GPU_KERNEL_TWO(ApplyMomentum, | |||||
| KernelAttr() | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddOutputAttr(kNumberTypeFloat32), | |||||
| MomentumGpuKernel, float, float) | |||||
| MS_REG_GPU_KERNEL_TWO(ApplyMomentum, | |||||
| KernelAttr() | |||||
| .AddInputAttr(kNumberTypeFloat16) | |||||
| .AddInputAttr(kNumberTypeFloat16) | |||||
| .AddInputAttr(kNumberTypeFloat16) | |||||
| .AddInputAttr(kNumberTypeFloat16) | |||||
| .AddInputAttr(kNumberTypeFloat16) | |||||
| .AddOutputAttr(kNumberTypeFloat16), | |||||
| MomentumGpuKernel, half, half) | |||||
| MS_REG_GPU_KERNEL_TWO(ApplyMomentum, | |||||
| KernelAttr() | |||||
| .AddInputAttr(kNumberTypeFloat16) | |||||
| .AddInputAttr(kNumberTypeFloat16) | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddInputAttr(kNumberTypeFloat16) | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddOutputAttr(kNumberTypeFloat16), | |||||
| MomentumGpuKernel, half, float) | |||||
| MS_REG_GPU_KERNEL_THREE(ApplyMomentum, | |||||
| KernelAttr() | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddOutputAttr(kNumberTypeFloat32), | |||||
| MomentumGpuKernel, float, float, float) | |||||
| MS_REG_GPU_KERNEL_THREE(ApplyMomentum, | |||||
| KernelAttr() | |||||
| .AddInputAttr(kNumberTypeFloat16) | |||||
| .AddInputAttr(kNumberTypeFloat16) | |||||
| .AddInputAttr(kNumberTypeFloat16) | |||||
| .AddInputAttr(kNumberTypeFloat16) | |||||
| .AddInputAttr(kNumberTypeFloat16) | |||||
| .AddOutputAttr(kNumberTypeFloat16), | |||||
| MomentumGpuKernel, half, half, half) | |||||
| MS_REG_GPU_KERNEL_THREE(ApplyMomentum, | |||||
| KernelAttr() | |||||
| .AddInputAttr(kNumberTypeFloat16) | |||||
| .AddInputAttr(kNumberTypeFloat16) | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddInputAttr(kNumberTypeFloat16) | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddOutputAttr(kNumberTypeFloat16), | |||||
| MomentumGpuKernel, half, float, half) | |||||
| MS_REG_GPU_KERNEL_THREE(ApplyMomentum, | |||||
| KernelAttr() | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddInputAttr(kNumberTypeFloat16) | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddOutputAttr(kNumberTypeFloat32), | |||||
| MomentumGpuKernel, float, float, half) | |||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -23,7 +23,7 @@ | |||||
| #include "backend/kernel_compiler/gpu/cuda_impl/momentum_impl.cuh" | #include "backend/kernel_compiler/gpu/cuda_impl/momentum_impl.cuh" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace kernel { | namespace kernel { | ||||
| template <typename T, typename S> | |||||
| template <typename T, typename S, typename G> | |||||
| class MomentumGpuKernel : public GpuKernel { | class MomentumGpuKernel : public GpuKernel { | ||||
| public: | public: | ||||
| MomentumGpuKernel() | MomentumGpuKernel() | ||||
| @@ -38,7 +38,7 @@ class MomentumGpuKernel : public GpuKernel { | |||||
| T *variable = GetDeviceAddress<T>(inputs, 0); | T *variable = GetDeviceAddress<T>(inputs, 0); | ||||
| T *accumulation = GetDeviceAddress<T>(inputs, 1); | T *accumulation = GetDeviceAddress<T>(inputs, 1); | ||||
| S *learning_rate = GetDeviceAddress<S>(inputs, 2); | S *learning_rate = GetDeviceAddress<S>(inputs, 2); | ||||
| T *gradient = GetDeviceAddress<T>(inputs, 3); | |||||
| G *gradient = GetDeviceAddress<G>(inputs, 3); | |||||
| S *momentum = GetDeviceAddress<S>(inputs, 4); | S *momentum = GetDeviceAddress<S>(inputs, 4); | ||||
| MomentumUpdateVariable(inputs[0]->size / sizeof(T), variable, accumulation, learning_rate, gradient, momentum, | MomentumUpdateVariable(inputs[0]->size / sizeof(T), variable, accumulation, learning_rate, gradient, momentum, | ||||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | reinterpret_cast<cudaStream_t>(stream_ptr)); | ||||
| @@ -54,7 +54,7 @@ class MomentumGpuKernel : public GpuKernel { | |||||
| variable_size_ = sizeof(T); | variable_size_ = sizeof(T); | ||||
| accumulation_size_ = sizeof(T); | accumulation_size_ = sizeof(T); | ||||
| learning_rate_size_ = sizeof(S); | learning_rate_size_ = sizeof(S); | ||||
| gradient_size_ = sizeof(T); | |||||
| gradient_size_ = sizeof(G); | |||||
| momentum_size_ = sizeof(S); | momentum_size_ = sizeof(S); | ||||
| auto variable_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | auto variable_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | ||||
| @@ -0,0 +1,63 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "backend/optimizer/gpu/replace_momentum_cast_fusion.h" | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include <string> | |||||
| #include "backend/session/anf_runtime_algorithm.h" | |||||
| #include "ir/primitive.h" | |||||
| #include "utils/utils.h" | |||||
| #include "backend/optimizer/common/helper.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| const BaseRef ReplaceMomentumCastFusion::DefinePattern() const { | |||||
| VectorRef grad_cast = VectorRef({prim::kPrimCast, grad_}); | |||||
| VectorRef momentum = VectorRef({prim::kPrimApplyMomentum, var_, acc_, lr_, grad_cast, mom_}); | |||||
| return momentum; | |||||
| } | |||||
| const AnfNodePtr ReplaceMomentumCastFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, | |||||
| const EquivPtr &equiv) const { | |||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| MS_EXCEPTION_IF_NULL(equiv); | |||||
| auto grad_cast = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 3); | |||||
| auto grad = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(grad_cast), 0); | |||||
| MS_EXCEPTION_IF_NULL(grad_cast); | |||||
| MS_EXCEPTION_IF_NULL(grad); | |||||
| auto manager = graph->manager(); | |||||
| MS_EXCEPTION_IF_NULL(manager); | |||||
| manager->Replace(utils::cast<CNodePtr>(grad_cast), utils::cast<CNodePtr>(grad)); | |||||
| std::vector<TypeId> outputs_type; | |||||
| std::vector<std::vector<size_t>> outputs_shape; | |||||
| auto output_num = AnfAlgo::GetOutputTensorNum(node); | |||||
| for (size_t i = 0; i < output_num; i++) { | |||||
| outputs_type.push_back(AnfAlgo::GetOutputInferDataType(node, i)); | |||||
| outputs_shape.push_back(AnfAlgo::GetOutputInferShape(node, i)); | |||||
| } | |||||
| outputs_type[3] = AnfAlgo::GetPrevNodeOutputInferDataType(grad_cast, 0); | |||||
| AnfAlgo::SetOutputInferTypeAndShape(outputs_type, outputs_shape, node.get()); | |||||
| return node; | |||||
| } | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,46 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_REPLACE_MOMENTUM_CAST_FUSION_H_ | |||||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_REPLACE_MOMENTUM_CAST_FUSION_H_ | |||||
| #include <memory> | |||||
| #include "backend/optimizer/common/optimizer.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| class ReplaceMomentumCastFusion : public PatternProcessPass { | |||||
| public: | |||||
| explicit ReplaceMomentumCastFusion(bool multigraph = true) : PatternProcessPass("replace_momentum_cast", multigraph) { | |||||
| var_ = std::make_shared<Var>(); | |||||
| acc_ = std::make_shared<Var>(); | |||||
| lr_ = std::make_shared<Var>(); | |||||
| grad_ = std::make_shared<Var>(); | |||||
| mom_ = std::make_shared<Var>(); | |||||
| } | |||||
| ~ReplaceMomentumCastFusion() override = default; | |||||
| const BaseRef DefinePattern() const override; | |||||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||||
| private: | |||||
| VarPtr var_; | |||||
| VarPtr acc_; | |||||
| VarPtr lr_; | |||||
| VarPtr grad_; | |||||
| VarPtr mom_; | |||||
| }; | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_REPLACE_MOMENTUM_CAST_FUSION_H_ | |||||
| @@ -25,6 +25,11 @@ | |||||
| #include "backend/optimizer/pass/getitem_tuple.h" | #include "backend/optimizer/pass/getitem_tuple.h" | ||||
| #include "backend/optimizer/gpu/adam_weight_decay_fusion.h" | #include "backend/optimizer/gpu/adam_weight_decay_fusion.h" | ||||
| #include "backend/optimizer/gpu/adam_fusion.h" | #include "backend/optimizer/gpu/adam_fusion.h" | ||||
| #include "backend/optimizer/gpu/replace_bn_cast_fusion.h" | |||||
| #include "backend/optimizer/gpu/replace_bn_grad_cast_fusion.h" | |||||
| #include "backend/optimizer/gpu/replace_bn_grad_cast2_fusion.h" | |||||
| #include "backend/optimizer/gpu/replace_momentum_cast_fusion.h" | |||||
| #include "backend/optimizer/gpu/replace_addn_fusion.h" | |||||
| #include "runtime/device/kernel_runtime_manager.h" | #include "runtime/device/kernel_runtime_manager.h" | ||||
| #include "predict/predict.h" | #include "predict/predict.h" | ||||
| #include "common/utils.h" | #include "common/utils.h" | ||||
| @@ -59,6 +64,11 @@ void GPUSession::Optimize(const std::shared_ptr<KernelGraph> &kernel_graph) { | |||||
| auto pm = std::make_shared<opt::PassManager>(); | auto pm = std::make_shared<opt::PassManager>(); | ||||
| pm->AddPass(std::make_shared<opt::AdamWeightDecayFusion>()); | pm->AddPass(std::make_shared<opt::AdamWeightDecayFusion>()); | ||||
| pm->AddPass(std::make_shared<opt::AdamFusion>()); | pm->AddPass(std::make_shared<opt::AdamFusion>()); | ||||
| pm->AddPass(std::make_shared<opt::ReplaceBNCastFusion>()); | |||||
| pm->AddPass(std::make_shared<opt::ReplaceBNGradCastFusion>()); | |||||
| pm->AddPass(std::make_shared<opt::ReplaceBNGradCast2Fusion>()); | |||||
| pm->AddPass(std::make_shared<opt::ReplaceMomentumCastFusion>()); | |||||
| pm->AddPass(std::make_shared<opt::ReplaceAddNFusion>()); | |||||
| optimizer->AddPassManager(pm); | optimizer->AddPassManager(pm); | ||||
| (void)optimizer->Optimize(kernel_graph); | (void)optimizer->Optimize(kernel_graph); | ||||
| kernel_graph->SetExecOrderByDefault(); | kernel_graph->SetExecOrderByDefault(); | ||||