GitOrigin-RevId: ab86e66533
tags/v1.5.0
| @@ -110,35 +110,33 @@ MEGDNN_DEVICE __forceinline__ static int unpack_integer_4bits(T storage, | |||
| return (result << (shift - bits)) >> shift; | |||
| } | |||
| MEGDNN_DEVICE __forceinline__ static void transform_int4x8_to_int8( | |||
| int (&result)[8], const int& source) { | |||
| #pragma unroll | |||
| for (int i = 0; i < 8; i++) { | |||
| result[i] = unpack_integer_4bits<true>( | |||
| reinterpret_cast<unsigned const&>(source), (i << 2)); | |||
| } | |||
| } | |||
| MEGDNN_DEVICE __forceinline__ static void transform_uint4x8_to_int8( | |||
| template <bool signedness> | |||
| MEGDNN_DEVICE __forceinline__ static void transform_b4x8_to_int8( | |||
| int (&result)[8], const int& source) { | |||
| #pragma unroll | |||
| for (int i = 0; i < 8; i++) { | |||
| result[i] = unpack_integer_4bits<false>( | |||
| result[i] = unpack_integer_4bits<signedness>( | |||
| reinterpret_cast<unsigned const&>(source), (i << 2)); | |||
| } | |||
| } | |||
| MEGDNN_DEVICE __forceinline__ static void transform_int4x2_to_int8( | |||
| template <bool signedness> | |||
| MEGDNN_DEVICE __forceinline__ static void transform_b4x2_to_int8( | |||
| int (&result)[2], const uint8_t& source) { | |||
| result[0] = unpack_integer_4bits<true>(source, 0); | |||
| result[1] = unpack_integer_4bits<true>(source, 4); | |||
| result[0] = unpack_integer_4bits<signedness>(source, 0); | |||
| result[1] = unpack_integer_4bits<signedness>(source, 4); | |||
| } | |||
| MEGDNN_DEVICE __forceinline__ static void transform_uint4x2_to_int8( | |||
| int (&result)[2], const uint8_t& source) { | |||
| result[0] = unpack_integer_4bits<false>(source, 0); | |||
| result[1] = unpack_integer_4bits<false>(source, 4); | |||
| template <bool signedness> | |||
| MEGDNN_DEVICE __forceinline__ static int transform_int8_to_b4x8( | |||
| int s0, int s1, int s2, int s3, int s4, int s5, int s6, int s7) { | |||
| if (signedness) { | |||
| return transform_int8_to_int4x8(s0, s1, s2, s3, s4, s5, s6, s7); | |||
| } else { | |||
| return transform_int8_to_uint4x8(s0, s1, s2, s3, s4, s5, s6, s7); | |||
| } | |||
| } | |||
| } // namespace integer_subbyte | |||
| } // namespace cuda | |||
| } // namespace megdnn | |||
| @@ -0,0 +1,171 @@ | |||
| /** | |||
| * \file dnn/src/cuda/relayout_format/cuda_post_process.cuh | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #pragma once | |||
| #include "src/cuda/relayout_format/relayout_format.cuh" | |||
| namespace megdnn { | |||
| namespace cuda { | |||
| namespace relayout_format { | |||
| namespace internal { | |||
| template <typename SrcType, typename DstType, bool same_scale> | |||
| struct CudaPostProcess; | |||
| template <> | |||
| struct CudaPostProcess<dtype::Uint8, dtype::QuantizedS8, true> { | |||
| CudaPostProcess(float, uint8_t, float, uint8_t){}; | |||
| inline __device__ int8_t operator()(uint8_t val) { return val - 128; } | |||
| }; | |||
| template <> | |||
| struct CudaPostProcess<dtype::Uint8, dtype::QuantizedS8, false> { | |||
| CudaDTypeParamImpl<dt_qint8> m_dst_type_cvt; | |||
| CudaPostProcess(float, uint8_t, float dst_scale, uint8_t) { | |||
| m_dst_type_cvt = CudaDTypeParamImpl<dt_qint8>(dst_scale); | |||
| }; | |||
| inline __device__ int8_t operator()(uint8_t val) { | |||
| return m_dst_type_cvt.quantize((float)val - 128.f).as_int8(); | |||
| } | |||
| }; | |||
| template <> | |||
| struct CudaPostProcess<dtype::Quantized8Asymm, dtype::QuantizedS8, false> { | |||
| CudaDTypeParamImpl<dt_qint8> m_dst_type_cvt; | |||
| CudaDTypeParamImpl<dt_quint8> m_src_type_cvt; | |||
| CudaPostProcess(float src_scale, uint8_t src_zero_point, float dst_scale, | |||
| uint8_t) { | |||
| m_dst_type_cvt = CudaDTypeParamImpl<dt_qint8>(dst_scale); | |||
| m_src_type_cvt = | |||
| CudaDTypeParamImpl<dt_quint8>(src_scale, src_zero_point); | |||
| }; | |||
| inline __device__ int8_t operator()(uint8_t val) { | |||
| float med_var = m_src_type_cvt.dequantize(dt_quint8(val)); | |||
| return m_dst_type_cvt.quantize(med_var).as_int8(); | |||
| } | |||
| }; | |||
| template <> | |||
| struct CudaPostProcess<dtype::Quantized8Asymm, dtype::QuantizedS8, true> { | |||
| uint8_t m_src_zero_point = 0; | |||
| CudaPostProcess(float, uint8_t src_zero_point, float, uint8_t) { | |||
| m_src_zero_point = src_zero_point; | |||
| }; | |||
| inline __device__ int8_t operator()(uint8_t val) { | |||
| return val - m_src_zero_point; | |||
| } | |||
| }; | |||
| template <> | |||
| struct CudaPostProcess<dtype::QuantizedS8, dtype::QuantizedS8, false> { | |||
| CudaDTypeParamImpl<dt_qint8> m_dst_type_cvt; | |||
| CudaDTypeParamImpl<dt_qint8> m_src_type_cvt; | |||
| CudaPostProcess(float src_scale, uint8_t, float dst_scale, uint8_t) { | |||
| m_dst_type_cvt = CudaDTypeParamImpl<dt_qint8>(dst_scale); | |||
| m_src_type_cvt = CudaDTypeParamImpl<dt_qint8>(src_scale); | |||
| }; | |||
| inline __device__ int8_t operator()(int8_t val) { | |||
| float med_var = m_src_type_cvt.dequantize(dt_qint8(val)); | |||
| return m_dst_type_cvt.quantize(med_var).as_int8(); | |||
| } | |||
| }; | |||
| template <> | |||
| struct CudaPostProcess<dtype::QuantizedS8, dtype::QuantizedS8, true> { | |||
| CudaPostProcess(){}; | |||
| CudaPostProcess(float, uint8_t, float, uint8_t){}; | |||
| inline __device__ int8_t operator()(int8_t val) { return val; } | |||
| }; | |||
| template <> | |||
| struct CudaPostProcess<dtype::QuantizedS32, dtype::QuantizedS32, false> { | |||
| CudaDTypeParamImpl<dt_qint32> m_dst_type_cvt; | |||
| CudaDTypeParamImpl<dt_qint32> m_src_type_cvt; | |||
| CudaPostProcess(float src_scale, int, float dst_scale, int) { | |||
| m_dst_type_cvt = CudaDTypeParamImpl<dt_qint32>(dst_scale); | |||
| m_src_type_cvt = CudaDTypeParamImpl<dt_qint32>(src_scale); | |||
| }; | |||
| inline __device__ int operator()(int val) { | |||
| float med_var = m_src_type_cvt.dequantize(dt_qint32(val)); | |||
| return m_dst_type_cvt.quantize(med_var).as_int32(); | |||
| } | |||
| }; | |||
| template <> | |||
| struct CudaPostProcess<dtype::QuantizedS32, dtype::QuantizedS32, true> { | |||
| CudaPostProcess(float, int, float, int){}; | |||
| inline __device__ int operator()(int val) { return val; } | |||
| }; | |||
| template <> | |||
| struct CudaPostProcess<dtype::QuantizedS4, dtype::QuantizedS4, false> { | |||
| using SrcType = dtype::QuantizedS4; | |||
| using DstType = dtype::QuantizedS4; | |||
| CudaDTypeParamImpl<dt_qint4> m_dst_type_cvt; | |||
| CudaDTypeParamImpl<dt_qint4> m_src_type_cvt; | |||
| CudaPostProcess(float src_scale, uint8_t, float dst_scale, uint8_t) { | |||
| m_dst_type_cvt = CudaDTypeParamImpl<dt_qint4>(dst_scale); | |||
| m_src_type_cvt = CudaDTypeParamImpl<dt_qint4>(src_scale); | |||
| } | |||
| inline __device__ int8_t operator()(int8_t val) { | |||
| float intermediate = m_src_type_cvt.dequantize(dt_qint4(val)); | |||
| return m_dst_type_cvt.quantize(intermediate).as_int8(); | |||
| } | |||
| }; | |||
| template <> | |||
| struct CudaPostProcess<dtype::QuantizedS4, dtype::QuantizedS4, true> { | |||
| using SrcType = dtype::QuantizedS4; | |||
| using DstType = dtype::QuantizedS4; | |||
| CudaPostProcess(float, uint8_t, float, uint8_t){}; | |||
| inline __device__ int8_t operator()(int8_t val) { return val; } | |||
| }; | |||
| template <> | |||
| struct CudaPostProcess<dtype::Quantized4Asymm, dtype::Quantized4Asymm, false> { | |||
| using SrcType = dtype::Quantized4Asymm; | |||
| using DstType = dtype::Quantized4Asymm; | |||
| CudaDTypeParamImpl<dt_quint4> m_dst_type_cvt; | |||
| CudaDTypeParamImpl<dt_quint4> m_src_type_cvt; | |||
| CudaPostProcess(float src_scale, uint8_t src_zero_point, float dst_scale, | |||
| uint8_t dst_zero_point) { | |||
| m_dst_type_cvt = | |||
| CudaDTypeParamImpl<dt_quint4>(dst_scale, dst_zero_point); | |||
| m_src_type_cvt = | |||
| CudaDTypeParamImpl<dt_quint4>(src_scale, src_zero_point); | |||
| }; | |||
| inline __device__ uint8_t operator()(uint8_t val) { | |||
| float intermediate = m_src_type_cvt.dequantize(dt_quint4(val)); | |||
| return m_dst_type_cvt.quantize(intermediate).as_uint8(); | |||
| } | |||
| }; | |||
| template <> | |||
| struct CudaPostProcess<dtype::Quantized4Asymm, dtype::Quantized4Asymm, true> { | |||
| using SrcType = dtype::Quantized4Asymm; | |||
| using DstType = dtype::Quantized4Asymm; | |||
| uint8_t m_src_zero_point = 0; | |||
| uint8_t m_dst_zero_point = 0; | |||
| CudaPostProcess(float, uint8_t src_zero_point, float, | |||
| uint8_t dst_zero_point) { | |||
| m_src_zero_point = src_zero_point; | |||
| m_dst_zero_point = dst_zero_point; | |||
| }; | |||
| inline __device__ uint8_t operator()(uint8_t val) { | |||
| int result = val - m_src_zero_point + m_dst_zero_point; | |||
| result = result >= 0 ? result : 0; | |||
| result = result < 16 ? result : 15; | |||
| return static_cast<uint8_t>(result); | |||
| } | |||
| }; | |||
| } // namespace internal | |||
| } // namespace relayout_format | |||
| } // namespace cuda | |||
| } // namespace megdnn | |||
| @@ -1,252 +0,0 @@ | |||
| /** | |||
| * \file dnn/src/cuda/relayout_format/helper.cuh | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| namespace megdnn { | |||
| namespace cuda { | |||
| namespace relayout_format { | |||
| #define devfunc __forceinline__ __device__ | |||
| template <int size_nbits> | |||
| devfunc int make_zero(int zero_point); | |||
| template <> | |||
| devfunc int make_zero<4>(int zero_point) { | |||
| return transform_int8_to_uint4x8(zero_point, zero_point, zero_point, | |||
| zero_point, zero_point, zero_point, | |||
| zero_point, zero_point); | |||
| } | |||
| template <typename AccessType, int LoadBytes> | |||
| struct global_load_with_zero_point; | |||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||
| // | |||
| // Specializations | |||
| // | |||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||
| // The redundant mov PTX instruction is used to enforce the compiler to | |||
| // initialize data to zero before ld.global | |||
| template <typename AccessType> | |||
| struct global_load_with_zero_point<AccessType, 32> { | |||
| devfunc global_load_with_zero_point(AccessType& D, void const* ptr, | |||
| bool pred_guard, int zero_point) { | |||
| uint4* data = reinterpret_cast<uint4*>(&D); | |||
| asm volatile( | |||
| "{\n" | |||
| " .reg .pred p;\n" | |||
| " setp.ne.b32 p, %9, 0;\n" | |||
| " mov.b32 %0, %10;\n" | |||
| " mov.b32 %1, %10;\n" | |||
| " mov.b32 %2, %10;\n" | |||
| " mov.b32 %3, %10;\n" | |||
| " mov.b32 %4, %10;\n" | |||
| " mov.b32 %5, %10;\n" | |||
| " mov.b32 %6, %10;\n" | |||
| " mov.b32 %7, %10;\n" | |||
| " @p ld.global.v4.u32 {%0, %1, %2, %3}, [%8];\n" | |||
| " @p ld.global.v4.u32 {%4, %5, %6, %7}, [%11];\n" | |||
| "}\n" | |||
| : "=r"(data[0].x), "=r"(data[0].y), "=r"(data[0].z), | |||
| "=r"(data[0].w), "=r"(data[1].x), "=r"(data[1].y), | |||
| "=r"(data[1].z), "=r"(data[1].w) | |||
| : "l"(ptr), "r"((int)pred_guard), | |||
| "r"(reinterpret_cast<unsigned&>(zero_point)), | |||
| "l"(((uint8_t*)ptr) + 16)); | |||
| } | |||
| }; | |||
| template <typename AccessType> | |||
| struct global_load_with_zero_point<AccessType, 16> { | |||
| devfunc global_load_with_zero_point(AccessType& D, void const* ptr, | |||
| bool pred_guard, int zero_point) { | |||
| uint4& data = reinterpret_cast<uint4&>(D); | |||
| asm volatile( | |||
| "{\n" | |||
| " .reg .pred p;\n" | |||
| " setp.ne.b32 p, %5, 0;\n" | |||
| " mov.b32 %0, %6;\n" | |||
| " mov.b32 %1, %6;\n" | |||
| " mov.b32 %2, %6;\n" | |||
| " mov.b32 %3, %6;\n" | |||
| " @p ld.global.v4.u32 {%0, %1, %2, %3}, [%4];\n" | |||
| "}\n" | |||
| : "=r"(data.x), "=r"(data.y), "=r"(data.z), "=r"(data.w) | |||
| : "l"(ptr), "r"((int)pred_guard), | |||
| "r"(reinterpret_cast<unsigned&>(zero_point))); | |||
| } | |||
| }; | |||
| template <typename AccessType> | |||
| struct global_load_with_zero_point<AccessType, 8> { | |||
| devfunc global_load_with_zero_point(AccessType& D, void const* ptr, | |||
| bool pred_guard, int zero_point) { | |||
| uint2& data = reinterpret_cast<uint2&>(D); | |||
| asm volatile( | |||
| "{\n" | |||
| " .reg .pred p;\n" | |||
| " setp.ne.b32 p, %3, 0;\n" | |||
| " mov.b32 %0, %4;\n" | |||
| " mov.b32 %1, %4;\n" | |||
| " @p ld.global.v2.u32 {%0, %1}, [%2];\n" | |||
| "}\n" | |||
| : "=r"(data.x), "=r"(data.y) | |||
| : "l"(ptr), "r"((int)pred_guard), | |||
| "r"(reinterpret_cast<unsigned&>(zero_point))); | |||
| } | |||
| }; | |||
| template <typename AccessType> | |||
| struct global_load_with_zero_point<AccessType, 4> { | |||
| devfunc global_load_with_zero_point(AccessType& D, void const* ptr, | |||
| bool pred_guard, int zero_point) { | |||
| unsigned& data = reinterpret_cast<unsigned&>(D); | |||
| asm volatile( | |||
| "{\n" | |||
| " .reg .pred p;\n" | |||
| " setp.ne.b32 p, %2, 0;\n" | |||
| " mov.b32 %0, %3;\n" | |||
| " @p ld.global.u32 %0, [%1];\n" | |||
| "}\n" | |||
| : "=r"(data) | |||
| : "l"(ptr), "r"((int)pred_guard), | |||
| "r"(reinterpret_cast<unsigned&>(zero_point))); | |||
| } | |||
| }; | |||
| template <typename AccessType> | |||
| struct global_load_with_zero_point<AccessType, 1> { | |||
| devfunc global_load_with_zero_point(AccessType& D, void const* ptr, | |||
| bool pred_guard, int zero_point) { | |||
| if (pred_guard) | |||
| D = *(reinterpret_cast<AccessType const*>(ptr)); | |||
| else { | |||
| unsigned uv = reinterpret_cast<unsigned&>(zero_point); | |||
| uint8_t& data = reinterpret_cast<uint8_t&>(D); | |||
| data = uv & 0xff; | |||
| } | |||
| } | |||
| }; | |||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||
| template < | |||
| /// Fragment type to store loaded data | |||
| typename AccessType, | |||
| /// The bytes of loading | |||
| int LoadBytes> | |||
| struct global_store; | |||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||
| // | |||
| // Specializations | |||
| // | |||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||
| template <typename AccessType> | |||
| struct global_store<AccessType, 32> { | |||
| devfunc global_store(AccessType const& D, void* ptr, bool pred_guard) { | |||
| uint4 const* data = reinterpret_cast<uint4 const*>(&D); | |||
| asm volatile( | |||
| "{\n" | |||
| " .reg .pred p;\n" | |||
| " setp.ne.b32 p, %5, 0;\n" | |||
| " @p st.global.v4.u32 [%0], {%1, %2, %3, %4};\n" | |||
| " @p st.global.v4.u32 [%6], {%7, %8, %9, %10};\n" | |||
| "}\n" | |||
| : | |||
| : "l"(ptr), "r"(data[0].x), "r"(data[0].y), "r"(data[0].z), | |||
| "r"(data[0].w), "r"((int)pred_guard), | |||
| "l"(((uint8_t*)ptr) + 16), "r"(data[1].x), "r"(data[1].y), | |||
| "r"(data[1].z), "r"(data[1].w)); | |||
| } | |||
| }; | |||
| template <typename AccessType> | |||
| struct global_store<AccessType, 16> { | |||
| devfunc global_store(AccessType const& D, void* ptr, bool pred_guard) { | |||
| uint4 const& data = reinterpret_cast<uint4 const&>(D); | |||
| asm volatile( | |||
| "{\n" | |||
| " .reg .pred p;\n" | |||
| " setp.ne.b32 p, %5, 0;\n" | |||
| " @p st.global.v4.u32 [%0], {%1, %2, %3, %4};\n" | |||
| "}\n" | |||
| : | |||
| : "l"(ptr), "r"(data.x), "r"(data.y), "r"(data.z), "r"(data.w), | |||
| "r"((int)pred_guard)); | |||
| } | |||
| }; | |||
| template <typename AccessType> | |||
| struct global_store<AccessType, 8> { | |||
| devfunc global_store(AccessType const& D, void* ptr, bool pred_guard) { | |||
| uint2 const& data = reinterpret_cast<uint2 const&>(D); | |||
| asm volatile( | |||
| "{\n" | |||
| " .reg .pred p;\n" | |||
| " setp.ne.b32 p, %3, 0;\n" | |||
| " @p st.global.v2.u32 [%0], {%1, %2};\n" | |||
| "}\n" | |||
| : | |||
| : "l"(ptr), "r"(data.x), "r"(data.y), "r"((int)pred_guard)); | |||
| } | |||
| }; | |||
| template <typename AccessType> | |||
| struct global_store<AccessType, 4> { | |||
| devfunc global_store(AccessType const& D, void* ptr, bool pred_guard) { | |||
| uint32_t const& data = reinterpret_cast<uint32_t const&>(D); | |||
| asm volatile( | |||
| "{\n" | |||
| " .reg .pred p;\n" | |||
| " setp.ne.b32 p, %2, 0;\n" | |||
| " @p st.global.u32 [%0], %1;\n" | |||
| "}\n" | |||
| : | |||
| : "l"(ptr), "r"(data), "r"((int)pred_guard)); | |||
| } | |||
| }; | |||
| template <typename AccessType> | |||
| struct global_store<AccessType, 2> { | |||
| devfunc global_store(AccessType const& D, void* ptr, bool pred_guard) { | |||
| uint16_t const& data = reinterpret_cast<uint16_t const&>(D); | |||
| asm volatile( | |||
| "{\n" | |||
| " .reg .pred p;\n" | |||
| " setp.ne.b32 p, %2, 0;\n" | |||
| " @p st.global.u16 [%0], %1;\n" | |||
| "}\n" | |||
| : | |||
| : "l"(ptr), "h"(data), "r"((int)pred_guard)); | |||
| } | |||
| }; | |||
| template <typename AccessType> | |||
| struct global_store<AccessType, 1> { | |||
| devfunc global_store(AccessType const& D, void* ptr, bool pred_guard) { | |||
| if (pred_guard) | |||
| *(reinterpret_cast<AccessType*>(ptr)) = D; | |||
| } | |||
| }; | |||
| #undef devfunc | |||
| } // namespace relayout_format | |||
| } // namespace cuda | |||
| } // namespace megdnn | |||
| @@ -39,6 +39,20 @@ void relayout_format_cuda_nchwx_nchw(const TensorND& src, const TensorND& dst, | |||
| const uint8_t src_zero_point = 0, | |||
| const uint8_t dst_zero_point = 0); | |||
| void relayout_format_cuda_nchw_nhwc(const TensorND& src, const TensorND& dst, | |||
| const cudaStream_t& stream, | |||
| const float src_scale = 1.f, | |||
| const float dst_scale = 1.f, | |||
| const uint8_t src_zero_point = 0, | |||
| const uint8_t dst_zero_point = 0); | |||
| void relayout_format_cuda_nhwc_nchw(const TensorND& src, const TensorND& dst, | |||
| const cudaStream_t& stream, | |||
| const float src_scale = 1.f, | |||
| const float dst_scale = 1.f, | |||
| const uint8_t src_zero_point = 0, | |||
| const uint8_t dst_zero_point = 0); | |||
| void relayout_format_cuda_nchw_nchw4_weight(const TensorND& src, | |||
| const TensorND& dst, | |||
| const cudaStream_t& stream); | |||
| @@ -0,0 +1,346 @@ | |||
| /** | |||
| * \file dnn/src/cuda/relayout_format/relayout_format_kern.cuh | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #pragma once | |||
| #include "src/cuda/int_fastdiv.cuh" | |||
| #include "src/cuda/memory_utils.cuh" | |||
| #include "src/cuda/relayout_format/translayout.cuh" | |||
| namespace megdnn { | |||
| namespace cuda { | |||
| namespace relayout_format { | |||
| namespace internal { | |||
| using namespace memory; | |||
| template <typename Type_, int pack_size_, int chan_blk_, int width_, | |||
| int size_nbits_> | |||
| class TensorIteratorOverChannel { | |||
| public: | |||
| using Type = Type_; | |||
| static constexpr int pack_size = pack_size_; | |||
| static constexpr int chan_blk = chan_blk_; | |||
| static constexpr int width = width_; | |||
| static constexpr int size_nbits = size_nbits_; | |||
| static constexpr int elements_in_type = | |||
| chan_blk * width * size_nbits / (8 * sizeof(Type)); | |||
| static constexpr int lane_size_in_type = | |||
| (width * pack_size * size_nbits) / (8 * sizeof(Type)); | |||
| static constexpr int pack_size_in_type = | |||
| (pack_size * size_nbits) >= (8 * sizeof(Type)) | |||
| ? (pack_size * size_nbits / (8 * sizeof(Type))) | |||
| : (width * pack_size * size_nbits / (8 * sizeof(Type))); | |||
| static constexpr int pack_size_in_byte = pack_size_in_type * sizeof(Type); | |||
| using AccessType = array_wrapper<Type, pack_size_in_type>; | |||
| using Fragment = array_wrapper<Type, elements_in_type>; | |||
| MEGDNN_HOST TensorIteratorOverChannel() | |||
| : pointer{nullptr}, chan_stride_in_elements{0}, channel{0} {} | |||
| MEGDNN_HOST TensorIteratorOverChannel(Type* pointer_, | |||
| int chan_stride_in_elements_, | |||
| int channel_, int, int) | |||
| : pointer{pointer_}, | |||
| chan_stride_in_elements{chan_stride_in_elements_}, | |||
| channel{channel_} {} | |||
| MEGDNN_DEVICE __forceinline__ void initialize(int c_idx, int hw_idx) { | |||
| pointer += (c_idx / pack_size) * chan_stride_in_elements + | |||
| hw_idx * pack_size * size_nbits / (8 * sizeof(Type)); | |||
| channel -= c_idx; | |||
| } | |||
| MEGDNN_DEVICE __forceinline__ void add_pointer_offset( | |||
| size_t offset_in_type) { | |||
| pointer += offset_in_type; | |||
| } | |||
| MEGDNN_DEVICE __forceinline__ void load(Fragment& frag, int zero_point) { | |||
| AccessType* frag_ptr = reinterpret_cast<AccessType*>(&frag); | |||
| Type* pointer_ = pointer; | |||
| #pragma unroll | |||
| for (int i = 0; i < chan_blk; i += pack_size) { | |||
| #pragma unroll | |||
| for (int j = 0; j < lane_size_in_type / pack_size_in_type; j++) { | |||
| int frag_idx = i / pack_size * | |||
| (lane_size_in_type / pack_size_in_type) + | |||
| j; | |||
| bool guard = i < channel; | |||
| global_load<AccessType, pack_size_in_byte>( | |||
| frag_ptr[frag_idx], | |||
| reinterpret_cast<void*>(pointer_ + | |||
| j * pack_size_in_type), | |||
| guard, zero_point); | |||
| } | |||
| pointer_ += chan_stride_in_elements; | |||
| } | |||
| } | |||
| MEGDNN_DEVICE __forceinline__ void store(const Fragment& frag) { | |||
| const AccessType* frag_ptr = reinterpret_cast<const AccessType*>(&frag); | |||
| Type* pointer_ = pointer; | |||
| #pragma unroll | |||
| for (int i = 0; i < chan_blk; i += pack_size) { | |||
| #pragma unroll | |||
| for (int j = 0; j < lane_size_in_type / pack_size_in_type; j++) { | |||
| int frag_idx = i / pack_size * | |||
| (lane_size_in_type / pack_size_in_type) + | |||
| j; | |||
| bool guard = i < channel; | |||
| global_store<AccessType, pack_size_in_byte>( | |||
| frag_ptr[frag_idx], | |||
| reinterpret_cast<void*>(pointer_ + | |||
| j * pack_size_in_type), | |||
| guard); | |||
| } | |||
| pointer_ += chan_stride_in_elements; | |||
| } | |||
| } | |||
| MEGDNN_DEVICE __forceinline__ void advance() { | |||
| pointer += (chan_blk / pack_size) * chan_stride_in_elements; | |||
| channel -= chan_blk; | |||
| } | |||
| private: | |||
| Type* pointer; | |||
| int chan_stride_in_elements; | |||
| int channel; | |||
| }; | |||
| template <typename Type_, int pack_size_, int chan_blk_, int width_, | |||
| int size_nbits_> | |||
| class MaskedTensorIteratorOverChannel { | |||
| public: | |||
| using Type = Type_; | |||
| static constexpr int pack_size = pack_size_; | |||
| static constexpr int chan_blk = chan_blk_; | |||
| static constexpr int width = width_; | |||
| static constexpr int size_nbits = size_nbits_; | |||
| static constexpr int elements_in_type = | |||
| chan_blk * width * size_nbits / (8 * sizeof(Type)); | |||
| static constexpr int lane_size_in_type = | |||
| (width * pack_size * size_nbits) / (8 * sizeof(Type)); | |||
| static constexpr int pack_size_in_type = | |||
| (pack_size * size_nbits) >= (8 * sizeof(Type)) | |||
| ? (pack_size * size_nbits / (8 * sizeof(Type))) | |||
| : (width * pack_size * size_nbits / (8 * sizeof(Type))); | |||
| static constexpr int pack_size_in_byte = pack_size_in_type * sizeof(Type); | |||
| static constexpr int accesses = elements_in_type / pack_size_in_type; | |||
| static constexpr int mask_size = (accesses + 32 - 1) / 32; | |||
| using AccessType = array_wrapper<Type, pack_size_in_type>; | |||
| using Fragment = array_wrapper<Type, elements_in_type>; | |||
| MEGDNN_HOST MaskedTensorIteratorOverChannel() | |||
| : pointer{nullptr}, chan_stride_in_elements{0}, channel{0} {} | |||
| MEGDNN_HOST MaskedTensorIteratorOverChannel(Type* pointer_, | |||
| int chan_stride_in_elements_, | |||
| int channel_, int bound_, | |||
| int div_) | |||
| : pointer{pointer_}, | |||
| chan_stride_in_elements{chan_stride_in_elements_}, | |||
| channel{channel_}, | |||
| bound{bound_}, | |||
| div{uint32_t(div_)} {} | |||
| MEGDNN_DEVICE __forceinline__ void initialize(int c_idx, int hw_idx) { | |||
| pointer += (c_idx / pack_size) * chan_stride_in_elements; | |||
| channel -= c_idx; | |||
| int w[lane_size_in_type / pack_size_in_type]; | |||
| #pragma unroll | |||
| for (int i = 0; i < mask_size; ++i) { | |||
| mask[i] = 0; | |||
| } | |||
| #pragma unroll | |||
| for (int j = 0; j < lane_size_in_type / pack_size_in_type; j++) { | |||
| int offset = hw_idx + j; | |||
| int h = (int)((uint32_t)(offset) / div); | |||
| w[j] = (int)((uint32_t)(offset) % div); | |||
| stride[j] = (h * bound + w[j]) * pack_size * size_nbits / | |||
| (8 * sizeof(Type)); | |||
| } | |||
| #pragma unroll | |||
| for (int i = 0; i < chan_blk; i += pack_size) { | |||
| #pragma unroll | |||
| for (int j = 0; j < lane_size_in_type / pack_size_in_type; j++) { | |||
| bool guard = (i < channel) && (w[j] < bound); | |||
| int index = (i / pack_size) * | |||
| (lane_size_in_type / pack_size_in_type) + | |||
| j; | |||
| int mask_index = (index >> 5); | |||
| int mask_shift = (index & 0x1f); | |||
| mask[mask_index] |= (guard << mask_shift); | |||
| } | |||
| } | |||
| } | |||
| MEGDNN_DEVICE __forceinline__ void add_pointer_offset( | |||
| size_t offset_in_type) { | |||
| pointer += offset_in_type; | |||
| } | |||
| MEGDNN_DEVICE __forceinline__ void load(Fragment& frag, int zero_point) { | |||
| AccessType* frag_ptr = reinterpret_cast<AccessType*>(&frag); | |||
| Type* pointer_ = pointer; | |||
| #pragma unroll | |||
| for (int i = 0; i < chan_blk; i += pack_size) { | |||
| #pragma unroll | |||
| for (int j = 0; j < lane_size_in_type / pack_size_in_type; j++) { | |||
| int frag_idx = i / pack_size * | |||
| (lane_size_in_type / pack_size_in_type) + | |||
| j; | |||
| int mask_index = (frag_idx >> 5); | |||
| int mask_shift = (frag_idx & 0x1f); | |||
| bool guard = (mask[mask_index] & (1 << mask_shift)); | |||
| global_load<AccessType, pack_size_in_byte>( | |||
| frag_ptr[frag_idx], | |||
| reinterpret_cast<void*>(pointer_ + stride[j]), guard, | |||
| zero_point); | |||
| } | |||
| pointer_ += chan_stride_in_elements; | |||
| } | |||
| } | |||
| MEGDNN_DEVICE __forceinline__ void store(const Fragment& frag) { | |||
| const AccessType* frag_ptr = reinterpret_cast<const AccessType*>(&frag); | |||
| Type* pointer_ = pointer; | |||
| #pragma unroll | |||
| for (int i = 0; i < chan_blk; i += pack_size) { | |||
| #pragma unroll | |||
| for (int j = 0; j < lane_size_in_type / pack_size_in_type; j++) { | |||
| int frag_idx = i / pack_size * | |||
| (lane_size_in_type / pack_size_in_type) + | |||
| j; | |||
| int mask_index = (frag_idx >> 5); | |||
| int mask_shift = (frag_idx & 0x1f); | |||
| bool guard = (mask[mask_index] & (1 << mask_shift)); | |||
| global_store<AccessType, pack_size_in_byte>( | |||
| frag_ptr[frag_idx], | |||
| reinterpret_cast<void*>(pointer_ + stride[j]), guard); | |||
| } | |||
| pointer_ += chan_stride_in_elements; | |||
| } | |||
| } | |||
| MEGDNN_DEVICE __forceinline__ void advance() { | |||
| pointer += (chan_blk / pack_size) * chan_stride_in_elements; | |||
| channel -= chan_blk; | |||
| } | |||
| private: | |||
| Type* pointer; | |||
| int chan_stride_in_elements; | |||
| int channel; | |||
| int bound; | |||
| Uint32Fastdiv div; | |||
| uint32_t mask[mask_size]; | |||
| size_t stride[lane_size_in_type / pack_size_in_type]; | |||
| }; | |||
| template <bool padding_, typename Type_, int pack_size_, int chan_blk_, | |||
| int width_, int size_nbits_> | |||
| struct TensorIteratorPolicy; | |||
| template <typename Type_, int pack_size_, int chan_blk_, int width_, | |||
| int size_nbits_> | |||
| struct TensorIteratorPolicy<true, Type_, pack_size_, chan_blk_, width_, | |||
| size_nbits_> { | |||
| using TensorIterator = | |||
| MaskedTensorIteratorOverChannel<Type_, pack_size_, chan_blk_, | |||
| width_, size_nbits_>; | |||
| }; | |||
| template <typename Type_, int pack_size_, int chan_blk_, int width_, | |||
| int size_nbits_> | |||
| struct TensorIteratorPolicy<false, Type_, pack_size_, chan_blk_, width_, | |||
| size_nbits_> { | |||
| using TensorIterator = | |||
| TensorIteratorOverChannel<Type_, pack_size_, chan_blk_, width_, | |||
| size_nbits_>; | |||
| }; | |||
| template <typename SrcIterator_, typename DstIterator_, typename Transpose_, | |||
| typename CudaPostProcess_> | |||
| struct RelayoutProblem { | |||
| using SrcIterator = SrcIterator_; | |||
| using DstIterator = DstIterator_; | |||
| using Transpose = Transpose_; | |||
| using CudaPostProcess = CudaPostProcess_; | |||
| MEGDNN_STATIC_ASSERT(SrcIterator::chan_blk == DstIterator::chan_blk, | |||
| "channel block mismatch"); | |||
| MEGDNN_STATIC_ASSERT(SrcIterator::width == DstIterator::width, | |||
| "width block mismatch"); | |||
| MEGDNN_STATIC_ASSERT(SrcIterator::size_nbits == DstIterator::size_nbits, | |||
| "size in bits of elements mismatch"); | |||
| static constexpr int pack_chan = SrcIterator::chan_blk; | |||
| static constexpr int pack_width = SrcIterator::width; | |||
| using DnnSrcType = typename CudaPostProcess::SrcType; | |||
| using DnnDstType = typename CudaPostProcess::DstType; | |||
| struct Param { | |||
| SrcIterator src_iterator; | |||
| DstIterator dst_iterator; | |||
| CudaPostProcess post_process; | |||
| int n_stride_src; | |||
| int n_stride_dst; | |||
| int batch_size; | |||
| int channels; | |||
| int hw; | |||
| int zero_point; | |||
| MEGDNN_HOST MEGDNN_DEVICE Param(SrcIterator src_iterator_, | |||
| DstIterator dst_iterator_, | |||
| CudaPostProcess post_process_, | |||
| int n_stride_src_, int n_stride_dst_, | |||
| int batch_size_, int channels_, int hw_, | |||
| int zero_point_) | |||
| : src_iterator{src_iterator_}, | |||
| dst_iterator{dst_iterator_}, | |||
| post_process{post_process_}, | |||
| n_stride_src{n_stride_src_}, | |||
| n_stride_dst{n_stride_dst_}, | |||
| batch_size{batch_size_}, | |||
| channels{channels_}, | |||
| hw{hw_}, | |||
| zero_point{zero_point_} {} | |||
| }; | |||
| }; | |||
| template <typename RelayoutProblem_> | |||
| __global__ void relayout_kern(typename RelayoutProblem_::Param param) { | |||
| using SrcIterator = typename RelayoutProblem_::SrcIterator; | |||
| using DstIterator = typename RelayoutProblem_::DstIterator; | |||
| static constexpr int pack_chan = RelayoutProblem_::pack_chan; | |||
| static constexpr int pack_width = RelayoutProblem_::pack_width; | |||
| const int thread_idx = blockIdx.x * blockDim.x + threadIdx.x; | |||
| const int thread_offset = thread_idx * pack_width; | |||
| const int hw_idx = (thread_offset % param.hw); | |||
| const int nc_blks = thread_offset / param.hw; | |||
| const int c_blks = (param.channels + pack_chan - 1) / pack_chan; | |||
| const int n_idx = nc_blks / c_blks; | |||
| const int c_blk_idx = nc_blks % c_blks; | |||
| const int c_idx = c_blk_idx * pack_chan; | |||
| if (n_idx < param.batch_size) { | |||
| const int src_offset = n_idx * param.n_stride_src; | |||
| const int dst_offset = n_idx * param.n_stride_dst; | |||
| param.src_iterator.add_pointer_offset(src_offset); | |||
| param.dst_iterator.add_pointer_offset(dst_offset); | |||
| param.src_iterator.initialize(c_idx, hw_idx); | |||
| param.dst_iterator.initialize(c_idx, hw_idx); | |||
| typename SrcIterator::Fragment src_frag; | |||
| typename DstIterator::Fragment dst_frag; | |||
| int zp = make_zero<SrcIterator::size_nbits>(param.zero_point); | |||
| param.src_iterator.load(src_frag, zp); | |||
| RelayoutProblem_::Transpose::trans( | |||
| reinterpret_cast<typename SrcIterator::Fragment&>(dst_frag), | |||
| src_frag, param.post_process); | |||
| param.dst_iterator.store(dst_frag); | |||
| } | |||
| } | |||
| } // namespace internal | |||
| } // namespace relayout_format | |||
| } // namespace cuda | |||
| } // namespace megdnn | |||
| @@ -0,0 +1,128 @@ | |||
| /** | |||
| * \file dnn/src/cuda/relayout_format/relayout_format_utils.cuh | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #pragma once | |||
| #include "src/cuda/integer_subbyte_utils.cuh" | |||
| #include "src/cuda/relayout_format/relayout_format.cuh" | |||
| namespace megdnn { | |||
| namespace cuda { | |||
| namespace relayout_format { | |||
| namespace internal { | |||
| template <typename cype, int pack_w, typename enable = void> | |||
| struct DTypeRWHelper; | |||
| template <typename ctype> | |||
| struct DTypeRWHelper< | |||
| ctype, 1, | |||
| typename std::enable_if<std::is_same<ctype, dt_qint8>::value || | |||
| std::is_same<ctype, dt_quint8>::value || | |||
| std::is_same<ctype, dt_uint8>::value>::type> { | |||
| using InnerDtype = char; | |||
| using DstDtype = char4; | |||
| }; | |||
| template <typename ctype> | |||
| struct DTypeRWHelper< | |||
| ctype, 4, | |||
| typename std::enable_if<std::is_same<ctype, dt_qint8>::value || | |||
| std::is_same<ctype, dt_quint8>::value || | |||
| std::is_same<ctype, dt_uint8>::value>::type> { | |||
| using InnerDtype = char4; | |||
| using DstDtype = char4; | |||
| }; | |||
| template <> | |||
| struct DTypeRWHelper<dt_qint32, 1> { | |||
| using InnerDtype = int; | |||
| using DstDtype = int4; | |||
| }; | |||
| template <> | |||
| struct DTypeRWHelper<dt_qint32, 4> { | |||
| using InnerDtype = int4; | |||
| using DstDtype = int4; | |||
| }; | |||
| template <typename ctype> | |||
| struct DTypeRWHelper< | |||
| ctype, 2, | |||
| typename std::enable_if<std::is_same<ctype, dt_qint4>::value || | |||
| std::is_same<ctype, dt_quint4>::value>::type> { | |||
| using InnerDtype = char; | |||
| using DstDtype = array_wrapper<uint8_t, 32>; | |||
| }; | |||
| template <typename ctype> | |||
| struct DTypeRWHelper< | |||
| ctype, 8, | |||
| typename std::enable_if<std::is_same<ctype, dt_qint4>::value || | |||
| std::is_same<ctype, dt_quint4>::value>::type> { | |||
| using InnerDtype = unsigned; | |||
| using DstDtype = array_wrapper<uint8_t, 32>; | |||
| }; | |||
| template <typename DstType> | |||
| inline __device__ DstType make_zero_pad(const uint8_t zero_point) { | |||
| return zero_point; | |||
| } | |||
| template <> | |||
| inline __device__ char4 make_zero_pad<char4>(const uint8_t zero_point) { | |||
| char izp = reinterpret_cast<const char&>(zero_point); | |||
| return {izp, izp, izp, izp}; | |||
| } | |||
| template <> | |||
| inline __device__ int4 make_zero_pad<int4>(const uint8_t zero_point) { | |||
| return {zero_point, zero_point, zero_point, zero_point}; | |||
| } | |||
| template <int size_nbits> | |||
| inline __device__ int make_zero(int zero_point); | |||
| template <> | |||
| inline __device__ int make_zero<4>(int zero_point) { | |||
| return integer_subbyte::transform_int8_to_uint4x8( | |||
| zero_point, zero_point, zero_point, zero_point, zero_point, | |||
| zero_point, zero_point, zero_point); | |||
| } | |||
| template <typename DstDtype> | |||
| inline __device__ void write_helper(DstDtype* ptr, DstDtype val) { | |||
| *ptr = val; | |||
| } | |||
| template <> | |||
| inline __device__ void write_helper<char4>(char4* ptr, char4 val) { | |||
| int32_t* rel_ptr = (int32_t*)ptr; | |||
| *rel_ptr = *(int32_t*)(&val); | |||
| } | |||
| template <> | |||
| inline __device__ void write_helper<array_wrapper<uint8_t, 32>>( | |||
| array_wrapper<uint8_t, 32>* ptr, array_wrapper<uint8_t, 32> val) { | |||
| uint4 const* data = reinterpret_cast<uint4 const*>(&val); | |||
| void* ptr_ = reinterpret_cast<void*>(ptr); | |||
| asm volatile( | |||
| "{\n" | |||
| " st.global.v4.u32 [%0], {%1, %2, %3, %4};\n" | |||
| " st.global.v4.u32 [%5], {%6, %7, %8, %9};\n" | |||
| "}\n" | |||
| : | |||
| : "l"(ptr_), "r"(data[0].x), "r"(data[0].y), "r"(data[0].z), | |||
| "r"(data[0].w), "l"(((uint8_t*)ptr_) + 16), "r"(data[1].x), | |||
| "r"(data[1].y), "r"(data[1].z), "r"(data[1].w)); | |||
| } | |||
| } // namespace internal | |||
| } // namespace relayout_format | |||
| } // namespace cuda | |||
| } // namespace megdnn | |||
| @@ -0,0 +1,537 @@ | |||
| /** | |||
| * \file dnn/src/cuda/relayout_format/translayout.cuh | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #pragma once | |||
| #include "src/cuda/integer_subbyte_utils.cuh" | |||
| #include "src/cuda/relayout_format/cuda_post_process.cuh" | |||
| #include "src/cuda/relayout_format/relayout_format.cuh" | |||
| #include "src/cuda/relayout_format/relayout_format_utils.cuh" | |||
| namespace megdnn { | |||
| namespace cuda { | |||
| namespace relayout_format { | |||
| namespace internal { | |||
| using namespace integer_subbyte; | |||
| template <typename dt> | |||
| struct qtype_signedness; | |||
| template <> | |||
| struct qtype_signedness<dtype::QuantizedS4> { | |||
| static constexpr bool value = true; | |||
| }; | |||
| template <> | |||
| struct qtype_signedness<dtype::Quantized4Asymm> { | |||
| static constexpr bool value = false; | |||
| }; | |||
| template <typename dt_src, typename dt_dst> | |||
| struct enable_qtype_b4 { | |||
| static constexpr bool val_src = | |||
| std::is_same<dt_src, dtype::QuantizedS4>::value || | |||
| std::is_same<dt_src, dtype::Quantized4Asymm>::value; | |||
| static constexpr bool val_dst = | |||
| std::is_same<dt_dst, dtype::QuantizedS4>::value || | |||
| std::is_same<dt_dst, dtype::Quantized4Asymm>::value; | |||
| using type = typename std::enable_if<std::is_same<dt_src, dt_dst>::value && | |||
| val_src && val_dst>::type; | |||
| }; | |||
| // The input fragment is stored in RowMajor order. The translayout operator | |||
| // performs a transpose operation on the input fragment, and produces a | |||
| // reordered fragment, i.e. a fragment stored in ColumnMajor order. | |||
| template <int col, int row, typename SrcType, typename DnnSrcType, | |||
| typename DnnDstType, bool same_scale, typename enable = void> | |||
| struct Translayout; | |||
| // partial specialization for translayout operator for qint8 and quint8 | |||
| template <typename SrcType, typename DnnSrcType, typename DnnDstType, | |||
| bool same_scale> | |||
| struct Translayout<1, 4, SrcType, DnnSrcType, DnnDstType, same_scale> { | |||
| using InnerDtype = | |||
| typename DTypeRWHelper<typename DTypeTrait<DnnSrcType>::ctype, | |||
| 1>::InnerDtype; | |||
| using DstDtype = | |||
| typename DTypeRWHelper<typename DTypeTrait<DnnSrcType>::ctype, | |||
| 1>::DstDtype; | |||
| static inline __device__ void trans( | |||
| DstDtype (&dst_width)[1], InnerDtype (&read_channel)[4], | |||
| CudaPostProcess<DnnSrcType, DnnDstType, same_scale>& post_process, | |||
| const char zero_point) { | |||
| dst_width[0].x = post_process(read_channel[0]); | |||
| dst_width[0].y = post_process(read_channel[1]); | |||
| dst_width[0].z = post_process(read_channel[2]); | |||
| dst_width[0].w = post_process(read_channel[3]); | |||
| } | |||
| }; | |||
| template <typename SrcType, typename DnnSrcType, typename DnnDstType, | |||
| bool same_scale> | |||
| struct Translayout<4, 4, SrcType, DnnSrcType, DnnDstType, same_scale> { | |||
| using InnerDtype = | |||
| typename DTypeRWHelper<typename DTypeTrait<DnnSrcType>::ctype, | |||
| 4>::InnerDtype; | |||
| using DstDtype = | |||
| typename DTypeRWHelper<typename DTypeTrait<DnnSrcType>::ctype, | |||
| 4>::DstDtype; | |||
| static inline __device__ void trans( | |||
| DstDtype (&dst_width)[4], InnerDtype (&read_channel)[4], | |||
| CudaPostProcess<DnnSrcType, DnnDstType, same_scale>& post_process, | |||
| const char zero_point) { | |||
| dst_width[0].x = post_process(read_channel[0].x); | |||
| dst_width[0].y = post_process(read_channel[1].x); | |||
| dst_width[0].z = post_process(read_channel[2].x); | |||
| dst_width[0].w = post_process(read_channel[3].x); | |||
| dst_width[1].x = post_process(read_channel[0].y); | |||
| dst_width[1].y = post_process(read_channel[1].y); | |||
| dst_width[1].z = post_process(read_channel[2].y); | |||
| dst_width[1].w = post_process(read_channel[3].y); | |||
| dst_width[2].x = post_process(read_channel[0].z); | |||
| dst_width[2].y = post_process(read_channel[1].z); | |||
| dst_width[2].z = post_process(read_channel[2].z); | |||
| dst_width[2].w = post_process(read_channel[3].z); | |||
| dst_width[3].x = post_process(read_channel[0].w); | |||
| dst_width[3].y = post_process(read_channel[1].w); | |||
| dst_width[3].z = post_process(read_channel[2].w); | |||
| dst_width[3].w = post_process(read_channel[3].w); | |||
| } | |||
| }; | |||
| // ========================================================= | |||
| // partial specialization for translayout operator for qint4 | |||
| // NCHW <-> NCHW64 | |||
| template <typename SrcType, typename DnnSrcType_, typename DnnDstType_, | |||
| bool same_scale> | |||
| struct Translayout<2, 64, SrcType, DnnSrcType_, DnnDstType_, same_scale, | |||
| typename enable_qtype_b4<DnnSrcType_, DnnDstType_>::type> { | |||
| using DnnSrcType = DnnSrcType_; | |||
| using DnnDstType = DnnDstType_; | |||
| using InnerDtype = | |||
| typename DTypeRWHelper<typename DTypeTrait<DnnSrcType>::ctype, | |||
| 2>::InnerDtype; | |||
| using DstDtype = | |||
| typename DTypeRWHelper<typename DTypeTrait<DnnSrcType>::ctype, | |||
| 2>::DstDtype; | |||
| static constexpr bool signedness = qtype_signedness<DnnSrcType>::value; | |||
| static inline __device__ void trans( | |||
| DstDtype (&dst_width)[2], InnerDtype (&read_channel)[64], | |||
| CudaPostProcess<DnnSrcType, DnnDstType, same_scale>& post_process, | |||
| const char zero_point) { | |||
| int intermediate[8][2]; | |||
| int* dst_frag = reinterpret_cast<int*>(dst_width); | |||
| auto pack_channel = [&](int idx) -> int { | |||
| return transform_int8_to_b4x8<signedness>( | |||
| post_process(intermediate[0][idx]), | |||
| post_process(intermediate[1][idx]), | |||
| post_process(intermediate[2][idx]), | |||
| post_process(intermediate[3][idx]), | |||
| post_process(intermediate[4][idx]), | |||
| post_process(intermediate[5][idx]), | |||
| post_process(intermediate[6][idx]), | |||
| post_process(intermediate[7][idx])); | |||
| }; | |||
| #pragma unroll | |||
| for (int i = 0; i < 64; i += 8) { | |||
| transform_b4x2_to_int8<signedness>( | |||
| intermediate[0], | |||
| reinterpret_cast<uint8_t&>(read_channel[i + 0])); | |||
| transform_b4x2_to_int8<signedness>( | |||
| intermediate[1], | |||
| reinterpret_cast<uint8_t&>(read_channel[i + 1])); | |||
| transform_b4x2_to_int8<signedness>( | |||
| intermediate[2], | |||
| reinterpret_cast<uint8_t&>(read_channel[i + 2])); | |||
| transform_b4x2_to_int8<signedness>( | |||
| intermediate[3], | |||
| reinterpret_cast<uint8_t&>(read_channel[i + 3])); | |||
| transform_b4x2_to_int8<signedness>( | |||
| intermediate[4], | |||
| reinterpret_cast<uint8_t&>(read_channel[i + 4])); | |||
| transform_b4x2_to_int8<signedness>( | |||
| intermediate[5], | |||
| reinterpret_cast<uint8_t&>(read_channel[i + 5])); | |||
| transform_b4x2_to_int8<signedness>( | |||
| intermediate[6], | |||
| reinterpret_cast<uint8_t&>(read_channel[i + 6])); | |||
| transform_b4x2_to_int8<signedness>( | |||
| intermediate[7], | |||
| reinterpret_cast<uint8_t&>(read_channel[i + 7])); | |||
| int frag_idx = i / 8; | |||
| dst_frag[0 * 8 + frag_idx] = pack_channel(0); | |||
| dst_frag[1 * 8 + frag_idx] = pack_channel(1); | |||
| } | |||
| } | |||
| using Fragment = array_wrapper<SrcType, 64>; | |||
| static inline __device__ void trans( | |||
| Fragment& dst, Fragment& src, | |||
| CudaPostProcess<DnnSrcType, DnnDstType, same_scale>& post_process) { | |||
| trans(reinterpret_cast<DstDtype(&)[2]>(dst), | |||
| reinterpret_cast<InnerDtype(&)[64]>(src), post_process, 0); | |||
| } | |||
| }; | |||
| template <typename SrcType, typename DnnSrcType_, typename DnnDstType_, | |||
| bool same_scale> | |||
| struct Translayout<8, 64, SrcType, DnnSrcType_, DnnDstType_, same_scale, | |||
| typename enable_qtype_b4<DnnSrcType_, DnnDstType_>::type> { | |||
| using DnnSrcType = DnnSrcType_; | |||
| using DnnDstType = DnnDstType_; | |||
| using InnerDtype = | |||
| typename DTypeRWHelper<typename DTypeTrait<DnnSrcType>::ctype, | |||
| 8>::InnerDtype; | |||
| using DstDtype = | |||
| typename DTypeRWHelper<typename DTypeTrait<DnnSrcType>::ctype, | |||
| 8>::DstDtype; | |||
| static constexpr bool signedness = qtype_signedness<DnnSrcType>::value; | |||
| static inline __device__ void trans( | |||
| DstDtype (&dst_width)[8], InnerDtype (&read_channel)[64], | |||
| CudaPostProcess<DnnSrcType, DnnDstType, same_scale>& post_process, | |||
| const char zero_point) { | |||
| int intermediate[8][8]; | |||
| int* dst_frag = reinterpret_cast<int*>(dst_width); | |||
| auto pack_channel = [&](int idx) -> int { | |||
| return transform_int8_to_b4x8<signedness>( | |||
| post_process(intermediate[0][idx]), | |||
| post_process(intermediate[1][idx]), | |||
| post_process(intermediate[2][idx]), | |||
| post_process(intermediate[3][idx]), | |||
| post_process(intermediate[4][idx]), | |||
| post_process(intermediate[5][idx]), | |||
| post_process(intermediate[6][idx]), | |||
| post_process(intermediate[7][idx])); | |||
| }; | |||
| #pragma unroll | |||
| for (int i = 0; i < 64; i += 8) { | |||
| transform_b4x8_to_int8<signedness>(intermediate[0], | |||
| read_channel[i + 0]); | |||
| transform_b4x8_to_int8<signedness>(intermediate[1], | |||
| read_channel[i + 1]); | |||
| transform_b4x8_to_int8<signedness>(intermediate[2], | |||
| read_channel[i + 2]); | |||
| transform_b4x8_to_int8<signedness>(intermediate[3], | |||
| read_channel[i + 3]); | |||
| transform_b4x8_to_int8<signedness>(intermediate[4], | |||
| read_channel[i + 4]); | |||
| transform_b4x8_to_int8<signedness>(intermediate[5], | |||
| read_channel[i + 5]); | |||
| transform_b4x8_to_int8<signedness>(intermediate[6], | |||
| read_channel[i + 6]); | |||
| transform_b4x8_to_int8<signedness>(intermediate[7], | |||
| read_channel[i + 7]); | |||
| int frag_idx = i / 8; | |||
| dst_frag[0 * 8 + frag_idx] = pack_channel(0); | |||
| dst_frag[1 * 8 + frag_idx] = pack_channel(1); | |||
| dst_frag[2 * 8 + frag_idx] = pack_channel(2); | |||
| dst_frag[3 * 8 + frag_idx] = pack_channel(3); | |||
| dst_frag[4 * 8 + frag_idx] = pack_channel(4); | |||
| dst_frag[5 * 8 + frag_idx] = pack_channel(5); | |||
| dst_frag[6 * 8 + frag_idx] = pack_channel(6); | |||
| dst_frag[7 * 8 + frag_idx] = pack_channel(7); | |||
| } | |||
| } | |||
| using Fragment = array_wrapper<unsigned, 64>; | |||
| static inline __device__ void trans( | |||
| Fragment& dst, Fragment& src, | |||
| CudaPostProcess<DnnSrcType, DnnDstType, same_scale>& post_process) { | |||
| trans(reinterpret_cast<DstDtype(&)[8]>(dst), | |||
| reinterpret_cast<InnerDtype(&)[64]>(src), post_process, 0); | |||
| } | |||
| }; | |||
| template <typename SrcType, typename DnnSrcType_, typename DnnDstType_, | |||
| bool same_scale> | |||
| struct Translayout<64, 8, SrcType, DnnSrcType_, DnnDstType_, same_scale, | |||
| typename enable_qtype_b4<DnnSrcType_, DnnDstType_>::type> { | |||
| using DnnSrcType = DnnSrcType_; | |||
| using DnnDstType = DnnDstType_; | |||
| static constexpr int row = 8; | |||
| static constexpr int col = 64; | |||
| static constexpr int size_nbits = 4; | |||
| static constexpr int col_in_type = col * size_nbits / (8 * sizeof(SrcType)); | |||
| static constexpr int elements_in_type = row * col_in_type; | |||
| static constexpr int inc_col = 8; | |||
| static constexpr int inc_col_in_type = | |||
| inc_col * size_nbits / (8 * sizeof(SrcType)); | |||
| static constexpr bool signedness = qtype_signedness<DnnSrcType>::value; | |||
| using Fragment = array_wrapper<SrcType, elements_in_type>; | |||
| static MEGDNN_DEVICE __forceinline__ void trans( | |||
| Fragment& dst, const Fragment& src, | |||
| CudaPostProcess<DnnSrcType, DnnDstType, same_scale>& post_process) { | |||
| int intermediate[8][8]; | |||
| int* dst_frag = reinterpret_cast<int*>(&dst); | |||
| auto pack = [&](int idx) -> int { | |||
| return transform_int8_to_b4x8<signedness>( | |||
| post_process(intermediate[0][idx]), | |||
| post_process(intermediate[1][idx]), | |||
| post_process(intermediate[2][idx]), | |||
| post_process(intermediate[3][idx]), | |||
| post_process(intermediate[4][idx]), | |||
| post_process(intermediate[5][idx]), | |||
| post_process(intermediate[6][idx]), | |||
| post_process(intermediate[7][idx])); | |||
| }; | |||
| #pragma unroll | |||
| for (int j = 0; j < col_in_type; j += inc_col_in_type) { | |||
| transform_b4x8_to_int8<signedness>( | |||
| intermediate[0], | |||
| reinterpret_cast<const int&>(src[0 * col_in_type + j])); | |||
| transform_b4x8_to_int8<signedness>( | |||
| intermediate[1], | |||
| reinterpret_cast<const int&>(src[1 * col_in_type + j])); | |||
| transform_b4x8_to_int8<signedness>( | |||
| intermediate[2], | |||
| reinterpret_cast<const int&>(src[2 * col_in_type + j])); | |||
| transform_b4x8_to_int8<signedness>( | |||
| intermediate[3], | |||
| reinterpret_cast<const int&>(src[3 * col_in_type + j])); | |||
| transform_b4x8_to_int8<signedness>( | |||
| intermediate[4], | |||
| reinterpret_cast<const int&>(src[4 * col_in_type + j])); | |||
| transform_b4x8_to_int8<signedness>( | |||
| intermediate[5], | |||
| reinterpret_cast<const int&>(src[5 * col_in_type + j])); | |||
| transform_b4x8_to_int8<signedness>( | |||
| intermediate[6], | |||
| reinterpret_cast<const int&>(src[6 * col_in_type + j])); | |||
| transform_b4x8_to_int8<signedness>( | |||
| intermediate[7], | |||
| reinterpret_cast<const int&>(src[7 * col_in_type + j])); | |||
| dst_frag[(j / inc_col_in_type) * 8 + 0] = pack(0); | |||
| dst_frag[(j / inc_col_in_type) * 8 + 1] = pack(1); | |||
| dst_frag[(j / inc_col_in_type) * 8 + 2] = pack(2); | |||
| dst_frag[(j / inc_col_in_type) * 8 + 3] = pack(3); | |||
| dst_frag[(j / inc_col_in_type) * 8 + 4] = pack(4); | |||
| dst_frag[(j / inc_col_in_type) * 8 + 5] = pack(5); | |||
| dst_frag[(j / inc_col_in_type) * 8 + 6] = pack(6); | |||
| dst_frag[(j / inc_col_in_type) * 8 + 7] = pack(7); | |||
| } | |||
| } | |||
| }; | |||
| template <typename SrcType, typename DnnSrcType_, typename DnnDstType_, | |||
| bool same_scale> | |||
| struct Translayout<64, 2, SrcType, DnnSrcType_, DnnDstType_, same_scale, | |||
| typename enable_qtype_b4<DnnSrcType_, DnnDstType_>::type> { | |||
| using DnnSrcType = DnnSrcType_; | |||
| using DnnDstType = DnnDstType_; | |||
| static constexpr int row = 2; | |||
| static constexpr int col = 64; | |||
| static constexpr int size_nbits = 4; | |||
| static constexpr int col_in_type = col * size_nbits / (8 * sizeof(SrcType)); | |||
| static constexpr int elements_in_type = row * col_in_type; | |||
| static constexpr int inc_col = 8; | |||
| static constexpr int inc_col_in_type = | |||
| inc_col * size_nbits / (8 * sizeof(SrcType)); | |||
| static constexpr bool signedness = qtype_signedness<DnnSrcType>::value; | |||
| using Fragment = array_wrapper<SrcType, elements_in_type>; | |||
| static MEGDNN_DEVICE __forceinline__ void trans( | |||
| Fragment& dst, const Fragment& src, | |||
| CudaPostProcess<DnnSrcType, DnnDstType, same_scale>& post_process) { | |||
| int intermediate[2][8]; | |||
| int* dst_frag = reinterpret_cast<int*>(&dst); | |||
| #pragma unroll | |||
| for (int j = 0; j < col_in_type; j += inc_col_in_type) { | |||
| transform_b4x8_to_int8<signedness>( | |||
| intermediate[0], | |||
| reinterpret_cast<const int&>(src[0 * col_in_type + j])); | |||
| transform_b4x8_to_int8<signedness>( | |||
| intermediate[1], | |||
| reinterpret_cast<const int&>(src[1 * col_in_type + j])); | |||
| dst_frag[(j / inc_col_in_type) * 2 + 0] = | |||
| transform_int8_to_b4x8<signedness>( | |||
| post_process(intermediate[0][0]), | |||
| post_process(intermediate[1][0]), | |||
| post_process(intermediate[0][1]), | |||
| post_process(intermediate[1][1]), | |||
| post_process(intermediate[0][2]), | |||
| post_process(intermediate[1][2]), | |||
| post_process(intermediate[0][3]), | |||
| post_process(intermediate[1][3])); | |||
| dst_frag[(j / inc_col_in_type) * 2 + 1] = | |||
| transform_int8_to_b4x8<signedness>( | |||
| post_process(intermediate[0][4]), | |||
| post_process(intermediate[1][4]), | |||
| post_process(intermediate[0][5]), | |||
| post_process(intermediate[1][5]), | |||
| post_process(intermediate[0][6]), | |||
| post_process(intermediate[1][6]), | |||
| post_process(intermediate[0][7]), | |||
| post_process(intermediate[1][7])); | |||
| } | |||
| } | |||
| }; | |||
| // ========================================================= | |||
| // partial specialization for translayout operator for qint4 | |||
| // NCHW <-> NHWC | |||
| template <typename SrcType, typename DnnSrcType_, typename DnnDstType_, | |||
| bool same_scale> | |||
| struct Translayout<2, 8, SrcType, DnnSrcType_, DnnDstType_, same_scale, | |||
| typename enable_qtype_b4<DnnSrcType_, DnnDstType_>::type> { | |||
| using DnnSrcType = DnnSrcType_; | |||
| using DnnDstType = DnnDstType_; | |||
| static constexpr int row = 8; | |||
| static constexpr int col = 2; | |||
| static constexpr int size_nbits = 4; | |||
| static constexpr int col_in_type = col * size_nbits / (8 * sizeof(SrcType)); | |||
| static constexpr int elements_in_type = row * col_in_type; | |||
| static constexpr bool signedness = qtype_signedness<DnnSrcType>::value; | |||
| using Fragment = array_wrapper<SrcType, elements_in_type>; | |||
| static inline __device__ void trans( | |||
| Fragment& dst, const Fragment& src, | |||
| CudaPostProcess<DnnSrcType, DnnDstType, same_scale>& post_process, | |||
| const char zero_point) { | |||
| int intermediate[8][2]; | |||
| transform_b4x2_to_int8<signedness>(intermediate[0], | |||
| reinterpret_cast<uint8_t&>(src[0])); | |||
| transform_b4x2_to_int8<signedness>(intermediate[1], | |||
| reinterpret_cast<uint8_t&>(src[1])); | |||
| transform_b4x2_to_int8<signedness>(intermediate[2], | |||
| reinterpret_cast<uint8_t&>(src[2])); | |||
| transform_b4x2_to_int8<signedness>(intermediate[3], | |||
| reinterpret_cast<uint8_t&>(src[3])); | |||
| transform_b4x2_to_int8<signedness>(intermediate[4], | |||
| reinterpret_cast<uint8_t&>(src[4])); | |||
| transform_b4x2_to_int8<signedness>(intermediate[5], | |||
| reinterpret_cast<uint8_t&>(src[5])); | |||
| transform_b4x2_to_int8<signedness>(intermediate[6], | |||
| reinterpret_cast<uint8_t&>(src[6])); | |||
| transform_b4x2_to_int8<signedness>(intermediate[7], | |||
| reinterpret_cast<uint8_t&>(src[7])); | |||
| int* dst_frag = reinterpret_cast<int*>(&dst); | |||
| auto pack = [&](int idx) -> int { | |||
| return transform_int8_to_b4x8<signedness>( | |||
| post_process(intermediate[0][idx]), | |||
| post_process(intermediate[1][idx]), | |||
| post_process(intermediate[2][idx]), | |||
| post_process(intermediate[3][idx]), | |||
| post_process(intermediate[4][idx]), | |||
| post_process(intermediate[5][idx]), | |||
| post_process(intermediate[6][idx]), | |||
| post_process(intermediate[7][idx])); | |||
| }; | |||
| dst_frag[0] = pack(0); | |||
| dst_frag[1] = pack(1); | |||
| } | |||
| }; | |||
| template <typename SrcType, typename DnnSrcType_, typename DnnDstType_, | |||
| bool same_scale> | |||
| struct Translayout<8, 8, SrcType, DnnSrcType_, DnnDstType_, same_scale, | |||
| typename enable_qtype_b4<DnnSrcType_, DnnDstType_>::type> { | |||
| using DnnSrcType = DnnSrcType_; | |||
| using DnnDstType = DnnDstType_; | |||
| static constexpr int row = 8; | |||
| static constexpr int col = 8; | |||
| static constexpr int size_nbits = 4; | |||
| static constexpr int col_in_type = col * size_nbits / (8 * sizeof(SrcType)); | |||
| static constexpr int elements_in_type = row * col_in_type; | |||
| static constexpr bool signedness = qtype_signedness<DnnSrcType>::value; | |||
| using Fragment = array_wrapper<SrcType, elements_in_type>; | |||
| static inline __device__ void trans( | |||
| Fragment& dst, const Fragment& src, | |||
| CudaPostProcess<DnnSrcType, DnnDstType, same_scale>& post_process, | |||
| const char zero_point) { | |||
| int intermediate[8][8]; | |||
| transform_b4x8_to_int8<signedness>( | |||
| intermediate[0], reinterpret_cast<const int&>(src[0])); | |||
| transform_b4x8_to_int8<signedness>( | |||
| intermediate[1], reinterpret_cast<const int&>(src[1])); | |||
| transform_b4x8_to_int8<signedness>( | |||
| intermediate[2], reinterpret_cast<const int&>(src[2])); | |||
| transform_b4x8_to_int8<signedness>( | |||
| intermediate[3], reinterpret_cast<const int&>(src[3])); | |||
| transform_b4x8_to_int8<signedness>( | |||
| intermediate[4], reinterpret_cast<const int&>(src[4])); | |||
| transform_b4x8_to_int8<signedness>( | |||
| intermediate[5], reinterpret_cast<const int&>(src[5])); | |||
| transform_b4x8_to_int8<signedness>( | |||
| intermediate[6], reinterpret_cast<const int&>(src[6])); | |||
| transform_b4x8_to_int8<signedness>( | |||
| intermediate[7], reinterpret_cast<const int&>(src[7])); | |||
| int* dst_frag = reinterpret_cast<int*>(&dst); | |||
| auto pack = [&](int idx) { | |||
| return transform_int8_to_b4x8<signedness>( | |||
| post_process(intermediate[0][idx]), | |||
| post_process(intermediate[1][idx]), | |||
| post_process(intermediate[2][idx]), | |||
| post_process(intermediate[3][idx]), | |||
| post_process(intermediate[4][idx]), | |||
| post_process(intermediate[5][idx]), | |||
| post_process(intermediate[6][idx]), | |||
| post_process(intermediate[7][idx])); | |||
| }; | |||
| dst_frag[0] = pack(0); | |||
| dst_frag[1] = pack(1); | |||
| dst_frag[2] = pack(2); | |||
| dst_frag[3] = pack(3); | |||
| dst_frag[4] = pack(4); | |||
| dst_frag[5] = pack(5); | |||
| dst_frag[6] = pack(6); | |||
| dst_frag[7] = pack(7); | |||
| } | |||
| }; | |||
| template <typename SrcType, typename DnnSrcType_, typename DnnDstType_, | |||
| bool same_scale> | |||
| struct Translayout<8, 2, SrcType, DnnSrcType_, DnnDstType_, same_scale, | |||
| typename enable_qtype_b4<DnnSrcType_, DnnDstType_>::type> { | |||
| using DnnSrcType = DnnSrcType_; | |||
| using DnnDstType = DnnDstType_; | |||
| static constexpr int row = 2; | |||
| static constexpr int col = 8; | |||
| static constexpr int size_nbits = 4; | |||
| static constexpr int col_in_type = col * size_nbits / (8 * sizeof(SrcType)); | |||
| static constexpr int elements_in_type = row * col_in_type; | |||
| static constexpr bool signedness = qtype_signedness<DnnSrcType>::value; | |||
| using Fragment = array_wrapper<SrcType, elements_in_type>; | |||
| static inline __device__ void trans( | |||
| Fragment& dst, const Fragment& src, | |||
| CudaPostProcess<DnnSrcType, DnnDstType, same_scale>& post_process, | |||
| const char zero_point) { | |||
| int intermediate[2][8]; | |||
| transform_b4x8_to_int8<signedness>( | |||
| intermediate[0], reinterpret_cast<const int&>(src[0])); | |||
| transform_b4x8_to_int8<signedness>( | |||
| intermediate[1], reinterpret_cast<const int&>(src[1])); | |||
| int* dst_frag = reinterpret_cast<int*>(&dst); | |||
| dst_frag[0] = transform_int8_to_b4x8<signedness>( | |||
| post_process(intermediate[0][0]), | |||
| post_process(intermediate[1][0]), | |||
| post_process(intermediate[0][1]), | |||
| post_process(intermediate[1][1]), | |||
| post_process(intermediate[0][2]), | |||
| post_process(intermediate[1][2]), | |||
| post_process(intermediate[0][3]), | |||
| post_process(intermediate[1][3])); | |||
| dst_frag[1] = transform_int8_to_b4x8<signedness>( | |||
| post_process(intermediate[0][4]), | |||
| post_process(intermediate[1][4]), | |||
| post_process(intermediate[0][5]), | |||
| post_process(intermediate[1][5]), | |||
| post_process(intermediate[0][6]), | |||
| post_process(intermediate[1][6]), | |||
| post_process(intermediate[0][7]), | |||
| post_process(intermediate[1][7])); | |||
| } | |||
| }; | |||
| } // namespace internal | |||
| } // namespace relayout_format | |||
| } // namespace cuda | |||
| } // namespace megdnn | |||
| @@ -176,60 +176,22 @@ __global__ void kern_general_nchw4(SrcVisitor src, const float* __restrict mat, | |||
| } | |||
| } | |||
| template <bool signedness> | |||
| MEGDNN_DEVICE __forceinline__ int transform_int8_to_bit4x8(int s0, int s1, | |||
| int s2, int s3, | |||
| int s4, int s5, | |||
| int s6, int s7); | |||
| template <> | |||
| MEGDNN_DEVICE __forceinline__ int transform_int8_to_bit4x8<true>( | |||
| int s0, int s1, int s2, int s3, int s4, int s5, int s6, int s7) { | |||
| return transform_int8_to_int4x8(s0, s1, s2, s3, s4, s5, s6, s7); | |||
| } | |||
| template <> | |||
| MEGDNN_DEVICE __forceinline__ int transform_int8_to_bit4x8<false>( | |||
| int s0, int s1, int s2, int s3, int s4, int s5, int s6, int s7) { | |||
| return transform_int8_to_uint4x8(s0, s1, s2, s3, s4, s5, s6, s7); | |||
| } | |||
| template <bool signedness> | |||
| MEGDNN_DEVICE __forceinline__ void | |||
| transform_bit4x8_to_int8(int (&result)[8], const int& source); | |||
| template <> | |||
| MEGDNN_DEVICE __forceinline__ void | |||
| transform_bit4x8_to_int8<true>(int (&result)[8], const int& source){ | |||
| transform_int4x8_to_int8(result, source); | |||
| } | |||
| template <> | |||
| MEGDNN_DEVICE __forceinline__ void | |||
| transform_bit4x8_to_int8<false>(int (&result)[8], const int& source){ | |||
| transform_uint4x8_to_int8(result, source); | |||
| } | |||
| template <bool signedness, typename OutputConverter> | |||
| MEGDNN_DEVICE __forceinline__ int pack_output_func( | |||
| OutputConverter& output_converter, int (&s00)[8], int (&s01)[8], | |||
| int (&s10)[8], int (&s11)[8], float w00, float w01, float w10, | |||
| float w11) { | |||
| #define warp_perspective_transform(idx) \ | |||
| static_cast<int>(output_converter(s00[idx] * w00 + \ | |||
| s01[idx] * w01 + \ | |||
| s10[idx] * w10 + \ | |||
| s11[idx] * w11) \ | |||
| #define warp_perspective_transform(idx) \ | |||
| static_cast<int>(output_converter(s00[idx] * w00 + s01[idx] * w01 + \ | |||
| s10[idx] * w10 + s11[idx] * w11) \ | |||
| .as_storage()) | |||
| return transform_int8_to_bit4x8<signedness>( | |||
| return transform_int8_to_b4x8<signedness>( | |||
| warp_perspective_transform(0), warp_perspective_transform(1), | |||
| warp_perspective_transform(2), warp_perspective_transform(3), | |||
| warp_perspective_transform(4), warp_perspective_transform(5), | |||
| warp_perspective_transform(6), warp_perspective_transform(7)); | |||
| #undef warp_perspective_transform | |||
| #undef warp_perspective_transform | |||
| } | |||
| template <typename ctype, typename Getter, typename SrcVisitor, | |||
| @@ -278,31 +240,31 @@ __global__ void kern_general_nchw64(SrcVisitor src, const float* __restrict mat, | |||
| s[2] = __ldg(sptr_int4 + i_coor_10 + c1); | |||
| s[3] = __ldg(sptr_int4 + i_coor_11 + c1); | |||
| transform_bit4x8_to_int8<signedness>(s00, s[0].x); | |||
| transform_bit4x8_to_int8<signedness>(s01, s[1].x); | |||
| transform_bit4x8_to_int8<signedness>(s10, s[2].x); | |||
| transform_bit4x8_to_int8<signedness>(s11, s[3].x); | |||
| transform_b4x8_to_int8<signedness>(s00, s[0].x); | |||
| transform_b4x8_to_int8<signedness>(s01, s[1].x); | |||
| transform_b4x8_to_int8<signedness>(s10, s[2].x); | |||
| transform_b4x8_to_int8<signedness>(s11, s[3].x); | |||
| d.x = pack_output_func<signedness>(output_converter, s00, s01, s10, | |||
| s11, w00, w01, w10, w11); | |||
| transform_bit4x8_to_int8<signedness>(s00, s[0].y); | |||
| transform_bit4x8_to_int8<signedness>(s01, s[1].y); | |||
| transform_bit4x8_to_int8<signedness>(s10, s[2].y); | |||
| transform_bit4x8_to_int8<signedness>(s11, s[3].y); | |||
| transform_b4x8_to_int8<signedness>(s00, s[0].y); | |||
| transform_b4x8_to_int8<signedness>(s01, s[1].y); | |||
| transform_b4x8_to_int8<signedness>(s10, s[2].y); | |||
| transform_b4x8_to_int8<signedness>(s11, s[3].y); | |||
| d.y = pack_output_func<signedness>(output_converter, s00, s01, s10, | |||
| s11, w00, w01, w10, w11); | |||
| transform_bit4x8_to_int8<signedness>(s00, s[0].z); | |||
| transform_bit4x8_to_int8<signedness>(s01, s[1].z); | |||
| transform_bit4x8_to_int8<signedness>(s10, s[2].z); | |||
| transform_bit4x8_to_int8<signedness>(s11, s[3].z); | |||
| transform_b4x8_to_int8<signedness>(s00, s[0].z); | |||
| transform_b4x8_to_int8<signedness>(s01, s[1].z); | |||
| transform_b4x8_to_int8<signedness>(s10, s[2].z); | |||
| transform_b4x8_to_int8<signedness>(s11, s[3].z); | |||
| d.z = pack_output_func<signedness>(output_converter, s00, s01, s10, | |||
| s11, w00, w01, w10, w11); | |||
| transform_bit4x8_to_int8<signedness>(s00, s[0].w); | |||
| transform_bit4x8_to_int8<signedness>(s01, s[1].w); | |||
| transform_bit4x8_to_int8<signedness>(s10, s[2].w); | |||
| transform_bit4x8_to_int8<signedness>(s11, s[3].w); | |||
| transform_b4x8_to_int8<signedness>(s00, s[0].w); | |||
| transform_b4x8_to_int8<signedness>(s01, s[1].w); | |||
| transform_b4x8_to_int8<signedness>(s10, s[2].w); | |||
| transform_b4x8_to_int8<signedness>(s11, s[3].w); | |||
| d.w = pack_output_func<signedness>(output_converter, s00, s01, s10, | |||
| s11, w00, w01, w10, w11); | |||
| @@ -403,15 +365,7 @@ __global__ void kern_const_border_nchw4(SrcVisitor src, | |||
| } | |||
| } | |||
| } | |||
| template <bool signedness> | |||
| MEGDNN_DEVICE __forceinline__ static void transform_bit4x8_to_int8( | |||
| int (&result)[8], const int& source) { | |||
| #pragma unroll | |||
| for (int i = 0; i < 8; i++) { | |||
| result[i] = unpack_integer_4bits<signedness>( | |||
| reinterpret_cast<unsigned const&>(source), (i << 2)); | |||
| } | |||
| } | |||
| template <typename ctype, typename SrcVisitor, typename OutputConverter> | |||
| __global__ void kern_const_border_nchw64(SrcVisitor src, | |||
| @@ -457,7 +411,7 @@ __global__ void kern_const_border_nchw64(SrcVisitor src, | |||
| bool flag00 = okh0 && okw0, flag01 = okh0 && okw1, | |||
| flag10 = okh1 && okw0, flag11 = okh1 && okw1; | |||
| int8_t bval_4 = bval.as_storage() & 0xF; | |||
| int bval_8 = transform_int8_to_bit4x8<signedness>( | |||
| int bval_8 = transform_int8_to_b4x8<signedness>( | |||
| bval_4, bval_4, bval_4, bval_4, bval_4, bval_4, bval_4, bval_4); | |||
| int4 bval_int4; | |||
| bval_int4.x = bval_8; | |||
| @@ -488,31 +442,31 @@ __global__ void kern_const_border_nchw64(SrcVisitor src, | |||
| s[3] = bval_int4; | |||
| } | |||
| transform_bit4x8_to_int8<signedness>(s00, s[0].x); | |||
| transform_bit4x8_to_int8<signedness>(s01, s[1].x); | |||
| transform_bit4x8_to_int8<signedness>(s10, s[2].x); | |||
| transform_bit4x8_to_int8<signedness>(s11, s[3].x); | |||
| transform_b4x8_to_int8<signedness>(s00, s[0].x); | |||
| transform_b4x8_to_int8<signedness>(s01, s[1].x); | |||
| transform_b4x8_to_int8<signedness>(s10, s[2].x); | |||
| transform_b4x8_to_int8<signedness>(s11, s[3].x); | |||
| d.x = pack_output_func<signedness>(output_converter, s00, s01, s10, | |||
| s11, w00, w01, w10, w11); | |||
| transform_bit4x8_to_int8<signedness>(s00, s[0].y); | |||
| transform_bit4x8_to_int8<signedness>(s01, s[1].y); | |||
| transform_bit4x8_to_int8<signedness>(s10, s[2].y); | |||
| transform_bit4x8_to_int8<signedness>(s11, s[3].y); | |||
| transform_b4x8_to_int8<signedness>(s00, s[0].y); | |||
| transform_b4x8_to_int8<signedness>(s01, s[1].y); | |||
| transform_b4x8_to_int8<signedness>(s10, s[2].y); | |||
| transform_b4x8_to_int8<signedness>(s11, s[3].y); | |||
| d.y = pack_output_func<signedness>(output_converter, s00, s01, s10, | |||
| s11, w00, w01, w10, w11); | |||
| transform_bit4x8_to_int8<signedness>(s00, s[0].z); | |||
| transform_bit4x8_to_int8<signedness>(s01, s[1].z); | |||
| transform_bit4x8_to_int8<signedness>(s10, s[2].z); | |||
| transform_bit4x8_to_int8<signedness>(s11, s[3].z); | |||
| transform_b4x8_to_int8<signedness>(s00, s[0].z); | |||
| transform_b4x8_to_int8<signedness>(s01, s[1].z); | |||
| transform_b4x8_to_int8<signedness>(s10, s[2].z); | |||
| transform_b4x8_to_int8<signedness>(s11, s[3].z); | |||
| d.z = pack_output_func<signedness>(output_converter, s00, s01, s10, | |||
| s11, w00, w01, w10, w11); | |||
| transform_bit4x8_to_int8<signedness>(s00, s[0].w); | |||
| transform_bit4x8_to_int8<signedness>(s01, s[1].w); | |||
| transform_bit4x8_to_int8<signedness>(s10, s[2].w); | |||
| transform_bit4x8_to_int8<signedness>(s11, s[3].w); | |||
| transform_b4x8_to_int8<signedness>(s00, s[0].w); | |||
| transform_b4x8_to_int8<signedness>(s01, s[1].w); | |||
| transform_b4x8_to_int8<signedness>(s10, s[2].w); | |||
| transform_b4x8_to_int8<signedness>(s11, s[3].w); | |||
| d.w = pack_output_func<signedness>(output_converter, s00, s01, s10, | |||
| s11, w00, w01, w10, w11); | |||