Merge pull request !3092 from VectorSL/momentumtags/v0.6.0-beta
| @@ -15,9 +15,9 @@ | |||
| */ | |||
| #include "momentum_impl.cuh" | |||
| template <typename T, typename S> | |||
| template <typename T, typename S, typename G> | |||
| __global__ void MomentumUpdateVariableKernel(const size_t size, T *variable, T *accumulation, const S *learning_rate, | |||
| const T *gradient, const S *momentum) { | |||
| const G *gradient, const S *momentum) { | |||
| for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (size); i += blockDim.x * gridDim.x) { | |||
| accumulation[i] = momentum[0] * accumulation[i] + gradient[i]; | |||
| variable[i] -= learning_rate[0] * accumulation[i]; | |||
| @@ -34,19 +34,32 @@ __global__ void MomentumUpdateVariableKernel(const size_t size, half *variable, | |||
| } | |||
| return; | |||
| } | |||
| template <typename T, typename S> | |||
| void MomentumUpdateVariable(const size_t size, T *variable, T *accumulation, const S *learning_rate, const T *gradient, | |||
| template <> | |||
| __global__ void MomentumUpdateVariableKernel(const size_t size, float *variable, float *accumulation, | |||
| const float *learning_rate, const half *gradient, | |||
| const float *momentum) { | |||
| for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (size); i += blockDim.x * gridDim.x) { | |||
| accumulation[i] = momentum[0] * accumulation[i] + __half2float(gradient[i]); | |||
| variable[i] -= learning_rate[0] * accumulation[i]; | |||
| } | |||
| return; | |||
| } | |||
| template <typename T, typename S, typename G> | |||
| void MomentumUpdateVariable(const size_t size, T *variable, T *accumulation, const S *learning_rate, const G *gradient, | |||
| const S *momentum, cudaStream_t cuda_stream) { | |||
| MomentumUpdateVariableKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, variable, accumulation, | |||
| learning_rate, gradient, momentum); | |||
| return; | |||
| } | |||
| template void MomentumUpdateVariable<float, float>(const size_t size, float *variable, float *accumulation, | |||
| template void MomentumUpdateVariable<float, float, float>(const size_t size, float *variable, float *accumulation, | |||
| const float *learning_rate, const float *gradient, | |||
| const float *momentum, cudaStream_t cuda_stream); | |||
| template void MomentumUpdateVariable<half, half>(const size_t size, half *variable, half *accumulation, | |||
| template void MomentumUpdateVariable<half, half, half>(const size_t size, half *variable, half *accumulation, | |||
| const half *learning_rate, const half *gradient, | |||
| const half *momentum, cudaStream_t cuda_stream); | |||
| template void MomentumUpdateVariable<half, float>(const size_t size, half *variable, half *accumulation, | |||
| template void MomentumUpdateVariable<half, float, half>(const size_t size, half *variable, half *accumulation, | |||
| const float *learning_rate, const half *gradient, | |||
| const float *momentum, cudaStream_t cuda_stream); | |||
| template void MomentumUpdateVariable<float, float, half>(const size_t size, float *variable, float *accumulation, | |||
| const float *learning_rate, const half *gradient, | |||
| const float *momentum, cudaStream_t cuda_stream); | |||
| @@ -18,8 +18,8 @@ | |||
| #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_MOMENTUMIMPL_H_ | |||
| #include "runtime/device/gpu/cuda_common.h" | |||
| template <typename T, typename S> | |||
| void MomentumUpdateVariable(const size_t size, T *variable, T *accumulation, const S *learning_rate, const T *gradient, | |||
| template <typename T, typename S, typename G> | |||
| void MomentumUpdateVariable(const size_t size, T *variable, T *accumulation, const S *learning_rate, const G *gradient, | |||
| const S *momentum, cudaStream_t cuda_stream); | |||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_MOMENTUMIMPL_H_ | |||
| @@ -88,6 +88,12 @@ class GpuKernelRegister { | |||
| static_assert(std::is_base_of<GpuKernel, OPCLASS<T, S>>::value, " must be base of GpuKernel"); \ | |||
| static const GpuKernelRegister g_##OPNAME##_##T##_##S##_gpu_kernel_reg(#OPNAME, ATTR, \ | |||
| []() { return new OPCLASS<T, S>(); }); | |||
| // register of mixed accuracy kernels which use template and maintain three typename | |||
| #define MS_REG_GPU_KERNEL_THREE(OPNAME, ATTR, OPCLASS, T, S, G) \ | |||
| static_assert(std::is_base_of<GpuKernel, OPCLASS<T, S, G>>::value, " must be base of GpuKernel"); \ | |||
| static const GpuKernelRegister g_##OPNAME##_##T##_##S##_##G##_gpu_kernel_reg( \ | |||
| #OPNAME, ATTR, []() { return new OPCLASS<T, S, G>(); }); | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_GPUKERNELFACTORY_H_ | |||
| @@ -34,15 +34,15 @@ MS_REG_GPU_KERNEL_ONE(FusedBatchNorm, | |||
| MS_REG_GPU_KERNEL_ONE(FusedBatchNorm, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddOutputAttr(kNumberTypeFloat16) | |||
| .AddOutputAttr(kNumberTypeFloat16) | |||
| .AddOutputAttr(kNumberTypeFloat16) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat16) | |||
| .AddOutputAttr(kNumberTypeFloat16), | |||
| .AddOutputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32), | |||
| FusedBatchNormGpuKernel, half) | |||
| MS_REG_GPU_KERNEL_ONE(BatchNorm, | |||
| KernelAttr() | |||
| @@ -60,15 +60,15 @@ MS_REG_GPU_KERNEL_ONE(BatchNorm, | |||
| MS_REG_GPU_KERNEL_ONE(BatchNorm, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddOutputAttr(kNumberTypeFloat16) | |||
| .AddOutputAttr(kNumberTypeFloat16) | |||
| .AddOutputAttr(kNumberTypeFloat16) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat16) | |||
| .AddOutputAttr(kNumberTypeFloat16), | |||
| .AddOutputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32), | |||
| FusedBatchNormGpuKernel, half) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -56,17 +56,17 @@ class FusedBatchNormGpuKernel : public GpuKernel { | |||
| return true; | |||
| } | |||
| auto x = GetDeviceAddress<T>(inputs, 0); | |||
| auto scale = GetDeviceAddress<T>(inputs, 1); | |||
| auto bias = GetDeviceAddress<T>(inputs, 2); | |||
| auto runing_mean = GetDeviceAddress<T>(inputs, 3); | |||
| auto runnig_variance = GetDeviceAddress<T>(inputs, 4); | |||
| auto scale = GetDeviceAddress<float>(inputs, 1); | |||
| auto bias = GetDeviceAddress<float>(inputs, 2); | |||
| auto runing_mean = GetDeviceAddress<float>(inputs, 3); | |||
| auto runnig_variance = GetDeviceAddress<float>(inputs, 4); | |||
| auto y = GetDeviceAddress<T>(outputs, 0); | |||
| const float alpha = 1; | |||
| const float beta = 0; | |||
| if (is_train_) { | |||
| auto save_mean = GetDeviceAddress<T>(outputs, 3); | |||
| auto save_variance = GetDeviceAddress<T>(outputs, 4); | |||
| auto save_mean = GetDeviceAddress<float>(outputs, 3); | |||
| auto save_variance = GetDeviceAddress<float>(outputs, 4); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| cudnnBatchNormalizationForwardTraining(handle_, mode_, &alpha, &beta, x_desc_, x, y_desc_, y, | |||
| scale_bias_mean_var_desc_, scale, bias, exp_avg_factor_, runing_mean, | |||
| @@ -33,12 +33,12 @@ MS_REG_GPU_KERNEL_ONE(FusedBatchNormGrad, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddOutputAttr(kNumberTypeFloat16) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat16) | |||
| .AddOutputAttr(kNumberTypeFloat16), | |||
| .AddOutputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32), | |||
| FusedBatchNormGradGpuKernel, half) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -55,12 +55,12 @@ class FusedBatchNormGradGpuKernel : public GpuKernel { | |||
| } | |||
| auto dy = GetDeviceAddress<T>(inputs, 0); | |||
| auto x = GetDeviceAddress<T>(inputs, 1); | |||
| auto scale = GetDeviceAddress<T>(inputs, 2); | |||
| auto save_mean = GetDeviceAddress<T>(inputs, 3); | |||
| auto save_variance = GetDeviceAddress<T>(inputs, 4); | |||
| auto scale = GetDeviceAddress<float>(inputs, 2); | |||
| auto save_mean = GetDeviceAddress<float>(inputs, 3); | |||
| auto save_variance = GetDeviceAddress<float>(inputs, 4); | |||
| auto dx = GetDeviceAddress<T>(outputs, 0); | |||
| auto bn_scale = GetDeviceAddress<T>(outputs, 1); | |||
| auto bn_bias = GetDeviceAddress<T>(outputs, 2); | |||
| auto bn_scale = GetDeviceAddress<float>(outputs, 1); | |||
| auto bn_bias = GetDeviceAddress<float>(outputs, 2); | |||
| const float alpha_data_diff = 1; | |||
| const float beta_data_diff = 0; | |||
| @@ -18,32 +18,41 @@ | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| MS_REG_GPU_KERNEL_TWO(ApplyMomentum, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32), | |||
| MomentumGpuKernel, float, float) | |||
| MS_REG_GPU_KERNEL_TWO(ApplyMomentum, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddOutputAttr(kNumberTypeFloat16), | |||
| MomentumGpuKernel, half, half) | |||
| MS_REG_GPU_KERNEL_TWO(ApplyMomentum, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat16), | |||
| MomentumGpuKernel, half, float) | |||
| MS_REG_GPU_KERNEL_THREE(ApplyMomentum, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32), | |||
| MomentumGpuKernel, float, float, float) | |||
| MS_REG_GPU_KERNEL_THREE(ApplyMomentum, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddOutputAttr(kNumberTypeFloat16), | |||
| MomentumGpuKernel, half, half, half) | |||
| MS_REG_GPU_KERNEL_THREE(ApplyMomentum, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat16), | |||
| MomentumGpuKernel, half, float, half) | |||
| MS_REG_GPU_KERNEL_THREE(ApplyMomentum, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32), | |||
| MomentumGpuKernel, float, float, half) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -23,7 +23,7 @@ | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/momentum_impl.cuh" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| template <typename T, typename S> | |||
| template <typename T, typename S, typename G> | |||
| class MomentumGpuKernel : public GpuKernel { | |||
| public: | |||
| MomentumGpuKernel() | |||
| @@ -38,7 +38,7 @@ class MomentumGpuKernel : public GpuKernel { | |||
| T *variable = GetDeviceAddress<T>(inputs, 0); | |||
| T *accumulation = GetDeviceAddress<T>(inputs, 1); | |||
| S *learning_rate = GetDeviceAddress<S>(inputs, 2); | |||
| T *gradient = GetDeviceAddress<T>(inputs, 3); | |||
| G *gradient = GetDeviceAddress<G>(inputs, 3); | |||
| S *momentum = GetDeviceAddress<S>(inputs, 4); | |||
| MomentumUpdateVariable(inputs[0]->size / sizeof(T), variable, accumulation, learning_rate, gradient, momentum, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| @@ -54,7 +54,7 @@ class MomentumGpuKernel : public GpuKernel { | |||
| variable_size_ = sizeof(T); | |||
| accumulation_size_ = sizeof(T); | |||
| learning_rate_size_ = sizeof(S); | |||
| gradient_size_ = sizeof(T); | |||
| gradient_size_ = sizeof(G); | |||
| momentum_size_ = sizeof(S); | |||
| auto variable_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| @@ -0,0 +1,63 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/optimizer/gpu/replace_momentum_cast_fusion.h" | |||
| #include <memory> | |||
| #include <vector> | |||
| #include <string> | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| #include "ir/primitive.h" | |||
| #include "utils/utils.h" | |||
| #include "backend/optimizer/common/helper.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| const BaseRef ReplaceMomentumCastFusion::DefinePattern() const { | |||
| VectorRef grad_cast = VectorRef({prim::kPrimCast, grad_}); | |||
| VectorRef momentum = VectorRef({prim::kPrimApplyMomentum, var_, acc_, lr_, grad_cast, mom_}); | |||
| return momentum; | |||
| } | |||
| const AnfNodePtr ReplaceMomentumCastFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, | |||
| const EquivPtr &equiv) const { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| MS_EXCEPTION_IF_NULL(equiv); | |||
| auto grad_cast = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 3); | |||
| auto grad = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(grad_cast), 0); | |||
| MS_EXCEPTION_IF_NULL(grad_cast); | |||
| MS_EXCEPTION_IF_NULL(grad); | |||
| auto manager = graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| manager->Replace(utils::cast<CNodePtr>(grad_cast), utils::cast<CNodePtr>(grad)); | |||
| std::vector<TypeId> outputs_type; | |||
| std::vector<std::vector<size_t>> outputs_shape; | |||
| auto output_num = AnfAlgo::GetOutputTensorNum(node); | |||
| for (size_t i = 0; i < output_num; i++) { | |||
| outputs_type.push_back(AnfAlgo::GetOutputInferDataType(node, i)); | |||
| outputs_shape.push_back(AnfAlgo::GetOutputInferShape(node, i)); | |||
| } | |||
| outputs_type[3] = AnfAlgo::GetPrevNodeOutputInferDataType(grad_cast, 0); | |||
| AnfAlgo::SetOutputInferTypeAndShape(outputs_type, outputs_shape, node.get()); | |||
| return node; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,46 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_REPLACE_MOMENTUM_CAST_FUSION_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_REPLACE_MOMENTUM_CAST_FUSION_H_ | |||
| #include <memory> | |||
| #include "backend/optimizer/common/optimizer.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class ReplaceMomentumCastFusion : public PatternProcessPass { | |||
| public: | |||
| explicit ReplaceMomentumCastFusion(bool multigraph = true) : PatternProcessPass("replace_momentum_cast", multigraph) { | |||
| var_ = std::make_shared<Var>(); | |||
| acc_ = std::make_shared<Var>(); | |||
| lr_ = std::make_shared<Var>(); | |||
| grad_ = std::make_shared<Var>(); | |||
| mom_ = std::make_shared<Var>(); | |||
| } | |||
| ~ReplaceMomentumCastFusion() override = default; | |||
| const BaseRef DefinePattern() const override; | |||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||
| private: | |||
| VarPtr var_; | |||
| VarPtr acc_; | |||
| VarPtr lr_; | |||
| VarPtr grad_; | |||
| VarPtr mom_; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_REPLACE_MOMENTUM_CAST_FUSION_H_ | |||
| @@ -25,6 +25,11 @@ | |||
| #include "backend/optimizer/pass/getitem_tuple.h" | |||
| #include "backend/optimizer/gpu/adam_weight_decay_fusion.h" | |||
| #include "backend/optimizer/gpu/adam_fusion.h" | |||
| #include "backend/optimizer/gpu/replace_bn_cast_fusion.h" | |||
| #include "backend/optimizer/gpu/replace_bn_grad_cast_fusion.h" | |||
| #include "backend/optimizer/gpu/replace_bn_grad_cast2_fusion.h" | |||
| #include "backend/optimizer/gpu/replace_momentum_cast_fusion.h" | |||
| #include "backend/optimizer/gpu/replace_addn_fusion.h" | |||
| #include "runtime/device/kernel_runtime_manager.h" | |||
| #include "predict/predict.h" | |||
| #include "common/utils.h" | |||
| @@ -59,6 +64,11 @@ void GPUSession::Optimize(const std::shared_ptr<KernelGraph> &kernel_graph) { | |||
| auto pm = std::make_shared<opt::PassManager>(); | |||
| pm->AddPass(std::make_shared<opt::AdamWeightDecayFusion>()); | |||
| pm->AddPass(std::make_shared<opt::AdamFusion>()); | |||
| pm->AddPass(std::make_shared<opt::ReplaceBNCastFusion>()); | |||
| pm->AddPass(std::make_shared<opt::ReplaceBNGradCastFusion>()); | |||
| pm->AddPass(std::make_shared<opt::ReplaceBNGradCast2Fusion>()); | |||
| pm->AddPass(std::make_shared<opt::ReplaceMomentumCastFusion>()); | |||
| pm->AddPass(std::make_shared<opt::ReplaceAddNFusion>()); | |||
| optimizer->AddPassManager(pm); | |||
| (void)optimizer->Optimize(kernel_graph); | |||
| kernel_graph->SetExecOrderByDefault(); | |||