| @@ -62,7 +62,7 @@ namespace megdnn { | |||||
| #define MEGDNN_FOREACH_PARAMETERIZED_DTYPE_OTHERS(cb) \ | #define MEGDNN_FOREACH_PARAMETERIZED_DTYPE_OTHERS(cb) \ | ||||
| cb(QuantizedS32) cb(QuantizedS8) cb(Quantized4Asymm) cb(QuantizedS4) \ | cb(QuantizedS32) cb(QuantizedS8) cb(Quantized4Asymm) cb(QuantizedS4) \ | ||||
| cb(QuantizedS16) | |||||
| cb(QuantizedS16) cb(QuantizedS1) | |||||
| #define MEGDNN_FOREACH_PARAMETERIZED_DTYPE_2(cb_first, cb_others) \ | #define MEGDNN_FOREACH_PARAMETERIZED_DTYPE_2(cb_first, cb_others) \ | ||||
| MEGDNN_FOREACH_PARAMETERIZED_DTYPE_FIRST(cb_first) \ | MEGDNN_FOREACH_PARAMETERIZED_DTYPE_FIRST(cb_first) \ | ||||
| @@ -112,7 +112,7 @@ namespace megdnn { | |||||
| #define MEGDNN_FOREACH_QUANTIZED_DTYPE_SYMM(cb) \ | #define MEGDNN_FOREACH_QUANTIZED_DTYPE_SYMM(cb) \ | ||||
| cb(::megdnn::dtype::QuantizedS32) cb(::megdnn::dtype::QuantizedS8) \ | cb(::megdnn::dtype::QuantizedS32) cb(::megdnn::dtype::QuantizedS8) \ | ||||
| cb(::megdnn::dtype::QuantizedS4) | |||||
| cb(::megdnn::dtype::QuantizedS4) cb(::megdnn::dtype::QuantizedS1) | |||||
| #define MEGDNN_FOREACH_QUANTIZED_DTYPE_ASYMM(cb) \ | #define MEGDNN_FOREACH_QUANTIZED_DTYPE_ASYMM(cb) \ | ||||
| cb(::megdnn::dtype::Quantized8Asymm) cb(::megdnn::dtype::Quantized4Asymm) | cb(::megdnn::dtype::Quantized8Asymm) cb(::megdnn::dtype::Quantized4Asymm) | ||||
| @@ -292,10 +292,27 @@ public: | |||||
| }; | }; | ||||
| using dt_qint4 = dt_qlowbit<4>; | using dt_qint4 = dt_qlowbit<4>; | ||||
| class dt_qint1 { | |||||
| int8_t _; | |||||
| public: | |||||
| MEGDNN_DEVICE int8_t as_int8() const { return _; } | |||||
| MEGDNN_HOST MEGDNN_DEVICE explicit dt_qint1(int8_t val) : _(val) {} | |||||
| #ifdef MEGDNN_CC_HOST | |||||
| explicit operator int8_t() { return _; } | |||||
| #endif | |||||
| bool operator<(const dt_qint1& b) const { return _ < b._; } | |||||
| bool operator>(const dt_qint1& b) const { return _ > b._; } | |||||
| bool operator==(const dt_qint1& b) const { return _ == b._; } | |||||
| bool operator!=(const dt_qint1& b) const { return _ != b._; } | |||||
| } MEGDNN_PACKED; | |||||
| #ifdef __clang__ | #ifdef __clang__ | ||||
| #pragma clang diagnostic pop | #pragma clang diagnostic pop | ||||
| #endif | #endif | ||||
| MEGDNN_STATIC_ASSERT(sizeof(dt_byte) == 1, "bad dt_byte size"); | MEGDNN_STATIC_ASSERT(sizeof(dt_byte) == 1, "bad dt_byte size"); | ||||
| MEGDNN_STATIC_ASSERT(sizeof(dt_qint1) == 1, "bad dt_qint1 size"); | |||||
| MEGDNN_STATIC_ASSERT(sizeof(dt_quint8) == 1, "bad dt_quint8 size"); | MEGDNN_STATIC_ASSERT(sizeof(dt_quint8) == 1, "bad dt_quint8 size"); | ||||
| MEGDNN_STATIC_ASSERT(sizeof(dt_qint16) == 2, "bad dt_qint16 size"); | MEGDNN_STATIC_ASSERT(sizeof(dt_qint16) == 2, "bad dt_qint16 size"); | ||||
| MEGDNN_STATIC_ASSERT(sizeof(dt_qint32) == 4, "bad dt_qint32 size"); | MEGDNN_STATIC_ASSERT(sizeof(dt_qint32) == 4, "bad dt_qint32 size"); | ||||
| @@ -677,7 +694,7 @@ MEGDNN_FOREACH_LOWBIT_DTYPE(MEGDNN_DEF_FRACTION_DT) | |||||
| return static_cast<_itype>(_maxval); \ | return static_cast<_itype>(_maxval); \ | ||||
| } \ | } \ | ||||
| }; | }; | ||||
| MEGDNN_DEF_PARAMETERIZED_DT(QuantizedS1, dt_qint1, int8_t, QUANTIZED, SIGNED, 0, 1, 0); | |||||
| MEGDNN_DEF_PARAMETERIZED_DT( | MEGDNN_DEF_PARAMETERIZED_DT( | ||||
| Quantized4Asymm, dt_quint4, uint8_t, QUANTIZED, SIGNED, 0, 15, 4); | Quantized4Asymm, dt_quint4, uint8_t, QUANTIZED, SIGNED, 0, 15, 4); | ||||
| MEGDNN_DEF_PARAMETERIZED_DT(QuantizedS4, dt_qint4, int8_t, QUANTIZED, SIGNED, -8, 7, 4); | MEGDNN_DEF_PARAMETERIZED_DT(QuantizedS4, dt_qint4, int8_t, QUANTIZED, SIGNED, -8, 7, 4); | ||||
| @@ -876,6 +893,26 @@ struct DTypeParamImpl<dt_quint4> { | |||||
| } | } | ||||
| }; | }; | ||||
| template <> | |||||
| struct DTypeParamImpl<dt_qint1> { | |||||
| float scale; | |||||
| DTypeParamImpl<dt_qint1>() = default; | |||||
| MGE_WIN_DECLSPEC_FUC DTypeParamImpl<dt_qint1>(float scale); | |||||
| #ifdef MEGDNN_CC_HOST | |||||
| std::size_t hash() const; | |||||
| #endif | |||||
| bool operator==(const DTypeParam<dt_qint1>& rhs) const; | |||||
| MEGDNN_DEVICE dt_qint1 quantize(float in) const { | |||||
| float v = in / scale; | |||||
| v = roundf(v); | |||||
| v = fmin(fmax(0.f, v), 1.f); | |||||
| return static_cast<dt_qint1>(v); | |||||
| } | |||||
| MEGDNN_DEVICE float dequantize(int8_t in) const { return in * scale; } | |||||
| MEGDNN_DEVICE float dequantize(dt_qint1 in) const { return in.as_int8() * scale; } | |||||
| }; | |||||
| template <> | template <> | ||||
| struct DTypeParamImpl<dt_qint4> { | struct DTypeParamImpl<dt_qint4> { | ||||
| float scale; | float scale; | ||||
| @@ -142,6 +142,19 @@ inline bool DTypeParam<dt_qint32>::operator==(const DTypeParam<dt_qint32>& rhs) | |||||
| return scale == rhs.scale; | return scale == rhs.scale; | ||||
| } | } | ||||
| DTypeParam<dt_qint1>::DTypeParamImpl(float scale) : scale{scale} { | |||||
| //! As the nan is not equal to any value | |||||
| megdnn_assert(!std::isnan(scale), "nan number compare is not support"); | |||||
| } | |||||
| inline std::size_t DTypeParam<dt_qint1>::hash() const { | |||||
| return std::hash<float>()(scale); | |||||
| } | |||||
| inline bool DTypeParam<dt_qint1>::operator==(const DTypeParam<dt_qint1>& rhs) const { | |||||
| return scale == rhs.scale; | |||||
| } | |||||
| DTypeParam<dt_quint4>::DTypeParamImpl(float scale, uint8_t zero_point) | DTypeParam<dt_quint4>::DTypeParamImpl(float scale, uint8_t zero_point) | ||||
| : scale{scale}, zero_point{zero_point} { | : scale{scale}, zero_point{zero_point} { | ||||
| //! As the nan is not equal to any value | //! As the nan is not equal to any value | ||||
| @@ -241,6 +241,7 @@ float megdnn::mul_scale(DType lhs, DType rhs) { | |||||
| return lhs.param<dt>().scale * rhs.param<dt>().scale; | return lhs.param<dt>().scale * rhs.param<dt>().scale; | ||||
| MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) | MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) | ||||
| MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb) | MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb) | ||||
| cb(::megdnn::dtype::QuantizedS1) | |||||
| #undef cb | #undef cb | ||||
| megdnn_assert_internal(0); | megdnn_assert_internal(0); | ||||
| } | } | ||||
| @@ -253,8 +254,9 @@ float megdnn::get_scale(DType dt) { | |||||
| return dt.param<_dt>().scale; | return dt.param<_dt>().scale; | ||||
| MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) | MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) | ||||
| MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb) | MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb) | ||||
| cb(::megdnn::dtype::QuantizedS1) | |||||
| #undef cb | #undef cb | ||||
| megdnn_assert_internal(0); | |||||
| megdnn_assert_internal(0); | |||||
| } | } | ||||
| bool megdnn::dtype_almost_equal(DType lhs, DType rhs) { | bool megdnn::dtype_almost_equal(DType lhs, DType rhs) { | ||||
| @@ -160,6 +160,9 @@ INST_FOR_CTYPE | |||||
| #define ct dt_bool | #define ct dt_bool | ||||
| INST_FOR_CTYPE | INST_FOR_CTYPE | ||||
| #undef ct | #undef ct | ||||
| #define ct dt_qint1 | |||||
| INST_FOR_CTYPE | |||||
| #undef ct | |||||
| #undef INST_FOR_CTYPE | #undef INST_FOR_CTYPE | ||||
| #undef INST | #undef INST | ||||
| @@ -210,6 +213,9 @@ INST_FOR_CTYPE | |||||
| #define ct dt_bool | #define ct dt_bool | ||||
| INST_FOR_CTYPE | INST_FOR_CTYPE | ||||
| #undef ct | #undef ct | ||||
| #define ct dt_qint1 | |||||
| INST_FOR_CTYPE | |||||
| #undef ct | |||||
| #undef ndim_cb | #undef ndim_cb | ||||
| @@ -221,6 +227,7 @@ INST(dt_int8); | |||||
| INST(dt_uint8); | INST(dt_uint8); | ||||
| INST(dt_bool); | INST(dt_bool); | ||||
| INST(dt_qint8); | INST(dt_qint8); | ||||
| INST(dt_qint1); | |||||
| INST(dt_quint8); | INST(dt_quint8); | ||||
| #undef dt_ibyte | #undef dt_ibyte | ||||
| @@ -96,6 +96,7 @@ INST(dt_bool, uchar4); | |||||
| #undef as_raw | #undef as_raw | ||||
| #define as_raw(x) x.as_int8() | #define as_raw(x) x.as_int8() | ||||
| INST(dt_qint8, char4); | INST(dt_qint8, char4); | ||||
| INST(dt_qint1, char4); | |||||
| #undef as_raw | #undef as_raw | ||||
| #define as_raw(x) x.as_uint8() | #define as_raw(x) x.as_uint8() | ||||
| INST(dt_quint8, uchar4); | INST(dt_quint8, uchar4); | ||||
| @@ -466,6 +467,7 @@ INST_PARAM_VECT_VISITOR; | |||||
| INST_DT_IBYTE(dt_int8); | INST_DT_IBYTE(dt_int8); | ||||
| INST_DT_IBYTE(dt_uint8); | INST_DT_IBYTE(dt_uint8); | ||||
| INST_DT_IBYTE(dt_qint8); | INST_DT_IBYTE(dt_qint8); | ||||
| INST_DT_IBYTE(dt_qint1); | |||||
| INST_DT_IBYTE(dt_quint8); | INST_DT_IBYTE(dt_quint8); | ||||
| INST_DT_IBYTE(dt_bool); | INST_DT_IBYTE(dt_bool); | ||||
| #undef INST_DT_IBYTE | #undef INST_DT_IBYTE | ||||
| @@ -1299,6 +1301,7 @@ private: | |||||
| INST_DT_IBYTE(dt_int8); | INST_DT_IBYTE(dt_int8); | ||||
| INST_DT_IBYTE(dt_uint8); | INST_DT_IBYTE(dt_uint8); | ||||
| INST_DT_IBYTE(dt_qint8); | INST_DT_IBYTE(dt_qint8); | ||||
| INST_DT_IBYTE(dt_qint1); | |||||
| INST_DT_IBYTE(dt_quint8); | INST_DT_IBYTE(dt_quint8); | ||||
| INST_DT_IBYTE(dt_bool); | INST_DT_IBYTE(dt_bool); | ||||
| #undef INST_DT_IBYTE | #undef INST_DT_IBYTE | ||||
| @@ -1649,6 +1652,7 @@ public: | |||||
| INST_DT_IBYTE(dt_int8); | INST_DT_IBYTE(dt_int8); | ||||
| INST_DT_IBYTE(dt_uint8); | INST_DT_IBYTE(dt_uint8); | ||||
| INST_DT_IBYTE(dt_qint8); | INST_DT_IBYTE(dt_qint8); | ||||
| INST_DT_IBYTE(dt_qint1); | |||||
| INST_DT_IBYTE(dt_quint8); | INST_DT_IBYTE(dt_quint8); | ||||
| INST_DT_IBYTE(dt_bool); | INST_DT_IBYTE(dt_bool); | ||||
| #undef INST_DT_IBYTE | #undef INST_DT_IBYTE | ||||
| @@ -88,6 +88,7 @@ struct TypeCvtOpToQuantized< | |||||
| typename std::enable_if< | typename std::enable_if< | ||||
| std::is_same<ctype_src, dt_int8>::value || | std::is_same<ctype_src, dt_int8>::value || | ||||
| std::is_same<ctype_src, dt_uint8>::value || | std::is_same<ctype_src, dt_uint8>::value || | ||||
| std::is_same<ctype_src, dt_qint1>::value || | |||||
| std::is_same<ctype_src, dt_bool>::value>::type> { | std::is_same<ctype_src, dt_bool>::value>::type> { | ||||
| ctype_dest* dest; | ctype_dest* dest; | ||||
| CudaDTypeParam<ctype_dest> param; | CudaDTypeParam<ctype_dest> param; | ||||
| @@ -111,6 +112,7 @@ struct TypeCvtOpFromQuantized< | |||||
| ctype_dest, ctype_src, | ctype_dest, ctype_src, | ||||
| typename std::enable_if< | typename std::enable_if< | ||||
| std::is_same<ctype_src, dt_qint8>::value || | std::is_same<ctype_src, dt_qint8>::value || | ||||
| std::is_same<ctype_src, dt_qint1>::value || | |||||
| std::is_same<ctype_src, dt_quint8>::value>::type> { | std::is_same<ctype_src, dt_quint8>::value>::type> { | ||||
| ctype_dest* dest; | ctype_dest* dest; | ||||
| CudaDTypeParam<ctype_src> param; | CudaDTypeParam<ctype_src> param; | ||||
| @@ -134,7 +136,8 @@ struct TypeCvtOpBetweenQuantized< | |||||
| ctype_dest, ctype_src, | ctype_dest, ctype_src, | ||||
| typename std::enable_if< | typename std::enable_if< | ||||
| (std::is_same<ctype_src, dt_qint8>::value || | (std::is_same<ctype_src, dt_qint8>::value || | ||||
| std::is_same<ctype_src, dt_quint8>::value) && | |||||
| std::is_same<ctype_src, dt_quint8>::value || | |||||
| std::is_same<ctype_src, dt_qint1>::value) && | |||||
| IsNotTypeQ4<ctype_dest>::value>::type> { | IsNotTypeQ4<ctype_dest>::value>::type> { | ||||
| ctype_dest* dest; | ctype_dest* dest; | ||||
| CudaDTypeParam<ctype_src> src_param; | CudaDTypeParam<ctype_src> src_param; | ||||
| @@ -306,6 +309,7 @@ void typecvt_kern_n2n(const TensorND& dest, const TensorND& src, cudaStream_t st | |||||
| cb(dtype_src, dt_quint8) \ | cb(dtype_src, dt_quint8) \ | ||||
| cb(dtype_src, dt_qint32) \ | cb(dtype_src, dt_qint32) \ | ||||
| cb(dtype_src, dt_qint8) \ | cb(dtype_src, dt_qint8) \ | ||||
| cb(dtype_src, dt_qint1) \ | |||||
| #define INST_SRC_QUANTIZED(dtype_src) \ | #define INST_SRC_QUANTIZED(dtype_src) \ | ||||
| MEGDNN_FOREACH_COMPUTING_DTYPE_WITH_DTYPE_SRC(dtype_src, INST_Q2N) \ | MEGDNN_FOREACH_COMPUTING_DTYPE_WITH_DTYPE_SRC(dtype_src, INST_Q2N) \ | ||||
| @@ -330,7 +334,8 @@ void typecvt_kern_n2n(const TensorND& dest, const TensorND& src, cudaStream_t st | |||||
| cb(dt_qint32) \ | cb(dt_qint32) \ | ||||
| cb(dt_qint8) \ | cb(dt_qint8) \ | ||||
| cb(dt_qint4) \ | cb(dt_qint4) \ | ||||
| cb(dt_quint4) | |||||
| cb(dt_quint4) \ | |||||
| cb(dt_qint1) | |||||
| MEGDNN_FOREACH_QUANTIZED_CTYPE(INST_SRC_QUANTIZED) | MEGDNN_FOREACH_QUANTIZED_CTYPE(INST_SRC_QUANTIZED) | ||||
| MEGDNN_FOREACH_COMPUTING_CTYPE(INST_SRC_NORMAL) | MEGDNN_FOREACH_COMPUTING_CTYPE(INST_SRC_NORMAL) | ||||
| @@ -50,6 +50,7 @@ void exec_src_quantized( | |||||
| return; \ | return; \ | ||||
| } | } | ||||
| MEGDNN_FOREACH_QUANTIZED_DTYPE(cb); | MEGDNN_FOREACH_QUANTIZED_DTYPE(cb); | ||||
| cb(::megdnn::dtype::QuantizedS1); | |||||
| default: | default: | ||||
| megdnn_assert_internal(0); | megdnn_assert_internal(0); | ||||
| #undef cb | #undef cb | ||||
| @@ -101,6 +102,7 @@ void exec_src_normal(const TensorND& dst, const TensorND& src, cudaStream_t stre | |||||
| return; \ | return; \ | ||||
| } | } | ||||
| MEGDNN_FOREACH_QUANTIZED_DTYPE(cb); | MEGDNN_FOREACH_QUANTIZED_DTYPE(cb); | ||||
| cb(::megdnn::dtype::QuantizedS1); | |||||
| #undef cb | #undef cb | ||||
| default: | default: | ||||
| megdnn_assert_internal(0); | megdnn_assert_internal(0); | ||||
| @@ -150,9 +152,9 @@ void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) { | |||||
| } | } | ||||
| MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) | MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) | ||||
| MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb) | MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb) | ||||
| cb(::megdnn::dtype::QuantizedS1) | |||||
| #undef cb | #undef cb | ||||
| default: | |||||
| megdnn_assert_internal(0); | |||||
| default : megdnn_assert_internal(0); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -241,6 +241,23 @@ struct CudaDTypeParamImpl<dt_qint4> : DTypeParamImpl<dt_qint4> { | |||||
| } | } | ||||
| }; | }; | ||||
| template <> | |||||
| struct CudaDTypeParamImpl<dt_qint1> : DTypeParamImpl<dt_qint1> { | |||||
| float inv_scale; | |||||
| CudaDTypeParamImpl() = default; | |||||
| CudaDTypeParamImpl(float scale) | |||||
| : DTypeParamImpl<dt_qint1>(scale), inv_scale(1.0f / scale) {} | |||||
| CudaDTypeParamImpl(const DTypeParamImpl<dt_qint1>& param) | |||||
| : CudaDTypeParamImpl(param.scale) {} | |||||
| __device__ dt_qint1 quantize(float in) const { | |||||
| float v = in * inv_scale; | |||||
| v = roundf(v); | |||||
| v = fmin(fmax(0.f, v), 1.f); | |||||
| return static_cast<dt_qint1>(v); | |||||
| } | |||||
| }; | |||||
| #if MEGDNN_CC_CUDA | #if MEGDNN_CC_CUDA | ||||
| static inline MEGDNN_DEVICE void dot_prod(int a, int b, int c, int& d) { | static inline MEGDNN_DEVICE void dot_prod(int a, int b, int c, int& d) { | ||||
| #if __CUDA_ARCH__ >= 610 | #if __CUDA_ARCH__ >= 610 | ||||
| @@ -510,7 +510,9 @@ void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) { | |||||
| }; | }; | ||||
| if (src.layout.is_contiguous() && dst.layout.is_contiguous() && | if (src.layout.is_contiguous() && dst.layout.is_contiguous() && | ||||
| !is_quantize_lowbit(src.layout.dtype) && | !is_quantize_lowbit(src.layout.dtype) && | ||||
| !is_quantize_lowbit(dst.layout.dtype)) { | |||||
| !is_quantize_lowbit(dst.layout.dtype) && | |||||
| dst.layout.dtype.enumv() != DTypeEnum::QuantizedS1 && | |||||
| src.layout.dtype.enumv() != DTypeEnum::QuantizedS1) { | |||||
| MEGDNN_DISPATCH_CPU_KERN_OPR(run_contiguous(src, dst)); | MEGDNN_DISPATCH_CPU_KERN_OPR(run_contiguous(src, dst)); | ||||
| } else { | } else { | ||||
| naive::TypeCvtImpl::exec(src, dst); | naive::TypeCvtImpl::exec(src, dst); | ||||
| @@ -79,8 +79,9 @@ void on_dest_ctype(HandleImpl* handle, const TensorND& dest, const TensorND& src | |||||
| MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) | MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) | ||||
| MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb) | MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb) | ||||
| cb(::megdnn::dtype::Bool) cb(::megdnn::dtype::Uint16) | cb(::megdnn::dtype::Bool) cb(::megdnn::dtype::Uint16) | ||||
| cb(::megdnn::dtype::QuantizedS1) | |||||
| #undef cb | #undef cb | ||||
| default : megdnn_throw("bad dtype"); | |||||
| default : megdnn_throw("bad dtype"); | |||||
| } | } | ||||
| } | } | ||||
| @@ -100,8 +101,9 @@ void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) { | |||||
| MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) | MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) | ||||
| MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb) | MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb) | ||||
| cb(::megdnn::dtype::Bool) cb(::megdnn::dtype::Uint16) | cb(::megdnn::dtype::Bool) cb(::megdnn::dtype::Uint16) | ||||
| cb(::megdnn::dtype::QuantizedS1) | |||||
| #undef cb | #undef cb | ||||
| default : megdnn_throw("bad dtype"); | |||||
| default : megdnn_throw("bad dtype"); | |||||
| } | } | ||||
| } | } | ||||
| @@ -79,7 +79,8 @@ template <typename ctype> | |||||
| const char* expr0, const char* expr1, const TensorND& v0, const TensorND& v1, | const char* expr0, const char* expr1, const TensorND& v0, const TensorND& v1, | ||||
| float maxerr, float maxerr_avg, float maxerr_avg_biased, bool allow_invalid) { | float maxerr, float maxerr_avg, float maxerr_avg_biased, bool allow_invalid) { | ||||
| if (!std::is_same<ctype, dt_qint4>::value && | if (!std::is_same<ctype, dt_qint4>::value && | ||||
| !std::is_same<ctype, dt_quint4>::value) { | |||||
| !std::is_same<ctype, dt_quint4>::value && | |||||
| !std::is_same<ctype, dt_qint1>::value) { | |||||
| if (v0.layout.is_physical_contiguous() && v1.layout.is_physical_contiguous()) { | if (v0.layout.is_physical_contiguous() && v1.layout.is_physical_contiguous()) { | ||||
| return assert_tensor_eq_with_iter<ctype>( | return assert_tensor_eq_with_iter<ctype>( | ||||
| expr0, expr1, v0.ptr<ctype>(), v1.ptr<ctype>(), v0.layout, maxerr, | expr0, expr1, v0.ptr<ctype>(), v1.ptr<ctype>(), v0.layout, maxerr, | ||||
| @@ -158,7 +159,7 @@ void copy_tensors( | |||||
| //! In order to avoid an unnecessary increase in binary size, we just | //! In order to avoid an unnecessary increase in binary size, we just | ||||
| //! use QuantizedS16 dtype in winograd_filter_preprocess now. | //! use QuantizedS16 dtype in winograd_filter_preprocess now. | ||||
| cb(::megdnn::dtype::QuantizedS16) MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb) | cb(::megdnn::dtype::QuantizedS16) MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb) | ||||
| cb(::megdnn::dtype::Uint16) | |||||
| cb(::megdnn::dtype::Uint16) cb(::megdnn::dtype::QuantizedS1) | |||||
| #undef cb | #undef cb | ||||
| default : megdnn_trap(); | default : megdnn_trap(); | ||||
| } | } | ||||
| @@ -71,6 +71,32 @@ TEST(TestDType, TestQuantized8Asymm) { | |||||
| EXPECT_ANY_THROW(DType::from_enum(DTypeEnum::Quantized8Asymm)); | EXPECT_ANY_THROW(DType::from_enum(DTypeEnum::Quantized8Asymm)); | ||||
| } | } | ||||
| TEST(TestDType, QuantizedS1) { | |||||
| using namespace megdnn; | |||||
| dtype::QuantizedS1 qint1(0.1f); | |||||
| EXPECT_EQ(qint1.size(1), 1u); | |||||
| EXPECT_FLOAT_EQ(qint1.param().scale, 0.1f); | |||||
| dtype::QuantizedS1 qint1_copy = qint1; | |||||
| EXPECT_NO_THROW(qint1_copy.assert_is(qint1)); | |||||
| EXPECT_FLOAT_EQ(qint1_copy.param().scale, 0.1f); | |||||
| dtype::QuantizedS1 qint1_reconstruct_with_same_param(0.1f); | |||||
| EXPECT_NO_THROW(qint1_reconstruct_with_same_param.assert_is(qint1)); | |||||
| dtype::QuantizedS1 qint1_diff(0.2f); | |||||
| EXPECT_ANY_THROW(qint1_diff.assert_is(qint1)); | |||||
| DType parent = qint1; | |||||
| ASSERT_NO_THROW(dtype::QuantizedS1::downcast_from(parent)); | |||||
| auto param = dtype::QuantizedS1::downcast_from(parent).param(); | |||||
| EXPECT_FLOAT_EQ(param.scale, 0.1f); | |||||
| EXPECT_ANY_THROW(dtype::QuantizedS1::downcast_from(dtype::IntB1())); | |||||
| EXPECT_ANY_THROW(DType::from_enum(DTypeEnum::QuantizedS1)); | |||||
| } | |||||
| TEST(TestDType, TestQuantizedS4) { | TEST(TestDType, TestQuantizedS4) { | ||||
| using namespace megdnn; | using namespace megdnn; | ||||
| @@ -149,7 +149,7 @@ void IIDRNG::gen(const TensorND& tensor) { | |||||
| MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) | MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) | ||||
| //! In order to avoid an unnecessary increase in binary size, we just | //! In order to avoid an unnecessary increase in binary size, we just | ||||
| //! use QuantizedS16 dtype in winograd_filter_preprocess now. | //! use QuantizedS16 dtype in winograd_filter_preprocess now. | ||||
| cb(::megdnn::dtype::QuantizedS16) | |||||
| cb(::megdnn::dtype::QuantizedS16) cb(::megdnn::dtype::QuantizedS1) | |||||
| #undef cb | #undef cb | ||||
| if (tensor.layout.dtype.enumv() == DTypeEnum::Quantized4Asymm) { | if (tensor.layout.dtype.enumv() == DTypeEnum::Quantized4Asymm) { | ||||
| auto ptr = static_cast<uint8_t*>(tensor.raw_ptr()); | auto ptr = static_cast<uint8_t*>(tensor.raw_ptr()); | ||||
| @@ -226,6 +226,10 @@ static inline int diff(dt_qint4 x, dt_qint4 y) { | |||||
| return x.as_int8() - y.as_int8(); | return x.as_int8() - y.as_int8(); | ||||
| } | } | ||||
| static inline int diff(dt_qint1 x, dt_qint1 y) { | |||||
| return x.as_int8() - y.as_int8(); | |||||
| } | |||||
| static inline int diff(dt_quint4 x, dt_quint4 y) { | static inline int diff(dt_quint4 x, dt_quint4 y) { | ||||
| return x.as_uint8() - y.as_uint8(); | return x.as_uint8() - y.as_uint8(); | ||||
| } | } | ||||
| @@ -339,6 +343,10 @@ static inline bool good_float(dt_qint4) { | |||||
| return true; | return true; | ||||
| } | } | ||||
| static inline bool good_float(dt_qint1) { | |||||
| return true; | |||||
| } | |||||
| static inline bool good_float(dt_quint4) { | static inline bool good_float(dt_quint4) { | ||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -373,6 +381,11 @@ static inline int operator+(dt_qint4 lhs, int rhs) { | |||||
| megdnn_assert(rhs == 0, "unexpected rhs"); | megdnn_assert(rhs == 0, "unexpected rhs"); | ||||
| return lhs.as_int8(); | return lhs.as_int8(); | ||||
| } | } | ||||
| static inline int operator+(dt_qint1 lhs, int rhs) { | |||||
| megdnn_assert(rhs == 0, "unexpected rhs"); | |||||
| return lhs.as_int8(); | |||||
| } | |||||
| } // namespace test | } // namespace test | ||||
| static inline bool operator==(const TensorLayout& a, const TensorLayout& b) { | static inline bool operator==(const TensorLayout& a, const TensorLayout& b) { | ||||
| @@ -77,16 +77,19 @@ TEST_F(CUDA, QUANTIZED_TYPECVT) { | |||||
| }; | }; | ||||
| run(dtype::Float32(), dtype::QuantizedS8(3.0f)); | run(dtype::Float32(), dtype::QuantizedS8(3.0f)); | ||||
| run(dtype::Float32(), dtype::QuantizedS1(3.0f)); | |||||
| run(dtype::Float16(), dtype::QuantizedS8(3.0f)); | run(dtype::Float16(), dtype::QuantizedS8(3.0f)); | ||||
| run(dtype::Int32(), dtype::QuantizedS32(5.0f)); | run(dtype::Int32(), dtype::QuantizedS32(5.0f)); | ||||
| run(dtype::Int8(), dtype::QuantizedS32(10.0f)); | run(dtype::Int8(), dtype::QuantizedS32(10.0f)); | ||||
| run(dtype::Float32(), dtype::QuantizedS8(2e-3f)); | run(dtype::Float32(), dtype::QuantizedS8(2e-3f)); | ||||
| run(dtype::Float32(), dtype::QuantizedS1(2e-3f)); | |||||
| run(dtype::Float16(), dtype::QuantizedS8(1e-3f)); | run(dtype::Float16(), dtype::QuantizedS8(1e-3f)); | ||||
| run(dtype::Int32(), dtype::QuantizedS32(1e-3f)); | run(dtype::Int32(), dtype::QuantizedS32(1e-3f)); | ||||
| run(dtype::Int8(), dtype::QuantizedS32(7e-4f)); | run(dtype::Int8(), dtype::QuantizedS32(7e-4f)); | ||||
| run(dtype::QuantizedS8(3.0f), dtype::QuantizedS8(10.0f)); | run(dtype::QuantizedS8(3.0f), dtype::QuantizedS8(10.0f)); | ||||
| run(dtype::QuantizedS1(3.0f), dtype::QuantizedS1(10.0f)); | |||||
| run(dtype::QuantizedS32(3.0f), dtype::QuantizedS8(10.0f)); | run(dtype::QuantizedS32(3.0f), dtype::QuantizedS8(10.0f)); | ||||
| run(dtype::QuantizedS8(3.0f), dtype::QuantizedS32(10.0f)); | run(dtype::QuantizedS8(3.0f), dtype::QuantizedS32(10.0f)); | ||||
| run(dtype::QuantizedS32(3.0f), dtype::QuantizedS32(10.0f)); | run(dtype::QuantizedS32(3.0f), dtype::QuantizedS32(10.0f)); | ||||
| @@ -95,6 +98,7 @@ TEST_F(CUDA, QUANTIZED_TYPECVT) { | |||||
| run(dtype::QuantizedS32(2e-3f), dtype::QuantizedS8(9e-4f)); | run(dtype::QuantizedS32(2e-3f), dtype::QuantizedS8(9e-4f)); | ||||
| run(dtype::QuantizedS8(9e-4f), dtype::QuantizedS32(7e-4f)); | run(dtype::QuantizedS8(9e-4f), dtype::QuantizedS32(7e-4f)); | ||||
| run(dtype::QuantizedS32(5e-3f), dtype::QuantizedS32(1e-3f)); | run(dtype::QuantizedS32(5e-3f), dtype::QuantizedS32(1e-3f)); | ||||
| run(dtype::QuantizedS1(1e-3f), dtype::Float32()); | |||||
| run(dtype::Quantized8Asymm(5.0f, (uint8_t)128), dtype::Float32()); | run(dtype::Quantized8Asymm(5.0f, (uint8_t)128), dtype::Float32()); | ||||
| run(dtype::Quantized8Asymm(5.0f, (uint8_t)124), dtype::Float16()); | run(dtype::Quantized8Asymm(5.0f, (uint8_t)124), dtype::Float16()); | ||||
| @@ -94,6 +94,7 @@ _builtin_quant_dtypes = { | |||||
| "qint8_narrow": QuantDtypeMeta("qint8_narrow", "QuantizedS8", "int8", -127, 127), | "qint8_narrow": QuantDtypeMeta("qint8_narrow", "QuantizedS8", "int8", -127, 127), | ||||
| "quint4": QuantDtypeMeta("quint4", "Quantized4Asymm", "uint8", 0, 15), | "quint4": QuantDtypeMeta("quint4", "Quantized4Asymm", "uint8", 0, 15), | ||||
| "qint4": QuantDtypeMeta("qint4", "QuantizedS4", "int8", -8, 7), | "qint4": QuantDtypeMeta("qint4", "QuantizedS4", "int8", -8, 7), | ||||
| "qint1": QuantDtypeMeta("qint1", "QuantizedS1", "int8", 0, 1), | |||||
| "qint32": QuantDtypeMeta( | "qint32": QuantDtypeMeta( | ||||
| "qint32", "QuantizedS32", "int32", -(2 ** 31), 2 ** 31 - 1, | "qint32", "QuantizedS32", "int32", -(2 ** 31), 2 ** 31 - 1, | ||||
| ), | ), | ||||
| @@ -192,6 +193,13 @@ def qint4(scale): | |||||
| return create_quantized_dtype(_builtin_quant_dtypes["qint4"], scale, None) | return create_quantized_dtype(_builtin_quant_dtypes["qint4"], scale, None) | ||||
| def qint1(scale): | |||||
| r"""Construct a quantized int1 data type with ``scale`` (float). The real value | |||||
| represented by a qint1 data type is float_val = scale * int1_val | |||||
| """ | |||||
| return create_quantized_dtype(_builtin_quant_dtypes["qint1"], scale, None) | |||||
| def _convert_to_quantized_dtype( | def _convert_to_quantized_dtype( | ||||
| arr: np.ndarray, dtype: np.dtype, dtype_meta: QuantDtypeMeta | arr: np.ndarray, dtype: np.dtype, dtype_meta: QuantDtypeMeta | ||||
| ): | ): | ||||
| @@ -335,3 +343,22 @@ def convert_from_qint4(arr: np.ndarray): | |||||
| arr: Input ndarray. | arr: Input ndarray. | ||||
| """ | """ | ||||
| return _convert_from_quantized_dtype(arr, _builtin_quant_dtypes["qint4"]) | return _convert_from_quantized_dtype(arr, _builtin_quant_dtypes["qint4"]) | ||||
| def convert_to_qint1(arr: np.ndarray, q: np.dtype): | |||||
| r"""Quantize a float NumPy ndarray into a qint1 one with specified params. | |||||
| Args: | |||||
| arr: Input ndarray. | |||||
| q: Target data type, should be a qint1. | |||||
| """ | |||||
| return _convert_to_quantized_dtype(arr, q, _builtin_quant_dtypes["qint1"]) | |||||
| def convert_from_qint1(arr: np.ndarray): | |||||
| r"""Dequantize a qint1 NumPy ndarray into a float one. | |||||
| Args: | |||||
| arr: Input ndarray. | |||||
| """ | |||||
| return _convert_from_quantized_dtype(arr, _builtin_quant_dtypes["qint1"]) | |||||
| @@ -214,6 +214,14 @@ std::unique_ptr<PyArray_Descr, PyArrayDescrDeleter> dtype_mgb2np_descr(DType dty | |||||
| if (dtype.has_param()) { | if (dtype.has_param()) { | ||||
| PyArray_Descr* type_descr; | PyArray_Descr* type_descr; | ||||
| switch (dtype.enumv()) { | switch (dtype.enumv()) { | ||||
| case DTypeEnum::QuantizedS1: { | |||||
| auto& param = dtype.param<dtype::QuantizedS1>(); | |||||
| type_descr = PyArray_DescrNewFromType(NPY_INT8); | |||||
| type_descr->metadata = build_mgb_dtype_dict( | |||||
| DTypeTrait<dtype::QuantizedS1>::name, | |||||
| {{"scale", PyFloat_FromDouble(param.scale)}}); | |||||
| break; | |||||
| } | |||||
| case DTypeEnum::Quantized4Asymm: { | case DTypeEnum::Quantized4Asymm: { | ||||
| auto& param = dtype.param<dtype::Quantized4Asymm>(); | auto& param = dtype.param<dtype::Quantized4Asymm>(); | ||||
| type_descr = PyArray_DescrNewFromType(NPY_UINT8); | type_descr = PyArray_DescrNewFromType(NPY_UINT8); | ||||
| @@ -354,7 +362,7 @@ DType dtype_np2mgb_descr(PyArray_Descr* descr) { | |||||
| static_cast<uint8_t>(zero_point)); | static_cast<uint8_t>(zero_point)); | ||||
| } | } | ||||
| if (dtype_name == "QuantizedS32" || dtype_name == "QuantizedS8" || | if (dtype_name == "QuantizedS32" || dtype_name == "QuantizedS8" || | ||||
| dtype_name == "QuantizedS4") { | |||||
| dtype_name == "QuantizedS4" || dtype_name == "QuantizedS1") { | |||||
| PyObject* scale_py = PyDict_GetItemString(metadata, "scale"); | PyObject* scale_py = PyDict_GetItemString(metadata, "scale"); | ||||
| mgb_assert(scale_py, "Invalid metadata: missing scale"); | mgb_assert(scale_py, "Invalid metadata: missing scale"); | ||||
| mgb_assert( | mgb_assert( | ||||
| @@ -364,8 +372,10 @@ DType dtype_np2mgb_descr(PyArray_Descr* descr) { | |||||
| return dtype::QuantizedS32(scale); | return dtype::QuantizedS32(scale); | ||||
| } else if (dtype_name == "QuantizedS8") { | } else if (dtype_name == "QuantizedS8") { | ||||
| return dtype::QuantizedS8(scale); | return dtype::QuantizedS8(scale); | ||||
| } else { | |||||
| } else if (dtype_name == "QuantizedS4") { | |||||
| return dtype::QuantizedS4(scale); | return dtype::QuantizedS4(scale); | ||||
| } else if (dtype_name == "QuantizedS1") { | |||||
| return dtype::QuantizedS1(scale); | |||||
| } | } | ||||
| } | } | ||||
| throw ConversionError( | throw ConversionError( | ||||
| @@ -15,10 +15,12 @@ import megengine.core.tensor.megbrain_graph as G | |||||
| from megengine.core.ops import builtin as ops | from megengine.core.ops import builtin as ops | ||||
| from megengine.core.tensor.dtype import ( | from megengine.core.tensor.dtype import ( | ||||
| _builtin_quant_dtypes, | _builtin_quant_dtypes, | ||||
| convert_from_qint1, | |||||
| convert_from_qint4, | convert_from_qint4, | ||||
| convert_from_qint8, | convert_from_qint8, | ||||
| convert_from_quint4, | convert_from_quint4, | ||||
| convert_from_quint8, | convert_from_quint8, | ||||
| convert_to_qint1, | |||||
| convert_to_qint4, | convert_to_qint4, | ||||
| convert_to_qint8, | convert_to_qint8, | ||||
| convert_to_quint4, | convert_to_quint4, | ||||
| @@ -26,6 +28,7 @@ from megengine.core.tensor.dtype import ( | |||||
| get_scale, | get_scale, | ||||
| get_zero_point, | get_zero_point, | ||||
| is_quantize, | is_quantize, | ||||
| qint1, | |||||
| qint4, | qint4, | ||||
| qint8, | qint8, | ||||
| quint4, | quint4, | ||||
| @@ -113,9 +116,20 @@ def test_dtype_qint4(): | |||||
| np.testing.assert_allclose(get_scale(dt), 0.01) | np.testing.assert_allclose(get_scale(dt), 0.01) | ||||
| def test_dtype_qint1(): | |||||
| dt = qint1(0.01) | |||||
| assert isinstance(dt, np.dtype) | |||||
| assert "mgb_dtype" in dt.metadata | |||||
| np.testing.assert_allclose(dt.metadata["mgb_dtype"]["scale"], 0.01) | |||||
| assert is_quantize(dt) | |||||
| np.testing.assert_allclose(get_scale(dt), 0.01) | |||||
| @pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||
| "dtype, dtype_name", | "dtype, dtype_name", | ||||
| [ | [ | ||||
| (qint1(0.01), "qint1"), | |||||
| (quint4(0.01, 5), "quint4"), | (quint4(0.01, 5), "quint4"), | ||||
| (qint4(0.01), "qint4"), | (qint4(0.01), "qint4"), | ||||
| (quint8(0.01, 135), "quint8"), | (quint8(0.01, 135), "quint8"), | ||||
| @@ -141,6 +155,7 @@ def test_dtype_qint_mgb_ffi_handle(dtype, dtype_name): | |||||
| @pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||
| "dtype, dtype_name", | "dtype, dtype_name", | ||||
| [ | [ | ||||
| (qint1(0.01), "qint1"), | |||||
| (quint4(0.01, 5), "quint4"), | (quint4(0.01, 5), "quint4"), | ||||
| (qint4(0.01), "qint4"), | (qint4(0.01), "qint4"), | ||||
| (quint8(0.01, 135), "quint8"), | (quint8(0.01, 135), "quint8"), | ||||
| @@ -178,6 +193,7 @@ def test_qint_typecvt(dtype, dtype_name): | |||||
| @pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||
| "dtype, dtype_name", | "dtype, dtype_name", | ||||
| [ | [ | ||||
| (qint1(0.01), "qint1"), | |||||
| (quint4(0.01, 5), "quint4"), | (quint4(0.01, 5), "quint4"), | ||||
| (qint4(0.01), "qint4"), | (qint4(0.01), "qint4"), | ||||
| (quint8(0.01, 135), "quint8"), | (quint8(0.01, 135), "quint8"), | ||||
| @@ -207,6 +223,7 @@ def test_qint_astype(dtype, dtype_name): | |||||
| @pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||
| "dtype, dtype_name", | "dtype, dtype_name", | ||||
| [ | [ | ||||
| (qint1(0.01), "qint1"), | |||||
| (quint4(0.01, 5), "quint4"), | (quint4(0.01, 5), "quint4"), | ||||
| (qint4(0.01), "qint4"), | (qint4(0.01), "qint4"), | ||||
| (quint8(0.01, 135), "quint8"), | (quint8(0.01, 135), "quint8"), | ||||
| @@ -42,6 +42,10 @@ double as_double(megdnn::dt_qint4& a) { | |||||
| return static_cast<double>(a.as_int8()); | return static_cast<double>(a.as_int8()); | ||||
| } | } | ||||
| template <> | template <> | ||||
| double as_double(megdnn::dt_qint1& a) { | |||||
| return static_cast<double>(a.as_int8()); | |||||
| } | |||||
| template <> | |||||
| double as_double(megdnn::dt_qint32& a) { | double as_double(megdnn::dt_qint32& a) { | ||||
| return static_cast<double>(a.as_int32()); | return static_cast<double>(a.as_int32()); | ||||
| } | } | ||||
| @@ -111,7 +115,7 @@ void print_host_val( | |||||
| MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | ||||
| MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) | MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) | ||||
| MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb) | MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb) | ||||
| cb(dtype::Bool) | |||||
| cb(dtype::Bool) cb(::megdnn::dtype::QuantizedS1) | |||||
| #undef cb | #undef cb | ||||
| default : mgb_throw( | default : mgb_throw( | ||||
| MegBrainError, | MegBrainError, | ||||
| @@ -23,6 +23,7 @@ enum DTypeEnum : byte { | |||||
| BFloat16, | BFloat16, | ||||
| Bool, | Bool, | ||||
| Uint16, | Uint16, | ||||
| QuantizedS1, | |||||
| } | } | ||||
| table LinearQuantizationParam { | table LinearQuantizationParam { | ||||
| @@ -55,6 +55,8 @@ megdnn::DType load_dtype(const fbs::DType* dtype) { | |||||
| return dtype::_dt{}; | return dtype::_dt{}; | ||||
| MEGDNN_FOREACH_DTYPE_NAME(cb) | MEGDNN_FOREACH_DTYPE_NAME(cb) | ||||
| #undef cb | #undef cb | ||||
| case DTypeEnum_QuantizedS1: | |||||
| return dtype::QuantizedS1{param->scale()}; | |||||
| case DTypeEnum_QuantizedS4: | case DTypeEnum_QuantizedS4: | ||||
| return dtype::QuantizedS4{param->scale()}; | return dtype::QuantizedS4{param->scale()}; | ||||
| case DTypeEnum_QuantizedS8: | case DTypeEnum_QuantizedS8: | ||||
| @@ -113,6 +115,7 @@ flatbuffers::Offset<fbs::DType> build_dtype( | |||||
| break; | break; | ||||
| CASE_ASYMMETRIC(Quantized4Asymm) | CASE_ASYMMETRIC(Quantized4Asymm) | ||||
| CASE_ASYMMETRIC(Quantized8Asymm) | CASE_ASYMMETRIC(Quantized8Asymm) | ||||
| CASE_SYMMETRIC(QuantizedS1) | |||||
| CASE_SYMMETRIC(QuantizedS4) | CASE_SYMMETRIC(QuantizedS4) | ||||
| CASE_SYMMETRIC(QuantizedS8) | CASE_SYMMETRIC(QuantizedS8) | ||||
| CASE_SYMMETRIC(QuantizedS16) | CASE_SYMMETRIC(QuantizedS16) | ||||