Browse Source

update apply_momentum_weight_scale_fusion pass

tags/v1.3.0
huangbingjian 5 years ago
parent
commit
9424aedec3
2 changed files with 37 additions and 6 deletions
  1. +34
    -4
      mindspore/ccsrc/backend/optimizer/gpu/apply_momentum_weight_scale_fusion.cc
  2. +3
    -2
      mindspore/ccsrc/backend/optimizer/gpu/apply_momentum_weight_scale_fusion.h

+ 34
- 4
mindspore/ccsrc/backend/optimizer/gpu/apply_momentum_weight_scale_fusion.cc View File

@@ -26,6 +26,8 @@

namespace mindspore {
namespace opt {
constexpr size_t kInputIndex = 1;

bool ApplyMomentumWeightDecayScaleFusion::IsScalar(const BaseRef &n) {
if (utils::isa<AnfNodePtr>(n)) {
AnfNodePtr in = utils::cast<AnfNodePtr>(n);
@@ -48,10 +50,36 @@ bool ApplyMomentumWeightDecayScaleFusion::IsScalar(const BaseRef &n) {
return false;
}

bool ApplyMomentumWeightDecayScaleFusion::IsCast(const BaseRef &n) {
if (utils::isa<AnfNodePtr>(n)) {
AnfNodePtr in = utils::cast<AnfNodePtr>(n);
MS_EXCEPTION_IF_NULL(in);
if (IsPrimitiveCNode(in, prim::kPrimCast) ||
(IsPrimitiveCNode(in, prim::kPrimDepend) &&
IsPrimitiveCNode(in->cast<CNodePtr>()->input(kInputIndex), prim::kPrimCast))) {
return true;
}
}
return false;
}

AnfNodePtr GetCastInput(const AnfNodePtr &node) {
if (IsPrimitiveCNode(node, prim::kPrimCast)) {
return node->cast<CNodePtr>()->input(kInputIndex);
}
if (IsPrimitiveCNode(node, prim::kPrimDepend)) {
auto cast_node = node->cast<CNodePtr>()->input(kInputIndex);
if (IsPrimitiveCNode(cast_node, prim::kPrimCast)) {
return cast_node->cast<CNodePtr>()->input(kInputIndex);
}
}
return nullptr;
}

const BaseRef ApplyMomentumWeightDecayScaleFusion::DefinePattern() const {
VectorRef load_para = VectorRef({prim::kPrimLoad, variable_, monad_});
VectorRef weight = VectorRef(
{prim::kPrimAddN, VectorRef({prim::kPrimMul, load_para, weight_decay_}), VectorRef({prim::kPrimCast, gradient_})});
VectorRef weight =
VectorRef({prim::kPrimAddN, VectorRef({prim::kPrimMul, load_para, weight_decay_}), cast_gradient_});
VectorRef scale = VectorRef({prim::kPrimMul, weight, scale_});
VectorRef apply_momentum =
VectorRef({prim::kPrimApplyMomentum, variable_, accumulation_, learning_rate_, scale, momentum_, monad_state_});
@@ -68,7 +96,7 @@ const AnfNodePtr ApplyMomentumWeightDecayScaleFusion::Process(const FuncGraphPtr
auto variable = utils::cast<AnfNodePtr>((*equiv)[variable_]);
auto accumulation = utils::cast<AnfNodePtr>((*equiv)[accumulation_]);
auto learning_rate = utils::cast<AnfNodePtr>((*equiv)[learning_rate_]);
auto gradient = utils::cast<AnfNodePtr>((*equiv)[gradient_]);
auto cast_gradient = utils::cast<AnfNodePtr>((*equiv)[cast_gradient_]);
auto momentum = utils::cast<AnfNodePtr>((*equiv)[momentum_]);
auto monad_state = utils::cast<AnfNodePtr>((*equiv)[monad_state_]);

@@ -77,12 +105,14 @@ const AnfNodePtr ApplyMomentumWeightDecayScaleFusion::Process(const FuncGraphPtr
MS_EXCEPTION_IF_NULL(variable);
MS_EXCEPTION_IF_NULL(accumulation);
MS_EXCEPTION_IF_NULL(learning_rate);
MS_EXCEPTION_IF_NULL(gradient);
MS_EXCEPTION_IF_NULL(cast_gradient);
MS_EXCEPTION_IF_NULL(momentum);
MS_EXCEPTION_IF_NULL(monad_state);

auto prim = std::make_shared<Primitive>(kFusedWeightScaleApplyMomentum);
MS_EXCEPTION_IF_NULL(prim);
auto gradient = GetCastInput(cast_gradient);
MS_EXCEPTION_IF_NULL(gradient);
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), weight_decay, scale, variable, accumulation,
learning_rate, gradient, momentum, monad_state};
auto replace_node = graph->NewCNode(inputs);


+ 3
- 2
mindspore/ccsrc/backend/optimizer/gpu/apply_momentum_weight_scale_fusion.h View File

@@ -31,7 +31,7 @@ class ApplyMomentumWeightDecayScaleFusion : public PatternProcessPass {
variable_ = std::make_shared<Var>();
accumulation_ = std::make_shared<Var>();
learning_rate_ = std::make_shared<Var>();
gradient_ = std::make_shared<Var>();
cast_gradient_ = std::make_shared<CondVar>(IsCast);
momentum_ = std::make_shared<Var>();
monad_state_ = std::make_shared<Var>();
}
@@ -41,6 +41,7 @@ class ApplyMomentumWeightDecayScaleFusion : public PatternProcessPass {

private:
static bool IsScalar(const BaseRef &n);
static bool IsCast(const BaseRef &n);

VarPtr monad_;
VarPtr weight_decay_;
@@ -48,7 +49,7 @@ class ApplyMomentumWeightDecayScaleFusion : public PatternProcessPass {
VarPtr variable_;
VarPtr accumulation_;
VarPtr learning_rate_;
VarPtr gradient_;
VarPtr cast_gradient_;
VarPtr momentum_;
VarPtr monad_state_;
};


Loading…
Cancel
Save