GitOrigin-RevId: e0f97052ff
tags/v0.4.0
| @@ -132,6 +132,15 @@ namespace megdnn { | |||||
| cb(::megdnn::dtype::Quantized4Asymm) \ | cb(::megdnn::dtype::Quantized4Asymm) \ | ||||
| cb(::megdnn::dtype::QuantizedS4) | cb(::megdnn::dtype::QuantizedS4) | ||||
| #define MEGDNN_FOREACH_QUANTIZED_DTYPE_SYMM(cb) \ | |||||
| cb(::megdnn::dtype::QuantizedS32) \ | |||||
| cb(::megdnn::dtype::QuantizedS8) \ | |||||
| cb(::megdnn::dtype::QuantizedS4) | |||||
| #define MEGDNN_FOREACH_QUANTIZED_DTYPE_ASYMM(cb) \ | |||||
| cb(::megdnn::dtype::Quantized8Asymm) \ | |||||
| cb(::megdnn::dtype::Quantized4Asymm) | |||||
| /*! | /*! | ||||
| * \brief a POD representation of a single byte | * \brief a POD representation of a single byte | ||||
| * | * | ||||
| @@ -604,9 +604,16 @@ void ConvolutionBase<Parameter>::check_or_deduce_dtype_fwd(DType src, | |||||
| if (!dst.valid()) { | if (!dst.valid()) { | ||||
| dst = supported_dst_dtype.at(0); | dst = supported_dst_dtype.at(0); | ||||
| } else { | } else { | ||||
| megdnn_assert(vec_contains(supported_dst_dtype, dst), | |||||
| "unsupported Conv(%s, %s) -> %s", src.name(), | |||||
| filter.name(), dst.name()); | |||||
| bool dst_supported = false; | |||||
| for (auto&& dt : supported_dst_dtype) { | |||||
| if (dtype_almost_equal(dt, dst)) { | |||||
| dst_supported = true; | |||||
| break; | |||||
| } | |||||
| } | |||||
| MEGDNN_MARK_USED_VAR(dst_supported); | |||||
| megdnn_assert(dst_supported, "unsupported Conv(%s, %s) -> %s", | |||||
| src.name(), filter.name(), dst.name()); | |||||
| } | } | ||||
| megdnn_assert(param().compute_mode != Param::ComputeMode::FLOAT32 | megdnn_assert(param().compute_mode != Param::ComputeMode::FLOAT32 | ||||
| #if !MEGDNN_DISABLE_FLOAT16 | #if !MEGDNN_DISABLE_FLOAT16 | ||||
| @@ -245,6 +245,25 @@ float megdnn::mul_scale(DType lhs, DType rhs) { | |||||
| } | } | ||||
| // clang-format on | // clang-format on | ||||
| bool megdnn::dtype_almost_equal(DType lhs, DType rhs) { | |||||
| if (lhs.enumv() != rhs.enumv()) | |||||
| return false; | |||||
| if (lhs.category() != DTypeCategory::QUANTIZED) | |||||
| return true; | |||||
| #define cb(dt) \ | |||||
| if (lhs.enumv() == DTypeTrait<dt>::enumv) \ | |||||
| return almost_equal(lhs.param<dt>().scale, rhs.param<dt>().scale); | |||||
| MEGDNN_FOREACH_QUANTIZED_DTYPE_SYMM(cb) | |||||
| #undef cb | |||||
| #define cb(dt) \ | |||||
| if (lhs.enumv() == DTypeTrait<dt>::enumv) \ | |||||
| return almost_equal(lhs.param<dt>().scale, rhs.param<dt>().scale) && \ | |||||
| lhs.param<dt>().zero_point == rhs.param<dt>().zero_point; | |||||
| MEGDNN_FOREACH_QUANTIZED_DTYPE_ASYMM(cb) | |||||
| #undef cb | |||||
| megdnn_assert_internal(false); | |||||
| } | |||||
| template <> | template <> | ||||
| uint8_t megdnn::convert<dt_quint4, uint8_t>(dt_quint4 src, uint8_t dst, | uint8_t megdnn::convert<dt_quint4, uint8_t>(dt_quint4 src, uint8_t dst, | ||||
| size_t offset) { | size_t offset) { | ||||
| @@ -434,6 +434,20 @@ int8_t convert<dt_qint4, int8_t>(dt_qint4 src, int8_t dst, size_t offset); | |||||
| template <> | template <> | ||||
| dt_qint4 convert<int8_t, dt_qint4>(int8_t src, dt_qint4 dst, size_t offset); | dt_qint4 convert<int8_t, dt_qint4>(int8_t src, dt_qint4 dst, size_t offset); | ||||
| /*! | |||||
| * \brief check float equal within given ULP(unit in the last place) | |||||
| */ | |||||
| template <class T> | |||||
| static inline | |||||
| typename std::enable_if<!std::numeric_limits<T>::is_integer, bool>::type | |||||
| almost_equal(T x, T y, int unit_last_place = 1) { | |||||
| return std::abs(x - y) < (std::numeric_limits<T>::epsilon() * | |||||
| std::abs(x + y) * unit_last_place) || | |||||
| std::abs(x - y) < std::numeric_limits<T>::min(); | |||||
| } | |||||
| bool dtype_almost_equal(DType lhs, DType rhs); | |||||
| /** | /** | ||||
| * \brief N-dimensional index space | * \brief N-dimensional index space | ||||
| */ | */ | ||||