GitOrigin-RevId: 025c591f75
tags/v1.6.0-rc1
| @@ -5,6 +5,8 @@ genrule( | |||||
| outs = cutlass_gen_list, | outs = cutlass_gen_list, | ||||
| cmd = """GEN=$(location //brain/megbrain/dnn/scripts/cutlass_generator:generator.py) | cmd = """GEN=$(location //brain/megbrain/dnn/scripts/cutlass_generator:generator.py) | ||||
| CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations gemm --type simt $(@D) | CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations gemm --type simt $(@D) | ||||
| CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations gemm --type tensorop884 $(@D) | |||||
| CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations gemm --type tensorop1688 $(@D) | |||||
| CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations gemv --type simt $(@D) | CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations gemv --type simt $(@D) | ||||
| CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations deconv --type simt $(@D) | CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations deconv --type simt $(@D) | ||||
| CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations conv2d --type simt $(@D) | CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations conv2d --type simt $(@D) | ||||
| @@ -252,7 +252,8 @@ def GeneratesGemm(tile, data_type, layout_a, layout_b, layout_c, min_cc, align_a | |||||
| if tile.math_instruction.element_accumulator == DataType.s32: | if tile.math_instruction.element_accumulator == DataType.s32: | ||||
| epilogues = [EpilogueFunctor.LinearCombinationClamp] | epilogues = [EpilogueFunctor.LinearCombinationClamp] | ||||
| else: | else: | ||||
| assert tile.math_instruction.element_accumulator == DataType.f32 | |||||
| assert tile.math_instruction.element_accumulator == DataType.f32 or \ | |||||
| tile.math_instruction.element_accumulator == DataType.f16 | |||||
| epilogues = [EpilogueFunctor.LinearCombination] | epilogues = [EpilogueFunctor.LinearCombination] | ||||
| for epilogue in epilogues: | for epilogue in epilogues: | ||||
| @@ -799,7 +800,22 @@ class EmitGemmSplitKParallelInstance: | |||||
| ${epilogue_vector_length}, | ${epilogue_vector_length}, | ||||
| ${element_accumulator}, | ${element_accumulator}, | ||||
| ${element_epilogue} | ${element_epilogue} | ||||
| > | |||||
| >, | |||||
| cutlass::epilogue::thread::Convert< | |||||
| ${element_accumulator}, | |||||
| ${epilogue_vector_length}, | |||||
| ${element_accumulator} | |||||
| >, | |||||
| cutlass::reduction::thread::ReduceAdd< | |||||
| ${element_accumulator}, | |||||
| ${element_accumulator}, | |||||
| ${epilogue_vector_length} | |||||
| >, | |||||
| cutlass::gemm::threadblock::GemmSplitKHorizontalThreadblockSwizzle, | |||||
| ${stages}, | |||||
| ${align_a}, | |||||
| ${align_b}, | |||||
| ${math_operation} | |||||
| >; | >; | ||||
| """ | """ | ||||
| def emit(self, operation): | def emit(self, operation): | ||||
| @@ -831,7 +847,10 @@ class EmitGemmSplitKParallelInstance: | |||||
| 'epilogue_vector_length': str(epilogue_vector_length), | 'epilogue_vector_length': str(epilogue_vector_length), | ||||
| 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), | 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), | ||||
| 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor], | 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor], | ||||
| 'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor], | |||||
| 'stages': str(operation.tile_description.stages), | |||||
| 'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation], | |||||
| 'align_a': str(operation.A.alignment), | |||||
| 'align_b': str(operation.B.alignment), | |||||
| } | } | ||||
| return SubstituteTemplate(self.template, values) | return SubstituteTemplate(self.template, values) | ||||
| @@ -32,6 +32,8 @@ if __name__ == "__main__": | |||||
| f.write("# Generated by dnn/scripts/cutlass_generator/gen_list.py\n\n") | f.write("# Generated by dnn/scripts/cutlass_generator/gen_list.py\n\n") | ||||
| f.write("cutlass_gen_list = [\n") | f.write("cutlass_gen_list = [\n") | ||||
| write_op_list(f, "gemm", "simt") | write_op_list(f, "gemm", "simt") | ||||
| write_op_list(f, "gemm", "tensorop1688") | |||||
| write_op_list(f, "gemm", "tensorop884") | |||||
| write_op_list(f, "gemv", "simt") | write_op_list(f, "gemv", "simt") | ||||
| write_op_list(f, "deconv", "simt") | write_op_list(f, "deconv", "simt") | ||||
| write_op_list(f, "conv2d", "simt") | write_op_list(f, "conv2d", "simt") | ||||
| @@ -596,6 +596,131 @@ def GenerateGemv_Simt(args): | |||||
| align_b)) | align_b)) | ||||
| return operations | return operations | ||||
| # | |||||
| def GeneratesGemm_TensorOp_1688(args): | |||||
| layouts = [ | |||||
| (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), # nn | |||||
| (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.RowMajor), # nt | |||||
| (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), # tn | |||||
| (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.RowMajor), # tt | |||||
| ] | |||||
| math_instructions = [ | |||||
| MathInstruction( \ | |||||
| [16, 8, 8], \ | |||||
| DataType.f16, DataType.f16, DataType.f32, \ | |||||
| OpcodeClass.TensorOp, \ | |||||
| MathOperation.multiply_add), | |||||
| MathInstruction( \ | |||||
| [16, 8, 8], \ | |||||
| DataType.f16, DataType.f16, DataType.f16, \ | |||||
| OpcodeClass.TensorOp, \ | |||||
| MathOperation.multiply_add), | |||||
| ] | |||||
| min_cc = 75 | |||||
| max_cc = 1024 | |||||
| alignment_constraints = [8, 4, 2, | |||||
| #1 | |||||
| ] | |||||
| operations = [] | |||||
| for math_inst in math_instructions: | |||||
| for layout in layouts: | |||||
| for align in alignment_constraints: | |||||
| tile_descriptions = [ | |||||
| TileDescription([256, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc), | |||||
| TileDescription([128, 256, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc), | |||||
| TileDescription([128, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), | |||||
| ## comment some configuration to reduce compilation time and binary size | |||||
| # TileDescription([ 64, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), | |||||
| # TileDescription([128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), | |||||
| # TileDescription([ 64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), | |||||
| ] | |||||
| data_type = [ | |||||
| math_inst.element_a, | |||||
| math_inst.element_b, | |||||
| math_inst.element_a, | |||||
| math_inst.element_accumulator, | |||||
| ] | |||||
| for tile in tile_descriptions: | |||||
| operations += GeneratesGemm(tile, \ | |||||
| data_type, \ | |||||
| layout[0], \ | |||||
| layout[1], \ | |||||
| layout[2], \ | |||||
| min_cc, \ | |||||
| align * 16, \ | |||||
| align * 16, \ | |||||
| align * 16) | |||||
| return operations | |||||
| # | |||||
| def GeneratesGemm_TensorOp_884(args): | |||||
| layouts = [ | |||||
| (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), # nn | |||||
| (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.RowMajor), # nt | |||||
| (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), # tn | |||||
| (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.RowMajor), # tt | |||||
| ] | |||||
| math_instructions = [ | |||||
| MathInstruction( \ | |||||
| [8, 8, 4], \ | |||||
| DataType.f16, DataType.f16, DataType.f32, \ | |||||
| OpcodeClass.TensorOp, \ | |||||
| MathOperation.multiply_add), | |||||
| MathInstruction( \ | |||||
| [8, 8, 4], \ | |||||
| DataType.f16, DataType.f16, DataType.f16, \ | |||||
| OpcodeClass.TensorOp, \ | |||||
| MathOperation.multiply_add), | |||||
| ] | |||||
| min_cc = 70 | |||||
| max_cc = 75 | |||||
| alignment_constraints = [8, 4, 2, | |||||
| # 1 | |||||
| ] | |||||
| operations = [] | |||||
| for math_inst in math_instructions: | |||||
| for layout in layouts: | |||||
| for align in alignment_constraints: | |||||
| tile_descriptions = [ | |||||
| TileDescription([256, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc), | |||||
| TileDescription([128, 256, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc), | |||||
| TileDescription([128, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), | |||||
| ## comment some configuration to reduce compilation time and binary size | |||||
| # TileDescription([ 64, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), | |||||
| # TileDescription([128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), | |||||
| # TileDescription([ 64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), | |||||
| ] | |||||
| data_type = [ | |||||
| math_inst.element_a, | |||||
| math_inst.element_b, | |||||
| math_inst.element_a, | |||||
| math_inst.element_accumulator, | |||||
| ] | |||||
| for tile in tile_descriptions: | |||||
| operations += GeneratesGemm(tile, \ | |||||
| data_type, \ | |||||
| layout[0], \ | |||||
| layout[1], \ | |||||
| layout[2], \ | |||||
| min_cc, \ | |||||
| align * 16, \ | |||||
| align * 16, \ | |||||
| align * 16) | |||||
| return operations | |||||
| # | # | ||||
| def GenerateConv2dOperations(args): | def GenerateConv2dOperations(args): | ||||
| if args.type == "simt": | if args.type == "simt": | ||||
| @@ -613,9 +738,14 @@ def GenerateDeconvOperations(args): | |||||
| return GenerateDeconv_Simt(args) | return GenerateDeconv_Simt(args) | ||||
| def GenerateGemmOperations(args): | def GenerateGemmOperations(args): | ||||
| assert args.type == "simt", "operation gemm only support" \ | |||||
| "simt. (got:{})".format(args.type) | |||||
| return GenerateGemm_Simt(args) | |||||
| if args.type == "tensorop884": | |||||
| return GeneratesGemm_TensorOp_884(args) | |||||
| elif args.type == "tensorop1688": | |||||
| return GeneratesGemm_TensorOp_1688(args) | |||||
| else: | |||||
| assert args.type == "simt", "operation gemm only support" \ | |||||
| "simt. (got:{})".format(args.type) | |||||
| return GenerateGemm_Simt(args) | |||||
| def GenerateGemvOperations(args): | def GenerateGemvOperations(args): | ||||
| assert args.type == "simt", "operation gemv only support" \ | assert args.type == "simt", "operation gemv only support" \ | ||||
| @@ -631,7 +761,7 @@ if __name__ == "__main__": | |||||
| parser.add_argument("--operations", type=str, choices=['gemm', 'gemv', 'conv2d', 'deconv'], | parser.add_argument("--operations", type=str, choices=['gemm', 'gemv', 'conv2d', 'deconv'], | ||||
| required=True, help="Specifies the operation to generate (gemm, gemv, conv2d, deconv)") | required=True, help="Specifies the operation to generate (gemm, gemv, conv2d, deconv)") | ||||
| parser.add_argument("output", type=str, help="output directory for CUTLASS kernel files") | parser.add_argument("output", type=str, help="output directory for CUTLASS kernel files") | ||||
| parser.add_argument("--type", type=str, choices=['simt', 'tensorop8816', 'tensorop8832'], | |||||
| parser.add_argument("--type", type=str, choices=['simt', 'tensorop8816', 'tensorop8832', 'tensorop884', 'tensorop1688'], | |||||
| default='simt', help="kernel type of CUTLASS kernel generator") | default='simt', help="kernel type of CUTLASS kernel generator") | ||||
| gemv_wrapper_path = "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper_batched_gemv_strided.cuinl" | gemv_wrapper_path = "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper_batched_gemv_strided.cuinl" | ||||
| @@ -138,6 +138,296 @@ cutlass_gen_list = [ | |||||
| "cutlass_simt_sgemm_256x64_8x2_tt_align1.cu", | "cutlass_simt_sgemm_256x64_8x2_tt_align1.cu", | ||||
| "cutlass_simt_sgemm_split_k_parallel_256x64_8x2_tt_align1.cu", | "cutlass_simt_sgemm_split_k_parallel_256x64_8x2_tt_align1.cu", | ||||
| "all_gemm_simt_operations.cu", | "all_gemm_simt_operations.cu", | ||||
| "cutlass_tensorop_f16_s1688gemm_f16_256x128_32x2_nn_align8.cu", | |||||
| "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_256x128_32x2_nn_align8.cu", | |||||
| "cutlass_tensorop_f16_s1688gemm_f16_128x256_32x2_nn_align8.cu", | |||||
| "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x256_32x2_nn_align8.cu", | |||||
| "cutlass_tensorop_f16_s1688gemm_f16_128x128_32x2_nn_align8.cu", | |||||
| "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x128_32x2_nn_align8.cu", | |||||
| "cutlass_tensorop_f16_s1688gemm_f16_256x128_32x2_nn_align4.cu", | |||||
| "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_256x128_32x2_nn_align4.cu", | |||||
| "cutlass_tensorop_f16_s1688gemm_f16_128x256_32x2_nn_align4.cu", | |||||
| "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x256_32x2_nn_align4.cu", | |||||
| "cutlass_tensorop_f16_s1688gemm_f16_128x128_32x2_nn_align4.cu", | |||||
| "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x128_32x2_nn_align4.cu", | |||||
| "cutlass_tensorop_f16_s1688gemm_f16_256x128_32x2_nn_align2.cu", | |||||
| "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_256x128_32x2_nn_align2.cu", | |||||
| "cutlass_tensorop_f16_s1688gemm_f16_128x256_32x2_nn_align2.cu", | |||||
| "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x256_32x2_nn_align2.cu", | |||||
| "cutlass_tensorop_f16_s1688gemm_f16_128x128_32x2_nn_align2.cu", | |||||
| "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x128_32x2_nn_align2.cu", | |||||
| "cutlass_tensorop_f16_s1688gemm_f16_256x128_32x2_nt_align8.cu", | |||||
| "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_256x128_32x2_nt_align8.cu", | |||||
| "cutlass_tensorop_f16_s1688gemm_f16_128x256_32x2_nt_align8.cu", | |||||
| "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x256_32x2_nt_align8.cu", | |||||
| "cutlass_tensorop_f16_s1688gemm_f16_128x128_32x2_nt_align8.cu", | |||||
| "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x128_32x2_nt_align8.cu", | |||||
| "cutlass_tensorop_f16_s1688gemm_f16_256x128_32x2_nt_align4.cu", | |||||
| "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_256x128_32x2_nt_align4.cu", | |||||
| "cutlass_tensorop_f16_s1688gemm_f16_128x256_32x2_nt_align4.cu", | |||||
| "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x256_32x2_nt_align4.cu", | |||||
| "cutlass_tensorop_f16_s1688gemm_f16_128x128_32x2_nt_align4.cu", | |||||
| "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x128_32x2_nt_align4.cu", | |||||
| "cutlass_tensorop_f16_s1688gemm_f16_256x128_32x2_nt_align2.cu", | |||||
| "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_256x128_32x2_nt_align2.cu", | |||||
| "cutlass_tensorop_f16_s1688gemm_f16_128x256_32x2_nt_align2.cu", | |||||
| "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x256_32x2_nt_align2.cu", | |||||
| "cutlass_tensorop_f16_s1688gemm_f16_128x128_32x2_nt_align2.cu", | |||||
| "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x128_32x2_nt_align2.cu", | |||||
| "cutlass_tensorop_f16_s1688gemm_f16_256x128_32x2_tn_align8.cu", | |||||
| "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_256x128_32x2_tn_align8.cu", | |||||
| "cutlass_tensorop_f16_s1688gemm_f16_128x256_32x2_tn_align8.cu", | |||||
| "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x256_32x2_tn_align8.cu", | |||||
| "cutlass_tensorop_f16_s1688gemm_f16_128x128_32x2_tn_align8.cu", | |||||
| "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x128_32x2_tn_align8.cu", | |||||
| "cutlass_tensorop_f16_s1688gemm_f16_256x128_32x2_tn_align4.cu", | |||||
| "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_256x128_32x2_tn_align4.cu", | |||||
| "cutlass_tensorop_f16_s1688gemm_f16_128x256_32x2_tn_align4.cu", | |||||
| "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x256_32x2_tn_align4.cu", | |||||
| "cutlass_tensorop_f16_s1688gemm_f16_128x128_32x2_tn_align4.cu", | |||||
| "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x128_32x2_tn_align4.cu", | |||||
| "cutlass_tensorop_f16_s1688gemm_f16_256x128_32x2_tn_align2.cu", | |||||
| "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_256x128_32x2_tn_align2.cu", | |||||
| "cutlass_tensorop_f16_s1688gemm_f16_128x256_32x2_tn_align2.cu", | |||||
| "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x256_32x2_tn_align2.cu", | |||||
| "cutlass_tensorop_f16_s1688gemm_f16_128x128_32x2_tn_align2.cu", | |||||
| "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x128_32x2_tn_align2.cu", | |||||
| "cutlass_tensorop_f16_s1688gemm_f16_256x128_32x2_tt_align8.cu", | |||||
| "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_256x128_32x2_tt_align8.cu", | |||||
| "cutlass_tensorop_f16_s1688gemm_f16_128x256_32x2_tt_align8.cu", | |||||
| "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x256_32x2_tt_align8.cu", | |||||
| "cutlass_tensorop_f16_s1688gemm_f16_128x128_32x2_tt_align8.cu", | |||||
| "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x128_32x2_tt_align8.cu", | |||||
| "cutlass_tensorop_f16_s1688gemm_f16_256x128_32x2_tt_align4.cu", | |||||
| "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_256x128_32x2_tt_align4.cu", | |||||
| "cutlass_tensorop_f16_s1688gemm_f16_128x256_32x2_tt_align4.cu", | |||||
| "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x256_32x2_tt_align4.cu", | |||||
| "cutlass_tensorop_f16_s1688gemm_f16_128x128_32x2_tt_align4.cu", | |||||
| "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x128_32x2_tt_align4.cu", | |||||
| "cutlass_tensorop_f16_s1688gemm_f16_256x128_32x2_tt_align2.cu", | |||||
| "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_256x128_32x2_tt_align2.cu", | |||||
| "cutlass_tensorop_f16_s1688gemm_f16_128x256_32x2_tt_align2.cu", | |||||
| "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x256_32x2_tt_align2.cu", | |||||
| "cutlass_tensorop_f16_s1688gemm_f16_128x128_32x2_tt_align2.cu", | |||||
| "cutlass_tensorop_f16_s1688gemm_split_k_parallel_f16_128x128_32x2_tt_align2.cu", | |||||
| "cutlass_tensorop_h1688gemm_256x128_32x2_nn_align8.cu", | |||||
| "cutlass_tensorop_h1688gemm_split_k_parallel_256x128_32x2_nn_align8.cu", | |||||
| "cutlass_tensorop_h1688gemm_128x256_32x2_nn_align8.cu", | |||||
| "cutlass_tensorop_h1688gemm_split_k_parallel_128x256_32x2_nn_align8.cu", | |||||
| "cutlass_tensorop_h1688gemm_128x128_32x2_nn_align8.cu", | |||||
| "cutlass_tensorop_h1688gemm_split_k_parallel_128x128_32x2_nn_align8.cu", | |||||
| "cutlass_tensorop_h1688gemm_256x128_32x2_nn_align4.cu", | |||||
| "cutlass_tensorop_h1688gemm_split_k_parallel_256x128_32x2_nn_align4.cu", | |||||
| "cutlass_tensorop_h1688gemm_128x256_32x2_nn_align4.cu", | |||||
| "cutlass_tensorop_h1688gemm_split_k_parallel_128x256_32x2_nn_align4.cu", | |||||
| "cutlass_tensorop_h1688gemm_128x128_32x2_nn_align4.cu", | |||||
| "cutlass_tensorop_h1688gemm_split_k_parallel_128x128_32x2_nn_align4.cu", | |||||
| "cutlass_tensorop_h1688gemm_256x128_32x2_nn_align2.cu", | |||||
| "cutlass_tensorop_h1688gemm_split_k_parallel_256x128_32x2_nn_align2.cu", | |||||
| "cutlass_tensorop_h1688gemm_128x256_32x2_nn_align2.cu", | |||||
| "cutlass_tensorop_h1688gemm_split_k_parallel_128x256_32x2_nn_align2.cu", | |||||
| "cutlass_tensorop_h1688gemm_128x128_32x2_nn_align2.cu", | |||||
| "cutlass_tensorop_h1688gemm_split_k_parallel_128x128_32x2_nn_align2.cu", | |||||
| "cutlass_tensorop_h1688gemm_256x128_32x2_nt_align8.cu", | |||||
| "cutlass_tensorop_h1688gemm_split_k_parallel_256x128_32x2_nt_align8.cu", | |||||
| "cutlass_tensorop_h1688gemm_128x256_32x2_nt_align8.cu", | |||||
| "cutlass_tensorop_h1688gemm_split_k_parallel_128x256_32x2_nt_align8.cu", | |||||
| "cutlass_tensorop_h1688gemm_128x128_32x2_nt_align8.cu", | |||||
| "cutlass_tensorop_h1688gemm_split_k_parallel_128x128_32x2_nt_align8.cu", | |||||
| "cutlass_tensorop_h1688gemm_256x128_32x2_nt_align4.cu", | |||||
| "cutlass_tensorop_h1688gemm_split_k_parallel_256x128_32x2_nt_align4.cu", | |||||
| "cutlass_tensorop_h1688gemm_128x256_32x2_nt_align4.cu", | |||||
| "cutlass_tensorop_h1688gemm_split_k_parallel_128x256_32x2_nt_align4.cu", | |||||
| "cutlass_tensorop_h1688gemm_128x128_32x2_nt_align4.cu", | |||||
| "cutlass_tensorop_h1688gemm_split_k_parallel_128x128_32x2_nt_align4.cu", | |||||
| "cutlass_tensorop_h1688gemm_256x128_32x2_nt_align2.cu", | |||||
| "cutlass_tensorop_h1688gemm_split_k_parallel_256x128_32x2_nt_align2.cu", | |||||
| "cutlass_tensorop_h1688gemm_128x256_32x2_nt_align2.cu", | |||||
| "cutlass_tensorop_h1688gemm_split_k_parallel_128x256_32x2_nt_align2.cu", | |||||
| "cutlass_tensorop_h1688gemm_128x128_32x2_nt_align2.cu", | |||||
| "cutlass_tensorop_h1688gemm_split_k_parallel_128x128_32x2_nt_align2.cu", | |||||
| "cutlass_tensorop_h1688gemm_256x128_32x2_tn_align8.cu", | |||||
| "cutlass_tensorop_h1688gemm_split_k_parallel_256x128_32x2_tn_align8.cu", | |||||
| "cutlass_tensorop_h1688gemm_128x256_32x2_tn_align8.cu", | |||||
| "cutlass_tensorop_h1688gemm_split_k_parallel_128x256_32x2_tn_align8.cu", | |||||
| "cutlass_tensorop_h1688gemm_128x128_32x2_tn_align8.cu", | |||||
| "cutlass_tensorop_h1688gemm_split_k_parallel_128x128_32x2_tn_align8.cu", | |||||
| "cutlass_tensorop_h1688gemm_256x128_32x2_tn_align4.cu", | |||||
| "cutlass_tensorop_h1688gemm_split_k_parallel_256x128_32x2_tn_align4.cu", | |||||
| "cutlass_tensorop_h1688gemm_128x256_32x2_tn_align4.cu", | |||||
| "cutlass_tensorop_h1688gemm_split_k_parallel_128x256_32x2_tn_align4.cu", | |||||
| "cutlass_tensorop_h1688gemm_128x128_32x2_tn_align4.cu", | |||||
| "cutlass_tensorop_h1688gemm_split_k_parallel_128x128_32x2_tn_align4.cu", | |||||
| "cutlass_tensorop_h1688gemm_256x128_32x2_tn_align2.cu", | |||||
| "cutlass_tensorop_h1688gemm_split_k_parallel_256x128_32x2_tn_align2.cu", | |||||
| "cutlass_tensorop_h1688gemm_128x256_32x2_tn_align2.cu", | |||||
| "cutlass_tensorop_h1688gemm_split_k_parallel_128x256_32x2_tn_align2.cu", | |||||
| "cutlass_tensorop_h1688gemm_128x128_32x2_tn_align2.cu", | |||||
| "cutlass_tensorop_h1688gemm_split_k_parallel_128x128_32x2_tn_align2.cu", | |||||
| "cutlass_tensorop_h1688gemm_256x128_32x2_tt_align8.cu", | |||||
| "cutlass_tensorop_h1688gemm_split_k_parallel_256x128_32x2_tt_align8.cu", | |||||
| "cutlass_tensorop_h1688gemm_128x256_32x2_tt_align8.cu", | |||||
| "cutlass_tensorop_h1688gemm_split_k_parallel_128x256_32x2_tt_align8.cu", | |||||
| "cutlass_tensorop_h1688gemm_128x128_32x2_tt_align8.cu", | |||||
| "cutlass_tensorop_h1688gemm_split_k_parallel_128x128_32x2_tt_align8.cu", | |||||
| "cutlass_tensorop_h1688gemm_256x128_32x2_tt_align4.cu", | |||||
| "cutlass_tensorop_h1688gemm_split_k_parallel_256x128_32x2_tt_align4.cu", | |||||
| "cutlass_tensorop_h1688gemm_128x256_32x2_tt_align4.cu", | |||||
| "cutlass_tensorop_h1688gemm_split_k_parallel_128x256_32x2_tt_align4.cu", | |||||
| "cutlass_tensorop_h1688gemm_128x128_32x2_tt_align4.cu", | |||||
| "cutlass_tensorop_h1688gemm_split_k_parallel_128x128_32x2_tt_align4.cu", | |||||
| "cutlass_tensorop_h1688gemm_256x128_32x2_tt_align2.cu", | |||||
| "cutlass_tensorop_h1688gemm_split_k_parallel_256x128_32x2_tt_align2.cu", | |||||
| "cutlass_tensorop_h1688gemm_128x256_32x2_tt_align2.cu", | |||||
| "cutlass_tensorop_h1688gemm_split_k_parallel_128x256_32x2_tt_align2.cu", | |||||
| "cutlass_tensorop_h1688gemm_128x128_32x2_tt_align2.cu", | |||||
| "cutlass_tensorop_h1688gemm_split_k_parallel_128x128_32x2_tt_align2.cu", | |||||
| "all_gemm_tensorop1688_operations.cu", | |||||
| "cutlass_tensorop_f16_s884gemm_f16_256x128_32x2_nn_align8.cu", | |||||
| "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_256x128_32x2_nn_align8.cu", | |||||
| "cutlass_tensorop_f16_s884gemm_f16_128x256_32x2_nn_align8.cu", | |||||
| "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x256_32x2_nn_align8.cu", | |||||
| "cutlass_tensorop_f16_s884gemm_f16_128x128_32x2_nn_align8.cu", | |||||
| "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x128_32x2_nn_align8.cu", | |||||
| "cutlass_tensorop_f16_s884gemm_f16_256x128_32x2_nn_align4.cu", | |||||
| "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_256x128_32x2_nn_align4.cu", | |||||
| "cutlass_tensorop_f16_s884gemm_f16_128x256_32x2_nn_align4.cu", | |||||
| "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x256_32x2_nn_align4.cu", | |||||
| "cutlass_tensorop_f16_s884gemm_f16_128x128_32x2_nn_align4.cu", | |||||
| "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x128_32x2_nn_align4.cu", | |||||
| "cutlass_tensorop_f16_s884gemm_f16_256x128_32x2_nn_align2.cu", | |||||
| "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_256x128_32x2_nn_align2.cu", | |||||
| "cutlass_tensorop_f16_s884gemm_f16_128x256_32x2_nn_align2.cu", | |||||
| "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x256_32x2_nn_align2.cu", | |||||
| "cutlass_tensorop_f16_s884gemm_f16_128x128_32x2_nn_align2.cu", | |||||
| "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x128_32x2_nn_align2.cu", | |||||
| "cutlass_tensorop_f16_s884gemm_f16_256x128_32x2_nt_align8.cu", | |||||
| "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_256x128_32x2_nt_align8.cu", | |||||
| "cutlass_tensorop_f16_s884gemm_f16_128x256_32x2_nt_align8.cu", | |||||
| "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x256_32x2_nt_align8.cu", | |||||
| "cutlass_tensorop_f16_s884gemm_f16_128x128_32x2_nt_align8.cu", | |||||
| "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x128_32x2_nt_align8.cu", | |||||
| "cutlass_tensorop_f16_s884gemm_f16_256x128_32x2_nt_align4.cu", | |||||
| "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_256x128_32x2_nt_align4.cu", | |||||
| "cutlass_tensorop_f16_s884gemm_f16_128x256_32x2_nt_align4.cu", | |||||
| "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x256_32x2_nt_align4.cu", | |||||
| "cutlass_tensorop_f16_s884gemm_f16_128x128_32x2_nt_align4.cu", | |||||
| "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x128_32x2_nt_align4.cu", | |||||
| "cutlass_tensorop_f16_s884gemm_f16_256x128_32x2_nt_align2.cu", | |||||
| "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_256x128_32x2_nt_align2.cu", | |||||
| "cutlass_tensorop_f16_s884gemm_f16_128x256_32x2_nt_align2.cu", | |||||
| "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x256_32x2_nt_align2.cu", | |||||
| "cutlass_tensorop_f16_s884gemm_f16_128x128_32x2_nt_align2.cu", | |||||
| "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x128_32x2_nt_align2.cu", | |||||
| "cutlass_tensorop_f16_s884gemm_f16_256x128_32x2_tn_align8.cu", | |||||
| "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_256x128_32x2_tn_align8.cu", | |||||
| "cutlass_tensorop_f16_s884gemm_f16_128x256_32x2_tn_align8.cu", | |||||
| "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x256_32x2_tn_align8.cu", | |||||
| "cutlass_tensorop_f16_s884gemm_f16_128x128_32x2_tn_align8.cu", | |||||
| "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x128_32x2_tn_align8.cu", | |||||
| "cutlass_tensorop_f16_s884gemm_f16_256x128_32x2_tn_align4.cu", | |||||
| "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_256x128_32x2_tn_align4.cu", | |||||
| "cutlass_tensorop_f16_s884gemm_f16_128x256_32x2_tn_align4.cu", | |||||
| "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x256_32x2_tn_align4.cu", | |||||
| "cutlass_tensorop_f16_s884gemm_f16_128x128_32x2_tn_align4.cu", | |||||
| "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x128_32x2_tn_align4.cu", | |||||
| "cutlass_tensorop_f16_s884gemm_f16_256x128_32x2_tn_align2.cu", | |||||
| "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_256x128_32x2_tn_align2.cu", | |||||
| "cutlass_tensorop_f16_s884gemm_f16_128x256_32x2_tn_align2.cu", | |||||
| "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x256_32x2_tn_align2.cu", | |||||
| "cutlass_tensorop_f16_s884gemm_f16_128x128_32x2_tn_align2.cu", | |||||
| "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x128_32x2_tn_align2.cu", | |||||
| "cutlass_tensorop_f16_s884gemm_f16_256x128_32x2_tt_align8.cu", | |||||
| "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_256x128_32x2_tt_align8.cu", | |||||
| "cutlass_tensorop_f16_s884gemm_f16_128x256_32x2_tt_align8.cu", | |||||
| "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x256_32x2_tt_align8.cu", | |||||
| "cutlass_tensorop_f16_s884gemm_f16_128x128_32x2_tt_align8.cu", | |||||
| "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x128_32x2_tt_align8.cu", | |||||
| "cutlass_tensorop_f16_s884gemm_f16_256x128_32x2_tt_align4.cu", | |||||
| "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_256x128_32x2_tt_align4.cu", | |||||
| "cutlass_tensorop_f16_s884gemm_f16_128x256_32x2_tt_align4.cu", | |||||
| "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x256_32x2_tt_align4.cu", | |||||
| "cutlass_tensorop_f16_s884gemm_f16_128x128_32x2_tt_align4.cu", | |||||
| "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x128_32x2_tt_align4.cu", | |||||
| "cutlass_tensorop_f16_s884gemm_f16_256x128_32x2_tt_align2.cu", | |||||
| "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_256x128_32x2_tt_align2.cu", | |||||
| "cutlass_tensorop_f16_s884gemm_f16_128x256_32x2_tt_align2.cu", | |||||
| "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x256_32x2_tt_align2.cu", | |||||
| "cutlass_tensorop_f16_s884gemm_f16_128x128_32x2_tt_align2.cu", | |||||
| "cutlass_tensorop_f16_s884gemm_split_k_parallel_f16_128x128_32x2_tt_align2.cu", | |||||
| "cutlass_tensorop_h884gemm_256x128_32x2_nn_align8.cu", | |||||
| "cutlass_tensorop_h884gemm_split_k_parallel_256x128_32x2_nn_align8.cu", | |||||
| "cutlass_tensorop_h884gemm_128x256_32x2_nn_align8.cu", | |||||
| "cutlass_tensorop_h884gemm_split_k_parallel_128x256_32x2_nn_align8.cu", | |||||
| "cutlass_tensorop_h884gemm_128x128_32x2_nn_align8.cu", | |||||
| "cutlass_tensorop_h884gemm_split_k_parallel_128x128_32x2_nn_align8.cu", | |||||
| "cutlass_tensorop_h884gemm_256x128_32x2_nn_align4.cu", | |||||
| "cutlass_tensorop_h884gemm_split_k_parallel_256x128_32x2_nn_align4.cu", | |||||
| "cutlass_tensorop_h884gemm_128x256_32x2_nn_align4.cu", | |||||
| "cutlass_tensorop_h884gemm_split_k_parallel_128x256_32x2_nn_align4.cu", | |||||
| "cutlass_tensorop_h884gemm_128x128_32x2_nn_align4.cu", | |||||
| "cutlass_tensorop_h884gemm_split_k_parallel_128x128_32x2_nn_align4.cu", | |||||
| "cutlass_tensorop_h884gemm_256x128_32x2_nn_align2.cu", | |||||
| "cutlass_tensorop_h884gemm_split_k_parallel_256x128_32x2_nn_align2.cu", | |||||
| "cutlass_tensorop_h884gemm_128x256_32x2_nn_align2.cu", | |||||
| "cutlass_tensorop_h884gemm_split_k_parallel_128x256_32x2_nn_align2.cu", | |||||
| "cutlass_tensorop_h884gemm_128x128_32x2_nn_align2.cu", | |||||
| "cutlass_tensorop_h884gemm_split_k_parallel_128x128_32x2_nn_align2.cu", | |||||
| "cutlass_tensorop_h884gemm_256x128_32x2_nt_align8.cu", | |||||
| "cutlass_tensorop_h884gemm_split_k_parallel_256x128_32x2_nt_align8.cu", | |||||
| "cutlass_tensorop_h884gemm_128x256_32x2_nt_align8.cu", | |||||
| "cutlass_tensorop_h884gemm_split_k_parallel_128x256_32x2_nt_align8.cu", | |||||
| "cutlass_tensorop_h884gemm_128x128_32x2_nt_align8.cu", | |||||
| "cutlass_tensorop_h884gemm_split_k_parallel_128x128_32x2_nt_align8.cu", | |||||
| "cutlass_tensorop_h884gemm_256x128_32x2_nt_align4.cu", | |||||
| "cutlass_tensorop_h884gemm_split_k_parallel_256x128_32x2_nt_align4.cu", | |||||
| "cutlass_tensorop_h884gemm_128x256_32x2_nt_align4.cu", | |||||
| "cutlass_tensorop_h884gemm_split_k_parallel_128x256_32x2_nt_align4.cu", | |||||
| "cutlass_tensorop_h884gemm_128x128_32x2_nt_align4.cu", | |||||
| "cutlass_tensorop_h884gemm_split_k_parallel_128x128_32x2_nt_align4.cu", | |||||
| "cutlass_tensorop_h884gemm_256x128_32x2_nt_align2.cu", | |||||
| "cutlass_tensorop_h884gemm_split_k_parallel_256x128_32x2_nt_align2.cu", | |||||
| "cutlass_tensorop_h884gemm_128x256_32x2_nt_align2.cu", | |||||
| "cutlass_tensorop_h884gemm_split_k_parallel_128x256_32x2_nt_align2.cu", | |||||
| "cutlass_tensorop_h884gemm_128x128_32x2_nt_align2.cu", | |||||
| "cutlass_tensorop_h884gemm_split_k_parallel_128x128_32x2_nt_align2.cu", | |||||
| "cutlass_tensorop_h884gemm_256x128_32x2_tn_align8.cu", | |||||
| "cutlass_tensorop_h884gemm_split_k_parallel_256x128_32x2_tn_align8.cu", | |||||
| "cutlass_tensorop_h884gemm_128x256_32x2_tn_align8.cu", | |||||
| "cutlass_tensorop_h884gemm_split_k_parallel_128x256_32x2_tn_align8.cu", | |||||
| "cutlass_tensorop_h884gemm_128x128_32x2_tn_align8.cu", | |||||
| "cutlass_tensorop_h884gemm_split_k_parallel_128x128_32x2_tn_align8.cu", | |||||
| "cutlass_tensorop_h884gemm_256x128_32x2_tn_align4.cu", | |||||
| "cutlass_tensorop_h884gemm_split_k_parallel_256x128_32x2_tn_align4.cu", | |||||
| "cutlass_tensorop_h884gemm_128x256_32x2_tn_align4.cu", | |||||
| "cutlass_tensorop_h884gemm_split_k_parallel_128x256_32x2_tn_align4.cu", | |||||
| "cutlass_tensorop_h884gemm_128x128_32x2_tn_align4.cu", | |||||
| "cutlass_tensorop_h884gemm_split_k_parallel_128x128_32x2_tn_align4.cu", | |||||
| "cutlass_tensorop_h884gemm_256x128_32x2_tn_align2.cu", | |||||
| "cutlass_tensorop_h884gemm_split_k_parallel_256x128_32x2_tn_align2.cu", | |||||
| "cutlass_tensorop_h884gemm_128x256_32x2_tn_align2.cu", | |||||
| "cutlass_tensorop_h884gemm_split_k_parallel_128x256_32x2_tn_align2.cu", | |||||
| "cutlass_tensorop_h884gemm_128x128_32x2_tn_align2.cu", | |||||
| "cutlass_tensorop_h884gemm_split_k_parallel_128x128_32x2_tn_align2.cu", | |||||
| "cutlass_tensorop_h884gemm_256x128_32x2_tt_align8.cu", | |||||
| "cutlass_tensorop_h884gemm_split_k_parallel_256x128_32x2_tt_align8.cu", | |||||
| "cutlass_tensorop_h884gemm_128x256_32x2_tt_align8.cu", | |||||
| "cutlass_tensorop_h884gemm_split_k_parallel_128x256_32x2_tt_align8.cu", | |||||
| "cutlass_tensorop_h884gemm_128x128_32x2_tt_align8.cu", | |||||
| "cutlass_tensorop_h884gemm_split_k_parallel_128x128_32x2_tt_align8.cu", | |||||
| "cutlass_tensorop_h884gemm_256x128_32x2_tt_align4.cu", | |||||
| "cutlass_tensorop_h884gemm_split_k_parallel_256x128_32x2_tt_align4.cu", | |||||
| "cutlass_tensorop_h884gemm_128x256_32x2_tt_align4.cu", | |||||
| "cutlass_tensorop_h884gemm_split_k_parallel_128x256_32x2_tt_align4.cu", | |||||
| "cutlass_tensorop_h884gemm_128x128_32x2_tt_align4.cu", | |||||
| "cutlass_tensorop_h884gemm_split_k_parallel_128x128_32x2_tt_align4.cu", | |||||
| "cutlass_tensorop_h884gemm_256x128_32x2_tt_align2.cu", | |||||
| "cutlass_tensorop_h884gemm_split_k_parallel_256x128_32x2_tt_align2.cu", | |||||
| "cutlass_tensorop_h884gemm_128x256_32x2_tt_align2.cu", | |||||
| "cutlass_tensorop_h884gemm_split_k_parallel_128x256_32x2_tt_align2.cu", | |||||
| "cutlass_tensorop_h884gemm_128x128_32x2_tt_align2.cu", | |||||
| "cutlass_tensorop_h884gemm_split_k_parallel_128x128_32x2_tt_align2.cu", | |||||
| "all_gemm_tensorop884_operations.cu", | |||||
| "cutlass_simt_sgemv_batched_strided_1x128_32_tt_align4x4.cu", | "cutlass_simt_sgemv_batched_strided_1x128_32_tt_align4x4.cu", | ||||
| "cutlass_simt_sgemv_batched_strided_1x128_16_tt_align4x2.cu", | "cutlass_simt_sgemv_batched_strided_1x128_16_tt_align4x2.cu", | ||||
| "cutlass_simt_sgemv_batched_strided_1x128_8_tt_align4x1.cu", | "cutlass_simt_sgemv_batched_strided_1x128_8_tt_align4x1.cu", | ||||
| @@ -646,4 +936,4 @@ cutlass_gen_list = [ | |||||
| "cutlass_tensorop_u4_i8832fprop_roc_relu_u4_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | "cutlass_tensorop_u4_i8832fprop_roc_relu_u4_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | ||||
| "cutlass_tensorop_u4_i8832fprop_1x1_roc_relu_u4_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | "cutlass_tensorop_u4_i8832fprop_1x1_roc_relu_u4_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | ||||
| "all_conv2d_tensorop8832_operations.cu", | "all_conv2d_tensorop8832_operations.cu", | ||||
| ] | |||||
| ] | |||||
| @@ -151,6 +151,8 @@ if(MGE_WITH_CUDA) | |||||
| set(${gen_files} "${${gen_files}}" PARENT_SCOPE) | set(${gen_files} "${${gen_files}}" PARENT_SCOPE) | ||||
| endfunction() | endfunction() | ||||
| gen_cutlass_kimpl(gemm simt CUTLASS_SOURCES) | gen_cutlass_kimpl(gemm simt CUTLASS_SOURCES) | ||||
| gen_cutlass_kimpl(gemm tensorop884 CUTLASS_SOURCES) | |||||
| gen_cutlass_kimpl(gemm tensorop1688 CUTLASS_SOURCES) | |||||
| gen_cutlass_kimpl(gemv simt CUTLASS_SOURCES) | gen_cutlass_kimpl(gemv simt CUTLASS_SOURCES) | ||||
| gen_cutlass_kimpl(deconv simt CUTLASS_SOURCES) | gen_cutlass_kimpl(deconv simt CUTLASS_SOURCES) | ||||
| gen_cutlass_kimpl(conv2d simt CUTLASS_SOURCES) | gen_cutlass_kimpl(conv2d simt CUTLASS_SOURCES) | ||||
| @@ -49,6 +49,8 @@ namespace library { | |||||
| (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) | (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) | ||||
| void initialize_all_gemm_simt_operations(Manifest& manifest); | void initialize_all_gemm_simt_operations(Manifest& manifest); | ||||
| void initialize_all_gemm_tensorop884_operations(Manifest& manifest); | |||||
| void initialize_all_gemm_tensorop1688_operations(Manifest& manifest); | |||||
| void initialize_all_conv2d_simt_operations(Manifest& manifest); | void initialize_all_conv2d_simt_operations(Manifest& manifest); | ||||
| void initialize_all_conv2d_tensorop8816_operations(Manifest& manifest); | void initialize_all_conv2d_tensorop8816_operations(Manifest& manifest); | ||||
| void initialize_all_conv2d_tensorop8832_operations(Manifest& manifest); | void initialize_all_conv2d_tensorop8832_operations(Manifest& manifest); | ||||
| @@ -56,6 +58,8 @@ void initialize_all_deconv_simt_operations(Manifest& manifest); | |||||
| void initialize_all(Manifest& manifest) { | void initialize_all(Manifest& manifest) { | ||||
| initialize_all_gemm_simt_operations(manifest); | initialize_all_gemm_simt_operations(manifest); | ||||
| initialize_all_gemm_tensorop884_operations(manifest); | |||||
| initialize_all_gemm_tensorop1688_operations(manifest); | |||||
| initialize_all_conv2d_simt_operations(manifest); | initialize_all_conv2d_simt_operations(manifest); | ||||
| initialize_all_conv2d_tensorop8816_operations(manifest); | initialize_all_conv2d_tensorop8816_operations(manifest); | ||||
| initialize_all_conv2d_tensorop8832_operations(manifest); | initialize_all_conv2d_tensorop8832_operations(manifest); | ||||
| @@ -55,6 +55,8 @@ GemmKey get_gemm_key_from_desc(const GemmDescription& desc) { | |||||
| key.layout_B = desc.B.layout; | key.layout_B = desc.B.layout; | ||||
| key.element_C = desc.C.element; | key.element_C = desc.C.element; | ||||
| key.layout_C = desc.C.layout; | key.layout_C = desc.C.layout; | ||||
| key.element_accumulator = | |||||
| desc.tile_description.math_instruction.element_accumulator; | |||||
| key.threadblock_shape_m = desc.tile_description.threadblock_shape.m(); | key.threadblock_shape_m = desc.tile_description.threadblock_shape.m(); | ||||
| key.threadblock_shape_n = desc.tile_description.threadblock_shape.n(); | key.threadblock_shape_n = desc.tile_description.threadblock_shape.n(); | ||||
| @@ -75,6 +77,8 @@ GemmKey get_gemm_key_from_desc(const GemmDescription& desc) { | |||||
| desc.tile_description.math_instruction.instruction_shape.k(); | desc.tile_description.math_instruction.instruction_shape.k(); | ||||
| key.stages = desc.stages; | key.stages = desc.stages; | ||||
| key.alignment_A = desc.A.alignment; | |||||
| key.alignment_B = desc.B.alignment; | |||||
| key.split_k_mode = desc.split_k_mode; | key.split_k_mode = desc.split_k_mode; | ||||
| return key; | return key; | ||||
| @@ -77,6 +77,7 @@ struct GemmKey { | |||||
| LayoutTypeID layout_B; | LayoutTypeID layout_B; | ||||
| NumericTypeID element_C; | NumericTypeID element_C; | ||||
| LayoutTypeID layout_C; | LayoutTypeID layout_C; | ||||
| NumericTypeID element_accumulator; | |||||
| int threadblock_shape_m; | int threadblock_shape_m; | ||||
| int threadblock_shape_n; | int threadblock_shape_n; | ||||
| @@ -91,12 +92,15 @@ struct GemmKey { | |||||
| int instruction_shape_k; | int instruction_shape_k; | ||||
| int stages; | int stages; | ||||
| int alignment_A; | |||||
| int alignment_B; | |||||
| SplitKMode split_k_mode; | SplitKMode split_k_mode; | ||||
| inline bool operator==(GemmKey const& rhs) const { | inline bool operator==(GemmKey const& rhs) const { | ||||
| return (element_A == rhs.element_A) && (layout_A == rhs.layout_A) && | return (element_A == rhs.element_A) && (layout_A == rhs.layout_A) && | ||||
| (element_B == rhs.element_B) && (layout_B == rhs.layout_B) && | (element_B == rhs.element_B) && (layout_B == rhs.layout_B) && | ||||
| (element_C == rhs.element_C) && (layout_C == rhs.layout_C) && | (element_C == rhs.element_C) && (layout_C == rhs.layout_C) && | ||||
| (element_accumulator == rhs.element_accumulator) && | |||||
| (threadblock_shape_m == rhs.threadblock_shape_m) && | (threadblock_shape_m == rhs.threadblock_shape_m) && | ||||
| (threadblock_shape_n == rhs.threadblock_shape_n) && | (threadblock_shape_n == rhs.threadblock_shape_n) && | ||||
| (threadblock_shape_k == rhs.threadblock_shape_k) && | (threadblock_shape_k == rhs.threadblock_shape_k) && | ||||
| @@ -106,7 +110,9 @@ struct GemmKey { | |||||
| (instruction_shape_m == rhs.instruction_shape_m) && | (instruction_shape_m == rhs.instruction_shape_m) && | ||||
| (instruction_shape_n == rhs.instruction_shape_n) && | (instruction_shape_n == rhs.instruction_shape_n) && | ||||
| (instruction_shape_k == rhs.instruction_shape_k) && | (instruction_shape_k == rhs.instruction_shape_k) && | ||||
| (stages == rhs.stages) && (split_k_mode == rhs.split_k_mode); | |||||
| (stages == rhs.stages) && (alignment_A == rhs.alignment_A) && | |||||
| (alignment_B == rhs.alignment_B) && | |||||
| (split_k_mode == rhs.split_k_mode); | |||||
| } | } | ||||
| inline bool operator!=(GemmKey const& rhs) const { return !(*this == rhs); } | inline bool operator!=(GemmKey const& rhs) const { return !(*this == rhs); } | ||||
| @@ -130,10 +136,13 @@ struct GemmKey { | |||||
| "\n layout_B: " + to_string(layout_B) + | "\n layout_B: " + to_string(layout_B) + | ||||
| "\n element_C: " + to_string(element_C) + | "\n element_C: " + to_string(element_C) + | ||||
| "\n layout_C: " + to_string(layout_C) + | "\n layout_C: " + to_string(layout_C) + | ||||
| "\n element_accumulator: " + to_string(element_accumulator) + | |||||
| "\n threadblock_shape: " + threadblock_shape_str + | "\n threadblock_shape: " + threadblock_shape_str + | ||||
| "\n warp_shape: " + warp_shape_str + | "\n warp_shape: " + warp_shape_str + | ||||
| "\n instruction_shape: " + instruction_shape_str + | "\n instruction_shape: " + instruction_shape_str + | ||||
| "\n stages: " + std::to_string(stages) + | "\n stages: " + std::to_string(stages) + | ||||
| "\n alignment_A: " + std::to_string(alignment_A) + | |||||
| "\n alignment_B: " + std::to_string(alignment_B) + | |||||
| "\n split_k_mode: " + to_string(split_k_mode) + "\n}"; | "\n split_k_mode: " + to_string(split_k_mode) + "\n}"; | ||||
| } | } | ||||
| }; | }; | ||||
| @@ -147,6 +156,8 @@ struct GemmKeyHasher { | |||||
| .update(&key.layout_B, sizeof(key.layout_B)) | .update(&key.layout_B, sizeof(key.layout_B)) | ||||
| .update(&key.element_C, sizeof(key.element_C)) | .update(&key.element_C, sizeof(key.element_C)) | ||||
| .update(&key.layout_C, sizeof(key.layout_C)) | .update(&key.layout_C, sizeof(key.layout_C)) | ||||
| .update(&key.element_accumulator, | |||||
| sizeof(key.element_accumulator)) | |||||
| .update(&key.threadblock_shape_m, | .update(&key.threadblock_shape_m, | ||||
| sizeof(key.threadblock_shape_m)) | sizeof(key.threadblock_shape_m)) | ||||
| .update(&key.threadblock_shape_n, | .update(&key.threadblock_shape_n, | ||||
| @@ -157,6 +168,8 @@ struct GemmKeyHasher { | |||||
| .update(&key.warp_shape_n, sizeof(key.warp_shape_n)) | .update(&key.warp_shape_n, sizeof(key.warp_shape_n)) | ||||
| .update(&key.warp_shape_k, sizeof(key.warp_shape_k)) | .update(&key.warp_shape_k, sizeof(key.warp_shape_k)) | ||||
| .update(&key.stages, sizeof(key.stages)) | .update(&key.stages, sizeof(key.stages)) | ||||
| .update(&key.alignment_A, sizeof(key.alignment_A)) | |||||
| .update(&key.alignment_B, sizeof(key.alignment_B)) | |||||
| .update(&key.split_k_mode, sizeof(key.split_k_mode)) | .update(&key.split_k_mode, sizeof(key.split_k_mode)) | ||||
| .digest(); | .digest(); | ||||
| } | } | ||||
| @@ -43,6 +43,12 @@ MatrixMulForwardImpl::AlgoPack::AlgoPack() { | |||||
| for (auto&& algo : simt_float32_gemv_batched_strided) { | for (auto&& algo : simt_float32_gemv_batched_strided) { | ||||
| all_algos.push_back(&algo); | all_algos.push_back(&algo); | ||||
| } | } | ||||
| for (auto&& algo : tensorop_float16) { | |||||
| all_algos.push_back(&algo); | |||||
| } | |||||
| for (auto&& algo : tensorop_float16_split_k) { | |||||
| all_algos.push_back(&algo); | |||||
| } | |||||
| #endif | #endif | ||||
| all_algos.push_back(&naive); | all_algos.push_back(&naive); | ||||
| @@ -53,7 +59,7 @@ MatrixMulForwardImpl::AlgoPack::AlgoPack() { | |||||
| #if CUDA_VERSION >= 9020 | #if CUDA_VERSION >= 9020 | ||||
| void MatrixMulForwardImpl::AlgoPack::fill_cutlass_algos() { | void MatrixMulForwardImpl::AlgoPack::fill_cutlass_algos() { | ||||
| using AlgoParam = AlgoFloat32SIMT::AlgoParam; | |||||
| using AlgoParam = AlgoCutlassMatrixMulBase::AlgoParam; | |||||
| simt_float32.emplace_back(AlgoParam{64, 256, 8, 32, 64, 8}); | simt_float32.emplace_back(AlgoParam{64, 256, 8, 32, 64, 8}); | ||||
| simt_float32.emplace_back(AlgoParam{256, 64, 8, 64, 32, 8}); | simt_float32.emplace_back(AlgoParam{256, 64, 8, 64, 32, 8}); | ||||
| simt_float32.emplace_back(AlgoParam{32, 256, 8, 16, 64, 8}); | simt_float32.emplace_back(AlgoParam{32, 256, 8, 16, 64, 8}); | ||||
| @@ -91,6 +97,19 @@ void MatrixMulForwardImpl::AlgoPack::fill_cutlass_algos() { | |||||
| simt_float32_gemv_batched_strided.emplace_back(128); | simt_float32_gemv_batched_strided.emplace_back(128); | ||||
| simt_float32_gemv_batched_strided.emplace_back(64); | simt_float32_gemv_batched_strided.emplace_back(64); | ||||
| simt_float32_gemv_batched_strided.emplace_back(32); | simt_float32_gemv_batched_strided.emplace_back(32); | ||||
| #define FOREACH_CUTLASS_MATMUL_F16_SHAPES(cb) \ | |||||
| cb(256, 128, 32, 64, 64, 32, 8, 8, 4); \ | |||||
| cb(128, 256, 32, 64, 64, 32, 8, 8, 4); \ | |||||
| cb(128, 128, 32, 64, 64, 32, 8, 8, 4); \ | |||||
| cb(256, 128, 32, 64, 64, 32, 16, 8, 8); \ | |||||
| cb(128, 256, 32, 64, 64, 32, 16, 8, 8); \ | |||||
| cb(128, 128, 32, 64, 64, 32, 16, 8, 8); | |||||
| #define cb(...) \ | |||||
| tensorop_float16.emplace_back(AlgoParam{__VA_ARGS__}); \ | |||||
| tensorop_float16_split_k.emplace_back(AlgoParam{__VA_ARGS__}); | |||||
| FOREACH_CUTLASS_MATMUL_F16_SHAPES(cb) | |||||
| #undef cb | |||||
| #undef FOREACH_CUTLASS_MATMUL_F16_SHAPES | |||||
| } | } | ||||
| #endif | #endif | ||||
| @@ -41,11 +41,13 @@ public: | |||||
| CUDA_WMMA_UINT4X4X32, | CUDA_WMMA_UINT4X4X32, | ||||
| CUDA_CUBLASLT, | CUDA_CUBLASLT, | ||||
| CUDA_NAIVE, | CUDA_NAIVE, | ||||
| CUDA_BFLOAT16, | |||||
| CUDA_BFLOAT16, | |||||
| #if CUDA_VERSION >= 9020 | #if CUDA_VERSION >= 9020 | ||||
| CUDA_FLOAT32_SIMT, | |||||
| CUDA_FLOAT32_SIMT_SPLIT_K, | |||||
| CUDA_FLOAT32_SIMT_GEMV_BATCHED_STRIDED, | |||||
| CUDA_FLOAT32_SIMT, | |||||
| CUDA_FLOAT32_SIMT_SPLIT_K, | |||||
| CUDA_FLOAT32_SIMT_GEMV_BATCHED_STRIDED, | |||||
| CUDA_FLOAT16_TENSOR_OP, | |||||
| CUDA_FLOAT16_TENSOR_OP_SPLIT_K, | |||||
| #endif | #endif | ||||
| }; | }; | ||||
| using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>; | using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>; | ||||
| @@ -188,65 +190,83 @@ private: | |||||
| #endif | #endif | ||||
| #if CUDA_VERSION >= 9020 | #if CUDA_VERSION >= 9020 | ||||
| class MatrixMulForwardImpl::AlgoFloat32SIMT final : public AlgoBase { | |||||
| class MatrixMulForwardImpl::AlgoCutlassMatrixMulBase : public AlgoBase { | |||||
| public: | public: | ||||
| struct AlgoParam { | struct AlgoParam { | ||||
| int threadblock_m, threadblock_n, threadblock_k; | int threadblock_m, threadblock_n, threadblock_k; | ||||
| int warp_m, warp_n, warp_k; | int warp_m, warp_n, warp_k; | ||||
| std::string to_string() { | |||||
| return ssprintf("%dX%dX%d_%dX%dX%d", threadblock_m, threadblock_n, | |||||
| threadblock_k, warp_m, warp_n, warp_k); | |||||
| } | |||||
| int instruction_m, instruction_n, instruction_k; | |||||
| AlgoParam(int threadblock_m_, int threadblock_n_, int threadblock_k_, | |||||
| int warp_m_, int warp_n_, int warp_k_, int instruction_m_ = 1, | |||||
| int instruction_n_ = 1, int instruction_k_ = 1) | |||||
| : threadblock_m{threadblock_m_}, | |||||
| threadblock_n{threadblock_n_}, | |||||
| threadblock_k{threadblock_k_}, | |||||
| warp_m{warp_m_}, | |||||
| warp_n{warp_n_}, | |||||
| warp_k{warp_k_}, | |||||
| instruction_m{instruction_m_}, | |||||
| instruction_n{instruction_n_}, | |||||
| instruction_k{instruction_k_} {} | |||||
| std::string to_string() const; | |||||
| }; | }; | ||||
| AlgoCutlassMatrixMulBase(AlgoParam algo_param) : m_algo_param{algo_param} {} | |||||
| void exec(const ExecArgs& args) const override; | |||||
| std::string param() const override { | |||||
| std::string ret; | |||||
| serialize_write_pod(m_algo_param, ret); | |||||
| return ret; | |||||
| } | |||||
| protected: | |||||
| virtual int min_alignment_requirement() const = 0; | |||||
| virtual void do_exec(const ExecArgs& args) const = 0; | |||||
| std::pair<bool, TensorLayoutArray> construct_aligned_layouts( | |||||
| const SizeArgs& args) const; | |||||
| int max_alignment(const SizeArgs& args) const; | |||||
| AlgoParam m_algo_param; | |||||
| }; | |||||
| class MatrixMulForwardImpl::AlgoFloat32SIMT final | |||||
| : public AlgoCutlassMatrixMulBase { | |||||
| public: | |||||
| AlgoFloat32SIMT(AlgoParam algo_param) | AlgoFloat32SIMT(AlgoParam algo_param) | ||||
| : m_algo_param{algo_param}, | |||||
| : AlgoCutlassMatrixMulBase{algo_param}, | |||||
| m_name{ssprintf("CUTLASS_FLOAT32_SIMT_%s", | m_name{ssprintf("CUTLASS_FLOAT32_SIMT_%s", | ||||
| m_algo_param.to_string().c_str())} {} | m_algo_param.to_string().c_str())} {} | ||||
| bool is_available(const SizeArgs& args) const override; | bool is_available(const SizeArgs& args) const override; | ||||
| size_t get_workspace_in_bytes(const SizeArgs& args) const override; | size_t get_workspace_in_bytes(const SizeArgs& args) const override; | ||||
| const char* name() const override { return m_name.c_str(); } | const char* name() const override { return m_name.c_str(); } | ||||
| void exec(const ExecArgs& args) const override; | |||||
| AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
| return AlgoAttribute::REPRODUCIBLE; | return AlgoAttribute::REPRODUCIBLE; | ||||
| } | } | ||||
| MEGDNN_DECL_ALGO_TYPE(CUDA_FLOAT32_SIMT) | MEGDNN_DECL_ALGO_TYPE(CUDA_FLOAT32_SIMT) | ||||
| std::string param() const override { | |||||
| std::string ret; | |||||
| serialize_write_pod(m_algo_param, ret); | |||||
| return ret; | |||||
| } | |||||
| private: | private: | ||||
| AlgoParam m_algo_param; | |||||
| void do_exec(const ExecArgs& args) const override; | |||||
| int min_alignment_requirement() const override { return 1; } | |||||
| std::string m_name; | std::string m_name; | ||||
| }; | }; | ||||
| class MatrixMulForwardImpl::AlgoFloat32SIMTSplitK final : public AlgoBase { | |||||
| class MatrixMulForwardImpl::AlgoFloat32SIMTSplitK final | |||||
| : public AlgoCutlassMatrixMulBase { | |||||
| public: | public: | ||||
| using AlgoParam = MatrixMulForwardImpl::AlgoFloat32SIMT::AlgoParam; | |||||
| AlgoFloat32SIMTSplitK(AlgoParam algo_param) | AlgoFloat32SIMTSplitK(AlgoParam algo_param) | ||||
| : m_algo_param{algo_param}, | |||||
| : AlgoCutlassMatrixMulBase{algo_param}, | |||||
| m_name{ssprintf("CUTLASS_FLOAT32_SIMT_SPLIT_K_%s", | m_name{ssprintf("CUTLASS_FLOAT32_SIMT_SPLIT_K_%s", | ||||
| m_algo_param.to_string().c_str())} {} | m_algo_param.to_string().c_str())} {} | ||||
| bool is_available(const SizeArgs& args) const override; | bool is_available(const SizeArgs& args) const override; | ||||
| size_t get_workspace_in_bytes(const SizeArgs& args) const override; | size_t get_workspace_in_bytes(const SizeArgs& args) const override; | ||||
| const char* name() const override { return m_name.c_str(); } | const char* name() const override { return m_name.c_str(); } | ||||
| void exec(const ExecArgs& args) const override; | |||||
| AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
| return AlgoAttribute::REPRODUCIBLE | | return AlgoAttribute::REPRODUCIBLE | | ||||
| AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | ||||
| } | } | ||||
| MEGDNN_DECL_ALGO_TYPE(CUDA_FLOAT32_SIMT_SPLIT_K) | MEGDNN_DECL_ALGO_TYPE(CUDA_FLOAT32_SIMT_SPLIT_K) | ||||
| std::string param() const override { | |||||
| std::string ret; | |||||
| serialize_write_pod(m_algo_param, ret); | |||||
| return ret; | |||||
| } | |||||
| private: | private: | ||||
| AlgoParam m_algo_param; | |||||
| void do_exec(const ExecArgs& args) const override; | |||||
| int min_alignment_requirement() const override { return 1; } | |||||
| std::string m_name; | std::string m_name; | ||||
| }; | }; | ||||
| @@ -276,6 +296,56 @@ private: | |||||
| int m_threadblock_n; | int m_threadblock_n; | ||||
| std::string m_name; | std::string m_name; | ||||
| }; | }; | ||||
| class MatrixMulForwardImpl::AlgoFloat16TensorOp final | |||||
| : public AlgoCutlassMatrixMulBase { | |||||
| public: | |||||
| AlgoFloat16TensorOp(AlgoParam algo_param) | |||||
| : AlgoCutlassMatrixMulBase{algo_param}, | |||||
| m_name{ssprintf("CUTLASS_FLOAT16_TENSOR_OP_h%d%d%d_%s", | |||||
| m_algo_param.instruction_m, | |||||
| m_algo_param.instruction_n, | |||||
| m_algo_param.instruction_k, | |||||
| m_algo_param.to_string().c_str())} {} | |||||
| bool is_available(const SizeArgs& args) const override; | |||||
| size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||||
| const char* name() const override { return m_name.c_str(); } | |||||
| AlgoAttribute attribute() const override { | |||||
| return AlgoAttribute::REPRODUCIBLE; | |||||
| } | |||||
| MEGDNN_DECL_ALGO_TYPE(CUDA_FLOAT16_TENSOR_OP) | |||||
| private: | |||||
| void do_exec(const ExecArgs& args) const override; | |||||
| int min_alignment_requirement() const override { return 2; } | |||||
| std::string m_name; | |||||
| }; | |||||
| class MatrixMulForwardImpl::AlgoFloat16TensorOpSplitK final | |||||
| : public AlgoCutlassMatrixMulBase { | |||||
| public: | |||||
| AlgoFloat16TensorOpSplitK(AlgoParam algo_param) | |||||
| : AlgoCutlassMatrixMulBase{algo_param}, | |||||
| m_name{ssprintf("CUTLASS_FLOAT16_TENSOR_OP_SPLIT_K_h%d%d%d_%s", | |||||
| m_algo_param.instruction_m, | |||||
| m_algo_param.instruction_n, | |||||
| m_algo_param.instruction_k, | |||||
| m_algo_param.to_string().c_str())} {} | |||||
| bool is_available(const SizeArgs& args) const override; | |||||
| size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||||
| const char* name() const override { return m_name.c_str(); } | |||||
| AlgoAttribute attribute() const override { | |||||
| return AlgoAttribute::REPRODUCIBLE | | |||||
| AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
| } | |||||
| MEGDNN_DECL_ALGO_TYPE(CUDA_FLOAT16_TENSOR_OP_SPLIT_K) | |||||
| private: | |||||
| void do_exec(const ExecArgs& args) const override; | |||||
| int min_alignment_requirement() const override { return 2; } | |||||
| std::string m_name; | |||||
| }; | |||||
| #endif | #endif | ||||
| class MatrixMulForwardImpl::AlgoPack : NonCopyableObj { | class MatrixMulForwardImpl::AlgoPack : NonCopyableObj { | ||||
| @@ -300,6 +370,8 @@ public: | |||||
| std::vector<AlgoFloat32SIMTSplitK> simt_float32_split_k; | std::vector<AlgoFloat32SIMTSplitK> simt_float32_split_k; | ||||
| std::vector<AlgoFloat32SIMTGemvBatchedStrided> | std::vector<AlgoFloat32SIMTGemvBatchedStrided> | ||||
| simt_float32_gemv_batched_strided; | simt_float32_gemv_batched_strided; | ||||
| std::vector<AlgoFloat16TensorOp> tensorop_float16; | |||||
| std::vector<AlgoFloat16TensorOpSplitK> tensorop_float16_split_k; | |||||
| #endif | #endif | ||||
| std::vector<AlgoBase*> all_algos; | std::vector<AlgoBase*> all_algos; | ||||
| @@ -0,0 +1,154 @@ | |||||
| /** | |||||
| * \file dnn/src/cuda/matrix_mul/cutlass_float16_tensorop.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #include "src/cuda/cutlass/singleton.h" | |||||
| #include "src/cuda/handle.h" | |||||
| #include "src/cuda/matrix_mul/algos.h" | |||||
| #include "src/cuda/utils.h" | |||||
| #if CUDA_VERSION >= 9020 | |||||
| using namespace megdnn; | |||||
| using namespace cuda; | |||||
| bool MatrixMulForwardImpl::AlgoFloat16TensorOp::is_available( | |||||
| const SizeArgs& args) const { | |||||
| bool available = | |||||
| args.opr->param().format == param::MatrixMul::Format::DEFAULT && | |||||
| args.layout_b.dtype == dtype::Float16() && | |||||
| args.layout_c.dtype == dtype::Float16(); | |||||
| int n = args.layout_c.shape[1]; | |||||
| auto&& device_prop = cuda::current_device_prop(); | |||||
| int y_grid_limit = device_prop.maxGridSize[1]; | |||||
| // limit y grid | |||||
| available &= ((n + m_algo_param.threadblock_n - 1) / | |||||
| m_algo_param.threadblock_n <= | |||||
| y_grid_limit); | |||||
| if (m_algo_param.instruction_m == 8 && m_algo_param.instruction_n == 8 && | |||||
| m_algo_param.instruction_k == 4) { | |||||
| available &= is_compute_capability_required(7, 0); | |||||
| } else { | |||||
| megdnn_assert(m_algo_param.instruction_m == 16 && | |||||
| m_algo_param.instruction_n == 8 && | |||||
| m_algo_param.instruction_k == 8); | |||||
| available &= is_compute_capability_required(7, 5); | |||||
| } | |||||
| return available; | |||||
| } | |||||
| size_t MatrixMulForwardImpl::AlgoFloat16TensorOp::get_workspace_in_bytes( | |||||
| const SizeArgs& args) const { | |||||
| auto aligned = construct_aligned_layouts(args); | |||||
| if (!aligned.first) | |||||
| return 0_z; | |||||
| const auto& layouts = aligned.second; | |||||
| size_t ws_size = 0; | |||||
| for (auto&& ly : layouts) { | |||||
| ws_size += ly.span().dist_byte(); | |||||
| } | |||||
| return ws_size; | |||||
| } | |||||
| void MatrixMulForwardImpl::AlgoFloat16TensorOp::do_exec( | |||||
| const ExecArgs& args) const { | |||||
| int64_t lda = args.tensor_a.layout.stride[0], | |||||
| ldb = args.tensor_b.layout.stride[0], | |||||
| ldc = args.tensor_c.layout.stride[0]; | |||||
| int alignment = max_alignment(args); | |||||
| int min_alignment = min_alignment_requirement(); | |||||
| auto&& param = args.opr->param(); | |||||
| int m = args.tensor_c.layout.shape[0], n = args.tensor_c.layout.shape[1], | |||||
| k = args.tensor_a.layout.shape[param.transposeA ? 0 : 1]; | |||||
| megdnn_assert(lda % alignment == 0 && ldb % alignment == 0 && | |||||
| ldc % alignment == 0 && m % alignment == 0 && | |||||
| n % alignment == 0 && k % alignment == 0 && | |||||
| alignment >= min_alignment); | |||||
| cutlass::gemm::GemmCoord problem_size{m, n, k}; | |||||
| auto&& stream = cuda_stream(args.opr->handle()); | |||||
| int* workspace = reinterpret_cast<int*>(args.workspace.raw_ptr); | |||||
| // \note these constants (i.e. one and zero) of cutlass epilogue will be | |||||
| // passed by pointers and interpreted as ElementCompute*, which will be used | |||||
| // to initialize kernel parameters. So the arguments' type on the host side | |||||
| // should be the same as the ElementCompute of kernel instance, otherwise | |||||
| // undefined kernel bahaviors will occur caused by incorrect intepretation | |||||
| // of these pointers. | |||||
| float one = 1.f, zero = 0.f; | |||||
| dt_float16 one_f16 = static_cast<dt_float16>(one), | |||||
| zero_f16 = static_cast<dt_float16>(zero); | |||||
| using namespace cutlass::library; | |||||
| auto layoutA = param.transposeA ? LayoutTypeID::kColumnMajor | |||||
| : LayoutTypeID::kRowMajor; | |||||
| auto layoutB = param.transposeB ? LayoutTypeID::kColumnMajor | |||||
| : LayoutTypeID::kRowMajor; | |||||
| void *host_one, *host_zero; | |||||
| NumericTypeID element_accumulator; | |||||
| if (param.compute_mode == param::MatrixMul::ComputeMode::DEFAULT) { | |||||
| element_accumulator = NumericTypeID::kF16; | |||||
| host_one = &one_f16; | |||||
| host_zero = &zero_f16; | |||||
| } else { | |||||
| megdnn_assert(param.compute_mode == | |||||
| param::MatrixMul::ComputeMode::FLOAT32); | |||||
| element_accumulator = NumericTypeID::kF32; | |||||
| host_one = &one; | |||||
| host_zero = &zero; | |||||
| } | |||||
| GemmKey key{NumericTypeID::kF16, | |||||
| layoutA, | |||||
| NumericTypeID::kF16, | |||||
| layoutB, | |||||
| NumericTypeID::kF16, | |||||
| LayoutTypeID::kRowMajor, | |||||
| element_accumulator, | |||||
| m_algo_param.threadblock_m, | |||||
| m_algo_param.threadblock_n, | |||||
| m_algo_param.threadblock_k, | |||||
| m_algo_param.warp_m, | |||||
| m_algo_param.warp_n, | |||||
| m_algo_param.warp_k, | |||||
| m_algo_param.instruction_m, | |||||
| m_algo_param.instruction_n, | |||||
| m_algo_param.instruction_k, | |||||
| 2, | |||||
| alignment, | |||||
| alignment, | |||||
| SplitKMode::kNone}; | |||||
| const auto& table = Singleton::get().operation_table; | |||||
| megdnn_assert(table.gemm_operations.count(key) > 0, | |||||
| "key not found in cutlass operation table"); | |||||
| const auto& ops = table.gemm_operations.at(key); | |||||
| megdnn_assert(ops.size() == 1, "exactly one kernel expected, got %zu", | |||||
| ops.size()); | |||||
| GemmArguments gemm_args{problem_size, | |||||
| args.tensor_a.raw_ptr, | |||||
| args.tensor_b.raw_ptr, | |||||
| args.tensor_c.raw_ptr, | |||||
| args.tensor_c.raw_ptr, | |||||
| lda, | |||||
| ldb, | |||||
| ldc, | |||||
| ldc, | |||||
| 1, | |||||
| host_one, | |||||
| host_zero}; | |||||
| cutlass_check(ops[0]->run(&gemm_args, workspace, stream)); | |||||
| } | |||||
| #endif | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -0,0 +1,165 @@ | |||||
| /** | |||||
| * \file dnn/src/cuda/matrix_mul/cutlass_float16_tensorop_split_k.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #include "src/cuda/cutlass/singleton.h" | |||||
| #include "src/cuda/handle.h" | |||||
| #include "src/cuda/matrix_mul/algos.h" | |||||
| #include "src/cuda/utils.h" | |||||
| #if CUDA_VERSION >= 9020 | |||||
| using namespace megdnn; | |||||
| using namespace cuda; | |||||
| bool MatrixMulForwardImpl::AlgoFloat16TensorOpSplitK::is_available( | |||||
| const SizeArgs& args) const { | |||||
| auto&& param = args.opr->param(); | |||||
| int n = args.layout_c.shape[1], | |||||
| k = args.layout_a.shape[param.transposeA ? 0 : 1]; | |||||
| bool available = | |||||
| args.opr->param().format == param::MatrixMul::Format::DEFAULT && | |||||
| args.layout_a.dtype == dtype::Float16() && | |||||
| args.layout_b.dtype == dtype::Float16() && | |||||
| args.layout_c.dtype == dtype::Float16() && k > n; | |||||
| auto&& device_prop = cuda::current_device_prop(); | |||||
| int y_grid_limit = device_prop.maxGridSize[1]; | |||||
| // limit y grid | |||||
| available &= ((n + m_algo_param.threadblock_n - 1) / | |||||
| m_algo_param.threadblock_n <= | |||||
| y_grid_limit); | |||||
| if (m_algo_param.instruction_m == 8 && m_algo_param.instruction_n == 8 && | |||||
| m_algo_param.instruction_k == 4) { | |||||
| available &= is_compute_capability_required(7, 0); | |||||
| } else { | |||||
| megdnn_assert(m_algo_param.instruction_m == 16 && | |||||
| m_algo_param.instruction_n == 8 && | |||||
| m_algo_param.instruction_k == 8); | |||||
| available &= is_compute_capability_required(7, 5); | |||||
| } | |||||
| return available; | |||||
| } | |||||
| size_t MatrixMulForwardImpl::AlgoFloat16TensorOpSplitK::get_workspace_in_bytes( | |||||
| const SizeArgs& args) const { | |||||
| auto aligned = construct_aligned_layouts(args); | |||||
| auto&& param = args.opr->param(); | |||||
| int m = args.layout_c.shape[0], n = args.layout_c.shape[1], | |||||
| k = args.layout_a.shape[param.transposeA ? 0 : 1]; | |||||
| int split_k_slices = std::max(1, k / n); | |||||
| if (!aligned.first) | |||||
| return args.layout_c.dtype.size(m * n * split_k_slices); | |||||
| const auto& layouts = aligned.second; | |||||
| int align_m = layouts[2].shape[0], align_n = layouts[2].shape[1], | |||||
| align_k = layouts[0].shape[1]; | |||||
| split_k_slices = std::max(1, align_k / align_n); | |||||
| size_t ws_size = | |||||
| args.layout_c.dtype.size(align_m * align_n * split_k_slices); | |||||
| for (auto&& ly : layouts) | |||||
| ws_size += ly.span().dist_byte(); | |||||
| return ws_size; | |||||
| } | |||||
| void MatrixMulForwardImpl::AlgoFloat16TensorOpSplitK::do_exec( | |||||
| const ExecArgs& args) const { | |||||
| int64_t lda = args.tensor_a.layout.stride[0], | |||||
| ldb = args.tensor_b.layout.stride[0], | |||||
| ldc = args.tensor_c.layout.stride[0]; | |||||
| int alignment = max_alignment(args); | |||||
| int min_alignment = min_alignment_requirement(); | |||||
| auto&& param = args.opr->param(); | |||||
| int m = args.tensor_c.layout.shape[0], n = args.tensor_c.layout.shape[1], | |||||
| k = args.tensor_a.layout.shape[param.transposeA ? 0 : 1]; | |||||
| megdnn_assert(lda % alignment == 0 && ldb % alignment == 0 && | |||||
| ldc % alignment == 0 && m % alignment == 0 && | |||||
| n % alignment == 0 && k % alignment == 0 && | |||||
| alignment >= min_alignment); | |||||
| cutlass::gemm::GemmCoord problem_size{m, n, k}; | |||||
| int split_k_slices = std::max(1, k / n); | |||||
| auto&& stream = cuda_stream(args.opr->handle()); | |||||
| int* workspace = reinterpret_cast<int*>(args.workspace.raw_ptr); | |||||
| // \note these constants (i.e. one and zero) of cutlass epilogue will be | |||||
| // passed by pointers and interpreted as ElementCompute*, which will be used | |||||
| // to initialize kernel parameters. So the arguments' type on the host side | |||||
| // should be the same as the ElementCompute of kernel instance, otherwise | |||||
| // undefined kernel bahaviors will occur caused by incorrect intepretation | |||||
| // of these pointers. | |||||
| float one = 1.f, zero = 0.f; | |||||
| dt_float16 one_f16 = static_cast<dt_float16>(one), | |||||
| zero_f16 = static_cast<dt_float16>(zero); | |||||
| using namespace cutlass::library; | |||||
| auto layoutA = param.transposeA ? LayoutTypeID::kColumnMajor | |||||
| : LayoutTypeID::kRowMajor; | |||||
| auto layoutB = param.transposeB ? LayoutTypeID::kColumnMajor | |||||
| : LayoutTypeID::kRowMajor; | |||||
| void *host_one, *host_zero; | |||||
| NumericTypeID element_accumulator; | |||||
| if (param.compute_mode == param::MatrixMul::ComputeMode::DEFAULT) { | |||||
| element_accumulator = NumericTypeID::kF16; | |||||
| host_one = &one_f16; | |||||
| host_zero = &zero_f16; | |||||
| } else { | |||||
| megdnn_assert(param.compute_mode == | |||||
| param::MatrixMul::ComputeMode::FLOAT32); | |||||
| element_accumulator = NumericTypeID::kF32; | |||||
| host_one = &one; | |||||
| host_zero = &zero; | |||||
| } | |||||
| GemmKey key{NumericTypeID::kF16, | |||||
| layoutA, | |||||
| NumericTypeID::kF16, | |||||
| layoutB, | |||||
| NumericTypeID::kF16, | |||||
| LayoutTypeID::kRowMajor, | |||||
| element_accumulator, | |||||
| m_algo_param.threadblock_m, | |||||
| m_algo_param.threadblock_n, | |||||
| m_algo_param.threadblock_k, | |||||
| m_algo_param.warp_m, | |||||
| m_algo_param.warp_n, | |||||
| m_algo_param.warp_k, | |||||
| m_algo_param.instruction_m, | |||||
| m_algo_param.instruction_n, | |||||
| m_algo_param.instruction_k, | |||||
| 2, | |||||
| alignment, | |||||
| alignment, | |||||
| SplitKMode::kParallel}; | |||||
| const auto& table = Singleton::get().operation_table; | |||||
| megdnn_assert(table.gemm_operations.count(key) > 0, | |||||
| "key not found in cutlass operation table"); | |||||
| const auto& ops = table.gemm_operations.at(key); | |||||
| megdnn_assert(ops.size() == 1, "exactly one kernel expected, got %zu", | |||||
| ops.size()); | |||||
| GemmArguments gemm_args{problem_size, | |||||
| args.tensor_a.raw_ptr, | |||||
| args.tensor_b.raw_ptr, | |||||
| args.tensor_c.raw_ptr, | |||||
| args.tensor_c.raw_ptr, | |||||
| lda, | |||||
| ldb, | |||||
| ldc, | |||||
| ldc, | |||||
| split_k_slices, | |||||
| host_one, | |||||
| host_zero}; | |||||
| cutlass_check(ops[0]->run(&gemm_args, workspace, stream)); | |||||
| } | |||||
| #endif | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -42,7 +42,8 @@ size_t MatrixMulForwardImpl::AlgoFloat32SIMT::get_workspace_in_bytes( | |||||
| return 0_z; | return 0_z; | ||||
| } | } | ||||
| void MatrixMulForwardImpl::AlgoFloat32SIMT::exec(const ExecArgs& args) const { | |||||
| void MatrixMulForwardImpl::AlgoFloat32SIMT::do_exec( | |||||
| const ExecArgs& args) const { | |||||
| int64_t lda = args.tensor_a.layout.stride[0], | int64_t lda = args.tensor_a.layout.stride[0], | ||||
| ldb = args.tensor_b.layout.stride[0], | ldb = args.tensor_b.layout.stride[0], | ||||
| ldc = args.tensor_c.layout.stride[0]; | ldc = args.tensor_c.layout.stride[0]; | ||||
| @@ -65,12 +66,14 @@ void MatrixMulForwardImpl::AlgoFloat32SIMT::exec(const ExecArgs& args) const { | |||||
| auto layoutB = param.transposeB ? LayoutTypeID::kColumnMajor | auto layoutB = param.transposeB ? LayoutTypeID::kColumnMajor | ||||
| : LayoutTypeID::kRowMajor; | : LayoutTypeID::kRowMajor; | ||||
| int alignment = min_alignment_requirement(); | |||||
| GemmKey key{NumericTypeID::kF32, | GemmKey key{NumericTypeID::kF32, | ||||
| layoutA, | layoutA, | ||||
| NumericTypeID::kF32, | NumericTypeID::kF32, | ||||
| layoutB, | layoutB, | ||||
| NumericTypeID::kF32, | NumericTypeID::kF32, | ||||
| LayoutTypeID::kRowMajor, | LayoutTypeID::kRowMajor, | ||||
| NumericTypeID::kF32, | |||||
| m_algo_param.threadblock_m, | m_algo_param.threadblock_m, | ||||
| m_algo_param.threadblock_n, | m_algo_param.threadblock_n, | ||||
| m_algo_param.threadblock_k, | m_algo_param.threadblock_k, | ||||
| @@ -79,8 +82,10 @@ void MatrixMulForwardImpl::AlgoFloat32SIMT::exec(const ExecArgs& args) const { | |||||
| m_algo_param.warp_k, | m_algo_param.warp_k, | ||||
| 1, | 1, | ||||
| 1, | 1, | ||||
| 1, | |||||
| 2, | |||||
| 1, | |||||
| 2, | |||||
| alignment, | |||||
| alignment, | |||||
| SplitKMode::kNone}; | SplitKMode::kNone}; | ||||
| const Operation* op = Singleton::get().operation_table.find_op(key); | const Operation* op = Singleton::get().operation_table.find_op(key); | ||||
| @@ -22,7 +22,7 @@ using namespace cuda; | |||||
| bool MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::is_available( | bool MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::is_available( | ||||
| const SizeArgs& args) const { | const SizeArgs& args) const { | ||||
| auto&& param = args.opr->param(); | auto&& param = args.opr->param(); | ||||
| int m = args.layout_c.shape[0], n = args.layout_c.shape[1], | |||||
| int n = args.layout_c.shape[1], | |||||
| k = args.layout_a.shape[param.transposeA ? 0 : 1]; | k = args.layout_a.shape[param.transposeA ? 0 : 1]; | ||||
| bool available = | bool available = | ||||
| args.opr->param().format == param::MatrixMul::Format::DEFAULT && | args.opr->param().format == param::MatrixMul::Format::DEFAULT && | ||||
| @@ -32,8 +32,8 @@ bool MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::is_available( | |||||
| auto&& device_prop = cuda::current_device_prop(); | auto&& device_prop = cuda::current_device_prop(); | ||||
| int y_grid_limit = device_prop.maxGridSize[1]; | int y_grid_limit = device_prop.maxGridSize[1]; | ||||
| // limit y grid | // limit y grid | ||||
| available &= ((m + m_algo_param.threadblock_m - 1) / | |||||
| m_algo_param.threadblock_m <= | |||||
| available &= ((n + m_algo_param.threadblock_n - 1) / | |||||
| m_algo_param.threadblock_n <= | |||||
| y_grid_limit); | y_grid_limit); | ||||
| return available; | return available; | ||||
| } | } | ||||
| @@ -47,7 +47,7 @@ size_t MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::get_workspace_in_bytes( | |||||
| return args.layout_c.dtype.size(m * n * split_k_slices); | return args.layout_c.dtype.size(m * n * split_k_slices); | ||||
| } | } | ||||
| void MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::exec( | |||||
| void MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::do_exec( | |||||
| const ExecArgs& args) const { | const ExecArgs& args) const { | ||||
| int64_t lda = args.tensor_a.layout.stride[0], | int64_t lda = args.tensor_a.layout.stride[0], | ||||
| ldb = args.tensor_b.layout.stride[0], | ldb = args.tensor_b.layout.stride[0], | ||||
| @@ -72,12 +72,14 @@ void MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::exec( | |||||
| auto layoutB = param.transposeB ? LayoutTypeID::kColumnMajor | auto layoutB = param.transposeB ? LayoutTypeID::kColumnMajor | ||||
| : LayoutTypeID::kRowMajor; | : LayoutTypeID::kRowMajor; | ||||
| int alignment = min_alignment_requirement(); | |||||
| GemmKey key{NumericTypeID::kF32, | GemmKey key{NumericTypeID::kF32, | ||||
| layoutA, | layoutA, | ||||
| NumericTypeID::kF32, | NumericTypeID::kF32, | ||||
| layoutB, | layoutB, | ||||
| NumericTypeID::kF32, | NumericTypeID::kF32, | ||||
| LayoutTypeID::kRowMajor, | LayoutTypeID::kRowMajor, | ||||
| NumericTypeID::kF32, | |||||
| m_algo_param.threadblock_m, | m_algo_param.threadblock_m, | ||||
| m_algo_param.threadblock_n, | m_algo_param.threadblock_n, | ||||
| m_algo_param.threadblock_k, | m_algo_param.threadblock_k, | ||||
| @@ -87,7 +89,9 @@ void MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::exec( | |||||
| 1, | 1, | ||||
| 1, | 1, | ||||
| 1, | 1, | ||||
| 2, | |||||
| 2, | |||||
| alignment, | |||||
| alignment, | |||||
| SplitKMode::kParallel}; | SplitKMode::kParallel}; | ||||
| Operation const* op = Singleton::get().operation_table.find_op(key); | Operation const* op = Singleton::get().operation_table.find_op(key); | ||||
| @@ -0,0 +1,136 @@ | |||||
| /** | |||||
| * \file dnn/src/cuda/matrix_mul/cutlass_matrix_mul_base.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #include "src/cuda/handle.h" | |||||
| #include "src/cuda/matrix_mul/algos.h" | |||||
| #include "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cuh" | |||||
| #include "src/cuda/utils.h" | |||||
| #if CUDA_VERSION >= 9020 | |||||
| using namespace megdnn; | |||||
| using namespace cuda; | |||||
| std::string | |||||
| MatrixMulForwardImpl::AlgoCutlassMatrixMulBase::AlgoParam::to_string() const { | |||||
| return ssprintf("%dX%dX%d_%dX%dX%d", threadblock_m, threadblock_n, | |||||
| threadblock_k, warp_m, warp_n, warp_k); | |||||
| } | |||||
| std::pair<bool, TensorLayoutArray> | |||||
| MatrixMulForwardImpl::AlgoCutlassMatrixMulBase::construct_aligned_layouts( | |||||
| const SizeArgs& args) const { | |||||
| int alignment = max_alignment(args); | |||||
| int min_alignment = min_alignment_requirement(); | |||||
| bool aligned = alignment >= min_alignment; | |||||
| if (aligned) | |||||
| return std::make_pair(!aligned, TensorLayoutArray{{}}); | |||||
| auto&& param = args.opr->param(); | |||||
| int m = args.layout_c.shape[0], n = args.layout_c.shape[1], | |||||
| k = args.layout_a.shape[param.transposeA ? 0 : 1]; | |||||
| size_t align_m = get_aligned_power2(m, min_alignment); | |||||
| size_t align_n = get_aligned_power2(n, min_alignment); | |||||
| size_t align_k = get_aligned_power2(k, min_alignment); | |||||
| TensorLayoutArray layouts; | |||||
| layouts.emplace_back(TensorLayout{{align_m, align_k}, args.layout_a.dtype}); | |||||
| layouts.emplace_back(TensorLayout{{align_k, align_n}, args.layout_b.dtype}); | |||||
| layouts.emplace_back(TensorLayout{{align_m, align_n}, args.layout_c.dtype}); | |||||
| return std::make_pair(!aligned, std::move(layouts)); | |||||
| } | |||||
| void MatrixMulForwardImpl::AlgoCutlassMatrixMulBase::exec( | |||||
| const ExecArgs& args) const { | |||||
| auto aligned = construct_aligned_layouts(args); | |||||
| if (!aligned.first) | |||||
| return do_exec(args); | |||||
| const auto& layouts = aligned.second; | |||||
| auto tensor_a = args.tensor_a; | |||||
| auto tensor_b = args.tensor_b; | |||||
| auto workspace = args.workspace; | |||||
| size_t copy_size = 0; | |||||
| for (const auto& ly : layouts) | |||||
| copy_size += ly.span().dist_byte(); | |||||
| auto&& param = args.opr->param(); | |||||
| auto&& stream = cuda_stream(args.opr->handle()); | |||||
| cuda_check(cudaMemsetAsync(workspace.raw_ptr, 0, copy_size, stream)); | |||||
| auto&& relayout = args.opr->handle()->create_operator<RelayoutForward>(); | |||||
| auto copy_stride = [](const TensorLayout& src, TensorLayout& dst, | |||||
| bool trans) { | |||||
| dst.stride[0] = src.stride[0], dst.stride[1] = src.stride[1]; | |||||
| if (trans) | |||||
| std::swap(dst.stride[0], dst.stride[1]); | |||||
| }; | |||||
| copy_stride(layouts[0], tensor_a.layout, param.transposeA); | |||||
| tensor_a.raw_ptr = workspace.raw_ptr; | |||||
| relayout->exec(args.tensor_a, tensor_a); | |||||
| workspace.raw_ptr += layouts[0].span().dist_byte(); | |||||
| workspace.size -= layouts[0].span().dist_byte(); | |||||
| copy_stride(layouts[1], tensor_b.layout, param.transposeB); | |||||
| tensor_b.raw_ptr = workspace.raw_ptr; | |||||
| relayout->exec(args.tensor_b, tensor_b); | |||||
| workspace.raw_ptr += layouts[1].span().dist_byte(); | |||||
| workspace.size -= layouts[1].span().dist_byte(); | |||||
| decltype(tensor_a) tensor_c{workspace.raw_ptr, layouts[2]}; | |||||
| workspace.raw_ptr += layouts[2].span().dist_byte(); | |||||
| workspace.size -= layouts[2].span().dist_byte(); | |||||
| auto&& matmul = args.opr->handle()->create_operator<MatrixMulForward>(); | |||||
| matmul->param().transposeA = false; | |||||
| matmul->param().transposeB = false; | |||||
| matmul->param().compute_mode = args.opr->param().compute_mode; | |||||
| tensor_a.layout = layouts[0]; | |||||
| tensor_b.layout = layouts[1]; | |||||
| ExecArgs args_{static_cast<MatrixMulForwardImpl*>(matmul.get()), tensor_a, | |||||
| tensor_b, tensor_c, workspace}; | |||||
| do_exec(args_); | |||||
| tensor_c.layout.TensorShape::operator=(args.layout_c); | |||||
| relayout->exec(tensor_c, args.tensor_c); | |||||
| } | |||||
| int MatrixMulForwardImpl::AlgoCutlassMatrixMulBase::max_alignment( | |||||
| const SizeArgs& args) const { | |||||
| auto&& dtype_a = args.layout_a.dtype; | |||||
| auto&& dtype_b = args.layout_b.dtype; | |||||
| auto&& dtype_c = args.layout_c.dtype; | |||||
| auto get_alignment = [](const DType& dt, int len) { | |||||
| int size_bits = dt.size(1) * 8; | |||||
| int align = 128; | |||||
| while (align > 1) { | |||||
| if ((len * size_bits) % align == 0) | |||||
| break; | |||||
| align = align / 2; | |||||
| } | |||||
| return align / size_bits; | |||||
| }; | |||||
| int lda = args.layout_a.stride[0], ldb = args.layout_b.stride[0], | |||||
| ldc = args.layout_c.stride[0]; | |||||
| auto&& param = args.opr->param(); | |||||
| int m = args.layout_c.shape[0], n = args.layout_c.shape[1], | |||||
| k = args.layout_a.shape[param.transposeA ? 0 : 1]; | |||||
| int max_align = get_alignment(dtype_a, lda); | |||||
| max_align = std::min(get_alignment(dtype_a, m), max_align); | |||||
| max_align = std::min(get_alignment(dtype_a, n), max_align); | |||||
| max_align = std::min(get_alignment(dtype_a, k), max_align); | |||||
| max_align = std::min(get_alignment(dtype_a, lda), max_align); | |||||
| max_align = std::min(get_alignment(dtype_b, ldb), max_align); | |||||
| max_align = std::min(get_alignment(dtype_c, ldc), max_align); | |||||
| return max_align; | |||||
| } | |||||
| #endif | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -42,9 +42,12 @@ public: | |||||
| class AlgoBFloat16; | class AlgoBFloat16; | ||||
| #endif | #endif | ||||
| #if CUDA_VERSION >= 9020 | #if CUDA_VERSION >= 9020 | ||||
| class AlgoCutlassMatrixMulBase; | |||||
| class AlgoFloat32SIMT; | class AlgoFloat32SIMT; | ||||
| class AlgoFloat32SIMTSplitK; | class AlgoFloat32SIMTSplitK; | ||||
| class AlgoFloat32SIMTGemvBatchedStrided; | class AlgoFloat32SIMTGemvBatchedStrided; | ||||
| class AlgoFloat16TensorOp; | |||||
| class AlgoFloat16TensorOpSplitK; | |||||
| #endif | #endif | ||||
| class AlgoPack; | class AlgoPack; | ||||
| @@ -184,7 +184,8 @@ void matrix_mul::check_matrix_mul(DType A_dtype, DType B_dtype, DType C_dtype, | |||||
| const ExecutionPolicyAlgoName& algo, | const ExecutionPolicyAlgoName& algo, | ||||
| param::MatrixMul::Format format, size_t nbase, | param::MatrixMul::Format format, size_t nbase, | ||||
| float eps, std::vector<TestArg>&& user_args, | float eps, std::vector<TestArg>&& user_args, | ||||
| bool force_deduce_dst) { | |||||
| bool force_deduce_dst, | |||||
| param::MatrixMul::ComputeMode compute_mode) { | |||||
| megdnn_assert(A_dtype.enumv() == B_dtype.enumv()); | megdnn_assert(A_dtype.enumv() == B_dtype.enumv()); | ||||
| Checker<Opr> checker(handle); | Checker<Opr> checker(handle); | ||||
| checker.set_force_deduce_dst(force_deduce_dst); | checker.set_force_deduce_dst(force_deduce_dst); | ||||
| @@ -261,6 +262,7 @@ void matrix_mul::check_matrix_mul(DType A_dtype, DType B_dtype, DType C_dtype, | |||||
| Param param; | Param param; | ||||
| param.transposeA = arg.mask & 0x1; | param.transposeA = arg.mask & 0x1; | ||||
| param.transposeB = arg.mask & 0x2; | param.transposeB = arg.mask & 0x2; | ||||
| param.compute_mode = compute_mode; | |||||
| param.format = format; | param.format = format; | ||||
| checker.set_dtype(0, A_dtype) | checker.set_dtype(0, A_dtype) | ||||
| .set_dtype(1, B_dtype) | .set_dtype(1, B_dtype) | ||||
| @@ -69,7 +69,9 @@ void check_matrix_mul( | |||||
| const ExecutionPolicyAlgoName& algo = {"", {}}, | const ExecutionPolicyAlgoName& algo = {"", {}}, | ||||
| param::MatrixMul::Format format = param::MatrixMul::Format::DEFAULT, | param::MatrixMul::Format format = param::MatrixMul::Format::DEFAULT, | ||||
| size_t nbase = 8, float eps = 1e-3, std::vector<TestArg>&& args = {}, | size_t nbase = 8, float eps = 1e-3, std::vector<TestArg>&& args = {}, | ||||
| bool force_deduce_dst = true); | |||||
| bool force_deduce_dst = true, | |||||
| param::MatrixMul::ComputeMode compute_mode = | |||||
| param::MatrixMul::ComputeMode::DEFAULT); | |||||
| void check_matrix_mul( | void check_matrix_mul( | ||||
| DType A_dtype, DType B_dtype, DType C_dtype, Handle* handle, | DType A_dtype, DType B_dtype, DType C_dtype, Handle* handle, | ||||
| @@ -21,6 +21,7 @@ | |||||
| #include "test/cuda/fixture.h" | #include "test/cuda/fixture.h" | ||||
| #include "test/cuda/utils.h" | #include "test/cuda/utils.h" | ||||
| #define MEGDNN_WITH_BENCHMARK 1 | |||||
| #if CUDA_VERSION >= 9020 | #if CUDA_VERSION >= 9020 | ||||
| namespace megdnn { | namespace megdnn { | ||||
| namespace test { | namespace test { | ||||
| @@ -215,6 +216,14 @@ std::vector<BenchArgs> get_feat_model_args() { | |||||
| return args; | return args; | ||||
| } | } | ||||
| std::vector<BenchArgs> get_f16_feat_model_args() { | |||||
| std::vector<BenchArgs> args; | |||||
| args.emplace_back(BenchArgs{128, 9216, 9216}); | |||||
| args.emplace_back(BenchArgs{128, 6400, 6400}); | |||||
| args.emplace_back(BenchArgs{128, 5184, 5184}); | |||||
| return args; | |||||
| } | |||||
| void benchmark_matrix_mul( | void benchmark_matrix_mul( | ||||
| Handle* handle, const std::vector<BenchArgs>& args, DType A_dtype, | Handle* handle, const std::vector<BenchArgs>& args, DType A_dtype, | ||||
| DType B_dtype, DType C_dtype, const char* algo = nullptr, | DType B_dtype, DType C_dtype, const char* algo = nullptr, | ||||
| @@ -364,6 +373,82 @@ MEGDNN_FOREACH_CUTLASS_KERNEL(cb) | |||||
| #undef cb | #undef cb | ||||
| #undef MEGDNN_FOREACH_CUTLASS_KERNEL | #undef MEGDNN_FOREACH_CUTLASS_KERNEL | ||||
| #define MEGDNN_FOREACH_CUTLASS_KERNEL(cb) \ | |||||
| cb(1, 256, 128, 32, 64, 64, 32, 8, 8, 4); \ | |||||
| cb(2, 128, 256, 32, 64, 64, 32, 8, 8, 4); \ | |||||
| cb(3, 128, 128, 32, 64, 64, 32, 8, 8, 4); | |||||
| #define cb(name, tbm, tbn, tbk, wm, wn, wk, im, in, ik) \ | |||||
| TEST_F(CUDA, CUTLASS_F16_884_GEMM_##name) { \ | |||||
| require_compute_capability(7, 0); \ | |||||
| matrix_mul::check_matrix_mul<MatrixMulForward>( \ | |||||
| dtype::Float16(), dtype::Float16(), dtype::Float16(), \ | |||||
| handle_cuda(), \ | |||||
| "CUTLASS_FLOAT16_TENSOR_OP_h" #im #in #ik "_" #tbm "X" #tbn \ | |||||
| "X" #tbk "_" #wm "X" #wn "X" #wk, \ | |||||
| param::MatrixMul::Format::DEFAULT, 8, 1e-2, \ | |||||
| matrix_mul::get_matmul_args()); \ | |||||
| } | |||||
| MEGDNN_FOREACH_CUTLASS_KERNEL(cb) | |||||
| #undef cb | |||||
| #define cb(name, tbm, tbn, tbk, wm, wn, wk, im, in, ik) \ | |||||
| TEST_F(CUDA, CUTLASS_F16_884_GEMM_SPLIT_K_##name) { \ | |||||
| require_compute_capability(7, 0); \ | |||||
| matrix_mul::check_matrix_mul<MatrixMulForward>( \ | |||||
| dtype::Float16(), dtype::Float16(), dtype::Float16(), \ | |||||
| handle_cuda(), \ | |||||
| "CUTLASS_FLOAT16_TENSOR_OP_SPLIT_K_h" #im #in #ik "_" #tbm \ | |||||
| "X" #tbn "X" #tbk "_" #wm "X" #wn "X" #wk, \ | |||||
| param::MatrixMul::Format::DEFAULT, 8, 1e-3, \ | |||||
| matrix_mul::get_matmul_args_split_k(), true, \ | |||||
| param::MatrixMul::ComputeMode::FLOAT32); \ | |||||
| } | |||||
| MEGDNN_FOREACH_CUTLASS_KERNEL(cb) | |||||
| #undef cb | |||||
| #undef MEGDNN_FOREACH_CUTLASS_KERNEL | |||||
| #define MEGDNN_FOREACH_CUTLASS_KERNEL(cb) \ | |||||
| cb(1, 256, 128, 32, 64, 64, 32, 16, 8, 8); \ | |||||
| cb(2, 128, 256, 32, 64, 64, 32, 16, 8, 8); \ | |||||
| cb(3, 128, 128, 32, 64, 64, 32, 16, 8, 8); | |||||
| #define cb(name, tbm, tbn, tbk, wm, wn, wk, im, in, ik) \ | |||||
| TEST_F(CUDA, CUTLASS_F16_1688_GEMM_##name) { \ | |||||
| require_compute_capability(7, 5); \ | |||||
| matrix_mul::check_matrix_mul<MatrixMulForward>( \ | |||||
| dtype::Float16(), dtype::Float16(), dtype::Float16(), \ | |||||
| handle_cuda(), \ | |||||
| "CUTLASS_FLOAT16_TENSOR_OP_h" #im #in #ik "_" #tbm "X" #tbn \ | |||||
| "X" #tbk "_" #wm "X" #wn "X" #wk, \ | |||||
| param::MatrixMul::Format::DEFAULT, 8, 1e-2, \ | |||||
| matrix_mul::get_matmul_args(), true, \ | |||||
| param::MatrixMul::ComputeMode::FLOAT32); \ | |||||
| } | |||||
| MEGDNN_FOREACH_CUTLASS_KERNEL(cb) | |||||
| #undef cb | |||||
| #define cb(name, tbm, tbn, tbk, wm, wn, wk, im, in, ik) \ | |||||
| TEST_F(CUDA, CUTLASS_F16_1688_GEMM_SPLIT_K_##name) { \ | |||||
| require_compute_capability(7, 5); \ | |||||
| matrix_mul::check_matrix_mul<MatrixMulForward>( \ | |||||
| dtype::Float16(), dtype::Float16(), dtype::Float16(), \ | |||||
| handle_cuda(), \ | |||||
| "CUTLASS_FLOAT16_TENSOR_OP_SPLIT_K_h" #im #in #ik "_" #tbm \ | |||||
| "X" #tbn "X" #tbk "_" #wm "X" #wn "X" #wk, \ | |||||
| param::MatrixMul::Format::DEFAULT, 8, 1e-3, \ | |||||
| matrix_mul::get_matmul_args_split_k()); \ | |||||
| } | |||||
| MEGDNN_FOREACH_CUTLASS_KERNEL(cb) | |||||
| #undef cb | |||||
| #undef MEGDNN_FOREACH_CUTLASS_KERNEL | |||||
| #if MEGDNN_WITH_BENCHMARK | #if MEGDNN_WITH_BENCHMARK | ||||
| TEST_F(CUDA, BENCHMARK_CUTLASS_MATMUL) { | TEST_F(CUDA, BENCHMARK_CUTLASS_MATMUL) { | ||||
| benchmark_matrix_mul(handle_cuda(), get_square_matmul_args(), | benchmark_matrix_mul(handle_cuda(), get_square_matmul_args(), | ||||
| @@ -376,6 +461,12 @@ TEST_F(CUDA, BENCHMARK_CUTLASS_MATMUL_FEAT) { | |||||
| dtype::Float32(), dtype::Float32(), | dtype::Float32(), dtype::Float32(), | ||||
| "CUTLASS_FLOAT32_SIMT"); | "CUTLASS_FLOAT32_SIMT"); | ||||
| } | } | ||||
| TEST_F(CUDA, BENCHMARK_CUTLASS_F16_MATMUL_FEAT) { | |||||
| benchmark_matrix_mul(handle_cuda(), get_f16_feat_model_args(), | |||||
| dtype::Float16(), dtype::Float16(), dtype::Float16(), | |||||
| "CUTLASS_FLOAT16_TENSOR_OP"); | |||||
| } | |||||
| #endif | #endif | ||||
| } // namespace test | } // namespace test | ||||
| } // namespace megdnn | } // namespace megdnn | ||||