GitOrigin-RevId: 2aec72010f
tags/v1.5.0
| @@ -86,6 +86,16 @@ struct RoundingConverter<dt_qint4> { | |||||
| } | } | ||||
| }; | }; | ||||
| template <> | |||||
| struct RoundingConverter<dt_quint4> { | |||||
| __host__ __device__ __forceinline__ dt_quint4 operator()(float x) const { | |||||
| #if MEGDNN_CC_HOST | |||||
| using std::round; | |||||
| #endif | |||||
| return static_cast<dt_quint4>(round(x)); | |||||
| } | |||||
| }; | |||||
| } // namespace rounding | } // namespace rounding | ||||
| } // namespace megdnn | } // namespace megdnn | ||||
| @@ -73,9 +73,10 @@ void WarpPerspectiveBase::check_layout_fwd(const TensorLayout& src, | |||||
| src.dtype.enumv() == DTypeEnum::Uint8 || | src.dtype.enumv() == DTypeEnum::Uint8 || | ||||
| (src.dtype.enumv() == DTypeEnum::QuantizedS8 || | (src.dtype.enumv() == DTypeEnum::QuantizedS8 || | ||||
| src.dtype.enumv() == DTypeEnum::Quantized8Asymm) || | src.dtype.enumv() == DTypeEnum::Quantized8Asymm) || | ||||
| src.dtype.enumv() == DTypeEnum::QuantizedS4, | |||||
| src.dtype.enumv() == DTypeEnum::QuantizedS4 || | |||||
| src.dtype.enumv() == DTypeEnum::Quantized4Asymm, | |||||
| "WarpPerspective NCHW input dtype should be " | "WarpPerspective NCHW input dtype should be " | ||||
| "Float32/Int8/Uint8/QInt8/QUint8" DNN_FLOAT16_SELECT( | |||||
| "Float32/Int8/Uint8/QInt8/QUint8/QInt4/QUInt4" DNN_FLOAT16_SELECT( | |||||
| "/Float16/BFloat16", "") "."); | "/Float16/BFloat16", "") "."); | ||||
| megdnn_assert( | megdnn_assert( | ||||
| (src.dtype.category() == DTypeCategory::FLOAT && | (src.dtype.category() == DTypeCategory::FLOAT && | ||||
| @@ -118,8 +119,9 @@ void WarpPerspectiveBase::check_layout_fwd(const TensorLayout& src, | |||||
| megdnn_assert(param().bmode != | megdnn_assert(param().bmode != | ||||
| param::WarpPerspective::BorderMode::ISOLATED); | param::WarpPerspective::BorderMode::ISOLATED); | ||||
| } else if (param().format == param::WarpPerspective::Format::NCHW64) { | } else if (param().format == param::WarpPerspective::Format::NCHW64) { | ||||
| megdnn_assert(src.dtype.enumv() == DTypeEnum::QuantizedS4, | |||||
| "src expected QuantizedS4, but got %s", | |||||
| megdnn_assert((src.dtype.enumv() == DTypeEnum::QuantizedS4 || | |||||
| src.dtype.enumv() == DTypeEnum::Quantized4Asymm), | |||||
| "src expected QuantizedS4/Quantized4Asymm, but got %s", | |||||
| src.dtype.name()); | src.dtype.name()); | ||||
| megdnn_assert(mat.dtype == dtype::Float32(), | megdnn_assert(mat.dtype == dtype::Float32(), | ||||
| "matrix dtype expected float, got %s", | "matrix dtype expected float, got %s", | ||||
| @@ -44,8 +44,9 @@ void get_inner_layout(const TensorLayout& src, const TensorLayout& dst, | |||||
| TensorLayout& inner_src, TensorLayout& inner_dst, | TensorLayout& inner_src, TensorLayout& inner_dst, | ||||
| Handle* handle, | Handle* handle, | ||||
| WarpPerspectiveForwardImpl::Param::Format format) { | WarpPerspectiveForwardImpl::Param::Format format) { | ||||
| if (src.dtype.enumv() == DTypeEnum::QuantizedS4 && | |||||
| dst.dtype.enumv() == DTypeEnum::QuantizedS4 && | |||||
| if ((src.dtype.enumv() == DTypeEnum::QuantizedS4 || | |||||
| src.dtype.enumv() == DTypeEnum::Quantized4Asymm) && | |||||
| dst.dtype.enumv() == src.dtype.enumv() && | |||||
| format == param::WarpPerspective::Format::NCHW) { | format == param::WarpPerspective::Format::NCHW) { | ||||
| auto relayout_opr = handle->create_operator<RelayoutFormat>(); | auto relayout_opr = handle->create_operator<RelayoutFormat>(); | ||||
| deduce_reformat_layout(relayout_opr, src, inner_src, | deduce_reformat_layout(relayout_opr, src, inner_src, | ||||
| @@ -130,7 +131,8 @@ WorkspaceBundle WarpPerspectiveForwardImpl::get_workspace_bundle( | |||||
| TensorLayout fsrc = src; | TensorLayout fsrc = src; | ||||
| TensorLayout fmat = mat; | TensorLayout fmat = mat; | ||||
| TensorLayout fdst = dst; | TensorLayout fdst = dst; | ||||
| if (src.dtype.enumv() == DTypeEnum::QuantizedS4 && | |||||
| if ((src.dtype.enumv() == DTypeEnum::QuantizedS4 || | |||||
| src.dtype.enumv() == DTypeEnum::Quantized4Asymm) && | |||||
| param().format == param::WarpPerspective::Format::NCHW) { | param().format == param::WarpPerspective::Format::NCHW) { | ||||
| get_inner_layout(src, dst, fsrc, fdst, handle(), param().format); | get_inner_layout(src, dst, fsrc, fdst, handle(), param().format); | ||||
| sizes.push_back(fsrc.span().dist_byte()); | sizes.push_back(fsrc.span().dist_byte()); | ||||
| @@ -177,7 +179,8 @@ void WarpPerspectiveForwardImpl::exec(_megdnn_tensor_in ssrc, | |||||
| ctypecvt.src_to_comp_type(ssrc, src) | ctypecvt.src_to_comp_type(ssrc, src) | ||||
| .src_to_comp_type(smat, mat) | .src_to_comp_type(smat, mat) | ||||
| .src_to_comp_type(sdst, dst); | .src_to_comp_type(sdst, dst); | ||||
| } else if (ssrc.layout.dtype.enumv() == DTypeEnum::QuantizedS4 && | |||||
| } else if ((ssrc.layout.dtype.enumv() == DTypeEnum::QuantizedS4 || | |||||
| ssrc.layout.dtype.enumv() == DTypeEnum::Quantized4Asymm) && | |||||
| param().format == Param::Format::NCHW) { | param().format == Param::Format::NCHW) { | ||||
| auto handle_ptr = handle(); | auto handle_ptr = handle(); | ||||
| get_inner_layout(ssrc.layout, sdst.layout, src.layout, dst.layout, | get_inner_layout(ssrc.layout, sdst.layout, src.layout, dst.layout, | ||||
| @@ -330,7 +333,7 @@ void WarpPerspectiveForwardImpl::exec(_megdnn_tensor_in ssrc, | |||||
| param().format == Param::Format::NCHW64 || | param().format == Param::Format::NCHW64 || | ||||
| param().format == Param::Format::NCHW, | param().format == Param::Format::NCHW, | ||||
| "WarpPerspective on CUDA supports NCHW64 or NCHW+ " | "WarpPerspective on CUDA supports NCHW64 or NCHW+ " | ||||
| "QuantizedS4 only"); | |||||
| "QuantizedS4"); | |||||
| bval = roundf(bval); | bval = roundf(bval); | ||||
| bval = fmin(fmax(-8.f, bval), 7.f); | bval = fmin(fmax(-8.f, bval), 7.f); | ||||
| warp_perspective::forward_proxy_nchw64<dt_qint4>( | warp_perspective::forward_proxy_nchw64<dt_qint4>( | ||||
| @@ -352,6 +355,34 @@ void WarpPerspectiveForwardImpl::exec(_megdnn_tensor_in ssrc, | |||||
| relayout_opr->param() = trans_param; | relayout_opr->param() = trans_param; | ||||
| relayout_opr->exec(dst, sdst, {}); | relayout_opr->exec(dst, sdst, {}); | ||||
| } | } | ||||
| } else if (src.layout.dtype.enumv() == | |||||
| DTypeEnum::Quantized4Asymm) { | |||||
| megdnn_assert( | |||||
| param().format == Param::Format::NCHW64 || | |||||
| param().format == Param::Format::NCHW, | |||||
| "WarpPerspective on CUDA supports NCHW64 or NCHW+ " | |||||
| "Quantized4Asymm"); | |||||
| bval = roundf(bval); | |||||
| bval = fmin(fmax(0, bval), 15); | |||||
| warp_perspective::forward_proxy_nchw64<dt_quint4>( | |||||
| src.compatible_ptr<dt_quint4>(), | |||||
| mat.ptr<dt_float32>(), | |||||
| mat_idx.raw_ptr ? mat_idx.ptr<int>() : nullptr, | |||||
| dst.compatible_ptr<dt_quint4>(), src.layout[0], | |||||
| mat.layout[0], C, IH, IW, OH, OW, | |||||
| static_cast<dt_quint4>(bval), bmode, | |||||
| async_error_info(handle()), m_error_tracker, | |||||
| stream); | |||||
| if (param().format == Param::Format::NCHW) { | |||||
| auto relayout_opr = | |||||
| handle()->create_operator<RelayoutFormat>(); | |||||
| RelayoutFormat::Param trans_param; | |||||
| trans_param.mode = | |||||
| RelayoutFormat::Param::Mode::NCHW64_NCHW; | |||||
| trans_param.oc = sdst.layout[1]; | |||||
| relayout_opr->param() = trans_param; | |||||
| relayout_opr->exec(dst, sdst, {}); | |||||
| } | |||||
| } | } | ||||
| } else if ((src.layout.dtype.enumv() == | } else if ((src.layout.dtype.enumv() == | ||||
| DTypeEnum::Quantized8Asymm || | DTypeEnum::Quantized8Asymm || | ||||
| @@ -144,25 +144,68 @@ __global__ void kern_general_nchw4(SrcVisitor src, const float* __restrict mat, | |||||
| } | } | ||||
| } | } | ||||
| #define warp_perspective_transform(idx) \ | |||||
| template <bool signedness> | |||||
| MEGDNN_DEVICE __forceinline__ int transform_int8_to_bit4x8(int s0, int s1, | |||||
| int s2, int s3, | |||||
| int s4, int s5, | |||||
| int s6, int s7); | |||||
| template <> | |||||
| MEGDNN_DEVICE __forceinline__ int transform_int8_to_bit4x8<true>( | |||||
| int s0, int s1, int s2, int s3, int s4, int s5, int s6, int s7) { | |||||
| return transform_int8_to_int4x8(s0, s1, s2, s3, s4, s5, s6, s7); | |||||
| } | |||||
| template <> | |||||
| MEGDNN_DEVICE __forceinline__ int transform_int8_to_bit4x8<false>( | |||||
| int s0, int s1, int s2, int s3, int s4, int s5, int s6, int s7) { | |||||
| return transform_int8_to_uint4x8(s0, s1, s2, s3, s4, s5, s6, s7); | |||||
| } | |||||
| template <bool signedness> | |||||
| MEGDNN_DEVICE __forceinline__ void | |||||
| transform_bit4x8_to_int8(int (&result)[8], const int& source); | |||||
| template <> | |||||
| MEGDNN_DEVICE __forceinline__ void | |||||
| transform_bit4x8_to_int8<true>(int (&result)[8], const int& source){ | |||||
| transform_int4x8_to_int8(result, source); | |||||
| } | |||||
| template <> | |||||
| MEGDNN_DEVICE __forceinline__ void | |||||
| transform_bit4x8_to_int8<false>(int (&result)[8], const int& source){ | |||||
| transform_uint4x8_to_int8(result, source); | |||||
| } | |||||
| template <bool signedness, typename OutputConverter> | |||||
| MEGDNN_DEVICE __forceinline__ int pack_output_func( | |||||
| OutputConverter& output_converter, int (&s00)[8], int (&s01)[8], | |||||
| int (&s10)[8], int (&s11)[8], float palpha, float pbeta, float nalpha, | |||||
| float nbeta) { | |||||
| #define warp_perspective_transform(idx) \ | |||||
| static_cast<int>(output_converter(s00[idx] * nalpha * nbeta + \ | static_cast<int>(output_converter(s00[idx] * nalpha * nbeta + \ | ||||
| s01[idx] * nalpha * pbeta + \ | s01[idx] * nalpha * pbeta + \ | ||||
| s10[idx] * palpha * nbeta + \ | s10[idx] * palpha * nbeta + \ | ||||
| s11[idx] * palpha * pbeta) \ | s11[idx] * palpha * pbeta) \ | ||||
| .as_int8()) | |||||
| #define pack_output \ | |||||
| transform_int8_to_int4x8( \ | |||||
| warp_perspective_transform(0), warp_perspective_transform(1), \ | |||||
| warp_perspective_transform(2), warp_perspective_transform(3), \ | |||||
| warp_perspective_transform(4), warp_perspective_transform(5), \ | |||||
| warp_perspective_transform(6), warp_perspective_transform(7)) | |||||
| .as_storage()) | |||||
| return transform_int8_to_bit4x8<signedness>( | |||||
| warp_perspective_transform(0), warp_perspective_transform(1), | |||||
| warp_perspective_transform(2), warp_perspective_transform(3), | |||||
| warp_perspective_transform(4), warp_perspective_transform(5), | |||||
| warp_perspective_transform(6), warp_perspective_transform(7)); | |||||
| #undef warp_perspective_transform | |||||
| } | |||||
| template <typename ctype, typename Getter, typename SrcVisitor, | template <typename ctype, typename Getter, typename SrcVisitor, | ||||
| typename OutputConverter> | typename OutputConverter> | ||||
| __global__ void kern_general_nchw64(SrcVisitor src, const float* __restrict mat, | __global__ void kern_general_nchw64(SrcVisitor src, const float* __restrict mat, | ||||
| ctype* __restrict dst, int C, int IH, | ctype* __restrict dst, int C, int IH, | ||||
| int IW, int OH, int OW) { | int IW, int OH, int OW) { | ||||
| constexpr bool signedness = std::is_same<ctype, dt_qint4>::value; | |||||
| Getter getter; | Getter getter; | ||||
| OutputConverter output_converter; | OutputConverter output_converter; | ||||
| int ow = blockIdx.x * blockDim.x + threadIdx.x; | int ow = blockIdx.x * blockDim.x + threadIdx.x; | ||||
| @@ -199,29 +242,37 @@ __global__ void kern_general_nchw64(SrcVisitor src, const float* __restrict mat, | |||||
| s[2] = __ldg(sptr_int4 + i_coor_10 + c1); | s[2] = __ldg(sptr_int4 + i_coor_10 + c1); | ||||
| s[3] = __ldg(sptr_int4 + i_coor_11 + c1); | s[3] = __ldg(sptr_int4 + i_coor_11 + c1); | ||||
| transform_int4x8_to_int8(s00, s[0].x); | |||||
| transform_int4x8_to_int8(s01, s[1].x); | |||||
| transform_int4x8_to_int8(s10, s[2].x); | |||||
| transform_int4x8_to_int8(s11, s[3].x); | |||||
| d.x = pack_output; | |||||
| transform_int4x8_to_int8(s00, s[0].y); | |||||
| transform_int4x8_to_int8(s01, s[1].y); | |||||
| transform_int4x8_to_int8(s10, s[2].y); | |||||
| transform_int4x8_to_int8(s11, s[3].y); | |||||
| d.y = pack_output; | |||||
| transform_int4x8_to_int8(s00, s[0].z); | |||||
| transform_int4x8_to_int8(s01, s[1].z); | |||||
| transform_int4x8_to_int8(s10, s[2].z); | |||||
| transform_int4x8_to_int8(s11, s[3].z); | |||||
| d.z = pack_output; | |||||
| transform_int4x8_to_int8(s00, s[0].w); | |||||
| transform_int4x8_to_int8(s01, s[1].w); | |||||
| transform_int4x8_to_int8(s10, s[2].w); | |||||
| transform_int4x8_to_int8(s11, s[3].w); | |||||
| d.w = pack_output; | |||||
| transform_bit4x8_to_int8<signedness>(s00, s[0].x); | |||||
| transform_bit4x8_to_int8<signedness>(s01, s[1].x); | |||||
| transform_bit4x8_to_int8<signedness>(s10, s[2].x); | |||||
| transform_bit4x8_to_int8<signedness>(s11, s[3].x); | |||||
| d.x = pack_output_func<signedness>(output_converter, s00, s01, s10, | |||||
| s11, palpha, pbeta, nalpha, | |||||
| nbeta); | |||||
| transform_bit4x8_to_int8<signedness>(s00, s[0].y); | |||||
| transform_bit4x8_to_int8<signedness>(s01, s[1].y); | |||||
| transform_bit4x8_to_int8<signedness>(s10, s[2].y); | |||||
| transform_bit4x8_to_int8<signedness>(s11, s[3].y); | |||||
| d.y = pack_output_func<signedness>(output_converter, s00, s01, s10, | |||||
| s11, palpha, pbeta, nalpha, | |||||
| nbeta); | |||||
| transform_bit4x8_to_int8<signedness>(s00, s[0].z); | |||||
| transform_bit4x8_to_int8<signedness>(s01, s[1].z); | |||||
| transform_bit4x8_to_int8<signedness>(s10, s[2].z); | |||||
| transform_bit4x8_to_int8<signedness>(s11, s[3].z); | |||||
| d.z = pack_output_func<signedness>(output_converter, s00, s01, s10, | |||||
| s11, palpha, pbeta, nalpha, | |||||
| nbeta); | |||||
| transform_bit4x8_to_int8<signedness>(s00, s[0].w); | |||||
| transform_bit4x8_to_int8<signedness>(s01, s[1].w); | |||||
| transform_bit4x8_to_int8<signedness>(s10, s[2].w); | |||||
| transform_bit4x8_to_int8<signedness>(s11, s[3].w); | |||||
| d.w = pack_output_func<signedness>(output_converter, s00, s01, s10, | |||||
| s11, palpha, pbeta, nalpha, | |||||
| nbeta); | |||||
| dst_int4[o_coor + c1] = d; | dst_int4[o_coor + c1] = d; | ||||
| sptr_int4 += IH * IW * 2; | sptr_int4 += IH * IW * 2; | ||||
| @@ -320,15 +371,25 @@ __global__ void kern_const_border_nchw4(SrcVisitor src, | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| template <bool signedness> | |||||
| MEGDNN_DEVICE __forceinline__ static void transform_bit4x8_to_int8( | |||||
| int (&result)[8], const int& source) { | |||||
| #pragma unroll | |||||
| for (int i = 0; i < 8; i++) { | |||||
| result[i] = unpack_integer_4bits<signedness>( | |||||
| reinterpret_cast<unsigned const&>(source), (i << 2)); | |||||
| } | |||||
| } | |||||
| template <typename ctype, typename SrcVisitor, typename OutputConverter> | template <typename ctype, typename SrcVisitor, typename OutputConverter> | ||||
| __global__ void kern_const_border_nchw64(SrcVisitor src, | __global__ void kern_const_border_nchw64(SrcVisitor src, | ||||
| const float* __restrict mat, | const float* __restrict mat, | ||||
| ctype* __restrict dst, int C, int IH, | ctype* __restrict dst, int C, int IH, | ||||
| int IW, int OH, int OW, ctype bval) { | int IW, int OH, int OW, ctype bval) { | ||||
| constexpr bool signedness = std::is_same<ctype, dt_qint4>::value; | |||||
| OutputConverter output_converter; | OutputConverter output_converter; | ||||
| int ow = blockIdx.x * blockDim.x + threadIdx.x; | int ow = blockIdx.x * blockDim.x + threadIdx.x; | ||||
| int c1 = ow %2; | |||||
| int c1 = ow % 2; | |||||
| ow = ow / 2; | ow = ow / 2; | ||||
| int oh = blockIdx.y * blockDim.y + threadIdx.y; | int oh = blockIdx.y * blockDim.y + threadIdx.y; | ||||
| const ctype* __restrict sptr = src.get(blockIdx.z, C * IH * IW / 2); | const ctype* __restrict sptr = src.get(blockIdx.z, C * IH * IW / 2); | ||||
| @@ -359,9 +420,9 @@ __global__ void kern_const_border_nchw64(SrcVisitor src, | |||||
| int i_coor_11 = (ih1 * IW + iw1) << 1; | int i_coor_11 = (ih1 * IW + iw1) << 1; | ||||
| bool flag00 = okh0 && okw0, flag01 = okh0 && okw1, | bool flag00 = okh0 && okw0, flag01 = okh0 && okw1, | ||||
| flag10 = okh1 && okw0, flag11 = okh1 && okw1; | flag10 = okh1 && okw0, flag11 = okh1 && okw1; | ||||
| int8_t bval_4 = bval.as_int8() & 0xF; | |||||
| int bval_8 = transform_int8_to_int4x8(bval_4, bval_4, bval_4, bval_4, | |||||
| bval_4, bval_4, bval_4, bval_4); | |||||
| int8_t bval_4 = bval.as_storage() & 0xF; | |||||
| int bval_8 = transform_int8_to_bit4x8<signedness>( | |||||
| bval_4, bval_4, bval_4, bval_4, bval_4, bval_4, bval_4, bval_4); | |||||
| int4 bval_int4; | int4 bval_int4; | ||||
| bval_int4.x = bval_8; | bval_int4.x = bval_8; | ||||
| bval_int4.y = bval_8; | bval_int4.y = bval_8; | ||||
| @@ -391,29 +452,37 @@ __global__ void kern_const_border_nchw64(SrcVisitor src, | |||||
| s[3] = bval_int4; | s[3] = bval_int4; | ||||
| } | } | ||||
| transform_int4x8_to_int8(s00, s[0].x); | |||||
| transform_int4x8_to_int8(s01, s[1].x); | |||||
| transform_int4x8_to_int8(s10, s[2].x); | |||||
| transform_int4x8_to_int8(s11, s[3].x); | |||||
| d.x = pack_output; | |||||
| transform_int4x8_to_int8(s00, s[0].y); | |||||
| transform_int4x8_to_int8(s01, s[1].y); | |||||
| transform_int4x8_to_int8(s10, s[2].y); | |||||
| transform_int4x8_to_int8(s11, s[3].y); | |||||
| d.y = pack_output; | |||||
| transform_int4x8_to_int8(s00, s[0].z); | |||||
| transform_int4x8_to_int8(s01, s[1].z); | |||||
| transform_int4x8_to_int8(s10, s[2].z); | |||||
| transform_int4x8_to_int8(s11, s[3].z); | |||||
| d.z = pack_output; | |||||
| transform_int4x8_to_int8(s00, s[0].w); | |||||
| transform_int4x8_to_int8(s01, s[1].w); | |||||
| transform_int4x8_to_int8(s10, s[2].w); | |||||
| transform_int4x8_to_int8(s11, s[3].w); | |||||
| d.w = pack_output; | |||||
| transform_bit4x8_to_int8<signedness>(s00, s[0].x); | |||||
| transform_bit4x8_to_int8<signedness>(s01, s[1].x); | |||||
| transform_bit4x8_to_int8<signedness>(s10, s[2].x); | |||||
| transform_bit4x8_to_int8<signedness>(s11, s[3].x); | |||||
| d.x = pack_output_func<signedness>(output_converter, s00, s01, s10, | |||||
| s11, palpha, pbeta, nalpha, | |||||
| nbeta); | |||||
| transform_bit4x8_to_int8<signedness>(s00, s[0].y); | |||||
| transform_bit4x8_to_int8<signedness>(s01, s[1].y); | |||||
| transform_bit4x8_to_int8<signedness>(s10, s[2].y); | |||||
| transform_bit4x8_to_int8<signedness>(s11, s[3].y); | |||||
| d.y = pack_output_func<signedness>(output_converter, s00, s01, s10, | |||||
| s11, palpha, pbeta, nalpha, | |||||
| nbeta); | |||||
| transform_bit4x8_to_int8<signedness>(s00, s[0].z); | |||||
| transform_bit4x8_to_int8<signedness>(s01, s[1].z); | |||||
| transform_bit4x8_to_int8<signedness>(s10, s[2].z); | |||||
| transform_bit4x8_to_int8<signedness>(s11, s[3].z); | |||||
| d.z = pack_output_func<signedness>(output_converter, s00, s01, s10, | |||||
| s11, palpha, pbeta, nalpha, | |||||
| nbeta); | |||||
| transform_bit4x8_to_int8<signedness>(s00, s[0].w); | |||||
| transform_bit4x8_to_int8<signedness>(s01, s[1].w); | |||||
| transform_bit4x8_to_int8<signedness>(s10, s[2].w); | |||||
| transform_bit4x8_to_int8<signedness>(s11, s[3].w); | |||||
| d.w = pack_output_func<signedness>(output_converter, s00, s01, s10, | |||||
| s11, palpha, pbeta, nalpha, | |||||
| nbeta); | |||||
| dst_int4[o_coor + c1] = d; | dst_int4[o_coor + c1] = d; | ||||
| sptr_int4 += IH * IW * 2; | sptr_int4 += IH * IW * 2; | ||||
| @@ -1448,6 +1517,7 @@ INST(int8_t) | |||||
| void*, cudaStream_t); | void*, cudaStream_t); | ||||
| INST(dt_qint4) | INST(dt_qint4) | ||||
| INST(dt_quint4) | |||||
| #undef INST | #undef INST | ||||
| template <typename src_dtype, typename src_ctype, typename dst_ctype> | template <typename src_dtype, typename src_ctype, typename dst_ctype> | ||||
| @@ -249,6 +249,7 @@ void WarpPerspectiveForwardImpl::kern_naive_nhwcd4( | |||||
| MIDOUT_END(); | MIDOUT_END(); | ||||
| } | } | ||||
| template <typename ctype, typename mtype> | template <typename ctype, typename mtype> | ||||
| void WarpPerspectiveForwardImpl::kern_naive_int4( | void WarpPerspectiveForwardImpl::kern_naive_int4( | ||||
| const KernParam<ctype, mtype>& kern_param, size_t task_id) { | const KernParam<ctype, mtype>& kern_param, size_t task_id) { | ||||
| @@ -257,6 +258,7 @@ void WarpPerspectiveForwardImpl::kern_naive_int4( | |||||
| UNPACK_WARP_PERSPECTIVE_FWD_KERN_PARAM(kern_param); | UNPACK_WARP_PERSPECTIVE_FWD_KERN_PARAM(kern_param); | ||||
| MEGDNN_MARK_USED_VAR(N_MAT); | MEGDNN_MARK_USED_VAR(N_MAT); | ||||
| uint8_t c_shift, c_mask, iw_shift = 0, ow_shift = 0; | uint8_t c_shift, c_mask, iw_shift = 0, ow_shift = 0; | ||||
| constexpr bool signedness = std::is_same<ctype, dt_qint4>::value; | |||||
| switch (param().format) { | switch (param().format) { | ||||
| case Format::NCHW: | case Format::NCHW: | ||||
| c_shift = 0; | c_shift = 0; | ||||
| @@ -282,8 +284,13 @@ void WarpPerspectiveForwardImpl::kern_naive_int4( | |||||
| << c_shift) + | << c_shift) + | ||||
| (c & c_mask); | (c & c_mask); | ||||
| uint8_t result = | uint8_t result = | ||||
| (sptr[index / 2].as_int8() >> (4 * (index % 2))) & 0xF; | |||||
| return result & uint8_t(1 << 3) ? result | ~mask : result; | |||||
| (sptr[index / 2].as_storage() >> (4 * (index % 2))) & 0xF; | |||||
| if (signedness) { | |||||
| return result & uint8_t(1 << 3) ? result | ~mask : result; | |||||
| } else { | |||||
| megdnn_assert((std::is_same<ctype, dt_quint4>::value)); | |||||
| return result; | |||||
| } | |||||
| }; | }; | ||||
| auto visit_src_bd = [&sptr, sstrd, border_val, c_shift, c_mask]( | auto visit_src_bd = [&sptr, sstrd, border_val, c_shift, c_mask]( | ||||
| size_t c, int h, int w) -> float { | size_t c, int h, int w) -> float { | ||||
| @@ -292,8 +299,14 @@ void WarpPerspectiveForwardImpl::kern_naive_int4( | |||||
| << c_shift) + | << c_shift) + | ||||
| (c & c_mask); | (c & c_mask); | ||||
| uint8_t result = | uint8_t result = | ||||
| (sptr[index / 2].as_int8() >> (4 * (index % 2))) & 0xF; | |||||
| return result & uint8_t(1 << 3) ? result | ~mask : result; | |||||
| (sptr[index / 2].as_storage() >> (4 * (index % 2))) & | |||||
| 0xF; | |||||
| if (signedness) { | |||||
| return result & uint8_t(1 << 3) ? result | ~mask : result; | |||||
| } else { | |||||
| megdnn_assert((std::is_same<ctype, dt_quint4>::value)); | |||||
| return result;; | |||||
| } | |||||
| } else | } else | ||||
| return border_val; | return border_val; | ||||
| }; | }; | ||||
| @@ -302,9 +315,9 @@ void WarpPerspectiveForwardImpl::kern_naive_int4( | |||||
| size_t index = ((dstrd[0] * (c >> c_shift) + dstrd[1] * h + w) | size_t index = ((dstrd[0] * (c >> c_shift) + dstrd[1] * h + w) | ||||
| << c_shift) + | << c_shift) + | ||||
| (c & c_mask); | (c & c_mask); | ||||
| dptr[index / 2] = | |||||
| (dptr[index / 2].as_int8() & (0xF0 >> (4 * (index % 2)))) | | |||||
| (v.as_int8() << (4 * (index % 2))); | |||||
| dptr[index / 2] = (dptr[index / 2].as_storage() & | |||||
| (0xF0 >> (4 * (index % 2)))) | | |||||
| (v.as_storage() << (4 * (index % 2))); | |||||
| }; | }; | ||||
| rounding::RoundingConverter<ctype> output_converter; | rounding::RoundingConverter<ctype> output_converter; | ||||
| @@ -334,21 +347,20 @@ void WarpPerspectiveForwardImpl::kern_naive_int4( | |||||
| int iw1 = get_real_coord(std::floor(alphaw) + 1, IW); | int iw1 = get_real_coord(std::floor(alphaw) + 1, IW); | ||||
| int ih0 = get_real_coord(std::floor(alphah) + 0, IH); | int ih0 = get_real_coord(std::floor(alphah) + 0, IH); | ||||
| int ih1 = get_real_coord(std::floor(alphah) + 1, IH); | int ih1 = get_real_coord(std::floor(alphah) + 1, IH); | ||||
| alphaw -= floor(alphaw); | alphaw -= floor(alphaw); | ||||
| alphah -= floor(alphah); | alphah -= floor(alphah); | ||||
| if (bmode != BorderMode::CONSTANT) { | if (bmode != BorderMode::CONSTANT) { | ||||
| rep(c, C) { | rep(c, C) { | ||||
| set_visit_dst( | |||||
| c, oh, ow, | |||||
| output_converter( | |||||
| visit_src(c, ih0, iw0) * (1.0f - alphaw) * | |||||
| auto val = visit_src(c, ih0, iw0) * (1.0f - alphaw) * | |||||
| (1.0f - alphah) + | (1.0f - alphah) + | ||||
| visit_src(c, ih0, iw1) * alphaw * | visit_src(c, ih0, iw1) * alphaw * | ||||
| (1.0f - alphah) + | (1.0f - alphah) + | ||||
| visit_src(c, ih1, iw0) * (1.0f - alphaw) * | visit_src(c, ih1, iw0) * (1.0f - alphaw) * | ||||
| alphah + | alphah + | ||||
| visit_src(c, ih1, iw1) * alphaw * alphah)); | |||||
| visit_src(c, ih1, iw1) * alphaw * alphah; | |||||
| set_visit_dst( | |||||
| c, oh, ow, | |||||
| output_converter(val)); | |||||
| } | } | ||||
| } else { | } else { | ||||
| rep(c, C) { | rep(c, C) { | ||||
| @@ -613,6 +625,13 @@ void WarpPerspectiveForwardImpl::exec(_megdnn_tensor_in src, | |||||
| "WarpPerspective: %s", | "WarpPerspective: %s", | ||||
| src.layout.dtype.name()) | src.layout.dtype.name()) | ||||
| .c_str()); | .c_str()); | ||||
| } else if (src.layout.dtype.enumv() == | |||||
| DTypeTrait<dtype::Quantized4Asymm>::enumv) { | |||||
| DISPATCH_ST(dtype::Quantized4Asymm, dt_quint4, float, KERN_INT4); | |||||
| megdnn_throw(ssprintf("Unsupported input DType in " | |||||
| "WarpPerspective: %s", | |||||
| src.layout.dtype.name()) | |||||
| .c_str()); | |||||
| } | } | ||||
| bool is_fusion_dtype = src.layout.dtype.enumv() != dst.layout.dtype.enumv(); | bool is_fusion_dtype = src.layout.dtype.enumv() != dst.layout.dtype.enumv(); | ||||
| @@ -107,7 +107,8 @@ protected: | |||||
| ret.mptr = mat.ptr<mtype>(); | ret.mptr = mat.ptr<mtype>(); | ||||
| ret.dptr = dst.compatible_ptr<ctype>(); | ret.dptr = dst.compatible_ptr<ctype>(); | ||||
| } else if (src.layout.dtype.enumv() == DTypeEnum::QuantizedS8 || | } else if (src.layout.dtype.enumv() == DTypeEnum::QuantizedS8 || | ||||
| src.layout.dtype.enumv() == DTypeEnum::QuantizedS4) { | |||||
| src.layout.dtype.enumv() == DTypeEnum::QuantizedS4 || | |||||
| src.layout.dtype.enumv() == DTypeEnum::Quantized4Asymm) { | |||||
| ret.sptr = src.compatible_ptr<ctype>(); | ret.sptr = src.compatible_ptr<ctype>(); | ||||
| ret.mptr = mat.ptr<mtype>(); | ret.mptr = mat.ptr<mtype>(); | ||||
| ret.dptr = dst.compatible_ptr<ctype>(); | ret.dptr = dst.compatible_ptr<ctype>(); | ||||
| @@ -647,6 +647,31 @@ TEST_F(CUDA, WARP_PERSPECTIVE_FORWARD_QINT4) { | |||||
| } | } | ||||
| } | } | ||||
| TEST_F(CUDA, WARP_PERSPECTIVE_FORWARD_QUINT4) { | |||||
| using Param = WarpPerspective::Param; | |||||
| Checker<WarpPerspectiveForward> checker(handle_cuda()); | |||||
| WarpPerspectiveMatRNG rng; | |||||
| checker.set_rng(1, &rng); | |||||
| checker.set_dtype(0, dtype::Quantized4Asymm(1.25f, 0)) | |||||
| .set_dtype(1, dtype::Float32()) | |||||
| .set_dtype(2, dtype::Quantized4Asymm(1.25f, 0)); | |||||
| for (auto bmode : {WarpPerspective::BorderMode::WRAP, | |||||
| WarpPerspective::BorderMode::REFLECT, | |||||
| WarpPerspective::BorderMode::REPLICATE, | |||||
| WarpPerspective::BorderMode::CONSTANT}) { | |||||
| WarpPerspective::Param param; | |||||
| param.border_val = 0.3f; | |||||
| param.bmode = bmode; | |||||
| param.imode = Param::InterpolationMode::LINEAR; | |||||
| param.format = Param::Format::NCHW; | |||||
| checker.set_param(param); | |||||
| checker.set_epsilon(1 + 1e-3); | |||||
| checker.execs({{1, 64, 11, 11}, {1, 3, 3}, {1, 64, 11, 11}}); | |||||
| checker.execs({{20, 640, 11, 12}, {20, 3, 3}, {20, 640, 11, 12}}); | |||||
| } | |||||
| } | |||||
| TEST_F(CUDA, WARP_PERSPECTIVE_BACKWARD_DATA_BFLOAT16) { | TEST_F(CUDA, WARP_PERSPECTIVE_BACKWARD_DATA_BFLOAT16) { | ||||
| Checker<WarpPerspectiveBackwardData> checker(handle_cuda()); | Checker<WarpPerspectiveBackwardData> checker(handle_cuda()); | ||||
| WarpPerspectiveMatRNG rng; | WarpPerspectiveMatRNG rng; | ||||
| @@ -701,7 +726,7 @@ TEST_F(CUDA, WARP_PERSPECTIVE_MAT_IDX) { | |||||
| warp_perspective::run_mat_idx_test(handle_cuda()); | warp_perspective::run_mat_idx_test(handle_cuda()); | ||||
| } | } | ||||
| TEST_F(CUDA, WARP_PERSPECTIVE_NCHW64) { | |||||
| TEST_F(CUDA, WARP_PERSPECTIVE_NCHW64_QINT4) { | |||||
| using Param = WarpPerspective::Param; | using Param = WarpPerspective::Param; | ||||
| WarpPerspective::Param param; | WarpPerspective::Param param; | ||||
| Checker<WarpPerspectiveForward> checker(handle_cuda()); | Checker<WarpPerspectiveForward> checker(handle_cuda()); | ||||
| @@ -767,6 +792,72 @@ TEST_F(CUDA, WARP_PERSPECTIVE_NCHW64) { | |||||
| } | } | ||||
| } | } | ||||
| TEST_F(CUDA, WARP_PERSPECTIVE_NCHW64_QUINT4) { | |||||
| using Param = WarpPerspective::Param; | |||||
| WarpPerspective::Param param; | |||||
| Checker<WarpPerspectiveForward> checker(handle_cuda()); | |||||
| WarpPerspectiveMatRNG_V2 rng; | |||||
| checker.set_dtype(0, dtype::Quantized4Asymm(0.1f, 3)); | |||||
| checker.set_dtype(2, dtype::Quantized4Asymm(0.1f, 3)); | |||||
| for (auto bmode : {WarpPerspective::BorderMode::WRAP, | |||||
| WarpPerspective::BorderMode::REFLECT, | |||||
| WarpPerspective::BorderMode::REPLICATE, | |||||
| WarpPerspective::BorderMode::CONSTANT}) { | |||||
| param.border_val = 0.3f; | |||||
| param.bmode = bmode; | |||||
| param.imode = Param::InterpolationMode::LINEAR; | |||||
| param.format = Param::Format::NCHW64; | |||||
| checker.set_param(param); | |||||
| checker.set_epsilon(1 + 1e-3); | |||||
| rng.set_hw(10, 11); | |||||
| checker.set_rng(1, &rng); | |||||
| checker.execs({{2, 1, 10, 11, 64}, {2, 3, 3}, {2, 1, 11, 12, 64}}); | |||||
| checker.execs( | |||||
| {{20, 300, 10, 11, 64}, {20, 3, 3}, {20, 300, 11, 12, 64}}); | |||||
| checker.execs( | |||||
| {{2200, 3, 10, 11, 64}, {2200, 3, 3}, {2200, 3, 11, 12, 64}}); | |||||
| rng.set_hw(25, 25); | |||||
| checker.set_rng(1, &rng); | |||||
| checker.execs({{1, 25, 25, 25, 64}, {1, 3, 3}, {1, 25, 25, 51, 64}}); | |||||
| rng.set_hw(25, 510); | |||||
| checker.set_rng(1, &rng); | |||||
| checker.execs({{1, 1, 25, 510, 64}, {1, 3, 3}, {1, 1, 25, 25, 64}}); | |||||
| rng.set_hw(25, 25); | |||||
| checker.set_rng(1, &rng); | |||||
| checker.execs({{1, 1, 25, 25, 64}, {1, 3, 3}, {1, 1, 51, 51, 64}}); | |||||
| rng.set_hw(51, 51); | |||||
| checker.set_rng(1, &rng); | |||||
| checker.execs({{1, 1, 51, 51, 64}, {1, 3, 3}, {1, 1, 25, 25, 64}}); | |||||
| } | |||||
| { | |||||
| Checker<WarpPerspective, WarpPerspectiveMatIdxProxy> checker( | |||||
| handle_cuda()); | |||||
| constexpr int N_SRC = 5; | |||||
| UniformIntRNG mat_idx_rng{0, N_SRC - 1}; | |||||
| checker.set_dtype(0, dtype::Quantized4Asymm(0.1f, 3)); | |||||
| checker.set_rng(1, &rng); | |||||
| checker.set_dtype(2, dtype::Int32()); | |||||
| checker.set_rng(2, &mat_idx_rng); | |||||
| checker.set_dtype(3, dtype::Quantized4Asymm(0.1f, 3)); | |||||
| param.bmode = WarpPerspective::Param::BorderMode::REFLECT; | |||||
| param.imode = param::WarpPerspective::InterpolationMode::LINEAR; | |||||
| checker.set_param(param); | |||||
| checker.set_epsilon(1 + 1e-3); | |||||
| rng.set_hw(10, 11); | |||||
| checker.set_rng(1, &rng); | |||||
| checker.execs( | |||||
| {{N_SRC, 3, 10, 11, 64}, {2, 3, 3}, {2}, {2, 3, 11, 12, 64}}); | |||||
| rng.set_hw(17, 13); | |||||
| checker.set_rng(1, &rng); | |||||
| checker.execs({{N_SRC, 14, 17, 13, 64}, | |||||
| {123, 3, 3}, | |||||
| {123}, | |||||
| {123, 14, 16, 15, 64}}); | |||||
| } | |||||
| } | |||||
| #if MEGDNN_WITH_BENCHMARK | #if MEGDNN_WITH_BENCHMARK | ||||
| TEST_F(CUDA, BENCHMARK_WARP_PERSPECTIVE_NCHW4) { | TEST_F(CUDA, BENCHMARK_WARP_PERSPECTIVE_NCHW4) { | ||||
| @@ -196,8 +196,8 @@ TEST_F(NAIVE, WARP_PERSPECTIVE_NCHW_QINT4) { | |||||
| param.imode = WarpPerspective::Param::InterpolationMode::LINEAR; | param.imode = WarpPerspective::Param::InterpolationMode::LINEAR; | ||||
| param.format = WarpPerspective::Param::Format::NCHW; | param.format = WarpPerspective::Param::Format::NCHW; | ||||
| std::vector<int> input_values = {1, 3, 2, 2, 0, 0, 0, 0, 2}, | |||||
| output_values = {1, 2, 2, 2}; | |||||
| std::vector<int> input_values = {-1, -3, -2, -2, 0, 0, 0, 0, -2}, | |||||
| output_values = {-1, -2, -2, -2}; | |||||
| checker.set_param(param).exect( | checker.set_param(param).exect( | ||||
| Testcase{TensorValueLowbit4({1, 1, 3, 3}, dtype::QuantizedS4(0.1), | Testcase{TensorValueLowbit4({1, 1, 3, 3}, dtype::QuantizedS4(0.1), | ||||
| @@ -212,6 +212,31 @@ TEST_F(NAIVE, WARP_PERSPECTIVE_NCHW_QINT4) { | |||||
| output_values)}); | output_values)}); | ||||
| } | } | ||||
| TEST_F(NAIVE, WARP_PERSPECTIVE_NCHW_QUINT4) { | |||||
| Checker<WarpPerspective> checker(handle(), false); | |||||
| WarpPerspective::Param param; | |||||
| param.bmode = WarpPerspective::Param::BorderMode::BORDER_REFLECT; | |||||
| param.imode = WarpPerspective::Param::InterpolationMode::LINEAR; | |||||
| param.format = WarpPerspective::Param::Format::NCHW; | |||||
| std::vector<int> input_values = {4, 13, 0, 0, 0, 0, 0, 0, 0}, | |||||
| output_values = {6, 8, 8, 9}; | |||||
| checker.set_param(param).exect( | |||||
| Testcase{TensorValueLowbit4({1, 1, 3, 3}, | |||||
| dtype::Quantized4Asymm(0.1, 3), | |||||
| input_values), | |||||
| TensorValue({1, 3, 3}, dtype::Float32{}, | |||||
| {1.2f, 1.2f, 0.6f, -1.05f, -2.0f, -0.7f, 1.3f, | |||||
| 1.5f, 3.0f}), | |||||
| {}}, | |||||
| Testcase{{}, | |||||
| {}, | |||||
| TensorValueLowbit4({1, 1, 2, 2}, | |||||
| dtype::Quantized4Asymm(0.1, 3), | |||||
| output_values)}); | |||||
| } | |||||
| TEST_F(NAIVE_MULTI_THREADS, WARP_PERSPECTIVE_NCHW4) { | TEST_F(NAIVE_MULTI_THREADS, WARP_PERSPECTIVE_NCHW4) { | ||||
| using Param = WarpPerspective::Param; | using Param = WarpPerspective::Param; | ||||