From c48d58daa88a41d479d4b1c669d467763aa58dc1 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 28 Oct 2021 16:23:02 +0800 Subject: [PATCH] feat(dnn/arm_common): add N1HW like elemwise broadcast mode GitOrigin-RevId: 28951358012c2d085f68260fd723797f943138ca --- dnn/src/arm_common/elemwise/binary/algo.cpp | 81 ++++++++++++ dnn/src/arm_common/elemwise/binary/algo.h | 1 + dnn/src/arm_common/elemwise/opr_impl.cpp | 12 ++ dnn/src/arm_common/elemwise/opr_impl.h | 1 + dnn/src/arm_common/elemwise_op.h | 136 ++++++++++++++++++++ dnn/src/common/elemwise/opr_impl_helper.cpp | 14 ++ dnn/src/common/elemwise/opr_impl_helper.h | 8 ++ dnn/test/arm_common/elemwise.cpp | 24 ++++ 8 files changed, 277 insertions(+) diff --git a/dnn/src/arm_common/elemwise/binary/algo.cpp b/dnn/src/arm_common/elemwise/binary/algo.cpp index cf7e086b..2bdd0ea1 100644 --- a/dnn/src/arm_common/elemwise/binary/algo.cpp +++ b/dnn/src/arm_common/elemwise/binary/algo.cpp @@ -104,6 +104,21 @@ bool ElemwiseImpl::AlgoBinaryVecBcast101::is_available( return false; } +bool ElemwiseImpl::AlgoBinaryVecBcastX0X::is_available( + const KernParam& kern_param) const { + if (!is_available_common(kern_param.mode) || + ((BcastType::VEC_BCASTX0X != kern_param.broad_cast_type) && + (BcastType::BCASTX0X_VEC != kern_param.broad_cast_type))) + return false; + + auto& elparam = kern_param.binary_elparam; + auto& src0 = elparam[0]; + + DISPATCH_TYPE("AlgoBinaryVecBcastX0X::is_available"_hash); + + return false; +} + bool ElemwiseImpl::AlgoBinaryVecBcast111C::is_available( const KernParam& kern_param) const { if (!is_available_common(kern_param.mode) || @@ -348,6 +363,72 @@ void ElemwiseImpl::AlgoBinaryVecBcast101::exec(const KernParam& kern_param) cons return; } +void ElemwiseImpl::AlgoBinaryVecBcastX0X::exec(const KernParam& kern_param) const { + auto& elparam = kern_param.binary_elparam; + auto &src0 = elparam[0], &src1 = elparam[1]; + auto&& dst = *(kern_param.m_dst); + BroadcastChannelInfo binfo; + + // Case: BcastType::VEC + BCAST_X0X + if (BcastType::VEC_BCASTX0X == kern_param.broad_cast_type && + is_broadcasted_3dim_like(src1.layout, binfo)) { +#define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \ + case Mode::_mode: \ + MIDOUT_BEGIN( \ + megdnn_arm_common_elemwise_binary, midout_iv(_case), \ + midout_iv(Mode::_mode), _type_midout_id) { \ + thin_function \ + run = OpCallerBinary< \ + _op<_type, _type>, BcastType::VEC_BCASTX0X>::run; \ + MEGDNN_DISPATCH_CPU_KERN( \ + static_cast(kern_param.handle), \ + run(static_cast(src0.raw_ptr), \ + static_cast(src1.raw_ptr), \ + static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ + src1.layout.dtype, dst.layout.dtype, binfo.x, binfo.y, \ + binfo.z)); \ + } \ + MIDOUT_END(); \ + return + + DISPATCH_TYPE("AlgoBinaryVecBcastX0X::exec_vec_b"_hash); + +#undef DISPATCH_BINARY + } + + // BCAST_X0X + BcastType::VEC + if (BcastType::BCASTX0X_VEC == kern_param.broad_cast_type && + is_broadcasted_3dim_like(src0.layout, binfo)) { +#define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \ + case Mode::_mode: \ + MIDOUT_BEGIN( \ + megdnn_arm_common_elemwise_binary, midout_iv(_case), \ + midout_iv(Mode::_mode), _type_midout_id) { \ + thin_function \ + run = OpCallerBinary< \ + _op<_type, _type>, BcastType::BCASTX0X_VEC>::run; \ + MEGDNN_DISPATCH_CPU_KERN( \ + static_cast(kern_param.handle), \ + run(static_cast(src0.raw_ptr), \ + static_cast(src1.raw_ptr), \ + static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ + src1.layout.dtype, dst.layout.dtype, binfo.x, binfo.y, \ + binfo.z)); \ + } \ + MIDOUT_END(); \ + return + + DISPATCH_TYPE("AlgoBinaryVecBcastX0X::exec_b_vec"_hash); + +#undef DISPATCH_BINARY + } + return; +} + void ElemwiseImpl::AlgoBinaryVecBcast111C::exec(const KernParam& kern_param) const { auto& elparam = kern_param.binary_elparam; auto &src0 = elparam[0], &src1 = elparam[1]; diff --git a/dnn/src/arm_common/elemwise/binary/algo.h b/dnn/src/arm_common/elemwise/binary/algo.h index f4462183..12eaa119 100644 --- a/dnn/src/arm_common/elemwise/binary/algo.h +++ b/dnn/src/arm_common/elemwise/binary/algo.h @@ -33,6 +33,7 @@ namespace arm_common { DECL_CB(VecVec); DECL_CB(VecScalar); DECL_CB(VecBcast101); +DECL_CB(VecBcastX0X); DECL_CB(VecBcast111C); DECL_CB(VecBcast101xX); #undef DECL_CB diff --git a/dnn/src/arm_common/elemwise/opr_impl.cpp b/dnn/src/arm_common/elemwise/opr_impl.cpp index 75c94358..e144834a 100644 --- a/dnn/src/arm_common/elemwise/opr_impl.cpp +++ b/dnn/src/arm_common/elemwise/opr_impl.cpp @@ -27,6 +27,7 @@ class ElemwiseImpl::AlgoPack { AlgoBinaryVecVec algo_binary_vec_vec; AlgoBinaryVecScalar algo_binary_vec_sca; AlgoBinaryVecBcast101 algo_binary_vec_bcast101; + AlgoBinaryVecBcastX0X algo_binary_vec_bcastX0X; AlgoBinaryVecBcast111C algo_binary_vec_bcast110; AlgoBinaryVecBcast101xX algo_binary_VEC_BCAST101xX; AlgoTernaryFma3VecVecVec algo_ternaryfma3_vec_vec_vec; @@ -46,6 +47,7 @@ public: all_algos.emplace_back(&algo_binary_vec_vec); all_algos.emplace_back(&algo_binary_vec_sca); all_algos.emplace_back(&algo_binary_vec_bcast101); + all_algos.emplace_back(&algo_binary_vec_bcastX0X); all_algos.emplace_back(&algo_binary_vec_bcast110); all_algos.emplace_back(&algo_binary_VEC_BCAST101xX); all_algos.emplace_back(&algo_ternaryfma3_vec_vec_vec); @@ -202,6 +204,16 @@ ElemwiseImpl::KernParam ElemwiseImpl::make_kern_param(ElemwiseImpl* opr) { return kern_param; } + if (is_vector(src0.layout) && is_broadcasted_3dim_like(src1.layout, binfo)) { + kern_param.broad_cast_type = BcastType::VEC_BCASTX0X; + return kern_param; + } + + if (is_vector(src1.layout) && is_broadcasted_3dim_like(src0.layout, binfo)) { + kern_param.broad_cast_type = BcastType::BCASTX0X_VEC; + return kern_param; + } + if (is_legal_layout_for_nhwc(src1.layout) && is_NHWC_broadcasted_channel_like(src0.layout, binfo)) { kern_param.broad_cast_type = BcastType::BCAST111C_VEC; diff --git a/dnn/src/arm_common/elemwise/opr_impl.h b/dnn/src/arm_common/elemwise/opr_impl.h index 33fa1c55..8f528a4d 100644 --- a/dnn/src/arm_common/elemwise/opr_impl.h +++ b/dnn/src/arm_common/elemwise/opr_impl.h @@ -38,6 +38,7 @@ private: class AlgoBinaryVecVec; class AlgoBinaryVecScalar; class AlgoBinaryVecBcast101; + class AlgoBinaryVecBcastX0X; class AlgoBinaryVecBcast111C; class AlgoBinaryVecBcast101xX; class AlgoTernaryFma3VecVecVec; diff --git a/dnn/src/arm_common/elemwise_op.h b/dnn/src/arm_common/elemwise_op.h index db18f422..38f7731b 100644 --- a/dnn/src/arm_common/elemwise_op.h +++ b/dnn/src/arm_common/elemwise_op.h @@ -107,11 +107,13 @@ enum BcastType { VEC, VEC_VEC, VEC_BCAST101, + VEC_BCASTX0X, VEC_BCAST111C, VEC_BCAST101xX, VEC_SCALAR, SCALAR_VEC, BCAST101_VEC, + BCASTX0X_VEC, BCAST111C_VEC, BCAST101xX_VEC, VEC_VEC_VEC, @@ -230,6 +232,34 @@ struct OpCallerBinary, VEC_BCAST101> { } }; +template +struct OpCallerBinary, VEC_BCASTX0X> { + using Op = PowOp; + static void run( + const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, + typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, + DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) { + Op op(src0_dtype, src1_dtype, dst_dtype); + for (size_t b = 0; b < batch; b++) { + const typename Op::src_ctype* src1_ptr_base = src1 + b * channel_stride; + for (size_t c = 0; c < channel; c++) { + size_t i = 0; + auto src1_ptr = src1_ptr_base; +#if MEGDNN_FIX_AARCH32_BUG +// FIXME: as llvm may cause cannot select error if enable vectorize +#pragma clang loop vectorize(disable) +#endif + for (; i < channel_stride; i++) { + op(*src0, *src1_ptr, dst); + src0++; + src1_ptr++; + dst++; + } + } + } + } +}; + template struct OpCallerBinary, VEC_BCAST111C> { using Op = PowOp; @@ -332,6 +362,34 @@ struct OpCallerBinary, BCAST101_VEC> { } }; +template +struct OpCallerBinary, BCASTX0X_VEC> { + using Op = PowOp; + static void run( + const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, + typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, + DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) { + Op op(src0_dtype, src1_dtype, dst_dtype); + for (size_t b = 0; b < batch; b++) { + auto src0_ptr_base = src0 + b * channel_stride; + for (size_t c = 0; c < channel; c++) { + size_t i = 0; + auto src0_ptr = src0_ptr_base; +#if MEGDNN_FIX_AARCH32_BUG +// FIXME: as llvm may cause cannot select error if enable vectorize +#pragma clang loop vectorize(disable) +#endif + for (; i < channel_stride; i++) { + op(*src0_ptr, *src1, dst); + src0_ptr++; + src1++; + dst++; + } + } + } + } +}; + template struct OpCallerBinary { static void run( @@ -398,6 +456,45 @@ struct OpCallerBinary { } }; +template +struct OpCallerBinary { + static void run( + const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, + typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, + DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) { + Op op(src0_dtype, src1_dtype, dst_dtype); + ParamElemVisitor vis; + for (size_t b = 0; b < batch; b++) { + const typename Op::src_ctype* src1_ptr_base = src1 + b * channel_stride; + for (size_t c = 0; c < channel; c++) { + size_t i = 0; + auto src1_ptr = src1_ptr_base; + for (; i + Op::SIMD_WIDTH * 2 <= channel_stride; + i += Op::SIMD_WIDTH * 2) { + auto src0_neon0 = vis(src0); + auto src0_neon1 = vis(src0 + Op::SIMD_WIDTH); + auto src1_neon0 = vis(src1_ptr); + auto src1_neon1 = vis(src1_ptr + Op::SIMD_WIDTH); + op({{src0_neon0, src0_neon1}}, {{src1_neon0, src1_neon1}}, dst); + src0 += Op::SIMD_WIDTH * 2; + src1_ptr += Op::SIMD_WIDTH * 2; + dst += Op::SIMD_WIDTH * 2; + } +#if MEGDNN_FIX_AARCH32_BUG +// FIXME: as llvm may cause cannot select error if enable vectorize +#pragma clang loop vectorize(disable) +#endif + for (; i < channel_stride; i++) { + op(*src0, *src1_ptr, dst); + src0++; + src1_ptr++; + dst++; + } + } + } + } +}; + template struct OpCallerBinary { static void run( @@ -844,6 +941,45 @@ struct OpCallerBinary { } }; +template +struct OpCallerBinary { + static void run( + const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, + typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, + DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) { + Op op(src0_dtype, src1_dtype, dst_dtype); + ParamElemVisitor vis; + for (size_t b = 0; b < batch; b++) { + auto src0_ptr_base = src0 + b * channel_stride; + for (size_t c = 0; c < channel; c++) { + auto src0_ptr = src0_ptr_base; + size_t i = 0; + for (; i + Op::SIMD_WIDTH * 2 <= channel_stride; + i += Op::SIMD_WIDTH * 2) { + auto src0_neon0 = vis(src0_ptr); + auto src0_neon1 = vis(src0_ptr + Op::SIMD_WIDTH); + auto src1_neon0 = vis(src1); + auto src1_neon1 = vis(src1 + Op::SIMD_WIDTH); + op({{src0_neon0, src0_neon1}}, {{src1_neon0, src1_neon1}}, dst); + src0_ptr += Op::SIMD_WIDTH * 2; + src1 += Op::SIMD_WIDTH * 2; + dst += Op::SIMD_WIDTH * 2; + } +#if MEGDNN_FIX_AARCH32_BUG +// FIXME: as llvm may cause cannot select error if enable vectorize +#pragma clang loop vectorize(disable) +#endif + for (; i < channel_stride; i++) { + op(*src0_ptr, *src1, dst); + src0_ptr++; + src1++; + dst++; + } + } + } + } +}; + template struct OpCallerTernary; diff --git a/dnn/src/common/elemwise/opr_impl_helper.cpp b/dnn/src/common/elemwise/opr_impl_helper.cpp index bee1cce2..9e14908d 100644 --- a/dnn/src/common/elemwise/opr_impl_helper.cpp +++ b/dnn/src/common/elemwise/opr_impl_helper.cpp @@ -150,6 +150,20 @@ bool ElemwiseLayoutHelper::is_broadcasted_channel_like( return false; } +bool ElemwiseLayoutHelper::is_broadcasted_3dim_like( + const TensorLayout& layout, BroadcastChannelInfo& info) { + if (layout.format.type() == TensorFormat::Type::DEFAULT) { + if (layout.ndim == 3 && (layout.stride[0] - layout.shape[2]) == 0 && + layout.stride[1] == 0 && layout.stride[2] == 1) { + info.x = layout.shape[0]; + info.y = layout.shape[1]; + info.z = layout.shape[2]; + return true; + } + } + return false; +} + bool ElemwiseLayoutHelper::is_NHWC_broadcasted_channel_like( const TensorLayout& layout, BroadcastChannelInfo& info) { if (layout.format.type() == TensorFormat::Type::DEFAULT) { diff --git a/dnn/src/common/elemwise/opr_impl_helper.h b/dnn/src/common/elemwise/opr_impl_helper.h index fde396af..c038829d 100644 --- a/dnn/src/common/elemwise/opr_impl_helper.h +++ b/dnn/src/common/elemwise/opr_impl_helper.h @@ -80,6 +80,14 @@ public: static bool is_broadcasted_channel_like( const TensorLayout& layout, BroadcastChannelInfo& info); + /*! + * \brief check whether layout matches BroadcastChannelInfo like N1HW + * + * Note layout should be [N, 1, H*W] like + */ + static bool is_broadcasted_3dim_like( + const TensorLayout& layout, BroadcastChannelInfo& info); + /*! * \brief check whether layout matches BroadcastChannelInfo under NHWC * layout diff --git a/dnn/test/arm_common/elemwise.cpp b/dnn/test/arm_common/elemwise.cpp index f144a46f..aeca923d 100644 --- a/dnn/test/arm_common/elemwise.cpp +++ b/dnn/test/arm_common/elemwise.cpp @@ -356,6 +356,30 @@ TEST_F(ARM_COMMON, ELEMWISE_FORWARD_NHWC_FP32_BCAST) { run_3d_incontig(Mode::FUSE_MUL_ADD3); } +TEST_F(ARM_COMMON, ELEMWISE_FORWARD_N1HW_FP32_BCAST) { + using Mode = ElemwiseForward::Param::Mode; + Checker checker(handle()); + + UniformFloatRNG rng(1e-5, 7e1); + checker.set_rng(0, &rng); + checker.set_epsilon(1e-5); + checker.set_dtype(0, dtype::Float32()); + checker.set_dtype(1, dtype::Float32()); + + //! 2 dim + auto run = [&](Mode mode) { + // VEC_BCASTX0X + checker.set_param(mode).execs({{2, 8, 4, 4}, {2, 1, 4, 4}, {}}); + checker.set_param(mode).execs({{4, 21, 78}, {4, 1, 78}, {}}); + // BCASTX0X_VEC + checker.set_param(mode).execs({{2, 1, 4, 4}, {2, 8, 4, 4}, {}}); + checker.set_param(mode).execs({{4, 1, 78}, {4, 21, 78}, {}}); + }; + run(Mode::ADD); + run(Mode::MUL); + run(Mode::SUB); +} + #if MEGDNN_WITH_BENCHMARK namespace { void run_elemwise_benchmark(