| @@ -618,6 +618,9 @@ void AnfExporter::ExportOneFuncGraph(std::ofstream &ofs, const FuncGraphPtr &fun | |||||
| std::vector<AnfNodePtr> parameters = func_graph->parameters(); | std::vector<AnfNodePtr> parameters = func_graph->parameters(); | ||||
| OrderedMap<AnfNodePtr, int, ParamPtrHasher, ParamPtrEqual> param_map; | OrderedMap<AnfNodePtr, int, ParamPtrHasher, ParamPtrEqual> param_map; | ||||
| if (*(func_graph->switch_input())) { | |||||
| ofs << "switch_input: " << *(func_graph->switch_input()) << "\n"; | |||||
| } | |||||
| if (*(func_graph->switch_layer_input())) { | if (*(func_graph->switch_layer_input())) { | ||||
| ofs << "switch_layer_input: " << *(func_graph->switch_layer_input()) << "\n"; | ofs << "switch_layer_input: " << *(func_graph->switch_layer_input()) << "\n"; | ||||
| } | } | ||||
| @@ -45,7 +45,8 @@ DFunctor::DFunctor(const FuncGraphPtr &primal_graph, const pipeline::ResourceBas | |||||
| TraceGuard guard(std::make_shared<TraceGradFprop>(primal_graph->debug_info())); | TraceGuard guard(std::make_shared<TraceGradFprop>(primal_graph->debug_info())); | ||||
| k_graph_ = std::make_shared<FuncGraph>(); | k_graph_ = std::make_shared<FuncGraph>(); | ||||
| } | } | ||||
| // To keep switch_layer's inputs from being inlined | |||||
| // To keep switch or switch_layer's inputs from being inlined | |||||
| k_graph_->set_switch_input(primal_graph->switch_input()); | |||||
| k_graph_->set_switch_layer_input(primal_graph->switch_layer_input()); | k_graph_->set_switch_layer_input(primal_graph->switch_layer_input()); | ||||
| k_graph_->set_stage(primal_graph->stage()); | k_graph_->set_stage(primal_graph->stage()); | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -47,7 +47,7 @@ | |||||
| #include "frontend/optimizer/opt.h" | #include "frontend/optimizer/opt.h" | ||||
| #include "frontend/optimizer/irpass/row_tensor_eliminate.h" | #include "frontend/optimizer/irpass/row_tensor_eliminate.h" | ||||
| #include "frontend/optimizer/irpass/sparse_tensor_eliminate.h" | #include "frontend/optimizer/irpass/sparse_tensor_eliminate.h" | ||||
| #include "frontend/optimizer/irpass/switch_layer_defer_inline.h" | |||||
| #include "frontend/optimizer/irpass/switch_or_switch_layer_defer_inline.h" | |||||
| #include "frontend/optimizer/irpass/call_graph_tuple_transform.h" | #include "frontend/optimizer/irpass/call_graph_tuple_transform.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| @@ -231,6 +231,10 @@ OptimizeIRPassLib::OptimizeIRPassLib() { | |||||
| value_based_eliminate_ = MakeSubstitution(std::make_shared<ValueBasedEliminate>(), "value_based_eliminate", | value_based_eliminate_ = MakeSubstitution(std::make_shared<ValueBasedEliminate>(), "value_based_eliminate", | ||||
| {prim::kPrimSelect, prim::kPrimMinimum, prim::kPrimMaximum}); | {prim::kPrimSelect, prim::kPrimMinimum, prim::kPrimMaximum}); | ||||
| // switch defer inline | |||||
| switch_defer_inline_ = | |||||
| MakeSubstitution(std::make_shared<SwitchDeferInline>(), "switch_defer_inline", prim::kPrimSwitch); | |||||
| // switch_layer defer inline | // switch_layer defer inline | ||||
| switch_layer_defer_inline_ = | switch_layer_defer_inline_ = | ||||
| MakeSubstitution(std::make_shared<SwitchLayerDeferInline>(), "switch_layer_defer_inline", prim::kPrimSwitchLayer); | MakeSubstitution(std::make_shared<SwitchLayerDeferInline>(), "switch_layer_defer_inline", prim::kPrimSwitchLayer); | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -139,6 +139,9 @@ class OptimizeIRPassLib { | |||||
| // Value_Based Eliminate | // Value_Based Eliminate | ||||
| SubstitutionPtr value_based_eliminate_; | SubstitutionPtr value_based_eliminate_; | ||||
| // Switch defer inline | |||||
| SubstitutionPtr switch_defer_inline_; | |||||
| // SwitchLayer defer inline | // SwitchLayer defer inline | ||||
| SubstitutionPtr switch_layer_defer_inline_; | SubstitutionPtr switch_layer_defer_inline_; | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -41,7 +41,8 @@ class ReplaceApplicator : public AnfVisitor { | |||||
| } | } | ||||
| auto fg = GetValueNode<FuncGraphPtr>(node); | auto fg = GetValueNode<FuncGraphPtr>(node); | ||||
| if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE) || fg->stage() != -1 || fg->stub() || *(fg->switch_layer_input())) { | |||||
| if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE) || fg->stage() != -1 || fg->stub() || *(fg->switch_input()) || | |||||
| *(fg->switch_layer_input())) { | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -28,12 +28,29 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| namespace irpass { | namespace irpass { | ||||
| // {prim::kPrimSwitchLayer, {Index, layers}} | |||||
| // {prim::kPrimSwitch, cond, true_branch, false_branch} | |||||
| class SwitchDeferInline : public AnfVisitor { | |||||
| public: | |||||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| auto true_abstract = dyn_cast<abstract::FuncGraphAbstractClosure>(cnode->input(2)->abstract()); | |||||
| if (true_abstract != nullptr) { | |||||
| *(true_abstract->func_graph()->switch_input()) = true; | |||||
| } | |||||
| auto false_abstract = dyn_cast<abstract::FuncGraphAbstractClosure>(cnode->input(3)->abstract()); | |||||
| if (false_abstract != nullptr) { | |||||
| *(false_abstract->func_graph()->switch_input()) = true; | |||||
| } | |||||
| return nullptr; | |||||
| } | |||||
| }; | |||||
| // {prim::kPrimSwitchLayer, Index, layers} | |||||
| class SwitchLayerDeferInline : public AnfVisitor { | class SwitchLayerDeferInline : public AnfVisitor { | ||||
| public: | public: | ||||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | ||||
| auto cnode = node->cast<CNodePtr>(); | auto cnode = node->cast<CNodePtr>(); | ||||
| auto tuple = dyn_cast<abstract::AbstractTuple>(cnode->inputs()[2]->abstract()); | |||||
| auto tuple = dyn_cast<abstract::AbstractTuple>(cnode->input(2)->abstract()); | |||||
| for (auto elem : tuple->elements()) { | for (auto elem : tuple->elements()) { | ||||
| auto abstract = dyn_cast<abstract::FuncGraphAbstractClosure>(elem); | auto abstract = dyn_cast<abstract::FuncGraphAbstractClosure>(elem); | ||||
| if (abstract != nullptr) { | if (abstract != nullptr) { | ||||
| @@ -97,6 +97,7 @@ bool ReAutoMonadWrapper(const FuncGraphPtr &root, const opt::OptimizerPtr &) { r | |||||
| OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { | OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { | ||||
| opt::OptPassConfig a_1 = opt::OptPassConfig({ | opt::OptPassConfig a_1 = opt::OptPassConfig({ | ||||
| irpass.switch_defer_inline_, | |||||
| irpass.switch_layer_defer_inline_, | irpass.switch_layer_defer_inline_, | ||||
| irpass.switch_simplify_, | irpass.switch_simplify_, | ||||
| irpass.exchange_switch_depend_value_, | irpass.exchange_switch_depend_value_, | ||||
| @@ -50,6 +50,7 @@ FuncGraph::FuncGraph() | |||||
| stub_(false), | stub_(false), | ||||
| stage_(-1) { | stage_(-1) { | ||||
| debug_info_ = std::make_shared<GraphDebugInfo>(); | debug_info_ = std::make_shared<GraphDebugInfo>(); | ||||
| switch_input_ = std::make_shared<bool>(false); | |||||
| switch_layer_input_ = std::make_shared<bool>(false); | switch_layer_input_ = std::make_shared<bool>(false); | ||||
| } | } | ||||
| @@ -381,6 +381,8 @@ class FuncGraph : public FuncGraphBase, public EffectInfoHolder { | |||||
| bool stub() const { return stub_; } | bool stub() const { return stub_; } | ||||
| void set_stub(bool stub) { stub_ = stub; } | void set_stub(bool stub) { stub_ = stub; } | ||||
| static void set_drawer(Drawer drawer) { drawer_ = drawer; } | static void set_drawer(Drawer drawer) { drawer_ = drawer; } | ||||
| std::shared_ptr<bool> switch_input() const { return switch_input_; } | |||||
| void set_switch_input(std::shared_ptr<bool> switch_input) { switch_input_ = switch_input; } | |||||
| std::shared_ptr<bool> switch_layer_input() const { return switch_layer_input_; } | std::shared_ptr<bool> switch_layer_input() const { return switch_layer_input_; } | ||||
| void set_switch_layer_input(std::shared_ptr<bool> switch_layer_input) { switch_layer_input_ = switch_layer_input; } | void set_switch_layer_input(std::shared_ptr<bool> switch_layer_input) { switch_layer_input_ = switch_layer_input; } | ||||
| bool ContainMultiTarget() const; | bool ContainMultiTarget() const; | ||||
| @@ -462,8 +464,9 @@ class FuncGraph : public FuncGraphBase, public EffectInfoHolder { | |||||
| OrderedSet<CNodePtr> order_; | OrderedSet<CNodePtr> order_; | ||||
| bool stub_; | bool stub_; | ||||
| inline static Drawer drawer_ = nullptr; | inline static Drawer drawer_ = nullptr; | ||||
| // Design switch_layer_input as a ptr to | |||||
| // Design switch_input and switch_layer_input as a ptr to | |||||
| // share between derived backpropagator and cloned graphs. | // share between derived backpropagator and cloned graphs. | ||||
| std::shared_ptr<bool> switch_input_; | |||||
| std::shared_ptr<bool> switch_layer_input_; | std::shared_ptr<bool> switch_layer_input_; | ||||
| int64_t stage_; | int64_t stage_; | ||||
| std::unordered_map<AbstractBasePtrList, FuncGraphPtr, abstract::AbstractBasePtrListHasher, | std::unordered_map<AbstractBasePtrList, FuncGraphPtr, abstract::AbstractBasePtrListHasher, | ||||
| @@ -233,6 +233,7 @@ void Cloner::SetFuncGraphInfo(const FuncGraphPtr &func_graph, FuncGraphPtr *cons | |||||
| (*target_func_graph)->set_hyper_param_count(func_graph->hyper_param_count()); | (*target_func_graph)->set_hyper_param_count(func_graph->hyper_param_count()); | ||||
| (*target_func_graph)->set_is_generate(func_graph->is_generated()); | (*target_func_graph)->set_is_generate(func_graph->is_generated()); | ||||
| (*target_func_graph)->set_stub(func_graph->stub()); | (*target_func_graph)->set_stub(func_graph->stub()); | ||||
| (*target_func_graph)->set_switch_input(func_graph->switch_input()); | |||||
| (*target_func_graph)->set_switch_layer_input(func_graph->switch_layer_input()); | (*target_func_graph)->set_switch_layer_input(func_graph->switch_layer_input()); | ||||
| } | } | ||||
| @@ -680,6 +681,7 @@ FuncGraphPtr TransformableClone(const FuncGraphPtr &func_graph, const TraceInfoP | |||||
| new_func_graph->set_hyper_param_count(func_graph->hyper_param_count()); | new_func_graph->set_hyper_param_count(func_graph->hyper_param_count()); | ||||
| new_func_graph->set_is_generate(func_graph->is_generated()); | new_func_graph->set_is_generate(func_graph->is_generated()); | ||||
| new_func_graph->set_stub(func_graph->stub()); | new_func_graph->set_stub(func_graph->stub()); | ||||
| new_func_graph->set_switch_input(func_graph->switch_input()); | |||||
| new_func_graph->set_switch_layer_input(func_graph->switch_layer_input()); | new_func_graph->set_switch_layer_input(func_graph->switch_layer_input()); | ||||
| for (auto &item : func_graph->parameter_default_value()) { | for (auto &item : func_graph->parameter_default_value()) { | ||||
| new_func_graph->set_param_default_value(item.first, cloner[item.second]); | new_func_graph->set_param_default_value(item.first, cloner[item.second]); | ||||