From cfef4722859e767d032653a6acb6e176bfe06f98 Mon Sep 17 00:00:00 2001 From: hangq Date: Sun, 16 Aug 2020 16:40:00 +0800 Subject: [PATCH] add format_transform_op and permute_op fusion --- .../tools/converter/graphdef_transform.cc | 12 ++ .../legacy_optimizer/fusion/CMakeLists.txt | 1 + .../format_trans_permute_fusion_pass.cc | 137 ++++++++++++++++++ .../fusion/format_trans_permute_fusion_pass.h | 49 +++++++ 4 files changed, 199 insertions(+) create mode 100644 mindspore/lite/tools/converter/legacy_optimizer/fusion/format_trans_permute_fusion_pass.cc create mode 100644 mindspore/lite/tools/converter/legacy_optimizer/fusion/format_trans_permute_fusion_pass.h diff --git a/mindspore/lite/tools/converter/graphdef_transform.cc b/mindspore/lite/tools/converter/graphdef_transform.cc index 8363734624..ada89972b1 100644 --- a/mindspore/lite/tools/converter/graphdef_transform.cc +++ b/mindspore/lite/tools/converter/graphdef_transform.cc @@ -30,6 +30,7 @@ #include "tools/converter/legacy_optimizer/fusion/conv_biasadd_fusion_pass.h" // #include "tools/converter/legacy_optimizer/fusion/matmul_biasadd_fusion_pass.h" #include "tools/converter/legacy_optimizer/fusion/format_trans_fusion_pass.h" +#include "tools/converter/legacy_optimizer/fusion/format_trans_permute_fusion_pass.h" #include "tools/converter/legacy_optimizer/fusion/quant_cast_fusion_pass.h" // #include "tools/converter/legacy_optimizer/fusion/batchnorm_fold_fusion_pass.h" // @@ -96,6 +97,17 @@ void GraphDefTransform::CreateQuantizer(const converter::Flags *flags) { int GraphDefTransform::Transform(const converter::Flags &ctx) { STATUS status; + { + Optimizer fusionOptimizer; + fusionOptimizer.AddPass(new (std::nothrow) FormatTransPermuteFusionPass()); + fusionOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); + status = fusionOptimizer.Run(graphDefT); + if (status != RET_OK && status != RET_NO_CHANGE) { + MS_LOG(ERROR) << "Run fusionOptimizer graphPasses Failed"; + return status; + } + } + // weight format trans if (ctx.formatTrans) { Optimizer weightFormatOptimizer; diff --git a/mindspore/lite/tools/converter/legacy_optimizer/fusion/CMakeLists.txt b/mindspore/lite/tools/converter/legacy_optimizer/fusion/CMakeLists.txt index 32aa9d4dac..95d7817f14 100755 --- a/mindspore/lite/tools/converter/legacy_optimizer/fusion/CMakeLists.txt +++ b/mindspore/lite/tools/converter/legacy_optimizer/fusion/CMakeLists.txt @@ -12,6 +12,7 @@ add_library(fusion_mid OBJECT ${CMAKE_CURRENT_SOURCE_DIR}/quant_cast_fusion_pass.cc ${CMAKE_CURRENT_SOURCE_DIR}/batchnorm_fold_fusion_pass.cc ${CMAKE_CURRENT_SOURCE_DIR}/format_trans_fusion_pass.cc + ${CMAKE_CURRENT_SOURCE_DIR}/format_trans_permute_fusion_pass.cc ) target_link_libraries(fusion_mid securec) diff --git a/mindspore/lite/tools/converter/legacy_optimizer/fusion/format_trans_permute_fusion_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/fusion/format_trans_permute_fusion_pass.cc new file mode 100644 index 0000000000..81b85037ac --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/fusion/format_trans_permute_fusion_pass.cc @@ -0,0 +1,137 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include "tools/converter/legacy_optimizer/fusion/format_trans_permute_fusion_pass.h" +#include "utils/log_adapter.h" +#include "securec/include/securec.h" +#include "tools/common/graph_util.h" +#include "include/errorcode.h" +#include "mindspore/lite/schema/inner/model_generated.h" + +namespace mindspore { +namespace lite { +#define kFormatTransPermuteMatchPathLen 2 + +STATUS FormatTransPermuteFusionPass::DefinePattern() { + // format trans + permute + { + auto formatTransOp = std::make_shared(); + formatTransOp->id = kFormatTransformOp; + formatTransOp->types = {PrimitiveType_Nchw2Nhwc, PrimitiveType_Nhwc2Nchw}; + auto permuteOp = std::make_shared(); + permuteOp->id = kPermuteOp; + permuteOp->types = {PrimitiveType_Permute}; + + permuteOp->left = formatTransOp; + std::unique_ptr formatTransPermuteFusionPattern(new (std::nothrow) + FusionPattern(kFormatTrans2PermuteFusionPattern)); + if (formatTransPermuteFusionPattern == nullptr) { + MS_LOG(ERROR) << "new " << kFormatTrans2PermuteFusionPattern << " failed"; + return RET_ERROR; + } + formatTransPermuteFusionPattern->AddPatternOp(formatTransOp); + formatTransPermuteFusionPattern->AddPatternOp(permuteOp); + formatTransPermuteFusionPattern->Finish(); + this->patterns.emplace_back(formatTransPermuteFusionPattern.release()); + } + // permute + format trans + { + auto formatTransOp = std::make_shared(); + formatTransOp->id = kFormatTransformOp; + formatTransOp->types = {PrimitiveType_Nchw2Nhwc, PrimitiveType_Nhwc2Nchw}; + auto permuteOp = std::make_shared(); + permuteOp->id = kPermuteOp; + permuteOp->types = {PrimitiveType_Permute}; + + formatTransOp->left = permuteOp; + std::unique_ptr permuteFormatTransFusionPattern(new (std::nothrow) + FusionPattern(kPermute2FormatTransFusionPattern)); + if (permuteFormatTransFusionPattern == nullptr) { + MS_LOG(ERROR) << "new " << kPermute2FormatTransFusionPattern << " failed"; + return RET_ERROR; + } + permuteFormatTransFusionPattern->AddPatternOp(formatTransOp); + permuteFormatTransFusionPattern->AddPatternOp(permuteOp); + permuteFormatTransFusionPattern->Finish(); + this->patterns.emplace_back(permuteFormatTransFusionPattern.release()); + } + return RET_OK; +} + +STATUS FormatTransPermuteFusionPass::Run(schema::MetaGraphT *graph) { return FusionPass::Run(graph); } + +STATUS FormatTransPermuteFusionPass::DoFusion(schema::MetaGraphT *graph, const std::string &patternName, + std::unordered_map> &matchedPath) { + MS_ASSERT(graph != nullptr); + if (matchedPath.size() != kFormatTransPermuteMatchPathLen) { + MS_LOG(ERROR) << "Format-Transform-Permute-Fusion should have " << kFormatTransPermuteMatchPathLen + << " NodeIndex in matchedPair"; + return RET_PARAM_INVALID; + } + + std::shared_ptr formatTransPath = matchedPath[kFormatTransformOp]; + std::shared_ptr permutePath = matchedPath[kPermuteOp]; + if (formatTransPath == nullptr) { + MS_LOG(ERROR) << "formatTransPath is failed to get"; + return RET_ERROR; + } + if (permutePath == nullptr) { + MS_LOG(ERROR) << "permutePath is failed to get"; + return RET_ERROR; + } + auto &formatTransNode = graph->nodes.at(formatTransPath->nodeIdx); + auto &permuteNode = graph->nodes.at(permutePath->nodeIdx); + MS_ASSERT(formatTransNode != nullptr); + MS_ASSERT(permuteNode != nullptr); + auto formatTransType = formatTransNode->primitive->value.type; + if (formatTransType != PrimitiveType_Nhwc2Nchw && formatTransType != PrimitiveType_Nchw2Nhwc) { + MS_LOG(ERROR) << "FormatTransNode should be " << EnumNamePrimitiveType(PrimitiveType_Nhwc2Nchw) << " or " + << EnumNamePrimitiveType(PrimitiveType_Nchw2Nhwc) << ", but got " + << EnumNamePrimitiveType(formatTransType); + return RET_ERROR; + } + MS_ASSERT(permuteNode->primitive != nullptr); + auto permPrimitive = permuteNode->primitive->value.AsPermute(); + MS_ASSERT(permPrimitive != nullptr); + auto perm = permPrimitive->order; + if (perm.size() != 4) { + return RET_OK; + } + std::vector nchw2nhwcPerm = {0, 2, 3, 1}; + std::vector nhwc2nchwPerm = {0, 3, 1, 2}; + if ((perm == nchw2nhwcPerm && formatTransType == PrimitiveType_Nhwc2Nchw) || + (perm == nhwc2nchwPerm && formatTransType == PrimitiveType_Nchw2Nhwc)) { + auto status = IsolateOneWayNode(graph, formatTransPath->nodeIdx); + if (status != RET_OK) { + MS_LOG(ERROR) << "IsolateOneWayNode failed, node: " << formatTransNode->name << ", error: " << status; + return status; + } + + status = IsolateOneWayNode(graph, permutePath->nodeIdx); + if (status != RET_OK) { + MS_LOG(ERROR) << "IsolateOneWayNode failed, node: " << permuteNode->name << ", error: " << status; + return status; + } + } + + return RET_OK; +} +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/legacy_optimizer/fusion/format_trans_permute_fusion_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/fusion/format_trans_permute_fusion_pass.h new file mode 100644 index 0000000000..f56e94969f --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/fusion/format_trans_permute_fusion_pass.h @@ -0,0 +1,49 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_PREDICT_FORMAT_TRANS_PERMUTE_FUSION_PASS_H +#define MINDSPORE_PREDICT_FORMAT_TRANS_PERMUTE_FUSION_PASS_H + +#include +#include +#include +#include "tools/converter/legacy_optimizer/fusion/fusion_pass.h" + +namespace mindspore { +namespace lite { +constexpr const char *kFormatTransformOp = "FormatTransOp"; +constexpr const char *kPermuteOp = "PermuteOp"; +constexpr const char *kFormatTrans2PermuteFusionPattern = "Nc2NhAndNh2NcFusionPattern"; +constexpr const char *kPermute2FormatTransFusionPattern = "Nc2NhAndNh2NcPassFusionPattern"; + +class FormatTransPermuteFusionPass : public FusionPass { + public: + FormatTransPermuteFusionPass() = default; + + ~FormatTransPermuteFusionPass() override = default; + + STATUS DefinePattern() override; + + STATUS DoFusion(schema::MetaGraphT *graph, const std::string &patternName, + std::unordered_map> &matchedPath) override; + + STATUS Run(schema::MetaGraphT *graph) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_PREDICT_FORMAT_TRANS_PERMUTE_FUSION_PASS_H +