GitOrigin-RevId: 78c3e72218
tags/v1.3.0
| @@ -67,6 +67,23 @@ WorkspaceBundle ConvBiasImpl::AlgoS8MatrixMul::get_bundle( | |||||
| size_t K = IC * FH * FW; | size_t K = IC * FH * FW; | ||||
| size_t N = OH * OW; | size_t N = OH * OW; | ||||
| #if MGB_ENABLE_DOT | |||||
| #define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \ | |||||
| _bias_midout_enum, _nonline, \ | |||||
| _nonline_midout_enum) \ | |||||
| matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ | |||||
| M, N, K, param.filter_type, param.src_type, param.dst_type); \ | |||||
| part2 = megdnn::matmul::GemmInterleaved< \ | |||||
| matmul::gemm_##_gemm##_##_bias##_##_nonline>( \ | |||||
| M, N, K, false, false, strategy) \ | |||||
| .get_workspace_size(); | |||||
| if (cpuinfo_has_arm_neon_dot()) { | |||||
| DISPATCH_GEMM_BIAS(s8_8x12, 1) | |||||
| } else { | |||||
| DISPATCH_GEMM_BIAS(s8_4x4, 0) | |||||
| } | |||||
| #else | |||||
| #define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \ | #define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \ | ||||
| _bias_midout_enum, _nonline, \ | _bias_midout_enum, _nonline, \ | ||||
| _nonline_midout_enum) \ | _nonline_midout_enum) \ | ||||
| @@ -80,11 +97,7 @@ WorkspaceBundle ConvBiasImpl::AlgoS8MatrixMul::get_bundle( | |||||
| .get_workspace_size(); \ | .get_workspace_size(); \ | ||||
| } \ | } \ | ||||
| MIDOUT_END() | MIDOUT_END() | ||||
| #if !(__ARM_FEATURE_DOTPROD) | |||||
| DISPATCH_GEMM_BIAS(s8_4x4, 0) | DISPATCH_GEMM_BIAS(s8_4x4, 0) | ||||
| #else | |||||
| DISPATCH_GEMM_BIAS(s8_8x12, 1) | |||||
| #endif | #endif | ||||
| #undef DISPATCH_GEMM_STRATEGY | #undef DISPATCH_GEMM_STRATEGY | ||||
| } | } | ||||
| @@ -158,6 +171,23 @@ void ConvBiasImpl::AlgoS8MatrixMul::kimpl(const NCBKernParam& param, | |||||
| size_t K = IC * FH * FW; | size_t K = IC * FH * FW; | ||||
| size_t N = OH * OW; | size_t N = OH * OW; | ||||
| #if MGB_ENABLE_DOT | |||||
| #define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \ | |||||
| _bias_midout_enum, _nonline, \ | |||||
| _nonline_midout_enum) \ | |||||
| matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ | |||||
| M, N, K, param.filter_type, param.src_type, param.dst_type); \ | |||||
| megdnn::matmul::GemmInterleaved< \ | |||||
| matmul::gemm_##_gemm##_##_bias##_##_nonline> \ | |||||
| gemm_interleaved(M, N, K, false, false, strategy); \ | |||||
| gemm_interleaved.execute(filter, K, B, N, dst, N, workspace.raw_ptr, bias); | |||||
| if (cpuinfo_has_arm_neon_dot()) { | |||||
| DISPATCH_GEMM_BIAS(s8_8x12, 1) | |||||
| } else { | |||||
| DISPATCH_GEMM_BIAS(s8_4x4, 0) | |||||
| } | |||||
| #else | |||||
| #define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \ | #define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \ | ||||
| _bias_midout_enum, _nonline, \ | _bias_midout_enum, _nonline, \ | ||||
| _nonline_midout_enum) \ | _nonline_midout_enum) \ | ||||
| @@ -172,11 +202,7 @@ void ConvBiasImpl::AlgoS8MatrixMul::kimpl(const NCBKernParam& param, | |||||
| bias); \ | bias); \ | ||||
| } \ | } \ | ||||
| MIDOUT_END() | MIDOUT_END() | ||||
| #if !(__ARM_FEATURE_DOTPROD) | |||||
| DISPATCH_GEMM_BIAS(s8_4x4, 0) | DISPATCH_GEMM_BIAS(s8_4x4, 0) | ||||
| #else | |||||
| DISPATCH_GEMM_BIAS(s8_8x12, 1) | |||||
| #endif | #endif | ||||
| #undef DISPATCH_GEMM_STRATEGY | #undef DISPATCH_GEMM_STRATEGY | ||||
| } | } | ||||
| @@ -26,7 +26,7 @@ namespace impl { | |||||
| template <BiasMode bmode, typename Op, int block_m, int block_n> | template <BiasMode bmode, typename Op, int block_m, int block_n> | ||||
| struct KernCaller; | struct KernCaller; | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #if MGB_ENABLE_DOT | |||||
| template <BiasMode bmode, typename Op> | template <BiasMode bmode, typename Op> | ||||
| struct KernCaller<bmode, Op, 8, 12> { | struct KernCaller<bmode, Op, 8, 12> { | ||||
| static void run(const dt_int8* packA, const dt_int8* packB, size_t M, | static void run(const dt_int8* packA, const dt_int8* packB, size_t M, | ||||
| @@ -118,7 +118,7 @@ struct KernCaller<bmode, Op, 8, 12> { | |||||
| } | } | ||||
| }; | }; | ||||
| #else | |||||
| #endif | |||||
| template <BiasMode bmode, typename Op> | template <BiasMode bmode, typename Op> | ||||
| struct KernCaller<bmode, Op, 4, 4> { | struct KernCaller<bmode, Op, 4, 4> { | ||||
| @@ -196,10 +196,8 @@ struct KernCaller<bmode, Op, 4, 4> { | |||||
| } | } | ||||
| }; | }; | ||||
| #endif | |||||
| } // namespace impl | } // namespace impl | ||||
| #if !(__ARM_FEATURE_DOTPROD) | |||||
| MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8_4x4_nobias_identity) | MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8_4x4_nobias_identity) | ||||
| void gemm_s8_4x4_nobias_identity::pack_A(dt_int8* outptr, const dt_int8* inptr, | void gemm_s8_4x4_nobias_identity::pack_A(dt_int8* outptr, const dt_int8* inptr, | ||||
| @@ -227,7 +225,8 @@ void gemm_s8_4x4_nobias_identity::pack_B(dt_int8* out, const dt_int8* in, | |||||
| size_t gemm_s8_4x4_nobias_identity::get_workspace_size() const { | size_t gemm_s8_4x4_nobias_identity::get_workspace_size() const { | ||||
| return 4 * 4 * sizeof(dt_int32); | return 4 * 4 * sizeof(dt_int32); | ||||
| } | } | ||||
| #else | |||||
| #if MGB_ENABLE_DOT | |||||
| MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8_8x12_nobias_identity) | MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8_8x12_nobias_identity) | ||||
| void gemm_s8_8x12_nobias_identity::pack_A(dt_int8* outptr, const dt_int8* inptr, | void gemm_s8_8x12_nobias_identity::pack_A(dt_int8* outptr, const dt_int8* inptr, | ||||
| @@ -277,11 +276,10 @@ size_t gemm_s8_8x12_nobias_identity::get_workspace_size() const { | |||||
| #define DEFINE_OP(_Op) \ | #define DEFINE_OP(_Op) \ | ||||
| arm_common::_Op<dt_qint32, dt_qint8> op(scale_A* scale_B, scale_C); | arm_common::_Op<dt_qint32, dt_qint8> op(scale_A* scale_B, scale_C); | ||||
| #if !(__ARM_FEATURE_DOTPROD) | |||||
| KERN(4, 4, nobias, BiasMode::NO_BIAS, identity, TypeCvtOp) | KERN(4, 4, nobias, BiasMode::NO_BIAS, identity, TypeCvtOp) | ||||
| KERN(4, 4, nobias, BiasMode::NO_BIAS, relu, ReluOp) | KERN(4, 4, nobias, BiasMode::NO_BIAS, relu, ReluOp) | ||||
| KERN(4, 4, nobias, BiasMode::NO_BIAS, hswish, HSwishOp) | KERN(4, 4, nobias, BiasMode::NO_BIAS, hswish, HSwishOp) | ||||
| #else | |||||
| #if MGB_ENABLE_DOT | |||||
| KERN(8, 12, nobias, BiasMode::NO_BIAS, identity, TypeCvtOp) | KERN(8, 12, nobias, BiasMode::NO_BIAS, identity, TypeCvtOp) | ||||
| KERN(8, 12, nobias, BiasMode::NO_BIAS, relu, ReluOp) | KERN(8, 12, nobias, BiasMode::NO_BIAS, relu, ReluOp) | ||||
| KERN(8, 12, nobias, BiasMode::NO_BIAS, hswish, HSwishOp) | KERN(8, 12, nobias, BiasMode::NO_BIAS, hswish, HSwishOp) | ||||
| @@ -291,12 +289,11 @@ KERN(8, 12, nobias, BiasMode::NO_BIAS, hswish, HSwishOp) | |||||
| #define DEFINE_OP(_Op) \ | #define DEFINE_OP(_Op) \ | ||||
| arm_common::_Op<dt_qint32, dt_qint8> op(scale_A* scale_B, \ | arm_common::_Op<dt_qint32, dt_qint8> op(scale_A* scale_B, \ | ||||
| scale_A* scale_B, scale_C); | scale_A* scale_B, scale_C); | ||||
| #if !(__ARM_FEATURE_DOTPROD) | |||||
| KERN(4, 4, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, identity, AddOp) | KERN(4, 4, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, identity, AddOp) | ||||
| KERN(4, 4, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, relu, FuseAddReluOp) | KERN(4, 4, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, relu, FuseAddReluOp) | ||||
| KERN(4, 4, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish, | KERN(4, 4, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish, | ||||
| FuseAddHSwishOp) | FuseAddHSwishOp) | ||||
| #else | |||||
| #if MGB_ENABLE_DOT | |||||
| KERN(8, 12, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, identity, AddOp) | KERN(8, 12, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, identity, AddOp) | ||||
| KERN(8, 12, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, relu, FuseAddReluOp) | KERN(8, 12, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, relu, FuseAddReluOp) | ||||
| KERN(8, 12, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish, | KERN(8, 12, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish, | ||||
| @@ -15,7 +15,6 @@ namespace megdnn { | |||||
| namespace aarch64 { | namespace aarch64 { | ||||
| namespace matmul { | namespace matmul { | ||||
| #if !(__ARM_FEATURE_DOTPROD) | |||||
| /** | /** | ||||
| * \brief base strategy of gemm. | * \brief base strategy of gemm. | ||||
| * | * | ||||
| @@ -39,8 +38,7 @@ MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_4x4_bias_channel_relu, | |||||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_4x4_bias_channel_hswish, | MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_4x4_bias_channel_hswish, | ||||
| gemm_s8_4x4_nobias_identity); | gemm_s8_4x4_nobias_identity); | ||||
| #else | |||||
| #if MGB_ENABLE_DOT | |||||
| MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK(dt_int8, dt_int8, dt_int32, 8, 12, 4, | MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK(dt_int8, dt_int8, dt_int32, 8, 12, 4, | ||||
| false, true, | false, true, | ||||
| gemm_s8_8x12_nobias_identity); | gemm_s8_8x12_nobias_identity); | ||||
| @@ -59,7 +57,6 @@ MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_8x12_bias_channel_relu, | |||||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_8x12_bias_channel_hswish, | MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_8x12_bias_channel_hswish, | ||||
| gemm_s8_8x12_nobias_identity); | gemm_s8_8x12_nobias_identity); | ||||
| #endif | #endif | ||||
| } // namespace matmul | } // namespace matmul | ||||
| @@ -69,6 +69,23 @@ WorkspaceBundle ConvBiasImpl::AlgoQU8MatrixMul::get_bundle( | |||||
| size_t K = IC * FH * FW; | size_t K = IC * FH * FW; | ||||
| size_t N = OH * OW; | size_t N = OH * OW; | ||||
| #if MGB_ENABLE_DOT | |||||
| #define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \ | |||||
| _bias_midout_enum, _nonline, \ | |||||
| _nonline_midout_enum) \ | |||||
| matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ | |||||
| M, N, K, param.filter_type, param.src_type, param.dst_type); \ | |||||
| part2 = megdnn::matmul::GemmInterleaved< \ | |||||
| matmul::gemm_##_gemm##_##_bias##_##_nonline>( \ | |||||
| M, N, K, false, false, strategy) \ | |||||
| .get_workspace_size(); | |||||
| if (cpuinfo_has_arm_neon_dot()) { | |||||
| DISPATCH_GEMM_BIAS(u8_8x8_dot, 1); | |||||
| } else { | |||||
| DISPATCH_GEMM_BIAS(u8_8x8_nodot, 0); | |||||
| } | |||||
| #else | |||||
| #define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \ | #define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \ | ||||
| _bias_midout_enum, _nonline, \ | _bias_midout_enum, _nonline, \ | ||||
| _nonline_midout_enum) \ | _nonline_midout_enum) \ | ||||
| @@ -82,8 +99,8 @@ WorkspaceBundle ConvBiasImpl::AlgoQU8MatrixMul::get_bundle( | |||||
| .get_workspace_size(); \ | .get_workspace_size(); \ | ||||
| } \ | } \ | ||||
| MIDOUT_END() | MIDOUT_END() | ||||
| DISPATCH_GEMM_BIAS(u8_8x8, 0) | |||||
| DISPATCH_GEMM_BIAS(u8_8x8_nodot, 0) | |||||
| #endif | |||||
| #undef DISPATCH_GEMM_STRATEGY | #undef DISPATCH_GEMM_STRATEGY | ||||
| } | } | ||||
| return {nullptr, {part0, part1, part2}}; | return {nullptr, {part0, part1, part2}}; | ||||
| @@ -157,6 +174,23 @@ void ConvBiasImpl::AlgoQU8MatrixMul::kimpl(const NCBKernParam& param, | |||||
| size_t K = IC * FH * FW; | size_t K = IC * FH * FW; | ||||
| size_t N = OH * OW; | size_t N = OH * OW; | ||||
| #if MGB_ENABLE_DOT | |||||
| #define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \ | |||||
| _bias_midout_enum, _nonline, \ | |||||
| _nonline_midout_enum) \ | |||||
| matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ | |||||
| M, N, K, param.filter_type, param.src_type, param.dst_type); \ | |||||
| megdnn::matmul::GemmInterleaved< \ | |||||
| matmul::gemm_##_gemm##_##_bias##_##_nonline> \ | |||||
| gemm_interleaved(M, N, K, false, false, strategy); \ | |||||
| gemm_interleaved.execute(filter, K, B, N, dst, N, workspace.raw_ptr, bias); | |||||
| if (cpuinfo_has_arm_neon_dot()) { | |||||
| DISPATCH_GEMM_BIAS(u8_8x8_dot, 1) | |||||
| } else { | |||||
| DISPATCH_GEMM_BIAS(u8_8x8_nodot, 0) | |||||
| } | |||||
| #else | |||||
| #define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \ | #define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \ | ||||
| _bias_midout_enum, _nonline, \ | _bias_midout_enum, _nonline, \ | ||||
| _nonline_midout_enum) \ | _nonline_midout_enum) \ | ||||
| @@ -172,7 +206,9 @@ void ConvBiasImpl::AlgoQU8MatrixMul::kimpl(const NCBKernParam& param, | |||||
| } \ | } \ | ||||
| MIDOUT_END() | MIDOUT_END() | ||||
| DISPATCH_GEMM_BIAS(u8_8x8, 0) | |||||
| DISPATCH_GEMM_BIAS(u8_8x8_nodot, 0) | |||||
| #endif | |||||
| #undef DISPATCH_GEMM_STRATEGY | #undef DISPATCH_GEMM_STRATEGY | ||||
| } | } | ||||
| } | } | ||||
| @@ -23,12 +23,12 @@ using namespace aarch64; | |||||
| using namespace aarch64::matmul; | using namespace aarch64::matmul; | ||||
| namespace impl { | namespace impl { | ||||
| template <BiasMode bmode, typename Op, int block_m, int block_n> | |||||
| template <BiasMode bmode, typename Op, int block_m, int block_n, bool dot> | |||||
| struct KernCaller; | struct KernCaller; | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #if MGB_ENABLE_DOT | |||||
| template <BiasMode bmode, typename Op> | template <BiasMode bmode, typename Op> | ||||
| struct KernCaller<bmode, Op, 8, 8> { | |||||
| struct KernCaller<bmode, Op, 8, 8, true> { | |||||
| static void run(const dt_uint8* packA, const dt_uint8* packB, size_t M, | static void run(const dt_uint8* packA, const dt_uint8* packB, size_t M, | ||||
| size_t N, size_t K, dt_uint8* C, size_t LDC, | size_t N, size_t K, dt_uint8* C, size_t LDC, | ||||
| bool is_first_k, Op op, const dt_int32* bias, | bool is_first_k, Op op, const dt_int32* bias, | ||||
| @@ -120,10 +120,10 @@ struct KernCaller<bmode, Op, 8, 8> { | |||||
| } | } | ||||
| }; | }; | ||||
| #else | |||||
| #endif | |||||
| template <BiasMode bmode, typename Op> | template <BiasMode bmode, typename Op> | ||||
| struct KernCaller<bmode, Op, 8, 8> { | |||||
| struct KernCaller<bmode, Op, 8, 8, false> { | |||||
| static void run(const dt_uint8* packA, const dt_uint8* packB, size_t M, | static void run(const dt_uint8* packA, const dt_uint8* packB, size_t M, | ||||
| size_t N, size_t K, dt_uint8* C, size_t LDC, | size_t N, size_t K, dt_uint8* C, size_t LDC, | ||||
| bool is_first_k, Op op, const dt_int32* bias, | bool is_first_k, Op op, const dt_int32* bias, | ||||
| @@ -215,13 +215,11 @@ struct KernCaller<bmode, Op, 8, 8> { | |||||
| } | } | ||||
| }; | }; | ||||
| #endif | |||||
| } // namespace impl | } // namespace impl | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_u8_8x8_nobias_identity) | |||||
| #if MGB_ENABLE_DOT | |||||
| MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_u8_8x8_dot_nobias_identity) | |||||
| void gemm_u8_8x8_nobias_identity::pack_A(uint8_t* outptr, const uint8_t* inptr, | |||||
| void gemm_u8_8x8_dot_nobias_identity::pack_A(uint8_t* outptr, const uint8_t* inptr, | |||||
| int ldin, int y0, int ymax, int k0, | int ldin, int y0, int ymax, int k0, | ||||
| int kmax, bool transpose) const { | int kmax, bool transpose) const { | ||||
| if (transpose) { | if (transpose) { | ||||
| @@ -233,7 +231,7 @@ void gemm_u8_8x8_nobias_identity::pack_A(uint8_t* outptr, const uint8_t* inptr, | |||||
| } | } | ||||
| } | } | ||||
| void gemm_u8_8x8_nobias_identity::pack_B(uint8_t* out, const uint8_t* in, | |||||
| void gemm_u8_8x8_dot_nobias_identity::pack_B(uint8_t* out, const uint8_t* in, | |||||
| int ldin, int x0, int xmax, int k0, | int ldin, int x0, int xmax, int k0, | ||||
| int kmax, bool transpose) const { | int kmax, bool transpose) const { | ||||
| if (transpose) { | if (transpose) { | ||||
| @@ -245,10 +243,13 @@ void gemm_u8_8x8_nobias_identity::pack_B(uint8_t* out, const uint8_t* in, | |||||
| } | } | ||||
| } | } | ||||
| #else | |||||
| size_t gemm_u8_8x8_dot_nobias_identity::get_workspace_size() const { | |||||
| return 8 * 8 * sizeof(dt_int32); | |||||
| } | |||||
| MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_u8_8x8_nobias_identity) | |||||
| void gemm_u8_8x8_nobias_identity::pack_A(dt_uint8* outptr, | |||||
| #endif | |||||
| MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_u8_8x8_nodot_nobias_identity) | |||||
| void gemm_u8_8x8_nodot_nobias_identity::pack_A(dt_uint8* outptr, | |||||
| const dt_uint8* inptr, int ldin, | const dt_uint8* inptr, int ldin, | ||||
| int y0, int ymax, int k0, int kmax, | int y0, int ymax, int k0, int kmax, | ||||
| bool transpose) const { | bool transpose) const { | ||||
| @@ -262,7 +263,7 @@ void gemm_u8_8x8_nobias_identity::pack_A(dt_uint8* outptr, | |||||
| } | } | ||||
| } | } | ||||
| void gemm_u8_8x8_nobias_identity::pack_B(dt_uint8* out, const dt_uint8* in, | |||||
| void gemm_u8_8x8_nodot_nobias_identity::pack_B(dt_uint8* out, const dt_uint8* in, | |||||
| int ldin, int x0, int xmax, int k0, | int ldin, int x0, int xmax, int k0, | ||||
| int kmax, bool transpose) const { | int kmax, bool transpose) const { | ||||
| uint8_t zB = B_dtype.param<dtype::Quantized8Asymm>().zero_point; | uint8_t zB = B_dtype.param<dtype::Quantized8Asymm>().zero_point; | ||||
| @@ -275,43 +276,52 @@ void gemm_u8_8x8_nobias_identity::pack_B(dt_uint8* out, const dt_uint8* in, | |||||
| } | } | ||||
| } | } | ||||
| #endif | |||||
| size_t gemm_u8_8x8_nobias_identity::get_workspace_size() const { | |||||
| size_t gemm_u8_8x8_nodot_nobias_identity::get_workspace_size() const { | |||||
| return 8 * 8 * sizeof(dt_int32); | return 8 * 8 * sizeof(dt_int32); | ||||
| } | } | ||||
| #define KERN(_block_m, _block_n, _bias, _BIAS, _nonline, _OP) \ | |||||
| void gemm_u8_##_block_m##x##_block_n##_##_bias##_##_nonline::kern( \ | |||||
| const dt_uint8* packA, const dt_uint8* packB, size_t M, size_t N, \ | |||||
| size_t K, dt_uint8* C, size_t LDC, bool is_first_k, \ | |||||
| const dt_int32* bias, dt_int32* workspace) const { \ | |||||
| float scale_A = A_dtype.param<dtype::Quantized8Asymm>().scale; \ | |||||
| uint8_t zp_A = A_dtype.param<dtype::Quantized8Asymm>().zero_point; \ | |||||
| float scale_B = B_dtype.param<dtype::Quantized8Asymm>().scale; \ | |||||
| uint8_t zp_B = B_dtype.param<dtype::Quantized8Asymm>().zero_point; \ | |||||
| float scale_C = C_dtype.param<dtype::Quantized8Asymm>().scale; \ | |||||
| uint8_t zp_C = C_dtype.param<dtype::Quantized8Asymm>().zero_point; \ | |||||
| DEFINE_OP(_OP); \ | |||||
| impl::KernCaller<_BIAS, decltype(op), _block_m, _block_n>::run( \ | |||||
| packA, packB, M, N, K, C, LDC, is_first_k, op, bias, \ | |||||
| workspace, zp_A, zp_B); \ | |||||
| #define KERN(_block_m, _block_n, _dot, _suffix, _bias, _BIAS, _nonline, \ | |||||
| _OP) \ | |||||
| void gemm_u8_##_block_m##x##_block_n##_suffix##_##_bias##_##_nonline:: \ | |||||
| kern(const dt_uint8* packA, const dt_uint8* packB, size_t M, \ | |||||
| size_t N, size_t K, dt_uint8* C, size_t LDC, bool is_first_k, \ | |||||
| const dt_int32* bias, dt_int32* workspace) const { \ | |||||
| float scale_A = A_dtype.param<dtype::Quantized8Asymm>().scale; \ | |||||
| uint8_t zp_A = A_dtype.param<dtype::Quantized8Asymm>().zero_point; \ | |||||
| float scale_B = B_dtype.param<dtype::Quantized8Asymm>().scale; \ | |||||
| uint8_t zp_B = B_dtype.param<dtype::Quantized8Asymm>().zero_point; \ | |||||
| float scale_C = C_dtype.param<dtype::Quantized8Asymm>().scale; \ | |||||
| uint8_t zp_C = C_dtype.param<dtype::Quantized8Asymm>().zero_point; \ | |||||
| DEFINE_OP(_OP); \ | |||||
| impl::KernCaller<_BIAS, decltype(op), _block_m, _block_n, _dot>::run( \ | |||||
| packA, packB, M, N, K, C, LDC, is_first_k, op, bias, \ | |||||
| workspace, zp_A, zp_B); \ | |||||
| } | } | ||||
| #define DEFINE_OP(_Op) \ | #define DEFINE_OP(_Op) \ | ||||
| arm_common::_Op<dt_qint32, dt_quint8> op(scale_A* scale_B, scale_C, zp_C); | arm_common::_Op<dt_qint32, dt_quint8> op(scale_A* scale_B, scale_C, zp_C); | ||||
| KERN(8, 8, nobias, BiasMode::NO_BIAS, identity, TypeCvtOp) | |||||
| KERN(8, 8, nobias, BiasMode::NO_BIAS, relu, ReluOp) | |||||
| KERN(8, 8, nobias, BiasMode::NO_BIAS, hswish, HSwishOp) | |||||
| #if MGB_ENABLE_DOT | |||||
| KERN(8, 8, true, _dot, nobias, BiasMode::NO_BIAS, identity, TypeCvtOp) | |||||
| KERN(8, 8, true, _dot, nobias, BiasMode::NO_BIAS, relu, ReluOp) | |||||
| KERN(8, 8, true, _dot, nobias, BiasMode::NO_BIAS, hswish, HSwishOp) | |||||
| #endif | |||||
| KERN(8, 8, false, _nodot, nobias, BiasMode::NO_BIAS, identity, TypeCvtOp) | |||||
| KERN(8, 8, false, _nodot, nobias, BiasMode::NO_BIAS, relu, ReluOp) | |||||
| KERN(8, 8, false, _nodot, nobias, BiasMode::NO_BIAS, hswish, HSwishOp) | |||||
| #undef DEFINE_OP | #undef DEFINE_OP | ||||
| #define DEFINE_OP(_Op) \ | #define DEFINE_OP(_Op) \ | ||||
| arm_common::_Op<dt_qint32, dt_quint8> op(scale_A* scale_B, \ | arm_common::_Op<dt_qint32, dt_quint8> op(scale_A* scale_B, \ | ||||
| scale_A* scale_B, scale_C, zp_C); | scale_A* scale_B, scale_C, zp_C); | ||||
| KERN(8, 8, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, identity, AddOp) | |||||
| KERN(8, 8, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, relu, FuseAddReluOp) | |||||
| KERN(8, 8, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish, | |||||
| FuseAddHSwishOp) | |||||
| #if MGB_ENABLE_DOT | |||||
| KERN(8, 8, true, _dot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, identity, AddOp) | |||||
| KERN(8, 8, true, _dot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, relu, FuseAddReluOp) | |||||
| KERN(8, 8, true, _dot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish, FuseAddHSwishOp) | |||||
| #endif | |||||
| KERN(8, 8, false, _nodot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, identity, AddOp) | |||||
| KERN(8, 8, false, _nodot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, relu, FuseAddReluOp) | |||||
| KERN(8, 8, false, _nodot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish, FuseAddHSwishOp) | |||||
| #undef DEFINE_OP | #undef DEFINE_OP | ||||
| #undef KERN | #undef KERN | ||||
| @@ -15,30 +15,46 @@ namespace megdnn { | |||||
| namespace aarch64 { | namespace aarch64 { | ||||
| namespace matmul { | namespace matmul { | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #if MGB_ENABLE_DOT | |||||
| MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK(dt_uint8, dt_uint8, dt_int32, 8, 8, 4, | MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK(dt_uint8, dt_uint8, dt_int32, 8, 8, 4, | ||||
| false, true, | false, true, | ||||
| gemm_u8_8x8_nobias_identity); | |||||
| #else | |||||
| gemm_u8_8x8_dot_nobias_identity); | |||||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_dot_nobias_relu, | |||||
| gemm_u8_8x8_dot_nobias_identity); | |||||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_dot_nobias_hswish, | |||||
| gemm_u8_8x8_dot_nobias_identity); | |||||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_dot_bias_channel_identity, | |||||
| gemm_u8_8x8_dot_nobias_identity); | |||||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_dot_bias_channel_relu, | |||||
| gemm_u8_8x8_dot_nobias_identity); | |||||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_dot_bias_channel_hswish, | |||||
| gemm_u8_8x8_dot_nobias_identity); | |||||
| #endif | |||||
| MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK(dt_uint8, dt_uint8, dt_int32, 8, 8, 8, | MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK(dt_uint8, dt_uint8, dt_int32, 8, 8, 8, | ||||
| false, true, | false, true, | ||||
| gemm_u8_8x8_nobias_identity); | |||||
| #endif | |||||
| gemm_u8_8x8_nodot_nobias_identity); | |||||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_nobias_relu, | |||||
| gemm_u8_8x8_nobias_identity); | |||||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_nodot_nobias_relu, | |||||
| gemm_u8_8x8_nodot_nobias_identity); | |||||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_nobias_hswish, | |||||
| gemm_u8_8x8_nobias_identity); | |||||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_nodot_nobias_hswish, | |||||
| gemm_u8_8x8_nodot_nobias_identity); | |||||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_bias_channel_identity, | |||||
| gemm_u8_8x8_nobias_identity); | |||||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_nodot_bias_channel_identity, | |||||
| gemm_u8_8x8_nodot_nobias_identity); | |||||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_bias_channel_relu, | |||||
| gemm_u8_8x8_nobias_identity); | |||||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_nodot_bias_channel_relu, | |||||
| gemm_u8_8x8_nodot_nobias_identity); | |||||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_bias_channel_hswish, | |||||
| gemm_u8_8x8_nobias_identity); | |||||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_nodot_bias_channel_hswish, | |||||
| gemm_u8_8x8_nodot_nobias_identity); | |||||
| } // namespace matmul | } // namespace matmul | ||||
| @@ -24,9 +24,6 @@ | |||||
| #include "src/common/utils.h" | #include "src/common/utils.h" | ||||
| #include "src/fallback/matrix_mul/gemm_impl.h" | #include "src/fallback/matrix_mul/gemm_impl.h" | ||||
| #if MGB_ENABLE_CPUINFO | |||||
| #include "cpuinfo.h" | |||||
| #endif | |||||
| #include "midout.h" | #include "midout.h" | ||||
| MIDOUT_DECL(megdnn_aarch64_matmul_kern) | MIDOUT_DECL(megdnn_aarch64_matmul_kern) | ||||
| @@ -394,7 +391,7 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF16MK8_8x8::get_kern( | |||||
| #endif | #endif | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #if MGB_ENABLE_DOT | |||||
| /* ==================== Int8x8x32 K8x12x4 Dotprod algo ==================== */ | /* ==================== Int8x8x32 K8x12x4 Dotprod algo ==================== */ | ||||
| namespace { | namespace { | ||||
| void int8x8x32_k8x12x4_dotprod_kern( | void int8x8x32_k8x12x4_dotprod_kern( | ||||
| @@ -422,6 +419,9 @@ void int8x8x32_k8x12x4_dotprod_kern( | |||||
| bool MatrixMulImpl::AlgoInt8x8x32K8x12x4DotProd::usable( | bool MatrixMulImpl::AlgoInt8x8x32K8x12x4DotProd::usable( | ||||
| const KernSizeParam& kern_size_param) const { | const KernSizeParam& kern_size_param) const { | ||||
| if (!cpuinfo_has_arm_neon_dot()){ | |||||
| return false; | |||||
| } | |||||
| return can_be_treated_as_int8x8x32(kern_size_param); | return can_be_treated_as_int8x8x32(kern_size_param); | ||||
| } | } | ||||
| @@ -484,6 +484,11 @@ void int8x8x32_mk4_8x12x4_dotprod_kern( | |||||
| bool MatrixMulImpl::AlgoInt8x8x32MK4_8x12x4DotProd::usable( | bool MatrixMulImpl::AlgoInt8x8x32MK4_8x12x4DotProd::usable( | ||||
| const KernSizeParam& kern_size_param) const { | const KernSizeParam& kern_size_param) const { | ||||
| if (!cpuinfo_has_arm_neon_dot()){ | |||||
| return false; | |||||
| } | |||||
| return kern_size_param.A_type.enumv() == kern_size_param.B_type.enumv() && | return kern_size_param.A_type.enumv() == kern_size_param.B_type.enumv() && | ||||
| (kern_size_param.A_type.enumv() == DTypeEnum::Int8 || | (kern_size_param.A_type.enumv() == DTypeEnum::Int8 || | ||||
| kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8) && | kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8) && | ||||
| @@ -527,7 +532,7 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32MK4_8x12x4DotProd, | |||||
| aarch64::matmul::gemm_mk4_s8_8x12, int8_t, | aarch64::matmul::gemm_mk4_s8_8x12, int8_t, | ||||
| int32_t, AlgoDataType::QINT8X8X32, | int32_t, AlgoDataType::QINT8X8X32, | ||||
| MK4_DOT); | MK4_DOT); | ||||
| #else | |||||
| #endif | |||||
| /* ===================== Int8x8x32 MK4 4x4x16 algo ===================== */ | /* ===================== Int8x8x32 MK4 4x4x16 algo ===================== */ | ||||
| namespace { | namespace { | ||||
| @@ -727,7 +732,6 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32K8x8x8, | |||||
| aarch64::matmul::gemm_s8_8x8, int8_t, | aarch64::matmul::gemm_s8_8x8, int8_t, | ||||
| int32_t, AlgoDataType::QINT8X8X32, | int32_t, AlgoDataType::QINT8X8X32, | ||||
| DEFAULT); | DEFAULT); | ||||
| #endif | |||||
| /* ===================== Int8x8x16 K8x8x8 algo ===================== */ | /* ===================== Int8x8x16 K8x8x8 algo ===================== */ | ||||
| namespace { | namespace { | ||||
| @@ -1151,7 +1155,7 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt16x16x32MK8_8x8::get_kern( | |||||
| return kern_mk8_8x8; | return kern_mk8_8x8; | ||||
| } | } | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #if MGB_ENABLE_DOT | |||||
| /* ==================== Quint8 K8x8x4 Dotprod algo ==================== */ | /* ==================== Quint8 K8x8x4 Dotprod algo ==================== */ | ||||
| namespace { | namespace { | ||||
| void quint8_k8x8x4_dotprod_kern(const MatrixMulImpl::KernParam& kern_param) { | void quint8_k8x8x4_dotprod_kern(const MatrixMulImpl::KernParam& kern_param) { | ||||
| @@ -1166,8 +1170,8 @@ void quint8_k8x8x4_dotprod_kern(const MatrixMulImpl::KernParam& kern_param) { | |||||
| Bptr = kern_param.B<dt_uint8>(); | Bptr = kern_param.B<dt_uint8>(); | ||||
| auto Cptr = kern_param.C<dt_int32>(); | auto Cptr = kern_param.C<dt_int32>(); | ||||
| aarch64::matmul::gemm_u8_8x8 strategy(M, N, K, A_type, B_type, C_type); | |||||
| megdnn::matmul::GemmInterleaved<aarch64::matmul::gemm_u8_8x8>( | |||||
| aarch64::matmul::gemm_u8_8x8_dot strategy(M, N, K, A_type, B_type, C_type); | |||||
| megdnn::matmul::GemmInterleaved<aarch64::matmul::gemm_u8_8x8_dot>( | |||||
| M, N, K, trA, trB, strategy) | M, N, K, trA, trB, strategy) | ||||
| .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, | .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, | ||||
| kern_param.workspace_ptr); | kern_param.workspace_ptr); | ||||
| @@ -1178,6 +1182,9 @@ void quint8_k8x8x4_dotprod_kern(const MatrixMulImpl::KernParam& kern_param) { | |||||
| bool MatrixMulImpl::AlgoQuint8K8x8x4DotProd::usable( | bool MatrixMulImpl::AlgoQuint8K8x8x4DotProd::usable( | ||||
| const KernSizeParam& kern_size_param) const { | const KernSizeParam& kern_size_param) const { | ||||
| if (!cpuinfo_has_arm_neon_dot()){ | |||||
| return false; | |||||
| } | |||||
| return kern_size_param.A_type.enumv() == DTypeEnum::Quantized8Asymm && | return kern_size_param.A_type.enumv() == DTypeEnum::Quantized8Asymm && | ||||
| kern_size_param.B_type.enumv() == DTypeEnum::Quantized8Asymm && | kern_size_param.B_type.enumv() == DTypeEnum::Quantized8Asymm && | ||||
| kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32 && | kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32 && | ||||
| @@ -1195,8 +1202,8 @@ size_t MatrixMulImpl::AlgoQuint8K8x8x4DotProd::get_workspace( | |||||
| auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, | auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, | ||||
| C_type = kern_size_param.C_type; | C_type = kern_size_param.C_type; | ||||
| aarch64::matmul::gemm_u8_8x8 strategy(M, N, K, A_type, B_type, C_type); | |||||
| return megdnn::matmul::GemmInterleaved<aarch64::matmul::gemm_u8_8x8>( | |||||
| aarch64::matmul::gemm_u8_8x8_dot strategy(M, N, K, A_type, B_type, C_type); | |||||
| return megdnn::matmul::GemmInterleaved<aarch64::matmul::gemm_u8_8x8_dot>( | |||||
| M, N, K, trA, trB, strategy) | M, N, K, trA, trB, strategy) | ||||
| .get_workspace_size(); | .get_workspace_size(); | ||||
| } | } | ||||
| @@ -1212,7 +1219,7 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoQuint8K8x8x4DotProd::get_kern( | |||||
| MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoQuint8K8x8x4DotProd, | MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoQuint8K8x8x4DotProd, | ||||
| megdnn_aarch64_matmul_kern, | megdnn_aarch64_matmul_kern, | ||||
| "AlgoQuint8K8x8x4DotProdImpl"_hash, | "AlgoQuint8K8x8x4DotProdImpl"_hash, | ||||
| aarch64::matmul::gemm_u8_8x8, uint8_t, | |||||
| aarch64::matmul::gemm_u8_8x8_dot, uint8_t, | |||||
| int32_t, AlgoDataType::QUINT8X8X32, | int32_t, AlgoDataType::QUINT8X8X32, | ||||
| DEFAULT); | DEFAULT); | ||||
| /* ===================== Quint8 Gemv DotProd algo ===================== */ | /* ===================== Quint8 Gemv DotProd algo ===================== */ | ||||
| @@ -1238,6 +1245,9 @@ void quint8_gemv_dotprod_kern(const MatrixMulImpl::KernParam& kern_param) { | |||||
| bool MatrixMulImpl::AlgoQuint8GemvDotProd::usable( | bool MatrixMulImpl::AlgoQuint8GemvDotProd::usable( | ||||
| const KernSizeParam& kern_size_param) const { | const KernSizeParam& kern_size_param) const { | ||||
| if (!cpuinfo_has_arm_neon_dot()){ | |||||
| return false; | |||||
| } | |||||
| return kern_size_param.A_type.enumv() == DTypeEnum::Quantized8Asymm && | return kern_size_param.A_type.enumv() == DTypeEnum::Quantized8Asymm && | ||||
| kern_size_param.B_type.enumv() == DTypeEnum::Quantized8Asymm && | kern_size_param.B_type.enumv() == DTypeEnum::Quantized8Asymm && | ||||
| kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32 && | kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32 && | ||||
| @@ -1257,7 +1267,7 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoQuint8GemvDotProd::get_kern( | |||||
| const KernSizeParam&) const { | const KernSizeParam&) const { | ||||
| return quint8_gemv_dotprod_kern; | return quint8_gemv_dotprod_kern; | ||||
| } | } | ||||
| #else | |||||
| #endif | |||||
| /* ===================== Quint8 K8x8x8 algo ===================== */ | /* ===================== Quint8 K8x8x8 algo ===================== */ | ||||
| namespace { | namespace { | ||||
| @@ -1322,7 +1332,6 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoQuint8K8x8x8, | |||||
| aarch64::matmul::gemm_u8_8x8, uint8_t, | aarch64::matmul::gemm_u8_8x8, uint8_t, | ||||
| int32_t, AlgoDataType::QUINT8X8X32, | int32_t, AlgoDataType::QUINT8X8X32, | ||||
| DEFAULT); | DEFAULT); | ||||
| #endif | |||||
| /* ===================== Int8x8x16 K8x8x8 algo ===================== */ | /* ===================== Int8x8x16 K8x8x8 algo ===================== */ | ||||
| namespace { | namespace { | ||||
| @@ -111,7 +111,7 @@ public: | |||||
| #endif | #endif | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #if MGB_ENABLE_DOT | |||||
| class MatrixMulImpl::AlgoInt8x8x32K8x12x4DotProd final : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x32K8x12x4DotProd final : public AlgoBase { | ||||
| public: | public: | ||||
| AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
| @@ -141,7 +141,7 @@ public: | |||||
| MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | ||||
| MEGDNN_DECL_ALGO_TYPE(AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD) | MEGDNN_DECL_ALGO_TYPE(AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD) | ||||
| }; | }; | ||||
| #else | |||||
| #endif | |||||
| class MatrixMulImpl::AlgoInt8x8x32MK4_4x4x16 final : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x32MK4_4x4x16 final : public AlgoBase { | ||||
| public: | public: | ||||
| @@ -187,7 +187,6 @@ public: | |||||
| MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | ||||
| MEGDNN_DECL_ALGO_TYPE(AARCH64_INT8X8X32_K8X8X8) | MEGDNN_DECL_ALGO_TYPE(AARCH64_INT8X8X32_K8X8X8) | ||||
| }; | }; | ||||
| #endif | |||||
| class MatrixMulImpl::AlgoInt8x8x16K8x8x8 final : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x16K8x8x8 final : public AlgoBase { | ||||
| public: | public: | ||||
| @@ -313,7 +312,7 @@ public: | |||||
| MEGDNN_DECL_ALGO_TYPE(AARCH64_INT16X16X32_MK8_8X8) | MEGDNN_DECL_ALGO_TYPE(AARCH64_INT16X16X32_MK8_8X8) | ||||
| }; | }; | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #if MGB_ENABLE_DOT | |||||
| class MatrixMulImpl::AlgoQuint8K8x8x4DotProd final : public AlgoBase { | class MatrixMulImpl::AlgoQuint8K8x8x4DotProd final : public AlgoBase { | ||||
| public: | public: | ||||
| AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
| @@ -328,7 +327,6 @@ public: | |||||
| MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | ||||
| MEGDNN_DECL_ALGO_TYPE(AARCH64_QUINT8_K8X8X4_DOTPROD) | MEGDNN_DECL_ALGO_TYPE(AARCH64_QUINT8_K8X8X4_DOTPROD) | ||||
| }; | }; | ||||
| class MatrixMulImpl::AlgoQuint8GemvDotProd final : public AlgoBase { | class MatrixMulImpl::AlgoQuint8GemvDotProd final : public AlgoBase { | ||||
| public: | public: | ||||
| AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
| @@ -344,8 +342,7 @@ public: | |||||
| MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QUINT8X8X32, DEFAULT) | MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QUINT8X8X32, DEFAULT) | ||||
| MEGDNN_DECL_ALGO_TYPE(AARCH64_QUINT8_GEMV_DOTPROD) | MEGDNN_DECL_ALGO_TYPE(AARCH64_QUINT8_GEMV_DOTPROD) | ||||
| }; | }; | ||||
| #else | |||||
| #endif | |||||
| class MatrixMulImpl::AlgoQuint8K8x8x8 final : public AlgoBase { | class MatrixMulImpl::AlgoQuint8K8x8x8 final : public AlgoBase { | ||||
| public: | public: | ||||
| AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
| @@ -358,7 +355,6 @@ public: | |||||
| MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | ||||
| MEGDNN_DECL_ALGO_TYPE(AARCH64_QUINT8_K8X8X8) | MEGDNN_DECL_ALGO_TYPE(AARCH64_QUINT8_K8X8X8) | ||||
| }; | }; | ||||
| #endif | |||||
| } // namespace aarch64 | } // namespace aarch64 | ||||
| } // namespace megdnn | } // namespace megdnn | ||||
| @@ -20,9 +20,6 @@ | |||||
| #include "src/aarch64/matrix_mul/fp32/strategy.h" | #include "src/aarch64/matrix_mul/fp32/strategy.h" | ||||
| #include "src/common/utils.h" | #include "src/common/utils.h" | ||||
| #if MGB_ENABLE_CPUINFO | |||||
| #include "cpuinfo.h" | |||||
| #endif | |||||
| using namespace megdnn; | using namespace megdnn; | ||||
| using namespace aarch64; | using namespace aarch64; | ||||
| @@ -9,7 +9,6 @@ | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| */ | */ | ||||
| #if !(__ARM_FEATURE_DOTPROD) | |||||
| #include "src/aarch64/matrix_mul/asm/common.h" | #include "src/aarch64/matrix_mul/asm/common.h" | ||||
| #include "src/arm_common/simd_macro/marm_neon.h" | #include "src/arm_common/simd_macro/marm_neon.h" | ||||
| @@ -851,6 +850,5 @@ static void gemm_s8_4x4_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, | |||||
| } // namespace matmul_4x4x16 | } // namespace matmul_4x4x16 | ||||
| } // namespace aarch64 | } // namespace aarch64 | ||||
| } // namespace megdnn | } // namespace megdnn | ||||
| #endif | |||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
| @@ -9,7 +9,6 @@ | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| */ | */ | ||||
| #if !(__ARM_FEATURE_DOTPROD) | |||||
| #include "src/aarch64/matrix_mul/asm/common.h" | #include "src/aarch64/matrix_mul/asm/common.h" | ||||
| #include "src/arm_common/simd_macro/marm_neon.h" | #include "src/arm_common/simd_macro/marm_neon.h" | ||||
| @@ -1372,4 +1371,3 @@ static void gemm_s8_8x8_transpose_pack_B_n(int8_t* outptr, const int8_t* inptr, | |||||
| } // namespace megdnn | } // namespace megdnn | ||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
| #endif | |||||
| @@ -10,8 +10,6 @@ | |||||
| * implied. | * implied. | ||||
| */ | */ | ||||
| #include <cstring> | |||||
| #if !(__ARM_FEATURE_DOTPROD) | |||||
| #include "src/aarch64/matrix_mul/asm/common.h" | #include "src/aarch64/matrix_mul/asm/common.h" | ||||
| #include "src/arm_common/simd_macro/marm_neon.h" | #include "src/arm_common/simd_macro/marm_neon.h" | ||||
| @@ -887,6 +885,5 @@ static void gemm_mk4_s8_4x4_pack_B(dt_int8* out, const dt_int8* in, int ldin, | |||||
| } // namespace matmul_4x4x16 | } // namespace matmul_4x4x16 | ||||
| } // namespace aarch64 | } // namespace aarch64 | ||||
| } // namespace megdnn | } // namespace megdnn | ||||
| #endif | |||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
| @@ -9,7 +9,6 @@ | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| */ | */ | ||||
| #if !(__ARM_FEATURE_DOTPROD) | |||||
| #include "src/aarch64/matrix_mul/int8/strategy.h" | #include "src/aarch64/matrix_mul/int8/strategy.h" | ||||
| #include "src/aarch64/matrix_mul/asm/common.h" | #include "src/aarch64/matrix_mul/asm/common.h" | ||||
| #include "src/aarch64/matrix_mul/int8/kernel_4x4x16.h" | #include "src/aarch64/matrix_mul/int8/kernel_4x4x16.h" | ||||
| @@ -105,7 +104,6 @@ void gemm_s8_4x4::kern(const dt_int8* packA, const dt_int8* packB, size_t M, | |||||
| packA += K4; | packA += K4; | ||||
| } | } | ||||
| } | } | ||||
| ///////////////////////// gemm_mk4_s8_4x4 //////////////////////////////////// | ///////////////////////// gemm_mk4_s8_4x4 //////////////////////////////////// | ||||
| MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_mk4_s8_4x4); | MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_mk4_s8_4x4); | ||||
| @@ -258,6 +256,5 @@ void gemm_s8_8x8::kern(const dt_int8* packA, const dt_int8* packB, size_t M, | |||||
| packA += K4; | packA += K4; | ||||
| } | } | ||||
| } | } | ||||
| #endif | |||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
| @@ -10,7 +10,6 @@ | |||||
| */ | */ | ||||
| #pragma once | #pragma once | ||||
| #if !(__ARM_FEATURE_DOTPROD) | |||||
| #include "src/fallback/matrix_mul/gemm_common.h" | #include "src/fallback/matrix_mul/gemm_common.h" | ||||
| namespace megdnn { | namespace megdnn { | ||||
| @@ -30,5 +29,4 @@ MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 8, 8, 8, false, true, | |||||
| } // namespace aarch64 | } // namespace aarch64 | ||||
| } // namespace megdnn | } // namespace megdnn | ||||
| #endif | |||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
| @@ -9,8 +9,7 @@ | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| */ | */ | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #if MGB_ENABLE_DOT | |||||
| #include "src/aarch64/matrix_mul/asm/common.h" | #include "src/aarch64/matrix_mul/asm/common.h" | ||||
| #include "src/arm_common/simd_macro/marm_neon.h" | #include "src/arm_common/simd_macro/marm_neon.h" | ||||
| @@ -50,7 +49,9 @@ namespace matmul_8x12x4 { | |||||
| * same, I test in kirin980 with small and big core, here i just keep both the | * same, I test in kirin980 with small and big core, here i just keep both the | ||||
| * implementation. | * implementation. | ||||
| */ | */ | ||||
| #if 1 | #if 1 | ||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
| static void kern_8x12(const int8_t* packA, const int8_t* packB, int K, | static void kern_8x12(const int8_t* packA, const int8_t* packB, int K, | ||||
| int32_t* output, int LDC, bool is_first_k) { | int32_t* output, int LDC, bool is_first_k) { | ||||
| K /= 4; | K /= 4; | ||||
| @@ -408,6 +409,7 @@ static void kern_8x12(const int8_t* packA, const int8_t* packB, int K, | |||||
| ); | ); | ||||
| } | } | ||||
| #else | #else | ||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
| static void kern_8x12(const int8_t* packA, const int8_t* packB, int K, | static void kern_8x12(const int8_t* packA, const int8_t* packB, int K, | ||||
| int32_t* output, int LDC, bool is_first_k) { | int32_t* output, int LDC, bool is_first_k) { | ||||
| K /= 4; | K /= 4; | ||||
| @@ -650,7 +652,7 @@ static void kern_8x12(const int8_t* packA, const int8_t* packB, int K, | |||||
| // +-------+-------+ - - - - +--------+--------+--------+ | // +-------+-------+ - - - - +--------+--------+--------+ | ||||
| // | // | ||||
| // Accumulator | // Accumulator | ||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
| static void kern_4x12(const int8_t* packA, const int8_t* packB, int K, | static void kern_4x12(const int8_t* packA, const int8_t* packB, int K, | ||||
| int32_t* output, int LDC, bool is_first_k, int m_remain) { | int32_t* output, int LDC, bool is_first_k, int m_remain) { | ||||
| K /= 4; | K /= 4; | ||||
| @@ -837,7 +839,7 @@ static void kern_4x12(const int8_t* packA, const int8_t* packB, int K, | |||||
| // +-------+-------+ - - - - +---------+ | // +-------+-------+ - - - - +---------+ | ||||
| // | // | ||||
| // Accumulator | // Accumulator | ||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
| static void kern_8x4(const int8_t* packA, const int8_t* packB, int K, | static void kern_8x4(const int8_t* packA, const int8_t* packB, int K, | ||||
| int32_t* output, int LDC, bool is_first_k, int n_remain) { | int32_t* output, int LDC, bool is_first_k, int n_remain) { | ||||
| K /= 4; | K /= 4; | ||||
| @@ -1038,7 +1040,7 @@ static void kern_8x4(const int8_t* packA, const int8_t* packB, int K, | |||||
| // +-------+-------+ - - - - +--------+ | // +-------+-------+ - - - - +--------+ | ||||
| // | // | ||||
| // Accumulator | // Accumulator | ||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
| static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | ||||
| int32_t* output, int LDC, bool is_first_k, int m_remain, | int32_t* output, int LDC, bool is_first_k, int m_remain, | ||||
| int n_remain) { | int n_remain) { | ||||
| @@ -10,8 +10,7 @@ | |||||
| * implied. | * implied. | ||||
| */ | */ | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #if MGB_ENABLE_DOT | |||||
| #include "src/aarch64/matrix_mul/asm/common.h" | #include "src/aarch64/matrix_mul/asm/common.h" | ||||
| #include "src/arm_common/simd_macro/marm_neon.h" | #include "src/arm_common/simd_macro/marm_neon.h" | ||||
| @@ -40,6 +39,7 @@ namespace matmul_mk4_8x12x4 { | |||||
| // | // | ||||
| // Accumulator | // Accumulator | ||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
| static void kern_8x12(const int8_t* packA, const int8_t* packB, int K, | static void kern_8x12(const int8_t* packA, const int8_t* packB, int K, | ||||
| int32_t* output, int LDC, bool is_first_k) { | int32_t* output, int LDC, bool is_first_k) { | ||||
| K /= 4; | K /= 4; | ||||
| @@ -60,7 +60,6 @@ static void kern_8x12(const int8_t* packA, const int8_t* packB, int K, | |||||
| int32_t* outptr0 = output; | int32_t* outptr0 = output; | ||||
| int32_t* outptr1; | int32_t* outptr1; | ||||
| asm volatile ( | asm volatile ( | ||||
| // load accumulator C | // load accumulator C | ||||
| "add %[outptr1], %[outptr0], %x[LDC]\n" | "add %[outptr1], %[outptr0], %x[LDC]\n" | ||||
| @@ -397,6 +396,7 @@ static void kern_8x12(const int8_t* packA, const int8_t* packB, int K, | |||||
| // | // | ||||
| // Accumulator | // Accumulator | ||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
| static void kern_4x12(const int8_t* packA, const int8_t* packB, int K, | static void kern_4x12(const int8_t* packA, const int8_t* packB, int K, | ||||
| int32_t* output, int LDC, bool is_first_k) { | int32_t* output, int LDC, bool is_first_k) { | ||||
| K /= 4; | K /= 4; | ||||
| @@ -543,6 +543,7 @@ static void kern_4x12(const int8_t* packA, const int8_t* packB, int K, | |||||
| // +--------+--------+ - - - - +------------+ | // +--------+--------+ - - - - +------------+ | ||||
| // Accumulator | // Accumulator | ||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
| static void kern_8x4(const int8_t* packA, const int8_t* packB, int K, | static void kern_8x4(const int8_t* packA, const int8_t* packB, int K, | ||||
| int32_t* output, int LDC, bool is_first_k, int n_remain) { | int32_t* output, int LDC, bool is_first_k, int n_remain) { | ||||
| K /= 4; | K /= 4; | ||||
| @@ -718,6 +719,7 @@ static void kern_8x4(const int8_t* packA, const int8_t* packB, int K, | |||||
| // +--------+--------+ - - - - +------------+ | // +--------+--------+ - - - - +------------+ | ||||
| // Accumulator | // Accumulator | ||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
| static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | ||||
| int32_t* output, int LDC, bool is_first_k, int n_remain) { | int32_t* output, int LDC, bool is_first_k, int n_remain) { | ||||
| K /= 4; | K /= 4; | ||||
| @@ -928,6 +930,5 @@ static void gemm_mk4_s8_8x12_pack_B(dt_int8* out, const dt_int8* in, int ldin, | |||||
| } // namespace matmul_mk4_8x12x4 | } // namespace matmul_mk4_8x12x4 | ||||
| } // namespace aarch64 | } // namespace aarch64 | ||||
| } // namespace megdnn | } // namespace megdnn | ||||
| #endif | #endif | ||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
| @@ -10,13 +10,13 @@ | |||||
| */ | */ | ||||
| #include "src/aarch64/matrix_mul/int8_dot/strategy.h" | #include "src/aarch64/matrix_mul/int8_dot/strategy.h" | ||||
| #if MGB_ENABLE_DOT | |||||
| #include "src/aarch64/matrix_mul/asm/common.h" | #include "src/aarch64/matrix_mul/asm/common.h" | ||||
| #include "src/arm_common/simd_macro/marm_neon.h" | #include "src/arm_common/simd_macro/marm_neon.h" | ||||
| #include "src/common/utils.h" | #include "src/common/utils.h" | ||||
| #include "src/aarch64/matrix_mul/int8_dot/kernel_8x12x4.h" | #include "src/aarch64/matrix_mul/int8_dot/kernel_8x12x4.h" | ||||
| #include "src/aarch64/matrix_mul/int8_dot/kernel_mk4_8x12x4.h" | #include "src/aarch64/matrix_mul/int8_dot/kernel_mk4_8x12x4.h" | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| using namespace megdnn; | using namespace megdnn; | ||||
| using namespace aarch64; | using namespace aarch64; | ||||
| using namespace aarch64::matmul; | using namespace aarch64::matmul; | ||||
| @@ -11,7 +11,7 @@ | |||||
| #pragma once | #pragma once | ||||
| #include "src/fallback/matrix_mul/gemm_common.h" | #include "src/fallback/matrix_mul/gemm_common.h" | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #if MGB_ENABLE_DOT | |||||
| namespace megdnn { | namespace megdnn { | ||||
| namespace aarch64 { | namespace aarch64 { | ||||
| namespace matmul { | namespace matmul { | ||||
| @@ -27,14 +27,13 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { | |||||
| AlgoF16K8x24x1 f16_k8x24x1; | AlgoF16K8x24x1 f16_k8x24x1; | ||||
| AlgoF16MK8_8x8 f16_mk8_8x8; | AlgoF16MK8_8x8 f16_mk8_8x8; | ||||
| #endif | #endif | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #if MGB_ENABLE_DOT | |||||
| AlgoInt8x8x32K8x12x4DotProd int8x8x32_k8x12x4_dotprod; | AlgoInt8x8x32K8x12x4DotProd int8x8x32_k8x12x4_dotprod; | ||||
| AlgoInt8x8x32MK4_8x12x4DotProd int8x8x32_mk4_8x12x4_dotprod; | AlgoInt8x8x32MK4_8x12x4DotProd int8x8x32_mk4_8x12x4_dotprod; | ||||
| #else | |||||
| #endif | |||||
| AlgoInt8x8x32MK4_4x4x16 int8x8x32_mk4_4x4x16; | AlgoInt8x8x32MK4_4x4x16 int8x8x32_mk4_4x4x16; | ||||
| AlgoInt8x8x32K4x4x16 int8x8x32_k4x4x16; | AlgoInt8x8x32K4x4x16 int8x8x32_k4x4x16; | ||||
| AlgoInt8x8x32K8x8x8 int8x8x32_k8x8x8; | AlgoInt8x8x32K8x8x8 int8x8x32_k8x8x8; | ||||
| #endif | |||||
| AlgoInt8x8x16K8x8x8 int8x8x16_k8x8x8; | AlgoInt8x8x16K8x8x8 int8x8x16_k8x8x8; | ||||
| AlgoInt8x8x16K4x4x16 int8x8x16_k4x4x16; | AlgoInt8x8x16K4x4x16 int8x8x16_k4x4x16; | ||||
| AlgoInt8x8x16MK4_16x12x4 int8x8x16_mk4_16x12x4; | AlgoInt8x8x16MK4_16x12x4 int8x8x16_mk4_16x12x4; | ||||
| @@ -44,12 +43,11 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { | |||||
| AlgoInt16x16x32K12x8x1 int16x16x32_k12x8x1; | AlgoInt16x16x32K12x8x1 int16x16x32_k12x8x1; | ||||
| AlgoInt16x16x32MK8_8x8 int16x16x32_mk8_8x8; | AlgoInt16x16x32MK8_8x8 int16x16x32_mk8_8x8; | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #if MGB_ENABLE_DOT | |||||
| AlgoQuint8K8x8x4DotProd quint8_k8x8x4_dotprod; | AlgoQuint8K8x8x4DotProd quint8_k8x8x4_dotprod; | ||||
| AlgoQuint8GemvDotProd quint8_gemv_dotprod; | AlgoQuint8GemvDotProd quint8_gemv_dotprod; | ||||
| #else | |||||
| AlgoQuint8K8x8x8 quint8_k8x8x8; | |||||
| #endif | #endif | ||||
| AlgoQuint8K8x8x8 quint8_k8x8x8; | |||||
| AlgoInt4x4x16K8x8x8 int4x4x16_k8x8x8; | AlgoInt4x4x16K8x8x8 int4x4x16_k8x8x8; | ||||
| SmallVector<fallback::MatrixMulImpl::AlgoBase*> m_all_algos; | SmallVector<fallback::MatrixMulImpl::AlgoBase*> m_all_algos; | ||||
| @@ -66,14 +64,13 @@ public: | |||||
| m_all_algos.emplace_back(&f16_k8x24x1); | m_all_algos.emplace_back(&f16_k8x24x1); | ||||
| m_all_algos.emplace_back(&f16_mk8_8x8); | m_all_algos.emplace_back(&f16_mk8_8x8); | ||||
| #endif | #endif | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #if MGB_ENABLE_DOT | |||||
| m_all_algos.emplace_back(&int8x8x32_k8x12x4_dotprod); | m_all_algos.emplace_back(&int8x8x32_k8x12x4_dotprod); | ||||
| m_all_algos.emplace_back(&int8x8x32_mk4_8x12x4_dotprod); | m_all_algos.emplace_back(&int8x8x32_mk4_8x12x4_dotprod); | ||||
| #else | |||||
| #endif | |||||
| m_all_algos.emplace_back(&int8x8x32_k4x4x16); | m_all_algos.emplace_back(&int8x8x32_k4x4x16); | ||||
| m_all_algos.emplace_back(&int8x8x32_k8x8x8); | m_all_algos.emplace_back(&int8x8x32_k8x8x8); | ||||
| m_all_algos.emplace_back(&int8x8x32_mk4_4x4x16); | m_all_algos.emplace_back(&int8x8x32_mk4_4x4x16); | ||||
| #endif | |||||
| m_all_algos.emplace_back(&int8x8x16_k4x4x16); | m_all_algos.emplace_back(&int8x8x16_k4x4x16); | ||||
| m_all_algos.emplace_back(&int8x8x16_k8x8x8); | m_all_algos.emplace_back(&int8x8x16_k8x8x8); | ||||
| m_all_algos.emplace_back(&int8x8x16_mk4_k8x8x8); | m_all_algos.emplace_back(&int8x8x16_mk4_k8x8x8); | ||||
| @@ -82,12 +79,11 @@ public: | |||||
| m_all_algos.emplace_back(&int16x16x32_k12x8x1); | m_all_algos.emplace_back(&int16x16x32_k12x8x1); | ||||
| m_all_algos.emplace_back(&int16x16x32_mk8_8x8); | m_all_algos.emplace_back(&int16x16x32_mk8_8x8); | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #if MGB_ENABLE_DOT | |||||
| m_all_algos.emplace_back(&quint8_gemv_dotprod); | m_all_algos.emplace_back(&quint8_gemv_dotprod); | ||||
| m_all_algos.emplace_back(&quint8_k8x8x4_dotprod); | m_all_algos.emplace_back(&quint8_k8x8x4_dotprod); | ||||
| #else | |||||
| m_all_algos.emplace_back(&quint8_k8x8x8); | |||||
| #endif | #endif | ||||
| m_all_algos.emplace_back(&quint8_k8x8x8); | |||||
| m_all_algos.emplace_back(&int4x4x16_k8x8x8); | m_all_algos.emplace_back(&int4x4x16_k8x8x8); | ||||
| for (auto&& algo : m_all_algos) { | for (auto&& algo : m_all_algos) { | ||||
| @@ -41,16 +41,15 @@ private: | |||||
| class AlgoF16MK8_8x8; // Aarch64 F16 Format MK8 block 16x8 | class AlgoF16MK8_8x8; // Aarch64 F16 Format MK8 block 16x8 | ||||
| #endif | #endif | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #if MGB_ENABLE_DOT | |||||
| class AlgoInt8x8x32K8x12x4DotProd; // Aarch64 Int8x8x32 Kernel | class AlgoInt8x8x32K8x12x4DotProd; // Aarch64 Int8x8x32 Kernel | ||||
| // 8x12x4 DotProduct | // 8x12x4 DotProduct | ||||
| class AlgoInt8x8x32MK4_8x12x4DotProd; // Aarch64 nchw44 Int8x8x32 Kernel | class AlgoInt8x8x32MK4_8x12x4DotProd; // Aarch64 nchw44 Int8x8x32 Kernel | ||||
| // 8x12x4 DotProduct | // 8x12x4 DotProduct | ||||
| #else | |||||
| #endif | |||||
| class AlgoInt8x8x32MK4_4x4x16; // Aarch64 nchw44 Int8x8x32 Kernel 4x4x16 | class AlgoInt8x8x32MK4_4x4x16; // Aarch64 nchw44 Int8x8x32 Kernel 4x4x16 | ||||
| class AlgoInt8x8x32K4x4x16; // Aarch64 Int8x8x32 Kernel 4x4x16 | class AlgoInt8x8x32K4x4x16; // Aarch64 Int8x8x32 Kernel 4x4x16 | ||||
| class AlgoInt8x8x32K8x8x8; // Aarch64 Int8x8x32 Kernel 8x8x8 | class AlgoInt8x8x32K8x8x8; // Aarch64 Int8x8x32 Kernel 8x8x8 | ||||
| #endif | |||||
| class AlgoInt8x8x16K8x8x8; // Aarch64 Int8x8x16 Kernel 8x8x8 | class AlgoInt8x8x16K8x8x8; // Aarch64 Int8x8x16 Kernel 8x8x8 | ||||
| class AlgoInt8x8x16K4x4x16; // Aarch64 Int8x8x16 Kernel 4x4x16 | class AlgoInt8x8x16K4x4x16; // Aarch64 Int8x8x16 Kernel 4x4x16 | ||||
| class AlgoInt8x8x16MK4_16x12x4; // Aarch64 Int8x8x16 Kernel 16x12x16 | class AlgoInt8x8x16MK4_16x12x4; // Aarch64 Int8x8x16 Kernel 16x12x16 | ||||
| @@ -59,13 +58,12 @@ private: | |||||
| class AlgoInt16x16x32K12x8x1; // Aarch64 Int16x16x32 Kernel 12x8x1 | class AlgoInt16x16x32K12x8x1; // Aarch64 Int16x16x32 Kernel 12x8x1 | ||||
| class AlgoInt16x16x32MK8_8x8; // Aarch64 Int16x16x32 Format MK8 block 8x8 | class AlgoInt16x16x32MK8_8x8; // Aarch64 Int16x16x32 Format MK8 block 8x8 | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #if MGB_ENABLE_DOT | |||||
| class AlgoQuint8K8x8x4DotProd; // Aarch64 Quint8 Kernel | class AlgoQuint8K8x8x4DotProd; // Aarch64 Quint8 Kernel | ||||
| // 8x8x4 DotProduct | // 8x8x4 DotProduct | ||||
| class AlgoQuint8GemvDotProd; // Aarch64 Quint8 Gemv DotProduct | class AlgoQuint8GemvDotProd; // Aarch64 Quint8 Gemv DotProduct | ||||
| #else | |||||
| class AlgoQuint8K8x8x8; // Aarch64 Quint8 Kernel 8x8x8 | |||||
| #endif | #endif | ||||
| class AlgoQuint8K8x8x8; // Aarch64 Quint8 Kernel 8x8x8 | |||||
| class AlgoInt8x8x16MK4_K8x8x8; // Aarch64 Int8x8x16 Kernel 4x4x16 | class AlgoInt8x8x16MK4_K8x8x8; // Aarch64 Int8x8x16 Kernel 4x4x16 | ||||
| class AlgoInt4x4x16K8x8x8; // Aarch64 Int4x4x16 Kernel 4x4x16 | class AlgoInt4x4x16K8x8x8; // Aarch64 Int4x4x16 Kernel 4x4x16 | ||||
| class AlgoPack; | class AlgoPack; | ||||
| @@ -9,7 +9,6 @@ | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| */ | */ | ||||
| #if !(__ARM_FEATURE_DOTPROD) | |||||
| #include "src/aarch64/matrix_mul/asm/common.h" | #include "src/aarch64/matrix_mul/asm/common.h" | ||||
| #include "src/arm_common/simd_macro/marm_neon.h" | #include "src/arm_common/simd_macro/marm_neon.h" | ||||
| @@ -1395,4 +1394,3 @@ static void gemm_u8_8x8_transpose_pack_B_n(dt_uint8* outptr, | |||||
| } // namespace megdnn | } // namespace megdnn | ||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
| #endif | |||||
| @@ -9,7 +9,6 @@ | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| */ | */ | ||||
| #if !(__ARM_FEATURE_DOTPROD) | |||||
| #include "src/aarch64/matrix_mul/quint8/strategy.h" | #include "src/aarch64/matrix_mul/quint8/strategy.h" | ||||
| #include "src/aarch64/matrix_mul/asm/common.h" | #include "src/aarch64/matrix_mul/asm/common.h" | ||||
| #include "src/aarch64/matrix_mul/quint8/kernel_8x8x8.h" | #include "src/aarch64/matrix_mul/quint8/kernel_8x8x8.h" | ||||
| @@ -108,6 +107,5 @@ void gemm_u8_8x8::kern(const dt_uint8* packA, const dt_uint8* packB, size_t M, | |||||
| packA += K4; | packA += K4; | ||||
| } | } | ||||
| } | } | ||||
| #endif | |||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
| @@ -10,7 +10,6 @@ | |||||
| */ | */ | ||||
| #pragma once | #pragma once | ||||
| #if !(__ARM_FEATURE_DOTPROD) | |||||
| #include "src/fallback/matrix_mul/gemm_common.h" | #include "src/fallback/matrix_mul/gemm_common.h" | ||||
| namespace megdnn { | namespace megdnn { | ||||
| @@ -23,6 +22,5 @@ MEGDNN_REG_GEMM_STRATEGY(dt_uint8, dt_int32, dt_int32, 8, 8, 8, false, true, | |||||
| } // namespace matmul | } // namespace matmul | ||||
| } // namespace aarch64 | } // namespace aarch64 | ||||
| } // namespace megdnn | } // namespace megdnn | ||||
| #endif | |||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
| @@ -10,15 +10,13 @@ | |||||
| */ | */ | ||||
| #include "src/aarch64/matrix_mul/quint8_dot/gemv.h" | #include "src/aarch64/matrix_mul/quint8_dot/gemv.h" | ||||
| #include <cstddef> | |||||
| #if MGB_ENABLE_DOT | |||||
| #include "src/arm_common/simd_macro/marm_neon.h" | #include "src/arm_common/simd_macro/marm_neon.h" | ||||
| #include "src/common/utils.h" | #include "src/common/utils.h" | ||||
| #include "src/common/unroll_macro.h" | #include "src/common/unroll_macro.h" | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| namespace { | namespace { | ||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
| void gemv_naive_n(const uint8_t* __restrict A, const uint8_t* __restrict B, | void gemv_naive_n(const uint8_t* __restrict A, const uint8_t* __restrict B, | ||||
| int32_t* __restrict C, size_t M, size_t N, size_t K, | int32_t* __restrict C, size_t M, size_t N, size_t K, | ||||
| size_t Astride, size_t Bstride, size_t Cstride, | size_t Astride, size_t Bstride, size_t Cstride, | ||||
| @@ -146,7 +144,6 @@ void gemv_naive_n(const uint8_t* __restrict A, const uint8_t* __restrict B, | |||||
| acc[0] + acc[1] + acc[2] + acc[3] + zAB - acc_zA - acc_zB; | acc[0] + acc[1] + acc[2] + acc[3] + zAB - acc_zA - acc_zB; | ||||
| } | } | ||||
| } | } | ||||
| } // namespace | } // namespace | ||||
| bool megdnn::aarch64::matmul::is_gemv_like_preferred_quint8( | bool megdnn::aarch64::matmul::is_gemv_like_preferred_quint8( | ||||
| @@ -171,7 +168,5 @@ void megdnn::aarch64::matmul::gemv_like_quint8( | |||||
| return gemv_naive_n(A, B, C, M, N, K, Astride, Bstride, Cstride, | return gemv_naive_n(A, B, C, M, N, K, Astride, Bstride, Cstride, | ||||
| zero_point_A, zero_point_B); | zero_point_A, zero_point_B); | ||||
| } | } | ||||
| #endif | #endif | ||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
| @@ -10,10 +10,9 @@ | |||||
| */ | */ | ||||
| #pragma once | #pragma once | ||||
| #include <cstddef> | |||||
| #include <cstdint> | |||||
| #include "src/common/utils.h" | |||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #if MGB_ENABLE_DOT | |||||
| namespace megdnn { | namespace megdnn { | ||||
| namespace aarch64 { | namespace aarch64 { | ||||
| namespace matmul { | namespace matmul { | ||||
| @@ -9,8 +9,7 @@ | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| */ | */ | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #if MGB_ENABLE_DOT | |||||
| #include "src/aarch64/matrix_mul/asm/common.h" | #include "src/aarch64/matrix_mul/asm/common.h" | ||||
| #include "src/arm_common/simd_macro/marm_neon.h" | #include "src/arm_common/simd_macro/marm_neon.h" | ||||
| @@ -56,7 +55,7 @@ namespace matmul_8x8x4 { | |||||
| // C = sum((A - zA) * (B - zB)) = sum(A * B) - sum(A) * zB - sum(B) * zA + zA * | // C = sum((A - zA) * (B - zB)) = sum(A * B) - sum(A) * zB - sum(B) * zA + zA * | ||||
| // zB * k | // zB * k | ||||
| // A -> v27, v28 | B -> v29, v30 | zA * zB * k -> v26 | // A -> v27, v28 | B -> v29, v30 | zA * zB * k -> v26 | ||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
| static void kern_8x8(const uint8_t* packA, const uint8_t* packB, int K, | static void kern_8x8(const uint8_t* packA, const uint8_t* packB, int K, | ||||
| int32_t* output, int LDC, bool is_first_k, | int32_t* output, int LDC, bool is_first_k, | ||||
| uint8_t zero_point_A, uint8_t zero_point_B, uint32_t zAB) { | uint8_t zero_point_A, uint8_t zero_point_B, uint32_t zAB) { | ||||
| @@ -293,6 +292,7 @@ static void kern_8x8(const uint8_t* packA, const uint8_t* packB, int K, | |||||
| // zB * k | // zB * k | ||||
| // A -> v28 | B -> v29, v30 | zA * zB * k -> v26 | // A -> v28 | B -> v29, v30 | zA * zB * k -> v26 | ||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
| static void kern_4x8(const uint8_t* packA, const uint8_t* packB, int K, | static void kern_4x8(const uint8_t* packA, const uint8_t* packB, int K, | ||||
| int32_t* output, int LDC, bool is_first_k, int m_remain, | int32_t* output, int LDC, bool is_first_k, int m_remain, | ||||
| uint8_t zero_point_A, uint8_t zero_point_B, uint32_t zAB) { | uint8_t zero_point_A, uint8_t zero_point_B, uint32_t zAB) { | ||||
| @@ -495,6 +495,7 @@ static void kern_4x8(const uint8_t* packA, const uint8_t* packB, int K, | |||||
| // zB * k | // zB * k | ||||
| // A -> v27, v28 | B -> v29 | zA * zB * k -> v26 | // A -> v27, v28 | B -> v29 | zA * zB * k -> v26 | ||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
| static void kern_8x4(const uint8_t* packA, const uint8_t* packB, int K, | static void kern_8x4(const uint8_t* packA, const uint8_t* packB, int K, | ||||
| int32_t* output, int LDC, bool is_first_k, int n_remain, | int32_t* output, int LDC, bool is_first_k, int n_remain, | ||||
| uint8_t zero_point_A, uint8_t zero_point_B, uint32_t zAB) { | uint8_t zero_point_A, uint8_t zero_point_B, uint32_t zAB) { | ||||
| @@ -733,6 +734,7 @@ static void kern_8x4(const uint8_t* packA, const uint8_t* packB, int K, | |||||
| // zB * k | // zB * k | ||||
| // A -> v28 | B -> v29 | zA * zB * k -> v26 | // A -> v28 | B -> v29 | zA * zB * k -> v26 | ||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
| static void kern_4x4(const uint8_t* packA, const uint8_t* packB, int K, | static void kern_4x4(const uint8_t* packA, const uint8_t* packB, int K, | ||||
| int32_t* output, int LDC, bool is_first_k, int m_remain, | int32_t* output, int LDC, bool is_first_k, int m_remain, | ||||
| int n_remain, uint8_t zero_point_A, uint8_t zero_point_B, | int n_remain, uint8_t zero_point_A, uint8_t zero_point_B, | ||||
| @@ -16,14 +16,14 @@ | |||||
| #include "src/arm_common/simd_macro/marm_neon.h" | #include "src/arm_common/simd_macro/marm_neon.h" | ||||
| #include "src/common/utils.h" | #include "src/common/utils.h" | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #if MGB_ENABLE_DOT | |||||
| using namespace megdnn; | using namespace megdnn; | ||||
| using namespace aarch64; | using namespace aarch64; | ||||
| using namespace aarch64::matmul; | using namespace aarch64::matmul; | ||||
| MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_u8_8x8); | |||||
| MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_u8_8x8_dot); | |||||
| void gemm_u8_8x8::pack_A(uint8_t* outptr, const uint8_t* inptr, int ldin, | |||||
| void gemm_u8_8x8_dot::pack_A(uint8_t* outptr, const uint8_t* inptr, int ldin, | |||||
| int y0, int ymax, int k0, int kmax, | int y0, int ymax, int k0, int kmax, | ||||
| bool transpose) const { | bool transpose) const { | ||||
| if (transpose) { | if (transpose) { | ||||
| @@ -35,7 +35,7 @@ void gemm_u8_8x8::pack_A(uint8_t* outptr, const uint8_t* inptr, int ldin, | |||||
| } | } | ||||
| } | } | ||||
| void gemm_u8_8x8::pack_B(uint8_t* out, const uint8_t* in, int ldin, int x0, | |||||
| void gemm_u8_8x8_dot::pack_B(uint8_t* out, const uint8_t* in, int ldin, int x0, | |||||
| int xmax, int k0, int kmax, bool transpose) const { | int xmax, int k0, int kmax, bool transpose) const { | ||||
| if (transpose) { | if (transpose) { | ||||
| matmul_8x8x4::gemm_u8_8x8_interleave_pack_helper(out, in, ldin, x0, | matmul_8x8x4::gemm_u8_8x8_interleave_pack_helper(out, in, ldin, x0, | ||||
| @@ -46,7 +46,7 @@ void gemm_u8_8x8::pack_B(uint8_t* out, const uint8_t* in, int ldin, int x0, | |||||
| } | } | ||||
| } | } | ||||
| void gemm_u8_8x8::kern(const uint8_t* packA, const uint8_t* packB, size_t M, | |||||
| void gemm_u8_8x8_dot::kern(const uint8_t* packA, const uint8_t* packB, size_t M, | |||||
| size_t N, size_t K, dt_int32* C, size_t LDC, | size_t N, size_t K, dt_int32* C, size_t LDC, | ||||
| bool is_first_k, const dt_int32*, dt_int32*) const { | bool is_first_k, const dt_int32*, dt_int32*) const { | ||||
| megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | ||||
| @@ -11,13 +11,13 @@ | |||||
| #pragma once | #pragma once | ||||
| #include "src/fallback/matrix_mul/gemm_common.h" | #include "src/fallback/matrix_mul/gemm_common.h" | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #if MGB_ENABLE_DOT | |||||
| namespace megdnn { | namespace megdnn { | ||||
| namespace aarch64 { | namespace aarch64 { | ||||
| namespace matmul { | namespace matmul { | ||||
| MEGDNN_REG_GEMM_STRATEGY(uint8_t, int32_t, int32_t, 8, 8, 4, false, true, | MEGDNN_REG_GEMM_STRATEGY(uint8_t, int32_t, int32_t, 8, 8, 4, false, true, | ||||
| gemm_u8_8x8); | |||||
| gemm_u8_8x8_dot); | |||||
| } // namespace aarch64 | } // namespace aarch64 | ||||
| } // namespace matmul | } // namespace matmul | ||||
| @@ -23,9 +23,6 @@ | |||||
| #include "src/armv7/matrix_mul/asm/common.h" | #include "src/armv7/matrix_mul/asm/common.h" | ||||
| #endif | #endif | ||||
| #if MGB_ENABLE_CPUINFO | |||||
| #include "cpuinfo.h" | |||||
| #endif | |||||
| using namespace megdnn; | using namespace megdnn; | ||||
| using namespace arm_common; | using namespace arm_common; | ||||
| @@ -161,10 +161,13 @@ ConvBiasImpl::AlgoS8DirectStride2::dispatch_kerns( | |||||
| return {}; | return {}; | ||||
| } | } | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #if MGB_ENABLE_DOT | |||||
| /* ===================== dot stride1 algo ======================== */ | /* ===================== dot stride1 algo ======================== */ | ||||
| bool ConvBiasImpl::AlgoDotS8DirectStride1::usable(const NCBKernSizeParam& param, | bool ConvBiasImpl::AlgoDotS8DirectStride1::usable(const NCBKernSizeParam& param, | ||||
| AlgoSelectionStrategy) const { | AlgoSelectionStrategy) const { | ||||
| if (!cpuinfo_has_arm_neon_dot()) { | |||||
| return false; | |||||
| } | |||||
| return direct_dotprod_int8_stride1::can_conv_direct_stride1_int8(param); | return direct_dotprod_int8_stride1::can_conv_direct_stride1_int8(param); | ||||
| } | } | ||||
| @@ -195,6 +198,9 @@ ConvBiasImpl::AlgoDotS8DirectStride1::dispatch_kerns( | |||||
| /* ===================== dot stride2 algo ======================== */ | /* ===================== dot stride2 algo ======================== */ | ||||
| bool ConvBiasImpl::AlgoDotS8DirectStride2::usable(const NCBKernSizeParam& param, | bool ConvBiasImpl::AlgoDotS8DirectStride2::usable(const NCBKernSizeParam& param, | ||||
| AlgoSelectionStrategy) const { | AlgoSelectionStrategy) const { | ||||
| if (!cpuinfo_has_arm_neon_dot()){ | |||||
| return false; | |||||
| } | |||||
| return direct_dotprod_int8_stride2::can_conv_direct_stride2_int8(param); | return direct_dotprod_int8_stride2::can_conv_direct_stride2_int8(param); | ||||
| } | } | ||||
| @@ -129,7 +129,7 @@ public: | |||||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_CHANWISE_STRD2_NCHW44_S8) | MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_CHANWISE_STRD2_NCHW44_S8) | ||||
| }; | }; | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #if MGB_ENABLE_DOT | |||||
| class ConvBiasImpl::AlgoDotS8DirectNCHWNCHW44 final : public AlgoBase { | class ConvBiasImpl::AlgoDotS8DirectNCHWNCHW44 final : public AlgoBase { | ||||
| public: | public: | ||||
| @@ -9,8 +9,8 @@ | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| */ | */ | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #include "src/arm_common/conv_bias/int8/direct_dotprod.h" | #include "src/arm_common/conv_bias/int8/direct_dotprod.h" | ||||
| #if MGB_ENABLE_DOT | |||||
| #include "src/arm_common/elemwise_op.h" | #include "src/arm_common/elemwise_op.h" | ||||
| #include "src/arm_common/simd_macro/marm_neon.h" | #include "src/arm_common/simd_macro/marm_neon.h" | ||||
| #include "src/common/utils.h" | #include "src/common/utils.h" | ||||
| @@ -90,6 +90,7 @@ inline int8x16_t vqtbl1q_s8_v7(int8x16_t a, uint8x16_t index) { | |||||
| _sum1##_c_idx = vdotq_s32(_sum1##_c_idx, _k##_k2_idx, _elem); | _sum1##_c_idx = vdotq_s32(_sum1##_c_idx, _k##_k2_idx, _elem); | ||||
| template <bool first_ic, bool last_ic, BiasMode bias_mode, typename Op> | template <bool first_ic, bool last_ic, BiasMode bias_mode, typename Op> | ||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
| void conv_bias::conv_direct_stride1_2x2_int8_dot( | void conv_bias::conv_direct_stride1_2x2_int8_dot( | ||||
| const int8_t* src, const int8_t* filter, const int32_t* bias, | const int8_t* src, const int8_t* filter, const int32_t* bias, | ||||
| int32_t* temp, int8_t* dst, const size_t IH, const size_t IW, | int32_t* temp, int8_t* dst, const size_t IH, const size_t IW, | ||||
| @@ -325,6 +326,7 @@ void conv_bias::conv_direct_stride1_2x2_int8_dot( | |||||
| } | } | ||||
| template <bool first_ic, bool last_ic, BiasMode bias_mode, typename Op> | template <bool first_ic, bool last_ic, BiasMode bias_mode, typename Op> | ||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
| void conv_bias::conv_direct_stride1_3x3_int8_dot( | void conv_bias::conv_direct_stride1_3x3_int8_dot( | ||||
| const int8_t* src, const int8_t* filter, const int32_t* bias, | const int8_t* src, const int8_t* filter, const int32_t* bias, | ||||
| int32_t* temp, int8_t* dst, const size_t IH, const size_t IW, | int32_t* temp, int8_t* dst, const size_t IH, const size_t IW, | ||||
| @@ -560,6 +562,7 @@ void conv_bias::conv_direct_stride1_3x3_int8_dot( | |||||
| } | } | ||||
| template <bool first_ic, bool last_ic, BiasMode bias_mode, typename Op> | template <bool first_ic, bool last_ic, BiasMode bias_mode, typename Op> | ||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
| void conv_bias::conv_direct_stride2_2x2_int8_dot( | void conv_bias::conv_direct_stride2_2x2_int8_dot( | ||||
| const int8_t* src, const int8_t* filter, const int32_t* bias, | const int8_t* src, const int8_t* filter, const int32_t* bias, | ||||
| int32_t* temp, int8_t* dst, const size_t IH, const size_t IW, | int32_t* temp, int8_t* dst, const size_t IH, const size_t IW, | ||||
| @@ -655,6 +658,7 @@ void conv_bias::conv_direct_stride2_2x2_int8_dot( | |||||
| } | } | ||||
| template <bool first_ic, bool last_ic, BiasMode bias_mode, typename Op> | template <bool first_ic, bool last_ic, BiasMode bias_mode, typename Op> | ||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
| void conv_bias::conv_direct_stride2_3x3_int8_dot( | void conv_bias::conv_direct_stride2_3x3_int8_dot( | ||||
| const int8_t* src, const int8_t* filter, const int32_t* bias, | const int8_t* src, const int8_t* filter, const int32_t* bias, | ||||
| int32_t* temp, int8_t* dst, const size_t IH, const size_t IW, | int32_t* temp, int8_t* dst, const size_t IH, const size_t IW, | ||||
| @@ -810,6 +814,7 @@ void conv_bias::conv_direct_stride2_3x3_int8_dot( | |||||
| _sum1##_c_idx = vdotq_s32(_sum1##_c_idx, _k##_k11_idx, _elem); | _sum1##_c_idx = vdotq_s32(_sum1##_c_idx, _k##_k11_idx, _elem); | ||||
| template <bool first_ic, bool last_ic, BiasMode bias_mode, typename Op> | template <bool first_ic, bool last_ic, BiasMode bias_mode, typename Op> | ||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
| void conv_bias::conv_direct_stride2_5x5_int8_dot( | void conv_bias::conv_direct_stride2_5x5_int8_dot( | ||||
| const int8_t* src, const int8_t* filter, const int32_t* bias, | const int8_t* src, const int8_t* filter, const int32_t* bias, | ||||
| int32_t* temp, int8_t* dst, const size_t IH, const size_t IW, | int32_t* temp, int8_t* dst, const size_t IH, const size_t IW, | ||||
| @@ -1108,6 +1113,7 @@ void conv_bias::conv_direct_stride2_5x5_int8_dot( | |||||
| } | } | ||||
| template <bool first_ic, bool last_ic, BiasMode bias_mode, typename Op> | template <bool first_ic, bool last_ic, BiasMode bias_mode, typename Op> | ||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
| void conv_bias::conv_direct_stride2_7x7_int8_dot( | void conv_bias::conv_direct_stride2_7x7_int8_dot( | ||||
| const int8_t* src, const int8_t* filter, const int32_t* bias, | const int8_t* src, const int8_t* filter, const int32_t* bias, | ||||
| int32_t* temp, int8_t* dst, const size_t IH, const size_t IW, | int32_t* temp, int8_t* dst, const size_t IH, const size_t IW, | ||||
| @@ -1470,6 +1476,7 @@ void conv_bias::conv_direct_stride2_7x7_int8_dot( | |||||
| } | } | ||||
| template <bool first_ic, bool last_ic, BiasMode bias_mode, typename Op> | template <bool first_ic, bool last_ic, BiasMode bias_mode, typename Op> | ||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
| void conv_bias::conv_direct_stride1_5x5_int8_dot( | void conv_bias::conv_direct_stride1_5x5_int8_dot( | ||||
| const int8_t* src, const int8_t* filter, const int32_t* bias, | const int8_t* src, const int8_t* filter, const int32_t* bias, | ||||
| int32_t* temp, int8_t* dst, const size_t IH, const size_t IW, | int32_t* temp, int8_t* dst, const size_t IH, const size_t IW, | ||||
| @@ -1770,6 +1777,7 @@ void conv_bias::conv_direct_stride1_5x5_int8_dot( | |||||
| } | } | ||||
| template <bool first_ic, bool last_ic, BiasMode bias_mode, typename Op> | template <bool first_ic, bool last_ic, BiasMode bias_mode, typename Op> | ||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
| void conv_bias::conv_direct_stride1_7x7_int8_dot( | void conv_bias::conv_direct_stride1_7x7_int8_dot( | ||||
| const int8_t* src, const int8_t* filter, const int32_t* bias, | const int8_t* src, const int8_t* filter, const int32_t* bias, | ||||
| int32_t* temp, int8_t* dst, const size_t IH, const size_t IW, | int32_t* temp, int8_t* dst, const size_t IH, const size_t IW, | ||||
| @@ -2115,6 +2123,7 @@ void conv_bias::conv_direct_stride1_7x7_int8_dot( | |||||
| #undef ST1_S32X4 | #undef ST1_S32X4 | ||||
| #undef ST2_S32X4X2 | #undef ST2_S32X4X2 | ||||
| #define INSTANTIATION(stride, i, first_ic, last_ic, bias, Op) \ | #define INSTANTIATION(stride, i, first_ic, last_ic, bias, Op) \ | ||||
| template void conv_bias::conv_direct_##stride##_##i##x##i##_int8_dot< \ | template void conv_bias::conv_direct_##stride##_##i##x##i##_int8_dot< \ | ||||
| first_ic, last_ic, bias, Op>( \ | first_ic, last_ic, bias, Op>( \ | ||||
| @@ -8,8 +8,8 @@ | |||||
| * software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| */ | */ | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #include "src/arm_common/conv_bias/opr_impl.h" | #include "src/arm_common/conv_bias/opr_impl.h" | ||||
| #if MGB_ENABLE_DOT | |||||
| #include "src/fallback/conv_bias/common.h" | #include "src/fallback/conv_bias/common.h" | ||||
| namespace megdnn { | namespace megdnn { | ||||
| @@ -10,9 +10,8 @@ | |||||
| * implied. | * implied. | ||||
| */ | */ | ||||
| #ifdef __ARM_FEATURE_DOTPROD | |||||
| #include "src/arm_common/elemwise_helper/kimpl/typecvt.h" | #include "src/arm_common/elemwise_helper/kimpl/typecvt.h" | ||||
| #if MGB_ENABLE_DOT | |||||
| #include "src/common/unroll_macro.h" | #include "src/common/unroll_macro.h" | ||||
| #include "src/common/utils.h" | #include "src/common/utils.h" | ||||
| #include "src/fallback/conv_bias/common.h" | #include "src/fallback/conv_bias/common.h" | ||||
| @@ -10,11 +10,10 @@ | |||||
| * implied. | * implied. | ||||
| */ | */ | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #pragma once | #pragma once | ||||
| #include "src/arm_common/conv_bias/opr_impl.h" | #include "src/arm_common/conv_bias/opr_impl.h" | ||||
| #if MGB_ENABLE_DOT | |||||
| namespace megdnn { | namespace megdnn { | ||||
| namespace arm_common { | namespace arm_common { | ||||
| @@ -78,4 +77,4 @@ void copy_packed_src_int8_nchw44(int8_t* dst, const int dst_step, | |||||
| #endif | #endif | ||||
| // vim: syntax=cpp.doxygen | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -10,9 +10,8 @@ | |||||
| * implied. | * implied. | ||||
| */ | */ | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #include "src/arm_common/conv_bias/block_helper.h" | #include "src/arm_common/conv_bias/block_helper.h" | ||||
| #if MGB_ENABLE_DOT | |||||
| #include "src/arm_common/conv_bias/int8/algos.h" | #include "src/arm_common/conv_bias/int8/algos.h" | ||||
| #include "src/arm_common/conv_bias/int8/direct_dotprod_nchw44.h" | #include "src/arm_common/conv_bias/int8/direct_dotprod_nchw44.h" | ||||
| #include "src/arm_common/elemwise_op.h" | #include "src/arm_common/elemwise_op.h" | ||||
| @@ -159,6 +158,9 @@ static void conv_kern(const WorkspaceBundle& bundle, | |||||
| bool ConvBiasImpl::AlgoDotS8Direct_NCHW44::usable( | bool ConvBiasImpl::AlgoDotS8Direct_NCHW44::usable( | ||||
| const NCBKernSizeParam& param, | const NCBKernSizeParam& param, | ||||
| AlgoSelectionStrategy algo_selection_strategy) const { | AlgoSelectionStrategy algo_selection_strategy) const { | ||||
| if (!cpuinfo_has_arm_neon_dot()){ | |||||
| return false; | |||||
| } | |||||
| MEGDNN_MARK_USED_VAR(algo_selection_strategy); | MEGDNN_MARK_USED_VAR(algo_selection_strategy); | ||||
| auto&& fm = param.filter_meta; | auto&& fm = param.filter_meta; | ||||
| auto FH = fm.spatial[0]; | auto FH = fm.spatial[0]; | ||||
| @@ -11,9 +11,9 @@ | |||||
| * implied. | * implied. | ||||
| */ | */ | ||||
| #pragma once | #pragma once | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #include "megdnn/arch.h" | #include "megdnn/arch.h" | ||||
| #include "src/arm_common/conv_bias/intrinsic_helper.h" | #include "src/arm_common/conv_bias/intrinsic_helper.h" | ||||
| #if MGB_ENABLE_DOT | |||||
| #include "src/arm_common/elemwise_op.h" | #include "src/arm_common/elemwise_op.h" | ||||
| #include "src/arm_common/intrinsic_helper.h" | #include "src/arm_common/intrinsic_helper.h" | ||||
| #include "src/arm_common/neon_struct.h" | #include "src/arm_common/neon_struct.h" | ||||
| @@ -208,6 +208,7 @@ MEGDNN_ALWAYS_INLINE void store_ocx_owx_remain_static(int32x4_t res[][8], | |||||
| template <int res_row, int src_row, int src_start_idx, int weight_idx, | template <int res_row, int src_row, int src_start_idx, int weight_idx, | ||||
| typename T, typename T2, typename T3> | typename T, typename T2, typename T3> | ||||
| struct ShiftCalHelper { | struct ShiftCalHelper { | ||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
| static MEGDNN_ALWAYS_INLINE void impl(T& res, T2& src, T3& weight) { | static MEGDNN_ALWAYS_INLINE void impl(T& res, T2& src, T3& weight) { | ||||
| #define cb(step) \ | #define cb(step) \ | ||||
| res[res_row][step] = \ | res[res_row][step] = \ | ||||
| @@ -221,6 +222,7 @@ struct ShiftCalHelper { | |||||
| template <int res_row, int src_row, int src_start_idx, int weight_idx, | template <int res_row, int src_row, int src_start_idx, int weight_idx, | ||||
| typename T, typename T2, typename T3> | typename T, typename T2, typename T3> | ||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
| MEGDNN_ALWAYS_INLINE void cal_helper(T& res, T2& src, T3& weight) { | MEGDNN_ALWAYS_INLINE void cal_helper(T& res, T2& src, T3& weight) { | ||||
| ShiftCalHelper<res_row, src_row, src_start_idx, weight_idx, T, T2, | ShiftCalHelper<res_row, src_row, src_start_idx, weight_idx, T, T2, | ||||
| T3>::impl(res, src, weight); | T3>::impl(res, src, weight); | ||||
| @@ -242,4 +244,4 @@ struct KernNeonSdotNCHW44 { | |||||
| } // namespace arm_common | } // namespace arm_common | ||||
| } // namespace megdnn | } // namespace megdnn | ||||
| #endif | #endif | ||||
| // vim: syntax=cpp.doxygen | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -10,8 +10,8 @@ | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | ||||
| * implied. | * implied. | ||||
| */ | */ | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #include "src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_common.h" | #include "src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_common.h" | ||||
| #if MGB_ENABLE_DOT | |||||
| namespace megdnn { | namespace megdnn { | ||||
| namespace arm_common { | namespace arm_common { | ||||
| @@ -20,6 +20,7 @@ template <typename dst_type, BiasMode bias_mode, typename Op, int ow_remain, | |||||
| int filter_size, int oc_interval, int ow_interval> | int filter_size, int oc_interval, int ow_interval> | ||||
| struct KernNeonSdotNCHW44<dst_type, 1, bias_mode, Op, ow_remain, filter_size, | struct KernNeonSdotNCHW44<dst_type, 1, bias_mode, Op, ow_remain, filter_size, | ||||
| oc_interval, ow_interval> { | oc_interval, ow_interval> { | ||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
| static void impl(dst_type* dst, const int dst_step, const int8_t* src, | static void impl(dst_type* dst, const int dst_step, const int8_t* src, | ||||
| const int ih, const int iw, const int8_t* filter, | const int ih, const int iw, const int8_t* filter, | ||||
| const int32_t* bias, const int ic, const Op& op) { | const int32_t* bias, const int ic, const Op& op) { | ||||
| @@ -109,6 +110,7 @@ struct KernNeonSdotNCHW44<dst_type, 1, bias_mode, Op, ow_remain, filter_size, | |||||
| template <typename dst_type, int stride, BiasMode bias_mode, typename Op, | template <typename dst_type, int stride, BiasMode bias_mode, typename Op, | ||||
| int filter_size> | int filter_size> | ||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
| void conv_direct_sdot_int8_nchw44(dst_type* dst, const int oh, const int ow, | void conv_direct_sdot_int8_nchw44(dst_type* dst, const int oh, const int ow, | ||||
| const int8_t* src, const int ih, const int iw, | const int8_t* src, const int ih, const int iw, | ||||
| const int8_t* filter, const int32_t* bias, | const int8_t* filter, const int32_t* bias, | ||||
| @@ -317,4 +319,4 @@ FOR_FILTER(1) | |||||
| } // namespace arm_common | } // namespace arm_common | ||||
| } // namespace megdnn | } // namespace megdnn | ||||
| #endif | #endif | ||||
| // vim: syntax=cpp.doxygen | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -10,9 +10,9 @@ | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | ||||
| * implied. | * implied. | ||||
| */ | */ | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #include "src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_common.h" | #include "src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_common.h" | ||||
| #if MGB_ENABLE_DOT | |||||
| namespace megdnn { | namespace megdnn { | ||||
| namespace arm_common { | namespace arm_common { | ||||
| namespace direct_dotprod_nchw44 { | namespace direct_dotprod_nchw44 { | ||||
| @@ -20,6 +20,7 @@ template <typename dst_type, BiasMode bias_mode, typename Op, int ow_remain, | |||||
| int filter_size, int oc_interval, int ow_interval> | int filter_size, int oc_interval, int ow_interval> | ||||
| struct KernNeonSdotNCHW44<dst_type, 2, bias_mode, Op, ow_remain, filter_size, | struct KernNeonSdotNCHW44<dst_type, 2, bias_mode, Op, ow_remain, filter_size, | ||||
| oc_interval, ow_interval> { | oc_interval, ow_interval> { | ||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
| static void impl(dst_type* dst, const int dst_step, const int8_t* src, | static void impl(dst_type* dst, const int dst_step, const int8_t* src, | ||||
| const int ih, const int iw, const int8_t* filter, | const int ih, const int iw, const int8_t* filter, | ||||
| const int32_t* bias, const int ic, const Op& op) { | const int32_t* bias, const int ic, const Op& op) { | ||||
| @@ -110,6 +111,7 @@ struct KernNeonSdotNCHW44<dst_type, 2, bias_mode, Op, ow_remain, filter_size, | |||||
| template <typename dst_type, int stride, BiasMode bias_mode, typename Op, | template <typename dst_type, int stride, BiasMode bias_mode, typename Op, | ||||
| int filter_size> | int filter_size> | ||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
| void conv_direct_sdot_int8_nchw44(dst_type* dst, const int oh, const int ow, | void conv_direct_sdot_int8_nchw44(dst_type* dst, const int oh, const int ow, | ||||
| const int8_t* src, const int ih, const int iw, | const int8_t* src, const int ih, const int iw, | ||||
| const int8_t* filter, const int32_t* bias, | const int8_t* filter, const int32_t* bias, | ||||
| @@ -319,4 +321,4 @@ FOR_FILTER(2) | |||||
| } // namespace megdnn | } // namespace megdnn | ||||
| #endif | #endif | ||||
| // vim: syntax=cpp.doxygen | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -11,8 +11,8 @@ | |||||
| * implied. | * implied. | ||||
| */ | */ | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #include "src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_kern.h" | #include "src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_kern.h" | ||||
| #if MGB_ENABLE_DOT | |||||
| namespace megdnn { | namespace megdnn { | ||||
| namespace arm_common { | namespace arm_common { | ||||
| namespace dot_direct_nchw_nchw44 { | namespace dot_direct_nchw_nchw44 { | ||||
| @@ -20,6 +20,7 @@ namespace dot_direct_nchw_nchw44 { | |||||
| template <int src_idx, int weight_idx, typename Func, typename T, typename T2, | template <int src_idx, int weight_idx, typename Func, typename T, typename T2, | ||||
| typename T3, typename T4> | typename T3, typename T4> | ||||
| struct ShiftCalHelper<src_idx, weight_idx, 2, Func, 8, 1, T, T2, T3, T4> { | struct ShiftCalHelper<src_idx, weight_idx, 2, Func, 8, 1, T, T2, T3, T4> { | ||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
| static void impl(T& c, T2& src, T3& weight) { | static void impl(T& c, T2& src, T3& weight) { | ||||
| #define cb(step) \ | #define cb(step) \ | ||||
| c[0][step] = Func::template impl<(src_idx + step) % 4>( \ | c[0][step] = Func::template impl<(src_idx + step) % 4>( \ | ||||
| @@ -35,6 +36,7 @@ struct ShiftCalHelper<src_idx, weight_idx, 2, Func, 8, 1, T, T2, T3, T4> { | |||||
| template <int src_idx, int weight_idx, typename Func, typename T, typename T2, | template <int src_idx, int weight_idx, typename Func, typename T, typename T2, | ||||
| typename T3, typename T4> | typename T3, typename T4> | ||||
| struct ShiftCalHelper<src_idx, weight_idx, 1, Func, 8, 1, T, T2, T3, T4> { | struct ShiftCalHelper<src_idx, weight_idx, 1, Func, 8, 1, T, T2, T3, T4> { | ||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
| static void impl(T& c, T2& src, T3& weight) { | static void impl(T& c, T2& src, T3& weight) { | ||||
| #define cb(step) \ | #define cb(step) \ | ||||
| c[0][step] = Func::template impl<(src_idx + step) % 4>( \ | c[0][step] = Func::template impl<(src_idx + step) % 4>( \ | ||||
| @@ -49,6 +51,7 @@ template <BiasMode bias_mode, typename Op, int remain_w, int oc_block, | |||||
| int ow_block> | int ow_block> | ||||
| struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 2, oc_block, ow_block, | struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 2, oc_block, ow_block, | ||||
| 1> { | 1> { | ||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
| static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, | static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, | ||||
| const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, | const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, | ||||
| int iw, int ld_dst_oc, const Op& op) { | int iw, int ld_dst_oc, const Op& op) { | ||||
| @@ -97,6 +100,7 @@ template <BiasMode bias_mode, typename Op, int remain_w, int oc_block, | |||||
| int ow_block> | int ow_block> | ||||
| struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 3, oc_block, ow_block, | struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 3, oc_block, ow_block, | ||||
| 1> { | 1> { | ||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
| static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, | static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, | ||||
| const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, | const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, | ||||
| int iw, int ld_dst_oc, const Op& op) { | int iw, int ld_dst_oc, const Op& op) { | ||||
| @@ -151,6 +155,7 @@ template <BiasMode bias_mode, typename Op, int remain_w, int oc_block, | |||||
| int ow_block> | int ow_block> | ||||
| struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 5, oc_block, ow_block, | struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 5, oc_block, ow_block, | ||||
| 1> { | 1> { | ||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
| static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, | static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, | ||||
| const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, | const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, | ||||
| int iw, int ld_dst_oc, const Op& op) { | int iw, int ld_dst_oc, const Op& op) { | ||||
| @@ -200,6 +205,7 @@ template <BiasMode bias_mode, typename Op, int remain_w, int oc_block, | |||||
| int ow_block> | int ow_block> | ||||
| struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 7, oc_block, ow_block, | struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 7, oc_block, ow_block, | ||||
| 1> { | 1> { | ||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
| static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, | static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, | ||||
| const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, | const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, | ||||
| int iw, int ld_dst_oc, const Op& op) { | int iw, int ld_dst_oc, const Op& op) { | ||||
| @@ -302,6 +308,7 @@ void pack_src_int8_nchw_nchw44_dot<1>(int8_t* sptr_base, | |||||
| } | } | ||||
| template <BiasMode bias_mode, typename Op, int filter_size, int stride> | template <BiasMode bias_mode, typename Op, int filter_size, int stride> | ||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
| void conv_direct_int8_nchw_nchw44_dot(const int8_t* src, const int8_t* filter, | void conv_direct_int8_nchw_nchw44_dot(const int8_t* src, const int8_t* filter, | ||||
| const int32_t* bias, int32_t* temp, | const int32_t* bias, int32_t* temp, | ||||
| int8_t* dst, const int oc, const int ic, | int8_t* dst, const int oc, const int ic, | ||||
| @@ -445,4 +452,4 @@ DISPATCH_CONV_KERN(1); | |||||
| } // namespace arm_common | } // namespace arm_common | ||||
| } // namespace megdnn | } // namespace megdnn | ||||
| #endif | #endif | ||||
| // vim: syntax=cpp.doxygen | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -10,8 +10,8 @@ | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | ||||
| * implied. | * implied. | ||||
| */ | */ | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #include "src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_kern.h" | #include "src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_kern.h" | ||||
| #if MGB_ENABLE_DOT | |||||
| namespace megdnn { | namespace megdnn { | ||||
| namespace arm_common { | namespace arm_common { | ||||
| namespace dot_direct_nchw_nchw44 { | namespace dot_direct_nchw_nchw44 { | ||||
| @@ -19,6 +19,7 @@ namespace dot_direct_nchw_nchw44 { | |||||
| template <int src_idx, int weight_idx, typename Func, typename T, typename T2, | template <int src_idx, int weight_idx, typename Func, typename T, typename T2, | ||||
| typename T3, typename T4> | typename T3, typename T4> | ||||
| struct ShiftCalHelper<src_idx, weight_idx, 2, Func, 8, 2, T, T2, T3, T4> { | struct ShiftCalHelper<src_idx, weight_idx, 2, Func, 8, 2, T, T2, T3, T4> { | ||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
| static void impl(T& c, T2& src, T3& weight) { | static void impl(T& c, T2& src, T3& weight) { | ||||
| #define cb(step) \ | #define cb(step) \ | ||||
| c[0][step * 2] = Func::template impl<(src_idx + step) % 4>( \ | c[0][step * 2] = Func::template impl<(src_idx + step) % 4>( \ | ||||
| @@ -42,6 +43,7 @@ struct ShiftCalHelper<src_idx, weight_idx, 2, Func, 8, 2, T, T2, T3, T4> { | |||||
| template <int src_idx, int weight_idx, typename Func, typename T, typename T2, | template <int src_idx, int weight_idx, typename Func, typename T, typename T2, | ||||
| typename T3, typename T4> | typename T3, typename T4> | ||||
| struct ShiftCalHelper<src_idx, weight_idx, 1, Func, 8, 2, T, T2, T3, T4> { | struct ShiftCalHelper<src_idx, weight_idx, 1, Func, 8, 2, T, T2, T3, T4> { | ||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
| static void impl(T& c, T2& src, T3& weight) { | static void impl(T& c, T2& src, T3& weight) { | ||||
| #define cb(step) \ | #define cb(step) \ | ||||
| c[0][step * 2] = Func::template impl<(src_idx + step) % 4>( \ | c[0][step * 2] = Func::template impl<(src_idx + step) % 4>( \ | ||||
| @@ -60,6 +62,7 @@ template <BiasMode bias_mode, typename Op, int remain_w, int oc_block, | |||||
| int ow_block> | int ow_block> | ||||
| struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 2, oc_block, ow_block, | struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 2, oc_block, ow_block, | ||||
| 2> { | 2> { | ||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
| static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, | static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, | ||||
| const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, | const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, | ||||
| int iw, int ld_dst_oc, const Op& op) { | int iw, int ld_dst_oc, const Op& op) { | ||||
| @@ -111,6 +114,7 @@ template <BiasMode bias_mode, typename Op, int remain_w, int oc_block, | |||||
| int ow_block> | int ow_block> | ||||
| struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 3, oc_block, ow_block, | struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 3, oc_block, ow_block, | ||||
| 2> { | 2> { | ||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
| static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, | static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, | ||||
| const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, | const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, | ||||
| int iw, int ld_dst_oc, const Op& op) { | int iw, int ld_dst_oc, const Op& op) { | ||||
| @@ -169,6 +173,7 @@ template <BiasMode bias_mode, typename Op, int remain_w, int oc_block, | |||||
| int ow_block> | int ow_block> | ||||
| struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 5, oc_block, ow_block, | struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 5, oc_block, ow_block, | ||||
| 2> { | 2> { | ||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
| static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, | static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, | ||||
| const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, | const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, | ||||
| int iw, int ld_dst_oc, const Op& op) { | int iw, int ld_dst_oc, const Op& op) { | ||||
| @@ -224,6 +229,7 @@ template <BiasMode bias_mode, typename Op, int remain_w, int oc_block, | |||||
| int ow_block> | int ow_block> | ||||
| struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 7, oc_block, ow_block, | struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 7, oc_block, ow_block, | ||||
| 2> { | 2> { | ||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
| static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, | static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, | ||||
| const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, | const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, | ||||
| int iw, int ld_dst_oc, const Op& op) { | int iw, int ld_dst_oc, const Op& op) { | ||||
| @@ -289,6 +295,7 @@ void pack_src_int8_nchw_nchw44_dot<2>( | |||||
| } | } | ||||
| template <BiasMode bias_mode, typename Op, int filter_size, int stride> | template <BiasMode bias_mode, typename Op, int filter_size, int stride> | ||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
| void conv_direct_int8_nchw_nchw44_dot(const int8_t* src, const int8_t* filter, | void conv_direct_int8_nchw_nchw44_dot(const int8_t* src, const int8_t* filter, | ||||
| const int32_t* bias, int32_t* temp, | const int32_t* bias, int32_t* temp, | ||||
| int8_t* dst, const int oc, const int ic, | int8_t* dst, const int oc, const int ic, | ||||
| @@ -434,4 +441,4 @@ DISPATCH_CONV_KERN(2); | |||||
| } // namespace megdnn | } // namespace megdnn | ||||
| #endif | #endif | ||||
| // vim: syntax=cpp.doxygen | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -10,8 +10,8 @@ | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express | ||||
| * or implied. | * or implied. | ||||
| */ | */ | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #include "megdnn/oprs.h" | #include "megdnn/oprs.h" | ||||
| #if MGB_ENABLE_DOT | |||||
| #include "src/arm_common/conv_bias/block_helper.h" | #include "src/arm_common/conv_bias/block_helper.h" | ||||
| #include "src/arm_common/conv_bias/int8/algos.h" | #include "src/arm_common/conv_bias/int8/algos.h" | ||||
| #include "src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_kern.h" | #include "src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_kern.h" | ||||
| @@ -175,6 +175,9 @@ static void do_conv_kern(const WorkspaceBundle& bundle, | |||||
| bool ConvBiasImpl::AlgoDotS8DirectNCHWNCHW44::usable( | bool ConvBiasImpl::AlgoDotS8DirectNCHWNCHW44::usable( | ||||
| const NCBKernSizeParam& param, AlgoSelectionStrategy) const { | const NCBKernSizeParam& param, AlgoSelectionStrategy) const { | ||||
| if (!cpuinfo_has_arm_neon_dot()){ | |||||
| return false; | |||||
| } | |||||
| return nchw_nchwxx_valid<NchwNchwxxType::NCHW44_INT8_DOT>( | return nchw_nchwxx_valid<NchwNchwxxType::NCHW44_INT8_DOT>( | ||||
| param.src_type.enumv(), param.filter_type.enumv(), | param.src_type.enumv(), param.filter_type.enumv(), | ||||
| param.dst_type.enumv(), param.filter_meta, param.bias_mode, | param.dst_type.enumv(), param.filter_meta, param.bias_mode, | ||||
| @@ -11,9 +11,9 @@ | |||||
| * implied. | * implied. | ||||
| */ | */ | ||||
| #pragma once | #pragma once | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #include "src/arm_common/conv_bias/intrinsic_helper.h" | #include "src/arm_common/conv_bias/intrinsic_helper.h" | ||||
| #if MGB_ENABLE_DOT | |||||
| #include "src/arm_common/elemwise_op.h" | #include "src/arm_common/elemwise_op.h" | ||||
| #include "src/arm_common/simd_macro/marm_neon.h" | #include "src/arm_common/simd_macro/marm_neon.h" | ||||
| #include "src/common/unroll_macro.h" | #include "src/common/unroll_macro.h" | ||||
| @@ -8,9 +8,9 @@ | |||||
| * software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| */ | */ | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #include "src/arm_common/conv_bias/int8/stride1_dotprod.h" | #include "src/arm_common/conv_bias/int8/stride1_dotprod.h" | ||||
| #if MGB_ENABLE_DOT | |||||
| #include "megdnn/oprs.h" | #include "megdnn/oprs.h" | ||||
| #include "src/arm_common/conv_bias/int8/direct_dotprod.h" | #include "src/arm_common/conv_bias/int8/direct_dotprod.h" | ||||
| #include "src/arm_common/conv_bias/int8/strategy.h" | #include "src/arm_common/conv_bias/int8/strategy.h" | ||||
| @@ -8,10 +8,10 @@ | |||||
| * software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| */ | */ | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #pragma once | #pragma once | ||||
| #include "src/arm_common/conv_bias/opr_impl.h" | #include "src/arm_common/conv_bias/opr_impl.h" | ||||
| #if MGB_ENABLE_DOT | |||||
| namespace megdnn { | namespace megdnn { | ||||
| namespace arm_common { | namespace arm_common { | ||||
| namespace direct_dotprod_int8_stride1 { | namespace direct_dotprod_int8_stride1 { | ||||
| @@ -9,8 +9,8 @@ | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| */ | */ | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #include "src/arm_common/conv_bias/int8/stride2_dotprod.h" | #include "src/arm_common/conv_bias/int8/stride2_dotprod.h" | ||||
| #if MGB_ENABLE_DOT | |||||
| #include "megdnn/oprs.h" | #include "megdnn/oprs.h" | ||||
| #include "src/arm_common/conv_bias/int8/direct_dotprod.h" | #include "src/arm_common/conv_bias/int8/direct_dotprod.h" | ||||
| #include "src/arm_common/conv_bias/int8/strategy.h" | #include "src/arm_common/conv_bias/int8/strategy.h" | ||||
| @@ -9,9 +9,9 @@ | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| */ | */ | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #pragma once | #pragma once | ||||
| #include "src/arm_common/conv_bias/opr_impl.h" | #include "src/arm_common/conv_bias/opr_impl.h" | ||||
| #if MGB_ENABLE_DOT | |||||
| namespace megdnn { | namespace megdnn { | ||||
| namespace arm_common { | namespace arm_common { | ||||
| @@ -60,7 +60,7 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj { | |||||
| AlgoS8x8x16ChanWiseStride1Stride2NCHW44 | AlgoS8x8x16ChanWiseStride1Stride2NCHW44 | ||||
| s8x8x16_channel_wise_stride1_stride2_nchw44; | s8x8x16_channel_wise_stride1_stride2_nchw44; | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #if MGB_ENABLE_DOT | |||||
| AlgoDotS8DirectStride1 ds8_direct_stride1; | AlgoDotS8DirectStride1 ds8_direct_stride1; | ||||
| AlgoDotS8DirectStride2 ds8_direct_stride2; | AlgoDotS8DirectStride2 ds8_direct_stride2; | ||||
| AlgoDotU8DirectStride1 du8_direct_stride1; | AlgoDotU8DirectStride1 du8_direct_stride1; | ||||
| @@ -94,7 +94,7 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj { | |||||
| public: | public: | ||||
| AlgoPack() { | AlgoPack() { | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #if MGB_ENABLE_DOT | |||||
| m_direct_algos.emplace_back(&ds8_direct_stride1); | m_direct_algos.emplace_back(&ds8_direct_stride1); | ||||
| m_direct_algos.emplace_back(&ds8_direct_stride2); | m_direct_algos.emplace_back(&ds8_direct_stride2); | ||||
| m_direct_algos.emplace_back(&du8_direct_stride1); | m_direct_algos.emplace_back(&du8_direct_stride1); | ||||
| @@ -70,7 +70,7 @@ private: | |||||
| class AlgoFP16WinogradF63; | class AlgoFP16WinogradF63; | ||||
| class AlgoFP16WinogradF23_8x8; | class AlgoFP16WinogradF23_8x8; | ||||
| #endif | #endif | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #if MGB_ENABLE_DOT | |||||
| class AlgoDotS8DirectNCHWNCHW44; | class AlgoDotS8DirectNCHWNCHW44; | ||||
| class AlgoDotS8DirectStride1; | class AlgoDotS8DirectStride1; | ||||
| class AlgoDotS8DirectStride2; | class AlgoDotS8DirectStride2; | ||||
| @@ -11,7 +11,6 @@ | |||||
| */ | */ | ||||
| #include "src/arm_common/conv_bias/quint8/algos.h" | #include "src/arm_common/conv_bias/quint8/algos.h" | ||||
| #include "midout.h" | |||||
| #include "src/arm_common/conv_bias/quint8/stride1.h" | #include "src/arm_common/conv_bias/quint8/stride1.h" | ||||
| #include "src/arm_common/conv_bias/quint8/stride1_dotprod.h" | #include "src/arm_common/conv_bias/quint8/stride1_dotprod.h" | ||||
| #include "src/arm_common/conv_bias/quint8/stride2.h" | #include "src/arm_common/conv_bias/quint8/stride2.h" | ||||
| @@ -19,6 +18,8 @@ | |||||
| #include "src/arm_common/elemwise_op.h" | #include "src/arm_common/elemwise_op.h" | ||||
| #include "src/fallback/conv_bias/common.h" | #include "src/fallback/conv_bias/common.h" | ||||
| #include "midout.h" | |||||
| MIDOUT_DECL(megdnn_arm_common_conv_bias_quint8) | MIDOUT_DECL(megdnn_arm_common_conv_bias_quint8) | ||||
| using namespace megdnn; | using namespace megdnn; | ||||
| @@ -84,10 +85,13 @@ ConvBiasImpl::AlgoQU8DirectStride2::dispatch_kerns( | |||||
| MIDOUT_END(); | MIDOUT_END(); | ||||
| return {}; | return {}; | ||||
| } | } | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #if MGB_ENABLE_DOT | |||||
| /* ===================== stride1 algo ===================== */ | /* ===================== stride1 algo ===================== */ | ||||
| bool ConvBiasImpl::AlgoDotU8DirectStride1::usable(const NCBKernSizeParam& param, | bool ConvBiasImpl::AlgoDotU8DirectStride1::usable(const NCBKernSizeParam& param, | ||||
| AlgoSelectionStrategy) const { | AlgoSelectionStrategy) const { | ||||
| if (!cpuinfo_has_arm_neon_dot()){ | |||||
| return false; | |||||
| } | |||||
| return direct_dotprod_quint8_stride1::can_conv_direct_stride1_quint8(param); | return direct_dotprod_quint8_stride1::can_conv_direct_stride1_quint8(param); | ||||
| } | } | ||||
| @@ -118,6 +122,9 @@ ConvBiasImpl::AlgoDotU8DirectStride1::dispatch_kerns( | |||||
| /* ===================== stride2 algo ===================== */ | /* ===================== stride2 algo ===================== */ | ||||
| bool ConvBiasImpl::AlgoDotU8DirectStride2::usable(const NCBKernSizeParam& param, | bool ConvBiasImpl::AlgoDotU8DirectStride2::usable(const NCBKernSizeParam& param, | ||||
| AlgoSelectionStrategy) const { | AlgoSelectionStrategy) const { | ||||
| if (!cpuinfo_has_arm_neon_dot()){ | |||||
| return false; | |||||
| } | |||||
| return direct_dotprod_quint8_stride2::can_conv_direct_stride2_quint8(param); | return direct_dotprod_quint8_stride2::can_conv_direct_stride2_quint8(param); | ||||
| } | } | ||||
| @@ -55,7 +55,7 @@ public: | |||||
| } | } | ||||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD2_QU8) | MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD2_QU8) | ||||
| }; | }; | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #if MGB_ENABLE_DOT | |||||
| class ConvBiasImpl::AlgoDotU8DirectStride1 final : public AlgoBase { | class ConvBiasImpl::AlgoDotU8DirectStride1 final : public AlgoBase { | ||||
| public: | public: | ||||
| @@ -9,8 +9,8 @@ | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| */ | */ | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #include "src/arm_common/conv_bias/quint8/direct_dotprod.h" | #include "src/arm_common/conv_bias/quint8/direct_dotprod.h" | ||||
| #if MGB_ENABLE_DOT | |||||
| #include "src/arm_common/elemwise_op.h" | #include "src/arm_common/elemwise_op.h" | ||||
| #include "src/arm_common/simd_macro/marm_neon.h" | #include "src/arm_common/simd_macro/marm_neon.h" | ||||
| #include "src/common/utils.h" | #include "src/common/utils.h" | ||||
| @@ -120,6 +120,7 @@ inline int8x16_t vqtbl1q_s8_v7(int8x16_t a, uint8x16_t index){ | |||||
| template <bool first_ic, bool last_ic, bool fused_kern, BiasMode bias_mode, | template <bool first_ic, bool last_ic, bool fused_kern, BiasMode bias_mode, | ||||
| typename Op> | typename Op> | ||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
| void conv_bias::conv_direct_stride1_2x2_quint8_dot( | void conv_bias::conv_direct_stride1_2x2_quint8_dot( | ||||
| const uint8_t* src, const uint8_t* filter, const int32_t* bias, | const uint8_t* src, const uint8_t* filter, const int32_t* bias, | ||||
| int32_t* temp, uint8_t* dst, const size_t IH, const size_t IW, | int32_t* temp, uint8_t* dst, const size_t IH, const size_t IW, | ||||
| @@ -452,6 +453,7 @@ void conv_bias::conv_direct_stride1_2x2_quint8_dot( | |||||
| template <bool first_ic, bool last_ic, bool fused_kern, BiasMode bias_mode, | template <bool first_ic, bool last_ic, bool fused_kern, BiasMode bias_mode, | ||||
| typename Op> | typename Op> | ||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
| void conv_bias::conv_direct_stride1_3x3_quint8_dot( | void conv_bias::conv_direct_stride1_3x3_quint8_dot( | ||||
| const uint8_t* src, const uint8_t* filter, const int32_t* bias, | const uint8_t* src, const uint8_t* filter, const int32_t* bias, | ||||
| int32_t* temp, uint8_t* dst, const size_t IH, const size_t IW, | int32_t* temp, uint8_t* dst, const size_t IH, const size_t IW, | ||||
| @@ -691,6 +693,7 @@ void conv_bias::conv_direct_stride1_3x3_quint8_dot( | |||||
| template <bool first_ic, bool last_ic, bool fused_kern, BiasMode bias_mode, | template <bool first_ic, bool last_ic, bool fused_kern, BiasMode bias_mode, | ||||
| typename Op> | typename Op> | ||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
| void conv_bias::conv_direct_stride2_2x2_quint8_dot( | void conv_bias::conv_direct_stride2_2x2_quint8_dot( | ||||
| const uint8_t* src, const uint8_t* filter, const int32_t* bias, | const uint8_t* src, const uint8_t* filter, const int32_t* bias, | ||||
| int32_t* temp, uint8_t* dst, const size_t IH, const size_t IW, | int32_t* temp, uint8_t* dst, const size_t IH, const size_t IW, | ||||
| @@ -801,6 +804,7 @@ void conv_bias::conv_direct_stride2_2x2_quint8_dot( | |||||
| template <bool first_ic, bool last_ic, bool fused_kern, BiasMode bias_mode, | template <bool first_ic, bool last_ic, bool fused_kern, BiasMode bias_mode, | ||||
| typename Op> | typename Op> | ||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
| void conv_bias::conv_direct_stride2_3x3_quint8_dot( | void conv_bias::conv_direct_stride2_3x3_quint8_dot( | ||||
| const uint8_t* src, const uint8_t* filter, const int32_t* bias, | const uint8_t* src, const uint8_t* filter, const int32_t* bias, | ||||
| int32_t* temp, uint8_t* dst, const size_t IH, const size_t IW, | int32_t* temp, uint8_t* dst, const size_t IH, const size_t IW, | ||||
| @@ -1135,6 +1139,7 @@ void conv_bias::conv_direct_stride2_3x3_quint8_dot( | |||||
| template <bool first_ic, bool last_ic, bool fused_kern, BiasMode bias_mode, | template <bool first_ic, bool last_ic, bool fused_kern, BiasMode bias_mode, | ||||
| typename Op> | typename Op> | ||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
| void conv_bias::conv_direct_stride1_5x5_quint8_dot( | void conv_bias::conv_direct_stride1_5x5_quint8_dot( | ||||
| const uint8_t* src, const uint8_t* filter, const int32_t* bias, | const uint8_t* src, const uint8_t* filter, const int32_t* bias, | ||||
| int32_t* temp, uint8_t* dst, const size_t IH, const size_t IW, | int32_t* temp, uint8_t* dst, const size_t IH, const size_t IW, | ||||
| @@ -1443,6 +1448,7 @@ void conv_bias::conv_direct_stride1_5x5_quint8_dot( | |||||
| template <bool first_ic, bool last_ic, bool fused_kern, BiasMode bias_mode, | template <bool first_ic, bool last_ic, bool fused_kern, BiasMode bias_mode, | ||||
| typename Op> | typename Op> | ||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
| void conv_bias::conv_direct_stride1_7x7_quint8_dot( | void conv_bias::conv_direct_stride1_7x7_quint8_dot( | ||||
| const uint8_t* src, const uint8_t* filter, const int32_t* bias, | const uint8_t* src, const uint8_t* filter, const int32_t* bias, | ||||
| int32_t* temp, uint8_t* dst, const size_t IH, const size_t IW, | int32_t* temp, uint8_t* dst, const size_t IH, const size_t IW, | ||||
| @@ -1785,6 +1791,7 @@ void conv_bias::conv_direct_stride1_7x7_quint8_dot( | |||||
| template <bool first_ic, bool last_ic, bool fused_kern, BiasMode bias_mode, | template <bool first_ic, bool last_ic, bool fused_kern, BiasMode bias_mode, | ||||
| typename Op> | typename Op> | ||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
| void conv_bias::conv_direct_stride2_5x5_quint8_dot( | void conv_bias::conv_direct_stride2_5x5_quint8_dot( | ||||
| const uint8_t* src, const uint8_t* filter, const int32_t* bias, | const uint8_t* src, const uint8_t* filter, const int32_t* bias, | ||||
| int32_t* temp, uint8_t* dst, const size_t IH, const size_t IW, | int32_t* temp, uint8_t* dst, const size_t IH, const size_t IW, | ||||
| @@ -2090,6 +2097,7 @@ void conv_bias::conv_direct_stride2_5x5_quint8_dot( | |||||
| template <bool first_ic, bool last_ic, bool fused_kern, BiasMode bias_mode, | template <bool first_ic, bool last_ic, bool fused_kern, BiasMode bias_mode, | ||||
| typename Op> | typename Op> | ||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
| void conv_bias::conv_direct_stride2_7x7_quint8_dot( | void conv_bias::conv_direct_stride2_7x7_quint8_dot( | ||||
| const uint8_t* src, const uint8_t* filter, const int32_t* bias, | const uint8_t* src, const uint8_t* filter, const int32_t* bias, | ||||
| int32_t* temp, uint8_t* dst, const size_t IH, const size_t IW, | int32_t* temp, uint8_t* dst, const size_t IH, const size_t IW, | ||||
| @@ -8,9 +8,9 @@ | |||||
| * software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| */ | */ | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #include "src/arm_common/conv_bias/opr_impl.h" | #include "src/arm_common/conv_bias/opr_impl.h" | ||||
| #if MGB_ENABLE_DOT | |||||
| #include "src/fallback/conv_bias/common.h" | #include "src/fallback/conv_bias/common.h" | ||||
| namespace megdnn { | namespace megdnn { | ||||
| @@ -8,8 +8,8 @@ | |||||
| * software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| */ | */ | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #include "src/arm_common/conv_bias/quint8/stride1_dotprod.h" | #include "src/arm_common/conv_bias/quint8/stride1_dotprod.h" | ||||
| #if MGB_ENABLE_DOT | |||||
| #include "megdnn/oprs.h" | #include "megdnn/oprs.h" | ||||
| #include "src/arm_common/conv_bias/quint8/direct_dotprod.h" | #include "src/arm_common/conv_bias/quint8/direct_dotprod.h" | ||||
| #include "src/arm_common/elemwise_op.h" | #include "src/arm_common/elemwise_op.h" | ||||
| @@ -8,10 +8,10 @@ | |||||
| * software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| */ | */ | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #pragma once | #pragma once | ||||
| #include "src/arm_common/conv_bias/opr_impl.h" | #include "src/arm_common/conv_bias/opr_impl.h" | ||||
| #if MGB_ENABLE_DOT | |||||
| namespace megdnn { | namespace megdnn { | ||||
| namespace arm_common { | namespace arm_common { | ||||
| @@ -8,8 +8,8 @@ | |||||
| * software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| */ | */ | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #include "src/arm_common/conv_bias/quint8/stride2_dotprod.h" | #include "src/arm_common/conv_bias/quint8/stride2_dotprod.h" | ||||
| #if MGB_ENABLE_DOT | |||||
| #include "megdnn/oprs.h" | #include "megdnn/oprs.h" | ||||
| #include "src/arm_common/conv_bias/quint8/direct_dotprod.h" | #include "src/arm_common/conv_bias/quint8/direct_dotprod.h" | ||||
| #include "src/arm_common/elemwise_op.h" | #include "src/arm_common/elemwise_op.h" | ||||
| @@ -8,10 +8,10 @@ | |||||
| * software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| */ | */ | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #pragma once | #pragma once | ||||
| #include "src/arm_common/conv_bias/opr_impl.h" | #include "src/arm_common/conv_bias/opr_impl.h" | ||||
| #if MGB_ENABLE_DOT | |||||
| namespace megdnn { | namespace megdnn { | ||||
| namespace arm_common { | namespace arm_common { | ||||
| @@ -13,21 +13,24 @@ | |||||
| #include "src/arm_common/convolution/int8x8x32/algos.h" | #include "src/arm_common/convolution/int8x8x32/algos.h" | ||||
| #include "src/arm_common/convolution/int8x8x32/conv_backdata_stride1.h" | #include "src/arm_common/convolution/int8x8x32/conv_backdata_stride1.h" | ||||
| #include "src/arm_common/convolution/int8x8x32/conv_backdata_stride2.h" | #include "src/arm_common/convolution/int8x8x32/conv_backdata_stride2.h" | ||||
| #include "src/common/opr_delegate.h" | |||||
| #include "midout.h" | #include "midout.h" | ||||
| #include "src/common/opr_delegate.h" | |||||
| MIDOUT_DECL(megdnn_arm_conv_int8832_kimpl) | MIDOUT_DECL(megdnn_arm_conv_int8832_kimpl) | ||||
| using namespace megdnn; | using namespace megdnn; | ||||
| using namespace arm_common; | using namespace arm_common; | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #if MGB_ENABLE_DOT | |||||
| /* ===================== ConvolutionBackwardData ===================== */ | /* ===================== ConvolutionBackwardData ===================== */ | ||||
| /* ===================== direct stride 1 algo ===================== */ | /* ===================== direct stride 1 algo ===================== */ | ||||
| bool ConvolutionBackwardDataImpl::AlgoSdot8DirectStride1::usable( | bool ConvolutionBackwardDataImpl::AlgoSdot8DirectStride1::usable( | ||||
| fallback::ConvolutionBackwardDataImpl*, | fallback::ConvolutionBackwardDataImpl*, | ||||
| const NCBKernSizeParam& param) const { | const NCBKernSizeParam& param) const { | ||||
| if (!cpuinfo_has_arm_neon_dot()){ | |||||
| return false; | |||||
| } | |||||
| return deconv::can_stride1_int8x8x32_dot(param); | return deconv::can_stride1_int8x8x32_dot(param); | ||||
| } | } | ||||
| @@ -57,6 +60,9 @@ ConvolutionBackwardDataImpl::AlgoSdot8DirectStride1::dispatch_kern( | |||||
| bool ConvolutionBackwardDataImpl::AlgoSdot8DirectStride2::usable( | bool ConvolutionBackwardDataImpl::AlgoSdot8DirectStride2::usable( | ||||
| fallback::ConvolutionBackwardDataImpl*, | fallback::ConvolutionBackwardDataImpl*, | ||||
| const NCBKernSizeParam& param) const { | const NCBKernSizeParam& param) const { | ||||
| if (!cpuinfo_has_arm_neon_dot()){ | |||||
| return false; | |||||
| } | |||||
| return deconv::can_stride2_int8x8x32_dot(param); | return deconv::can_stride2_int8x8x32_dot(param); | ||||
| } | } | ||||
| @@ -17,7 +17,7 @@ | |||||
| namespace megdnn { | namespace megdnn { | ||||
| namespace arm_common { | namespace arm_common { | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #if MGB_ENABLE_DOT | |||||
| /* ===================== ConvolutionBackwardData ===================== */ | /* ===================== ConvolutionBackwardData ===================== */ | ||||
| class ConvolutionBackwardDataImpl::AlgoSdot8DirectStride1 final | class ConvolutionBackwardDataImpl::AlgoSdot8DirectStride1 final | ||||
| @@ -9,11 +9,9 @@ | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| */ | */ | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #include "src/arm_common/convolution/int8x8x32/conv_backdata_stride1.h" | #include "src/arm_common/convolution/int8x8x32/conv_backdata_stride1.h" | ||||
| #if MGB_ENABLE_DOT | |||||
| #include "src/common/utils.h" | #include "src/common/utils.h" | ||||
| #include <cstring> | |||||
| #include "src/arm_common/simd_macro/marm_neon.h" | #include "src/arm_common/simd_macro/marm_neon.h" | ||||
| using namespace megdnn; | using namespace megdnn; | ||||
| @@ -94,6 +92,7 @@ inline int8x16_t vqtbl1q_s8_common(int8x16_t a, uint8x16_t index) { | |||||
| _sum0##_c_idx = vdotq_s32(_sum0##_c_idx, _k##_k1_idx, _elem); \ | _sum0##_c_idx = vdotq_s32(_sum0##_c_idx, _k##_k1_idx, _elem); \ | ||||
| _sum1##_c_idx = vdotq_s32(_sum1##_c_idx, _k##_k2_idx, _elem); | _sum1##_c_idx = vdotq_s32(_sum1##_c_idx, _k##_k2_idx, _elem); | ||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
| void deconv_direct_2x2(const int8_t* src, const int8_t* filter, int32_t* dst, | void deconv_direct_2x2(const int8_t* src, const int8_t* filter, int32_t* dst, | ||||
| size_t IH, size_t IW, size_t OH, size_t OW, size_t IC) { | size_t IH, size_t IW, size_t OH, size_t OW, size_t IC) { | ||||
| MEGDNN_MARK_USED_VAR(IH); | MEGDNN_MARK_USED_VAR(IH); | ||||
| @@ -328,6 +327,7 @@ void deconv_direct_2x2(const int8_t* src, const int8_t* filter, int32_t* dst, | |||||
| } | } | ||||
| } | } | ||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
| void deconv_direct_3x3(const int8_t* src, const int8_t* filter, int32_t* dst, | void deconv_direct_3x3(const int8_t* src, const int8_t* filter, int32_t* dst, | ||||
| size_t IH, size_t IW, size_t OH, size_t OW, size_t IC) { | size_t IH, size_t IW, size_t OH, size_t OW, size_t IC) { | ||||
| MEGDNN_MARK_USED_VAR(IH); | MEGDNN_MARK_USED_VAR(IH); | ||||
| @@ -530,6 +530,7 @@ void deconv_direct_3x3(const int8_t* src, const int8_t* filter, int32_t* dst, | |||||
| _sum0##_c_idx = vdotq_s32(_sum0##_c_idx, _k##_k01_idx, _elem); \ | _sum0##_c_idx = vdotq_s32(_sum0##_c_idx, _k##_k01_idx, _elem); \ | ||||
| _sum1##_c_idx = vdotq_s32(_sum1##_c_idx, _k##_k11_idx, _elem); | _sum1##_c_idx = vdotq_s32(_sum1##_c_idx, _k##_k11_idx, _elem); | ||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
| void deconv_direct_5x5(const int8_t* src, const int8_t* filter, int32_t* dst, | void deconv_direct_5x5(const int8_t* src, const int8_t* filter, int32_t* dst, | ||||
| size_t IH, size_t IW, size_t OH, size_t OW, size_t IC) { | size_t IH, size_t IW, size_t OH, size_t OW, size_t IC) { | ||||
| MEGDNN_MARK_USED_VAR(IH); | MEGDNN_MARK_USED_VAR(IH); | ||||
| @@ -777,6 +778,7 @@ void deconv_direct_5x5(const int8_t* src, const int8_t* filter, int32_t* dst, | |||||
| } | } | ||||
| } | } | ||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
| void deconv_direct_7x7(const int8_t* src, const int8_t* filter, int32_t* dst, | void deconv_direct_7x7(const int8_t* src, const int8_t* filter, int32_t* dst, | ||||
| size_t IH, size_t IW, size_t OH, size_t OW, size_t IC) { | size_t IH, size_t IW, size_t OH, size_t OW, size_t IC) { | ||||
| MEGDNN_MARK_USED_VAR(IH); | MEGDNN_MARK_USED_VAR(IH); | ||||
| @@ -1070,6 +1072,7 @@ void deconv_direct_7x7(const int8_t* src, const int8_t* filter, int32_t* dst, | |||||
| } // anonymous namespace | } // anonymous namespace | ||||
| size_t deconv::get_workspace_in_bytes_stride1_int8x8x32_dot( | size_t deconv::get_workspace_in_bytes_stride1_int8x8x32_dot( | ||||
| const NCBKernSizeParam& param) { | const NCBKernSizeParam& param) { | ||||
| return get_bundle(param).total_size_in_bytes(); | return get_bundle(param).total_size_in_bytes(); | ||||
| @@ -10,8 +10,8 @@ | |||||
| */ | */ | ||||
| #pragma once | #pragma once | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #include "src/arm_common/convolution/opr_impl.h" | #include "src/arm_common/convolution/opr_impl.h" | ||||
| #if MGB_ENABLE_DOT | |||||
| #include <cstddef> | #include <cstddef> | ||||
| #include <cstdint> | #include <cstdint> | ||||
| @@ -9,11 +9,9 @@ | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| */ | */ | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #include "src/arm_common/convolution/int8x8x32/conv_backdata_stride2.h" | #include "src/arm_common/convolution/int8x8x32/conv_backdata_stride2.h" | ||||
| #if MGB_ENABLE_DOT | |||||
| #include "src/common/utils.h" | #include "src/common/utils.h" | ||||
| #include <cstring> | |||||
| #include "src/arm_common/simd_macro/marm_neon.h" | #include "src/arm_common/simd_macro/marm_neon.h" | ||||
| using namespace megdnn; | using namespace megdnn; | ||||
| @@ -83,6 +81,7 @@ inline int8x16_t vqtbl1q_s8_common(int8x16_t a, uint8x16_t index) { | |||||
| _sum1##_c_idx = vdotq_s32(_sum1##_c_idx, _k##_k2_idx, _elem); | _sum1##_c_idx = vdotq_s32(_sum1##_c_idx, _k##_k2_idx, _elem); | ||||
| template <bool even> | template <bool even> | ||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
| void deconv_direct_2x2(const int8_t* src, const int8_t* filter, int32_t* dst, | void deconv_direct_2x2(const int8_t* src, const int8_t* filter, int32_t* dst, | ||||
| size_t IH, size_t IW, size_t OH, size_t OW, size_t IC) { | size_t IH, size_t IW, size_t OH, size_t OW, size_t IC) { | ||||
| MEGDNN_MARK_USED_VAR(IH); | MEGDNN_MARK_USED_VAR(IH); | ||||
| @@ -334,6 +333,7 @@ void deconv_direct_2x2(const int8_t* src, const int8_t* filter, int32_t* dst, | |||||
| } | } | ||||
| template <bool even> | template <bool even> | ||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
| void deconv_direct_3x3(const int8_t* src, const int8_t* filter, int32_t* dst, | void deconv_direct_3x3(const int8_t* src, const int8_t* filter, int32_t* dst, | ||||
| size_t IH, size_t IW, size_t OH, size_t OW, size_t IC) { | size_t IH, size_t IW, size_t OH, size_t OW, size_t IC) { | ||||
| MEGDNN_MARK_USED_VAR(IH); | MEGDNN_MARK_USED_VAR(IH); | ||||
| @@ -558,6 +558,7 @@ void deconv_direct_3x3(const int8_t* src, const int8_t* filter, int32_t* dst, | |||||
| _sum1##_c_idx = vdotq_s32(_sum1##_c_idx, _k##_k11_idx, _elem); | _sum1##_c_idx = vdotq_s32(_sum1##_c_idx, _k##_k11_idx, _elem); | ||||
| template <bool even> | template <bool even> | ||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
| void deconv_direct_5x5(const int8_t* src, const int8_t* filter, int32_t* dst, | void deconv_direct_5x5(const int8_t* src, const int8_t* filter, int32_t* dst, | ||||
| size_t IH, size_t IW, size_t OH, size_t OW, size_t IC) { | size_t IH, size_t IW, size_t OH, size_t OW, size_t IC) { | ||||
| MEGDNN_MARK_USED_VAR(IH); | MEGDNN_MARK_USED_VAR(IH); | ||||
| @@ -835,6 +836,7 @@ void deconv_direct_5x5(const int8_t* src, const int8_t* filter, int32_t* dst, | |||||
| } | } | ||||
| template <bool even> | template <bool even> | ||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
| void deconv_direct_7x7(const int8_t* src, const int8_t* filter, int32_t* dst, | void deconv_direct_7x7(const int8_t* src, const int8_t* filter, int32_t* dst, | ||||
| size_t IH, size_t IW, size_t OH, size_t OW, size_t IC) { | size_t IH, size_t IW, size_t OH, size_t OW, size_t IC) { | ||||
| MEGDNN_MARK_USED_VAR(IH); | MEGDNN_MARK_USED_VAR(IH); | ||||
| @@ -10,8 +10,8 @@ | |||||
| */ | */ | ||||
| #pragma once | #pragma once | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #include "src/arm_common/convolution/opr_impl.h" | #include "src/arm_common/convolution/opr_impl.h" | ||||
| #if MGB_ENABLE_DOT | |||||
| #include <cstddef> | #include <cstddef> | ||||
| #include <cstdint> | #include <cstdint> | ||||
| @@ -24,7 +24,7 @@ using namespace arm_common; | |||||
| /* ===================== ConvolutionBackwardData ===================== */ | /* ===================== ConvolutionBackwardData ===================== */ | ||||
| class ConvolutionBackwardDataImpl::AlgoPack : NonCopyableObj { | class ConvolutionBackwardDataImpl::AlgoPack : NonCopyableObj { | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #if MGB_ENABLE_DOT | |||||
| AlgoSdot8DirectStride1 i8x8x32_direct_stride1_sdot; | AlgoSdot8DirectStride1 i8x8x32_direct_stride1_sdot; | ||||
| AlgoSdot8DirectStride2 i8x8x32_direct_stride2_sdot; | AlgoSdot8DirectStride2 i8x8x32_direct_stride2_sdot; | ||||
| AlgoUdot8DirectStride1 quint8_direct_stride1_udot; | AlgoUdot8DirectStride1 quint8_direct_stride1_udot; | ||||
| @@ -37,7 +37,7 @@ class ConvolutionBackwardDataImpl::AlgoPack : NonCopyableObj { | |||||
| public: | public: | ||||
| AlgoPack() { | AlgoPack() { | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #if MGB_ENABLE_DOT | |||||
| m_all_algos.emplace_back(&i8x8x32_direct_stride1_sdot); | m_all_algos.emplace_back(&i8x8x32_direct_stride1_sdot); | ||||
| m_all_algos.emplace_back(&i8x8x32_direct_stride2_sdot); | m_all_algos.emplace_back(&i8x8x32_direct_stride2_sdot); | ||||
| m_all_algos.emplace_back(&quint8_direct_stride1_udot); | m_all_algos.emplace_back(&quint8_direct_stride1_udot); | ||||
| @@ -56,7 +56,7 @@ public: | |||||
| MEGDNN_FB_DECL_GET_ALGO_FROM_DESC(ConvolutionBackwardDataImpl); | MEGDNN_FB_DECL_GET_ALGO_FROM_DESC(ConvolutionBackwardDataImpl); | ||||
| private: | private: | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #if MGB_ENABLE_DOT | |||||
| class AlgoSdot8DirectStride1; | class AlgoSdot8DirectStride1; | ||||
| class AlgoSdot8DirectStride2; | class AlgoSdot8DirectStride2; | ||||
| class AlgoUdot8DirectStride1; | class AlgoUdot8DirectStride1; | ||||
| @@ -14,6 +14,7 @@ | |||||
| #include "src/arm_common/convolution/quint8/conv_backdata_stride1.h" | #include "src/arm_common/convolution/quint8/conv_backdata_stride1.h" | ||||
| #include "src/arm_common/convolution/quint8/conv_backdata_stride2.h" | #include "src/arm_common/convolution/quint8/conv_backdata_stride2.h" | ||||
| #include "src/common/opr_delegate.h" | #include "src/common/opr_delegate.h" | ||||
| #include "midout.h" | #include "midout.h" | ||||
| MIDOUT_DECL(megdnn_arm_conv_quint8_kimpl) | MIDOUT_DECL(megdnn_arm_conv_quint8_kimpl) | ||||
| @@ -21,7 +22,7 @@ MIDOUT_DECL(megdnn_arm_conv_quint8_kimpl) | |||||
| using namespace megdnn; | using namespace megdnn; | ||||
| using namespace arm_common; | using namespace arm_common; | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #if MGB_ENABLE_DOT | |||||
| /* ===================== ConvolutionBackwardData ===================== */ | /* ===================== ConvolutionBackwardData ===================== */ | ||||
| @@ -29,6 +30,10 @@ using namespace arm_common; | |||||
| bool ConvolutionBackwardDataImpl::AlgoUdot8DirectStride1::usable( | bool ConvolutionBackwardDataImpl::AlgoUdot8DirectStride1::usable( | ||||
| fallback::ConvolutionBackwardDataImpl*, | fallback::ConvolutionBackwardDataImpl*, | ||||
| const NCBKernSizeParam& param) const { | const NCBKernSizeParam& param) const { | ||||
| if (!cpuinfo_has_arm_neon_dot()){ | |||||
| return false; | |||||
| } | |||||
| return deconv::can_stride1_quint8_dot(param); | return deconv::can_stride1_quint8_dot(param); | ||||
| } | } | ||||
| @@ -58,6 +63,9 @@ ConvolutionBackwardDataImpl::AlgoUdot8DirectStride1::dispatch_kern( | |||||
| bool ConvolutionBackwardDataImpl::AlgoUdot8DirectStride2::usable( | bool ConvolutionBackwardDataImpl::AlgoUdot8DirectStride2::usable( | ||||
| fallback::ConvolutionBackwardDataImpl*, | fallback::ConvolutionBackwardDataImpl*, | ||||
| const NCBKernSizeParam& param) const { | const NCBKernSizeParam& param) const { | ||||
| if (!cpuinfo_has_arm_neon_dot()){ | |||||
| return false; | |||||
| } | |||||
| return deconv::can_stride2_quint8_dot(param); | return deconv::can_stride2_quint8_dot(param); | ||||
| } | } | ||||
| @@ -17,7 +17,7 @@ | |||||
| namespace megdnn { | namespace megdnn { | ||||
| namespace arm_common { | namespace arm_common { | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #if MGB_ENABLE_DOT | |||||
| /* ===================== ConvolutionBackwardData ===================== */ | /* ===================== ConvolutionBackwardData ===================== */ | ||||
| class ConvolutionBackwardDataImpl::AlgoUdot8DirectStride1 final | class ConvolutionBackwardDataImpl::AlgoUdot8DirectStride1 final | ||||
| : public AlgoBase { | : public AlgoBase { | ||||
| @@ -9,11 +9,9 @@ | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| */ | */ | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #include "src/arm_common/convolution/quint8/conv_backdata_stride1.h" | #include "src/arm_common/convolution/quint8/conv_backdata_stride1.h" | ||||
| #if MGB_ENABLE_DOT | |||||
| #include "src/common/utils.h" | #include "src/common/utils.h" | ||||
| #include <cstring> | |||||
| #include "src/arm_common/simd_macro/marm_neon.h" | #include "src/arm_common/simd_macro/marm_neon.h" | ||||
| using namespace megdnn; | using namespace megdnn; | ||||
| @@ -109,6 +107,7 @@ inline uint8x16_t vqtbl1q_u8_common(uint8x16_t a, uint8x16_t index) { | |||||
| _sum1##_c_idx = vsubq_u32(_sum1##_c_idx, vdotq2_u32(_filter_zp, _elem)); | _sum1##_c_idx = vsubq_u32(_sum1##_c_idx, vdotq2_u32(_filter_zp, _elem)); | ||||
| template <bool last_oc = false> | template <bool last_oc = false> | ||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
| void deconv_direct_2x2(const uint8_t* src, const uint8_t* filter, int32_t* dst, | void deconv_direct_2x2(const uint8_t* src, const uint8_t* filter, int32_t* dst, | ||||
| size_t IH, size_t IW, size_t OH, size_t OW, size_t IC, | size_t IH, size_t IW, size_t OH, size_t OW, size_t IC, | ||||
| uint8_t src_zp, uint8_t filter_zp, | uint8_t src_zp, uint8_t filter_zp, | ||||
| @@ -385,6 +384,7 @@ void deconv_direct_2x2(const uint8_t* src, const uint8_t* filter, int32_t* dst, | |||||
| } | } | ||||
| template <bool last_oc = false> | template <bool last_oc = false> | ||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
| void deconv_direct_3x3(const uint8_t* src, const uint8_t* filter, int32_t* dst, | void deconv_direct_3x3(const uint8_t* src, const uint8_t* filter, int32_t* dst, | ||||
| size_t IH, size_t IW, size_t OH, size_t OW, size_t IC, | size_t IH, size_t IW, size_t OH, size_t OW, size_t IC, | ||||
| uint8_t src_zp, uint8_t filter_zp, | uint8_t src_zp, uint8_t filter_zp, | ||||
| @@ -636,6 +636,7 @@ void deconv_direct_3x3(const uint8_t* src, const uint8_t* filter, int32_t* dst, | |||||
| _sum1##_c_idx = vsubq_u32(_sum1##_c_idx, _elem2); | _sum1##_c_idx = vsubq_u32(_sum1##_c_idx, _elem2); | ||||
| template <bool last_oc = false> | template <bool last_oc = false> | ||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
| void deconv_direct_5x5(const uint8_t* src, const uint8_t* filter, int32_t* dst, | void deconv_direct_5x5(const uint8_t* src, const uint8_t* filter, int32_t* dst, | ||||
| size_t IH, size_t IW, size_t OH, size_t OW, size_t IC, | size_t IH, size_t IW, size_t OH, size_t OW, size_t IC, | ||||
| uint8_t src_zp, uint8_t filter_zp, | uint8_t src_zp, uint8_t filter_zp, | ||||
| @@ -907,6 +908,7 @@ void deconv_direct_5x5(const uint8_t* src, const uint8_t* filter, int32_t* dst, | |||||
| } | } | ||||
| template <bool last_oc = false> | template <bool last_oc = false> | ||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
| void deconv_direct_7x7(const uint8_t* src, const uint8_t* filter, int32_t* dst, | void deconv_direct_7x7(const uint8_t* src, const uint8_t* filter, int32_t* dst, | ||||
| size_t IH, size_t IW, size_t OH, size_t OW, size_t IC, | size_t IH, size_t IW, size_t OH, size_t OW, size_t IC, | ||||
| uint8_t src_zp, uint8_t filter_zp, | uint8_t src_zp, uint8_t filter_zp, | ||||
| @@ -1220,6 +1222,7 @@ void deconv_direct_7x7(const uint8_t* src, const uint8_t* filter, int32_t* dst, | |||||
| } // anonymous namespace | } // anonymous namespace | ||||
| size_t deconv::get_workspace_in_bytes_stride1_quint8_dot( | size_t deconv::get_workspace_in_bytes_stride1_quint8_dot( | ||||
| const NCBKernSizeParam& param) { | const NCBKernSizeParam& param) { | ||||
| return get_bundle(param).total_size_in_bytes(); | return get_bundle(param).total_size_in_bytes(); | ||||
| @@ -10,11 +10,8 @@ | |||||
| */ | */ | ||||
| #pragma once | #pragma once | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #include "src/arm_common/convolution/opr_impl.h" | #include "src/arm_common/convolution/opr_impl.h" | ||||
| #include <cstddef> | |||||
| #include <cstdint> | |||||
| #if MGB_ENABLE_DOT | |||||
| namespace megdnn { | namespace megdnn { | ||||
| namespace arm_common { | namespace arm_common { | ||||
| @@ -9,11 +9,9 @@ | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| */ | */ | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #include "src/arm_common/convolution/quint8/conv_backdata_stride2.h" | #include "src/arm_common/convolution/quint8/conv_backdata_stride2.h" | ||||
| #if MGB_ENABLE_DOT | |||||
| #include "src/common/utils.h" | #include "src/common/utils.h" | ||||
| #include <cstring> | |||||
| #include "src/arm_common/simd_macro/marm_neon.h" | #include "src/arm_common/simd_macro/marm_neon.h" | ||||
| using namespace megdnn; | using namespace megdnn; | ||||
| @@ -110,6 +108,7 @@ inline uint8x16_t vqtbx1q_u8_common(uint8x16_t a, uint8x16_t t, | |||||
| _sum1##_c_idx = vsubq_u32(_sum1##_c_idx, vdotq2_u32(_filter_zp, _elem)); | _sum1##_c_idx = vsubq_u32(_sum1##_c_idx, vdotq2_u32(_filter_zp, _elem)); | ||||
| template <bool even, bool last_oc = false> | template <bool even, bool last_oc = false> | ||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
| void deconv_direct_2x2(const uint8_t* src, const uint8_t* filter, int32_t* dst, | void deconv_direct_2x2(const uint8_t* src, const uint8_t* filter, int32_t* dst, | ||||
| size_t IH, size_t IW, size_t OH, size_t OW, size_t IC, | size_t IH, size_t IW, size_t OH, size_t OW, size_t IC, | ||||
| uint8_t src_zp, uint8_t filter_zp, | uint8_t src_zp, uint8_t filter_zp, | ||||
| @@ -402,6 +401,7 @@ void deconv_direct_2x2(const uint8_t* src, const uint8_t* filter, int32_t* dst, | |||||
| } | } | ||||
| template <bool even, bool last_oc = false> | template <bool even, bool last_oc = false> | ||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
| void deconv_direct_3x3(const uint8_t* src, const uint8_t* filter, int32_t* dst, | void deconv_direct_3x3(const uint8_t* src, const uint8_t* filter, int32_t* dst, | ||||
| size_t IH, size_t IW, size_t OH, size_t OW, size_t IC, | size_t IH, size_t IW, size_t OH, size_t OW, size_t IC, | ||||
| uint8_t src_zp, uint8_t filter_zp, | uint8_t src_zp, uint8_t filter_zp, | ||||
| @@ -673,6 +673,7 @@ void deconv_direct_3x3(const uint8_t* src, const uint8_t* filter, int32_t* dst, | |||||
| _sum1##_c_idx = vsubq_u32(_sum1##_c_idx, _elem2); | _sum1##_c_idx = vsubq_u32(_sum1##_c_idx, _elem2); | ||||
| template <bool even, bool last_oc = false> | template <bool even, bool last_oc = false> | ||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
| void deconv_direct_5x5(const uint8_t* src, const uint8_t* filter, int32_t* dst, | void deconv_direct_5x5(const uint8_t* src, const uint8_t* filter, int32_t* dst, | ||||
| size_t IH, size_t IW, size_t OH, size_t OW, size_t IC, | size_t IH, size_t IW, size_t OH, size_t OW, size_t IC, | ||||
| uint8_t src_zp, uint8_t filter_zp, | uint8_t src_zp, uint8_t filter_zp, | ||||
| @@ -972,6 +973,7 @@ void deconv_direct_5x5(const uint8_t* src, const uint8_t* filter, int32_t* dst, | |||||
| } | } | ||||
| template <bool even, bool last_oc = false> | template <bool even, bool last_oc = false> | ||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
| void deconv_direct_7x7(const uint8_t* src, const uint8_t* filter, int32_t* dst, | void deconv_direct_7x7(const uint8_t* src, const uint8_t* filter, int32_t* dst, | ||||
| size_t IH, size_t IW, size_t OH, size_t OW, size_t IC, | size_t IH, size_t IW, size_t OH, size_t OW, size_t IC, | ||||
| uint8_t src_zp, uint8_t filter_zp, | uint8_t src_zp, uint8_t filter_zp, | ||||
| @@ -10,11 +10,8 @@ | |||||
| */ | */ | ||||
| #pragma once | #pragma once | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #include "src/arm_common/convolution/opr_impl.h" | #include "src/arm_common/convolution/opr_impl.h" | ||||
| #include <cstddef> | |||||
| #include <cstdint> | |||||
| #if MGB_ENABLE_DOT | |||||
| namespace megdnn { | namespace megdnn { | ||||
| namespace arm_common { | namespace arm_common { | ||||
| @@ -14,8 +14,10 @@ | |||||
| #include "src/arm_common/matrix_mul/fp16/hgemv.h" | #include "src/arm_common/matrix_mul/fp16/hgemv.h" | ||||
| #include "src/arm_common/matrix_mul/fp32/exec_sgemv.h" | #include "src/arm_common/matrix_mul/fp32/exec_sgemv.h" | ||||
| #include "src/arm_common/matrix_mul/int8/gemv.h" | #include "src/arm_common/matrix_mul/int8/gemv.h" | ||||
| #include "midout.h" | #include "midout.h" | ||||
| MIDOUT_DECL(megdnn_arm_hgemv) | MIDOUT_DECL(megdnn_arm_hgemv) | ||||
| MIDOUT_DECL(megdnn_arm_exec_int8816) | MIDOUT_DECL(megdnn_arm_exec_int8816) | ||||
| MIDOUT_DECL(megdnn_arm_exec_int8832) | MIDOUT_DECL(megdnn_arm_exec_int8832) | ||||
| @@ -158,7 +160,7 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32GemvMK4::get_kern( | |||||
| return int8x8x32_gemv_mk4_kern; | return int8x8x32_gemv_mk4_kern; | ||||
| } | } | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #if MGB_ENABLE_DOT | |||||
| /* =================== Int8x8x32 Gemv MK4_DOT algo ==================== */ | /* =================== Int8x8x32 Gemv MK4_DOT algo ==================== */ | ||||
| namespace { | namespace { | ||||
| void int8x8x32_gemv_mk4_dot_kern(const MatrixMulImpl::KernParam& kern_param) { | void int8x8x32_gemv_mk4_dot_kern(const MatrixMulImpl::KernParam& kern_param) { | ||||
| @@ -176,6 +178,10 @@ void int8x8x32_gemv_mk4_dot_kern(const MatrixMulImpl::KernParam& kern_param) { | |||||
| bool MatrixMulImpl::AlgoInt8x8x32GemvMK4Dot::usable( | bool MatrixMulImpl::AlgoInt8x8x32GemvMK4Dot::usable( | ||||
| const KernSizeParam& kern_size_param) const { | const KernSizeParam& kern_size_param) const { | ||||
| if (!cpuinfo_has_arm_neon_dot()){ | |||||
| return false; | |||||
| } | |||||
| auto M = kern_size_param.M; | auto M = kern_size_param.M; | ||||
| auto N = kern_size_param.N; | auto N = kern_size_param.N; | ||||
| auto K = kern_size_param.K; | auto K = kern_size_param.K; | ||||
| @@ -63,7 +63,7 @@ public: | |||||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_INT8X8X32_GEMV_MK4) | MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_INT8X8X32_GEMV_MK4) | ||||
| }; | }; | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #if MGB_ENABLE_DOT | |||||
| class MatrixMulImpl::AlgoInt8x8x32GemvMK4Dot : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x32GemvMK4Dot : public AlgoBase { | ||||
| public: | public: | ||||
| AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
| @@ -9,7 +9,6 @@ | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| */ | */ | ||||
| #include <cstddef> | |||||
| #include "src/arm_common/simd_macro/marm_neon.h" | #include "src/arm_common/simd_macro/marm_neon.h" | ||||
| #include "src/arm_common/matrix_mul/int8/gemv.h" | #include "src/arm_common/matrix_mul/int8/gemv.h" | ||||
| #include "src/common/utils.h" | #include "src/common/utils.h" | ||||
| @@ -21,7 +20,6 @@ MIDOUT_DECL(megdnn_arm_common_int8_gemv) | |||||
| using namespace megdnn; | using namespace megdnn; | ||||
| using namespace arm_common; | using namespace arm_common; | ||||
| #if !__ARM_FEATURE_DOTPROD | |||||
| namespace { | namespace { | ||||
| @@ -170,12 +168,11 @@ void gemv_naive_n_mk4(const int8_t* __restrict A, const int8_t* __restrict B, | |||||
| } | } | ||||
| } // namespace | } // namespace | ||||
| #endif | |||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #if MGB_ENABLE_DOT | |||||
| namespace { | namespace { | ||||
| void gemv_naive_n(const int8_t* __restrict A, const int8_t* __restrict B, | |||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
| void gemv_naive_n_dot(const int8_t* __restrict A, const int8_t* __restrict B, | |||||
| int32_t* __restrict C, size_t M, size_t N, size_t K, | int32_t* __restrict C, size_t M, size_t N, size_t K, | ||||
| size_t Astride, size_t Bstride, size_t Cstride) { | size_t Astride, size_t Bstride, size_t Cstride) { | ||||
| megdnn_assert(N == 1 && Bstride == 1); | megdnn_assert(N == 1 && Bstride == 1); | ||||
| @@ -244,7 +241,8 @@ void gemv_naive_n(const int8_t* __restrict A, const int8_t* __restrict B, | |||||
| } | } | ||||
| } | } | ||||
| void gemv_naive_n_mk4(const int8_t* __restrict A, const int8_t* __restrict B, | |||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
| void gemv_naive_n_mk4_dotprod(const int8_t* __restrict A, const int8_t* __restrict B, | |||||
| int32_t* __restrict C, size_t M, size_t N, size_t K, | int32_t* __restrict C, size_t M, size_t N, size_t K, | ||||
| size_t Astride, size_t Bstride, size_t Cstride) { | size_t Astride, size_t Bstride, size_t Cstride) { | ||||
| constexpr size_t PACK_SIZE = 4; | constexpr size_t PACK_SIZE = 4; | ||||
| @@ -323,6 +321,7 @@ void gemv_naive_n_mk4(const int8_t* __restrict A, const int8_t* __restrict B, | |||||
| } | } | ||||
| } | } | ||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
| void gemv_naive_n_mk4_dot(const int8_t* __restrict A, | void gemv_naive_n_mk4_dot(const int8_t* __restrict A, | ||||
| const int8_t* __restrict B, int32_t* __restrict C, | const int8_t* __restrict B, int32_t* __restrict C, | ||||
| size_t M, size_t N, size_t K, size_t Astride, | size_t M, size_t N, size_t K, size_t Astride, | ||||
| @@ -403,7 +402,16 @@ void arm_common::gemv_like(const int8_t* __restrict A, | |||||
| megdnn_assert(N == 1); | megdnn_assert(N == 1); | ||||
| MIDOUT_BEGIN(megdnn_arm_common_int8_gemv, | MIDOUT_BEGIN(megdnn_arm_common_int8_gemv, | ||||
| midout_iv("INT8_gemv_like"_hash)) { | midout_iv("INT8_gemv_like"_hash)) { | ||||
| #if MGB_ENABLE_DOT | |||||
| if (cpuinfo_has_arm_neon_dot()) { | |||||
| return gemv_naive_n_dot(A, B, C, M, N, K, Astride, Bstride, | |||||
| Cstride); | |||||
| } else { | |||||
| return gemv_naive_n(A, B, C, M, N, K, Astride, Bstride, Cstride); | |||||
| } | |||||
| #else | |||||
| return gemv_naive_n(A, B, C, M, N, K, Astride, Bstride, Cstride); | return gemv_naive_n(A, B, C, M, N, K, Astride, Bstride, Cstride); | ||||
| #endif | |||||
| } | } | ||||
| MIDOUT_END(); | MIDOUT_END(); | ||||
| } | } | ||||
| @@ -416,12 +424,22 @@ void arm_common::gemv_like_mk4(const int8_t* __restrict A, | |||||
| megdnn_assert(N == 1); | megdnn_assert(N == 1); | ||||
| MIDOUT_BEGIN(megdnn_arm_common_int8_gemv, | MIDOUT_BEGIN(megdnn_arm_common_int8_gemv, | ||||
| midout_iv("INT8_gemv_like_mk4"_hash)) { | midout_iv("INT8_gemv_like_mk4"_hash)) { | ||||
| #if MGB_ENABLE_DOT | |||||
| if (cpuinfo_has_arm_neon_dot()) { | |||||
| return gemv_naive_n_mk4_dotprod(A, B, C, M, N, K, Astride, Bstride, | |||||
| Cstride); | |||||
| } else { | |||||
| return gemv_naive_n_mk4(A, B, C, M, N, K, Astride, Bstride, | |||||
| Cstride); | |||||
| } | |||||
| #else | |||||
| return gemv_naive_n_mk4(A, B, C, M, N, K, Astride, Bstride, Cstride); | return gemv_naive_n_mk4(A, B, C, M, N, K, Astride, Bstride, Cstride); | ||||
| #endif | |||||
| } | } | ||||
| MIDOUT_END(); | MIDOUT_END(); | ||||
| } | } | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #if MGB_ENABLE_DOT | |||||
| void arm_common::gemv_like_mk4_dot(const int8_t* __restrict A, | void arm_common::gemv_like_mk4_dot(const int8_t* __restrict A, | ||||
| const int8_t* __restrict B, | const int8_t* __restrict B, | ||||
| int32_t* __restrict C, size_t M, size_t N, | int32_t* __restrict C, size_t M, size_t N, | ||||
| @@ -437,4 +455,5 @@ void arm_common::gemv_like_mk4_dot(const int8_t* __restrict A, | |||||
| } | } | ||||
| #endif | #endif | ||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
| @@ -28,7 +28,7 @@ void gemv_like_mk4(const int8_t* __restrict A, const int8_t* __restrict B, | |||||
| int32_t* __restrict C, size_t M, size_t N, size_t K, | int32_t* __restrict C, size_t M, size_t N, size_t K, | ||||
| size_t Astride, size_t Bstride, size_t Cstride); | size_t Astride, size_t Bstride, size_t Cstride); | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #if MGB_ENABLE_DOT | |||||
| void gemv_like_mk4_dot(const int8_t* __restrict A, const int8_t* __restrict B, | void gemv_like_mk4_dot(const int8_t* __restrict A, const int8_t* __restrict B, | ||||
| int32_t* __restrict C, size_t M, size_t N, size_t K, | int32_t* __restrict C, size_t M, size_t N, size_t K, | ||||
| size_t Astride, size_t Bstride, size_t Cstride); | size_t Astride, size_t Bstride, size_t Cstride); | ||||
| @@ -22,7 +22,7 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { | |||||
| #endif | #endif | ||||
| AlgoInt8x8x32Gemv int8x8x32_gemv; | AlgoInt8x8x32Gemv int8x8x32_gemv; | ||||
| AlgoInt8x8x32GemvMK4 int8x8x32_gemv_mk4; | AlgoInt8x8x32GemvMK4 int8x8x32_gemv_mk4; | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #if MGB_ENABLE_DOT | |||||
| AlgoInt8x8x32GemvMK4Dot int8x8x32_gemv_mk4_dot; | AlgoInt8x8x32GemvMK4Dot int8x8x32_gemv_mk4_dot; | ||||
| #endif | #endif | ||||
| AlgoGevm gevm; | AlgoGevm gevm; | ||||
| @@ -37,7 +37,7 @@ public: | |||||
| #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | ||||
| m_all_algos.emplace_back(&f16gemv); | m_all_algos.emplace_back(&f16gemv); | ||||
| #endif | #endif | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #if MGB_ENABLE_DOT | |||||
| m_all_algos.emplace_back(&int8x8x32_gemv_mk4_dot); | m_all_algos.emplace_back(&int8x8x32_gemv_mk4_dot); | ||||
| #endif | #endif | ||||
| m_all_algos.emplace_back(&int8x8x32_gemv); | m_all_algos.emplace_back(&int8x8x32_gemv); | ||||
| @@ -42,7 +42,7 @@ protected: | |||||
| #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | ||||
| class AlgoF16Gemv; | class AlgoF16Gemv; | ||||
| #endif | #endif | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #if MGB_ENABLE_DOT | |||||
| class AlgoInt8x8x32GemvMK4Dot;// Arm_common Int8x8x32 Gemv NCHW44_DOT | class AlgoInt8x8x32GemvMK4Dot;// Arm_common Int8x8x32 Gemv NCHW44_DOT | ||||
| #endif | #endif | ||||
| class AlgoInt8x8x16; // Arm_common Int 8x8x16 | class AlgoInt8x8x16; // Arm_common Int 8x8x16 | ||||
| @@ -69,9 +69,10 @@ struct Vfmaq_laneq_f32 { | |||||
| return vfmaq_laneq_f32(a, b, v, lane); | return vfmaq_laneq_f32(a, b, v, lane); | ||||
| } | } | ||||
| }; | }; | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #if MGB_ENABLE_DOT | |||||
| struct Vdotq_laneq_s32 { | struct Vdotq_laneq_s32 { | ||||
| template <const int lane> | template <const int lane> | ||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
| static __ai int32x4_t impl(int32x4_t a, int8x16_t b, int8x16_t v) { | static __ai int32x4_t impl(int32x4_t a, int8x16_t b, int8x16_t v) { | ||||
| return vdotq_laneq_s32(a, b, v, lane); | return vdotq_laneq_s32(a, b, v, lane); | ||||
| } | } | ||||
| @@ -82,4 +83,4 @@ struct Vdotq_laneq_s32 { | |||||
| } // namespace megdnn | } // namespace megdnn | ||||
| #undef __ai | #undef __ai | ||||
| // vim: syntax=cpp.doxygen | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -10,7 +10,12 @@ | |||||
| * implied. | * implied. | ||||
| */ | */ | ||||
| #pragma once | #pragma once | ||||
| #if MGB_ENABLE_DOT | |||||
| #if defined(__ARM_FEATURE_DOTPROD) | |||||
| #undef __ARM_FEATURE_DOTPROD | |||||
| #endif | |||||
| #define __ARM_FEATURE_DOTPROD 1 | |||||
| #endif | |||||
| #include <arm_neon.h> | #include <arm_neon.h> | ||||
| #include "megdnn/arch.h" | #include "megdnn/arch.h" | ||||
| #include "src/common/unroll_macro.h" | #include "src/common/unroll_macro.h" | ||||
| @@ -249,13 +254,14 @@ __ai float16x8_t vdupq_n_f16(__fp16 a) { | |||||
| #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #if MGB_ENABLE_DOT | |||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
| __ai int32x4_t vdotq2_s32(int8x16_t a, int8x16_t b) { | __ai int32x4_t vdotq2_s32(int8x16_t a, int8x16_t b) { | ||||
| int32x4_t c = vdupq_n_s32(0); | int32x4_t c = vdupq_n_s32(0); | ||||
| return vdotq_s32(c, a, b); | return vdotq_s32(c, a, b); | ||||
| } | } | ||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
| __ai uint32x4_t vdotq2_u32(uint8x16_t a, uint8x16_t b) { | __ai uint32x4_t vdotq2_u32(uint8x16_t a, uint8x16_t b) { | ||||
| uint32x4_t c = vdupq_n_u32(0); | uint32x4_t c = vdupq_n_u32(0); | ||||
| return vdotq_u32(c, a, b); | return vdotq_u32(c, a, b); | ||||
| @@ -275,11 +281,13 @@ __ai uint32x4_t vdotq2_u32(uint8x16_t a, uint8x16_t b) { | |||||
| c; \ | c; \ | ||||
| }) | }) | ||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
| __ai int32x2_t vdot2_s32(int8x8_t a, int8x8_t b) { | __ai int32x2_t vdot2_s32(int8x8_t a, int8x8_t b) { | ||||
| int32x2_t c = vdup_n_s32(0); | int32x2_t c = vdup_n_s32(0); | ||||
| return vdot_s32(c, a, b); | return vdot_s32(c, a, b); | ||||
| } | } | ||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
| __ai uint32x2_t vdot2_u8(uint8x8_t a, uint8x8_t b) { | __ai uint32x2_t vdot2_u8(uint8x8_t a, uint8x8_t b) { | ||||
| uint32x2_t c = vdup_n_u32(0); | uint32x2_t c = vdup_n_u32(0); | ||||
| return vdot_u32(c, a, b); | return vdot_u32(c, a, b); | ||||
| @@ -298,8 +306,7 @@ __ai uint32x2_t vdot2_u8(uint8x8_t a, uint8x8_t b) { | |||||
| c = vdot_lane_u32(c, a, b, lane); \ | c = vdot_lane_u32(c, a, b, lane); \ | ||||
| c; \ | c; \ | ||||
| }) | }) | ||||
| #endif // __ARM_FEATURE_DOTPROD | |||||
| #endif // MGB_ENABLE_DOT | |||||
| #if __GNUC__ < 8 | #if __GNUC__ < 8 | ||||
| #undef vld1q_f32_x2 | #undef vld1q_f32_x2 | ||||
| @@ -575,7 +582,7 @@ struct Vfmsq_laneq_f32_armv7<3> { | |||||
| #define vfmsq_laneq_f32(a, b, v, lane) \ | #define vfmsq_laneq_f32(a, b, v, lane) \ | ||||
| Vfmsq_laneq_f32_armv7<lane>::impl(a, b, v) | Vfmsq_laneq_f32_armv7<lane>::impl(a, b, v) | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #if MGB_ENABLE_DOT | |||||
| namespace { | namespace { | ||||
| template <int lane> | template <int lane> | ||||
| struct Vdotq_laneq_s32_armv7 { | struct Vdotq_laneq_s32_armv7 { | ||||
| @@ -583,24 +590,28 @@ struct Vdotq_laneq_s32_armv7 { | |||||
| }; | }; | ||||
| template <> | template <> | ||||
| struct Vdotq_laneq_s32_armv7<0> { | struct Vdotq_laneq_s32_armv7<0> { | ||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
| __ai int32x4_t impl(int32x4_t a, int8x16_t b, int8x16_t v) { | __ai int32x4_t impl(int32x4_t a, int8x16_t b, int8x16_t v) { | ||||
| return vdotq_lane_s32(a, b, vget_low_s32(v), 0); | return vdotq_lane_s32(a, b, vget_low_s32(v), 0); | ||||
| } | } | ||||
| }; | }; | ||||
| template <> | template <> | ||||
| struct Vdotq_laneq_s32_armv7<1> { | struct Vdotq_laneq_s32_armv7<1> { | ||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
| __ai int32x4_t impl(int32x4_t a, int8x16_t b, int8x16_t v) { | __ai int32x4_t impl(int32x4_t a, int8x16_t b, int8x16_t v) { | ||||
| return vdotq_lane_s32(a, b, vget_low_s32(v), 1); | return vdotq_lane_s32(a, b, vget_low_s32(v), 1); | ||||
| } | } | ||||
| }; | }; | ||||
| template <> | template <> | ||||
| struct Vdotq_laneq_s32_armv7<2> { | struct Vdotq_laneq_s32_armv7<2> { | ||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
| __ai int32x4_t impl(int32x4_t a, int8x16_t b, int8x16_t v) { | __ai int32x4_t impl(int32x4_t a, int8x16_t b, int8x16_t v) { | ||||
| return vdotq_lane_s32(a, b, vget_high_s32(v), 0); | return vdotq_lane_s32(a, b, vget_high_s32(v), 0); | ||||
| } | } | ||||
| }; | }; | ||||
| template <> | template <> | ||||
| struct Vdotq_laneq_s32_armv7<3> { | struct Vdotq_laneq_s32_armv7<3> { | ||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
| __ai int32x4_t impl(int32x4_t a, int8x16_t b, int8x16_t v) { | __ai int32x4_t impl(int32x4_t a, int8x16_t b, int8x16_t v) { | ||||
| return vdotq_lane_s32(a, b, vget_high_f32(v), 1); | return vdotq_lane_s32(a, b, vget_high_f32(v), 1); | ||||
| } | } | ||||
| @@ -765,7 +776,9 @@ __ai float32x4_t Vfmsq_f32(float32x4_t& a, float32x4_t& b, float32x4_t& v) { | |||||
| :); | :); | ||||
| return a; | return a; | ||||
| } | } | ||||
| #if MGB_ENABLE_DOT | |||||
| #undef __ARM_FEATURE_DOTPROD | |||||
| #endif | |||||
| #undef __ai | #undef __ai | ||||
| #pragma GCC diagnostic pop | #pragma GCC diagnostic pop | ||||
| @@ -19,6 +19,9 @@ | |||||
| #include "src/armv7/matrix_mul/quint8/strategy.h" | #include "src/armv7/matrix_mul/quint8/strategy.h" | ||||
| #include "src/common/utils.h" | #include "src/common/utils.h" | ||||
| #include "src/fallback/matrix_mul/gemm_impl.h" | #include "src/fallback/matrix_mul/gemm_impl.h" | ||||
| #if MGB_ENABLE_CPUINFO | |||||
| #include "cpuinfo.h" | |||||
| #endif | |||||
| #include "midout.h" | #include "midout.h" | ||||
| @@ -744,7 +747,7 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt16x16x32K12x4x1, | |||||
| armv7::matmul::gemm_s16x16x32_12x4, | armv7::matmul::gemm_s16x16x32_12x4, | ||||
| int16_t, int32_t, | int16_t, int32_t, | ||||
| AlgoDataType::INT16X16X32, DEFAULT); | AlgoDataType::INT16X16X32, DEFAULT); | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #if MGB_ENABLE_DOT | |||||
| /* ===================== Int8 K6x8x4 algo ===================== */ | /* ===================== Int8 K6x8x4 algo ===================== */ | ||||
| namespace { | namespace { | ||||
| void int8_k6x8x4_kern(const MatrixMulImpl::KernParam& kern_param) { | void int8_k6x8x4_kern(const MatrixMulImpl::KernParam& kern_param) { | ||||
| @@ -769,6 +772,9 @@ void int8_k6x8x4_kern(const MatrixMulImpl::KernParam& kern_param) { | |||||
| bool MatrixMulImpl::AlgoInt8x8x32K6x8x4::usable( | bool MatrixMulImpl::AlgoInt8x8x32K6x8x4::usable( | ||||
| const KernSizeParam& kern_size_param) const { | const KernSizeParam& kern_size_param) const { | ||||
| if (!cpuinfo_has_arm_neon_dot()){ | |||||
| return false; | |||||
| } | |||||
| return can_be_treated_as_int8x8x32(kern_size_param); | return can_be_treated_as_int8x8x32(kern_size_param); | ||||
| } | } | ||||
| @@ -827,6 +833,9 @@ void quint8_dot_k4x8x4_kern(const MatrixMulImpl::KernParam& kern_param) { | |||||
| bool MatrixMulImpl::AlgoQuint8DotK4x8x4::usable( | bool MatrixMulImpl::AlgoQuint8DotK4x8x4::usable( | ||||
| const KernSizeParam& kern_size_param) const { | const KernSizeParam& kern_size_param) const { | ||||
| if (!cpuinfo_has_arm_neon_dot()){ | |||||
| return false; | |||||
| } | |||||
| return kern_size_param.A_type.enumv() == DTypeEnum::Quantized8Asymm && | return kern_size_param.A_type.enumv() == DTypeEnum::Quantized8Asymm && | ||||
| kern_size_param.B_type.enumv() == DTypeEnum::Quantized8Asymm && | kern_size_param.B_type.enumv() == DTypeEnum::Quantized8Asymm && | ||||
| kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32 && | kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32 && | ||||
| @@ -891,6 +900,9 @@ void int8_mk4_8x4x4_dotprod_kern(const MatrixMulImpl::KernParam& kern_param) { | |||||
| bool MatrixMulImpl::AlgoInt8x8x32MK4_8x4x4DotProd::usable( | bool MatrixMulImpl::AlgoInt8x8x32MK4_8x4x4DotProd::usable( | ||||
| const KernSizeParam& kern_size_param) const { | const KernSizeParam& kern_size_param) const { | ||||
| if (!cpuinfo_has_arm_neon_dot()){ | |||||
| return false; | |||||
| } | |||||
| return kern_size_param.A_type.enumv() == kern_size_param.B_type.enumv() && | return kern_size_param.A_type.enumv() == kern_size_param.B_type.enumv() && | ||||
| (kern_size_param.A_type.enumv() == DTypeEnum::Int8 || | (kern_size_param.A_type.enumv() == DTypeEnum::Int8 || | ||||
| kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8) && | kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8) && | ||||
| @@ -86,7 +86,7 @@ public: | |||||
| MEGDNN_DECL_ALGO_TYPE(ARMV7_F16_MK8_4X8) | MEGDNN_DECL_ALGO_TYPE(ARMV7_F16_MK8_4X8) | ||||
| }; | }; | ||||
| #endif | #endif | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #if MGB_ENABLE_DOT | |||||
| class MatrixMulImpl::AlgoInt8x8x32K6x8x4 final : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x32K6x8x4 final : public AlgoBase { | ||||
| public: | public: | ||||
| AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
| @@ -10,7 +10,6 @@ | |||||
| * implied. | * implied. | ||||
| */ | */ | ||||
| #pragma once | #pragma once | ||||
| #include <arm_neon.h> | |||||
| #include <cmath> | #include <cmath> | ||||
| #include <cstdint> | #include <cstdint> | ||||
| #include <type_traits> | #include <type_traits> | ||||
| @@ -10,7 +10,6 @@ | |||||
| * implied. | * implied. | ||||
| */ | */ | ||||
| #include "src/arm_common/simd_macro/marm_neon.h" | |||||
| #include "src/armv7/matrix_mul/asm/common.h" | #include "src/armv7/matrix_mul/asm/common.h" | ||||
| #include "src/armv7/matrix_mul/fp32/strategy.h" | #include "src/armv7/matrix_mul/fp32/strategy.h" | ||||
| #include "src/common/utils.h" | #include "src/common/utils.h" | ||||
| @@ -9,7 +9,7 @@ | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| */ | */ | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #if MGB_ENABLE_DOT | |||||
| #include "src/arm_common/simd_macro/marm_neon.h" | #include "src/arm_common/simd_macro/marm_neon.h" | ||||
| #include "src/armv7/matrix_mul/asm/common.h" | #include "src/armv7/matrix_mul/asm/common.h" | ||||
| @@ -43,6 +43,7 @@ namespace matmul_dot_6x8x4 { | |||||
| // | // | ||||
| // Accumulator | // Accumulator | ||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
| static void kern_6x8(const int8_t* packA, const int8_t* packB, int K, | static void kern_6x8(const int8_t* packA, const int8_t* packB, int K, | ||||
| int32_t* output, int LDC, bool is_first_k, | int32_t* output, int LDC, bool is_first_k, | ||||
| size_t m_remain = 6) { | size_t m_remain = 6) { | ||||
| @@ -274,6 +275,7 @@ static void kern_6x8(const int8_t* packA, const int8_t* packB, int K, | |||||
| // | // | ||||
| // Accumulator | // Accumulator | ||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
| static void kern_6x4(const int8_t* packA, const int8_t* packB, int K, | static void kern_6x4(const int8_t* packA, const int8_t* packB, int K, | ||||
| int32_t* output, int LDC, bool is_first_k, | int32_t* output, int LDC, bool is_first_k, | ||||
| size_t n_remain = 8, size_t m_remain = 6) { | size_t n_remain = 8, size_t m_remain = 6) { | ||||
| @@ -10,7 +10,7 @@ | |||||
| * implied. | * implied. | ||||
| */ | */ | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #if MGB_ENABLE_DOT | |||||
| #include "src/arm_common/simd_macro/marm_neon.h" | #include "src/arm_common/simd_macro/marm_neon.h" | ||||
| #include "src/armv7/matrix_mul/asm/common.h" | #include "src/armv7/matrix_mul/asm/common.h" | ||||
| @@ -42,7 +42,7 @@ namespace matmul_mk4_dot_8x4x4 { | |||||
| // |q14[0-4]| | // |q14[0-4]| | ||||
| // +--------+ | // +--------+ | ||||
| // Accumulator | // Accumulator | ||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
| static void kern_8x4(const int8_t* packA, const int8_t* packB, int K, | static void kern_8x4(const int8_t* packA, const int8_t* packB, int K, | ||||
| int32_t* output, int LDC, bool is_first_k, int n_remain) { | int32_t* output, int LDC, bool is_first_k, int n_remain) { | ||||
| K /= 4; | K /= 4; | ||||
| @@ -211,6 +211,7 @@ static void kern_8x4(const int8_t* packA, const int8_t* packB, int K, | |||||
| // +--------+ | // +--------+ | ||||
| // Accumulator | // Accumulator | ||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
| static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | ||||
| int32_t* output, int LDC, bool is_first_k, int n_remain) { | int32_t* output, int LDC, bool is_first_k, int n_remain) { | ||||
| K /= 4; | K /= 4; | ||||
| @@ -175,7 +175,7 @@ void gemm_s8_4x8::kern(const dt_int8* packA, const dt_int8* packB, size_t M, | |||||
| } | } | ||||
| } | } | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #if MGB_ENABLE_DOT | |||||
| // ===========================gemm_s8_6x8====================================== | // ===========================gemm_s8_6x8====================================== | ||||
| MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_dots8_6x8); | MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_dots8_6x8); | ||||
| void gemm_dots8_6x8::pack_A(dt_int8* out, const dt_int8* in, int ldin, int y0, | void gemm_dots8_6x8::pack_A(dt_int8* out, const dt_int8* in, int ldin, int y0, | ||||
| @@ -23,7 +23,7 @@ MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 4, 8, 8, false, true, | |||||
| MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 4, 2, 16, false, false, | MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 4, 2, 16, false, false, | ||||
| gemm_mk4_s8_4x2); | gemm_mk4_s8_4x2); | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #if MGB_ENABLE_DOT | |||||
| MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 6, 8, 4, false, false, | MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 6, 8, 4, false, false, | ||||
| gemm_dots8_6x8); | gemm_dots8_6x8); | ||||
| @@ -27,7 +27,7 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { | |||||
| AlgoF16K4x16x1 f16_k4x16x1; | AlgoF16K4x16x1 f16_k4x16x1; | ||||
| AlgoF16MK8_4x8 f16_mk8_4x8; | AlgoF16MK8_4x8 f16_mk8_4x8; | ||||
| #endif | #endif | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #if MGB_ENABLE_DOT | |||||
| AlgoInt8x8x32K6x8x4 int8_k6x8x4; | AlgoInt8x8x32K6x8x4 int8_k6x8x4; | ||||
| AlgoQuint8DotK4x8x4 quint8_k4x8x4; | AlgoQuint8DotK4x8x4 quint8_k4x8x4; | ||||
| AlgoInt8x8x32MK4_8x4x4DotProd int8x8x32_mk4_8x4x4_dotprod; | AlgoInt8x8x32MK4_8x4x4DotProd int8x8x32_mk4_8x4x4_dotprod; | ||||
| @@ -57,7 +57,7 @@ public: | |||||
| m_all_algos.emplace_back(&f16_k4x16x1); | m_all_algos.emplace_back(&f16_k4x16x1); | ||||
| m_all_algos.emplace_back(&f16_mk8_4x8); | m_all_algos.emplace_back(&f16_mk8_4x8); | ||||
| #endif | #endif | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #if MGB_ENABLE_DOT | |||||
| m_all_algos.emplace_back(&int8x8x32_mk4_8x4x4_dotprod); | m_all_algos.emplace_back(&int8x8x32_mk4_8x4x4_dotprod); | ||||
| m_all_algos.emplace_back(&int8_k6x8x4); | m_all_algos.emplace_back(&int8_k6x8x4); | ||||
| m_all_algos.emplace_back(&quint8_k4x8x4); | m_all_algos.emplace_back(&quint8_k4x8x4); | ||||
| @@ -49,7 +49,7 @@ private: | |||||
| class AlgoF16K4x16x1; // Armv7 F16 Kernel 4x16x1 | class AlgoF16K4x16x1; // Armv7 F16 Kernel 4x16x1 | ||||
| class AlgoF16MK8_4x8; // Armv7 F16 MK8 Format block 4x8 | class AlgoF16MK8_4x8; // Armv7 F16 MK8 Format block 4x8 | ||||
| #endif | #endif | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #if MGB_ENABLE_DOT | |||||
| class AlgoInt8x8x32K6x8x4; // Armv7 Int8 Kernel 6x8x4 | class AlgoInt8x8x32K6x8x4; // Armv7 Int8 Kernel 6x8x4 | ||||
| class AlgoQuint8DotK4x8x4; // Armv7 Quint8 Kernel 6x8x4 | class AlgoQuint8DotK4x8x4; // Armv7 Quint8 Kernel 6x8x4 | ||||
| class AlgoInt8x8x32MK4_8x4x4DotProd; // Armv7 nchw44 Int8x8x32 Kernel 8x4x4 | class AlgoInt8x8x32MK4_8x4x4DotProd; // Armv7 nchw44 Int8x8x32 Kernel 8x4x4 | ||||
| @@ -9,7 +9,7 @@ | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| */ | */ | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #if MGB_ENABLE_DOT | |||||
| #include "src/arm_common/simd_macro/marm_neon.h" | #include "src/arm_common/simd_macro/marm_neon.h" | ||||
| #include "src/armv7/matrix_mul/asm/common.h" | #include "src/armv7/matrix_mul/asm/common.h" | ||||
| @@ -41,7 +41,7 @@ namespace matmul_dot_4x8x4 { | |||||
| // +-------+-------+ - - - - +--------+--------+--------+ | // +-------+-------+ - - - - +--------+--------+--------+ | ||||
| // | // | ||||
| // Accumulator | // Accumulator | ||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
| static void kern_4x8(const uint8_t* packA, const uint8_t* packB, int K, | static void kern_4x8(const uint8_t* packA, const uint8_t* packB, int K, | ||||
| int32_t* output, int LDC, bool is_first_k, uint8_t zA, | int32_t* output, int LDC, bool is_first_k, uint8_t zA, | ||||
| uint8_t zB, uint32_t zAB, size_t m_remain = 4) { | uint8_t zB, uint32_t zAB, size_t m_remain = 4) { | ||||
| @@ -257,6 +257,7 @@ static void kern_4x8(const uint8_t* packA, const uint8_t* packB, int K, | |||||
| // +-------+-------+ - - - - +--------+--------+--------+ | // +-------+-------+ - - - - +--------+--------+--------+ | ||||
| // | // | ||||
| // Accumulator | // Accumulator | ||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
| static void kern_4x4(const uint8_t* packA, const uint8_t* packB, int K, | static void kern_4x4(const uint8_t* packA, const uint8_t* packB, int K, | ||||
| int32_t* output, int LDC, bool is_first_k, uint8_t zA, | int32_t* output, int LDC, bool is_first_k, uint8_t zA, | ||||
| uint8_t zB, uint32_t zAB, size_t m_remain = 4, | uint8_t zB, uint32_t zAB, size_t m_remain = 4, | ||||
| @@ -88,7 +88,7 @@ void gemm_u8_4x8::kern(const dt_uint8* packA, const dt_uint8* packB, size_t M, | |||||
| } | } | ||||
| } | } | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #if MGB_ENABLE_DOT | |||||
| // ===========================gemm_dot_quint8_4x8====================================== | // ===========================gemm_dot_quint8_4x8====================================== | ||||
| MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_dot_quint8_4x8); | MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_dot_quint8_4x8); | ||||
| void gemm_dot_quint8_4x8::pack_A(dt_uint8* out, const dt_uint8* in, int ldin, | void gemm_dot_quint8_4x8::pack_A(dt_uint8* out, const dt_uint8* in, int ldin, | ||||
| @@ -17,7 +17,7 @@ namespace matmul { | |||||
| MEGDNN_REG_GEMM_STRATEGY(dt_uint8, dt_int32, dt_int32, 4, 8, 8, false, true, | MEGDNN_REG_GEMM_STRATEGY(dt_uint8, dt_int32, dt_int32, 4, 8, 8, false, true, | ||||
| gemm_u8_4x8); | gemm_u8_4x8); | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #if MGB_ENABLE_DOT | |||||
| MEGDNN_REG_GEMM_STRATEGY(dt_uint8, dt_int32, dt_int32, 4, 8, 4, false, false, | MEGDNN_REG_GEMM_STRATEGY(dt_uint8, dt_int32, dt_int32, 4, 8, 4, false, false, | ||||
| gemm_dot_quint8_4x8); | gemm_dot_quint8_4x8); | ||||
| #endif | #endif | ||||
| @@ -60,6 +60,13 @@ | |||||
| #include <windows.h> | #include <windows.h> | ||||
| #endif | #endif | ||||
| #if MEGDNN_AARCH64 || MEGDNN_ARMV7 | |||||
| #if MGB_ENABLE_CPUINFO | |||||
| #include "cpuinfo.h" | |||||
| #endif | |||||
| #endif | |||||
| #if __cplusplus >= 201703L || __clang_major__ >= 4 | #if __cplusplus >= 201703L || __clang_major__ >= 4 | ||||
| #define MEGDNN_FALLTHRU [[fallthrough]]; | #define MEGDNN_FALLTHRU [[fallthrough]]; | ||||
| #elif __GNUC__ >= 7 | #elif __GNUC__ >= 7 | ||||
| @@ -148,7 +148,7 @@ struct GemvLike<stype, btype, param::ConvBias::Format::NCHW44> { | |||||
| } | } | ||||
| }; | }; | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #if MGB_ENABLE_DOT | |||||
| template <typename stype, typename btype> | template <typename stype, typename btype> | ||||
| struct GemvLike<stype, btype, param::ConvBias::Format::NCHW44_DOT> { | struct GemvLike<stype, btype, param::ConvBias::Format::NCHW44_DOT> { | ||||
| inline static void do_gemv(const stype* A, const stype* B, btype* C, | inline static void do_gemv(const stype* A, const stype* B, btype* C, | ||||
| @@ -87,7 +87,7 @@ TEST_F(AARCH64, MATRIX_MUL_F16_MK8) { | |||||
| } | } | ||||
| #endif | #endif | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #if MGB_ENABLE_DOT | |||||
| TEST_F(AARCH64, MATRIX_MUL_INT8X8X32_K8X12X4_DOTPROD) { | TEST_F(AARCH64, MATRIX_MUL_INT8X8X32_K8X12X4_DOTPROD) { | ||||
| matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int32{}, | matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int32{}, | ||||
| handle(), "AARCH64_INT8X8X32_K8X12X4_DOTPROD"); | handle(), "AARCH64_INT8X8X32_K8X12X4_DOTPROD"); | ||||
| @@ -690,7 +690,7 @@ TEST_F(AARCH64, BENCHMARK_GEMV) { | |||||
| run(M, K, N); | run(M, K, N); | ||||
| } | } | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #if MGB_ENABLE_DOT | |||||
| TEST_F(AARCH64, BENCHMARK_TRANSPOSED_MATRIX_MUL_INT_8X8X32) { | TEST_F(AARCH64, BENCHMARK_TRANSPOSED_MATRIX_MUL_INT_8X8X32) { | ||||
| constexpr size_t RUNS = 50; | constexpr size_t RUNS = 50; | ||||
| param::MatrixMul param; | param::MatrixMul param; | ||||
| @@ -803,7 +803,7 @@ TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_INT8X8X32_MK4_8X12X4) { | |||||
| std::cout << std::endl; | std::cout << std::endl; | ||||
| } | } | ||||
| } | } | ||||
| #endif // __ARM_FEATURE_DOTPROD | |||||
| #endif // MGB_ENABLE_DOT | |||||
| #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | ||||
| TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_F16_MK8) { | TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_F16_MK8) { | ||||
| @@ -166,7 +166,7 @@ static void benchmark_convbias(Handle* handle, std::string int_name, | |||||
| .set_display(false); | .set_display(false); | ||||
| } | } | ||||
| auto nchw44_algo_regx = ".*(DIRECT|NCHW_NCHW44).*"; | auto nchw44_algo_regx = ".*(DIRECT|NCHW_NCHW44).*"; | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #if MGB_ENBALE_DOT | |||||
| if (!is_fp32) { | if (!is_fp32) { | ||||
| nchw44_algo_regx = ".*DOT.*"; | nchw44_algo_regx = ".*DOT.*"; | ||||
| } | } | ||||
| @@ -1852,7 +1852,7 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_QINT8_STRIDE1_NCHW44) { | |||||
| #endif | #endif | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #if MGB_ENBALE_DOT | |||||
| #if MEGDNN_WITH_BENCHMARK | #if MEGDNN_WITH_BENCHMARK | ||||
| TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_INT8_STRIDE1_WITHDOTPROD) { | TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_INT8_STRIDE1_WITHDOTPROD) { | ||||
| // have to remove preferred restrict in usable func before run the benchmark | // have to remove preferred restrict in usable func before run the benchmark | ||||
| @@ -2440,7 +2440,7 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_CONV1X1_S1_QUANTIZEDSYM) { | |||||
| dtype::QuantizedS8 stype(2.5f); | dtype::QuantizedS8 stype(2.5f); | ||||
| dtype::QuantizedS32 dtype(6.25f); | dtype::QuantizedS32 dtype(6.25f); | ||||
| #if MEGDNN_AARCH64 | #if MEGDNN_AARCH64 | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #if MGB_ENBALE_DOT | |||||
| benchmark_conv1x1("AARCH64_INT8X8X32_K8X12X4_DOTPROD", handle(), stype, | benchmark_conv1x1("AARCH64_INT8X8X32_K8X12X4_DOTPROD", handle(), stype, | ||||
| dtype, dtype, dtype); | dtype, dtype, dtype); | ||||
| #else | #else | ||||
| @@ -2460,7 +2460,7 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_CONV1X1_S1_QUANTIZEDASYM) { | |||||
| dtype::QuantizedS32 dtype(1.2 * 1.2); | dtype::QuantizedS32 dtype(1.2 * 1.2); | ||||
| #if MEGDNN_AARCH64 | #if MEGDNN_AARCH64 | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #if MGB_ENBALE_DOT | |||||
| benchmark_conv1x1("AARCH64_QUINT8_K8X8X4_DOTPROD", handle(), stype, dtype, | benchmark_conv1x1("AARCH64_QUINT8_K8X8X4_DOTPROD", handle(), stype, dtype, | ||||
| dtype, dtype); | dtype, dtype); | ||||
| #else | #else | ||||
| @@ -2565,7 +2565,7 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_CONV1X1_GEMV_FP32) { | |||||
| } | } | ||||
| } | } | ||||
| #ifndef __ARM_FEATURE_DOTPROD | |||||
| //! enable none dot algo now | |||||
| TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_1X1_S1_NCHW_VS_NCHW44_INT8x8x32) { | TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_1X1_S1_NCHW_VS_NCHW44_INT8x8x32) { | ||||
| std::vector<TestArg> conv_bias_1x1_args_nchw44 = | std::vector<TestArg> conv_bias_1x1_args_nchw44 = | ||||
| get_conv_bias_1x1_benchmark_args(4); | get_conv_bias_1x1_benchmark_args(4); | ||||
| @@ -2634,7 +2634,6 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_1X1_S1_NCHW_VS_NCHW44_INT8x8x32) { | |||||
| computations / conv1x1_nchw44, conv1x1_nchw / conv1x1_nchw44); | computations / conv1x1_nchw44, conv1x1_nchw / conv1x1_nchw44); | ||||
| } | } | ||||
| } | } | ||||
| #endif | |||||
| TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_WINOGRAD_VS_IM2COL_INT8) { | TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_WINOGRAD_VS_IM2COL_INT8) { | ||||
| auto&& args = get_winograd_benchmark_args(3, 8); | auto&& args = get_winograd_benchmark_args(3, 8); | ||||
| @@ -500,7 +500,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE2) { | |||||
| } | } | ||||
| /****************************dot qint8 direct*************************/ | /****************************dot qint8 direct*************************/ | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #if MGB_ENABLE_DOT | |||||
| TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_DOT_NCHW_NCHW44) { | TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_DOT_NCHW_NCHW44) { | ||||
| auto args = get_nchw44_conv_bias_args({2, 3, 5, 7}, QUAN_NLMODE, | auto args = get_nchw44_conv_bias_args({2, 3, 5, 7}, QUAN_NLMODE, | ||||
| BR_AND_NO_BIASMODE, 2, false, true); | BR_AND_NO_BIASMODE, 2, false, true); | ||||
| @@ -655,7 +655,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_INT8_NCHW44) { | |||||
| bench_case(1, 512, 256, 28, 28, 3, 4, 1, 2); | bench_case(1, 512, 256, 28, 28, 3, 4, 1, 2); | ||||
| } | } | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #if MGB_ENABLE_DOT | |||||
| TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_INT8_NCHW44_DOT) { | TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_INT8_NCHW44_DOT) { | ||||
| constexpr size_t RUNS = 40; | constexpr size_t RUNS = 40; | ||||
| std::vector<DType> data_type = { | std::vector<DType> data_type = { | ||||
| @@ -892,7 +892,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, | |||||
| benchmark_impl(param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, | benchmark_impl(param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, | ||||
| {1, {4}}, data_type); | {1, {4}}, data_type); | ||||
| } | } | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #if MGB_ENABLE_DOT | |||||
| TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, | TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, | ||||
| BENCHMARK_CONVBIAS_INT8_INT8_INT8_STRIDE1_WITHDOTPROD) { | BENCHMARK_CONVBIAS_INT8_INT8_INT8_STRIDE1_WITHDOTPROD) { | ||||
| constexpr size_t RUNS = 50; | constexpr size_t RUNS = 50; | ||||
| @@ -1157,7 +1157,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, | |||||
| benchmark_impl(param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, | benchmark_impl(param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, | ||||
| {1, {4}}, data_type); | {1, {4}}, data_type); | ||||
| } | } | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #if MGB_ENABLE_DOT | |||||
| TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, | TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, | ||||
| BENCHMARK_CONVBIAS_QUINT8_QUINT8_QUINT8_STRIDE1_WITHDOTPROD) { | BENCHMARK_CONVBIAS_QUINT8_QUINT8_QUINT8_STRIDE1_WITHDOTPROD) { | ||||
| constexpr size_t RUNS = 50; | constexpr size_t RUNS = 50; | ||||
| @@ -1977,7 +1977,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, | |||||
| dtype::QuantizedS32 btype(0.04f); | dtype::QuantizedS32 btype(0.04f); | ||||
| dtype::Quantized8Asymm dtype(1.4f, 110); | dtype::Quantized8Asymm dtype(1.4f, 110); | ||||
| #if MEGDNN_AARCH64 | #if MEGDNN_AARCH64 | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #if MGB_ENABLE_DOT | |||||
| conv1x1_multithread_benchmark("CONV1x1:AARCH64_QUINT8_K8X8X4_DOTPROD:8", | conv1x1_multithread_benchmark("CONV1x1:AARCH64_QUINT8_K8X8X4_DOTPROD:8", | ||||
| stype, ftype, btype, dtype); | stype, ftype, btype, dtype); | ||||
| #else | #else | ||||
| @@ -20,7 +20,7 @@ using namespace megdnn; | |||||
| using namespace test; | using namespace test; | ||||
| using namespace conv_bias; | using namespace conv_bias; | ||||
| #ifdef __ARM_FEATURE_DOTPROD | |||||
| #if MGB_ENABLE_DOT | |||||
| TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_CONV1x1_QUANTIZEDSYM_MK4_DOT) { | TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_CONV1x1_QUANTIZEDSYM_MK4_DOT) { | ||||
| UniformIntRNG rng{-50, 50}; | UniformIntRNG rng{-50, 50}; | ||||
| @@ -138,7 +138,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_QUANTIZEDSYM) { | |||||
| dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \ | dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \ | ||||
| dtype::QuantizedS8(60.25f), name); | dtype::QuantizedS8(60.25f), name); | ||||
| #if MEGDNN_AARCH64 | #if MEGDNN_AARCH64 | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #if MGB_ENABLE_DOT | |||||
| cb("CONV1x1:AARCH64_INT8X8X32_K8X12X4_DOTPROD:24"); | cb("CONV1x1:AARCH64_INT8X8X32_K8X12X4_DOTPROD:24"); | ||||
| #else | #else | ||||
| cb("CONV1x1:AARCH64_INT8X8X32_K8X8X8:24"); | cb("CONV1x1:AARCH64_INT8X8X32_K8X8X8:24"); | ||||
| @@ -174,7 +174,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_QUANTIZEDASYM) { | |||||
| name); | name); | ||||
| float epsilon = 0.001; | float epsilon = 0.001; | ||||
| #if MEGDNN_AARCH64 | #if MEGDNN_AARCH64 | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #if MGB_ENABLE_DOT | |||||
| cb("CONV1x1:AARCH64_QUINT8_K8X8X4_DOTPROD:48"); | cb("CONV1x1:AARCH64_QUINT8_K8X8X4_DOTPROD:48"); | ||||
| #else | #else | ||||
| cb("CONV1x1:AARCH64_QUINT8_K8X8X8:24"); | cb("CONV1x1:AARCH64_QUINT8_K8X8X8:24"); | ||||
| @@ -210,13 +210,13 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_QUINT8x8x32) { | |||||
| dtype::QuantizedS32(1.2 * 1.3), {}, name); | dtype::QuantizedS32(1.2 * 1.3), {}, name); | ||||
| #if MEGDNN_AARCH64 | #if MEGDNN_AARCH64 | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #if MGB_ENABLE_DOT | |||||
| cb("CONV1x1:AARCH64_QUINT8_K8X8X4_DOTPROD:24"); | cb("CONV1x1:AARCH64_QUINT8_K8X8X4_DOTPROD:24"); | ||||
| #else | #else | ||||
| cb("CONV1x1:AARCH64_QUINT8_K8X8X8:48"); | cb("CONV1x1:AARCH64_QUINT8_K8X8X8:48"); | ||||
| #endif | #endif | ||||
| #elif MEGDNN_ARMV7 | #elif MEGDNN_ARMV7 | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #if MGB_ENABLE_DOT | |||||
| cb("CONV1x1:AARCH32_QUINT8_K4X8X4:48"); | cb("CONV1x1:AARCH32_QUINT8_K4X8X4:48"); | ||||
| #endif | #endif | ||||
| cb("CONV1x1:ARMV7_QUINT8_K4X8X8:24"); | cb("CONV1x1:ARMV7_QUINT8_K4X8X8:24"); | ||||
| @@ -287,14 +287,14 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_INT8x8x32) { | |||||
| #define cb(name) checker_conv_bias_mul_int8x8x32(args, handle(), name); | #define cb(name) checker_conv_bias_mul_int8x8x32(args, handle(), name); | ||||
| #if MEGDNN_AARCH64 | #if MEGDNN_AARCH64 | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #if MGB_ENABLE_DOT | |||||
| cb("CONV1x1:AARCH64_INT8X8X32_K8X12X4_DOTPROD:48"); | cb("CONV1x1:AARCH64_INT8X8X32_K8X12X4_DOTPROD:48"); | ||||
| #else | #else | ||||
| cb("CONV1x1:AARCH64_INT8X8X32_K8X8X8:24"); | cb("CONV1x1:AARCH64_INT8X8X32_K8X8X8:24"); | ||||
| cb("CONV1x1:AARCH64_INT8X8X32_K4X4X16:24"); | cb("CONV1x1:AARCH64_INT8X8X32_K4X4X16:24"); | ||||
| #endif | #endif | ||||
| #elif MEGDNN_ARMV7 | #elif MEGDNN_ARMV7 | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #if MGB_ENABLE_DOT | |||||
| cb("CONV1x1:AARCH32_INT8_K6X8X4:48"); | cb("CONV1x1:AARCH32_INT8_K6X8X4:48"); | ||||
| #endif | #endif | ||||
| cb("CONV1x1:ARMV7_INT8X8X32_K4X8X8:24"); | cb("CONV1x1:ARMV7_INT8X8X32_K4X8X8:24"); | ||||
| @@ -312,8 +312,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_INT8x8x32) { | |||||
| } | } | ||||
| checker_conv_bias_mul_int8x8x32(gemv_args, handle(), "CONV1x1_GEMV"); | checker_conv_bias_mul_int8x8x32(gemv_args, handle(), "CONV1x1_GEMV"); | ||||
| } | } | ||||
| #ifndef __ARM_FEATURE_DOTPROD | |||||
| //! enable none dot algo now | |||||
| TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_INT8x8x32_MK4) { | TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_INT8x8x32_MK4) { | ||||
| using namespace conv_bias; | using namespace conv_bias; | ||||
| std::vector<conv_bias::TestArg> args = | std::vector<conv_bias::TestArg> args = | ||||
| @@ -345,7 +344,6 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_INT8x8x32_MK4) { | |||||
| #endif | #endif | ||||
| #undef cb | #undef cb | ||||
| } | } | ||||
| #endif | |||||
| TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_INT8x8x32_NCHW44) { | TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_INT8x8x32_NCHW44) { | ||||
| using namespace conv_bias; | using namespace conv_bias; | ||||
| @@ -364,7 +362,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_INT8x8x32_NCHW44) { | |||||
| "CONV1x1_GEMV"); | "CONV1x1_GEMV"); | ||||
| } | } | ||||
| #ifdef __ARM_FEATURE_DOTPROD | |||||
| #if MGB_ENABLE_DOT | |||||
| TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_INT8x8x32_NCHW44_DOT) { | TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_INT8x8x32_NCHW44_DOT) { | ||||
| using namespace conv_bias; | using namespace conv_bias; | ||||
| std::vector<conv_bias::TestArg> args = get_nchw44_conv_bias_args( | std::vector<conv_bias::TestArg> args = get_nchw44_conv_bias_args( | ||||
| @@ -135,7 +135,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM) { | |||||
| float epsilon = 0.001; | float epsilon = 0.001; | ||||
| #if MEGDNN_AARCH64 | #if MEGDNN_AARCH64 | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #if MGB_ENABLE_DOT | |||||
| cb("IM2COLMATMUL:AARCH64_INT8X8X32_K8X12X4_DOTPROD"); | cb("IM2COLMATMUL:AARCH64_INT8X8X32_K8X12X4_DOTPROD"); | ||||
| #else | #else | ||||
| cb("IM2COLMATMUL:AARCH64_INT8X8X32_K8X8X8"); | cb("IM2COLMATMUL:AARCH64_INT8X8X32_K8X8X8"); | ||||
| @@ -148,7 +148,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM) { | |||||
| #undef cb | #undef cb | ||||
| } | } | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #if MGB_ENABLE_DOT | |||||
| TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM_MK4_DOT) { | TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM_MK4_DOT) { | ||||
| UniformIntRNG rng{-50, 50}; | UniformIntRNG rng{-50, 50}; | ||||
| @@ -173,6 +173,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM_MK4_DOT) { | |||||
| #if MEGDNN_AARCH64 | #if MEGDNN_AARCH64 | ||||
| cb("IM2COLMATMUL:AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD:96"); | cb("IM2COLMATMUL:AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD:96"); | ||||
| #elif MEGDNN_ARMV7 | #elif MEGDNN_ARMV7 | ||||
| epsilon = 1; | |||||
| cb("IM2COLMATMUL:AARCH32_INT8_MK4_8X4X4_DOTPROD:96"); | cb("IM2COLMATMUL:AARCH32_INT8_MK4_8X4X4_DOTPROD:96"); | ||||
| #endif | #endif | ||||
| #undef cb | #undef cb | ||||
| @@ -194,6 +195,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, | |||||
| #if MEGDNN_AARCH64 | #if MEGDNN_AARCH64 | ||||
| cb("IM2COLMATMUL:AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD:96"); | cb("IM2COLMATMUL:AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD:96"); | ||||
| #elif MEGDNN_ARMV7 | #elif MEGDNN_ARMV7 | ||||
| epsilon = 1; | |||||
| cb("IM2COLMATMUL:AARCH32_INT8_MK4_8X4X4_DOTPROD:96"); | cb("IM2COLMATMUL:AARCH32_INT8_MK4_8X4X4_DOTPROD:96"); | ||||
| #endif | #endif | ||||
| #undef cb | #undef cb | ||||
| @@ -273,7 +275,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUANTIZEDASYM) { | |||||
| dtype::Quantized8Asymm(50.3f, (uint8_t)120), name); | dtype::Quantized8Asymm(50.3f, (uint8_t)120), name); | ||||
| float epsilon = 0.001; | float epsilon = 0.001; | ||||
| #if MEGDNN_AARCH64 | #if MEGDNN_AARCH64 | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #if MGB_ENABLE_DOT | |||||
| cb("IM2COLMATMUL:AARCH64_QUINT8_K8X8X4_DOTPROD"); | cb("IM2COLMATMUL:AARCH64_QUINT8_K8X8X4_DOTPROD"); | ||||
| #else | #else | ||||
| cb("IM2COLMATMUL:AARCH64_QUINT8_K8X8X8"); | cb("IM2COLMATMUL:AARCH64_QUINT8_K8X8X8"); | ||||
| @@ -305,13 +307,13 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUINT8x8x32) { | |||||
| dtype::QuantizedS32(1.2 * 1.3), {}, name); | dtype::QuantizedS32(1.2 * 1.3), {}, name); | ||||
| #if MEGDNN_AARCH64 | #if MEGDNN_AARCH64 | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #if MGB_ENABLE_DOT | |||||
| cb("IM2COLMATMUL:AARCH64_QUINT8_K8X8X4_DOTPROD"); | cb("IM2COLMATMUL:AARCH64_QUINT8_K8X8X4_DOTPROD"); | ||||
| #else | #else | ||||
| cb("IM2COLMATMUL:AARCH64_QUINT8_K8X8X8"); | cb("IM2COLMATMUL:AARCH64_QUINT8_K8X8X8"); | ||||
| #endif | #endif | ||||
| #elif MEGDNN_ARMV7 | #elif MEGDNN_ARMV7 | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #if MGB_ENABLE_DOT | |||||
| cb("IM2COLMATMUL:AARCH32_QUINT8_K4X8X4"); | cb("IM2COLMATMUL:AARCH32_QUINT8_K4X8X4"); | ||||
| #endif | #endif | ||||
| cb("IM2COLMATMUL:ARMV7_QUINT8_K4X8X8"); | cb("IM2COLMATMUL:ARMV7_QUINT8_K4X8X8"); | ||||
| @@ -392,7 +394,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_FP16) { | |||||
| #endif | #endif | ||||
| #if MEGDNN_AARCH64 || MEGDNN_ARMV7 | #if MEGDNN_AARCH64 || MEGDNN_ARMV7 | ||||
| #if !__ARM_FEATURE_DOTPROD | |||||
| //! enable none dot algo now | |||||
| TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44_S2) { | TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44_S2) { | ||||
| using namespace conv_bias; | using namespace conv_bias; | ||||
| std::vector<conv_bias::TestArg> args = get_nchw44_conv_bias_args( | std::vector<conv_bias::TestArg> args = get_nchw44_conv_bias_args( | ||||
| @@ -481,12 +483,11 @@ TEST_F(ARM_COMMON_MULTI_THREADS, | |||||
| #undef cb | #undef cb | ||||
| } | } | ||||
| #endif | |||||
| #endif | #endif | ||||
| #endif | #endif | ||||
| #if MEGDNN_AARCH64 | #if MEGDNN_AARCH64 | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #if MGB_ENABLE_DOT | |||||
| TEST_F(ARM_COMMON_MULTI_THREADS, | TEST_F(ARM_COMMON_MULTI_THREADS, | ||||
| CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM_NCHW44DOT_FUSE) { | CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM_NCHW44DOT_FUSE) { | ||||
| UniformIntRNG rng{-50, 50}; | UniformIntRNG rng{-50, 50}; | ||||
| @@ -516,14 +517,14 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32) { | |||||
| #define cb(name) checker_conv_bias_mul_int8x8x32(args, handle(), name); | #define cb(name) checker_conv_bias_mul_int8x8x32(args, handle(), name); | ||||
| #if MEGDNN_AARCH64 | #if MEGDNN_AARCH64 | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #if MGB_ENABLE_DOT | |||||
| cb("IM2COLMATMUL:AARCH64_INT8X8X32_K8X12X4_DOTPROD"); | cb("IM2COLMATMUL:AARCH64_INT8X8X32_K8X12X4_DOTPROD"); | ||||
| #else | #else | ||||
| cb("IM2COLMATMUL:AARCH64_INT8X8X32_K8X8X8"); | cb("IM2COLMATMUL:AARCH64_INT8X8X32_K8X8X8"); | ||||
| cb("IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16"); | cb("IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16"); | ||||
| #endif | #endif | ||||
| #elif MEGDNN_ARMV7 | #elif MEGDNN_ARMV7 | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #if MGB_ENABLE_DOT | |||||
| cb("IM2COLMATMUL:AARCH32_INT8_K6X8X4"); | cb("IM2COLMATMUL:AARCH32_INT8_K6X8X4"); | ||||
| #endif | #endif | ||||
| cb("IM2COLMATMUL:ARMV7_INT8X8X32_K4X8X8"); | cb("IM2COLMATMUL:ARMV7_INT8X8X32_K4X8X8"); | ||||