|
|
|
@@ -107,6 +107,53 @@ std::vector<size_t> TransShapeToSizet(const abstract::ShapePtr &shape) { |
|
|
|
} |
|
|
|
|
|
|
|
enum ShapeType { kMaxShape, kMinShape }; |
|
|
|
|
|
|
|
void GetRealOutputRecursively(const AnfNodePtr &node, size_t output_index, |
|
|
|
std::vector<session::KernelWithIndex> *inputs) { |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
if (node->isa<ValueNode>() || node->isa<Parameter>()) { |
|
|
|
return inputs->push_back(std::make_pair(node, 0)); |
|
|
|
} |
|
|
|
|
|
|
|
// Skip control node |
|
|
|
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend) || AnfAlgo::CheckPrimitiveType(node, prim::kPrimLoad) || |
|
|
|
AnfAlgo::CheckPrimitiveType(node, prim::kPrimUpdateState)) { |
|
|
|
return GetRealOutputRecursively(node->cast<CNodePtr>()->input(kRealInputIndexInDepend), 0, inputs); |
|
|
|
} |
|
|
|
|
|
|
|
// Bypass TupleGetItem |
|
|
|
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem)) { |
|
|
|
auto tuple_get_item = node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(tuple_get_item); |
|
|
|
auto input = AnfAlgo::GetTupleGetItemRealInput(tuple_get_item); |
|
|
|
auto index = AnfAlgo::GetTupleGetItemOutIndex(tuple_get_item); |
|
|
|
|
|
|
|
// Conceal MakeTuple + TupleGetItem pair. |
|
|
|
if (AnfAlgo::CheckPrimitiveType(input, prim::kPrimMakeTuple)) { |
|
|
|
auto make_tuple = input->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(make_tuple); |
|
|
|
auto real_input = AnfAlgo::GetInputNode(make_tuple, index); |
|
|
|
return GetRealOutputRecursively(real_input, 0, inputs); |
|
|
|
} |
|
|
|
|
|
|
|
// Skip TupleGetItem. |
|
|
|
return GetRealOutputRecursively(input, index, inputs); |
|
|
|
} |
|
|
|
|
|
|
|
// Flatten MakeTuple inputs. |
|
|
|
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimMakeTuple)) { |
|
|
|
auto make_tuple = node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(make_tuple); |
|
|
|
size_t input_num = AnfAlgo::GetInputTensorNum(make_tuple); |
|
|
|
for (size_t input_index = 0; input_index < input_num; ++input_index) { |
|
|
|
auto input_node = AnfAlgo::GetInputNode(make_tuple, input_index); |
|
|
|
GetRealOutputRecursively(input_node, 0, inputs); |
|
|
|
} |
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
return inputs->push_back(std::make_pair(node, output_index)); |
|
|
|
} |
|
|
|
} // namespace |
|
|
|
|
|
|
|
AnfNodePtr AnfRuntimeAlgorithm::MakeMonadValueNode(const KernelGraphPtr &kg) { |
|
|
|
@@ -1956,5 +2003,13 @@ AnfNodeIndexSet AnfRuntimeAlgorithm::GetUpdateStateUsers(const FuncGraphManagerP |
|
|
|
} |
|
|
|
return update_states; |
|
|
|
} |
|
|
|
|
|
|
|
void AnfRuntimeAlgorithm::GetRealInputs(const AnfNodePtr &node, std::vector<session::KernelWithIndex> *inputs) { |
|
|
|
size_t input_num = AnfAlgo::GetInputTensorNum(node); |
|
|
|
for (size_t input_index = 0; input_index < input_num; ++input_index) { |
|
|
|
auto input_node = AnfAlgo::GetInputNode(node->cast<CNodePtr>(), input_index); |
|
|
|
GetRealOutputRecursively(input_node, 0, inputs); |
|
|
|
} |
|
|
|
} |
|
|
|
} // namespace session |
|
|
|
} // namespace mindspore |