| @@ -9,6 +9,7 @@ genrule( | |||
| CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations gemm --type tensorop1688 $(@D) | |||
| CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations gemv --type simt $(@D) | |||
| CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations deconv --type simt $(@D) | |||
| CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations deconv --type tensorop8816 $(@D) | |||
| CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations conv2d --type simt $(@D) | |||
| CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations conv2d --type tensorop8816 $(@D) | |||
| CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations conv2d --type tensorop8832 $(@D) | |||
| @@ -337,7 +337,10 @@ def GenerateConv2d(conv_kind, tile_descriptions, src_layout, flt_layout, dst_lay | |||
| else: | |||
| swizzling_functor = SwizzlingFunctor.ConvFpropNCxHWx | |||
| else: | |||
| swizzling_functor = SwizzlingFunctor.ConvDgradNCxHWx | |||
| if implicit_gemm_mode == ImplicitGemmMode.GemmTN: | |||
| swizzling_functor = SwizzlingFunctor.ConvDgradTrans | |||
| else: | |||
| swizzling_functor = SwizzlingFunctor.ConvDgradNCxHWx | |||
| # skip rule | |||
| def filter_tile_with_layout(tile: TileDescription, layout: LayoutType) -> bool: | |||
| @@ -36,6 +36,7 @@ if __name__ == "__main__": | |||
| write_op_list(f, "gemm", "tensorop884") | |||
| write_op_list(f, "gemv", "simt") | |||
| write_op_list(f, "deconv", "simt") | |||
| write_op_list(f, "deconv", "tensorop8816") | |||
| write_op_list(f, "conv2d", "simt") | |||
| write_op_list(f, "conv2d", "tensorop8816") | |||
| write_op_list(f, "conv2d", "tensorop8832") | |||
| @@ -445,6 +445,53 @@ def GenerateDeconv_Simt(args): | |||
| use_special_optimization) | |||
| return operations | |||
| def GenerateDeconv_TensorOp_8816(args): | |||
| operations = [] | |||
| layouts = [ | |||
| (LayoutType.TensorNHWC, LayoutType.TensorCK4RS4, 32), | |||
| (LayoutType.TensorNHWC, LayoutType.TensorCK8RS8, 64), | |||
| (LayoutType.TensorNHWC, LayoutType.TensorCK16RS16, 128), | |||
| ] | |||
| math_instructions = [ | |||
| MathInstruction( \ | |||
| [8, 8, 16], \ | |||
| DataType.s8, DataType.s8, DataType.s32, \ | |||
| OpcodeClass.TensorOp, \ | |||
| MathOperation.multiply_add_saturate), | |||
| ] | |||
| dst_layouts = [ | |||
| LayoutType.TensorNHWC, | |||
| ] | |||
| dst_types = [ | |||
| DataType.s8, | |||
| ] | |||
| use_special_optimization = SpecialOptimizeDesc.DeconvDoubleUpsampling | |||
| min_cc = 75 | |||
| max_cc = 1024 | |||
| cuda_major = 10 | |||
| cuda_minor = 2 | |||
| for math_inst in math_instructions: | |||
| for layout in layouts: | |||
| for dst_type, dst_layout in zip(dst_types, dst_layouts): | |||
| tile_descriptions = [ | |||
| TileDescription([128, 32, 32], 1, [2, 1, 1], math_inst, min_cc, max_cc), | |||
| TileDescription([64, 16, 32], 2, [1, 1, 1], math_inst, min_cc, max_cc), | |||
| ] | |||
| for tile in tile_descriptions: | |||
| dst_align = 32 if tile.threadblock_shape[1] == 16 else 64 | |||
| operations += GenerateConv2d(ConvKind.Dgrad, [tile], layout[0], layout[1], dst_layout, dst_type, | |||
| min_cc, layout[2], layout[2], dst_align, use_special_optimization, | |||
| ImplicitGemmMode.GemmTN, False, cuda_major, cuda_minor) | |||
| return operations | |||
| ################################################################################ | |||
| # parameters | |||
| # Edge - for tiles, the edges represent the length of one side | |||
| @@ -820,9 +867,12 @@ def GenerateConv2dOperations(args): | |||
| return GenerateConv2d_TensorOp_8832(args) | |||
| def GenerateDeconvOperations(args): | |||
| assert args.type == "simt", "operation deconv only support" \ | |||
| "simt. (got:{})".format(args.type) | |||
| return GenerateDeconv_Simt(args) | |||
| if args.type == "simt": | |||
| return GenerateDeconv_Simt(args) | |||
| else: | |||
| assert args.type == "tensorop8816", "operation deconv only support" \ | |||
| "simt and tensorop8816. (got:{})".format(args.type) | |||
| return GenerateDeconv_TensorOp_8816(args) | |||
| def GenerateGemmOperations(args): | |||
| if args.type == "tensorop884": | |||
| @@ -280,6 +280,9 @@ class LayoutType(enum.Enum): | |||
| TensorC32RSK32 = enum_auto() | |||
| TensorC64RSK64 = enum_auto() | |||
| TensorK4RSC4 = enum_auto() | |||
| TensorCK4RS4 = enum_auto() | |||
| TensorCK8RS8 = enum_auto() | |||
| TensorCK16RS16 = enum_auto() | |||
| # | |||
| LayoutTag = { | |||
| @@ -303,7 +306,10 @@ LayoutTag = { | |||
| LayoutType.TensorC32RSK32: 'cutlass::layout::TensorCxRSKx<32>', | |||
| LayoutType.TensorNC64HW64: 'cutlass::layout::TensorNCxHWx<64>', | |||
| LayoutType.TensorC64RSK64: 'cutlass::layout::TensorCxRSKx<64>', | |||
| LayoutType.TensorK4RSC4: 'cutlass::layout::TensorKxRSCx<4>', | |||
| LayoutType.TensorK4RSC4: 'cutlass::layout::TensorKxRSCx<4>', | |||
| LayoutType.TensorCK4RS4: 'cutlass::layout::TensorCKxRSx<4>', | |||
| LayoutType.TensorCK8RS8: 'cutlass::layout::TensorCKxRSx<8>', | |||
| LayoutType.TensorCK16RS16: 'cutlass::layout::TensorCKxRSx<16>', | |||
| } | |||
| # | |||
| @@ -342,6 +348,9 @@ ShortLayoutTypeNames = { | |||
| LayoutType.TensorC32RSK32: 'c32rsk32', | |||
| LayoutType.TensorC64RSK64: 'c64rsk64', | |||
| LayoutType.TensorK4RSC4: 'k4rsc4', | |||
| LayoutType.TensorCK4RS4: 'ck4rs4', | |||
| LayoutType.TensorCK8RS8: 'ck8rs8', | |||
| LayoutType.TensorCK16RS16: 'ck16rs16', | |||
| } | |||
| # | |||
| @@ -484,6 +493,7 @@ class SwizzlingFunctor(enum.Enum): | |||
| ConvFpropNCxHWx = enum_auto() | |||
| ConvFpropTrans = enum_auto() | |||
| ConvDgradNCxHWx = enum_auto() | |||
| ConvDgradTrans = enum_auto() | |||
| # | |||
| SwizzlingFunctorTag = { | |||
| @@ -494,6 +504,7 @@ SwizzlingFunctorTag = { | |||
| SwizzlingFunctor.ConvFpropNCxHWx: 'cutlass::conv::threadblock::ConvolutionFpropNCxHWxThreadblockSwizzle', | |||
| SwizzlingFunctor.ConvFpropTrans: 'cutlass::conv::threadblock::ConvolutionFpropTransThreadblockSwizzle', | |||
| SwizzlingFunctor.ConvDgradNCxHWx: 'cutlass::conv::threadblock::ConvolutionDgradNCxHWxThreadblockSwizzle', | |||
| SwizzlingFunctor.ConvDgradTrans: 'cutlass::conv::threadblock::ConvolutionDgradTransThreadblockSwizzle', | |||
| } | |||
| ################################################################################################### | |||
| @@ -464,6 +464,19 @@ cutlass_gen_list = [ | |||
| "cutlass_simt_s8_idgrad_id_s8_16x64x8_16x64x8_2_nc4hw4_k4rsc4.cu", | |||
| "cutlass_simt_s8_idgrad_s2_id_s8_16x64x8_16x64x8_2_nc4hw4_k4rsc4.cu", | |||
| "all_deconv_simt_operations.cu", | |||
| "cutlass_tensorop_s8_i8816dgrad_id_s8_128x32x32_64x32x32_1_nhwc_ck4rs4.cu", | |||
| "cutlass_tensorop_s8_i8816dgrad_s2_id_s8_128x32x32_64x32x32_1_nhwc_ck4rs4.cu", | |||
| "cutlass_tensorop_s8_i8816dgrad_id_s8_64x16x32_64x16x32_2_nhwc_ck4rs4.cu", | |||
| "cutlass_tensorop_s8_i8816dgrad_s2_id_s8_64x16x32_64x16x32_2_nhwc_ck4rs4.cu", | |||
| "cutlass_tensorop_s8_i8816dgrad_id_s8_128x32x32_64x32x32_1_nhwc_ck8rs8.cu", | |||
| "cutlass_tensorop_s8_i8816dgrad_s2_id_s8_128x32x32_64x32x32_1_nhwc_ck8rs8.cu", | |||
| "cutlass_tensorop_s8_i8816dgrad_id_s8_64x16x32_64x16x32_2_nhwc_ck8rs8.cu", | |||
| "cutlass_tensorop_s8_i8816dgrad_s2_id_s8_64x16x32_64x16x32_2_nhwc_ck8rs8.cu", | |||
| "cutlass_tensorop_s8_i8816dgrad_id_s8_128x32x32_64x32x32_1_nhwc_ck16rs16.cu", | |||
| "cutlass_tensorop_s8_i8816dgrad_s2_id_s8_128x32x32_64x32x32_1_nhwc_ck16rs16.cu", | |||
| "cutlass_tensorop_s8_i8816dgrad_id_s8_64x16x32_64x16x32_2_nhwc_ck16rs16.cu", | |||
| "cutlass_tensorop_s8_i8816dgrad_s2_id_s8_64x16x32_64x16x32_2_nhwc_ck16rs16.cu", | |||
| "all_deconv_tensorop8816_operations.cu", | |||
| "cutlass_simt_s8_ifprop_id_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
| "cutlass_simt_s8_ifprop_1x1_id_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
| "cutlass_simt_s8_ifprop_relu_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
| @@ -155,6 +155,7 @@ if(MGE_WITH_CUDA) | |||
| gen_cutlass_kimpl(gemm tensorop1688 CUTLASS_SOURCES) | |||
| gen_cutlass_kimpl(gemv simt CUTLASS_SOURCES) | |||
| gen_cutlass_kimpl(deconv simt CUTLASS_SOURCES) | |||
| gen_cutlass_kimpl(deconv tensorop8816 CUTLASS_SOURCES) | |||
| gen_cutlass_kimpl(conv2d simt CUTLASS_SOURCES) | |||
| gen_cutlass_kimpl(conv2d tensorop8816 CUTLASS_SOURCES) | |||
| gen_cutlass_kimpl(conv2d tensorop8832 CUTLASS_SOURCES) | |||
| @@ -36,6 +36,12 @@ ConvolutionBackwardDataImpl::AlgoPack::AlgoPack() { | |||
| int8_algos.push_back(&algo); | |||
| } | |||
| fill_int8_imma_algos(); | |||
| for (auto&& algo : int8_nhwc_imma) { | |||
| all_algos.push_back(&algo); | |||
| int8_algos.push_back(&algo); | |||
| } | |||
| int8_algos.push_back(&int8_nchw_dotprod); | |||
| all_algos.push_back(&int8_nchw_dotprod); | |||
| @@ -40,7 +40,8 @@ public: | |||
| CUDA_BFLOAT16, | |||
| CUDA_GROUP_CONV_GENERAL, | |||
| CUDA_IMPLICIT_GEMM_NCHW4_DOTPROD_INT8, | |||
| CUDA_IMPLICIT_GEMM_NCHW_DOTPROD_INT8 | |||
| CUDA_IMPLICIT_GEMM_NCHW_DOTPROD_INT8, | |||
| CUDA_IMPLICIT_GEMM_NHWC_IMMA_INT8 | |||
| }; | |||
| using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>; | |||
| @@ -299,11 +300,53 @@ private: | |||
| const void* get_available_op(const SizeArgs& args) const; | |||
| }; | |||
| class ConvolutionBackwardDataImpl::AlgoInt8NHWCIMMAImplicitGemm final | |||
| : public AlgoBase { | |||
| public: | |||
| struct AlgoParam { | |||
| int threadblock_m; | |||
| int threadblock_n; | |||
| int threadblock_k; | |||
| int warp_m; | |||
| int warp_n; | |||
| int warp_k; | |||
| int stage; | |||
| int access_size; | |||
| std::string to_string() { | |||
| return ssprintf("_%dX%dX%d_%dX%dX%d_%dstage_%d", threadblock_m, | |||
| threadblock_n, threadblock_k, warp_m, warp_n, | |||
| warp_k, stage, access_size); | |||
| } | |||
| }; | |||
| AlgoInt8NHWCIMMAImplicitGemm(AlgoParam algo_param) | |||
| : m_algo_param{algo_param}, | |||
| m_name{ssprintf("INT8_NHWC_IMMA_IMPLICIT_GEMM%s", | |||
| m_algo_param.to_string().c_str())} {} | |||
| bool is_available(const SizeArgs& args) const override; | |||
| size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||
| void exec(const ExecArgs& args) const override; | |||
| const char* name() const override { return m_name.c_str(); } | |||
| AlgoAttribute attribute() const override { | |||
| return AlgoAttribute::REPRODUCIBLE; | |||
| } | |||
| MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_NHWC_IMMA_INT8) | |||
| private: | |||
| WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, | |||
| const SizeArgs& args) const; | |||
| const void* get_available_op(const SizeArgs& args) const; | |||
| void reorder_filter(const ExecArgs& args, const int iterleaved, | |||
| int8_t* reordered_filter) const; | |||
| AlgoParam m_algo_param; | |||
| std::string m_name; | |||
| }; | |||
| class ConvolutionBackwardDataImpl::AlgoPack : NonCopyableObj { | |||
| // defined in cudnn.cpp | |||
| void fill_cudnn_algos(); | |||
| // defined in implicit_gemm_int8_nchw4_dp4a.cpp | |||
| void fill_int8_dp4a_algos(); | |||
| // defined in implicit_gemm_int8_nhwc_imma.cpp | |||
| void fill_int8_imma_algos(); | |||
| AlgoBase::Mapper m_all_algos_map; | |||
| @@ -318,6 +361,7 @@ public: | |||
| AlgoGroupConvGeneral group; | |||
| std::vector<AlgoInt8NCHW4DotProdImplicitGemm> int8_nchw4_dotprod; | |||
| AlgoInt8NCHWDotProdImplicitGemm int8_nchw_dotprod; | |||
| std::vector<AlgoInt8NHWCIMMAImplicitGemm> int8_nhwc_imma; | |||
| std::vector<AlgoBase*> | |||
| //! all algorithms | |||
| @@ -11,6 +11,7 @@ | |||
| */ | |||
| #include "src/cuda/convolution/backward_data/deconv_int8_helper.cuh" | |||
| #include "src/cuda/transpose_utils.cuh" | |||
| using namespace megdnn; | |||
| using namespace cuda; | |||
| @@ -21,7 +22,6 @@ using namespace deconv; | |||
| namespace { | |||
| // | |||
| __global__ void reorder_filter_nc4hw4_to_n4hwc4_kernel( | |||
| int8_t* __restrict__ dst, const int8_t* __restrict__ src, uint32_t OC, | |||
| uint32_t IC, uint32_t FHFW) { | |||
| @@ -30,32 +30,55 @@ __global__ void reorder_filter_nc4hw4_to_n4hwc4_kernel( | |||
| const int32_t fhfw = blockIdx.x * BLOCKSIZE_Y + threadIdx.x; | |||
| if (fhfw < FHFW && icb < IC / 4) { | |||
| int src0 = *reinterpret_cast<const int*>( | |||
| src + (ocb * 4 + 0) * IC * FHFW + (icb * FHFW + fhfw) * 4); | |||
| int src1 = *reinterpret_cast<const int*>( | |||
| src + (ocb * 4 + 1) * IC * FHFW + (icb * FHFW + fhfw) * 4); | |||
| int src2 = *reinterpret_cast<const int*>( | |||
| src + (ocb * 4 + 2) * IC * FHFW + (icb * FHFW + fhfw) * 4); | |||
| int src3 = *reinterpret_cast<const int*>( | |||
| src + (ocb * 4 + 3) * IC * FHFW + (icb * FHFW + fhfw) * 4); | |||
| int src_value[4], dst_value[4]; | |||
| #pragma unroll | |||
| for (int i = 0; i < 4; i++) { | |||
| src_value[i] = *reinterpret_cast<const int*>( | |||
| src + (ocb * 4 + i) * IC * FHFW + (icb * FHFW + fhfw) * 4); | |||
| } | |||
| // transpose 4x4 | |||
| int dst01_lo = __byte_perm(src0, src1, 0x5140); | |||
| int dst01_hi = __byte_perm(src0, src1, 0x7362); | |||
| int dst23_lo = __byte_perm(src2, src3, 0x5140); | |||
| int dst23_hi = __byte_perm(src2, src3, 0x7362); | |||
| int dst0 = __byte_perm(dst01_lo, dst23_lo, 0x5410); | |||
| int dst1 = __byte_perm(dst01_lo, dst23_lo, 0x7632); | |||
| int dst2 = __byte_perm(dst01_hi, dst23_hi, 0x5410); | |||
| int dst3 = __byte_perm(dst01_hi, dst23_hi, 0x7632); | |||
| *reinterpret_cast<int*>( | |||
| dst + (ocb * FHFW * IC + fhfw * IC + icb * 4 + 0) * 4) = dst0; | |||
| *reinterpret_cast<int*>( | |||
| dst + (ocb * FHFW * IC + fhfw * IC + icb * 4 + 1) * 4) = dst1; | |||
| *reinterpret_cast<int*>( | |||
| dst + (ocb * FHFW * IC + fhfw * IC + icb * 4 + 2) * 4) = dst2; | |||
| *reinterpret_cast<int*>( | |||
| dst + (ocb * FHFW * IC + fhfw * IC + icb * 4 + 3) * 4) = dst3; | |||
| transpose_int8_interleavedx4<4, int>(src_value, dst_value); | |||
| #pragma unroll | |||
| for (int i = 0; i < 4; i++) { | |||
| *reinterpret_cast<int*>( | |||
| dst + (ocb * FHFW * IC + fhfw * IC + icb * 4 + i) * 4) = | |||
| dst_value[i]; | |||
| } | |||
| } | |||
| } | |||
| template <uint32_t interleaved, typename vec_type> | |||
| __global__ void reorder_filter_nhwc_to_cnxhwx_kernel( | |||
| int8_t* __restrict__ dst, const int8_t* __restrict__ src, uint32_t OC, | |||
| uint32_t IC, uint32_t FHFW) { | |||
| uint32_t lane = threadIdx.x + blockIdx.x * blockDim.x; | |||
| const int32_t ocb = lane / (FHFW * IC / 4); | |||
| const int32_t fhfw_icb = lane % (FHFW * IC / 4); | |||
| const int32_t fhfw = fhfw_icb / (IC / 4); | |||
| const int32_t icb = fhfw_icb % (IC / 4); | |||
| if (ocb < OC / interleaved && fhfw < FHFW) { | |||
| int src_value[interleaved]; | |||
| vec_type dst_value[4]; | |||
| #pragma unroll | |||
| for (int i = 0; i < interleaved; i++) { | |||
| src_value[i] = *reinterpret_cast<const int*>( | |||
| src + (ocb * interleaved + i) * FHFW * IC + fhfw * IC + | |||
| icb * 4); | |||
| } | |||
| transpose_int8_interleavedx4<interleaved, vec_type>(src_value, | |||
| dst_value); | |||
| #pragma unroll | |||
| for (int i = 0; i < 4; i++) { | |||
| *reinterpret_cast<vec_type*>(dst + (icb * 4 + i) * FHFW * OC + | |||
| (ocb * FHFW + fhfw) * interleaved) = | |||
| dst_value[i]; | |||
| } | |||
| } | |||
| } | |||
| @@ -73,4 +96,27 @@ void megdnn::cuda::deconv::reorder_filter_nc4hw4_to_n4hwc4( | |||
| after_kernel_launch(); | |||
| } | |||
| void megdnn::cuda::deconv::reorder_filter_nhwc_to_cnxhwx( | |||
| int8_t* dst, const int8_t* src, uint32_t OC, uint32_t IC, uint32_t FH, | |||
| uint32_t FW, uint32_t interleaved, cudaStream_t stream) { | |||
| int32_t vthreads = OC / interleaved * IC / 4 * FH * FW; | |||
| int32_t nr_threads = std::min(256, vthreads); | |||
| int32_t nr_blocks = DIVUP(vthreads, nr_threads); | |||
| if (interleaved == 4) { | |||
| reorder_filter_nhwc_to_cnxhwx_kernel<4, int> | |||
| <<<nr_blocks, nr_threads, 0, stream>>>(dst, src, OC, IC, | |||
| FH * FW); | |||
| } else if (interleaved == 8) { | |||
| reorder_filter_nhwc_to_cnxhwx_kernel<8, int2> | |||
| <<<nr_blocks, nr_threads, 0, stream>>>(dst, src, OC, IC, | |||
| FH * FW); | |||
| } else { | |||
| reorder_filter_nhwc_to_cnxhwx_kernel<16, int4> | |||
| <<<nr_blocks, nr_threads, 0, stream>>>(dst, src, OC, IC, | |||
| FH * FW); | |||
| } | |||
| after_kernel_launch(); | |||
| } | |||
| // vim: syntax=cuda.doxygen | |||
| @@ -20,6 +20,10 @@ void reorder_filter_nc4hw4_to_n4hwc4(int8_t* dst, const int8_t* src, | |||
| uint32_t OC, uint32_t IC, uint32_t FH, | |||
| uint32_t FW, cudaStream_t stream); | |||
| void reorder_filter_nhwc_to_cnxhwx(int8_t* dst, const int8_t* src, uint32_t OC, | |||
| uint32_t IC, uint32_t FH, uint32_t FW, | |||
| uint32_t interleaved, cudaStream_t stream); | |||
| } // namespace deconv | |||
| } // namespace cuda | |||
| } // namespace megdnn | |||
| @@ -0,0 +1,214 @@ | |||
| /** | |||
| * \file | |||
| * dnn/src/cuda/convolution/backward_data/implicit_gemm_int8_nchw4_dp4a.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "src/cuda/convolution/backward_data/algo.h" | |||
| #include "src/cuda/convolution/backward_data/deconv_int8_helper.cuh" | |||
| #include "src/cuda/convolution_helper/parameter.cuh" | |||
| #include "src/cuda/cutlass/singleton.h" | |||
| #include "src/cuda/utils.h" | |||
| using namespace megdnn; | |||
| using namespace cuda; | |||
| const void* | |||
| ConvolutionBackwardDataImpl::AlgoInt8NHWCIMMAImplicitGemm::get_available_op( | |||
| const SizeArgs& args) const { | |||
| using namespace cutlass::library; | |||
| auto&& fm = args.filter_meta; | |||
| size_t sh = fm.stride[0], sw = fm.stride[1]; | |||
| cutlass::conv::SpecialOptimizeDesc special_optimization = | |||
| (sh == 2 && sw == 2) ? cutlass::conv::SpecialOptimizeDesc:: | |||
| DECONV_DOUBLE_UPSAMPLING | |||
| : cutlass::conv::SpecialOptimizeDesc::NONE; | |||
| LayoutTypeID filter_layout; | |||
| if (m_algo_param.access_size == 16) { | |||
| filter_layout = LayoutTypeID::kTensorCK16RS16; | |||
| } else if (m_algo_param.access_size == 8) { | |||
| filter_layout = LayoutTypeID::kTensorCK8RS8; | |||
| } else { | |||
| megdnn_assert(m_algo_param.access_size == 4, "invalid access_size: %d", | |||
| m_algo_param.access_size); | |||
| filter_layout = LayoutTypeID::kTensorCK4RS4; | |||
| } | |||
| ConvolutionKey key{ | |||
| cutlass::conv::Operator::kDgrad, | |||
| NumericTypeID::kS8, | |||
| LayoutTypeID::kTensorNHWC, | |||
| NumericTypeID::kS8, | |||
| filter_layout, | |||
| NumericTypeID::kS8, | |||
| LayoutTypeID::kTensorNHWC, | |||
| NumericTypeID::kS32, | |||
| LayoutTypeID::kTensorNHWC, | |||
| cutlass::conv::ConvType::kConvolution, | |||
| m_algo_param.threadblock_m, | |||
| m_algo_param.threadblock_n, | |||
| m_algo_param.threadblock_k, | |||
| m_algo_param.warp_m, | |||
| m_algo_param.warp_n, | |||
| m_algo_param.warp_k, | |||
| 8, | |||
| 8, | |||
| 16, | |||
| cutlass::epilogue::EpilogueType::kBiasAddLinearCombinationClamp, | |||
| m_algo_param.stage, | |||
| special_optimization, | |||
| false}; | |||
| return (void*)Singleton::get().operation_table.find_op(key); | |||
| } | |||
| bool ConvolutionBackwardDataImpl::AlgoInt8NHWCIMMAImplicitGemm::is_available( | |||
| const SizeArgs& args) const { | |||
| auto&& fm = args.filter_meta; | |||
| if (fm.format != Param::Format::NHWC) | |||
| return false; | |||
| if (!args.grad_layout->is_contiguous() || | |||
| !args.diff_layout->is_contiguous()) { | |||
| return false; | |||
| } | |||
| bool available = true; | |||
| auto src_dtype = args.diff_layout->dtype, | |||
| filter_dtype = args.filter_layout->dtype, | |||
| dst_dtype = args.grad_layout->dtype; | |||
| size_t co = args.diff_layout->operator[](3); | |||
| size_t ci = args.grad_layout->operator[](3); | |||
| available &= (src_dtype.enumv() == DTypeEnum::QuantizedS8 && | |||
| filter_dtype.enumv() == DTypeEnum::QuantizedS8 && | |||
| dst_dtype.enumv() == DTypeEnum::QuantizedS8); | |||
| // TODO support group deconv int8 | |||
| available &= (fm.group == 1); | |||
| // mode must be cross correlation | |||
| available &= !fm.should_flip; | |||
| // mode must be 2D | |||
| available &= fm.spatial_ndim == 2; | |||
| // TODO: support dialtion | |||
| available &= (fm.dilation[0] == 1 && fm.dilation[1] == 1); | |||
| // FIXME: too large filter size is not supported now | |||
| size_t kMaxFilterPixels = | |||
| 848 / (m_algo_param.warp_k / m_algo_param.access_size) - 1; | |||
| available &= fm.spatial[0] * fm.spatial[1] <= kMaxFilterPixels; | |||
| // ci should be aligned with 4, and co should be aligned with | |||
| // algo_param.access_size | |||
| available &= ((ci % 4 == 0) && (co % m_algo_param.access_size == 0)); | |||
| available &= (get_available_op(args) != nullptr); | |||
| // only support sm_75 or later, platform should have imma int8 support | |||
| available &= is_compute_capability_required(7, 5); | |||
| return available; | |||
| } | |||
| WorkspaceBundle | |||
| ConvolutionBackwardDataImpl::AlgoInt8NHWCIMMAImplicitGemm::get_workspace_bundle( | |||
| dt_byte* raw_ptr, const SizeArgs& args) const { | |||
| size_t ws_filter = args.filter_layout->span().dist_byte(); | |||
| return WorkspaceBundle{raw_ptr, {ws_filter}}; | |||
| } | |||
| size_t ConvolutionBackwardDataImpl::AlgoInt8NHWCIMMAImplicitGemm:: | |||
| get_workspace_in_bytes(const SizeArgs& args) const { | |||
| return get_workspace_bundle(nullptr, args).total_size_in_bytes(); | |||
| } | |||
| void ConvolutionBackwardDataImpl::AlgoInt8NHWCIMMAImplicitGemm::exec( | |||
| const ExecArgs& args) const { | |||
| auto&& param = args.opr->param(); | |||
| auto&& fm = args.filter_meta; | |||
| size_t n = args.diff_layout->operator[](0), | |||
| co = args.diff_layout->operator[](3), | |||
| ho = args.diff_layout->operator[](1), | |||
| wo = args.diff_layout->operator[](2); | |||
| size_t ci = args.grad_layout->operator[](3), | |||
| hi = args.grad_layout->operator[](1), | |||
| wi = args.grad_layout->operator[](2); | |||
| size_t fh = fm.spatial[0], fw = fm.spatial[1]; | |||
| size_t sh = fm.stride[0], sw = fm.stride[1]; | |||
| size_t ph = fm.padding[0], pw = fm.padding[1]; | |||
| size_t dh = param.dilate_h, dw = param.dilate_w; | |||
| auto&& stream = cuda_stream(args.opr->handle()); | |||
| int8_t* filter_ptr = nullptr; | |||
| // TODO: weight preprocess | |||
| { | |||
| filter_ptr = reinterpret_cast<int8_t*>(args.workspace.raw_ptr); | |||
| // reformat filter from nc4hw4 to n4hwc4 | |||
| reorder_filter(args, m_algo_param.access_size, filter_ptr); | |||
| } | |||
| float diff_scale = | |||
| args.diff_layout->dtype.param<dtype::QuantizedS8>().scale, | |||
| filter_scale = | |||
| args.filter_layout->dtype.param<dtype::QuantizedS8>().scale, | |||
| grad_scale = | |||
| args.grad_layout->dtype.param<dtype::QuantizedS8>().scale; | |||
| // \note these constants of cutlass epilogue will be passed to struct | |||
| // `ConvolutionArguments` by pointer and interpreted as ElementCompute*, | |||
| // a different dtype here results in undefined epilogue behaviors | |||
| float alpha = diff_scale * filter_scale / grad_scale, beta = 0.f, | |||
| gamma = 0.f, delta = 0.f; | |||
| using namespace cutlass::library; | |||
| const Operation* op = (const Operation*)get_available_op(args); | |||
| // gcc prints warnings when size_t values are implicitly narrowed to int | |||
| cutlass::conv::Conv2dProblemSize problem_size{ | |||
| int(n), int(hi), int(wi), int(ci), | |||
| int(co), int(fh), int(fw), int(ho), | |||
| int(wo), int(ph), int(pw), int(sh), | |||
| int(sw), int(dh), int(dw), cutlass::conv::Mode::kCrossCorrelation}; | |||
| cutlass::library::ConvolutionArguments conv_args{ | |||
| problem_size, args.diff_tensor->compatible_ptr<int8_t>(), | |||
| filter_ptr, nullptr, | |||
| nullptr, args.grad_tensor->compatible_ptr<int8_t>(), | |||
| &alpha, &beta, | |||
| &gamma, &delta, | |||
| nullptr, nullptr, | |||
| nullptr, nullptr}; | |||
| cutlass_check(op->run(&conv_args, nullptr, stream)); | |||
| after_kernel_launch(); | |||
| } | |||
| void ConvolutionBackwardDataImpl::AlgoInt8NHWCIMMAImplicitGemm::reorder_filter( | |||
| const ExecArgs& args, const int interleaved, | |||
| int8_t* reordered_filter) const { | |||
| auto&& fm = args.filter_meta; | |||
| size_t co = args.diff_layout->operator[](3); | |||
| size_t ci = args.grad_layout->operator[](3); | |||
| size_t fh = fm.spatial[0], fw = fm.spatial[1]; | |||
| auto&& stream = cuda_stream(args.opr->handle()); | |||
| megdnn::cuda::deconv::reorder_filter_nhwc_to_cnxhwx( | |||
| reordered_filter, args.filter_tensor->compatible_ptr<int8_t>(), co, | |||
| ci, fh, fw, interleaved, stream); | |||
| } | |||
| void ConvolutionBackwardDataImpl::AlgoPack::fill_int8_imma_algos() { | |||
| using AlgoParam = AlgoInt8NHWCIMMAImplicitGemm::AlgoParam; | |||
| int8_nhwc_imma.emplace_back(AlgoParam{64, 16, 32, 64, 16, 32, 2, 4}); | |||
| int8_nhwc_imma.emplace_back(AlgoParam{64, 16, 32, 64, 16, 32, 2, 8}); | |||
| int8_nhwc_imma.emplace_back(AlgoParam{64, 16, 32, 64, 16, 32, 2, 16}); | |||
| int8_nhwc_imma.emplace_back(AlgoParam{128, 32, 32, 64, 32, 32, 1, 4}); | |||
| int8_nhwc_imma.emplace_back(AlgoParam{128, 32, 32, 64, 32, 32, 1, 8}); | |||
| int8_nhwc_imma.emplace_back(AlgoParam{128, 32, 32, 64, 32, 32, 1, 16}); | |||
| } | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -99,6 +99,7 @@ public: | |||
| class AlgoBFloat16; | |||
| class AlgoInt8NCHW4DotProdImplicitGemm; | |||
| class AlgoInt8NCHWDotProdImplicitGemm; | |||
| class AlgoInt8NHWCIMMAImplicitGemm; | |||
| class AlgoPack; | |||
| @@ -60,6 +60,7 @@ void initialize_all_gemm_tensorop884_operations(Manifest& manifest); | |||
| void initialize_all_gemm_tensorop1688_operations(Manifest& manifest); | |||
| void initialize_all_conv2d_tensorop8816_operations(Manifest& manifest); | |||
| void initialize_all_conv2d_tensorop8832_operations(Manifest& manifest); | |||
| void initialize_all_deconv_tensorop8816_operations(Manifest& manifest); | |||
| #endif | |||
| void initialize_all(Manifest& manifest) { | |||
| @@ -71,6 +72,7 @@ void initialize_all(Manifest& manifest) { | |||
| initialize_all_gemm_tensorop1688_operations(manifest); | |||
| initialize_all_conv2d_tensorop8816_operations(manifest); | |||
| initialize_all_conv2d_tensorop8832_operations(manifest); | |||
| initialize_all_deconv_tensorop8816_operations(manifest); | |||
| #endif | |||
| } | |||
| @@ -100,6 +100,9 @@ enum class LayoutTypeID { | |||
| kTensorNC64HW64, | |||
| kTensorC64RSK64, | |||
| kTensorK4RSC4, | |||
| kTensorCK4RS4, | |||
| kTensorCK8RS8, | |||
| kTensorCK16RS16, | |||
| kInvalid | |||
| }; | |||
| @@ -225,6 +228,7 @@ enum class ThreadblockSwizzleID { | |||
| kConvolutionFpropNCxHWx, | |||
| kConvolutionFpropTrans, | |||
| kConvolutionDgradNCxHWx, | |||
| kConvolutionDgradTrans, | |||
| kInvalid | |||
| }; | |||
| @@ -340,6 +340,21 @@ struct LayoutMap<cutlass::layout::TensorKxRSCx<4>> { | |||
| static LayoutTypeID const kId = LayoutTypeID::kTensorK4RSC4; | |||
| }; | |||
| template <> | |||
| struct LayoutMap<cutlass::layout::TensorCKxRSx<4>> { | |||
| static LayoutTypeID const kId = LayoutTypeID::kTensorCK4RS4; | |||
| }; | |||
| template <> | |||
| struct LayoutMap<cutlass::layout::TensorCKxRSx<8>> { | |||
| static LayoutTypeID const kId = LayoutTypeID::kTensorCK8RS8; | |||
| }; | |||
| template <> | |||
| struct LayoutMap<cutlass::layout::TensorCKxRSx<16>> { | |||
| static LayoutTypeID const kId = LayoutTypeID::kTensorCK16RS16; | |||
| }; | |||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||
| template <typename T> | |||
| @@ -556,6 +571,13 @@ struct ThreadblockSwizzleMap< | |||
| ThreadblockSwizzleID::kConvolutionDgradNCxHWx; | |||
| }; | |||
| template <> | |||
| struct ThreadblockSwizzleMap< | |||
| conv::threadblock::ConvolutionDgradTransThreadblockSwizzle> { | |||
| static ThreadblockSwizzleID const kId = | |||
| ThreadblockSwizzleID::kConvolutionDgradTrans; | |||
| }; | |||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||
| template <typename Element, typename Layout> | |||
| @@ -533,7 +533,10 @@ static struct { | |||
| {LayoutTypeID::kTensorC16RSK16, "c16rsk16"}, | |||
| {LayoutTypeID::kTensorC32RSK32, "c32rsk32"}, | |||
| {LayoutTypeID::kTensorC64RSK64, "c64rsk64"}, | |||
| {LayoutTypeID::kTensorK4RSC4, "k4rsC4"}, | |||
| {LayoutTypeID::kTensorK4RSC4, "k4rsc4"}, | |||
| {LayoutTypeID::kTensorCK4RS4, "ck4rs4"}, | |||
| {LayoutTypeID::kTensorCK8RS8, "ck8rs8"}, | |||
| {LayoutTypeID::kTensorCK16RS16, "ck16rs16"}, | |||
| {LayoutTypeID::kUnknown, "*"}, | |||
| {LayoutTypeID::kInvalid, nullptr}}; | |||
| @@ -1499,6 +1502,8 @@ static struct { | |||
| ThreadblockSwizzleID::kConvolutionFpropTrans}, | |||
| {"convolution_dgrad_ncxhwx", "ConvolutionDgradNCxHWxThreadblockSwizzle", | |||
| ThreadblockSwizzleID::kConvolutionDgradNCxHWx}, | |||
| {"convolution_dgrad_ncxhwx", "ConvolutionDgradTransThreadblockSwizzle", | |||
| ThreadblockSwizzleID::kConvolutionDgradTrans}, | |||
| }; | |||
| /// Converts a ThreadblockSwizzleID enumerant to a string | |||
| @@ -0,0 +1,69 @@ | |||
| /** | |||
| * \file dnn/src/cuda/memory_utils.cuh | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #if MEGDNN_CC_CUDA | |||
| #pragma once | |||
| #include "src/cuda/utils.cuh" | |||
| namespace megdnn { | |||
| namespace cuda { | |||
| MEGDNN_DEVICE __forceinline__ void transpose_int8_4x4_impl( | |||
| const int src0, const int src1, const int src2, const int src3, | |||
| int& dst0, int& dst1, int& dst2, int& dst3) { | |||
| int dst01_lo = __byte_perm(src0, src1, 0x5140); | |||
| int dst01_hi = __byte_perm(src0, src1, 0x7362); | |||
| int dst23_lo = __byte_perm(src2, src3, 0x5140); | |||
| int dst23_hi = __byte_perm(src2, src3, 0x7362); | |||
| dst0 = __byte_perm(dst01_lo, dst23_lo, 0x5410); | |||
| dst1 = __byte_perm(dst01_lo, dst23_lo, 0x7632); | |||
| dst2 = __byte_perm(dst01_hi, dst23_hi, 0x5410); | |||
| dst3 = __byte_perm(dst01_hi, dst23_hi, 0x7632); | |||
| } | |||
| template <uint32_t interleaved, typename vec_type> | |||
| MEGDNN_DEVICE __forceinline__ void transpose_int8_interleavedx4( | |||
| const int src[interleaved], vec_type (&dst)[4]); | |||
| template <> | |||
| MEGDNN_DEVICE __forceinline__ void transpose_int8_interleavedx4<4, int>( | |||
| const int src[4], int (&dst)[4]) { | |||
| transpose_int8_4x4_impl(src[0], src[1], src[2], src[3], dst[0], dst[1], | |||
| dst[2], dst[3]); | |||
| } | |||
| template <> | |||
| MEGDNN_DEVICE __forceinline__ void transpose_int8_interleavedx4<8, int2>( | |||
| const int src[8], int2 (&dst)[4]) { | |||
| transpose_int8_4x4_impl(src[0], src[1], src[2], src[3], dst[0].x, dst[1].x, | |||
| dst[2].x, dst[3].x); | |||
| transpose_int8_4x4_impl(src[4], src[5], src[6], src[7], dst[0].y, dst[1].y, | |||
| dst[2].y, dst[3].y); | |||
| } | |||
| template <> | |||
| MEGDNN_DEVICE __forceinline__ void transpose_int8_interleavedx4<16, int4>( | |||
| const int src[16], int4 (&dst)[4]) { | |||
| transpose_int8_4x4_impl(src[0], src[1], src[2], src[3], dst[0].x, dst[1].x, | |||
| dst[2].x, dst[3].x); | |||
| transpose_int8_4x4_impl(src[4], src[5], src[6], src[7], dst[0].y, dst[1].y, | |||
| dst[2].y, dst[3].y); | |||
| transpose_int8_4x4_impl(src[8], src[9], src[10], src[11], dst[0].z, | |||
| dst[1].z, dst[2].z, dst[3].z); | |||
| transpose_int8_4x4_impl(src[12], src[13], src[14], src[15], dst[0].w, | |||
| dst[1].w, dst[2].w, dst[3].w); | |||
| } | |||
| } // namespace cuda | |||
| } // namespace megdnn | |||
| #endif | |||
| // vim: ft=cpp syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||
| @@ -469,7 +469,6 @@ std::vector<TestArg> convolution::get_args_int8_nchw4_conv_bwd_data() { | |||
| return args; | |||
| } | |||
| std::vector<TestArg> convolution::get_args_int8_nchw_conv_bwd_data() { | |||
| std::vector<TestArg> args; | |||
| param::Convolution cur_param; | |||
| @@ -511,6 +510,46 @@ std::vector<TestArg> convolution::get_args_int8_nchw_conv_bwd_data() { | |||
| return args; | |||
| } | |||
| std::vector<TestArg> convolution::get_args_int8_nhwc_conv_bwd_data() { | |||
| std::vector<TestArg> args; | |||
| param::Convolution cur_param; | |||
| // clang-format off | |||
| for (auto mode : {param::Convolution::Mode::CROSS_CORRELATION}) { | |||
| for (size_t b : {64, 16}) { | |||
| for (size_t ic : {16, 32}) { | |||
| for (size_t oc : {16, 32}) { | |||
| for (size_t h : {8}) { | |||
| for (size_t w : {8, 11}) { | |||
| for (size_t kernel_size : {3, 4, 5, 7}) { | |||
| for (int p : {0, static_cast<int>(kernel_size / 2)}) { | |||
| for (size_t s : {2}) { | |||
| if (kernel_size >= 7) { | |||
| b = std::min(b, 32_z); | |||
| } | |||
| size_t f = kernel_size; | |||
| cur_param.mode = mode; | |||
| cur_param.format = param::Convolution::Format::NHWC; | |||
| cur_param.sparse = param::Convolution::Sparse::DENSE; | |||
| cur_param.pad_h = cur_param.pad_w = p; | |||
| cur_param.stride_h = cur_param.stride_w = s; | |||
| //! bias channel | |||
| args.emplace_back(cur_param, TensorShape{b, h, w, ic}, | |||
| TensorShape{oc, f, f, ic}); | |||
| } } } } } } } } } | |||
| // clang-format on | |||
| cur_param.pad_h = cur_param.pad_w = 1; | |||
| cur_param.stride_h = cur_param.stride_w = 1; | |||
| args.emplace_back(cur_param, TensorShape{16, 8, 11, 16}, | |||
| TensorShape{16, 3, 3, 16}); | |||
| return args; | |||
| } | |||
| void convolution::test_conv_config_combinations( | |||
| int k_size, Handle* handle, bool test_int8, bool test_backward, | |||
| bool is_cuda, ConvEPSGetter eps_getter, bool use_io16xc32) { | |||
| @@ -50,6 +50,7 @@ std::vector<TestArg> get_dilated_args(); | |||
| std::vector<TestArg> get_chanwise_args(); | |||
| std::vector<TestArg> get_args_int8_nchw4_conv_bwd_data(); | |||
| std::vector<TestArg> get_args_int8_nchw_conv_bwd_data(); | |||
| std::vector<TestArg> get_args_int8_nhwc_conv_bwd_data(); | |||
| //! \param stage 0 for fwd, 1 for bwd data, 2 for bwd filter | |||
| using ConvEPSGetter = | |||
| @@ -386,6 +386,69 @@ TEST_F(CUDA, CONVOLUTION_BACKWARD_DATA_INT8_NCHW_DP4A) { | |||
| } | |||
| } | |||
| #if CUDA_VERSION >= 10020 | |||
| TEST_F(CUDA, CONVOLUTION_BACKWARD_DATA_INT8_NHWC_IMMA) { | |||
| if (!cuda::is_compute_capability_required(7, 5)) { | |||
| printf("Skip CUDA.CONVOLUTION_BACKWARD_DATA_INT8_NHWC_IMMA test as " | |||
| "current device doesn't support\n"); | |||
| return; | |||
| } | |||
| using namespace convolution; | |||
| std::vector<TestArg> args = get_args_int8_nhwc_conv_bwd_data(); | |||
| struct AlgoParam { | |||
| int threadblock_m; | |||
| int threadblock_n; | |||
| int threadblock_k; | |||
| int warp_m; | |||
| int warp_n; | |||
| int warp_k; | |||
| int stage; | |||
| int access_size; | |||
| std::string to_string() { | |||
| return ssprintf("_%dX%dX%d_%dX%dX%d_%dstage_%d", threadblock_m, | |||
| threadblock_n, threadblock_k, warp_m, warp_n, | |||
| warp_k, stage, access_size); | |||
| } | |||
| }; | |||
| std::vector<AlgoParam> all_params; | |||
| all_params.emplace_back(AlgoParam{64, 16, 32, 64, 16, 32, 2, 4}); | |||
| all_params.emplace_back(AlgoParam{64, 16, 32, 64, 16, 32, 2, 8}); | |||
| all_params.emplace_back(AlgoParam{64, 16, 32, 64, 16, 32, 2, 16}); | |||
| all_params.emplace_back(AlgoParam{128, 32, 32, 64, 32, 32, 1, 4}); | |||
| all_params.emplace_back(AlgoParam{128, 32, 32, 64, 32, 32, 1, 8}); | |||
| all_params.emplace_back(AlgoParam{128, 32, 32, 64, 32, 32, 1, 16}); | |||
| for (auto algo_param : all_params) { | |||
| Checker<ConvolutionBackwardData> checker(handle_cuda()); | |||
| std::string algo_name(ssprintf("INT8_NHWC_IMMA_IMPLICIT_GEMM%s", | |||
| algo_param.to_string().c_str())); | |||
| checker.set_before_exec_callback( | |||
| AlgoChecker<ConvolutionBackwardData>(algo_name.c_str())); | |||
| checker.set_epsilon(1 + 1e-3).set_max_avg_error(1e-1); | |||
| for (auto&& arg : args) { | |||
| UniformIntRNG rng(-3, 3); | |||
| auto src = TensorLayout(arg.src, dtype::QuantizedS8{1.2f}); | |||
| auto filter = TensorLayout(arg.filter, dtype::QuantizedS8{1.3f}); | |||
| TensorLayout dst; | |||
| dst.dtype = dtype::QuantizedS8{1.2f}; | |||
| { | |||
| auto opr = handle_cuda()->create_operator<Convolution>(); | |||
| opr->param() = arg.param; | |||
| opr->deduce_layout(src, filter, dst); | |||
| } | |||
| checker.set_rng(0, &rng).set_rng(1, &rng).set_param(arg.param).exec( | |||
| TensorLayoutArray{filter, dst, src}); | |||
| } | |||
| } | |||
| } | |||
| #endif | |||
| TEST_F(CUDA, CONVOLUTION_BACKWARD_DATA_FAILED_CUDNN7_5) { | |||
| // BRAIN-481 failed on architectures 7.0, remove the following if statement, | |||
| // when cudnn fixed the problem. | |||