GitOrigin-RevId: 0b5574f526
tags/v1.6.0-rc1
| @@ -20,7 +20,7 @@ class Conv2dOperation: | |||
| # | |||
| def __init__(self, conv_kind, conv_type, arch, tile_description, src, flt, bias, dst, element_epilogue, \ | |||
| epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity4, \ | |||
| need_load_from_const = True, implicit_gemm_mode = ImplicitGemmMode.GemmNt): | |||
| need_load_from_const = True, implicit_gemm_mode = ImplicitGemmMode.GemmNT, without_shared_load = False): | |||
| self.operation_kind = OperationKind.Conv2d | |||
| self.conv_kind = conv_kind | |||
| @@ -36,6 +36,7 @@ class Conv2dOperation: | |||
| self.swizzling_functor = swizzling_functor | |||
| self.need_load_from_const = need_load_from_const | |||
| self.implicit_gemm_mode = implicit_gemm_mode | |||
| self.without_shared_load = without_shared_load | |||
| # | |||
| def accumulator_type(self): | |||
| accum = self.tile_description.math_instruction.element_accumulator | |||
| @@ -58,11 +59,15 @@ class Conv2dOperation: | |||
| unity_kernel = '' | |||
| if not self.need_load_from_const: | |||
| unity_kernel = '_1x1' | |||
| unity_kernel = '_1x1' | |||
| return "%s%s%s%s%s_%s" % (ShortDataTypeNames[self.accumulator_type()], \ | |||
| reorder_k = '' | |||
| if self.without_shared_load: | |||
| reorder_k = '_roc' | |||
| return "%s%s%s%s%s%s_%s" % (ShortDataTypeNames[self.accumulator_type()], \ | |||
| inst_shape, intermediate_type, ConvKindNames[self.conv_kind], unity_kernel, \ | |||
| ShortEpilogueNames[self.epilogue_functor]) | |||
| reorder_k, ShortEpilogueNames[self.epilogue_functor]) | |||
| # | |||
| def extended_name(self): | |||
| @@ -177,7 +182,8 @@ using Convolution = | |||
| ${alignment_filter}, | |||
| ${nonuninity_kernel}, | |||
| ${math_operator}, | |||
| ${implicit_gemm_mode}>; | |||
| ${implicit_gemm_mode}, | |||
| ${without_shared_load}>; | |||
| """ | |||
| @@ -219,7 +225,8 @@ using Convolution = | |||
| 'alignment_filter': str(operation.flt.alignment), | |||
| 'nonuninity_kernel': str(operation.need_load_from_const).lower(), | |||
| 'math_operator': MathOperationTag[operation.tile_description.math_instruction.math_operation], | |||
| 'implicit_gemm_mode': ImplicitGemmModeTag[operation.implicit_gemm_mode] | |||
| 'implicit_gemm_mode': ImplicitGemmModeTag[operation.implicit_gemm_mode], | |||
| 'without_shared_load': str(operation.without_shared_load).lower() | |||
| } | |||
| return SubstituteTemplate(self.template, values) | |||
| @@ -312,13 +319,13 @@ using Deconvolution = | |||
| # | |||
| def GenerateConv2d(conv_kind, tile_descriptions, src_layout, flt_layout, dst_layout, dst_type, min_cc, src_align = 32, flt_align = 32, dst_align = 128, \ | |||
| skip_unity_kernel = False, implicit_gemm_mode = ImplicitGemmMode.GemmNt): | |||
| skip_unity_kernel = False, implicit_gemm_mode = ImplicitGemmMode.GemmNT, without_shared_load = False): | |||
| operations = [] | |||
| element_epilogue = DataType.f32 | |||
| if conv_kind == ConvKind.Fprop: | |||
| if src_layout == LayoutType.TensorNHWC: | |||
| swizzling_functor = SwizzlingFunctor.ConvFpropNHWC | |||
| if implicit_gemm_mode == ImplicitGemmMode.GemmTN: | |||
| swizzling_functor = SwizzlingFunctor.ConvFpropTrans | |||
| else: | |||
| swizzling_functor = SwizzlingFunctor.ConvFpropNCxHWx | |||
| else: | |||
| @@ -399,10 +406,10 @@ def GenerateConv2d(conv_kind, tile_descriptions, src_layout, flt_layout, dst_lay | |||
| bias = TensorDescription(bias_type, dst_layout, max(1, int(32 / DataTypeSize[bias_type]))) | |||
| dst = TensorDescription(dst_type, dst_layout, int(dst_align / DataTypeSize[dst_type])) | |||
| new_operation = Conv2dOperation(conv_kind, ConvType.Convolution, min_cc, tile, src, flt, bias, dst, element_epilogue, epilogue, swizzling_functor, True, implicit_gemm_mode) | |||
| new_operation = Conv2dOperation(conv_kind, ConvType.Convolution, min_cc, tile, src, flt, bias, dst, element_epilogue, epilogue, swizzling_functor, True, implicit_gemm_mode, without_shared_load) | |||
| operations.append(new_operation) | |||
| if not skip_unity_kernel: | |||
| new_operation = Conv2dOperation(conv_kind, ConvType.Convolution, min_cc, tile, src, flt, bias, dst, element_epilogue, epilogue, swizzling_functor, False, implicit_gemm_mode) | |||
| new_operation = Conv2dOperation(conv_kind, ConvType.Convolution, min_cc, tile, src, flt, bias, dst, element_epilogue, epilogue, swizzling_functor, False, implicit_gemm_mode, without_shared_load) | |||
| operations.append(new_operation) | |||
| return operations | |||
| @@ -175,12 +175,10 @@ def GenerateConv2d_Simt(args): | |||
| TileDescription([128, 128, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc), | |||
| TileDescription([128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), | |||
| TileDescription([ 64, 128, 32], 2, [1, 4, 1], math_inst, min_cc, max_cc), | |||
| TileDescription([ 64, 64, 32], 2, [1, 2, 1], math_inst, min_cc, max_cc), | |||
| TileDescription([128, 32, 32], 2, [2, 1, 1], math_inst, min_cc, max_cc), | |||
| TileDescription([ 32, 128, 32], 2, [1, 2, 1], math_inst, min_cc, max_cc), | |||
| TileDescription([ 32, 64, 32], 2, [1, 1, 1], math_inst, min_cc, max_cc), | |||
| TileDescription([ 64, 32, 32], 2, [1, 1, 1], math_inst, min_cc, max_cc), | |||
| TileDescription([ 32, 32, 32], 2, [1, 1, 1], math_inst, min_cc, max_cc), | |||
| TileDescription([ 16, 128, 16], 1, [1, 1, 1], math_inst, min_cc, max_cc), | |||
| TileDescription([ 16, 64, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc), | |||
| ] | |||
| @@ -223,28 +221,36 @@ def GenerateConv2d_TensorOp_8816(args): | |||
| for dst_type, dst_layout in zip(dst_types, dst_layouts): | |||
| if dst_layout == LayoutType.TensorNC32HW32: | |||
| tile_descriptions = [ | |||
| TileDescription([256, 128, 64], 2, [4, 2, 1], math_inst, min_cc, max_cc), | |||
| TileDescription([128, 256, 64], 2, [2, 4, 1], math_inst, min_cc, max_cc), | |||
| TileDescription([256, 128, 64], 2, [4, 2, 1], math_inst, min_cc, max_cc), | |||
| TileDescription([128, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), | |||
| TileDescription([ 64, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), | |||
| TileDescription([128, 64, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), | |||
| TileDescription([ 64, 64, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), | |||
| TileDescription([ 32, 64, 64], 2, [1, 4, 1], math_inst, min_cc, max_cc), | |||
| TileDescription([ 64, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), | |||
| TileDescription([128, 64, 32], 1, [2, 2, 1], math_inst, min_cc, max_cc), | |||
| TileDescription([128, 32, 32], 1, [2, 1, 1], math_inst, min_cc, max_cc), | |||
| TileDescription([ 64, 128, 32], 1, [2, 2, 1], math_inst, min_cc, max_cc), | |||
| TileDescription([ 32, 128, 32], 1, [1, 2, 1], math_inst, min_cc, max_cc), | |||
| ] | |||
| operations += GenerateConv2d(ConvKind.Fprop, tile_descriptions, layout[0], layout[1], | |||
| dst_layout, dst_type, min_cc, 128, 128, 64, | |||
| False, ImplicitGemmMode.GemmTN, True) | |||
| else: | |||
| assert dst_layout == LayoutType.TensorNC4HW4 | |||
| tile_descriptions = [ | |||
| TileDescription([256, 128, 64], 2, [4, 2, 1], math_inst, min_cc, max_cc), | |||
| TileDescription([128, 256, 64], 2, [2, 4, 1], math_inst, min_cc, max_cc), | |||
| TileDescription([256, 128, 64], 2, [4, 2, 1], math_inst, min_cc, max_cc), | |||
| TileDescription([128, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), | |||
| TileDescription([ 64, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), | |||
| TileDescription([128, 64, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), | |||
| TileDescription([ 64, 64, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), | |||
| TileDescription([ 32, 64, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), | |||
| TileDescription([ 64, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), | |||
| TileDescription([128, 64, 32], 1, [2, 2, 1], math_inst, min_cc, max_cc), | |||
| TileDescription([128, 32, 32], 1, [2, 1, 1], math_inst, min_cc, max_cc), | |||
| TileDescription([ 64, 128, 32], 1, [2, 2, 1], math_inst, min_cc, max_cc), | |||
| TileDescription([ 32, 128, 32], 1, [1, 2, 1], math_inst, min_cc, max_cc), | |||
| ] | |||
| operations += GenerateConv2d(ConvKind.Fprop, tile_descriptions, layout[0], layout[1], | |||
| operations += GenerateConv2d(ConvKind.Fprop, tile_descriptions, layout[0], layout[1], | |||
| dst_layout, dst_type, min_cc, 128, 128, 64, | |||
| False) | |||
| return operations | |||
| def GenerateConv2d_TensorOp_8832(args): | |||
| @@ -279,12 +285,14 @@ def GenerateConv2d_TensorOp_8832(args): | |||
| for dst_layout in dst_layouts: | |||
| dst_type = math_inst.element_b | |||
| tile_descriptions = [ | |||
| TileDescription([256, 128, 128], 2, [4, 2, 1], math_inst, min_cc, max_cc), | |||
| TileDescription([128, 256, 128], 2, [2, 4, 1], math_inst, min_cc, max_cc), | |||
| TileDescription([128, 128, 128], 2, [2, 2, 1], math_inst, min_cc, max_cc), | |||
| TileDescription([128, 64, 128], 2, [2, 1, 1], math_inst, min_cc, max_cc), | |||
| TileDescription([128, 64, 64], 1, [2, 1, 1], math_inst, min_cc, max_cc), | |||
| ] | |||
| operations += GenerateConv2d(ConvKind.Fprop, tile_descriptions, layout[0], layout[1], | |||
| dst_layout, dst_type, min_cc, 128, 128, 64, | |||
| True) | |||
| True, ImplicitGemmMode.GemmTN, True) | |||
| layouts_nhwc = [ | |||
| (LayoutType.TensorNHWC, LayoutType.TensorNC8HW8, 32), | |||
| @@ -299,14 +307,21 @@ def GenerateConv2d_TensorOp_8832(args): | |||
| for math_inst in math_instructions: | |||
| for layout in layouts_nhwc: | |||
| for dst_layout in dst_layouts_nhwc: | |||
| dst_type = math_inst.element_b | |||
| tile_descriptions = [ | |||
| TileDescription([128, 32, 64], 2, [2, 1, 1], math_inst, min_cc, max_cc), | |||
| TileDescription([128, 64, 64], 2, [2, 1, 1], math_inst, min_cc, max_cc), | |||
| ] | |||
| operations += GenerateConv2d(ConvKind.Fprop, tile_descriptions, layout[0], layout[1], | |||
| dst_layout, dst_type, min_cc, layout[2], layout[2], 32, | |||
| False, ImplicitGemmMode.GemmTn) | |||
| dst_type = math_inst.element_b | |||
| tile_descriptions = [ | |||
| TileDescription([128, 32, 64], 1, [2, 1, 1], math_inst, min_cc, max_cc), | |||
| TileDescription([128, 64, 64], 1, [2, 1, 1], math_inst, min_cc, max_cc), | |||
| ] | |||
| for tile in tile_descriptions: | |||
| operations += GenerateConv2d(ConvKind.Fprop, [tile], layout[0], layout[1], | |||
| dst_layout, dst_type, min_cc, layout[2], layout[2], 32, | |||
| False, ImplicitGemmMode.GemmTN, False) | |||
| if tile.threadblock_shape[1] == 32 or tile.threadblock_shape[1] == 64: | |||
| dst_align = 32 if tile.threadblock_shape[1] == 32 else 64 | |||
| operations += GenerateConv2d(ConvKind.Fprop, [tile], layout[0], layout[1], | |||
| dst_layout, dst_type, min_cc, layout[2], layout[2], dst_align, | |||
| False, ImplicitGemmMode.GemmTN, True) | |||
| return operations | |||
| def GenerateDeconv_Simt(args): | |||
| @@ -649,3 +664,4 @@ if __name__ == "__main__": | |||
| # | |||
| ################################################################################################### | |||
| @@ -464,10 +464,10 @@ EpilogueFunctorTag = { | |||
| ShortEpilogueNames = { | |||
| EpilogueFunctor.BiasAddLinearCombinationHSwishClamp: 'hswish', | |||
| EpilogueFunctor.BiasAddLinearCombinationReluClamp: 'relu', | |||
| EpilogueFunctor.BiasAddLinearCombinationClamp: 'identity', | |||
| EpilogueFunctor.BiasAddLinearCombinationClamp: 'id', | |||
| EpilogueFunctor.BiasAddLinearCombinationHSwish: 'hswish', | |||
| EpilogueFunctor.BiasAddLinearCombinationRelu: 'relu', | |||
| EpilogueFunctor.BiasAddLinearCombination: 'identity', | |||
| EpilogueFunctor.BiasAddLinearCombination: 'id', | |||
| } | |||
| @@ -482,7 +482,7 @@ class SwizzlingFunctor(enum.Enum): | |||
| Identity4 = enum_auto() | |||
| Identity8 = enum_auto() | |||
| ConvFpropNCxHWx = enum_auto() | |||
| ConvFpropNHWC = enum_auto() | |||
| ConvFpropTrans = enum_auto() | |||
| ConvDgradNCxHWx = enum_auto() | |||
| # | |||
| @@ -492,7 +492,7 @@ SwizzlingFunctorTag = { | |||
| SwizzlingFunctor.Identity4: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>', | |||
| SwizzlingFunctor.Identity8: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>', | |||
| SwizzlingFunctor.ConvFpropNCxHWx: 'cutlass::conv::threadblock::ConvolutionFpropNCxHWxThreadblockSwizzle', | |||
| SwizzlingFunctor.ConvFpropNHWC: 'cutlass::conv::threadblock::ConvolutionFpropNHWCThreadblockSwizzle', | |||
| SwizzlingFunctor.ConvFpropTrans: 'cutlass::conv::threadblock::ConvolutionFpropTransThreadblockSwizzle', | |||
| SwizzlingFunctor.ConvDgradNCxHWx: 'cutlass::conv::threadblock::ConvolutionDgradNCxHWxThreadblockSwizzle', | |||
| } | |||
| @@ -563,17 +563,17 @@ StrideSupportNames = { | |||
| } | |||
| class ImplicitGemmMode(enum.Enum): | |||
| GemmNt = enum_auto() | |||
| GemmTn = enum_auto() | |||
| GemmNT = enum_auto() | |||
| GemmTN = enum_auto() | |||
| ImplicitGemmModeNames = { | |||
| ImplicitGemmMode.GemmNt: 'gemm_nt', | |||
| ImplicitGemmMode.GemmTn: 'gemm_tn', | |||
| ImplicitGemmMode.GemmNT: 'gemm_nt', | |||
| ImplicitGemmMode.GemmTN: 'gemm_tn', | |||
| } | |||
| ImplicitGemmModeTag = { | |||
| ImplicitGemmMode.GemmNt: 'cutlass::conv::ImplicitGemmMode::GEMM_NT', | |||
| ImplicitGemmMode.GemmTn: 'cutlass::conv::ImplicitGemmMode::GEMM_TN', | |||
| ImplicitGemmMode.GemmNT: 'cutlass::conv::ImplicitGemmMode::GEMM_NT', | |||
| ImplicitGemmMode.GemmTN: 'cutlass::conv::ImplicitGemmMode::GEMM_TN', | |||
| } | |||
| ################################################################################################### | |||
| @@ -164,415 +164,461 @@ cutlass_gen_list = [ | |||
| "cutlass_simt_sgemv_batched_strided_1x32_32_tt_align1x4.cu", | |||
| "cutlass_simt_sgemv_batched_strided_1x32_16_tt_align1x2.cu", | |||
| "cutlass_simt_sgemv_batched_strided_1x32_8_tt_align1x1.cu", | |||
| "cutlass_simt_s8_idgrad_identity_s8_64x128x32_64x32x32_2_nc4hw4_k4rsc4.cu", | |||
| "cutlass_simt_s8_idgrad_identity_s8_32x128x32_32x64x32_2_nc4hw4_k4rsc4.cu", | |||
| "cutlass_simt_s8_idgrad_identity_s8_16x128x16_16x64x16_2_nc4hw4_k4rsc4.cu", | |||
| "cutlass_simt_s8_idgrad_identity_s8_16x128x16_16x128x16_1_nc4hw4_k4rsc4.cu", | |||
| "cutlass_simt_s8_idgrad_identity_s8_16x64x8_16x64x8_2_nc4hw4_k4rsc4.cu", | |||
| "cutlass_simt_s8_ifprop_identity_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
| "cutlass_simt_s8_ifprop_1x1_identity_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
| "cutlass_simt_s8_idgrad_id_s8_64x128x32_64x32x32_2_nc4hw4_k4rsc4.cu", | |||
| "cutlass_simt_s8_idgrad_id_s8_32x128x32_32x64x32_2_nc4hw4_k4rsc4.cu", | |||
| "cutlass_simt_s8_idgrad_id_s8_16x128x16_16x64x16_2_nc4hw4_k4rsc4.cu", | |||
| "cutlass_simt_s8_idgrad_id_s8_16x128x16_16x128x16_1_nc4hw4_k4rsc4.cu", | |||
| "cutlass_simt_s8_idgrad_id_s8_16x64x8_16x64x8_2_nc4hw4_k4rsc4.cu", | |||
| "cutlass_simt_s8_ifprop_id_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
| "cutlass_simt_s8_ifprop_1x1_id_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
| "cutlass_simt_s8_ifprop_relu_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
| "cutlass_simt_s8_ifprop_1x1_relu_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
| "cutlass_simt_s8_ifprop_hswish_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
| "cutlass_simt_s8_ifprop_1x1_hswish_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
| "cutlass_simt_s8_ifprop_identity_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
| "cutlass_simt_s8_ifprop_1x1_identity_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
| "cutlass_simt_s8_ifprop_id_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
| "cutlass_simt_s8_ifprop_1x1_id_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
| "cutlass_simt_s8_ifprop_relu_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
| "cutlass_simt_s8_ifprop_1x1_relu_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
| "cutlass_simt_s8_ifprop_hswish_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
| "cutlass_simt_s8_ifprop_1x1_hswish_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
| "cutlass_simt_s8_ifprop_identity_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
| "cutlass_simt_s8_ifprop_1x1_identity_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
| "cutlass_simt_s8_ifprop_id_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
| "cutlass_simt_s8_ifprop_1x1_id_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
| "cutlass_simt_s8_ifprop_relu_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
| "cutlass_simt_s8_ifprop_1x1_relu_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
| "cutlass_simt_s8_ifprop_hswish_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
| "cutlass_simt_s8_ifprop_1x1_hswish_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
| "cutlass_simt_s8_ifprop_identity_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
| "cutlass_simt_s8_ifprop_1x1_identity_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
| "cutlass_simt_s8_ifprop_relu_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
| "cutlass_simt_s8_ifprop_1x1_relu_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
| "cutlass_simt_s8_ifprop_hswish_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
| "cutlass_simt_s8_ifprop_1x1_hswish_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
| "cutlass_simt_s8_ifprop_identity_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
| "cutlass_simt_s8_ifprop_1x1_identity_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
| "cutlass_simt_s8_ifprop_id_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
| "cutlass_simt_s8_ifprop_1x1_id_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
| "cutlass_simt_s8_ifprop_relu_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
| "cutlass_simt_s8_ifprop_1x1_relu_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
| "cutlass_simt_s8_ifprop_hswish_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
| "cutlass_simt_s8_ifprop_1x1_hswish_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
| "cutlass_simt_s8_ifprop_identity_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4.cu", | |||
| "cutlass_simt_s8_ifprop_1x1_identity_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4.cu", | |||
| "cutlass_simt_s8_ifprop_id_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4.cu", | |||
| "cutlass_simt_s8_ifprop_1x1_id_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4.cu", | |||
| "cutlass_simt_s8_ifprop_relu_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4.cu", | |||
| "cutlass_simt_s8_ifprop_1x1_relu_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4.cu", | |||
| "cutlass_simt_s8_ifprop_hswish_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4.cu", | |||
| "cutlass_simt_s8_ifprop_1x1_hswish_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4.cu", | |||
| "cutlass_simt_s8_ifprop_identity_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4.cu", | |||
| "cutlass_simt_s8_ifprop_1x1_identity_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4.cu", | |||
| "cutlass_simt_s8_ifprop_id_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4.cu", | |||
| "cutlass_simt_s8_ifprop_1x1_id_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4.cu", | |||
| "cutlass_simt_s8_ifprop_relu_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4.cu", | |||
| "cutlass_simt_s8_ifprop_1x1_relu_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4.cu", | |||
| "cutlass_simt_s8_ifprop_hswish_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4.cu", | |||
| "cutlass_simt_s8_ifprop_1x1_hswish_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4.cu", | |||
| "cutlass_simt_s8_ifprop_identity_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
| "cutlass_simt_s8_ifprop_1x1_identity_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
| "cutlass_simt_s8_ifprop_id_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
| "cutlass_simt_s8_ifprop_1x1_id_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
| "cutlass_simt_s8_ifprop_relu_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
| "cutlass_simt_s8_ifprop_1x1_relu_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
| "cutlass_simt_s8_ifprop_hswish_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
| "cutlass_simt_s8_ifprop_1x1_hswish_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||
| "cutlass_simt_s8_ifprop_identity_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4.cu", | |||
| "cutlass_simt_s8_ifprop_1x1_identity_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4.cu", | |||
| "cutlass_simt_s8_ifprop_relu_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4.cu", | |||
| "cutlass_simt_s8_ifprop_1x1_relu_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4.cu", | |||
| "cutlass_simt_s8_ifprop_hswish_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4.cu", | |||
| "cutlass_simt_s8_ifprop_1x1_hswish_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4.cu", | |||
| "cutlass_simt_s8_ifprop_identity_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4.cu", | |||
| "cutlass_simt_s8_ifprop_1x1_identity_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4.cu", | |||
| "cutlass_simt_s8_ifprop_id_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4.cu", | |||
| "cutlass_simt_s8_ifprop_1x1_id_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4.cu", | |||
| "cutlass_simt_s8_ifprop_relu_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4.cu", | |||
| "cutlass_simt_s8_ifprop_1x1_relu_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4.cu", | |||
| "cutlass_simt_s8_ifprop_hswish_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4.cu", | |||
| "cutlass_simt_s8_ifprop_1x1_hswish_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4.cu", | |||
| "cutlass_simt_s8_ifprop_identity_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4.cu", | |||
| "cutlass_simt_s8_ifprop_1x1_identity_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4.cu", | |||
| "cutlass_simt_s8_ifprop_id_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4.cu", | |||
| "cutlass_simt_s8_ifprop_1x1_id_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4.cu", | |||
| "cutlass_simt_s8_ifprop_relu_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4.cu", | |||
| "cutlass_simt_s8_ifprop_1x1_relu_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4.cu", | |||
| "cutlass_simt_s8_ifprop_hswish_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4.cu", | |||
| "cutlass_simt_s8_ifprop_1x1_hswish_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4.cu", | |||
| "cutlass_simt_s8_ifprop_identity_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
| "cutlass_simt_s8_ifprop_1x1_identity_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
| "cutlass_simt_s8_ifprop_id_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
| "cutlass_simt_s8_ifprop_1x1_id_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
| "cutlass_simt_s8_ifprop_relu_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
| "cutlass_simt_s8_ifprop_1x1_relu_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
| "cutlass_simt_s8_ifprop_hswish_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
| "cutlass_simt_s8_ifprop_1x1_hswish_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
| "cutlass_simt_s8_ifprop_identity_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
| "cutlass_simt_s8_ifprop_1x1_identity_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
| "cutlass_simt_s8_ifprop_id_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
| "cutlass_simt_s8_ifprop_1x1_id_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
| "cutlass_simt_s8_ifprop_relu_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
| "cutlass_simt_s8_ifprop_1x1_relu_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
| "cutlass_simt_s8_ifprop_hswish_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
| "cutlass_simt_s8_ifprop_1x1_hswish_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
| "cutlass_simt_s8_ifprop_identity_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
| "cutlass_simt_s8_ifprop_1x1_identity_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
| "cutlass_simt_s8_ifprop_id_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
| "cutlass_simt_s8_ifprop_1x1_id_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
| "cutlass_simt_s8_ifprop_relu_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
| "cutlass_simt_s8_ifprop_1x1_relu_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
| "cutlass_simt_s8_ifprop_hswish_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
| "cutlass_simt_s8_ifprop_1x1_hswish_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
| "cutlass_simt_s8_ifprop_identity_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
| "cutlass_simt_s8_ifprop_1x1_identity_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
| "cutlass_simt_s8_ifprop_relu_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
| "cutlass_simt_s8_ifprop_1x1_relu_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
| "cutlass_simt_s8_ifprop_hswish_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
| "cutlass_simt_s8_ifprop_1x1_hswish_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
| "cutlass_simt_s8_ifprop_identity_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
| "cutlass_simt_s8_ifprop_1x1_identity_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
| "cutlass_simt_s8_ifprop_id_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
| "cutlass_simt_s8_ifprop_1x1_id_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
| "cutlass_simt_s8_ifprop_relu_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
| "cutlass_simt_s8_ifprop_1x1_relu_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
| "cutlass_simt_s8_ifprop_hswish_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
| "cutlass_simt_s8_ifprop_1x1_hswish_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
| "cutlass_simt_s8_ifprop_identity_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
| "cutlass_simt_s8_ifprop_1x1_identity_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
| "cutlass_simt_s8_ifprop_id_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
| "cutlass_simt_s8_ifprop_1x1_id_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
| "cutlass_simt_s8_ifprop_relu_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
| "cutlass_simt_s8_ifprop_1x1_relu_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
| "cutlass_simt_s8_ifprop_hswish_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
| "cutlass_simt_s8_ifprop_1x1_hswish_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
| "cutlass_simt_s8_ifprop_identity_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
| "cutlass_simt_s8_ifprop_1x1_identity_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
| "cutlass_simt_s8_ifprop_id_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
| "cutlass_simt_s8_ifprop_1x1_id_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
| "cutlass_simt_s8_ifprop_relu_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
| "cutlass_simt_s8_ifprop_1x1_relu_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
| "cutlass_simt_s8_ifprop_hswish_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
| "cutlass_simt_s8_ifprop_1x1_hswish_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
| "cutlass_simt_s8_ifprop_identity_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
| "cutlass_simt_s8_ifprop_1x1_identity_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
| "cutlass_simt_s8_ifprop_id_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
| "cutlass_simt_s8_ifprop_1x1_id_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
| "cutlass_simt_s8_ifprop_relu_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
| "cutlass_simt_s8_ifprop_1x1_relu_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
| "cutlass_simt_s8_ifprop_hswish_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
| "cutlass_simt_s8_ifprop_1x1_hswish_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
| "cutlass_simt_s8_ifprop_identity_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
| "cutlass_simt_s8_ifprop_1x1_identity_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
| "cutlass_simt_s8_ifprop_relu_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
| "cutlass_simt_s8_ifprop_1x1_relu_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
| "cutlass_simt_s8_ifprop_hswish_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
| "cutlass_simt_s8_ifprop_1x1_hswish_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||
| "cutlass_simt_u4_ifprop_identity_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
| "cutlass_simt_u4_ifprop_id_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
| "cutlass_simt_u4_ifprop_relu_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
| "cutlass_simt_u4_ifprop_hswish_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
| "cutlass_simt_u4_ifprop_identity_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
| "cutlass_simt_u4_ifprop_id_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
| "cutlass_simt_u4_ifprop_relu_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
| "cutlass_simt_u4_ifprop_hswish_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
| "cutlass_simt_u4_ifprop_identity_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
| "cutlass_simt_u4_ifprop_id_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
| "cutlass_simt_u4_ifprop_relu_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
| "cutlass_simt_u4_ifprop_hswish_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
| "cutlass_simt_u4_ifprop_identity_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
| "cutlass_simt_u4_ifprop_relu_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
| "cutlass_simt_u4_ifprop_hswish_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
| "cutlass_simt_u4_ifprop_identity_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
| "cutlass_simt_u4_ifprop_id_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
| "cutlass_simt_u4_ifprop_relu_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
| "cutlass_simt_u4_ifprop_hswish_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
| "cutlass_simt_u4_ifprop_identity_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
| "cutlass_simt_u4_ifprop_id_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
| "cutlass_simt_u4_ifprop_relu_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
| "cutlass_simt_u4_ifprop_hswish_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
| "cutlass_simt_u4_ifprop_identity_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
| "cutlass_simt_u4_ifprop_id_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
| "cutlass_simt_u4_ifprop_relu_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
| "cutlass_simt_u4_ifprop_hswish_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
| "cutlass_simt_u4_ifprop_identity_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
| "cutlass_simt_u4_ifprop_id_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
| "cutlass_simt_u4_ifprop_relu_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
| "cutlass_simt_u4_ifprop_hswish_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
| "cutlass_simt_u4_ifprop_identity_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
| "cutlass_simt_u4_ifprop_relu_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
| "cutlass_simt_u4_ifprop_hswish_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
| "cutlass_simt_u4_ifprop_identity_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4_nhwc.cu", | |||
| "cutlass_simt_u4_ifprop_id_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4_nhwc.cu", | |||
| "cutlass_simt_u4_ifprop_relu_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4_nhwc.cu", | |||
| "cutlass_simt_u4_ifprop_hswish_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4_nhwc.cu", | |||
| "cutlass_simt_u4_ifprop_identity_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nhwc.cu", | |||
| "cutlass_simt_u4_ifprop_id_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nhwc.cu", | |||
| "cutlass_simt_u4_ifprop_relu_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nhwc.cu", | |||
| "cutlass_simt_u4_ifprop_hswish_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nhwc.cu", | |||
| "cutlass_simt_s4_ifprop_identity_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
| "cutlass_simt_s4_ifprop_id_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
| "cutlass_simt_s4_ifprop_relu_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
| "cutlass_simt_s4_ifprop_hswish_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
| "cutlass_simt_s4_ifprop_identity_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
| "cutlass_simt_s4_ifprop_id_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
| "cutlass_simt_s4_ifprop_relu_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
| "cutlass_simt_s4_ifprop_hswish_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
| "cutlass_simt_s4_ifprop_identity_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
| "cutlass_simt_s4_ifprop_id_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
| "cutlass_simt_s4_ifprop_relu_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
| "cutlass_simt_s4_ifprop_hswish_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
| "cutlass_simt_s4_ifprop_identity_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
| "cutlass_simt_s4_ifprop_relu_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
| "cutlass_simt_s4_ifprop_hswish_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
| "cutlass_simt_s4_ifprop_identity_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
| "cutlass_simt_s4_ifprop_id_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
| "cutlass_simt_s4_ifprop_relu_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
| "cutlass_simt_s4_ifprop_hswish_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
| "cutlass_simt_s4_ifprop_identity_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
| "cutlass_simt_s4_ifprop_id_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
| "cutlass_simt_s4_ifprop_relu_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
| "cutlass_simt_s4_ifprop_hswish_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
| "cutlass_simt_s4_ifprop_identity_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
| "cutlass_simt_s4_ifprop_id_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
| "cutlass_simt_s4_ifprop_relu_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
| "cutlass_simt_s4_ifprop_hswish_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
| "cutlass_simt_s4_ifprop_identity_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
| "cutlass_simt_s4_ifprop_id_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
| "cutlass_simt_s4_ifprop_relu_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
| "cutlass_simt_s4_ifprop_hswish_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
| "cutlass_simt_s4_ifprop_identity_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
| "cutlass_simt_s4_ifprop_relu_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
| "cutlass_simt_s4_ifprop_hswish_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||
| "cutlass_simt_s4_ifprop_identity_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4_nhwc.cu", | |||
| "cutlass_simt_s4_ifprop_id_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4_nhwc.cu", | |||
| "cutlass_simt_s4_ifprop_relu_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4_nhwc.cu", | |||
| "cutlass_simt_s4_ifprop_hswish_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4_nhwc.cu", | |||
| "cutlass_simt_s4_ifprop_identity_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nhwc.cu", | |||
| "cutlass_simt_s4_ifprop_id_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nhwc.cu", | |||
| "cutlass_simt_s4_ifprop_relu_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nhwc.cu", | |||
| "cutlass_simt_s4_ifprop_hswish_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nhwc.cu", | |||
| "cutlass_simt_f32_ifprop_identity_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
| "cutlass_simt_f32_ifprop_1x1_identity_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
| "cutlass_simt_f32_ifprop_id_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
| "cutlass_simt_f32_ifprop_1x1_id_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
| "cutlass_simt_f32_ifprop_relu_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
| "cutlass_simt_f32_ifprop_1x1_relu_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
| "cutlass_simt_f32_ifprop_hswish_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
| "cutlass_simt_f32_ifprop_1x1_hswish_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
| "cutlass_simt_f32_ifprop_identity_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
| "cutlass_simt_f32_ifprop_1x1_identity_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
| "cutlass_simt_f32_ifprop_id_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
| "cutlass_simt_f32_ifprop_1x1_id_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
| "cutlass_simt_f32_ifprop_relu_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
| "cutlass_simt_f32_ifprop_1x1_relu_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
| "cutlass_simt_f32_ifprop_hswish_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
| "cutlass_simt_f32_ifprop_1x1_hswish_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
| "cutlass_simt_f32_ifprop_identity_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
| "cutlass_simt_f32_ifprop_1x1_identity_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
| "cutlass_simt_f32_ifprop_id_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
| "cutlass_simt_f32_ifprop_1x1_id_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
| "cutlass_simt_f32_ifprop_relu_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
| "cutlass_simt_f32_ifprop_1x1_relu_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
| "cutlass_simt_f32_ifprop_hswish_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
| "cutlass_simt_f32_ifprop_1x1_hswish_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
| "cutlass_simt_f32_ifprop_identity_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
| "cutlass_simt_f32_ifprop_1x1_identity_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
| "cutlass_simt_f32_ifprop_relu_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
| "cutlass_simt_f32_ifprop_1x1_relu_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
| "cutlass_simt_f32_ifprop_hswish_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
| "cutlass_simt_f32_ifprop_1x1_hswish_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
| "cutlass_simt_f32_ifprop_identity_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
| "cutlass_simt_f32_ifprop_1x1_identity_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
| "cutlass_simt_f32_ifprop_id_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
| "cutlass_simt_f32_ifprop_1x1_id_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
| "cutlass_simt_f32_ifprop_relu_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
| "cutlass_simt_f32_ifprop_1x1_relu_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
| "cutlass_simt_f32_ifprop_hswish_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
| "cutlass_simt_f32_ifprop_1x1_hswish_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
| "cutlass_simt_f32_ifprop_identity_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
| "cutlass_simt_f32_ifprop_1x1_identity_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
| "cutlass_simt_f32_ifprop_id_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
| "cutlass_simt_f32_ifprop_1x1_id_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
| "cutlass_simt_f32_ifprop_relu_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
| "cutlass_simt_f32_ifprop_1x1_relu_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
| "cutlass_simt_f32_ifprop_hswish_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
| "cutlass_simt_f32_ifprop_1x1_hswish_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
| "cutlass_simt_f32_ifprop_identity_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
| "cutlass_simt_f32_ifprop_1x1_identity_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
| "cutlass_simt_f32_ifprop_id_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
| "cutlass_simt_f32_ifprop_1x1_id_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
| "cutlass_simt_f32_ifprop_relu_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
| "cutlass_simt_f32_ifprop_1x1_relu_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
| "cutlass_simt_f32_ifprop_hswish_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
| "cutlass_simt_f32_ifprop_1x1_hswish_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
| "cutlass_simt_f32_ifprop_identity_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
| "cutlass_simt_f32_ifprop_1x1_identity_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
| "cutlass_simt_f32_ifprop_id_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
| "cutlass_simt_f32_ifprop_1x1_id_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
| "cutlass_simt_f32_ifprop_relu_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
| "cutlass_simt_f32_ifprop_1x1_relu_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
| "cutlass_simt_f32_ifprop_hswish_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
| "cutlass_simt_f32_ifprop_1x1_hswish_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
| "cutlass_simt_f32_ifprop_identity_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
| "cutlass_simt_f32_ifprop_1x1_identity_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
| "cutlass_simt_f32_ifprop_relu_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
| "cutlass_simt_f32_ifprop_1x1_relu_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
| "cutlass_simt_f32_ifprop_hswish_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
| "cutlass_simt_f32_ifprop_1x1_hswish_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||
| "cutlass_simt_f32_ifprop_identity_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4_nchw.cu", | |||
| "cutlass_simt_f32_ifprop_1x1_identity_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4_nchw.cu", | |||
| "cutlass_simt_f32_ifprop_id_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4_nchw.cu", | |||
| "cutlass_simt_f32_ifprop_1x1_id_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4_nchw.cu", | |||
| "cutlass_simt_f32_ifprop_relu_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4_nchw.cu", | |||
| "cutlass_simt_f32_ifprop_1x1_relu_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4_nchw.cu", | |||
| "cutlass_simt_f32_ifprop_hswish_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4_nchw.cu", | |||
| "cutlass_simt_f32_ifprop_1x1_hswish_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4_nchw.cu", | |||
| "cutlass_simt_f32_ifprop_identity_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nchw.cu", | |||
| "cutlass_simt_f32_ifprop_1x1_identity_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nchw.cu", | |||
| "cutlass_simt_f32_ifprop_id_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nchw.cu", | |||
| "cutlass_simt_f32_ifprop_1x1_id_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nchw.cu", | |||
| "cutlass_simt_f32_ifprop_relu_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nchw.cu", | |||
| "cutlass_simt_f32_ifprop_1x1_relu_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nchw.cu", | |||
| "cutlass_simt_f32_ifprop_hswish_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nchw.cu", | |||
| "cutlass_simt_f32_ifprop_1x1_hswish_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nchw.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_identity_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_1x1_identity_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_relu_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_1x1_relu_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_hswish_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_1x1_hswish_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_identity_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_1x1_identity_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_relu_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_1x1_relu_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_hswish_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_1x1_hswish_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_identity_s8_128x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_1x1_identity_s8_128x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_relu_s8_128x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_1x1_relu_s8_128x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_hswish_s8_128x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_1x1_hswish_s8_128x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_identity_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_1x1_identity_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_relu_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_1x1_relu_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_hswish_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_1x1_hswish_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_identity_s8_128x64x64_64x32x64_2_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_1x1_identity_s8_128x64x64_64x32x64_2_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_relu_s8_128x64x64_64x32x64_2_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_1x1_relu_s8_128x64x64_64x32x64_2_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_hswish_s8_128x64x64_64x32x64_2_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_1x1_hswish_s8_128x64x64_64x32x64_2_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_identity_s8_64x64x64_32x32x64_2_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_1x1_identity_s8_64x64x64_32x32x64_2_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_relu_s8_64x64x64_32x32x64_2_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_1x1_relu_s8_64x64x64_32x32x64_2_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_hswish_s8_64x64x64_32x32x64_2_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_1x1_hswish_s8_64x64x64_32x32x64_2_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_identity_s8_32x64x64_32x16x64_2_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_1x1_identity_s8_32x64x64_32x16x64_2_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_relu_s8_32x64x64_32x16x64_2_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_1x1_relu_s8_32x64x64_32x16x64_2_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_hswish_s8_32x64x64_32x16x64_2_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_1x1_hswish_s8_32x64x64_32x16x64_2_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_identity_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_1x1_identity_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_relu_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_1x1_relu_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_hswish_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_1x1_hswish_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_identity_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_1x1_identity_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_roc_id_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_1x1_roc_id_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_roc_relu_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_1x1_roc_relu_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_roc_hswish_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_1x1_roc_hswish_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_roc_id_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_1x1_roc_id_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_roc_relu_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_1x1_roc_relu_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_roc_hswish_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_1x1_roc_hswish_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_roc_id_s8_128x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_1x1_roc_id_s8_128x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_roc_relu_s8_128x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_1x1_roc_relu_s8_128x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_roc_hswish_s8_128x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_1x1_roc_hswish_s8_128x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_roc_id_s8_128x64x64_64x32x64_2_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_1x1_roc_id_s8_128x64x64_64x32x64_2_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_roc_relu_s8_128x64x64_64x32x64_2_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_1x1_roc_relu_s8_128x64x64_64x32x64_2_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_roc_hswish_s8_128x64x64_64x32x64_2_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_1x1_roc_hswish_s8_128x64x64_64x32x64_2_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_roc_id_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_1x1_roc_id_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_roc_relu_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_1x1_roc_relu_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_roc_hswish_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_1x1_roc_hswish_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_roc_id_s8_128x64x32_64x32x32_1_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_1x1_roc_id_s8_128x64x32_64x32x32_1_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_roc_relu_s8_128x64x32_64x32x32_1_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_1x1_roc_relu_s8_128x64x32_64x32x32_1_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_roc_hswish_s8_128x64x32_64x32x32_1_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_1x1_roc_hswish_s8_128x64x32_64x32x32_1_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_roc_id_s8_128x32x32_64x32x32_1_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_1x1_roc_id_s8_128x32x32_64x32x32_1_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_roc_relu_s8_128x32x32_64x32x32_1_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_1x1_roc_relu_s8_128x32x32_64x32x32_1_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_roc_hswish_s8_128x32x32_64x32x32_1_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_1x1_roc_hswish_s8_128x32x32_64x32x32_1_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_roc_id_s8_64x128x32_32x64x32_1_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_1x1_roc_id_s8_64x128x32_32x64x32_1_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_roc_relu_s8_64x128x32_32x64x32_1_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_1x1_roc_relu_s8_64x128x32_32x64x32_1_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_roc_hswish_s8_64x128x32_32x64x32_1_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_1x1_roc_hswish_s8_64x128x32_32x64x32_1_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_roc_id_s8_32x128x32_32x64x32_1_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_1x1_roc_id_s8_32x128x32_32x64x32_1_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_roc_relu_s8_32x128x32_32x64x32_1_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_1x1_roc_relu_s8_32x128x32_32x64x32_1_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_roc_hswish_s8_32x128x32_32x64x32_1_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_1x1_roc_hswish_s8_32x128x32_32x64x32_1_nc32hw32_c32rsk32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_id_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_1x1_id_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_relu_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_1x1_relu_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_hswish_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_1x1_hswish_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_identity_s8_128x128x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_1x1_identity_s8_128x128x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_id_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_1x1_id_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_relu_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_1x1_relu_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_hswish_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_1x1_hswish_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_id_s8_128x128x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_1x1_id_s8_128x128x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_relu_s8_128x128x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_1x1_relu_s8_128x128x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_hswish_s8_128x128x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_1x1_hswish_s8_128x128x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_identity_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_1x1_identity_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_relu_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_1x1_relu_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_hswish_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_1x1_hswish_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_identity_s8_128x64x64_64x32x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_1x1_identity_s8_128x64x64_64x32x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_id_s8_128x64x64_64x32x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_1x1_id_s8_128x64x64_64x32x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_relu_s8_128x64x64_64x32x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_1x1_relu_s8_128x64x64_64x32x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_hswish_s8_128x64x64_64x32x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_1x1_hswish_s8_128x64x64_64x32x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_identity_s8_64x64x64_32x32x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_1x1_identity_s8_64x64x64_32x32x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_relu_s8_64x64x64_32x32x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_1x1_relu_s8_64x64x64_32x32x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_hswish_s8_64x64x64_32x32x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_1x1_hswish_s8_64x64x64_32x32x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_identity_s8_32x64x64_16x32x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_1x1_identity_s8_32x64x64_16x32x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_relu_s8_32x64x64_16x32x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_1x1_relu_s8_32x64x64_16x32x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_hswish_s8_32x64x64_16x32x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_1x1_hswish_s8_32x64x64_16x32x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_identity_s4_256x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_relu_s4_256x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_hswish_s4_256x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_identity_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_relu_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_hswish_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_identity_u4_s4_256x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_relu_u4_s4_256x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_identity_u4_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_relu_u4_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_identity_s4_128x32x64_64x32x64_2_nhwc_nc8hw8.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_1x1_identity_s4_128x32x64_64x32x64_2_nhwc_nc8hw8.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_relu_s4_128x32x64_64x32x64_2_nhwc_nc8hw8.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_1x1_relu_s4_128x32x64_64x32x64_2_nhwc_nc8hw8.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_hswish_s4_128x32x64_64x32x64_2_nhwc_nc8hw8.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_1x1_hswish_s4_128x32x64_64x32x64_2_nhwc_nc8hw8.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_identity_s4_128x64x64_64x64x64_2_nhwc_nc8hw8.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_1x1_identity_s4_128x64x64_64x64x64_2_nhwc_nc8hw8.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_relu_s4_128x64x64_64x64x64_2_nhwc_nc8hw8.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_1x1_relu_s4_128x64x64_64x64x64_2_nhwc_nc8hw8.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_hswish_s4_128x64x64_64x64x64_2_nhwc_nc8hw8.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_1x1_hswish_s4_128x64x64_64x64x64_2_nhwc_nc8hw8.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_identity_s4_128x32x64_64x32x64_2_nhwc_nc16hw16.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_1x1_identity_s4_128x32x64_64x32x64_2_nhwc_nc16hw16.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_relu_s4_128x32x64_64x32x64_2_nhwc_nc16hw16.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_1x1_relu_s4_128x32x64_64x32x64_2_nhwc_nc16hw16.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_hswish_s4_128x32x64_64x32x64_2_nhwc_nc16hw16.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_1x1_hswish_s4_128x32x64_64x32x64_2_nhwc_nc16hw16.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_identity_s4_128x64x64_64x64x64_2_nhwc_nc16hw16.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_1x1_identity_s4_128x64x64_64x64x64_2_nhwc_nc16hw16.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_relu_s4_128x64x64_64x64x64_2_nhwc_nc16hw16.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_1x1_relu_s4_128x64x64_64x64x64_2_nhwc_nc16hw16.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_hswish_s4_128x64x64_64x64x64_2_nhwc_nc16hw16.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_1x1_hswish_s4_128x64x64_64x64x64_2_nhwc_nc16hw16.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_identity_s4_128x32x64_64x32x64_2_nhwc_nc32hw32.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_1x1_identity_s4_128x32x64_64x32x64_2_nhwc_nc32hw32.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_relu_s4_128x32x64_64x32x64_2_nhwc_nc32hw32.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_1x1_relu_s4_128x32x64_64x32x64_2_nhwc_nc32hw32.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_hswish_s4_128x32x64_64x32x64_2_nhwc_nc32hw32.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_1x1_hswish_s4_128x32x64_64x32x64_2_nhwc_nc32hw32.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_identity_s4_128x64x64_64x64x64_2_nhwc_nc32hw32.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_1x1_identity_s4_128x64x64_64x64x64_2_nhwc_nc32hw32.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_relu_s4_128x64x64_64x64x64_2_nhwc_nc32hw32.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_1x1_relu_s4_128x64x64_64x64x64_2_nhwc_nc32hw32.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_hswish_s4_128x64x64_64x64x64_2_nhwc_nc32hw32.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_1x1_hswish_s4_128x64x64_64x64x64_2_nhwc_nc32hw32.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_identity_u4_s4_128x32x64_64x32x64_2_nhwc_nc8hw8.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_1x1_identity_u4_s4_128x32x64_64x32x64_2_nhwc_nc8hw8.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_relu_u4_s4_128x32x64_64x32x64_2_nhwc_nc8hw8.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_1x1_relu_u4_s4_128x32x64_64x32x64_2_nhwc_nc8hw8.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_identity_u4_s4_128x64x64_64x64x64_2_nhwc_nc8hw8.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_1x1_identity_u4_s4_128x64x64_64x64x64_2_nhwc_nc8hw8.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_relu_u4_s4_128x64x64_64x64x64_2_nhwc_nc8hw8.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_1x1_relu_u4_s4_128x64x64_64x64x64_2_nhwc_nc8hw8.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_identity_u4_s4_128x32x64_64x32x64_2_nhwc_nc16hw16.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_1x1_identity_u4_s4_128x32x64_64x32x64_2_nhwc_nc16hw16.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_relu_u4_s4_128x32x64_64x32x64_2_nhwc_nc16hw16.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_1x1_relu_u4_s4_128x32x64_64x32x64_2_nhwc_nc16hw16.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_identity_u4_s4_128x64x64_64x64x64_2_nhwc_nc16hw16.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_1x1_identity_u4_s4_128x64x64_64x64x64_2_nhwc_nc16hw16.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_relu_u4_s4_128x64x64_64x64x64_2_nhwc_nc16hw16.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_1x1_relu_u4_s4_128x64x64_64x64x64_2_nhwc_nc16hw16.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_identity_u4_s4_128x32x64_64x32x64_2_nhwc_nc32hw32.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_1x1_identity_u4_s4_128x32x64_64x32x64_2_nhwc_nc32hw32.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_relu_u4_s4_128x32x64_64x32x64_2_nhwc_nc32hw32.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_1x1_relu_u4_s4_128x32x64_64x32x64_2_nhwc_nc32hw32.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_identity_u4_s4_128x64x64_64x64x64_2_nhwc_nc32hw32.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_1x1_identity_u4_s4_128x64x64_64x64x64_2_nhwc_nc32hw32.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_relu_u4_s4_128x64x64_64x64x64_2_nhwc_nc32hw32.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_1x1_relu_u4_s4_128x64x64_64x64x64_2_nhwc_nc32hw32.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_id_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_1x1_id_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_relu_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_1x1_relu_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_hswish_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_1x1_hswish_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_id_s8_128x64x32_64x32x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_1x1_id_s8_128x64x32_64x32x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_relu_s8_128x64x32_64x32x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_1x1_relu_s8_128x64x32_64x32x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_hswish_s8_128x64x32_64x32x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_1x1_hswish_s8_128x64x32_64x32x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_id_s8_128x32x32_64x32x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_1x1_id_s8_128x32x32_64x32x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_relu_s8_128x32x32_64x32x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_1x1_relu_s8_128x32x32_64x32x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_hswish_s8_128x32x32_64x32x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_1x1_hswish_s8_128x32x32_64x32x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_id_s8_64x128x32_32x64x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_1x1_id_s8_64x128x32_32x64x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_relu_s8_64x128x32_32x64x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_1x1_relu_s8_64x128x32_32x64x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_hswish_s8_64x128x32_32x64x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_1x1_hswish_s8_64x128x32_32x64x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_id_s8_32x128x32_32x64x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_1x1_id_s8_32x128x32_32x64x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_relu_s8_32x128x32_32x64x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_1x1_relu_s8_32x128x32_32x64x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_hswish_s8_32x128x32_32x64x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||
| "cutlass_tensorop_s8_i8816fprop_1x1_hswish_s8_32x128x32_32x64x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_roc_id_s4_128x256x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_roc_relu_s4_128x256x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_roc_hswish_s4_128x256x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_roc_id_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_roc_relu_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_roc_hswish_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_roc_id_s4_128x64x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_roc_relu_s4_128x64x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_roc_hswish_s4_128x64x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_roc_id_s4_128x64x64_64x64x64_1_nc64hw64_c64rsk64.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_roc_relu_s4_128x64x64_64x64x64_1_nc64hw64_c64rsk64.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_roc_hswish_s4_128x64x64_64x64x64_1_nc64hw64_c64rsk64.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_roc_id_u4_s4_128x256x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_roc_relu_u4_s4_128x256x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_roc_id_u4_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_roc_relu_u4_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_roc_id_u4_s4_128x64x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_roc_relu_u4_s4_128x64x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_roc_id_u4_s4_128x64x64_64x64x64_1_nc64hw64_c64rsk64.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_roc_relu_u4_s4_128x64x64_64x64x64_1_nc64hw64_c64rsk64.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_id_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_1x1_id_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_relu_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_1x1_relu_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_hswish_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_1x1_hswish_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_roc_id_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_1x1_roc_id_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_roc_relu_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_1x1_roc_relu_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_roc_hswish_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_1x1_roc_hswish_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_id_s4_128x64x64_64x64x64_1_nhwc_nc8hw8.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_1x1_id_s4_128x64x64_64x64x64_1_nhwc_nc8hw8.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_relu_s4_128x64x64_64x64x64_1_nhwc_nc8hw8.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_1x1_relu_s4_128x64x64_64x64x64_1_nhwc_nc8hw8.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_hswish_s4_128x64x64_64x64x64_1_nhwc_nc8hw8.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_1x1_hswish_s4_128x64x64_64x64x64_1_nhwc_nc8hw8.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_roc_id_s4_128x64x64_64x64x64_1_nhwc_nc8hw8.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_1x1_roc_id_s4_128x64x64_64x64x64_1_nhwc_nc8hw8.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_roc_relu_s4_128x64x64_64x64x64_1_nhwc_nc8hw8.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_1x1_roc_relu_s4_128x64x64_64x64x64_1_nhwc_nc8hw8.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_roc_hswish_s4_128x64x64_64x64x64_1_nhwc_nc8hw8.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_1x1_roc_hswish_s4_128x64x64_64x64x64_1_nhwc_nc8hw8.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_id_s4_128x32x64_64x32x64_1_nhwc_nc16hw16.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_1x1_id_s4_128x32x64_64x32x64_1_nhwc_nc16hw16.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_relu_s4_128x32x64_64x32x64_1_nhwc_nc16hw16.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_1x1_relu_s4_128x32x64_64x32x64_1_nhwc_nc16hw16.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_hswish_s4_128x32x64_64x32x64_1_nhwc_nc16hw16.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_1x1_hswish_s4_128x32x64_64x32x64_1_nhwc_nc16hw16.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_roc_id_s4_128x32x64_64x32x64_1_nhwc_nc16hw16.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_1x1_roc_id_s4_128x32x64_64x32x64_1_nhwc_nc16hw16.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_roc_relu_s4_128x32x64_64x32x64_1_nhwc_nc16hw16.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_1x1_roc_relu_s4_128x32x64_64x32x64_1_nhwc_nc16hw16.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_roc_hswish_s4_128x32x64_64x32x64_1_nhwc_nc16hw16.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_1x1_roc_hswish_s4_128x32x64_64x32x64_1_nhwc_nc16hw16.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_id_s4_128x64x64_64x64x64_1_nhwc_nc16hw16.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_1x1_id_s4_128x64x64_64x64x64_1_nhwc_nc16hw16.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_relu_s4_128x64x64_64x64x64_1_nhwc_nc16hw16.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_1x1_relu_s4_128x64x64_64x64x64_1_nhwc_nc16hw16.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_hswish_s4_128x64x64_64x64x64_1_nhwc_nc16hw16.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_1x1_hswish_s4_128x64x64_64x64x64_1_nhwc_nc16hw16.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_roc_id_s4_128x64x64_64x64x64_1_nhwc_nc16hw16.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_1x1_roc_id_s4_128x64x64_64x64x64_1_nhwc_nc16hw16.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_roc_relu_s4_128x64x64_64x64x64_1_nhwc_nc16hw16.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_1x1_roc_relu_s4_128x64x64_64x64x64_1_nhwc_nc16hw16.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_roc_hswish_s4_128x64x64_64x64x64_1_nhwc_nc16hw16.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_1x1_roc_hswish_s4_128x64x64_64x64x64_1_nhwc_nc16hw16.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_id_s4_128x32x64_64x32x64_1_nhwc_nc32hw32.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_1x1_id_s4_128x32x64_64x32x64_1_nhwc_nc32hw32.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_relu_s4_128x32x64_64x32x64_1_nhwc_nc32hw32.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_1x1_relu_s4_128x32x64_64x32x64_1_nhwc_nc32hw32.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_hswish_s4_128x32x64_64x32x64_1_nhwc_nc32hw32.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_1x1_hswish_s4_128x32x64_64x32x64_1_nhwc_nc32hw32.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_roc_id_s4_128x32x64_64x32x64_1_nhwc_nc32hw32.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_1x1_roc_id_s4_128x32x64_64x32x64_1_nhwc_nc32hw32.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_roc_relu_s4_128x32x64_64x32x64_1_nhwc_nc32hw32.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_1x1_roc_relu_s4_128x32x64_64x32x64_1_nhwc_nc32hw32.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_roc_hswish_s4_128x32x64_64x32x64_1_nhwc_nc32hw32.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_1x1_roc_hswish_s4_128x32x64_64x32x64_1_nhwc_nc32hw32.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_id_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_1x1_id_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_relu_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_1x1_relu_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_hswish_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_1x1_hswish_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_roc_id_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_1x1_roc_id_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_roc_relu_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_1x1_roc_relu_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_roc_hswish_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | |||
| "cutlass_tensorop_s4_i8832fprop_1x1_roc_hswish_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_id_u4_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_1x1_id_u4_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_relu_u4_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_1x1_relu_u4_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_roc_id_u4_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_1x1_roc_id_u4_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_roc_relu_u4_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_1x1_roc_relu_u4_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_id_u4_s4_128x64x64_64x64x64_1_nhwc_nc8hw8.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_1x1_id_u4_s4_128x64x64_64x64x64_1_nhwc_nc8hw8.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_relu_u4_s4_128x64x64_64x64x64_1_nhwc_nc8hw8.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_1x1_relu_u4_s4_128x64x64_64x64x64_1_nhwc_nc8hw8.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_roc_id_u4_s4_128x64x64_64x64x64_1_nhwc_nc8hw8.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_1x1_roc_id_u4_s4_128x64x64_64x64x64_1_nhwc_nc8hw8.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_roc_relu_u4_s4_128x64x64_64x64x64_1_nhwc_nc8hw8.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_1x1_roc_relu_u4_s4_128x64x64_64x64x64_1_nhwc_nc8hw8.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_id_u4_s4_128x32x64_64x32x64_1_nhwc_nc16hw16.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_1x1_id_u4_s4_128x32x64_64x32x64_1_nhwc_nc16hw16.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_relu_u4_s4_128x32x64_64x32x64_1_nhwc_nc16hw16.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_1x1_relu_u4_s4_128x32x64_64x32x64_1_nhwc_nc16hw16.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_roc_id_u4_s4_128x32x64_64x32x64_1_nhwc_nc16hw16.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_1x1_roc_id_u4_s4_128x32x64_64x32x64_1_nhwc_nc16hw16.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_roc_relu_u4_s4_128x32x64_64x32x64_1_nhwc_nc16hw16.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_1x1_roc_relu_u4_s4_128x32x64_64x32x64_1_nhwc_nc16hw16.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_id_u4_s4_128x64x64_64x64x64_1_nhwc_nc16hw16.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_1x1_id_u4_s4_128x64x64_64x64x64_1_nhwc_nc16hw16.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_relu_u4_s4_128x64x64_64x64x64_1_nhwc_nc16hw16.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_1x1_relu_u4_s4_128x64x64_64x64x64_1_nhwc_nc16hw16.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_roc_id_u4_s4_128x64x64_64x64x64_1_nhwc_nc16hw16.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_1x1_roc_id_u4_s4_128x64x64_64x64x64_1_nhwc_nc16hw16.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_roc_relu_u4_s4_128x64x64_64x64x64_1_nhwc_nc16hw16.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_1x1_roc_relu_u4_s4_128x64x64_64x64x64_1_nhwc_nc16hw16.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_id_u4_s4_128x32x64_64x32x64_1_nhwc_nc32hw32.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_1x1_id_u4_s4_128x32x64_64x32x64_1_nhwc_nc32hw32.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_relu_u4_s4_128x32x64_64x32x64_1_nhwc_nc32hw32.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_1x1_relu_u4_s4_128x32x64_64x32x64_1_nhwc_nc32hw32.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_roc_id_u4_s4_128x32x64_64x32x64_1_nhwc_nc32hw32.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_1x1_roc_id_u4_s4_128x32x64_64x32x64_1_nhwc_nc32hw32.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_roc_relu_u4_s4_128x32x64_64x32x64_1_nhwc_nc32hw32.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_1x1_roc_relu_u4_s4_128x32x64_64x32x64_1_nhwc_nc32hw32.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_id_u4_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_1x1_id_u4_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_relu_u4_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_1x1_relu_u4_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_roc_id_u4_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_1x1_roc_id_u4_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_roc_relu_u4_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | |||
| "cutlass_tensorop_u4_i8832fprop_1x1_roc_relu_u4_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | |||
| ] | |||
| @@ -217,56 +217,68 @@ void ConvBiasForwardImpl::AlgoPack::fill_imma_algos() { | |||
| #if CUDA_VERSION >= 10020 | |||
| { | |||
| using AlgoParam = AlgoInt8NCHW32IMMAImplicitGemm::AlgoParam; | |||
| int8_nchw32_imma.emplace_back(AlgoParam{128, 256, 64, 64, 64, 64}); | |||
| int8_nchw32_imma.emplace_back(AlgoParam{256, 128, 64, 64, 64, 64}); | |||
| int8_nchw32_imma.emplace_back(AlgoParam{128, 128, 64, 64, 64, 64}); | |||
| int8_nchw32_imma.emplace_back(AlgoParam{64, 128, 64, 32, 64, 64}); | |||
| int8_nchw32_imma.emplace_back(AlgoParam{128, 64, 64, 64, 32, 64}); | |||
| int8_nchw32_imma.emplace_back(AlgoParam{64, 64, 64, 32, 32, 64}); | |||
| int8_nchw32_imma.emplace_back(AlgoParam{32, 64, 64, 32, 16, 64}); | |||
| int8_nchw32_imma.emplace_back(AlgoParam{128, 256, 64, 64, 64, 64, 2}); | |||
| int8_nchw32_imma.emplace_back(AlgoParam{256, 128, 64, 64, 64, 64, 2}); | |||
| int8_nchw32_imma.emplace_back(AlgoParam{128, 128, 64, 64, 64, 64, 2}); | |||
| int8_nchw32_imma.emplace_back(AlgoParam{128, 64, 64, 64, 32, 64, 2}); | |||
| int8_nchw32_imma.emplace_back(AlgoParam{64, 128, 64, 32, 64, 64, 2}); | |||
| int8_nchw32_imma.emplace_back(AlgoParam{128, 64, 32, 64, 32, 32, 1}); | |||
| int8_nchw32_imma.emplace_back(AlgoParam{128, 32, 32, 64, 32, 32, 1}); | |||
| int8_nchw32_imma.emplace_back(AlgoParam{64, 128, 32, 32, 64, 32, 1}); | |||
| int8_nchw32_imma.emplace_back(AlgoParam{32, 128, 32, 32, 64, 32, 1}); | |||
| } | |||
| { | |||
| using AlgoParam = AlgoInt4Int4NCHW64IMMAImplicitGemm::AlgoParam; | |||
| int4_int4_nchw64_imma.emplace_back( | |||
| AlgoParam{128, 128, 128, 64, 64, 128}); | |||
| AlgoParam{128, 128, 128, 64, 64, 128, 2}); | |||
| int4_int4_nchw64_imma.emplace_back( | |||
| AlgoParam{256, 128, 128, 64, 64, 128}); | |||
| AlgoParam{128, 256, 128, 64, 64, 128, 2}); | |||
| int4_int4_nchw64_imma.emplace_back( | |||
| AlgoParam{128, 64, 128, 64, 64, 128, 2}); | |||
| int4_int4_nchw64_imma.emplace_back( | |||
| AlgoParam{128, 64, 64, 64, 64, 64, 1}); | |||
| } | |||
| { | |||
| using AlgoParam = AlgoUInt4Int4NCHW64IMMAImplicitGemm::AlgoParam; | |||
| uint4_int4_nchw64_imma.emplace_back( | |||
| AlgoParam{128, 128, 128, 64, 64, 128}); | |||
| AlgoParam{128, 128, 128, 64, 64, 128, 2}); | |||
| uint4_int4_nchw64_imma.emplace_back( | |||
| AlgoParam{128, 256, 128, 64, 64, 128, 2}); | |||
| uint4_int4_nchw64_imma.emplace_back( | |||
| AlgoParam{256, 128, 128, 64, 64, 128}); | |||
| AlgoParam{128, 64, 128, 64, 64, 128, 2}); | |||
| uint4_int4_nchw64_imma.emplace_back( | |||
| AlgoParam{128, 64, 64, 64, 64, 64, 1}); | |||
| } | |||
| { | |||
| using AlgoParam = AlgoInt4Int4NHWCIMMAImplicitGemm::AlgoParam; | |||
| int4_int4_nhwc_imma.emplace_back( | |||
| AlgoParam{128, 32, 64, 64, 32, 64, 32}); | |||
| AlgoParam{128, 32, 64, 64, 32, 64, 1, 32}); | |||
| int4_int4_nhwc_imma.emplace_back( | |||
| AlgoParam{128, 32, 64, 64, 32, 64, 1, 16}); | |||
| int4_int4_nhwc_imma.emplace_back( | |||
| AlgoParam{128, 32, 64, 64, 32, 64, 1, 8}); | |||
| int4_int4_nhwc_imma.emplace_back( | |||
| AlgoParam{128, 32, 64, 64, 32, 64, 16}); | |||
| int4_int4_nhwc_imma.emplace_back(AlgoParam{128, 32, 64, 64, 32, 64, 8}); | |||
| AlgoParam{128, 64, 64, 64, 64, 64, 1, 32}); | |||
| int4_int4_nhwc_imma.emplace_back( | |||
| AlgoParam{128, 64, 64, 64, 64, 64, 32}); | |||
| AlgoParam{128, 64, 64, 64, 64, 64, 1, 16}); | |||
| int4_int4_nhwc_imma.emplace_back( | |||
| AlgoParam{128, 64, 64, 64, 64, 64, 16}); | |||
| int4_int4_nhwc_imma.emplace_back(AlgoParam{128, 64, 64, 64, 64, 64, 8}); | |||
| AlgoParam{128, 64, 64, 64, 64, 64, 1, 8}); | |||
| } | |||
| { | |||
| using AlgoParam = AlgoUInt4Int4NHWCIMMAImplicitGemm::AlgoParam; | |||
| uint4_int4_nhwc_imma.emplace_back( | |||
| AlgoParam{128, 32, 64, 64, 32, 64, 32}); | |||
| AlgoParam{128, 32, 64, 64, 32, 64, 1, 32}); | |||
| uint4_int4_nhwc_imma.emplace_back( | |||
| AlgoParam{128, 32, 64, 64, 32, 64, 16}); | |||
| AlgoParam{128, 32, 64, 64, 32, 64, 1, 16}); | |||
| uint4_int4_nhwc_imma.emplace_back( | |||
| AlgoParam{128, 32, 64, 64, 32, 64, 8}); | |||
| AlgoParam{128, 32, 64, 64, 32, 64, 1, 8}); | |||
| uint4_int4_nhwc_imma.emplace_back( | |||
| AlgoParam{128, 64, 64, 64, 64, 64, 32}); | |||
| AlgoParam{128, 64, 64, 64, 64, 64, 1, 32}); | |||
| uint4_int4_nhwc_imma.emplace_back( | |||
| AlgoParam{128, 64, 64, 64, 64, 64, 16}); | |||
| AlgoParam{128, 64, 64, 64, 64, 64, 1, 16}); | |||
| uint4_int4_nhwc_imma.emplace_back( | |||
| AlgoParam{128, 64, 64, 64, 64, 64, 8}); | |||
| AlgoParam{128, 64, 64, 64, 64, 64, 1, 8}); | |||
| } | |||
| #endif | |||
| } | |||
| @@ -279,10 +291,8 @@ void ConvBiasForwardImpl::AlgoPack::fill_dp4a_algos() { | |||
| int8_nchw4_dotprod.emplace_back(AlgoParam{64, 128, 32, 64, 32, 32, 2}); | |||
| int8_nchw4_dotprod.emplace_back(AlgoParam{32, 128, 32, 32, 64, 32, 2}); | |||
| int8_nchw4_dotprod.emplace_back(AlgoParam{128, 32, 32, 64, 32, 32, 2}); | |||
| int8_nchw4_dotprod.emplace_back(AlgoParam{64, 64, 32, 64, 32, 32, 2}); | |||
| int8_nchw4_dotprod.emplace_back(AlgoParam{32, 64, 32, 32, 64, 32, 2}); | |||
| int8_nchw4_dotprod.emplace_back(AlgoParam{64, 32, 32, 64, 32, 32, 2}); | |||
| int8_nchw4_dotprod.emplace_back(AlgoParam{32, 32, 32, 32, 32, 32, 2}); | |||
| int8_nchw4_dotprod.emplace_back(AlgoParam{16, 128, 16, 16, 128, 16, 1}); | |||
| int8_nchw4_dotprod.emplace_back(AlgoParam{16, 64, 8, 16, 64, 8, 2}); | |||
| } | |||
| @@ -723,6 +723,7 @@ public: | |||
| int warp_m; | |||
| int warp_n; | |||
| int warp_k; | |||
| int stage; | |||
| }; | |||
| AlgoInt8NCHW32IMMAImplicitGemm(AlgoParam algo_param) | |||
| : m_algo_param{algo_param} { | |||
| @@ -770,6 +771,7 @@ public: | |||
| int warp_m; | |||
| int warp_n; | |||
| int warp_k; | |||
| int stage; | |||
| }; | |||
| AlgoInt4NCHW64IMMAImplicitGemmBase(AlgoParam algo_param) | |||
| @@ -897,6 +899,7 @@ public: | |||
| int warp_m; | |||
| int warp_n; | |||
| int warp_k; | |||
| int stage; | |||
| int access_size; | |||
| }; | |||
| @@ -38,7 +38,7 @@ void do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32( | |||
| const convolution::ConvParam& param, uint32_t nonlinear_mode, | |||
| float alpha, float beta, float gamma, float scale, | |||
| const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, | |||
| cudaStream_t stream); | |||
| int stages, cudaStream_t stream); | |||
| template <bool NeedLoadFromConstMem> | |||
| void do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4( | |||
| @@ -47,7 +47,7 @@ void do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4( | |||
| const convolution::ConvParam& param, uint32_t nonlinear_mode, | |||
| float alpha, float beta, float gamma, float scale, | |||
| const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, | |||
| cudaStream_t stream); | |||
| int stages, cudaStream_t stream); | |||
| template <bool NeedLoadFromConstMem> | |||
| void do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4( | |||
| @@ -83,7 +83,7 @@ void do_conv_bias_int4_int4_implicit_gemm_imma_ncdiv64hw64( | |||
| const convolution::ConvParam& param, uint32_t nonlinear_mode, | |||
| float alpha, float beta, float gamma, float scale, | |||
| const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, | |||
| cudaStream_t stream); | |||
| int stages, cudaStream_t stream); | |||
| template <bool NeedLoadFromConstMem> | |||
| void do_conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64( | |||
| @@ -92,7 +92,7 @@ void do_conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64( | |||
| const convolution::ConvParam& param, uint32_t nonlinear_mode, | |||
| float alpha, float beta, float gamma, float delta, float theta, | |||
| float scale, uint8_t src_zero_point, const GemmCoord& threadblock_shape, | |||
| const GemmCoord& warp_shape, cudaStream_t stream); | |||
| const GemmCoord& warp_shape, int stages, cudaStream_t stream); | |||
| template <bool signedness> | |||
| void do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nhwc( | |||
| @@ -110,7 +110,7 @@ void do_conv_bias_int4_int4_implicit_gemm_imma_nhwc( | |||
| const convolution::ConvParam& param, uint32_t nonlinear_mode, | |||
| float alpha, float beta, float gamma, float scale, | |||
| const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, | |||
| const int32_t access_size, cudaStream_t stream); | |||
| const int32_t access_size, int stages, cudaStream_t stream); | |||
| template <bool NeedLoadFromConstMem> | |||
| void do_conv_bias_uint4_int4_implicit_gemm_imma_nhwc( | |||
| @@ -119,7 +119,7 @@ void do_conv_bias_uint4_int4_implicit_gemm_imma_nhwc( | |||
| const convolution::ConvParam& param, uint32_t nonlinear_mode, | |||
| float alpha, float beta, float gamma, float delta, float theta, | |||
| float scale, uint8_t src_zero_point, const GemmCoord& threadblock_shape, | |||
| const GemmCoord& warp_shape, const int32_t access_size, | |||
| const GemmCoord& warp_shape, const int32_t access_size, int stages, | |||
| cudaStream_t stream); | |||
| } // namespace cutlass_wrapper | |||
| @@ -0,0 +1,595 @@ | |||
| /** | |||
| * \file dnn/src/cuda/conv_bias/cutlass_convolution_wrapper.cu | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| // ignore warning of cutlass | |||
| #pragma GCC diagnostic push | |||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||
| #if !MEGDNN_TEGRA_X1 | |||
| #include "cutlass/convolution/device/convolution.h" | |||
| #endif | |||
| #include "src/common/opr_param_defs_enumv.cuh" | |||
| #include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh" | |||
| #pragma GCC diagnostic pop | |||
| using namespace megdnn; | |||
| using namespace cuda; | |||
| using namespace cutlass_wrapper; | |||
| /* ====== cutlass kernel wrapper for int4 x int4 nchw64 layout ====== */ | |||
| #if MEGDNN_TEGRA_X1 | |||
| template <bool NeedLoadFromConstMem> | |||
| void megdnn::cuda::cutlass_wrapper:: | |||
| do_conv_bias_int4_int4_implicit_gemm_imma_ncdiv64hw64( | |||
| const int8_t* /* d_src */, const int8_t* /* d_filter */, | |||
| const int32_t* /* d_bias */, const int8_t* /* d_z */, | |||
| int8_t* /* d_dst */, int* /* workspace */, | |||
| const convolution::ConvParam& /* param */, | |||
| uint32_t /* nonlinear_mode */, float /* alpha */, | |||
| float /* beta */, float /* gamma */, float /* scale */, | |||
| const GemmCoord& /* threadblock_shape */, | |||
| const GemmCoord& /* warp_shape */, int /* stages */, | |||
| cudaStream_t /* stream */) {} | |||
| #else | |||
| template <bool NeedLoadFromConstMem> | |||
| void megdnn::cuda::cutlass_wrapper:: | |||
| do_conv_bias_int4_int4_implicit_gemm_imma_ncdiv64hw64( | |||
| const int8_t* d_src, const int8_t* d_filter, | |||
| const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, | |||
| int* workspace, const convolution::ConvParam& param, | |||
| uint32_t nonlinear_mode, float alpha, float beta, float gamma, | |||
| float scale, const GemmCoord& threadblock_shape, | |||
| const GemmCoord& warp_shape, int stages, cudaStream_t stream) { | |||
| #define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \ | |||
| threadblock_k_, warp_m_, warp_n_, \ | |||
| warp_k_, stage_) \ | |||
| if (threadblock_shape.m() == threadblock_m_ && \ | |||
| threadblock_shape.n() == threadblock_n_ && \ | |||
| threadblock_shape.k() == threadblock_k_ && \ | |||
| warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ | |||
| warp_shape.k() == warp_k_ && stages == stage_) { \ | |||
| using ThreadBlockShape = \ | |||
| cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \ | |||
| threadblock_k_>; \ | |||
| using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \ | |||
| using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; \ | |||
| using Convolution = cutlass::conv::device::Convolution< \ | |||
| cutlass::int4b_t, cutlass::layout::TensorNCxHWx<64>, \ | |||
| cutlass::int4b_t, cutlass::layout::TensorCxRSKx<64>, \ | |||
| ElementOutput, cutlass::layout::TensorNCxHWx<64>, int32_t, \ | |||
| cutlass::layout::TensorNCxHWx<64>, int32_t, \ | |||
| cutlass::conv::ConvType::kConvolution, \ | |||
| cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \ | |||
| ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ | |||
| cutlass::conv::threadblock:: \ | |||
| ConvolutionFpropTransThreadblockSwizzle, \ | |||
| stage_, 32, 32, NeedLoadFromConstMem, \ | |||
| cutlass::arch::OpMultiplyAddSaturate, \ | |||
| cutlass::conv::ImplicitGemmMode::GEMM_TN, true>; \ | |||
| typename Convolution::ConvolutionParameter conv_param( \ | |||
| param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ | |||
| param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ | |||
| param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \ | |||
| return cutlass_convolution_wrapper<Convolution>( \ | |||
| reinterpret_cast<const cutlass::int4b_t*>(d_src), \ | |||
| reinterpret_cast<const cutlass::int4b_t*>(d_filter), d_bias, \ | |||
| reinterpret_cast<const cutlass::int4b_t*>(d_z), \ | |||
| reinterpret_cast<cutlass::int4b_t*>(d_dst), workspace, \ | |||
| conv_param, epilogue, stream); \ | |||
| } | |||
| #define DISPATCH_KERNEL \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 128, 64, 64, 128, 2); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 256, 128, 64, 64, 128, 2); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 128, 64, 64, 128, 2); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1); \ | |||
| megdnn_assert(false, \ | |||
| "unsupported threadblock shape (%dx%dx%d) and warp shape " \ | |||
| "(%dx%dx%d)", \ | |||
| threadblock_shape.m(), threadblock_shape.n(), \ | |||
| threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ | |||
| warp_shape.k()); | |||
| using ElementOutput = cutlass::int4b_t; | |||
| using ElementAccumulator = int32_t; | |||
| using ElementBias = int32_t; | |||
| using ElementCompute = float; | |||
| using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode; | |||
| switch (nonlinear_mode) { | |||
| case NonlineMode::IDENTITY: { | |||
| using EpilogueOp = | |||
| cutlass::epilogue::thread::BiasAddLinearCombinationClamp< | |||
| ElementOutput, 16, ElementAccumulator, ElementBias, | |||
| ElementCompute>; | |||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma}; | |||
| DISPATCH_KERNEL; | |||
| } | |||
| case NonlineMode::RELU: { | |||
| using EpilogueOp = cutlass::epilogue::thread:: | |||
| BiasAddLinearCombinationReluClamp< | |||
| ElementOutput, 16, ElementAccumulator, ElementBias, | |||
| ElementCompute>; | |||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma, 0}; | |||
| DISPATCH_KERNEL; | |||
| } | |||
| case NonlineMode::H_SWISH: { | |||
| using EpilogueOp = cutlass::epilogue::thread:: | |||
| BiasAddLinearCombinationHSwishClamp< | |||
| ElementOutput, 16, ElementAccumulator, ElementBias, | |||
| ElementCompute>; | |||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma, scale}; | |||
| DISPATCH_KERNEL; | |||
| } | |||
| default: | |||
| megdnn_assert(false, | |||
| "unsupported nonlinear mode for conv bias operator"); | |||
| } | |||
| #undef DISPATCH_KERNEL_WITH_TILE_SHAPE | |||
| #undef DISPATCH_KERNEL | |||
| } | |||
| #endif | |||
| #define INST(need_load_from_const_mem) \ | |||
| template void megdnn::cuda::cutlass_wrapper:: \ | |||
| do_conv_bias_int4_int4_implicit_gemm_imma_ncdiv64hw64< \ | |||
| need_load_from_const_mem>( \ | |||
| const int8_t* d_src, const int8_t* d_filter, \ | |||
| const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, \ | |||
| int* workspace, const convolution::ConvParam& param, \ | |||
| uint32_t nonlinear_mode, float alpha, float beta, \ | |||
| float gamma, float scale, \ | |||
| const GemmCoord& threadblock_shape, \ | |||
| const GemmCoord& warp_shape, int stages, \ | |||
| cudaStream_t stream); | |||
| INST(true); | |||
| #undef INST | |||
| /* ====== cutlass kernel wrapper for uint4 x int4 nchw64 layout ====== */ | |||
| #if MEGDNN_TEGRA_X1 | |||
| template <bool NeedLoadFromConstMem> | |||
| void megdnn::cuda::cutlass_wrapper:: | |||
| do_conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64( | |||
| const uint8_t* /* d_src */, const int8_t* /* d_filter */, | |||
| const int32_t* /* d_bias */, const uint8_t* /* d_z */, | |||
| uint8_t* /* d_dst */, int* /* workspace */, | |||
| const convolution::ConvParam& /* param */, | |||
| uint32_t /* nonlinear_mode */, float /* alpha */, | |||
| float /* beta */, float /* gamma */, float /* delta */, | |||
| float /* theta */, float /* scale */, | |||
| uint8_t /* src_zero_point */, | |||
| const GemmCoord& /* threadblock_shape */, | |||
| const GemmCoord& /* warp_shape */, int /* stages */, | |||
| cudaStream_t /* stream */) {} | |||
| #else | |||
| template <bool NeedLoadFromConstMem> | |||
| void megdnn::cuda::cutlass_wrapper:: | |||
| do_conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64( | |||
| const uint8_t* d_src, const int8_t* d_filter, | |||
| const int32_t* d_bias, const uint8_t* d_z, uint8_t* d_dst, | |||
| int* workspace, const convolution::ConvParam& param, | |||
| uint32_t nonlinear_mode, float alpha, float beta, float gamma, | |||
| float delta, float theta, float /* scale */, | |||
| uint8_t src_zero_point, const GemmCoord& threadblock_shape, | |||
| const GemmCoord& warp_shape, int stages, cudaStream_t stream) { | |||
| #define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \ | |||
| threadblock_k_, warp_m_, warp_n_, \ | |||
| warp_k_, stage_) \ | |||
| if (threadblock_shape.m() == threadblock_m_ && \ | |||
| threadblock_shape.n() == threadblock_n_ && \ | |||
| threadblock_shape.k() == threadblock_k_ && \ | |||
| warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ | |||
| warp_shape.k() == warp_k_ && stages == stage_) { \ | |||
| using ThreadBlockShape = \ | |||
| cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \ | |||
| threadblock_k_>; \ | |||
| using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \ | |||
| using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; \ | |||
| using Convolution = cutlass::conv::device::Convolution< \ | |||
| cutlass::uint4b_t, cutlass::layout::TensorNCxHWx<64>, \ | |||
| cutlass::int4b_t, cutlass::layout::TensorCxRSKx<64>, \ | |||
| ElementOutput, cutlass::layout::TensorNCxHWx<64>, int32_t, \ | |||
| cutlass::layout::TensorNCxHWx<64>, int32_t, \ | |||
| cutlass::conv::ConvType::kConvolution, \ | |||
| cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \ | |||
| ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ | |||
| cutlass::conv::threadblock:: \ | |||
| ConvolutionFpropTransThreadblockSwizzle, \ | |||
| stage_, 32, 32, NeedLoadFromConstMem, \ | |||
| cutlass::arch::OpMultiplyAddSaturate, \ | |||
| cutlass::conv::ImplicitGemmMode::GEMM_TN, true>; \ | |||
| typename Convolution::ConvolutionParameter conv_param( \ | |||
| param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ | |||
| param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ | |||
| param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \ | |||
| return cutlass_convolution_wrapper<Convolution>( \ | |||
| reinterpret_cast<const cutlass::uint4b_t*>(d_src), \ | |||
| reinterpret_cast<const cutlass::int4b_t*>(d_filter), d_bias, \ | |||
| reinterpret_cast<const cutlass::uint4b_t*>(d_z), \ | |||
| reinterpret_cast<cutlass::uint4b_t*>(d_dst), workspace, \ | |||
| conv_param, epilogue, stream, {src_zero_point}); \ | |||
| } | |||
| #define DISPATCH_KERNEL \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 128, 64, 64, 128, 2); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 256, 128, 64, 64, 128, 2); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 128, 64, 64, 128, 2); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1); \ | |||
| megdnn_assert(false, \ | |||
| "unsupported threadblock shape (%dx%dx%d) and warp shape " \ | |||
| "(%dx%dx%d)", \ | |||
| threadblock_shape.m(), threadblock_shape.n(), \ | |||
| threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ | |||
| warp_shape.k()); | |||
| using ElementOutput = cutlass::uint4b_t; | |||
| using ElementAccumulator = int32_t; | |||
| using ElementBias = int32_t; | |||
| using ElementCompute = float; | |||
| using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode; | |||
| switch (nonlinear_mode) { | |||
| case NonlineMode::IDENTITY: { | |||
| using EpilogueOp = | |||
| cutlass::epilogue::thread::BiasAddLinearCombinationClamp< | |||
| ElementOutput, 16, ElementAccumulator, ElementBias, | |||
| ElementCompute>; | |||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma, | |||
| delta + theta}; | |||
| DISPATCH_KERNEL; | |||
| } | |||
| case NonlineMode::RELU: { | |||
| using EpilogueOp = cutlass::epilogue::thread:: | |||
| BiasAddLinearCombinationReluClamp< | |||
| ElementOutput, 16, ElementAccumulator, ElementBias, | |||
| ElementCompute>; | |||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma, | |||
| 0, delta, theta}; | |||
| DISPATCH_KERNEL; | |||
| } | |||
| default: | |||
| megdnn_assert(false, | |||
| "unsupported nonlinear mode for conv bias operator"); | |||
| } | |||
| #undef DISPATCH_KERNEL_WITH_TILE_SHAPE | |||
| #undef DISPATCH_KERNEL | |||
| } | |||
| #endif | |||
| #define INST(need_load_from_const_mem) \ | |||
| template void megdnn::cuda::cutlass_wrapper:: \ | |||
| do_conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64< \ | |||
| need_load_from_const_mem>( \ | |||
| const uint8_t* d_src, const int8_t* d_filter, \ | |||
| const int32_t* d_bias, const uint8_t* d_z, uint8_t* d_dst, \ | |||
| int* workspace, const convolution::ConvParam& param, \ | |||
| uint32_t nonlinear_mode, float alpha, float beta, \ | |||
| float gamma, float delta, float theta, float scale, \ | |||
| uint8_t src_zero_point, \ | |||
| const GemmCoord& threadblock_shape, \ | |||
| const GemmCoord& warp_shape, int stages, \ | |||
| cudaStream_t stream); | |||
| INST(true); | |||
| #undef INST | |||
| /* ====== cutlass kernel wrapper for int4 x int4 nhwc layout ====== */ | |||
| #if MEGDNN_TEGRA_X1 | |||
| template <bool NeedLoadFromConstMem> | |||
| void megdnn::cuda::cutlass_wrapper:: | |||
| do_conv_bias_int4_int4_implicit_gemm_imma_nhwc( | |||
| const int8_t* /* d_src */, const int8_t* /* d_filter */, | |||
| const int32_t* /* d_bias */, const int8_t* /* d_z */, | |||
| int8_t* /* d_dst */, int* /* workspace */, | |||
| const convolution::ConvParam& /* param */, | |||
| uint32_t /* nonlinear_mode */, float /* alpha */, | |||
| float /* beta */, float /* gamma */, float /* scale */, | |||
| const GemmCoord& /* threadblock_shape */, | |||
| const GemmCoord& /* warp_shape */, | |||
| const int32_t /* access_size */, int /* stages */, | |||
| cudaStream_t /* stream */) {} | |||
| #else | |||
| template <bool NeedLoadFromConstMem> | |||
| void megdnn::cuda::cutlass_wrapper:: | |||
| do_conv_bias_int4_int4_implicit_gemm_imma_nhwc( | |||
| const int8_t* d_src, const int8_t* d_filter, | |||
| const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, | |||
| int* workspace, const convolution::ConvParam& param, | |||
| uint32_t nonlinear_mode, float alpha, float beta, float gamma, | |||
| float scale, const GemmCoord& threadblock_shape, | |||
| const GemmCoord& warp_shape, const int32_t access_size, | |||
| int stages, cudaStream_t stream) { | |||
| bool without_shared_load = | |||
| ((param.co % threadblock_shape.n() == 0) && | |||
| (threadblock_shape.n() == 32 || threadblock_shape.n() == 64)); | |||
| int out_elements_per_access = | |||
| without_shared_load ? threadblock_shape.n() / 4 : 8; | |||
| #define RUN_CUTLASS_WRAPPER(stage_, access_size_, without_shared_load_) \ | |||
| using Convolution = cutlass::conv::device::Convolution< \ | |||
| cutlass::int4b_t, cutlass::layout::TensorNHWC, cutlass::int4b_t, \ | |||
| cutlass::layout::TensorNCxHWx<access_size_>, ElementOutput, \ | |||
| cutlass::layout::TensorNHWC, int32_t, cutlass::layout::TensorNHWC, \ | |||
| int32_t, cutlass::conv::ConvType::kConvolution, \ | |||
| cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \ | |||
| ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ | |||
| cutlass::conv::threadblock:: \ | |||
| ConvolutionFpropTransThreadblockSwizzle, \ | |||
| stage_, access_size_, access_size_, NeedLoadFromConstMem, \ | |||
| cutlass::arch::OpMultiplyAddSaturate, \ | |||
| cutlass::conv::ImplicitGemmMode::GEMM_TN, without_shared_load_>; \ | |||
| typename Convolution::ConvolutionParameter conv_param( \ | |||
| param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ | |||
| param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ | |||
| param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \ | |||
| return cutlass_convolution_wrapper<Convolution>( \ | |||
| reinterpret_cast<const cutlass::int4b_t*>(d_src), \ | |||
| reinterpret_cast<const cutlass::int4b_t*>(d_filter), d_bias, \ | |||
| reinterpret_cast<const cutlass::int4b_t*>(d_z), \ | |||
| reinterpret_cast<cutlass::int4b_t*>(d_dst), workspace, conv_param, \ | |||
| epilogue, stream); | |||
| #define DISPATCH_KERNEL_WITH_TILE_SHAPE( \ | |||
| threadblock_m_, threadblock_n_, threadblock_k_, warp_m_, warp_n_, \ | |||
| warp_k_, stage_, access_size_, out_elements_per_access_, \ | |||
| without_shared_load_) \ | |||
| if (threadblock_shape.m() == threadblock_m_ && \ | |||
| threadblock_shape.n() == threadblock_n_ && \ | |||
| threadblock_shape.k() == threadblock_k_ && \ | |||
| warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ | |||
| warp_shape.k() == warp_k_ && stages == stage_ && \ | |||
| access_size == access_size_ && \ | |||
| out_elements_per_access == out_elements_per_access_ && \ | |||
| without_shared_load == without_shared_load_) { \ | |||
| using ThreadBlockShape = \ | |||
| cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \ | |||
| threadblock_k_>; \ | |||
| using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \ | |||
| using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; \ | |||
| using ElementOutput = cutlass::int4b_t; \ | |||
| using ElementAccumulator = int32_t; \ | |||
| using ElementBias = int32_t; \ | |||
| using ElementCompute = float; \ | |||
| using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode; \ | |||
| switch (nonlinear_mode) { \ | |||
| case NonlineMode::IDENTITY: { \ | |||
| using EpilogueOp = cutlass::epilogue::thread:: \ | |||
| BiasAddLinearCombinationClamp< \ | |||
| ElementOutput, out_elements_per_access_, \ | |||
| ElementAccumulator, ElementBias, \ | |||
| ElementCompute>; \ | |||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma}; \ | |||
| RUN_CUTLASS_WRAPPER(stage_, access_size_, \ | |||
| without_shared_load_); \ | |||
| } \ | |||
| case NonlineMode::RELU: { \ | |||
| using EpilogueOp = cutlass::epilogue::thread:: \ | |||
| BiasAddLinearCombinationReluClamp< \ | |||
| ElementOutput, out_elements_per_access_, \ | |||
| ElementAccumulator, ElementBias, \ | |||
| ElementCompute>; \ | |||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma, 0}; \ | |||
| RUN_CUTLASS_WRAPPER(stage_, access_size_, \ | |||
| without_shared_load_); \ | |||
| } \ | |||
| case NonlineMode::H_SWISH: { \ | |||
| using EpilogueOp = cutlass::epilogue::thread:: \ | |||
| BiasAddLinearCombinationHSwishClamp< \ | |||
| ElementOutput, out_elements_per_access_, \ | |||
| ElementAccumulator, ElementBias, \ | |||
| ElementCompute>; \ | |||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma, \ | |||
| scale}; \ | |||
| RUN_CUTLASS_WRAPPER(stage_, access_size_, \ | |||
| without_shared_load_); \ | |||
| } \ | |||
| default: \ | |||
| megdnn_assert( \ | |||
| false, \ | |||
| "unsupported nonlinear mode for conv bias operator"); \ | |||
| } \ | |||
| } | |||
| #define DISPATCH_KERNEL \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 32, 8, false); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 16, 8, false); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 8, 8, false); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 32, 8, false); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 16, 8, false); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 8, 8, false); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 32, 8, true); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 16, 8, true); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 8, 8, true); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 32, 16, true); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 16, 16, true); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 8, 16, true); \ | |||
| megdnn_assert(false, \ | |||
| "unsupported threadblock shape (%dx%dx%d) and warp shape " \ | |||
| "(%dx%dx%d) and access_size (%d)", \ | |||
| threadblock_shape.m(), threadblock_shape.n(), \ | |||
| threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ | |||
| warp_shape.k(), access_size); | |||
| DISPATCH_KERNEL; | |||
| #undef RUN_CUTLASS_WRAPPER | |||
| #undef DISPATCH_KERNEL_WITH_TILE_SHAPE | |||
| #undef DISPATCH_KERNEL | |||
| } | |||
| #endif | |||
| #define INST(need_load_from_const_mem) \ | |||
| template void megdnn::cuda::cutlass_wrapper:: \ | |||
| do_conv_bias_int4_int4_implicit_gemm_imma_nhwc< \ | |||
| need_load_from_const_mem>( \ | |||
| const int8_t* d_src, const int8_t* d_filter, \ | |||
| const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, \ | |||
| int* workspace, const convolution::ConvParam& param, \ | |||
| uint32_t nonlinear_mode, float alpha, float beta, \ | |||
| float gamma, float scale, \ | |||
| const GemmCoord& threadblock_shape, \ | |||
| const GemmCoord& warp_shape, const int32_t access_size, \ | |||
| int stages, cudaStream_t stream); | |||
| INST(true); | |||
| INST(false); | |||
| #undef INST | |||
| /* ====== cutlass kernel wrapper for uint4 x int4 nhwc layout ====== */ | |||
| #if MEGDNN_TEGRA_X1 | |||
| template <bool NeedLoadFromConstMem> | |||
| void megdnn::cuda::cutlass_wrapper:: | |||
| do_conv_bias_uint4_int4_implicit_gemm_imma_nhwc( | |||
| const uint8_t* /* d_src */, const int8_t* /* d_filter */, | |||
| const int32_t* /* d_bias */, const uint8_t* /* d_z */, | |||
| uint8_t* /* d_dst */, int* /* workspace */, | |||
| const convolution::ConvParam& /* param */, | |||
| uint32_t /* nonlinear_mode */, float /* alpha */, | |||
| float /* beta */, float /* gamma */, float /* delta */, | |||
| float /* theta */, float /* scale */, | |||
| uint8_t /* src_zero_point */, | |||
| const GemmCoord& /* threadblock_shape */, | |||
| const GemmCoord& /* warp_shape */, | |||
| const int32_t /* access_size */, int /* stages */, | |||
| cudaStream_t /* stream */) {} | |||
| #else | |||
| template <bool NeedLoadFromConstMem> | |||
| void megdnn::cuda::cutlass_wrapper:: | |||
| do_conv_bias_uint4_int4_implicit_gemm_imma_nhwc( | |||
| const uint8_t* d_src, const int8_t* d_filter, | |||
| const int32_t* d_bias, const uint8_t* d_z, uint8_t* d_dst, | |||
| int* workspace, const convolution::ConvParam& param, | |||
| uint32_t nonlinear_mode, float alpha, float beta, float gamma, | |||
| float delta, float theta, float /* scale */, | |||
| uint8_t src_zero_point, const GemmCoord& threadblock_shape, | |||
| const GemmCoord& warp_shape, const int32_t access_size, | |||
| int stages, cudaStream_t stream) { | |||
| bool without_shared_load = | |||
| ((param.co % threadblock_shape.n() == 0) && | |||
| (threadblock_shape.n() == 32 || threadblock_shape.n() == 64)); | |||
| int out_elements_per_access = | |||
| without_shared_load ? threadblock_shape.n() / 4 : 8; | |||
| #define RUN_CUTLASS_WRAPPER(stage_, access_size_, without_shared_load_) \ | |||
| using Convolution = cutlass::conv::device::Convolution< \ | |||
| cutlass::uint4b_t, cutlass::layout::TensorNHWC, cutlass::int4b_t, \ | |||
| cutlass::layout::TensorNCxHWx<access_size_>, ElementOutput, \ | |||
| cutlass::layout::TensorNHWC, int32_t, cutlass::layout::TensorNHWC, \ | |||
| int32_t, cutlass::conv::ConvType::kConvolution, \ | |||
| cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \ | |||
| ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ | |||
| cutlass::conv::threadblock:: \ | |||
| ConvolutionFpropTransThreadblockSwizzle, \ | |||
| stage_, access_size_, access_size_, NeedLoadFromConstMem, \ | |||
| cutlass::arch::OpMultiplyAddSaturate, \ | |||
| cutlass::conv::ImplicitGemmMode::GEMM_TN, without_shared_load_>; \ | |||
| typename Convolution::ConvolutionParameter conv_param( \ | |||
| param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ | |||
| param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ | |||
| param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \ | |||
| return cutlass_convolution_wrapper<Convolution>( \ | |||
| reinterpret_cast<const cutlass::uint4b_t*>(d_src), \ | |||
| reinterpret_cast<const cutlass::int4b_t*>(d_filter), d_bias, \ | |||
| reinterpret_cast<const cutlass::uint4b_t*>(d_z), \ | |||
| reinterpret_cast<cutlass::uint4b_t*>(d_dst), workspace, \ | |||
| conv_param, epilogue, stream, {src_zero_point}); | |||
| #define DISPATCH_KERNEL_WITH_TILE_SHAPE( \ | |||
| threadblock_m_, threadblock_n_, threadblock_k_, warp_m_, warp_n_, \ | |||
| warp_k_, stage_, access_size_, out_elements_per_access_, \ | |||
| without_shared_load_) \ | |||
| if (threadblock_shape.m() == threadblock_m_ && \ | |||
| threadblock_shape.n() == threadblock_n_ && \ | |||
| threadblock_shape.k() == threadblock_k_ && \ | |||
| warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ | |||
| warp_shape.k() == warp_k_ && stages == stage_ && \ | |||
| access_size == access_size_ && \ | |||
| out_elements_per_access == out_elements_per_access_ && \ | |||
| without_shared_load == without_shared_load_) { \ | |||
| using ThreadBlockShape = \ | |||
| cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \ | |||
| threadblock_k_>; \ | |||
| using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \ | |||
| using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; \ | |||
| using ElementOutput = cutlass::uint4b_t; \ | |||
| using ElementAccumulator = int32_t; \ | |||
| using ElementBias = int32_t; \ | |||
| using ElementCompute = float; \ | |||
| using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode; \ | |||
| switch (nonlinear_mode) { \ | |||
| case NonlineMode::IDENTITY: { \ | |||
| using EpilogueOp = cutlass::epilogue::thread:: \ | |||
| BiasAddLinearCombinationClamp< \ | |||
| ElementOutput, out_elements_per_access_, \ | |||
| ElementAccumulator, ElementBias, \ | |||
| ElementCompute>; \ | |||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma, \ | |||
| delta + theta}; \ | |||
| RUN_CUTLASS_WRAPPER(stage_, access_size_, \ | |||
| without_shared_load_); \ | |||
| } \ | |||
| case NonlineMode::RELU: { \ | |||
| using EpilogueOp = cutlass::epilogue::thread:: \ | |||
| BiasAddLinearCombinationReluClamp< \ | |||
| ElementOutput, out_elements_per_access_, \ | |||
| ElementAccumulator, ElementBias, \ | |||
| ElementCompute>; \ | |||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma, \ | |||
| 0, delta, theta}; \ | |||
| RUN_CUTLASS_WRAPPER(stage_, access_size_, \ | |||
| without_shared_load_); \ | |||
| } \ | |||
| default: \ | |||
| megdnn_assert( \ | |||
| false, \ | |||
| "unsupported nonlinear mode for conv bias operator"); \ | |||
| } \ | |||
| } | |||
| #define DISPATCH_KERNEL \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 32, 8, false); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 16, 8, false); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 8, 8, false); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 32, 8, false); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 16, 8, false); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 8, 8, false); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 32, 8, true); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 16, 8, true); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 8, 8, true); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 32, 16, true); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 16, 16, true); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 8, 16, true); \ | |||
| megdnn_assert(false, \ | |||
| "unsupported threadblock shape (%dx%dx%d) and warp shape " \ | |||
| "(%dx%dx%d) and access_size (%d)", \ | |||
| threadblock_shape.m(), threadblock_shape.n(), \ | |||
| threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ | |||
| warp_shape.k(), access_size); | |||
| DISPATCH_KERNEL; | |||
| #undef RUN_CUTLASS_WRAPPER | |||
| #undef DISPATCH_KERNEL_WITH_TILE_SHAPE | |||
| #undef DISPATCH_KERNEL | |||
| } | |||
| #endif | |||
| #define INST(need_load_from_const_mem) \ | |||
| template void megdnn::cuda::cutlass_wrapper:: \ | |||
| do_conv_bias_uint4_int4_implicit_gemm_imma_nhwc< \ | |||
| need_load_from_const_mem>( \ | |||
| const uint8_t* d_src, const int8_t* d_filter, \ | |||
| const int32_t* d_bias, const uint8_t* d_z, uint8_t* d_dst, \ | |||
| int* workspace, const convolution::ConvParam& param, \ | |||
| uint32_t nonlinear_mode, float alpha, float beta, \ | |||
| float gamma, float delta, float theta, float scale, \ | |||
| uint8_t src_zero_point, \ | |||
| const GemmCoord& threadblock_shape, \ | |||
| const GemmCoord& warp_shape, const int32_t access_size, \ | |||
| int stages, cudaStream_t stream); | |||
| INST(true); | |||
| INST(false); | |||
| #undef INST | |||
| // vim: syntax=cuda.doxygen | |||
| @@ -38,7 +38,8 @@ void megdnn::cuda::cutlass_wrapper:: | |||
| uint32_t /* nonlinear_mode */, float /* alpha */, | |||
| float /* beta */, float /* gamma */, float /* scale */, | |||
| const GemmCoord& /* threadblock_shape */, | |||
| const GemmCoord& /* warp_shape */, cudaStream_t /* stream */) {} | |||
| const GemmCoord& /* warp_shape */, int /* stages */, | |||
| cudaStream_t /* stream */) {} | |||
| #else | |||
| template <bool NeedLoadFromConstMem> | |||
| void megdnn::cuda::cutlass_wrapper:: | |||
| @@ -48,15 +49,15 @@ void megdnn::cuda::cutlass_wrapper:: | |||
| int* workspace, const convolution::ConvParam& param, | |||
| uint32_t nonlinear_mode, float alpha, float beta, float gamma, | |||
| float scale, const GemmCoord& threadblock_shape, | |||
| const GemmCoord& warp_shape, cudaStream_t stream) { | |||
| const GemmCoord& warp_shape, int stages, cudaStream_t stream) { | |||
| #define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \ | |||
| threadblock_k_, warp_m_, warp_n_, \ | |||
| warp_k_) \ | |||
| warp_k_, stage_) \ | |||
| if (threadblock_shape.m() == threadblock_m_ && \ | |||
| threadblock_shape.n() == threadblock_n_ && \ | |||
| threadblock_shape.k() == threadblock_k_ && \ | |||
| warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ | |||
| warp_shape.k() == warp_k_) { \ | |||
| warp_shape.k() == warp_k_ && stages == stage_) { \ | |||
| using ThreadBlockShape = \ | |||
| cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \ | |||
| threadblock_k_>; \ | |||
| @@ -71,8 +72,10 @@ void megdnn::cuda::cutlass_wrapper:: | |||
| cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \ | |||
| ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ | |||
| cutlass::conv::threadblock:: \ | |||
| ConvolutionFpropNCxHWxThreadblockSwizzle, \ | |||
| 2, 16, 16, NeedLoadFromConstMem>; \ | |||
| ConvolutionFpropTransThreadblockSwizzle, \ | |||
| stage_, 16, 16, NeedLoadFromConstMem, \ | |||
| cutlass::arch::OpMultiplyAddSaturate, \ | |||
| cutlass::conv::ImplicitGemmMode::GEMM_TN, true>; \ | |||
| typename Convolution::ConvolutionParameter conv_param( \ | |||
| param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ | |||
| param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ | |||
| @@ -82,13 +85,15 @@ void megdnn::cuda::cutlass_wrapper:: | |||
| epilogue, stream); \ | |||
| } | |||
| #define DISPATCH_KERNEL \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(256, 128, 64, 64, 64, 64); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 256, 64, 64, 64, 64); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 64, 64, 64, 64); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 64, 32, 64, 64); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 32, 64); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 64, 64, 32, 32, 64); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 64, 64, 32, 16, 64); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 256, 64, 64, 64, 64, 2); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(256, 128, 64, 64, 64, 64, 2); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 64, 64, 64, 64, 2); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 32, 64, 2); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 64, 32, 64, 64, 2); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 32, 64, 32, 32, 1); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 32, 64, 32, 32, 1); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 32, 64, 32, 1); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 1); \ | |||
| megdnn_assert(false, \ | |||
| "unsupported threadblock shape (%dx%dx%d) and warp shape " \ | |||
| "(%dx%dx%d)", \ | |||
| @@ -144,7 +149,8 @@ void megdnn::cuda::cutlass_wrapper:: | |||
| uint32_t nonlinear_mode, float alpha, float beta, \ | |||
| float gamma, float scale, \ | |||
| const GemmCoord& threadblock_shape, \ | |||
| const GemmCoord& warp_shape, cudaStream_t stream); | |||
| const GemmCoord& warp_shape, int stages, \ | |||
| cudaStream_t stream); | |||
| INST(true); | |||
| INST(false); | |||
| #undef INST | |||
| @@ -162,7 +168,8 @@ void megdnn::cuda::cutlass_wrapper:: | |||
| uint32_t /* nonlinear_mode */, float /* alpha */, | |||
| float /* beta */, float /* gamma */, float /* scale */, | |||
| const GemmCoord& /* threadblock_shape */, | |||
| const GemmCoord& /* warp_shape */, cudaStream_t /* stream */) {} | |||
| const GemmCoord& /* warp_shape */, int /* stages */, | |||
| cudaStream_t /* stream */) {} | |||
| #else | |||
| template <bool NeedLoadFromConstMem> | |||
| void megdnn::cuda::cutlass_wrapper:: | |||
| @@ -172,15 +179,15 @@ void megdnn::cuda::cutlass_wrapper:: | |||
| int* workspace, const convolution::ConvParam& param, | |||
| uint32_t nonlinear_mode, float alpha, float beta, float gamma, | |||
| float scale, const GemmCoord& threadblock_shape, | |||
| const GemmCoord& warp_shape, cudaStream_t stream) { | |||
| const GemmCoord& warp_shape, int stages, cudaStream_t stream) { | |||
| #define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \ | |||
| threadblock_k_, warp_m_, warp_n_, \ | |||
| warp_k_) \ | |||
| warp_k_, stage_) \ | |||
| if (threadblock_shape.m() == threadblock_m_ && \ | |||
| threadblock_shape.n() == threadblock_n_ && \ | |||
| threadblock_shape.k() == threadblock_k_ && \ | |||
| warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ | |||
| warp_shape.k() == warp_k_) { \ | |||
| warp_shape.k() == warp_k_ && stages == stage_) { \ | |||
| using ThreadBlockShape = \ | |||
| cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \ | |||
| threadblock_k_>; \ | |||
| @@ -196,7 +203,7 @@ void megdnn::cuda::cutlass_wrapper:: | |||
| ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ | |||
| cutlass::conv::threadblock:: \ | |||
| ConvolutionFpropNCxHWxThreadblockSwizzle, \ | |||
| 2, 16, 16, NeedLoadFromConstMem>; \ | |||
| stage_, 16, 16, NeedLoadFromConstMem>; \ | |||
| typename Convolution::ConvolutionParameter conv_param( \ | |||
| param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ | |||
| param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ | |||
| @@ -206,13 +213,15 @@ void megdnn::cuda::cutlass_wrapper:: | |||
| epilogue, stream); \ | |||
| } | |||
| #define DISPATCH_KERNEL \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(256, 128, 64, 64, 64, 64); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 256, 64, 64, 64, 64); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 64, 64, 64, 64); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 64, 32, 64, 64); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 32, 64); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 64, 64, 32, 32, 64); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 64, 64, 16, 32, 64); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 256, 64, 64, 64, 64, 2); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(256, 128, 64, 64, 64, 64, 2); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 64, 64, 64, 64, 2); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 32, 64, 2); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 64, 32, 64, 64, 2); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 32, 64, 32, 32, 1); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 32, 64, 32, 32, 1); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 32, 64, 32, 1); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 1); \ | |||
| megdnn_assert(false, \ | |||
| "unsupported threadblock shape (%dx%dx%d) and warp shape " \ | |||
| "(%dx%dx%d)", \ | |||
| @@ -268,7 +277,8 @@ void megdnn::cuda::cutlass_wrapper:: | |||
| uint32_t nonlinear_mode, float alpha, float beta, \ | |||
| float gamma, float scale, \ | |||
| const GemmCoord& threadblock_shape, \ | |||
| const GemmCoord& warp_shape, cudaStream_t stream); | |||
| const GemmCoord& warp_shape, int stages, \ | |||
| cudaStream_t stream); | |||
| INST(true); | |||
| INST(false); | |||
| #undef INST | |||
| @@ -337,10 +347,8 @@ void megdnn::cuda::cutlass_wrapper:: | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 64, 32, 32, 2, 16); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 32, 64, 32, 32, 2, 16); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 2, 16); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 64, 32, 64, 32, 32, 2, 16); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 64, 32, 32, 64, 32, 2, 16); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 32, 32, 64, 32, 32, 2, 16); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 32, 32, 32, 32, 32, 2, 16); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 128, 16, 16, 128, 16, 1, 8); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 64, 8, 16, 64, 8, 2, 4); \ | |||
| megdnn_assert(false, \ | |||
| @@ -468,10 +476,8 @@ void megdnn::cuda::cutlass_wrapper:: | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 64, 32, 32, 2, 16); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 32, 64, 32, 32, 2, 16); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 2, 16); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 64, 32, 64, 32, 32, 2, 16); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 64, 32, 32, 64, 32, 2, 16); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 32, 32, 64, 32, 32, 2, 16); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 32, 32, 32, 32, 32, 2, 16); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 128, 16, 16, 128, 16, 1, 8); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 64, 8, 16, 64, 8, 2, 4); \ | |||
| megdnn_assert(false, \ | |||
| @@ -599,10 +605,8 @@ void megdnn::cuda::cutlass_wrapper:: | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 64, 32, 32, 2, 16); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 32, 64, 32, 32, 2, 16); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 2, 16); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 64, 32, 64, 32, 32, 2, 16); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 64, 32, 32, 64, 32, 2, 16); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 32, 32, 64, 32, 32, 2, 16); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 32, 32, 32, 32, 32, 2, 16); \ | |||
| megdnn_assert(false, \ | |||
| "unsupported threadblock shape (%dx%dx%d) and warp shape " \ | |||
| "(%dx%dx%d)", \ | |||
| @@ -664,246 +668,6 @@ INST(true); | |||
| INST(false); | |||
| #undef INST | |||
| /* ====== cutlass kernel wrapper for int4 x int4 nchw64 layout ====== */ | |||
| #if MEGDNN_TEGRA_X1 | |||
| template <bool NeedLoadFromConstMem> | |||
| void megdnn::cuda::cutlass_wrapper:: | |||
| do_conv_bias_int4_int4_implicit_gemm_imma_ncdiv64hw64( | |||
| const int8_t* /* d_src */, const int8_t* /* d_filter */, | |||
| const int32_t* /* d_bias */, const int8_t* /* d_z */, | |||
| int8_t* /* d_dst */, int* /* workspace */, | |||
| const convolution::ConvParam& /* param */, | |||
| uint32_t /* nonlinear_mode */, float /* alpha */, | |||
| float /* beta */, float /* gamma */, float /* scale */, | |||
| const GemmCoord& /* threadblock_shape */, | |||
| const GemmCoord& /* warp_shape */, cudaStream_t /* stream */) {} | |||
| #else | |||
| template <bool NeedLoadFromConstMem> | |||
| void megdnn::cuda::cutlass_wrapper:: | |||
| do_conv_bias_int4_int4_implicit_gemm_imma_ncdiv64hw64( | |||
| const int8_t* d_src, const int8_t* d_filter, | |||
| const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, | |||
| int* workspace, const convolution::ConvParam& param, | |||
| uint32_t nonlinear_mode, float alpha, float beta, float gamma, | |||
| float scale, const GemmCoord& threadblock_shape, | |||
| const GemmCoord& warp_shape, cudaStream_t stream) { | |||
| #define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \ | |||
| threadblock_k_, warp_m_, warp_n_, \ | |||
| warp_k_) \ | |||
| if (threadblock_shape.m() == threadblock_m_ && \ | |||
| threadblock_shape.n() == threadblock_n_ && \ | |||
| threadblock_shape.k() == threadblock_k_ && \ | |||
| warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ | |||
| warp_shape.k() == warp_k_) { \ | |||
| using ThreadBlockShape = \ | |||
| cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \ | |||
| threadblock_k_>; \ | |||
| using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \ | |||
| using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; \ | |||
| using Convolution = cutlass::conv::device::Convolution< \ | |||
| cutlass::int4b_t, cutlass::layout::TensorNCxHWx<64>, \ | |||
| cutlass::int4b_t, cutlass::layout::TensorCxRSKx<64>, \ | |||
| ElementOutput, cutlass::layout::TensorNCxHWx<64>, int32_t, \ | |||
| cutlass::layout::TensorNCxHWx<64>, int32_t, \ | |||
| cutlass::conv::ConvType::kConvolution, \ | |||
| cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \ | |||
| ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ | |||
| cutlass::conv::threadblock:: \ | |||
| ConvolutionFpropNCxHWxThreadblockSwizzle, \ | |||
| 2, 32, 32, NeedLoadFromConstMem>; \ | |||
| typename Convolution::ConvolutionParameter conv_param( \ | |||
| param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ | |||
| param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ | |||
| param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \ | |||
| return cutlass_convolution_wrapper<Convolution>( \ | |||
| reinterpret_cast<const cutlass::int4b_t*>(d_src), \ | |||
| reinterpret_cast<const cutlass::int4b_t*>(d_filter), d_bias, \ | |||
| reinterpret_cast<const cutlass::int4b_t*>(d_z), \ | |||
| reinterpret_cast<cutlass::int4b_t*>(d_dst), workspace, \ | |||
| conv_param, epilogue, stream); \ | |||
| } | |||
| #define DISPATCH_KERNEL \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 128, 64, 64, 128); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(256, 128, 128, 64, 64, 128); \ | |||
| megdnn_assert(false, \ | |||
| "unsupported threadblock shape (%dx%dx%d) and warp shape " \ | |||
| "(%dx%dx%d)", \ | |||
| threadblock_shape.m(), threadblock_shape.n(), \ | |||
| threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ | |||
| warp_shape.k()); | |||
| using ElementOutput = cutlass::int4b_t; | |||
| using ElementAccumulator = int32_t; | |||
| using ElementBias = int32_t; | |||
| using ElementCompute = float; | |||
| using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode; | |||
| switch (nonlinear_mode) { | |||
| case NonlineMode::IDENTITY: { | |||
| using EpilogueOp = | |||
| cutlass::epilogue::thread::BiasAddLinearCombinationClamp< | |||
| ElementOutput, 16, ElementAccumulator, ElementBias, | |||
| ElementCompute>; | |||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma}; | |||
| DISPATCH_KERNEL; | |||
| } | |||
| case NonlineMode::RELU: { | |||
| using EpilogueOp = cutlass::epilogue::thread:: | |||
| BiasAddLinearCombinationReluClamp< | |||
| ElementOutput, 16, ElementAccumulator, ElementBias, | |||
| ElementCompute>; | |||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma, 0}; | |||
| DISPATCH_KERNEL; | |||
| } | |||
| case NonlineMode::H_SWISH: { | |||
| using EpilogueOp = cutlass::epilogue::thread:: | |||
| BiasAddLinearCombinationHSwishClamp< | |||
| ElementOutput, 16, ElementAccumulator, ElementBias, | |||
| ElementCompute>; | |||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma, scale}; | |||
| DISPATCH_KERNEL; | |||
| } | |||
| default: | |||
| megdnn_assert(false, | |||
| "unsupported nonlinear mode for conv bias operator"); | |||
| } | |||
| #undef DISPATCH_KERNEL_WITH_TILE_SHAPE | |||
| #undef DISPATCH_KERNEL | |||
| } | |||
| #endif | |||
| #define INST(need_load_from_const_mem) \ | |||
| template void megdnn::cuda::cutlass_wrapper:: \ | |||
| do_conv_bias_int4_int4_implicit_gemm_imma_ncdiv64hw64< \ | |||
| need_load_from_const_mem>( \ | |||
| const int8_t* d_src, const int8_t* d_filter, \ | |||
| const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, \ | |||
| int* workspace, const convolution::ConvParam& param, \ | |||
| uint32_t nonlinear_mode, float alpha, float beta, \ | |||
| float gamma, float scale, \ | |||
| const GemmCoord& threadblock_shape, \ | |||
| const GemmCoord& warp_shape, cudaStream_t stream); | |||
| INST(true); | |||
| #undef INST | |||
| /* ====== cutlass kernel wrapper for uint4 x int4 nchw64 layout ====== */ | |||
| #if MEGDNN_TEGRA_X1 | |||
| template <bool NeedLoadFromConstMem> | |||
| void megdnn::cuda::cutlass_wrapper:: | |||
| do_conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64( | |||
| const uint8_t* /* d_src */, const int8_t* /* d_filter */, | |||
| const int32_t* /* d_bias */, const uint8_t* /* d_z */, | |||
| uint8_t* /* d_dst */, int* /* workspace */, | |||
| const convolution::ConvParam& /* param */, | |||
| uint32_t /* nonlinear_mode */, float /* alpha */, | |||
| float /* beta */, float /* gamma */, float /* delta */, | |||
| float /* theta */, float /* scale */, | |||
| uint8_t /* src_zero_point */, | |||
| const GemmCoord& /* threadblock_shape */, | |||
| const GemmCoord& /* warp_shape */, cudaStream_t /* stream */) {} | |||
| #else | |||
| template <bool NeedLoadFromConstMem> | |||
| void megdnn::cuda::cutlass_wrapper:: | |||
| do_conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64( | |||
| const uint8_t* d_src, const int8_t* d_filter, | |||
| const int32_t* d_bias, const uint8_t* d_z, uint8_t* d_dst, | |||
| int* workspace, const convolution::ConvParam& param, | |||
| uint32_t nonlinear_mode, float alpha, float beta, float gamma, | |||
| float delta, float theta, float /* scale */, | |||
| uint8_t src_zero_point, const GemmCoord& threadblock_shape, | |||
| const GemmCoord& warp_shape, cudaStream_t stream) { | |||
| #define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \ | |||
| threadblock_k_, warp_m_, warp_n_, \ | |||
| warp_k_) \ | |||
| if (threadblock_shape.m() == threadblock_m_ && \ | |||
| threadblock_shape.n() == threadblock_n_ && \ | |||
| threadblock_shape.k() == threadblock_k_ && \ | |||
| warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ | |||
| warp_shape.k() == warp_k_) { \ | |||
| using ThreadBlockShape = \ | |||
| cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \ | |||
| threadblock_k_>; \ | |||
| using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \ | |||
| using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; \ | |||
| using Convolution = cutlass::conv::device::Convolution< \ | |||
| cutlass::uint4b_t, cutlass::layout::TensorNCxHWx<64>, \ | |||
| cutlass::int4b_t, cutlass::layout::TensorCxRSKx<64>, \ | |||
| ElementOutput, cutlass::layout::TensorNCxHWx<64>, int32_t, \ | |||
| cutlass::layout::TensorNCxHWx<64>, int32_t, \ | |||
| cutlass::conv::ConvType::kConvolution, \ | |||
| cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \ | |||
| ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ | |||
| cutlass::conv::threadblock:: \ | |||
| ConvolutionFpropNCxHWxThreadblockSwizzle, \ | |||
| 2, 32, 32, NeedLoadFromConstMem>; \ | |||
| typename Convolution::ConvolutionParameter conv_param( \ | |||
| param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ | |||
| param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ | |||
| param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \ | |||
| return cutlass_convolution_wrapper<Convolution>( \ | |||
| reinterpret_cast<const cutlass::uint4b_t*>(d_src), \ | |||
| reinterpret_cast<const cutlass::int4b_t*>(d_filter), d_bias, \ | |||
| reinterpret_cast<const cutlass::uint4b_t*>(d_z), \ | |||
| reinterpret_cast<cutlass::uint4b_t*>(d_dst), workspace, \ | |||
| conv_param, epilogue, stream, {src_zero_point}); \ | |||
| } | |||
| #define DISPATCH_KERNEL \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 128, 64, 64, 128); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(256, 128, 128, 64, 64, 128); \ | |||
| megdnn_assert(false, \ | |||
| "unsupported threadblock shape (%dx%dx%d) and warp shape " \ | |||
| "(%dx%dx%d)", \ | |||
| threadblock_shape.m(), threadblock_shape.n(), \ | |||
| threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ | |||
| warp_shape.k()); | |||
| using ElementOutput = cutlass::uint4b_t; | |||
| using ElementAccumulator = int32_t; | |||
| using ElementBias = int32_t; | |||
| using ElementCompute = float; | |||
| using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode; | |||
| switch (nonlinear_mode) { | |||
| case NonlineMode::IDENTITY: { | |||
| using EpilogueOp = | |||
| cutlass::epilogue::thread::BiasAddLinearCombinationClamp< | |||
| ElementOutput, 16, ElementAccumulator, ElementBias, | |||
| ElementCompute>; | |||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma, | |||
| delta + theta}; | |||
| DISPATCH_KERNEL; | |||
| } | |||
| case NonlineMode::RELU: { | |||
| using EpilogueOp = cutlass::epilogue::thread:: | |||
| BiasAddLinearCombinationReluClamp< | |||
| ElementOutput, 16, ElementAccumulator, ElementBias, | |||
| ElementCompute>; | |||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma, | |||
| 0, delta, theta}; | |||
| DISPATCH_KERNEL; | |||
| } | |||
| default: | |||
| megdnn_assert(false, | |||
| "unsupported nonlinear mode for conv bias operator"); | |||
| } | |||
| #undef DISPATCH_KERNEL_WITH_TILE_SHAPE | |||
| #undef DISPATCH_KERNEL | |||
| } | |||
| #endif | |||
| #define INST(need_load_from_const_mem) \ | |||
| template void megdnn::cuda::cutlass_wrapper:: \ | |||
| do_conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64< \ | |||
| need_load_from_const_mem>( \ | |||
| const uint8_t* d_src, const int8_t* d_filter, \ | |||
| const int32_t* d_bias, const uint8_t* d_z, uint8_t* d_dst, \ | |||
| int* workspace, const convolution::ConvParam& param, \ | |||
| uint32_t nonlinear_mode, float alpha, float beta, \ | |||
| float gamma, float delta, float theta, float scale, \ | |||
| uint8_t src_zero_point, \ | |||
| const GemmCoord& threadblock_shape, \ | |||
| const GemmCoord& warp_shape, cudaStream_t stream); | |||
| INST(true); | |||
| #undef INST | |||
| /* ===== cutlass kernel wrapper for nchw4 layout and nhwc output ===== */ | |||
| #if MEGDNN_TEGRA_X1 | |||
| template <bool signedness> | |||
| @@ -970,10 +734,8 @@ void megdnn::cuda::cutlass_wrapper:: | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 64, 32, 32, 2, 16); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 32, 64, 32, 32, 2, 16); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 2, 16); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 64, 32, 64, 32, 32, 2, 16); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 64, 32, 32, 64, 32, 2, 16); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 32, 32, 64, 32, 32, 2, 16); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 32, 32, 32, 32, 32, 2, 16); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 128, 16, 16, 128, 16, 1, 8); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 64, 8, 16, 64, 8, 2, 4); \ | |||
| megdnn_assert(false, \ | |||
| @@ -1039,262 +801,4 @@ INST(true); | |||
| INST(false); | |||
| #undef INST | |||
| /* ====== cutlass kernel wrapper for int4 x int4 nchw64 layout ====== */ | |||
| #if MEGDNN_TEGRA_X1 | |||
| template <bool NeedLoadFromConstMem> | |||
| void megdnn::cuda::cutlass_wrapper:: | |||
| do_conv_bias_int4_int4_implicit_gemm_imma_nhwc( | |||
| const int8_t* /* d_src */, const int8_t* /* d_filter */, | |||
| const int32_t* /* d_bias */, const int8_t* /* d_z */, | |||
| int8_t* /* d_dst */, int* /* workspace */, | |||
| const convolution::ConvParam& /* param */, | |||
| uint32_t /* nonlinear_mode */, float /* alpha */, | |||
| float /* beta */, float /* gamma */, float /* scale */, | |||
| const GemmCoord& /* threadblock_shape */, | |||
| const GemmCoord& /* warp_shape */, | |||
| const int32_t /* access_size */, cudaStream_t /* stream */) {} | |||
| #else | |||
| template <bool NeedLoadFromConstMem> | |||
| void megdnn::cuda::cutlass_wrapper:: | |||
| do_conv_bias_int4_int4_implicit_gemm_imma_nhwc( | |||
| const int8_t* d_src, const int8_t* d_filter, | |||
| const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, | |||
| int* workspace, const convolution::ConvParam& param, | |||
| uint32_t nonlinear_mode, float alpha, float beta, float gamma, | |||
| float scale, const GemmCoord& threadblock_shape, | |||
| const GemmCoord& warp_shape, const int32_t access_size, | |||
| cudaStream_t stream) { | |||
| #define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \ | |||
| threadblock_k_, warp_m_, warp_n_, \ | |||
| warp_k_, access_size_) \ | |||
| if (threadblock_shape.m() == threadblock_m_ && \ | |||
| threadblock_shape.n() == threadblock_n_ && \ | |||
| threadblock_shape.k() == threadblock_k_ && \ | |||
| warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ | |||
| warp_shape.k() == warp_k_ && access_size == access_size_) { \ | |||
| using ThreadBlockShape = \ | |||
| cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \ | |||
| threadblock_k_>; \ | |||
| using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \ | |||
| using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; \ | |||
| using Convolution = cutlass::conv::device::Convolution< \ | |||
| cutlass::int4b_t, cutlass::layout::TensorNHWC, \ | |||
| cutlass::int4b_t, cutlass::layout::TensorNCxHWx<access_size_>, \ | |||
| ElementOutput, cutlass::layout::TensorNHWC, int32_t, \ | |||
| cutlass::layout::TensorNHWC, int32_t, \ | |||
| cutlass::conv::ConvType::kConvolution, \ | |||
| cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \ | |||
| ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ | |||
| cutlass::conv::threadblock:: \ | |||
| ConvolutionFpropNHWCThreadblockSwizzle, \ | |||
| 2, access_size_, access_size_, NeedLoadFromConstMem, \ | |||
| cutlass::arch::OpMultiplyAddSaturate, \ | |||
| cutlass::conv::ImplicitGemmMode::GEMM_TN>; \ | |||
| typename Convolution::ConvolutionParameter conv_param( \ | |||
| param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ | |||
| param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ | |||
| param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \ | |||
| return cutlass_convolution_wrapper<Convolution>( \ | |||
| reinterpret_cast<const cutlass::int4b_t*>(d_src), \ | |||
| reinterpret_cast<const cutlass::int4b_t*>(d_filter), d_bias, \ | |||
| reinterpret_cast<const cutlass::int4b_t*>(d_z), \ | |||
| reinterpret_cast<cutlass::int4b_t*>(d_dst), workspace, \ | |||
| conv_param, epilogue, stream); \ | |||
| } | |||
| #define DISPATCH_KERNEL \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 32); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 16); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 8); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 32); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 16); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 8); \ | |||
| megdnn_assert(false, \ | |||
| "unsupported threadblock shape (%dx%dx%d) and warp shape " \ | |||
| "(%dx%dx%d) and access_size (%d)", \ | |||
| threadblock_shape.m(), threadblock_shape.n(), \ | |||
| threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ | |||
| warp_shape.k(), access_size); | |||
| using ElementOutput = cutlass::int4b_t; | |||
| using ElementAccumulator = int32_t; | |||
| using ElementBias = int32_t; | |||
| using ElementCompute = float; | |||
| using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode; | |||
| switch (nonlinear_mode) { | |||
| case NonlineMode::IDENTITY: { | |||
| using EpilogueOp = | |||
| cutlass::epilogue::thread::BiasAddLinearCombinationClamp< | |||
| ElementOutput, 8, ElementAccumulator, ElementBias, | |||
| ElementCompute>; | |||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma}; | |||
| DISPATCH_KERNEL; | |||
| } | |||
| case NonlineMode::RELU: { | |||
| using EpilogueOp = cutlass::epilogue::thread:: | |||
| BiasAddLinearCombinationReluClamp< | |||
| ElementOutput, 8, ElementAccumulator, ElementBias, | |||
| ElementCompute>; | |||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma, 0}; | |||
| DISPATCH_KERNEL; | |||
| } | |||
| case NonlineMode::H_SWISH: { | |||
| using EpilogueOp = cutlass::epilogue::thread:: | |||
| BiasAddLinearCombinationHSwishClamp< | |||
| ElementOutput, 8, ElementAccumulator, ElementBias, | |||
| ElementCompute>; | |||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma, scale}; | |||
| DISPATCH_KERNEL; | |||
| } | |||
| default: | |||
| megdnn_assert(false, | |||
| "unsupported nonlinear mode for conv bias operator"); | |||
| } | |||
| #undef DISPATCH_KERNEL_WITH_TILE_SHAPE | |||
| #undef DISPATCH_KERNEL | |||
| } | |||
| #endif | |||
| #define INST(need_load_from_const_mem) \ | |||
| template void megdnn::cuda::cutlass_wrapper:: \ | |||
| do_conv_bias_int4_int4_implicit_gemm_imma_nhwc< \ | |||
| need_load_from_const_mem>( \ | |||
| const int8_t* d_src, const int8_t* d_filter, \ | |||
| const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, \ | |||
| int* workspace, const convolution::ConvParam& param, \ | |||
| uint32_t nonlinear_mode, float alpha, float beta, \ | |||
| float gamma, float scale, \ | |||
| const GemmCoord& threadblock_shape, \ | |||
| const GemmCoord& warp_shape, const int32_t access_size, \ | |||
| cudaStream_t stream); | |||
| INST(true); | |||
| INST(false); | |||
| #undef INST | |||
| #if MEGDNN_TEGRA_X1 | |||
| template <bool NeedLoadFromConstMem> | |||
| void megdnn::cuda::cutlass_wrapper:: | |||
| do_conv_bias_uint4_int4_implicit_gemm_imma_nhwc( | |||
| const uint8_t* /* d_src */, const int8_t* /* d_filter */, | |||
| const int32_t* /* d_bias */, const uint8_t* /* d_z */, | |||
| uint8_t* /* d_dst */, int* /* workspace */, | |||
| const convolution::ConvParam& /* param */, | |||
| uint32_t /* nonlinear_mode */, float /* alpha */, | |||
| float /* beta */, float /* gamma */, float /* delta */, | |||
| float /* theta */, float /* scale */, | |||
| uint8_t /* src_zero_point */, | |||
| const GemmCoord& /* threadblock_shape */, | |||
| const GemmCoord& /* warp_shape */, | |||
| const int32_t /* access_size */, cudaStream_t /* stream */) {} | |||
| #else | |||
| template <bool NeedLoadFromConstMem> | |||
| void megdnn::cuda::cutlass_wrapper:: | |||
| do_conv_bias_uint4_int4_implicit_gemm_imma_nhwc( | |||
| const uint8_t* d_src, const int8_t* d_filter, | |||
| const int32_t* d_bias, const uint8_t* d_z, uint8_t* d_dst, | |||
| int* workspace, const convolution::ConvParam& param, | |||
| uint32_t nonlinear_mode, float alpha, float beta, float gamma, | |||
| float delta, float theta, float /* scale */, | |||
| uint8_t src_zero_point, const GemmCoord& threadblock_shape, | |||
| const GemmCoord& warp_shape, const int32_t access_size, | |||
| cudaStream_t stream) { | |||
| #define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \ | |||
| threadblock_k_, warp_m_, warp_n_, \ | |||
| warp_k_, access_size_) \ | |||
| if (threadblock_shape.m() == threadblock_m_ && \ | |||
| threadblock_shape.n() == threadblock_n_ && \ | |||
| threadblock_shape.k() == threadblock_k_ && \ | |||
| warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ | |||
| warp_shape.k() == warp_k_ && access_size == access_size_) { \ | |||
| using ThreadBlockShape = \ | |||
| cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \ | |||
| threadblock_k_>; \ | |||
| using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \ | |||
| using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; \ | |||
| using Convolution = cutlass::conv::device::Convolution< \ | |||
| cutlass::uint4b_t, cutlass::layout::TensorNHWC, \ | |||
| cutlass::int4b_t, cutlass::layout::TensorNCxHWx<access_size_>, \ | |||
| ElementOutput, cutlass::layout::TensorNHWC, int32_t, \ | |||
| cutlass::layout::TensorNHWC, int32_t, \ | |||
| cutlass::conv::ConvType::kConvolution, \ | |||
| cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \ | |||
| ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ | |||
| cutlass::conv::threadblock:: \ | |||
| ConvolutionFpropNHWCThreadblockSwizzle, \ | |||
| 2, access_size_, access_size_, NeedLoadFromConstMem, \ | |||
| cutlass::arch::OpMultiplyAddSaturate, \ | |||
| cutlass::conv::ImplicitGemmMode::GEMM_TN>; \ | |||
| typename Convolution::ConvolutionParameter conv_param( \ | |||
| param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ | |||
| param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ | |||
| param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \ | |||
| return cutlass_convolution_wrapper<Convolution>( \ | |||
| reinterpret_cast<const cutlass::uint4b_t*>(d_src), \ | |||
| reinterpret_cast<const cutlass::int4b_t*>(d_filter), d_bias, \ | |||
| reinterpret_cast<const cutlass::uint4b_t*>(d_z), \ | |||
| reinterpret_cast<cutlass::uint4b_t*>(d_dst), workspace, \ | |||
| conv_param, epilogue, stream, {src_zero_point}); \ | |||
| } | |||
| #define DISPATCH_KERNEL \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 32); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 16); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 8); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 32); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 16); \ | |||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 8); \ | |||
| megdnn_assert(false, \ | |||
| "unsupported threadblock shape (%dx%dx%d) and warp shape " \ | |||
| "(%dx%dx%d) and access_size (%d)", \ | |||
| threadblock_shape.m(), threadblock_shape.n(), \ | |||
| threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ | |||
| warp_shape.k(), access_size); | |||
| using ElementOutput = cutlass::uint4b_t; | |||
| using ElementAccumulator = int32_t; | |||
| using ElementBias = int32_t; | |||
| using ElementCompute = float; | |||
| using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode; | |||
| switch (nonlinear_mode) { | |||
| case NonlineMode::IDENTITY: { | |||
| using EpilogueOp = | |||
| cutlass::epilogue::thread::BiasAddLinearCombinationClamp< | |||
| ElementOutput, 8, ElementAccumulator, ElementBias, | |||
| ElementCompute>; | |||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma, | |||
| delta + theta}; | |||
| DISPATCH_KERNEL; | |||
| } | |||
| case NonlineMode::RELU: { | |||
| using EpilogueOp = cutlass::epilogue::thread:: | |||
| BiasAddLinearCombinationReluClamp< | |||
| ElementOutput, 8, ElementAccumulator, ElementBias, | |||
| ElementCompute>; | |||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma, | |||
| 0, delta, theta}; | |||
| DISPATCH_KERNEL; | |||
| } | |||
| default: | |||
| megdnn_assert(false, | |||
| "unsupported nonlinear mode for conv bias operator"); | |||
| } | |||
| #undef DISPATCH_KERNEL_WITH_TILE_SHAPE | |||
| #undef DISPATCH_KERNEL | |||
| } | |||
| #endif | |||
| #define INST(need_load_from_const_mem) \ | |||
| template void megdnn::cuda::cutlass_wrapper:: \ | |||
| do_conv_bias_uint4_int4_implicit_gemm_imma_nhwc< \ | |||
| need_load_from_const_mem>( \ | |||
| const uint8_t* d_src, const int8_t* d_filter, \ | |||
| const int32_t* d_bias, const uint8_t* d_z, uint8_t* d_dst, \ | |||
| int* workspace, const convolution::ConvParam& param, \ | |||
| uint32_t nonlinear_mode, float alpha, float beta, \ | |||
| float gamma, float delta, float theta, float scale, \ | |||
| uint8_t src_zero_point, \ | |||
| const GemmCoord& threadblock_shape, \ | |||
| const GemmCoord& warp_shape, const int32_t access_size, \ | |||
| cudaStream_t stream); | |||
| INST(true); | |||
| INST(false); | |||
| #undef INST | |||
| // vim: syntax=cuda.doxygen | |||
| @@ -0,0 +1,194 @@ | |||
| /** | |||
| * \file dnn/src/cuda/conv_bias/cutlass_reorder_filter.cu | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "src/cuda/conv_bias/cutlass_reorder_filter.cuh" | |||
| #include "src/cuda/query_blocksize.cuh" | |||
| #include "src/cuda/integer_subbyte_utils.cuh" | |||
| using namespace megdnn; | |||
| using namespace cuda; | |||
| using namespace cutlass_wrapper; | |||
| namespace { | |||
| template <uint32_t size_bits, uint32_t interleaved> | |||
| __device__ __forceinline__ void reorder_ncxhwx_imma_filter_func( | |||
| int8_t* dst, const int8_t* src, uint32_t OC, uint32_t IC, uint32_t FH, | |||
| uint32_t FW, uint32_t lane, bool trans_oc) { | |||
| static constexpr uint32_t elements_per_lane = 128 / size_bits; | |||
| static constexpr uint32_t threads_per_interleaved = | |||
| interleaved / elements_per_lane; | |||
| static constexpr uint32_t instruction_shape_col = 8; | |||
| // 4 threads per Quad | |||
| static constexpr uint32_t elements_per_thread = instruction_shape_col / 4; | |||
| // 4 threads per Quad | |||
| static constexpr uint32_t reordered_elements_per_thread = interleaved / 4; | |||
| uint32_t id = lane / threads_per_interleaved; | |||
| uint32_t residue = lane % threads_per_interleaved; | |||
| uint32_t ICx = IC / interleaved; | |||
| uint32_t row = id / (ICx * FH * FW); | |||
| uint32_t col = id - row * ICx * FH * FW; | |||
| // transpose ncxhwx to cxhwnx | |||
| uint32_t src_offset = id * interleaved + residue * elements_per_lane; | |||
| row = (trans_oc) ? (row / interleaved) * interleaved + | |||
| ((row % reordered_elements_per_thread) / | |||
| elements_per_thread) * | |||
| instruction_shape_col + | |||
| ((row % interleaved) / | |||
| reordered_elements_per_thread) * | |||
| elements_per_thread + | |||
| (row % elements_per_thread) | |||
| : row; | |||
| uint32_t dst_offset = | |||
| (col * OC + row) * interleaved + residue * elements_per_lane; | |||
| *(reinterpret_cast<int4*>(dst + dst_offset * size_bits / 8)) = | |||
| *(reinterpret_cast<const int4*>(src + src_offset * size_bits / 8)); | |||
| } | |||
| template <uint32_t size_bits, uint32_t interleaved> | |||
| __global__ void reorder_ncxhwx_imma_filter_kernel( | |||
| int8_t* __restrict__ dst_filter, const int8_t* __restrict__ src_filter, | |||
| uint32_t OC, uint32_t IC, uint32_t FH, uint32_t FW, bool trans_oc) { | |||
| static constexpr uint32_t elements_per_lane = 128 / size_bits; | |||
| const uint32_t size = OC * IC * FH * FW / elements_per_lane; | |||
| uint32_t lane = threadIdx.x + blockIdx.x * blockDim.x; | |||
| if (lane < size) { | |||
| reorder_ncxhwx_imma_filter_func<size_bits, interleaved>( | |||
| dst_filter, src_filter, OC, IC, FH, FW, lane, trans_oc); | |||
| } | |||
| } | |||
| template <uint32_t size_bits, uint32_t alignbits, uint32_t interleaved> | |||
| __device__ __forceinline__ void reorder_nhwc_imma_filter_func( | |||
| int8_t* dst, const int8_t* src, uint32_t OC, uint32_t IC, uint32_t FH, | |||
| uint32_t FW, uint32_t lane, bool trans_oc) { | |||
| static constexpr uint32_t elements_per_access = alignbits / size_bits; | |||
| static constexpr uint32_t instruction_shape_col = 8; | |||
| // 4 threads per Quad | |||
| static constexpr uint32_t elements_per_thread = instruction_shape_col / 4; | |||
| // 4 threads per Quad | |||
| static constexpr uint32_t reordered_elements_per_thread = interleaved / 4; | |||
| uint32_t ICx = IC / elements_per_access; | |||
| uint32_t k = lane / (ICx * FH * FW); | |||
| uint32_t cxrs = lane - k * ICx * FH * FW; | |||
| uint32_t rs = cxrs / ICx; | |||
| uint32_t cx = cxrs - rs * ICx; | |||
| // transpose nhwc to ncxhwx | |||
| uint32_t src_offset = lane * elements_per_access; | |||
| // reorder k | |||
| k = (trans_oc) | |||
| ? (k / interleaved) * interleaved + | |||
| ((k % reordered_elements_per_thread) / | |||
| elements_per_thread) * | |||
| instruction_shape_col + | |||
| ((k % interleaved) / reordered_elements_per_thread) * | |||
| elements_per_thread + | |||
| (k % elements_per_thread) | |||
| : k; | |||
| uint32_t dst_offset = | |||
| (k * ICx * FH * FW + cx * FH * FW + rs) * elements_per_access; | |||
| if (alignbits == 32) { | |||
| *(reinterpret_cast<int*>(dst + dst_offset * size_bits / 8)) = *( | |||
| reinterpret_cast<const int*>(src + src_offset * size_bits / 8)); | |||
| } else if (alignbits == 64) { | |||
| *(reinterpret_cast<int2*>(dst + dst_offset * size_bits / 8)) = | |||
| *(reinterpret_cast<const int2*>(src + | |||
| src_offset * size_bits / 8)); | |||
| } else { | |||
| *(reinterpret_cast<int4*>(dst + dst_offset * size_bits / 8)) = | |||
| *(reinterpret_cast<const int4*>(src + | |||
| src_offset * size_bits / 8)); | |||
| } | |||
| } | |||
| template <uint32_t size_bits, uint32_t alignbits, uint32_t interleaved> | |||
| __global__ void reorder_nhwc_imma_filter_kernel( | |||
| int8_t* __restrict__ dst_filter, const int8_t* __restrict__ src_filter, | |||
| uint32_t OC, uint32_t IC, uint32_t FH, uint32_t FW, bool trans_oc) { | |||
| static constexpr uint32_t elements_per_access = alignbits / size_bits; | |||
| const uint32_t size = OC * IC * FH * FW / elements_per_access; | |||
| uint32_t lane = threadIdx.x + blockIdx.x * blockDim.x; | |||
| if (lane < size) { | |||
| reorder_nhwc_imma_filter_func<size_bits, alignbits, interleaved>( | |||
| dst_filter, src_filter, OC, IC, FH, FW, lane, trans_oc); | |||
| } | |||
| } | |||
| } // namespace | |||
| template <uint32_t size_bits, uint32_t interleaved> | |||
| void megdnn::cuda::cutlass_wrapper::reorder_ncxhwx_imma_filter( | |||
| int8_t* dst_filter, const int8_t* src_filter, uint32_t OC, uint32_t IC, | |||
| uint32_t FH, uint32_t FW, bool trans_oc, cudaStream_t stream) { | |||
| static constexpr uint32_t elements_per_lane = 128 / size_bits; | |||
| uint32_t nr_threads = | |||
| query_blocksize_for_kernel(reinterpret_cast<const void*>( | |||
| reorder_ncxhwx_imma_filter_kernel<size_bits, interleaved>)); | |||
| uint32_t vthreads = DIVUP(OC * IC * FH * FW, elements_per_lane); | |||
| nr_threads = std::min(nr_threads, vthreads); | |||
| uint32_t nr_blocks = DIVUP(vthreads, nr_threads); | |||
| reorder_ncxhwx_imma_filter_kernel<size_bits, interleaved> | |||
| <<<nr_blocks, nr_threads, 0, stream>>>(dst_filter, src_filter, OC, | |||
| IC, FH, FW, trans_oc); | |||
| after_kernel_launch(); | |||
| } | |||
| template <uint32_t size_bits, uint32_t alignbits> | |||
| void megdnn::cuda::cutlass_wrapper::reorder_nhwc_imma_filter( | |||
| int8_t* dst_filter, const int8_t* src_filter, uint32_t OC, uint32_t IC, | |||
| uint32_t FH, uint32_t FW, bool trans_oc, uint32_t oc_interleaved, | |||
| cudaStream_t stream) { | |||
| static constexpr uint32_t elements_per_access = alignbits / size_bits; | |||
| uint32_t nr_threads = | |||
| query_blocksize_for_kernel(reinterpret_cast<const void*>( | |||
| reorder_nhwc_imma_filter_kernel<size_bits, alignbits, 32>)); | |||
| uint32_t vthreads = DIVUP(OC * IC * FH * FW, elements_per_access); | |||
| nr_threads = std::min(nr_threads, vthreads); | |||
| uint32_t nr_blocks = DIVUP(vthreads, nr_threads); | |||
| if (oc_interleaved == 32) { | |||
| reorder_nhwc_imma_filter_kernel<size_bits, alignbits, 32> | |||
| <<<nr_blocks, nr_threads, 0, stream>>>( | |||
| dst_filter, src_filter, OC, IC, FH, FW, trans_oc); | |||
| } else { | |||
| reorder_nhwc_imma_filter_kernel<size_bits, alignbits, 64> | |||
| <<<nr_blocks, nr_threads, 0, stream>>>( | |||
| dst_filter, src_filter, OC, IC, FH, FW, trans_oc); | |||
| } | |||
| after_kernel_launch(); | |||
| } | |||
| #define INST(_size_bits, _interleaved) \ | |||
| template void megdnn::cuda::cutlass_wrapper::reorder_ncxhwx_imma_filter< \ | |||
| _size_bits, _interleaved>(int8_t * dst_filter, \ | |||
| const int8_t* src_filter, uint32_t OC, \ | |||
| uint32_t IC, uint32_t FH, uint32_t FW, \ | |||
| bool trans_oc, cudaStream_t stream); | |||
| INST(8, 32) | |||
| INST(4, 64) | |||
| #undef INST | |||
| #define INST(_size_bits, _alignbits) \ | |||
| template void megdnn::cuda::cutlass_wrapper::reorder_nhwc_imma_filter< \ | |||
| _size_bits, _alignbits>( \ | |||
| int8_t * dst_filter, const int8_t* src_filter, uint32_t OC, \ | |||
| uint32_t IC, uint32_t FH, uint32_t FW, bool trans_oc, \ | |||
| uint32_t oc_interleaved, cudaStream_t stream); | |||
| INST(4, 32) | |||
| INST(4, 64) | |||
| INST(4, 128) | |||
| #undef INST | |||
| // vim: syntax=cuda.doxygen | |||
| @@ -0,0 +1,33 @@ | |||
| /** | |||
| * \file dnn/src/cuda/conv_bias/cutlass_reorder_filter.cuh | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #pragma once | |||
| #include "src/cuda/utils.cuh" | |||
| namespace megdnn { | |||
| namespace cuda { | |||
| namespace cutlass_wrapper { | |||
| template <uint32_t size_bits, uint32_t interleaved> | |||
| void reorder_ncxhwx_imma_filter(int8_t* dst_filter, const int8_t* src_filter, | |||
| uint32_t OC, uint32_t IC, uint32_t FH, | |||
| uint32_t FW, bool trans_oc, | |||
| cudaStream_t stream); | |||
| template <uint32_t size_bits, uint32_t alignbits> | |||
| void reorder_nhwc_imma_filter(int8_t* dst_filter, const int8_t* src_filter, | |||
| uint32_t OC, uint32_t IC, uint32_t FH, | |||
| uint32_t FW, bool trans_oc, | |||
| uint32_t oc_interleaved, cudaStream_t stream); | |||
| } // namespace cutlass_wrapper | |||
| } // namespace cuda | |||
| } // namespace megdnn | |||
| @@ -102,7 +102,7 @@ void ConvBiasForwardImpl::AlgoInt4Int4NCHW64IMMAImplicitGemm::do_exec( | |||
| reinterpret_cast<int8_t*>(z_ptr), | |||
| reinterpret_cast<int8_t*>(args.dst_tensor->raw_ptr), nullptr, | |||
| kern_param, nonlinear_mode, alpha, beta, gamma, dst_scale, | |||
| threadblock_shape, warp_shape, stream); | |||
| threadblock_shape, warp_shape, m_algo_param.stage, stream); | |||
| } | |||
| #endif | |||
| @@ -104,7 +104,7 @@ void ConvBiasForwardImpl::AlgoInt4Int4NHWCIMMAImplicitGemm::do_exec( | |||
| reinterpret_cast<int8_t*>(args.dst_tensor->raw_ptr), nullptr, | |||
| kern_param, nonlinear_mode, alpha, beta, gamma, dst_scale, | |||
| threadblock_shape, warp_shape, m_algo_param.access_size, | |||
| stream); | |||
| m_algo_param.stage, stream); | |||
| } else { | |||
| cutlass_wrapper::do_conv_bias_int4_int4_implicit_gemm_imma_nhwc<true>( | |||
| reinterpret_cast<int8_t*>(args.src_tensor->raw_ptr), | |||
| @@ -114,7 +114,7 @@ void ConvBiasForwardImpl::AlgoInt4Int4NHWCIMMAImplicitGemm::do_exec( | |||
| reinterpret_cast<int8_t*>(args.dst_tensor->raw_ptr), nullptr, | |||
| kern_param, nonlinear_mode, alpha, beta, gamma, dst_scale, | |||
| threadblock_shape, warp_shape, m_algo_param.access_size, | |||
| stream); | |||
| m_algo_param.stage, stream); | |||
| } | |||
| } | |||
| #endif | |||
| @@ -12,6 +12,7 @@ | |||
| #include "./algo.h" | |||
| #include "src/common/conv_bias.h" | |||
| #include "src/cuda/conv_bias/cutlass_reorder_filter.cuh" | |||
| #include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh" | |||
| #include "src/cuda/conv_bias/reduce_filter.cuh" | |||
| #include "src/cuda/convolution_helper/parameter.cuh" | |||
| @@ -121,41 +122,26 @@ void ConvBiasForwardImpl::AlgoInt4NCHW64IMMAImplicitGemmBase::exec( | |||
| std::string ConvBiasForwardImpl::AlgoInt4NCHW64IMMAImplicitGemmBase::to_string( | |||
| AlgoParam algo_param) { | |||
| return ssprintf("%dX%dX%d_%dX%dX%d", algo_param.threadblock_m, | |||
| return ssprintf("%dX%dX%d_%dX%dX%d_%d", algo_param.threadblock_m, | |||
| algo_param.threadblock_n, algo_param.threadblock_k, | |||
| algo_param.warp_m, algo_param.warp_n, algo_param.warp_k); | |||
| algo_param.warp_m, algo_param.warp_n, algo_param.warp_k, | |||
| algo_param.stage); | |||
| } | |||
| void ConvBiasForwardImpl::AlgoInt4NCHW64IMMAImplicitGemmBase::reorder_filter( | |||
| const ExecArgs& args, void* reordered_filter) const { | |||
| auto&& param = args.opr->param(); | |||
| size_t ci = args.src_layout->operator[](1) * 64; | |||
| size_t co = args.dst_layout->operator[](1) * 64; | |||
| auto&& fm = args.filter_meta; | |||
| size_t n = args.src_layout->operator[](0), | |||
| ci = args.src_layout->operator[](1) * 64, | |||
| hi = args.src_layout->operator[](2), | |||
| wi = args.src_layout->operator[](3); | |||
| size_t co = args.dst_layout->operator[](1) * 64, | |||
| ho = args.dst_layout->operator[](2), | |||
| wo = args.dst_layout->operator[](3); | |||
| UNPACK_CONV_PARAMETER(fm, param); | |||
| MARK_USED_VAR; | |||
| // filter: KCRS64 => CRSK64 | |||
| TensorLayout src{{co, ci / 64, fh, fw, 64}, dtype::QuantizedS4()}; | |||
| src.init_contiguous_stride(); | |||
| TensorLayout dst = src; | |||
| dst.stride[0] = 64; | |||
| dst.stride[1] = co * fh * fw * 64; | |||
| dst.stride[2] = co * fw * 64; | |||
| dst.stride[3] = co * 64; | |||
| dst.stride[4] = 1; | |||
| TensorND ts_src, ts_dst; | |||
| ts_src.raw_ptr = args.filter_tensor->raw_ptr; | |||
| ts_src.layout = src; | |||
| ts_dst.raw_ptr = reordered_filter; | |||
| ts_dst.layout = dst; | |||
| auto&& transpose = args.opr->handle()->create_operator<RelayoutForward>(); | |||
| transpose->exec(ts_src, ts_dst); | |||
| size_t fh = fm.spatial[0], fw = fm.spatial[1]; | |||
| cudaStream_t stream = cuda_stream(args.opr->handle()); | |||
| // filter: KCRS64 => CRSK64 and reorder oc | |||
| cutlass_wrapper::reorder_ncxhwx_imma_filter<4, 64>( | |||
| reinterpret_cast<int8_t*>(reordered_filter), | |||
| reinterpret_cast<int8_t*>(args.filter_tensor->raw_ptr), co, ci, fh, | |||
| fw, true, stream); | |||
| } | |||
| #endif | |||
| @@ -12,6 +12,7 @@ | |||
| #include "./algo.h" | |||
| #include "src/common/conv_bias.h" | |||
| #include "src/cuda/conv_bias/cutlass_reorder_filter.cuh" | |||
| #include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh" | |||
| #include "src/cuda/conv_bias/reduce_filter.cuh" | |||
| #include "src/cuda/convolution_helper/parameter.cuh" | |||
| @@ -128,10 +129,10 @@ void ConvBiasForwardImpl::AlgoInt4NHWCIMMAImplicitGemmBase::exec( | |||
| std::string ConvBiasForwardImpl::AlgoInt4NHWCIMMAImplicitGemmBase::to_string( | |||
| AlgoParam algo_param) { | |||
| return ssprintf("%dX%dX%d_%dX%dX%d_%d", algo_param.threadblock_m, | |||
| return ssprintf("%dX%dX%d_%dX%dX%d_%d_%d", algo_param.threadblock_m, | |||
| algo_param.threadblock_n, algo_param.threadblock_k, | |||
| algo_param.warp_m, algo_param.warp_n, algo_param.warp_k, | |||
| algo_param.access_size); | |||
| algo_param.stage, algo_param.access_size); | |||
| } | |||
| void ConvBiasForwardImpl::AlgoInt4NHWCIMMAImplicitGemmBase::reorder_filter( | |||
| @@ -142,17 +143,32 @@ void ConvBiasForwardImpl::AlgoInt4NHWCIMMAImplicitGemmBase::reorder_filter( | |||
| fh = args.filter_layout->operator[](1), | |||
| fw = args.filter_layout->operator[](2); | |||
| // reformat grad from nhwc to ncxhwx | |||
| TensorLayout exec_src{{co, fh, fw, ci / iterleaved, (size_t)iterleaved / 2}, | |||
| dtype::Int8()}; | |||
| TensorLayout exec_dst{{co, ci / iterleaved, fh, fw, (size_t)iterleaved / 2}, | |||
| dtype::Int8()}; | |||
| exec_src = exec_src.dimshuffle({0, 3, 1, 2, 4}); | |||
| cudaStream_t stream = cuda_stream(args.opr->handle()); | |||
| auto&& relayout = args.opr->handle()->create_operator<RelayoutForward>(); | |||
| relayout->exec({args.filter_tensor->raw_ptr, exec_src}, | |||
| {reordered_filter, exec_dst}); | |||
| // reformat filter from nhwc to ncxhwx and reorder oc | |||
| // use trans_oc threadblock_n must be 32 or 64 | |||
| bool trans_oc = ((co % m_algo_param.threadblock_n == 0) && | |||
| (m_algo_param.threadblock_n == 32 || | |||
| m_algo_param.threadblock_n == 64)); | |||
| uint32_t oc_iterleave = (m_algo_param.threadblock_n == 64) ? 64 : 32; | |||
| if (iterleaved == 8) { | |||
| cutlass_wrapper::reorder_nhwc_imma_filter<4, 32>( | |||
| reinterpret_cast<int8_t*>(reordered_filter), | |||
| reinterpret_cast<int8_t*>(args.filter_tensor->raw_ptr), co, ci, | |||
| fh, fw, trans_oc, oc_iterleave, stream); | |||
| } else if (iterleaved == 16) { | |||
| cutlass_wrapper::reorder_nhwc_imma_filter<4, 64>( | |||
| reinterpret_cast<int8_t*>(reordered_filter), | |||
| reinterpret_cast<int8_t*>(args.filter_tensor->raw_ptr), co, ci, | |||
| fh, fw, trans_oc, oc_iterleave, stream); | |||
| } else { | |||
| megdnn_assert(iterleaved == 32); | |||
| cutlass_wrapper::reorder_nhwc_imma_filter<4, 128>( | |||
| reinterpret_cast<int8_t*>(reordered_filter), | |||
| reinterpret_cast<int8_t*>(args.filter_tensor->raw_ptr), co, ci, | |||
| fh, fw, trans_oc, oc_iterleave, stream); | |||
| } | |||
| } | |||
| #endif | |||
| @@ -11,6 +11,7 @@ | |||
| */ | |||
| #include "./algo.h" | |||
| #include "src/cuda/conv_bias/cutlass_reorder_filter.cuh" | |||
| #include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh" | |||
| #include "src/cuda/convolution_helper/parameter.cuh" | |||
| #include "src/cuda/utils.h" | |||
| @@ -110,11 +111,14 @@ void ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::exec( | |||
| size_t ho = args.dst_layout->operator[](2), | |||
| wo = args.dst_layout->operator[](3); | |||
| size_t co; | |||
| bool trans_oc; | |||
| if (param.format == Format::NCHW32) { | |||
| co = args.dst_layout->operator[](1) * 32; | |||
| trans_oc = true; | |||
| } else { | |||
| megdnn_assert(param.format == Format::NCHW32_NCHW4); | |||
| co = args.dst_layout->operator[](1) * 4; | |||
| trans_oc = false; | |||
| } | |||
| UNPACK_CONV_PARAMETER(fm, param); | |||
| MARK_USED_VAR | |||
| @@ -123,23 +127,11 @@ void ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::exec( | |||
| int8_t* filter_ptr = nullptr; | |||
| if (args.preprocessed_filter == nullptr) { | |||
| filter_ptr = reinterpret_cast<int8_t*>(args.workspace.raw_ptr); | |||
| // reformat filter from nchw32 to chwn32 | |||
| TensorLayout src{{co, ci / 32, fh, fw, 32}, dtype::Int8()}; | |||
| src.init_contiguous_stride(); | |||
| TensorLayout dst = src; | |||
| dst.stride[0] = 32; | |||
| dst.stride[1] = co * fh * fw * 32; | |||
| dst.stride[2] = co * fw * 32; | |||
| dst.stride[3] = co * 32; | |||
| dst.stride[4] = 1; | |||
| TensorND ts_src, ts_dst; | |||
| ts_src.raw_ptr = args.filter_tensor->raw_ptr; | |||
| ts_src.layout = src; | |||
| ts_dst.raw_ptr = args.workspace.raw_ptr; | |||
| ts_dst.layout = dst; | |||
| auto&& transpose = | |||
| args.opr->handle()->create_operator<RelayoutForward>(); | |||
| transpose->exec(ts_src, ts_dst); | |||
| // filter: KCRS32 => CRSK32 and reorder oc | |||
| cutlass_wrapper::reorder_ncxhwx_imma_filter<8, 32>( | |||
| filter_ptr, | |||
| reinterpret_cast<int8_t*>(args.filter_tensor->raw_ptr), co, ci, | |||
| fh, fw, trans_oc, stream); | |||
| } else { | |||
| filter_ptr = reinterpret_cast<int8_t*>( | |||
| args.preprocessed_filter->tensors[0].raw_ptr); | |||
| @@ -182,7 +174,7 @@ void ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::exec( | |||
| cutlass_wrapper::GemmCoord{m_algo_param.warp_m, | |||
| m_algo_param.warp_n, | |||
| m_algo_param.warp_k}, | |||
| stream); | |||
| m_algo_param.stage, stream); | |||
| } else { | |||
| megdnn_assert(param.format == Format::NCHW32_NCHW4); | |||
| cutlass_wrapper:: | |||
| @@ -202,7 +194,7 @@ void ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::exec( | |||
| cutlass_wrapper::GemmCoord{m_algo_param.warp_m, | |||
| m_algo_param.warp_n, | |||
| m_algo_param.warp_k}, | |||
| stream); | |||
| m_algo_param.stage, stream); | |||
| } | |||
| } else { | |||
| if (param.format == Format::NCHW32) { | |||
| @@ -218,7 +210,7 @@ void ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::exec( | |||
| cutlass_wrapper::GemmCoord{m_algo_param.warp_m, | |||
| m_algo_param.warp_n, | |||
| m_algo_param.warp_k}, | |||
| stream); | |||
| m_algo_param.stage, stream); | |||
| } else { | |||
| megdnn_assert(param.format == Format::NCHW32_NCHW4); | |||
| cutlass_wrapper:: | |||
| @@ -238,7 +230,7 @@ void ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::exec( | |||
| cutlass_wrapper::GemmCoord{m_algo_param.warp_m, | |||
| m_algo_param.warp_n, | |||
| m_algo_param.warp_k}, | |||
| stream); | |||
| m_algo_param.stage, stream); | |||
| } | |||
| } | |||
| after_kernel_launch(); | |||
| @@ -246,9 +238,10 @@ void ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::exec( | |||
| std::string ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::to_string( | |||
| AlgoParam algo_param) { | |||
| return ssprintf("%uX%uX%u_%uX%uX%u", algo_param.threadblock_m, | |||
| return ssprintf("%uX%uX%u_%uX%uX%u_%u", algo_param.threadblock_m, | |||
| algo_param.threadblock_n, algo_param.threadblock_k, | |||
| algo_param.warp_m, algo_param.warp_n, algo_param.warp_k); | |||
| algo_param.warp_m, algo_param.warp_n, algo_param.warp_k, | |||
| algo_param.stage); | |||
| } | |||
| size_t ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm:: | |||
| @@ -267,36 +260,26 @@ void ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::exec_preprocess( | |||
| using Format = Param::Format; | |||
| auto&& param = args.opr->param(); | |||
| auto&& fm = args.filter_meta; | |||
| size_t n = args.src_layout->operator[](0), | |||
| ci = args.src_layout->operator[](1) * 32, | |||
| hi = args.src_layout->operator[](2), | |||
| wi = args.src_layout->operator[](3); | |||
| size_t ho = args.dst_layout->operator[](2), | |||
| wo = args.dst_layout->operator[](3); | |||
| size_t ci = args.src_layout->operator[](1) * 32; | |||
| size_t co; | |||
| bool trans_oc; | |||
| if (param.format == Format::NCHW32) { | |||
| co = args.dst_layout->operator[](1) * 32; | |||
| trans_oc = true; | |||
| } else { | |||
| megdnn_assert(param.format == Format::NCHW32_NCHW4); | |||
| co = args.dst_layout->operator[](1) * 4; | |||
| trans_oc = false; | |||
| } | |||
| UNPACK_CONV_PARAMETER(fm, param); | |||
| MARK_USED_VAR | |||
| TensorLayout src{{co, ci / 32, fh, fw, 32}, dtype::Int8()}; | |||
| src.init_contiguous_stride(); | |||
| TensorLayout dst = src; | |||
| dst.stride[0] = 32; | |||
| dst.stride[1] = co * fh * fw * 32; | |||
| dst.stride[2] = co * fw * 32; | |||
| dst.stride[3] = co * 32; | |||
| dst.stride[4] = 1; | |||
| TensorND ts_src, ts_dst; | |||
| ts_src.raw_ptr = args.filter_tensor->raw_ptr; | |||
| ts_src.layout = src; | |||
| ts_dst.raw_ptr = args.preprocessed_filter->tensors[0].raw_ptr; | |||
| ts_dst.layout = dst; | |||
| auto&& transpose = args.opr->handle()->create_operator<RelayoutForward>(); | |||
| transpose->exec(ts_src, ts_dst); | |||
| size_t fh = fm.spatial[0], fw = fm.spatial[1]; | |||
| cudaStream_t stream = cuda_stream(args.opr->handle()); | |||
| // filter: KCRS32 => CRSK32 and reorder oc | |||
| cutlass_wrapper::reorder_ncxhwx_imma_filter<8, 32>( | |||
| reinterpret_cast<int8_t*>( | |||
| args.preprocessed_filter->tensors[0].raw_ptr), | |||
| reinterpret_cast<int8_t*>(args.filter_tensor->raw_ptr), co, ci, fh, | |||
| fw, trans_oc, stream); | |||
| } | |||
| #endif | |||
| @@ -144,7 +144,8 @@ void ConvBiasForwardImpl::AlgoUInt4Int4NCHW64IMMAImplicitGemm::do_exec( | |||
| reinterpret_cast<uint8_t*>(z_ptr), | |||
| reinterpret_cast<uint8_t*>(args.dst_tensor->raw_ptr), nullptr, | |||
| kern_param, nonlinear_mode, alpha, beta, gamma, delta, theta, | |||
| dst_scale, src_zero, threadblock_shape, warp_shape, stream); | |||
| dst_scale, src_zero, threadblock_shape, warp_shape, | |||
| m_algo_param.stage, stream); | |||
| } | |||
| void ConvBiasForwardImpl::AlgoUInt4Int4NCHW64IMMAImplicitGemm::update_bias( | |||
| @@ -147,7 +147,7 @@ void ConvBiasForwardImpl::AlgoUInt4Int4NHWCIMMAImplicitGemm::do_exec( | |||
| reinterpret_cast<uint8_t*>(args.dst_tensor->raw_ptr), nullptr, | |||
| kern_param, nonlinear_mode, alpha, beta, gamma, delta, theta, | |||
| dst_scale, src_zero, threadblock_shape, warp_shape, | |||
| m_algo_param.access_size, stream); | |||
| m_algo_param.access_size, m_algo_param.stage, stream); | |||
| } else { | |||
| cutlass_wrapper::do_conv_bias_uint4_int4_implicit_gemm_imma_nhwc<true>( | |||
| reinterpret_cast<uint8_t*>(args.src_tensor->raw_ptr), | |||
| @@ -157,7 +157,7 @@ void ConvBiasForwardImpl::AlgoUInt4Int4NHWCIMMAImplicitGemm::do_exec( | |||
| reinterpret_cast<uint8_t*>(args.dst_tensor->raw_ptr), nullptr, | |||
| kern_param, nonlinear_mode, alpha, beta, gamma, delta, theta, | |||
| dst_scale, src_zero, threadblock_shape, warp_shape, | |||
| m_algo_param.access_size, stream); | |||
| m_algo_param.access_size, m_algo_param.stage, stream); | |||
| } | |||
| } | |||
| @@ -840,21 +840,21 @@ TEST_F(CUDA, CUTLASS_CONV_BIAS_INT8_NCHW32_IMMA) { | |||
| param.pad_h = param.pad_w = 1; | |||
| param.stride_h = param.stride_w = 1; | |||
| param.format = param::ConvBias::Format::NCHW32; | |||
| checker.set_param(param).execs({{16, 16, 7, 7, 32}, | |||
| {512, 16, 3, 3, 32}, | |||
| {1, 16, 1, 1, 32}, | |||
| checker.set_param(param).execs({{16, 8, 7, 7, 32}, | |||
| {256, 8, 3, 3, 32}, | |||
| {1, 8, 1, 1, 32}, | |||
| {}, | |||
| {}}); | |||
| param.nonlineMode = param::ConvBias::NonlineMode::RELU; | |||
| checker.set_param(param).execs({{16, 16, 7, 7, 32}, | |||
| {512, 16, 1, 1, 32}, | |||
| {1, 16, 1, 1, 32}, | |||
| checker.set_param(param).execs({{16, 8, 7, 7, 32}, | |||
| {256, 8, 1, 1, 32}, | |||
| {1, 8, 1, 1, 32}, | |||
| {}, | |||
| {}}); | |||
| param.nonlineMode = param::ConvBias::NonlineMode::H_SWISH; | |||
| checker.set_param(param).execs({{16, 16, 7, 7, 32}, | |||
| {512, 16, 3, 3, 32}, | |||
| {1, 16, 1, 1, 32}, | |||
| checker.set_param(param).execs({{16, 8, 7, 7, 32}, | |||
| {256, 8, 3, 3, 32}, | |||
| {1, 8, 1, 1, 32}, | |||
| {}, | |||
| {}}); | |||
| // use non integer scale | |||
| @@ -867,18 +867,18 @@ TEST_F(CUDA, CUTLASS_CONV_BIAS_INT8_NCHW32_IMMA) { | |||
| .set_epsilon(1 + 1e-3) | |||
| .set_max_avg_error(1e-1) | |||
| .set_max_avg_biased_error(1e-1) | |||
| .execs({{16, 16, 7, 7, 32}, | |||
| {512, 16, 3, 3, 32}, | |||
| {1, 16, 1, 1, 32}, | |||
| {16, 16, 7, 7, 32}, | |||
| .execs({{16, 8, 7, 7, 32}, | |||
| {256, 8, 3, 3, 32}, | |||
| {1, 8, 1, 1, 32}, | |||
| {16, 8, 7, 7, 32}, | |||
| {}}); | |||
| }; | |||
| std::string algo = ConvBias::algo_name<ConvBias::DirectParam>( | |||
| "INT8_NCHW32_IMMA_IMPLICIT_GEMM_256X128X64_64X64X64", | |||
| "INT8_NCHW32_IMMA_IMPLICIT_GEMM_128X128X64_64X64X64_2", | |||
| ConvBias::DirectParam{}); | |||
| check(algo); | |||
| algo = ConvBias::algo_name<ConvBias::DirectParam>( | |||
| "INT8_NCHW32_IMMA_IMPLICIT_GEMM_32X64X64_32X16X64", | |||
| "INT8_NCHW32_IMMA_IMPLICIT_GEMM_128X32X32_64X32X32_1", | |||
| ConvBias::DirectParam{}); | |||
| check(algo); | |||
| } | |||
| @@ -969,7 +969,7 @@ TEST_F(CUDA, CUTLASS_CONV_BIAS_INT8_NCHW32_NCHW4) { | |||
| checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker< | |||
| ConvBiasForward>( | |||
| ConvBias::algo_name<ConvBias::DirectParam>( | |||
| "INT8_NCHW32_IMMA_IMPLICIT_GEMM_256X128X64_64X64X64", | |||
| "INT8_NCHW32_IMMA_IMPLICIT_GEMM_128X128X64_64X64X64_2", | |||
| ConvBias::DirectParam{}) | |||
| .c_str())); | |||
| checker.set_dtype(0, dtype::QuantizedS8(1.9980618f)) | |||