#include "src/fallback/matrix_mul/gi/fp32/exec_sgemv.h" #include "include/megdnn/oprs.h" #include "src/common/unroll_macro.h" #include "src/common/utils.h" #include "src/fallback/general_intrinsic/gi_float.h" #include "midout.h" MIDOUT_DECL(megdnn_fp32_gi_sgemv) using namespace megdnn; using namespace fallback; namespace { void sgemv_gi_naive_n_mk4( const float* __restrict A, const float* __restrict B, float* __restrict C, size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride) { constexpr size_t PACK_SIZE = 4; megdnn_assert( N == 1 && Bstride == PACK_SIZE && M % PACK_SIZE == 0 && K % PACK_SIZE == 0); auto Aptr = A; auto Cptr = C; size_t m = 0; while (m < M) { auto Aptr0 = Aptr; auto Cptr0 = Cptr; GI_FLOAT32_V4_t c; #define INIT(step) GiSetSubVectorFloat32V4(c, step, GiBroadcastFloat32(0.0f)); UNROLL_CALL_RAW(4, INIT) #undef INIT auto Bptr = B; size_t k = 0; while (k < K) { GI_FLOAT32_t b = GiLoadFloat32(Bptr); GI_FLOAT32_V4_t a; #define LOAD_A(step) GiSetSubVectorFloat32V4(a, step, GiLoadFloat32(Aptr0 + step * 4)); UNROLL_CALL_RAW(4, LOAD_A) #undef LOAD_A #define COMPT(step) \ t = GiSimdFmaLane( \ GiGetSubVectorFloat32V4(c, step), GiGetSubVectorFloat32V4(a, step), b, \ step % 4); \ GiSetSubVectorFloat32V4(c, step, t); GI_FLOAT32_t t; UNROLL_CALL_RAW(4, COMPT) #undef COMPT Bptr += Bstride; Aptr0 += PACK_SIZE * PACK_SIZE; k += PACK_SIZE; } #define ADD_C(step, stride) \ t = GiAddFloat32( \ GiGetSubVectorFloat32V4(c, step), \ GiGetSubVectorFloat32V4(c, step + stride)); \ GiSetSubVectorFloat32V4(c, step, t); GI_FLOAT32_t t; UNROLL_CALL_RAW(2, ADD_C, 2) UNROLL_CALL_RAW(1, ADD_C, 1) #undef ADD_C GiStoreFloat32(Cptr0, GiGetSubVectorFloat32V4(c, 0)); Aptr += Astride; Cptr += Cstride; m += PACK_SIZE; } } } // namespace namespace megdnn { namespace fallback { void gi_gemv_like_mk4( const float* __restrict A, const float* __restrict B, float* __restrict C, size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride) { megdnn_assert(N == 1 && Bstride == 4); MIDOUT_BEGIN(megdnn_fp32_gi_sgemv, midout_iv("F32_GEMV_NCHW_GI_44_N"_hash)) { return sgemv_gi_naive_n_mk4(A, B, C, M, N, K, Astride, Bstride, Cstride); } MIDOUT_END(); } } // namespace fallback } // namespace megdnn // vim: syntax=cpp.doxygen