Browse Source

add switch_defer_inline

pull/15459/head
huangbingjian 4 years ago
parent
commit
e5d32c9ff8
10 changed files with 46 additions and 10 deletions
  1. +3
    -0
      mindspore/ccsrc/debug/anf_ir_utils.cc
  2. +2
    -1
      mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc
  3. +6
    -2
      mindspore/ccsrc/frontend/optimizer/irpass.cc
  4. +4
    -1
      mindspore/ccsrc/frontend/optimizer/irpass.h
  5. +3
    -2
      mindspore/ccsrc/frontend/optimizer/irpass/inline.h
  6. +20
    -3
      mindspore/ccsrc/frontend/optimizer/irpass/switch_or_switch_layer_defer_inline.h
  7. +1
    -0
      mindspore/ccsrc/pipeline/jit/pass.cc
  8. +1
    -0
      mindspore/core/ir/func_graph.cc
  9. +4
    -1
      mindspore/core/ir/func_graph.h
  10. +2
    -0
      mindspore/core/ir/func_graph_cloner.cc

+ 3
- 0
mindspore/ccsrc/debug/anf_ir_utils.cc View File

@@ -618,6 +618,9 @@ void AnfExporter::ExportOneFuncGraph(std::ofstream &ofs, const FuncGraphPtr &fun
std::vector<AnfNodePtr> parameters = func_graph->parameters();
OrderedMap<AnfNodePtr, int, ParamPtrHasher, ParamPtrEqual> param_map;

if (*(func_graph->switch_input())) {
ofs << "switch_input: " << *(func_graph->switch_input()) << "\n";
}
if (*(func_graph->switch_layer_input())) {
ofs << "switch_layer_input: " << *(func_graph->switch_layer_input()) << "\n";
}


+ 2
- 1
mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc View File

@@ -45,7 +45,8 @@ DFunctor::DFunctor(const FuncGraphPtr &primal_graph, const pipeline::ResourceBas
TraceGuard guard(std::make_shared<TraceGradFprop>(primal_graph->debug_info()));
k_graph_ = std::make_shared<FuncGraph>();
}
// To keep switch_layer's inputs from being inlined
// To keep switch or switch_layer's inputs from being inlined
k_graph_->set_switch_input(primal_graph->switch_input());
k_graph_->set_switch_layer_input(primal_graph->switch_layer_input());
k_graph_->set_stage(primal_graph->stage());



+ 6
- 2
mindspore/ccsrc/frontend/optimizer/irpass.cc View File

@@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -47,7 +47,7 @@
#include "frontend/optimizer/opt.h"
#include "frontend/optimizer/irpass/row_tensor_eliminate.h"
#include "frontend/optimizer/irpass/sparse_tensor_eliminate.h"
#include "frontend/optimizer/irpass/switch_layer_defer_inline.h"
#include "frontend/optimizer/irpass/switch_or_switch_layer_defer_inline.h"
#include "frontend/optimizer/irpass/call_graph_tuple_transform.h"

namespace mindspore {
@@ -231,6 +231,10 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
value_based_eliminate_ = MakeSubstitution(std::make_shared<ValueBasedEliminate>(), "value_based_eliminate",
{prim::kPrimSelect, prim::kPrimMinimum, prim::kPrimMaximum});

// switch defer inline
switch_defer_inline_ =
MakeSubstitution(std::make_shared<SwitchDeferInline>(), "switch_defer_inline", prim::kPrimSwitch);

// switch_layer defer inline
switch_layer_defer_inline_ =
MakeSubstitution(std::make_shared<SwitchLayerDeferInline>(), "switch_layer_defer_inline", prim::kPrimSwitchLayer);


+ 4
- 1
mindspore/ccsrc/frontend/optimizer/irpass.h View File

@@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -139,6 +139,9 @@ class OptimizeIRPassLib {
// Value_Based Eliminate
SubstitutionPtr value_based_eliminate_;

// Switch defer inline
SubstitutionPtr switch_defer_inline_;

// SwitchLayer defer inline
SubstitutionPtr switch_layer_defer_inline_;



+ 3
- 2
mindspore/ccsrc/frontend/optimizer/irpass/inline.h View File

@@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -41,7 +41,8 @@ class ReplaceApplicator : public AnfVisitor {
}

auto fg = GetValueNode<FuncGraphPtr>(node);
if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE) || fg->stage() != -1 || fg->stub() || *(fg->switch_layer_input())) {
if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE) || fg->stage() != -1 || fg->stub() || *(fg->switch_input()) ||
*(fg->switch_layer_input())) {
return nullptr;
}



mindspore/ccsrc/frontend/optimizer/irpass/switch_layer_defer_inline.h → mindspore/ccsrc/frontend/optimizer/irpass/switch_or_switch_layer_defer_inline.h View File

