GitOrigin-RevId: 9d6c48ed99
tags/v1.0.0-rc1
| @@ -57,6 +57,13 @@ add_dependencies(opr_param_defs _opr_param_defs) | |||
| install(TARGETS opr_param_defs EXPORT ${MGE_EXPORT_TARGETS}) | |||
| if(MGE_WITH_CUDA) | |||
| add_library(cutlass INTERFACE) | |||
| target_include_directories(cutlass | |||
| INTERFACE | |||
| $<BUILD_INTERFACE:${PROJECT_SOURCE_DIR}/third_party/cutlass/include>) | |||
| install(TARGETS cutlass EXPORT ${MGE_EXPORT_TARGETS}) | |||
| endif() | |||
| if(MGE_WITH_TEST) | |||
| if(NOT MGE_BUILD_IMPERATIVE_RT) | |||
| @@ -36,8 +36,9 @@ all: ${PARAM_DEFS} ${ELEMWISE_IMPL} ${CUDA_CONV_IMPL} | |||
| ../src/cuda/conv_bias/int8/kimpl: gen_cuda_conv_bias_kern_impls.py | |||
| ./$^ --type dp4a $@ | |||
| ../src/cuda/conv_bias/int8_imma/kimpl: gen_cuda_conv_bias_kern_impls.py | |||
| ./$^ --type imma $@ | |||
| ../src/cuda/conv_bias/int8_imma/kimpl: gen_cuda_conv_bias_kern_impls.py gen_cutlass_conv_bias_kern_impls.py | |||
| ./gen_cuda_conv_bias_kern_impls.py --type imma $@ | |||
| ./gen_cutlass_conv_bias_kern_impls.py --type imma $@ | |||
| ../src/cuda/batch_conv_bias/int8/kimpl: gen_cuda_batch_conv_bias_kern_impls.py | |||
| ./$^ --type dp4a $@ | |||
| @@ -51,6 +51,9 @@ add_definitions(${LIBMEGDNN_DEF}) | |||
| add_library(megdnn EXCLUDE_FROM_ALL OBJECT ${SOURCES}) | |||
| target_link_libraries(megdnn PUBLIC opr_param_defs) | |||
| if(MGE_WITH_CUDA) | |||
| target_link_libraries(megdnn PUBLIC cutlass) | |||
| endif() | |||
| if(${MGE_ARCH} STREQUAL "x86_64" OR ${MGE_ARCH} STREQUAL "i386" OR ${MGE_ARCH} STREQUAL "armv7" OR ${MGE_ARCH} STREQUAL "aarch64") | |||
| if(MGE_ENABLE_CPUINFO) | |||
| @@ -85,6 +85,11 @@ ConvBiasForwardImpl::AlgoPack::AlgoPack() { | |||
| for (auto&& algo : int8_chwn4_imma_unroll_width) { | |||
| all_algos.push_back(&algo); | |||
| } | |||
| #if CUDA_VERSION >= 10020 | |||
| for (auto&& algo : int8_nchw32_imma) { | |||
| all_algos.push_back(&algo); | |||
| } | |||
| #endif | |||
| #endif | |||
| all_algos.push_back(&int8_nchw4_dotprod); | |||
| all_algos.push_back(&int8_chwn4_dotprod); | |||
| @@ -233,6 +238,18 @@ void ConvBiasForwardImpl::AlgoPack::fill_imma_algos() { | |||
| int8_chwn4_imma_unroll_width.push_back( | |||
| {AlgoInt8CHWN4IMMAImplicitGemmUnrollWidth::MMATileSize:: | |||
| IMMA8x32x16}); | |||
| #if CUDA_VERSION >= 10020 | |||
| { | |||
| using AlgoParam = AlgoInt8NCHW32IMMAImplicitGemm::AlgoParam; | |||
| int8_nchw32_imma.emplace_back(AlgoParam{128, 256, 64, 64, 64, 64}); | |||
| int8_nchw32_imma.emplace_back(AlgoParam{256, 128, 64, 64, 64, 64}); | |||
| int8_nchw32_imma.emplace_back(AlgoParam{128, 128, 64, 64, 64, 64}); | |||
| int8_nchw32_imma.emplace_back(AlgoParam{64, 128, 64, 32, 64, 64}); | |||
| int8_nchw32_imma.emplace_back(AlgoParam{128, 64, 64, 64, 32, 64}); | |||
| int8_nchw32_imma.emplace_back(AlgoParam{64, 64, 64, 32, 32, 64}); | |||
| int8_nchw32_imma.emplace_back(AlgoParam{32, 64, 64, 32, 16, 64}); | |||
| } | |||
| #endif | |||
| } | |||
| #endif | |||
| @@ -499,6 +499,41 @@ private: | |||
| }; | |||
| #endif | |||
| #if CUDA_VERSION >= 10020 | |||
| class ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm final | |||
| : public AlgoBase { | |||
| public: | |||
| struct AlgoParam { | |||
| int threadblock_m; | |||
| int threadblock_n; | |||
| int threadblock_k; | |||
| int warp_m; | |||
| int warp_n; | |||
| int warp_k; | |||
| }; | |||
| AlgoInt8NCHW32IMMAImplicitGemm(AlgoParam algo_param) | |||
| : m_algo_param{algo_param} { | |||
| m_name = ConvBias::algo_name<ConvBias::DirectParam>( | |||
| ssprintf("INT8_NCHW32_IMMA_IMPLICIT_GEMM_%s", | |||
| to_string(m_algo_param).c_str()), | |||
| ConvBias::DirectParam{}); | |||
| } | |||
| bool is_available(const SizeArgs& args) const override; | |||
| size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||
| void exec(const ExecArgs& args) const override; | |||
| const char* name() const override { return m_name.c_str(); } | |||
| bool is_reproducible() const override { return true; } | |||
| static std::string to_string(AlgoParam algo_param); | |||
| private: | |||
| WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, | |||
| const SizeArgs& args) const; | |||
| AlgoParam m_algo_param; | |||
| std::string m_name; | |||
| }; | |||
| #endif | |||
| class ConvBiasForwardImpl::AlgoBFloat16 final : public AlgoBase { | |||
| public: | |||
| AlgoBFloat16(AlgoBase* impl); | |||
| @@ -553,6 +588,9 @@ public: | |||
| int8_chwn4_imma_reorder_filter; | |||
| std::vector<AlgoInt8CHWN4IMMAImplicitGemmUnrollWidth> | |||
| int8_chwn4_imma_unroll_width; | |||
| #endif | |||
| #if CUDA_VERSION >= 10020 | |||
| std::vector<AlgoInt8NCHW32IMMAImplicitGemm> int8_nchw32_imma; | |||
| #endif | |||
| std::vector<std::unique_ptr<AlgoGroupConvGeneral>> gconv_refhold; | |||
| std::vector<std::unique_ptr<AlgoBFloat16>> bfloat16_refhold; | |||
| @@ -142,4 +142,12 @@ void do_conv_bias_int8_implicit_gemm_imma8x32x16_cdiv4hwn4_unroll_width( | |||
| UNPACK_CONV_PARAMETER(_filter_meta, _param); \ | |||
| MARK_USED_VAR | |||
| #define UNPACK_CONV_BIAS_NCHW32_PARAM(_src, _filter_meta, _dst, _param) \ | |||
| using Format = param::ConvBias::Format; \ | |||
| megdnn_assert(_param.format == Format::NCHW32); \ | |||
| size_t n = (_src)[0], ci = (_src)[1] * 32, hi = (_src)[2], wi = (_src)[3]; \ | |||
| size_t co = (_dst)[1] * 32, ho = (_dst)[2], wo = (_dst)[3]; \ | |||
| UNPACK_CONV_PARAMETER(_filter_meta, _param); \ | |||
| MARK_USED_VAR | |||
| // vim: syntax=cuda.doxygen | |||
| @@ -0,0 +1,152 @@ | |||
| /** | |||
| * \file dnn/src/cuda/conv_bias/cutlass_convolution_wrapper.cu | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| // ignore warning of cutlass | |||
| #pragma GCC diagnostic push | |||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||
| #if !MEGDNN_TEGRA_X1 | |||
| #include "cutlass/convolution/device/convolution.h" | |||
| #endif | |||
| #include "src/common/opr_param_defs_enumv.cuh" | |||
| #include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh" | |||
| #pragma GCC diagnostic pop | |||
| using namespace megdnn; | |||
| using namespace cuda; | |||
| using namespace cutlass_wrapper; | |||
| #if MEGDNN_TEGRA_X1 | |||
| template <bool NeedLoadFromConstMem> | |||
| void megdnn::cuda::cutlass_wrapper:: | |||
| do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32( | |||
| const int8_t* /* d_src */, const int8_t* /* d_filter */, | |||
| const int32_t* /* d_bias */, const int8_t* /* d_z */, | |||
| int8_t* /* d_dst */, int* /* workspace */, | |||
| const convolution::ConvParam& /* param */, | |||
| uint32_t /* nonlinear_mode */, float /* alpha */, | |||
| float /* beta */, float /* gamma */, float /* scale */, | |||
| const GemmCoord& /* threadblock_shape */, | |||
| const GemmCoord& /* warp_shape */, cudaStream_t /* stream */) {} | |||
| #else | |||
| template <bool NeedLoadFromConstMem> | |||
| void megdnn::cuda::cutlass_wrapper:: | |||
| do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32( | |||
| const int8_t* d_src, const int8_t* d_filter, | |||
| const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, | |||
| int* workspace, const convolution::ConvParam& param, | |||
| uint32_t nonlinear_mode, float alpha, float beta, float gamma, | |||
| float scale, const GemmCoord& threadblock_shape, | |||
| const GemmCoord& warp_shape, cudaStream_t stream) { | |||
| #define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \ | |||
| threadblock_k_, warp_m_, warp_n_, \ | |||
| warp_k_) \ | |||
| if (threadblock_shape.m() == threadblock_m_ && \ | |||
| threadblock_shape.n() == threadblock_n_ && \ | |||
| threadblock_shape.k() == threadblock_k_ && \ | |||
| warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ | |||
| warp_shape.k() == warp_k_) { \ | |||
| using ThreadBlockShape = \ | |||
| cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \ | |||
| threadblock_k_>; \ | |||
| using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \ | |||
| using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; \ | |||
| using Convolution = cutlass::convolution::device::Convolution< \ | |||
| int8_t, cutlass::layout::TensorNCxHWx<32>, int8_t, \ | |||
| cutlass::layout::TensorCxRSKx<32>, ElementOutput, \ | |||
| cutlass::layout::TensorNCxHWx<32>, int32_t, \ | |||
| cutlass::layout::TensorNCxHWx<32>, int32_t, \ | |||
| cutlass::convolution::ConvType::kConvolution, \ | |||
| cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \ | |||
| ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ | |||
| cutlass::convolution::threadblock:: \ | |||
| ConvolutionNCxHWxThreadblockSwizzle< \ | |||
| cutlass::convolution::ConvType::kConvolution>, \ | |||
| 2, 16, 16, NeedLoadFromConstMem>; \ | |||
| typename Convolution::ConvolutionParameter conv_param{ \ | |||
| param.n, param.ci, param.co, param.hi, param.wi, \ | |||
| param.fh, param.fw, param.ho, param.wo, param.sh, \ | |||
| param.sw, param.ph, param.pw, 1, 1}; \ | |||
| return cutlass_convolution_wrapper<Convolution>( \ | |||
| d_src, d_filter, d_bias, d_z, d_dst, workspace, conv_param, \ | |||
| epilogue, stream); \ | |||
| } | |||
| #define DISPATCH_KERNEL \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(256, 128, 64, 64, 64, 64); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 256, 64, 64, 64, 64); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 64, 64, 64, 64); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 64, 32, 64, 64); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 32, 64); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 64, 64, 32, 32, 64); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 64, 64, 32, 16, 64); \ | |||
| megdnn_assert(false, \ | |||
| "unsupported threadblock shape (%dx%dx%d) and warp shape " \ | |||
| "(%dx%dx%d)", \ | |||
| threadblock_shape.m(), threadblock_shape.n(), \ | |||
| threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ | |||
| warp_shape.k()); | |||
| using ElementOutput = int8_t; | |||
| using ElementAccumulator = int32_t; | |||
| using ElementBias = int32_t; | |||
| using ElementCompute = float; | |||
| using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode; | |||
| switch (nonlinear_mode) { | |||
| case NonlineMode::IDENTITY: { | |||
| using EpilogueOp = | |||
| cutlass::epilogue::thread::BiasAddLinearCombinationClamp< | |||
| ElementOutput, 8, ElementAccumulator, ElementBias, | |||
| ElementCompute>; | |||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma}; | |||
| DISPATCH_KERNEL; | |||
| } | |||
| case NonlineMode::RELU: { | |||
| using EpilogueOp = cutlass::epilogue::thread:: | |||
| BiasAddLinearCombinationReluClamp< | |||
| ElementOutput, 8, ElementAccumulator, ElementBias, | |||
| ElementCompute>; | |||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma, 0}; | |||
| DISPATCH_KERNEL; | |||
| } | |||
| case NonlineMode::H_SWISH: { | |||
| using EpilogueOp = cutlass::epilogue::thread:: | |||
| BiasAddLinearCombinationHSwishClamp< | |||
| ElementOutput, 8, ElementAccumulator, ElementBias, | |||
| ElementCompute>; | |||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma, scale}; | |||
| DISPATCH_KERNEL; | |||
| } | |||
| default: | |||
| megdnn_assert(false, | |||
| "unsupported nonlinear mode for conv bias operator"); | |||
| } | |||
| #undef DISPATCH_KERNEL_WITH_TILE_SHAPE | |||
| #undef DISPATCH_KERNEL | |||
| } | |||
| #endif | |||
| #define INST(need_load_from_const_mem) \ | |||
| template void megdnn::cuda::cutlass_wrapper:: \ | |||
| do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32< \ | |||
| need_load_from_const_mem>( \ | |||
| const int8_t* d_src, const int8_t* d_filter, \ | |||
| const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, \ | |||
| int* workspace, const convolution::ConvParam& param, \ | |||
| uint32_t nonlinear_mode, float alpha, float beta, \ | |||
| float gamma, float scale, \ | |||
| const GemmCoord& threadblock_shape, \ | |||
| const GemmCoord& warp_shape, cudaStream_t stream); | |||
| INST(true); | |||
| INST(false); | |||
| #undef INST | |||
| // vim: syntax=cuda.doxygen | |||
| @@ -0,0 +1,44 @@ | |||
| /** | |||
| * \file dnn/src/cuda/conv_bias/cutlass_convolution_wrapper.cuh | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #pragma once | |||
| #include "cutlass/gemm/gemm.h" | |||
| #include "src/cuda/convolution_helper/parameter.cuh" | |||
| #include "src/cuda/utils.cuh" | |||
| namespace megdnn { | |||
| namespace cuda { | |||
| namespace cutlass_wrapper { | |||
| using GemmCoord = cutlass::gemm::GemmCoord; | |||
| template <typename Convolution> | |||
| void cutlass_convolution_wrapper( | |||
| const int8_t* d_src, const int8_t* d_filter, const int32_t* d_bias, | |||
| const int8_t* d_z, int8_t* d_dst, int* workspace, | |||
| typename Convolution::ConvolutionParameter const& conv_param, | |||
| typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||
| cudaStream_t stream); | |||
| template <bool NeedLoadFromConstMem> | |||
| void do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32( | |||
| const int8_t* d_src, const int8_t* d_filter, const int32_t* d_bias, | |||
| const int8_t* d_z, int8_t* d_dst, int* workspace, | |||
| const convolution::ConvParam& param, uint32_t nonlinear_mode, | |||
| float alpha, float beta, float gamma, float scale, | |||
| const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, | |||
| cudaStream_t stream); | |||
| } // namespace cutlass_wrapper | |||
| } // namespace cuda | |||
| } // namespace megdnn | |||
| // vim: syntax=cuda.doxygen | |||
| @@ -0,0 +1,188 @@ | |||
| /** | |||
| * \file dnn/src/cuda/conv_bias/implicit_gemm_int8_nchw32_imma.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "./algo.h" | |||
| #include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh" | |||
| #include "src/cuda/convolution_helper/parameter.cuh" | |||
| #include "src/cuda/utils.h" | |||
| using namespace megdnn; | |||
| using namespace cuda; | |||
| using namespace convolution; | |||
| #if CUDA_VERSION >= 10020 | |||
| bool ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::is_available( | |||
| const SizeArgs& args) const { | |||
| if (args.bias_layout->ndim <= 0) | |||
| return false; | |||
| using Param = param::ConvBias; | |||
| using Format = Param::Format; | |||
| using Sparse = Param::Sparse; | |||
| using Mode = Param::Mode; | |||
| bool available = true; | |||
| auto&& param = args.opr->param(); | |||
| auto&& fm = args.filter_meta; | |||
| if (!conv_bias::check_bias_share_in_channel(*(args.bias_layout), | |||
| param.format)) | |||
| return false; | |||
| if (param.format != Format::NCHW32) | |||
| return false; | |||
| UNPACK_CONV_BIAS_NCHW32_PARAM(*(args.src_layout), fm, *(args.dst_layout), | |||
| param); | |||
| // TODO support group conv | |||
| available &= param.sparse == Sparse::DENSE; | |||
| // mode must be cross correlation | |||
| available &= param.mode == Mode::CROSS_CORRELATION; | |||
| // check data type | |||
| auto src_dtype = args.src_layout->dtype, | |||
| filter_dtype = args.filter_layout->dtype, | |||
| bias_dtype = args.bias_layout->dtype, | |||
| dst_dtype = args.dst_layout->dtype; | |||
| available &= (src_dtype.enumv() == DTypeEnum::QuantizedS8 && | |||
| filter_dtype.enumv() == DTypeEnum::QuantizedS8 && | |||
| bias_dtype.enumv() == DTypeEnum::QuantizedS32 && | |||
| dst_dtype.enumv() == DTypeEnum::QuantizedS8); | |||
| // TODO: support dialtion | |||
| available &= dh == 1 && dw == 1; | |||
| // only support sm_75 or later, platform should have tensorcore int8 | |||
| // support | |||
| available &= is_compute_capability_required(7, 5); | |||
| if (fh == 1 && fw == 1) | |||
| return available; | |||
| // for non 1x1 convolution, we have to check constant memory size | |||
| auto&& device_prop = current_device_prop(); | |||
| // const mem size >= 64K | |||
| available &= device_prop.totalConstMem >= 65536; | |||
| size_t const_mem_usage = get_workspace_in_bytes(args) - | |||
| args.filter_layout->span().dist_byte(); | |||
| available &= const_mem_usage <= device_prop.totalConstMem; | |||
| return available; | |||
| } | |||
| WorkspaceBundle | |||
| ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::get_workspace_bundle( | |||
| dt_byte* raw_ptr, const SizeArgs& args) const { | |||
| size_t ci = args.filter_layout->operator[](1) * 32; | |||
| size_t fh = args.filter_layout->operator[](2); | |||
| size_t fw = args.filter_layout->operator[](3); | |||
| size_t ws_filter = args.filter_layout->span().dist_byte(); | |||
| if (fh == 1 && fw == 1) { | |||
| return WorkspaceBundle{raw_ptr, {ws_filter}}; | |||
| } | |||
| size_t ws_size = (ci / 32) * fh * fw * sizeof(int32_t) * 2; | |||
| return WorkspaceBundle{raw_ptr, {ws_filter, ws_size}}; | |||
| } | |||
| size_t | |||
| ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::get_workspace_in_bytes( | |||
| const SizeArgs& args) const { | |||
| return get_workspace_bundle(nullptr, args).total_size_in_bytes(); | |||
| } | |||
| void ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::exec( | |||
| const ExecArgs& args) const { | |||
| using Format = Param::Format; | |||
| auto&& param = args.opr->param(); | |||
| auto&& fm = args.filter_meta; | |||
| UNPACK_CONV_BIAS_NCHW32_PARAM(*(args.src_layout), fm, *(args.dst_layout), | |||
| param); | |||
| auto ws = get_workspace_bundle(args.workspace.raw_ptr, args); | |||
| auto ws_filter = ws.get(0); | |||
| auto&& stream = cuda_stream(args.opr->handle()); | |||
| // reformat filter from nchw32 to chwn32 | |||
| { | |||
| TensorLayout src{{co, ci / 32, fh, fw, 32}, dtype::Int8()}; | |||
| src.init_contiguous_stride(); | |||
| TensorLayout dst = src; | |||
| dst.stride[0] = 32; | |||
| dst.stride[1] = co * fh * fw * 32; | |||
| dst.stride[2] = co * fw * 32; | |||
| dst.stride[3] = co * 32; | |||
| dst.stride[4] = 1; | |||
| TensorND ts_src, ts_dst; | |||
| ts_src.raw_ptr = args.filter_tensor->raw_ptr; | |||
| ts_src.layout = src; | |||
| ts_dst.raw_ptr = ws_filter; | |||
| ts_dst.layout = dst; | |||
| auto&& transpose = | |||
| args.opr->handle()->create_operator<RelayoutForward>(); | |||
| transpose->exec(ts_src, ts_dst); | |||
| } | |||
| ConvParam kern_param; | |||
| kern_param.n = n, kern_param.co = co, kern_param.ci = ci, | |||
| kern_param.hi = hi, kern_param.wi = wi, kern_param.ho = ho, | |||
| kern_param.wo = wo, kern_param.ph = ph, kern_param.pw = pw, | |||
| kern_param.sh = sh, kern_param.sw = sw, kern_param.fh = fh, | |||
| kern_param.fw = fw; | |||
| float src_scale = args.src_layout->dtype.param<dtype::QuantizedS8>().scale, | |||
| filter_scale = | |||
| args.filter_layout->dtype.param<dtype::QuantizedS8>().scale, | |||
| bias_scale = | |||
| args.bias_layout->dtype.param<dtype::QuantizedS32>().scale, | |||
| dst_scale = args.dst_layout->dtype.param<dtype::QuantizedS8>().scale; | |||
| float alpha = src_scale * filter_scale / dst_scale, | |||
| beta = bias_scale / dst_scale; | |||
| int8_t* z_dev_ptr = nullptr; | |||
| float gamma = 0.0; | |||
| if (args.z_layout->ndim > 0) { | |||
| z_dev_ptr = args.z_tensor->compatible_ptr<int8_t>(); | |||
| float z_scale = args.z_layout->dtype.param<dtype::QuantizedS8>().scale; | |||
| gamma = z_scale / dst_scale; | |||
| } | |||
| uint32_t nonlinear_mode = static_cast<uint32_t>(param.nonlineMode); | |||
| if (fh == 1 && fw == 1) { | |||
| cutlass_wrapper::do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32< | |||
| false>(args.src_tensor->compatible_ptr<int8_t>(), | |||
| reinterpret_cast<int8_t*>(ws_filter), | |||
| args.bias_tensor->compatible_ptr<int32_t>(), z_dev_ptr, | |||
| args.dst_tensor->compatible_ptr<int8_t>(), | |||
| nullptr, kern_param, nonlinear_mode, | |||
| alpha, beta, gamma, dst_scale, | |||
| cutlass_wrapper::GemmCoord{m_algo_param.threadblock_m, | |||
| m_algo_param.threadblock_n, | |||
| m_algo_param.threadblock_k}, | |||
| cutlass_wrapper::GemmCoord{m_algo_param.warp_m, | |||
| m_algo_param.warp_n, | |||
| m_algo_param.warp_k}, | |||
| stream); | |||
| } else { | |||
| auto workspace = ws.get(1); | |||
| cutlass_wrapper::do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32<true>( | |||
| args.src_tensor->compatible_ptr<int8_t>(), | |||
| reinterpret_cast<int8_t*>(ws_filter), | |||
| args.bias_tensor->compatible_ptr<int32_t>(), z_dev_ptr, | |||
| args.dst_tensor->compatible_ptr<int8_t>(), | |||
| reinterpret_cast<int*>(workspace), kern_param, nonlinear_mode, | |||
| alpha, beta, gamma, dst_scale, | |||
| cutlass_wrapper::GemmCoord{m_algo_param.threadblock_m, | |||
| m_algo_param.threadblock_n, | |||
| m_algo_param.threadblock_k}, | |||
| cutlass_wrapper::GemmCoord{m_algo_param.warp_m, | |||
| m_algo_param.warp_n, | |||
| m_algo_param.warp_k}, | |||
| stream); | |||
| } | |||
| } | |||
| std::string ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::to_string( | |||
| AlgoParam algo_param) { | |||
| return ssprintf("%uX%uX%u_%uX%uX%u", algo_param.threadblock_m, | |||
| algo_param.threadblock_n, algo_param.threadblock_k, | |||
| algo_param.warp_m, algo_param.warp_n, algo_param.warp_k); | |||
| } | |||
| #endif | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,57 @@ | |||
| /** | |||
| * \file | |||
| * dnn/src/cuda/conv_bias/int8_imma/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32.cuinl | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "cutlass/convolution/device/convolution.h" | |||
| #include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh" | |||
| using namespace megdnn; | |||
| using namespace cuda; | |||
| using namespace cutlass_wrapper; | |||
| template <typename Convolution> | |||
| void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper( | |||
| const int8_t* d_src, const int8_t* d_filter, const int32_t* d_bias, | |||
| const int8_t* d_z, int8_t* d_dst, int* workspace, | |||
| typename Convolution::ConvolutionParameter const& conv_param, | |||
| typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||
| cudaStream_t stream) { | |||
| typename Convolution::TensorRefSrc tensor_src{ | |||
| const_cast<int8_t*>(d_src), | |||
| Convolution::LayoutSrc::packed({conv_param.n(), conv_param.hi(), | |||
| conv_param.wi(), conv_param.ci()})}; | |||
| typename Convolution::TensorRefFilter tensor_filter{ | |||
| const_cast<int8_t*>(d_filter), | |||
| Convolution::LayoutFilter::packed({conv_param.co(), conv_param.fh(), | |||
| conv_param.fw(), | |||
| conv_param.ci()})}; | |||
| typename Convolution::TensorRefBias tensor_bias{ | |||
| const_cast<int32_t*>(d_bias), | |||
| Convolution::LayoutBias::packed({1, 1, 1, conv_param.co()})}; | |||
| typename Convolution::TensorRefDst tensor_z{ | |||
| const_cast<int8_t*>(d_z), | |||
| Convolution::LayoutDst::packed({conv_param.n(), conv_param.ho(), | |||
| conv_param.wo(), conv_param.co()})}; | |||
| typename Convolution::TensorRefDst tensor_dst{ | |||
| d_dst, | |||
| Convolution::LayoutDst::packed({conv_param.n(), conv_param.ho(), | |||
| conv_param.wo(), conv_param.co()})}; | |||
| typename Convolution::Arguments arguments{ | |||
| conv_param, tensor_src, tensor_filter, | |||
| tensor_bias, tensor_z, tensor_dst.non_const_ref(), | |||
| epilogue}; | |||
| Convolution conv_op; | |||
| cutlass_check(conv_op.initialize(arguments, workspace)); | |||
| cutlass_check(conv_op(stream)); | |||
| after_kernel_launch(); | |||
| } | |||
| // vim: syntax=cuda.doxygen | |||
| @@ -0,0 +1,35 @@ | |||
| #if !MEGDNN_TEGRA_X1 | |||
| // generated by gen_cuda_conv_bias_kern_impls.py | |||
| // ignore warning of cutlass | |||
| #pragma GCC diagnostic push | |||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||
| #include "../conv_bias_int8_implicit_gemm_imma_ncdiv32hw32.cuinl" | |||
| using LayoutSrc = cutlass::layout::TensorNCxHWx<32>; | |||
| using LayoutFilter = cutlass::layout::TensorCxRSKx<32>; | |||
| using ThreadBlockShape = cutlass::gemm::GemmShape<128, 128, 64>; | |||
| using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; | |||
| using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; | |||
| using EpilogueOp = cutlass::epilogue::thread::BiasAddLinearCombinationHSwishClamp< | |||
| int8_t, 8, int32_t, int32_t, float>; | |||
| using Convolution = cutlass::convolution::device::Convolution< | |||
| int8_t, LayoutSrc, int8_t, LayoutFilter, int8_t, | |||
| LayoutSrc, int32_t, LayoutSrc, int32_t, | |||
| cutlass::convolution::ConvType::kConvolution, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, | |||
| ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, | |||
| cutlass::convolution::threadblock::ConvolutionNCxHWxThreadblockSwizzle< | |||
| cutlass::convolution::ConvType::kConvolution>, | |||
| 2, 16, 16, true>; | |||
| template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>( | |||
| const int8_t* d_src, | |||
| const int8_t* d_filter, | |||
| const int32_t* d_bias, | |||
| const int8_t* d_z, | |||
| int8_t* d_dst, | |||
| int* workspace, | |||
| typename Convolution::ConvolutionParameter const& conv_param, | |||
| typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||
| cudaStream_t stream); | |||
| #pragma GCC diagnostic pop | |||
| #endif | |||
| @@ -0,0 +1,35 @@ | |||
| #if !MEGDNN_TEGRA_X1 | |||
| // generated by gen_cuda_conv_bias_kern_impls.py | |||
| // ignore warning of cutlass | |||
| #pragma GCC diagnostic push | |||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||
| #include "../conv_bias_int8_implicit_gemm_imma_ncdiv32hw32.cuinl" | |||
| using LayoutSrc = cutlass::layout::TensorNCxHWx<32>; | |||
| using LayoutFilter = cutlass::layout::TensorCxRSKx<32>; | |||
| using ThreadBlockShape = cutlass::gemm::GemmShape<128, 128, 64>; | |||
| using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; | |||
| using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; | |||
| using EpilogueOp = cutlass::epilogue::thread::BiasAddLinearCombinationClamp< | |||
| int8_t, 8, int32_t, int32_t, float>; | |||
| using Convolution = cutlass::convolution::device::Convolution< | |||
| int8_t, LayoutSrc, int8_t, LayoutFilter, int8_t, | |||
| LayoutSrc, int32_t, LayoutSrc, int32_t, | |||
| cutlass::convolution::ConvType::kConvolution, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, | |||
| ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, | |||
| cutlass::convolution::threadblock::ConvolutionNCxHWxThreadblockSwizzle< | |||
| cutlass::convolution::ConvType::kConvolution>, | |||
| 2, 16, 16, true>; | |||
| template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>( | |||
| const int8_t* d_src, | |||
| const int8_t* d_filter, | |||
| const int32_t* d_bias, | |||
| const int8_t* d_z, | |||
| int8_t* d_dst, | |||
| int* workspace, | |||
| typename Convolution::ConvolutionParameter const& conv_param, | |||
| typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||
| cudaStream_t stream); | |||
| #pragma GCC diagnostic pop | |||
| #endif | |||
| @@ -0,0 +1,35 @@ | |||
| #if !MEGDNN_TEGRA_X1 | |||
| // generated by gen_cuda_conv_bias_kern_impls.py | |||
| // ignore warning of cutlass | |||
| #pragma GCC diagnostic push | |||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||
| #include "../conv_bias_int8_implicit_gemm_imma_ncdiv32hw32.cuinl" | |||
| using LayoutSrc = cutlass::layout::TensorNCxHWx<32>; | |||
| using LayoutFilter = cutlass::layout::TensorCxRSKx<32>; | |||
| using ThreadBlockShape = cutlass::gemm::GemmShape<128, 128, 64>; | |||
| using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; | |||
| using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; | |||
| using EpilogueOp = cutlass::epilogue::thread::BiasAddLinearCombinationReluClamp< | |||
| int8_t, 8, int32_t, int32_t, float>; | |||
| using Convolution = cutlass::convolution::device::Convolution< | |||
| int8_t, LayoutSrc, int8_t, LayoutFilter, int8_t, | |||
| LayoutSrc, int32_t, LayoutSrc, int32_t, | |||
| cutlass::convolution::ConvType::kConvolution, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, | |||
| ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, | |||
| cutlass::convolution::threadblock::ConvolutionNCxHWxThreadblockSwizzle< | |||
| cutlass::convolution::ConvType::kConvolution>, | |||
| 2, 16, 16, true>; | |||
| template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>( | |||
| const int8_t* d_src, | |||
| const int8_t* d_filter, | |||
| const int32_t* d_bias, | |||
| const int8_t* d_z, | |||
| int8_t* d_dst, | |||
| int* workspace, | |||
| typename Convolution::ConvolutionParameter const& conv_param, | |||
| typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||
| cudaStream_t stream); | |||
| #pragma GCC diagnostic pop | |||
| #endif | |||
| @@ -0,0 +1,35 @@ | |||
| #if !MEGDNN_TEGRA_X1 | |||
| // generated by gen_cuda_conv_bias_kern_impls.py | |||
| // ignore warning of cutlass | |||
| #pragma GCC diagnostic push | |||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||
| #include "../conv_bias_int8_implicit_gemm_imma_ncdiv32hw32.cuinl" | |||
| using LayoutSrc = cutlass::layout::TensorNCxHWx<32>; | |||
| using LayoutFilter = cutlass::layout::TensorCxRSKx<32>; | |||
| using ThreadBlockShape = cutlass::gemm::GemmShape<128, 256, 64>; | |||
| using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; | |||
| using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; | |||
| using EpilogueOp = cutlass::epilogue::thread::BiasAddLinearCombinationHSwishClamp< | |||
| int8_t, 8, int32_t, int32_t, float>; | |||
| using Convolution = cutlass::convolution::device::Convolution< | |||
| int8_t, LayoutSrc, int8_t, LayoutFilter, int8_t, | |||
| LayoutSrc, int32_t, LayoutSrc, int32_t, | |||
| cutlass::convolution::ConvType::kConvolution, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, | |||
| ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, | |||
| cutlass::convolution::threadblock::ConvolutionNCxHWxThreadblockSwizzle< | |||
| cutlass::convolution::ConvType::kConvolution>, | |||
| 2, 16, 16, true>; | |||
| template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>( | |||
| const int8_t* d_src, | |||
| const int8_t* d_filter, | |||
| const int32_t* d_bias, | |||
| const int8_t* d_z, | |||
| int8_t* d_dst, | |||
| int* workspace, | |||
| typename Convolution::ConvolutionParameter const& conv_param, | |||
| typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||
| cudaStream_t stream); | |||
| #pragma GCC diagnostic pop | |||
| #endif | |||
| @@ -0,0 +1,35 @@ | |||
| #if !MEGDNN_TEGRA_X1 | |||
| // generated by gen_cuda_conv_bias_kern_impls.py | |||
| // ignore warning of cutlass | |||
| #pragma GCC diagnostic push | |||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||
| #include "../conv_bias_int8_implicit_gemm_imma_ncdiv32hw32.cuinl" | |||
| using LayoutSrc = cutlass::layout::TensorNCxHWx<32>; | |||
| using LayoutFilter = cutlass::layout::TensorCxRSKx<32>; | |||
| using ThreadBlockShape = cutlass::gemm::GemmShape<128, 256, 64>; | |||
| using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; | |||
| using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; | |||
| using EpilogueOp = cutlass::epilogue::thread::BiasAddLinearCombinationClamp< | |||
| int8_t, 8, int32_t, int32_t, float>; | |||
| using Convolution = cutlass::convolution::device::Convolution< | |||
| int8_t, LayoutSrc, int8_t, LayoutFilter, int8_t, | |||
| LayoutSrc, int32_t, LayoutSrc, int32_t, | |||
| cutlass::convolution::ConvType::kConvolution, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, | |||
| ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, | |||
| cutlass::convolution::threadblock::ConvolutionNCxHWxThreadblockSwizzle< | |||
| cutlass::convolution::ConvType::kConvolution>, | |||
| 2, 16, 16, true>; | |||
| template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>( | |||
| const int8_t* d_src, | |||
| const int8_t* d_filter, | |||
| const int32_t* d_bias, | |||
| const int8_t* d_z, | |||
| int8_t* d_dst, | |||
| int* workspace, | |||
| typename Convolution::ConvolutionParameter const& conv_param, | |||
| typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||
| cudaStream_t stream); | |||
| #pragma GCC diagnostic pop | |||
| #endif | |||
| @@ -0,0 +1,35 @@ | |||
| #if !MEGDNN_TEGRA_X1 | |||
| // generated by gen_cuda_conv_bias_kern_impls.py | |||
| // ignore warning of cutlass | |||
| #pragma GCC diagnostic push | |||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||
| #include "../conv_bias_int8_implicit_gemm_imma_ncdiv32hw32.cuinl" | |||
| using LayoutSrc = cutlass::layout::TensorNCxHWx<32>; | |||
| using LayoutFilter = cutlass::layout::TensorCxRSKx<32>; | |||
| using ThreadBlockShape = cutlass::gemm::GemmShape<128, 256, 64>; | |||
| using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; | |||
| using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; | |||
| using EpilogueOp = cutlass::epilogue::thread::BiasAddLinearCombinationReluClamp< | |||
| int8_t, 8, int32_t, int32_t, float>; | |||
| using Convolution = cutlass::convolution::device::Convolution< | |||
| int8_t, LayoutSrc, int8_t, LayoutFilter, int8_t, | |||
| LayoutSrc, int32_t, LayoutSrc, int32_t, | |||
| cutlass::convolution::ConvType::kConvolution, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, | |||
| ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, | |||
| cutlass::convolution::threadblock::ConvolutionNCxHWxThreadblockSwizzle< | |||
| cutlass::convolution::ConvType::kConvolution>, | |||
| 2, 16, 16, true>; | |||
| template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>( | |||
| const int8_t* d_src, | |||
| const int8_t* d_filter, | |||
| const int32_t* d_bias, | |||
| const int8_t* d_z, | |||
| int8_t* d_dst, | |||
| int* workspace, | |||
| typename Convolution::ConvolutionParameter const& conv_param, | |||
| typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||
| cudaStream_t stream); | |||
| #pragma GCC diagnostic pop | |||
| #endif | |||
| @@ -0,0 +1,35 @@ | |||
| #if !MEGDNN_TEGRA_X1 | |||
| // generated by gen_cuda_conv_bias_kern_impls.py | |||
| // ignore warning of cutlass | |||
| #pragma GCC diagnostic push | |||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||
| #include "../conv_bias_int8_implicit_gemm_imma_ncdiv32hw32.cuinl" | |||
| using LayoutSrc = cutlass::layout::TensorNCxHWx<32>; | |||
| using LayoutFilter = cutlass::layout::TensorCxRSKx<32>; | |||
| using ThreadBlockShape = cutlass::gemm::GemmShape<128, 64, 64>; | |||
| using WarpShape = cutlass::gemm::GemmShape<64, 32, 64>; | |||
| using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; | |||
| using EpilogueOp = cutlass::epilogue::thread::BiasAddLinearCombinationHSwishClamp< | |||
| int8_t, 8, int32_t, int32_t, float>; | |||
| using Convolution = cutlass::convolution::device::Convolution< | |||
| int8_t, LayoutSrc, int8_t, LayoutFilter, int8_t, | |||
| LayoutSrc, int32_t, LayoutSrc, int32_t, | |||
| cutlass::convolution::ConvType::kConvolution, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, | |||
| ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, | |||
| cutlass::convolution::threadblock::ConvolutionNCxHWxThreadblockSwizzle< | |||
| cutlass::convolution::ConvType::kConvolution>, | |||
| 2, 16, 16, true>; | |||
| template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>( | |||
| const int8_t* d_src, | |||
| const int8_t* d_filter, | |||
| const int32_t* d_bias, | |||
| const int8_t* d_z, | |||
| int8_t* d_dst, | |||
| int* workspace, | |||
| typename Convolution::ConvolutionParameter const& conv_param, | |||
| typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||
| cudaStream_t stream); | |||
| #pragma GCC diagnostic pop | |||
| #endif | |||
| @@ -0,0 +1,35 @@ | |||
| #if !MEGDNN_TEGRA_X1 | |||
| // generated by gen_cuda_conv_bias_kern_impls.py | |||
| // ignore warning of cutlass | |||
| #pragma GCC diagnostic push | |||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||
| #include "../conv_bias_int8_implicit_gemm_imma_ncdiv32hw32.cuinl" | |||
| using LayoutSrc = cutlass::layout::TensorNCxHWx<32>; | |||
| using LayoutFilter = cutlass::layout::TensorCxRSKx<32>; | |||
| using ThreadBlockShape = cutlass::gemm::GemmShape<128, 64, 64>; | |||
| using WarpShape = cutlass::gemm::GemmShape<64, 32, 64>; | |||
| using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; | |||
| using EpilogueOp = cutlass::epilogue::thread::BiasAddLinearCombinationClamp< | |||
| int8_t, 8, int32_t, int32_t, float>; | |||
| using Convolution = cutlass::convolution::device::Convolution< | |||
| int8_t, LayoutSrc, int8_t, LayoutFilter, int8_t, | |||
| LayoutSrc, int32_t, LayoutSrc, int32_t, | |||
| cutlass::convolution::ConvType::kConvolution, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, | |||
| ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, | |||
| cutlass::convolution::threadblock::ConvolutionNCxHWxThreadblockSwizzle< | |||
| cutlass::convolution::ConvType::kConvolution>, | |||
| 2, 16, 16, true>; | |||
| template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>( | |||
| const int8_t* d_src, | |||
| const int8_t* d_filter, | |||
| const int32_t* d_bias, | |||
| const int8_t* d_z, | |||
| int8_t* d_dst, | |||
| int* workspace, | |||
| typename Convolution::ConvolutionParameter const& conv_param, | |||
| typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||
| cudaStream_t stream); | |||
| #pragma GCC diagnostic pop | |||
| #endif | |||
| @@ -0,0 +1,35 @@ | |||
| #if !MEGDNN_TEGRA_X1 | |||
| // generated by gen_cuda_conv_bias_kern_impls.py | |||
| // ignore warning of cutlass | |||
| #pragma GCC diagnostic push | |||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||
| #include "../conv_bias_int8_implicit_gemm_imma_ncdiv32hw32.cuinl" | |||
| using LayoutSrc = cutlass::layout::TensorNCxHWx<32>; | |||
| using LayoutFilter = cutlass::layout::TensorCxRSKx<32>; | |||
| using ThreadBlockShape = cutlass::gemm::GemmShape<128, 64, 64>; | |||
| using WarpShape = cutlass::gemm::GemmShape<64, 32, 64>; | |||
| using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; | |||
| using EpilogueOp = cutlass::epilogue::thread::BiasAddLinearCombinationReluClamp< | |||
| int8_t, 8, int32_t, int32_t, float>; | |||
| using Convolution = cutlass::convolution::device::Convolution< | |||
| int8_t, LayoutSrc, int8_t, LayoutFilter, int8_t, | |||
| LayoutSrc, int32_t, LayoutSrc, int32_t, | |||
| cutlass::convolution::ConvType::kConvolution, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, | |||
| ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, | |||
| cutlass::convolution::threadblock::ConvolutionNCxHWxThreadblockSwizzle< | |||
| cutlass::convolution::ConvType::kConvolution>, | |||
| 2, 16, 16, true>; | |||
| template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>( | |||
| const int8_t* d_src, | |||
| const int8_t* d_filter, | |||
| const int32_t* d_bias, | |||
| const int8_t* d_z, | |||
| int8_t* d_dst, | |||
| int* workspace, | |||
| typename Convolution::ConvolutionParameter const& conv_param, | |||
| typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||
| cudaStream_t stream); | |||
| #pragma GCC diagnostic pop | |||
| #endif | |||
| @@ -0,0 +1,35 @@ | |||
| #if !MEGDNN_TEGRA_X1 | |||
| // generated by gen_cuda_conv_bias_kern_impls.py | |||
| // ignore warning of cutlass | |||
| #pragma GCC diagnostic push | |||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||
| #include "../conv_bias_int8_implicit_gemm_imma_ncdiv32hw32.cuinl" | |||
| using LayoutSrc = cutlass::layout::TensorNCxHWx<32>; | |||
| using LayoutFilter = cutlass::layout::TensorCxRSKx<32>; | |||
| using ThreadBlockShape = cutlass::gemm::GemmShape<128, 128, 64>; | |||
| using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; | |||
| using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; | |||
| using EpilogueOp = cutlass::epilogue::thread::BiasAddLinearCombinationHSwishClamp< | |||
| int8_t, 8, int32_t, int32_t, float>; | |||
| using Convolution = cutlass::convolution::device::Convolution< | |||
| int8_t, LayoutSrc, int8_t, LayoutFilter, int8_t, | |||
| LayoutSrc, int32_t, LayoutSrc, int32_t, | |||
| cutlass::convolution::ConvType::kConvolution, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, | |||
| ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, | |||
| cutlass::convolution::threadblock::ConvolutionNCxHWxThreadblockSwizzle< | |||
| cutlass::convolution::ConvType::kConvolution>, | |||
| 2, 16, 16, false>; | |||
| template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>( | |||
| const int8_t* d_src, | |||
| const int8_t* d_filter, | |||
| const int32_t* d_bias, | |||
| const int8_t* d_z, | |||
| int8_t* d_dst, | |||
| int* workspace, | |||
| typename Convolution::ConvolutionParameter const& conv_param, | |||
| typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||
| cudaStream_t stream); | |||
| #pragma GCC diagnostic pop | |||
| #endif | |||
| @@ -0,0 +1,35 @@ | |||
| #if !MEGDNN_TEGRA_X1 | |||
| // generated by gen_cuda_conv_bias_kern_impls.py | |||
| // ignore warning of cutlass | |||
| #pragma GCC diagnostic push | |||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||
| #include "../conv_bias_int8_implicit_gemm_imma_ncdiv32hw32.cuinl" | |||
| using LayoutSrc = cutlass::layout::TensorNCxHWx<32>; | |||
| using LayoutFilter = cutlass::layout::TensorCxRSKx<32>; | |||
| using ThreadBlockShape = cutlass::gemm::GemmShape<128, 128, 64>; | |||
| using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; | |||
| using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; | |||
| using EpilogueOp = cutlass::epilogue::thread::BiasAddLinearCombinationClamp< | |||
| int8_t, 8, int32_t, int32_t, float>; | |||
| using Convolution = cutlass::convolution::device::Convolution< | |||
| int8_t, LayoutSrc, int8_t, LayoutFilter, int8_t, | |||
| LayoutSrc, int32_t, LayoutSrc, int32_t, | |||
| cutlass::convolution::ConvType::kConvolution, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, | |||
| ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, | |||
| cutlass::convolution::threadblock::ConvolutionNCxHWxThreadblockSwizzle< | |||
| cutlass::convolution::ConvType::kConvolution>, | |||
| 2, 16, 16, false>; | |||
| template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>( | |||
| const int8_t* d_src, | |||
| const int8_t* d_filter, | |||
| const int32_t* d_bias, | |||
| const int8_t* d_z, | |||
| int8_t* d_dst, | |||
| int* workspace, | |||
| typename Convolution::ConvolutionParameter const& conv_param, | |||
| typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||
| cudaStream_t stream); | |||
| #pragma GCC diagnostic pop | |||
| #endif | |||
| @@ -0,0 +1,35 @@ | |||
| #if !MEGDNN_TEGRA_X1 | |||
| // generated by gen_cuda_conv_bias_kern_impls.py | |||
| // ignore warning of cutlass | |||
| #pragma GCC diagnostic push | |||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||
| #include "../conv_bias_int8_implicit_gemm_imma_ncdiv32hw32.cuinl" | |||
| using LayoutSrc = cutlass::layout::TensorNCxHWx<32>; | |||
| using LayoutFilter = cutlass::layout::TensorCxRSKx<32>; | |||
| using ThreadBlockShape = cutlass::gemm::GemmShape<128, 128, 64>; | |||
| using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; | |||
| using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; | |||
| using EpilogueOp = cutlass::epilogue::thread::BiasAddLinearCombinationReluClamp< | |||
| int8_t, 8, int32_t, int32_t, float>; | |||
| using Convolution = cutlass::convolution::device::Convolution< | |||
| int8_t, LayoutSrc, int8_t, LayoutFilter, int8_t, | |||
| LayoutSrc, int32_t, LayoutSrc, int32_t, | |||
| cutlass::convolution::ConvType::kConvolution, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, | |||
| ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, | |||
| cutlass::convolution::threadblock::ConvolutionNCxHWxThreadblockSwizzle< | |||
| cutlass::convolution::ConvType::kConvolution>, | |||
| 2, 16, 16, false>; | |||
| template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>( | |||
| const int8_t* d_src, | |||
| const int8_t* d_filter, | |||
| const int32_t* d_bias, | |||
| const int8_t* d_z, | |||
| int8_t* d_dst, | |||
| int* workspace, | |||
| typename Convolution::ConvolutionParameter const& conv_param, | |||
| typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||
| cudaStream_t stream); | |||
| #pragma GCC diagnostic pop | |||
| #endif | |||
| @@ -0,0 +1,35 @@ | |||
| #if !MEGDNN_TEGRA_X1 | |||
| // generated by gen_cuda_conv_bias_kern_impls.py | |||
| // ignore warning of cutlass | |||
| #pragma GCC diagnostic push | |||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||
| #include "../conv_bias_int8_implicit_gemm_imma_ncdiv32hw32.cuinl" | |||
| using LayoutSrc = cutlass::layout::TensorNCxHWx<32>; | |||
| using LayoutFilter = cutlass::layout::TensorCxRSKx<32>; | |||
| using ThreadBlockShape = cutlass::gemm::GemmShape<128, 256, 64>; | |||
| using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; | |||
| using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; | |||
| using EpilogueOp = cutlass::epilogue::thread::BiasAddLinearCombinationHSwishClamp< | |||
| int8_t, 8, int32_t, int32_t, float>; | |||
| using Convolution = cutlass::convolution::device::Convolution< | |||
| int8_t, LayoutSrc, int8_t, LayoutFilter, int8_t, | |||
| LayoutSrc, int32_t, LayoutSrc, int32_t, | |||
| cutlass::convolution::ConvType::kConvolution, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, | |||
| ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, | |||
| cutlass::convolution::threadblock::ConvolutionNCxHWxThreadblockSwizzle< | |||
| cutlass::convolution::ConvType::kConvolution>, | |||
| 2, 16, 16, false>; | |||
| template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>( | |||
| const int8_t* d_src, | |||
| const int8_t* d_filter, | |||
| const int32_t* d_bias, | |||
| const int8_t* d_z, | |||
| int8_t* d_dst, | |||
| int* workspace, | |||
| typename Convolution::ConvolutionParameter const& conv_param, | |||
| typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||
| cudaStream_t stream); | |||
| #pragma GCC diagnostic pop | |||
| #endif | |||
| @@ -0,0 +1,35 @@ | |||
| #if !MEGDNN_TEGRA_X1 | |||
| // generated by gen_cuda_conv_bias_kern_impls.py | |||
| // ignore warning of cutlass | |||
| #pragma GCC diagnostic push | |||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||
| #include "../conv_bias_int8_implicit_gemm_imma_ncdiv32hw32.cuinl" | |||
| using LayoutSrc = cutlass::layout::TensorNCxHWx<32>; | |||
| using LayoutFilter = cutlass::layout::TensorCxRSKx<32>; | |||
| using ThreadBlockShape = cutlass::gemm::GemmShape<128, 256, 64>; | |||
| using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; | |||
| using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; | |||
| using EpilogueOp = cutlass::epilogue::thread::BiasAddLinearCombinationClamp< | |||
| int8_t, 8, int32_t, int32_t, float>; | |||
| using Convolution = cutlass::convolution::device::Convolution< | |||
| int8_t, LayoutSrc, int8_t, LayoutFilter, int8_t, | |||
| LayoutSrc, int32_t, LayoutSrc, int32_t, | |||
| cutlass::convolution::ConvType::kConvolution, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, | |||
| ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, | |||
| cutlass::convolution::threadblock::ConvolutionNCxHWxThreadblockSwizzle< | |||
| cutlass::convolution::ConvType::kConvolution>, | |||
| 2, 16, 16, false>; | |||
| template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>( | |||
| const int8_t* d_src, | |||
| const int8_t* d_filter, | |||
| const int32_t* d_bias, | |||
| const int8_t* d_z, | |||
| int8_t* d_dst, | |||
| int* workspace, | |||
| typename Convolution::ConvolutionParameter const& conv_param, | |||
| typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||
| cudaStream_t stream); | |||
| #pragma GCC diagnostic pop | |||
| #endif | |||
| @@ -0,0 +1,35 @@ | |||
| #if !MEGDNN_TEGRA_X1 | |||
| // generated by gen_cuda_conv_bias_kern_impls.py | |||
| // ignore warning of cutlass | |||
| #pragma GCC diagnostic push | |||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||
| #include "../conv_bias_int8_implicit_gemm_imma_ncdiv32hw32.cuinl" | |||
| using LayoutSrc = cutlass::layout::TensorNCxHWx<32>; | |||
| using LayoutFilter = cutlass::layout::TensorCxRSKx<32>; | |||
| using ThreadBlockShape = cutlass::gemm::GemmShape<128, 256, 64>; | |||
| using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; | |||
| using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; | |||
| using EpilogueOp = cutlass::epilogue::thread::BiasAddLinearCombinationReluClamp< | |||
| int8_t, 8, int32_t, int32_t, float>; | |||
| using Convolution = cutlass::convolution::device::Convolution< | |||
| int8_t, LayoutSrc, int8_t, LayoutFilter, int8_t, | |||
| LayoutSrc, int32_t, LayoutSrc, int32_t, | |||
| cutlass::convolution::ConvType::kConvolution, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, | |||
| ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, | |||
| cutlass::convolution::threadblock::ConvolutionNCxHWxThreadblockSwizzle< | |||
| cutlass::convolution::ConvType::kConvolution>, | |||
| 2, 16, 16, false>; | |||
| template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>( | |||
| const int8_t* d_src, | |||
| const int8_t* d_filter, | |||
| const int32_t* d_bias, | |||
| const int8_t* d_z, | |||
| int8_t* d_dst, | |||
| int* workspace, | |||
| typename Convolution::ConvolutionParameter const& conv_param, | |||
| typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||
| cudaStream_t stream); | |||
| #pragma GCC diagnostic pop | |||
| #endif | |||
| @@ -0,0 +1,35 @@ | |||
| #if !MEGDNN_TEGRA_X1 | |||
| // generated by gen_cuda_conv_bias_kern_impls.py | |||
| // ignore warning of cutlass | |||
| #pragma GCC diagnostic push | |||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||
| #include "../conv_bias_int8_implicit_gemm_imma_ncdiv32hw32.cuinl" | |||
| using LayoutSrc = cutlass::layout::TensorNCxHWx<32>; | |||
| using LayoutFilter = cutlass::layout::TensorCxRSKx<32>; | |||
| using ThreadBlockShape = cutlass::gemm::GemmShape<128, 64, 64>; | |||
| using WarpShape = cutlass::gemm::GemmShape<64, 32, 64>; | |||
| using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; | |||
| using EpilogueOp = cutlass::epilogue::thread::BiasAddLinearCombinationHSwishClamp< | |||
| int8_t, 8, int32_t, int32_t, float>; | |||
| using Convolution = cutlass::convolution::device::Convolution< | |||
| int8_t, LayoutSrc, int8_t, LayoutFilter, int8_t, | |||
| LayoutSrc, int32_t, LayoutSrc, int32_t, | |||
| cutlass::convolution::ConvType::kConvolution, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, | |||
| ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, | |||
| cutlass::convolution::threadblock::ConvolutionNCxHWxThreadblockSwizzle< | |||
| cutlass::convolution::ConvType::kConvolution>, | |||
| 2, 16, 16, false>; | |||
| template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>( | |||
| const int8_t* d_src, | |||
| const int8_t* d_filter, | |||
| const int32_t* d_bias, | |||
| const int8_t* d_z, | |||
| int8_t* d_dst, | |||
| int* workspace, | |||
| typename Convolution::ConvolutionParameter const& conv_param, | |||
| typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||
| cudaStream_t stream); | |||
| #pragma GCC diagnostic pop | |||
| #endif | |||
| @@ -0,0 +1,35 @@ | |||
| #if !MEGDNN_TEGRA_X1 | |||
| // generated by gen_cuda_conv_bias_kern_impls.py | |||
| // ignore warning of cutlass | |||
| #pragma GCC diagnostic push | |||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||
| #include "../conv_bias_int8_implicit_gemm_imma_ncdiv32hw32.cuinl" | |||
| using LayoutSrc = cutlass::layout::TensorNCxHWx<32>; | |||
| using LayoutFilter = cutlass::layout::TensorCxRSKx<32>; | |||
| using ThreadBlockShape = cutlass::gemm::GemmShape<128, 64, 64>; | |||
| using WarpShape = cutlass::gemm::GemmShape<64, 32, 64>; | |||
| using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; | |||
| using EpilogueOp = cutlass::epilogue::thread::BiasAddLinearCombinationClamp< | |||
| int8_t, 8, int32_t, int32_t, float>; | |||
| using Convolution = cutlass::convolution::device::Convolution< | |||
| int8_t, LayoutSrc, int8_t, LayoutFilter, int8_t, | |||
| LayoutSrc, int32_t, LayoutSrc, int32_t, | |||
| cutlass::convolution::ConvType::kConvolution, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, | |||
| ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, | |||
| cutlass::convolution::threadblock::ConvolutionNCxHWxThreadblockSwizzle< | |||
| cutlass::convolution::ConvType::kConvolution>, | |||
| 2, 16, 16, false>; | |||
| template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>( | |||
| const int8_t* d_src, | |||
| const int8_t* d_filter, | |||
| const int32_t* d_bias, | |||
| const int8_t* d_z, | |||
| int8_t* d_dst, | |||
| int* workspace, | |||
| typename Convolution::ConvolutionParameter const& conv_param, | |||
| typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||
| cudaStream_t stream); | |||
| #pragma GCC diagnostic pop | |||
| #endif | |||
| @@ -0,0 +1,35 @@ | |||
| #if !MEGDNN_TEGRA_X1 | |||
| // generated by gen_cuda_conv_bias_kern_impls.py | |||
| // ignore warning of cutlass | |||
| #pragma GCC diagnostic push | |||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||
| #include "../conv_bias_int8_implicit_gemm_imma_ncdiv32hw32.cuinl" | |||
| using LayoutSrc = cutlass::layout::TensorNCxHWx<32>; | |||
| using LayoutFilter = cutlass::layout::TensorCxRSKx<32>; | |||
| using ThreadBlockShape = cutlass::gemm::GemmShape<128, 64, 64>; | |||
| using WarpShape = cutlass::gemm::GemmShape<64, 32, 64>; | |||
| using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; | |||
| using EpilogueOp = cutlass::epilogue::thread::BiasAddLinearCombinationReluClamp< | |||
| int8_t, 8, int32_t, int32_t, float>; | |||
| using Convolution = cutlass::convolution::device::Convolution< | |||
| int8_t, LayoutSrc, int8_t, LayoutFilter, int8_t, | |||
| LayoutSrc, int32_t, LayoutSrc, int32_t, | |||
| cutlass::convolution::ConvType::kConvolution, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, | |||
| ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, | |||
| cutlass::convolution::threadblock::ConvolutionNCxHWxThreadblockSwizzle< | |||
| cutlass::convolution::ConvType::kConvolution>, | |||
| 2, 16, 16, false>; | |||
| template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>( | |||
| const int8_t* d_src, | |||
| const int8_t* d_filter, | |||
| const int32_t* d_bias, | |||
| const int8_t* d_z, | |||
| int8_t* d_dst, | |||
| int* workspace, | |||
| typename Convolution::ConvolutionParameter const& conv_param, | |||
| typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||
| cudaStream_t stream); | |||
| #pragma GCC diagnostic pop | |||
| #endif | |||
| @@ -0,0 +1,35 @@ | |||
| #if !MEGDNN_TEGRA_X1 | |||
| // generated by gen_cuda_conv_bias_kern_impls.py | |||
| // ignore warning of cutlass | |||
| #pragma GCC diagnostic push | |||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||
| #include "../conv_bias_int8_implicit_gemm_imma_ncdiv32hw32.cuinl" | |||
| using LayoutSrc = cutlass::layout::TensorNCxHWx<32>; | |||
| using LayoutFilter = cutlass::layout::TensorCxRSKx<32>; | |||
| using ThreadBlockShape = cutlass::gemm::GemmShape<256, 128, 64>; | |||
| using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; | |||
| using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; | |||
| using EpilogueOp = cutlass::epilogue::thread::BiasAddLinearCombinationHSwishClamp< | |||
| int8_t, 8, int32_t, int32_t, float>; | |||
| using Convolution = cutlass::convolution::device::Convolution< | |||
| int8_t, LayoutSrc, int8_t, LayoutFilter, int8_t, | |||
| LayoutSrc, int32_t, LayoutSrc, int32_t, | |||
| cutlass::convolution::ConvType::kConvolution, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, | |||
| ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, | |||
| cutlass::convolution::threadblock::ConvolutionNCxHWxThreadblockSwizzle< | |||
| cutlass::convolution::ConvType::kConvolution>, | |||
| 2, 16, 16, false>; | |||
| template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>( | |||
| const int8_t* d_src, | |||
| const int8_t* d_filter, | |||
| const int32_t* d_bias, | |||
| const int8_t* d_z, | |||
| int8_t* d_dst, | |||
| int* workspace, | |||
| typename Convolution::ConvolutionParameter const& conv_param, | |||
| typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||
| cudaStream_t stream); | |||
| #pragma GCC diagnostic pop | |||
| #endif | |||
| @@ -0,0 +1,35 @@ | |||
| #if !MEGDNN_TEGRA_X1 | |||
| // generated by gen_cuda_conv_bias_kern_impls.py | |||
| // ignore warning of cutlass | |||
| #pragma GCC diagnostic push | |||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||
| #include "../conv_bias_int8_implicit_gemm_imma_ncdiv32hw32.cuinl" | |||
| using LayoutSrc = cutlass::layout::TensorNCxHWx<32>; | |||
| using LayoutFilter = cutlass::layout::TensorCxRSKx<32>; | |||
| using ThreadBlockShape = cutlass::gemm::GemmShape<256, 128, 64>; | |||
| using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; | |||
| using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; | |||
| using EpilogueOp = cutlass::epilogue::thread::BiasAddLinearCombinationClamp< | |||
| int8_t, 8, int32_t, int32_t, float>; | |||
| using Convolution = cutlass::convolution::device::Convolution< | |||
| int8_t, LayoutSrc, int8_t, LayoutFilter, int8_t, | |||
| LayoutSrc, int32_t, LayoutSrc, int32_t, | |||
| cutlass::convolution::ConvType::kConvolution, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, | |||
| ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, | |||
| cutlass::convolution::threadblock::ConvolutionNCxHWxThreadblockSwizzle< | |||
| cutlass::convolution::ConvType::kConvolution>, | |||
| 2, 16, 16, false>; | |||
| template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>( | |||
| const int8_t* d_src, | |||
| const int8_t* d_filter, | |||
| const int32_t* d_bias, | |||
| const int8_t* d_z, | |||
| int8_t* d_dst, | |||
| int* workspace, | |||
| typename Convolution::ConvolutionParameter const& conv_param, | |||
| typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||
| cudaStream_t stream); | |||
| #pragma GCC diagnostic pop | |||
| #endif | |||
| @@ -0,0 +1,35 @@ | |||
| #if !MEGDNN_TEGRA_X1 | |||
| // generated by gen_cuda_conv_bias_kern_impls.py | |||
| // ignore warning of cutlass | |||
| #pragma GCC diagnostic push | |||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||
| #include "../conv_bias_int8_implicit_gemm_imma_ncdiv32hw32.cuinl" | |||
| using LayoutSrc = cutlass::layout::TensorNCxHWx<32>; | |||
| using LayoutFilter = cutlass::layout::TensorCxRSKx<32>; | |||
| using ThreadBlockShape = cutlass::gemm::GemmShape<256, 128, 64>; | |||
| using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; | |||
| using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; | |||
| using EpilogueOp = cutlass::epilogue::thread::BiasAddLinearCombinationReluClamp< | |||
| int8_t, 8, int32_t, int32_t, float>; | |||
| using Convolution = cutlass::convolution::device::Convolution< | |||
| int8_t, LayoutSrc, int8_t, LayoutFilter, int8_t, | |||
| LayoutSrc, int32_t, LayoutSrc, int32_t, | |||
| cutlass::convolution::ConvType::kConvolution, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, | |||
| ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, | |||
| cutlass::convolution::threadblock::ConvolutionNCxHWxThreadblockSwizzle< | |||
| cutlass::convolution::ConvType::kConvolution>, | |||
| 2, 16, 16, false>; | |||
| template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>( | |||
| const int8_t* d_src, | |||
| const int8_t* d_filter, | |||
| const int32_t* d_bias, | |||
| const int8_t* d_z, | |||
| int8_t* d_dst, | |||
| int* workspace, | |||
| typename Convolution::ConvolutionParameter const& conv_param, | |||
| typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||
| cudaStream_t stream); | |||
| #pragma GCC diagnostic pop | |||
| #endif | |||
| @@ -0,0 +1,35 @@ | |||
| #if !MEGDNN_TEGRA_X1 | |||
| // generated by gen_cuda_conv_bias_kern_impls.py | |||
| // ignore warning of cutlass | |||
| #pragma GCC diagnostic push | |||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||
| #include "../conv_bias_int8_implicit_gemm_imma_ncdiv32hw32.cuinl" | |||
| using LayoutSrc = cutlass::layout::TensorNCxHWx<32>; | |||
| using LayoutFilter = cutlass::layout::TensorCxRSKx<32>; | |||
| using ThreadBlockShape = cutlass::gemm::GemmShape<32, 64, 64>; | |||
| using WarpShape = cutlass::gemm::GemmShape<32, 16, 64>; | |||
| using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; | |||
| using EpilogueOp = cutlass::epilogue::thread::BiasAddLinearCombinationHSwishClamp< | |||
| int8_t, 8, int32_t, int32_t, float>; | |||
| using Convolution = cutlass::convolution::device::Convolution< | |||
| int8_t, LayoutSrc, int8_t, LayoutFilter, int8_t, | |||
| LayoutSrc, int32_t, LayoutSrc, int32_t, | |||
| cutlass::convolution::ConvType::kConvolution, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, | |||
| ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, | |||
| cutlass::convolution::threadblock::ConvolutionNCxHWxThreadblockSwizzle< | |||
| cutlass::convolution::ConvType::kConvolution>, | |||
| 2, 16, 16, false>; | |||
| template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>( | |||
| const int8_t* d_src, | |||
| const int8_t* d_filter, | |||
| const int32_t* d_bias, | |||
| const int8_t* d_z, | |||
| int8_t* d_dst, | |||
| int* workspace, | |||
| typename Convolution::ConvolutionParameter const& conv_param, | |||
| typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||
| cudaStream_t stream); | |||
| #pragma GCC diagnostic pop | |||
| #endif | |||
| @@ -0,0 +1,35 @@ | |||
| #if !MEGDNN_TEGRA_X1 | |||
| // generated by gen_cuda_conv_bias_kern_impls.py | |||
| // ignore warning of cutlass | |||
| #pragma GCC diagnostic push | |||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||
| #include "../conv_bias_int8_implicit_gemm_imma_ncdiv32hw32.cuinl" | |||
| using LayoutSrc = cutlass::layout::TensorNCxHWx<32>; | |||
| using LayoutFilter = cutlass::layout::TensorCxRSKx<32>; | |||
| using ThreadBlockShape = cutlass::gemm::GemmShape<32, 64, 64>; | |||
| using WarpShape = cutlass::gemm::GemmShape<32, 16, 64>; | |||
| using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; | |||
| using EpilogueOp = cutlass::epilogue::thread::BiasAddLinearCombinationClamp< | |||
| int8_t, 8, int32_t, int32_t, float>; | |||
| using Convolution = cutlass::convolution::device::Convolution< | |||
| int8_t, LayoutSrc, int8_t, LayoutFilter, int8_t, | |||
| LayoutSrc, int32_t, LayoutSrc, int32_t, | |||
| cutlass::convolution::ConvType::kConvolution, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, | |||
| ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, | |||
| cutlass::convolution::threadblock::ConvolutionNCxHWxThreadblockSwizzle< | |||
| cutlass::convolution::ConvType::kConvolution>, | |||
| 2, 16, 16, false>; | |||
| template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>( | |||
| const int8_t* d_src, | |||
| const int8_t* d_filter, | |||
| const int32_t* d_bias, | |||
| const int8_t* d_z, | |||
| int8_t* d_dst, | |||
| int* workspace, | |||
| typename Convolution::ConvolutionParameter const& conv_param, | |||
| typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||
| cudaStream_t stream); | |||
| #pragma GCC diagnostic pop | |||
| #endif | |||
| @@ -0,0 +1,35 @@ | |||
| #if !MEGDNN_TEGRA_X1 | |||
| // generated by gen_cuda_conv_bias_kern_impls.py | |||
| // ignore warning of cutlass | |||
| #pragma GCC diagnostic push | |||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||
| #include "../conv_bias_int8_implicit_gemm_imma_ncdiv32hw32.cuinl" | |||
| using LayoutSrc = cutlass::layout::TensorNCxHWx<32>; | |||
| using LayoutFilter = cutlass::layout::TensorCxRSKx<32>; | |||
| using ThreadBlockShape = cutlass::gemm::GemmShape<32, 64, 64>; | |||
| using WarpShape = cutlass::gemm::GemmShape<32, 16, 64>; | |||
| using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; | |||
| using EpilogueOp = cutlass::epilogue::thread::BiasAddLinearCombinationReluClamp< | |||
| int8_t, 8, int32_t, int32_t, float>; | |||
| using Convolution = cutlass::convolution::device::Convolution< | |||
| int8_t, LayoutSrc, int8_t, LayoutFilter, int8_t, | |||
| LayoutSrc, int32_t, LayoutSrc, int32_t, | |||
| cutlass::convolution::ConvType::kConvolution, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, | |||
| ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, | |||
| cutlass::convolution::threadblock::ConvolutionNCxHWxThreadblockSwizzle< | |||
| cutlass::convolution::ConvType::kConvolution>, | |||
| 2, 16, 16, false>; | |||
| template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>( | |||
| const int8_t* d_src, | |||
| const int8_t* d_filter, | |||
| const int32_t* d_bias, | |||
| const int8_t* d_z, | |||
| int8_t* d_dst, | |||
| int* workspace, | |||
| typename Convolution::ConvolutionParameter const& conv_param, | |||
| typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||
| cudaStream_t stream); | |||
| #pragma GCC diagnostic pop | |||
| #endif | |||
| @@ -0,0 +1,35 @@ | |||
| #if !MEGDNN_TEGRA_X1 | |||
| // generated by gen_cuda_conv_bias_kern_impls.py | |||
| // ignore warning of cutlass | |||
| #pragma GCC diagnostic push | |||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||
| #include "../conv_bias_int8_implicit_gemm_imma_ncdiv32hw32.cuinl" | |||
| using LayoutSrc = cutlass::layout::TensorNCxHWx<32>; | |||
| using LayoutFilter = cutlass::layout::TensorCxRSKx<32>; | |||
| using ThreadBlockShape = cutlass::gemm::GemmShape<64, 128, 64>; | |||
| using WarpShape = cutlass::gemm::GemmShape<32, 64, 64>; | |||
| using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; | |||
| using EpilogueOp = cutlass::epilogue::thread::BiasAddLinearCombinationHSwishClamp< | |||
| int8_t, 8, int32_t, int32_t, float>; | |||
| using Convolution = cutlass::convolution::device::Convolution< | |||
| int8_t, LayoutSrc, int8_t, LayoutFilter, int8_t, | |||
| LayoutSrc, int32_t, LayoutSrc, int32_t, | |||
| cutlass::convolution::ConvType::kConvolution, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, | |||
| ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, | |||
| cutlass::convolution::threadblock::ConvolutionNCxHWxThreadblockSwizzle< | |||
| cutlass::convolution::ConvType::kConvolution>, | |||
| 2, 16, 16, false>; | |||
| template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>( | |||
| const int8_t* d_src, | |||
| const int8_t* d_filter, | |||
| const int32_t* d_bias, | |||
| const int8_t* d_z, | |||
| int8_t* d_dst, | |||
| int* workspace, | |||
| typename Convolution::ConvolutionParameter const& conv_param, | |||
| typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||
| cudaStream_t stream); | |||
| #pragma GCC diagnostic pop | |||
| #endif | |||
| @@ -0,0 +1,35 @@ | |||
| #if !MEGDNN_TEGRA_X1 | |||
| // generated by gen_cuda_conv_bias_kern_impls.py | |||
| // ignore warning of cutlass | |||
| #pragma GCC diagnostic push | |||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||
| #include "../conv_bias_int8_implicit_gemm_imma_ncdiv32hw32.cuinl" | |||
| using LayoutSrc = cutlass::layout::TensorNCxHWx<32>; | |||
| using LayoutFilter = cutlass::layout::TensorCxRSKx<32>; | |||
| using ThreadBlockShape = cutlass::gemm::GemmShape<64, 128, 64>; | |||
| using WarpShape = cutlass::gemm::GemmShape<32, 64, 64>; | |||
| using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; | |||
| using EpilogueOp = cutlass::epilogue::thread::BiasAddLinearCombinationClamp< | |||
| int8_t, 8, int32_t, int32_t, float>; | |||
| using Convolution = cutlass::convolution::device::Convolution< | |||
| int8_t, LayoutSrc, int8_t, LayoutFilter, int8_t, | |||
| LayoutSrc, int32_t, LayoutSrc, int32_t, | |||
| cutlass::convolution::ConvType::kConvolution, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, | |||
| ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, | |||
| cutlass::convolution::threadblock::ConvolutionNCxHWxThreadblockSwizzle< | |||
| cutlass::convolution::ConvType::kConvolution>, | |||
| 2, 16, 16, false>; | |||
| template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>( | |||
| const int8_t* d_src, | |||
| const int8_t* d_filter, | |||
| const int32_t* d_bias, | |||
| const int8_t* d_z, | |||
| int8_t* d_dst, | |||
| int* workspace, | |||
| typename Convolution::ConvolutionParameter const& conv_param, | |||
| typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||
| cudaStream_t stream); | |||
| #pragma GCC diagnostic pop | |||
| #endif | |||
| @@ -0,0 +1,35 @@ | |||
| #if !MEGDNN_TEGRA_X1 | |||
| // generated by gen_cuda_conv_bias_kern_impls.py | |||
| // ignore warning of cutlass | |||
| #pragma GCC diagnostic push | |||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||
| #include "../conv_bias_int8_implicit_gemm_imma_ncdiv32hw32.cuinl" | |||
| using LayoutSrc = cutlass::layout::TensorNCxHWx<32>; | |||
| using LayoutFilter = cutlass::layout::TensorCxRSKx<32>; | |||
| using ThreadBlockShape = cutlass::gemm::GemmShape<64, 128, 64>; | |||
| using WarpShape = cutlass::gemm::GemmShape<32, 64, 64>; | |||
| using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; | |||
| using EpilogueOp = cutlass::epilogue::thread::BiasAddLinearCombinationReluClamp< | |||
| int8_t, 8, int32_t, int32_t, float>; | |||
| using Convolution = cutlass::convolution::device::Convolution< | |||
| int8_t, LayoutSrc, int8_t, LayoutFilter, int8_t, | |||
| LayoutSrc, int32_t, LayoutSrc, int32_t, | |||
| cutlass::convolution::ConvType::kConvolution, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, | |||
| ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, | |||
| cutlass::convolution::threadblock::ConvolutionNCxHWxThreadblockSwizzle< | |||
| cutlass::convolution::ConvType::kConvolution>, | |||
| 2, 16, 16, false>; | |||
| template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>( | |||
| const int8_t* d_src, | |||
| const int8_t* d_filter, | |||
| const int32_t* d_bias, | |||
| const int8_t* d_z, | |||
| int8_t* d_dst, | |||
| int* workspace, | |||
| typename Convolution::ConvolutionParameter const& conv_param, | |||
| typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||
| cudaStream_t stream); | |||
| #pragma GCC diagnostic pop | |||
| #endif | |||
| @@ -0,0 +1,35 @@ | |||
| #if !MEGDNN_TEGRA_X1 | |||
| // generated by gen_cuda_conv_bias_kern_impls.py | |||
| // ignore warning of cutlass | |||
| #pragma GCC diagnostic push | |||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||
| #include "../conv_bias_int8_implicit_gemm_imma_ncdiv32hw32.cuinl" | |||
| using LayoutSrc = cutlass::layout::TensorNCxHWx<32>; | |||
| using LayoutFilter = cutlass::layout::TensorCxRSKx<32>; | |||
| using ThreadBlockShape = cutlass::gemm::GemmShape<64, 64, 64>; | |||
| using WarpShape = cutlass::gemm::GemmShape<32, 32, 64>; | |||
| using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; | |||
| using EpilogueOp = cutlass::epilogue::thread::BiasAddLinearCombinationHSwishClamp< | |||
| int8_t, 8, int32_t, int32_t, float>; | |||
| using Convolution = cutlass::convolution::device::Convolution< | |||
| int8_t, LayoutSrc, int8_t, LayoutFilter, int8_t, | |||
| LayoutSrc, int32_t, LayoutSrc, int32_t, | |||
| cutlass::convolution::ConvType::kConvolution, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, | |||
| ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, | |||
| cutlass::convolution::threadblock::ConvolutionNCxHWxThreadblockSwizzle< | |||
| cutlass::convolution::ConvType::kConvolution>, | |||
| 2, 16, 16, false>; | |||
| template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>( | |||
| const int8_t* d_src, | |||
| const int8_t* d_filter, | |||
| const int32_t* d_bias, | |||
| const int8_t* d_z, | |||
| int8_t* d_dst, | |||
| int* workspace, | |||
| typename Convolution::ConvolutionParameter const& conv_param, | |||
| typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||
| cudaStream_t stream); | |||
| #pragma GCC diagnostic pop | |||
| #endif | |||
| @@ -0,0 +1,35 @@ | |||
| #if !MEGDNN_TEGRA_X1 | |||
| // generated by gen_cuda_conv_bias_kern_impls.py | |||
| // ignore warning of cutlass | |||
| #pragma GCC diagnostic push | |||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||
| #include "../conv_bias_int8_implicit_gemm_imma_ncdiv32hw32.cuinl" | |||
| using LayoutSrc = cutlass::layout::TensorNCxHWx<32>; | |||
| using LayoutFilter = cutlass::layout::TensorCxRSKx<32>; | |||
| using ThreadBlockShape = cutlass::gemm::GemmShape<64, 64, 64>; | |||
| using WarpShape = cutlass::gemm::GemmShape<32, 32, 64>; | |||
| using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; | |||
| using EpilogueOp = cutlass::epilogue::thread::BiasAddLinearCombinationClamp< | |||
| int8_t, 8, int32_t, int32_t, float>; | |||
| using Convolution = cutlass::convolution::device::Convolution< | |||
| int8_t, LayoutSrc, int8_t, LayoutFilter, int8_t, | |||
| LayoutSrc, int32_t, LayoutSrc, int32_t, | |||
| cutlass::convolution::ConvType::kConvolution, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, | |||
| ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, | |||
| cutlass::convolution::threadblock::ConvolutionNCxHWxThreadblockSwizzle< | |||
| cutlass::convolution::ConvType::kConvolution>, | |||
| 2, 16, 16, false>; | |||
| template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>( | |||
| const int8_t* d_src, | |||
| const int8_t* d_filter, | |||
| const int32_t* d_bias, | |||
| const int8_t* d_z, | |||
| int8_t* d_dst, | |||
| int* workspace, | |||
| typename Convolution::ConvolutionParameter const& conv_param, | |||
| typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||
| cudaStream_t stream); | |||
| #pragma GCC diagnostic pop | |||
| #endif | |||
| @@ -0,0 +1,35 @@ | |||
| #if !MEGDNN_TEGRA_X1 | |||
| // generated by gen_cuda_conv_bias_kern_impls.py | |||
| // ignore warning of cutlass | |||
| #pragma GCC diagnostic push | |||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||
| #include "../conv_bias_int8_implicit_gemm_imma_ncdiv32hw32.cuinl" | |||
| using LayoutSrc = cutlass::layout::TensorNCxHWx<32>; | |||
| using LayoutFilter = cutlass::layout::TensorCxRSKx<32>; | |||
| using ThreadBlockShape = cutlass::gemm::GemmShape<64, 64, 64>; | |||
| using WarpShape = cutlass::gemm::GemmShape<32, 32, 64>; | |||
| using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; | |||
| using EpilogueOp = cutlass::epilogue::thread::BiasAddLinearCombinationReluClamp< | |||
| int8_t, 8, int32_t, int32_t, float>; | |||
| using Convolution = cutlass::convolution::device::Convolution< | |||
| int8_t, LayoutSrc, int8_t, LayoutFilter, int8_t, | |||
| LayoutSrc, int32_t, LayoutSrc, int32_t, | |||
| cutlass::convolution::ConvType::kConvolution, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, | |||
| ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, | |||
| cutlass::convolution::threadblock::ConvolutionNCxHWxThreadblockSwizzle< | |||
| cutlass::convolution::ConvType::kConvolution>, | |||
| 2, 16, 16, false>; | |||
| template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>( | |||
| const int8_t* d_src, | |||
| const int8_t* d_filter, | |||
| const int32_t* d_bias, | |||
| const int8_t* d_z, | |||
| int8_t* d_dst, | |||
| int* workspace, | |||
| typename Convolution::ConvolutionParameter const& conv_param, | |||
| typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||
| cudaStream_t stream); | |||
| #pragma GCC diagnostic pop | |||
| #endif | |||
| @@ -0,0 +1,35 @@ | |||
| #if !MEGDNN_TEGRA_X1 | |||
| // generated by gen_cuda_conv_bias_kern_impls.py | |||
| // ignore warning of cutlass | |||
| #pragma GCC diagnostic push | |||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||
| #include "../conv_bias_int8_implicit_gemm_imma_ncdiv32hw32.cuinl" | |||
| using LayoutSrc = cutlass::layout::TensorNCxHWx<32>; | |||
| using LayoutFilter = cutlass::layout::TensorCxRSKx<32>; | |||
| using ThreadBlockShape = cutlass::gemm::GemmShape<256, 128, 64>; | |||
| using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; | |||
| using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; | |||
| using EpilogueOp = cutlass::epilogue::thread::BiasAddLinearCombinationHSwishClamp< | |||
| int8_t, 8, int32_t, int32_t, float>; | |||
| using Convolution = cutlass::convolution::device::Convolution< | |||
| int8_t, LayoutSrc, int8_t, LayoutFilter, int8_t, | |||
| LayoutSrc, int32_t, LayoutSrc, int32_t, | |||
| cutlass::convolution::ConvType::kConvolution, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, | |||
| ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, | |||
| cutlass::convolution::threadblock::ConvolutionNCxHWxThreadblockSwizzle< | |||
| cutlass::convolution::ConvType::kConvolution>, | |||
| 2, 16, 16, true>; | |||
| template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>( | |||
| const int8_t* d_src, | |||
| const int8_t* d_filter, | |||
| const int32_t* d_bias, | |||
| const int8_t* d_z, | |||
| int8_t* d_dst, | |||
| int* workspace, | |||
| typename Convolution::ConvolutionParameter const& conv_param, | |||
| typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||
| cudaStream_t stream); | |||
| #pragma GCC diagnostic pop | |||
| #endif | |||
| @@ -0,0 +1,35 @@ | |||
| #if !MEGDNN_TEGRA_X1 | |||
| // generated by gen_cuda_conv_bias_kern_impls.py | |||
| // ignore warning of cutlass | |||
| #pragma GCC diagnostic push | |||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||
| #include "../conv_bias_int8_implicit_gemm_imma_ncdiv32hw32.cuinl" | |||
| using LayoutSrc = cutlass::layout::TensorNCxHWx<32>; | |||
| using LayoutFilter = cutlass::layout::TensorCxRSKx<32>; | |||
| using ThreadBlockShape = cutlass::gemm::GemmShape<256, 128, 64>; | |||
| using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; | |||
| using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; | |||
| using EpilogueOp = cutlass::epilogue::thread::BiasAddLinearCombinationClamp< | |||
| int8_t, 8, int32_t, int32_t, float>; | |||
| using Convolution = cutlass::convolution::device::Convolution< | |||
| int8_t, LayoutSrc, int8_t, LayoutFilter, int8_t, | |||
| LayoutSrc, int32_t, LayoutSrc, int32_t, | |||
| cutlass::convolution::ConvType::kConvolution, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, | |||
| ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, | |||
| cutlass::convolution::threadblock::ConvolutionNCxHWxThreadblockSwizzle< | |||
| cutlass::convolution::ConvType::kConvolution>, | |||
| 2, 16, 16, true>; | |||
| template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>( | |||
| const int8_t* d_src, | |||
| const int8_t* d_filter, | |||
| const int32_t* d_bias, | |||
| const int8_t* d_z, | |||
| int8_t* d_dst, | |||
| int* workspace, | |||
| typename Convolution::ConvolutionParameter const& conv_param, | |||
| typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||
| cudaStream_t stream); | |||
| #pragma GCC diagnostic pop | |||
| #endif | |||
| @@ -0,0 +1,35 @@ | |||
| #if !MEGDNN_TEGRA_X1 | |||
| // generated by gen_cuda_conv_bias_kern_impls.py | |||
| // ignore warning of cutlass | |||
| #pragma GCC diagnostic push | |||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||
| #include "../conv_bias_int8_implicit_gemm_imma_ncdiv32hw32.cuinl" | |||
| using LayoutSrc = cutlass::layout::TensorNCxHWx<32>; | |||
| using LayoutFilter = cutlass::layout::TensorCxRSKx<32>; | |||
| using ThreadBlockShape = cutlass::gemm::GemmShape<256, 128, 64>; | |||
| using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; | |||
| using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; | |||
| using EpilogueOp = cutlass::epilogue::thread::BiasAddLinearCombinationReluClamp< | |||
| int8_t, 8, int32_t, int32_t, float>; | |||
| using Convolution = cutlass::convolution::device::Convolution< | |||
| int8_t, LayoutSrc, int8_t, LayoutFilter, int8_t, | |||
| LayoutSrc, int32_t, LayoutSrc, int32_t, | |||
| cutlass::convolution::ConvType::kConvolution, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, | |||
| ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, | |||
| cutlass::convolution::threadblock::ConvolutionNCxHWxThreadblockSwizzle< | |||
| cutlass::convolution::ConvType::kConvolution>, | |||
| 2, 16, 16, true>; | |||
| template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>( | |||
| const int8_t* d_src, | |||
| const int8_t* d_filter, | |||
| const int32_t* d_bias, | |||
| const int8_t* d_z, | |||
| int8_t* d_dst, | |||
| int* workspace, | |||
| typename Convolution::ConvolutionParameter const& conv_param, | |||
| typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||
| cudaStream_t stream); | |||
| #pragma GCC diagnostic pop | |||
| #endif | |||
| @@ -0,0 +1,35 @@ | |||
| #if !MEGDNN_TEGRA_X1 | |||
| // generated by gen_cuda_conv_bias_kern_impls.py | |||
| // ignore warning of cutlass | |||
| #pragma GCC diagnostic push | |||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||
| #include "../conv_bias_int8_implicit_gemm_imma_ncdiv32hw32.cuinl" | |||
| using LayoutSrc = cutlass::layout::TensorNCxHWx<32>; | |||
| using LayoutFilter = cutlass::layout::TensorCxRSKx<32>; | |||
| using ThreadBlockShape = cutlass::gemm::GemmShape<32, 64, 64>; | |||
| using WarpShape = cutlass::gemm::GemmShape<32, 16, 64>; | |||
| using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; | |||
| using EpilogueOp = cutlass::epilogue::thread::BiasAddLinearCombinationHSwishClamp< | |||
| int8_t, 8, int32_t, int32_t, float>; | |||
| using Convolution = cutlass::convolution::device::Convolution< | |||
| int8_t, LayoutSrc, int8_t, LayoutFilter, int8_t, | |||
| LayoutSrc, int32_t, LayoutSrc, int32_t, | |||
| cutlass::convolution::ConvType::kConvolution, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, | |||
| ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, | |||
| cutlass::convolution::threadblock::ConvolutionNCxHWxThreadblockSwizzle< | |||
| cutlass::convolution::ConvType::kConvolution>, | |||
| 2, 16, 16, true>; | |||
| template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>( | |||
| const int8_t* d_src, | |||
| const int8_t* d_filter, | |||
| const int32_t* d_bias, | |||
| const int8_t* d_z, | |||
| int8_t* d_dst, | |||
| int* workspace, | |||
| typename Convolution::ConvolutionParameter const& conv_param, | |||
| typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||
| cudaStream_t stream); | |||
| #pragma GCC diagnostic pop | |||
| #endif | |||
| @@ -0,0 +1,35 @@ | |||
| #if !MEGDNN_TEGRA_X1 | |||
| // generated by gen_cuda_conv_bias_kern_impls.py | |||
| // ignore warning of cutlass | |||
| #pragma GCC diagnostic push | |||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||
| #include "../conv_bias_int8_implicit_gemm_imma_ncdiv32hw32.cuinl" | |||
| using LayoutSrc = cutlass::layout::TensorNCxHWx<32>; | |||
| using LayoutFilter = cutlass::layout::TensorCxRSKx<32>; | |||
| using ThreadBlockShape = cutlass::gemm::GemmShape<32, 64, 64>; | |||
| using WarpShape = cutlass::gemm::GemmShape<32, 16, 64>; | |||
| using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; | |||
| using EpilogueOp = cutlass::epilogue::thread::BiasAddLinearCombinationClamp< | |||
| int8_t, 8, int32_t, int32_t, float>; | |||
| using Convolution = cutlass::convolution::device::Convolution< | |||
| int8_t, LayoutSrc, int8_t, LayoutFilter, int8_t, | |||
| LayoutSrc, int32_t, LayoutSrc, int32_t, | |||
| cutlass::convolution::ConvType::kConvolution, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, | |||
| ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, | |||
| cutlass::convolution::threadblock::ConvolutionNCxHWxThreadblockSwizzle< | |||
| cutlass::convolution::ConvType::kConvolution>, | |||
| 2, 16, 16, true>; | |||
| template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>( | |||
| const int8_t* d_src, | |||
| const int8_t* d_filter, | |||
| const int32_t* d_bias, | |||
| const int8_t* d_z, | |||
| int8_t* d_dst, | |||
| int* workspace, | |||
| typename Convolution::ConvolutionParameter const& conv_param, | |||
| typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||
| cudaStream_t stream); | |||
| #pragma GCC diagnostic pop | |||
| #endif | |||
| @@ -0,0 +1,35 @@ | |||
| #if !MEGDNN_TEGRA_X1 | |||
| // generated by gen_cuda_conv_bias_kern_impls.py | |||
| // ignore warning of cutlass | |||
| #pragma GCC diagnostic push | |||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||
| #include "../conv_bias_int8_implicit_gemm_imma_ncdiv32hw32.cuinl" | |||
| using LayoutSrc = cutlass::layout::TensorNCxHWx<32>; | |||
| using LayoutFilter = cutlass::layout::TensorCxRSKx<32>; | |||
| using ThreadBlockShape = cutlass::gemm::GemmShape<32, 64, 64>; | |||
| using WarpShape = cutlass::gemm::GemmShape<32, 16, 64>; | |||
| using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; | |||
| using EpilogueOp = cutlass::epilogue::thread::BiasAddLinearCombinationReluClamp< | |||
| int8_t, 8, int32_t, int32_t, float>; | |||
| using Convolution = cutlass::convolution::device::Convolution< | |||
| int8_t, LayoutSrc, int8_t, LayoutFilter, int8_t, | |||
| LayoutSrc, int32_t, LayoutSrc, int32_t, | |||
| cutlass::convolution::ConvType::kConvolution, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, | |||
| ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, | |||
| cutlass::convolution::threadblock::ConvolutionNCxHWxThreadblockSwizzle< | |||
| cutlass::convolution::ConvType::kConvolution>, | |||
| 2, 16, 16, true>; | |||
| template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>( | |||
| const int8_t* d_src, | |||
| const int8_t* d_filter, | |||
| const int32_t* d_bias, | |||
| const int8_t* d_z, | |||
| int8_t* d_dst, | |||
| int* workspace, | |||
| typename Convolution::ConvolutionParameter const& conv_param, | |||
| typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||
| cudaStream_t stream); | |||
| #pragma GCC diagnostic pop | |||
| #endif | |||
| @@ -0,0 +1,35 @@ | |||
| #if !MEGDNN_TEGRA_X1 | |||
| // generated by gen_cuda_conv_bias_kern_impls.py | |||
| // ignore warning of cutlass | |||
| #pragma GCC diagnostic push | |||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||
| #include "../conv_bias_int8_implicit_gemm_imma_ncdiv32hw32.cuinl" | |||
| using LayoutSrc = cutlass::layout::TensorNCxHWx<32>; | |||
| using LayoutFilter = cutlass::layout::TensorCxRSKx<32>; | |||
| using ThreadBlockShape = cutlass::gemm::GemmShape<64, 128, 64>; | |||
| using WarpShape = cutlass::gemm::GemmShape<32, 64, 64>; | |||
| using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; | |||
| using EpilogueOp = cutlass::epilogue::thread::BiasAddLinearCombinationHSwishClamp< | |||
| int8_t, 8, int32_t, int32_t, float>; | |||
| using Convolution = cutlass::convolution::device::Convolution< | |||
| int8_t, LayoutSrc, int8_t, LayoutFilter, int8_t, | |||
| LayoutSrc, int32_t, LayoutSrc, int32_t, | |||
| cutlass::convolution::ConvType::kConvolution, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, | |||
| ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, | |||
| cutlass::convolution::threadblock::ConvolutionNCxHWxThreadblockSwizzle< | |||
| cutlass::convolution::ConvType::kConvolution>, | |||
| 2, 16, 16, true>; | |||
| template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>( | |||
| const int8_t* d_src, | |||
| const int8_t* d_filter, | |||
| const int32_t* d_bias, | |||
| const int8_t* d_z, | |||
| int8_t* d_dst, | |||
| int* workspace, | |||
| typename Convolution::ConvolutionParameter const& conv_param, | |||
| typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||
| cudaStream_t stream); | |||
| #pragma GCC diagnostic pop | |||
| #endif | |||
| @@ -0,0 +1,35 @@ | |||
| #if !MEGDNN_TEGRA_X1 | |||
| // generated by gen_cuda_conv_bias_kern_impls.py | |||
| // ignore warning of cutlass | |||
| #pragma GCC diagnostic push | |||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||
| #include "../conv_bias_int8_implicit_gemm_imma_ncdiv32hw32.cuinl" | |||
| using LayoutSrc = cutlass::layout::TensorNCxHWx<32>; | |||
| using LayoutFilter = cutlass::layout::TensorCxRSKx<32>; | |||
| using ThreadBlockShape = cutlass::gemm::GemmShape<64, 128, 64>; | |||
| using WarpShape = cutlass::gemm::GemmShape<32, 64, 64>; | |||
| using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; | |||
| using EpilogueOp = cutlass::epilogue::thread::BiasAddLinearCombinationClamp< | |||
| int8_t, 8, int32_t, int32_t, float>; | |||
| using Convolution = cutlass::convolution::device::Convolution< | |||
| int8_t, LayoutSrc, int8_t, LayoutFilter, int8_t, | |||
| LayoutSrc, int32_t, LayoutSrc, int32_t, | |||
| cutlass::convolution::ConvType::kConvolution, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, | |||
| ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, | |||
| cutlass::convolution::threadblock::ConvolutionNCxHWxThreadblockSwizzle< | |||
| cutlass::convolution::ConvType::kConvolution>, | |||
| 2, 16, 16, true>; | |||
| template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>( | |||
| const int8_t* d_src, | |||
| const int8_t* d_filter, | |||
| const int32_t* d_bias, | |||
| const int8_t* d_z, | |||
| int8_t* d_dst, | |||
| int* workspace, | |||
| typename Convolution::ConvolutionParameter const& conv_param, | |||
| typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||
| cudaStream_t stream); | |||
| #pragma GCC diagnostic pop | |||
| #endif | |||
| @@ -0,0 +1,35 @@ | |||
| #if !MEGDNN_TEGRA_X1 | |||
| // generated by gen_cuda_conv_bias_kern_impls.py | |||
| // ignore warning of cutlass | |||
| #pragma GCC diagnostic push | |||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||
| #include "../conv_bias_int8_implicit_gemm_imma_ncdiv32hw32.cuinl" | |||
| using LayoutSrc = cutlass::layout::TensorNCxHWx<32>; | |||
| using LayoutFilter = cutlass::layout::TensorCxRSKx<32>; | |||
| using ThreadBlockShape = cutlass::gemm::GemmShape<64, 128, 64>; | |||
| using WarpShape = cutlass::gemm::GemmShape<32, 64, 64>; | |||
| using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; | |||
| using EpilogueOp = cutlass::epilogue::thread::BiasAddLinearCombinationReluClamp< | |||
| int8_t, 8, int32_t, int32_t, float>; | |||
| using Convolution = cutlass::convolution::device::Convolution< | |||
| int8_t, LayoutSrc, int8_t, LayoutFilter, int8_t, | |||
| LayoutSrc, int32_t, LayoutSrc, int32_t, | |||
| cutlass::convolution::ConvType::kConvolution, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, | |||
| ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, | |||
| cutlass::convolution::threadblock::ConvolutionNCxHWxThreadblockSwizzle< | |||
| cutlass::convolution::ConvType::kConvolution>, | |||
| 2, 16, 16, true>; | |||
| template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>( | |||
| const int8_t* d_src, | |||
| const int8_t* d_filter, | |||
| const int32_t* d_bias, | |||
| const int8_t* d_z, | |||
| int8_t* d_dst, | |||
| int* workspace, | |||
| typename Convolution::ConvolutionParameter const& conv_param, | |||
| typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||
| cudaStream_t stream); | |||
| #pragma GCC diagnostic pop | |||
| #endif | |||
| @@ -0,0 +1,35 @@ | |||
| #if !MEGDNN_TEGRA_X1 | |||
| // generated by gen_cuda_conv_bias_kern_impls.py | |||
| // ignore warning of cutlass | |||
| #pragma GCC diagnostic push | |||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||
| #include "../conv_bias_int8_implicit_gemm_imma_ncdiv32hw32.cuinl" | |||
| using LayoutSrc = cutlass::layout::TensorNCxHWx<32>; | |||
| using LayoutFilter = cutlass::layout::TensorCxRSKx<32>; | |||
| using ThreadBlockShape = cutlass::gemm::GemmShape<64, 64, 64>; | |||
| using WarpShape = cutlass::gemm::GemmShape<32, 32, 64>; | |||
| using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; | |||
| using EpilogueOp = cutlass::epilogue::thread::BiasAddLinearCombinationHSwishClamp< | |||
| int8_t, 8, int32_t, int32_t, float>; | |||
| using Convolution = cutlass::convolution::device::Convolution< | |||
| int8_t, LayoutSrc, int8_t, LayoutFilter, int8_t, | |||
| LayoutSrc, int32_t, LayoutSrc, int32_t, | |||
| cutlass::convolution::ConvType::kConvolution, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, | |||
| ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, | |||
| cutlass::convolution::threadblock::ConvolutionNCxHWxThreadblockSwizzle< | |||
| cutlass::convolution::ConvType::kConvolution>, | |||
| 2, 16, 16, true>; | |||
| template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>( | |||
| const int8_t* d_src, | |||
| const int8_t* d_filter, | |||
| const int32_t* d_bias, | |||
| const int8_t* d_z, | |||
| int8_t* d_dst, | |||
| int* workspace, | |||
| typename Convolution::ConvolutionParameter const& conv_param, | |||
| typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||
| cudaStream_t stream); | |||
| #pragma GCC diagnostic pop | |||
| #endif | |||
| @@ -0,0 +1,35 @@ | |||
| #if !MEGDNN_TEGRA_X1 | |||
| // generated by gen_cuda_conv_bias_kern_impls.py | |||
| // ignore warning of cutlass | |||
| #pragma GCC diagnostic push | |||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||
| #include "../conv_bias_int8_implicit_gemm_imma_ncdiv32hw32.cuinl" | |||
| using LayoutSrc = cutlass::layout::TensorNCxHWx<32>; | |||
| using LayoutFilter = cutlass::layout::TensorCxRSKx<32>; | |||
| using ThreadBlockShape = cutlass::gemm::GemmShape<64, 64, 64>; | |||
| using WarpShape = cutlass::gemm::GemmShape<32, 32, 64>; | |||
| using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; | |||
| using EpilogueOp = cutlass::epilogue::thread::BiasAddLinearCombinationClamp< | |||
| int8_t, 8, int32_t, int32_t, float>; | |||
| using Convolution = cutlass::convolution::device::Convolution< | |||
| int8_t, LayoutSrc, int8_t, LayoutFilter, int8_t, | |||
| LayoutSrc, int32_t, LayoutSrc, int32_t, | |||
| cutlass::convolution::ConvType::kConvolution, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, | |||
| ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, | |||
| cutlass::convolution::threadblock::ConvolutionNCxHWxThreadblockSwizzle< | |||
| cutlass::convolution::ConvType::kConvolution>, | |||
| 2, 16, 16, true>; | |||
| template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>( | |||
| const int8_t* d_src, | |||
| const int8_t* d_filter, | |||
| const int32_t* d_bias, | |||
| const int8_t* d_z, | |||
| int8_t* d_dst, | |||
| int* workspace, | |||
| typename Convolution::ConvolutionParameter const& conv_param, | |||
| typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||
| cudaStream_t stream); | |||
| #pragma GCC diagnostic pop | |||
| #endif | |||
| @@ -0,0 +1,35 @@ | |||
| #if !MEGDNN_TEGRA_X1 | |||
| // generated by gen_cuda_conv_bias_kern_impls.py | |||
| // ignore warning of cutlass | |||
| #pragma GCC diagnostic push | |||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||
| #include "../conv_bias_int8_implicit_gemm_imma_ncdiv32hw32.cuinl" | |||
| using LayoutSrc = cutlass::layout::TensorNCxHWx<32>; | |||
| using LayoutFilter = cutlass::layout::TensorCxRSKx<32>; | |||
| using ThreadBlockShape = cutlass::gemm::GemmShape<64, 64, 64>; | |||
| using WarpShape = cutlass::gemm::GemmShape<32, 32, 64>; | |||
| using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; | |||
| using EpilogueOp = cutlass::epilogue::thread::BiasAddLinearCombinationReluClamp< | |||
| int8_t, 8, int32_t, int32_t, float>; | |||
| using Convolution = cutlass::convolution::device::Convolution< | |||
| int8_t, LayoutSrc, int8_t, LayoutFilter, int8_t, | |||
| LayoutSrc, int32_t, LayoutSrc, int32_t, | |||
| cutlass::convolution::ConvType::kConvolution, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, | |||
| ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, | |||
| cutlass::convolution::threadblock::ConvolutionNCxHWxThreadblockSwizzle< | |||
| cutlass::convolution::ConvType::kConvolution>, | |||
| 2, 16, 16, true>; | |||
| template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>( | |||
| const int8_t* d_src, | |||
| const int8_t* d_filter, | |||
| const int32_t* d_bias, | |||
| const int8_t* d_z, | |||
| int8_t* d_dst, | |||
| int* workspace, | |||
| typename Convolution::ConvolutionParameter const& conv_param, | |||
| typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||
| cudaStream_t stream); | |||
| #pragma GCC diagnostic pop | |||
| #endif | |||
| @@ -80,6 +80,7 @@ public: | |||
| class AlgoInt8NCHW4IMMAImplicitGemm; | |||
| class AlgoInt8CHWN4IMMAImplicitGemmReorderFilter; | |||
| class AlgoInt8CHWN4IMMAImplicitGemmUnrollWidth; | |||
| class AlgoInt8NCHW32IMMAImplicitGemm; | |||
| class AlgoBFloat16; | |||
| class AlgoPack; | |||
| @@ -56,7 +56,6 @@ const char *cublasGetErrorString(cublasStatus_t error) { | |||
| } | |||
| return "Unknown CUBLAS error"; | |||
| } | |||
| } // anonymous namespace | |||
| void cuda::__throw_cuda_error__(cudaError_t err, const char *msg) { | |||
| @@ -87,6 +86,12 @@ void cuda::__throw_cuda_driver_error__(CUresult err, const char* msg) { | |||
| megdnn_throw(s.c_str()); | |||
| } | |||
| void cuda::__throw_cutlass_error__(cutlass::Status err, const char* msg) { | |||
| auto s = ssprintf("cutlass error %s(%d) occurred; expr: %s", | |||
| cutlass::cutlassGetStatusString(err), int(err), msg); | |||
| megdnn_throw(s.c_str()); | |||
| } | |||
| void cuda::report_error(const char *msg) { | |||
| megdnn_throw(msg); | |||
| MEGDNN_MARK_USED_VAR(msg); | |||
| @@ -20,6 +20,7 @@ | |||
| #include <cusolverDn.h> | |||
| #include "cuda.h" | |||
| #include "src/cuda/cudnn_with_check.h" | |||
| #include "cutlass/cutlass.h" | |||
| #define cuda_check(_x) \ | |||
| do { \ | |||
| @@ -61,6 +62,14 @@ | |||
| } \ | |||
| } while (0) | |||
| #define cutlass_check(_x) \ | |||
| do { \ | |||
| cutlass::Status _err = (_x); \ | |||
| if (_err != cutlass::Status::kSuccess) { \ | |||
| ::megdnn::cuda::__throw_cutlass_error__(_err, #_x); \ | |||
| } \ | |||
| } while (0) | |||
| #define after_kernel_launch() \ | |||
| do { \ | |||
| cuda_check(cudaGetLastError()); \ | |||
| @@ -93,6 +102,8 @@ MEGDNN_NORETURN void __throw_cublas_error__(cublasStatus_t err, | |||
| MEGDNN_NORETURN void __throw_cusolver_error__(cusolverStatus_t err, | |||
| const char* msg); | |||
| MEGDNN_NORETURN void __throw_cuda_driver_error__(CUresult err, const char* msg); | |||
| MEGDNN_NORETURN void __throw_cutlass_error__(cutlass::Status status, | |||
| const char* msg); | |||
| MEGDNN_NORETURN void report_error(const char* msg); | |||
| template <typename T, size_t N> | |||
| @@ -32,6 +32,10 @@ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-narrowing") | |||
| target_link_libraries(megdnn_test gtest) | |||
| target_link_libraries(megdnn_test megdnn ${MGE_BLAS_LIBS}) | |||
| if (MGE_WITH_CUDA) | |||
| target_link_libraries(megdnn_test cutlass) | |||
| endif() | |||
| target_include_directories(megdnn_test | |||
| PRIVATE | |||
| ${PROJECT_SOURCE_DIR}/third_party/midout/src | |||
| @@ -254,8 +254,8 @@ public: | |||
| }; | |||
| ////////////////// Algo Benchmark //////////////////////// | |||
| template <typename Opr, typename Proxy = OprProxy<Opr>> | |||
| float algo_benchmark(Benchmarker<Opr>& benchmark, TensorLayoutArray layouts, | |||
| template <typename Opr, typename Proxy = OprProxy<Opr>, typename T = Timer> | |||
| float algo_benchmark(Benchmarker<Opr, T>& benchmark, TensorLayoutArray layouts, | |||
| const std::string& algo_base) { | |||
| Proxy proxy; | |||
| auto opr = benchmark.opr(); | |||
| @@ -279,8 +279,8 @@ float algo_benchmark(Benchmarker<Opr>& benchmark, TensorLayoutArray layouts, | |||
| return min_used; | |||
| } | |||
| template <typename Opr, typename Proxy = OprProxy<Opr>> | |||
| float algo_benchmark(Benchmarker<Opr>& benchmark, TensorShapeArray shapes, | |||
| template <typename Opr, typename Proxy = OprProxy<Opr>, typename T = Timer> | |||
| float algo_benchmark(Benchmarker<Opr, T>& benchmark, TensorShapeArray shapes, | |||
| const std::string& algo_base) { | |||
| return algo_benchmark(benchmark, benchmark.make_layouts(shapes), algo_base); | |||
| } | |||
| @@ -18,6 +18,8 @@ | |||
| #include "test/cuda/fixture.h" | |||
| #include "test/cuda/utils.h" | |||
| #define MEGDNN_WITH_BENCHMARK 1 | |||
| #define V1(x) #x | |||
| #define V(x) V1(x) | |||
| @@ -107,11 +109,6 @@ void benchmark_target_algo( | |||
| benchmarker.set_display(false).set_times(RUNS); | |||
| benchmarker_cudnn.set_display(false).set_times(RUNS); | |||
| if (algo) { | |||
| benchmarker.set_before_exec_callback( | |||
| conv_bias::ConvBiasAlgoChecker<ConvBiasForward>(algo)); | |||
| } | |||
| #define CUDNN_VERSION_STRING \ | |||
| "v" V(CUDNN_MAJOR) "." V(CUDNN_MINOR) "." V(CUDNN_PATCHLEVEL) | |||
| benchmarker_cudnn.set_before_exec_callback( | |||
| @@ -133,168 +130,117 @@ void benchmark_target_algo( | |||
| using Param = ConvBias::Param; | |||
| using Format = Param::Format; | |||
| if (format == Format::NCHW4) { | |||
| for (auto&& arg : args) { | |||
| Param param; | |||
| param.pad_h = param.pad_w = arg.f / 2; | |||
| param.stride_h = param.stride_w = arg.s; | |||
| param.format = Format::NCHW4; | |||
| size_t ho = infer_conv_shape(arg.hi, arg.f, arg.s, arg.f / 2); | |||
| size_t wo = infer_conv_shape(arg.wi, arg.f, arg.s, arg.f / 2); | |||
| benchmarker.set_param(param); | |||
| auto time_in_ms = | |||
| benchmarker.execs({{arg.n, arg.ci / 4, arg.hi, arg.wi, 4}, | |||
| {arg.co, arg.ci / 4, arg.f, arg.f, 4}, | |||
| {1, arg.co / 4, 1, 1, 4}, | |||
| {}, | |||
| {}}) / | |||
| RUNS; | |||
| param.nonlineMode = Param::NonlineMode::IDENTITY; | |||
| benchmarker_cudnn.set_param(param); | |||
| auto time_in_ms_cudnn = | |||
| benchmarker_cudnn.execs( | |||
| {{arg.n, arg.ci / 4, arg.hi, arg.wi, 4}, | |||
| {arg.co, arg.ci / 4, arg.f, arg.f, 4}, | |||
| {1, arg.co / 4, 1, 1, 4}, | |||
| {}, | |||
| {}}) / | |||
| RUNS; | |||
| float flo = 2.0 * arg.n * arg.co * ho * wo * arg.ci * arg.f * | |||
| arg.f / (1e12); | |||
| TensorShape src{arg.n, arg.ci, arg.hi, arg.wi}, | |||
| filter{arg.co, arg.ci, arg.f, arg.f}; | |||
| printf("src=%s, filter=%s, time(algo=%s)=%.2f %.2fTops, " | |||
| "time(cudnn)=%.2f %.2fTops, " | |||
| "perf(algo=%s)/perf(cudnn)=%.2f\n", | |||
| src.to_string().c_str(), filter.to_string().c_str(), algo, | |||
| time_in_ms, (flo / (time_in_ms * 1e-3)), time_in_ms_cudnn, | |||
| (flo / (time_in_ms_cudnn * 1e-3)), algo, | |||
| time_in_ms_cudnn / time_in_ms); | |||
| // helper function to change format | |||
| auto get_tensor_shape = [](TensorShape shape, | |||
| Format format) -> TensorShape { | |||
| TensorShape ret; | |||
| if (format == Format::NCHW4) { | |||
| ret = static_cast<TensorShape>( | |||
| TensorLayout{shape, dtype::Int8()} | |||
| .reshape({shape[0], shape[1] / 4, 4, shape[2], | |||
| shape[3]}) | |||
| .dimshuffle({0, 1, 3, 4, 2})); | |||
| } else if (format == Format::CHWN4) { | |||
| ret = static_cast<TensorShape>( | |||
| TensorLayout{shape, dtype::Int8()} | |||
| .reshape({shape[0], shape[1] / 4, 4, shape[2], | |||
| shape[3]}) | |||
| .dimshuffle({1, 3, 4, 0, 2})); | |||
| } | |||
| printf("bench with z tensor\n"); | |||
| for (auto&& arg : args) { | |||
| Param param; | |||
| param.pad_h = param.pad_w = arg.f / 2; | |||
| param.stride_h = param.stride_w = arg.s; | |||
| param.format = Format::NCHW4; | |||
| size_t ho = infer_conv_shape(arg.hi, arg.f, arg.s, arg.f / 2); | |||
| size_t wo = infer_conv_shape(arg.wi, arg.f, arg.s, arg.f / 2); | |||
| benchmarker.set_param(param); | |||
| auto time_in_ms = | |||
| benchmarker.execs({{arg.n, arg.ci / 4, arg.hi, arg.wi, 4}, | |||
| {arg.co, arg.ci / 4, arg.f, arg.f, 4}, | |||
| {1, arg.co / 4, 1, 1, 4}, | |||
| {arg.n, arg.co / 4, ho, wo, 4}, | |||
| {}}) / | |||
| RUNS; | |||
| param.format = Format::NCHW4; | |||
| param.nonlineMode = Param::NonlineMode::IDENTITY; | |||
| benchmarker_cudnn.set_param(param); | |||
| auto time_in_ms_cudnn = | |||
| benchmarker_cudnn.execs( | |||
| {{arg.n, arg.ci / 4, arg.hi, arg.wi, 4}, | |||
| {arg.co, arg.ci / 4, arg.f, arg.f, 4}, | |||
| {1, arg.co / 4, 1, 1, 4}, | |||
| {arg.n, arg.co / 4, ho, wo, 4}, | |||
| {}}) / | |||
| RUNS; | |||
| float flo = 2.0 * arg.n * arg.co * ho * wo * arg.ci * arg.f * | |||
| arg.f / (1e12); | |||
| TensorShape src{arg.n, arg.ci, arg.hi, arg.wi}, | |||
| filter{arg.co, arg.ci, arg.f, arg.f}; | |||
| printf("src=%s, filter=%s, time(algo=%s)=%.2f %.2fTops, " | |||
| "time(cudnn)=%.2f %.2fTops, " | |||
| "perf(algo=%s)/perf(cudnn)=%.2f\n", | |||
| src.to_string().c_str(), filter.to_string().c_str(), algo, | |||
| time_in_ms, (flo / (time_in_ms * 1e-3)), time_in_ms_cudnn, | |||
| (flo / (time_in_ms_cudnn * 1e-3)), algo, | |||
| time_in_ms_cudnn / time_in_ms); | |||
| return ret; | |||
| }; | |||
| for (auto&& arg : args) { | |||
| Param param; | |||
| param.pad_h = param.pad_w = arg.f / 2; | |||
| param.stride_h = param.stride_w = arg.s; | |||
| param.format = format; | |||
| size_t ho = infer_conv_shape(arg.hi, arg.f, arg.s, arg.f / 2); | |||
| size_t wo = infer_conv_shape(arg.wi, arg.f, arg.s, arg.f / 2); | |||
| benchmarker.set_param(param); | |||
| if (!algo) { | |||
| benchmarker.proxy()->target_algo = nullptr; | |||
| } | |||
| } else if (format == Format::CHWN4) { | |||
| for (auto&& arg : args) { | |||
| Param param; | |||
| param.pad_h = param.pad_w = arg.f / 2; | |||
| param.stride_h = param.stride_w = arg.s; | |||
| param.format = Format::CHWN4; | |||
| size_t ho = infer_conv_shape(arg.hi, arg.f, arg.s, arg.f / 2); | |||
| size_t wo = infer_conv_shape(arg.wi, arg.f, arg.s, arg.f / 2); | |||
| benchmarker.set_param(param); | |||
| auto time_in_ms = | |||
| benchmarker.execs({{arg.ci / 4, arg.hi, arg.wi, arg.n, 4}, | |||
| {arg.ci / 4, arg.f, arg.f, arg.co, 4}, | |||
| {arg.co / 4, 1, 1, 1, 4}, | |||
| {}, | |||
| {}}) / | |||
| RUNS; | |||
| param.format = Format::NCHW4; | |||
| benchmarker_cudnn.set_param(param); | |||
| auto time_in_ms_cudnn = | |||
| benchmarker_cudnn.execs( | |||
| {{arg.n, arg.ci / 4, arg.hi, arg.wi, 4}, | |||
| {arg.co, arg.ci / 4, arg.f, arg.f, 4}, | |||
| {1, arg.co / 4, 1, 1, 4}, | |||
| {}, | |||
| {}}) / | |||
| TensorShape src{arg.n, arg.ci, arg.hi, arg.wi}, | |||
| filter{arg.co, arg.ci, arg.f, arg.f}, bias{1, arg.co, 1, 1}, | |||
| z{arg.n, arg.co, ho, wo}, dst = z; | |||
| float time_in_ms = 0.f; | |||
| if (algo) { | |||
| time_in_ms = | |||
| algo_benchmark<ConvBiasForward, OprProxy<ConvBiasForward>, | |||
| CUTimer>(benchmarker, | |||
| {get_tensor_shape(src, format), | |||
| get_tensor_shape(filter, format), | |||
| get_tensor_shape(bias, format), | |||
| {}, | |||
| {}}, | |||
| algo) / | |||
| RUNS; | |||
| float flo = 2.0 * arg.n * arg.co * ho * wo * arg.ci * arg.f * | |||
| arg.f / (1e12); | |||
| TensorShape src{arg.n, arg.ci, arg.hi, arg.wi}, | |||
| filter{arg.co, arg.ci, arg.f, arg.f}; | |||
| printf("src=%s, filter=%s, time(algo=%s)=%.2f %.2fTops, " | |||
| "time(cudnn)=%.2f %.2fTops, " | |||
| "perf(algo=%s)/perf(cudnn)=%.2f\n", | |||
| src.to_string().c_str(), filter.to_string().c_str(), algo, | |||
| time_in_ms, (flo / (time_in_ms * 1e-3)), time_in_ms_cudnn, | |||
| (flo / (time_in_ms_cudnn * 1e-3)), algo, | |||
| time_in_ms_cudnn / time_in_ms); | |||
| } else { | |||
| time_in_ms = benchmarker.execs({get_tensor_shape(src, format), | |||
| get_tensor_shape(filter, format), | |||
| get_tensor_shape(bias, format), | |||
| {}, | |||
| {}}) / | |||
| RUNS; | |||
| } | |||
| Format format_cudnn = Format::NCHW4; | |||
| param.format = format_cudnn; | |||
| benchmarker_cudnn.set_param(param); | |||
| auto time_in_ms_cudnn = | |||
| benchmarker_cudnn.execs({get_tensor_shape(src, format_cudnn), | |||
| get_tensor_shape(filter, format_cudnn), | |||
| get_tensor_shape(bias, format_cudnn), | |||
| {}, | |||
| {}}) / | |||
| RUNS; | |||
| float flo = 2.0 * arg.n * arg.co * ho * wo * arg.ci * arg.f * arg.f / | |||
| (1e12); | |||
| printf("src=%s, filter=%s, dst=%s, time(algo=%s)=%.2f %.2fTops, " | |||
| "time(cudnn)=%.2f %.2fTops, " | |||
| "perf(algo=%s)/perf(cudnn)=%.2f\n", | |||
| src.to_string().c_str(), filter.to_string().c_str(), | |||
| dst.to_string().c_str(), algo, time_in_ms, | |||
| (flo / (time_in_ms * 1e-3)), time_in_ms_cudnn, | |||
| (flo / (time_in_ms_cudnn * 1e-3)), algo, | |||
| time_in_ms_cudnn / time_in_ms); | |||
| printf("bench with z tensor\n"); | |||
| for (auto&& arg : args) { | |||
| Param param; | |||
| param.pad_h = param.pad_w = arg.f / 2; | |||
| param.stride_h = param.stride_w = arg.s; | |||
| param.format = Format::CHWN4; | |||
| size_t ho = infer_conv_shape(arg.hi, arg.f, arg.s, arg.f / 2); | |||
| size_t wo = infer_conv_shape(arg.wi, arg.f, arg.s, arg.f / 2); | |||
| benchmarker.set_param(param); | |||
| auto time_in_ms = | |||
| benchmarker.execs({{arg.ci / 4, arg.hi, arg.wi, arg.n, 4}, | |||
| {arg.ci / 4, arg.f, arg.f, arg.co, 4}, | |||
| {arg.co / 4, 1, 1, 1, 4}, | |||
| {arg.co / 4, ho, wo, arg.n, 4}, | |||
| {}}) / | |||
| if (algo) { | |||
| time_in_ms = | |||
| algo_benchmark<ConvBiasForward, OprProxy<ConvBiasForward>, | |||
| CUTimer>(benchmarker, | |||
| {get_tensor_shape(src, format), | |||
| get_tensor_shape(filter, format), | |||
| get_tensor_shape(bias, format), | |||
| get_tensor_shape(z, format), | |||
| {}}, | |||
| algo) / | |||
| RUNS; | |||
| param.format = Format::NCHW4; | |||
| benchmarker_cudnn.set_param(param); | |||
| param.nonlineMode = Param::NonlineMode::IDENTITY; | |||
| auto time_in_ms_cudnn = | |||
| benchmarker_cudnn.execs( | |||
| {{arg.n, arg.ci / 4, arg.hi, arg.wi, 4}, | |||
| {arg.co, arg.ci / 4, arg.f, arg.f, 4}, | |||
| {1, arg.co / 4, 1, 1, 4}, | |||
| {arg.n, arg.co / 4, ho, wo, 4}, | |||
| {}}) / | |||
| RUNS; | |||
| float flo = 2.0 * arg.n * arg.co * ho * wo * arg.ci * arg.f * | |||
| arg.f / (1e12); | |||
| TensorShape src{arg.n, arg.ci, arg.hi, arg.wi}, | |||
| filter{arg.co, arg.ci, arg.f, arg.f}; | |||
| printf("src=%s, filter=%s, time(algo=%s)=%.2f %.2fTops, " | |||
| "time(cudnn)=%.2f %.2fTops, " | |||
| "perf(algo=%s)/perf(cudnn)=%.2f\n", | |||
| src.to_string().c_str(), filter.to_string().c_str(), algo, | |||
| time_in_ms, (flo / (time_in_ms * 1e-3)), time_in_ms_cudnn, | |||
| (flo / (time_in_ms_cudnn * 1e-3)), algo, | |||
| time_in_ms_cudnn / time_in_ms); | |||
| } else { | |||
| time_in_ms = benchmarker.execs({get_tensor_shape(src, format), | |||
| get_tensor_shape(filter, format), | |||
| get_tensor_shape(bias, format), | |||
| get_tensor_shape(z, format), | |||
| {}}) / | |||
| RUNS; | |||
| } | |||
| time_in_ms_cudnn = | |||
| benchmarker_cudnn.execs({get_tensor_shape(src, format_cudnn), | |||
| get_tensor_shape(filter, format_cudnn), | |||
| get_tensor_shape(bias, format_cudnn), | |||
| get_tensor_shape(z, format_cudnn), | |||
| {}}) / | |||
| RUNS; | |||
| printf("src=%s, filter=%s, dst=%s, time(algo=%s)=%.2f %.2fTops, " | |||
| "time(cudnn)=%.2f %.2fTops, " | |||
| "perf(algo=%s)/perf(cudnn)=%.2f\n", | |||
| src.to_string().c_str(), filter.to_string().c_str(), | |||
| dst.to_string().c_str(), algo, time_in_ms, | |||
| (flo / (time_in_ms * 1e-3)), time_in_ms_cudnn, | |||
| (flo / (time_in_ms_cudnn * 1e-3)), algo, | |||
| time_in_ms_cudnn / time_in_ms); | |||
| } | |||
| } | |||
| @@ -313,10 +259,7 @@ void benchmark_target_algo_with_cudnn_tsc( | |||
| std::unique_ptr<OprProxy<ConvBiasForward>> proxy{ | |||
| new OprProxy<ConvBiasForward>{true}}; | |||
| if (algo) { | |||
| benchmarker.set_before_exec_callback( | |||
| conv_bias::ConvBiasAlgoChecker<ConvBiasForward>(algo)); | |||
| } else { | |||
| if (!algo) { | |||
| benchmarker.set_proxy(proxy); | |||
| } | |||
| @@ -340,163 +283,132 @@ void benchmark_target_algo_with_cudnn_tsc( | |||
| using Param = ConvBias::Param; | |||
| using Format = Param::Format; | |||
| if (format == Format::NCHW4) { | |||
| for (auto&& arg : args) { | |||
| Param param; | |||
| param.pad_h = param.pad_w = arg.f / 2; | |||
| param.stride_h = param.stride_w = arg.s; | |||
| param.format = Format::NCHW4; | |||
| size_t ho = infer_conv_shape(arg.hi, arg.f, arg.s, arg.f / 2); | |||
| size_t wo = infer_conv_shape(arg.wi, arg.f, arg.s, arg.f / 2); | |||
| benchmarker.set_param(param); | |||
| if (!algo) { | |||
| benchmarker.proxy()->target_algo = nullptr; | |||
| } | |||
| auto time_in_ms = | |||
| benchmarker.execs({{arg.n, arg.ci / 4, arg.hi, arg.wi, 4}, | |||
| {arg.co, arg.ci / 4, arg.f, arg.f, 4}, | |||
| {1, arg.co / 4, 1, 1, 4}, | |||
| {}, | |||
| {}}) / | |||
| RUNS; | |||
| param.format = Format::NCHW32; | |||
| benchmarker_cudnn.set_param(param); | |||
| auto time_in_ms_cudnn = | |||
| benchmarker_cudnn.execs( | |||
| {{arg.n, arg.ci / 32, arg.hi, arg.wi, 32}, | |||
| {arg.co, arg.ci / 32, arg.f, arg.f, 32}, | |||
| {1, arg.co / 32, 1, 1, 32}, | |||
| {}, | |||
| {}}) / | |||
| RUNS; | |||
| float flo = 2.0 * arg.n * arg.co * ho * wo * arg.ci * arg.f * | |||
| arg.f / (1e12); | |||
| TensorShape src{arg.n, arg.ci, arg.hi, arg.wi}, | |||
| filter{arg.co, arg.ci, arg.f, arg.f}; | |||
| printf("src=%s, filter=%s, time(algo=%s)=%.2f %.2fTops, " | |||
| "time(cudnn)=%.2f %.2fTops, " | |||
| "perf(algo=%s)/perf(cudnn)=%.2f\n", | |||
| src.to_string().c_str(), filter.to_string().c_str(), algo, | |||
| time_in_ms, (flo / (time_in_ms * 1e-3)), time_in_ms_cudnn, | |||
| (flo / (time_in_ms_cudnn * 1e-3)), algo, | |||
| time_in_ms_cudnn / time_in_ms); | |||
| // helper function to change format | |||
| auto get_tensor_shape = [](TensorShape shape, | |||
| Format format) -> TensorShape { | |||
| TensorShape ret; | |||
| if (format == Format::NCHW4) { | |||
| ret = static_cast<TensorShape>( | |||
| TensorLayout{shape, dtype::Int8()} | |||
| .reshape({shape[0], shape[1] / 4, 4, shape[2], | |||
| shape[3]}) | |||
| .dimshuffle({0, 1, 3, 4, 2})); | |||
| } else if (format == Format::NCHW32) { | |||
| ret = static_cast<TensorShape>( | |||
| TensorLayout{shape, dtype::Int8()} | |||
| .reshape({shape[0], shape[1] / 32, 32, shape[2], | |||
| shape[3]}) | |||
| .dimshuffle({0, 1, 3, 4, 2})); | |||
| } else if (format == Format::CHWN4) { | |||
| ret = static_cast<TensorShape>( | |||
| TensorLayout{shape, dtype::Int8()} | |||
| .reshape({shape[0], shape[1] / 4, 4, shape[2], | |||
| shape[3]}) | |||
| .dimshuffle({1, 3, 4, 0, 2})); | |||
| } | |||
| } else if (format == Format::CHWN4) { | |||
| for (auto&& arg : args) { | |||
| Param param; | |||
| param.pad_h = param.pad_w = arg.f / 2; | |||
| param.stride_h = param.stride_w = arg.s; | |||
| param.format = Format::CHWN4; | |||
| size_t ho = infer_conv_shape(arg.hi, arg.f, arg.s, arg.f / 2); | |||
| size_t wo = infer_conv_shape(arg.wi, arg.f, arg.s, arg.f / 2); | |||
| benchmarker.set_param(param); | |||
| if (!algo) { | |||
| benchmarker.proxy()->target_algo = nullptr; | |||
| } | |||
| auto time_in_ms = | |||
| benchmarker.execs({{arg.ci / 4, arg.hi, arg.wi, arg.n, 4}, | |||
| {arg.ci / 4, arg.f, arg.f, arg.co, 4}, | |||
| {arg.co / 4, 1, 1, 1, 4}, | |||
| {}, | |||
| {}}) / | |||
| return ret; | |||
| }; | |||
| for (auto&& arg : args) { | |||
| Param param; | |||
| param.pad_h = param.pad_w = arg.f / 2; | |||
| param.stride_h = param.stride_w = arg.s; | |||
| param.format = format; | |||
| size_t ho = infer_conv_shape(arg.hi, arg.f, arg.s, arg.f / 2); | |||
| size_t wo = infer_conv_shape(arg.wi, arg.f, arg.s, arg.f / 2); | |||
| benchmarker.set_param(param); | |||
| if (!algo) { | |||
| benchmarker.proxy()->target_algo = nullptr; | |||
| } | |||
| TensorShape src{arg.n, arg.ci, arg.hi, arg.wi}, | |||
| filter{arg.co, arg.ci, arg.f, arg.f}, bias{1, arg.co, 1, 1}, | |||
| z{arg.n, arg.co, ho, wo}, dst = z; | |||
| // skip testcase which cannot enable nchw32 tensorcore | |||
| if (format == Format::NCHW32 && (arg.co % 32 != 0 || arg.ci % 32 != 0)) | |||
| continue; | |||
| // skip testcase which cannot enable nchw4/chwn4 tensorcore | |||
| if ((format == Format::CHWN4 || format == Format::NCHW4) && | |||
| (arg.ci % 16 != 0)) | |||
| continue; | |||
| float time_in_ms = 0.f; | |||
| if (algo) { | |||
| time_in_ms = | |||
| algo_benchmark<ConvBiasForward, OprProxy<ConvBiasForward>, | |||
| CUTimer>(benchmarker, | |||
| {get_tensor_shape(src, format), | |||
| get_tensor_shape(filter, format), | |||
| get_tensor_shape(bias, format), | |||
| {}, | |||
| {}}, | |||
| algo) / | |||
| RUNS; | |||
| float time_in_ms_cudnn = 0.f; | |||
| if (arg.ci % 32 == 0 && arg.co % 32 == 0) { | |||
| param.format = Format::NCHW32; | |||
| benchmarker_cudnn.set_param(param); | |||
| time_in_ms_cudnn = | |||
| benchmarker_cudnn.execs( | |||
| {{arg.n, arg.ci / 32, arg.hi, arg.wi, 32}, | |||
| {arg.co, arg.ci / 32, arg.f, arg.f, 32}, | |||
| {1, arg.co / 32, 1, 1, 32}, | |||
| {}, | |||
| {}}) / | |||
| RUNS; | |||
| } else { | |||
| param.format = Format::NCHW4; | |||
| benchmarker_cudnn.set_param(param); | |||
| time_in_ms_cudnn = | |||
| benchmarker_cudnn.execs( | |||
| {{arg.n, arg.ci / 4, arg.hi, arg.wi, 4}, | |||
| {arg.co, arg.ci / 4, arg.f, arg.f, 4}, | |||
| {1, arg.co / 4, 1, 1, 4}, | |||
| {}, | |||
| {}}) / | |||
| RUNS; | |||
| } | |||
| float flo = 2.0 * arg.n * arg.co * ho * wo * arg.ci * arg.f * | |||
| arg.f / (1e12); | |||
| TensorShape src{arg.n, arg.ci, arg.hi, arg.wi}, | |||
| filter{arg.co, arg.ci, arg.f, arg.f}; | |||
| printf("src=%s, filter=%s, time(algo=%s)=%.2f %.2fTops, " | |||
| "time(cudnn)=%.2f %.2fTops, " | |||
| "perf(algo=%s)/perf(cudnn)=%.2f\n", | |||
| src.to_string().c_str(), filter.to_string().c_str(), algo, | |||
| time_in_ms, (flo / (time_in_ms * 1e-3)), time_in_ms_cudnn, | |||
| (flo / (time_in_ms_cudnn * 1e-3)), algo, | |||
| time_in_ms_cudnn / time_in_ms); | |||
| } else { | |||
| time_in_ms = benchmarker.execs({get_tensor_shape(src, format), | |||
| get_tensor_shape(filter, format), | |||
| get_tensor_shape(bias, format), | |||
| {}, | |||
| {}}) / | |||
| RUNS; | |||
| } | |||
| Format format_cudnn = arg.ci % 32 == 0 && arg.co % 32 == 0 | |||
| ? Format::NCHW32 | |||
| : Format::NCHW4; | |||
| param.format = format_cudnn; | |||
| benchmarker_cudnn.set_param(param); | |||
| auto time_in_ms_cudnn = | |||
| benchmarker_cudnn.execs({get_tensor_shape(src, format_cudnn), | |||
| get_tensor_shape(filter, format_cudnn), | |||
| get_tensor_shape(bias, format_cudnn), | |||
| {}, | |||
| {}}) / | |||
| RUNS; | |||
| float flo = 2.0 * arg.n * arg.co * ho * wo * arg.ci * arg.f * arg.f / | |||
| (1e12); | |||
| printf("src=%s, filter=%s, dst=%s, time(algo=%s)=%.2f %.2fTops, " | |||
| "time(cudnn)=%.2f %.2fTops, " | |||
| "perf(algo=%s)/perf(cudnn)=%.2f\n", | |||
| src.to_string().c_str(), filter.to_string().c_str(), | |||
| dst.to_string().c_str(), algo, time_in_ms, | |||
| (flo / (time_in_ms * 1e-3)), time_in_ms_cudnn, | |||
| (flo / (time_in_ms_cudnn * 1e-3)), algo, | |||
| time_in_ms_cudnn / time_in_ms); | |||
| printf("bench with z tensor\n"); | |||
| for (auto&& arg : args) { | |||
| Param param; | |||
| param.pad_h = param.pad_w = arg.f / 2; | |||
| param.stride_h = param.stride_w = arg.s; | |||
| param.format = Format::CHWN4; | |||
| size_t ho = infer_conv_shape(arg.hi, arg.f, arg.s, arg.f / 2); | |||
| size_t wo = infer_conv_shape(arg.wi, arg.f, arg.s, arg.f / 2); | |||
| benchmarker.set_param(param); | |||
| if (!algo) { | |||
| benchmarker.proxy()->target_algo = nullptr; | |||
| } | |||
| auto time_in_ms = | |||
| benchmarker.execs({{arg.ci / 4, arg.hi, arg.wi, arg.n, 4}, | |||
| {arg.ci / 4, arg.f, arg.f, arg.co, 4}, | |||
| {arg.co / 4, 1, 1, 1, 4}, | |||
| {arg.co / 4, ho, wo, arg.n, 4}, | |||
| {}}) / | |||
| if (algo) { | |||
| time_in_ms = | |||
| algo_benchmark<ConvBiasForward, OprProxy<ConvBiasForward>, | |||
| CUTimer>(benchmarker, | |||
| {get_tensor_shape(src, format), | |||
| get_tensor_shape(filter, format), | |||
| get_tensor_shape(bias, format), | |||
| get_tensor_shape(z, format), | |||
| {}}, | |||
| algo) / | |||
| RUNS; | |||
| float time_in_ms_cudnn = 0.f; | |||
| if (arg.ci % 32 == 0 && arg.co % 32 == 0) { | |||
| param.format = Format::NCHW32; | |||
| benchmarker_cudnn.set_param(param); | |||
| time_in_ms_cudnn = | |||
| benchmarker_cudnn.execs( | |||
| {{arg.n, arg.ci / 32, arg.hi, arg.wi, 32}, | |||
| {arg.co, arg.ci / 32, arg.f, arg.f, 32}, | |||
| {1, arg.co / 32, 1, 1, 32}, | |||
| {arg.n, arg.co / 32, ho, wo, 32}, | |||
| {}}) / | |||
| RUNS; | |||
| } else { | |||
| param.format = Format::NCHW4; | |||
| benchmarker_cudnn.set_param(param); | |||
| time_in_ms_cudnn = | |||
| benchmarker_cudnn.execs( | |||
| {{arg.n, arg.ci / 4, arg.hi, arg.wi, 4}, | |||
| {arg.co, arg.ci / 4, arg.f, arg.f, 4}, | |||
| {1, arg.co / 4, 1, 1, 4}, | |||
| {arg.n, arg.co / 4, ho, wo, 4}, | |||
| {}}) / | |||
| RUNS; | |||
| } | |||
| float flo = 2.0 * arg.n * arg.co * ho * wo * arg.ci * arg.f * | |||
| arg.f / (1e12); | |||
| TensorShape src{arg.n, arg.ci, arg.hi, arg.wi}, | |||
| filter{arg.co, arg.ci, arg.f, arg.f}; | |||
| printf("src=%s, filter=%s, time(algo=%s)=%.2f %.2fTops, " | |||
| "time(cudnn)=%.2f %.2fTops, " | |||
| "perf(algo=%s)/perf(cudnn)=%.2f\n", | |||
| src.to_string().c_str(), filter.to_string().c_str(), algo, | |||
| time_in_ms, (flo / (time_in_ms * 1e-3)), time_in_ms_cudnn, | |||
| (flo / (time_in_ms_cudnn * 1e-3)), algo, | |||
| time_in_ms_cudnn / time_in_ms); | |||
| } else { | |||
| time_in_ms = benchmarker.execs({get_tensor_shape(src, format), | |||
| get_tensor_shape(filter, format), | |||
| get_tensor_shape(bias, format), | |||
| get_tensor_shape(z, format), | |||
| {}}) / | |||
| RUNS; | |||
| } | |||
| time_in_ms_cudnn = | |||
| benchmarker_cudnn.execs({get_tensor_shape(src, format_cudnn), | |||
| get_tensor_shape(filter, format_cudnn), | |||
| get_tensor_shape(bias, format_cudnn), | |||
| get_tensor_shape(z, format_cudnn), | |||
| {}}) / | |||
| RUNS; | |||
| printf("src=%s, filter=%s, dst=%s, time(algo=%s)=%.2f %.2fTops, " | |||
| "time(cudnn)=%.2f %.2fTops, " | |||
| "perf(algo=%s)/perf(cudnn)=%.2f\n", | |||
| src.to_string().c_str(), filter.to_string().c_str(), | |||
| dst.to_string().c_str(), algo, time_in_ms, | |||
| (flo / (time_in_ms * 1e-3)), time_in_ms_cudnn, | |||
| (flo / (time_in_ms_cudnn * 1e-3)), algo, | |||
| time_in_ms_cudnn / time_in_ms); | |||
| } | |||
| } | |||
| #endif | |||
| @@ -1166,6 +1078,77 @@ TEST_F(CUDA, CONV_BIAS_INT8_CHWN4_UNROLL_WIDTH_TENSORCORE_1x1_ALGO_2) { | |||
| } | |||
| #if CUDA_VERSION >= 10020 | |||
| /// \note: we only check several cases and block sizes in megdnn_test, the full | |||
| /// testcases are written in cutlass repository | |||
| TEST_F(CUDA, CUTLASS_CONV_BIAS_INT8_NCHW32_IMMA) { | |||
| require_compute_capability_eq(7, 5); | |||
| Checker<ConvBiasForward> checker(handle_cuda()); | |||
| auto check = [&checker](const std::string& algo) { | |||
| checker.set_before_exec_callback( | |||
| conv_bias::ConvBiasAlgoChecker<ConvBiasForward>(algo.c_str())); | |||
| UniformIntRNG rng{-8, 8}; | |||
| UniformIntRNG bias_rng{-50, 50}; | |||
| UniformIntRNG const_rng{1, 1}; | |||
| // use scale that are all integers to avoid rouding error | |||
| checker.set_rng(0, &rng) | |||
| .set_rng(1, &rng) | |||
| .set_rng(2, &bias_rng) | |||
| .set_rng(3, &rng) | |||
| .set_dtype(0, dtype::QuantizedS8{6.0f}) | |||
| .set_dtype(1, dtype::QuantizedS8{1.0f}) | |||
| .set_dtype(2, dtype::QuantizedS32{6.0f}) | |||
| .set_dtype(3, dtype::QuantizedS8{1.0f}) | |||
| .set_dtype(4, dtype::QuantizedS8{6.0f}) | |||
| .set_epsilon(1e-3); | |||
| param::ConvBias param; | |||
| param.pad_h = param.pad_w = 1; | |||
| param.stride_h = param.stride_w = 1; | |||
| param.format = param::ConvBias::Format::NCHW32; | |||
| checker.set_param(param).execs({{16, 16, 7, 7, 32}, | |||
| {512, 16, 3, 3, 32}, | |||
| {1, 16, 1, 1, 32}, | |||
| {}, | |||
| {}}); | |||
| param.nonlineMode = param::ConvBias::NonlineMode::RELU; | |||
| checker.set_param(param).execs({{16, 16, 7, 7, 32}, | |||
| {512, 16, 1, 1, 32}, | |||
| {1, 16, 1, 1, 32}, | |||
| {}, | |||
| {}}); | |||
| param.nonlineMode = param::ConvBias::NonlineMode::H_SWISH; | |||
| checker.set_param(param).execs({{16, 16, 7, 7, 32}, | |||
| {512, 16, 3, 3, 32}, | |||
| {1, 16, 1, 1, 32}, | |||
| {}, | |||
| {}}); | |||
| // use non integer scale | |||
| param.nonlineMode = param::ConvBias::NonlineMode::H_SWISH; | |||
| checker.set_dtype(0, dtype::QuantizedS8{1.1f}) | |||
| .set_dtype(1, dtype::QuantizedS8{1.2f}) | |||
| .set_dtype(2, dtype::QuantizedS32{1.1f * 1.2f}) | |||
| .set_dtype(3, dtype::QuantizedS8{1.1f}) | |||
| .set_dtype(4, dtype::QuantizedS8{6.0f}) | |||
| .set_epsilon(1 + 1e-3) | |||
| .set_max_avg_error(1e-1) | |||
| .set_max_avg_biased_error(1e-1) | |||
| .execs({{16, 16, 7, 7, 32}, | |||
| {512, 16, 3, 3, 32}, | |||
| {1, 16, 1, 1, 32}, | |||
| {16, 16, 7, 7, 32}, | |||
| {}}); | |||
| }; | |||
| std::string algo = ConvBias::algo_name<ConvBias::DirectParam>( | |||
| "INT8_NCHW32_IMMA_IMPLICIT_GEMM_256X128X64_64X64X64", | |||
| ConvBias::DirectParam{}); | |||
| check(algo); | |||
| algo = ConvBias::algo_name<ConvBias::DirectParam>( | |||
| "INT8_NCHW32_IMMA_IMPLICIT_GEMM_32X64X64_32X16X64", | |||
| ConvBias::DirectParam{}); | |||
| check(algo); | |||
| } | |||
| #endif | |||
| #if MEGDNN_WITH_BENCHMARK | |||
| TEST_F(CUDA, BENCHMARK_CONV_BIAS_INT8_CHWN4) { | |||
| require_compute_capability(6, 1); | |||
| @@ -1233,6 +1216,18 @@ TEST_F(CUDA, BENCHMARK_CONV_BIAS_INT8_CHWN4_SMALL_CHANNEL) { | |||
| param::ConvBias::Format::CHWN4); | |||
| } | |||
| #if CUDA_VERSION >= 10020 | |||
| TEST_F(CUDA, BENCHMARK_CUTLASS_CONV_BIAS_INT8_NCHW32) { | |||
| require_compute_capability(7, 5); | |||
| benchmark_target_algo_with_cudnn_tsc( | |||
| handle_cuda(), get_resnet50_bench_args(256), | |||
| dtype::QuantizedS8{1.2f}, dtype::QuantizedS8{1.3f}, | |||
| dtype::QuantizedS32{1.2f * 1.3f}, dtype::QuantizedS8{1.0f}, | |||
| "DIRECT:INT8_NCHW32_IMMA_IMPLICIT_GEMM", | |||
| param::ConvBias::Format::NCHW32); | |||
| } | |||
| #endif | |||
| #endif | |||
| } // namespace test | |||
| @@ -34,7 +34,7 @@ bool check_compute_capability_eq(int major, int minor); | |||
| do { \ | |||
| if (!megdnn::test::check_compute_capability((x), (y))) { \ | |||
| printf("skip testcase due to cuda compute capability not " \ | |||
| "require.(expected:%d.%d)", \ | |||
| "require.(expected:%d.%d)\n", \ | |||
| (x), (y)); \ | |||
| return; \ | |||
| } \ | |||
| @@ -44,7 +44,7 @@ bool check_compute_capability_eq(int major, int minor); | |||
| do { \ | |||
| if (!megdnn::test::check_compute_capability_eq((x), (y))) { \ | |||
| printf("skip testcase due to cuda compute capability not " \ | |||
| "equal to %d.%d", \ | |||
| "equal to %d.%d\n", \ | |||
| (x), (y)); \ | |||
| return; \ | |||
| } \ | |||