Browse Source

Fix segment input/output caculation

pull/16073/head
He Wei 4 years ago
parent
commit
98bab1eb87
3 changed files with 36 additions and 29 deletions
  1. +3
    -2
      mindspore/ccsrc/backend/session/session_basic.cc
  2. +32
    -26
      mindspore/ccsrc/vm/segment_runner.cc
  3. +1
    -1
      mindspore/core/ir/anf.cc

+ 3
- 2
mindspore/ccsrc/backend/session/session_basic.cc View File

@@ -727,7 +727,8 @@ void SessionBasic::GetNewCNodeInputs(const CNodePtr &cnode, KernelGraph *graph,
MS_EXCEPTION_IF_NULL(other_graph_cnode); MS_EXCEPTION_IF_NULL(other_graph_cnode);
MS_EXCEPTION_IF_NULL(cnode_inputs); MS_EXCEPTION_IF_NULL(cnode_inputs);
auto origin_inputs = cnode->inputs(); auto origin_inputs = cnode->inputs();
bool optimize_depend = IsPrimitiveCNode(cnode, prim::kPrimDepend) && origin_inputs.size() >= 3;
const bool is_depend = IsPrimitiveCNode(cnode, prim::kPrimDepend);
const bool is_updatestate = IsPrimitiveCNode(cnode, prim::kPrimUpdateState);
// if has multiple depends,only select first depend as parameter // if has multiple depends,only select first depend as parameter
for (size_t input_idx = 1; input_idx < origin_inputs.size(); input_idx++) { for (size_t input_idx = 1; input_idx < origin_inputs.size(); input_idx++) {
auto anf = origin_inputs[input_idx]; auto anf = origin_inputs[input_idx];
@@ -736,7 +737,7 @@ void SessionBasic::GetNewCNodeInputs(const CNodePtr &cnode, KernelGraph *graph,
if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) { if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) {
(void)cnode_inputs->emplace_back(graph->GetBackendAnfByFrontAnf(anf)); (void)cnode_inputs->emplace_back(graph->GetBackendAnfByFrontAnf(anf));
continue; continue;
} else if (optimize_depend && input_idx > 1) {
} else if ((is_depend && input_idx > 1) || (is_updatestate && input_idx > 2)) {
cnode_inputs->push_back(NewValueNode(MakeValue(SizeToInt(input_idx)))); cnode_inputs->push_back(NewValueNode(MakeValue(SizeToInt(input_idx))));
continue; continue;
} else if (other_graph_cnode->find(anf) != other_graph_cnode->end()) { } else if (other_graph_cnode->find(anf) != other_graph_cnode->end()) {


+ 32
- 26
mindspore/ccsrc/vm/segment_runner.cc View File

@@ -22,6 +22,7 @@
#include <functional> #include <functional>
#include <memory> #include <memory>
#include <set> #include <set>
#include <unordered_set>
#include <tuple> #include <tuple>
#include <unordered_map> #include <unordered_map>
#include <utility> #include <utility>
@@ -39,43 +40,39 @@ namespace compile {
ConvertCache g_ConvertCache; ConvertCache g_ConvertCache;
void ClearConvertCache() { g_ConvertCache.clear(); } void ClearConvertCache() { g_ConvertCache.clear(); }


namespace {
// Return the list of nodes whose values are required beyond this segment. // Return the list of nodes whose values are required beyond this segment.
// Arguments: // Arguments:
// lst: list of nodes (the segment)
// nodes: list of nodes in the segment
// users: dict mapping each node to its users (globally) // users: dict mapping each node to its users (globally)
// seen: set of nodes that are part of the segment // seen: set of nodes that are part of the segment
AnfNodePtrList GetOutput(const AnfNodePtrList &lst, const NodeUsersMap &users, const std::vector<AnfNodePtr> &seen) {
AnfNodePtrList GetOutput(const AnfNodePtrList &nodes, const NodeUsersMap &users,
const std::unordered_set<AnfNodePtr> &seen) {
AnfNodePtrList output; AnfNodePtrList output;
if (users.size() == 0) { if (users.size() == 0) {
return output; return output;
} }

(void)std::transform(
std::begin(lst), std::end(lst), std::back_inserter(output), [&users, &seen](AnfNodePtr n) -> AnfNodePtr {
auto usersn = users.find(n);
bool is_referred_out_of_segment = std::any_of(
std::begin(usersn->second), std::end(usersn->second), [&seen](const std::pair<AnfNodePtr, int64_t> &u) -> bool {
return std::find(std::begin(seen), std::end(seen), u.first) == std::end(seen);
});
if (n->isa<CNode>() && is_referred_out_of_segment) {
return n;
}
return nullptr;
});

// remove nullptr
for (auto it = output.begin(); it != output.end();) {
if (*it == nullptr) {
it = output.erase(it);
} else {
++it;
for (auto &node : nodes) {
if (!node->isa<CNode>()) {
continue;
}
auto iter = users.find(node);
if (iter == users.end()) {
continue;
}
auto &node_users = iter->second;
const bool has_outer_user = std::any_of(
std::begin(node_users), std::end(node_users), [&seen](const std::pair<AnfNodePtr, int64_t> &u) -> bool {
const bool is_outer_user = (seen.find(u.first) == seen.end());
return is_outer_user && !(IsPrimitiveCNode(u.first, prim::kPrimUpdateState) && u.second > 2);
});
if (has_outer_user) {
output.emplace_back(node);
} }
} }

return output; return output;
} }


namespace {
AnfNodePtr RefSubGraphNode(const FuncGraphPtr &fg, const AnfNodePtr &node, AnfNodePtrList *const inputs_ptr, AnfNodePtr RefSubGraphNode(const FuncGraphPtr &fg, const AnfNodePtr &node, AnfNodePtrList *const inputs_ptr,
AnfNodePtrToAnfNodePtrMap *eqv_ptr) { AnfNodePtrToAnfNodePtrMap *eqv_ptr) {
MS_EXCEPTION_IF_NULL(fg); MS_EXCEPTION_IF_NULL(fg);
@@ -129,6 +126,15 @@ std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> TransformSegmentToAnfGr
for (size_t i = 2; i < inps.size(); ++i) { for (size_t i = 2; i < inps.size(); ++i) {
args.emplace_back(NewValueNode(MakeValue(0))); args.emplace_back(NewValueNode(MakeValue(0)));
} }
} else if (IsPrimitive(fn, prim::kPrimUpdateState)) {
args.emplace_back(RefSubGraphNode(fg, inps[1], &inputs, &eqv));
args.emplace_back(RefSubGraphNode(fg, inps[2], &inputs, &eqv));
for (size_t i = 3; i < inps.size(); ++i) {
auto &input = inps[i];
if (eqv.find(input) != eqv.end()) {
args.emplace_back(RefSubGraphNode(fg, input, &inputs, &eqv));
}
}
} else { } else {
(void)std::transform(std::begin(inps) + 1, std::end(inps), std::back_inserter(args), (void)std::transform(std::begin(inps) + 1, std::end(inps), std::back_inserter(args),
[&fg, &inputs, &eqv](const AnfNodePtr &a) { return RefSubGraphNode(fg, a, &inputs, &eqv); }); [&fg, &inputs, &eqv](const AnfNodePtr &a) { return RefSubGraphNode(fg, a, &inputs, &eqv); });
@@ -138,8 +144,8 @@ std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> TransformSegmentToAnfGr
eqv[n]->set_abstract(n->abstract()); eqv[n]->set_abstract(n->abstract());
eqv[n]->set_kernel_info(n->kernel_info_ptr()); eqv[n]->set_kernel_info(n->kernel_info_ptr());
} }
std::vector<AnfNodePtr> eqv_keys;
(void)std::transform(std::begin(eqv), std::end(eqv), std::back_inserter(eqv_keys),
std::unordered_set<AnfNodePtr> eqv_keys;
(void)std::transform(std::begin(eqv), std::end(eqv), std::inserter(eqv_keys, eqv_keys.end()),
[](const std::pair<AnfNodePtr, AnfNodePtr> &elem) -> AnfNodePtr { return elem.first; }); [](const std::pair<AnfNodePtr, AnfNodePtr> &elem) -> AnfNodePtr { return elem.first; });
auto outputs = GetOutput(lst, lst[0]->func_graph()->manager()->node_users(), eqv_keys); auto outputs = GetOutput(lst, lst[0]->func_graph()->manager()->node_users(), eqv_keys);
AnfNodePtr fg_output; AnfNodePtr fg_output;


+ 1
- 1
mindspore/core/ir/anf.cc View File

@@ -465,7 +465,7 @@ std::string GetCNodeTarget(const AnfNodePtr &node) {
} }
} else if (IsPrimitiveCNode(node, prim::kPrimUpdateState)) { } else if (IsPrimitiveCNode(node, prim::kPrimUpdateState)) {
auto &inputs = cnode->inputs(); auto &inputs = cnode->inputs();
if (inputs.size() == 3 && !IsPrimitiveCNode(inputs[2], prim::kPrimMakeTuple)) {
if (inputs.size() >= 3 && !IsPrimitiveCNode(inputs[2], prim::kPrimMakeTuple)) {
return GetCNodeTarget(inputs[2]); return GetCNodeTarget(inputs[2]);
} }
} else if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) { } else if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {


Loading…
Cancel
Save