Browse Source

fix momentum fusion pass

tags/v1.1.0
wilfChen 5 years ago
parent
commit
cbdd658e24
3 changed files with 72 additions and 3 deletions
  1. +24
    -1
      mindspore/ccsrc/backend/optimizer/gpu/apply_momentum_scale_fusion.cc
  2. +23
    -1
      mindspore/ccsrc/backend/optimizer/gpu/apply_momentum_scale_fusion.h
  3. +25
    -1
      mindspore/ccsrc/backend/optimizer/gpu/apply_momentum_weight_scale_fusion.h

+ 24
- 1
mindspore/ccsrc/backend/optimizer/gpu/apply_momentum_scale_fusion.cc View File

@@ -26,6 +26,29 @@

namespace mindspore {
namespace opt {
namespace {
bool ApplyMomentumScaleFusion::IsScalar(const BaseRef &n) {
if (utils::isa<AnfNodePtr>(n)) {
AnfNodePtr in = utils::cast<AnfNodePtr>(n);
MS_EXCEPTION_IF_NULL(in);
auto shape = in->Shape()->cast<abstract::ShapePtr>();
MS_EXCEPTION_IF_NULL(shape);
if (shape->shape().size() != 0) {
return false;
}
auto dtype = in->Type();
if (dtype->type_id() != kObjectTypeTensorType) {
return false;
}
auto element_type = dyn_cast<TensorType>(dtype)->element()->type_id();
if (element_type != kNumberTypeFloat32) {
return false;
}
return true;
}
return false;
}

const BaseRef ApplyMomentumScaleFusion::DefinePattern() const {
VectorRef scale = VectorRef({prim::kPrimMul, gradient_, scale_});
VectorRef apply_momentum =
@@ -63,5 +86,5 @@ const AnfNodePtr ApplyMomentumScaleFusion::Process(const FuncGraphPtr &graph, co
replace_node->set_scope(node->scope());
return replace_node;
}
} // namespace
} // namespace opt
} // namespace mindspore

+ 23
- 1
mindspore/ccsrc/backend/optimizer/gpu/apply_momentum_scale_fusion.h View File

@@ -18,13 +18,14 @@

#include <memory>
#include "backend/optimizer/common/optimizer.h"
#include "backend/session/anf_runtime_algorithm.h"

namespace mindspore {
namespace opt {
class ApplyMomentumScaleFusion : public PatternProcessPass {
public:
explicit ApplyMomentumScaleFusion(bool multigraph = true) : PatternProcessPass("momentum_scale_fusion", multigraph) {
scale_ = std::make_shared<Var>();
scale_ = std::make_shared<CondVar>(IsScalar);
variable_ = std::make_shared<Var>();
accumulation_ = std::make_shared<Var>();
learning_rate_ = std::make_shared<Var>();
@@ -36,6 +37,27 @@ class ApplyMomentumScaleFusion : public PatternProcessPass {
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;

private:
static bool IsScalar(const BaseRef &n) {
if (utils::isa<AnfNodePtr>(n)) {
AnfNodePtr in = utils::cast<AnfNodePtr>(n);
MS_EXCEPTION_IF_NULL(in);
auto shape = in->Shape()->cast<abstract::ShapePtr>();
MS_EXCEPTION_IF_NULL(shape);
if (shape->shape().size() != 0) {
return false;
}
auto dtype = in->Type();
if (dtype->type_id() != kObjectTypeTensorType) {
return false;
}
auto element_type = dyn_cast<TensorType>(dtype)->element()->type_id();
if (element_type != kNumberTypeFloat32) {
return false;
}
return true;
}
return false;
}
VarPtr scale_;
VarPtr variable_;
VarPtr accumulation_;


+ 25
- 1
mindspore/ccsrc/backend/optimizer/gpu/apply_momentum_weight_scale_fusion.h View File

@@ -21,12 +21,36 @@

namespace mindspore {
namespace opt {
namespace {
bool IsScalar(const BaseRef &n) {
if (utils::isa<AnfNodePtr>(n)) {
AnfNodePtr in = utils::cast<AnfNodePtr>(n);
MS_EXCEPTION_IF_NULL(in);
auto shape = in->Shape()->cast<abstract::ShapePtr>();
MS_EXCEPTION_IF_NULL(shape);
if (shape->shape().size() != 0) {
return false;
}
auto dtype = in->Type();
if (dtype->type_id() != kObjectTypeTensorType) {
return false;
}
auto element_type = dyn_cast<TensorType>(dtype)->element()->type_id();
if (element_type != kNumberTypeFloat32) {
return false;
}
return true;
}
return false;
}
} // namespace

class ApplyMomentumWeightDecayScaleFusion : public PatternProcessPass {
public:
explicit ApplyMomentumWeightDecayScaleFusion(bool multigraph = true)
: PatternProcessPass("momentum_weightdecay_scale_fusion", multigraph) {
weight_decay_ = std::make_shared<Var>();
scale_ = std::make_shared<Var>();
scale_ = std::make_shared<CondVar>(IsScalar);
variable_ = std::make_shared<Var>();
accumulation_ = std::make_shared<Var>();
learning_rate_ = std::make_shared<Var>();


Loading…
Cancel
Save