| @@ -26,6 +26,29 @@ | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace { | |||
| bool ApplyMomentumScaleFusion::IsScalar(const BaseRef &n) { | |||
| if (utils::isa<AnfNodePtr>(n)) { | |||
| AnfNodePtr in = utils::cast<AnfNodePtr>(n); | |||
| MS_EXCEPTION_IF_NULL(in); | |||
| auto shape = in->Shape()->cast<abstract::ShapePtr>(); | |||
| MS_EXCEPTION_IF_NULL(shape); | |||
| if (shape->shape().size() != 0) { | |||
| return false; | |||
| } | |||
| auto dtype = in->Type(); | |||
| if (dtype->type_id() != kObjectTypeTensorType) { | |||
| return false; | |||
| } | |||
| auto element_type = dyn_cast<TensorType>(dtype)->element()->type_id(); | |||
| if (element_type != kNumberTypeFloat32) { | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| return false; | |||
| } | |||
| const BaseRef ApplyMomentumScaleFusion::DefinePattern() const { | |||
| VectorRef scale = VectorRef({prim::kPrimMul, gradient_, scale_}); | |||
| VectorRef apply_momentum = | |||
| @@ -63,5 +86,5 @@ const AnfNodePtr ApplyMomentumScaleFusion::Process(const FuncGraphPtr &graph, co | |||
| replace_node->set_scope(node->scope()); | |||
| return replace_node; | |||
| } | |||
| } // namespace | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -18,13 +18,14 @@ | |||
| #include <memory> | |||
| #include "backend/optimizer/common/optimizer.h" | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class ApplyMomentumScaleFusion : public PatternProcessPass { | |||
| public: | |||
| explicit ApplyMomentumScaleFusion(bool multigraph = true) : PatternProcessPass("momentum_scale_fusion", multigraph) { | |||
| scale_ = std::make_shared<Var>(); | |||
| scale_ = std::make_shared<CondVar>(IsScalar); | |||
| variable_ = std::make_shared<Var>(); | |||
| accumulation_ = std::make_shared<Var>(); | |||
| learning_rate_ = std::make_shared<Var>(); | |||
| @@ -36,6 +37,27 @@ class ApplyMomentumScaleFusion : public PatternProcessPass { | |||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||
| private: | |||
| static bool IsScalar(const BaseRef &n) { | |||
| if (utils::isa<AnfNodePtr>(n)) { | |||
| AnfNodePtr in = utils::cast<AnfNodePtr>(n); | |||
| MS_EXCEPTION_IF_NULL(in); | |||
| auto shape = in->Shape()->cast<abstract::ShapePtr>(); | |||
| MS_EXCEPTION_IF_NULL(shape); | |||
| if (shape->shape().size() != 0) { | |||
| return false; | |||
| } | |||
| auto dtype = in->Type(); | |||
| if (dtype->type_id() != kObjectTypeTensorType) { | |||
| return false; | |||
| } | |||
| auto element_type = dyn_cast<TensorType>(dtype)->element()->type_id(); | |||
| if (element_type != kNumberTypeFloat32) { | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| return false; | |||
| } | |||
| VarPtr scale_; | |||
| VarPtr variable_; | |||
| VarPtr accumulation_; | |||
| @@ -21,12 +21,36 @@ | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace { | |||
| bool IsScalar(const BaseRef &n) { | |||
| if (utils::isa<AnfNodePtr>(n)) { | |||
| AnfNodePtr in = utils::cast<AnfNodePtr>(n); | |||
| MS_EXCEPTION_IF_NULL(in); | |||
| auto shape = in->Shape()->cast<abstract::ShapePtr>(); | |||
| MS_EXCEPTION_IF_NULL(shape); | |||
| if (shape->shape().size() != 0) { | |||
| return false; | |||
| } | |||
| auto dtype = in->Type(); | |||
| if (dtype->type_id() != kObjectTypeTensorType) { | |||
| return false; | |||
| } | |||
| auto element_type = dyn_cast<TensorType>(dtype)->element()->type_id(); | |||
| if (element_type != kNumberTypeFloat32) { | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| return false; | |||
| } | |||
| } // namespace | |||
| class ApplyMomentumWeightDecayScaleFusion : public PatternProcessPass { | |||
| public: | |||
| explicit ApplyMomentumWeightDecayScaleFusion(bool multigraph = true) | |||
| : PatternProcessPass("momentum_weightdecay_scale_fusion", multigraph) { | |||
| weight_decay_ = std::make_shared<Var>(); | |||
| scale_ = std::make_shared<Var>(); | |||
| scale_ = std::make_shared<CondVar>(IsScalar); | |||
| variable_ = std::make_shared<Var>(); | |||
| accumulation_ = std::make_shared<Var>(); | |||
| learning_rate_ = std::make_shared<Var>(); | |||