GitOrigin-RevId: 025c591f75
tags/v1.6.0-rc1
| @@ -5,6 +5,8 @@ genrule( | |||
| outs = cutlass_gen_list, | |||
| cmd = """GEN=$(location //brain/megbrain/dnn/scripts/cutlass_generator:generator.py) | |||
| CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations gemm --type simt $(@D) | |||
| CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations gemm --type tensorop884 $(@D) | |||
| CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations gemm --type tensorop1688 $(@D) | |||
| CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations gemv --type simt $(@D) | |||
| CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations deconv --type simt $(@D) | |||
| CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations conv2d --type simt $(@D) | |||
| @@ -252,7 +252,8 @@ def GeneratesGemm(tile, data_type, layout_a, layout_b, layout_c, min_cc, align_a | |||
| if tile.math_instruction.element_accumulator == DataType.s32: | |||
| epilogues = [EpilogueFunctor.LinearCombinationClamp] | |||
| else: | |||
| assert tile.math_instruction.element_accumulator == DataType.f32 | |||
| assert tile.math_instruction.element_accumulator == DataType.f32 or \ | |||
| tile.math_instruction.element_accumulator == DataType.f16 | |||
| epilogues = [EpilogueFunctor.LinearCombination] | |||
| for epilogue in epilogues: | |||
| @@ -799,7 +800,22 @@ class EmitGemmSplitKParallelInstance: | |||
| ${epilogue_vector_length}, | |||
| ${element_accumulator}, | |||
| ${element_epilogue} | |||
| > | |||
| >, | |||
| cutlass::epilogue::thread::Convert< | |||
| ${element_accumulator}, | |||
| ${epilogue_vector_length}, | |||
| ${element_accumulator} | |||
| >, | |||
| cutlass::reduction::thread::ReduceAdd< | |||
| ${element_accumulator}, | |||
| ${element_accumulator}, | |||
| ${epilogue_vector_length} | |||
| >, | |||
| cutlass::gemm::threadblock::GemmSplitKHorizontalThreadblockSwizzle, | |||
| ${stages}, | |||
| ${align_a}, | |||
| ${align_b}, | |||
| ${math_operation} | |||
| >; | |||
| """ | |||
| def emit(self, operation): | |||
| @@ -831,7 +847,10 @@ class EmitGemmSplitKParallelInstance: | |||
| 'epilogue_vector_length': str(epilogue_vector_length), | |||
| 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), | |||
| 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor], | |||
| 'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor], | |||
| 'stages': str(operation.tile_description.stages), | |||
| 'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation], | |||
| 'align_a': str(operation.A.alignment), | |||
| 'align_b': str(operation.B.alignment), | |||
| } | |||
| return SubstituteTemplate(self.template, values) | |||
| @@ -32,6 +32,8 @@ if __name__ == "__main__": | |||
| f.write("# Generated by dnn/scripts/cutlass_generator/gen_list.py\n\n") | |||
| f.write("cutlass_gen_list = [\n") | |||
| write_op_list(f, "gemm", "simt") | |||
| write_op_list(f, "gemm", "tensorop1688") | |||
| write_op_list(f, "gemm", "tensorop884") | |||
| write_op_list(f, "gemv", "simt") | |||
| write_op_list(f, "deconv", "simt") | |||
| write_op_list(f, "conv2d", "simt") | |||
| @@ -596,6 +596,131 @@ def GenerateGemv_Simt(args): | |||
| align_b)) | |||
| return operations | |||
| # | |||
| def GeneratesGemm_TensorOp_1688(args): | |||
| layouts = [ | |||
| (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), # nn | |||
| (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.RowMajor), # nt | |||
| (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), # tn | |||
| (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.RowMajor), # tt | |||
| ] | |||
| math_instructions = [ | |||
| MathInstruction( \ | |||
| [16, 8, 8], \ | |||
| DataType.f16, DataType.f16, DataType.f32, \ | |||
| OpcodeClass.TensorOp, \ | |||
| MathOperation.multiply_add), | |||
| MathInstruction( \ | |||
| [16, 8, 8], \ | |||
| DataType.f16, DataType.f16, DataType.f16, \ | |||
| OpcodeClass.TensorOp, \ | |||
| MathOperation.multiply_add), | |||
| ] | |||
| min_cc = 75 | |||
| max_cc = 1024 | |||
| alignment_constraints = [8, 4, 2, | |||
| #1 | |||
| ] | |||
| operations = [] | |||
| for math_inst in math_instructions: | |||
| for layout in layouts: | |||
| for align in alignment_constraints: | |||
| tile_descriptions = [ | |||
| TileDescription([256, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc), | |||
| TileDescription([128, 256, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc), | |||
| TileDescription([128, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), | |||
| ## comment some configuration to reduce compilation time and binary size | |||
| # TileDescription([ 64, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), | |||
| # TileDescription([128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), | |||
| # TileDescription([ 64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), | |||
| ] | |||
| data_type = [ | |||
| math_inst.element_a, | |||
| math_inst.element_b, | |||
| math_inst.element_a, | |||
| math_inst.element_accumulator, | |||
| ] | |||
| for tile in tile_descriptions: | |||
| operations += GeneratesGemm(tile, \ | |||
| data_type, \ | |||
| layout[0], \ | |||
| layout[1], \ | |||
| layout[2], \ | |||
| min_cc, \ | |||
| align * 16, \ | |||
| align * 16, \ | |||
| align * 16) | |||
| return operations | |||
| # | |||
| def GeneratesGemm_TensorOp_884(args): | |||
| layouts = [ | |||
| (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), # nn | |||
| (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.RowMajor), # nt | |||
| (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), # tn | |||
| (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.RowMajor), # tt | |||
| ] | |||
| math_instructions = [ | |||
| MathInstruction( \ | |||
| [8, 8, 4], \ | |||
| DataType.f16, DataType.f16, DataType.f32, \ | |||
| OpcodeClass.TensorOp, \ | |||
| MathOperation.multiply_add), | |||
| MathInstruction( \ | |||
| [8, 8, 4], \ | |||
| DataType.f16, DataType.f16, DataType.f16, \ | |||
| OpcodeClass.TensorOp, \ | |||
| MathOperation.multiply_add), | |||
| ] | |||
| min_cc = 70 | |||
| max_cc = 75 | |||
| alignment_constraints = [8, 4, 2, | |||
| # 1 | |||
| ] | |||
| operations = [] | |||
| for math_inst in math_instructions: | |||
| for layout in layouts: | |||
| for align in alignment_constraints: | |||
| tile_descriptions = [ | |||
| TileDescription([256, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc), | |||
| TileDescription([128, 256, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc), | |||
| TileDescription([128, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), | |||
| ## comment some configuration to reduce compilation time and binary size | |||
| # TileDescription([ 64, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), | |||
| # TileDescription([128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), | |||
| # TileDescription([ 64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), | |||
| ] | |||
| data_type = [ | |||
| math_inst.element_a, | |||
| math_inst.element_b, | |||
| math_inst.element_a, | |||
| math_inst.element_accumulator, | |||
| ] | |||
| for tile in tile_descriptions: | |||
| operations += GeneratesGemm(tile, \ | |||
| data_type, \ | |||
| layout[0], \ | |||
| layout[1], \ | |||
| layout[2], \ | |||
| min_cc, \ | |||
| align * 16, \ | |||
| align * 16, \ | |||
| align * 16) | |||
| return operations | |||
| # | |||
| def GenerateConv2dOperations(args): | |||
| if args.type == "simt": | |||
| @@ -613,9 +738,14 @@ def GenerateDeconvOperations(args): | |||
| return GenerateDeconv_Simt(args) | |||
| def GenerateGemmOperations(args): | |||
| assert args.type == "simt", "operation gemm only support" \ | |||
| "simt. (got:{})".format(args.type) | |||
| return GenerateGemm_Simt(args) | |||
| if args.type == "tensorop884": | |||
| return GeneratesGemm_TensorOp_884(args) | |||
| elif args.type == "tensorop1688": | |||
| return GeneratesGemm_TensorOp_1688(args) | |||
| else: | |||
| assert args.type == "simt", "operation gemm only support" \ | |||
| "simt. (got:{})".format(args.type) | |||
| return GenerateGemm_Simt(args) | |||
| def GenerateGemvOperations(args): | |||
| assert args.type == "simt", "operation gemv only support" \ | |||
| @@ -631,7 +761,7 @@ if __name__ == "__main__": | |||
| parser.add_argument("--operations", type=str, choices=['gemm', 'gemv', 'conv2d', 'deconv'], | |||
| required=True, help="Specifies the operation to generate (gemm, gemv, conv2d, deconv)") | |||
| parser.add_argument("output", type=str, help="output directory for CUTLASS kernel files") | |||
| parser.add_argument("--type", type=str, choices=['simt', 'tensorop8816', 'tensorop8832'], | |||
| parser.add_argument("--type", type=str, choices=['simt', 'tensorop8816', 'tensorop8832', 'tensorop884', 'tensorop1688'], | |||
| default='simt', help="kernel type of CUTLASS kernel generator") | |||
| gemv_wrapper_path = "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper_batched_gemv_strided.cuinl" | |||
| @@ -138,6 +138,296 @@ cutlass_gen_list = [ | |||
| "cutlass_simt_sgemm_256x64_8x2_tt_align1.cu", | |||
| "cutlass_simt_sgemm_split_k_parallel_256x64_8x2_tt_align1.cu", | |||
| "all_gemm_simt_operations.cu", | |||
| "cutlass_tensorop_f16_s1688gemm_f16_256x128_32x2_nn_align8.cu", | |||
| "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_256x128_32x2_nn_align8.cu", | |||
| "cutlass_tensorop_f16_s1688gemm_f16_128x256_32x2_nn_align8.cu", | |||
| "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x256_32x2_nn_align8.cu", | |||
| "cutlass_tensorop_f16_s1688gemm_f16_128x128_32x2_nn_align8.cu", | |||
| "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x128_32x2_nn_align8.cu", | |||
| "cutlass_tensorop_f16_s1688gemm_f16_256x128_32x2_nn_align4.cu", | |||
| "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_256x128_32x2_nn_align4.cu", | |||
| "cutlass_tensorop_f16_s1688gemm_f16_128x256_32x2_nn_align4.cu", | |||
| "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x256_32x2_nn_align4.cu", | |||
| "cutlass_tensorop_f16_s1688gemm_f16_128x128_32x2_nn_align4.cu", | |||
| "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x128_32x2_nn_align4.cu", | |||
| "cutlass_tensorop_f16_s1688gemm_f16_256x128_32x2_nn_align2.cu", | |||
| "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_256x128_32x2_nn_align2.cu", | |||
| "cutlass_tensorop_f16_s1688gemm_f16_128x256_32x2_nn_align2.cu", | |||
| "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x256_32x2_nn_align2.cu", | |||
| "cutlass_tensorop_f16_s1688gemm_f16_128x128_32x2_nn_align2.cu", | |||
| "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x128_32x2_nn_align2.cu", | |||
| "cutlass_tensorop_f16_s1688gemm_f16_256x128_32x2_nt_align8.cu", | |||
| "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_256x128_32x2_nt_align8.cu", | |||
| "cutlass_tensorop_f16_s1688gemm_f16_128x256_32x2_nt_align8.cu", | |||
| "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x256_32x2_nt_align8.cu", | |||
| "cutlass_tensorop_f16_s1688gemm_f16_128x128_32x2_nt_align8.cu", | |||
| "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x128_32x2_nt_align8.cu", | |||
| "cutlass_tensorop_f16_s1688gemm_f16_256x128_32x2_nt_align4.cu", | |||
| "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_256x128_32x2_nt_align4.cu", | |||
| "cutlass_tensorop_f16_s1688gemm_f16_128x256_32x2_nt_align4.cu", | |||
| "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x256_32x2_nt_align4.cu", | |||
| "cutlass_tensorop_f16_s1688gemm_f16_128x128_32x2_nt_align4.cu", | |||
| "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x128_32x2_nt_align4.cu", | |||
| "cutlass_tensorop_f16_s1688gemm_f16_256x128_32x2_nt_align2.cu", | |||
| "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_256x128_32x2_nt_align2.cu", | |||
| "cutlass_tensorop_f16_s1688gemm_f16_128x256_32x2_nt_align2.cu", | |||
| "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x256_32x2_nt_align2.cu", | |||
| "cutlass_tensorop_f16_s1688gemm_f16_128x128_32x2_nt_align2.cu", | |||
| "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x128_32x2_nt_align2.cu", | |||
| "cutlass_tensorop_f16_s1688gemm_f16_256x128_32x2_tn_align8.cu", | |||
| "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_256x128_32x2_tn_align8.cu", | |||
| "cutlass_tensorop_f16_s1688gemm_f16_128x256_32x2_tn_align8.cu", | |||
| "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x256_32x2_tn_align8.cu", | |||
| "cutlass_tensorop_f16_s1688gemm_f16_128x128_32x2_tn_align8.cu", | |||
| "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x128_32x2_tn_align8.cu", | |||
| "cutlass_tensorop_f16_s1688gemm_f16_256x128_32x2_tn_align4.cu", | |||
| "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_256x128_32x2_tn_align4.cu", | |||
| "cutlass_tensorop_f16_s1688gemm_f16_128x256_32x2_tn_align4.cu", | |||
| "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x256_32x2_tn_align4.cu", | |||
| "cutlass_tensorop_f16_s1688gemm_f16_128x128_32x2_tn_align4.cu", | |||
| "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x128_32x2_tn_align4.cu", | |||
| "cutlass_tensorop_f16_s1688gemm_f16_256x128_32x2_tn_align2.cu", | |||
| "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_256x128_32x2_tn_align2.cu", | |||
| "cutlass_tensorop_f16_s1688gemm_f16_128x256_32x2_tn_align2.cu", | |||
| "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x256_32x2_tn_align2.cu", | |||
| "cutlass_tensorop_f16_s1688gemm_f16_128x128_32x2_tn_align2.cu", | |||
| "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x128_32x2_tn_align2.cu", | |||
| "cutlass_tensorop_f16_s1688gemm_f16_256x128_32x2_tt_align8.cu", | |||
| "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_256x128_32x2_tt_align8.cu", | |||
| "cutlass_tensorop_f16_s1688gemm_f16_128x256_32x2_tt_align8.cu", | |||
| "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x256_32x2_tt_align8.cu", | |||
| "cutlass_tensorop_f16_s1688gemm_f16_128x128_32x2_tt_align8.cu", | |||
| "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x128_32x2_tt_align8.cu", | |||
| "cutlass_tensorop_f16_s1688gemm_f16_256x128_32x2_tt_align4.cu", | |||
| "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_256x128_32x2_tt_align4.cu", | |||
| "cutlass_tensorop_f16_s1688gemm_f16_128x256_32x2_tt_align4.cu", | |||
| "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x256_32x2_tt_align4.cu", | |||
| "cutlass_tensorop_f16_s1688gemm_f16_128x128_32x2_tt_align4.cu", | |||
| "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x128_32x2_tt_align4.cu", | |||
| "cutlass_tensorop_f16_s1688gemm_f16_256x128_32x2_tt_align2.cu", | |||
| "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_256x128_32x2_tt_align2.cu", | |||
| "cutlass_tensorop_f16_s1688gemm_f16_128x256_32x2_tt_align2.cu", | |||
| "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x256_32x2_tt_align2.cu", | |||
| "cutlass_tensorop_f16_s1688gemm_f16_128x128_32x2_tt_align2.cu", | |||
| "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x128_32x2_tt_align2.cu", | |||
| "cutlass_tensorop_h1688gemm_256x128_32x2_nn_align8.cu", | |||
| "cutlass_tensorop_h1688gemm_split_k_parallel_256x128_32x2_nn_align8.cu", | |||
| "cutlass_tensorop_h1688gemm_128x256_32x2_nn_align8.cu", | |||
| "cutlass_tensorop_h1688gemm_split_k_parallel_128x256_32x2_nn_align8.cu", | |||
| "cutlass_tensorop_h1688gemm_128x128_32x2_nn_align8.cu", | |||
| "cutlass_tensorop_h1688gemm_split_k_parallel_128x128_32x2_nn_align8.cu", | |||
| "cutlass_tensorop_h1688gemm_256x128_32x2_nn_align4.cu", | |||
| "cutlass_tensorop_h1688gemm_split_k_parallel_256x128_32x2_nn_align4.cu", | |||
| "cutlass_tensorop_h1688gemm_128x256_32x2_nn_align4.cu", | |||
| "cutlass_tensorop_h1688gemm_split_k_parallel_128x256_32x2_nn_align4.cu", | |||
| "cutlass_tensorop_h1688gemm_128x128_32x2_nn_align4.cu", | |||
| "cutlass_tensorop_h1688gemm_split_k_parallel_128x128_32x2_nn_align4.cu", | |||
| "cutlass_tensorop_h1688gemm_256x128_32x2_nn_align2.cu", | |||
| "cutlass_tensorop_h1688gemm_split_k_parallel_256x128_32x2_nn_align2.cu", | |||
| "cutlass_tensorop_h1688gemm_128x256_32x2_nn_align2.cu", | |||
| "cutlass_tensorop_h1688gemm_split_k_parallel_128x256_32x2_nn_align2.cu", | |||
| "cutlass_tensorop_h1688gemm_128x128_32x2_nn_align2.cu", | |||
| "cutlass_tensorop_h1688gemm_split_k_parallel_128x128_32x2_nn_align2.cu", | |||
| "cutlass_tensorop_h1688gemm_256x128_32x2_nt_align8.cu", | |||
| "cutlass_tensorop_h1688gemm_split_k_parallel_256x128_32x2_nt_align8.cu", | |||
| "cutlass_tensorop_h1688gemm_128x256_32x2_nt_align8.cu", | |||
| "cutlass_tensorop_h1688gemm_split_k_parallel_128x256_32x2_nt_align8.cu", | |||
| "cutlass_tensorop_h1688gemm_128x128_32x2_nt_align8.cu", | |||
| "cutlass_tensorop_h1688gemm_split_k_parallel_128x128_32x2_nt_align8.cu", | |||
| "cutlass_tensorop_h1688gemm_256x128_32x2_nt_align4.cu", | |||
| "cutlass_tensorop_h1688gemm_split_k_parallel_256x128_32x2_nt_align4.cu", | |||
| "cutlass_tensorop_h1688gemm_128x256_32x2_nt_align4.cu", | |||
| "cutlass_tensorop_h1688gemm_split_k_parallel_128x256_32x2_nt_align4.cu", | |||
| "cutlass_tensorop_h1688gemm_128x128_32x2_nt_align4.cu", | |||
| "cutlass_tensorop_h1688gemm_split_k_parallel_128x128_32x2_nt_align4.cu", | |||
| "cutlass_tensorop_h1688gemm_256x128_32x2_nt_align2.cu", | |||
| "cutlass_tensorop_h1688gemm_split_k_parallel_256x128_32x2_nt_align2.cu", | |||
| "cutlass_tensorop_h1688gemm_128x256_32x2_nt_align2.cu", | |||
| "cutlass_tensorop_h1688gemm_split_k_parallel_128x256_32x2_nt_align2.cu", | |||
| "cutlass_tensorop_h1688gemm_128x128_32x2_nt_align2.cu", | |||
| "cutlass_tensorop_h1688gemm_split_k_parallel_128x128_32x2_nt_align2.cu", | |||
| "cutlass_tensorop_h1688gemm_256x128_32x2_tn_align8.cu", | |||
| "cutlass_tensorop_h1688gemm_split_k_parallel_256x128_32x2_tn_align8.cu", | |||
| "cutlass_tensorop_h1688gemm_128x256_32x2_tn_align8.cu", | |||
| "cutlass_tensorop_h1688gemm_split_k_parallel_128x256_32x2_tn_align8.cu", | |||
| "cutlass_tensorop_h1688gemm_128x128_32x2_tn_align8.cu", | |||
| "cutlass_tensorop_h1688gemm_split_k_parallel_128x128_32x2_tn_align8.cu", | |||
| "cutlass_tensorop_h1688gemm_256x128_32x2_tn_align4.cu", | |||
| "cutlass_tensorop_h1688gemm_split_k_parallel_256x128_32x2_tn_align4.cu", | |||
| "cutlass_tensorop_h1688gemm_128x256_32x2_tn_align4.cu", | |||
| "cutlass_tensorop_h1688gemm_split_k_parallel_128x256_32x2_tn_align4.cu", | |||
| "cutlass_tensorop_h1688gemm_128x128_32x2_tn_align4.cu", | |||
| "cutlass_tensorop_h1688gemm_split_k_parallel_128x128_32x2_tn_align4.cu", | |||
| "cutlass_tensorop_h1688gemm_256x128_32x2_tn_align2.cu", | |||
| "cutlass_tensorop_h1688gemm_split_k_parallel_256x128_32x2_tn_align2.cu", | |||
| "cutlass_tensorop_h1688gemm_128x256_32x2_tn_align2.cu", | |||
| "cutlass_tensorop_h1688gemm_split_k_parallel_128x256_32x2_tn_align2.cu", | |||
| "cutlass_tensorop_h1688gemm_128x128_32x2_tn_align2.cu", | |||
| "cutlass_tensorop_h1688gemm_split_k_parallel_128x128_32x2_tn_align2.cu", | |||
| "cutlass_tensorop_h1688gemm_256x128_32x2_tt_align8.cu", | |||
| "cutlass_tensorop_h1688gemm_split_k_parallel_256x128_32x2_tt_align8.cu", | |||
| "cutlass_tensorop_h1688gemm_128x256_32x2_tt_align8.cu", | |||
| "cutlass_tensorop_h1688gemm_split_k_parallel_128x256_32x2_tt_align8.cu", | |||
| "cutlass_tensorop_h1688gemm_128x128_32x2_tt_align8.cu", | |||
| "cutlass_tensorop_h1688gemm_split_k_parallel_128x128_32x2_tt_align8.cu", | |||
| "cutlass_tensorop_h1688gemm_256x128_32x2_tt_align4.cu", | |||
| "cutlass_tensorop_h1688gemm_split_k_parallel_256x128_32x2_tt_align4.cu", | |||
| "cutlass_tensorop_h1688gemm_128x256_32x2_tt_align4.cu", | |||
| "cutlass_tensorop_h1688gemm_split_k_parallel_128x256_32x2_tt_align4.cu", | |||
| "cutlass_tensorop_h1688gemm_128x128_32x2_tt_align4.cu", | |||
| "cutlass_tensorop_h1688gemm_split_k_parallel_128x128_32x2_tt_align4.cu", | |||
| "cutlass_tensorop_h1688gemm_256x128_32x2_tt_align2.cu", | |||
| "cutlass_tensorop_h1688gemm_split_k_parallel_256x128_32x2_tt_align2.cu", | |||
| "cutlass_tensorop_h1688gemm_128x256_32x2_tt_align2.cu", | |||
| "cutlass_tensorop_h1688gemm_split_k_parallel_128x256_32x2_tt_align2.cu", | |||
| "cutlass_tensorop_h1688gemm_128x128_32x2_tt_align2.cu", | |||
| "cutlass_tensorop_h1688gemm_split_k_parallel_128x128_32x2_tt_align2.cu", | |||
| "all_gemm_tensorop1688_operations.cu", | |||
| "cutlass_tensorop_f16_s884gemm_f16_256x128_32x2_nn_align8.cu", | |||
| "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_256x128_32x2_nn_align8.cu", | |||
| "cutlass_tensorop_f16_s884gemm_f16_128x256_32x2_nn_align8.cu", | |||
| "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x256_32x2_nn_align8.cu", | |||
| "cutlass_tensorop_f16_s884gemm_f16_128x128_32x2_nn_align8.cu", | |||
| "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x128_32x2_nn_align8.cu", | |||
| "cutlass_tensorop_f16_s884gemm_f16_256x128_32x2_nn_align4.cu", | |||
| "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_256x128_32x2_nn_align4.cu", | |||
| "cutlass_tensorop_f16_s884gemm_f16_128x256_32x2_nn_align4.cu", | |||
| "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x256_32x2_nn_align4.cu", | |||
| "cutlass_tensorop_f16_s884gemm_f16_128x128_32x2_nn_align4.cu", | |||
| "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x128_32x2_nn_align4.cu", | |||
| "cutlass_tensorop_f16_s884gemm_f16_256x128_32x2_nn_align2.cu", | |||
| "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_256x128_32x2_nn_align2.cu", | |||
| "cutlass_tensorop_f16_s884gemm_f16_128x256_32x2_nn_align2.cu", | |||
| "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x256_32x2_nn_align2.cu", | |||
| "cutlass_tensorop_f16_s884gemm_f16_128x128_32x2_nn_align2.cu", | |||
| "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x128_32x2_nn_align2.cu", | |||
| "cutlass_tensorop_f16_s884gemm_f16_256x128_32x2_nt_align8.cu", | |||
| "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_256x128_32x2_nt_align8.cu", | |||
| "cutlass_tensorop_f16_s884gemm_f16_128x256_32x2_nt_align8.cu", | |||
| "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x256_32x2_nt_align8.cu", | |||
| "cutlass_tensorop_f16_s884gemm_f16_128x128_32x2_nt_align8.cu", | |||
| "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x128_32x2_nt_align8.cu", | |||
| "cutlass_tensorop_f16_s884gemm_f16_256x128_32x2_nt_align4.cu", | |||
| "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_256x128_32x2_nt_align4.cu", | |||
| "cutlass_tensorop_f16_s884gemm_f16_128x256_32x2_nt_align4.cu", | |||
| "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x256_32x2_nt_align4.cu", | |||
| "cutlass_tensorop_f16_s884gemm_f16_128x128_32x2_nt_align4.cu", | |||
| "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x128_32x2_nt_align4.cu", | |||
| "cutlass_tensorop_f16_s884gemm_f16_256x128_32x2_nt_align2.cu", | |||
| "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_256x128_32x2_nt_align2.cu", | |||
| "cutlass_tensorop_f16_s884gemm_f16_128x256_32x2_nt_align2.cu", | |||
| "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x256_32x2_nt_align2.cu", | |||
| "cutlass_tensorop_f16_s884gemm_f16_128x128_32x2_nt_align2.cu", | |||
| "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x128_32x2_nt_align2.cu", | |||
| "cutlass_tensorop_f16_s884gemm_f16_256x128_32x2_tn_align8.cu", | |||
| "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_256x128_32x2_tn_align8.cu", | |||
| "cutlass_tensorop_f16_s884gemm_f16_128x256_32x2_tn_align8.cu", | |||
| "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x256_32x2_tn_align8.cu", | |||
| "cutlass_tensorop_f16_s884gemm_f16_128x128_32x2_tn_align8.cu", | |||
| "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x128_32x2_tn_align8.cu", | |||
| "cutlass_tensorop_f16_s884gemm_f16_256x128_32x2_tn_align4.cu", | |||
| "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_256x128_32x2_tn_align4.cu", | |||
| "cutlass_tensorop_f16_s884gemm_f16_128x256_32x2_tn_align4.cu", | |||
| "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x256_32x2_tn_align4.cu", | |||
| "cutlass_tensorop_f16_s884gemm_f16_128x128_32x2_tn_align4.cu", | |||
| "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x128_32x2_tn_align4.cu", | |||
| "cutlass_tensorop_f16_s884gemm_f16_256x128_32x2_tn_align2.cu", | |||
| "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_256x128_32x2_tn_align2.cu", | |||
| "cutlass_tensorop_f16_s884gemm_f16_128x256_32x2_tn_align2.cu", | |||
| "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x256_32x2_tn_align2.cu", | |||
| "cutlass_tensorop_f16_s884gemm_f16_128x128_32x2_tn_align2.cu", | |||
| "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x128_32x2_tn_align2.cu", | |||
| "cutlass_tensorop_f16_s884gemm_f16_256x128_32x2_tt_align8.cu", | |||
| "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_256x128_32x2_tt_align8.cu", | |||
| "cutlass_tensorop_f16_s884gemm_f16_128x256_32x2_tt_align8.cu", | |||
| "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x256_32x2_tt_align8.cu", | |||
| "cutlass_tensorop_f16_s884gemm_f16_128x128_32x2_tt_align8.cu", | |||
| "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x128_32x2_tt_align8.cu", | |||
| "cutlass_tensorop_f16_s884gemm_f16_256x128_32x2_tt_align4.cu", | |||
| "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_256x128_32x2_tt_align4.cu", | |||
| "cutlass_tensorop_f16_s884gemm_f16_128x256_32x2_tt_align4.cu", | |||
| "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x256_32x2_tt_align4.cu", | |||
| "cutlass_tensorop_f16_s884gemm_f16_128x128_32x2_tt_align4.cu", | |||
| "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x128_32x2_tt_align4.cu", | |||
| "cutlass_tensorop_f16_s884gemm_f16_256x128_32x2_tt_align2.cu", | |||
| "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_256x128_32x2_tt_align2.cu", | |||
| "cutlass_tensorop_f16_s884gemm_f16_128x256_32x2_tt_align2.cu", | |||
| "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x256_32x2_tt_align2.cu", | |||
| "cutlass_tensorop_f16_s884gemm_f16_128x128_32x2_tt_align2.cu", | |||
| "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x128_32x2_tt_align2.cu", | |||
| "cutlass_tensorop_h884gemm_256x128_32x2_nn_align8.cu", | |||
| "cutlass_tensorop_h884gemm_split_k_parallel_256x128_32x2_nn_align8.cu", | |||
| "cutlass_tensorop_h884gemm_128x256_32x2_nn_align8.cu", | |||
| "cutlass_tensorop_h884gemm_split_k_parallel_128x256_32x2_nn_align8.cu", | |||
| "cutlass_tensorop_h884gemm_128x128_32x2_nn_align8.cu", | |||
| "cutlass_tensorop_h884gemm_split_k_parallel_128x128_32x2_nn_align8.cu", | |||
| "cutlass_tensorop_h884gemm_256x128_32x2_nn_align4.cu", | |||
| "cutlass_tensorop_h884gemm_split_k_parallel_256x128_32x2_nn_align4.cu", | |||
| "cutlass_tensorop_h884gemm_128x256_32x2_nn_align4.cu", | |||
| "cutlass_tensorop_h884gemm_split_k_parallel_128x256_32x2_nn_align4.cu", | |||
| "cutlass_tensorop_h884gemm_128x128_32x2_nn_align4.cu", | |||
| "cutlass_tensorop_h884gemm_split_k_parallel_128x128_32x2_nn_align4.cu", | |||
| "cutlass_tensorop_h884gemm_256x128_32x2_nn_align2.cu", | |||
| "cutlass_tensorop_h884gemm_split_k_parallel_256x128_32x2_nn_align2.cu", | |||
| "cutlass_tensorop_h884gemm_128x256_32x2_nn_align2.cu", | |||
| "cutlass_tensorop_h884gemm_split_k_parallel_128x256_32x2_nn_align2.cu", | |||
| "cutlass_tensorop_h884gemm_128x128_32x2_nn_align2.cu", | |||
| "cutlass_tensorop_h884gemm_split_k_parallel_128x128_32x2_nn_align2.cu", | |||
| "cutlass_tensorop_h884gemm_256x128_32x2_nt_align8.cu", | |||
| "cutlass_tensorop_h884gemm_split_k_parallel_256x128_32x2_nt_align8.cu", | |||
| "cutlass_tensorop_h884gemm_128x256_32x2_nt_align8.cu", | |||
| "cutlass_tensorop_h884gemm_split_k_parallel_128x256_32x2_nt_align8.cu", | |||
| "cutlass_tensorop_h884gemm_128x128_32x2_nt_align8.cu", | |||
| "cutlass_tensorop_h884gemm_split_k_parallel_128x128_32x2_nt_align8.cu", | |||
| "cutlass_tensorop_h884gemm_256x128_32x2_nt_align4.cu", | |||
| "cutlass_tensorop_h884gemm_split_k_parallel_256x128_32x2_nt_align4.cu", | |||
| "cutlass_tensorop_h884gemm_128x256_32x2_nt_align4.cu", | |||
| "cutlass_tensorop_h884gemm_split_k_parallel_128x256_32x2_nt_align4.cu", | |||
| "cutlass_tensorop_h884gemm_128x128_32x2_nt_align4.cu", | |||
| "cutlass_tensorop_h884gemm_split_k_parallel_128x128_32x2_nt_align4.cu", | |||
| "cutlass_tensorop_h884gemm_256x128_32x2_nt_align2.cu", | |||
| "cutlass_tensorop_h884gemm_split_k_parallel_256x128_32x2_nt_align2.cu", | |||
| "cutlass_tensorop_h884gemm_128x256_32x2_nt_align2.cu", | |||
| "cutlass_tensorop_h884gemm_split_k_parallel_128x256_32x2_nt_align2.cu", | |||
| "cutlass_tensorop_h884gemm_128x128_32x2_nt_align2.cu", | |||
| "cutlass_tensorop_h884gemm_split_k_parallel_128x128_32x2_nt_align2.cu", | |||
| "cutlass_tensorop_h884gemm_256x128_32x2_tn_align8.cu", | |||
| "cutlass_tensorop_h884gemm_split_k_parallel_256x128_32x2_tn_align8.cu", | |||
| "cutlass_tensorop_h884gemm_128x256_32x2_tn_align8.cu", | |||
| "cutlass_tensorop_h884gemm_split_k_parallel_128x256_32x2_tn_align8.cu", | |||
| "cutlass_tensorop_h884gemm_128x128_32x2_tn_align8.cu", | |||
| "cutlass_tensorop_h884gemm_split_k_parallel_128x128_32x2_tn_align8.cu", | |||
| "cutlass_tensorop_h884gemm_256x128_32x2_tn_align4.cu", | |||
| "cutlass_tensorop_h884gemm_split_k_parallel_256x128_32x2_tn_align4.cu", | |||
| "cutlass_tensorop_h884gemm_128x256_32x2_tn_align4.cu", | |||
| "cutlass_tensorop_h884gemm_split_k_parallel_128x256_32x2_tn_align4.cu", | |||
| "cutlass_tensorop_h884gemm_128x128_32x2_tn_align4.cu", | |||
| "cutlass_tensorop_h884gemm_split_k_parallel_128x128_32x2_tn_align4.cu", | |||
| "cutlass_tensorop_h884gemm_256x128_32x2_tn_align2.cu", | |||
| "cutlass_tensorop_h884gemm_split_k_parallel_256x128_32x2_tn_align2.cu", | |||
| "cutlass_tensorop_h884gemm_128x256_32x2_tn_align2.cu", | |||
| "cutlass_tensorop_h884gemm_split_k_parallel_128x256_32x2_tn_align2.cu", | |||
| "cutlass_tensorop_h884gemm_128x128_32x2_tn_align2.cu", | |||
| "cutlass_tensorop_h884gemm_split_k_parallel_128x128_32x2_tn_align2.cu", | |||
| "cutlass_tensorop_h884gemm_256x128_32x2_tt_align8.cu", | |||
| "cutlass_tensorop_h884gemm_split_k_parallel_256x128_32x2_tt_align8.cu", | |||
| "cutlass_tensorop_h884gemm_128x256_32x2_tt_align8.cu", | |||
| "cutlass_tensorop_h884gemm_split_k_parallel_128x256_32x2_tt_align8.cu", | |||
| "cutlass_tensorop_h884gemm_128x128_32x2_tt_align8.cu", | |||
| "cutlass_tensorop_h884gemm_split_k_parallel_128x128_32x2_tt_align8.cu", | |||
| "cutlass_tensorop_h884gemm_256x128_32x2_tt_align4.cu", | |||
| "cutlass_tensorop_h884gemm_split_k_parallel_256x128_32x2_tt_align4.cu", | |||
| "cutlass_tensorop_h884gemm_128x256_32x2_tt_align4.cu", | |||
| "cutlass_tensorop_h884gemm_split_k_parallel_128x256_32x2_tt_align4.cu", | |||
| "cutlass_tensorop_h884gemm_128x128_32x2_tt_align4.cu", | |||
| "cutlass_tensorop_h884gemm_split_k_parallel_128x128_32x2_tt_align4.cu", | |||
| "cutlass_tensorop_h884gemm_256x128_32x2_tt_align2.cu", | |||
| "cutlass_tensorop_h884gemm_split_k_parallel_256x128_32x2_tt_align2.cu", | |||
| "cutlass_tensorop_h884gemm_128x256_32x2_tt_align2.cu", | |||
| "cutlass_tensorop_h884gemm_split_k_parallel_128x256_32x2_tt_align2.cu", | |||
| "cutlass_tensorop_h884gemm_128x128_32x2_tt_align2.cu", | |||
| "cutlass_tensorop_h884gemm_split_k_parallel_128x128_32x2_tt_align2.cu", | |||
| "all_gemm_tensorop884_operations.cu", | |||
| "cutlass_simt_sgemv_batched_strided_1x128_32_tt_align4x4.cu", | |||
| "cutlass_simt_sgemv_batched_strided_1x128_16_tt_align4x2.cu", | |||
| "cutlass_simt_sgemv_batched_strided_1x128_8_tt_align4x1.cu", | |||
| @@ -646,4 +936,4 @@ cutlass_gen_list = [ | |||
| "cutlass_tensorop_u4_i8832fprop_roc_relu_u4_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_1x1_roc_relu_u4_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | |||
| "all_conv2d_tensorop8832_operations.cu", | |||
| ] | |||
| ] | |||
| @@ -151,6 +151,8 @@ if(MGE_WITH_CUDA) | |||
| set(${gen_files} "${${gen_files}}" PARENT_SCOPE) | |||
| endfunction() | |||
| gen_cutlass_kimpl(gemm simt CUTLASS_SOURCES) | |||
| gen_cutlass_kimpl(gemm tensorop884 CUTLASS_SOURCES) | |||
| gen_cutlass_kimpl(gemm tensorop1688 CUTLASS_SOURCES) | |||
| gen_cutlass_kimpl(gemv simt CUTLASS_SOURCES) | |||
| gen_cutlass_kimpl(deconv simt CUTLASS_SOURCES) | |||
| gen_cutlass_kimpl(conv2d simt CUTLASS_SOURCES) | |||
| @@ -49,6 +49,8 @@ namespace library { | |||
| (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) | |||
| void initialize_all_gemm_simt_operations(Manifest& manifest); | |||
| void initialize_all_gemm_tensorop884_operations(Manifest& manifest); | |||
| void initialize_all_gemm_tensorop1688_operations(Manifest& manifest); | |||
| void initialize_all_conv2d_simt_operations(Manifest& manifest); | |||
| void initialize_all_conv2d_tensorop8816_operations(Manifest& manifest); | |||
| void initialize_all_conv2d_tensorop8832_operations(Manifest& manifest); | |||
| @@ -56,6 +58,8 @@ void initialize_all_deconv_simt_operations(Manifest& manifest); | |||
| void initialize_all(Manifest& manifest) { | |||
| initialize_all_gemm_simt_operations(manifest); | |||
| initialize_all_gemm_tensorop884_operations(manifest); | |||
| initialize_all_gemm_tensorop1688_operations(manifest); | |||
| initialize_all_conv2d_simt_operations(manifest); | |||
| initialize_all_conv2d_tensorop8816_operations(manifest); | |||
| initialize_all_conv2d_tensorop8832_operations(manifest); | |||
| @@ -55,6 +55,8 @@ GemmKey get_gemm_key_from_desc(const GemmDescription& desc) { | |||
| key.layout_B = desc.B.layout; | |||
| key.element_C = desc.C.element; | |||
| key.layout_C = desc.C.layout; | |||
| key.element_accumulator = | |||
| desc.tile_description.math_instruction.element_accumulator; | |||
| key.threadblock_shape_m = desc.tile_description.threadblock_shape.m(); | |||
| key.threadblock_shape_n = desc.tile_description.threadblock_shape.n(); | |||
| @@ -75,6 +77,8 @@ GemmKey get_gemm_key_from_desc(const GemmDescription& desc) { | |||
| desc.tile_description.math_instruction.instruction_shape.k(); | |||
| key.stages = desc.stages; | |||
| key.alignment_A = desc.A.alignment; | |||
| key.alignment_B = desc.B.alignment; | |||
| key.split_k_mode = desc.split_k_mode; | |||
| return key; | |||
| @@ -77,6 +77,7 @@ struct GemmKey { | |||
| LayoutTypeID layout_B; | |||
| NumericTypeID element_C; | |||
| LayoutTypeID layout_C; | |||
| NumericTypeID element_accumulator; | |||
| int threadblock_shape_m; | |||
| int threadblock_shape_n; | |||
| @@ -91,12 +92,15 @@ struct GemmKey { | |||
| int instruction_shape_k; | |||
| int stages; | |||
| int alignment_A; | |||
| int alignment_B; | |||
| SplitKMode split_k_mode; | |||
| inline bool operator==(GemmKey const& rhs) const { | |||
| return (element_A == rhs.element_A) && (layout_A == rhs.layout_A) && | |||
| (element_B == rhs.element_B) && (layout_B == rhs.layout_B) && | |||
| (element_C == rhs.element_C) && (layout_C == rhs.layout_C) && | |||
| (element_accumulator == rhs.element_accumulator) && | |||
| (threadblock_shape_m == rhs.threadblock_shape_m) && | |||
| (threadblock_shape_n == rhs.threadblock_shape_n) && | |||
| (threadblock_shape_k == rhs.threadblock_shape_k) && | |||
| @@ -106,7 +110,9 @@ struct GemmKey { | |||
| (instruction_shape_m == rhs.instruction_shape_m) && | |||
| (instruction_shape_n == rhs.instruction_shape_n) && | |||
| (instruction_shape_k == rhs.instruction_shape_k) && | |||
| (stages == rhs.stages) && (split_k_mode == rhs.split_k_mode); | |||
| (stages == rhs.stages) && (alignment_A == rhs.alignment_A) && | |||
| (alignment_B == rhs.alignment_B) && | |||
| (split_k_mode == rhs.split_k_mode); | |||
| } | |||
| inline bool operator!=(GemmKey const& rhs) const { return !(*this == rhs); } | |||
| @@ -130,10 +136,13 @@ struct GemmKey { | |||
| "\n layout_B: " + to_string(layout_B) + | |||
| "\n element_C: " + to_string(element_C) + | |||
| "\n layout_C: " + to_string(layout_C) + | |||
| "\n element_accumulator: " + to_string(element_accumulator) + | |||
| "\n threadblock_shape: " + threadblock_shape_str + | |||
| "\n warp_shape: " + warp_shape_str + | |||
| "\n instruction_shape: " + instruction_shape_str + | |||
| "\n stages: " + std::to_string(stages) + | |||
| "\n alignment_A: " + std::to_string(alignment_A) + | |||
| "\n alignment_B: " + std::to_string(alignment_B) + | |||
| "\n split_k_mode: " + to_string(split_k_mode) + "\n}"; | |||
| } | |||
| }; | |||
| @@ -147,6 +156,8 @@ struct GemmKeyHasher { | |||
| .update(&key.layout_B, sizeof(key.layout_B)) | |||
| .update(&key.element_C, sizeof(key.element_C)) | |||
| .update(&key.layout_C, sizeof(key.layout_C)) | |||
| .update(&key.element_accumulator, | |||
| sizeof(key.element_accumulator)) | |||
| .update(&key.threadblock_shape_m, | |||
| sizeof(key.threadblock_shape_m)) | |||
| .update(&key.threadblock_shape_n, | |||
| @@ -157,6 +168,8 @@ struct GemmKeyHasher { | |||
| .update(&key.warp_shape_n, sizeof(key.warp_shape_n)) | |||
| .update(&key.warp_shape_k, sizeof(key.warp_shape_k)) | |||
| .update(&key.stages, sizeof(key.stages)) | |||
| .update(&key.alignment_A, sizeof(key.alignment_A)) | |||
| .update(&key.alignment_B, sizeof(key.alignment_B)) | |||
| .update(&key.split_k_mode, sizeof(key.split_k_mode)) | |||
| .digest(); | |||
| } | |||
| @@ -43,6 +43,12 @@ MatrixMulForwardImpl::AlgoPack::AlgoPack() { | |||
| for (auto&& algo : simt_float32_gemv_batched_strided) { | |||
| all_algos.push_back(&algo); | |||
| } | |||
| for (auto&& algo : tensorop_float16) { | |||
| all_algos.push_back(&algo); | |||
| } | |||
| for (auto&& algo : tensorop_float16_split_k) { | |||
| all_algos.push_back(&algo); | |||
| } | |||
| #endif | |||
| all_algos.push_back(&naive); | |||
| @@ -53,7 +59,7 @@ MatrixMulForwardImpl::AlgoPack::AlgoPack() { | |||
| #if CUDA_VERSION >= 9020 | |||
| void MatrixMulForwardImpl::AlgoPack::fill_cutlass_algos() { | |||
| using AlgoParam = AlgoFloat32SIMT::AlgoParam; | |||
| using AlgoParam = AlgoCutlassMatrixMulBase::AlgoParam; | |||
| simt_float32.emplace_back(AlgoParam{64, 256, 8, 32, 64, 8}); | |||
| simt_float32.emplace_back(AlgoParam{256, 64, 8, 64, 32, 8}); | |||
| simt_float32.emplace_back(AlgoParam{32, 256, 8, 16, 64, 8}); | |||
| @@ -91,6 +97,19 @@ void MatrixMulForwardImpl::AlgoPack::fill_cutlass_algos() { | |||
| simt_float32_gemv_batched_strided.emplace_back(128); | |||
| simt_float32_gemv_batched_strided.emplace_back(64); | |||
| simt_float32_gemv_batched_strided.emplace_back(32); | |||
| #define FOREACH_CUTLASS_MATMUL_F16_SHAPES(cb) \ | |||
| cb(256, 128, 32, 64, 64, 32, 8, 8, 4); \ | |||
| cb(128, 256, 32, 64, 64, 32, 8, 8, 4); \ | |||
| cb(128, 128, 32, 64, 64, 32, 8, 8, 4); \ | |||
| cb(256, 128, 32, 64, 64, 32, 16, 8, 8); \ | |||
| cb(128, 256, 32, 64, 64, 32, 16, 8, 8); \ | |||
| cb(128, 128, 32, 64, 64, 32, 16, 8, 8); | |||
| #define cb(...) \ | |||
| tensorop_float16.emplace_back(AlgoParam{__VA_ARGS__}); \ | |||
| tensorop_float16_split_k.emplace_back(AlgoParam{__VA_ARGS__}); | |||
| FOREACH_CUTLASS_MATMUL_F16_SHAPES(cb) | |||
| #undef cb | |||
| #undef FOREACH_CUTLASS_MATMUL_F16_SHAPES | |||
| } | |||
| #endif | |||
| @@ -41,11 +41,13 @@ public: | |||
| CUDA_WMMA_UINT4X4X32, | |||
| CUDA_CUBLASLT, | |||
| CUDA_NAIVE, | |||
| CUDA_BFLOAT16, | |||
| CUDA_BFLOAT16, | |||
| #if CUDA_VERSION >= 9020 | |||
| CUDA_FLOAT32_SIMT, | |||
| CUDA_FLOAT32_SIMT_SPLIT_K, | |||
| CUDA_FLOAT32_SIMT_GEMV_BATCHED_STRIDED, | |||
| CUDA_FLOAT32_SIMT, | |||
| CUDA_FLOAT32_SIMT_SPLIT_K, | |||
| CUDA_FLOAT32_SIMT_GEMV_BATCHED_STRIDED, | |||
| CUDA_FLOAT16_TENSOR_OP, | |||
| CUDA_FLOAT16_TENSOR_OP_SPLIT_K, | |||
| #endif | |||
| }; | |||
| using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>; | |||
| @@ -188,65 +190,83 @@ private: | |||
| #endif | |||
| #if CUDA_VERSION >= 9020 | |||
| class MatrixMulForwardImpl::AlgoFloat32SIMT final : public AlgoBase { | |||
| class MatrixMulForwardImpl::AlgoCutlassMatrixMulBase : public AlgoBase { | |||
| public: | |||
| struct AlgoParam { | |||
| int threadblock_m, threadblock_n, threadblock_k; | |||
| int warp_m, warp_n, warp_k; | |||
| std::string to_string() { | |||
| return ssprintf("%dX%dX%d_%dX%dX%d", threadblock_m, threadblock_n, | |||
| threadblock_k, warp_m, warp_n, warp_k); | |||
| } | |||
| int instruction_m, instruction_n, instruction_k; | |||
| AlgoParam(int threadblock_m_, int threadblock_n_, int threadblock_k_, | |||
| int warp_m_, int warp_n_, int warp_k_, int instruction_m_ = 1, | |||
| int instruction_n_ = 1, int instruction_k_ = 1) | |||
| : threadblock_m{threadblock_m_}, | |||
| threadblock_n{threadblock_n_}, | |||
| threadblock_k{threadblock_k_}, | |||
| warp_m{warp_m_}, | |||
| warp_n{warp_n_}, | |||
| warp_k{warp_k_}, | |||
| instruction_m{instruction_m_}, | |||
| instruction_n{instruction_n_}, | |||
| instruction_k{instruction_k_} {} | |||
| std::string to_string() const; | |||
| }; | |||
| AlgoCutlassMatrixMulBase(AlgoParam algo_param) : m_algo_param{algo_param} {} | |||
| void exec(const ExecArgs& args) const override; | |||
| std::string param() const override { | |||
| std::string ret; | |||
| serialize_write_pod(m_algo_param, ret); | |||
| return ret; | |||
| } | |||
| protected: | |||
| virtual int min_alignment_requirement() const = 0; | |||
| virtual void do_exec(const ExecArgs& args) const = 0; | |||
| std::pair<bool, TensorLayoutArray> construct_aligned_layouts( | |||
| const SizeArgs& args) const; | |||
| int max_alignment(const SizeArgs& args) const; | |||
| AlgoParam m_algo_param; | |||
| }; | |||
| class MatrixMulForwardImpl::AlgoFloat32SIMT final | |||
| : public AlgoCutlassMatrixMulBase { | |||
| public: | |||
| AlgoFloat32SIMT(AlgoParam algo_param) | |||
| : m_algo_param{algo_param}, | |||
| : AlgoCutlassMatrixMulBase{algo_param}, | |||
| m_name{ssprintf("CUTLASS_FLOAT32_SIMT_%s", | |||
| m_algo_param.to_string().c_str())} {} | |||
| bool is_available(const SizeArgs& args) const override; | |||
| size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||
| const char* name() const override { return m_name.c_str(); } | |||
| void exec(const ExecArgs& args) const override; | |||
| AlgoAttribute attribute() const override { | |||
| return AlgoAttribute::REPRODUCIBLE; | |||
| } | |||
| MEGDNN_DECL_ALGO_TYPE(CUDA_FLOAT32_SIMT) | |||
| std::string param() const override { | |||
| std::string ret; | |||
| serialize_write_pod(m_algo_param, ret); | |||
| return ret; | |||
| } | |||
| private: | |||
| AlgoParam m_algo_param; | |||
| void do_exec(const ExecArgs& args) const override; | |||
| int min_alignment_requirement() const override { return 1; } | |||
| std::string m_name; | |||
| }; | |||
| class MatrixMulForwardImpl::AlgoFloat32SIMTSplitK final : public AlgoBase { | |||
| class MatrixMulForwardImpl::AlgoFloat32SIMTSplitK final | |||
| : public AlgoCutlassMatrixMulBase { | |||
| public: | |||
| using AlgoParam = MatrixMulForwardImpl::AlgoFloat32SIMT::AlgoParam; | |||
| AlgoFloat32SIMTSplitK(AlgoParam algo_param) | |||
| : m_algo_param{algo_param}, | |||
| : AlgoCutlassMatrixMulBase{algo_param}, | |||
| m_name{ssprintf("CUTLASS_FLOAT32_SIMT_SPLIT_K_%s", | |||
| m_algo_param.to_string().c_str())} {} | |||
| bool is_available(const SizeArgs& args) const override; | |||
| size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||
| const char* name() const override { return m_name.c_str(); } | |||
| void exec(const ExecArgs& args) const override; | |||
| AlgoAttribute attribute() const override { | |||
| return AlgoAttribute::REPRODUCIBLE | | |||
| AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||
| } | |||
| MEGDNN_DECL_ALGO_TYPE(CUDA_FLOAT32_SIMT_SPLIT_K) | |||
| std::string param() const override { | |||
| std::string ret; | |||
| serialize_write_pod(m_algo_param, ret); | |||
| return ret; | |||
| } | |||
| private: | |||
| AlgoParam m_algo_param; | |||
| void do_exec(const ExecArgs& args) const override; | |||
| int min_alignment_requirement() const override { return 1; } | |||
| std::string m_name; | |||
| }; | |||
| @@ -276,6 +296,56 @@ private: | |||
| int m_threadblock_n; | |||
| std::string m_name; | |||
| }; | |||
| class MatrixMulForwardImpl::AlgoFloat16TensorOp final | |||
| : public AlgoCutlassMatrixMulBase { | |||
| public: | |||
| AlgoFloat16TensorOp(AlgoParam algo_param) | |||
| : AlgoCutlassMatrixMulBase{algo_param}, | |||
| m_name{ssprintf("CUTLASS_FLOAT16_TENSOR_OP_h%d%d%d_%s", | |||
| m_algo_param.instruction_m, | |||
| m_algo_param.instruction_n, | |||
| m_algo_param.instruction_k, | |||
| m_algo_param.to_string().c_str())} {} | |||
| bool is_available(const SizeArgs& args) const override; | |||
| size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||
| const char* name() const override { return m_name.c_str(); } | |||
| AlgoAttribute attribute() const override { | |||
| return AlgoAttribute::REPRODUCIBLE; | |||
| } | |||
| MEGDNN_DECL_ALGO_TYPE(CUDA_FLOAT16_TENSOR_OP) | |||
| private: | |||
| void do_exec(const ExecArgs& args) const override; | |||
| int min_alignment_requirement() const override { return 2; } | |||
| std::string m_name; | |||
| }; | |||
| class MatrixMulForwardImpl::AlgoFloat16TensorOpSplitK final | |||
| : public AlgoCutlassMatrixMulBase { | |||
| public: | |||
| AlgoFloat16TensorOpSplitK(AlgoParam algo_param) | |||
| : AlgoCutlassMatrixMulBase{algo_param}, | |||
| m_name{ssprintf("CUTLASS_FLOAT16_TENSOR_OP_SPLIT_K_h%d%d%d_%s", | |||
| m_algo_param.instruction_m, | |||
| m_algo_param.instruction_n, | |||
| m_algo_param.instruction_k, | |||
| m_algo_param.to_string().c_str())} {} | |||
| bool is_available(const SizeArgs& args) const override; | |||
| size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||
| const char* name() const override { return m_name.c_str(); } | |||
| AlgoAttribute attribute() const override { | |||
| return AlgoAttribute::REPRODUCIBLE | | |||
| AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||
| } | |||
| MEGDNN_DECL_ALGO_TYPE(CUDA_FLOAT16_TENSOR_OP_SPLIT_K) | |||
| private: | |||
| void do_exec(const ExecArgs& args) const override; | |||
| int min_alignment_requirement() const override { return 2; } | |||
| std::string m_name; | |||
| }; | |||
| #endif | |||
| class MatrixMulForwardImpl::AlgoPack : NonCopyableObj { | |||
| @@ -300,6 +370,8 @@ public: | |||
| std::vector<AlgoFloat32SIMTSplitK> simt_float32_split_k; | |||
| std::vector<AlgoFloat32SIMTGemvBatchedStrided> | |||
| simt_float32_gemv_batched_strided; | |||
| std::vector<AlgoFloat16TensorOp> tensorop_float16; | |||
| std::vector<AlgoFloat16TensorOpSplitK> tensorop_float16_split_k; | |||
| #endif | |||
| std::vector<AlgoBase*> all_algos; | |||
| @@ -0,0 +1,154 @@ | |||
| /** | |||
| * \file dnn/src/cuda/matrix_mul/cutlass_float16_tensorop.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "src/cuda/cutlass/singleton.h" | |||
| #include "src/cuda/handle.h" | |||
| #include "src/cuda/matrix_mul/algos.h" | |||
| #include "src/cuda/utils.h" | |||
| #if CUDA_VERSION >= 9020 | |||
| using namespace megdnn; | |||
| using namespace cuda; | |||
| bool MatrixMulForwardImpl::AlgoFloat16TensorOp::is_available( | |||
| const SizeArgs& args) const { | |||
| bool available = | |||
| args.opr->param().format == param::MatrixMul::Format::DEFAULT && | |||
| args.layout_b.dtype == dtype::Float16() && | |||
| args.layout_c.dtype == dtype::Float16(); | |||
| int n = args.layout_c.shape[1]; | |||
| auto&& device_prop = cuda::current_device_prop(); | |||
| int y_grid_limit = device_prop.maxGridSize[1]; | |||
| // limit y grid | |||
| available &= ((n + m_algo_param.threadblock_n - 1) / | |||
| m_algo_param.threadblock_n <= | |||
| y_grid_limit); | |||
| if (m_algo_param.instruction_m == 8 && m_algo_param.instruction_n == 8 && | |||
| m_algo_param.instruction_k == 4) { | |||
| available &= is_compute_capability_required(7, 0); | |||
| } else { | |||
| megdnn_assert(m_algo_param.instruction_m == 16 && | |||
| m_algo_param.instruction_n == 8 && | |||
| m_algo_param.instruction_k == 8); | |||
| available &= is_compute_capability_required(7, 5); | |||
| } | |||
| return available; | |||
| } | |||
| size_t MatrixMulForwardImpl::AlgoFloat16TensorOp::get_workspace_in_bytes( | |||
| const SizeArgs& args) const { | |||
| auto aligned = construct_aligned_layouts(args); | |||
| if (!aligned.first) | |||
| return 0_z; | |||
| const auto& layouts = aligned.second; | |||
| size_t ws_size = 0; | |||
| for (auto&& ly : layouts) { | |||
| ws_size += ly.span().dist_byte(); | |||
| } | |||
| return ws_size; | |||
| } | |||
| void MatrixMulForwardImpl::AlgoFloat16TensorOp::do_exec( | |||
| const ExecArgs& args) const { | |||
| int64_t lda = args.tensor_a.layout.stride[0], | |||
| ldb = args.tensor_b.layout.stride[0], | |||
| ldc = args.tensor_c.layout.stride[0]; | |||
| int alignment = max_alignment(args); | |||
| int min_alignment = min_alignment_requirement(); | |||
| auto&& param = args.opr->param(); | |||
| int m = args.tensor_c.layout.shape[0], n = args.tensor_c.layout.shape[1], | |||
| k = args.tensor_a.layout.shape[param.transposeA ? 0 : 1]; | |||
| megdnn_assert(lda % alignment == 0 && ldb % alignment == 0 && | |||
| ldc % alignment == 0 && m % alignment == 0 && | |||
| n % alignment == 0 && k % alignment == 0 && | |||
| alignment >= min_alignment); | |||
| cutlass::gemm::GemmCoord problem_size{m, n, k}; | |||
| auto&& stream = cuda_stream(args.opr->handle()); | |||
| int* workspace = reinterpret_cast<int*>(args.workspace.raw_ptr); | |||
| // \note these constants (i.e. one and zero) of cutlass epilogue will be | |||
| // passed by pointers and interpreted as ElementCompute*, which will be used | |||
| // to initialize kernel parameters. So the arguments' type on the host side | |||
| // should be the same as the ElementCompute of kernel instance, otherwise | |||
| // undefined kernel bahaviors will occur caused by incorrect intepretation | |||
| // of these pointers. | |||
| float one = 1.f, zero = 0.f; | |||
| dt_float16 one_f16 = static_cast<dt_float16>(one), | |||
| zero_f16 = static_cast<dt_float16>(zero); | |||
| using namespace cutlass::library; | |||
| auto layoutA = param.transposeA ? LayoutTypeID::kColumnMajor | |||
| : LayoutTypeID::kRowMajor; | |||
| auto layoutB = param.transposeB ? LayoutTypeID::kColumnMajor | |||
| : LayoutTypeID::kRowMajor; | |||
| void *host_one, *host_zero; | |||
| NumericTypeID element_accumulator; | |||
| if (param.compute_mode == param::MatrixMul::ComputeMode::DEFAULT) { | |||
| element_accumulator = NumericTypeID::kF16; | |||
| host_one = &one_f16; | |||
| host_zero = &zero_f16; | |||
| } else { | |||
| megdnn_assert(param.compute_mode == | |||
| param::MatrixMul::ComputeMode::FLOAT32); | |||
| element_accumulator = NumericTypeID::kF32; | |||
| host_one = &one; | |||
| host_zero = &zero; | |||
| } | |||
| GemmKey key{NumericTypeID::kF16, | |||
| layoutA, | |||
| NumericTypeID::kF16, | |||
| layoutB, | |||
| NumericTypeID::kF16, | |||
| LayoutTypeID::kRowMajor, | |||
| element_accumulator, | |||
| m_algo_param.threadblock_m, | |||
| m_algo_param.threadblock_n, | |||
| m_algo_param.threadblock_k, | |||
| m_algo_param.warp_m, | |||
| m_algo_param.warp_n, | |||
| m_algo_param.warp_k, | |||
| m_algo_param.instruction_m, | |||
| m_algo_param.instruction_n, | |||
| m_algo_param.instruction_k, | |||
| 2, | |||
| alignment, | |||
| alignment, | |||
| SplitKMode::kNone}; | |||
| const auto& table = Singleton::get().operation_table; | |||
| megdnn_assert(table.gemm_operations.count(key) > 0, | |||
| "key not found in cutlass operation table"); | |||
| const auto& ops = table.gemm_operations.at(key); | |||
| megdnn_assert(ops.size() == 1, "exactly one kernel expected, got %zu", | |||
| ops.size()); | |||
| GemmArguments gemm_args{problem_size, | |||
| args.tensor_a.raw_ptr, | |||
| args.tensor_b.raw_ptr, | |||
| args.tensor_c.raw_ptr, | |||
| args.tensor_c.raw_ptr, | |||
| lda, | |||
| ldb, | |||
| ldc, | |||
| ldc, | |||
| 1, | |||
| host_one, | |||
| host_zero}; | |||
| cutlass_check(ops[0]->run(&gemm_args, workspace, stream)); | |||
| } | |||
| #endif | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,165 @@ | |||
| /** | |||
| * \file dnn/src/cuda/matrix_mul/cutlass_float16_tensorop_split_k.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "src/cuda/cutlass/singleton.h" | |||
| #include "src/cuda/handle.h" | |||
| #include "src/cuda/matrix_mul/algos.h" | |||
| #include "src/cuda/utils.h" | |||
| #if CUDA_VERSION >= 9020 | |||
| using namespace megdnn; | |||
| using namespace cuda; | |||
| bool MatrixMulForwardImpl::AlgoFloat16TensorOpSplitK::is_available( | |||
| const SizeArgs& args) const { | |||
| auto&& param = args.opr->param(); | |||
| int n = args.layout_c.shape[1], | |||
| k = args.layout_a.shape[param.transposeA ? 0 : 1]; | |||
| bool available = | |||
| args.opr->param().format == param::MatrixMul::Format::DEFAULT && | |||
| args.layout_a.dtype == dtype::Float16() && | |||
| args.layout_b.dtype == dtype::Float16() && | |||
| args.layout_c.dtype == dtype::Float16() && k > n; | |||
| auto&& device_prop = cuda::current_device_prop(); | |||
| int y_grid_limit = device_prop.maxGridSize[1]; | |||
| // limit y grid | |||
| available &= ((n + m_algo_param.threadblock_n - 1) / | |||
| m_algo_param.threadblock_n <= | |||
| y_grid_limit); | |||
| if (m_algo_param.instruction_m == 8 && m_algo_param.instruction_n == 8 && | |||
| m_algo_param.instruction_k == 4) { | |||
| available &= is_compute_capability_required(7, 0); | |||
| } else { | |||
| megdnn_assert(m_algo_param.instruction_m == 16 && | |||
| m_algo_param.instruction_n == 8 && | |||
| m_algo_param.instruction_k == 8); | |||
| available &= is_compute_capability_required(7, 5); | |||
| } | |||
| return available; | |||
| } | |||
| size_t MatrixMulForwardImpl::AlgoFloat16TensorOpSplitK::get_workspace_in_bytes( | |||
| const SizeArgs& args) const { | |||
| auto aligned = construct_aligned_layouts(args); | |||
| auto&& param = args.opr->param(); | |||
| int m = args.layout_c.shape[0], n = args.layout_c.shape[1], | |||
| k = args.layout_a.shape[param.transposeA ? 0 : 1]; | |||
| int split_k_slices = std::max(1, k / n); | |||
| if (!aligned.first) | |||
| return args.layout_c.dtype.size(m * n * split_k_slices); | |||
| const auto& layouts = aligned.second; | |||
| int align_m = layouts[2].shape[0], align_n = layouts[2].shape[1], | |||
| align_k = layouts[0].shape[1]; | |||
| split_k_slices = std::max(1, align_k / align_n); | |||
| size_t ws_size = | |||
| args.layout_c.dtype.size(align_m * align_n * split_k_slices); | |||
| for (auto&& ly : layouts) | |||
| ws_size += ly.span().dist_byte(); | |||
| return ws_size; | |||
| } | |||
| void MatrixMulForwardImpl::AlgoFloat16TensorOpSplitK::do_exec( | |||
| const ExecArgs& args) const { | |||
| int64_t lda = args.tensor_a.layout.stride[0], | |||
| ldb = args.tensor_b.layout.stride[0], | |||
| ldc = args.tensor_c.layout.stride[0]; | |||
| int alignment = max_alignment(args); | |||
| int min_alignment = min_alignment_requirement(); | |||
| auto&& param = args.opr->param(); | |||
| int m = args.tensor_c.layout.shape[0], n = args.tensor_c.layout.shape[1], | |||
| k = args.tensor_a.layout.shape[param.transposeA ? 0 : 1]; | |||
| megdnn_assert(lda % alignment == 0 && ldb % alignment == 0 && | |||
| ldc % alignment == 0 && m % alignment == 0 && | |||
| n % alignment == 0 && k % alignment == 0 && | |||
| alignment >= min_alignment); | |||
| cutlass::gemm::GemmCoord problem_size{m, n, k}; | |||
| int split_k_slices = std::max(1, k / n); | |||
| auto&& stream = cuda_stream(args.opr->handle()); | |||
| int* workspace = reinterpret_cast<int*>(args.workspace.raw_ptr); | |||
| // \note these constants (i.e. one and zero) of cutlass epilogue will be | |||
| // passed by pointers and interpreted as ElementCompute*, which will be used | |||
| // to initialize kernel parameters. So the arguments' type on the host side | |||
| // should be the same as the ElementCompute of kernel instance, otherwise | |||
| // undefined kernel bahaviors will occur caused by incorrect intepretation | |||
| // of these pointers. | |||
| float one = 1.f, zero = 0.f; | |||
| dt_float16 one_f16 = static_cast<dt_float16>(one), | |||
| zero_f16 = static_cast<dt_float16>(zero); | |||
| using namespace cutlass::library; | |||
| auto layoutA = param.transposeA ? LayoutTypeID::kColumnMajor | |||
| : LayoutTypeID::kRowMajor; | |||
| auto layoutB = param.transposeB ? LayoutTypeID::kColumnMajor | |||
| : LayoutTypeID::kRowMajor; | |||
| void *host_one, *host_zero; | |||
| NumericTypeID element_accumulator; | |||
| if (param.compute_mode == param::MatrixMul::ComputeMode::DEFAULT) { | |||
| element_accumulator = NumericTypeID::kF16; | |||
| host_one = &one_f16; | |||
| host_zero = &zero_f16; | |||
| } else { | |||
| megdnn_assert(param.compute_mode == | |||
| param::MatrixMul::ComputeMode::FLOAT32); | |||
| element_accumulator = NumericTypeID::kF32; | |||
| host_one = &one; | |||
| host_zero = &zero; | |||
| } | |||
| GemmKey key{NumericTypeID::kF16, | |||
| layoutA, | |||
| NumericTypeID::kF16, | |||
| layoutB, | |||
| NumericTypeID::kF16, | |||
| LayoutTypeID::kRowMajor, | |||
| element_accumulator, | |||
| m_algo_param.threadblock_m, | |||
| m_algo_param.threadblock_n, | |||
| m_algo_param.threadblock_k, | |||
| m_algo_param.warp_m, | |||
| m_algo_param.warp_n, | |||
| m_algo_param.warp_k, | |||
| m_algo_param.instruction_m, | |||
| m_algo_param.instruction_n, | |||
| m_algo_param.instruction_k, | |||
| 2, | |||
| alignment, | |||
| alignment, | |||
| SplitKMode::kParallel}; | |||
| const auto& table = Singleton::get().operation_table; | |||
| megdnn_assert(table.gemm_operations.count(key) > 0, | |||
| "key not found in cutlass operation table"); | |||
| const auto& ops = table.gemm_operations.at(key); | |||
| megdnn_assert(ops.size() == 1, "exactly one kernel expected, got %zu", | |||
| ops.size()); | |||
| GemmArguments gemm_args{problem_size, | |||
| args.tensor_a.raw_ptr, | |||
| args.tensor_b.raw_ptr, | |||
| args.tensor_c.raw_ptr, | |||
| args.tensor_c.raw_ptr, | |||
| lda, | |||
| ldb, | |||
| ldc, | |||
| ldc, | |||
| split_k_slices, | |||
| host_one, | |||
| host_zero}; | |||
| cutlass_check(ops[0]->run(&gemm_args, workspace, stream)); | |||
| } | |||
| #endif | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -42,7 +42,8 @@ size_t MatrixMulForwardImpl::AlgoFloat32SIMT::get_workspace_in_bytes( | |||
| return 0_z; | |||
| } | |||
| void MatrixMulForwardImpl::AlgoFloat32SIMT::exec(const ExecArgs& args) const { | |||
| void MatrixMulForwardImpl::AlgoFloat32SIMT::do_exec( | |||
| const ExecArgs& args) const { | |||
| int64_t lda = args.tensor_a.layout.stride[0], | |||
| ldb = args.tensor_b.layout.stride[0], | |||
| ldc = args.tensor_c.layout.stride[0]; | |||
| @@ -65,12 +66,14 @@ void MatrixMulForwardImpl::AlgoFloat32SIMT::exec(const ExecArgs& args) const { | |||
| auto layoutB = param.transposeB ? LayoutTypeID::kColumnMajor | |||
| : LayoutTypeID::kRowMajor; | |||
| int alignment = min_alignment_requirement(); | |||
| GemmKey key{NumericTypeID::kF32, | |||
| layoutA, | |||
| NumericTypeID::kF32, | |||
| layoutB, | |||
| NumericTypeID::kF32, | |||
| LayoutTypeID::kRowMajor, | |||
| NumericTypeID::kF32, | |||
| m_algo_param.threadblock_m, | |||
| m_algo_param.threadblock_n, | |||
| m_algo_param.threadblock_k, | |||
| @@ -79,8 +82,10 @@ void MatrixMulForwardImpl::AlgoFloat32SIMT::exec(const ExecArgs& args) const { | |||
| m_algo_param.warp_k, | |||
| 1, | |||
| 1, | |||
| 1, | |||
| 2, | |||
| 1, | |||
| 2, | |||
| alignment, | |||
| alignment, | |||
| SplitKMode::kNone}; | |||
| const Operation* op = Singleton::get().operation_table.find_op(key); | |||
| @@ -22,7 +22,7 @@ using namespace cuda; | |||
| bool MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::is_available( | |||
| const SizeArgs& args) const { | |||
| auto&& param = args.opr->param(); | |||
| int m = args.layout_c.shape[0], n = args.layout_c.shape[1], | |||
| int n = args.layout_c.shape[1], | |||
| k = args.layout_a.shape[param.transposeA ? 0 : 1]; | |||
| bool available = | |||
| args.opr->param().format == param::MatrixMul::Format::DEFAULT && | |||
| @@ -32,8 +32,8 @@ bool MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::is_available( | |||
| auto&& device_prop = cuda::current_device_prop(); | |||
| int y_grid_limit = device_prop.maxGridSize[1]; | |||
| // limit y grid | |||
| available &= ((m + m_algo_param.threadblock_m - 1) / | |||
| m_algo_param.threadblock_m <= | |||
| available &= ((n + m_algo_param.threadblock_n - 1) / | |||
| m_algo_param.threadblock_n <= | |||
| y_grid_limit); | |||
| return available; | |||
| } | |||
| @@ -47,7 +47,7 @@ size_t MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::get_workspace_in_bytes( | |||
| return args.layout_c.dtype.size(m * n * split_k_slices); | |||
| } | |||
| void MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::exec( | |||
| void MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::do_exec( | |||
| const ExecArgs& args) const { | |||
| int64_t lda = args.tensor_a.layout.stride[0], | |||
| ldb = args.tensor_b.layout.stride[0], | |||
| @@ -72,12 +72,14 @@ void MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::exec( | |||
| auto layoutB = param.transposeB ? LayoutTypeID::kColumnMajor | |||
| : LayoutTypeID::kRowMajor; | |||
| int alignment = min_alignment_requirement(); | |||
| GemmKey key{NumericTypeID::kF32, | |||
| layoutA, | |||
| NumericTypeID::kF32, | |||
| layoutB, | |||
| NumericTypeID::kF32, | |||
| LayoutTypeID::kRowMajor, | |||
| NumericTypeID::kF32, | |||
| m_algo_param.threadblock_m, | |||
| m_algo_param.threadblock_n, | |||
| m_algo_param.threadblock_k, | |||
| @@ -87,7 +89,9 @@ void MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::exec( | |||
| 1, | |||
| 1, | |||
| 1, | |||
| 2, | |||
| 2, | |||
| alignment, | |||
| alignment, | |||
| SplitKMode::kParallel}; | |||
| Operation const* op = Singleton::get().operation_table.find_op(key); | |||
| @@ -0,0 +1,136 @@ | |||
| /** | |||
| * \file dnn/src/cuda/matrix_mul/cutlass_matrix_mul_base.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "src/cuda/handle.h" | |||
| #include "src/cuda/matrix_mul/algos.h" | |||
| #include "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cuh" | |||
| #include "src/cuda/utils.h" | |||
| #if CUDA_VERSION >= 9020 | |||
| using namespace megdnn; | |||
| using namespace cuda; | |||
| std::string | |||
| MatrixMulForwardImpl::AlgoCutlassMatrixMulBase::AlgoParam::to_string() const { | |||
| return ssprintf("%dX%dX%d_%dX%dX%d", threadblock_m, threadblock_n, | |||
| threadblock_k, warp_m, warp_n, warp_k); | |||
| } | |||
| std::pair<bool, TensorLayoutArray> | |||
| MatrixMulForwardImpl::AlgoCutlassMatrixMulBase::construct_aligned_layouts( | |||
| const SizeArgs& args) const { | |||
| int alignment = max_alignment(args); | |||
| int min_alignment = min_alignment_requirement(); | |||
| bool aligned = alignment >= min_alignment; | |||
| if (aligned) | |||
| return std::make_pair(!aligned, TensorLayoutArray{{}}); | |||
| auto&& param = args.opr->param(); | |||
| int m = args.layout_c.shape[0], n = args.layout_c.shape[1], | |||
| k = args.layout_a.shape[param.transposeA ? 0 : 1]; | |||
| size_t align_m = get_aligned_power2(m, min_alignment); | |||
| size_t align_n = get_aligned_power2(n, min_alignment); | |||
| size_t align_k = get_aligned_power2(k, min_alignment); | |||
| TensorLayoutArray layouts; | |||
| layouts.emplace_back(TensorLayout{{align_m, align_k}, args.layout_a.dtype}); | |||
| layouts.emplace_back(TensorLayout{{align_k, align_n}, args.layout_b.dtype}); | |||
| layouts.emplace_back(TensorLayout{{align_m, align_n}, args.layout_c.dtype}); | |||
| return std::make_pair(!aligned, std::move(layouts)); | |||
| } | |||
| void MatrixMulForwardImpl::AlgoCutlassMatrixMulBase::exec( | |||
| const ExecArgs& args) const { | |||
| auto aligned = construct_aligned_layouts(args); | |||
| if (!aligned.first) | |||
| return do_exec(args); | |||
| const auto& layouts = aligned.second; | |||
| auto tensor_a = args.tensor_a; | |||
| auto tensor_b = args.tensor_b; | |||
| auto workspace = args.workspace; | |||
| size_t copy_size = 0; | |||
| for (const auto& ly : layouts) | |||
| copy_size += ly.span().dist_byte(); | |||
| auto&& param = args.opr->param(); | |||
| auto&& stream = cuda_stream(args.opr->handle()); | |||
| cuda_check(cudaMemsetAsync(workspace.raw_ptr, 0, copy_size, stream)); | |||
| auto&& relayout = args.opr->handle()->create_operator<RelayoutForward>(); | |||
| auto copy_stride = [](const TensorLayout& src, TensorLayout& dst, | |||
| bool trans) { | |||
| dst.stride[0] = src.stride[0], dst.stride[1] = src.stride[1]; | |||
| if (trans) | |||
| std::swap(dst.stride[0], dst.stride[1]); | |||
| }; | |||
| copy_stride(layouts[0], tensor_a.layout, param.transposeA); | |||
| tensor_a.raw_ptr = workspace.raw_ptr; | |||
| relayout->exec(args.tensor_a, tensor_a); | |||
| workspace.raw_ptr += layouts[0].span().dist_byte(); | |||
| workspace.size -= layouts[0].span().dist_byte(); | |||
| copy_stride(layouts[1], tensor_b.layout, param.transposeB); | |||
| tensor_b.raw_ptr = workspace.raw_ptr; | |||
| relayout->exec(args.tensor_b, tensor_b); | |||
| workspace.raw_ptr += layouts[1].span().dist_byte(); | |||
| workspace.size -= layouts[1].span().dist_byte(); | |||
| decltype(tensor_a) tensor_c{workspace.raw_ptr, layouts[2]}; | |||
| workspace.raw_ptr += layouts[2].span().dist_byte(); | |||
| workspace.size -= layouts[2].span().dist_byte(); | |||
| auto&& matmul = args.opr->handle()->create_operator<MatrixMulForward>(); | |||
| matmul->param().transposeA = false; | |||
| matmul->param().transposeB = false; | |||
| matmul->param().compute_mode = args.opr->param().compute_mode; | |||
| tensor_a.layout = layouts[0]; | |||
| tensor_b.layout = layouts[1]; | |||
| ExecArgs args_{static_cast<MatrixMulForwardImpl*>(matmul.get()), tensor_a, | |||
| tensor_b, tensor_c, workspace}; | |||
| do_exec(args_); | |||
| tensor_c.layout.TensorShape::operator=(args.layout_c); | |||
| relayout->exec(tensor_c, args.tensor_c); | |||
| } | |||
| int MatrixMulForwardImpl::AlgoCutlassMatrixMulBase::max_alignment( | |||
| const SizeArgs& args) const { | |||
| auto&& dtype_a = args.layout_a.dtype; | |||
| auto&& dtype_b = args.layout_b.dtype; | |||
| auto&& dtype_c = args.layout_c.dtype; | |||
| auto get_alignment = [](const DType& dt, int len) { | |||
| int size_bits = dt.size(1) * 8; | |||
| int align = 128; | |||
| while (align > 1) { | |||
| if ((len * size_bits) % align == 0) | |||
| break; | |||
| align = align / 2; | |||
| } | |||
| return align / size_bits; | |||
| }; | |||
| int lda = args.layout_a.stride[0], ldb = args.layout_b.stride[0], | |||
| ldc = args.layout_c.stride[0]; | |||
| auto&& param = args.opr->param(); | |||
| int m = args.layout_c.shape[0], n = args.layout_c.shape[1], | |||
| k = args.layout_a.shape[param.transposeA ? 0 : 1]; | |||
| int max_align = get_alignment(dtype_a, lda); | |||
| max_align = std::min(get_alignment(dtype_a, m), max_align); | |||
| max_align = std::min(get_alignment(dtype_a, n), max_align); | |||
| max_align = std::min(get_alignment(dtype_a, k), max_align); | |||
| max_align = std::min(get_alignment(dtype_a, lda), max_align); | |||
| max_align = std::min(get_alignment(dtype_b, ldb), max_align); | |||
| max_align = std::min(get_alignment(dtype_c, ldc), max_align); | |||
| return max_align; | |||
| } | |||
| #endif | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -42,9 +42,12 @@ public: | |||
| class AlgoBFloat16; | |||
| #endif | |||
| #if CUDA_VERSION >= 9020 | |||
| class AlgoCutlassMatrixMulBase; | |||
| class AlgoFloat32SIMT; | |||
| class AlgoFloat32SIMTSplitK; | |||
| class AlgoFloat32SIMTGemvBatchedStrided; | |||
| class AlgoFloat16TensorOp; | |||
| class AlgoFloat16TensorOpSplitK; | |||
| #endif | |||
| class AlgoPack; | |||
| @@ -184,7 +184,8 @@ void matrix_mul::check_matrix_mul(DType A_dtype, DType B_dtype, DType C_dtype, | |||
| const ExecutionPolicyAlgoName& algo, | |||
| param::MatrixMul::Format format, size_t nbase, | |||
| float eps, std::vector<TestArg>&& user_args, | |||
| bool force_deduce_dst) { | |||
| bool force_deduce_dst, | |||
| param::MatrixMul::ComputeMode compute_mode) { | |||
| megdnn_assert(A_dtype.enumv() == B_dtype.enumv()); | |||
| Checker<Opr> checker(handle); | |||
| checker.set_force_deduce_dst(force_deduce_dst); | |||
| @@ -261,6 +262,7 @@ void matrix_mul::check_matrix_mul(DType A_dtype, DType B_dtype, DType C_dtype, | |||
| Param param; | |||
| param.transposeA = arg.mask & 0x1; | |||
| param.transposeB = arg.mask & 0x2; | |||
| param.compute_mode = compute_mode; | |||
| param.format = format; | |||
| checker.set_dtype(0, A_dtype) | |||
| .set_dtype(1, B_dtype) | |||
| @@ -69,7 +69,9 @@ void check_matrix_mul( | |||
| const ExecutionPolicyAlgoName& algo = {"", {}}, | |||
| param::MatrixMul::Format format = param::MatrixMul::Format::DEFAULT, | |||
| size_t nbase = 8, float eps = 1e-3, std::vector<TestArg>&& args = {}, | |||
| bool force_deduce_dst = true); | |||
| bool force_deduce_dst = true, | |||
| param::MatrixMul::ComputeMode compute_mode = | |||
| param::MatrixMul::ComputeMode::DEFAULT); | |||
| void check_matrix_mul( | |||
| DType A_dtype, DType B_dtype, DType C_dtype, Handle* handle, | |||
| @@ -21,6 +21,7 @@ | |||
| #include "test/cuda/fixture.h" | |||
| #include "test/cuda/utils.h" | |||
| #define MEGDNN_WITH_BENCHMARK 1 | |||
| #if CUDA_VERSION >= 9020 | |||
| namespace megdnn { | |||
| namespace test { | |||
| @@ -215,6 +216,14 @@ std::vector<BenchArgs> get_feat_model_args() { | |||
| return args; | |||
| } | |||
| std::vector<BenchArgs> get_f16_feat_model_args() { | |||
| std::vector<BenchArgs> args; | |||
| args.emplace_back(BenchArgs{128, 9216, 9216}); | |||
| args.emplace_back(BenchArgs{128, 6400, 6400}); | |||
| args.emplace_back(BenchArgs{128, 5184, 5184}); | |||
| return args; | |||
| } | |||
| void benchmark_matrix_mul( | |||
| Handle* handle, const std::vector<BenchArgs>& args, DType A_dtype, | |||
| DType B_dtype, DType C_dtype, const char* algo = nullptr, | |||
| @@ -364,6 +373,82 @@ MEGDNN_FOREACH_CUTLASS_KERNEL(cb) | |||
| #undef cb | |||
| #undef MEGDNN_FOREACH_CUTLASS_KERNEL | |||
| #define MEGDNN_FOREACH_CUTLASS_KERNEL(cb) \ | |||
| cb(1, 256, 128, 32, 64, 64, 32, 8, 8, 4); \ | |||
| cb(2, 128, 256, 32, 64, 64, 32, 8, 8, 4); \ | |||
| cb(3, 128, 128, 32, 64, 64, 32, 8, 8, 4); | |||
| #define cb(name, tbm, tbn, tbk, wm, wn, wk, im, in, ik) \ | |||
| TEST_F(CUDA, CUTLASS_F16_884_GEMM_##name) { \ | |||
| require_compute_capability(7, 0); \ | |||
| matrix_mul::check_matrix_mul<MatrixMulForward>( \ | |||
| dtype::Float16(), dtype::Float16(), dtype::Float16(), \ | |||
| handle_cuda(), \ | |||
| "CUTLASS_FLOAT16_TENSOR_OP_h" #im #in #ik "_" #tbm "X" #tbn \ | |||
| "X" #tbk "_" #wm "X" #wn "X" #wk, \ | |||
| param::MatrixMul::Format::DEFAULT, 8, 1e-2, \ | |||
| matrix_mul::get_matmul_args()); \ | |||
| } | |||
| MEGDNN_FOREACH_CUTLASS_KERNEL(cb) | |||
| #undef cb | |||
| #define cb(name, tbm, tbn, tbk, wm, wn, wk, im, in, ik) \ | |||
| TEST_F(CUDA, CUTLASS_F16_884_GEMM_SPLIT_K_##name) { \ | |||
| require_compute_capability(7, 0); \ | |||
| matrix_mul::check_matrix_mul<MatrixMulForward>( \ | |||
| dtype::Float16(), dtype::Float16(), dtype::Float16(), \ | |||
| handle_cuda(), \ | |||
| "CUTLASS_FLOAT16_TENSOR_OP_SPLIT_K_h" #im #in #ik "_" #tbm \ | |||
| "X" #tbn "X" #tbk "_" #wm "X" #wn "X" #wk, \ | |||
| param::MatrixMul::Format::DEFAULT, 8, 1e-3, \ | |||
| matrix_mul::get_matmul_args_split_k(), true, \ | |||
| param::MatrixMul::ComputeMode::FLOAT32); \ | |||
| } | |||
| MEGDNN_FOREACH_CUTLASS_KERNEL(cb) | |||
| #undef cb | |||
| #undef MEGDNN_FOREACH_CUTLASS_KERNEL | |||
| #define MEGDNN_FOREACH_CUTLASS_KERNEL(cb) \ | |||
| cb(1, 256, 128, 32, 64, 64, 32, 16, 8, 8); \ | |||
| cb(2, 128, 256, 32, 64, 64, 32, 16, 8, 8); \ | |||
| cb(3, 128, 128, 32, 64, 64, 32, 16, 8, 8); | |||
| #define cb(name, tbm, tbn, tbk, wm, wn, wk, im, in, ik) \ | |||
| TEST_F(CUDA, CUTLASS_F16_1688_GEMM_##name) { \ | |||
| require_compute_capability(7, 5); \ | |||
| matrix_mul::check_matrix_mul<MatrixMulForward>( \ | |||
| dtype::Float16(), dtype::Float16(), dtype::Float16(), \ | |||
| handle_cuda(), \ | |||
| "CUTLASS_FLOAT16_TENSOR_OP_h" #im #in #ik "_" #tbm "X" #tbn \ | |||
| "X" #tbk "_" #wm "X" #wn "X" #wk, \ | |||
| param::MatrixMul::Format::DEFAULT, 8, 1e-2, \ | |||
| matrix_mul::get_matmul_args(), true, \ | |||
| param::MatrixMul::ComputeMode::FLOAT32); \ | |||
| } | |||
| MEGDNN_FOREACH_CUTLASS_KERNEL(cb) | |||
| #undef cb | |||
| #define cb(name, tbm, tbn, tbk, wm, wn, wk, im, in, ik) \ | |||
| TEST_F(CUDA, CUTLASS_F16_1688_GEMM_SPLIT_K_##name) { \ | |||
| require_compute_capability(7, 5); \ | |||
| matrix_mul::check_matrix_mul<MatrixMulForward>( \ | |||
| dtype::Float16(), dtype::Float16(), dtype::Float16(), \ | |||
| handle_cuda(), \ | |||
| "CUTLASS_FLOAT16_TENSOR_OP_SPLIT_K_h" #im #in #ik "_" #tbm \ | |||
| "X" #tbn "X" #tbk "_" #wm "X" #wn "X" #wk, \ | |||
| param::MatrixMul::Format::DEFAULT, 8, 1e-3, \ | |||
| matrix_mul::get_matmul_args_split_k()); \ | |||
| } | |||
| MEGDNN_FOREACH_CUTLASS_KERNEL(cb) | |||
| #undef cb | |||
| #undef MEGDNN_FOREACH_CUTLASS_KERNEL | |||
| #if MEGDNN_WITH_BENCHMARK | |||
| TEST_F(CUDA, BENCHMARK_CUTLASS_MATMUL) { | |||
| benchmark_matrix_mul(handle_cuda(), get_square_matmul_args(), | |||
| @@ -376,6 +461,12 @@ TEST_F(CUDA, BENCHMARK_CUTLASS_MATMUL_FEAT) { | |||
| dtype::Float32(), dtype::Float32(), | |||
| "CUTLASS_FLOAT32_SIMT"); | |||
| } | |||
| TEST_F(CUDA, BENCHMARK_CUTLASS_F16_MATMUL_FEAT) { | |||
| benchmark_matrix_mul(handle_cuda(), get_f16_feat_model_args(), | |||
| dtype::Float16(), dtype::Float16(), dtype::Float16(), | |||
| "CUTLASS_FLOAT16_TENSOR_OP"); | |||
| } | |||
| #endif | |||
| } // namespace test | |||
| } // namespace megdnn | |||