From 6cdea5a3d91514ac1951c56b1c3a03d6c7a9f7b9 Mon Sep 17 00:00:00 2001 From: dingshihao2 Date: Fri, 16 Apr 2021 19:21:06 +0800 Subject: [PATCH 1/6] fix gen_mask control-edges bug --- ge/graph/passes/link_gen_mask_nodes_pass.cc | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/ge/graph/passes/link_gen_mask_nodes_pass.cc b/ge/graph/passes/link_gen_mask_nodes_pass.cc index 14f5dfc3..27b12ffc 100755 --- a/ge/graph/passes/link_gen_mask_nodes_pass.cc +++ b/ge/graph/passes/link_gen_mask_nodes_pass.cc @@ -28,7 +28,10 @@ using std::vector; namespace ge { namespace { +<<<<<<< Updated upstream const size_t kGenMaskInputIndex = 1; +======= +>>>>>>> Stashed changes const size_t kDefaultMaxParallelNum = 1; } // namespace @@ -105,8 +108,18 @@ void LinkGenMaskNodesPass::GetAllGenMaskNodes(ComputeGraphPtr graph, vectorGetInDataNodes(); +<<<<<<< Updated upstream if (in_data_nodes.size() > kGenMaskInputIndex) { NodePtr &gen_mask = in_data_nodes.at(kGenMaskInputIndex); +======= + for (auto &in_data_node : in_data_nodes) { + // node gen_mask is located at different place in the fused node + if (in_data_node->GetName().find(DROPOUTGENMASK) == in_data_node->GetName().npos) { + continue; + } + NodePtr &gen_mask = in_data_node; + +>>>>>>> Stashed changes if ((gen_mask->GetOpDesc() == nullptr) || (gen_mask->GetOpDesc()->HasAttr(ATTR_NAME_STREAM_LABEL))) { continue; } From 1a492ec8927417d7ed6d9232a816583fa1ca3859 Mon Sep 17 00:00:00 2001 From: dingshihao2 Date: Fri, 16 Apr 2021 19:23:20 +0800 Subject: [PATCH 2/6] fix gen_mask control-edges bug --- ge/graph/passes/link_gen_mask_nodes_pass.cc | 9 --------- 1 file changed, 9 deletions(-) diff --git a/ge/graph/passes/link_gen_mask_nodes_pass.cc b/ge/graph/passes/link_gen_mask_nodes_pass.cc index 27b12ffc..2788bc43 100755 --- a/ge/graph/passes/link_gen_mask_nodes_pass.cc +++ b/ge/graph/passes/link_gen_mask_nodes_pass.cc @@ -28,10 +28,6 @@ using std::vector; namespace ge { namespace { -<<<<<<< Updated upstream -const size_t kGenMaskInputIndex = 1; -======= ->>>>>>> Stashed changes const size_t kDefaultMaxParallelNum = 1; } // namespace @@ -108,10 +104,6 @@ void LinkGenMaskNodesPass::GetAllGenMaskNodes(ComputeGraphPtr graph, vectorGetInDataNodes(); -<<<<<<< Updated upstream - if (in_data_nodes.size() > kGenMaskInputIndex) { - NodePtr &gen_mask = in_data_nodes.at(kGenMaskInputIndex); -======= for (auto &in_data_node : in_data_nodes) { // node gen_mask is located at different place in the fused node if (in_data_node->GetName().find(DROPOUTGENMASK) == in_data_node->GetName().npos) { @@ -119,7 +111,6 @@ void LinkGenMaskNodesPass::GetAllGenMaskNodes(ComputeGraphPtr graph, vector>>>>>> Stashed changes if ((gen_mask->GetOpDesc() == nullptr) || (gen_mask->GetOpDesc()->HasAttr(ATTR_NAME_STREAM_LABEL))) { continue; } From dfd18d346574172ac74a7337eba462ea42ac3080 Mon Sep 17 00:00:00 2001 From: dingshihao2 Date: Fri, 16 Apr 2021 20:16:52 +0800 Subject: [PATCH 3/6] fix gen_mask control-edges bug --- ge/graph/passes/link_gen_mask_nodes_pass.cc | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/ge/graph/passes/link_gen_mask_nodes_pass.cc b/ge/graph/passes/link_gen_mask_nodes_pass.cc index 2788bc43..8dfd447d 100755 --- a/ge/graph/passes/link_gen_mask_nodes_pass.cc +++ b/ge/graph/passes/link_gen_mask_nodes_pass.cc @@ -28,6 +28,7 @@ using std::vector; namespace ge { namespace { +const size_t kGenMaskInputIndex = 1; const size_t kDefaultMaxParallelNum = 1; } // namespace @@ -104,12 +105,14 @@ void LinkGenMaskNodesPass::GetAllGenMaskNodes(ComputeGraphPtr graph, vectorGetInDataNodes(); - for (auto &in_data_node : in_data_nodes) { - // node gen_mask is located at different place in the fused node - if (in_data_node->GetName().find(DROPOUTGENMASK) == in_data_node->GetName().npos) { - continue; + if (in_data_nodes.size() > kGenMaskInputIndex) { + NodePtr &gen_mask = in_data_nodes.at(kGenMaskInputIndex); + for (auto &in_data_node : in_data_nodes) { + // node gen_mask is located at different place in the fused node + if (in_data_node->GetName().find(DROPOUTGENMASK) != in_data_node->GetName().npos) { + gen_mask = in_data_node; + } } - NodePtr &gen_mask = in_data_node; if ((gen_mask->GetOpDesc() == nullptr) || (gen_mask->GetOpDesc()->HasAttr(ATTR_NAME_STREAM_LABEL))) { continue; From 8858e2280704ea632da5df8df079f56b9825bc8f Mon Sep 17 00:00:00 2001 From: dingshihao2 Date: Sat, 17 Apr 2021 14:30:42 +0800 Subject: [PATCH 4/6] fix gen_mask control-edges bug --- ge/graph/passes/link_gen_mask_nodes_pass.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ge/graph/passes/link_gen_mask_nodes_pass.cc b/ge/graph/passes/link_gen_mask_nodes_pass.cc index 8dfd447d..b9adc51b 100755 --- a/ge/graph/passes/link_gen_mask_nodes_pass.cc +++ b/ge/graph/passes/link_gen_mask_nodes_pass.cc @@ -111,6 +111,8 @@ void LinkGenMaskNodesPass::GetAllGenMaskNodes(ComputeGraphPtr graph, vectorGetName().find(DROPOUTGENMASK) != in_data_node->GetName().npos) { gen_mask = in_data_node; + GELOGD("The fused node type [%s], paired with input node name [%s].", node->GetType(), gen_mask->GetName()); + break; } } From 0eacf6dfcdc13fbbfc6c575376c44a2f5ad524f4 Mon Sep 17 00:00:00 2001 From: dingshihao2 Date: Sat, 17 Apr 2021 14:59:16 +0800 Subject: [PATCH 5/6] fix gen_mask control-edges bug --- ge/graph/passes/link_gen_mask_nodes_pass.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ge/graph/passes/link_gen_mask_nodes_pass.cc b/ge/graph/passes/link_gen_mask_nodes_pass.cc index b9adc51b..e00ede45 100755 --- a/ge/graph/passes/link_gen_mask_nodes_pass.cc +++ b/ge/graph/passes/link_gen_mask_nodes_pass.cc @@ -111,7 +111,8 @@ void LinkGenMaskNodesPass::GetAllGenMaskNodes(ComputeGraphPtr graph, vectorGetName().find(DROPOUTGENMASK) != in_data_node->GetName().npos) { gen_mask = in_data_node; - GELOGD("The fused node type [%s], paired with input node name [%s].", node->GetType(), gen_mask->GetName()); + GELOGD("The fused node type [%s], paired with the input node name [%s].", + node->GetType().c_str(), gen_mask->GetName().c_str()); break; } } From 2865bcff6cb69c0c6664f0690ab6e9e3ff002f33 Mon Sep 17 00:00:00 2001 From: stormchasingg <837008578@qq.com> Date: Sun, 18 Apr 2021 21:34:11 +0800 Subject: [PATCH 6/6] fix gen_mask control-edges bug --- tests/ut/ge/graph/passes/link_gen_mask_nodes_pass_unittest.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/ut/ge/graph/passes/link_gen_mask_nodes_pass_unittest.cc b/tests/ut/ge/graph/passes/link_gen_mask_nodes_pass_unittest.cc index 511ddece..716cc91d 100644 --- a/tests/ut/ge/graph/passes/link_gen_mask_nodes_pass_unittest.cc +++ b/tests/ut/ge/graph/passes/link_gen_mask_nodes_pass_unittest.cc @@ -43,7 +43,7 @@ ut::GraphBuilder Graph1Builder() { ut::GraphBuilder builder = ut::GraphBuilder("g1"); auto const1 = builder.AddNode("const1", "Const", 0, 1); auto const2 = builder.AddNode("const2", "Const", 0, 1); - auto gen_mask1 = builder.AddNode("gen_mask1", "DropOutGenMask", 2, 1); + auto gen_mask1 = builder.AddNode("gen_mask1_DropOutGenMask", "DropOutGenMask", 2, 1); auto gen_mask2 = builder.AddNode("gen_mask2", "DropOutGenMaskV3", 2, 1); auto gen_mask3 = builder.AddNode("gen_mask3", "DropOutGenMaskV3D", 2, 1); auto do_mask1 = builder.AddNode("do_mask1", "DropOutDoMask", 3, 1); @@ -106,6 +106,6 @@ TEST_F(UtestLinkGenMaskNodesPass, link_gen_mask_nodes_pass_success) { auto out_ctrl_nodes = gen_mask2->GetOutControlNodes(); EXPECT_EQ(out_ctrl_nodes.size(), 1); auto out_ctrl_node = out_ctrl_nodes.at(0); - EXPECT_EQ(out_ctrl_node->GetName(), "gen_mask1"); + EXPECT_EQ(out_ctrl_node->GetName(), "gen_mask1_DropOutGenMask"); } } // namespace ge