From add322b1c34c79a00bcd4d9c088ae0c0d1fde2e8 Mon Sep 17 00:00:00 2001 From: yangwei Date: Mon, 1 Mar 2021 11:23:54 +0800 Subject: [PATCH] fix switchlayer --- .../backend/session/ascend_auto_monad.cc | 74 +++++++++++++++++-- .../ccsrc/backend/session/session_basic.cc | 56 +++++++++----- .../ccsrc/backend/session/session_basic.h | 2 +- 3 files changed, 107 insertions(+), 25 deletions(-) diff --git a/mindspore/ccsrc/backend/session/ascend_auto_monad.cc b/mindspore/ccsrc/backend/session/ascend_auto_monad.cc index 0e741dc78a..1d2630f28b 100644 --- a/mindspore/ccsrc/backend/session/ascend_auto_monad.cc +++ b/mindspore/ccsrc/backend/session/ascend_auto_monad.cc @@ -270,8 +270,9 @@ class AscendAutoMonadConverter { MS_LOG(EXCEPTION) << "Invalid CNode: " << cnode->DebugString() << std::endl; } if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimCall) || - AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch)) { - // Found call/switch node, set it as the tail call node. + AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch) || + AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitchLayer)) { + // Found call/switch/switchlayer node, set it as the tail call node. tail_call_node_ = cnode; call_switch_nodes_.emplace_back(cnode); monad_map_.emplace(cnode, last_monad); @@ -292,8 +293,10 @@ class AscendAutoMonadConverter { HandleCall(cnode); } else if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch)) { HandleSwitch(cnode); + } else if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitchLayer)) { + HandleSwitchLayer(cnode); } else { - MS_LOG(EXCEPTION) << "Not a call/switch node: " << cnode->DebugString(); + MS_LOG(EXCEPTION) << "Not a call/switch/switchlayer node: " << cnode->DebugString(); } } // If no tail call, assign output value to output parameter, @@ -413,6 +416,60 @@ class AscendAutoMonadConverter { } } + // + // Convert switch node: + // branch1 = Partial(graph1, arg) + // branch2 = Partial(graph2, arg) + // out = SwitchLayer(index, branch1, branch2) + // to: + // r = link_args(graph1, arg) + // c = UpdateState(c, r) + // r = link_args(graph2, arg) + // c = UpdateState(c, r) + // c = LabelSwitch(index, c) : L1, L2 + // c = LabelSet(c) : + // + void HandleSwitchLayer(const CNodePtr &cnode) { + // Update last_monad_. + last_monad_ = monad_map_[cnode]; + + // Get both branches of the switch, true branch first. + auto branches = GetSwitchBranches(cnode); + + // Link arguments and generate labels for branches. + std::vector graphes; + std::vector labels; + graphes.reserve(branches.size()); + labels.reserve(graphes.size()); + for (auto &[graph, args] : branches) { + if (graph == nullptr) { + MS_LOG(EXCEPTION) << "Invalid switch: " << cnode->DebugString(); + } + auto linked_args = LinkArguments(args, graph); + if (linked_args != nullptr) { + monad_ = UpdateState(GetMonad(), linked_args); + } + graphes.push_back(graph); + labels.push_back(GetOrCreateGraphLabel(graph)); + } + + // Add LabelSwith node. + auto switch_node = LabelSwitch(cnode->input(1), labels); + + // Set child graph attribute for switch node. + SetChildGrapAttr(switch_node, graphes); + + // Setup return label if required. + const bool is_tail_call = (cnode == tail_call_node_); + const bool need_return = (return_label_ == kNoLabel || !is_tail_call); + auto [para_pool, output_para, return_label] = MakeReturn(cnode, need_return); + + // Handle sub-graphs recursively. + for (auto &graph : graphes) { + HandleSubGraph(graph, para_pool, output_para, return_label); + } + } + ParameterPoolPtr GetParameterPool(bool is_last_call) { if (!is_last_call) { // There are multiple calls in this graph, use a new parameter pool @@ -483,10 +540,13 @@ class AscendAutoMonadConverter { } std::vector GetSwitchBranches(const CNodePtr &cnode) { - constexpr size_t true_index = 2; - constexpr size_t false_index = 3; - // True branch first, then false branch. - return {GetSwitchBranch(cnode, true_index), GetSwitchBranch(cnode, false_index)}; + constexpr size_t cond_start_index = 2; + // switch branches + std::vector switch_branches; + for (size_t index = cond_start_index; index < cnode->inputs().size(); ++index) { + switch_branches.emplace_back(GetSwitchBranch(cnode, index)); + } + return switch_branches; } // diff --git a/mindspore/ccsrc/backend/session/session_basic.cc b/mindspore/ccsrc/backend/session/session_basic.cc index c648063c63..82b9ddfc9b 100644 --- a/mindspore/ccsrc/backend/session/session_basic.cc +++ b/mindspore/ccsrc/backend/session/session_basic.cc @@ -928,9 +928,8 @@ std::vector SessionBasic::CreateCallSwitchInputs(const CNodePtr &cno return cnode_inputs; } -void SessionBasic::CreateCallNodeReturnFunction(const CNodePtr &cnode, const AnfNodePtr &real_input) { +void SessionBasic::CreateCallNodeReturnFunction(const CNodePtr &cnode, const std::vector &real_inputs) { MS_EXCEPTION_IF_NULL(cnode); - MS_EXCEPTION_IF_NULL(real_input); if (!(AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimPartial))) { MS_LOG(EXCEPTION) << "Node: " << cnode->DebugString() << "is not a partial node."; } @@ -940,24 +939,37 @@ void SessionBasic::CreateCallNodeReturnFunction(const CNodePtr &cnode, const Anf auto ret = partial_kernel_graph->get_return(); MS_EXCEPTION_IF_NULL(ret); auto return_input = ret->input(kFirstDataInputIndex); - // if kernel graph return node is a function + // return node is a function + std::vector call_inputs = { + partial_kernel_graph->NewValueNode(NewValueNode(std::make_shared(prim::kPrimCall->name())))}; + AnfNodePtr real_kernel_graph; if (AnfAlgo::CheckPrimitiveType(return_input, prim::kPrimPartial)) { - std::vector call_inputs = { - partial_kernel_graph->NewValueNode(NewValueNode(std::make_shared(prim::kPrimCall->name())))}; auto return_input_cnode = return_input->cast(); - auto partial_inputs = return_input_cnode->inputs(); call_inputs.insert(call_inputs.end(), partial_inputs.begin() + kFirstDataInputIndex, partial_inputs.end()); + real_kernel_graph = partial_inputs[kFirstDataInputIndex]; + } else { // return node is kernel graph + call_inputs.emplace_back(return_input); + real_kernel_graph = return_input; + } + + // new call node inputs + for (auto real_input : real_inputs) { auto parameter_for_input = CreateNewParameterFromCNode(real_input, partial_kernel_graph.get()); call_inputs.emplace_back(parameter_for_input); - auto call_node = partial_kernel_graph->NewCNode(call_inputs); - // update abstract - KernelGraphPtr sub_partial_kernel_graph = GetValueNode(partial_inputs[kFirstDataInputIndex]); + } + + auto call_node = partial_kernel_graph->NewCNode(call_inputs); + // update abstract + MS_EXCEPTION_IF_NULL(real_kernel_graph); + if (real_kernel_graph->isa() && IsValueNode(real_kernel_graph)) { + KernelGraphPtr sub_partial_kernel_graph = GetValueNode(real_kernel_graph); + MS_EXCEPTION_IF_NULL(sub_partial_kernel_graph); auto ret_partial = sub_partial_kernel_graph->get_return(); call_node->set_abstract(ret_partial->abstract()); - // update return input - ret->set_input(kFirstDataInputIndex, call_node); } + // update return input + ret->set_input(kFirstDataInputIndex, call_node); } std::vector SessionBasic::CreateCallSwitchLayerInputs(const CNodePtr &cnode, KernelGraph *graph) { @@ -977,9 +989,11 @@ std::vector SessionBasic::CreateCallSwitchLayerInputs(const CNodePtr auto node = make_tuple_node->cast(); MS_EXCEPTION_IF_NULL(node); auto make_tuple_inputs = node->inputs(); - // there is real input in call, should put it to make_tuple in switch_layer - auto real_input = cnode->input(kFirstDataInputIndex); - auto real_input_back = graph->GetBackendAnfByFrontAnf(real_input); + // there are real inputs in call, should put it to make_tuple in switch_layer + std::vector real_inputs; + for (size_t idx = kFirstDataInputIndex; idx < cnode->inputs().size(); ++idx) { + real_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(cnode->input(idx))); + } std::vector new_make_tuple_inputs = { graph->NewValueNode(NewValueNode(std::make_shared(prim::kPrimMakeTuple->name())))}; for (size_t idx = kFirstDataInputIndex; idx < make_tuple_inputs.size(); idx++) { @@ -990,10 +1004,18 @@ std::vector SessionBasic::CreateCallSwitchLayerInputs(const CNodePtr auto partial_node = partial_idx->cast(); MS_EXCEPTION_IF_NULL(partial_node); // update kernel graph when switch_layer node return function - CreateCallNodeReturnFunction(partial_node, real_input_back); + auto partial_input = partial_node->input(kFirstDataInputIndex); + KernelGraphPtr partial_kernel_graph = GetValueNode(partial_input); + MS_EXCEPTION_IF_NULL(partial_kernel_graph); + auto ret = partial_kernel_graph->get_return(); + MS_EXCEPTION_IF_NULL(ret); + auto return_input = ret->input(kFirstDataInputIndex); + if (AnfAlgo::CheckPrimitiveType(return_input, prim::kPrimPartial) || IsValueNode(return_input)) { + CreateCallNodeReturnFunction(partial_node, real_inputs); + } std::vector new_partial_inputs = partial_node->inputs(); - new_partial_inputs.emplace_back(real_input_back); + new_partial_inputs.insert(new_partial_inputs.end(), real_inputs.begin(), real_inputs.end()); auto new_partial = graph->NewCNode(new_partial_inputs); new_make_tuple_inputs.emplace_back(new_partial); } @@ -1003,7 +1025,7 @@ std::vector SessionBasic::CreateCallSwitchLayerInputs(const CNodePtr std::vector new_partial_inputs; new_partial_inputs.emplace_back(NewValueNode(std::make_shared(prim::kPrimPartial->name()))); new_partial_inputs.emplace_back(partial_idx); - new_partial_inputs.emplace_back(real_input_back); + new_partial_inputs.insert(new_partial_inputs.end(), real_inputs.begin(), real_inputs.end()); auto new_partial = graph->NewCNode(new_partial_inputs); new_make_tuple_inputs.emplace_back(new_partial); } diff --git a/mindspore/ccsrc/backend/session/session_basic.h b/mindspore/ccsrc/backend/session/session_basic.h index 96b274b1d2..19ff2927f5 100644 --- a/mindspore/ccsrc/backend/session/session_basic.h +++ b/mindspore/ccsrc/backend/session/session_basic.h @@ -147,7 +147,7 @@ class SessionBasic : public std::enable_shared_from_this { void GetNewCNodeInputs(const CNodePtr &cnode, KernelGraph *graph, std::vector *cnode_inputs, std::unordered_map *other_graph_cnode); std::vector CreateCallSwitchLayerInputs(const CNodePtr &cnode, KernelGraph *graph); - void CreateCallNodeReturnFunction(const CNodePtr &cnode, const AnfNodePtr &real_input); + void CreateCallNodeReturnFunction(const CNodePtr &cnode, const std::vector &real_inputs); protected: friend class Executor;