Browse Source

!6846 [MSLITE]transformat optimize for slice and crop

Merge pull request !6846 from zhengjun10/stride
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
f7db856b69
3 changed files with 32 additions and 2 deletions
  1. +1
    -1
      mindspore/lite/test/models_onnx.cfg
  2. +1
    -1
      mindspore/lite/tools/common/node_util.cc
  3. +30
    -0
      mindspore/lite/tools/converter/legacy_optimizer/graph/trans_format_insert_pass.cc

+ 1
- 1
mindspore/lite/test/models_onnx.cfg View File

@@ -1,4 +1,4 @@
mtk_detect-mbv2-shortcut-400-400-simplified.onnx
mtk_emotions-d2012-75.8%.onnx
# mtk_face_features_v3.onnx
mtk_face_features_v3.onnx
ml_face_3d.onnx

+ 1
- 1
mindspore/lite/tools/common/node_util.cc View File

@@ -84,7 +84,7 @@ static const std::vector<schema::PrimitiveType> int8OpList = {
static const std::vector<schema::PrimitiveType> needInsertOpList = {
schema::PrimitiveType_Eltwise, schema::PrimitiveType_Activation, schema::PrimitiveType_Concat,
schema::PrimitiveType_Power, schema::PrimitiveType_StridedSlice, schema::PrimitiveType_Add,
schema::PrimitiveType_Split};
schema::PrimitiveType_Split, schema::PrimitiveType_Slice, schema::PrimitiveType_Crop};

static const std::unordered_map<int, int> nc2NhAxisMap = {{0, 0}, {1, -1}, {2, 1}, {3, 2}};



+ 30
- 0
mindspore/lite/tools/converter/legacy_optimizer/graph/trans_format_insert_pass.cc View File

@@ -147,6 +147,36 @@ STATUS TransOpInsertPass::ChangeOpAxis(schema::MetaGraphT *graph, const std::uni
auto axis_map = GetNc2NhAxisMap();
node->primitive->value.AsSplit()->splitDim = axis_map[origin_axis];
}
if (type == PrimitiveType_Crop) {
auto origin_axis = node->primitive->value.AsCrop()->axis;
auto offsets = node->primitive->value.AsCrop()->offsets;
auto axis_map = GetNc2NhAxisMap();
node->primitive->value.AsCrop()->axis = axis_map[origin_axis];
// nchw->nhwc,offsets need pad 0;
if (axis_map[origin_axis] == 0) {
offsets = {offsets[0], offsets[2], offsets[3], offsets[1]};
} else if (axis_map[origin_axis] == 1 || axis_map[origin_axis] == 2) {
// orgin_axis = 2 or orgin_axis = 3
offsets.push_back(0);
} else if (axis_map[origin_axis] == -1) {
// origin_axis = 1
offsets = {offsets[1], offsets[2], offsets[0]};
} else {
// axis error
MS_LOG(ERROR) << "Crop error";
return RET_ERROR;
}
node->primitive->value.AsCrop()->offsets = offsets;
}
if (type == PrimitiveType_Slice) {
auto attr = node->primitive->value.AsSlice();
auto origin_begin = attr->begin;
attr->begin = {origin_begin[NCHW_N], origin_begin[NCHW_H], origin_begin[NCHW_W], origin_begin[NCHW_C]};
auto origin_end = attr->axes;
attr->axes = {origin_end[NCHW_N], origin_end[NCHW_H], origin_end[NCHW_W], origin_end[NCHW_C]};
auto origin_stride = attr->size;
attr->size = {origin_stride[NCHW_N], origin_stride[NCHW_H], origin_stride[NCHW_W], origin_stride[NCHW_C]};
}
return RET_OK;
}



Loading…
Cancel
Save