|
|
|
@@ -27,6 +27,34 @@ |
|
|
|
namespace mindspore { |
|
|
|
namespace opt { |
|
|
|
namespace { |
|
|
|
int64_t SplitTupleInputs(const FuncGraphPtr &graph, const AnfNodePtr &tuple_input, |
|
|
|
std::vector<AnfNodePtr> *plant_inputs) { |
|
|
|
if (!AnfAlgo::IsTupleOutput(tuple_input)) { |
|
|
|
auto abs = tuple_input->abstract(); |
|
|
|
MS_LOG(WARNING) << "The Function only split the output type is tuple type but got" << abs->ToString(); |
|
|
|
return -1; |
|
|
|
} |
|
|
|
MS_EXCEPTION_IF_NULL(plant_inputs); |
|
|
|
auto input_size = AnfAlgo::GetOutputTensorNum(tuple_input); |
|
|
|
if (tuple_input->isa<CNode>() && AnfAlgo::CheckPrimitiveType(tuple_input, prim::kPrimMakeTuple)) { |
|
|
|
auto make_tuple = tuple_input->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(make_tuple); |
|
|
|
size_t tuple_input_num = AnfAlgo::GetInputTensorNum(make_tuple); |
|
|
|
for (size_t j = 0; j < tuple_input_num; ++j) { |
|
|
|
// using for graph kernel |
|
|
|
auto dyn_input_node = AnfAlgo::GetInputNode(make_tuple, j); |
|
|
|
MS_EXCEPTION_IF_NULL(dyn_input_node); |
|
|
|
plant_inputs->emplace_back(dyn_input_node); |
|
|
|
} |
|
|
|
return input_size; |
|
|
|
} |
|
|
|
for (size_t index = 0; index < input_size; ++index) { |
|
|
|
auto dyn_input_node = CreatTupleGetItemNode(graph, tuple_input, index); |
|
|
|
plant_inputs->emplace_back(dyn_input_node); |
|
|
|
} |
|
|
|
return input_size; |
|
|
|
} |
|
|
|
|
|
|
|
void ConvertMakeTupleInputToPlantInputs(const FuncGraphPtr &graph, const CNodePtr &cnode_ptr) { |
|
|
|
MS_EXCEPTION_IF_NULL(cnode_ptr); |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
@@ -41,25 +69,8 @@ void ConvertMakeTupleInputToPlantInputs(const FuncGraphPtr &graph, const CNodePt |
|
|
|
for (size_t i = 0; i < input_num; ++i) { |
|
|
|
auto input_node = AnfAlgo::GetInputNode(cnode_ptr, i); |
|
|
|
MS_EXCEPTION_IF_NULL(input_node); |
|
|
|
if (input_node->isa<CNode>() && AnfAlgo::CheckPrimitiveType(input_node, prim::kPrimMakeTuple)) { |
|
|
|
auto input_size = AnfAlgo::GetOutputTensorNum(input_node); |
|
|
|
dyn_input_sizes.push_back(input_size); |
|
|
|
auto make_tuple = input_node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(make_tuple); |
|
|
|
size_t tuple_input_num = AnfAlgo::GetInputTensorNum(make_tuple); |
|
|
|
for (size_t j = 0; j < tuple_input_num; ++j) { |
|
|
|
auto dyn_input_node = AnfAlgo::GetInputNode(make_tuple, j); |
|
|
|
MS_EXCEPTION_IF_NULL(dyn_input_node); |
|
|
|
if (IsValueNode<tensor::Tensor>(dyn_input_node)) { |
|
|
|
auto kernel_graph = graph->cast<KernelGraphPtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_graph); |
|
|
|
auto success = kernel_graph->NewValueNode(dyn_input_node->cast<ValueNodePtr>()); |
|
|
|
if (!success) { |
|
|
|
MS_LOG(WARNING) << "Make value node failed, " << dyn_input_node->DebugString(); |
|
|
|
} |
|
|
|
} |
|
|
|
plant_inputs.push_back(dyn_input_node); |
|
|
|
} |
|
|
|
if (AnfAlgo::IsTupleOutput(input_node)) { |
|
|
|
dyn_input_sizes.emplace_back(SplitTupleInputs(graph, input_node, &plant_inputs)); |
|
|
|
} else { |
|
|
|
dyn_input_sizes.push_back(-1); |
|
|
|
plant_inputs.push_back(input_node); |
|
|
|
|