@@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -28,12 +28,29 @@
namespace mindspore {
namespace opt {
namespace irpass {
// {prim::kPrimSwitchLayer, {Index, layers}}
// {prim::kPrimSwitch, cond, true_branch, false_branch}
class SwitchDeferInline : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
auto cnode = node->cast<CNodePtr>();
auto true_abstract = dyn_cast<abstract::FuncGraphAbstractClosure>(cnode->input(2)->abstract());
if (true_abstract != nullptr) {
*(true_abstract->func_graph()->switch_input()) = true;
}
auto false_abstract = dyn_cast<abstract::FuncGraphAbstractClosure>(cnode->input(3)->abstract());
if (false_abstract != nullptr) {
*(false_abstract->func_graph()->switch_input()) = true;
}
return nullptr;
}
};

// {prim::kPrimSwitchLayer, Index, layers}
class SwitchLayerDeferInline : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
auto cnode = node->cast<CNodePtr>();
auto tuple = dyn_cast<abstract::AbstractTuple>(cnode->inputs()[2]->abstract());
auto tuple = dyn_cast<abstract::AbstractTuple>(cnode->input(2)->abstract());
for (auto elem : tuple->elements()) {
auto abstract = dyn_cast<abstract::FuncGraphAbstractClosure>(elem);
if (abstract != nullptr) {

+ 1
- 0
mindspore/ccsrc/pipeline/jit/pass.cc View File

@@ -97,6 +97,7 @@ bool ReAutoMonadWrapper(const FuncGraphPtr &root, const opt::OptimizerPtr &) { r

OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
opt::OptPassConfig a_1 = opt::OptPassConfig({
irpass.switch_defer_inline_,
irpass.switch_layer_defer_inline_,
irpass.switch_simplify_,
irpass.exchange_switch_depend_value_,


+ 1
- 0
mindspore/core/ir/func_graph.cc View File

@@ -50,6 +50,7 @@ FuncGraph::FuncGraph()
stub_(false),
stage_(-1) {
debug_info_ = std::make_shared<GraphDebugInfo>();
switch_input_ = std::make_shared<bool>(false);
switch_layer_input_ = std::make_shared<bool>(false);
}



+ 4
- 1
mindspore/core/ir/func_graph.h View File

@@ -381,6 +381,8 @@ class FuncGraph : public FuncGraphBase, public EffectInfoHolder {
bool stub() const { return stub_; }
void set_stub(bool stub) { stub_ = stub; }
static void set_drawer(Drawer drawer) { drawer_ = drawer; }
std::shared_ptr<bool> switch_input() const { return switch_input_; }
void set_switch_input(std::shared_ptr<bool> switch_input) { switch_input_ = switch_input; }
std::shared_ptr<bool> switch_layer_input() const { return switch_layer_input_; }
void set_switch_layer_input(std::shared_ptr<bool> switch_layer_input) { switch_layer_input_ = switch_layer_input; }
bool ContainMultiTarget() const;
@@ -462,8 +464,9 @@ class FuncGraph : public FuncGraphBase, public EffectInfoHolder {
OrderedSet<CNodePtr> order_;
bool stub_;
inline static Drawer drawer_ = nullptr;
// Design switch_layer_input as a ptr to
// Design switch_input and switch_layer_input as a ptr to
// share between derived backpropagator and cloned graphs.
std::shared_ptr<bool> switch_input_;
std::shared_ptr<bool> switch_layer_input_;
int64_t stage_;
std::unordered_map<AbstractBasePtrList, FuncGraphPtr, abstract::AbstractBasePtrListHasher,


+ 2
- 0
mindspore/core/ir/func_graph_cloner.cc View File

@@ -233,6 +233,7 @@ void Cloner::SetFuncGraphInfo(const FuncGraphPtr &func_graph, FuncGraphPtr *cons
(*target_func_graph)->set_hyper_param_count(func_graph->hyper_param_count());
(*target_func_graph)->set_is_generate(func_graph->is_generated());
(*target_func_graph)->set_stub(func_graph->stub());
(*target_func_graph)->set_switch_input(func_graph->switch_input());
(*target_func_graph)->set_switch_layer_input(func_graph->switch_layer_input());
}

@@ -680,6 +681,7 @@ FuncGraphPtr TransformableClone(const FuncGraphPtr &func_graph, const TraceInfoP
new_func_graph->set_hyper_param_count(func_graph->hyper_param_count());
new_func_graph->set_is_generate(func_graph->is_generated());
new_func_graph->set_stub(func_graph->stub());
new_func_graph->set_switch_input(func_graph->switch_input());
new_func_graph->set_switch_layer_input(func_graph->switch_layer_input());
for (auto &item : func_graph->parameter_default_value()) {
new_func_graph->set_param_default_value(item.first, cloner[item.second]);


Loading…
Cancel
Save