|
|
|
@@ -60,32 +60,36 @@ int TransposeInferShape(const TensorC *const *inputs, size_t inputs_size, Tensor |
|
|
|
for (size_t i = 0; i < perms_num; i++) { |
|
|
|
ShapePush(perm, &perm_size, perm_data[i]); |
|
|
|
} |
|
|
|
int out_shape[MAX_TRANSPOSE_DIM_SIZE] = {0}; |
|
|
|
if (input->shape_size_ != 4 && perms_num == 4) { |
|
|
|
for (size_t i = 0; i < input->shape_size_; ++i) { |
|
|
|
out_shape[i] = input->shape_[i]; |
|
|
|
} |
|
|
|
SetShapeArray(output, out_shape, input->shape_size_); |
|
|
|
return NNACL_OK; |
|
|
|
} |
|
|
|
const int nchw2nhwc[4] = {0, 2, 3, 1}; |
|
|
|
const int nhwc2nchw[4] = {0, 3, 1, 2}; |
|
|
|
const int trans3d[3] = {0, 2, 1}; |
|
|
|
if (perms_num == 4) { |
|
|
|
if (input->format_ == Format_NCHW && CheckPermTransFormat(perm, nchw2nhwc, perms_num)) { |
|
|
|
output->format_ = Format_NHWC; |
|
|
|
} else if (input->format_ == Format_NHWC && CheckPermTransFormat(perm, nhwc2nchw, perms_num)) { |
|
|
|
output->format_ = Format_NCHW; |
|
|
|
} |
|
|
|
// though the perm is 4d in default, the input can be a 3d tensor. The op implementation should be adapted to this. |
|
|
|
if (input->shape_size_ == 3) { |
|
|
|
ShapeSet(perm, &perm_size, trans3d, 3); |
|
|
|
} |
|
|
|
} |
|
|
|
output->shape_size_ = perm_size; |
|
|
|
for (size_t i = 0; i < perm_size; ++i) { |
|
|
|
out_shape[i] = input->shape_[perm[i]]; |
|
|
|
} |
|
|
|
// set output shape |
|
|
|
int out_shape[MAX_TRANSPOSE_DIM_SIZE] = {0}; |
|
|
|
size_t in_shape_size = input->shape_size_; |
|
|
|
output->shape_size_ = in_shape_size; |
|
|
|
if (perm_size == 0) { |
|
|
|
size_t shape_size = input->shape_size_; |
|
|
|
output->shape_size_ = shape_size; |
|
|
|
for (size_t i = 0; i < shape_size; ++i) { |
|
|
|
out_shape[shape_size - i - 1] = input->shape_[i]; |
|
|
|
for (size_t i = 0; i < in_shape_size; ++i) { |
|
|
|
out_shape[in_shape_size - i - 1] = input->shape_[i]; |
|
|
|
} |
|
|
|
} else if (perm_size != in_shape_size) { |
|
|
|
for (size_t i = 0; i < in_shape_size; ++i) { |
|
|
|
out_shape[i] = input->shape_[i]; |
|
|
|
} |
|
|
|
} else { |
|
|
|
output->shape_size_ = perm_size; |
|
|
|
for (size_t i = 0; i < perm_size; ++i) { |
|
|
|
out_shape[i] = input->shape_[perm[i]]; |
|
|
|
} |
|
|
|
} |
|
|
|
SetShapeArray(output, out_shape, output->shape_size_); |
|
|
|
|