GitOrigin-RevId: 878bb8c955
tags/v1.5.0
| @@ -1,5 +1,6 @@ | |||||
| # Mark generated files as binary, ignore them in git diff. | # Mark generated files as binary, ignore them in git diff. | ||||
| # dnn | # dnn | ||||
| dnn/src/cuda/conv_bias/int4/kimpl/* binary | |||||
| dnn/src/cuda/conv_bias/int8/kimpl/* binary | dnn/src/cuda/conv_bias/int8/kimpl/* binary | ||||
| dnn/src/cuda/conv_bias/int8_imma/kimpl/* binary | dnn/src/cuda/conv_bias/int8_imma/kimpl/* binary | ||||
| dnn/src/cuda/batch_conv_bias/int8/kimpl/* binary | dnn/src/cuda/batch_conv_bias/int8/kimpl/* binary | ||||
| @@ -84,6 +84,9 @@ ConvBiasForwardImpl::AlgoPack::AlgoPack() { | |||||
| for (auto&& algo : int8_nchw32_imma) { | for (auto&& algo : int8_nchw32_imma) { | ||||
| all_algos.push_back(&algo); | all_algos.push_back(&algo); | ||||
| } | } | ||||
| for (auto&& algo : int4_int4_nchw64_imma) { | |||||
| all_algos.push_back(&algo); | |||||
| } | |||||
| #endif | #endif | ||||
| #endif | #endif | ||||
| fill_dp4a_algos(); | fill_dp4a_algos(); | ||||
| @@ -225,6 +228,12 @@ void ConvBiasForwardImpl::AlgoPack::fill_imma_algos() { | |||||
| int8_nchw32_imma.emplace_back(AlgoParam{64, 64, 64, 32, 32, 64}); | int8_nchw32_imma.emplace_back(AlgoParam{64, 64, 64, 32, 32, 64}); | ||||
| int8_nchw32_imma.emplace_back(AlgoParam{32, 64, 64, 32, 16, 64}); | int8_nchw32_imma.emplace_back(AlgoParam{32, 64, 64, 32, 16, 64}); | ||||
| } | } | ||||
| { | |||||
| using AlgoParam = AlgoInt4Int4NCHW64IMMAImplicitGemm::AlgoParam; | |||||
| int4_int4_nchw64_imma.emplace_back(AlgoParam{128, 128, 128, 64, 64, 128}); | |||||
| int4_int4_nchw64_imma.emplace_back(AlgoParam{256, 128, 128, 64, 64, 128}); | |||||
| } | |||||
| #endif | #endif | ||||
| } | } | ||||
| #endif | #endif | ||||
| @@ -61,6 +61,7 @@ public: | |||||
| CUDA_IMPLICIT_GEMM_REORDER_FILTER_CHWN4_IMMA_INT8, | CUDA_IMPLICIT_GEMM_REORDER_FILTER_CHWN4_IMMA_INT8, | ||||
| CUDA_IMPLICIT_GEMM_UNROLL_WIDTH_CHWN4_IMMA_INT8, | CUDA_IMPLICIT_GEMM_UNROLL_WIDTH_CHWN4_IMMA_INT8, | ||||
| CUDA_IMPLICIT_GEMM_IMMA_NCHW32_INT8, | CUDA_IMPLICIT_GEMM_IMMA_NCHW32_INT8, | ||||
| CUDA_IMPLICIT_GEMM_IMMA_NCHW64_INT4_INT4, | |||||
| CUDA_BFLOAT16, | CUDA_BFLOAT16, | ||||
| CUDA_IMPLICIT_GEMM_SASS_NCHW4_DOTPROD_INT8, | CUDA_IMPLICIT_GEMM_SASS_NCHW4_DOTPROD_INT8, | ||||
| CUDA_IMPLICIT_GEMM_1X1_SASS_NCHW4_DOTPROD_INT8, | CUDA_IMPLICIT_GEMM_1X1_SASS_NCHW4_DOTPROD_INT8, | ||||
| @@ -755,6 +756,53 @@ public: | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| private: | |||||
| WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, | |||||
| const SizeArgs& args) const; | |||||
| AlgoParam m_algo_param; | |||||
| std::string m_name; | |||||
| }; | |||||
| class ConvBiasForwardImpl::AlgoInt4Int4NCHW64IMMAImplicitGemm final | |||||
| : public AlgoBase { | |||||
| public: | |||||
| struct AlgoParam { | |||||
| int threadblock_m; | |||||
| int threadblock_n; | |||||
| int threadblock_k; | |||||
| int warp_m; | |||||
| int warp_n; | |||||
| int warp_k; | |||||
| }; | |||||
| AlgoInt4Int4NCHW64IMMAImplicitGemm(AlgoParam algo_param) | |||||
| : m_algo_param{algo_param} { | |||||
| m_name = ConvBias::algo_name<ConvBias::DirectParam>( | |||||
| ssprintf("INT4_INT4_NCHW64_IMMA_IMPLICIT_GEMM_%s", | |||||
| to_string(m_algo_param).c_str()), | |||||
| ConvBias::DirectParam{}); | |||||
| } | |||||
| bool is_available(const SizeArgs& args) const override; | |||||
| size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||||
| void exec(const ExecArgs& args) const override; | |||||
| const char* name() const override { return m_name.c_str(); } | |||||
| AlgoAttribute attribute() const override { | |||||
| return AlgoAttribute::REPRODUCIBLE; | |||||
| } | |||||
| static std::string to_string(AlgoParam algo_param); | |||||
| size_t get_preprocess_workspace_in_bytes( | |||||
| const SizeArgs& args) const override; | |||||
| SmallVector<TensorLayout> deduce_preprocessed_filter_layout( | |||||
| const SizeArgs& args) const override; | |||||
| void exec_preprocess(const ExecArgs& args) const override; | |||||
| MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_IMMA_NCHW64_INT4_INT4) | |||||
| std::string param() const override { | |||||
| std::string ret; | |||||
| serialize_write_pod(m_algo_param, ret); | |||||
| return ret; | |||||
| } | |||||
| private: | private: | ||||
| WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, | WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, | ||||
| const SizeArgs& args) const; | const SizeArgs& args) const; | ||||
| @@ -819,6 +867,7 @@ public: | |||||
| #endif | #endif | ||||
| #if CUDA_VERSION >= 10020 | #if CUDA_VERSION >= 10020 | ||||
| std::vector<AlgoInt8NCHW32IMMAImplicitGemm> int8_nchw32_imma; | std::vector<AlgoInt8NCHW32IMMAImplicitGemm> int8_nchw32_imma; | ||||
| std::vector<AlgoInt4Int4NCHW64IMMAImplicitGemm> int4_int4_nchw64_imma; | |||||
| #endif | #endif | ||||
| std::vector<std::unique_ptr<AlgoGroupConvGeneral>> gconv_refhold; | std::vector<std::unique_ptr<AlgoGroupConvGeneral>> gconv_refhold; | ||||
| AlgoBFloat16 bfloat16; | AlgoBFloat16 bfloat16; | ||||
| @@ -25,8 +25,8 @@ using namespace megdnn; | |||||
| using namespace cuda; | using namespace cuda; | ||||
| using namespace cutlass_wrapper; | using namespace cutlass_wrapper; | ||||
| /* ================= cutlass kernel wrapper for nchw32 layout ================ | |||||
| */ | |||||
| /* ====== cutlass kernel wrapper for int8 nchw32 layout ====== */ | |||||
| #if MEGDNN_TEGRA_X1 | #if MEGDNN_TEGRA_X1 | ||||
| template <bool NeedLoadFromConstMem> | template <bool NeedLoadFromConstMem> | ||||
| void megdnn::cuda::cutlass_wrapper:: | void megdnn::cuda::cutlass_wrapper:: | ||||
| @@ -149,7 +149,8 @@ INST(true); | |||||
| INST(false); | INST(false); | ||||
| #undef INST | #undef INST | ||||
| /* ==== cutlass kernel wrapper for nchw32 layout and nchw4 output ===== */ | |||||
| /* ===== cutlass kernel wrapper for int8 nchw32 layout and nchw4 output ===== */ | |||||
| #if MEGDNN_TEGRA_X1 | #if MEGDNN_TEGRA_X1 | ||||
| template <bool NeedLoadFromConstMem> | template <bool NeedLoadFromConstMem> | ||||
| void megdnn::cuda::cutlass_wrapper:: | void megdnn::cuda::cutlass_wrapper:: | ||||
| @@ -272,7 +273,8 @@ INST(true); | |||||
| INST(false); | INST(false); | ||||
| #undef INST | #undef INST | ||||
| /* ================ cutlass kernel wrapper for nchw4 layout ================= */ | |||||
| /* ====== cutlass kernel wrapper for int8 nchw4 layout ====== */ | |||||
| #if MEGDNN_TEGRA_X1 | #if MEGDNN_TEGRA_X1 | ||||
| template <bool NeedLoadFromConstMem> | template <bool NeedLoadFromConstMem> | ||||
| void megdnn::cuda::cutlass_wrapper:: | void megdnn::cuda::cutlass_wrapper:: | ||||
| @@ -401,7 +403,8 @@ INST(true); | |||||
| INST(false); | INST(false); | ||||
| #undef INST | #undef INST | ||||
| /* ===== cutlass kernel wrapper for nchw4 layout and nchw output ===== */ | |||||
| /* ====== cutlass kernel wrapper for int8 nchw4 layout and nchw output ====== */ | |||||
| #if MEGDNN_TEGRA_X1 | #if MEGDNN_TEGRA_X1 | ||||
| template <bool NeedLoadFromConstMem> | template <bool NeedLoadFromConstMem> | ||||
| void megdnn::cuda::cutlass_wrapper:: | void megdnn::cuda::cutlass_wrapper:: | ||||
| @@ -531,7 +534,8 @@ INST(true); | |||||
| INST(false); | INST(false); | ||||
| #undef INST | #undef INST | ||||
| /* ====== cutlass kernel wrapper for nchw4 layout and nchw32 output ====== */ | |||||
| /* ===== cutlass kernel wrapper for int8 nchw4 layout and nchw32 output ===== */ | |||||
| #if MEGDNN_TEGRA_X1 | #if MEGDNN_TEGRA_X1 | ||||
| template <bool NeedLoadFromConstMem> | template <bool NeedLoadFromConstMem> | ||||
| void megdnn::cuda::cutlass_wrapper:: | void megdnn::cuda::cutlass_wrapper:: | ||||
| @@ -658,4 +662,125 @@ INST(true); | |||||
| INST(false); | INST(false); | ||||
| #undef INST | #undef INST | ||||
| /* ====== cutlass kernel wrapper for int4 nchw64 layout ====== */ | |||||
| #if MEGDNN_TEGRA_X1 | |||||
| template <bool NeedLoadFromConstMem> | |||||
| void megdnn::cuda::cutlass_wrapper:: | |||||
| do_conv_bias_int4_int4_implicit_gemm_imma_ncdiv64hw64( | |||||
| const int8_t* /* d_src */, const int8_t* /* d_filter */, | |||||
| const int32_t* /* d_bias */, const int8_t* /* d_z */, | |||||
| int8_t* /* d_dst */, int* /* workspace */, | |||||
| const convolution::ConvParam& /* param */, | |||||
| uint32_t /* nonlinear_mode */, float /* alpha */, | |||||
| float /* beta */, float /* gamma */, float /* scale */, | |||||
| const GemmCoord& /* threadblock_shape */, | |||||
| const GemmCoord& /* warp_shape */, cudaStream_t /* stream */) {} | |||||
| #else | |||||
| template <bool NeedLoadFromConstMem> | |||||
| void megdnn::cuda::cutlass_wrapper:: | |||||
| do_conv_bias_int4_int4_implicit_gemm_imma_ncdiv64hw64( | |||||
| const int8_t* d_src, const int8_t* d_filter, | |||||
| const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, | |||||
| int* workspace, const convolution::ConvParam& param, | |||||
| uint32_t nonlinear_mode, float alpha, float beta, float gamma, | |||||
| float scale, const GemmCoord& threadblock_shape, | |||||
| const GemmCoord& warp_shape, cudaStream_t stream) { | |||||
| #define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \ | |||||
| threadblock_k_, warp_m_, warp_n_, \ | |||||
| warp_k_) \ | |||||
| if (threadblock_shape.m() == threadblock_m_ && \ | |||||
| threadblock_shape.n() == threadblock_n_ && \ | |||||
| threadblock_shape.k() == threadblock_k_ && \ | |||||
| warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ | |||||
| warp_shape.k() == warp_k_) { \ | |||||
| using ThreadBlockShape = \ | |||||
| cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \ | |||||
| threadblock_k_>; \ | |||||
| using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \ | |||||
| using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; \ | |||||
| using Convolution = cutlass::conv::device::Convolution< \ | |||||
| cutlass::int4b_t, cutlass::layout::TensorNCxHWx<64>, \ | |||||
| cutlass::int4b_t, cutlass::layout::TensorCxRSKx<64>, \ | |||||
| ElementOutput, cutlass::layout::TensorNCxHWx<64>, int32_t, \ | |||||
| cutlass::layout::TensorNCxHWx<64>, int32_t, \ | |||||
| cutlass::conv::ConvType::kConvolution, \ | |||||
| cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \ | |||||
| ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ | |||||
| cutlass::conv::threadblock:: \ | |||||
| ConvolutionFpropNCxHWxThreadblockSwizzle, \ | |||||
| 2, 32, 32, NeedLoadFromConstMem>; \ | |||||
| typename Convolution::ConvolutionParameter conv_param( \ | |||||
| param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ | |||||
| param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ | |||||
| param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \ | |||||
| return cutlass_convolution_wrapper<Convolution>( \ | |||||
| reinterpret_cast<const cutlass::int4b_t*>(d_src), \ | |||||
| reinterpret_cast<const cutlass::int4b_t*>(d_filter), d_bias, \ | |||||
| reinterpret_cast<const cutlass::int4b_t*>(d_z), \ | |||||
| reinterpret_cast<cutlass::int4b_t*>(d_dst), workspace, \ | |||||
| conv_param, epilogue, stream); \ | |||||
| } | |||||
| #define DISPATCH_KERNEL \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 128, 64, 64, 128); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(256, 128, 128, 64, 64, 128); \ | |||||
| megdnn_assert(false, \ | |||||
| "unsupported threadblock shape (%dx%dx%d) and warp shape " \ | |||||
| "(%dx%dx%d)", \ | |||||
| threadblock_shape.m(), threadblock_shape.n(), \ | |||||
| threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ | |||||
| warp_shape.k()); | |||||
| using ElementOutput = cutlass::int4b_t; | |||||
| using ElementAccumulator = int32_t; | |||||
| using ElementBias = int32_t; | |||||
| using ElementCompute = float; | |||||
| using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode; | |||||
| switch (nonlinear_mode) { | |||||
| case NonlineMode::IDENTITY: { | |||||
| using EpilogueOp = | |||||
| cutlass::epilogue::thread::BiasAddLinearCombinationClamp< | |||||
| ElementOutput, 16, ElementAccumulator, ElementBias, | |||||
| ElementCompute>; | |||||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma}; | |||||
| DISPATCH_KERNEL; | |||||
| } | |||||
| case NonlineMode::RELU: { | |||||
| using EpilogueOp = cutlass::epilogue::thread:: | |||||
| BiasAddLinearCombinationReluClamp< | |||||
| ElementOutput, 16, ElementAccumulator, ElementBias, | |||||
| ElementCompute>; | |||||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma, 0}; | |||||
| DISPATCH_KERNEL; | |||||
| } | |||||
| case NonlineMode::H_SWISH: { | |||||
| using EpilogueOp = cutlass::epilogue::thread:: | |||||
| BiasAddLinearCombinationHSwishClamp< | |||||
| ElementOutput, 16, ElementAccumulator, ElementBias, | |||||
| ElementCompute>; | |||||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma, scale}; | |||||
| DISPATCH_KERNEL; | |||||
| } | |||||
| default: | |||||
| megdnn_assert(false, | |||||
| "unsupported nonlinear mode for conv bias operator"); | |||||
| } | |||||
| #undef DISPATCH_KERNEL_WITH_TILE_SHAPE | |||||
| #undef DISPATCH_KERNEL | |||||
| } | |||||
| #endif | |||||
| #define INST(need_load_from_const_mem) \ | |||||
| template void megdnn::cuda::cutlass_wrapper:: \ | |||||
| do_conv_bias_int4_int4_implicit_gemm_imma_ncdiv64hw64< \ | |||||
| need_load_from_const_mem>( \ | |||||
| const int8_t* d_src, const int8_t* d_filter, \ | |||||
| const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, \ | |||||
| int* workspace, const convolution::ConvParam& param, \ | |||||
| uint32_t nonlinear_mode, float alpha, float beta, \ | |||||
| float gamma, float scale, \ | |||||
| const GemmCoord& threadblock_shape, \ | |||||
| const GemmCoord& warp_shape, cudaStream_t stream); | |||||
| INST(true); | |||||
| #undef INST | |||||
| // vim: syntax=cuda.doxygen | // vim: syntax=cuda.doxygen | ||||
| @@ -76,6 +76,15 @@ void do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32( | |||||
| const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, | const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, | ||||
| int stages, cudaStream_t stream); | int stages, cudaStream_t stream); | ||||
| template <bool NeedLoadFromConstMem> | |||||
| void do_conv_bias_int4_int4_implicit_gemm_imma_ncdiv64hw64( | |||||
| const int8_t* d_src, const int8_t* d_filter, const int32_t* d_bias, | |||||
| const int8_t* d_z, int8_t* d_dst, int* workspace, | |||||
| const convolution::ConvParam& param, uint32_t nonlinear_mode, | |||||
| float alpha, float beta, float gamma, float scale, | |||||
| const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, | |||||
| cudaStream_t stream); | |||||
| } // namespace cutlass_wrapper | } // namespace cutlass_wrapper | ||||
| } // namespace cuda | } // namespace cuda | ||||
| } // namespace megdnn | } // namespace megdnn | ||||
| @@ -0,0 +1,209 @@ | |||||
| /** | |||||
| * \file dnn/src/cuda/conv_bias/implicit_gemm_int4_nchw64_imma.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #include "./algo.h" | |||||
| #include "src/common/conv_bias.h" | |||||
| #include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh" | |||||
| #include "src/cuda/convolution_helper/parameter.cuh" | |||||
| #include "src/cuda/utils.h" | |||||
| using namespace megdnn; | |||||
| using namespace cuda; | |||||
| using namespace convolution; | |||||
| #if CUDA_VERSION >= 10020 | |||||
| bool ConvBiasForwardImpl::AlgoInt4Int4NCHW64IMMAImplicitGemm::is_available( | |||||
| const SizeArgs& args) const { | |||||
| if (args.bias_layout->ndim <= 0) | |||||
| return false; | |||||
| using Param = param::ConvBias; | |||||
| using Format = Param::Format; | |||||
| using Sparse = Param::Sparse; | |||||
| using Mode = Param::Mode; | |||||
| using NonlineMode = megdnn::param::ConvBias::NonlineMode; | |||||
| auto&& param = args.opr->param(); | |||||
| if (!check_bias_share_in_channel(*(args.bias_layout), param.format)) | |||||
| return false; | |||||
| if (param.format != Format::NCHW64 || param.sparse != Sparse::DENSE || | |||||
| param.mode != Mode::CROSS_CORRELATION) | |||||
| return false; | |||||
| if (param.nonlineMode != NonlineMode::IDENTITY && | |||||
| param.nonlineMode != NonlineMode::RELU && | |||||
| param.nonlineMode != NonlineMode::H_SWISH) | |||||
| return false; | |||||
| if (args.src_layout->dtype.enumv() != DTypeEnum::QuantizedS4 || | |||||
| args.filter_layout->dtype.enumv() != DTypeEnum::QuantizedS4 || | |||||
| args.bias_layout->dtype.enumv() != DTypeEnum::QuantizedS32 || | |||||
| args.dst_layout->dtype.enumv() != DTypeEnum::QuantizedS4) | |||||
| return false; | |||||
| if (!is_compute_capability_required(7, 5)) | |||||
| return false; | |||||
| return true; | |||||
| } | |||||
| WorkspaceBundle | |||||
| ConvBiasForwardImpl::AlgoInt4Int4NCHW64IMMAImplicitGemm::get_workspace_bundle( | |||||
| dt_byte* raw_ptr, const SizeArgs& args) const { | |||||
| if (args.preprocessed_filter) { | |||||
| return WorkspaceBundle{raw_ptr, {}}; | |||||
| } else { | |||||
| size_t ws_filter = args.filter_layout->span().dist_byte(); | |||||
| return WorkspaceBundle{raw_ptr, {ws_filter}}; | |||||
| } | |||||
| } | |||||
| size_t | |||||
| ConvBiasForwardImpl::AlgoInt4Int4NCHW64IMMAImplicitGemm::get_workspace_in_bytes( | |||||
| const SizeArgs& args) const { | |||||
| return get_workspace_bundle(nullptr, args).total_size_in_bytes(); | |||||
| } | |||||
| void ConvBiasForwardImpl::AlgoInt4Int4NCHW64IMMAImplicitGemm::exec( | |||||
| const ExecArgs& args) const { | |||||
| using Format = Param::Format; | |||||
| auto&& param = args.opr->param(); | |||||
| auto&& fm = args.filter_meta; | |||||
| size_t n = args.src_layout->operator[](0), | |||||
| ci = args.src_layout->operator[](1) * 64, | |||||
| hi = args.src_layout->operator[](2), | |||||
| wi = args.src_layout->operator[](3); | |||||
| size_t co = args.dst_layout->operator[](1) * 64, | |||||
| ho = args.dst_layout->operator[](2), | |||||
| wo = args.dst_layout->operator[](3); | |||||
| UNPACK_CONV_PARAMETER(fm, param); | |||||
| MARK_USED_VAR | |||||
| auto&& stream = cuda_stream(args.opr->handle()); | |||||
| int8_t* filter_ptr = nullptr; | |||||
| if (args.preprocessed_filter == nullptr) { | |||||
| filter_ptr = reinterpret_cast<int8_t*>(args.workspace.raw_ptr); | |||||
| // reformat filter from nchw64 to chwn64 | |||||
| TensorLayout src{{co, ci / 64, fh, fw, 64}, dtype::QuantizedS4()}; | |||||
| src.init_contiguous_stride(); | |||||
| TensorLayout dst = src; | |||||
| dst.stride[0] = 64; | |||||
| dst.stride[1] = co * fh * fw * 64; | |||||
| dst.stride[2] = co * fw * 64; | |||||
| dst.stride[3] = co * 64; | |||||
| dst.stride[4] = 1; | |||||
| TensorND ts_src, ts_dst; | |||||
| ts_src.raw_ptr = args.filter_tensor->raw_ptr; | |||||
| ts_src.layout = src; | |||||
| ts_dst.raw_ptr = args.workspace.raw_ptr; | |||||
| ts_dst.layout = dst; | |||||
| auto&& transpose = | |||||
| args.opr->handle()->create_operator<RelayoutForward>(); | |||||
| transpose->exec(ts_src, ts_dst); | |||||
| } else { | |||||
| filter_ptr = reinterpret_cast<int8_t*>( | |||||
| args.preprocessed_filter->tensors[0].raw_ptr); | |||||
| } | |||||
| ConvParam kern_param; | |||||
| kern_param.n = n, kern_param.co = co, kern_param.ci = ci, | |||||
| kern_param.hi = hi, kern_param.wi = wi, kern_param.ho = ho, | |||||
| kern_param.wo = wo, kern_param.ph = ph, kern_param.pw = pw, | |||||
| kern_param.sh = sh, kern_param.sw = sw, kern_param.fh = fh, | |||||
| kern_param.fw = fw; | |||||
| float src_scale = args.src_layout->dtype.param<dtype::QuantizedS4>().scale, | |||||
| filter_scale = | |||||
| args.filter_layout->dtype.param<dtype::QuantizedS4>().scale, | |||||
| bias_scale = | |||||
| args.bias_layout->dtype.param<dtype::QuantizedS32>().scale, | |||||
| dst_scale = args.dst_layout->dtype.param<dtype::QuantizedS4>().scale; | |||||
| float alpha = src_scale * filter_scale / dst_scale, | |||||
| beta = bias_scale / dst_scale; | |||||
| int8_t* z_dev_ptr = nullptr; | |||||
| float gamma = 0.f; | |||||
| if (args.z_layout->ndim > 0) { | |||||
| z_dev_ptr = reinterpret_cast<int8_t*>(args.z_tensor->raw_ptr); | |||||
| float z_scale = args.z_layout->dtype.param<dtype::QuantizedS4>().scale; | |||||
| gamma = z_scale / dst_scale; | |||||
| } | |||||
| uint32_t nonlinear_mode = static_cast<uint32_t>(param.nonlineMode); | |||||
| cutlass_wrapper::do_conv_bias_int4_int4_implicit_gemm_imma_ncdiv64hw64< | |||||
| true>( | |||||
| reinterpret_cast<int8_t*>(args.src_tensor->raw_ptr), filter_ptr, | |||||
| args.bias_tensor->compatible_ptr<int32_t>(), z_dev_ptr, | |||||
| reinterpret_cast<int8_t*>(args.dst_tensor->raw_ptr), nullptr, | |||||
| kern_param, nonlinear_mode, alpha, beta, gamma, dst_scale, | |||||
| cutlass_wrapper::GemmCoord{m_algo_param.threadblock_m, | |||||
| m_algo_param.threadblock_n, | |||||
| m_algo_param.threadblock_k}, | |||||
| cutlass_wrapper::GemmCoord{m_algo_param.warp_m, m_algo_param.warp_n, | |||||
| m_algo_param.warp_k}, | |||||
| stream); | |||||
| } | |||||
| std::string ConvBiasForwardImpl::AlgoInt4Int4NCHW64IMMAImplicitGemm::to_string( | |||||
| AlgoParam algo_param) { | |||||
| return ssprintf("%uX%uX%u_%uX%uX%u", algo_param.threadblock_m, | |||||
| algo_param.threadblock_n, algo_param.threadblock_k, | |||||
| algo_param.warp_m, algo_param.warp_n, algo_param.warp_k); | |||||
| } | |||||
| size_t ConvBiasForwardImpl::AlgoInt4Int4NCHW64IMMAImplicitGemm:: | |||||
| get_preprocess_workspace_in_bytes(const SizeArgs& args) const { | |||||
| return 0_z; | |||||
| } | |||||
| SmallVector<TensorLayout> ConvBiasForwardImpl:: | |||||
| AlgoInt4Int4NCHW64IMMAImplicitGemm::deduce_preprocessed_filter_layout( | |||||
| const SizeArgs& args) const { | |||||
| return {args.filter_layout->collapse_contiguous()}; | |||||
| } | |||||
| void ConvBiasForwardImpl::AlgoInt4Int4NCHW64IMMAImplicitGemm::exec_preprocess( | |||||
| const ExecArgs& args) const { | |||||
| auto&& param = args.opr->param(); | |||||
| auto&& fm = args.filter_meta; | |||||
| size_t n = args.src_layout->operator[](0), | |||||
| ci = args.src_layout->operator[](1) * 64, | |||||
| hi = args.src_layout->operator[](2), | |||||
| wi = args.src_layout->operator[](3); | |||||
| size_t co = args.dst_layout->operator[](1) * 64, | |||||
| ho = args.dst_layout->operator[](2), | |||||
| wo = args.dst_layout->operator[](3); | |||||
| UNPACK_CONV_PARAMETER(fm, param); | |||||
| MARK_USED_VAR | |||||
| TensorLayout src{{co, ci / 64, fh, fw, 64}, dtype::QuantizedS4()}; | |||||
| src.init_contiguous_stride(); | |||||
| TensorLayout dst = src; | |||||
| dst.stride[0] = 64; | |||||
| dst.stride[1] = co * fh * fw * 64; | |||||
| dst.stride[2] = co * fw * 64; | |||||
| dst.stride[3] = co * 64; | |||||
| dst.stride[4] = 1; | |||||
| TensorND ts_src, ts_dst; | |||||
| ts_src.raw_ptr = args.filter_tensor->raw_ptr; | |||||
| ts_src.layout = src; | |||||
| ts_dst.raw_ptr = args.preprocessed_filter->tensors[0].raw_ptr; | |||||
| ts_dst.layout = dst; | |||||
| auto&& transpose = args.opr->handle()->create_operator<RelayoutForward>(); | |||||
| transpose->exec(ts_src, ts_dst); | |||||
| } | |||||
| #endif | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -0,0 +1 @@ | |||||
| ../int8/conv_bias_int8_implicit_gemm_cutlass_wrapper.cuinl | |||||
| @@ -0,0 +1,36 @@ | |||||
| #if !MEGDNN_TEGRA_X1 | |||||
| // generated by gen_cuda_conv_bias_int4_kern_impls.py | |||||
| // ignore warning of cutlass | |||||
| #pragma GCC diagnostic push | |||||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||||
| #include "src/cuda/conv_bias/int4/conv_bias_int4_implicit_gemm_cutlass_wrapper.cuinl" | |||||
| using LayoutSrc = cutlass::layout::TensorNCxHWx<64>; | |||||
| using LayoutFilter = cutlass::layout::TensorCxRSKx<64>; | |||||
| using LayoutDst = cutlass::layout::TensorNCxHWx<64>; | |||||
| using ThreadBlockShape = cutlass::gemm::GemmShape<128, 128, 128>; | |||||
| using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; | |||||
| using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; | |||||
| using EpilogueOp = cutlass::epilogue::thread::BiasAddLinearCombinationHSwishClamp< | |||||
| cutlass::int4b_t, 16, int32_t, int32_t, float>; | |||||
| using Convolution = cutlass::conv::device::Convolution< | |||||
| cutlass::int4b_t, LayoutSrc, cutlass::int4b_t, LayoutFilter, cutlass::int4b_t, | |||||
| LayoutDst, int32_t, LayoutDst, int32_t, | |||||
| cutlass::conv::ConvType::kConvolution, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, | |||||
| ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, | |||||
| cutlass::conv::threadblock::ConvolutionFpropNCxHWxThreadblockSwizzle, | |||||
| 2, 32, 32, true, | |||||
| cutlass::arch::OpMultiplyAddSaturate>; | |||||
| template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>( | |||||
| const typename Convolution::ElementSrc* d_src, | |||||
| const typename Convolution::ElementFilter* d_filter, | |||||
| const typename Convolution::ElementBias* d_bias, | |||||
| const typename Convolution::ElementDst* d_z, | |||||
| typename Convolution::ElementDst* d_dst, | |||||
| int* workspace, | |||||
| typename Convolution::ConvolutionParameter const& conv_param, | |||||
| typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||||
| cudaStream_t stream); | |||||
| #pragma GCC diagnostic pop | |||||
| #endif | |||||
| @@ -0,0 +1,36 @@ | |||||
| #if !MEGDNN_TEGRA_X1 | |||||
| // generated by gen_cuda_conv_bias_int4_kern_impls.py | |||||
| // ignore warning of cutlass | |||||
| #pragma GCC diagnostic push | |||||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||||
| #include "src/cuda/conv_bias/int4/conv_bias_int4_implicit_gemm_cutlass_wrapper.cuinl" | |||||
| using LayoutSrc = cutlass::layout::TensorNCxHWx<64>; | |||||
| using LayoutFilter = cutlass::layout::TensorCxRSKx<64>; | |||||
| using LayoutDst = cutlass::layout::TensorNCxHWx<64>; | |||||
| using ThreadBlockShape = cutlass::gemm::GemmShape<128, 128, 128>; | |||||
| using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; | |||||
| using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; | |||||
| using EpilogueOp = cutlass::epilogue::thread::BiasAddLinearCombinationClamp< | |||||
| cutlass::int4b_t, 16, int32_t, int32_t, float>; | |||||
| using Convolution = cutlass::conv::device::Convolution< | |||||
| cutlass::int4b_t, LayoutSrc, cutlass::int4b_t, LayoutFilter, cutlass::int4b_t, | |||||
| LayoutDst, int32_t, LayoutDst, int32_t, | |||||
| cutlass::conv::ConvType::kConvolution, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, | |||||
| ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, | |||||
| cutlass::conv::threadblock::ConvolutionFpropNCxHWxThreadblockSwizzle, | |||||
| 2, 32, 32, true, | |||||
| cutlass::arch::OpMultiplyAddSaturate>; | |||||
| template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>( | |||||
| const typename Convolution::ElementSrc* d_src, | |||||
| const typename Convolution::ElementFilter* d_filter, | |||||
| const typename Convolution::ElementBias* d_bias, | |||||
| const typename Convolution::ElementDst* d_z, | |||||
| typename Convolution::ElementDst* d_dst, | |||||
| int* workspace, | |||||
| typename Convolution::ConvolutionParameter const& conv_param, | |||||
| typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||||
| cudaStream_t stream); | |||||
| #pragma GCC diagnostic pop | |||||
| #endif | |||||
| @@ -0,0 +1,36 @@ | |||||
| #if !MEGDNN_TEGRA_X1 | |||||
| // generated by gen_cuda_conv_bias_int4_kern_impls.py | |||||
| // ignore warning of cutlass | |||||
| #pragma GCC diagnostic push | |||||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||||
| #include "src/cuda/conv_bias/int4/conv_bias_int4_implicit_gemm_cutlass_wrapper.cuinl" | |||||
| using LayoutSrc = cutlass::layout::TensorNCxHWx<64>; | |||||
| using LayoutFilter = cutlass::layout::TensorCxRSKx<64>; | |||||
| using LayoutDst = cutlass::layout::TensorNCxHWx<64>; | |||||
| using ThreadBlockShape = cutlass::gemm::GemmShape<128, 128, 128>; | |||||
| using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; | |||||
| using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; | |||||
| using EpilogueOp = cutlass::epilogue::thread::BiasAddLinearCombinationReluClamp< | |||||
| cutlass::int4b_t, 16, int32_t, int32_t, float>; | |||||
| using Convolution = cutlass::conv::device::Convolution< | |||||
| cutlass::int4b_t, LayoutSrc, cutlass::int4b_t, LayoutFilter, cutlass::int4b_t, | |||||
| LayoutDst, int32_t, LayoutDst, int32_t, | |||||
| cutlass::conv::ConvType::kConvolution, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, | |||||
| ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, | |||||
| cutlass::conv::threadblock::ConvolutionFpropNCxHWxThreadblockSwizzle, | |||||
| 2, 32, 32, true, | |||||
| cutlass::arch::OpMultiplyAddSaturate>; | |||||
| template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>( | |||||
| const typename Convolution::ElementSrc* d_src, | |||||
| const typename Convolution::ElementFilter* d_filter, | |||||
| const typename Convolution::ElementBias* d_bias, | |||||
| const typename Convolution::ElementDst* d_z, | |||||
| typename Convolution::ElementDst* d_dst, | |||||
| int* workspace, | |||||
| typename Convolution::ConvolutionParameter const& conv_param, | |||||
| typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||||
| cudaStream_t stream); | |||||
| #pragma GCC diagnostic pop | |||||
| #endif | |||||
| @@ -0,0 +1,36 @@ | |||||
| #if !MEGDNN_TEGRA_X1 | |||||
| // generated by gen_cuda_conv_bias_int4_kern_impls.py | |||||
| // ignore warning of cutlass | |||||
| #pragma GCC diagnostic push | |||||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||||
| #include "src/cuda/conv_bias/int4/conv_bias_int4_implicit_gemm_cutlass_wrapper.cuinl" | |||||
| using LayoutSrc = cutlass::layout::TensorNCxHWx<64>; | |||||
| using LayoutFilter = cutlass::layout::TensorCxRSKx<64>; | |||||
| using LayoutDst = cutlass::layout::TensorNCxHWx<64>; | |||||
| using ThreadBlockShape = cutlass::gemm::GemmShape<256, 128, 128>; | |||||
| using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; | |||||
| using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; | |||||
| using EpilogueOp = cutlass::epilogue::thread::BiasAddLinearCombinationHSwishClamp< | |||||
| cutlass::int4b_t, 16, int32_t, int32_t, float>; | |||||
| using Convolution = cutlass::conv::device::Convolution< | |||||
| cutlass::int4b_t, LayoutSrc, cutlass::int4b_t, LayoutFilter, cutlass::int4b_t, | |||||
| LayoutDst, int32_t, LayoutDst, int32_t, | |||||
| cutlass::conv::ConvType::kConvolution, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, | |||||
| ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, | |||||
| cutlass::conv::threadblock::ConvolutionFpropNCxHWxThreadblockSwizzle, | |||||
| 2, 32, 32, true, | |||||
| cutlass::arch::OpMultiplyAddSaturate>; | |||||
| template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>( | |||||
| const typename Convolution::ElementSrc* d_src, | |||||
| const typename Convolution::ElementFilter* d_filter, | |||||
| const typename Convolution::ElementBias* d_bias, | |||||
| const typename Convolution::ElementDst* d_z, | |||||
| typename Convolution::ElementDst* d_dst, | |||||
| int* workspace, | |||||
| typename Convolution::ConvolutionParameter const& conv_param, | |||||
| typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||||
| cudaStream_t stream); | |||||
| #pragma GCC diagnostic pop | |||||
| #endif | |||||
| @@ -0,0 +1,36 @@ | |||||
| #if !MEGDNN_TEGRA_X1 | |||||
| // generated by gen_cuda_conv_bias_int4_kern_impls.py | |||||
| // ignore warning of cutlass | |||||
| #pragma GCC diagnostic push | |||||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||||
| #include "src/cuda/conv_bias/int4/conv_bias_int4_implicit_gemm_cutlass_wrapper.cuinl" | |||||
| using LayoutSrc = cutlass::layout::TensorNCxHWx<64>; | |||||
| using LayoutFilter = cutlass::layout::TensorCxRSKx<64>; | |||||
| using LayoutDst = cutlass::layout::TensorNCxHWx<64>; | |||||
| using ThreadBlockShape = cutlass::gemm::GemmShape<256, 128, 128>; | |||||
| using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; | |||||
| using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; | |||||
| using EpilogueOp = cutlass::epilogue::thread::BiasAddLinearCombinationClamp< | |||||
| cutlass::int4b_t, 16, int32_t, int32_t, float>; | |||||
| using Convolution = cutlass::conv::device::Convolution< | |||||
| cutlass::int4b_t, LayoutSrc, cutlass::int4b_t, LayoutFilter, cutlass::int4b_t, | |||||
| LayoutDst, int32_t, LayoutDst, int32_t, | |||||
| cutlass::conv::ConvType::kConvolution, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, | |||||
| ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, | |||||
| cutlass::conv::threadblock::ConvolutionFpropNCxHWxThreadblockSwizzle, | |||||
| 2, 32, 32, true, | |||||
| cutlass::arch::OpMultiplyAddSaturate>; | |||||
| template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>( | |||||
| const typename Convolution::ElementSrc* d_src, | |||||
| const typename Convolution::ElementFilter* d_filter, | |||||
| const typename Convolution::ElementBias* d_bias, | |||||
| const typename Convolution::ElementDst* d_z, | |||||
| typename Convolution::ElementDst* d_dst, | |||||
| int* workspace, | |||||
| typename Convolution::ConvolutionParameter const& conv_param, | |||||
| typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||||
| cudaStream_t stream); | |||||
| #pragma GCC diagnostic pop | |||||
| #endif | |||||
| @@ -0,0 +1,36 @@ | |||||
| #if !MEGDNN_TEGRA_X1 | |||||
| // generated by gen_cuda_conv_bias_int4_kern_impls.py | |||||
| // ignore warning of cutlass | |||||
| #pragma GCC diagnostic push | |||||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||||
| #include "src/cuda/conv_bias/int4/conv_bias_int4_implicit_gemm_cutlass_wrapper.cuinl" | |||||
| using LayoutSrc = cutlass::layout::TensorNCxHWx<64>; | |||||
| using LayoutFilter = cutlass::layout::TensorCxRSKx<64>; | |||||
| using LayoutDst = cutlass::layout::TensorNCxHWx<64>; | |||||
| using ThreadBlockShape = cutlass::gemm::GemmShape<256, 128, 128>; | |||||
| using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; | |||||
| using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; | |||||
| using EpilogueOp = cutlass::epilogue::thread::BiasAddLinearCombinationReluClamp< | |||||
| cutlass::int4b_t, 16, int32_t, int32_t, float>; | |||||
| using Convolution = cutlass::conv::device::Convolution< | |||||
| cutlass::int4b_t, LayoutSrc, cutlass::int4b_t, LayoutFilter, cutlass::int4b_t, | |||||
| LayoutDst, int32_t, LayoutDst, int32_t, | |||||
| cutlass::conv::ConvType::kConvolution, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, | |||||
| ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, | |||||
| cutlass::conv::threadblock::ConvolutionFpropNCxHWxThreadblockSwizzle, | |||||
| 2, 32, 32, true, | |||||
| cutlass::arch::OpMultiplyAddSaturate>; | |||||
| template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>( | |||||
| const typename Convolution::ElementSrc* d_src, | |||||
| const typename Convolution::ElementFilter* d_filter, | |||||
| const typename Convolution::ElementBias* d_bias, | |||||
| const typename Convolution::ElementDst* d_z, | |||||
| typename Convolution::ElementDst* d_dst, | |||||
| int* workspace, | |||||
| typename Convolution::ConvolutionParameter const& conv_param, | |||||
| typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||||
| cudaStream_t stream); | |||||
| #pragma GCC diagnostic pop | |||||
| #endif | |||||
| @@ -64,6 +64,7 @@ public: | |||||
| class AlgoInt8CHWN4IMMAImplicitGemmReorderFilter; | class AlgoInt8CHWN4IMMAImplicitGemmReorderFilter; | ||||
| class AlgoInt8CHWN4IMMAImplicitGemmUnrollWidth; | class AlgoInt8CHWN4IMMAImplicitGemmUnrollWidth; | ||||
| class AlgoInt8NCHW32IMMAImplicitGemm; | class AlgoInt8NCHW32IMMAImplicitGemm; | ||||
| class AlgoInt4Int4NCHW64IMMAImplicitGemm; | |||||
| class AlgoBFloat16; | class AlgoBFloat16; | ||||
| class AlgoPack; | class AlgoPack; | ||||
| @@ -689,7 +689,7 @@ TEST_F(CUDA, CONV_BIAS_INT8_CHWN4_UNROLL_WIDTH_TENSORCORE_1x1_ALGO_2) { | |||||
| } | } | ||||
| TEST_F(CUDA, CUTLASS_WEIGHT_PREPROCESS) { | |||||
| TEST_F(CUDA, CUTLASS_INT8_WEIGHT_PREPROCESS) { | |||||
| require_compute_capability(6, 1); | require_compute_capability(6, 1); | ||||
| Checker<ConvBiasForward, OprWeightPreprocessProxy<ConvBiasForward>> checker( | Checker<ConvBiasForward, OprWeightPreprocessProxy<ConvBiasForward>> checker( | ||||
| handle_cuda()); | handle_cuda()); | ||||