You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.cpp 6.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245
  1. /**
  2. * \file dnn/src/x86/utils.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "src/x86/utils.h"
  12. #include <xmmintrin.h>
  13. #include "src/common/utils.h"
  14. #ifdef _WIN32
  15. // For __cpuid
  16. #include <intrin.h>
  17. #endif
  18. #if MEGDNN_X86_WITH_MKL || MEGDNN_X86_WITH_OPENBLAS
  19. #include <pmmintrin.h>
  20. #endif
  21. using namespace megdnn;
  22. using namespace x86;
  23. namespace {
  24. struct CPUID {
  25. uint32_t eax, ebx, ecx, edx;
  26. CPUID() {
  27. #if defined(_WIN32)
  28. int cpuInfo[4];
  29. __cpuid(cpuInfo, 1);
  30. eax = cpuInfo[0];
  31. ebx = cpuInfo[1];
  32. ecx = cpuInfo[2];
  33. edx = cpuInfo[3];
  34. #else
  35. asm volatile("cpuid\n"
  36. : "=a"(eax), "=b"(ebx), "=c"(ecx), "=d"(edx)
  37. : "a"(1)
  38. : "cc");
  39. #endif
  40. }
  41. } cpuid;
  42. bool bit(unsigned x, unsigned y) {
  43. return (x >> y) & 1;
  44. }
  45. MEGDNN_ATTRIBUTE_TARGET("sse")
  46. void transpose4x4_sse(const float* src, float* dst, ptrdiff_t lda, ptrdiff_t ldb) {
  47. __m128 row0 = _mm_loadu_ps(src + 0 * lda);
  48. __m128 row1 = _mm_loadu_ps(src + 1 * lda);
  49. __m128 row2 = _mm_loadu_ps(src + 2 * lda);
  50. __m128 row3 = _mm_loadu_ps(src + 3 * lda);
  51. _MM_TRANSPOSE4_PS(row0, row1, row2, row3);
  52. _mm_storeu_ps(dst + 0 * ldb, row0);
  53. _mm_storeu_ps(dst + 1 * ldb, row1);
  54. _mm_storeu_ps(dst + 2 * ldb, row2);
  55. _mm_storeu_ps(dst + 3 * ldb, row3);
  56. }
  57. void transpose_naive(
  58. const float* src, float* dst, ptrdiff_t lda, ptrdiff_t ldb, size_t n,
  59. size_t m) {
  60. rep(i, n) rep(j, m) { dst[i * ldb + j] = src[j * lda + i]; }
  61. }
  62. bool feature_detect_avx2() {
  63. uint32_t eax, ebx, ecx, edx;
  64. // check cpu support
  65. #if defined(_WIN32)
  66. int cpuInfo[4];
  67. __cpuid(cpuInfo, 7);
  68. eax = cpuInfo[0];
  69. ebx = cpuInfo[1];
  70. ecx = cpuInfo[2];
  71. edx = cpuInfo[3];
  72. #else
  73. asm volatile("cpuid\n"
  74. : "=a"(eax), "=b"(ebx), "=c"(ecx), "=d"(edx)
  75. : "a"(7), "c"(0)
  76. : "cc");
  77. #endif
  78. if (!(bit(ebx, 3) && bit(ebx, 5) && bit(ebx, 8)))
  79. return false;
  80. // check os support
  81. asm volatile("xgetbv" : "=a"(eax), "=d"(edx) : "c"(0));
  82. return (eax & 6) == 6;
  83. }
  84. bool feature_detect_vnni() {
  85. uint32_t eax, ebx, ecx, edx;
  86. // check cpu support
  87. #if defined(_WIN32)
  88. int cpuInfo[4];
  89. __cpuid(cpuInfo, 7);
  90. eax = cpuInfo[0];
  91. ebx = cpuInfo[1];
  92. ecx = cpuInfo[2];
  93. edx = cpuInfo[3];
  94. #else
  95. asm volatile("cpuid\n"
  96. : "=a"(eax), "=b"(ebx), "=c"(ecx), "=d"(edx)
  97. : "a"(7), "c"(0)
  98. : "cc");
  99. #endif
  100. // avx512f ---> 16 ebx
  101. // avx512dq ---> 17 ebx
  102. // avx512bw ---> 30 ebx
  103. // avx512vl ---> 31 ebx
  104. // avx512vnni --->11 ecx
  105. if (!(bit(ebx, 16) && bit(ebx, 17) && bit(ebx, 30) && bit(ebx, 31) && bit(ecx, 11)))
  106. return false;
  107. // check os support
  108. asm volatile("xgetbv" : "=a"(eax), "=d"(edx) : "c"(0));
  109. return (eax & 6) == 6;
  110. }
  111. bool feature_detect_avx_fma(int ftr) {
  112. // see Detecting Availability and Support in
  113. // https://software.intel.com/en-us/articles/introduction-to-intel-advanced-vector-extensions
  114. // check CPU support
  115. if (!(bit(cpuid.ecx, 27) && bit(cpuid.ecx, ftr)))
  116. return false;
  117. // check OS support
  118. uint32_t edx, eax;
  119. asm volatile("xgetbv" : "=a"(eax), "=d"(edx) : "c"(0));
  120. return (eax & 6) == 6;
  121. }
  122. bool is_avx_supported = feature_detect_avx_fma(28);
  123. bool is_fma_supported = feature_detect_avx_fma(12);
  124. bool is_avx2_supported = feature_detect_avx2();
  125. bool is_vnni_supported = feature_detect_vnni();
  126. SIMDType disabled_simd_type_thresh = SIMDType::__NR_SIMD_TYPE;
  127. } // namespace
  128. namespace megdnn {
  129. #ifndef __SSE2__
  130. #error "megdnn at least needs sse2, please compile with -msse2 or higher"
  131. #endif
  132. bool x86::is_supported(SIMDType type) {
  133. if (type >= disabled_simd_type_thresh)
  134. return false;
  135. switch (type) {
  136. case SIMDType::SSE:
  137. return bit(cpuid.edx, 25);
  138. case SIMDType::SSE2:
  139. return bit(cpuid.edx, 26);
  140. case SIMDType::SSE3:
  141. return bit(cpuid.ecx, 0);
  142. case SIMDType::SSE4_1:
  143. return bit(cpuid.ecx, 19);
  144. case SIMDType::SSE4_2:
  145. return bit(cpuid.ecx, 20);
  146. case SIMDType::AVX:
  147. return is_avx_supported;
  148. case SIMDType::FMA:
  149. return is_fma_supported;
  150. case SIMDType::AVX2:
  151. return is_avx2_supported;
  152. case SIMDType::VNNI:
  153. return is_vnni_supported;
  154. default:
  155. break;
  156. }
  157. megdnn_throw("unknown cpu feature");
  158. }
  159. void x86::disable_simd_type(SIMDType type) {
  160. disabled_simd_type_thresh = type;
  161. }
  162. template <>
  163. void transpose(
  164. const float* src, float* dst, size_t m, size_t n, ptrdiff_t lds,
  165. ptrdiff_t ldd) {
  166. if (lds == -1) {
  167. lds = n;
  168. }
  169. if (ldd == -1) {
  170. ldd = m;
  171. }
  172. for (size_t is = 0; is < n; is += 16) {
  173. for (size_t js = 0; js < m; js += 16) {
  174. auto ie = std::min(is + 16, n), je = std::min(js + 16, m), i = is;
  175. for (; i + 4 <= ie; i += 4) {
  176. auto j = js;
  177. for (; j + 4 <= je; j += 4) {
  178. transpose4x4_sse(src + j * lds + i, dst + i * ldd + j, lds, ldd);
  179. }
  180. if (j < je) {
  181. transpose_naive(
  182. src + j * lds + i, dst + i * ldd + j, lds, ldd, 4, je - j);
  183. }
  184. }
  185. if (i < ie) {
  186. transpose_naive(
  187. src + js * lds + i, dst + i * ldd + js, lds, ldd, ie - i,
  188. je - js);
  189. }
  190. }
  191. }
  192. }
  193. template <>
  194. void transpose_knc2nsck(
  195. const float* src, float* dst, size_t k, size_t n, size_t c, size_t n_stride) {
  196. if (n_stride == k * c) {
  197. // dst is contiguous
  198. transpose(src, dst, k, n * c);
  199. } else {
  200. for (size_t i = 0; i < n; ++i) {
  201. transpose(src + i * c, dst + i * n_stride, k, c, n * c);
  202. }
  203. }
  204. }
  205. MEGDNN_ATTRIBUTE_TARGET("sse")
  206. void x86::disable_denorm() {
  207. _mm_setcsr(_mm_getcsr() | (_MM_FLUSH_ZERO_ON | _MM_DENORMALS_ZERO_ON));
  208. }
  209. } // namespace megdnn
  210. // vim: syntax=cpp.doxygen