From ef6507a44dfdd12a362ca4b32937316e82827d87 Mon Sep 17 00:00:00 2001 From: yangwei Date: Tue, 9 Mar 2021 09:52:16 +0800 Subject: [PATCH] fix switch_layer --- .../ccsrc/backend/session/session_basic.cc | 101 ++++++++++-------- .../ccsrc/backend/session/session_basic.h | 3 +- 2 files changed, 58 insertions(+), 46 deletions(-) diff --git a/mindspore/ccsrc/backend/session/session_basic.cc b/mindspore/ccsrc/backend/session/session_basic.cc index e8c840046b..db072d707e 100644 --- a/mindspore/ccsrc/backend/session/session_basic.cc +++ b/mindspore/ccsrc/backend/session/session_basic.cc @@ -934,46 +934,60 @@ std::vector SessionBasic::CreateCallSwitchInputs(const CNodePtr &cno return cnode_inputs; } -void SessionBasic::CreateCallNodeReturnFunction(const CNodePtr &cnode, const std::vector &real_inputs) { +void SessionBasic::CreateCallNodeReturnFunction(const CNodePtr &cnode, KernelGraph *graph, + const std::vector &real_inputs) { MS_EXCEPTION_IF_NULL(cnode); - if (!(AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimPartial))) { - MS_LOG(EXCEPTION) << "Node: " << cnode->DebugString() << "is not a partial node."; + // func1 =switch(branch1, branch2) + // func2 = func1(param1) + // out = func2(param2) + // process the last cnode(func2), not func1 which abstract is AbstractFunction + if (cnode->abstract()->isa()) { + return; } - auto partial_input = cnode->input(kFirstDataInputIndex); - KernelGraphPtr partial_kernel_graph = GetValueNode(partial_input); - MS_EXCEPTION_IF_NULL(partial_kernel_graph); - auto ret = partial_kernel_graph->get_return(); + MS_EXCEPTION_IF_NULL(graph); + auto ret = graph->get_return(); MS_EXCEPTION_IF_NULL(ret); auto return_input = ret->input(kFirstDataInputIndex); // return node is a function std::vector call_inputs = { - partial_kernel_graph->NewValueNode(NewValueNode(std::make_shared(prim::kPrimCall->name())))}; - AnfNodePtr real_kernel_graph; + graph->NewValueNode(NewValueNode(std::make_shared(prim::kPrimCall->name())))}; if (AnfAlgo::CheckPrimitiveType(return_input, prim::kPrimPartial)) { auto return_input_cnode = return_input->cast(); auto partial_inputs = return_input_cnode->inputs(); call_inputs.insert(call_inputs.end(), partial_inputs.begin() + kFirstDataInputIndex, partial_inputs.end()); - real_kernel_graph = partial_inputs[kFirstDataInputIndex]; - } else { // return node is kernel graph + } else if (IsValueNode(return_input)) { // return node is kernel graph call_inputs.emplace_back(return_input); - real_kernel_graph = return_input; + } else { // return node is value node + KernelGraphPtr kernel_graph = NewKernelGraph(); + auto valid_inputs = kernel_graph->MutableValidInputs(); + MS_EXCEPTION_IF_NULL(valid_inputs); + auto graph_inputs = kernel_graph->MutableInputs(); + MS_EXCEPTION_IF_NULL(graph_inputs); + std::vector cnode_inputs = {return_input}; + for (auto real_input : real_inputs) { + auto new_parameter = kernel_graph->NewParameter(real_input->abstract()); + valid_inputs->push_back(true); + graph_inputs->push_back(new_parameter); + cnode_inputs.push_back(new_parameter); + } + auto new_cnode = kernel_graph->NewCNode(cnode_inputs); + new_cnode->set_abstract(cnode->abstract()); + std::vector return_inputs = { + kernel_graph->NewValueNode(NewValueNode(std::make_shared(prim::kPrimReturn->name()))), new_cnode}; + auto return_node = kernel_graph->NewCNode(return_inputs); + return_node->set_abstract(cnode->abstract()); + kernel_graph->set_return(return_node); + call_inputs.push_back(std::make_shared(kernel_graph)); } // new call node inputs for (auto real_input : real_inputs) { - auto parameter_for_input = CreateNewParameterFromCNode(real_input, partial_kernel_graph.get()); + auto parameter_for_input = CreateNewParameterFromCNode(real_input, graph); call_inputs.emplace_back(parameter_for_input); } - auto call_node = partial_kernel_graph->NewCNode(call_inputs); - // update abstract - MS_EXCEPTION_IF_NULL(real_kernel_graph); - if (real_kernel_graph->isa() && IsValueNode(real_kernel_graph)) { - KernelGraphPtr sub_partial_kernel_graph = GetValueNode(real_kernel_graph); - MS_EXCEPTION_IF_NULL(sub_partial_kernel_graph); - auto ret_partial = sub_partial_kernel_graph->get_return(); - call_node->set_abstract(ret_partial->abstract()); - } + auto call_node = graph->NewCNode(call_inputs); + call_node->set_abstract(cnode->abstract()); // update return input ret->set_input(kFirstDataInputIndex, call_node); } @@ -1005,36 +1019,33 @@ std::vector SessionBasic::CreateCallSwitchLayerInputs(const CNodePtr for (size_t idx = kFirstDataInputIndex; idx < make_tuple_inputs.size(); idx++) { auto partial_idx = make_tuple_inputs[idx]; MS_EXCEPTION_IF_NULL(cnode->abstract()); + std::vector new_partial_inputs; + KernelGraphPtr partial_kernel_graph; // switch_layer node input is partial cnode if (AnfAlgo::CheckPrimitiveType(partial_idx, prim::kPrimPartial)) { auto partial_node = partial_idx->cast(); MS_EXCEPTION_IF_NULL(partial_node); - // update kernel graph when switch_layer node return function auto partial_input = partial_node->input(kFirstDataInputIndex); - KernelGraphPtr partial_kernel_graph = GetValueNode(partial_input); - MS_EXCEPTION_IF_NULL(partial_kernel_graph); - auto ret = partial_kernel_graph->get_return(); - MS_EXCEPTION_IF_NULL(ret); - auto return_input = ret->input(kFirstDataInputIndex); - if (AnfAlgo::CheckPrimitiveType(return_input, prim::kPrimPartial) || IsValueNode(return_input)) { - CreateCallNodeReturnFunction(partial_node, real_inputs); - } - - std::vector new_partial_inputs = partial_node->inputs(); - new_partial_inputs.insert(new_partial_inputs.end(), real_inputs.begin(), real_inputs.end()); - auto new_partial = graph->NewCNode(new_partial_inputs); - new_make_tuple_inputs.emplace_back(new_partial); - } - // switch_layer node input is kernel graph value node - if (IsValueNode(partial_idx)) { - // make_tuple inputs is KernelGraph - std::vector new_partial_inputs; + partial_kernel_graph = GetValueNode(partial_input); + new_partial_inputs = partial_node->inputs(); + } else if (IsValueNode(partial_idx)) { // switch_layer node input is kernel graph value node new_partial_inputs.emplace_back(NewValueNode(std::make_shared(prim::kPrimPartial->name()))); new_partial_inputs.emplace_back(partial_idx); - new_partial_inputs.insert(new_partial_inputs.end(), real_inputs.begin(), real_inputs.end()); - auto new_partial = graph->NewCNode(new_partial_inputs); - new_make_tuple_inputs.emplace_back(new_partial); - } + partial_kernel_graph = GetValueNode(partial_idx); + } + // when branch in swich_layer return function + MS_EXCEPTION_IF_NULL(partial_kernel_graph); + auto ret = partial_kernel_graph->get_return(); + MS_EXCEPTION_IF_NULL(ret); + auto return_input = ret->input(kFirstDataInputIndex); + if (AnfAlgo::CheckPrimitiveType(return_input, prim::kPrimPartial) || return_input->isa()) { + CreateCallNodeReturnFunction(cnode, partial_kernel_graph.get(), real_inputs); + } + // partial node add input args + new_partial_inputs.insert(new_partial_inputs.end(), real_inputs.begin(), real_inputs.end()); + // create new partial node + auto new_partial = graph->NewCNode(new_partial_inputs); + new_make_tuple_inputs.emplace_back(new_partial); } auto new_make_tuple = graph->NewCNode(new_make_tuple_inputs); switch_layer_inputs.emplace_back(new_make_tuple); diff --git a/mindspore/ccsrc/backend/session/session_basic.h b/mindspore/ccsrc/backend/session/session_basic.h index f446b67bf3..ea42c1db09 100644 --- a/mindspore/ccsrc/backend/session/session_basic.h +++ b/mindspore/ccsrc/backend/session/session_basic.h @@ -148,7 +148,8 @@ class SessionBasic : public std::enable_shared_from_this { void GetNewCNodeInputs(const CNodePtr &cnode, KernelGraph *graph, std::vector *cnode_inputs, std::unordered_map *other_graph_cnode); std::vector CreateCallSwitchLayerInputs(const CNodePtr &cnode, KernelGraph *graph); - void CreateCallNodeReturnFunction(const CNodePtr &cnode, const std::vector &real_inputs); + void CreateCallNodeReturnFunction(const CNodePtr &cnode, KernelGraph *graph, + const std::vector &real_inputs); protected: friend class Executor;