|
|
|
@@ -562,11 +562,12 @@ std::vector<KernelWithIndex> FetchInputNodeByNode(const AnfNodePtr &node) { |
|
|
|
MS_EXCEPTION_IF_NULL(real_node); |
|
|
|
std::vector<KernelWithIndex> results; |
|
|
|
// 2. MakeTuple. |
|
|
|
if (AnfAlgo::CheckPrimitiveType(real_node, prim::kPrimMakeTuple)) { |
|
|
|
const auto &make_tuple_cnode = real_node->cast<CNodePtr>(); |
|
|
|
const auto &make_tuple_inputs = make_tuple_cnode->inputs(); |
|
|
|
for (size_t i = kMakeTupleInputStartPos; i < make_tuple_inputs.size(); ++i) { |
|
|
|
const auto &sub_results = FetchInputNodeByNode(make_tuple_inputs[i]); |
|
|
|
if (AnfAlgo::CheckPrimitiveType(real_node, prim::kPrimMakeTuple) || |
|
|
|
AnfAlgo::CheckPrimitiveType(real_node, prim::kPrimMakeCSRTensor)) { |
|
|
|
const auto &cnode = real_node->cast<CNodePtr>(); |
|
|
|
const auto &inputs = cnode->inputs(); |
|
|
|
for (size_t i = kMakeTupleInputStartPos; i < inputs.size(); ++i) { |
|
|
|
const auto &sub_results = FetchInputNodeByNode(inputs[i]); |
|
|
|
results.insert(results.end(), sub_results.begin(), sub_results.end()); |
|
|
|
} |
|
|
|
return results; |
|
|
|
@@ -647,8 +648,10 @@ std::vector<KernelWithIndex> FetchInputNodeByNode(const AnfNodePtr &node) { |
|
|
|
} |
|
|
|
|
|
|
|
size_t output_num = AnfAlgo::GetOutputNumByAbstract(abstract); |
|
|
|
if (output_num == 1) { |
|
|
|
results.emplace_back(real_node, 0); |
|
|
|
if (!abstract->isa<abstract::AbstractTuple>()) { |
|
|
|
for (size_t i = 0; i < output_num; ++i) { |
|
|
|
results.emplace_back(real_node, i); |
|
|
|
} |
|
|
|
return results; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -1748,6 +1751,9 @@ void ControlNodeParser::ParseNeedStackControlNode(const std::vector<AnfNodePtr> |
|
|
|
MS_EXCEPTION_IF_NULL(input_with_index.first); |
|
|
|
// If the call node has call or recursion graph input, a stack created for the call node is required. |
|
|
|
if (!AnfAlgo::IsCallNode(input_with_index.first)) { |
|
|
|
if (!input_with_index.first->isa<CNode>()) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
const auto &graph = FetchKernelGraphByFrontNode(input_with_index.first); |
|
|
|
if (graph == nullptr || (!IsRecursionKernelGraph(graph))) { |
|
|
|
continue; |
|
|
|
|