You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.cpp 6.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. /**
  2. * \file dnn/src/arm_common/utils.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "src/common/utils.h"
  12. #include <cstring>
  13. #include "src/arm_common/simd_macro/marm_neon.h"
  14. using namespace megdnn;
  15. namespace {
  16. template <typename dtype>
  17. void transpose_naive(const dtype *src, dtype *dst,
  18. int lda, int ldb, int n, int m)
  19. {
  20. rep(i, n) rep(j, m) {
  21. dst[i*ldb + j] = src[j*lda + i];
  22. }
  23. }
  24. void transpose_4x4_neon(const float *src, float *dst, int lda, int ldb)
  25. {
  26. float32x4x2_t a0, a1;
  27. a0.val[0] = vld1q_f32(src + 0*lda);
  28. a0.val[1] = vld1q_f32(src + 1*lda);
  29. a1.val[0] = vld1q_f32(src + 2*lda);
  30. a1.val[1] = vld1q_f32(src + 3*lda);
  31. float32x4x2_t b0 = vzipq_f32(a0.val[0], a1.val[0]);
  32. float32x4x2_t b1 = vzipq_f32(a0.val[1], a1.val[1]);
  33. float32x4x2_t c0 = vzipq_f32(b0.val[0], b1.val[0]);
  34. float32x4x2_t c1 = vzipq_f32(b0.val[1], b1.val[1]);
  35. vst1q_f32(dst + 0*ldb, c0.val[0]);
  36. vst1q_f32(dst + 1*ldb, c0.val[1]);
  37. vst1q_f32(dst + 2*ldb, c1.val[0]);
  38. vst1q_f32(dst + 3*ldb, c1.val[1]);
  39. }
  40. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  41. void transpose_8x8_neon(const dt_float16 *src, dt_float16 *dst, int lda, int ldb)
  42. {
  43. const __fp16* src_ptr = reinterpret_cast<const __fp16*>(src);
  44. __fp16* dst_ptr = reinterpret_cast<__fp16*>(dst);
  45. float16x8x4_t a0, a1;
  46. a0.val[0] = vld1q_f16(src_ptr + 0*lda); // A0A1A2A3A4A5A6A7
  47. a0.val[1] = vld1q_f16(src_ptr + 1*lda); // B0B1B2B3B4B5B6B7
  48. a0.val[2] = vld1q_f16(src_ptr + 2*lda); // C0C1C2C3C4C5C6C7
  49. a0.val[3] = vld1q_f16(src_ptr + 3*lda); // D0D1D2D3D4D5D6D7
  50. a1.val[0] = vld1q_f16(src_ptr + 4*lda); // E0E1E2E3E4E5E6E7
  51. a1.val[1] = vld1q_f16(src_ptr + 5*lda); // F0F1F2F3F4F5F6F7
  52. a1.val[2] = vld1q_f16(src_ptr + 6*lda); // G0G1G2G3G4G5G6G7
  53. a1.val[3] = vld1q_f16(src_ptr + 7*lda); // H0H1H2H3H4H5H6H7
  54. float16x8x2_t b0 = vzipq_f16(a0.val[0], a1.val[0]); // A0E0A1E1A2E2A3E3 A4E4A5E5A6E6A7E7
  55. float16x8x2_t b1 = vzipq_f16(a0.val[2], a1.val[2]); // C0G0C1G1C2G2C3G3 C4G4C5G5C6G6C7G7
  56. float16x8x2_t c0 = vzipq_f16(a0.val[1], a1.val[1]); // B0F0B1F1B2F2B3F3 B4F4B5F5B6F6B7F7
  57. float16x8x2_t c1 = vzipq_f16(a0.val[3], a1.val[3]); // D0H0D1H1D2H2D3H3 D4H4D5H5D6H6D7H7
  58. float16x8x2_t d0 = vzipq_f16(b0.val[0], b1.val[0]); // A0C0E0G0A1C1E1G1 A2C2E2G2A3C3E3G3
  59. float16x8x2_t d1 = vzipq_f16(c0.val[0], c1.val[0]); // B0D0F0H0B1D1F1H1 B2D2F2H2B3D3F3H3
  60. float16x8x2_t e0 = vzipq_f16(d0.val[0], d1.val[0]); // A0B0C0D0E0F0G0H0 A1B1C1D1E1F1G1H1
  61. float16x8x2_t e1 = vzipq_f16(d0.val[1], d1.val[1]); // A2B2C2D2E2F2G2H2 A3B3C3D3E3F3G3H3
  62. float16x8x2_t f0 = vzipq_f16(b0.val[1], b1.val[1]); // A4C4E4G4A5C5E5G5 A6C6E6G6A7C7E7G7
  63. float16x8x2_t f1 = vzipq_f16(c0.val[1], c1.val[1]); // B4D4F4H4B5D5F5H5 B6D6E6G6B7D7E7H7
  64. float16x8x2_t g0 = vzipq_f16(f0.val[0], f1.val[0]); // A4B4C4D4E4F4G4H4 A5B5C5D5E5F5G5H5
  65. float16x8x2_t g1 = vzipq_f16(f0.val[1], f1.val[1]); // A6B6C6D6E6F6G6H6 A7B7C7D7E7F7G7H7
  66. vst1q_f16(dst_ptr + 0*ldb, e0.val[0]);
  67. vst1q_f16(dst_ptr + 1*ldb, e0.val[1]);
  68. vst1q_f16(dst_ptr + 2*ldb, e1.val[0]);
  69. vst1q_f16(dst_ptr + 3*ldb, e1.val[1]);
  70. vst1q_f16(dst_ptr + 4*ldb, g0.val[0]);
  71. vst1q_f16(dst_ptr + 5*ldb, g0.val[1]);
  72. vst1q_f16(dst_ptr + 6*ldb, g1.val[0]);
  73. vst1q_f16(dst_ptr + 7*ldb, g1.val[1]);
  74. }
  75. #endif
  76. } // anonymous namespace
  77. namespace megdnn {
  78. template <>
  79. void transpose(const float* src, float* dst, size_t m, size_t n, ptrdiff_t lds,
  80. ptrdiff_t ldd) {
  81. if (lds == -1) {
  82. lds = n;
  83. }
  84. if (ldd == -1) {
  85. ldd = m;
  86. }
  87. for (size_t is = 0; is < n; is += 16) {
  88. for (size_t js = 0; js < m; js += 16) {
  89. auto ie = std::min(is + 16, n), je = std::min(js + 16, m), i = is;
  90. for (; i + 4 <= ie; i += 4) {
  91. auto j = js;
  92. for (; j + 4 <= je; j += 4) {
  93. transpose_4x4_neon(src + j * lds + i, dst + i * ldd + j,
  94. lds, ldd);
  95. }
  96. if (j < je) {
  97. transpose_naive(src + j * lds + i, dst + i * ldd + j, lds,
  98. ldd, 4, je - j);
  99. }
  100. }
  101. if (i < ie) {
  102. transpose_naive(src + js * lds + i, dst + i * ldd + js, lds,
  103. ldd, ie - i, je - js);
  104. }
  105. }
  106. }
  107. }
  108. template<typename dtype>
  109. void transpose_knc2nsck_helper(const dtype *src, dtype *dst,
  110. size_t k, size_t n, size_t c, size_t n_stride) {
  111. if (n_stride == k * c) {
  112. // dst is contiguous
  113. transpose(src, dst, k, n * c);
  114. } else {
  115. for (size_t i = 0; i < n; ++ i) {
  116. transpose(src + i * c, dst + i * n_stride,
  117. k, c, n * c);
  118. }
  119. }
  120. }
  121. template <>
  122. void transpose_knc2nsck(const float *src, float *dst,
  123. size_t k, size_t n, size_t c, size_t n_stride) {
  124. transpose_knc2nsck_helper(src, dst, k, n, c, n_stride);
  125. }
  126. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  127. template <>
  128. void transpose(const dt_float16* src, dt_float16* dst, size_t m, size_t n,
  129. ptrdiff_t lds, ptrdiff_t ldd) {
  130. if (lds == -1) {
  131. lds = n;
  132. }
  133. if (ldd == -1) {
  134. ldd = m;
  135. }
  136. for (size_t is = 0; is < n; is += 16) {
  137. for (size_t js = 0; js < m; js += 16) {
  138. auto ie = std::min(is + 16, n), je = std::min(js + 16, m), i = is;
  139. for (; i + 8 <= ie; i += 8) {
  140. auto j = js;
  141. for (; j + 8 <= je; j += 8) {
  142. transpose_8x8_neon(src + j * lds + i, dst + i * ldd + j,
  143. lds, ldd);
  144. }
  145. if (j < je) {
  146. transpose_naive(src + j * lds + i, dst + i * ldd + j, lds,
  147. ldd, 8, je - j);
  148. }
  149. }
  150. if (i < ie) {
  151. transpose_naive(src + js * lds + i, dst + i * ldd + js, lds,
  152. ldd, ie - i, je - js);
  153. }
  154. }
  155. }
  156. }
  157. template <>
  158. void transpose_knc2nsck(const dt_float16* src, dt_float16* dst, size_t k,
  159. size_t n, size_t c, size_t n_stride) {
  160. transpose_knc2nsck_helper(src, dst, k, n, c, n_stride);
  161. }
  162. #endif
  163. } // namespace megdnn
  164. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台