| @@ -26,6 +26,29 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| namespace { | |||||
| bool ApplyMomentumScaleFusion::IsScalar(const BaseRef &n) { | |||||
| if (utils::isa<AnfNodePtr>(n)) { | |||||
| AnfNodePtr in = utils::cast<AnfNodePtr>(n); | |||||
| MS_EXCEPTION_IF_NULL(in); | |||||
| auto shape = in->Shape()->cast<abstract::ShapePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(shape); | |||||
| if (shape->shape().size() != 0) { | |||||
| return false; | |||||
| } | |||||
| auto dtype = in->Type(); | |||||
| if (dtype->type_id() != kObjectTypeTensorType) { | |||||
| return false; | |||||
| } | |||||
| auto element_type = dyn_cast<TensorType>(dtype)->element()->type_id(); | |||||
| if (element_type != kNumberTypeFloat32) { | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| const BaseRef ApplyMomentumScaleFusion::DefinePattern() const { | const BaseRef ApplyMomentumScaleFusion::DefinePattern() const { | ||||
| VectorRef scale = VectorRef({prim::kPrimMul, gradient_, scale_}); | VectorRef scale = VectorRef({prim::kPrimMul, gradient_, scale_}); | ||||
| VectorRef apply_momentum = | VectorRef apply_momentum = | ||||
| @@ -63,5 +86,5 @@ const AnfNodePtr ApplyMomentumScaleFusion::Process(const FuncGraphPtr &graph, co | |||||
| replace_node->set_scope(node->scope()); | replace_node->set_scope(node->scope()); | ||||
| return replace_node; | return replace_node; | ||||
| } | } | ||||
| } // namespace | |||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | |||||
| @@ -18,13 +18,14 @@ | |||||
| #include <memory> | #include <memory> | ||||
| #include "backend/optimizer/common/optimizer.h" | #include "backend/optimizer/common/optimizer.h" | ||||
| #include "backend/session/anf_runtime_algorithm.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| class ApplyMomentumScaleFusion : public PatternProcessPass { | class ApplyMomentumScaleFusion : public PatternProcessPass { | ||||
| public: | public: | ||||
| explicit ApplyMomentumScaleFusion(bool multigraph = true) : PatternProcessPass("momentum_scale_fusion", multigraph) { | explicit ApplyMomentumScaleFusion(bool multigraph = true) : PatternProcessPass("momentum_scale_fusion", multigraph) { | ||||
| scale_ = std::make_shared<Var>(); | |||||
| scale_ = std::make_shared<CondVar>(IsScalar); | |||||
| variable_ = std::make_shared<Var>(); | variable_ = std::make_shared<Var>(); | ||||
| accumulation_ = std::make_shared<Var>(); | accumulation_ = std::make_shared<Var>(); | ||||
| learning_rate_ = std::make_shared<Var>(); | learning_rate_ = std::make_shared<Var>(); | ||||
| @@ -36,6 +37,27 @@ class ApplyMomentumScaleFusion : public PatternProcessPass { | |||||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | ||||
| private: | private: | ||||
| static bool IsScalar(const BaseRef &n) { | |||||
| if (utils::isa<AnfNodePtr>(n)) { | |||||
| AnfNodePtr in = utils::cast<AnfNodePtr>(n); | |||||
| MS_EXCEPTION_IF_NULL(in); | |||||
| auto shape = in->Shape()->cast<abstract::ShapePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(shape); | |||||
| if (shape->shape().size() != 0) { | |||||
| return false; | |||||
| } | |||||
| auto dtype = in->Type(); | |||||
| if (dtype->type_id() != kObjectTypeTensorType) { | |||||
| return false; | |||||
| } | |||||
| auto element_type = dyn_cast<TensorType>(dtype)->element()->type_id(); | |||||
| if (element_type != kNumberTypeFloat32) { | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| VarPtr scale_; | VarPtr scale_; | ||||
| VarPtr variable_; | VarPtr variable_; | ||||
| VarPtr accumulation_; | VarPtr accumulation_; | ||||
| @@ -21,12 +21,36 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| namespace { | |||||
| bool IsScalar(const BaseRef &n) { | |||||
| if (utils::isa<AnfNodePtr>(n)) { | |||||
| AnfNodePtr in = utils::cast<AnfNodePtr>(n); | |||||
| MS_EXCEPTION_IF_NULL(in); | |||||
| auto shape = in->Shape()->cast<abstract::ShapePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(shape); | |||||
| if (shape->shape().size() != 0) { | |||||
| return false; | |||||
| } | |||||
| auto dtype = in->Type(); | |||||
| if (dtype->type_id() != kObjectTypeTensorType) { | |||||
| return false; | |||||
| } | |||||
| auto element_type = dyn_cast<TensorType>(dtype)->element()->type_id(); | |||||
| if (element_type != kNumberTypeFloat32) { | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| } // namespace | |||||
| class ApplyMomentumWeightDecayScaleFusion : public PatternProcessPass { | class ApplyMomentumWeightDecayScaleFusion : public PatternProcessPass { | ||||
| public: | public: | ||||
| explicit ApplyMomentumWeightDecayScaleFusion(bool multigraph = true) | explicit ApplyMomentumWeightDecayScaleFusion(bool multigraph = true) | ||||
| : PatternProcessPass("momentum_weightdecay_scale_fusion", multigraph) { | : PatternProcessPass("momentum_weightdecay_scale_fusion", multigraph) { | ||||
| weight_decay_ = std::make_shared<Var>(); | weight_decay_ = std::make_shared<Var>(); | ||||
| scale_ = std::make_shared<Var>(); | |||||
| scale_ = std::make_shared<CondVar>(IsScalar); | |||||
| variable_ = std::make_shared<Var>(); | variable_ = std::make_shared<Var>(); | ||||
| accumulation_ = std::make_shared<Var>(); | accumulation_ = std::make_shared<Var>(); | ||||
| learning_rate_ = std::make_shared<Var>(); | learning_rate_ = std::make_shared<Var>(); | ||||