GitOrigin-RevId: 87046b8197
tags/v1.10.0
| @@ -12,7 +12,7 @@ | |||||
| #include "src/aarch64/conv_bias/int8/algos.h" | #include "src/aarch64/conv_bias/int8/algos.h" | ||||
| #include "src/aarch64/conv_bias/int8/strategy.h" | #include "src/aarch64/conv_bias/int8/strategy.h" | ||||
| #include "src/arm_common/convolution/img2col_helper.h" | #include "src/arm_common/convolution/img2col_helper.h" | ||||
| #include "src/arm_common/elemwise_op.h" | |||||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
| #include "src/common/opr_delegate.h" | #include "src/common/opr_delegate.h" | ||||
| #include "src/fallback/conv_bias/common.h" | #include "src/fallback/conv_bias/common.h" | ||||
| #include "src/fallback/matrix_mul/gemm_impl.h" | #include "src/fallback/matrix_mul/gemm_impl.h" | ||||
| @@ -14,7 +14,7 @@ | |||||
| #include "src/aarch64/matrix_mul/quint8_dot/gemv.h" | #include "src/aarch64/matrix_mul/quint8_dot/gemv.h" | ||||
| #include "src/aarch64/matrix_mul/quint8_dot/strategy.h" | #include "src/aarch64/matrix_mul/quint8_dot/strategy.h" | ||||
| #include "src/arm_common/convolution/img2col_helper.h" | #include "src/arm_common/convolution/img2col_helper.h" | ||||
| #include "src/arm_common/elemwise_op.h" | |||||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
| #include "src/common/opr_delegate.h" | #include "src/common/opr_delegate.h" | ||||
| #include "src/fallback/conv_bias/common.h" | #include "src/fallback/conv_bias/common.h" | ||||
| #include "src/fallback/matrix_mul/gemm_impl.h" | #include "src/fallback/matrix_mul/gemm_impl.h" | ||||
| @@ -11,7 +11,7 @@ | |||||
| */ | */ | ||||
| #include "src/arm_common/conv_bias/f16/channel_wise_3x3_s1p1_nchw88_kern.h" | #include "src/arm_common/conv_bias/f16/channel_wise_3x3_s1p1_nchw88_kern.h" | ||||
| #include "src/arm_common/elemwise_op.h" | |||||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
| #include "src/arm_common/simd_macro/marm_neon.h" | #include "src/arm_common/simd_macro/marm_neon.h" | ||||
| #include "src/arm_common/utils.h" | #include "src/arm_common/utils.h" | ||||
| #include "src/common/utils.h" | #include "src/common/utils.h" | ||||
| @@ -12,7 +12,7 @@ | |||||
| #include "src/arm_common/conv_bias/f16/algos.h" | #include "src/arm_common/conv_bias/f16/algos.h" | ||||
| #include "src/arm_common/conv_bias/f16/channel_wise_nchw88_kern.h" | #include "src/arm_common/conv_bias/f16/channel_wise_nchw88_kern.h" | ||||
| #include "src/arm_common/elemwise_op.h" | |||||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
| #include "midout.h" | #include "midout.h" | ||||
| @@ -12,7 +12,7 @@ | |||||
| #include "src/arm_common/conv_bias/f16/channel_wise_nchw88_kern.h" | #include "src/arm_common/conv_bias/f16/channel_wise_nchw88_kern.h" | ||||
| #include "src/arm_common/conv_bias/f16/channel_wise_3x3_s1p1_nchw88_kern.h" | #include "src/arm_common/conv_bias/f16/channel_wise_3x3_s1p1_nchw88_kern.h" | ||||
| #include "src/arm_common/elemwise_op.h" | |||||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
| #include "src/arm_common/simd_macro/marm_neon.h" | #include "src/arm_common/simd_macro/marm_neon.h" | ||||
| #include "src/arm_common/utils.h" | #include "src/arm_common/utils.h" | ||||
| #include "src/common/unroll_macro.h" | #include "src/common/unroll_macro.h" | ||||
| @@ -15,7 +15,7 @@ | |||||
| #include "src/arm_common/conv_bias/f16/algos.h" | #include "src/arm_common/conv_bias/f16/algos.h" | ||||
| #include "src/arm_common/conv_bias/f16/direct_nchw88_kern.h" | #include "src/arm_common/conv_bias/f16/direct_nchw88_kern.h" | ||||
| #include "src/arm_common/elemwise_op.h" | |||||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
| #include "midout.h" | #include "midout.h" | ||||
| @@ -13,7 +13,7 @@ | |||||
| #include "src/arm_common/conv_bias/f16/direct_nchw88_kern.h" | #include "src/arm_common/conv_bias/f16/direct_nchw88_kern.h" | ||||
| #include "src/arm_common/conv_bias/opr_impl.h" | #include "src/arm_common/conv_bias/opr_impl.h" | ||||
| #include "src/arm_common/elemwise_op.h" | |||||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
| #include "src/arm_common/simd_macro/marm_neon.h" | #include "src/arm_common/simd_macro/marm_neon.h" | ||||
| #include "src/fallback/conv_bias/common.h" | #include "src/fallback/conv_bias/common.h" | ||||
| @@ -11,7 +11,7 @@ | |||||
| */ | */ | ||||
| #include "src/arm_common/conv_bias/fp32/channel_wise_3x3_s1p1_nchw44_kern.h" | #include "src/arm_common/conv_bias/fp32/channel_wise_3x3_s1p1_nchw44_kern.h" | ||||
| #include "src/arm_common/elemwise_op.h" | |||||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
| #include "src/arm_common/simd_macro/marm_neon.h" | #include "src/arm_common/simd_macro/marm_neon.h" | ||||
| #include "src/arm_common/utils.h" | #include "src/arm_common/utils.h" | ||||
| #include "src/common/utils.h" | #include "src/common/utils.h" | ||||
| @@ -11,7 +11,7 @@ | |||||
| */ | */ | ||||
| #include "src/arm_common/conv_bias/fp32/channel_wise_5x5_s1p2_nchw44_kern.h" | #include "src/arm_common/conv_bias/fp32/channel_wise_5x5_s1p2_nchw44_kern.h" | ||||
| #include "src/arm_common/elemwise_op.h" | |||||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
| #include "src/arm_common/simd_macro/marm_neon.h" | #include "src/arm_common/simd_macro/marm_neon.h" | ||||
| #include "src/arm_common/utils.h" | #include "src/arm_common/utils.h" | ||||
| #include "src/common/utils.h" | #include "src/common/utils.h" | ||||
| @@ -12,7 +12,7 @@ | |||||
| #include "src/arm_common/conv_bias/fp32/algos.h" | #include "src/arm_common/conv_bias/fp32/algos.h" | ||||
| #include "src/arm_common/conv_bias/fp32/channel_wise_nchw44_kern.h" | #include "src/arm_common/conv_bias/fp32/channel_wise_nchw44_kern.h" | ||||
| #include "src/arm_common/elemwise_op.h" | |||||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
| #include "midout.h" | #include "midout.h" | ||||
| @@ -13,7 +13,7 @@ | |||||
| #include "src/arm_common/conv_bias/fp32/channel_wise_nchw44_kern.h" | #include "src/arm_common/conv_bias/fp32/channel_wise_nchw44_kern.h" | ||||
| #include "src/arm_common/conv_bias/fp32/channel_wise_3x3_s1p1_nchw44_kern.h" | #include "src/arm_common/conv_bias/fp32/channel_wise_3x3_s1p1_nchw44_kern.h" | ||||
| #include "src/arm_common/conv_bias/fp32/channel_wise_5x5_s1p2_nchw44_kern.h" | #include "src/arm_common/conv_bias/fp32/channel_wise_5x5_s1p2_nchw44_kern.h" | ||||
| #include "src/arm_common/elemwise_op.h" | |||||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
| #include "src/arm_common/simd_macro/marm_neon.h" | #include "src/arm_common/simd_macro/marm_neon.h" | ||||
| #include "src/arm_common/utils.h" | #include "src/arm_common/utils.h" | ||||
| #include "src/common/unroll_macro.h" | #include "src/common/unroll_macro.h" | ||||
| @@ -14,7 +14,7 @@ | |||||
| #include "megdnn/arch.h" | #include "megdnn/arch.h" | ||||
| #include "src/arm_common/conv_bias/fp32/f32_direct_nchw44_kern.h" | #include "src/arm_common/conv_bias/fp32/f32_direct_nchw44_kern.h" | ||||
| #include "src/arm_common/conv_bias/intrinsic_helper.h" | #include "src/arm_common/conv_bias/intrinsic_helper.h" | ||||
| #include "src/arm_common/elemwise_op.h" | |||||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
| #include "src/arm_common/simd_macro/marm_neon.h" | #include "src/arm_common/simd_macro/marm_neon.h" | ||||
| #include "src/common/unroll_macro.h" | #include "src/common/unroll_macro.h" | ||||
| #include "src/common/utils.h" | #include "src/common/utils.h" | ||||
| @@ -14,7 +14,7 @@ | |||||
| #include "megdnn/arch.h" | #include "megdnn/arch.h" | ||||
| #include "src/arm_common/conv_bias/fp32/f32_direct_nchw44_kern.h" | #include "src/arm_common/conv_bias/fp32/f32_direct_nchw44_kern.h" | ||||
| #include "src/arm_common/conv_bias/intrinsic_helper.h" | #include "src/arm_common/conv_bias/intrinsic_helper.h" | ||||
| #include "src/arm_common/elemwise_op.h" | |||||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
| #include "src/arm_common/simd_macro/marm_neon.h" | #include "src/arm_common/simd_macro/marm_neon.h" | ||||
| #include "src/common/unroll_macro.h" | #include "src/common/unroll_macro.h" | ||||
| #include "src/common/utils.h" | #include "src/common/utils.h" | ||||
| @@ -14,7 +14,7 @@ | |||||
| #include "src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_kern.h" | #include "src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_kern.h" | ||||
| #include "src/arm_common/conv_bias/intrinsic_helper.h" | #include "src/arm_common/conv_bias/intrinsic_helper.h" | ||||
| #include "src/arm_common/conv_bias/opr_impl.h" | #include "src/arm_common/conv_bias/opr_impl.h" | ||||
| #include "src/arm_common/elemwise_op.h" | |||||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
| #include "src/arm_common/simd_macro/marm_neon.h" | #include "src/arm_common/simd_macro/marm_neon.h" | ||||
| #include "src/common/unroll_macro.h" | #include "src/common/unroll_macro.h" | ||||
| #include "src/common/utils.h" | #include "src/common/utils.h" | ||||
| @@ -15,7 +15,7 @@ | |||||
| #include "src/arm_common/conv_bias/fp32/algos.h" | #include "src/arm_common/conv_bias/fp32/algos.h" | ||||
| #include "src/arm_common/conv_bias/fp32/f32_direct_nchw44_kern.h" | #include "src/arm_common/conv_bias/fp32/f32_direct_nchw44_kern.h" | ||||
| #include "src/arm_common/elemwise_op.h" | |||||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
| #include "midout.h" | #include "midout.h" | ||||
| @@ -15,7 +15,7 @@ | |||||
| #include "src/arm_common/conv_bias/fp32/algos.h" | #include "src/arm_common/conv_bias/fp32/algos.h" | ||||
| #include "src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_kern.h" | #include "src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_kern.h" | ||||
| #include "src/arm_common/conv_bias/fp32/strategy.h" | #include "src/arm_common/conv_bias/fp32/strategy.h" | ||||
| #include "src/arm_common/elemwise_op.h" | |||||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
| #include "src/common/nchw_nchwxx_valid.h" | #include "src/common/nchw_nchwxx_valid.h" | ||||
| #include "src/common/opr_delegate.h" | #include "src/common/opr_delegate.h" | ||||
| @@ -13,7 +13,7 @@ | |||||
| #include "megdnn/arch.h" | #include "megdnn/arch.h" | ||||
| #include "src/arm_common/conv_bias/intrinsic_helper.h" | #include "src/arm_common/conv_bias/intrinsic_helper.h" | ||||
| #include "src/arm_common/conv_bias/opr_impl.h" | #include "src/arm_common/conv_bias/opr_impl.h" | ||||
| #include "src/arm_common/elemwise_op.h" | |||||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
| #include "src/arm_common/simd_macro/marm_neon.h" | #include "src/arm_common/simd_macro/marm_neon.h" | ||||
| #include "src/common/unroll_macro.h" | #include "src/common/unroll_macro.h" | ||||
| #include "src/common/utils.h" | #include "src/common/utils.h" | ||||
| @@ -17,7 +17,7 @@ | |||||
| #include "src/arm_common/conv_bias/int8/stride1_dotprod.h" | #include "src/arm_common/conv_bias/int8/stride1_dotprod.h" | ||||
| #include "src/arm_common/conv_bias/int8/stride2.h" | #include "src/arm_common/conv_bias/int8/stride2.h" | ||||
| #include "src/arm_common/conv_bias/int8/stride2_dotprod.h" | #include "src/arm_common/conv_bias/int8/stride2_dotprod.h" | ||||
| #include "src/arm_common/elemwise_op.h" | |||||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
| #include "src/fallback/conv_bias/common.h" | #include "src/fallback/conv_bias/common.h" | ||||
| #include "midout.h" | #include "midout.h" | ||||
| @@ -11,7 +11,7 @@ | |||||
| */ | */ | ||||
| #include "src/arm_common/conv_bias/int8/channel_wise_kernel.h" | #include "src/arm_common/conv_bias/int8/channel_wise_kernel.h" | ||||
| #include "src/arm_common/elemwise_op.h" | |||||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
| #include "src/arm_common/simd_macro/marm_neon.h" | #include "src/arm_common/simd_macro/marm_neon.h" | ||||
| #include "src/common/unroll_macro.h" | #include "src/common/unroll_macro.h" | ||||
| #include "src/common/utils.h" | #include "src/common/utils.h" | ||||
| @@ -12,7 +12,7 @@ | |||||
| #include "src/arm_common/conv_bias/int8/channel_wise_nchw44.h" | #include "src/arm_common/conv_bias/int8/channel_wise_nchw44.h" | ||||
| #include "megdnn/oprs.h" | #include "megdnn/oprs.h" | ||||
| #include "src/arm_common/conv_bias/int8/channel_wise_kernel.h" | #include "src/arm_common/conv_bias/int8/channel_wise_kernel.h" | ||||
| #include "src/arm_common/elemwise_op.h" | |||||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
| #include "src/common/opr_delegate.h" | #include "src/common/opr_delegate.h" | ||||
| #include "midout.h" | #include "midout.h" | ||||
| @@ -10,7 +10,7 @@ | |||||
| */ | */ | ||||
| #include "src/arm_common/conv_bias/int8/direct.h" | #include "src/arm_common/conv_bias/int8/direct.h" | ||||
| #include "src/arm_common/elemwise_op.h" | |||||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
| #include "src/arm_common/simd_macro/marm_neon.h" | #include "src/arm_common/simd_macro/marm_neon.h" | ||||
| #include "src/common/utils.h" | #include "src/common/utils.h" | ||||
| #include "src/fallback/conv_bias/common.h" | #include "src/fallback/conv_bias/common.h" | ||||
| @@ -11,7 +11,7 @@ | |||||
| #include "src/arm_common/conv_bias/int8/direct_dotprod.h" | #include "src/arm_common/conv_bias/int8/direct_dotprod.h" | ||||
| #if MGB_ENABLE_DOT | #if MGB_ENABLE_DOT | ||||
| #include "src/arm_common/elemwise_op.h" | |||||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
| #include "src/arm_common/simd_macro/marm_neon.h" | #include "src/arm_common/simd_macro/marm_neon.h" | ||||
| #include "src/common/utils.h" | #include "src/common/utils.h" | ||||
| #include "src/fallback/conv_bias/common.h" | #include "src/fallback/conv_bias/common.h" | ||||
| @@ -14,7 +14,7 @@ | |||||
| #if MGB_ENABLE_DOT | #if MGB_ENABLE_DOT | ||||
| #include "src/arm_common/conv_bias/int8/algos.h" | #include "src/arm_common/conv_bias/int8/algos.h" | ||||
| #include "src/arm_common/conv_bias/int8/direct_dotprod_nchw44.h" | #include "src/arm_common/conv_bias/int8/direct_dotprod_nchw44.h" | ||||
| #include "src/arm_common/elemwise_op.h" | |||||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
| #include "midout.h" | #include "midout.h" | ||||
| @@ -14,7 +14,7 @@ | |||||
| #include "megdnn/arch.h" | #include "megdnn/arch.h" | ||||
| #include "src/arm_common/conv_bias/intrinsic_helper.h" | #include "src/arm_common/conv_bias/intrinsic_helper.h" | ||||
| #if MGB_ENABLE_DOT | #if MGB_ENABLE_DOT | ||||
| #include "src/arm_common/elemwise_op.h" | |||||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
| #include "src/arm_common/intrinsic_helper.h" | #include "src/arm_common/intrinsic_helper.h" | ||||
| #include "src/arm_common/neon_struct.h" | #include "src/arm_common/neon_struct.h" | ||||
| #include "src/common/unroll_macro.h" | #include "src/common/unroll_macro.h" | ||||
| @@ -13,7 +13,7 @@ | |||||
| #include "src/arm_common/conv_bias/int8/direct.h" | #include "src/arm_common/conv_bias/int8/direct.h" | ||||
| #include "src/arm_common/conv_bias/int8/direct_nchw44_kern.h" | #include "src/arm_common/conv_bias/int8/direct_nchw44_kern.h" | ||||
| #include "src/arm_common/conv_bias/intrinsic_helper.h" | #include "src/arm_common/conv_bias/intrinsic_helper.h" | ||||
| #include "src/arm_common/elemwise_op.h" | |||||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
| #include "src/arm_common/simd_macro/marm_neon.h" | #include "src/arm_common/simd_macro/marm_neon.h" | ||||
| #include "src/common/utils.h" | #include "src/common/utils.h" | ||||
| #include "src/fallback/conv_bias/common.h" | #include "src/fallback/conv_bias/common.h" | ||||
| @@ -14,7 +14,7 @@ | |||||
| #include "src/arm_common/conv_bias/int8/direct.h" | #include "src/arm_common/conv_bias/int8/direct.h" | ||||
| #include "src/arm_common/conv_bias/int8/direct_nchw44_kern.h" | #include "src/arm_common/conv_bias/int8/direct_nchw44_kern.h" | ||||
| #include "src/arm_common/conv_bias/intrinsic_helper.h" | #include "src/arm_common/conv_bias/intrinsic_helper.h" | ||||
| #include "src/arm_common/elemwise_op.h" | |||||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
| #include "src/arm_common/simd_macro/marm_neon.h" | #include "src/arm_common/simd_macro/marm_neon.h" | ||||
| #include "src/common/utils.h" | #include "src/common/utils.h" | ||||
| #include "src/fallback/conv_bias/common.h" | #include "src/fallback/conv_bias/common.h" | ||||
| @@ -14,7 +14,7 @@ | |||||
| #include "src/arm_common/conv_bias/int8/algos.h" | #include "src/arm_common/conv_bias/int8/algos.h" | ||||
| #include "src/arm_common/conv_bias/int8/direct.h" | #include "src/arm_common/conv_bias/int8/direct.h" | ||||
| #include "src/arm_common/conv_bias/int8/direct_nchw44_kern.h" | #include "src/arm_common/conv_bias/int8/direct_nchw44_kern.h" | ||||
| #include "src/arm_common/elemwise_op.h" | |||||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
| #include "src/common/opr_delegate.h" | #include "src/common/opr_delegate.h" | ||||
| #include "midout.h" | #include "midout.h" | ||||
| @@ -12,7 +12,7 @@ | |||||
| #pragma once | #pragma once | ||||
| #include "src/arm_common/conv_bias/int8/direct.h" | #include "src/arm_common/conv_bias/int8/direct.h" | ||||
| #include "src/arm_common/conv_bias/intrinsic_helper.h" | #include "src/arm_common/conv_bias/intrinsic_helper.h" | ||||
| #include "src/arm_common/elemwise_op.h" | |||||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
| #include "src/arm_common/simd_macro/marm_neon.h" | #include "src/arm_common/simd_macro/marm_neon.h" | ||||
| #include "src/common/utils.h" | #include "src/common/utils.h" | ||||
| #include "src/fallback/conv_bias/common.h" | #include "src/fallback/conv_bias/common.h" | ||||
| @@ -14,7 +14,7 @@ | |||||
| #include "src/arm_common/conv_bias/int8/algos.h" | #include "src/arm_common/conv_bias/int8/algos.h" | ||||
| #include "src/arm_common/conv_bias/int8/direct_nchw_nchw44_kern.h" | #include "src/arm_common/conv_bias/int8/direct_nchw_nchw44_kern.h" | ||||
| #include "src/arm_common/conv_bias/int8/strategy.h" | #include "src/arm_common/conv_bias/int8/strategy.h" | ||||
| #include "src/arm_common/elemwise_op.h" | |||||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
| #include "src/common/nchw_nchwxx_valid.h" | #include "src/common/nchw_nchwxx_valid.h" | ||||
| #include "src/common/opr_delegate.h" | #include "src/common/opr_delegate.h" | ||||
| @@ -13,7 +13,7 @@ | |||||
| #include "megdnn/arch.h" | #include "megdnn/arch.h" | ||||
| #include "src/arm_common/conv_bias/intrinsic_helper.h" | #include "src/arm_common/conv_bias/intrinsic_helper.h" | ||||
| #include "src/arm_common/conv_bias/opr_impl.h" | #include "src/arm_common/conv_bias/opr_impl.h" | ||||
| #include "src/arm_common/elemwise_op.h" | |||||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
| #include "src/arm_common/simd_macro/marm_neon.h" | #include "src/arm_common/simd_macro/marm_neon.h" | ||||
| #include "src/common/unroll_macro.h" | #include "src/common/unroll_macro.h" | ||||
| #include "src/common/utils.h" | #include "src/common/utils.h" | ||||
| @@ -15,7 +15,7 @@ | |||||
| #include "src/arm_common/conv_bias/block_helper.h" | #include "src/arm_common/conv_bias/block_helper.h" | ||||
| #include "src/arm_common/conv_bias/int8/algos.h" | #include "src/arm_common/conv_bias/int8/algos.h" | ||||
| #include "src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_kern.h" | #include "src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_kern.h" | ||||
| #include "src/arm_common/elemwise_op.h" | |||||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
| #include "src/common/nchw_nchwxx_valid.h" | #include "src/common/nchw_nchwxx_valid.h" | ||||
| #include "midout.h" | #include "midout.h" | ||||
| @@ -14,7 +14,7 @@ | |||||
| #include "src/arm_common/conv_bias/intrinsic_helper.h" | #include "src/arm_common/conv_bias/intrinsic_helper.h" | ||||
| #if MGB_ENABLE_DOT | #if MGB_ENABLE_DOT | ||||
| #include "src/arm_common/elemwise_op.h" | |||||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
| #include "src/arm_common/simd_macro/marm_neon.h" | #include "src/arm_common/simd_macro/marm_neon.h" | ||||
| #include "src/common/unroll_macro.h" | #include "src/common/unroll_macro.h" | ||||
| #include "src/common/utils.h" | #include "src/common/utils.h" | ||||
| @@ -14,7 +14,7 @@ | |||||
| #include "megdnn/oprs.h" | #include "megdnn/oprs.h" | ||||
| #include "src/arm_common/conv_bias/int8/direct.h" | #include "src/arm_common/conv_bias/int8/direct.h" | ||||
| #include "src/arm_common/conv_bias/int8/strategy.h" | #include "src/arm_common/conv_bias/int8/strategy.h" | ||||
| #include "src/arm_common/elemwise_op.h" | |||||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
| #include "src/common/opr_delegate.h" | #include "src/common/opr_delegate.h" | ||||
| using namespace megdnn; | using namespace megdnn; | ||||
| @@ -14,7 +14,7 @@ | |||||
| #include "megdnn/oprs.h" | #include "megdnn/oprs.h" | ||||
| #include "src/arm_common/conv_bias/int8/direct_dotprod.h" | #include "src/arm_common/conv_bias/int8/direct_dotprod.h" | ||||
| #include "src/arm_common/conv_bias/int8/strategy.h" | #include "src/arm_common/conv_bias/int8/strategy.h" | ||||
| #include "src/arm_common/elemwise_op.h" | |||||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
| #include "src/common/opr_delegate.h" | #include "src/common/opr_delegate.h" | ||||
| using namespace megdnn; | using namespace megdnn; | ||||
| @@ -14,7 +14,7 @@ | |||||
| #include "megdnn/oprs.h" | #include "megdnn/oprs.h" | ||||
| #include "src/arm_common/conv_bias/int8/direct.h" | #include "src/arm_common/conv_bias/int8/direct.h" | ||||
| #include "src/arm_common/conv_bias/int8/strategy.h" | #include "src/arm_common/conv_bias/int8/strategy.h" | ||||
| #include "src/arm_common/elemwise_op.h" | |||||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
| #include "src/common/opr_delegate.h" | #include "src/common/opr_delegate.h" | ||||
| using namespace megdnn; | using namespace megdnn; | ||||
| @@ -14,7 +14,7 @@ | |||||
| #include "megdnn/oprs.h" | #include "megdnn/oprs.h" | ||||
| #include "src/arm_common/conv_bias/int8/direct_dotprod.h" | #include "src/arm_common/conv_bias/int8/direct_dotprod.h" | ||||
| #include "src/arm_common/conv_bias/int8/strategy.h" | #include "src/arm_common/conv_bias/int8/strategy.h" | ||||
| #include "src/arm_common/elemwise_op.h" | |||||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
| #include "src/common/opr_delegate.h" | #include "src/common/opr_delegate.h" | ||||
| using namespace megdnn; | using namespace megdnn; | ||||
| @@ -11,7 +11,7 @@ | |||||
| */ | */ | ||||
| #include "src/arm_common/conv_bias/int8x8x16/channel_wise_kernel.h" | #include "src/arm_common/conv_bias/int8x8x16/channel_wise_kernel.h" | ||||
| #include "src/arm_common/elemwise_op.h" | |||||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
| #include "src/arm_common/simd_macro/marm_neon.h" | #include "src/arm_common/simd_macro/marm_neon.h" | ||||
| #include "src/common/unroll_macro.h" | #include "src/common/unroll_macro.h" | ||||
| #include "src/common/utils.h" | #include "src/common/utils.h" | ||||
| @@ -15,7 +15,7 @@ | |||||
| #include "src/arm_common/conv_bias/block_helper.h" | #include "src/arm_common/conv_bias/block_helper.h" | ||||
| #include "src/arm_common/conv_bias/int8x8x16/algos.h" | #include "src/arm_common/conv_bias/int8x8x16/algos.h" | ||||
| #include "src/arm_common/conv_bias/int8x8x16/direct_nchw_nchw44_kern.h" | #include "src/arm_common/conv_bias/int8x8x16/direct_nchw_nchw44_kern.h" | ||||
| #include "src/arm_common/elemwise_op.h" | |||||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
| #include "src/common/nchw_nchwxx_valid.h" | #include "src/common/nchw_nchwxx_valid.h" | ||||
| #include "src/common/opr_delegate.h" | #include "src/common/opr_delegate.h" | ||||
| @@ -13,7 +13,7 @@ | |||||
| #include "megdnn/arch.h" | #include "megdnn/arch.h" | ||||
| #include "src/arm_common/conv_bias/intrinsic_helper.h" | #include "src/arm_common/conv_bias/intrinsic_helper.h" | ||||
| #include "src/arm_common/conv_bias/opr_impl.h" | #include "src/arm_common/conv_bias/opr_impl.h" | ||||
| #include "src/arm_common/elemwise_op.h" | |||||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
| #include "src/arm_common/simd_macro/marm_neon.h" | #include "src/arm_common/simd_macro/marm_neon.h" | ||||
| #include "src/common/unroll_macro.h" | #include "src/common/unroll_macro.h" | ||||
| #include "src/common/utils.h" | #include "src/common/utils.h" | ||||
| @@ -16,7 +16,7 @@ | |||||
| #include "src/arm_common/conv_bias/int8x8x16/direct_nchw_nchw44_kern.h" | #include "src/arm_common/conv_bias/int8x8x16/direct_nchw_nchw44_kern.h" | ||||
| #include "src/arm_common/conv_bias/intrinsic_helper.h" | #include "src/arm_common/conv_bias/intrinsic_helper.h" | ||||
| #include "src/arm_common/conv_bias/opr_impl.h" | #include "src/arm_common/conv_bias/opr_impl.h" | ||||
| #include "src/arm_common/elemwise_op.h" | |||||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
| #include "src/arm_common/simd_macro/marm_neon.h" | #include "src/arm_common/simd_macro/marm_neon.h" | ||||
| #include "src/common/unroll_macro.h" | #include "src/common/unroll_macro.h" | ||||
| #include "src/common/utils.h" | #include "src/common/utils.h" | ||||
| @@ -12,7 +12,7 @@ | |||||
| #include "megdnn/dtype.h" | #include "megdnn/dtype.h" | ||||
| #include "megdnn/oprs.h" | #include "megdnn/oprs.h" | ||||
| #include "src/arm_common/elemwise_op.h" | |||||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
| #include "src/arm_common/simd_macro/marm_neon.h" | #include "src/arm_common/simd_macro/marm_neon.h" | ||||
| #include "src/common/utils.h" | #include "src/common/utils.h" | ||||
| #include "src/fallback/conv_bias/common.h" | #include "src/fallback/conv_bias/common.h" | ||||
| @@ -13,8 +13,8 @@ | |||||
| #pragma once | #pragma once | ||||
| #include "megdnn/basic_types.h" | #include "megdnn/basic_types.h" | ||||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
| #include "src/arm_common/elemwise_helper/kimpl/op_base.h" | #include "src/arm_common/elemwise_helper/kimpl/op_base.h" | ||||
| #include "src/arm_common/elemwise_op.h" | |||||
| #include "src/fallback/conv_bias/opr_impl.h" | #include "src/fallback/conv_bias/opr_impl.h" | ||||
| #include "midout.h" | #include "midout.h" | ||||
| @@ -44,29 +44,29 @@ namespace { | |||||
| break; | break; | ||||
| #define FOR_NONLINEAR_UNARY(_op) \ | #define FOR_NONLINEAR_UNARY(_op) \ | ||||
| megdnn::arm_common::OpCallerUnary<_op<ctype>, megdnn::VEC>::run( \ | |||||
| megdnn::elemwise::OpCallerUnary<_op<ctype>, megdnn::elemwise::VEC>::run( \ | |||||
| static_cast<ctype*>(conv_dst_ptr), reinterpret_cast<ctype*>(dst_ptr), \ | static_cast<ctype*>(conv_dst_ptr), reinterpret_cast<ctype*>(dst_ptr), \ | ||||
| bias_type, dst_type, N* OC* OH* OW* pack_oc_size); | bias_type, dst_type, N* OC* OH* OW* pack_oc_size); | ||||
| #define FOR_NONLINEAR_BINARY_BROADCAST(_op) \ | |||||
| megdnn::arm_common::OpCallerBinary<_op<ctype>, megdnn::VEC_BCAST101>::run( \ | |||||
| static_cast<ctype*>(conv_dst_ptr), \ | |||||
| reinterpret_cast<const ctype*>(bias_ptr), \ | |||||
| reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, N, OC, \ | |||||
| #define FOR_NONLINEAR_BINARY_BROADCAST(_op) \ | |||||
| megdnn::elemwise::OpCallerBinary<_op<ctype>, megdnn::elemwise::VEC_BCAST101>::run( \ | |||||
| static_cast<ctype*>(conv_dst_ptr), \ | |||||
| reinterpret_cast<const ctype*>(bias_ptr), \ | |||||
| reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, N, OC, \ | |||||
| OH* OW); | OH* OW); | ||||
| #define FOR_NONLINEAR_BINARY_BROADCAST_NCHWXX(_op) \ | #define FOR_NONLINEAR_BINARY_BROADCAST_NCHWXX(_op) \ | ||||
| megdnn::arm_common::OpCallerBinary<_op<ctype>, megdnn::VEC_BCAST101xX>::run( \ | |||||
| static_cast<ctype*>(conv_dst_ptr), \ | |||||
| reinterpret_cast<const ctype*>(bias_ptr), \ | |||||
| reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, N, OC, \ | |||||
| OH* OW, pack_oc_size); | |||||
| #define FOR_NONLINEAR_BINARY(_op) \ | |||||
| megdnn::arm_common::OpCallerBinary<_op<ctype>, megdnn::VEC_VEC>::run( \ | |||||
| static_cast<ctype*>(conv_dst_ptr), \ | |||||
| reinterpret_cast<const ctype*>(bias_ptr), \ | |||||
| reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, \ | |||||
| megdnn::elemwise::OpCallerBinary<_op<ctype>, megdnn::elemwise::VEC_BCAST101xX>:: \ | |||||
| run(static_cast<ctype*>(conv_dst_ptr), \ | |||||
| reinterpret_cast<const ctype*>(bias_ptr), \ | |||||
| reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, N, \ | |||||
| OC, OH* OW, pack_oc_size); | |||||
| #define FOR_NONLINEAR_BINARY(_op) \ | |||||
| megdnn::elemwise::OpCallerBinary<_op<ctype>, megdnn::elemwise::VEC_VEC>::run( \ | |||||
| static_cast<ctype*>(conv_dst_ptr), \ | |||||
| reinterpret_cast<const ctype*>(bias_ptr), \ | |||||
| reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, \ | |||||
| N* OC* OH* OW* pack_oc_size); | N* OC* OH* OW* pack_oc_size); | ||||
| #define FOR_BIAS(_mode) \ | #define FOR_BIAS(_mode) \ | ||||
| @@ -167,33 +167,35 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::NO_PROCESS> { | |||||
| #undef FOR_BIAS | #undef FOR_BIAS | ||||
| #undef HANDLE_IDENTITY | #undef HANDLE_IDENTITY | ||||
| #define FOR_NONLINEAR_UNARY(_op) \ | |||||
| megdnn::arm_common::OpCallerUnary<_op<opctype, opdtype>, megdnn::VEC>::run( \ | |||||
| static_cast<opctype*>(conv_dst_ptr), reinterpret_cast<opdtype*>(dst_ptr), \ | |||||
| bias_type, dst_type, N* OC* OH* OW* pack_oc_size); | |||||
| #define FOR_NONLINEAR_BINARY_BROADCAST(_op) \ | |||||
| megdnn::arm_common::OpCallerBinary<_op<opctype, opdtype>, megdnn::VEC_BCAST101>:: \ | |||||
| run(static_cast<opctype*>(conv_dst_ptr), \ | |||||
| reinterpret_cast<const opctype*>(bias_ptr), \ | |||||
| reinterpret_cast<opdtype*>(dst_ptr), bias_type, bias_type, dst_type, \ | |||||
| #define FOR_NONLINEAR_UNARY(_op) \ | |||||
| megdnn::elemwise::OpCallerUnary<_op<opctype, opdtype>, megdnn::elemwise::VEC>:: \ | |||||
| run(static_cast<opctype*>(conv_dst_ptr), \ | |||||
| reinterpret_cast<opdtype*>(dst_ptr), bias_type, dst_type, \ | |||||
| N* OC* OH* OW* pack_oc_size); | |||||
| #define FOR_NONLINEAR_BINARY_BROADCAST(_op) \ | |||||
| megdnn::elemwise::OpCallerBinary< \ | |||||
| _op<opctype, opdtype>, megdnn::elemwise::VEC_BCAST101>:: \ | |||||
| run(static_cast<opctype*>(conv_dst_ptr), \ | |||||
| reinterpret_cast<const opctype*>(bias_ptr), \ | |||||
| reinterpret_cast<opdtype*>(dst_ptr), bias_type, bias_type, dst_type, \ | |||||
| N, OC, OH* OW); | N, OC, OH* OW); | ||||
| #define FOR_NONLINEAR_BINARY_BROADCAST_NCHWXX(_op) \ | |||||
| megdnn::arm_common:: \ | |||||
| OpCallerBinary<_op<opctype, opdtype>, megdnn::VEC_BCAST101xX>::run( \ | |||||
| static_cast<opctype*>(conv_dst_ptr), \ | |||||
| reinterpret_cast<const opctype*>(bias_ptr), \ | |||||
| reinterpret_cast<opdtype*>(dst_ptr), bias_type, bias_type, \ | |||||
| dst_type, N, OC, OH* OW, pack_oc_size); | |||||
| #define FOR_NONLINEAR_BINARY_BROADCAST_NCHW88(_op) \ | |||||
| megdnn::arm_common:: \ | |||||
| OpCallerBinary<_op<opctype, opdtype>, megdnn::VEC_BCAST101xX>::run( \ | |||||
| static_cast<opctype*>(conv_dst_ptr), \ | |||||
| reinterpret_cast<const opctype*>(bias_ptr), \ | |||||
| reinterpret_cast<opdtype*>(dst_ptr), bias_type, bias_type, \ | |||||
| dst_type, N, OC, OH* OW, pack_oc_size); | |||||
| #define FOR_NONLINEAR_BINARY_BROADCAST_NCHWXX(_op) \ | |||||
| megdnn::elemwise::OpCallerBinary< \ | |||||
| _op<opctype, opdtype>, megdnn::elemwise::VEC_BCAST101xX>:: \ | |||||
| run(static_cast<opctype*>(conv_dst_ptr), \ | |||||
| reinterpret_cast<const opctype*>(bias_ptr), \ | |||||
| reinterpret_cast<opdtype*>(dst_ptr), bias_type, bias_type, dst_type, \ | |||||
| N, OC, OH* OW, pack_oc_size); | |||||
| #define FOR_NONLINEAR_BINARY_BROADCAST_NCHW88(_op) \ | |||||
| megdnn::elemwise::OpCallerBinary< \ | |||||
| _op<opctype, opdtype>, megdnn::elemwise::VEC_BCAST101xX>:: \ | |||||
| run(static_cast<opctype*>(conv_dst_ptr), \ | |||||
| reinterpret_cast<const opctype*>(bias_ptr), \ | |||||
| reinterpret_cast<opdtype*>(dst_ptr), bias_type, bias_type, dst_type, \ | |||||
| N, OC, OH* OW, pack_oc_size); | |||||
| #define HANDLE_IDENTITY(_caller, _op) \ | #define HANDLE_IDENTITY(_caller, _op) \ | ||||
| case megdnn::NonlineMode::IDENTITY: \ | case megdnn::NonlineMode::IDENTITY: \ | ||||
| @@ -267,25 +269,25 @@ struct PostProcess<opctype, opdtype, megdnn::PostprocessMode::QUANTIZED> { | |||||
| #undef FOR_NONLINEAR | #undef FOR_NONLINEAR | ||||
| #undef FOR_BIAS | #undef FOR_BIAS | ||||
| #define FOR_BINARY_BROADCAST(_op) \ | |||||
| megdnn::arm_common::OpCallerBinary<_op<ctype>, megdnn::VEC_BCAST101>::run( \ | |||||
| static_cast<ctype*>(conv_dst_ptr), \ | |||||
| reinterpret_cast<const ctype*>(bias_ptr), \ | |||||
| reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, N, OC, \ | |||||
| #define FOR_BINARY_BROADCAST(_op) \ | |||||
| megdnn::elemwise::OpCallerBinary<_op<ctype>, megdnn::elemwise::VEC_BCAST101>::run( \ | |||||
| static_cast<ctype*>(conv_dst_ptr), \ | |||||
| reinterpret_cast<const ctype*>(bias_ptr), \ | |||||
| reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, N, OC, \ | |||||
| OH* OW); | OH* OW); | ||||
| #define FOR_BINARY_BROADCAST_NCHWXX(_op) \ | #define FOR_BINARY_BROADCAST_NCHWXX(_op) \ | ||||
| megdnn::arm_common::OpCallerBinary<_op<ctype>, megdnn::VEC_BCAST101xX>::run( \ | |||||
| static_cast<ctype*>(conv_dst_ptr), \ | |||||
| reinterpret_cast<const ctype*>(bias_ptr), \ | |||||
| reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, N, OC, \ | |||||
| OH* OW, pack_oc_size); | |||||
| #define FOR_BINARY(_op) \ | |||||
| megdnn::arm_common::OpCallerBinary<_op<ctype>, megdnn::VEC_VEC>::run( \ | |||||
| static_cast<ctype*>(conv_dst_ptr), \ | |||||
| reinterpret_cast<const ctype*>(bias_ptr), \ | |||||
| reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, \ | |||||
| megdnn::elemwise::OpCallerBinary<_op<ctype>, megdnn::elemwise::VEC_BCAST101xX>:: \ | |||||
| run(static_cast<ctype*>(conv_dst_ptr), \ | |||||
| reinterpret_cast<const ctype*>(bias_ptr), \ | |||||
| reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, N, \ | |||||
| OC, OH* OW, pack_oc_size); | |||||
| #define FOR_BINARY(_op) \ | |||||
| megdnn::elemwise::OpCallerBinary<_op<ctype>, megdnn::elemwise::VEC_VEC>::run( \ | |||||
| static_cast<ctype*>(conv_dst_ptr), \ | |||||
| reinterpret_cast<const ctype*>(bias_ptr), \ | |||||
| reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, \ | |||||
| N* OC* OH* OW* pack_oc_size); | N* OC* OH* OW* pack_oc_size); | ||||
| #define FOR_BIAS(_bias_mode, OH, OW) \ | #define FOR_BIAS(_bias_mode, OH, OW) \ | ||||
| @@ -15,7 +15,7 @@ | |||||
| #include "src/arm_common/conv_bias/quint8/stride1_dotprod.h" | #include "src/arm_common/conv_bias/quint8/stride1_dotprod.h" | ||||
| #include "src/arm_common/conv_bias/quint8/stride2.h" | #include "src/arm_common/conv_bias/quint8/stride2.h" | ||||
| #include "src/arm_common/conv_bias/quint8/stride2_dotprod.h" | #include "src/arm_common/conv_bias/quint8/stride2_dotprod.h" | ||||
| #include "src/arm_common/elemwise_op.h" | |||||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
| #include "src/fallback/conv_bias/common.h" | #include "src/fallback/conv_bias/common.h" | ||||
| #include "midout.h" | #include "midout.h" | ||||
| @@ -10,7 +10,7 @@ | |||||
| */ | */ | ||||
| #include "src/arm_common/conv_bias/quint8/direct.h" | #include "src/arm_common/conv_bias/quint8/direct.h" | ||||
| #include "src/arm_common/elemwise_op.h" | |||||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
| #include "src/arm_common/simd_macro/marm_neon.h" | #include "src/arm_common/simd_macro/marm_neon.h" | ||||
| #include "src/common/utils.h" | #include "src/common/utils.h" | ||||
| #include "src/fallback/conv_bias/common.h" | #include "src/fallback/conv_bias/common.h" | ||||
| @@ -11,7 +11,7 @@ | |||||
| #include "src/arm_common/conv_bias/quint8/direct_dotprod.h" | #include "src/arm_common/conv_bias/quint8/direct_dotprod.h" | ||||
| #if MGB_ENABLE_DOT | #if MGB_ENABLE_DOT | ||||
| #include "src/arm_common/elemwise_op.h" | |||||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
| #include "src/arm_common/simd_macro/marm_neon.h" | #include "src/arm_common/simd_macro/marm_neon.h" | ||||
| #include "src/common/utils.h" | #include "src/common/utils.h" | ||||
| #include "src/fallback/conv_bias/common.h" | #include "src/fallback/conv_bias/common.h" | ||||
| @@ -12,7 +12,7 @@ | |||||
| #include "src/arm_common/conv_bias/quint8/stride1.h" | #include "src/arm_common/conv_bias/quint8/stride1.h" | ||||
| #include "megdnn/oprs.h" | #include "megdnn/oprs.h" | ||||
| #include "src/arm_common/conv_bias/quint8/direct.h" | #include "src/arm_common/conv_bias/quint8/direct.h" | ||||
| #include "src/arm_common/elemwise_op.h" | |||||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
| #include "src/common/opr_delegate.h" | #include "src/common/opr_delegate.h" | ||||
| using namespace megdnn; | using namespace megdnn; | ||||
| @@ -12,7 +12,7 @@ | |||||
| #if MGB_ENABLE_DOT | #if MGB_ENABLE_DOT | ||||
| #include "megdnn/oprs.h" | #include "megdnn/oprs.h" | ||||
| #include "src/arm_common/conv_bias/quint8/direct_dotprod.h" | #include "src/arm_common/conv_bias/quint8/direct_dotprod.h" | ||||
| #include "src/arm_common/elemwise_op.h" | |||||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
| #include "src/common/opr_delegate.h" | #include "src/common/opr_delegate.h" | ||||
| using namespace megdnn; | using namespace megdnn; | ||||
| @@ -12,7 +12,7 @@ | |||||
| #include "src/arm_common/conv_bias/quint8/stride2.h" | #include "src/arm_common/conv_bias/quint8/stride2.h" | ||||
| #include "megdnn/oprs.h" | #include "megdnn/oprs.h" | ||||
| #include "src/arm_common/conv_bias/quint8/direct.h" | #include "src/arm_common/conv_bias/quint8/direct.h" | ||||
| #include "src/arm_common/elemwise_op.h" | |||||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
| #include "src/common/opr_delegate.h" | #include "src/common/opr_delegate.h" | ||||
| using namespace megdnn; | using namespace megdnn; | ||||
| @@ -12,7 +12,7 @@ | |||||
| #if MGB_ENABLE_DOT | #if MGB_ENABLE_DOT | ||||
| #include "megdnn/oprs.h" | #include "megdnn/oprs.h" | ||||
| #include "src/arm_common/conv_bias/quint8/direct_dotprod.h" | #include "src/arm_common/conv_bias/quint8/direct_dotprod.h" | ||||
| #include "src/arm_common/elemwise_op.h" | |||||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
| #include "src/common/opr_delegate.h" | #include "src/common/opr_delegate.h" | ||||
| using namespace megdnn; | using namespace megdnn; | ||||
| @@ -10,7 +10,7 @@ | |||||
| * implied. | * implied. | ||||
| */ | */ | ||||
| #include "src/arm_common/elemwise/binary/algo.h" | #include "src/arm_common/elemwise/binary/algo.h" | ||||
| #include "src/arm_common/elemwise_op.h" | |||||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
| #include "src/common/utils.h" | #include "src/common/utils.h" | ||||
| #include "src/naive/handle.h" | #include "src/naive/handle.h" | ||||
| @@ -20,6 +20,7 @@ | |||||
| MIDOUT_DECL(megdnn_arm_common_elemwise_binary) | MIDOUT_DECL(megdnn_arm_common_elemwise_binary) | ||||
| using namespace megdnn; | using namespace megdnn; | ||||
| using namespace elemwise; | |||||
| using namespace arm_common; | using namespace arm_common; | ||||
| namespace { | namespace { | ||||
| @@ -160,7 +161,7 @@ bool ElemwiseImpl::AlgoBinaryVecBcast101xX::is_available( | |||||
| DISPATCH_BINARY(ADD, _case, _type, _type_midout_id, AddOp); \ | DISPATCH_BINARY(ADD, _case, _type, _type_midout_id, AddOp); \ | ||||
| DISPATCH_BINARY(SUB, _case, _type, _type_midout_id, SubOp); \ | DISPATCH_BINARY(SUB, _case, _type, _type_midout_id, SubOp); \ | ||||
| DISPATCH_BINARY(MUL, _case, _type, _type_midout_id, MulOp); \ | DISPATCH_BINARY(MUL, _case, _type, _type_midout_id, MulOp); \ | ||||
| DISPATCH_BINARY(POW, _case, _type, _type_midout_id, PowOp); \ | |||||
| DISPATCH_BINARY(POW, _case, _type, _type_midout_id, fallback::PowOp); \ | |||||
| DISPATCH_BINARY(TRUE_DIV, _case, _type, _type_midout_id, TrueDivOp); \ | DISPATCH_BINARY(TRUE_DIV, _case, _type, _type_midout_id, TrueDivOp); \ | ||||
| DISPATCH_BINARY(FUSE_ADD_RELU, _case, _type, _type_midout_id, FuseAddReluOp); \ | DISPATCH_BINARY(FUSE_ADD_RELU, _case, _type, _type_midout_id, FuseAddReluOp); \ | ||||
| DISPATCH_BINARY( \ | DISPATCH_BINARY( \ | ||||
| @@ -178,7 +179,7 @@ bool ElemwiseImpl::AlgoBinaryVecBcast101xX::is_available( | |||||
| DISPATCH_BINARY(ADD, _case, _type, _type_midout_id, AddOp); \ | DISPATCH_BINARY(ADD, _case, _type, _type_midout_id, AddOp); \ | ||||
| DISPATCH_BINARY(SUB, _case, _type, _type_midout_id, SubOp); \ | DISPATCH_BINARY(SUB, _case, _type, _type_midout_id, SubOp); \ | ||||
| DISPATCH_BINARY(MUL, _case, _type, _type_midout_id, MulOp); \ | DISPATCH_BINARY(MUL, _case, _type, _type_midout_id, MulOp); \ | ||||
| DISPATCH_BINARY(POW, _case, _type, _type_midout_id, PowOp); \ | |||||
| DISPATCH_BINARY(POW, _case, _type, _type_midout_id, fallback::PowOp); \ | |||||
| DISPATCH_BINARY(FUSE_ADD_RELU, _case, _type, _type_midout_id, FuseAddReluOp); \ | DISPATCH_BINARY(FUSE_ADD_RELU, _case, _type, _type_midout_id, FuseAddReluOp); \ | ||||
| DISPATCH_BINARY( \ | DISPATCH_BINARY( \ | ||||
| FUSE_ADD_H_SWISH, _case, _type, _type_midout_id, FuseAddHSwishOp); \ | FUSE_ADD_H_SWISH, _case, _type, _type_midout_id, FuseAddHSwishOp); \ | ||||
| @@ -13,7 +13,7 @@ | |||||
| #include "src/arm_common/elemwise/binary/algo.h" | #include "src/arm_common/elemwise/binary/algo.h" | ||||
| #include "src/arm_common/elemwise/ternary/algo.h" | #include "src/arm_common/elemwise/ternary/algo.h" | ||||
| #include "src/arm_common/elemwise/unary/algo.h" | #include "src/arm_common/elemwise/unary/algo.h" | ||||
| #include "src/arm_common/elemwise_op.h" | |||||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
| #include "src/common/metahelper.h" | #include "src/common/metahelper.h" | ||||
| #include "src/common/utils.h" | #include "src/common/utils.h" | ||||
| @@ -12,7 +12,7 @@ | |||||
| #pragma once | #pragma once | ||||
| #include "src/fallback/elemwise/opr_impl.h" | #include "src/fallback/elemwise/opr_impl.h" | ||||
| #include "src/arm_common/elemwise_op.h" | |||||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
| namespace megdnn { | namespace megdnn { | ||||
| namespace arm_common { | namespace arm_common { | ||||
| @@ -10,7 +10,7 @@ | |||||
| * implied. | * implied. | ||||
| */ | */ | ||||
| #include "src/arm_common/elemwise/ternary/algo.h" | #include "src/arm_common/elemwise/ternary/algo.h" | ||||
| #include "src/arm_common/elemwise_op.h" | |||||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
| #include "src/common/utils.h" | #include "src/common/utils.h" | ||||
| #include "src/naive/handle.h" | #include "src/naive/handle.h" | ||||
| @@ -20,6 +20,7 @@ | |||||
| MIDOUT_DECL(megdnn_arm_common_elemwise_ternary) | MIDOUT_DECL(megdnn_arm_common_elemwise_ternary) | ||||
| using namespace megdnn; | using namespace megdnn; | ||||
| using namespace elemwise; | |||||
| using namespace arm_common; | using namespace arm_common; | ||||
| #define DISPATCH_MODE_FLOAT(_case, _type, _type_midout_id) \ | #define DISPATCH_MODE_FLOAT(_case, _type, _type_midout_id) \ | ||||
| @@ -10,7 +10,7 @@ | |||||
| * implied. | * implied. | ||||
| */ | */ | ||||
| #include "src/arm_common/elemwise/unary/algo.h" | #include "src/arm_common/elemwise/unary/algo.h" | ||||
| #include "src/arm_common/elemwise_op.h" | |||||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
| #include "src/common/utils.h" | #include "src/common/utils.h" | ||||
| #include "src/naive/handle.h" | #include "src/naive/handle.h" | ||||
| @@ -20,6 +20,7 @@ | |||||
| MIDOUT_DECL(megdnn_arm_common_elemwise_unary) | MIDOUT_DECL(megdnn_arm_common_elemwise_unary) | ||||
| using namespace megdnn; | using namespace megdnn; | ||||
| using namespace elemwise; | |||||
| using namespace arm_common; | using namespace arm_common; | ||||
| bool ElemwiseImpl::AlgoUnary::is_available(const KernParam& kern_param) const { | bool ElemwiseImpl::AlgoUnary::is_available(const KernParam& kern_param) const { | ||||
| @@ -0,0 +1,151 @@ | |||||
| /** | |||||
| * \file dnn/src/arm_common/elemwise_helper/elemwise_op.h | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #pragma once | |||||
| #include "src/arm_common/elemwise_helper/op_binary.h" | |||||
| #include "src/arm_common/elemwise_helper/op_ternary.h" | |||||
| #include "src/arm_common/elemwise_helper/op_unary.h" | |||||
| #include "src/fallback/elemwise_helper/elemwise_op.h" | |||||
| namespace megdnn { | |||||
| namespace elemwise { | |||||
| using BcastType = megdnn::elemwise::BcastType; | |||||
| ///////////////////////////////// ParamElemVistor /////////////////////////// | |||||
| #define cb(_ctype, _inner_ctype, _neon_type, _fun_suffix) \ | |||||
| template <> \ | |||||
| struct ParamElemVisitor<_ctype> { \ | |||||
| _neon_type operator()(const _ctype* src) const { \ | |||||
| return vld1q_##_fun_suffix(reinterpret_cast<const _inner_ctype*>(src)); \ | |||||
| } \ | |||||
| }; \ | |||||
| template <> \ | |||||
| struct ParamElemVisitorDup<_ctype> { \ | |||||
| _neon_type operator()(const _ctype* src) const { \ | |||||
| return vdupq_n_##_fun_suffix(*reinterpret_cast<const _inner_ctype*>(src)); \ | |||||
| } \ | |||||
| } | |||||
| cb(dt_quint8, uint8_t, uint8x16_t, u8); | |||||
| #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||||
| cb(__fp16, __fp16, float16x8_t, f16); | |||||
| #endif | |||||
| cb(dt_int16, int16_t, int16x8_t, s16); | |||||
| #undef cb | |||||
| template <typename ctype> | |||||
| struct ParamElemVisitorBcast101x4; | |||||
| #define cb(_ctype, _inner_ctype, _neon_type, _fun_suffix, rel_suffix) \ | |||||
| template <> \ | |||||
| struct ParamElemVisitorBcast101x4<_ctype> { \ | |||||
| _neon_type operator()(const _ctype* src) const { \ | |||||
| return vreinterpretq_##_fun_suffix##_##rel_suffix(vld1q_dup_##rel_suffix( \ | |||||
| reinterpret_cast<const _inner_ctype*>(src))); \ | |||||
| } \ | |||||
| } | |||||
| cb(dt_quint8, uint32_t, uint8x16_t, u8, u32); | |||||
| cb(dt_int16, int64_t, int16x8_t, s16, s64); | |||||
| #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||||
| cb(__fp16, uint64_t, float16x8_t, f16, u64); | |||||
| #endif | |||||
| #undef cb | |||||
| template <typename ctype> | |||||
| struct ParamElemVisitorBcast101x8; | |||||
| #define cb(_ctype, _inner_ctype, _neon_type, _fun_suffix) \ | |||||
| template <> \ | |||||
| struct ParamElemVisitorBcast101x8<_ctype> { \ | |||||
| _neon_type operator()(const _ctype* src) const { \ | |||||
| return vld1q_##_fun_suffix(reinterpret_cast<const _inner_ctype*>(src)); \ | |||||
| } \ | |||||
| } | |||||
| #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||||
| cb(__fp16, __fp16, float16x8_t, f16); | |||||
| #endif | |||||
| #undef cb | |||||
| #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||||
| template <> | |||||
| struct OpCallerBinaryBcast101xXVec<__fp16, 8> { | |||||
| using src_ctype = __fp16; | |||||
| template <typename Op> | |||||
| static void run( | |||||
| const src_ctype* src0, const src_ctype* src1, typename Op::dst_ctype* dst, | |||||
| const Op& op, size_t batch, size_t nr_channel_blocks, | |||||
| size_t channel_stride) { | |||||
| ParamElemVisitorBcast101x8<src_ctype> vis0; | |||||
| ParamElemVisitor<src_ctype> vis1; | |||||
| OpCallerBinaryBcast101xDVec<src_ctype, 8>::run( | |||||
| src0, src1, dst, op, vis0, vis1, batch, nr_channel_blocks, | |||||
| channel_stride); | |||||
| } | |||||
| }; | |||||
| template <> | |||||
| struct OpCallerBinaryVecBcast101xX<__fp16, 8> { | |||||
| using src_ctype = __fp16; | |||||
| template <typename Op> | |||||
| static void run( | |||||
| const src_ctype* src0, const src_ctype* src1, typename Op::dst_ctype* dst, | |||||
| const Op& op, size_t batch, size_t nr_channel_blocks, | |||||
| size_t channel_stride) { | |||||
| ParamElemVisitor<src_ctype> vis0; | |||||
| ParamElemVisitorBcast101x8<src_ctype> vis1; | |||||
| OpCallerBinaryVecBcast101xD<src_ctype, 8>::run( | |||||
| src0, src1, dst, op, vis0, vis1, batch, nr_channel_blocks, | |||||
| channel_stride); | |||||
| } | |||||
| }; | |||||
| template <> | |||||
| struct OpCallerTernaryBcast101xXVecBcast101xX<__fp16, 8> { | |||||
| using src_ctype = __fp16; | |||||
| template <typename Op> | |||||
| static void run( | |||||
| const src_ctype* src0, const src_ctype* src1, const src_ctype* src2, | |||||
| typename Op::dst_ctype* dst, const Op& op, size_t batch, | |||||
| size_t nr_channel_blocks, size_t channel_stride) { | |||||
| ParamElemVisitorBcast101x8<src_ctype> vis0; | |||||
| ParamElemVisitor<src_ctype> vis1; | |||||
| ParamElemVisitorBcast101x8<src_ctype> vis2; | |||||
| OpCallerTernaryBcast101xDVecBcast101xD<src_ctype, 8>::run( | |||||
| src0, src1, src2, dst, op, vis0, vis1, vis2, batch, nr_channel_blocks, | |||||
| channel_stride); | |||||
| } | |||||
| }; | |||||
| template <> | |||||
| struct OpCallerTernaryVecBcast101xXVec<__fp16, 8> { | |||||
| using src_ctype = __fp16; | |||||
| template <typename Op> | |||||
| static void run( | |||||
| const src_ctype* src0, const src_ctype* src1, const src_ctype* src2, | |||||
| typename Op::dst_ctype* dst, const Op& op, size_t batch, | |||||
| size_t nr_channel_blocks, size_t channel_stride) { | |||||
| ParamElemVisitor<src_ctype> vis0; | |||||
| ParamElemVisitorBcast101x8<src_ctype> vis1; | |||||
| ParamElemVisitor<src_ctype> vis2; | |||||
| OpCallerTernaryVecBcast101xDVec<src_ctype, 8>::run( | |||||
| src0, src1, src2, dst, op, vis0, vis1, vis2, batch, nr_channel_blocks, | |||||
| channel_stride); | |||||
| } | |||||
| }; | |||||
| #endif | |||||
| } // namespace elemwise | |||||
| } // namespace megdnn | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -1,36 +0,0 @@ | |||||
| /** | |||||
| * \file dnn/src/arm_common/elemwise_helper/kimpl/pow.h | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| */ | |||||
| #pragma once | |||||
| #include "src/arm_common/elemwise_helper/kimpl/op_base.h" | |||||
| namespace megdnn { | |||||
| namespace arm_common { | |||||
| // when __fp16 is avaliable POW is very slow, so add there | |||||
| /////////////////////// POW float only //////////////////////////// | |||||
| template <typename src_ctype, typename dst_ctype = src_ctype> | |||||
| struct PowOp : BinaryOpBase<src_ctype, dst_ctype> { | |||||
| using BinaryOpBase<src_ctype, dst_ctype>::BinaryOpBase; | |||||
| constexpr static size_t SIMD_WIDTH = 1; | |||||
| void operator()( | |||||
| const src_ctype& src0, const src_ctype& src1, dst_ctype* dst) const { | |||||
| *dst = operator()(src0, src1); | |||||
| } | |||||
| dst_ctype operator()(const src_ctype& src0, const src_ctype& src1) const { | |||||
| return powf(src0, src1); | |||||
| } | |||||
| }; | |||||
| } // namespace arm_common | |||||
| } // namespace megdnn | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -18,7 +18,6 @@ | |||||
| #include "src/arm_common/elemwise_helper/kimpl/max.h" | #include "src/arm_common/elemwise_helper/kimpl/max.h" | ||||
| #include "src/arm_common/elemwise_helper/kimpl/min.h" | #include "src/arm_common/elemwise_helper/kimpl/min.h" | ||||
| #include "src/arm_common/elemwise_helper/kimpl/mul.h" | #include "src/arm_common/elemwise_helper/kimpl/mul.h" | ||||
| #include "src/arm_common/elemwise_helper/kimpl/pow.h" | |||||
| #include "src/arm_common/elemwise_helper/kimpl/rmulh.h" | #include "src/arm_common/elemwise_helper/kimpl/rmulh.h" | ||||
| #include "src/arm_common/elemwise_helper/kimpl/sub.h" | #include "src/arm_common/elemwise_helper/kimpl/sub.h" | ||||
| #include "src/arm_common/elemwise_helper/kimpl/true_div.h" | #include "src/arm_common/elemwise_helper/kimpl/true_div.h" | ||||
| @@ -15,7 +15,7 @@ | |||||
| #include "src/common/elemwise_multi_type/kern_defs.cuh" | #include "src/common/elemwise_multi_type/kern_defs.cuh" | ||||
| #include "src/naive/handle.h" | #include "src/naive/handle.h" | ||||
| #include "src/arm_common/elemwise_op.h" | |||||
| #include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
| #include "src/arm_common/simd_macro/marm_neon.h" | #include "src/arm_common/simd_macro/marm_neon.h" | ||||
| namespace { | namespace { | ||||
| @@ -46,6 +46,8 @@ void neon_round_shr_saturate_int16_static_k( | |||||
| } // namespace | } // namespace | ||||
| using namespace elemwise; | |||||
| namespace megdnn { | namespace megdnn { | ||||
| namespace arm_common { | namespace arm_common { | ||||
| @@ -2,7 +2,7 @@ | |||||
| * \file dnn/src/fallback/elemwise/gi_impl/binary/algo.cpp | * \file dnn/src/fallback/elemwise/gi_impl/binary/algo.cpp | ||||
| */ | */ | ||||
| #include "src/fallback/elemwise/gi_impl/binary/algo.h" | #include "src/fallback/elemwise/gi_impl/binary/algo.h" | ||||
| #include "src/fallback/elemwise_op.h" | |||||
| #include "src/fallback/elemwise_helper/elemwise_op.h" | |||||
| #include "src/common/utils.h" | #include "src/common/utils.h" | ||||
| #include "src/naive/handle.h" | #include "src/naive/handle.h" | ||||
| @@ -12,6 +12,7 @@ | |||||
| MIDOUT_DECL(megdnn_fallback_elemwise_binary) | MIDOUT_DECL(megdnn_fallback_elemwise_binary) | ||||
| using namespace megdnn; | using namespace megdnn; | ||||
| using namespace elemwise; | |||||
| using namespace fallback; | using namespace fallback; | ||||
| namespace { | namespace { | ||||
| @@ -3,7 +3,7 @@ | |||||
| */ | */ | ||||
| #include "src/fallback/elemwise/gi_impl/ternary/algo.h" | #include "src/fallback/elemwise/gi_impl/ternary/algo.h" | ||||
| #include "src/fallback/elemwise_op.h" | |||||
| #include "src/fallback/elemwise_helper/elemwise_op.h" | |||||
| #include "src/common/utils.h" | #include "src/common/utils.h" | ||||
| #include "src/naive/handle.h" | #include "src/naive/handle.h" | ||||
| @@ -13,6 +13,7 @@ | |||||
| MIDOUT_DECL(megdnn_fallback_elemwise_ternary) | MIDOUT_DECL(megdnn_fallback_elemwise_ternary) | ||||
| using namespace megdnn; | using namespace megdnn; | ||||
| using namespace elemwise; | |||||
| using namespace fallback; | using namespace fallback; | ||||
| #define DISPATCH_MODE_FLOAT(_case, _type, _type_midout_id) \ | #define DISPATCH_MODE_FLOAT(_case, _type, _type_midout_id) \ | ||||
| @@ -2,7 +2,7 @@ | |||||
| * \file dnn/src/fallback/elemwise/gi_impl/unary/algo.cpp | * \file dnn/src/fallback/elemwise/gi_impl/unary/algo.cpp | ||||
| */ | */ | ||||
| #include "src/fallback/elemwise/gi_impl/unary/algo.h" | #include "src/fallback/elemwise/gi_impl/unary/algo.h" | ||||
| #include "src/fallback/elemwise_op.h" | |||||
| #include "src/fallback/elemwise_helper/elemwise_op.h" | |||||
| #include "src/common/utils.h" | #include "src/common/utils.h" | ||||
| #include "src/naive/handle.h" | #include "src/naive/handle.h" | ||||
| @@ -12,6 +12,7 @@ | |||||
| MIDOUT_DECL(megdnn_fallback_elemwise_unary) | MIDOUT_DECL(megdnn_fallback_elemwise_unary) | ||||
| using namespace megdnn; | using namespace megdnn; | ||||
| using namespace elemwise; | |||||
| using namespace fallback; | using namespace fallback; | ||||
| bool ElemwiseImpl::AlgoUnary::is_available(const KernParam& kern_param) const { | bool ElemwiseImpl::AlgoUnary::is_available(const KernParam& kern_param) const { | ||||
| @@ -25,6 +25,7 @@ MIDOUT_DECL(megdnn_fallback_elemwise_exec_BINARY_INT) | |||||
| MIDOUT_DECL(megdnn_fallback_elemwise_exec_BINARY_FLOAT) | MIDOUT_DECL(megdnn_fallback_elemwise_exec_BINARY_FLOAT) | ||||
| using namespace megdnn; | using namespace megdnn; | ||||
| using namespace elemwise; | |||||
| using namespace fallback; | using namespace fallback; | ||||
| void ElemwiseImpl::exec(const TensorNDArray& srcs, _megdnn_tensor_out dst) { | void ElemwiseImpl::exec(const TensorNDArray& srcs, _megdnn_tensor_out dst) { | ||||
| @@ -9,7 +9,7 @@ | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| */ | */ | ||||
| #pragma once | #pragma once | ||||
| #include "src/fallback/elemwise_op.h" | |||||
| #include "src/fallback/elemwise_helper/elemwise_op.h" | |||||
| #include "src/naive/elemwise/opr_impl.h" | #include "src/naive/elemwise/opr_impl.h" | ||||
| namespace megdnn { | namespace megdnn { | ||||
| @@ -60,7 +60,7 @@ private: | |||||
| public: | public: | ||||
| class AlgoBase; | class AlgoBase; | ||||
| struct KernParam { | struct KernParam { | ||||
| BcastType broad_cast_type; | |||||
| elemwise::BcastType broad_cast_type; | |||||
| Mode mode; | Mode mode; | ||||
| const TensorND* m_dst; | const TensorND* m_dst; | ||||
| Handle* handle; | Handle* handle; | ||||
| @@ -0,0 +1,72 @@ | |||||
| /** | |||||
| * \file dnn/src/fallback/elemwise_helper/elemwise_op.h | |||||
| */ | |||||
| #pragma once | |||||
| #include "src/fallback/elemwise_helper/op_binary.h" | |||||
| #include "src/fallback/elemwise_helper/op_common.h" | |||||
| #include "src/fallback/elemwise_helper/op_ternary.h" | |||||
| #include "src/fallback/elemwise_helper/op_unary.h" | |||||
| #include "src/fallback/general_intrinsic/gi_float.h" | |||||
| #include "src/fallback/general_intrinsic/gi_int.h" | |||||
| namespace megdnn { | |||||
| namespace elemwise { | |||||
| ///////////////////////////////// ParamElemVistor /////////////////////////// | |||||
| #define cb(_ctype, _inner_ctype, _simd_type, _fun_suffix) \ | |||||
| template <> \ | |||||
| struct ParamElemVisitor<_ctype> { \ | |||||
| _simd_type operator()(const _ctype* src) const { \ | |||||
| return GiLoad##_fun_suffix(src); \ | |||||
| } \ | |||||
| }; \ | |||||
| template <> \ | |||||
| struct ParamElemVisitorDup<_ctype> { \ | |||||
| _simd_type operator()(const _ctype* src) const { \ | |||||
| return GiBroadcast##_fun_suffix( \ | |||||
| *reinterpret_cast<const _inner_ctype*>(src)); \ | |||||
| } \ | |||||
| } | |||||
| cb(dt_qint32, int32_t, GI_INT32_t, Int32); | |||||
| cb(dt_qint8, int8_t, GI_INT8_t, Int8); | |||||
| cb(dt_float32, float, GI_FLOAT32_t, Float32); | |||||
| cb(dt_int32, int32_t, GI_INT32_t, Int32); | |||||
| cb(dt_int8, int8_t, GI_INT8_t, Int8); | |||||
| #undef cb | |||||
| template <typename ctype> | |||||
| struct ParamElemVisitorBcast101x4; | |||||
| #define cb(_ctype, _inner_ctype, _simd_type, _fun_suffix, rel_suffix) \ | |||||
| template <> \ | |||||
| struct ParamElemVisitorBcast101x4<_ctype> { \ | |||||
| _simd_type operator()(const _ctype* src) const { \ | |||||
| return GiReinter##rel_suffix##To##_fun_suffix(GiBroadcast##rel_suffix( \ | |||||
| *reinterpret_cast<const _inner_ctype*>(src))); \ | |||||
| } \ | |||||
| } | |||||
| cb(dt_qint8, int32_t, GI_INT8_t, Int8, Int32); | |||||
| cb(dt_int8, int32_t, GI_INT8_t, Int8, Int32); | |||||
| #undef cb | |||||
| #define cb(_ctype, _inner_ctype, _simd_type, _fun_suffix) \ | |||||
| template <> \ | |||||
| struct ParamElemVisitorBcast101x4<_ctype> { \ | |||||
| _simd_type operator()(const _ctype* src) const { \ | |||||
| return GiLoad##_fun_suffix(src); \ | |||||
| } \ | |||||
| } | |||||
| cb(dt_qint32, int32_t, GI_INT32_t, Int32); | |||||
| cb(dt_float32, float, GI_FLOAT32_t, Float32); | |||||
| cb(dt_int32, int32_t, GI_INT32_t, Int32); | |||||
| #undef cb | |||||
| } // namespace elemwise | |||||
| } // namespace megdnn | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -58,7 +58,7 @@ struct AbsOpBase<dt_qint8, dt_qint8> : UnaryOpBase<dt_qint8, dt_qint8> { | |||||
| template <> | template <> | ||||
| struct AbsOp<dt_qint8, dt_qint8> : AbsOpBase<dt_qint8, dt_qint8> { | struct AbsOp<dt_qint8, dt_qint8> : AbsOpBase<dt_qint8, dt_qint8> { | ||||
| using AbsOpBase::AbsOpBase; | using AbsOpBase::AbsOpBase; | ||||
| constexpr static size_t SIMD_WIDTH = 16; | |||||
| constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int8_t); | |||||
| using AbsOpBase::operator(); | using AbsOpBase::operator(); | ||||
| void operator()(const GI_INT8_V2_t& vsrc, dt_qint8* dst) const { | void operator()(const GI_INT8_V2_t& vsrc, dt_qint8* dst) const { | ||||
| OPERATOR_UNARY_QINT8_FALLBACK; | OPERATOR_UNARY_QINT8_FALLBACK; | ||||
| @@ -87,7 +87,7 @@ template <> | |||||
| struct FuseAddHSwishOp<dt_qint32, dt_qint8> : FuseAddHSwishOpBase<dt_qint32, dt_qint8> { | struct FuseAddHSwishOp<dt_qint32, dt_qint8> : FuseAddHSwishOpBase<dt_qint32, dt_qint8> { | ||||
| using FuseAddHSwishOpBase::FuseAddHSwishOpBase; | using FuseAddHSwishOpBase::FuseAddHSwishOpBase; | ||||
| using FuseAddHSwishOpBase::operator(); | using FuseAddHSwishOpBase::operator(); | ||||
| constexpr static size_t SIMD_WIDTH = 4; | |||||
| constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int8_t); | |||||
| void operator()( | void operator()( | ||||
| const GI_INT32_V2_t& vsrc0, const GI_INT32_V2_t& vsrc1, | const GI_INT32_V2_t& vsrc0, const GI_INT32_V2_t& vsrc1, | ||||
| dt_qint8* dst) const { | dt_qint8* dst) const { | ||||
| @@ -83,7 +83,7 @@ template <> | |||||
| struct HSwishOp<dt_qint32, dt_qint8> : HSwishOpBase<dt_qint32, dt_qint8> { | struct HSwishOp<dt_qint32, dt_qint8> : HSwishOpBase<dt_qint32, dt_qint8> { | ||||
| using HSwishOpBase::HSwishOpBase; | using HSwishOpBase::HSwishOpBase; | ||||
| using HSwishOpBase::operator(); | using HSwishOpBase::operator(); | ||||
| constexpr static size_t SIMD_WIDTH = 4; | |||||
| constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int32_t); | |||||
| void operator()(const GI_INT32_V2_t& vsrc, dt_qint8* dst) const { | void operator()(const GI_INT32_V2_t& vsrc, dt_qint8* dst) const { | ||||
| GiStoreLowInt8(reinterpret_cast<int8_t*>(dst), operator()(vsrc)); | GiStoreLowInt8(reinterpret_cast<int8_t*>(dst), operator()(vsrc)); | ||||
| @@ -77,7 +77,7 @@ struct MaxOpBase<dt_qint8, dt_qint8> : BinaryOpBase<dt_qint8, dt_qint8> { | |||||
| template <> | template <> | ||||
| struct MaxOp<dt_qint8, dt_qint8> : MaxOpBase<dt_qint8, dt_qint8> { | struct MaxOp<dt_qint8, dt_qint8> : MaxOpBase<dt_qint8, dt_qint8> { | ||||
| using MaxOpBase::MaxOpBase; | using MaxOpBase::MaxOpBase; | ||||
| constexpr static size_t SIMD_WIDTH = 16; | |||||
| constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int8_t); | |||||
| using MaxOpBase::operator(); | using MaxOpBase::operator(); | ||||
| void operator()( | void operator()( | ||||
| @@ -74,7 +74,7 @@ struct MinOpBase<dt_qint8, dt_qint8> : BinaryOpBase<dt_qint8, dt_qint8> { | |||||
| template <> | template <> | ||||
| struct MinOp<dt_qint8, dt_qint8> : MinOpBase<dt_qint8, dt_qint8> { | struct MinOp<dt_qint8, dt_qint8> : MinOpBase<dt_qint8, dt_qint8> { | ||||
| using MinOpBase::MinOpBase; | using MinOpBase::MinOpBase; | ||||
| constexpr static size_t SIMD_WIDTH = 16; | |||||
| constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int8_t); | |||||
| using MinOpBase::operator(); | using MinOpBase::operator(); | ||||
| void operator()( | void operator()( | ||||
| @@ -73,7 +73,7 @@ struct MulOpBase<dt_qint8, dt_qint8> : BinaryOpBase<dt_qint8, dt_qint8> { | |||||
| template <> | template <> | ||||
| struct MulOp<dt_qint8, dt_qint8> : MulOpBase<dt_qint8, dt_qint8> { | struct MulOp<dt_qint8, dt_qint8> : MulOpBase<dt_qint8, dt_qint8> { | ||||
| using MulOpBase::MulOpBase; | using MulOpBase::MulOpBase; | ||||
| constexpr static size_t SIMD_WIDTH = 16; | |||||
| constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int8_t); | |||||
| using MulOpBase::operator(); | using MulOpBase::operator(); | ||||
| void operator()( | void operator()( | ||||
| @@ -54,8 +54,6 @@ struct NoneOpBase<dt_qint32, dt_qint8> : UnaryOpBase<dt_qint32, dt_qint8> { | |||||
| } | } | ||||
| }; | }; | ||||
| #pragma GCC diagnostic ignored "-Waddress-of-packed-member" | |||||
| template <> | template <> | ||||
| struct NoneOp<dt_qint32, dt_qint8> : NoneOpBase<dt_qint32, dt_qint8> { | struct NoneOp<dt_qint32, dt_qint8> : NoneOpBase<dt_qint32, dt_qint8> { | ||||
| using NoneOpBase::NoneOpBase; | using NoneOpBase::NoneOpBase; | ||||
| @@ -63,11 +61,11 @@ struct NoneOp<dt_qint32, dt_qint8> : NoneOpBase<dt_qint32, dt_qint8> { | |||||
| constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int32_t); | constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int32_t); | ||||
| void operator()(const GI_INT32_V2_t& vsrc, dt_qint8* dst) const { | void operator()(const GI_INT32_V2_t& vsrc, dt_qint8* dst) const { | ||||
| GiStoreInt32(reinterpret_cast<int32_t*>(dst), vsrc.val[0]); | |||||
| GiStoreInt32(reinterpret_cast<int32_t*>(dst + 16), vsrc.val[1]); | |||||
| GiStoreInt32(dst, vsrc.val[0]); | |||||
| GiStoreInt32(dst + 16, vsrc.val[1]); | |||||
| } | } | ||||
| void operator()(const GI_INT32_t& src, dt_qint8* dst) const { | void operator()(const GI_INT32_t& src, dt_qint8* dst) const { | ||||
| GiStoreInt32(reinterpret_cast<int32_t*>(dst), src); | |||||
| GiStoreInt32(dst, src); | |||||
| } | } | ||||
| }; | }; | ||||
| @@ -112,36 +112,38 @@ struct ReluOp<dt_qint32, dt_qint8> : ReluOpBase<dt_qint32, dt_qint8>, FixupBase | |||||
| : ReluOpBase(src_scale, dst_scale), FixupBase(scale) {} | : ReluOpBase(src_scale, dst_scale), FixupBase(scale) {} | ||||
| void operator()(const int32x4x2_t& vsrc, dt_qint8* dst) const { | void operator()(const int32x4x2_t& vsrc, dt_qint8* dst) const { | ||||
| vst1_s8(reinterpret_cast<int8_t*>(dst), operator()(vsrc)); | |||||
| vst1_s8(reinterpret_cast<int8_t*>(dst), vget_low_s8(operator()(vsrc))); | |||||
| } | } | ||||
| int8x8_t operator()(const int32x4x2_t& vsrc) const { | |||||
| int8x16_t operator()(const int32x4x2_t& vsrc) const { | |||||
| int32x4_t vitem0 = vqrdmulhq_s32(vsrc.val[0], vmultiplier); | int32x4_t vitem0 = vqrdmulhq_s32(vsrc.val[0], vmultiplier); | ||||
| int32x4_t vitem1 = vqrdmulhq_s32(vsrc.val[1], vmultiplier); | int32x4_t vitem1 = vqrdmulhq_s32(vsrc.val[1], vmultiplier); | ||||
| vitem0 = vmaxq_s32(vitem0, QConverterBase::vzero()); | vitem0 = vmaxq_s32(vitem0, QConverterBase::vzero()); | ||||
| vitem1 = vmaxq_s32(vitem1, QConverterBase::vzero()); | vitem1 = vmaxq_s32(vitem1, QConverterBase::vzero()); | ||||
| return vqmovn_s16(vcombine_s16( | |||||
| auto tmp = vqmovn_s16(vcombine_s16( | |||||
| vqmovn_s32(vrshlq_s32(vitem0, vshift)), | vqmovn_s32(vrshlq_s32(vitem0, vshift)), | ||||
| vqmovn_s32(vrshlq_s32(vitem1, vshift)))); | vqmovn_s32(vrshlq_s32(vitem1, vshift)))); | ||||
| return vcombine_s8(tmp, tmp); | |||||
| } | } | ||||
| int8x8_t operator()(const float32x4_t& vsrc) const { | |||||
| int8x16_t operator()(const float32x4_t& vsrc) const { | |||||
| int32x4_t vitem0 = vqrdmulhq_s32(vcvtq_s32_f32(vsrc), vmultiplier); | int32x4_t vitem0 = vqrdmulhq_s32(vcvtq_s32_f32(vsrc), vmultiplier); | ||||
| vitem0 = vmaxq_s32(vitem0, QConverterBase::vzero()); | vitem0 = vmaxq_s32(vitem0, QConverterBase::vzero()); | ||||
| vitem0 = vrshlq_s32(vitem0, vshift); | vitem0 = vrshlq_s32(vitem0, vshift); | ||||
| int16x4_t vitem = vqmovn_s32(vitem0); | int16x4_t vitem = vqmovn_s32(vitem0); | ||||
| return vqmovn_s16(vcombine_s16(vitem, vitem)); | |||||
| auto tmp = vqmovn_s16(vcombine_s16(vitem, vitem)); | |||||
| return vcombine_s8(tmp, tmp); | |||||
| } | } | ||||
| void operator()(const int32x4_t& src, dt_qint8* dst) const { | void operator()(const int32x4_t& src, dt_qint8* dst) const { | ||||
| auto vitem0 = vmulq_f32(vcvtq_f32_s32(src), this->vscale); | auto vitem0 = vmulq_f32(vcvtq_f32_s32(src), this->vscale); | ||||
| vitem0 = vmaxq_f32(vitem0, QConverterBase::vfzero()); | vitem0 = vmaxq_f32(vitem0, QConverterBase::vfzero()); | ||||
| auto result = QConverter::convert<int8x8_t, float32x4_t>(vitem0); | |||||
| vst1_lane_s32(reinterpret_cast<int32_t*>(dst), (int32x2_t)result, 0); | |||||
| auto result = QConverter::convert<int8x16_t, float32x4_t>(vitem0); | |||||
| vst1q_lane_s32(reinterpret_cast<int32_t*>(dst), (int32x4_t)result, 0); | |||||
| } | } | ||||
| void operator()(const float32x4_t& src, dt_qint8* dst) const { | void operator()(const float32x4_t& src, dt_qint8* dst) const { | ||||
| auto vitem0 = vmulq_f32(src, this->vscale); | auto vitem0 = vmulq_f32(src, this->vscale); | ||||
| vitem0 = vmaxq_f32(vitem0, QConverterBase::vfzero()); | vitem0 = vmaxq_f32(vitem0, QConverterBase::vfzero()); | ||||
| auto result = QConverter::convert<int8x8_t, float32x4_t>(vitem0); | |||||
| vst1_lane_s32(reinterpret_cast<int32_t*>(dst), (int32x2_t)result, 0); | |||||
| auto result = QConverter::convert<int8x16_t, float32x4_t>(vitem0); | |||||
| vst1q_lane_s32(reinterpret_cast<int32_t*>(dst), (int32x4_t)result, 0); | |||||
| } | } | ||||
| }; | }; | ||||
| @@ -73,7 +73,7 @@ struct SubOpBase<dt_qint8, dt_qint8> : BinaryOpBase<dt_qint8, dt_qint8> { | |||||
| template <> | template <> | ||||
| struct SubOp<dt_qint8, dt_qint8> : SubOpBase<dt_qint8, dt_qint8> { | struct SubOp<dt_qint8, dt_qint8> : SubOpBase<dt_qint8, dt_qint8> { | ||||
| using SubOpBase::SubOpBase; | using SubOpBase::SubOpBase; | ||||
| constexpr static size_t SIMD_WIDTH = 16; | |||||
| constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int8_t); | |||||
| using SubOpBase::operator(); | using SubOpBase::operator(); | ||||
| void operator()( | void operator()( | ||||
| @@ -13,6 +13,7 @@ | |||||
| #include "math.h" | #include "math.h" | ||||
| #include "stdint.h" | #include "stdint.h" | ||||
| #include "string.h" | |||||
| #if defined(_WIN32) | #if defined(_WIN32) | ||||
| #include <intrin.h> | #include <intrin.h> | ||||
| @@ -132,6 +133,18 @@ typedef uint32_t GI_UINT32_t __attribute__((vector_size(16))); | |||||
| #define Max(a, b) (a) > (b) ? (a) : (b) | #define Max(a, b) (a) > (b) ? (a) : (b) | ||||
| #define Min(a, b) (a) < (b) ? (a) : (b) | #define Min(a, b) (a) < (b) ? (a) : (b) | ||||
| #if defined(GI_NEON_INTRINSICS) | |||||
| #if defined(__ARM_FEATURE_FMA) && defined(GI_NEON64_INTRINSICS) | |||||
| #define v_fma_ps_f32(c, b, a) vfmaq_f32((c), (b), (a)) | |||||
| #define v_fma_n_f32(c, b, a) vfmaq_n_f32((c), (b), (a)) | |||||
| #define v_fma_lane_f32(c, b, a, lane) vfmaq_lane_f32((c), (b), (a), (lane)) | |||||
| #else | |||||
| #define v_fma_ps_f32(c, b, a) vmlaq_f32((c), (b), (a)) | |||||
| #define v_fma_n_f32(c, b, a) vmlaq_n_f32((c), (b), (a)) | |||||
| #define v_fma_lane_f32(c, b, a, lane) vmlaq_lane_f32((c), (b), (a), (lane)) | |||||
| #endif | |||||
| #endif | |||||
| typedef struct { | typedef struct { | ||||
| GI_INT32_t val[2]; | GI_INT32_t val[2]; | ||||
| } GI_INT32_V2_t; | } GI_INT32_V2_t; | ||||
| @@ -20,7 +20,9 @@ GI_INT32_t GiReinterpretAsInt32(GI_FLOAT32_t In) { | |||||
| #elif defined(GI_SSE2_INTRINSICS) | #elif defined(GI_SSE2_INTRINSICS) | ||||
| return _mm_castps_si128(In); | return _mm_castps_si128(In); | ||||
| #else | #else | ||||
| return *(GI_INT32_t*)(&In); | |||||
| GI_INT32_t ret; | |||||
| memcpy(&ret, &In, GI_SIMD_LEN_BYTE); | |||||
| return ret; | |||||
| #endif | #endif | ||||
| } | } | ||||
| @@ -31,7 +33,9 @@ GI_UINT32_t GiReinterpretAsUint32(GI_FLOAT32_t In) { | |||||
| #elif defined(GI_SSE2_INTRINSICS) | #elif defined(GI_SSE2_INTRINSICS) | ||||
| return _mm_castps_si128(In); | return _mm_castps_si128(In); | ||||
| #else | #else | ||||
| return *(GI_UINT32_t*)(&In); | |||||
| GI_UINT32_t ret; | |||||
| memcpy(&ret, &In, GI_SIMD_LEN_BYTE); | |||||
| return ret; | |||||
| #endif | #endif | ||||
| } | } | ||||
| @@ -42,7 +46,9 @@ GI_FLOAT32_t GiReintInt32ToFloat32(GI_INT32_t Vector) { | |||||
| #elif defined(GI_SSE2_INTRINSICS) | #elif defined(GI_SSE2_INTRINSICS) | ||||
| return _mm_castsi128_ps(Vector); | return _mm_castsi128_ps(Vector); | ||||
| #else | #else | ||||
| return *(GI_FLOAT32_t*)(&Vector); | |||||
| GI_FLOAT32_t ret; | |||||
| memcpy(&ret, &Vector, GI_SIMD_LEN_BYTE); | |||||
| return ret; | |||||
| #endif | #endif | ||||
| } | } | ||||
| @@ -53,7 +59,9 @@ GI_FLOAT32_t GiReintUint32ToFloat32(GI_UINT32_t Vector) { | |||||
| #elif defined(GI_SSE2_INTRINSICS) | #elif defined(GI_SSE2_INTRINSICS) | ||||
| return _mm_castsi128_ps(Vector); | return _mm_castsi128_ps(Vector); | ||||
| #else | #else | ||||
| return *(GI_FLOAT32_t*)(&Vector); | |||||
| GI_FLOAT32_t ret; | |||||
| memcpy(&ret, &Vector, GI_SIMD_LEN_BYTE); | |||||
| return ret; | |||||
| #endif | #endif | ||||
| } | } | ||||
| @@ -69,7 +77,7 @@ GI_INT32_t GiRoundAsInt32(GI_FLOAT32_t Vector) { | |||||
| float32x4_t vinc0 = vbslq_f32(vcgeq_f32(Vector, vzero), vfhalf, vfneg_half); | float32x4_t vinc0 = vbslq_f32(vcgeq_f32(Vector, vzero), vfhalf, vfneg_half); | ||||
| return vcvtq_s32_f32(vaddq_f32(Vector, vinc0)); | return vcvtq_s32_f32(vaddq_f32(Vector, vinc0)); | ||||
| #endif | #endif | ||||
| #elif defined(GI_SSE2_INTRINSICS) | |||||
| #elif defined(GI_SSE42_INTRINSICS) | |||||
| __m128 vfzero = _mm_set1_ps(0.f); | __m128 vfzero = _mm_set1_ps(0.f); | ||||
| __m128 vfhalf = _mm_set1_ps(0.5f); | __m128 vfhalf = _mm_set1_ps(0.5f); | ||||
| __m128 vfneg_half = _mm_set1_ps(-0.5f); | __m128 vfneg_half = _mm_set1_ps(-0.5f); | ||||
| @@ -322,11 +330,7 @@ GI_FORCEINLINE | |||||
| GI_FLOAT32_t GiMultiplyAddFloat32( | GI_FLOAT32_t GiMultiplyAddFloat32( | ||||
| GI_FLOAT32_t VectorSum, GI_FLOAT32_t Vector1, GI_FLOAT32_t Vector2) { | GI_FLOAT32_t VectorSum, GI_FLOAT32_t Vector1, GI_FLOAT32_t Vector2) { | ||||
| #if defined(GI_NEON_INTRINSICS) | #if defined(GI_NEON_INTRINSICS) | ||||
| #if defined(__ARM_FEATURE_FMA) | |||||
| return vfmaq_f32(VectorSum, Vector1, Vector2); | |||||
| #else | |||||
| return vmlaq_f32(VectorSum, Vector1, Vector2); | |||||
| #endif | |||||
| return v_fma_ps_f32(VectorSum, Vector1, Vector2); | |||||
| #elif defined(GI_FMA3_INTRINSICS) | #elif defined(GI_FMA3_INTRINSICS) | ||||
| return _mm_fmadd_ps(Vector1, Vector2, VectorSum); | return _mm_fmadd_ps(Vector1, Vector2, VectorSum); | ||||
| #elif defined(GI_SSE2_INTRINSICS) | #elif defined(GI_SSE2_INTRINSICS) | ||||
| @@ -352,11 +356,7 @@ GI_FORCEINLINE | |||||
| GI_FLOAT32_t GiMultiplyAddScalarFloat32( | GI_FLOAT32_t GiMultiplyAddScalarFloat32( | ||||
| GI_FLOAT32_t VectorSum, GI_FLOAT32_t Vector, float Scalar) { | GI_FLOAT32_t VectorSum, GI_FLOAT32_t Vector, float Scalar) { | ||||
| #if defined(GI_NEON_INTRINSICS) | #if defined(GI_NEON_INTRINSICS) | ||||
| #if defined(__ARM_FEATURE_FMA) | |||||
| return vfmaq_n_f32(VectorSum, Vector, Scalar); | |||||
| #else | |||||
| return vfmla_n_f32(VectorSum, Vector, Scalar); | |||||
| #endif | |||||
| return v_fma_n_f32(VectorSum, Vector, Scalar); | |||||
| #elif defined(GI_SSE2_INTRINSICS) | #elif defined(GI_SSE2_INTRINSICS) | ||||
| return GiMultiplyAddFloat32(VectorSum, GiBroadcastFloat32(Scalar), Vector); | return GiMultiplyAddFloat32(VectorSum, GiBroadcastFloat32(Scalar), Vector); | ||||
| #else | #else | ||||
| @@ -365,27 +365,10 @@ GI_FLOAT32_t GiMultiplyAddScalarFloat32( | |||||
| } | } | ||||
| #if defined(GI_NEON_INTRINSICS) | #if defined(GI_NEON_INTRINSICS) | ||||
| #if defined(__ARM_FEATURE_FMA) | |||||
| #define GIMULTIPLYADDLANFLOAT32(i) \ | |||||
| GI_FORCEINLINE GI_FLOAT32_t GiMultiplyAddLan##i##Float32( \ | |||||
| GI_FLOAT32_t VectorSum, GI_FLOAT32_t Vector1, GI_FLOAT32_t Vector2) { \ | |||||
| return vfmaq_lane_f32(VectorSum, Vector1, vget_low_f32(Vector2), i); \ | |||||
| } | |||||
| GIMULTIPLYADDLANFLOAT32(0) | |||||
| GIMULTIPLYADDLANFLOAT32(1) | |||||
| #undef GIMULTIPLYADDLANFLOAT32 | |||||
| #define GIMULTIPLYADDLANFLOAT32(i) \ | |||||
| GI_FORCEINLINE GI_FLOAT32_t GiMultiplyAddLan##i##Float32( \ | |||||
| GI_FLOAT32_t VectorSum, GI_FLOAT32_t Vector1, GI_FLOAT32_t Vector2) { \ | |||||
| return vfmaq_lane_f32(VectorSum, Vector1, vget_high_f32(Vector2), i - 2); \ | |||||
| } | |||||
| GIMULTIPLYADDLANFLOAT32(2) | |||||
| GIMULTIPLYADDLANFLOAT32(3) | |||||
| #else | |||||
| #define GIMULTIPLYADDLANFLOAT32(i) \ | #define GIMULTIPLYADDLANFLOAT32(i) \ | ||||
| GI_FORCEINLINE GI_FLOAT32_t GiMultiplyAddLan##i##Float32( \ | GI_FORCEINLINE GI_FLOAT32_t GiMultiplyAddLan##i##Float32( \ | ||||
| GI_FLOAT32_t VectorSum, GI_FLOAT32_t Vector1, GI_FLOAT32_t Vector2) { \ | GI_FLOAT32_t VectorSum, GI_FLOAT32_t Vector1, GI_FLOAT32_t Vector2) { \ | ||||
| return vmlaq_lane_f32(VectorSum, Vector1, vget_low_f32(Vector2), i); \ | |||||
| return v_fma_lane_f32(VectorSum, Vector1, vget_low_f32(Vector2), i); \ | |||||
| } | } | ||||
| GIMULTIPLYADDLANFLOAT32(0) | GIMULTIPLYADDLANFLOAT32(0) | ||||
| GIMULTIPLYADDLANFLOAT32(1) | GIMULTIPLYADDLANFLOAT32(1) | ||||
| @@ -393,11 +376,10 @@ GIMULTIPLYADDLANFLOAT32(1) | |||||
| #define GIMULTIPLYADDLANFLOAT32(i) \ | #define GIMULTIPLYADDLANFLOAT32(i) \ | ||||
| GI_FORCEINLINE GI_FLOAT32_t GiMultiplyAddLan##i##Float32( \ | GI_FORCEINLINE GI_FLOAT32_t GiMultiplyAddLan##i##Float32( \ | ||||
| GI_FLOAT32_t VectorSum, GI_FLOAT32_t Vector1, GI_FLOAT32_t Vector2) { \ | GI_FLOAT32_t VectorSum, GI_FLOAT32_t Vector1, GI_FLOAT32_t Vector2) { \ | ||||
| return vmlaq_lane_f32(VectorSum, Vector1, vget_high_f32(Vector2), i - 2); \ | |||||
| return v_fma_lane_f32(VectorSum, Vector1, vget_high_f32(Vector2), i - 2); \ | |||||
| } | } | ||||
| GIMULTIPLYADDLANFLOAT32(2) | GIMULTIPLYADDLANFLOAT32(2) | ||||
| GIMULTIPLYADDLANFLOAT32(3) | GIMULTIPLYADDLANFLOAT32(3) | ||||
| #endif | |||||
| #undef GIMULTIPLYADDLANFLOAT32 | #undef GIMULTIPLYADDLANFLOAT32 | ||||
| #elif defined(GI_SSE2_INTRINSICS) | #elif defined(GI_SSE2_INTRINSICS) | ||||
| @@ -59,66 +59,69 @@ GI_INT8_t GiBroadcastInt8(int8_t Value) { | |||||
| } | } | ||||
| GI_FORCEINLINE | GI_FORCEINLINE | ||||
| GI_INT32_t GiLoadInt32(const int32_t* Buffer) { | |||||
| GI_INT32_t GiLoadInt32(const void* Buffer) { | |||||
| #if defined(GI_NEON_INTRINSICS) | #if defined(GI_NEON_INTRINSICS) | ||||
| return vld1q_s32(Buffer); | |||||
| return vld1q_s32((int32_t*)Buffer); | |||||
| #elif defined(GI_SSE2_INTRINSICS) | #elif defined(GI_SSE2_INTRINSICS) | ||||
| return _mm_loadu_si128((const __m128i*)Buffer); | return _mm_loadu_si128((const __m128i*)Buffer); | ||||
| #else | #else | ||||
| GI_INT32_t ret; | GI_INT32_t ret; | ||||
| const int32_t* ptr = (int32_t*)Buffer; | |||||
| for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(int32_t); i++) { | for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(int32_t); i++) { | ||||
| ret[i] = Buffer[i]; | |||||
| ret[i] = ptr[i]; | |||||
| } | } | ||||
| return ret; | return ret; | ||||
| #endif | #endif | ||||
| } | } | ||||
| GI_FORCEINLINE | GI_FORCEINLINE | ||||
| GI_INT8_t GiLoadInt8(const int8_t* Buffer) { | |||||
| GI_INT8_t GiLoadInt8(const void* Buffer) { | |||||
| #if defined(GI_NEON_INTRINSICS) | #if defined(GI_NEON_INTRINSICS) | ||||
| return vld1q_s8(Buffer); | |||||
| return vld1q_s8((int8_t*)Buffer); | |||||
| #elif defined(GI_SSE2_INTRINSICS) | #elif defined(GI_SSE2_INTRINSICS) | ||||
| return _mm_loadu_si128((const __m128i*)Buffer); | return _mm_loadu_si128((const __m128i*)Buffer); | ||||
| #else | #else | ||||
| GI_INT8_t ret; | GI_INT8_t ret; | ||||
| const int8_t* ptr = (int8_t*)Buffer; | |||||
| for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(int8_t); i++) { | for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(int8_t); i++) { | ||||
| ret[i] = Buffer[i]; | |||||
| ret[i] = ptr[i]; | |||||
| } | } | ||||
| return ret; | return ret; | ||||
| #endif | #endif | ||||
| } | } | ||||
| GI_FORCEINLINE | GI_FORCEINLINE | ||||
| void GiStoreInt32(int32_t* Buffer, GI_INT32_t Vector) { | |||||
| void GiStoreInt32(void* Buffer, GI_INT32_t Vector) { | |||||
| #if defined(GI_NEON_INTRINSICS) | #if defined(GI_NEON_INTRINSICS) | ||||
| vst1q_s32(Buffer, Vector); | |||||
| vst1q_s32((int32_t*)Buffer, Vector); | |||||
| #elif defined(GI_SSE2_INTRINSICS) | #elif defined(GI_SSE2_INTRINSICS) | ||||
| _mm_storeu_si128((__m128i*)Buffer, Vector); | _mm_storeu_si128((__m128i*)Buffer, Vector); | ||||
| #else | #else | ||||
| int32_t* ptr = (int32_t*)Buffer; | |||||
| for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(int32_t); i++) { | for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(int32_t); i++) { | ||||
| Buffer[i] = Vector[i]; | |||||
| ptr[i] = Vector[i]; | |||||
| } | } | ||||
| #endif | #endif | ||||
| } | } | ||||
| #if defined(GI_NEON_INTRINSICS) | #if defined(GI_NEON_INTRINSICS) | ||||
| #define GISTORELANEINT32(i) \ | |||||
| GI_FORCEINLINE void GiStoreLane##i##Int32(int32_t* Buffer, GI_INT32_t Vector) { \ | |||||
| vst1q_lane_s32(Buffer, Vector, i); \ | |||||
| #define GISTORELANEINT32(i) \ | |||||
| GI_FORCEINLINE void GiStoreLane##i##Int32(void* Buffer, GI_INT32_t Vector) { \ | |||||
| vst1q_lane_s32((int32_t*)Buffer, Vector, i); \ | |||||
| } | } | ||||
| #elif defined(GI_SSE2_INTRINSICS) | #elif defined(GI_SSE2_INTRINSICS) | ||||
| #define GISTORELANEINT32(i) \ | #define GISTORELANEINT32(i) \ | ||||
| GI_FORCEINLINE void GiStoreLane##i##Int32(int32_t* Buffer, GI_INT32_t Vector) { \ | |||||
| GI_FORCEINLINE void GiStoreLane##i##Int32(void* Buffer, GI_INT32_t Vector) { \ | |||||
| GI_FLOAT32_t tmp = _mm_castsi128_ps(Vector); \ | GI_FLOAT32_t tmp = _mm_castsi128_ps(Vector); \ | ||||
| _mm_store_ss( \ | _mm_store_ss( \ | ||||
| (float*)Buffer, _mm_shuffle_ps(tmp, tmp, _MM_SHUFFLE(i, i, i, i))); \ | (float*)Buffer, _mm_shuffle_ps(tmp, tmp, _MM_SHUFFLE(i, i, i, i))); \ | ||||
| } | } | ||||
| #else | #else | ||||
| #define GISTORELANEINT32(i) \ | |||||
| GI_FORCEINLINE void GiStoreLane##i##Int32(int32_t* Buffer, GI_INT32_t Vector) { \ | |||||
| *Buffer = Vector[i]; \ | |||||
| #define GISTORELANEINT32(i) \ | |||||
| GI_FORCEINLINE void GiStoreLane##i##Int32(void* Buffer, GI_INT32_t Vector) { \ | |||||
| *((int32_t*)Buffer) = Vector[i]; \ | |||||
| } | } | ||||
| #endif | #endif | ||||
| @@ -141,53 +144,57 @@ GI_INT8_t GiReinterInt32ToInt8(GI_INT32_t Vector) { | |||||
| } | } | ||||
| GI_FORCEINLINE | GI_FORCEINLINE | ||||
| void GiStoreInt16(int16_t* Buffer, GI_INT16_t Vector) { | |||||
| void GiStoreInt16(void* Buffer, GI_INT16_t Vector) { | |||||
| #if defined(GI_NEON_INTRINSICS) | #if defined(GI_NEON_INTRINSICS) | ||||
| vst1q_s16(Buffer, Vector); | |||||
| vst1q_s16((int16_t*)Buffer, Vector); | |||||
| #elif defined(GI_SSE2_INTRINSICS) | #elif defined(GI_SSE2_INTRINSICS) | ||||
| _mm_storeu_si128((__m128i*)Buffer, Vector); | _mm_storeu_si128((__m128i*)Buffer, Vector); | ||||
| #else | #else | ||||
| int16_t* ptr = (int16_t*)Buffer; | |||||
| for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(int16_t); i++) { | for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(int16_t); i++) { | ||||
| Buffer[i] = Vector[i]; | |||||
| ptr[i] = Vector[i]; | |||||
| } | } | ||||
| #endif | #endif | ||||
| } | } | ||||
| GI_FORCEINLINE | GI_FORCEINLINE | ||||
| void GiStoreInt8(int8_t* Buffer, GI_INT8_t Vector) { | |||||
| void GiStoreInt8(void* Buffer, GI_INT8_t Vector) { | |||||
| #if defined(GI_NEON_INTRINSICS) | #if defined(GI_NEON_INTRINSICS) | ||||
| vst1q_s8(Buffer, Vector); | |||||
| vst1q_s8((int8_t*)Buffer, Vector); | |||||
| #elif defined(GI_SSE2_INTRINSICS) | #elif defined(GI_SSE2_INTRINSICS) | ||||
| _mm_storeu_si128((__m128i*)Buffer, Vector); | _mm_storeu_si128((__m128i*)Buffer, Vector); | ||||
| #else | #else | ||||
| int8_t* ptr = (int8_t*)Buffer; | |||||
| for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(int8_t); i++) { | for (size_t i = 0; i < GI_SIMD_LEN_BYTE / sizeof(int8_t); i++) { | ||||
| Buffer[i] = Vector[i]; | |||||
| ptr[i] = Vector[i]; | |||||
| } | } | ||||
| #endif | #endif | ||||
| } | } | ||||
| GI_FORCEINLINE | GI_FORCEINLINE | ||||
| void GiStoreLowInt8(int8_t* Buffer, GI_INT8_t Vector) { | |||||
| void GiStoreLowInt8(void* Buffer, GI_INT8_t Vector) { | |||||
| #if defined(GI_NEON_INTRINSICS) | #if defined(GI_NEON_INTRINSICS) | ||||
| vst1_s8(Buffer, vget_low_s8(Vector)); | |||||
| vst1_s8((int8_t*)Buffer, vget_low_s8(Vector)); | |||||
| #elif defined(GI_SSE2_INTRINSICS) | #elif defined(GI_SSE2_INTRINSICS) | ||||
| _mm_storel_epi64((__m128i*)Buffer, Vector); | _mm_storel_epi64((__m128i*)Buffer, Vector); | ||||
| #else | #else | ||||
| int8_t* ptr = (int8_t*)Buffer; | |||||
| for (size_t i = 0; i < GI_SIMD_LEN_BYTE / 2 / sizeof(int8_t); i++) { | for (size_t i = 0; i < GI_SIMD_LEN_BYTE / 2 / sizeof(int8_t); i++) { | ||||
| Buffer[i] = Vector[i]; | |||||
| ptr[i] = Vector[i]; | |||||
| } | } | ||||
| #endif | #endif | ||||
| } | } | ||||
| GI_FORCEINLINE | GI_FORCEINLINE | ||||
| void GiStoreHihgInt8(int8_t* Buffer, GI_INT8_t Vector) { | |||||
| void GiStoreHihgInt8(void* Buffer, GI_INT8_t Vector) { | |||||
| #if defined(GI_NEON_INTRINSICS) | #if defined(GI_NEON_INTRINSICS) | ||||
| vst1_s8(Buffer, vget_high_s8(Vector)); | |||||
| vst1_s8((int8_t*)Buffer, vget_high_s8(Vector)); | |||||
| #elif defined(GI_SSE2_INTRINSICS) | #elif defined(GI_SSE2_INTRINSICS) | ||||
| _mm_storel_epi64((__m128i*)Buffer, _mm_unpackhi_epi64(Vector, Vector)); | _mm_storel_epi64((__m128i*)Buffer, _mm_unpackhi_epi64(Vector, Vector)); | ||||
| #else | #else | ||||
| int8_t* ptr = (int8_t*)Buffer; | |||||
| for (size_t i = 0; i < GI_SIMD_LEN_BYTE / 2 / sizeof(int8_t); i++) { | for (size_t i = 0; i < GI_SIMD_LEN_BYTE / 2 / sizeof(int8_t); i++) { | ||||
| Buffer[i] = Vector[GI_SIMD_LEN_BYTE / 2 + i]; | |||||
| ptr[i] = Vector[GI_SIMD_LEN_BYTE / 2 + i]; | |||||
| } | } | ||||
| #endif | #endif | ||||
| } | } | ||||
| @@ -39,7 +39,6 @@ TEST_F(FALLBACK, ELEMWISE_RECORD) { | |||||
| checker.execs({{10, 10, 32}, {10, 10, 32}, {}}); | checker.execs({{10, 10, 32}, {10, 10, 32}, {}}); | ||||
| } | } | ||||
| TEST_F(FALLBACK, ELEMWISE_FORWARD_TERNARY) { | TEST_F(FALLBACK, ELEMWISE_FORWARD_TERNARY) { | ||||
| using Mode = ElemwiseForward::Param::Mode; | using Mode = ElemwiseForward::Param::Mode; | ||||
| Checker<ElemwiseForward> checker(handle()); | Checker<ElemwiseForward> checker(handle()); | ||||