GitOrigin-RevId: 2a70335441
tags/v1.6.0-rc1
| @@ -163,7 +163,7 @@ using Convolution = | |||
| ${element_bias}, | |||
| ${layout_bias}, | |||
| ${element_accumulator}, | |||
| ${conv_type}, | |||
| ${conv_type}, | |||
| ${opcode_class}, | |||
| ${arch}, | |||
| cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, | |||
| @@ -246,6 +246,7 @@ using Deconvolution = | |||
| ${element_bias}, | |||
| ${layout_bias}, | |||
| ${element_accumulator}, | |||
| ${conv_type}, | |||
| ${opcode_class}, | |||
| ${arch}, | |||
| cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, | |||
| @@ -276,6 +277,7 @@ using Deconvolution = | |||
| values = { | |||
| 'operation_name': operation.procedural_name(), | |||
| 'conv_type': ConvTypeTag[operation.conv_type], | |||
| 'element_src': DataTypeTag[operation.src.element], | |||
| 'layout_src': LayoutTag[operation.src.layout], | |||
| 'element_flt': DataTypeTag[operation.flt.element], | |||
| @@ -530,44 +532,17 @@ void initialize_${configuration_name}(Manifest &manifest) { | |||
| ################################################################################################### | |||
| class EmitConvSingleKernelWrapper(): | |||
| def __init__(self, kernel_path, operation, wrapper_path): | |||
| def __init__(self, kernel_path, operation): | |||
| self.kernel_path = kernel_path | |||
| self.wrapper_path = wrapper_path | |||
| self.operation = operation | |||
| self.conv_wrappers = { \ | |||
| ConvKind.Fprop: """ | |||
| template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>( | |||
| const typename Convolution::ElementSrc* d_src, | |||
| const typename Convolution::ElementFilter* d_filter, | |||
| const typename Convolution::ElementBias* d_bias, | |||
| const typename Convolution::ElementDst* d_z, | |||
| typename Convolution::ElementDst* d_dst, | |||
| int* workspace, | |||
| typename Convolution::ConvolutionParameter const& conv_param, | |||
| typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||
| cudaStream_t stream, | |||
| typename Convolution::ExtraParam extra_param); | |||
| """, \ | |||
| ConvKind.Dgrad: """ | |||
| template void megdnn::cuda::cutlass_wrapper::cutlass_deconvolution_wrapper<Deconvolution>( | |||
| const typename Deconvolution::ElementSrc* d_src, | |||
| const typename Deconvolution::ElementFilter* d_filter, | |||
| const typename Deconvolution::ElementBias* d_bias, | |||
| const typename Deconvolution::ElementDst* d_z, | |||
| typename Deconvolution::ElementDst* d_dst, | |||
| int* workspace, | |||
| typename Deconvolution::ConvolutionParameter const& conv_param, | |||
| typename Deconvolution::EpilogueOutputOp::Params const& epilogue, | |||
| cudaStream_t stream); | |||
| """, \ | |||
| } | |||
| if self.operation.conv_kind == ConvKind.Fprop: | |||
| self.instance_emitter = EmitConv2dInstance() | |||
| self.convolution_name = "Convolution" | |||
| else: | |||
| assert self.operation.conv_kind == ConvKind.Dgrad | |||
| self.instance_emitter = EmitDeconvInstance() | |||
| self.convolution_name = "Deconvolution" | |||
| self.header_template = """ | |||
| #if !MEGDNN_TEGRA_X1 | |||
| @@ -575,13 +550,30 @@ template void megdnn::cuda::cutlass_wrapper::cutlass_deconvolution_wrapper<Decon | |||
| #pragma GCC diagnostic push | |||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||
| #include "${wrapper_path}" | |||
| #pragma GCC diagnostic ignored "-Wuninitialized" | |||
| #pragma GCC diagnostic ignored "-Wmaybe-uninitialized" | |||
| #include "cutlass/convolution/device/convolution.h" | |||
| #include "src/cuda/cutlass/manifest.h" | |||
| #include "src/cuda/cutlass/convolution_operation.h" | |||
| """ | |||
| self.instance_template = """ | |||
| ${operation_instance} | |||
| """ | |||
| self.wrapper_template = """ | |||
| ${wrapper_instance} | |||
| self.manifest_template = """ | |||
| namespace cutlass { | |||
| namespace library { | |||
| void initialize_${operation_name}(Manifest &manifest) { | |||
| manifest.append(new ConvolutionOperation<${convolution_name}>( | |||
| "${operation_name}" | |||
| )); | |||
| } | |||
| } // namespace library | |||
| } // namespace cutlass | |||
| """ | |||
| self.epilogue_template = """ | |||
| @@ -593,9 +585,7 @@ ${wrapper_instance} | |||
| def __enter__(self): | |||
| self.kernel_path = os.path.join(self.kernel_path, "%s.cu" % self.operation.procedural_name()) | |||
| self.kernel_file = LazyFile(self.kernel_path) | |||
| self.kernel_file.write(SubstituteTemplate(self.header_template, { | |||
| 'wrapper_path': self.wrapper_path, | |||
| })) | |||
| self.kernel_file.write(self.header_template) | |||
| return self | |||
| # | |||
| @@ -604,11 +594,12 @@ ${wrapper_instance} | |||
| 'operation_instance': self.instance_emitter.emit(self.operation), | |||
| })) | |||
| # emit wrapper | |||
| wrapper = SubstituteTemplate(self.wrapper_template, { | |||
| 'wrapper_instance': self.conv_wrappers[self.operation.conv_kind], | |||
| # emit manifest helper | |||
| manifest = SubstituteTemplate(self.manifest_template, { | |||
| 'operation_name': self.operation.procedural_name(), | |||
| 'convolution_name': self.convolution_name | |||
| }) | |||
| self.kernel_file.write(wrapper) | |||
| self.kernel_file.write(manifest) | |||
| # | |||
| def __exit__(self, exception_type, exception_value, traceback): | |||
| @@ -940,8 +940,8 @@ void initialize_${configuration_name}(Manifest &manifest) { | |||
| /////////////////////////////////////////////////////////////////////////////////////////////////// | |||
| } // namespace library | |||
| } // namespace cutlass | |||
| } // namespace library | |||
| } // namespace cutlass | |||
| /////////////////////////////////////////////////////////////////////////////////////////////////// | |||
| @@ -995,48 +995,101 @@ void initialize_${configuration_name}(Manifest &manifest) { | |||
| ################################################################################################### | |||
| class EmitGemmSingleKernelWrapper: | |||
| def __init__(self, kernel_path, gemm_operation, wrapper_path): | |||
| def __init__(self, kernel_path, gemm_operation): | |||
| self.kernel_path = kernel_path | |||
| self.wrapper_path = wrapper_path | |||
| self.operation = gemm_operation | |||
| gemm_wrapper = """ | |||
| template void megdnn::cuda::cutlass_wrapper::cutlass_matrix_mul_wrapper<Operation_${operation_name}>( | |||
| const typename Operation_${operation_name}::ElementA* d_A, size_t lda, | |||
| const typename Operation_${operation_name}::ElementB* d_B, size_t ldb, | |||
| typename Operation_${operation_name}::ElementC* d_C, size_t ldc, | |||
| int* workspace, | |||
| cutlass::gemm::GemmCoord const& problem_size, | |||
| typename Operation_${operation_name}::EpilogueOutputOp::Params const& epilogue, | |||
| cudaStream_t stream, int split_k_slices); | |||
| instance_emitters = { | |||
| GemmKind.Gemm: EmitGemmInstance(), | |||
| GemmKind.SplitKParallel: EmitGemmSplitKParallelInstance(), | |||
| } | |||
| self.instance_emitter = instance_emitters[self.operation.gemm_kind] | |||
| self.header_template = """ | |||
| #if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) | |||
| // ignore warning of cutlass | |||
| #pragma GCC diagnostic push | |||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||
| #pragma GCC diagnostic ignored "-Wuninitialized" | |||
| #pragma GCC diagnostic ignored "-Wmaybe-uninitialized" | |||
| #include "cutlass/gemm/device/gemm.h" | |||
| #include "cutlass/gemm/device/gemm_splitk_parallel.h" | |||
| #include "src/cuda/cutlass/manifest.h" | |||
| #include "src/cuda/cutlass/gemm_operation.h" | |||
| """ | |||
| self.instance_template = """ | |||
| ${operation_instance} | |||
| """ | |||
| self.manifest_template = """ | |||
| namespace cutlass { | |||
| namespace library { | |||
| void initialize_${operation_name}(Manifest &manifest) { | |||
| manifest.append(new GemmOperation< | |||
| Operation_${operation_name} | |||
| >("${operation_name}")); | |||
| } | |||
| } // namespace library | |||
| } // namespace cutlass | |||
| """ | |||
| self.epilogue_template = """ | |||
| #pragma GCC diagnostic pop | |||
| #endif | |||
| """ | |||
| # | |||
| def __enter__(self): | |||
| self.kernel_path = os.path.join(self.kernel_path, "%s.cu" % self.operation.procedural_name()) | |||
| self.kernel_file = LazyFile(self.kernel_path) | |||
| self.kernel_file.write(self.header_template) | |||
| return self | |||
| # | |||
| def emit(self): | |||
| self.kernel_file.write(SubstituteTemplate(self.instance_template, { | |||
| 'operation_instance': self.instance_emitter.emit(self.operation), | |||
| })) | |||
| gemv_wrapper = """ | |||
| # emit manifest helper | |||
| manifest = SubstituteTemplate(self.manifest_template, { | |||
| 'operation_name': self.operation.procedural_name(), | |||
| }) | |||
| self.kernel_file.write(manifest) | |||
| # | |||
| def __exit__(self, exception_type, exception_value, traceback): | |||
| self.kernel_file.write(self.epilogue_template) | |||
| self.kernel_file.close() | |||
| ################################################################################################### | |||
| ################################################################################################### | |||
| class EmitGemvSingleKernelWrapper: | |||
| def __init__(self, kernel_path, gemm_operation, wrapper_path): | |||
| self.kernel_path = kernel_path | |||
| self.wrapper_path = wrapper_path | |||
| self.operation = gemm_operation | |||
| self.wrapper_template = """ | |||
| template void megdnn::cuda::cutlass_wrapper:: | |||
| cutlass_vector_matrix_mul_batched_strided_wrapper<Operation_${operation_name}>( | |||
| BatchedGemmCoord const& problem_size, | |||
| const typename Operation_${operation_name}::ElementA* d_A, size_t lda, size_t batch_stride_a, | |||
| const typename Operation_${operation_name}::ElementB* d_B, size_t ldb, size_t batch_stride_b, | |||
| const typename Operation_${operation_name}::ElementA* d_A, size_t lda, size_t batch_stride_a, | |||
| const typename Operation_${operation_name}::ElementB* d_B, size_t ldb, size_t batch_stride_b, | |||
| typename Operation_${operation_name}::ElementCD* d_C, size_t ldc, size_t batch_stride_c, | |||
| cudaStream_t stream); | |||
| """ | |||
| if self.operation.gemm_kind == GemmKind.SplitKParallel or \ | |||
| self.operation.gemm_kind == GemmKind.Gemm: | |||
| self.wrapper_template = gemm_wrapper | |||
| else: | |||
| assert self.operation.gemm_kind == GemmKind.GemvBatchedStrided | |||
| self.wrapper_template = gemv_wrapper | |||
| instance_emitters = { | |||
| GemmKind.Gemm: EmitGemmInstance(), | |||
| GemmKind.SplitKParallel: EmitGemmSplitKParallelInstance(), | |||
| GemmKind.GemvBatchedStrided: EmitGemvBatchedStridedInstance(), | |||
| } | |||
| self.instance_emitter = instance_emitters[self.operation.gemm_kind] | |||
| self.instance_emitter = EmitGemvBatchedStridedInstance() | |||
| self.header_template = """ | |||
| #if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) | |||
| #if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) | |||
| // ignore warning of cutlass | |||
| #pragma GCC diagnostic push | |||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||
| @@ -1055,10 +1108,10 @@ ${operation_instance} | |||
| """ | |||
| # | |||
| def __enter__(self): | |||
| self.kernel_path = os.path.join(self.kernel_path, "%s.cu" % self.operation.procedural_name()) | |||
| self.kernel_path = os.path.join(self.kernel_path, "%s.cu" % self.operation.procedural_name()) | |||
| self.kernel_file = LazyFile(self.kernel_path) | |||
| self.kernel_file.write(SubstituteTemplate(self.header_template, { | |||
| 'wrapper_path': self.wrapper_path, | |||
| 'wrapper_path': self.wrapper_path, | |||
| })) | |||
| return self | |||
| @@ -1070,7 +1123,7 @@ ${operation_instance} | |||
| # emit wrapper | |||
| wrapper = SubstituteTemplate(self.wrapper_template, { | |||
| 'operation_name': self.operation.procedural_name(), | |||
| 'operation_name': self.operation.procedural_name(), | |||
| }) | |||
| self.kernel_file.write(wrapper) | |||
| @@ -1079,7 +1132,5 @@ ${operation_instance} | |||
| self.kernel_file.write(self.epilogue_template) | |||
| self.kernel_file.close() | |||
| ################################################################################################### | |||
| ################################################################################################### | |||
| @@ -23,6 +23,8 @@ def write_op_list(f, gen_op, gen_type): | |||
| operations = GenerateDeconvOperations(GenArg(gen_op, gen_type)) | |||
| for op in operations: | |||
| f.write(' "%s.cu",\n' % op.procedural_name()) | |||
| if gen_op != "gemv": | |||
| f.write(' "all_%s_%s_operations.cu",\n' % (gen_op, gen_type)) | |||
| if __name__ == "__main__": | |||
| @@ -292,7 +292,7 @@ def GenerateConv2d_TensorOp_8832(args): | |||
| ] | |||
| operations += GenerateConv2d(ConvKind.Fprop, tile_descriptions, layout[0], layout[1], | |||
| dst_layout, dst_type, min_cc, 128, 128, 64, | |||
| True, ImplicitGemmMode.GemmTN, True) | |||
| False, ImplicitGemmMode.GemmTN, True) | |||
| layouts_nhwc = [ | |||
| (LayoutType.TensorNHWC, LayoutType.TensorNC8HW8, 32), | |||
| @@ -633,16 +633,10 @@ if __name__ == "__main__": | |||
| parser.add_argument("--type", type=str, choices=['simt', 'tensorop8816', 'tensorop8832'], | |||
| default='simt', help="kernel type of CUTLASS kernel generator") | |||
| operation2wrapper_path = { | |||
| "gemm": "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cuinl", \ | |||
| "gemv": "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper_batched_gemv_strided.cuinl", \ | |||
| "conv2d": "src/cuda/conv_bias/implicit_gemm_conv_bias_cutlass_wrapper.cuinl", \ | |||
| "deconv": "src/cuda/convolution/backward_data/implicit_gemm_deconv_cutlass_wrapper.cuinl", \ | |||
| } | |||
| gemv_wrapper_path = "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper_batched_gemv_strided.cuinl" | |||
| args = parser.parse_args() | |||
| wrapper_path = operation2wrapper_path[args.operations] | |||
| if args.operations == "gemm": | |||
| operations = GenerateGemmOperations(args) | |||
| elif args.operations == "gemv": | |||
| @@ -652,16 +646,22 @@ if __name__ == "__main__": | |||
| elif args.operations == "deconv": | |||
| operations = GenerateDeconvOperations(args) | |||
| if args.operations == "conv2d" or args.operations == "deconv": | |||
| for operation in operations: | |||
| with EmitConvSingleKernelWrapper(args.output, operation, wrapper_path) as emitter: | |||
| with EmitConvSingleKernelWrapper(args.output, operation) as emitter: | |||
| emitter.emit() | |||
| elif args.operations == "gemm" or args.operations == "gemv": | |||
| elif args.operations == "gemm": | |||
| for operation in operations: | |||
| with EmitGemmSingleKernelWrapper(args.output, operation, wrapper_path) as emitter: | |||
| with EmitGemmSingleKernelWrapper(args.output, operation) as emitter: | |||
| emitter.emit() | |||
| elif args.operations == "gemv": | |||
| for operation in operations: | |||
| with EmitGemvSingleKernelWrapper(args.output, operation, gemv_wrapper_path) as emitter: | |||
| emitter.emit() | |||
| if args.operations != "gemv": | |||
| GenerateManifest(args, operations, args.output) | |||
| # | |||
| ################################################################################################### | |||
| @@ -137,6 +137,7 @@ cutlass_gen_list = [ | |||
| "cutlass_simt_sgemm_split_k_parallel_256x32_8x2_tt_align1.cu", | |||
| "cutlass_simt_sgemm_256x64_8x2_tt_align1.cu", | |||
| "cutlass_simt_sgemm_split_k_parallel_256x64_8x2_tt_align1.cu", | |||
| "all_gemm_simt_operations.cu", | |||
| "cutlass_simt_sgemv_batched_strided_1x128_32_tt_align4x4.cu", | |||
| "cutlass_simt_sgemv_batched_strided_1x128_16_tt_align4x2.cu", | |||
| "cutlass_simt_sgemv_batched_strided_1x128_8_tt_align4x1.cu", | |||
| @@ -169,6 +170,7 @@ cutlass_gen_list = [ | |||
| "cutlass_simt_s8_idgrad_id_s8_16x128x16_16x64x16_2_nc4hw4_k4rsc4.cu", | |||
| "cutlass_simt_s8_idgrad_id_s8_16x128x16_16x128x16_1_nc4hw4_k4rsc4.cu", | |||
| "cutlass_simt_s8_idgrad_id_s8_16x64x8_16x64x8_2_nc4hw4_k4rsc4.cu", | |||
| "all_deconv_simt_operations.cu", | |||
| "cutlass_simt_s8_ifprop_id_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
| "cutlass_simt_s8_ifprop_1x1_id_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
| "cutlass_simt_s8_ifprop_relu_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
| @@ -373,6 +375,7 @@ cutlass_gen_list = [ | |||
| "cutlass_simt_f32_ifprop_1x1_relu_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nchw.cu", | |||
| "cutlass_simt_f32_ifprop_hswish_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nchw.cu", | |||
| "cutlass_simt_f32_ifprop_1x1_hswish_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nchw.cu", | |||
| "all_conv2d_simt_operations.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_roc_id_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_1x1_roc_id_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_roc_relu_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||
| @@ -481,26 +484,47 @@ cutlass_gen_list = [ | |||
| "cutlass_tensorop_s8_i8816fprop_1x1_relu_s8_32x128x32_32x64x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_hswish_s8_32x128x32_32x64x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_1x1_hswish_s8_32x128x32_32x64x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||
| "all_conv2d_tensorop8816_operations.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_roc_id_s4_128x256x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_1x1_roc_id_s4_128x256x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_roc_relu_s4_128x256x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_1x1_roc_relu_s4_128x256x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_roc_hswish_s4_128x256x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_1x1_roc_hswish_s4_128x256x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_roc_id_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_1x1_roc_id_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_roc_relu_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_1x1_roc_relu_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_roc_hswish_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_1x1_roc_hswish_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_roc_id_s4_128x64x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_1x1_roc_id_s4_128x64x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_roc_relu_s4_128x64x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_1x1_roc_relu_s4_128x64x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_roc_hswish_s4_128x64x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_1x1_roc_hswish_s4_128x64x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_roc_id_s4_128x64x64_64x64x64_1_nc64hw64_c64rsk64.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_1x1_roc_id_s4_128x64x64_64x64x64_1_nc64hw64_c64rsk64.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_roc_relu_s4_128x64x64_64x64x64_1_nc64hw64_c64rsk64.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_1x1_roc_relu_s4_128x64x64_64x64x64_1_nc64hw64_c64rsk64.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_roc_hswish_s4_128x64x64_64x64x64_1_nc64hw64_c64rsk64.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_1x1_roc_hswish_s4_128x64x64_64x64x64_1_nc64hw64_c64rsk64.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_roc_id_u4_s4_128x256x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_1x1_roc_id_u4_s4_128x256x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_roc_relu_u4_s4_128x256x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_1x1_roc_relu_u4_s4_128x256x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_roc_id_u4_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_1x1_roc_id_u4_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_roc_relu_u4_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_1x1_roc_relu_u4_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_roc_id_u4_s4_128x64x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_1x1_roc_id_u4_s4_128x64x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_roc_relu_u4_s4_128x64x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_1x1_roc_relu_u4_s4_128x64x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_roc_id_u4_s4_128x64x64_64x64x64_1_nc64hw64_c64rsk64.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_1x1_roc_id_u4_s4_128x64x64_64x64x64_1_nc64hw64_c64rsk64.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_roc_relu_u4_s4_128x64x64_64x64x64_1_nc64hw64_c64rsk64.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_1x1_roc_relu_u4_s4_128x64x64_64x64x64_1_nc64hw64_c64rsk64.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_id_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_1x1_id_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_relu_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | |||
| @@ -621,4 +645,5 @@ cutlass_gen_list = [ | |||
| "cutlass_tensorop_u4_i8832fprop_1x1_roc_id_u4_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_roc_relu_u4_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_1x1_roc_relu_u4_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | |||
| "all_conv2d_tensorop8832_operations.cu", | |||
| ] | |||
| @@ -8,6 +8,7 @@ import enum | |||
| import os.path | |||
| import shutil | |||
| from lazy_file import LazyFile | |||
| from library import * | |||
| from gemm_operation import * | |||
| from conv2d_operation import * | |||
| @@ -349,3 +350,41 @@ void initialize_all(Manifest &manifest) { | |||
| # | |||
| ################################################################################################### | |||
| def GenerateManifest(args, operations, output_dir): | |||
| manifest_path = os.path.join(output_dir, "all_%s_%s_operations.cu" % (args.operations, args.type)) | |||
| f = LazyFile(manifest_path) | |||
| f.write(""" | |||
| /* | |||
| Generated by generator.py - Do not edit. | |||
| */ | |||
| #if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) | |||
| #include "cutlass/cutlass.h" | |||
| #include "src/cuda/cutlass/library.h" | |||
| #include "src/cuda/cutlass/manifest.h" | |||
| namespace cutlass { | |||
| namespace library { | |||
| """) | |||
| for op in operations: | |||
| f.write("void initialize_%s(Manifest &manifest);\n" % op.procedural_name()) | |||
| f.write(""" | |||
| void initialize_all_%s_%s_operations(Manifest &manifest) { | |||
| """ % (args.operations, args.type)) | |||
| for op in operations: | |||
| f.write(" initialize_%s(manifest);\n" % op.procedural_name()) | |||
| f.write(""" | |||
| } | |||
| } // namespace library | |||
| } // namespace cutlass | |||
| #endif | |||
| """) | |||
| f.close() | |||
| @@ -217,68 +217,77 @@ void ConvBiasForwardImpl::AlgoPack::fill_imma_algos() { | |||
| #if CUDA_VERSION >= 10020 | |||
| { | |||
| using AlgoParam = AlgoInt8NCHW32IMMAImplicitGemm::AlgoParam; | |||
| int8_nchw32_imma.emplace_back(AlgoParam{128, 256, 64, 64, 64, 64, 2}); | |||
| int8_nchw32_imma.emplace_back(AlgoParam{256, 128, 64, 64, 64, 64, 2}); | |||
| int8_nchw32_imma.emplace_back(AlgoParam{128, 128, 64, 64, 64, 64, 2}); | |||
| int8_nchw32_imma.emplace_back(AlgoParam{128, 64, 64, 64, 32, 64, 2}); | |||
| int8_nchw32_imma.emplace_back(AlgoParam{64, 128, 64, 32, 64, 64, 2}); | |||
| int8_nchw32_imma.emplace_back(AlgoParam{128, 64, 32, 64, 32, 32, 1}); | |||
| int8_nchw32_imma.emplace_back(AlgoParam{128, 32, 32, 64, 32, 32, 1}); | |||
| int8_nchw32_imma.emplace_back(AlgoParam{64, 128, 32, 32, 64, 32, 1}); | |||
| int8_nchw32_imma.emplace_back(AlgoParam{32, 128, 32, 32, 64, 32, 1}); | |||
| int8_nchw32_imma.emplace_back( | |||
| AlgoParam{128, 256, 64, 64, 64, 64, 8, 8, 16, 2}); | |||
| int8_nchw32_imma.emplace_back( | |||
| AlgoParam{256, 128, 64, 64, 64, 64, 8, 8, 16, 2}); | |||
| int8_nchw32_imma.emplace_back( | |||
| AlgoParam{128, 128, 64, 64, 64, 64, 8, 8, 16, 2}); | |||
| int8_nchw32_imma.emplace_back( | |||
| AlgoParam{128, 64, 64, 64, 32, 64, 8, 8, 16, 2}); | |||
| int8_nchw32_imma.emplace_back( | |||
| AlgoParam{64, 128, 64, 32, 64, 64, 8, 8, 16, 2}); | |||
| int8_nchw32_imma.emplace_back( | |||
| AlgoParam{128, 64, 32, 64, 32, 32, 8, 8, 16, 1}); | |||
| int8_nchw32_imma.emplace_back( | |||
| AlgoParam{128, 32, 32, 64, 32, 32, 8, 8, 16, 1}); | |||
| int8_nchw32_imma.emplace_back( | |||
| AlgoParam{64, 128, 32, 32, 64, 32, 8, 8, 16, 1}); | |||
| int8_nchw32_imma.emplace_back( | |||
| AlgoParam{32, 128, 32, 32, 64, 32, 8, 8, 16, 1}); | |||
| } | |||
| { | |||
| using AlgoParam = AlgoInt4Int4NCHW64IMMAImplicitGemm::AlgoParam; | |||
| int4_int4_nchw64_imma.emplace_back( | |||
| AlgoParam{128, 128, 128, 64, 64, 128, 2}); | |||
| AlgoParam{128, 128, 128, 64, 64, 128, 8, 8, 32, 2}); | |||
| int4_int4_nchw64_imma.emplace_back( | |||
| AlgoParam{128, 256, 128, 64, 64, 128, 2}); | |||
| AlgoParam{128, 256, 128, 64, 64, 128, 8, 8, 32, 2}); | |||
| int4_int4_nchw64_imma.emplace_back( | |||
| AlgoParam{128, 64, 128, 64, 64, 128, 2}); | |||
| AlgoParam{128, 64, 128, 64, 64, 128, 8, 8, 32, 2}); | |||
| int4_int4_nchw64_imma.emplace_back( | |||
| AlgoParam{128, 64, 64, 64, 64, 64, 1}); | |||
| AlgoParam{128, 64, 64, 64, 64, 64, 8, 8, 32, 1}); | |||
| } | |||
| { | |||
| using AlgoParam = AlgoUInt4Int4NCHW64IMMAImplicitGemm::AlgoParam; | |||
| uint4_int4_nchw64_imma.emplace_back( | |||
| AlgoParam{128, 128, 128, 64, 64, 128, 2}); | |||
| AlgoParam{128, 128, 128, 64, 64, 128, 8, 8, 32, 2}); | |||
| uint4_int4_nchw64_imma.emplace_back( | |||
| AlgoParam{128, 256, 128, 64, 64, 128, 2}); | |||
| AlgoParam{128, 256, 128, 64, 64, 128, 8, 8, 32, 2}); | |||
| uint4_int4_nchw64_imma.emplace_back( | |||
| AlgoParam{128, 64, 128, 64, 64, 128, 2}); | |||
| AlgoParam{128, 64, 128, 64, 64, 128, 8, 8, 32, 2}); | |||
| uint4_int4_nchw64_imma.emplace_back( | |||
| AlgoParam{128, 64, 64, 64, 64, 64, 1}); | |||
| AlgoParam{128, 64, 64, 64, 64, 64, 8, 8, 32, 1}); | |||
| } | |||
| { | |||
| using AlgoParam = AlgoInt4Int4NHWCIMMAImplicitGemm::AlgoParam; | |||
| int4_int4_nhwc_imma.emplace_back( | |||
| AlgoParam{128, 32, 64, 64, 32, 64, 1, 32}); | |||
| AlgoParam{128, 32, 64, 64, 32, 64, 8, 8, 32, 1, 32}); | |||
| int4_int4_nhwc_imma.emplace_back( | |||
| AlgoParam{128, 32, 64, 64, 32, 64, 1, 16}); | |||
| AlgoParam{128, 32, 64, 64, 32, 64, 8, 8, 32, 1, 16}); | |||
| int4_int4_nhwc_imma.emplace_back( | |||
| AlgoParam{128, 32, 64, 64, 32, 64, 1, 8}); | |||
| AlgoParam{128, 32, 64, 64, 32, 64, 8, 8, 32, 1, 8}); | |||
| int4_int4_nhwc_imma.emplace_back( | |||
| AlgoParam{128, 64, 64, 64, 64, 64, 1, 32}); | |||
| AlgoParam{128, 64, 64, 64, 64, 64, 8, 8, 32, 1, 32}); | |||
| int4_int4_nhwc_imma.emplace_back( | |||
| AlgoParam{128, 64, 64, 64, 64, 64, 1, 16}); | |||
| AlgoParam{128, 64, 64, 64, 64, 64, 8, 8, 32, 1, 16}); | |||
| int4_int4_nhwc_imma.emplace_back( | |||
| AlgoParam{128, 64, 64, 64, 64, 64, 1, 8}); | |||
| AlgoParam{128, 64, 64, 64, 64, 64, 8, 8, 32, 1, 8}); | |||
| } | |||
| { | |||
| using AlgoParam = AlgoUInt4Int4NHWCIMMAImplicitGemm::AlgoParam; | |||
| uint4_int4_nhwc_imma.emplace_back( | |||
| AlgoParam{128, 32, 64, 64, 32, 64, 1, 32}); | |||
| AlgoParam{128, 32, 64, 64, 32, 64, 8, 8, 32, 1, 32}); | |||
| uint4_int4_nhwc_imma.emplace_back( | |||
| AlgoParam{128, 32, 64, 64, 32, 64, 1, 16}); | |||
| AlgoParam{128, 32, 64, 64, 32, 64, 8, 8, 32, 1, 16}); | |||
| uint4_int4_nhwc_imma.emplace_back( | |||
| AlgoParam{128, 32, 64, 64, 32, 64, 1, 8}); | |||
| AlgoParam{128, 32, 64, 64, 32, 64, 8, 8, 32, 1, 8}); | |||
| uint4_int4_nhwc_imma.emplace_back( | |||
| AlgoParam{128, 64, 64, 64, 64, 64, 1, 32}); | |||
| AlgoParam{128, 64, 64, 64, 64, 64, 8, 8, 32, 1, 32}); | |||
| uint4_int4_nhwc_imma.emplace_back( | |||
| AlgoParam{128, 64, 64, 64, 64, 64, 1, 16}); | |||
| AlgoParam{128, 64, 64, 64, 64, 64, 8, 8, 32, 1, 16}); | |||
| uint4_int4_nhwc_imma.emplace_back( | |||
| AlgoParam{128, 64, 64, 64, 64, 64, 1, 8}); | |||
| AlgoParam{128, 64, 64, 64, 64, 64, 8, 8, 32, 1, 8}); | |||
| } | |||
| #endif | |||
| } | |||
| @@ -286,15 +295,24 @@ void ConvBiasForwardImpl::AlgoPack::fill_imma_algos() { | |||
| void ConvBiasForwardImpl::AlgoPack::fill_dp4a_algos() { | |||
| using AlgoParam = AlgoInt8NCHW4DotProdImplicitGemm::AlgoParam; | |||
| int8_nchw4_dotprod.emplace_back(AlgoParam{128, 128, 32, 64, 32, 32, 2}); | |||
| int8_nchw4_dotprod.emplace_back(AlgoParam{128, 64, 32, 64, 32, 32, 2}); | |||
| int8_nchw4_dotprod.emplace_back(AlgoParam{64, 128, 32, 64, 32, 32, 2}); | |||
| int8_nchw4_dotprod.emplace_back(AlgoParam{32, 128, 32, 32, 64, 32, 2}); | |||
| int8_nchw4_dotprod.emplace_back(AlgoParam{128, 32, 32, 64, 32, 32, 2}); | |||
| int8_nchw4_dotprod.emplace_back(AlgoParam{32, 64, 32, 32, 64, 32, 2}); | |||
| int8_nchw4_dotprod.emplace_back(AlgoParam{64, 32, 32, 64, 32, 32, 2}); | |||
| int8_nchw4_dotprod.emplace_back(AlgoParam{16, 128, 16, 16, 128, 16, 1}); | |||
| int8_nchw4_dotprod.emplace_back(AlgoParam{16, 64, 8, 16, 64, 8, 2}); | |||
| int8_nchw4_dotprod.emplace_back( | |||
| AlgoParam{128, 128, 32, 64, 32, 32, 1, 1, 4, 2}); | |||
| int8_nchw4_dotprod.emplace_back( | |||
| AlgoParam{128, 64, 32, 64, 32, 32, 1, 1, 4, 2}); | |||
| int8_nchw4_dotprod.emplace_back( | |||
| AlgoParam{64, 128, 32, 64, 32, 32, 1, 1, 4, 2}); | |||
| int8_nchw4_dotprod.emplace_back( | |||
| AlgoParam{32, 128, 32, 32, 64, 32, 1, 1, 4, 2}); | |||
| int8_nchw4_dotprod.emplace_back( | |||
| AlgoParam{128, 32, 32, 64, 32, 32, 1, 1, 4, 2}); | |||
| int8_nchw4_dotprod.emplace_back( | |||
| AlgoParam{32, 64, 32, 32, 64, 32, 1, 1, 4, 2}); | |||
| int8_nchw4_dotprod.emplace_back( | |||
| AlgoParam{64, 32, 32, 64, 32, 32, 1, 1, 4, 2}); | |||
| int8_nchw4_dotprod.emplace_back( | |||
| AlgoParam{16, 128, 16, 16, 128, 16, 1, 1, 4, 1}); | |||
| int8_nchw4_dotprod.emplace_back( | |||
| AlgoParam{16, 64, 8, 16, 64, 8, 1, 1, 4, 2}); | |||
| } | |||
| ConvBiasForwardImpl::AlgoBase* | |||
| @@ -28,6 +28,17 @@ | |||
| #include <memory> | |||
| #include <unordered_map> | |||
| namespace cutlass { | |||
| namespace library { | |||
| // forward declaration of cutlass library concepts, we hope that algo.h does | |||
| // not depend on cutlass headers | |||
| class Operation; | |||
| } // namespace library | |||
| } // namespace cutlass | |||
| namespace megdnn { | |||
| namespace cuda { | |||
| @@ -505,9 +516,44 @@ public: | |||
| MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_CHWN4_DOTPROD_INT8) | |||
| }; | |||
| class ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm final | |||
| : public AlgoBase { | |||
| /*********************** Cutlass Algorithms ************************/ | |||
| /* The inheritance of cutlass algorithm classes: | |||
| * | |||
| * AlgoCutlassConvolutionBase | |||
| * + | |||
| * +--- AlgoInt8NCHW4DotProdImplicitGemm | |||
| * +--- AlgoInt8NCHW32IMMAImplicitGemm | |||
| * + | |||
| * +--- AlgoInt4NCHW64IMMAImplicitGemmBase | |||
| * +----+--- AlgoInt4Int4NCHW64IMMAImplicitGemm | |||
| * +----+--- AlgoUInt4Int4NCHW64IMMAImplicitGemm | |||
| * + | |||
| * +--- AlgoInt4NHWCIMMAImplicitGemmBase | |||
| * +----+--- AlgoInt4Int4NHWCIMMAImplicitGemm | |||
| * +----+--- AlgoUInt4Int4NHWCIMMAImplicitGemm | |||
| * + | |||
| */ | |||
| /* | |||
| * The base class for all cutlass algorithm classes | |||
| */ | |||
| class ConvBiasForwardImpl::AlgoCutlassConvolutionBase : public AlgoBase { | |||
| public: | |||
| // corresponds to cutlass::conv::Operator. we hope that algo.h does not | |||
| // depend on cutlass headers | |||
| enum class ConvOperator { kFprop, kDgrad, kWgrad }; | |||
| // corresponds to cutlass::conv::ConvType. we hope that algo.h does not | |||
| // depend on cutlass headers | |||
| enum class ConvType { | |||
| kConvolution, | |||
| kBatchConvolution, | |||
| kLocal, | |||
| kLocalShare | |||
| }; | |||
| // common parameters for operation selection | |||
| struct AlgoParam { | |||
| int threadblock_m; | |||
| int threadblock_n; | |||
| @@ -515,21 +561,54 @@ public: | |||
| int warp_m; | |||
| int warp_n; | |||
| int warp_k; | |||
| int instruction_m; | |||
| int instruction_n; | |||
| int instruction_k; | |||
| int stage; | |||
| std::string to_string() { | |||
| /// default algorithm | |||
| if (threadblock_m == 128 && threadblock_n == 128 && | |||
| threadblock_k == 32 && warp_m == 32 && warp_n == 64 && | |||
| warp_k == 32 && stage == 2) { | |||
| return ""; | |||
| } | |||
| return ssprintf("_%dX%dX%d_%dX%dX%d_%dstage", threadblock_m, | |||
| threadblock_n, threadblock_k, warp_m, warp_n, | |||
| warp_k, stage); | |||
| } | |||
| int access_size; | |||
| AlgoParam(int threadblock_m_, int threadblock_n_, int threadblock_k_, | |||
| int warp_m_, int warp_n_, int warp_k_, int instruction_m_, | |||
| int instruction_n_, int instruction_k_, int stage_, | |||
| int access_size_ = 0); | |||
| std::string to_string() const; | |||
| }; | |||
| AlgoCutlassConvolutionBase(AlgoParam algo_param) | |||
| : m_algo_param{algo_param} {} | |||
| // generate a cutlass::library::ConvolutionKey and find the corresponding | |||
| // operation (cutlass kernel) from the global OperationTable | |||
| const cutlass::library::Operation* get_cutlass_conv_op( | |||
| const SizeArgs& args, ConvOperator conv_op, ConvType conv_type, | |||
| bool load_from_const, bool without_shared_load) const; | |||
| // execute the cutlass kernel found by get_cutlass_conv_op. we give | |||
| // subclasses full freedom to decide where and how these arguments are | |||
| // extracted | |||
| void execute_cutlass_conv_op(const cutlass::library::Operation* op, | |||
| const void* src, const void* filter, | |||
| const void* bias, const void* z, void* dst, | |||
| void* workspace, size_t n, size_t hi, | |||
| size_t wi, size_t ci, size_t co, size_t fh, | |||
| size_t fw, size_t ho, size_t wo, size_t ph, | |||
| size_t pw, size_t sh, size_t sw, size_t dh, | |||
| size_t dw, const void* alpha, const void* beta, | |||
| const void* gamma, const void* delta, | |||
| const void* theta, const void* threshold, | |||
| const void* dst_scale, cudaStream_t stream, | |||
| const void* extra_param = nullptr) const; | |||
| protected: | |||
| AlgoParam m_algo_param; | |||
| }; | |||
| class ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm final | |||
| : public AlgoCutlassConvolutionBase { | |||
| public: | |||
| AlgoInt8NCHW4DotProdImplicitGemm(AlgoParam algo_param) | |||
| : m_algo_param{algo_param}, | |||
| : AlgoCutlassConvolutionBase(algo_param), | |||
| m_name{ssprintf("INT8_NCHW4_DOTPROD_IMPLICIT_GEMM%s", | |||
| m_algo_param.to_string().c_str())} {} | |||
| bool is_available(const SizeArgs& args) const override; | |||
| @@ -555,7 +634,6 @@ public: | |||
| private: | |||
| WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, | |||
| const SizeArgs& args) const; | |||
| AlgoParam m_algo_param; | |||
| std::string m_name; | |||
| }; | |||
| @@ -714,19 +792,10 @@ private: | |||
| #if CUDA_VERSION >= 10020 | |||
| class ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm final | |||
| : public AlgoBase { | |||
| : public AlgoCutlassConvolutionBase { | |||
| public: | |||
| struct AlgoParam { | |||
| int threadblock_m; | |||
| int threadblock_n; | |||
| int threadblock_k; | |||
| int warp_m; | |||
| int warp_n; | |||
| int warp_k; | |||
| int stage; | |||
| }; | |||
| AlgoInt8NCHW32IMMAImplicitGemm(AlgoParam algo_param) | |||
| : m_algo_param{algo_param} { | |||
| : AlgoCutlassConvolutionBase(algo_param) { | |||
| m_name = ConvBias::algo_name<ConvBias::DirectParam>( | |||
| ssprintf("INT8_NCHW32_IMMA_IMPLICIT_GEMM_%s", | |||
| to_string(m_algo_param).c_str()), | |||
| @@ -757,25 +826,14 @@ private: | |||
| WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, | |||
| const SizeArgs& args) const; | |||
| AlgoParam m_algo_param; | |||
| std::string m_name; | |||
| }; | |||
| class ConvBiasForwardImpl::AlgoInt4NCHW64IMMAImplicitGemmBase | |||
| : public AlgoBase { | |||
| : public AlgoCutlassConvolutionBase { | |||
| public: | |||
| struct AlgoParam { | |||
| int threadblock_m; | |||
| int threadblock_n; | |||
| int threadblock_k; | |||
| int warp_m; | |||
| int warp_n; | |||
| int warp_k; | |||
| int stage; | |||
| }; | |||
| AlgoInt4NCHW64IMMAImplicitGemmBase(AlgoParam algo_param) | |||
| : m_algo_param(algo_param) {} | |||
| : AlgoCutlassConvolutionBase(algo_param) {} | |||
| AlgoAttribute attribute() const override { | |||
| return AlgoAttribute::REPRODUCIBLE; | |||
| @@ -799,16 +857,9 @@ protected: | |||
| virtual std::tuple<float, float, float, float, float> get_constants( | |||
| const ExecArgs& args) const = 0; | |||
| virtual void do_exec(const ExecArgs& args, void* filter_ptr, void* bias_ptr, | |||
| void* z_ptr, convolution::ConvParam kern_param, | |||
| uint32_t nonlinear_mode, float alpha, float beta, | |||
| float gamma, float delta, float theta, | |||
| cudaStream_t stream) const = 0; | |||
| void reorder_filter(const ExecArgs& args, void* reordered_filter) const; | |||
| std::string m_name; | |||
| AlgoParam m_algo_param; | |||
| }; | |||
| class ConvBiasForwardImpl::AlgoInt4Int4NCHW64IMMAImplicitGemm final | |||
| @@ -842,11 +893,6 @@ private: | |||
| std::tuple<float, float, float, float, float> get_constants( | |||
| const ExecArgs& args) const override; | |||
| void do_exec(const ExecArgs& args, void* filter_ptr, void* bias_ptr, | |||
| void* z_ptr, convolution::ConvParam kern_param, | |||
| uint32_t nonlinear_mode, float alpha, float beta, float gamma, | |||
| float delta, float theta, cudaStream_t stream) const override; | |||
| }; | |||
| class ConvBiasForwardImpl::AlgoUInt4Int4NCHW64IMMAImplicitGemm final | |||
| @@ -881,30 +927,15 @@ private: | |||
| std::tuple<float, float, float, float, float> get_constants( | |||
| const ExecArgs& args) const override; | |||
| void do_exec(const ExecArgs& args, void* filter_ptr, void* bias_ptr, | |||
| void* z_ptr, convolution::ConvParam kern_param, | |||
| uint32_t nonlinear_mode, float alpha, float beta, float gamma, | |||
| float delta, float theta, cudaStream_t stream) const override; | |||
| void update_bias(const ExecArgs& args, void* updated_bias, | |||
| void* reduce_filter_ptr, void* reduce_workspace) const; | |||
| }; | |||
| class ConvBiasForwardImpl::AlgoInt4NHWCIMMAImplicitGemmBase : public AlgoBase { | |||
| class ConvBiasForwardImpl::AlgoInt4NHWCIMMAImplicitGemmBase | |||
| : public AlgoCutlassConvolutionBase { | |||
| public: | |||
| struct AlgoParam { | |||
| int threadblock_m; | |||
| int threadblock_n; | |||
| int threadblock_k; | |||
| int warp_m; | |||
| int warp_n; | |||
| int warp_k; | |||
| int stage; | |||
| int access_size; | |||
| }; | |||
| AlgoInt4NHWCIMMAImplicitGemmBase(AlgoParam algo_param) | |||
| : m_algo_param(algo_param) {} | |||
| : AlgoCutlassConvolutionBase(algo_param) {} | |||
| AlgoAttribute attribute() const override { | |||
| return AlgoAttribute::REPRODUCIBLE; | |||
| @@ -928,17 +959,10 @@ protected: | |||
| virtual std::tuple<float, float, float, float, float> get_constants( | |||
| const ExecArgs& args) const = 0; | |||
| virtual void do_exec(const ExecArgs& args, void* filter_ptr, void* bias_ptr, | |||
| void* z_ptr, convolution::ConvParam kern_param, | |||
| uint32_t nonlinear_mode, float alpha, float beta, | |||
| float gamma, float delta, float theta, | |||
| cudaStream_t stream) const = 0; | |||
| void reorder_filter(const ExecArgs& args, int interleaved, | |||
| void* reordered_filter) const; | |||
| std::string m_name; | |||
| AlgoParam m_algo_param; | |||
| }; | |||
| class ConvBiasForwardImpl::AlgoInt4Int4NHWCIMMAImplicitGemm final | |||
| @@ -971,11 +995,6 @@ private: | |||
| std::tuple<float, float, float, float, float> get_constants( | |||
| const ExecArgs& args) const override; | |||
| void do_exec(const ExecArgs& args, void* filter_ptr, void* bias_ptr, | |||
| void* z_ptr, convolution::ConvParam kern_param, | |||
| uint32_t nonlinear_mode, float alpha, float beta, float gamma, | |||
| float delta, float theta, cudaStream_t stream) const override; | |||
| }; | |||
| class ConvBiasForwardImpl::AlgoUInt4Int4NHWCIMMAImplicitGemm final | |||
| @@ -1009,11 +1028,6 @@ private: | |||
| std::tuple<float, float, float, float, float> get_constants( | |||
| const ExecArgs& args) const override; | |||
| void do_exec(const ExecArgs& args, void* filter_ptr, void* bias_ptr, | |||
| void* z_ptr, convolution::ConvParam kern_param, | |||
| uint32_t nonlinear_mode, float alpha, float beta, float gamma, | |||
| float delta, float theta, cudaStream_t stream) const override; | |||
| void update_bias(const ExecArgs& args, void* updated_bias, | |||
| void* reduce_filter_ptr, void* reduce_workspace) const; | |||
| }; | |||
| @@ -0,0 +1,253 @@ | |||
| /** | |||
| * \file dnn/src/cuda/conv_bias/cutlass_convolution_base.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "src/cuda/conv_bias/algo.h" | |||
| #include "src/cuda/cutlass/singleton.h" | |||
| namespace megdnn { | |||
| namespace cuda { | |||
| using namespace cutlass::library; | |||
| using namespace cutlass::epilogue; | |||
| ConvBiasForwardImpl::AlgoCutlassConvolutionBase::AlgoParam::AlgoParam( | |||
| int threadblock_m_, int threadblock_n_, int threadblock_k_, int warp_m_, | |||
| int warp_n_, int warp_k_, int instruction_m_, int instruction_n_, | |||
| int instruction_k_, int stage_, int access_size_) | |||
| : threadblock_m(threadblock_m_), | |||
| threadblock_n(threadblock_n_), | |||
| threadblock_k(threadblock_k_), | |||
| warp_m(warp_m_), | |||
| warp_n(warp_n_), | |||
| warp_k(warp_k_), | |||
| instruction_m(instruction_m_), | |||
| instruction_n(instruction_m_), | |||
| instruction_k(instruction_k_), | |||
| stage(stage_), | |||
| access_size(access_size_) {} | |||
| std::string | |||
| ConvBiasForwardImpl::AlgoCutlassConvolutionBase::AlgoParam::to_string() const { | |||
| /// default algorithm | |||
| if (threadblock_m == 128 && threadblock_n == 128 && threadblock_k == 32 && | |||
| warp_m == 32 && warp_n == 64 && warp_k == 32 && stage == 2) { | |||
| return ""; | |||
| } | |||
| return ssprintf("_%dX%dX%d_%dX%dX%d_%dstage", threadblock_m, threadblock_n, | |||
| threadblock_k, warp_m, warp_n, warp_k, stage); | |||
| } | |||
| namespace { | |||
| using Base = ConvBiasForwardImpl::AlgoCutlassConvolutionBase; | |||
| cutlass::conv::Operator convert_conv_op(Base::ConvOperator conv_op) { | |||
| switch (conv_op) { | |||
| case Base::ConvOperator::kFprop: | |||
| return cutlass::conv::Operator::kFprop; | |||
| case Base::ConvOperator::kDgrad: | |||
| return cutlass::conv::Operator::kDgrad; | |||
| case Base::ConvOperator::kWgrad: | |||
| return cutlass::conv::Operator::kWgrad; | |||
| default: | |||
| megdnn_assert(0, "invalid conv op"); | |||
| } | |||
| } | |||
| cutlass::conv::ConvType convert_conv_type(Base::ConvType conv_type) { | |||
| switch (conv_type) { | |||
| case Base::ConvType::kConvolution: | |||
| return cutlass::conv::ConvType::kConvolution; | |||
| case Base::ConvType::kBatchConvolution: | |||
| return cutlass::conv::ConvType::kBatchConvolution; | |||
| case Base::ConvType::kLocal: | |||
| return cutlass::conv::ConvType::kLocal; | |||
| case Base::ConvType::kLocalShare: | |||
| return cutlass::conv::ConvType::kLocalShare; | |||
| default: | |||
| megdnn_assert(0, "invalid conv type"); | |||
| } | |||
| } | |||
| NumericTypeID convert_dtype(DTypeEnum dtype) { | |||
| switch (dtype) { | |||
| case DTypeEnum::Float32: | |||
| return NumericTypeID::kF32; | |||
| case DTypeEnum::Float16: | |||
| return NumericTypeID::kF16; | |||
| case DTypeEnum::Int8: | |||
| return NumericTypeID::kS8; | |||
| case DTypeEnum::QuantizedS32: | |||
| return NumericTypeID::kS32; | |||
| case DTypeEnum::QuantizedS8: | |||
| return NumericTypeID::kS8; | |||
| case DTypeEnum::QuantizedS4: | |||
| return NumericTypeID::kS4; | |||
| case DTypeEnum::Quantized4Asymm: | |||
| return NumericTypeID::kU4; | |||
| default: | |||
| megdnn_assert(0, "invalid dtype"); | |||
| } | |||
| } | |||
| struct LayoutPack { | |||
| LayoutTypeID src; | |||
| LayoutTypeID filter; | |||
| LayoutTypeID dst; | |||
| LayoutTypeID bias; | |||
| }; | |||
| LayoutPack get_layout_pack(const param::ConvBias::Format format, | |||
| int access_type) { | |||
| using Format = param::ConvBias::Format; | |||
| switch (format) { | |||
| case Format::NCHW4: | |||
| return {LayoutTypeID::kTensorNC4HW4, LayoutTypeID::kTensorC4RSK4, | |||
| LayoutTypeID::kTensorNC4HW4, LayoutTypeID::kTensorNC4HW4}; | |||
| case Format::NCHW4_NCHW: | |||
| return {LayoutTypeID::kTensorNC4HW4, LayoutTypeID::kTensorC4RSK4, | |||
| LayoutTypeID::kTensorNCHW, LayoutTypeID::kTensorNCHW}; | |||
| case Format::NCHW4_NHWC: | |||
| return {LayoutTypeID::kTensorNC4HW4, LayoutTypeID::kTensorC4RSK4, | |||
| LayoutTypeID::kTensorNHWC, LayoutTypeID::kTensorNHWC}; | |||
| case Format::NCHW4_NCHW32: | |||
| return {LayoutTypeID::kTensorNC4HW4, LayoutTypeID::kTensorC4RSK4, | |||
| LayoutTypeID::kTensorNC32HW32, | |||
| LayoutTypeID::kTensorNC32HW32}; | |||
| case Format::NCHW32: | |||
| return {LayoutTypeID::kTensorNC32HW32, | |||
| LayoutTypeID::kTensorC32RSK32, | |||
| LayoutTypeID::kTensorNC32HW32, | |||
| LayoutTypeID::kTensorNC32HW32}; | |||
| case Format::NCHW32_NCHW4: | |||
| return {LayoutTypeID::kTensorNC32HW32, | |||
| LayoutTypeID::kTensorC32RSK32, LayoutTypeID::kTensorNC4HW4, | |||
| LayoutTypeID::kTensorNC4HW4}; | |||
| case Format::NCHW64: | |||
| return {LayoutTypeID::kTensorNC64HW64, | |||
| LayoutTypeID::kTensorC64RSK64, | |||
| LayoutTypeID::kTensorNC64HW64, | |||
| LayoutTypeID::kTensorNC64HW64}; | |||
| case Format::NHWC: | |||
| switch (access_type) { | |||
| case 8: | |||
| return {LayoutTypeID::kTensorNHWC, | |||
| LayoutTypeID::kTensorNC8HW8, | |||
| LayoutTypeID::kTensorNHWC, | |||
| LayoutTypeID::kTensorNHWC}; | |||
| case 16: | |||
| return {LayoutTypeID::kTensorNHWC, | |||
| LayoutTypeID::kTensorNC16HW16, | |||
| LayoutTypeID::kTensorNHWC, | |||
| LayoutTypeID::kTensorNHWC}; | |||
| case 32: | |||
| return {LayoutTypeID::kTensorNHWC, | |||
| LayoutTypeID::kTensorNC32HW32, | |||
| LayoutTypeID::kTensorNHWC, | |||
| LayoutTypeID::kTensorNHWC}; | |||
| default: | |||
| megdnn_assert(0, "invalid access_type"); | |||
| } | |||
| default: | |||
| megdnn_assert(0, "invalid format"); | |||
| } | |||
| } | |||
| EpilogueType get_epilogue_type(const param::ConvBias::NonlineMode mode, | |||
| bool clamp) { | |||
| using NonlineMode = param::ConvBias::NonlineMode; | |||
| if (clamp) { | |||
| if (mode == NonlineMode::IDENTITY) { | |||
| return EpilogueType::kBiasAddLinearCombinationClamp; | |||
| } else if (mode == NonlineMode::RELU) { | |||
| return EpilogueType::kBiasAddLinearCombinationReluClamp; | |||
| } else if (mode == NonlineMode::H_SWISH) { | |||
| return EpilogueType::kBiasAddLinearCombinationHSwishClamp; | |||
| } | |||
| } else { | |||
| if (mode == NonlineMode::IDENTITY) { | |||
| return EpilogueType::kBiasAddLinearCombination; | |||
| } else if (mode == NonlineMode::RELU) { | |||
| return EpilogueType::kBiasAddLinearCombinationRelu; | |||
| } else if (mode == NonlineMode::H_SWISH) { | |||
| return EpilogueType::kBiasAddLinearCombinationHSwish; | |||
| } | |||
| } | |||
| megdnn_assert(0, "invalid nonlinear mode"); | |||
| } | |||
| } // namespace | |||
| const Operation* | |||
| ConvBiasForwardImpl::AlgoCutlassConvolutionBase::get_cutlass_conv_op( | |||
| const SizeArgs& args, ConvOperator conv_op, ConvType conv_type, | |||
| bool load_from_const, bool without_shared_load) const { | |||
| using Format = param::ConvBias::Format; | |||
| auto&& param = args.opr->param(); | |||
| auto layouts = get_layout_pack(param.format, m_algo_param.access_size); | |||
| auto epilogue_type = get_epilogue_type(param.nonlineMode, | |||
| param.format != Format::NCHW4_NCHW); | |||
| ConvolutionKey key{convert_conv_op(conv_op), | |||
| convert_dtype(args.src_layout->dtype.enumv()), | |||
| layouts.src, | |||
| convert_dtype(args.filter_layout->dtype.enumv()), | |||
| layouts.filter, | |||
| convert_dtype(args.dst_layout->dtype.enumv()), | |||
| layouts.dst, | |||
| convert_dtype(args.bias_layout->dtype.enumv()), | |||
| layouts.bias, | |||
| convert_conv_type(conv_type), | |||
| m_algo_param.threadblock_m, | |||
| m_algo_param.threadblock_n, | |||
| m_algo_param.threadblock_k, | |||
| m_algo_param.warp_m, | |||
| m_algo_param.warp_n, | |||
| m_algo_param.warp_k, | |||
| m_algo_param.instruction_m, | |||
| m_algo_param.instruction_n, | |||
| m_algo_param.instruction_k, | |||
| epilogue_type, | |||
| m_algo_param.stage, | |||
| load_from_const, | |||
| without_shared_load}; | |||
| return Singleton::get().operation_table.find_op(key); | |||
| } | |||
| void ConvBiasForwardImpl::AlgoCutlassConvolutionBase::execute_cutlass_conv_op( | |||
| const Operation* op, const void* src, const void* filter, | |||
| const void* bias, const void* z, void* dst, void* workspace, size_t n, | |||
| size_t hi, size_t wi, size_t ci, size_t co, size_t fh, size_t fw, | |||
| size_t ho, size_t wo, size_t ph, size_t pw, size_t sh, size_t sw, | |||
| size_t dh, size_t dw, const void* alpha, const void* beta, | |||
| const void* gamma, const void* delta, const void* theta, | |||
| const void* threshold, const void* dst_scale, cudaStream_t stream, | |||
| const void* extra_param) const { | |||
| // gcc prints warnings when size_t values are implicitly narrowed to int | |||
| cutlass::conv::Conv2dProblemSize problem_size{ | |||
| int(n), int(hi), int(wi), int(ci), | |||
| int(co), int(fh), int(fw), int(ho), | |||
| int(wo), int(ph), int(pw), int(sh), | |||
| int(sw), int(dh), int(dw), cutlass::conv::Mode::kCrossCorrelation}; | |||
| ConvolutionArguments conv_args{ | |||
| problem_size, src, filter, bias, z, | |||
| dst, alpha, beta, gamma, delta, | |||
| theta, threshold, dst_scale, extra_param}; | |||
| cutlass_check(op->run(&conv_args, workspace, stream)); | |||
| } | |||
| } // namespace cuda | |||
| } // namespace megdnn | |||
| @@ -1,129 +0,0 @@ | |||
| /** | |||
| * \file dnn/src/cuda/conv_bias/cutlass_convolution_wrapper.cuh | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #pragma once | |||
| #include "cutlass/gemm/gemm.h" | |||
| #include "src/cuda/convolution_helper/parameter.cuh" | |||
| #include "src/cuda/utils.cuh" | |||
| namespace megdnn { | |||
| namespace cuda { | |||
| namespace cutlass_wrapper { | |||
| using GemmCoord = cutlass::gemm::GemmCoord; | |||
| template <typename Convolution> | |||
| void cutlass_convolution_wrapper( | |||
| const typename Convolution::ElementSrc* d_src, | |||
| const typename Convolution::ElementFilter* d_filter, | |||
| const typename Convolution::ElementBias* d_bias, | |||
| const typename Convolution::ElementDst* d_z, | |||
| typename Convolution::ElementDst* d_dst, int* workspace, | |||
| typename Convolution::ConvolutionParameter const& conv_param, | |||
| typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||
| cudaStream_t stream, typename Convolution::ExtraParam extra_param = {}); | |||
| template <bool NeedLoadFromConstMem> | |||
| void do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32( | |||
| const int8_t* d_src, const int8_t* d_filter, const int32_t* d_bias, | |||
| const int8_t* d_z, int8_t* d_dst, int* workspace, | |||
| const convolution::ConvParam& param, uint32_t nonlinear_mode, | |||
| float alpha, float beta, float gamma, float scale, | |||
| const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, | |||
| int stages, cudaStream_t stream); | |||
| template <bool NeedLoadFromConstMem> | |||
| void do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4( | |||
| const int8_t* d_src, const int8_t* d_filter, const int32_t* d_bias, | |||
| const int8_t* d_z, int8_t* d_dst, int* workspace, | |||
| const convolution::ConvParam& param, uint32_t nonlinear_mode, | |||
| float alpha, float beta, float gamma, float scale, | |||
| const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, | |||
| int stages, cudaStream_t stream); | |||
| template <bool NeedLoadFromConstMem> | |||
| void do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4( | |||
| const int8_t* d_src, const int8_t* d_filter, const int32_t* d_bias, | |||
| const int8_t* d_z, int8_t* d_dst, int* workspace, | |||
| const convolution::ConvParam& param, uint32_t nonlinear_mode, | |||
| float alpha, float beta, float gamma, float scale, | |||
| const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, | |||
| int stages, cudaStream_t stream); | |||
| template <bool NeedLoadFromConstMem> | |||
| void do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nchw( | |||
| const int8_t* d_src, const int8_t* d_filter, const float* d_bias, | |||
| const float* d_z, float* d_dst, int* workspace, | |||
| const convolution::ConvParam& param, uint32_t nonlinear_mode, | |||
| float alpha, float beta, float gamma, float scale, | |||
| const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, | |||
| int stages, cudaStream_t stream); | |||
| template <bool NeedLoadFromConstMem> | |||
| void do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32( | |||
| const int8_t* d_src, const int8_t* d_filter, const int32_t* d_bias, | |||
| const int8_t* d_z, int8_t* d_dst, int* workspace, | |||
| const convolution::ConvParam& param, uint32_t nonlinear_mode, | |||
| float alpha, float beta, float gamma, float scale, | |||
| const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, | |||
| int stages, cudaStream_t stream); | |||
| template <bool NeedLoadFromConstMem> | |||
| void do_conv_bias_int4_int4_implicit_gemm_imma_ncdiv64hw64( | |||
| const int8_t* d_src, const int8_t* d_filter, const int32_t* d_bias, | |||
| const int8_t* d_z, int8_t* d_dst, int* workspace, | |||
| const convolution::ConvParam& param, uint32_t nonlinear_mode, | |||
| float alpha, float beta, float gamma, float scale, | |||
| const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, | |||
| int stages, cudaStream_t stream); | |||
| template <bool NeedLoadFromConstMem> | |||
| void do_conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64( | |||
| const uint8_t* d_src, const int8_t* d_filter, const int32_t* d_bias, | |||
| const uint8_t* d_z, uint8_t* d_dst, int* workspace, | |||
| const convolution::ConvParam& param, uint32_t nonlinear_mode, | |||
| float alpha, float beta, float gamma, float delta, float theta, | |||
| float scale, uint8_t src_zero_point, const GemmCoord& threadblock_shape, | |||
| const GemmCoord& warp_shape, int stages, cudaStream_t stream); | |||
| template <bool signedness> | |||
| void do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nhwc( | |||
| const int8_t* d_src, const int8_t* d_filter, const int32_t* d_bias, | |||
| const int8_t* d_z, int8_t* d_dst, int* workspace, | |||
| const convolution::ConvParam& param, uint32_t nonlinear_mode, | |||
| float alpha, float beta, float gamma, float delta, float theta, | |||
| float scale, const GemmCoord& threadblock_shape, | |||
| const GemmCoord& warp_shape, int stages, cudaStream_t stream); | |||
| template <bool NeedLoadFromConstMem> | |||
| void do_conv_bias_int4_int4_implicit_gemm_imma_nhwc( | |||
| const int8_t* d_src, const int8_t* d_filter, const int32_t* d_bias, | |||
| const int8_t* d_z, int8_t* d_dst, int* workspace, | |||
| const convolution::ConvParam& param, uint32_t nonlinear_mode, | |||
| float alpha, float beta, float gamma, float scale, | |||
| const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, | |||
| const int32_t access_size, int stages, cudaStream_t stream); | |||
| template <bool NeedLoadFromConstMem> | |||
| void do_conv_bias_uint4_int4_implicit_gemm_imma_nhwc( | |||
| const uint8_t* d_src, const int8_t* d_filter, const int32_t* d_bias, | |||
| const uint8_t* d_z, uint8_t* d_dst, int* workspace, | |||
| const convolution::ConvParam& param, uint32_t nonlinear_mode, | |||
| float alpha, float beta, float gamma, float delta, float theta, | |||
| float scale, uint8_t src_zero_point, const GemmCoord& threadblock_shape, | |||
| const GemmCoord& warp_shape, const int32_t access_size, int stages, | |||
| cudaStream_t stream); | |||
| } // namespace cutlass_wrapper | |||
| } // namespace cuda | |||
| } // namespace megdnn | |||
| // vim: syntax=cuda.doxygen | |||
| @@ -1,595 +0,0 @@ | |||
| /** | |||
| * \file dnn/src/cuda/conv_bias/cutlass_convolution_wrapper.cu | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| // ignore warning of cutlass | |||
| #pragma GCC diagnostic push | |||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||
| #if !MEGDNN_TEGRA_X1 | |||
| #include "cutlass/convolution/device/convolution.h" | |||
| #endif | |||
| #include "src/common/opr_param_defs_enumv.cuh" | |||
| #include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh" | |||
| #pragma GCC diagnostic pop | |||
| using namespace megdnn; | |||
| using namespace cuda; | |||
| using namespace cutlass_wrapper; | |||
| /* ====== cutlass kernel wrapper for int4 x int4 nchw64 layout ====== */ | |||
| #if MEGDNN_TEGRA_X1 | |||
| template <bool NeedLoadFromConstMem> | |||
| void megdnn::cuda::cutlass_wrapper:: | |||
| do_conv_bias_int4_int4_implicit_gemm_imma_ncdiv64hw64( | |||
| const int8_t* /* d_src */, const int8_t* /* d_filter */, | |||
| const int32_t* /* d_bias */, const int8_t* /* d_z */, | |||
| int8_t* /* d_dst */, int* /* workspace */, | |||
| const convolution::ConvParam& /* param */, | |||
| uint32_t /* nonlinear_mode */, float /* alpha */, | |||
| float /* beta */, float /* gamma */, float /* scale */, | |||
| const GemmCoord& /* threadblock_shape */, | |||
| const GemmCoord& /* warp_shape */, int /* stages */, | |||
| cudaStream_t /* stream */) {} | |||
| #else | |||
| template <bool NeedLoadFromConstMem> | |||
| void megdnn::cuda::cutlass_wrapper:: | |||
| do_conv_bias_int4_int4_implicit_gemm_imma_ncdiv64hw64( | |||
| const int8_t* d_src, const int8_t* d_filter, | |||
| const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, | |||
| int* workspace, const convolution::ConvParam& param, | |||
| uint32_t nonlinear_mode, float alpha, float beta, float gamma, | |||
| float scale, const GemmCoord& threadblock_shape, | |||
| const GemmCoord& warp_shape, int stages, cudaStream_t stream) { | |||
| #define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \ | |||
| threadblock_k_, warp_m_, warp_n_, \ | |||
| warp_k_, stage_) \ | |||
| if (threadblock_shape.m() == threadblock_m_ && \ | |||
| threadblock_shape.n() == threadblock_n_ && \ | |||
| threadblock_shape.k() == threadblock_k_ && \ | |||
| warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ | |||
| warp_shape.k() == warp_k_ && stages == stage_) { \ | |||
| using ThreadBlockShape = \ | |||
| cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \ | |||
| threadblock_k_>; \ | |||
| using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \ | |||
| using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; \ | |||
| using Convolution = cutlass::conv::device::Convolution< \ | |||
| cutlass::int4b_t, cutlass::layout::TensorNCxHWx<64>, \ | |||
| cutlass::int4b_t, cutlass::layout::TensorCxRSKx<64>, \ | |||
| ElementOutput, cutlass::layout::TensorNCxHWx<64>, int32_t, \ | |||
| cutlass::layout::TensorNCxHWx<64>, int32_t, \ | |||
| cutlass::conv::ConvType::kConvolution, \ | |||
| cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \ | |||
| ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ | |||
| cutlass::conv::threadblock:: \ | |||
| ConvolutionFpropTransThreadblockSwizzle, \ | |||
| stage_, 32, 32, NeedLoadFromConstMem, \ | |||
| cutlass::arch::OpMultiplyAddSaturate, \ | |||
| cutlass::conv::ImplicitGemmMode::GEMM_TN, true>; \ | |||
| typename Convolution::ConvolutionParameter conv_param( \ | |||
| param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ | |||
| param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ | |||
| param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \ | |||
| return cutlass_convolution_wrapper<Convolution>( \ | |||
| reinterpret_cast<const cutlass::int4b_t*>(d_src), \ | |||
| reinterpret_cast<const cutlass::int4b_t*>(d_filter), d_bias, \ | |||
| reinterpret_cast<const cutlass::int4b_t*>(d_z), \ | |||
| reinterpret_cast<cutlass::int4b_t*>(d_dst), workspace, \ | |||
| conv_param, epilogue, stream); \ | |||
| } | |||
| #define DISPATCH_KERNEL \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 128, 64, 64, 128, 2); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 256, 128, 64, 64, 128, 2); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 128, 64, 64, 128, 2); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1); \ | |||
| megdnn_assert(false, \ | |||
| "unsupported threadblock shape (%dx%dx%d) and warp shape " \ | |||
| "(%dx%dx%d)", \ | |||
| threadblock_shape.m(), threadblock_shape.n(), \ | |||
| threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ | |||
| warp_shape.k()); | |||
| using ElementOutput = cutlass::int4b_t; | |||
| using ElementAccumulator = int32_t; | |||
| using ElementBias = int32_t; | |||
| using ElementCompute = float; | |||
| using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode; | |||
| switch (nonlinear_mode) { | |||
| case NonlineMode::IDENTITY: { | |||
| using EpilogueOp = | |||
| cutlass::epilogue::thread::BiasAddLinearCombinationClamp< | |||
| ElementOutput, 16, ElementAccumulator, ElementBias, | |||
| ElementCompute>; | |||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma}; | |||
| DISPATCH_KERNEL; | |||
| } | |||
| case NonlineMode::RELU: { | |||
| using EpilogueOp = cutlass::epilogue::thread:: | |||
| BiasAddLinearCombinationReluClamp< | |||
| ElementOutput, 16, ElementAccumulator, ElementBias, | |||
| ElementCompute>; | |||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma, 0}; | |||
| DISPATCH_KERNEL; | |||
| } | |||
| case NonlineMode::H_SWISH: { | |||
| using EpilogueOp = cutlass::epilogue::thread:: | |||
| BiasAddLinearCombinationHSwishClamp< | |||
| ElementOutput, 16, ElementAccumulator, ElementBias, | |||
| ElementCompute>; | |||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma, scale}; | |||
| DISPATCH_KERNEL; | |||
| } | |||
| default: | |||
| megdnn_assert(false, | |||
| "unsupported nonlinear mode for conv bias operator"); | |||
| } | |||
| #undef DISPATCH_KERNEL_WITH_TILE_SHAPE | |||
| #undef DISPATCH_KERNEL | |||
| } | |||
| #endif | |||
| #define INST(need_load_from_const_mem) \ | |||
| template void megdnn::cuda::cutlass_wrapper:: \ | |||
| do_conv_bias_int4_int4_implicit_gemm_imma_ncdiv64hw64< \ | |||
| need_load_from_const_mem>( \ | |||
| const int8_t* d_src, const int8_t* d_filter, \ | |||
| const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, \ | |||
| int* workspace, const convolution::ConvParam& param, \ | |||
| uint32_t nonlinear_mode, float alpha, float beta, \ | |||
| float gamma, float scale, \ | |||
| const GemmCoord& threadblock_shape, \ | |||
| const GemmCoord& warp_shape, int stages, \ | |||
| cudaStream_t stream); | |||
| INST(true); | |||
| #undef INST | |||
| /* ====== cutlass kernel wrapper for uint4 x int4 nchw64 layout ====== */ | |||
| #if MEGDNN_TEGRA_X1 | |||
| template <bool NeedLoadFromConstMem> | |||
| void megdnn::cuda::cutlass_wrapper:: | |||
| do_conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64( | |||
| const uint8_t* /* d_src */, const int8_t* /* d_filter */, | |||
| const int32_t* /* d_bias */, const uint8_t* /* d_z */, | |||
| uint8_t* /* d_dst */, int* /* workspace */, | |||
| const convolution::ConvParam& /* param */, | |||
| uint32_t /* nonlinear_mode */, float /* alpha */, | |||
| float /* beta */, float /* gamma */, float /* delta */, | |||
| float /* theta */, float /* scale */, | |||
| uint8_t /* src_zero_point */, | |||
| const GemmCoord& /* threadblock_shape */, | |||
| const GemmCoord& /* warp_shape */, int /* stages */, | |||
| cudaStream_t /* stream */) {} | |||
| #else | |||
| template <bool NeedLoadFromConstMem> | |||
| void megdnn::cuda::cutlass_wrapper:: | |||
| do_conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64( | |||
| const uint8_t* d_src, const int8_t* d_filter, | |||
| const int32_t* d_bias, const uint8_t* d_z, uint8_t* d_dst, | |||
| int* workspace, const convolution::ConvParam& param, | |||
| uint32_t nonlinear_mode, float alpha, float beta, float gamma, | |||
| float delta, float theta, float /* scale */, | |||
| uint8_t src_zero_point, const GemmCoord& threadblock_shape, | |||
| const GemmCoord& warp_shape, int stages, cudaStream_t stream) { | |||
| #define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \ | |||
| threadblock_k_, warp_m_, warp_n_, \ | |||
| warp_k_, stage_) \ | |||
| if (threadblock_shape.m() == threadblock_m_ && \ | |||
| threadblock_shape.n() == threadblock_n_ && \ | |||
| threadblock_shape.k() == threadblock_k_ && \ | |||
| warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ | |||
| warp_shape.k() == warp_k_ && stages == stage_) { \ | |||
| using ThreadBlockShape = \ | |||
| cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \ | |||
| threadblock_k_>; \ | |||
| using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \ | |||
| using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; \ | |||
| using Convolution = cutlass::conv::device::Convolution< \ | |||
| cutlass::uint4b_t, cutlass::layout::TensorNCxHWx<64>, \ | |||
| cutlass::int4b_t, cutlass::layout::TensorCxRSKx<64>, \ | |||
| ElementOutput, cutlass::layout::TensorNCxHWx<64>, int32_t, \ | |||
| cutlass::layout::TensorNCxHWx<64>, int32_t, \ | |||
| cutlass::conv::ConvType::kConvolution, \ | |||
| cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \ | |||
| ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ | |||
| cutlass::conv::threadblock:: \ | |||
| ConvolutionFpropTransThreadblockSwizzle, \ | |||
| stage_, 32, 32, NeedLoadFromConstMem, \ | |||
| cutlass::arch::OpMultiplyAddSaturate, \ | |||
| cutlass::conv::ImplicitGemmMode::GEMM_TN, true>; \ | |||
| typename Convolution::ConvolutionParameter conv_param( \ | |||
| param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ | |||
| param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ | |||
| param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \ | |||
| return cutlass_convolution_wrapper<Convolution>( \ | |||
| reinterpret_cast<const cutlass::uint4b_t*>(d_src), \ | |||
| reinterpret_cast<const cutlass::int4b_t*>(d_filter), d_bias, \ | |||
| reinterpret_cast<const cutlass::uint4b_t*>(d_z), \ | |||
| reinterpret_cast<cutlass::uint4b_t*>(d_dst), workspace, \ | |||
| conv_param, epilogue, stream, {src_zero_point}); \ | |||
| } | |||
| #define DISPATCH_KERNEL \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 128, 64, 64, 128, 2); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 256, 128, 64, 64, 128, 2); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 128, 64, 64, 128, 2); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1); \ | |||
| megdnn_assert(false, \ | |||
| "unsupported threadblock shape (%dx%dx%d) and warp shape " \ | |||
| "(%dx%dx%d)", \ | |||
| threadblock_shape.m(), threadblock_shape.n(), \ | |||
| threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ | |||
| warp_shape.k()); | |||
| using ElementOutput = cutlass::uint4b_t; | |||
| using ElementAccumulator = int32_t; | |||
| using ElementBias = int32_t; | |||
| using ElementCompute = float; | |||
| using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode; | |||
| switch (nonlinear_mode) { | |||
| case NonlineMode::IDENTITY: { | |||
| using EpilogueOp = | |||
| cutlass::epilogue::thread::BiasAddLinearCombinationClamp< | |||
| ElementOutput, 16, ElementAccumulator, ElementBias, | |||
| ElementCompute>; | |||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma, | |||
| delta + theta}; | |||
| DISPATCH_KERNEL; | |||
| } | |||
| case NonlineMode::RELU: { | |||
| using EpilogueOp = cutlass::epilogue::thread:: | |||
| BiasAddLinearCombinationReluClamp< | |||
| ElementOutput, 16, ElementAccumulator, ElementBias, | |||
| ElementCompute>; | |||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma, | |||
| 0, delta, theta}; | |||
| DISPATCH_KERNEL; | |||
| } | |||
| default: | |||
| megdnn_assert(false, | |||
| "unsupported nonlinear mode for conv bias operator"); | |||
| } | |||
| #undef DISPATCH_KERNEL_WITH_TILE_SHAPE | |||
| #undef DISPATCH_KERNEL | |||
| } | |||
| #endif | |||
| #define INST(need_load_from_const_mem) \ | |||
| template void megdnn::cuda::cutlass_wrapper:: \ | |||
| do_conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64< \ | |||
| need_load_from_const_mem>( \ | |||
| const uint8_t* d_src, const int8_t* d_filter, \ | |||
| const int32_t* d_bias, const uint8_t* d_z, uint8_t* d_dst, \ | |||
| int* workspace, const convolution::ConvParam& param, \ | |||
| uint32_t nonlinear_mode, float alpha, float beta, \ | |||
| float gamma, float delta, float theta, float scale, \ | |||
| uint8_t src_zero_point, \ | |||
| const GemmCoord& threadblock_shape, \ | |||
| const GemmCoord& warp_shape, int stages, \ | |||
| cudaStream_t stream); | |||
| INST(true); | |||
| #undef INST | |||
| /* ====== cutlass kernel wrapper for int4 x int4 nhwc layout ====== */ | |||
| #if MEGDNN_TEGRA_X1 | |||
| template <bool NeedLoadFromConstMem> | |||
| void megdnn::cuda::cutlass_wrapper:: | |||
| do_conv_bias_int4_int4_implicit_gemm_imma_nhwc( | |||
| const int8_t* /* d_src */, const int8_t* /* d_filter */, | |||
| const int32_t* /* d_bias */, const int8_t* /* d_z */, | |||
| int8_t* /* d_dst */, int* /* workspace */, | |||
| const convolution::ConvParam& /* param */, | |||
| uint32_t /* nonlinear_mode */, float /* alpha */, | |||
| float /* beta */, float /* gamma */, float /* scale */, | |||
| const GemmCoord& /* threadblock_shape */, | |||
| const GemmCoord& /* warp_shape */, | |||
| const int32_t /* access_size */, int /* stages */, | |||
| cudaStream_t /* stream */) {} | |||
| #else | |||
| template <bool NeedLoadFromConstMem> | |||
| void megdnn::cuda::cutlass_wrapper:: | |||
| do_conv_bias_int4_int4_implicit_gemm_imma_nhwc( | |||
| const int8_t* d_src, const int8_t* d_filter, | |||
| const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, | |||
| int* workspace, const convolution::ConvParam& param, | |||
| uint32_t nonlinear_mode, float alpha, float beta, float gamma, | |||
| float scale, const GemmCoord& threadblock_shape, | |||
| const GemmCoord& warp_shape, const int32_t access_size, | |||
| int stages, cudaStream_t stream) { | |||
| bool without_shared_load = | |||
| ((param.co % threadblock_shape.n() == 0) && | |||
| (threadblock_shape.n() == 32 || threadblock_shape.n() == 64)); | |||
| int out_elements_per_access = | |||
| without_shared_load ? threadblock_shape.n() / 4 : 8; | |||
| #define RUN_CUTLASS_WRAPPER(stage_, access_size_, without_shared_load_) \ | |||
| using Convolution = cutlass::conv::device::Convolution< \ | |||
| cutlass::int4b_t, cutlass::layout::TensorNHWC, cutlass::int4b_t, \ | |||
| cutlass::layout::TensorNCxHWx<access_size_>, ElementOutput, \ | |||
| cutlass::layout::TensorNHWC, int32_t, cutlass::layout::TensorNHWC, \ | |||
| int32_t, cutlass::conv::ConvType::kConvolution, \ | |||
| cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \ | |||
| ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ | |||
| cutlass::conv::threadblock:: \ | |||
| ConvolutionFpropTransThreadblockSwizzle, \ | |||
| stage_, access_size_, access_size_, NeedLoadFromConstMem, \ | |||
| cutlass::arch::OpMultiplyAddSaturate, \ | |||
| cutlass::conv::ImplicitGemmMode::GEMM_TN, without_shared_load_>; \ | |||
| typename Convolution::ConvolutionParameter conv_param( \ | |||
| param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ | |||
| param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ | |||
| param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \ | |||
| return cutlass_convolution_wrapper<Convolution>( \ | |||
| reinterpret_cast<const cutlass::int4b_t*>(d_src), \ | |||
| reinterpret_cast<const cutlass::int4b_t*>(d_filter), d_bias, \ | |||
| reinterpret_cast<const cutlass::int4b_t*>(d_z), \ | |||
| reinterpret_cast<cutlass::int4b_t*>(d_dst), workspace, conv_param, \ | |||
| epilogue, stream); | |||
| #define DISPATCH_KERNEL_WITH_TILE_SHAPE( \ | |||
| threadblock_m_, threadblock_n_, threadblock_k_, warp_m_, warp_n_, \ | |||
| warp_k_, stage_, access_size_, out_elements_per_access_, \ | |||
| without_shared_load_) \ | |||
| if (threadblock_shape.m() == threadblock_m_ && \ | |||
| threadblock_shape.n() == threadblock_n_ && \ | |||
| threadblock_shape.k() == threadblock_k_ && \ | |||
| warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ | |||
| warp_shape.k() == warp_k_ && stages == stage_ && \ | |||
| access_size == access_size_ && \ | |||
| out_elements_per_access == out_elements_per_access_ && \ | |||
| without_shared_load == without_shared_load_) { \ | |||
| using ThreadBlockShape = \ | |||
| cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \ | |||
| threadblock_k_>; \ | |||
| using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \ | |||
| using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; \ | |||
| using ElementOutput = cutlass::int4b_t; \ | |||
| using ElementAccumulator = int32_t; \ | |||
| using ElementBias = int32_t; \ | |||
| using ElementCompute = float; \ | |||
| using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode; \ | |||
| switch (nonlinear_mode) { \ | |||
| case NonlineMode::IDENTITY: { \ | |||
| using EpilogueOp = cutlass::epilogue::thread:: \ | |||
| BiasAddLinearCombinationClamp< \ | |||
| ElementOutput, out_elements_per_access_, \ | |||
| ElementAccumulator, ElementBias, \ | |||
| ElementCompute>; \ | |||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma}; \ | |||
| RUN_CUTLASS_WRAPPER(stage_, access_size_, \ | |||
| without_shared_load_); \ | |||
| } \ | |||
| case NonlineMode::RELU: { \ | |||
| using EpilogueOp = cutlass::epilogue::thread:: \ | |||
| BiasAddLinearCombinationReluClamp< \ | |||
| ElementOutput, out_elements_per_access_, \ | |||
| ElementAccumulator, ElementBias, \ | |||
| ElementCompute>; \ | |||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma, 0}; \ | |||
| RUN_CUTLASS_WRAPPER(stage_, access_size_, \ | |||
| without_shared_load_); \ | |||
| } \ | |||
| case NonlineMode::H_SWISH: { \ | |||
| using EpilogueOp = cutlass::epilogue::thread:: \ | |||
| BiasAddLinearCombinationHSwishClamp< \ | |||
| ElementOutput, out_elements_per_access_, \ | |||
| ElementAccumulator, ElementBias, \ | |||
| ElementCompute>; \ | |||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma, \ | |||
| scale}; \ | |||
| RUN_CUTLASS_WRAPPER(stage_, access_size_, \ | |||
| without_shared_load_); \ | |||
| } \ | |||
| default: \ | |||
| megdnn_assert( \ | |||
| false, \ | |||
| "unsupported nonlinear mode for conv bias operator"); \ | |||
| } \ | |||
| } | |||
| #define DISPATCH_KERNEL \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 32, 8, false); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 16, 8, false); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 8, 8, false); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 32, 8, false); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 16, 8, false); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 8, 8, false); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 32, 8, true); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 16, 8, true); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 8, 8, true); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 32, 16, true); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 16, 16, true); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 8, 16, true); \ | |||
| megdnn_assert(false, \ | |||
| "unsupported threadblock shape (%dx%dx%d) and warp shape " \ | |||
| "(%dx%dx%d) and access_size (%d)", \ | |||
| threadblock_shape.m(), threadblock_shape.n(), \ | |||
| threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ | |||
| warp_shape.k(), access_size); | |||
| DISPATCH_KERNEL; | |||
| #undef RUN_CUTLASS_WRAPPER | |||
| #undef DISPATCH_KERNEL_WITH_TILE_SHAPE | |||
| #undef DISPATCH_KERNEL | |||
| } | |||
| #endif | |||
| #define INST(need_load_from_const_mem) \ | |||
| template void megdnn::cuda::cutlass_wrapper:: \ | |||
| do_conv_bias_int4_int4_implicit_gemm_imma_nhwc< \ | |||
| need_load_from_const_mem>( \ | |||
| const int8_t* d_src, const int8_t* d_filter, \ | |||
| const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, \ | |||
| int* workspace, const convolution::ConvParam& param, \ | |||
| uint32_t nonlinear_mode, float alpha, float beta, \ | |||
| float gamma, float scale, \ | |||
| const GemmCoord& threadblock_shape, \ | |||
| const GemmCoord& warp_shape, const int32_t access_size, \ | |||
| int stages, cudaStream_t stream); | |||
| INST(true); | |||
| INST(false); | |||
| #undef INST | |||
| /* ====== cutlass kernel wrapper for uint4 x int4 nhwc layout ====== */ | |||
| #if MEGDNN_TEGRA_X1 | |||
| template <bool NeedLoadFromConstMem> | |||
| void megdnn::cuda::cutlass_wrapper:: | |||
| do_conv_bias_uint4_int4_implicit_gemm_imma_nhwc( | |||
| const uint8_t* /* d_src */, const int8_t* /* d_filter */, | |||
| const int32_t* /* d_bias */, const uint8_t* /* d_z */, | |||
| uint8_t* /* d_dst */, int* /* workspace */, | |||
| const convolution::ConvParam& /* param */, | |||
| uint32_t /* nonlinear_mode */, float /* alpha */, | |||
| float /* beta */, float /* gamma */, float /* delta */, | |||
| float /* theta */, float /* scale */, | |||
| uint8_t /* src_zero_point */, | |||
| const GemmCoord& /* threadblock_shape */, | |||
| const GemmCoord& /* warp_shape */, | |||
| const int32_t /* access_size */, int /* stages */, | |||
| cudaStream_t /* stream */) {} | |||
| #else | |||
| template <bool NeedLoadFromConstMem> | |||
| void megdnn::cuda::cutlass_wrapper:: | |||
| do_conv_bias_uint4_int4_implicit_gemm_imma_nhwc( | |||
| const uint8_t* d_src, const int8_t* d_filter, | |||
| const int32_t* d_bias, const uint8_t* d_z, uint8_t* d_dst, | |||
| int* workspace, const convolution::ConvParam& param, | |||
| uint32_t nonlinear_mode, float alpha, float beta, float gamma, | |||
| float delta, float theta, float /* scale */, | |||
| uint8_t src_zero_point, const GemmCoord& threadblock_shape, | |||
| const GemmCoord& warp_shape, const int32_t access_size, | |||
| int stages, cudaStream_t stream) { | |||
| bool without_shared_load = | |||
| ((param.co % threadblock_shape.n() == 0) && | |||
| (threadblock_shape.n() == 32 || threadblock_shape.n() == 64)); | |||
| int out_elements_per_access = | |||
| without_shared_load ? threadblock_shape.n() / 4 : 8; | |||
| #define RUN_CUTLASS_WRAPPER(stage_, access_size_, without_shared_load_) \ | |||
| using Convolution = cutlass::conv::device::Convolution< \ | |||
| cutlass::uint4b_t, cutlass::layout::TensorNHWC, cutlass::int4b_t, \ | |||
| cutlass::layout::TensorNCxHWx<access_size_>, ElementOutput, \ | |||
| cutlass::layout::TensorNHWC, int32_t, cutlass::layout::TensorNHWC, \ | |||
| int32_t, cutlass::conv::ConvType::kConvolution, \ | |||
| cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \ | |||
| ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ | |||
| cutlass::conv::threadblock:: \ | |||
| ConvolutionFpropTransThreadblockSwizzle, \ | |||
| stage_, access_size_, access_size_, NeedLoadFromConstMem, \ | |||
| cutlass::arch::OpMultiplyAddSaturate, \ | |||
| cutlass::conv::ImplicitGemmMode::GEMM_TN, without_shared_load_>; \ | |||
| typename Convolution::ConvolutionParameter conv_param( \ | |||
| param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ | |||
| param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ | |||
| param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \ | |||
| return cutlass_convolution_wrapper<Convolution>( \ | |||
| reinterpret_cast<const cutlass::uint4b_t*>(d_src), \ | |||
| reinterpret_cast<const cutlass::int4b_t*>(d_filter), d_bias, \ | |||
| reinterpret_cast<const cutlass::uint4b_t*>(d_z), \ | |||
| reinterpret_cast<cutlass::uint4b_t*>(d_dst), workspace, \ | |||
| conv_param, epilogue, stream, {src_zero_point}); | |||
| #define DISPATCH_KERNEL_WITH_TILE_SHAPE( \ | |||
| threadblock_m_, threadblock_n_, threadblock_k_, warp_m_, warp_n_, \ | |||
| warp_k_, stage_, access_size_, out_elements_per_access_, \ | |||
| without_shared_load_) \ | |||
| if (threadblock_shape.m() == threadblock_m_ && \ | |||
| threadblock_shape.n() == threadblock_n_ && \ | |||
| threadblock_shape.k() == threadblock_k_ && \ | |||
| warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ | |||
| warp_shape.k() == warp_k_ && stages == stage_ && \ | |||
| access_size == access_size_ && \ | |||
| out_elements_per_access == out_elements_per_access_ && \ | |||
| without_shared_load == without_shared_load_) { \ | |||
| using ThreadBlockShape = \ | |||
| cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \ | |||
| threadblock_k_>; \ | |||
| using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \ | |||
| using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; \ | |||
| using ElementOutput = cutlass::uint4b_t; \ | |||
| using ElementAccumulator = int32_t; \ | |||
| using ElementBias = int32_t; \ | |||
| using ElementCompute = float; \ | |||
| using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode; \ | |||
| switch (nonlinear_mode) { \ | |||
| case NonlineMode::IDENTITY: { \ | |||
| using EpilogueOp = cutlass::epilogue::thread:: \ | |||
| BiasAddLinearCombinationClamp< \ | |||
| ElementOutput, out_elements_per_access_, \ | |||
| ElementAccumulator, ElementBias, \ | |||
| ElementCompute>; \ | |||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma, \ | |||
| delta + theta}; \ | |||
| RUN_CUTLASS_WRAPPER(stage_, access_size_, \ | |||
| without_shared_load_); \ | |||
| } \ | |||
| case NonlineMode::RELU: { \ | |||
| using EpilogueOp = cutlass::epilogue::thread:: \ | |||
| BiasAddLinearCombinationReluClamp< \ | |||
| ElementOutput, out_elements_per_access_, \ | |||
| ElementAccumulator, ElementBias, \ | |||
| ElementCompute>; \ | |||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma, \ | |||
| 0, delta, theta}; \ | |||
| RUN_CUTLASS_WRAPPER(stage_, access_size_, \ | |||
| without_shared_load_); \ | |||
| } \ | |||
| default: \ | |||
| megdnn_assert( \ | |||
| false, \ | |||
| "unsupported nonlinear mode for conv bias operator"); \ | |||
| } \ | |||
| } | |||
| #define DISPATCH_KERNEL \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 32, 8, false); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 16, 8, false); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 8, 8, false); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 32, 8, false); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 16, 8, false); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 8, 8, false); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 32, 8, true); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 16, 8, true); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 8, 8, true); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 32, 16, true); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 16, 16, true); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 8, 16, true); \ | |||
| megdnn_assert(false, \ | |||
| "unsupported threadblock shape (%dx%dx%d) and warp shape " \ | |||
| "(%dx%dx%d) and access_size (%d)", \ | |||
| threadblock_shape.m(), threadblock_shape.n(), \ | |||
| threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ | |||
| warp_shape.k(), access_size); | |||
| DISPATCH_KERNEL; | |||
| #undef RUN_CUTLASS_WRAPPER | |||
| #undef DISPATCH_KERNEL_WITH_TILE_SHAPE | |||
| #undef DISPATCH_KERNEL | |||
| } | |||
| #endif | |||
| #define INST(need_load_from_const_mem) \ | |||
| template void megdnn::cuda::cutlass_wrapper:: \ | |||
| do_conv_bias_uint4_int4_implicit_gemm_imma_nhwc< \ | |||
| need_load_from_const_mem>( \ | |||
| const uint8_t* d_src, const int8_t* d_filter, \ | |||
| const int32_t* d_bias, const uint8_t* d_z, uint8_t* d_dst, \ | |||
| int* workspace, const convolution::ConvParam& param, \ | |||
| uint32_t nonlinear_mode, float alpha, float beta, \ | |||
| float gamma, float delta, float theta, float scale, \ | |||
| uint8_t src_zero_point, \ | |||
| const GemmCoord& threadblock_shape, \ | |||
| const GemmCoord& warp_shape, const int32_t access_size, \ | |||
| int stages, cudaStream_t stream); | |||
| INST(true); | |||
| INST(false); | |||
| #undef INST | |||
| // vim: syntax=cuda.doxygen | |||
| @@ -1,804 +0,0 @@ | |||
| /** | |||
| * \file dnn/src/cuda/conv_bias/cutlass_convolution_wrapper.cu | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| // ignore warning of cutlass | |||
| #pragma GCC diagnostic push | |||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||
| #if !MEGDNN_TEGRA_X1 | |||
| #include "cutlass/convolution/device/convolution.h" | |||
| #endif | |||
| #include "src/common/opr_param_defs_enumv.cuh" | |||
| #include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh" | |||
| #pragma GCC diagnostic pop | |||
| using namespace megdnn; | |||
| using namespace cuda; | |||
| using namespace cutlass_wrapper; | |||
| /* ====== cutlass kernel wrapper for int8 nchw32 layout ====== */ | |||
| #if MEGDNN_TEGRA_X1 | |||
| template <bool NeedLoadFromConstMem> | |||
| void megdnn::cuda::cutlass_wrapper:: | |||
| do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32( | |||
| const int8_t* /* d_src */, const int8_t* /* d_filter */, | |||
| const int32_t* /* d_bias */, const int8_t* /* d_z */, | |||
| int8_t* /* d_dst */, int* /* workspace */, | |||
| const convolution::ConvParam& /* param */, | |||
| uint32_t /* nonlinear_mode */, float /* alpha */, | |||
| float /* beta */, float /* gamma */, float /* scale */, | |||
| const GemmCoord& /* threadblock_shape */, | |||
| const GemmCoord& /* warp_shape */, int /* stages */, | |||
| cudaStream_t /* stream */) {} | |||
| #else | |||
| template <bool NeedLoadFromConstMem> | |||
| void megdnn::cuda::cutlass_wrapper:: | |||
| do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32( | |||
| const int8_t* d_src, const int8_t* d_filter, | |||
| const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, | |||
| int* workspace, const convolution::ConvParam& param, | |||
| uint32_t nonlinear_mode, float alpha, float beta, float gamma, | |||
| float scale, const GemmCoord& threadblock_shape, | |||
| const GemmCoord& warp_shape, int stages, cudaStream_t stream) { | |||
| #define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \ | |||
| threadblock_k_, warp_m_, warp_n_, \ | |||
| warp_k_, stage_) \ | |||
| if (threadblock_shape.m() == threadblock_m_ && \ | |||
| threadblock_shape.n() == threadblock_n_ && \ | |||
| threadblock_shape.k() == threadblock_k_ && \ | |||
| warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ | |||
| warp_shape.k() == warp_k_ && stages == stage_) { \ | |||
| using ThreadBlockShape = \ | |||
| cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \ | |||
| threadblock_k_>; \ | |||
| using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \ | |||
| using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; \ | |||
| using Convolution = cutlass::conv::device::Convolution< \ | |||
| int8_t, cutlass::layout::TensorNCxHWx<32>, int8_t, \ | |||
| cutlass::layout::TensorCxRSKx<32>, ElementOutput, \ | |||
| cutlass::layout::TensorNCxHWx<32>, int32_t, \ | |||
| cutlass::layout::TensorNCxHWx<32>, int32_t, \ | |||
| cutlass::conv::ConvType::kConvolution, \ | |||
| cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \ | |||
| ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ | |||
| cutlass::conv::threadblock:: \ | |||
| ConvolutionFpropTransThreadblockSwizzle, \ | |||
| stage_, 16, 16, NeedLoadFromConstMem, \ | |||
| cutlass::arch::OpMultiplyAddSaturate, \ | |||
| cutlass::conv::ImplicitGemmMode::GEMM_TN, true>; \ | |||
| typename Convolution::ConvolutionParameter conv_param( \ | |||
| param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ | |||
| param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ | |||
| param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \ | |||
| return cutlass_convolution_wrapper<Convolution>( \ | |||
| d_src, d_filter, d_bias, d_z, d_dst, workspace, conv_param, \ | |||
| epilogue, stream); \ | |||
| } | |||
| #define DISPATCH_KERNEL \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 256, 64, 64, 64, 64, 2); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(256, 128, 64, 64, 64, 64, 2); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 64, 64, 64, 64, 2); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 32, 64, 2); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 64, 32, 64, 64, 2); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 32, 64, 32, 32, 1); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 32, 64, 32, 32, 1); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 32, 64, 32, 1); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 1); \ | |||
| megdnn_assert(false, \ | |||
| "unsupported threadblock shape (%dx%dx%d) and warp shape " \ | |||
| "(%dx%dx%d)", \ | |||
| threadblock_shape.m(), threadblock_shape.n(), \ | |||
| threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ | |||
| warp_shape.k()); | |||
| using ElementOutput = int8_t; | |||
| using ElementAccumulator = int32_t; | |||
| using ElementBias = int32_t; | |||
| using ElementCompute = float; | |||
| using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode; | |||
| switch (nonlinear_mode) { | |||
| case NonlineMode::IDENTITY: { | |||
| using EpilogueOp = | |||
| cutlass::epilogue::thread::BiasAddLinearCombinationClamp< | |||
| ElementOutput, 8, ElementAccumulator, ElementBias, | |||
| ElementCompute>; | |||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma}; | |||
| DISPATCH_KERNEL; | |||
| } | |||
| case NonlineMode::RELU: { | |||
| using EpilogueOp = cutlass::epilogue::thread:: | |||
| BiasAddLinearCombinationReluClamp< | |||
| ElementOutput, 8, ElementAccumulator, ElementBias, | |||
| ElementCompute>; | |||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma, 0}; | |||
| DISPATCH_KERNEL; | |||
| } | |||
| case NonlineMode::H_SWISH: { | |||
| using EpilogueOp = cutlass::epilogue::thread:: | |||
| BiasAddLinearCombinationHSwishClamp< | |||
| ElementOutput, 8, ElementAccumulator, ElementBias, | |||
| ElementCompute>; | |||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma, scale}; | |||
| DISPATCH_KERNEL; | |||
| } | |||
| default: | |||
| megdnn_assert(false, | |||
| "unsupported nonlinear mode for conv bias operator"); | |||
| } | |||
| #undef DISPATCH_KERNEL_WITH_TILE_SHAPE | |||
| #undef DISPATCH_KERNEL | |||
| } | |||
| #endif | |||
| #define INST(need_load_from_const_mem) \ | |||
| template void megdnn::cuda::cutlass_wrapper:: \ | |||
| do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32< \ | |||
| need_load_from_const_mem>( \ | |||
| const int8_t* d_src, const int8_t* d_filter, \ | |||
| const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, \ | |||
| int* workspace, const convolution::ConvParam& param, \ | |||
| uint32_t nonlinear_mode, float alpha, float beta, \ | |||
| float gamma, float scale, \ | |||
| const GemmCoord& threadblock_shape, \ | |||
| const GemmCoord& warp_shape, int stages, \ | |||
| cudaStream_t stream); | |||
| INST(true); | |||
| INST(false); | |||
| #undef INST | |||
| /* ===== cutlass kernel wrapper for int8 nchw32 layout and nchw4 output ===== */ | |||
| #if MEGDNN_TEGRA_X1 | |||
| template <bool NeedLoadFromConstMem> | |||
| void megdnn::cuda::cutlass_wrapper:: | |||
| do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4( | |||
| const int8_t* /* d_src */, const int8_t* /* d_filter */, | |||
| const int32_t* /* d_bias */, const int8_t* /* d_z */, | |||
| int8_t* /* d_dst */, int* /* workspace */, | |||
| const convolution::ConvParam& /* param */, | |||
| uint32_t /* nonlinear_mode */, float /* alpha */, | |||
| float /* beta */, float /* gamma */, float /* scale */, | |||
| const GemmCoord& /* threadblock_shape */, | |||
| const GemmCoord& /* warp_shape */, int /* stages */, | |||
| cudaStream_t /* stream */) {} | |||
| #else | |||
| template <bool NeedLoadFromConstMem> | |||
| void megdnn::cuda::cutlass_wrapper:: | |||
| do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4( | |||
| const int8_t* d_src, const int8_t* d_filter, | |||
| const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, | |||
| int* workspace, const convolution::ConvParam& param, | |||
| uint32_t nonlinear_mode, float alpha, float beta, float gamma, | |||
| float scale, const GemmCoord& threadblock_shape, | |||
| const GemmCoord& warp_shape, int stages, cudaStream_t stream) { | |||
| #define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \ | |||
| threadblock_k_, warp_m_, warp_n_, \ | |||
| warp_k_, stage_) \ | |||
| if (threadblock_shape.m() == threadblock_m_ && \ | |||
| threadblock_shape.n() == threadblock_n_ && \ | |||
| threadblock_shape.k() == threadblock_k_ && \ | |||
| warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ | |||
| warp_shape.k() == warp_k_ && stages == stage_) { \ | |||
| using ThreadBlockShape = \ | |||
| cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \ | |||
| threadblock_k_>; \ | |||
| using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \ | |||
| using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; \ | |||
| using Convolution = cutlass::conv::device::Convolution< \ | |||
| int8_t, cutlass::layout::TensorNCxHWx<32>, int8_t, \ | |||
| cutlass::layout::TensorCxRSKx<32>, ElementOutput, \ | |||
| cutlass::layout::TensorNCxHWx<4>, int32_t, \ | |||
| cutlass::layout::TensorNCxHWx<4>, int32_t, \ | |||
| cutlass::conv::ConvType::kConvolution, \ | |||
| cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \ | |||
| ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ | |||
| cutlass::conv::threadblock:: \ | |||
| ConvolutionFpropNCxHWxThreadblockSwizzle, \ | |||
| stage_, 16, 16, NeedLoadFromConstMem>; \ | |||
| typename Convolution::ConvolutionParameter conv_param( \ | |||
| param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ | |||
| param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ | |||
| param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \ | |||
| return cutlass_convolution_wrapper<Convolution>( \ | |||
| d_src, d_filter, d_bias, d_z, d_dst, workspace, conv_param, \ | |||
| epilogue, stream); \ | |||
| } | |||
| #define DISPATCH_KERNEL \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 256, 64, 64, 64, 64, 2); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(256, 128, 64, 64, 64, 64, 2); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 64, 64, 64, 64, 2); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 32, 64, 2); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 64, 32, 64, 64, 2); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 32, 64, 32, 32, 1); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 32, 64, 32, 32, 1); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 32, 64, 32, 1); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 1); \ | |||
| megdnn_assert(false, \ | |||
| "unsupported threadblock shape (%dx%dx%d) and warp shape " \ | |||
| "(%dx%dx%d)", \ | |||
| threadblock_shape.m(), threadblock_shape.n(), \ | |||
| threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ | |||
| warp_shape.k()); | |||
| using ElementOutput = int8_t; | |||
| using ElementAccumulator = int32_t; | |||
| using ElementBias = int32_t; | |||
| using ElementCompute = float; | |||
| using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode; | |||
| switch (nonlinear_mode) { | |||
| case NonlineMode::IDENTITY: { | |||
| using EpilogueOp = | |||
| cutlass::epilogue::thread::BiasAddLinearCombinationClamp< | |||
| ElementOutput, 4, ElementAccumulator, ElementBias, | |||
| ElementCompute>; | |||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma}; | |||
| DISPATCH_KERNEL; | |||
| } | |||
| case NonlineMode::RELU: { | |||
| using EpilogueOp = cutlass::epilogue::thread:: | |||
| BiasAddLinearCombinationReluClamp< | |||
| ElementOutput, 4, ElementAccumulator, ElementBias, | |||
| ElementCompute>; | |||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma, 0}; | |||
| DISPATCH_KERNEL; | |||
| } | |||
| case NonlineMode::H_SWISH: { | |||
| using EpilogueOp = cutlass::epilogue::thread:: | |||
| BiasAddLinearCombinationHSwishClamp< | |||
| ElementOutput, 4, ElementAccumulator, ElementBias, | |||
| ElementCompute>; | |||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma, scale}; | |||
| DISPATCH_KERNEL; | |||
| } | |||
| default: | |||
| megdnn_assert(false, | |||
| "unsupported nonlinear mode for conv bias operator"); | |||
| } | |||
| #undef DISPATCH_KERNEL_WITH_TILE_SHAPE | |||
| #undef DISPATCH_KERNEL | |||
| } | |||
| #endif | |||
| #define INST(need_load_from_const_mem) \ | |||
| template void megdnn::cuda::cutlass_wrapper:: \ | |||
| do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4< \ | |||
| need_load_from_const_mem>( \ | |||
| const int8_t* d_src, const int8_t* d_filter, \ | |||
| const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, \ | |||
| int* workspace, const convolution::ConvParam& param, \ | |||
| uint32_t nonlinear_mode, float alpha, float beta, \ | |||
| float gamma, float scale, \ | |||
| const GemmCoord& threadblock_shape, \ | |||
| const GemmCoord& warp_shape, int stages, \ | |||
| cudaStream_t stream); | |||
| INST(true); | |||
| INST(false); | |||
| #undef INST | |||
| /* ====== cutlass kernel wrapper for int8 nchw4 layout ====== */ | |||
| #if MEGDNN_TEGRA_X1 | |||
| template <bool NeedLoadFromConstMem> | |||
| void megdnn::cuda::cutlass_wrapper:: | |||
| do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4( | |||
| const int8_t* /* d_src */, const int8_t* /* d_filter */, | |||
| const int32_t* /* d_bias */, const int8_t* /* d_z */, | |||
| int8_t* /* d_dst */, int* /* workspace */, | |||
| const convolution::ConvParam& /* param */, | |||
| uint32_t /* nonlinear_mode */, float /* alpha */, | |||
| float /* beta */, float /* gamma */, float /* scale */, | |||
| const GemmCoord& /* threadblock_shape */, | |||
| const GemmCoord& /* warp_shape */, int /* stages */, | |||
| cudaStream_t /* stream */) {} | |||
| #else | |||
| template <bool NeedLoadFromConstMem> | |||
| void megdnn::cuda::cutlass_wrapper:: | |||
| do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4( | |||
| const int8_t* d_src, const int8_t* d_filter, | |||
| const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, | |||
| int* workspace, const convolution::ConvParam& param, | |||
| uint32_t nonlinear_mode, float alpha, float beta, float gamma, | |||
| float scale, const GemmCoord& threadblock_shape, | |||
| const GemmCoord& warp_shape, int stages, cudaStream_t stream) { | |||
| #define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \ | |||
| threadblock_k_, warp_m_, warp_n_, \ | |||
| warp_k_, stage_, aligned_) \ | |||
| if (threadblock_shape.m() == threadblock_m_ && \ | |||
| threadblock_shape.n() == threadblock_n_ && \ | |||
| threadblock_shape.k() == threadblock_k_ && \ | |||
| warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ | |||
| warp_shape.k() == warp_k_ && stages == stage_) { \ | |||
| using ThreadBlockShape = \ | |||
| cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \ | |||
| threadblock_k_>; \ | |||
| using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \ | |||
| using InstructionShape = cutlass::gemm::GemmShape<1, 1, 4>; \ | |||
| using Convolution = cutlass::conv::device::Convolution< \ | |||
| int8_t, cutlass::layout::TensorNCxHWx<4>, int8_t, \ | |||
| cutlass::layout::TensorCxRSKx<4>, ElementOutput, \ | |||
| cutlass::layout::TensorNCxHWx<4>, int32_t, \ | |||
| cutlass::layout::TensorNCxHWx<4>, int32_t, \ | |||
| cutlass::conv::ConvType::kConvolution, \ | |||
| cutlass::arch::OpClassSimt, cutlass::arch::Sm61, \ | |||
| ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ | |||
| cutlass::conv::threadblock:: \ | |||
| ConvolutionFpropNCxHWxThreadblockSwizzle, \ | |||
| stage_, 4, aligned_, NeedLoadFromConstMem, \ | |||
| cutlass::arch::OpMultiplyAdd>; \ | |||
| typename Convolution::ConvolutionParameter conv_param( \ | |||
| param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ | |||
| param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ | |||
| param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \ | |||
| return cutlass_convolution_wrapper<Convolution>( \ | |||
| d_src, d_filter, d_bias, d_z, d_dst, workspace, conv_param, \ | |||
| epilogue, stream); \ | |||
| } | |||
| #define DISPATCH_KERNEL \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 32, 64, 32, 32, 2, 16); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 32, 64, 32, 32, 2, 16); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 64, 32, 32, 2, 16); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 32, 64, 32, 32, 2, 16); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 2, 16); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 64, 32, 32, 64, 32, 2, 16); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 32, 32, 64, 32, 32, 2, 16); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 128, 16, 16, 128, 16, 1, 8); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 64, 8, 16, 64, 8, 2, 4); \ | |||
| megdnn_assert(false, \ | |||
| "unsupported threadblock shape (%dx%dx%d) and warp shape " \ | |||
| "(%dx%dx%d)", \ | |||
| threadblock_shape.m(), threadblock_shape.n(), \ | |||
| threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ | |||
| warp_shape.k()); | |||
| using ElementOutput = int8_t; | |||
| using ElementAccumulator = int32_t; | |||
| using ElementBias = int32_t; | |||
| using ElementCompute = float; | |||
| using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode; | |||
| switch (nonlinear_mode) { | |||
| case NonlineMode::IDENTITY: { | |||
| using EpilogueOp = | |||
| cutlass::epilogue::thread::BiasAddLinearCombinationClamp< | |||
| ElementOutput, 4, ElementAccumulator, ElementBias, | |||
| ElementCompute>; | |||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma}; | |||
| DISPATCH_KERNEL; | |||
| } | |||
| case NonlineMode::RELU: { | |||
| using EpilogueOp = cutlass::epilogue::thread:: | |||
| BiasAddLinearCombinationReluClamp< | |||
| ElementOutput, 4, ElementAccumulator, ElementBias, | |||
| ElementCompute>; | |||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma, 0}; | |||
| DISPATCH_KERNEL; | |||
| } | |||
| case NonlineMode::H_SWISH: { | |||
| using EpilogueOp = cutlass::epilogue::thread:: | |||
| BiasAddLinearCombinationHSwishClamp< | |||
| ElementOutput, 4, ElementAccumulator, ElementBias, | |||
| ElementCompute>; | |||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma, scale}; | |||
| DISPATCH_KERNEL; | |||
| } | |||
| default: | |||
| megdnn_assert(false, | |||
| "unsupported nonlinear mode for conv bias operator"); | |||
| } | |||
| #undef DISPATCH_KERNEL_WITH_TILE_SHAPE | |||
| #undef DISPATCH_KERNEL | |||
| } | |||
| #endif | |||
| #define INST(need_load_from_const_mem) \ | |||
| template void megdnn::cuda::cutlass_wrapper:: \ | |||
| do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4< \ | |||
| need_load_from_const_mem>( \ | |||
| const int8_t* d_src, const int8_t* d_filter, \ | |||
| const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, \ | |||
| int* workspace, const convolution::ConvParam& param, \ | |||
| uint32_t nonlinear_mode, float alpha, float beta, \ | |||
| float gamma, float scale, \ | |||
| const GemmCoord& threadblock_shape, \ | |||
| const GemmCoord& warp_shape, int stages, \ | |||
| cudaStream_t stream); | |||
| INST(true); | |||
| INST(false); | |||
| #undef INST | |||
| /* ====== cutlass kernel wrapper for int8 nchw4 layout and nchw output ====== */ | |||
| #if MEGDNN_TEGRA_X1 | |||
| template <bool NeedLoadFromConstMem> | |||
| void megdnn::cuda::cutlass_wrapper:: | |||
| do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nchw( | |||
| const int8_t* /* d_src */, const int8_t* /* d_filter */, | |||
| const float* /* d_bias */, const float* /* d_z */, | |||
| float* /* d_dst */, int* /* workspace */, | |||
| const convolution::ConvParam& /* param */, | |||
| uint32_t /* nonlinear_mode */, float /* alpha */, | |||
| float /* beta */, float /* gamma */, float /* scale */, | |||
| const GemmCoord& /* threadblock_shape */, | |||
| const GemmCoord& /* warp_shape */, int /* stages */, | |||
| cudaStream_t /* stream */) {} | |||
| #else | |||
| template <bool NeedLoadFromConstMem> | |||
| void megdnn::cuda::cutlass_wrapper:: | |||
| do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nchw( | |||
| const int8_t* d_src, const int8_t* d_filter, | |||
| const float* d_bias, const float* d_z, float* d_dst, | |||
| int* workspace, const convolution::ConvParam& param, | |||
| uint32_t nonlinear_mode, float alpha, float beta, float gamma, | |||
| float scale, const GemmCoord& threadblock_shape, | |||
| const GemmCoord& warp_shape, int stages, cudaStream_t stream) { | |||
| #define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \ | |||
| threadblock_k_, warp_m_, warp_n_, \ | |||
| warp_k_, stages_, aligned_) \ | |||
| if (threadblock_shape.m() == threadblock_m_ && \ | |||
| threadblock_shape.n() == threadblock_n_ && \ | |||
| threadblock_shape.k() == threadblock_k_ && \ | |||
| warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ | |||
| warp_shape.k() == warp_k_ && stages == stages_) { \ | |||
| using ThreadBlockShape = \ | |||
| cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \ | |||
| threadblock_k_>; \ | |||
| using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \ | |||
| using InstructionShape = cutlass::gemm::GemmShape<1, 1, 4>; \ | |||
| using Convolution = cutlass::conv::device::Convolution< \ | |||
| int8_t, cutlass::layout::TensorNCxHWx<4>, int8_t, \ | |||
| cutlass::layout::TensorCxRSKx<4>, ElementOutput, \ | |||
| cutlass::layout::TensorNCHW, float, \ | |||
| cutlass::layout::TensorNCHW, int32_t, \ | |||
| cutlass::conv::ConvType::kConvolution, \ | |||
| cutlass::arch::OpClassSimt, cutlass::arch::Sm61, \ | |||
| ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ | |||
| cutlass::conv::threadblock:: \ | |||
| ConvolutionFpropNCxHWxThreadblockSwizzle, \ | |||
| stages_, 4, aligned_, NeedLoadFromConstMem, \ | |||
| cutlass::arch::OpMultiplyAdd>; \ | |||
| typename Convolution::ConvolutionParameter conv_param( \ | |||
| param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ | |||
| param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ | |||
| param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \ | |||
| return cutlass_convolution_wrapper<Convolution>( \ | |||
| d_src, d_filter, d_bias, d_z, d_dst, workspace, conv_param, \ | |||
| epilogue, stream); \ | |||
| } | |||
| #define DISPATCH_KERNEL \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 32, 64, 32, 32, 2, 16); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 32, 64, 32, 32, 2, 16); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 64, 32, 32, 2, 16); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 32, 64, 32, 32, 2, 16); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 2, 16); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 64, 32, 32, 64, 32, 2, 16); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 32, 32, 64, 32, 32, 2, 16); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 128, 16, 16, 128, 16, 1, 8); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 64, 8, 16, 64, 8, 2, 4); \ | |||
| megdnn_assert(false, \ | |||
| "unsupported threadblock shape (%dx%dx%d) and warp shape " \ | |||
| "(%dx%dx%d)", \ | |||
| threadblock_shape.m(), threadblock_shape.n(), \ | |||
| threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ | |||
| warp_shape.k()); | |||
| using ElementOutput = float; | |||
| using ElementAccumulator = int32_t; | |||
| using ElementBias = float; | |||
| using ElementCompute = float; | |||
| using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode; | |||
| switch (nonlinear_mode) { | |||
| case NonlineMode::IDENTITY: { | |||
| using EpilogueOp = | |||
| cutlass::epilogue::thread::BiasAddLinearCombination< | |||
| ElementOutput, 1, ElementAccumulator, ElementBias, | |||
| ElementCompute>; | |||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma}; | |||
| DISPATCH_KERNEL; | |||
| } | |||
| case NonlineMode::RELU: { | |||
| using EpilogueOp = | |||
| cutlass::epilogue::thread::BiasAddLinearCombinationRelu< | |||
| ElementOutput, 1, ElementAccumulator, ElementBias, | |||
| ElementCompute>; | |||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma, 0}; | |||
| DISPATCH_KERNEL; | |||
| } | |||
| case NonlineMode::H_SWISH: { | |||
| using EpilogueOp = | |||
| cutlass::epilogue::thread::BiasAddLinearCombinationHSwish< | |||
| ElementOutput, 1, ElementAccumulator, ElementBias, | |||
| ElementCompute>; | |||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma, scale}; | |||
| DISPATCH_KERNEL; | |||
| } | |||
| default: | |||
| megdnn_assert(false, | |||
| "unsupported nonlinear mode for conv bias operator"); | |||
| } | |||
| #undef DISPATCH_KERNEL_WITH_TILE_SHAPE | |||
| #undef DISPATCH_KERNEL | |||
| } | |||
| #endif | |||
| #define INST(need_load_from_const_mem) \ | |||
| template void megdnn::cuda::cutlass_wrapper:: \ | |||
| do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nchw< \ | |||
| need_load_from_const_mem>( \ | |||
| const int8_t* d_src, const int8_t* d_filter, \ | |||
| const float* d_bias, const float* d_z, float* d_dst, \ | |||
| int* workspace, const convolution::ConvParam& param, \ | |||
| uint32_t nonlinear_mode, float alpha, float beta, \ | |||
| float gamma, float scale, \ | |||
| const GemmCoord& threadblock_shape, \ | |||
| const GemmCoord& warp_shape, int stages, \ | |||
| cudaStream_t stream); | |||
| INST(true); | |||
| INST(false); | |||
| #undef INST | |||
| /* ===== cutlass kernel wrapper for int8 nchw4 layout and nchw32 output ===== */ | |||
| #if MEGDNN_TEGRA_X1 | |||
| template <bool NeedLoadFromConstMem> | |||
| void megdnn::cuda::cutlass_wrapper:: | |||
| do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32( | |||
| const int8_t* /* d_src */, const int8_t* /* d_filter */, | |||
| const int32_t* /* d_bias */, const int8_t* /* d_z */, | |||
| int8_t* /* d_dst */, int* /* workspace */, | |||
| const convolution::ConvParam& /* param */, | |||
| uint32_t /* nonlinear_mode */, float /* alpha */, | |||
| float /* beta */, float /* gamma */, float /* scale */, | |||
| const GemmCoord& /* threadblock_shape */, | |||
| const GemmCoord& /* warp_shape */, int /* stages */, | |||
| cudaStream_t /* stream */) {} | |||
| #else | |||
| template <bool NeedLoadFromConstMem> | |||
| void megdnn::cuda::cutlass_wrapper:: | |||
| do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32( | |||
| const int8_t* d_src, const int8_t* d_filter, | |||
| const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, | |||
| int* workspace, const convolution::ConvParam& param, | |||
| uint32_t nonlinear_mode, float alpha, float beta, float gamma, | |||
| float scale, const GemmCoord& threadblock_shape, | |||
| const GemmCoord& warp_shape, int stages, cudaStream_t stream) { | |||
| #define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \ | |||
| threadblock_k_, warp_m_, warp_n_, \ | |||
| warp_k_, stages_, aligned_) \ | |||
| if (threadblock_shape.m() == threadblock_m_ && \ | |||
| threadblock_shape.n() == threadblock_n_ && \ | |||
| threadblock_shape.k() == threadblock_k_ && \ | |||
| warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ | |||
| warp_shape.k() == warp_k_ && stages == stages_) { \ | |||
| using ThreadBlockShape = \ | |||
| cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \ | |||
| threadblock_k_>; \ | |||
| using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \ | |||
| using InstructionShape = cutlass::gemm::GemmShape<1, 1, 4>; \ | |||
| using Convolution = cutlass::conv::device::Convolution< \ | |||
| int8_t, cutlass::layout::TensorNCxHWx<4>, int8_t, \ | |||
| cutlass::layout::TensorCxRSKx<4>, ElementOutput, \ | |||
| cutlass::layout::TensorNCxHWx<32>, int32_t, \ | |||
| cutlass::layout::TensorNCxHWx<32>, int32_t, \ | |||
| cutlass::conv::ConvType::kConvolution, \ | |||
| cutlass::arch::OpClassSimt, cutlass::arch::Sm61, \ | |||
| ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ | |||
| cutlass::conv::threadblock:: \ | |||
| ConvolutionFpropNCxHWxThreadblockSwizzle, \ | |||
| stages_, 4, aligned_, NeedLoadFromConstMem, \ | |||
| cutlass::arch::OpMultiplyAdd>; \ | |||
| typename Convolution::ConvolutionParameter conv_param( \ | |||
| param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ | |||
| param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ | |||
| param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \ | |||
| return cutlass_convolution_wrapper<Convolution>( \ | |||
| d_src, d_filter, d_bias, d_z, d_dst, workspace, conv_param, \ | |||
| epilogue, stream); \ | |||
| } | |||
| #define DISPATCH_KERNEL \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 32, 64, 32, 32, 2, 16); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 32, 64, 32, 32, 2, 16); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 64, 32, 32, 2, 16); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 32, 64, 32, 32, 2, 16); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 2, 16); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 64, 32, 32, 64, 32, 2, 16); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 32, 32, 64, 32, 32, 2, 16); \ | |||
| megdnn_assert(false, \ | |||
| "unsupported threadblock shape (%dx%dx%d) and warp shape " \ | |||
| "(%dx%dx%d)", \ | |||
| threadblock_shape.m(), threadblock_shape.n(), \ | |||
| threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ | |||
| warp_shape.k()); | |||
| using ElementOutput = int8_t; | |||
| using ElementAccumulator = int32_t; | |||
| using ElementBias = int32_t; | |||
| using ElementCompute = float; | |||
| using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode; | |||
| switch (nonlinear_mode) { | |||
| case NonlineMode::IDENTITY: { | |||
| using EpilogueOp = | |||
| cutlass::epilogue::thread::BiasAddLinearCombinationClamp< | |||
| ElementOutput, 4, ElementAccumulator, ElementBias, | |||
| ElementCompute>; | |||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma}; | |||
| DISPATCH_KERNEL; | |||
| } | |||
| case NonlineMode::RELU: { | |||
| using EpilogueOp = cutlass::epilogue::thread:: | |||
| BiasAddLinearCombinationReluClamp< | |||
| ElementOutput, 4, ElementAccumulator, ElementBias, | |||
| ElementCompute>; | |||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma, 0}; | |||
| DISPATCH_KERNEL; | |||
| } | |||
| case NonlineMode::H_SWISH: { | |||
| using EpilogueOp = cutlass::epilogue::thread:: | |||
| BiasAddLinearCombinationHSwishClamp< | |||
| ElementOutput, 4, ElementAccumulator, ElementBias, | |||
| ElementCompute>; | |||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma, scale}; | |||
| DISPATCH_KERNEL; | |||
| } | |||
| default: | |||
| megdnn_assert(false, | |||
| "unsupported nonlinear mode for conv bias operator"); | |||
| } | |||
| #undef DISPATCH_KERNEL_WITH_TILE_SHAPE | |||
| #undef DISPATCH_KERNEL | |||
| } | |||
| #endif | |||
| #define INST(need_load_from_const_mem) \ | |||
| template void megdnn::cuda::cutlass_wrapper:: \ | |||
| do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32< \ | |||
| need_load_from_const_mem>( \ | |||
| const int8_t* d_src, const int8_t* d_filter, \ | |||
| const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, \ | |||
| int* workspace, const convolution::ConvParam& param, \ | |||
| uint32_t nonlinear_mode, float alpha, float beta, \ | |||
| float gamma, float scale, \ | |||
| const GemmCoord& threadblock_shape, \ | |||
| const GemmCoord& warp_shape, int stages, \ | |||
| cudaStream_t stream); | |||
| INST(true); | |||
| INST(false); | |||
| #undef INST | |||
| /* ===== cutlass kernel wrapper for nchw4 layout and nhwc output ===== */ | |||
| #if MEGDNN_TEGRA_X1 | |||
| template <bool signedness> | |||
| void megdnn::cuda::cutlass_wrapper:: | |||
| do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nhwc( | |||
| const int8_t* /* d_src */, const int8_t* /* d_filter */, | |||
| const int32_t* /* d_bias */, const int8_t* /* d_z */, | |||
| int8_t* /* d_dst */, int* /* workspace */, | |||
| const convolution::ConvParam& /* param */, | |||
| uint32_t /* nonlinear_mode */, float /* alpha */, | |||
| float /* beta */, float /* gamma */, float /* delta */, | |||
| float /* theta */, float /* scale */, | |||
| const GemmCoord& /* threadblock_shape */, | |||
| const GemmCoord& /* warp_shape */, int /* stages */, | |||
| cudaStream_t /* stream */) {} | |||
| #else | |||
| template <bool signedness> | |||
| void megdnn::cuda::cutlass_wrapper:: | |||
| do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nhwc( | |||
| const int8_t* d_src, const int8_t* d_filter, | |||
| const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, | |||
| int* workspace, const convolution::ConvParam& param, | |||
| uint32_t nonlinear_mode, float alpha, float beta, float gamma, | |||
| float delta, float theta, float scale, | |||
| const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, | |||
| int stages, cudaStream_t stream) { | |||
| #define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \ | |||
| threadblock_k_, warp_m_, warp_n_, \ | |||
| warp_k_, stages_, aligned_) \ | |||
| if (threadblock_shape.m() == threadblock_m_ && \ | |||
| threadblock_shape.n() == threadblock_n_ && \ | |||
| threadblock_shape.k() == threadblock_k_ && \ | |||
| warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ | |||
| warp_shape.k() == warp_k_ && stages == stages_) { \ | |||
| using ThreadBlockShape = \ | |||
| cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \ | |||
| threadblock_k_>; \ | |||
| using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \ | |||
| using InstructionShape = cutlass::gemm::GemmShape<1, 1, 4>; \ | |||
| using Convolution = cutlass::conv::device::Convolution< \ | |||
| int8_t, cutlass::layout::TensorNCxHWx<4>, int8_t, \ | |||
| cutlass::layout::TensorCxRSKx<4>, ElementOutput, \ | |||
| cutlass::layout::TensorNHWC, int32_t, \ | |||
| cutlass::layout::TensorNHWC, int32_t, \ | |||
| cutlass::conv::ConvType::kConvolution, \ | |||
| cutlass::arch::OpClassSimt, cutlass::arch::Sm75, \ | |||
| ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ | |||
| cutlass::conv::threadblock:: \ | |||
| ConvolutionFpropNCxHWxThreadblockSwizzle, \ | |||
| stages_, 4, aligned_, true, cutlass::arch::OpMultiplyAdd>; \ | |||
| typename Convolution::ConvolutionParameter conv_param( \ | |||
| param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ | |||
| param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ | |||
| param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \ | |||
| return cutlass_convolution_wrapper<Convolution>( \ | |||
| d_src, d_filter, d_bias, \ | |||
| reinterpret_cast<const ElementOutput*>(d_z), \ | |||
| reinterpret_cast<ElementOutput*>(d_dst), workspace, \ | |||
| conv_param, epilogue, stream); \ | |||
| } | |||
| #define DISPATCH_KERNEL \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 32, 64, 32, 32, 2, 16); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 32, 64, 32, 32, 2, 16); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 64, 32, 32, 2, 16); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 32, 64, 32, 32, 2, 16); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 2, 16); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 64, 32, 32, 64, 32, 2, 16); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 32, 32, 64, 32, 32, 2, 16); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 128, 16, 16, 128, 16, 1, 8); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 64, 8, 16, 64, 8, 2, 4); \ | |||
| megdnn_assert(false, \ | |||
| "unsupported threadblock shape (%dx%dx%d) and warp shape " \ | |||
| "(%dx%dx%d)", \ | |||
| threadblock_shape.m(), threadblock_shape.n(), \ | |||
| threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ | |||
| warp_shape.k()); | |||
| using ElementOutput = cutlass::integer_subbyte<4, signedness>; | |||
| using ElementAccumulator = int32_t; | |||
| using ElementBias = int32_t; | |||
| using ElementCompute = float; | |||
| using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode; | |||
| switch (nonlinear_mode) { | |||
| case NonlineMode::IDENTITY: { | |||
| using EpilogueOp = | |||
| cutlass::epilogue::thread::BiasAddLinearCombinationClamp< | |||
| ElementOutput, 8, ElementAccumulator, ElementBias, | |||
| ElementCompute>; | |||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma, | |||
| delta + theta}; | |||
| DISPATCH_KERNEL; | |||
| } | |||
| case NonlineMode::RELU: { | |||
| using EpilogueOp = cutlass::epilogue::thread:: | |||
| BiasAddLinearCombinationReluClamp< | |||
| ElementOutput, 8, ElementAccumulator, ElementBias, | |||
| ElementCompute>; | |||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma, | |||
| 0, delta, theta}; | |||
| DISPATCH_KERNEL; | |||
| } | |||
| case NonlineMode::H_SWISH: { | |||
| using EpilogueOp = cutlass::epilogue::thread:: | |||
| BiasAddLinearCombinationHSwishClamp< | |||
| ElementOutput, 8, ElementAccumulator, ElementBias, | |||
| ElementCompute>; | |||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma, | |||
| scale, delta, theta}; | |||
| DISPATCH_KERNEL; | |||
| } | |||
| default: | |||
| megdnn_assert(false, | |||
| "unsupported nonlinear mode for conv bias operator"); | |||
| } | |||
| #undef DISPATCH_KERNEL_WITH_TILE_SHAPE | |||
| #undef DISPATCH_KERNEL | |||
| } | |||
| #endif | |||
| #define INST(signedness) \ | |||
| template void megdnn::cuda::cutlass_wrapper:: \ | |||
| do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nhwc<signedness>( \ | |||
| const int8_t* d_src, const int8_t* d_filter, \ | |||
| const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, \ | |||
| int* workspace, const convolution::ConvParam& param, \ | |||
| uint32_t nonlinear_mode, float alpha, float beta, \ | |||
| float gamma, float delta, float theta, float scale, \ | |||
| const GemmCoord& threadblock_shape, \ | |||
| const GemmCoord& warp_shape, int stages, \ | |||
| cudaStream_t stream); | |||
| INST(true); | |||
| INST(false); | |||
| #undef INST | |||
| // vim: syntax=cuda.doxygen | |||
| @@ -1,65 +0,0 @@ | |||
| /** | |||
| * \file | |||
| * dnn/src/cuda/conv_bias/int8/implicit_gemm_conv_bias_cutlass_wrapper.cuinl | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "cutlass/convolution/device/convolution.h" | |||
| #include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh" | |||
| using namespace megdnn; | |||
| using namespace cuda; | |||
| using namespace cutlass_wrapper; | |||
| template <typename Convolution> | |||
| void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper( | |||
| const typename Convolution::ElementSrc* d_src, | |||
| const typename Convolution::ElementFilter* d_filter, | |||
| const typename Convolution::ElementBias* d_bias, | |||
| const typename Convolution::ElementDst* d_z, | |||
| typename Convolution::ElementDst* d_dst, int* workspace, | |||
| typename Convolution::ConvolutionParameter const& conv_param, | |||
| typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||
| cudaStream_t stream, typename Convolution::ExtraParam extra_param) { | |||
| typename Convolution::TensorRefSrc tensor_src{ | |||
| const_cast<typename Convolution::ElementSrc*>(d_src), | |||
| Convolution::LayoutSrc::packed( | |||
| {conv_param.N, conv_param.H, conv_param.W, conv_param.C})}; | |||
| typename Convolution::TensorRefFilter tensor_filter{ | |||
| const_cast<typename Convolution::ElementFilter*>(d_filter), | |||
| Convolution::LayoutFilter::packed( | |||
| {conv_param.K, conv_param.R, conv_param.S, conv_param.C})}; | |||
| typename Convolution::TensorRefBias tensor_bias{ | |||
| const_cast<typename Convolution::ElementBias*>(d_bias), | |||
| Convolution::LayoutBias::packed({1, 1, 1, conv_param.K})}; | |||
| typename Convolution::TensorRefDst tensor_z{ | |||
| const_cast<typename Convolution::ElementDst*>(d_z), | |||
| Convolution::LayoutDst::packed( | |||
| {conv_param.N, conv_param.P, conv_param.Q, conv_param.K})}; | |||
| typename Convolution::TensorRefDst tensor_dst{ | |||
| d_dst, | |||
| Convolution::LayoutDst::packed( | |||
| {conv_param.N, conv_param.P, conv_param.Q, conv_param.K})}; | |||
| typename Convolution::Arguments arguments{conv_param, | |||
| tensor_src.non_const_ref(), | |||
| tensor_filter.non_const_ref(), | |||
| tensor_bias.non_const_ref(), | |||
| tensor_z.non_const_ref(), | |||
| tensor_dst.non_const_ref(), | |||
| epilogue, | |||
| {}, | |||
| {}, | |||
| extra_param}; | |||
| Convolution conv_op; | |||
| cutlass_check(conv_op.initialize(arguments, workspace)); | |||
| cutlass_check(conv_op(stream)); | |||
| after_kernel_launch(); | |||
| } | |||
| // vim: syntax=cuda.doxygen | |||
| @@ -10,8 +10,7 @@ | |||
| * implied. | |||
| */ | |||
| #include "./algo.h" | |||
| #include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh" | |||
| #include "src/cuda/conv_bias/algo.h" | |||
| using namespace megdnn; | |||
| using namespace cuda; | |||
| @@ -81,29 +80,6 @@ ConvBiasForwardImpl::AlgoInt4Int4NCHW64IMMAImplicitGemm::get_constants( | |||
| return {alpha, beta, gamma, delta, theta}; | |||
| } | |||
| void ConvBiasForwardImpl::AlgoInt4Int4NCHW64IMMAImplicitGemm::do_exec( | |||
| const ExecArgs& args, void* filter_ptr, void* bias_ptr, void* z_ptr, | |||
| ConvParam kern_param, uint32_t nonlinear_mode, float alpha, float beta, | |||
| float gamma, float delta, float theta, cudaStream_t stream) const { | |||
| float dst_scale = args.dst_layout->dtype.param<dtype::QuantizedS4>().scale; | |||
| cutlass_wrapper::GemmCoord threadblock_shape{m_algo_param.threadblock_m, | |||
| m_algo_param.threadblock_n, | |||
| m_algo_param.threadblock_k}; | |||
| cutlass_wrapper::GemmCoord warp_shape{ | |||
| m_algo_param.warp_m, m_algo_param.warp_n, m_algo_param.warp_k}; | |||
| cutlass_wrapper::do_conv_bias_int4_int4_implicit_gemm_imma_ncdiv64hw64< | |||
| true>(reinterpret_cast<int8_t*>(args.src_tensor->raw_ptr), | |||
| reinterpret_cast<int8_t*>(filter_ptr), | |||
| reinterpret_cast<int32_t*>(bias_ptr), | |||
| reinterpret_cast<int8_t*>(z_ptr), | |||
| reinterpret_cast<int8_t*>(args.dst_tensor->raw_ptr), nullptr, | |||
| kern_param, nonlinear_mode, alpha, beta, gamma, dst_scale, | |||
| threadblock_shape, warp_shape, m_algo_param.stage, stream); | |||
| } | |||
| #endif | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -10,8 +10,7 @@ | |||
| * implied. | |||
| */ | |||
| #include "./algo.h" | |||
| #include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh" | |||
| #include "src/cuda/conv_bias/algo.h" | |||
| using namespace megdnn; | |||
| using namespace cuda; | |||
| @@ -81,42 +80,6 @@ ConvBiasForwardImpl::AlgoInt4Int4NHWCIMMAImplicitGemm::get_constants( | |||
| return {alpha, beta, gamma, delta, theta}; | |||
| } | |||
| void ConvBiasForwardImpl::AlgoInt4Int4NHWCIMMAImplicitGemm::do_exec( | |||
| const ExecArgs& args, void* filter_ptr, void* bias_ptr, void* z_ptr, | |||
| ConvParam kern_param, uint32_t nonlinear_mode, float alpha, float beta, | |||
| float gamma, float delta, float theta, cudaStream_t stream) const { | |||
| float dst_scale = args.dst_layout->dtype.param<dtype::QuantizedS4>().scale; | |||
| cutlass_wrapper::GemmCoord threadblock_shape{m_algo_param.threadblock_m, | |||
| m_algo_param.threadblock_n, | |||
| m_algo_param.threadblock_k}; | |||
| cutlass_wrapper::GemmCoord warp_shape{ | |||
| m_algo_param.warp_m, m_algo_param.warp_n, m_algo_param.warp_k}; | |||
| if (kern_param.fh == 1 && kern_param.fw == 1) { | |||
| cutlass_wrapper::do_conv_bias_int4_int4_implicit_gemm_imma_nhwc<false>( | |||
| reinterpret_cast<int8_t*>(args.src_tensor->raw_ptr), | |||
| reinterpret_cast<int8_t*>(filter_ptr), | |||
| reinterpret_cast<int32_t*>(bias_ptr), | |||
| reinterpret_cast<int8_t*>(z_ptr), | |||
| reinterpret_cast<int8_t*>(args.dst_tensor->raw_ptr), nullptr, | |||
| kern_param, nonlinear_mode, alpha, beta, gamma, dst_scale, | |||
| threadblock_shape, warp_shape, m_algo_param.access_size, | |||
| m_algo_param.stage, stream); | |||
| } else { | |||
| cutlass_wrapper::do_conv_bias_int4_int4_implicit_gemm_imma_nhwc<true>( | |||
| reinterpret_cast<int8_t*>(args.src_tensor->raw_ptr), | |||
| reinterpret_cast<int8_t*>(filter_ptr), | |||
| reinterpret_cast<int32_t*>(bias_ptr), | |||
| reinterpret_cast<int8_t*>(z_ptr), | |||
| reinterpret_cast<int8_t*>(args.dst_tensor->raw_ptr), nullptr, | |||
| kern_param, nonlinear_mode, alpha, beta, gamma, dst_scale, | |||
| threadblock_shape, warp_shape, m_algo_param.access_size, | |||
| m_algo_param.stage, stream); | |||
| } | |||
| } | |||
| #endif | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -10,10 +10,9 @@ | |||
| * implied. | |||
| */ | |||
| #include "./algo.h" | |||
| #include "src/common/conv_bias.h" | |||
| #include "src/cuda/conv_bias/algo.h" | |||
| #include "src/cuda/conv_bias/cutlass_reorder_filter.cuh" | |||
| #include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh" | |||
| #include "src/cuda/conv_bias/reduce_filter.cuh" | |||
| #include "src/cuda/convolution_helper/parameter.cuh" | |||
| #include "src/cuda/utils.h" | |||
| @@ -102,22 +101,40 @@ void ConvBiasForwardImpl::AlgoInt4NCHW64IMMAImplicitGemmBase::exec( | |||
| if (args.z_layout->ndim > 0) | |||
| z_ptr = args.z_tensor->raw_ptr; | |||
| // \note these constants of cutlass epilogue will be passed to method | |||
| // `execute_cutlass_conv_op` by pointer and interpreted as ElementCompute*, | |||
| // a different dtype here results in undefined epilogue behaviors | |||
| float alpha, beta, gamma, delta, theta; | |||
| std::tie(alpha, beta, gamma, delta, theta) = get_constants(args); | |||
| float dst_scale = 0.f; | |||
| float threshold = 0.f; | |||
| uint8_t src_zero = 0; | |||
| bool load_from_const = !(fh == 1 && fw == 1); | |||
| bool without_shared_load = true; | |||
| if (args.dst_layout->dtype.enumv() == DTypeEnum::Quantized4Asymm) { | |||
| dst_scale = | |||
| args.dst_layout->dtype.param<dtype::Quantized4Asymm>().scale; | |||
| src_zero = args.src_layout->dtype.param<dtype::Quantized4Asymm>() | |||
| .zero_point; | |||
| } else { // DTypeEnum::QuantizedS4 | |||
| dst_scale = args.dst_layout->dtype.param<dtype::QuantizedS4>().scale; | |||
| } | |||
| ConvParam kern_param; | |||
| kern_param.n = n, kern_param.co = co, kern_param.ci = ci, | |||
| kern_param.hi = hi, kern_param.wi = wi, kern_param.ho = ho, | |||
| kern_param.wo = wo, kern_param.ph = ph, kern_param.pw = pw, | |||
| kern_param.sh = sh, kern_param.sw = sw, kern_param.fh = fh, | |||
| kern_param.fw = fw; | |||
| cudaStream_t stream = cuda_stream(args.opr->handle()); | |||
| uint32_t nonlinear_mode = static_cast<uint32_t>(param.nonlineMode); | |||
| const auto* op = get_cutlass_conv_op(args, ConvOperator::kFprop, | |||
| ConvType::kConvolution, | |||
| load_from_const, without_shared_load); | |||
| cudaStream_t stream = cuda_stream(args.opr->handle()); | |||
| execute_cutlass_conv_op(op, args.src_tensor->raw_ptr, filter_ptr, bias_ptr, | |||
| z_ptr, args.dst_tensor->raw_ptr, nullptr, n, hi, wi, | |||
| ci, co, fh, fw, ho, wo, ph, pw, sh, sw, dh, dw, | |||
| &alpha, &beta, &gamma, &delta, &theta, &threshold, | |||
| &dst_scale, stream, &src_zero); | |||
| do_exec(args, filter_ptr, bias_ptr, z_ptr, kern_param, nonlinear_mode, | |||
| alpha, beta, gamma, delta, theta, stream); | |||
| after_kernel_launch(); | |||
| } | |||
| std::string ConvBiasForwardImpl::AlgoInt4NCHW64IMMAImplicitGemmBase::to_string( | |||
| @@ -10,10 +10,9 @@ | |||
| * implied. | |||
| */ | |||
| #include "./algo.h" | |||
| #include "src/common/conv_bias.h" | |||
| #include "src/cuda/conv_bias/algo.h" | |||
| #include "src/cuda/conv_bias/cutlass_reorder_filter.cuh" | |||
| #include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh" | |||
| #include "src/cuda/conv_bias/reduce_filter.cuh" | |||
| #include "src/cuda/convolution_helper/parameter.cuh" | |||
| #include "src/cuda/utils.h" | |||
| @@ -109,22 +108,43 @@ void ConvBiasForwardImpl::AlgoInt4NHWCIMMAImplicitGemmBase::exec( | |||
| if (args.z_layout->ndim > 0) | |||
| z_ptr = args.z_tensor->raw_ptr; | |||
| // \note these constants of cutlass epilogue will be passed to method | |||
| // `execute_cutlass_conv_op` by pointer and interpreted as ElementCompute*, | |||
| // a different dtype here results in undefined epilogue behaviors | |||
| float alpha, beta, gamma, delta, theta; | |||
| std::tie(alpha, beta, gamma, delta, theta) = get_constants(args); | |||
| float dst_scale = 0.f; | |||
| float threshold = 0.f; | |||
| uint8_t src_zero = 0; | |||
| bool load_from_const = !(fh == 1 && fw == 1); | |||
| bool without_shared_load = ((co % m_algo_param.threadblock_n == 0) && | |||
| (m_algo_param.threadblock_n == 32 || | |||
| m_algo_param.threadblock_n == 64)); | |||
| if (args.dst_layout->dtype.enumv() == DTypeEnum::Quantized4Asymm) { | |||
| dst_scale = | |||
| args.dst_layout->dtype.param<dtype::Quantized4Asymm>().scale; | |||
| src_zero = args.src_layout->dtype.param<dtype::Quantized4Asymm>() | |||
| .zero_point; | |||
| } else { // DTypeEnum::QuantizedS4 | |||
| dst_scale = args.dst_layout->dtype.param<dtype::QuantizedS4>().scale; | |||
| } | |||
| ConvParam kern_param; | |||
| kern_param.n = n, kern_param.co = co, kern_param.ci = ci, | |||
| kern_param.hi = hi, kern_param.wi = wi, kern_param.ho = ho, | |||
| kern_param.wo = wo, kern_param.ph = ph, kern_param.pw = pw, | |||
| kern_param.sh = sh, kern_param.sw = sw, kern_param.fh = fh, | |||
| kern_param.fw = fw; | |||
| cudaStream_t stream = cuda_stream(args.opr->handle()); | |||
| uint32_t nonlinear_mode = static_cast<uint32_t>(param.nonlineMode); | |||
| const auto* op = get_cutlass_conv_op(args, ConvOperator::kFprop, | |||
| ConvType::kConvolution, | |||
| load_from_const, without_shared_load); | |||
| cudaStream_t stream = cuda_stream(args.opr->handle()); | |||
| execute_cutlass_conv_op(op, args.src_tensor->raw_ptr, filter_ptr, bias_ptr, | |||
| z_ptr, args.dst_tensor->raw_ptr, nullptr, n, hi, wi, | |||
| ci, co, fh, fw, ho, wo, ph, pw, sh, sw, dh, dw, | |||
| &alpha, &beta, &gamma, &delta, &theta, &threshold, | |||
| &dst_scale, stream, &src_zero); | |||
| do_exec(args, filter_ptr, bias_ptr, z_ptr, kern_param, nonlinear_mode, | |||
| alpha, beta, gamma, delta, theta, stream); | |||
| after_kernel_launch(); | |||
| } | |||
| std::string ConvBiasForwardImpl::AlgoInt4NHWCIMMAImplicitGemmBase::to_string( | |||
| @@ -10,12 +10,11 @@ | |||
| * implied. | |||
| */ | |||
| #include "./algo.h" | |||
| #include "src/common/conv_bias.h" | |||
| #include "src/cuda/conv_bias/algo.h" | |||
| #include "src/cuda/conv_bias/cutlass_reorder_filter.cuh" | |||
| #include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh" | |||
| #include "src/cuda/convolution_helper/parameter.cuh" | |||
| #include "src/cuda/utils.h" | |||
| #include "src/common/conv_bias.h" | |||
| using namespace megdnn; | |||
| using namespace cuda; | |||
| @@ -38,8 +37,7 @@ bool ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::is_available( | |||
| bool available = true; | |||
| auto&& param = args.opr->param(); | |||
| auto&& fm = args.filter_meta; | |||
| if (!check_bias_share_in_channel(*(args.bias_layout), | |||
| param.format)) | |||
| if (!check_bias_share_in_channel(*(args.bias_layout), param.format)) | |||
| return false; | |||
| if (param.format != Format::NCHW32 && param.format != Format::NCHW32_NCHW4) | |||
| return false; | |||
| @@ -137,19 +135,16 @@ void ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::exec( | |||
| args.preprocessed_filter->tensors[0].raw_ptr); | |||
| } | |||
| ConvParam kern_param; | |||
| kern_param.n = n, kern_param.co = co, kern_param.ci = ci, | |||
| kern_param.hi = hi, kern_param.wi = wi, kern_param.ho = ho, | |||
| kern_param.wo = wo, kern_param.ph = ph, kern_param.pw = pw, | |||
| kern_param.sh = sh, kern_param.sw = sw, kern_param.fh = fh, | |||
| kern_param.fw = fw; | |||
| float src_scale = args.src_layout->dtype.param<dtype::QuantizedS8>().scale, | |||
| filter_scale = | |||
| args.filter_layout->dtype.param<dtype::QuantizedS8>().scale, | |||
| bias_scale = | |||
| args.bias_layout->dtype.param<dtype::QuantizedS32>().scale, | |||
| dst_scale = args.dst_layout->dtype.param<dtype::QuantizedS8>().scale; | |||
| // \note these constants of cutlass epilogue will be passed to method | |||
| // `execute_cutlass_conv_op` by pointer and interpreted as ElementCompute*, | |||
| // a different dtype here results in undefined epilogue behaviors | |||
| float alpha = src_scale * filter_scale / dst_scale, | |||
| beta = bias_scale / dst_scale; | |||
| int8_t* z_dev_ptr = nullptr; | |||
| @@ -159,80 +154,20 @@ void ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::exec( | |||
| float z_scale = args.z_layout->dtype.param<dtype::QuantizedS8>().scale; | |||
| gamma = z_scale / dst_scale; | |||
| } | |||
| uint32_t nonlinear_mode = static_cast<uint32_t>(param.nonlineMode); | |||
| if (fh == 1 && fw == 1) { | |||
| if (param.format == Format::NCHW32) { | |||
| cutlass_wrapper::do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32< | |||
| false>( | |||
| args.src_tensor->compatible_ptr<int8_t>(), filter_ptr, | |||
| args.bias_tensor->compatible_ptr<int32_t>(), z_dev_ptr, | |||
| args.dst_tensor->compatible_ptr<int8_t>(), nullptr, | |||
| kern_param, nonlinear_mode, alpha, beta, gamma, dst_scale, | |||
| cutlass_wrapper::GemmCoord{m_algo_param.threadblock_m, | |||
| m_algo_param.threadblock_n, | |||
| m_algo_param.threadblock_k}, | |||
| cutlass_wrapper::GemmCoord{m_algo_param.warp_m, | |||
| m_algo_param.warp_n, | |||
| m_algo_param.warp_k}, | |||
| m_algo_param.stage, stream); | |||
| } else { | |||
| megdnn_assert(param.format == Format::NCHW32_NCHW4); | |||
| cutlass_wrapper:: | |||
| do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4< | |||
| false>( | |||
| args.src_tensor->compatible_ptr<int8_t>(), | |||
| filter_ptr, | |||
| args.bias_tensor->compatible_ptr<int32_t>(), | |||
| z_dev_ptr, | |||
| args.dst_tensor->compatible_ptr<int8_t>(), nullptr, | |||
| kern_param, nonlinear_mode, alpha, beta, gamma, | |||
| dst_scale, | |||
| cutlass_wrapper::GemmCoord{ | |||
| m_algo_param.threadblock_m, | |||
| m_algo_param.threadblock_n, | |||
| m_algo_param.threadblock_k}, | |||
| cutlass_wrapper::GemmCoord{m_algo_param.warp_m, | |||
| m_algo_param.warp_n, | |||
| m_algo_param.warp_k}, | |||
| m_algo_param.stage, stream); | |||
| } | |||
| } else { | |||
| if (param.format == Format::NCHW32) { | |||
| cutlass_wrapper::do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32< | |||
| true>( | |||
| args.src_tensor->compatible_ptr<int8_t>(), filter_ptr, | |||
| args.bias_tensor->compatible_ptr<int32_t>(), z_dev_ptr, | |||
| args.dst_tensor->compatible_ptr<int8_t>(), nullptr, | |||
| kern_param, nonlinear_mode, alpha, beta, gamma, dst_scale, | |||
| cutlass_wrapper::GemmCoord{m_algo_param.threadblock_m, | |||
| m_algo_param.threadblock_n, | |||
| m_algo_param.threadblock_k}, | |||
| cutlass_wrapper::GemmCoord{m_algo_param.warp_m, | |||
| m_algo_param.warp_n, | |||
| m_algo_param.warp_k}, | |||
| m_algo_param.stage, stream); | |||
| } else { | |||
| megdnn_assert(param.format == Format::NCHW32_NCHW4); | |||
| cutlass_wrapper:: | |||
| do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4< | |||
| true>( | |||
| args.src_tensor->compatible_ptr<int8_t>(), | |||
| filter_ptr, | |||
| args.bias_tensor->compatible_ptr<int32_t>(), | |||
| z_dev_ptr, | |||
| args.dst_tensor->compatible_ptr<int8_t>(), nullptr, | |||
| kern_param, nonlinear_mode, alpha, beta, gamma, | |||
| dst_scale, | |||
| cutlass_wrapper::GemmCoord{ | |||
| m_algo_param.threadblock_m, | |||
| m_algo_param.threadblock_n, | |||
| m_algo_param.threadblock_k}, | |||
| cutlass_wrapper::GemmCoord{m_algo_param.warp_m, | |||
| m_algo_param.warp_n, | |||
| m_algo_param.warp_k}, | |||
| m_algo_param.stage, stream); | |||
| } | |||
| } | |||
| float delta = 0.f, theta = 0.f, threshold = 0.f; | |||
| bool load_from_const = !(fh == 1 && fw == 1); | |||
| bool without_shared_load = (param.format == Format::NCHW32); | |||
| const auto* op = get_cutlass_conv_op(args, ConvOperator::kFprop, | |||
| ConvType::kConvolution, | |||
| load_from_const, without_shared_load); | |||
| execute_cutlass_conv_op( | |||
| op, args.src_tensor->raw_ptr, filter_ptr, args.bias_tensor->raw_ptr, | |||
| z_dev_ptr, args.dst_tensor->raw_ptr, nullptr, n, hi, wi, ci, co, fh, | |||
| fw, ho, wo, ph, pw, sh, sw, dh, dw, &alpha, &beta, &gamma, &delta, | |||
| &theta, &threshold, &dst_scale, stream); | |||
| after_kernel_launch(); | |||
| } | |||
| @@ -249,9 +184,8 @@ size_t ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm:: | |||
| return 0_z; | |||
| } | |||
| SmallVector<TensorLayout> ConvBiasForwardImpl:: | |||
| AlgoInt8NCHW32IMMAImplicitGemm::deduce_preprocessed_filter_layout( | |||
| const SizeArgs& args) const { | |||
| SmallVector<TensorLayout> ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm:: | |||
| deduce_preprocessed_filter_layout(const SizeArgs& args) const { | |||
| return {args.filter_layout->collapse_contiguous()}; | |||
| } | |||
| @@ -6,14 +6,14 @@ | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "./algo.h" | |||
| #include "src/cuda/utils.h" | |||
| #include "src/cuda/convolution_helper/parameter.cuh" | |||
| #include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh" | |||
| #include "src/common/conv_bias.h" | |||
| #include "src/cuda/conv_bias/algo.h" | |||
| #include "src/cuda/convolution_helper/parameter.cuh" | |||
| #include "src/cuda/utils.h" | |||
| using namespace megdnn; | |||
| using namespace cuda; | |||
| @@ -34,8 +34,7 @@ bool ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::is_available( | |||
| bool available = true; | |||
| auto&& param = args.opr->param(); | |||
| auto&& fm = args.filter_meta; | |||
| if (!check_bias_share_in_channel(*(args.bias_layout), | |||
| param.format)) | |||
| if (!check_bias_share_in_channel(*(args.bias_layout), param.format)) | |||
| return false; | |||
| bool valid_format = param.format == Format::NCHW4_NCHW32 && | |||
| m_algo_param.threadblock_m % 32 == 0; | |||
| @@ -48,7 +47,8 @@ bool ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::is_available( | |||
| (args.dst_layout->dtype.enumv() == DTypeEnum::QuantizedS4 || | |||
| args.dst_layout->dtype.enumv() == DTypeEnum::Quantized4Asymm); | |||
| valid_format |= param.format == Format::NCHW4; | |||
| if (!valid_format) return false; | |||
| if (!valid_format) | |||
| return false; | |||
| size_t n = args.src_layout->operator[](0), | |||
| ci = args.src_layout->operator[](1) * 4, | |||
| hi = args.src_layout->operator[](2), | |||
| @@ -170,16 +170,13 @@ void ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::exec( | |||
| args.preprocessed_filter->tensors[0].raw_ptr); | |||
| } | |||
| convolution::ConvParam kern_param; | |||
| kern_param.n = n, kern_param.co = co, kern_param.ci = ci, | |||
| kern_param.hi = hi, kern_param.wi = wi, kern_param.ho = ho, | |||
| kern_param.wo = wo, kern_param.ph = ph, kern_param.pw = pw, | |||
| kern_param.sh = sh, kern_param.sw = sw, kern_param.fh = fh, | |||
| kern_param.fw = fw; | |||
| float src_scale = args.src_layout->dtype.param<dtype::QuantizedS8>().scale, | |||
| filter_scale = | |||
| args.filter_layout->dtype.param<dtype::QuantizedS8>().scale; | |||
| // \note these constants of cutlass epilogue will be passed to method | |||
| // `execute_cutlass_conv_op` by pointer and interpreted as ElementCompute*, | |||
| // a different dtype here results in undefined epilogue behaviors | |||
| float alpha = src_scale * filter_scale; | |||
| float beta = 1.f; | |||
| float dst_scale = 1.f; | |||
| @@ -192,13 +189,15 @@ void ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::exec( | |||
| if (args.bias_layout->dtype.enumv() == DTypeEnum::QuantizedS32) { | |||
| megdnn_assert(args.dst_layout->dtype.category() == | |||
| DTypeCategory::QUANTIZED); | |||
| float bias_scale = args.bias_layout->dtype.param<dtype::QuantizedS32>() | |||
| .scale; | |||
| float bias_scale = | |||
| args.bias_layout->dtype.param<dtype::QuantizedS32>().scale; | |||
| dst_scale = get_scale(args.dst_layout->dtype); | |||
| alpha /= dst_scale, beta = bias_scale / dst_scale; | |||
| } | |||
| float delta = 0.f; | |||
| void* z_ptr = nullptr; | |||
| if (args.z_layout->ndim > 0) { | |||
| z_ptr = args.z_tensor->raw_ptr; | |||
| gamma = 1.f; | |||
| if (args.z_layout->dtype.category() == DTypeCategory::QUANTIZED) { | |||
| megdnn_assert(args.dst_layout->dtype.category() == | |||
| @@ -213,98 +212,20 @@ void ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::exec( | |||
| delta = -z_zero * gamma; | |||
| } | |||
| } | |||
| uint32_t nonlinear_mode = static_cast<uint32_t>(param.nonlineMode); | |||
| bool nonunity_kernel = !(fh == 1 && fw == 1); | |||
| #define DISPATCH(_nonunity_kernel) \ | |||
| if (nonunity_kernel == _nonunity_kernel) { \ | |||
| cb(_nonunity_kernel) \ | |||
| } | |||
| if (param.format == Format::NCHW4) { | |||
| #define cb(_nonunity_kernel) \ | |||
| cutlass_wrapper::do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4< \ | |||
| _nonunity_kernel>( \ | |||
| args.src_tensor->compatible_ptr<int8_t>(), filter_ptr, \ | |||
| args.bias_tensor->compatible_ptr<int32_t>(), \ | |||
| args.z_tensor->compatible_ptr<int8_t>(), \ | |||
| args.dst_tensor->compatible_ptr<int8_t>(), nullptr, kern_param, \ | |||
| nonlinear_mode, alpha, beta, gamma, dst_scale, \ | |||
| cutlass_wrapper::GemmCoord{m_algo_param.threadblock_m, \ | |||
| m_algo_param.threadblock_n, \ | |||
| m_algo_param.threadblock_k}, \ | |||
| cutlass_wrapper::GemmCoord{m_algo_param.warp_m, \ | |||
| m_algo_param.warp_n, \ | |||
| m_algo_param.warp_k}, \ | |||
| m_algo_param.stage, stream); | |||
| DISPATCH(true); | |||
| DISPATCH(false); | |||
| #undef cb | |||
| } else if (param.format == Format::NCHW4_NCHW) { | |||
| #define cb(_nonunity_kernel) \ | |||
| cutlass_wrapper::do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nchw< \ | |||
| _nonunity_kernel>( \ | |||
| args.src_tensor->compatible_ptr<int8_t>(), filter_ptr, \ | |||
| args.bias_tensor->compatible_ptr<float>(), \ | |||
| args.z_tensor->compatible_ptr<float>(), \ | |||
| args.dst_tensor->compatible_ptr<float>(), nullptr, kern_param, \ | |||
| nonlinear_mode, alpha, beta, gamma, dst_scale, \ | |||
| cutlass_wrapper::GemmCoord{m_algo_param.threadblock_m, \ | |||
| m_algo_param.threadblock_n, \ | |||
| m_algo_param.threadblock_k}, \ | |||
| cutlass_wrapper::GemmCoord{m_algo_param.warp_m, \ | |||
| m_algo_param.warp_n, \ | |||
| m_algo_param.warp_k}, \ | |||
| m_algo_param.stage, stream); | |||
| DISPATCH(true); | |||
| DISPATCH(false); | |||
| #undef cb | |||
| } else if (param.format == Format::NCHW4_NHWC) { | |||
| #define cb(_signedness) \ | |||
| cutlass_wrapper::do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nhwc< \ | |||
| _signedness>( \ | |||
| args.src_tensor->compatible_ptr<int8_t>(), filter_ptr, \ | |||
| args.bias_tensor->compatible_ptr<int32_t>(), \ | |||
| reinterpret_cast<int8_t*>(args.z_tensor->raw_ptr), \ | |||
| reinterpret_cast<int8_t*>(args.dst_tensor->raw_ptr), nullptr, \ | |||
| kern_param, nonlinear_mode, alpha, beta, gamma, delta, theta, \ | |||
| dst_scale, \ | |||
| cutlass_wrapper::GemmCoord{m_algo_param.threadblock_m, \ | |||
| m_algo_param.threadblock_n, \ | |||
| m_algo_param.threadblock_k}, \ | |||
| cutlass_wrapper::GemmCoord{m_algo_param.warp_m, \ | |||
| m_algo_param.warp_n, \ | |||
| m_algo_param.warp_k}, \ | |||
| m_algo_param.stage, stream); | |||
| if (args.dst_layout->dtype.enumv() == DTypeEnum::QuantizedS4) { | |||
| cb(true); | |||
| } else { | |||
| megdnn_assert(args.dst_layout->dtype.enumv() == | |||
| DTypeEnum::Quantized4Asymm); | |||
| cb(false); | |||
| } | |||
| #undef cb | |||
| } else { | |||
| megdnn_assert(param.format == Format::NCHW4_NCHW32); | |||
| #define cb(_nonunity_kernel) \ | |||
| cutlass_wrapper:: \ | |||
| do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32< \ | |||
| _nonunity_kernel>( \ | |||
| args.src_tensor->compatible_ptr<int8_t>(), filter_ptr, \ | |||
| args.bias_tensor->compatible_ptr<int32_t>(), \ | |||
| args.z_tensor->compatible_ptr<int8_t>(), \ | |||
| args.dst_tensor->compatible_ptr<int8_t>(), nullptr, \ | |||
| kern_param, nonlinear_mode, alpha, beta, gamma, dst_scale, \ | |||
| cutlass_wrapper::GemmCoord{m_algo_param.threadblock_m, \ | |||
| m_algo_param.threadblock_n, \ | |||
| m_algo_param.threadblock_k}, \ | |||
| cutlass_wrapper::GemmCoord{m_algo_param.warp_m, \ | |||
| m_algo_param.warp_n, \ | |||
| m_algo_param.warp_k}, \ | |||
| m_algo_param.stage, stream); | |||
| DISPATCH(true); | |||
| DISPATCH(false); | |||
| #undef cb | |||
| #undef DISPATCH | |||
| } | |||
| float threshold = 0.f; | |||
| bool load_from_const = !(fh == 1 && fw == 1); | |||
| bool without_shared_load = false; | |||
| const auto* op = get_cutlass_conv_op(args, ConvOperator::kFprop, | |||
| ConvType::kConvolution, | |||
| load_from_const, without_shared_load); | |||
| execute_cutlass_conv_op( | |||
| op, args.src_tensor->raw_ptr, filter_ptr, args.bias_tensor->raw_ptr, | |||
| z_ptr, args.dst_tensor->raw_ptr, nullptr, n, hi, wi, ci, co, fh, fw, | |||
| ho, wo, ph, pw, sh, sw, dh, dw, &alpha, &beta, &gamma, &delta, | |||
| &theta, &threshold, &dst_scale, stream); | |||
| after_kernel_launch(); | |||
| } | |||
| @@ -10,8 +10,7 @@ | |||
| * implied. | |||
| */ | |||
| #include "./algo.h" | |||
| #include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh" | |||
| #include "src/cuda/conv_bias/algo.h" | |||
| #include "src/cuda/conv_bias/reduce_filter.cuh" | |||
| #include "src/cuda/utils.h" | |||
| @@ -120,32 +119,15 @@ ConvBiasForwardImpl::AlgoUInt4Int4NCHW64IMMAImplicitGemm::get_constants( | |||
| delta = -z_zero * gamma; | |||
| } | |||
| return {alpha, beta, gamma, delta, theta}; | |||
| } | |||
| // identity epilogue has no theta: | |||
| // alpha * accumulator + beta * bias + gamma * source + delta | |||
| if (args.opr->param().nonlineMode == | |||
| param::ConvBias::NonlineMode::IDENTITY) { | |||
| delta += theta; | |||
| theta = 0.f; | |||
| } | |||
| void ConvBiasForwardImpl::AlgoUInt4Int4NCHW64IMMAImplicitGemm::do_exec( | |||
| const ExecArgs& args, void* filter_ptr, void* bias_ptr, void* z_ptr, | |||
| ConvParam kern_param, uint32_t nonlinear_mode, float alpha, float beta, | |||
| float gamma, float delta, float theta, cudaStream_t stream) const { | |||
| float dst_scale = | |||
| args.dst_layout->dtype.param<dtype::Quantized4Asymm>().scale; | |||
| uint8_t src_zero = | |||
| args.src_layout->dtype.param<dtype::Quantized4Asymm>().zero_point; | |||
| cutlass_wrapper::GemmCoord threadblock_shape{m_algo_param.threadblock_m, | |||
| m_algo_param.threadblock_n, | |||
| m_algo_param.threadblock_k}; | |||
| cutlass_wrapper::GemmCoord warp_shape{ | |||
| m_algo_param.warp_m, m_algo_param.warp_n, m_algo_param.warp_k}; | |||
| cutlass_wrapper::do_conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64< | |||
| true>(reinterpret_cast<uint8_t*>(args.src_tensor->raw_ptr), | |||
| reinterpret_cast<int8_t*>(filter_ptr), | |||
| reinterpret_cast<int32_t*>(bias_ptr), | |||
| reinterpret_cast<uint8_t*>(z_ptr), | |||
| reinterpret_cast<uint8_t*>(args.dst_tensor->raw_ptr), nullptr, | |||
| kern_param, nonlinear_mode, alpha, beta, gamma, delta, theta, | |||
| dst_scale, src_zero, threadblock_shape, warp_shape, | |||
| m_algo_param.stage, stream); | |||
| return {alpha, beta, gamma, delta, theta}; | |||
| } | |||
| void ConvBiasForwardImpl::AlgoUInt4Int4NCHW64IMMAImplicitGemm::update_bias( | |||
| @@ -10,8 +10,7 @@ | |||
| * implied. | |||
| */ | |||
| #include "./algo.h" | |||
| #include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh" | |||
| #include "src/cuda/conv_bias/algo.h" | |||
| #include "src/cuda/conv_bias/reduce_filter.cuh" | |||
| #include "src/cuda/utils.h" | |||
| @@ -121,44 +120,15 @@ ConvBiasForwardImpl::AlgoUInt4Int4NHWCIMMAImplicitGemm::get_constants( | |||
| delta = -z_zero * gamma; | |||
| } | |||
| return {alpha, beta, gamma, delta, theta}; | |||
| } | |||
| void ConvBiasForwardImpl::AlgoUInt4Int4NHWCIMMAImplicitGemm::do_exec( | |||
| const ExecArgs& args, void* filter_ptr, void* bias_ptr, void* z_ptr, | |||
| ConvParam kern_param, uint32_t nonlinear_mode, float alpha, float beta, | |||
| float gamma, float delta, float theta, cudaStream_t stream) const { | |||
| float dst_scale = | |||
| args.dst_layout->dtype.param<dtype::Quantized4Asymm>().scale; | |||
| uint8_t src_zero = | |||
| args.src_layout->dtype.param<dtype::Quantized4Asymm>().zero_point; | |||
| cutlass_wrapper::GemmCoord threadblock_shape{m_algo_param.threadblock_m, | |||
| m_algo_param.threadblock_n, | |||
| m_algo_param.threadblock_k}; | |||
| cutlass_wrapper::GemmCoord warp_shape{ | |||
| m_algo_param.warp_m, m_algo_param.warp_n, m_algo_param.warp_k}; | |||
| if (kern_param.fh == 1 && kern_param.fw == 1) { | |||
| cutlass_wrapper::do_conv_bias_uint4_int4_implicit_gemm_imma_nhwc<false>( | |||
| reinterpret_cast<uint8_t*>(args.src_tensor->raw_ptr), | |||
| reinterpret_cast<int8_t*>(filter_ptr), | |||
| reinterpret_cast<int32_t*>(bias_ptr), | |||
| reinterpret_cast<uint8_t*>(z_ptr), | |||
| reinterpret_cast<uint8_t*>(args.dst_tensor->raw_ptr), nullptr, | |||
| kern_param, nonlinear_mode, alpha, beta, gamma, delta, theta, | |||
| dst_scale, src_zero, threadblock_shape, warp_shape, | |||
| m_algo_param.access_size, m_algo_param.stage, stream); | |||
| } else { | |||
| cutlass_wrapper::do_conv_bias_uint4_int4_implicit_gemm_imma_nhwc<true>( | |||
| reinterpret_cast<uint8_t*>(args.src_tensor->raw_ptr), | |||
| reinterpret_cast<int8_t*>(filter_ptr), | |||
| reinterpret_cast<int32_t*>(bias_ptr), | |||
| reinterpret_cast<uint8_t*>(z_ptr), | |||
| reinterpret_cast<uint8_t*>(args.dst_tensor->raw_ptr), nullptr, | |||
| kern_param, nonlinear_mode, alpha, beta, gamma, delta, theta, | |||
| dst_scale, src_zero, threadblock_shape, warp_shape, | |||
| m_algo_param.access_size, m_algo_param.stage, stream); | |||
| // identity epilogue has no theta: | |||
| // alpha * accumulator + beta * bias + gamma * source + delta | |||
| if (args.opr->param().nonlineMode == | |||
| param::ConvBias::NonlineMode::IDENTITY) { | |||
| delta += theta; | |||
| theta = 0.f; | |||
| } | |||
| return {alpha, beta, gamma, delta, theta}; | |||
| } | |||
| void ConvBiasForwardImpl::AlgoUInt4Int4NHWCIMMAImplicitGemm::update_bias( | |||
| @@ -57,6 +57,7 @@ public: | |||
| class AlgoBatchedMatmul; | |||
| class AlgoGroupConvGeneral; | |||
| class AlgoQUInt4x4x32WMMA; | |||
| class AlgoCutlassConvolutionBase; | |||
| class AlgoInt8CHWN4DotProdImplicitGemm; | |||
| class AlgoInt8NCHW4DotProdImplicitGemm; | |||
| class AlgoInt8CHWN4IMMAImplicitGemm; | |||
| @@ -1,100 +0,0 @@ | |||
| /** | |||
| * \file src/cuda/convolution/backward_data/cutlass_deconvolution_wrapper.cu | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| // ignore warning of cutlass | |||
| #pragma GCC diagnostic push | |||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||
| #if !MEGDNN_TEGRA_X1 | |||
| #include "cutlass/convolution/device/convolution.h" | |||
| #endif | |||
| #include "src/common/opr_param_defs_enumv.cuh" | |||
| #include "src/cuda/convolution/backward_data/cutlass_deconvolution_wrapper.cuh" | |||
| #pragma GCC diagnostic pop | |||
| using namespace megdnn; | |||
| using namespace cuda; | |||
| using namespace cutlass_wrapper; | |||
| /* ================ cutlass kernel wrapper for nchw4 layout ================= */ | |||
| #if MEGDNN_TEGRA_X1 | |||
| void megdnn::cuda::cutlass_wrapper::do_deconv_int8_implicit_gemm_dp4a_ncdiv4hw4( | |||
| const int8_t* /* d_src */, const int8_t* /* d_filter */, | |||
| int8_t* /* d_dst */, int* /* workspace */, | |||
| const convolution::ConvParam& /* param */, float /* alpha */, | |||
| const GemmCoord& /* threadblock_shape */, | |||
| const GemmCoord& /* warp_shape */, int /* stages */, | |||
| cudaStream_t /* stream */) {} | |||
| #else | |||
| void megdnn::cuda::cutlass_wrapper::do_deconv_int8_implicit_gemm_dp4a_ncdiv4hw4( | |||
| const int8_t* d_src, const int8_t* d_filter, int8_t* d_dst, | |||
| int* workspace, const convolution::ConvParam& param, float alpha, | |||
| const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, | |||
| int stages, cudaStream_t stream) { | |||
| #define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \ | |||
| threadblock_k_, warp_m_, warp_n_, \ | |||
| warp_k_, stage_, aligned_) \ | |||
| if (threadblock_shape.m() == threadblock_m_ && \ | |||
| threadblock_shape.n() == threadblock_n_ && \ | |||
| threadblock_shape.k() == threadblock_k_ && \ | |||
| warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ | |||
| warp_shape.k() == warp_k_ && stages == stage_) { \ | |||
| using ThreadBlockShape = \ | |||
| cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \ | |||
| threadblock_k_>; \ | |||
| using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \ | |||
| using InstructionShape = cutlass::gemm::GemmShape<1, 1, 4>; \ | |||
| using Deconvolution = cutlass::conv::device::Deconvolution< \ | |||
| int8_t, cutlass::layout::TensorNCxHWx<4>, int8_t, \ | |||
| cutlass::layout::TensorKxRSCx<4>, ElementOutput, \ | |||
| cutlass::layout::TensorNCxHWx<4>, int32_t, \ | |||
| cutlass::layout::TensorNCxHWx<4>, int32_t, \ | |||
| cutlass::arch::OpClassSimt, cutlass::arch::Sm61, \ | |||
| ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ | |||
| cutlass::conv::threadblock:: \ | |||
| ConvolutionDgradNCxHWxThreadblockSwizzle, \ | |||
| stage_, 4, aligned_, true, cutlass::arch::OpMultiplyAdd>; \ | |||
| typename Deconvolution::ConvolutionParameter conv_param( \ | |||
| param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ | |||
| param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ | |||
| param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \ | |||
| return cutlass_deconvolution_wrapper<Deconvolution>( \ | |||
| d_src, d_filter, nullptr, nullptr, d_dst, workspace, \ | |||
| conv_param, epilogue, stream); \ | |||
| } | |||
| #define DISPATCH_KERNEL \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 64, 8, 16, 64, 8, 2, 4); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 128, 16, 16, 64, 16, 2, 4); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 128, 16, 16, 128, 16, 1, 8); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 2, 16); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 64, 32, 32, 2, 16); \ | |||
| megdnn_assert(false, \ | |||
| "unsupported threadblock shape (%dx%dx%d) and warp shape " \ | |||
| "(%dx%dx%d)", \ | |||
| threadblock_shape.m(), threadblock_shape.n(), \ | |||
| threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ | |||
| warp_shape.k()); | |||
| using ElementOutput = int8_t; | |||
| using ElementAccumulator = int32_t; | |||
| using ElementBias = int32_t; | |||
| using ElementCompute = float; | |||
| using EpilogueOp = cutlass::epilogue::thread::BiasAddLinearCombinationClamp< | |||
| ElementOutput, 4, ElementAccumulator, ElementBias, ElementCompute>; | |||
| typename EpilogueOp::Params epilogue{alpha, 0, 0}; | |||
| DISPATCH_KERNEL; | |||
| #undef DISPATCH_KERNEL_WITH_TILE_SHAPE | |||
| #undef DISPATCH_KERNEL | |||
| } | |||
| #endif | |||
| // vim: syntax=cuda.doxygen | |||
| @@ -1,44 +0,0 @@ | |||
| /** | |||
| * \file src/cuda/convolution/backward_data/cutlass_deconvolution_wrapper.cuh | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #pragma once | |||
| #include "cutlass/gemm/gemm.h" | |||
| #include "src/cuda/convolution_helper/parameter.cuh" | |||
| #include "src/cuda/utils.cuh" | |||
| namespace megdnn { | |||
| namespace cuda { | |||
| namespace cutlass_wrapper { | |||
| using GemmCoord = cutlass::gemm::GemmCoord; | |||
| template <typename Convolution> | |||
| void cutlass_deconvolution_wrapper( | |||
| const typename Convolution::ElementSrc* d_src, | |||
| const typename Convolution::ElementFilter* d_filter, | |||
| const typename Convolution::ElementBias* d_bias, | |||
| const typename Convolution::ElementDst* d_z, | |||
| typename Convolution::ElementDst* d_dst, int* workspace, | |||
| typename Convolution::ConvolutionParameter const& conv_param, | |||
| typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||
| cudaStream_t stream); | |||
| void do_deconv_int8_implicit_gemm_dp4a_ncdiv4hw4( | |||
| const int8_t* d_src, const int8_t* d_filter, int8_t* d_dst, | |||
| int* workspace, const convolution::ConvParam& param, float alpha, | |||
| const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, | |||
| int stages, cudaStream_t stream); | |||
| } // namespace cutlass_wrapper | |||
| } // namespace cuda | |||
| } // namespace megdnn | |||
| // vim: syntax=cuda.doxygen | |||
| @@ -1,62 +0,0 @@ | |||
| /** | |||
| * \file | |||
| * dnn/src/cuda/convolution/backward_data/implicit_gemm_deconv_cutlass_wrapper.cuinl | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "cutlass/convolution/device/convolution.h" | |||
| #include "src/cuda/convolution/backward_data/cutlass_deconvolution_wrapper.cuh" | |||
| using namespace megdnn; | |||
| using namespace cuda; | |||
| using namespace cutlass_wrapper; | |||
| template <typename Deconvolution> | |||
| void megdnn::cuda::cutlass_wrapper::cutlass_deconvolution_wrapper( | |||
| const typename Deconvolution::ElementSrc* d_src, | |||
| const typename Deconvolution::ElementFilter* d_filter, | |||
| const typename Deconvolution::ElementBias* d_bias, | |||
| const typename Deconvolution::ElementDst* d_z, | |||
| typename Deconvolution::ElementDst* d_dst, int* workspace, | |||
| typename Deconvolution::ConvolutionParameter const& conv_param, | |||
| typename Deconvolution::EpilogueOutputOp::Params const& epilogue, | |||
| cudaStream_t stream) { | |||
| typename Deconvolution::TensorRefSrc tensor_src{ | |||
| const_cast<typename Deconvolution::ElementSrc*>(d_src), | |||
| Deconvolution::LayoutSrc::packed( | |||
| {conv_param.N, conv_param.P, conv_param.Q, conv_param.K})}; | |||
| typename Deconvolution::TensorRefFilter tensor_filter{ | |||
| const_cast<typename Deconvolution::ElementFilter*>(d_filter), | |||
| Deconvolution::LayoutFilter::packed( | |||
| {conv_param.K, conv_param.R, conv_param.S, conv_param.C})}; | |||
| typename Deconvolution::TensorRefBias tensor_bias{ | |||
| const_cast<typename Deconvolution::ElementBias*>(d_bias), | |||
| Deconvolution::LayoutBias::packed({1, 1, 1, conv_param.K})}; | |||
| typename Deconvolution::TensorRefDst tensor_z{ | |||
| const_cast<typename Deconvolution::ElementDst*>(d_z), | |||
| Deconvolution::LayoutDst::packed( | |||
| {conv_param.N, conv_param.H, conv_param.W, conv_param.C})}; | |||
| typename Deconvolution::TensorRefDst tensor_dst{ | |||
| d_dst, | |||
| Deconvolution::LayoutDst::packed( | |||
| {conv_param.N, conv_param.H, conv_param.W, conv_param.C})}; | |||
| typename Deconvolution::Arguments arguments{conv_param, | |||
| tensor_src.non_const_ref(), | |||
| tensor_filter.non_const_ref(), | |||
| tensor_bias.non_const_ref(), | |||
| tensor_z.non_const_ref(), | |||
| tensor_dst.non_const_ref(), | |||
| epilogue}; | |||
| Deconvolution deconv_op; | |||
| cutlass_check(deconv_op.initialize(arguments, workspace)); | |||
| cutlass_check(deconv_op(stream)); | |||
| after_kernel_launch(); | |||
| } | |||
| // vim: syntax=cuda.doxygen | |||
| @@ -1,5 +1,6 @@ | |||
| /** | |||
| * \file dnn/src/cuda/convolution/backward_data/implicit_gemm_int8_nchw4_dp4a.cpp | |||
| * \file | |||
| * dnn/src/cuda/convolution/backward_data/implicit_gemm_int8_nchw4_dp4a.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| @@ -10,11 +11,11 @@ | |||
| * implied. | |||
| */ | |||
| #include "./algo.h" | |||
| #include "src/cuda/utils.h" | |||
| #include "src/cuda/convolution_helper/parameter.cuh" | |||
| #include "src/cuda/convolution/backward_data/cutlass_deconvolution_wrapper.cuh" | |||
| #include "src/cuda/convolution/backward_data/algo.h" | |||
| #include "src/cuda/convolution/backward_data/deconv_int8_helper.cuh" | |||
| #include "src/cuda/convolution_helper/parameter.cuh" | |||
| #include "src/cuda/cutlass/singleton.h" | |||
| #include "src/cuda/utils.h" | |||
| using namespace megdnn; | |||
| using namespace cuda; | |||
| @@ -70,6 +71,7 @@ size_t ConvolutionBackwardDataImpl::AlgoInt8NCHW4DotProdImplicitGemm:: | |||
| void ConvolutionBackwardDataImpl::AlgoInt8NCHW4DotProdImplicitGemm::exec( | |||
| const ExecArgs& args) const { | |||
| auto&& param = args.opr->param(); | |||
| auto&& fm = args.filter_meta; | |||
| size_t n = args.diff_layout->operator[](0), | |||
| co = args.diff_layout->operator[](1) * 4, | |||
| @@ -81,6 +83,7 @@ void ConvolutionBackwardDataImpl::AlgoInt8NCHW4DotProdImplicitGemm::exec( | |||
| size_t fh = fm.spatial[0], fw = fm.spatial[1]; | |||
| size_t sh = fm.stride[0], sw = fm.stride[1]; | |||
| size_t ph = fm.padding[0], pw = fm.padding[1]; | |||
| size_t dh = param.dilate_h, dw = param.dilate_w; | |||
| auto&& stream = cuda_stream(args.opr->handle()); | |||
| @@ -93,12 +96,6 @@ void ConvolutionBackwardDataImpl::AlgoInt8NCHW4DotProdImplicitGemm::exec( | |||
| filter_ptr, args.filter_tensor->compatible_ptr<int8_t>(), co, | |||
| ci, fh, fw, stream); | |||
| } | |||
| convolution::ConvParam kern_param; | |||
| kern_param.n = n, kern_param.co = co, kern_param.ci = ci, | |||
| kern_param.hi = hi, kern_param.wi = wi, kern_param.ho = ho, | |||
| kern_param.wo = wo, kern_param.ph = ph, kern_param.pw = pw, | |||
| kern_param.sh = sh, kern_param.sw = sw, kern_param.fh = fh, | |||
| kern_param.fw = fw; | |||
| float diff_scale = | |||
| args.diff_layout->dtype.param<dtype::QuantizedS8>().scale, | |||
| @@ -106,17 +103,60 @@ void ConvolutionBackwardDataImpl::AlgoInt8NCHW4DotProdImplicitGemm::exec( | |||
| args.filter_layout->dtype.param<dtype::QuantizedS8>().scale, | |||
| grad_scale = | |||
| args.grad_layout->dtype.param<dtype::QuantizedS8>().scale; | |||
| float alpha = diff_scale * filter_scale / grad_scale; | |||
| cutlass_wrapper::do_deconv_int8_implicit_gemm_dp4a_ncdiv4hw4( | |||
| args.diff_tensor->compatible_ptr<int8_t>(), filter_ptr, | |||
| args.grad_tensor->compatible_ptr<int8_t>(), nullptr, kern_param, | |||
| alpha, | |||
| cutlass_wrapper::GemmCoord{m_algo_param.threadblock_m, | |||
| m_algo_param.threadblock_n, | |||
| m_algo_param.threadblock_k}, | |||
| cutlass_wrapper::GemmCoord{m_algo_param.warp_m, m_algo_param.warp_n, | |||
| m_algo_param.warp_k}, | |||
| m_algo_param.stage, stream); | |||
| // \note these constants of cutlass epilogue will be passed to struct | |||
| // `ConvolutionArguments` by pointer and interpreted as ElementCompute*, a | |||
| // different dtype here results in undefined epilogue behaviors | |||
| float alpha = diff_scale * filter_scale / grad_scale, beta = 0.f, | |||
| gamma = 0.f, delta = 0.f; | |||
| using namespace cutlass::library; | |||
| // only use 16x64x8_16x64x8_2stages impl | |||
| ConvolutionKey key{ | |||
| cutlass::conv::Operator::kDgrad, | |||
| NumericTypeID::kS8, | |||
| LayoutTypeID::kTensorNC4HW4, | |||
| NumericTypeID::kS8, | |||
| LayoutTypeID::kTensorK4RSC4, | |||
| NumericTypeID::kS8, | |||
| LayoutTypeID::kTensorNC4HW4, | |||
| NumericTypeID::kS32, | |||
| LayoutTypeID::kTensorNC4HW4, | |||
| cutlass::conv::ConvType::kConvolution, | |||
| m_algo_param.threadblock_m, | |||
| m_algo_param.threadblock_n, | |||
| m_algo_param.threadblock_k, | |||
| m_algo_param.warp_m, | |||
| m_algo_param.warp_n, | |||
| m_algo_param.warp_k, | |||
| 1, | |||
| 1, | |||
| 4, | |||
| cutlass::epilogue::EpilogueType::kBiasAddLinearCombinationClamp, | |||
| m_algo_param.stage, | |||
| true, | |||
| false}; | |||
| const Operation* op = Singleton::get().operation_table.find_op(key); | |||
| // gcc prints warnings when size_t values are implicitly narrowed to int | |||
| cutlass::conv::Conv2dProblemSize problem_size{ | |||
| int(n), int(hi), int(wi), int(ci), | |||
| int(co), int(fh), int(fw), int(ho), | |||
| int(wo), int(ph), int(pw), int(sh), | |||
| int(sw), int(dh), int(dw), cutlass::conv::Mode::kCrossCorrelation}; | |||
| cutlass::library::ConvolutionArguments conv_args{ | |||
| problem_size, args.diff_tensor->compatible_ptr<int8_t>(), | |||
| filter_ptr, nullptr, | |||
| nullptr, args.grad_tensor->compatible_ptr<int8_t>(), | |||
| &alpha, &beta, | |||
| &gamma, &delta, | |||
| nullptr, nullptr, | |||
| nullptr, nullptr}; | |||
| cutlass_check(op->run(&conv_args, nullptr, stream)); | |||
| after_kernel_launch(); | |||
| } | |||
| @@ -11,16 +11,16 @@ | |||
| * implied. | |||
| */ | |||
| #include "./algo.h" | |||
| #include "src/cuda/utils.h" | |||
| #include "src/cuda/convolution/backward_data/algo.h" | |||
| #include "src/cuda/convolution_helper/parameter.cuh" | |||
| #include "src/cuda/convolution/backward_data/cutlass_deconvolution_wrapper.cuh" | |||
| #include "src/cuda/cutlass/singleton.h" | |||
| #include "src/cuda/utils.h" | |||
| using namespace megdnn; | |||
| using namespace cuda; | |||
| bool ConvolutionBackwardDataImpl::AlgoInt8NCHWDotProdImplicitGemm:: | |||
| is_available(const SizeArgs& args) const { | |||
| bool ConvolutionBackwardDataImpl::AlgoInt8NCHWDotProdImplicitGemm::is_available( | |||
| const SizeArgs& args) const { | |||
| auto&& fm = args.filter_meta; | |||
| if (fm.format != Param::Format::NCHW) | |||
| return false; | |||
| @@ -42,7 +42,8 @@ bool ConvolutionBackwardDataImpl::AlgoInt8NCHWDotProdImplicitGemm:: | |||
| // TODO support group deconv int8 | |||
| available &= (fm.group == 1); | |||
| // ic and oc must be multiples of 4 | |||
| available &= ((fm.group * fm.icpg) % 4 == 0 && (fm.group * fm.ocpg) % 4 == 0); | |||
| available &= | |||
| ((fm.group * fm.icpg) % 4 == 0 && (fm.group * fm.ocpg) % 4 == 0); | |||
| // mode must be cross correlation | |||
| available &= !fm.should_flip; | |||
| // mode must be 2D | |||
| @@ -73,6 +74,7 @@ size_t ConvolutionBackwardDataImpl::AlgoInt8NCHWDotProdImplicitGemm:: | |||
| void ConvolutionBackwardDataImpl::AlgoInt8NCHWDotProdImplicitGemm::exec( | |||
| const ExecArgs& args) const { | |||
| auto&& param = args.opr->param(); | |||
| auto&& fm = args.filter_meta; | |||
| size_t n = args.diff_layout->operator[](0), | |||
| co = args.diff_layout->operator[](1), | |||
| @@ -84,6 +86,7 @@ void ConvolutionBackwardDataImpl::AlgoInt8NCHWDotProdImplicitGemm::exec( | |||
| size_t fh = fm.spatial[0], fw = fm.spatial[1]; | |||
| size_t sh = fm.stride[0], sw = fm.stride[1]; | |||
| size_t ph = fm.padding[0], pw = fm.padding[1]; | |||
| size_t dh = param.dilate_h, dw = param.dilate_w; | |||
| auto&& stream = cuda_stream(args.opr->handle()); | |||
| @@ -120,26 +123,63 @@ void ConvolutionBackwardDataImpl::AlgoInt8NCHWDotProdImplicitGemm::exec( | |||
| } | |||
| int8_t* inner_grad_ptr = reinterpret_cast<int8_t*>(bundle.get(2)); | |||
| convolution::ConvParam kern_param; | |||
| kern_param.n = n, kern_param.co = co, kern_param.ci = ci, | |||
| kern_param.hi = hi, kern_param.wi = wi, kern_param.ho = ho, | |||
| kern_param.wo = wo, kern_param.ph = ph, kern_param.pw = pw, | |||
| kern_param.sh = sh, kern_param.sw = sw, kern_param.fh = fh, | |||
| kern_param.fw = fw; | |||
| float diff_scale = | |||
| args.diff_layout->dtype.param<dtype::QuantizedS8>().scale, | |||
| filter_scale = | |||
| args.filter_layout->dtype.param<dtype::QuantizedS8>().scale, | |||
| grad_scale = | |||
| args.grad_layout->dtype.param<dtype::QuantizedS8>().scale; | |||
| float alpha = diff_scale * filter_scale / grad_scale; | |||
| // \note these constants of cutlass epilogue will be passed to struct | |||
| // `ConvolutionArguments` by pointer and interpreted as ElementCompute*, a | |||
| // different dtype here results in undefined epilogue behaviors | |||
| float alpha = diff_scale * filter_scale / grad_scale, beta = 0.f, | |||
| gamma = 0.f, delta = 0.f; | |||
| using namespace cutlass::library; | |||
| // only use 16x64x8_16x64x8_2stages impl | |||
| cutlass_wrapper::do_deconv_int8_implicit_gemm_dp4a_ncdiv4hw4( | |||
| inner_diff_ptr, inner_filter_ptr, inner_grad_ptr, nullptr, | |||
| kern_param, alpha, cutlass_wrapper::GemmCoord{16, 64, 8}, | |||
| cutlass_wrapper::GemmCoord{16, 64, 8}, 2, stream); | |||
| ConvolutionKey key{ | |||
| cutlass::conv::Operator::kDgrad, | |||
| NumericTypeID::kS8, | |||
| LayoutTypeID::kTensorNC4HW4, | |||
| NumericTypeID::kS8, | |||
| LayoutTypeID::kTensorK4RSC4, | |||
| NumericTypeID::kS8, | |||
| LayoutTypeID::kTensorNC4HW4, | |||
| NumericTypeID::kS32, | |||
| LayoutTypeID::kTensorNC4HW4, | |||
| cutlass::conv::ConvType::kConvolution, | |||
| 16, | |||
| 64, | |||
| 8, | |||
| 16, | |||
| 64, | |||
| 8, | |||
| 1, | |||
| 1, | |||
| 4, | |||
| cutlass::epilogue::EpilogueType::kBiasAddLinearCombinationClamp, | |||
| 2, | |||
| true, | |||
| false}; | |||
| const Operation* op = Singleton::get().operation_table.find_op(key); | |||
| // gcc prints warnings when size_t values are implicitly narrowed to int | |||
| cutlass::conv::Conv2dProblemSize problem_size{ | |||
| int(n), int(hi), int(wi), int(ci), | |||
| int(co), int(fh), int(fw), int(ho), | |||
| int(wo), int(ph), int(pw), int(sh), | |||
| int(sw), int(dh), int(dw), cutlass::conv::Mode::kCrossCorrelation}; | |||
| cutlass::library::ConvolutionArguments conv_args{ | |||
| problem_size, inner_diff_ptr, inner_filter_ptr, nullptr, | |||
| nullptr, inner_grad_ptr, &alpha, &beta, | |||
| &gamma, &delta, nullptr, nullptr, | |||
| nullptr, nullptr}; | |||
| cutlass_check(op->run(&conv_args, nullptr, stream)); | |||
| after_kernel_launch(); | |||
| @@ -0,0 +1,107 @@ | |||
| /*************************************************************************************************** | |||
| * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. | |||
| * | |||
| * Redistribution and use in source and binary forms, with or without | |||
| *modification, are permitted provided that the following conditions are met: | |||
| * * Redistributions of source code must retain the above copyright notice, | |||
| *this list of conditions and the following disclaimer. | |||
| * * Redistributions in binary form must reproduce the above copyright | |||
| *notice, this list of conditions and the following disclaimer in the | |||
| *documentation and/or other materials provided with the distribution. | |||
| * * Neither the name of the NVIDIA CORPORATION nor the names of its | |||
| *contributors may be used to endorse or promote products derived from this | |||
| *software without specific prior written permission. | |||
| * | |||
| * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||
| *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||
| *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | |||
| *DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, | |||
| *INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, | |||
| * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, | |||
| *DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY | |||
| *OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING | |||
| *NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, | |||
| *EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||
| * | |||
| **************************************************************************************************/ | |||
| /** | |||
| * \file dnn/src/cuda/cutlass/arch_mappings.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #pragma once | |||
| #include "cutlass/arch/arch.h" | |||
| #include "cutlass/arch/mma.h" | |||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||
| namespace cutlass { | |||
| namespace library { | |||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||
| template <typename ArchTag, typename OperatorClass> | |||
| struct ArchMap; | |||
| template <> | |||
| struct ArchMap<arch::Sm50, arch::OpClassSimt> { | |||
| static int const kMin = 50; | |||
| static int const kMax = 1024; | |||
| }; | |||
| template <> | |||
| struct ArchMap<arch::Sm60, arch::OpClassSimt> { | |||
| static int const kMin = 60; | |||
| static int const kMax = 1024; | |||
| }; | |||
| template <> | |||
| struct ArchMap<arch::Sm61, arch::OpClassSimt> { | |||
| static int const kMin = 61; | |||
| static int const kMax = 1024; | |||
| }; | |||
| template <> | |||
| struct ArchMap<arch::Sm70, arch::OpClassWmmaTensorOp> { | |||
| static int const kMin = 70; | |||
| static int const kMax = 1024; | |||
| }; | |||
| template <> | |||
| struct ArchMap<arch::Sm70, arch::OpClassTensorOp> { | |||
| static int const kMin = 70; | |||
| static int const kMax = 75; | |||
| }; | |||
| template <typename OperatorClass> | |||
| struct ArchMap<arch::Sm75, OperatorClass> { | |||
| static int const kMin = 75; | |||
| static int const kMax = 1024; | |||
| }; | |||
| template <typename OperatorClass> | |||
| struct ArchMap<arch::Sm80, OperatorClass> { | |||
| static int const kMin = 80; | |||
| static int const kMax = 1024; | |||
| }; | |||
| template <typename OperatorClass> | |||
| struct ArchMap<arch::Sm86, OperatorClass> { | |||
| static int const kMin = 86; | |||
| static int const kMax = 1024; | |||
| }; | |||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||
| } // namespace library | |||
| } // namespace cutlass | |||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||
| @@ -0,0 +1,307 @@ | |||
| /*************************************************************************************************** | |||
| * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. | |||
| * | |||
| * Redistribution and use in source and binary forms, with or without | |||
| *modification, are permitted provided that the following conditions are met: | |||
| * * Redistributions of source code must retain the above copyright notice, | |||
| *this list of conditions and the following disclaimer. | |||
| * * Redistributions in binary form must reproduce the above copyright | |||
| *notice, this list of conditions and the following disclaimer in the | |||
| *documentation and/or other materials provided with the distribution. | |||
| * * Neither the name of the NVIDIA CORPORATION nor the names of its | |||
| *contributors may be used to endorse or promote products derived from this | |||
| *software without specific prior written permission. | |||
| * | |||
| * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||
| *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||
| *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | |||
| *DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, | |||
| *INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, | |||
| * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, | |||
| *DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY | |||
| *OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING | |||
| *NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, | |||
| *EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||
| * | |||
| **************************************************************************************************/ | |||
| /** | |||
| * \file dnn/src/cuda/cutlass/convolution_operation.h | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #pragma once | |||
| #include "cutlass/convolution/device/convolution.h" | |||
| #include "src/cuda/cutlass/library_internal.h" | |||
| /////////////////////////////////////////////////////////////////////////////////////////////////// | |||
| namespace cutlass { | |||
| namespace library { | |||
| /////////////////////////////////////////////////////////////////////////////////////////////////// | |||
| template <typename Operator_> | |||
| class ConvolutionOperationBase : public Operation { | |||
| public: | |||
| using Operator = Operator_; | |||
| using ElementSrc = typename Operator::ElementSrc; | |||
| using LayoutSrc = typename Operator::LayoutSrc; | |||
| using ElementFilter = typename Operator::ElementFilter; | |||
| using LayoutFilter = typename Operator::LayoutFilter; | |||
| using ElementDst = typename Operator::ElementDst; | |||
| using LayoutDst = typename Operator::LayoutDst; | |||
| using ElementBias = typename Operator::ElementBias; | |||
| using LayoutBias = typename Operator::LayoutBias; | |||
| using ElementAccumulator = typename Operator::ElementAccumulator; | |||
| ConvolutionOperationBase(char const* name = "unknown_convolution") { | |||
| m_description.name = name; | |||
| m_description.provider = Provider::kCUTLASS; | |||
| m_description.kind = OperationKind::kConvolution; | |||
| m_description.conv_op = Operator::kConvolutionalOperator; | |||
| m_description.tile_description.threadblock_shape = make_Coord( | |||
| Operator::ThreadblockShape::kM, Operator::ThreadblockShape::kN, | |||
| Operator::ThreadblockShape::kK); | |||
| m_description.tile_description.threadblock_stages = Operator::kStages; | |||
| m_description.tile_description.warp_count = | |||
| make_Coord(Operator::ConvolutionKernel::WarpCount::kM, | |||
| Operator::ConvolutionKernel::WarpCount::kN, | |||
| Operator::ConvolutionKernel::WarpCount::kK); | |||
| m_description.tile_description.math_instruction.instruction_shape = | |||
| make_Coord(Operator::InstructionShape::kM, | |||
| Operator::InstructionShape::kN, | |||
| Operator::InstructionShape::kK); | |||
| m_description.tile_description.math_instruction.element_accumulator = | |||
| NumericTypeMap<ElementAccumulator>::kId; | |||
| m_description.tile_description.math_instruction.opcode_class = | |||
| OpcodeClassMap<typename Operator::OperatorClass>::kId; | |||
| m_description.tile_description.math_instruction.math_operation = | |||
| MathOperationMap<typename Operator::Operator>::kId; | |||
| m_description.tile_description.minimum_compute_capability = | |||
| ArchMap<typename Operator::ArchTag, | |||
| typename Operator::OperatorClass>::kMin; | |||
| m_description.tile_description.maximum_compute_capability = | |||
| ArchMap<typename Operator::ArchTag, | |||
| typename Operator::OperatorClass>::kMax; | |||
| m_description.src = make_TensorDescription<ElementSrc, LayoutSrc>( | |||
| Operator::kAlignmentSrc); | |||
| m_description.filter = | |||
| make_TensorDescription<ElementFilter, LayoutFilter>( | |||
| Operator::kAlignmentFilter); | |||
| m_description.dst = make_TensorDescription<ElementDst, LayoutDst>( | |||
| Operator::kAlignmentDst); | |||
| m_description.bias = make_TensorDescription<ElementBias, LayoutBias>( | |||
| Operator::kAlignmentDst); | |||
| m_description.convolution_type = Operator::kConvolutionType; | |||
| m_description.arch_tag = ArchTagMap<typename Operator::ArchTag>::kId; | |||
| m_description.epilogue_type = Operator::EpilogueOutputOp::kType; | |||
| m_description.epilogue_count = Operator::EpilogueOutputOp::kCount; | |||
| m_description.threadblock_swizzle = ThreadblockSwizzleMap< | |||
| typename Operator::ThreadblockSwizzle>::kId; | |||
| m_description.need_load_from_const_mem = | |||
| Operator::kNeedLoadFromConstMem; | |||
| m_description.gemm_mode = Operator::kGemmMode; | |||
| m_description.without_shared_load = Operator::kWithoutSharedLoad; | |||
| } | |||
| virtual OperationDescription const& description() const { | |||
| return m_description; | |||
| } | |||
| protected: | |||
| ConvolutionDescription m_description; | |||
| }; | |||
| /////////////////////////////////////////////////////////////////////////////////////////////////// | |||
| namespace detail { | |||
| template <typename EpilogueOp, epilogue::EpilogueType type> | |||
| struct init_epilogue_param_; | |||
| template <typename EpilogueOp> | |||
| struct init_epilogue_param_<EpilogueOp, | |||
| epilogue::EpilogueType::kBiasAddLinearCombination> { | |||
| using ElementCompute = typename EpilogueOp::ElementCompute; | |||
| typename EpilogueOp::Params get(ConvolutionArguments const* conv_args) { | |||
| return {*static_cast<ElementCompute const*>(conv_args->alpha), | |||
| *static_cast<ElementCompute const*>(conv_args->beta), | |||
| *static_cast<ElementCompute const*>(conv_args->gamma), | |||
| *static_cast<ElementCompute const*>(conv_args->delta)}; | |||
| } | |||
| }; | |||
| template <typename EpilogueOp> | |||
| struct init_epilogue_param_< | |||
| EpilogueOp, epilogue::EpilogueType::kBiasAddLinearCombinationClamp> { | |||
| using ElementCompute = typename EpilogueOp::ElementCompute; | |||
| typename EpilogueOp::Params get(ConvolutionArguments const* conv_args) { | |||
| return {*static_cast<ElementCompute const*>(conv_args->alpha), | |||
| *static_cast<ElementCompute const*>(conv_args->beta), | |||
| *static_cast<ElementCompute const*>(conv_args->gamma), | |||
| *static_cast<ElementCompute const*>(conv_args->delta)}; | |||
| } | |||
| }; | |||
| template <typename EpilogueOp> | |||
| struct init_epilogue_param_< | |||
| EpilogueOp, epilogue::EpilogueType::kBiasAddLinearCombinationRelu> { | |||
| using ElementCompute = typename EpilogueOp::ElementCompute; | |||
| typename EpilogueOp::Params get(ConvolutionArguments const* conv_args) { | |||
| return {*static_cast<ElementCompute const*>(conv_args->alpha), | |||
| *static_cast<ElementCompute const*>(conv_args->beta), | |||
| *static_cast<ElementCompute const*>(conv_args->gamma), | |||
| *static_cast<ElementCompute const*>(conv_args->threshold), | |||
| *static_cast<ElementCompute const*>(conv_args->delta), | |||
| *static_cast<ElementCompute const*>(conv_args->theta)}; | |||
| } | |||
| }; | |||
| template <typename EpilogueOp> | |||
| struct init_epilogue_param_< | |||
| EpilogueOp, | |||
| epilogue::EpilogueType::kBiasAddLinearCombinationReluClamp> { | |||
| using ElementCompute = typename EpilogueOp::ElementCompute; | |||
| typename EpilogueOp::Params get(ConvolutionArguments const* conv_args) { | |||
| return {*static_cast<ElementCompute const*>(conv_args->alpha), | |||
| *static_cast<ElementCompute const*>(conv_args->beta), | |||
| *static_cast<ElementCompute const*>(conv_args->gamma), | |||
| *static_cast<ElementCompute const*>(conv_args->threshold), | |||
| *static_cast<ElementCompute const*>(conv_args->delta), | |||
| *static_cast<ElementCompute const*>(conv_args->theta)}; | |||
| } | |||
| }; | |||
| template <typename EpilogueOp> | |||
| struct init_epilogue_param_< | |||
| EpilogueOp, epilogue::EpilogueType::kBiasAddLinearCombinationHSwish> { | |||
| using ElementCompute = typename EpilogueOp::ElementCompute; | |||
| typename EpilogueOp::Params get(ConvolutionArguments const* conv_args) { | |||
| return {*static_cast<ElementCompute const*>(conv_args->alpha), | |||
| *static_cast<ElementCompute const*>(conv_args->beta), | |||
| *static_cast<ElementCompute const*>(conv_args->gamma), | |||
| *static_cast<ElementCompute const*>(conv_args->scale), | |||
| *static_cast<ElementCompute const*>(conv_args->delta), | |||
| *static_cast<ElementCompute const*>(conv_args->theta)}; | |||
| } | |||
| }; | |||
| template <typename EpilogueOp> | |||
| struct init_epilogue_param_< | |||
| EpilogueOp, | |||
| epilogue::EpilogueType::kBiasAddLinearCombinationHSwishClamp> { | |||
| using ElementCompute = typename EpilogueOp::ElementCompute; | |||
| typename EpilogueOp::Params get(ConvolutionArguments const* conv_args) { | |||
| return {*static_cast<ElementCompute const*>(conv_args->alpha), | |||
| *static_cast<ElementCompute const*>(conv_args->beta), | |||
| *static_cast<ElementCompute const*>(conv_args->gamma), | |||
| *static_cast<ElementCompute const*>(conv_args->scale), | |||
| *static_cast<ElementCompute const*>(conv_args->delta), | |||
| *static_cast<ElementCompute const*>(conv_args->theta)}; | |||
| } | |||
| }; | |||
| } // namespace detail | |||
| template <typename EpilogueOp> | |||
| struct init_epilogue_param | |||
| : public detail::init_epilogue_param_<EpilogueOp, EpilogueOp::kType> {}; | |||
| /////////////////////////////////////////////////////////////////////////////////////////////////// | |||
| template <typename Operator_> | |||
| class ConvolutionOperation : public ConvolutionOperationBase<Operator_> { | |||
| public: | |||
| using Operator = Operator_; | |||
| using ElementSrc = typename Operator::ElementSrc; | |||
| using LayoutSrc = typename Operator::LayoutSrc; | |||
| using ElementFilter = typename Operator::ElementFilter; | |||
| using LayoutFilter = typename Operator::LayoutFilter; | |||
| using ElementBias = typename Operator::ElementBias; | |||
| using LayoutBias = typename Operator::LayoutBias; | |||
| using ElementDst = typename Operator::ElementDst; | |||
| using LayoutDst = typename Operator::LayoutDst; | |||
| using ElementAccumulator = typename Operator::ElementAccumulator; | |||
| using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; | |||
| using OperatorArguments = typename Operator::Arguments; | |||
| ConvolutionOperation(char const* name = "unknown_gemm") | |||
| : ConvolutionOperationBase<Operator_>(name) {} | |||
| virtual Status run(void const* arguments_ptr, | |||
| void* device_workspace = nullptr, | |||
| cudaStream_t stream = nullptr) const { | |||
| cutlass::conv::Operator conv_op = this->m_description.conv_op; | |||
| ConvolutionArguments const* conv_args = | |||
| reinterpret_cast<ConvolutionArguments const*>(arguments_ptr); | |||
| const auto& ps = conv_args->problem_size; | |||
| OperatorArguments args; | |||
| args.problem_size = ps; | |||
| args.ref_src = { | |||
| static_cast<ElementSrc*>(const_cast<void*>(conv_args->src)), | |||
| LayoutSrc::packed(implicit_gemm_tensor_a_extent(conv_op, ps))}; | |||
| args.ref_filter = {static_cast<ElementFilter*>( | |||
| const_cast<void*>(conv_args->filter)), | |||
| LayoutFilter::packed( | |||
| implicit_gemm_tensor_b_extent(conv_op, ps))}; | |||
| args.ref_bias = { | |||
| static_cast<ElementBias*>(const_cast<void*>(conv_args->bias)), | |||
| LayoutBias::packed( | |||
| implicit_gemm_tensor_bias_extent(conv_op, ps))}; | |||
| args.ref_z = { | |||
| static_cast<ElementDst*>(const_cast<void*>(conv_args->z)), | |||
| LayoutDst::packed(implicit_gemm_tensor_c_extent(conv_op, ps))}; | |||
| args.ref_dst = { | |||
| static_cast<ElementDst*>(conv_args->dst), | |||
| LayoutDst::packed(implicit_gemm_tensor_c_extent(conv_op, ps))}; | |||
| args.output_op = | |||
| init_epilogue_param<typename Operator::EpilogueOutputOp>().get( | |||
| conv_args); | |||
| if (conv_args->extra_param) { | |||
| args.extra_param = | |||
| *reinterpret_cast<typename Operator::ExtraParam const*>( | |||
| conv_args->extra_param); | |||
| } | |||
| Operator op; | |||
| Status status = op.initialize(args, device_workspace); | |||
| if (status != Status::kSuccess) { | |||
| return status; | |||
| } | |||
| return op.run(stream); | |||
| } | |||
| }; | |||
| /////////////////////////////////////////////////////////////////////////////////////////////////// | |||
| } // namespace library | |||
| } // namespace cutlass | |||
| /////////////////////////////////////////////////////////////////////////////////////////////////// | |||
| @@ -0,0 +1,202 @@ | |||
| /*************************************************************************************************** | |||
| * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. | |||
| * | |||
| * Redistribution and use in source and binary forms, with or without | |||
| *modification, are permitted provided that the following conditions are met: | |||
| * * Redistributions of source code must retain the above copyright notice, | |||
| *this list of conditions and the following disclaimer. | |||
| * * Redistributions in binary form must reproduce the above copyright | |||
| *notice, this list of conditions and the following disclaimer in the | |||
| *documentation and/or other materials provided with the distribution. | |||
| * * Neither the name of the NVIDIA CORPORATION nor the names of its | |||
| *contributors may be used to endorse or promote products derived from this | |||
| *software without specific prior written permission. | |||
| * | |||
| * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||
| *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||
| *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | |||
| *DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, | |||
| *INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, | |||
| * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, | |||
| *DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY | |||
| *OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING | |||
| *NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, | |||
| *EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||
| * | |||
| **************************************************************************************************/ | |||
| /** | |||
| * \file dnn/src/cuda/cutlass/gemm_operation.h | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #pragma once | |||
| #include "cutlass/gemm/device/gemm.h" | |||
| #include "src/cuda/cutlass/library_internal.h" | |||
| /////////////////////////////////////////////////////////////////////////////////////////////////// | |||
| namespace cutlass { | |||
| namespace library { | |||
| /////////////////////////////////////////////////////////////////////////////////////////////////// | |||
| /// Check whether Operator has member ReductionKernel using SFINAE (Substitution | |||
| /// Failure Is Not An Error) | |||
| template <typename Operator> | |||
| struct split_k_mode { | |||
| template <typename T> | |||
| static char check(typename T::ReductionKernel*); | |||
| template <typename T> | |||
| static int check(...); | |||
| SplitKMode operator()() { | |||
| if (sizeof(check<Operator>(0)) == sizeof(char)) { | |||
| // cutlass::gemm::device::GemmSplitKParallel | |||
| return SplitKMode::kParallel; | |||
| } else { | |||
| // cutlass::gemm::device::Gemm | |||
| return SplitKMode::kNone; | |||
| } | |||
| } | |||
| }; | |||
| /////////////////////////////////////////////////////////////////////////////////////////////////// | |||
| template <typename Operator_> | |||
| class GemmOperationBase : public Operation { | |||
| public: | |||
| using Operator = Operator_; | |||
| using ElementA = typename Operator::ElementA; | |||
| using LayoutA = typename Operator::LayoutA; | |||
| using ElementB = typename Operator::ElementB; | |||
| using LayoutB = typename Operator::LayoutB; | |||
| using ElementC = typename Operator::ElementC; | |||
| using LayoutC = typename Operator::LayoutC; | |||
| using ElementAccumulator = typename Operator::ElementAccumulator; | |||
| GemmOperationBase(char const* name = "unknown_gemm") { | |||
| m_description.name = name; | |||
| m_description.provider = Provider::kCUTLASS; | |||
| m_description.kind = OperationKind::kGemm; | |||
| m_description.gemm_kind = GemmKind::kGemm; | |||
| m_description.tile_description.threadblock_shape = make_Coord( | |||
| Operator::ThreadblockShape::kM, Operator::ThreadblockShape::kN, | |||
| Operator::ThreadblockShape::kK); | |||
| m_description.tile_description.threadblock_stages = Operator::kStages; | |||
| m_description.tile_description.warp_count = | |||
| make_Coord(Operator::GemmKernel::WarpCount::kM, | |||
| Operator::GemmKernel::WarpCount::kN, | |||
| Operator::GemmKernel::WarpCount::kK); | |||
| m_description.tile_description.math_instruction.instruction_shape = | |||
| make_Coord(Operator::InstructionShape::kM, | |||
| Operator::InstructionShape::kN, | |||
| Operator::InstructionShape::kK); | |||
| m_description.tile_description.math_instruction.element_accumulator = | |||
| NumericTypeMap<ElementAccumulator>::kId; | |||
| m_description.tile_description.math_instruction.opcode_class = | |||
| OpcodeClassMap<typename Operator::OperatorClass>::kId; | |||
| m_description.tile_description.math_instruction.math_operation = | |||
| MathOperationMap<typename Operator::Operator>::kId; | |||
| m_description.tile_description.minimum_compute_capability = | |||
| ArchMap<typename Operator::ArchTag, | |||
| typename Operator::OperatorClass>::kMin; | |||
| m_description.tile_description.maximum_compute_capability = | |||
| ArchMap<typename Operator::ArchTag, | |||
| typename Operator::OperatorClass>::kMax; | |||
| m_description.A = make_TensorDescription<ElementA, LayoutA>( | |||
| Operator::kAlignmentA); | |||
| m_description.B = make_TensorDescription<ElementB, LayoutB>( | |||
| Operator::kAlignmentB); | |||
| m_description.C = make_TensorDescription<ElementC, LayoutC>( | |||
| Operator::kAlignmentC); | |||
| m_description.stages = Operator::kStages; | |||
| split_k_mode<Operator> mode; | |||
| m_description.split_k_mode = mode(); | |||
| } | |||
| virtual OperationDescription const& description() const { | |||
| return m_description; | |||
| } | |||
| protected: | |||
| GemmDescription m_description; | |||
| }; | |||
| /////////////////////////////////////////////////////////////////////////////////////////////////// | |||
| template <typename Operator_> | |||
| class GemmOperation : public GemmOperationBase<Operator_> { | |||
| public: | |||
| using Operator = Operator_; | |||
| using ElementA = typename Operator::ElementA; | |||
| using LayoutA = typename Operator::LayoutA; | |||
| using ElementB = typename Operator::ElementB; | |||
| using LayoutB = typename Operator::LayoutB; | |||
| using ElementC = typename Operator::ElementC; | |||
| using LayoutC = typename Operator::LayoutC; | |||
| using ElementAccumulator = typename Operator::ElementAccumulator; | |||
| using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; | |||
| using OperatorArguments = typename Operator::Arguments; | |||
| GemmOperation(char const* name = "unknown_gemm") | |||
| : GemmOperationBase<Operator_>(name) {} | |||
| virtual Status run(void const* arguments_ptr, | |||
| void* device_workspace = nullptr, | |||
| cudaStream_t stream = nullptr) const { | |||
| GemmArguments const* gemm_args = | |||
| reinterpret_cast<GemmArguments const*>(arguments_ptr); | |||
| OperatorArguments args; | |||
| args.problem_size = gemm_args->problem_size; | |||
| args.ref_A = {static_cast<ElementA const*>(gemm_args->A), | |||
| int(gemm_args->lda)}; | |||
| args.ref_B = {static_cast<ElementB const*>(gemm_args->B), | |||
| int(gemm_args->ldb)}; | |||
| args.ref_C = {static_cast<ElementC const*>(gemm_args->C), | |||
| int(gemm_args->ldc)}; | |||
| args.ref_D = {static_cast<ElementC*>(gemm_args->D), | |||
| int(gemm_args->ldd)}; | |||
| args.split_k_slices = gemm_args->split_k_slices; | |||
| args.epilogue = {*static_cast<ElementCompute const*>(gemm_args->alpha), | |||
| *static_cast<ElementCompute const*>(gemm_args->beta)}; | |||
| Operator op; | |||
| Status status = op.initialize(args, device_workspace); | |||
| if (status != Status::kSuccess) { | |||
| return status; | |||
| } | |||
| return op.run(stream); | |||
| } | |||
| }; | |||
| /////////////////////////////////////////////////////////////////////////////////////////////////// | |||
| } // namespace library | |||
| } // namespace cutlass | |||
| /////////////////////////////////////////////////////////////////////////////////////////////////// | |||
| @@ -0,0 +1,76 @@ | |||
| /*************************************************************************************************** | |||
| * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. | |||
| * | |||
| * Redistribution and use in source and binary forms, with or without | |||
| *modification, are permitted provided that the following conditions are met: | |||
| * * Redistributions of source code must retain the above copyright notice, | |||
| *this list of conditions and the following disclaimer. | |||
| * * Redistributions in binary form must reproduce the above copyright | |||
| *notice, this list of conditions and the following disclaimer in the | |||
| *documentation and/or other materials provided with the distribution. | |||
| * * Neither the name of the NVIDIA CORPORATION nor the names of its | |||
| *contributors may be used to endorse or promote products derived from this | |||
| *software without specific prior written permission. | |||
| * | |||
| * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||
| *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||
| *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | |||
| *DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, | |||
| *INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, | |||
| * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, | |||
| *DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY | |||
| *OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING | |||
| *NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, | |||
| *EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||
| * | |||
| **************************************************************************************************/ | |||
| /** | |||
| * \file dnn/src/cuda/cutlass/initialize_all.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "src/cuda/cutlass/manifest.h" | |||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||
| namespace cutlass { | |||
| namespace library { | |||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||
| #if __CUDACC_VER_MAJOR__ > 9 || \ | |||
| (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) | |||
| void initialize_all_gemm_simt_operations(Manifest& manifest); | |||
| void initialize_all_conv2d_simt_operations(Manifest& manifest); | |||
| void initialize_all_conv2d_tensorop8816_operations(Manifest& manifest); | |||
| void initialize_all_conv2d_tensorop8832_operations(Manifest& manifest); | |||
| void initialize_all_deconv_simt_operations(Manifest& manifest); | |||
| void initialize_all(Manifest& manifest) { | |||
| initialize_all_gemm_simt_operations(manifest); | |||
| initialize_all_conv2d_simt_operations(manifest); | |||
| initialize_all_conv2d_tensorop8816_operations(manifest); | |||
| initialize_all_conv2d_tensorop8832_operations(manifest); | |||
| initialize_all_deconv_simt_operations(manifest); | |||
| } | |||
| #else | |||
| void initialize_all(Manifest& manifest) {} | |||
| #endif | |||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||
| } // namespace library | |||
| } // namespace cutlass | |||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||
| @@ -0,0 +1,541 @@ | |||
| /*************************************************************************************************** | |||
| * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. | |||
| * | |||
| * Redistribution and use in source and binary forms, with or without | |||
| *modification, are permitted provided that the following conditions are met: | |||
| * * Redistributions of source code must retain the above copyright notice, | |||
| *this list of conditions and the following disclaimer. | |||
| * * Redistributions in binary form must reproduce the above copyright | |||
| *notice, this list of conditions and the following disclaimer in the | |||
| *documentation and/or other materials provided with the distribution. | |||
| * * Neither the name of the NVIDIA CORPORATION nor the names of its | |||
| *contributors may be used to endorse or promote products derived from this | |||
| *software without specific prior written permission. | |||
| * | |||
| * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||
| *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||
| *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | |||
| *DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, | |||
| *INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, | |||
| * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, | |||
| *DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY | |||
| *OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING | |||
| *NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, | |||
| *EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||
| * | |||
| **************************************************************************************************/ | |||
| /** | |||
| * \file dnn/src/cuda/cutlass/library.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #pragma once | |||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||
| #include <cuda_runtime.h> | |||
| #include <cstdint> | |||
| #include <stdexcept> | |||
| #include <string> | |||
| #include <vector> | |||
| #pragma GCC diagnostic push | |||
| #pragma GCC diagnostic ignored "-Wreorder" | |||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||
| #include "cutlass/cutlass.h" | |||
| #include "cutlass/layout/tensor.h" | |||
| #include "cutlass/matrix_coord.h" | |||
| #include "cutlass/tensor_coord.h" | |||
| #include "cutlass/conv/conv2d_problem_size.h" | |||
| #include "cutlass/conv/convolution.h" | |||
| #include "cutlass/epilogue/epilogue.h" | |||
| #include "cutlass/gemm/gemm.h" | |||
| #pragma GCC diagnostic pop | |||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||
| namespace cutlass { | |||
| namespace library { | |||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||
| /// Layout type identifier | |||
| enum class LayoutTypeID { | |||
| kUnknown, | |||
| kColumnMajor, | |||
| kRowMajor, | |||
| kColumnMajorInterleavedK2, | |||
| kRowMajorInterleavedK2, | |||
| kColumnMajorInterleavedK4, | |||
| kRowMajorInterleavedK4, | |||
| kColumnMajorInterleavedK16, | |||
| kRowMajorInterleavedK16, | |||
| kColumnMajorInterleavedK32, | |||
| kRowMajorInterleavedK32, | |||
| kColumnMajorInterleavedK64, | |||
| kRowMajorInterleavedK64, | |||
| kTensorNCHW, | |||
| kTensorNCDHW, | |||
| kTensorNHWC, | |||
| kTensorNDHWC, | |||
| kTensorNC4HW4, | |||
| kTensorC4RSK4, | |||
| kTensorNC8HW8, | |||
| kTensorC8RSK8, | |||
| kTensorNC16HW16, | |||
| kTensorC16RSK16, | |||
| kTensorNC32HW32, | |||
| kTensorC32RSK32, | |||
| kTensorNC64HW64, | |||
| kTensorC64RSK64, | |||
| kTensorK4RSC4, | |||
| kInvalid | |||
| }; | |||
| /// Numeric data type | |||
| enum class NumericTypeID { | |||
| kUnknown, | |||
| kVoid, | |||
| kB1, | |||
| kU2, | |||
| kU4, | |||
| kU8, | |||
| kU16, | |||
| kU32, | |||
| kU64, | |||
| kS2, | |||
| kS4, | |||
| kS8, | |||
| kS16, | |||
| kS32, | |||
| kS64, | |||
| kF16, | |||
| kBF16, | |||
| kTF32, | |||
| kF32, | |||
| kF64, | |||
| kCF16, | |||
| kCBF16, | |||
| kCF32, | |||
| kCTF32, | |||
| kCF64, | |||
| kCS2, | |||
| kCS4, | |||
| kCS8, | |||
| kCS16, | |||
| kCS32, | |||
| kCS64, | |||
| kCU2, | |||
| kCU4, | |||
| kCU8, | |||
| kCU16, | |||
| kCU32, | |||
| kCU64, | |||
| kInvalid | |||
| }; | |||
| /// Enumerated type describing a transformation on a complex value. | |||
| enum class ComplexTransform { kNone, kConjugate, kInvalid }; | |||
| /// Providers | |||
| enum class Provider { | |||
| kNone, | |||
| kCUTLASS, | |||
| kReferenceHost, | |||
| kReferenceDevice, | |||
| kCUBLAS, | |||
| kCUDNN, | |||
| kInvalid | |||
| }; | |||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||
| /// Enumeration indicating the kind of operation | |||
| enum class OperationKind { | |||
| kGemm, | |||
| kConv2d, | |||
| kConv3d, | |||
| kConvolution, | |||
| kEqGemm, | |||
| kSparseGemm, | |||
| kReduction, | |||
| kInvalid | |||
| }; | |||
| /// Enumeration indicating whether scalars are in host or device memory | |||
| enum class ScalarPointerMode { kHost, kDevice, kInvalid }; | |||
| /// Describes how reductions are performed across threadblocks | |||
| enum class SplitKMode { kNone, kSerial, kParallel, kParallelSerial, kInvalid }; | |||
| /// Indicates the classificaition of the math instruction | |||
| enum class OpcodeClassID { | |||
| kSimt, | |||
| kTensorOp, | |||
| kWmmaTensorOp, | |||
| kSparseTensorOp, | |||
| kInvalid | |||
| }; | |||
| enum class ArchTagID { | |||
| kSm50, | |||
| kSm60, | |||
| kSm61, | |||
| kSm70, | |||
| kSm72, | |||
| kSm75, | |||
| kSm80, | |||
| kSm86, | |||
| kInvalid | |||
| }; | |||
| enum class MathOperationID { | |||
| kAdd, | |||
| kMultiplyAdd, | |||
| kMultiplyAddSaturate, | |||
| kMultiplyAddFastBF16, | |||
| kMultiplyAddFastF16, | |||
| kMultiplyAddComplex, | |||
| kMultiplyAddGaussianComplex, | |||
| kXorPopc, | |||
| kInvalid | |||
| }; | |||
| enum class ThreadblockSwizzleID { | |||
| kGemmIdentity, | |||
| kGemmHorizontal, | |||
| kGemmBatchedIdentity, | |||
| kGemmSplitKIdentity, | |||
| kGemmSplitKHorizontal, | |||
| kGemvBatchedStridedDefault, | |||
| kGemvBatchedStridedReduction, | |||
| kConvolutionFpropCxRSKx, | |||
| kConvolutionDgradCxRSKx, | |||
| kConvolutionFpropNCxHWx, | |||
| kConvolutionFpropTrans, | |||
| kConvolutionDgradNCxHWx, | |||
| kInvalid | |||
| }; | |||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||
| /// Enumeration indicating what kind of GEMM operation to perform | |||
| enum class GemmKind { | |||
| kGemm, | |||
| kSparse, | |||
| kUniversal, | |||
| kPlanarComplex, | |||
| kPlanarComplexArray, | |||
| kInvalid | |||
| }; | |||
| /// Mode of Universal GEMM | |||
| using GemmUniversalMode = cutlass::gemm::GemmUniversalMode; | |||
| /// Enumeration indicating what kind of Conv2d operation to perform | |||
| enum class ConvKind { kUnknown, kFprop, kDgrad, kWgrad, kInvalid }; | |||
| enum class ConvModeID { kCrossCorrelation, kConvolution, kInvalid }; | |||
| // Iterator algorithm enum in order of general performance-efficiency | |||
| enum class IteratorAlgorithmID { kNone, kAnalytic, kOptimized, kInvalid }; | |||
| enum class EpilogueKind { | |||
| kUnknown, | |||
| kBiasAddLinearCombination, | |||
| kBiasAddLinearCombinationClamp, | |||
| kBiasAddLInearCombinationHSwish, | |||
| kBiasAddLInearCombinationHSwishClamp, | |||
| kBiasAddLInearCombinationRelu, | |||
| kBiasAddLInearCombinationReluClamp, | |||
| kConversion, | |||
| kLinearCombination, | |||
| kLinearCombinationClamp, | |||
| kLinearCombinationPlanarComplex, | |||
| kLinearCombinationRelu, | |||
| kLinearCombinationSigmoid, | |||
| kInvalid | |||
| }; | |||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||
| struct MathInstructionDescription { | |||
| /// Shape of the target math instruction | |||
| cutlass::gemm::GemmCoord instruction_shape; | |||
| /// Describes the data type of the internal accumulator | |||
| NumericTypeID element_accumulator; | |||
| /// Classification of math instruction | |||
| OpcodeClassID opcode_class; | |||
| /// Type of math operation performed | |||
| MathOperationID math_operation; | |||
| // | |||
| // Methods | |||
| // | |||
| MathInstructionDescription( | |||
| cutlass::gemm::GemmCoord instruction_shape = | |||
| cutlass::gemm::GemmCoord(), | |||
| NumericTypeID element_accumulator = NumericTypeID::kInvalid, | |||
| OpcodeClassID opcode_class = OpcodeClassID::kInvalid, | |||
| MathOperationID math_operation = MathOperationID::kMultiplyAdd) | |||
| : instruction_shape(instruction_shape), | |||
| element_accumulator(element_accumulator), | |||
| opcode_class(opcode_class), | |||
| math_operation(math_operation) {} | |||
| // Equality operator | |||
| inline bool operator==(MathInstructionDescription const& rhs) const { | |||
| return ((instruction_shape == rhs.instruction_shape) && | |||
| (element_accumulator == rhs.element_accumulator) && | |||
| (opcode_class == rhs.opcode_class) && | |||
| (math_operation == rhs.math_operation)); | |||
| } | |||
| // Inequality operator | |||
| inline bool operator!=(MathInstructionDescription const& rhs) const { | |||
| return !(*this == rhs); | |||
| } | |||
| }; | |||
| /// Structure describing the tiled structure of a GEMM-like computation | |||
| struct TileDescription { | |||
| /// Describes the shape of a threadblock (in elements) | |||
| cutlass::gemm::GemmCoord threadblock_shape; | |||
| /// Describes the number of pipeline stages in the threadblock-scoped | |||
| /// mainloop | |||
| int threadblock_stages; | |||
| /// Number of warps in each logical dimension | |||
| cutlass::gemm::GemmCoord warp_count; | |||
| /// Core math instruction | |||
| MathInstructionDescription math_instruction; | |||
| /// Minimum compute capability (e.g. 70, 75) of a device eligible to run the | |||
| /// operation. | |||
| int minimum_compute_capability; | |||
| /// Minimum compute capability (e.g. 70, 75) of a device eligible to run the | |||
| /// operation. | |||
| int maximum_compute_capability; | |||
| // | |||
| // Methods | |||
| // | |||
| TileDescription( | |||
| cutlass::gemm::GemmCoord threadblock_shape = | |||
| cutlass::gemm::GemmCoord(), | |||
| int threadblock_stages = 0, | |||
| cutlass::gemm::GemmCoord warp_count = cutlass::gemm::GemmCoord(), | |||
| MathInstructionDescription math_instruction = | |||
| MathInstructionDescription(), | |||
| int minimum_compute_capability = 0, | |||
| int maximum_compute_capability = 0) | |||
| : threadblock_shape(threadblock_shape), | |||
| threadblock_stages(threadblock_stages), | |||
| warp_count(warp_count), | |||
| math_instruction(math_instruction), | |||
| minimum_compute_capability(minimum_compute_capability), | |||
| maximum_compute_capability(maximum_compute_capability) {} | |||
| // Equality operator | |||
| inline bool operator==(TileDescription const& rhs) const { | |||
| return ((threadblock_shape == rhs.threadblock_shape) && | |||
| (threadblock_stages == rhs.threadblock_stages) && | |||
| (warp_count == rhs.warp_count) && | |||
| (math_instruction == rhs.math_instruction) && | |||
| (minimum_compute_capability == | |||
| rhs.minimum_compute_capability) && | |||
| (maximum_compute_capability == rhs.maximum_compute_capability)); | |||
| } | |||
| // Inequality operator | |||
| inline bool operator!=(TileDescription const& rhs) const { | |||
| return !(*this == rhs); | |||
| } | |||
| }; | |||
| /// High-level description of an operation | |||
| struct OperationDescription { | |||
| /// Unique identifier describing the operation | |||
| char const* name; | |||
| /// Operation provider | |||
| Provider provider; | |||
| /// Kind of operation | |||
| OperationKind kind; | |||
| /// Describes the tiled structure of a GEMM-like computation | |||
| TileDescription tile_description; | |||
| // | |||
| // Methods | |||
| // | |||
| OperationDescription( | |||
| char const* name = "unknown", | |||
| OperationKind kind = OperationKind::kInvalid, | |||
| TileDescription const& tile_description = TileDescription()) | |||
| : name(name), kind(kind), tile_description(tile_description) {} | |||
| }; | |||
| /// Structure describing the properties of a tensor | |||
| struct TensorDescription { | |||
| /// Numeric type of an individual element | |||
| NumericTypeID element; | |||
| /// Enumerant identifying the layout function for the tensor | |||
| LayoutTypeID layout; | |||
| /// Alignment restriction on pointers, strides, and extents | |||
| int alignment; | |||
| /// log2() of the maximum extent of each dimension | |||
| int log_extent_range; | |||
| /// log2() of the maximum value each relevant stride may have | |||
| int log_stride_range; | |||
| // | |||
| // Methods | |||
| // | |||
| TensorDescription(NumericTypeID element = NumericTypeID::kInvalid, | |||
| LayoutTypeID layout = LayoutTypeID::kInvalid, | |||
| int alignment = 1, int log_extent_range = 24, | |||
| int log_stride_range = 24) | |||
| : element(element), | |||
| layout(layout), | |||
| alignment(alignment), | |||
| log_extent_range(log_extent_range), | |||
| log_stride_range(log_stride_range) {} | |||
| }; | |||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||
| struct GemmDescription : public OperationDescription { | |||
| GemmKind gemm_kind; | |||
| TensorDescription A; | |||
| TensorDescription B; | |||
| TensorDescription C; | |||
| int stages; | |||
| SplitKMode split_k_mode; | |||
| }; | |||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||
| struct GemmArguments { | |||
| /// GEMM problem size | |||
| gemm::GemmCoord problem_size; | |||
| /// Device pointers to input and output matrices | |||
| void const* A; | |||
| void const* B; | |||
| void const* C; | |||
| void* D; | |||
| /// Leading dimensions of input and output matrices | |||
| int64_t lda; | |||
| int64_t ldb; | |||
| int64_t ldc; | |||
| int64_t ldd; | |||
| /// Number of partitions of K dimension | |||
| int split_k_slices; | |||
| /// Host or device pointers to epilogue scalars, note that these pointers | |||
| /// will be interpreted as ElementCompute* in method `op->run(args)`, a | |||
| /// different dtype here results in undefined epilogue behaviors | |||
| void const* alpha; | |||
| void const* beta; | |||
| }; | |||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||
| struct ConvolutionDescription : public OperationDescription { | |||
| conv::Operator conv_op; | |||
| TensorDescription src; | |||
| TensorDescription filter; | |||
| TensorDescription dst; | |||
| TensorDescription bias; | |||
| conv::ConvType convolution_type; | |||
| ArchTagID arch_tag; | |||
| epilogue::EpilogueType epilogue_type; | |||
| int epilogue_count; | |||
| ThreadblockSwizzleID threadblock_swizzle; | |||
| bool need_load_from_const_mem; | |||
| conv::ImplicitGemmMode gemm_mode; | |||
| bool without_shared_load; | |||
| }; | |||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||
| struct ConvolutionArguments { | |||
| /// Problem size | |||
| conv::Conv2dProblemSize problem_size; | |||
| /// Device pointers to input and output tensors | |||
| void const* src; | |||
| void const* filter; | |||
| void const* bias; | |||
| void const* z; | |||
| void* dst; | |||
| /// Host or device pointers to epilogue scalars, note that these pointers | |||
| /// will be interpreted as ElementCompute* in method `op->run(args)`, a | |||
| /// different dtype here results in undefined epilogue behaviors | |||
| void const* alpha; | |||
| void const* beta; | |||
| void const* gamma; | |||
| void const* delta; | |||
| void const* theta; | |||
| void const* threshold; | |||
| void const* scale; | |||
| /// Host pointer to extra param struct | |||
| void const* extra_param; | |||
| }; | |||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||
| /// Base class for all operations | |||
| class Operation { | |||
| public: | |||
| virtual ~Operation() {} | |||
| virtual OperationDescription const& description() const = 0; | |||
| virtual Status run(void const* arguments, void* device_workspace = nullptr, | |||
| cudaStream_t stream = nullptr) const = 0; | |||
| }; | |||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||
| } // namespace library | |||
| } // namespace cutlass | |||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||
| @@ -0,0 +1,580 @@ | |||
| /*************************************************************************************************** | |||
| * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. | |||
| * | |||
| * Redistribution and use in source and binary forms, with or without | |||
| *modification, are permitted provided that the following conditions are met: | |||
| * * Redistributions of source code must retain the above copyright notice, | |||
| *this list of conditions and the following disclaimer. | |||
| * * Redistributions in binary form must reproduce the above copyright | |||
| *notice, this list of conditions and the following disclaimer in the | |||
| *documentation and/or other materials provided with the distribution. | |||
| * * Neither the name of the NVIDIA CORPORATION nor the names of its | |||
| *contributors may be used to endorse or promote products derived from this | |||
| *software without specific prior written permission. | |||
| * | |||
| * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||
| *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||
| *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | |||
| *DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, | |||
| *INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, | |||
| * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, | |||
| *DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY | |||
| *OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING | |||
| *NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, | |||
| *EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||
| * | |||
| **************************************************************************************************/ | |||
| /** | |||
| * \file dnn/src/cuda/cutlass/library_internal.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #pragma once | |||
| #pragma GCC diagnostic push | |||
| #pragma GCC diagnostic ignored "-Wreorder" | |||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||
| #include "cutlass/arch/arch.h" | |||
| #include "cutlass/arch/mma.h" | |||
| #include "cutlass/complex.h" | |||
| #include "cutlass/convolution/threadblock/threadblock_swizzle.h" | |||
| #include "cutlass/cutlass.h" | |||
| #include "cutlass/gemm/threadblock/threadblock_swizzle.h" | |||
| #include "cutlass/layout/matrix.h" | |||
| #include "cutlass/numeric_types.h" | |||
| #pragma GCC diagnostic pop | |||
| #include "src/cuda/cutlass/arch_mappings.h" | |||
| #include "src/cuda/cutlass/library.h" | |||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||
| namespace cutlass { | |||
| namespace library { | |||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||
| template <typename T> | |||
| struct NumericTypeMap; | |||
| template <> | |||
| struct NumericTypeMap<cutlass::uint1b_t> { | |||
| static NumericTypeID const kId = NumericTypeID::kB1; | |||
| }; | |||
| template <> | |||
| struct NumericTypeMap<cutlass::int4b_t> { | |||
| static NumericTypeID const kId = NumericTypeID::kS4; | |||
| }; | |||
| template <> | |||
| struct NumericTypeMap<int8_t> { | |||
| static NumericTypeID const kId = NumericTypeID::kS8; | |||
| }; | |||
| template <> | |||
| struct NumericTypeMap<int16_t> { | |||
| static NumericTypeID const kId = NumericTypeID::kS16; | |||
| }; | |||
| template <> | |||
| struct NumericTypeMap<int32_t> { | |||
| static NumericTypeID const kId = NumericTypeID::kS32; | |||
| }; | |||
| template <> | |||
| struct NumericTypeMap<int64_t> { | |||
| static NumericTypeID const kId = NumericTypeID::kS64; | |||
| }; | |||
| template <> | |||
| struct NumericTypeMap<cutlass::uint4b_t> { | |||
| static NumericTypeID const kId = NumericTypeID::kU4; | |||
| }; | |||
| template <> | |||
| struct NumericTypeMap<uint8_t> { | |||
| static NumericTypeID const kId = NumericTypeID::kU8; | |||
| }; | |||
| template <> | |||
| struct NumericTypeMap<uint16_t> { | |||
| static NumericTypeID const kId = NumericTypeID::kU16; | |||
| }; | |||
| template <> | |||
| struct NumericTypeMap<uint32_t> { | |||
| static NumericTypeID const kId = NumericTypeID::kU32; | |||
| }; | |||
| template <> | |||
| struct NumericTypeMap<uint64_t> { | |||
| static NumericTypeID const kId = NumericTypeID::kU64; | |||
| }; | |||
| template <> | |||
| struct NumericTypeMap<cutlass::half_t> { | |||
| static NumericTypeID const kId = NumericTypeID::kF16; | |||
| }; | |||
| template <> | |||
| struct NumericTypeMap<float> { | |||
| static NumericTypeID const kId = NumericTypeID::kF32; | |||
| }; | |||
| template <> | |||
| struct NumericTypeMap<double> { | |||
| static NumericTypeID const kId = NumericTypeID::kF64; | |||
| }; | |||
| template <> | |||
| struct NumericTypeMap<cutlass::complex<cutlass::half_t>> { | |||
| static NumericTypeID const kId = NumericTypeID::kCF16; | |||
| }; | |||
| template <> | |||
| struct NumericTypeMap<cutlass::complex<float>> { | |||
| static NumericTypeID const kId = NumericTypeID::kCF32; | |||
| }; | |||
| template <> | |||
| struct NumericTypeMap<cutlass::complex<double>> { | |||
| static NumericTypeID const kId = NumericTypeID::kCF64; | |||
| }; | |||
| template <> | |||
| struct NumericTypeMap<cutlass::bfloat16_t> { | |||
| static NumericTypeID const kId = NumericTypeID::kBF16; | |||
| }; | |||
| template <> | |||
| struct NumericTypeMap<cutlass::tfloat32_t> { | |||
| static NumericTypeID const kId = NumericTypeID::kTF32; | |||
| }; | |||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||
| template <typename T> | |||
| struct MathOperationMap { | |||
| static MathOperationID const kId = MathOperationID::kInvalid; | |||
| }; | |||
| template <> | |||
| struct MathOperationMap<cutlass::arch::OpMultiplyAdd> { | |||
| static MathOperationID const kId = MathOperationID::kMultiplyAdd; | |||
| }; | |||
| template <> | |||
| struct MathOperationMap<cutlass::arch::OpMultiplyAddFastBF16> { | |||
| static MathOperationID const kId = MathOperationID::kMultiplyAddFastBF16; | |||
| }; | |||
| template <> | |||
| struct MathOperationMap<cutlass::arch::OpMultiplyAddFastF16> { | |||
| static MathOperationID const kId = MathOperationID::kMultiplyAddFastF16; | |||
| }; | |||
| template <> | |||
| struct MathOperationMap<cutlass::arch::OpMultiplyAddSaturate> { | |||
| static MathOperationID const kId = MathOperationID::kMultiplyAddSaturate; | |||
| }; | |||
| template <> | |||
| struct MathOperationMap<cutlass::arch::OpMultiplyAddComplex> { | |||
| static MathOperationID const kId = MathOperationID::kMultiplyAddComplex; | |||
| }; | |||
| template <> | |||
| struct MathOperationMap<cutlass::arch::OpMultiplyAddGaussianComplex> { | |||
| static MathOperationID const kId = | |||
| MathOperationID::kMultiplyAddGaussianComplex; | |||
| }; | |||
| template <> | |||
| struct MathOperationMap<cutlass::arch::OpXorPopc> { | |||
| static MathOperationID const kId = MathOperationID::kXorPopc; | |||
| }; | |||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||
| template <typename T> | |||
| struct LayoutMap; | |||
| template <> | |||
| struct LayoutMap<cutlass::layout::ColumnMajor> { | |||
| static LayoutTypeID const kId = LayoutTypeID::kColumnMajor; | |||
| }; | |||
| template <> | |||
| struct LayoutMap<cutlass::layout::RowMajor> { | |||
| static LayoutTypeID const kId = LayoutTypeID::kRowMajor; | |||
| }; | |||
| template <> | |||
| struct LayoutMap<cutlass::layout::ColumnMajorInterleaved<2>> { | |||
| static LayoutTypeID const kId = LayoutTypeID::kColumnMajorInterleavedK2; | |||
| }; | |||
| template <> | |||
| struct LayoutMap<cutlass::layout::RowMajorInterleaved<2>> { | |||
| static LayoutTypeID const kId = LayoutTypeID::kRowMajorInterleavedK2; | |||
| }; | |||
| template <> | |||
| struct LayoutMap<cutlass::layout::ColumnMajorInterleaved<4>> { | |||
| static LayoutTypeID const kId = LayoutTypeID::kColumnMajorInterleavedK4; | |||
| }; | |||
| template <> | |||
| struct LayoutMap<cutlass::layout::RowMajorInterleaved<4>> { | |||
| static LayoutTypeID const kId = LayoutTypeID::kRowMajorInterleavedK4; | |||
| }; | |||
| template <> | |||
| struct LayoutMap<cutlass::layout::ColumnMajorInterleaved<16>> { | |||
| static LayoutTypeID const kId = LayoutTypeID::kColumnMajorInterleavedK16; | |||
| }; | |||
| template <> | |||
| struct LayoutMap<cutlass::layout::RowMajorInterleaved<16>> { | |||
| static LayoutTypeID const kId = LayoutTypeID::kRowMajorInterleavedK16; | |||
| }; | |||
| template <> | |||
| struct LayoutMap<cutlass::layout::ColumnMajorInterleaved<32>> { | |||
| static LayoutTypeID const kId = LayoutTypeID::kColumnMajorInterleavedK32; | |||
| }; | |||
| template <> | |||
| struct LayoutMap<cutlass::layout::RowMajorInterleaved<32>> { | |||
| static LayoutTypeID const kId = LayoutTypeID::kRowMajorInterleavedK32; | |||
| }; | |||
| template <> | |||
| struct LayoutMap<cutlass::layout::ColumnMajorInterleaved<64>> { | |||
| static LayoutTypeID const kId = LayoutTypeID::kColumnMajorInterleavedK64; | |||
| }; | |||
| template <> | |||
| struct LayoutMap<cutlass::layout::RowMajorInterleaved<64>> { | |||
| static LayoutTypeID const kId = LayoutTypeID::kRowMajorInterleavedK64; | |||
| }; | |||
| template <> | |||
| struct LayoutMap<cutlass::layout::TensorNCHW> { | |||
| static LayoutTypeID const kId = LayoutTypeID::kTensorNCHW; | |||
| }; | |||
| template <> | |||
| struct LayoutMap<cutlass::layout::TensorNHWC> { | |||
| static LayoutTypeID const kId = LayoutTypeID::kTensorNHWC; | |||
| }; | |||
| template <> | |||
| struct LayoutMap<cutlass::layout::TensorNDHWC> { | |||
| static LayoutTypeID const kId = LayoutTypeID::kTensorNDHWC; | |||
| }; | |||
| template <> | |||
| struct LayoutMap<cutlass::layout::TensorNCxHWx<4>> { | |||
| static LayoutTypeID const kId = LayoutTypeID::kTensorNC4HW4; | |||
| }; | |||
| template <> | |||
| struct LayoutMap<cutlass::layout::TensorNCxHWx<8>> { | |||
| static LayoutTypeID const kId = LayoutTypeID::kTensorNC8HW8; | |||
| }; | |||
| template <> | |||
| struct LayoutMap<cutlass::layout::TensorNCxHWx<16>> { | |||
| static LayoutTypeID const kId = LayoutTypeID::kTensorNC16HW16; | |||
| }; | |||
| template <> | |||
| struct LayoutMap<cutlass::layout::TensorNCxHWx<32>> { | |||
| static LayoutTypeID const kId = LayoutTypeID::kTensorNC32HW32; | |||
| }; | |||
| template <> | |||
| struct LayoutMap<cutlass::layout::TensorNCxHWx<64>> { | |||
| static LayoutTypeID const kId = LayoutTypeID::kTensorNC64HW64; | |||
| }; | |||
| template <> | |||
| struct LayoutMap<cutlass::layout::TensorCxRSKx<4>> { | |||
| static LayoutTypeID const kId = LayoutTypeID::kTensorC4RSK4; | |||
| }; | |||
| template <> | |||
| struct LayoutMap<cutlass::layout::TensorCxRSKx<8>> { | |||
| static LayoutTypeID const kId = LayoutTypeID::kTensorC8RSK8; | |||
| }; | |||
| template <> | |||
| struct LayoutMap<cutlass::layout::TensorCxRSKx<16>> { | |||
| static LayoutTypeID const kId = LayoutTypeID::kTensorC16RSK16; | |||
| }; | |||
| template <> | |||
| struct LayoutMap<cutlass::layout::TensorCxRSKx<32>> { | |||
| static LayoutTypeID const kId = LayoutTypeID::kTensorC32RSK32; | |||
| }; | |||
| template <> | |||
| struct LayoutMap<cutlass::layout::TensorCxRSKx<64>> { | |||
| static LayoutTypeID const kId = LayoutTypeID::kTensorC64RSK64; | |||
| }; | |||
| template <> | |||
| struct LayoutMap<cutlass::layout::TensorKxRSCx<4>> { | |||
| static LayoutTypeID const kId = LayoutTypeID::kTensorK4RSC4; | |||
| }; | |||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||
| template <typename T> | |||
| struct OpcodeClassMap; | |||
| template <> | |||
| struct OpcodeClassMap<arch::OpClassSimt> { | |||
| static OpcodeClassID const kId = OpcodeClassID::kSimt; | |||
| }; | |||
| template <> | |||
| struct OpcodeClassMap<arch::OpClassTensorOp> { | |||
| static OpcodeClassID const kId = OpcodeClassID::kTensorOp; | |||
| }; | |||
| template <> | |||
| struct OpcodeClassMap<arch::OpClassWmmaTensorOp> { | |||
| static OpcodeClassID const kId = OpcodeClassID::kWmmaTensorOp; | |||
| }; | |||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||
| template <typename T> | |||
| struct ArchTagMap; | |||
| template <> | |||
| struct ArchTagMap<arch::Sm50> { | |||
| static ArchTagID const kId = ArchTagID::kSm50; | |||
| }; | |||
| template <> | |||
| struct ArchTagMap<arch::Sm60> { | |||
| static ArchTagID const kId = ArchTagID::kSm60; | |||
| }; | |||
| template <> | |||
| struct ArchTagMap<arch::Sm61> { | |||
| static ArchTagID const kId = ArchTagID::kSm61; | |||
| }; | |||
| template <> | |||
| struct ArchTagMap<arch::Sm70> { | |||
| static ArchTagID const kId = ArchTagID::kSm70; | |||
| }; | |||
| template <> | |||
| struct ArchTagMap<arch::Sm72> { | |||
| static ArchTagID const kId = ArchTagID::kSm72; | |||
| }; | |||
| template <> | |||
| struct ArchTagMap<arch::Sm75> { | |||
| static ArchTagID const kId = ArchTagID::kSm75; | |||
| }; | |||
| template <> | |||
| struct ArchTagMap<arch::Sm80> { | |||
| static ArchTagID const kId = ArchTagID::kSm80; | |||
| }; | |||
| template <> | |||
| struct ArchTagMap<arch::Sm86> { | |||
| static ArchTagID const kId = ArchTagID::kSm86; | |||
| }; | |||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||
| template <cutlass::ComplexTransform Transform> | |||
| struct ComplexTransformMap; | |||
| template <> | |||
| struct ComplexTransformMap<cutlass::ComplexTransform::kNone> { | |||
| static cutlass::library::ComplexTransform const kId = | |||
| cutlass::library::ComplexTransform::kNone; | |||
| }; | |||
| template <> | |||
| struct ComplexTransformMap<cutlass::ComplexTransform::kConjugate> { | |||
| static cutlass::library::ComplexTransform const kId = | |||
| cutlass::library::ComplexTransform::kConjugate; | |||
| }; | |||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||
| template <cutlass::conv::Mode T> | |||
| struct ConvModeMap; | |||
| template <> | |||
| struct ConvModeMap<conv::Mode::kCrossCorrelation> { | |||
| static ConvModeID const kId = ConvModeID::kCrossCorrelation; | |||
| }; | |||
| template <> | |||
| struct ConvModeMap<conv::Mode::kConvolution> { | |||
| static ConvModeID const kId = ConvModeID::kConvolution; | |||
| }; | |||
| template <cutlass::conv::Operator T> | |||
| struct ConvKindMap; | |||
| template <> | |||
| struct ConvKindMap<conv::Operator::kFprop> { | |||
| static ConvKind const kId = ConvKind::kFprop; | |||
| }; | |||
| template <> | |||
| struct ConvKindMap<conv::Operator::kDgrad> { | |||
| static ConvKind const kId = ConvKind::kDgrad; | |||
| }; | |||
| template <> | |||
| struct ConvKindMap<conv::Operator::kWgrad> { | |||
| static ConvKind const kId = ConvKind::kWgrad; | |||
| }; | |||
| template <cutlass::conv::IteratorAlgorithm T> | |||
| struct IteratorAlgorithmMap; | |||
| template <> | |||
| struct IteratorAlgorithmMap<conv::IteratorAlgorithm::kAnalytic> { | |||
| static IteratorAlgorithmID const kId = IteratorAlgorithmID::kAnalytic; | |||
| }; | |||
| template <> | |||
| struct IteratorAlgorithmMap<conv::IteratorAlgorithm::kOptimized> { | |||
| static IteratorAlgorithmID const kId = IteratorAlgorithmID::kOptimized; | |||
| }; | |||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||
| template <typename T> | |||
| struct ThreadblockSwizzleMap; | |||
| template <int N> | |||
| struct ThreadblockSwizzleMap< | |||
| gemm::threadblock::GemmIdentityThreadblockSwizzle<N>> { | |||
| static ThreadblockSwizzleID const kId = ThreadblockSwizzleID::kGemmIdentity; | |||
| }; | |||
| template <> | |||
| struct ThreadblockSwizzleMap< | |||
| gemm::threadblock::GemmHorizontalThreadblockSwizzle> { | |||
| static ThreadblockSwizzleID const kId = | |||
| ThreadblockSwizzleID::kGemmHorizontal; | |||
| }; | |||
| template <> | |||
| struct ThreadblockSwizzleMap< | |||
| gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle> { | |||
| static ThreadblockSwizzleID const kId = | |||
| ThreadblockSwizzleID::kGemmBatchedIdentity; | |||
| }; | |||
| template <int N> | |||
| struct ThreadblockSwizzleMap< | |||
| gemm::threadblock::GemmSplitKIdentityThreadblockSwizzle<N>> { | |||
| static ThreadblockSwizzleID const kId = | |||
| ThreadblockSwizzleID::kGemmSplitKIdentity; | |||
| }; | |||
| template <> | |||
| struct ThreadblockSwizzleMap< | |||
| gemm::threadblock::GemmSplitKHorizontalThreadblockSwizzle> { | |||
| static ThreadblockSwizzleID const kId = | |||
| ThreadblockSwizzleID::kGemmSplitKHorizontal; | |||
| }; | |||
| template <> | |||
| struct ThreadblockSwizzleMap< | |||
| gemm::threadblock::GemvBatchedStridedThreadblockDefaultSwizzle> { | |||
| static ThreadblockSwizzleID const kId = | |||
| ThreadblockSwizzleID::kGemvBatchedStridedDefault; | |||
| }; | |||
| template <> | |||
| struct ThreadblockSwizzleMap< | |||
| gemm::threadblock::GemvBatchedStridedThreadblockReductionSwizzle> { | |||
| static ThreadblockSwizzleID const kId = | |||
| ThreadblockSwizzleID::kGemvBatchedStridedReduction; | |||
| }; | |||
| template <> | |||
| struct ThreadblockSwizzleMap< | |||
| conv::threadblock::ConvolutionFpropCxRSKxThreadblockSwizzle> { | |||
| static ThreadblockSwizzleID const kId = | |||
| ThreadblockSwizzleID::kConvolutionFpropCxRSKx; | |||
| }; | |||
| template <> | |||
| struct ThreadblockSwizzleMap< | |||
| conv::threadblock::ConvolutionDgradCxRSKxThreadblockSwizzle> { | |||
| static ThreadblockSwizzleID const kId = | |||
| ThreadblockSwizzleID::kConvolutionDgradCxRSKx; | |||
| }; | |||
| template <> | |||
| struct ThreadblockSwizzleMap< | |||
| conv::threadblock::ConvolutionFpropNCxHWxThreadblockSwizzle> { | |||
| static ThreadblockSwizzleID const kId = | |||
| ThreadblockSwizzleID::kConvolutionFpropNCxHWx; | |||
| }; | |||
| template <> | |||
| struct ThreadblockSwizzleMap< | |||
| conv::threadblock::ConvolutionFpropTransThreadblockSwizzle> { | |||
| static ThreadblockSwizzleID const kId = | |||
| ThreadblockSwizzleID::kConvolutionFpropTrans; | |||
| }; | |||
| template <> | |||
| struct ThreadblockSwizzleMap< | |||
| conv::threadblock::ConvolutionDgradNCxHWxThreadblockSwizzle> { | |||
| static ThreadblockSwizzleID const kId = | |||
| ThreadblockSwizzleID::kConvolutionDgradNCxHWx; | |||
| }; | |||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||
| template <typename Element, typename Layout> | |||
| TensorDescription make_TensorDescription(int alignment = 1) { | |||
| TensorDescription desc; | |||
| desc.element = NumericTypeMap<Element>::kId; | |||
| desc.layout = LayoutMap<Layout>::kId; | |||
| desc.alignment = alignment; | |||
| desc.log_extent_range = | |||
| int(sizeof(typename Layout::TensorCoord::Index) - 1) * 8; | |||
| desc.log_stride_range = int(sizeof(typename Layout::Stride::Index) - 1) * 8; | |||
| return desc; | |||
| } | |||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||
| } // namespace library | |||
| } // namespace cutlass | |||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||
| @@ -0,0 +1,96 @@ | |||
| /*************************************************************************************************** | |||
| * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. | |||
| * | |||
| * Redistribution and use in source and binary forms, with or without | |||
| *modification, are permitted provided that the following conditions are met: | |||
| * * Redistributions of source code must retain the above copyright notice, | |||
| *this list of conditions and the following disclaimer. | |||
| * * Redistributions in binary form must reproduce the above copyright | |||
| *notice, this list of conditions and the following disclaimer in the | |||
| *documentation and/or other materials provided with the distribution. | |||
| * * Neither the name of the NVIDIA CORPORATION nor the names of its | |||
| *contributors may be used to endorse or promote products derived from this | |||
| *software without specific prior written permission. | |||
| * | |||
| * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||
| *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||
| *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | |||
| *DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, | |||
| *INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, | |||
| * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, | |||
| *DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY | |||
| *OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING | |||
| *NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, | |||
| *EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||
| * | |||
| **************************************************************************************************/ | |||
| /** | |||
| * \file dnn/src/cuda/cutlass/manifest.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include <memory> | |||
| #include "src/cuda/cutlass/manifest.h" | |||
| namespace cutlass { | |||
| namespace library { | |||
| ////////////////////////////////////////////////////////////////////////////////////////////////////////// | |||
| /// Top-level initialization | |||
| Status Manifest::initialize() { | |||
| if (!operations_.empty()) { | |||
| operations_.clear(); | |||
| } | |||
| // initialize procedurally generated cutlass op in manifest object | |||
| initialize_all(*this); | |||
| return Status::kSuccess; | |||
| } | |||
| /// Used for initialization | |||
| void Manifest::reserve(size_t operation_count) { | |||
| operations_.reserve(operation_count); | |||
| } | |||
| /// Graceful shutdown | |||
| Status Manifest::release() { | |||
| operations_.clear(); | |||
| return Status::kSuccess; | |||
| } | |||
| /// Appends an operation and takes ownership | |||
| void Manifest::append(Operation* operation_ptr) { | |||
| operations_.emplace_back(operation_ptr); | |||
| } | |||
| /// Returns an iterator to the first operation | |||
| OperationVector const& Manifest::operations() const { | |||
| return operations_; | |||
| } | |||
| /// Returns a const iterator | |||
| OperationVector::const_iterator Manifest::begin() const { | |||
| return operations_.begin(); | |||
| } | |||
| /// Returns a const iterator | |||
| OperationVector::const_iterator Manifest::end() const { | |||
| return operations_.end(); | |||
| } | |||
| /////////////////////////////////////////////////////////////////////////////////////////////////// | |||
| } // namespace library | |||
| } // namespace cutlass | |||
| /////////////////////////////////////////////////////////////////////////////////////////////////// | |||
| @@ -0,0 +1,108 @@ | |||
| /*************************************************************************************************** | |||
| * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. | |||
| * | |||
| * Redistribution and use in source and binary forms, with or without | |||
| *modification, are permitted provided that the following conditions are met: | |||
| * * Redistributions of source code must retain the above copyright notice, | |||
| *this list of conditions and the following disclaimer. | |||
| * * Redistributions in binary form must reproduce the above copyright | |||
| *notice, this list of conditions and the following disclaimer in the | |||
| *documentation and/or other materials provided with the distribution. | |||
| * * Neither the name of the NVIDIA CORPORATION nor the names of its | |||
| *contributors may be used to endorse or promote products derived from this | |||
| *software without specific prior written permission. | |||
| * | |||
| * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||
| *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||
| *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | |||
| *DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, | |||
| *INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, | |||
| * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, | |||
| *DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY | |||
| *OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING | |||
| *NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, | |||
| *EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||
| * | |||
| **************************************************************************************************/ | |||
| /** | |||
| * \file dnn/src/cuda/cutlass/manifest.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #pragma once | |||
| #include <list> | |||
| #include <map> | |||
| #include <memory> | |||
| #include "src/cuda/cutlass/library.h" | |||
| /////////////////////////////////////////////////////////////////////////////////////////////////// | |||
| namespace cutlass { | |||
| namespace library { | |||
| /////////////////////////////////////////////////////////////////////////////////////////////////// | |||
| // Forward declaration | |||
| class Manifest; | |||
| // init and insert all cutlass gemm operations in manifest object (procedurally | |||
| // generated using generator.py) | |||
| void initialize_all(Manifest& manifest); | |||
| ///////////////////////////////////////////////////////////////////////////////////////////////////////// | |||
| /// List of operations | |||
| using OperationVector = std::vector<std::unique_ptr<Operation>>; | |||
| /////////////////////////////////////////////////////////////////////////////////////////////////// | |||
| /// Manifest of CUTLASS Library | |||
| class Manifest { | |||
| private: | |||
| /// Operation provider | |||
| Provider provider_; | |||
| /// Global list of operations | |||
| OperationVector operations_; | |||
| public: | |||
| Manifest(Provider provider = library::Provider::kCUTLASS) | |||
| : provider_(provider) {} | |||
| /// Top-level initialization | |||
| Status initialize(); | |||
| /// Used for initialization | |||
| void reserve(size_t operation_count); | |||
| /// Graceful shutdown | |||
| Status release(); | |||
| /// Appends an operation and takes ownership | |||
| void append(Operation* operation_ptr); | |||
| /// Returns an iterator to the first operation | |||
| OperationVector const& operations() const; | |||
| /// Returns a const iterator | |||
| OperationVector::const_iterator begin() const; | |||
| /// Returns a const iterator | |||
| OperationVector::const_iterator end() const; | |||
| }; | |||
| /////////////////////////////////////////////////////////////////////////////////////////////////// | |||
| } // namespace library | |||
| } // namespace cutlass | |||
| /////////////////////////////////////////////////////////////////////////////////////////////////// | |||
| @@ -0,0 +1,179 @@ | |||
| /*************************************************************************************************** | |||
| * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. | |||
| * | |||
| * Redistribution and use in source and binary forms, with or without | |||
| *modification, are permitted provided that the following conditions are met: | |||
| * * Redistributions of source code must retain the above copyright notice, | |||
| *this list of conditions and the following disclaimer. | |||
| * * Redistributions in binary form must reproduce the above copyright | |||
| *notice, this list of conditions and the following disclaimer in the | |||
| *documentation and/or other materials provided with the distribution. | |||
| * * Neither the name of the NVIDIA CORPORATION nor the names of its | |||
| *contributors may be used to endorse or promote products derived from this | |||
| *software without specific prior written permission. | |||
| * | |||
| * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||
| *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||
| *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | |||
| *DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, | |||
| *INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, | |||
| * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, | |||
| *DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY | |||
| *OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING | |||
| *NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, | |||
| *EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||
| * | |||
| **************************************************************************************************/ | |||
| /** | |||
| * \file dnn/src/cuda/cutlass/operation_table.cu | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "src/common/utils.h" | |||
| #include "src/cuda/cutlass/operation_table.h" | |||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||
| namespace cutlass { | |||
| namespace library { | |||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||
| GemmKey get_gemm_key_from_desc(const GemmDescription& desc) { | |||
| GemmKey key; | |||
| key.element_A = desc.A.element; | |||
| key.layout_A = desc.A.layout; | |||
| key.element_B = desc.B.element; | |||
| key.layout_B = desc.B.layout; | |||
| key.element_C = desc.C.element; | |||
| key.layout_C = desc.C.layout; | |||
| key.threadblock_shape_m = desc.tile_description.threadblock_shape.m(); | |||
| key.threadblock_shape_n = desc.tile_description.threadblock_shape.n(); | |||
| key.threadblock_shape_k = desc.tile_description.threadblock_shape.k(); | |||
| key.warp_shape_m = desc.tile_description.threadblock_shape.m() / | |||
| desc.tile_description.warp_count.m(); | |||
| key.warp_shape_n = desc.tile_description.threadblock_shape.n() / | |||
| desc.tile_description.warp_count.n(); | |||
| key.warp_shape_k = desc.tile_description.threadblock_shape.k() / | |||
| desc.tile_description.warp_count.k(); | |||
| key.instruction_shape_m = | |||
| desc.tile_description.math_instruction.instruction_shape.m(); | |||
| key.instruction_shape_n = | |||
| desc.tile_description.math_instruction.instruction_shape.n(); | |||
| key.instruction_shape_k = | |||
| desc.tile_description.math_instruction.instruction_shape.k(); | |||
| key.stages = desc.stages; | |||
| key.split_k_mode = desc.split_k_mode; | |||
| return key; | |||
| } | |||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||
| ConvolutionKey get_convolution_key_from_desc( | |||
| const ConvolutionDescription& desc) { | |||
| ConvolutionKey key; | |||
| key.conv_op = desc.conv_op; | |||
| key.element_src = desc.src.element; | |||
| key.layout_src = desc.src.layout; | |||
| key.element_filter = desc.filter.element; | |||
| key.layout_filter = desc.filter.layout; | |||
| key.element_dst = desc.dst.element; | |||
| key.layout_dst = desc.dst.layout; | |||
| key.element_bias = desc.bias.element; | |||
| key.layout_bias = desc.bias.layout; | |||
| key.convolution_type = desc.convolution_type; | |||
| key.threadblock_shape_m = desc.tile_description.threadblock_shape.m(); | |||
| key.threadblock_shape_n = desc.tile_description.threadblock_shape.n(); | |||
| key.threadblock_shape_k = desc.tile_description.threadblock_shape.k(); | |||
| key.warp_shape_m = desc.tile_description.threadblock_shape.m() / | |||
| desc.tile_description.warp_count.m(); | |||
| key.warp_shape_n = desc.tile_description.threadblock_shape.n() / | |||
| desc.tile_description.warp_count.n(); | |||
| key.warp_shape_k = desc.tile_description.threadblock_shape.k() / | |||
| desc.tile_description.warp_count.k(); | |||
| key.instruction_shape_m = | |||
| desc.tile_description.math_instruction.instruction_shape.m(); | |||
| key.instruction_shape_n = | |||
| desc.tile_description.math_instruction.instruction_shape.n(); | |||
| key.instruction_shape_k = | |||
| desc.tile_description.math_instruction.instruction_shape.k(); | |||
| key.epilogue_type = desc.epilogue_type; | |||
| key.stages = desc.tile_description.threadblock_stages; | |||
| key.need_load_from_const_mem = desc.need_load_from_const_mem; | |||
| key.without_shared_load = desc.without_shared_load; | |||
| return key; | |||
| } | |||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||
| void OperationTable::append(Manifest const& manifest) { | |||
| // Insert operations into appropriate data structure | |||
| for (auto const& operation : manifest) { | |||
| OperationDescription const& desc = operation->description(); | |||
| // insert all gemm operations into operation table | |||
| if (desc.kind == OperationKind::kGemm) { | |||
| GemmKey key = get_gemm_key_from_desc( | |||
| static_cast<GemmDescription const&>(desc)); | |||
| gemm_operations[key].push_back(operation.get()); | |||
| } | |||
| // insert all conv operations into operation table | |||
| if (desc.kind == OperationKind::kConvolution) { | |||
| ConvolutionKey key = get_convolution_key_from_desc( | |||
| static_cast<ConvolutionDescription const&>(desc)); | |||
| convolution_operations[key].push_back(operation.get()); | |||
| } | |||
| } | |||
| } | |||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||
| Operation const* OperationTable::find_op(GemmKey const& key) const { | |||
| megdnn_assert(gemm_operations.count(key) > 0, | |||
| "key not found in cutlass operation table"); | |||
| auto const& ops = gemm_operations.at(key); | |||
| megdnn_assert(ops.size() == 1, "exactly one kernel expected, got %zu", | |||
| ops.size()); | |||
| return ops[0]; | |||
| } | |||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||
| Operation const* OperationTable::find_op(ConvolutionKey const& key) const { | |||
| megdnn_assert(convolution_operations.count(key) > 0, | |||
| "key not found in cutlass operation table"); | |||
| auto const& ops = convolution_operations.at(key); | |||
| megdnn_assert(ops.size() == 1, "exactly one kernel expected, got %zu", | |||
| ops.size()); | |||
| return ops[0]; | |||
| } | |||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||
| } // namespace library | |||
| } // namespace cutlass | |||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||
| @@ -0,0 +1,334 @@ | |||
| /*************************************************************************************************** | |||
| * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. | |||
| * | |||
| * Redistribution and use in source and binary forms, with or without | |||
| *modification, are permitted provided that the following conditions are met: | |||
| * * Redistributions of source code must retain the above copyright notice, | |||
| *this list of conditions and the following disclaimer. | |||
| * * Redistributions in binary form must reproduce the above copyright | |||
| *notice, this list of conditions and the following disclaimer in the | |||
| *documentation and/or other materials provided with the distribution. | |||
| * * Neither the name of the NVIDIA CORPORATION nor the names of its | |||
| *contributors may be used to endorse or promote products derived from this | |||
| *software without specific prior written permission. | |||
| * | |||
| * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||
| *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||
| *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | |||
| *DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, | |||
| *INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, | |||
| * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, | |||
| *DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY | |||
| *OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING | |||
| *NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, | |||
| *EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||
| * | |||
| **************************************************************************************************/ | |||
| /** | |||
| * \file dnn/src/cuda/cutlass/operation_table.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #pragma once | |||
| #include <unordered_map> | |||
| #include "src/common/hash_ct.h" | |||
| #include "src/cuda/cutlass/manifest.h" | |||
| #include "src/cuda/cutlass/util.h" | |||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||
| namespace cutlass { | |||
| namespace library { | |||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||
| class Hash { | |||
| public: | |||
| Hash() : m_val(0) {} | |||
| Hash& update(const void* ptr, size_t len) { | |||
| m_val += megdnn::XXHash64CT::hash((const char*)ptr, len, 123456); | |||
| return *this; | |||
| } | |||
| uint64_t digest() const { return m_val; } | |||
| private: | |||
| uint64_t m_val; | |||
| }; | |||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||
| // Data Structures for GemmOperationMap | |||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||
| struct GemmKey { | |||
| NumericTypeID element_A; | |||
| LayoutTypeID layout_A; | |||
| NumericTypeID element_B; | |||
| LayoutTypeID layout_B; | |||
| NumericTypeID element_C; | |||
| LayoutTypeID layout_C; | |||
| int threadblock_shape_m; | |||
| int threadblock_shape_n; | |||
| int threadblock_shape_k; | |||
| int warp_shape_m; | |||
| int warp_shape_n; | |||
| int warp_shape_k; | |||
| int instruction_shape_m; | |||
| int instruction_shape_n; | |||
| int instruction_shape_k; | |||
| int stages; | |||
| SplitKMode split_k_mode; | |||
| inline bool operator==(GemmKey const& rhs) const { | |||
| return (element_A == rhs.element_A) && (layout_A == rhs.layout_A) && | |||
| (element_B == rhs.element_B) && (layout_B == rhs.layout_B) && | |||
| (element_C == rhs.element_C) && (layout_C == rhs.layout_C) && | |||
| (threadblock_shape_m == rhs.threadblock_shape_m) && | |||
| (threadblock_shape_n == rhs.threadblock_shape_n) && | |||
| (threadblock_shape_k == rhs.threadblock_shape_k) && | |||
| (warp_shape_m == rhs.warp_shape_m) && | |||
| (warp_shape_n == rhs.warp_shape_n) && | |||
| (warp_shape_k == rhs.warp_shape_k) && | |||
| (instruction_shape_m == rhs.instruction_shape_m) && | |||
| (instruction_shape_n == rhs.instruction_shape_n) && | |||
| (instruction_shape_k == rhs.instruction_shape_k) && | |||
| (stages == rhs.stages) && (split_k_mode == rhs.split_k_mode); | |||
| } | |||
| inline bool operator!=(GemmKey const& rhs) const { return !(*this == rhs); } | |||
| inline std::string str() const { | |||
| auto tuple_to_str = [](int m, int n, int k) -> std::string { | |||
| return std::to_string(m) + " x " + std::to_string(n) + " x " + | |||
| std::to_string(k); | |||
| }; | |||
| std::string threadblock_shape_str = tuple_to_str( | |||
| threadblock_shape_m, threadblock_shape_n, threadblock_shape_k); | |||
| std::string warp_shape_str = | |||
| tuple_to_str(warp_shape_m, warp_shape_n, warp_shape_k); | |||
| std::string instruction_shape_str = tuple_to_str( | |||
| instruction_shape_m, instruction_shape_n, instruction_shape_k); | |||
| return std::string("{") + "\n element_A: " + to_string(element_A) + | |||
| "\n layout_A: " + to_string(layout_A) + | |||
| "\n element_B: " + to_string(element_B) + | |||
| "\n layout_B: " + to_string(layout_B) + | |||
| "\n element_C: " + to_string(element_C) + | |||
| "\n layout_C: " + to_string(layout_C) + | |||
| "\n threadblock_shape: " + threadblock_shape_str + | |||
| "\n warp_shape: " + warp_shape_str + | |||
| "\n instruction_shape: " + instruction_shape_str + | |||
| "\n stages: " + std::to_string(stages) + | |||
| "\n split_k_mode: " + to_string(split_k_mode) + "\n}"; | |||
| } | |||
| }; | |||
| struct GemmKeyHasher { | |||
| inline size_t operator()(GemmKey const& key) const { | |||
| return Hash() | |||
| .update(&key.element_A, sizeof(key.element_A)) | |||
| .update(&key.layout_A, sizeof(key.layout_A)) | |||
| .update(&key.element_B, sizeof(key.element_B)) | |||
| .update(&key.layout_B, sizeof(key.layout_B)) | |||
| .update(&key.element_C, sizeof(key.element_C)) | |||
| .update(&key.layout_C, sizeof(key.layout_C)) | |||
| .update(&key.threadblock_shape_m, | |||
| sizeof(key.threadblock_shape_m)) | |||
| .update(&key.threadblock_shape_n, | |||
| sizeof(key.threadblock_shape_n)) | |||
| .update(&key.threadblock_shape_k, | |||
| sizeof(key.threadblock_shape_k)) | |||
| .update(&key.warp_shape_m, sizeof(key.warp_shape_m)) | |||
| .update(&key.warp_shape_n, sizeof(key.warp_shape_n)) | |||
| .update(&key.warp_shape_k, sizeof(key.warp_shape_k)) | |||
| .update(&key.stages, sizeof(key.stages)) | |||
| .update(&key.split_k_mode, sizeof(key.split_k_mode)) | |||
| .digest(); | |||
| } | |||
| }; | |||
| using GemmOperationMap = | |||
| std::unordered_map<GemmKey, std::vector<Operation const*>, | |||
| GemmKeyHasher>; | |||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||
| // Data Structures for ConvolutionOperationMap | |||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||
| struct ConvolutionKey { | |||
| conv::Operator conv_op; | |||
| library::NumericTypeID element_src; | |||
| library::LayoutTypeID layout_src; | |||
| library::NumericTypeID element_filter; | |||
| library::LayoutTypeID layout_filter; | |||
| library::NumericTypeID element_dst; | |||
| library::LayoutTypeID layout_dst; | |||
| library::NumericTypeID element_bias; | |||
| library::LayoutTypeID layout_bias; | |||
| conv::ConvType convolution_type; | |||
| int threadblock_shape_m; | |||
| int threadblock_shape_n; | |||
| int threadblock_shape_k; | |||
| int warp_shape_m; | |||
| int warp_shape_n; | |||
| int warp_shape_k; | |||
| int instruction_shape_m; | |||
| int instruction_shape_n; | |||
| int instruction_shape_k; | |||
| epilogue::EpilogueType epilogue_type; | |||
| int stages; | |||
| bool need_load_from_const_mem; | |||
| bool without_shared_load; | |||
| inline bool operator==(ConvolutionKey const& rhs) const { | |||
| return (conv_op == rhs.conv_op) && (element_src == rhs.element_src) && | |||
| (layout_src == rhs.layout_src) && | |||
| (element_filter == rhs.element_filter) && | |||
| (layout_filter == rhs.layout_filter) && | |||
| (element_dst == rhs.element_dst) && | |||
| (layout_dst == rhs.layout_dst) && | |||
| (element_bias == rhs.element_bias) && | |||
| (layout_bias == rhs.layout_bias) && | |||
| (convolution_type == rhs.convolution_type) && | |||
| (threadblock_shape_m == rhs.threadblock_shape_m) && | |||
| (threadblock_shape_n == rhs.threadblock_shape_n) && | |||
| (threadblock_shape_k == rhs.threadblock_shape_k) && | |||
| (warp_shape_m == rhs.warp_shape_m) && | |||
| (warp_shape_n == rhs.warp_shape_n) && | |||
| (warp_shape_k == rhs.warp_shape_k) && | |||
| (instruction_shape_m == rhs.instruction_shape_m) && | |||
| (instruction_shape_n == rhs.instruction_shape_n) && | |||
| (instruction_shape_k == rhs.instruction_shape_k) && | |||
| (epilogue_type == rhs.epilogue_type) && (stages == rhs.stages) && | |||
| (need_load_from_const_mem == rhs.need_load_from_const_mem) && | |||
| (without_shared_load == rhs.without_shared_load); | |||
| } | |||
| inline bool operator!=(ConvolutionKey const& rhs) const { | |||
| return !(*this == rhs); | |||
| } | |||
| inline std::string str() const { | |||
| auto tuple_to_str = [](int m, int n, int k) -> std::string { | |||
| return std::to_string(m) + " x " + std::to_string(n) + " x " + | |||
| std::to_string(k); | |||
| }; | |||
| std::string threadblock_shape_str = tuple_to_str( | |||
| threadblock_shape_m, threadblock_shape_n, threadblock_shape_k); | |||
| std::string warp_shape_str = | |||
| tuple_to_str(warp_shape_m, warp_shape_n, warp_shape_k); | |||
| std::string instruction_shape_str = tuple_to_str( | |||
| instruction_shape_m, instruction_shape_n, instruction_shape_k); | |||
| return std::string("{") + "\n conv_op: " + to_string(conv_op) + | |||
| "\n element_src: " + to_string(element_src) + | |||
| "\n layout_src: " + to_string(layout_src) + | |||
| "\n element_filter: " + to_string(element_filter) + | |||
| "\n layout_filter: " + to_string(layout_filter) + | |||
| "\n element_dst: " + to_string(element_dst) + | |||
| "\n layout_dst: " + to_string(layout_dst) + | |||
| "\n element_bias: " + to_string(element_bias) + | |||
| "\n layout_bias: " + to_string(layout_bias) + | |||
| "\n convolution_type: " + to_string(convolution_type) + | |||
| "\n threadblock_shape: " + threadblock_shape_str + | |||
| "\n warp_shape: " + warp_shape_str + | |||
| "\n instruction_shape: " + instruction_shape_str + | |||
| "\n epilogue_type: " + to_string(epilogue_type) + | |||
| "\n stages: " + std::to_string(stages) + | |||
| "\n need_load_from_const_mem: " + | |||
| to_string(need_load_from_const_mem) + | |||
| "\n without_shared_load: " + to_string(without_shared_load) + | |||
| "\n}"; | |||
| } | |||
| }; | |||
| struct ConvolutionKeyHasher { | |||
| inline size_t operator()(ConvolutionKey const& key) const { | |||
| return Hash() | |||
| .update(&key.conv_op, sizeof(key.conv_op)) | |||
| .update(&key.conv_op, sizeof(key.conv_op)) | |||
| .update(&key.element_src, sizeof(key.element_src)) | |||
| .update(&key.layout_src, sizeof(key.layout_src)) | |||
| .update(&key.element_filter, sizeof(key.element_filter)) | |||
| .update(&key.layout_filter, sizeof(key.layout_filter)) | |||
| .update(&key.element_dst, sizeof(key.element_dst)) | |||
| .update(&key.layout_dst, sizeof(key.layout_dst)) | |||
| .update(&key.element_bias, sizeof(key.element_bias)) | |||
| .update(&key.layout_bias, sizeof(key.layout_bias)) | |||
| .update(&key.convolution_type, sizeof(key.convolution_type)) | |||
| .update(&key.threadblock_shape_m, | |||
| sizeof(key.threadblock_shape_m)) | |||
| .update(&key.threadblock_shape_n, | |||
| sizeof(key.threadblock_shape_n)) | |||
| .update(&key.threadblock_shape_k, | |||
| sizeof(key.threadblock_shape_k)) | |||
| .update(&key.warp_shape_m, sizeof(key.warp_shape_m)) | |||
| .update(&key.warp_shape_n, sizeof(key.warp_shape_n)) | |||
| .update(&key.warp_shape_k, sizeof(key.warp_shape_k)) | |||
| .update(&key.instruction_shape_m, | |||
| sizeof(key.instruction_shape_m)) | |||
| .update(&key.instruction_shape_n, | |||
| sizeof(key.instruction_shape_n)) | |||
| .update(&key.instruction_shape_k, | |||
| sizeof(key.instruction_shape_k)) | |||
| .update(&key.epilogue_type, sizeof(key.epilogue_type)) | |||
| .update(&key.stages, sizeof(key.stages)) | |||
| .update(&key.need_load_from_const_mem, | |||
| sizeof(key.need_load_from_const_mem)) | |||
| .update(&key.without_shared_load, | |||
| sizeof(key.without_shared_load)) | |||
| .digest(); | |||
| } | |||
| }; | |||
| using ConvolutionOperationMap = | |||
| std::unordered_map<ConvolutionKey, std::vector<Operation const*>, | |||
| ConvolutionKeyHasher>; | |||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||
| /// Table of cutlass::library::Operation instances | |||
| class OperationTable { | |||
| public: | |||
| /// Map of all operations of type kGemm | |||
| GemmOperationMap gemm_operations; | |||
| /// Map of all operations of type kConvolution | |||
| ConvolutionOperationMap convolution_operations; | |||
| public: | |||
| void append(Manifest const& manifest); | |||
| Operation const* find_op(GemmKey const& key) const; | |||
| Operation const* find_op(ConvolutionKey const& key) const; | |||
| }; | |||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||
| } // namespace library | |||
| } // namespace cutlass | |||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||
| @@ -0,0 +1,72 @@ | |||
| /*************************************************************************************************** | |||
| * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. | |||
| * | |||
| * Redistribution and use in source and binary forms, with or without | |||
| *modification, are permitted provided that the following conditions are met: | |||
| * * Redistributions of source code must retain the above copyright notice, | |||
| *this list of conditions and the following disclaimer. | |||
| * * Redistributions in binary form must reproduce the above copyright | |||
| *notice, this list of conditions and the following disclaimer in the | |||
| *documentation and/or other materials provided with the distribution. | |||
| * * Neither the name of the NVIDIA CORPORATION nor the names of its | |||
| *contributors may be used to endorse or promote products derived from this | |||
| *software without specific prior written permission. | |||
| * | |||
| * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||
| *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||
| *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | |||
| *DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, | |||
| *INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, | |||
| * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, | |||
| *DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY | |||
| *OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING | |||
| *NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, | |||
| *EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||
| * | |||
| **************************************************************************************************/ | |||
| /** | |||
| * \file dnn/src/cuda/cutlass/singleton.cu | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include <memory> | |||
| #include "src/cuda/cutlass/singleton.h" | |||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||
| namespace cutlass { | |||
| namespace library { | |||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||
| static std::unique_ptr<Singleton> instance; | |||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||
| Singleton::Singleton() { | |||
| manifest.initialize(); | |||
| operation_table.append(manifest); | |||
| } | |||
| Singleton const& Singleton::get() { | |||
| if (!instance.get()) { | |||
| instance.reset(new Singleton); | |||
| } | |||
| return *instance.get(); | |||
| } | |||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||
| } // namespace library | |||
| } // namespace cutlass | |||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||
| @@ -0,0 +1,70 @@ | |||
| /*************************************************************************************************** | |||
| * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. | |||
| * | |||
| * Redistribution and use in source and binary forms, with or without | |||
| *modification, are permitted provided that the following conditions are met: | |||
| * * Redistributions of source code must retain the above copyright notice, | |||
| *this list of conditions and the following disclaimer. | |||
| * * Redistributions in binary form must reproduce the above copyright | |||
| *notice, this list of conditions and the following disclaimer in the | |||
| *documentation and/or other materials provided with the distribution. | |||
| * * Neither the name of the NVIDIA CORPORATION nor the names of its | |||
| *contributors may be used to endorse or promote products derived from this | |||
| *software without specific prior written permission. | |||
| * | |||
| * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||
| *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||
| *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | |||
| *DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, | |||
| *INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, | |||
| * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, | |||
| *DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY | |||
| *OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING | |||
| *NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, | |||
| *EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||
| * | |||
| **************************************************************************************************/ | |||
| /** | |||
| * \file dnn/src/cuda/cutlass/singleton.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #pragma once | |||
| #include "src/cuda/cutlass/operation_table.h" | |||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||
| namespace cutlass { | |||
| namespace library { | |||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||
| /// Singleton instance stores a Manifest and Operation table | |||
| class Singleton { | |||
| public: | |||
| /// Manifest object | |||
| Manifest manifest; | |||
| /// Operation table referencing the Manifest | |||
| OperationTable operation_table; | |||
| public: | |||
| Singleton(); | |||
| static Singleton const& get(); | |||
| }; | |||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||
| } // namespace library | |||
| } // namespace cutlass | |||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||
| @@ -0,0 +1,218 @@ | |||
| /*************************************************************************************************** | |||
| * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. | |||
| * | |||
| * Redistribution and use in source and binary forms, with or without | |||
| *modification, are permitted provided that the following conditions are met: | |||
| * * Redistributions of source code must retain the above copyright notice, | |||
| *this list of conditions and the following disclaimer. | |||
| * * Redistributions in binary form must reproduce the above copyright | |||
| *notice, this list of conditions and the following disclaimer in the | |||
| *documentation and/or other materials provided with the distribution. | |||
| * * Neither the name of the NVIDIA CORPORATION nor the names of its | |||
| *contributors may be used to endorse or promote products derived from this | |||
| *software without specific prior written permission. | |||
| * | |||
| * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||
| *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||
| *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | |||
| *DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, | |||
| *INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, | |||
| * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, | |||
| *DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY | |||
| *OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING | |||
| *NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, | |||
| *EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||
| * | |||
| **************************************************************************************************/ | |||
| /** | |||
| * \file dnn/src/cuda/cutlass/util.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #pragma once | |||
| #include "src/cuda/cutlass/library.h" | |||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||
| namespace cutlass { | |||
| namespace library { | |||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||
| /// Lexical cast from string | |||
| template <typename T> | |||
| T from_string(std::string const&); | |||
| /// Converts a Provider enumerant to a string | |||
| char const* to_string(Provider provider, bool pretty = false); | |||
| /// Parses a Provider enumerant from a string | |||
| template <> | |||
| Provider from_string<Provider>(std::string const& str); | |||
| /// Converts a GemmKind enumerant to a string | |||
| char const* to_string(GemmKind type, bool pretty = false); | |||
| /// Converts a NumericType enumerant to a string | |||
| char const* to_string(OperationKind type, bool pretty = false); | |||
| /// Parses a NumericType enumerant from a string | |||
| template <> | |||
| OperationKind from_string<OperationKind>(std::string const& str); | |||
| /// Converts a NumericType enumerant to a string | |||
| char const* to_string(NumericTypeID type, bool pretty = false); | |||
| /// Parses a NumericType enumerant from a string | |||
| template <> | |||
| NumericTypeID from_string<NumericTypeID>(std::string const& str); | |||
| /// Returns the size of a data type in bits | |||
| int sizeof_bits(NumericTypeID type); | |||
| /// Returns true if the numeric type is a complex data type or false if | |||
| /// real-valued. | |||
| bool is_complex_type(NumericTypeID type); | |||
| /// Returns the real-valued type underlying a type (only different from 'type' | |||
| /// if complex) | |||
| NumericTypeID get_real_type(NumericTypeID type); | |||
| /// Returns true if numeric type is integer | |||
| bool is_integer_type(NumericTypeID type); | |||
| /// Returns true if numeric type is signed | |||
| bool is_signed_type(NumericTypeID type); | |||
| /// Returns true if numeric type is a signed integer | |||
| bool is_signed_integer(NumericTypeID type); | |||
| /// returns true if numeric type is an unsigned integer | |||
| bool is_unsigned_integer(NumericTypeID type); | |||
| /// Returns true if numeric type is floating-point type | |||
| bool is_float_type(NumericTypeID type); | |||
| /// To string method for cutlass::Status | |||
| char const* to_string(Status status, bool pretty = false); | |||
| /// Converts a LayoutTypeID enumerant to a string | |||
| char const* to_string(LayoutTypeID layout, bool pretty = false); | |||
| /// Parses a LayoutType enumerant from a string | |||
| template <> | |||
| LayoutTypeID from_string<LayoutTypeID>(std::string const& str); | |||
| /// Returns the rank of a layout's stride base on the LayoutTypeID | |||
| int get_layout_stride_rank(LayoutTypeID layout_id); | |||
| /// Converts a OpcodeClassID enumerant to a string | |||
| char const* to_string(OpcodeClassID type, bool pretty = false); | |||
| /// Converts a OpcodeClassID enumerant from a string | |||
| template <> | |||
| OpcodeClassID from_string<OpcodeClassID>(std::string const& str); | |||
| /// Converts a ComplexTransform enumerant to a string | |||
| char const* to_string(ComplexTransform type, bool pretty = false); | |||
| /// Converts a ComplexTransform enumerant from a string | |||
| template <> | |||
| ComplexTransform from_string<ComplexTransform>(std::string const& str); | |||
| /// Converts a SplitKMode enumerant to a string | |||
| char const* to_string(SplitKMode split_k_mode, bool pretty = false); | |||
| /// Converts a SplitKMode enumerant from a string | |||
| template <> | |||
| SplitKMode from_string<SplitKMode>(std::string const& str); | |||
| /// Converts a ConvModeID enumerant to a string | |||
| char const* to_string(ConvModeID type, bool pretty = false); | |||
| /// Converts a ConvModeID enumerant from a string | |||
| template <> | |||
| ConvModeID from_string<ConvModeID>(std::string const& str); | |||
| /// Converts a IteratorAlgorithmID enumerant to a string | |||
| char const* to_string(IteratorAlgorithmID type, bool pretty = false); | |||
| /// Converts a IteratorAlgorithmID enumerant from a string | |||
| template <> | |||
| IteratorAlgorithmID from_string<IteratorAlgorithmID>(std::string const& str); | |||
| /// Converts a ConvKind enumerant to a string | |||
| char const* to_string(ConvKind type, bool pretty = false); | |||
| /// Converts a ConvKind enumerant from a string | |||
| template <> | |||
| ConvKind from_string<ConvKind>(std::string const& str); | |||
| /// Lexical cast from int64_t to string | |||
| std::string lexical_cast(int64_t int_value); | |||
| /// Lexical cast a string to a byte array. Returns true if cast is successful or | |||
| /// false if invalid. | |||
| bool lexical_cast(std::vector<uint8_t>& bytes, NumericTypeID type, | |||
| std::string const& str); | |||
| /// Lexical cast TO a string FROM a byte array. Returns true if cast is | |||
| /// successful or false if invalid. | |||
| std::string lexical_cast(std::vector<uint8_t>& bytes, NumericTypeID type); | |||
| /// Casts from a signed int64 to the destination type. Returns true if | |||
| /// successful. | |||
| bool cast_from_int64(std::vector<uint8_t>& bytes, NumericTypeID type, | |||
| int64_t src); | |||
| /// Casts from an unsigned int64 to the destination type. Returns true if | |||
| /// successful. | |||
| bool cast_from_uint64(std::vector<uint8_t>& bytes, NumericTypeID type, | |||
| uint64_t src); | |||
| /// Casts from a real value represented as a double to the destination type. | |||
| /// Returns true if successful. | |||
| bool cast_from_double(std::vector<uint8_t>& bytes, NumericTypeID type, | |||
| double src); | |||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||
| /// Converts a conv::Operator enumerant to a string | |||
| char const* to_string(conv::Operator conv_op, bool pretty = false); | |||
| /// Converts a ConvType enumerant to a string | |||
| char const* to_string(conv::ConvType type, bool pretty = false); | |||
| /// Converts an ArchTagID enumerant to a string | |||
| char const* to_string(ArchTagID tag, bool pretty = false); | |||
| /// Converts an EpilogueType enumerant to a string | |||
| char const* to_string(epilogue::EpilogueType type, bool pretty = false); | |||
| /// Converts a ThreadblockSwizzleID enumerant to a string | |||
| char const* to_string(ThreadblockSwizzleID threadblock_swizzle, | |||
| bool pretty = false); | |||
| /// Converts a bool value to a string | |||
| char const* to_string(bool val, bool pretty = false); | |||
| /// Converts a MathOperationID enumerant to a string | |||
| char const* to_string(MathOperationID math_op, bool pretty = false); | |||
| /// Converts an ImplicitGemmMode enumerant to a string | |||
| char const* to_string(conv::ImplicitGemmMode mode, bool pretty = false); | |||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||
| } // namespace library | |||
| } // namespace cutlass | |||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||
| @@ -10,15 +10,14 @@ | |||
| * implied. | |||
| */ | |||
| #include "src/cuda/cutlass/singleton.h" | |||
| #include "src/cuda/handle.h" | |||
| #include "src/cuda/matrix_mul/algos.h" | |||
| #include "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cuh" | |||
| #include "src/cuda/utils.h" | |||
| #if CUDA_VERSION >= 9020 | |||
| using namespace megdnn; | |||
| using namespace cuda; | |||
| using namespace cutlass_wrapper; | |||
| bool MatrixMulForwardImpl::AlgoFloat32SIMT::is_available( | |||
| const SizeArgs& args) const { | |||
| @@ -44,25 +43,62 @@ size_t MatrixMulForwardImpl::AlgoFloat32SIMT::get_workspace_in_bytes( | |||
| } | |||
| void MatrixMulForwardImpl::AlgoFloat32SIMT::exec(const ExecArgs& args) const { | |||
| size_t lda = args.tensor_a.layout.stride[0], | |||
| ldb = args.tensor_b.layout.stride[0], | |||
| ldc = args.tensor_c.layout.stride[0]; | |||
| int64_t lda = args.tensor_a.layout.stride[0], | |||
| ldb = args.tensor_b.layout.stride[0], | |||
| ldc = args.tensor_c.layout.stride[0]; | |||
| auto&& param = args.opr->param(); | |||
| int m = args.tensor_c.layout.shape[0], n = args.tensor_c.layout.shape[1], | |||
| k = args.tensor_a.layout.shape[param.transposeA ? 0 : 1]; | |||
| GemmCoord problem_size{m, n, k}; | |||
| cutlass::gemm::GemmCoord problem_size{m, n, k}; | |||
| auto&& stream = cuda_stream(args.opr->handle()); | |||
| int* workspace = reinterpret_cast<int*>(args.workspace.raw_ptr); | |||
| return cutlass_matrix_mul_float32_simt( | |||
| args.tensor_a.ptr<dt_float32>(), param.transposeA, lda, | |||
| args.tensor_b.ptr<dt_float32>(), param.transposeB, ldb, | |||
| args.tensor_c.ptr<dt_float32>(), ldc, workspace, problem_size, 1.f, | |||
| 0.f, | |||
| GemmCoord{m_algo_param.threadblock_m, m_algo_param.threadblock_n, | |||
| m_algo_param.threadblock_k}, | |||
| GemmCoord{m_algo_param.warp_m, m_algo_param.warp_n, | |||
| m_algo_param.warp_k}, | |||
| stream); | |||
| // \note these constants of cutlass epilogue will be passed to struct | |||
| // `GemmArguments` by pointer and interpreted as ElementCompute*, a | |||
| // different dtype here results in undefined epilogue behaviors | |||
| float alpha = 1.f, beta = 0.f; | |||
| using namespace cutlass::library; | |||
| auto layoutA = param.transposeA ? LayoutTypeID::kColumnMajor | |||
| : LayoutTypeID::kRowMajor; | |||
| auto layoutB = param.transposeB ? LayoutTypeID::kColumnMajor | |||
| : LayoutTypeID::kRowMajor; | |||
| GemmKey key{NumericTypeID::kF32, | |||
| layoutA, | |||
| NumericTypeID::kF32, | |||
| layoutB, | |||
| NumericTypeID::kF32, | |||
| LayoutTypeID::kRowMajor, | |||
| m_algo_param.threadblock_m, | |||
| m_algo_param.threadblock_n, | |||
| m_algo_param.threadblock_k, | |||
| m_algo_param.warp_m, | |||
| m_algo_param.warp_n, | |||
| m_algo_param.warp_k, | |||
| 1, | |||
| 1, | |||
| 1, | |||
| 2, | |||
| SplitKMode::kNone}; | |||
| const Operation* op = Singleton::get().operation_table.find_op(key); | |||
| GemmArguments gemm_args{problem_size, | |||
| args.tensor_a.raw_ptr, | |||
| args.tensor_b.raw_ptr, | |||
| args.tensor_c.raw_ptr, | |||
| args.tensor_c.raw_ptr, | |||
| lda, | |||
| ldb, | |||
| ldc, | |||
| ldc, | |||
| 1, | |||
| &alpha, | |||
| &beta}; | |||
| cutlass_check(op->run(&gemm_args, workspace, stream)); | |||
| } | |||
| #endif | |||
| @@ -10,15 +10,14 @@ | |||
| * implied. | |||
| */ | |||
| #include "src/cuda/cutlass/singleton.h" | |||
| #include "src/cuda/handle.h" | |||
| #include "src/cuda/matrix_mul/algos.h" | |||
| #include "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cuh" | |||
| #include "src/cuda/utils.h" | |||
| #if CUDA_VERSION >= 9020 | |||
| using namespace megdnn; | |||
| using namespace cuda; | |||
| using namespace cutlass_wrapper; | |||
| bool MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::is_available( | |||
| const SizeArgs& args) const { | |||
| @@ -50,26 +49,63 @@ size_t MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::get_workspace_in_bytes( | |||
| void MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::exec( | |||
| const ExecArgs& args) const { | |||
| size_t lda = args.tensor_a.layout.stride[0], | |||
| ldb = args.tensor_b.layout.stride[0], | |||
| ldc = args.tensor_c.layout.stride[0]; | |||
| int64_t lda = args.tensor_a.layout.stride[0], | |||
| ldb = args.tensor_b.layout.stride[0], | |||
| ldc = args.tensor_c.layout.stride[0]; | |||
| auto&& param = args.opr->param(); | |||
| int m = args.tensor_c.layout.shape[0], n = args.tensor_c.layout.shape[1], | |||
| k = args.tensor_a.layout.shape[param.transposeA ? 0 : 1]; | |||
| GemmCoord problem_size{m, n, k}; | |||
| cutlass::gemm::GemmCoord problem_size{m, n, k}; | |||
| int split_k_slices = std::max(1, k / n); | |||
| auto&& stream = cuda_stream(args.opr->handle()); | |||
| int* workspace = reinterpret_cast<int*>(args.workspace.raw_ptr); | |||
| return cutlass_matrix_mul_float32_simt( | |||
| args.tensor_a.ptr<dt_float32>(), param.transposeA, lda, | |||
| args.tensor_b.ptr<dt_float32>(), param.transposeB, ldb, | |||
| args.tensor_c.ptr<dt_float32>(), ldc, workspace, problem_size, 1.f, | |||
| 0.f, | |||
| GemmCoord{m_algo_param.threadblock_m, m_algo_param.threadblock_n, | |||
| m_algo_param.threadblock_k}, | |||
| GemmCoord{m_algo_param.warp_m, m_algo_param.warp_n, | |||
| m_algo_param.warp_k}, | |||
| stream, split_k_slices); | |||
| // \note these constants of cutlass epilogue will be passed to struct | |||
| // `GemmArguments` by pointer and interpreted as ElementCompute*, a | |||
| // different dtype here results in undefined epilogue behaviors | |||
| float alpha = 1.f, beta = 0.f; | |||
| using namespace cutlass::library; | |||
| auto layoutA = param.transposeA ? LayoutTypeID::kColumnMajor | |||
| : LayoutTypeID::kRowMajor; | |||
| auto layoutB = param.transposeB ? LayoutTypeID::kColumnMajor | |||
| : LayoutTypeID::kRowMajor; | |||
| GemmKey key{NumericTypeID::kF32, | |||
| layoutA, | |||
| NumericTypeID::kF32, | |||
| layoutB, | |||
| NumericTypeID::kF32, | |||
| LayoutTypeID::kRowMajor, | |||
| m_algo_param.threadblock_m, | |||
| m_algo_param.threadblock_n, | |||
| m_algo_param.threadblock_k, | |||
| m_algo_param.warp_m, | |||
| m_algo_param.warp_n, | |||
| m_algo_param.warp_k, | |||
| 1, | |||
| 1, | |||
| 1, | |||
| 2, | |||
| SplitKMode::kParallel}; | |||
| Operation const* op = Singleton::get().operation_table.find_op(key); | |||
| GemmArguments gemm_args{problem_size, | |||
| args.tensor_a.raw_ptr, | |||
| args.tensor_b.raw_ptr, | |||
| args.tensor_c.raw_ptr, | |||
| args.tensor_c.raw_ptr, | |||
| lda, | |||
| ldb, | |||
| ldc, | |||
| ldc, | |||
| split_k_slices, | |||
| &alpha, | |||
| &beta}; | |||
| cutlass_check(op->run(&gemm_args, workspace, stream)); | |||
| } | |||
| #endif | |||
| @@ -1,157 +0,0 @@ | |||
| /** | |||
| * \file dnn/src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cu | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| // ignore warning of cutlass | |||
| #include "cuda.h" | |||
| #if __CUDACC_VER_MAJOR__ > 9 || \ | |||
| (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) | |||
| #pragma GCC diagnostic push | |||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||
| #include "cutlass/gemm/device/gemm.h" | |||
| #include "cutlass/gemm/device/gemm_splitk_parallel.h" | |||
| #include "cutlass/gemm/kernel/default_gemv.h" | |||
| #include "src/common/opr_param_defs_enumv.cuh" | |||
| #include "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cuh" | |||
| #pragma GCC diagnostic pop | |||
| using namespace megdnn; | |||
| using namespace cuda; | |||
| using namespace cutlass_wrapper; | |||
| /* ================= cutlass kernel wrapper for f32 matrix mul ================ | |||
| */ | |||
| #define DISPATCH(cb) \ | |||
| cb(64, 256, 8, 32, 64, 8); \ | |||
| cb(256, 64, 8, 64, 32, 8); \ | |||
| cb(32, 256, 8, 16, 64, 8); \ | |||
| cb(256, 32, 8, 64, 16, 8); \ | |||
| cb(128, 128, 8, 32, 64, 8); \ | |||
| cb(128, 64, 8, 64, 32, 8); \ | |||
| cb(64, 128, 8, 32, 64, 8); \ | |||
| cb(128, 32, 8, 64, 32, 8); \ | |||
| cb(32, 128, 8, 32, 64, 8); \ | |||
| cb(64, 64, 8, 32, 64, 8); \ | |||
| cb(32, 64, 8, 32, 64, 8); \ | |||
| cb(64, 32, 8, 64, 32, 8); \ | |||
| cb(32, 32, 8, 32, 32, 8); \ | |||
| cb(8, 32, 8, 8, 32, 8); \ | |||
| cb(16, 32, 8, 16, 32, 8); \ | |||
| cb(16, 64, 8, 16, 64, 8); \ | |||
| cb(16, 128, 8, 16, 64, 8); \ | |||
| megdnn_assert(false, \ | |||
| "unsupported threadblock shape (%dx%dx%d) and warp shape " \ | |||
| "(%dx%dx%d)", \ | |||
| threadblock_shape.m(), threadblock_shape.n(), \ | |||
| threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ | |||
| warp_shape.k()); | |||
| void megdnn::cuda::cutlass_wrapper::cutlass_matrix_mul_float32_simt( | |||
| const float* d_A, bool transpose_A, size_t lda, const float* d_B, | |||
| bool transpose_B, size_t ldb, float* d_C, size_t ldc, int* workspace, | |||
| GemmCoord const& problem_size, float alpha, float beta, | |||
| const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, | |||
| cudaStream_t stream, int split_k_slices) { | |||
| static constexpr int kEpilogueElementsPerAccess = 1; | |||
| using EpilogueOp = cutlass::epilogue::thread::LinearCombination< | |||
| float, kEpilogueElementsPerAccess, float, float>; | |||
| typename EpilogueOp::Params epilogue{alpha, beta}; | |||
| if (split_k_slices == 1) { | |||
| #define cb(threadblock_m_, threadblock_n_, threadblock_k_, warp_m_, warp_n_, \ | |||
| warp_k_) \ | |||
| if (threadblock_shape.m() == threadblock_m_ && \ | |||
| threadblock_shape.n() == threadblock_n_ && \ | |||
| threadblock_shape.k() == threadblock_k_ && \ | |||
| warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ | |||
| warp_shape.k() == warp_k_) { \ | |||
| using ThreadBlockShape = \ | |||
| cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \ | |||
| threadblock_k_>; \ | |||
| using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \ | |||
| using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; \ | |||
| using Gemm = cutlass::gemm::device::Gemm< \ | |||
| float, LayoutA, float, LayoutB, float, \ | |||
| cutlass::layout::RowMajor, float, cutlass::arch::OpClassSimt, \ | |||
| cutlass::arch::Sm50, ThreadBlockShape, WarpShape, \ | |||
| InstructionShape, EpilogueOp, \ | |||
| cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, \ | |||
| 2>; \ | |||
| return cutlass_matrix_mul_wrapper<Gemm>(d_A, lda, d_B, ldb, d_C, ldc, \ | |||
| workspace, problem_size, \ | |||
| epilogue, stream); \ | |||
| } | |||
| if (!transpose_A && !transpose_B) { | |||
| using LayoutA = cutlass::layout::RowMajor; | |||
| using LayoutB = cutlass::layout::RowMajor; | |||
| DISPATCH(cb) | |||
| } else if (!transpose_A && transpose_B) { | |||
| using LayoutA = cutlass::layout::RowMajor; | |||
| using LayoutB = cutlass::layout::ColumnMajor; | |||
| DISPATCH(cb) | |||
| } else if (transpose_A && !transpose_B) { | |||
| using LayoutA = cutlass::layout::ColumnMajor; | |||
| using LayoutB = cutlass::layout::RowMajor; | |||
| DISPATCH(cb) | |||
| } else { | |||
| megdnn_assert(transpose_A && transpose_B); | |||
| using LayoutA = cutlass::layout::ColumnMajor; | |||
| using LayoutB = cutlass::layout::ColumnMajor; | |||
| DISPATCH(cb) | |||
| } | |||
| #undef cb | |||
| } else { | |||
| #define cb(threadblock_m_, threadblock_n_, threadblock_k_, warp_m_, warp_n_, \ | |||
| warp_k_) \ | |||
| if (threadblock_shape.m() == threadblock_m_ && \ | |||
| threadblock_shape.n() == threadblock_n_ && \ | |||
| threadblock_shape.k() == threadblock_k_ && \ | |||
| warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ | |||
| warp_shape.k() == warp_k_) { \ | |||
| using ThreadBlockShape = \ | |||
| cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \ | |||
| threadblock_k_>; \ | |||
| using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \ | |||
| using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; \ | |||
| using Gemm = cutlass::gemm::device::GemmSplitKParallel< \ | |||
| float, LayoutA, float, LayoutB, float, \ | |||
| cutlass::layout::RowMajor, float, cutlass::arch::OpClassSimt, \ | |||
| cutlass::arch::Sm50, ThreadBlockShape, WarpShape, \ | |||
| InstructionShape, EpilogueOp>; \ | |||
| return cutlass_matrix_mul_wrapper<Gemm>( \ | |||
| d_A, lda, d_B, ldb, d_C, ldc, workspace, problem_size, \ | |||
| epilogue, stream, split_k_slices); \ | |||
| } | |||
| if (!transpose_A && !transpose_B) { | |||
| using LayoutA = cutlass::layout::RowMajor; | |||
| using LayoutB = cutlass::layout::RowMajor; | |||
| DISPATCH(cb) | |||
| } else if (!transpose_A && transpose_B) { | |||
| using LayoutA = cutlass::layout::RowMajor; | |||
| using LayoutB = cutlass::layout::ColumnMajor; | |||
| DISPATCH(cb) | |||
| } else if (transpose_A && !transpose_B) { | |||
| using LayoutA = cutlass::layout::ColumnMajor; | |||
| using LayoutB = cutlass::layout::RowMajor; | |||
| DISPATCH(cb) | |||
| } else { | |||
| megdnn_assert(transpose_A && transpose_B); | |||
| using LayoutA = cutlass::layout::ColumnMajor; | |||
| using LayoutB = cutlass::layout::ColumnMajor; | |||
| DISPATCH(cb) | |||
| } | |||
| #undef cb | |||
| } | |||
| } | |||
| #undef DISPATCH | |||
| #endif | |||
| // vim: syntax=cuda.doxygen | |||
| @@ -21,22 +21,6 @@ namespace cutlass_wrapper { | |||
| using GemmCoord = cutlass::gemm::GemmCoord; | |||
| using BatchedGemmCoord = cutlass::gemm::BatchedGemmCoord; | |||
| template <typename Gemm> | |||
| void cutlass_matrix_mul_wrapper( | |||
| const typename Gemm::ElementA* d_A, size_t lda, | |||
| const typename Gemm::ElementB* d_B, size_t ldb, | |||
| typename Gemm::ElementC* d_C, size_t ldc, int* workspace, | |||
| GemmCoord const& problem_size, | |||
| typename Gemm::EpilogueOutputOp::Params const& epilogue, | |||
| cudaStream_t stream, int split_k_slices = 1); | |||
| void cutlass_matrix_mul_float32_simt( | |||
| const float* d_A, bool transpose_A, size_t lda, const float* d_B, | |||
| bool transpose_B, size_t ldb, float* d_C, size_t ldc, int* workspace, | |||
| GemmCoord const& problem_size, float alpha, float beta, | |||
| const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, | |||
| cudaStream_t stream, int split_k_slices = 1); | |||
| template <typename GemvKernel> | |||
| void cutlass_vector_matrix_mul_batched_strided_wrapper( | |||
| BatchedGemmCoord const& problem_size, | |||
| @@ -1,57 +0,0 @@ | |||
| /** | |||
| * \file | |||
| * dnn/src/cuda/matrix_mul/matrix_mul_float_simt_cutlass_wrapper.cuinl | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "cutlass/gemm/device/gemm.h" | |||
| #include "cutlass/gemm/device/gemm_splitk_parallel.h" | |||
| #include "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cuh" | |||
| using namespace megdnn; | |||
| using namespace cuda; | |||
| using namespace cutlass_wrapper; | |||
| template <typename Gemm> | |||
| void megdnn::cuda::cutlass_wrapper::cutlass_matrix_mul_wrapper( | |||
| const typename Gemm::ElementA* d_A, size_t lda, | |||
| const typename Gemm::ElementB* d_B, size_t ldb, | |||
| typename Gemm::ElementC* d_C, size_t ldc, int* workspace, | |||
| GemmCoord const& problem_size, | |||
| typename Gemm::EpilogueOutputOp::Params const& epilogue, | |||
| cudaStream_t stream, int split_k_slices) { | |||
| using TensorRefA = cutlass::TensorRef<typename Gemm::ElementA const, | |||
| typename Gemm::LayoutA>; | |||
| using TensorRefB = cutlass::TensorRef<typename Gemm::ElementB const, | |||
| typename Gemm::LayoutB>; | |||
| using TensorRefC = cutlass::TensorRef<typename Gemm::ElementC const, | |||
| typename Gemm::LayoutC>; | |||
| using TensorRefD = | |||
| cutlass::TensorRef<typename Gemm::ElementC, typename Gemm::LayoutC>; | |||
| TensorRefA tensor_a{const_cast<typename Gemm::ElementA*>(d_A), | |||
| typename Gemm::LayoutA{static_cast<int>(lda)}}; | |||
| TensorRefB tensor_b{const_cast<typename Gemm::ElementB*>(d_B), | |||
| typename Gemm::LayoutB{static_cast<int>(ldb)}}; | |||
| TensorRefC tensor_c{nullptr, typename Gemm::LayoutC{static_cast<int>(ldc)}}; | |||
| TensorRefD tensor_d{d_C, typename Gemm::LayoutC{static_cast<int>(ldc)}}; | |||
| typename Gemm::Arguments arguments{problem_size, | |||
| tensor_a, | |||
| tensor_b, | |||
| tensor_c, | |||
| tensor_d.non_const_ref(), | |||
| epilogue, | |||
| split_k_slices}; | |||
| Gemm gemm_op; | |||
| cutlass_check(gemm_op.initialize(arguments, workspace)); | |||
| cutlass_check(gemm_op(stream)); | |||
| after_kernel_launch(); | |||
| } | |||
| // vim: syntax=cuda.doxygen | |||