From cbdd658e24f1fac3bb9dd2e36911d346cd495a96 Mon Sep 17 00:00:00 2001 From: wilfChen Date: Wed, 28 Oct 2020 20:48:45 +0800 Subject: [PATCH] fix momentum fusion pass --- .../gpu/apply_momentum_scale_fusion.cc | 25 +++++++++++++++++- .../gpu/apply_momentum_scale_fusion.h | 24 ++++++++++++++++- .../gpu/apply_momentum_weight_scale_fusion.h | 26 ++++++++++++++++++- 3 files changed, 72 insertions(+), 3 deletions(-) diff --git a/mindspore/ccsrc/backend/optimizer/gpu/apply_momentum_scale_fusion.cc b/mindspore/ccsrc/backend/optimizer/gpu/apply_momentum_scale_fusion.cc index 6ce2d3a72a..c21f107805 100644 --- a/mindspore/ccsrc/backend/optimizer/gpu/apply_momentum_scale_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/gpu/apply_momentum_scale_fusion.cc @@ -26,6 +26,29 @@ namespace mindspore { namespace opt { +namespace { +bool ApplyMomentumScaleFusion::IsScalar(const BaseRef &n) { + if (utils::isa(n)) { + AnfNodePtr in = utils::cast(n); + MS_EXCEPTION_IF_NULL(in); + auto shape = in->Shape()->cast(); + MS_EXCEPTION_IF_NULL(shape); + if (shape->shape().size() != 0) { + return false; + } + auto dtype = in->Type(); + if (dtype->type_id() != kObjectTypeTensorType) { + return false; + } + auto element_type = dyn_cast(dtype)->element()->type_id(); + if (element_type != kNumberTypeFloat32) { + return false; + } + return true; + } + return false; +} + const BaseRef ApplyMomentumScaleFusion::DefinePattern() const { VectorRef scale = VectorRef({prim::kPrimMul, gradient_, scale_}); VectorRef apply_momentum = @@ -63,5 +86,5 @@ const AnfNodePtr ApplyMomentumScaleFusion::Process(const FuncGraphPtr &graph, co replace_node->set_scope(node->scope()); return replace_node; } +} // namespace } // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/gpu/apply_momentum_scale_fusion.h b/mindspore/ccsrc/backend/optimizer/gpu/apply_momentum_scale_fusion.h index c9112ab6e9..349c6be338 100644 --- a/mindspore/ccsrc/backend/optimizer/gpu/apply_momentum_scale_fusion.h +++ b/mindspore/ccsrc/backend/optimizer/gpu/apply_momentum_scale_fusion.h @@ -18,13 +18,14 @@ #include #include "backend/optimizer/common/optimizer.h" +#include "backend/session/anf_runtime_algorithm.h" namespace mindspore { namespace opt { class ApplyMomentumScaleFusion : public PatternProcessPass { public: explicit ApplyMomentumScaleFusion(bool multigraph = true) : PatternProcessPass("momentum_scale_fusion", multigraph) { - scale_ = std::make_shared(); + scale_ = std::make_shared(IsScalar); variable_ = std::make_shared(); accumulation_ = std::make_shared(); learning_rate_ = std::make_shared(); @@ -36,6 +37,27 @@ class ApplyMomentumScaleFusion : public PatternProcessPass { const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; private: + static bool IsScalar(const BaseRef &n) { + if (utils::isa(n)) { + AnfNodePtr in = utils::cast(n); + MS_EXCEPTION_IF_NULL(in); + auto shape = in->Shape()->cast(); + MS_EXCEPTION_IF_NULL(shape); + if (shape->shape().size() != 0) { + return false; + } + auto dtype = in->Type(); + if (dtype->type_id() != kObjectTypeTensorType) { + return false; + } + auto element_type = dyn_cast(dtype)->element()->type_id(); + if (element_type != kNumberTypeFloat32) { + return false; + } + return true; + } + return false; + } VarPtr scale_; VarPtr variable_; VarPtr accumulation_; diff --git a/mindspore/ccsrc/backend/optimizer/gpu/apply_momentum_weight_scale_fusion.h b/mindspore/ccsrc/backend/optimizer/gpu/apply_momentum_weight_scale_fusion.h index f047881d81..d8fb16e7b9 100644 --- a/mindspore/ccsrc/backend/optimizer/gpu/apply_momentum_weight_scale_fusion.h +++ b/mindspore/ccsrc/backend/optimizer/gpu/apply_momentum_weight_scale_fusion.h @@ -21,12 +21,36 @@ namespace mindspore { namespace opt { +namespace { +bool IsScalar(const BaseRef &n) { + if (utils::isa(n)) { + AnfNodePtr in = utils::cast(n); + MS_EXCEPTION_IF_NULL(in); + auto shape = in->Shape()->cast(); + MS_EXCEPTION_IF_NULL(shape); + if (shape->shape().size() != 0) { + return false; + } + auto dtype = in->Type(); + if (dtype->type_id() != kObjectTypeTensorType) { + return false; + } + auto element_type = dyn_cast(dtype)->element()->type_id(); + if (element_type != kNumberTypeFloat32) { + return false; + } + return true; + } + return false; +} +} // namespace + class ApplyMomentumWeightDecayScaleFusion : public PatternProcessPass { public: explicit ApplyMomentumWeightDecayScaleFusion(bool multigraph = true) : PatternProcessPass("momentum_weightdecay_scale_fusion", multigraph) { weight_decay_ = std::make_shared(); - scale_ = std::make_shared(); + scale_ = std::make_shared(IsScalar); variable_ = std::make_shared(); accumulation_ = std::make_shared(); learning_rate_ = std::make_shared();