|
|
|
@@ -85,6 +85,44 @@ void TransposeDim4(float *in_data, float *out_data, int *strides, int *out_strid |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void TransposeDim5(float *in_data, float *out_data, int *strides, int *out_strides, int *perm, int *output_shape) { |
|
|
|
const int stride0 = strides[perm[0]]; |
|
|
|
const int stride1 = strides[perm[1]]; |
|
|
|
const int stride2 = strides[perm[2]]; |
|
|
|
const int stride3 = strides[perm[3]]; |
|
|
|
const int stride4 = strides[perm[4]]; |
|
|
|
const int out_stride0 = out_strides[0]; |
|
|
|
const int out_stride1 = out_strides[1]; |
|
|
|
const int out_stride2 = out_strides[2]; |
|
|
|
const int out_stride3 = out_strides[3]; |
|
|
|
const int output0 = output_shape[0]; |
|
|
|
const int output1 = output_shape[1]; |
|
|
|
const int output2 = output_shape[2]; |
|
|
|
const int output3 = output_shape[3]; |
|
|
|
const int output4 = output_shape[4]; |
|
|
|
|
|
|
|
for (int i = 0; i < output0; i++) { |
|
|
|
int out_stride0_i = i * out_stride0; |
|
|
|
int stride0_i = i * stride0; |
|
|
|
for (int j = 0; j < output1; j++) { |
|
|
|
int out_stride1_j = j * out_stride1; |
|
|
|
int stride1_j = j * stride1; |
|
|
|
for (int k = 0; k < output2; k++) { |
|
|
|
int out_stride2_k = k * out_stride2; |
|
|
|
int stride2_k = k * stride2; |
|
|
|
for (int m = 0; m < output3; m++) { |
|
|
|
int out_stride3_m = m * out_stride3; |
|
|
|
int stride3_m = m * stride3; |
|
|
|
for (int n = 0; n < output4; n++) { |
|
|
|
out_data[out_stride0_i + out_stride1_j + out_stride2_k + out_stride3_m + n] = |
|
|
|
in_data[stride0_i + stride1_j + stride2_k + stride3_m + n * stride4]; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
int DoTranspose(float *in_data, float *out_data, int *input_shape, int *output_shape, |
|
|
|
TransposeParameter *transpose_param) { |
|
|
|
if (in_data == nullptr || out_data == nullptr) { |
|
|
|
@@ -96,7 +134,7 @@ int DoTranspose(float *in_data, float *out_data, int *input_shape, int *output_s |
|
|
|
int data_size = transpose_param->data_size_; |
|
|
|
int num_axes = transpose_param->num_axes_; |
|
|
|
|
|
|
|
if (num_axes < 2 || num_axes > 4) { |
|
|
|
if (num_axes < 2 || num_axes > 5) { |
|
|
|
return NNACL_ERR; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -119,6 +157,8 @@ int DoTranspose(float *in_data, float *out_data, int *input_shape, int *output_s |
|
|
|
TransposeDim3(in_data, out_data, strides, out_strides, perm, output_shape); |
|
|
|
} else if (num_axes == 4) { |
|
|
|
TransposeDim4(in_data, out_data, strides, out_strides, perm, output_shape); |
|
|
|
} else if (num_axes == 5) { |
|
|
|
TransposeDim5(in_data, out_data, strides, out_strides, perm, output_shape); |
|
|
|
} |
|
|
|
return NNACL_OK; |
|
|
|
} |
|
|
|
|