Merge pull request !459 from liubuyu/mastertags/v0.2.0-alpha
| @@ -45,6 +45,7 @@ | |||||
| #include "pre_activate/ascend/ir_fusion/mul_add_fusion.h" | #include "pre_activate/ascend/ir_fusion/mul_add_fusion.h" | ||||
| #include "pre_activate/ascend/ir_fusion/mul_addn_fusion.h" | #include "pre_activate/ascend/ir_fusion/mul_addn_fusion.h" | ||||
| #include "pre_activate/ascend/ir_fusion/matmul_biasadd_fusion.h" | #include "pre_activate/ascend/ir_fusion/matmul_biasadd_fusion.h" | ||||
| #include "pre_activate/ascend/ir_fusion/remove_reshape_pair.h" | |||||
| #include "pre_activate/ascend/format_type/insert_trans_op.h" | #include "pre_activate/ascend/format_type/insert_trans_op.h" | ||||
| #include "pre_activate/pass/getitem_tuple.h" | #include "pre_activate/pass/getitem_tuple.h" | ||||
| #include "pre_activate/pass/optimize_dependence.h" | #include "pre_activate/pass/optimize_dependence.h" | ||||
| @@ -113,6 +114,7 @@ void AscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_graph) | |||||
| data_layout_pm->AddPass(std::make_shared<InsertTransOp>()); | data_layout_pm->AddPass(std::make_shared<InsertTransOp>()); | ||||
| data_layout_pm->AddPass(std::make_shared<GetitemTuple>()); | data_layout_pm->AddPass(std::make_shared<GetitemTuple>()); | ||||
| data_layout_pm->AddPass(std::make_shared<CommonSubexpressionElimination>()); | data_layout_pm->AddPass(std::make_shared<CommonSubexpressionElimination>()); | ||||
| data_layout_pm->AddPass(std::make_shared<RemoveReshapePair>()); | |||||
| data_layout_pm->AddPass(std::make_shared<EliminateRedundantOp>()); | data_layout_pm->AddPass(std::make_shared<EliminateRedundantOp>()); | ||||
| data_layout_pm->AddPass(std::make_shared<OptimizeDependence>()); | data_layout_pm->AddPass(std::make_shared<OptimizeDependence>()); | ||||
| data_layout_pm->AddPass(std::make_shared<TransDataSplit>()); | data_layout_pm->AddPass(std::make_shared<TransDataSplit>()); | ||||
| @@ -0,0 +1,55 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "pre_activate/ascend/ir_fusion/remove_reshape_pair.h" | |||||
| #include <memory> | |||||
| #include "session/anf_runtime_algorithm.h" | |||||
| #include "utils/utils.h" | |||||
| #include "operator/ops.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| const BaseRef RemoveReshapePair::DefinePattern() const { | |||||
| const auto prim_reshape = std::make_shared<Primitive>(prim::kPrimReshape->name()); | |||||
| VectorRef reshape({prim_reshape, input_varptr_}); | |||||
| return VectorRef({prim::kPrimReshape, reshape}); | |||||
| } | |||||
| const AnfNodePtr RemoveReshapePair::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||||
| const EquivPtr &equiv) const { | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | |||||
| MS_EXCEPTION_IF_NULL(equiv); | |||||
| auto manager = func_graph->manager(); | |||||
| MS_EXCEPTION_IF_NULL(manager); | |||||
| auto reshape_op_1 = CheckAnfNodeIfCNodeAndInputSize(node, kBackendReshapeInputNum); | |||||
| MS_EXCEPTION_IF_NULL(reshape_op_1); | |||||
| // If reshape operator used by more than one other operators, reshape operator cant not be deleted directly | |||||
| auto users = manager->node_users()[reshape_op_1]; | |||||
| if (users.size() > 1) { | |||||
| return nullptr; | |||||
| } | |||||
| auto reshape_op_2 = CheckAnfNodeIfCNodeAndInputSize(reshape_op_1->input(1), kBackendReshapeInputNum); | |||||
| MS_EXCEPTION_IF_NULL(reshape_op_2); | |||||
| users = manager->node_users()[reshape_op_2]; | |||||
| if (users.size() > 1) { | |||||
| return nullptr; | |||||
| } | |||||
| auto input_node = reshape_op_2->input(1); | |||||
| return input_node; | |||||
| } | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,43 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_REMOVE_RESHAPE_PAIR_H_ | |||||
| #define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_REMOVE_RESHAPE_PAIR_H_ | |||||
| #include <utility> | |||||
| #include <memory> | |||||
| #include "ir/anf.h" | |||||
| #include "pre_activate/common/pattern_engine.h" | |||||
| #include "pre_activate/common/helper.h" | |||||
| #include "pre_activate/common/optimizer.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| class RemoveReshapePair : public PatternProcessPass { | |||||
| public: | |||||
| explicit RemoveReshapePair(bool multigraph = true) : PatternProcessPass("remove_reshape_pair", multigraph) { | |||||
| input_varptr_ = std::make_shared<Var>(); | |||||
| } | |||||
| ~RemoveReshapePair() override = default; | |||||
| const BaseRef DefinePattern() const override; | |||||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||||
| private: | |||||
| VarPtr input_varptr_; | |||||
| }; | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_REMOVE_RESHAPE_PAIR_H_ | |||||