|
|
|
@@ -31,8 +31,9 @@ session::KernelWithIndex FindRefOriginNode(const AnfNodePtr &node) { |
|
|
|
session::KernelWithIndex kernel_with_index = AnfAlgo::VisitKernel(node, 0); |
|
|
|
AnfNodePtr cur_node = kernel_with_index.first; |
|
|
|
size_t cur_out_index = kernel_with_index.second; |
|
|
|
MS_EXCEPTION_IF_NULL(cur_node); |
|
|
|
if (cur_node->isa<CNode>()) { |
|
|
|
auto cnode = node->cast<CNodePtr>(); |
|
|
|
auto cnode = cur_node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
std::string op_name = AnfAlgo::GetCNodeName(cnode); |
|
|
|
auto op_info = mindspore::kernel::OpLib::FindOp(op_name, kernel::kTBE); |
|
|
|
@@ -88,7 +89,7 @@ AnfNodePtr AddAdditionalToRefOutput(const FuncGraphPtr &func_graph, const CNodeP |
|
|
|
size_t input_index, const AnfNodePtr &get_item) { |
|
|
|
AnfNodePtr final_node = (get_item == nullptr ? cnode : get_item); |
|
|
|
size_t final_index = output_index; |
|
|
|
AnfNodePtr input_node = cnode->input(input_index + 1); |
|
|
|
AnfNodePtr input_node = AnfAlgo::GetInputNode(cnode, input_index); |
|
|
|
session::KernelWithIndex origin_pair; |
|
|
|
origin_pair = FindRefOriginNode(input_node); |
|
|
|
MS_EXCEPTION_IF_NULL(origin_pair.first); |
|
|
|
@@ -133,6 +134,7 @@ AnfNodePtr AddAdditionalToRefOutput(const FuncGraphPtr &func_graph, const CNodeP |
|
|
|
} |
|
|
|
AnfNodePtr DealRefForMultipleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, |
|
|
|
const std::shared_ptr<kernel::OpInfo> &op_info) { |
|
|
|
MS_EXCEPTION_IF_NULL(op_info); |
|
|
|
auto ref_infos = op_info->ref_infos(); |
|
|
|
std::vector<AnfNodePtr> make_tuple_inputs; |
|
|
|
AbstractBasePtrList abstract_list; |
|
|
|
@@ -144,9 +146,11 @@ AnfNodePtr DealRefForMultipleOutput(const FuncGraphPtr &func_graph, const CNodeP |
|
|
|
auto input_index = ref_infos.at(output_index); |
|
|
|
final_node = AddAdditionalToRefOutput(func_graph, cnode, output_index, input_index, final_node); |
|
|
|
} |
|
|
|
MS_EXCEPTION_IF_NULL(final_node); |
|
|
|
abstract_list.push_back(final_node->abstract()); |
|
|
|
make_tuple_inputs.push_back(final_node); |
|
|
|
} |
|
|
|
MS_EXCEPTION_IF_NULL(func_graph); |
|
|
|
AnfNodePtr make_tuple = func_graph->NewCNode(make_tuple_inputs); |
|
|
|
MS_EXCEPTION_IF_NULL(make_tuple); |
|
|
|
make_tuple->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list)); |
|
|
|
@@ -155,6 +159,8 @@ AnfNodePtr DealRefForMultipleOutput(const FuncGraphPtr &func_graph, const CNodeP |
|
|
|
|
|
|
|
AnfNodePtr DealRefSigleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, |
|
|
|
const std::shared_ptr<kernel::OpInfo> &op_info) { |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
MS_EXCEPTION_IF_NULL(op_info); |
|
|
|
auto ref_infos = op_info->ref_infos(); |
|
|
|
for (const auto &ref_info : ref_infos) { |
|
|
|
if (ref_info.second > cnode->inputs().size()) { |
|
|
|
@@ -206,7 +212,9 @@ const AnfNodePtr DealRefTransAndCast::Process(const FuncGraphPtr &graph, const A |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
if (op_info->is_ref()) { |
|
|
|
if (!cnode->Type()->isa<Tuple>()) { |
|
|
|
auto type = cnode->Type(); |
|
|
|
MS_EXCEPTION_IF_NULL(type); |
|
|
|
if (!type->isa<Tuple>()) { |
|
|
|
return DealRefSigleOutput(graph, cnode, op_info); |
|
|
|
} else { |
|
|
|
return DealRefForMultipleOutput(graph, cnode, op_info); |
|
|
|
|