diff --git a/ge/graph/passes/link_gen_mask_nodes_pass.cc b/ge/graph/passes/link_gen_mask_nodes_pass.cc index 14f5dfc3..e00ede45 100755 --- a/ge/graph/passes/link_gen_mask_nodes_pass.cc +++ b/ge/graph/passes/link_gen_mask_nodes_pass.cc @@ -107,6 +107,16 @@ void LinkGenMaskNodesPass::GetAllGenMaskNodes(ComputeGraphPtr graph, vectorGetInDataNodes(); if (in_data_nodes.size() > kGenMaskInputIndex) { NodePtr &gen_mask = in_data_nodes.at(kGenMaskInputIndex); + for (auto &in_data_node : in_data_nodes) { + // node gen_mask is located at different place in the fused node + if (in_data_node->GetName().find(DROPOUTGENMASK) != in_data_node->GetName().npos) { + gen_mask = in_data_node; + GELOGD("The fused node type [%s], paired with the input node name [%s].", + node->GetType().c_str(), gen_mask->GetName().c_str()); + break; + } + } + if ((gen_mask->GetOpDesc() == nullptr) || (gen_mask->GetOpDesc()->HasAttr(ATTR_NAME_STREAM_LABEL))) { continue; } diff --git a/tests/ut/ge/graph/passes/link_gen_mask_nodes_pass_unittest.cc b/tests/ut/ge/graph/passes/link_gen_mask_nodes_pass_unittest.cc index 511ddece..716cc91d 100644 --- a/tests/ut/ge/graph/passes/link_gen_mask_nodes_pass_unittest.cc +++ b/tests/ut/ge/graph/passes/link_gen_mask_nodes_pass_unittest.cc @@ -43,7 +43,7 @@ ut::GraphBuilder Graph1Builder() { ut::GraphBuilder builder = ut::GraphBuilder("g1"); auto const1 = builder.AddNode("const1", "Const", 0, 1); auto const2 = builder.AddNode("const2", "Const", 0, 1); - auto gen_mask1 = builder.AddNode("gen_mask1", "DropOutGenMask", 2, 1); + auto gen_mask1 = builder.AddNode("gen_mask1_DropOutGenMask", "DropOutGenMask", 2, 1); auto gen_mask2 = builder.AddNode("gen_mask2", "DropOutGenMaskV3", 2, 1); auto gen_mask3 = builder.AddNode("gen_mask3", "DropOutGenMaskV3D", 2, 1); auto do_mask1 = builder.AddNode("do_mask1", "DropOutDoMask", 3, 1); @@ -106,6 +106,6 @@ TEST_F(UtestLinkGenMaskNodesPass, link_gen_mask_nodes_pass_success) { auto out_ctrl_nodes = gen_mask2->GetOutControlNodes(); EXPECT_EQ(out_ctrl_nodes.size(), 1); auto out_ctrl_node = out_ctrl_nodes.at(0); - EXPECT_EQ(out_ctrl_node->GetName(), "gen_mask1"); + EXPECT_EQ(out_ctrl_node->GetName(), "gen_mask1_DropOutGenMask"); } } // namespace ge