From 6bf7534cb2f92f26e2b3f17ec19b6abb281b4fd6 Mon Sep 17 00:00:00 2001 From: wsc Date: Tue, 11 Aug 2020 11:31:09 +0800 Subject: [PATCH] Modify implementation of 'Transpose' to support the 5-dimension input. --- .../src/runtime/kernel/arm/nnacl/transpose.cc | 42 ++++++++++++++++++- .../src/runtime/kernel/arm/nnacl/transpose.h | 1 + 2 files changed, 42 insertions(+), 1 deletion(-) diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/transpose.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/transpose.cc index c8f860812a..bda88804d0 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/transpose.cc +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/transpose.cc @@ -85,6 +85,44 @@ void TransposeDim4(float *in_data, float *out_data, int *strides, int *out_strid } } +void TransposeDim5(float *in_data, float *out_data, int *strides, int *out_strides, int *perm, int *output_shape) { + const int stride0 = strides[perm[0]]; + const int stride1 = strides[perm[1]]; + const int stride2 = strides[perm[2]]; + const int stride3 = strides[perm[3]]; + const int stride4 = strides[perm[4]]; + const int out_stride0 = out_strides[0]; + const int out_stride1 = out_strides[1]; + const int out_stride2 = out_strides[2]; + const int out_stride3 = out_strides[3]; + const int output0 = output_shape[0]; + const int output1 = output_shape[1]; + const int output2 = output_shape[2]; + const int output3 = output_shape[3]; + const int output4 = output_shape[4]; + + for (int i = 0; i < output0; i++) { + int out_stride0_i = i * out_stride0; + int stride0_i = i * stride0; + for (int j = 0; j < output1; j++) { + int out_stride1_j = j * out_stride1; + int stride1_j = j * stride1; + for (int k = 0; k < output2; k++) { + int out_stride2_k = k * out_stride2; + int stride2_k = k * stride2; + for (int m = 0; m < output3; m++) { + int out_stride3_m = m * out_stride3; + int stride3_m = m * stride3; + for (int n = 0; n < output4; n++) { + out_data[out_stride0_i + out_stride1_j + out_stride2_k + out_stride3_m + n] = + in_data[stride0_i + stride1_j + stride2_k + stride3_m + n * stride4]; + } + } + } + } + } +} + int DoTranspose(float *in_data, float *out_data, int *input_shape, int *output_shape, TransposeParameter *transpose_param) { if (in_data == nullptr || out_data == nullptr) { @@ -96,7 +134,7 @@ int DoTranspose(float *in_data, float *out_data, int *input_shape, int *output_s int data_size = transpose_param->data_size_; int num_axes = transpose_param->num_axes_; - if (num_axes < 2 || num_axes > 4) { + if (num_axes < 2 || num_axes > 5) { return NNACL_ERR; } @@ -119,6 +157,8 @@ int DoTranspose(float *in_data, float *out_data, int *input_shape, int *output_s TransposeDim3(in_data, out_data, strides, out_strides, perm, output_shape); } else if (num_axes == 4) { TransposeDim4(in_data, out_data, strides, out_strides, perm, output_shape); + } else if (num_axes == 5) { + TransposeDim5(in_data, out_data, strides, out_strides, perm, output_shape); } return NNACL_OK; } diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/transpose.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/transpose.h index 9220bfbd7e..2d3b2d0c50 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/transpose.h +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/transpose.h @@ -34,6 +34,7 @@ int DoTranspose(float *in_data, float *out_data, int *input_shape, int *output_s void TransposeDim2(float *in_data, float *out_data, int *strides, int *out_strides, int *perm, int *output_shape); void TransposeDim3(float *in_data, float *out_data, int *strides, int *out_strides, int *perm, int *output_shape); void TransposeDim4(float *in_data, float *out_data, int *strides, int *out_strides, int *perm, int *output_shape); +void TransposeDim5(float *in_data, float *out_data, int *strides, int *out_strides, int *perm, int *output_shape); #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_TRANSPOSE_H_