diff --git a/mindspore/ccsrc/debug/anf_ir_utils.cc b/mindspore/ccsrc/debug/anf_ir_utils.cc index 0af24206de..e9d54ac6eb 100644 --- a/mindspore/ccsrc/debug/anf_ir_utils.cc +++ b/mindspore/ccsrc/debug/anf_ir_utils.cc @@ -618,6 +618,9 @@ void AnfExporter::ExportOneFuncGraph(std::ofstream &ofs, const FuncGraphPtr &fun std::vector parameters = func_graph->parameters(); OrderedMap param_map; + if (*(func_graph->switch_input())) { + ofs << "switch_input: " << *(func_graph->switch_input()) << "\n"; + } if (*(func_graph->switch_layer_input())) { ofs << "switch_layer_input: " << *(func_graph->switch_layer_input()) << "\n"; } diff --git a/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc b/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc index 8ca769a6d8..0db90f5863 100644 --- a/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc +++ b/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc @@ -45,7 +45,8 @@ DFunctor::DFunctor(const FuncGraphPtr &primal_graph, const pipeline::ResourceBas TraceGuard guard(std::make_shared(primal_graph->debug_info())); k_graph_ = std::make_shared(); } - // To keep switch_layer's inputs from being inlined + // To keep switch or switch_layer's inputs from being inlined + k_graph_->set_switch_input(primal_graph->switch_input()); k_graph_->set_switch_layer_input(primal_graph->switch_layer_input()); k_graph_->set_stage(primal_graph->stage()); diff --git a/mindspore/ccsrc/frontend/optimizer/irpass.cc b/mindspore/ccsrc/frontend/optimizer/irpass.cc index 944cc3398a..f6505d4d6a 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass.cc +++ b/mindspore/ccsrc/frontend/optimizer/irpass.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -47,7 +47,7 @@ #include "frontend/optimizer/opt.h" #include "frontend/optimizer/irpass/row_tensor_eliminate.h" #include "frontend/optimizer/irpass/sparse_tensor_eliminate.h" -#include "frontend/optimizer/irpass/switch_layer_defer_inline.h" +#include "frontend/optimizer/irpass/switch_or_switch_layer_defer_inline.h" #include "frontend/optimizer/irpass/call_graph_tuple_transform.h" namespace mindspore { @@ -231,6 +231,10 @@ OptimizeIRPassLib::OptimizeIRPassLib() { value_based_eliminate_ = MakeSubstitution(std::make_shared(), "value_based_eliminate", {prim::kPrimSelect, prim::kPrimMinimum, prim::kPrimMaximum}); + // switch defer inline + switch_defer_inline_ = + MakeSubstitution(std::make_shared(), "switch_defer_inline", prim::kPrimSwitch); + // switch_layer defer inline switch_layer_defer_inline_ = MakeSubstitution(std::make_shared(), "switch_layer_defer_inline", prim::kPrimSwitchLayer); diff --git a/mindspore/ccsrc/frontend/optimizer/irpass.h b/mindspore/ccsrc/frontend/optimizer/irpass.h index aa01f7c20e..3d5423b9f6 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -139,6 +139,9 @@ class OptimizeIRPassLib { // Value_Based Eliminate SubstitutionPtr value_based_eliminate_; + // Switch defer inline + SubstitutionPtr switch_defer_inline_; + // SwitchLayer defer inline SubstitutionPtr switch_layer_defer_inline_; diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/inline.h b/mindspore/ccsrc/frontend/optimizer/irpass/inline.h index 49955cec22..c1e991f289 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/inline.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/inline.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -41,7 +41,8 @@ class ReplaceApplicator : public AnfVisitor { } auto fg = GetValueNode(node); - if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE) || fg->stage() != -1 || fg->stub() || *(fg->switch_layer_input())) { + if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE) || fg->stage() != -1 || fg->stub() || *(fg->switch_input()) || + *(fg->switch_layer_input())) { return nullptr; } diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/switch_layer_defer_inline.h b/mindspore/ccsrc/frontend/optimizer/irpass/switch_or_switch_layer_defer_inline.h similarity index 64% rename from mindspore/ccsrc/frontend/optimizer/irpass/switch_layer_defer_inline.h rename to mindspore/ccsrc/frontend/optimizer/irpass/switch_or_switch_layer_defer_inline.h index e360e016cc..1f74686634 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/switch_layer_defer_inline.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/switch_or_switch_layer_defer_inline.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -28,12 +28,29 @@ namespace mindspore { namespace opt { namespace irpass { -// {prim::kPrimSwitchLayer, {Index, layers}} +// {prim::kPrimSwitch, cond, true_branch, false_branch} +class SwitchDeferInline : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + auto cnode = node->cast(); + auto true_abstract = dyn_cast(cnode->input(2)->abstract()); + if (true_abstract != nullptr) { + *(true_abstract->func_graph()->switch_input()) = true; + } + auto false_abstract = dyn_cast(cnode->input(3)->abstract()); + if (false_abstract != nullptr) { + *(false_abstract->func_graph()->switch_input()) = true; + } + return nullptr; + } +}; + +// {prim::kPrimSwitchLayer, Index, layers} class SwitchLayerDeferInline : public AnfVisitor { public: AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { auto cnode = node->cast(); - auto tuple = dyn_cast(cnode->inputs()[2]->abstract()); + auto tuple = dyn_cast(cnode->input(2)->abstract()); for (auto elem : tuple->elements()) { auto abstract = dyn_cast(elem); if (abstract != nullptr) { diff --git a/mindspore/ccsrc/pipeline/jit/pass.cc b/mindspore/ccsrc/pipeline/jit/pass.cc index a7c09a63c6..f1130d2987 100644 --- a/mindspore/ccsrc/pipeline/jit/pass.cc +++ b/mindspore/ccsrc/pipeline/jit/pass.cc @@ -97,6 +97,7 @@ bool ReAutoMonadWrapper(const FuncGraphPtr &root, const opt::OptimizerPtr &) { r OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { opt::OptPassConfig a_1 = opt::OptPassConfig({ + irpass.switch_defer_inline_, irpass.switch_layer_defer_inline_, irpass.switch_simplify_, irpass.exchange_switch_depend_value_, diff --git a/mindspore/core/ir/func_graph.cc b/mindspore/core/ir/func_graph.cc index aae51466d1..1bea9e30cf 100644 --- a/mindspore/core/ir/func_graph.cc +++ b/mindspore/core/ir/func_graph.cc @@ -50,6 +50,7 @@ FuncGraph::FuncGraph() stub_(false), stage_(-1) { debug_info_ = std::make_shared(); + switch_input_ = std::make_shared(false); switch_layer_input_ = std::make_shared(false); } diff --git a/mindspore/core/ir/func_graph.h b/mindspore/core/ir/func_graph.h index 2fe94451e9..7f1faa9a0f 100644 --- a/mindspore/core/ir/func_graph.h +++ b/mindspore/core/ir/func_graph.h @@ -381,6 +381,8 @@ class FuncGraph : public FuncGraphBase, public EffectInfoHolder { bool stub() const { return stub_; } void set_stub(bool stub) { stub_ = stub; } static void set_drawer(Drawer drawer) { drawer_ = drawer; } + std::shared_ptr switch_input() const { return switch_input_; } + void set_switch_input(std::shared_ptr switch_input) { switch_input_ = switch_input; } std::shared_ptr switch_layer_input() const { return switch_layer_input_; } void set_switch_layer_input(std::shared_ptr switch_layer_input) { switch_layer_input_ = switch_layer_input; } bool ContainMultiTarget() const; @@ -462,8 +464,9 @@ class FuncGraph : public FuncGraphBase, public EffectInfoHolder { OrderedSet order_; bool stub_; inline static Drawer drawer_ = nullptr; - // Design switch_layer_input as a ptr to + // Design switch_input and switch_layer_input as a ptr to // share between derived backpropagator and cloned graphs. + std::shared_ptr switch_input_; std::shared_ptr switch_layer_input_; int64_t stage_; std::unordered_mapset_hyper_param_count(func_graph->hyper_param_count()); (*target_func_graph)->set_is_generate(func_graph->is_generated()); (*target_func_graph)->set_stub(func_graph->stub()); + (*target_func_graph)->set_switch_input(func_graph->switch_input()); (*target_func_graph)->set_switch_layer_input(func_graph->switch_layer_input()); } @@ -680,6 +681,7 @@ FuncGraphPtr TransformableClone(const FuncGraphPtr &func_graph, const TraceInfoP new_func_graph->set_hyper_param_count(func_graph->hyper_param_count()); new_func_graph->set_is_generate(func_graph->is_generated()); new_func_graph->set_stub(func_graph->stub()); + new_func_graph->set_switch_input(func_graph->switch_input()); new_func_graph->set_switch_layer_input(func_graph->switch_layer_input()); for (auto &item : func_graph->parameter_default_value()) { new_func_graph->set_param_default_value(item.first, cloner[item.second]);