Merge pull request !5666 from limingqi107/mastertags/v1.0.0
| @@ -0,0 +1,59 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "backend/optimizer/gpu/remove_redundant_format_transform.h" | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include "backend/session/anf_runtime_algorithm.h" | |||||
| #include "ir/primitive.h" | |||||
| #include "utils/utils.h" | |||||
| #include "backend/optimizer/common/helper.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| const BaseRef RemoveRedundantFormatTransform::DefinePattern() const { | |||||
| VarPtr X = std::make_shared<Var>(); | |||||
| MS_EXCEPTION_IF_NULL(X); | |||||
| VectorRef transpose = VectorRef({prim::kPrimTranspose, X}); | |||||
| return transpose; | |||||
| } | |||||
| const AnfNodePtr RemoveRedundantFormatTransform::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, | |||||
| const EquivPtr &equiv) const { | |||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| MS_EXCEPTION_IF_NULL(equiv); | |||||
| MS_LOG(DEBUG) << "Process node:" << node->fullname_with_scope(); | |||||
| auto input_node = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 0); | |||||
| MS_EXCEPTION_IF_NULL(input_node); | |||||
| AnfNodePtr first_transpose = nullptr; | |||||
| auto used_node_list = GetRealNodeUsedList(graph, input_node); | |||||
| for (size_t j = 0; j < used_node_list->size(); j++) { | |||||
| auto used_node = used_node_list->at(j).first; | |||||
| if (AnfAlgo::GetCNodeName(used_node) == prim::kPrimTranspose->name()) { | |||||
| first_transpose = used_node; | |||||
| break; | |||||
| } | |||||
| } | |||||
| auto first_transpose_perm = AnfAlgo::GetNodeAttr<std::vector<int>>(first_transpose, "perm"); | |||||
| auto node_perm = AnfAlgo::GetNodeAttr<std::vector<int>>(node, "perm"); | |||||
| if ((first_transpose != node) && (first_transpose_perm == node_perm)) { | |||||
| return first_transpose; | |||||
| } | |||||
| return nullptr; | |||||
| } | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,34 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_REMOVE_REDUNDANT_FORMAT_TRANSFORM_H_ | |||||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_REMOVE_REDUNDANT_FORMAT_TRANSFORM_H_ | |||||
| #include <memory> | |||||
| #include "backend/optimizer/common/optimizer.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| class RemoveRedundantFormatTransform : public PatternProcessPass { | |||||
| public: | |||||
| explicit RemoveRedundantFormatTransform(bool multigraph = true) | |||||
| : PatternProcessPass("remove_redundant_format_transform", multigraph) {} | |||||
| ~RemoveRedundantFormatTransform() override = default; | |||||
| const BaseRef DefinePattern() const override; | |||||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||||
| }; | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_REMOVE_REDUNDANT_FORMAT_TRANSFORM_H_ | |||||
| @@ -35,6 +35,7 @@ | |||||
| #include "backend/optimizer/gpu/replace_addn_fusion.h" | #include "backend/optimizer/gpu/replace_addn_fusion.h" | ||||
| #include "backend/optimizer/gpu/insert_format_transform_op.h" | #include "backend/optimizer/gpu/insert_format_transform_op.h" | ||||
| #include "backend/optimizer/gpu/remove_format_transform_pair.h" | #include "backend/optimizer/gpu/remove_format_transform_pair.h" | ||||
| #include "backend/optimizer/gpu/remove_redundant_format_transform.h" | |||||
| #include "runtime/device/kernel_runtime_manager.h" | #include "runtime/device/kernel_runtime_manager.h" | ||||
| #include "utils/ms_utils.h" | #include "utils/ms_utils.h" | ||||
| #include "common/trans.h" | #include "common/trans.h" | ||||
| @@ -91,6 +92,7 @@ void GPUSession::HardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_gra | |||||
| auto pm = std::make_shared<opt::PassManager>(); | auto pm = std::make_shared<opt::PassManager>(); | ||||
| pm->AddPass(std::make_shared<opt::InsertFormatTransformOp>()); | pm->AddPass(std::make_shared<opt::InsertFormatTransformOp>()); | ||||
| pm->AddPass(std::make_shared<opt::RemoveFormatTransformPair>()); | pm->AddPass(std::make_shared<opt::RemoveFormatTransformPair>()); | ||||
| pm->AddPass(std::make_shared<opt::RemoveRedundantFormatTransform>()); | |||||
| pm->AddPass(std::make_shared<opt::AllReduceFusion>()); | pm->AddPass(std::make_shared<opt::AllReduceFusion>()); | ||||
| pm->AddPass(std::make_shared<opt::GetitemTuple>()); | pm->AddPass(std::make_shared<opt::GetitemTuple>()); | ||||
| optimizer->AddPassManager(pm); | optimizer->AddPassManager(pm); | ||||