Browse Source

Modify implementation of 'Transpose' to support the 5-dimension input.

tags/v0.7.0-beta
wsc 5 years ago
parent
commit
6bf7534cb2
2 changed files with 42 additions and 1 deletions
  1. +41
    -1
      mindspore/lite/src/runtime/kernel/arm/nnacl/transpose.cc
  2. +1
    -0
      mindspore/lite/src/runtime/kernel/arm/nnacl/transpose.h

+ 41
- 1
mindspore/lite/src/runtime/kernel/arm/nnacl/transpose.cc View File

@@ -85,6 +85,44 @@ void TransposeDim4(float *in_data, float *out_data, int *strides, int *out_strid
}
}

void TransposeDim5(float *in_data, float *out_data, int *strides, int *out_strides, int *perm, int *output_shape) {
const int stride0 = strides[perm[0]];
const int stride1 = strides[perm[1]];
const int stride2 = strides[perm[2]];
const int stride3 = strides[perm[3]];
const int stride4 = strides[perm[4]];
const int out_stride0 = out_strides[0];
const int out_stride1 = out_strides[1];
const int out_stride2 = out_strides[2];
const int out_stride3 = out_strides[3];
const int output0 = output_shape[0];
const int output1 = output_shape[1];
const int output2 = output_shape[2];
const int output3 = output_shape[3];
const int output4 = output_shape[4];

for (int i = 0; i < output0; i++) {
int out_stride0_i = i * out_stride0;
int stride0_i = i * stride0;
for (int j = 0; j < output1; j++) {
int out_stride1_j = j * out_stride1;
int stride1_j = j * stride1;
for (int k = 0; k < output2; k++) {
int out_stride2_k = k * out_stride2;
int stride2_k = k * stride2;
for (int m = 0; m < output3; m++) {
int out_stride3_m = m * out_stride3;
int stride3_m = m * stride3;
for (int n = 0; n < output4; n++) {
out_data[out_stride0_i + out_stride1_j + out_stride2_k + out_stride3_m + n] =
in_data[stride0_i + stride1_j + stride2_k + stride3_m + n * stride4];
}
}
}
}
}
}

int DoTranspose(float *in_data, float *out_data, int *input_shape, int *output_shape,
TransposeParameter *transpose_param) {
if (in_data == nullptr || out_data == nullptr) {
@@ -96,7 +134,7 @@ int DoTranspose(float *in_data, float *out_data, int *input_shape, int *output_s
int data_size = transpose_param->data_size_;
int num_axes = transpose_param->num_axes_;

if (num_axes < 2 || num_axes > 4) {
if (num_axes < 2 || num_axes > 5) {
return NNACL_ERR;
}

@@ -119,6 +157,8 @@ int DoTranspose(float *in_data, float *out_data, int *input_shape, int *output_s
TransposeDim3(in_data, out_data, strides, out_strides, perm, output_shape);
} else if (num_axes == 4) {
TransposeDim4(in_data, out_data, strides, out_strides, perm, output_shape);
} else if (num_axes == 5) {
TransposeDim5(in_data, out_data, strides, out_strides, perm, output_shape);
}
return NNACL_OK;
}


+ 1
- 0
mindspore/lite/src/runtime/kernel/arm/nnacl/transpose.h View File

@@ -34,6 +34,7 @@ int DoTranspose(float *in_data, float *out_data, int *input_shape, int *output_s
void TransposeDim2(float *in_data, float *out_data, int *strides, int *out_strides, int *perm, int *output_shape);
void TransposeDim3(float *in_data, float *out_data, int *strides, int *out_strides, int *perm, int *output_shape);
void TransposeDim4(float *in_data, float *out_data, int *strides, int *out_strides, int *perm, int *output_shape);
void TransposeDim5(float *in_data, float *out_data, int *strides, int *out_strides, int *perm, int *output_shape);

#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_TRANSPOSE_H_


Loading…
Cancel
Save