GitOrigin-RevId: fb4300004c
tags/v1.7.0
| @@ -104,6 +104,21 @@ bool ElemwiseImpl::AlgoBinaryVecBcast101::is_available( | |||
| return false; | |||
| } | |||
| bool ElemwiseImpl::AlgoBinaryVecBcast111C::is_available( | |||
| const KernParam& kern_param) const { | |||
| if (!is_available_common(kern_param.mode) || | |||
| ((BcastType::VEC_BCAST111C != kern_param.broad_cast_type) && | |||
| (BcastType::BCAST111C_VEC != kern_param.broad_cast_type))) | |||
| return false; | |||
| auto& elparam = kern_param.binary_elparam; | |||
| auto& src0 = elparam[0]; | |||
| DISPATCH_TYPE("AlgoBinaryVecBcast111C::is_available"_hash); | |||
| return false; | |||
| } | |||
| bool ElemwiseImpl::AlgoBinaryVecBcast101xX::is_available( | |||
| const KernParam& kern_param) const { | |||
| if (!is_available_common(kern_param.mode) || | |||
| @@ -333,6 +348,72 @@ void ElemwiseImpl::AlgoBinaryVecBcast101::exec(const KernParam& kern_param) cons | |||
| return; | |||
| } | |||
| void ElemwiseImpl::AlgoBinaryVecBcast111C::exec(const KernParam& kern_param) const { | |||
| auto& elparam = kern_param.binary_elparam; | |||
| auto &src0 = elparam[0], &src1 = elparam[1]; | |||
| auto&& dst = *(kern_param.m_dst); | |||
| BroadcastChannelInfo binfo; | |||
| // Case extra: BcastType::VEC + BCAST_111C | |||
| if (BcastType::VEC_BCAST111C == kern_param.broad_cast_type && | |||
| is_NHWC_broadcasted_channel_like(src1.layout, binfo)) { | |||
| #define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \ | |||
| case Mode::_mode: \ | |||
| MIDOUT_BEGIN( \ | |||
| megdnn_arm_common_elemwise_binary, midout_iv(_case), \ | |||
| midout_iv(Mode::_mode), _type_midout_id) { \ | |||
| thin_function<void( \ | |||
| const _type*, const _type*, _type*, DType, DType, DType, size_t, \ | |||
| size_t, size_t)> \ | |||
| run = OpCallerBinary< \ | |||
| _op<_type, _type>, BcastType::VEC_BCAST111C>::run; \ | |||
| MEGDNN_DISPATCH_CPU_KERN( \ | |||
| static_cast<naive::HandleImpl*>(kern_param.handle), \ | |||
| run(static_cast<const _type*>(src0.raw_ptr), \ | |||
| static_cast<const _type*>(src1.raw_ptr), \ | |||
| static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ | |||
| src1.layout.dtype, dst.layout.dtype, binfo.x, binfo.y, \ | |||
| binfo.z)); \ | |||
| } \ | |||
| MIDOUT_END(); \ | |||
| return | |||
| DISPATCH_TYPE("AlgoBinaryVecBcast111C::exec_vec_b"_hash); | |||
| #undef DISPATCH_BINARY | |||
| } | |||
| // BCAST_111C + BcastType::VEC | |||
| if (BcastType::BCAST111C_VEC == kern_param.broad_cast_type && | |||
| is_NHWC_broadcasted_channel_like(src0.layout, binfo)) { | |||
| #define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \ | |||
| case Mode::_mode: \ | |||
| MIDOUT_BEGIN( \ | |||
| megdnn_arm_common_elemwise_binary, midout_iv(_case), \ | |||
| midout_iv(Mode::_mode), _type_midout_id) { \ | |||
| thin_function<void( \ | |||
| const _type*, const _type*, _type*, DType, DType, DType, size_t, \ | |||
| size_t, size_t)> \ | |||
| run = OpCallerBinary< \ | |||
| _op<_type, _type>, BcastType::BCAST111C_VEC>::run; \ | |||
| MEGDNN_DISPATCH_CPU_KERN( \ | |||
| static_cast<naive::HandleImpl*>(kern_param.handle), \ | |||
| run(static_cast<const _type*>(src0.raw_ptr), \ | |||
| static_cast<const _type*>(src1.raw_ptr), \ | |||
| static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ | |||
| src1.layout.dtype, dst.layout.dtype, binfo.x, binfo.y, \ | |||
| binfo.z)); \ | |||
| } \ | |||
| MIDOUT_END(); \ | |||
| return | |||
| DISPATCH_TYPE("AlgoBinaryVecBcast111C::exec_b_vec"_hash); | |||
| #undef DISPATCH_BINARY | |||
| } | |||
| return; | |||
| } | |||
| void ElemwiseImpl::AlgoBinaryVecBcast101xX::exec(const KernParam& kern_param) const { | |||
| auto& elparam = kern_param.binary_elparam; | |||
| auto &src0 = elparam[0], &src1 = elparam[1]; | |||
| @@ -33,6 +33,7 @@ namespace arm_common { | |||
| DECL_CB(VecVec); | |||
| DECL_CB(VecScalar); | |||
| DECL_CB(VecBcast101); | |||
| DECL_CB(VecBcast111C); | |||
| DECL_CB(VecBcast101xX); | |||
| #undef DECL_CB | |||
| } // namespace arm_common | |||
| @@ -27,12 +27,15 @@ class ElemwiseImpl::AlgoPack { | |||
| AlgoBinaryVecVec algo_binary_vec_vec; | |||
| AlgoBinaryVecScalar algo_binary_vec_sca; | |||
| AlgoBinaryVecBcast101 algo_binary_vec_bcast101; | |||
| AlgoBinaryVecBcast111C algo_binary_vec_bcast110; | |||
| AlgoBinaryVecBcast101xX algo_binary_VEC_BCAST101xX; | |||
| AlgoTernaryFma3VecVecVec algo_ternaryfma3_vec_vec_vec; | |||
| AlgoTernaryFma3VecVecScalar algo_ternaryfma3_vec_vecsca; | |||
| AlgoTernaryFma3Bcast101VecBcast101 algo_ternaryfma3_bcast101_vec_bcast101; | |||
| AlgoTernaryFma3Bcast111CVecBcast111C algo_ternaryfma3_bcast110_vec_bcast110; | |||
| AlgoTernaryFma3Bcast101xXVecBcast101xX algo_ternaryfma3_bcast101xX_vec_bcast101xX; | |||
| AlgoTernaryFma3VecBcast101Vec algo_ternaryfma3_vec_bcast101_vec; | |||
| AlgoTernaryFma3VecBcast111CVec algo_ternaryfma3_vec_bcast110_vec; | |||
| AlgoTernaryFma3VecBcast101xXVec algo_ternaryfma3_vec_bcast101xX_vec; | |||
| AlgoTernaryFma3VecScalarVec algo_ternaryfma3_vec_sca_vec; | |||
| AlgoTernaryFma3VecScalarScalar algo_ternaryfma3_vec_sca_sca; | |||
| @@ -43,12 +46,15 @@ public: | |||
| all_algos.emplace_back(&algo_binary_vec_vec); | |||
| all_algos.emplace_back(&algo_binary_vec_sca); | |||
| all_algos.emplace_back(&algo_binary_vec_bcast101); | |||
| all_algos.emplace_back(&algo_binary_vec_bcast110); | |||
| all_algos.emplace_back(&algo_binary_VEC_BCAST101xX); | |||
| all_algos.emplace_back(&algo_ternaryfma3_vec_vec_vec); | |||
| all_algos.emplace_back(&algo_ternaryfma3_vec_vecsca); | |||
| all_algos.emplace_back(&algo_ternaryfma3_bcast101_vec_bcast101); | |||
| all_algos.emplace_back(&algo_ternaryfma3_bcast110_vec_bcast110); | |||
| all_algos.emplace_back(&algo_ternaryfma3_bcast101xX_vec_bcast101xX); | |||
| all_algos.emplace_back(&algo_ternaryfma3_vec_bcast101_vec); | |||
| all_algos.emplace_back(&algo_ternaryfma3_vec_bcast110_vec); | |||
| all_algos.emplace_back(&algo_ternaryfma3_vec_bcast101xX_vec); | |||
| all_algos.emplace_back(&algo_ternaryfma3_vec_sca_vec); | |||
| all_algos.emplace_back(&algo_ternaryfma3_vec_sca_sca); | |||
| @@ -87,6 +93,14 @@ ElemwiseImpl::KernParam ElemwiseImpl::make_kern_param(ElemwiseImpl* opr) { | |||
| kern_param.mode = opr->param().mode; | |||
| kern_param.handle = opr->handle(); | |||
| auto is_legal_layout_for_nhwc = [](const TensorLayout& l) { | |||
| if (is_vector(l)) | |||
| return true; | |||
| if (l.ndim == 2 && l.stride[1] == 1) | |||
| return true; | |||
| return false; | |||
| }; | |||
| if ((opr->m_src->size() == 3) && (opr->param().mode == Mode::FUSE_MUL_ADD3)) { | |||
| kern_param.ternary_elparam = opr->make_elemwise_op_param<3>(); | |||
| bool c_is_scalar; | |||
| @@ -127,6 +141,20 @@ ElemwiseImpl::KernParam ElemwiseImpl::make_kern_param(ElemwiseImpl* opr) { | |||
| return kern_param; | |||
| } | |||
| if (is_legal_layout_for_nhwc(src1.layout) && | |||
| is_NHWC_broadcasted_channel_like(src0.layout, binfo) && | |||
| src0.layout.eq_layout(src2.layout)) { | |||
| kern_param.broad_cast_type = BcastType::BCAST111C_VEC_BCAST111C; | |||
| return kern_param; | |||
| } | |||
| if (is_legal_layout_for_nhwc(src0.layout) && | |||
| src2.layout.eq_layout(src0.layout) && | |||
| is_NHWC_broadcasted_channel_like(src1.layout, binfo)) { | |||
| kern_param.broad_cast_type = BcastType::VEC_BCAST111C_VEC; | |||
| return kern_param; | |||
| } | |||
| if (is_vector(src0.layout) && src0.layout.eq_layout(src2.layout) && | |||
| (is_broadcastedx_channel_like<4>(src1.layout, binfo) || | |||
| is_broadcastedx_channel_like<8>(src1.layout, binfo))) { | |||
| @@ -174,6 +202,18 @@ ElemwiseImpl::KernParam ElemwiseImpl::make_kern_param(ElemwiseImpl* opr) { | |||
| return kern_param; | |||
| } | |||
| if (is_legal_layout_for_nhwc(src1.layout) && | |||
| is_NHWC_broadcasted_channel_like(src0.layout, binfo)) { | |||
| kern_param.broad_cast_type = BcastType::BCAST111C_VEC; | |||
| return kern_param; | |||
| } | |||
| if (is_legal_layout_for_nhwc(src0.layout) && | |||
| is_NHWC_broadcasted_channel_like(src1.layout, binfo)) { | |||
| kern_param.broad_cast_type = BcastType::VEC_BCAST111C; | |||
| return kern_param; | |||
| } | |||
| if (is_vector(src0.layout) && | |||
| (is_broadcastedx_channel_like<4>(src1.layout, binfo) || | |||
| is_broadcastedx_channel_like<8>(src1.layout, binfo))) { | |||
| @@ -38,12 +38,15 @@ private: | |||
| class AlgoBinaryVecVec; | |||
| class AlgoBinaryVecScalar; | |||
| class AlgoBinaryVecBcast101; | |||
| class AlgoBinaryVecBcast111C; | |||
| class AlgoBinaryVecBcast101xX; | |||
| class AlgoTernaryFma3VecVecVec; | |||
| class AlgoTernaryFma3VecVecScalar; | |||
| class AlgoTernaryFma3Bcast101VecBcast101; | |||
| class AlgoTernaryFma3Bcast111CVecBcast111C; | |||
| class AlgoTernaryFma3Bcast101xXVecBcast101xX; | |||
| class AlgoTernaryFma3VecBcast101Vec; | |||
| class AlgoTernaryFma3VecBcast111CVec; | |||
| class AlgoTernaryFma3VecBcast101xXVec; | |||
| class AlgoTernaryFma3VecScalarVec; | |||
| class AlgoTernaryFma3VecScalarScalar; | |||
| @@ -42,8 +42,10 @@ using namespace arm_common; | |||
| DECL_AVAILABLE(VecVecVec, BcastType::VEC_VEC_VEC); | |||
| DECL_AVAILABLE(VecVecScalar, BcastType::VEC_VEC_SCALAR); | |||
| DECL_AVAILABLE(Bcast101VecBcast101, BcastType::BCAST101_VEC_BCAST101); | |||
| DECL_AVAILABLE(Bcast111CVecBcast111C, BcastType::BCAST111C_VEC_BCAST111C); | |||
| DECL_AVAILABLE(Bcast101xXVecBcast101xX, BcastType::BCAST101xX_VEC_BCAST101xX); | |||
| DECL_AVAILABLE(VecBcast101Vec, BcastType::VEC_BCAST101_VEC); | |||
| DECL_AVAILABLE(VecBcast111CVec, BcastType::VEC_BCAST111C_VEC); | |||
| DECL_AVAILABLE(VecBcast101xXVec, BcastType::VEC_BCAST101xX_VEC); | |||
| DECL_AVAILABLE(VecScalarVec, BcastType::VEC_SCALAR_VEC); | |||
| DECL_AVAILABLE(VecScalarScalar, BcastType::VEC_SCALAR_SCALAR); | |||
| @@ -164,6 +166,45 @@ void ElemwiseImpl::AlgoTernaryFma3Bcast101VecBcast101::exec( | |||
| return; | |||
| } | |||
| void ElemwiseImpl::AlgoTernaryFma3Bcast111CVecBcast111C::exec( | |||
| const KernParam& kern_param) const { | |||
| auto& elparam = kern_param.ternary_elparam; | |||
| auto &src0 = elparam[0], &src1 = elparam[1], &src2 = elparam[2]; | |||
| // Case 3: shape of src0 and src2 is {1, 1, 1, C} | |||
| BroadcastChannelInfo binfo; | |||
| is_NHWC_broadcasted_channel_like(src0.layout, binfo); | |||
| #define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op) \ | |||
| case Mode::_mode: \ | |||
| MIDOUT_BEGIN( \ | |||
| megdnn_arm_common_elemwise_ternary, midout_iv(_case), \ | |||
| midout_iv(Mode::_mode), _type_midout_id) { \ | |||
| thin_function<void( \ | |||
| const _type*, const _type*, size_t, const _type*, _type*, DType, \ | |||
| DType, DType, DType, size_t, size_t, size_t)> \ | |||
| run = OpCallerTernary< \ | |||
| _op<_type, _type>, \ | |||
| BcastType::BCAST111C_VEC_BCAST111C>::run; \ | |||
| MEGDNN_DISPATCH_CPU_KERN( \ | |||
| static_cast<naive::HandleImpl*>(kern_param.handle), \ | |||
| run(static_cast<const _type*>(src0.raw_ptr), \ | |||
| static_cast<const _type*>(src1.raw_ptr), \ | |||
| is_vector(src1.layout) ? 0 : src1.layout.stride[0] - binfo.z, \ | |||
| static_cast<const _type*>(src2.raw_ptr), \ | |||
| static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ | |||
| src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ | |||
| binfo.x, binfo.y, binfo.z)); \ | |||
| } \ | |||
| MIDOUT_END(); \ | |||
| return | |||
| auto&& dst = *(kern_param.m_dst); | |||
| DISPATCH_TYPE("AlgoTernaryFma3Bcast111CVecBcast111C::exec"_hash); | |||
| #undef DISPATCH_TERNARY | |||
| return; | |||
| } | |||
| void ElemwiseImpl::AlgoTernaryFma3Bcast101xXVecBcast101xX::exec( | |||
| const KernParam& kern_param) const { | |||
| auto& elparam = kern_param.ternary_elparam; | |||
| @@ -282,6 +323,45 @@ void ElemwiseImpl::AlgoTernaryFma3VecBcast101Vec::exec( | |||
| return; | |||
| } | |||
| void ElemwiseImpl::AlgoTernaryFma3VecBcast111CVec::exec( | |||
| const KernParam& kern_param) const { | |||
| auto& elparam = kern_param.ternary_elparam; | |||
| auto &src0 = elparam[0], &src1 = elparam[1], &src2 = elparam[2]; | |||
| // Case 4: shape of src1 is {1, 1, 1, C}, and src0 and src2 are contig | |||
| BroadcastChannelInfo binfo; | |||
| is_NHWC_broadcasted_channel_like(src1.layout, binfo); | |||
| #define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op) \ | |||
| case Mode::_mode: \ | |||
| MIDOUT_BEGIN( \ | |||
| megdnn_arm_common_elemwise_ternary, midout_iv(_case), \ | |||
| midout_iv(Mode::_mode), _type_midout_id) { \ | |||
| thin_function<void( \ | |||
| const _type*, size_t, const _type*, const _type*, size_t, _type*, \ | |||
| DType, DType, DType, DType, size_t, size_t, size_t)> \ | |||
| run = OpCallerTernary< \ | |||
| _op<_type, _type>, BcastType::VEC_BCAST111C_VEC>::run; \ | |||
| MEGDNN_DISPATCH_CPU_KERN( \ | |||
| static_cast<naive::HandleImpl*>(kern_param.handle), \ | |||
| run(static_cast<const _type*>(src0.raw_ptr), \ | |||
| is_vector(src0.layout) ? 0 : src0.layout.stride[0] - binfo.z, \ | |||
| static_cast<const _type*>(src1.raw_ptr), \ | |||
| static_cast<const _type*>(src2.raw_ptr), \ | |||
| is_vector(src2.layout) ? 0 : src2.layout.stride[0] - binfo.z, \ | |||
| static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ | |||
| src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ | |||
| binfo.x, binfo.y, binfo.z)); \ | |||
| } \ | |||
| MIDOUT_END(); \ | |||
| return | |||
| auto&& dst = *(kern_param.m_dst); | |||
| DISPATCH_TYPE("AlgoTernaryFma3VecBcast111CVec::exec"_hash); | |||
| #undef DISPATCH_TERNARY | |||
| return; | |||
| } | |||
| void ElemwiseImpl::AlgoTernaryFma3VecScalarVec::exec( | |||
| const KernParam& kern_param) const { | |||
| auto& elparam = kern_param.ternary_elparam; | |||
| @@ -33,8 +33,10 @@ namespace arm_common { | |||
| DECL_CB(VecVecVec); | |||
| DECL_CB(VecVecScalar); | |||
| DECL_CB(Bcast101VecBcast101); | |||
| DECL_CB(Bcast111CVecBcast111C); | |||
| DECL_CB(Bcast101xXVecBcast101xX); | |||
| DECL_CB(VecBcast101Vec); | |||
| DECL_CB(VecBcast111CVec); | |||
| DECL_CB(VecBcast101xXVec); | |||
| DECL_CB(VecScalarVec); | |||
| DECL_CB(VecScalarScalar); | |||
| @@ -107,16 +107,20 @@ enum BcastType { | |||
| VEC, | |||
| VEC_VEC, | |||
| VEC_BCAST101, | |||
| VEC_BCAST111C, | |||
| VEC_BCAST101xX, | |||
| VEC_SCALAR, | |||
| SCALAR_VEC, | |||
| BCAST101_VEC, | |||
| BCAST111C_VEC, | |||
| BCAST101xX_VEC, | |||
| VEC_VEC_VEC, | |||
| VEC_VEC_SCALAR, | |||
| BCAST101_VEC_BCAST101, | |||
| BCAST111C_VEC_BCAST111C, | |||
| BCAST101xX_VEC_BCAST101xX, | |||
| VEC_BCAST101_VEC, | |||
| VEC_BCAST111C_VEC, | |||
| VEC_BCAST101xX_VEC, | |||
| VEC_SCALAR_VEC, | |||
| VEC_SCALAR_SCALAR, | |||
| @@ -226,6 +230,60 @@ struct OpCallerBinary<PowOp<ctype, ctype>, VEC_BCAST101> { | |||
| } | |||
| }; | |||
| template <typename ctype> | |||
| struct OpCallerBinary<PowOp<ctype, ctype>, VEC_BCAST111C> { | |||
| using Op = PowOp<ctype, ctype>; | |||
| static void run( | |||
| const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, | |||
| typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, | |||
| DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) { | |||
| Op op(src0_dtype, src1_dtype, dst_dtype); | |||
| for (size_t b = 0; b < batch; b++) { | |||
| for (size_t c = 0; c < channel; c++) { | |||
| size_t i = 0; | |||
| const typename Op::src_ctype* src1_ptr = src1; | |||
| #if MEGDNN_FIX_AARCH32_BUG | |||
| // FIXME: as llvm may cause cannot select error if enable vectorize | |||
| #pragma clang loop vectorize(disable) | |||
| #endif | |||
| for (; i < channel_stride; i++) { | |||
| op(*src0, *src1_ptr, dst); | |||
| src0++; | |||
| src1_ptr++; | |||
| dst++; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| }; | |||
| template <typename ctype> | |||
| struct OpCallerBinary<PowOp<ctype, ctype>, BCAST111C_VEC> { | |||
| using Op = PowOp<ctype, ctype>; | |||
| static void run( | |||
| const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, | |||
| typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, | |||
| DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) { | |||
| Op op(src0_dtype, src1_dtype, dst_dtype); | |||
| for (size_t b = 0; b < batch; b++) { | |||
| for (size_t c = 0; c < channel; c++) { | |||
| size_t i = 0; | |||
| const typename Op::src_ctype* src0_ptr = src0; | |||
| #if MEGDNN_FIX_AARCH32_BUG | |||
| // FIXME: as llvm may cause cannot select error if enable vectorize | |||
| #pragma clang loop vectorize(disable) | |||
| #endif | |||
| for (; i < channel_stride; i++) { | |||
| op(*src0_ptr, *src1, dst); | |||
| src0_ptr++; | |||
| src1++; | |||
| dst++; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| }; | |||
| template <typename ctype> | |||
| struct OpCallerBinary<PowOp<ctype, ctype>, SCALAR_VEC> { | |||
| using Op = PowOp<ctype, ctype>; | |||
| @@ -340,6 +398,84 @@ struct OpCallerBinary<Op, VEC_BCAST101> { | |||
| } | |||
| }; | |||
| template <typename Op> | |||
| struct OpCallerBinary<Op, VEC_BCAST111C> { | |||
| static void run( | |||
| const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, | |||
| typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, | |||
| DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) { | |||
| Op op(src0_dtype, src1_dtype, dst_dtype); | |||
| ParamElemVisitor<typename Op::src_ctype> vis; | |||
| for (size_t b = 0; b < batch; b++) { | |||
| for (size_t c = 0; c < channel; c++) { | |||
| size_t rest = channel_stride; | |||
| const typename Op::src_ctype* src1_ptr = src1; | |||
| while (rest >= Op::SIMD_WIDTH * 2) { | |||
| auto src0_neon0 = vis(src0); | |||
| auto src0_neon1 = vis(src0 + Op::SIMD_WIDTH); | |||
| auto src1_neon0 = vis(src1_ptr); | |||
| auto src1_neon1 = vis(src1_ptr + Op::SIMD_WIDTH); | |||
| src0 += Op::SIMD_WIDTH * 2; | |||
| src1_ptr += Op::SIMD_WIDTH * 2; | |||
| op({{src0_neon0, src0_neon1}}, {{src1_neon0, src1_neon1}}, dst); | |||
| dst += Op::SIMD_WIDTH * 2; | |||
| rest -= Op::SIMD_WIDTH * 2; | |||
| } | |||
| #if MEGDNN_FIX_AARCH32_BUG | |||
| // FIXME: as llvm may cause cannot select error if enable vectorize | |||
| #pragma clang loop vectorize(disable) | |||
| #endif | |||
| while (rest > 0) { | |||
| op(*src0, *src1_ptr, dst); | |||
| dst++; | |||
| src0++; | |||
| src1_ptr++; | |||
| rest--; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| }; | |||
| template <typename Op> | |||
| struct OpCallerBinary<Op, BCAST111C_VEC> { | |||
| static void run( | |||
| const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, | |||
| typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, | |||
| DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) { | |||
| Op op(src0_dtype, src1_dtype, dst_dtype); | |||
| ParamElemVisitor<typename Op::src_ctype> vis; | |||
| for (size_t b = 0; b < batch; b++) { | |||
| for (size_t c = 0; c < channel; c++) { | |||
| size_t rest = channel_stride; | |||
| const typename Op::src_ctype* src0_ptr = src0; | |||
| while (rest >= Op::SIMD_WIDTH * 2) { | |||
| auto src0_neon0 = vis(src0_ptr); | |||
| auto src0_neon1 = vis(src0_ptr + Op::SIMD_WIDTH); | |||
| auto src1_neon0 = vis(src1); | |||
| auto src1_neon1 = vis(src1 + Op::SIMD_WIDTH); | |||
| src0_ptr += Op::SIMD_WIDTH * 2; | |||
| src1 += Op::SIMD_WIDTH * 2; | |||
| op({{src0_neon0, src0_neon1}}, {{src1_neon0, src1_neon1}}, dst); | |||
| dst += Op::SIMD_WIDTH * 2; | |||
| rest -= Op::SIMD_WIDTH * 2; | |||
| } | |||
| #if MEGDNN_FIX_AARCH32_BUG | |||
| // FIXME: as llvm may cause cannot select error if enable vectorize | |||
| #pragma clang loop vectorize(disable) | |||
| #endif | |||
| while (rest > 0) { | |||
| op(*src0_ptr, *src1, dst); | |||
| dst++; | |||
| src0_ptr++; | |||
| src1++; | |||
| rest--; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| }; | |||
| template <typename ctype> | |||
| struct OpCallerBinary<PowOp<ctype, ctype>, BCAST101xX_VEC> { | |||
| using Op = PowOp<ctype, ctype>; | |||
| @@ -824,6 +960,54 @@ struct OpCallerTernary<Op, BCAST101_VEC_BCAST101> { | |||
| } | |||
| }; | |||
| //! src0: 111C, src1: vector, src2: 111C, src1 may not be contig | |||
| template <typename Op> | |||
| struct OpCallerTernary<Op, BCAST111C_VEC_BCAST111C> { | |||
| static void run( | |||
| const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, | |||
| size_t src1_offset, const typename Op::src_ctype* src2, | |||
| typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, | |||
| DType src2_dtype, DType dst_dtype, size_t batch_size, size_t channel_size, | |||
| size_t channel_stride) { | |||
| Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype); | |||
| ParamElemVisitor<typename Op::src_ctype> vis; | |||
| for (size_t batch = 0; batch < batch_size; batch++) { | |||
| for (size_t channel = 0; channel < channel_size; channel++) { | |||
| auto src0_ptr = src0; | |||
| auto src2_ptr = src2; | |||
| size_t i = 0; | |||
| for (; i + Op::SIMD_WIDTH * 2 <= channel_stride; | |||
| i += Op::SIMD_WIDTH * 2) { | |||
| auto src0_neon0 = vis(src0_ptr); | |||
| auto src0_neon1 = vis(src0_ptr + Op::SIMD_WIDTH); | |||
| auto src1_neon0 = vis(src1); | |||
| auto src1_neon1 = vis(src1 + Op::SIMD_WIDTH); | |||
| auto src2_neon0 = vis(src2_ptr); | |||
| auto src2_neon1 = vis(src2_ptr + Op::SIMD_WIDTH); | |||
| op({{src0_neon0, src0_neon1}}, {{src1_neon0, src1_neon1}}, | |||
| {{src2_neon0, src2_neon1}}, dst); | |||
| src0_ptr += Op::SIMD_WIDTH * 2; | |||
| src1 += Op::SIMD_WIDTH * 2; | |||
| src2_ptr += Op::SIMD_WIDTH * 2; | |||
| dst += Op::SIMD_WIDTH * 2; | |||
| } | |||
| #if MEGDNN_FIX_AARCH32_BUG | |||
| // FIXME: as llvm may cause cannot select error if enable vectorize | |||
| #pragma clang loop vectorize(disable) | |||
| #endif | |||
| for (; i < channel_stride; i++) { | |||
| op(*src0_ptr, *src1, *src2_ptr, dst); | |||
| src0_ptr++; | |||
| src1++; | |||
| src2_ptr++; | |||
| dst++; | |||
| } | |||
| src1 += src1_offset; | |||
| } | |||
| } | |||
| } | |||
| }; | |||
| template <typename src_ctype, size_t channel_block_dim> | |||
| struct OpCallerTernaryBcast101xXVecBcast101xX { | |||
| template <typename Op> | |||
| @@ -992,6 +1176,51 @@ struct OpCallerTernary<Op, VEC_BCAST101_VEC> { | |||
| } | |||
| }; | |||
| //! src1: 111C, src0 and src2 may not be contig | |||
| template <typename Op> | |||
| struct OpCallerTernary<Op, VEC_BCAST111C_VEC> { | |||
| static void run( | |||
| const typename Op::src_ctype* src0, size_t src0_offset, | |||
| const typename Op::src_ctype* src1, const typename Op::src_ctype* src2, | |||
| size_t src2_offset, typename Op::dst_ctype* dst, DType src0_dtype, | |||
| DType src1_dtype, DType src2_dtype, DType dst_dtype, size_t batch_size, | |||
| size_t channel_size, size_t channel_stride) { | |||
| Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype); | |||
| ParamElemVisitor<typename Op::src_ctype> vis0; | |||
| ParamElemVisitor<typename Op::src_ctype> vis1; | |||
| ParamElemVisitor<typename Op::src_ctype> vis2; | |||
| for (size_t batch = 0; batch < batch_size; batch++) { | |||
| for (size_t channel = 0; channel < channel_size; channel++) { | |||
| auto src1_ptr = src1; | |||
| size_t i = 0; | |||
| for (; i + Op::SIMD_WIDTH * 2 <= channel_stride; | |||
| i += Op::SIMD_WIDTH * 2) { | |||
| op({{vis0(src0), vis0(src0 + Op::SIMD_WIDTH)}}, | |||
| {{vis1(src1_ptr), vis1(src1_ptr + Op::SIMD_WIDTH)}}, | |||
| {{vis2(src2), vis2(src2 + Op::SIMD_WIDTH)}}, dst); | |||
| src0 += Op::SIMD_WIDTH * 2; | |||
| src1_ptr += Op::SIMD_WIDTH * 2; | |||
| src2 += Op::SIMD_WIDTH * 2; | |||
| dst += Op::SIMD_WIDTH * 2; | |||
| } | |||
| #if MEGDNN_FIX_AARCH32_BUG | |||
| // FIXME: as llvm may cause cannot select error if enable vectorize | |||
| #pragma clang loop vectorize(disable) | |||
| #endif | |||
| for (; i < channel_stride; i++) { | |||
| op(*src0, *src1_ptr, *src2, dst); | |||
| src0++; | |||
| src1_ptr++; | |||
| src2++; | |||
| dst++; | |||
| } | |||
| src0 += src0_offset; | |||
| src2 += src2_offset; | |||
| } | |||
| } | |||
| } | |||
| }; | |||
| template <typename src_ctype, size_t channel_block_dim> | |||
| struct OpCallerTernaryVecBcast101xXVec { | |||
| template <typename Op> | |||
| @@ -50,6 +50,20 @@ inline dt_qint32 QConverter::convert(const float& src) { | |||
| saturate<int32_t, float>(std::round(src), -2147483648, 2147483647)); | |||
| } | |||
| template <> | |||
| inline float32x4x2_t QConverter::convert(const int16x8_t& vsrc) { | |||
| int32x4_t vhi = vmovl_s16(vget_high_s16(vsrc)); | |||
| int32x4_t vlo = vmovl_s16(vget_low_s16(vsrc)); | |||
| return {{vcvtq_f32_s32(vlo), vcvtq_f32_s32(vhi)}}; | |||
| } | |||
| template <> | |||
| inline float32x4x2_t QConverter::convert(const uint16x8_t& vsrc) { | |||
| uint32x4_t vhi = vmovl_u16(vget_high_u16(vsrc)); | |||
| uint32x4_t vlo = vmovl_u16(vget_low_u16(vsrc)); | |||
| return {{vcvtq_f32_u32(vlo), vcvtq_f32_u32(vhi)}}; | |||
| } | |||
| #if __ARM_ARCH >= 8 | |||
| template <> | |||
| inline int8x8_t QConverter::convert(const float32x4x2_t& vsrc) { | |||
| @@ -17,6 +17,7 @@ | |||
| #include "src/common/utils.h" | |||
| #include "src/naive/handle.h" | |||
| MIDOUT_DECL(megdnn_arm_typecvt_fix2float) | |||
| MIDOUT_DECL(megdnn_arm_typecvt_quantized) | |||
| MIDOUT_DECL(megdnn_arm_typecvt_float) | |||
| @@ -325,6 +326,48 @@ struct FloatTypeCvter<float, __fp16> { | |||
| }; | |||
| #endif | |||
| template <typename ctype, typename dtype> | |||
| struct Fix2FloatTypeCvter; | |||
| template <> | |||
| struct Fix2FloatTypeCvter<int16_t, float> { | |||
| using stype = int16_t; | |||
| using dst_type = float; | |||
| static constexpr size_t SIMD_WIDTH = 8; | |||
| Fix2FloatTypeCvter(DType src_dtype, DType dst_dtype) { | |||
| MEGDNN_MARK_USED_VAR(src_dtype); | |||
| MEGDNN_MARK_USED_VAR(dst_dtype); | |||
| } | |||
| void cvt(const int16_t* src, float* dst) { | |||
| int16x8_t vitem = vld1q_s16(src); | |||
| auto vres = QConverter::convert<float32x4x2_t, int16x8_t>(vitem); | |||
| vst1q_f32_x2(dst, vres); | |||
| } | |||
| void cvt_remain(const int16_t* src, float* dst) { *dst = *src; } | |||
| }; | |||
| template <> | |||
| struct Fix2FloatTypeCvter<uint16_t, float> { | |||
| using stype = uint16_t; | |||
| using dst_type = float; | |||
| static constexpr size_t SIMD_WIDTH = 8; | |||
| Fix2FloatTypeCvter(DType src_dtype, DType dst_dtype) { | |||
| MEGDNN_MARK_USED_VAR(src_dtype); | |||
| MEGDNN_MARK_USED_VAR(dst_dtype); | |||
| } | |||
| void cvt(const uint16_t* src, float* dst) { | |||
| uint16x8_t vitem = vld1q_u16(src); | |||
| auto vres = QConverter::convert<float32x4x2_t, uint16x8_t>(vitem); | |||
| vst1q_f32_x2(dst, vres); | |||
| } | |||
| void cvt_remain(const uint16_t* src, float* dst) { *dst = *src; } | |||
| }; | |||
| template <typename TypeCvter> | |||
| void do_typecvt( | |||
| const typename TypeCvter::stype* src, typename TypeCvter::dst_type* dst, | |||
| @@ -347,6 +390,43 @@ void do_typecvt( | |||
| } | |||
| } | |||
| template <typename TypeCvter> | |||
| void do_typecvt( | |||
| const typename TypeCvter::stype* src, typename TypeCvter::dst_type* dst, | |||
| DType src_dtype, DType dst_dtype, const TensorLayout& src_layout) { | |||
| TypeCvter typecvt(src_dtype, dst_dtype); | |||
| size_t calc_num = 1; | |||
| size_t nr_elems = src_layout.total_nr_elems(); | |||
| size_t src_stride = nr_elems; | |||
| //! adjust calc_num nr_elems and src_stride according to src_collapse_layout | |||
| auto src_collapse_layout = src_layout.collapse_contiguous(); | |||
| if (src_collapse_layout.ndim == 2) { | |||
| calc_num = src_collapse_layout.shape[0]; | |||
| nr_elems = src_collapse_layout.shape[1]; | |||
| src_stride = src_collapse_layout.stride[0]; | |||
| } | |||
| for (size_t c = 0; c < calc_num; ++c) { | |||
| size_t i = 0; | |||
| for (; i + TypeCvter::SIMD_WIDTH <= nr_elems; i += TypeCvter::SIMD_WIDTH) { | |||
| typecvt.cvt(src, dst); | |||
| src += TypeCvter::SIMD_WIDTH; | |||
| dst += TypeCvter::SIMD_WIDTH; | |||
| } | |||
| #if MEGDNN_FIX_AARCH32_BUG | |||
| // FIXME: as llvm may cause cannot select error if enable vectorize | |||
| #pragma clang loop vectorize(disable) | |||
| #endif | |||
| for (; i < nr_elems; i++) { | |||
| typecvt.cvt_remain(src, dst); | |||
| src++; | |||
| dst++; | |||
| } | |||
| src += src_stride - nr_elems; | |||
| } | |||
| } | |||
| } // anonymous namespace | |||
| void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) { | |||
| @@ -354,7 +434,30 @@ void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) { | |||
| DType dst_dtype = dst.layout.dtype; | |||
| size_t nr_elems = src.layout.total_nr_elems(); | |||
| bool execed = false; | |||
| if (src.layout.is_contiguous()) { | |||
| auto src_collapse_layout = src.layout.collapse_contiguous(); | |||
| bool has_int16_special_impl = | |||
| (src.layout.dtype.enumv() == DTypeEnum::Int16 || | |||
| src.layout.dtype.enumv() == DTypeEnum::Uint16) && | |||
| (src.layout.is_contiguous() || src_collapse_layout.ndim == 2) && | |||
| dst.layout.is_contiguous(); | |||
| if (has_int16_special_impl) { | |||
| using namespace dtype; | |||
| #define DISPATCH_FIX2FLOAT(_stype_enumv, _stype, _dtype_enumv, _dtype, _midout_iv) \ | |||
| if (src_dtype.enumv() == DTypeTrait<_stype_enumv>::enumv && \ | |||
| dst_dtype.enumv() == DTypeTrait<_dtype_enumv>::enumv) { \ | |||
| MIDOUT_BEGIN(megdnn_arm_typecvt_fix2float, midout_iv(_midout_iv)) { \ | |||
| using _TypeCvter = Fix2FloatTypeCvter<_stype, _dtype>; \ | |||
| MEGDNN_DISPATCH_CPU_KERN_OPR(do_typecvt<_TypeCvter>( \ | |||
| src.compatible_ptr<_stype>(), dst.compatible_ptr<_dtype>(), \ | |||
| src_dtype, dst_dtype, src.layout)); \ | |||
| execed = true; \ | |||
| } \ | |||
| MIDOUT_END(); \ | |||
| } | |||
| DISPATCH_FIX2FLOAT(Int16, int16_t, Float32, float, 0); | |||
| DISPATCH_FIX2FLOAT(Uint16, uint16_t, Float32, float, 1); | |||
| #undef DISPATCH_FIX2FLOAT | |||
| } else if (src.layout.is_contiguous()) { | |||
| using namespace dtype; | |||
| #define DISPATCH_QUANTIZED(_stype_enumv, _stype, _dtype_enumv, _dtype, _midout_iv) \ | |||
| if (src_dtype.enumv() == DTypeTrait<_stype_enumv>::enumv && \ | |||
| @@ -377,6 +480,7 @@ void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) { | |||
| DISPATCH_QUANTIZED(QuantizedS32, int32_t, QuantizedS32, int32_t, 5); | |||
| DISPATCH_QUANTIZED(float, float, QuantizedS8, int8_t, 6); | |||
| DISPATCH_QUANTIZED(float, float, Quantized8Asymm, uint8_t, 7); | |||
| #undef DISPATCH_QUANTIZED | |||
| #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
| #define DISPATCH_FLOAT(_stype_enumv, _stype, _dtype_enumv, _dtype, _midout_iv) \ | |||
| @@ -394,6 +498,7 @@ void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) { | |||
| } | |||
| DISPATCH_FLOAT(dt_float16, __fp16, float, float, 0); | |||
| DISPATCH_FLOAT(float, float, dt_float16, __fp16, 1); | |||
| #undef DISPATCH_FLOAT | |||
| #endif | |||
| } | |||
| if (!execed) { | |||
| @@ -150,6 +150,19 @@ bool ElemwiseLayoutHelper::is_broadcasted_channel_like( | |||
| return false; | |||
| } | |||
| bool ElemwiseLayoutHelper::is_NHWC_broadcasted_channel_like( | |||
| const TensorLayout& layout, BroadcastChannelInfo& info) { | |||
| if (layout.format.type() == TensorFormat::Type::DEFAULT) { | |||
| if (layout.ndim == 2 && layout.stride[1] == 1 && layout.stride[0] == 0) { | |||
| info.x = 1; | |||
| info.y = layout.shape[0]; | |||
| info.z = layout.shape[1]; | |||
| return true; | |||
| } | |||
| } | |||
| return false; | |||
| } | |||
| bool ElemwiseLayoutHelper::is_broadcasted_1x( | |||
| const TensorLayout& layout, Broadcast1xInfo& binfo) { | |||
| if (layout.ndim == 2 && layout.stride[0] == 0 && layout.stride[1] == 1) { | |||
| @@ -80,6 +80,16 @@ public: | |||
| static bool is_broadcasted_channel_like( | |||
| const TensorLayout& layout, BroadcastChannelInfo& info); | |||
| /*! | |||
| * \brief check whether layout matches BroadcastChannelInfo under NHWC | |||
| * layout | |||
| * | |||
| * Note that Input must be 2-dimensional, and must be [1, y] broadacsted | |||
| * into [z, y] and x would be set to 1. | |||
| */ | |||
| static bool is_NHWC_broadcasted_channel_like( | |||
| const TensorLayout& layout, BroadcastChannelInfo& info); | |||
| /*! | |||
| * \brief check whether layout matches BroadcastChannelInfo | |||
| * | |||
| @@ -309,7 +309,8 @@ void on_dest_ctype(_megdnn_tensor_in src, _megdnn_tensor_out dst) { | |||
| break; \ | |||
| } | |||
| MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | |||
| cb(::megdnn::dtype::Bool) case DTypeEnum::QuantizedS8 | |||
| cb(::megdnn::dtype::Bool) | |||
| cb(::megdnn::dtype::Uint16) case DTypeEnum::QuantizedS8 | |||
| : MIDOUT_BEGIN( | |||
| megdnn_fb_typecvt_src_dtype, | |||
| midout_iv(DTypeEnum::QuantizedS8)) { | |||
| @@ -467,7 +468,8 @@ void run_contiguous(_megdnn_tensor_in src, _megdnn_tensor_out dst) { | |||
| } | |||
| MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | |||
| cb(::megdnn::dtype::Bool) case DTypeEnum::QuantizedS8 | |||
| cb(::megdnn::dtype::Bool) | |||
| cb(::megdnn::dtype::Uint16) case DTypeEnum::QuantizedS8 | |||
| : MIDOUT_BEGIN( | |||
| megdnn_fb_typecvt_dst_dtype, | |||
| midout_iv(DTypeEnum::QuantizedS8)) { | |||
| @@ -78,7 +78,7 @@ void on_dest_ctype(HandleImpl* handle, const TensorND& dest, const TensorND& src | |||
| MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | |||
| MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) | |||
| MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb) | |||
| cb(::megdnn::dtype::Bool) | |||
| cb(::megdnn::dtype::Bool) cb(::megdnn::dtype::Uint16) | |||
| #undef cb | |||
| default : megdnn_throw("bad dtype"); | |||
| } | |||
| @@ -99,7 +99,7 @@ void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) { | |||
| MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | |||
| MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) | |||
| MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb) | |||
| cb(::megdnn::dtype::Bool) | |||
| cb(::megdnn::dtype::Bool) cb(::megdnn::dtype::Uint16) | |||
| #undef cb | |||
| default : megdnn_throw("bad dtype"); | |||
| } | |||
| @@ -14,6 +14,7 @@ | |||
| #include "test/common/benchmarker.h" | |||
| #include "test/common/checker.h" | |||
| #include "megdnn/opr_param_defs.h" | |||
| #include "megdnn/oprs/general.h" | |||
| using namespace megdnn; | |||
| @@ -298,6 +299,63 @@ TEST_F(ARM_COMMON, ELEMWISE_FORWARD_NCHW88_FP) { | |||
| #endif | |||
| } | |||
| TEST_F(ARM_COMMON, ELEMWISE_FORWARD_NHWC_FP32_BCAST) { | |||
| using Mode = ElemwiseForward::Param::Mode; | |||
| Checker<ElemwiseForward> checker(handle()); | |||
| UniformFloatRNG rng(1e-5, 7e1); | |||
| checker.set_rng(0, &rng); | |||
| checker.set_epsilon(1e-5); | |||
| checker.set_dtype(0, dtype::Float32()); | |||
| checker.set_dtype(1, dtype::Float32()); | |||
| //! 2 dim | |||
| auto run = [&](Mode mode) { | |||
| // VEC_BCAST111C | |||
| checker.set_param(mode).execs({{1, 2, 2, 12}, {1, 1, 1, 12}, {}}); | |||
| checker.set_param(mode).execs({{2, 5, 3, 28}, {1, 1, 1, 28}, {}}); | |||
| checker.set_param(mode).execs({{3, 5, 8, 32}, {1, 1, 1, 32}, {}}); | |||
| // BCAST111C_VEC | |||
| checker.set_param(mode).execs({{1, 1, 1, 12}, {1, 2, 2, 12}, {}}); | |||
| checker.set_param(mode).execs({{1, 1, 1, 28}, {2, 5, 3, 28}, {}}); | |||
| checker.set_param(mode).execs({{1, 1, 1, 32}, {3, 5, 8, 32}, {}}); | |||
| }; | |||
| run(Mode::ADD); | |||
| run(Mode::MUL); | |||
| run(Mode::SUB); | |||
| //! 3 dim contig | |||
| auto run_3d_contig = [&](Mode mode) { | |||
| // BCAST111C_VEC_BCAST111C | |||
| checker.set_param(mode).execs( | |||
| {{1, 1, 1, 12}, {1, 2, 2, 12}, {1, 1, 1, 12}, {}}); | |||
| checker.set_param(mode).execs( | |||
| {{1, 1, 1, 28}, {2, 5, 3, 28}, {1, 1, 1, 28}, {}}); | |||
| checker.set_param(mode).execs( | |||
| {{1, 1, 1, 32}, {3, 5, 8, 32}, {1, 1, 1, 32}, {}}); | |||
| // VEC_BCAST111C_VEC | |||
| checker.set_param(mode).execs( | |||
| {{1, 2, 2, 12}, {1, 1, 1, 12}, {1, 2, 2, 12}, {}}); | |||
| checker.set_param(mode).execs( | |||
| {{2, 5, 3, 28}, {1, 1, 1, 28}, {2, 5, 3, 28}, {}}); | |||
| checker.set_param(mode).execs( | |||
| {{3, 5, 8, 32}, {1, 1, 1, 32}, {3, 5, 8, 32}, {}}); | |||
| }; | |||
| run_3d_contig(Mode::FUSE_MUL_ADD3); | |||
| //! 3 dim incontig | |||
| auto run_3d_incontig = [&](Mode mode) { | |||
| megdnn::TensorLayout src0({1, 1, 1, 12}, dtype::Float32()); | |||
| megdnn::TensorLayout src1({1, 2, 2, 12}, {80, 40, 20, 1}, dtype::Float32()); | |||
| // BCAST111C_VEC_BCAST111C | |||
| checker.set_param(mode).execl({src0, src1, src0, {}}); | |||
| // VEC_BCAST111C_VEC | |||
| checker.set_param(mode).execl({src1, src0, src1, {}}); | |||
| }; | |||
| run_3d_incontig(Mode::FUSE_MUL_ADD3); | |||
| } | |||
| #if MEGDNN_WITH_BENCHMARK | |||
| namespace { | |||
| void run_elemwise_benchmark( | |||
| @@ -354,6 +412,39 @@ void run_elemwise_benchmark( | |||
| } | |||
| } // namespace | |||
| TEST_F(ARM_COMMON, BENCHMARK_NCHW_VS_NHWC) { | |||
| Benchmarker<Elemwise> benchmarker(handle()); | |||
| constexpr size_t RUN = 50; | |||
| benchmarker.set_times(RUN).set_display(false); | |||
| auto run = [&](size_t N, size_t C, size_t H, size_t W, param::Elemwise::Mode mode, | |||
| const char* mode_name) { | |||
| megdnn::param::Elemwise param; | |||
| param.mode = mode; | |||
| benchmarker.set_param(param); | |||
| megdnn::TensorShape nhwc_src0{N, H, W, C}; | |||
| megdnn::TensorShape nhwc_src1{1, 1, 1, C}; | |||
| megdnn::TensorShape nchw_src0{N, C, H, W}; | |||
| megdnn::TensorShape nchw_src1{1, C, 1, 1}; | |||
| float computations = N * C * H * W; | |||
| auto nhwc_time = benchmarker.execs({nhwc_src1, nhwc_src0, {}}) / RUN; | |||
| auto nchw_time = benchmarker.execs({nchw_src1, nchw_src0, {}}) / RUN; | |||
| auto perf_nhwc = computations / nhwc_time / 1e6; | |||
| auto perf_nchw = computations / nchw_time / 1e6; | |||
| printf("Elemwise Mode : %s\nNHWC : %fms %fGflops\nNCHW : %fms " | |||
| "%fGflops\n", | |||
| mode_name, nhwc_time, perf_nhwc, nchw_time, perf_nchw); | |||
| }; | |||
| run(1, 120, 16, 24, param::Elemwise::Mode::ADD, "ADD"); | |||
| run(1, 120, 16, 24, param::Elemwise::Mode::MUL, "MUL"); | |||
| run(1, 120, 32, 48, param::Elemwise::Mode::ADD, "ADD"); | |||
| run(1, 120, 32, 48, param::Elemwise::Mode::MUL, "MUL"); | |||
| run(1, 120, 64, 96, param::Elemwise::Mode::ADD, "ADD"); | |||
| run(1, 120, 64, 96, param::Elemwise::Mode::MUL, "MUL"); | |||
| } | |||
| #define INT_RUN(shape, mode) \ | |||
| run_elemwise_benchmark(shape, mode, #mode, dtype::Int8{}, handle()); \ | |||
| run_elemwise_benchmark(shape, mode, #mode, dtype::Int16{}, handle()); \ | |||
| @@ -88,6 +88,26 @@ TEST_F(ARM_COMMON, TYPE_CVT) { | |||
| .execs({{1, 32, 24, 128}, {1, 32, 24, 128}}); | |||
| } | |||
| TEST_F(ARM_COMMON, TYPE_CVT_16_F32) { | |||
| Checker<TypeCvt> checker(handle()); | |||
| UniformIntRNG rng{INT16_MIN >> 1, INT16_MAX >> 1}; | |||
| for (size_t size : {3, 7, 15, 33, 10000}) { | |||
| checker.set_rng(0, &rng); | |||
| checker.set_dtype(0, dtype::Int16()).execs({{size}, {size}}); | |||
| checker.set_dtype(0, dtype::Uint16()).execs({{size}, {size}}); | |||
| } | |||
| TensorLayout src_int16{ | |||
| {1, 96, 64, 120}, {128 * 64 * 96, 128 * 64, 128, 1}, dtype::Int16()}; | |||
| TensorLayout dst_int16{{1, 96, 64, 120}, dtype::Float32()}; | |||
| checker.execl({src_int16, dst_int16}); | |||
| TensorLayout src_uint16{ | |||
| {1, 96, 64, 120}, {128 * 64 * 96, 128 * 64, 128, 1}, dtype::Uint16()}; | |||
| TensorLayout dst_uint16{{1, 96, 64, 120}, dtype::Float32()}; | |||
| checker.execl({src_uint16, dst_uint16}); | |||
| } | |||
| #if MEGDNN_WITH_BENCHMARK | |||
| TEST_F(ARM_COMMON, BENCHMARK_TYPE_CVT) { | |||
| auto run = [&](const TensorShapeArray& shapes) { | |||
| @@ -158,8 +158,9 @@ void copy_tensors( | |||
| //! In order to avoid an unnecessary increase in binary size, we just | |||
| //! use QuantizedS16 dtype in winograd_filter_preprocess now. | |||
| cb(::megdnn::dtype::QuantizedS16) MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb) | |||
| cb(::megdnn::dtype::Uint16) | |||
| #undef cb | |||
| default : megdnn_trap(); | |||
| default : megdnn_trap(); | |||
| } | |||
| } | |||
| @@ -202,6 +202,9 @@ void IIDRNG::gen(const TensorND& tensor) { | |||
| memset(tensor.raw_ptr, 0, tensor.layout.access_bytes()); | |||
| return; | |||
| } | |||
| if (tensor.layout.dtype.enumv() == DTypeEnum::Uint16) { | |||
| return; | |||
| } | |||
| megdnn_assert( | |||
| 0, "IIDRNG does not know how to generate value for DType %s", | |||
| tensor.layout.dtype.name()); | |||
| @@ -25,6 +25,11 @@ TEST_F(CUDA, TYPE_CVT) { | |||
| TensorLayout src({10, 10}, sdtype), dst({10, 10}, ddtype); | |||
| Checker<TypeCvt> checker(handle_cuda()); | |||
| checker.set_rng(0, &init).exec(TensorLayoutArray{src, dst}); | |||
| TensorLayout non_contig_src( | |||
| {1, 96, 64, 120}, {96 * 64 * 128, 64 * 128, 128, 1}, sdtype); | |||
| TensorLayout non_contig_dst({1, 96, 64, 120}, ddtype); | |||
| checker.exec(TensorLayoutArray{non_contig_src, non_contig_dst}); | |||
| } | |||
| } | |||
| @@ -37,8 +37,22 @@ TEST_F(X86, TYPE_CVT) { | |||
| for (auto ddtype : dtypes) { | |||
| checker.set_dtype(0, sdtype).set_dtype(1, ddtype).execs( | |||
| {{size}, {size}}); | |||
| TensorLayout non_contig_src( | |||
| {1, 10, 10, 12}, {10 * 10 * 18, 10 * 18, 18, 1}, sdtype); | |||
| TensorLayout non_contig_dst({1, 10, 10, 12}, ddtype); | |||
| checker.exec(TensorLayoutArray{non_contig_src, non_contig_dst}); | |||
| } | |||
| } | |||
| for (size_t size : {1, 7, 15, 33}) { | |||
| checker.set_dtype(0, dtype::Uint16()) | |||
| .set_dtype(1, dtype::Float32()) | |||
| .execs({{size}, {size}}); | |||
| } | |||
| TensorLayout non_contig_src( | |||
| {1, 10, 10, 12}, {10 * 10 * 18, 10 * 18, 18, 1}, dtype::Uint16()); | |||
| TensorLayout non_contig_dst({1, 10, 10, 12}, dtype::Float32()); | |||
| checker.exec(TensorLayoutArray{non_contig_src, non_contig_dst}); | |||
| } | |||
| TEST_F(X86, TYPE_CVT_NO_CONTIGUOUS) { | |||
| @@ -772,8 +772,10 @@ void TypeCvt::perform( | |||
| } | |||
| void TypeCvt::add_input_layout_constraint() { | |||
| //! Because the implementation of typecvt on arm/x86/cuda/opencl support | |||
| //! non-contiguous memory. So we change constraint of typecvt to monotone | |||
| for (auto i : input()) { | |||
| i->add_layout_constraint_contiguous(); | |||
| i->add_layout_constraint_monotone(); | |||
| } | |||
| } | |||