Browse Source

!4542 change formatTrans+permute fusion to formatTrans+transpose fusion

Merge pull request !4542 from hangq/master
tags/v0.7.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
2e3f598726
4 changed files with 44 additions and 45 deletions
  1. +1
    -1
      mindspore/lite/tools/converter/graphdef_transform.cc
  2. +1
    -1
      mindspore/lite/tools/converter/legacy_optimizer/fusion/CMakeLists.txt
  3. +40
    -40
      mindspore/lite/tools/converter/legacy_optimizer/fusion/format_trans_transpose_fusion_pass.cc
  4. +2
    -3
      mindspore/lite/tools/converter/legacy_optimizer/fusion/format_trans_transpose_fusion_pass.h

+ 1
- 1
mindspore/lite/tools/converter/graphdef_transform.cc View File

@@ -30,7 +30,7 @@
#include "tools/converter/legacy_optimizer/fusion/conv_biasadd_fusion_pass.h"
// #include "tools/converter/legacy_optimizer/fusion/matmul_biasadd_fusion_pass.h"
#include "tools/converter/legacy_optimizer/fusion/format_trans_fusion_pass.h"
#include "tools/converter/legacy_optimizer/fusion/format_trans_permute_fusion_pass.h"
#include "tools/converter/legacy_optimizer/fusion/format_trans_transpose_fusion_pass.h"
#include "tools/converter/legacy_optimizer/fusion/quant_cast_fusion_pass.h"
// #include "tools/converter/legacy_optimizer/fusion/batchnorm_fold_fusion_pass.h"
//


+ 1
- 1
mindspore/lite/tools/converter/legacy_optimizer/fusion/CMakeLists.txt View File

@@ -12,7 +12,7 @@ add_library(fusion_mid OBJECT
${CMAKE_CURRENT_SOURCE_DIR}/quant_cast_fusion_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/batchnorm_fold_fusion_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/format_trans_fusion_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/format_trans_permute_fusion_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/format_trans_transpose_fusion_pass.cc
)

target_link_libraries(fusion_mid securec)

mindspore/lite/tools/converter/legacy_optimizer/fusion/format_trans_permute_fusion_pass.cc → mindspore/lite/tools/converter/legacy_optimizer/fusion/format_trans_transpose_fusion_pass.cc View File

@@ -18,7 +18,7 @@
#include <vector>
#include <unordered_map>
#include <memory>
#include "tools/converter/legacy_optimizer/fusion/format_trans_permute_fusion_pass.h"
#include "tools/converter/legacy_optimizer/fusion/format_trans_transpose_fusion_pass.h"
#include "utils/log_adapter.h"
#include "securec/include/securec.h"
#include "tools/common/graph_util.h"
@@ -27,7 +27,7 @@

