Merge pull request !459 from liubuyu/mastertags/v0.2.0-alpha
| @@ -45,6 +45,7 @@ | |||
| #include "pre_activate/ascend/ir_fusion/mul_add_fusion.h" | |||
| #include "pre_activate/ascend/ir_fusion/mul_addn_fusion.h" | |||
| #include "pre_activate/ascend/ir_fusion/matmul_biasadd_fusion.h" | |||
| #include "pre_activate/ascend/ir_fusion/remove_reshape_pair.h" | |||
| #include "pre_activate/ascend/format_type/insert_trans_op.h" | |||
| #include "pre_activate/pass/getitem_tuple.h" | |||
| #include "pre_activate/pass/optimize_dependence.h" | |||
| @@ -113,6 +114,7 @@ void AscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_graph) | |||
| data_layout_pm->AddPass(std::make_shared<InsertTransOp>()); | |||
| data_layout_pm->AddPass(std::make_shared<GetitemTuple>()); | |||
| data_layout_pm->AddPass(std::make_shared<CommonSubexpressionElimination>()); | |||
| data_layout_pm->AddPass(std::make_shared<RemoveReshapePair>()); | |||
| data_layout_pm->AddPass(std::make_shared<EliminateRedundantOp>()); | |||
| data_layout_pm->AddPass(std::make_shared<OptimizeDependence>()); | |||
| data_layout_pm->AddPass(std::make_shared<TransDataSplit>()); | |||
| @@ -0,0 +1,55 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "pre_activate/ascend/ir_fusion/remove_reshape_pair.h" | |||
| #include <memory> | |||
| #include "session/anf_runtime_algorithm.h" | |||
| #include "utils/utils.h" | |||
| #include "operator/ops.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| const BaseRef RemoveReshapePair::DefinePattern() const { | |||
| const auto prim_reshape = std::make_shared<Primitive>(prim::kPrimReshape->name()); | |||
| VectorRef reshape({prim_reshape, input_varptr_}); | |||
| return VectorRef({prim::kPrimReshape, reshape}); | |||
| } | |||
| const AnfNodePtr RemoveReshapePair::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||
| const EquivPtr &equiv) const { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(equiv); | |||
| auto manager = func_graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| auto reshape_op_1 = CheckAnfNodeIfCNodeAndInputSize(node, kBackendReshapeInputNum); | |||
| MS_EXCEPTION_IF_NULL(reshape_op_1); | |||
| // If reshape operator used by more than one other operators, reshape operator cant not be deleted directly | |||
| auto users = manager->node_users()[reshape_op_1]; | |||
| if (users.size() > 1) { | |||
| return nullptr; | |||
| } | |||
| auto reshape_op_2 = CheckAnfNodeIfCNodeAndInputSize(reshape_op_1->input(1), kBackendReshapeInputNum); | |||
| MS_EXCEPTION_IF_NULL(reshape_op_2); | |||
| users = manager->node_users()[reshape_op_2]; | |||
| if (users.size() > 1) { | |||
| return nullptr; | |||
| } | |||
| auto input_node = reshape_op_2->input(1); | |||
| return input_node; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,43 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_REMOVE_RESHAPE_PAIR_H_ | |||
| #define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_REMOVE_RESHAPE_PAIR_H_ | |||
| #include <utility> | |||
| #include <memory> | |||
| #include "ir/anf.h" | |||
| #include "pre_activate/common/pattern_engine.h" | |||
| #include "pre_activate/common/helper.h" | |||
| #include "pre_activate/common/optimizer.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class RemoveReshapePair : public PatternProcessPass { | |||
| public: | |||
| explicit RemoveReshapePair(bool multigraph = true) : PatternProcessPass("remove_reshape_pair", multigraph) { | |||
| input_varptr_ = std::make_shared<Var>(); | |||
| } | |||
| ~RemoveReshapePair() override = default; | |||
| const BaseRef DefinePattern() const override; | |||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||
| private: | |||
| VarPtr input_varptr_; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_REMOVE_RESHAPE_PAIR_H_ | |||