You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

exec_sgemv.cpp 2.9 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. #include "src/fallback/matrix_mul/gi/fp32/exec_sgemv.h"
  2. #include "include/megdnn/oprs.h"
  3. #include "src/common/unroll_macro.h"
  4. #include "src/common/utils.h"
  5. #include "src/fallback/general_intrinsic/gi_float.h"
  6. #include "midout.h"
  7. MIDOUT_DECL(megdnn_fp32_gi_sgemv)
  8. using namespace megdnn;
  9. using namespace fallback;
  10. namespace {
  11. void sgemv_gi_naive_n_mk4(
  12. const float* __restrict A, const float* __restrict B, float* __restrict C,
  13. size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride) {
  14. constexpr size_t PACK_SIZE = 4;
  15. megdnn_assert(
  16. N == 1 && Bstride == PACK_SIZE && M % PACK_SIZE == 0 && K % PACK_SIZE == 0);
  17. auto Aptr = A;
  18. auto Cptr = C;
  19. size_t m = 0;
  20. while (m < M) {
  21. auto Aptr0 = Aptr;
  22. auto Cptr0 = Cptr;
  23. GI_FLOAT32_V4_t c;
  24. #define INIT(step) GiSetSubVectorFloat32V4(c, step, GiBroadcastFloat32(0.0f));
  25. UNROLL_CALL_RAW(4, INIT)
  26. #undef INIT
  27. auto Bptr = B;
  28. size_t k = 0;
  29. while (k < K) {
  30. GI_FLOAT32_t b = GiLoadFloat32(Bptr);
  31. GI_FLOAT32_V4_t a;
  32. #define LOAD_A(step) GiSetSubVectorFloat32V4(a, step, GiLoadFloat32(Aptr0 + step * 4));
  33. UNROLL_CALL_RAW(4, LOAD_A)
  34. #undef LOAD_A
  35. #define COMPT(step) \
  36. t = GiSimdFmaLane( \
  37. GiGetSubVectorFloat32V4(c, step), GiGetSubVectorFloat32V4(a, step), b, \
  38. step % 4); \
  39. GiSetSubVectorFloat32V4(c, step, t);
  40. GI_FLOAT32_t t;
  41. UNROLL_CALL_RAW(4, COMPT)
  42. #undef COMPT
  43. Bptr += Bstride;
  44. Aptr0 += PACK_SIZE * PACK_SIZE;
  45. k += PACK_SIZE;
  46. }
  47. #define ADD_C(step, stride) \
  48. t = GiAddFloat32( \
  49. GiGetSubVectorFloat32V4(c, step), \
  50. GiGetSubVectorFloat32V4(c, step + stride)); \
  51. GiSetSubVectorFloat32V4(c, step, t);
  52. GI_FLOAT32_t t;
  53. UNROLL_CALL_RAW(2, ADD_C, 2)
  54. UNROLL_CALL_RAW(1, ADD_C, 1)
  55. #undef ADD_C
  56. GiStoreFloat32(Cptr0, GiGetSubVectorFloat32V4(c, 0));
  57. Aptr += Astride;
  58. Cptr += Cstride;
  59. m += PACK_SIZE;
  60. }
  61. }
  62. } // namespace
  63. namespace megdnn {
  64. namespace fallback {
  65. void gi_gemv_like_mk4(
  66. const float* __restrict A, const float* __restrict B, float* __restrict C,
  67. size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride) {
  68. megdnn_assert(N == 1 && Bstride == 4);
  69. MIDOUT_BEGIN(megdnn_fp32_gi_sgemv, midout_iv("F32_GEMV_NCHW_GI_44_N"_hash)) {
  70. return sgemv_gi_naive_n_mk4(A, B, C, M, N, K, Astride, Bstride, Cstride);
  71. }
  72. MIDOUT_END();
  73. }
  74. } // namespace fallback
  75. } // namespace megdnn
  76. // vim: syntax=cpp.doxygen