namespace mindspore {
namespace lite {
#define kFormatTransPermuteMatchPathLen 2
#define kFormatTransTransposeMatchPathLen 2

STATUS FormatTransPermuteFusionPass::DefinePattern() {
// format trans + permute
@@ -35,42 +35,42 @@ STATUS FormatTransPermuteFusionPass::DefinePattern() {
auto formatTransOp = std::make_shared<PatternOp>();
formatTransOp->id = kFormatTransformOp;
formatTransOp->types = {PrimitiveType_Nchw2Nhwc, PrimitiveType_Nhwc2Nchw};
auto permuteOp = std::make_shared<PatternOp>();
permuteOp->id = kPermuteOp;
permuteOp->types = {PrimitiveType_Permute};
auto transposeOp = std::make_shared<PatternOp>();
transposeOp->id = kPermuteOp;
transposeOp->types = {PrimitiveType_Transpose};

permuteOp->left = formatTransOp;
std::unique_ptr<FusionPattern> formatTransPermuteFusionPattern(new (std::nothrow)
FusionPattern(kFormatTrans2PermuteFusionPattern));
if (formatTransPermuteFusionPattern == nullptr) {
MS_LOG(ERROR) << "new " << kFormatTrans2PermuteFusionPattern << " failed";
transposeOp->left = formatTransOp;
std::unique_ptr<FusionPattern> formatTransTransposeFusionPattern(
new (std::nothrow) FusionPattern(kFormatTrans2TransposeFusionPattern));
if (formatTransTransposeFusionPattern == nullptr) {
MS_LOG(ERROR) << "new " << kFormatTrans2TransposeFusionPattern << " failed";
return RET_ERROR;
}
formatTransPermuteFusionPattern->AddPatternOp(formatTransOp);
formatTransPermuteFusionPattern->AddPatternOp(permuteOp);
formatTransPermuteFusionPattern->Finish();
this->patterns.emplace_back(formatTransPermuteFusionPattern.release());
formatTransTransposeFusionPattern->AddPatternOp(formatTransOp);
formatTransTransposeFusionPattern->AddPatternOp(transposeOp);
formatTransTransposeFusionPattern->Finish();
this->patterns.emplace_back(formatTransTransposeFusionPattern.release());
}
// permute + format trans
{
auto formatTransOp = std::make_shared<PatternOp>();
formatTransOp->id = kFormatTransformOp;
formatTransOp->types = {PrimitiveType_Nchw2Nhwc, PrimitiveType_Nhwc2Nchw};
auto permuteOp = std::make_shared<PatternOp>();
permuteOp->id = kPermuteOp;
permuteOp->types = {PrimitiveType_Permute};
auto transposeOp = std::make_shared<PatternOp>();
transposeOp->id = kPermuteOp;
transposeOp->types = {PrimitiveType_Permute};

formatTransOp->left = permuteOp;
std::unique_ptr<FusionPattern> permuteFormatTransFusionPattern(new (std::nothrow)
FusionPattern(kPermute2FormatTransFusionPattern));
if (permuteFormatTransFusionPattern == nullptr) {
MS_LOG(ERROR) << "new " << kPermute2FormatTransFusionPattern << " failed";
formatTransOp->left = transposeOp;
std::unique_ptr<FusionPattern> transposeFormatTransFusionPattern(
new (std::nothrow) FusionPattern(kTranspose2FormatTransFusionPattern));
if (transposeFormatTransFusionPattern == nullptr) {
MS_LOG(ERROR) << "new " << kTranspose2FormatTransFusionPattern << " failed";
return RET_ERROR;
}
permuteFormatTransFusionPattern->AddPatternOp(formatTransOp);
permuteFormatTransFusionPattern->AddPatternOp(permuteOp);
permuteFormatTransFusionPattern->Finish();
this->patterns.emplace_back(permuteFormatTransFusionPattern.release());
transposeFormatTransFusionPattern->AddPatternOp(formatTransOp);
transposeFormatTransFusionPattern->AddPatternOp(transposeOp);
transposeFormatTransFusionPattern->Finish();
this->patterns.emplace_back(transposeFormatTransFusionPattern.release());
}
return RET_OK;
}
@@ -80,26 +80,26 @@ STATUS FormatTransPermuteFusionPass::Run(schema::MetaGraphT *graph) { return Fus
STATUS FormatTransPermuteFusionPass::DoFusion(schema::MetaGraphT *graph, const std::string &patternName,
std::unordered_map<std::string, std::shared_ptr<Path>> &matchedPath) {
MS_ASSERT(graph != nullptr);
if (matchedPath.size() != kFormatTransPermuteMatchPathLen) {
MS_LOG(ERROR) << "Format-Transform-Permute-Fusion should have " << kFormatTransPermuteMatchPathLen
if (matchedPath.size() != kFormatTransTransposeMatchPathLen) {
MS_LOG(ERROR) << "Format-Transform-Transpose-Fusion should have " << kFormatTransTransposeMatchPathLen
<< " NodeIndex in matchedPair";
return RET_PARAM_INVALID;
}

std::shared_ptr<Path> formatTransPath = matchedPath[kFormatTransformOp];
std::shared_ptr<Path> permutePath = matchedPath[kPermuteOp];
std::shared_ptr<Path> transposePath = matchedPath[kPermuteOp];
if (formatTransPath == nullptr) {
MS_LOG(ERROR) << "formatTransPath is failed to get";
return RET_ERROR;
}
if (permutePath == nullptr) {
if (transposePath == nullptr) {
MS_LOG(ERROR) << "permutePath is failed to get";
return RET_ERROR;
}
auto &formatTransNode = graph->nodes.at(formatTransPath->nodeIdx);
auto &permuteNode = graph->nodes.at(permutePath->nodeIdx);
auto &transposeNode = graph->nodes.at(transposePath->nodeIdx);
MS_ASSERT(formatTransNode != nullptr);
MS_ASSERT(permuteNode != nullptr);
MS_ASSERT(transposeNode != nullptr);
auto formatTransType = formatTransNode->primitive->value.type;
if (formatTransType != PrimitiveType_Nhwc2Nchw && formatTransType != PrimitiveType_Nchw2Nhwc) {
MS_LOG(ERROR) << "FormatTransNode should be " << EnumNamePrimitiveType(PrimitiveType_Nhwc2Nchw) << " or "
@@ -107,15 +107,15 @@ STATUS FormatTransPermuteFusionPass::DoFusion(schema::MetaGraphT *graph, const s
<< EnumNamePrimitiveType(formatTransType);
return RET_ERROR;
}
MS_ASSERT(permuteNode->primitive != nullptr);
auto permPrimitive = permuteNode->primitive->value.AsPermute();
MS_ASSERT(permPrimitive != nullptr);
auto perm = permPrimitive->order;
MS_ASSERT(transposeNode->primitive != nullptr);
auto transposePrimitive = transposeNode->primitive->value.AsTranspose();
MS_ASSERT(transposePrimitive != nullptr);
auto perm = transposePrimitive->perm;
if (perm.size() != 4) {
return RET_OK;
}
std::vector<int64_t> nchw2nhwcPerm = {0, 2, 3, 1};
std::vector<int64_t> nhwc2nchwPerm = {0, 3, 1, 2};
std::vector<int32_t> nchw2nhwcPerm = {0, 2, 3, 1};
std::vector<int32_t> nhwc2nchwPerm = {0, 3, 1, 2};
if ((perm == nchw2nhwcPerm && formatTransType == PrimitiveType_Nhwc2Nchw) ||
(perm == nhwc2nchwPerm && formatTransType == PrimitiveType_Nchw2Nhwc)) {
auto status = IsolateOneWayNode(graph, formatTransPath->nodeIdx);
@@ -124,9 +124,9 @@ STATUS FormatTransPermuteFusionPass::DoFusion(schema::MetaGraphT *graph, const s
return status;
}

status = IsolateOneWayNode(graph, permutePath->nodeIdx);
status = IsolateOneWayNode(graph, transposePath->nodeIdx);
if (status != RET_OK) {
MS_LOG(ERROR) << "IsolateOneWayNode failed, node: " << permuteNode->name << ", error: " << status;
MS_LOG(ERROR) << "IsolateOneWayNode failed, node: " << transposeNode->name << ", error: " << status;
return status;
}
}

mindspore/lite/tools/converter/legacy_optimizer/fusion/format_trans_permute_fusion_pass.h → mindspore/lite/tools/converter/legacy_optimizer/fusion/format_trans_transpose_fusion_pass.h View File

@@ -26,8 +26,8 @@ namespace mindspore {
namespace lite {
constexpr const char *kFormatTransformOp = "FormatTransOp";
constexpr const char *kPermuteOp = "PermuteOp";
constexpr const char *kFormatTrans2PermuteFusionPattern = "Nc2NhAndNh2NcFusionPattern";
constexpr const char *kPermute2FormatTransFusionPattern = "Nc2NhAndNh2NcPassFusionPattern";
constexpr const char *kFormatTrans2TransposeFusionPattern = "Nc2NhAndNh2NcFusionPattern";
constexpr const char *kTranspose2FormatTransFusionPattern = "Nc2NhAndNh2NcPassFusionPattern";

class FormatTransPermuteFusionPass : public FusionPass {
public:
@@ -46,4 +46,3 @@ class FormatTransPermuteFusionPass : public FusionPass {
} // namespace mindspore

#endif // MINDSPORE_PREDICT_FORMAT_TRANS_PERMUTE_FUSION_PASS_H


Loading…
Cancel
Save