diff --git a/mindspore/lite/tools/converter/graphdef_transform.cc b/mindspore/lite/tools/converter/graphdef_transform.cc index ada89972b1..312bbfdb9c 100644 --- a/mindspore/lite/tools/converter/graphdef_transform.cc +++ b/mindspore/lite/tools/converter/graphdef_transform.cc @@ -30,7 +30,7 @@ #include "tools/converter/legacy_optimizer/fusion/conv_biasadd_fusion_pass.h" // #include "tools/converter/legacy_optimizer/fusion/matmul_biasadd_fusion_pass.h" #include "tools/converter/legacy_optimizer/fusion/format_trans_fusion_pass.h" -#include "tools/converter/legacy_optimizer/fusion/format_trans_permute_fusion_pass.h" +#include "tools/converter/legacy_optimizer/fusion/format_trans_transpose_fusion_pass.h" #include "tools/converter/legacy_optimizer/fusion/quant_cast_fusion_pass.h" // #include "tools/converter/legacy_optimizer/fusion/batchnorm_fold_fusion_pass.h" // diff --git a/mindspore/lite/tools/converter/legacy_optimizer/fusion/CMakeLists.txt b/mindspore/lite/tools/converter/legacy_optimizer/fusion/CMakeLists.txt index 95d7817f14..3706eccda5 100755 --- a/mindspore/lite/tools/converter/legacy_optimizer/fusion/CMakeLists.txt +++ b/mindspore/lite/tools/converter/legacy_optimizer/fusion/CMakeLists.txt @@ -12,7 +12,7 @@ add_library(fusion_mid OBJECT ${CMAKE_CURRENT_SOURCE_DIR}/quant_cast_fusion_pass.cc ${CMAKE_CURRENT_SOURCE_DIR}/batchnorm_fold_fusion_pass.cc ${CMAKE_CURRENT_SOURCE_DIR}/format_trans_fusion_pass.cc - ${CMAKE_CURRENT_SOURCE_DIR}/format_trans_permute_fusion_pass.cc + ${CMAKE_CURRENT_SOURCE_DIR}/format_trans_transpose_fusion_pass.cc ) target_link_libraries(fusion_mid securec) diff --git a/mindspore/lite/tools/converter/legacy_optimizer/fusion/format_trans_permute_fusion_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/fusion/format_trans_transpose_fusion_pass.cc similarity index 57% rename from mindspore/lite/tools/converter/legacy_optimizer/fusion/format_trans_permute_fusion_pass.cc rename to mindspore/lite/tools/converter/legacy_optimizer/fusion/format_trans_transpose_fusion_pass.cc index 81b85037ac..9a81e60dce 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/fusion/format_trans_permute_fusion_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/fusion/format_trans_transpose_fusion_pass.cc @@ -18,7 +18,7 @@ #include #include #include -#include "tools/converter/legacy_optimizer/fusion/format_trans_permute_fusion_pass.h" +#include "tools/converter/legacy_optimizer/fusion/format_trans_transpose_fusion_pass.h" #include "utils/log_adapter.h" #include "securec/include/securec.h" #include "tools/common/graph_util.h" @@ -27,7 +27,7 @@ namespace mindspore { namespace lite { -#define kFormatTransPermuteMatchPathLen 2 +#define kFormatTransTransposeMatchPathLen 2 STATUS FormatTransPermuteFusionPass::DefinePattern() { // format trans + permute @@ -35,42 +35,42 @@ STATUS FormatTransPermuteFusionPass::DefinePattern() { auto formatTransOp = std::make_shared(); formatTransOp->id = kFormatTransformOp; formatTransOp->types = {PrimitiveType_Nchw2Nhwc, PrimitiveType_Nhwc2Nchw}; - auto permuteOp = std::make_shared(); - permuteOp->id = kPermuteOp; - permuteOp->types = {PrimitiveType_Permute}; + auto transposeOp = std::make_shared(); + transposeOp->id = kPermuteOp; + transposeOp->types = {PrimitiveType_Transpose}; - permuteOp->left = formatTransOp; - std::unique_ptr formatTransPermuteFusionPattern(new (std::nothrow) - FusionPattern(kFormatTrans2PermuteFusionPattern)); - if (formatTransPermuteFusionPattern == nullptr) { - MS_LOG(ERROR) << "new " << kFormatTrans2PermuteFusionPattern << " failed"; + transposeOp->left = formatTransOp; + std::unique_ptr formatTransTransposeFusionPattern( + new (std::nothrow) FusionPattern(kFormatTrans2TransposeFusionPattern)); + if (formatTransTransposeFusionPattern == nullptr) { + MS_LOG(ERROR) << "new " << kFormatTrans2TransposeFusionPattern << " failed"; return RET_ERROR; } - formatTransPermuteFusionPattern->AddPatternOp(formatTransOp); - formatTransPermuteFusionPattern->AddPatternOp(permuteOp); - formatTransPermuteFusionPattern->Finish(); - this->patterns.emplace_back(formatTransPermuteFusionPattern.release()); + formatTransTransposeFusionPattern->AddPatternOp(formatTransOp); + formatTransTransposeFusionPattern->AddPatternOp(transposeOp); + formatTransTransposeFusionPattern->Finish(); + this->patterns.emplace_back(formatTransTransposeFusionPattern.release()); } // permute + format trans { auto formatTransOp = std::make_shared(); formatTransOp->id = kFormatTransformOp; formatTransOp->types = {PrimitiveType_Nchw2Nhwc, PrimitiveType_Nhwc2Nchw}; - auto permuteOp = std::make_shared(); - permuteOp->id = kPermuteOp; - permuteOp->types = {PrimitiveType_Permute}; + auto transposeOp = std::make_shared(); + transposeOp->id = kPermuteOp; + transposeOp->types = {PrimitiveType_Permute}; - formatTransOp->left = permuteOp; - std::unique_ptr permuteFormatTransFusionPattern(new (std::nothrow) - FusionPattern(kPermute2FormatTransFusionPattern)); - if (permuteFormatTransFusionPattern == nullptr) { - MS_LOG(ERROR) << "new " << kPermute2FormatTransFusionPattern << " failed"; + formatTransOp->left = transposeOp; + std::unique_ptr transposeFormatTransFusionPattern( + new (std::nothrow) FusionPattern(kTranspose2FormatTransFusionPattern)); + if (transposeFormatTransFusionPattern == nullptr) { + MS_LOG(ERROR) << "new " << kTranspose2FormatTransFusionPattern << " failed"; return RET_ERROR; } - permuteFormatTransFusionPattern->AddPatternOp(formatTransOp); - permuteFormatTransFusionPattern->AddPatternOp(permuteOp); - permuteFormatTransFusionPattern->Finish(); - this->patterns.emplace_back(permuteFormatTransFusionPattern.release()); + transposeFormatTransFusionPattern->AddPatternOp(formatTransOp); + transposeFormatTransFusionPattern->AddPatternOp(transposeOp); + transposeFormatTransFusionPattern->Finish(); + this->patterns.emplace_back(transposeFormatTransFusionPattern.release()); } return RET_OK; } @@ -80,26 +80,26 @@ STATUS FormatTransPermuteFusionPass::Run(schema::MetaGraphT *graph) { return Fus STATUS FormatTransPermuteFusionPass::DoFusion(schema::MetaGraphT *graph, const std::string &patternName, std::unordered_map> &matchedPath) { MS_ASSERT(graph != nullptr); - if (matchedPath.size() != kFormatTransPermuteMatchPathLen) { - MS_LOG(ERROR) << "Format-Transform-Permute-Fusion should have " << kFormatTransPermuteMatchPathLen + if (matchedPath.size() != kFormatTransTransposeMatchPathLen) { + MS_LOG(ERROR) << "Format-Transform-Transpose-Fusion should have " << kFormatTransTransposeMatchPathLen << " NodeIndex in matchedPair"; return RET_PARAM_INVALID; } std::shared_ptr formatTransPath = matchedPath[kFormatTransformOp]; - std::shared_ptr permutePath = matchedPath[kPermuteOp]; + std::shared_ptr transposePath = matchedPath[kPermuteOp]; if (formatTransPath == nullptr) { MS_LOG(ERROR) << "formatTransPath is failed to get"; return RET_ERROR; } - if (permutePath == nullptr) { + if (transposePath == nullptr) { MS_LOG(ERROR) << "permutePath is failed to get"; return RET_ERROR; } auto &formatTransNode = graph->nodes.at(formatTransPath->nodeIdx); - auto &permuteNode = graph->nodes.at(permutePath->nodeIdx); + auto &transposeNode = graph->nodes.at(transposePath->nodeIdx); MS_ASSERT(formatTransNode != nullptr); - MS_ASSERT(permuteNode != nullptr); + MS_ASSERT(transposeNode != nullptr); auto formatTransType = formatTransNode->primitive->value.type; if (formatTransType != PrimitiveType_Nhwc2Nchw && formatTransType != PrimitiveType_Nchw2Nhwc) { MS_LOG(ERROR) << "FormatTransNode should be " << EnumNamePrimitiveType(PrimitiveType_Nhwc2Nchw) << " or " @@ -107,15 +107,15 @@ STATUS FormatTransPermuteFusionPass::DoFusion(schema::MetaGraphT *graph, const s << EnumNamePrimitiveType(formatTransType); return RET_ERROR; } - MS_ASSERT(permuteNode->primitive != nullptr); - auto permPrimitive = permuteNode->primitive->value.AsPermute(); - MS_ASSERT(permPrimitive != nullptr); - auto perm = permPrimitive->order; + MS_ASSERT(transposeNode->primitive != nullptr); + auto transposePrimitive = transposeNode->primitive->value.AsTranspose(); + MS_ASSERT(transposePrimitive != nullptr); + auto perm = transposePrimitive->perm; if (perm.size() != 4) { return RET_OK; } - std::vector nchw2nhwcPerm = {0, 2, 3, 1}; - std::vector nhwc2nchwPerm = {0, 3, 1, 2}; + std::vector nchw2nhwcPerm = {0, 2, 3, 1}; + std::vector nhwc2nchwPerm = {0, 3, 1, 2}; if ((perm == nchw2nhwcPerm && formatTransType == PrimitiveType_Nhwc2Nchw) || (perm == nhwc2nchwPerm && formatTransType == PrimitiveType_Nchw2Nhwc)) { auto status = IsolateOneWayNode(graph, formatTransPath->nodeIdx); @@ -124,9 +124,9 @@ STATUS FormatTransPermuteFusionPass::DoFusion(schema::MetaGraphT *graph, const s return status; } - status = IsolateOneWayNode(graph, permutePath->nodeIdx); + status = IsolateOneWayNode(graph, transposePath->nodeIdx); if (status != RET_OK) { - MS_LOG(ERROR) << "IsolateOneWayNode failed, node: " << permuteNode->name << ", error: " << status; + MS_LOG(ERROR) << "IsolateOneWayNode failed, node: " << transposeNode->name << ", error: " << status; return status; } } diff --git a/mindspore/lite/tools/converter/legacy_optimizer/fusion/format_trans_permute_fusion_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/fusion/format_trans_transpose_fusion_pass.h similarity index 89% rename from mindspore/lite/tools/converter/legacy_optimizer/fusion/format_trans_permute_fusion_pass.h rename to mindspore/lite/tools/converter/legacy_optimizer/fusion/format_trans_transpose_fusion_pass.h index f56e94969f..722227f19a 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/fusion/format_trans_permute_fusion_pass.h +++ b/mindspore/lite/tools/converter/legacy_optimizer/fusion/format_trans_transpose_fusion_pass.h @@ -26,8 +26,8 @@ namespace mindspore { namespace lite { constexpr const char *kFormatTransformOp = "FormatTransOp"; constexpr const char *kPermuteOp = "PermuteOp"; -constexpr const char *kFormatTrans2PermuteFusionPattern = "Nc2NhAndNh2NcFusionPattern"; -constexpr const char *kPermute2FormatTransFusionPattern = "Nc2NhAndNh2NcPassFusionPattern"; +constexpr const char *kFormatTrans2TransposeFusionPattern = "Nc2NhAndNh2NcFusionPattern"; +constexpr const char *kTranspose2FormatTransFusionPattern = "Nc2NhAndNh2NcPassFusionPattern"; class FormatTransPermuteFusionPass : public FusionPass { public: @@ -46,4 +46,3 @@ class FormatTransPermuteFusionPass : public FusionPass { } // namespace mindspore #endif // MINDSPORE_PREDICT_FORMAT_TRANS_PERMUTE_FUSION_PASS_H -