GitOrigin-RevId: fe7b335545
tags/v1.10.0
| @@ -15,7 +15,7 @@ | |||
| #include "src/arm_common/elemwise_helper/op_binary.h" | |||
| #include "src/arm_common/elemwise_helper/op_ternary.h" | |||
| #include "src/arm_common/elemwise_helper/op_unary.h" | |||
| #include "src/fallback/elemwise_helper/elemwise_op.h" | |||
| #include "src/fallback/elemwise_helper/op_common.h" | |||
| namespace megdnn { | |||
| namespace elemwise { | |||
| @@ -364,17 +364,9 @@ void ElemwiseMultiTypeImpl::on_quantized_mode( | |||
| } | |||
| #define DISPATCH() \ | |||
| if (param[0].layout.dtype.enumv() == DTypeEnum::QuantizedS8 && \ | |||
| dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8) { \ | |||
| DISPATCH_QUANTIZED_MODE(dtype::QuantizedS8, dtype::QuantizedS8) \ | |||
| } else if ( \ | |||
| param[0].layout.dtype.enumv() == DTypeEnum::Quantized8Asymm && \ | |||
| dst.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm) { \ | |||
| if (param[0].layout.dtype.enumv() == DTypeEnum::Quantized8Asymm && \ | |||
| dst.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm) { \ | |||
| DISPATCH_QUANTIZED_MODE(dtype::Quantized8Asymm, dtype::Quantized8Asymm) \ | |||
| } else if ( \ | |||
| param[0].layout.dtype.enumv() == DTypeEnum::QuantizedS32 && \ | |||
| dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8) { \ | |||
| DISPATCH_MODE(dtype::QuantizedS32, dtype::QuantizedS8) \ | |||
| } else if ( \ | |||
| param[0].layout.dtype.enumv() == DTypeEnum::QuantizedS32 && \ | |||
| dst.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm) { \ | |||
| @@ -467,16 +459,8 @@ void ElemwiseMultiTypeImpl::on_quantized_mode( | |||
| #define DISPATCH() \ | |||
| if (param[0].layout.dtype.enumv() == DTypeEnum::QuantizedS32 && \ | |||
| dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8) { \ | |||
| DISPATCH_MODE(dtype::QuantizedS32, dtype::QuantizedS8) \ | |||
| } else if ( \ | |||
| param[0].layout.dtype.enumv() == DTypeEnum::QuantizedS32 && \ | |||
| dst.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm) { \ | |||
| dst.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm) { \ | |||
| DISPATCH_MODE(dtype::QuantizedS32, dtype::Quantized8Asymm) \ | |||
| } else if ( \ | |||
| param[0].layout.dtype.enumv() == DTypeEnum::QuantizedS8 && \ | |||
| dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8) { \ | |||
| DISPATCH_QUANTIZED_MODE(dtype::QuantizedS8, dtype::QuantizedS8) \ | |||
| } else if ( \ | |||
| param[0].layout.dtype.enumv() == DTypeEnum::Quantized8Asymm && \ | |||
| dst.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm) { \ | |||
| @@ -701,12 +685,8 @@ void ElemwiseMultiTypeImpl::on_quantized_mode( | |||
| } | |||
| #define DISPATCH() \ | |||
| if (param[0].layout.dtype.enumv() == DTypeEnum::QuantizedS8 && \ | |||
| dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8) { \ | |||
| DISPATCH_QUANTIZED_MODE(dtype::QuantizedS8, dtype::QuantizedS8) \ | |||
| } else if ( \ | |||
| param[0].layout.dtype.enumv() == DTypeEnum::Quantized8Asymm && \ | |||
| dst.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm) { \ | |||
| if (param[0].layout.dtype.enumv() == DTypeEnum::Quantized8Asymm && \ | |||
| dst.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm) { \ | |||
| DISPATCH_QUANTIZED_MODE(dtype::Quantized8Asymm, dtype::Quantized8Asymm) \ | |||
| } | |||
| @@ -12,61 +12,4 @@ | |||
| #include "src/fallback/general_intrinsic/gi_float.h" | |||
| #include "src/fallback/general_intrinsic/gi_int.h" | |||
| namespace megdnn { | |||
| namespace elemwise { | |||
| ///////////////////////////////// ParamElemVistor /////////////////////////// | |||
| #define cb(_ctype, _inner_ctype, _simd_type, _fun_suffix) \ | |||
| template <> \ | |||
| struct ParamElemVisitor<_ctype> { \ | |||
| _simd_type operator()(const _ctype* src) const { \ | |||
| return GiLoad##_fun_suffix(src); \ | |||
| } \ | |||
| }; \ | |||
| template <> \ | |||
| struct ParamElemVisitorDup<_ctype> { \ | |||
| _simd_type operator()(const _ctype* src) const { \ | |||
| return GiBroadcast##_fun_suffix( \ | |||
| *reinterpret_cast<const _inner_ctype*>(src)); \ | |||
| } \ | |||
| } | |||
| cb(dt_qint32, int32_t, GI_INT32_t, Int32); | |||
| cb(dt_qint8, int8_t, GI_INT8_t, Int8); | |||
| cb(dt_float32, float, GI_FLOAT32_t, Float32); | |||
| cb(dt_int32, int32_t, GI_INT32_t, Int32); | |||
| cb(dt_int8, int8_t, GI_INT8_t, Int8); | |||
| #undef cb | |||
| template <typename ctype> | |||
| struct ParamElemVisitorBcast101x4; | |||
| #define cb(_ctype, _inner_ctype, _simd_type, _fun_suffix, rel_suffix) \ | |||
| template <> \ | |||
| struct ParamElemVisitorBcast101x4<_ctype> { \ | |||
| _simd_type operator()(const _ctype* src) const { \ | |||
| return GiReinter##rel_suffix##To##_fun_suffix(GiBroadcast##rel_suffix( \ | |||
| *reinterpret_cast<const _inner_ctype*>(src))); \ | |||
| } \ | |||
| } | |||
| cb(dt_qint8, int32_t, GI_INT8_t, Int8, Int32); | |||
| cb(dt_int8, int32_t, GI_INT8_t, Int8, Int32); | |||
| #undef cb | |||
| #define cb(_ctype, _inner_ctype, _simd_type, _fun_suffix) \ | |||
| template <> \ | |||
| struct ParamElemVisitorBcast101x4<_ctype> { \ | |||
| _simd_type operator()(const _ctype* src) const { \ | |||
| return GiLoad##_fun_suffix(src); \ | |||
| } \ | |||
| } | |||
| cb(dt_qint32, int32_t, GI_INT32_t, Int32); | |||
| cb(dt_float32, float, GI_FLOAT32_t, Float32); | |||
| cb(dt_int32, int32_t, GI_INT32_t, Int32); | |||
| #undef cb | |||
| } // namespace elemwise | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -87,7 +87,7 @@ template <> | |||
| struct FuseAddHSwishOp<dt_qint32, dt_qint8> : FuseAddHSwishOpBase<dt_qint32, dt_qint8> { | |||
| using FuseAddHSwishOpBase::FuseAddHSwishOpBase; | |||
| using FuseAddHSwishOpBase::operator(); | |||
| constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int8_t); | |||
| constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int32_t); | |||
| void operator()( | |||
| const GI_INT32_V2_t& vsrc0, const GI_INT32_V2_t& vsrc1, | |||
| dt_qint8* dst) const { | |||
| @@ -41,7 +41,7 @@ struct UnaryOpBase : OpBase<src_ctype, dst_ctype> { | |||
| GiStoreLowInt8( \ | |||
| reinterpret_cast<int8_t*>(dst + 8), \ | |||
| operator()({{GiMoveLowLongInt16(vsrct1), GiMoveHighLongInt16(vsrct1)}})); \ | |||
| GI_INT16_t vsrct2 = GiMoveHighLongInt8(vsrc.val[1]); \ | |||
| GI_INT16_t vsrct2 = GiMoveLowLongInt8(vsrc.val[1]); \ | |||
| GiStoreLowInt8( \ | |||
| reinterpret_cast<int8_t*>(dst + 16), \ | |||
| operator()({{GiMoveLowLongInt16(vsrct2), GiMoveHighLongInt16(vsrct2)}})); \ | |||
| @@ -330,7 +330,7 @@ struct UnaryQuantizationOp; | |||
| template <typename Op> | |||
| struct UnaryQuantizationOp<dt_qint8, dt_qint8, Op> : UnaryOpBase<dt_qint8, dt_qint8> { | |||
| using UnaryOpBase<dt_qint8, dt_qint8>::UnaryOpBase; | |||
| constexpr static size_t SIMD_WIDTH = 16; | |||
| constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int8_t); | |||
| Op op; | |||
| void operator()(const dt_qint8& src, dt_qint8* dst) const { | |||
| @@ -354,7 +354,7 @@ struct UnaryQuantizationOp<dt_qint8, dt_qint8, Op> : UnaryOpBase<dt_qint8, dt_qi | |||
| auto val = this->op({{vitem0, vitem1}}); | |||
| val.val[0] = GiMultiplyFloat32(val.val[0], this->vscale_dst); | |||
| val.val[1] = GiMultiplyFloat32(val.val[1], this->vscale_dst); | |||
| return QConverter::convert<GI_INT8_t, GI_FLOAT32_V4_t>(val); | |||
| return QConverter::convert<GI_INT8_t, GI_FLOAT32_V2_t>(val); | |||
| } | |||
| }; | |||
| @@ -364,7 +364,7 @@ struct BinaryQuantizationOp; | |||
| template <typename Op> | |||
| struct BinaryQuantizationOp<dt_qint8, dt_qint8, Op> : BinaryOpBase<dt_qint8, dt_qint8> { | |||
| using BinaryOpBase<dt_qint8, dt_qint8>::BinaryOpBase; | |||
| constexpr static size_t SIMD_WIDTH = 16; | |||
| constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int8_t); | |||
| Op op; | |||
| void operator()(const dt_qint8& src0, const dt_qint8& src1, dt_qint8* dst) const { | |||
| @@ -403,7 +403,7 @@ template <typename Op> | |||
| struct TernaryQuantizationOp<dt_qint8, dt_qint8, Op> | |||
| : TernaryOpBase<dt_qint8, dt_qint8> { | |||
| using TernaryOpBase<dt_qint8, dt_qint8>::TernaryOpBase; | |||
| constexpr static size_t SIMD_WIDTH = 16; | |||
| constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int8_t); | |||
| Op op; | |||
| void operator()( | |||
| @@ -69,7 +69,7 @@ struct ReluOpBase<dt_qint8, dt_qint8> : UnaryOpBase<dt_qint8, dt_qint8> { | |||
| template <> | |||
| struct ReluOp<dt_qint8, dt_qint8> : ReluOpBase<dt_qint8, dt_qint8> { | |||
| using ReluOpBase::ReluOpBase; | |||
| constexpr static size_t SIMD_WIDTH = 16; | |||
| constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int8_t); | |||
| using ReluOpBase::operator(); | |||
| void operator()(const GI_INT8_V2_t& vsrc, dt_qint8* dst) const { | |||
| @@ -8,6 +8,7 @@ | |||
| namespace megdnn { | |||
| namespace elemwise { | |||
| /*! | |||
| * \brief broadcast type | |||
| * BCAST_x[0]x[1]...: x[i] == !stride[i] | |||
| @@ -49,6 +50,55 @@ struct ParamElemVisitorDup; | |||
| template <typename ctype> | |||
| struct ParamElemVisitorBcast101x4; | |||
| #define cb(_ctype, _inner_ctype, _simd_type, _fun_suffix) \ | |||
| template <> \ | |||
| struct ParamElemVisitor<_ctype> { \ | |||
| _simd_type operator()(const _ctype* src) const { \ | |||
| return GiLoad##_fun_suffix(src); \ | |||
| } \ | |||
| }; \ | |||
| template <> \ | |||
| struct ParamElemVisitorDup<_ctype> { \ | |||
| _simd_type operator()(const _ctype* src) const { \ | |||
| return GiBroadcast##_fun_suffix( \ | |||
| *reinterpret_cast<const _inner_ctype*>(src)); \ | |||
| } \ | |||
| } | |||
| cb(dt_qint32, int32_t, GI_INT32_t, Int32); | |||
| cb(dt_qint8, int8_t, GI_INT8_t, Int8); | |||
| cb(dt_float32, float, GI_FLOAT32_t, Float32); | |||
| cb(dt_int32, int32_t, GI_INT32_t, Int32); | |||
| cb(dt_int8, int8_t, GI_INT8_t, Int8); | |||
| #undef cb | |||
| template <typename ctype> | |||
| struct ParamElemVisitorBcast101x4; | |||
| #define cb(_ctype, _inner_ctype, _simd_type, _fun_suffix, rel_suffix) \ | |||
| template <> \ | |||
| struct ParamElemVisitorBcast101x4<_ctype> { \ | |||
| _simd_type operator()(const _ctype* src) const { \ | |||
| return GiReinter##rel_suffix##To##_fun_suffix(GiBroadcast##rel_suffix( \ | |||
| *reinterpret_cast<const _inner_ctype*>(src))); \ | |||
| } \ | |||
| } | |||
| cb(dt_qint8, int32_t, GI_INT8_t, Int8, Int32); | |||
| cb(dt_int8, int32_t, GI_INT8_t, Int8, Int32); | |||
| #undef cb | |||
| #define cb(_ctype, _inner_ctype, _simd_type, _fun_suffix) \ | |||
| template <> \ | |||
| struct ParamElemVisitorBcast101x4<_ctype> { \ | |||
| _simd_type operator()(const _ctype* src) const { \ | |||
| return GiLoad##_fun_suffix(src); \ | |||
| } \ | |||
| } | |||
| cb(dt_qint32, int32_t, GI_INT32_t, Int32); | |||
| cb(dt_float32, float, GI_FLOAT32_t, Float32); | |||
| cb(dt_int32, int32_t, GI_INT32_t, Int32); | |||
| #undef cb | |||
| ///////////////////////////////// OpCaller ///////////////////////////// | |||
| template <typename Op, BcastType bcast_type> | |||
| struct OpCallerUnary; | |||
| @@ -50,6 +50,18 @@ protected: | |||
| void on_fuse_mul_add3_uint8xf32xf32xf32( | |||
| const ElemwiseOpParamN<3>& param, const TensorND& dst) override; | |||
| void on_quantized_mode( | |||
| const ElemwiseOpParamN<1>& param, const TensorND& dst, | |||
| Elemwise::Mode mode) override; | |||
| void on_quantized_mode( | |||
| const ElemwiseOpParamN<2>& param, const TensorND& dst, | |||
| Elemwise::Mode mode) override; | |||
| void on_quantized_mode( | |||
| const ElemwiseOpParamN<3>& param, const TensorND& dst, | |||
| Elemwise::Mode mode) override; | |||
| public: | |||
| using naive::ElemwiseMultiTypeImpl::ElemwiseMultiTypeImpl; | |||
| }; | |||
| @@ -0,0 +1,499 @@ | |||
| /** | |||
| * \file dnn/src/fallback/elemwise_multi_type/quantized_impl.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "megdnn/tensor_iter.h" | |||
| #include "src/fallback/elemwise_helper/elemwise_op.h" | |||
| #include "src/fallback/elemwise_multi_type/opr_impl.h" | |||
| #include "src/naive/handle.h" | |||
| using namespace megdnn; | |||
| using namespace fallback; | |||
| using namespace elemwise; | |||
| void ElemwiseMultiTypeImpl::on_quantized_mode( | |||
| const ElemwiseOpParamN<1>& param, const TensorND& dst, Elemwise::Mode mode) { | |||
| megdnn_assert(param[0].layout.dtype.category() == DTypeCategory::QUANTIZED); | |||
| megdnn_assert(dst.layout.dtype.category() == DTypeCategory::QUANTIZED); | |||
| #define DISPATCH_MODE(_src_dt, _dst_dt) \ | |||
| switch (mode) { \ | |||
| DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::RELU, ReluOp) \ | |||
| DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::H_SWISH, HSwishOp) \ | |||
| default: \ | |||
| break; \ | |||
| } | |||
| #define DISPATCH_QUANTIZED_MODE(_src_dt, _dst_dt) \ | |||
| switch (mode) { \ | |||
| DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::RELU, ReluOp) \ | |||
| DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::ABS, AbsOp) \ | |||
| DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::SIGMOID, SigmoidOp) \ | |||
| DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::EXP, ExpOp) \ | |||
| DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::TANH, TanhOp) \ | |||
| DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::FAST_TANH, FastTanhOp) \ | |||
| DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::H_SWISH, HSwishOp) \ | |||
| default: \ | |||
| break; \ | |||
| } | |||
| #define DISPATCH() \ | |||
| if (param[0].layout.dtype.enumv() == DTypeEnum::QuantizedS8 && \ | |||
| dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8) { \ | |||
| DISPATCH_QUANTIZED_MODE(dtype::QuantizedS8, dtype::QuantizedS8) \ | |||
| } else if ( \ | |||
| param[0].layout.dtype.enumv() == DTypeEnum::QuantizedS32 && \ | |||
| dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8) { \ | |||
| DISPATCH_MODE(dtype::QuantizedS32, dtype::QuantizedS8) \ | |||
| } | |||
| TensorND src = param[0]; | |||
| size_t nr_elems = src.layout.total_nr_elems(); | |||
| #define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \ | |||
| case _mode: { \ | |||
| using src_ctype = typename DTypeTrait<_src_dt>::ctype; \ | |||
| using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \ | |||
| thin_function<void(const src_ctype*, dst_ctype*, DType, DType, size_t)> run = \ | |||
| OpCallerUnary<_op<src_ctype, dst_ctype>, VEC>::run; \ | |||
| MEGDNN_DISPATCH_CPU_KERN_OPR( \ | |||
| run(src.ptr<src_ctype>(), dst.ptr<dst_ctype>(), src.layout.dtype, \ | |||
| dst.layout.dtype, nr_elems)); \ | |||
| return; \ | |||
| } | |||
| DISPATCH() | |||
| naive::ElemwiseMultiTypeImpl::on_quantized_mode(param, dst, mode); | |||
| #undef DISPATCH_SINGLE_MODE | |||
| #undef DISPATCH | |||
| #undef DISPATCH_QUANTIZED_MODE | |||
| #undef DISPATCH_MODE | |||
| } | |||
| void ElemwiseMultiTypeImpl::on_quantized_mode( | |||
| const ElemwiseOpParamN<2>& param, const TensorND& dst, Elemwise::Mode mode) { | |||
| megdnn_assert( | |||
| param[0].layout.dtype.enumv() == param[1].layout.dtype.enumv() && | |||
| param[0].layout.dtype.category() == DTypeCategory::QUANTIZED); | |||
| megdnn_assert(dst.layout.dtype.category() == DTypeCategory::QUANTIZED); | |||
| #define DISPATCH_MODE(_src_dt, _dst_dt) \ | |||
| switch (mode) { \ | |||
| DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::ADD, AddOp) \ | |||
| DISPATCH_SINGLE_MODE( \ | |||
| _src_dt, _dst_dt, Elemwise::Mode::FUSE_ADD_RELU, FuseAddReluOp) \ | |||
| DISPATCH_SINGLE_MODE( \ | |||
| _src_dt, _dst_dt, Elemwise::Mode::FUSE_ADD_H_SWISH, FuseAddHSwishOp) \ | |||
| default: \ | |||
| break; \ | |||
| } | |||
| #define DISPATCH_QUANTIZED_MODE(_src_dt, _dst_dt) \ | |||
| switch (mode) { \ | |||
| DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::ADD, AddOp) \ | |||
| DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::MIN, MinOp) \ | |||
| DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::MAX, MaxOp) \ | |||
| DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::SUB, SubOp) \ | |||
| DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::MUL, MulOp) \ | |||
| DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::TRUE_DIV, TrueDivOp) \ | |||
| DISPATCH_SINGLE_MODE( \ | |||
| _src_dt, _dst_dt, Elemwise::Mode::FUSE_ADD_RELU, FuseAddReluOp) \ | |||
| DISPATCH_SINGLE_MODE( \ | |||
| _src_dt, _dst_dt, Elemwise::Mode::FUSE_ADD_SIGMOID, FuseAddSigmoidOp) \ | |||
| DISPATCH_SINGLE_MODE( \ | |||
| _src_dt, _dst_dt, Elemwise::Mode::FUSE_ADD_TANH, FuseAddTanhOp) \ | |||
| DISPATCH_SINGLE_MODE( \ | |||
| _src_dt, _dst_dt, Elemwise::Mode::FUSE_ADD_H_SWISH, FuseAddHSwishOp) \ | |||
| default: \ | |||
| break; \ | |||
| } | |||
| #define DISPATCH() \ | |||
| if (param[0].layout.dtype.enumv() == DTypeEnum::QuantizedS32 && \ | |||
| dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8) { \ | |||
| DISPATCH_MODE(dtype::QuantizedS32, dtype::QuantizedS8) \ | |||
| } else if ( \ | |||
| param[0].layout.dtype.enumv() == DTypeEnum::QuantizedS8 && \ | |||
| dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8) { \ | |||
| DISPATCH_QUANTIZED_MODE(dtype::QuantizedS8, dtype::QuantizedS8) \ | |||
| } | |||
| TensorND src0 = param[0]; | |||
| TensorND src1 = param[1]; | |||
| //! VEC + VEC | |||
| if (is_vector(src0.layout) && is_vector(src1.layout)) { | |||
| size_t nr_elems = src0.layout.total_nr_elems(); | |||
| #define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \ | |||
| case _mode: { \ | |||
| using src_ctype = typename DTypeTrait<_src_dt>::ctype; \ | |||
| using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \ | |||
| thin_function<void( \ | |||
| const src_ctype*, const src_ctype*, dst_ctype*, DType, DType, DType, \ | |||
| size_t)> \ | |||
| run = OpCallerBinary<_op<src_ctype, dst_ctype>, VEC_VEC>::run; \ | |||
| MEGDNN_DISPATCH_CPU_KERN_OPR(run( \ | |||
| src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), dst.ptr<dst_ctype>(), \ | |||
| src0.layout.dtype, src1.layout.dtype, dst.layout.dtype, nr_elems)); \ | |||
| return; \ | |||
| } | |||
| DISPATCH() | |||
| #undef DISPATCH_SINGLE_MODE | |||
| } | |||
| //! VEC + SCALAR | |||
| { | |||
| bool normal_case = is_vector(src0.layout) && is_broadcasted_scalar(src1.layout); | |||
| bool swap_case = false; | |||
| bool commutable = false; | |||
| if (mode != Elemwise::Mode::SUB && mode != Elemwise::Mode::TRUE_DIV) | |||
| commutable = true; | |||
| if (!normal_case && commutable) { | |||
| swap_case = is_vector(src1.layout) && is_broadcasted_scalar(src0.layout); | |||
| } | |||
| if (normal_case || swap_case) { | |||
| auto &lhs = src0, &rhs = src1; | |||
| if (swap_case) { | |||
| std::swap(lhs, rhs); | |||
| } | |||
| #define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \ | |||
| case _mode: { \ | |||
| using src_ctype = typename DTypeTrait<_src_dt>::ctype; \ | |||
| using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \ | |||
| thin_function<void( \ | |||
| const src_ctype*, const src_ctype, dst_ctype*, DType, DType, DType, \ | |||
| size_t)> \ | |||
| run = OpCallerBinary<_op<src_ctype, dst_ctype>, VEC_SCALAR>::run; \ | |||
| MEGDNN_DISPATCH_CPU_KERN_OPR( \ | |||
| run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>()[0], \ | |||
| dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \ | |||
| dst.layout.dtype, src0.layout.total_nr_elems())); \ | |||
| return; \ | |||
| } | |||
| DISPATCH() | |||
| #undef DISPATCH_SINGLE_MODE | |||
| } | |||
| //! SCALAR + VEC | |||
| if (!commutable && is_vector(src1.layout) && | |||
| is_broadcasted_scalar(src0.layout)) { | |||
| #define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \ | |||
| case _mode: { \ | |||
| using src_ctype = typename DTypeTrait<_src_dt>::ctype; \ | |||
| using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \ | |||
| thin_function<void( \ | |||
| const src_ctype, const src_ctype*, dst_ctype*, DType, DType, DType, \ | |||
| size_t)> \ | |||
| run = OpCallerBinary<_op<src_ctype, dst_ctype>, SCALAR_VEC>::run; \ | |||
| MEGDNN_DISPATCH_CPU_KERN_OPR( \ | |||
| run(src0.ptr<src_ctype>()[0], src1.ptr<src_ctype>(), \ | |||
| dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \ | |||
| dst.layout.dtype, src1.layout.total_nr_elems())); \ | |||
| return; \ | |||
| } | |||
| DISPATCH() | |||
| #undef DISPATCH_SINGLE_MODE | |||
| } | |||
| } | |||
| //! VEC + BCAST101 | |||
| { | |||
| BroadcastChannelInfo binfo; | |||
| bool normal_case = is_vector(src0.layout) && | |||
| is_broadcasted_channel_like(src1.layout, binfo); | |||
| bool swap_case = false; | |||
| bool commutable = false; | |||
| if (mode != Elemwise::Mode::SUB && mode != Elemwise::Mode::TRUE_DIV) | |||
| commutable = true; | |||
| if (!normal_case && commutable) { | |||
| swap_case = is_vector(src1.layout) && | |||
| is_broadcasted_channel_like(src0.layout, binfo); | |||
| } | |||
| if (normal_case || swap_case) { | |||
| auto &lhs = src0, &rhs = src1; | |||
| if (swap_case) | |||
| std::swap(lhs, rhs); | |||
| #define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \ | |||
| case _mode: { \ | |||
| using src_ctype = typename DTypeTrait<_src_dt>::ctype; \ | |||
| using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \ | |||
| thin_function<void( \ | |||
| const src_ctype*, const src_ctype*, dst_ctype*, DType, DType, DType, \ | |||
| size_t, size_t, size_t)> \ | |||
| run = OpCallerBinary<_op<src_ctype, dst_ctype>, VEC_BCAST101>::run; \ | |||
| MEGDNN_DISPATCH_CPU_KERN_OPR( \ | |||
| run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \ | |||
| dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \ | |||
| dst.layout.dtype, binfo.x, binfo.y, binfo.z)); \ | |||
| return; \ | |||
| } | |||
| DISPATCH() | |||
| #undef DISPATCH_SINGLE_MODE | |||
| } | |||
| //! BCAST101 + VEC : only for SUB or TRUE_DIV | |||
| if (!commutable && is_vector(src1.layout) && | |||
| is_broadcasted_channel_like(src0.layout, binfo)) { | |||
| #define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \ | |||
| case _mode: { \ | |||
| using src_ctype = typename DTypeTrait<_src_dt>::ctype; \ | |||
| using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \ | |||
| thin_function<void( \ | |||
| const src_ctype*, const src_ctype*, dst_ctype*, DType, DType, DType, \ | |||
| size_t, size_t, size_t)> \ | |||
| run = OpCallerBinary<_op<src_ctype, dst_ctype>, BCAST101_VEC>::run; \ | |||
| MEGDNN_DISPATCH_CPU_KERN_OPR( \ | |||
| run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \ | |||
| dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \ | |||
| dst.layout.dtype, binfo.x, binfo.y, binfo.z)); \ | |||
| return; \ | |||
| } | |||
| DISPATCH() | |||
| #undef DISPATCH_SINGLE_MODE | |||
| } | |||
| } | |||
| //! VEC + BCAST101x4 | |||
| { | |||
| BroadcastChannelInfo binfo; | |||
| if (is_vector(src0.layout) && | |||
| (is_broadcastedx_channel_like<4>(src1.layout, binfo) || | |||
| is_broadcastedx_channel_like<8>(src1.layout, binfo))) { | |||
| #define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \ | |||
| case _mode: { \ | |||
| using src_ctype = typename DTypeTrait<_src_dt>::ctype; \ | |||
| using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \ | |||
| thin_function<void( \ | |||
| const src_ctype*, const src_ctype*, dst_ctype*, DType, DType, DType, \ | |||
| size_t, size_t, size_t, size_t)> \ | |||
| run = OpCallerBinary<_op<src_ctype, dst_ctype>, VEC_BCAST101xX>::run; \ | |||
| MEGDNN_DISPATCH_CPU_KERN_OPR( \ | |||
| run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \ | |||
| dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \ | |||
| dst.layout.dtype, batch_size, binfo.x, binfo.y, binfo.z)); \ | |||
| return; \ | |||
| } | |||
| size_t batch_size = src0.layout.shape[0] / (binfo.x * binfo.y * binfo.z); | |||
| DISPATCH() | |||
| #undef DISPATCH_SINGLE_MODE | |||
| } | |||
| //! BCAST101x + VEC | |||
| if (is_vector(src1.layout) && | |||
| is_broadcastedx_channel_like<4>(src0.layout, binfo)) { | |||
| #define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \ | |||
| case _mode: { \ | |||
| using src_ctype = typename DTypeTrait<_src_dt>::ctype; \ | |||
| using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \ | |||
| thin_function<void( \ | |||
| const src_ctype*, const src_ctype*, dst_ctype*, DType, DType, DType, \ | |||
| size_t, size_t, size_t, size_t)> \ | |||
| run = OpCallerBinary<_op<src_ctype, dst_ctype>, BCAST101xX_VEC>::run; \ | |||
| MEGDNN_DISPATCH_CPU_KERN_OPR( \ | |||
| run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \ | |||
| dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \ | |||
| dst.layout.dtype, batch_size, binfo.x, binfo.y, binfo.z)); \ | |||
| return; \ | |||
| } | |||
| size_t batch_size = src1.layout.shape[0] / (binfo.x * binfo.y * binfo.z); | |||
| DISPATCH() | |||
| #undef DISPATCH_SINGLE_MODE | |||
| } | |||
| } | |||
| naive::ElemwiseMultiTypeImpl::on_quantized_mode(param, dst, mode); | |||
| #undef DISPATCH_MODE | |||
| #undef DISPATCH_QUANTIZED_MODE | |||
| #undef DISPATCH | |||
| } | |||
| void ElemwiseMultiTypeImpl::on_quantized_mode( | |||
| const ElemwiseOpParamN<3>& param, const TensorND& dst, Elemwise::Mode mode) { | |||
| megdnn_assert( | |||
| param[0].layout.dtype.enumv() == param[1].layout.dtype.enumv() && | |||
| param[0].layout.dtype.enumv() == param[2].layout.dtype.enumv() && | |||
| param[0].layout.dtype.category() == DTypeCategory::QUANTIZED); | |||
| megdnn_assert(dst.layout.dtype.category() == DTypeCategory::QUANTIZED); | |||
| #define DISPATCH_QUANTIZED_MODE(_src_dt, _dst_dt) \ | |||
| switch (mode) { \ | |||
| DISPATCH_SINGLE_MODE( \ | |||
| _src_dt, _dst_dt, Elemwise::Mode::FUSE_MUL_ADD3, FuseMulAdd3Op) \ | |||
| default: \ | |||
| break; \ | |||
| } | |||
| #define DISPATCH() \ | |||
| if (param[0].layout.dtype.enumv() == DTypeEnum::QuantizedS8 && \ | |||
| dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8) { \ | |||
| DISPATCH_QUANTIZED_MODE(dtype::QuantizedS8, dtype::QuantizedS8) \ | |||
| } | |||
| TensorND src0 = param[0]; | |||
| TensorND src1 = param[1]; | |||
| TensorND src2 = param[2]; | |||
| //! VEC + VEC + VEC | |||
| if (is_vector(src0.layout) && is_vector(src1.layout) && is_vector(src2.layout)) { | |||
| size_t nr_elems = src0.layout.total_nr_elems(); | |||
| #define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \ | |||
| case _mode: { \ | |||
| using src_ctype = typename DTypeTrait<_src_dt>::ctype; \ | |||
| using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \ | |||
| thin_function<void( \ | |||
| const src_ctype*, const src_ctype*, const src_ctype*, dst_ctype*, \ | |||
| DType, DType, DType, DType, size_t)> \ | |||
| run = OpCallerTernary<_op<src_ctype, dst_ctype>, VEC_VEC_VEC>::run; \ | |||
| MEGDNN_DISPATCH_CPU_KERN_OPR(run( \ | |||
| src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), src2.ptr<src_ctype>(), \ | |||
| dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \ | |||
| src2.layout.dtype, dst.layout.dtype, nr_elems)); \ | |||
| return; \ | |||
| } | |||
| DISPATCH() | |||
| #undef DISPATCH_SINGLE_MODE | |||
| } | |||
| //! VEC + VEC + SCALAR | |||
| if (is_vector(src0.layout) && is_vector(src1.layout) && | |||
| is_broadcasted_scalar(src2.layout)) { | |||
| #define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \ | |||
| case _mode: { \ | |||
| using src_ctype = typename DTypeTrait<_src_dt>::ctype; \ | |||
| using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \ | |||
| thin_function<void( \ | |||
| const src_ctype*, const src_ctype*, const src_ctype, dst_ctype*, \ | |||
| DType, DType, DType, DType, size_t)> \ | |||
| run = OpCallerTernary<_op<src_ctype, dst_ctype>, VEC_VEC_SCALAR>::run; \ | |||
| MEGDNN_DISPATCH_CPU_KERN_OPR( \ | |||
| run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \ | |||
| src2.ptr<src_ctype>()[0], dst.ptr<dst_ctype>(), src0.layout.dtype, \ | |||
| src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ | |||
| src0.layout.total_nr_elems())); \ | |||
| return; \ | |||
| } | |||
| DISPATCH() | |||
| #undef DISPATCH_SINGLE_MODE | |||
| } | |||
| //! BCAST101 + VEC + BCAST101 | |||
| { | |||
| BroadcastChannelInfo binfo; | |||
| bool normal_case = is_vector(src1.layout) && | |||
| is_broadcasted_channel_like(src0.layout, binfo) && | |||
| src0.layout.eq_shape(src2.layout); | |||
| if (normal_case) { | |||
| #define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \ | |||
| case _mode: { \ | |||
| using src_ctype = typename DTypeTrait<_src_dt>::ctype; \ | |||
| using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \ | |||
| thin_function<void( \ | |||
| const src_ctype*, const src_ctype*, const src_ctype*, dst_ctype*, \ | |||
| DType, DType, DType, DType, size_t, size_t, size_t, size_t)> \ | |||
| run = OpCallerTernary< \ | |||
| _op<src_ctype, dst_ctype>, BCAST101_VEC_BCAST101>::run; \ | |||
| MEGDNN_DISPATCH_CPU_KERN_OPR( \ | |||
| run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \ | |||
| src2.ptr<src_ctype>(), dst.ptr<dst_ctype>(), src0.layout.dtype, \ | |||
| src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, binfo.x, \ | |||
| binfo.y, binfo.z, binfo.y* binfo.z)); \ | |||
| return; \ | |||
| } | |||
| DISPATCH() | |||
| #undef DISPATCH_SINGLE_MODE | |||
| } | |||
| } | |||
| //! VEC + BCAST101x4 + VEC | |||
| { | |||
| BroadcastChannelInfo binfo; | |||
| if (is_vector(src0.layout) && | |||
| (is_broadcastedx_channel_like<4>(src1.layout, binfo) || | |||
| is_broadcastedx_channel_like<8>(src1.layout, binfo)) && | |||
| src0.layout.eq_shape(src2.layout)) { | |||
| #define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \ | |||
| case _mode: { \ | |||
| using src_ctype = typename DTypeTrait<_src_dt>::ctype; \ | |||
| using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \ | |||
| thin_function<void( \ | |||
| const src_ctype*, const src_ctype*, const src_ctype*, dst_ctype*, \ | |||
| DType, DType, DType, DType, size_t, size_t, size_t, size_t)> \ | |||
| run = OpCallerTernary< \ | |||
| _op<src_ctype, dst_ctype>, VEC_BCAST101xX_VEC>::run; \ | |||
| MEGDNN_DISPATCH_CPU_KERN_OPR( \ | |||
| run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \ | |||
| src2.ptr<src_ctype>(), dst.ptr<dst_ctype>(), src0.layout.dtype, \ | |||
| src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ | |||
| batch_size, binfo.x, binfo.y, binfo.z)); \ | |||
| return; \ | |||
| } | |||
| size_t batch_size = src0.layout.shape[0] / (binfo.x * binfo.y * binfo.z); | |||
| DISPATCH() | |||
| #undef DISPATCH_SINGLE_MODE | |||
| } | |||
| //! BCAST101x + VEC +BCAST101x | |||
| if (is_vector(src1.layout) && | |||
| (is_broadcastedx_channel_like<4>(src0.layout, binfo) || | |||
| is_broadcastedx_channel_like<8>(src0.layout, binfo)) && | |||
| src0.layout.eq_shape(src2.layout)) { | |||
| #define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \ | |||
| case _mode: { \ | |||
| using src_ctype = typename DTypeTrait<_src_dt>::ctype; \ | |||
| using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \ | |||
| thin_function<void( \ | |||
| const src_ctype*, const src_ctype*, const src_ctype*, dst_ctype*, \ | |||
| DType, DType, DType, DType, size_t, size_t, size_t, size_t)> \ | |||
| run = OpCallerTernary< \ | |||
| _op<src_ctype, dst_ctype>, BCAST101xX_VEC_BCAST101xX>::run; \ | |||
| MEGDNN_DISPATCH_CPU_KERN_OPR( \ | |||
| run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \ | |||
| src2.ptr<src_ctype>(), dst.ptr<dst_ctype>(), src0.layout.dtype, \ | |||
| src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ | |||
| batch_size, binfo.x, binfo.y, binfo.z)); \ | |||
| return; \ | |||
| } | |||
| size_t batch_size = src1.layout.shape[0] / (binfo.x * binfo.y * binfo.z); | |||
| DISPATCH() | |||
| #undef DISPATCH_SINGLE_MODE | |||
| } | |||
| } | |||
| naive::ElemwiseMultiTypeImpl::on_quantized_mode(param, dst, mode); | |||
| #undef DISPATCH | |||
| #undef DISPATCH_QUANTIZED_MODE | |||
| } | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -60,6 +60,7 @@ | |||
| #define GI_NEON_INTRINSICS | |||
| #if defined(__aarch64__) | |||
| #define GI_NEON64_INTRINSICS | |||
| #define GI_NEON32_INTRINSICS | |||
| #else | |||
| #define GI_NEON32_INTRINSICS | |||
| #endif | |||
| @@ -11,8 +11,10 @@ | |||
| */ | |||
| #include "test/common/elemwise_multi_type.h" | |||
| #include "megdnn/opr_param_defs.h" | |||
| #include "megdnn/oprs.h" | |||
| #include "test/arm_common/fixture.h" | |||
| #include "test/common/benchmarker.h" | |||
| #include "test/common/checker.h" | |||
| #include "test/common/task_record_check.h" | |||
| #include "test/common/timer.h" | |||
| @@ -559,4 +561,95 @@ TEST_F(ARM_COMMON, ELEMWISE_FMA3_UINT8xF32xF32xF32_RECORD) { | |||
| .execs({{16, 128, 16, 16}, {1, 1, 1, 1}, {1, 1, 1, 1}, {}}); | |||
| } | |||
| #if MEGDNN_WITH_BENCHMARK | |||
| namespace { | |||
| void run_elemwise_benchmark( | |||
| const TensorShapeArray& shapes, ElemwiseMultiType::Param::Mode mode, | |||
| const char* mode_str, std::vector<DType> types, Handle* handle_bench) { | |||
| auto handle_fallback = create_cpu_handle(1); | |||
| Benchmarker<ElemwiseMultiType> benchmarker_bench(handle_bench); | |||
| Benchmarker<ElemwiseMultiType> benchmarker_fallback(handle_fallback.get()); | |||
| float throughput = 0; | |||
| SmallVector<TensorLayout> layouts; | |||
| std::string src_strs; | |||
| for (size_t i = 0; i < shapes.size(); i++) { | |||
| layouts.emplace_back(shapes[i], types[i]); | |||
| throughput += layouts.back().span().dist_byte(); | |||
| src_strs += layouts.back().to_string(); | |||
| if (i != shapes.size() - 1) { | |||
| src_strs += ","; | |||
| } | |||
| } | |||
| constexpr size_t RUN = 50; | |||
| benchmarker_fallback.set_times(RUN).set_display(false); | |||
| benchmarker_bench.set_times(RUN).set_display(false); | |||
| benchmarker_fallback.set_param(mode); | |||
| benchmarker_bench.set_param(mode); | |||
| TensorLayout dst_layout; | |||
| dst_layout.dtype = types.back(); | |||
| auto opr = handle_bench->create_operator<ElemwiseMultiType>(); | |||
| opr->param() = mode; | |||
| opr->deduce_layout(layouts, dst_layout); | |||
| float computations = | |||
| dst_layout.total_nr_elems() * (std::max<size_t>(shapes.size(), 2) - 1); | |||
| throughput += dst_layout.span().dist_byte(); | |||
| computations *= (1e3 / (1024.0 * 1024)); | |||
| throughput *= (1e3 / (1024.0 * 1024)); | |||
| layouts.emplace_back(dst_layout); | |||
| auto fallback_time = benchmarker_fallback.execl(layouts) / RUN; | |||
| auto bench_time = benchmarker_bench.execl(layouts) / RUN; | |||
| float fallback_flops = computations / fallback_time; | |||
| float bench_flops = computations / bench_time; | |||
| float fallback_thr = throughput / fallback_time; | |||
| float bench_thr = throughput / bench_time; | |||
| printf("%s = %s (mode: %s) cpu=%fMFLOPS %fMB/s, bench=%fMFLOPS " | |||
| "%fMB/s " | |||
| "computations: %fx, throughput: %fx\n", | |||
| src_strs.c_str(), dst_layout.to_string().c_str(), mode_str, fallback_flops, | |||
| fallback_thr, bench_flops, bench_thr, bench_flops / fallback_flops, | |||
| bench_thr / fallback_thr); | |||
| } | |||
| } // namespace | |||
| #define RUN_WITH_MODE(shape, mode, types) \ | |||
| run_elemwise_benchmark(shape, mode, #mode, types, handle()); | |||
| TEST_F(ARM_COMMON, BENCHMARK_UNARY_MULTI_TYPE) { | |||
| using Mode = ElemwiseMultiType::Param::Mode; | |||
| for (auto mode : | |||
| {Mode::QRELU, Mode::QABS, Mode::QSIGMOID, Mode::QEXP, Mode::QTANH, | |||
| Mode::QFAST_TANH, Mode::QH_SWISH}) { | |||
| std::vector<DType> types = {dtype::QuantizedS8(1.4f), dtype::QuantizedS8(3.4f)}; | |||
| TensorShapeArray shapes = {{10000}}; | |||
| RUN_WITH_MODE(shapes, mode, types); | |||
| std::vector<DType> types2 = { | |||
| dtype::QuantizedS32(1.4f), dtype::QuantizedS8(3.4f)}; | |||
| RUN_WITH_MODE(shapes, mode, types2); | |||
| } | |||
| } | |||
| TEST_F(ARM_COMMON, BENCHMARK_BINARY_MULTI_TYPE) { | |||
| using Mode = ElemwiseMultiType::Param::Mode; | |||
| for (auto mode : {Mode::QADD, Mode::QFUSE_ADD_RELU, Mode::QFUSE_ADD_H_SWISH}) { | |||
| std::vector<DType> types = { | |||
| dtype::QuantizedS8(1.4f), dtype::QuantizedS8(3.4f), | |||
| dtype::QuantizedS8(1.6f)}; | |||
| TensorShapeArray shapes = {{10000}, {10000}}; | |||
| RUN_WITH_MODE(shapes, mode, types); | |||
| std::vector<DType> types2 = { | |||
| dtype::QuantizedS32(1.4f), dtype::QuantizedS32(3.4f), | |||
| dtype::QuantizedS8(1.6f)}; | |||
| RUN_WITH_MODE(shapes, mode, types2); | |||
| } | |||
| } | |||
| #endif | |||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||
| @@ -26,6 +26,175 @@ TYPED_TEST(FALLBACK_ELEMWISE_MULTI_TYPE, run) { | |||
| elemwise_multi_type::run_test<TypeParam>(this->handle()); | |||
| } | |||
| TEST_F(FALLBACK, ELEMWISE_QUANTIZED_MODE_UNARY) { | |||
| using Mode = ElemwiseMultiType::Param::Mode; | |||
| Checker<ElemwiseMultiType> checker(handle()); | |||
| std::unique_ptr<RNG> rng; | |||
| for (auto mode : | |||
| {Mode::QRELU, Mode::QABS, Mode::QSIGMOID, Mode::QEXP, Mode::QTANH, | |||
| Mode::QFAST_TANH, Mode::QH_SWISH}) { | |||
| checker.set_param({mode}); | |||
| for (DType src_type : | |||
| std::vector<DType>{dtype::QuantizedS8(1.4f), dtype::QuantizedS32(1.3f)}) { | |||
| checker.set_dtype(0, src_type); | |||
| if (src_type.enumv() == DTypeEnum::QuantizedS8) { | |||
| rng = std::make_unique<UniformIntRNG>(-127, 127); | |||
| checker.set_dtype(1, dtype::QuantizedS8(1.7f)); | |||
| } else { | |||
| rng = std::make_unique<UniformIntRNG>(INT16_MIN >> 1, INT16_MAX >> 1); | |||
| } | |||
| checker.set_rng(0, rng.get()); | |||
| auto run = [&]() { | |||
| checker.execs({{3, 4, 5, 6}, {}}); | |||
| checker.execs({{3}, {}}); | |||
| checker.execs({{9}, {}}); | |||
| checker.execs({{17}, {}}); | |||
| }; | |||
| if (src_type.enumv() == DTypeEnum::QuantizedS32) { | |||
| for (DType dst_type : | |||
| std::vector<DType>{dtype::QuantizedS8(32718.6f)}) { | |||
| checker.set_dtype(1, dst_type); | |||
| run(); | |||
| } | |||
| } else { | |||
| run(); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| TEST_F(FALLBACK, ELEMWISE_QUANTIZED_MODE_BINARY) { | |||
| using Mode = ElemwiseMultiType::Param::Mode; | |||
| Checker<ElemwiseMultiType> checker(handle()); | |||
| auto run = [&]() { | |||
| //! nchw44 | |||
| checker.execs({{1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}}); | |||
| checker.execs({{1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {}}); | |||
| checker.execs({{3, 8, 5, 3, 4}, {1, 8, 1, 1, 4}, {}}); | |||
| checker.execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}}); | |||
| checker.execs({{1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {}}); | |||
| checker.execs({{1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {}}); | |||
| checker.execs({{1, 8, 1, 1, 4}, {3, 8, 5, 3, 4}, {}}); | |||
| checker.execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}}); | |||
| checker.execs({{1, 2, 1, 1, 4}, {1, 2, 5, 7, 4}, {}}); | |||
| //! VEC + SCALAR | |||
| checker.execs({{3, 4, 5, 6}, {1, 1, 1, 1}, {}}); | |||
| checker.execs({{1, 1, 1, 1}, {3, 4, 5, 6}, {}}); | |||
| checker.execs({{3, 4, 5, 6}, {1}, {}}); | |||
| checker.execs({{1}, {3, 4, 5, 6}, {}}); | |||
| //! VEC + 1C11 | |||
| checker.execs({{3, 4, 5, 6}, {1, 4, 1, 1}, {}}); | |||
| checker.execs({{1, 4, 1, 1}, {3, 4, 5, 6}, {}}); | |||
| //! VEC + VEC | |||
| checker.execs({{3}, {3}, {}}); | |||
| checker.execs({{9}, {9}, {}}); | |||
| checker.execs({{17}, {17}, {}}); | |||
| }; | |||
| // qint32 to qint8/quint8 | |||
| for (auto mode : {Mode::QADD, Mode::QFUSE_ADD_RELU, Mode::QFUSE_ADD_H_SWISH}) { | |||
| checker.set_param({mode}); | |||
| UniformIntRNG rng{INT16_MIN >> 1, INT16_MAX >> 1}; | |||
| checker.set_rng(0, &rng) | |||
| .set_rng(1, &rng) | |||
| .set_dtype(0, dtype::QuantizedS32(1.3f)) | |||
| .set_dtype(1, dtype::QuantizedS32(1.2f)); | |||
| for (DType dst_type : std::vector<DType>{dtype::QuantizedS8(32718.6f)}) { | |||
| checker.set_dtype(2, dst_type); | |||
| run(); | |||
| } | |||
| } | |||
| for (auto mode : | |||
| {Mode::QMUL, Mode::QADD, Mode::QMIN, Mode::QMAX, Mode::QSUB, | |||
| Mode::QFUSE_ADD_RELU, Mode::QFUSE_ADD_SIGMOID, Mode::QFUSE_ADD_H_SWISH}) { | |||
| checker.set_param({mode}); | |||
| // qint8 to qint8 | |||
| UniformIntRNG rng_int8{-127, 127}; | |||
| checker.set_rng(0, &rng_int8) | |||
| .set_rng(1, &rng_int8) | |||
| .set_dtype(0, dtype::QuantizedS8(1.35f)) | |||
| .set_dtype(1, dtype::QuantizedS8(1.15f)) | |||
| .set_dtype(2, dtype::QuantizedS8(1.75f)); | |||
| run(); | |||
| } | |||
| //! TRUE_DIV : 0.0 / 0.0 will fail | |||
| checker.set_param({Mode::QTRUE_DIV}); | |||
| UniformIntRNG rng_int8_1{-127, 127}; | |||
| UniformIntRNG rng_int8_2{-127, -1}; | |||
| checker.set_rng(0, &rng_int8_1) | |||
| .set_rng(1, &rng_int8_2) | |||
| .set_dtype(0, dtype::QuantizedS8(1.4f)) | |||
| .set_dtype(1, dtype::QuantizedS8(1.1f)) | |||
| .set_dtype(2, dtype::QuantizedS8(1.7f)); | |||
| run(); | |||
| //! TANH | |||
| checker.set_param({Mode::QFUSE_ADD_TANH}); | |||
| UniformIntRNG rng_int8{-5, 5}; | |||
| checker.set_rng(0, &rng_int8) | |||
| .set_rng(1, &rng_int8) | |||
| .set_dtype(0, dtype::QuantizedS8(1.1f)) | |||
| .set_dtype(1, dtype::QuantizedS8(1.4f)) | |||
| .set_dtype(2, dtype::QuantizedS8(1.7f)); | |||
| run(); | |||
| } | |||
| TEST_F(FALLBACK, ELEMWISE_QUANTIZED_MODE_TERNARY) { | |||
| using Mode = ElemwiseMultiType::Param::Mode; | |||
| Checker<ElemwiseMultiType> checker(handle()); | |||
| auto run = [&]() { | |||
| //! nchw44 | |||
| checker.execs({{1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}}); | |||
| checker.execs({{1, 3, 1, 1, 4}, {2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}}); | |||
| checker.execs({{1, 8, 1, 1, 4}, {3, 8, 5, 3, 4}, {1, 8, 1, 1, 4}, {}}); | |||
| checker.execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}}); | |||
| checker.execs({{1, 2, 1, 1, 4}, {1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {}}); | |||
| //! nchw44 | |||
| checker.execs({{1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {}}); | |||
| checker.execs({{2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {2, 3, 2, 2, 4}, {}}); | |||
| checker.execs({{3, 8, 5, 3, 4}, {1, 8, 1, 1, 4}, {3, 8, 5, 3, 4}, {}}); | |||
| checker.execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}}); | |||
| checker.execs({{1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {1, 2, 5, 7, 4}, {}}); | |||
| checker.execs({{3, 4, 5, 6}, {3, 4, 5, 6}, {1, 1, 1, 1}, {}}); | |||
| checker.execs({{1, 4, 1, 1}, {3, 4, 5, 6}, {1, 4, 1, 1}, {}}); | |||
| checker.execs({{3}, {3}, {3}, {}}); | |||
| checker.execs({{9}, {9}, {9}, {}}); | |||
| checker.execs({{17}, {17}, {17}, {}}); | |||
| checker.execs({{3, 4, 5, 6}, {3, 4, 5, 6}, {3, 4, 5, 6}, {}}); | |||
| }; | |||
| for (auto mode : {Mode::QFUSE_MUL_ADD3}) { | |||
| checker.set_param({mode}); | |||
| // qint8 to qint8 | |||
| UniformIntRNG rng_int8{-127, 127}; | |||
| checker.set_rng(0, &rng_int8) | |||
| .set_rng(1, &rng_int8) | |||
| .set_rng(2, &rng_int8) | |||
| .set_dtype(0, dtype::QuantizedS8(1.45f)) | |||
| .set_dtype(1, dtype::QuantizedS8(1.15f)) | |||
| .set_dtype(2, dtype::QuantizedS8(1.75f)) | |||
| .set_dtype(3, dtype::QuantizedS8(1.35f)); | |||
| run(); | |||
| } | |||
| } | |||
| TEST_F(FALLBACK, ELEMWISE_MULTI_TYPE_RECORD_FMA3_INT16x32x32x32) { | |||
| TaskRecordChecker<ElemwiseMultiType> checker{1}; | |||
| checker.set_param({ElemwiseMultiType::Mode::FUSE_MUL_ADD3_INT16x32x32x32}); | |||