Browse Source

!7925 fix gpu momentum fusion

Merge pull request !7925 from chenweifeng/momentum-fusion-fix
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
ea10c7a146
4 changed files with 51 additions and 3 deletions
  1. +22
    -0
      mindspore/ccsrc/backend/optimizer/gpu/apply_momentum_scale_fusion.cc
  2. +4
    -1
      mindspore/ccsrc/backend/optimizer/gpu/apply_momentum_scale_fusion.h
  3. +22
    -0
      mindspore/ccsrc/backend/optimizer/gpu/apply_momentum_weight_scale_fusion.cc
  4. +3
    -2
      mindspore/ccsrc/backend/optimizer/gpu/apply_momentum_weight_scale_fusion.h

+ 22
- 0
mindspore/ccsrc/backend/optimizer/gpu/apply_momentum_scale_fusion.cc View File

@@ -26,6 +26,28 @@

namespace mindspore {
namespace opt {
bool ApplyMomentumScaleFusion::IsScalar(const BaseRef &n) {
if (utils::isa<AnfNodePtr>(n)) {
AnfNodePtr in = utils::cast<AnfNodePtr>(n);
MS_EXCEPTION_IF_NULL(in);
auto shape = in->Shape()->cast<abstract::ShapePtr>();
MS_EXCEPTION_IF_NULL(shape);
if (shape->shape().size() != 0) {
return false;
}
auto dtype = in->Type();
if (dtype->type_id() != kObjectTypeTensorType) {
return false;
}
auto element_type = dyn_cast<TensorType>(dtype)->element()->type_id();
if (element_type != kNumberTypeFloat32) {
return false;
}
return true;
}
return false;
}

const BaseRef ApplyMomentumScaleFusion::DefinePattern() const {
VectorRef scale = VectorRef({prim::kPrimMul, gradient_, scale_});
VectorRef apply_momentum =


+ 4
- 1
mindspore/ccsrc/backend/optimizer/gpu/apply_momentum_scale_fusion.h View File

@@ -18,13 +18,14 @@

#include <memory>
#include "backend/optimizer/common/optimizer.h"
#include "backend/session/anf_runtime_algorithm.h"

namespace mindspore {
namespace opt {
class ApplyMomentumScaleFusion : public PatternProcessPass {
public:
explicit ApplyMomentumScaleFusion(bool multigraph = true) : PatternProcessPass("momentum_scale_fusion", multigraph) {
scale_ = std::make_shared<Var>();
scale_ = std::make_shared<CondVar>(IsScalar);
variable_ = std::make_shared<Var>();
accumulation_ = std::make_shared<Var>();
learning_rate_ = std::make_shared<Var>();
@@ -36,6 +37,8 @@ class ApplyMomentumScaleFusion : public PatternProcessPass {
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;

private:
static bool IsScalar(const BaseRef &n);

VarPtr scale_;
VarPtr variable_;
VarPtr accumulation_;


+ 22
- 0
mindspore/ccsrc/backend/optimizer/gpu/apply_momentum_weight_scale_fusion.cc View File

@@ -26,6 +26,28 @@

namespace mindspore {
namespace opt {
bool ApplyMomentumWeightDecayScaleFusion::IsScalar(const BaseRef &n) {
if (utils::isa<AnfNodePtr>(n)) {
AnfNodePtr in = utils::cast<AnfNodePtr>(n);
MS_EXCEPTION_IF_NULL(in);
auto shape = in->Shape()->cast<abstract::ShapePtr>();
MS_EXCEPTION_IF_NULL(shape);
if (shape->shape().size() != 0) {
return false;
}
auto dtype = in->Type();
if (dtype->type_id() != kObjectTypeTensorType) {
return false;
}
auto element_type = dyn_cast<TensorType>(dtype)->element()->type_id();
if (element_type != kNumberTypeFloat32) {
return false;
}
return true;
}
return false;
}

const BaseRef ApplyMomentumWeightDecayScaleFusion::DefinePattern() const {
VectorRef weight = VectorRef(
{prim::kPrimAddN, VectorRef({prim::kPrimMul, variable_, weight_decay_}), VectorRef({prim::kPrimCast, gradient_})});


+ 3
- 2
mindspore/ccsrc/backend/optimizer/gpu/apply_momentum_weight_scale_fusion.h View File

@@ -26,7 +26,7 @@ class ApplyMomentumWeightDecayScaleFusion : public PatternProcessPass {
explicit ApplyMomentumWeightDecayScaleFusion(bool multigraph = true)
: PatternProcessPass("momentum_weightdecay_scale_fusion", multigraph) {
weight_decay_ = std::make_shared<Var>();
scale_ = std::make_shared<Var>();
scale_ = std::make_shared<CondVar>(IsScalar);
variable_ = std::make_shared<Var>();
accumulation_ = std::make_shared<Var>();
learning_rate_ = std::make_shared<Var>();
@@ -38,9 +38,10 @@ class ApplyMomentumWeightDecayScaleFusion : public PatternProcessPass {
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;

private:
static bool IsScalar(const BaseRef &n);

VarPtr weight_decay_;
VarPtr scale_;

VarPtr variable_;
VarPtr accumulation_;
VarPtr learning_rate_;


Loading…
Cancel
Save