Merge pull request !4188 from riemann_penn/fix_switch_layer_sigle_prim_celltags/v0.7.0-beta
| @@ -607,6 +607,9 @@ void AnfExporter::ExportOneFuncGraph(std::ofstream &ofs, const FuncGraphPtr &fun | |||||
| std::vector<AnfNodePtr> parameters = func_graph->parameters(); | std::vector<AnfNodePtr> parameters = func_graph->parameters(); | ||||
| OrderedMap<AnfNodePtr, int, ParamPtrHasher, ParamPtrEqual> param_map; | OrderedMap<AnfNodePtr, int, ParamPtrHasher, ParamPtrEqual> param_map; | ||||
| if (*(func_graph->switch_layer_input())) { | |||||
| ofs << "switch_layer_input: " << *(func_graph->switch_layer_input()) << "\n"; | |||||
| } | |||||
| ofs << "# [No." << (exported.size() + 1) << "] " << func_graph->DumpText() << "." | ofs << "# [No." << (exported.size() + 1) << "] " << func_graph->DumpText() << "." | ||||
| << func_graph->debug_info()->get_id() << "\n"; | << func_graph->debug_info()->get_id() << "\n"; | ||||
| if (label_manage::GetGlobalTraceLabelType() == label_manage::TraceLabelType::kWithUniqueId) { | if (label_manage::GetGlobalTraceLabelType() == label_manage::TraceLabelType::kWithUniqueId) { | ||||
| @@ -49,6 +49,8 @@ DFunctor::DFunctor(const FuncGraphPtr &primal_graph, const pipeline::ResourceBas | |||||
| std::string grad_op_name = GetValue<std::string>(primal_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)); | std::string grad_op_name = GetValue<std::string>(primal_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)); | ||||
| k_graph_->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(grad_op_name)); | k_graph_->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(grad_op_name)); | ||||
| } | } | ||||
| // To keep switch_layer's inputs from being inlined | |||||
| k_graph_->set_switch_layer_input(primal_graph->switch_layer_input()); | |||||
| TraceManager::EndTrace(); | TraceManager::EndTrace(); | ||||
| TraceManager::DebugTrace(std::make_shared<TraceGradBprop>(primal_graph->debug_info())); | TraceManager::DebugTrace(std::make_shared<TraceGradBprop>(primal_graph->debug_info())); | ||||
| @@ -45,6 +45,7 @@ | |||||
| #include "frontend/optimizer/opt.h" | #include "frontend/optimizer/opt.h" | ||||
| #include "frontend/optimizer/irpass/row_tensor_eliminate.h" | #include "frontend/optimizer/irpass/row_tensor_eliminate.h" | ||||
| #include "frontend/optimizer/irpass/sparse_tensor_eliminate.h" | #include "frontend/optimizer/irpass/sparse_tensor_eliminate.h" | ||||
| #include "frontend/optimizer/irpass/switch_layer_defer_inline.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| @@ -170,6 +171,10 @@ OptimizeIRPassLib::OptimizeIRPassLib() { | |||||
| // Value_Based Eliminate | // Value_Based Eliminate | ||||
| value_based_eliminate_ = MakeSubstitution(std::make_shared<ValueBasedEliminate>(), "value_based_eliminate", | value_based_eliminate_ = MakeSubstitution(std::make_shared<ValueBasedEliminate>(), "value_based_eliminate", | ||||
| {prim::kPrimSelect, prim::kPrimMinimum, prim::kPrimMaximum}); | {prim::kPrimSelect, prim::kPrimMinimum, prim::kPrimMaximum}); | ||||
| // switch_layer defer inline | |||||
| switch_layer_defer_inline_ = | |||||
| MakeSubstitution(std::make_shared<SwitchLayerDeferInline>(), "switch_layer_defer_inline", prim::kPrimSwitchLayer); | |||||
| } | } | ||||
| ResolveIRPassLib::ResolveIRPassLib() { | ResolveIRPassLib::ResolveIRPassLib() { | ||||
| @@ -113,6 +113,9 @@ class OptimizeIRPassLib { | |||||
| // Value_Based Eliminate | // Value_Based Eliminate | ||||
| SubstitutionPtr value_based_eliminate_; | SubstitutionPtr value_based_eliminate_; | ||||
| // SwitchLayer defer inline | |||||
| SubstitutionPtr switch_layer_defer_inline_; | |||||
| }; | }; | ||||
| // the collection of irpass for resolve action | // the collection of irpass for resolve action | ||||
| @@ -39,7 +39,7 @@ class ReplaceApplicator : public AnfVisitor { | |||||
| } | } | ||||
| auto fg = GetValueNode<FuncGraphPtr>(node); | auto fg = GetValueNode<FuncGraphPtr>(node); | ||||
| if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE) || fg->stub()) { | |||||
| if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE) || fg->stub() || *(fg->switch_layer_input())) { | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| @@ -0,0 +1,47 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_SWITCH_LAYER_DEFER_INLINE_H_ | |||||
| #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_SWITCH_LAYER_DEFER_INLINE_H_ | |||||
| #include <vector> | |||||
| #include <algorithm> | |||||
| #include "frontend/optimizer/irpass.h" | |||||
| #include "frontend/optimizer/optimizer.h" | |||||
| #include "frontend/optimizer/anf_visitor.h" | |||||
| #include "frontend/operator/ops.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| namespace irpass { | |||||
| // {prim::kPrimSwitchLayer, {Index, layers}} | |||||
| class SwitchLayerDeferInline : public AnfVisitor { | |||||
| public: | |||||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| auto tuple = dyn_cast<abstract::AbstractTuple>(cnode->inputs()[2]->abstract()); | |||||
| for (auto elem : tuple->elements()) { | |||||
| auto abstract = dyn_cast<abstract::FuncGraphAbstractClosure>(elem); | |||||
| *(abstract->func_graph()->switch_layer_input()) = true; | |||||
| } | |||||
| return nullptr; | |||||
| } | |||||
| }; | |||||
| } // namespace irpass | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_SWITCH_LAYER_DEFER_INLINE_H_ | |||||
| @@ -90,6 +90,7 @@ bool CleanAfterOptAPass(const ResourcePtr &res) { | |||||
| namespace { | namespace { | ||||
| OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { | OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { | ||||
| opt::OptPassConfig a_1 = opt::OptPassConfig({ | opt::OptPassConfig a_1 = opt::OptPassConfig({ | ||||
| irpass.switch_layer_defer_inline_, | |||||
| irpass.switch_simplify_, | irpass.switch_simplify_, | ||||
| // Safe inlining | // Safe inlining | ||||
| @@ -48,6 +48,7 @@ FuncGraph::FuncGraph() | |||||
| manager_(std::weak_ptr<FuncGraphManager>()), | manager_(std::weak_ptr<FuncGraphManager>()), | ||||
| stub_(false) { | stub_(false) { | ||||
| debug_info_ = std::make_shared<GraphDebugInfo>(); | debug_info_ = std::make_shared<GraphDebugInfo>(); | ||||
| switch_layer_input_ = std::make_shared<bool>(false); | |||||
| } | } | ||||
| abstract::AbstractBasePtr FuncGraph::ToAbstract() { | abstract::AbstractBasePtr FuncGraph::ToAbstract() { | ||||
| @@ -353,6 +353,8 @@ class FuncGraph : public FuncGraphBase { | |||||
| bool stub() const { return stub_; } | bool stub() const { return stub_; } | ||||
| void set_stub(bool stub) { stub_ = stub; } | void set_stub(bool stub) { stub_ = stub; } | ||||
| static void set_drawer(Drawer drawer) { drawer_ = drawer; } | static void set_drawer(Drawer drawer) { drawer_ = drawer; } | ||||
| std::shared_ptr<bool> switch_layer_input() const { return switch_layer_input_; } | |||||
| void set_switch_layer_input(std::shared_ptr<bool> switch_layer_input) { switch_layer_input_ = switch_layer_input; } | |||||
| private: | private: | ||||
| // graph is manipulated by manager and others | // graph is manipulated by manager and others | ||||
| @@ -414,6 +416,9 @@ class FuncGraph : public FuncGraphBase { | |||||
| std::list<CNodePtr> order_; | std::list<CNodePtr> order_; | ||||
| bool stub_; | bool stub_; | ||||
| inline static Drawer drawer_ = nullptr; | inline static Drawer drawer_ = nullptr; | ||||
| // Design switch_layer_input as a ptr to | |||||
| // share between derived backpropagator and cloned graphs | |||||
| std::shared_ptr<bool> switch_layer_input_; | |||||
| }; | }; | ||||
| inline CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &fg) { | inline CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &fg) { | ||||
| @@ -228,6 +228,7 @@ void Cloner::SetFuncGraphInfo(const FuncGraphPtr &func_graph, FuncGraphPtr *cons | |||||
| (*target_func_graph)->set_hyper_param_count(func_graph->hyper_param_count()); | (*target_func_graph)->set_hyper_param_count(func_graph->hyper_param_count()); | ||||
| (*target_func_graph)->set_is_generate(func_graph->is_generated()); | (*target_func_graph)->set_is_generate(func_graph->is_generated()); | ||||
| (*target_func_graph)->set_stub(func_graph->stub()); | (*target_func_graph)->set_stub(func_graph->stub()); | ||||
| (*target_func_graph)->set_switch_layer_input(func_graph->switch_layer_input()); | |||||
| TraceManager::EndTrace(); | TraceManager::EndTrace(); | ||||
| } | } | ||||
| @@ -645,6 +646,7 @@ FuncGraphPtr TransformableClone(const FuncGraphPtr &func_graph, const TraceInfoP | |||||
| new_func_graph->set_hyper_param_count(func_graph->hyper_param_count()); | new_func_graph->set_hyper_param_count(func_graph->hyper_param_count()); | ||||
| new_func_graph->set_is_generate(func_graph->is_generated()); | new_func_graph->set_is_generate(func_graph->is_generated()); | ||||
| new_func_graph->set_stub(func_graph->stub()); | new_func_graph->set_stub(func_graph->stub()); | ||||
| new_func_graph->set_switch_layer_input(func_graph->switch_layer_input()); | |||||
| for (auto &item : func_graph->parameter_default_value()) { | for (auto &item : func_graph->parameter_default_value()) { | ||||
| new_func_graph->set_param_default_value(item.first, cloner[item.second]); | new_func_graph->set_param_default_value(item.first, cloner[item.second]); | ||||
| } | } | ||||
| @@ -444,6 +444,26 @@ def test_index_to_switch_layer(): | |||||
| C.grad_all(net)(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32))) | C.grad_all(net)(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32))) | ||||
| def test_switch_layer_with_single_prim(): | |||||
| class SwitchLayerCell(nn.Cell): | |||||
| def __init__(self): | |||||
| super(SwitchLayerCell, self).__init__() | |||||
| self.layers = (nn.ReLU(), nn.ReLU()) | |||||
| self.z3 = Parameter( | |||||
| Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z3') | |||||
| def construct(self, index, x): | |||||
| ret = self.layers[index](x) * self.z3 | |||||
| return ret | |||||
| index = Tensor(0, dtype=mstype.int32) | |||||
| net = SwitchLayerCell() | |||||
| net(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32))) | |||||
| C.grad_by_list(net, ParameterTuple(net.trainable_params()))(index, | |||||
| Tensor(np.full([128, 96], 0.6, dtype=np.float32))) | |||||
| C.grad_all(net)(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32))) | |||||
| def test_control_depend_check(): | def test_control_depend_check(): | ||||
| with pytest.raises(TypeError) as e: | with pytest.raises(TypeError) as e: | ||||
| P.ControlDepend(0.0) | P.ControlDepend(0.0) | ||||