GitOrigin-RevId: ef19a636ba
tags/v0.5.0
| @@ -31,7 +31,10 @@ class ElemwiseImpl::AlgoPack { | |||||
| AlgoTernaryFma3VecVecVec algo_ternaryfma3_vec_vec_vec; | AlgoTernaryFma3VecVecVec algo_ternaryfma3_vec_vec_vec; | ||||
| AlgoTernaryFma3VecVecScalar algo_ternaryfma3_vec_vecsca; | AlgoTernaryFma3VecVecScalar algo_ternaryfma3_vec_vecsca; | ||||
| AlgoTernaryFma3Bcast101VecBcast101 algo_ternaryfma3_bcast101_vec_bcast101; | AlgoTernaryFma3Bcast101VecBcast101 algo_ternaryfma3_bcast101_vec_bcast101; | ||||
| AlgoTernaryFma3Bcast101x4VecBcast101x4 | |||||
| algo_ternaryfma3_bcast101x4_vec_bcast101x4; | |||||
| AlgoTernaryFma3VecBcast101Vec algo_ternaryfma3_vec_bcast101_vec; | AlgoTernaryFma3VecBcast101Vec algo_ternaryfma3_vec_bcast101_vec; | ||||
| AlgoTernaryFma3VecBcast101x4Vec algo_ternaryfma3_vec_bcast101x4_vec; | |||||
| AlgoTernaryFma3VecScalarVec algo_ternaryfma3_vec_sca_vec; | AlgoTernaryFma3VecScalarVec algo_ternaryfma3_vec_sca_vec; | ||||
| AlgoTernaryFma3VecScalarScalar algo_ternaryfma3_vec_sca_sca; | AlgoTernaryFma3VecScalarScalar algo_ternaryfma3_vec_sca_sca; | ||||
| @@ -45,7 +48,9 @@ public: | |||||
| all_algos.emplace_back(&algo_ternaryfma3_vec_vec_vec); | all_algos.emplace_back(&algo_ternaryfma3_vec_vec_vec); | ||||
| all_algos.emplace_back(&algo_ternaryfma3_vec_vecsca); | all_algos.emplace_back(&algo_ternaryfma3_vec_vecsca); | ||||
| all_algos.emplace_back(&algo_ternaryfma3_bcast101_vec_bcast101); | all_algos.emplace_back(&algo_ternaryfma3_bcast101_vec_bcast101); | ||||
| all_algos.emplace_back(&algo_ternaryfma3_bcast101x4_vec_bcast101x4); | |||||
| all_algos.emplace_back(&algo_ternaryfma3_vec_bcast101_vec); | all_algos.emplace_back(&algo_ternaryfma3_vec_bcast101_vec); | ||||
| all_algos.emplace_back(&algo_ternaryfma3_vec_bcast101x4_vec); | |||||
| all_algos.emplace_back(&algo_ternaryfma3_vec_sca_vec); | all_algos.emplace_back(&algo_ternaryfma3_vec_sca_vec); | ||||
| all_algos.emplace_back(&algo_ternaryfma3_vec_sca_sca); | all_algos.emplace_back(&algo_ternaryfma3_vec_sca_sca); | ||||
| } | } | ||||
| @@ -112,12 +117,25 @@ ElemwiseImpl::KernParam ElemwiseImpl::make_kern_param(ElemwiseImpl* opr) { | |||||
| return kern_param; | return kern_param; | ||||
| } | } | ||||
| if (is_vector(src1.layout) && | |||||
| is_broadcastedx_channel_like<4>(src0.layout, binfo) && | |||||
| src0.layout.eq_layout(src2.layout)) { | |||||
| kern_param.broad_cast_type = BcastType::BCAST101x4_VEC_BCAST101x4; | |||||
| return kern_param; | |||||
| } | |||||
| if (is_vector(src0.layout) && src0.layout.eq_layout(src2.layout) && | if (is_vector(src0.layout) && src0.layout.eq_layout(src2.layout) && | ||||
| is_broadcasted_channel_like(src1.layout, binfo)) { | is_broadcasted_channel_like(src1.layout, binfo)) { | ||||
| kern_param.broad_cast_type = BcastType::VEC_BCAST101_VEC; | kern_param.broad_cast_type = BcastType::VEC_BCAST101_VEC; | ||||
| return kern_param; | return kern_param; | ||||
| } | } | ||||
| if (is_vector(src0.layout) && src0.layout.eq_layout(src2.layout) && | |||||
| is_broadcastedx_channel_like<4>(src1.layout, binfo)) { | |||||
| kern_param.broad_cast_type = BcastType::VEC_BCAST101x4_VEC; | |||||
| return kern_param; | |||||
| } | |||||
| if (is_vector(src0.layout) && is_vector(src2.layout) && | if (is_vector(src0.layout) && is_vector(src2.layout) && | ||||
| is_broadcasted_scalar(src1.layout)) { | is_broadcasted_scalar(src1.layout)) { | ||||
| kern_param.broad_cast_type = BcastType::VEC_SCALAR_VEC; | kern_param.broad_cast_type = BcastType::VEC_SCALAR_VEC; | ||||
| @@ -41,7 +41,9 @@ private: | |||||
| class AlgoTernaryFma3VecVecVec; | class AlgoTernaryFma3VecVecVec; | ||||
| class AlgoTernaryFma3VecVecScalar; | class AlgoTernaryFma3VecVecScalar; | ||||
| class AlgoTernaryFma3Bcast101VecBcast101; | class AlgoTernaryFma3Bcast101VecBcast101; | ||||
| class AlgoTernaryFma3Bcast101x4VecBcast101x4; | |||||
| class AlgoTernaryFma3VecBcast101Vec; | class AlgoTernaryFma3VecBcast101Vec; | ||||
| class AlgoTernaryFma3VecBcast101x4Vec; | |||||
| class AlgoTernaryFma3VecScalarVec; | class AlgoTernaryFma3VecScalarVec; | ||||
| class AlgoTernaryFma3VecScalarScalar; | class AlgoTernaryFma3VecScalarScalar; | ||||
| class AlgoPack; | class AlgoPack; | ||||
| @@ -42,7 +42,9 @@ using namespace arm_common; | |||||
| DECL_AVAILABLE(VecVecVec, BcastType::VEC_VEC_VEC); | DECL_AVAILABLE(VecVecVec, BcastType::VEC_VEC_VEC); | ||||
| DECL_AVAILABLE(VecVecScalar, BcastType::VEC_VEC_SCALAR); | DECL_AVAILABLE(VecVecScalar, BcastType::VEC_VEC_SCALAR); | ||||
| DECL_AVAILABLE(Bcast101VecBcast101, BcastType::BCAST101_VEC_BCAST101); | DECL_AVAILABLE(Bcast101VecBcast101, BcastType::BCAST101_VEC_BCAST101); | ||||
| DECL_AVAILABLE(Bcast101x4VecBcast101x4, BcastType::BCAST101x4_VEC_BCAST101x4); | |||||
| DECL_AVAILABLE(VecBcast101Vec, BcastType::VEC_BCAST101_VEC); | DECL_AVAILABLE(VecBcast101Vec, BcastType::VEC_BCAST101_VEC); | ||||
| DECL_AVAILABLE(VecBcast101x4Vec, BcastType::VEC_BCAST101x4_VEC); | |||||
| DECL_AVAILABLE(VecScalarVec, BcastType::VEC_SCALAR_VEC); | DECL_AVAILABLE(VecScalarVec, BcastType::VEC_SCALAR_VEC); | ||||
| DECL_AVAILABLE(VecScalarScalar, BcastType::VEC_SCALAR_SCALAR); | DECL_AVAILABLE(VecScalarScalar, BcastType::VEC_SCALAR_SCALAR); | ||||
| #undef DECL_CB | #undef DECL_CB | ||||
| @@ -158,6 +160,82 @@ void ElemwiseImpl::AlgoTernaryFma3Bcast101VecBcast101::exec( | |||||
| return; | return; | ||||
| } | } | ||||
| void ElemwiseImpl::AlgoTernaryFma3Bcast101x4VecBcast101x4::exec( | |||||
| const KernParam& kern_param) const { | |||||
| auto& elparam = kern_param.ternary_elparam; | |||||
| auto &src0 = elparam[0], &src1 = elparam[1], &src2 = elparam[2]; | |||||
| BroadcastChannelInfo binfo; | |||||
| is_broadcastedx_channel_like<4>(src0.layout, binfo); | |||||
| #define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op) \ | |||||
| case Mode::_mode: \ | |||||
| MIDOUT_BEGIN(megdnn_arm_common_elemwise_ternary, midout_iv(_case), \ | |||||
| midout_iv(Mode::_mode), _type_midout_id) { \ | |||||
| thin_function<void(const _type*, const _type*, const _type*, \ | |||||
| _type*, DType, DType, DType, DType, size_t, \ | |||||
| size_t, size_t, size_t)> \ | |||||
| run = OpCallerTernary< \ | |||||
| _op<_type, _type>, \ | |||||
| BcastType::BCAST101x4_VEC_BCAST101x4>::run; \ | |||||
| MEGDNN_DISPATCH_CPU_KERN( \ | |||||
| static_cast<naive::HandleImpl*>(kern_param.handle), \ | |||||
| run(static_cast<const _type*>(src0.raw_ptr), \ | |||||
| static_cast<const _type*>(src1.raw_ptr), \ | |||||
| static_cast<const _type*>(src2.raw_ptr), \ | |||||
| static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ | |||||
| src1.layout.dtype, src2.layout.dtype, \ | |||||
| dst.layout.dtype, batch_size, binfo.x, binfo.y, \ | |||||
| binfo.z)); \ | |||||
| } \ | |||||
| MIDOUT_END(); \ | |||||
| return | |||||
| size_t batch_size = src1.layout.shape[0] / (binfo.x * binfo.y * binfo.z); | |||||
| auto&& dst = *(kern_param.m_dst); | |||||
| DISPATCH_TYPE("AlgoTernaryFma3Bcast101x4VecBcast101x4::exec"_hash); | |||||
| #undef DISPATCH_TERNARY | |||||
| return; | |||||
| } | |||||
| void ElemwiseImpl::AlgoTernaryFma3VecBcast101x4Vec::exec( | |||||
| const KernParam& kern_param) const { | |||||
| auto& elparam = kern_param.ternary_elparam; | |||||
| auto &src0 = elparam[0], &src1 = elparam[1], &src2 = elparam[2]; | |||||
| BroadcastChannelInfo binfo; | |||||
| is_broadcastedx_channel_like<4>(src1.layout, binfo); | |||||
| #define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op) \ | |||||
| case Mode::_mode: \ | |||||
| MIDOUT_BEGIN(megdnn_arm_common_elemwise_ternary, midout_iv(_case), \ | |||||
| midout_iv(Mode::_mode), _type_midout_id) { \ | |||||
| thin_function<void(const _type*, const _type*, const _type*, \ | |||||
| _type*, DType, DType, DType, DType, size_t, \ | |||||
| size_t, size_t, size_t)> \ | |||||
| run = OpCallerTernary<_op<_type, _type>, \ | |||||
| BcastType::VEC_BCAST101x4_VEC>::run; \ | |||||
| MEGDNN_DISPATCH_CPU_KERN( \ | |||||
| static_cast<naive::HandleImpl*>(kern_param.handle), \ | |||||
| run(static_cast<const _type*>(src0.raw_ptr), \ | |||||
| static_cast<const _type*>(src1.raw_ptr), \ | |||||
| static_cast<const _type*>(src2.raw_ptr), \ | |||||
| static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ | |||||
| src1.layout.dtype, src2.layout.dtype, \ | |||||
| dst.layout.dtype, batch_size, binfo.x, binfo.y, \ | |||||
| binfo.z)); \ | |||||
| } \ | |||||
| MIDOUT_END(); \ | |||||
| return | |||||
| size_t batch_size = src0.layout.shape[0] / (binfo.x * binfo.y * binfo.z); | |||||
| auto&& dst = *(kern_param.m_dst); | |||||
| DISPATCH_TYPE("AlgoTernaryFma3VecBcast101x4Vec::exec"_hash); | |||||
| #undef DISPATCH_TERNARY | |||||
| return; | |||||
| } | |||||
| void ElemwiseImpl::AlgoTernaryFma3VecBcast101Vec::exec( | void ElemwiseImpl::AlgoTernaryFma3VecBcast101Vec::exec( | ||||
| const KernParam& kern_param) const { | const KernParam& kern_param) const { | ||||
| auto& elparam = kern_param.ternary_elparam; | auto& elparam = kern_param.ternary_elparam; | ||||
| @@ -193,6 +271,7 @@ void ElemwiseImpl::AlgoTernaryFma3VecBcast101Vec::exec( | |||||
| return; | return; | ||||
| } | } | ||||
| void ElemwiseImpl::AlgoTernaryFma3VecScalarVec::exec( | void ElemwiseImpl::AlgoTernaryFma3VecScalarVec::exec( | ||||
| const KernParam& kern_param) const { | const KernParam& kern_param) const { | ||||
| auto& elparam = kern_param.ternary_elparam; | auto& elparam = kern_param.ternary_elparam; | ||||
| @@ -33,7 +33,9 @@ namespace arm_common { | |||||
| DECL_CB(VecVecVec); | DECL_CB(VecVecVec); | ||||
| DECL_CB(VecVecScalar); | DECL_CB(VecVecScalar); | ||||
| DECL_CB(Bcast101VecBcast101); | DECL_CB(Bcast101VecBcast101); | ||||
| DECL_CB(Bcast101x4VecBcast101x4); | |||||
| DECL_CB(VecBcast101Vec); | DECL_CB(VecBcast101Vec); | ||||
| DECL_CB(VecBcast101x4Vec); | |||||
| DECL_CB(VecScalarVec); | DECL_CB(VecScalarVec); | ||||
| DECL_CB(VecScalarScalar); | DECL_CB(VecScalarScalar); | ||||
| #undef DECL_CB | #undef DECL_CB | ||||
| @@ -810,6 +810,65 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(const ElemwiseOpParamN<3>& param, | |||||
| DISPATCH() | DISPATCH() | ||||
| #undef DISPATCH_SINGLE_MODE | |||||
| } | |||||
| } | |||||
| //! VEC + BCAST101x4 + VEC | |||||
| { | |||||
| BroadcastChannelInfo binfo; | |||||
| if (is_vector(src0.layout) && | |||||
| is_broadcastedx_channel_like<4>(src1.layout, binfo) && | |||||
| src0.layout.eq_shape(src2.layout)) { | |||||
| #define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \ | |||||
| case _mode: { \ | |||||
| using src_ctype = typename DTypeTrait<_src_dt>::ctype; \ | |||||
| using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \ | |||||
| thin_function<void(const src_ctype*, const src_ctype*, \ | |||||
| const src_ctype*, dst_ctype*, DType, DType, DType, \ | |||||
| DType, size_t, size_t, size_t, size_t)> \ | |||||
| run = OpCallerTernary<_op<src_ctype, dst_ctype>, \ | |||||
| VEC_BCAST101x4_VEC>::run; \ | |||||
| MEGDNN_DISPATCH_CPU_KERN_OPR( \ | |||||
| run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \ | |||||
| src2.ptr<src_ctype>(), dst.ptr<dst_ctype>(), \ | |||||
| src0.layout.dtype, src1.layout.dtype, src2.layout.dtype, \ | |||||
| dst.layout.dtype, batch_size, binfo.x, binfo.y, binfo.z)); \ | |||||
| return; \ | |||||
| } | |||||
| size_t batch_size = | |||||
| src0.layout.shape[0] / (binfo.x * binfo.y * binfo.z); | |||||
| DISPATCH() | |||||
| #undef DISPATCH_SINGLE_MODE | |||||
| } | |||||
| //! BCAST101x + VEC +BCAST101x | |||||
| if (is_vector(src1.layout) && | |||||
| is_broadcastedx_channel_like<4>(src0.layout, binfo) && | |||||
| src0.layout.eq_shape(src2.layout)) { | |||||
| #define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \ | |||||
| case _mode: { \ | |||||
| using src_ctype = typename DTypeTrait<_src_dt>::ctype; \ | |||||
| using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \ | |||||
| thin_function<void(const src_ctype*, const src_ctype*, \ | |||||
| const src_ctype*, dst_ctype*, DType, DType, DType, \ | |||||
| DType, size_t, size_t, size_t, size_t)> \ | |||||
| run = OpCallerTernary<_op<src_ctype, dst_ctype>, \ | |||||
| BCAST101x4_VEC_BCAST101x4>::run; \ | |||||
| MEGDNN_DISPATCH_CPU_KERN_OPR( \ | |||||
| run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \ | |||||
| src2.ptr<src_ctype>(), dst.ptr<dst_ctype>(), \ | |||||
| src0.layout.dtype, src1.layout.dtype, src2.layout.dtype, \ | |||||
| dst.layout.dtype, batch_size, binfo.x, binfo.y, binfo.z)); \ | |||||
| return; \ | |||||
| } | |||||
| size_t batch_size = | |||||
| src1.layout.shape[0] / (binfo.x * binfo.y * binfo.z); | |||||
| DISPATCH() | |||||
| #undef DISPATCH_SINGLE_MODE | #undef DISPATCH_SINGLE_MODE | ||||
| } | } | ||||
| } | } | ||||
| @@ -105,7 +105,9 @@ enum BcastType { | |||||
| VEC_VEC_VEC, | VEC_VEC_VEC, | ||||
| VEC_VEC_SCALAR, | VEC_VEC_SCALAR, | ||||
| BCAST101_VEC_BCAST101, | BCAST101_VEC_BCAST101, | ||||
| BCAST101x4_VEC_BCAST101x4, | |||||
| VEC_BCAST101_VEC, | VEC_BCAST101_VEC, | ||||
| VEC_BCAST101x4_VEC, | |||||
| VEC_SCALAR_VEC, | VEC_SCALAR_VEC, | ||||
| VEC_SCALAR_SCALAR, | VEC_SCALAR_SCALAR, | ||||
| UNKNOWN_BCAST_TYPE | UNKNOWN_BCAST_TYPE | ||||
| @@ -681,6 +683,54 @@ struct OpCallerTernary<Op, BCAST101_VEC_BCAST101> { | |||||
| } | } | ||||
| }; | }; | ||||
| //! src0: CHW44, src1: vector, src2: CHW44 | |||||
| template <typename Op> | |||||
| struct OpCallerTernary<Op, BCAST101x4_VEC_BCAST101x4> { | |||||
| static void run(const typename Op::src_ctype* src0, | |||||
| const typename Op::src_ctype* src1, | |||||
| const typename Op::src_ctype* src2, | |||||
| typename Op::dst_ctype* dst, DType src0_dtype, | |||||
| DType src1_dtype, DType src2_dtype, DType dst_dtype, | |||||
| size_t batch, size_t nr_channel_blocks, | |||||
| size_t channel_stride, size_t channel_block_dim) { | |||||
| megdnn_assert(channel_block_dim == 4, "only imp for nchw44"); | |||||
| Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype); | |||||
| ParamElemVisitorBcast101x4<typename Op::src_ctype> vis0; | |||||
| ParamElemVisitor<typename Op::src_ctype> vis1; | |||||
| ParamElemVisitorBcast101x4<typename Op::src_ctype> vis2; | |||||
| for (size_t b = 0; b < batch; b++) { | |||||
| auto src0_ptr = src0; | |||||
| auto src2_ptr = src2; | |||||
| for (size_t cb = 0; cb < nr_channel_blocks; cb++) { | |||||
| auto src0_block_ptr = src0_ptr + cb * channel_block_dim; | |||||
| auto src2_block_ptr = src2_ptr + cb * channel_block_dim; | |||||
| auto channel_block_vec0 = vis0(src0_block_ptr); | |||||
| auto channel_block_vec2 = vis2(src2_block_ptr); | |||||
| size_t img_index = 0; | |||||
| auto src1_offset = Op::SIMD_WIDTH / channel_block_dim; | |||||
| for (; img_index + 2 * src1_offset <= channel_stride; | |||||
| img_index += 2 * src1_offset) { | |||||
| op({{channel_block_vec0, channel_block_vec0}}, | |||||
| {{vis1(src1), vis1(src1 + Op::SIMD_WIDTH)}}, | |||||
| {{channel_block_vec2, channel_block_vec2}}, dst); | |||||
| src1 += Op::SIMD_WIDTH * 2; | |||||
| dst += Op::SIMD_WIDTH * 2; | |||||
| } | |||||
| // TODO:all elemwise_multi_type op imp one simd mode | |||||
| for (; img_index < channel_stride; img_index++) { | |||||
| for (size_t c_iter = 0; c_iter < channel_block_dim; | |||||
| c_iter++) { | |||||
| op(*(src0_block_ptr + c_iter), *src1, | |||||
| *(src2_block_ptr + c_iter), dst); | |||||
| src1++; | |||||
| dst++; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| }; | |||||
| //! src1: 1C11, src0 and src2 are contig | //! src1: 1C11, src0 and src2 are contig | ||||
| template <typename Op> | template <typename Op> | ||||
| struct OpCallerTernary<Op, VEC_BCAST101_VEC> { | struct OpCallerTernary<Op, VEC_BCAST101_VEC> { | ||||
| @@ -725,6 +775,52 @@ struct OpCallerTernary<Op, VEC_BCAST101_VEC> { | |||||
| } | } | ||||
| }; | }; | ||||
| //! src1: CHW44, src0 and src2 are contig | |||||
| template <typename Op> | |||||
| struct OpCallerTernary<Op, VEC_BCAST101x4_VEC> { | |||||
| static void run(const typename Op::src_ctype* src0, | |||||
| const typename Op::src_ctype* src1, | |||||
| const typename Op::src_ctype* src2, | |||||
| typename Op::dst_ctype* dst, DType src0_dtype, | |||||
| DType src1_dtype, DType src2_dtype, DType dst_dtype, | |||||
| size_t batch, size_t nr_channel_blocks, | |||||
| size_t channel_stride, size_t channel_block_dim) { | |||||
| megdnn_assert(channel_block_dim == 4, "only imp for nchw44"); | |||||
| Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype); | |||||
| ParamElemVisitor<typename Op::src_ctype> vis0; | |||||
| ParamElemVisitorBcast101x4<typename Op::src_ctype> vis1; | |||||
| ParamElemVisitor<typename Op::src_ctype> vis2; | |||||
| for (size_t b = 0; b < batch; b++) { | |||||
| auto src1_ptr = src1; | |||||
| for (size_t cb = 0; cb < nr_channel_blocks; cb++) { | |||||
| auto src1_block_ptr = src1_ptr + cb * channel_block_dim; | |||||
| auto channel_block_vec = vis1(src1_block_ptr); | |||||
| size_t img_index = 0; | |||||
| auto offset = Op::SIMD_WIDTH / channel_block_dim; | |||||
| for (; img_index + 2 * offset <= channel_stride; | |||||
| img_index += 2 * offset) { | |||||
| op({{vis0(src0), vis0(src0 + Op::SIMD_WIDTH)}}, | |||||
| {{channel_block_vec, channel_block_vec}}, | |||||
| {{vis2(src2), vis2(src2 + Op::SIMD_WIDTH)}}, dst); | |||||
| src0 += Op::SIMD_WIDTH * 2; | |||||
| src2 += Op::SIMD_WIDTH * 2; | |||||
| dst += Op::SIMD_WIDTH * 2; | |||||
| } | |||||
| // TODO:all elemwise_multi_type op imp one simd mode | |||||
| for (; img_index < channel_stride; img_index++) { | |||||
| for (size_t c_iter = 0; c_iter < channel_block_dim; | |||||
| c_iter++) { | |||||
| op(*src0, *(src1_block_ptr + c_iter), *src2, dst); | |||||
| src0++; | |||||
| src2++; | |||||
| dst++; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| }; | |||||
| //! src1: scalar, src0 and src2 has the same shape | //! src1: scalar, src0 and src2 has the same shape | ||||
| template <typename Op> | template <typename Op> | ||||
| struct OpCallerTernary<Op, VEC_SCALAR_VEC> { | struct OpCallerTernary<Op, VEC_SCALAR_VEC> { | ||||
| @@ -26,50 +26,53 @@ TYPED_TEST(ARM_ELEMWISE, run) { | |||||
| elemwise::run_test<TypeParam>(this->handle()); | elemwise::run_test<TypeParam>(this->handle()); | ||||
| } | } | ||||
| #define TERNARY_COMPLATE_TEST_CASE(_optr) \ | |||||
| printf("Check binary optr %s by all cases.\n", #_optr); \ | |||||
| checker.set_param(Mode::_optr) \ | |||||
| .execs({{3, 4, 7}, {3, 4, 7}, {3, 4, 7}, {}}); \ | |||||
| checker.set_param(Mode::_optr) \ | |||||
| .execs({{1, 4, 1, 1}, {3, 4, 5, 7}, {1, 4, 1, 1}, {}}); \ | |||||
| checker.set_param(Mode::_optr) \ | |||||
| .execs({{1, 4, 1}, {3, 4, 7}, {1, 4, 1}, {}}); \ | |||||
| checker.set_param(Mode::_optr) \ | |||||
| .execs({{3, 4, 5, 7}, {3, 4, 5, 7}, {1, 1, 1, 1}, {}}); \ | |||||
| checker.set_param(Mode::_optr).execs({{1, 7}, {1, 7}, {1, 7}, {}}); \ | |||||
| checker.set_param(Mode::_optr) \ | |||||
| .execs({{1, 2, 1}, {1, 2, 2}, {1, 2, 1}, {}}); \ | |||||
| checker.set_param(Mode::_optr) \ | |||||
| .execs({{1, 2, 2}, {1, 2, 2}, {1, 1, 1}, {}}); \ | |||||
| checker.set_param(Mode::_optr) \ | |||||
| .execs({{3, 4, 1}, {3, 4, 1}, {3, 4, 1}, {}}); \ | |||||
| checker.set_param(Mode::_optr).execs({{3, 4, 5}, {1}, {1}, {}}); \ | |||||
| checker.set_param(Mode::_optr).execs({{1}, {3, 4, 5}, {1}, {}}); | |||||
| #define BUILD_TERNARY_COMPLATE_TEST_CASE \ | |||||
| TERNARY_COMPLATE_TEST_CASE(FUSE_MUL_ADD3) | |||||
| TEST_F(ARM_COMMON, ELEMWISE_FORWARD_TERNARY) { | TEST_F(ARM_COMMON, ELEMWISE_FORWARD_TERNARY) { | ||||
| using Mode = ElemwiseForward::Param::Mode; | using Mode = ElemwiseForward::Param::Mode; | ||||
| Checker<ElemwiseForward> checker(handle()); | Checker<ElemwiseForward> checker(handle()); | ||||
| checker.set_param(Mode::FUSE_MUL_ADD3); | |||||
| auto run = [&] { | |||||
| //! nchw44 | |||||
| checker.execs({{1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}}); | |||||
| checker.execs({{1, 3, 1, 1, 4}, {2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}}); | |||||
| checker.execs({{1, 8, 1, 1, 4}, {3, 8, 5, 3, 4}, {1, 8, 1, 1, 4}, {}}); | |||||
| checker.execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}}); | |||||
| checker.execs({{1, 2, 1, 1, 4}, {1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {}}); | |||||
| //! nchw44 | |||||
| checker.execs({{1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {}}); | |||||
| checker.execs({{2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {2, 3, 2, 2, 4}, {}}); | |||||
| checker.execs({{3, 8, 5, 3, 4}, {1, 8, 1, 1, 4}, {3, 8, 5, 3, 4}, {}}); | |||||
| checker.execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}}); | |||||
| checker.execs({{1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {1, 2, 5, 7, 4}, {}}); | |||||
| checker.execs({{3, 4, 7}, {3, 4, 7}, {3, 4, 7}, {}}); | |||||
| checker.execs({{1, 4, 1, 1}, {3, 4, 5, 7}, {1, 4, 1, 1}, {}}); | |||||
| checker.execs({{1, 4, 1}, {3, 4, 7}, {1, 4, 1}, {}}); | |||||
| checker.execs({{3, 4, 5, 7}, {3, 4, 5, 7}, {1, 1, 1, 1}, {}}); | |||||
| checker.execs({{1, 7}, {1, 7}, {1, 7}, {}}); | |||||
| checker.execs({{1, 2, 1}, {1, 2, 2}, {1, 2, 1}, {}}); | |||||
| checker.execs({{1, 2, 2}, {1, 2, 2}, {1, 1, 1}, {}}); | |||||
| checker.execs({{3, 4, 1}, {3, 4, 1}, {3, 4, 1}, {}}); | |||||
| checker.execs({{3, 4, 5}, {1}, {1}, {}}); | |||||
| checker.execs({{1}, {3, 4, 5}, {1}, {}}); | |||||
| }; | |||||
| // case int | // case int | ||||
| checker.set_dtype(0, dtype::Int8()); | checker.set_dtype(0, dtype::Int8()); | ||||
| checker.set_dtype(1, dtype::Int8()); | checker.set_dtype(1, dtype::Int8()); | ||||
| checker.set_dtype(2, dtype::Int8()); | checker.set_dtype(2, dtype::Int8()); | ||||
| // BUILD_TERNARY_TEST_CASE | |||||
| BUILD_TERNARY_COMPLATE_TEST_CASE | |||||
| run(); | |||||
| checker.set_dtype(0, dtype::Int16()); | checker.set_dtype(0, dtype::Int16()); | ||||
| checker.set_dtype(1, dtype::Int16()); | checker.set_dtype(1, dtype::Int16()); | ||||
| checker.set_dtype(2, dtype::Int16()); | checker.set_dtype(2, dtype::Int16()); | ||||
| // BUILD_TERNARY_TEST_CASE | |||||
| BUILD_TERNARY_COMPLATE_TEST_CASE | |||||
| run(); | |||||
| checker.set_dtype(0, dtype::Int32()); | checker.set_dtype(0, dtype::Int32()); | ||||
| checker.set_dtype(1, dtype::Int32()); | checker.set_dtype(1, dtype::Int32()); | ||||
| checker.set_dtype(2, dtype::Int32()); | checker.set_dtype(2, dtype::Int32()); | ||||
| // BUILD_TERNARY_TEST_CASE | |||||
| BUILD_TERNARY_COMPLATE_TEST_CASE | |||||
| run(); | |||||
| // case float | // case float | ||||
| UniformFloatRNG rng(1e-5, 7e1); | UniformFloatRNG rng(1e-5, 7e1); | ||||
| @@ -78,9 +81,7 @@ TEST_F(ARM_COMMON, ELEMWISE_FORWARD_TERNARY) { | |||||
| checker.set_dtype(0, dtype::Float32()); | checker.set_dtype(0, dtype::Float32()); | ||||
| checker.set_dtype(1, dtype::Float32()); | checker.set_dtype(1, dtype::Float32()); | ||||
| checker.set_dtype(2, dtype::Float32()); | checker.set_dtype(2, dtype::Float32()); | ||||
| // BUILD_TERNARY_TEST_CASE | |||||
| BUILD_TERNARY_COMPLATE_TEST_CASE | |||||
| run(); | |||||
| #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | ||||
| // case half | // case half | ||||
| @@ -90,9 +91,7 @@ TEST_F(ARM_COMMON, ELEMWISE_FORWARD_TERNARY) { | |||||
| checker.set_dtype(0, dtype::Float16()); | checker.set_dtype(0, dtype::Float16()); | ||||
| checker.set_dtype(1, dtype::Float16()); | checker.set_dtype(1, dtype::Float16()); | ||||
| checker.set_dtype(2, dtype::Float16()); | checker.set_dtype(2, dtype::Float16()); | ||||
| // BUILD_TERNARY_TEST_CASE | |||||
| BUILD_TERNARY_COMPLATE_TEST_CASE | |||||
| run(); | |||||
| #endif | #endif | ||||
| } | } | ||||
| @@ -214,6 +214,30 @@ TEST_F(ARM_COMMON, ELEMWISE_QUANTIZED_MODE_TERNARY) { | |||||
| using Mode = ElemwiseMultiType::Param::Mode; | using Mode = ElemwiseMultiType::Param::Mode; | ||||
| Checker<ElemwiseMultiType> checker(handle()); | Checker<ElemwiseMultiType> checker(handle()); | ||||
| auto run = [&]() { | |||||
| //! nchw44 | |||||
| checker.execs({{1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}}); | |||||
| checker.execs({{1, 3, 1, 1, 4}, {2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}}); | |||||
| checker.execs({{1, 8, 1, 1, 4}, {3, 8, 5, 3, 4}, {1, 8, 1, 1, 4}, {}}); | |||||
| checker.execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}}); | |||||
| checker.execs({{1, 2, 1, 1, 4}, {1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {}}); | |||||
| //! nchw44 | |||||
| checker.execs({{1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {}}); | |||||
| checker.execs({{2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {2, 3, 2, 2, 4}, {}}); | |||||
| checker.execs({{3, 8, 5, 3, 4}, {1, 8, 1, 1, 4}, {3, 8, 5, 3, 4}, {}}); | |||||
| checker.execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}}); | |||||
| checker.execs({{1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {1, 2, 5, 7, 4}, {}}); | |||||
| checker.execs({{3, 4, 5, 6}, {3, 4, 5, 6}, {1, 1, 1, 1}, {}}); | |||||
| checker.execs({{1, 4, 1, 1}, {3, 4, 5, 6}, {1, 4, 1, 1}, {}}); | |||||
| checker.execs({{3}, {3}, {3}, {}}); | |||||
| checker.execs({{9}, {9}, {9}, {}}); | |||||
| checker.execs({{17}, {17}, {17}, {}}); | |||||
| checker.execs({{3, 4, 5, 6}, {3, 4, 5, 6}, {3, 4, 5, 6}, {}}); | |||||
| }; | |||||
| for (auto mode : {Mode::QFUSE_MUL_ADD3}) { | for (auto mode : {Mode::QFUSE_MUL_ADD3}) { | ||||
| checker.set_param({mode}); | checker.set_param({mode}); | ||||
| @@ -226,14 +250,7 @@ TEST_F(ARM_COMMON, ELEMWISE_QUANTIZED_MODE_TERNARY) { | |||||
| .set_dtype(1, dtype::QuantizedS8(1.15f)) | .set_dtype(1, dtype::QuantizedS8(1.15f)) | ||||
| .set_dtype(2, dtype::QuantizedS8(1.75f)) | .set_dtype(2, dtype::QuantizedS8(1.75f)) | ||||
| .set_dtype(3, dtype::QuantizedS8(1.35f)); | .set_dtype(3, dtype::QuantizedS8(1.35f)); | ||||
| checker.execs({{3, 4, 5, 6}, {3, 4, 5, 6}, {1, 1, 1, 1}, {}}); | |||||
| checker.execs({{1, 4, 1, 1}, {3, 4, 5, 6}, {1, 4, 1, 1}, {}}); | |||||
| checker.execs({{3}, {3}, {3}, {}}); | |||||
| checker.execs({{9}, {9}, {9}, {}}); | |||||
| checker.execs({{17}, {17}, {17}, {}}); | |||||
| checker.execs({{3, 4, 5, 6}, {3, 4, 5, 6}, {3, 4, 5, 6}, {}}); | |||||
| run(); | |||||
| // quint8 to quint8 | // quint8 to quint8 | ||||
| UniformIntRNG rng_uint8{0, 225}; | UniformIntRNG rng_uint8{0, 225}; | ||||
| @@ -248,14 +265,7 @@ TEST_F(ARM_COMMON, ELEMWISE_QUANTIZED_MODE_TERNARY) { | |||||
| static_cast<uint8_t>(128))) | static_cast<uint8_t>(128))) | ||||
| .set_dtype(3, dtype::Quantized8Asymm( | .set_dtype(3, dtype::Quantized8Asymm( | ||||
| 1.45f, static_cast<uint8_t>(128))); | 1.45f, static_cast<uint8_t>(128))); | ||||
| checker.execs({{3, 4, 5, 6}, {3, 4, 5, 6}, {1, 1, 1, 1}, {}}); | |||||
| checker.execs({{1, 4, 1, 1}, {3, 4, 5, 6}, {1, 4, 1, 1}, {}}); | |||||
| checker.execs({{3}, {3}, {3}, {}}); | |||||
| checker.execs({{9}, {9}, {9}, {}}); | |||||
| checker.execs({{17}, {17}, {17}, {}}); | |||||
| checker.execs({{3, 4, 5, 6}, {3, 4, 5, 6}, {3, 4, 5, 6}, {}}); | |||||
| run(); | |||||
| } | } | ||||
| } | } | ||||