| @@ -35,7 +35,9 @@ ConvBiasForward::CanonizedFilterMeta ConvBiasForward::check_exec( | |||||
| const TensorLayout& bias, const TensorLayout& z, | const TensorLayout& bias, const TensorLayout& z, | ||||
| const TensorLayout& dst, size_t workspace_in_bytes, | const TensorLayout& dst, size_t workspace_in_bytes, | ||||
| const PreprocessedFilter* preprocessed_filter) { | const PreprocessedFilter* preprocessed_filter) { | ||||
| megdnn_assert(src.dtype.enumv() == filter.dtype.enumv()); | |||||
| megdnn_assert((src.dtype.enumv() == filter.dtype.enumv()) || | |||||
| (src.dtype.enumv() == DTypeEnum::Quantized4Asymm && | |||||
| filter.dtype.enumv() == DTypeEnum::QuantizedS4)); | |||||
| // check compatibility of bias's scale | // check compatibility of bias's scale | ||||
| if (src.dtype.category() == DTypeCategory::QUANTIZED) { | if (src.dtype.category() == DTypeCategory::QUANTIZED) { | ||||
| if (bias.dtype.enumv() == DTypeEnum::QuantizedS32) { | if (bias.dtype.enumv() == DTypeEnum::QuantizedS32) { | ||||
| @@ -598,8 +598,10 @@ ConvolutionBase<Parameter>::deduce_layout_fwd(const TensorLayout& src, | |||||
| megdnn_assert_contiguous(src); | megdnn_assert_contiguous(src); | ||||
| megdnn_assert_contiguous(filter); | megdnn_assert_contiguous(filter); | ||||
| megdnn_assert(src.ndim >= 3_z, "%s", errmsg().c_str()); | megdnn_assert(src.ndim >= 3_z, "%s", errmsg().c_str()); | ||||
| megdnn_assert(src.dtype.enumv() == filter.dtype.enumv(), "%s", | |||||
| errmsg().c_str()); | |||||
| megdnn_assert(((src.dtype.enumv() == filter.dtype.enumv()) || | |||||
| (src.dtype.enumv() == DTypeEnum::Quantized4Asymm && | |||||
| filter.dtype.enumv() == DTypeEnum::QuantizedS4)), | |||||
| "%s", errmsg().c_str()); | |||||
| check_or_deduce_dtype_fwd(src.dtype, filter.dtype, dst.dtype); | check_or_deduce_dtype_fwd(src.dtype, filter.dtype, dst.dtype); | ||||
| size_t img_dim; | size_t img_dim; | ||||
| if (param().format == Param::Format::NCHW || | if (param().format == Param::Format::NCHW || | ||||
| @@ -488,6 +488,10 @@ void LowbitsAlignedTensorFormatBase::assert_valid( | |||||
| "bad stride:%s, %zu", layout.to_string().c_str(), | "bad stride:%s, %zu", layout.to_string().c_str(), | ||||
| layout.stride[i]); | layout.stride[i]); | ||||
| } | } | ||||
| if (!has_dim_unity_stride && | |||||
| (int)layout.stride[layout.ndim - 1] == | |||||
| round_up(1, (int)m_align_size_in_elements)) | |||||
| has_dim_unity_stride = true; | |||||
| megdnn_assert(layout.ndim == 0 || has_dim_unity_stride, | megdnn_assert(layout.ndim == 0 || has_dim_unity_stride, | ||||
| "innermost dim not contiguous"); | "innermost dim not contiguous"); | ||||
| } | } | ||||
| @@ -546,7 +550,12 @@ bool LowbitsAlignedTensorFormatBase::is_contiguous_spec( | |||||
| assert_valid(layout); | assert_valid(layout); | ||||
| ptrdiff_t expected = 1; | ptrdiff_t expected = 1; | ||||
| for (int i = static_cast<int>(layout.ndim) - 1; i >= 0; --i) { | for (int i = static_cast<int>(layout.ndim) - 1; i >= 0; --i) { | ||||
| if (layout.shape[i] != 1 && layout.stride[i] != expected) | |||||
| bool is_valid_stride = | |||||
| (layout.stride[i] == expected) || | |||||
| (expected == 1 && | |||||
| (int)layout.stride[i] == | |||||
| round_up(1, (int)m_align_size_in_elements)); | |||||
| if (layout.shape[i] != 1 && !is_valid_stride) | |||||
| return false; | return false; | ||||
| auto multiplier = layout.shape[i]; | auto multiplier = layout.shape[i]; | ||||
| if (i == static_cast<int>(layout.ndim) - 1) | if (i == static_cast<int>(layout.ndim) - 1) | ||||
| @@ -568,7 +577,7 @@ TensorLayout LowbitsAlignedTensorFormatBase::collapse_contiguous_spec( | |||||
| res.stride[0] = 1; | res.stride[0] = 1; | ||||
| return res; | return res; | ||||
| } | } | ||||
| if (res.shape[i] == 1 && res.stride[i] != 1) { | |||||
| if (res.shape[i] == 1) { | |||||
| res.remove_axis_inplace(i); | res.remove_axis_inplace(i); | ||||
| } | } | ||||
| } | } | ||||
| @@ -232,6 +232,7 @@ float megdnn::mul_scale(DType lhs, DType rhs) { | |||||
| (rhs.enumv() == DTypeTrait<dt2>::enumv)) \ | (rhs.enumv() == DTypeTrait<dt2>::enumv)) \ | ||||
| return lhs.param<dt1>().scale * rhs.param<dt2>().scale; | return lhs.param<dt1>().scale * rhs.param<dt2>().scale; | ||||
| cb_binary(::megdnn::dtype::QuantizedS8, ::megdnn::dtype::QuantizedS16) | cb_binary(::megdnn::dtype::QuantizedS8, ::megdnn::dtype::QuantizedS16) | ||||
| cb_binary(::megdnn::dtype::Quantized4Asymm, ::megdnn::dtype::QuantizedS4) | |||||
| #undef cb_binary | #undef cb_binary | ||||
| megdnn_assert(lhs.enumv() == rhs.enumv()); | megdnn_assert(lhs.enumv() == rhs.enumv()); | ||||
| @@ -66,7 +66,8 @@ public: | |||||
| CUDA_IMPLICIT_GEMM_1X1_SASS_NCHW4_DOTPROD_INT8, | CUDA_IMPLICIT_GEMM_1X1_SASS_NCHW4_DOTPROD_INT8, | ||||
| CUDA_IMPLICIT_GEMM_SASS_NCHW32_IMMA_INT8, | CUDA_IMPLICIT_GEMM_SASS_NCHW32_IMMA_INT8, | ||||
| CUDA_IMPLICIT_GEMM_1X1_SASS_NCHW32_IMMA_INT8, | CUDA_IMPLICIT_GEMM_1X1_SASS_NCHW32_IMMA_INT8, | ||||
| CUDA_IMPLICIT_GEMM_SASS_NCHW64_IMMA_INT4, | |||||
| CUDA_IMPLICIT_GEMM_SASS_NCHW64_IMMA_INT4_INT4, | |||||
| CUDA_IMPLICIT_GEMM_SASS_NCHW64_IMMA_UINT4_INT4, | |||||
| }; | }; | ||||
| using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>; | using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>; | ||||
| @@ -15,7 +15,7 @@ | |||||
| #include "./quint4x4x32_wmma/activation_u4.cuh" | #include "./quint4x4x32_wmma/activation_u4.cuh" | ||||
| #include "./quint4x4x32_wmma/reduce_with_scale_data.cuh" | #include "./quint4x4x32_wmma/reduce_with_scale_data.cuh" | ||||
| #include "./quint4x4x32_wmma/reduce_with_scale_filter.cuh" | |||||
| #include "./reduce_with_scale_filter.cuh" | |||||
| #include "./quint4x4x32_wmma/wmma_conv_integer_u4.cuh" | #include "./quint4x4x32_wmma/wmma_conv_integer_u4.cuh" | ||||
| using namespace megdnn; | using namespace megdnn; | ||||
| @@ -75,7 +75,7 @@ WorkspaceBundle ConvBiasForwardImpl::AlgoQUInt4x4x32WMMA::get_workspace_bundle( | |||||
| // for reduce filter | // for reduce filter | ||||
| { | { | ||||
| size_t A = OC, B = IC * FH * FW / 8, C = 1; | size_t A = OC, B = IC * FH * FW / 8, C = 1; | ||||
| ws_size_zp_filter += _do_dispatch_reduce_workspace_in_bytes(A, B, C); | |||||
| ws_size_zp_filter += do_dispatch_reduce_workspace_in_bytes(A, B, C); | |||||
| } | } | ||||
| size_t ws_size_zp_data = N * OH * OW * sizeof(int32_t); | size_t ws_size_zp_data = N * OH * OW * sizeof(int32_t); | ||||
| size_t ws_size_relayout_filter = get_workspace_in_bytes_do_conv(args); | size_t ws_size_relayout_filter = get_workspace_in_bytes_do_conv(args); | ||||
| @@ -135,11 +135,11 @@ void ConvBiasForwardImpl::AlgoQUInt4x4x32WMMA::exec( | |||||
| int32_t zp_data_filter = zp_data * zp_filter * FH * FW * IC; | int32_t zp_data_filter = zp_data * zp_filter * FH * FW * IC; | ||||
| auto&& stream = cuda_stream(handle); | auto&& stream = cuda_stream(handle); | ||||
| // zp filter | // zp filter | ||||
| _do_dispatch_reduce_with_scale_filter_u4( | |||||
| do_dispatch_reduce_with_scale_filter_4bit<false>( | |||||
| static_cast<uint8_t*>(args.filter_tensor->raw_ptr), -zp_data, OC, | static_cast<uint8_t*>(args.filter_tensor->raw_ptr), -zp_data, OC, | ||||
| FH * FW * IC / 8, ws_zp_filter.ptr<int32_t>(), stream); | FH * FW * IC / 8, ws_zp_filter.ptr<int32_t>(), stream); | ||||
| // zp data | // zp data | ||||
| _do_dispatch_reduce_with_scale_data_u4( | |||||
| do_dispatch_reduce_with_scale_data_u4( | |||||
| ws_zp_data.ptr<int32_t>(), | ws_zp_data.ptr<int32_t>(), | ||||
| static_cast<uint8_t*>(args.src_tensor->raw_ptr), N, IH, IW, OH, OW, | static_cast<uint8_t*>(args.src_tensor->raw_ptr), N, IH, IW, OH, OW, | ||||
| PH, PW, FH, FW, SH, SW, IC, -zp_filter, | PH, PW, FH, FW, SH, SW, IC, -zp_filter, | ||||
| @@ -173,12 +173,12 @@ void ConvBiasForwardImpl::AlgoQUInt4x4x32WMMA::exec( | |||||
| args.bias_tensor->compatible_ptr<int32_t>(), s0, s1, s2, s3}; | args.bias_tensor->compatible_ptr<int32_t>(), s0, s1, s2, s3}; | ||||
| auto&& param = args.opr->param(); | auto&& param = args.opr->param(); | ||||
| if (param.nonlineMode == Param::NonlineMode::RELU) { | if (param.nonlineMode == Param::NonlineMode::RELU) { | ||||
| _do_dispatch_activation_u4<ActivationRELU>( | |||||
| do_dispatch_activation_u4<ActivationRELU>( | |||||
| args.dst_tensor->compatible_ptr<int32_t>(), visitor, | args.dst_tensor->compatible_ptr<int32_t>(), visitor, | ||||
| ws_zp_data.ptr<int32_t>(), ws_zp_filter.ptr<int32_t>(), | ws_zp_data.ptr<int32_t>(), ws_zp_filter.ptr<int32_t>(), | ||||
| zp_data_filter, N, OC, OH, OW, stream); | zp_data_filter, N, OC, OH, OW, stream); | ||||
| } else if (param.nonlineMode == Param::NonlineMode::IDENTITY) { | } else if (param.nonlineMode == Param::NonlineMode::IDENTITY) { | ||||
| _do_dispatch_activation_u4<ActivationIdentity>( | |||||
| do_dispatch_activation_u4<ActivationIdentity>( | |||||
| args.dst_tensor->compatible_ptr<int32_t>(), visitor, | args.dst_tensor->compatible_ptr<int32_t>(), visitor, | ||||
| ws_zp_data.ptr<int32_t>(), ws_zp_filter.ptr<int32_t>(), | ws_zp_data.ptr<int32_t>(), ws_zp_filter.ptr<int32_t>(), | ||||
| zp_data_filter, N, OC, OH, OW, stream); | zp_data_filter, N, OC, OH, OW, stream); | ||||
| @@ -87,11 +87,10 @@ __global__ void kern_activation_u4(int32_t* dst, const int32_t* zp_data, | |||||
| } // namespace | } // namespace | ||||
| template <typename ActivationOp> | template <typename ActivationOp> | ||||
| void _do_dispatch_activation_u4(int32_t* dst, BiasVisitor visitor, | |||||
| const int32_t* zp_data, | |||||
| const int32_t* zp_filter, | |||||
| int32_t zp_data_filter, int batch_size, int co, | |||||
| int ho, int wo, cudaStream_t stream) { | |||||
| void do_dispatch_activation_u4(int32_t* dst, BiasVisitor visitor, | |||||
| const int32_t* zp_data, const int32_t* zp_filter, | |||||
| int32_t zp_data_filter, int batch_size, int co, | |||||
| int ho, int wo, cudaStream_t stream) { | |||||
| void (*fptr)(int32_t*, const int32_t*, const int32_t*, int32_t, int, int OC, | void (*fptr)(int32_t*, const int32_t*, const int32_t*, int32_t, int, int OC, | ||||
| int, int, BiasVisitor) = kern_activation_u4<ActivationOp>; | int, int, BiasVisitor) = kern_activation_u4<ActivationOp>; | ||||
| dim3 grids{0, 0, 0}; | dim3 grids{0, 0, 0}; | ||||
| @@ -105,7 +104,7 @@ void _do_dispatch_activation_u4(int32_t* dst, BiasVisitor visitor, | |||||
| } | } | ||||
| #define INST(_op) \ | #define INST(_op) \ | ||||
| template void _do_dispatch_activation_u4<_op>( \ | |||||
| template void do_dispatch_activation_u4<_op>( \ | |||||
| int32_t * dst, BiasVisitor visitor, const int32_t* zp_data, \ | int32_t * dst, BiasVisitor visitor, const int32_t* zp_data, \ | ||||
| const int32_t* zp_filter, int32_t zp_data_filter, int batch_size, \ | const int32_t* zp_filter, int32_t zp_data_filter, int batch_size, \ | ||||
| int co, int ho, int wo, cudaStream_t stream); | int co, int ho, int wo, cudaStream_t stream); | ||||
| @@ -82,12 +82,10 @@ struct ActivationIdentity { | |||||
| } // namespace activation_u4 | } // namespace activation_u4 | ||||
| template <typename ActivationOp> | template <typename ActivationOp> | ||||
| void _do_dispatch_activation_u4(int32_t* dst, | |||||
| activation_u4::BiasVisitor visitor, | |||||
| const int32_t* zp_data, | |||||
| const int32_t* zp_filter, | |||||
| int32_t zp_data_filter, int batch_size, int co, | |||||
| int ho, int wo, cudaStream_t stream); | |||||
| void do_dispatch_activation_u4(int32_t* dst, activation_u4::BiasVisitor visitor, | |||||
| const int32_t* zp_data, const int32_t* zp_filter, | |||||
| int32_t zp_data_filter, int batch_size, int co, | |||||
| int ho, int wo, cudaStream_t stream); | |||||
| } // namespace cuda | } // namespace cuda | ||||
| } // namespace megdnn | } // namespace megdnn | ||||
| @@ -444,7 +444,7 @@ reduce_in_spatial_block_and_along_input_channel_with_scale_u4_large_channels( | |||||
| } // namespace | } // namespace | ||||
| void megdnn::cuda::_do_dispatch_reduce_with_scale_data_u4( | |||||
| void megdnn::cuda::do_dispatch_reduce_with_scale_data_u4( | |||||
| int32_t* dst, const uint8_t* src, int batch_size, int ih, int iw, | int32_t* dst, const uint8_t* src, int batch_size, int ih, int iw, | ||||
| int oh, int ow, int ph, int pw, int fh, int fw, int sh, int sw, int ic, | int oh, int ow, int ph, int pw, int fh, int fw, int sh, int sw, int ic, | ||||
| int32_t scale, uint8_t zp_data, cudaStream_t stream) { | int32_t scale, uint8_t zp_data, cudaStream_t stream) { | ||||
| @@ -37,7 +37,7 @@ | |||||
| namespace megdnn { | namespace megdnn { | ||||
| namespace cuda { | namespace cuda { | ||||
| void _do_dispatch_reduce_with_scale_data_u4( | |||||
| void do_dispatch_reduce_with_scale_data_u4( | |||||
| int32_t* dst, const uint8_t* src, int batch_size, int ih, int iw, | int32_t* dst, const uint8_t* src, int batch_size, int ih, int iw, | ||||
| int oh, int ow, int ph, int pw, int fh, int fw, int sh, int sw, int ic, | int oh, int ow, int ph, int pw, int fh, int fw, int sh, int sw, int ic, | ||||
| int32_t scale, uint8_t zp_data, cudaStream_t stream); | int32_t scale, uint8_t zp_data, cudaStream_t stream); | ||||
| @@ -1,100 +0,0 @@ | |||||
| /*************************************************************************************************** | |||||
| * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. | |||||
| * | |||||
| * Redistribution and use in source and binary forms, with or without modification, are permitted | |||||
| * provided that the following conditions are met: | |||||
| * * Redistributions of source code must retain the above copyright notice, this list of | |||||
| * conditions and the following disclaimer. | |||||
| * * Redistributions in binary form must reproduce the above copyright notice, this list of | |||||
| * conditions and the following disclaimer in the documentation and/or other materials | |||||
| * provided with the distribution. | |||||
| * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used | |||||
| * to endorse or promote products derived from this software without specific prior written | |||||
| * permission. | |||||
| * | |||||
| * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR | |||||
| * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND | |||||
| * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE | |||||
| * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, | |||||
| * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; | |||||
| * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, | |||||
| * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | |||||
| * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||||
| * | |||||
| **************************************************************************************************/ | |||||
| /** | |||||
| * \file dnn/src/cuda/conv_bias/quint4x4x32_wmma/reduce_with_scale_filter.cu | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| */ | |||||
| #include "./reduce_with_scale_filter.cuh" | |||||
| #include "src/cuda/reduce_helper.cuh" | |||||
| using namespace megdnn; | |||||
| using namespace cuda; | |||||
| namespace { | |||||
| struct ReduceWithScaleUInt4Op { | |||||
| typedef int32_t wtype; | |||||
| const uint8_t* src; | |||||
| int32_t* dst; | |||||
| int32_t scale; | |||||
| static const wtype INIT = 0; | |||||
| #if MEGDNN_CC_CUDA | |||||
| __host__ __device__ void write(uint32_t idx, wtype val) { | |||||
| dst[idx] = val * scale; | |||||
| } | |||||
| __host__ __device__ static wtype apply(wtype a, wtype b) { return a + b; } | |||||
| __device__ wtype read(uint32_t idx) { | |||||
| constexpr uint32_t subbytes_per_pixel = 8; | |||||
| const uint32_t* sptr = | |||||
| (const uint32_t*)(src + subbytes_per_pixel * idx / 2); | |||||
| uint32_t val = *sptr; | |||||
| int32_t ret = 0; | |||||
| #pragma unroll | |||||
| for (int j = 0; j < 8; j++) { | |||||
| uint8_t cur = (val & 0xF); | |||||
| ret += cur; | |||||
| val = (val >> 4); | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| #endif | |||||
| }; | |||||
| } // namespace | |||||
| void megdnn::cuda::_do_dispatch_reduce_with_scale_filter_u4( | |||||
| const uint8_t* src, int32_t scale, uint32_t rows, uint32_t cols, | |||||
| int32_t* dst, cudaStream_t stream) { | |||||
| // rows = OC | |||||
| // cols is measured in pixels, i.e. IC * FH * FW / 8, a pixel consists of 8 | |||||
| // subbyte data, | |||||
| ReduceWithScaleUInt4Op op; | |||||
| op.src = src; | |||||
| op.scale = scale; | |||||
| op.dst = dst; | |||||
| static_cast<void>(op); | |||||
| static_cast<void>(stream); | |||||
| static_cast<void>(rows); | |||||
| static_cast<void>(cols); | |||||
| run_reduce<ReduceWithScaleUInt4Op, false>(dst + rows, rows, cols, 1, stream, | |||||
| op); | |||||
| } | |||||
| size_t megdnn::cuda::_do_dispatch_reduce_workspace_in_bytes(size_t A, size_t B, | |||||
| size_t C) { | |||||
| return get_reduce_workspace_in_bytes<ReduceWithScaleUInt4Op>(A, B, C); | |||||
| } | |||||
| // vim: ft=cpp syntax=cuda.doxygen | |||||
| @@ -1,48 +0,0 @@ | |||||
| /*************************************************************************************************** | |||||
| * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. | |||||
| * | |||||
| * Redistribution and use in source and binary forms, with or without modification, are permitted | |||||
| * provided that the following conditions are met: | |||||
| * * Redistributions of source code must retain the above copyright notice, this list of | |||||
| * conditions and the following disclaimer. | |||||
| * * Redistributions in binary form must reproduce the above copyright notice, this list of | |||||
| * conditions and the following disclaimer in the documentation and/or other materials | |||||
| * provided with the distribution. | |||||
| * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used | |||||
| * to endorse or promote products derived from this software without specific prior written | |||||
| * permission. | |||||
| * | |||||
| * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR | |||||
| * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND | |||||
| * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE | |||||
| * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, | |||||
| * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; | |||||
| * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, | |||||
| * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | |||||
| * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||||
| * | |||||
| **************************************************************************************************/ | |||||
| /** | |||||
| * \file dnn/src/cuda/conv_bias/quint4x4x32_wmma/reduce_with_scale_filter.cuh | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| */ | |||||
| #include "src/cuda/utils.cuh" | |||||
| namespace megdnn { | |||||
| namespace cuda { | |||||
| void _do_dispatch_reduce_with_scale_filter_u4(const uint8_t* src, int32_t scale, | |||||
| uint32_t rows, uint32_t cols, | |||||
| int32_t* dst, | |||||
| cudaStream_t stream); | |||||
| size_t _do_dispatch_reduce_workspace_in_bytes(size_t A, size_t B, size_t C); | |||||
| } // namespace cuda | |||||
| } // namespace megdnn | |||||
| // vim: ft=cpp syntax=cuda.doxygen | |||||
| @@ -0,0 +1,114 @@ | |||||
| /*************************************************************************************************** | |||||
| * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. | |||||
| * | |||||
| * Redistribution and use in source and binary forms, with or without | |||||
| *modification, are permitted provided that the following conditions are met: | |||||
| * * Redistributions of source code must retain the above copyright notice, | |||||
| *this list of conditions and the following disclaimer. | |||||
| * * Redistributions in binary form must reproduce the above copyright | |||||
| *notice, this list of conditions and the following disclaimer in the | |||||
| *documentation and/or other materials provided with the distribution. | |||||
| * * Neither the name of the NVIDIA CORPORATION nor the names of its | |||||
| *contributors may be used to endorse or promote products derived from this | |||||
| *software without specific prior written permission. | |||||
| * | |||||
| * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||||
| *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||||
| *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | |||||
| *DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, | |||||
| *INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, | |||||
| * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, | |||||
| *DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY | |||||
| *OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING | |||||
| *NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, | |||||
| *EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||||
| * | |||||
| **************************************************************************************************/ | |||||
| /** | |||||
| * \file dnn/src/cuda/conv_bias/reduce_with_scale_filter.cu | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #include "./reduce_with_scale_filter.cuh" | |||||
| #include "src/cuda/reduce_helper.cuh" | |||||
| #include "src/cuda/integer_subbyte_utils.cuh" | |||||
| using namespace megdnn; | |||||
| using namespace cuda; | |||||
| namespace { | |||||
| template <bool signedness> | |||||
| struct ReduceWithScaleInt4Op { | |||||
| typedef int32_t wtype; | |||||
| const uint8_t* src; | |||||
| int32_t* dst; | |||||
| int32_t scale; | |||||
| static const wtype INIT = 0; | |||||
| #if MEGDNN_CC_CUDA | |||||
| __host__ __device__ void write(uint32_t idx, wtype val) { | |||||
| dst[idx] = val * scale; | |||||
| } | |||||
| __host__ __device__ static wtype apply(wtype a, wtype b) { return a + b; } | |||||
| __device__ wtype read(uint32_t idx) { | |||||
| constexpr uint32_t subbytes_per_pixel = 8; | |||||
| const uint32_t* sptr = | |||||
| (const uint32_t*)(src + subbytes_per_pixel * idx / 2); | |||||
| uint32_t val = *sptr; | |||||
| int32_t ret = 0; | |||||
| #pragma unroll | |||||
| for (int j = 0; j < 8; j++) { | |||||
| ret += integer_subbyte::unpack_integer_4bits<signedness>(val, | |||||
| (j << 2)); | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| #endif | |||||
| }; | |||||
| } // namespace | |||||
| template <bool signedness> | |||||
| void megdnn::cuda::do_dispatch_reduce_with_scale_filter_4bit( | |||||
| const uint8_t* src, int32_t scale, uint32_t rows, uint32_t cols, | |||||
| int32_t* dst, cudaStream_t stream) { | |||||
| // rows = OC | |||||
| // cols is measured in pixels, i.e. IC * FH * FW / 8, a pixel consists of 8 | |||||
| // subbyte data, | |||||
| ReduceWithScaleInt4Op<signedness> op; | |||||
| op.src = src; | |||||
| op.scale = scale; | |||||
| op.dst = dst; | |||||
| static_cast<void>(op); | |||||
| static_cast<void>(stream); | |||||
| static_cast<void>(rows); | |||||
| static_cast<void>(cols); | |||||
| run_reduce<ReduceWithScaleInt4Op<signedness>, false>(dst + rows, rows, cols, | |||||
| 1, stream, op); | |||||
| } | |||||
| #define INST(signedness) \ | |||||
| template void \ | |||||
| megdnn::cuda::do_dispatch_reduce_with_scale_filter_4bit<signedness>( \ | |||||
| const uint8_t* src, int32_t scale, uint32_t rows, uint32_t cols, \ | |||||
| int32_t* dst, cudaStream_t stream) | |||||
| INST(false); | |||||
| INST(true); | |||||
| #undef INST | |||||
| size_t megdnn::cuda::do_dispatch_reduce_workspace_in_bytes(size_t A, size_t B, | |||||
| size_t C) { | |||||
| return get_reduce_workspace_in_bytes<ReduceWithScaleInt4Op<false>>(A, B, C); | |||||
| } | |||||
| // vim: ft=cpp syntax=cuda.doxygen | |||||
| @@ -0,0 +1,52 @@ | |||||
| /*************************************************************************************************** | |||||
| * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. | |||||
| * | |||||
| * Redistribution and use in source and binary forms, with or without | |||||
| *modification, are permitted provided that the following conditions are met: | |||||
| * * Redistributions of source code must retain the above copyright notice, | |||||
| *this list of conditions and the following disclaimer. | |||||
| * * Redistributions in binary form must reproduce the above copyright | |||||
| *notice, this list of conditions and the following disclaimer in the | |||||
| *documentation and/or other materials provided with the distribution. | |||||
| * * Neither the name of the NVIDIA CORPORATION nor the names of its | |||||
| *contributors may be used to endorse or promote products derived from this | |||||
| *software without specific prior written permission. | |||||
| * | |||||
| * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||||
| *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||||
| *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | |||||
| *DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, | |||||
| *INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, | |||||
| * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, | |||||
| *DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY | |||||
| *OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING | |||||
| *NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, | |||||
| *EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||||
| * | |||||
| **************************************************************************************************/ | |||||
| /** | |||||
| * \file dnn/src/cuda/conv_bias/reduce_with_scale_filter.cuh | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #include "src/cuda/utils.cuh" | |||||
| namespace megdnn { | |||||
| namespace cuda { | |||||
| template <bool signedness> | |||||
| void do_dispatch_reduce_with_scale_filter_4bit(const uint8_t* src, | |||||
| int32_t scale, uint32_t rows, | |||||
| uint32_t cols, int32_t* dst, | |||||
| cudaStream_t stream); | |||||
| size_t do_dispatch_reduce_workspace_in_bytes(size_t A, size_t B, size_t C); | |||||
| } // namespace cuda | |||||
| } // namespace megdnn | |||||
| // vim: ft=cpp syntax=cuda.doxygen | |||||
| @@ -1,312 +0,0 @@ | |||||
| /** | |||||
| * \file dnn/src/cuda/conv_bias/sass_implicit_gemm_int4_nchw64_imma.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #include "./algo.h" | |||||
| #include "src/cuda/conv_bias/sass_helper.cuh" | |||||
| #include "src/cuda/sass_loader.h" | |||||
| #include "src/cuda/utils.h" | |||||
| #include "src/common/conv_bias.h" | |||||
| using namespace megdnn; | |||||
| using namespace cuda; | |||||
| using namespace sass; | |||||
| namespace { | |||||
| #if !MEGDNN_TEGRA_X1 | |||||
| // all stride are in bytes | |||||
| void compute_conv2d_offset(size_t fh, size_t fw, size_t ics, size_t ihs, | |||||
| Conv2dConstantOffset& constant_offset) { | |||||
| constexpr int interleaved = 64; | |||||
| constexpr int size_bits = 4; | |||||
| constexpr int threablock_k = 128; | |||||
| constexpr int inc_step = threablock_k / interleaved; | |||||
| size_t i = 0; | |||||
| int* s32 = reinterpret_cast<int*>(&(constant_offset.c_offset[0])); | |||||
| for (; i < inc_step; i++) { | |||||
| int c = i / (fh * fw); | |||||
| int khkw = i % (fh * fw); | |||||
| int kh = khkw / fw; | |||||
| int kw = khkw % fw; | |||||
| s32[2 * i] = c * ics + kh * ihs + kw * interleaved * size_bits / 8; | |||||
| int8_t* s8 = reinterpret_cast<int8_t*>(&(s32[2 * i + 1])); | |||||
| s8[0] = kh; | |||||
| s8[1] = kw; | |||||
| s8[2] = -kh; | |||||
| s8[3] = -kw; | |||||
| } | |||||
| for (; i < (inc_step + fh * fw * inc_step); i++) { | |||||
| int c = i / (fh * fw); | |||||
| int khkw = i % (fh * fw); | |||||
| int kh = khkw / fw; | |||||
| int kw = khkw % fw; | |||||
| s32[2 * i] = c * ics + kh * ihs + kw * interleaved * size_bits / 8; | |||||
| int8_t* s8 = reinterpret_cast<int8_t*>(&(s32[2 * i + 1])); | |||||
| s8[0] = kh; | |||||
| s8[1] = kw; | |||||
| s8[2] = -kh; | |||||
| s8[3] = -kw; | |||||
| int i_ = i - inc_step; | |||||
| c = i_ / (fh * fw); | |||||
| khkw = i_ % (fh * fw); | |||||
| kh = khkw / fw; | |||||
| kw = khkw % fw; | |||||
| s32[2 * i] -= c * ics + kh * ihs + kw * interleaved * size_bits / 8; | |||||
| } | |||||
| } | |||||
| #endif | |||||
| }; // namespace | |||||
| std::string ConvBiasForwardImpl::AlgoSASSInt4NCHW64IMMAImplicitGemm::kernel_key( | |||||
| const SizeArgs& args) const { | |||||
| std::string kernel_key; | |||||
| using NonlineMode = Param::NonlineMode; | |||||
| auto&& param = args.opr->param(); | |||||
| if (args.z_layout->ndim > 0) { | |||||
| kernel_key = | |||||
| ssprintf("%s_conv_bias_int4_fuse_z_imma8832_ldg16_%ux%u", | |||||
| current_device_arch_name(), m_tile_nhw, m_tile_oc); | |||||
| } else { | |||||
| kernel_key = | |||||
| ssprintf("%s_conv_bias_int4_imma8832_ldg16_%ux%u", | |||||
| current_device_arch_name(), m_tile_nhw, m_tile_oc); | |||||
| } | |||||
| if (param.nonlineMode == NonlineMode::H_SWISH) { | |||||
| kernel_key += "_hswish"; | |||||
| } else { | |||||
| megdnn_assert(param.nonlineMode == NonlineMode::RELU || | |||||
| param.nonlineMode == NonlineMode::IDENTITY); | |||||
| kernel_key += "_relu"; | |||||
| } | |||||
| return kernel_key; | |||||
| } | |||||
| bool ConvBiasForwardImpl::AlgoSASSInt4NCHW64IMMAImplicitGemm::is_available( | |||||
| const SizeArgs& args) const { | |||||
| if (args.bias_layout->ndim <= 0) | |||||
| return false; | |||||
| using Param = param::ConvBias; | |||||
| using Format = Param::Format; | |||||
| using Sparse = Param::Sparse; | |||||
| using Mode = Param::Mode; | |||||
| bool available = true; | |||||
| auto&& param = args.opr->param(); | |||||
| auto&& fm = args.filter_meta; | |||||
| if (!check_bias_share_in_channel(*(args.bias_layout), param.format)) | |||||
| return false; | |||||
| if (param.format != Format::NCHW64) | |||||
| return false; | |||||
| UNPACK_CONV_BIAS_NCHW64_PARAM(*(args.src_layout), fm, *(args.dst_layout), | |||||
| param); | |||||
| // TODO support group conv | |||||
| available &= param.sparse == Sparse::DENSE; | |||||
| // mode must be cross correlation | |||||
| available &= param.mode == Mode::CROSS_CORRELATION; | |||||
| // check data type | |||||
| auto src_dtype = args.src_layout->dtype, | |||||
| filter_dtype = args.filter_layout->dtype, | |||||
| bias_dtype = args.bias_layout->dtype, | |||||
| dst_dtype = args.dst_layout->dtype; | |||||
| available &= (src_dtype.enumv() == DTypeEnum::QuantizedS4 && | |||||
| filter_dtype.enumv() == DTypeEnum::QuantizedS4 && | |||||
| bias_dtype.enumv() == DTypeEnum::QuantizedS32 && | |||||
| dst_dtype.enumv() == DTypeEnum::QuantizedS4); | |||||
| // TODO: support dialtion | |||||
| available &= dh == 1 && dw == 1; | |||||
| // ensure precomputed offsets are positive integers | |||||
| available &= hi >= fh && wi >= fw; | |||||
| // only support sm_75 or later, platform should have tensorcore int8 | |||||
| // support | |||||
| available &= is_compute_capability_required(7, 5); | |||||
| // param buffer size is 4K, use 3K to store precomputed offset, fh * fw <= | |||||
| // (3*1024/4/2/2) - 1 | |||||
| available &= fh * fw <= 191; | |||||
| return available; | |||||
| } | |||||
| size_t | |||||
| ConvBiasForwardImpl::AlgoSASSInt4NCHW64IMMAImplicitGemm::get_workspace_in_bytes( | |||||
| const SizeArgs& args) const { | |||||
| if (args.preprocessed_filter == nullptr) { | |||||
| return args.filter_layout->span().dist_byte() + | |||||
| args.bias_layout->span().dist_byte(); | |||||
| } | |||||
| return 0_z; | |||||
| } | |||||
| void ConvBiasForwardImpl::AlgoSASSInt4NCHW64IMMAImplicitGemm::exec( | |||||
| const ExecArgs& args) const { | |||||
| #if MEGDNN_TEGRA_X1 | |||||
| megdnn_throw("sass kernel is disabled at compile time for TX1"); | |||||
| #else | |||||
| using Format = Param::Format; | |||||
| auto&& param = args.opr->param(); | |||||
| auto&& fm = args.filter_meta; | |||||
| UNPACK_CONV_BIAS_NCHW64_PARAM(*(args.src_layout), fm, *(args.dst_layout), | |||||
| param); | |||||
| auto&& stream = cuda_stream(args.opr->handle()); | |||||
| constexpr int interleaved = 64; | |||||
| void* bias_ptr = nullptr; | |||||
| void* filter_ptr = nullptr; | |||||
| if (args.preprocessed_filter) { | |||||
| megdnn_assert(args.preprocessed_filter->tensors.size() == 2); | |||||
| filter_ptr = args.preprocessed_filter->tensors[0].raw_ptr; | |||||
| bias_ptr = args.preprocessed_filter->tensors[1].raw_ptr; | |||||
| } else { | |||||
| // reorder filter and bias | |||||
| filter_ptr = reinterpret_cast<void*>(args.workspace.raw_ptr); | |||||
| bias_ptr = | |||||
| reinterpret_cast<void*>(args.workspace.raw_ptr + | |||||
| args.filter_layout->span().dist_byte()); | |||||
| if (args.z_layout->ndim > 0) { | |||||
| reorder_imma_filter_bias<4, 64>( | |||||
| reinterpret_cast<int8_t*>(filter_ptr), | |||||
| reinterpret_cast<int32_t*>(bias_ptr), | |||||
| reinterpret_cast<int8_t*>(args.filter_tensor->raw_ptr), | |||||
| args.bias_tensor->compatible_ptr<int32_t>(), co, ci, fh, fw, | |||||
| stream); | |||||
| } else { | |||||
| reorder_imma_filter_bias<4, 64, true>( | |||||
| reinterpret_cast<int8_t*>(filter_ptr), | |||||
| reinterpret_cast<int32_t*>(bias_ptr), | |||||
| reinterpret_cast<int8_t*>(args.filter_tensor->raw_ptr), | |||||
| args.bias_tensor->compatible_ptr<int32_t>(), co, ci, fh, fw, | |||||
| stream); | |||||
| } | |||||
| } | |||||
| uint32_t u32_n = n, u32_ci = ci, u32_hi = hi, u32_wi = wi, u32_fh = fh, | |||||
| u32_fw = fw, u32_sh = sh, u32_sw = sw, u32_ph = ph, u32_pw = pw, | |||||
| u32_co = co, u32_ho = ho, u32_wo = wo; | |||||
| Conv2dInt4Param kern_param(u32_n, u32_ci, u32_hi, u32_wi, u32_fh, u32_fw, | |||||
| u32_sh, u32_sw, u32_ph, u32_pw, u32_co, u32_ho, | |||||
| u32_wo, interleaved); | |||||
| Conv2dConstantOffset kern_coffset; | |||||
| compute_conv2d_offset(fh, fw, kern_param.ics, kern_param.ihs, kern_coffset); | |||||
| // The starting address of Turing param buffer is c[0x0][0x160] | |||||
| kern_coffset.c_offset_param.begin = param_buffer_start_address(); | |||||
| kern_coffset.c_offset_param.size = 16 * (1 + fh * fw); | |||||
| kern_coffset.c_offset_param.max = 16 * fh * fw; | |||||
| kern_coffset.c_offset_param.rewind = 16 * (1 - fh * fw); | |||||
| auto kern_key = kernel_key(args); | |||||
| float src_scale = args.src_layout->dtype.param<dtype::QuantizedS4>().scale, | |||||
| filter_scale = | |||||
| args.filter_layout->dtype.param<dtype::QuantizedS4>().scale, | |||||
| bias_scale = | |||||
| args.bias_layout->dtype.param<dtype::QuantizedS32>().scale, | |||||
| dst_scale = args.dst_layout->dtype.param<dtype::QuantizedS4>().scale; | |||||
| float alpha = src_scale * filter_scale / dst_scale, | |||||
| beta = bias_scale / dst_scale; | |||||
| float inv_dst_scale = 1.f / dst_scale; | |||||
| unsigned int tx = m_threads, ty = 1; | |||||
| unsigned int gridx = div_ceil<unsigned int>( | |||||
| static_cast<unsigned int>(n * ho * wo), m_tile_nhw); | |||||
| unsigned int gridy = | |||||
| div_ceil<unsigned int>(static_cast<unsigned int>(co), m_tile_oc); | |||||
| void* src_ptr = const_cast<void*>(args.src_tensor->raw_ptr); | |||||
| void* dst_ptr = const_cast<void*>(args.dst_tensor->raw_ptr); | |||||
| using NonlineMode = Param::NonlineMode; | |||||
| auto&& kernel = SASSKernelLoader::instance().get_kernel(kern_key, kern_key); | |||||
| if (args.z_layout->ndim > 0) { | |||||
| void* z_ptr = const_cast<void*>(args.z_tensor->raw_ptr); | |||||
| float z_scale = args.z_layout->dtype.param<dtype::QuantizedS4>().scale; | |||||
| float gamma = z_scale / dst_scale; | |||||
| std::vector<void*> params = {&src_ptr, &filter_ptr, &bias_ptr, &z_ptr, | |||||
| &dst_ptr, &alpha, &beta, &gamma}; | |||||
| kern_coffset.c_offset_param.begin += | |||||
| sizeof(src_ptr) + sizeof(filter_ptr) + sizeof(bias_ptr) + | |||||
| sizeof(z_ptr) + sizeof(dst_ptr) + sizeof(alpha) + sizeof(beta) + | |||||
| sizeof(gamma); | |||||
| uint32_t relu = param.nonlineMode == NonlineMode::RELU ? 1 : 0; | |||||
| if (param.nonlineMode == NonlineMode::H_SWISH) { | |||||
| params.push_back(&dst_scale); | |||||
| params.push_back(&inv_dst_scale); | |||||
| kern_coffset.c_offset_param.begin += | |||||
| sizeof(dst_scale) + sizeof(inv_dst_scale); | |||||
| } else { | |||||
| params.push_back(&relu); | |||||
| kern_coffset.c_offset_param.begin += sizeof(relu); | |||||
| } | |||||
| params.push_back(&kern_param); | |||||
| kern_coffset.c_offset_param.begin += sizeof(kern_param); | |||||
| kern_coffset.c_offset_param.begin += | |||||
| sizeof(kern_coffset.c_offset_param); | |||||
| kern_coffset.c_offset_param.max += kern_coffset.c_offset_param.begin; | |||||
| params.push_back(&kern_coffset); | |||||
| cucheck(cuLaunchKernel(kernel, gridx, gridy, 1, tx, ty, 1, 0, stream, | |||||
| params.data(), 0)); | |||||
| } else { | |||||
| std::vector<void*> params = {&src_ptr, &filter_ptr, &bias_ptr, | |||||
| &dst_ptr, &alpha, &beta}; | |||||
| kern_coffset.c_offset_param.begin += | |||||
| sizeof(src_ptr) + sizeof(filter_ptr) + sizeof(bias_ptr) + | |||||
| sizeof(dst_ptr) + sizeof(alpha) + sizeof(beta); | |||||
| uint32_t relu = param.nonlineMode == NonlineMode::RELU ? 1 : 0; | |||||
| if (param.nonlineMode == NonlineMode::H_SWISH) { | |||||
| params.push_back(&dst_scale); | |||||
| params.push_back(&inv_dst_scale); | |||||
| kern_coffset.c_offset_param.begin += | |||||
| sizeof(dst_scale) + sizeof(inv_dst_scale); | |||||
| } else { | |||||
| params.push_back(&relu); | |||||
| kern_coffset.c_offset_param.begin += sizeof(relu); | |||||
| } | |||||
| params.push_back(&kern_param); | |||||
| kern_coffset.c_offset_param.begin += sizeof(kern_param); | |||||
| kern_coffset.c_offset_param.begin += | |||||
| sizeof(kern_coffset.c_offset_param); | |||||
| kern_coffset.c_offset_param.max += kern_coffset.c_offset_param.begin; | |||||
| params.push_back(&kern_coffset); | |||||
| cucheck(cuLaunchKernel(kernel, gridx, gridy, 1, tx, ty, 1, 0, stream, | |||||
| params.data(), 0)); | |||||
| } | |||||
| after_kernel_launch(); | |||||
| #endif | |||||
| } | |||||
| size_t ConvBiasForwardImpl::AlgoSASSInt4NCHW64IMMAImplicitGemm:: | |||||
| get_preprocess_workspace_in_bytes(const SizeArgs& args) const { | |||||
| return 0_z; | |||||
| } | |||||
| SmallVector<TensorLayout> ConvBiasForwardImpl:: | |||||
| AlgoSASSInt4NCHW64IMMAImplicitGemm::deduce_preprocessed_filter_layout( | |||||
| const SizeArgs& args) const { | |||||
| return {args.filter_layout->collapse_contiguous(), | |||||
| args.bias_layout->collapse_contiguous()}; | |||||
| } | |||||
| void ConvBiasForwardImpl::AlgoSASSInt4NCHW64IMMAImplicitGemm::exec_preprocess( | |||||
| const ExecArgs& args) const { | |||||
| using Format = Param::Format; | |||||
| auto&& param = args.opr->param(); | |||||
| auto&& fm = args.filter_meta; | |||||
| UNPACK_CONV_BIAS_NCHW64_PARAM(*(args.src_layout), fm, *(args.dst_layout), | |||||
| param); | |||||
| auto&& stream = cuda_stream(args.opr->handle()); | |||||
| reorder_imma_filter_bias<4, 64>( | |||||
| reinterpret_cast<int8_t*>( | |||||
| args.preprocessed_filter->tensors[0].raw_ptr), | |||||
| args.preprocessed_filter->tensors[1].compatible_ptr<int32_t>(), | |||||
| reinterpret_cast<int8_t*>(args.filter_tensor->raw_ptr), | |||||
| args.bias_tensor->compatible_ptr<int32_t>(), co, ci, fh, fw, | |||||
| stream); | |||||
| } | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -161,6 +161,38 @@ void forward_bias<dt_qint4, dt_qint4, dt_qint32, dt_qint32>( | |||||
| forward_bias<dt_qint8, dt_qint8, dt_qint32, dt_qint32>( | forward_bias<dt_qint8, dt_qint8, dt_qint32, dt_qint32>( | ||||
| new_src, new_flt, bias, dst, nullptr, new_filter_meta); | new_src, new_flt, bias, dst, nullptr, new_filter_meta); | ||||
| } | } | ||||
| template <> | |||||
| void forward_bias<dt_quint4, dt_qint4, dt_qint32, dt_qint32>( | |||||
| _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_in bias, | |||||
| _megdnn_tensor_out dst, dt_byte* workspace_ptr, | |||||
| const ConvBiasForward::CanonizedFilterMeta& filter_meta) { | |||||
| auto convert_layout_src = [](const TensorLayout& layout) { | |||||
| auto ret = layout; | |||||
| auto param = layout.dtype.param<dtype::Quantized4Asymm>(); | |||||
| ret.dtype = dtype::QuantizedS8(param.scale); | |||||
| ret.format = TensorFormat(ret.dtype); | |||||
| ret.init_contiguous_stride(); | |||||
| return ret; | |||||
| }; | |||||
| auto convert_layout_flt = [](const TensorLayout& layout) { | |||||
| auto ret = layout; | |||||
| auto param = layout.dtype.param<dtype::QuantizedS4>(); | |||||
| ret.dtype = dtype::QuantizedS8(param.scale); | |||||
| ret.format = TensorFormat(ret.dtype); | |||||
| ret.init_contiguous_stride(); | |||||
| return ret; | |||||
| }; | |||||
| TensorND new_src = {workspace_ptr, convert_layout_src(src.layout)}; | |||||
| TensorND new_flt = {workspace_ptr + new_src.layout.span().dist_byte(), | |||||
| convert_layout_flt(filter.layout)}; | |||||
| uint4_to_int8(src, new_src); | |||||
| int4_to_int8(filter, new_flt); | |||||
| auto new_filter_meta = filter_meta; | |||||
| new_filter_meta.dtype = new_flt.layout.dtype; | |||||
| forward_bias<dt_qint8, dt_qint8, dt_qint32, dt_qint32>( | |||||
| new_src, new_flt, bias, dst, nullptr, new_filter_meta); | |||||
| } | |||||
| } // namespace convolution | } // namespace convolution | ||||
| size_t ConvBiasForwardImpl::get_workspace_in_bytes(const TensorLayout& src, | size_t ConvBiasForwardImpl::get_workspace_in_bytes(const TensorLayout& src, | ||||
| @@ -211,9 +243,10 @@ void ConvBiasForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in filter, | |||||
| TensorLayout{dst.layout, bias.layout.dtype}}; | TensorLayout{dst.layout, bias.layout.dtype}}; | ||||
| workspace_ptr += sfb.layout.span().dist_byte(); | workspace_ptr += sfb.layout.span().dist_byte(); | ||||
| } | } | ||||
| #define DISPATCH_RAW(in_dt, bias_dt, out_dt, cmode, func) \ | |||||
| #define DISPATCH_RAW(in_dt, flt_dt, bias_dt, out_dt, cmode, func) \ | |||||
| else if (src.layout.dtype.enumv() == DTypeTrait<dtype::in_dt>::enumv && \ | else if (src.layout.dtype.enumv() == DTypeTrait<dtype::in_dt>::enumv && \ | ||||
| filter.layout.dtype.enumv() == DTypeTrait<dtype::in_dt>::enumv && \ | |||||
| filter.layout.dtype.enumv() == \ | |||||
| DTypeTrait<dtype::flt_dt>::enumv && \ | |||||
| bias.layout.dtype.enumv() == DTypeTrait<dtype::bias_dt>::enumv && \ | bias.layout.dtype.enumv() == DTypeTrait<dtype::bias_dt>::enumv && \ | ||||
| sfb.layout.dtype.enumv() == DTypeTrait<dtype::out_dt>::enumv && \ | sfb.layout.dtype.enumv() == DTypeTrait<dtype::out_dt>::enumv && \ | ||||
| param().compute_mode == Param::ComputeMode::cmode) { \ | param().compute_mode == Param::ComputeMode::cmode) { \ | ||||
| @@ -222,7 +255,7 @@ void ConvBiasForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in filter, | |||||
| } | } | ||||
| #define DISPATCH(in_dt, out_dt) \ | #define DISPATCH(in_dt, out_dt) \ | ||||
| DISPATCH_RAW( \ | DISPATCH_RAW( \ | ||||
| in_dt, out_dt, out_dt, DEFAULT, \ | |||||
| in_dt, in_dt, out_dt, out_dt, DEFAULT, \ | |||||
| (convolution::forward_bias<DTypeTrait<dtype::in_dt>::ctype, \ | (convolution::forward_bias<DTypeTrait<dtype::in_dt>::ctype, \ | ||||
| DTypeTrait<dtype::in_dt>::ctype, \ | DTypeTrait<dtype::in_dt>::ctype, \ | ||||
| DTypeTrait<dtype::out_dt>::ctype, \ | DTypeTrait<dtype::out_dt>::ctype, \ | ||||
| @@ -236,16 +269,21 @@ void ConvBiasForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in filter, | |||||
| DISPATCH(QuantizedS8, Float32) | DISPATCH(QuantizedS8, Float32) | ||||
| DISPATCH(Quantized8Asymm, QuantizedS32) | DISPATCH(Quantized8Asymm, QuantizedS32) | ||||
| DISPATCH(Quantized4Asymm, QuantizedS32) | DISPATCH(Quantized4Asymm, QuantizedS32) | ||||
| DISPATCH_RAW(QuantizedS8, QuantizedS32, QuantizedS32, FLOAT32, | |||||
| DISPATCH_RAW(QuantizedS8, QuantizedS8, QuantizedS32, QuantizedS32, | |||||
| FLOAT32, | |||||
| (convolution::forward_bias<dt_int8, dt_int8, dt_int32, | (convolution::forward_bias<dt_int8, dt_int8, dt_int32, | ||||
| dt_int32>)) | dt_int32>)) | ||||
| DISPATCH(QuantizedS4, QuantizedS32) | DISPATCH(QuantizedS4, QuantizedS32) | ||||
| DISPATCH_RAW(Quantized4Asymm, QuantizedS4, QuantizedS32, QuantizedS32, | |||||
| DEFAULT, | |||||
| (convolution::forward_bias<dt_quint4, dt_qint4, dt_qint32, | |||||
| dt_qint32>)) | |||||
| #if !MEGDNN_DISABLE_FLOAT16 | #if !MEGDNN_DISABLE_FLOAT16 | ||||
| DISPATCH(Float16, Float16) | DISPATCH(Float16, Float16) | ||||
| DISPATCH_RAW(Float16, Float16, Float16, FLOAT32, | |||||
| DISPATCH_RAW(Float16, Float16, Float16, Float16, FLOAT32, | |||||
| (convolution::forward_bias<dt_float16, dt_float16, | (convolution::forward_bias<dt_float16, dt_float16, | ||||
| dt_float16, dt_float32>)) | dt_float16, dt_float32>)) | ||||
| DISPATCH_RAW(BFloat16, BFloat16, BFloat16, FLOAT32, | |||||
| DISPATCH_RAW(BFloat16, BFloat16, BFloat16, BFloat16, FLOAT32, | |||||
| (convolution::forward_bias<dt_bfloat16, dt_bfloat16, | (convolution::forward_bias<dt_bfloat16, dt_bfloat16, | ||||
| dt_bfloat16, dt_float32>)) | dt_bfloat16, dt_float32>)) | ||||
| #endif | #endif | ||||
| @@ -57,6 +57,54 @@ void megdnn::naive::uint8_to_uint4(const TensorND& in, const TensorND& out) { | |||||
| } | } | ||||
| } | } | ||||
| void megdnn::naive::uint4_to_int8(const TensorND& in, const TensorND& out) { | |||||
| auto in_ptr = static_cast<uint8_t*>(in.raw_ptr) + in.layout.span().low_byte; | |||||
| auto out_ptr = out.compatible_ptr<int8_t>() + out.layout.span().low_byte; | |||||
| const auto& ly = in.layout; | |||||
| int8_t zero_point = | |||||
| (int8_t)ly.dtype.param<dtype::Quantized4Asymm>().zero_point; | |||||
| auto dim_in = ly.shape[ly.ndim - 1]; | |||||
| auto elems = ly.total_nr_elems(); | |||||
| auto dim_out = elems / dim_in; | |||||
| auto stride_out = div_ceil(dim_in, 2_z); | |||||
| for (size_t i = 0; i < dim_out; ++i) { | |||||
| for (size_t j = 0; j < dim_in; j += 2) { | |||||
| uint8_t val = in_ptr[j / 2]; | |||||
| out_ptr[j] = (int8_t)(val & 0xF) - zero_point; | |||||
| if (j + 1 < dim_in) | |||||
| out_ptr[j + 1] = (int8_t)((val >> 4) & 0xF) - zero_point; | |||||
| } | |||||
| in_ptr += stride_out; | |||||
| out_ptr += dim_in; | |||||
| } | |||||
| } | |||||
| void megdnn::naive::int8_to_uint4(const TensorND& in, const TensorND& out) { | |||||
| auto in_ptr = static_cast<int8_t*>(in.raw_ptr) + in.layout.span().low_byte; | |||||
| auto out_ptr = | |||||
| static_cast<uint8_t*>(out.raw_ptr) + out.layout.span().low_byte; | |||||
| auto zero_point = | |||||
| out.layout.dtype.param<dtype::Quantized4Asymm>().zero_point; | |||||
| const auto& ly = in.layout; | |||||
| auto dim_in = ly.shape[ly.ndim - 1]; | |||||
| auto elems = ly.total_nr_elems(); | |||||
| auto dim_out = elems / dim_in; | |||||
| auto stride_out = div_ceil(dim_in, 2_z); | |||||
| for (size_t i = 0; i < dim_out; ++i) { | |||||
| for (size_t j = 0; j < dim_in; j += 2) { | |||||
| uint8_t a = (uint8_t)std::max((int32_t)in_ptr[j] + zero_point, 0); | |||||
| uint8_t b = 0; | |||||
| if (j + 1 < dim_in) | |||||
| b = (uint8_t)std::max((int32_t)in_ptr[j + 1] + zero_point, 0); | |||||
| a = std::min(a, DTypeTrait<dtype::Quantized4Asymm>::max()); | |||||
| b = std::min(b, DTypeTrait<dtype::Quantized4Asymm>::max()); | |||||
| out_ptr[j / 2] = a + (b << 4); | |||||
| } | |||||
| in_ptr += dim_in; | |||||
| out_ptr += stride_out; | |||||
| } | |||||
| } | |||||
| // ==================================qint4====================================== | // ==================================qint4====================================== | ||||
| void megdnn::naive::int4_to_int8(const TensorND& in, const TensorND& out) { | void megdnn::naive::int4_to_int8(const TensorND& in, const TensorND& out) { | ||||
| auto in_ptr = static_cast<int8_t*>(in.raw_ptr) + in.layout.span().low_byte; | auto in_ptr = static_cast<int8_t*>(in.raw_ptr) + in.layout.span().low_byte; | ||||
| @@ -20,6 +20,10 @@ void uint4_to_uint8(const TensorND& in, const TensorND& out); | |||||
| void uint8_to_uint4(const TensorND& in, const TensorND& out); | void uint8_to_uint4(const TensorND& in, const TensorND& out); | ||||
| void uint4_to_int8(const TensorND& in, const TensorND& out); | |||||
| void int8_to_uint4(const TensorND& in, const TensorND& out); | |||||
| void int4_to_int8(const TensorND& in, const TensorND& out); | void int4_to_int8(const TensorND& in, const TensorND& out); | ||||
| void int8_to_int4(const TensorND& in , const TensorND& out); | void int8_to_int4(const TensorND& in , const TensorND& out); | ||||
| @@ -733,19 +733,33 @@ void check_conv_bias(DType src_dtype, DType filter_dtype, DType bias_dtype, | |||||
| param::ConvBias::Format format, | param::ConvBias::Format format, | ||||
| const std::vector<TestArg>& args, bool fuse_z, | const std::vector<TestArg>& args, bool fuse_z, | ||||
| bool stable_test) { | bool stable_test) { | ||||
| megdnn_assert(src_dtype.enumv() == filter_dtype.enumv()); | |||||
| megdnn_assert((src_dtype.enumv() == filter_dtype.enumv()) || | |||||
| (src_dtype.enumv() == DTypeEnum::Quantized4Asymm && | |||||
| filter_dtype.enumv() == DTypeEnum::QuantizedS4)); | |||||
| Checker<ConvBiasForward> checker(handle, !stable_test); | Checker<ConvBiasForward> checker(handle, !stable_test); | ||||
| if (algo) { | if (algo) { | ||||
| checker.set_before_exec_callback( | checker.set_before_exec_callback( | ||||
| ConvBiasAlgoChecker<ConvBiasForward>(algo)); | ConvBiasAlgoChecker<ConvBiasForward>(algo)); | ||||
| } | } | ||||
| std::unique_ptr<RNG> rng; | std::unique_ptr<RNG> rng; | ||||
| std::unique_ptr<RNG> flt_rng; | |||||
| std::unique_ptr<RNG> bias_rng; | std::unique_ptr<RNG> bias_rng; | ||||
| std::unique_ptr<RNG> const_rng; | std::unique_ptr<RNG> const_rng; | ||||
| std::unique_ptr<RNG> zero_rng; | std::unique_ptr<RNG> zero_rng; | ||||
| // TODO: check range of rng | // TODO: check range of rng | ||||
| if (src_dtype.enumv() == DTypeEnum::QuantizedS8) { | if (src_dtype.enumv() == DTypeEnum::QuantizedS8) { | ||||
| rng = std::make_unique<UniformIntRNG>(-3, 3); | rng = std::make_unique<UniformIntRNG>(-3, 3); | ||||
| flt_rng = std::make_unique<UniformIntRNG>(-3, 3); | |||||
| const_rng = std::make_unique<UniformIntRNG>(1, 1); | |||||
| zero_rng = std::make_unique<UniformIntRNG>(0, 0); | |||||
| megdnn_assert(bias_dtype.enumv() == DTypeEnum::QuantizedS32); | |||||
| bias_rng = std::make_unique<UniformIntRNG>(-50, 50); | |||||
| checker.set_epsilon(1 + 1e-3) | |||||
| .set_max_avg_error(1e-1) | |||||
| .set_max_avg_biased_error(1e-3); | |||||
| } else if (src_dtype.enumv() == DTypeEnum::Quantized4Asymm) { | |||||
| rng = std::make_unique<UniformIntRNG>(0, 6); | |||||
| flt_rng = std::make_unique<UniformIntRNG>(-3, 3); | |||||
| const_rng = std::make_unique<UniformIntRNG>(1, 1); | const_rng = std::make_unique<UniformIntRNG>(1, 1); | ||||
| zero_rng = std::make_unique<UniformIntRNG>(0, 0); | zero_rng = std::make_unique<UniformIntRNG>(0, 0); | ||||
| megdnn_assert(bias_dtype.enumv() == DTypeEnum::QuantizedS32); | megdnn_assert(bias_dtype.enumv() == DTypeEnum::QuantizedS32); | ||||
| @@ -755,6 +769,7 @@ void check_conv_bias(DType src_dtype, DType filter_dtype, DType bias_dtype, | |||||
| .set_max_avg_biased_error(1e-3); | .set_max_avg_biased_error(1e-3); | ||||
| } else if (src_dtype.enumv() == DTypeEnum::QuantizedS4) { | } else if (src_dtype.enumv() == DTypeEnum::QuantizedS4) { | ||||
| rng = std::make_unique<UniformIntRNG>(-3, 3); | rng = std::make_unique<UniformIntRNG>(-3, 3); | ||||
| flt_rng = std::make_unique<UniformIntRNG>(-3, 3); | |||||
| const_rng = std::make_unique<UniformIntRNG>(1, 1); | const_rng = std::make_unique<UniformIntRNG>(1, 1); | ||||
| zero_rng = std::make_unique<UniformIntRNG>(0, 0); | zero_rng = std::make_unique<UniformIntRNG>(0, 0); | ||||
| megdnn_assert(bias_dtype.enumv() == DTypeEnum::QuantizedS32); | megdnn_assert(bias_dtype.enumv() == DTypeEnum::QuantizedS32); | ||||
| @@ -764,11 +779,13 @@ void check_conv_bias(DType src_dtype, DType filter_dtype, DType bias_dtype, | |||||
| .set_max_avg_biased_error(1e-3); | .set_max_avg_biased_error(1e-3); | ||||
| } else if (src_dtype.enumv() == DTypeEnum::Float16) { | } else if (src_dtype.enumv() == DTypeEnum::Float16) { | ||||
| rng = std::make_unique<NormalRNG>(2.f); | rng = std::make_unique<NormalRNG>(2.f); | ||||
| flt_rng = std::make_unique<NormalRNG>(2.f); | |||||
| megdnn_assert(bias_dtype.enumv() == DTypeEnum::Float16); | megdnn_assert(bias_dtype.enumv() == DTypeEnum::Float16); | ||||
| bias_rng = std::make_unique<NormalRNG>(2.f); | bias_rng = std::make_unique<NormalRNG>(2.f); | ||||
| checker.set_epsilon(1e-2); | checker.set_epsilon(1e-2); | ||||
| } else if (src_dtype.enumv() == DTypeEnum::Float32) { | } else if (src_dtype.enumv() == DTypeEnum::Float32) { | ||||
| rng = std::make_unique<NormalRNG>(2.f); | rng = std::make_unique<NormalRNG>(2.f); | ||||
| flt_rng = std::make_unique<NormalRNG>(2.f); | |||||
| megdnn_assert(bias_dtype.enumv() == DTypeEnum::Float32); | megdnn_assert(bias_dtype.enumv() == DTypeEnum::Float32); | ||||
| bias_rng = std::make_unique<NormalRNG>(2.f); | bias_rng = std::make_unique<NormalRNG>(2.f); | ||||
| } | } | ||||
| @@ -819,9 +836,9 @@ void check_conv_bias(DType src_dtype, DType filter_dtype, DType bias_dtype, | |||||
| } | } | ||||
| return z; | return z; | ||||
| }; | }; | ||||
| megdnn_assert(rng != nullptr && bias_rng != nullptr); | |||||
| megdnn_assert(rng != nullptr && flt_rng != nullptr && bias_rng != nullptr); | |||||
| checker.set_rng(0, rng.get()) | checker.set_rng(0, rng.get()) | ||||
| .set_rng(1, rng.get()) | |||||
| .set_rng(1, flt_rng.get()) | |||||
| .set_rng(2, bias_rng.get()) | .set_rng(2, bias_rng.get()) | ||||
| .set_rng(3, rng.get()); | .set_rng(3, rng.get()); | ||||
| if (stable_test) { | if (stable_test) { | ||||
| @@ -257,7 +257,9 @@ void benchmark_target_algo_with_cudnn_tsc( | |||||
| param::ConvBias::Format change_cudnn_format, | param::ConvBias::Format change_cudnn_format, | ||||
| DType change_cudnn_src_dtype, DType change_cudnn_filter_dtype, | DType change_cudnn_src_dtype, DType change_cudnn_filter_dtype, | ||||
| DType change_cudnn_bias_dtype, DType change_cudnn_dst_dtype) { | DType change_cudnn_bias_dtype, DType change_cudnn_dst_dtype) { | ||||
| megdnn_assert(src_dtype.enumv() == filter_dtype.enumv()); | |||||
| megdnn_assert((src_dtype.enumv() == filter_dtype.enumv()) || | |||||
| (src_dtype.enumv() == DTypeEnum::Quantized4Asymm && | |||||
| filter_dtype.enumv() == DTypeEnum::QuantizedS4)); | |||||
| CUBenchmarker<ConvBiasForward> benchmarker(handle); | CUBenchmarker<ConvBiasForward> benchmarker(handle); | ||||
| CUBenchmarker<ConvBiasForward> benchmarker_cudnn(handle); | CUBenchmarker<ConvBiasForward> benchmarker_cudnn(handle); | ||||
| size_t RUNS = 200; | size_t RUNS = 200; | ||||
| @@ -299,30 +301,30 @@ void benchmark_target_algo_with_cudnn_tsc( | |||||
| using Param = ConvBias::Param; | using Param = ConvBias::Param; | ||||
| using Format = Param::Format; | using Format = Param::Format; | ||||
| // helper function to change format | // helper function to change format | ||||
| auto get_tensor_shape = [](TensorShape shape, | |||||
| auto get_tensor_shape = [](TensorShape shape, DType dtype, | |||||
| Format format) -> TensorShape { | Format format) -> TensorShape { | ||||
| TensorShape ret; | TensorShape ret; | ||||
| if (format == Format::NCHW4) { | if (format == Format::NCHW4) { | ||||
| ret = static_cast<TensorShape>( | ret = static_cast<TensorShape>( | ||||
| TensorLayout{shape, dtype::Int8()} | |||||
| TensorLayout{shape, dtype} | |||||
| .reshape({shape[0], shape[1] / 4, 4, shape[2], | .reshape({shape[0], shape[1] / 4, 4, shape[2], | ||||
| shape[3]}) | shape[3]}) | ||||
| .dimshuffle({0, 1, 3, 4, 2})); | .dimshuffle({0, 1, 3, 4, 2})); | ||||
| } else if (format == Format::NCHW32) { | } else if (format == Format::NCHW32) { | ||||
| ret = static_cast<TensorShape>( | ret = static_cast<TensorShape>( | ||||
| TensorLayout{shape, dtype::Int8()} | |||||
| TensorLayout{shape, dtype} | |||||
| .reshape({shape[0], shape[1] / 32, 32, shape[2], | .reshape({shape[0], shape[1] / 32, 32, shape[2], | ||||
| shape[3]}) | shape[3]}) | ||||
| .dimshuffle({0, 1, 3, 4, 2})); | .dimshuffle({0, 1, 3, 4, 2})); | ||||
| } else if (format == Format::NCHW64) { | } else if (format == Format::NCHW64) { | ||||
| ret = static_cast<TensorShape>( | ret = static_cast<TensorShape>( | ||||
| TensorLayout{shape, dtype::QuantizedS4(1.f)} | |||||
| TensorLayout{shape, dtype} | |||||
| .reshape({shape[0], shape[1] / 64, 64, shape[2], | .reshape({shape[0], shape[1] / 64, 64, shape[2], | ||||
| shape[3]}) | shape[3]}) | ||||
| .dimshuffle({0, 1, 3, 4, 2})); | .dimshuffle({0, 1, 3, 4, 2})); | ||||
| } else if (format == Format::CHWN4) { | } else if (format == Format::CHWN4) { | ||||
| ret = static_cast<TensorShape>( | ret = static_cast<TensorShape>( | ||||
| TensorLayout{shape, dtype::Int8()} | |||||
| TensorLayout{shape, dtype} | |||||
| .reshape({shape[0], shape[1] / 4, 4, shape[2], | .reshape({shape[0], shape[1] / 4, 4, shape[2], | ||||
| shape[3]}) | shape[3]}) | ||||
| .dimshuffle({1, 3, 4, 0, 2})); | .dimshuffle({1, 3, 4, 0, 2})); | ||||
| @@ -370,21 +372,24 @@ void benchmark_target_algo_with_cudnn_tsc( | |||||
| if (algo) { | if (algo) { | ||||
| time_in_ms = | time_in_ms = | ||||
| algo_benchmark<ConvBiasForward, OprProxy<ConvBiasForward>, | algo_benchmark<ConvBiasForward, OprProxy<ConvBiasForward>, | ||||
| CUTimer>(benchmarker, | |||||
| {get_tensor_shape(src, format), | |||||
| get_tensor_shape(filter, format), | |||||
| get_tensor_shape(bias, format), | |||||
| {}, | |||||
| {}}, | |||||
| algo) / | |||||
| CUTimer>( | |||||
| benchmarker, | |||||
| {get_tensor_shape(src, src_dtype, format), | |||||
| get_tensor_shape(filter, filter_dtype, format), | |||||
| get_tensor_shape(bias, bias_dtype, format), | |||||
| {}, | |||||
| {}}, | |||||
| algo) / | |||||
| RUNS; | RUNS; | ||||
| } else { | } else { | ||||
| time_in_ms = benchmarker.execs({get_tensor_shape(src, format), | |||||
| get_tensor_shape(filter, format), | |||||
| get_tensor_shape(bias, format), | |||||
| {}, | |||||
| {}}) / | |||||
| RUNS; | |||||
| time_in_ms = | |||||
| benchmarker.execs( | |||||
| {get_tensor_shape(src, src_dtype, format), | |||||
| get_tensor_shape(filter, filter_dtype, format), | |||||
| get_tensor_shape(bias, bias_dtype, format), | |||||
| {}, | |||||
| {}}) / | |||||
| RUNS; | |||||
| } | } | ||||
| float time_in_ms_cudnn = 0; | float time_in_ms_cudnn = 0; | ||||
| if (with_cudnn) { | if (with_cudnn) { | ||||
| @@ -393,9 +398,11 @@ void benchmark_target_algo_with_cudnn_tsc( | |||||
| algo_benchmark<ConvBiasForward, | algo_benchmark<ConvBiasForward, | ||||
| OprProxy<ConvBiasForward>, CUTimer>( | OprProxy<ConvBiasForward>, CUTimer>( | ||||
| benchmarker_cudnn, | benchmarker_cudnn, | ||||
| {get_tensor_shape(src, format_cudnn), | |||||
| get_tensor_shape(filter, format_cudnn), | |||||
| get_tensor_shape(bias, format_cudnn), | |||||
| {get_tensor_shape(src, src_dtype, format_cudnn), | |||||
| get_tensor_shape(filter, filter_dtype, | |||||
| format_cudnn), | |||||
| get_tensor_shape(bias, bias_dtype, | |||||
| format_cudnn), | |||||
| {}, | {}, | ||||
| {}}, | {}}, | ||||
| change_cudnn_algo) / | change_cudnn_algo) / | ||||
| @@ -403,9 +410,11 @@ void benchmark_target_algo_with_cudnn_tsc( | |||||
| } else { | } else { | ||||
| time_in_ms_cudnn = | time_in_ms_cudnn = | ||||
| benchmarker_cudnn.execs( | benchmarker_cudnn.execs( | ||||
| {get_tensor_shape(src, format_cudnn), | |||||
| get_tensor_shape(filter, format_cudnn), | |||||
| get_tensor_shape(bias, format_cudnn), | |||||
| {get_tensor_shape(src, src_dtype, format_cudnn), | |||||
| get_tensor_shape(filter, filter_dtype, | |||||
| format_cudnn), | |||||
| get_tensor_shape(bias, bias_dtype, | |||||
| format_cudnn), | |||||
| {}, | {}, | ||||
| {}}) / | {}}) / | ||||
| RUNS; | RUNS; | ||||
| @@ -426,21 +435,24 @@ void benchmark_target_algo_with_cudnn_tsc( | |||||
| if (algo) { | if (algo) { | ||||
| time_in_ms = | time_in_ms = | ||||
| algo_benchmark<ConvBiasForward, OprProxy<ConvBiasForward>, | algo_benchmark<ConvBiasForward, OprProxy<ConvBiasForward>, | ||||
| CUTimer>(benchmarker, | |||||
| {get_tensor_shape(src, format), | |||||
| get_tensor_shape(filter, format), | |||||
| get_tensor_shape(bias, format), | |||||
| get_tensor_shape(z, format), | |||||
| {}}, | |||||
| algo) / | |||||
| CUTimer>( | |||||
| benchmarker, | |||||
| {get_tensor_shape(src, src_dtype, format), | |||||
| get_tensor_shape(filter, filter_dtype, format), | |||||
| get_tensor_shape(bias, bias_dtype, format), | |||||
| get_tensor_shape(z, src_dtype, format), | |||||
| {}}, | |||||
| algo) / | |||||
| RUNS; | RUNS; | ||||
| } else { | } else { | ||||
| time_in_ms = benchmarker.execs({get_tensor_shape(src, format), | |||||
| get_tensor_shape(filter, format), | |||||
| get_tensor_shape(bias, format), | |||||
| get_tensor_shape(z, format), | |||||
| {}}) / | |||||
| RUNS; | |||||
| time_in_ms = | |||||
| benchmarker.execs( | |||||
| {get_tensor_shape(src, src_dtype, format), | |||||
| get_tensor_shape(filter, filter_dtype, format), | |||||
| get_tensor_shape(bias, bias_dtype, format), | |||||
| get_tensor_shape(z, src_dtype, format), | |||||
| {}}) / | |||||
| RUNS; | |||||
| } | } | ||||
| time_in_ms_cudnn = 0; | time_in_ms_cudnn = 0; | ||||
| if (with_cudnn) { | if (with_cudnn) { | ||||
| @@ -449,20 +461,24 @@ void benchmark_target_algo_with_cudnn_tsc( | |||||
| algo_benchmark<ConvBiasForward, | algo_benchmark<ConvBiasForward, | ||||
| OprProxy<ConvBiasForward>, CUTimer>( | OprProxy<ConvBiasForward>, CUTimer>( | ||||
| benchmarker_cudnn, | benchmarker_cudnn, | ||||
| {get_tensor_shape(src, format_cudnn), | |||||
| get_tensor_shape(filter, format_cudnn), | |||||
| get_tensor_shape(bias, format_cudnn), | |||||
| get_tensor_shape(z, format_cudnn), | |||||
| {get_tensor_shape(src, src_dtype, format_cudnn), | |||||
| get_tensor_shape(filter, filter_dtype, | |||||
| format_cudnn), | |||||
| get_tensor_shape(bias, bias_dtype, | |||||
| format_cudnn), | |||||
| get_tensor_shape(z, src_dtype, format_cudnn), | |||||
| {}}, | {}}, | ||||
| change_cudnn_algo) / | change_cudnn_algo) / | ||||
| RUNS; | RUNS; | ||||
| } else { | } else { | ||||
| time_in_ms_cudnn = | time_in_ms_cudnn = | ||||
| benchmarker_cudnn.execs( | benchmarker_cudnn.execs( | ||||
| {get_tensor_shape(src, format_cudnn), | |||||
| get_tensor_shape(filter, format_cudnn), | |||||
| get_tensor_shape(bias, format_cudnn), | |||||
| get_tensor_shape(z, format_cudnn), | |||||
| {get_tensor_shape(src, src_dtype, format_cudnn), | |||||
| get_tensor_shape(filter, filter_dtype, | |||||
| format_cudnn), | |||||
| get_tensor_shape(bias, bias_dtype, | |||||
| format_cudnn), | |||||
| get_tensor_shape(z, src_dtype, format_cudnn), | |||||
| {}}) / | {}}) / | ||||
| RUNS; | RUNS; | ||||
| } | } | ||||
| @@ -746,6 +746,45 @@ TEST_F(NAIVE, CONV_BIAS_QUANTIZED4) { | |||||
| checker.set_param(param).exect(Testcase{input, filter, bias, z, {}}, | checker.set_param(param).exect(Testcase{input, filter, bias, z, {}}, | ||||
| Testcase{{}, {}, {}, {}, output}); | Testcase{{}, {}, {}, {}, output}); | ||||
| // test qu4 x q4 | |||||
| for (size_t i = 0; i < input_values.size(); i++) { | |||||
| input_values[i] = input_values[i] + 8; | |||||
| } | |||||
| for (size_t i = 0; i < z_values.size(); i++) { | |||||
| z_values[i] = z_values[i] + 8; | |||||
| } | |||||
| std::vector<int> output_uint4; | |||||
| auto dtype_qu4 = dtype::Quantized4Asymm(0.01, 8); | |||||
| for (size_t i = 0; i < output_values.size(); i++) { | |||||
| int result = | |||||
| static_cast<int>(dtype_qu4.param() | |||||
| .quantize(output_values.at(i) * 0.01) | |||||
| .as_uint8()); | |||||
| output_uint4.push_back(result); | |||||
| } | |||||
| auto input_qu4 = TensorValueLowbit4( | |||||
| {1, 1, 4, 4}, dtype::Quantized4Asymm(0.1, 8), input_values); | |||||
| auto filter_q4 = TensorValueLowbit4({3, 1, 3, 3}, dtype::QuantizedS4(0.1), | |||||
| filter_values); | |||||
| auto bias_s32 = GenTensorValue({1, 3, 1, 1}, dtype::QuantizedS32(0.01), | |||||
| bias_values); | |||||
| auto z_qu4 = TensorValueLowbit4({1, 3, 2, 2}, | |||||
| dtype::Quantized4Asymm(0.01, 8), z_values); | |||||
| auto output_qu4 = TensorValueLowbit4( | |||||
| {1, 3, 2, 2}, dtype::Quantized4Asymm(0.01, 8), output_uint4); | |||||
| checker.set_param(param).exect( | |||||
| Testcase{input_qu4, filter_q4, bias_s32, z_qu4, {}}, | |||||
| Testcase{{}, {}, {}, {}, output_qu4}); | |||||
| } | } | ||||
| TEST_F(NAIVE, CONV_BIAS_NCHW64_Q4) { | TEST_F(NAIVE, CONV_BIAS_NCHW64_Q4) { | ||||
| @@ -3329,7 +3368,7 @@ TEST_F(NAIVE, CONV_BIAS_NCHW64_Q4) { | |||||
| auto input_64 = TensorValueLowbit4({1, 1, 4, 4, 64}, | auto input_64 = TensorValueLowbit4({1, 1, 4, 4, 64}, | ||||
| dtype::QuantizedS4(0.1), input_values); | dtype::QuantizedS4(0.1), input_values); | ||||
| auto fliter_64 = TensorValueLowbit4({64, 1, 3, 3, 64}, | |||||
| auto filter_64 = TensorValueLowbit4({64, 1, 3, 3, 64}, | |||||
| dtype::QuantizedS4(0.1), filter_values); | dtype::QuantizedS4(0.1), filter_values); | ||||
| auto bias1_64 = | auto bias1_64 = | ||||
| GenTensorValue({1, 1, 1, 1, 64}, dtype::QuantizedS32(0.01), bias_1); | GenTensorValue({1, 1, 1, 1, 64}, dtype::QuantizedS32(0.01), bias_1); | ||||
| @@ -3338,7 +3377,31 @@ TEST_F(NAIVE, CONV_BIAS_NCHW64_Q4) { | |||||
| {1, 1, 2, 2, 64}, dtype::QuantizedS4(1), output_values); | {1, 1, 2, 2, 64}, dtype::QuantizedS4(1), output_values); | ||||
| checker.set_param(param).exect( | checker.set_param(param).exect( | ||||
| Testcase{input_64, fliter_64, bias1_64, {}, {}}, | |||||
| Testcase{input_64, filter_64, bias1_64, {}, {}}, | |||||
| Testcase{{}, {}, {}, {}, output_64}); | Testcase{{}, {}, {}, {}, output_64}); | ||||
| // test qu4 x q4 | |||||
| for (size_t i = 0; i < input_values.size(); i++) { | |||||
| input_values[i] = input_values[i] + 8; | |||||
| } | |||||
| for (size_t i = 0; i < output_values.size(); i++) { | |||||
| output_values[i] = output_values[i] + 8; | |||||
| } | |||||
| auto input_qu4_64 = TensorValueLowbit4( | |||||
| {1, 1, 4, 4, 64}, dtype::Quantized4Asymm(0.1, 8), input_values); | |||||
| auto filter_q4_64 = TensorValueLowbit4( | |||||
| {64, 1, 3, 3, 64}, dtype::QuantizedS4(0.1), filter_values); | |||||
| auto bias_64 = | |||||
| GenTensorValue({1, 1, 1, 1, 64}, dtype::QuantizedS32(0.01), bias_1); | |||||
| auto output_q4_64 = TensorValueLowbit4( | |||||
| {1, 1, 2, 2, 64}, dtype::Quantized4Asymm(1, 8), output_values); | |||||
| checker.set_param(param).exect( | |||||
| Testcase{input_qu4_64, filter_q4_64, bias_64, {}, {}}, | |||||
| Testcase{{}, {}, {}, {}, output_q4_64}); | |||||
| } | } | ||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||