GitOrigin-RevId: b60a7b6cf8
tags/v1.7.0
| @@ -47,6 +47,8 @@ struct MaxPooler { | |||
| } | |||
| }; | |||
| //! WARNING:for Integer, if sum ctype_ set incorrectly may cause overflow such as | |||
| //! (stype_=ctype_ =int8_t) | |||
| template <typename stype_, typename ctype_> | |||
| struct MeanIncludePoolerBase { | |||
| using stype = stype_; | |||
| @@ -65,6 +67,7 @@ struct MeanIncludePooler : public MeanIncludePoolerBase<T, T> { | |||
| ctype get_ans() { return this->sum / this->count; } | |||
| }; | |||
| //! WARNING: the result is truncated | |||
| template <> | |||
| struct MeanIncludePooler<int8_t> : public MeanIncludePoolerBase<int8_t, int32_t> { | |||
| using MeanIncludePoolerBase::MeanIncludePoolerBase; | |||
| @@ -74,7 +77,9 @@ struct MeanIncludePooler<int8_t> : public MeanIncludePoolerBase<int8_t, int32_t> | |||
| std::numeric_limits<int8_t>::max()); | |||
| } | |||
| }; | |||
| /*! | |||
| * average pooling with zero point for quint8 | |||
| /*/ | |||
| template <> | |||
| struct MeanIncludePooler<dt_quint8> { | |||
| int32_t sum; | |||
| @@ -107,7 +112,7 @@ struct MeanIncludePooler<dt_quint8> { | |||
| /*! | |||
| * \brief Average pooling operation within a single window. | |||
| * Works on integers. Rounds toward +INF. | |||
| * Works on integers. Rounds toward nearest Integer | |||
| * \tparam T input data type | |||
| * \tparam U convert input data type to U before accumulating | |||
| * \tparam ICType data type for intermediate result | |||
| @@ -228,10 +233,11 @@ struct MeanExcludePooler { | |||
| /*! | |||
| * \brief Average pooling operation within a single window. | |||
| * Works on integers. Rounds toward +INF. | |||
| * Works on integers. Rounds toward nearest Integer | |||
| * \tparam T input data type | |||
| * \tparam U convert input data type to U before accumulating | |||
| * \tparam ICType data type for intermediate result | |||
| * WARNING:for Integer, if type U or ICType set incorrectly may cause overflow | |||
| */ | |||
| template <typename T, typename U, typename ICType = U> | |||
| struct MeanExcludeRoundedPooler { | |||
| @@ -256,6 +262,10 @@ struct MeanExcludeRoundedPooler { | |||
| } | |||
| }; | |||
| template <> | |||
| struct MeanExcludePooler<int8_t> : MeanExcludeRoundedPooler<int8_t, int8_t, int32_t> { | |||
| using MeanExcludeRoundedPooler::MeanExcludeRoundedPooler; | |||
| }; | |||
| template <> | |||
| struct MeanExcludePooler<dt_quint8> | |||
| : MeanExcludeRoundedPooler<dt_quint8, uint8_t, uint32_t> { | |||
| @@ -100,4 +100,35 @@ TEST_F(NAIVE, POOLING_QUANTIZED_Q4) { | |||
| TensorValueLowbit4({1, 1, 2, 2}, u4_dt, u8_avg_exclu_dst_vec)}); | |||
| } | |||
| } | |||
| TEST_F(NAIVE, POOLING_INT_AVERAGE) { | |||
| using Mode = Pooling::Param::Mode; | |||
| Checker<Pooling> checker(handle(), /* check_dispatch */ false); | |||
| auto dt = dtype::Int8(); | |||
| Pooling::Param param = {Mode::AVERAGE, 0, 0, 1, 1, 2, 2}; | |||
| Testcase input_positive{ | |||
| TensorValue( | |||
| {1, 1, 3, 3}, dt, {127, 127, 127, 127, 127, 127, 127, 127, 127}), | |||
| {}}; | |||
| Testcase input_negative{ | |||
| TensorValue( | |||
| {1, 1, 3, 3}, dt, | |||
| {-127, -127, -127, -127, -127, -127, -127, -127, -127}), | |||
| {}}; | |||
| checker.set_param(param).exect( | |||
| input_positive, | |||
| Testcase{{}, TensorValue({1, 1, 2, 2}, dt, {127, 127, 127, 127})}); | |||
| checker.set_param(param).exect( | |||
| input_negative, | |||
| Testcase{{}, TensorValue({1, 1, 2, 2}, dt, {-127, -127, -127, -127})}); | |||
| param = {Mode::AVERAGE_COUNT_EXCLUDE_PADDING, 0, 0, 1, 1, 2, 2}; | |||
| checker.set_param(param).exect( | |||
| input_positive, | |||
| Testcase{{}, TensorValue({1, 1, 2, 2}, dt, {127, 127, 127, 127})}); | |||
| checker.set_param(param).exect( | |||
| input_negative, | |||
| Testcase{{}, TensorValue({1, 1, 2, 2}, dt, {-127, -127, -127, -127})}); | |||
| } | |||
| // vim: syntax=cpp.doxygen | |||