|
|
|
@@ -26,7 +26,7 @@ class ApplyMomentumWeightDecayScaleFusion : public PatternProcessPass { |
|
|
|
explicit ApplyMomentumWeightDecayScaleFusion(bool multigraph = true) |
|
|
|
: PatternProcessPass("momentum_weightdecay_scale_fusion", multigraph) { |
|
|
|
weight_decay_ = std::make_shared<Var>(); |
|
|
|
scale_ = std::make_shared<Var>(); |
|
|
|
scale_ = std::make_shared<CondVar>(IsScalar); |
|
|
|
variable_ = std::make_shared<Var>(); |
|
|
|
accumulation_ = std::make_shared<Var>(); |
|
|
|
learning_rate_ = std::make_shared<Var>(); |
|
|
|
@@ -38,9 +38,10 @@ class ApplyMomentumWeightDecayScaleFusion : public PatternProcessPass { |
|
|
|
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; |
|
|
|
|
|
|
|
private: |
|
|
|
static bool IsScalar(const BaseRef &n); |
|
|
|
|
|
|
|
VarPtr weight_decay_; |
|
|
|
VarPtr scale_; |
|
|
|
|
|
|
|
VarPtr variable_; |
|
|
|
VarPtr accumulation_; |
|
|
|
VarPtr learning_rate_; |
|
|
|
|