From 367a31fa040a9f5311f3b9a03ec889d3fe58c730 Mon Sep 17 00:00:00 2001 From: kswang Date: Wed, 2 Dec 2020 10:52:03 +0800 Subject: [PATCH] split sort visit switch partial first --- mindspore/ccsrc/backend/session/session_basic.cc | 4 ++-- mindspore/ccsrc/vm/graph_partition.cc | 4 +++- mindspore/ccsrc/vm/segment_runner.cc | 6 ++++-- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/mindspore/ccsrc/backend/session/session_basic.cc b/mindspore/ccsrc/backend/session/session_basic.cc index 11b72e49d9..2d284efe78 100644 --- a/mindspore/ccsrc/backend/session/session_basic.cc +++ b/mindspore/ccsrc/backend/session/session_basic.cc @@ -605,7 +605,7 @@ void SessionBasic::GetNewCNodeInputs(const CNodePtr &cnode, KernelGraph *graph, MS_EXCEPTION_IF_NULL(other_graph_cnode); MS_EXCEPTION_IF_NULL(cnode_inputs); auto origin_inputs = cnode->inputs(); - bool optimize_depend = IsPrimitiveCNode(cnode, prim::kPrimDepend) && origin_inputs.size() == 3; + bool optimize_depend = IsPrimitiveCNode(cnode, prim::kPrimDepend) && origin_inputs.size() >= 3; bool optimize_control_depend = IsPrimitiveCNode(cnode, prim::kPrimControlDepend) && origin_inputs.size() == 3; // if has multiple depends,only select first depend as parameter for (size_t input_idx = 1; input_idx < origin_inputs.size(); input_idx++) { @@ -615,7 +615,7 @@ void SessionBasic::GetNewCNodeInputs(const CNodePtr &cnode, KernelGraph *graph, if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) { cnode_inputs->emplace_back(graph->GetBackendAnfByFrontAnf(anf)); continue; - } else if (optimize_depend && input_idx == kDependAttachNodeIndex) { + } else if (optimize_depend && input_idx > 1) { cnode_inputs->push_back(NewValueNode(MakeValue(SizeToInt(input_idx)))); continue; } else if (other_graph_cnode->find(anf) != other_graph_cnode->end()) { diff --git a/mindspore/ccsrc/vm/graph_partition.cc b/mindspore/ccsrc/vm/graph_partition.cc index 0eac22774e..74fa47d9e1 100644 --- a/mindspore/ccsrc/vm/graph_partition.cc +++ b/mindspore/ccsrc/vm/graph_partition.cc @@ -214,7 +214,9 @@ std::vector SplitSort(const FuncGraphPtr &graph, const std::string & auto cnode = node->cast(); MS_EXCEPTION_IF_NULL(cnode); auto node_inputs = cnode->inputs(); - std::reverse(node_inputs.begin(), node_inputs.end()); + if (!IsPrimitiveCNode(cnode, prim::kPrimSwitch)) { + std::reverse(node_inputs.begin(), node_inputs.end()); + } auto ctrl_inputs = control_edges.find(node); if (ctrl_inputs != control_edges.end()) { node_inputs.insert(node_inputs.end(), ctrl_inputs->second.begin(), ctrl_inputs->second.end()); diff --git a/mindspore/ccsrc/vm/segment_runner.cc b/mindspore/ccsrc/vm/segment_runner.cc index af49da3a51..b79d67cc9f 100644 --- a/mindspore/ccsrc/vm/segment_runner.cc +++ b/mindspore/ccsrc/vm/segment_runner.cc @@ -139,9 +139,11 @@ std::tuple TransformSegmentToAnfGr } auto fn = inps[0]; std::vector args{fn}; - if (IsPrimitive(fn, prim::kPrimDepend) && inps.size() == 3 && eqv.find(inps[kDependAttachNodeIndex]) == eqv.end()) { + if (IsPrimitive(fn, prim::kPrimDepend) && inps.size() >= 3 && eqv.find(inps[kDependAttachNodeIndex]) == eqv.end()) { args.emplace_back(RefSubGraphNode(fg, inps[kRealInputIndexInDepend], &inputs, &eqv)); - args.emplace_back(NewValueNode(MakeValue(0))); + for (size_t i = 2; i < inps.size(); ++i) { + args.emplace_back(NewValueNode(MakeValue(0))); + } } else if (IsPrimitive(fn, prim::kPrimControlDepend) && inps.size() == 3) { for (size_t i = 1; i < inps.size(); ++i) { if (inps[i]->isa() && std::find(lst.begin(), lst.end(), inps[i]) == lst.end()) {