|
|
|
@@ -18,7 +18,7 @@ |
|
|
|
#include <vector> |
|
|
|
#include <unordered_map> |
|
|
|
#include <memory> |
|
|
|
#include "tools/converter/legacy_optimizer/fusion/format_trans_permute_fusion_pass.h" |
|
|
|
#include "tools/converter/legacy_optimizer/fusion/format_trans_transpose_fusion_pass.h" |
|
|
|
#include "utils/log_adapter.h" |
|
|
|
#include "securec/include/securec.h" |
|
|
|
#include "tools/common/graph_util.h" |
|
|
|
@@ -27,7 +27,7 @@ |
|
|
|
|
|
|
|
namespace mindspore { |
|
|
|
namespace lite { |
|
|
|
#define kFormatTransPermuteMatchPathLen 2 |
|
|
|
#define kFormatTransTransposeMatchPathLen 2 |
|
|
|
|
|
|
|
STATUS FormatTransPermuteFusionPass::DefinePattern() { |
|
|
|
// format trans + permute |
|
|
|
@@ -35,42 +35,42 @@ STATUS FormatTransPermuteFusionPass::DefinePattern() { |
|
|
|
auto formatTransOp = std::make_shared<PatternOp>(); |
|
|
|
formatTransOp->id = kFormatTransformOp; |
|
|
|
formatTransOp->types = {PrimitiveType_Nchw2Nhwc, PrimitiveType_Nhwc2Nchw}; |
|
|
|
auto permuteOp = std::make_shared<PatternOp>(); |
|
|
|
permuteOp->id = kPermuteOp; |
|
|
|
permuteOp->types = {PrimitiveType_Permute}; |
|
|
|
auto transposeOp = std::make_shared<PatternOp>(); |
|
|
|
transposeOp->id = kPermuteOp; |
|
|
|
transposeOp->types = {PrimitiveType_Transpose}; |
|
|
|
|
|
|
|
permuteOp->left = formatTransOp; |
|
|
|
std::unique_ptr<FusionPattern> formatTransPermuteFusionPattern(new (std::nothrow) |
|
|
|
FusionPattern(kFormatTrans2PermuteFusionPattern)); |
|
|
|
if (formatTransPermuteFusionPattern == nullptr) { |
|
|
|
MS_LOG(ERROR) << "new " << kFormatTrans2PermuteFusionPattern << " failed"; |
|
|
|
transposeOp->left = formatTransOp; |
|
|
|
std::unique_ptr<FusionPattern> formatTransTransposeFusionPattern( |
|
|
|
new (std::nothrow) FusionPattern(kFormatTrans2TransposeFusionPattern)); |
|
|
|
if (formatTransTransposeFusionPattern == nullptr) { |
|
|
|
MS_LOG(ERROR) << "new " << kFormatTrans2TransposeFusionPattern << " failed"; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
formatTransPermuteFusionPattern->AddPatternOp(formatTransOp); |
|
|
|
formatTransPermuteFusionPattern->AddPatternOp(permuteOp); |
|
|
|
formatTransPermuteFusionPattern->Finish(); |
|
|
|
this->patterns.emplace_back(formatTransPermuteFusionPattern.release()); |
|
|
|
formatTransTransposeFusionPattern->AddPatternOp(formatTransOp); |
|
|
|
formatTransTransposeFusionPattern->AddPatternOp(transposeOp); |
|
|
|
formatTransTransposeFusionPattern->Finish(); |
|
|
|
this->patterns.emplace_back(formatTransTransposeFusionPattern.release()); |
|
|
|
} |
|
|
|
// permute + format trans |
|
|
|
{ |
|
|
|
auto formatTransOp = std::make_shared<PatternOp>(); |
|
|
|
formatTransOp->id = kFormatTransformOp; |
|
|
|
formatTransOp->types = {PrimitiveType_Nchw2Nhwc, PrimitiveType_Nhwc2Nchw}; |
|
|
|
auto permuteOp = std::make_shared<PatternOp>(); |
|
|
|
permuteOp->id = kPermuteOp; |
|
|
|
permuteOp->types = {PrimitiveType_Permute}; |
|
|
|
auto transposeOp = std::make_shared<PatternOp>(); |
|
|
|
transposeOp->id = kPermuteOp; |
|
|
|
transposeOp->types = {PrimitiveType_Permute}; |
|
|
|
|
|
|
|
formatTransOp->left = permuteOp; |
|
|
|
std::unique_ptr<FusionPattern> permuteFormatTransFusionPattern(new (std::nothrow) |
|
|
|
FusionPattern(kPermute2FormatTransFusionPattern)); |
|
|
|
if (permuteFormatTransFusionPattern == nullptr) { |
|
|
|
MS_LOG(ERROR) << "new " << kPermute2FormatTransFusionPattern << " failed"; |
|
|
|
formatTransOp->left = transposeOp; |
|
|
|
std::unique_ptr<FusionPattern> transposeFormatTransFusionPattern( |
|
|
|
new (std::nothrow) FusionPattern(kTranspose2FormatTransFusionPattern)); |
|
|
|
if (transposeFormatTransFusionPattern == nullptr) { |
|
|
|
MS_LOG(ERROR) << "new " << kTranspose2FormatTransFusionPattern << " failed"; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
permuteFormatTransFusionPattern->AddPatternOp(formatTransOp); |
|
|
|
permuteFormatTransFusionPattern->AddPatternOp(permuteOp); |
|
|
|
permuteFormatTransFusionPattern->Finish(); |
|
|
|
this->patterns.emplace_back(permuteFormatTransFusionPattern.release()); |
|
|
|
transposeFormatTransFusionPattern->AddPatternOp(formatTransOp); |
|
|
|
transposeFormatTransFusionPattern->AddPatternOp(transposeOp); |
|
|
|
transposeFormatTransFusionPattern->Finish(); |
|
|
|
this->patterns.emplace_back(transposeFormatTransFusionPattern.release()); |
|
|
|
} |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
@@ -80,26 +80,26 @@ STATUS FormatTransPermuteFusionPass::Run(schema::MetaGraphT *graph) { return Fus |
|
|
|
STATUS FormatTransPermuteFusionPass::DoFusion(schema::MetaGraphT *graph, const std::string &patternName, |
|
|
|
std::unordered_map<std::string, std::shared_ptr<Path>> &matchedPath) { |
|
|
|
MS_ASSERT(graph != nullptr); |
|
|
|
if (matchedPath.size() != kFormatTransPermuteMatchPathLen) { |
|
|
|
MS_LOG(ERROR) << "Format-Transform-Permute-Fusion should have " << kFormatTransPermuteMatchPathLen |
|
|
|
if (matchedPath.size() != kFormatTransTransposeMatchPathLen) { |
|
|
|
MS_LOG(ERROR) << "Format-Transform-Transpose-Fusion should have " << kFormatTransTransposeMatchPathLen |
|
|
|
<< " NodeIndex in matchedPair"; |
|
|
|
return RET_PARAM_INVALID; |
|
|
|
} |
|
|
|
|
|
|
|
std::shared_ptr<Path> formatTransPath = matchedPath[kFormatTransformOp]; |
|
|
|
std::shared_ptr<Path> permutePath = matchedPath[kPermuteOp]; |
|
|
|
std::shared_ptr<Path> transposePath = matchedPath[kPermuteOp]; |
|
|
|
if (formatTransPath == nullptr) { |
|
|
|
MS_LOG(ERROR) << "formatTransPath is failed to get"; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
if (permutePath == nullptr) { |
|
|
|
if (transposePath == nullptr) { |
|
|
|
MS_LOG(ERROR) << "permutePath is failed to get"; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
auto &formatTransNode = graph->nodes.at(formatTransPath->nodeIdx); |
|
|
|
auto &permuteNode = graph->nodes.at(permutePath->nodeIdx); |
|
|
|
auto &transposeNode = graph->nodes.at(transposePath->nodeIdx); |
|
|
|
MS_ASSERT(formatTransNode != nullptr); |
|
|
|
MS_ASSERT(permuteNode != nullptr); |
|
|
|
MS_ASSERT(transposeNode != nullptr); |
|
|
|
auto formatTransType = formatTransNode->primitive->value.type; |
|
|
|
if (formatTransType != PrimitiveType_Nhwc2Nchw && formatTransType != PrimitiveType_Nchw2Nhwc) { |
|
|
|
MS_LOG(ERROR) << "FormatTransNode should be " << EnumNamePrimitiveType(PrimitiveType_Nhwc2Nchw) << " or " |
|
|
|
@@ -107,15 +107,15 @@ STATUS FormatTransPermuteFusionPass::DoFusion(schema::MetaGraphT *graph, const s |
|
|
|
<< EnumNamePrimitiveType(formatTransType); |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
MS_ASSERT(permuteNode->primitive != nullptr); |
|
|
|
auto permPrimitive = permuteNode->primitive->value.AsPermute(); |
|
|
|
MS_ASSERT(permPrimitive != nullptr); |
|
|
|
auto perm = permPrimitive->order; |
|
|
|
MS_ASSERT(transposeNode->primitive != nullptr); |
|
|
|
auto transposePrimitive = transposeNode->primitive->value.AsTranspose(); |
|
|
|
MS_ASSERT(transposePrimitive != nullptr); |
|
|
|
auto perm = transposePrimitive->perm; |
|
|
|
if (perm.size() != 4) { |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
std::vector<int64_t> nchw2nhwcPerm = {0, 2, 3, 1}; |
|
|
|
std::vector<int64_t> nhwc2nchwPerm = {0, 3, 1, 2}; |
|
|
|
std::vector<int32_t> nchw2nhwcPerm = {0, 2, 3, 1}; |
|
|
|
std::vector<int32_t> nhwc2nchwPerm = {0, 3, 1, 2}; |
|
|
|
if ((perm == nchw2nhwcPerm && formatTransType == PrimitiveType_Nhwc2Nchw) || |
|
|
|
(perm == nhwc2nchwPerm && formatTransType == PrimitiveType_Nchw2Nhwc)) { |
|
|
|
auto status = IsolateOneWayNode(graph, formatTransPath->nodeIdx); |
|
|
|
@@ -124,9 +124,9 @@ STATUS FormatTransPermuteFusionPass::DoFusion(schema::MetaGraphT *graph, const s |
|
|
|
return status; |
|
|
|
} |
|
|
|
|
|
|
|
status = IsolateOneWayNode(graph, permutePath->nodeIdx); |
|
|
|
status = IsolateOneWayNode(graph, transposePath->nodeIdx); |
|
|
|
if (status != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "IsolateOneWayNode failed, node: " << permuteNode->name << ", error: " << status; |
|
|
|
MS_LOG(ERROR) << "IsolateOneWayNode failed, node: " << transposeNode->name << ", error: " << status; |
|
|
|
return status; |
|
|
|
} |
|
|
|
} |