You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

transpose.c 7.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "nnacl/transpose.h"
  17. #include <string.h>
  18. #include "nnacl/errorcode.h"
  19. void TransposeDim2(float *in_data, float *out_data, int *strides, int *out_strides, int *perm, int *output_shape,
  20. int h_start, int h_end) {
  21. const int stride0 = strides[perm[0]];
  22. const int stride1 = strides[perm[1]];
  23. const int output0 = output_shape[0];
  24. const int output1 = output_shape[1];
  25. for (int i = 0; i < output0; ++i) {
  26. int out_stride0_i = i * output1;
  27. int stride0_i = i * 1 * stride0;
  28. for (int j = 0; j < output1; ++j) {
  29. out_data[out_stride0_i + j] = in_data[stride0_i + j * stride1];
  30. }
  31. }
  32. }
  33. void TransposeDim3(float *in_data, float *out_data, int *strides, int *out_strides, int *perm, int *output_shape,
  34. int h_start, int h_end) {
  35. const int stride0 = strides[perm[0]];
  36. const int stride1 = strides[perm[1]];
  37. const int stride2 = strides[perm[2]];
  38. const int out_stride0 = out_strides[0];
  39. const int out_stride1 = out_strides[1];
  40. const int output0 = output_shape[0];
  41. const int output1 = output_shape[1];
  42. const int output2 = output_shape[2];
  43. for (int i = 0; i < output0; ++i) {
  44. int out_stride0_i = i * out_stride0;
  45. int stride0_i = i * stride0;
  46. for (int j = 0; j < output1; ++j) {
  47. int out_stride1_j = j * out_stride1;
  48. int stride1_j = j * stride1;
  49. for (int k = 0; k < output2; ++k) {
  50. out_data[out_stride0_i + out_stride1_j + k] = in_data[stride0_i + stride1_j + k * stride2];
  51. }
  52. }
  53. }
  54. }
  55. void TransposeDim4(float *in_data, float *out_data, int *strides, int *out_strides, int *perm, int *output_shape,
  56. int h_start, int h_end) {
  57. const int stride0 = strides[perm[0]];
  58. const int stride1 = strides[perm[1]];
  59. const int stride2 = strides[perm[2]];
  60. const int stride3 = strides[perm[3]];
  61. const int out_stride0 = out_strides[0];
  62. const int out_stride1 = out_strides[1];
  63. const int out_stride2 = out_strides[2];
  64. const int output0 = output_shape[0];
  65. const int output1 = output_shape[1];
  66. const int output2 = output_shape[2];
  67. const int output3 = output_shape[3];
  68. for (int i = 0; i < output0; ++i) {
  69. int out_stride0_i = i * out_stride0;
  70. int stride0_i = i * stride0;
  71. for (int j = 0; j < output1; ++j) {
  72. int out_stride1_j = j * out_stride1;
  73. int stride1_j = j * stride1;
  74. for (int k = 0; k < output2; ++k) {
  75. int out_stride2_k = k * out_stride2;
  76. int stride2_k = k * stride2;
  77. for (int m = 0; m < output3; ++m) {
  78. out_data[out_stride0_i + out_stride1_j + out_stride2_k + m] =
  79. in_data[stride0_i + stride1_j + stride2_k + m * stride3];
  80. }
  81. }
  82. }
  83. }
  84. }
  85. void TransposeDim5(float *in_data, float *out_data, int *strides, int *out_strides, int *perm, int *output_shape,
  86. int h_start, int h_end) {
  87. const int stride0 = strides[perm[0]];
  88. const int stride1 = strides[perm[1]];
  89. const int stride2 = strides[perm[2]];
  90. const int stride3 = strides[perm[3]];
  91. const int stride4 = strides[perm[4]];
  92. const int out_stride0 = out_strides[0];
  93. const int out_stride1 = out_strides[1];
  94. const int out_stride2 = out_strides[2];
  95. const int out_stride3 = out_strides[3];
  96. const int output0 = output_shape[0];
  97. const int output1 = output_shape[1];
  98. const int output2 = output_shape[2];
  99. const int output3 = output_shape[3];
  100. const int output4 = output_shape[4];
  101. for (int i = 0; i < output0; ++i) {
  102. int out_stride0_i = i * out_stride0;
  103. int stride0_i = i * stride0;
  104. for (int j = 0; j < output1; ++j) {
  105. int out_stride1_j = j * out_stride1;
  106. int stride1_j = j * stride1;
  107. for (int k = 0; k < output2; ++k) {
  108. int out_stride2_k = k * out_stride2;
  109. int stride2_k = k * stride2;
  110. for (int m = 0; m < output3; ++m) {
  111. int out_stride3_m = m * out_stride3;
  112. int stride3_m = m * stride3;
  113. for (int n = 0; n < output4; ++n) {
  114. out_data[out_stride0_i + out_stride1_j + out_stride2_k + out_stride3_m + n] =
  115. in_data[stride0_i + stride1_j + stride2_k + stride3_m + n * stride4];
  116. }
  117. }
  118. }
  119. }
  120. }
  121. }
  122. void TransposeDims(float *in_data, float *out_data, int *strides, int *out_strides, int *perm, int *output_shape,
  123. int h_start, int h_end, int dims, int *size, int *position) {
  124. *(size + dims - 1) = 1;
  125. for (int i = dims - 1; i > 0; --i) {
  126. *(size + i - 1) = *(size + i) * output_shape[i];
  127. }
  128. for (size_t idx = 0; idx < (*size) * output_shape[0]; ++idx) {
  129. int pos = idx;
  130. int output_idx = 0;
  131. int input_idx = 0;
  132. for (int i = 0; i < dims; ++i) {
  133. *(position + i) = pos / *(size + i);
  134. int out_stride = i < dims - 1 ? out_strides[i] : 1;
  135. output_idx += (*(position + i) * out_stride);
  136. input_idx += (*(position + i) * strides[perm[i]]);
  137. pos -= *(position + i) * (*(size + i));
  138. }
  139. out_data[output_idx] = in_data[input_idx];
  140. }
  141. }
  142. int DoTranspose(float *in_data, float *out_data, int *input_shape, int *output_shape,
  143. TransposeParameter *transpose_param, int h_start, int h_end, int *size, int *position) {
  144. if (in_data == NULL || out_data == NULL) {
  145. return NNACL_ERR;
  146. }
  147. int *perm = transpose_param->perm_;
  148. int *strides = transpose_param->strides_;
  149. int *out_strides = transpose_param->out_strides_;
  150. int data_size = transpose_param->data_size_;
  151. int num_axes = transpose_param->num_axes_;
  152. if (num_axes < 2) {
  153. return NNACL_ERR;
  154. }
  155. // check if transpose is needed
  156. bool needTranspose = false;
  157. for (int i = 1; i < num_axes; ++i) {
  158. if (perm[i] - perm[i - 1] != 1) {
  159. needTranspose = true;
  160. break;
  161. }
  162. }
  163. if (!needTranspose) {
  164. (void)memcpy(out_data, in_data, data_size);
  165. return NNACL_OK;
  166. }
  167. if (num_axes == 2) {
  168. TransposeDim2(in_data, out_data, strides, out_strides, perm, output_shape, h_start, h_end);
  169. } else if (num_axes == 3) {
  170. TransposeDim3(in_data, out_data, strides, out_strides, perm, output_shape, h_start, h_end);
  171. } else if (num_axes == 4) {
  172. TransposeDim4(in_data, out_data, strides, out_strides, perm, output_shape, h_start, h_end);
  173. } else if (num_axes == 5) {
  174. TransposeDim5(in_data, out_data, strides, out_strides, perm, output_shape, h_start, h_end);
  175. } else {
  176. TransposeDims(in_data, out_data, strides, out_strides, perm, output_shape, h_start, h_end, num_axes, size,
  177. position);
  178. }
  179. return NNACL_OK;
  180. }