|
|
|
@@ -147,6 +147,36 @@ STATUS TransOpInsertPass::ChangeOpAxis(schema::MetaGraphT *graph, const std::uni |
|
|
|
auto axis_map = GetNc2NhAxisMap(); |
|
|
|
node->primitive->value.AsSplit()->splitDim = axis_map[origin_axis]; |
|
|
|
} |
|
|
|
if (type == PrimitiveType_Crop) { |
|
|
|
auto origin_axis = node->primitive->value.AsCrop()->axis; |
|
|
|
auto offsets = node->primitive->value.AsCrop()->offsets; |
|
|
|
auto axis_map = GetNc2NhAxisMap(); |
|
|
|
node->primitive->value.AsCrop()->axis = axis_map[origin_axis]; |
|
|
|
// nchw->nhwc,offsets need pad 0; |
|
|
|
if (axis_map[origin_axis] == 0) { |
|
|
|
offsets = {offsets[0], offsets[2], offsets[3], offsets[1]}; |
|
|
|
} else if (axis_map[origin_axis] == 1 || axis_map[origin_axis] == 2) { |
|
|
|
// orgin_axis = 2 or orgin_axis = 3 |
|
|
|
offsets.push_back(0); |
|
|
|
} else if (axis_map[origin_axis] == -1) { |
|
|
|
// origin_axis = 1 |
|
|
|
offsets = {offsets[1], offsets[2], offsets[0]}; |
|
|
|
} else { |
|
|
|
// axis error |
|
|
|
MS_LOG(ERROR) << "Crop error"; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
node->primitive->value.AsCrop()->offsets = offsets; |
|
|
|
} |
|
|
|
if (type == PrimitiveType_Slice) { |
|
|
|
auto attr = node->primitive->value.AsSlice(); |
|
|
|
auto origin_begin = attr->begin; |
|
|
|
attr->begin = {origin_begin[NCHW_N], origin_begin[NCHW_H], origin_begin[NCHW_W], origin_begin[NCHW_C]}; |
|
|
|
auto origin_end = attr->axes; |
|
|
|
attr->axes = {origin_end[NCHW_N], origin_end[NCHW_H], origin_end[NCHW_W], origin_end[NCHW_C]}; |
|
|
|
auto origin_stride = attr->size; |
|
|
|
attr->size = {origin_stride[NCHW_N], origin_stride[NCHW_H], origin_stride[NCHW_W], origin_stride[NCHW_C]}; |
|
|
|
} |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
|
|
|
|
|