GitOrigin-RevId: 10512645d5
tags/v1.5.0
| @@ -37,15 +37,16 @@ all: ${PARAM_DEFS} ${ELEMWISE_IMPL} ${CUDA_CONV_IMPL} $(CUDA_MATMUL_IMPL) | |||||
| ../src/cuda/elemwise_multi_type/kimpl: gen_elemwise_multi_type_kern_impls.py | ../src/cuda/elemwise_multi_type/kimpl: gen_elemwise_multi_type_kern_impls.py | ||||
| ./$^ --type cuda $@ | ./$^ --type cuda $@ | ||||
| ../src/cuda/conv_bias/int8/kimpl: gen_cuda_conv_bias_kern_impls.py gen_cutlass_conv_bias_kern_impls.py | |||||
| ../src/cuda/conv_bias/int8/kimpl: gen_cuda_conv_bias_kern_impls.py gen_cutlass_conv_bias_kern_impls.py cutlass_generator/generator.py | |||||
| ./gen_cuda_conv_bias_kern_impls.py --type dp4a $@ | ./gen_cuda_conv_bias_kern_impls.py --type dp4a $@ | ||||
| ./gen_cutlass_conv_bias_kern_impls.py --type dp4a $@ | ./gen_cutlass_conv_bias_kern_impls.py --type dp4a $@ | ||||
| python3 ./cutlass_generator/generator.py --operations all --type simt $@ | |||||
| ../src/cuda/conv_bias/int8_imma/kimpl: gen_cuda_conv_bias_kern_impls.py gen_cutlass_conv_bias_kern_impls.py | ../src/cuda/conv_bias/int8_imma/kimpl: gen_cuda_conv_bias_kern_impls.py gen_cutlass_conv_bias_kern_impls.py | ||||
| ./gen_cuda_conv_bias_kern_impls.py --type imma $@ | ./gen_cuda_conv_bias_kern_impls.py --type imma $@ | ||||
| ./gen_cutlass_conv_bias_kern_impls.py --type imma $@ | ./gen_cutlass_conv_bias_kern_impls.py --type imma $@ | ||||
| ../src/cuda/batch_conv_bias/int8/kimpl: gen_cuda_batch_conv_bias_kern_impls.py | |||||
| ../src/cuda/batch_conv_bias/int8/kimpl: gen_cuda_batch_conv_bias_kern_impls.py | |||||
| ./$^ --type dp4a $@ | ./$^ --type dp4a $@ | ||||
| ../src/cuda/matrix_mul/fp32_simt/kimpl: gen_cutlass_matmul_kern_impls.py | ../src/cuda/matrix_mul/fp32_simt/kimpl: gen_cutlass_matmul_kern_impls.py | ||||
| @@ -43,6 +43,7 @@ pdef('Axis').add_fields('int32', 'axis', 0) | |||||
| Doc('NCHW4_NCHW32', 'NCHW4_NCHW32 means input tensors are nchw4 layout, output tensor is nchw32 layout'), | Doc('NCHW4_NCHW32', 'NCHW4_NCHW32 means input tensors are nchw4 layout, output tensor is nchw32 layout'), | ||||
| Doc('NCHW32_NCHW4', 'NCHW32_NCHW4 means input tensors are nchw32 layout, output tensor is nchw4 layout'), | Doc('NCHW32_NCHW4', 'NCHW32_NCHW4 means input tensors are nchw32 layout, output tensor is nchw4 layout'), | ||||
| Doc('NCHW4_NCHW', 'NCHW4_NCHW means input tensors are nchw4 layout, output tensor is nchw layout'), | Doc('NCHW4_NCHW', 'NCHW4_NCHW means input tensors are nchw4 layout, output tensor is nchw layout'), | ||||
| Doc('NCHW4_NHWC', 'NCHW4_NHWC means input tensors are nchw4 layout, output tensor is nhwc layout'), | |||||
| Doc('NHWC_NCHW', 'NHWC_NCHW means input tensors are nhwc layout, ' | Doc('NHWC_NCHW', 'NHWC_NCHW means input tensors are nhwc layout, ' | ||||
| 'output tensor is nchw layout'), | 'output tensor is nchw layout'), | ||||
| Doc('NHWC_NCHW4_IC_SMALL', 'NHWC_NCHW4_IC_SMALL means input tensors are nhwc(c < 4) layout, ' | Doc('NHWC_NCHW4_IC_SMALL', 'NHWC_NCHW4_IC_SMALL means input tensors are nhwc(c < 4) layout, ' | ||||
| @@ -99,6 +100,7 @@ pdef('Axis').add_fields('int32', 'axis', 0) | |||||
| Doc('NCHW4_NCHW32', 'NCHW4_NCHW32 means input tensors are nchw4 layout, output tensor is nchw32 layout'), | Doc('NCHW4_NCHW32', 'NCHW4_NCHW32 means input tensors are nchw4 layout, output tensor is nchw32 layout'), | ||||
| Doc('NCHW32_NCHW4', 'NCHW32_NCHW4 means input tensors are nchw32 layout, output tensor is nchw4 layout'), | Doc('NCHW32_NCHW4', 'NCHW32_NCHW4 means input tensors are nchw32 layout, output tensor is nchw4 layout'), | ||||
| Doc('NCHW4_NCHW', 'NCHW4_NCHW means input tensors are nchw4 layout, output tensor is nchw layout'), | Doc('NCHW4_NCHW', 'NCHW4_NCHW means input tensors are nchw4 layout, output tensor is nchw layout'), | ||||
| Doc('NCHW4_NHWC', 'NCHW4_NHWC means input tensors are nchw4 layout, output tensor is nhwc layout'), | |||||
| Doc('NHWC_NCHW', 'NHWC_NCHW means input tensors are nhwc layout, ' | Doc('NHWC_NCHW', 'NHWC_NCHW means input tensors are nhwc layout, ' | ||||
| 'output tensor is nchw layout'), | 'output tensor is nchw layout'), | ||||
| Doc('NHWC_NCHW4_IC_SMALL', 'NHWC_NCHW4_IC_SMALL means input tensors are nhwc(c < 4) layout, ' | Doc('NHWC_NCHW4_IC_SMALL', 'NHWC_NCHW4_IC_SMALL means input tensors are nhwc(c < 4) layout, ' | ||||
| @@ -65,7 +65,8 @@ void do_check_exec_common( | |||||
| bias.to_string().c_str(), dst.to_string().c_str()); | bias.to_string().c_str(), dst.to_string().c_str()); | ||||
| megdnn_assert(bias.shape[2] == 1); | megdnn_assert(bias.shape[2] == 1); | ||||
| megdnn_assert(bias.shape[3] == 1); | megdnn_assert(bias.shape[3] == 1); | ||||
| } else if (opr->param().format == param::ConvBias::Format::NHWC) { | |||||
| } else if (param().format == param::ConvBias::Format::NHWC || | |||||
| param().format == param::ConvBias::Format::NCHW4_NHWC) { | |||||
| megdnn_assert(bias.shape[0] == 1); | megdnn_assert(bias.shape[0] == 1); | ||||
| megdnn_assert(bias.shape[1] == 1); | megdnn_assert(bias.shape[1] == 1); | ||||
| megdnn_assert(bias.shape[2] == 1); | megdnn_assert(bias.shape[2] == 1); | ||||
| @@ -368,7 +368,8 @@ void make_canonized_filter_meta_nchwx( | |||||
| megdnn_assert(param.format == Param::Format::NCHW4 || | megdnn_assert(param.format == Param::Format::NCHW4 || | ||||
| param.format == Param::Format::NCHW8 || | param.format == Param::Format::NCHW8 || | ||||
| param.format == Param::Format::NCHW32 || | param.format == Param::Format::NCHW32 || | ||||
| param.format == Param::Format::NCHW4_NCHW || | |||||
| param.format == Param::Format::NCHW4_NCHW || | |||||
| param.format == Param::Format::NCHW4_NHWC || | |||||
| param.format == Param::Format::NCHW4_NCHW32 || | param.format == Param::Format::NCHW4_NCHW32 || | ||||
| param.format == Param::Format::NCHW32_NCHW4 || | param.format == Param::Format::NCHW32_NCHW4 || | ||||
| param.format == Param::Format::NCHW64); | param.format == Param::Format::NCHW64); | ||||
| @@ -498,6 +499,7 @@ ConvolutionBase<Parameter>::make_canonized_filter_meta( | |||||
| } | } | ||||
| } else if (param().format == Param::Format::NCHW4 || | } else if (param().format == Param::Format::NCHW4 || | ||||
| param().format == Param::Format::NCHW4_NCHW || | param().format == Param::Format::NCHW4_NCHW || | ||||
| param().format == Param::Format::NCHW4_NHWC || | |||||
| param().format == Param::Format::NCHW4_NCHW32) { | param().format == Param::Format::NCHW4_NCHW32) { | ||||
| make_canonized_filter_meta_nchwx<4, Parameter>(src_ndim, filter, | make_canonized_filter_meta_nchwx<4, Parameter>(src_ndim, filter, | ||||
| param(), ret); | param(), ret); | ||||
| @@ -547,7 +549,12 @@ void ConvolutionBase<Parameter>::check_or_deduce_dtype_fwd(DType src, | |||||
| src.enumv() == DTypeEnum::Quantized4Asymm) { | src.enumv() == DTypeEnum::Quantized4Asymm) { | ||||
| supported_dst_dtype.push_back( | supported_dst_dtype.push_back( | ||||
| dtype::QuantizedS32(mul_scale(src, filter))); | dtype::QuantizedS32(mul_scale(src, filter))); | ||||
| if (dst.valid() && dst.enumv() == src.enumv()) { | |||||
| bool cond_dst = | |||||
| dst.valid() && (dst.enumv() == src.enumv() || | |||||
| ((dst.enumv() == DTypeEnum::QuantizedS4 || | |||||
| dst.enumv() == DTypeEnum::Quantized4Asymm) && | |||||
| src.enumv() == DTypeEnum::QuantizedS8)); | |||||
| if (cond_dst) { | |||||
| supported_dst_dtype.push_back(dst); | supported_dst_dtype.push_back(dst); | ||||
| } | } | ||||
| if (src.enumv() == DTypeEnum::QuantizedS8) { | if (src.enumv() == DTypeEnum::QuantizedS8) { | ||||
| @@ -611,7 +618,8 @@ ConvolutionBase<Parameter>::deduce_layout_fwd(const TensorLayout& src, | |||||
| } else { | } else { | ||||
| megdnn_assert(param().format == Param::Format::NHWCD4 || | megdnn_assert(param().format == Param::Format::NHWCD4 || | ||||
| param().format == Param::Format::NCHW4 || | param().format == Param::Format::NCHW4 || | ||||
| param().format == Param::Format::NCHW4_NCHW || | |||||
| param().format == Param::Format::NCHW4_NCHW || | |||||
| param().format == Param::Format::NCHW4_NHWC || | |||||
| param().format == Param::Format::NCHW4_NCHW32 || | param().format == Param::Format::NCHW4_NCHW32 || | ||||
| param().format == Param::Format::NCHW44 || | param().format == Param::Format::NCHW44 || | ||||
| param().format == Param::Format::NCHW44_DOT || | param().format == Param::Format::NCHW44_DOT || | ||||
| @@ -879,6 +887,21 @@ ConvolutionBase<Parameter>::deduce_layout_fwd(const TensorLayout& src, | |||||
| cflt.stride[0], cflt.padding[0]); | cflt.stride[0], cflt.padding[0]); | ||||
| dst[3] = infer_conv_shape(src[3], cflt.dilated_spatial[1], | dst[3] = infer_conv_shape(src[3], cflt.dilated_spatial[1], | ||||
| cflt.stride[1], cflt.padding[1]); | cflt.stride[1], cflt.padding[1]); | ||||
| } else if (param().format == Param::Format::NCHW4_NHWC) { | |||||
| megdnn_assert(src.ndim == 5, | |||||
| "invalid src ndim for NCHW4_NHWC, expected=5, got=%zu", | |||||
| src.ndim); | |||||
| megdnn_assert(cflt.icpg * cflt.group == src[1] * 4, | |||||
| "%s icpg=%u group=%u", errmsg().c_str(), cflt.icpg, | |||||
| cflt.group); | |||||
| dst.ndim = 4; | |||||
| dst[0] = src[0]; | |||||
| dst[1] = infer_conv_shape(src[2], cflt.dilated_spatial[0], | |||||
| cflt.stride[0], cflt.padding[0]); | |||||
| dst[2] = infer_conv_shape(src[3], cflt.dilated_spatial[1], | |||||
| cflt.stride[1], cflt.padding[1]); | |||||
| auto oc = cflt.ocpg * cflt.group; | |||||
| dst[3] = oc; | |||||
| } else if (param().format == Param::Format::NCHW4_NCHW32) { | } else if (param().format == Param::Format::NCHW4_NCHW32) { | ||||
| megdnn_assert(src.ndim == 5, | megdnn_assert(src.ndim == 5, | ||||
| "invalid src ndim for NCHW4_NCHW32, expected=5, got=%zu", | "invalid src ndim for NCHW4_NCHW32, expected=5, got=%zu", | ||||
| @@ -35,6 +35,9 @@ bool ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::is_available( | |||||
| args.src_layout->dtype.enumv() == DTypeEnum::Quantized4Asymm) && | args.src_layout->dtype.enumv() == DTypeEnum::Quantized4Asymm) && | ||||
| args.filter_layout->dtype.enumv() == DTypeEnum::QuantizedS4) | args.filter_layout->dtype.enumv() == DTypeEnum::QuantizedS4) | ||||
| return false; | return false; | ||||
| if (args.dst_layout->dtype.enumv() == DTypeEnum::QuantizedS4 || | |||||
| args.dst_layout->dtype.enumv() == DTypeEnum::Quantized4Asymm) | |||||
| return false; | |||||
| if (args.src_layout->dtype == args.filter_layout->dtype && | if (args.src_layout->dtype == args.filter_layout->dtype && | ||||
| args.src_layout->dtype == dtype::BFloat16()) { | args.src_layout->dtype == dtype::BFloat16()) { | ||||
| return false; | return false; | ||||
| @@ -911,4 +911,140 @@ void megdnn::cuda::cutlass_wrapper:: | |||||
| INST(true); | INST(true); | ||||
| #undef INST | #undef INST | ||||
| /* ===== cutlass kernel wrapper for nchw4 layout and nhwc output ===== */ | |||||
| #if MEGDNN_TEGRA_X1 | |||||
| template <bool signedness> | |||||
| void megdnn::cuda::cutlass_wrapper:: | |||||
| do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nhwc( | |||||
| const int8_t* /* d_src */, const int8_t* /* d_filter */, | |||||
| const int32_t* /* d_bias */, const int8_t* /* d_z */, | |||||
| int8_t* /* d_dst */, int* /* workspace */, | |||||
| const convolution::ConvParam& /* param */, | |||||
| uint32_t /* nonlinear_mode */, float /* alpha */, | |||||
| float /* beta */, float /* gamma */, float /* delta */, | |||||
| float /* theta */, float /* scale */, | |||||
| const GemmCoord& /* threadblock_shape */, | |||||
| const GemmCoord& /* warp_shape */, int /* stages */, | |||||
| cudaStream_t /* stream */) {} | |||||
| #else | |||||
| template <bool signedness> | |||||
| void megdnn::cuda::cutlass_wrapper:: | |||||
| do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nhwc( | |||||
| const int8_t* d_src, const int8_t* d_filter, | |||||
| const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, | |||||
| int* workspace, const convolution::ConvParam& param, | |||||
| uint32_t nonlinear_mode, float alpha, float beta, float gamma, | |||||
| float delta, float theta, float scale, | |||||
| const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, | |||||
| int stages, cudaStream_t stream) { | |||||
| #define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \ | |||||
| threadblock_k_, warp_m_, warp_n_, \ | |||||
| warp_k_, stages_, aligned_) \ | |||||
| if (threadblock_shape.m() == threadblock_m_ && \ | |||||
| threadblock_shape.n() == threadblock_n_ && \ | |||||
| threadblock_shape.k() == threadblock_k_ && \ | |||||
| warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ | |||||
| warp_shape.k() == warp_k_ && stages == stages_) { \ | |||||
| using ThreadBlockShape = \ | |||||
| cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \ | |||||
| threadblock_k_>; \ | |||||
| using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \ | |||||
| using InstructionShape = cutlass::gemm::GemmShape<1, 1, 4>; \ | |||||
| using Convolution = cutlass::conv::device::Convolution< \ | |||||
| int8_t, cutlass::layout::TensorNCxHWx<4>, int8_t, \ | |||||
| cutlass::layout::TensorCxRSKx<4>, ElementOutput, \ | |||||
| cutlass::layout::TensorNHWC, int32_t, \ | |||||
| cutlass::layout::TensorNHWC, int32_t, \ | |||||
| cutlass::conv::ConvType::kConvolution, \ | |||||
| cutlass::arch::OpClassSimt, cutlass::arch::Sm75, \ | |||||
| ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ | |||||
| cutlass::conv::threadblock:: \ | |||||
| ConvolutionFpropNCxHWxThreadblockSwizzle, \ | |||||
| stages_, 4, aligned_, NeedLoadFromConstMem, \ | |||||
| cutlass::arch::OpMultiplyAddSaturate>; \ | |||||
| typename Convolution::ConvolutionParameter conv_param( \ | |||||
| param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ | |||||
| param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ | |||||
| param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \ | |||||
| return cutlass_convolution_wrapper<Convolution>( \ | |||||
| d_src, d_filter, d_bias, \ | |||||
| reinterpret_cast<const ElementOutput*>(d_z), \ | |||||
| reinterpret_cast<ElementOutput*>(d_dst), workspace, \ | |||||
| conv_param, epilogue, stream); \ | |||||
| } | |||||
| #define DISPATCH_KERNEL \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 32, 64, 32, 32, 2, 16); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 32, 64, 32, 32, 2, 16); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 64, 32, 32, 2, 16); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 32, 64, 32, 32, 2, 16); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 2, 16); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 64, 32, 64, 32, 32, 2, 16); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 64, 32, 32, 64, 32, 2, 16); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 32, 32, 64, 32, 32, 2, 16); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 32, 32, 32, 32, 32, 2, 16); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 128, 16, 16, 128, 16, 1, 8); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 64, 8, 16, 64, 8, 2, 4); \ | |||||
| megdnn_assert(false, \ | |||||
| "unsupported threadblock shape (%dx%dx%d) and warp shape " \ | |||||
| "(%dx%dx%d)", \ | |||||
| threadblock_shape.m(), threadblock_shape.n(), \ | |||||
| threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ | |||||
| warp_shape.k()); | |||||
| using ElementOutput = cutlass::integer_subbyte<4, signedness>; | |||||
| using ElementAccumulator = int32_t; | |||||
| using ElementBias = int32_t; | |||||
| using ElementCompute = float; | |||||
| using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode; | |||||
| switch (nonlinear_mode) { | |||||
| case NonlineMode::IDENTITY: { | |||||
| using EpilogueOp = | |||||
| cutlass::epilogue::thread::BiasAddLinearCombinationClamp< | |||||
| ElementOutput, 8, ElementAccumulator, ElementBias, | |||||
| ElementCompute>; | |||||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma, | |||||
| delta + theta}; | |||||
| DISPATCH_KERNEL; | |||||
| } | |||||
| case NonlineMode::RELU: { | |||||
| using EpilogueOp = cutlass::epilogue::thread:: | |||||
| BiasAddLinearCombinationReluClamp< | |||||
| ElementOutput, 8, ElementAccumulator, ElementBias, | |||||
| ElementCompute>; | |||||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma, | |||||
| 0, delta, theta}; | |||||
| DISPATCH_KERNEL; | |||||
| } | |||||
| case NonlineMode::H_SWISH: { | |||||
| using EpilogueOp = cutlass::epilogue::thread:: | |||||
| BiasAddLinearCombinationHSwishClamp< | |||||
| ElementOutput, 8, ElementAccumulator, ElementBias, | |||||
| ElementCompute>; | |||||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma, | |||||
| scale, detla, theta}; | |||||
| DISPATCH_KERNEL; | |||||
| } | |||||
| default: | |||||
| megdnn_assert(false, | |||||
| "unsupported nonlinear mode for conv bias operator"); | |||||
| } | |||||
| #undef DISPATCH_KERNEL_WITH_TILE_SHAPE | |||||
| #undef DISPATCH_KERNEL | |||||
| } | |||||
| #endif | |||||
| #define INST(signedness) \ | |||||
| template void megdnn::cuda::cutlass_wrapper:: \ | |||||
| do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nhwc<signedness>( \ | |||||
| const int8_t* d_src, const int8_t* d_filter, \ | |||||
| const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, \ | |||||
| int* workspace, const convolution::ConvParam& param, \ | |||||
| uint32_t nonlinear_mode, float alpha, float beta, \ | |||||
| float gamma, float delta, float theta, float scale, \ | |||||
| const GemmCoord& threadblock_shape, \ | |||||
| const GemmCoord& warp_shape, int stages, \ | |||||
| cudaStream_t stream); | |||||
| INST(true); | |||||
| INST(false); | |||||
| #undef INST | |||||
| // vim: syntax=cuda.doxygen | // vim: syntax=cuda.doxygen | ||||
| @@ -94,6 +94,15 @@ void do_conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64( | |||||
| float scale, uint8_t src_zero_point, const GemmCoord& threadblock_shape, | float scale, uint8_t src_zero_point, const GemmCoord& threadblock_shape, | ||||
| const GemmCoord& warp_shape, cudaStream_t stream); | const GemmCoord& warp_shape, cudaStream_t stream); | ||||
| template <bool signedness> | |||||
| void do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nhwc( | |||||
| const int8_t* d_src, const int8_t* d_filter, const int32_t* d_bias, | |||||
| const int8_t* d_z, int8_t* d_dst, int* workspace, | |||||
| const convolution::ConvParam& param, uint32_t nonlinear_mode, | |||||
| float alpha, float beta, float gamma, float delta, float theta, | |||||
| float scale, const GemmCoord& threadblock_shape, | |||||
| const GemmCoord& warp_shape, int stages, cudaStream_t stream); | |||||
| } // namespace cutlass_wrapper | } // namespace cutlass_wrapper | ||||
| } // namespace cuda | } // namespace cuda | ||||
| } // namespace megdnn | } // namespace megdnn | ||||
| @@ -0,0 +1,65 @@ | |||||
| /** | |||||
| * \file | |||||
| * dnn/src/cuda/conv_bias/int8/conv_bias_int8_implicit_gemm_cutlass_wrapper.cuinl | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #include "cutlass/convolution/device/convolution.h" | |||||
| #include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh" | |||||
| using namespace megdnn; | |||||
| using namespace cuda; | |||||
| using namespace cutlass_wrapper; | |||||
| template <typename Convolution> | |||||
| void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper( | |||||
| const typename Convolution::ElementSrc* d_src, | |||||
| const typename Convolution::ElementFilter* d_filter, | |||||
| const typename Convolution::ElementBias* d_bias, | |||||
| const typename Convolution::ElementDst* d_z, | |||||
| typename Convolution::ElementDst* d_dst, int* workspace, | |||||
| typename Convolution::ConvolutionParameter const& conv_param, | |||||
| typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||||
| cudaStream_t stream, typename Convolution::ExtraParam extra_param) { | |||||
| typename Convolution::TensorRefSrc tensor_src{ | |||||
| const_cast<typename Convolution::ElementSrc*>(d_src), | |||||
| Convolution::LayoutSrc::packed( | |||||
| {conv_param.N, conv_param.H, conv_param.W, conv_param.C})}; | |||||
| typename Convolution::TensorRefFilter tensor_filter{ | |||||
| const_cast<typename Convolution::ElementFilter*>(d_filter), | |||||
| Convolution::LayoutFilter::packed( | |||||
| {conv_param.K, conv_param.R, conv_param.S, conv_param.C})}; | |||||
| typename Convolution::TensorRefBias tensor_bias{ | |||||
| const_cast<typename Convolution::ElementBias*>(d_bias), | |||||
| Convolution::LayoutBias::packed({1, 1, 1, conv_param.K})}; | |||||
| typename Convolution::TensorRefDst tensor_z{ | |||||
| const_cast<typename Convolution::ElementDst*>(d_z), | |||||
| Convolution::LayoutDst::packed( | |||||
| {conv_param.N, conv_param.P, conv_param.Q, conv_param.K})}; | |||||
| typename Convolution::TensorRefDst tensor_dst{ | |||||
| d_dst, | |||||
| Convolution::LayoutDst::packed( | |||||
| {conv_param.N, conv_param.P, conv_param.Q, conv_param.K})}; | |||||
| typename Convolution::Arguments arguments{conv_param, | |||||
| tensor_src.non_const_ref(), | |||||
| tensor_filter.non_const_ref(), | |||||
| tensor_bias.non_const_ref(), | |||||
| tensor_z.non_const_ref(), | |||||
| tensor_dst.non_const_ref(), | |||||
| epilogue, | |||||
| {}, | |||||
| {}, | |||||
| extra_param}; | |||||
| Convolution conv_op; | |||||
| cutlass_check(conv_op.initialize(arguments, workspace)); | |||||
| cutlass_check(conv_op(stream)); | |||||
| after_kernel_launch(); | |||||
| } | |||||
| // vim: syntax=cuda.doxygen | |||||
| @@ -37,27 +37,40 @@ bool ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::is_available( | |||||
| if (!check_bias_share_in_channel(*(args.bias_layout), | if (!check_bias_share_in_channel(*(args.bias_layout), | ||||
| param.format)) | param.format)) | ||||
| return false; | return false; | ||||
| if (param.format == Format::NCHW4_NCHW32) { | |||||
| if (m_algo_param.threadblock_m % 32 != 0) | |||||
| return false; | |||||
| } else if (param.format != Format::NCHW4_NCHW && | |||||
| param.format != Format::NCHW4) | |||||
| return false; | |||||
| bool valid_format = param.format == Format::NCHW4_NCHW32 && | |||||
| m_algo_param.threadblock_m % 32 == 0; | |||||
| valid_format |= param.format == Format::NCHW4_NCHW && | |||||
| args.bias_layout->dtype.enumv() == DTypeEnum::Float32 && | |||||
| args.dst_layout->dtype.enumv() == DTypeEnum::Float32; | |||||
| valid_format |= | |||||
| param.format == Format::NCHW4_NHWC && | |||||
| args.bias_layout->dtype.enumv() == DTypeEnum::QuantizedS32 && | |||||
| (args.dst_layout->dtype.enumv() == DTypeEnum::QuantizedS4 || | |||||
| args.dst_layout->dtype.enumv() == DTypeEnum::Quantized4Asymm); | |||||
| valid_format |= param.format == Format::NCHW4; | |||||
| if (!valid_format) return false; | |||||
| size_t n = args.src_layout->operator[](0), | size_t n = args.src_layout->operator[](0), | ||||
| ci = args.src_layout->operator[](1) * 4, | ci = args.src_layout->operator[](1) * 4, | ||||
| hi = args.src_layout->operator[](2), | hi = args.src_layout->operator[](2), | ||||
| wi = args.src_layout->operator[](3); | wi = args.src_layout->operator[](3); | ||||
| size_t ho = args.dst_layout->operator[](2), | |||||
| wo = args.dst_layout->operator[](3); | |||||
| size_t co; | size_t co; | ||||
| size_t dst_spatial_pos; | |||||
| if (param.format == Format::NCHW4) { | if (param.format == Format::NCHW4) { | ||||
| co = args.dst_layout->operator[](1) * 4; | co = args.dst_layout->operator[](1) * 4; | ||||
| dst_spatial_pos = 2; | |||||
| } else if (param.format == Format::NCHW4_NCHW) { | } else if (param.format == Format::NCHW4_NCHW) { | ||||
| co = args.dst_layout->operator[](1); | co = args.dst_layout->operator[](1); | ||||
| dst_spatial_pos = 2; | |||||
| } else if (param.format == Format::NCHW4_NHWC) { | |||||
| co = args.dst_layout->operator[](3); | |||||
| dst_spatial_pos = 1; | |||||
| } else { | } else { | ||||
| megdnn_assert(param.format == Format::NCHW4_NCHW32); | megdnn_assert(param.format == Format::NCHW4_NCHW32); | ||||
| dst_spatial_pos = 2; | |||||
| co = args.dst_layout->operator[](1) * 32; | co = args.dst_layout->operator[](1) * 32; | ||||
| } | } | ||||
| size_t ho = args.dst_layout->operator[](dst_spatial_pos), | |||||
| wo = args.dst_layout->operator[](dst_spatial_pos + 1); | |||||
| UNPACK_CONV_PARAMETER(fm, param); | UNPACK_CONV_PARAMETER(fm, param); | ||||
| MARK_USED_VAR | MARK_USED_VAR | ||||
| // TODO support group conv | // TODO support group conv | ||||
| @@ -72,7 +85,9 @@ bool ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::is_available( | |||||
| available &= (src_dtype.enumv() == DTypeEnum::QuantizedS8 && | available &= (src_dtype.enumv() == DTypeEnum::QuantizedS8 && | ||||
| filter_dtype.enumv() == DTypeEnum::QuantizedS8); | filter_dtype.enumv() == DTypeEnum::QuantizedS8); | ||||
| available &= (bias_dtype.enumv() == DTypeEnum::QuantizedS32 && | available &= (bias_dtype.enumv() == DTypeEnum::QuantizedS32 && | ||||
| dst_dtype.enumv() == DTypeEnum::QuantizedS8) || | |||||
| (dst_dtype.enumv() == DTypeEnum::QuantizedS8 || | |||||
| dst_dtype.enumv() == DTypeEnum::QuantizedS4 || | |||||
| dst_dtype.enumv() == DTypeEnum::Quantized4Asymm)) || | |||||
| (bias_dtype.enumv() == DTypeEnum::Float32 && | (bias_dtype.enumv() == DTypeEnum::Float32 && | ||||
| dst_dtype.enumv() == DTypeEnum::Float32); | dst_dtype.enumv() == DTypeEnum::Float32); | ||||
| // TODO: support dialtion | // TODO: support dialtion | ||||
| @@ -111,17 +126,23 @@ void ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::exec( | |||||
| ci = args.src_layout->operator[](1) * 4, | ci = args.src_layout->operator[](1) * 4, | ||||
| hi = args.src_layout->operator[](2), | hi = args.src_layout->operator[](2), | ||||
| wi = args.src_layout->operator[](3); | wi = args.src_layout->operator[](3); | ||||
| size_t ho = args.dst_layout->operator[](2), | |||||
| wo = args.dst_layout->operator[](3); | |||||
| size_t co; | |||||
| size_t co, dst_spatial_pos; | |||||
| if (param.format == Format::NCHW4) { | if (param.format == Format::NCHW4) { | ||||
| co = args.dst_layout->operator[](1) * 4; | co = args.dst_layout->operator[](1) * 4; | ||||
| dst_spatial_pos = 2; | |||||
| } else if (param.format == Format::NCHW4_NCHW) { | } else if (param.format == Format::NCHW4_NCHW) { | ||||
| co = args.dst_layout->operator[](1); | co = args.dst_layout->operator[](1); | ||||
| dst_spatial_pos = 2; | |||||
| } else if (param.format == Format::NCHW4_NHWC) { | |||||
| co = args.dst_layout->operator[](3); | |||||
| dst_spatial_pos = 1; | |||||
| } else { | } else { | ||||
| megdnn_assert(param.format == Format::NCHW4_NCHW32); | megdnn_assert(param.format == Format::NCHW4_NCHW32); | ||||
| dst_spatial_pos = 2; | |||||
| co = args.dst_layout->operator[](1) * 32; | co = args.dst_layout->operator[](1) * 32; | ||||
| } | } | ||||
| size_t ho = args.dst_layout->operator[](dst_spatial_pos), | |||||
| wo = args.dst_layout->operator[](dst_spatial_pos + 1); | |||||
| UNPACK_CONV_PARAMETER(fm, param); | UNPACK_CONV_PARAMETER(fm, param); | ||||
| MARK_USED_VAR | MARK_USED_VAR | ||||
| auto&& stream = cuda_stream(args.opr->handle()); | auto&& stream = cuda_stream(args.opr->handle()); | ||||
| @@ -161,136 +182,107 @@ void ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::exec( | |||||
| float beta = 1.f; | float beta = 1.f; | ||||
| float dst_scale = 1.f; | float dst_scale = 1.f; | ||||
| if (args.bias_layout->dtype.enumv() == DTypeEnum::QuantizedS32) { | if (args.bias_layout->dtype.enumv() == DTypeEnum::QuantizedS32) { | ||||
| megdnn_assert(args.dst_layout->dtype.enumv() == DTypeEnum::QuantizedS8); | |||||
| megdnn_assert(args.dst_layout->dtype.category() == | |||||
| DTypeCategory::QUANTIZED); | |||||
| float bias_scale = args.bias_layout->dtype.param<dtype::QuantizedS32>() | float bias_scale = args.bias_layout->dtype.param<dtype::QuantizedS32>() | ||||
| .scale, | |||||
| dst_scale = | |||||
| args.dst_layout->dtype.param<dtype::QuantizedS8>().scale; | |||||
| .scale; | |||||
| dst_scale = get_scale(args.dst_layout->dtype); | |||||
| alpha /= dst_scale, beta = bias_scale / dst_scale; | alpha /= dst_scale, beta = bias_scale / dst_scale; | ||||
| } | } | ||||
| float gamma = 0.f; | float gamma = 0.f; | ||||
| if (args.z_layout->ndim > 0) { | if (args.z_layout->ndim > 0) { | ||||
| gamma = 1.f; | gamma = 1.f; | ||||
| if (args.z_layout->dtype.enumv() == DTypeEnum::QuantizedS8) { | |||||
| megdnn_assert(args.dst_layout->dtype.enumv() == | |||||
| DTypeEnum::QuantizedS8); | |||||
| float z_scale = args.z_layout->dtype.param<dtype::QuantizedS8>() | |||||
| .scale; | |||||
| if (args.z_layout->dtype.category() == DTypeCategory::QUANTIZED) { | |||||
| megdnn_assert(args.dst_layout->dtype.category() == | |||||
| DTypeCategory::QUANTIZED); | |||||
| float z_scale = get_scale(args.z_layout->dtype); | |||||
| gamma = z_scale / dst_scale; | gamma = z_scale / dst_scale; | ||||
| } | } | ||||
| } | } | ||||
| uint32_t nonlinear_mode = static_cast<uint32_t>(param.nonlineMode); | uint32_t nonlinear_mode = static_cast<uint32_t>(param.nonlineMode); | ||||
| if (fh == 1 && fw == 1) { | |||||
| if (param.format == Format::NCHW4) { | |||||
| cutlass_wrapper::do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4< | |||||
| false>( | |||||
| args.src_tensor->compatible_ptr<int8_t>(), filter_ptr, | |||||
| args.bias_tensor->compatible_ptr<int32_t>(), | |||||
| args.z_tensor->compatible_ptr<int8_t>(), | |||||
| args.dst_tensor->compatible_ptr<int8_t>(), nullptr, | |||||
| kern_param, nonlinear_mode, alpha, beta, gamma, dst_scale, | |||||
| cutlass_wrapper::GemmCoord{m_algo_param.threadblock_m, | |||||
| m_algo_param.threadblock_n, | |||||
| m_algo_param.threadblock_k}, | |||||
| cutlass_wrapper::GemmCoord{m_algo_param.warp_m, | |||||
| m_algo_param.warp_n, | |||||
| m_algo_param.warp_k}, | |||||
| m_algo_param.stage, stream); | |||||
| } else if (param.format == Format::NCHW4_NCHW) { | |||||
| cutlass_wrapper:: | |||||
| do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nchw<false>( | |||||
| args.src_tensor->compatible_ptr<int8_t>(), | |||||
| filter_ptr, | |||||
| args.bias_tensor->compatible_ptr<float>(), | |||||
| args.z_tensor->compatible_ptr<float>(), | |||||
| args.dst_tensor->compatible_ptr<float>(), nullptr, | |||||
| kern_param, nonlinear_mode, alpha, beta, gamma, | |||||
| dst_scale, | |||||
| cutlass_wrapper::GemmCoord{ | |||||
| m_algo_param.threadblock_m, | |||||
| m_algo_param.threadblock_n, | |||||
| m_algo_param.threadblock_k}, | |||||
| cutlass_wrapper::GemmCoord{m_algo_param.warp_m, | |||||
| m_algo_param.warp_n, | |||||
| m_algo_param.warp_k}, | |||||
| m_algo_param.stage, stream); | |||||
| } else { | |||||
| megdnn_assert(param.format == Format::NCHW4_NCHW32); | |||||
| cutlass_wrapper:: | |||||
| do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32< | |||||
| false>( | |||||
| args.src_tensor->compatible_ptr<int8_t>(), | |||||
| filter_ptr, | |||||
| args.bias_tensor->compatible_ptr<int32_t>(), | |||||
| args.z_tensor->compatible_ptr<int8_t>(), | |||||
| args.dst_tensor->compatible_ptr<int8_t>(), nullptr, | |||||
| kern_param, nonlinear_mode, alpha, beta, gamma, | |||||
| dst_scale, | |||||
| cutlass_wrapper::GemmCoord{ | |||||
| m_algo_param.threadblock_m, | |||||
| m_algo_param.threadblock_n, | |||||
| m_algo_param.threadblock_k}, | |||||
| cutlass_wrapper::GemmCoord{m_algo_param.warp_m, | |||||
| m_algo_param.warp_n, | |||||
| m_algo_param.warp_k}, | |||||
| m_algo_param.stage, stream); | |||||
| } | |||||
| bool nonunity_kernel = !(fh == 1 && fw == 1); | |||||
| #define DISPATCH(_nonunity_kernel) \ | |||||
| if (nonunity_kernel == _nonunity_kernel) { \ | |||||
| cb(_nonunity_kernel) \ | |||||
| } | |||||
| if (param.format == Format::NCHW4) { | |||||
| #define cb(_nonunity_kernel) \ | |||||
| cutlass_wrapper::do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4< \ | |||||
| _nonunity_kernel>( \ | |||||
| args.src_tensor->compatible_ptr<int8_t>(), filter_ptr, \ | |||||
| args.bias_tensor->compatible_ptr<int32_t>(), \ | |||||
| args.z_tensor->compatible_ptr<int8_t>(), \ | |||||
| args.dst_tensor->compatible_ptr<int8_t>(), nullptr, kern_param, \ | |||||
| nonlinear_mode, alpha, beta, gamma, dst_scale, \ | |||||
| cutlass_wrapper::GemmCoord{m_algo_param.threadblock_m, \ | |||||
| m_algo_param.threadblock_n, \ | |||||
| m_algo_param.threadblock_k}, \ | |||||
| cutlass_wrapper::GemmCoord{m_algo_param.warp_m, \ | |||||
| m_algo_param.warp_n, \ | |||||
| m_algo_param.warp_k}, \ | |||||
| m_algo_param.stage, stream); | |||||
| DISPATCH(true); | |||||
| DISPATCH(false); | |||||
| #undef cb | |||||
| } else if (param.format == Format::NCHW4_NCHW) { | |||||
| #define cb(_nonunity_kernel) \ | |||||
| cutlass_wrapper::do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nchw< \ | |||||
| _nonunity_kernel>( \ | |||||
| args.src_tensor->compatible_ptr<int8_t>(), filter_ptr, \ | |||||
| args.bias_tensor->compatible_ptr<float>(), \ | |||||
| args.z_tensor->compatible_ptr<float>(), \ | |||||
| args.dst_tensor->compatible_ptr<float>(), nullptr, kern_param, \ | |||||
| nonlinear_mode, alpha, beta, gamma, dst_scale, \ | |||||
| cutlass_wrapper::GemmCoord{m_algo_param.threadblock_m, \ | |||||
| m_algo_param.threadblock_n, \ | |||||
| m_algo_param.threadblock_k}, \ | |||||
| cutlass_wrapper::GemmCoord{m_algo_param.warp_m, \ | |||||
| m_algo_param.warp_n, \ | |||||
| m_algo_param.warp_k}, \ | |||||
| m_algo_param.stage, stream); | |||||
| DISPATCH(true); | |||||
| DISPATCH(false); | |||||
| #undef cb | |||||
| } else if (param.format == Format::NCHW4_NHWC) { | |||||
| #define cb(_nonunity_kernel) \ | |||||
| cutlass_wrapper::do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nhwc< \ | |||||
| _nonunity_kernel>( \ | |||||
| args.src_tensor->compatible_ptr<int8_t>(), filter_ptr, \ | |||||
| args.bias_tensor->compatible_ptr<int32_t>(), \ | |||||
| reinterpret_cast<int8_t*>(args.z_tensor->raw_ptr), \ | |||||
| reinterpret_cast<int8_t*>(args.dst_tensor->raw_ptr), nullptr, \ | |||||
| kern_param, nonlinear_mode, alpha, beta, gamma, dst_scale, \ | |||||
| cutlass_wrapper::GemmCoord{m_algo_param.threadblock_m, \ | |||||
| m_algo_param.threadblock_n, \ | |||||
| m_algo_param.threadblock_k}, \ | |||||
| cutlass_wrapper::GemmCoord{m_algo_param.warp_m, \ | |||||
| m_algo_param.warp_n, \ | |||||
| m_algo_param.warp_k}, \ | |||||
| m_algo_param.stage, stream); | |||||
| cb(true); | |||||
| #undef cb | |||||
| } else { | } else { | ||||
| if (param.format == Format::NCHW4) { | |||||
| cutlass_wrapper::do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4< | |||||
| true>( | |||||
| args.src_tensor->compatible_ptr<int8_t>(), filter_ptr, | |||||
| args.bias_tensor->compatible_ptr<int32_t>(), | |||||
| args.z_tensor->compatible_ptr<int8_t>(), | |||||
| args.dst_tensor->compatible_ptr<int8_t>(), nullptr, | |||||
| kern_param, nonlinear_mode, alpha, beta, gamma, dst_scale, | |||||
| cutlass_wrapper::GemmCoord{m_algo_param.threadblock_m, | |||||
| m_algo_param.threadblock_n, | |||||
| m_algo_param.threadblock_k}, | |||||
| cutlass_wrapper::GemmCoord{m_algo_param.warp_m, | |||||
| m_algo_param.warp_n, | |||||
| m_algo_param.warp_k}, | |||||
| megdnn_assert(param.format == Format::NCHW4_NCHW32); | |||||
| #define cb(_nonunity_kernel) \ | |||||
| cutlass_wrapper:: \ | |||||
| do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32< \ | |||||
| _nonunity_kernel>( \ | |||||
| args.src_tensor->compatible_ptr<int8_t>(), filter_ptr, \ | |||||
| args.bias_tensor->compatible_ptr<int32_t>(), \ | |||||
| args.z_tensor->compatible_ptr<int8_t>(), \ | |||||
| args.dst_tensor->compatible_ptr<int8_t>(), nullptr, \ | |||||
| kern_param, nonlinear_mode, alpha, beta, gamma, dst_scale, \ | |||||
| cutlass_wrapper::GemmCoord{m_algo_param.threadblock_m, \ | |||||
| m_algo_param.threadblock_n, \ | |||||
| m_algo_param.threadblock_k}, \ | |||||
| cutlass_wrapper::GemmCoord{m_algo_param.warp_m, \ | |||||
| m_algo_param.warp_n, \ | |||||
| m_algo_param.warp_k}, \ | |||||
| m_algo_param.stage, stream); | m_algo_param.stage, stream); | ||||
| } else if (param.format == Format::NCHW4_NCHW) { | |||||
| cutlass_wrapper:: | |||||
| do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nchw<true>( | |||||
| args.src_tensor->compatible_ptr<int8_t>(), | |||||
| filter_ptr, | |||||
| args.bias_tensor->compatible_ptr<float>(), | |||||
| args.z_tensor->compatible_ptr<float>(), | |||||
| args.dst_tensor->compatible_ptr<float>(), nullptr, | |||||
| kern_param, nonlinear_mode, alpha, beta, gamma, | |||||
| dst_scale, | |||||
| cutlass_wrapper::GemmCoord{ | |||||
| m_algo_param.threadblock_m, | |||||
| m_algo_param.threadblock_n, | |||||
| m_algo_param.threadblock_k}, | |||||
| cutlass_wrapper::GemmCoord{m_algo_param.warp_m, | |||||
| m_algo_param.warp_n, | |||||
| m_algo_param.warp_k}, | |||||
| m_algo_param.stage, stream); | |||||
| } else { | |||||
| megdnn_assert(param.format == Format::NCHW4_NCHW32); | |||||
| cutlass_wrapper:: | |||||
| do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32< | |||||
| true>( | |||||
| args.src_tensor->compatible_ptr<int8_t>(), | |||||
| filter_ptr, | |||||
| args.bias_tensor->compatible_ptr<int32_t>(), | |||||
| args.z_tensor->compatible_ptr<int8_t>(), | |||||
| args.dst_tensor->compatible_ptr<int8_t>(), nullptr, | |||||
| kern_param, nonlinear_mode, alpha, beta, gamma, | |||||
| dst_scale, | |||||
| cutlass_wrapper::GemmCoord{ | |||||
| m_algo_param.threadblock_m, | |||||
| m_algo_param.threadblock_n, | |||||
| m_algo_param.threadblock_k}, | |||||
| cutlass_wrapper::GemmCoord{m_algo_param.warp_m, | |||||
| m_algo_param.warp_n, | |||||
| m_algo_param.warp_k}, | |||||
| m_algo_param.stage, stream); | |||||
| } | |||||
| DISPATCH(true); | |||||
| DISPATCH(false); | |||||
| #undef cb | |||||
| #undef DISPATCH | |||||
| } | } | ||||
| after_kernel_launch(); | after_kernel_launch(); | ||||
| } | } | ||||
| @@ -315,17 +307,23 @@ void ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::exec_preprocess( | |||||
| ci = args.src_layout->operator[](1) * 4, | ci = args.src_layout->operator[](1) * 4, | ||||
| hi = args.src_layout->operator[](2), | hi = args.src_layout->operator[](2), | ||||
| wi = args.src_layout->operator[](3); | wi = args.src_layout->operator[](3); | ||||
| size_t ho = args.dst_layout->operator[](2), | |||||
| wo = args.dst_layout->operator[](3); | |||||
| size_t co; | |||||
| size_t co, dst_spatial_pos; | |||||
| if (param.format == Format::NCHW4) { | if (param.format == Format::NCHW4) { | ||||
| co = args.dst_layout->operator[](1) * 4; | co = args.dst_layout->operator[](1) * 4; | ||||
| dst_spatial_pos = 2; | |||||
| } else if (param.format == Format::NCHW4_NCHW) { | } else if (param.format == Format::NCHW4_NCHW) { | ||||
| co = args.dst_layout->operator[](1); | co = args.dst_layout->operator[](1); | ||||
| dst_spatial_pos = 2; | |||||
| } else if (param.format == Format::NCHW4_NHWC) { | |||||
| co = args.dst_layout->operator[](3); | |||||
| dst_spatial_pos = 1; | |||||
| } else { | } else { | ||||
| megdnn_assert(param.format == Format::NCHW4_NCHW32); | megdnn_assert(param.format == Format::NCHW4_NCHW32); | ||||
| dst_spatial_pos = 2; | |||||
| co = args.dst_layout->operator[](1) * 32; | co = args.dst_layout->operator[](1) * 32; | ||||
| } | } | ||||
| size_t ho = args.dst_layout->operator[](dst_spatial_pos), | |||||
| wo = args.dst_layout->operator[](dst_spatial_pos + 1); | |||||
| UNPACK_CONV_PARAMETER(fm, param); | UNPACK_CONV_PARAMETER(fm, param); | ||||
| MARK_USED_VAR | MARK_USED_VAR | ||||
| TensorLayout src{{co, ci / 4 * fh * fw}, dtype::Int32()}; | TensorLayout src{{co, ci / 4 * fh * fw}, dtype::Int32()}; | ||||
| @@ -1,65 +0,0 @@ | |||||
| /** | |||||
| * \file | |||||
| * dnn/src/cuda/conv_bias/int8/conv_bias_int8_implicit_gemm_cutlass_wrapper.cuinl | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #include "cutlass/convolution/device/convolution.h" | |||||
| #include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh" | |||||
| using namespace megdnn; | |||||
| using namespace cuda; | |||||
| using namespace cutlass_wrapper; | |||||
| template <typename Convolution> | |||||
| void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper( | |||||
| const typename Convolution::ElementSrc* d_src, | |||||
| const typename Convolution::ElementFilter* d_filter, | |||||
| const typename Convolution::ElementBias* d_bias, | |||||
| const typename Convolution::ElementDst* d_z, | |||||
| typename Convolution::ElementDst* d_dst, int* workspace, | |||||
| typename Convolution::ConvolutionParameter const& conv_param, | |||||
| typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||||
| cudaStream_t stream, typename Convolution::ExtraParam extra_param) { | |||||
| typename Convolution::TensorRefSrc tensor_src{ | |||||
| const_cast<typename Convolution::ElementSrc*>(d_src), | |||||
| Convolution::LayoutSrc::packed( | |||||
| {conv_param.N, conv_param.H, conv_param.W, conv_param.C})}; | |||||
| typename Convolution::TensorRefFilter tensor_filter{ | |||||
| const_cast<typename Convolution::ElementFilter*>(d_filter), | |||||
| Convolution::LayoutFilter::packed( | |||||
| {conv_param.K, conv_param.R, conv_param.S, conv_param.C})}; | |||||
| typename Convolution::TensorRefBias tensor_bias{ | |||||
| const_cast<typename Convolution::ElementBias*>(d_bias), | |||||
| Convolution::LayoutBias::packed({1, 1, 1, conv_param.K})}; | |||||
| typename Convolution::TensorRefDst tensor_z{ | |||||
| const_cast<typename Convolution::ElementDst*>(d_z), | |||||
| Convolution::LayoutDst::packed( | |||||
| {conv_param.N, conv_param.P, conv_param.Q, conv_param.K})}; | |||||
| typename Convolution::TensorRefDst tensor_dst{ | |||||
| d_dst, | |||||
| Convolution::LayoutDst::packed( | |||||
| {conv_param.N, conv_param.P, conv_param.Q, conv_param.K})}; | |||||
| typename Convolution::Arguments arguments{conv_param, | |||||
| tensor_src.non_const_ref(), | |||||
| tensor_filter.non_const_ref(), | |||||
| tensor_bias.non_const_ref(), | |||||
| tensor_z.non_const_ref(), | |||||
| tensor_dst.non_const_ref(), | |||||
| epilogue, | |||||
| {}, | |||||
| {}, | |||||
| extra_param}; | |||||
| Convolution conv_op; | |||||
| cutlass_check(conv_op.initialize(arguments, workspace)); | |||||
| cutlass_check(conv_op(stream)); | |||||
| after_kernel_launch(); | |||||
| } | |||||
| // vim: syntax=cuda.doxygen | |||||
| @@ -0,0 +1 @@ | |||||
| ../implicit_gemm_conv_bias_cutlass_wrapper.cuinl | |||||
| @@ -0,0 +1,54 @@ | |||||
| #if !MEGDNN_TEGRA_X1 | |||||
| // ignore warning of cutlass | |||||
| #pragma GCC diagnostic push | |||||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||||
| #include "src/cuda/conv_bias/implicit_gemm_conv_bias_cutlass_wrapper.cuinl" | |||||
| // kernel instance "cutlass_simt_s4_ifprop_hswish_s8_128x128x32_64x32x32_2_nc4hw4_nhwc" generated by cutlass generator | |||||
| using Convolution = | |||||
| typename cutlass::conv::device::Convolution< | |||||
| int8_t, | |||||
| cutlass::layout::TensorNCxHWx<4>, | |||||
| int8_t, | |||||
| cutlass::layout::TensorCxRSKx<4>, | |||||
| cutlass::int4b_t, | |||||
| cutlass::layout::TensorNHWC, | |||||
| int32_t, | |||||
| cutlass::layout::TensorNHWC, | |||||
| int32_t, | |||||
| cutlass::conv::ConvType::kConvolution, | |||||
| cutlass::arch::OpClassSimt, | |||||
| cutlass::arch::Sm75, | |||||
| cutlass::gemm::GemmShape<128, 128, 32>, | |||||
| cutlass::gemm::GemmShape<64, 32, 32>, | |||||
| cutlass::gemm::GemmShape<1, 1, 4>, | |||||
| cutlass::epilogue::thread::BiasAddLinearCombinationHSwishClamp< | |||||
| cutlass::int4b_t, | |||||
| 8, | |||||
| int32_t, | |||||
| int32_t, | |||||
| float | |||||
| >, | |||||
| cutlass::conv::threadblock::ConvolutionFpropNCxHWxThreadblockSwizzle, | |||||
| 2, | |||||
| 4, | |||||
| 16, | |||||
| true, | |||||
| cutlass::arch::OpMultiplyAddSaturate>; | |||||
| template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>( | |||||
| const typename Convolution::ElementSrc* d_src, | |||||
| const typename Convolution::ElementFilter* d_filter, | |||||
| const typename Convolution::ElementBias* d_bias, | |||||
| const typename Convolution::ElementDst* d_z, | |||||
| typename Convolution::ElementDst* d_dst, | |||||
| int* workspace, | |||||
| typename Convolution::ConvolutionParameter const& conv_param, | |||||
| typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||||
| cudaStream_t stream); | |||||
| #pragma GCC diagnostic pop | |||||
| #endif | |||||
| @@ -0,0 +1,54 @@ | |||||
| #if !MEGDNN_TEGRA_X1 | |||||
| // ignore warning of cutlass | |||||
| #pragma GCC diagnostic push | |||||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||||
| #include "src/cuda/conv_bias/implicit_gemm_conv_bias_cutlass_wrapper.cuinl" | |||||
| // kernel instance "cutlass_simt_s4_ifprop_hswish_s8_128x32x32_64x32x32_2_nc4hw4_nhwc" generated by cutlass generator | |||||
| using Convolution = | |||||
| typename cutlass::conv::device::Convolution< | |||||
| int8_t, | |||||
| cutlass::layout::TensorNCxHWx<4>, | |||||
| int8_t, | |||||
| cutlass::layout::TensorCxRSKx<4>, | |||||
| cutlass::int4b_t, | |||||
| cutlass::layout::TensorNHWC, | |||||
| int32_t, | |||||
| cutlass::layout::TensorNHWC, | |||||
| int32_t, | |||||
| cutlass::conv::ConvType::kConvolution, | |||||
| cutlass::arch::OpClassSimt, | |||||
| cutlass::arch::Sm75, | |||||
| cutlass::gemm::GemmShape<128, 32, 32>, | |||||
| cutlass::gemm::GemmShape<64, 32, 32>, | |||||
| cutlass::gemm::GemmShape<1, 1, 4>, | |||||
| cutlass::epilogue::thread::BiasAddLinearCombinationHSwishClamp< | |||||
| cutlass::int4b_t, | |||||
| 8, | |||||
| int32_t, | |||||
| int32_t, | |||||
| float | |||||
| >, | |||||
| cutlass::conv::threadblock::ConvolutionFpropNCxHWxThreadblockSwizzle, | |||||
| 2, | |||||
| 4, | |||||
| 16, | |||||
| true, | |||||
| cutlass::arch::OpMultiplyAddSaturate>; | |||||
| template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>( | |||||
| const typename Convolution::ElementSrc* d_src, | |||||
| const typename Convolution::ElementFilter* d_filter, | |||||
| const typename Convolution::ElementBias* d_bias, | |||||
| const typename Convolution::ElementDst* d_z, | |||||
| typename Convolution::ElementDst* d_dst, | |||||
| int* workspace, | |||||
| typename Convolution::ConvolutionParameter const& conv_param, | |||||
| typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||||
| cudaStream_t stream); | |||||
| #pragma GCC diagnostic pop | |||||
| #endif | |||||
| @@ -0,0 +1,54 @@ | |||||
| #if !MEGDNN_TEGRA_X1 | |||||
| // ignore warning of cutlass | |||||
| #pragma GCC diagnostic push | |||||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||||
| #include "src/cuda/conv_bias/implicit_gemm_conv_bias_cutlass_wrapper.cuinl" | |||||
| // kernel instance "cutlass_simt_s4_ifprop_hswish_s8_128x64x32_64x32x32_2_nc4hw4_nhwc" generated by cutlass generator | |||||
| using Convolution = | |||||
| typename cutlass::conv::device::Convolution< | |||||
| int8_t, | |||||
| cutlass::layout::TensorNCxHWx<4>, | |||||
| int8_t, | |||||
| cutlass::layout::TensorCxRSKx<4>, | |||||
| cutlass::int4b_t, | |||||
| cutlass::layout::TensorNHWC, | |||||
| int32_t, | |||||
| cutlass::layout::TensorNHWC, | |||||
| int32_t, | |||||
| cutlass::conv::ConvType::kConvolution, | |||||
| cutlass::arch::OpClassSimt, | |||||
| cutlass::arch::Sm75, | |||||
| cutlass::gemm::GemmShape<128, 64, 32>, | |||||
| cutlass::gemm::GemmShape<64, 32, 32>, | |||||
| cutlass::gemm::GemmShape<1, 1, 4>, | |||||
| cutlass::epilogue::thread::BiasAddLinearCombinationHSwishClamp< | |||||
| cutlass::int4b_t, | |||||
| 8, | |||||
| int32_t, | |||||
| int32_t, | |||||
| float | |||||
| >, | |||||
| cutlass::conv::threadblock::ConvolutionFpropNCxHWxThreadblockSwizzle, | |||||
| 2, | |||||
| 4, | |||||
| 16, | |||||
| true, | |||||
| cutlass::arch::OpMultiplyAddSaturate>; | |||||
| template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>( | |||||
| const typename Convolution::ElementSrc* d_src, | |||||
| const typename Convolution::ElementFilter* d_filter, | |||||
| const typename Convolution::ElementBias* d_bias, | |||||
| const typename Convolution::ElementDst* d_z, | |||||
| typename Convolution::ElementDst* d_dst, | |||||
| int* workspace, | |||||
| typename Convolution::ConvolutionParameter const& conv_param, | |||||
| typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||||
| cudaStream_t stream); | |||||
| #pragma GCC diagnostic pop | |||||
| #endif | |||||
| @@ -0,0 +1,54 @@ | |||||
| #if !MEGDNN_TEGRA_X1 | |||||
| // ignore warning of cutlass | |||||
| #pragma GCC diagnostic push | |||||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||||
| #include "src/cuda/conv_bias/implicit_gemm_conv_bias_cutlass_wrapper.cuinl" | |||||
| // kernel instance "cutlass_simt_s4_ifprop_hswish_s8_16x128x16_16x128x16_1_nc4hw4_nhwc" generated by cutlass generator | |||||
| using Convolution = | |||||
| typename cutlass::conv::device::Convolution< | |||||
| int8_t, | |||||
| cutlass::layout::TensorNCxHWx<4>, | |||||
| int8_t, | |||||
| cutlass::layout::TensorCxRSKx<4>, | |||||
| cutlass::int4b_t, | |||||
| cutlass::layout::TensorNHWC, | |||||
| int32_t, | |||||
| cutlass::layout::TensorNHWC, | |||||
| int32_t, | |||||
| cutlass::conv::ConvType::kConvolution, | |||||
| cutlass::arch::OpClassSimt, | |||||
| cutlass::arch::Sm75, | |||||
| cutlass::gemm::GemmShape<16, 128, 16>, | |||||
| cutlass::gemm::GemmShape<16, 128, 16>, | |||||
| cutlass::gemm::GemmShape<1, 1, 4>, | |||||
| cutlass::epilogue::thread::BiasAddLinearCombinationHSwishClamp< | |||||
| cutlass::int4b_t, | |||||
| 8, | |||||
| int32_t, | |||||
| int32_t, | |||||
| float | |||||
| >, | |||||
| cutlass::conv::threadblock::ConvolutionFpropNCxHWxThreadblockSwizzle, | |||||
| 1, | |||||
| 4, | |||||
| 8, | |||||
| true, | |||||
| cutlass::arch::OpMultiplyAddSaturate>; | |||||
| template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>( | |||||
| const typename Convolution::ElementSrc* d_src, | |||||
| const typename Convolution::ElementFilter* d_filter, | |||||
| const typename Convolution::ElementBias* d_bias, | |||||
| const typename Convolution::ElementDst* d_z, | |||||
| typename Convolution::ElementDst* d_dst, | |||||
| int* workspace, | |||||
| typename Convolution::ConvolutionParameter const& conv_param, | |||||
| typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||||
| cudaStream_t stream); | |||||
| #pragma GCC diagnostic pop | |||||
| #endif | |||||
| @@ -0,0 +1,54 @@ | |||||
| #if !MEGDNN_TEGRA_X1 | |||||
| // ignore warning of cutlass | |||||
| #pragma GCC diagnostic push | |||||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||||
| #include "src/cuda/conv_bias/implicit_gemm_conv_bias_cutlass_wrapper.cuinl" | |||||
| // kernel instance "cutlass_simt_s4_ifprop_hswish_s8_16x64x8_16x64x8_2_nc4hw4_nhwc" generated by cutlass generator | |||||
| using Convolution = | |||||
| typename cutlass::conv::device::Convolution< | |||||
| int8_t, | |||||
| cutlass::layout::TensorNCxHWx<4>, | |||||
| int8_t, | |||||
| cutlass::layout::TensorCxRSKx<4>, | |||||
| cutlass::int4b_t, | |||||
| cutlass::layout::TensorNHWC, | |||||
| int32_t, | |||||
| cutlass::layout::TensorNHWC, | |||||
| int32_t, | |||||
| cutlass::conv::ConvType::kConvolution, | |||||
| cutlass::arch::OpClassSimt, | |||||
| cutlass::arch::Sm75, | |||||
| cutlass::gemm::GemmShape<16, 64, 8>, | |||||
| cutlass::gemm::GemmShape<16, 64, 8>, | |||||
| cutlass::gemm::GemmShape<1, 1, 4>, | |||||
| cutlass::epilogue::thread::BiasAddLinearCombinationHSwishClamp< | |||||
| cutlass::int4b_t, | |||||
| 8, | |||||
| int32_t, | |||||
| int32_t, | |||||
| float | |||||
| >, | |||||
| cutlass::conv::threadblock::ConvolutionFpropNCxHWxThreadblockSwizzle, | |||||
| 2, | |||||
| 4, | |||||
| 4, | |||||
| true, | |||||
| cutlass::arch::OpMultiplyAddSaturate>; | |||||
| template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>( | |||||
| const typename Convolution::ElementSrc* d_src, | |||||
| const typename Convolution::ElementFilter* d_filter, | |||||
| const typename Convolution::ElementBias* d_bias, | |||||
| const typename Convolution::ElementDst* d_z, | |||||
| typename Convolution::ElementDst* d_dst, | |||||
| int* workspace, | |||||
| typename Convolution::ConvolutionParameter const& conv_param, | |||||
| typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||||
| cudaStream_t stream); | |||||
| #pragma GCC diagnostic pop | |||||
| #endif | |||||
| @@ -0,0 +1,54 @@ | |||||
| #if !MEGDNN_TEGRA_X1 | |||||
| // ignore warning of cutlass | |||||
| #pragma GCC diagnostic push | |||||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||||
| #include "src/cuda/conv_bias/implicit_gemm_conv_bias_cutlass_wrapper.cuinl" | |||||
| // kernel instance "cutlass_simt_s4_ifprop_hswish_s8_32x128x32_32x64x32_2_nc4hw4_nhwc" generated by cutlass generator | |||||
| using Convolution = | |||||
| typename cutlass::conv::device::Convolution< | |||||
| int8_t, | |||||
| cutlass::layout::TensorNCxHWx<4>, | |||||
| int8_t, | |||||
| cutlass::layout::TensorCxRSKx<4>, | |||||
| cutlass::int4b_t, | |||||
| cutlass::layout::TensorNHWC, | |||||
| int32_t, | |||||
| cutlass::layout::TensorNHWC, | |||||
| int32_t, | |||||
| cutlass::conv::ConvType::kConvolution, | |||||
| cutlass::arch::OpClassSimt, | |||||
| cutlass::arch::Sm75, | |||||
| cutlass::gemm::GemmShape<32, 128, 32>, | |||||
| cutlass::gemm::GemmShape<32, 64, 32>, | |||||
| cutlass::gemm::GemmShape<1, 1, 4>, | |||||
| cutlass::epilogue::thread::BiasAddLinearCombinationHSwishClamp< | |||||
| cutlass::int4b_t, | |||||
| 8, | |||||
| int32_t, | |||||
| int32_t, | |||||
| float | |||||
| >, | |||||
| cutlass::conv::threadblock::ConvolutionFpropNCxHWxThreadblockSwizzle, | |||||
| 2, | |||||
| 4, | |||||
| 16, | |||||
| true, | |||||
| cutlass::arch::OpMultiplyAddSaturate>; | |||||
| template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>( | |||||
| const typename Convolution::ElementSrc* d_src, | |||||
| const typename Convolution::ElementFilter* d_filter, | |||||
| const typename Convolution::ElementBias* d_bias, | |||||
| const typename Convolution::ElementDst* d_z, | |||||
| typename Convolution::ElementDst* d_dst, | |||||
| int* workspace, | |||||
| typename Convolution::ConvolutionParameter const& conv_param, | |||||
| typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||||
| cudaStream_t stream); | |||||
| #pragma GCC diagnostic pop | |||||
| #endif | |||||
| @@ -0,0 +1,54 @@ | |||||
| #if !MEGDNN_TEGRA_X1 | |||||
| // ignore warning of cutlass | |||||
| #pragma GCC diagnostic push | |||||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||||
| #include "src/cuda/conv_bias/implicit_gemm_conv_bias_cutlass_wrapper.cuinl" | |||||
| // kernel instance "cutlass_simt_s4_ifprop_hswish_s8_32x32x32_32x32x32_2_nc4hw4_nhwc" generated by cutlass generator | |||||
| using Convolution = | |||||
| typename cutlass::conv::device::Convolution< | |||||
| int8_t, | |||||
| cutlass::layout::TensorNCxHWx<4>, | |||||
| int8_t, | |||||
| cutlass::layout::TensorCxRSKx<4>, | |||||
| cutlass::int4b_t, | |||||
| cutlass::layout::TensorNHWC, | |||||
| int32_t, | |||||
| cutlass::layout::TensorNHWC, | |||||
| int32_t, | |||||
| cutlass::conv::ConvType::kConvolution, | |||||
| cutlass::arch::OpClassSimt, | |||||
| cutlass::arch::Sm75, | |||||
| cutlass::gemm::GemmShape<32, 32, 32>, | |||||
| cutlass::gemm::GemmShape<32, 32, 32>, | |||||
| cutlass::gemm::GemmShape<1, 1, 4>, | |||||
| cutlass::epilogue::thread::BiasAddLinearCombinationHSwishClamp< | |||||
| cutlass::int4b_t, | |||||
| 8, | |||||
| int32_t, | |||||
| int32_t, | |||||
| float | |||||
| >, | |||||
| cutlass::conv::threadblock::ConvolutionFpropNCxHWxThreadblockSwizzle, | |||||
| 2, | |||||
| 4, | |||||
| 16, | |||||
| true, | |||||
| cutlass::arch::OpMultiplyAddSaturate>; | |||||
| template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>( | |||||
| const typename Convolution::ElementSrc* d_src, | |||||
| const typename Convolution::ElementFilter* d_filter, | |||||
| const typename Convolution::ElementBias* d_bias, | |||||
| const typename Convolution::ElementDst* d_z, | |||||
| typename Convolution::ElementDst* d_dst, | |||||
| int* workspace, | |||||
| typename Convolution::ConvolutionParameter const& conv_param, | |||||
| typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||||
| cudaStream_t stream); | |||||
| #pragma GCC diagnostic pop | |||||
| #endif | |||||
| @@ -0,0 +1,54 @@ | |||||
| #if !MEGDNN_TEGRA_X1 | |||||
| // ignore warning of cutlass | |||||
| #pragma GCC diagnostic push | |||||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||||
| #include "src/cuda/conv_bias/implicit_gemm_conv_bias_cutlass_wrapper.cuinl" | |||||
| // kernel instance "cutlass_simt_s4_ifprop_hswish_s8_32x64x32_32x64x32_2_nc4hw4_nhwc" generated by cutlass generator | |||||
| using Convolution = | |||||
| typename cutlass::conv::device::Convolution< | |||||
| int8_t, | |||||
| cutlass::layout::TensorNCxHWx<4>, | |||||
| int8_t, | |||||
| cutlass::layout::TensorCxRSKx<4>, | |||||
| cutlass::int4b_t, | |||||
| cutlass::layout::TensorNHWC, | |||||
| int32_t, | |||||
| cutlass::layout::TensorNHWC, | |||||
| int32_t, | |||||
| cutlass::conv::ConvType::kConvolution, | |||||
| cutlass::arch::OpClassSimt, | |||||
| cutlass::arch::Sm75, | |||||
| cutlass::gemm::GemmShape<32, 64, 32>, | |||||
| cutlass::gemm::GemmShape<32, 64, 32>, | |||||
| cutlass::gemm::GemmShape<1, 1, 4>, | |||||
| cutlass::epilogue::thread::BiasAddLinearCombinationHSwishClamp< | |||||
| cutlass::int4b_t, | |||||
| 8, | |||||
| int32_t, | |||||
| int32_t, | |||||
| float | |||||
| >, | |||||
| cutlass::conv::threadblock::ConvolutionFpropNCxHWxThreadblockSwizzle, | |||||
| 2, | |||||
| 4, | |||||
| 16, | |||||
| true, | |||||
| cutlass::arch::OpMultiplyAddSaturate>; | |||||
| template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>( | |||||
| const typename Convolution::ElementSrc* d_src, | |||||
| const typename Convolution::ElementFilter* d_filter, | |||||
| const typename Convolution::ElementBias* d_bias, | |||||
| const typename Convolution::ElementDst* d_z, | |||||
| typename Convolution::ElementDst* d_dst, | |||||
| int* workspace, | |||||
| typename Convolution::ConvolutionParameter const& conv_param, | |||||
| typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||||
| cudaStream_t stream); | |||||
| #pragma GCC diagnostic pop | |||||
| #endif | |||||
| @@ -0,0 +1,54 @@ | |||||
| #if !MEGDNN_TEGRA_X1 | |||||
| // ignore warning of cutlass | |||||
| #pragma GCC diagnostic push | |||||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||||
| #include "src/cuda/conv_bias/implicit_gemm_conv_bias_cutlass_wrapper.cuinl" | |||||
| // kernel instance "cutlass_simt_s4_ifprop_hswish_s8_64x128x32_64x32x32_2_nc4hw4_nhwc" generated by cutlass generator | |||||
| using Convolution = | |||||
| typename cutlass::conv::device::Convolution< | |||||
| int8_t, | |||||
| cutlass::layout::TensorNCxHWx<4>, | |||||
| int8_t, | |||||
| cutlass::layout::TensorCxRSKx<4>, | |||||
| cutlass::int4b_t, | |||||
| cutlass::layout::TensorNHWC, | |||||
| int32_t, | |||||
| cutlass::layout::TensorNHWC, | |||||
| int32_t, | |||||
| cutlass::conv::ConvType::kConvolution, | |||||
| cutlass::arch::OpClassSimt, | |||||
| cutlass::arch::Sm75, | |||||
| cutlass::gemm::GemmShape<64, 128, 32>, | |||||
| cutlass::gemm::GemmShape<64, 32, 32>, | |||||
| cutlass::gemm::GemmShape<1, 1, 4>, | |||||
| cutlass::epilogue::thread::BiasAddLinearCombinationHSwishClamp< | |||||
| cutlass::int4b_t, | |||||
| 8, | |||||
| int32_t, | |||||
| int32_t, | |||||
| float | |||||
| >, | |||||
| cutlass::conv::threadblock::ConvolutionFpropNCxHWxThreadblockSwizzle, | |||||
| 2, | |||||
| 4, | |||||
| 16, | |||||
| true, | |||||
| cutlass::arch::OpMultiplyAddSaturate>; | |||||
| template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>( | |||||
| const typename Convolution::ElementSrc* d_src, | |||||
| const typename Convolution::ElementFilter* d_filter, | |||||
| const typename Convolution::ElementBias* d_bias, | |||||
| const typename Convolution::ElementDst* d_z, | |||||
| typename Convolution::ElementDst* d_dst, | |||||
| int* workspace, | |||||
| typename Convolution::ConvolutionParameter const& conv_param, | |||||
| typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||||
| cudaStream_t stream); | |||||
| #pragma GCC diagnostic pop | |||||
| #endif | |||||
| @@ -0,0 +1,54 @@ | |||||
| #if !MEGDNN_TEGRA_X1 | |||||
| // ignore warning of cutlass | |||||
| #pragma GCC diagnostic push | |||||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||||
| #include "src/cuda/conv_bias/implicit_gemm_conv_bias_cutlass_wrapper.cuinl" | |||||
| // kernel instance "cutlass_simt_s4_ifprop_hswish_s8_64x32x32_64x32x32_2_nc4hw4_nhwc" generated by cutlass generator | |||||
| using Convolution = | |||||
| typename cutlass::conv::device::Convolution< | |||||
| int8_t, | |||||
| cutlass::layout::TensorNCxHWx<4>, | |||||
| int8_t, | |||||
| cutlass::layout::TensorCxRSKx<4>, | |||||
| cutlass::int4b_t, | |||||
| cutlass::layout::TensorNHWC, | |||||
| int32_t, | |||||
| cutlass::layout::TensorNHWC, | |||||
| int32_t, | |||||
| cutlass::conv::ConvType::kConvolution, | |||||
| cutlass::arch::OpClassSimt, | |||||
| cutlass::arch::Sm75, | |||||
| cutlass::gemm::GemmShape<64, 32, 32>, | |||||
| cutlass::gemm::GemmShape<64, 32, 32>, | |||||
| cutlass::gemm::GemmShape<1, 1, 4>, | |||||
| cutlass::epilogue::thread::BiasAddLinearCombinationHSwishClamp< | |||||
| cutlass::int4b_t, | |||||
| 8, | |||||
| int32_t, | |||||
| int32_t, | |||||
| float | |||||
| >, | |||||
| cutlass::conv::threadblock::ConvolutionFpropNCxHWxThreadblockSwizzle, | |||||
| 2, | |||||
| 4, | |||||
| 16, | |||||
| true, | |||||
| cutlass::arch::OpMultiplyAddSaturate>; | |||||
| template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>( | |||||
| const typename Convolution::ElementSrc* d_src, | |||||
| const typename Convolution::ElementFilter* d_filter, | |||||
| const typename Convolution::ElementBias* d_bias, | |||||
| const typename Convolution::ElementDst* d_z, | |||||
| typename Convolution::ElementDst* d_dst, | |||||
| int* workspace, | |||||
| typename Convolution::ConvolutionParameter const& conv_param, | |||||
| typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||||
| cudaStream_t stream); | |||||
| #pragma GCC diagnostic pop | |||||
| #endif | |||||
| @@ -0,0 +1,54 @@ | |||||
| #if !MEGDNN_TEGRA_X1 | |||||
| // ignore warning of cutlass | |||||
| #pragma GCC diagnostic push | |||||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||||
| #include "src/cuda/conv_bias/implicit_gemm_conv_bias_cutlass_wrapper.cuinl" | |||||
| // kernel instance "cutlass_simt_s4_ifprop_hswish_s8_64x64x32_64x32x32_2_nc4hw4_nhwc" generated by cutlass generator | |||||
| using Convolution = | |||||
| typename cutlass::conv::device::Convolution< | |||||
| int8_t, | |||||
| cutlass::layout::TensorNCxHWx<4>, | |||||
| int8_t, | |||||
| cutlass::layout::TensorCxRSKx<4>, | |||||
| cutlass::int4b_t, | |||||
| cutlass::layout::TensorNHWC, | |||||
| int32_t, | |||||
| cutlass::layout::TensorNHWC, | |||||
| int32_t, | |||||
| cutlass::conv::ConvType::kConvolution, | |||||
| cutlass::arch::OpClassSimt, | |||||
| cutlass::arch::Sm75, | |||||
| cutlass::gemm::GemmShape<64, 64, 32>, | |||||
| cutlass::gemm::GemmShape<64, 32, 32>, | |||||
| cutlass::gemm::GemmShape<1, 1, 4>, | |||||
| cutlass::epilogue::thread::BiasAddLinearCombinationHSwishClamp< | |||||
| cutlass::int4b_t, | |||||
| 8, | |||||
| int32_t, | |||||
| int32_t, | |||||
| float | |||||
| >, | |||||
| cutlass::conv::threadblock::ConvolutionFpropNCxHWxThreadblockSwizzle, | |||||
| 2, | |||||
| 4, | |||||
| 16, | |||||
| true, | |||||
| cutlass::arch::OpMultiplyAddSaturate>; | |||||
| template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>( | |||||
| const typename Convolution::ElementSrc* d_src, | |||||
| const typename Convolution::ElementFilter* d_filter, | |||||
| const typename Convolution::ElementBias* d_bias, | |||||
| const typename Convolution::ElementDst* d_z, | |||||
| typename Convolution::ElementDst* d_dst, | |||||
| int* workspace, | |||||
| typename Convolution::ConvolutionParameter const& conv_param, | |||||
| typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||||
| cudaStream_t stream); | |||||
| #pragma GCC diagnostic pop | |||||
| #endif | |||||
| @@ -0,0 +1,54 @@ | |||||
| #if !MEGDNN_TEGRA_X1 | |||||
| // ignore warning of cutlass | |||||
| #pragma GCC diagnostic push | |||||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||||
| #include "src/cuda/conv_bias/implicit_gemm_conv_bias_cutlass_wrapper.cuinl" | |||||
| // kernel instance "cutlass_simt_s4_ifprop_identity_s8_128x128x32_64x32x32_2_nc4hw4_nhwc" generated by cutlass generator | |||||
| using Convolution = | |||||
| typename cutlass::conv::device::Convolution< | |||||
| int8_t, | |||||
| cutlass::layout::TensorNCxHWx<4>, | |||||
| int8_t, | |||||
| cutlass::layout::TensorCxRSKx<4>, | |||||
| cutlass::int4b_t, | |||||
| cutlass::layout::TensorNHWC, | |||||
| int32_t, | |||||
| cutlass::layout::TensorNHWC, | |||||
| int32_t, | |||||
| cutlass::conv::ConvType::kConvolution, | |||||
| cutlass::arch::OpClassSimt, | |||||
| cutlass::arch::Sm75, | |||||
| cutlass::gemm::GemmShape<128, 128, 32>, | |||||
| cutlass::gemm::GemmShape<64, 32, 32>, | |||||
| cutlass::gemm::GemmShape<1, 1, 4>, | |||||
| cutlass::epilogue::thread::BiasAddLinearCombinationClamp< | |||||
| cutlass::int4b_t, | |||||
| 8, | |||||
| int32_t, | |||||
| int32_t, | |||||
| float | |||||
| >, | |||||
| cutlass::conv::threadblock::ConvolutionFpropNCxHWxThreadblockSwizzle, | |||||
| 2, | |||||
| 4, | |||||
| 16, | |||||
| true, | |||||
| cutlass::arch::OpMultiplyAddSaturate>; | |||||
| template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>( | |||||
| const typename Convolution::ElementSrc* d_src, | |||||
| const typename Convolution::ElementFilter* d_filter, | |||||
| const typename Convolution::ElementBias* d_bias, | |||||
| const typename Convolution::ElementDst* d_z, | |||||
| typename Convolution::ElementDst* d_dst, | |||||
| int* workspace, | |||||
| typename Convolution::ConvolutionParameter const& conv_param, | |||||
| typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||||
| cudaStream_t stream); | |||||
| #pragma GCC diagnostic pop | |||||
| #endif | |||||
| @@ -0,0 +1,54 @@ | |||||
| #if !MEGDNN_TEGRA_X1 | |||||
| // ignore warning of cutlass | |||||
| #pragma GCC diagnostic push | |||||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||||
| #include "src/cuda/conv_bias/implicit_gemm_conv_bias_cutlass_wrapper.cuinl" | |||||
| // kernel instance "cutlass_simt_s4_ifprop_identity_s8_128x32x32_64x32x32_2_nc4hw4_nhwc" generated by cutlass generator | |||||
| using Convolution = | |||||
| typename cutlass::conv::device::Convolution< | |||||
| int8_t, | |||||
| cutlass::layout::TensorNCxHWx<4>, | |||||
| int8_t, | |||||
| cutlass::layout::TensorCxRSKx<4>, | |||||
| cutlass::int4b_t, | |||||
| cutlass::layout::TensorNHWC, | |||||
| int32_t, | |||||
| cutlass::layout::TensorNHWC, | |||||
| int32_t, | |||||
| cutlass::conv::ConvType::kConvolution, | |||||
| cutlass::arch::OpClassSimt, | |||||
| cutlass::arch::Sm75, | |||||
| cutlass::gemm::GemmShape<128, 32, 32>, | |||||
| cutlass::gemm::GemmShape<64, 32, 32>, | |||||
| cutlass::gemm::GemmShape<1, 1, 4>, | |||||
| cutlass::epilogue::thread::BiasAddLinearCombinationClamp< | |||||
| cutlass::int4b_t, | |||||
| 8, | |||||
| int32_t, | |||||
| int32_t, | |||||
| float | |||||
| >, | |||||
| cutlass::conv::threadblock::ConvolutionFpropNCxHWxThreadblockSwizzle, | |||||
| 2, | |||||
| 4, | |||||
| 16, | |||||
| true, | |||||
| cutlass::arch::OpMultiplyAddSaturate>; | |||||
| template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>( | |||||
| const typename Convolution::ElementSrc* d_src, | |||||
| const typename Convolution::ElementFilter* d_filter, | |||||
| const typename Convolution::ElementBias* d_bias, | |||||
| const typename Convolution::ElementDst* d_z, | |||||
| typename Convolution::ElementDst* d_dst, | |||||
| int* workspace, | |||||
| typename Convolution::ConvolutionParameter const& conv_param, | |||||
| typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||||
| cudaStream_t stream); | |||||
| #pragma GCC diagnostic pop | |||||
| #endif | |||||
| @@ -0,0 +1,54 @@ | |||||
| #if !MEGDNN_TEGRA_X1 | |||||
| // ignore warning of cutlass | |||||
| #pragma GCC diagnostic push | |||||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||||
| #include "src/cuda/conv_bias/implicit_gemm_conv_bias_cutlass_wrapper.cuinl" | |||||
| // kernel instance "cutlass_simt_s4_ifprop_identity_s8_128x64x32_64x32x32_2_nc4hw4_nhwc" generated by cutlass generator | |||||
| using Convolution = | |||||
| typename cutlass::conv::device::Convolution< | |||||
| int8_t, | |||||
| cutlass::layout::TensorNCxHWx<4>, | |||||
| int8_t, | |||||
| cutlass::layout::TensorCxRSKx<4>, | |||||
| cutlass::int4b_t, | |||||
| cutlass::layout::TensorNHWC, | |||||
| int32_t, | |||||
| cutlass::layout::TensorNHWC, | |||||
| int32_t, | |||||
| cutlass::conv::ConvType::kConvolution, | |||||
| cutlass::arch::OpClassSimt, | |||||
| cutlass::arch::Sm75, | |||||
| cutlass::gemm::GemmShape<128, 64, 32>, | |||||
| cutlass::gemm::GemmShape<64, 32, 32>, | |||||
| cutlass::gemm::GemmShape<1, 1, 4>, | |||||
| cutlass::epilogue::thread::BiasAddLinearCombinationClamp< | |||||
| cutlass::int4b_t, | |||||
| 8, | |||||
| int32_t, | |||||
| int32_t, | |||||
| float | |||||
| >, | |||||
| cutlass::conv::threadblock::ConvolutionFpropNCxHWxThreadblockSwizzle, | |||||
| 2, | |||||
| 4, | |||||
| 16, | |||||
| true, | |||||
| cutlass::arch::OpMultiplyAddSaturate>; | |||||
| template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>( | |||||
| const typename Convolution::ElementSrc* d_src, | |||||
| const typename Convolution::ElementFilter* d_filter, | |||||
| const typename Convolution::ElementBias* d_bias, | |||||
| const typename Convolution::ElementDst* d_z, | |||||
| typename Convolution::ElementDst* d_dst, | |||||
| int* workspace, | |||||
| typename Convolution::ConvolutionParameter const& conv_param, | |||||
| typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||||
| cudaStream_t stream); | |||||
| #pragma GCC diagnostic pop | |||||
| #endif | |||||
| @@ -0,0 +1,54 @@ | |||||
| #if !MEGDNN_TEGRA_X1 | |||||
| // ignore warning of cutlass | |||||
| #pragma GCC diagnostic push | |||||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||||
| #include "src/cuda/conv_bias/implicit_gemm_conv_bias_cutlass_wrapper.cuinl" | |||||
| // kernel instance "cutlass_simt_s4_ifprop_identity_s8_16x128x16_16x128x16_1_nc4hw4_nhwc" generated by cutlass generator | |||||
| using Convolution = | |||||
| typename cutlass::conv::device::Convolution< | |||||
| int8_t, | |||||
| cutlass::layout::TensorNCxHWx<4>, | |||||
| int8_t, | |||||
| cutlass::layout::TensorCxRSKx<4>, | |||||
| cutlass::int4b_t, | |||||
| cutlass::layout::TensorNHWC, | |||||
| int32_t, | |||||
| cutlass::layout::TensorNHWC, | |||||
| int32_t, | |||||
| cutlass::conv::ConvType::kConvolution, | |||||
| cutlass::arch::OpClassSimt, | |||||
| cutlass::arch::Sm75, | |||||
| cutlass::gemm::GemmShape<16, 128, 16>, | |||||
| cutlass::gemm::GemmShape<16, 128, 16>, | |||||
| cutlass::gemm::GemmShape<1, 1, 4>, | |||||
| cutlass::epilogue::thread::BiasAddLinearCombinationClamp< | |||||
| cutlass::int4b_t, | |||||
| 8, | |||||
| int32_t, | |||||
| int32_t, | |||||
| float | |||||
| >, | |||||
| cutlass::conv::threadblock::ConvolutionFpropNCxHWxThreadblockSwizzle, | |||||
| 1, | |||||
| 4, | |||||
| 8, | |||||
| true, | |||||
| cutlass::arch::OpMultiplyAddSaturate>; | |||||
| template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>( | |||||
| const typename Convolution::ElementSrc* d_src, | |||||
| const typename Convolution::ElementFilter* d_filter, | |||||
| const typename Convolution::ElementBias* d_bias, | |||||
| const typename Convolution::ElementDst* d_z, | |||||
| typename Convolution::ElementDst* d_dst, | |||||
| int* workspace, | |||||
| typename Convolution::ConvolutionParameter const& conv_param, | |||||
| typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||||
| cudaStream_t stream); | |||||
| #pragma GCC diagnostic pop | |||||
| #endif | |||||
| @@ -0,0 +1,54 @@ | |||||
| #if !MEGDNN_TEGRA_X1 | |||||
| // ignore warning of cutlass | |||||
| #pragma GCC diagnostic push | |||||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||||
| #include "src/cuda/conv_bias/implicit_gemm_conv_bias_cutlass_wrapper.cuinl" | |||||
| // kernel instance "cutlass_simt_s4_ifprop_identity_s8_16x64x8_16x64x8_2_nc4hw4_nhwc" generated by cutlass generator | |||||
| using Convolution = | |||||
| typename cutlass::conv::device::Convolution< | |||||
| int8_t, | |||||
| cutlass::layout::TensorNCxHWx<4>, | |||||
| int8_t, | |||||
| cutlass::layout::TensorCxRSKx<4>, | |||||
| cutlass::int4b_t, | |||||
| cutlass::layout::TensorNHWC, | |||||
| int32_t, | |||||
| cutlass::layout::TensorNHWC, | |||||
| int32_t, | |||||
| cutlass::conv::ConvType::kConvolution, | |||||
| cutlass::arch::OpClassSimt, | |||||
| cutlass::arch::Sm75, | |||||
| cutlass::gemm::GemmShape<16, 64, 8>, | |||||
| cutlass::gemm::GemmShape<16, 64, 8>, | |||||
| cutlass::gemm::GemmShape<1, 1, 4>, | |||||
| cutlass::epilogue::thread::BiasAddLinearCombinationClamp< | |||||
| cutlass::int4b_t, | |||||
| 8, | |||||
| int32_t, | |||||
| int32_t, | |||||
| float | |||||
| >, | |||||
| cutlass::conv::threadblock::ConvolutionFpropNCxHWxThreadblockSwizzle, | |||||
| 2, | |||||
| 4, | |||||
| 4, | |||||
| true, | |||||
| cutlass::arch::OpMultiplyAddSaturate>; | |||||
| template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>( | |||||
| const typename Convolution::ElementSrc* d_src, | |||||
| const typename Convolution::ElementFilter* d_filter, | |||||
| const typename Convolution::ElementBias* d_bias, | |||||
| const typename Convolution::ElementDst* d_z, | |||||
| typename Convolution::ElementDst* d_dst, | |||||
| int* workspace, | |||||
| typename Convolution::ConvolutionParameter const& conv_param, | |||||
| typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||||
| cudaStream_t stream); | |||||
| #pragma GCC diagnostic pop | |||||
| #endif | |||||
| @@ -0,0 +1,54 @@ | |||||
| #if !MEGDNN_TEGRA_X1 | |||||
| // ignore warning of cutlass | |||||
| #pragma GCC diagnostic push | |||||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||||
| #include "src/cuda/conv_bias/implicit_gemm_conv_bias_cutlass_wrapper.cuinl" | |||||
| // kernel instance "cutlass_simt_s4_ifprop_identity_s8_32x128x32_32x64x32_2_nc4hw4_nhwc" generated by cutlass generator | |||||
| using Convolution = | |||||
| typename cutlass::conv::device::Convolution< | |||||
| int8_t, | |||||
| cutlass::layout::TensorNCxHWx<4>, | |||||
| int8_t, | |||||
| cutlass::layout::TensorCxRSKx<4>, | |||||
| cutlass::int4b_t, | |||||
| cutlass::layout::TensorNHWC, | |||||
| int32_t, | |||||
| cutlass::layout::TensorNHWC, | |||||
| int32_t, | |||||
| cutlass::conv::ConvType::kConvolution, | |||||
| cutlass::arch::OpClassSimt, | |||||
| cutlass::arch::Sm75, | |||||
| cutlass::gemm::GemmShape<32, 128, 32>, | |||||
| cutlass::gemm::GemmShape<32, 64, 32>, | |||||
| cutlass::gemm::GemmShape<1, 1, 4>, | |||||
| cutlass::epilogue::thread::BiasAddLinearCombinationClamp< | |||||
| cutlass::int4b_t, | |||||
| 8, | |||||
| int32_t, | |||||
| int32_t, | |||||
| float | |||||
| >, | |||||
| cutlass::conv::threadblock::ConvolutionFpropNCxHWxThreadblockSwizzle, | |||||
| 2, | |||||
| 4, | |||||
| 16, | |||||
| true, | |||||
| cutlass::arch::OpMultiplyAddSaturate>; | |||||
| template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>( | |||||
| const typename Convolution::ElementSrc* d_src, | |||||
| const typename Convolution::ElementFilter* d_filter, | |||||
| const typename Convolution::ElementBias* d_bias, | |||||
| const typename Convolution::ElementDst* d_z, | |||||
| typename Convolution::ElementDst* d_dst, | |||||
| int* workspace, | |||||
| typename Convolution::ConvolutionParameter const& conv_param, | |||||
| typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||||
| cudaStream_t stream); | |||||
| #pragma GCC diagnostic pop | |||||
| #endif | |||||
| @@ -0,0 +1,54 @@ | |||||
| #if !MEGDNN_TEGRA_X1 | |||||
| // ignore warning of cutlass | |||||
| #pragma GCC diagnostic push | |||||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||||
| #include "src/cuda/conv_bias/implicit_gemm_conv_bias_cutlass_wrapper.cuinl" | |||||
| // kernel instance "cutlass_simt_s4_ifprop_identity_s8_32x32x32_32x32x32_2_nc4hw4_nhwc" generated by cutlass generator | |||||
| using Convolution = | |||||
| typename cutlass::conv::device::Convolution< | |||||
| int8_t, | |||||
| cutlass::layout::TensorNCxHWx<4>, | |||||
| int8_t, | |||||
| cutlass::layout::TensorCxRSKx<4>, | |||||
| cutlass::int4b_t, | |||||
| cutlass::layout::TensorNHWC, | |||||
| int32_t, | |||||
| cutlass::layout::TensorNHWC, | |||||
| int32_t, | |||||
| cutlass::conv::ConvType::kConvolution, | |||||
| cutlass::arch::OpClassSimt, | |||||
| cutlass::arch::Sm75, | |||||
| cutlass::gemm::GemmShape<32, 32, 32>, | |||||
| cutlass::gemm::GemmShape<32, 32, 32>, | |||||
| cutlass::gemm::GemmShape<1, 1, 4>, | |||||
| cutlass::epilogue::thread::BiasAddLinearCombinationClamp< | |||||
| cutlass::int4b_t, | |||||
| 8, | |||||
| int32_t, | |||||
| int32_t, | |||||
| float | |||||
| >, | |||||
| cutlass::conv::threadblock::ConvolutionFpropNCxHWxThreadblockSwizzle, | |||||
| 2, | |||||
| 4, | |||||
| 16, | |||||
| true, | |||||
| cutlass::arch::OpMultiplyAddSaturate>; | |||||
| template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>( | |||||
| const typename Convolution::ElementSrc* d_src, | |||||
| const typename Convolution::ElementFilter* d_filter, | |||||
| const typename Convolution::ElementBias* d_bias, | |||||
| const typename Convolution::ElementDst* d_z, | |||||
| typename Convolution::ElementDst* d_dst, | |||||
| int* workspace, | |||||
| typename Convolution::ConvolutionParameter const& conv_param, | |||||
| typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||||
| cudaStream_t stream); | |||||
| #pragma GCC diagnostic pop | |||||
| #endif | |||||
| @@ -0,0 +1,54 @@ | |||||
| #if !MEGDNN_TEGRA_X1 | |||||
| // ignore warning of cutlass | |||||
| #pragma GCC diagnostic push | |||||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||||
| #include "src/cuda/conv_bias/implicit_gemm_conv_bias_cutlass_wrapper.cuinl" | |||||
| // kernel instance "cutlass_simt_s4_ifprop_identity_s8_32x64x32_32x64x32_2_nc4hw4_nhwc" generated by cutlass generator | |||||
| using Convolution = | |||||
| typename cutlass::conv::device::Convolution< | |||||
| int8_t, | |||||
| cutlass::layout::TensorNCxHWx<4>, | |||||
| int8_t, | |||||
| cutlass::layout::TensorCxRSKx<4>, | |||||
| cutlass::int4b_t, | |||||
| cutlass::layout::TensorNHWC, | |||||
| int32_t, | |||||
| cutlass::layout::TensorNHWC, | |||||
| int32_t, | |||||
| cutlass::conv::ConvType::kConvolution, | |||||
| cutlass::arch::OpClassSimt, | |||||
| cutlass::arch::Sm75, | |||||
| cutlass::gemm::GemmShape<32, 64, 32>, | |||||
| cutlass::gemm::GemmShape<32, 64, 32>, | |||||
| cutlass::gemm::GemmShape<1, 1, 4>, | |||||
| cutlass::epilogue::thread::BiasAddLinearCombinationClamp< | |||||
| cutlass::int4b_t, | |||||
| 8, | |||||
| int32_t, | |||||
| int32_t, | |||||
| float | |||||
| >, | |||||
| cutlass::conv::threadblock::ConvolutionFpropNCxHWxThreadblockSwizzle, | |||||
| 2, | |||||
| 4, | |||||
| 16, | |||||
| true, | |||||
| cutlass::arch::OpMultiplyAddSaturate>; | |||||
| template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>( | |||||
| const typename Convolution::ElementSrc* d_src, | |||||
| const typename Convolution::ElementFilter* d_filter, | |||||
| const typename Convolution::ElementBias* d_bias, | |||||
| const typename Convolution::ElementDst* d_z, | |||||
| typename Convolution::ElementDst* d_dst, | |||||
| int* workspace, | |||||
| typename Convolution::ConvolutionParameter const& conv_param, | |||||
| typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||||
| cudaStream_t stream); | |||||
| #pragma GCC diagnostic pop | |||||
| #endif | |||||
| @@ -0,0 +1,54 @@ | |||||
| #if !MEGDNN_TEGRA_X1 | |||||
| // ignore warning of cutlass | |||||
| #pragma GCC diagnostic push | |||||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||||
| #include "src/cuda/conv_bias/implicit_gemm_conv_bias_cutlass_wrapper.cuinl" | |||||
| // kernel instance "cutlass_simt_s4_ifprop_identity_s8_64x128x32_64x32x32_2_nc4hw4_nhwc" generated by cutlass generator | |||||
| using Convolution = | |||||
| typename cutlass::conv::device::Convolution< | |||||
| int8_t, | |||||
| cutlass::layout::TensorNCxHWx<4>, | |||||
| int8_t, | |||||
| cutlass::layout::TensorCxRSKx<4>, | |||||
| cutlass::int4b_t, | |||||
| cutlass::layout::TensorNHWC, | |||||
| int32_t, | |||||
| cutlass::layout::TensorNHWC, | |||||
| int32_t, | |||||
| cutlass::conv::ConvType::kConvolution, | |||||
| cutlass::arch::OpClassSimt, | |||||
| cutlass::arch::Sm75, | |||||
| cutlass::gemm::GemmShape<64, 128, 32>, | |||||
| cutlass::gemm::GemmShape<64, 32, 32>, | |||||
| cutlass::gemm::GemmShape<1, 1, 4>, | |||||
| cutlass::epilogue::thread::BiasAddLinearCombinationClamp< | |||||
| cutlass::int4b_t, | |||||
| 8, | |||||
| int32_t, | |||||
| int32_t, | |||||
| float | |||||
| >, | |||||
| cutlass::conv::threadblock::ConvolutionFpropNCxHWxThreadblockSwizzle, | |||||
| 2, | |||||
| 4, | |||||
| 16, | |||||
| true, | |||||
| cutlass::arch::OpMultiplyAddSaturate>; | |||||
| template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>( | |||||
| const typename Convolution::ElementSrc* d_src, | |||||
| const typename Convolution::ElementFilter* d_filter, | |||||
| const typename Convolution::ElementBias* d_bias, | |||||
| const typename Convolution::ElementDst* d_z, | |||||
| typename Convolution::ElementDst* d_dst, | |||||
| int* workspace, | |||||
| typename Convolution::ConvolutionParameter const& conv_param, | |||||
| typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||||
| cudaStream_t stream); | |||||
| #pragma GCC diagnostic pop | |||||
| #endif | |||||
| @@ -0,0 +1,54 @@ | |||||
| #if !MEGDNN_TEGRA_X1 | |||||
| // ignore warning of cutlass | |||||
| #pragma GCC diagnostic push | |||||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||||
| #include "src/cuda/conv_bias/implicit_gemm_conv_bias_cutlass_wrapper.cuinl" | |||||
| // kernel instance "cutlass_simt_s4_ifprop_identity_s8_64x32x32_64x32x32_2_nc4hw4_nhwc" generated by cutlass generator | |||||
| using Convolution = | |||||
| typename cutlass::conv::device::Convolution< | |||||
| int8_t, | |||||
| cutlass::layout::TensorNCxHWx<4>, | |||||
| int8_t, | |||||
| cutlass::layout::TensorCxRSKx<4>, | |||||
| cutlass::int4b_t, | |||||
| cutlass::layout::TensorNHWC, | |||||
| int32_t, | |||||
| cutlass::layout::TensorNHWC, | |||||
| int32_t, | |||||
| cutlass::conv::ConvType::kConvolution, | |||||
| cutlass::arch::OpClassSimt, | |||||
| cutlass::arch::Sm75, | |||||
| cutlass::gemm::GemmShape<64, 32, 32>, | |||||
| cutlass::gemm::GemmShape<64, 32, 32>, | |||||
| cutlass::gemm::GemmShape<1, 1, 4>, | |||||
| cutlass::epilogue::thread::BiasAddLinearCombinationClamp< | |||||
| cutlass::int4b_t, | |||||
| 8, | |||||
| int32_t, | |||||
| int32_t, | |||||
| float | |||||
| >, | |||||
| cutlass::conv::threadblock::ConvolutionFpropNCxHWxThreadblockSwizzle, | |||||
| 2, | |||||
| 4, | |||||
| 16, | |||||
| true, | |||||
| cutlass::arch::OpMultiplyAddSaturate>; | |||||
| template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>( | |||||
| const typename Convolution::ElementSrc* d_src, | |||||
| const typename Convolution::ElementFilter* d_filter, | |||||
| const typename Convolution::ElementBias* d_bias, | |||||
| const typename Convolution::ElementDst* d_z, | |||||
| typename Convolution::ElementDst* d_dst, | |||||
| int* workspace, | |||||
| typename Convolution::ConvolutionParameter const& conv_param, | |||||
| typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||||
| cudaStream_t stream); | |||||
| #pragma GCC diagnostic pop | |||||
| #endif | |||||
| @@ -0,0 +1,54 @@ | |||||
| #if !MEGDNN_TEGRA_X1 | |||||
| // ignore warning of cutlass | |||||
| #pragma GCC diagnostic push | |||||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||||
| #include "src/cuda/conv_bias/implicit_gemm_conv_bias_cutlass_wrapper.cuinl" | |||||
| // kernel instance "cutlass_simt_s4_ifprop_identity_s8_64x64x32_64x32x32_2_nc4hw4_nhwc" generated by cutlass generator | |||||
| using Convolution = | |||||
| typename cutlass::conv::device::Convolution< | |||||
| int8_t, | |||||
| cutlass::layout::TensorNCxHWx<4>, | |||||
| int8_t, | |||||
| cutlass::layout::TensorCxRSKx<4>, | |||||
| cutlass::int4b_t, | |||||
| cutlass::layout::TensorNHWC, | |||||
| int32_t, | |||||
| cutlass::layout::TensorNHWC, | |||||
| int32_t, | |||||
| cutlass::conv::ConvType::kConvolution, | |||||
| cutlass::arch::OpClassSimt, | |||||
| cutlass::arch::Sm75, | |||||
| cutlass::gemm::GemmShape<64, 64, 32>, | |||||
| cutlass::gemm::GemmShape<64, 32, 32>, | |||||
| cutlass::gemm::GemmShape<1, 1, 4>, | |||||
| cutlass::epilogue::thread::BiasAddLinearCombinationClamp< | |||||
| cutlass::int4b_t, | |||||
| 8, | |||||
| int32_t, | |||||
| int32_t, | |||||
| float | |||||
| >, | |||||
| cutlass::conv::threadblock::ConvolutionFpropNCxHWxThreadblockSwizzle, | |||||
| 2, | |||||
| 4, | |||||
| 16, | |||||
| true, | |||||
| cutlass::arch::OpMultiplyAddSaturate>; | |||||
| template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>( | |||||
| const typename Convolution::ElementSrc* d_src, | |||||
| const typename Convolution::ElementFilter* d_filter, | |||||
| const typename Convolution::ElementBias* d_bias, | |||||
| const typename Convolution::ElementDst* d_z, | |||||
| typename Convolution::ElementDst* d_dst, | |||||
| int* workspace, | |||||
| typename Convolution::ConvolutionParameter const& conv_param, | |||||
| typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||||
| cudaStream_t stream); | |||||
| #pragma GCC diagnostic pop | |||||
| #endif | |||||
| @@ -0,0 +1,54 @@ | |||||
| #if !MEGDNN_TEGRA_X1 | |||||
| // ignore warning of cutlass | |||||
| #pragma GCC diagnostic push | |||||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||||
| #include "src/cuda/conv_bias/implicit_gemm_conv_bias_cutlass_wrapper.cuinl" | |||||
| // kernel instance "cutlass_simt_s4_ifprop_relu_s8_128x128x32_64x32x32_2_nc4hw4_nhwc" generated by cutlass generator | |||||
| using Convolution = | |||||
| typename cutlass::conv::device::Convolution< | |||||
| int8_t, | |||||
| cutlass::layout::TensorNCxHWx<4>, | |||||
| int8_t, | |||||
| cutlass::layout::TensorCxRSKx<4>, | |||||
| cutlass::int4b_t, | |||||
| cutlass::layout::TensorNHWC, | |||||
| int32_t, | |||||
| cutlass::layout::TensorNHWC, | |||||
| int32_t, | |||||
| cutlass::conv::ConvType::kConvolution, | |||||
| cutlass::arch::OpClassSimt, | |||||
| cutlass::arch::Sm75, | |||||
| cutlass::gemm::GemmShape<128, 128, 32>, | |||||
| cutlass::gemm::GemmShape<64, 32, 32>, | |||||
| cutlass::gemm::GemmShape<1, 1, 4>, | |||||
| cutlass::epilogue::thread::BiasAddLinearCombinationReluClamp< | |||||
| cutlass::int4b_t, | |||||
| 8, | |||||
| int32_t, | |||||
| int32_t, | |||||
| float | |||||
| >, | |||||
| cutlass::conv::threadblock::ConvolutionFpropNCxHWxThreadblockSwizzle, | |||||
| 2, | |||||
| 4, | |||||
| 16, | |||||
| true, | |||||
| cutlass::arch::OpMultiplyAddSaturate>; | |||||
| template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>( | |||||
| const typename Convolution::ElementSrc* d_src, | |||||
| const typename Convolution::ElementFilter* d_filter, | |||||
| const typename Convolution::ElementBias* d_bias, | |||||
| const typename Convolution::ElementDst* d_z, | |||||
| typename Convolution::ElementDst* d_dst, | |||||
| int* workspace, | |||||
| typename Convolution::ConvolutionParameter const& conv_param, | |||||
| typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||||
| cudaStream_t stream); | |||||
| #pragma GCC diagnostic pop | |||||
| #endif | |||||
| @@ -0,0 +1,54 @@ | |||||
| #if !MEGDNN_TEGRA_X1 | |||||
| // ignore warning of cutlass | |||||
| #pragma GCC diagnostic push | |||||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||||
| #include "src/cuda/conv_bias/implicit_gemm_conv_bias_cutlass_wrapper.cuinl" | |||||
| // kernel instance "cutlass_simt_s4_ifprop_relu_s8_128x32x32_64x32x32_2_nc4hw4_nhwc" generated by cutlass generator | |||||
| using Convolution = | |||||
| typename cutlass::conv::device::Convolution< | |||||
| int8_t, | |||||
| cutlass::layout::TensorNCxHWx<4>, | |||||
| int8_t, | |||||
| cutlass::layout::TensorCxRSKx<4>, | |||||
| cutlass::int4b_t, | |||||
| cutlass::layout::TensorNHWC, | |||||
| int32_t, | |||||
| cutlass::layout::TensorNHWC, | |||||
| int32_t, | |||||
| cutlass::conv::ConvType::kConvolution, | |||||
| cutlass::arch::OpClassSimt, | |||||
| cutlass::arch::Sm75, | |||||
| cutlass::gemm::GemmShape<128, 32, 32>, | |||||
| cutlass::gemm::GemmShape<64, 32, 32>, | |||||
| cutlass::gemm::GemmShape<1, 1, 4>, | |||||
| cutlass::epilogue::thread::BiasAddLinearCombinationReluClamp< | |||||
| cutlass::int4b_t, | |||||
| 8, | |||||
| int32_t, | |||||
| int32_t, | |||||
| float | |||||
| >, | |||||
| cutlass::conv::threadblock::ConvolutionFpropNCxHWxThreadblockSwizzle, | |||||
| 2, | |||||
| 4, | |||||
| 16, | |||||
| true, | |||||
| cutlass::arch::OpMultiplyAddSaturate>; | |||||
| template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>( | |||||
| const typename Convolution::ElementSrc* d_src, | |||||
| const typename Convolution::ElementFilter* d_filter, | |||||
| const typename Convolution::ElementBias* d_bias, | |||||
| const typename Convolution::ElementDst* d_z, | |||||
| typename Convolution::ElementDst* d_dst, | |||||
| int* workspace, | |||||
| typename Convolution::ConvolutionParameter const& conv_param, | |||||
| typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||||
| cudaStream_t stream); | |||||
| #pragma GCC diagnostic pop | |||||
| #endif | |||||
| @@ -0,0 +1,54 @@ | |||||
| #if !MEGDNN_TEGRA_X1 | |||||
| // ignore warning of cutlass | |||||
| #pragma GCC diagnostic push | |||||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||||
| #include "src/cuda/conv_bias/implicit_gemm_conv_bias_cutlass_wrapper.cuinl" | |||||
| // kernel instance "cutlass_simt_s4_ifprop_relu_s8_128x64x32_64x32x32_2_nc4hw4_nhwc" generated by cutlass generator | |||||
| using Convolution = | |||||
| typename cutlass::conv::device::Convolution< | |||||
| int8_t, | |||||
| cutlass::layout::TensorNCxHWx<4>, | |||||
| int8_t, | |||||
| cutlass::layout::TensorCxRSKx<4>, | |||||
| cutlass::int4b_t, | |||||
| cutlass::layout::TensorNHWC, | |||||
| int32_t, | |||||
| cutlass::layout::TensorNHWC, | |||||
| int32_t, | |||||
| cutlass::conv::ConvType::kConvolution, | |||||
| cutlass::arch::OpClassSimt, | |||||
| cutlass::arch::Sm75, | |||||
| cutlass::gemm::GemmShape<128, 64, 32>, | |||||
| cutlass::gemm::GemmShape<64, 32, 32>, | |||||
| cutlass::gemm::GemmShape<1, 1, 4>, | |||||
| cutlass::epilogue::thread::BiasAddLinearCombinationReluClamp< | |||||
| cutlass::int4b_t, | |||||
| 8, | |||||
| int32_t, | |||||
| int32_t, | |||||
| float | |||||
| >, | |||||
| cutlass::conv::threadblock::ConvolutionFpropNCxHWxThreadblockSwizzle, | |||||
| 2, | |||||
| 4, | |||||
| 16, | |||||
| true, | |||||
| cutlass::arch::OpMultiplyAddSaturate>; | |||||
| template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>( | |||||
| const typename Convolution::ElementSrc* d_src, | |||||
| const typename Convolution::ElementFilter* d_filter, | |||||
| const typename Convolution::ElementBias* d_bias, | |||||
| const typename Convolution::ElementDst* d_z, | |||||
| typename Convolution::ElementDst* d_dst, | |||||
| int* workspace, | |||||
| typename Convolution::ConvolutionParameter const& conv_param, | |||||
| typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||||
| cudaStream_t stream); | |||||
| #pragma GCC diagnostic pop | |||||
| #endif | |||||
| @@ -0,0 +1,54 @@ | |||||
| #if !MEGDNN_TEGRA_X1 | |||||
| // ignore warning of cutlass | |||||
| #pragma GCC diagnostic push | |||||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||||
| #include "src/cuda/conv_bias/implicit_gemm_conv_bias_cutlass_wrapper.cuinl" | |||||
| // kernel instance "cutlass_simt_s4_ifprop_relu_s8_16x128x16_16x128x16_1_nc4hw4_nhwc" generated by cutlass generator | |||||
| using Convolution = | |||||
| typename cutlass::conv::device::Convolution< | |||||
| int8_t, | |||||
| cutlass::layout::TensorNCxHWx<4>, | |||||
| int8_t, | |||||
| cutlass::layout::TensorCxRSKx<4>, | |||||
| cutlass::int4b_t, | |||||
| cutlass::layout::TensorNHWC, | |||||
| int32_t, | |||||
| cutlass::layout::TensorNHWC, | |||||
| int32_t, | |||||
| cutlass::conv::ConvType::kConvolution, | |||||
| cutlass::arch::OpClassSimt, | |||||
| cutlass::arch::Sm75, | |||||
| cutlass::gemm::GemmShape<16, 128, 16>, | |||||
| cutlass::gemm::GemmShape<16, 128, 16>, | |||||
| cutlass::gemm::GemmShape<1, 1, 4>, | |||||
| cutlass::epilogue::thread::BiasAddLinearCombinationReluClamp< | |||||
| cutlass::int4b_t, | |||||
| 8, | |||||
| int32_t, | |||||
| int32_t, | |||||
| float | |||||
| >, | |||||
| cutlass::conv::threadblock::ConvolutionFpropNCxHWxThreadblockSwizzle, | |||||
| 1, | |||||
| 4, | |||||
| 8, | |||||
| true, | |||||
| cutlass::arch::OpMultiplyAddSaturate>; | |||||
| template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>( | |||||
| const typename Convolution::ElementSrc* d_src, | |||||
| const typename Convolution::ElementFilter* d_filter, | |||||
| const typename Convolution::ElementBias* d_bias, | |||||
| const typename Convolution::ElementDst* d_z, | |||||
| typename Convolution::ElementDst* d_dst, | |||||
| int* workspace, | |||||
| typename Convolution::ConvolutionParameter const& conv_param, | |||||
| typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||||
| cudaStream_t stream); | |||||
| #pragma GCC diagnostic pop | |||||
| #endif | |||||
| @@ -0,0 +1,54 @@ | |||||
| #if !MEGDNN_TEGRA_X1 | |||||
| // ignore warning of cutlass | |||||
| #pragma GCC diagnostic push | |||||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||||
| #include "src/cuda/conv_bias/implicit_gemm_conv_bias_cutlass_wrapper.cuinl" | |||||
| // kernel instance "cutlass_simt_s4_ifprop_relu_s8_16x64x8_16x64x8_2_nc4hw4_nhwc" generated by cutlass generator | |||||
| using Convolution = | |||||
| typename cutlass::conv::device::Convolution< | |||||
| int8_t, | |||||
| cutlass::layout::TensorNCxHWx<4>, | |||||
| int8_t, | |||||
| cutlass::layout::TensorCxRSKx<4>, | |||||
| cutlass::int4b_t, | |||||
| cutlass::layout::TensorNHWC, | |||||
| int32_t, | |||||
| cutlass::layout::TensorNHWC, | |||||
| int32_t, | |||||
| cutlass::conv::ConvType::kConvolution, | |||||
| cutlass::arch::OpClassSimt, | |||||
| cutlass::arch::Sm75, | |||||
| cutlass::gemm::GemmShape<16, 64, 8>, | |||||
| cutlass::gemm::GemmShape<16, 64, 8>, | |||||
| cutlass::gemm::GemmShape<1, 1, 4>, | |||||
| cutlass::epilogue::thread::BiasAddLinearCombinationReluClamp< | |||||
| cutlass::int4b_t, | |||||
| 8, | |||||
| int32_t, | |||||
| int32_t, | |||||
| float | |||||
| >, | |||||
| cutlass::conv::threadblock::ConvolutionFpropNCxHWxThreadblockSwizzle, | |||||
| 2, | |||||
| 4, | |||||
| 4, | |||||
| true, | |||||
| cutlass::arch::OpMultiplyAddSaturate>; | |||||
| template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>( | |||||
| const typename Convolution::ElementSrc* d_src, | |||||
| const typename Convolution::ElementFilter* d_filter, | |||||
| const typename Convolution::ElementBias* d_bias, | |||||
| const typename Convolution::ElementDst* d_z, | |||||
| typename Convolution::ElementDst* d_dst, | |||||
| int* workspace, | |||||
| typename Convolution::ConvolutionParameter const& conv_param, | |||||
| typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||||
| cudaStream_t stream); | |||||
| #pragma GCC diagnostic pop | |||||
| #endif | |||||
| @@ -0,0 +1,54 @@ | |||||
| #if !MEGDNN_TEGRA_X1 | |||||
| // ignore warning of cutlass | |||||
| #pragma GCC diagnostic push | |||||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||||
| #include "src/cuda/conv_bias/implicit_gemm_conv_bias_cutlass_wrapper.cuinl" | |||||
| // kernel instance "cutlass_simt_s4_ifprop_relu_s8_32x128x32_32x64x32_2_nc4hw4_nhwc" generated by cutlass generator | |||||
| using Convolution = | |||||
| typename cutlass::conv::device::Convolution< | |||||
| int8_t, | |||||
| cutlass::layout::TensorNCxHWx<4>, | |||||
| int8_t, | |||||
| cutlass::layout::TensorCxRSKx<4>, | |||||
| cutlass::int4b_t, | |||||
| cutlass::layout::TensorNHWC, | |||||
| int32_t, | |||||
| cutlass::layout::TensorNHWC, | |||||
| int32_t, | |||||
| cutlass::conv::ConvType::kConvolution, | |||||
| cutlass::arch::OpClassSimt, | |||||
| cutlass::arch::Sm75, | |||||
| cutlass::gemm::GemmShape<32, 128, 32>, | |||||
| cutlass::gemm::GemmShape<32, 64, 32>, | |||||
| cutlass::gemm::GemmShape<1, 1, 4>, | |||||
| cutlass::epilogue::thread::BiasAddLinearCombinationReluClamp< | |||||
| cutlass::int4b_t, | |||||
| 8, | |||||
| int32_t, | |||||
| int32_t, | |||||
| float | |||||
| >, | |||||
| cutlass::conv::threadblock::ConvolutionFpropNCxHWxThreadblockSwizzle, | |||||
| 2, | |||||
| 4, | |||||
| 16, | |||||
| true, | |||||
| cutlass::arch::OpMultiplyAddSaturate>; | |||||
| template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>( | |||||
| const typename Convolution::ElementSrc* d_src, | |||||
| const typename Convolution::ElementFilter* d_filter, | |||||
| const typename Convolution::ElementBias* d_bias, | |||||
| const typename Convolution::ElementDst* d_z, | |||||
| typename Convolution::ElementDst* d_dst, | |||||
| int* workspace, | |||||
| typename Convolution::ConvolutionParameter const& conv_param, | |||||
| typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||||
| cudaStream_t stream); | |||||
| #pragma GCC diagnostic pop | |||||
| #endif | |||||
| @@ -0,0 +1,54 @@ | |||||
| #if !MEGDNN_TEGRA_X1 | |||||
| // ignore warning of cutlass | |||||
| #pragma GCC diagnostic push | |||||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||||
| #include "src/cuda/conv_bias/implicit_gemm_conv_bias_cutlass_wrapper.cuinl" | |||||
| // kernel instance "cutlass_simt_s4_ifprop_relu_s8_32x32x32_32x32x32_2_nc4hw4_nhwc" generated by cutlass generator | |||||
| using Convolution = | |||||
| typename cutlass::conv::device::Convolution< | |||||
| int8_t, | |||||
| cutlass::layout::TensorNCxHWx<4>, | |||||
| int8_t, | |||||
| cutlass::layout::TensorCxRSKx<4>, | |||||
| cutlass::int4b_t, | |||||
| cutlass::layout::TensorNHWC, | |||||
| int32_t, | |||||
| cutlass::layout::TensorNHWC, | |||||
| int32_t, | |||||
| cutlass::conv::ConvType::kConvolution, | |||||
| cutlass::arch::OpClassSimt, | |||||
| cutlass::arch::Sm75, | |||||
| cutlass::gemm::GemmShape<32, 32, 32>, | |||||
| cutlass::gemm::GemmShape<32, 32, 32>, | |||||
| cutlass::gemm::GemmShape<1, 1, 4>, | |||||
| cutlass::epilogue::thread::BiasAddLinearCombinationReluClamp< | |||||
| cutlass::int4b_t, | |||||
| 8, | |||||
| int32_t, | |||||
| int32_t, | |||||
| float | |||||
| >, | |||||
| cutlass::conv::threadblock::ConvolutionFpropNCxHWxThreadblockSwizzle, | |||||
| 2, | |||||
| 4, | |||||
| 16, | |||||
| true, | |||||
| cutlass::arch::OpMultiplyAddSaturate>; | |||||
| template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>( | |||||
| const typename Convolution::ElementSrc* d_src, | |||||
| const typename Convolution::ElementFilter* d_filter, | |||||
| const typename Convolution::ElementBias* d_bias, | |||||
| const typename Convolution::ElementDst* d_z, | |||||
| typename Convolution::ElementDst* d_dst, | |||||
| int* workspace, | |||||
| typename Convolution::ConvolutionParameter const& conv_param, | |||||
| typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||||
| cudaStream_t stream); | |||||
| #pragma GCC diagnostic pop | |||||
| #endif | |||||
| @@ -0,0 +1,54 @@ | |||||
| #if !MEGDNN_TEGRA_X1 | |||||
| // ignore warning of cutlass | |||||
| #pragma GCC diagnostic push | |||||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||||
| #include "src/cuda/conv_bias/implicit_gemm_conv_bias_cutlass_wrapper.cuinl" | |||||
| // kernel instance "cutlass_simt_s4_ifprop_relu_s8_32x64x32_32x64x32_2_nc4hw4_nhwc" generated by cutlass generator | |||||
| using Convolution = | |||||
| typename cutlass::conv::device::Convolution< | |||||
| int8_t, | |||||
| cutlass::layout::TensorNCxHWx<4>, | |||||
| int8_t, | |||||
| cutlass::layout::TensorCxRSKx<4>, | |||||
| cutlass::int4b_t, | |||||
| cutlass::layout::TensorNHWC, | |||||
| int32_t, | |||||
| cutlass::layout::TensorNHWC, | |||||
| int32_t, | |||||
| cutlass::conv::ConvType::kConvolution, | |||||
| cutlass::arch::OpClassSimt, | |||||
| cutlass::arch::Sm75, | |||||
| cutlass::gemm::GemmShape<32, 64, 32>, | |||||
| cutlass::gemm::GemmShape<32, 64, 32>, | |||||
| cutlass::gemm::GemmShape<1, 1, 4>, | |||||
| cutlass::epilogue::thread::BiasAddLinearCombinationReluClamp< | |||||
| cutlass::int4b_t, | |||||
| 8, | |||||
| int32_t, | |||||
| int32_t, | |||||
| float | |||||
| >, | |||||
| cutlass::conv::threadblock::ConvolutionFpropNCxHWxThreadblockSwizzle, | |||||
| 2, | |||||
| 4, | |||||
| 16, | |||||
| true, | |||||
| cutlass::arch::OpMultiplyAddSaturate>; | |||||
| template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>( | |||||
| const typename Convolution::ElementSrc* d_src, | |||||
| const typename Convolution::ElementFilter* d_filter, | |||||
| const typename Convolution::ElementBias* d_bias, | |||||
| const typename Convolution::ElementDst* d_z, | |||||
| typename Convolution::ElementDst* d_dst, | |||||
| int* workspace, | |||||
| typename Convolution::ConvolutionParameter const& conv_param, | |||||
| typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||||
| cudaStream_t stream); | |||||
| #pragma GCC diagnostic pop | |||||
| #endif | |||||
| @@ -0,0 +1,54 @@ | |||||
| #if !MEGDNN_TEGRA_X1 | |||||
| // ignore warning of cutlass | |||||
| #pragma GCC diagnostic push | |||||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||||
| #include "src/cuda/conv_bias/implicit_gemm_conv_bias_cutlass_wrapper.cuinl" | |||||
| // kernel instance "cutlass_simt_s4_ifprop_relu_s8_64x128x32_64x32x32_2_nc4hw4_nhwc" generated by cutlass generator | |||||
| using Convolution = | |||||
| typename cutlass::conv::device::Convolution< | |||||
| int8_t, | |||||
| cutlass::layout::TensorNCxHWx<4>, | |||||
| int8_t, | |||||
| cutlass::layout::TensorCxRSKx<4>, | |||||
| cutlass::int4b_t, | |||||
| cutlass::layout::TensorNHWC, | |||||
| int32_t, | |||||
| cutlass::layout::TensorNHWC, | |||||
| int32_t, | |||||
| cutlass::conv::ConvType::kConvolution, | |||||
| cutlass::arch::OpClassSimt, | |||||
| cutlass::arch::Sm75, | |||||
| cutlass::gemm::GemmShape<64, 128, 32>, | |||||
| cutlass::gemm::GemmShape<64, 32, 32>, | |||||
| cutlass::gemm::GemmShape<1, 1, 4>, | |||||
| cutlass::epilogue::thread::BiasAddLinearCombinationReluClamp< | |||||
| cutlass::int4b_t, | |||||
| 8, | |||||
| int32_t, | |||||
| int32_t, | |||||
| float | |||||
| >, | |||||
| cutlass::conv::threadblock::ConvolutionFpropNCxHWxThreadblockSwizzle, | |||||
| 2, | |||||
| 4, | |||||
| 16, | |||||
| true, | |||||
| cutlass::arch::OpMultiplyAddSaturate>; | |||||
| template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>( | |||||
| const typename Convolution::ElementSrc* d_src, | |||||
| const typename Convolution::ElementFilter* d_filter, | |||||
| const typename Convolution::ElementBias* d_bias, | |||||
| const typename Convolution::ElementDst* d_z, | |||||
| typename Convolution::ElementDst* d_dst, | |||||
| int* workspace, | |||||
| typename Convolution::ConvolutionParameter const& conv_param, | |||||
| typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||||
| cudaStream_t stream); | |||||
| #pragma GCC diagnostic pop | |||||
| #endif | |||||
| @@ -0,0 +1,54 @@ | |||||
| #if !MEGDNN_TEGRA_X1 | |||||
| // ignore warning of cutlass | |||||
| #pragma GCC diagnostic push | |||||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||||
| #include "src/cuda/conv_bias/implicit_gemm_conv_bias_cutlass_wrapper.cuinl" | |||||
| // kernel instance "cutlass_simt_s4_ifprop_relu_s8_64x32x32_64x32x32_2_nc4hw4_nhwc" generated by cutlass generator | |||||
| using Convolution = | |||||
| typename cutlass::conv::device::Convolution< | |||||
| int8_t, | |||||
| cutlass::layout::TensorNCxHWx<4>, | |||||
| int8_t, | |||||
| cutlass::layout::TensorCxRSKx<4>, | |||||
| cutlass::int4b_t, | |||||
| cutlass::layout::TensorNHWC, | |||||
| int32_t, | |||||
| cutlass::layout::TensorNHWC, | |||||
| int32_t, | |||||
| cutlass::conv::ConvType::kConvolution, | |||||
| cutlass::arch::OpClassSimt, | |||||
| cutlass::arch::Sm75, | |||||
| cutlass::gemm::GemmShape<64, 32, 32>, | |||||
| cutlass::gemm::GemmShape<64, 32, 32>, | |||||
| cutlass::gemm::GemmShape<1, 1, 4>, | |||||
| cutlass::epilogue::thread::BiasAddLinearCombinationReluClamp< | |||||
| cutlass::int4b_t, | |||||
| 8, | |||||
| int32_t, | |||||
| int32_t, | |||||
| float | |||||
| >, | |||||
| cutlass::conv::threadblock::ConvolutionFpropNCxHWxThreadblockSwizzle, | |||||
| 2, | |||||
| 4, | |||||
| 16, | |||||
| true, | |||||
| cutlass::arch::OpMultiplyAddSaturate>; | |||||
| template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>( | |||||
| const typename Convolution::ElementSrc* d_src, | |||||
| const typename Convolution::ElementFilter* d_filter, | |||||
| const typename Convolution::ElementBias* d_bias, | |||||
| const typename Convolution::ElementDst* d_z, | |||||
| typename Convolution::ElementDst* d_dst, | |||||
| int* workspace, | |||||
| typename Convolution::ConvolutionParameter const& conv_param, | |||||
| typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||||
| cudaStream_t stream); | |||||
| #pragma GCC diagnostic pop | |||||
| #endif | |||||
| @@ -0,0 +1,54 @@ | |||||
| #if !MEGDNN_TEGRA_X1 | |||||
| // ignore warning of cutlass | |||||
| #pragma GCC diagnostic push | |||||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||||
| #include "src/cuda/conv_bias/implicit_gemm_conv_bias_cutlass_wrapper.cuinl" | |||||
| // kernel instance "cutlass_simt_s4_ifprop_relu_s8_64x64x32_64x32x32_2_nc4hw4_nhwc" generated by cutlass generator | |||||
| using Convolution = | |||||
| typename cutlass::conv::device::Convolution< | |||||
| int8_t, | |||||
| cutlass::layout::TensorNCxHWx<4>, | |||||
| int8_t, | |||||
| cutlass::layout::TensorCxRSKx<4>, | |||||
| cutlass::int4b_t, | |||||
| cutlass::layout::TensorNHWC, | |||||
| int32_t, | |||||
| cutlass::layout::TensorNHWC, | |||||
| int32_t, | |||||
| cutlass::conv::ConvType::kConvolution, | |||||
| cutlass::arch::OpClassSimt, | |||||
| cutlass::arch::Sm75, | |||||
| cutlass::gemm::GemmShape<64, 64, 32>, | |||||
| cutlass::gemm::GemmShape<64, 32, 32>, | |||||
| cutlass::gemm::GemmShape<1, 1, 4>, | |||||
| cutlass::epilogue::thread::BiasAddLinearCombinationReluClamp< | |||||
| cutlass::int4b_t, | |||||
| 8, | |||||
| int32_t, | |||||
| int32_t, | |||||
| float | |||||
| >, | |||||
| cutlass::conv::threadblock::ConvolutionFpropNCxHWxThreadblockSwizzle, | |||||
| 2, | |||||
| 4, | |||||
| 16, | |||||
| true, | |||||
| cutlass::arch::OpMultiplyAddSaturate>; | |||||
| template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>( | |||||
| const typename Convolution::ElementSrc* d_src, | |||||
| const typename Convolution::ElementFilter* d_filter, | |||||
| const typename Convolution::ElementBias* d_bias, | |||||
| const typename Convolution::ElementDst* d_z, | |||||
| typename Convolution::ElementDst* d_dst, | |||||
| int* workspace, | |||||
| typename Convolution::ConvolutionParameter const& conv_param, | |||||
| typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||||
| cudaStream_t stream); | |||||
| #pragma GCC diagnostic pop | |||||
| #endif | |||||
| @@ -159,6 +159,7 @@ void compute2d(_megdnn_tensor_in src, ftype* __restrict fptr, | |||||
| filter_meta.format == Format::NCHW44_DOT || | filter_meta.format == Format::NCHW44_DOT || | ||||
| filter_meta.format == Format::NCHW4 || | filter_meta.format == Format::NCHW4 || | ||||
| filter_meta.format == Format::NCHW4_NCHW || | filter_meta.format == Format::NCHW4_NCHW || | ||||
| filter_meta.format == Format::NCHW4_NHWC || | |||||
| filter_meta.format == Format::NCHW4_NCHW32 || | filter_meta.format == Format::NCHW4_NCHW32 || | ||||
| filter_meta.format == Format::NCHW8 || | filter_meta.format == Format::NCHW8 || | ||||
| filter_meta.format == Format::NCHW32 || | filter_meta.format == Format::NCHW32 || | ||||
| @@ -182,9 +183,15 @@ void compute2d(_megdnn_tensor_in src, ftype* __restrict fptr, | |||||
| auto N = src.layout.shape[batch_pos], IH = src.layout.shape[spatial_start], | auto N = src.layout.shape[batch_pos], IH = src.layout.shape[spatial_start], | ||||
| IW = src.layout.shape[spatial_start + 1]; | IW = src.layout.shape[spatial_start + 1]; | ||||
| auto FH = filter_meta.spatial[0], FW = filter_meta.spatial[1]; | auto FH = filter_meta.spatial[0], FW = filter_meta.spatial[1]; | ||||
| auto OC = dst.layout.shape[channel_pos], | |||||
| OH = dst.layout.shape[spatial_start], | |||||
| OW = dst.layout.shape[spatial_start + 1]; | |||||
| size_t OC, OH, OW; | |||||
| if (filter_meta.format == Format::NCHW4_NHWC) { | |||||
| OC = dst.layout.shape[3], OH = dst.layout.shape[1], | |||||
| OW = dst.layout.shape[2]; | |||||
| } else { | |||||
| OC = dst.layout.shape[channel_pos], | |||||
| OH = dst.layout.shape[spatial_start], | |||||
| OW = dst.layout.shape[spatial_start + 1]; | |||||
| } | |||||
| if (filter_meta.format == Format::NCHW4 || | if (filter_meta.format == Format::NCHW4 || | ||||
| filter_meta.format == Format::CHWN4 || | filter_meta.format == Format::CHWN4 || | ||||
| @@ -206,6 +213,7 @@ void compute2d(_megdnn_tensor_in src, ftype* __restrict fptr, | |||||
| if (filter_meta.format == Format::NCHW || | if (filter_meta.format == Format::NCHW || | ||||
| filter_meta.format == Format::NCHW4 || | filter_meta.format == Format::NCHW4 || | ||||
| filter_meta.format == Format::NCHW4_NCHW || | filter_meta.format == Format::NCHW4_NCHW || | ||||
| filter_meta.format == Format::NCHW4_NHWC || | |||||
| filter_meta.format == Format::NCHW4_NCHW32 || | filter_meta.format == Format::NCHW4_NCHW32 || | ||||
| filter_meta.format == Format::NCHW8 || | filter_meta.format == Format::NCHW8 || | ||||
| filter_meta.format == Format::NCHW32 || | filter_meta.format == Format::NCHW32 || | ||||
| @@ -343,6 +351,15 @@ void compute2d(_megdnn_tensor_in src, ftype* __restrict fptr, | |||||
| h * layout.stride[2] + w * layout.stride[3] + | h * layout.stride[2] + w * layout.stride[3] + | ||||
| (c & 0b11) * layout.stride[4]; | (c & 0b11) * layout.stride[4]; | ||||
| } | } | ||||
| } else if (filter_meta.format == Format::NCHW4_NHWC) { | |||||
| if (is_output) { | |||||
| return n * layout.stride[0] + h * layout.stride[1] + | |||||
| w * layout.stride[2] + c * layout.stride[3]; | |||||
| } else { | |||||
| return n * layout.stride[0] + (c / 4) * layout.stride[1] + | |||||
| h * layout.stride[2] + w * layout.stride[3] + | |||||
| (c & 0b11) * layout.stride[4]; | |||||
| } | |||||
| } else if (filter_meta.format == Format::NCHW4_NCHW32) { | } else if (filter_meta.format == Format::NCHW4_NCHW32) { | ||||
| if (is_output) { | if (is_output) { | ||||
| return n * layout.stride[0] + (c >> 5) * layout.stride[1] + | return n * layout.stride[0] + (c >> 5) * layout.stride[1] + | ||||
| @@ -370,6 +387,7 @@ void compute2d(_megdnn_tensor_in src, ftype* __restrict fptr, | |||||
| size_t fh, size_t fw) { | size_t fh, size_t fw) { | ||||
| if (filter_meta.format == Format::NCHW4 || | if (filter_meta.format == Format::NCHW4 || | ||||
| filter_meta.format == Format::NCHW4_NCHW || | filter_meta.format == Format::NCHW4_NCHW || | ||||
| filter_meta.format == Format::NCHW4_NHWC || | |||||
| filter_meta.format == Format::NCHW4_NCHW32) { | filter_meta.format == Format::NCHW4_NCHW32) { | ||||
| return gc_out.cur_grp * FS_G + gc_out.cur_off * FS_OC + | return gc_out.cur_grp * FS_G + gc_out.cur_off * FS_OC + | ||||
| (ic - ic0) / 4 * FS_IC * 4 + | (ic - ic0) / 4 * FS_IC * 4 + | ||||
| @@ -695,6 +713,7 @@ void forward_bias(_megdnn_tensor_in src, _megdnn_tensor_in filter, | |||||
| case param::Convolution::Format::NHWC: | case param::Convolution::Format::NHWC: | ||||
| case param::Convolution::Format::NCHW4: | case param::Convolution::Format::NCHW4: | ||||
| case param::Convolution::Format::NCHW4_NCHW: | case param::Convolution::Format::NCHW4_NCHW: | ||||
| case param::Convolution::Format::NCHW4_NHWC: | |||||
| case param::Convolution::Format::NCHW4_NCHW32: | case param::Convolution::Format::NCHW4_NCHW32: | ||||
| case param::Convolution::Format::NCHW8: | case param::Convolution::Format::NCHW8: | ||||
| case param::Convolution::Format::NCHW32: | case param::Convolution::Format::NCHW32: | ||||
| @@ -820,6 +839,7 @@ void forward_bias(_megdnn_tensor_in src, _megdnn_tensor_in filter, | |||||
| BIAS_ADD_CHWNx(4); | BIAS_ADD_CHWNx(4); | ||||
| break; | break; | ||||
| } | } | ||||
| case Format::NCHW4_NHWC: | |||||
| case Format::NHWC: { | case Format::NHWC: { | ||||
| int dst_nhw = dst.layout.shape[0] * dst.layout.shape[1] * | int dst_nhw = dst.layout.shape[0] * dst.layout.shape[1] * | ||||
| dst.layout.shape[2]; | dst.layout.shape[2]; | ||||