From 0fd57bd15eb272486dfa79d8569b008c99fe62f3 Mon Sep 17 00:00:00 2001 From: liubuyu Date: Wed, 10 Jun 2020 11:31:50 +0800 Subject: [PATCH] fix remove reshape pair pass --- .../format_type/deal_ref_trans_and_cast.cc | 2 +- .../ascend/ir_fusion/remove_reshape_pair.cc | 24 +++++++++---------- .../ascend/ir_fusion/remove_reshape_pair.h | 7 +----- 3 files changed, 14 insertions(+), 19 deletions(-) diff --git a/mindspore/ccsrc/pre_activate/ascend/format_type/deal_ref_trans_and_cast.cc b/mindspore/ccsrc/pre_activate/ascend/format_type/deal_ref_trans_and_cast.cc index 43857dddfd..c1cb308338 100644 --- a/mindspore/ccsrc/pre_activate/ascend/format_type/deal_ref_trans_and_cast.cc +++ b/mindspore/ccsrc/pre_activate/ascend/format_type/deal_ref_trans_and_cast.cc @@ -37,7 +37,7 @@ session::KernelWithIndex FindRefOriginNode(const AnfNodePtr &node) { std::string op_name = AnfAlgo::GetCNodeName(cnode); auto op_info = mindspore::kernel::OpLib::FindOp(op_name, kernel::kTBE); // deal ref op - if (op_info->is_ref()) { + if (op_info != nullptr && op_info->is_ref()) { auto ref_infos = op_info->ref_infos(); if (ref_infos.count(cur_out_index) != 0) { auto in_index = ref_infos.at(cur_out_index); diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/remove_reshape_pair.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/remove_reshape_pair.cc index 5e265f2cf1..fa2815ff62 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/remove_reshape_pair.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/remove_reshape_pair.cc @@ -23,33 +23,33 @@ namespace mindspore { namespace opt { const BaseRef RemoveReshapePair::DefinePattern() const { - const auto prim_reshape = std::make_shared(prim::kPrimReshape->name()); - VectorRef reshape({prim_reshape, input_varptr_}); - - return VectorRef({prim::kPrimReshape, reshape}); + VarPtr X = std::make_shared(); + MS_EXCEPTION_IF_NULL(X); + return VectorRef({prim::kPrimReshape, VectorRef({prim::kPrimReshape, X})}); } const AnfNodePtr RemoveReshapePair::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &equiv) const { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(equiv); - auto manager = func_graph->manager(); - MS_EXCEPTION_IF_NULL(manager); auto reshape_op_1 = CheckAnfNodeIfCNodeAndInputSize(node, kBackendReshapeInputNum); MS_EXCEPTION_IF_NULL(reshape_op_1); // If reshape operator used by more than one other operators, reshape operator cant not be deleted directly - auto users = manager->node_users()[reshape_op_1]; - if (users.size() > 1) { + if (IsUsedByOthers(func_graph, reshape_op_1)) { return nullptr; } auto reshape_op_2 = CheckAnfNodeIfCNodeAndInputSize(reshape_op_1->input(1), kBackendReshapeInputNum); MS_EXCEPTION_IF_NULL(reshape_op_2); - users = manager->node_users()[reshape_op_2]; - if (users.size() > 1) { + if (IsUsedByOthers(func_graph, reshape_op_2)) { return nullptr; } - auto input_node = reshape_op_2->input(1); - return input_node; + auto output_shape = AnfAlgo::GetOutputDeviceShape(reshape_op_2, 0); + auto input_shape = AnfAlgo::GetInputDeviceShape(reshape_op_1, 0); + if (input_shape == output_shape) { + auto input_node = reshape_op_2->input(1); + return input_node; + } + return nullptr; } } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/remove_reshape_pair.h b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/remove_reshape_pair.h index a284f4eaa9..ddb25df70c 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/remove_reshape_pair.h +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/remove_reshape_pair.h @@ -28,15 +28,10 @@ namespace mindspore { namespace opt { class RemoveReshapePair : public PatternProcessPass { public: - explicit RemoveReshapePair(bool multigraph = true) : PatternProcessPass("remove_reshape_pair", multigraph) { - input_varptr_ = std::make_shared(); - } + explicit RemoveReshapePair(bool multigraph = true) : PatternProcessPass("remove_reshape_pair", multigraph) {} ~RemoveReshapePair() override = default; const BaseRef DefinePattern() const override; const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; - - private: - VarPtr input_varptr_; }; } // namespace opt } // namespace mindspore