You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.cpp 6.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235
  1. #include "src/x86/utils.h"
  2. #include <xmmintrin.h>
  3. #include "src/common/utils.h"
  4. #ifdef _WIN32
  5. // For __cpuid
  6. #include <intrin.h>
  7. #endif
  8. #if MEGDNN_X86_WITH_MKL || MEGDNN_X86_WITH_OPENBLAS
  9. #include <pmmintrin.h>
  10. #endif
  11. using namespace megdnn;
  12. using namespace x86;
  13. namespace {
  14. struct CPUID {
  15. uint32_t eax, ebx, ecx, edx;
  16. CPUID() {
  17. #if defined(_WIN32)
  18. int cpuInfo[4];
  19. __cpuid(cpuInfo, 1);
  20. eax = cpuInfo[0];
  21. ebx = cpuInfo[1];
  22. ecx = cpuInfo[2];
  23. edx = cpuInfo[3];
  24. #else
  25. asm volatile("cpuid\n"
  26. : "=a"(eax), "=b"(ebx), "=c"(ecx), "=d"(edx)
  27. : "a"(1)
  28. : "cc");
  29. #endif
  30. }
  31. } cpuid;
  32. bool bit(unsigned x, unsigned y) {
  33. return (x >> y) & 1;
  34. }
  35. MEGDNN_ATTRIBUTE_TARGET("sse")
  36. void transpose4x4_sse(const float* src, float* dst, ptrdiff_t lda, ptrdiff_t ldb) {
  37. __m128 row0 = _mm_loadu_ps(src + 0 * lda);
  38. __m128 row1 = _mm_loadu_ps(src + 1 * lda);
  39. __m128 row2 = _mm_loadu_ps(src + 2 * lda);
  40. __m128 row3 = _mm_loadu_ps(src + 3 * lda);
  41. _MM_TRANSPOSE4_PS(row0, row1, row2, row3);
  42. _mm_storeu_ps(dst + 0 * ldb, row0);
  43. _mm_storeu_ps(dst + 1 * ldb, row1);
  44. _mm_storeu_ps(dst + 2 * ldb, row2);
  45. _mm_storeu_ps(dst + 3 * ldb, row3);
  46. }
  47. void transpose_naive(
  48. const float* src, float* dst, ptrdiff_t lda, ptrdiff_t ldb, size_t n,
  49. size_t m) {
  50. rep(i, n) rep(j, m) { dst[i * ldb + j] = src[j * lda + i]; }
  51. }
  52. bool feature_detect_avx2() {
  53. uint32_t eax, ebx, ecx, edx;
  54. // check cpu support
  55. #if defined(_WIN32)
  56. int cpuInfo[4];
  57. __cpuid(cpuInfo, 7);
  58. eax = cpuInfo[0];
  59. ebx = cpuInfo[1];
  60. ecx = cpuInfo[2];
  61. edx = cpuInfo[3];
  62. #else
  63. asm volatile("cpuid\n"
  64. : "=a"(eax), "=b"(ebx), "=c"(ecx), "=d"(edx)
  65. : "a"(7), "c"(0)
  66. : "cc");
  67. #endif
  68. if (!(bit(ebx, 3) && bit(ebx, 5) && bit(ebx, 8)))
  69. return false;
  70. // check os support
  71. asm volatile("xgetbv" : "=a"(eax), "=d"(edx) : "c"(0));
  72. return (eax & 6) == 6;
  73. }
  74. bool feature_detect_vnni() {
  75. uint32_t eax, ebx, ecx, edx;
  76. // check cpu support
  77. #if defined(_WIN32)
  78. int cpuInfo[4];
  79. __cpuid(cpuInfo, 7);
  80. eax = cpuInfo[0];
  81. ebx = cpuInfo[1];
  82. ecx = cpuInfo[2];
  83. edx = cpuInfo[3];
  84. #else
  85. asm volatile("cpuid\n"
  86. : "=a"(eax), "=b"(ebx), "=c"(ecx), "=d"(edx)
  87. : "a"(7), "c"(0)
  88. : "cc");
  89. #endif
  90. // avx512f ---> 16 ebx
  91. // avx512dq ---> 17 ebx
  92. // avx512bw ---> 30 ebx
  93. // avx512vl ---> 31 ebx
  94. // avx512vnni --->11 ecx
  95. if (!(bit(ebx, 16) && bit(ebx, 17) && bit(ebx, 30) && bit(ebx, 31) && bit(ecx, 11)))
  96. return false;
  97. // check os support
  98. asm volatile("xgetbv" : "=a"(eax), "=d"(edx) : "c"(0));
  99. return (eax & 6) == 6;
  100. }
  101. bool feature_detect_avx_fma(int ftr) {
  102. // see Detecting Availability and Support in
  103. // https://software.intel.com/en-us/articles/introduction-to-intel-advanced-vector-extensions
  104. // check CPU support
  105. if (!(bit(cpuid.ecx, 27) && bit(cpuid.ecx, ftr)))
  106. return false;
  107. // check OS support
  108. uint32_t edx, eax;
  109. asm volatile("xgetbv" : "=a"(eax), "=d"(edx) : "c"(0));
  110. return (eax & 6) == 6;
  111. }
  112. bool is_avx_supported = feature_detect_avx_fma(28);
  113. bool is_fma_supported = feature_detect_avx_fma(12);
  114. bool is_avx2_supported = feature_detect_avx2();
  115. bool is_vnni_supported = feature_detect_vnni();
  116. SIMDType disabled_simd_type_thresh = SIMDType::__NR_SIMD_TYPE;
  117. } // namespace
  118. namespace megdnn {
  119. #ifndef __SSE2__
  120. #error "megdnn at least needs sse2, please compile with -msse2 or higher"
  121. #endif
  122. bool x86::is_supported(SIMDType type) {
  123. if (type >= disabled_simd_type_thresh)
  124. return false;
  125. switch (type) {
  126. case SIMDType::SSE:
  127. return bit(cpuid.edx, 25);
  128. case SIMDType::SSE2:
  129. return bit(cpuid.edx, 26);
  130. case SIMDType::SSE3:
  131. return bit(cpuid.ecx, 0);
  132. case SIMDType::SSE4_1:
  133. return bit(cpuid.ecx, 19);
  134. case SIMDType::SSE4_2:
  135. return bit(cpuid.ecx, 20);
  136. case SIMDType::AVX:
  137. return is_avx_supported;
  138. case SIMDType::FMA:
  139. return is_fma_supported;
  140. case SIMDType::AVX2:
  141. return is_avx2_supported;
  142. case SIMDType::VNNI:
  143. return is_vnni_supported;
  144. default:
  145. break;
  146. }
  147. megdnn_throw("unknown cpu feature");
  148. }
  149. void x86::disable_simd_type(SIMDType type) {
  150. disabled_simd_type_thresh = type;
  151. }
  152. template <>
  153. void transpose(
  154. const float* src, float* dst, size_t m, size_t n, ptrdiff_t lds,
  155. ptrdiff_t ldd) {
  156. if (lds == -1) {
  157. lds = n;
  158. }
  159. if (ldd == -1) {
  160. ldd = m;
  161. }
  162. for (size_t is = 0; is < n; is += 16) {
  163. for (size_t js = 0; js < m; js += 16) {
  164. auto ie = std::min(is + 16, n), je = std::min(js + 16, m), i = is;
  165. for (; i + 4 <= ie; i += 4) {
  166. auto j = js;
  167. for (; j + 4 <= je; j += 4) {
  168. transpose4x4_sse(src + j * lds + i, dst + i * ldd + j, lds, ldd);
  169. }
  170. if (j < je) {
  171. transpose_naive(
  172. src + j * lds + i, dst + i * ldd + j, lds, ldd, 4, je - j);
  173. }
  174. }
  175. if (i < ie) {
  176. transpose_naive(
  177. src + js * lds + i, dst + i * ldd + js, lds, ldd, ie - i,
  178. je - js);
  179. }
  180. }
  181. }
  182. }
  183. template <>
  184. void transpose_knc2nsck(
  185. const float* src, float* dst, size_t k, size_t n, size_t c, size_t n_stride) {
  186. if (n_stride == k * c) {
  187. // dst is contiguous
  188. transpose(src, dst, k, n * c);
  189. } else {
  190. for (size_t i = 0; i < n; ++i) {
  191. transpose(src + i * c, dst + i * n_stride, k, c, n * c);
  192. }
  193. }
  194. }
  195. MEGDNN_ATTRIBUTE_TARGET("sse")
  196. void x86::disable_denorm() {
  197. _mm_setcsr(_mm_getcsr() | (_MM_FLUSH_ZERO_ON | _MM_DENORMALS_ZERO_ON));
  198. }
  199. } // namespace megdnn
  200. // vim: syntax=cpp.doxygen