GitOrigin-RevId: 87046b8197
tags/v1.10.0
| @@ -12,7 +12,7 @@ | |||
| #include "src/aarch64/conv_bias/int8/algos.h" | |||
| #include "src/aarch64/conv_bias/int8/strategy.h" | |||
| #include "src/arm_common/convolution/img2col_helper.h" | |||
| #include "src/arm_common/elemwise_op.h" | |||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||
| #include "src/common/opr_delegate.h" | |||
| #include "src/fallback/conv_bias/common.h" | |||
| #include "src/fallback/matrix_mul/gemm_impl.h" | |||
| @@ -14,7 +14,7 @@ | |||
| #include "src/aarch64/matrix_mul/quint8_dot/gemv.h" | |||
| #include "src/aarch64/matrix_mul/quint8_dot/strategy.h" | |||
| #include "src/arm_common/convolution/img2col_helper.h" | |||
| #include "src/arm_common/elemwise_op.h" | |||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||
| #include "src/common/opr_delegate.h" | |||
| #include "src/fallback/conv_bias/common.h" | |||
| #include "src/fallback/matrix_mul/gemm_impl.h" | |||
| @@ -11,7 +11,7 @@ | |||
| */ | |||
| #include "src/arm_common/conv_bias/f16/channel_wise_3x3_s1p1_nchw88_kern.h" | |||
| #include "src/arm_common/elemwise_op.h" | |||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||
| #include "src/arm_common/simd_macro/marm_neon.h" | |||
| #include "src/arm_common/utils.h" | |||
| #include "src/common/utils.h" | |||
| @@ -12,7 +12,7 @@ | |||
| #include "src/arm_common/conv_bias/f16/algos.h" | |||
| #include "src/arm_common/conv_bias/f16/channel_wise_nchw88_kern.h" | |||
| #include "src/arm_common/elemwise_op.h" | |||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||
| #include "midout.h" | |||
| @@ -12,7 +12,7 @@ | |||
| #include "src/arm_common/conv_bias/f16/channel_wise_nchw88_kern.h" | |||
| #include "src/arm_common/conv_bias/f16/channel_wise_3x3_s1p1_nchw88_kern.h" | |||
| #include "src/arm_common/elemwise_op.h" | |||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||
| #include "src/arm_common/simd_macro/marm_neon.h" | |||
| #include "src/arm_common/utils.h" | |||
| #include "src/common/unroll_macro.h" | |||
| @@ -15,7 +15,7 @@ | |||
| #include "src/arm_common/conv_bias/f16/algos.h" | |||
| #include "src/arm_common/conv_bias/f16/direct_nchw88_kern.h" | |||
| #include "src/arm_common/elemwise_op.h" | |||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||
| #include "midout.h" | |||
| @@ -13,7 +13,7 @@ | |||
| #include "src/arm_common/conv_bias/f16/direct_nchw88_kern.h" | |||
| #include "src/arm_common/conv_bias/opr_impl.h" | |||
| #include "src/arm_common/elemwise_op.h" | |||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||
| #include "src/arm_common/simd_macro/marm_neon.h" | |||
| #include "src/fallback/conv_bias/common.h" | |||
| @@ -11,7 +11,7 @@ | |||
| */ | |||
| #include "src/arm_common/conv_bias/fp32/channel_wise_3x3_s1p1_nchw44_kern.h" | |||
| #include "src/arm_common/elemwise_op.h" | |||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||
| #include "src/arm_common/simd_macro/marm_neon.h" | |||
| #include "src/arm_common/utils.h" | |||
| #include "src/common/utils.h" | |||
| @@ -11,7 +11,7 @@ | |||
| */ | |||
| #include "src/arm_common/conv_bias/fp32/channel_wise_5x5_s1p2_nchw44_kern.h" | |||
| #include "src/arm_common/elemwise_op.h" | |||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||
| #include "src/arm_common/simd_macro/marm_neon.h" | |||
| #include "src/arm_common/utils.h" | |||
| #include "src/common/utils.h" | |||
| @@ -12,7 +12,7 @@ | |||
| #include "src/arm_common/conv_bias/fp32/algos.h" | |||
| #include "src/arm_common/conv_bias/fp32/channel_wise_nchw44_kern.h" | |||
| #include "src/arm_common/elemwise_op.h" | |||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||
| #include "midout.h" | |||
| @@ -13,7 +13,7 @@ | |||
| #include "src/arm_common/conv_bias/fp32/channel_wise_nchw44_kern.h" | |||
| #include "src/arm_common/conv_bias/fp32/channel_wise_3x3_s1p1_nchw44_kern.h" | |||
| #include "src/arm_common/conv_bias/fp32/channel_wise_5x5_s1p2_nchw44_kern.h" | |||
| #include "src/arm_common/elemwise_op.h" | |||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||
| #include "src/arm_common/simd_macro/marm_neon.h" | |||
| #include "src/arm_common/utils.h" | |||
| #include "src/common/unroll_macro.h" | |||
| @@ -14,7 +14,7 @@ | |||
| #include "megdnn/arch.h" | |||
| #include "src/arm_common/conv_bias/fp32/f32_direct_nchw44_kern.h" | |||
| #include "src/arm_common/conv_bias/intrinsic_helper.h" | |||
| #include "src/arm_common/elemwise_op.h" | |||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||
| #include "src/arm_common/simd_macro/marm_neon.h" | |||
| #include "src/common/unroll_macro.h" | |||
| #include "src/common/utils.h" | |||
| @@ -14,7 +14,7 @@ | |||
| #include "megdnn/arch.h" | |||
| #include "src/arm_common/conv_bias/fp32/f32_direct_nchw44_kern.h" | |||
| #include "src/arm_common/conv_bias/intrinsic_helper.h" | |||
| #include "src/arm_common/elemwise_op.h" | |||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||
| #include "src/arm_common/simd_macro/marm_neon.h" | |||
| #include "src/common/unroll_macro.h" | |||
| #include "src/common/utils.h" | |||
| @@ -14,7 +14,7 @@ | |||
| #include "src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_kern.h" | |||
| #include "src/arm_common/conv_bias/intrinsic_helper.h" | |||
| #include "src/arm_common/conv_bias/opr_impl.h" | |||
| #include "src/arm_common/elemwise_op.h" | |||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||
| #include "src/arm_common/simd_macro/marm_neon.h" | |||
| #include "src/common/unroll_macro.h" | |||
| #include "src/common/utils.h" | |||
| @@ -15,7 +15,7 @@ | |||
| #include "src/arm_common/conv_bias/fp32/algos.h" | |||
| #include "src/arm_common/conv_bias/fp32/f32_direct_nchw44_kern.h" | |||
| #include "src/arm_common/elemwise_op.h" | |||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||
| #include "midout.h" | |||
| @@ -15,7 +15,7 @@ | |||
| #include "src/arm_common/conv_bias/fp32/algos.h" | |||
| #include "src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_kern.h" | |||
| #include "src/arm_common/conv_bias/fp32/strategy.h" | |||
| #include "src/arm_common/elemwise_op.h" | |||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||
| #include "src/common/nchw_nchwxx_valid.h" | |||
| #include "src/common/opr_delegate.h" | |||
| @@ -13,7 +13,7 @@ | |||
| #include "megdnn/arch.h" | |||
| #include "src/arm_common/conv_bias/intrinsic_helper.h" | |||
| #include "src/arm_common/conv_bias/opr_impl.h" | |||
| #include "src/arm_common/elemwise_op.h" | |||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||
| #include "src/arm_common/simd_macro/marm_neon.h" | |||
| #include "src/common/unroll_macro.h" | |||
| #include "src/common/utils.h" | |||
| @@ -17,7 +17,7 @@ | |||
| #include "src/arm_common/conv_bias/int8/stride1_dotprod.h" | |||
| #include "src/arm_common/conv_bias/int8/stride2.h" | |||
| #include "src/arm_common/conv_bias/int8/stride2_dotprod.h" | |||
| #include "src/arm_common/elemwise_op.h" | |||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||
| #include "src/fallback/conv_bias/common.h" | |||
| #include "midout.h" | |||
| @@ -11,7 +11,7 @@ | |||
| */ | |||
| #include "src/arm_common/conv_bias/int8/channel_wise_kernel.h" | |||
| #include "src/arm_common/elemwise_op.h" | |||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||
| #include "src/arm_common/simd_macro/marm_neon.h" | |||
| #include "src/common/unroll_macro.h" | |||
| #include "src/common/utils.h" | |||
| @@ -12,7 +12,7 @@ | |||
| #include "src/arm_common/conv_bias/int8/channel_wise_nchw44.h" | |||
| #include "megdnn/oprs.h" | |||
| #include "src/arm_common/conv_bias/int8/channel_wise_kernel.h" | |||
| #include "src/arm_common/elemwise_op.h" | |||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||
| #include "src/common/opr_delegate.h" | |||
| #include "midout.h" | |||
| @@ -10,7 +10,7 @@ | |||
| */ | |||
| #include "src/arm_common/conv_bias/int8/direct.h" | |||
| #include "src/arm_common/elemwise_op.h" | |||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||
| #include "src/arm_common/simd_macro/marm_neon.h" | |||
| #include "src/common/utils.h" | |||
| #include "src/fallback/conv_bias/common.h" | |||
| @@ -11,7 +11,7 @@ | |||
| #include "src/arm_common/conv_bias/int8/direct_dotprod.h" | |||
| #if MGB_ENABLE_DOT | |||
| #include "src/arm_common/elemwise_op.h" | |||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||
| #include "src/arm_common/simd_macro/marm_neon.h" | |||
| #include "src/common/utils.h" | |||
| #include "src/fallback/conv_bias/common.h" | |||
| @@ -14,7 +14,7 @@ | |||
| #if MGB_ENABLE_DOT | |||
| #include "src/arm_common/conv_bias/int8/algos.h" | |||
| #include "src/arm_common/conv_bias/int8/direct_dotprod_nchw44.h" | |||
| #include "src/arm_common/elemwise_op.h" | |||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||
| #include "midout.h" | |||
| @@ -14,7 +14,7 @@ | |||
| #include "megdnn/arch.h" | |||
| #include "src/arm_common/conv_bias/intrinsic_helper.h" | |||
| #if MGB_ENABLE_DOT | |||
| #include "src/arm_common/elemwise_op.h" | |||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||
| #include "src/arm_common/intrinsic_helper.h" | |||
| #include "src/arm_common/neon_struct.h" | |||
| #include "src/common/unroll_macro.h" | |||
| @@ -13,7 +13,7 @@ | |||
| #include "src/arm_common/conv_bias/int8/direct.h" | |||
| #include "src/arm_common/conv_bias/int8/direct_nchw44_kern.h" | |||
| #include "src/arm_common/conv_bias/intrinsic_helper.h" | |||
| #include "src/arm_common/elemwise_op.h" | |||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||
| #include "src/arm_common/simd_macro/marm_neon.h" | |||
| #include "src/common/utils.h" | |||
| #include "src/fallback/conv_bias/common.h" | |||
| @@ -14,7 +14,7 @@ | |||
| #include "src/arm_common/conv_bias/int8/direct.h" | |||
| #include "src/arm_common/conv_bias/int8/direct_nchw44_kern.h" | |||
| #include "src/arm_common/conv_bias/intrinsic_helper.h" | |||
| #include "src/arm_common/elemwise_op.h" | |||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||
| #include "src/arm_common/simd_macro/marm_neon.h" | |||
| #include "src/common/utils.h" | |||
| #include "src/fallback/conv_bias/common.h" | |||
| @@ -14,7 +14,7 @@ | |||
| #include "src/arm_common/conv_bias/int8/algos.h" | |||
| #include "src/arm_common/conv_bias/int8/direct.h" | |||
| #include "src/arm_common/conv_bias/int8/direct_nchw44_kern.h" | |||
| #include "src/arm_common/elemwise_op.h" | |||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||
| #include "src/common/opr_delegate.h" | |||
| #include "midout.h" | |||
| @@ -12,7 +12,7 @@ | |||
| #pragma once | |||
| #include "src/arm_common/conv_bias/int8/direct.h" | |||
| #include "src/arm_common/conv_bias/intrinsic_helper.h" | |||
| #include "src/arm_common/elemwise_op.h" | |||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||
| #include "src/arm_common/simd_macro/marm_neon.h" | |||
| #include "src/common/utils.h" | |||
| #include "src/fallback/conv_bias/common.h" | |||
| @@ -14,7 +14,7 @@ | |||
| #include "src/arm_common/conv_bias/int8/algos.h" | |||
| #include "src/arm_common/conv_bias/int8/direct_nchw_nchw44_kern.h" | |||
| #include "src/arm_common/conv_bias/int8/strategy.h" | |||
| #include "src/arm_common/elemwise_op.h" | |||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||
| #include "src/common/nchw_nchwxx_valid.h" | |||
| #include "src/common/opr_delegate.h" | |||
| @@ -13,7 +13,7 @@ | |||
| #include "megdnn/arch.h" | |||
| #include "src/arm_common/conv_bias/intrinsic_helper.h" | |||
| #include "src/arm_common/conv_bias/opr_impl.h" | |||
| #include "src/arm_common/elemwise_op.h" | |||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||
| #include "src/arm_common/simd_macro/marm_neon.h" | |||
| #include "src/common/unroll_macro.h" | |||
| #include "src/common/utils.h" | |||
| @@ -15,7 +15,7 @@ | |||
| #include "src/arm_common/conv_bias/block_helper.h" | |||
| #include "src/arm_common/conv_bias/int8/algos.h" | |||
| #include "src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_kern.h" | |||
| #include "src/arm_common/elemwise_op.h" | |||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||
| #include "src/common/nchw_nchwxx_valid.h" | |||
| #include "midout.h" | |||
| @@ -14,7 +14,7 @@ | |||
| #include "src/arm_common/conv_bias/intrinsic_helper.h" | |||
| #if MGB_ENABLE_DOT | |||
| #include "src/arm_common/elemwise_op.h" | |||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||
| #include "src/arm_common/simd_macro/marm_neon.h" | |||
| #include "src/common/unroll_macro.h" | |||
| #include "src/common/utils.h" | |||
| @@ -14,7 +14,7 @@ | |||
| #include "megdnn/oprs.h" | |||
| #include "src/arm_common/conv_bias/int8/direct.h" | |||
| #include "src/arm_common/conv_bias/int8/strategy.h" | |||
| #include "src/arm_common/elemwise_op.h" | |||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||
| #include "src/common/opr_delegate.h" | |||
| using namespace megdnn; | |||
| @@ -14,7 +14,7 @@ | |||
| #include "megdnn/oprs.h" | |||
| #include "src/arm_common/conv_bias/int8/direct_dotprod.h" | |||
| #include "src/arm_common/conv_bias/int8/strategy.h" | |||
| #include "src/arm_common/elemwise_op.h" | |||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||
| #include "src/common/opr_delegate.h" | |||
| using namespace megdnn; | |||
| @@ -14,7 +14,7 @@ | |||
| #include "megdnn/oprs.h" | |||
| #include "src/arm_common/conv_bias/int8/direct.h" | |||
| #include "src/arm_common/conv_bias/int8/strategy.h" | |||
| #include "src/arm_common/elemwise_op.h" | |||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||
| #include "src/common/opr_delegate.h" | |||
| using namespace megdnn; | |||
| @@ -14,7 +14,7 @@ | |||
| #include "megdnn/oprs.h" | |||
| #include "src/arm_common/conv_bias/int8/direct_dotprod.h" | |||
| #include "src/arm_common/conv_bias/int8/strategy.h" | |||
| #include "src/arm_common/elemwise_op.h" | |||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||
| #include "src/common/opr_delegate.h" | |||
| using namespace megdnn; | |||
| @@ -11,7 +11,7 @@ | |||
| */ | |||
| #include "src/arm_common/conv_bias/int8x8x16/channel_wise_kernel.h" | |||
| #include "src/arm_common/elemwise_op.h" | |||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||
| #include "src/arm_common/simd_macro/marm_neon.h" | |||
| #include "src/common/unroll_macro.h" | |||
| #include "src/common/utils.h" | |||
| @@ -15,7 +15,7 @@ | |||
| #include "src/arm_common/conv_bias/block_helper.h" | |||
| #include "src/arm_common/conv_bias/int8x8x16/algos.h" | |||
| #include "src/arm_common/conv_bias/int8x8x16/direct_nchw_nchw44_kern.h" | |||
| #include "src/arm_common/elemwise_op.h" | |||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||
| #include "src/common/nchw_nchwxx_valid.h" | |||
| #include "src/common/opr_delegate.h" | |||
| @@ -13,7 +13,7 @@ | |||
| #include "megdnn/arch.h" | |||
| #include "src/arm_common/conv_bias/intrinsic_helper.h" | |||
| #include "src/arm_common/conv_bias/opr_impl.h" | |||
| #include "src/arm_common/elemwise_op.h" | |||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||
| #include "src/arm_common/simd_macro/marm_neon.h" | |||
| #include "src/common/unroll_macro.h" | |||
| #include "src/common/utils.h" | |||
| @@ -16,7 +16,7 @@ | |||
| #include "src/arm_common/conv_bias/int8x8x16/direct_nchw_nchw44_kern.h" | |||
| #include "src/arm_common/conv_bias/intrinsic_helper.h" | |||
| #include "src/arm_common/conv_bias/opr_impl.h" | |||
| #include "src/arm_common/elemwise_op.h" | |||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||
| #include "src/arm_common/simd_macro/marm_neon.h" | |||
| #include "src/common/unroll_macro.h" | |||
| #include "src/common/utils.h" | |||
| @@ -12,7 +12,7 @@ | |||
| #include "megdnn/dtype.h" | |||
| #include "megdnn/oprs.h" | |||
| #include "src/arm_common/elemwise_op.h" | |||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||
| #include "src/arm_common/simd_macro/marm_neon.h" | |||
| #include "src/common/utils.h" | |||
| #include "src/fallback/conv_bias/common.h" | |||
| @@ -13,8 +13,8 @@ | |||
| #pragma once | |||
| #include "megdnn/basic_types.h" | |||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||
| #include "src/arm_common/elemwise_helper/kimpl/op_base.h" | |||
| #include "src/arm_common/elemwise_op.h" | |||
| #include "src/fallback/conv_bias/opr_impl.h" | |||
| #include "midout.h" | |||
| @@ -44,29 +44,29 @@ namespace { | |||
| break; | |||
| #define FOR_NONLINEAR_UNARY(_op) \ | |||
| megdnn::arm_common::OpCallerUnary<_op<ctype>, megdnn::VEC>::run( \ | |||
| megdnn::elemwise::OpCallerUnary<_op<ctype>, megdnn::elemwise::VEC>::run( \ | |||
| static_cast<ctype*>(conv_dst_ptr), reinterpret_cast<ctype*>(dst_ptr), \ | |||
| bias_type, dst_type, N* OC* OH* OW* pack_oc_size); | |||
| #define FOR_NONLINEAR_BINARY_BROADCAST(_op) \ | |||
| megdnn::arm_common::OpCallerBinary<_op<ctype>, megdnn::VEC_BCAST101>::run( \ | |||
| static_cast<ctype*>(conv_dst_ptr), \ | |||
| reinterpret_cast<const ctype*>(bias_ptr), \ | |||
| reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, N, OC, \ | |||
| #define FOR_NONLINEAR_BINARY_BROADCAST(_op) \ | |||
| megdnn::elemwise::OpCallerBinary<_op<ctype>, megdnn::elemwise::VEC_BCAST101>::run( \ | |||
| static_cast<ctype*>(conv_dst_ptr), \ | |||
| reinterpret_cast<const ctype*>(bias_ptr), \ | |||
| reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, N, OC, \ | |||
| OH* OW); | |||
| #define FOR_NONLINEAR_BINARY_BROADCAST_NCHWXX(_op) \ | |||
| megdnn::arm_common::OpCallerBinary<_op<ctype>, megdnn::VEC_BCAST101xX>::run( \ | |||
| static_cast<ctype*>(conv_dst_ptr), \ | |||
| reinterpret_cast<const ctype*>(bias_ptr), \ | |||
| reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, N, OC, \ | |||
| OH* OW, pack_oc_size); | |||
| #define FOR_NONLINEAR_BINARY(_op) \ | |||
| megdnn::arm_common::OpCallerBinary<_op<ctype>, megdnn::VEC_VEC>::run( \ | |||
| static_cast<ctype*>(conv_dst_ptr), \ | |||
| reinterpret_cast<const ctype*>(bias_ptr), \ | |||
| reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, \ | |||
| megdnn::elemwise::OpCallerBinary<_op<ctype>, megdnn::elemwise::VEC_BCAST101xX>:: \ | |||
| run(static_cast<ctype*>(conv_dst_ptr), \ | |||
| reinterpret_cast<const ctype*>(bias_ptr), \ | |||
| reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, N, \ | |||
| OC, OH* OW, pack_oc_size); | |||
| #define FOR_NONLINEAR_BINARY(_op) \ | |||
| megdnn::elemwise::OpCallerBinary<_op<ctype>, megdnn::elemwise::VEC_VEC>::run( \ | |||
| static_cast<ctype*>(conv_dst_ptr), \ | |||
| reinterpret_cast<const ctype*>(bias_ptr), \ | |||
| reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, \ | |||
| N* OC* OH* OW* pack_oc_size); | |||
| #define FOR_BIAS(_mode) \ | |||
| @@ -167,33 +167,35 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::NO_PROCESS> { | |||
| #undef FOR_BIAS | |||
| #undef HANDLE_IDENTITY | |||
| #define FOR_NONLINEAR_UNARY(_op) \ | |||
| megdnn::arm_common::OpCallerUnary<_op<opctype, opdtype>, megdnn::VEC>::run( \ | |||
| static_cast<opctype*>(conv_dst_ptr), reinterpret_cast<opdtype*>(dst_ptr), \ | |||
| bias_type, dst_type, N* OC* OH* OW* pack_oc_size); | |||
| #define FOR_NONLINEAR_BINARY_BROADCAST(_op) \ | |||
| megdnn::arm_common::OpCallerBinary<_op<opctype, opdtype>, megdnn::VEC_BCAST101>:: \ | |||
| run(static_cast<opctype*>(conv_dst_ptr), \ | |||
| reinterpret_cast<const opctype*>(bias_ptr), \ | |||
| reinterpret_cast<opdtype*>(dst_ptr), bias_type, bias_type, dst_type, \ | |||
| #define FOR_NONLINEAR_UNARY(_op) \ | |||
| megdnn::elemwise::OpCallerUnary<_op<opctype, opdtype>, megdnn::elemwise::VEC>:: \ | |||
| run(static_cast<opctype*>(conv_dst_ptr), \ | |||
| reinterpret_cast<opdtype*>(dst_ptr), bias_type, dst_type, \ | |||
| N* OC* OH* OW* pack_oc_size); | |||
| #define FOR_NONLINEAR_BINARY_BROADCAST(_op) \ | |||
| megdnn::elemwise::OpCallerBinary< \ | |||
| _op<opctype, opdtype>, megdnn::elemwise::VEC_BCAST101>:: \ | |||
| run(static_cast<opctype*>(conv_dst_ptr), \ | |||
| reinterpret_cast<const opctype*>(bias_ptr), \ | |||
| reinterpret_cast<opdtype*>(dst_ptr), bias_type, bias_type, dst_type, \ | |||
| N, OC, OH* OW); | |||
| #define FOR_NONLINEAR_BINARY_BROADCAST_NCHWXX(_op) \ | |||
| megdnn::arm_common:: \ | |||
| OpCallerBinary<_op<opctype, opdtype>, megdnn::VEC_BCAST101xX>::run( \ | |||
| static_cast<opctype*>(conv_dst_ptr), \ | |||
| reinterpret_cast<const opctype*>(bias_ptr), \ | |||
| reinterpret_cast<opdtype*>(dst_ptr), bias_type, bias_type, \ | |||
| dst_type, N, OC, OH* OW, pack_oc_size); | |||
| #define FOR_NONLINEAR_BINARY_BROADCAST_NCHW88(_op) \ | |||
| megdnn::arm_common:: \ | |||
| OpCallerBinary<_op<opctype, opdtype>, megdnn::VEC_BCAST101xX>::run( \ | |||
| static_cast<opctype*>(conv_dst_ptr), \ | |||
| reinterpret_cast<const opctype*>(bias_ptr), \ | |||
| reinterpret_cast<opdtype*>(dst_ptr), bias_type, bias_type, \ | |||
| dst_type, N, OC, OH* OW, pack_oc_size); | |||
| #define FOR_NONLINEAR_BINARY_BROADCAST_NCHWXX(_op) \ | |||
| megdnn::elemwise::OpCallerBinary< \ | |||
| _op<opctype, opdtype>, megdnn::elemwise::VEC_BCAST101xX>:: \ | |||
| run(static_cast<opctype*>(conv_dst_ptr), \ | |||
| reinterpret_cast<const opctype*>(bias_ptr), \ | |||
| reinterpret_cast<opdtype*>(dst_ptr), bias_type, bias_type, dst_type, \ | |||
| N, OC, OH* OW, pack_oc_size); | |||
| #define FOR_NONLINEAR_BINARY_BROADCAST_NCHW88(_op) \ | |||
| megdnn::elemwise::OpCallerBinary< \ | |||
| _op<opctype, opdtype>, megdnn::elemwise::VEC_BCAST101xX>:: \ | |||
| run(static_cast<opctype*>(conv_dst_ptr), \ | |||
| reinterpret_cast<const opctype*>(bias_ptr), \ | |||
| reinterpret_cast<opdtype*>(dst_ptr), bias_type, bias_type, dst_type, \ | |||
| N, OC, OH* OW, pack_oc_size); | |||
| #define HANDLE_IDENTITY(_caller, _op) \ | |||
| case megdnn::NonlineMode::IDENTITY: \ | |||
| @@ -267,25 +269,25 @@ struct PostProcess<opctype, opdtype, megdnn::PostprocessMode::QUANTIZED> { | |||
| #undef FOR_NONLINEAR | |||
| #undef FOR_BIAS | |||
| #define FOR_BINARY_BROADCAST(_op) \ | |||
| megdnn::arm_common::OpCallerBinary<_op<ctype>, megdnn::VEC_BCAST101>::run( \ | |||
| static_cast<ctype*>(conv_dst_ptr), \ | |||
| reinterpret_cast<const ctype*>(bias_ptr), \ | |||
| reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, N, OC, \ | |||
| #define FOR_BINARY_BROADCAST(_op) \ | |||
| megdnn::elemwise::OpCallerBinary<_op<ctype>, megdnn::elemwise::VEC_BCAST101>::run( \ | |||
| static_cast<ctype*>(conv_dst_ptr), \ | |||
| reinterpret_cast<const ctype*>(bias_ptr), \ | |||
| reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, N, OC, \ | |||
| OH* OW); | |||
| #define FOR_BINARY_BROADCAST_NCHWXX(_op) \ | |||
| megdnn::arm_common::OpCallerBinary<_op<ctype>, megdnn::VEC_BCAST101xX>::run( \ | |||
| static_cast<ctype*>(conv_dst_ptr), \ | |||
| reinterpret_cast<const ctype*>(bias_ptr), \ | |||
| reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, N, OC, \ | |||
| OH* OW, pack_oc_size); | |||
| #define FOR_BINARY(_op) \ | |||
| megdnn::arm_common::OpCallerBinary<_op<ctype>, megdnn::VEC_VEC>::run( \ | |||
| static_cast<ctype*>(conv_dst_ptr), \ | |||
| reinterpret_cast<const ctype*>(bias_ptr), \ | |||
| reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, \ | |||
| megdnn::elemwise::OpCallerBinary<_op<ctype>, megdnn::elemwise::VEC_BCAST101xX>:: \ | |||
| run(static_cast<ctype*>(conv_dst_ptr), \ | |||
| reinterpret_cast<const ctype*>(bias_ptr), \ | |||
| reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, N, \ | |||
| OC, OH* OW, pack_oc_size); | |||
| #define FOR_BINARY(_op) \ | |||
| megdnn::elemwise::OpCallerBinary<_op<ctype>, megdnn::elemwise::VEC_VEC>::run( \ | |||
| static_cast<ctype*>(conv_dst_ptr), \ | |||
| reinterpret_cast<const ctype*>(bias_ptr), \ | |||
| reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, \ | |||
| N* OC* OH* OW* pack_oc_size); | |||
| #define FOR_BIAS(_bias_mode, OH, OW) \ | |||
| @@ -15,7 +15,7 @@ | |||
| #include "src/arm_common/conv_bias/quint8/stride1_dotprod.h" | |||
| #include "src/arm_common/conv_bias/quint8/stride2.h" | |||
| #include "src/arm_common/conv_bias/quint8/stride2_dotprod.h" | |||
| #include "src/arm_common/elemwise_op.h" | |||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||
| #include "src/fallback/conv_bias/common.h" | |||
| #include "midout.h" | |||
| @@ -10,7 +10,7 @@ | |||
| */ | |||
| #include "src/arm_common/conv_bias/quint8/direct.h" | |||
| #include "src/arm_common/elemwise_op.h" | |||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||
| #include "src/arm_common/simd_macro/marm_neon.h" | |||
| #include "src/common/utils.h" | |||
| #include "src/fallback/conv_bias/common.h" | |||
| @@ -11,7 +11,7 @@ | |||
| #include "src/arm_common/conv_bias/quint8/direct_dotprod.h" | |||
| #if MGB_ENABLE_DOT | |||
| #include "src/arm_common/elemwise_op.h" | |||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||
| #include "src/arm_common/simd_macro/marm_neon.h" | |||
| #include "src/common/utils.h" | |||
| #include "src/fallback/conv_bias/common.h" | |||
| @@ -12,7 +12,7 @@ | |||
| #include "src/arm_common/conv_bias/quint8/stride1.h" | |||
| #include "megdnn/oprs.h" | |||
| #include "src/arm_common/conv_bias/quint8/direct.h" | |||
| #include "src/arm_common/elemwise_op.h" | |||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||
| #include "src/common/opr_delegate.h" | |||
| using namespace megdnn; | |||
| @@ -12,7 +12,7 @@ | |||
| #if MGB_ENABLE_DOT | |||
| #include "megdnn/oprs.h" | |||
| #include "src/arm_common/conv_bias/quint8/direct_dotprod.h" | |||
| #include "src/arm_common/elemwise_op.h" | |||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||
| #include "src/common/opr_delegate.h" | |||
| using namespace megdnn; | |||
| @@ -12,7 +12,7 @@ | |||
| #include "src/arm_common/conv_bias/quint8/stride2.h" | |||
| #include "megdnn/oprs.h" | |||
| #include "src/arm_common/conv_bias/quint8/direct.h" | |||
| #include "src/arm_common/elemwise_op.h" | |||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||
| #include "src/common/opr_delegate.h" | |||
| using namespace megdnn; | |||
| @@ -12,7 +12,7 @@ | |||
| #if MGB_ENABLE_DOT | |||
| #include "megdnn/oprs.h" | |||
| #include "src/arm_common/conv_bias/quint8/direct_dotprod.h" | |||
| #include "src/arm_common/elemwise_op.h" | |||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||
| #include "src/common/opr_delegate.h" | |||
| using namespace megdnn; | |||
| @@ -10,7 +10,7 @@ | |||
| * implied. | |||
| */ | |||
| #include "src/arm_common/elemwise/binary/algo.h" | |||
| #include "src/arm_common/elemwise_op.h" | |||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||
| #include "src/common/utils.h" | |||
| #include "src/naive/handle.h" | |||
| @@ -20,6 +20,7 @@ | |||
| MIDOUT_DECL(megdnn_arm_common_elemwise_binary) | |||
| using namespace megdnn; | |||
| using namespace elemwise; | |||
| using namespace arm_common; | |||
| namespace { | |||
| @@ -160,7 +161,7 @@ bool ElemwiseImpl::AlgoBinaryVecBcast101xX::is_available( | |||
| DISPATCH_BINARY(ADD, _case, _type, _type_midout_id, AddOp); \ | |||
| DISPATCH_BINARY(SUB, _case, _type, _type_midout_id, SubOp); \ | |||
| DISPATCH_BINARY(MUL, _case, _type, _type_midout_id, MulOp); \ | |||
| DISPATCH_BINARY(POW, _case, _type, _type_midout_id, PowOp); \ | |||
| DISPATCH_BINARY(POW, _case, _type, _type_midout_id, fallback::PowOp); \ | |||
| DISPATCH_BINARY(TRUE_DIV, _case, _type, _type_midout_id, TrueDivOp); \ | |||
| DISPATCH_BINARY(FUSE_ADD_RELU, _case, _type, _type_midout_id, FuseAddReluOp); \ | |||
| DISPATCH_BINARY( \ | |||
| @@ -178,7 +179,7 @@ bool ElemwiseImpl::AlgoBinaryVecBcast101xX::is_available( | |||
| DISPATCH_BINARY(ADD, _case, _type, _type_midout_id, AddOp); \ | |||
| DISPATCH_BINARY(SUB, _case, _type, _type_midout_id, SubOp); \ | |||
| DISPATCH_BINARY(MUL, _case, _type, _type_midout_id, MulOp); \ | |||
| DISPATCH_BINARY(POW, _case, _type, _type_midout_id, PowOp); \ | |||
| DISPATCH_BINARY(POW, _case, _type, _type_midout_id, fallback::PowOp); \ | |||
| DISPATCH_BINARY(FUSE_ADD_RELU, _case, _type, _type_midout_id, FuseAddReluOp); \ | |||
| DISPATCH_BINARY( \ | |||
| FUSE_ADD_H_SWISH, _case, _type, _type_midout_id, FuseAddHSwishOp); \ | |||
| @@ -13,7 +13,7 @@ | |||
| #include "src/arm_common/elemwise/binary/algo.h" | |||
| #include "src/arm_common/elemwise/ternary/algo.h" | |||
| #include "src/arm_common/elemwise/unary/algo.h" | |||
| #include "src/arm_common/elemwise_op.h" | |||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||
| #include "src/common/metahelper.h" | |||
| #include "src/common/utils.h" | |||
| @@ -12,7 +12,7 @@ | |||
| #pragma once | |||
| #include "src/fallback/elemwise/opr_impl.h" | |||
| #include "src/arm_common/elemwise_op.h" | |||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||
| namespace megdnn { | |||
| namespace arm_common { | |||
| @@ -10,7 +10,7 @@ | |||
| * implied. | |||
| */ | |||
| #include "src/arm_common/elemwise/ternary/algo.h" | |||
| #include "src/arm_common/elemwise_op.h" | |||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||
| #include "src/common/utils.h" | |||
| #include "src/naive/handle.h" | |||
| @@ -20,6 +20,7 @@ | |||
| MIDOUT_DECL(megdnn_arm_common_elemwise_ternary) | |||
| using namespace megdnn; | |||
| using namespace elemwise; | |||
| using namespace arm_common; | |||
| #define DISPATCH_MODE_FLOAT(_case, _type, _type_midout_id) \ | |||
| @@ -10,7 +10,7 @@ | |||
| * implied. | |||
| */ | |||
| #include "src/arm_common/elemwise/unary/algo.h" | |||
| #include "src/arm_common/elemwise_op.h" | |||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||
| #include "src/common/utils.h" | |||
| #include "src/naive/handle.h" | |||
| @@ -20,6 +20,7 @@ | |||
| MIDOUT_DECL(megdnn_arm_common_elemwise_unary) | |||
| using namespace megdnn; | |||
| using namespace elemwise; | |||
| using namespace arm_common; | |||
| bool ElemwiseImpl::AlgoUnary::is_available(const KernParam& kern_param) const { | |||
| @@ -0,0 +1,151 @@ | |||
| /** | |||
| * \file dnn/src/arm_common/elemwise_helper/elemwise_op.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #pragma once | |||
| #include "src/arm_common/elemwise_helper/op_binary.h" | |||
| #include "src/arm_common/elemwise_helper/op_ternary.h" | |||
| #include "src/arm_common/elemwise_helper/op_unary.h" | |||
| #include "src/fallback/elemwise_helper/elemwise_op.h" | |||
| namespace megdnn { | |||
| namespace elemwise { | |||
| using BcastType = megdnn::elemwise::BcastType; | |||
| ///////////////////////////////// ParamElemVistor /////////////////////////// | |||
| #define cb(_ctype, _inner_ctype, _neon_type, _fun_suffix) \ | |||
| template <> \ | |||
| struct ParamElemVisitor<_ctype> { \ | |||
| _neon_type operator()(const _ctype* src) const { \ | |||
| return vld1q_##_fun_suffix(reinterpret_cast<const _inner_ctype*>(src)); \ | |||
| } \ | |||
| }; \ | |||
| template <> \ | |||
| struct ParamElemVisitorDup<_ctype> { \ | |||
| _neon_type operator()(const _ctype* src) const { \ | |||
| return vdupq_n_##_fun_suffix(*reinterpret_cast<const _inner_ctype*>(src)); \ | |||
| } \ | |||
| } | |||
| cb(dt_quint8, uint8_t, uint8x16_t, u8); | |||
| #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
| cb(__fp16, __fp16, float16x8_t, f16); | |||
| #endif | |||
| cb(dt_int16, int16_t, int16x8_t, s16); | |||
| #undef cb | |||
| template <typename ctype> | |||
| struct ParamElemVisitorBcast101x4; | |||
| #define cb(_ctype, _inner_ctype, _neon_type, _fun_suffix, rel_suffix) \ | |||
| template <> \ | |||
| struct ParamElemVisitorBcast101x4<_ctype> { \ | |||
| _neon_type operator()(const _ctype* src) const { \ | |||
| return vreinterpretq_##_fun_suffix##_##rel_suffix(vld1q_dup_##rel_suffix( \ | |||
| reinterpret_cast<const _inner_ctype*>(src))); \ | |||
| } \ | |||
| } | |||
| cb(dt_quint8, uint32_t, uint8x16_t, u8, u32); | |||
| cb(dt_int16, int64_t, int16x8_t, s16, s64); | |||
| #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
| cb(__fp16, uint64_t, float16x8_t, f16, u64); | |||
| #endif | |||
| #undef cb | |||
| template <typename ctype> | |||
| struct ParamElemVisitorBcast101x8; | |||
| #define cb(_ctype, _inner_ctype, _neon_type, _fun_suffix) \ | |||
| template <> \ | |||
| struct ParamElemVisitorBcast101x8<_ctype> { \ | |||
| _neon_type operator()(const _ctype* src) const { \ | |||
| return vld1q_##_fun_suffix(reinterpret_cast<const _inner_ctype*>(src)); \ | |||
| } \ | |||
| } | |||
| #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
| cb(__fp16, __fp16, float16x8_t, f16); | |||
| #endif | |||
| #undef cb | |||
| #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
| template <> | |||
| struct OpCallerBinaryBcast101xXVec<__fp16, 8> { | |||
| using src_ctype = __fp16; | |||
| template <typename Op> | |||
| static void run( | |||
| const src_ctype* src0, const src_ctype* src1, typename Op::dst_ctype* dst, | |||
| const Op& op, size_t batch, size_t nr_channel_blocks, | |||
| size_t channel_stride) { | |||
| ParamElemVisitorBcast101x8<src_ctype> vis0; | |||
| ParamElemVisitor<src_ctype> vis1; | |||
| OpCallerBinaryBcast101xDVec<src_ctype, 8>::run( | |||
| src0, src1, dst, op, vis0, vis1, batch, nr_channel_blocks, | |||
| channel_stride); | |||
| } | |||
| }; | |||
| template <> | |||
| struct OpCallerBinaryVecBcast101xX<__fp16, 8> { | |||
| using src_ctype = __fp16; | |||
| template <typename Op> | |||
| static void run( | |||
| const src_ctype* src0, const src_ctype* src1, typename Op::dst_ctype* dst, | |||
| const Op& op, size_t batch, size_t nr_channel_blocks, | |||
| size_t channel_stride) { | |||
| ParamElemVisitor<src_ctype> vis0; | |||
| ParamElemVisitorBcast101x8<src_ctype> vis1; | |||
| OpCallerBinaryVecBcast101xD<src_ctype, 8>::run( | |||
| src0, src1, dst, op, vis0, vis1, batch, nr_channel_blocks, | |||
| channel_stride); | |||
| } | |||
| }; | |||
| template <> | |||
| struct OpCallerTernaryBcast101xXVecBcast101xX<__fp16, 8> { | |||
| using src_ctype = __fp16; | |||
| template <typename Op> | |||
| static void run( | |||
| const src_ctype* src0, const src_ctype* src1, const src_ctype* src2, | |||
| typename Op::dst_ctype* dst, const Op& op, size_t batch, | |||
| size_t nr_channel_blocks, size_t channel_stride) { | |||
| ParamElemVisitorBcast101x8<src_ctype> vis0; | |||
| ParamElemVisitor<src_ctype> vis1; | |||
| ParamElemVisitorBcast101x8<src_ctype> vis2; | |||
| OpCallerTernaryBcast101xDVecBcast101xD<src_ctype, 8>::run( | |||
| src0, src1, src2, dst, op, vis0, vis1, vis2, batch, nr_channel_blocks, | |||
| channel_stride); | |||
| } | |||
| }; | |||
| template <> | |||
| struct OpCallerTernaryVecBcast101xXVec<__fp16, 8> { | |||
| using src_ctype = __fp16; | |||
| template <typename Op> | |||
| static void run( | |||
| const src_ctype* src0, const src_ctype* src1, const src_ctype* src2, | |||
| typename Op::dst_ctype* dst, const Op& op, size_t batch, | |||
| size_t nr_channel_blocks, size_t channel_stride) { | |||
| ParamElemVisitor<src_ctype> vis0; | |||
| ParamElemVisitorBcast101x8<src_ctype> vis1; | |||
| ParamElemVisitor<src_ctype> vis2; | |||
| OpCallerTernaryVecBcast101xDVec<src_ctype, 8>::run( | |||
| src0, src1, src2, dst, op, vis0, vis1, vis2, batch, nr_channel_blocks, | |||
| channel_stride); | |||
| } | |||
| }; | |||
| #endif | |||
| } // namespace elemwise | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -1,36 +0,0 @@ | |||
| /** | |||
| * \file dnn/src/arm_common/elemwise_helper/kimpl/pow.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #pragma once | |||
| #include "src/arm_common/elemwise_helper/kimpl/op_base.h" | |||
| namespace megdnn { | |||
| namespace arm_common { | |||
| // when __fp16 is avaliable POW is very slow, so add there | |||
| /////////////////////// POW float only //////////////////////////// | |||
| template <typename src_ctype, typename dst_ctype = src_ctype> | |||
| struct PowOp : BinaryOpBase<src_ctype, dst_ctype> { | |||
| using BinaryOpBase<src_ctype, dst_ctype>::BinaryOpBase; | |||
| constexpr static size_t SIMD_WIDTH = 1; | |||
| void operator()( | |||
| const src_ctype& src0, const src_ctype& src1, dst_ctype* dst) const { | |||
| *dst = operator()(src0, src1); | |||
| } | |||
| dst_ctype operator()(const src_ctype& src0, const src_ctype& src1) const { | |||
| return powf(src0, src1); | |||
| } | |||
| }; | |||
| } // namespace arm_common | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -18,7 +18,6 @@ | |||
| #include "src/arm_common/elemwise_helper/kimpl/max.h" | |||
| #include "src/arm_common/elemwise_helper/kimpl/min.h" | |||
| #include "src/arm_common/elemwise_helper/kimpl/mul.h" | |||
| #include "src/arm_common/elemwise_helper/kimpl/pow.h" | |||
| #include "src/arm_common/elemwise_helper/kimpl/rmulh.h" | |||
| #include "src/arm_common/elemwise_helper/kimpl/sub.h" | |||
| #include "src/arm_common/elemwise_helper/kimpl/true_div.h" | |||
| @@ -15,7 +15,7 @@ | |||
| #include "src/common/elemwise_multi_type/kern_defs.cuh" | |||
| #include "src/naive/handle.h" | |||
| #include "src/arm_common/elemwise_op.h" | |||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||
| #include "src/arm_common/simd_macro/marm_neon.h" | |||
| namespace { | |||
| @@ -46,6 +46,8 @@ void neon_round_shr_saturate_int16_static_k( | |||
| } // namespace | |||
| using namespace elemwise; | |||
| namespace megdnn { | |||
| namespace arm_common { | |||
| @@ -2,7 +2,7 @@ | |||
| * \file dnn/src/fallback/elemwise/gi_impl/binary/algo.cpp | |||
| */ | |||
| #include "src/fallback/elemwise/gi_impl/binary/algo.h" | |||
| #include "src/fallback/elemwise_op.h" | |||
| #include "src/fallback/elemwise_helper/elemwise_op.h" | |||
| #include "src/common/utils.h" | |||
| #include "src/naive/handle.h" | |||
| @@ -12,6 +12,7 @@ | |||
| MIDOUT_DECL(megdnn_fallback_elemwise_binary) | |||
| using namespace megdnn; | |||
| using namespace elemwise; | |||
| using namespace fallback; | |||
| namespace { | |||
| @@ -3,7 +3,7 @@ | |||
| */ | |||
| #include "src/fallback/elemwise/gi_impl/ternary/algo.h" | |||
| #include "src/fallback/elemwise_op.h" | |||
| #include "src/fallback/elemwise_helper/elemwise_op.h" | |||
| #include "src/common/utils.h" | |||
| #include "src/naive/handle.h" | |||
| @@ -13,6 +13,7 @@ | |||
| MIDOUT_DECL(megdnn_fallback_elemwise_ternary) | |||
| using namespace megdnn; | |||
| using namespace elemwise; | |||
| using namespace fallback; | |||
| #define DISPATCH_MODE_FLOAT(_case, _type, _type_midout_id) \ | |||
| @@ -2,7 +2,7 @@ | |||
| * \file dnn/src/fallback/elemwise/gi_impl/unary/algo.cpp | |||
| */ | |||
| #include "src/fallback/elemwise/gi_impl/unary/algo.h" | |||
| #include "src/fallback/elemwise_op.h" | |||
| #include "src/fallback/elemwise_helper/elemwise_op.h" | |||
| #include "src/common/utils.h" | |||
| #include "src/naive/handle.h" | |||
| @@ -12,6 +12,7 @@ | |||
| MIDOUT_DECL(megdnn_fallback_elemwise_unary) | |||
| using namespace megdnn; | |||
| using namespace elemwise; | |||
| using namespace fallback; | |||
| bool ElemwiseImpl::AlgoUnary::is_available(const KernParam& kern_param) const { | |||
| @@ -25,6 +25,7 @@ MIDOUT_DECL(megdnn_fallback_elemwise_exec_BINARY_INT) | |||
| MIDOUT_DECL(megdnn_fallback_elemwise_exec_BINARY_FLOAT) | |||
| using namespace megdnn; | |||
| using namespace elemwise; | |||
| using namespace fallback; | |||
| void ElemwiseImpl::exec(const TensorNDArray& srcs, _megdnn_tensor_out dst) { | |||
| @@ -9,7 +9,7 @@ | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #pragma once | |||
| #include "src/fallback/elemwise_op.h" | |||
| #include "src/fallback/elemwise_helper/elemwise_op.h" | |||
| #include "src/naive/elemwise/opr_impl.h" | |||
| namespace megdnn { | |||
| @@ -60,7 +60,7 @@ private: | |||
| public: | |||
| class AlgoBase; | |||
| struct KernParam { | |||
| BcastType broad_cast_type; | |||
| elemwise::BcastType broad_cast_type; | |||
| Mode mode; | |||
| const TensorND* m_dst; | |||
| Handle* handle; | |||
| @@ -0,0 +1,72 @@ | |||
| /** | |||
| * \file dnn/src/fallback/elemwise_helper/elemwise_op.h | |||
| */ | |||
| #pragma once | |||
| #include "src/fallback/elemwise_helper/op_binary.h" | |||
| #include "src/fallback/elemwise_helper/op_common.h" | |||
| #include "src/fallback/elemwise_helper/op_ternary.h" | |||
| #include "src/fallback/elemwise_helper/op_unary.h" | |||
| #include "src/fallback/general_intrinsic/gi_float.h" | |||
| #include "src/fallback/general_intrinsic/gi_int.h" | |||
| namespace megdnn { | |||
| namespace elemwise { | |||
| ///////////////////////////////// ParamElemVistor /////////////////////////// | |||
| #define cb(_ctype, _inner_ctype, _simd_type, _fun_suffix) \ | |||
| template <> \ | |||
| struct ParamElemVisitor<_ctype> { \ | |||
| _simd_type operator()(const _ctype* src) const { \ | |||
| return GiLoad##_fun_suffix(src); \ | |||
| } \ | |||
| }; \ | |||
| template <> \ | |||
| struct ParamElemVisitorDup<_ctype> { \ | |||
| _simd_type operator()(const _ctype* src) const { \ | |||
| return GiBroadcast##_fun_suffix( \ | |||
| *reinterpret_cast<const _inner_ctype*>(src)); \ | |||
| } \ | |||
| } | |||
| cb(dt_qint32, int32_t, GI_INT32_t, Int32); | |||
| cb(dt_qint8, int8_t, GI_INT8_t, Int8); | |||
| cb(dt_float32, float, GI_FLOAT32_t, Float32); | |||
| cb(dt_int32, int32_t, GI_INT32_t, Int32); | |||
| cb(dt_int8, int8_t, GI_INT8_t, Int8); | |||
| #undef cb | |||
| template <typename ctype> | |||
| struct ParamElemVisitorBcast101x4; | |||
| #define cb(_ctype, _inner_ctype, _simd_type, _fun_suffix, rel_suffix) \ | |||
| template <> \ | |||
| struct ParamElemVisitorBcast101x4<_ctype> { \ | |||
| _simd_type operator()(const _ctype* src) const { \ | |||
| return GiReinter##rel_suffix##To##_fun_suffix(GiBroadcast##rel_suffix( \ | |||
| *reinterpret_cast<const _inner_ctype*>(src))); \ | |||
| } \ | |||
| } | |||
| cb(dt_qint8, int32_t, GI_INT8_t, Int8, Int32); | |||
| cb(dt_int8, int32_t, GI_INT8_t, Int8, Int32); | |||
| #undef cb | |||
| #define cb(_ctype, _inner_ctype, _simd_type, _fun_suffix) \ | |||
| template <> \ | |||
| struct ParamElemVisitorBcast101x4<_ctype> { \ | |||
| _simd_type operator()(const _ctype* src) const { \ | |||
| return GiLoad##_fun_suffix(src); \ | |||
| } \ | |||
| } | |||
| cb(dt_qint32, int32_t, GI_INT32_t, Int32); | |||
| cb(dt_float32, float, GI_FLOAT32_t, Float32); | |||
| cb(dt_int32, int32_t, GI_INT32_t, Int32); | |||
| #undef cb | |||
| } // namespace elemwise | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -58,7 +58,7 @@ struct AbsOpBase<dt_qint8, dt_qint8> : UnaryOpBase<dt_qint8, dt_qint8> { | |||
| template <> | |||
| struct AbsOp<dt_qint8, dt_qint8> : AbsOpBase<dt_qint8, dt_qint8> { | |||
| using AbsOpBase::AbsOpBase; | |||
| constexpr static size_t SIMD_WIDTH = 16; | |||
| constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int8_t); | |||
| using AbsOpBase::operator(); | |||
| void operator()(const GI_INT8_V2_t& vsrc, dt_qint8* dst) const { | |||
| OPERATOR_UNARY_QINT8_FALLBACK; | |||
| @@ -87,7 +87,7 @@ template <> | |||
| struct FuseAddHSwishOp<dt_qint32, dt_qint8> : FuseAddHSwishOpBase<dt_qint32, dt_qint8> { | |||
| using FuseAddHSwishOpBase::FuseAddHSwishOpBase; | |||
| using FuseAddHSwishOpBase::operator(); | |||
| constexpr static size_t SIMD_WIDTH = 4; | |||
| constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int8_t); | |||
| void operator()( | |||
| const GI_INT32_V2_t& vsrc0, const GI_INT32_V2_t& vsrc1, | |||
| dt_qint8* dst) const { | |||
| @@ -83,7 +83,7 @@ template <> | |||
| struct HSwishOp<dt_qint32, dt_qint8> : HSwishOpBase<dt_qint32, dt_qint8> { | |||
| using HSwishOpBase::HSwishOpBase; | |||
| using HSwishOpBase::operator(); | |||
| constexpr static size_t SIMD_WIDTH = 4; | |||
| constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int32_t); | |||
| void operator()(const GI_INT32_V2_t& vsrc, dt_qint8* dst) const { | |||
| GiStoreLowInt8(reinterpret_cast<int8_t*>(dst), operator()(vsrc)); | |||
| @@ -77,7 +77,7 @@ struct MaxOpBase<dt_qint8, dt_qint8> : BinaryOpBase<dt_qint8, dt_qint8> { | |||
| template <> | |||
| struct MaxOp<dt_qint8, dt_qint8> : MaxOpBase<dt_qint8, dt_qint8> { | |||
| using MaxOpBase::MaxOpBase; | |||
| constexpr static size_t SIMD_WIDTH = 16; | |||
| constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int8_t); | |||
| using MaxOpBase::operator(); | |||
| void operator()( | |||
| @@ -74,7 +74,7 @@ struct MinOpBase<dt_qint8, dt_qint8> : BinaryOpBase<dt_qint8, dt_qint8> { | |||
| template <> | |||
| struct MinOp<dt_qint8, dt_qint8> : MinOpBase<dt_qint8, dt_qint8> { | |||
| using MinOpBase::MinOpBase; | |||
| constexpr static size_t SIMD_WIDTH = 16; | |||
| constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int8_t); | |||
| using MinOpBase::operator(); | |||
| void operator()( | |||
| @@ -73,7 +73,7 @@ struct MulOpBase<dt_qint8, dt_qint8> : BinaryOpBase<dt_qint8, dt_qint8> { | |||
| template <> | |||
| struct MulOp<dt_qint8, dt_qint8> : MulOpBase<dt_qint8, dt_qint8> { | |||
| using MulOpBase::MulOpBase; | |||
| constexpr static size_t SIMD_WIDTH = 16; | |||
| constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int8_t); | |||
| using MulOpBase::operator(); | |||
| void operator()( | |||
| @@ -54,8 +54,6 @@ struct NoneOpBase<dt_qint32, dt_qint8> : UnaryOpBase<dt_qint32, dt_qint8> { | |||
| } | |||
| }; | |||
| #pragma GCC diagnostic ignored "-Waddress-of-packed-member" | |||
| template <> | |||
| struct NoneOp<dt_qint32, dt_qint8> : NoneOpBase<dt_qint32, dt_qint8> { | |||
| using NoneOpBase::NoneOpBase; | |||
| @@ -63,11 +61,11 @@ struct NoneOp<dt_qint32, dt_qint8> : NoneOpBase<dt_qint32, dt_qint8> { | |||
| constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int32_t); | |||
| void operator()(const GI_INT32_V2_t& vsrc, dt_qint8* dst) const { | |||
| GiStoreInt32(reinterpret_cast<int32_t*>(dst), vsrc.val[0]); | |||
| GiStoreInt32(reinterpret_cast<int32_t*>(dst + 16), vsrc.val[1]); | |||
| GiStoreInt32(dst, vsrc.val[0]); | |||
| GiStoreInt32(dst + 16, vsrc.val[1]); | |||
| } | |||
| void operator()(const GI_INT32_t& src, dt_qint8* dst) const { | |||
| GiStoreInt32(reinterpret_cast<int32_t*>(dst), src); | |||
| GiStoreInt32(dst, src); | |||
| } | |||
| }; | |||
| @@ -112,36 +112,38 @@ struct ReluOp<dt_qint32, dt_qint8> : ReluOpBase<dt_qint32, dt_qint8>, FixupBase | |||
| : ReluOpBase(src_scale, dst_scale), FixupBase(scale) {} | |||
| void operator()(const int32x4x2_t& vsrc, dt_qint8* dst) const { | |||
| vst1_s8(reinterpret_cast<int8_t*>(dst), operator()(vsrc)); | |||
| vst1_s8(reinterpret_cast<int8_t*>(dst), vget_low_s8(operator()(vsrc))); | |||
| } | |||
| int8x8_t operator()(const int32x4x2_t& vsrc) const { | |||
| int8x16_t operator()(const int32x4x2_t& vsrc) const { | |||
| int32x4_t vitem0 = vqrdmulhq_s32(vsrc.val[0], vmultiplier); | |||
| int32x4_t vitem1 = vqrdmulhq_s32(vsrc.val[1], vmultiplier); | |||
| vitem0 = vmaxq_s32(vitem0, QConverterBase::vzero()); | |||
| vitem1 = vmaxq_s32(vitem1, QConverterBase::vzero()); | |||
| return vqmovn_s16(vcombine_s16( | |||
| auto tmp = vqmovn_s16(vcombine_s16( | |||
| vqmovn_s32(vrshlq_s32(vitem0, vshift)), | |||
| vqmovn_s32(vrshlq_s32(vitem1, vshift)))); | |||
| return vcombine_s8(tmp, tmp); | |||
| } | |||
| int8x8_t operator()(const float32x4_t& vsrc) const { | |||
| int8x16_t operator()(const float32x4_t& vsrc) const { | |||
| int32x4_t vitem0 = vqrdmulhq_s32(vcvtq_s32_f32(vsrc), vmultiplier); | |||
| vitem0 = vmaxq_s32(vitem0, QConverterBase::vzero()); | |||
| vitem0 = vrshlq_s32(vitem0, vshift); | |||
| int16x4_t vitem = vqmovn_s32(vitem0); | |||
| return vqmovn_s16(vcombine_s16(vitem, vitem)); | |||
| auto tmp = vqmovn_s16(vcombine_s16(vitem, vitem)); | |||
| return vcombine_s8(tmp, tmp); | |||
| } | |||
| void operator()(const int32x4_t& src, dt_qint8* dst) const { | |||
| auto vitem0 = vmulq_f32(vcvtq_f32_s32(src), this->vscale); | |||
| vitem0 = vmaxq_f32(vitem0, QConverterBase::vfzero()); | |||
| auto result = QConverter::convert<int8x8_t, float32x4_t>(vitem0); | |||
| vst1_lane_s32(reinterpret_cast<int32_t*>(dst), (int32x2_t)result, 0); | |||
| auto result = QConverter::convert<int8x16_t, float32x4_t>(vitem0); | |||
| vst1q_lane_s32(reinterpret_cast<int32_t*>(dst), (int32x4_t)result, 0); | |||
| } | |||
| void operator()(const float32x4_t& src, dt_qint8* dst) const { | |||
| auto vitem0 = vmulq_f32(src, this->vscale); | |||
| vitem0 = vmaxq_f32(vitem0, QConverterBase::vfzero()); | |||
| auto result = QConverter::convert<int8x8_t, float32x4_t>(vitem0); | |||
| vst1_lane_s32(reinterpret_cast<int32_t*>(dst), (int32x2_t)result, 0); | |||
| auto result = QConverter::convert<int8x16_t, float32x4_t>(vitem0); | |||
| vst1q_lane_s32(reinterpret_cast<int32_t*>(dst), (int32x4_t)result, 0); | |||
| } | |||
| }; | |||
| @@ -73,7 +73,7 @@ struct SubOpBase<dt_qint8, dt_qint8> : BinaryOpBase<dt_qint8, dt_qint8> { | |||
| template <> | |||
| struct SubOp<dt_qint8, dt_qint8> : SubOpBase<dt_qint8, dt_qint8> { | |||
| using SubOpBase::SubOpBase; | |||
| constexpr static size_t SIMD_WIDTH = 16; | |||
| constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int8_t); | |||
| using SubOpBase::operator(); | |||
| void operator()( | |||
| @@ -13,6 +13,7 @@ | |||
| #include "math.h" | |||
| #include "stdint.h" | |||
| #include "string.h" | |||
| #if defined(_WIN32) | |||
| #include <intrin.h> | |||
| @@ -132,6 +133,18 @@ typedef uint32_t GI_UINT32_t __attribute__((vector_size(16))); | |||
| #define Max(a, b) (a) > (b) ? (a) : (b) | |||
| #define Min(a, b) (a) < (b) ? (a) : (b) | |||
| #if defined(GI_NEON_INTRINSICS) | |||
| #if defined(__ARM_FEATURE_FMA) && defined(GI_NEON64_INTRINSICS) | |||
| #define v_fma_ps_f32(c, b, a) vfmaq_f32((c), (b), (a)) | |||
| #define v_fma_n_f32(c, b, a) vfmaq_n_f32((c), (b), (a)) | |||
| #define v_fma_lane_f32(c, b, a, lane) vfmaq_lane_f32((c), (b), (a), (lane)) | |||
| #else | |||
| #define v_fma_ps_f32(c, b, a) vmlaq_f32((c), (b), (a)) | |||
| #define v_fma_n_f32(c, b, a) vmlaq_n_f32((c), (b), (a)) | |||
| #define v_fma_lane_f32(c, b, a, lane) vmlaq_lane_f32((c), (b), (a), (lane)) | |||
| #endif | |||
| #endif | |||
| typedef struct { | |||
| GI_INT32_t val[2]; | |||
| } GI_INT32_V2_t; | |||
| @@ -20,7 +20,9 @@ GI_INT32_t GiReinterpretAsInt32(GI_FLOAT32_t In) { | |||
| #elif defined(GI_SSE2_INTRINSICS) | |||
| return _mm_castps_si128(In); | |||
| #else | |||
| return *(GI_INT32_t*)(&In); | |||
| GI_INT32_t ret; | |||
| memcpy(&ret, &In, GI_SIMD_LEN_BYTE); | |||
| return ret; | |||
| #endif | |||
| } | |||
| @@ -31,7 +33,9 @@ GI_UINT32_t GiReinterpretAsUint32(GI_FLOAT32_t In) { | |||
| #elif defined(GI_SSE2_INTRINSICS) | |||
| return _mm_castps_si128(In); | |||
| #else | |||
| return *(GI_UINT32_t*)(&In); | |||
| GI_UINT32_t ret; | |||
| memcpy(&ret, &In, GI_SIMD_LEN_BYTE); | |||
| return ret; | |||
| #endif | |||
| } | |||
| @@ -42,7 +46,9 @@ GI_FLOAT32_t GiReintInt32ToFloat32(GI_INT32_t Vector) { | |||
| #elif defined(GI_SSE2_INTRINSICS) | |||
| return _mm_castsi128_ps(Vector); | |||
| #else | |||
| return *(GI_FLOAT32_t*)(&Vector); | |||
| GI_FLOAT32_t ret; | |||
| memcpy(&ret, &Vector, GI_SIMD_LEN_BYTE); | |||
| return ret; | |||
| #endif | |||
| } | |||
| @@ -53,7 +59,9 @@ GI_FLOAT32_t GiReintUint32ToFloat32(GI_UINT32_t Vector) { | |||
| #elif defined(GI_SSE2_INTRINSICS) | |||
| return _mm_castsi128_ps(Vector); | |||
| #else | |||
| return *(GI_FLOAT32_t*)(&Vector); | |||
| GI_FLOAT32_t ret; | |||
| memcpy(&ret, &Vector, GI_SIMD_LEN_BYTE); | |||
| return ret; | |||
| #endif | |||
| } | |||
| @@ -69,7 +77,7 @@ GI_INT32_t GiRoundAsInt32(GI_FLOAT32_t Vector) { | |||
| float32x4_t vinc0 = vbslq_f32(vcgeq_f32(Vector, vzero), vfhalf, vfneg_half); | |||
| return vcvtq_s32_f32(vaddq_f32(Vector, vinc0)); | |||
| #endif | |||
| #elif defined(GI_SSE2_INTRINSICS) | |||
| #elif defined(GI_SSE42_INTRINSICS) | |||
| __m128 vfzero = _mm_set1_ps(0.f); | |||
| __m128 vfhalf = _mm_set1_ps(0.5f); | |||
| __m128 vfneg_half = _mm_set1_ps(-0.5f); | |||
| @@ -322,11 +330,7 @@ GI_FORCEINLINE | |||
| GI_FLOAT32_t GiMultiplyAddFloat32( | |||
| GI_FLOAT32_t VectorSum, GI_FLOAT32_t Vector1, GI_FLOAT32_t Vector2) { | |||
| #if defined(GI_NEON_INTRINSICS) | |||
| #if defined(__ARM_FEATURE_FMA) | |||
| return vfmaq_f32(VectorSum, Vector1, Vector2); | |||
| #else | |||
| return vmlaq_f32(VectorSum, Vector1, Vector2); | |||
| #endif | |||
| return v_fma_ps_f32(VectorSum, Vector1, Vector2); | |||
| #elif defined(GI_FMA3_INTRINSICS) | |||
| return _mm_fmadd_ps(Vector1, Vector2, VectorSum); | |||
| #elif defined(GI_SSE2_INTRINSICS) | |||
| @@ -352,11 +356,7 @@ GI_FORCEINLINE | |||
| GI_FLOAT32_t GiMultiplyAddScalarFloat32( | |||
| GI_FLOAT32_t VectorSum, GI_FLOAT32_t Vector, float Scalar) { | |||
| #if defined(GI_NEON_INTRINSICS) | |||
| #if defined(__ARM_FEATURE_FMA) | |||
| return vfmaq_n_f32(VectorSum, Vector, Scalar); | |||
| #else | |||
| return vfmla_n_f32(VectorSum, Vector, Scalar); | |||
| #endif | |||
| return v_fma_n_f32(VectorSum, Vector, Scalar); | |||
| #elif defined(GI_SSE2_INTRINSICS) | |||
| return GiMultiplyAddFloat32(VectorSum, GiBroadcastFloat32(Scalar), Vector); | |||
| #else | |||
| @@ -365,27 +365,10 @@ GI_FLOAT32_t GiMultiplyAddScalarFloat32( | |||
| } | |||
| #if defined(GI_NEON_INTRINSICS) | |||
| #if defined(__ARM_FEATURE_FMA) | |||
| #define GIMULTIPLYADDLANFLOAT32(i) \ | |||
| GI_FORCEINLINE GI_FLOAT32_t GiMultiplyAddLan##i##Float32( \ | |||
| GI_FLOAT32_t VectorSum, GI_FLOAT32_t Vector1, GI_FLOAT32_t Vector2) { \ | |||
| return vfmaq_lane_f32(VectorSum, Vector1, vget_low_f32(Vector2), i); \ | |||
| } | |||
| GIMULTIPLYADDLANFLOAT32(0) | |||
| GIMULTIPLYADDLANFLOAT32(1) | |||
| #undef GIMULTIPLYADDLANFLOAT32 | |||
| #define GIMULTIPLYADDLANFLOAT32(i) \ | |||
| GI_FORCEINLINE GI_FLOAT32_t GiMultiplyAddLan##i##Float32( \ | |||
| GI_FLOAT32_t VectorSum, GI_FLOAT32_t Vector1, GI_FLOAT32_t Vector2) { \ | |||
| return vfmaq_lane_f32(VectorSum, Vector1, vget_high_f32(Vector2), i - 2); \ | |||
| } | |||
| GIMULTIPLYADDLANFLOAT32(2) | |||
| GIMULTIPLYADDLANFLOAT32(3) | |||
| #else | |||
| #define GIMULTIPLYADDLANFLOAT32(i) \ | |||
| GI_FORCEINLINE GI_FLOAT32_t GiMultiplyAddLan##i##Float32( \ | |||
| GI_FLOAT32_t VectorSum, GI_FLOAT32_t Vector1, GI_FLOAT32_t Vector2) { \ | |||
| return vmlaq_lane_f32(VectorSum, Vector1, vget_low_f32(Vector2), i); \ | |||
| return v_fma_lane_f32(VectorSum, Vector1, vget_low_f32(Vector2), i); \ | |||
| } | |||
| GIMULTIPLYADDLANFLOAT32(0) | |||
| GIMULTIPLYADDLANFLOAT32(1) | |||
| @@ -393,11 +376,10 @@ GIMULTIPLYADDLANFLOAT32(1) | |||
| #define GIMULTIPLYADDLANFLOAT32(i) \ | |||
| GI_FORCEINLINE GI_FLOAT32_t GiMultiplyAddLan##i##Float32( \ | |||
| GI_FLOAT32_t VectorSum, GI_FLOAT32_t Vector1, GI_FLOAT32_t Vector2) { \ | |||
| return vmlaq_lane_f32(VectorSum, Vector1, vget_high_f32(Vector2), i - 2); \ | |||
| return v_fma_lane_f32(VectorSum, Vector1, vget_high_f32(Vector2), i - 2); \ | |||
| } | |||
| GIMULTIPLYADDLANFLOAT32(2) | |||
| GIMULTIPLYADDLANFLOAT32(3) | |||
| #endif | |||
| #undef GIMULTIPLYADDLANFLOAT32 | |||
| #elif defined(GI_SSE2_INTRINSICS) | |||
| @@ -59,66 +59,69 @@ GI_INT8_t GiBroadcastInt8(int8_t Value) { | |||
| } | |||
| GI_FORCEINLINE | |||
| GI_INT32_t GiLoadInt32(const int32_t* Buffer) { | |||
| GI_INT32_t GiLoadInt32(const void* Buffer) { | |||
| #if defined(GI_NEON_INTRINSICS) | |||
| return vld1q_s32(Buffer); | |||
| return vld1q_s32((int32_t*)Buffer); | |||
| #elif defined(GI_SSE2_INTRINSICS) | |||
| return _mm_loadu_si128((const __m128i*)Buffer); | |||
| #else | |||
| GI_INT32_t ret; | |||
| const int32_t* ptr = (int32_t*)Buffer; | |||
| for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(int32_t); i++) { | |||
| ret[i] = Buffer[i]; | |||
| ret[i] = ptr[i]; | |||
| } | |||
| return ret; | |||
| #endif | |||
| } | |||
| GI_FORCEINLINE | |||
| GI_INT8_t GiLoadInt8(const int8_t* Buffer) { | |||
| GI_INT8_t GiLoadInt8(const void* Buffer) { | |||
| #if defined(GI_NEON_INTRINSICS) | |||
| return vld1q_s8(Buffer); | |||
| return vld1q_s8((int8_t*)Buffer); | |||
| #elif defined(GI_SSE2_INTRINSICS) | |||
| return _mm_loadu_si128((const __m128i*)Buffer); | |||
| #else | |||
| GI_INT8_t ret; | |||
| const int8_t* ptr = (int8_t*)Buffer; | |||
| for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(int8_t); i++) { | |||
| ret[i] = Buffer[i]; | |||
| ret[i] = ptr[i]; | |||
| } | |||
| return ret; | |||
| #endif | |||
| } | |||
| GI_FORCEINLINE | |||
| void GiStoreInt32(int32_t* Buffer, GI_INT32_t Vector) { | |||
| void GiStoreInt32(void* Buffer, GI_INT32_t Vector) { | |||
| #if defined(GI_NEON_INTRINSICS) | |||
| vst1q_s32(Buffer, Vector); | |||
| vst1q_s32((int32_t*)Buffer, Vector); | |||
| #elif defined(GI_SSE2_INTRINSICS) | |||
| _mm_storeu_si128((__m128i*)Buffer, Vector); | |||
| #else | |||
| int32_t* ptr = (int32_t*)Buffer; | |||
| for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(int32_t); i++) { | |||
| Buffer[i] = Vector[i]; | |||
| ptr[i] = Vector[i]; | |||
| } | |||
| #endif | |||
| } | |||
| #if defined(GI_NEON_INTRINSICS) | |||
| #define GISTORELANEINT32(i) \ | |||
| GI_FORCEINLINE void GiStoreLane##i##Int32(int32_t* Buffer, GI_INT32_t Vector) { \ | |||
| vst1q_lane_s32(Buffer, Vector, i); \ | |||
| #define GISTORELANEINT32(i) \ | |||
| GI_FORCEINLINE void GiStoreLane##i##Int32(void* Buffer, GI_INT32_t Vector) { \ | |||
| vst1q_lane_s32((int32_t*)Buffer, Vector, i); \ | |||
| } | |||
| #elif defined(GI_SSE2_INTRINSICS) | |||
| #define GISTORELANEINT32(i) \ | |||
| GI_FORCEINLINE void GiStoreLane##i##Int32(int32_t* Buffer, GI_INT32_t Vector) { \ | |||
| GI_FORCEINLINE void GiStoreLane##i##Int32(void* Buffer, GI_INT32_t Vector) { \ | |||
| GI_FLOAT32_t tmp = _mm_castsi128_ps(Vector); \ | |||
| _mm_store_ss( \ | |||
| (float*)Buffer, _mm_shuffle_ps(tmp, tmp, _MM_SHUFFLE(i, i, i, i))); \ | |||
| } | |||
| #else | |||
| #define GISTORELANEINT32(i) \ | |||
| GI_FORCEINLINE void GiStoreLane##i##Int32(int32_t* Buffer, GI_INT32_t Vector) { \ | |||
| *Buffer = Vector[i]; \ | |||
| #define GISTORELANEINT32(i) \ | |||
| GI_FORCEINLINE void GiStoreLane##i##Int32(void* Buffer, GI_INT32_t Vector) { \ | |||
| *((int32_t*)Buffer) = Vector[i]; \ | |||
| } | |||
| #endif | |||
| @@ -141,53 +144,57 @@ GI_INT8_t GiReinterInt32ToInt8(GI_INT32_t Vector) { | |||
| } | |||
| GI_FORCEINLINE | |||
| void GiStoreInt16(int16_t* Buffer, GI_INT16_t Vector) { | |||
| void GiStoreInt16(void* Buffer, GI_INT16_t Vector) { | |||
| #if defined(GI_NEON_INTRINSICS) | |||
| vst1q_s16(Buffer, Vector); | |||
| vst1q_s16((int16_t*)Buffer, Vector); | |||
| #elif defined(GI_SSE2_INTRINSICS) | |||
| _mm_storeu_si128((__m128i*)Buffer, Vector); | |||
| #else | |||
| int16_t* ptr = (int16_t*)Buffer; | |||
| for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(int16_t); i++) { | |||
| Buffer[i] = Vector[i]; | |||
| ptr[i] = Vector[i]; | |||
| } | |||
| #endif | |||
| } | |||
| GI_FORCEINLINE | |||
| void GiStoreInt8(int8_t* Buffer, GI_INT8_t Vector) { | |||
| void GiStoreInt8(void* Buffer, GI_INT8_t Vector) { | |||
| #if defined(GI_NEON_INTRINSICS) | |||
| vst1q_s8(Buffer, Vector); | |||
| vst1q_s8((int8_t*)Buffer, Vector); | |||
| #elif defined(GI_SSE2_INTRINSICS) | |||
| _mm_storeu_si128((__m128i*)Buffer, Vector); | |||
| #else | |||
| int8_t* ptr = (int8_t*)Buffer; | |||
| for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(int8_t); i++) { | |||
| Buffer[i] = Vector[i]; | |||
| ptr[i] = Vector[i]; | |||
| } | |||
| #endif | |||
| } | |||
| GI_FORCEINLINE | |||
| void GiStoreLowInt8(int8_t* Buffer, GI_INT8_t Vector) { | |||
| void GiStoreLowInt8(void* Buffer, GI_INT8_t Vector) { | |||
| #if defined(GI_NEON_INTRINSICS) | |||
| vst1_s8(Buffer, vget_low_s8(Vector)); | |||
| vst1_s8((int8_t*)Buffer, vget_low_s8(Vector)); | |||
| #elif defined(GI_SSE2_INTRINSICS) | |||
| _mm_storel_epi64((__m128i*)Buffer, Vector); | |||
| #else | |||
| int8_t* ptr = (int8_t*)Buffer; | |||
| for (size_t i = 0; i < GI_SIMD_LEN_BYTE / 2 / sizeof(int8_t); i++) { | |||
| Buffer[i] = Vector[i]; | |||
| ptr[i] = Vector[i]; | |||
| } | |||
| #endif | |||
| } | |||
| GI_FORCEINLINE | |||
| void GiStoreHihgInt8(int8_t* Buffer, GI_INT8_t Vector) { | |||
| void GiStoreHihgInt8(void* Buffer, GI_INT8_t Vector) { | |||
| #if defined(GI_NEON_INTRINSICS) | |||
| vst1_s8(Buffer, vget_high_s8(Vector)); | |||
| vst1_s8((int8_t*)Buffer, vget_high_s8(Vector)); | |||
| #elif defined(GI_SSE2_INTRINSICS) | |||
| _mm_storel_epi64((__m128i*)Buffer, _mm_unpackhi_epi64(Vector, Vector)); | |||
| #else | |||
| int8_t* ptr = (int8_t*)Buffer; | |||
| for (size_t i = 0; i < GI_SIMD_LEN_BYTE / 2 / sizeof(int8_t); i++) { | |||
| Buffer[i] = Vector[GI_SIMD_LEN_BYTE / 2 + i]; | |||
| ptr[i] = Vector[GI_SIMD_LEN_BYTE / 2 + i]; | |||
| } | |||
| #endif | |||
| } | |||
| @@ -39,7 +39,6 @@ TEST_F(FALLBACK, ELEMWISE_RECORD) { | |||
| checker.execs({{10, 10, 32}, {10, 10, 32}, {}}); | |||
| } | |||
| TEST_F(FALLBACK, ELEMWISE_FORWARD_TERNARY) { | |||
| using Mode = ElemwiseForward::Param::Mode; | |||
| Checker<ElemwiseForward> checker(handle()); | |||