Browse Source

!1885 splitsort reorder getitem

Merge pull request !1885 from kisnwang/splitsort-reorder-getitem
tags/v0.5.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
5958c4abc6
2 changed files with 61 additions and 5 deletions
  1. +3
    -3
      mindspore/ccsrc/vm/backend.cc
  2. +58
    -2
      mindspore/ccsrc/vm/transform.cc

+ 3
- 3
mindspore/ccsrc/vm/backend.cc View File

@@ -65,7 +65,7 @@ LinConvertResult MsBackend::MsConvert(const AnfNodePtrList &lst, const std::stri
result.outputs = outputs; result.outputs = outputs;
result.graph_id = kInvalidGraphId; result.graph_id = kInvalidGraphId;
GraphId graph_id = kInvalidGraphId; GraphId graph_id = kInvalidGraphId;
if (target != target_device_ && target != "") {
if (target != target_device_ && !target.empty()) {
CreateOtherSession(target); CreateOtherSession(target);
graph_id = other_sess_->CompileGraph(lst, outputs); graph_id = other_sess_->CompileGraph(lst, outputs);
} else { } else {
@@ -76,7 +76,7 @@ LinConvertResult MsBackend::MsConvert(const AnfNodePtrList &lst, const std::stri
MS_LOG(INFO) << "PrecompileOnly, stop run graph"; MS_LOG(INFO) << "PrecompileOnly, stop run graph";
return result; return result;
} }
if (target != target_device_ && target != "") {
if (target != target_device_ && !target.empty()) {
other_sess_->BuildGraph(graph_id); other_sess_->BuildGraph(graph_id);
} else if (!is_multi_graph_sink_) { } else if (!is_multi_graph_sink_) {
target_sess_->BuildGraph(graph_id); target_sess_->BuildGraph(graph_id);
@@ -279,7 +279,7 @@ VectorRef MsBackend::MsRunGraph(const GraphId &g, const VectorRef &args, const s


VectorRef outputs; VectorRef outputs;
// call ms rungraph (graphId, input ,output) // call ms rungraph (graphId, input ,output)
if (target != target_device_ && target != "") {
if (target != target_device_ && !target.empty()) {
other_sess_->RunGraph(g, inputs, &outputs); other_sess_->RunGraph(g, inputs, &outputs);
} else { } else {
target_sess_->RunGraph(g, inputs, &outputs); target_sess_->RunGraph(g, inputs, &outputs);


+ 58
- 2
mindspore/ccsrc/vm/transform.cc View File

@@ -129,6 +129,62 @@ void CalcNodeRefCount(const FuncGraphPtr &graph, std::map<AnfNodePtr, size_t> *n
} }
} }


bool IsGetItemNode(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (node->isa<CNode>()) {
auto cnode = node->cast<CNodePtr>();
auto &inputs = cnode->inputs();
if (inputs.empty()) {
MS_LOG(EXCEPTION) << "Inputs of apply node is empty";
}
if (!IsValueNode<Primitive>(inputs[0])) {
return true;
}
PrimitivePtr node_prim = GetValueNode<PrimitivePtr>(inputs[0]);
return node_prim->name() == prim::kPrimTupleGetItem->name();
}
return false;
}

std::vector<AnfNodePtr> ReorderGetItemNode(const std::vector<AnfNodePtr> &nodes) {
std::vector<AnfNodePtr> result;
std::map<size_t, std::vector<AnfNodePtr>> insert_positions;
std::map<AnfNodePtr, size_t> node_positions;
for (auto &node : nodes) {
if (IsGetItemNode(node)) {
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto &inputs = cnode->inputs();
if (inputs.size() < 2) {
MS_LOG(EXCEPTION) << "Invalid get item node";
}
auto &parent = inputs[1];
auto iter = node_positions.find(parent);
if (iter != node_positions.end()) {
size_t position = iter->second;
auto iter_nodes = insert_positions.find(position);
if (iter_nodes != insert_positions.end()) {
iter_nodes->second.push_back(node);
} else {
(void)insert_positions.insert(
std::pair<size_t, std::vector<AnfNodePtr>>(position, std::vector<AnfNodePtr>{node}));
}
continue;
}
}
result.emplace_back(node);
node_positions[node] = result.size();
}

size_t insert_num = 0;
for (auto &item : insert_positions) {
size_t position = item.first + insert_num;
(void)result.insert(result.begin() + position, item.second.begin(), item.second.end());
insert_num += item.second.size();
}
return result;
}

std::vector<AnfNodePtr> SplitSort(const FuncGraphPtr &graph, const std::string &default_target) { std::vector<AnfNodePtr> SplitSort(const FuncGraphPtr &graph, const std::string &default_target) {
std::vector<AnfNodePtr> result; std::vector<AnfNodePtr> result;
std::stack<AnfNodePtr> to_visit; std::stack<AnfNodePtr> to_visit;
@@ -144,8 +200,8 @@ std::vector<AnfNodePtr> SplitSort(const FuncGraphPtr &graph, const std::string &
handle_target = next_target; handle_target = next_target;
} }
auto &node = to_visit.top(); auto &node = to_visit.top();
to_visit.pop();
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
to_visit.pop();
result.emplace_back(node); result.emplace_back(node);
if (!node->isa<CNode>()) { if (!node->isa<CNode>()) {
continue; continue;
@@ -178,7 +234,7 @@ std::vector<AnfNodePtr> SplitSort(const FuncGraphPtr &graph, const std::string &
} }
} }
std::reverse(result.begin(), result.end()); std::reverse(result.begin(), result.end());
return result;
return ReorderGetItemNode(result);
} }
} // namespace } // namespace




Loading…
Cancel
Save