GitOrigin-RevId: 2895135801
tags/v1.7.0
| @@ -104,6 +104,21 @@ bool ElemwiseImpl::AlgoBinaryVecBcast101::is_available( | |||
| return false; | |||
| } | |||
| bool ElemwiseImpl::AlgoBinaryVecBcastX0X::is_available( | |||
| const KernParam& kern_param) const { | |||
| if (!is_available_common(kern_param.mode) || | |||
| ((BcastType::VEC_BCASTX0X != kern_param.broad_cast_type) && | |||
| (BcastType::BCASTX0X_VEC != kern_param.broad_cast_type))) | |||
| return false; | |||
| auto& elparam = kern_param.binary_elparam; | |||
| auto& src0 = elparam[0]; | |||
| DISPATCH_TYPE("AlgoBinaryVecBcastX0X::is_available"_hash); | |||
| return false; | |||
| } | |||
| bool ElemwiseImpl::AlgoBinaryVecBcast111C::is_available( | |||
| const KernParam& kern_param) const { | |||
| if (!is_available_common(kern_param.mode) || | |||
| @@ -348,6 +363,72 @@ void ElemwiseImpl::AlgoBinaryVecBcast101::exec(const KernParam& kern_param) cons | |||
| return; | |||
| } | |||
| void ElemwiseImpl::AlgoBinaryVecBcastX0X::exec(const KernParam& kern_param) const { | |||
| auto& elparam = kern_param.binary_elparam; | |||
| auto &src0 = elparam[0], &src1 = elparam[1]; | |||
| auto&& dst = *(kern_param.m_dst); | |||
| BroadcastChannelInfo binfo; | |||
| // Case: BcastType::VEC + BCAST_X0X | |||
| if (BcastType::VEC_BCASTX0X == kern_param.broad_cast_type && | |||
| is_broadcasted_3dim_like(src1.layout, binfo)) { | |||
| #define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \ | |||
| case Mode::_mode: \ | |||
| MIDOUT_BEGIN( \ | |||
| megdnn_arm_common_elemwise_binary, midout_iv(_case), \ | |||
| midout_iv(Mode::_mode), _type_midout_id) { \ | |||
| thin_function<void( \ | |||
| const _type*, const _type*, _type*, DType, DType, DType, size_t, \ | |||
| size_t, size_t)> \ | |||
| run = OpCallerBinary< \ | |||
| _op<_type, _type>, BcastType::VEC_BCASTX0X>::run; \ | |||
| MEGDNN_DISPATCH_CPU_KERN( \ | |||
| static_cast<naive::HandleImpl*>(kern_param.handle), \ | |||
| run(static_cast<const _type*>(src0.raw_ptr), \ | |||
| static_cast<const _type*>(src1.raw_ptr), \ | |||
| static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ | |||
| src1.layout.dtype, dst.layout.dtype, binfo.x, binfo.y, \ | |||
| binfo.z)); \ | |||
| } \ | |||
| MIDOUT_END(); \ | |||
| return | |||
| DISPATCH_TYPE("AlgoBinaryVecBcastX0X::exec_vec_b"_hash); | |||
| #undef DISPATCH_BINARY | |||
| } | |||
| // BCAST_X0X + BcastType::VEC | |||
| if (BcastType::BCASTX0X_VEC == kern_param.broad_cast_type && | |||
| is_broadcasted_3dim_like(src0.layout, binfo)) { | |||
| #define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \ | |||
| case Mode::_mode: \ | |||
| MIDOUT_BEGIN( \ | |||
| megdnn_arm_common_elemwise_binary, midout_iv(_case), \ | |||
| midout_iv(Mode::_mode), _type_midout_id) { \ | |||
| thin_function<void( \ | |||
| const _type*, const _type*, _type*, DType, DType, DType, size_t, \ | |||
| size_t, size_t)> \ | |||
| run = OpCallerBinary< \ | |||
| _op<_type, _type>, BcastType::BCASTX0X_VEC>::run; \ | |||
| MEGDNN_DISPATCH_CPU_KERN( \ | |||
| static_cast<naive::HandleImpl*>(kern_param.handle), \ | |||
| run(static_cast<const _type*>(src0.raw_ptr), \ | |||
| static_cast<const _type*>(src1.raw_ptr), \ | |||
| static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ | |||
| src1.layout.dtype, dst.layout.dtype, binfo.x, binfo.y, \ | |||
| binfo.z)); \ | |||
| } \ | |||
| MIDOUT_END(); \ | |||
| return | |||
| DISPATCH_TYPE("AlgoBinaryVecBcastX0X::exec_b_vec"_hash); | |||
| #undef DISPATCH_BINARY | |||
| } | |||
| return; | |||
| } | |||
| void ElemwiseImpl::AlgoBinaryVecBcast111C::exec(const KernParam& kern_param) const { | |||
| auto& elparam = kern_param.binary_elparam; | |||
| auto &src0 = elparam[0], &src1 = elparam[1]; | |||
| @@ -33,6 +33,7 @@ namespace arm_common { | |||
| DECL_CB(VecVec); | |||
| DECL_CB(VecScalar); | |||
| DECL_CB(VecBcast101); | |||
| DECL_CB(VecBcastX0X); | |||
| DECL_CB(VecBcast111C); | |||
| DECL_CB(VecBcast101xX); | |||
| #undef DECL_CB | |||
| @@ -27,6 +27,7 @@ class ElemwiseImpl::AlgoPack { | |||
| AlgoBinaryVecVec algo_binary_vec_vec; | |||
| AlgoBinaryVecScalar algo_binary_vec_sca; | |||
| AlgoBinaryVecBcast101 algo_binary_vec_bcast101; | |||
| AlgoBinaryVecBcastX0X algo_binary_vec_bcastX0X; | |||
| AlgoBinaryVecBcast111C algo_binary_vec_bcast110; | |||
| AlgoBinaryVecBcast101xX algo_binary_VEC_BCAST101xX; | |||
| AlgoTernaryFma3VecVecVec algo_ternaryfma3_vec_vec_vec; | |||
| @@ -46,6 +47,7 @@ public: | |||
| all_algos.emplace_back(&algo_binary_vec_vec); | |||
| all_algos.emplace_back(&algo_binary_vec_sca); | |||
| all_algos.emplace_back(&algo_binary_vec_bcast101); | |||
| all_algos.emplace_back(&algo_binary_vec_bcastX0X); | |||
| all_algos.emplace_back(&algo_binary_vec_bcast110); | |||
| all_algos.emplace_back(&algo_binary_VEC_BCAST101xX); | |||
| all_algos.emplace_back(&algo_ternaryfma3_vec_vec_vec); | |||
| @@ -202,6 +204,16 @@ ElemwiseImpl::KernParam ElemwiseImpl::make_kern_param(ElemwiseImpl* opr) { | |||
| return kern_param; | |||
| } | |||
| if (is_vector(src0.layout) && is_broadcasted_3dim_like(src1.layout, binfo)) { | |||
| kern_param.broad_cast_type = BcastType::VEC_BCASTX0X; | |||
| return kern_param; | |||
| } | |||
| if (is_vector(src1.layout) && is_broadcasted_3dim_like(src0.layout, binfo)) { | |||
| kern_param.broad_cast_type = BcastType::BCASTX0X_VEC; | |||
| return kern_param; | |||
| } | |||
| if (is_legal_layout_for_nhwc(src1.layout) && | |||
| is_NHWC_broadcasted_channel_like(src0.layout, binfo)) { | |||
| kern_param.broad_cast_type = BcastType::BCAST111C_VEC; | |||
| @@ -38,6 +38,7 @@ private: | |||
| class AlgoBinaryVecVec; | |||
| class AlgoBinaryVecScalar; | |||
| class AlgoBinaryVecBcast101; | |||
| class AlgoBinaryVecBcastX0X; | |||
| class AlgoBinaryVecBcast111C; | |||
| class AlgoBinaryVecBcast101xX; | |||
| class AlgoTernaryFma3VecVecVec; | |||
| @@ -107,11 +107,13 @@ enum BcastType { | |||
| VEC, | |||
| VEC_VEC, | |||
| VEC_BCAST101, | |||
| VEC_BCASTX0X, | |||
| VEC_BCAST111C, | |||
| VEC_BCAST101xX, | |||
| VEC_SCALAR, | |||
| SCALAR_VEC, | |||
| BCAST101_VEC, | |||
| BCASTX0X_VEC, | |||
| BCAST111C_VEC, | |||
| BCAST101xX_VEC, | |||
| VEC_VEC_VEC, | |||
| @@ -230,6 +232,34 @@ struct OpCallerBinary<PowOp<ctype, ctype>, VEC_BCAST101> { | |||
| } | |||
| }; | |||
| template <typename ctype> | |||
| struct OpCallerBinary<PowOp<ctype, ctype>, VEC_BCASTX0X> { | |||
| using Op = PowOp<ctype, ctype>; | |||
| static void run( | |||
| const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, | |||
| typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, | |||
| DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) { | |||
| Op op(src0_dtype, src1_dtype, dst_dtype); | |||
| for (size_t b = 0; b < batch; b++) { | |||
| const typename Op::src_ctype* src1_ptr_base = src1 + b * channel_stride; | |||
| for (size_t c = 0; c < channel; c++) { | |||
| size_t i = 0; | |||
| auto src1_ptr = src1_ptr_base; | |||
| #if MEGDNN_FIX_AARCH32_BUG | |||
| // FIXME: as llvm may cause cannot select error if enable vectorize | |||
| #pragma clang loop vectorize(disable) | |||
| #endif | |||
| for (; i < channel_stride; i++) { | |||
| op(*src0, *src1_ptr, dst); | |||
| src0++; | |||
| src1_ptr++; | |||
| dst++; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| }; | |||
| template <typename ctype> | |||
| struct OpCallerBinary<PowOp<ctype, ctype>, VEC_BCAST111C> { | |||
| using Op = PowOp<ctype, ctype>; | |||
| @@ -332,6 +362,34 @@ struct OpCallerBinary<PowOp<ctype, ctype>, BCAST101_VEC> { | |||
| } | |||
| }; | |||
| template <typename ctype> | |||
| struct OpCallerBinary<PowOp<ctype, ctype>, BCASTX0X_VEC> { | |||
| using Op = PowOp<ctype, ctype>; | |||
| static void run( | |||
| const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, | |||
| typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, | |||
| DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) { | |||
| Op op(src0_dtype, src1_dtype, dst_dtype); | |||
| for (size_t b = 0; b < batch; b++) { | |||
| auto src0_ptr_base = src0 + b * channel_stride; | |||
| for (size_t c = 0; c < channel; c++) { | |||
| size_t i = 0; | |||
| auto src0_ptr = src0_ptr_base; | |||
| #if MEGDNN_FIX_AARCH32_BUG | |||
| // FIXME: as llvm may cause cannot select error if enable vectorize | |||
| #pragma clang loop vectorize(disable) | |||
| #endif | |||
| for (; i < channel_stride; i++) { | |||
| op(*src0_ptr, *src1, dst); | |||
| src0_ptr++; | |||
| src1++; | |||
| dst++; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| }; | |||
| template <typename Op> | |||
| struct OpCallerBinary<Op, VEC_VEC> { | |||
| static void run( | |||
| @@ -398,6 +456,45 @@ struct OpCallerBinary<Op, VEC_BCAST101> { | |||
| } | |||
| }; | |||
| template <typename Op> | |||
| struct OpCallerBinary<Op, VEC_BCASTX0X> { | |||
| static void run( | |||
| const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, | |||
| typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, | |||
| DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) { | |||
| Op op(src0_dtype, src1_dtype, dst_dtype); | |||
| ParamElemVisitor<typename Op::src_ctype> vis; | |||
| for (size_t b = 0; b < batch; b++) { | |||
| const typename Op::src_ctype* src1_ptr_base = src1 + b * channel_stride; | |||
| for (size_t c = 0; c < channel; c++) { | |||
| size_t i = 0; | |||
| auto src1_ptr = src1_ptr_base; | |||
| for (; i + Op::SIMD_WIDTH * 2 <= channel_stride; | |||
| i += Op::SIMD_WIDTH * 2) { | |||
| auto src0_neon0 = vis(src0); | |||
| auto src0_neon1 = vis(src0 + Op::SIMD_WIDTH); | |||
| auto src1_neon0 = vis(src1_ptr); | |||
| auto src1_neon1 = vis(src1_ptr + Op::SIMD_WIDTH); | |||
| op({{src0_neon0, src0_neon1}}, {{src1_neon0, src1_neon1}}, dst); | |||
| src0 += Op::SIMD_WIDTH * 2; | |||
| src1_ptr += Op::SIMD_WIDTH * 2; | |||
| dst += Op::SIMD_WIDTH * 2; | |||
| } | |||
| #if MEGDNN_FIX_AARCH32_BUG | |||
| // FIXME: as llvm may cause cannot select error if enable vectorize | |||
| #pragma clang loop vectorize(disable) | |||
| #endif | |||
| for (; i < channel_stride; i++) { | |||
| op(*src0, *src1_ptr, dst); | |||
| src0++; | |||
| src1_ptr++; | |||
| dst++; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| }; | |||
| template <typename Op> | |||
| struct OpCallerBinary<Op, VEC_BCAST111C> { | |||
| static void run( | |||
| @@ -844,6 +941,45 @@ struct OpCallerBinary<Op, BCAST101_VEC> { | |||
| } | |||
| }; | |||
| template <typename Op> | |||
| struct OpCallerBinary<Op, BCASTX0X_VEC> { | |||
| static void run( | |||
| const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, | |||
| typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, | |||
| DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) { | |||
| Op op(src0_dtype, src1_dtype, dst_dtype); | |||
| ParamElemVisitor<typename Op::src_ctype> vis; | |||
| for (size_t b = 0; b < batch; b++) { | |||
| auto src0_ptr_base = src0 + b * channel_stride; | |||
| for (size_t c = 0; c < channel; c++) { | |||
| auto src0_ptr = src0_ptr_base; | |||
| size_t i = 0; | |||
| for (; i + Op::SIMD_WIDTH * 2 <= channel_stride; | |||
| i += Op::SIMD_WIDTH * 2) { | |||
| auto src0_neon0 = vis(src0_ptr); | |||
| auto src0_neon1 = vis(src0_ptr + Op::SIMD_WIDTH); | |||
| auto src1_neon0 = vis(src1); | |||
| auto src1_neon1 = vis(src1 + Op::SIMD_WIDTH); | |||
| op({{src0_neon0, src0_neon1}}, {{src1_neon0, src1_neon1}}, dst); | |||
| src0_ptr += Op::SIMD_WIDTH * 2; | |||
| src1 += Op::SIMD_WIDTH * 2; | |||
| dst += Op::SIMD_WIDTH * 2; | |||
| } | |||
| #if MEGDNN_FIX_AARCH32_BUG | |||
| // FIXME: as llvm may cause cannot select error if enable vectorize | |||
| #pragma clang loop vectorize(disable) | |||
| #endif | |||
| for (; i < channel_stride; i++) { | |||
| op(*src0_ptr, *src1, dst); | |||
| src0_ptr++; | |||
| src1++; | |||
| dst++; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| }; | |||
| template <typename Op, BcastType bcast_type> | |||
| struct OpCallerTernary; | |||
| @@ -150,6 +150,20 @@ bool ElemwiseLayoutHelper::is_broadcasted_channel_like( | |||
| return false; | |||
| } | |||
| bool ElemwiseLayoutHelper::is_broadcasted_3dim_like( | |||
| const TensorLayout& layout, BroadcastChannelInfo& info) { | |||
| if (layout.format.type() == TensorFormat::Type::DEFAULT) { | |||
| if (layout.ndim == 3 && (layout.stride[0] - layout.shape[2]) == 0 && | |||
| layout.stride[1] == 0 && layout.stride[2] == 1) { | |||
| info.x = layout.shape[0]; | |||
| info.y = layout.shape[1]; | |||
| info.z = layout.shape[2]; | |||
| return true; | |||
| } | |||
| } | |||
| return false; | |||
| } | |||
| bool ElemwiseLayoutHelper::is_NHWC_broadcasted_channel_like( | |||
| const TensorLayout& layout, BroadcastChannelInfo& info) { | |||
| if (layout.format.type() == TensorFormat::Type::DEFAULT) { | |||
| @@ -80,6 +80,14 @@ public: | |||
| static bool is_broadcasted_channel_like( | |||
| const TensorLayout& layout, BroadcastChannelInfo& info); | |||
| /*! | |||
| * \brief check whether layout matches BroadcastChannelInfo like N1HW | |||
| * | |||
| * Note layout should be [N, 1, H*W] like | |||
| */ | |||
| static bool is_broadcasted_3dim_like( | |||
| const TensorLayout& layout, BroadcastChannelInfo& info); | |||
| /*! | |||
| * \brief check whether layout matches BroadcastChannelInfo under NHWC | |||
| * layout | |||
| @@ -356,6 +356,30 @@ TEST_F(ARM_COMMON, ELEMWISE_FORWARD_NHWC_FP32_BCAST) { | |||
| run_3d_incontig(Mode::FUSE_MUL_ADD3); | |||
| } | |||
| TEST_F(ARM_COMMON, ELEMWISE_FORWARD_N1HW_FP32_BCAST) { | |||
| using Mode = ElemwiseForward::Param::Mode; | |||
| Checker<ElemwiseForward> checker(handle()); | |||
| UniformFloatRNG rng(1e-5, 7e1); | |||
| checker.set_rng(0, &rng); | |||
| checker.set_epsilon(1e-5); | |||
| checker.set_dtype(0, dtype::Float32()); | |||
| checker.set_dtype(1, dtype::Float32()); | |||
| //! 2 dim | |||
| auto run = [&](Mode mode) { | |||
| // VEC_BCASTX0X | |||
| checker.set_param(mode).execs({{2, 8, 4, 4}, {2, 1, 4, 4}, {}}); | |||
| checker.set_param(mode).execs({{4, 21, 78}, {4, 1, 78}, {}}); | |||
| // BCASTX0X_VEC | |||
| checker.set_param(mode).execs({{2, 1, 4, 4}, {2, 8, 4, 4}, {}}); | |||
| checker.set_param(mode).execs({{4, 1, 78}, {4, 21, 78}, {}}); | |||
| }; | |||
| run(Mode::ADD); | |||
| run(Mode::MUL); | |||
| run(Mode::SUB); | |||
| } | |||
| #if MEGDNN_WITH_BENCHMARK | |||
| namespace { | |||
| void run_elemwise_benchmark( | |||