diff --git a/ge/hybrid/model/hybrid_model_builder.cc b/ge/hybrid/model/hybrid_model_builder.cc index e5450375..edf9eb92 100755 --- a/ge/hybrid/model/hybrid_model_builder.cc +++ b/ge/hybrid/model/hybrid_model_builder.cc @@ -1540,14 +1540,20 @@ Status HybridModelBuilder::IdentifyVariableOutputs(NodeItem &node_item) { in_data_anchor->GetIdx(), src_node->GetName().c_str(), src_op_type.c_str()); + uint32_t parent_index = 0; + GE_CHK_STATUS_RET_NOLOG(GetParentNodeOutputIndex(*net_output_desc, in_data_anchor->GetIdx(), parent_index)); + GELOGD("Got parent output index = %u", parent_index); + if (src_op_type == DATA) { + int ref_i = 0; + (void)AttrUtils::GetInt(src_node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, ref_i); + node_item.reuse_inputs.emplace(static_cast(parent_index), ref_i); + GELOGD("[%s] output[%u] resues input[%d]", node_item.NodeName().c_str(), parent_index, ref_i); + } if (src_op_type != CONSTANTOP && src_op_type != CONSTANT && src_op_type != VARIABLE) { continue; } - uint32_t parent_index = 0; - GE_CHK_STATUS_RET_NOLOG(GetParentNodeOutputIndex(*net_output_desc, in_data_anchor->GetIdx(), parent_index)); - GELOGD("Got parent output index = %u", parent_index); GE_CHECK_LE(parent_index, INT32_MAX); node_item.ref_outputs.emplace(static_cast(parent_index), src_node); if (src_op_type == CONSTANTOP || src_op_type == CONSTANT) { diff --git a/tests/ut/ge/hybrid/ge_hybrid_unittest.cc b/tests/ut/ge/hybrid/ge_hybrid_unittest.cc index ffc8c972..25115340 100644 --- a/tests/ut/ge/hybrid/ge_hybrid_unittest.cc +++ b/tests/ut/ge/hybrid/ge_hybrid_unittest.cc @@ -39,7 +39,7 @@ #include "hybrid/common/npu_memory_allocator.h" #include "graph/types.h" #include "graph/utils/tensor_utils.h" - +#include "graph/testcase/ge_graph/graph_builder_utils.h" #undef private #undef protected @@ -173,6 +173,36 @@ TEST_F(UtestGeHybrid, parse_force_infershape_nodes) { HybridModelBuilder hybrid_model_builder(hybrid_model); ASSERT_EQ(hybrid_model_builder.ParseForceInfershapeNodes(node, *new_node), SUCCESS); } +static ComputeGraphPtr BuildDataDirectConnectGraph() { + const char *kRefIndex = "_parent_node_index"; + ge::ut::GraphBuilder builder("subgraph"); + auto data = builder.AddNode("Data", "Data", 1, 1); + auto netoutput = builder.AddNode("NetOutput", "NetOutput", 1, 1); + (void)AttrUtils::SetInt(netoutput->GetOpDesc()->MutableInputDesc(0), kRefIndex, 0); + + builder.AddDataEdge(data, 0, netoutput, 0); + return builder.GetGraph(); +} +TEST_F(UtestGeHybrid, data_direct_connect) { + std::unique_ptr node_item; + auto root_graph = make_shared("root_graph"); + OpDescPtr op_desc = CreateOpDesc("PartitionedCall", "PartitionedCall"); + auto node = root_graph->AddNode(op_desc); + node->SetOwnerComputeGraph(root_graph); + auto sub_graph = BuildDataDirectConnectGraph(); + sub_graph->SetParentGraph(root_graph); + sub_graph->SetParentNode(node); + node->GetOpDesc()->AddSubgraphName("subgraph"); + node->GetOpDesc()->SetSubgraphInstanceName(0, "subgraph"); + root_graph->AddSubgraph("subgraph", sub_graph); + std::unique_ptr new_node; + NodeItem::Create(node, new_node); + GeRootModelPtr ge_root_model = make_shared(root_graph); + HybridModel hybrid_model(ge_root_model); + HybridModelBuilder hybrid_model_builder(hybrid_model); + auto ret = hybrid_model_builder.IdentifyVariableOutputs(*new_node.get()); + ASSERT_EQ(ret, SUCCESS); +} TEST_F(UtestGeHybrid, index_taskdefs_success) { // build aicore task