| @@ -451,7 +451,12 @@ namespace fallback { | |||
| void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) { | |||
| check_exec(src.layout, dst.layout); | |||
| if (src.layout.is_contiguous() && dst.layout.is_contiguous()) { | |||
| auto is_quantize_lowbit = [](const DType& dt) { | |||
| return dt.category() == DTypeCategory::QUANTIZED && dt.is_low_bit(); | |||
| }; | |||
| if (src.layout.is_contiguous() && dst.layout.is_contiguous() && | |||
| !is_quantize_lowbit(src.layout.dtype) && | |||
| !is_quantize_lowbit(dst.layout.dtype)) { | |||
| MEGDNN_DISPATCH_CPU_KERN_OPR(run_contiguous(src, dst)); | |||
| } else { | |||
| naive::TypeCvtImpl::exec(src, dst); | |||
| @@ -32,7 +32,7 @@ def get_scale(dtype): | |||
| def get_zero_point(dtype): | |||
| assert is_quantize(dtype) | |||
| metadata = dtype.metadata["mgb_dtype"] | |||
| assert metadata["name"] == "Quantized8Asymm" | |||
| assert metadata["name"] in ("Quantized8Asymm", "Quantized4Asymm") | |||
| return metadata["zero_point"] | |||
| @@ -79,6 +79,38 @@ def qint32(scale): | |||
| ) | |||
| def quint4(scale, zero_point): | |||
| """ | |||
| Consturct a quantized unsigned int4 data type with ``scale`` (float) and | |||
| ``zero_point`` (uint8). The real value represented by a quint4 data type is | |||
| float_val = scale * (uint4_val - zero_point) | |||
| """ | |||
| int_zp = int(zero_point) | |||
| assert int_zp == zero_point, "zero_point should be an integer" | |||
| if int_zp < 0 or int_zp > 15: | |||
| raise ValueError("zero_point should be within [0, 15] for quint4") | |||
| return np.dtype( | |||
| np.uint8, | |||
| metadata={ | |||
| "mgb_dtype": { | |||
| "name": "Quantized4Asymm", | |||
| "scale": float(scale), | |||
| "zero_point": int(zero_point), | |||
| } | |||
| }, | |||
| ) | |||
| def qint4(scale): | |||
| """ | |||
| Construct a quantized int4 data type with ``scale`` (float). The real value | |||
| represented by a qint4 data type is float_val = scale * int4_val | |||
| """ | |||
| return np.dtype( | |||
| np.int8, metadata={"mgb_dtype": {"name": "QuantizedS4", "scale": float(scale)}} | |||
| ) | |||
| def convert_to_quint8(arr, q): | |||
| """ | |||
| Quantize a float NumPy ndarray into a quint8 one with specified params. | |||
| @@ -177,3 +209,71 @@ def convert_from_qint32(arr): | |||
| ), "arr should be a ndarray with qint8 dtype" | |||
| scale = arr.dtype.metadata["mgb_dtype"]["scale"] | |||
| return arr.astype(np.float32) * scale | |||
| def convert_to_quint4(arr, q): | |||
| """ | |||
| Quantize a float NumPy ndarray into a quint4 one with specified params. | |||
| :param arr: Input ndarray. | |||
| :type arr: :class:`np.ndarray` | |||
| :param q: Target data type, should be a quint4. | |||
| :type q: :class:`np.dtype` | |||
| """ | |||
| assert isinstance(arr, np.ndarray) | |||
| assert ( | |||
| "mgb_dtype" in q.metadata | |||
| and q.metadata["mgb_dtype"]["name"] == "Quantized4Asymm" | |||
| ), "q should be a quint4 dtype" | |||
| scale, zp = q.metadata["mgb_dtype"]["scale"], q.metadata["mgb_dtype"]["zero_point"] | |||
| return (np.round(arr / scale) + zp).clip(0, 15).astype(q) | |||
| def convert_from_quint4(arr): | |||
| """ | |||
| Dequantize a quint4 NumPy ndarray into a float one. | |||
| :param arr: Input ndarray. | |||
| """ | |||
| assert isinstance(arr, np.ndarray) | |||
| assert ( | |||
| "mgb_dtype" in arr.dtype.metadata | |||
| and arr.dtype.metadata["mgb_dtype"]["name"] == "Quantized4Asymm" | |||
| ), "arr should be a ndarray with quint4 dtype" | |||
| scale, zp = ( | |||
| arr.dtype.metadata["mgb_dtype"]["scale"], | |||
| arr.dtype.metadata["mgb_dtype"]["zero_point"], | |||
| ) | |||
| return (arr.astype(np.float32) - zp) * scale | |||
| def convert_to_qint4(arr, q): | |||
| """ | |||
| Quantize a float NumPy ndarray into a qint4 one with specified params. | |||
| :param arr: Input ndarray. | |||
| :type arr: :class:`np.ndarray` | |||
| :param q: Target data type, should be a qint4. | |||
| :type q: :class:`np.dtype` | |||
| """ | |||
| assert isinstance(arr, np.ndarray) | |||
| assert ( | |||
| "mgb_dtype" in q.metadata and q.metadata["mgb_dtype"]["name"] == "QuantizedS4" | |||
| ), "q should be a qint4 dtype" | |||
| scale = q.metadata["mgb_dtype"]["scale"] | |||
| return (np.round(arr / scale)).clip(-8, 7).astype(q) | |||
| def convert_from_qint4(arr): | |||
| """ | |||
| Dequantize a qint4 NumPy ndarray into a float one. | |||
| :param arr: Input ndarray. | |||
| """ | |||
| assert isinstance(arr, np.ndarray) | |||
| assert ( | |||
| "mgb_dtype" in arr.dtype.metadata | |||
| and arr.dtype.metadata["mgb_dtype"]["name"] == "QuantizedS4" | |||
| ), "arr should be a ndarray with qint4 dtype" | |||
| scale = arr.dtype.metadata["mgb_dtype"]["scale"] | |||
| return arr.astype(np.float32) * scale | |||
| @@ -452,6 +452,23 @@ std::unique_ptr<PyArray_Descr, PyArrayDescrDeleter> dtype_mgb2np_descr( | |||
| {{"scale", PyFloat_FromDouble(param.scale)}}); | |||
| break; | |||
| } | |||
| case DTypeEnum::Quantized4Asymm: { | |||
| auto& param = dtype.param<dtype::Quantized4Asymm>(); | |||
| type_descr = PyArray_DescrNewFromType(NPY_UINT8); | |||
| type_descr->metadata = build_mgb_dtype_dict( | |||
| DTypeTrait<dtype::Quantized4Asymm>::name, | |||
| {{"scale", PyFloat_FromDouble(param.scale)}, | |||
| {"zero_point", PyLong_FromLong(param.zero_point)}}); | |||
| break; | |||
| } | |||
| case DTypeEnum::QuantizedS4: { | |||
| auto& param = dtype.param<dtype::QuantizedS4>(); | |||
| type_descr = PyArray_DescrNewFromType(NPY_INT8); | |||
| type_descr->metadata = build_mgb_dtype_dict( | |||
| DTypeTrait<dtype::QuantizedS4>::name, | |||
| {{"scale", PyFloat_FromDouble(param.scale)}}); | |||
| break; | |||
| } | |||
| case DTypeEnum::QuantizedS32: { | |||
| auto& param = dtype.param<dtype::QuantizedS32>(); | |||
| type_descr = PyArray_DescrNewFromType(NPY_INT32); | |||
| @@ -529,7 +546,29 @@ DType dtype_np2mgb_descr(PyArray_Descr* descr) { | |||
| static_cast<float>(PyFloat_AS_DOUBLE(scale_py)), | |||
| static_cast<uint8_t>(zero_point)); | |||
| } | |||
| if (dtype_name == "QuantizedS32" || dtype_name == "QuantizedS8") { | |||
| if (dtype_name == "Quantized4Asymm") { | |||
| PyObject* scale_py = PyDict_GetItemString(metadata, "scale"); | |||
| PyObject* zero_point_py = | |||
| PyDict_GetItemString(metadata, "zero_point"); | |||
| mgb_assert(scale_py && zero_point_py, | |||
| "Invalid Quantized4Asymm metadata: missing scale or " | |||
| "zero_point."); | |||
| mgb_assert( | |||
| PyFloat_Check(scale_py), | |||
| "Invalid Quantized4Asymm metadata: scale should be float"); | |||
| mgb_assert(PyLong_Check(zero_point_py), | |||
| "Invalid Quantized4Asymm metadata: zero_point should be " | |||
| "integer"); | |||
| auto zero_point = PyLong_AS_LONG(zero_point_py); | |||
| mgb_assert(zero_point >= 0 && zero_point < 15, | |||
| "Invalid Quantized4Asymm metadata: zero_point should be " | |||
| "in [0, 15)"); | |||
| return dtype::Quantized4Asymm( | |||
| static_cast<float>(PyFloat_AS_DOUBLE(scale_py)), | |||
| static_cast<uint8_t>(zero_point)); | |||
| } | |||
| if (dtype_name == "QuantizedS32" || dtype_name == "QuantizedS8" || | |||
| dtype_name == "QuantizedS4") { | |||
| PyObject* scale_py = PyDict_GetItemString(metadata, "scale"); | |||
| mgb_assert(scale_py, "Invalid metadata: missing scale"); | |||
| mgb_assert(PyFloat_Check(scale_py), | |||
| @@ -537,8 +576,10 @@ DType dtype_np2mgb_descr(PyArray_Descr* descr) { | |||
| float scale = static_cast<float>(PyFloat_AS_DOUBLE(scale_py)); | |||
| if (dtype_name == "QuantizedS32") { | |||
| return dtype::QuantizedS32(scale); | |||
| } else { | |||
| } else if (dtype_name == "QuantizedS8"){ | |||
| return dtype::QuantizedS8(scale); | |||
| } else { | |||
| return dtype::QuantizedS4(scale); | |||
| } | |||
| } | |||
| throw ConversionError( | |||
| @@ -14,6 +14,7 @@ | |||
| #include "megbrain/exception.h" | |||
| #include "megbrain/utils/metahelper.h" | |||
| #include "megbrain/utils/arith_helper.h" | |||
| #include "megdnn/dtype.h" | |||
| #include <cmath> | |||
| #include <cstring> | |||
| @@ -357,6 +358,52 @@ struct LowbitMemcpy<bits, true> { | |||
| } | |||
| } | |||
| }; | |||
| template<typename DT> | |||
| struct QuantizedLowbitTrait; | |||
| template<> | |||
| struct QuantizedLowbitTrait<dtype::Quantized4Asymm> { | |||
| static constexpr int8_t SHIFT = 0; | |||
| }; | |||
| template<> | |||
| struct QuantizedLowbitTrait<dtype::QuantizedS4> { | |||
| static constexpr int8_t SHIFT = 8; | |||
| }; | |||
| template <typename DT, bool div_byte = (DTypeTrait<DT>::category == | |||
| DTypeCategory::QUANTIZED) && | |||
| (8 % DTypeTrait<DT>::low_bit == 0)> | |||
| struct QuantizedLowbitMemcpy; | |||
| template <typename DT> | |||
| struct QuantizedLowbitMemcpy<DT, true> { | |||
| // cast with bits that 8 % bits == 0 | |||
| static constexpr uint16_t bits = DTypeTrait<DT>::low_bit; | |||
| static constexpr uint8_t MASK = (1 << bits) - 1; | |||
| using Trait = QuantizedLowbitTrait<DT>; | |||
| static void byte2compact(void* dest_raw, const void* src_raw, size_t n) { | |||
| auto dest = static_cast<uint8_t*>(dest_raw); | |||
| auto src = static_cast<const int8_t*>(src_raw); | |||
| memset(dest, 0, divup<size_t>(n * bits, 8)); | |||
| for (size_t i = 0; i < n; ++i) { | |||
| int8_t val = src[i] + Trait::SHIFT; | |||
| mgb_assert(val >= 0 && val < (1 << bits)); | |||
| dest[i * bits / 8] |= val << (i * bits % 8); | |||
| } | |||
| } | |||
| static void compact2byte(void* dest_raw, const void* src_raw, size_t n) { | |||
| auto dest = static_cast<int8_t*>(dest_raw); | |||
| auto src = static_cast<const uint8_t*>(src_raw); | |||
| for (size_t i = 0; i < n; ++i) { | |||
| int8_t val = ((src[i * bits / 8] >> (i * bits % 8)) & MASK); | |||
| dest[i] = val - Trait::SHIFT; | |||
| } | |||
| } | |||
| }; | |||
| } // anonymous namespace | |||
| void mgb::lowbit_memcpy_byte2compact( | |||
| @@ -365,6 +412,11 @@ void mgb::lowbit_memcpy_byte2compact( | |||
| if (dtype == mgb::dtype::name##bits()) \ | |||
| return LowbitMemcpy<bits>::byte2compact(dest, src, n); | |||
| MEGDNN_FOREACH_LOWBIT_DTYPE(cb) | |||
| #undef cb | |||
| #define cb(dt) \ | |||
| if (dtype.enumv() == DTypeTrait<dt>::enumv) \ | |||
| return QuantizedLowbitMemcpy<dt>::byte2compact(dest, src, n); | |||
| MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb) | |||
| #undef cb | |||
| mgb_throw(MegBrainError, "bad dtype for lowbit: %s", dtype.name()); | |||
| } | |||
| @@ -375,6 +427,11 @@ void mgb::lowbit_memcpy_compact2byte( | |||
| if (dtype == mgb::dtype::name##bits()) \ | |||
| return LowbitMemcpy<bits>::compact2byte(dest, src, n); | |||
| MEGDNN_FOREACH_LOWBIT_DTYPE(cb) | |||
| #undef cb | |||
| #define cb(dt) \ | |||
| if (dtype.enumv() == DTypeTrait<dt>::enumv) \ | |||
| return QuantizedLowbitMemcpy<dt>::compact2byte(dest, src, n); | |||
| MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb) | |||
| #undef cb | |||
| mgb_throw(MegBrainError, "bad dtype for lowbit: %s", dtype.name()); | |||
| } | |||