GitOrigin-RevId: 2a70335441
tags/v1.6.0-rc1
| @@ -163,7 +163,7 @@ using Convolution = | |||||
| ${element_bias}, | ${element_bias}, | ||||
| ${layout_bias}, | ${layout_bias}, | ||||
| ${element_accumulator}, | ${element_accumulator}, | ||||
| ${conv_type}, | |||||
| ${conv_type}, | |||||
| ${opcode_class}, | ${opcode_class}, | ||||
| ${arch}, | ${arch}, | ||||
| cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, | cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, | ||||
| @@ -246,6 +246,7 @@ using Deconvolution = | |||||
| ${element_bias}, | ${element_bias}, | ||||
| ${layout_bias}, | ${layout_bias}, | ||||
| ${element_accumulator}, | ${element_accumulator}, | ||||
| ${conv_type}, | |||||
| ${opcode_class}, | ${opcode_class}, | ||||
| ${arch}, | ${arch}, | ||||
| cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, | cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, | ||||
| @@ -276,6 +277,7 @@ using Deconvolution = | |||||
| values = { | values = { | ||||
| 'operation_name': operation.procedural_name(), | 'operation_name': operation.procedural_name(), | ||||
| 'conv_type': ConvTypeTag[operation.conv_type], | |||||
| 'element_src': DataTypeTag[operation.src.element], | 'element_src': DataTypeTag[operation.src.element], | ||||
| 'layout_src': LayoutTag[operation.src.layout], | 'layout_src': LayoutTag[operation.src.layout], | ||||
| 'element_flt': DataTypeTag[operation.flt.element], | 'element_flt': DataTypeTag[operation.flt.element], | ||||
| @@ -530,44 +532,17 @@ void initialize_${configuration_name}(Manifest &manifest) { | |||||
| ################################################################################################### | ################################################################################################### | ||||
| class EmitConvSingleKernelWrapper(): | class EmitConvSingleKernelWrapper(): | ||||
| def __init__(self, kernel_path, operation, wrapper_path): | |||||
| def __init__(self, kernel_path, operation): | |||||
| self.kernel_path = kernel_path | self.kernel_path = kernel_path | ||||
| self.wrapper_path = wrapper_path | |||||
| self.operation = operation | self.operation = operation | ||||
| self.conv_wrappers = { \ | |||||
| ConvKind.Fprop: """ | |||||
| template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>( | |||||
| const typename Convolution::ElementSrc* d_src, | |||||
| const typename Convolution::ElementFilter* d_filter, | |||||
| const typename Convolution::ElementBias* d_bias, | |||||
| const typename Convolution::ElementDst* d_z, | |||||
| typename Convolution::ElementDst* d_dst, | |||||
| int* workspace, | |||||
| typename Convolution::ConvolutionParameter const& conv_param, | |||||
| typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||||
| cudaStream_t stream, | |||||
| typename Convolution::ExtraParam extra_param); | |||||
| """, \ | |||||
| ConvKind.Dgrad: """ | |||||
| template void megdnn::cuda::cutlass_wrapper::cutlass_deconvolution_wrapper<Deconvolution>( | |||||
| const typename Deconvolution::ElementSrc* d_src, | |||||
| const typename Deconvolution::ElementFilter* d_filter, | |||||
| const typename Deconvolution::ElementBias* d_bias, | |||||
| const typename Deconvolution::ElementDst* d_z, | |||||
| typename Deconvolution::ElementDst* d_dst, | |||||
| int* workspace, | |||||
| typename Deconvolution::ConvolutionParameter const& conv_param, | |||||
| typename Deconvolution::EpilogueOutputOp::Params const& epilogue, | |||||
| cudaStream_t stream); | |||||
| """, \ | |||||
| } | |||||
| if self.operation.conv_kind == ConvKind.Fprop: | if self.operation.conv_kind == ConvKind.Fprop: | ||||
| self.instance_emitter = EmitConv2dInstance() | self.instance_emitter = EmitConv2dInstance() | ||||
| self.convolution_name = "Convolution" | |||||
| else: | else: | ||||
| assert self.operation.conv_kind == ConvKind.Dgrad | assert self.operation.conv_kind == ConvKind.Dgrad | ||||
| self.instance_emitter = EmitDeconvInstance() | self.instance_emitter = EmitDeconvInstance() | ||||
| self.convolution_name = "Deconvolution" | |||||
| self.header_template = """ | self.header_template = """ | ||||
| #if !MEGDNN_TEGRA_X1 | #if !MEGDNN_TEGRA_X1 | ||||
| @@ -575,13 +550,30 @@ template void megdnn::cuda::cutlass_wrapper::cutlass_deconvolution_wrapper<Decon | |||||
| #pragma GCC diagnostic push | #pragma GCC diagnostic push | ||||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | #pragma GCC diagnostic ignored "-Wunused-parameter" | ||||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | #pragma GCC diagnostic ignored "-Wstrict-aliasing" | ||||
| #include "${wrapper_path}" | |||||
| #pragma GCC diagnostic ignored "-Wuninitialized" | |||||
| #pragma GCC diagnostic ignored "-Wmaybe-uninitialized" | |||||
| #include "cutlass/convolution/device/convolution.h" | |||||
| #include "src/cuda/cutlass/manifest.h" | |||||
| #include "src/cuda/cutlass/convolution_operation.h" | |||||
| """ | """ | ||||
| self.instance_template = """ | self.instance_template = """ | ||||
| ${operation_instance} | ${operation_instance} | ||||
| """ | """ | ||||
| self.wrapper_template = """ | |||||
| ${wrapper_instance} | |||||
| self.manifest_template = """ | |||||
| namespace cutlass { | |||||
| namespace library { | |||||
| void initialize_${operation_name}(Manifest &manifest) { | |||||
| manifest.append(new ConvolutionOperation<${convolution_name}>( | |||||
| "${operation_name}" | |||||
| )); | |||||
| } | |||||
| } // namespace library | |||||
| } // namespace cutlass | |||||
| """ | """ | ||||
| self.epilogue_template = """ | self.epilogue_template = """ | ||||
| @@ -593,9 +585,7 @@ ${wrapper_instance} | |||||
| def __enter__(self): | def __enter__(self): | ||||
| self.kernel_path = os.path.join(self.kernel_path, "%s.cu" % self.operation.procedural_name()) | self.kernel_path = os.path.join(self.kernel_path, "%s.cu" % self.operation.procedural_name()) | ||||
| self.kernel_file = LazyFile(self.kernel_path) | self.kernel_file = LazyFile(self.kernel_path) | ||||
| self.kernel_file.write(SubstituteTemplate(self.header_template, { | |||||
| 'wrapper_path': self.wrapper_path, | |||||
| })) | |||||
| self.kernel_file.write(self.header_template) | |||||
| return self | return self | ||||
| # | # | ||||
| @@ -604,11 +594,12 @@ ${wrapper_instance} | |||||
| 'operation_instance': self.instance_emitter.emit(self.operation), | 'operation_instance': self.instance_emitter.emit(self.operation), | ||||
| })) | })) | ||||
| # emit wrapper | |||||
| wrapper = SubstituteTemplate(self.wrapper_template, { | |||||
| 'wrapper_instance': self.conv_wrappers[self.operation.conv_kind], | |||||
| # emit manifest helper | |||||
| manifest = SubstituteTemplate(self.manifest_template, { | |||||
| 'operation_name': self.operation.procedural_name(), | |||||
| 'convolution_name': self.convolution_name | |||||
| }) | }) | ||||
| self.kernel_file.write(wrapper) | |||||
| self.kernel_file.write(manifest) | |||||
| # | # | ||||
| def __exit__(self, exception_type, exception_value, traceback): | def __exit__(self, exception_type, exception_value, traceback): | ||||
| @@ -940,8 +940,8 @@ void initialize_${configuration_name}(Manifest &manifest) { | |||||
| /////////////////////////////////////////////////////////////////////////////////////////////////// | /////////////////////////////////////////////////////////////////////////////////////////////////// | ||||
| } // namespace library | |||||
| } // namespace cutlass | |||||
| } // namespace library | |||||
| } // namespace cutlass | |||||
| /////////////////////////////////////////////////////////////////////////////////////////////////// | /////////////////////////////////////////////////////////////////////////////////////////////////// | ||||
| @@ -995,48 +995,101 @@ void initialize_${configuration_name}(Manifest &manifest) { | |||||
| ################################################################################################### | ################################################################################################### | ||||
| class EmitGemmSingleKernelWrapper: | class EmitGemmSingleKernelWrapper: | ||||
| def __init__(self, kernel_path, gemm_operation, wrapper_path): | |||||
| def __init__(self, kernel_path, gemm_operation): | |||||
| self.kernel_path = kernel_path | self.kernel_path = kernel_path | ||||
| self.wrapper_path = wrapper_path | |||||
| self.operation = gemm_operation | self.operation = gemm_operation | ||||
| gemm_wrapper = """ | |||||
| template void megdnn::cuda::cutlass_wrapper::cutlass_matrix_mul_wrapper<Operation_${operation_name}>( | |||||
| const typename Operation_${operation_name}::ElementA* d_A, size_t lda, | |||||
| const typename Operation_${operation_name}::ElementB* d_B, size_t ldb, | |||||
| typename Operation_${operation_name}::ElementC* d_C, size_t ldc, | |||||
| int* workspace, | |||||
| cutlass::gemm::GemmCoord const& problem_size, | |||||
| typename Operation_${operation_name}::EpilogueOutputOp::Params const& epilogue, | |||||
| cudaStream_t stream, int split_k_slices); | |||||
| instance_emitters = { | |||||
| GemmKind.Gemm: EmitGemmInstance(), | |||||
| GemmKind.SplitKParallel: EmitGemmSplitKParallelInstance(), | |||||
| } | |||||
| self.instance_emitter = instance_emitters[self.operation.gemm_kind] | |||||
| self.header_template = """ | |||||
| #if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) | |||||
| // ignore warning of cutlass | |||||
| #pragma GCC diagnostic push | |||||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||||
| #pragma GCC diagnostic ignored "-Wuninitialized" | |||||
| #pragma GCC diagnostic ignored "-Wmaybe-uninitialized" | |||||
| #include "cutlass/gemm/device/gemm.h" | |||||
| #include "cutlass/gemm/device/gemm_splitk_parallel.h" | |||||
| #include "src/cuda/cutlass/manifest.h" | |||||
| #include "src/cuda/cutlass/gemm_operation.h" | |||||
| """ | """ | ||||
| self.instance_template = """ | |||||
| ${operation_instance} | |||||
| """ | |||||
| self.manifest_template = """ | |||||
| namespace cutlass { | |||||
| namespace library { | |||||
| void initialize_${operation_name}(Manifest &manifest) { | |||||
| manifest.append(new GemmOperation< | |||||
| Operation_${operation_name} | |||||
| >("${operation_name}")); | |||||
| } | |||||
| } // namespace library | |||||
| } // namespace cutlass | |||||
| """ | |||||
| self.epilogue_template = """ | |||||
| #pragma GCC diagnostic pop | |||||
| #endif | |||||
| """ | |||||
| # | |||||
| def __enter__(self): | |||||
| self.kernel_path = os.path.join(self.kernel_path, "%s.cu" % self.operation.procedural_name()) | |||||
| self.kernel_file = LazyFile(self.kernel_path) | |||||
| self.kernel_file.write(self.header_template) | |||||
| return self | |||||
| # | |||||
| def emit(self): | |||||
| self.kernel_file.write(SubstituteTemplate(self.instance_template, { | |||||
| 'operation_instance': self.instance_emitter.emit(self.operation), | |||||
| })) | |||||
| gemv_wrapper = """ | |||||
| # emit manifest helper | |||||
| manifest = SubstituteTemplate(self.manifest_template, { | |||||
| 'operation_name': self.operation.procedural_name(), | |||||
| }) | |||||
| self.kernel_file.write(manifest) | |||||
| # | |||||
| def __exit__(self, exception_type, exception_value, traceback): | |||||
| self.kernel_file.write(self.epilogue_template) | |||||
| self.kernel_file.close() | |||||
| ################################################################################################### | |||||
| ################################################################################################### | |||||
| class EmitGemvSingleKernelWrapper: | |||||
| def __init__(self, kernel_path, gemm_operation, wrapper_path): | |||||
| self.kernel_path = kernel_path | |||||
| self.wrapper_path = wrapper_path | |||||
| self.operation = gemm_operation | |||||
| self.wrapper_template = """ | |||||
| template void megdnn::cuda::cutlass_wrapper:: | template void megdnn::cuda::cutlass_wrapper:: | ||||
| cutlass_vector_matrix_mul_batched_strided_wrapper<Operation_${operation_name}>( | cutlass_vector_matrix_mul_batched_strided_wrapper<Operation_${operation_name}>( | ||||
| BatchedGemmCoord const& problem_size, | BatchedGemmCoord const& problem_size, | ||||
| const typename Operation_${operation_name}::ElementA* d_A, size_t lda, size_t batch_stride_a, | |||||
| const typename Operation_${operation_name}::ElementB* d_B, size_t ldb, size_t batch_stride_b, | |||||
| const typename Operation_${operation_name}::ElementA* d_A, size_t lda, size_t batch_stride_a, | |||||
| const typename Operation_${operation_name}::ElementB* d_B, size_t ldb, size_t batch_stride_b, | |||||
| typename Operation_${operation_name}::ElementCD* d_C, size_t ldc, size_t batch_stride_c, | typename Operation_${operation_name}::ElementCD* d_C, size_t ldc, size_t batch_stride_c, | ||||
| cudaStream_t stream); | cudaStream_t stream); | ||||
| """ | """ | ||||
| if self.operation.gemm_kind == GemmKind.SplitKParallel or \ | |||||
| self.operation.gemm_kind == GemmKind.Gemm: | |||||
| self.wrapper_template = gemm_wrapper | |||||
| else: | |||||
| assert self.operation.gemm_kind == GemmKind.GemvBatchedStrided | |||||
| self.wrapper_template = gemv_wrapper | |||||
| instance_emitters = { | |||||
| GemmKind.Gemm: EmitGemmInstance(), | |||||
| GemmKind.SplitKParallel: EmitGemmSplitKParallelInstance(), | |||||
| GemmKind.GemvBatchedStrided: EmitGemvBatchedStridedInstance(), | |||||
| } | |||||
| self.instance_emitter = instance_emitters[self.operation.gemm_kind] | |||||
| self.instance_emitter = EmitGemvBatchedStridedInstance() | |||||
| self.header_template = """ | self.header_template = """ | ||||
| #if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) | |||||
| #if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) | |||||
| // ignore warning of cutlass | // ignore warning of cutlass | ||||
| #pragma GCC diagnostic push | #pragma GCC diagnostic push | ||||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | #pragma GCC diagnostic ignored "-Wunused-parameter" | ||||
| @@ -1055,10 +1108,10 @@ ${operation_instance} | |||||
| """ | """ | ||||
| # | # | ||||
| def __enter__(self): | def __enter__(self): | ||||
| self.kernel_path = os.path.join(self.kernel_path, "%s.cu" % self.operation.procedural_name()) | |||||
| self.kernel_path = os.path.join(self.kernel_path, "%s.cu" % self.operation.procedural_name()) | |||||
| self.kernel_file = LazyFile(self.kernel_path) | self.kernel_file = LazyFile(self.kernel_path) | ||||
| self.kernel_file.write(SubstituteTemplate(self.header_template, { | self.kernel_file.write(SubstituteTemplate(self.header_template, { | ||||
| 'wrapper_path': self.wrapper_path, | |||||
| 'wrapper_path': self.wrapper_path, | |||||
| })) | })) | ||||
| return self | return self | ||||
| @@ -1070,7 +1123,7 @@ ${operation_instance} | |||||
| # emit wrapper | # emit wrapper | ||||
| wrapper = SubstituteTemplate(self.wrapper_template, { | wrapper = SubstituteTemplate(self.wrapper_template, { | ||||
| 'operation_name': self.operation.procedural_name(), | |||||
| 'operation_name': self.operation.procedural_name(), | |||||
| }) | }) | ||||
| self.kernel_file.write(wrapper) | self.kernel_file.write(wrapper) | ||||
| @@ -1079,7 +1132,5 @@ ${operation_instance} | |||||
| self.kernel_file.write(self.epilogue_template) | self.kernel_file.write(self.epilogue_template) | ||||
| self.kernel_file.close() | self.kernel_file.close() | ||||
| ################################################################################################### | ################################################################################################### | ||||
| ################################################################################################### | ################################################################################################### | ||||
| @@ -23,6 +23,8 @@ def write_op_list(f, gen_op, gen_type): | |||||
| operations = GenerateDeconvOperations(GenArg(gen_op, gen_type)) | operations = GenerateDeconvOperations(GenArg(gen_op, gen_type)) | ||||
| for op in operations: | for op in operations: | ||||
| f.write(' "%s.cu",\n' % op.procedural_name()) | f.write(' "%s.cu",\n' % op.procedural_name()) | ||||
| if gen_op != "gemv": | |||||
| f.write(' "all_%s_%s_operations.cu",\n' % (gen_op, gen_type)) | |||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||
| @@ -292,7 +292,7 @@ def GenerateConv2d_TensorOp_8832(args): | |||||
| ] | ] | ||||
| operations += GenerateConv2d(ConvKind.Fprop, tile_descriptions, layout[0], layout[1], | operations += GenerateConv2d(ConvKind.Fprop, tile_descriptions, layout[0], layout[1], | ||||
| dst_layout, dst_type, min_cc, 128, 128, 64, | dst_layout, dst_type, min_cc, 128, 128, 64, | ||||
| True, ImplicitGemmMode.GemmTN, True) | |||||
| False, ImplicitGemmMode.GemmTN, True) | |||||
| layouts_nhwc = [ | layouts_nhwc = [ | ||||
| (LayoutType.TensorNHWC, LayoutType.TensorNC8HW8, 32), | (LayoutType.TensorNHWC, LayoutType.TensorNC8HW8, 32), | ||||
| @@ -633,16 +633,10 @@ if __name__ == "__main__": | |||||
| parser.add_argument("--type", type=str, choices=['simt', 'tensorop8816', 'tensorop8832'], | parser.add_argument("--type", type=str, choices=['simt', 'tensorop8816', 'tensorop8832'], | ||||
| default='simt', help="kernel type of CUTLASS kernel generator") | default='simt', help="kernel type of CUTLASS kernel generator") | ||||
| operation2wrapper_path = { | |||||
| "gemm": "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cuinl", \ | |||||
| "gemv": "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper_batched_gemv_strided.cuinl", \ | |||||
| "conv2d": "src/cuda/conv_bias/implicit_gemm_conv_bias_cutlass_wrapper.cuinl", \ | |||||
| "deconv": "src/cuda/convolution/backward_data/implicit_gemm_deconv_cutlass_wrapper.cuinl", \ | |||||
| } | |||||
| gemv_wrapper_path = "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper_batched_gemv_strided.cuinl" | |||||
| args = parser.parse_args() | args = parser.parse_args() | ||||
| wrapper_path = operation2wrapper_path[args.operations] | |||||
| if args.operations == "gemm": | if args.operations == "gemm": | ||||
| operations = GenerateGemmOperations(args) | operations = GenerateGemmOperations(args) | ||||
| elif args.operations == "gemv": | elif args.operations == "gemv": | ||||
| @@ -652,16 +646,22 @@ if __name__ == "__main__": | |||||
| elif args.operations == "deconv": | elif args.operations == "deconv": | ||||
| operations = GenerateDeconvOperations(args) | operations = GenerateDeconvOperations(args) | ||||
| if args.operations == "conv2d" or args.operations == "deconv": | if args.operations == "conv2d" or args.operations == "deconv": | ||||
| for operation in operations: | for operation in operations: | ||||
| with EmitConvSingleKernelWrapper(args.output, operation, wrapper_path) as emitter: | |||||
| with EmitConvSingleKernelWrapper(args.output, operation) as emitter: | |||||
| emitter.emit() | emitter.emit() | ||||
| elif args.operations == "gemm" or args.operations == "gemv": | |||||
| elif args.operations == "gemm": | |||||
| for operation in operations: | for operation in operations: | ||||
| with EmitGemmSingleKernelWrapper(args.output, operation, wrapper_path) as emitter: | |||||
| with EmitGemmSingleKernelWrapper(args.output, operation) as emitter: | |||||
| emitter.emit() | emitter.emit() | ||||
| elif args.operations == "gemv": | |||||
| for operation in operations: | |||||
| with EmitGemvSingleKernelWrapper(args.output, operation, gemv_wrapper_path) as emitter: | |||||
| emitter.emit() | |||||
| if args.operations != "gemv": | |||||
| GenerateManifest(args, operations, args.output) | |||||
| # | # | ||||
| ################################################################################################### | ################################################################################################### | ||||
| @@ -137,6 +137,7 @@ cutlass_gen_list = [ | |||||
| "cutlass_simt_sgemm_split_k_parallel_256x32_8x2_tt_align1.cu", | "cutlass_simt_sgemm_split_k_parallel_256x32_8x2_tt_align1.cu", | ||||
| "cutlass_simt_sgemm_256x64_8x2_tt_align1.cu", | "cutlass_simt_sgemm_256x64_8x2_tt_align1.cu", | ||||
| "cutlass_simt_sgemm_split_k_parallel_256x64_8x2_tt_align1.cu", | "cutlass_simt_sgemm_split_k_parallel_256x64_8x2_tt_align1.cu", | ||||
| "all_gemm_simt_operations.cu", | |||||
| "cutlass_simt_sgemv_batched_strided_1x128_32_tt_align4x4.cu", | "cutlass_simt_sgemv_batched_strided_1x128_32_tt_align4x4.cu", | ||||
| "cutlass_simt_sgemv_batched_strided_1x128_16_tt_align4x2.cu", | "cutlass_simt_sgemv_batched_strided_1x128_16_tt_align4x2.cu", | ||||
| "cutlass_simt_sgemv_batched_strided_1x128_8_tt_align4x1.cu", | "cutlass_simt_sgemv_batched_strided_1x128_8_tt_align4x1.cu", | ||||
| @@ -169,6 +170,7 @@ cutlass_gen_list = [ | |||||
| "cutlass_simt_s8_idgrad_id_s8_16x128x16_16x64x16_2_nc4hw4_k4rsc4.cu", | "cutlass_simt_s8_idgrad_id_s8_16x128x16_16x64x16_2_nc4hw4_k4rsc4.cu", | ||||
| "cutlass_simt_s8_idgrad_id_s8_16x128x16_16x128x16_1_nc4hw4_k4rsc4.cu", | "cutlass_simt_s8_idgrad_id_s8_16x128x16_16x128x16_1_nc4hw4_k4rsc4.cu", | ||||
| "cutlass_simt_s8_idgrad_id_s8_16x64x8_16x64x8_2_nc4hw4_k4rsc4.cu", | "cutlass_simt_s8_idgrad_id_s8_16x64x8_16x64x8_2_nc4hw4_k4rsc4.cu", | ||||
| "all_deconv_simt_operations.cu", | |||||
| "cutlass_simt_s8_ifprop_id_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | "cutlass_simt_s8_ifprop_id_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | ||||
| "cutlass_simt_s8_ifprop_1x1_id_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | "cutlass_simt_s8_ifprop_1x1_id_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | ||||
| "cutlass_simt_s8_ifprop_relu_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | "cutlass_simt_s8_ifprop_relu_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | ||||
| @@ -373,6 +375,7 @@ cutlass_gen_list = [ | |||||
| "cutlass_simt_f32_ifprop_1x1_relu_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nchw.cu", | "cutlass_simt_f32_ifprop_1x1_relu_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nchw.cu", | ||||
| "cutlass_simt_f32_ifprop_hswish_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nchw.cu", | "cutlass_simt_f32_ifprop_hswish_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nchw.cu", | ||||
| "cutlass_simt_f32_ifprop_1x1_hswish_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nchw.cu", | "cutlass_simt_f32_ifprop_1x1_hswish_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nchw.cu", | ||||
| "all_conv2d_simt_operations.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_roc_id_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32.cu", | "cutlass_tensorop_s8_i8816fprop_roc_id_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32.cu", | ||||
| "cutlass_tensorop_s8_i8816fprop_1x1_roc_id_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32.cu", | "cutlass_tensorop_s8_i8816fprop_1x1_roc_id_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32.cu", | ||||
| "cutlass_tensorop_s8_i8816fprop_roc_relu_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32.cu", | "cutlass_tensorop_s8_i8816fprop_roc_relu_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32.cu", | ||||
| @@ -481,26 +484,47 @@ cutlass_gen_list = [ | |||||
| "cutlass_tensorop_s8_i8816fprop_1x1_relu_s8_32x128x32_32x64x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | "cutlass_tensorop_s8_i8816fprop_1x1_relu_s8_32x128x32_32x64x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | ||||
| "cutlass_tensorop_s8_i8816fprop_hswish_s8_32x128x32_32x64x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | "cutlass_tensorop_s8_i8816fprop_hswish_s8_32x128x32_32x64x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | ||||
| "cutlass_tensorop_s8_i8816fprop_1x1_hswish_s8_32x128x32_32x64x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | "cutlass_tensorop_s8_i8816fprop_1x1_hswish_s8_32x128x32_32x64x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | ||||
| "all_conv2d_tensorop8816_operations.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_roc_id_s4_128x256x128_64x64x128_2_nc64hw64_c64rsk64.cu", | "cutlass_tensorop_s4_i8832fprop_roc_id_s4_128x256x128_64x64x128_2_nc64hw64_c64rsk64.cu", | ||||
| "cutlass_tensorop_s4_i8832fprop_1x1_roc_id_s4_128x256x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_roc_relu_s4_128x256x128_64x64x128_2_nc64hw64_c64rsk64.cu", | "cutlass_tensorop_s4_i8832fprop_roc_relu_s4_128x256x128_64x64x128_2_nc64hw64_c64rsk64.cu", | ||||
| "cutlass_tensorop_s4_i8832fprop_1x1_roc_relu_s4_128x256x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_roc_hswish_s4_128x256x128_64x64x128_2_nc64hw64_c64rsk64.cu", | "cutlass_tensorop_s4_i8832fprop_roc_hswish_s4_128x256x128_64x64x128_2_nc64hw64_c64rsk64.cu", | ||||
| "cutlass_tensorop_s4_i8832fprop_1x1_roc_hswish_s4_128x256x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_roc_id_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | "cutlass_tensorop_s4_i8832fprop_roc_id_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | ||||
| "cutlass_tensorop_s4_i8832fprop_1x1_roc_id_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_roc_relu_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | "cutlass_tensorop_s4_i8832fprop_roc_relu_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | ||||
| "cutlass_tensorop_s4_i8832fprop_1x1_roc_relu_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_roc_hswish_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | "cutlass_tensorop_s4_i8832fprop_roc_hswish_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | ||||
| "cutlass_tensorop_s4_i8832fprop_1x1_roc_hswish_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_roc_id_s4_128x64x128_64x64x128_2_nc64hw64_c64rsk64.cu", | "cutlass_tensorop_s4_i8832fprop_roc_id_s4_128x64x128_64x64x128_2_nc64hw64_c64rsk64.cu", | ||||
| "cutlass_tensorop_s4_i8832fprop_1x1_roc_id_s4_128x64x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_roc_relu_s4_128x64x128_64x64x128_2_nc64hw64_c64rsk64.cu", | "cutlass_tensorop_s4_i8832fprop_roc_relu_s4_128x64x128_64x64x128_2_nc64hw64_c64rsk64.cu", | ||||
| "cutlass_tensorop_s4_i8832fprop_1x1_roc_relu_s4_128x64x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_roc_hswish_s4_128x64x128_64x64x128_2_nc64hw64_c64rsk64.cu", | "cutlass_tensorop_s4_i8832fprop_roc_hswish_s4_128x64x128_64x64x128_2_nc64hw64_c64rsk64.cu", | ||||
| "cutlass_tensorop_s4_i8832fprop_1x1_roc_hswish_s4_128x64x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_roc_id_s4_128x64x64_64x64x64_1_nc64hw64_c64rsk64.cu", | "cutlass_tensorop_s4_i8832fprop_roc_id_s4_128x64x64_64x64x64_1_nc64hw64_c64rsk64.cu", | ||||
| "cutlass_tensorop_s4_i8832fprop_1x1_roc_id_s4_128x64x64_64x64x64_1_nc64hw64_c64rsk64.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_roc_relu_s4_128x64x64_64x64x64_1_nc64hw64_c64rsk64.cu", | "cutlass_tensorop_s4_i8832fprop_roc_relu_s4_128x64x64_64x64x64_1_nc64hw64_c64rsk64.cu", | ||||
| "cutlass_tensorop_s4_i8832fprop_1x1_roc_relu_s4_128x64x64_64x64x64_1_nc64hw64_c64rsk64.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_roc_hswish_s4_128x64x64_64x64x64_1_nc64hw64_c64rsk64.cu", | "cutlass_tensorop_s4_i8832fprop_roc_hswish_s4_128x64x64_64x64x64_1_nc64hw64_c64rsk64.cu", | ||||
| "cutlass_tensorop_s4_i8832fprop_1x1_roc_hswish_s4_128x64x64_64x64x64_1_nc64hw64_c64rsk64.cu", | |||||
| "cutlass_tensorop_u4_i8832fprop_roc_id_u4_s4_128x256x128_64x64x128_2_nc64hw64_c64rsk64.cu", | "cutlass_tensorop_u4_i8832fprop_roc_id_u4_s4_128x256x128_64x64x128_2_nc64hw64_c64rsk64.cu", | ||||
| "cutlass_tensorop_u4_i8832fprop_1x1_roc_id_u4_s4_128x256x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||||
| "cutlass_tensorop_u4_i8832fprop_roc_relu_u4_s4_128x256x128_64x64x128_2_nc64hw64_c64rsk64.cu", | "cutlass_tensorop_u4_i8832fprop_roc_relu_u4_s4_128x256x128_64x64x128_2_nc64hw64_c64rsk64.cu", | ||||
| "cutlass_tensorop_u4_i8832fprop_1x1_roc_relu_u4_s4_128x256x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||||
| "cutlass_tensorop_u4_i8832fprop_roc_id_u4_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | "cutlass_tensorop_u4_i8832fprop_roc_id_u4_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | ||||
| "cutlass_tensorop_u4_i8832fprop_1x1_roc_id_u4_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||||
| "cutlass_tensorop_u4_i8832fprop_roc_relu_u4_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | "cutlass_tensorop_u4_i8832fprop_roc_relu_u4_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | ||||
| "cutlass_tensorop_u4_i8832fprop_1x1_roc_relu_u4_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||||
| "cutlass_tensorop_u4_i8832fprop_roc_id_u4_s4_128x64x128_64x64x128_2_nc64hw64_c64rsk64.cu", | "cutlass_tensorop_u4_i8832fprop_roc_id_u4_s4_128x64x128_64x64x128_2_nc64hw64_c64rsk64.cu", | ||||
| "cutlass_tensorop_u4_i8832fprop_1x1_roc_id_u4_s4_128x64x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||||
| "cutlass_tensorop_u4_i8832fprop_roc_relu_u4_s4_128x64x128_64x64x128_2_nc64hw64_c64rsk64.cu", | "cutlass_tensorop_u4_i8832fprop_roc_relu_u4_s4_128x64x128_64x64x128_2_nc64hw64_c64rsk64.cu", | ||||
| "cutlass_tensorop_u4_i8832fprop_1x1_roc_relu_u4_s4_128x64x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||||
| "cutlass_tensorop_u4_i8832fprop_roc_id_u4_s4_128x64x64_64x64x64_1_nc64hw64_c64rsk64.cu", | "cutlass_tensorop_u4_i8832fprop_roc_id_u4_s4_128x64x64_64x64x64_1_nc64hw64_c64rsk64.cu", | ||||
| "cutlass_tensorop_u4_i8832fprop_1x1_roc_id_u4_s4_128x64x64_64x64x64_1_nc64hw64_c64rsk64.cu", | |||||
| "cutlass_tensorop_u4_i8832fprop_roc_relu_u4_s4_128x64x64_64x64x64_1_nc64hw64_c64rsk64.cu", | "cutlass_tensorop_u4_i8832fprop_roc_relu_u4_s4_128x64x64_64x64x64_1_nc64hw64_c64rsk64.cu", | ||||
| "cutlass_tensorop_u4_i8832fprop_1x1_roc_relu_u4_s4_128x64x64_64x64x64_1_nc64hw64_c64rsk64.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_id_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | "cutlass_tensorop_s4_i8832fprop_id_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | ||||
| "cutlass_tensorop_s4_i8832fprop_1x1_id_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | "cutlass_tensorop_s4_i8832fprop_1x1_id_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | ||||
| "cutlass_tensorop_s4_i8832fprop_relu_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | "cutlass_tensorop_s4_i8832fprop_relu_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | ||||
| @@ -621,4 +645,5 @@ cutlass_gen_list = [ | |||||
| "cutlass_tensorop_u4_i8832fprop_1x1_roc_id_u4_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | "cutlass_tensorop_u4_i8832fprop_1x1_roc_id_u4_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | ||||
| "cutlass_tensorop_u4_i8832fprop_roc_relu_u4_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | "cutlass_tensorop_u4_i8832fprop_roc_relu_u4_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | ||||
| "cutlass_tensorop_u4_i8832fprop_1x1_roc_relu_u4_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | "cutlass_tensorop_u4_i8832fprop_1x1_roc_relu_u4_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | ||||
| "all_conv2d_tensorop8832_operations.cu", | |||||
| ] | ] | ||||
| @@ -8,6 +8,7 @@ import enum | |||||
| import os.path | import os.path | ||||
| import shutil | import shutil | ||||
| from lazy_file import LazyFile | |||||
| from library import * | from library import * | ||||
| from gemm_operation import * | from gemm_operation import * | ||||
| from conv2d_operation import * | from conv2d_operation import * | ||||
| @@ -349,3 +350,41 @@ void initialize_all(Manifest &manifest) { | |||||
| # | # | ||||
| ################################################################################################### | ################################################################################################### | ||||
| def GenerateManifest(args, operations, output_dir): | |||||
| manifest_path = os.path.join(output_dir, "all_%s_%s_operations.cu" % (args.operations, args.type)) | |||||
| f = LazyFile(manifest_path) | |||||
| f.write(""" | |||||
| /* | |||||
| Generated by generator.py - Do not edit. | |||||
| */ | |||||
| #if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) | |||||
| #include "cutlass/cutlass.h" | |||||
| #include "src/cuda/cutlass/library.h" | |||||
| #include "src/cuda/cutlass/manifest.h" | |||||
| namespace cutlass { | |||||
| namespace library { | |||||
| """) | |||||
| for op in operations: | |||||
| f.write("void initialize_%s(Manifest &manifest);\n" % op.procedural_name()) | |||||
| f.write(""" | |||||
| void initialize_all_%s_%s_operations(Manifest &manifest) { | |||||
| """ % (args.operations, args.type)) | |||||
| for op in operations: | |||||
| f.write(" initialize_%s(manifest);\n" % op.procedural_name()) | |||||
| f.write(""" | |||||
| } | |||||
| } // namespace library | |||||
| } // namespace cutlass | |||||
| #endif | |||||
| """) | |||||
| f.close() | |||||
| @@ -217,68 +217,77 @@ void ConvBiasForwardImpl::AlgoPack::fill_imma_algos() { | |||||
| #if CUDA_VERSION >= 10020 | #if CUDA_VERSION >= 10020 | ||||
| { | { | ||||
| using AlgoParam = AlgoInt8NCHW32IMMAImplicitGemm::AlgoParam; | using AlgoParam = AlgoInt8NCHW32IMMAImplicitGemm::AlgoParam; | ||||
| int8_nchw32_imma.emplace_back(AlgoParam{128, 256, 64, 64, 64, 64, 2}); | |||||
| int8_nchw32_imma.emplace_back(AlgoParam{256, 128, 64, 64, 64, 64, 2}); | |||||
| int8_nchw32_imma.emplace_back(AlgoParam{128, 128, 64, 64, 64, 64, 2}); | |||||
| int8_nchw32_imma.emplace_back(AlgoParam{128, 64, 64, 64, 32, 64, 2}); | |||||
| int8_nchw32_imma.emplace_back(AlgoParam{64, 128, 64, 32, 64, 64, 2}); | |||||
| int8_nchw32_imma.emplace_back(AlgoParam{128, 64, 32, 64, 32, 32, 1}); | |||||
| int8_nchw32_imma.emplace_back(AlgoParam{128, 32, 32, 64, 32, 32, 1}); | |||||
| int8_nchw32_imma.emplace_back(AlgoParam{64, 128, 32, 32, 64, 32, 1}); | |||||
| int8_nchw32_imma.emplace_back(AlgoParam{32, 128, 32, 32, 64, 32, 1}); | |||||
| int8_nchw32_imma.emplace_back( | |||||
| AlgoParam{128, 256, 64, 64, 64, 64, 8, 8, 16, 2}); | |||||
| int8_nchw32_imma.emplace_back( | |||||
| AlgoParam{256, 128, 64, 64, 64, 64, 8, 8, 16, 2}); | |||||
| int8_nchw32_imma.emplace_back( | |||||
| AlgoParam{128, 128, 64, 64, 64, 64, 8, 8, 16, 2}); | |||||
| int8_nchw32_imma.emplace_back( | |||||
| AlgoParam{128, 64, 64, 64, 32, 64, 8, 8, 16, 2}); | |||||
| int8_nchw32_imma.emplace_back( | |||||
| AlgoParam{64, 128, 64, 32, 64, 64, 8, 8, 16, 2}); | |||||
| int8_nchw32_imma.emplace_back( | |||||
| AlgoParam{128, 64, 32, 64, 32, 32, 8, 8, 16, 1}); | |||||
| int8_nchw32_imma.emplace_back( | |||||
| AlgoParam{128, 32, 32, 64, 32, 32, 8, 8, 16, 1}); | |||||
| int8_nchw32_imma.emplace_back( | |||||
| AlgoParam{64, 128, 32, 32, 64, 32, 8, 8, 16, 1}); | |||||
| int8_nchw32_imma.emplace_back( | |||||
| AlgoParam{32, 128, 32, 32, 64, 32, 8, 8, 16, 1}); | |||||
| } | } | ||||
| { | { | ||||
| using AlgoParam = AlgoInt4Int4NCHW64IMMAImplicitGemm::AlgoParam; | using AlgoParam = AlgoInt4Int4NCHW64IMMAImplicitGemm::AlgoParam; | ||||
| int4_int4_nchw64_imma.emplace_back( | int4_int4_nchw64_imma.emplace_back( | ||||
| AlgoParam{128, 128, 128, 64, 64, 128, 2}); | |||||
| AlgoParam{128, 128, 128, 64, 64, 128, 8, 8, 32, 2}); | |||||
| int4_int4_nchw64_imma.emplace_back( | int4_int4_nchw64_imma.emplace_back( | ||||
| AlgoParam{128, 256, 128, 64, 64, 128, 2}); | |||||
| AlgoParam{128, 256, 128, 64, 64, 128, 8, 8, 32, 2}); | |||||
| int4_int4_nchw64_imma.emplace_back( | int4_int4_nchw64_imma.emplace_back( | ||||
| AlgoParam{128, 64, 128, 64, 64, 128, 2}); | |||||
| AlgoParam{128, 64, 128, 64, 64, 128, 8, 8, 32, 2}); | |||||
| int4_int4_nchw64_imma.emplace_back( | int4_int4_nchw64_imma.emplace_back( | ||||
| AlgoParam{128, 64, 64, 64, 64, 64, 1}); | |||||
| AlgoParam{128, 64, 64, 64, 64, 64, 8, 8, 32, 1}); | |||||
| } | } | ||||
| { | { | ||||
| using AlgoParam = AlgoUInt4Int4NCHW64IMMAImplicitGemm::AlgoParam; | using AlgoParam = AlgoUInt4Int4NCHW64IMMAImplicitGemm::AlgoParam; | ||||
| uint4_int4_nchw64_imma.emplace_back( | uint4_int4_nchw64_imma.emplace_back( | ||||
| AlgoParam{128, 128, 128, 64, 64, 128, 2}); | |||||
| AlgoParam{128, 128, 128, 64, 64, 128, 8, 8, 32, 2}); | |||||
| uint4_int4_nchw64_imma.emplace_back( | uint4_int4_nchw64_imma.emplace_back( | ||||
| AlgoParam{128, 256, 128, 64, 64, 128, 2}); | |||||
| AlgoParam{128, 256, 128, 64, 64, 128, 8, 8, 32, 2}); | |||||
| uint4_int4_nchw64_imma.emplace_back( | uint4_int4_nchw64_imma.emplace_back( | ||||
| AlgoParam{128, 64, 128, 64, 64, 128, 2}); | |||||
| AlgoParam{128, 64, 128, 64, 64, 128, 8, 8, 32, 2}); | |||||
| uint4_int4_nchw64_imma.emplace_back( | uint4_int4_nchw64_imma.emplace_back( | ||||
| AlgoParam{128, 64, 64, 64, 64, 64, 1}); | |||||
| AlgoParam{128, 64, 64, 64, 64, 64, 8, 8, 32, 1}); | |||||
| } | } | ||||
| { | { | ||||
| using AlgoParam = AlgoInt4Int4NHWCIMMAImplicitGemm::AlgoParam; | using AlgoParam = AlgoInt4Int4NHWCIMMAImplicitGemm::AlgoParam; | ||||
| int4_int4_nhwc_imma.emplace_back( | int4_int4_nhwc_imma.emplace_back( | ||||
| AlgoParam{128, 32, 64, 64, 32, 64, 1, 32}); | |||||
| AlgoParam{128, 32, 64, 64, 32, 64, 8, 8, 32, 1, 32}); | |||||
| int4_int4_nhwc_imma.emplace_back( | int4_int4_nhwc_imma.emplace_back( | ||||
| AlgoParam{128, 32, 64, 64, 32, 64, 1, 16}); | |||||
| AlgoParam{128, 32, 64, 64, 32, 64, 8, 8, 32, 1, 16}); | |||||
| int4_int4_nhwc_imma.emplace_back( | int4_int4_nhwc_imma.emplace_back( | ||||
| AlgoParam{128, 32, 64, 64, 32, 64, 1, 8}); | |||||
| AlgoParam{128, 32, 64, 64, 32, 64, 8, 8, 32, 1, 8}); | |||||
| int4_int4_nhwc_imma.emplace_back( | int4_int4_nhwc_imma.emplace_back( | ||||
| AlgoParam{128, 64, 64, 64, 64, 64, 1, 32}); | |||||
| AlgoParam{128, 64, 64, 64, 64, 64, 8, 8, 32, 1, 32}); | |||||
| int4_int4_nhwc_imma.emplace_back( | int4_int4_nhwc_imma.emplace_back( | ||||
| AlgoParam{128, 64, 64, 64, 64, 64, 1, 16}); | |||||
| AlgoParam{128, 64, 64, 64, 64, 64, 8, 8, 32, 1, 16}); | |||||
| int4_int4_nhwc_imma.emplace_back( | int4_int4_nhwc_imma.emplace_back( | ||||
| AlgoParam{128, 64, 64, 64, 64, 64, 1, 8}); | |||||
| AlgoParam{128, 64, 64, 64, 64, 64, 8, 8, 32, 1, 8}); | |||||
| } | } | ||||
| { | { | ||||
| using AlgoParam = AlgoUInt4Int4NHWCIMMAImplicitGemm::AlgoParam; | using AlgoParam = AlgoUInt4Int4NHWCIMMAImplicitGemm::AlgoParam; | ||||
| uint4_int4_nhwc_imma.emplace_back( | uint4_int4_nhwc_imma.emplace_back( | ||||
| AlgoParam{128, 32, 64, 64, 32, 64, 1, 32}); | |||||
| AlgoParam{128, 32, 64, 64, 32, 64, 8, 8, 32, 1, 32}); | |||||
| uint4_int4_nhwc_imma.emplace_back( | uint4_int4_nhwc_imma.emplace_back( | ||||
| AlgoParam{128, 32, 64, 64, 32, 64, 1, 16}); | |||||
| AlgoParam{128, 32, 64, 64, 32, 64, 8, 8, 32, 1, 16}); | |||||
| uint4_int4_nhwc_imma.emplace_back( | uint4_int4_nhwc_imma.emplace_back( | ||||
| AlgoParam{128, 32, 64, 64, 32, 64, 1, 8}); | |||||
| AlgoParam{128, 32, 64, 64, 32, 64, 8, 8, 32, 1, 8}); | |||||
| uint4_int4_nhwc_imma.emplace_back( | uint4_int4_nhwc_imma.emplace_back( | ||||
| AlgoParam{128, 64, 64, 64, 64, 64, 1, 32}); | |||||
| AlgoParam{128, 64, 64, 64, 64, 64, 8, 8, 32, 1, 32}); | |||||
| uint4_int4_nhwc_imma.emplace_back( | uint4_int4_nhwc_imma.emplace_back( | ||||
| AlgoParam{128, 64, 64, 64, 64, 64, 1, 16}); | |||||
| AlgoParam{128, 64, 64, 64, 64, 64, 8, 8, 32, 1, 16}); | |||||
| uint4_int4_nhwc_imma.emplace_back( | uint4_int4_nhwc_imma.emplace_back( | ||||
| AlgoParam{128, 64, 64, 64, 64, 64, 1, 8}); | |||||
| AlgoParam{128, 64, 64, 64, 64, 64, 8, 8, 32, 1, 8}); | |||||
| } | } | ||||
| #endif | #endif | ||||
| } | } | ||||
| @@ -286,15 +295,24 @@ void ConvBiasForwardImpl::AlgoPack::fill_imma_algos() { | |||||
| void ConvBiasForwardImpl::AlgoPack::fill_dp4a_algos() { | void ConvBiasForwardImpl::AlgoPack::fill_dp4a_algos() { | ||||
| using AlgoParam = AlgoInt8NCHW4DotProdImplicitGemm::AlgoParam; | using AlgoParam = AlgoInt8NCHW4DotProdImplicitGemm::AlgoParam; | ||||
| int8_nchw4_dotprod.emplace_back(AlgoParam{128, 128, 32, 64, 32, 32, 2}); | |||||
| int8_nchw4_dotprod.emplace_back(AlgoParam{128, 64, 32, 64, 32, 32, 2}); | |||||
| int8_nchw4_dotprod.emplace_back(AlgoParam{64, 128, 32, 64, 32, 32, 2}); | |||||
| int8_nchw4_dotprod.emplace_back(AlgoParam{32, 128, 32, 32, 64, 32, 2}); | |||||
| int8_nchw4_dotprod.emplace_back(AlgoParam{128, 32, 32, 64, 32, 32, 2}); | |||||
| int8_nchw4_dotprod.emplace_back(AlgoParam{32, 64, 32, 32, 64, 32, 2}); | |||||
| int8_nchw4_dotprod.emplace_back(AlgoParam{64, 32, 32, 64, 32, 32, 2}); | |||||
| int8_nchw4_dotprod.emplace_back(AlgoParam{16, 128, 16, 16, 128, 16, 1}); | |||||
| int8_nchw4_dotprod.emplace_back(AlgoParam{16, 64, 8, 16, 64, 8, 2}); | |||||
| int8_nchw4_dotprod.emplace_back( | |||||
| AlgoParam{128, 128, 32, 64, 32, 32, 1, 1, 4, 2}); | |||||
| int8_nchw4_dotprod.emplace_back( | |||||
| AlgoParam{128, 64, 32, 64, 32, 32, 1, 1, 4, 2}); | |||||
| int8_nchw4_dotprod.emplace_back( | |||||
| AlgoParam{64, 128, 32, 64, 32, 32, 1, 1, 4, 2}); | |||||
| int8_nchw4_dotprod.emplace_back( | |||||
| AlgoParam{32, 128, 32, 32, 64, 32, 1, 1, 4, 2}); | |||||
| int8_nchw4_dotprod.emplace_back( | |||||
| AlgoParam{128, 32, 32, 64, 32, 32, 1, 1, 4, 2}); | |||||
| int8_nchw4_dotprod.emplace_back( | |||||
| AlgoParam{32, 64, 32, 32, 64, 32, 1, 1, 4, 2}); | |||||
| int8_nchw4_dotprod.emplace_back( | |||||
| AlgoParam{64, 32, 32, 64, 32, 32, 1, 1, 4, 2}); | |||||
| int8_nchw4_dotprod.emplace_back( | |||||
| AlgoParam{16, 128, 16, 16, 128, 16, 1, 1, 4, 1}); | |||||
| int8_nchw4_dotprod.emplace_back( | |||||
| AlgoParam{16, 64, 8, 16, 64, 8, 1, 1, 4, 2}); | |||||
| } | } | ||||
| ConvBiasForwardImpl::AlgoBase* | ConvBiasForwardImpl::AlgoBase* | ||||
| @@ -28,6 +28,17 @@ | |||||
| #include <memory> | #include <memory> | ||||
| #include <unordered_map> | #include <unordered_map> | ||||
| namespace cutlass { | |||||
| namespace library { | |||||
| // forward declaration of cutlass library concepts, we hope that algo.h does | |||||
| // not depend on cutlass headers | |||||
| class Operation; | |||||
| } // namespace library | |||||
| } // namespace cutlass | |||||
| namespace megdnn { | namespace megdnn { | ||||
| namespace cuda { | namespace cuda { | ||||
| @@ -505,9 +516,44 @@ public: | |||||
| MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_CHWN4_DOTPROD_INT8) | MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_CHWN4_DOTPROD_INT8) | ||||
| }; | }; | ||||
| class ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm final | |||||
| : public AlgoBase { | |||||
| /*********************** Cutlass Algorithms ************************/ | |||||
| /* The inheritance of cutlass algorithm classes: | |||||
| * | |||||
| * AlgoCutlassConvolutionBase | |||||
| * + | |||||
| * +--- AlgoInt8NCHW4DotProdImplicitGemm | |||||
| * +--- AlgoInt8NCHW32IMMAImplicitGemm | |||||
| * + | |||||
| * +--- AlgoInt4NCHW64IMMAImplicitGemmBase | |||||
| * +----+--- AlgoInt4Int4NCHW64IMMAImplicitGemm | |||||
| * +----+--- AlgoUInt4Int4NCHW64IMMAImplicitGemm | |||||
| * + | |||||
| * +--- AlgoInt4NHWCIMMAImplicitGemmBase | |||||
| * +----+--- AlgoInt4Int4NHWCIMMAImplicitGemm | |||||
| * +----+--- AlgoUInt4Int4NHWCIMMAImplicitGemm | |||||
| * + | |||||
| */ | |||||
| /* | |||||
| * The base class for all cutlass algorithm classes | |||||
| */ | |||||
| class ConvBiasForwardImpl::AlgoCutlassConvolutionBase : public AlgoBase { | |||||
| public: | public: | ||||
| // corresponds to cutlass::conv::Operator. we hope that algo.h does not | |||||
| // depend on cutlass headers | |||||
| enum class ConvOperator { kFprop, kDgrad, kWgrad }; | |||||
| // corresponds to cutlass::conv::ConvType. we hope that algo.h does not | |||||
| // depend on cutlass headers | |||||
| enum class ConvType { | |||||
| kConvolution, | |||||
| kBatchConvolution, | |||||
| kLocal, | |||||
| kLocalShare | |||||
| }; | |||||
| // common parameters for operation selection | |||||
| struct AlgoParam { | struct AlgoParam { | ||||
| int threadblock_m; | int threadblock_m; | ||||
| int threadblock_n; | int threadblock_n; | ||||
| @@ -515,21 +561,54 @@ public: | |||||
| int warp_m; | int warp_m; | ||||
| int warp_n; | int warp_n; | ||||
| int warp_k; | int warp_k; | ||||
| int instruction_m; | |||||
| int instruction_n; | |||||
| int instruction_k; | |||||
| int stage; | int stage; | ||||
| std::string to_string() { | |||||
| /// default algorithm | |||||
| if (threadblock_m == 128 && threadblock_n == 128 && | |||||
| threadblock_k == 32 && warp_m == 32 && warp_n == 64 && | |||||
| warp_k == 32 && stage == 2) { | |||||
| return ""; | |||||
| } | |||||
| return ssprintf("_%dX%dX%d_%dX%dX%d_%dstage", threadblock_m, | |||||
| threadblock_n, threadblock_k, warp_m, warp_n, | |||||
| warp_k, stage); | |||||
| } | |||||
| int access_size; | |||||
| AlgoParam(int threadblock_m_, int threadblock_n_, int threadblock_k_, | |||||
| int warp_m_, int warp_n_, int warp_k_, int instruction_m_, | |||||
| int instruction_n_, int instruction_k_, int stage_, | |||||
| int access_size_ = 0); | |||||
| std::string to_string() const; | |||||
| }; | }; | ||||
| AlgoCutlassConvolutionBase(AlgoParam algo_param) | |||||
| : m_algo_param{algo_param} {} | |||||
| // generate a cutlass::library::ConvolutionKey and find the corresponding | |||||
| // operation (cutlass kernel) from the global OperationTable | |||||
| const cutlass::library::Operation* get_cutlass_conv_op( | |||||
| const SizeArgs& args, ConvOperator conv_op, ConvType conv_type, | |||||
| bool load_from_const, bool without_shared_load) const; | |||||
| // execute the cutlass kernel found by get_cutlass_conv_op. we give | |||||
| // subclasses full freedom to decide where and how these arguments are | |||||
| // extracted | |||||
| void execute_cutlass_conv_op(const cutlass::library::Operation* op, | |||||
| const void* src, const void* filter, | |||||
| const void* bias, const void* z, void* dst, | |||||
| void* workspace, size_t n, size_t hi, | |||||
| size_t wi, size_t ci, size_t co, size_t fh, | |||||
| size_t fw, size_t ho, size_t wo, size_t ph, | |||||
| size_t pw, size_t sh, size_t sw, size_t dh, | |||||
| size_t dw, const void* alpha, const void* beta, | |||||
| const void* gamma, const void* delta, | |||||
| const void* theta, const void* threshold, | |||||
| const void* dst_scale, cudaStream_t stream, | |||||
| const void* extra_param = nullptr) const; | |||||
| protected: | |||||
| AlgoParam m_algo_param; | |||||
| }; | |||||
| class ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm final | |||||
| : public AlgoCutlassConvolutionBase { | |||||
| public: | |||||
| AlgoInt8NCHW4DotProdImplicitGemm(AlgoParam algo_param) | AlgoInt8NCHW4DotProdImplicitGemm(AlgoParam algo_param) | ||||
| : m_algo_param{algo_param}, | |||||
| : AlgoCutlassConvolutionBase(algo_param), | |||||
| m_name{ssprintf("INT8_NCHW4_DOTPROD_IMPLICIT_GEMM%s", | m_name{ssprintf("INT8_NCHW4_DOTPROD_IMPLICIT_GEMM%s", | ||||
| m_algo_param.to_string().c_str())} {} | m_algo_param.to_string().c_str())} {} | ||||
| bool is_available(const SizeArgs& args) const override; | bool is_available(const SizeArgs& args) const override; | ||||
| @@ -555,7 +634,6 @@ public: | |||||
| private: | private: | ||||
| WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, | WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, | ||||
| const SizeArgs& args) const; | const SizeArgs& args) const; | ||||
| AlgoParam m_algo_param; | |||||
| std::string m_name; | std::string m_name; | ||||
| }; | }; | ||||
| @@ -714,19 +792,10 @@ private: | |||||
| #if CUDA_VERSION >= 10020 | #if CUDA_VERSION >= 10020 | ||||
| class ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm final | class ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm final | ||||
| : public AlgoBase { | |||||
| : public AlgoCutlassConvolutionBase { | |||||
| public: | public: | ||||
| struct AlgoParam { | |||||
| int threadblock_m; | |||||
| int threadblock_n; | |||||
| int threadblock_k; | |||||
| int warp_m; | |||||
| int warp_n; | |||||
| int warp_k; | |||||
| int stage; | |||||
| }; | |||||
| AlgoInt8NCHW32IMMAImplicitGemm(AlgoParam algo_param) | AlgoInt8NCHW32IMMAImplicitGemm(AlgoParam algo_param) | ||||
| : m_algo_param{algo_param} { | |||||
| : AlgoCutlassConvolutionBase(algo_param) { | |||||
| m_name = ConvBias::algo_name<ConvBias::DirectParam>( | m_name = ConvBias::algo_name<ConvBias::DirectParam>( | ||||
| ssprintf("INT8_NCHW32_IMMA_IMPLICIT_GEMM_%s", | ssprintf("INT8_NCHW32_IMMA_IMPLICIT_GEMM_%s", | ||||
| to_string(m_algo_param).c_str()), | to_string(m_algo_param).c_str()), | ||||
| @@ -757,25 +826,14 @@ private: | |||||
| WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, | WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, | ||||
| const SizeArgs& args) const; | const SizeArgs& args) const; | ||||
| AlgoParam m_algo_param; | |||||
| std::string m_name; | std::string m_name; | ||||
| }; | }; | ||||
| class ConvBiasForwardImpl::AlgoInt4NCHW64IMMAImplicitGemmBase | class ConvBiasForwardImpl::AlgoInt4NCHW64IMMAImplicitGemmBase | ||||
| : public AlgoBase { | |||||
| : public AlgoCutlassConvolutionBase { | |||||
| public: | public: | ||||
| struct AlgoParam { | |||||
| int threadblock_m; | |||||
| int threadblock_n; | |||||
| int threadblock_k; | |||||
| int warp_m; | |||||
| int warp_n; | |||||
| int warp_k; | |||||
| int stage; | |||||
| }; | |||||
| AlgoInt4NCHW64IMMAImplicitGemmBase(AlgoParam algo_param) | AlgoInt4NCHW64IMMAImplicitGemmBase(AlgoParam algo_param) | ||||
| : m_algo_param(algo_param) {} | |||||
| : AlgoCutlassConvolutionBase(algo_param) {} | |||||
| AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
| return AlgoAttribute::REPRODUCIBLE; | return AlgoAttribute::REPRODUCIBLE; | ||||
| @@ -799,16 +857,9 @@ protected: | |||||
| virtual std::tuple<float, float, float, float, float> get_constants( | virtual std::tuple<float, float, float, float, float> get_constants( | ||||
| const ExecArgs& args) const = 0; | const ExecArgs& args) const = 0; | ||||
| virtual void do_exec(const ExecArgs& args, void* filter_ptr, void* bias_ptr, | |||||
| void* z_ptr, convolution::ConvParam kern_param, | |||||
| uint32_t nonlinear_mode, float alpha, float beta, | |||||
| float gamma, float delta, float theta, | |||||
| cudaStream_t stream) const = 0; | |||||
| void reorder_filter(const ExecArgs& args, void* reordered_filter) const; | void reorder_filter(const ExecArgs& args, void* reordered_filter) const; | ||||
| std::string m_name; | std::string m_name; | ||||
| AlgoParam m_algo_param; | |||||
| }; | }; | ||||
| class ConvBiasForwardImpl::AlgoInt4Int4NCHW64IMMAImplicitGemm final | class ConvBiasForwardImpl::AlgoInt4Int4NCHW64IMMAImplicitGemm final | ||||
| @@ -842,11 +893,6 @@ private: | |||||
| std::tuple<float, float, float, float, float> get_constants( | std::tuple<float, float, float, float, float> get_constants( | ||||
| const ExecArgs& args) const override; | const ExecArgs& args) const override; | ||||
| void do_exec(const ExecArgs& args, void* filter_ptr, void* bias_ptr, | |||||
| void* z_ptr, convolution::ConvParam kern_param, | |||||
| uint32_t nonlinear_mode, float alpha, float beta, float gamma, | |||||
| float delta, float theta, cudaStream_t stream) const override; | |||||
| }; | }; | ||||
| class ConvBiasForwardImpl::AlgoUInt4Int4NCHW64IMMAImplicitGemm final | class ConvBiasForwardImpl::AlgoUInt4Int4NCHW64IMMAImplicitGemm final | ||||
| @@ -881,30 +927,15 @@ private: | |||||
| std::tuple<float, float, float, float, float> get_constants( | std::tuple<float, float, float, float, float> get_constants( | ||||
| const ExecArgs& args) const override; | const ExecArgs& args) const override; | ||||
| void do_exec(const ExecArgs& args, void* filter_ptr, void* bias_ptr, | |||||
| void* z_ptr, convolution::ConvParam kern_param, | |||||
| uint32_t nonlinear_mode, float alpha, float beta, float gamma, | |||||
| float delta, float theta, cudaStream_t stream) const override; | |||||
| void update_bias(const ExecArgs& args, void* updated_bias, | void update_bias(const ExecArgs& args, void* updated_bias, | ||||
| void* reduce_filter_ptr, void* reduce_workspace) const; | void* reduce_filter_ptr, void* reduce_workspace) const; | ||||
| }; | }; | ||||
| class ConvBiasForwardImpl::AlgoInt4NHWCIMMAImplicitGemmBase : public AlgoBase { | |||||
| class ConvBiasForwardImpl::AlgoInt4NHWCIMMAImplicitGemmBase | |||||
| : public AlgoCutlassConvolutionBase { | |||||
| public: | public: | ||||
| struct AlgoParam { | |||||
| int threadblock_m; | |||||
| int threadblock_n; | |||||
| int threadblock_k; | |||||
| int warp_m; | |||||
| int warp_n; | |||||
| int warp_k; | |||||
| int stage; | |||||
| int access_size; | |||||
| }; | |||||
| AlgoInt4NHWCIMMAImplicitGemmBase(AlgoParam algo_param) | AlgoInt4NHWCIMMAImplicitGemmBase(AlgoParam algo_param) | ||||
| : m_algo_param(algo_param) {} | |||||
| : AlgoCutlassConvolutionBase(algo_param) {} | |||||
| AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
| return AlgoAttribute::REPRODUCIBLE; | return AlgoAttribute::REPRODUCIBLE; | ||||
| @@ -928,17 +959,10 @@ protected: | |||||
| virtual std::tuple<float, float, float, float, float> get_constants( | virtual std::tuple<float, float, float, float, float> get_constants( | ||||
| const ExecArgs& args) const = 0; | const ExecArgs& args) const = 0; | ||||
| virtual void do_exec(const ExecArgs& args, void* filter_ptr, void* bias_ptr, | |||||
| void* z_ptr, convolution::ConvParam kern_param, | |||||
| uint32_t nonlinear_mode, float alpha, float beta, | |||||
| float gamma, float delta, float theta, | |||||
| cudaStream_t stream) const = 0; | |||||
| void reorder_filter(const ExecArgs& args, int interleaved, | void reorder_filter(const ExecArgs& args, int interleaved, | ||||
| void* reordered_filter) const; | void* reordered_filter) const; | ||||
| std::string m_name; | std::string m_name; | ||||
| AlgoParam m_algo_param; | |||||
| }; | }; | ||||
| class ConvBiasForwardImpl::AlgoInt4Int4NHWCIMMAImplicitGemm final | class ConvBiasForwardImpl::AlgoInt4Int4NHWCIMMAImplicitGemm final | ||||
| @@ -971,11 +995,6 @@ private: | |||||
| std::tuple<float, float, float, float, float> get_constants( | std::tuple<float, float, float, float, float> get_constants( | ||||
| const ExecArgs& args) const override; | const ExecArgs& args) const override; | ||||
| void do_exec(const ExecArgs& args, void* filter_ptr, void* bias_ptr, | |||||
| void* z_ptr, convolution::ConvParam kern_param, | |||||
| uint32_t nonlinear_mode, float alpha, float beta, float gamma, | |||||
| float delta, float theta, cudaStream_t stream) const override; | |||||
| }; | }; | ||||
| class ConvBiasForwardImpl::AlgoUInt4Int4NHWCIMMAImplicitGemm final | class ConvBiasForwardImpl::AlgoUInt4Int4NHWCIMMAImplicitGemm final | ||||
| @@ -1009,11 +1028,6 @@ private: | |||||
| std::tuple<float, float, float, float, float> get_constants( | std::tuple<float, float, float, float, float> get_constants( | ||||
| const ExecArgs& args) const override; | const ExecArgs& args) const override; | ||||
| void do_exec(const ExecArgs& args, void* filter_ptr, void* bias_ptr, | |||||
| void* z_ptr, convolution::ConvParam kern_param, | |||||
| uint32_t nonlinear_mode, float alpha, float beta, float gamma, | |||||
| float delta, float theta, cudaStream_t stream) const override; | |||||
| void update_bias(const ExecArgs& args, void* updated_bias, | void update_bias(const ExecArgs& args, void* updated_bias, | ||||
| void* reduce_filter_ptr, void* reduce_workspace) const; | void* reduce_filter_ptr, void* reduce_workspace) const; | ||||
| }; | }; | ||||
| @@ -0,0 +1,253 @@ | |||||
| /** | |||||
| * \file dnn/src/cuda/conv_bias/cutlass_convolution_base.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #include "src/cuda/conv_bias/algo.h" | |||||
| #include "src/cuda/cutlass/singleton.h" | |||||
| namespace megdnn { | |||||
| namespace cuda { | |||||
| using namespace cutlass::library; | |||||
| using namespace cutlass::epilogue; | |||||
| ConvBiasForwardImpl::AlgoCutlassConvolutionBase::AlgoParam::AlgoParam( | |||||
| int threadblock_m_, int threadblock_n_, int threadblock_k_, int warp_m_, | |||||
| int warp_n_, int warp_k_, int instruction_m_, int instruction_n_, | |||||
| int instruction_k_, int stage_, int access_size_) | |||||
| : threadblock_m(threadblock_m_), | |||||
| threadblock_n(threadblock_n_), | |||||
| threadblock_k(threadblock_k_), | |||||
| warp_m(warp_m_), | |||||
| warp_n(warp_n_), | |||||
| warp_k(warp_k_), | |||||
| instruction_m(instruction_m_), | |||||
| instruction_n(instruction_m_), | |||||
| instruction_k(instruction_k_), | |||||
| stage(stage_), | |||||
| access_size(access_size_) {} | |||||
| std::string | |||||
| ConvBiasForwardImpl::AlgoCutlassConvolutionBase::AlgoParam::to_string() const { | |||||
| /// default algorithm | |||||
| if (threadblock_m == 128 && threadblock_n == 128 && threadblock_k == 32 && | |||||
| warp_m == 32 && warp_n == 64 && warp_k == 32 && stage == 2) { | |||||
| return ""; | |||||
| } | |||||
| return ssprintf("_%dX%dX%d_%dX%dX%d_%dstage", threadblock_m, threadblock_n, | |||||
| threadblock_k, warp_m, warp_n, warp_k, stage); | |||||
| } | |||||
| namespace { | |||||
| using Base = ConvBiasForwardImpl::AlgoCutlassConvolutionBase; | |||||
| cutlass::conv::Operator convert_conv_op(Base::ConvOperator conv_op) { | |||||
| switch (conv_op) { | |||||
| case Base::ConvOperator::kFprop: | |||||
| return cutlass::conv::Operator::kFprop; | |||||
| case Base::ConvOperator::kDgrad: | |||||
| return cutlass::conv::Operator::kDgrad; | |||||
| case Base::ConvOperator::kWgrad: | |||||
| return cutlass::conv::Operator::kWgrad; | |||||
| default: | |||||
| megdnn_assert(0, "invalid conv op"); | |||||
| } | |||||
| } | |||||
| cutlass::conv::ConvType convert_conv_type(Base::ConvType conv_type) { | |||||
| switch (conv_type) { | |||||
| case Base::ConvType::kConvolution: | |||||
| return cutlass::conv::ConvType::kConvolution; | |||||
| case Base::ConvType::kBatchConvolution: | |||||
| return cutlass::conv::ConvType::kBatchConvolution; | |||||
| case Base::ConvType::kLocal: | |||||
| return cutlass::conv::ConvType::kLocal; | |||||
| case Base::ConvType::kLocalShare: | |||||
| return cutlass::conv::ConvType::kLocalShare; | |||||
| default: | |||||
| megdnn_assert(0, "invalid conv type"); | |||||
| } | |||||
| } | |||||
| NumericTypeID convert_dtype(DTypeEnum dtype) { | |||||
| switch (dtype) { | |||||
| case DTypeEnum::Float32: | |||||
| return NumericTypeID::kF32; | |||||
| case DTypeEnum::Float16: | |||||
| return NumericTypeID::kF16; | |||||
| case DTypeEnum::Int8: | |||||
| return NumericTypeID::kS8; | |||||
| case DTypeEnum::QuantizedS32: | |||||
| return NumericTypeID::kS32; | |||||
| case DTypeEnum::QuantizedS8: | |||||
| return NumericTypeID::kS8; | |||||
| case DTypeEnum::QuantizedS4: | |||||
| return NumericTypeID::kS4; | |||||
| case DTypeEnum::Quantized4Asymm: | |||||
| return NumericTypeID::kU4; | |||||
| default: | |||||
| megdnn_assert(0, "invalid dtype"); | |||||
| } | |||||
| } | |||||
| struct LayoutPack { | |||||
| LayoutTypeID src; | |||||
| LayoutTypeID filter; | |||||
| LayoutTypeID dst; | |||||
| LayoutTypeID bias; | |||||
| }; | |||||
| LayoutPack get_layout_pack(const param::ConvBias::Format format, | |||||
| int access_type) { | |||||
| using Format = param::ConvBias::Format; | |||||
| switch (format) { | |||||
| case Format::NCHW4: | |||||
| return {LayoutTypeID::kTensorNC4HW4, LayoutTypeID::kTensorC4RSK4, | |||||
| LayoutTypeID::kTensorNC4HW4, LayoutTypeID::kTensorNC4HW4}; | |||||
| case Format::NCHW4_NCHW: | |||||
| return {LayoutTypeID::kTensorNC4HW4, LayoutTypeID::kTensorC4RSK4, | |||||
| LayoutTypeID::kTensorNCHW, LayoutTypeID::kTensorNCHW}; | |||||
| case Format::NCHW4_NHWC: | |||||
| return {LayoutTypeID::kTensorNC4HW4, LayoutTypeID::kTensorC4RSK4, | |||||
| LayoutTypeID::kTensorNHWC, LayoutTypeID::kTensorNHWC}; | |||||
| case Format::NCHW4_NCHW32: | |||||
| return {LayoutTypeID::kTensorNC4HW4, LayoutTypeID::kTensorC4RSK4, | |||||
| LayoutTypeID::kTensorNC32HW32, | |||||
| LayoutTypeID::kTensorNC32HW32}; | |||||
| case Format::NCHW32: | |||||
| return {LayoutTypeID::kTensorNC32HW32, | |||||
| LayoutTypeID::kTensorC32RSK32, | |||||
| LayoutTypeID::kTensorNC32HW32, | |||||
| LayoutTypeID::kTensorNC32HW32}; | |||||
| case Format::NCHW32_NCHW4: | |||||
| return {LayoutTypeID::kTensorNC32HW32, | |||||
| LayoutTypeID::kTensorC32RSK32, LayoutTypeID::kTensorNC4HW4, | |||||
| LayoutTypeID::kTensorNC4HW4}; | |||||
| case Format::NCHW64: | |||||
| return {LayoutTypeID::kTensorNC64HW64, | |||||
| LayoutTypeID::kTensorC64RSK64, | |||||
| LayoutTypeID::kTensorNC64HW64, | |||||
| LayoutTypeID::kTensorNC64HW64}; | |||||
| case Format::NHWC: | |||||
| switch (access_type) { | |||||
| case 8: | |||||
| return {LayoutTypeID::kTensorNHWC, | |||||
| LayoutTypeID::kTensorNC8HW8, | |||||
| LayoutTypeID::kTensorNHWC, | |||||
| LayoutTypeID::kTensorNHWC}; | |||||
| case 16: | |||||
| return {LayoutTypeID::kTensorNHWC, | |||||
| LayoutTypeID::kTensorNC16HW16, | |||||
| LayoutTypeID::kTensorNHWC, | |||||
| LayoutTypeID::kTensorNHWC}; | |||||
| case 32: | |||||
| return {LayoutTypeID::kTensorNHWC, | |||||
| LayoutTypeID::kTensorNC32HW32, | |||||
| LayoutTypeID::kTensorNHWC, | |||||
| LayoutTypeID::kTensorNHWC}; | |||||
| default: | |||||
| megdnn_assert(0, "invalid access_type"); | |||||
| } | |||||
| default: | |||||
| megdnn_assert(0, "invalid format"); | |||||
| } | |||||
| } | |||||
| EpilogueType get_epilogue_type(const param::ConvBias::NonlineMode mode, | |||||
| bool clamp) { | |||||
| using NonlineMode = param::ConvBias::NonlineMode; | |||||
| if (clamp) { | |||||
| if (mode == NonlineMode::IDENTITY) { | |||||
| return EpilogueType::kBiasAddLinearCombinationClamp; | |||||
| } else if (mode == NonlineMode::RELU) { | |||||
| return EpilogueType::kBiasAddLinearCombinationReluClamp; | |||||
| } else if (mode == NonlineMode::H_SWISH) { | |||||
| return EpilogueType::kBiasAddLinearCombinationHSwishClamp; | |||||
| } | |||||
| } else { | |||||
| if (mode == NonlineMode::IDENTITY) { | |||||
| return EpilogueType::kBiasAddLinearCombination; | |||||
| } else if (mode == NonlineMode::RELU) { | |||||
| return EpilogueType::kBiasAddLinearCombinationRelu; | |||||
| } else if (mode == NonlineMode::H_SWISH) { | |||||
| return EpilogueType::kBiasAddLinearCombinationHSwish; | |||||
| } | |||||
| } | |||||
| megdnn_assert(0, "invalid nonlinear mode"); | |||||
| } | |||||
| } // namespace | |||||
| const Operation* | |||||
| ConvBiasForwardImpl::AlgoCutlassConvolutionBase::get_cutlass_conv_op( | |||||
| const SizeArgs& args, ConvOperator conv_op, ConvType conv_type, | |||||
| bool load_from_const, bool without_shared_load) const { | |||||
| using Format = param::ConvBias::Format; | |||||
| auto&& param = args.opr->param(); | |||||
| auto layouts = get_layout_pack(param.format, m_algo_param.access_size); | |||||
| auto epilogue_type = get_epilogue_type(param.nonlineMode, | |||||
| param.format != Format::NCHW4_NCHW); | |||||
| ConvolutionKey key{convert_conv_op(conv_op), | |||||
| convert_dtype(args.src_layout->dtype.enumv()), | |||||
| layouts.src, | |||||
| convert_dtype(args.filter_layout->dtype.enumv()), | |||||
| layouts.filter, | |||||
| convert_dtype(args.dst_layout->dtype.enumv()), | |||||
| layouts.dst, | |||||
| convert_dtype(args.bias_layout->dtype.enumv()), | |||||
| layouts.bias, | |||||
| convert_conv_type(conv_type), | |||||
| m_algo_param.threadblock_m, | |||||
| m_algo_param.threadblock_n, | |||||
| m_algo_param.threadblock_k, | |||||
| m_algo_param.warp_m, | |||||
| m_algo_param.warp_n, | |||||
| m_algo_param.warp_k, | |||||
| m_algo_param.instruction_m, | |||||
| m_algo_param.instruction_n, | |||||
| m_algo_param.instruction_k, | |||||
| epilogue_type, | |||||
| m_algo_param.stage, | |||||
| load_from_const, | |||||
| without_shared_load}; | |||||
| return Singleton::get().operation_table.find_op(key); | |||||
| } | |||||
| void ConvBiasForwardImpl::AlgoCutlassConvolutionBase::execute_cutlass_conv_op( | |||||
| const Operation* op, const void* src, const void* filter, | |||||
| const void* bias, const void* z, void* dst, void* workspace, size_t n, | |||||
| size_t hi, size_t wi, size_t ci, size_t co, size_t fh, size_t fw, | |||||
| size_t ho, size_t wo, size_t ph, size_t pw, size_t sh, size_t sw, | |||||
| size_t dh, size_t dw, const void* alpha, const void* beta, | |||||
| const void* gamma, const void* delta, const void* theta, | |||||
| const void* threshold, const void* dst_scale, cudaStream_t stream, | |||||
| const void* extra_param) const { | |||||
| // gcc prints warnings when size_t values are implicitly narrowed to int | |||||
| cutlass::conv::Conv2dProblemSize problem_size{ | |||||
| int(n), int(hi), int(wi), int(ci), | |||||
| int(co), int(fh), int(fw), int(ho), | |||||
| int(wo), int(ph), int(pw), int(sh), | |||||
| int(sw), int(dh), int(dw), cutlass::conv::Mode::kCrossCorrelation}; | |||||
| ConvolutionArguments conv_args{ | |||||
| problem_size, src, filter, bias, z, | |||||
| dst, alpha, beta, gamma, delta, | |||||
| theta, threshold, dst_scale, extra_param}; | |||||
| cutlass_check(op->run(&conv_args, workspace, stream)); | |||||
| } | |||||
| } // namespace cuda | |||||
| } // namespace megdnn | |||||
| @@ -1,129 +0,0 @@ | |||||
| /** | |||||
| * \file dnn/src/cuda/conv_bias/cutlass_convolution_wrapper.cuh | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #pragma once | |||||
| #include "cutlass/gemm/gemm.h" | |||||
| #include "src/cuda/convolution_helper/parameter.cuh" | |||||
| #include "src/cuda/utils.cuh" | |||||
| namespace megdnn { | |||||
| namespace cuda { | |||||
| namespace cutlass_wrapper { | |||||
| using GemmCoord = cutlass::gemm::GemmCoord; | |||||
| template <typename Convolution> | |||||
| void cutlass_convolution_wrapper( | |||||
| const typename Convolution::ElementSrc* d_src, | |||||
| const typename Convolution::ElementFilter* d_filter, | |||||
| const typename Convolution::ElementBias* d_bias, | |||||
| const typename Convolution::ElementDst* d_z, | |||||
| typename Convolution::ElementDst* d_dst, int* workspace, | |||||
| typename Convolution::ConvolutionParameter const& conv_param, | |||||
| typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||||
| cudaStream_t stream, typename Convolution::ExtraParam extra_param = {}); | |||||
| template <bool NeedLoadFromConstMem> | |||||
| void do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32( | |||||
| const int8_t* d_src, const int8_t* d_filter, const int32_t* d_bias, | |||||
| const int8_t* d_z, int8_t* d_dst, int* workspace, | |||||
| const convolution::ConvParam& param, uint32_t nonlinear_mode, | |||||
| float alpha, float beta, float gamma, float scale, | |||||
| const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, | |||||
| int stages, cudaStream_t stream); | |||||
| template <bool NeedLoadFromConstMem> | |||||
| void do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4( | |||||
| const int8_t* d_src, const int8_t* d_filter, const int32_t* d_bias, | |||||
| const int8_t* d_z, int8_t* d_dst, int* workspace, | |||||
| const convolution::ConvParam& param, uint32_t nonlinear_mode, | |||||
| float alpha, float beta, float gamma, float scale, | |||||
| const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, | |||||
| int stages, cudaStream_t stream); | |||||
| template <bool NeedLoadFromConstMem> | |||||
| void do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4( | |||||
| const int8_t* d_src, const int8_t* d_filter, const int32_t* d_bias, | |||||
| const int8_t* d_z, int8_t* d_dst, int* workspace, | |||||
| const convolution::ConvParam& param, uint32_t nonlinear_mode, | |||||
| float alpha, float beta, float gamma, float scale, | |||||
| const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, | |||||
| int stages, cudaStream_t stream); | |||||
| template <bool NeedLoadFromConstMem> | |||||
| void do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nchw( | |||||
| const int8_t* d_src, const int8_t* d_filter, const float* d_bias, | |||||
| const float* d_z, float* d_dst, int* workspace, | |||||
| const convolution::ConvParam& param, uint32_t nonlinear_mode, | |||||
| float alpha, float beta, float gamma, float scale, | |||||
| const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, | |||||
| int stages, cudaStream_t stream); | |||||
| template <bool NeedLoadFromConstMem> | |||||
| void do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32( | |||||
| const int8_t* d_src, const int8_t* d_filter, const int32_t* d_bias, | |||||
| const int8_t* d_z, int8_t* d_dst, int* workspace, | |||||
| const convolution::ConvParam& param, uint32_t nonlinear_mode, | |||||
| float alpha, float beta, float gamma, float scale, | |||||
| const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, | |||||
| int stages, cudaStream_t stream); | |||||
| template <bool NeedLoadFromConstMem> | |||||
| void do_conv_bias_int4_int4_implicit_gemm_imma_ncdiv64hw64( | |||||
| const int8_t* d_src, const int8_t* d_filter, const int32_t* d_bias, | |||||
| const int8_t* d_z, int8_t* d_dst, int* workspace, | |||||
| const convolution::ConvParam& param, uint32_t nonlinear_mode, | |||||
| float alpha, float beta, float gamma, float scale, | |||||
| const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, | |||||
| int stages, cudaStream_t stream); | |||||
| template <bool NeedLoadFromConstMem> | |||||
| void do_conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64( | |||||
| const uint8_t* d_src, const int8_t* d_filter, const int32_t* d_bias, | |||||
| const uint8_t* d_z, uint8_t* d_dst, int* workspace, | |||||
| const convolution::ConvParam& param, uint32_t nonlinear_mode, | |||||
| float alpha, float beta, float gamma, float delta, float theta, | |||||
| float scale, uint8_t src_zero_point, const GemmCoord& threadblock_shape, | |||||
| const GemmCoord& warp_shape, int stages, cudaStream_t stream); | |||||
| template <bool signedness> | |||||
| void do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nhwc( | |||||
| const int8_t* d_src, const int8_t* d_filter, const int32_t* d_bias, | |||||
| const int8_t* d_z, int8_t* d_dst, int* workspace, | |||||
| const convolution::ConvParam& param, uint32_t nonlinear_mode, | |||||
| float alpha, float beta, float gamma, float delta, float theta, | |||||
| float scale, const GemmCoord& threadblock_shape, | |||||
| const GemmCoord& warp_shape, int stages, cudaStream_t stream); | |||||
| template <bool NeedLoadFromConstMem> | |||||
| void do_conv_bias_int4_int4_implicit_gemm_imma_nhwc( | |||||
| const int8_t* d_src, const int8_t* d_filter, const int32_t* d_bias, | |||||
| const int8_t* d_z, int8_t* d_dst, int* workspace, | |||||
| const convolution::ConvParam& param, uint32_t nonlinear_mode, | |||||
| float alpha, float beta, float gamma, float scale, | |||||
| const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, | |||||
| const int32_t access_size, int stages, cudaStream_t stream); | |||||
| template <bool NeedLoadFromConstMem> | |||||
| void do_conv_bias_uint4_int4_implicit_gemm_imma_nhwc( | |||||
| const uint8_t* d_src, const int8_t* d_filter, const int32_t* d_bias, | |||||
| const uint8_t* d_z, uint8_t* d_dst, int* workspace, | |||||
| const convolution::ConvParam& param, uint32_t nonlinear_mode, | |||||
| float alpha, float beta, float gamma, float delta, float theta, | |||||
| float scale, uint8_t src_zero_point, const GemmCoord& threadblock_shape, | |||||
| const GemmCoord& warp_shape, const int32_t access_size, int stages, | |||||
| cudaStream_t stream); | |||||
| } // namespace cutlass_wrapper | |||||
| } // namespace cuda | |||||
| } // namespace megdnn | |||||
| // vim: syntax=cuda.doxygen | |||||
| @@ -1,595 +0,0 @@ | |||||
| /** | |||||
| * \file dnn/src/cuda/conv_bias/cutlass_convolution_wrapper.cu | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| // ignore warning of cutlass | |||||
| #pragma GCC diagnostic push | |||||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||||
| #if !MEGDNN_TEGRA_X1 | |||||
| #include "cutlass/convolution/device/convolution.h" | |||||
| #endif | |||||
| #include "src/common/opr_param_defs_enumv.cuh" | |||||
| #include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh" | |||||
| #pragma GCC diagnostic pop | |||||
| using namespace megdnn; | |||||
| using namespace cuda; | |||||
| using namespace cutlass_wrapper; | |||||
| /* ====== cutlass kernel wrapper for int4 x int4 nchw64 layout ====== */ | |||||
| #if MEGDNN_TEGRA_X1 | |||||
| template <bool NeedLoadFromConstMem> | |||||
| void megdnn::cuda::cutlass_wrapper:: | |||||
| do_conv_bias_int4_int4_implicit_gemm_imma_ncdiv64hw64( | |||||
| const int8_t* /* d_src */, const int8_t* /* d_filter */, | |||||
| const int32_t* /* d_bias */, const int8_t* /* d_z */, | |||||
| int8_t* /* d_dst */, int* /* workspace */, | |||||
| const convolution::ConvParam& /* param */, | |||||
| uint32_t /* nonlinear_mode */, float /* alpha */, | |||||
| float /* beta */, float /* gamma */, float /* scale */, | |||||
| const GemmCoord& /* threadblock_shape */, | |||||
| const GemmCoord& /* warp_shape */, int /* stages */, | |||||
| cudaStream_t /* stream */) {} | |||||
| #else | |||||
| template <bool NeedLoadFromConstMem> | |||||
| void megdnn::cuda::cutlass_wrapper:: | |||||
| do_conv_bias_int4_int4_implicit_gemm_imma_ncdiv64hw64( | |||||
| const int8_t* d_src, const int8_t* d_filter, | |||||
| const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, | |||||
| int* workspace, const convolution::ConvParam& param, | |||||
| uint32_t nonlinear_mode, float alpha, float beta, float gamma, | |||||
| float scale, const GemmCoord& threadblock_shape, | |||||
| const GemmCoord& warp_shape, int stages, cudaStream_t stream) { | |||||
| #define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \ | |||||
| threadblock_k_, warp_m_, warp_n_, \ | |||||
| warp_k_, stage_) \ | |||||
| if (threadblock_shape.m() == threadblock_m_ && \ | |||||
| threadblock_shape.n() == threadblock_n_ && \ | |||||
| threadblock_shape.k() == threadblock_k_ && \ | |||||
| warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ | |||||
| warp_shape.k() == warp_k_ && stages == stage_) { \ | |||||
| using ThreadBlockShape = \ | |||||
| cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \ | |||||
| threadblock_k_>; \ | |||||
| using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \ | |||||
| using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; \ | |||||
| using Convolution = cutlass::conv::device::Convolution< \ | |||||
| cutlass::int4b_t, cutlass::layout::TensorNCxHWx<64>, \ | |||||
| cutlass::int4b_t, cutlass::layout::TensorCxRSKx<64>, \ | |||||
| ElementOutput, cutlass::layout::TensorNCxHWx<64>, int32_t, \ | |||||
| cutlass::layout::TensorNCxHWx<64>, int32_t, \ | |||||
| cutlass::conv::ConvType::kConvolution, \ | |||||
| cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \ | |||||
| ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ | |||||
| cutlass::conv::threadblock:: \ | |||||
| ConvolutionFpropTransThreadblockSwizzle, \ | |||||
| stage_, 32, 32, NeedLoadFromConstMem, \ | |||||
| cutlass::arch::OpMultiplyAddSaturate, \ | |||||
| cutlass::conv::ImplicitGemmMode::GEMM_TN, true>; \ | |||||
| typename Convolution::ConvolutionParameter conv_param( \ | |||||
| param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ | |||||
| param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ | |||||
| param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \ | |||||
| return cutlass_convolution_wrapper<Convolution>( \ | |||||
| reinterpret_cast<const cutlass::int4b_t*>(d_src), \ | |||||
| reinterpret_cast<const cutlass::int4b_t*>(d_filter), d_bias, \ | |||||
| reinterpret_cast<const cutlass::int4b_t*>(d_z), \ | |||||
| reinterpret_cast<cutlass::int4b_t*>(d_dst), workspace, \ | |||||
| conv_param, epilogue, stream); \ | |||||
| } | |||||
| #define DISPATCH_KERNEL \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 128, 64, 64, 128, 2); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 256, 128, 64, 64, 128, 2); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 128, 64, 64, 128, 2); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1); \ | |||||
| megdnn_assert(false, \ | |||||
| "unsupported threadblock shape (%dx%dx%d) and warp shape " \ | |||||
| "(%dx%dx%d)", \ | |||||
| threadblock_shape.m(), threadblock_shape.n(), \ | |||||
| threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ | |||||
| warp_shape.k()); | |||||
| using ElementOutput = cutlass::int4b_t; | |||||
| using ElementAccumulator = int32_t; | |||||
| using ElementBias = int32_t; | |||||
| using ElementCompute = float; | |||||
| using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode; | |||||
| switch (nonlinear_mode) { | |||||
| case NonlineMode::IDENTITY: { | |||||
| using EpilogueOp = | |||||
| cutlass::epilogue::thread::BiasAddLinearCombinationClamp< | |||||
| ElementOutput, 16, ElementAccumulator, ElementBias, | |||||
| ElementCompute>; | |||||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma}; | |||||
| DISPATCH_KERNEL; | |||||
| } | |||||
| case NonlineMode::RELU: { | |||||
| using EpilogueOp = cutlass::epilogue::thread:: | |||||
| BiasAddLinearCombinationReluClamp< | |||||
| ElementOutput, 16, ElementAccumulator, ElementBias, | |||||
| ElementCompute>; | |||||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma, 0}; | |||||
| DISPATCH_KERNEL; | |||||
| } | |||||
| case NonlineMode::H_SWISH: { | |||||
| using EpilogueOp = cutlass::epilogue::thread:: | |||||
| BiasAddLinearCombinationHSwishClamp< | |||||
| ElementOutput, 16, ElementAccumulator, ElementBias, | |||||
| ElementCompute>; | |||||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma, scale}; | |||||
| DISPATCH_KERNEL; | |||||
| } | |||||
| default: | |||||
| megdnn_assert(false, | |||||
| "unsupported nonlinear mode for conv bias operator"); | |||||
| } | |||||
| #undef DISPATCH_KERNEL_WITH_TILE_SHAPE | |||||
| #undef DISPATCH_KERNEL | |||||
| } | |||||
| #endif | |||||
| #define INST(need_load_from_const_mem) \ | |||||
| template void megdnn::cuda::cutlass_wrapper:: \ | |||||
| do_conv_bias_int4_int4_implicit_gemm_imma_ncdiv64hw64< \ | |||||
| need_load_from_const_mem>( \ | |||||
| const int8_t* d_src, const int8_t* d_filter, \ | |||||
| const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, \ | |||||
| int* workspace, const convolution::ConvParam& param, \ | |||||
| uint32_t nonlinear_mode, float alpha, float beta, \ | |||||
| float gamma, float scale, \ | |||||
| const GemmCoord& threadblock_shape, \ | |||||
| const GemmCoord& warp_shape, int stages, \ | |||||
| cudaStream_t stream); | |||||
| INST(true); | |||||
| #undef INST | |||||
| /* ====== cutlass kernel wrapper for uint4 x int4 nchw64 layout ====== */ | |||||
| #if MEGDNN_TEGRA_X1 | |||||
| template <bool NeedLoadFromConstMem> | |||||
| void megdnn::cuda::cutlass_wrapper:: | |||||
| do_conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64( | |||||
| const uint8_t* /* d_src */, const int8_t* /* d_filter */, | |||||
| const int32_t* /* d_bias */, const uint8_t* /* d_z */, | |||||
| uint8_t* /* d_dst */, int* /* workspace */, | |||||
| const convolution::ConvParam& /* param */, | |||||
| uint32_t /* nonlinear_mode */, float /* alpha */, | |||||
| float /* beta */, float /* gamma */, float /* delta */, | |||||
| float /* theta */, float /* scale */, | |||||
| uint8_t /* src_zero_point */, | |||||
| const GemmCoord& /* threadblock_shape */, | |||||
| const GemmCoord& /* warp_shape */, int /* stages */, | |||||
| cudaStream_t /* stream */) {} | |||||
| #else | |||||
| template <bool NeedLoadFromConstMem> | |||||
| void megdnn::cuda::cutlass_wrapper:: | |||||
| do_conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64( | |||||
| const uint8_t* d_src, const int8_t* d_filter, | |||||
| const int32_t* d_bias, const uint8_t* d_z, uint8_t* d_dst, | |||||
| int* workspace, const convolution::ConvParam& param, | |||||
| uint32_t nonlinear_mode, float alpha, float beta, float gamma, | |||||
| float delta, float theta, float /* scale */, | |||||
| uint8_t src_zero_point, const GemmCoord& threadblock_shape, | |||||
| const GemmCoord& warp_shape, int stages, cudaStream_t stream) { | |||||
| #define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \ | |||||
| threadblock_k_, warp_m_, warp_n_, \ | |||||
| warp_k_, stage_) \ | |||||
| if (threadblock_shape.m() == threadblock_m_ && \ | |||||
| threadblock_shape.n() == threadblock_n_ && \ | |||||
| threadblock_shape.k() == threadblock_k_ && \ | |||||
| warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ | |||||
| warp_shape.k() == warp_k_ && stages == stage_) { \ | |||||
| using ThreadBlockShape = \ | |||||
| cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \ | |||||
| threadblock_k_>; \ | |||||
| using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \ | |||||
| using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; \ | |||||
| using Convolution = cutlass::conv::device::Convolution< \ | |||||
| cutlass::uint4b_t, cutlass::layout::TensorNCxHWx<64>, \ | |||||
| cutlass::int4b_t, cutlass::layout::TensorCxRSKx<64>, \ | |||||
| ElementOutput, cutlass::layout::TensorNCxHWx<64>, int32_t, \ | |||||
| cutlass::layout::TensorNCxHWx<64>, int32_t, \ | |||||
| cutlass::conv::ConvType::kConvolution, \ | |||||
| cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \ | |||||
| ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ | |||||
| cutlass::conv::threadblock:: \ | |||||
| ConvolutionFpropTransThreadblockSwizzle, \ | |||||
| stage_, 32, 32, NeedLoadFromConstMem, \ | |||||
| cutlass::arch::OpMultiplyAddSaturate, \ | |||||
| cutlass::conv::ImplicitGemmMode::GEMM_TN, true>; \ | |||||
| typename Convolution::ConvolutionParameter conv_param( \ | |||||
| param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ | |||||
| param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ | |||||
| param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \ | |||||
| return cutlass_convolution_wrapper<Convolution>( \ | |||||
| reinterpret_cast<const cutlass::uint4b_t*>(d_src), \ | |||||
| reinterpret_cast<const cutlass::int4b_t*>(d_filter), d_bias, \ | |||||
| reinterpret_cast<const cutlass::uint4b_t*>(d_z), \ | |||||
| reinterpret_cast<cutlass::uint4b_t*>(d_dst), workspace, \ | |||||
| conv_param, epilogue, stream, {src_zero_point}); \ | |||||
| } | |||||
| #define DISPATCH_KERNEL \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 128, 64, 64, 128, 2); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 256, 128, 64, 64, 128, 2); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 128, 64, 64, 128, 2); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1); \ | |||||
| megdnn_assert(false, \ | |||||
| "unsupported threadblock shape (%dx%dx%d) and warp shape " \ | |||||
| "(%dx%dx%d)", \ | |||||
| threadblock_shape.m(), threadblock_shape.n(), \ | |||||
| threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ | |||||
| warp_shape.k()); | |||||
| using ElementOutput = cutlass::uint4b_t; | |||||
| using ElementAccumulator = int32_t; | |||||
| using ElementBias = int32_t; | |||||
| using ElementCompute = float; | |||||
| using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode; | |||||
| switch (nonlinear_mode) { | |||||
| case NonlineMode::IDENTITY: { | |||||
| using EpilogueOp = | |||||
| cutlass::epilogue::thread::BiasAddLinearCombinationClamp< | |||||
| ElementOutput, 16, ElementAccumulator, ElementBias, | |||||
| ElementCompute>; | |||||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma, | |||||
| delta + theta}; | |||||
| DISPATCH_KERNEL; | |||||
| } | |||||
| case NonlineMode::RELU: { | |||||
| using EpilogueOp = cutlass::epilogue::thread:: | |||||
| BiasAddLinearCombinationReluClamp< | |||||
| ElementOutput, 16, ElementAccumulator, ElementBias, | |||||
| ElementCompute>; | |||||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma, | |||||
| 0, delta, theta}; | |||||
| DISPATCH_KERNEL; | |||||
| } | |||||
| default: | |||||
| megdnn_assert(false, | |||||
| "unsupported nonlinear mode for conv bias operator"); | |||||
| } | |||||
| #undef DISPATCH_KERNEL_WITH_TILE_SHAPE | |||||
| #undef DISPATCH_KERNEL | |||||
| } | |||||
| #endif | |||||
| #define INST(need_load_from_const_mem) \ | |||||
| template void megdnn::cuda::cutlass_wrapper:: \ | |||||
| do_conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64< \ | |||||
| need_load_from_const_mem>( \ | |||||
| const uint8_t* d_src, const int8_t* d_filter, \ | |||||
| const int32_t* d_bias, const uint8_t* d_z, uint8_t* d_dst, \ | |||||
| int* workspace, const convolution::ConvParam& param, \ | |||||
| uint32_t nonlinear_mode, float alpha, float beta, \ | |||||
| float gamma, float delta, float theta, float scale, \ | |||||
| uint8_t src_zero_point, \ | |||||
| const GemmCoord& threadblock_shape, \ | |||||
| const GemmCoord& warp_shape, int stages, \ | |||||
| cudaStream_t stream); | |||||
| INST(true); | |||||
| #undef INST | |||||
| /* ====== cutlass kernel wrapper for int4 x int4 nhwc layout ====== */ | |||||
| #if MEGDNN_TEGRA_X1 | |||||
| template <bool NeedLoadFromConstMem> | |||||
| void megdnn::cuda::cutlass_wrapper:: | |||||
| do_conv_bias_int4_int4_implicit_gemm_imma_nhwc( | |||||
| const int8_t* /* d_src */, const int8_t* /* d_filter */, | |||||
| const int32_t* /* d_bias */, const int8_t* /* d_z */, | |||||
| int8_t* /* d_dst */, int* /* workspace */, | |||||
| const convolution::ConvParam& /* param */, | |||||
| uint32_t /* nonlinear_mode */, float /* alpha */, | |||||
| float /* beta */, float /* gamma */, float /* scale */, | |||||
| const GemmCoord& /* threadblock_shape */, | |||||
| const GemmCoord& /* warp_shape */, | |||||
| const int32_t /* access_size */, int /* stages */, | |||||
| cudaStream_t /* stream */) {} | |||||
| #else | |||||
| template <bool NeedLoadFromConstMem> | |||||
| void megdnn::cuda::cutlass_wrapper:: | |||||
| do_conv_bias_int4_int4_implicit_gemm_imma_nhwc( | |||||
| const int8_t* d_src, const int8_t* d_filter, | |||||
| const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, | |||||
| int* workspace, const convolution::ConvParam& param, | |||||
| uint32_t nonlinear_mode, float alpha, float beta, float gamma, | |||||
| float scale, const GemmCoord& threadblock_shape, | |||||
| const GemmCoord& warp_shape, const int32_t access_size, | |||||
| int stages, cudaStream_t stream) { | |||||
| bool without_shared_load = | |||||
| ((param.co % threadblock_shape.n() == 0) && | |||||
| (threadblock_shape.n() == 32 || threadblock_shape.n() == 64)); | |||||
| int out_elements_per_access = | |||||
| without_shared_load ? threadblock_shape.n() / 4 : 8; | |||||
| #define RUN_CUTLASS_WRAPPER(stage_, access_size_, without_shared_load_) \ | |||||
| using Convolution = cutlass::conv::device::Convolution< \ | |||||
| cutlass::int4b_t, cutlass::layout::TensorNHWC, cutlass::int4b_t, \ | |||||
| cutlass::layout::TensorNCxHWx<access_size_>, ElementOutput, \ | |||||
| cutlass::layout::TensorNHWC, int32_t, cutlass::layout::TensorNHWC, \ | |||||
| int32_t, cutlass::conv::ConvType::kConvolution, \ | |||||
| cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \ | |||||
| ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ | |||||
| cutlass::conv::threadblock:: \ | |||||
| ConvolutionFpropTransThreadblockSwizzle, \ | |||||
| stage_, access_size_, access_size_, NeedLoadFromConstMem, \ | |||||
| cutlass::arch::OpMultiplyAddSaturate, \ | |||||
| cutlass::conv::ImplicitGemmMode::GEMM_TN, without_shared_load_>; \ | |||||
| typename Convolution::ConvolutionParameter conv_param( \ | |||||
| param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ | |||||
| param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ | |||||
| param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \ | |||||
| return cutlass_convolution_wrapper<Convolution>( \ | |||||
| reinterpret_cast<const cutlass::int4b_t*>(d_src), \ | |||||
| reinterpret_cast<const cutlass::int4b_t*>(d_filter), d_bias, \ | |||||
| reinterpret_cast<const cutlass::int4b_t*>(d_z), \ | |||||
| reinterpret_cast<cutlass::int4b_t*>(d_dst), workspace, conv_param, \ | |||||
| epilogue, stream); | |||||
| #define DISPATCH_KERNEL_WITH_TILE_SHAPE( \ | |||||
| threadblock_m_, threadblock_n_, threadblock_k_, warp_m_, warp_n_, \ | |||||
| warp_k_, stage_, access_size_, out_elements_per_access_, \ | |||||
| without_shared_load_) \ | |||||
| if (threadblock_shape.m() == threadblock_m_ && \ | |||||
| threadblock_shape.n() == threadblock_n_ && \ | |||||
| threadblock_shape.k() == threadblock_k_ && \ | |||||
| warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ | |||||
| warp_shape.k() == warp_k_ && stages == stage_ && \ | |||||
| access_size == access_size_ && \ | |||||
| out_elements_per_access == out_elements_per_access_ && \ | |||||
| without_shared_load == without_shared_load_) { \ | |||||
| using ThreadBlockShape = \ | |||||
| cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \ | |||||
| threadblock_k_>; \ | |||||
| using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \ | |||||
| using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; \ | |||||
| using ElementOutput = cutlass::int4b_t; \ | |||||
| using ElementAccumulator = int32_t; \ | |||||
| using ElementBias = int32_t; \ | |||||
| using ElementCompute = float; \ | |||||
| using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode; \ | |||||
| switch (nonlinear_mode) { \ | |||||
| case NonlineMode::IDENTITY: { \ | |||||
| using EpilogueOp = cutlass::epilogue::thread:: \ | |||||
| BiasAddLinearCombinationClamp< \ | |||||
| ElementOutput, out_elements_per_access_, \ | |||||
| ElementAccumulator, ElementBias, \ | |||||
| ElementCompute>; \ | |||||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma}; \ | |||||
| RUN_CUTLASS_WRAPPER(stage_, access_size_, \ | |||||
| without_shared_load_); \ | |||||
| } \ | |||||
| case NonlineMode::RELU: { \ | |||||
| using EpilogueOp = cutlass::epilogue::thread:: \ | |||||
| BiasAddLinearCombinationReluClamp< \ | |||||
| ElementOutput, out_elements_per_access_, \ | |||||
| ElementAccumulator, ElementBias, \ | |||||
| ElementCompute>; \ | |||||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma, 0}; \ | |||||
| RUN_CUTLASS_WRAPPER(stage_, access_size_, \ | |||||
| without_shared_load_); \ | |||||
| } \ | |||||
| case NonlineMode::H_SWISH: { \ | |||||
| using EpilogueOp = cutlass::epilogue::thread:: \ | |||||
| BiasAddLinearCombinationHSwishClamp< \ | |||||
| ElementOutput, out_elements_per_access_, \ | |||||
| ElementAccumulator, ElementBias, \ | |||||
| ElementCompute>; \ | |||||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma, \ | |||||
| scale}; \ | |||||
| RUN_CUTLASS_WRAPPER(stage_, access_size_, \ | |||||
| without_shared_load_); \ | |||||
| } \ | |||||
| default: \ | |||||
| megdnn_assert( \ | |||||
| false, \ | |||||
| "unsupported nonlinear mode for conv bias operator"); \ | |||||
| } \ | |||||
| } | |||||
| #define DISPATCH_KERNEL \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 32, 8, false); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 16, 8, false); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 8, 8, false); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 32, 8, false); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 16, 8, false); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 8, 8, false); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 32, 8, true); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 16, 8, true); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 8, 8, true); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 32, 16, true); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 16, 16, true); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 8, 16, true); \ | |||||
| megdnn_assert(false, \ | |||||
| "unsupported threadblock shape (%dx%dx%d) and warp shape " \ | |||||
| "(%dx%dx%d) and access_size (%d)", \ | |||||
| threadblock_shape.m(), threadblock_shape.n(), \ | |||||
| threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ | |||||
| warp_shape.k(), access_size); | |||||
| DISPATCH_KERNEL; | |||||
| #undef RUN_CUTLASS_WRAPPER | |||||
| #undef DISPATCH_KERNEL_WITH_TILE_SHAPE | |||||
| #undef DISPATCH_KERNEL | |||||
| } | |||||
| #endif | |||||
| #define INST(need_load_from_const_mem) \ | |||||
| template void megdnn::cuda::cutlass_wrapper:: \ | |||||
| do_conv_bias_int4_int4_implicit_gemm_imma_nhwc< \ | |||||
| need_load_from_const_mem>( \ | |||||
| const int8_t* d_src, const int8_t* d_filter, \ | |||||
| const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, \ | |||||
| int* workspace, const convolution::ConvParam& param, \ | |||||
| uint32_t nonlinear_mode, float alpha, float beta, \ | |||||
| float gamma, float scale, \ | |||||
| const GemmCoord& threadblock_shape, \ | |||||
| const GemmCoord& warp_shape, const int32_t access_size, \ | |||||
| int stages, cudaStream_t stream); | |||||
| INST(true); | |||||
| INST(false); | |||||
| #undef INST | |||||
| /* ====== cutlass kernel wrapper for uint4 x int4 nhwc layout ====== */ | |||||
| #if MEGDNN_TEGRA_X1 | |||||
| template <bool NeedLoadFromConstMem> | |||||
| void megdnn::cuda::cutlass_wrapper:: | |||||
| do_conv_bias_uint4_int4_implicit_gemm_imma_nhwc( | |||||
| const uint8_t* /* d_src */, const int8_t* /* d_filter */, | |||||
| const int32_t* /* d_bias */, const uint8_t* /* d_z */, | |||||
| uint8_t* /* d_dst */, int* /* workspace */, | |||||
| const convolution::ConvParam& /* param */, | |||||
| uint32_t /* nonlinear_mode */, float /* alpha */, | |||||
| float /* beta */, float /* gamma */, float /* delta */, | |||||
| float /* theta */, float /* scale */, | |||||
| uint8_t /* src_zero_point */, | |||||
| const GemmCoord& /* threadblock_shape */, | |||||
| const GemmCoord& /* warp_shape */, | |||||
| const int32_t /* access_size */, int /* stages */, | |||||
| cudaStream_t /* stream */) {} | |||||
| #else | |||||
| template <bool NeedLoadFromConstMem> | |||||
| void megdnn::cuda::cutlass_wrapper:: | |||||
| do_conv_bias_uint4_int4_implicit_gemm_imma_nhwc( | |||||
| const uint8_t* d_src, const int8_t* d_filter, | |||||
| const int32_t* d_bias, const uint8_t* d_z, uint8_t* d_dst, | |||||
| int* workspace, const convolution::ConvParam& param, | |||||
| uint32_t nonlinear_mode, float alpha, float beta, float gamma, | |||||
| float delta, float theta, float /* scale */, | |||||
| uint8_t src_zero_point, const GemmCoord& threadblock_shape, | |||||
| const GemmCoord& warp_shape, const int32_t access_size, | |||||
| int stages, cudaStream_t stream) { | |||||
| bool without_shared_load = | |||||
| ((param.co % threadblock_shape.n() == 0) && | |||||
| (threadblock_shape.n() == 32 || threadblock_shape.n() == 64)); | |||||
| int out_elements_per_access = | |||||
| without_shared_load ? threadblock_shape.n() / 4 : 8; | |||||
| #define RUN_CUTLASS_WRAPPER(stage_, access_size_, without_shared_load_) \ | |||||
| using Convolution = cutlass::conv::device::Convolution< \ | |||||
| cutlass::uint4b_t, cutlass::layout::TensorNHWC, cutlass::int4b_t, \ | |||||
| cutlass::layout::TensorNCxHWx<access_size_>, ElementOutput, \ | |||||
| cutlass::layout::TensorNHWC, int32_t, cutlass::layout::TensorNHWC, \ | |||||
| int32_t, cutlass::conv::ConvType::kConvolution, \ | |||||
| cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \ | |||||
| ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ | |||||
| cutlass::conv::threadblock:: \ | |||||
| ConvolutionFpropTransThreadblockSwizzle, \ | |||||
| stage_, access_size_, access_size_, NeedLoadFromConstMem, \ | |||||
| cutlass::arch::OpMultiplyAddSaturate, \ | |||||
| cutlass::conv::ImplicitGemmMode::GEMM_TN, without_shared_load_>; \ | |||||
| typename Convolution::ConvolutionParameter conv_param( \ | |||||
| param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ | |||||
| param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ | |||||
| param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \ | |||||
| return cutlass_convolution_wrapper<Convolution>( \ | |||||
| reinterpret_cast<const cutlass::uint4b_t*>(d_src), \ | |||||
| reinterpret_cast<const cutlass::int4b_t*>(d_filter), d_bias, \ | |||||
| reinterpret_cast<const cutlass::uint4b_t*>(d_z), \ | |||||
| reinterpret_cast<cutlass::uint4b_t*>(d_dst), workspace, \ | |||||
| conv_param, epilogue, stream, {src_zero_point}); | |||||
| #define DISPATCH_KERNEL_WITH_TILE_SHAPE( \ | |||||
| threadblock_m_, threadblock_n_, threadblock_k_, warp_m_, warp_n_, \ | |||||
| warp_k_, stage_, access_size_, out_elements_per_access_, \ | |||||
| without_shared_load_) \ | |||||
| if (threadblock_shape.m() == threadblock_m_ && \ | |||||
| threadblock_shape.n() == threadblock_n_ && \ | |||||
| threadblock_shape.k() == threadblock_k_ && \ | |||||
| warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ | |||||
| warp_shape.k() == warp_k_ && stages == stage_ && \ | |||||
| access_size == access_size_ && \ | |||||
| out_elements_per_access == out_elements_per_access_ && \ | |||||
| without_shared_load == without_shared_load_) { \ | |||||
| using ThreadBlockShape = \ | |||||
| cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \ | |||||
| threadblock_k_>; \ | |||||
| using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \ | |||||
| using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; \ | |||||
| using ElementOutput = cutlass::uint4b_t; \ | |||||
| using ElementAccumulator = int32_t; \ | |||||
| using ElementBias = int32_t; \ | |||||
| using ElementCompute = float; \ | |||||
| using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode; \ | |||||
| switch (nonlinear_mode) { \ | |||||
| case NonlineMode::IDENTITY: { \ | |||||
| using EpilogueOp = cutlass::epilogue::thread:: \ | |||||
| BiasAddLinearCombinationClamp< \ | |||||
| ElementOutput, out_elements_per_access_, \ | |||||
| ElementAccumulator, ElementBias, \ | |||||
| ElementCompute>; \ | |||||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma, \ | |||||
| delta + theta}; \ | |||||
| RUN_CUTLASS_WRAPPER(stage_, access_size_, \ | |||||
| without_shared_load_); \ | |||||
| } \ | |||||
| case NonlineMode::RELU: { \ | |||||
| using EpilogueOp = cutlass::epilogue::thread:: \ | |||||
| BiasAddLinearCombinationReluClamp< \ | |||||
| ElementOutput, out_elements_per_access_, \ | |||||
| ElementAccumulator, ElementBias, \ | |||||
| ElementCompute>; \ | |||||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma, \ | |||||
| 0, delta, theta}; \ | |||||
| RUN_CUTLASS_WRAPPER(stage_, access_size_, \ | |||||
| without_shared_load_); \ | |||||
| } \ | |||||
| default: \ | |||||
| megdnn_assert( \ | |||||
| false, \ | |||||
| "unsupported nonlinear mode for conv bias operator"); \ | |||||
| } \ | |||||
| } | |||||
| #define DISPATCH_KERNEL \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 32, 8, false); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 16, 8, false); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 8, 8, false); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 32, 8, false); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 16, 8, false); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 8, 8, false); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 32, 8, true); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 16, 8, true); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 8, 8, true); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 32, 16, true); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 16, 16, true); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 8, 16, true); \ | |||||
| megdnn_assert(false, \ | |||||
| "unsupported threadblock shape (%dx%dx%d) and warp shape " \ | |||||
| "(%dx%dx%d) and access_size (%d)", \ | |||||
| threadblock_shape.m(), threadblock_shape.n(), \ | |||||
| threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ | |||||
| warp_shape.k(), access_size); | |||||
| DISPATCH_KERNEL; | |||||
| #undef RUN_CUTLASS_WRAPPER | |||||
| #undef DISPATCH_KERNEL_WITH_TILE_SHAPE | |||||
| #undef DISPATCH_KERNEL | |||||
| } | |||||
| #endif | |||||
| #define INST(need_load_from_const_mem) \ | |||||
| template void megdnn::cuda::cutlass_wrapper:: \ | |||||
| do_conv_bias_uint4_int4_implicit_gemm_imma_nhwc< \ | |||||
| need_load_from_const_mem>( \ | |||||
| const uint8_t* d_src, const int8_t* d_filter, \ | |||||
| const int32_t* d_bias, const uint8_t* d_z, uint8_t* d_dst, \ | |||||
| int* workspace, const convolution::ConvParam& param, \ | |||||
| uint32_t nonlinear_mode, float alpha, float beta, \ | |||||
| float gamma, float delta, float theta, float scale, \ | |||||
| uint8_t src_zero_point, \ | |||||
| const GemmCoord& threadblock_shape, \ | |||||
| const GemmCoord& warp_shape, const int32_t access_size, \ | |||||
| int stages, cudaStream_t stream); | |||||
| INST(true); | |||||
| INST(false); | |||||
| #undef INST | |||||
| // vim: syntax=cuda.doxygen | |||||
| @@ -1,804 +0,0 @@ | |||||
| /** | |||||
| * \file dnn/src/cuda/conv_bias/cutlass_convolution_wrapper.cu | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| // ignore warning of cutlass | |||||
| #pragma GCC diagnostic push | |||||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||||
| #if !MEGDNN_TEGRA_X1 | |||||
| #include "cutlass/convolution/device/convolution.h" | |||||
| #endif | |||||
| #include "src/common/opr_param_defs_enumv.cuh" | |||||
| #include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh" | |||||
| #pragma GCC diagnostic pop | |||||
| using namespace megdnn; | |||||
| using namespace cuda; | |||||
| using namespace cutlass_wrapper; | |||||
| /* ====== cutlass kernel wrapper for int8 nchw32 layout ====== */ | |||||
| #if MEGDNN_TEGRA_X1 | |||||
| template <bool NeedLoadFromConstMem> | |||||
| void megdnn::cuda::cutlass_wrapper:: | |||||
| do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32( | |||||
| const int8_t* /* d_src */, const int8_t* /* d_filter */, | |||||
| const int32_t* /* d_bias */, const int8_t* /* d_z */, | |||||
| int8_t* /* d_dst */, int* /* workspace */, | |||||
| const convolution::ConvParam& /* param */, | |||||
| uint32_t /* nonlinear_mode */, float /* alpha */, | |||||
| float /* beta */, float /* gamma */, float /* scale */, | |||||
| const GemmCoord& /* threadblock_shape */, | |||||
| const GemmCoord& /* warp_shape */, int /* stages */, | |||||
| cudaStream_t /* stream */) {} | |||||
| #else | |||||
| template <bool NeedLoadFromConstMem> | |||||
| void megdnn::cuda::cutlass_wrapper:: | |||||
| do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32( | |||||
| const int8_t* d_src, const int8_t* d_filter, | |||||
| const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, | |||||
| int* workspace, const convolution::ConvParam& param, | |||||
| uint32_t nonlinear_mode, float alpha, float beta, float gamma, | |||||
| float scale, const GemmCoord& threadblock_shape, | |||||
| const GemmCoord& warp_shape, int stages, cudaStream_t stream) { | |||||
| #define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \ | |||||
| threadblock_k_, warp_m_, warp_n_, \ | |||||
| warp_k_, stage_) \ | |||||
| if (threadblock_shape.m() == threadblock_m_ && \ | |||||
| threadblock_shape.n() == threadblock_n_ && \ | |||||
| threadblock_shape.k() == threadblock_k_ && \ | |||||
| warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ | |||||
| warp_shape.k() == warp_k_ && stages == stage_) { \ | |||||
| using ThreadBlockShape = \ | |||||
| cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \ | |||||
| threadblock_k_>; \ | |||||
| using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \ | |||||
| using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; \ | |||||
| using Convolution = cutlass::conv::device::Convolution< \ | |||||
| int8_t, cutlass::layout::TensorNCxHWx<32>, int8_t, \ | |||||
| cutlass::layout::TensorCxRSKx<32>, ElementOutput, \ | |||||
| cutlass::layout::TensorNCxHWx<32>, int32_t, \ | |||||
| cutlass::layout::TensorNCxHWx<32>, int32_t, \ | |||||
| cutlass::conv::ConvType::kConvolution, \ | |||||
| cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \ | |||||
| ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ | |||||
| cutlass::conv::threadblock:: \ | |||||
| ConvolutionFpropTransThreadblockSwizzle, \ | |||||
| stage_, 16, 16, NeedLoadFromConstMem, \ | |||||
| cutlass::arch::OpMultiplyAddSaturate, \ | |||||
| cutlass::conv::ImplicitGemmMode::GEMM_TN, true>; \ | |||||
| typename Convolution::ConvolutionParameter conv_param( \ | |||||
| param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ | |||||
| param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ | |||||
| param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \ | |||||
| return cutlass_convolution_wrapper<Convolution>( \ | |||||
| d_src, d_filter, d_bias, d_z, d_dst, workspace, conv_param, \ | |||||
| epilogue, stream); \ | |||||
| } | |||||
| #define DISPATCH_KERNEL \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 256, 64, 64, 64, 64, 2); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(256, 128, 64, 64, 64, 64, 2); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 64, 64, 64, 64, 2); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 32, 64, 2); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 64, 32, 64, 64, 2); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 32, 64, 32, 32, 1); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 32, 64, 32, 32, 1); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 32, 64, 32, 1); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 1); \ | |||||
| megdnn_assert(false, \ | |||||
| "unsupported threadblock shape (%dx%dx%d) and warp shape " \ | |||||
| "(%dx%dx%d)", \ | |||||
| threadblock_shape.m(), threadblock_shape.n(), \ | |||||
| threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ | |||||
| warp_shape.k()); | |||||
| using ElementOutput = int8_t; | |||||
| using ElementAccumulator = int32_t; | |||||
| using ElementBias = int32_t; | |||||
| using ElementCompute = float; | |||||
| using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode; | |||||
| switch (nonlinear_mode) { | |||||
| case NonlineMode::IDENTITY: { | |||||
| using EpilogueOp = | |||||
| cutlass::epilogue::thread::BiasAddLinearCombinationClamp< | |||||
| ElementOutput, 8, ElementAccumulator, ElementBias, | |||||
| ElementCompute>; | |||||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma}; | |||||
| DISPATCH_KERNEL; | |||||
| } | |||||
| case NonlineMode::RELU: { | |||||
| using EpilogueOp = cutlass::epilogue::thread:: | |||||
| BiasAddLinearCombinationReluClamp< | |||||
| ElementOutput, 8, ElementAccumulator, ElementBias, | |||||
| ElementCompute>; | |||||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma, 0}; | |||||
| DISPATCH_KERNEL; | |||||
| } | |||||
| case NonlineMode::H_SWISH: { | |||||
| using EpilogueOp = cutlass::epilogue::thread:: | |||||
| BiasAddLinearCombinationHSwishClamp< | |||||
| ElementOutput, 8, ElementAccumulator, ElementBias, | |||||
| ElementCompute>; | |||||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma, scale}; | |||||
| DISPATCH_KERNEL; | |||||
| } | |||||
| default: | |||||
| megdnn_assert(false, | |||||
| "unsupported nonlinear mode for conv bias operator"); | |||||
| } | |||||
| #undef DISPATCH_KERNEL_WITH_TILE_SHAPE | |||||
| #undef DISPATCH_KERNEL | |||||
| } | |||||
| #endif | |||||
| #define INST(need_load_from_const_mem) \ | |||||
| template void megdnn::cuda::cutlass_wrapper:: \ | |||||
| do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32< \ | |||||
| need_load_from_const_mem>( \ | |||||
| const int8_t* d_src, const int8_t* d_filter, \ | |||||
| const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, \ | |||||
| int* workspace, const convolution::ConvParam& param, \ | |||||
| uint32_t nonlinear_mode, float alpha, float beta, \ | |||||
| float gamma, float scale, \ | |||||
| const GemmCoord& threadblock_shape, \ | |||||
| const GemmCoord& warp_shape, int stages, \ | |||||
| cudaStream_t stream); | |||||
| INST(true); | |||||
| INST(false); | |||||
| #undef INST | |||||
| /* ===== cutlass kernel wrapper for int8 nchw32 layout and nchw4 output ===== */ | |||||
| #if MEGDNN_TEGRA_X1 | |||||
| template <bool NeedLoadFromConstMem> | |||||
| void megdnn::cuda::cutlass_wrapper:: | |||||
| do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4( | |||||
| const int8_t* /* d_src */, const int8_t* /* d_filter */, | |||||
| const int32_t* /* d_bias */, const int8_t* /* d_z */, | |||||
| int8_t* /* d_dst */, int* /* workspace */, | |||||
| const convolution::ConvParam& /* param */, | |||||
| uint32_t /* nonlinear_mode */, float /* alpha */, | |||||
| float /* beta */, float /* gamma */, float /* scale */, | |||||
| const GemmCoord& /* threadblock_shape */, | |||||
| const GemmCoord& /* warp_shape */, int /* stages */, | |||||
| cudaStream_t /* stream */) {} | |||||
| #else | |||||
| template <bool NeedLoadFromConstMem> | |||||
| void megdnn::cuda::cutlass_wrapper:: | |||||
| do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4( | |||||
| const int8_t* d_src, const int8_t* d_filter, | |||||
| const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, | |||||
| int* workspace, const convolution::ConvParam& param, | |||||
| uint32_t nonlinear_mode, float alpha, float beta, float gamma, | |||||
| float scale, const GemmCoord& threadblock_shape, | |||||
| const GemmCoord& warp_shape, int stages, cudaStream_t stream) { | |||||
| #define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \ | |||||
| threadblock_k_, warp_m_, warp_n_, \ | |||||
| warp_k_, stage_) \ | |||||
| if (threadblock_shape.m() == threadblock_m_ && \ | |||||
| threadblock_shape.n() == threadblock_n_ && \ | |||||
| threadblock_shape.k() == threadblock_k_ && \ | |||||
| warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ | |||||
| warp_shape.k() == warp_k_ && stages == stage_) { \ | |||||
| using ThreadBlockShape = \ | |||||
| cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \ | |||||
| threadblock_k_>; \ | |||||
| using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \ | |||||
| using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; \ | |||||
| using Convolution = cutlass::conv::device::Convolution< \ | |||||
| int8_t, cutlass::layout::TensorNCxHWx<32>, int8_t, \ | |||||
| cutlass::layout::TensorCxRSKx<32>, ElementOutput, \ | |||||
| cutlass::layout::TensorNCxHWx<4>, int32_t, \ | |||||
| cutlass::layout::TensorNCxHWx<4>, int32_t, \ | |||||
| cutlass::conv::ConvType::kConvolution, \ | |||||
| cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \ | |||||
| ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ | |||||
| cutlass::conv::threadblock:: \ | |||||
| ConvolutionFpropNCxHWxThreadblockSwizzle, \ | |||||
| stage_, 16, 16, NeedLoadFromConstMem>; \ | |||||
| typename Convolution::ConvolutionParameter conv_param( \ | |||||
| param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ | |||||
| param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ | |||||
| param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \ | |||||
| return cutlass_convolution_wrapper<Convolution>( \ | |||||
| d_src, d_filter, d_bias, d_z, d_dst, workspace, conv_param, \ | |||||
| epilogue, stream); \ | |||||
| } | |||||
| #define DISPATCH_KERNEL \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 256, 64, 64, 64, 64, 2); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(256, 128, 64, 64, 64, 64, 2); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 64, 64, 64, 64, 2); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 32, 64, 2); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 64, 32, 64, 64, 2); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 32, 64, 32, 32, 1); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 32, 64, 32, 32, 1); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 32, 64, 32, 1); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 1); \ | |||||
| megdnn_assert(false, \ | |||||
| "unsupported threadblock shape (%dx%dx%d) and warp shape " \ | |||||
| "(%dx%dx%d)", \ | |||||
| threadblock_shape.m(), threadblock_shape.n(), \ | |||||
| threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ | |||||
| warp_shape.k()); | |||||
| using ElementOutput = int8_t; | |||||
| using ElementAccumulator = int32_t; | |||||
| using ElementBias = int32_t; | |||||
| using ElementCompute = float; | |||||
| using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode; | |||||
| switch (nonlinear_mode) { | |||||
| case NonlineMode::IDENTITY: { | |||||
| using EpilogueOp = | |||||
| cutlass::epilogue::thread::BiasAddLinearCombinationClamp< | |||||
| ElementOutput, 4, ElementAccumulator, ElementBias, | |||||
| ElementCompute>; | |||||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma}; | |||||
| DISPATCH_KERNEL; | |||||
| } | |||||
| case NonlineMode::RELU: { | |||||
| using EpilogueOp = cutlass::epilogue::thread:: | |||||
| BiasAddLinearCombinationReluClamp< | |||||
| ElementOutput, 4, ElementAccumulator, ElementBias, | |||||
| ElementCompute>; | |||||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma, 0}; | |||||
| DISPATCH_KERNEL; | |||||
| } | |||||
| case NonlineMode::H_SWISH: { | |||||
| using EpilogueOp = cutlass::epilogue::thread:: | |||||
| BiasAddLinearCombinationHSwishClamp< | |||||
| ElementOutput, 4, ElementAccumulator, ElementBias, | |||||
| ElementCompute>; | |||||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma, scale}; | |||||
| DISPATCH_KERNEL; | |||||
| } | |||||
| default: | |||||
| megdnn_assert(false, | |||||
| "unsupported nonlinear mode for conv bias operator"); | |||||
| } | |||||
| #undef DISPATCH_KERNEL_WITH_TILE_SHAPE | |||||
| #undef DISPATCH_KERNEL | |||||
| } | |||||
| #endif | |||||
| #define INST(need_load_from_const_mem) \ | |||||
| template void megdnn::cuda::cutlass_wrapper:: \ | |||||
| do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4< \ | |||||
| need_load_from_const_mem>( \ | |||||
| const int8_t* d_src, const int8_t* d_filter, \ | |||||
| const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, \ | |||||
| int* workspace, const convolution::ConvParam& param, \ | |||||
| uint32_t nonlinear_mode, float alpha, float beta, \ | |||||
| float gamma, float scale, \ | |||||
| const GemmCoord& threadblock_shape, \ | |||||
| const GemmCoord& warp_shape, int stages, \ | |||||
| cudaStream_t stream); | |||||
| INST(true); | |||||
| INST(false); | |||||
| #undef INST | |||||
| /* ====== cutlass kernel wrapper for int8 nchw4 layout ====== */ | |||||
| #if MEGDNN_TEGRA_X1 | |||||
| template <bool NeedLoadFromConstMem> | |||||
| void megdnn::cuda::cutlass_wrapper:: | |||||
| do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4( | |||||
| const int8_t* /* d_src */, const int8_t* /* d_filter */, | |||||
| const int32_t* /* d_bias */, const int8_t* /* d_z */, | |||||
| int8_t* /* d_dst */, int* /* workspace */, | |||||
| const convolution::ConvParam& /* param */, | |||||
| uint32_t /* nonlinear_mode */, float /* alpha */, | |||||
| float /* beta */, float /* gamma */, float /* scale */, | |||||
| const GemmCoord& /* threadblock_shape */, | |||||
| const GemmCoord& /* warp_shape */, int /* stages */, | |||||
| cudaStream_t /* stream */) {} | |||||
| #else | |||||
| template <bool NeedLoadFromConstMem> | |||||
| void megdnn::cuda::cutlass_wrapper:: | |||||
| do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4( | |||||
| const int8_t* d_src, const int8_t* d_filter, | |||||
| const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, | |||||
| int* workspace, const convolution::ConvParam& param, | |||||
| uint32_t nonlinear_mode, float alpha, float beta, float gamma, | |||||
| float scale, const GemmCoord& threadblock_shape, | |||||
| const GemmCoord& warp_shape, int stages, cudaStream_t stream) { | |||||
| #define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \ | |||||
| threadblock_k_, warp_m_, warp_n_, \ | |||||
| warp_k_, stage_, aligned_) \ | |||||
| if (threadblock_shape.m() == threadblock_m_ && \ | |||||
| threadblock_shape.n() == threadblock_n_ && \ | |||||
| threadblock_shape.k() == threadblock_k_ && \ | |||||
| warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ | |||||
| warp_shape.k() == warp_k_ && stages == stage_) { \ | |||||
| using ThreadBlockShape = \ | |||||
| cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \ | |||||
| threadblock_k_>; \ | |||||
| using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \ | |||||
| using InstructionShape = cutlass::gemm::GemmShape<1, 1, 4>; \ | |||||
| using Convolution = cutlass::conv::device::Convolution< \ | |||||
| int8_t, cutlass::layout::TensorNCxHWx<4>, int8_t, \ | |||||
| cutlass::layout::TensorCxRSKx<4>, ElementOutput, \ | |||||
| cutlass::layout::TensorNCxHWx<4>, int32_t, \ | |||||
| cutlass::layout::TensorNCxHWx<4>, int32_t, \ | |||||
| cutlass::conv::ConvType::kConvolution, \ | |||||
| cutlass::arch::OpClassSimt, cutlass::arch::Sm61, \ | |||||
| ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ | |||||
| cutlass::conv::threadblock:: \ | |||||
| ConvolutionFpropNCxHWxThreadblockSwizzle, \ | |||||
| stage_, 4, aligned_, NeedLoadFromConstMem, \ | |||||
| cutlass::arch::OpMultiplyAdd>; \ | |||||
| typename Convolution::ConvolutionParameter conv_param( \ | |||||
| param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ | |||||
| param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ | |||||
| param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \ | |||||
| return cutlass_convolution_wrapper<Convolution>( \ | |||||
| d_src, d_filter, d_bias, d_z, d_dst, workspace, conv_param, \ | |||||
| epilogue, stream); \ | |||||
| } | |||||
| #define DISPATCH_KERNEL \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 32, 64, 32, 32, 2, 16); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 32, 64, 32, 32, 2, 16); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 64, 32, 32, 2, 16); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 32, 64, 32, 32, 2, 16); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 2, 16); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 64, 32, 32, 64, 32, 2, 16); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 32, 32, 64, 32, 32, 2, 16); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 128, 16, 16, 128, 16, 1, 8); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 64, 8, 16, 64, 8, 2, 4); \ | |||||
| megdnn_assert(false, \ | |||||
| "unsupported threadblock shape (%dx%dx%d) and warp shape " \ | |||||
| "(%dx%dx%d)", \ | |||||
| threadblock_shape.m(), threadblock_shape.n(), \ | |||||
| threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ | |||||
| warp_shape.k()); | |||||
| using ElementOutput = int8_t; | |||||
| using ElementAccumulator = int32_t; | |||||
| using ElementBias = int32_t; | |||||
| using ElementCompute = float; | |||||
| using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode; | |||||
| switch (nonlinear_mode) { | |||||
| case NonlineMode::IDENTITY: { | |||||
| using EpilogueOp = | |||||
| cutlass::epilogue::thread::BiasAddLinearCombinationClamp< | |||||
| ElementOutput, 4, ElementAccumulator, ElementBias, | |||||
| ElementCompute>; | |||||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma}; | |||||
| DISPATCH_KERNEL; | |||||
| } | |||||
| case NonlineMode::RELU: { | |||||
| using EpilogueOp = cutlass::epilogue::thread:: | |||||
| BiasAddLinearCombinationReluClamp< | |||||
| ElementOutput, 4, ElementAccumulator, ElementBias, | |||||
| ElementCompute>; | |||||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma, 0}; | |||||
| DISPATCH_KERNEL; | |||||
| } | |||||
| case NonlineMode::H_SWISH: { | |||||
| using EpilogueOp = cutlass::epilogue::thread:: | |||||
| BiasAddLinearCombinationHSwishClamp< | |||||
| ElementOutput, 4, ElementAccumulator, ElementBias, | |||||
| ElementCompute>; | |||||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma, scale}; | |||||
| DISPATCH_KERNEL; | |||||
| } | |||||
| default: | |||||
| megdnn_assert(false, | |||||
| "unsupported nonlinear mode for conv bias operator"); | |||||
| } | |||||
| #undef DISPATCH_KERNEL_WITH_TILE_SHAPE | |||||
| #undef DISPATCH_KERNEL | |||||
| } | |||||
| #endif | |||||
| #define INST(need_load_from_const_mem) \ | |||||
| template void megdnn::cuda::cutlass_wrapper:: \ | |||||
| do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4< \ | |||||
| need_load_from_const_mem>( \ | |||||
| const int8_t* d_src, const int8_t* d_filter, \ | |||||
| const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, \ | |||||
| int* workspace, const convolution::ConvParam& param, \ | |||||
| uint32_t nonlinear_mode, float alpha, float beta, \ | |||||
| float gamma, float scale, \ | |||||
| const GemmCoord& threadblock_shape, \ | |||||
| const GemmCoord& warp_shape, int stages, \ | |||||
| cudaStream_t stream); | |||||
| INST(true); | |||||
| INST(false); | |||||
| #undef INST | |||||
| /* ====== cutlass kernel wrapper for int8 nchw4 layout and nchw output ====== */ | |||||
| #if MEGDNN_TEGRA_X1 | |||||
| template <bool NeedLoadFromConstMem> | |||||
| void megdnn::cuda::cutlass_wrapper:: | |||||
| do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nchw( | |||||
| const int8_t* /* d_src */, const int8_t* /* d_filter */, | |||||
| const float* /* d_bias */, const float* /* d_z */, | |||||
| float* /* d_dst */, int* /* workspace */, | |||||
| const convolution::ConvParam& /* param */, | |||||
| uint32_t /* nonlinear_mode */, float /* alpha */, | |||||
| float /* beta */, float /* gamma */, float /* scale */, | |||||
| const GemmCoord& /* threadblock_shape */, | |||||
| const GemmCoord& /* warp_shape */, int /* stages */, | |||||
| cudaStream_t /* stream */) {} | |||||
| #else | |||||
| template <bool NeedLoadFromConstMem> | |||||
| void megdnn::cuda::cutlass_wrapper:: | |||||
| do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nchw( | |||||
| const int8_t* d_src, const int8_t* d_filter, | |||||
| const float* d_bias, const float* d_z, float* d_dst, | |||||
| int* workspace, const convolution::ConvParam& param, | |||||
| uint32_t nonlinear_mode, float alpha, float beta, float gamma, | |||||
| float scale, const GemmCoord& threadblock_shape, | |||||
| const GemmCoord& warp_shape, int stages, cudaStream_t stream) { | |||||
| #define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \ | |||||
| threadblock_k_, warp_m_, warp_n_, \ | |||||
| warp_k_, stages_, aligned_) \ | |||||
| if (threadblock_shape.m() == threadblock_m_ && \ | |||||
| threadblock_shape.n() == threadblock_n_ && \ | |||||
| threadblock_shape.k() == threadblock_k_ && \ | |||||
| warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ | |||||
| warp_shape.k() == warp_k_ && stages == stages_) { \ | |||||
| using ThreadBlockShape = \ | |||||
| cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \ | |||||
| threadblock_k_>; \ | |||||
| using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \ | |||||
| using InstructionShape = cutlass::gemm::GemmShape<1, 1, 4>; \ | |||||
| using Convolution = cutlass::conv::device::Convolution< \ | |||||
| int8_t, cutlass::layout::TensorNCxHWx<4>, int8_t, \ | |||||
| cutlass::layout::TensorCxRSKx<4>, ElementOutput, \ | |||||
| cutlass::layout::TensorNCHW, float, \ | |||||
| cutlass::layout::TensorNCHW, int32_t, \ | |||||
| cutlass::conv::ConvType::kConvolution, \ | |||||
| cutlass::arch::OpClassSimt, cutlass::arch::Sm61, \ | |||||
| ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ | |||||
| cutlass::conv::threadblock:: \ | |||||
| ConvolutionFpropNCxHWxThreadblockSwizzle, \ | |||||
| stages_, 4, aligned_, NeedLoadFromConstMem, \ | |||||
| cutlass::arch::OpMultiplyAdd>; \ | |||||
| typename Convolution::ConvolutionParameter conv_param( \ | |||||
| param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ | |||||
| param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ | |||||
| param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \ | |||||
| return cutlass_convolution_wrapper<Convolution>( \ | |||||
| d_src, d_filter, d_bias, d_z, d_dst, workspace, conv_param, \ | |||||
| epilogue, stream); \ | |||||
| } | |||||
| #define DISPATCH_KERNEL \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 32, 64, 32, 32, 2, 16); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 32, 64, 32, 32, 2, 16); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 64, 32, 32, 2, 16); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 32, 64, 32, 32, 2, 16); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 2, 16); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 64, 32, 32, 64, 32, 2, 16); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 32, 32, 64, 32, 32, 2, 16); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 128, 16, 16, 128, 16, 1, 8); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 64, 8, 16, 64, 8, 2, 4); \ | |||||
| megdnn_assert(false, \ | |||||
| "unsupported threadblock shape (%dx%dx%d) and warp shape " \ | |||||
| "(%dx%dx%d)", \ | |||||
| threadblock_shape.m(), threadblock_shape.n(), \ | |||||
| threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ | |||||
| warp_shape.k()); | |||||
| using ElementOutput = float; | |||||
| using ElementAccumulator = int32_t; | |||||
| using ElementBias = float; | |||||
| using ElementCompute = float; | |||||
| using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode; | |||||
| switch (nonlinear_mode) { | |||||
| case NonlineMode::IDENTITY: { | |||||
| using EpilogueOp = | |||||
| cutlass::epilogue::thread::BiasAddLinearCombination< | |||||
| ElementOutput, 1, ElementAccumulator, ElementBias, | |||||
| ElementCompute>; | |||||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma}; | |||||
| DISPATCH_KERNEL; | |||||
| } | |||||
| case NonlineMode::RELU: { | |||||
| using EpilogueOp = | |||||
| cutlass::epilogue::thread::BiasAddLinearCombinationRelu< | |||||
| ElementOutput, 1, ElementAccumulator, ElementBias, | |||||
| ElementCompute>; | |||||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma, 0}; | |||||
| DISPATCH_KERNEL; | |||||
| } | |||||
| case NonlineMode::H_SWISH: { | |||||
| using EpilogueOp = | |||||
| cutlass::epilogue::thread::BiasAddLinearCombinationHSwish< | |||||
| ElementOutput, 1, ElementAccumulator, ElementBias, | |||||
| ElementCompute>; | |||||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma, scale}; | |||||
| DISPATCH_KERNEL; | |||||
| } | |||||
| default: | |||||
| megdnn_assert(false, | |||||
| "unsupported nonlinear mode for conv bias operator"); | |||||
| } | |||||
| #undef DISPATCH_KERNEL_WITH_TILE_SHAPE | |||||
| #undef DISPATCH_KERNEL | |||||
| } | |||||
| #endif | |||||
| #define INST(need_load_from_const_mem) \ | |||||
| template void megdnn::cuda::cutlass_wrapper:: \ | |||||
| do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nchw< \ | |||||
| need_load_from_const_mem>( \ | |||||
| const int8_t* d_src, const int8_t* d_filter, \ | |||||
| const float* d_bias, const float* d_z, float* d_dst, \ | |||||
| int* workspace, const convolution::ConvParam& param, \ | |||||
| uint32_t nonlinear_mode, float alpha, float beta, \ | |||||
| float gamma, float scale, \ | |||||
| const GemmCoord& threadblock_shape, \ | |||||
| const GemmCoord& warp_shape, int stages, \ | |||||
| cudaStream_t stream); | |||||
| INST(true); | |||||
| INST(false); | |||||
| #undef INST | |||||
| /* ===== cutlass kernel wrapper for int8 nchw4 layout and nchw32 output ===== */ | |||||
| #if MEGDNN_TEGRA_X1 | |||||
| template <bool NeedLoadFromConstMem> | |||||
| void megdnn::cuda::cutlass_wrapper:: | |||||
| do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32( | |||||
| const int8_t* /* d_src */, const int8_t* /* d_filter */, | |||||
| const int32_t* /* d_bias */, const int8_t* /* d_z */, | |||||
| int8_t* /* d_dst */, int* /* workspace */, | |||||
| const convolution::ConvParam& /* param */, | |||||
| uint32_t /* nonlinear_mode */, float /* alpha */, | |||||
| float /* beta */, float /* gamma */, float /* scale */, | |||||
| const GemmCoord& /* threadblock_shape */, | |||||
| const GemmCoord& /* warp_shape */, int /* stages */, | |||||
| cudaStream_t /* stream */) {} | |||||
| #else | |||||
| template <bool NeedLoadFromConstMem> | |||||
| void megdnn::cuda::cutlass_wrapper:: | |||||
| do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32( | |||||
| const int8_t* d_src, const int8_t* d_filter, | |||||
| const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, | |||||
| int* workspace, const convolution::ConvParam& param, | |||||
| uint32_t nonlinear_mode, float alpha, float beta, float gamma, | |||||
| float scale, const GemmCoord& threadblock_shape, | |||||
| const GemmCoord& warp_shape, int stages, cudaStream_t stream) { | |||||
| #define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \ | |||||
| threadblock_k_, warp_m_, warp_n_, \ | |||||
| warp_k_, stages_, aligned_) \ | |||||
| if (threadblock_shape.m() == threadblock_m_ && \ | |||||
| threadblock_shape.n() == threadblock_n_ && \ | |||||
| threadblock_shape.k() == threadblock_k_ && \ | |||||
| warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ | |||||
| warp_shape.k() == warp_k_ && stages == stages_) { \ | |||||
| using ThreadBlockShape = \ | |||||
| cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \ | |||||
| threadblock_k_>; \ | |||||
| using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \ | |||||
| using InstructionShape = cutlass::gemm::GemmShape<1, 1, 4>; \ | |||||
| using Convolution = cutlass::conv::device::Convolution< \ | |||||
| int8_t, cutlass::layout::TensorNCxHWx<4>, int8_t, \ | |||||
| cutlass::layout::TensorCxRSKx<4>, ElementOutput, \ | |||||
| cutlass::layout::TensorNCxHWx<32>, int32_t, \ | |||||
| cutlass::layout::TensorNCxHWx<32>, int32_t, \ | |||||
| cutlass::conv::ConvType::kConvolution, \ | |||||
| cutlass::arch::OpClassSimt, cutlass::arch::Sm61, \ | |||||
| ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ | |||||
| cutlass::conv::threadblock:: \ | |||||
| ConvolutionFpropNCxHWxThreadblockSwizzle, \ | |||||
| stages_, 4, aligned_, NeedLoadFromConstMem, \ | |||||
| cutlass::arch::OpMultiplyAdd>; \ | |||||
| typename Convolution::ConvolutionParameter conv_param( \ | |||||
| param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ | |||||
| param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ | |||||
| param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \ | |||||
| return cutlass_convolution_wrapper<Convolution>( \ | |||||
| d_src, d_filter, d_bias, d_z, d_dst, workspace, conv_param, \ | |||||
| epilogue, stream); \ | |||||
| } | |||||
| #define DISPATCH_KERNEL \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 32, 64, 32, 32, 2, 16); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 32, 64, 32, 32, 2, 16); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 64, 32, 32, 2, 16); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 32, 64, 32, 32, 2, 16); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 2, 16); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 64, 32, 32, 64, 32, 2, 16); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 32, 32, 64, 32, 32, 2, 16); \ | |||||
| megdnn_assert(false, \ | |||||
| "unsupported threadblock shape (%dx%dx%d) and warp shape " \ | |||||
| "(%dx%dx%d)", \ | |||||
| threadblock_shape.m(), threadblock_shape.n(), \ | |||||
| threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ | |||||
| warp_shape.k()); | |||||
| using ElementOutput = int8_t; | |||||
| using ElementAccumulator = int32_t; | |||||
| using ElementBias = int32_t; | |||||
| using ElementCompute = float; | |||||
| using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode; | |||||
| switch (nonlinear_mode) { | |||||
| case NonlineMode::IDENTITY: { | |||||
| using EpilogueOp = | |||||
| cutlass::epilogue::thread::BiasAddLinearCombinationClamp< | |||||
| ElementOutput, 4, ElementAccumulator, ElementBias, | |||||
| ElementCompute>; | |||||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma}; | |||||
| DISPATCH_KERNEL; | |||||
| } | |||||
| case NonlineMode::RELU: { | |||||
| using EpilogueOp = cutlass::epilogue::thread:: | |||||
| BiasAddLinearCombinationReluClamp< | |||||
| ElementOutput, 4, ElementAccumulator, ElementBias, | |||||
| ElementCompute>; | |||||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma, 0}; | |||||
| DISPATCH_KERNEL; | |||||
| } | |||||
| case NonlineMode::H_SWISH: { | |||||
| using EpilogueOp = cutlass::epilogue::thread:: | |||||
| BiasAddLinearCombinationHSwishClamp< | |||||
| ElementOutput, 4, ElementAccumulator, ElementBias, | |||||
| ElementCompute>; | |||||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma, scale}; | |||||
| DISPATCH_KERNEL; | |||||
| } | |||||
| default: | |||||
| megdnn_assert(false, | |||||
| "unsupported nonlinear mode for conv bias operator"); | |||||
| } | |||||
| #undef DISPATCH_KERNEL_WITH_TILE_SHAPE | |||||
| #undef DISPATCH_KERNEL | |||||
| } | |||||
| #endif | |||||
| #define INST(need_load_from_const_mem) \ | |||||
| template void megdnn::cuda::cutlass_wrapper:: \ | |||||
| do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32< \ | |||||
| need_load_from_const_mem>( \ | |||||
| const int8_t* d_src, const int8_t* d_filter, \ | |||||
| const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, \ | |||||
| int* workspace, const convolution::ConvParam& param, \ | |||||
| uint32_t nonlinear_mode, float alpha, float beta, \ | |||||
| float gamma, float scale, \ | |||||
| const GemmCoord& threadblock_shape, \ | |||||
| const GemmCoord& warp_shape, int stages, \ | |||||
| cudaStream_t stream); | |||||
| INST(true); | |||||
| INST(false); | |||||
| #undef INST | |||||
| /* ===== cutlass kernel wrapper for nchw4 layout and nhwc output ===== */ | |||||
| #if MEGDNN_TEGRA_X1 | |||||
| template <bool signedness> | |||||
| void megdnn::cuda::cutlass_wrapper:: | |||||
| do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nhwc( | |||||
| const int8_t* /* d_src */, const int8_t* /* d_filter */, | |||||
| const int32_t* /* d_bias */, const int8_t* /* d_z */, | |||||
| int8_t* /* d_dst */, int* /* workspace */, | |||||
| const convolution::ConvParam& /* param */, | |||||
| uint32_t /* nonlinear_mode */, float /* alpha */, | |||||
| float /* beta */, float /* gamma */, float /* delta */, | |||||
| float /* theta */, float /* scale */, | |||||
| const GemmCoord& /* threadblock_shape */, | |||||
| const GemmCoord& /* warp_shape */, int /* stages */, | |||||
| cudaStream_t /* stream */) {} | |||||
| #else | |||||
| template <bool signedness> | |||||
| void megdnn::cuda::cutlass_wrapper:: | |||||
| do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nhwc( | |||||
| const int8_t* d_src, const int8_t* d_filter, | |||||
| const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, | |||||
| int* workspace, const convolution::ConvParam& param, | |||||
| uint32_t nonlinear_mode, float alpha, float beta, float gamma, | |||||
| float delta, float theta, float scale, | |||||
| const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, | |||||
| int stages, cudaStream_t stream) { | |||||
| #define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \ | |||||
| threadblock_k_, warp_m_, warp_n_, \ | |||||
| warp_k_, stages_, aligned_) \ | |||||
| if (threadblock_shape.m() == threadblock_m_ && \ | |||||
| threadblock_shape.n() == threadblock_n_ && \ | |||||
| threadblock_shape.k() == threadblock_k_ && \ | |||||
| warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ | |||||
| warp_shape.k() == warp_k_ && stages == stages_) { \ | |||||
| using ThreadBlockShape = \ | |||||
| cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \ | |||||
| threadblock_k_>; \ | |||||
| using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \ | |||||
| using InstructionShape = cutlass::gemm::GemmShape<1, 1, 4>; \ | |||||
| using Convolution = cutlass::conv::device::Convolution< \ | |||||
| int8_t, cutlass::layout::TensorNCxHWx<4>, int8_t, \ | |||||
| cutlass::layout::TensorCxRSKx<4>, ElementOutput, \ | |||||
| cutlass::layout::TensorNHWC, int32_t, \ | |||||
| cutlass::layout::TensorNHWC, int32_t, \ | |||||
| cutlass::conv::ConvType::kConvolution, \ | |||||
| cutlass::arch::OpClassSimt, cutlass::arch::Sm75, \ | |||||
| ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ | |||||
| cutlass::conv::threadblock:: \ | |||||
| ConvolutionFpropNCxHWxThreadblockSwizzle, \ | |||||
| stages_, 4, aligned_, true, cutlass::arch::OpMultiplyAdd>; \ | |||||
| typename Convolution::ConvolutionParameter conv_param( \ | |||||
| param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ | |||||
| param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ | |||||
| param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \ | |||||
| return cutlass_convolution_wrapper<Convolution>( \ | |||||
| d_src, d_filter, d_bias, \ | |||||
| reinterpret_cast<const ElementOutput*>(d_z), \ | |||||
| reinterpret_cast<ElementOutput*>(d_dst), workspace, \ | |||||
| conv_param, epilogue, stream); \ | |||||
| } | |||||
| #define DISPATCH_KERNEL \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 32, 64, 32, 32, 2, 16); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 32, 64, 32, 32, 2, 16); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 64, 32, 32, 2, 16); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 32, 64, 32, 32, 2, 16); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 2, 16); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 64, 32, 32, 64, 32, 2, 16); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 32, 32, 64, 32, 32, 2, 16); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 128, 16, 16, 128, 16, 1, 8); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 64, 8, 16, 64, 8, 2, 4); \ | |||||
| megdnn_assert(false, \ | |||||
| "unsupported threadblock shape (%dx%dx%d) and warp shape " \ | |||||
| "(%dx%dx%d)", \ | |||||
| threadblock_shape.m(), threadblock_shape.n(), \ | |||||
| threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ | |||||
| warp_shape.k()); | |||||
| using ElementOutput = cutlass::integer_subbyte<4, signedness>; | |||||
| using ElementAccumulator = int32_t; | |||||
| using ElementBias = int32_t; | |||||
| using ElementCompute = float; | |||||
| using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode; | |||||
| switch (nonlinear_mode) { | |||||
| case NonlineMode::IDENTITY: { | |||||
| using EpilogueOp = | |||||
| cutlass::epilogue::thread::BiasAddLinearCombinationClamp< | |||||
| ElementOutput, 8, ElementAccumulator, ElementBias, | |||||
| ElementCompute>; | |||||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma, | |||||
| delta + theta}; | |||||
| DISPATCH_KERNEL; | |||||
| } | |||||
| case NonlineMode::RELU: { | |||||
| using EpilogueOp = cutlass::epilogue::thread:: | |||||
| BiasAddLinearCombinationReluClamp< | |||||
| ElementOutput, 8, ElementAccumulator, ElementBias, | |||||
| ElementCompute>; | |||||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma, | |||||
| 0, delta, theta}; | |||||
| DISPATCH_KERNEL; | |||||
| } | |||||
| case NonlineMode::H_SWISH: { | |||||
| using EpilogueOp = cutlass::epilogue::thread:: | |||||
| BiasAddLinearCombinationHSwishClamp< | |||||
| ElementOutput, 8, ElementAccumulator, ElementBias, | |||||
| ElementCompute>; | |||||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma, | |||||
| scale, delta, theta}; | |||||
| DISPATCH_KERNEL; | |||||
| } | |||||
| default: | |||||
| megdnn_assert(false, | |||||
| "unsupported nonlinear mode for conv bias operator"); | |||||
| } | |||||
| #undef DISPATCH_KERNEL_WITH_TILE_SHAPE | |||||
| #undef DISPATCH_KERNEL | |||||
| } | |||||
| #endif | |||||
| #define INST(signedness) \ | |||||
| template void megdnn::cuda::cutlass_wrapper:: \ | |||||
| do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nhwc<signedness>( \ | |||||
| const int8_t* d_src, const int8_t* d_filter, \ | |||||
| const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, \ | |||||
| int* workspace, const convolution::ConvParam& param, \ | |||||
| uint32_t nonlinear_mode, float alpha, float beta, \ | |||||
| float gamma, float delta, float theta, float scale, \ | |||||
| const GemmCoord& threadblock_shape, \ | |||||
| const GemmCoord& warp_shape, int stages, \ | |||||
| cudaStream_t stream); | |||||
| INST(true); | |||||
| INST(false); | |||||
| #undef INST | |||||
| // vim: syntax=cuda.doxygen | |||||
| @@ -1,65 +0,0 @@ | |||||
| /** | |||||
| * \file | |||||
| * dnn/src/cuda/conv_bias/int8/implicit_gemm_conv_bias_cutlass_wrapper.cuinl | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #include "cutlass/convolution/device/convolution.h" | |||||
| #include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh" | |||||
| using namespace megdnn; | |||||
| using namespace cuda; | |||||
| using namespace cutlass_wrapper; | |||||
| template <typename Convolution> | |||||
| void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper( | |||||
| const typename Convolution::ElementSrc* d_src, | |||||
| const typename Convolution::ElementFilter* d_filter, | |||||
| const typename Convolution::ElementBias* d_bias, | |||||
| const typename Convolution::ElementDst* d_z, | |||||
| typename Convolution::ElementDst* d_dst, int* workspace, | |||||
| typename Convolution::ConvolutionParameter const& conv_param, | |||||
| typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||||
| cudaStream_t stream, typename Convolution::ExtraParam extra_param) { | |||||
| typename Convolution::TensorRefSrc tensor_src{ | |||||
| const_cast<typename Convolution::ElementSrc*>(d_src), | |||||
| Convolution::LayoutSrc::packed( | |||||
| {conv_param.N, conv_param.H, conv_param.W, conv_param.C})}; | |||||
| typename Convolution::TensorRefFilter tensor_filter{ | |||||
| const_cast<typename Convolution::ElementFilter*>(d_filter), | |||||
| Convolution::LayoutFilter::packed( | |||||
| {conv_param.K, conv_param.R, conv_param.S, conv_param.C})}; | |||||
| typename Convolution::TensorRefBias tensor_bias{ | |||||
| const_cast<typename Convolution::ElementBias*>(d_bias), | |||||
| Convolution::LayoutBias::packed({1, 1, 1, conv_param.K})}; | |||||
| typename Convolution::TensorRefDst tensor_z{ | |||||
| const_cast<typename Convolution::ElementDst*>(d_z), | |||||
| Convolution::LayoutDst::packed( | |||||
| {conv_param.N, conv_param.P, conv_param.Q, conv_param.K})}; | |||||
| typename Convolution::TensorRefDst tensor_dst{ | |||||
| d_dst, | |||||
| Convolution::LayoutDst::packed( | |||||
| {conv_param.N, conv_param.P, conv_param.Q, conv_param.K})}; | |||||
| typename Convolution::Arguments arguments{conv_param, | |||||
| tensor_src.non_const_ref(), | |||||
| tensor_filter.non_const_ref(), | |||||
| tensor_bias.non_const_ref(), | |||||
| tensor_z.non_const_ref(), | |||||
| tensor_dst.non_const_ref(), | |||||
| epilogue, | |||||
| {}, | |||||
| {}, | |||||
| extra_param}; | |||||
| Convolution conv_op; | |||||
| cutlass_check(conv_op.initialize(arguments, workspace)); | |||||
| cutlass_check(conv_op(stream)); | |||||
| after_kernel_launch(); | |||||
| } | |||||
| // vim: syntax=cuda.doxygen | |||||
| @@ -10,8 +10,7 @@ | |||||
| * implied. | * implied. | ||||
| */ | */ | ||||
| #include "./algo.h" | |||||
| #include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh" | |||||
| #include "src/cuda/conv_bias/algo.h" | |||||
| using namespace megdnn; | using namespace megdnn; | ||||
| using namespace cuda; | using namespace cuda; | ||||
| @@ -81,29 +80,6 @@ ConvBiasForwardImpl::AlgoInt4Int4NCHW64IMMAImplicitGemm::get_constants( | |||||
| return {alpha, beta, gamma, delta, theta}; | return {alpha, beta, gamma, delta, theta}; | ||||
| } | } | ||||
| void ConvBiasForwardImpl::AlgoInt4Int4NCHW64IMMAImplicitGemm::do_exec( | |||||
| const ExecArgs& args, void* filter_ptr, void* bias_ptr, void* z_ptr, | |||||
| ConvParam kern_param, uint32_t nonlinear_mode, float alpha, float beta, | |||||
| float gamma, float delta, float theta, cudaStream_t stream) const { | |||||
| float dst_scale = args.dst_layout->dtype.param<dtype::QuantizedS4>().scale; | |||||
| cutlass_wrapper::GemmCoord threadblock_shape{m_algo_param.threadblock_m, | |||||
| m_algo_param.threadblock_n, | |||||
| m_algo_param.threadblock_k}; | |||||
| cutlass_wrapper::GemmCoord warp_shape{ | |||||
| m_algo_param.warp_m, m_algo_param.warp_n, m_algo_param.warp_k}; | |||||
| cutlass_wrapper::do_conv_bias_int4_int4_implicit_gemm_imma_ncdiv64hw64< | |||||
| true>(reinterpret_cast<int8_t*>(args.src_tensor->raw_ptr), | |||||
| reinterpret_cast<int8_t*>(filter_ptr), | |||||
| reinterpret_cast<int32_t*>(bias_ptr), | |||||
| reinterpret_cast<int8_t*>(z_ptr), | |||||
| reinterpret_cast<int8_t*>(args.dst_tensor->raw_ptr), nullptr, | |||||
| kern_param, nonlinear_mode, alpha, beta, gamma, dst_scale, | |||||
| threadblock_shape, warp_shape, m_algo_param.stage, stream); | |||||
| } | |||||
| #endif | #endif | ||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
| @@ -10,8 +10,7 @@ | |||||
| * implied. | * implied. | ||||
| */ | */ | ||||
| #include "./algo.h" | |||||
| #include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh" | |||||
| #include "src/cuda/conv_bias/algo.h" | |||||
| using namespace megdnn; | using namespace megdnn; | ||||
| using namespace cuda; | using namespace cuda; | ||||
| @@ -81,42 +80,6 @@ ConvBiasForwardImpl::AlgoInt4Int4NHWCIMMAImplicitGemm::get_constants( | |||||
| return {alpha, beta, gamma, delta, theta}; | return {alpha, beta, gamma, delta, theta}; | ||||
| } | } | ||||
| void ConvBiasForwardImpl::AlgoInt4Int4NHWCIMMAImplicitGemm::do_exec( | |||||
| const ExecArgs& args, void* filter_ptr, void* bias_ptr, void* z_ptr, | |||||
| ConvParam kern_param, uint32_t nonlinear_mode, float alpha, float beta, | |||||
| float gamma, float delta, float theta, cudaStream_t stream) const { | |||||
| float dst_scale = args.dst_layout->dtype.param<dtype::QuantizedS4>().scale; | |||||
| cutlass_wrapper::GemmCoord threadblock_shape{m_algo_param.threadblock_m, | |||||
| m_algo_param.threadblock_n, | |||||
| m_algo_param.threadblock_k}; | |||||
| cutlass_wrapper::GemmCoord warp_shape{ | |||||
| m_algo_param.warp_m, m_algo_param.warp_n, m_algo_param.warp_k}; | |||||
| if (kern_param.fh == 1 && kern_param.fw == 1) { | |||||
| cutlass_wrapper::do_conv_bias_int4_int4_implicit_gemm_imma_nhwc<false>( | |||||
| reinterpret_cast<int8_t*>(args.src_tensor->raw_ptr), | |||||
| reinterpret_cast<int8_t*>(filter_ptr), | |||||
| reinterpret_cast<int32_t*>(bias_ptr), | |||||
| reinterpret_cast<int8_t*>(z_ptr), | |||||
| reinterpret_cast<int8_t*>(args.dst_tensor->raw_ptr), nullptr, | |||||
| kern_param, nonlinear_mode, alpha, beta, gamma, dst_scale, | |||||
| threadblock_shape, warp_shape, m_algo_param.access_size, | |||||
| m_algo_param.stage, stream); | |||||
| } else { | |||||
| cutlass_wrapper::do_conv_bias_int4_int4_implicit_gemm_imma_nhwc<true>( | |||||
| reinterpret_cast<int8_t*>(args.src_tensor->raw_ptr), | |||||
| reinterpret_cast<int8_t*>(filter_ptr), | |||||
| reinterpret_cast<int32_t*>(bias_ptr), | |||||
| reinterpret_cast<int8_t*>(z_ptr), | |||||
| reinterpret_cast<int8_t*>(args.dst_tensor->raw_ptr), nullptr, | |||||
| kern_param, nonlinear_mode, alpha, beta, gamma, dst_scale, | |||||
| threadblock_shape, warp_shape, m_algo_param.access_size, | |||||
| m_algo_param.stage, stream); | |||||
| } | |||||
| } | |||||
| #endif | #endif | ||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
| @@ -10,10 +10,9 @@ | |||||
| * implied. | * implied. | ||||
| */ | */ | ||||
| #include "./algo.h" | |||||
| #include "src/common/conv_bias.h" | #include "src/common/conv_bias.h" | ||||
| #include "src/cuda/conv_bias/algo.h" | |||||
| #include "src/cuda/conv_bias/cutlass_reorder_filter.cuh" | #include "src/cuda/conv_bias/cutlass_reorder_filter.cuh" | ||||
| #include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh" | |||||
| #include "src/cuda/conv_bias/reduce_filter.cuh" | #include "src/cuda/conv_bias/reduce_filter.cuh" | ||||
| #include "src/cuda/convolution_helper/parameter.cuh" | #include "src/cuda/convolution_helper/parameter.cuh" | ||||
| #include "src/cuda/utils.h" | #include "src/cuda/utils.h" | ||||
| @@ -102,22 +101,40 @@ void ConvBiasForwardImpl::AlgoInt4NCHW64IMMAImplicitGemmBase::exec( | |||||
| if (args.z_layout->ndim > 0) | if (args.z_layout->ndim > 0) | ||||
| z_ptr = args.z_tensor->raw_ptr; | z_ptr = args.z_tensor->raw_ptr; | ||||
| // \note these constants of cutlass epilogue will be passed to method | |||||
| // `execute_cutlass_conv_op` by pointer and interpreted as ElementCompute*, | |||||
| // a different dtype here results in undefined epilogue behaviors | |||||
| float alpha, beta, gamma, delta, theta; | float alpha, beta, gamma, delta, theta; | ||||
| std::tie(alpha, beta, gamma, delta, theta) = get_constants(args); | std::tie(alpha, beta, gamma, delta, theta) = get_constants(args); | ||||
| float dst_scale = 0.f; | |||||
| float threshold = 0.f; | |||||
| uint8_t src_zero = 0; | |||||
| bool load_from_const = !(fh == 1 && fw == 1); | |||||
| bool without_shared_load = true; | |||||
| if (args.dst_layout->dtype.enumv() == DTypeEnum::Quantized4Asymm) { | |||||
| dst_scale = | |||||
| args.dst_layout->dtype.param<dtype::Quantized4Asymm>().scale; | |||||
| src_zero = args.src_layout->dtype.param<dtype::Quantized4Asymm>() | |||||
| .zero_point; | |||||
| } else { // DTypeEnum::QuantizedS4 | |||||
| dst_scale = args.dst_layout->dtype.param<dtype::QuantizedS4>().scale; | |||||
| } | |||||
| ConvParam kern_param; | |||||
| kern_param.n = n, kern_param.co = co, kern_param.ci = ci, | |||||
| kern_param.hi = hi, kern_param.wi = wi, kern_param.ho = ho, | |||||
| kern_param.wo = wo, kern_param.ph = ph, kern_param.pw = pw, | |||||
| kern_param.sh = sh, kern_param.sw = sw, kern_param.fh = fh, | |||||
| kern_param.fw = fw; | |||||
| cudaStream_t stream = cuda_stream(args.opr->handle()); | |||||
| uint32_t nonlinear_mode = static_cast<uint32_t>(param.nonlineMode); | |||||
| const auto* op = get_cutlass_conv_op(args, ConvOperator::kFprop, | |||||
| ConvType::kConvolution, | |||||
| load_from_const, without_shared_load); | |||||
| cudaStream_t stream = cuda_stream(args.opr->handle()); | |||||
| execute_cutlass_conv_op(op, args.src_tensor->raw_ptr, filter_ptr, bias_ptr, | |||||
| z_ptr, args.dst_tensor->raw_ptr, nullptr, n, hi, wi, | |||||
| ci, co, fh, fw, ho, wo, ph, pw, sh, sw, dh, dw, | |||||
| &alpha, &beta, &gamma, &delta, &theta, &threshold, | |||||
| &dst_scale, stream, &src_zero); | |||||
| do_exec(args, filter_ptr, bias_ptr, z_ptr, kern_param, nonlinear_mode, | |||||
| alpha, beta, gamma, delta, theta, stream); | |||||
| after_kernel_launch(); | |||||
| } | } | ||||
| std::string ConvBiasForwardImpl::AlgoInt4NCHW64IMMAImplicitGemmBase::to_string( | std::string ConvBiasForwardImpl::AlgoInt4NCHW64IMMAImplicitGemmBase::to_string( | ||||
| @@ -10,10 +10,9 @@ | |||||
| * implied. | * implied. | ||||
| */ | */ | ||||
| #include "./algo.h" | |||||
| #include "src/common/conv_bias.h" | #include "src/common/conv_bias.h" | ||||
| #include "src/cuda/conv_bias/algo.h" | |||||
| #include "src/cuda/conv_bias/cutlass_reorder_filter.cuh" | #include "src/cuda/conv_bias/cutlass_reorder_filter.cuh" | ||||
| #include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh" | |||||
| #include "src/cuda/conv_bias/reduce_filter.cuh" | #include "src/cuda/conv_bias/reduce_filter.cuh" | ||||
| #include "src/cuda/convolution_helper/parameter.cuh" | #include "src/cuda/convolution_helper/parameter.cuh" | ||||
| #include "src/cuda/utils.h" | #include "src/cuda/utils.h" | ||||
| @@ -109,22 +108,43 @@ void ConvBiasForwardImpl::AlgoInt4NHWCIMMAImplicitGemmBase::exec( | |||||
| if (args.z_layout->ndim > 0) | if (args.z_layout->ndim > 0) | ||||
| z_ptr = args.z_tensor->raw_ptr; | z_ptr = args.z_tensor->raw_ptr; | ||||
| // \note these constants of cutlass epilogue will be passed to method | |||||
| // `execute_cutlass_conv_op` by pointer and interpreted as ElementCompute*, | |||||
| // a different dtype here results in undefined epilogue behaviors | |||||
| float alpha, beta, gamma, delta, theta; | float alpha, beta, gamma, delta, theta; | ||||
| std::tie(alpha, beta, gamma, delta, theta) = get_constants(args); | std::tie(alpha, beta, gamma, delta, theta) = get_constants(args); | ||||
| float dst_scale = 0.f; | |||||
| float threshold = 0.f; | |||||
| uint8_t src_zero = 0; | |||||
| bool load_from_const = !(fh == 1 && fw == 1); | |||||
| bool without_shared_load = ((co % m_algo_param.threadblock_n == 0) && | |||||
| (m_algo_param.threadblock_n == 32 || | |||||
| m_algo_param.threadblock_n == 64)); | |||||
| if (args.dst_layout->dtype.enumv() == DTypeEnum::Quantized4Asymm) { | |||||
| dst_scale = | |||||
| args.dst_layout->dtype.param<dtype::Quantized4Asymm>().scale; | |||||
| src_zero = args.src_layout->dtype.param<dtype::Quantized4Asymm>() | |||||
| .zero_point; | |||||
| } else { // DTypeEnum::QuantizedS4 | |||||
| dst_scale = args.dst_layout->dtype.param<dtype::QuantizedS4>().scale; | |||||
| } | |||||
| ConvParam kern_param; | |||||
| kern_param.n = n, kern_param.co = co, kern_param.ci = ci, | |||||
| kern_param.hi = hi, kern_param.wi = wi, kern_param.ho = ho, | |||||
| kern_param.wo = wo, kern_param.ph = ph, kern_param.pw = pw, | |||||
| kern_param.sh = sh, kern_param.sw = sw, kern_param.fh = fh, | |||||
| kern_param.fw = fw; | |||||
| cudaStream_t stream = cuda_stream(args.opr->handle()); | |||||
| uint32_t nonlinear_mode = static_cast<uint32_t>(param.nonlineMode); | |||||
| const auto* op = get_cutlass_conv_op(args, ConvOperator::kFprop, | |||||
| ConvType::kConvolution, | |||||
| load_from_const, without_shared_load); | |||||
| cudaStream_t stream = cuda_stream(args.opr->handle()); | |||||
| execute_cutlass_conv_op(op, args.src_tensor->raw_ptr, filter_ptr, bias_ptr, | |||||
| z_ptr, args.dst_tensor->raw_ptr, nullptr, n, hi, wi, | |||||
| ci, co, fh, fw, ho, wo, ph, pw, sh, sw, dh, dw, | |||||
| &alpha, &beta, &gamma, &delta, &theta, &threshold, | |||||
| &dst_scale, stream, &src_zero); | |||||
| do_exec(args, filter_ptr, bias_ptr, z_ptr, kern_param, nonlinear_mode, | |||||
| alpha, beta, gamma, delta, theta, stream); | |||||
| after_kernel_launch(); | |||||
| } | } | ||||
| std::string ConvBiasForwardImpl::AlgoInt4NHWCIMMAImplicitGemmBase::to_string( | std::string ConvBiasForwardImpl::AlgoInt4NHWCIMMAImplicitGemmBase::to_string( | ||||
| @@ -10,12 +10,11 @@ | |||||
| * implied. | * implied. | ||||
| */ | */ | ||||
| #include "./algo.h" | |||||
| #include "src/common/conv_bias.h" | |||||
| #include "src/cuda/conv_bias/algo.h" | |||||
| #include "src/cuda/conv_bias/cutlass_reorder_filter.cuh" | #include "src/cuda/conv_bias/cutlass_reorder_filter.cuh" | ||||
| #include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh" | |||||
| #include "src/cuda/convolution_helper/parameter.cuh" | #include "src/cuda/convolution_helper/parameter.cuh" | ||||
| #include "src/cuda/utils.h" | #include "src/cuda/utils.h" | ||||
| #include "src/common/conv_bias.h" | |||||
| using namespace megdnn; | using namespace megdnn; | ||||
| using namespace cuda; | using namespace cuda; | ||||
| @@ -38,8 +37,7 @@ bool ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::is_available( | |||||
| bool available = true; | bool available = true; | ||||
| auto&& param = args.opr->param(); | auto&& param = args.opr->param(); | ||||
| auto&& fm = args.filter_meta; | auto&& fm = args.filter_meta; | ||||
| if (!check_bias_share_in_channel(*(args.bias_layout), | |||||
| param.format)) | |||||
| if (!check_bias_share_in_channel(*(args.bias_layout), param.format)) | |||||
| return false; | return false; | ||||
| if (param.format != Format::NCHW32 && param.format != Format::NCHW32_NCHW4) | if (param.format != Format::NCHW32 && param.format != Format::NCHW32_NCHW4) | ||||
| return false; | return false; | ||||
| @@ -137,19 +135,16 @@ void ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::exec( | |||||
| args.preprocessed_filter->tensors[0].raw_ptr); | args.preprocessed_filter->tensors[0].raw_ptr); | ||||
| } | } | ||||
| ConvParam kern_param; | |||||
| kern_param.n = n, kern_param.co = co, kern_param.ci = ci, | |||||
| kern_param.hi = hi, kern_param.wi = wi, kern_param.ho = ho, | |||||
| kern_param.wo = wo, kern_param.ph = ph, kern_param.pw = pw, | |||||
| kern_param.sh = sh, kern_param.sw = sw, kern_param.fh = fh, | |||||
| kern_param.fw = fw; | |||||
| float src_scale = args.src_layout->dtype.param<dtype::QuantizedS8>().scale, | float src_scale = args.src_layout->dtype.param<dtype::QuantizedS8>().scale, | ||||
| filter_scale = | filter_scale = | ||||
| args.filter_layout->dtype.param<dtype::QuantizedS8>().scale, | args.filter_layout->dtype.param<dtype::QuantizedS8>().scale, | ||||
| bias_scale = | bias_scale = | ||||
| args.bias_layout->dtype.param<dtype::QuantizedS32>().scale, | args.bias_layout->dtype.param<dtype::QuantizedS32>().scale, | ||||
| dst_scale = args.dst_layout->dtype.param<dtype::QuantizedS8>().scale; | dst_scale = args.dst_layout->dtype.param<dtype::QuantizedS8>().scale; | ||||
| // \note these constants of cutlass epilogue will be passed to method | |||||
| // `execute_cutlass_conv_op` by pointer and interpreted as ElementCompute*, | |||||
| // a different dtype here results in undefined epilogue behaviors | |||||
| float alpha = src_scale * filter_scale / dst_scale, | float alpha = src_scale * filter_scale / dst_scale, | ||||
| beta = bias_scale / dst_scale; | beta = bias_scale / dst_scale; | ||||
| int8_t* z_dev_ptr = nullptr; | int8_t* z_dev_ptr = nullptr; | ||||
| @@ -159,80 +154,20 @@ void ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::exec( | |||||
| float z_scale = args.z_layout->dtype.param<dtype::QuantizedS8>().scale; | float z_scale = args.z_layout->dtype.param<dtype::QuantizedS8>().scale; | ||||
| gamma = z_scale / dst_scale; | gamma = z_scale / dst_scale; | ||||
| } | } | ||||
| uint32_t nonlinear_mode = static_cast<uint32_t>(param.nonlineMode); | |||||
| if (fh == 1 && fw == 1) { | |||||
| if (param.format == Format::NCHW32) { | |||||
| cutlass_wrapper::do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32< | |||||
| false>( | |||||
| args.src_tensor->compatible_ptr<int8_t>(), filter_ptr, | |||||
| args.bias_tensor->compatible_ptr<int32_t>(), z_dev_ptr, | |||||
| args.dst_tensor->compatible_ptr<int8_t>(), nullptr, | |||||
| kern_param, nonlinear_mode, alpha, beta, gamma, dst_scale, | |||||
| cutlass_wrapper::GemmCoord{m_algo_param.threadblock_m, | |||||
| m_algo_param.threadblock_n, | |||||
| m_algo_param.threadblock_k}, | |||||
| cutlass_wrapper::GemmCoord{m_algo_param.warp_m, | |||||
| m_algo_param.warp_n, | |||||
| m_algo_param.warp_k}, | |||||
| m_algo_param.stage, stream); | |||||
| } else { | |||||
| megdnn_assert(param.format == Format::NCHW32_NCHW4); | |||||
| cutlass_wrapper:: | |||||
| do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4< | |||||
| false>( | |||||
| args.src_tensor->compatible_ptr<int8_t>(), | |||||
| filter_ptr, | |||||
| args.bias_tensor->compatible_ptr<int32_t>(), | |||||
| z_dev_ptr, | |||||
| args.dst_tensor->compatible_ptr<int8_t>(), nullptr, | |||||
| kern_param, nonlinear_mode, alpha, beta, gamma, | |||||
| dst_scale, | |||||
| cutlass_wrapper::GemmCoord{ | |||||
| m_algo_param.threadblock_m, | |||||
| m_algo_param.threadblock_n, | |||||
| m_algo_param.threadblock_k}, | |||||
| cutlass_wrapper::GemmCoord{m_algo_param.warp_m, | |||||
| m_algo_param.warp_n, | |||||
| m_algo_param.warp_k}, | |||||
| m_algo_param.stage, stream); | |||||
| } | |||||
| } else { | |||||
| if (param.format == Format::NCHW32) { | |||||
| cutlass_wrapper::do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32< | |||||
| true>( | |||||
| args.src_tensor->compatible_ptr<int8_t>(), filter_ptr, | |||||
| args.bias_tensor->compatible_ptr<int32_t>(), z_dev_ptr, | |||||
| args.dst_tensor->compatible_ptr<int8_t>(), nullptr, | |||||
| kern_param, nonlinear_mode, alpha, beta, gamma, dst_scale, | |||||
| cutlass_wrapper::GemmCoord{m_algo_param.threadblock_m, | |||||
| m_algo_param.threadblock_n, | |||||
| m_algo_param.threadblock_k}, | |||||
| cutlass_wrapper::GemmCoord{m_algo_param.warp_m, | |||||
| m_algo_param.warp_n, | |||||
| m_algo_param.warp_k}, | |||||
| m_algo_param.stage, stream); | |||||
| } else { | |||||
| megdnn_assert(param.format == Format::NCHW32_NCHW4); | |||||
| cutlass_wrapper:: | |||||
| do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4< | |||||
| true>( | |||||
| args.src_tensor->compatible_ptr<int8_t>(), | |||||
| filter_ptr, | |||||
| args.bias_tensor->compatible_ptr<int32_t>(), | |||||
| z_dev_ptr, | |||||
| args.dst_tensor->compatible_ptr<int8_t>(), nullptr, | |||||
| kern_param, nonlinear_mode, alpha, beta, gamma, | |||||
| dst_scale, | |||||
| cutlass_wrapper::GemmCoord{ | |||||
| m_algo_param.threadblock_m, | |||||
| m_algo_param.threadblock_n, | |||||
| m_algo_param.threadblock_k}, | |||||
| cutlass_wrapper::GemmCoord{m_algo_param.warp_m, | |||||
| m_algo_param.warp_n, | |||||
| m_algo_param.warp_k}, | |||||
| m_algo_param.stage, stream); | |||||
| } | |||||
| } | |||||
| float delta = 0.f, theta = 0.f, threshold = 0.f; | |||||
| bool load_from_const = !(fh == 1 && fw == 1); | |||||
| bool without_shared_load = (param.format == Format::NCHW32); | |||||
| const auto* op = get_cutlass_conv_op(args, ConvOperator::kFprop, | |||||
| ConvType::kConvolution, | |||||
| load_from_const, without_shared_load); | |||||
| execute_cutlass_conv_op( | |||||
| op, args.src_tensor->raw_ptr, filter_ptr, args.bias_tensor->raw_ptr, | |||||
| z_dev_ptr, args.dst_tensor->raw_ptr, nullptr, n, hi, wi, ci, co, fh, | |||||
| fw, ho, wo, ph, pw, sh, sw, dh, dw, &alpha, &beta, &gamma, &delta, | |||||
| &theta, &threshold, &dst_scale, stream); | |||||
| after_kernel_launch(); | after_kernel_launch(); | ||||
| } | } | ||||
| @@ -249,9 +184,8 @@ size_t ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm:: | |||||
| return 0_z; | return 0_z; | ||||
| } | } | ||||
| SmallVector<TensorLayout> ConvBiasForwardImpl:: | |||||
| AlgoInt8NCHW32IMMAImplicitGemm::deduce_preprocessed_filter_layout( | |||||
| const SizeArgs& args) const { | |||||
| SmallVector<TensorLayout> ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm:: | |||||
| deduce_preprocessed_filter_layout(const SizeArgs& args) const { | |||||
| return {args.filter_layout->collapse_contiguous()}; | return {args.filter_layout->collapse_contiguous()}; | ||||
| } | } | ||||
| @@ -6,14 +6,14 @@ | |||||
| * | * | ||||
| * Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
| * software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | */ | ||||
| #include "./algo.h" | |||||
| #include "src/cuda/utils.h" | |||||
| #include "src/cuda/convolution_helper/parameter.cuh" | |||||
| #include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh" | |||||
| #include "src/common/conv_bias.h" | #include "src/common/conv_bias.h" | ||||
| #include "src/cuda/conv_bias/algo.h" | |||||
| #include "src/cuda/convolution_helper/parameter.cuh" | |||||
| #include "src/cuda/utils.h" | |||||
| using namespace megdnn; | using namespace megdnn; | ||||
| using namespace cuda; | using namespace cuda; | ||||
| @@ -34,8 +34,7 @@ bool ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::is_available( | |||||
| bool available = true; | bool available = true; | ||||
| auto&& param = args.opr->param(); | auto&& param = args.opr->param(); | ||||
| auto&& fm = args.filter_meta; | auto&& fm = args.filter_meta; | ||||
| if (!check_bias_share_in_channel(*(args.bias_layout), | |||||
| param.format)) | |||||
| if (!check_bias_share_in_channel(*(args.bias_layout), param.format)) | |||||
| return false; | return false; | ||||
| bool valid_format = param.format == Format::NCHW4_NCHW32 && | bool valid_format = param.format == Format::NCHW4_NCHW32 && | ||||
| m_algo_param.threadblock_m % 32 == 0; | m_algo_param.threadblock_m % 32 == 0; | ||||
| @@ -48,7 +47,8 @@ bool ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::is_available( | |||||
| (args.dst_layout->dtype.enumv() == DTypeEnum::QuantizedS4 || | (args.dst_layout->dtype.enumv() == DTypeEnum::QuantizedS4 || | ||||
| args.dst_layout->dtype.enumv() == DTypeEnum::Quantized4Asymm); | args.dst_layout->dtype.enumv() == DTypeEnum::Quantized4Asymm); | ||||
| valid_format |= param.format == Format::NCHW4; | valid_format |= param.format == Format::NCHW4; | ||||
| if (!valid_format) return false; | |||||
| if (!valid_format) | |||||
| return false; | |||||
| size_t n = args.src_layout->operator[](0), | size_t n = args.src_layout->operator[](0), | ||||
| ci = args.src_layout->operator[](1) * 4, | ci = args.src_layout->operator[](1) * 4, | ||||
| hi = args.src_layout->operator[](2), | hi = args.src_layout->operator[](2), | ||||
| @@ -170,16 +170,13 @@ void ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::exec( | |||||
| args.preprocessed_filter->tensors[0].raw_ptr); | args.preprocessed_filter->tensors[0].raw_ptr); | ||||
| } | } | ||||
| convolution::ConvParam kern_param; | |||||
| kern_param.n = n, kern_param.co = co, kern_param.ci = ci, | |||||
| kern_param.hi = hi, kern_param.wi = wi, kern_param.ho = ho, | |||||
| kern_param.wo = wo, kern_param.ph = ph, kern_param.pw = pw, | |||||
| kern_param.sh = sh, kern_param.sw = sw, kern_param.fh = fh, | |||||
| kern_param.fw = fw; | |||||
| float src_scale = args.src_layout->dtype.param<dtype::QuantizedS8>().scale, | float src_scale = args.src_layout->dtype.param<dtype::QuantizedS8>().scale, | ||||
| filter_scale = | filter_scale = | ||||
| args.filter_layout->dtype.param<dtype::QuantizedS8>().scale; | args.filter_layout->dtype.param<dtype::QuantizedS8>().scale; | ||||
| // \note these constants of cutlass epilogue will be passed to method | |||||
| // `execute_cutlass_conv_op` by pointer and interpreted as ElementCompute*, | |||||
| // a different dtype here results in undefined epilogue behaviors | |||||
| float alpha = src_scale * filter_scale; | float alpha = src_scale * filter_scale; | ||||
| float beta = 1.f; | float beta = 1.f; | ||||
| float dst_scale = 1.f; | float dst_scale = 1.f; | ||||
| @@ -192,13 +189,15 @@ void ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::exec( | |||||
| if (args.bias_layout->dtype.enumv() == DTypeEnum::QuantizedS32) { | if (args.bias_layout->dtype.enumv() == DTypeEnum::QuantizedS32) { | ||||
| megdnn_assert(args.dst_layout->dtype.category() == | megdnn_assert(args.dst_layout->dtype.category() == | ||||
| DTypeCategory::QUANTIZED); | DTypeCategory::QUANTIZED); | ||||
| float bias_scale = args.bias_layout->dtype.param<dtype::QuantizedS32>() | |||||
| .scale; | |||||
| float bias_scale = | |||||
| args.bias_layout->dtype.param<dtype::QuantizedS32>().scale; | |||||
| dst_scale = get_scale(args.dst_layout->dtype); | dst_scale = get_scale(args.dst_layout->dtype); | ||||
| alpha /= dst_scale, beta = bias_scale / dst_scale; | alpha /= dst_scale, beta = bias_scale / dst_scale; | ||||
| } | } | ||||
| float delta = 0.f; | float delta = 0.f; | ||||
| void* z_ptr = nullptr; | |||||
| if (args.z_layout->ndim > 0) { | if (args.z_layout->ndim > 0) { | ||||
| z_ptr = args.z_tensor->raw_ptr; | |||||
| gamma = 1.f; | gamma = 1.f; | ||||
| if (args.z_layout->dtype.category() == DTypeCategory::QUANTIZED) { | if (args.z_layout->dtype.category() == DTypeCategory::QUANTIZED) { | ||||
| megdnn_assert(args.dst_layout->dtype.category() == | megdnn_assert(args.dst_layout->dtype.category() == | ||||
| @@ -213,98 +212,20 @@ void ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::exec( | |||||
| delta = -z_zero * gamma; | delta = -z_zero * gamma; | ||||
| } | } | ||||
| } | } | ||||
| uint32_t nonlinear_mode = static_cast<uint32_t>(param.nonlineMode); | |||||
| bool nonunity_kernel = !(fh == 1 && fw == 1); | |||||
| #define DISPATCH(_nonunity_kernel) \ | |||||
| if (nonunity_kernel == _nonunity_kernel) { \ | |||||
| cb(_nonunity_kernel) \ | |||||
| } | |||||
| if (param.format == Format::NCHW4) { | |||||
| #define cb(_nonunity_kernel) \ | |||||
| cutlass_wrapper::do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4< \ | |||||
| _nonunity_kernel>( \ | |||||
| args.src_tensor->compatible_ptr<int8_t>(), filter_ptr, \ | |||||
| args.bias_tensor->compatible_ptr<int32_t>(), \ | |||||
| args.z_tensor->compatible_ptr<int8_t>(), \ | |||||
| args.dst_tensor->compatible_ptr<int8_t>(), nullptr, kern_param, \ | |||||
| nonlinear_mode, alpha, beta, gamma, dst_scale, \ | |||||
| cutlass_wrapper::GemmCoord{m_algo_param.threadblock_m, \ | |||||
| m_algo_param.threadblock_n, \ | |||||
| m_algo_param.threadblock_k}, \ | |||||
| cutlass_wrapper::GemmCoord{m_algo_param.warp_m, \ | |||||
| m_algo_param.warp_n, \ | |||||
| m_algo_param.warp_k}, \ | |||||
| m_algo_param.stage, stream); | |||||
| DISPATCH(true); | |||||
| DISPATCH(false); | |||||
| #undef cb | |||||
| } else if (param.format == Format::NCHW4_NCHW) { | |||||
| #define cb(_nonunity_kernel) \ | |||||
| cutlass_wrapper::do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nchw< \ | |||||
| _nonunity_kernel>( \ | |||||
| args.src_tensor->compatible_ptr<int8_t>(), filter_ptr, \ | |||||
| args.bias_tensor->compatible_ptr<float>(), \ | |||||
| args.z_tensor->compatible_ptr<float>(), \ | |||||
| args.dst_tensor->compatible_ptr<float>(), nullptr, kern_param, \ | |||||
| nonlinear_mode, alpha, beta, gamma, dst_scale, \ | |||||
| cutlass_wrapper::GemmCoord{m_algo_param.threadblock_m, \ | |||||
| m_algo_param.threadblock_n, \ | |||||
| m_algo_param.threadblock_k}, \ | |||||
| cutlass_wrapper::GemmCoord{m_algo_param.warp_m, \ | |||||
| m_algo_param.warp_n, \ | |||||
| m_algo_param.warp_k}, \ | |||||
| m_algo_param.stage, stream); | |||||
| DISPATCH(true); | |||||
| DISPATCH(false); | |||||
| #undef cb | |||||
| } else if (param.format == Format::NCHW4_NHWC) { | |||||
| #define cb(_signedness) \ | |||||
| cutlass_wrapper::do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nhwc< \ | |||||
| _signedness>( \ | |||||
| args.src_tensor->compatible_ptr<int8_t>(), filter_ptr, \ | |||||
| args.bias_tensor->compatible_ptr<int32_t>(), \ | |||||
| reinterpret_cast<int8_t*>(args.z_tensor->raw_ptr), \ | |||||
| reinterpret_cast<int8_t*>(args.dst_tensor->raw_ptr), nullptr, \ | |||||
| kern_param, nonlinear_mode, alpha, beta, gamma, delta, theta, \ | |||||
| dst_scale, \ | |||||
| cutlass_wrapper::GemmCoord{m_algo_param.threadblock_m, \ | |||||
| m_algo_param.threadblock_n, \ | |||||
| m_algo_param.threadblock_k}, \ | |||||
| cutlass_wrapper::GemmCoord{m_algo_param.warp_m, \ | |||||
| m_algo_param.warp_n, \ | |||||
| m_algo_param.warp_k}, \ | |||||
| m_algo_param.stage, stream); | |||||
| if (args.dst_layout->dtype.enumv() == DTypeEnum::QuantizedS4) { | |||||
| cb(true); | |||||
| } else { | |||||
| megdnn_assert(args.dst_layout->dtype.enumv() == | |||||
| DTypeEnum::Quantized4Asymm); | |||||
| cb(false); | |||||
| } | |||||
| #undef cb | |||||
| } else { | |||||
| megdnn_assert(param.format == Format::NCHW4_NCHW32); | |||||
| #define cb(_nonunity_kernel) \ | |||||
| cutlass_wrapper:: \ | |||||
| do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32< \ | |||||
| _nonunity_kernel>( \ | |||||
| args.src_tensor->compatible_ptr<int8_t>(), filter_ptr, \ | |||||
| args.bias_tensor->compatible_ptr<int32_t>(), \ | |||||
| args.z_tensor->compatible_ptr<int8_t>(), \ | |||||
| args.dst_tensor->compatible_ptr<int8_t>(), nullptr, \ | |||||
| kern_param, nonlinear_mode, alpha, beta, gamma, dst_scale, \ | |||||
| cutlass_wrapper::GemmCoord{m_algo_param.threadblock_m, \ | |||||
| m_algo_param.threadblock_n, \ | |||||
| m_algo_param.threadblock_k}, \ | |||||
| cutlass_wrapper::GemmCoord{m_algo_param.warp_m, \ | |||||
| m_algo_param.warp_n, \ | |||||
| m_algo_param.warp_k}, \ | |||||
| m_algo_param.stage, stream); | |||||
| DISPATCH(true); | |||||
| DISPATCH(false); | |||||
| #undef cb | |||||
| #undef DISPATCH | |||||
| } | |||||
| float threshold = 0.f; | |||||
| bool load_from_const = !(fh == 1 && fw == 1); | |||||
| bool without_shared_load = false; | |||||
| const auto* op = get_cutlass_conv_op(args, ConvOperator::kFprop, | |||||
| ConvType::kConvolution, | |||||
| load_from_const, without_shared_load); | |||||
| execute_cutlass_conv_op( | |||||
| op, args.src_tensor->raw_ptr, filter_ptr, args.bias_tensor->raw_ptr, | |||||
| z_ptr, args.dst_tensor->raw_ptr, nullptr, n, hi, wi, ci, co, fh, fw, | |||||
| ho, wo, ph, pw, sh, sw, dh, dw, &alpha, &beta, &gamma, &delta, | |||||
| &theta, &threshold, &dst_scale, stream); | |||||
| after_kernel_launch(); | after_kernel_launch(); | ||||
| } | } | ||||
| @@ -10,8 +10,7 @@ | |||||
| * implied. | * implied. | ||||
| */ | */ | ||||
| #include "./algo.h" | |||||
| #include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh" | |||||
| #include "src/cuda/conv_bias/algo.h" | |||||
| #include "src/cuda/conv_bias/reduce_filter.cuh" | #include "src/cuda/conv_bias/reduce_filter.cuh" | ||||
| #include "src/cuda/utils.h" | #include "src/cuda/utils.h" | ||||
| @@ -120,32 +119,15 @@ ConvBiasForwardImpl::AlgoUInt4Int4NCHW64IMMAImplicitGemm::get_constants( | |||||
| delta = -z_zero * gamma; | delta = -z_zero * gamma; | ||||
| } | } | ||||
| return {alpha, beta, gamma, delta, theta}; | |||||
| } | |||||
| // identity epilogue has no theta: | |||||
| // alpha * accumulator + beta * bias + gamma * source + delta | |||||
| if (args.opr->param().nonlineMode == | |||||
| param::ConvBias::NonlineMode::IDENTITY) { | |||||
| delta += theta; | |||||
| theta = 0.f; | |||||
| } | |||||
| void ConvBiasForwardImpl::AlgoUInt4Int4NCHW64IMMAImplicitGemm::do_exec( | |||||
| const ExecArgs& args, void* filter_ptr, void* bias_ptr, void* z_ptr, | |||||
| ConvParam kern_param, uint32_t nonlinear_mode, float alpha, float beta, | |||||
| float gamma, float delta, float theta, cudaStream_t stream) const { | |||||
| float dst_scale = | |||||
| args.dst_layout->dtype.param<dtype::Quantized4Asymm>().scale; | |||||
| uint8_t src_zero = | |||||
| args.src_layout->dtype.param<dtype::Quantized4Asymm>().zero_point; | |||||
| cutlass_wrapper::GemmCoord threadblock_shape{m_algo_param.threadblock_m, | |||||
| m_algo_param.threadblock_n, | |||||
| m_algo_param.threadblock_k}; | |||||
| cutlass_wrapper::GemmCoord warp_shape{ | |||||
| m_algo_param.warp_m, m_algo_param.warp_n, m_algo_param.warp_k}; | |||||
| cutlass_wrapper::do_conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64< | |||||
| true>(reinterpret_cast<uint8_t*>(args.src_tensor->raw_ptr), | |||||
| reinterpret_cast<int8_t*>(filter_ptr), | |||||
| reinterpret_cast<int32_t*>(bias_ptr), | |||||
| reinterpret_cast<uint8_t*>(z_ptr), | |||||
| reinterpret_cast<uint8_t*>(args.dst_tensor->raw_ptr), nullptr, | |||||
| kern_param, nonlinear_mode, alpha, beta, gamma, delta, theta, | |||||
| dst_scale, src_zero, threadblock_shape, warp_shape, | |||||
| m_algo_param.stage, stream); | |||||
| return {alpha, beta, gamma, delta, theta}; | |||||
| } | } | ||||
| void ConvBiasForwardImpl::AlgoUInt4Int4NCHW64IMMAImplicitGemm::update_bias( | void ConvBiasForwardImpl::AlgoUInt4Int4NCHW64IMMAImplicitGemm::update_bias( | ||||
| @@ -10,8 +10,7 @@ | |||||
| * implied. | * implied. | ||||
| */ | */ | ||||
| #include "./algo.h" | |||||
| #include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh" | |||||
| #include "src/cuda/conv_bias/algo.h" | |||||
| #include "src/cuda/conv_bias/reduce_filter.cuh" | #include "src/cuda/conv_bias/reduce_filter.cuh" | ||||
| #include "src/cuda/utils.h" | #include "src/cuda/utils.h" | ||||
| @@ -121,44 +120,15 @@ ConvBiasForwardImpl::AlgoUInt4Int4NHWCIMMAImplicitGemm::get_constants( | |||||
| delta = -z_zero * gamma; | delta = -z_zero * gamma; | ||||
| } | } | ||||
| return {alpha, beta, gamma, delta, theta}; | |||||
| } | |||||
| void ConvBiasForwardImpl::AlgoUInt4Int4NHWCIMMAImplicitGemm::do_exec( | |||||
| const ExecArgs& args, void* filter_ptr, void* bias_ptr, void* z_ptr, | |||||
| ConvParam kern_param, uint32_t nonlinear_mode, float alpha, float beta, | |||||
| float gamma, float delta, float theta, cudaStream_t stream) const { | |||||
| float dst_scale = | |||||
| args.dst_layout->dtype.param<dtype::Quantized4Asymm>().scale; | |||||
| uint8_t src_zero = | |||||
| args.src_layout->dtype.param<dtype::Quantized4Asymm>().zero_point; | |||||
| cutlass_wrapper::GemmCoord threadblock_shape{m_algo_param.threadblock_m, | |||||
| m_algo_param.threadblock_n, | |||||
| m_algo_param.threadblock_k}; | |||||
| cutlass_wrapper::GemmCoord warp_shape{ | |||||
| m_algo_param.warp_m, m_algo_param.warp_n, m_algo_param.warp_k}; | |||||
| if (kern_param.fh == 1 && kern_param.fw == 1) { | |||||
| cutlass_wrapper::do_conv_bias_uint4_int4_implicit_gemm_imma_nhwc<false>( | |||||
| reinterpret_cast<uint8_t*>(args.src_tensor->raw_ptr), | |||||
| reinterpret_cast<int8_t*>(filter_ptr), | |||||
| reinterpret_cast<int32_t*>(bias_ptr), | |||||
| reinterpret_cast<uint8_t*>(z_ptr), | |||||
| reinterpret_cast<uint8_t*>(args.dst_tensor->raw_ptr), nullptr, | |||||
| kern_param, nonlinear_mode, alpha, beta, gamma, delta, theta, | |||||
| dst_scale, src_zero, threadblock_shape, warp_shape, | |||||
| m_algo_param.access_size, m_algo_param.stage, stream); | |||||
| } else { | |||||
| cutlass_wrapper::do_conv_bias_uint4_int4_implicit_gemm_imma_nhwc<true>( | |||||
| reinterpret_cast<uint8_t*>(args.src_tensor->raw_ptr), | |||||
| reinterpret_cast<int8_t*>(filter_ptr), | |||||
| reinterpret_cast<int32_t*>(bias_ptr), | |||||
| reinterpret_cast<uint8_t*>(z_ptr), | |||||
| reinterpret_cast<uint8_t*>(args.dst_tensor->raw_ptr), nullptr, | |||||
| kern_param, nonlinear_mode, alpha, beta, gamma, delta, theta, | |||||
| dst_scale, src_zero, threadblock_shape, warp_shape, | |||||
| m_algo_param.access_size, m_algo_param.stage, stream); | |||||
| // identity epilogue has no theta: | |||||
| // alpha * accumulator + beta * bias + gamma * source + delta | |||||
| if (args.opr->param().nonlineMode == | |||||
| param::ConvBias::NonlineMode::IDENTITY) { | |||||
| delta += theta; | |||||
| theta = 0.f; | |||||
| } | } | ||||
| return {alpha, beta, gamma, delta, theta}; | |||||
| } | } | ||||
| void ConvBiasForwardImpl::AlgoUInt4Int4NHWCIMMAImplicitGemm::update_bias( | void ConvBiasForwardImpl::AlgoUInt4Int4NHWCIMMAImplicitGemm::update_bias( | ||||
| @@ -57,6 +57,7 @@ public: | |||||
| class AlgoBatchedMatmul; | class AlgoBatchedMatmul; | ||||
| class AlgoGroupConvGeneral; | class AlgoGroupConvGeneral; | ||||
| class AlgoQUInt4x4x32WMMA; | class AlgoQUInt4x4x32WMMA; | ||||
| class AlgoCutlassConvolutionBase; | |||||
| class AlgoInt8CHWN4DotProdImplicitGemm; | class AlgoInt8CHWN4DotProdImplicitGemm; | ||||
| class AlgoInt8NCHW4DotProdImplicitGemm; | class AlgoInt8NCHW4DotProdImplicitGemm; | ||||
| class AlgoInt8CHWN4IMMAImplicitGemm; | class AlgoInt8CHWN4IMMAImplicitGemm; | ||||
| @@ -1,100 +0,0 @@ | |||||
| /** | |||||
| * \file src/cuda/convolution/backward_data/cutlass_deconvolution_wrapper.cu | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| // ignore warning of cutlass | |||||
| #pragma GCC diagnostic push | |||||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||||
| #if !MEGDNN_TEGRA_X1 | |||||
| #include "cutlass/convolution/device/convolution.h" | |||||
| #endif | |||||
| #include "src/common/opr_param_defs_enumv.cuh" | |||||
| #include "src/cuda/convolution/backward_data/cutlass_deconvolution_wrapper.cuh" | |||||
| #pragma GCC diagnostic pop | |||||
| using namespace megdnn; | |||||
| using namespace cuda; | |||||
| using namespace cutlass_wrapper; | |||||
| /* ================ cutlass kernel wrapper for nchw4 layout ================= */ | |||||
| #if MEGDNN_TEGRA_X1 | |||||
| void megdnn::cuda::cutlass_wrapper::do_deconv_int8_implicit_gemm_dp4a_ncdiv4hw4( | |||||
| const int8_t* /* d_src */, const int8_t* /* d_filter */, | |||||
| int8_t* /* d_dst */, int* /* workspace */, | |||||
| const convolution::ConvParam& /* param */, float /* alpha */, | |||||
| const GemmCoord& /* threadblock_shape */, | |||||
| const GemmCoord& /* warp_shape */, int /* stages */, | |||||
| cudaStream_t /* stream */) {} | |||||
| #else | |||||
| void megdnn::cuda::cutlass_wrapper::do_deconv_int8_implicit_gemm_dp4a_ncdiv4hw4( | |||||
| const int8_t* d_src, const int8_t* d_filter, int8_t* d_dst, | |||||
| int* workspace, const convolution::ConvParam& param, float alpha, | |||||
| const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, | |||||
| int stages, cudaStream_t stream) { | |||||
| #define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \ | |||||
| threadblock_k_, warp_m_, warp_n_, \ | |||||
| warp_k_, stage_, aligned_) \ | |||||
| if (threadblock_shape.m() == threadblock_m_ && \ | |||||
| threadblock_shape.n() == threadblock_n_ && \ | |||||
| threadblock_shape.k() == threadblock_k_ && \ | |||||
| warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ | |||||
| warp_shape.k() == warp_k_ && stages == stage_) { \ | |||||
| using ThreadBlockShape = \ | |||||
| cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \ | |||||
| threadblock_k_>; \ | |||||
| using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \ | |||||
| using InstructionShape = cutlass::gemm::GemmShape<1, 1, 4>; \ | |||||
| using Deconvolution = cutlass::conv::device::Deconvolution< \ | |||||
| int8_t, cutlass::layout::TensorNCxHWx<4>, int8_t, \ | |||||
| cutlass::layout::TensorKxRSCx<4>, ElementOutput, \ | |||||
| cutlass::layout::TensorNCxHWx<4>, int32_t, \ | |||||
| cutlass::layout::TensorNCxHWx<4>, int32_t, \ | |||||
| cutlass::arch::OpClassSimt, cutlass::arch::Sm61, \ | |||||
| ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ | |||||
| cutlass::conv::threadblock:: \ | |||||
| ConvolutionDgradNCxHWxThreadblockSwizzle, \ | |||||
| stage_, 4, aligned_, true, cutlass::arch::OpMultiplyAdd>; \ | |||||
| typename Deconvolution::ConvolutionParameter conv_param( \ | |||||
| param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ | |||||
| param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ | |||||
| param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \ | |||||
| return cutlass_deconvolution_wrapper<Deconvolution>( \ | |||||
| d_src, d_filter, nullptr, nullptr, d_dst, workspace, \ | |||||
| conv_param, epilogue, stream); \ | |||||
| } | |||||
| #define DISPATCH_KERNEL \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 64, 8, 16, 64, 8, 2, 4); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 128, 16, 16, 64, 16, 2, 4); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 128, 16, 16, 128, 16, 1, 8); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 2, 16); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 64, 32, 32, 2, 16); \ | |||||
| megdnn_assert(false, \ | |||||
| "unsupported threadblock shape (%dx%dx%d) and warp shape " \ | |||||
| "(%dx%dx%d)", \ | |||||
| threadblock_shape.m(), threadblock_shape.n(), \ | |||||
| threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ | |||||
| warp_shape.k()); | |||||
| using ElementOutput = int8_t; | |||||
| using ElementAccumulator = int32_t; | |||||
| using ElementBias = int32_t; | |||||
| using ElementCompute = float; | |||||
| using EpilogueOp = cutlass::epilogue::thread::BiasAddLinearCombinationClamp< | |||||
| ElementOutput, 4, ElementAccumulator, ElementBias, ElementCompute>; | |||||
| typename EpilogueOp::Params epilogue{alpha, 0, 0}; | |||||
| DISPATCH_KERNEL; | |||||
| #undef DISPATCH_KERNEL_WITH_TILE_SHAPE | |||||
| #undef DISPATCH_KERNEL | |||||
| } | |||||
| #endif | |||||
| // vim: syntax=cuda.doxygen | |||||
| @@ -1,44 +0,0 @@ | |||||
| /** | |||||
| * \file src/cuda/convolution/backward_data/cutlass_deconvolution_wrapper.cuh | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #pragma once | |||||
| #include "cutlass/gemm/gemm.h" | |||||
| #include "src/cuda/convolution_helper/parameter.cuh" | |||||
| #include "src/cuda/utils.cuh" | |||||
| namespace megdnn { | |||||
| namespace cuda { | |||||
| namespace cutlass_wrapper { | |||||
| using GemmCoord = cutlass::gemm::GemmCoord; | |||||
| template <typename Convolution> | |||||
| void cutlass_deconvolution_wrapper( | |||||
| const typename Convolution::ElementSrc* d_src, | |||||
| const typename Convolution::ElementFilter* d_filter, | |||||
| const typename Convolution::ElementBias* d_bias, | |||||
| const typename Convolution::ElementDst* d_z, | |||||
| typename Convolution::ElementDst* d_dst, int* workspace, | |||||
| typename Convolution::ConvolutionParameter const& conv_param, | |||||
| typename Convolution::EpilogueOutputOp::Params const& epilogue, | |||||
| cudaStream_t stream); | |||||
| void do_deconv_int8_implicit_gemm_dp4a_ncdiv4hw4( | |||||
| const int8_t* d_src, const int8_t* d_filter, int8_t* d_dst, | |||||
| int* workspace, const convolution::ConvParam& param, float alpha, | |||||
| const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, | |||||
| int stages, cudaStream_t stream); | |||||
| } // namespace cutlass_wrapper | |||||
| } // namespace cuda | |||||
| } // namespace megdnn | |||||
| // vim: syntax=cuda.doxygen | |||||
| @@ -1,62 +0,0 @@ | |||||
| /** | |||||
| * \file | |||||
| * dnn/src/cuda/convolution/backward_data/implicit_gemm_deconv_cutlass_wrapper.cuinl | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #include "cutlass/convolution/device/convolution.h" | |||||
| #include "src/cuda/convolution/backward_data/cutlass_deconvolution_wrapper.cuh" | |||||
| using namespace megdnn; | |||||
| using namespace cuda; | |||||
| using namespace cutlass_wrapper; | |||||
| template <typename Deconvolution> | |||||
| void megdnn::cuda::cutlass_wrapper::cutlass_deconvolution_wrapper( | |||||
| const typename Deconvolution::ElementSrc* d_src, | |||||
| const typename Deconvolution::ElementFilter* d_filter, | |||||
| const typename Deconvolution::ElementBias* d_bias, | |||||
| const typename Deconvolution::ElementDst* d_z, | |||||
| typename Deconvolution::ElementDst* d_dst, int* workspace, | |||||
| typename Deconvolution::ConvolutionParameter const& conv_param, | |||||
| typename Deconvolution::EpilogueOutputOp::Params const& epilogue, | |||||
| cudaStream_t stream) { | |||||
| typename Deconvolution::TensorRefSrc tensor_src{ | |||||
| const_cast<typename Deconvolution::ElementSrc*>(d_src), | |||||
| Deconvolution::LayoutSrc::packed( | |||||
| {conv_param.N, conv_param.P, conv_param.Q, conv_param.K})}; | |||||
| typename Deconvolution::TensorRefFilter tensor_filter{ | |||||
| const_cast<typename Deconvolution::ElementFilter*>(d_filter), | |||||
| Deconvolution::LayoutFilter::packed( | |||||
| {conv_param.K, conv_param.R, conv_param.S, conv_param.C})}; | |||||
| typename Deconvolution::TensorRefBias tensor_bias{ | |||||
| const_cast<typename Deconvolution::ElementBias*>(d_bias), | |||||
| Deconvolution::LayoutBias::packed({1, 1, 1, conv_param.K})}; | |||||
| typename Deconvolution::TensorRefDst tensor_z{ | |||||
| const_cast<typename Deconvolution::ElementDst*>(d_z), | |||||
| Deconvolution::LayoutDst::packed( | |||||
| {conv_param.N, conv_param.H, conv_param.W, conv_param.C})}; | |||||
| typename Deconvolution::TensorRefDst tensor_dst{ | |||||
| d_dst, | |||||
| Deconvolution::LayoutDst::packed( | |||||
| {conv_param.N, conv_param.H, conv_param.W, conv_param.C})}; | |||||
| typename Deconvolution::Arguments arguments{conv_param, | |||||
| tensor_src.non_const_ref(), | |||||
| tensor_filter.non_const_ref(), | |||||
| tensor_bias.non_const_ref(), | |||||
| tensor_z.non_const_ref(), | |||||
| tensor_dst.non_const_ref(), | |||||
| epilogue}; | |||||
| Deconvolution deconv_op; | |||||
| cutlass_check(deconv_op.initialize(arguments, workspace)); | |||||
| cutlass_check(deconv_op(stream)); | |||||
| after_kernel_launch(); | |||||
| } | |||||
| // vim: syntax=cuda.doxygen | |||||
| @@ -1,5 +1,6 @@ | |||||
| /** | /** | ||||
| * \file dnn/src/cuda/convolution/backward_data/implicit_gemm_int8_nchw4_dp4a.cpp | |||||
| * \file | |||||
| * dnn/src/cuda/convolution/backward_data/implicit_gemm_int8_nchw4_dp4a.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
| * | * | ||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
| @@ -10,11 +11,11 @@ | |||||
| * implied. | * implied. | ||||
| */ | */ | ||||
| #include "./algo.h" | |||||
| #include "src/cuda/utils.h" | |||||
| #include "src/cuda/convolution_helper/parameter.cuh" | |||||
| #include "src/cuda/convolution/backward_data/cutlass_deconvolution_wrapper.cuh" | |||||
| #include "src/cuda/convolution/backward_data/algo.h" | |||||
| #include "src/cuda/convolution/backward_data/deconv_int8_helper.cuh" | #include "src/cuda/convolution/backward_data/deconv_int8_helper.cuh" | ||||
| #include "src/cuda/convolution_helper/parameter.cuh" | |||||
| #include "src/cuda/cutlass/singleton.h" | |||||
| #include "src/cuda/utils.h" | |||||
| using namespace megdnn; | using namespace megdnn; | ||||
| using namespace cuda; | using namespace cuda; | ||||
| @@ -70,6 +71,7 @@ size_t ConvolutionBackwardDataImpl::AlgoInt8NCHW4DotProdImplicitGemm:: | |||||
| void ConvolutionBackwardDataImpl::AlgoInt8NCHW4DotProdImplicitGemm::exec( | void ConvolutionBackwardDataImpl::AlgoInt8NCHW4DotProdImplicitGemm::exec( | ||||
| const ExecArgs& args) const { | const ExecArgs& args) const { | ||||
| auto&& param = args.opr->param(); | |||||
| auto&& fm = args.filter_meta; | auto&& fm = args.filter_meta; | ||||
| size_t n = args.diff_layout->operator[](0), | size_t n = args.diff_layout->operator[](0), | ||||
| co = args.diff_layout->operator[](1) * 4, | co = args.diff_layout->operator[](1) * 4, | ||||
| @@ -81,6 +83,7 @@ void ConvolutionBackwardDataImpl::AlgoInt8NCHW4DotProdImplicitGemm::exec( | |||||
| size_t fh = fm.spatial[0], fw = fm.spatial[1]; | size_t fh = fm.spatial[0], fw = fm.spatial[1]; | ||||
| size_t sh = fm.stride[0], sw = fm.stride[1]; | size_t sh = fm.stride[0], sw = fm.stride[1]; | ||||
| size_t ph = fm.padding[0], pw = fm.padding[1]; | size_t ph = fm.padding[0], pw = fm.padding[1]; | ||||
| size_t dh = param.dilate_h, dw = param.dilate_w; | |||||
| auto&& stream = cuda_stream(args.opr->handle()); | auto&& stream = cuda_stream(args.opr->handle()); | ||||
| @@ -93,12 +96,6 @@ void ConvolutionBackwardDataImpl::AlgoInt8NCHW4DotProdImplicitGemm::exec( | |||||
| filter_ptr, args.filter_tensor->compatible_ptr<int8_t>(), co, | filter_ptr, args.filter_tensor->compatible_ptr<int8_t>(), co, | ||||
| ci, fh, fw, stream); | ci, fh, fw, stream); | ||||
| } | } | ||||
| convolution::ConvParam kern_param; | |||||
| kern_param.n = n, kern_param.co = co, kern_param.ci = ci, | |||||
| kern_param.hi = hi, kern_param.wi = wi, kern_param.ho = ho, | |||||
| kern_param.wo = wo, kern_param.ph = ph, kern_param.pw = pw, | |||||
| kern_param.sh = sh, kern_param.sw = sw, kern_param.fh = fh, | |||||
| kern_param.fw = fw; | |||||
| float diff_scale = | float diff_scale = | ||||
| args.diff_layout->dtype.param<dtype::QuantizedS8>().scale, | args.diff_layout->dtype.param<dtype::QuantizedS8>().scale, | ||||
| @@ -106,17 +103,60 @@ void ConvolutionBackwardDataImpl::AlgoInt8NCHW4DotProdImplicitGemm::exec( | |||||
| args.filter_layout->dtype.param<dtype::QuantizedS8>().scale, | args.filter_layout->dtype.param<dtype::QuantizedS8>().scale, | ||||
| grad_scale = | grad_scale = | ||||
| args.grad_layout->dtype.param<dtype::QuantizedS8>().scale; | args.grad_layout->dtype.param<dtype::QuantizedS8>().scale; | ||||
| float alpha = diff_scale * filter_scale / grad_scale; | |||||
| cutlass_wrapper::do_deconv_int8_implicit_gemm_dp4a_ncdiv4hw4( | |||||
| args.diff_tensor->compatible_ptr<int8_t>(), filter_ptr, | |||||
| args.grad_tensor->compatible_ptr<int8_t>(), nullptr, kern_param, | |||||
| alpha, | |||||
| cutlass_wrapper::GemmCoord{m_algo_param.threadblock_m, | |||||
| m_algo_param.threadblock_n, | |||||
| m_algo_param.threadblock_k}, | |||||
| cutlass_wrapper::GemmCoord{m_algo_param.warp_m, m_algo_param.warp_n, | |||||
| m_algo_param.warp_k}, | |||||
| m_algo_param.stage, stream); | |||||
| // \note these constants of cutlass epilogue will be passed to struct | |||||
| // `ConvolutionArguments` by pointer and interpreted as ElementCompute*, a | |||||
| // different dtype here results in undefined epilogue behaviors | |||||
| float alpha = diff_scale * filter_scale / grad_scale, beta = 0.f, | |||||
| gamma = 0.f, delta = 0.f; | |||||
| using namespace cutlass::library; | |||||
| // only use 16x64x8_16x64x8_2stages impl | |||||
| ConvolutionKey key{ | |||||
| cutlass::conv::Operator::kDgrad, | |||||
| NumericTypeID::kS8, | |||||
| LayoutTypeID::kTensorNC4HW4, | |||||
| NumericTypeID::kS8, | |||||
| LayoutTypeID::kTensorK4RSC4, | |||||
| NumericTypeID::kS8, | |||||
| LayoutTypeID::kTensorNC4HW4, | |||||
| NumericTypeID::kS32, | |||||
| LayoutTypeID::kTensorNC4HW4, | |||||
| cutlass::conv::ConvType::kConvolution, | |||||
| m_algo_param.threadblock_m, | |||||
| m_algo_param.threadblock_n, | |||||
| m_algo_param.threadblock_k, | |||||
| m_algo_param.warp_m, | |||||
| m_algo_param.warp_n, | |||||
| m_algo_param.warp_k, | |||||
| 1, | |||||
| 1, | |||||
| 4, | |||||
| cutlass::epilogue::EpilogueType::kBiasAddLinearCombinationClamp, | |||||
| m_algo_param.stage, | |||||
| true, | |||||
| false}; | |||||
| const Operation* op = Singleton::get().operation_table.find_op(key); | |||||
| // gcc prints warnings when size_t values are implicitly narrowed to int | |||||
| cutlass::conv::Conv2dProblemSize problem_size{ | |||||
| int(n), int(hi), int(wi), int(ci), | |||||
| int(co), int(fh), int(fw), int(ho), | |||||
| int(wo), int(ph), int(pw), int(sh), | |||||
| int(sw), int(dh), int(dw), cutlass::conv::Mode::kCrossCorrelation}; | |||||
| cutlass::library::ConvolutionArguments conv_args{ | |||||
| problem_size, args.diff_tensor->compatible_ptr<int8_t>(), | |||||
| filter_ptr, nullptr, | |||||
| nullptr, args.grad_tensor->compatible_ptr<int8_t>(), | |||||
| &alpha, &beta, | |||||
| &gamma, &delta, | |||||
| nullptr, nullptr, | |||||
| nullptr, nullptr}; | |||||
| cutlass_check(op->run(&conv_args, nullptr, stream)); | |||||
| after_kernel_launch(); | after_kernel_launch(); | ||||
| } | } | ||||
| @@ -11,16 +11,16 @@ | |||||
| * implied. | * implied. | ||||
| */ | */ | ||||
| #include "./algo.h" | |||||
| #include "src/cuda/utils.h" | |||||
| #include "src/cuda/convolution/backward_data/algo.h" | |||||
| #include "src/cuda/convolution_helper/parameter.cuh" | #include "src/cuda/convolution_helper/parameter.cuh" | ||||
| #include "src/cuda/convolution/backward_data/cutlass_deconvolution_wrapper.cuh" | |||||
| #include "src/cuda/cutlass/singleton.h" | |||||
| #include "src/cuda/utils.h" | |||||
| using namespace megdnn; | using namespace megdnn; | ||||
| using namespace cuda; | using namespace cuda; | ||||
| bool ConvolutionBackwardDataImpl::AlgoInt8NCHWDotProdImplicitGemm:: | |||||
| is_available(const SizeArgs& args) const { | |||||
| bool ConvolutionBackwardDataImpl::AlgoInt8NCHWDotProdImplicitGemm::is_available( | |||||
| const SizeArgs& args) const { | |||||
| auto&& fm = args.filter_meta; | auto&& fm = args.filter_meta; | ||||
| if (fm.format != Param::Format::NCHW) | if (fm.format != Param::Format::NCHW) | ||||
| return false; | return false; | ||||
| @@ -42,7 +42,8 @@ bool ConvolutionBackwardDataImpl::AlgoInt8NCHWDotProdImplicitGemm:: | |||||
| // TODO support group deconv int8 | // TODO support group deconv int8 | ||||
| available &= (fm.group == 1); | available &= (fm.group == 1); | ||||
| // ic and oc must be multiples of 4 | // ic and oc must be multiples of 4 | ||||
| available &= ((fm.group * fm.icpg) % 4 == 0 && (fm.group * fm.ocpg) % 4 == 0); | |||||
| available &= | |||||
| ((fm.group * fm.icpg) % 4 == 0 && (fm.group * fm.ocpg) % 4 == 0); | |||||
| // mode must be cross correlation | // mode must be cross correlation | ||||
| available &= !fm.should_flip; | available &= !fm.should_flip; | ||||
| // mode must be 2D | // mode must be 2D | ||||
| @@ -73,6 +74,7 @@ size_t ConvolutionBackwardDataImpl::AlgoInt8NCHWDotProdImplicitGemm:: | |||||
| void ConvolutionBackwardDataImpl::AlgoInt8NCHWDotProdImplicitGemm::exec( | void ConvolutionBackwardDataImpl::AlgoInt8NCHWDotProdImplicitGemm::exec( | ||||
| const ExecArgs& args) const { | const ExecArgs& args) const { | ||||
| auto&& param = args.opr->param(); | |||||
| auto&& fm = args.filter_meta; | auto&& fm = args.filter_meta; | ||||
| size_t n = args.diff_layout->operator[](0), | size_t n = args.diff_layout->operator[](0), | ||||
| co = args.diff_layout->operator[](1), | co = args.diff_layout->operator[](1), | ||||
| @@ -84,6 +86,7 @@ void ConvolutionBackwardDataImpl::AlgoInt8NCHWDotProdImplicitGemm::exec( | |||||
| size_t fh = fm.spatial[0], fw = fm.spatial[1]; | size_t fh = fm.spatial[0], fw = fm.spatial[1]; | ||||
| size_t sh = fm.stride[0], sw = fm.stride[1]; | size_t sh = fm.stride[0], sw = fm.stride[1]; | ||||
| size_t ph = fm.padding[0], pw = fm.padding[1]; | size_t ph = fm.padding[0], pw = fm.padding[1]; | ||||
| size_t dh = param.dilate_h, dw = param.dilate_w; | |||||
| auto&& stream = cuda_stream(args.opr->handle()); | auto&& stream = cuda_stream(args.opr->handle()); | ||||
| @@ -120,26 +123,63 @@ void ConvolutionBackwardDataImpl::AlgoInt8NCHWDotProdImplicitGemm::exec( | |||||
| } | } | ||||
| int8_t* inner_grad_ptr = reinterpret_cast<int8_t*>(bundle.get(2)); | int8_t* inner_grad_ptr = reinterpret_cast<int8_t*>(bundle.get(2)); | ||||
| convolution::ConvParam kern_param; | |||||
| kern_param.n = n, kern_param.co = co, kern_param.ci = ci, | |||||
| kern_param.hi = hi, kern_param.wi = wi, kern_param.ho = ho, | |||||
| kern_param.wo = wo, kern_param.ph = ph, kern_param.pw = pw, | |||||
| kern_param.sh = sh, kern_param.sw = sw, kern_param.fh = fh, | |||||
| kern_param.fw = fw; | |||||
| float diff_scale = | float diff_scale = | ||||
| args.diff_layout->dtype.param<dtype::QuantizedS8>().scale, | args.diff_layout->dtype.param<dtype::QuantizedS8>().scale, | ||||
| filter_scale = | filter_scale = | ||||
| args.filter_layout->dtype.param<dtype::QuantizedS8>().scale, | args.filter_layout->dtype.param<dtype::QuantizedS8>().scale, | ||||
| grad_scale = | grad_scale = | ||||
| args.grad_layout->dtype.param<dtype::QuantizedS8>().scale; | args.grad_layout->dtype.param<dtype::QuantizedS8>().scale; | ||||
| float alpha = diff_scale * filter_scale / grad_scale; | |||||
| // \note these constants of cutlass epilogue will be passed to struct | |||||
| // `ConvolutionArguments` by pointer and interpreted as ElementCompute*, a | |||||
| // different dtype here results in undefined epilogue behaviors | |||||
| float alpha = diff_scale * filter_scale / grad_scale, beta = 0.f, | |||||
| gamma = 0.f, delta = 0.f; | |||||
| using namespace cutlass::library; | |||||
| // only use 16x64x8_16x64x8_2stages impl | // only use 16x64x8_16x64x8_2stages impl | ||||
| cutlass_wrapper::do_deconv_int8_implicit_gemm_dp4a_ncdiv4hw4( | |||||
| inner_diff_ptr, inner_filter_ptr, inner_grad_ptr, nullptr, | |||||
| kern_param, alpha, cutlass_wrapper::GemmCoord{16, 64, 8}, | |||||
| cutlass_wrapper::GemmCoord{16, 64, 8}, 2, stream); | |||||
| ConvolutionKey key{ | |||||
| cutlass::conv::Operator::kDgrad, | |||||
| NumericTypeID::kS8, | |||||
| LayoutTypeID::kTensorNC4HW4, | |||||
| NumericTypeID::kS8, | |||||
| LayoutTypeID::kTensorK4RSC4, | |||||
| NumericTypeID::kS8, | |||||
| LayoutTypeID::kTensorNC4HW4, | |||||
| NumericTypeID::kS32, | |||||
| LayoutTypeID::kTensorNC4HW4, | |||||
| cutlass::conv::ConvType::kConvolution, | |||||
| 16, | |||||
| 64, | |||||
| 8, | |||||
| 16, | |||||
| 64, | |||||
| 8, | |||||
| 1, | |||||
| 1, | |||||
| 4, | |||||
| cutlass::epilogue::EpilogueType::kBiasAddLinearCombinationClamp, | |||||
| 2, | |||||
| true, | |||||
| false}; | |||||
| const Operation* op = Singleton::get().operation_table.find_op(key); | |||||
| // gcc prints warnings when size_t values are implicitly narrowed to int | |||||
| cutlass::conv::Conv2dProblemSize problem_size{ | |||||
| int(n), int(hi), int(wi), int(ci), | |||||
| int(co), int(fh), int(fw), int(ho), | |||||
| int(wo), int(ph), int(pw), int(sh), | |||||
| int(sw), int(dh), int(dw), cutlass::conv::Mode::kCrossCorrelation}; | |||||
| cutlass::library::ConvolutionArguments conv_args{ | |||||
| problem_size, inner_diff_ptr, inner_filter_ptr, nullptr, | |||||
| nullptr, inner_grad_ptr, &alpha, &beta, | |||||
| &gamma, &delta, nullptr, nullptr, | |||||
| nullptr, nullptr}; | |||||
| cutlass_check(op->run(&conv_args, nullptr, stream)); | |||||
| after_kernel_launch(); | after_kernel_launch(); | ||||
| @@ -0,0 +1,107 @@ | |||||
| /*************************************************************************************************** | |||||
| * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. | |||||
| * | |||||
| * Redistribution and use in source and binary forms, with or without | |||||
| *modification, are permitted provided that the following conditions are met: | |||||
| * * Redistributions of source code must retain the above copyright notice, | |||||
| *this list of conditions and the following disclaimer. | |||||
| * * Redistributions in binary form must reproduce the above copyright | |||||
| *notice, this list of conditions and the following disclaimer in the | |||||
| *documentation and/or other materials provided with the distribution. | |||||
| * * Neither the name of the NVIDIA CORPORATION nor the names of its | |||||
| *contributors may be used to endorse or promote products derived from this | |||||
| *software without specific prior written permission. | |||||
| * | |||||
| * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||||
| *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||||
| *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | |||||
| *DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, | |||||
| *INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, | |||||
| * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, | |||||
| *DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY | |||||
| *OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING | |||||
| *NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, | |||||
| *EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||||
| * | |||||
| **************************************************************************************************/ | |||||
| /** | |||||
| * \file dnn/src/cuda/cutlass/arch_mappings.h | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #pragma once | |||||
| #include "cutlass/arch/arch.h" | |||||
| #include "cutlass/arch/mma.h" | |||||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
| namespace cutlass { | |||||
| namespace library { | |||||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
| template <typename ArchTag, typename OperatorClass> | |||||
| struct ArchMap; | |||||
| template <> | |||||
| struct ArchMap<arch::Sm50, arch::OpClassSimt> { | |||||
| static int const kMin = 50; | |||||
| static int const kMax = 1024; | |||||
| }; | |||||
| template <> | |||||
| struct ArchMap<arch::Sm60, arch::OpClassSimt> { | |||||
| static int const kMin = 60; | |||||
| static int const kMax = 1024; | |||||
| }; | |||||
| template <> | |||||
| struct ArchMap<arch::Sm61, arch::OpClassSimt> { | |||||
| static int const kMin = 61; | |||||
| static int const kMax = 1024; | |||||
| }; | |||||
| template <> | |||||
| struct ArchMap<arch::Sm70, arch::OpClassWmmaTensorOp> { | |||||
| static int const kMin = 70; | |||||
| static int const kMax = 1024; | |||||
| }; | |||||
| template <> | |||||
| struct ArchMap<arch::Sm70, arch::OpClassTensorOp> { | |||||
| static int const kMin = 70; | |||||
| static int const kMax = 75; | |||||
| }; | |||||
| template <typename OperatorClass> | |||||
| struct ArchMap<arch::Sm75, OperatorClass> { | |||||
| static int const kMin = 75; | |||||
| static int const kMax = 1024; | |||||
| }; | |||||
| template <typename OperatorClass> | |||||
| struct ArchMap<arch::Sm80, OperatorClass> { | |||||
| static int const kMin = 80; | |||||
| static int const kMax = 1024; | |||||
| }; | |||||
| template <typename OperatorClass> | |||||
| struct ArchMap<arch::Sm86, OperatorClass> { | |||||
| static int const kMin = 86; | |||||
| static int const kMax = 1024; | |||||
| }; | |||||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
| } // namespace library | |||||
| } // namespace cutlass | |||||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
| @@ -0,0 +1,307 @@ | |||||
| /*************************************************************************************************** | |||||
| * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. | |||||
| * | |||||
| * Redistribution and use in source and binary forms, with or without | |||||
| *modification, are permitted provided that the following conditions are met: | |||||
| * * Redistributions of source code must retain the above copyright notice, | |||||
| *this list of conditions and the following disclaimer. | |||||
| * * Redistributions in binary form must reproduce the above copyright | |||||
| *notice, this list of conditions and the following disclaimer in the | |||||
| *documentation and/or other materials provided with the distribution. | |||||
| * * Neither the name of the NVIDIA CORPORATION nor the names of its | |||||
| *contributors may be used to endorse or promote products derived from this | |||||
| *software without specific prior written permission. | |||||
| * | |||||
| * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||||
| *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||||
| *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | |||||
| *DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, | |||||
| *INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, | |||||
| * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, | |||||
| *DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY | |||||
| *OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING | |||||
| *NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, | |||||
| *EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||||
| * | |||||
| **************************************************************************************************/ | |||||
| /** | |||||
| * \file dnn/src/cuda/cutlass/convolution_operation.h | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #pragma once | |||||
| #include "cutlass/convolution/device/convolution.h" | |||||
| #include "src/cuda/cutlass/library_internal.h" | |||||
| /////////////////////////////////////////////////////////////////////////////////////////////////// | |||||
| namespace cutlass { | |||||
| namespace library { | |||||
| /////////////////////////////////////////////////////////////////////////////////////////////////// | |||||
| template <typename Operator_> | |||||
| class ConvolutionOperationBase : public Operation { | |||||
| public: | |||||
| using Operator = Operator_; | |||||
| using ElementSrc = typename Operator::ElementSrc; | |||||
| using LayoutSrc = typename Operator::LayoutSrc; | |||||
| using ElementFilter = typename Operator::ElementFilter; | |||||
| using LayoutFilter = typename Operator::LayoutFilter; | |||||
| using ElementDst = typename Operator::ElementDst; | |||||
| using LayoutDst = typename Operator::LayoutDst; | |||||
| using ElementBias = typename Operator::ElementBias; | |||||
| using LayoutBias = typename Operator::LayoutBias; | |||||
| using ElementAccumulator = typename Operator::ElementAccumulator; | |||||
| ConvolutionOperationBase(char const* name = "unknown_convolution") { | |||||
| m_description.name = name; | |||||
| m_description.provider = Provider::kCUTLASS; | |||||
| m_description.kind = OperationKind::kConvolution; | |||||
| m_description.conv_op = Operator::kConvolutionalOperator; | |||||
| m_description.tile_description.threadblock_shape = make_Coord( | |||||
| Operator::ThreadblockShape::kM, Operator::ThreadblockShape::kN, | |||||
| Operator::ThreadblockShape::kK); | |||||
| m_description.tile_description.threadblock_stages = Operator::kStages; | |||||
| m_description.tile_description.warp_count = | |||||
| make_Coord(Operator::ConvolutionKernel::WarpCount::kM, | |||||
| Operator::ConvolutionKernel::WarpCount::kN, | |||||
| Operator::ConvolutionKernel::WarpCount::kK); | |||||
| m_description.tile_description.math_instruction.instruction_shape = | |||||
| make_Coord(Operator::InstructionShape::kM, | |||||
| Operator::InstructionShape::kN, | |||||
| Operator::InstructionShape::kK); | |||||
| m_description.tile_description.math_instruction.element_accumulator = | |||||
| NumericTypeMap<ElementAccumulator>::kId; | |||||
| m_description.tile_description.math_instruction.opcode_class = | |||||
| OpcodeClassMap<typename Operator::OperatorClass>::kId; | |||||
| m_description.tile_description.math_instruction.math_operation = | |||||
| MathOperationMap<typename Operator::Operator>::kId; | |||||
| m_description.tile_description.minimum_compute_capability = | |||||
| ArchMap<typename Operator::ArchTag, | |||||
| typename Operator::OperatorClass>::kMin; | |||||
| m_description.tile_description.maximum_compute_capability = | |||||
| ArchMap<typename Operator::ArchTag, | |||||
| typename Operator::OperatorClass>::kMax; | |||||
| m_description.src = make_TensorDescription<ElementSrc, LayoutSrc>( | |||||
| Operator::kAlignmentSrc); | |||||
| m_description.filter = | |||||
| make_TensorDescription<ElementFilter, LayoutFilter>( | |||||
| Operator::kAlignmentFilter); | |||||
| m_description.dst = make_TensorDescription<ElementDst, LayoutDst>( | |||||
| Operator::kAlignmentDst); | |||||
| m_description.bias = make_TensorDescription<ElementBias, LayoutBias>( | |||||
| Operator::kAlignmentDst); | |||||
| m_description.convolution_type = Operator::kConvolutionType; | |||||
| m_description.arch_tag = ArchTagMap<typename Operator::ArchTag>::kId; | |||||
| m_description.epilogue_type = Operator::EpilogueOutputOp::kType; | |||||
| m_description.epilogue_count = Operator::EpilogueOutputOp::kCount; | |||||
| m_description.threadblock_swizzle = ThreadblockSwizzleMap< | |||||
| typename Operator::ThreadblockSwizzle>::kId; | |||||
| m_description.need_load_from_const_mem = | |||||
| Operator::kNeedLoadFromConstMem; | |||||
| m_description.gemm_mode = Operator::kGemmMode; | |||||
| m_description.without_shared_load = Operator::kWithoutSharedLoad; | |||||
| } | |||||
| virtual OperationDescription const& description() const { | |||||
| return m_description; | |||||
| } | |||||
| protected: | |||||
| ConvolutionDescription m_description; | |||||
| }; | |||||
| /////////////////////////////////////////////////////////////////////////////////////////////////// | |||||
| namespace detail { | |||||
| template <typename EpilogueOp, epilogue::EpilogueType type> | |||||
| struct init_epilogue_param_; | |||||
| template <typename EpilogueOp> | |||||
| struct init_epilogue_param_<EpilogueOp, | |||||
| epilogue::EpilogueType::kBiasAddLinearCombination> { | |||||
| using ElementCompute = typename EpilogueOp::ElementCompute; | |||||
| typename EpilogueOp::Params get(ConvolutionArguments const* conv_args) { | |||||
| return {*static_cast<ElementCompute const*>(conv_args->alpha), | |||||
| *static_cast<ElementCompute const*>(conv_args->beta), | |||||
| *static_cast<ElementCompute const*>(conv_args->gamma), | |||||
| *static_cast<ElementCompute const*>(conv_args->delta)}; | |||||
| } | |||||
| }; | |||||
| template <typename EpilogueOp> | |||||
| struct init_epilogue_param_< | |||||
| EpilogueOp, epilogue::EpilogueType::kBiasAddLinearCombinationClamp> { | |||||
| using ElementCompute = typename EpilogueOp::ElementCompute; | |||||
| typename EpilogueOp::Params get(ConvolutionArguments const* conv_args) { | |||||
| return {*static_cast<ElementCompute const*>(conv_args->alpha), | |||||
| *static_cast<ElementCompute const*>(conv_args->beta), | |||||
| *static_cast<ElementCompute const*>(conv_args->gamma), | |||||
| *static_cast<ElementCompute const*>(conv_args->delta)}; | |||||
| } | |||||
| }; | |||||
| template <typename EpilogueOp> | |||||
| struct init_epilogue_param_< | |||||
| EpilogueOp, epilogue::EpilogueType::kBiasAddLinearCombinationRelu> { | |||||
| using ElementCompute = typename EpilogueOp::ElementCompute; | |||||
| typename EpilogueOp::Params get(ConvolutionArguments const* conv_args) { | |||||
| return {*static_cast<ElementCompute const*>(conv_args->alpha), | |||||
| *static_cast<ElementCompute const*>(conv_args->beta), | |||||
| *static_cast<ElementCompute const*>(conv_args->gamma), | |||||
| *static_cast<ElementCompute const*>(conv_args->threshold), | |||||
| *static_cast<ElementCompute const*>(conv_args->delta), | |||||
| *static_cast<ElementCompute const*>(conv_args->theta)}; | |||||
| } | |||||
| }; | |||||
| template <typename EpilogueOp> | |||||
| struct init_epilogue_param_< | |||||
| EpilogueOp, | |||||
| epilogue::EpilogueType::kBiasAddLinearCombinationReluClamp> { | |||||
| using ElementCompute = typename EpilogueOp::ElementCompute; | |||||
| typename EpilogueOp::Params get(ConvolutionArguments const* conv_args) { | |||||
| return {*static_cast<ElementCompute const*>(conv_args->alpha), | |||||
| *static_cast<ElementCompute const*>(conv_args->beta), | |||||
| *static_cast<ElementCompute const*>(conv_args->gamma), | |||||
| *static_cast<ElementCompute const*>(conv_args->threshold), | |||||
| *static_cast<ElementCompute const*>(conv_args->delta), | |||||
| *static_cast<ElementCompute const*>(conv_args->theta)}; | |||||
| } | |||||
| }; | |||||
| template <typename EpilogueOp> | |||||
| struct init_epilogue_param_< | |||||
| EpilogueOp, epilogue::EpilogueType::kBiasAddLinearCombinationHSwish> { | |||||
| using ElementCompute = typename EpilogueOp::ElementCompute; | |||||
| typename EpilogueOp::Params get(ConvolutionArguments const* conv_args) { | |||||
| return {*static_cast<ElementCompute const*>(conv_args->alpha), | |||||
| *static_cast<ElementCompute const*>(conv_args->beta), | |||||
| *static_cast<ElementCompute const*>(conv_args->gamma), | |||||
| *static_cast<ElementCompute const*>(conv_args->scale), | |||||
| *static_cast<ElementCompute const*>(conv_args->delta), | |||||
| *static_cast<ElementCompute const*>(conv_args->theta)}; | |||||
| } | |||||
| }; | |||||
| template <typename EpilogueOp> | |||||
| struct init_epilogue_param_< | |||||
| EpilogueOp, | |||||
| epilogue::EpilogueType::kBiasAddLinearCombinationHSwishClamp> { | |||||
| using ElementCompute = typename EpilogueOp::ElementCompute; | |||||
| typename EpilogueOp::Params get(ConvolutionArguments const* conv_args) { | |||||
| return {*static_cast<ElementCompute const*>(conv_args->alpha), | |||||
| *static_cast<ElementCompute const*>(conv_args->beta), | |||||
| *static_cast<ElementCompute const*>(conv_args->gamma), | |||||
| *static_cast<ElementCompute const*>(conv_args->scale), | |||||
| *static_cast<ElementCompute const*>(conv_args->delta), | |||||
| *static_cast<ElementCompute const*>(conv_args->theta)}; | |||||
| } | |||||
| }; | |||||
| } // namespace detail | |||||
| template <typename EpilogueOp> | |||||
| struct init_epilogue_param | |||||
| : public detail::init_epilogue_param_<EpilogueOp, EpilogueOp::kType> {}; | |||||
| /////////////////////////////////////////////////////////////////////////////////////////////////// | |||||
| template <typename Operator_> | |||||
| class ConvolutionOperation : public ConvolutionOperationBase<Operator_> { | |||||
| public: | |||||
| using Operator = Operator_; | |||||
| using ElementSrc = typename Operator::ElementSrc; | |||||
| using LayoutSrc = typename Operator::LayoutSrc; | |||||
| using ElementFilter = typename Operator::ElementFilter; | |||||
| using LayoutFilter = typename Operator::LayoutFilter; | |||||
| using ElementBias = typename Operator::ElementBias; | |||||
| using LayoutBias = typename Operator::LayoutBias; | |||||
| using ElementDst = typename Operator::ElementDst; | |||||
| using LayoutDst = typename Operator::LayoutDst; | |||||
| using ElementAccumulator = typename Operator::ElementAccumulator; | |||||
| using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; | |||||
| using OperatorArguments = typename Operator::Arguments; | |||||
| ConvolutionOperation(char const* name = "unknown_gemm") | |||||
| : ConvolutionOperationBase<Operator_>(name) {} | |||||
| virtual Status run(void const* arguments_ptr, | |||||
| void* device_workspace = nullptr, | |||||
| cudaStream_t stream = nullptr) const { | |||||
| cutlass::conv::Operator conv_op = this->m_description.conv_op; | |||||
| ConvolutionArguments const* conv_args = | |||||
| reinterpret_cast<ConvolutionArguments const*>(arguments_ptr); | |||||
| const auto& ps = conv_args->problem_size; | |||||
| OperatorArguments args; | |||||
| args.problem_size = ps; | |||||
| args.ref_src = { | |||||
| static_cast<ElementSrc*>(const_cast<void*>(conv_args->src)), | |||||
| LayoutSrc::packed(implicit_gemm_tensor_a_extent(conv_op, ps))}; | |||||
| args.ref_filter = {static_cast<ElementFilter*>( | |||||
| const_cast<void*>(conv_args->filter)), | |||||
| LayoutFilter::packed( | |||||
| implicit_gemm_tensor_b_extent(conv_op, ps))}; | |||||
| args.ref_bias = { | |||||
| static_cast<ElementBias*>(const_cast<void*>(conv_args->bias)), | |||||
| LayoutBias::packed( | |||||
| implicit_gemm_tensor_bias_extent(conv_op, ps))}; | |||||
| args.ref_z = { | |||||
| static_cast<ElementDst*>(const_cast<void*>(conv_args->z)), | |||||
| LayoutDst::packed(implicit_gemm_tensor_c_extent(conv_op, ps))}; | |||||
| args.ref_dst = { | |||||
| static_cast<ElementDst*>(conv_args->dst), | |||||
| LayoutDst::packed(implicit_gemm_tensor_c_extent(conv_op, ps))}; | |||||
| args.output_op = | |||||
| init_epilogue_param<typename Operator::EpilogueOutputOp>().get( | |||||
| conv_args); | |||||
| if (conv_args->extra_param) { | |||||
| args.extra_param = | |||||
| *reinterpret_cast<typename Operator::ExtraParam const*>( | |||||
| conv_args->extra_param); | |||||
| } | |||||
| Operator op; | |||||
| Status status = op.initialize(args, device_workspace); | |||||
| if (status != Status::kSuccess) { | |||||
| return status; | |||||
| } | |||||
| return op.run(stream); | |||||
| } | |||||
| }; | |||||
| /////////////////////////////////////////////////////////////////////////////////////////////////// | |||||
| } // namespace library | |||||
| } // namespace cutlass | |||||
| /////////////////////////////////////////////////////////////////////////////////////////////////// | |||||
| @@ -0,0 +1,202 @@ | |||||
| /*************************************************************************************************** | |||||
| * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. | |||||
| * | |||||
| * Redistribution and use in source and binary forms, with or without | |||||
| *modification, are permitted provided that the following conditions are met: | |||||
| * * Redistributions of source code must retain the above copyright notice, | |||||
| *this list of conditions and the following disclaimer. | |||||
| * * Redistributions in binary form must reproduce the above copyright | |||||
| *notice, this list of conditions and the following disclaimer in the | |||||
| *documentation and/or other materials provided with the distribution. | |||||
| * * Neither the name of the NVIDIA CORPORATION nor the names of its | |||||
| *contributors may be used to endorse or promote products derived from this | |||||
| *software without specific prior written permission. | |||||
| * | |||||
| * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||||
| *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||||
| *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | |||||
| *DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, | |||||
| *INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, | |||||
| * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, | |||||
| *DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY | |||||
| *OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING | |||||
| *NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, | |||||
| *EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||||
| * | |||||
| **************************************************************************************************/ | |||||
| /** | |||||
| * \file dnn/src/cuda/cutlass/gemm_operation.h | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #pragma once | |||||
| #include "cutlass/gemm/device/gemm.h" | |||||
| #include "src/cuda/cutlass/library_internal.h" | |||||
| /////////////////////////////////////////////////////////////////////////////////////////////////// | |||||
| namespace cutlass { | |||||
| namespace library { | |||||
| /////////////////////////////////////////////////////////////////////////////////////////////////// | |||||
| /// Check whether Operator has member ReductionKernel using SFINAE (Substitution | |||||
| /// Failure Is Not An Error) | |||||
| template <typename Operator> | |||||
| struct split_k_mode { | |||||
| template <typename T> | |||||
| static char check(typename T::ReductionKernel*); | |||||
| template <typename T> | |||||
| static int check(...); | |||||
| SplitKMode operator()() { | |||||
| if (sizeof(check<Operator>(0)) == sizeof(char)) { | |||||
| // cutlass::gemm::device::GemmSplitKParallel | |||||
| return SplitKMode::kParallel; | |||||
| } else { | |||||
| // cutlass::gemm::device::Gemm | |||||
| return SplitKMode::kNone; | |||||
| } | |||||
| } | |||||
| }; | |||||
| /////////////////////////////////////////////////////////////////////////////////////////////////// | |||||
| template <typename Operator_> | |||||
| class GemmOperationBase : public Operation { | |||||
| public: | |||||
| using Operator = Operator_; | |||||
| using ElementA = typename Operator::ElementA; | |||||
| using LayoutA = typename Operator::LayoutA; | |||||
| using ElementB = typename Operator::ElementB; | |||||
| using LayoutB = typename Operator::LayoutB; | |||||
| using ElementC = typename Operator::ElementC; | |||||
| using LayoutC = typename Operator::LayoutC; | |||||
| using ElementAccumulator = typename Operator::ElementAccumulator; | |||||
| GemmOperationBase(char const* name = "unknown_gemm") { | |||||
| m_description.name = name; | |||||
| m_description.provider = Provider::kCUTLASS; | |||||
| m_description.kind = OperationKind::kGemm; | |||||
| m_description.gemm_kind = GemmKind::kGemm; | |||||
| m_description.tile_description.threadblock_shape = make_Coord( | |||||
| Operator::ThreadblockShape::kM, Operator::ThreadblockShape::kN, | |||||
| Operator::ThreadblockShape::kK); | |||||
| m_description.tile_description.threadblock_stages = Operator::kStages; | |||||
| m_description.tile_description.warp_count = | |||||
| make_Coord(Operator::GemmKernel::WarpCount::kM, | |||||
| Operator::GemmKernel::WarpCount::kN, | |||||
| Operator::GemmKernel::WarpCount::kK); | |||||
| m_description.tile_description.math_instruction.instruction_shape = | |||||
| make_Coord(Operator::InstructionShape::kM, | |||||
| Operator::InstructionShape::kN, | |||||
| Operator::InstructionShape::kK); | |||||
| m_description.tile_description.math_instruction.element_accumulator = | |||||
| NumericTypeMap<ElementAccumulator>::kId; | |||||
| m_description.tile_description.math_instruction.opcode_class = | |||||
| OpcodeClassMap<typename Operator::OperatorClass>::kId; | |||||
| m_description.tile_description.math_instruction.math_operation = | |||||
| MathOperationMap<typename Operator::Operator>::kId; | |||||
| m_description.tile_description.minimum_compute_capability = | |||||
| ArchMap<typename Operator::ArchTag, | |||||
| typename Operator::OperatorClass>::kMin; | |||||
| m_description.tile_description.maximum_compute_capability = | |||||
| ArchMap<typename Operator::ArchTag, | |||||
| typename Operator::OperatorClass>::kMax; | |||||
| m_description.A = make_TensorDescription<ElementA, LayoutA>( | |||||
| Operator::kAlignmentA); | |||||
| m_description.B = make_TensorDescription<ElementB, LayoutB>( | |||||
| Operator::kAlignmentB); | |||||
| m_description.C = make_TensorDescription<ElementC, LayoutC>( | |||||
| Operator::kAlignmentC); | |||||
| m_description.stages = Operator::kStages; | |||||
| split_k_mode<Operator> mode; | |||||
| m_description.split_k_mode = mode(); | |||||
| } | |||||
| virtual OperationDescription const& description() const { | |||||
| return m_description; | |||||
| } | |||||
| protected: | |||||
| GemmDescription m_description; | |||||
| }; | |||||
| /////////////////////////////////////////////////////////////////////////////////////////////////// | |||||
| template <typename Operator_> | |||||
| class GemmOperation : public GemmOperationBase<Operator_> { | |||||
| public: | |||||
| using Operator = Operator_; | |||||
| using ElementA = typename Operator::ElementA; | |||||
| using LayoutA = typename Operator::LayoutA; | |||||
| using ElementB = typename Operator::ElementB; | |||||
| using LayoutB = typename Operator::LayoutB; | |||||
| using ElementC = typename Operator::ElementC; | |||||
| using LayoutC = typename Operator::LayoutC; | |||||
| using ElementAccumulator = typename Operator::ElementAccumulator; | |||||
| using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; | |||||
| using OperatorArguments = typename Operator::Arguments; | |||||
| GemmOperation(char const* name = "unknown_gemm") | |||||
| : GemmOperationBase<Operator_>(name) {} | |||||
| virtual Status run(void const* arguments_ptr, | |||||
| void* device_workspace = nullptr, | |||||
| cudaStream_t stream = nullptr) const { | |||||
| GemmArguments const* gemm_args = | |||||
| reinterpret_cast<GemmArguments const*>(arguments_ptr); | |||||
| OperatorArguments args; | |||||
| args.problem_size = gemm_args->problem_size; | |||||
| args.ref_A = {static_cast<ElementA const*>(gemm_args->A), | |||||
| int(gemm_args->lda)}; | |||||
| args.ref_B = {static_cast<ElementB const*>(gemm_args->B), | |||||
| int(gemm_args->ldb)}; | |||||
| args.ref_C = {static_cast<ElementC const*>(gemm_args->C), | |||||
| int(gemm_args->ldc)}; | |||||
| args.ref_D = {static_cast<ElementC*>(gemm_args->D), | |||||
| int(gemm_args->ldd)}; | |||||
| args.split_k_slices = gemm_args->split_k_slices; | |||||
| args.epilogue = {*static_cast<ElementCompute const*>(gemm_args->alpha), | |||||
| *static_cast<ElementCompute const*>(gemm_args->beta)}; | |||||
| Operator op; | |||||
| Status status = op.initialize(args, device_workspace); | |||||
| if (status != Status::kSuccess) { | |||||
| return status; | |||||
| } | |||||
| return op.run(stream); | |||||
| } | |||||
| }; | |||||
| /////////////////////////////////////////////////////////////////////////////////////////////////// | |||||
| } // namespace library | |||||
| } // namespace cutlass | |||||
| /////////////////////////////////////////////////////////////////////////////////////////////////// | |||||
| @@ -0,0 +1,76 @@ | |||||
| /*************************************************************************************************** | |||||
| * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. | |||||
| * | |||||
| * Redistribution and use in source and binary forms, with or without | |||||
| *modification, are permitted provided that the following conditions are met: | |||||
| * * Redistributions of source code must retain the above copyright notice, | |||||
| *this list of conditions and the following disclaimer. | |||||
| * * Redistributions in binary form must reproduce the above copyright | |||||
| *notice, this list of conditions and the following disclaimer in the | |||||
| *documentation and/or other materials provided with the distribution. | |||||
| * * Neither the name of the NVIDIA CORPORATION nor the names of its | |||||
| *contributors may be used to endorse or promote products derived from this | |||||
| *software without specific prior written permission. | |||||
| * | |||||
| * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||||
| *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||||
| *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | |||||
| *DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, | |||||
| *INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, | |||||
| * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, | |||||
| *DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY | |||||
| *OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING | |||||
| *NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, | |||||
| *EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||||
| * | |||||
| **************************************************************************************************/ | |||||
| /** | |||||
| * \file dnn/src/cuda/cutlass/initialize_all.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #include "src/cuda/cutlass/manifest.h" | |||||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
| namespace cutlass { | |||||
| namespace library { | |||||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
| #if __CUDACC_VER_MAJOR__ > 9 || \ | |||||
| (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) | |||||
| void initialize_all_gemm_simt_operations(Manifest& manifest); | |||||
| void initialize_all_conv2d_simt_operations(Manifest& manifest); | |||||
| void initialize_all_conv2d_tensorop8816_operations(Manifest& manifest); | |||||
| void initialize_all_conv2d_tensorop8832_operations(Manifest& manifest); | |||||
| void initialize_all_deconv_simt_operations(Manifest& manifest); | |||||
| void initialize_all(Manifest& manifest) { | |||||
| initialize_all_gemm_simt_operations(manifest); | |||||
| initialize_all_conv2d_simt_operations(manifest); | |||||
| initialize_all_conv2d_tensorop8816_operations(manifest); | |||||
| initialize_all_conv2d_tensorop8832_operations(manifest); | |||||
| initialize_all_deconv_simt_operations(manifest); | |||||
| } | |||||
| #else | |||||
| void initialize_all(Manifest& manifest) {} | |||||
| #endif | |||||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
| } // namespace library | |||||
| } // namespace cutlass | |||||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
| @@ -0,0 +1,541 @@ | |||||
| /*************************************************************************************************** | |||||
| * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. | |||||
| * | |||||
| * Redistribution and use in source and binary forms, with or without | |||||
| *modification, are permitted provided that the following conditions are met: | |||||
| * * Redistributions of source code must retain the above copyright notice, | |||||
| *this list of conditions and the following disclaimer. | |||||
| * * Redistributions in binary form must reproduce the above copyright | |||||
| *notice, this list of conditions and the following disclaimer in the | |||||
| *documentation and/or other materials provided with the distribution. | |||||
| * * Neither the name of the NVIDIA CORPORATION nor the names of its | |||||
| *contributors may be used to endorse or promote products derived from this | |||||
| *software without specific prior written permission. | |||||
| * | |||||
| * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||||
| *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||||
| *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | |||||
| *DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, | |||||
| *INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, | |||||
| * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, | |||||
| *DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY | |||||
| *OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING | |||||
| *NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, | |||||
| *EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||||
| * | |||||
| **************************************************************************************************/ | |||||
| /** | |||||
| * \file dnn/src/cuda/cutlass/library.h | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #pragma once | |||||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
| #include <cuda_runtime.h> | |||||
| #include <cstdint> | |||||
| #include <stdexcept> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #pragma GCC diagnostic push | |||||
| #pragma GCC diagnostic ignored "-Wreorder" | |||||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||||
| #include "cutlass/cutlass.h" | |||||
| #include "cutlass/layout/tensor.h" | |||||
| #include "cutlass/matrix_coord.h" | |||||
| #include "cutlass/tensor_coord.h" | |||||
| #include "cutlass/conv/conv2d_problem_size.h" | |||||
| #include "cutlass/conv/convolution.h" | |||||
| #include "cutlass/epilogue/epilogue.h" | |||||
| #include "cutlass/gemm/gemm.h" | |||||
| #pragma GCC diagnostic pop | |||||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
| namespace cutlass { | |||||
| namespace library { | |||||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
| /// Layout type identifier | |||||
| enum class LayoutTypeID { | |||||
| kUnknown, | |||||
| kColumnMajor, | |||||
| kRowMajor, | |||||
| kColumnMajorInterleavedK2, | |||||
| kRowMajorInterleavedK2, | |||||
| kColumnMajorInterleavedK4, | |||||
| kRowMajorInterleavedK4, | |||||
| kColumnMajorInterleavedK16, | |||||
| kRowMajorInterleavedK16, | |||||
| kColumnMajorInterleavedK32, | |||||
| kRowMajorInterleavedK32, | |||||
| kColumnMajorInterleavedK64, | |||||
| kRowMajorInterleavedK64, | |||||
| kTensorNCHW, | |||||
| kTensorNCDHW, | |||||
| kTensorNHWC, | |||||
| kTensorNDHWC, | |||||
| kTensorNC4HW4, | |||||
| kTensorC4RSK4, | |||||
| kTensorNC8HW8, | |||||
| kTensorC8RSK8, | |||||
| kTensorNC16HW16, | |||||
| kTensorC16RSK16, | |||||
| kTensorNC32HW32, | |||||
| kTensorC32RSK32, | |||||
| kTensorNC64HW64, | |||||
| kTensorC64RSK64, | |||||
| kTensorK4RSC4, | |||||
| kInvalid | |||||
| }; | |||||
| /// Numeric data type | |||||
| enum class NumericTypeID { | |||||
| kUnknown, | |||||
| kVoid, | |||||
| kB1, | |||||
| kU2, | |||||
| kU4, | |||||
| kU8, | |||||
| kU16, | |||||
| kU32, | |||||
| kU64, | |||||
| kS2, | |||||
| kS4, | |||||
| kS8, | |||||
| kS16, | |||||
| kS32, | |||||
| kS64, | |||||
| kF16, | |||||
| kBF16, | |||||
| kTF32, | |||||
| kF32, | |||||
| kF64, | |||||
| kCF16, | |||||
| kCBF16, | |||||
| kCF32, | |||||
| kCTF32, | |||||
| kCF64, | |||||
| kCS2, | |||||
| kCS4, | |||||
| kCS8, | |||||
| kCS16, | |||||
| kCS32, | |||||
| kCS64, | |||||
| kCU2, | |||||
| kCU4, | |||||
| kCU8, | |||||
| kCU16, | |||||
| kCU32, | |||||
| kCU64, | |||||
| kInvalid | |||||
| }; | |||||
| /// Enumerated type describing a transformation on a complex value. | |||||
| enum class ComplexTransform { kNone, kConjugate, kInvalid }; | |||||
| /// Providers | |||||
| enum class Provider { | |||||
| kNone, | |||||
| kCUTLASS, | |||||
| kReferenceHost, | |||||
| kReferenceDevice, | |||||
| kCUBLAS, | |||||
| kCUDNN, | |||||
| kInvalid | |||||
| }; | |||||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
| /// Enumeration indicating the kind of operation | |||||
| enum class OperationKind { | |||||
| kGemm, | |||||
| kConv2d, | |||||
| kConv3d, | |||||
| kConvolution, | |||||
| kEqGemm, | |||||
| kSparseGemm, | |||||
| kReduction, | |||||
| kInvalid | |||||
| }; | |||||
| /// Enumeration indicating whether scalars are in host or device memory | |||||
| enum class ScalarPointerMode { kHost, kDevice, kInvalid }; | |||||
| /// Describes how reductions are performed across threadblocks | |||||
| enum class SplitKMode { kNone, kSerial, kParallel, kParallelSerial, kInvalid }; | |||||
| /// Indicates the classificaition of the math instruction | |||||
| enum class OpcodeClassID { | |||||
| kSimt, | |||||
| kTensorOp, | |||||
| kWmmaTensorOp, | |||||
| kSparseTensorOp, | |||||
| kInvalid | |||||
| }; | |||||
| enum class ArchTagID { | |||||
| kSm50, | |||||
| kSm60, | |||||
| kSm61, | |||||
| kSm70, | |||||
| kSm72, | |||||
| kSm75, | |||||
| kSm80, | |||||
| kSm86, | |||||
| kInvalid | |||||
| }; | |||||
| enum class MathOperationID { | |||||
| kAdd, | |||||
| kMultiplyAdd, | |||||
| kMultiplyAddSaturate, | |||||
| kMultiplyAddFastBF16, | |||||
| kMultiplyAddFastF16, | |||||
| kMultiplyAddComplex, | |||||
| kMultiplyAddGaussianComplex, | |||||
| kXorPopc, | |||||
| kInvalid | |||||
| }; | |||||
| enum class ThreadblockSwizzleID { | |||||
| kGemmIdentity, | |||||
| kGemmHorizontal, | |||||
| kGemmBatchedIdentity, | |||||
| kGemmSplitKIdentity, | |||||
| kGemmSplitKHorizontal, | |||||
| kGemvBatchedStridedDefault, | |||||
| kGemvBatchedStridedReduction, | |||||
| kConvolutionFpropCxRSKx, | |||||
| kConvolutionDgradCxRSKx, | |||||
| kConvolutionFpropNCxHWx, | |||||
| kConvolutionFpropTrans, | |||||
| kConvolutionDgradNCxHWx, | |||||
| kInvalid | |||||
| }; | |||||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
| /// Enumeration indicating what kind of GEMM operation to perform | |||||
| enum class GemmKind { | |||||
| kGemm, | |||||
| kSparse, | |||||
| kUniversal, | |||||
| kPlanarComplex, | |||||
| kPlanarComplexArray, | |||||
| kInvalid | |||||
| }; | |||||
| /// Mode of Universal GEMM | |||||
| using GemmUniversalMode = cutlass::gemm::GemmUniversalMode; | |||||
| /// Enumeration indicating what kind of Conv2d operation to perform | |||||
| enum class ConvKind { kUnknown, kFprop, kDgrad, kWgrad, kInvalid }; | |||||
| enum class ConvModeID { kCrossCorrelation, kConvolution, kInvalid }; | |||||
| // Iterator algorithm enum in order of general performance-efficiency | |||||
| enum class IteratorAlgorithmID { kNone, kAnalytic, kOptimized, kInvalid }; | |||||
| enum class EpilogueKind { | |||||
| kUnknown, | |||||
| kBiasAddLinearCombination, | |||||
| kBiasAddLinearCombinationClamp, | |||||
| kBiasAddLInearCombinationHSwish, | |||||
| kBiasAddLInearCombinationHSwishClamp, | |||||
| kBiasAddLInearCombinationRelu, | |||||
| kBiasAddLInearCombinationReluClamp, | |||||
| kConversion, | |||||
| kLinearCombination, | |||||
| kLinearCombinationClamp, | |||||
| kLinearCombinationPlanarComplex, | |||||
| kLinearCombinationRelu, | |||||
| kLinearCombinationSigmoid, | |||||
| kInvalid | |||||
| }; | |||||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
| struct MathInstructionDescription { | |||||
| /// Shape of the target math instruction | |||||
| cutlass::gemm::GemmCoord instruction_shape; | |||||
| /// Describes the data type of the internal accumulator | |||||
| NumericTypeID element_accumulator; | |||||
| /// Classification of math instruction | |||||
| OpcodeClassID opcode_class; | |||||
| /// Type of math operation performed | |||||
| MathOperationID math_operation; | |||||
| // | |||||
| // Methods | |||||
| // | |||||
| MathInstructionDescription( | |||||
| cutlass::gemm::GemmCoord instruction_shape = | |||||
| cutlass::gemm::GemmCoord(), | |||||
| NumericTypeID element_accumulator = NumericTypeID::kInvalid, | |||||
| OpcodeClassID opcode_class = OpcodeClassID::kInvalid, | |||||
| MathOperationID math_operation = MathOperationID::kMultiplyAdd) | |||||
| : instruction_shape(instruction_shape), | |||||
| element_accumulator(element_accumulator), | |||||
| opcode_class(opcode_class), | |||||
| math_operation(math_operation) {} | |||||
| // Equality operator | |||||
| inline bool operator==(MathInstructionDescription const& rhs) const { | |||||
| return ((instruction_shape == rhs.instruction_shape) && | |||||
| (element_accumulator == rhs.element_accumulator) && | |||||
| (opcode_class == rhs.opcode_class) && | |||||
| (math_operation == rhs.math_operation)); | |||||
| } | |||||
| // Inequality operator | |||||
| inline bool operator!=(MathInstructionDescription const& rhs) const { | |||||
| return !(*this == rhs); | |||||
| } | |||||
| }; | |||||
| /// Structure describing the tiled structure of a GEMM-like computation | |||||
| struct TileDescription { | |||||
| /// Describes the shape of a threadblock (in elements) | |||||
| cutlass::gemm::GemmCoord threadblock_shape; | |||||
| /// Describes the number of pipeline stages in the threadblock-scoped | |||||
| /// mainloop | |||||
| int threadblock_stages; | |||||
| /// Number of warps in each logical dimension | |||||
| cutlass::gemm::GemmCoord warp_count; | |||||
| /// Core math instruction | |||||
| MathInstructionDescription math_instruction; | |||||
| /// Minimum compute capability (e.g. 70, 75) of a device eligible to run the | |||||
| /// operation. | |||||
| int minimum_compute_capability; | |||||
| /// Minimum compute capability (e.g. 70, 75) of a device eligible to run the | |||||
| /// operation. | |||||
| int maximum_compute_capability; | |||||
| // | |||||
| // Methods | |||||
| // | |||||
| TileDescription( | |||||
| cutlass::gemm::GemmCoord threadblock_shape = | |||||
| cutlass::gemm::GemmCoord(), | |||||
| int threadblock_stages = 0, | |||||
| cutlass::gemm::GemmCoord warp_count = cutlass::gemm::GemmCoord(), | |||||
| MathInstructionDescription math_instruction = | |||||
| MathInstructionDescription(), | |||||
| int minimum_compute_capability = 0, | |||||
| int maximum_compute_capability = 0) | |||||
| : threadblock_shape(threadblock_shape), | |||||
| threadblock_stages(threadblock_stages), | |||||
| warp_count(warp_count), | |||||
| math_instruction(math_instruction), | |||||
| minimum_compute_capability(minimum_compute_capability), | |||||
| maximum_compute_capability(maximum_compute_capability) {} | |||||
| // Equality operator | |||||
| inline bool operator==(TileDescription const& rhs) const { | |||||
| return ((threadblock_shape == rhs.threadblock_shape) && | |||||
| (threadblock_stages == rhs.threadblock_stages) && | |||||
| (warp_count == rhs.warp_count) && | |||||
| (math_instruction == rhs.math_instruction) && | |||||
| (minimum_compute_capability == | |||||
| rhs.minimum_compute_capability) && | |||||
| (maximum_compute_capability == rhs.maximum_compute_capability)); | |||||
| } | |||||
| // Inequality operator | |||||
| inline bool operator!=(TileDescription const& rhs) const { | |||||
| return !(*this == rhs); | |||||
| } | |||||
| }; | |||||
| /// High-level description of an operation | |||||
| struct OperationDescription { | |||||
| /// Unique identifier describing the operation | |||||
| char const* name; | |||||
| /// Operation provider | |||||
| Provider provider; | |||||
| /// Kind of operation | |||||
| OperationKind kind; | |||||
| /// Describes the tiled structure of a GEMM-like computation | |||||
| TileDescription tile_description; | |||||
| // | |||||
| // Methods | |||||
| // | |||||
| OperationDescription( | |||||
| char const* name = "unknown", | |||||
| OperationKind kind = OperationKind::kInvalid, | |||||
| TileDescription const& tile_description = TileDescription()) | |||||
| : name(name), kind(kind), tile_description(tile_description) {} | |||||
| }; | |||||
| /// Structure describing the properties of a tensor | |||||
| struct TensorDescription { | |||||
| /// Numeric type of an individual element | |||||
| NumericTypeID element; | |||||
| /// Enumerant identifying the layout function for the tensor | |||||
| LayoutTypeID layout; | |||||
| /// Alignment restriction on pointers, strides, and extents | |||||
| int alignment; | |||||
| /// log2() of the maximum extent of each dimension | |||||
| int log_extent_range; | |||||
| /// log2() of the maximum value each relevant stride may have | |||||
| int log_stride_range; | |||||
| // | |||||
| // Methods | |||||
| // | |||||
| TensorDescription(NumericTypeID element = NumericTypeID::kInvalid, | |||||
| LayoutTypeID layout = LayoutTypeID::kInvalid, | |||||
| int alignment = 1, int log_extent_range = 24, | |||||
| int log_stride_range = 24) | |||||
| : element(element), | |||||
| layout(layout), | |||||
| alignment(alignment), | |||||
| log_extent_range(log_extent_range), | |||||
| log_stride_range(log_stride_range) {} | |||||
| }; | |||||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
| struct GemmDescription : public OperationDescription { | |||||
| GemmKind gemm_kind; | |||||
| TensorDescription A; | |||||
| TensorDescription B; | |||||
| TensorDescription C; | |||||
| int stages; | |||||
| SplitKMode split_k_mode; | |||||
| }; | |||||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
| struct GemmArguments { | |||||
| /// GEMM problem size | |||||
| gemm::GemmCoord problem_size; | |||||
| /// Device pointers to input and output matrices | |||||
| void const* A; | |||||
| void const* B; | |||||
| void const* C; | |||||
| void* D; | |||||
| /// Leading dimensions of input and output matrices | |||||
| int64_t lda; | |||||
| int64_t ldb; | |||||
| int64_t ldc; | |||||
| int64_t ldd; | |||||
| /// Number of partitions of K dimension | |||||
| int split_k_slices; | |||||
| /// Host or device pointers to epilogue scalars, note that these pointers | |||||
| /// will be interpreted as ElementCompute* in method `op->run(args)`, a | |||||
| /// different dtype here results in undefined epilogue behaviors | |||||
| void const* alpha; | |||||
| void const* beta; | |||||
| }; | |||||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
| struct ConvolutionDescription : public OperationDescription { | |||||
| conv::Operator conv_op; | |||||
| TensorDescription src; | |||||
| TensorDescription filter; | |||||
| TensorDescription dst; | |||||
| TensorDescription bias; | |||||
| conv::ConvType convolution_type; | |||||
| ArchTagID arch_tag; | |||||
| epilogue::EpilogueType epilogue_type; | |||||
| int epilogue_count; | |||||
| ThreadblockSwizzleID threadblock_swizzle; | |||||
| bool need_load_from_const_mem; | |||||
| conv::ImplicitGemmMode gemm_mode; | |||||
| bool without_shared_load; | |||||
| }; | |||||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
| struct ConvolutionArguments { | |||||
| /// Problem size | |||||
| conv::Conv2dProblemSize problem_size; | |||||
| /// Device pointers to input and output tensors | |||||
| void const* src; | |||||
| void const* filter; | |||||
| void const* bias; | |||||
| void const* z; | |||||
| void* dst; | |||||
| /// Host or device pointers to epilogue scalars, note that these pointers | |||||
| /// will be interpreted as ElementCompute* in method `op->run(args)`, a | |||||
| /// different dtype here results in undefined epilogue behaviors | |||||
| void const* alpha; | |||||
| void const* beta; | |||||
| void const* gamma; | |||||
| void const* delta; | |||||
| void const* theta; | |||||
| void const* threshold; | |||||
| void const* scale; | |||||
| /// Host pointer to extra param struct | |||||
| void const* extra_param; | |||||
| }; | |||||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
| /// Base class for all operations | |||||
| class Operation { | |||||
| public: | |||||
| virtual ~Operation() {} | |||||
| virtual OperationDescription const& description() const = 0; | |||||
| virtual Status run(void const* arguments, void* device_workspace = nullptr, | |||||
| cudaStream_t stream = nullptr) const = 0; | |||||
| }; | |||||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
| } // namespace library | |||||
| } // namespace cutlass | |||||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
| @@ -0,0 +1,580 @@ | |||||
| /*************************************************************************************************** | |||||
| * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. | |||||
| * | |||||
| * Redistribution and use in source and binary forms, with or without | |||||
| *modification, are permitted provided that the following conditions are met: | |||||
| * * Redistributions of source code must retain the above copyright notice, | |||||
| *this list of conditions and the following disclaimer. | |||||
| * * Redistributions in binary form must reproduce the above copyright | |||||
| *notice, this list of conditions and the following disclaimer in the | |||||
| *documentation and/or other materials provided with the distribution. | |||||
| * * Neither the name of the NVIDIA CORPORATION nor the names of its | |||||
| *contributors may be used to endorse or promote products derived from this | |||||
| *software without specific prior written permission. | |||||
| * | |||||
| * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||||
| *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||||
| *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | |||||
| *DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, | |||||
| *INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, | |||||
| * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, | |||||
| *DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY | |||||
| *OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING | |||||
| *NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, | |||||
| *EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||||
| * | |||||
| **************************************************************************************************/ | |||||
| /** | |||||
| * \file dnn/src/cuda/cutlass/library_internal.h | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #pragma once | |||||
| #pragma GCC diagnostic push | |||||
| #pragma GCC diagnostic ignored "-Wreorder" | |||||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||||
| #include "cutlass/arch/arch.h" | |||||
| #include "cutlass/arch/mma.h" | |||||
| #include "cutlass/complex.h" | |||||
| #include "cutlass/convolution/threadblock/threadblock_swizzle.h" | |||||
| #include "cutlass/cutlass.h" | |||||
| #include "cutlass/gemm/threadblock/threadblock_swizzle.h" | |||||
| #include "cutlass/layout/matrix.h" | |||||
| #include "cutlass/numeric_types.h" | |||||
| #pragma GCC diagnostic pop | |||||
| #include "src/cuda/cutlass/arch_mappings.h" | |||||
| #include "src/cuda/cutlass/library.h" | |||||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
| namespace cutlass { | |||||
| namespace library { | |||||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
| template <typename T> | |||||
| struct NumericTypeMap; | |||||
| template <> | |||||
| struct NumericTypeMap<cutlass::uint1b_t> { | |||||
| static NumericTypeID const kId = NumericTypeID::kB1; | |||||
| }; | |||||
| template <> | |||||
| struct NumericTypeMap<cutlass::int4b_t> { | |||||
| static NumericTypeID const kId = NumericTypeID::kS4; | |||||
| }; | |||||
| template <> | |||||
| struct NumericTypeMap<int8_t> { | |||||
| static NumericTypeID const kId = NumericTypeID::kS8; | |||||
| }; | |||||
| template <> | |||||
| struct NumericTypeMap<int16_t> { | |||||
| static NumericTypeID const kId = NumericTypeID::kS16; | |||||
| }; | |||||
| template <> | |||||
| struct NumericTypeMap<int32_t> { | |||||
| static NumericTypeID const kId = NumericTypeID::kS32; | |||||
| }; | |||||
| template <> | |||||
| struct NumericTypeMap<int64_t> { | |||||
| static NumericTypeID const kId = NumericTypeID::kS64; | |||||
| }; | |||||
| template <> | |||||
| struct NumericTypeMap<cutlass::uint4b_t> { | |||||
| static NumericTypeID const kId = NumericTypeID::kU4; | |||||
| }; | |||||
| template <> | |||||
| struct NumericTypeMap<uint8_t> { | |||||
| static NumericTypeID const kId = NumericTypeID::kU8; | |||||
| }; | |||||
| template <> | |||||
| struct NumericTypeMap<uint16_t> { | |||||
| static NumericTypeID const kId = NumericTypeID::kU16; | |||||
| }; | |||||
| template <> | |||||
| struct NumericTypeMap<uint32_t> { | |||||
| static NumericTypeID const kId = NumericTypeID::kU32; | |||||
| }; | |||||
| template <> | |||||
| struct NumericTypeMap<uint64_t> { | |||||
| static NumericTypeID const kId = NumericTypeID::kU64; | |||||
| }; | |||||
| template <> | |||||
| struct NumericTypeMap<cutlass::half_t> { | |||||
| static NumericTypeID const kId = NumericTypeID::kF16; | |||||
| }; | |||||
| template <> | |||||
| struct NumericTypeMap<float> { | |||||
| static NumericTypeID const kId = NumericTypeID::kF32; | |||||
| }; | |||||
| template <> | |||||
| struct NumericTypeMap<double> { | |||||
| static NumericTypeID const kId = NumericTypeID::kF64; | |||||
| }; | |||||
| template <> | |||||
| struct NumericTypeMap<cutlass::complex<cutlass::half_t>> { | |||||
| static NumericTypeID const kId = NumericTypeID::kCF16; | |||||
| }; | |||||
| template <> | |||||
| struct NumericTypeMap<cutlass::complex<float>> { | |||||
| static NumericTypeID const kId = NumericTypeID::kCF32; | |||||
| }; | |||||
| template <> | |||||
| struct NumericTypeMap<cutlass::complex<double>> { | |||||
| static NumericTypeID const kId = NumericTypeID::kCF64; | |||||
| }; | |||||
| template <> | |||||
| struct NumericTypeMap<cutlass::bfloat16_t> { | |||||
| static NumericTypeID const kId = NumericTypeID::kBF16; | |||||
| }; | |||||
| template <> | |||||
| struct NumericTypeMap<cutlass::tfloat32_t> { | |||||
| static NumericTypeID const kId = NumericTypeID::kTF32; | |||||
| }; | |||||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
| template <typename T> | |||||
| struct MathOperationMap { | |||||
| static MathOperationID const kId = MathOperationID::kInvalid; | |||||
| }; | |||||
| template <> | |||||
| struct MathOperationMap<cutlass::arch::OpMultiplyAdd> { | |||||
| static MathOperationID const kId = MathOperationID::kMultiplyAdd; | |||||
| }; | |||||
| template <> | |||||
| struct MathOperationMap<cutlass::arch::OpMultiplyAddFastBF16> { | |||||
| static MathOperationID const kId = MathOperationID::kMultiplyAddFastBF16; | |||||
| }; | |||||
| template <> | |||||
| struct MathOperationMap<cutlass::arch::OpMultiplyAddFastF16> { | |||||
| static MathOperationID const kId = MathOperationID::kMultiplyAddFastF16; | |||||
| }; | |||||
| template <> | |||||
| struct MathOperationMap<cutlass::arch::OpMultiplyAddSaturate> { | |||||
| static MathOperationID const kId = MathOperationID::kMultiplyAddSaturate; | |||||
| }; | |||||
| template <> | |||||
| struct MathOperationMap<cutlass::arch::OpMultiplyAddComplex> { | |||||
| static MathOperationID const kId = MathOperationID::kMultiplyAddComplex; | |||||
| }; | |||||
| template <> | |||||
| struct MathOperationMap<cutlass::arch::OpMultiplyAddGaussianComplex> { | |||||
| static MathOperationID const kId = | |||||
| MathOperationID::kMultiplyAddGaussianComplex; | |||||
| }; | |||||
| template <> | |||||
| struct MathOperationMap<cutlass::arch::OpXorPopc> { | |||||
| static MathOperationID const kId = MathOperationID::kXorPopc; | |||||
| }; | |||||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
| template <typename T> | |||||
| struct LayoutMap; | |||||
| template <> | |||||
| struct LayoutMap<cutlass::layout::ColumnMajor> { | |||||
| static LayoutTypeID const kId = LayoutTypeID::kColumnMajor; | |||||
| }; | |||||
| template <> | |||||
| struct LayoutMap<cutlass::layout::RowMajor> { | |||||
| static LayoutTypeID const kId = LayoutTypeID::kRowMajor; | |||||
| }; | |||||
| template <> | |||||
| struct LayoutMap<cutlass::layout::ColumnMajorInterleaved<2>> { | |||||
| static LayoutTypeID const kId = LayoutTypeID::kColumnMajorInterleavedK2; | |||||
| }; | |||||
| template <> | |||||
| struct LayoutMap<cutlass::layout::RowMajorInterleaved<2>> { | |||||
| static LayoutTypeID const kId = LayoutTypeID::kRowMajorInterleavedK2; | |||||
| }; | |||||
| template <> | |||||
| struct LayoutMap<cutlass::layout::ColumnMajorInterleaved<4>> { | |||||
| static LayoutTypeID const kId = LayoutTypeID::kColumnMajorInterleavedK4; | |||||
| }; | |||||
| template <> | |||||
| struct LayoutMap<cutlass::layout::RowMajorInterleaved<4>> { | |||||
| static LayoutTypeID const kId = LayoutTypeID::kRowMajorInterleavedK4; | |||||
| }; | |||||
| template <> | |||||
| struct LayoutMap<cutlass::layout::ColumnMajorInterleaved<16>> { | |||||
| static LayoutTypeID const kId = LayoutTypeID::kColumnMajorInterleavedK16; | |||||
| }; | |||||
| template <> | |||||
| struct LayoutMap<cutlass::layout::RowMajorInterleaved<16>> { | |||||
| static LayoutTypeID const kId = LayoutTypeID::kRowMajorInterleavedK16; | |||||
| }; | |||||
| template <> | |||||
| struct LayoutMap<cutlass::layout::ColumnMajorInterleaved<32>> { | |||||
| static LayoutTypeID const kId = LayoutTypeID::kColumnMajorInterleavedK32; | |||||
| }; | |||||
| template <> | |||||
| struct LayoutMap<cutlass::layout::RowMajorInterleaved<32>> { | |||||
| static LayoutTypeID const kId = LayoutTypeID::kRowMajorInterleavedK32; | |||||
| }; | |||||
| template <> | |||||
| struct LayoutMap<cutlass::layout::ColumnMajorInterleaved<64>> { | |||||
| static LayoutTypeID const kId = LayoutTypeID::kColumnMajorInterleavedK64; | |||||
| }; | |||||
| template <> | |||||
| struct LayoutMap<cutlass::layout::RowMajorInterleaved<64>> { | |||||
| static LayoutTypeID const kId = LayoutTypeID::kRowMajorInterleavedK64; | |||||
| }; | |||||
| template <> | |||||
| struct LayoutMap<cutlass::layout::TensorNCHW> { | |||||
| static LayoutTypeID const kId = LayoutTypeID::kTensorNCHW; | |||||
| }; | |||||
| template <> | |||||
| struct LayoutMap<cutlass::layout::TensorNHWC> { | |||||
| static LayoutTypeID const kId = LayoutTypeID::kTensorNHWC; | |||||
| }; | |||||
| template <> | |||||
| struct LayoutMap<cutlass::layout::TensorNDHWC> { | |||||
| static LayoutTypeID const kId = LayoutTypeID::kTensorNDHWC; | |||||
| }; | |||||
| template <> | |||||
| struct LayoutMap<cutlass::layout::TensorNCxHWx<4>> { | |||||
| static LayoutTypeID const kId = LayoutTypeID::kTensorNC4HW4; | |||||
| }; | |||||
| template <> | |||||
| struct LayoutMap<cutlass::layout::TensorNCxHWx<8>> { | |||||
| static LayoutTypeID const kId = LayoutTypeID::kTensorNC8HW8; | |||||
| }; | |||||
| template <> | |||||
| struct LayoutMap<cutlass::layout::TensorNCxHWx<16>> { | |||||
| static LayoutTypeID const kId = LayoutTypeID::kTensorNC16HW16; | |||||
| }; | |||||
| template <> | |||||
| struct LayoutMap<cutlass::layout::TensorNCxHWx<32>> { | |||||
| static LayoutTypeID const kId = LayoutTypeID::kTensorNC32HW32; | |||||
| }; | |||||
| template <> | |||||
| struct LayoutMap<cutlass::layout::TensorNCxHWx<64>> { | |||||
| static LayoutTypeID const kId = LayoutTypeID::kTensorNC64HW64; | |||||
| }; | |||||
| template <> | |||||
| struct LayoutMap<cutlass::layout::TensorCxRSKx<4>> { | |||||
| static LayoutTypeID const kId = LayoutTypeID::kTensorC4RSK4; | |||||
| }; | |||||
| template <> | |||||
| struct LayoutMap<cutlass::layout::TensorCxRSKx<8>> { | |||||
| static LayoutTypeID const kId = LayoutTypeID::kTensorC8RSK8; | |||||
| }; | |||||
| template <> | |||||
| struct LayoutMap<cutlass::layout::TensorCxRSKx<16>> { | |||||
| static LayoutTypeID const kId = LayoutTypeID::kTensorC16RSK16; | |||||
| }; | |||||
| template <> | |||||
| struct LayoutMap<cutlass::layout::TensorCxRSKx<32>> { | |||||
| static LayoutTypeID const kId = LayoutTypeID::kTensorC32RSK32; | |||||
| }; | |||||
| template <> | |||||
| struct LayoutMap<cutlass::layout::TensorCxRSKx<64>> { | |||||
| static LayoutTypeID const kId = LayoutTypeID::kTensorC64RSK64; | |||||
| }; | |||||
| template <> | |||||
| struct LayoutMap<cutlass::layout::TensorKxRSCx<4>> { | |||||
| static LayoutTypeID const kId = LayoutTypeID::kTensorK4RSC4; | |||||
| }; | |||||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
| template <typename T> | |||||
| struct OpcodeClassMap; | |||||
| template <> | |||||
| struct OpcodeClassMap<arch::OpClassSimt> { | |||||
| static OpcodeClassID const kId = OpcodeClassID::kSimt; | |||||
| }; | |||||
| template <> | |||||
| struct OpcodeClassMap<arch::OpClassTensorOp> { | |||||
| static OpcodeClassID const kId = OpcodeClassID::kTensorOp; | |||||
| }; | |||||
| template <> | |||||
| struct OpcodeClassMap<arch::OpClassWmmaTensorOp> { | |||||
| static OpcodeClassID const kId = OpcodeClassID::kWmmaTensorOp; | |||||
| }; | |||||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
| template <typename T> | |||||
| struct ArchTagMap; | |||||
| template <> | |||||
| struct ArchTagMap<arch::Sm50> { | |||||
| static ArchTagID const kId = ArchTagID::kSm50; | |||||
| }; | |||||
| template <> | |||||
| struct ArchTagMap<arch::Sm60> { | |||||
| static ArchTagID const kId = ArchTagID::kSm60; | |||||
| }; | |||||
| template <> | |||||
| struct ArchTagMap<arch::Sm61> { | |||||
| static ArchTagID const kId = ArchTagID::kSm61; | |||||
| }; | |||||
| template <> | |||||
| struct ArchTagMap<arch::Sm70> { | |||||
| static ArchTagID const kId = ArchTagID::kSm70; | |||||
| }; | |||||
| template <> | |||||
| struct ArchTagMap<arch::Sm72> { | |||||
| static ArchTagID const kId = ArchTagID::kSm72; | |||||
| }; | |||||
| template <> | |||||
| struct ArchTagMap<arch::Sm75> { | |||||
| static ArchTagID const kId = ArchTagID::kSm75; | |||||
| }; | |||||
| template <> | |||||
| struct ArchTagMap<arch::Sm80> { | |||||
| static ArchTagID const kId = ArchTagID::kSm80; | |||||
| }; | |||||
| template <> | |||||
| struct ArchTagMap<arch::Sm86> { | |||||
| static ArchTagID const kId = ArchTagID::kSm86; | |||||
| }; | |||||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
| template <cutlass::ComplexTransform Transform> | |||||
| struct ComplexTransformMap; | |||||
| template <> | |||||
| struct ComplexTransformMap<cutlass::ComplexTransform::kNone> { | |||||
| static cutlass::library::ComplexTransform const kId = | |||||
| cutlass::library::ComplexTransform::kNone; | |||||
| }; | |||||
| template <> | |||||
| struct ComplexTransformMap<cutlass::ComplexTransform::kConjugate> { | |||||
| static cutlass::library::ComplexTransform const kId = | |||||
| cutlass::library::ComplexTransform::kConjugate; | |||||
| }; | |||||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
| template <cutlass::conv::Mode T> | |||||
| struct ConvModeMap; | |||||
| template <> | |||||
| struct ConvModeMap<conv::Mode::kCrossCorrelation> { | |||||
| static ConvModeID const kId = ConvModeID::kCrossCorrelation; | |||||
| }; | |||||
| template <> | |||||
| struct ConvModeMap<conv::Mode::kConvolution> { | |||||
| static ConvModeID const kId = ConvModeID::kConvolution; | |||||
| }; | |||||
| template <cutlass::conv::Operator T> | |||||
| struct ConvKindMap; | |||||
| template <> | |||||
| struct ConvKindMap<conv::Operator::kFprop> { | |||||
| static ConvKind const kId = ConvKind::kFprop; | |||||
| }; | |||||
| template <> | |||||
| struct ConvKindMap<conv::Operator::kDgrad> { | |||||
| static ConvKind const kId = ConvKind::kDgrad; | |||||
| }; | |||||
| template <> | |||||
| struct ConvKindMap<conv::Operator::kWgrad> { | |||||
| static ConvKind const kId = ConvKind::kWgrad; | |||||
| }; | |||||
| template <cutlass::conv::IteratorAlgorithm T> | |||||
| struct IteratorAlgorithmMap; | |||||
| template <> | |||||
| struct IteratorAlgorithmMap<conv::IteratorAlgorithm::kAnalytic> { | |||||
| static IteratorAlgorithmID const kId = IteratorAlgorithmID::kAnalytic; | |||||
| }; | |||||
| template <> | |||||
| struct IteratorAlgorithmMap<conv::IteratorAlgorithm::kOptimized> { | |||||
| static IteratorAlgorithmID const kId = IteratorAlgorithmID::kOptimized; | |||||
| }; | |||||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
| template <typename T> | |||||
| struct ThreadblockSwizzleMap; | |||||
| template <int N> | |||||
| struct ThreadblockSwizzleMap< | |||||
| gemm::threadblock::GemmIdentityThreadblockSwizzle<N>> { | |||||
| static ThreadblockSwizzleID const kId = ThreadblockSwizzleID::kGemmIdentity; | |||||
| }; | |||||
| template <> | |||||
| struct ThreadblockSwizzleMap< | |||||
| gemm::threadblock::GemmHorizontalThreadblockSwizzle> { | |||||
| static ThreadblockSwizzleID const kId = | |||||
| ThreadblockSwizzleID::kGemmHorizontal; | |||||
| }; | |||||
| template <> | |||||
| struct ThreadblockSwizzleMap< | |||||
| gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle> { | |||||
| static ThreadblockSwizzleID const kId = | |||||
| ThreadblockSwizzleID::kGemmBatchedIdentity; | |||||
| }; | |||||
| template <int N> | |||||
| struct ThreadblockSwizzleMap< | |||||
| gemm::threadblock::GemmSplitKIdentityThreadblockSwizzle<N>> { | |||||
| static ThreadblockSwizzleID const kId = | |||||
| ThreadblockSwizzleID::kGemmSplitKIdentity; | |||||
| }; | |||||
| template <> | |||||
| struct ThreadblockSwizzleMap< | |||||
| gemm::threadblock::GemmSplitKHorizontalThreadblockSwizzle> { | |||||
| static ThreadblockSwizzleID const kId = | |||||
| ThreadblockSwizzleID::kGemmSplitKHorizontal; | |||||
| }; | |||||
| template <> | |||||
| struct ThreadblockSwizzleMap< | |||||
| gemm::threadblock::GemvBatchedStridedThreadblockDefaultSwizzle> { | |||||
| static ThreadblockSwizzleID const kId = | |||||
| ThreadblockSwizzleID::kGemvBatchedStridedDefault; | |||||
| }; | |||||
| template <> | |||||
| struct ThreadblockSwizzleMap< | |||||
| gemm::threadblock::GemvBatchedStridedThreadblockReductionSwizzle> { | |||||
| static ThreadblockSwizzleID const kId = | |||||
| ThreadblockSwizzleID::kGemvBatchedStridedReduction; | |||||
| }; | |||||
| template <> | |||||
| struct ThreadblockSwizzleMap< | |||||
| conv::threadblock::ConvolutionFpropCxRSKxThreadblockSwizzle> { | |||||
| static ThreadblockSwizzleID const kId = | |||||
| ThreadblockSwizzleID::kConvolutionFpropCxRSKx; | |||||
| }; | |||||
| template <> | |||||
| struct ThreadblockSwizzleMap< | |||||
| conv::threadblock::ConvolutionDgradCxRSKxThreadblockSwizzle> { | |||||
| static ThreadblockSwizzleID const kId = | |||||
| ThreadblockSwizzleID::kConvolutionDgradCxRSKx; | |||||
| }; | |||||
| template <> | |||||
| struct ThreadblockSwizzleMap< | |||||
| conv::threadblock::ConvolutionFpropNCxHWxThreadblockSwizzle> { | |||||
| static ThreadblockSwizzleID const kId = | |||||
| ThreadblockSwizzleID::kConvolutionFpropNCxHWx; | |||||
| }; | |||||
| template <> | |||||
| struct ThreadblockSwizzleMap< | |||||
| conv::threadblock::ConvolutionFpropTransThreadblockSwizzle> { | |||||
| static ThreadblockSwizzleID const kId = | |||||
| ThreadblockSwizzleID::kConvolutionFpropTrans; | |||||
| }; | |||||
| template <> | |||||
| struct ThreadblockSwizzleMap< | |||||
| conv::threadblock::ConvolutionDgradNCxHWxThreadblockSwizzle> { | |||||
| static ThreadblockSwizzleID const kId = | |||||
| ThreadblockSwizzleID::kConvolutionDgradNCxHWx; | |||||
| }; | |||||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
| template <typename Element, typename Layout> | |||||
| TensorDescription make_TensorDescription(int alignment = 1) { | |||||
| TensorDescription desc; | |||||
| desc.element = NumericTypeMap<Element>::kId; | |||||
| desc.layout = LayoutMap<Layout>::kId; | |||||
| desc.alignment = alignment; | |||||
| desc.log_extent_range = | |||||
| int(sizeof(typename Layout::TensorCoord::Index) - 1) * 8; | |||||
| desc.log_stride_range = int(sizeof(typename Layout::Stride::Index) - 1) * 8; | |||||
| return desc; | |||||
| } | |||||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
| } // namespace library | |||||
| } // namespace cutlass | |||||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
| @@ -0,0 +1,96 @@ | |||||
| /*************************************************************************************************** | |||||
| * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. | |||||
| * | |||||
| * Redistribution and use in source and binary forms, with or without | |||||
| *modification, are permitted provided that the following conditions are met: | |||||
| * * Redistributions of source code must retain the above copyright notice, | |||||
| *this list of conditions and the following disclaimer. | |||||
| * * Redistributions in binary form must reproduce the above copyright | |||||
| *notice, this list of conditions and the following disclaimer in the | |||||
| *documentation and/or other materials provided with the distribution. | |||||
| * * Neither the name of the NVIDIA CORPORATION nor the names of its | |||||
| *contributors may be used to endorse or promote products derived from this | |||||
| *software without specific prior written permission. | |||||
| * | |||||
| * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||||
| *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||||
| *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | |||||
| *DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, | |||||
| *INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, | |||||
| * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, | |||||
| *DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY | |||||
| *OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING | |||||
| *NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, | |||||
| *EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||||
| * | |||||
| **************************************************************************************************/ | |||||
| /** | |||||
| * \file dnn/src/cuda/cutlass/manifest.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #include <memory> | |||||
| #include "src/cuda/cutlass/manifest.h" | |||||
| namespace cutlass { | |||||
| namespace library { | |||||
| ////////////////////////////////////////////////////////////////////////////////////////////////////////// | |||||
| /// Top-level initialization | |||||
| Status Manifest::initialize() { | |||||
| if (!operations_.empty()) { | |||||
| operations_.clear(); | |||||
| } | |||||
| // initialize procedurally generated cutlass op in manifest object | |||||
| initialize_all(*this); | |||||
| return Status::kSuccess; | |||||
| } | |||||
| /// Used for initialization | |||||
| void Manifest::reserve(size_t operation_count) { | |||||
| operations_.reserve(operation_count); | |||||
| } | |||||
| /// Graceful shutdown | |||||
| Status Manifest::release() { | |||||
| operations_.clear(); | |||||
| return Status::kSuccess; | |||||
| } | |||||
| /// Appends an operation and takes ownership | |||||
| void Manifest::append(Operation* operation_ptr) { | |||||
| operations_.emplace_back(operation_ptr); | |||||
| } | |||||
| /// Returns an iterator to the first operation | |||||
| OperationVector const& Manifest::operations() const { | |||||
| return operations_; | |||||
| } | |||||
| /// Returns a const iterator | |||||
| OperationVector::const_iterator Manifest::begin() const { | |||||
| return operations_.begin(); | |||||
| } | |||||
| /// Returns a const iterator | |||||
| OperationVector::const_iterator Manifest::end() const { | |||||
| return operations_.end(); | |||||
| } | |||||
| /////////////////////////////////////////////////////////////////////////////////////////////////// | |||||
| } // namespace library | |||||
| } // namespace cutlass | |||||
| /////////////////////////////////////////////////////////////////////////////////////////////////// | |||||
| @@ -0,0 +1,108 @@ | |||||
| /*************************************************************************************************** | |||||
| * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. | |||||
| * | |||||
| * Redistribution and use in source and binary forms, with or without | |||||
| *modification, are permitted provided that the following conditions are met: | |||||
| * * Redistributions of source code must retain the above copyright notice, | |||||
| *this list of conditions and the following disclaimer. | |||||
| * * Redistributions in binary form must reproduce the above copyright | |||||
| *notice, this list of conditions and the following disclaimer in the | |||||
| *documentation and/or other materials provided with the distribution. | |||||
| * * Neither the name of the NVIDIA CORPORATION nor the names of its | |||||
| *contributors may be used to endorse or promote products derived from this | |||||
| *software without specific prior written permission. | |||||
| * | |||||
| * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||||
| *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||||
| *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | |||||
| *DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, | |||||
| *INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, | |||||
| * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, | |||||
| *DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY | |||||
| *OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING | |||||
| *NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, | |||||
| *EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||||
| * | |||||
| **************************************************************************************************/ | |||||
| /** | |||||
| * \file dnn/src/cuda/cutlass/manifest.h | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #pragma once | |||||
| #include <list> | |||||
| #include <map> | |||||
| #include <memory> | |||||
| #include "src/cuda/cutlass/library.h" | |||||
| /////////////////////////////////////////////////////////////////////////////////////////////////// | |||||
| namespace cutlass { | |||||
| namespace library { | |||||
| /////////////////////////////////////////////////////////////////////////////////////////////////// | |||||
| // Forward declaration | |||||
| class Manifest; | |||||
| // init and insert all cutlass gemm operations in manifest object (procedurally | |||||
| // generated using generator.py) | |||||
| void initialize_all(Manifest& manifest); | |||||
| ///////////////////////////////////////////////////////////////////////////////////////////////////////// | |||||
| /// List of operations | |||||
| using OperationVector = std::vector<std::unique_ptr<Operation>>; | |||||
| /////////////////////////////////////////////////////////////////////////////////////////////////// | |||||
| /// Manifest of CUTLASS Library | |||||
| class Manifest { | |||||
| private: | |||||
| /// Operation provider | |||||
| Provider provider_; | |||||
| /// Global list of operations | |||||
| OperationVector operations_; | |||||
| public: | |||||
| Manifest(Provider provider = library::Provider::kCUTLASS) | |||||
| : provider_(provider) {} | |||||
| /// Top-level initialization | |||||
| Status initialize(); | |||||
| /// Used for initialization | |||||
| void reserve(size_t operation_count); | |||||
| /// Graceful shutdown | |||||
| Status release(); | |||||
| /// Appends an operation and takes ownership | |||||
| void append(Operation* operation_ptr); | |||||
| /// Returns an iterator to the first operation | |||||
| OperationVector const& operations() const; | |||||
| /// Returns a const iterator | |||||
| OperationVector::const_iterator begin() const; | |||||
| /// Returns a const iterator | |||||
| OperationVector::const_iterator end() const; | |||||
| }; | |||||
| /////////////////////////////////////////////////////////////////////////////////////////////////// | |||||
| } // namespace library | |||||
| } // namespace cutlass | |||||
| /////////////////////////////////////////////////////////////////////////////////////////////////// | |||||
| @@ -0,0 +1,179 @@ | |||||
| /*************************************************************************************************** | |||||
| * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. | |||||
| * | |||||
| * Redistribution and use in source and binary forms, with or without | |||||
| *modification, are permitted provided that the following conditions are met: | |||||
| * * Redistributions of source code must retain the above copyright notice, | |||||
| *this list of conditions and the following disclaimer. | |||||
| * * Redistributions in binary form must reproduce the above copyright | |||||
| *notice, this list of conditions and the following disclaimer in the | |||||
| *documentation and/or other materials provided with the distribution. | |||||
| * * Neither the name of the NVIDIA CORPORATION nor the names of its | |||||
| *contributors may be used to endorse or promote products derived from this | |||||
| *software without specific prior written permission. | |||||
| * | |||||
| * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||||
| *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||||
| *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | |||||
| *DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, | |||||
| *INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, | |||||
| * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, | |||||
| *DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY | |||||
| *OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING | |||||
| *NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, | |||||
| *EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||||
| * | |||||
| **************************************************************************************************/ | |||||
| /** | |||||
| * \file dnn/src/cuda/cutlass/operation_table.cu | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #include "src/common/utils.h" | |||||
| #include "src/cuda/cutlass/operation_table.h" | |||||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
| namespace cutlass { | |||||
| namespace library { | |||||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
| GemmKey get_gemm_key_from_desc(const GemmDescription& desc) { | |||||
| GemmKey key; | |||||
| key.element_A = desc.A.element; | |||||
| key.layout_A = desc.A.layout; | |||||
| key.element_B = desc.B.element; | |||||
| key.layout_B = desc.B.layout; | |||||
| key.element_C = desc.C.element; | |||||
| key.layout_C = desc.C.layout; | |||||
| key.threadblock_shape_m = desc.tile_description.threadblock_shape.m(); | |||||
| key.threadblock_shape_n = desc.tile_description.threadblock_shape.n(); | |||||
| key.threadblock_shape_k = desc.tile_description.threadblock_shape.k(); | |||||
| key.warp_shape_m = desc.tile_description.threadblock_shape.m() / | |||||
| desc.tile_description.warp_count.m(); | |||||
| key.warp_shape_n = desc.tile_description.threadblock_shape.n() / | |||||
| desc.tile_description.warp_count.n(); | |||||
| key.warp_shape_k = desc.tile_description.threadblock_shape.k() / | |||||
| desc.tile_description.warp_count.k(); | |||||
| key.instruction_shape_m = | |||||
| desc.tile_description.math_instruction.instruction_shape.m(); | |||||
| key.instruction_shape_n = | |||||
| desc.tile_description.math_instruction.instruction_shape.n(); | |||||
| key.instruction_shape_k = | |||||
| desc.tile_description.math_instruction.instruction_shape.k(); | |||||
| key.stages = desc.stages; | |||||
| key.split_k_mode = desc.split_k_mode; | |||||
| return key; | |||||
| } | |||||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
| ConvolutionKey get_convolution_key_from_desc( | |||||
| const ConvolutionDescription& desc) { | |||||
| ConvolutionKey key; | |||||
| key.conv_op = desc.conv_op; | |||||
| key.element_src = desc.src.element; | |||||
| key.layout_src = desc.src.layout; | |||||
| key.element_filter = desc.filter.element; | |||||
| key.layout_filter = desc.filter.layout; | |||||
| key.element_dst = desc.dst.element; | |||||
| key.layout_dst = desc.dst.layout; | |||||
| key.element_bias = desc.bias.element; | |||||
| key.layout_bias = desc.bias.layout; | |||||
| key.convolution_type = desc.convolution_type; | |||||
| key.threadblock_shape_m = desc.tile_description.threadblock_shape.m(); | |||||
| key.threadblock_shape_n = desc.tile_description.threadblock_shape.n(); | |||||
| key.threadblock_shape_k = desc.tile_description.threadblock_shape.k(); | |||||
| key.warp_shape_m = desc.tile_description.threadblock_shape.m() / | |||||
| desc.tile_description.warp_count.m(); | |||||
| key.warp_shape_n = desc.tile_description.threadblock_shape.n() / | |||||
| desc.tile_description.warp_count.n(); | |||||
| key.warp_shape_k = desc.tile_description.threadblock_shape.k() / | |||||
| desc.tile_description.warp_count.k(); | |||||
| key.instruction_shape_m = | |||||
| desc.tile_description.math_instruction.instruction_shape.m(); | |||||
| key.instruction_shape_n = | |||||
| desc.tile_description.math_instruction.instruction_shape.n(); | |||||
| key.instruction_shape_k = | |||||
| desc.tile_description.math_instruction.instruction_shape.k(); | |||||
| key.epilogue_type = desc.epilogue_type; | |||||
| key.stages = desc.tile_description.threadblock_stages; | |||||
| key.need_load_from_const_mem = desc.need_load_from_const_mem; | |||||
| key.without_shared_load = desc.without_shared_load; | |||||
| return key; | |||||
| } | |||||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
| void OperationTable::append(Manifest const& manifest) { | |||||
| // Insert operations into appropriate data structure | |||||
| for (auto const& operation : manifest) { | |||||
| OperationDescription const& desc = operation->description(); | |||||
| // insert all gemm operations into operation table | |||||
| if (desc.kind == OperationKind::kGemm) { | |||||
| GemmKey key = get_gemm_key_from_desc( | |||||
| static_cast<GemmDescription const&>(desc)); | |||||
| gemm_operations[key].push_back(operation.get()); | |||||
| } | |||||
| // insert all conv operations into operation table | |||||
| if (desc.kind == OperationKind::kConvolution) { | |||||
| ConvolutionKey key = get_convolution_key_from_desc( | |||||
| static_cast<ConvolutionDescription const&>(desc)); | |||||
| convolution_operations[key].push_back(operation.get()); | |||||
| } | |||||
| } | |||||
| } | |||||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
| Operation const* OperationTable::find_op(GemmKey const& key) const { | |||||
| megdnn_assert(gemm_operations.count(key) > 0, | |||||
| "key not found in cutlass operation table"); | |||||
| auto const& ops = gemm_operations.at(key); | |||||
| megdnn_assert(ops.size() == 1, "exactly one kernel expected, got %zu", | |||||
| ops.size()); | |||||
| return ops[0]; | |||||
| } | |||||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
| Operation const* OperationTable::find_op(ConvolutionKey const& key) const { | |||||
| megdnn_assert(convolution_operations.count(key) > 0, | |||||
| "key not found in cutlass operation table"); | |||||
| auto const& ops = convolution_operations.at(key); | |||||
| megdnn_assert(ops.size() == 1, "exactly one kernel expected, got %zu", | |||||
| ops.size()); | |||||
| return ops[0]; | |||||
| } | |||||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
| } // namespace library | |||||
| } // namespace cutlass | |||||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
| @@ -0,0 +1,334 @@ | |||||
| /*************************************************************************************************** | |||||
| * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. | |||||
| * | |||||
| * Redistribution and use in source and binary forms, with or without | |||||
| *modification, are permitted provided that the following conditions are met: | |||||
| * * Redistributions of source code must retain the above copyright notice, | |||||
| *this list of conditions and the following disclaimer. | |||||
| * * Redistributions in binary form must reproduce the above copyright | |||||
| *notice, this list of conditions and the following disclaimer in the | |||||
| *documentation and/or other materials provided with the distribution. | |||||
| * * Neither the name of the NVIDIA CORPORATION nor the names of its | |||||
| *contributors may be used to endorse or promote products derived from this | |||||
| *software without specific prior written permission. | |||||
| * | |||||
| * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||||
| *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||||
| *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | |||||
| *DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, | |||||
| *INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, | |||||
| * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, | |||||
| *DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY | |||||
| *OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING | |||||
| *NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, | |||||
| *EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||||
| * | |||||
| **************************************************************************************************/ | |||||
| /** | |||||
| * \file dnn/src/cuda/cutlass/operation_table.h | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #pragma once | |||||
| #include <unordered_map> | |||||
| #include "src/common/hash_ct.h" | |||||
| #include "src/cuda/cutlass/manifest.h" | |||||
| #include "src/cuda/cutlass/util.h" | |||||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
| namespace cutlass { | |||||
| namespace library { | |||||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
| class Hash { | |||||
| public: | |||||
| Hash() : m_val(0) {} | |||||
| Hash& update(const void* ptr, size_t len) { | |||||
| m_val += megdnn::XXHash64CT::hash((const char*)ptr, len, 123456); | |||||
| return *this; | |||||
| } | |||||
| uint64_t digest() const { return m_val; } | |||||
| private: | |||||
| uint64_t m_val; | |||||
| }; | |||||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
| // Data Structures for GemmOperationMap | |||||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
| struct GemmKey { | |||||
| NumericTypeID element_A; | |||||
| LayoutTypeID layout_A; | |||||
| NumericTypeID element_B; | |||||
| LayoutTypeID layout_B; | |||||
| NumericTypeID element_C; | |||||
| LayoutTypeID layout_C; | |||||
| int threadblock_shape_m; | |||||
| int threadblock_shape_n; | |||||
| int threadblock_shape_k; | |||||
| int warp_shape_m; | |||||
| int warp_shape_n; | |||||
| int warp_shape_k; | |||||
| int instruction_shape_m; | |||||
| int instruction_shape_n; | |||||
| int instruction_shape_k; | |||||
| int stages; | |||||
| SplitKMode split_k_mode; | |||||
| inline bool operator==(GemmKey const& rhs) const { | |||||
| return (element_A == rhs.element_A) && (layout_A == rhs.layout_A) && | |||||
| (element_B == rhs.element_B) && (layout_B == rhs.layout_B) && | |||||
| (element_C == rhs.element_C) && (layout_C == rhs.layout_C) && | |||||
| (threadblock_shape_m == rhs.threadblock_shape_m) && | |||||
| (threadblock_shape_n == rhs.threadblock_shape_n) && | |||||
| (threadblock_shape_k == rhs.threadblock_shape_k) && | |||||
| (warp_shape_m == rhs.warp_shape_m) && | |||||
| (warp_shape_n == rhs.warp_shape_n) && | |||||
| (warp_shape_k == rhs.warp_shape_k) && | |||||
| (instruction_shape_m == rhs.instruction_shape_m) && | |||||
| (instruction_shape_n == rhs.instruction_shape_n) && | |||||
| (instruction_shape_k == rhs.instruction_shape_k) && | |||||
| (stages == rhs.stages) && (split_k_mode == rhs.split_k_mode); | |||||
| } | |||||
| inline bool operator!=(GemmKey const& rhs) const { return !(*this == rhs); } | |||||
| inline std::string str() const { | |||||
| auto tuple_to_str = [](int m, int n, int k) -> std::string { | |||||
| return std::to_string(m) + " x " + std::to_string(n) + " x " + | |||||
| std::to_string(k); | |||||
| }; | |||||
| std::string threadblock_shape_str = tuple_to_str( | |||||
| threadblock_shape_m, threadblock_shape_n, threadblock_shape_k); | |||||
| std::string warp_shape_str = | |||||
| tuple_to_str(warp_shape_m, warp_shape_n, warp_shape_k); | |||||
| std::string instruction_shape_str = tuple_to_str( | |||||
| instruction_shape_m, instruction_shape_n, instruction_shape_k); | |||||
| return std::string("{") + "\n element_A: " + to_string(element_A) + | |||||
| "\n layout_A: " + to_string(layout_A) + | |||||
| "\n element_B: " + to_string(element_B) + | |||||
| "\n layout_B: " + to_string(layout_B) + | |||||
| "\n element_C: " + to_string(element_C) + | |||||
| "\n layout_C: " + to_string(layout_C) + | |||||
| "\n threadblock_shape: " + threadblock_shape_str + | |||||
| "\n warp_shape: " + warp_shape_str + | |||||
| "\n instruction_shape: " + instruction_shape_str + | |||||
| "\n stages: " + std::to_string(stages) + | |||||
| "\n split_k_mode: " + to_string(split_k_mode) + "\n}"; | |||||
| } | |||||
| }; | |||||
| struct GemmKeyHasher { | |||||
| inline size_t operator()(GemmKey const& key) const { | |||||
| return Hash() | |||||
| .update(&key.element_A, sizeof(key.element_A)) | |||||
| .update(&key.layout_A, sizeof(key.layout_A)) | |||||
| .update(&key.element_B, sizeof(key.element_B)) | |||||
| .update(&key.layout_B, sizeof(key.layout_B)) | |||||
| .update(&key.element_C, sizeof(key.element_C)) | |||||
| .update(&key.layout_C, sizeof(key.layout_C)) | |||||
| .update(&key.threadblock_shape_m, | |||||
| sizeof(key.threadblock_shape_m)) | |||||
| .update(&key.threadblock_shape_n, | |||||
| sizeof(key.threadblock_shape_n)) | |||||
| .update(&key.threadblock_shape_k, | |||||
| sizeof(key.threadblock_shape_k)) | |||||
| .update(&key.warp_shape_m, sizeof(key.warp_shape_m)) | |||||
| .update(&key.warp_shape_n, sizeof(key.warp_shape_n)) | |||||
| .update(&key.warp_shape_k, sizeof(key.warp_shape_k)) | |||||
| .update(&key.stages, sizeof(key.stages)) | |||||
| .update(&key.split_k_mode, sizeof(key.split_k_mode)) | |||||
| .digest(); | |||||
| } | |||||
| }; | |||||
| using GemmOperationMap = | |||||
| std::unordered_map<GemmKey, std::vector<Operation const*>, | |||||
| GemmKeyHasher>; | |||||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
| // Data Structures for ConvolutionOperationMap | |||||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
| struct ConvolutionKey { | |||||
| conv::Operator conv_op; | |||||
| library::NumericTypeID element_src; | |||||
| library::LayoutTypeID layout_src; | |||||
| library::NumericTypeID element_filter; | |||||
| library::LayoutTypeID layout_filter; | |||||
| library::NumericTypeID element_dst; | |||||
| library::LayoutTypeID layout_dst; | |||||
| library::NumericTypeID element_bias; | |||||
| library::LayoutTypeID layout_bias; | |||||
| conv::ConvType convolution_type; | |||||
| int threadblock_shape_m; | |||||
| int threadblock_shape_n; | |||||
| int threadblock_shape_k; | |||||
| int warp_shape_m; | |||||
| int warp_shape_n; | |||||
| int warp_shape_k; | |||||
| int instruction_shape_m; | |||||
| int instruction_shape_n; | |||||
| int instruction_shape_k; | |||||
| epilogue::EpilogueType epilogue_type; | |||||
| int stages; | |||||
| bool need_load_from_const_mem; | |||||
| bool without_shared_load; | |||||
| inline bool operator==(ConvolutionKey const& rhs) const { | |||||
| return (conv_op == rhs.conv_op) && (element_src == rhs.element_src) && | |||||
| (layout_src == rhs.layout_src) && | |||||
| (element_filter == rhs.element_filter) && | |||||
| (layout_filter == rhs.layout_filter) && | |||||
| (element_dst == rhs.element_dst) && | |||||
| (layout_dst == rhs.layout_dst) && | |||||
| (element_bias == rhs.element_bias) && | |||||
| (layout_bias == rhs.layout_bias) && | |||||
| (convolution_type == rhs.convolution_type) && | |||||
| (threadblock_shape_m == rhs.threadblock_shape_m) && | |||||
| (threadblock_shape_n == rhs.threadblock_shape_n) && | |||||
| (threadblock_shape_k == rhs.threadblock_shape_k) && | |||||
| (warp_shape_m == rhs.warp_shape_m) && | |||||
| (warp_shape_n == rhs.warp_shape_n) && | |||||
| (warp_shape_k == rhs.warp_shape_k) && | |||||
| (instruction_shape_m == rhs.instruction_shape_m) && | |||||
| (instruction_shape_n == rhs.instruction_shape_n) && | |||||
| (instruction_shape_k == rhs.instruction_shape_k) && | |||||
| (epilogue_type == rhs.epilogue_type) && (stages == rhs.stages) && | |||||
| (need_load_from_const_mem == rhs.need_load_from_const_mem) && | |||||
| (without_shared_load == rhs.without_shared_load); | |||||
| } | |||||
| inline bool operator!=(ConvolutionKey const& rhs) const { | |||||
| return !(*this == rhs); | |||||
| } | |||||
| inline std::string str() const { | |||||
| auto tuple_to_str = [](int m, int n, int k) -> std::string { | |||||
| return std::to_string(m) + " x " + std::to_string(n) + " x " + | |||||
| std::to_string(k); | |||||
| }; | |||||
| std::string threadblock_shape_str = tuple_to_str( | |||||
| threadblock_shape_m, threadblock_shape_n, threadblock_shape_k); | |||||
| std::string warp_shape_str = | |||||
| tuple_to_str(warp_shape_m, warp_shape_n, warp_shape_k); | |||||
| std::string instruction_shape_str = tuple_to_str( | |||||
| instruction_shape_m, instruction_shape_n, instruction_shape_k); | |||||
| return std::string("{") + "\n conv_op: " + to_string(conv_op) + | |||||
| "\n element_src: " + to_string(element_src) + | |||||
| "\n layout_src: " + to_string(layout_src) + | |||||
| "\n element_filter: " + to_string(element_filter) + | |||||
| "\n layout_filter: " + to_string(layout_filter) + | |||||
| "\n element_dst: " + to_string(element_dst) + | |||||
| "\n layout_dst: " + to_string(layout_dst) + | |||||
| "\n element_bias: " + to_string(element_bias) + | |||||
| "\n layout_bias: " + to_string(layout_bias) + | |||||
| "\n convolution_type: " + to_string(convolution_type) + | |||||
| "\n threadblock_shape: " + threadblock_shape_str + | |||||
| "\n warp_shape: " + warp_shape_str + | |||||
| "\n instruction_shape: " + instruction_shape_str + | |||||
| "\n epilogue_type: " + to_string(epilogue_type) + | |||||
| "\n stages: " + std::to_string(stages) + | |||||
| "\n need_load_from_const_mem: " + | |||||
| to_string(need_load_from_const_mem) + | |||||
| "\n without_shared_load: " + to_string(without_shared_load) + | |||||
| "\n}"; | |||||
| } | |||||
| }; | |||||
| struct ConvolutionKeyHasher { | |||||
| inline size_t operator()(ConvolutionKey const& key) const { | |||||
| return Hash() | |||||
| .update(&key.conv_op, sizeof(key.conv_op)) | |||||
| .update(&key.conv_op, sizeof(key.conv_op)) | |||||
| .update(&key.element_src, sizeof(key.element_src)) | |||||
| .update(&key.layout_src, sizeof(key.layout_src)) | |||||
| .update(&key.element_filter, sizeof(key.element_filter)) | |||||
| .update(&key.layout_filter, sizeof(key.layout_filter)) | |||||
| .update(&key.element_dst, sizeof(key.element_dst)) | |||||
| .update(&key.layout_dst, sizeof(key.layout_dst)) | |||||
| .update(&key.element_bias, sizeof(key.element_bias)) | |||||
| .update(&key.layout_bias, sizeof(key.layout_bias)) | |||||
| .update(&key.convolution_type, sizeof(key.convolution_type)) | |||||
| .update(&key.threadblock_shape_m, | |||||
| sizeof(key.threadblock_shape_m)) | |||||
| .update(&key.threadblock_shape_n, | |||||
| sizeof(key.threadblock_shape_n)) | |||||
| .update(&key.threadblock_shape_k, | |||||
| sizeof(key.threadblock_shape_k)) | |||||
| .update(&key.warp_shape_m, sizeof(key.warp_shape_m)) | |||||
| .update(&key.warp_shape_n, sizeof(key.warp_shape_n)) | |||||
| .update(&key.warp_shape_k, sizeof(key.warp_shape_k)) | |||||
| .update(&key.instruction_shape_m, | |||||
| sizeof(key.instruction_shape_m)) | |||||
| .update(&key.instruction_shape_n, | |||||
| sizeof(key.instruction_shape_n)) | |||||
| .update(&key.instruction_shape_k, | |||||
| sizeof(key.instruction_shape_k)) | |||||
| .update(&key.epilogue_type, sizeof(key.epilogue_type)) | |||||
| .update(&key.stages, sizeof(key.stages)) | |||||
| .update(&key.need_load_from_const_mem, | |||||
| sizeof(key.need_load_from_const_mem)) | |||||
| .update(&key.without_shared_load, | |||||
| sizeof(key.without_shared_load)) | |||||
| .digest(); | |||||
| } | |||||
| }; | |||||
| using ConvolutionOperationMap = | |||||
| std::unordered_map<ConvolutionKey, std::vector<Operation const*>, | |||||
| ConvolutionKeyHasher>; | |||||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
| /// Table of cutlass::library::Operation instances | |||||
| class OperationTable { | |||||
| public: | |||||
| /// Map of all operations of type kGemm | |||||
| GemmOperationMap gemm_operations; | |||||
| /// Map of all operations of type kConvolution | |||||
| ConvolutionOperationMap convolution_operations; | |||||
| public: | |||||
| void append(Manifest const& manifest); | |||||
| Operation const* find_op(GemmKey const& key) const; | |||||
| Operation const* find_op(ConvolutionKey const& key) const; | |||||
| }; | |||||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
| } // namespace library | |||||
| } // namespace cutlass | |||||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
| @@ -0,0 +1,72 @@ | |||||
| /*************************************************************************************************** | |||||
| * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. | |||||
| * | |||||
| * Redistribution and use in source and binary forms, with or without | |||||
| *modification, are permitted provided that the following conditions are met: | |||||
| * * Redistributions of source code must retain the above copyright notice, | |||||
| *this list of conditions and the following disclaimer. | |||||
| * * Redistributions in binary form must reproduce the above copyright | |||||
| *notice, this list of conditions and the following disclaimer in the | |||||
| *documentation and/or other materials provided with the distribution. | |||||
| * * Neither the name of the NVIDIA CORPORATION nor the names of its | |||||
| *contributors may be used to endorse or promote products derived from this | |||||
| *software without specific prior written permission. | |||||
| * | |||||
| * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||||
| *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||||
| *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | |||||
| *DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, | |||||
| *INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, | |||||
| * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, | |||||
| *DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY | |||||
| *OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING | |||||
| *NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, | |||||
| *EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||||
| * | |||||
| **************************************************************************************************/ | |||||
| /** | |||||
| * \file dnn/src/cuda/cutlass/singleton.cu | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #include <memory> | |||||
| #include "src/cuda/cutlass/singleton.h" | |||||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
| namespace cutlass { | |||||
| namespace library { | |||||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
| static std::unique_ptr<Singleton> instance; | |||||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
| Singleton::Singleton() { | |||||
| manifest.initialize(); | |||||
| operation_table.append(manifest); | |||||
| } | |||||
| Singleton const& Singleton::get() { | |||||
| if (!instance.get()) { | |||||
| instance.reset(new Singleton); | |||||
| } | |||||
| return *instance.get(); | |||||
| } | |||||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
| } // namespace library | |||||
| } // namespace cutlass | |||||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
| @@ -0,0 +1,70 @@ | |||||
| /*************************************************************************************************** | |||||
| * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. | |||||
| * | |||||
| * Redistribution and use in source and binary forms, with or without | |||||
| *modification, are permitted provided that the following conditions are met: | |||||
| * * Redistributions of source code must retain the above copyright notice, | |||||
| *this list of conditions and the following disclaimer. | |||||
| * * Redistributions in binary form must reproduce the above copyright | |||||
| *notice, this list of conditions and the following disclaimer in the | |||||
| *documentation and/or other materials provided with the distribution. | |||||
| * * Neither the name of the NVIDIA CORPORATION nor the names of its | |||||
| *contributors may be used to endorse or promote products derived from this | |||||
| *software without specific prior written permission. | |||||
| * | |||||
| * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||||
| *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||||
| *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | |||||
| *DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, | |||||
| *INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, | |||||
| * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, | |||||
| *DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY | |||||
| *OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING | |||||
| *NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, | |||||
| *EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||||
| * | |||||
| **************************************************************************************************/ | |||||
| /** | |||||
| * \file dnn/src/cuda/cutlass/singleton.h | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #pragma once | |||||
| #include "src/cuda/cutlass/operation_table.h" | |||||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
| namespace cutlass { | |||||
| namespace library { | |||||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
| /// Singleton instance stores a Manifest and Operation table | |||||
| class Singleton { | |||||
| public: | |||||
| /// Manifest object | |||||
| Manifest manifest; | |||||
| /// Operation table referencing the Manifest | |||||
| OperationTable operation_table; | |||||
| public: | |||||
| Singleton(); | |||||
| static Singleton const& get(); | |||||
| }; | |||||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
| } // namespace library | |||||
| } // namespace cutlass | |||||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
| @@ -0,0 +1,218 @@ | |||||
| /*************************************************************************************************** | |||||
| * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. | |||||
| * | |||||
| * Redistribution and use in source and binary forms, with or without | |||||
| *modification, are permitted provided that the following conditions are met: | |||||
| * * Redistributions of source code must retain the above copyright notice, | |||||
| *this list of conditions and the following disclaimer. | |||||
| * * Redistributions in binary form must reproduce the above copyright | |||||
| *notice, this list of conditions and the following disclaimer in the | |||||
| *documentation and/or other materials provided with the distribution. | |||||
| * * Neither the name of the NVIDIA CORPORATION nor the names of its | |||||
| *contributors may be used to endorse or promote products derived from this | |||||
| *software without specific prior written permission. | |||||
| * | |||||
| * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||||
| *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||||
| *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | |||||
| *DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, | |||||
| *INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, | |||||
| * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, | |||||
| *DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY | |||||
| *OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING | |||||
| *NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, | |||||
| *EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||||
| * | |||||
| **************************************************************************************************/ | |||||
| /** | |||||
| * \file dnn/src/cuda/cutlass/util.h | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #pragma once | |||||
| #include "src/cuda/cutlass/library.h" | |||||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
| namespace cutlass { | |||||
| namespace library { | |||||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
| /// Lexical cast from string | |||||
| template <typename T> | |||||
| T from_string(std::string const&); | |||||
| /// Converts a Provider enumerant to a string | |||||
| char const* to_string(Provider provider, bool pretty = false); | |||||
| /// Parses a Provider enumerant from a string | |||||
| template <> | |||||
| Provider from_string<Provider>(std::string const& str); | |||||
| /// Converts a GemmKind enumerant to a string | |||||
| char const* to_string(GemmKind type, bool pretty = false); | |||||
| /// Converts a NumericType enumerant to a string | |||||
| char const* to_string(OperationKind type, bool pretty = false); | |||||
| /// Parses a NumericType enumerant from a string | |||||
| template <> | |||||
| OperationKind from_string<OperationKind>(std::string const& str); | |||||
| /// Converts a NumericType enumerant to a string | |||||
| char const* to_string(NumericTypeID type, bool pretty = false); | |||||
| /// Parses a NumericType enumerant from a string | |||||
| template <> | |||||
| NumericTypeID from_string<NumericTypeID>(std::string const& str); | |||||
| /// Returns the size of a data type in bits | |||||
| int sizeof_bits(NumericTypeID type); | |||||
| /// Returns true if the numeric type is a complex data type or false if | |||||
| /// real-valued. | |||||
| bool is_complex_type(NumericTypeID type); | |||||
| /// Returns the real-valued type underlying a type (only different from 'type' | |||||
| /// if complex) | |||||
| NumericTypeID get_real_type(NumericTypeID type); | |||||
| /// Returns true if numeric type is integer | |||||
| bool is_integer_type(NumericTypeID type); | |||||
| /// Returns true if numeric type is signed | |||||
| bool is_signed_type(NumericTypeID type); | |||||
| /// Returns true if numeric type is a signed integer | |||||
| bool is_signed_integer(NumericTypeID type); | |||||
| /// returns true if numeric type is an unsigned integer | |||||
| bool is_unsigned_integer(NumericTypeID type); | |||||
| /// Returns true if numeric type is floating-point type | |||||
| bool is_float_type(NumericTypeID type); | |||||
| /// To string method for cutlass::Status | |||||
| char const* to_string(Status status, bool pretty = false); | |||||
| /// Converts a LayoutTypeID enumerant to a string | |||||
| char const* to_string(LayoutTypeID layout, bool pretty = false); | |||||
| /// Parses a LayoutType enumerant from a string | |||||
| template <> | |||||
| LayoutTypeID from_string<LayoutTypeID>(std::string const& str); | |||||
| /// Returns the rank of a layout's stride base on the LayoutTypeID | |||||
| int get_layout_stride_rank(LayoutTypeID layout_id); | |||||
| /// Converts a OpcodeClassID enumerant to a string | |||||
| char const* to_string(OpcodeClassID type, bool pretty = false); | |||||
| /// Converts a OpcodeClassID enumerant from a string | |||||
| template <> | |||||
| OpcodeClassID from_string<OpcodeClassID>(std::string const& str); | |||||
| /// Converts a ComplexTransform enumerant to a string | |||||
| char const* to_string(ComplexTransform type, bool pretty = false); | |||||
| /// Converts a ComplexTransform enumerant from a string | |||||
| template <> | |||||
| ComplexTransform from_string<ComplexTransform>(std::string const& str); | |||||
| /// Converts a SplitKMode enumerant to a string | |||||
| char const* to_string(SplitKMode split_k_mode, bool pretty = false); | |||||
| /// Converts a SplitKMode enumerant from a string | |||||
| template <> | |||||
| SplitKMode from_string<SplitKMode>(std::string const& str); | |||||
| /// Converts a ConvModeID enumerant to a string | |||||
| char const* to_string(ConvModeID type, bool pretty = false); | |||||
| /// Converts a ConvModeID enumerant from a string | |||||
| template <> | |||||
| ConvModeID from_string<ConvModeID>(std::string const& str); | |||||
| /// Converts a IteratorAlgorithmID enumerant to a string | |||||
| char const* to_string(IteratorAlgorithmID type, bool pretty = false); | |||||
| /// Converts a IteratorAlgorithmID enumerant from a string | |||||
| template <> | |||||
| IteratorAlgorithmID from_string<IteratorAlgorithmID>(std::string const& str); | |||||
| /// Converts a ConvKind enumerant to a string | |||||
| char const* to_string(ConvKind type, bool pretty = false); | |||||
| /// Converts a ConvKind enumerant from a string | |||||
| template <> | |||||
| ConvKind from_string<ConvKind>(std::string const& str); | |||||
| /// Lexical cast from int64_t to string | |||||
| std::string lexical_cast(int64_t int_value); | |||||
| /// Lexical cast a string to a byte array. Returns true if cast is successful or | |||||
| /// false if invalid. | |||||
| bool lexical_cast(std::vector<uint8_t>& bytes, NumericTypeID type, | |||||
| std::string const& str); | |||||
| /// Lexical cast TO a string FROM a byte array. Returns true if cast is | |||||
| /// successful or false if invalid. | |||||
| std::string lexical_cast(std::vector<uint8_t>& bytes, NumericTypeID type); | |||||
| /// Casts from a signed int64 to the destination type. Returns true if | |||||
| /// successful. | |||||
| bool cast_from_int64(std::vector<uint8_t>& bytes, NumericTypeID type, | |||||
| int64_t src); | |||||
| /// Casts from an unsigned int64 to the destination type. Returns true if | |||||
| /// successful. | |||||
| bool cast_from_uint64(std::vector<uint8_t>& bytes, NumericTypeID type, | |||||
| uint64_t src); | |||||
| /// Casts from a real value represented as a double to the destination type. | |||||
| /// Returns true if successful. | |||||
| bool cast_from_double(std::vector<uint8_t>& bytes, NumericTypeID type, | |||||
| double src); | |||||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
| /// Converts a conv::Operator enumerant to a string | |||||
| char const* to_string(conv::Operator conv_op, bool pretty = false); | |||||
| /// Converts a ConvType enumerant to a string | |||||
| char const* to_string(conv::ConvType type, bool pretty = false); | |||||
| /// Converts an ArchTagID enumerant to a string | |||||
| char const* to_string(ArchTagID tag, bool pretty = false); | |||||
| /// Converts an EpilogueType enumerant to a string | |||||
| char const* to_string(epilogue::EpilogueType type, bool pretty = false); | |||||
| /// Converts a ThreadblockSwizzleID enumerant to a string | |||||
| char const* to_string(ThreadblockSwizzleID threadblock_swizzle, | |||||
| bool pretty = false); | |||||
| /// Converts a bool value to a string | |||||
| char const* to_string(bool val, bool pretty = false); | |||||
| /// Converts a MathOperationID enumerant to a string | |||||
| char const* to_string(MathOperationID math_op, bool pretty = false); | |||||
| /// Converts an ImplicitGemmMode enumerant to a string | |||||
| char const* to_string(conv::ImplicitGemmMode mode, bool pretty = false); | |||||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
| } // namespace library | |||||
| } // namespace cutlass | |||||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
| @@ -10,15 +10,14 @@ | |||||
| * implied. | * implied. | ||||
| */ | */ | ||||
| #include "src/cuda/cutlass/singleton.h" | |||||
| #include "src/cuda/handle.h" | #include "src/cuda/handle.h" | ||||
| #include "src/cuda/matrix_mul/algos.h" | #include "src/cuda/matrix_mul/algos.h" | ||||
| #include "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cuh" | |||||
| #include "src/cuda/utils.h" | #include "src/cuda/utils.h" | ||||
| #if CUDA_VERSION >= 9020 | #if CUDA_VERSION >= 9020 | ||||
| using namespace megdnn; | using namespace megdnn; | ||||
| using namespace cuda; | using namespace cuda; | ||||
| using namespace cutlass_wrapper; | |||||
| bool MatrixMulForwardImpl::AlgoFloat32SIMT::is_available( | bool MatrixMulForwardImpl::AlgoFloat32SIMT::is_available( | ||||
| const SizeArgs& args) const { | const SizeArgs& args) const { | ||||
| @@ -44,25 +43,62 @@ size_t MatrixMulForwardImpl::AlgoFloat32SIMT::get_workspace_in_bytes( | |||||
| } | } | ||||
| void MatrixMulForwardImpl::AlgoFloat32SIMT::exec(const ExecArgs& args) const { | void MatrixMulForwardImpl::AlgoFloat32SIMT::exec(const ExecArgs& args) const { | ||||
| size_t lda = args.tensor_a.layout.stride[0], | |||||
| ldb = args.tensor_b.layout.stride[0], | |||||
| ldc = args.tensor_c.layout.stride[0]; | |||||
| int64_t lda = args.tensor_a.layout.stride[0], | |||||
| ldb = args.tensor_b.layout.stride[0], | |||||
| ldc = args.tensor_c.layout.stride[0]; | |||||
| auto&& param = args.opr->param(); | auto&& param = args.opr->param(); | ||||
| int m = args.tensor_c.layout.shape[0], n = args.tensor_c.layout.shape[1], | int m = args.tensor_c.layout.shape[0], n = args.tensor_c.layout.shape[1], | ||||
| k = args.tensor_a.layout.shape[param.transposeA ? 0 : 1]; | k = args.tensor_a.layout.shape[param.transposeA ? 0 : 1]; | ||||
| GemmCoord problem_size{m, n, k}; | |||||
| cutlass::gemm::GemmCoord problem_size{m, n, k}; | |||||
| auto&& stream = cuda_stream(args.opr->handle()); | auto&& stream = cuda_stream(args.opr->handle()); | ||||
| int* workspace = reinterpret_cast<int*>(args.workspace.raw_ptr); | int* workspace = reinterpret_cast<int*>(args.workspace.raw_ptr); | ||||
| return cutlass_matrix_mul_float32_simt( | |||||
| args.tensor_a.ptr<dt_float32>(), param.transposeA, lda, | |||||
| args.tensor_b.ptr<dt_float32>(), param.transposeB, ldb, | |||||
| args.tensor_c.ptr<dt_float32>(), ldc, workspace, problem_size, 1.f, | |||||
| 0.f, | |||||
| GemmCoord{m_algo_param.threadblock_m, m_algo_param.threadblock_n, | |||||
| m_algo_param.threadblock_k}, | |||||
| GemmCoord{m_algo_param.warp_m, m_algo_param.warp_n, | |||||
| m_algo_param.warp_k}, | |||||
| stream); | |||||
| // \note these constants of cutlass epilogue will be passed to struct | |||||
| // `GemmArguments` by pointer and interpreted as ElementCompute*, a | |||||
| // different dtype here results in undefined epilogue behaviors | |||||
| float alpha = 1.f, beta = 0.f; | |||||
| using namespace cutlass::library; | |||||
| auto layoutA = param.transposeA ? LayoutTypeID::kColumnMajor | |||||
| : LayoutTypeID::kRowMajor; | |||||
| auto layoutB = param.transposeB ? LayoutTypeID::kColumnMajor | |||||
| : LayoutTypeID::kRowMajor; | |||||
| GemmKey key{NumericTypeID::kF32, | |||||
| layoutA, | |||||
| NumericTypeID::kF32, | |||||
| layoutB, | |||||
| NumericTypeID::kF32, | |||||
| LayoutTypeID::kRowMajor, | |||||
| m_algo_param.threadblock_m, | |||||
| m_algo_param.threadblock_n, | |||||
| m_algo_param.threadblock_k, | |||||
| m_algo_param.warp_m, | |||||
| m_algo_param.warp_n, | |||||
| m_algo_param.warp_k, | |||||
| 1, | |||||
| 1, | |||||
| 1, | |||||
| 2, | |||||
| SplitKMode::kNone}; | |||||
| const Operation* op = Singleton::get().operation_table.find_op(key); | |||||
| GemmArguments gemm_args{problem_size, | |||||
| args.tensor_a.raw_ptr, | |||||
| args.tensor_b.raw_ptr, | |||||
| args.tensor_c.raw_ptr, | |||||
| args.tensor_c.raw_ptr, | |||||
| lda, | |||||
| ldb, | |||||
| ldc, | |||||
| ldc, | |||||
| 1, | |||||
| &alpha, | |||||
| &beta}; | |||||
| cutlass_check(op->run(&gemm_args, workspace, stream)); | |||||
| } | } | ||||
| #endif | #endif | ||||
| @@ -10,15 +10,14 @@ | |||||
| * implied. | * implied. | ||||
| */ | */ | ||||
| #include "src/cuda/cutlass/singleton.h" | |||||
| #include "src/cuda/handle.h" | #include "src/cuda/handle.h" | ||||
| #include "src/cuda/matrix_mul/algos.h" | #include "src/cuda/matrix_mul/algos.h" | ||||
| #include "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cuh" | |||||
| #include "src/cuda/utils.h" | #include "src/cuda/utils.h" | ||||
| #if CUDA_VERSION >= 9020 | #if CUDA_VERSION >= 9020 | ||||
| using namespace megdnn; | using namespace megdnn; | ||||
| using namespace cuda; | using namespace cuda; | ||||
| using namespace cutlass_wrapper; | |||||
| bool MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::is_available( | bool MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::is_available( | ||||
| const SizeArgs& args) const { | const SizeArgs& args) const { | ||||
| @@ -50,26 +49,63 @@ size_t MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::get_workspace_in_bytes( | |||||
| void MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::exec( | void MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::exec( | ||||
| const ExecArgs& args) const { | const ExecArgs& args) const { | ||||
| size_t lda = args.tensor_a.layout.stride[0], | |||||
| ldb = args.tensor_b.layout.stride[0], | |||||
| ldc = args.tensor_c.layout.stride[0]; | |||||
| int64_t lda = args.tensor_a.layout.stride[0], | |||||
| ldb = args.tensor_b.layout.stride[0], | |||||
| ldc = args.tensor_c.layout.stride[0]; | |||||
| auto&& param = args.opr->param(); | auto&& param = args.opr->param(); | ||||
| int m = args.tensor_c.layout.shape[0], n = args.tensor_c.layout.shape[1], | int m = args.tensor_c.layout.shape[0], n = args.tensor_c.layout.shape[1], | ||||
| k = args.tensor_a.layout.shape[param.transposeA ? 0 : 1]; | k = args.tensor_a.layout.shape[param.transposeA ? 0 : 1]; | ||||
| GemmCoord problem_size{m, n, k}; | |||||
| cutlass::gemm::GemmCoord problem_size{m, n, k}; | |||||
| int split_k_slices = std::max(1, k / n); | int split_k_slices = std::max(1, k / n); | ||||
| auto&& stream = cuda_stream(args.opr->handle()); | auto&& stream = cuda_stream(args.opr->handle()); | ||||
| int* workspace = reinterpret_cast<int*>(args.workspace.raw_ptr); | int* workspace = reinterpret_cast<int*>(args.workspace.raw_ptr); | ||||
| return cutlass_matrix_mul_float32_simt( | |||||
| args.tensor_a.ptr<dt_float32>(), param.transposeA, lda, | |||||
| args.tensor_b.ptr<dt_float32>(), param.transposeB, ldb, | |||||
| args.tensor_c.ptr<dt_float32>(), ldc, workspace, problem_size, 1.f, | |||||
| 0.f, | |||||
| GemmCoord{m_algo_param.threadblock_m, m_algo_param.threadblock_n, | |||||
| m_algo_param.threadblock_k}, | |||||
| GemmCoord{m_algo_param.warp_m, m_algo_param.warp_n, | |||||
| m_algo_param.warp_k}, | |||||
| stream, split_k_slices); | |||||
| // \note these constants of cutlass epilogue will be passed to struct | |||||
| // `GemmArguments` by pointer and interpreted as ElementCompute*, a | |||||
| // different dtype here results in undefined epilogue behaviors | |||||
| float alpha = 1.f, beta = 0.f; | |||||
| using namespace cutlass::library; | |||||
| auto layoutA = param.transposeA ? LayoutTypeID::kColumnMajor | |||||
| : LayoutTypeID::kRowMajor; | |||||
| auto layoutB = param.transposeB ? LayoutTypeID::kColumnMajor | |||||
| : LayoutTypeID::kRowMajor; | |||||
| GemmKey key{NumericTypeID::kF32, | |||||
| layoutA, | |||||
| NumericTypeID::kF32, | |||||
| layoutB, | |||||
| NumericTypeID::kF32, | |||||
| LayoutTypeID::kRowMajor, | |||||
| m_algo_param.threadblock_m, | |||||
| m_algo_param.threadblock_n, | |||||
| m_algo_param.threadblock_k, | |||||
| m_algo_param.warp_m, | |||||
| m_algo_param.warp_n, | |||||
| m_algo_param.warp_k, | |||||
| 1, | |||||
| 1, | |||||
| 1, | |||||
| 2, | |||||
| SplitKMode::kParallel}; | |||||
| Operation const* op = Singleton::get().operation_table.find_op(key); | |||||
| GemmArguments gemm_args{problem_size, | |||||
| args.tensor_a.raw_ptr, | |||||
| args.tensor_b.raw_ptr, | |||||
| args.tensor_c.raw_ptr, | |||||
| args.tensor_c.raw_ptr, | |||||
| lda, | |||||
| ldb, | |||||
| ldc, | |||||
| ldc, | |||||
| split_k_slices, | |||||
| &alpha, | |||||
| &beta}; | |||||
| cutlass_check(op->run(&gemm_args, workspace, stream)); | |||||
| } | } | ||||
| #endif | #endif | ||||
| @@ -1,157 +0,0 @@ | |||||
| /** | |||||
| * \file dnn/src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cu | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| // ignore warning of cutlass | |||||
| #include "cuda.h" | |||||
| #if __CUDACC_VER_MAJOR__ > 9 || \ | |||||
| (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) | |||||
| #pragma GCC diagnostic push | |||||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||||
| #include "cutlass/gemm/device/gemm.h" | |||||
| #include "cutlass/gemm/device/gemm_splitk_parallel.h" | |||||
| #include "cutlass/gemm/kernel/default_gemv.h" | |||||
| #include "src/common/opr_param_defs_enumv.cuh" | |||||
| #include "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cuh" | |||||
| #pragma GCC diagnostic pop | |||||
| using namespace megdnn; | |||||
| using namespace cuda; | |||||
| using namespace cutlass_wrapper; | |||||
| /* ================= cutlass kernel wrapper for f32 matrix mul ================ | |||||
| */ | |||||
| #define DISPATCH(cb) \ | |||||
| cb(64, 256, 8, 32, 64, 8); \ | |||||
| cb(256, 64, 8, 64, 32, 8); \ | |||||
| cb(32, 256, 8, 16, 64, 8); \ | |||||
| cb(256, 32, 8, 64, 16, 8); \ | |||||
| cb(128, 128, 8, 32, 64, 8); \ | |||||
| cb(128, 64, 8, 64, 32, 8); \ | |||||
| cb(64, 128, 8, 32, 64, 8); \ | |||||
| cb(128, 32, 8, 64, 32, 8); \ | |||||
| cb(32, 128, 8, 32, 64, 8); \ | |||||
| cb(64, 64, 8, 32, 64, 8); \ | |||||
| cb(32, 64, 8, 32, 64, 8); \ | |||||
| cb(64, 32, 8, 64, 32, 8); \ | |||||
| cb(32, 32, 8, 32, 32, 8); \ | |||||
| cb(8, 32, 8, 8, 32, 8); \ | |||||
| cb(16, 32, 8, 16, 32, 8); \ | |||||
| cb(16, 64, 8, 16, 64, 8); \ | |||||
| cb(16, 128, 8, 16, 64, 8); \ | |||||
| megdnn_assert(false, \ | |||||
| "unsupported threadblock shape (%dx%dx%d) and warp shape " \ | |||||
| "(%dx%dx%d)", \ | |||||
| threadblock_shape.m(), threadblock_shape.n(), \ | |||||
| threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ | |||||
| warp_shape.k()); | |||||
| void megdnn::cuda::cutlass_wrapper::cutlass_matrix_mul_float32_simt( | |||||
| const float* d_A, bool transpose_A, size_t lda, const float* d_B, | |||||
| bool transpose_B, size_t ldb, float* d_C, size_t ldc, int* workspace, | |||||
| GemmCoord const& problem_size, float alpha, float beta, | |||||
| const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, | |||||
| cudaStream_t stream, int split_k_slices) { | |||||
| static constexpr int kEpilogueElementsPerAccess = 1; | |||||
| using EpilogueOp = cutlass::epilogue::thread::LinearCombination< | |||||
| float, kEpilogueElementsPerAccess, float, float>; | |||||
| typename EpilogueOp::Params epilogue{alpha, beta}; | |||||
| if (split_k_slices == 1) { | |||||
| #define cb(threadblock_m_, threadblock_n_, threadblock_k_, warp_m_, warp_n_, \ | |||||
| warp_k_) \ | |||||
| if (threadblock_shape.m() == threadblock_m_ && \ | |||||
| threadblock_shape.n() == threadblock_n_ && \ | |||||
| threadblock_shape.k() == threadblock_k_ && \ | |||||
| warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ | |||||
| warp_shape.k() == warp_k_) { \ | |||||
| using ThreadBlockShape = \ | |||||
| cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \ | |||||
| threadblock_k_>; \ | |||||
| using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \ | |||||
| using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; \ | |||||
| using Gemm = cutlass::gemm::device::Gemm< \ | |||||
| float, LayoutA, float, LayoutB, float, \ | |||||
| cutlass::layout::RowMajor, float, cutlass::arch::OpClassSimt, \ | |||||
| cutlass::arch::Sm50, ThreadBlockShape, WarpShape, \ | |||||
| InstructionShape, EpilogueOp, \ | |||||
| cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, \ | |||||
| 2>; \ | |||||
| return cutlass_matrix_mul_wrapper<Gemm>(d_A, lda, d_B, ldb, d_C, ldc, \ | |||||
| workspace, problem_size, \ | |||||
| epilogue, stream); \ | |||||
| } | |||||
| if (!transpose_A && !transpose_B) { | |||||
| using LayoutA = cutlass::layout::RowMajor; | |||||
| using LayoutB = cutlass::layout::RowMajor; | |||||
| DISPATCH(cb) | |||||
| } else if (!transpose_A && transpose_B) { | |||||
| using LayoutA = cutlass::layout::RowMajor; | |||||
| using LayoutB = cutlass::layout::ColumnMajor; | |||||
| DISPATCH(cb) | |||||
| } else if (transpose_A && !transpose_B) { | |||||
| using LayoutA = cutlass::layout::ColumnMajor; | |||||
| using LayoutB = cutlass::layout::RowMajor; | |||||
| DISPATCH(cb) | |||||
| } else { | |||||
| megdnn_assert(transpose_A && transpose_B); | |||||
| using LayoutA = cutlass::layout::ColumnMajor; | |||||
| using LayoutB = cutlass::layout::ColumnMajor; | |||||
| DISPATCH(cb) | |||||
| } | |||||
| #undef cb | |||||
| } else { | |||||
| #define cb(threadblock_m_, threadblock_n_, threadblock_k_, warp_m_, warp_n_, \ | |||||
| warp_k_) \ | |||||
| if (threadblock_shape.m() == threadblock_m_ && \ | |||||
| threadblock_shape.n() == threadblock_n_ && \ | |||||
| threadblock_shape.k() == threadblock_k_ && \ | |||||
| warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ | |||||
| warp_shape.k() == warp_k_) { \ | |||||
| using ThreadBlockShape = \ | |||||
| cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \ | |||||
| threadblock_k_>; \ | |||||
| using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \ | |||||
| using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; \ | |||||
| using Gemm = cutlass::gemm::device::GemmSplitKParallel< \ | |||||
| float, LayoutA, float, LayoutB, float, \ | |||||
| cutlass::layout::RowMajor, float, cutlass::arch::OpClassSimt, \ | |||||
| cutlass::arch::Sm50, ThreadBlockShape, WarpShape, \ | |||||
| InstructionShape, EpilogueOp>; \ | |||||
| return cutlass_matrix_mul_wrapper<Gemm>( \ | |||||
| d_A, lda, d_B, ldb, d_C, ldc, workspace, problem_size, \ | |||||
| epilogue, stream, split_k_slices); \ | |||||
| } | |||||
| if (!transpose_A && !transpose_B) { | |||||
| using LayoutA = cutlass::layout::RowMajor; | |||||
| using LayoutB = cutlass::layout::RowMajor; | |||||
| DISPATCH(cb) | |||||
| } else if (!transpose_A && transpose_B) { | |||||
| using LayoutA = cutlass::layout::RowMajor; | |||||
| using LayoutB = cutlass::layout::ColumnMajor; | |||||
| DISPATCH(cb) | |||||
| } else if (transpose_A && !transpose_B) { | |||||
| using LayoutA = cutlass::layout::ColumnMajor; | |||||
| using LayoutB = cutlass::layout::RowMajor; | |||||
| DISPATCH(cb) | |||||
| } else { | |||||
| megdnn_assert(transpose_A && transpose_B); | |||||
| using LayoutA = cutlass::layout::ColumnMajor; | |||||
| using LayoutB = cutlass::layout::ColumnMajor; | |||||
| DISPATCH(cb) | |||||
| } | |||||
| #undef cb | |||||
| } | |||||
| } | |||||
| #undef DISPATCH | |||||
| #endif | |||||
| // vim: syntax=cuda.doxygen | |||||
| @@ -21,22 +21,6 @@ namespace cutlass_wrapper { | |||||
| using GemmCoord = cutlass::gemm::GemmCoord; | using GemmCoord = cutlass::gemm::GemmCoord; | ||||
| using BatchedGemmCoord = cutlass::gemm::BatchedGemmCoord; | using BatchedGemmCoord = cutlass::gemm::BatchedGemmCoord; | ||||
| template <typename Gemm> | |||||
| void cutlass_matrix_mul_wrapper( | |||||
| const typename Gemm::ElementA* d_A, size_t lda, | |||||
| const typename Gemm::ElementB* d_B, size_t ldb, | |||||
| typename Gemm::ElementC* d_C, size_t ldc, int* workspace, | |||||
| GemmCoord const& problem_size, | |||||
| typename Gemm::EpilogueOutputOp::Params const& epilogue, | |||||
| cudaStream_t stream, int split_k_slices = 1); | |||||
| void cutlass_matrix_mul_float32_simt( | |||||
| const float* d_A, bool transpose_A, size_t lda, const float* d_B, | |||||
| bool transpose_B, size_t ldb, float* d_C, size_t ldc, int* workspace, | |||||
| GemmCoord const& problem_size, float alpha, float beta, | |||||
| const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, | |||||
| cudaStream_t stream, int split_k_slices = 1); | |||||
| template <typename GemvKernel> | template <typename GemvKernel> | ||||
| void cutlass_vector_matrix_mul_batched_strided_wrapper( | void cutlass_vector_matrix_mul_batched_strided_wrapper( | ||||
| BatchedGemmCoord const& problem_size, | BatchedGemmCoord const& problem_size, | ||||
| @@ -1,57 +0,0 @@ | |||||
| /** | |||||
| * \file | |||||
| * dnn/src/cuda/matrix_mul/matrix_mul_float_simt_cutlass_wrapper.cuinl | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #include "cutlass/gemm/device/gemm.h" | |||||
| #include "cutlass/gemm/device/gemm_splitk_parallel.h" | |||||
| #include "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cuh" | |||||
| using namespace megdnn; | |||||
| using namespace cuda; | |||||
| using namespace cutlass_wrapper; | |||||
| template <typename Gemm> | |||||
| void megdnn::cuda::cutlass_wrapper::cutlass_matrix_mul_wrapper( | |||||
| const typename Gemm::ElementA* d_A, size_t lda, | |||||
| const typename Gemm::ElementB* d_B, size_t ldb, | |||||
| typename Gemm::ElementC* d_C, size_t ldc, int* workspace, | |||||
| GemmCoord const& problem_size, | |||||
| typename Gemm::EpilogueOutputOp::Params const& epilogue, | |||||
| cudaStream_t stream, int split_k_slices) { | |||||
| using TensorRefA = cutlass::TensorRef<typename Gemm::ElementA const, | |||||
| typename Gemm::LayoutA>; | |||||
| using TensorRefB = cutlass::TensorRef<typename Gemm::ElementB const, | |||||
| typename Gemm::LayoutB>; | |||||
| using TensorRefC = cutlass::TensorRef<typename Gemm::ElementC const, | |||||
| typename Gemm::LayoutC>; | |||||
| using TensorRefD = | |||||
| cutlass::TensorRef<typename Gemm::ElementC, typename Gemm::LayoutC>; | |||||
| TensorRefA tensor_a{const_cast<typename Gemm::ElementA*>(d_A), | |||||
| typename Gemm::LayoutA{static_cast<int>(lda)}}; | |||||
| TensorRefB tensor_b{const_cast<typename Gemm::ElementB*>(d_B), | |||||
| typename Gemm::LayoutB{static_cast<int>(ldb)}}; | |||||
| TensorRefC tensor_c{nullptr, typename Gemm::LayoutC{static_cast<int>(ldc)}}; | |||||
| TensorRefD tensor_d{d_C, typename Gemm::LayoutC{static_cast<int>(ldc)}}; | |||||
| typename Gemm::Arguments arguments{problem_size, | |||||
| tensor_a, | |||||
| tensor_b, | |||||
| tensor_c, | |||||
| tensor_d.non_const_ref(), | |||||
| epilogue, | |||||
| split_k_slices}; | |||||
| Gemm gemm_op; | |||||
| cutlass_check(gemm_op.initialize(arguments, workspace)); | |||||
| cutlass_check(gemm_op(stream)); | |||||
| after_kernel_launch(); | |||||
| } | |||||
| // vim: syntax=cuda.doxygen | |||||