From 9dd0282a1797edf60fb749fc81e06b98f752c58f Mon Sep 17 00:00:00 2001 From: VectorSL Date: Sat, 15 Aug 2020 17:14:57 +0800 Subject: [PATCH] git update relplace_bn_grad_cast --- .../optimizer/gpu/replace_addn_fusion.cc | 2 - .../optimizer/gpu/replace_bn_cast_fusion.cc | 12 ---- .../gpu/replace_bn_grad_cast_fusion.cc | 57 +++++++++---------- 3 files changed, 27 insertions(+), 44 deletions(-) diff --git a/mindspore/ccsrc/backend/optimizer/gpu/replace_addn_fusion.cc b/mindspore/ccsrc/backend/optimizer/gpu/replace_addn_fusion.cc index 575a01cc24..87973a8f3c 100644 --- a/mindspore/ccsrc/backend/optimizer/gpu/replace_addn_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/gpu/replace_addn_fusion.cc @@ -36,13 +36,11 @@ const AnfNodePtr ReplaceAddNFusion::Process(const FuncGraphPtr &graph, const Anf MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(equiv); - auto A = AnfAlgo::GetInputNode(utils::cast(node), 0); auto B = AnfAlgo::GetInputNode(utils::cast(node), 1); MS_EXCEPTION_IF_NULL(A); MS_EXCEPTION_IF_NULL(B); int num_input = AnfAlgo::GetNodeAttr(node, "n"); - if (num_input == 2) { auto prim = std::make_shared(prim::kPrimTensorAdd->name()); MS_EXCEPTION_IF_NULL(prim); diff --git a/mindspore/ccsrc/backend/optimizer/gpu/replace_bn_cast_fusion.cc b/mindspore/ccsrc/backend/optimizer/gpu/replace_bn_cast_fusion.cc index 2d48e5b002..2483e8171a 100644 --- a/mindspore/ccsrc/backend/optimizer/gpu/replace_bn_cast_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/gpu/replace_bn_cast_fusion.cc @@ -38,27 +38,16 @@ const AnfNodePtr ReplaceBNCastFusion::Process(const FuncGraphPtr &graph, const A MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(equiv); - auto fbn2 = AnfAlgo::GetInputNode(utils::cast(node), 0); auto x_after = AnfAlgo::GetInputNode(utils::cast(fbn2), 0); auto x_before = AnfAlgo::GetInputNode(utils::cast(x_after), 0); - auto scale = AnfAlgo::GetInputNode(utils::cast(fbn2), 1); - auto bias = AnfAlgo::GetInputNode(utils::cast(fbn2), 2); - auto mean = AnfAlgo::GetInputNode(utils::cast(fbn2), 3); - auto var = AnfAlgo::GetInputNode(utils::cast(fbn2), 4); - MS_EXCEPTION_IF_NULL(fbn2); MS_EXCEPTION_IF_NULL(x_after); MS_EXCEPTION_IF_NULL(x_before); - MS_EXCEPTION_IF_NULL(scale); - MS_EXCEPTION_IF_NULL(bias); - MS_EXCEPTION_IF_NULL(mean); - MS_EXCEPTION_IF_NULL(var); std::vector outputs_type; std::vector> outputs_shape; auto manager = graph->manager(); MS_EXCEPTION_IF_NULL(manager); - auto outlist = GetRealNodeUsedList(graph, fbn2); for (size_t i = 0; i < outlist->size(); i++) { auto index_node = AnfAlgo::GetInputNode(utils::cast(outlist->at(i).first), 1); @@ -76,7 +65,6 @@ const AnfNodePtr ReplaceBNCastFusion::Process(const FuncGraphPtr &graph, const A AnfAlgo::SetOutputInferTypeAndShape(outputs_type, outputs_shape, outlist->at(i).first.get()); } } - manager->Replace(utils::cast(x_after), utils::cast(x_before)); outputs_type.clear(); outputs_shape.clear(); diff --git a/mindspore/ccsrc/backend/optimizer/gpu/replace_bn_grad_cast_fusion.cc b/mindspore/ccsrc/backend/optimizer/gpu/replace_bn_grad_cast_fusion.cc index 4e1be81ab7..eb78e7280f 100644 --- a/mindspore/ccsrc/backend/optimizer/gpu/replace_bn_grad_cast_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/gpu/replace_bn_grad_cast_fusion.cc @@ -33,6 +33,30 @@ const BaseRef ReplaceBNGradCastFusion::DefinePattern() const { return tupleget; } +const void HandleOutput(const FuncGraphPtr &graph, const mindspore::CNodePtr &kernel) { + auto outlist = GetRealNodeUsedList(graph, kernel); + auto manager = graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + for (size_t j = 0; j < outlist->size(); j++) { + std::vector outputs_type; + std::vector> outputs_shape; + auto index_node = AnfAlgo::GetInputNode(utils::cast(outlist->at(j).first), 1); + auto value_node = index_node->cast(); + MS_EXCEPTION_IF_NULL(value_node); + int item_idx = GetValue(value_node->value()); + if (item_idx == 0) { + auto cast = GetRealNodeUsedList(graph, outlist->at(j).first); + if (AnfAlgo::GetCNodeName(cast->at(0).first) != "Cast") { + continue; + } + manager->Replace(utils::cast(cast->at(0).first), utils::cast(outlist->at(j).first)); + outputs_type.push_back(kNumberTypeFloat16); + outputs_shape.push_back(AnfAlgo::GetOutputInferShape(outlist->at(j).first, 0)); + AnfAlgo::SetOutputInferTypeAndShape(outputs_type, outputs_shape, outlist->at(j).first.get()); + } + } +} + const AnfNodePtr ReplaceBNGradCastFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &equiv) const { MS_EXCEPTION_IF_NULL(graph); @@ -40,26 +64,17 @@ const AnfNodePtr ReplaceBNGradCastFusion::Process(const FuncGraphPtr &graph, con MS_EXCEPTION_IF_NULL(equiv); auto fbn2g = AnfAlgo::GetInputNode(utils::cast(node), 0); - auto dy_after = AnfAlgo::GetInputNode(utils::cast(fbn2g), 0); auto dy_before = AnfAlgo::GetInputNode(utils::cast(dy_after), 0); auto x_ = AnfAlgo::GetInputNode(utils::cast(fbn2g), 1); - auto x_type = AnfAlgo::GetOutputInferDataType(x_, 0); + MS_EXCEPTION_IF_NULL(x_); // if x_type is fp32, the cast is necessary. - if (x_type == kNumberTypeFloat32) { + if (AnfAlgo::GetOutputInferDataType(x_, 0) == kNumberTypeFloat32) { return nullptr; } - auto scale = AnfAlgo::GetInputNode(utils::cast(fbn2g), 2); - auto mean = AnfAlgo::GetInputNode(utils::cast(fbn2g), 3); - auto var = AnfAlgo::GetInputNode(utils::cast(fbn2g), 4); - MS_EXCEPTION_IF_NULL(fbn2g); MS_EXCEPTION_IF_NULL(dy_after); MS_EXCEPTION_IF_NULL(dy_before); - MS_EXCEPTION_IF_NULL(scale); - MS_EXCEPTION_IF_NULL(x_); - MS_EXCEPTION_IF_NULL(mean); - MS_EXCEPTION_IF_NULL(var); std::vector outputs_type; std::vector> outputs_shape; auto manager = graph->manager(); @@ -83,25 +98,7 @@ const AnfNodePtr ReplaceBNGradCastFusion::Process(const FuncGraphPtr &graph, con AnfAlgo::SetOutputInferTypeAndShape(outputs_type, outputs_shape, kernel.get()); } // 3. handle the output of fusedbatchnormgrad: tuplegetitem - auto outlist = GetRealNodeUsedList(graph, kernel); - for (size_t j = 0; j < outlist->size(); j++) { - outputs_type.clear(); - outputs_shape.clear(); - auto index_node = AnfAlgo::GetInputNode(utils::cast(outlist->at(j).first), 1); - auto value_node = index_node->cast(); - MS_EXCEPTION_IF_NULL(value_node); - int item_idx = GetValue(value_node->value()); - if (item_idx == 0) { - auto cast = GetRealNodeUsedList(graph, outlist->at(j).first); - if (AnfAlgo::GetCNodeName(cast->at(0).first) != "Cast") { - continue; - } - manager->Replace(utils::cast(cast->at(0).first), utils::cast(outlist->at(j).first)); - outputs_type.push_back(kNumberTypeFloat16); - outputs_shape.push_back(AnfAlgo::GetOutputInferShape(outlist->at(j).first, 0)); - AnfAlgo::SetOutputInferTypeAndShape(outputs_type, outputs_shape, outlist->at(j).first.get()); - } - } + HandleOutput(graph, kernel); } manager->Replace(utils::cast(dy_after), utils::cast(dy_before)); return node;