Browse Source

!28278 Fix nee stack control node check.

Merge pull request !28278 from gaoyong10/runtime_second12
tags/v1.6.0
i-robot Gitee 4 years ago
parent
commit
92b0d33f8e
1 changed files with 13 additions and 7 deletions
  1. +13
    -7
      mindspore/ccsrc/runtime/framework/control_node_parser.cc

+ 13
- 7
mindspore/ccsrc/runtime/framework/control_node_parser.cc View File

@@ -562,11 +562,12 @@ std::vector<KernelWithIndex> FetchInputNodeByNode(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(real_node);
std::vector<KernelWithIndex> results;
// 2. MakeTuple.
if (AnfAlgo::CheckPrimitiveType(real_node, prim::kPrimMakeTuple)) {
const auto &make_tuple_cnode = real_node->cast<CNodePtr>();
const auto &make_tuple_inputs = make_tuple_cnode->inputs();
for (size_t i = kMakeTupleInputStartPos; i < make_tuple_inputs.size(); ++i) {
const auto &sub_results = FetchInputNodeByNode(make_tuple_inputs[i]);
if (AnfAlgo::CheckPrimitiveType(real_node, prim::kPrimMakeTuple) ||
AnfAlgo::CheckPrimitiveType(real_node, prim::kPrimMakeCSRTensor)) {
const auto &cnode = real_node->cast<CNodePtr>();
const auto &inputs = cnode->inputs();
for (size_t i = kMakeTupleInputStartPos; i < inputs.size(); ++i) {
const auto &sub_results = FetchInputNodeByNode(inputs[i]);
results.insert(results.end(), sub_results.begin(), sub_results.end());
}
return results;
@@ -647,8 +648,10 @@ std::vector<KernelWithIndex> FetchInputNodeByNode(const AnfNodePtr &node) {
}

size_t output_num = AnfAlgo::GetOutputNumByAbstract(abstract);
if (output_num == 1) {
results.emplace_back(real_node, 0);
if (!abstract->isa<abstract::AbstractTuple>()) {
for (size_t i = 0; i < output_num; ++i) {
results.emplace_back(real_node, i);
}
return results;
}

@@ -1748,6 +1751,9 @@ void ControlNodeParser::ParseNeedStackControlNode(const std::vector<AnfNodePtr>
MS_EXCEPTION_IF_NULL(input_with_index.first);
// If the call node has call or recursion graph input, a stack created for the call node is required.
if (!AnfAlgo::IsCallNode(input_with_index.first)) {
if (!input_with_index.first->isa<CNode>()) {
continue;
}
const auto &graph = FetchKernelGraphByFrontNode(input_with_index.first);
if (graph == nullptr || (!IsRecursionKernelGraph(graph))) {
continue;


Loading…
Cancel
Save