| @@ -451,7 +451,12 @@ namespace fallback { | |||||
| void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) { | void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) { | ||||
| check_exec(src.layout, dst.layout); | check_exec(src.layout, dst.layout); | ||||
| if (src.layout.is_contiguous() && dst.layout.is_contiguous()) { | |||||
| auto is_quantize_lowbit = [](const DType& dt) { | |||||
| return dt.category() == DTypeCategory::QUANTIZED && dt.is_low_bit(); | |||||
| }; | |||||
| if (src.layout.is_contiguous() && dst.layout.is_contiguous() && | |||||
| !is_quantize_lowbit(src.layout.dtype) && | |||||
| !is_quantize_lowbit(dst.layout.dtype)) { | |||||
| MEGDNN_DISPATCH_CPU_KERN_OPR(run_contiguous(src, dst)); | MEGDNN_DISPATCH_CPU_KERN_OPR(run_contiguous(src, dst)); | ||||
| } else { | } else { | ||||
| naive::TypeCvtImpl::exec(src, dst); | naive::TypeCvtImpl::exec(src, dst); | ||||
| @@ -32,7 +32,7 @@ def get_scale(dtype): | |||||
| def get_zero_point(dtype): | def get_zero_point(dtype): | ||||
| assert is_quantize(dtype) | assert is_quantize(dtype) | ||||
| metadata = dtype.metadata["mgb_dtype"] | metadata = dtype.metadata["mgb_dtype"] | ||||
| assert metadata["name"] == "Quantized8Asymm" | |||||
| assert metadata["name"] in ("Quantized8Asymm", "Quantized4Asymm") | |||||
| return metadata["zero_point"] | return metadata["zero_point"] | ||||
| @@ -79,6 +79,38 @@ def qint32(scale): | |||||
| ) | ) | ||||
| def quint4(scale, zero_point): | |||||
| """ | |||||
| Consturct a quantized unsigned int4 data type with ``scale`` (float) and | |||||
| ``zero_point`` (uint8). The real value represented by a quint4 data type is | |||||
| float_val = scale * (uint4_val - zero_point) | |||||
| """ | |||||
| int_zp = int(zero_point) | |||||
| assert int_zp == zero_point, "zero_point should be an integer" | |||||
| if int_zp < 0 or int_zp > 15: | |||||
| raise ValueError("zero_point should be within [0, 15] for quint4") | |||||
| return np.dtype( | |||||
| np.uint8, | |||||
| metadata={ | |||||
| "mgb_dtype": { | |||||
| "name": "Quantized4Asymm", | |||||
| "scale": float(scale), | |||||
| "zero_point": int(zero_point), | |||||
| } | |||||
| }, | |||||
| ) | |||||
| def qint4(scale): | |||||
| """ | |||||
| Construct a quantized int4 data type with ``scale`` (float). The real value | |||||
| represented by a qint4 data type is float_val = scale * int4_val | |||||
| """ | |||||
| return np.dtype( | |||||
| np.int8, metadata={"mgb_dtype": {"name": "QuantizedS4", "scale": float(scale)}} | |||||
| ) | |||||
| def convert_to_quint8(arr, q): | def convert_to_quint8(arr, q): | ||||
| """ | """ | ||||
| Quantize a float NumPy ndarray into a quint8 one with specified params. | Quantize a float NumPy ndarray into a quint8 one with specified params. | ||||
| @@ -177,3 +209,71 @@ def convert_from_qint32(arr): | |||||
| ), "arr should be a ndarray with qint8 dtype" | ), "arr should be a ndarray with qint8 dtype" | ||||
| scale = arr.dtype.metadata["mgb_dtype"]["scale"] | scale = arr.dtype.metadata["mgb_dtype"]["scale"] | ||||
| return arr.astype(np.float32) * scale | return arr.astype(np.float32) * scale | ||||
| def convert_to_quint4(arr, q): | |||||
| """ | |||||
| Quantize a float NumPy ndarray into a quint4 one with specified params. | |||||
| :param arr: Input ndarray. | |||||
| :type arr: :class:`np.ndarray` | |||||
| :param q: Target data type, should be a quint4. | |||||
| :type q: :class:`np.dtype` | |||||
| """ | |||||
| assert isinstance(arr, np.ndarray) | |||||
| assert ( | |||||
| "mgb_dtype" in q.metadata | |||||
| and q.metadata["mgb_dtype"]["name"] == "Quantized4Asymm" | |||||
| ), "q should be a quint4 dtype" | |||||
| scale, zp = q.metadata["mgb_dtype"]["scale"], q.metadata["mgb_dtype"]["zero_point"] | |||||
| return (np.round(arr / scale) + zp).clip(0, 15).astype(q) | |||||
| def convert_from_quint4(arr): | |||||
| """ | |||||
| Dequantize a quint4 NumPy ndarray into a float one. | |||||
| :param arr: Input ndarray. | |||||
| """ | |||||
| assert isinstance(arr, np.ndarray) | |||||
| assert ( | |||||
| "mgb_dtype" in arr.dtype.metadata | |||||
| and arr.dtype.metadata["mgb_dtype"]["name"] == "Quantized4Asymm" | |||||
| ), "arr should be a ndarray with quint4 dtype" | |||||
| scale, zp = ( | |||||
| arr.dtype.metadata["mgb_dtype"]["scale"], | |||||
| arr.dtype.metadata["mgb_dtype"]["zero_point"], | |||||
| ) | |||||
| return (arr.astype(np.float32) - zp) * scale | |||||
| def convert_to_qint4(arr, q): | |||||
| """ | |||||
| Quantize a float NumPy ndarray into a qint4 one with specified params. | |||||
| :param arr: Input ndarray. | |||||
| :type arr: :class:`np.ndarray` | |||||
| :param q: Target data type, should be a qint4. | |||||
| :type q: :class:`np.dtype` | |||||
| """ | |||||
| assert isinstance(arr, np.ndarray) | |||||
| assert ( | |||||
| "mgb_dtype" in q.metadata and q.metadata["mgb_dtype"]["name"] == "QuantizedS4" | |||||
| ), "q should be a qint4 dtype" | |||||
| scale = q.metadata["mgb_dtype"]["scale"] | |||||
| return (np.round(arr / scale)).clip(-8, 7).astype(q) | |||||
| def convert_from_qint4(arr): | |||||
| """ | |||||
| Dequantize a qint4 NumPy ndarray into a float one. | |||||
| :param arr: Input ndarray. | |||||
| """ | |||||
| assert isinstance(arr, np.ndarray) | |||||
| assert ( | |||||
| "mgb_dtype" in arr.dtype.metadata | |||||
| and arr.dtype.metadata["mgb_dtype"]["name"] == "QuantizedS4" | |||||
| ), "arr should be a ndarray with qint4 dtype" | |||||
| scale = arr.dtype.metadata["mgb_dtype"]["scale"] | |||||
| return arr.astype(np.float32) * scale | |||||
| @@ -452,6 +452,23 @@ std::unique_ptr<PyArray_Descr, PyArrayDescrDeleter> dtype_mgb2np_descr( | |||||
| {{"scale", PyFloat_FromDouble(param.scale)}}); | {{"scale", PyFloat_FromDouble(param.scale)}}); | ||||
| break; | break; | ||||
| } | } | ||||
| case DTypeEnum::Quantized4Asymm: { | |||||
| auto& param = dtype.param<dtype::Quantized4Asymm>(); | |||||
| type_descr = PyArray_DescrNewFromType(NPY_UINT8); | |||||
| type_descr->metadata = build_mgb_dtype_dict( | |||||
| DTypeTrait<dtype::Quantized4Asymm>::name, | |||||
| {{"scale", PyFloat_FromDouble(param.scale)}, | |||||
| {"zero_point", PyLong_FromLong(param.zero_point)}}); | |||||
| break; | |||||
| } | |||||
| case DTypeEnum::QuantizedS4: { | |||||
| auto& param = dtype.param<dtype::QuantizedS4>(); | |||||
| type_descr = PyArray_DescrNewFromType(NPY_INT8); | |||||
| type_descr->metadata = build_mgb_dtype_dict( | |||||
| DTypeTrait<dtype::QuantizedS4>::name, | |||||
| {{"scale", PyFloat_FromDouble(param.scale)}}); | |||||
| break; | |||||
| } | |||||
| case DTypeEnum::QuantizedS32: { | case DTypeEnum::QuantizedS32: { | ||||
| auto& param = dtype.param<dtype::QuantizedS32>(); | auto& param = dtype.param<dtype::QuantizedS32>(); | ||||
| type_descr = PyArray_DescrNewFromType(NPY_INT32); | type_descr = PyArray_DescrNewFromType(NPY_INT32); | ||||
| @@ -529,7 +546,29 @@ DType dtype_np2mgb_descr(PyArray_Descr* descr) { | |||||
| static_cast<float>(PyFloat_AS_DOUBLE(scale_py)), | static_cast<float>(PyFloat_AS_DOUBLE(scale_py)), | ||||
| static_cast<uint8_t>(zero_point)); | static_cast<uint8_t>(zero_point)); | ||||
| } | } | ||||
| if (dtype_name == "QuantizedS32" || dtype_name == "QuantizedS8") { | |||||
| if (dtype_name == "Quantized4Asymm") { | |||||
| PyObject* scale_py = PyDict_GetItemString(metadata, "scale"); | |||||
| PyObject* zero_point_py = | |||||
| PyDict_GetItemString(metadata, "zero_point"); | |||||
| mgb_assert(scale_py && zero_point_py, | |||||
| "Invalid Quantized4Asymm metadata: missing scale or " | |||||
| "zero_point."); | |||||
| mgb_assert( | |||||
| PyFloat_Check(scale_py), | |||||
| "Invalid Quantized4Asymm metadata: scale should be float"); | |||||
| mgb_assert(PyLong_Check(zero_point_py), | |||||
| "Invalid Quantized4Asymm metadata: zero_point should be " | |||||
| "integer"); | |||||
| auto zero_point = PyLong_AS_LONG(zero_point_py); | |||||
| mgb_assert(zero_point >= 0 && zero_point < 15, | |||||
| "Invalid Quantized4Asymm metadata: zero_point should be " | |||||
| "in [0, 15)"); | |||||
| return dtype::Quantized4Asymm( | |||||
| static_cast<float>(PyFloat_AS_DOUBLE(scale_py)), | |||||
| static_cast<uint8_t>(zero_point)); | |||||
| } | |||||
| if (dtype_name == "QuantizedS32" || dtype_name == "QuantizedS8" || | |||||
| dtype_name == "QuantizedS4") { | |||||
| PyObject* scale_py = PyDict_GetItemString(metadata, "scale"); | PyObject* scale_py = PyDict_GetItemString(metadata, "scale"); | ||||
| mgb_assert(scale_py, "Invalid metadata: missing scale"); | mgb_assert(scale_py, "Invalid metadata: missing scale"); | ||||
| mgb_assert(PyFloat_Check(scale_py), | mgb_assert(PyFloat_Check(scale_py), | ||||
| @@ -537,8 +576,10 @@ DType dtype_np2mgb_descr(PyArray_Descr* descr) { | |||||
| float scale = static_cast<float>(PyFloat_AS_DOUBLE(scale_py)); | float scale = static_cast<float>(PyFloat_AS_DOUBLE(scale_py)); | ||||
| if (dtype_name == "QuantizedS32") { | if (dtype_name == "QuantizedS32") { | ||||
| return dtype::QuantizedS32(scale); | return dtype::QuantizedS32(scale); | ||||
| } else { | |||||
| } else if (dtype_name == "QuantizedS8"){ | |||||
| return dtype::QuantizedS8(scale); | return dtype::QuantizedS8(scale); | ||||
| } else { | |||||
| return dtype::QuantizedS4(scale); | |||||
| } | } | ||||
| } | } | ||||
| throw ConversionError( | throw ConversionError( | ||||
| @@ -14,6 +14,7 @@ | |||||
| #include "megbrain/exception.h" | #include "megbrain/exception.h" | ||||
| #include "megbrain/utils/metahelper.h" | #include "megbrain/utils/metahelper.h" | ||||
| #include "megbrain/utils/arith_helper.h" | #include "megbrain/utils/arith_helper.h" | ||||
| #include "megdnn/dtype.h" | |||||
| #include <cmath> | #include <cmath> | ||||
| #include <cstring> | #include <cstring> | ||||
| @@ -357,6 +358,52 @@ struct LowbitMemcpy<bits, true> { | |||||
| } | } | ||||
| } | } | ||||
| }; | }; | ||||
| template<typename DT> | |||||
| struct QuantizedLowbitTrait; | |||||
| template<> | |||||
| struct QuantizedLowbitTrait<dtype::Quantized4Asymm> { | |||||
| static constexpr int8_t SHIFT = 0; | |||||
| }; | |||||
| template<> | |||||
| struct QuantizedLowbitTrait<dtype::QuantizedS4> { | |||||
| static constexpr int8_t SHIFT = 8; | |||||
| }; | |||||
| template <typename DT, bool div_byte = (DTypeTrait<DT>::category == | |||||
| DTypeCategory::QUANTIZED) && | |||||
| (8 % DTypeTrait<DT>::low_bit == 0)> | |||||
| struct QuantizedLowbitMemcpy; | |||||
| template <typename DT> | |||||
| struct QuantizedLowbitMemcpy<DT, true> { | |||||
| // cast with bits that 8 % bits == 0 | |||||
| static constexpr uint16_t bits = DTypeTrait<DT>::low_bit; | |||||
| static constexpr uint8_t MASK = (1 << bits) - 1; | |||||
| using Trait = QuantizedLowbitTrait<DT>; | |||||
| static void byte2compact(void* dest_raw, const void* src_raw, size_t n) { | |||||
| auto dest = static_cast<uint8_t*>(dest_raw); | |||||
| auto src = static_cast<const int8_t*>(src_raw); | |||||
| memset(dest, 0, divup<size_t>(n * bits, 8)); | |||||
| for (size_t i = 0; i < n; ++i) { | |||||
| int8_t val = src[i] + Trait::SHIFT; | |||||
| mgb_assert(val >= 0 && val < (1 << bits)); | |||||
| dest[i * bits / 8] |= val << (i * bits % 8); | |||||
| } | |||||
| } | |||||
| static void compact2byte(void* dest_raw, const void* src_raw, size_t n) { | |||||
| auto dest = static_cast<int8_t*>(dest_raw); | |||||
| auto src = static_cast<const uint8_t*>(src_raw); | |||||
| for (size_t i = 0; i < n; ++i) { | |||||
| int8_t val = ((src[i * bits / 8] >> (i * bits % 8)) & MASK); | |||||
| dest[i] = val - Trait::SHIFT; | |||||
| } | |||||
| } | |||||
| }; | |||||
| } // anonymous namespace | } // anonymous namespace | ||||
| void mgb::lowbit_memcpy_byte2compact( | void mgb::lowbit_memcpy_byte2compact( | ||||
| @@ -365,6 +412,11 @@ void mgb::lowbit_memcpy_byte2compact( | |||||
| if (dtype == mgb::dtype::name##bits()) \ | if (dtype == mgb::dtype::name##bits()) \ | ||||
| return LowbitMemcpy<bits>::byte2compact(dest, src, n); | return LowbitMemcpy<bits>::byte2compact(dest, src, n); | ||||
| MEGDNN_FOREACH_LOWBIT_DTYPE(cb) | MEGDNN_FOREACH_LOWBIT_DTYPE(cb) | ||||
| #undef cb | |||||
| #define cb(dt) \ | |||||
| if (dtype.enumv() == DTypeTrait<dt>::enumv) \ | |||||
| return QuantizedLowbitMemcpy<dt>::byte2compact(dest, src, n); | |||||
| MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb) | |||||
| #undef cb | #undef cb | ||||
| mgb_throw(MegBrainError, "bad dtype for lowbit: %s", dtype.name()); | mgb_throw(MegBrainError, "bad dtype for lowbit: %s", dtype.name()); | ||||
| } | } | ||||
| @@ -375,6 +427,11 @@ void mgb::lowbit_memcpy_compact2byte( | |||||
| if (dtype == mgb::dtype::name##bits()) \ | if (dtype == mgb::dtype::name##bits()) \ | ||||
| return LowbitMemcpy<bits>::compact2byte(dest, src, n); | return LowbitMemcpy<bits>::compact2byte(dest, src, n); | ||||
| MEGDNN_FOREACH_LOWBIT_DTYPE(cb) | MEGDNN_FOREACH_LOWBIT_DTYPE(cb) | ||||
| #undef cb | |||||
| #define cb(dt) \ | |||||
| if (dtype.enumv() == DTypeTrait<dt>::enumv) \ | |||||
| return QuantizedLowbitMemcpy<dt>::compact2byte(dest, src, n); | |||||
| MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb) | |||||
| #undef cb | #undef cb | ||||
| mgb_throw(MegBrainError, "bad dtype for lowbit: %s", dtype.name()); | mgb_throw(MegBrainError, "bad dtype for lowbit: %s", dtype.name()); | ||||
| } | } | ||||