From f329bdd77be7a3bf21790d8039fcf0566a39a97d Mon Sep 17 00:00:00 2001 From: zhengjun10 Date: Thu, 24 Sep 2020 16:53:24 +0800 Subject: [PATCH] transformat optimize for slice and crop --- mindspore/lite/test/models_onnx.cfg | 2 +- mindspore/lite/tools/common/node_util.cc | 2 +- .../graph/trans_format_insert_pass.cc | 30 +++++++++++++++++++ 3 files changed, 32 insertions(+), 2 deletions(-) diff --git a/mindspore/lite/test/models_onnx.cfg b/mindspore/lite/test/models_onnx.cfg index ab17ccd60c..da693fe0a2 100644 --- a/mindspore/lite/test/models_onnx.cfg +++ b/mindspore/lite/test/models_onnx.cfg @@ -1,4 +1,4 @@ mtk_detect-mbv2-shortcut-400-400-simplified.onnx mtk_emotions-d2012-75.8%.onnx -# mtk_face_features_v3.onnx +mtk_face_features_v3.onnx ml_face_3d.onnx diff --git a/mindspore/lite/tools/common/node_util.cc b/mindspore/lite/tools/common/node_util.cc index 7d98d1d05d..c0fd102a51 100644 --- a/mindspore/lite/tools/common/node_util.cc +++ b/mindspore/lite/tools/common/node_util.cc @@ -84,7 +84,7 @@ static const std::vector int8OpList = { static const std::vector needInsertOpList = { schema::PrimitiveType_Eltwise, schema::PrimitiveType_Activation, schema::PrimitiveType_Concat, schema::PrimitiveType_Power, schema::PrimitiveType_StridedSlice, schema::PrimitiveType_Add, - schema::PrimitiveType_Split}; + schema::PrimitiveType_Split, schema::PrimitiveType_Slice, schema::PrimitiveType_Crop}; static const std::unordered_map nc2NhAxisMap = {{0, 0}, {1, -1}, {2, 1}, {3, 2}}; diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/trans_format_insert_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/trans_format_insert_pass.cc index dca8150222..7a0534231f 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/trans_format_insert_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/trans_format_insert_pass.cc @@ -147,6 +147,36 @@ STATUS TransOpInsertPass::ChangeOpAxis(schema::MetaGraphT *graph, const std::uni auto axis_map = GetNc2NhAxisMap(); node->primitive->value.AsSplit()->splitDim = axis_map[origin_axis]; } + if (type == PrimitiveType_Crop) { + auto origin_axis = node->primitive->value.AsCrop()->axis; + auto offsets = node->primitive->value.AsCrop()->offsets; + auto axis_map = GetNc2NhAxisMap(); + node->primitive->value.AsCrop()->axis = axis_map[origin_axis]; + // nchw->nhwc,offsets need pad 0; + if (axis_map[origin_axis] == 0) { + offsets = {offsets[0], offsets[2], offsets[3], offsets[1]}; + } else if (axis_map[origin_axis] == 1 || axis_map[origin_axis] == 2) { + // orgin_axis = 2 or orgin_axis = 3 + offsets.push_back(0); + } else if (axis_map[origin_axis] == -1) { + // origin_axis = 1 + offsets = {offsets[1], offsets[2], offsets[0]}; + } else { + // axis error + MS_LOG(ERROR) << "Crop error"; + return RET_ERROR; + } + node->primitive->value.AsCrop()->offsets = offsets; + } + if (type == PrimitiveType_Slice) { + auto attr = node->primitive->value.AsSlice(); + auto origin_begin = attr->begin; + attr->begin = {origin_begin[NCHW_N], origin_begin[NCHW_H], origin_begin[NCHW_W], origin_begin[NCHW_C]}; + auto origin_end = attr->axes; + attr->axes = {origin_end[NCHW_N], origin_end[NCHW_H], origin_end[NCHW_W], origin_end[NCHW_C]}; + auto origin_stride = attr->size; + attr->size = {origin_stride[NCHW_N], origin_stride[NCHW_H], origin_stride[NCHW_W], origin_stride[NCHW_C]}; + } return RET_OK; }