From 8d2cb9bc2ab3cae22ee36e9e51b666a73dbbd548 Mon Sep 17 00:00:00 2001 From: kswang Date: Fri, 5 Jun 2020 19:19:07 +0800 Subject: [PATCH] splitsort reorder getitem --- mindspore/ccsrc/vm/backend.cc | 6 ++-- mindspore/ccsrc/vm/transform.cc | 60 +++++++++++++++++++++++++++++++-- 2 files changed, 61 insertions(+), 5 deletions(-) diff --git a/mindspore/ccsrc/vm/backend.cc b/mindspore/ccsrc/vm/backend.cc index ee4d67d5af..fe84bfbff3 100644 --- a/mindspore/ccsrc/vm/backend.cc +++ b/mindspore/ccsrc/vm/backend.cc @@ -65,7 +65,7 @@ LinConvertResult MsBackend::MsConvert(const AnfNodePtrList &lst, const std::stri result.outputs = outputs; result.graph_id = kInvalidGraphId; GraphId graph_id = kInvalidGraphId; - if (target != target_device_ && target != "") { + if (target != target_device_ && !target.empty()) { CreateOtherSession(target); graph_id = other_sess_->CompileGraph(lst, outputs); } else { @@ -76,7 +76,7 @@ LinConvertResult MsBackend::MsConvert(const AnfNodePtrList &lst, const std::stri MS_LOG(INFO) << "PrecompileOnly, stop run graph"; return result; } - if (target != target_device_ && target != "") { + if (target != target_device_ && !target.empty()) { other_sess_->BuildGraph(graph_id); } else if (!is_multi_graph_sink_) { target_sess_->BuildGraph(graph_id); @@ -279,7 +279,7 @@ VectorRef MsBackend::MsRunGraph(const GraphId &g, const VectorRef &args, const s VectorRef outputs; // call ms rungraph (graphId, input ,output) - if (target != target_device_ && target != "") { + if (target != target_device_ && !target.empty()) { other_sess_->RunGraph(g, inputs, &outputs); } else { target_sess_->RunGraph(g, inputs, &outputs); diff --git a/mindspore/ccsrc/vm/transform.cc b/mindspore/ccsrc/vm/transform.cc index 732107beb4..922fea81dd 100644 --- a/mindspore/ccsrc/vm/transform.cc +++ b/mindspore/ccsrc/vm/transform.cc @@ -129,6 +129,62 @@ void CalcNodeRefCount(const FuncGraphPtr &graph, std::map *n } } +bool IsGetItemNode(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + if (node->isa()) { + auto cnode = node->cast(); + auto &inputs = cnode->inputs(); + if (inputs.empty()) { + MS_LOG(EXCEPTION) << "Inputs of apply node is empty"; + } + if (!IsValueNode(inputs[0])) { + return true; + } + PrimitivePtr node_prim = GetValueNode(inputs[0]); + return node_prim->name() == prim::kPrimTupleGetItem->name(); + } + return false; +} + +std::vector ReorderGetItemNode(const std::vector &nodes) { + std::vector result; + std::map> insert_positions; + std::map node_positions; + for (auto &node : nodes) { + if (IsGetItemNode(node)) { + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto &inputs = cnode->inputs(); + if (inputs.size() < 2) { + MS_LOG(EXCEPTION) << "Invalid get item node"; + } + auto &parent = inputs[1]; + auto iter = node_positions.find(parent); + if (iter != node_positions.end()) { + size_t position = iter->second; + auto iter_nodes = insert_positions.find(position); + if (iter_nodes != insert_positions.end()) { + iter_nodes->second.push_back(node); + } else { + (void)insert_positions.insert( + std::pair>(position, std::vector{node})); + } + continue; + } + } + result.emplace_back(node); + node_positions[node] = result.size(); + } + + size_t insert_num = 0; + for (auto &item : insert_positions) { + size_t position = item.first + insert_num; + (void)result.insert(result.begin() + position, item.second.begin(), item.second.end()); + insert_num += item.second.size(); + } + return result; +} + std::vector SplitSort(const FuncGraphPtr &graph, const std::string &default_target) { std::vector result; std::stack to_visit; @@ -144,8 +200,8 @@ std::vector SplitSort(const FuncGraphPtr &graph, const std::string & handle_target = next_target; } auto &node = to_visit.top(); - to_visit.pop(); MS_EXCEPTION_IF_NULL(node); + to_visit.pop(); result.emplace_back(node); if (!node->isa()) { continue; @@ -178,7 +234,7 @@ std::vector SplitSort(const FuncGraphPtr &graph, const std::string & } } std::reverse(result.begin(), result.end()); - return result; + return ReorderGetItemNode(result); } } // namespace