From 614258dc269b4e0f2aea8f2e65b1455dd26c24af Mon Sep 17 00:00:00 2001 From: Margaret_wangrui Date: Sat, 15 Aug 2020 16:46:08 +0800 Subject: [PATCH] refactor CreateNewCNode --- .../ccsrc/backend/session/kernel_graph.cc | 7 +- .../ccsrc/backend/session/session_basic.cc | 67 ++++++++++++------- .../ccsrc/backend/session/session_basic.h | 2 + 3 files changed, 51 insertions(+), 25 deletions(-) diff --git a/mindspore/ccsrc/backend/session/kernel_graph.cc b/mindspore/ccsrc/backend/session/kernel_graph.cc index 5f7c37fa18..6a9ac497ec 100644 --- a/mindspore/ccsrc/backend/session/kernel_graph.cc +++ b/mindspore/ccsrc/backend/session/kernel_graph.cc @@ -563,7 +563,12 @@ void KernelGraph::FrontBackendlMapAdd(const AnfNodePtr &front_anf, const AnfNode MS_LOG(EXCEPTION) << "Anf " << front_anf->DebugString() << " has been exist in the front_backend_anf_map_"; } if (backend_front_anf_map_.find(backend_anf) != backend_front_anf_map_.end()) { - MS_LOG(EXCEPTION) << "Kernel " << backend_anf->DebugString() << "has been exist in the backend_front_anf_map_"; + auto front_node = front_anf->cast(); + MS_EXCEPTION_IF_NULL(front_node); + auto attr_input = front_node->input(kAnfPrimitiveIndex); + if (!attr_input->isa()) { + MS_LOG(EXCEPTION) << "Kernel " << backend_anf->DebugString() << "has been exist in the backend_front_anf_map_"; + } } front_backend_anf_map_[front_anf] = backend_anf; backend_front_anf_map_[backend_anf] = front_anf; diff --git a/mindspore/ccsrc/backend/session/session_basic.cc b/mindspore/ccsrc/backend/session/session_basic.cc index 5344defef8..c8c172532f 100644 --- a/mindspore/ccsrc/backend/session/session_basic.cc +++ b/mindspore/ccsrc/backend/session/session_basic.cc @@ -606,6 +606,10 @@ std::vector SessionBasic::CreateSwitchOrPartialNode(const CNodePtr & } else if (AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimSwitch)) { auto switch_cnode = cnode_input->cast(); MS_EXCEPTION_IF_NULL(switch_cnode); + if (cnode->inputs().size() < 2) { + cnode_inputs = switch_cnode->inputs(); + return cnode_inputs; + } std::vector switch_inputs = {switch_cnode->input(kAnfPrimitiveIndex), switch_cnode->input(kFirstDataInputIndex)}; for (size_t index = kFirstBranchInSwitch; index < switch_cnode->inputs().size(); index++) { @@ -630,7 +634,7 @@ std::vector SessionBasic::CreateSwitchOrPartialNode(const CNodePtr & MS_LOG(EXCEPTION) << "CNode input[0] must be partial or switch."; } -CNodePtr SessionBasic::CreateNewCNode(CNodePtr cnode, KernelGraph *graph) { +std::vector SessionBasic::CreateValueNode(const CNodePtr &cnode, KernelGraph *graph) { MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(graph); std::vector cnode_inputs; @@ -641,7 +645,7 @@ CNodePtr SessionBasic::CreateNewCNode(CNodePtr cnode, KernelGraph *graph) { MS_EXCEPTION_IF_NULL(fg); auto new_fg = BasicClone(fg); cnode_inputs.push_back(std::make_shared(new_fg)); - } else if (IsValueNode(attr_input)) { + } else { // create primitive of cnode:call cnode_inputs = {graph->NewValueNode(NewValueNode(std::make_shared(prim::kPrimCall->name())))}; // create a ValueNode as input of cnode:call @@ -653,38 +657,27 @@ CNodePtr SessionBasic::CreateNewCNode(CNodePtr cnode, KernelGraph *graph) { cnode_inputs.emplace_back(new_value_node); } } - } else if (attr_input->isa()) { - auto cnode_input = graph->GetBackendAnfByFrontAnf(attr_input); - if (cnode->inputs().size() < 2 && AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimSwitch)) { - auto switch_cnode = cnode_input->cast(); - MS_EXCEPTION_IF_NULL(switch_cnode); - cnode_inputs = switch_cnode->inputs(); - } else { - cnode_inputs = CreateSwitchOrPartialNode(cnode, graph); - } - } else if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch)) { - cnode_inputs = {graph->GetBackendAnfByFrontAnf(cnode->input(kAnfPrimitiveIndex)), - graph->GetBackendAnfByFrontAnf(cnode->input(kFirstDataInputIndex))}; + } + return cnode_inputs; +} + +void SessionBasic::CreateCNodeInputs(const CNodePtr &cnode, KernelGraph *graph, std::vector *cnode_inputs) { + MS_EXCEPTION_IF_NULL(cnode); + MS_EXCEPTION_IF_NULL(graph); + if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch)) { + cnode_inputs->emplace_back(graph->GetBackendAnfByFrontAnf(cnode->input(kFirstDataInputIndex))); for (size_t index = kFirstBranchInSwitch; index < cnode->inputs().size(); index++) { auto node_input = cnode->input(index); auto switch_input = CreateSwitchInput(node_input, graph); - cnode_inputs.emplace_back(switch_input); + cnode_inputs->emplace_back(switch_input); } } else { - // get primitive of old node - auto prim = AnfAlgo::GetCNodePrimitive(cnode); - MS_EXCEPTION_IF_NULL(prim); - // push attr to inputs[0] of new cnode - cnode_inputs = {graph->NewValueNode(NewValueNode(std::make_shared(*prim)))}; - } - - if (!AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch)) { for (size_t input_idx = kFirstDataInputIndex; input_idx < cnode->inputs().size(); input_idx++) { auto anf = cnode->input(input_idx); MS_EXCEPTION_IF_NULL(anf); // anf has been created before if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) { - cnode_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(anf)); + cnode_inputs->emplace_back(graph->GetBackendAnfByFrontAnf(anf)); continue; } else if (IsValueNode(anf)) { continue; @@ -692,6 +685,32 @@ CNodePtr SessionBasic::CreateNewCNode(CNodePtr cnode, KernelGraph *graph) { MS_LOG(EXCEPTION) << "Unexpected input[" << anf->DebugString() << "]"; } } +} + +CNodePtr SessionBasic::CreateNewCNode(CNodePtr cnode, KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(cnode); + MS_EXCEPTION_IF_NULL(graph); + std::vector cnode_inputs; + auto attr_input = cnode->input(kAnfPrimitiveIndex); + MS_EXCEPTION_IF_NULL(attr_input); + if (IsValueNode(attr_input)) { + // cnode is a graph or a call + cnode_inputs = CreateValueNode(cnode, graph); + } else if (attr_input->isa()) { + // cnode ia a call (partial/switch/switch_layer) + // 1. take the args of call to the partial node, as the real_args to call switch's or switch_layer's child graph + // 2. the call in frontend is map to the partial/switch/switch_layer in backend and haven't been created + cnode_inputs = CreateSwitchOrPartialNode(cnode, graph); + } else { + // get primitive of old node + auto prim = AnfAlgo::GetCNodePrimitive(cnode); + MS_EXCEPTION_IF_NULL(prim); + // push attr to inputs[0] of new cnode + cnode_inputs = {graph->NewValueNode(NewValueNode(std::make_shared(*prim)))}; + } + // handle inputs of cnode except primitive + CreateCNodeInputs(cnode, graph, &cnode_inputs); + TraceManager::DebugTrace(std::make_shared(cnode->debug_info())); auto new_cnode = graph->NewCNode(cnode_inputs); TraceManager::EndTrace(); diff --git a/mindspore/ccsrc/backend/session/session_basic.h b/mindspore/ccsrc/backend/session/session_basic.h index d3e94473a4..b70e57452a 100644 --- a/mindspore/ccsrc/backend/session/session_basic.h +++ b/mindspore/ccsrc/backend/session/session_basic.h @@ -111,6 +111,8 @@ class SessionBasic { private: CNodePtr CreateSwitchInput(const AnfNodePtr &node_input, KernelGraph *graph); std::vector CreateSwitchOrPartialNode(const CNodePtr &cnode, KernelGraph *graph); + std::vector CreateValueNode(const CNodePtr &cnode, KernelGraph *graph); + void CreateCNodeInputs(const CNodePtr &cnode, KernelGraph *graph, std::vector *cnode_inputs); protected: virtual void SetSummaryNodes(KernelGraph *graph);