GitOrigin-RevId: 0b5574f526
tags/v1.6.0-rc1
| @@ -20,7 +20,7 @@ class Conv2dOperation: | |||||
| # | # | ||||
| def __init__(self, conv_kind, conv_type, arch, tile_description, src, flt, bias, dst, element_epilogue, \ | def __init__(self, conv_kind, conv_type, arch, tile_description, src, flt, bias, dst, element_epilogue, \ | ||||
| epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity4, \ | epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity4, \ | ||||
| need_load_from_const = True, implicit_gemm_mode = ImplicitGemmMode.GemmNt): | |||||
| need_load_from_const = True, implicit_gemm_mode = ImplicitGemmMode.GemmNT, without_shared_load = False): | |||||
| self.operation_kind = OperationKind.Conv2d | self.operation_kind = OperationKind.Conv2d | ||||
| self.conv_kind = conv_kind | self.conv_kind = conv_kind | ||||
| @@ -36,6 +36,7 @@ class Conv2dOperation: | |||||
| self.swizzling_functor = swizzling_functor | self.swizzling_functor = swizzling_functor | ||||
| self.need_load_from_const = need_load_from_const | self.need_load_from_const = need_load_from_const | ||||
| self.implicit_gemm_mode = implicit_gemm_mode | self.implicit_gemm_mode = implicit_gemm_mode | ||||
| self.without_shared_load = without_shared_load | |||||
| # | # | ||||
| def accumulator_type(self): | def accumulator_type(self): | ||||
| accum = self.tile_description.math_instruction.element_accumulator | accum = self.tile_description.math_instruction.element_accumulator | ||||
| @@ -58,11 +59,15 @@ class Conv2dOperation: | |||||
| unity_kernel = '' | unity_kernel = '' | ||||
| if not self.need_load_from_const: | if not self.need_load_from_const: | ||||
| unity_kernel = '_1x1' | |||||
| unity_kernel = '_1x1' | |||||
| return "%s%s%s%s%s_%s" % (ShortDataTypeNames[self.accumulator_type()], \ | |||||
| reorder_k = '' | |||||
| if self.without_shared_load: | |||||
| reorder_k = '_roc' | |||||
| return "%s%s%s%s%s%s_%s" % (ShortDataTypeNames[self.accumulator_type()], \ | |||||
| inst_shape, intermediate_type, ConvKindNames[self.conv_kind], unity_kernel, \ | inst_shape, intermediate_type, ConvKindNames[self.conv_kind], unity_kernel, \ | ||||
| ShortEpilogueNames[self.epilogue_functor]) | |||||
| reorder_k, ShortEpilogueNames[self.epilogue_functor]) | |||||
| # | # | ||||
| def extended_name(self): | def extended_name(self): | ||||
| @@ -177,7 +182,8 @@ using Convolution = | |||||
| ${alignment_filter}, | ${alignment_filter}, | ||||
| ${nonuninity_kernel}, | ${nonuninity_kernel}, | ||||
| ${math_operator}, | ${math_operator}, | ||||
| ${implicit_gemm_mode}>; | |||||
| ${implicit_gemm_mode}, | |||||
| ${without_shared_load}>; | |||||
| """ | """ | ||||
| @@ -219,7 +225,8 @@ using Convolution = | |||||
| 'alignment_filter': str(operation.flt.alignment), | 'alignment_filter': str(operation.flt.alignment), | ||||
| 'nonuninity_kernel': str(operation.need_load_from_const).lower(), | 'nonuninity_kernel': str(operation.need_load_from_const).lower(), | ||||
| 'math_operator': MathOperationTag[operation.tile_description.math_instruction.math_operation], | 'math_operator': MathOperationTag[operation.tile_description.math_instruction.math_operation], | ||||
| 'implicit_gemm_mode': ImplicitGemmModeTag[operation.implicit_gemm_mode] | |||||
| 'implicit_gemm_mode': ImplicitGemmModeTag[operation.implicit_gemm_mode], | |||||
| 'without_shared_load': str(operation.without_shared_load).lower() | |||||
| } | } | ||||
| return SubstituteTemplate(self.template, values) | return SubstituteTemplate(self.template, values) | ||||
| @@ -312,13 +319,13 @@ using Deconvolution = | |||||
| # | # | ||||
| def GenerateConv2d(conv_kind, tile_descriptions, src_layout, flt_layout, dst_layout, dst_type, min_cc, src_align = 32, flt_align = 32, dst_align = 128, \ | def GenerateConv2d(conv_kind, tile_descriptions, src_layout, flt_layout, dst_layout, dst_type, min_cc, src_align = 32, flt_align = 32, dst_align = 128, \ | ||||
| skip_unity_kernel = False, implicit_gemm_mode = ImplicitGemmMode.GemmNt): | |||||
| skip_unity_kernel = False, implicit_gemm_mode = ImplicitGemmMode.GemmNT, without_shared_load = False): | |||||
| operations = [] | operations = [] | ||||
| element_epilogue = DataType.f32 | element_epilogue = DataType.f32 | ||||
| if conv_kind == ConvKind.Fprop: | if conv_kind == ConvKind.Fprop: | ||||
| if src_layout == LayoutType.TensorNHWC: | |||||
| swizzling_functor = SwizzlingFunctor.ConvFpropNHWC | |||||
| if implicit_gemm_mode == ImplicitGemmMode.GemmTN: | |||||
| swizzling_functor = SwizzlingFunctor.ConvFpropTrans | |||||
| else: | else: | ||||
| swizzling_functor = SwizzlingFunctor.ConvFpropNCxHWx | swizzling_functor = SwizzlingFunctor.ConvFpropNCxHWx | ||||
| else: | else: | ||||
| @@ -399,10 +406,10 @@ def GenerateConv2d(conv_kind, tile_descriptions, src_layout, flt_layout, dst_lay | |||||
| bias = TensorDescription(bias_type, dst_layout, max(1, int(32 / DataTypeSize[bias_type]))) | bias = TensorDescription(bias_type, dst_layout, max(1, int(32 / DataTypeSize[bias_type]))) | ||||
| dst = TensorDescription(dst_type, dst_layout, int(dst_align / DataTypeSize[dst_type])) | dst = TensorDescription(dst_type, dst_layout, int(dst_align / DataTypeSize[dst_type])) | ||||
| new_operation = Conv2dOperation(conv_kind, ConvType.Convolution, min_cc, tile, src, flt, bias, dst, element_epilogue, epilogue, swizzling_functor, True, implicit_gemm_mode) | |||||
| new_operation = Conv2dOperation(conv_kind, ConvType.Convolution, min_cc, tile, src, flt, bias, dst, element_epilogue, epilogue, swizzling_functor, True, implicit_gemm_mode, without_shared_load) | |||||
| operations.append(new_operation) | operations.append(new_operation) | ||||
| if not skip_unity_kernel: | if not skip_unity_kernel: | ||||
| new_operation = Conv2dOperation(conv_kind, ConvType.Convolution, min_cc, tile, src, flt, bias, dst, element_epilogue, epilogue, swizzling_functor, False, implicit_gemm_mode) | |||||
| new_operation = Conv2dOperation(conv_kind, ConvType.Convolution, min_cc, tile, src, flt, bias, dst, element_epilogue, epilogue, swizzling_functor, False, implicit_gemm_mode, without_shared_load) | |||||
| operations.append(new_operation) | operations.append(new_operation) | ||||
| return operations | return operations | ||||
| @@ -175,12 +175,10 @@ def GenerateConv2d_Simt(args): | |||||
| TileDescription([128, 128, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc), | TileDescription([128, 128, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc), | ||||
| TileDescription([128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), | TileDescription([128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), | ||||
| TileDescription([ 64, 128, 32], 2, [1, 4, 1], math_inst, min_cc, max_cc), | TileDescription([ 64, 128, 32], 2, [1, 4, 1], math_inst, min_cc, max_cc), | ||||
| TileDescription([ 64, 64, 32], 2, [1, 2, 1], math_inst, min_cc, max_cc), | |||||
| TileDescription([128, 32, 32], 2, [2, 1, 1], math_inst, min_cc, max_cc), | TileDescription([128, 32, 32], 2, [2, 1, 1], math_inst, min_cc, max_cc), | ||||
| TileDescription([ 32, 128, 32], 2, [1, 2, 1], math_inst, min_cc, max_cc), | TileDescription([ 32, 128, 32], 2, [1, 2, 1], math_inst, min_cc, max_cc), | ||||
| TileDescription([ 32, 64, 32], 2, [1, 1, 1], math_inst, min_cc, max_cc), | TileDescription([ 32, 64, 32], 2, [1, 1, 1], math_inst, min_cc, max_cc), | ||||
| TileDescription([ 64, 32, 32], 2, [1, 1, 1], math_inst, min_cc, max_cc), | TileDescription([ 64, 32, 32], 2, [1, 1, 1], math_inst, min_cc, max_cc), | ||||
| TileDescription([ 32, 32, 32], 2, [1, 1, 1], math_inst, min_cc, max_cc), | |||||
| TileDescription([ 16, 128, 16], 1, [1, 1, 1], math_inst, min_cc, max_cc), | TileDescription([ 16, 128, 16], 1, [1, 1, 1], math_inst, min_cc, max_cc), | ||||
| TileDescription([ 16, 64, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc), | TileDescription([ 16, 64, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc), | ||||
| ] | ] | ||||
| @@ -223,28 +221,36 @@ def GenerateConv2d_TensorOp_8816(args): | |||||
| for dst_type, dst_layout in zip(dst_types, dst_layouts): | for dst_type, dst_layout in zip(dst_types, dst_layouts): | ||||
| if dst_layout == LayoutType.TensorNC32HW32: | if dst_layout == LayoutType.TensorNC32HW32: | ||||
| tile_descriptions = [ | tile_descriptions = [ | ||||
| TileDescription([256, 128, 64], 2, [4, 2, 1], math_inst, min_cc, max_cc), | |||||
| TileDescription([128, 256, 64], 2, [2, 4, 1], math_inst, min_cc, max_cc), | TileDescription([128, 256, 64], 2, [2, 4, 1], math_inst, min_cc, max_cc), | ||||
| TileDescription([256, 128, 64], 2, [4, 2, 1], math_inst, min_cc, max_cc), | |||||
| TileDescription([128, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), | TileDescription([128, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), | ||||
| TileDescription([ 64, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), | |||||
| TileDescription([128, 64, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), | TileDescription([128, 64, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), | ||||
| TileDescription([ 64, 64, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), | |||||
| TileDescription([ 32, 64, 64], 2, [1, 4, 1], math_inst, min_cc, max_cc), | |||||
| TileDescription([ 64, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), | |||||
| TileDescription([128, 64, 32], 1, [2, 2, 1], math_inst, min_cc, max_cc), | |||||
| TileDescription([128, 32, 32], 1, [2, 1, 1], math_inst, min_cc, max_cc), | |||||
| TileDescription([ 64, 128, 32], 1, [2, 2, 1], math_inst, min_cc, max_cc), | |||||
| TileDescription([ 32, 128, 32], 1, [1, 2, 1], math_inst, min_cc, max_cc), | |||||
| ] | ] | ||||
| operations += GenerateConv2d(ConvKind.Fprop, tile_descriptions, layout[0], layout[1], | |||||
| dst_layout, dst_type, min_cc, 128, 128, 64, | |||||
| False, ImplicitGemmMode.GemmTN, True) | |||||
| else: | else: | ||||
| assert dst_layout == LayoutType.TensorNC4HW4 | assert dst_layout == LayoutType.TensorNC4HW4 | ||||
| tile_descriptions = [ | tile_descriptions = [ | ||||
| TileDescription([256, 128, 64], 2, [4, 2, 1], math_inst, min_cc, max_cc), | |||||
| TileDescription([128, 256, 64], 2, [2, 4, 1], math_inst, min_cc, max_cc), | TileDescription([128, 256, 64], 2, [2, 4, 1], math_inst, min_cc, max_cc), | ||||
| TileDescription([256, 128, 64], 2, [4, 2, 1], math_inst, min_cc, max_cc), | |||||
| TileDescription([128, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), | TileDescription([128, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), | ||||
| TileDescription([ 64, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), | |||||
| TileDescription([128, 64, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), | TileDescription([128, 64, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), | ||||
| TileDescription([ 64, 64, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), | |||||
| TileDescription([ 32, 64, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), | |||||
| TileDescription([ 64, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), | |||||
| TileDescription([128, 64, 32], 1, [2, 2, 1], math_inst, min_cc, max_cc), | |||||
| TileDescription([128, 32, 32], 1, [2, 1, 1], math_inst, min_cc, max_cc), | |||||
| TileDescription([ 64, 128, 32], 1, [2, 2, 1], math_inst, min_cc, max_cc), | |||||
| TileDescription([ 32, 128, 32], 1, [1, 2, 1], math_inst, min_cc, max_cc), | |||||
| ] | ] | ||||
| operations += GenerateConv2d(ConvKind.Fprop, tile_descriptions, layout[0], layout[1], | |||||
| operations += GenerateConv2d(ConvKind.Fprop, tile_descriptions, layout[0], layout[1], | |||||
| dst_layout, dst_type, min_cc, 128, 128, 64, | dst_layout, dst_type, min_cc, 128, 128, 64, | ||||
| False) | False) | ||||
| return operations | return operations | ||||
| def GenerateConv2d_TensorOp_8832(args): | def GenerateConv2d_TensorOp_8832(args): | ||||
| @@ -279,12 +285,14 @@ def GenerateConv2d_TensorOp_8832(args): | |||||
| for dst_layout in dst_layouts: | for dst_layout in dst_layouts: | ||||
| dst_type = math_inst.element_b | dst_type = math_inst.element_b | ||||
| tile_descriptions = [ | tile_descriptions = [ | ||||
| TileDescription([256, 128, 128], 2, [4, 2, 1], math_inst, min_cc, max_cc), | |||||
| TileDescription([128, 256, 128], 2, [2, 4, 1], math_inst, min_cc, max_cc), | |||||
| TileDescription([128, 128, 128], 2, [2, 2, 1], math_inst, min_cc, max_cc), | TileDescription([128, 128, 128], 2, [2, 2, 1], math_inst, min_cc, max_cc), | ||||
| TileDescription([128, 64, 128], 2, [2, 1, 1], math_inst, min_cc, max_cc), | |||||
| TileDescription([128, 64, 64], 1, [2, 1, 1], math_inst, min_cc, max_cc), | |||||
| ] | ] | ||||
| operations += GenerateConv2d(ConvKind.Fprop, tile_descriptions, layout[0], layout[1], | operations += GenerateConv2d(ConvKind.Fprop, tile_descriptions, layout[0], layout[1], | ||||
| dst_layout, dst_type, min_cc, 128, 128, 64, | dst_layout, dst_type, min_cc, 128, 128, 64, | ||||
| True) | |||||
| True, ImplicitGemmMode.GemmTN, True) | |||||
| layouts_nhwc = [ | layouts_nhwc = [ | ||||
| (LayoutType.TensorNHWC, LayoutType.TensorNC8HW8, 32), | (LayoutType.TensorNHWC, LayoutType.TensorNC8HW8, 32), | ||||
| @@ -299,14 +307,21 @@ def GenerateConv2d_TensorOp_8832(args): | |||||
| for math_inst in math_instructions: | for math_inst in math_instructions: | ||||
| for layout in layouts_nhwc: | for layout in layouts_nhwc: | ||||
| for dst_layout in dst_layouts_nhwc: | for dst_layout in dst_layouts_nhwc: | ||||
| dst_type = math_inst.element_b | |||||
| tile_descriptions = [ | |||||
| TileDescription([128, 32, 64], 2, [2, 1, 1], math_inst, min_cc, max_cc), | |||||
| TileDescription([128, 64, 64], 2, [2, 1, 1], math_inst, min_cc, max_cc), | |||||
| ] | |||||
| operations += GenerateConv2d(ConvKind.Fprop, tile_descriptions, layout[0], layout[1], | |||||
| dst_layout, dst_type, min_cc, layout[2], layout[2], 32, | |||||
| False, ImplicitGemmMode.GemmTn) | |||||
| dst_type = math_inst.element_b | |||||
| tile_descriptions = [ | |||||
| TileDescription([128, 32, 64], 1, [2, 1, 1], math_inst, min_cc, max_cc), | |||||
| TileDescription([128, 64, 64], 1, [2, 1, 1], math_inst, min_cc, max_cc), | |||||
| ] | |||||
| for tile in tile_descriptions: | |||||
| operations += GenerateConv2d(ConvKind.Fprop, [tile], layout[0], layout[1], | |||||
| dst_layout, dst_type, min_cc, layout[2], layout[2], 32, | |||||
| False, ImplicitGemmMode.GemmTN, False) | |||||
| if tile.threadblock_shape[1] == 32 or tile.threadblock_shape[1] == 64: | |||||
| dst_align = 32 if tile.threadblock_shape[1] == 32 else 64 | |||||
| operations += GenerateConv2d(ConvKind.Fprop, [tile], layout[0], layout[1], | |||||
| dst_layout, dst_type, min_cc, layout[2], layout[2], dst_align, | |||||
| False, ImplicitGemmMode.GemmTN, True) | |||||
| return operations | return operations | ||||
| def GenerateDeconv_Simt(args): | def GenerateDeconv_Simt(args): | ||||
| @@ -649,3 +664,4 @@ if __name__ == "__main__": | |||||
| # | # | ||||
| ################################################################################################### | ################################################################################################### | ||||
| @@ -464,10 +464,10 @@ EpilogueFunctorTag = { | |||||
| ShortEpilogueNames = { | ShortEpilogueNames = { | ||||
| EpilogueFunctor.BiasAddLinearCombinationHSwishClamp: 'hswish', | EpilogueFunctor.BiasAddLinearCombinationHSwishClamp: 'hswish', | ||||
| EpilogueFunctor.BiasAddLinearCombinationReluClamp: 'relu', | EpilogueFunctor.BiasAddLinearCombinationReluClamp: 'relu', | ||||
| EpilogueFunctor.BiasAddLinearCombinationClamp: 'identity', | |||||
| EpilogueFunctor.BiasAddLinearCombinationClamp: 'id', | |||||
| EpilogueFunctor.BiasAddLinearCombinationHSwish: 'hswish', | EpilogueFunctor.BiasAddLinearCombinationHSwish: 'hswish', | ||||
| EpilogueFunctor.BiasAddLinearCombinationRelu: 'relu', | EpilogueFunctor.BiasAddLinearCombinationRelu: 'relu', | ||||
| EpilogueFunctor.BiasAddLinearCombination: 'identity', | |||||
| EpilogueFunctor.BiasAddLinearCombination: 'id', | |||||
| } | } | ||||
| @@ -482,7 +482,7 @@ class SwizzlingFunctor(enum.Enum): | |||||
| Identity4 = enum_auto() | Identity4 = enum_auto() | ||||
| Identity8 = enum_auto() | Identity8 = enum_auto() | ||||
| ConvFpropNCxHWx = enum_auto() | ConvFpropNCxHWx = enum_auto() | ||||
| ConvFpropNHWC = enum_auto() | |||||
| ConvFpropTrans = enum_auto() | |||||
| ConvDgradNCxHWx = enum_auto() | ConvDgradNCxHWx = enum_auto() | ||||
| # | # | ||||
| @@ -492,7 +492,7 @@ SwizzlingFunctorTag = { | |||||
| SwizzlingFunctor.Identity4: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>', | SwizzlingFunctor.Identity4: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>', | ||||
| SwizzlingFunctor.Identity8: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>', | SwizzlingFunctor.Identity8: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>', | ||||
| SwizzlingFunctor.ConvFpropNCxHWx: 'cutlass::conv::threadblock::ConvolutionFpropNCxHWxThreadblockSwizzle', | SwizzlingFunctor.ConvFpropNCxHWx: 'cutlass::conv::threadblock::ConvolutionFpropNCxHWxThreadblockSwizzle', | ||||
| SwizzlingFunctor.ConvFpropNHWC: 'cutlass::conv::threadblock::ConvolutionFpropNHWCThreadblockSwizzle', | |||||
| SwizzlingFunctor.ConvFpropTrans: 'cutlass::conv::threadblock::ConvolutionFpropTransThreadblockSwizzle', | |||||
| SwizzlingFunctor.ConvDgradNCxHWx: 'cutlass::conv::threadblock::ConvolutionDgradNCxHWxThreadblockSwizzle', | SwizzlingFunctor.ConvDgradNCxHWx: 'cutlass::conv::threadblock::ConvolutionDgradNCxHWxThreadblockSwizzle', | ||||
| } | } | ||||
| @@ -563,17 +563,17 @@ StrideSupportNames = { | |||||
| } | } | ||||
| class ImplicitGemmMode(enum.Enum): | class ImplicitGemmMode(enum.Enum): | ||||
| GemmNt = enum_auto() | |||||
| GemmTn = enum_auto() | |||||
| GemmNT = enum_auto() | |||||
| GemmTN = enum_auto() | |||||
| ImplicitGemmModeNames = { | ImplicitGemmModeNames = { | ||||
| ImplicitGemmMode.GemmNt: 'gemm_nt', | |||||
| ImplicitGemmMode.GemmTn: 'gemm_tn', | |||||
| ImplicitGemmMode.GemmNT: 'gemm_nt', | |||||
| ImplicitGemmMode.GemmTN: 'gemm_tn', | |||||
| } | } | ||||
| ImplicitGemmModeTag = { | ImplicitGemmModeTag = { | ||||
| ImplicitGemmMode.GemmNt: 'cutlass::conv::ImplicitGemmMode::GEMM_NT', | |||||
| ImplicitGemmMode.GemmTn: 'cutlass::conv::ImplicitGemmMode::GEMM_TN', | |||||
| ImplicitGemmMode.GemmNT: 'cutlass::conv::ImplicitGemmMode::GEMM_NT', | |||||
| ImplicitGemmMode.GemmTN: 'cutlass::conv::ImplicitGemmMode::GEMM_TN', | |||||
| } | } | ||||
| ################################################################################################### | ################################################################################################### | ||||
| @@ -164,415 +164,461 @@ cutlass_gen_list = [ | |||||
| "cutlass_simt_sgemv_batched_strided_1x32_32_tt_align1x4.cu", | "cutlass_simt_sgemv_batched_strided_1x32_32_tt_align1x4.cu", | ||||
| "cutlass_simt_sgemv_batched_strided_1x32_16_tt_align1x2.cu", | "cutlass_simt_sgemv_batched_strided_1x32_16_tt_align1x2.cu", | ||||
| "cutlass_simt_sgemv_batched_strided_1x32_8_tt_align1x1.cu", | "cutlass_simt_sgemv_batched_strided_1x32_8_tt_align1x1.cu", | ||||
| "cutlass_simt_s8_idgrad_identity_s8_64x128x32_64x32x32_2_nc4hw4_k4rsc4.cu", | |||||
| "cutlass_simt_s8_idgrad_identity_s8_32x128x32_32x64x32_2_nc4hw4_k4rsc4.cu", | |||||
| "cutlass_simt_s8_idgrad_identity_s8_16x128x16_16x64x16_2_nc4hw4_k4rsc4.cu", | |||||
| "cutlass_simt_s8_idgrad_identity_s8_16x128x16_16x128x16_1_nc4hw4_k4rsc4.cu", | |||||
| "cutlass_simt_s8_idgrad_identity_s8_16x64x8_16x64x8_2_nc4hw4_k4rsc4.cu", | |||||
| "cutlass_simt_s8_ifprop_identity_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||||
| "cutlass_simt_s8_ifprop_1x1_identity_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||||
| "cutlass_simt_s8_idgrad_id_s8_64x128x32_64x32x32_2_nc4hw4_k4rsc4.cu", | |||||
| "cutlass_simt_s8_idgrad_id_s8_32x128x32_32x64x32_2_nc4hw4_k4rsc4.cu", | |||||
| "cutlass_simt_s8_idgrad_id_s8_16x128x16_16x64x16_2_nc4hw4_k4rsc4.cu", | |||||
| "cutlass_simt_s8_idgrad_id_s8_16x128x16_16x128x16_1_nc4hw4_k4rsc4.cu", | |||||
| "cutlass_simt_s8_idgrad_id_s8_16x64x8_16x64x8_2_nc4hw4_k4rsc4.cu", | |||||
| "cutlass_simt_s8_ifprop_id_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||||
| "cutlass_simt_s8_ifprop_1x1_id_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||||
| "cutlass_simt_s8_ifprop_relu_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | "cutlass_simt_s8_ifprop_relu_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | ||||
| "cutlass_simt_s8_ifprop_1x1_relu_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | "cutlass_simt_s8_ifprop_1x1_relu_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | ||||
| "cutlass_simt_s8_ifprop_hswish_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | "cutlass_simt_s8_ifprop_hswish_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | ||||
| "cutlass_simt_s8_ifprop_1x1_hswish_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | "cutlass_simt_s8_ifprop_1x1_hswish_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | ||||
| "cutlass_simt_s8_ifprop_identity_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||||
| "cutlass_simt_s8_ifprop_1x1_identity_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||||
| "cutlass_simt_s8_ifprop_id_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||||
| "cutlass_simt_s8_ifprop_1x1_id_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||||
| "cutlass_simt_s8_ifprop_relu_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4.cu", | "cutlass_simt_s8_ifprop_relu_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4.cu", | ||||
| "cutlass_simt_s8_ifprop_1x1_relu_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4.cu", | "cutlass_simt_s8_ifprop_1x1_relu_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4.cu", | ||||
| "cutlass_simt_s8_ifprop_hswish_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4.cu", | "cutlass_simt_s8_ifprop_hswish_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4.cu", | ||||
| "cutlass_simt_s8_ifprop_1x1_hswish_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4.cu", | "cutlass_simt_s8_ifprop_1x1_hswish_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4.cu", | ||||
| "cutlass_simt_s8_ifprop_identity_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||||
| "cutlass_simt_s8_ifprop_1x1_identity_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||||
| "cutlass_simt_s8_ifprop_id_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||||
| "cutlass_simt_s8_ifprop_1x1_id_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||||
| "cutlass_simt_s8_ifprop_relu_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | "cutlass_simt_s8_ifprop_relu_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | ||||
| "cutlass_simt_s8_ifprop_1x1_relu_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | "cutlass_simt_s8_ifprop_1x1_relu_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | ||||
| "cutlass_simt_s8_ifprop_hswish_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | "cutlass_simt_s8_ifprop_hswish_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | ||||
| "cutlass_simt_s8_ifprop_1x1_hswish_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | "cutlass_simt_s8_ifprop_1x1_hswish_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | ||||
| "cutlass_simt_s8_ifprop_identity_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||||
| "cutlass_simt_s8_ifprop_1x1_identity_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||||
| "cutlass_simt_s8_ifprop_relu_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||||
| "cutlass_simt_s8_ifprop_1x1_relu_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||||
| "cutlass_simt_s8_ifprop_hswish_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||||
| "cutlass_simt_s8_ifprop_1x1_hswish_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||||
| "cutlass_simt_s8_ifprop_identity_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||||
| "cutlass_simt_s8_ifprop_1x1_identity_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||||
| "cutlass_simt_s8_ifprop_id_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||||
| "cutlass_simt_s8_ifprop_1x1_id_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||||
| "cutlass_simt_s8_ifprop_relu_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4.cu", | "cutlass_simt_s8_ifprop_relu_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4.cu", | ||||
| "cutlass_simt_s8_ifprop_1x1_relu_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4.cu", | "cutlass_simt_s8_ifprop_1x1_relu_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4.cu", | ||||
| "cutlass_simt_s8_ifprop_hswish_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4.cu", | "cutlass_simt_s8_ifprop_hswish_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4.cu", | ||||
| "cutlass_simt_s8_ifprop_1x1_hswish_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4.cu", | "cutlass_simt_s8_ifprop_1x1_hswish_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4.cu", | ||||
| "cutlass_simt_s8_ifprop_identity_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4.cu", | |||||
| "cutlass_simt_s8_ifprop_1x1_identity_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4.cu", | |||||
| "cutlass_simt_s8_ifprop_id_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4.cu", | |||||
| "cutlass_simt_s8_ifprop_1x1_id_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4.cu", | |||||
| "cutlass_simt_s8_ifprop_relu_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4.cu", | "cutlass_simt_s8_ifprop_relu_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4.cu", | ||||
| "cutlass_simt_s8_ifprop_1x1_relu_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4.cu", | "cutlass_simt_s8_ifprop_1x1_relu_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4.cu", | ||||
| "cutlass_simt_s8_ifprop_hswish_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4.cu", | "cutlass_simt_s8_ifprop_hswish_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4.cu", | ||||
| "cutlass_simt_s8_ifprop_1x1_hswish_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4.cu", | "cutlass_simt_s8_ifprop_1x1_hswish_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4.cu", | ||||
| "cutlass_simt_s8_ifprop_identity_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4.cu", | |||||
| "cutlass_simt_s8_ifprop_1x1_identity_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4.cu", | |||||
| "cutlass_simt_s8_ifprop_id_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4.cu", | |||||
| "cutlass_simt_s8_ifprop_1x1_id_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4.cu", | |||||
| "cutlass_simt_s8_ifprop_relu_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4.cu", | "cutlass_simt_s8_ifprop_relu_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4.cu", | ||||
| "cutlass_simt_s8_ifprop_1x1_relu_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4.cu", | "cutlass_simt_s8_ifprop_1x1_relu_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4.cu", | ||||
| "cutlass_simt_s8_ifprop_hswish_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4.cu", | "cutlass_simt_s8_ifprop_hswish_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4.cu", | ||||
| "cutlass_simt_s8_ifprop_1x1_hswish_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4.cu", | "cutlass_simt_s8_ifprop_1x1_hswish_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4.cu", | ||||
| "cutlass_simt_s8_ifprop_identity_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||||
| "cutlass_simt_s8_ifprop_1x1_identity_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||||
| "cutlass_simt_s8_ifprop_id_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||||
| "cutlass_simt_s8_ifprop_1x1_id_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4.cu", | |||||
| "cutlass_simt_s8_ifprop_relu_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4.cu", | "cutlass_simt_s8_ifprop_relu_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4.cu", | ||||
| "cutlass_simt_s8_ifprop_1x1_relu_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4.cu", | "cutlass_simt_s8_ifprop_1x1_relu_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4.cu", | ||||
| "cutlass_simt_s8_ifprop_hswish_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4.cu", | "cutlass_simt_s8_ifprop_hswish_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4.cu", | ||||
| "cutlass_simt_s8_ifprop_1x1_hswish_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4.cu", | "cutlass_simt_s8_ifprop_1x1_hswish_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4.cu", | ||||
| "cutlass_simt_s8_ifprop_identity_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4.cu", | |||||
| "cutlass_simt_s8_ifprop_1x1_identity_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4.cu", | |||||
| "cutlass_simt_s8_ifprop_relu_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4.cu", | |||||
| "cutlass_simt_s8_ifprop_1x1_relu_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4.cu", | |||||
| "cutlass_simt_s8_ifprop_hswish_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4.cu", | |||||
| "cutlass_simt_s8_ifprop_1x1_hswish_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4.cu", | |||||
| "cutlass_simt_s8_ifprop_identity_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4.cu", | |||||
| "cutlass_simt_s8_ifprop_1x1_identity_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4.cu", | |||||
| "cutlass_simt_s8_ifprop_id_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4.cu", | |||||
| "cutlass_simt_s8_ifprop_1x1_id_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4.cu", | |||||
| "cutlass_simt_s8_ifprop_relu_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4.cu", | "cutlass_simt_s8_ifprop_relu_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4.cu", | ||||
| "cutlass_simt_s8_ifprop_1x1_relu_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4.cu", | "cutlass_simt_s8_ifprop_1x1_relu_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4.cu", | ||||
| "cutlass_simt_s8_ifprop_hswish_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4.cu", | "cutlass_simt_s8_ifprop_hswish_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4.cu", | ||||
| "cutlass_simt_s8_ifprop_1x1_hswish_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4.cu", | "cutlass_simt_s8_ifprop_1x1_hswish_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4.cu", | ||||
| "cutlass_simt_s8_ifprop_identity_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4.cu", | |||||
| "cutlass_simt_s8_ifprop_1x1_identity_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4.cu", | |||||
| "cutlass_simt_s8_ifprop_id_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4.cu", | |||||
| "cutlass_simt_s8_ifprop_1x1_id_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4.cu", | |||||
| "cutlass_simt_s8_ifprop_relu_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4.cu", | "cutlass_simt_s8_ifprop_relu_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4.cu", | ||||
| "cutlass_simt_s8_ifprop_1x1_relu_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4.cu", | "cutlass_simt_s8_ifprop_1x1_relu_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4.cu", | ||||
| "cutlass_simt_s8_ifprop_hswish_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4.cu", | "cutlass_simt_s8_ifprop_hswish_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4.cu", | ||||
| "cutlass_simt_s8_ifprop_1x1_hswish_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4.cu", | "cutlass_simt_s8_ifprop_1x1_hswish_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4.cu", | ||||
| "cutlass_simt_s8_ifprop_identity_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||||
| "cutlass_simt_s8_ifprop_1x1_identity_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||||
| "cutlass_simt_s8_ifprop_id_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||||
| "cutlass_simt_s8_ifprop_1x1_id_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||||
| "cutlass_simt_s8_ifprop_relu_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | "cutlass_simt_s8_ifprop_relu_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | ||||
| "cutlass_simt_s8_ifprop_1x1_relu_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | "cutlass_simt_s8_ifprop_1x1_relu_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | ||||
| "cutlass_simt_s8_ifprop_hswish_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | "cutlass_simt_s8_ifprop_hswish_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | ||||
| "cutlass_simt_s8_ifprop_1x1_hswish_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | "cutlass_simt_s8_ifprop_1x1_hswish_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | ||||
| "cutlass_simt_s8_ifprop_identity_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||||
| "cutlass_simt_s8_ifprop_1x1_identity_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||||
| "cutlass_simt_s8_ifprop_id_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||||
| "cutlass_simt_s8_ifprop_1x1_id_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||||
| "cutlass_simt_s8_ifprop_relu_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | "cutlass_simt_s8_ifprop_relu_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | ||||
| "cutlass_simt_s8_ifprop_1x1_relu_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | "cutlass_simt_s8_ifprop_1x1_relu_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | ||||
| "cutlass_simt_s8_ifprop_hswish_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | "cutlass_simt_s8_ifprop_hswish_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | ||||
| "cutlass_simt_s8_ifprop_1x1_hswish_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | "cutlass_simt_s8_ifprop_1x1_hswish_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | ||||
| "cutlass_simt_s8_ifprop_identity_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||||
| "cutlass_simt_s8_ifprop_1x1_identity_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||||
| "cutlass_simt_s8_ifprop_id_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||||
| "cutlass_simt_s8_ifprop_1x1_id_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||||
| "cutlass_simt_s8_ifprop_relu_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | "cutlass_simt_s8_ifprop_relu_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | ||||
| "cutlass_simt_s8_ifprop_1x1_relu_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | "cutlass_simt_s8_ifprop_1x1_relu_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | ||||
| "cutlass_simt_s8_ifprop_hswish_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | "cutlass_simt_s8_ifprop_hswish_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | ||||
| "cutlass_simt_s8_ifprop_1x1_hswish_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | "cutlass_simt_s8_ifprop_1x1_hswish_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | ||||
| "cutlass_simt_s8_ifprop_identity_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||||
| "cutlass_simt_s8_ifprop_1x1_identity_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||||
| "cutlass_simt_s8_ifprop_relu_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||||
| "cutlass_simt_s8_ifprop_1x1_relu_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||||
| "cutlass_simt_s8_ifprop_hswish_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||||
| "cutlass_simt_s8_ifprop_1x1_hswish_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||||
| "cutlass_simt_s8_ifprop_identity_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||||
| "cutlass_simt_s8_ifprop_1x1_identity_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||||
| "cutlass_simt_s8_ifprop_id_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||||
| "cutlass_simt_s8_ifprop_1x1_id_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||||
| "cutlass_simt_s8_ifprop_relu_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | "cutlass_simt_s8_ifprop_relu_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | ||||
| "cutlass_simt_s8_ifprop_1x1_relu_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | "cutlass_simt_s8_ifprop_1x1_relu_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | ||||
| "cutlass_simt_s8_ifprop_hswish_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | "cutlass_simt_s8_ifprop_hswish_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | ||||
| "cutlass_simt_s8_ifprop_1x1_hswish_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | "cutlass_simt_s8_ifprop_1x1_hswish_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | ||||
| "cutlass_simt_s8_ifprop_identity_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||||
| "cutlass_simt_s8_ifprop_1x1_identity_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||||
| "cutlass_simt_s8_ifprop_id_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||||
| "cutlass_simt_s8_ifprop_1x1_id_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||||
| "cutlass_simt_s8_ifprop_relu_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | "cutlass_simt_s8_ifprop_relu_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | ||||
| "cutlass_simt_s8_ifprop_1x1_relu_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | "cutlass_simt_s8_ifprop_1x1_relu_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | ||||
| "cutlass_simt_s8_ifprop_hswish_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | "cutlass_simt_s8_ifprop_hswish_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | ||||
| "cutlass_simt_s8_ifprop_1x1_hswish_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | "cutlass_simt_s8_ifprop_1x1_hswish_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | ||||
| "cutlass_simt_s8_ifprop_identity_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||||
| "cutlass_simt_s8_ifprop_1x1_identity_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||||
| "cutlass_simt_s8_ifprop_id_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||||
| "cutlass_simt_s8_ifprop_1x1_id_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||||
| "cutlass_simt_s8_ifprop_relu_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | "cutlass_simt_s8_ifprop_relu_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | ||||
| "cutlass_simt_s8_ifprop_1x1_relu_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | "cutlass_simt_s8_ifprop_1x1_relu_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | ||||
| "cutlass_simt_s8_ifprop_hswish_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | "cutlass_simt_s8_ifprop_hswish_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | ||||
| "cutlass_simt_s8_ifprop_1x1_hswish_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | "cutlass_simt_s8_ifprop_1x1_hswish_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | ||||
| "cutlass_simt_s8_ifprop_identity_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||||
| "cutlass_simt_s8_ifprop_1x1_identity_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||||
| "cutlass_simt_s8_ifprop_id_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||||
| "cutlass_simt_s8_ifprop_1x1_id_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||||
| "cutlass_simt_s8_ifprop_relu_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | "cutlass_simt_s8_ifprop_relu_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | ||||
| "cutlass_simt_s8_ifprop_1x1_relu_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | "cutlass_simt_s8_ifprop_1x1_relu_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | ||||
| "cutlass_simt_s8_ifprop_hswish_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | "cutlass_simt_s8_ifprop_hswish_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | ||||
| "cutlass_simt_s8_ifprop_1x1_hswish_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | "cutlass_simt_s8_ifprop_1x1_hswish_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | ||||
| "cutlass_simt_s8_ifprop_identity_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||||
| "cutlass_simt_s8_ifprop_1x1_identity_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||||
| "cutlass_simt_s8_ifprop_relu_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||||
| "cutlass_simt_s8_ifprop_1x1_relu_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||||
| "cutlass_simt_s8_ifprop_hswish_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||||
| "cutlass_simt_s8_ifprop_1x1_hswish_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4_nc32hw32.cu", | |||||
| "cutlass_simt_u4_ifprop_identity_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||||
| "cutlass_simt_u4_ifprop_id_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||||
| "cutlass_simt_u4_ifprop_relu_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | "cutlass_simt_u4_ifprop_relu_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | ||||
| "cutlass_simt_u4_ifprop_hswish_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | "cutlass_simt_u4_ifprop_hswish_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | ||||
| "cutlass_simt_u4_ifprop_identity_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||||
| "cutlass_simt_u4_ifprop_id_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||||
| "cutlass_simt_u4_ifprop_relu_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | "cutlass_simt_u4_ifprop_relu_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | ||||
| "cutlass_simt_u4_ifprop_hswish_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | "cutlass_simt_u4_ifprop_hswish_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | ||||
| "cutlass_simt_u4_ifprop_identity_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||||
| "cutlass_simt_u4_ifprop_id_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||||
| "cutlass_simt_u4_ifprop_relu_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | "cutlass_simt_u4_ifprop_relu_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | ||||
| "cutlass_simt_u4_ifprop_hswish_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | "cutlass_simt_u4_ifprop_hswish_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | ||||
| "cutlass_simt_u4_ifprop_identity_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||||
| "cutlass_simt_u4_ifprop_relu_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||||
| "cutlass_simt_u4_ifprop_hswish_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||||
| "cutlass_simt_u4_ifprop_identity_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||||
| "cutlass_simt_u4_ifprop_id_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||||
| "cutlass_simt_u4_ifprop_relu_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | "cutlass_simt_u4_ifprop_relu_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | ||||
| "cutlass_simt_u4_ifprop_hswish_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | "cutlass_simt_u4_ifprop_hswish_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | ||||
| "cutlass_simt_u4_ifprop_identity_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||||
| "cutlass_simt_u4_ifprop_id_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||||
| "cutlass_simt_u4_ifprop_relu_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nhwc.cu", | "cutlass_simt_u4_ifprop_relu_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nhwc.cu", | ||||
| "cutlass_simt_u4_ifprop_hswish_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nhwc.cu", | "cutlass_simt_u4_ifprop_hswish_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nhwc.cu", | ||||
| "cutlass_simt_u4_ifprop_identity_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||||
| "cutlass_simt_u4_ifprop_id_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||||
| "cutlass_simt_u4_ifprop_relu_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nhwc.cu", | "cutlass_simt_u4_ifprop_relu_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nhwc.cu", | ||||
| "cutlass_simt_u4_ifprop_hswish_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nhwc.cu", | "cutlass_simt_u4_ifprop_hswish_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nhwc.cu", | ||||
| "cutlass_simt_u4_ifprop_identity_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||||
| "cutlass_simt_u4_ifprop_id_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||||
| "cutlass_simt_u4_ifprop_relu_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | "cutlass_simt_u4_ifprop_relu_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | ||||
| "cutlass_simt_u4_ifprop_hswish_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | "cutlass_simt_u4_ifprop_hswish_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | ||||
| "cutlass_simt_u4_ifprop_identity_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||||
| "cutlass_simt_u4_ifprop_relu_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||||
| "cutlass_simt_u4_ifprop_hswish_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||||
| "cutlass_simt_u4_ifprop_identity_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4_nhwc.cu", | |||||
| "cutlass_simt_u4_ifprop_id_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4_nhwc.cu", | |||||
| "cutlass_simt_u4_ifprop_relu_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4_nhwc.cu", | "cutlass_simt_u4_ifprop_relu_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4_nhwc.cu", | ||||
| "cutlass_simt_u4_ifprop_hswish_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4_nhwc.cu", | "cutlass_simt_u4_ifprop_hswish_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4_nhwc.cu", | ||||
| "cutlass_simt_u4_ifprop_identity_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nhwc.cu", | |||||
| "cutlass_simt_u4_ifprop_id_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nhwc.cu", | |||||
| "cutlass_simt_u4_ifprop_relu_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nhwc.cu", | "cutlass_simt_u4_ifprop_relu_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nhwc.cu", | ||||
| "cutlass_simt_u4_ifprop_hswish_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nhwc.cu", | "cutlass_simt_u4_ifprop_hswish_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nhwc.cu", | ||||
| "cutlass_simt_s4_ifprop_identity_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||||
| "cutlass_simt_s4_ifprop_id_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||||
| "cutlass_simt_s4_ifprop_relu_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | "cutlass_simt_s4_ifprop_relu_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | ||||
| "cutlass_simt_s4_ifprop_hswish_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | "cutlass_simt_s4_ifprop_hswish_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | ||||
| "cutlass_simt_s4_ifprop_identity_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||||
| "cutlass_simt_s4_ifprop_id_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||||
| "cutlass_simt_s4_ifprop_relu_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | "cutlass_simt_s4_ifprop_relu_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | ||||
| "cutlass_simt_s4_ifprop_hswish_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | "cutlass_simt_s4_ifprop_hswish_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | ||||
| "cutlass_simt_s4_ifprop_identity_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||||
| "cutlass_simt_s4_ifprop_id_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||||
| "cutlass_simt_s4_ifprop_relu_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | "cutlass_simt_s4_ifprop_relu_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | ||||
| "cutlass_simt_s4_ifprop_hswish_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | "cutlass_simt_s4_ifprop_hswish_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | ||||
| "cutlass_simt_s4_ifprop_identity_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||||
| "cutlass_simt_s4_ifprop_relu_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||||
| "cutlass_simt_s4_ifprop_hswish_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||||
| "cutlass_simt_s4_ifprop_identity_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||||
| "cutlass_simt_s4_ifprop_id_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||||
| "cutlass_simt_s4_ifprop_relu_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | "cutlass_simt_s4_ifprop_relu_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | ||||
| "cutlass_simt_s4_ifprop_hswish_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | "cutlass_simt_s4_ifprop_hswish_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | ||||
| "cutlass_simt_s4_ifprop_identity_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||||
| "cutlass_simt_s4_ifprop_id_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||||
| "cutlass_simt_s4_ifprop_relu_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nhwc.cu", | "cutlass_simt_s4_ifprop_relu_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nhwc.cu", | ||||
| "cutlass_simt_s4_ifprop_hswish_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nhwc.cu", | "cutlass_simt_s4_ifprop_hswish_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nhwc.cu", | ||||
| "cutlass_simt_s4_ifprop_identity_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||||
| "cutlass_simt_s4_ifprop_id_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||||
| "cutlass_simt_s4_ifprop_relu_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nhwc.cu", | "cutlass_simt_s4_ifprop_relu_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nhwc.cu", | ||||
| "cutlass_simt_s4_ifprop_hswish_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nhwc.cu", | "cutlass_simt_s4_ifprop_hswish_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nhwc.cu", | ||||
| "cutlass_simt_s4_ifprop_identity_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||||
| "cutlass_simt_s4_ifprop_id_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||||
| "cutlass_simt_s4_ifprop_relu_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | "cutlass_simt_s4_ifprop_relu_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | ||||
| "cutlass_simt_s4_ifprop_hswish_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | "cutlass_simt_s4_ifprop_hswish_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | ||||
| "cutlass_simt_s4_ifprop_identity_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||||
| "cutlass_simt_s4_ifprop_relu_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||||
| "cutlass_simt_s4_ifprop_hswish_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4_nhwc.cu", | |||||
| "cutlass_simt_s4_ifprop_identity_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4_nhwc.cu", | |||||
| "cutlass_simt_s4_ifprop_id_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4_nhwc.cu", | |||||
| "cutlass_simt_s4_ifprop_relu_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4_nhwc.cu", | "cutlass_simt_s4_ifprop_relu_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4_nhwc.cu", | ||||
| "cutlass_simt_s4_ifprop_hswish_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4_nhwc.cu", | "cutlass_simt_s4_ifprop_hswish_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4_nhwc.cu", | ||||
| "cutlass_simt_s4_ifprop_identity_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nhwc.cu", | |||||
| "cutlass_simt_s4_ifprop_id_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nhwc.cu", | |||||
| "cutlass_simt_s4_ifprop_relu_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nhwc.cu", | "cutlass_simt_s4_ifprop_relu_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nhwc.cu", | ||||
| "cutlass_simt_s4_ifprop_hswish_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nhwc.cu", | "cutlass_simt_s4_ifprop_hswish_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nhwc.cu", | ||||
| "cutlass_simt_f32_ifprop_identity_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||||
| "cutlass_simt_f32_ifprop_1x1_identity_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||||
| "cutlass_simt_f32_ifprop_id_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||||
| "cutlass_simt_f32_ifprop_1x1_id_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||||
| "cutlass_simt_f32_ifprop_relu_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | "cutlass_simt_f32_ifprop_relu_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | ||||
| "cutlass_simt_f32_ifprop_1x1_relu_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | "cutlass_simt_f32_ifprop_1x1_relu_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | ||||
| "cutlass_simt_f32_ifprop_hswish_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | "cutlass_simt_f32_ifprop_hswish_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | ||||
| "cutlass_simt_f32_ifprop_1x1_hswish_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | "cutlass_simt_f32_ifprop_1x1_hswish_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | ||||
| "cutlass_simt_f32_ifprop_identity_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||||
| "cutlass_simt_f32_ifprop_1x1_identity_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||||
| "cutlass_simt_f32_ifprop_id_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||||
| "cutlass_simt_f32_ifprop_1x1_id_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||||
| "cutlass_simt_f32_ifprop_relu_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | "cutlass_simt_f32_ifprop_relu_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | ||||
| "cutlass_simt_f32_ifprop_1x1_relu_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | "cutlass_simt_f32_ifprop_1x1_relu_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | ||||
| "cutlass_simt_f32_ifprop_hswish_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | "cutlass_simt_f32_ifprop_hswish_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | ||||
| "cutlass_simt_f32_ifprop_1x1_hswish_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | "cutlass_simt_f32_ifprop_1x1_hswish_s8_128x64x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | ||||
| "cutlass_simt_f32_ifprop_identity_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||||
| "cutlass_simt_f32_ifprop_1x1_identity_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||||
| "cutlass_simt_f32_ifprop_id_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||||
| "cutlass_simt_f32_ifprop_1x1_id_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||||
| "cutlass_simt_f32_ifprop_relu_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | "cutlass_simt_f32_ifprop_relu_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | ||||
| "cutlass_simt_f32_ifprop_1x1_relu_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | "cutlass_simt_f32_ifprop_1x1_relu_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | ||||
| "cutlass_simt_f32_ifprop_hswish_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | "cutlass_simt_f32_ifprop_hswish_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | ||||
| "cutlass_simt_f32_ifprop_1x1_hswish_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | "cutlass_simt_f32_ifprop_1x1_hswish_s8_64x128x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | ||||
| "cutlass_simt_f32_ifprop_identity_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||||
| "cutlass_simt_f32_ifprop_1x1_identity_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||||
| "cutlass_simt_f32_ifprop_relu_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||||
| "cutlass_simt_f32_ifprop_1x1_relu_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||||
| "cutlass_simt_f32_ifprop_hswish_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||||
| "cutlass_simt_f32_ifprop_1x1_hswish_s8_64x64x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||||
| "cutlass_simt_f32_ifprop_identity_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||||
| "cutlass_simt_f32_ifprop_1x1_identity_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||||
| "cutlass_simt_f32_ifprop_id_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||||
| "cutlass_simt_f32_ifprop_1x1_id_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||||
| "cutlass_simt_f32_ifprop_relu_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | "cutlass_simt_f32_ifprop_relu_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | ||||
| "cutlass_simt_f32_ifprop_1x1_relu_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | "cutlass_simt_f32_ifprop_1x1_relu_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | ||||
| "cutlass_simt_f32_ifprop_hswish_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | "cutlass_simt_f32_ifprop_hswish_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | ||||
| "cutlass_simt_f32_ifprop_1x1_hswish_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | "cutlass_simt_f32_ifprop_1x1_hswish_s8_128x32x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | ||||
| "cutlass_simt_f32_ifprop_identity_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nchw.cu", | |||||
| "cutlass_simt_f32_ifprop_1x1_identity_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nchw.cu", | |||||
| "cutlass_simt_f32_ifprop_id_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nchw.cu", | |||||
| "cutlass_simt_f32_ifprop_1x1_id_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nchw.cu", | |||||
| "cutlass_simt_f32_ifprop_relu_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nchw.cu", | "cutlass_simt_f32_ifprop_relu_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nchw.cu", | ||||
| "cutlass_simt_f32_ifprop_1x1_relu_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nchw.cu", | "cutlass_simt_f32_ifprop_1x1_relu_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nchw.cu", | ||||
| "cutlass_simt_f32_ifprop_hswish_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nchw.cu", | "cutlass_simt_f32_ifprop_hswish_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nchw.cu", | ||||
| "cutlass_simt_f32_ifprop_1x1_hswish_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nchw.cu", | "cutlass_simt_f32_ifprop_1x1_hswish_s8_32x128x32_32x64x32_2_nc4hw4_c4rsk4_nchw.cu", | ||||
| "cutlass_simt_f32_ifprop_identity_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nchw.cu", | |||||
| "cutlass_simt_f32_ifprop_1x1_identity_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nchw.cu", | |||||
| "cutlass_simt_f32_ifprop_id_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nchw.cu", | |||||
| "cutlass_simt_f32_ifprop_1x1_id_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nchw.cu", | |||||
| "cutlass_simt_f32_ifprop_relu_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nchw.cu", | "cutlass_simt_f32_ifprop_relu_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nchw.cu", | ||||
| "cutlass_simt_f32_ifprop_1x1_relu_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nchw.cu", | "cutlass_simt_f32_ifprop_1x1_relu_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nchw.cu", | ||||
| "cutlass_simt_f32_ifprop_hswish_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nchw.cu", | "cutlass_simt_f32_ifprop_hswish_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nchw.cu", | ||||
| "cutlass_simt_f32_ifprop_1x1_hswish_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nchw.cu", | "cutlass_simt_f32_ifprop_1x1_hswish_s8_32x64x32_32x64x32_2_nc4hw4_c4rsk4_nchw.cu", | ||||
| "cutlass_simt_f32_ifprop_identity_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||||
| "cutlass_simt_f32_ifprop_1x1_identity_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||||
| "cutlass_simt_f32_ifprop_id_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||||
| "cutlass_simt_f32_ifprop_1x1_id_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||||
| "cutlass_simt_f32_ifprop_relu_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | "cutlass_simt_f32_ifprop_relu_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | ||||
| "cutlass_simt_f32_ifprop_1x1_relu_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | "cutlass_simt_f32_ifprop_1x1_relu_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | ||||
| "cutlass_simt_f32_ifprop_hswish_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | "cutlass_simt_f32_ifprop_hswish_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | ||||
| "cutlass_simt_f32_ifprop_1x1_hswish_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | "cutlass_simt_f32_ifprop_1x1_hswish_s8_64x32x32_64x32x32_2_nc4hw4_c4rsk4_nchw.cu", | ||||
| "cutlass_simt_f32_ifprop_identity_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||||
| "cutlass_simt_f32_ifprop_1x1_identity_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||||
| "cutlass_simt_f32_ifprop_relu_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||||
| "cutlass_simt_f32_ifprop_1x1_relu_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||||
| "cutlass_simt_f32_ifprop_hswish_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||||
| "cutlass_simt_f32_ifprop_1x1_hswish_s8_32x32x32_32x32x32_2_nc4hw4_c4rsk4_nchw.cu", | |||||
| "cutlass_simt_f32_ifprop_identity_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4_nchw.cu", | |||||
| "cutlass_simt_f32_ifprop_1x1_identity_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4_nchw.cu", | |||||
| "cutlass_simt_f32_ifprop_id_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4_nchw.cu", | |||||
| "cutlass_simt_f32_ifprop_1x1_id_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4_nchw.cu", | |||||
| "cutlass_simt_f32_ifprop_relu_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4_nchw.cu", | "cutlass_simt_f32_ifprop_relu_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4_nchw.cu", | ||||
| "cutlass_simt_f32_ifprop_1x1_relu_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4_nchw.cu", | "cutlass_simt_f32_ifprop_1x1_relu_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4_nchw.cu", | ||||
| "cutlass_simt_f32_ifprop_hswish_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4_nchw.cu", | "cutlass_simt_f32_ifprop_hswish_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4_nchw.cu", | ||||
| "cutlass_simt_f32_ifprop_1x1_hswish_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4_nchw.cu", | "cutlass_simt_f32_ifprop_1x1_hswish_s8_16x128x16_16x128x16_1_nc4hw4_c4rsk4_nchw.cu", | ||||
| "cutlass_simt_f32_ifprop_identity_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nchw.cu", | |||||
| "cutlass_simt_f32_ifprop_1x1_identity_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nchw.cu", | |||||
| "cutlass_simt_f32_ifprop_id_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nchw.cu", | |||||
| "cutlass_simt_f32_ifprop_1x1_id_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nchw.cu", | |||||
| "cutlass_simt_f32_ifprop_relu_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nchw.cu", | "cutlass_simt_f32_ifprop_relu_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nchw.cu", | ||||
| "cutlass_simt_f32_ifprop_1x1_relu_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nchw.cu", | "cutlass_simt_f32_ifprop_1x1_relu_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nchw.cu", | ||||
| "cutlass_simt_f32_ifprop_hswish_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nchw.cu", | "cutlass_simt_f32_ifprop_hswish_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nchw.cu", | ||||
| "cutlass_simt_f32_ifprop_1x1_hswish_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nchw.cu", | "cutlass_simt_f32_ifprop_1x1_hswish_s8_16x64x8_16x64x8_2_nc4hw4_c4rsk4_nchw.cu", | ||||
| "cutlass_tensorop_s8_i8816fprop_identity_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_1x1_identity_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_relu_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_1x1_relu_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_hswish_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_1x1_hswish_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_identity_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_1x1_identity_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_relu_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_1x1_relu_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_hswish_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_1x1_hswish_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_identity_s8_128x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_1x1_identity_s8_128x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_relu_s8_128x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_1x1_relu_s8_128x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_hswish_s8_128x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_1x1_hswish_s8_128x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_identity_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_1x1_identity_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_relu_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_1x1_relu_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_hswish_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_1x1_hswish_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_identity_s8_128x64x64_64x32x64_2_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_1x1_identity_s8_128x64x64_64x32x64_2_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_relu_s8_128x64x64_64x32x64_2_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_1x1_relu_s8_128x64x64_64x32x64_2_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_hswish_s8_128x64x64_64x32x64_2_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_1x1_hswish_s8_128x64x64_64x32x64_2_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_identity_s8_64x64x64_32x32x64_2_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_1x1_identity_s8_64x64x64_32x32x64_2_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_relu_s8_64x64x64_32x32x64_2_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_1x1_relu_s8_64x64x64_32x32x64_2_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_hswish_s8_64x64x64_32x32x64_2_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_1x1_hswish_s8_64x64x64_32x32x64_2_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_identity_s8_32x64x64_32x16x64_2_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_1x1_identity_s8_32x64x64_32x16x64_2_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_relu_s8_32x64x64_32x16x64_2_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_1x1_relu_s8_32x64x64_32x16x64_2_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_hswish_s8_32x64x64_32x16x64_2_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_1x1_hswish_s8_32x64x64_32x16x64_2_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_identity_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_1x1_identity_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_relu_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_1x1_relu_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_hswish_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_1x1_hswish_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_identity_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_1x1_identity_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_roc_id_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_1x1_roc_id_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_roc_relu_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_1x1_roc_relu_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_roc_hswish_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_1x1_roc_hswish_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_roc_id_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_1x1_roc_id_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_roc_relu_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_1x1_roc_relu_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_roc_hswish_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_1x1_roc_hswish_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_roc_id_s8_128x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_1x1_roc_id_s8_128x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_roc_relu_s8_128x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_1x1_roc_relu_s8_128x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_roc_hswish_s8_128x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_1x1_roc_hswish_s8_128x128x64_64x64x64_2_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_roc_id_s8_128x64x64_64x32x64_2_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_1x1_roc_id_s8_128x64x64_64x32x64_2_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_roc_relu_s8_128x64x64_64x32x64_2_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_1x1_roc_relu_s8_128x64x64_64x32x64_2_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_roc_hswish_s8_128x64x64_64x32x64_2_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_1x1_roc_hswish_s8_128x64x64_64x32x64_2_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_roc_id_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_1x1_roc_id_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_roc_relu_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_1x1_roc_relu_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_roc_hswish_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_1x1_roc_hswish_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_roc_id_s8_128x64x32_64x32x32_1_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_1x1_roc_id_s8_128x64x32_64x32x32_1_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_roc_relu_s8_128x64x32_64x32x32_1_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_1x1_roc_relu_s8_128x64x32_64x32x32_1_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_roc_hswish_s8_128x64x32_64x32x32_1_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_1x1_roc_hswish_s8_128x64x32_64x32x32_1_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_roc_id_s8_128x32x32_64x32x32_1_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_1x1_roc_id_s8_128x32x32_64x32x32_1_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_roc_relu_s8_128x32x32_64x32x32_1_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_1x1_roc_relu_s8_128x32x32_64x32x32_1_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_roc_hswish_s8_128x32x32_64x32x32_1_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_1x1_roc_hswish_s8_128x32x32_64x32x32_1_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_roc_id_s8_64x128x32_32x64x32_1_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_1x1_roc_id_s8_64x128x32_32x64x32_1_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_roc_relu_s8_64x128x32_32x64x32_1_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_1x1_roc_relu_s8_64x128x32_32x64x32_1_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_roc_hswish_s8_64x128x32_32x64x32_1_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_1x1_roc_hswish_s8_64x128x32_32x64x32_1_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_roc_id_s8_32x128x32_32x64x32_1_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_1x1_roc_id_s8_32x128x32_32x64x32_1_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_roc_relu_s8_32x128x32_32x64x32_1_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_1x1_roc_relu_s8_32x128x32_32x64x32_1_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_roc_hswish_s8_32x128x32_32x64x32_1_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_1x1_roc_hswish_s8_32x128x32_32x64x32_1_nc32hw32_c32rsk32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_id_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_1x1_id_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_relu_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | "cutlass_tensorop_s8_i8816fprop_relu_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | ||||
| "cutlass_tensorop_s8_i8816fprop_1x1_relu_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | "cutlass_tensorop_s8_i8816fprop_1x1_relu_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | ||||
| "cutlass_tensorop_s8_i8816fprop_hswish_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | "cutlass_tensorop_s8_i8816fprop_hswish_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | ||||
| "cutlass_tensorop_s8_i8816fprop_1x1_hswish_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | "cutlass_tensorop_s8_i8816fprop_1x1_hswish_s8_128x256x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | ||||
| "cutlass_tensorop_s8_i8816fprop_identity_s8_128x128x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_1x1_identity_s8_128x128x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_id_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_1x1_id_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_relu_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_1x1_relu_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_hswish_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_1x1_hswish_s8_256x128x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_id_s8_128x128x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_1x1_id_s8_128x128x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_relu_s8_128x128x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | "cutlass_tensorop_s8_i8816fprop_relu_s8_128x128x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | ||||
| "cutlass_tensorop_s8_i8816fprop_1x1_relu_s8_128x128x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | "cutlass_tensorop_s8_i8816fprop_1x1_relu_s8_128x128x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | ||||
| "cutlass_tensorop_s8_i8816fprop_hswish_s8_128x128x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | "cutlass_tensorop_s8_i8816fprop_hswish_s8_128x128x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | ||||
| "cutlass_tensorop_s8_i8816fprop_1x1_hswish_s8_128x128x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | "cutlass_tensorop_s8_i8816fprop_1x1_hswish_s8_128x128x64_64x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | ||||
| "cutlass_tensorop_s8_i8816fprop_identity_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_1x1_identity_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_relu_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_1x1_relu_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_hswish_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_1x1_hswish_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_identity_s8_128x64x64_64x32x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_1x1_identity_s8_128x64x64_64x32x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_id_s8_128x64x64_64x32x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_1x1_id_s8_128x64x64_64x32x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_relu_s8_128x64x64_64x32x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | "cutlass_tensorop_s8_i8816fprop_relu_s8_128x64x64_64x32x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | ||||
| "cutlass_tensorop_s8_i8816fprop_1x1_relu_s8_128x64x64_64x32x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | "cutlass_tensorop_s8_i8816fprop_1x1_relu_s8_128x64x64_64x32x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | ||||
| "cutlass_tensorop_s8_i8816fprop_hswish_s8_128x64x64_64x32x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | "cutlass_tensorop_s8_i8816fprop_hswish_s8_128x64x64_64x32x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | ||||
| "cutlass_tensorop_s8_i8816fprop_1x1_hswish_s8_128x64x64_64x32x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | "cutlass_tensorop_s8_i8816fprop_1x1_hswish_s8_128x64x64_64x32x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | ||||
| "cutlass_tensorop_s8_i8816fprop_identity_s8_64x64x64_32x32x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_1x1_identity_s8_64x64x64_32x32x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_relu_s8_64x64x64_32x32x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_1x1_relu_s8_64x64x64_32x32x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_hswish_s8_64x64x64_32x32x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_1x1_hswish_s8_64x64x64_32x32x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_identity_s8_32x64x64_16x32x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_1x1_identity_s8_32x64x64_16x32x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_relu_s8_32x64x64_16x32x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_1x1_relu_s8_32x64x64_16x32x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_hswish_s8_32x64x64_16x32x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_1x1_hswish_s8_32x64x64_16x32x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_identity_s4_256x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_relu_s4_256x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_hswish_s4_256x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_identity_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_relu_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_hswish_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||||
| "cutlass_tensorop_u4_i8832fprop_identity_u4_s4_256x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||||
| "cutlass_tensorop_u4_i8832fprop_relu_u4_s4_256x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||||
| "cutlass_tensorop_u4_i8832fprop_identity_u4_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||||
| "cutlass_tensorop_u4_i8832fprop_relu_u4_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_identity_s4_128x32x64_64x32x64_2_nhwc_nc8hw8.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_1x1_identity_s4_128x32x64_64x32x64_2_nhwc_nc8hw8.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_relu_s4_128x32x64_64x32x64_2_nhwc_nc8hw8.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_1x1_relu_s4_128x32x64_64x32x64_2_nhwc_nc8hw8.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_hswish_s4_128x32x64_64x32x64_2_nhwc_nc8hw8.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_1x1_hswish_s4_128x32x64_64x32x64_2_nhwc_nc8hw8.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_identity_s4_128x64x64_64x64x64_2_nhwc_nc8hw8.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_1x1_identity_s4_128x64x64_64x64x64_2_nhwc_nc8hw8.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_relu_s4_128x64x64_64x64x64_2_nhwc_nc8hw8.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_1x1_relu_s4_128x64x64_64x64x64_2_nhwc_nc8hw8.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_hswish_s4_128x64x64_64x64x64_2_nhwc_nc8hw8.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_1x1_hswish_s4_128x64x64_64x64x64_2_nhwc_nc8hw8.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_identity_s4_128x32x64_64x32x64_2_nhwc_nc16hw16.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_1x1_identity_s4_128x32x64_64x32x64_2_nhwc_nc16hw16.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_relu_s4_128x32x64_64x32x64_2_nhwc_nc16hw16.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_1x1_relu_s4_128x32x64_64x32x64_2_nhwc_nc16hw16.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_hswish_s4_128x32x64_64x32x64_2_nhwc_nc16hw16.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_1x1_hswish_s4_128x32x64_64x32x64_2_nhwc_nc16hw16.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_identity_s4_128x64x64_64x64x64_2_nhwc_nc16hw16.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_1x1_identity_s4_128x64x64_64x64x64_2_nhwc_nc16hw16.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_relu_s4_128x64x64_64x64x64_2_nhwc_nc16hw16.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_1x1_relu_s4_128x64x64_64x64x64_2_nhwc_nc16hw16.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_hswish_s4_128x64x64_64x64x64_2_nhwc_nc16hw16.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_1x1_hswish_s4_128x64x64_64x64x64_2_nhwc_nc16hw16.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_identity_s4_128x32x64_64x32x64_2_nhwc_nc32hw32.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_1x1_identity_s4_128x32x64_64x32x64_2_nhwc_nc32hw32.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_relu_s4_128x32x64_64x32x64_2_nhwc_nc32hw32.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_1x1_relu_s4_128x32x64_64x32x64_2_nhwc_nc32hw32.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_hswish_s4_128x32x64_64x32x64_2_nhwc_nc32hw32.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_1x1_hswish_s4_128x32x64_64x32x64_2_nhwc_nc32hw32.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_identity_s4_128x64x64_64x64x64_2_nhwc_nc32hw32.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_1x1_identity_s4_128x64x64_64x64x64_2_nhwc_nc32hw32.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_relu_s4_128x64x64_64x64x64_2_nhwc_nc32hw32.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_1x1_relu_s4_128x64x64_64x64x64_2_nhwc_nc32hw32.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_hswish_s4_128x64x64_64x64x64_2_nhwc_nc32hw32.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_1x1_hswish_s4_128x64x64_64x64x64_2_nhwc_nc32hw32.cu", | |||||
| "cutlass_tensorop_u4_i8832fprop_identity_u4_s4_128x32x64_64x32x64_2_nhwc_nc8hw8.cu", | |||||
| "cutlass_tensorop_u4_i8832fprop_1x1_identity_u4_s4_128x32x64_64x32x64_2_nhwc_nc8hw8.cu", | |||||
| "cutlass_tensorop_u4_i8832fprop_relu_u4_s4_128x32x64_64x32x64_2_nhwc_nc8hw8.cu", | |||||
| "cutlass_tensorop_u4_i8832fprop_1x1_relu_u4_s4_128x32x64_64x32x64_2_nhwc_nc8hw8.cu", | |||||
| "cutlass_tensorop_u4_i8832fprop_identity_u4_s4_128x64x64_64x64x64_2_nhwc_nc8hw8.cu", | |||||
| "cutlass_tensorop_u4_i8832fprop_1x1_identity_u4_s4_128x64x64_64x64x64_2_nhwc_nc8hw8.cu", | |||||
| "cutlass_tensorop_u4_i8832fprop_relu_u4_s4_128x64x64_64x64x64_2_nhwc_nc8hw8.cu", | |||||
| "cutlass_tensorop_u4_i8832fprop_1x1_relu_u4_s4_128x64x64_64x64x64_2_nhwc_nc8hw8.cu", | |||||
| "cutlass_tensorop_u4_i8832fprop_identity_u4_s4_128x32x64_64x32x64_2_nhwc_nc16hw16.cu", | |||||
| "cutlass_tensorop_u4_i8832fprop_1x1_identity_u4_s4_128x32x64_64x32x64_2_nhwc_nc16hw16.cu", | |||||
| "cutlass_tensorop_u4_i8832fprop_relu_u4_s4_128x32x64_64x32x64_2_nhwc_nc16hw16.cu", | |||||
| "cutlass_tensorop_u4_i8832fprop_1x1_relu_u4_s4_128x32x64_64x32x64_2_nhwc_nc16hw16.cu", | |||||
| "cutlass_tensorop_u4_i8832fprop_identity_u4_s4_128x64x64_64x64x64_2_nhwc_nc16hw16.cu", | |||||
| "cutlass_tensorop_u4_i8832fprop_1x1_identity_u4_s4_128x64x64_64x64x64_2_nhwc_nc16hw16.cu", | |||||
| "cutlass_tensorop_u4_i8832fprop_relu_u4_s4_128x64x64_64x64x64_2_nhwc_nc16hw16.cu", | |||||
| "cutlass_tensorop_u4_i8832fprop_1x1_relu_u4_s4_128x64x64_64x64x64_2_nhwc_nc16hw16.cu", | |||||
| "cutlass_tensorop_u4_i8832fprop_identity_u4_s4_128x32x64_64x32x64_2_nhwc_nc32hw32.cu", | |||||
| "cutlass_tensorop_u4_i8832fprop_1x1_identity_u4_s4_128x32x64_64x32x64_2_nhwc_nc32hw32.cu", | |||||
| "cutlass_tensorop_u4_i8832fprop_relu_u4_s4_128x32x64_64x32x64_2_nhwc_nc32hw32.cu", | |||||
| "cutlass_tensorop_u4_i8832fprop_1x1_relu_u4_s4_128x32x64_64x32x64_2_nhwc_nc32hw32.cu", | |||||
| "cutlass_tensorop_u4_i8832fprop_identity_u4_s4_128x64x64_64x64x64_2_nhwc_nc32hw32.cu", | |||||
| "cutlass_tensorop_u4_i8832fprop_1x1_identity_u4_s4_128x64x64_64x64x64_2_nhwc_nc32hw32.cu", | |||||
| "cutlass_tensorop_u4_i8832fprop_relu_u4_s4_128x64x64_64x64x64_2_nhwc_nc32hw32.cu", | |||||
| "cutlass_tensorop_u4_i8832fprop_1x1_relu_u4_s4_128x64x64_64x64x64_2_nhwc_nc32hw32.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_id_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_1x1_id_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_relu_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_1x1_relu_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_hswish_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_1x1_hswish_s8_64x128x64_32x64x64_2_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_id_s8_128x64x32_64x32x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_1x1_id_s8_128x64x32_64x32x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_relu_s8_128x64x32_64x32x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_1x1_relu_s8_128x64x32_64x32x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_hswish_s8_128x64x32_64x32x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_1x1_hswish_s8_128x64x32_64x32x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_id_s8_128x32x32_64x32x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_1x1_id_s8_128x32x32_64x32x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_relu_s8_128x32x32_64x32x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_1x1_relu_s8_128x32x32_64x32x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_hswish_s8_128x32x32_64x32x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_1x1_hswish_s8_128x32x32_64x32x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_id_s8_64x128x32_32x64x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_1x1_id_s8_64x128x32_32x64x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_relu_s8_64x128x32_32x64x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_1x1_relu_s8_64x128x32_32x64x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_hswish_s8_64x128x32_32x64x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_1x1_hswish_s8_64x128x32_32x64x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_id_s8_32x128x32_32x64x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_1x1_id_s8_32x128x32_32x64x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_relu_s8_32x128x32_32x64x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_1x1_relu_s8_32x128x32_32x64x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_hswish_s8_32x128x32_32x64x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
| "cutlass_tensorop_s8_i8816fprop_1x1_hswish_s8_32x128x32_32x64x32_1_nc32hw32_c32rsk32_nc4hw4.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_roc_id_s4_128x256x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_roc_relu_s4_128x256x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_roc_hswish_s4_128x256x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_roc_id_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_roc_relu_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_roc_hswish_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_roc_id_s4_128x64x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_roc_relu_s4_128x64x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_roc_hswish_s4_128x64x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_roc_id_s4_128x64x64_64x64x64_1_nc64hw64_c64rsk64.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_roc_relu_s4_128x64x64_64x64x64_1_nc64hw64_c64rsk64.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_roc_hswish_s4_128x64x64_64x64x64_1_nc64hw64_c64rsk64.cu", | |||||
| "cutlass_tensorop_u4_i8832fprop_roc_id_u4_s4_128x256x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||||
| "cutlass_tensorop_u4_i8832fprop_roc_relu_u4_s4_128x256x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||||
| "cutlass_tensorop_u4_i8832fprop_roc_id_u4_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||||
| "cutlass_tensorop_u4_i8832fprop_roc_relu_u4_s4_128x128x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||||
| "cutlass_tensorop_u4_i8832fprop_roc_id_u4_s4_128x64x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||||
| "cutlass_tensorop_u4_i8832fprop_roc_relu_u4_s4_128x64x128_64x64x128_2_nc64hw64_c64rsk64.cu", | |||||
| "cutlass_tensorop_u4_i8832fprop_roc_id_u4_s4_128x64x64_64x64x64_1_nc64hw64_c64rsk64.cu", | |||||
| "cutlass_tensorop_u4_i8832fprop_roc_relu_u4_s4_128x64x64_64x64x64_1_nc64hw64_c64rsk64.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_id_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_1x1_id_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_relu_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_1x1_relu_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_hswish_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_1x1_hswish_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_roc_id_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_1x1_roc_id_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_roc_relu_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_1x1_roc_relu_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_roc_hswish_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_1x1_roc_hswish_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_id_s4_128x64x64_64x64x64_1_nhwc_nc8hw8.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_1x1_id_s4_128x64x64_64x64x64_1_nhwc_nc8hw8.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_relu_s4_128x64x64_64x64x64_1_nhwc_nc8hw8.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_1x1_relu_s4_128x64x64_64x64x64_1_nhwc_nc8hw8.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_hswish_s4_128x64x64_64x64x64_1_nhwc_nc8hw8.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_1x1_hswish_s4_128x64x64_64x64x64_1_nhwc_nc8hw8.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_roc_id_s4_128x64x64_64x64x64_1_nhwc_nc8hw8.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_1x1_roc_id_s4_128x64x64_64x64x64_1_nhwc_nc8hw8.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_roc_relu_s4_128x64x64_64x64x64_1_nhwc_nc8hw8.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_1x1_roc_relu_s4_128x64x64_64x64x64_1_nhwc_nc8hw8.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_roc_hswish_s4_128x64x64_64x64x64_1_nhwc_nc8hw8.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_1x1_roc_hswish_s4_128x64x64_64x64x64_1_nhwc_nc8hw8.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_id_s4_128x32x64_64x32x64_1_nhwc_nc16hw16.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_1x1_id_s4_128x32x64_64x32x64_1_nhwc_nc16hw16.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_relu_s4_128x32x64_64x32x64_1_nhwc_nc16hw16.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_1x1_relu_s4_128x32x64_64x32x64_1_nhwc_nc16hw16.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_hswish_s4_128x32x64_64x32x64_1_nhwc_nc16hw16.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_1x1_hswish_s4_128x32x64_64x32x64_1_nhwc_nc16hw16.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_roc_id_s4_128x32x64_64x32x64_1_nhwc_nc16hw16.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_1x1_roc_id_s4_128x32x64_64x32x64_1_nhwc_nc16hw16.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_roc_relu_s4_128x32x64_64x32x64_1_nhwc_nc16hw16.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_1x1_roc_relu_s4_128x32x64_64x32x64_1_nhwc_nc16hw16.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_roc_hswish_s4_128x32x64_64x32x64_1_nhwc_nc16hw16.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_1x1_roc_hswish_s4_128x32x64_64x32x64_1_nhwc_nc16hw16.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_id_s4_128x64x64_64x64x64_1_nhwc_nc16hw16.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_1x1_id_s4_128x64x64_64x64x64_1_nhwc_nc16hw16.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_relu_s4_128x64x64_64x64x64_1_nhwc_nc16hw16.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_1x1_relu_s4_128x64x64_64x64x64_1_nhwc_nc16hw16.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_hswish_s4_128x64x64_64x64x64_1_nhwc_nc16hw16.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_1x1_hswish_s4_128x64x64_64x64x64_1_nhwc_nc16hw16.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_roc_id_s4_128x64x64_64x64x64_1_nhwc_nc16hw16.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_1x1_roc_id_s4_128x64x64_64x64x64_1_nhwc_nc16hw16.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_roc_relu_s4_128x64x64_64x64x64_1_nhwc_nc16hw16.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_1x1_roc_relu_s4_128x64x64_64x64x64_1_nhwc_nc16hw16.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_roc_hswish_s4_128x64x64_64x64x64_1_nhwc_nc16hw16.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_1x1_roc_hswish_s4_128x64x64_64x64x64_1_nhwc_nc16hw16.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_id_s4_128x32x64_64x32x64_1_nhwc_nc32hw32.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_1x1_id_s4_128x32x64_64x32x64_1_nhwc_nc32hw32.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_relu_s4_128x32x64_64x32x64_1_nhwc_nc32hw32.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_1x1_relu_s4_128x32x64_64x32x64_1_nhwc_nc32hw32.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_hswish_s4_128x32x64_64x32x64_1_nhwc_nc32hw32.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_1x1_hswish_s4_128x32x64_64x32x64_1_nhwc_nc32hw32.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_roc_id_s4_128x32x64_64x32x64_1_nhwc_nc32hw32.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_1x1_roc_id_s4_128x32x64_64x32x64_1_nhwc_nc32hw32.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_roc_relu_s4_128x32x64_64x32x64_1_nhwc_nc32hw32.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_1x1_roc_relu_s4_128x32x64_64x32x64_1_nhwc_nc32hw32.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_roc_hswish_s4_128x32x64_64x32x64_1_nhwc_nc32hw32.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_1x1_roc_hswish_s4_128x32x64_64x32x64_1_nhwc_nc32hw32.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_id_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_1x1_id_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_relu_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_1x1_relu_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_hswish_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_1x1_hswish_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_roc_id_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_1x1_roc_id_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_roc_relu_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_1x1_roc_relu_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_roc_hswish_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | |||||
| "cutlass_tensorop_s4_i8832fprop_1x1_roc_hswish_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | |||||
| "cutlass_tensorop_u4_i8832fprop_id_u4_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | |||||
| "cutlass_tensorop_u4_i8832fprop_1x1_id_u4_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | |||||
| "cutlass_tensorop_u4_i8832fprop_relu_u4_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | |||||
| "cutlass_tensorop_u4_i8832fprop_1x1_relu_u4_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | |||||
| "cutlass_tensorop_u4_i8832fprop_roc_id_u4_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | |||||
| "cutlass_tensorop_u4_i8832fprop_1x1_roc_id_u4_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | |||||
| "cutlass_tensorop_u4_i8832fprop_roc_relu_u4_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | |||||
| "cutlass_tensorop_u4_i8832fprop_1x1_roc_relu_u4_s4_128x32x64_64x32x64_1_nhwc_nc8hw8.cu", | |||||
| "cutlass_tensorop_u4_i8832fprop_id_u4_s4_128x64x64_64x64x64_1_nhwc_nc8hw8.cu", | |||||
| "cutlass_tensorop_u4_i8832fprop_1x1_id_u4_s4_128x64x64_64x64x64_1_nhwc_nc8hw8.cu", | |||||
| "cutlass_tensorop_u4_i8832fprop_relu_u4_s4_128x64x64_64x64x64_1_nhwc_nc8hw8.cu", | |||||
| "cutlass_tensorop_u4_i8832fprop_1x1_relu_u4_s4_128x64x64_64x64x64_1_nhwc_nc8hw8.cu", | |||||
| "cutlass_tensorop_u4_i8832fprop_roc_id_u4_s4_128x64x64_64x64x64_1_nhwc_nc8hw8.cu", | |||||
| "cutlass_tensorop_u4_i8832fprop_1x1_roc_id_u4_s4_128x64x64_64x64x64_1_nhwc_nc8hw8.cu", | |||||
| "cutlass_tensorop_u4_i8832fprop_roc_relu_u4_s4_128x64x64_64x64x64_1_nhwc_nc8hw8.cu", | |||||
| "cutlass_tensorop_u4_i8832fprop_1x1_roc_relu_u4_s4_128x64x64_64x64x64_1_nhwc_nc8hw8.cu", | |||||
| "cutlass_tensorop_u4_i8832fprop_id_u4_s4_128x32x64_64x32x64_1_nhwc_nc16hw16.cu", | |||||
| "cutlass_tensorop_u4_i8832fprop_1x1_id_u4_s4_128x32x64_64x32x64_1_nhwc_nc16hw16.cu", | |||||
| "cutlass_tensorop_u4_i8832fprop_relu_u4_s4_128x32x64_64x32x64_1_nhwc_nc16hw16.cu", | |||||
| "cutlass_tensorop_u4_i8832fprop_1x1_relu_u4_s4_128x32x64_64x32x64_1_nhwc_nc16hw16.cu", | |||||
| "cutlass_tensorop_u4_i8832fprop_roc_id_u4_s4_128x32x64_64x32x64_1_nhwc_nc16hw16.cu", | |||||
| "cutlass_tensorop_u4_i8832fprop_1x1_roc_id_u4_s4_128x32x64_64x32x64_1_nhwc_nc16hw16.cu", | |||||
| "cutlass_tensorop_u4_i8832fprop_roc_relu_u4_s4_128x32x64_64x32x64_1_nhwc_nc16hw16.cu", | |||||
| "cutlass_tensorop_u4_i8832fprop_1x1_roc_relu_u4_s4_128x32x64_64x32x64_1_nhwc_nc16hw16.cu", | |||||
| "cutlass_tensorop_u4_i8832fprop_id_u4_s4_128x64x64_64x64x64_1_nhwc_nc16hw16.cu", | |||||
| "cutlass_tensorop_u4_i8832fprop_1x1_id_u4_s4_128x64x64_64x64x64_1_nhwc_nc16hw16.cu", | |||||
| "cutlass_tensorop_u4_i8832fprop_relu_u4_s4_128x64x64_64x64x64_1_nhwc_nc16hw16.cu", | |||||
| "cutlass_tensorop_u4_i8832fprop_1x1_relu_u4_s4_128x64x64_64x64x64_1_nhwc_nc16hw16.cu", | |||||
| "cutlass_tensorop_u4_i8832fprop_roc_id_u4_s4_128x64x64_64x64x64_1_nhwc_nc16hw16.cu", | |||||
| "cutlass_tensorop_u4_i8832fprop_1x1_roc_id_u4_s4_128x64x64_64x64x64_1_nhwc_nc16hw16.cu", | |||||
| "cutlass_tensorop_u4_i8832fprop_roc_relu_u4_s4_128x64x64_64x64x64_1_nhwc_nc16hw16.cu", | |||||
| "cutlass_tensorop_u4_i8832fprop_1x1_roc_relu_u4_s4_128x64x64_64x64x64_1_nhwc_nc16hw16.cu", | |||||
| "cutlass_tensorop_u4_i8832fprop_id_u4_s4_128x32x64_64x32x64_1_nhwc_nc32hw32.cu", | |||||
| "cutlass_tensorop_u4_i8832fprop_1x1_id_u4_s4_128x32x64_64x32x64_1_nhwc_nc32hw32.cu", | |||||
| "cutlass_tensorop_u4_i8832fprop_relu_u4_s4_128x32x64_64x32x64_1_nhwc_nc32hw32.cu", | |||||
| "cutlass_tensorop_u4_i8832fprop_1x1_relu_u4_s4_128x32x64_64x32x64_1_nhwc_nc32hw32.cu", | |||||
| "cutlass_tensorop_u4_i8832fprop_roc_id_u4_s4_128x32x64_64x32x64_1_nhwc_nc32hw32.cu", | |||||
| "cutlass_tensorop_u4_i8832fprop_1x1_roc_id_u4_s4_128x32x64_64x32x64_1_nhwc_nc32hw32.cu", | |||||
| "cutlass_tensorop_u4_i8832fprop_roc_relu_u4_s4_128x32x64_64x32x64_1_nhwc_nc32hw32.cu", | |||||
| "cutlass_tensorop_u4_i8832fprop_1x1_roc_relu_u4_s4_128x32x64_64x32x64_1_nhwc_nc32hw32.cu", | |||||
| "cutlass_tensorop_u4_i8832fprop_id_u4_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | |||||
| "cutlass_tensorop_u4_i8832fprop_1x1_id_u4_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | |||||
| "cutlass_tensorop_u4_i8832fprop_relu_u4_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | |||||
| "cutlass_tensorop_u4_i8832fprop_1x1_relu_u4_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | |||||
| "cutlass_tensorop_u4_i8832fprop_roc_id_u4_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | |||||
| "cutlass_tensorop_u4_i8832fprop_1x1_roc_id_u4_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | |||||
| "cutlass_tensorop_u4_i8832fprop_roc_relu_u4_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | |||||
| "cutlass_tensorop_u4_i8832fprop_1x1_roc_relu_u4_s4_128x64x64_64x64x64_1_nhwc_nc32hw32.cu", | |||||
| ] | ] | ||||
| @@ -217,56 +217,68 @@ void ConvBiasForwardImpl::AlgoPack::fill_imma_algos() { | |||||
| #if CUDA_VERSION >= 10020 | #if CUDA_VERSION >= 10020 | ||||
| { | { | ||||
| using AlgoParam = AlgoInt8NCHW32IMMAImplicitGemm::AlgoParam; | using AlgoParam = AlgoInt8NCHW32IMMAImplicitGemm::AlgoParam; | ||||
| int8_nchw32_imma.emplace_back(AlgoParam{128, 256, 64, 64, 64, 64}); | |||||
| int8_nchw32_imma.emplace_back(AlgoParam{256, 128, 64, 64, 64, 64}); | |||||
| int8_nchw32_imma.emplace_back(AlgoParam{128, 128, 64, 64, 64, 64}); | |||||
| int8_nchw32_imma.emplace_back(AlgoParam{64, 128, 64, 32, 64, 64}); | |||||
| int8_nchw32_imma.emplace_back(AlgoParam{128, 64, 64, 64, 32, 64}); | |||||
| int8_nchw32_imma.emplace_back(AlgoParam{64, 64, 64, 32, 32, 64}); | |||||
| int8_nchw32_imma.emplace_back(AlgoParam{32, 64, 64, 32, 16, 64}); | |||||
| int8_nchw32_imma.emplace_back(AlgoParam{128, 256, 64, 64, 64, 64, 2}); | |||||
| int8_nchw32_imma.emplace_back(AlgoParam{256, 128, 64, 64, 64, 64, 2}); | |||||
| int8_nchw32_imma.emplace_back(AlgoParam{128, 128, 64, 64, 64, 64, 2}); | |||||
| int8_nchw32_imma.emplace_back(AlgoParam{128, 64, 64, 64, 32, 64, 2}); | |||||
| int8_nchw32_imma.emplace_back(AlgoParam{64, 128, 64, 32, 64, 64, 2}); | |||||
| int8_nchw32_imma.emplace_back(AlgoParam{128, 64, 32, 64, 32, 32, 1}); | |||||
| int8_nchw32_imma.emplace_back(AlgoParam{128, 32, 32, 64, 32, 32, 1}); | |||||
| int8_nchw32_imma.emplace_back(AlgoParam{64, 128, 32, 32, 64, 32, 1}); | |||||
| int8_nchw32_imma.emplace_back(AlgoParam{32, 128, 32, 32, 64, 32, 1}); | |||||
| } | } | ||||
| { | { | ||||
| using AlgoParam = AlgoInt4Int4NCHW64IMMAImplicitGemm::AlgoParam; | using AlgoParam = AlgoInt4Int4NCHW64IMMAImplicitGemm::AlgoParam; | ||||
| int4_int4_nchw64_imma.emplace_back( | int4_int4_nchw64_imma.emplace_back( | ||||
| AlgoParam{128, 128, 128, 64, 64, 128}); | |||||
| AlgoParam{128, 128, 128, 64, 64, 128, 2}); | |||||
| int4_int4_nchw64_imma.emplace_back( | int4_int4_nchw64_imma.emplace_back( | ||||
| AlgoParam{256, 128, 128, 64, 64, 128}); | |||||
| AlgoParam{128, 256, 128, 64, 64, 128, 2}); | |||||
| int4_int4_nchw64_imma.emplace_back( | |||||
| AlgoParam{128, 64, 128, 64, 64, 128, 2}); | |||||
| int4_int4_nchw64_imma.emplace_back( | |||||
| AlgoParam{128, 64, 64, 64, 64, 64, 1}); | |||||
| } | } | ||||
| { | { | ||||
| using AlgoParam = AlgoUInt4Int4NCHW64IMMAImplicitGemm::AlgoParam; | using AlgoParam = AlgoUInt4Int4NCHW64IMMAImplicitGemm::AlgoParam; | ||||
| uint4_int4_nchw64_imma.emplace_back( | uint4_int4_nchw64_imma.emplace_back( | ||||
| AlgoParam{128, 128, 128, 64, 64, 128}); | |||||
| AlgoParam{128, 128, 128, 64, 64, 128, 2}); | |||||
| uint4_int4_nchw64_imma.emplace_back( | |||||
| AlgoParam{128, 256, 128, 64, 64, 128, 2}); | |||||
| uint4_int4_nchw64_imma.emplace_back( | uint4_int4_nchw64_imma.emplace_back( | ||||
| AlgoParam{256, 128, 128, 64, 64, 128}); | |||||
| AlgoParam{128, 64, 128, 64, 64, 128, 2}); | |||||
| uint4_int4_nchw64_imma.emplace_back( | |||||
| AlgoParam{128, 64, 64, 64, 64, 64, 1}); | |||||
| } | } | ||||
| { | { | ||||
| using AlgoParam = AlgoInt4Int4NHWCIMMAImplicitGemm::AlgoParam; | using AlgoParam = AlgoInt4Int4NHWCIMMAImplicitGemm::AlgoParam; | ||||
| int4_int4_nhwc_imma.emplace_back( | int4_int4_nhwc_imma.emplace_back( | ||||
| AlgoParam{128, 32, 64, 64, 32, 64, 32}); | |||||
| AlgoParam{128, 32, 64, 64, 32, 64, 1, 32}); | |||||
| int4_int4_nhwc_imma.emplace_back( | |||||
| AlgoParam{128, 32, 64, 64, 32, 64, 1, 16}); | |||||
| int4_int4_nhwc_imma.emplace_back( | |||||
| AlgoParam{128, 32, 64, 64, 32, 64, 1, 8}); | |||||
| int4_int4_nhwc_imma.emplace_back( | int4_int4_nhwc_imma.emplace_back( | ||||
| AlgoParam{128, 32, 64, 64, 32, 64, 16}); | |||||
| int4_int4_nhwc_imma.emplace_back(AlgoParam{128, 32, 64, 64, 32, 64, 8}); | |||||
| AlgoParam{128, 64, 64, 64, 64, 64, 1, 32}); | |||||
| int4_int4_nhwc_imma.emplace_back( | int4_int4_nhwc_imma.emplace_back( | ||||
| AlgoParam{128, 64, 64, 64, 64, 64, 32}); | |||||
| AlgoParam{128, 64, 64, 64, 64, 64, 1, 16}); | |||||
| int4_int4_nhwc_imma.emplace_back( | int4_int4_nhwc_imma.emplace_back( | ||||
| AlgoParam{128, 64, 64, 64, 64, 64, 16}); | |||||
| int4_int4_nhwc_imma.emplace_back(AlgoParam{128, 64, 64, 64, 64, 64, 8}); | |||||
| AlgoParam{128, 64, 64, 64, 64, 64, 1, 8}); | |||||
| } | } | ||||
| { | { | ||||
| using AlgoParam = AlgoUInt4Int4NHWCIMMAImplicitGemm::AlgoParam; | using AlgoParam = AlgoUInt4Int4NHWCIMMAImplicitGemm::AlgoParam; | ||||
| uint4_int4_nhwc_imma.emplace_back( | uint4_int4_nhwc_imma.emplace_back( | ||||
| AlgoParam{128, 32, 64, 64, 32, 64, 32}); | |||||
| AlgoParam{128, 32, 64, 64, 32, 64, 1, 32}); | |||||
| uint4_int4_nhwc_imma.emplace_back( | uint4_int4_nhwc_imma.emplace_back( | ||||
| AlgoParam{128, 32, 64, 64, 32, 64, 16}); | |||||
| AlgoParam{128, 32, 64, 64, 32, 64, 1, 16}); | |||||
| uint4_int4_nhwc_imma.emplace_back( | uint4_int4_nhwc_imma.emplace_back( | ||||
| AlgoParam{128, 32, 64, 64, 32, 64, 8}); | |||||
| AlgoParam{128, 32, 64, 64, 32, 64, 1, 8}); | |||||
| uint4_int4_nhwc_imma.emplace_back( | uint4_int4_nhwc_imma.emplace_back( | ||||
| AlgoParam{128, 64, 64, 64, 64, 64, 32}); | |||||
| AlgoParam{128, 64, 64, 64, 64, 64, 1, 32}); | |||||
| uint4_int4_nhwc_imma.emplace_back( | uint4_int4_nhwc_imma.emplace_back( | ||||
| AlgoParam{128, 64, 64, 64, 64, 64, 16}); | |||||
| AlgoParam{128, 64, 64, 64, 64, 64, 1, 16}); | |||||
| uint4_int4_nhwc_imma.emplace_back( | uint4_int4_nhwc_imma.emplace_back( | ||||
| AlgoParam{128, 64, 64, 64, 64, 64, 8}); | |||||
| AlgoParam{128, 64, 64, 64, 64, 64, 1, 8}); | |||||
| } | } | ||||
| #endif | #endif | ||||
| } | } | ||||
| @@ -279,10 +291,8 @@ void ConvBiasForwardImpl::AlgoPack::fill_dp4a_algos() { | |||||
| int8_nchw4_dotprod.emplace_back(AlgoParam{64, 128, 32, 64, 32, 32, 2}); | int8_nchw4_dotprod.emplace_back(AlgoParam{64, 128, 32, 64, 32, 32, 2}); | ||||
| int8_nchw4_dotprod.emplace_back(AlgoParam{32, 128, 32, 32, 64, 32, 2}); | int8_nchw4_dotprod.emplace_back(AlgoParam{32, 128, 32, 32, 64, 32, 2}); | ||||
| int8_nchw4_dotprod.emplace_back(AlgoParam{128, 32, 32, 64, 32, 32, 2}); | int8_nchw4_dotprod.emplace_back(AlgoParam{128, 32, 32, 64, 32, 32, 2}); | ||||
| int8_nchw4_dotprod.emplace_back(AlgoParam{64, 64, 32, 64, 32, 32, 2}); | |||||
| int8_nchw4_dotprod.emplace_back(AlgoParam{32, 64, 32, 32, 64, 32, 2}); | int8_nchw4_dotprod.emplace_back(AlgoParam{32, 64, 32, 32, 64, 32, 2}); | ||||
| int8_nchw4_dotprod.emplace_back(AlgoParam{64, 32, 32, 64, 32, 32, 2}); | int8_nchw4_dotprod.emplace_back(AlgoParam{64, 32, 32, 64, 32, 32, 2}); | ||||
| int8_nchw4_dotprod.emplace_back(AlgoParam{32, 32, 32, 32, 32, 32, 2}); | |||||
| int8_nchw4_dotprod.emplace_back(AlgoParam{16, 128, 16, 16, 128, 16, 1}); | int8_nchw4_dotprod.emplace_back(AlgoParam{16, 128, 16, 16, 128, 16, 1}); | ||||
| int8_nchw4_dotprod.emplace_back(AlgoParam{16, 64, 8, 16, 64, 8, 2}); | int8_nchw4_dotprod.emplace_back(AlgoParam{16, 64, 8, 16, 64, 8, 2}); | ||||
| } | } | ||||
| @@ -723,6 +723,7 @@ public: | |||||
| int warp_m; | int warp_m; | ||||
| int warp_n; | int warp_n; | ||||
| int warp_k; | int warp_k; | ||||
| int stage; | |||||
| }; | }; | ||||
| AlgoInt8NCHW32IMMAImplicitGemm(AlgoParam algo_param) | AlgoInt8NCHW32IMMAImplicitGemm(AlgoParam algo_param) | ||||
| : m_algo_param{algo_param} { | : m_algo_param{algo_param} { | ||||
| @@ -770,6 +771,7 @@ public: | |||||
| int warp_m; | int warp_m; | ||||
| int warp_n; | int warp_n; | ||||
| int warp_k; | int warp_k; | ||||
| int stage; | |||||
| }; | }; | ||||
| AlgoInt4NCHW64IMMAImplicitGemmBase(AlgoParam algo_param) | AlgoInt4NCHW64IMMAImplicitGemmBase(AlgoParam algo_param) | ||||
| @@ -897,6 +899,7 @@ public: | |||||
| int warp_m; | int warp_m; | ||||
| int warp_n; | int warp_n; | ||||
| int warp_k; | int warp_k; | ||||
| int stage; | |||||
| int access_size; | int access_size; | ||||
| }; | }; | ||||
| @@ -38,7 +38,7 @@ void do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32( | |||||
| const convolution::ConvParam& param, uint32_t nonlinear_mode, | const convolution::ConvParam& param, uint32_t nonlinear_mode, | ||||
| float alpha, float beta, float gamma, float scale, | float alpha, float beta, float gamma, float scale, | ||||
| const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, | const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, | ||||
| cudaStream_t stream); | |||||
| int stages, cudaStream_t stream); | |||||
| template <bool NeedLoadFromConstMem> | template <bool NeedLoadFromConstMem> | ||||
| void do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4( | void do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4( | ||||
| @@ -47,7 +47,7 @@ void do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4( | |||||
| const convolution::ConvParam& param, uint32_t nonlinear_mode, | const convolution::ConvParam& param, uint32_t nonlinear_mode, | ||||
| float alpha, float beta, float gamma, float scale, | float alpha, float beta, float gamma, float scale, | ||||
| const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, | const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, | ||||
| cudaStream_t stream); | |||||
| int stages, cudaStream_t stream); | |||||
| template <bool NeedLoadFromConstMem> | template <bool NeedLoadFromConstMem> | ||||
| void do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4( | void do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4( | ||||
| @@ -83,7 +83,7 @@ void do_conv_bias_int4_int4_implicit_gemm_imma_ncdiv64hw64( | |||||
| const convolution::ConvParam& param, uint32_t nonlinear_mode, | const convolution::ConvParam& param, uint32_t nonlinear_mode, | ||||
| float alpha, float beta, float gamma, float scale, | float alpha, float beta, float gamma, float scale, | ||||
| const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, | const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, | ||||
| cudaStream_t stream); | |||||
| int stages, cudaStream_t stream); | |||||
| template <bool NeedLoadFromConstMem> | template <bool NeedLoadFromConstMem> | ||||
| void do_conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64( | void do_conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64( | ||||
| @@ -92,7 +92,7 @@ void do_conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64( | |||||
| const convolution::ConvParam& param, uint32_t nonlinear_mode, | const convolution::ConvParam& param, uint32_t nonlinear_mode, | ||||
| float alpha, float beta, float gamma, float delta, float theta, | float alpha, float beta, float gamma, float delta, float theta, | ||||
| float scale, uint8_t src_zero_point, const GemmCoord& threadblock_shape, | float scale, uint8_t src_zero_point, const GemmCoord& threadblock_shape, | ||||
| const GemmCoord& warp_shape, cudaStream_t stream); | |||||
| const GemmCoord& warp_shape, int stages, cudaStream_t stream); | |||||
| template <bool signedness> | template <bool signedness> | ||||
| void do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nhwc( | void do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nhwc( | ||||
| @@ -110,7 +110,7 @@ void do_conv_bias_int4_int4_implicit_gemm_imma_nhwc( | |||||
| const convolution::ConvParam& param, uint32_t nonlinear_mode, | const convolution::ConvParam& param, uint32_t nonlinear_mode, | ||||
| float alpha, float beta, float gamma, float scale, | float alpha, float beta, float gamma, float scale, | ||||
| const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, | const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, | ||||
| const int32_t access_size, cudaStream_t stream); | |||||
| const int32_t access_size, int stages, cudaStream_t stream); | |||||
| template <bool NeedLoadFromConstMem> | template <bool NeedLoadFromConstMem> | ||||
| void do_conv_bias_uint4_int4_implicit_gemm_imma_nhwc( | void do_conv_bias_uint4_int4_implicit_gemm_imma_nhwc( | ||||
| @@ -119,7 +119,7 @@ void do_conv_bias_uint4_int4_implicit_gemm_imma_nhwc( | |||||
| const convolution::ConvParam& param, uint32_t nonlinear_mode, | const convolution::ConvParam& param, uint32_t nonlinear_mode, | ||||
| float alpha, float beta, float gamma, float delta, float theta, | float alpha, float beta, float gamma, float delta, float theta, | ||||
| float scale, uint8_t src_zero_point, const GemmCoord& threadblock_shape, | float scale, uint8_t src_zero_point, const GemmCoord& threadblock_shape, | ||||
| const GemmCoord& warp_shape, const int32_t access_size, | |||||
| const GemmCoord& warp_shape, const int32_t access_size, int stages, | |||||
| cudaStream_t stream); | cudaStream_t stream); | ||||
| } // namespace cutlass_wrapper | } // namespace cutlass_wrapper | ||||
| @@ -0,0 +1,595 @@ | |||||
| /** | |||||
| * \file dnn/src/cuda/conv_bias/cutlass_convolution_wrapper.cu | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| // ignore warning of cutlass | |||||
| #pragma GCC diagnostic push | |||||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||||
| #if !MEGDNN_TEGRA_X1 | |||||
| #include "cutlass/convolution/device/convolution.h" | |||||
| #endif | |||||
| #include "src/common/opr_param_defs_enumv.cuh" | |||||
| #include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh" | |||||
| #pragma GCC diagnostic pop | |||||
| using namespace megdnn; | |||||
| using namespace cuda; | |||||
| using namespace cutlass_wrapper; | |||||
| /* ====== cutlass kernel wrapper for int4 x int4 nchw64 layout ====== */ | |||||
| #if MEGDNN_TEGRA_X1 | |||||
| template <bool NeedLoadFromConstMem> | |||||
| void megdnn::cuda::cutlass_wrapper:: | |||||
| do_conv_bias_int4_int4_implicit_gemm_imma_ncdiv64hw64( | |||||
| const int8_t* /* d_src */, const int8_t* /* d_filter */, | |||||
| const int32_t* /* d_bias */, const int8_t* /* d_z */, | |||||
| int8_t* /* d_dst */, int* /* workspace */, | |||||
| const convolution::ConvParam& /* param */, | |||||
| uint32_t /* nonlinear_mode */, float /* alpha */, | |||||
| float /* beta */, float /* gamma */, float /* scale */, | |||||
| const GemmCoord& /* threadblock_shape */, | |||||
| const GemmCoord& /* warp_shape */, int /* stages */, | |||||
| cudaStream_t /* stream */) {} | |||||
| #else | |||||
| template <bool NeedLoadFromConstMem> | |||||
| void megdnn::cuda::cutlass_wrapper:: | |||||
| do_conv_bias_int4_int4_implicit_gemm_imma_ncdiv64hw64( | |||||
| const int8_t* d_src, const int8_t* d_filter, | |||||
| const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, | |||||
| int* workspace, const convolution::ConvParam& param, | |||||
| uint32_t nonlinear_mode, float alpha, float beta, float gamma, | |||||
| float scale, const GemmCoord& threadblock_shape, | |||||
| const GemmCoord& warp_shape, int stages, cudaStream_t stream) { | |||||
| #define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \ | |||||
| threadblock_k_, warp_m_, warp_n_, \ | |||||
| warp_k_, stage_) \ | |||||
| if (threadblock_shape.m() == threadblock_m_ && \ | |||||
| threadblock_shape.n() == threadblock_n_ && \ | |||||
| threadblock_shape.k() == threadblock_k_ && \ | |||||
| warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ | |||||
| warp_shape.k() == warp_k_ && stages == stage_) { \ | |||||
| using ThreadBlockShape = \ | |||||
| cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \ | |||||
| threadblock_k_>; \ | |||||
| using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \ | |||||
| using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; \ | |||||
| using Convolution = cutlass::conv::device::Convolution< \ | |||||
| cutlass::int4b_t, cutlass::layout::TensorNCxHWx<64>, \ | |||||
| cutlass::int4b_t, cutlass::layout::TensorCxRSKx<64>, \ | |||||
| ElementOutput, cutlass::layout::TensorNCxHWx<64>, int32_t, \ | |||||
| cutlass::layout::TensorNCxHWx<64>, int32_t, \ | |||||
| cutlass::conv::ConvType::kConvolution, \ | |||||
| cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \ | |||||
| ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ | |||||
| cutlass::conv::threadblock:: \ | |||||
| ConvolutionFpropTransThreadblockSwizzle, \ | |||||
| stage_, 32, 32, NeedLoadFromConstMem, \ | |||||
| cutlass::arch::OpMultiplyAddSaturate, \ | |||||
| cutlass::conv::ImplicitGemmMode::GEMM_TN, true>; \ | |||||
| typename Convolution::ConvolutionParameter conv_param( \ | |||||
| param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ | |||||
| param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ | |||||
| param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \ | |||||
| return cutlass_convolution_wrapper<Convolution>( \ | |||||
| reinterpret_cast<const cutlass::int4b_t*>(d_src), \ | |||||
| reinterpret_cast<const cutlass::int4b_t*>(d_filter), d_bias, \ | |||||
| reinterpret_cast<const cutlass::int4b_t*>(d_z), \ | |||||
| reinterpret_cast<cutlass::int4b_t*>(d_dst), workspace, \ | |||||
| conv_param, epilogue, stream); \ | |||||
| } | |||||
| #define DISPATCH_KERNEL \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 128, 64, 64, 128, 2); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 256, 128, 64, 64, 128, 2); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 128, 64, 64, 128, 2); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1); \ | |||||
| megdnn_assert(false, \ | |||||
| "unsupported threadblock shape (%dx%dx%d) and warp shape " \ | |||||
| "(%dx%dx%d)", \ | |||||
| threadblock_shape.m(), threadblock_shape.n(), \ | |||||
| threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ | |||||
| warp_shape.k()); | |||||
| using ElementOutput = cutlass::int4b_t; | |||||
| using ElementAccumulator = int32_t; | |||||
| using ElementBias = int32_t; | |||||
| using ElementCompute = float; | |||||
| using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode; | |||||
| switch (nonlinear_mode) { | |||||
| case NonlineMode::IDENTITY: { | |||||
| using EpilogueOp = | |||||
| cutlass::epilogue::thread::BiasAddLinearCombinationClamp< | |||||
| ElementOutput, 16, ElementAccumulator, ElementBias, | |||||
| ElementCompute>; | |||||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma}; | |||||
| DISPATCH_KERNEL; | |||||
| } | |||||
| case NonlineMode::RELU: { | |||||
| using EpilogueOp = cutlass::epilogue::thread:: | |||||
| BiasAddLinearCombinationReluClamp< | |||||
| ElementOutput, 16, ElementAccumulator, ElementBias, | |||||
| ElementCompute>; | |||||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma, 0}; | |||||
| DISPATCH_KERNEL; | |||||
| } | |||||
| case NonlineMode::H_SWISH: { | |||||
| using EpilogueOp = cutlass::epilogue::thread:: | |||||
| BiasAddLinearCombinationHSwishClamp< | |||||
| ElementOutput, 16, ElementAccumulator, ElementBias, | |||||
| ElementCompute>; | |||||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma, scale}; | |||||
| DISPATCH_KERNEL; | |||||
| } | |||||
| default: | |||||
| megdnn_assert(false, | |||||
| "unsupported nonlinear mode for conv bias operator"); | |||||
| } | |||||
| #undef DISPATCH_KERNEL_WITH_TILE_SHAPE | |||||
| #undef DISPATCH_KERNEL | |||||
| } | |||||
| #endif | |||||
| #define INST(need_load_from_const_mem) \ | |||||
| template void megdnn::cuda::cutlass_wrapper:: \ | |||||
| do_conv_bias_int4_int4_implicit_gemm_imma_ncdiv64hw64< \ | |||||
| need_load_from_const_mem>( \ | |||||
| const int8_t* d_src, const int8_t* d_filter, \ | |||||
| const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, \ | |||||
| int* workspace, const convolution::ConvParam& param, \ | |||||
| uint32_t nonlinear_mode, float alpha, float beta, \ | |||||
| float gamma, float scale, \ | |||||
| const GemmCoord& threadblock_shape, \ | |||||
| const GemmCoord& warp_shape, int stages, \ | |||||
| cudaStream_t stream); | |||||
| INST(true); | |||||
| #undef INST | |||||
| /* ====== cutlass kernel wrapper for uint4 x int4 nchw64 layout ====== */ | |||||
| #if MEGDNN_TEGRA_X1 | |||||
| template <bool NeedLoadFromConstMem> | |||||
| void megdnn::cuda::cutlass_wrapper:: | |||||
| do_conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64( | |||||
| const uint8_t* /* d_src */, const int8_t* /* d_filter */, | |||||
| const int32_t* /* d_bias */, const uint8_t* /* d_z */, | |||||
| uint8_t* /* d_dst */, int* /* workspace */, | |||||
| const convolution::ConvParam& /* param */, | |||||
| uint32_t /* nonlinear_mode */, float /* alpha */, | |||||
| float /* beta */, float /* gamma */, float /* delta */, | |||||
| float /* theta */, float /* scale */, | |||||
| uint8_t /* src_zero_point */, | |||||
| const GemmCoord& /* threadblock_shape */, | |||||
| const GemmCoord& /* warp_shape */, int /* stages */, | |||||
| cudaStream_t /* stream */) {} | |||||
| #else | |||||
| template <bool NeedLoadFromConstMem> | |||||
| void megdnn::cuda::cutlass_wrapper:: | |||||
| do_conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64( | |||||
| const uint8_t* d_src, const int8_t* d_filter, | |||||
| const int32_t* d_bias, const uint8_t* d_z, uint8_t* d_dst, | |||||
| int* workspace, const convolution::ConvParam& param, | |||||
| uint32_t nonlinear_mode, float alpha, float beta, float gamma, | |||||
| float delta, float theta, float /* scale */, | |||||
| uint8_t src_zero_point, const GemmCoord& threadblock_shape, | |||||
| const GemmCoord& warp_shape, int stages, cudaStream_t stream) { | |||||
| #define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \ | |||||
| threadblock_k_, warp_m_, warp_n_, \ | |||||
| warp_k_, stage_) \ | |||||
| if (threadblock_shape.m() == threadblock_m_ && \ | |||||
| threadblock_shape.n() == threadblock_n_ && \ | |||||
| threadblock_shape.k() == threadblock_k_ && \ | |||||
| warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ | |||||
| warp_shape.k() == warp_k_ && stages == stage_) { \ | |||||
| using ThreadBlockShape = \ | |||||
| cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \ | |||||
| threadblock_k_>; \ | |||||
| using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \ | |||||
| using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; \ | |||||
| using Convolution = cutlass::conv::device::Convolution< \ | |||||
| cutlass::uint4b_t, cutlass::layout::TensorNCxHWx<64>, \ | |||||
| cutlass::int4b_t, cutlass::layout::TensorCxRSKx<64>, \ | |||||
| ElementOutput, cutlass::layout::TensorNCxHWx<64>, int32_t, \ | |||||
| cutlass::layout::TensorNCxHWx<64>, int32_t, \ | |||||
| cutlass::conv::ConvType::kConvolution, \ | |||||
| cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \ | |||||
| ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ | |||||
| cutlass::conv::threadblock:: \ | |||||
| ConvolutionFpropTransThreadblockSwizzle, \ | |||||
| stage_, 32, 32, NeedLoadFromConstMem, \ | |||||
| cutlass::arch::OpMultiplyAddSaturate, \ | |||||
| cutlass::conv::ImplicitGemmMode::GEMM_TN, true>; \ | |||||
| typename Convolution::ConvolutionParameter conv_param( \ | |||||
| param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ | |||||
| param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ | |||||
| param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \ | |||||
| return cutlass_convolution_wrapper<Convolution>( \ | |||||
| reinterpret_cast<const cutlass::uint4b_t*>(d_src), \ | |||||
| reinterpret_cast<const cutlass::int4b_t*>(d_filter), d_bias, \ | |||||
| reinterpret_cast<const cutlass::uint4b_t*>(d_z), \ | |||||
| reinterpret_cast<cutlass::uint4b_t*>(d_dst), workspace, \ | |||||
| conv_param, epilogue, stream, {src_zero_point}); \ | |||||
| } | |||||
| #define DISPATCH_KERNEL \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 128, 64, 64, 128, 2); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 256, 128, 64, 64, 128, 2); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 128, 64, 64, 128, 2); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1); \ | |||||
| megdnn_assert(false, \ | |||||
| "unsupported threadblock shape (%dx%dx%d) and warp shape " \ | |||||
| "(%dx%dx%d)", \ | |||||
| threadblock_shape.m(), threadblock_shape.n(), \ | |||||
| threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ | |||||
| warp_shape.k()); | |||||
| using ElementOutput = cutlass::uint4b_t; | |||||
| using ElementAccumulator = int32_t; | |||||
| using ElementBias = int32_t; | |||||
| using ElementCompute = float; | |||||
| using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode; | |||||
| switch (nonlinear_mode) { | |||||
| case NonlineMode::IDENTITY: { | |||||
| using EpilogueOp = | |||||
| cutlass::epilogue::thread::BiasAddLinearCombinationClamp< | |||||
| ElementOutput, 16, ElementAccumulator, ElementBias, | |||||
| ElementCompute>; | |||||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma, | |||||
| delta + theta}; | |||||
| DISPATCH_KERNEL; | |||||
| } | |||||
| case NonlineMode::RELU: { | |||||
| using EpilogueOp = cutlass::epilogue::thread:: | |||||
| BiasAddLinearCombinationReluClamp< | |||||
| ElementOutput, 16, ElementAccumulator, ElementBias, | |||||
| ElementCompute>; | |||||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma, | |||||
| 0, delta, theta}; | |||||
| DISPATCH_KERNEL; | |||||
| } | |||||
| default: | |||||
| megdnn_assert(false, | |||||
| "unsupported nonlinear mode for conv bias operator"); | |||||
| } | |||||
| #undef DISPATCH_KERNEL_WITH_TILE_SHAPE | |||||
| #undef DISPATCH_KERNEL | |||||
| } | |||||
| #endif | |||||
| #define INST(need_load_from_const_mem) \ | |||||
| template void megdnn::cuda::cutlass_wrapper:: \ | |||||
| do_conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64< \ | |||||
| need_load_from_const_mem>( \ | |||||
| const uint8_t* d_src, const int8_t* d_filter, \ | |||||
| const int32_t* d_bias, const uint8_t* d_z, uint8_t* d_dst, \ | |||||
| int* workspace, const convolution::ConvParam& param, \ | |||||
| uint32_t nonlinear_mode, float alpha, float beta, \ | |||||
| float gamma, float delta, float theta, float scale, \ | |||||
| uint8_t src_zero_point, \ | |||||
| const GemmCoord& threadblock_shape, \ | |||||
| const GemmCoord& warp_shape, int stages, \ | |||||
| cudaStream_t stream); | |||||
| INST(true); | |||||
| #undef INST | |||||
| /* ====== cutlass kernel wrapper for int4 x int4 nhwc layout ====== */ | |||||
| #if MEGDNN_TEGRA_X1 | |||||
| template <bool NeedLoadFromConstMem> | |||||
| void megdnn::cuda::cutlass_wrapper:: | |||||
| do_conv_bias_int4_int4_implicit_gemm_imma_nhwc( | |||||
| const int8_t* /* d_src */, const int8_t* /* d_filter */, | |||||
| const int32_t* /* d_bias */, const int8_t* /* d_z */, | |||||
| int8_t* /* d_dst */, int* /* workspace */, | |||||
| const convolution::ConvParam& /* param */, | |||||
| uint32_t /* nonlinear_mode */, float /* alpha */, | |||||
| float /* beta */, float /* gamma */, float /* scale */, | |||||
| const GemmCoord& /* threadblock_shape */, | |||||
| const GemmCoord& /* warp_shape */, | |||||
| const int32_t /* access_size */, int /* stages */, | |||||
| cudaStream_t /* stream */) {} | |||||
| #else | |||||
| template <bool NeedLoadFromConstMem> | |||||
| void megdnn::cuda::cutlass_wrapper:: | |||||
| do_conv_bias_int4_int4_implicit_gemm_imma_nhwc( | |||||
| const int8_t* d_src, const int8_t* d_filter, | |||||
| const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, | |||||
| int* workspace, const convolution::ConvParam& param, | |||||
| uint32_t nonlinear_mode, float alpha, float beta, float gamma, | |||||
| float scale, const GemmCoord& threadblock_shape, | |||||
| const GemmCoord& warp_shape, const int32_t access_size, | |||||
| int stages, cudaStream_t stream) { | |||||
| bool without_shared_load = | |||||
| ((param.co % threadblock_shape.n() == 0) && | |||||
| (threadblock_shape.n() == 32 || threadblock_shape.n() == 64)); | |||||
| int out_elements_per_access = | |||||
| without_shared_load ? threadblock_shape.n() / 4 : 8; | |||||
| #define RUN_CUTLASS_WRAPPER(stage_, access_size_, without_shared_load_) \ | |||||
| using Convolution = cutlass::conv::device::Convolution< \ | |||||
| cutlass::int4b_t, cutlass::layout::TensorNHWC, cutlass::int4b_t, \ | |||||
| cutlass::layout::TensorNCxHWx<access_size_>, ElementOutput, \ | |||||
| cutlass::layout::TensorNHWC, int32_t, cutlass::layout::TensorNHWC, \ | |||||
| int32_t, cutlass::conv::ConvType::kConvolution, \ | |||||
| cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \ | |||||
| ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ | |||||
| cutlass::conv::threadblock:: \ | |||||
| ConvolutionFpropTransThreadblockSwizzle, \ | |||||
| stage_, access_size_, access_size_, NeedLoadFromConstMem, \ | |||||
| cutlass::arch::OpMultiplyAddSaturate, \ | |||||
| cutlass::conv::ImplicitGemmMode::GEMM_TN, without_shared_load_>; \ | |||||
| typename Convolution::ConvolutionParameter conv_param( \ | |||||
| param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ | |||||
| param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ | |||||
| param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \ | |||||
| return cutlass_convolution_wrapper<Convolution>( \ | |||||
| reinterpret_cast<const cutlass::int4b_t*>(d_src), \ | |||||
| reinterpret_cast<const cutlass::int4b_t*>(d_filter), d_bias, \ | |||||
| reinterpret_cast<const cutlass::int4b_t*>(d_z), \ | |||||
| reinterpret_cast<cutlass::int4b_t*>(d_dst), workspace, conv_param, \ | |||||
| epilogue, stream); | |||||
| #define DISPATCH_KERNEL_WITH_TILE_SHAPE( \ | |||||
| threadblock_m_, threadblock_n_, threadblock_k_, warp_m_, warp_n_, \ | |||||
| warp_k_, stage_, access_size_, out_elements_per_access_, \ | |||||
| without_shared_load_) \ | |||||
| if (threadblock_shape.m() == threadblock_m_ && \ | |||||
| threadblock_shape.n() == threadblock_n_ && \ | |||||
| threadblock_shape.k() == threadblock_k_ && \ | |||||
| warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ | |||||
| warp_shape.k() == warp_k_ && stages == stage_ && \ | |||||
| access_size == access_size_ && \ | |||||
| out_elements_per_access == out_elements_per_access_ && \ | |||||
| without_shared_load == without_shared_load_) { \ | |||||
| using ThreadBlockShape = \ | |||||
| cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \ | |||||
| threadblock_k_>; \ | |||||
| using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \ | |||||
| using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; \ | |||||
| using ElementOutput = cutlass::int4b_t; \ | |||||
| using ElementAccumulator = int32_t; \ | |||||
| using ElementBias = int32_t; \ | |||||
| using ElementCompute = float; \ | |||||
| using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode; \ | |||||
| switch (nonlinear_mode) { \ | |||||
| case NonlineMode::IDENTITY: { \ | |||||
| using EpilogueOp = cutlass::epilogue::thread:: \ | |||||
| BiasAddLinearCombinationClamp< \ | |||||
| ElementOutput, out_elements_per_access_, \ | |||||
| ElementAccumulator, ElementBias, \ | |||||
| ElementCompute>; \ | |||||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma}; \ | |||||
| RUN_CUTLASS_WRAPPER(stage_, access_size_, \ | |||||
| without_shared_load_); \ | |||||
| } \ | |||||
| case NonlineMode::RELU: { \ | |||||
| using EpilogueOp = cutlass::epilogue::thread:: \ | |||||
| BiasAddLinearCombinationReluClamp< \ | |||||
| ElementOutput, out_elements_per_access_, \ | |||||
| ElementAccumulator, ElementBias, \ | |||||
| ElementCompute>; \ | |||||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma, 0}; \ | |||||
| RUN_CUTLASS_WRAPPER(stage_, access_size_, \ | |||||
| without_shared_load_); \ | |||||
| } \ | |||||
| case NonlineMode::H_SWISH: { \ | |||||
| using EpilogueOp = cutlass::epilogue::thread:: \ | |||||
| BiasAddLinearCombinationHSwishClamp< \ | |||||
| ElementOutput, out_elements_per_access_, \ | |||||
| ElementAccumulator, ElementBias, \ | |||||
| ElementCompute>; \ | |||||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma, \ | |||||
| scale}; \ | |||||
| RUN_CUTLASS_WRAPPER(stage_, access_size_, \ | |||||
| without_shared_load_); \ | |||||
| } \ | |||||
| default: \ | |||||
| megdnn_assert( \ | |||||
| false, \ | |||||
| "unsupported nonlinear mode for conv bias operator"); \ | |||||
| } \ | |||||
| } | |||||
| #define DISPATCH_KERNEL \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 32, 8, false); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 16, 8, false); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 8, 8, false); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 32, 8, false); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 16, 8, false); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 8, 8, false); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 32, 8, true); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 16, 8, true); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 8, 8, true); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 32, 16, true); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 16, 16, true); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 8, 16, true); \ | |||||
| megdnn_assert(false, \ | |||||
| "unsupported threadblock shape (%dx%dx%d) and warp shape " \ | |||||
| "(%dx%dx%d) and access_size (%d)", \ | |||||
| threadblock_shape.m(), threadblock_shape.n(), \ | |||||
| threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ | |||||
| warp_shape.k(), access_size); | |||||
| DISPATCH_KERNEL; | |||||
| #undef RUN_CUTLASS_WRAPPER | |||||
| #undef DISPATCH_KERNEL_WITH_TILE_SHAPE | |||||
| #undef DISPATCH_KERNEL | |||||
| } | |||||
| #endif | |||||
| #define INST(need_load_from_const_mem) \ | |||||
| template void megdnn::cuda::cutlass_wrapper:: \ | |||||
| do_conv_bias_int4_int4_implicit_gemm_imma_nhwc< \ | |||||
| need_load_from_const_mem>( \ | |||||
| const int8_t* d_src, const int8_t* d_filter, \ | |||||
| const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, \ | |||||
| int* workspace, const convolution::ConvParam& param, \ | |||||
| uint32_t nonlinear_mode, float alpha, float beta, \ | |||||
| float gamma, float scale, \ | |||||
| const GemmCoord& threadblock_shape, \ | |||||
| const GemmCoord& warp_shape, const int32_t access_size, \ | |||||
| int stages, cudaStream_t stream); | |||||
| INST(true); | |||||
| INST(false); | |||||
| #undef INST | |||||
| /* ====== cutlass kernel wrapper for uint4 x int4 nhwc layout ====== */ | |||||
| #if MEGDNN_TEGRA_X1 | |||||
| template <bool NeedLoadFromConstMem> | |||||
| void megdnn::cuda::cutlass_wrapper:: | |||||
| do_conv_bias_uint4_int4_implicit_gemm_imma_nhwc( | |||||
| const uint8_t* /* d_src */, const int8_t* /* d_filter */, | |||||
| const int32_t* /* d_bias */, const uint8_t* /* d_z */, | |||||
| uint8_t* /* d_dst */, int* /* workspace */, | |||||
| const convolution::ConvParam& /* param */, | |||||
| uint32_t /* nonlinear_mode */, float /* alpha */, | |||||
| float /* beta */, float /* gamma */, float /* delta */, | |||||
| float /* theta */, float /* scale */, | |||||
| uint8_t /* src_zero_point */, | |||||
| const GemmCoord& /* threadblock_shape */, | |||||
| const GemmCoord& /* warp_shape */, | |||||
| const int32_t /* access_size */, int /* stages */, | |||||
| cudaStream_t /* stream */) {} | |||||
| #else | |||||
| template <bool NeedLoadFromConstMem> | |||||
| void megdnn::cuda::cutlass_wrapper:: | |||||
| do_conv_bias_uint4_int4_implicit_gemm_imma_nhwc( | |||||
| const uint8_t* d_src, const int8_t* d_filter, | |||||
| const int32_t* d_bias, const uint8_t* d_z, uint8_t* d_dst, | |||||
| int* workspace, const convolution::ConvParam& param, | |||||
| uint32_t nonlinear_mode, float alpha, float beta, float gamma, | |||||
| float delta, float theta, float /* scale */, | |||||
| uint8_t src_zero_point, const GemmCoord& threadblock_shape, | |||||
| const GemmCoord& warp_shape, const int32_t access_size, | |||||
| int stages, cudaStream_t stream) { | |||||
| bool without_shared_load = | |||||
| ((param.co % threadblock_shape.n() == 0) && | |||||
| (threadblock_shape.n() == 32 || threadblock_shape.n() == 64)); | |||||
| int out_elements_per_access = | |||||
| without_shared_load ? threadblock_shape.n() / 4 : 8; | |||||
| #define RUN_CUTLASS_WRAPPER(stage_, access_size_, without_shared_load_) \ | |||||
| using Convolution = cutlass::conv::device::Convolution< \ | |||||
| cutlass::uint4b_t, cutlass::layout::TensorNHWC, cutlass::int4b_t, \ | |||||
| cutlass::layout::TensorNCxHWx<access_size_>, ElementOutput, \ | |||||
| cutlass::layout::TensorNHWC, int32_t, cutlass::layout::TensorNHWC, \ | |||||
| int32_t, cutlass::conv::ConvType::kConvolution, \ | |||||
| cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \ | |||||
| ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ | |||||
| cutlass::conv::threadblock:: \ | |||||
| ConvolutionFpropTransThreadblockSwizzle, \ | |||||
| stage_, access_size_, access_size_, NeedLoadFromConstMem, \ | |||||
| cutlass::arch::OpMultiplyAddSaturate, \ | |||||
| cutlass::conv::ImplicitGemmMode::GEMM_TN, without_shared_load_>; \ | |||||
| typename Convolution::ConvolutionParameter conv_param( \ | |||||
| param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ | |||||
| param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ | |||||
| param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \ | |||||
| return cutlass_convolution_wrapper<Convolution>( \ | |||||
| reinterpret_cast<const cutlass::uint4b_t*>(d_src), \ | |||||
| reinterpret_cast<const cutlass::int4b_t*>(d_filter), d_bias, \ | |||||
| reinterpret_cast<const cutlass::uint4b_t*>(d_z), \ | |||||
| reinterpret_cast<cutlass::uint4b_t*>(d_dst), workspace, \ | |||||
| conv_param, epilogue, stream, {src_zero_point}); | |||||
| #define DISPATCH_KERNEL_WITH_TILE_SHAPE( \ | |||||
| threadblock_m_, threadblock_n_, threadblock_k_, warp_m_, warp_n_, \ | |||||
| warp_k_, stage_, access_size_, out_elements_per_access_, \ | |||||
| without_shared_load_) \ | |||||
| if (threadblock_shape.m() == threadblock_m_ && \ | |||||
| threadblock_shape.n() == threadblock_n_ && \ | |||||
| threadblock_shape.k() == threadblock_k_ && \ | |||||
| warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ | |||||
| warp_shape.k() == warp_k_ && stages == stage_ && \ | |||||
| access_size == access_size_ && \ | |||||
| out_elements_per_access == out_elements_per_access_ && \ | |||||
| without_shared_load == without_shared_load_) { \ | |||||
| using ThreadBlockShape = \ | |||||
| cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \ | |||||
| threadblock_k_>; \ | |||||
| using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \ | |||||
| using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; \ | |||||
| using ElementOutput = cutlass::uint4b_t; \ | |||||
| using ElementAccumulator = int32_t; \ | |||||
| using ElementBias = int32_t; \ | |||||
| using ElementCompute = float; \ | |||||
| using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode; \ | |||||
| switch (nonlinear_mode) { \ | |||||
| case NonlineMode::IDENTITY: { \ | |||||
| using EpilogueOp = cutlass::epilogue::thread:: \ | |||||
| BiasAddLinearCombinationClamp< \ | |||||
| ElementOutput, out_elements_per_access_, \ | |||||
| ElementAccumulator, ElementBias, \ | |||||
| ElementCompute>; \ | |||||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma, \ | |||||
| delta + theta}; \ | |||||
| RUN_CUTLASS_WRAPPER(stage_, access_size_, \ | |||||
| without_shared_load_); \ | |||||
| } \ | |||||
| case NonlineMode::RELU: { \ | |||||
| using EpilogueOp = cutlass::epilogue::thread:: \ | |||||
| BiasAddLinearCombinationReluClamp< \ | |||||
| ElementOutput, out_elements_per_access_, \ | |||||
| ElementAccumulator, ElementBias, \ | |||||
| ElementCompute>; \ | |||||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma, \ | |||||
| 0, delta, theta}; \ | |||||
| RUN_CUTLASS_WRAPPER(stage_, access_size_, \ | |||||
| without_shared_load_); \ | |||||
| } \ | |||||
| default: \ | |||||
| megdnn_assert( \ | |||||
| false, \ | |||||
| "unsupported nonlinear mode for conv bias operator"); \ | |||||
| } \ | |||||
| } | |||||
| #define DISPATCH_KERNEL \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 32, 8, false); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 16, 8, false); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 8, 8, false); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 32, 8, false); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 16, 8, false); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 8, 8, false); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 32, 8, true); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 16, 8, true); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 8, 8, true); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 32, 16, true); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 16, 16, true); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 8, 16, true); \ | |||||
| megdnn_assert(false, \ | |||||
| "unsupported threadblock shape (%dx%dx%d) and warp shape " \ | |||||
| "(%dx%dx%d) and access_size (%d)", \ | |||||
| threadblock_shape.m(), threadblock_shape.n(), \ | |||||
| threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ | |||||
| warp_shape.k(), access_size); | |||||
| DISPATCH_KERNEL; | |||||
| #undef RUN_CUTLASS_WRAPPER | |||||
| #undef DISPATCH_KERNEL_WITH_TILE_SHAPE | |||||
| #undef DISPATCH_KERNEL | |||||
| } | |||||
| #endif | |||||
| #define INST(need_load_from_const_mem) \ | |||||
| template void megdnn::cuda::cutlass_wrapper:: \ | |||||
| do_conv_bias_uint4_int4_implicit_gemm_imma_nhwc< \ | |||||
| need_load_from_const_mem>( \ | |||||
| const uint8_t* d_src, const int8_t* d_filter, \ | |||||
| const int32_t* d_bias, const uint8_t* d_z, uint8_t* d_dst, \ | |||||
| int* workspace, const convolution::ConvParam& param, \ | |||||
| uint32_t nonlinear_mode, float alpha, float beta, \ | |||||
| float gamma, float delta, float theta, float scale, \ | |||||
| uint8_t src_zero_point, \ | |||||
| const GemmCoord& threadblock_shape, \ | |||||
| const GemmCoord& warp_shape, const int32_t access_size, \ | |||||
| int stages, cudaStream_t stream); | |||||
| INST(true); | |||||
| INST(false); | |||||
| #undef INST | |||||
| // vim: syntax=cuda.doxygen | |||||
| @@ -38,7 +38,8 @@ void megdnn::cuda::cutlass_wrapper:: | |||||
| uint32_t /* nonlinear_mode */, float /* alpha */, | uint32_t /* nonlinear_mode */, float /* alpha */, | ||||
| float /* beta */, float /* gamma */, float /* scale */, | float /* beta */, float /* gamma */, float /* scale */, | ||||
| const GemmCoord& /* threadblock_shape */, | const GemmCoord& /* threadblock_shape */, | ||||
| const GemmCoord& /* warp_shape */, cudaStream_t /* stream */) {} | |||||
| const GemmCoord& /* warp_shape */, int /* stages */, | |||||
| cudaStream_t /* stream */) {} | |||||
| #else | #else | ||||
| template <bool NeedLoadFromConstMem> | template <bool NeedLoadFromConstMem> | ||||
| void megdnn::cuda::cutlass_wrapper:: | void megdnn::cuda::cutlass_wrapper:: | ||||
| @@ -48,15 +49,15 @@ void megdnn::cuda::cutlass_wrapper:: | |||||
| int* workspace, const convolution::ConvParam& param, | int* workspace, const convolution::ConvParam& param, | ||||
| uint32_t nonlinear_mode, float alpha, float beta, float gamma, | uint32_t nonlinear_mode, float alpha, float beta, float gamma, | ||||
| float scale, const GemmCoord& threadblock_shape, | float scale, const GemmCoord& threadblock_shape, | ||||
| const GemmCoord& warp_shape, cudaStream_t stream) { | |||||
| const GemmCoord& warp_shape, int stages, cudaStream_t stream) { | |||||
| #define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \ | #define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \ | ||||
| threadblock_k_, warp_m_, warp_n_, \ | threadblock_k_, warp_m_, warp_n_, \ | ||||
| warp_k_) \ | |||||
| warp_k_, stage_) \ | |||||
| if (threadblock_shape.m() == threadblock_m_ && \ | if (threadblock_shape.m() == threadblock_m_ && \ | ||||
| threadblock_shape.n() == threadblock_n_ && \ | threadblock_shape.n() == threadblock_n_ && \ | ||||
| threadblock_shape.k() == threadblock_k_ && \ | threadblock_shape.k() == threadblock_k_ && \ | ||||
| warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ | warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ | ||||
| warp_shape.k() == warp_k_) { \ | |||||
| warp_shape.k() == warp_k_ && stages == stage_) { \ | |||||
| using ThreadBlockShape = \ | using ThreadBlockShape = \ | ||||
| cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \ | cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \ | ||||
| threadblock_k_>; \ | threadblock_k_>; \ | ||||
| @@ -71,8 +72,10 @@ void megdnn::cuda::cutlass_wrapper:: | |||||
| cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \ | cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \ | ||||
| ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ | ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ | ||||
| cutlass::conv::threadblock:: \ | cutlass::conv::threadblock:: \ | ||||
| ConvolutionFpropNCxHWxThreadblockSwizzle, \ | |||||
| 2, 16, 16, NeedLoadFromConstMem>; \ | |||||
| ConvolutionFpropTransThreadblockSwizzle, \ | |||||
| stage_, 16, 16, NeedLoadFromConstMem, \ | |||||
| cutlass::arch::OpMultiplyAddSaturate, \ | |||||
| cutlass::conv::ImplicitGemmMode::GEMM_TN, true>; \ | |||||
| typename Convolution::ConvolutionParameter conv_param( \ | typename Convolution::ConvolutionParameter conv_param( \ | ||||
| param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ | param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ | ||||
| param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ | param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ | ||||
| @@ -82,13 +85,15 @@ void megdnn::cuda::cutlass_wrapper:: | |||||
| epilogue, stream); \ | epilogue, stream); \ | ||||
| } | } | ||||
| #define DISPATCH_KERNEL \ | #define DISPATCH_KERNEL \ | ||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(256, 128, 64, 64, 64, 64); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 256, 64, 64, 64, 64); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 64, 64, 64, 64); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 64, 32, 64, 64); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 32, 64); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 64, 64, 32, 32, 64); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 64, 64, 32, 16, 64); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 256, 64, 64, 64, 64, 2); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(256, 128, 64, 64, 64, 64, 2); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 64, 64, 64, 64, 2); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 32, 64, 2); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 64, 32, 64, 64, 2); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 32, 64, 32, 32, 1); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 32, 64, 32, 32, 1); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 32, 64, 32, 1); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 1); \ | |||||
| megdnn_assert(false, \ | megdnn_assert(false, \ | ||||
| "unsupported threadblock shape (%dx%dx%d) and warp shape " \ | "unsupported threadblock shape (%dx%dx%d) and warp shape " \ | ||||
| "(%dx%dx%d)", \ | "(%dx%dx%d)", \ | ||||
| @@ -144,7 +149,8 @@ void megdnn::cuda::cutlass_wrapper:: | |||||
| uint32_t nonlinear_mode, float alpha, float beta, \ | uint32_t nonlinear_mode, float alpha, float beta, \ | ||||
| float gamma, float scale, \ | float gamma, float scale, \ | ||||
| const GemmCoord& threadblock_shape, \ | const GemmCoord& threadblock_shape, \ | ||||
| const GemmCoord& warp_shape, cudaStream_t stream); | |||||
| const GemmCoord& warp_shape, int stages, \ | |||||
| cudaStream_t stream); | |||||
| INST(true); | INST(true); | ||||
| INST(false); | INST(false); | ||||
| #undef INST | #undef INST | ||||
| @@ -162,7 +168,8 @@ void megdnn::cuda::cutlass_wrapper:: | |||||
| uint32_t /* nonlinear_mode */, float /* alpha */, | uint32_t /* nonlinear_mode */, float /* alpha */, | ||||
| float /* beta */, float /* gamma */, float /* scale */, | float /* beta */, float /* gamma */, float /* scale */, | ||||
| const GemmCoord& /* threadblock_shape */, | const GemmCoord& /* threadblock_shape */, | ||||
| const GemmCoord& /* warp_shape */, cudaStream_t /* stream */) {} | |||||
| const GemmCoord& /* warp_shape */, int /* stages */, | |||||
| cudaStream_t /* stream */) {} | |||||
| #else | #else | ||||
| template <bool NeedLoadFromConstMem> | template <bool NeedLoadFromConstMem> | ||||
| void megdnn::cuda::cutlass_wrapper:: | void megdnn::cuda::cutlass_wrapper:: | ||||
| @@ -172,15 +179,15 @@ void megdnn::cuda::cutlass_wrapper:: | |||||
| int* workspace, const convolution::ConvParam& param, | int* workspace, const convolution::ConvParam& param, | ||||
| uint32_t nonlinear_mode, float alpha, float beta, float gamma, | uint32_t nonlinear_mode, float alpha, float beta, float gamma, | ||||
| float scale, const GemmCoord& threadblock_shape, | float scale, const GemmCoord& threadblock_shape, | ||||
| const GemmCoord& warp_shape, cudaStream_t stream) { | |||||
| const GemmCoord& warp_shape, int stages, cudaStream_t stream) { | |||||
| #define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \ | #define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \ | ||||
| threadblock_k_, warp_m_, warp_n_, \ | threadblock_k_, warp_m_, warp_n_, \ | ||||
| warp_k_) \ | |||||
| warp_k_, stage_) \ | |||||
| if (threadblock_shape.m() == threadblock_m_ && \ | if (threadblock_shape.m() == threadblock_m_ && \ | ||||
| threadblock_shape.n() == threadblock_n_ && \ | threadblock_shape.n() == threadblock_n_ && \ | ||||
| threadblock_shape.k() == threadblock_k_ && \ | threadblock_shape.k() == threadblock_k_ && \ | ||||
| warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ | warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ | ||||
| warp_shape.k() == warp_k_) { \ | |||||
| warp_shape.k() == warp_k_ && stages == stage_) { \ | |||||
| using ThreadBlockShape = \ | using ThreadBlockShape = \ | ||||
| cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \ | cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \ | ||||
| threadblock_k_>; \ | threadblock_k_>; \ | ||||
| @@ -196,7 +203,7 @@ void megdnn::cuda::cutlass_wrapper:: | |||||
| ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ | ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ | ||||
| cutlass::conv::threadblock:: \ | cutlass::conv::threadblock:: \ | ||||
| ConvolutionFpropNCxHWxThreadblockSwizzle, \ | ConvolutionFpropNCxHWxThreadblockSwizzle, \ | ||||
| 2, 16, 16, NeedLoadFromConstMem>; \ | |||||
| stage_, 16, 16, NeedLoadFromConstMem>; \ | |||||
| typename Convolution::ConvolutionParameter conv_param( \ | typename Convolution::ConvolutionParameter conv_param( \ | ||||
| param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ | param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ | ||||
| param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ | param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ | ||||
| @@ -206,13 +213,15 @@ void megdnn::cuda::cutlass_wrapper:: | |||||
| epilogue, stream); \ | epilogue, stream); \ | ||||
| } | } | ||||
| #define DISPATCH_KERNEL \ | #define DISPATCH_KERNEL \ | ||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(256, 128, 64, 64, 64, 64); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 256, 64, 64, 64, 64); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 64, 64, 64, 64); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 64, 32, 64, 64); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 32, 64); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 64, 64, 32, 32, 64); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 64, 64, 16, 32, 64); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 256, 64, 64, 64, 64, 2); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(256, 128, 64, 64, 64, 64, 2); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 64, 64, 64, 64, 2); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 32, 64, 2); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 64, 32, 64, 64, 2); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 32, 64, 32, 32, 1); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 32, 64, 32, 32, 1); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 32, 64, 32, 1); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 1); \ | |||||
| megdnn_assert(false, \ | megdnn_assert(false, \ | ||||
| "unsupported threadblock shape (%dx%dx%d) and warp shape " \ | "unsupported threadblock shape (%dx%dx%d) and warp shape " \ | ||||
| "(%dx%dx%d)", \ | "(%dx%dx%d)", \ | ||||
| @@ -268,7 +277,8 @@ void megdnn::cuda::cutlass_wrapper:: | |||||
| uint32_t nonlinear_mode, float alpha, float beta, \ | uint32_t nonlinear_mode, float alpha, float beta, \ | ||||
| float gamma, float scale, \ | float gamma, float scale, \ | ||||
| const GemmCoord& threadblock_shape, \ | const GemmCoord& threadblock_shape, \ | ||||
| const GemmCoord& warp_shape, cudaStream_t stream); | |||||
| const GemmCoord& warp_shape, int stages, \ | |||||
| cudaStream_t stream); | |||||
| INST(true); | INST(true); | ||||
| INST(false); | INST(false); | ||||
| #undef INST | #undef INST | ||||
| @@ -337,10 +347,8 @@ void megdnn::cuda::cutlass_wrapper:: | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 64, 32, 32, 2, 16); \ | DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 64, 32, 32, 2, 16); \ | ||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 32, 64, 32, 32, 2, 16); \ | DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 32, 64, 32, 32, 2, 16); \ | ||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 2, 16); \ | DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 2, 16); \ | ||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 64, 32, 64, 32, 32, 2, 16); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 64, 32, 32, 64, 32, 2, 16); \ | DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 64, 32, 32, 64, 32, 2, 16); \ | ||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 32, 32, 64, 32, 32, 2, 16); \ | DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 32, 32, 64, 32, 32, 2, 16); \ | ||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 32, 32, 32, 32, 32, 2, 16); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 128, 16, 16, 128, 16, 1, 8); \ | DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 128, 16, 16, 128, 16, 1, 8); \ | ||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 64, 8, 16, 64, 8, 2, 4); \ | DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 64, 8, 16, 64, 8, 2, 4); \ | ||||
| megdnn_assert(false, \ | megdnn_assert(false, \ | ||||
| @@ -468,10 +476,8 @@ void megdnn::cuda::cutlass_wrapper:: | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 64, 32, 32, 2, 16); \ | DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 64, 32, 32, 2, 16); \ | ||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 32, 64, 32, 32, 2, 16); \ | DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 32, 64, 32, 32, 2, 16); \ | ||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 2, 16); \ | DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 2, 16); \ | ||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 64, 32, 64, 32, 32, 2, 16); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 64, 32, 32, 64, 32, 2, 16); \ | DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 64, 32, 32, 64, 32, 2, 16); \ | ||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 32, 32, 64, 32, 32, 2, 16); \ | DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 32, 32, 64, 32, 32, 2, 16); \ | ||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 32, 32, 32, 32, 32, 2, 16); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 128, 16, 16, 128, 16, 1, 8); \ | DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 128, 16, 16, 128, 16, 1, 8); \ | ||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 64, 8, 16, 64, 8, 2, 4); \ | DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 64, 8, 16, 64, 8, 2, 4); \ | ||||
| megdnn_assert(false, \ | megdnn_assert(false, \ | ||||
| @@ -599,10 +605,8 @@ void megdnn::cuda::cutlass_wrapper:: | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 64, 32, 32, 2, 16); \ | DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 64, 32, 32, 2, 16); \ | ||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 32, 64, 32, 32, 2, 16); \ | DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 32, 64, 32, 32, 2, 16); \ | ||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 2, 16); \ | DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 2, 16); \ | ||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 64, 32, 64, 32, 32, 2, 16); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 64, 32, 32, 64, 32, 2, 16); \ | DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 64, 32, 32, 64, 32, 2, 16); \ | ||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 32, 32, 64, 32, 32, 2, 16); \ | DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 32, 32, 64, 32, 32, 2, 16); \ | ||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 32, 32, 32, 32, 32, 2, 16); \ | |||||
| megdnn_assert(false, \ | megdnn_assert(false, \ | ||||
| "unsupported threadblock shape (%dx%dx%d) and warp shape " \ | "unsupported threadblock shape (%dx%dx%d) and warp shape " \ | ||||
| "(%dx%dx%d)", \ | "(%dx%dx%d)", \ | ||||
| @@ -664,246 +668,6 @@ INST(true); | |||||
| INST(false); | INST(false); | ||||
| #undef INST | #undef INST | ||||
| /* ====== cutlass kernel wrapper for int4 x int4 nchw64 layout ====== */ | |||||
| #if MEGDNN_TEGRA_X1 | |||||
| template <bool NeedLoadFromConstMem> | |||||
| void megdnn::cuda::cutlass_wrapper:: | |||||
| do_conv_bias_int4_int4_implicit_gemm_imma_ncdiv64hw64( | |||||
| const int8_t* /* d_src */, const int8_t* /* d_filter */, | |||||
| const int32_t* /* d_bias */, const int8_t* /* d_z */, | |||||
| int8_t* /* d_dst */, int* /* workspace */, | |||||
| const convolution::ConvParam& /* param */, | |||||
| uint32_t /* nonlinear_mode */, float /* alpha */, | |||||
| float /* beta */, float /* gamma */, float /* scale */, | |||||
| const GemmCoord& /* threadblock_shape */, | |||||
| const GemmCoord& /* warp_shape */, cudaStream_t /* stream */) {} | |||||
| #else | |||||
| template <bool NeedLoadFromConstMem> | |||||
| void megdnn::cuda::cutlass_wrapper:: | |||||
| do_conv_bias_int4_int4_implicit_gemm_imma_ncdiv64hw64( | |||||
| const int8_t* d_src, const int8_t* d_filter, | |||||
| const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, | |||||
| int* workspace, const convolution::ConvParam& param, | |||||
| uint32_t nonlinear_mode, float alpha, float beta, float gamma, | |||||
| float scale, const GemmCoord& threadblock_shape, | |||||
| const GemmCoord& warp_shape, cudaStream_t stream) { | |||||
| #define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \ | |||||
| threadblock_k_, warp_m_, warp_n_, \ | |||||
| warp_k_) \ | |||||
| if (threadblock_shape.m() == threadblock_m_ && \ | |||||
| threadblock_shape.n() == threadblock_n_ && \ | |||||
| threadblock_shape.k() == threadblock_k_ && \ | |||||
| warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ | |||||
| warp_shape.k() == warp_k_) { \ | |||||
| using ThreadBlockShape = \ | |||||
| cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \ | |||||
| threadblock_k_>; \ | |||||
| using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \ | |||||
| using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; \ | |||||
| using Convolution = cutlass::conv::device::Convolution< \ | |||||
| cutlass::int4b_t, cutlass::layout::TensorNCxHWx<64>, \ | |||||
| cutlass::int4b_t, cutlass::layout::TensorCxRSKx<64>, \ | |||||
| ElementOutput, cutlass::layout::TensorNCxHWx<64>, int32_t, \ | |||||
| cutlass::layout::TensorNCxHWx<64>, int32_t, \ | |||||
| cutlass::conv::ConvType::kConvolution, \ | |||||
| cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \ | |||||
| ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ | |||||
| cutlass::conv::threadblock:: \ | |||||
| ConvolutionFpropNCxHWxThreadblockSwizzle, \ | |||||
| 2, 32, 32, NeedLoadFromConstMem>; \ | |||||
| typename Convolution::ConvolutionParameter conv_param( \ | |||||
| param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ | |||||
| param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ | |||||
| param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \ | |||||
| return cutlass_convolution_wrapper<Convolution>( \ | |||||
| reinterpret_cast<const cutlass::int4b_t*>(d_src), \ | |||||
| reinterpret_cast<const cutlass::int4b_t*>(d_filter), d_bias, \ | |||||
| reinterpret_cast<const cutlass::int4b_t*>(d_z), \ | |||||
| reinterpret_cast<cutlass::int4b_t*>(d_dst), workspace, \ | |||||
| conv_param, epilogue, stream); \ | |||||
| } | |||||
| #define DISPATCH_KERNEL \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 128, 64, 64, 128); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(256, 128, 128, 64, 64, 128); \ | |||||
| megdnn_assert(false, \ | |||||
| "unsupported threadblock shape (%dx%dx%d) and warp shape " \ | |||||
| "(%dx%dx%d)", \ | |||||
| threadblock_shape.m(), threadblock_shape.n(), \ | |||||
| threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ | |||||
| warp_shape.k()); | |||||
| using ElementOutput = cutlass::int4b_t; | |||||
| using ElementAccumulator = int32_t; | |||||
| using ElementBias = int32_t; | |||||
| using ElementCompute = float; | |||||
| using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode; | |||||
| switch (nonlinear_mode) { | |||||
| case NonlineMode::IDENTITY: { | |||||
| using EpilogueOp = | |||||
| cutlass::epilogue::thread::BiasAddLinearCombinationClamp< | |||||
| ElementOutput, 16, ElementAccumulator, ElementBias, | |||||
| ElementCompute>; | |||||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma}; | |||||
| DISPATCH_KERNEL; | |||||
| } | |||||
| case NonlineMode::RELU: { | |||||
| using EpilogueOp = cutlass::epilogue::thread:: | |||||
| BiasAddLinearCombinationReluClamp< | |||||
| ElementOutput, 16, ElementAccumulator, ElementBias, | |||||
| ElementCompute>; | |||||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma, 0}; | |||||
| DISPATCH_KERNEL; | |||||
| } | |||||
| case NonlineMode::H_SWISH: { | |||||
| using EpilogueOp = cutlass::epilogue::thread:: | |||||
| BiasAddLinearCombinationHSwishClamp< | |||||
| ElementOutput, 16, ElementAccumulator, ElementBias, | |||||
| ElementCompute>; | |||||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma, scale}; | |||||
| DISPATCH_KERNEL; | |||||
| } | |||||
| default: | |||||
| megdnn_assert(false, | |||||
| "unsupported nonlinear mode for conv bias operator"); | |||||
| } | |||||
| #undef DISPATCH_KERNEL_WITH_TILE_SHAPE | |||||
| #undef DISPATCH_KERNEL | |||||
| } | |||||
| #endif | |||||
| #define INST(need_load_from_const_mem) \ | |||||
| template void megdnn::cuda::cutlass_wrapper:: \ | |||||
| do_conv_bias_int4_int4_implicit_gemm_imma_ncdiv64hw64< \ | |||||
| need_load_from_const_mem>( \ | |||||
| const int8_t* d_src, const int8_t* d_filter, \ | |||||
| const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, \ | |||||
| int* workspace, const convolution::ConvParam& param, \ | |||||
| uint32_t nonlinear_mode, float alpha, float beta, \ | |||||
| float gamma, float scale, \ | |||||
| const GemmCoord& threadblock_shape, \ | |||||
| const GemmCoord& warp_shape, cudaStream_t stream); | |||||
| INST(true); | |||||
| #undef INST | |||||
| /* ====== cutlass kernel wrapper for uint4 x int4 nchw64 layout ====== */ | |||||
| #if MEGDNN_TEGRA_X1 | |||||
| template <bool NeedLoadFromConstMem> | |||||
| void megdnn::cuda::cutlass_wrapper:: | |||||
| do_conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64( | |||||
| const uint8_t* /* d_src */, const int8_t* /* d_filter */, | |||||
| const int32_t* /* d_bias */, const uint8_t* /* d_z */, | |||||
| uint8_t* /* d_dst */, int* /* workspace */, | |||||
| const convolution::ConvParam& /* param */, | |||||
| uint32_t /* nonlinear_mode */, float /* alpha */, | |||||
| float /* beta */, float /* gamma */, float /* delta */, | |||||
| float /* theta */, float /* scale */, | |||||
| uint8_t /* src_zero_point */, | |||||
| const GemmCoord& /* threadblock_shape */, | |||||
| const GemmCoord& /* warp_shape */, cudaStream_t /* stream */) {} | |||||
| #else | |||||
| template <bool NeedLoadFromConstMem> | |||||
| void megdnn::cuda::cutlass_wrapper:: | |||||
| do_conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64( | |||||
| const uint8_t* d_src, const int8_t* d_filter, | |||||
| const int32_t* d_bias, const uint8_t* d_z, uint8_t* d_dst, | |||||
| int* workspace, const convolution::ConvParam& param, | |||||
| uint32_t nonlinear_mode, float alpha, float beta, float gamma, | |||||
| float delta, float theta, float /* scale */, | |||||
| uint8_t src_zero_point, const GemmCoord& threadblock_shape, | |||||
| const GemmCoord& warp_shape, cudaStream_t stream) { | |||||
| #define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \ | |||||
| threadblock_k_, warp_m_, warp_n_, \ | |||||
| warp_k_) \ | |||||
| if (threadblock_shape.m() == threadblock_m_ && \ | |||||
| threadblock_shape.n() == threadblock_n_ && \ | |||||
| threadblock_shape.k() == threadblock_k_ && \ | |||||
| warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ | |||||
| warp_shape.k() == warp_k_) { \ | |||||
| using ThreadBlockShape = \ | |||||
| cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \ | |||||
| threadblock_k_>; \ | |||||
| using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \ | |||||
| using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; \ | |||||
| using Convolution = cutlass::conv::device::Convolution< \ | |||||
| cutlass::uint4b_t, cutlass::layout::TensorNCxHWx<64>, \ | |||||
| cutlass::int4b_t, cutlass::layout::TensorCxRSKx<64>, \ | |||||
| ElementOutput, cutlass::layout::TensorNCxHWx<64>, int32_t, \ | |||||
| cutlass::layout::TensorNCxHWx<64>, int32_t, \ | |||||
| cutlass::conv::ConvType::kConvolution, \ | |||||
| cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \ | |||||
| ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ | |||||
| cutlass::conv::threadblock:: \ | |||||
| ConvolutionFpropNCxHWxThreadblockSwizzle, \ | |||||
| 2, 32, 32, NeedLoadFromConstMem>; \ | |||||
| typename Convolution::ConvolutionParameter conv_param( \ | |||||
| param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ | |||||
| param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ | |||||
| param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \ | |||||
| return cutlass_convolution_wrapper<Convolution>( \ | |||||
| reinterpret_cast<const cutlass::uint4b_t*>(d_src), \ | |||||
| reinterpret_cast<const cutlass::int4b_t*>(d_filter), d_bias, \ | |||||
| reinterpret_cast<const cutlass::uint4b_t*>(d_z), \ | |||||
| reinterpret_cast<cutlass::uint4b_t*>(d_dst), workspace, \ | |||||
| conv_param, epilogue, stream, {src_zero_point}); \ | |||||
| } | |||||
| #define DISPATCH_KERNEL \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 128, 64, 64, 128); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(256, 128, 128, 64, 64, 128); \ | |||||
| megdnn_assert(false, \ | |||||
| "unsupported threadblock shape (%dx%dx%d) and warp shape " \ | |||||
| "(%dx%dx%d)", \ | |||||
| threadblock_shape.m(), threadblock_shape.n(), \ | |||||
| threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ | |||||
| warp_shape.k()); | |||||
| using ElementOutput = cutlass::uint4b_t; | |||||
| using ElementAccumulator = int32_t; | |||||
| using ElementBias = int32_t; | |||||
| using ElementCompute = float; | |||||
| using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode; | |||||
| switch (nonlinear_mode) { | |||||
| case NonlineMode::IDENTITY: { | |||||
| using EpilogueOp = | |||||
| cutlass::epilogue::thread::BiasAddLinearCombinationClamp< | |||||
| ElementOutput, 16, ElementAccumulator, ElementBias, | |||||
| ElementCompute>; | |||||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma, | |||||
| delta + theta}; | |||||
| DISPATCH_KERNEL; | |||||
| } | |||||
| case NonlineMode::RELU: { | |||||
| using EpilogueOp = cutlass::epilogue::thread:: | |||||
| BiasAddLinearCombinationReluClamp< | |||||
| ElementOutput, 16, ElementAccumulator, ElementBias, | |||||
| ElementCompute>; | |||||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma, | |||||
| 0, delta, theta}; | |||||
| DISPATCH_KERNEL; | |||||
| } | |||||
| default: | |||||
| megdnn_assert(false, | |||||
| "unsupported nonlinear mode for conv bias operator"); | |||||
| } | |||||
| #undef DISPATCH_KERNEL_WITH_TILE_SHAPE | |||||
| #undef DISPATCH_KERNEL | |||||
| } | |||||
| #endif | |||||
| #define INST(need_load_from_const_mem) \ | |||||
| template void megdnn::cuda::cutlass_wrapper:: \ | |||||
| do_conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64< \ | |||||
| need_load_from_const_mem>( \ | |||||
| const uint8_t* d_src, const int8_t* d_filter, \ | |||||
| const int32_t* d_bias, const uint8_t* d_z, uint8_t* d_dst, \ | |||||
| int* workspace, const convolution::ConvParam& param, \ | |||||
| uint32_t nonlinear_mode, float alpha, float beta, \ | |||||
| float gamma, float delta, float theta, float scale, \ | |||||
| uint8_t src_zero_point, \ | |||||
| const GemmCoord& threadblock_shape, \ | |||||
| const GemmCoord& warp_shape, cudaStream_t stream); | |||||
| INST(true); | |||||
| #undef INST | |||||
| /* ===== cutlass kernel wrapper for nchw4 layout and nhwc output ===== */ | /* ===== cutlass kernel wrapper for nchw4 layout and nhwc output ===== */ | ||||
| #if MEGDNN_TEGRA_X1 | #if MEGDNN_TEGRA_X1 | ||||
| template <bool signedness> | template <bool signedness> | ||||
| @@ -970,10 +734,8 @@ void megdnn::cuda::cutlass_wrapper:: | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 64, 32, 32, 2, 16); \ | DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 64, 32, 32, 2, 16); \ | ||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 32, 64, 32, 32, 2, 16); \ | DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 32, 64, 32, 32, 2, 16); \ | ||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 2, 16); \ | DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 2, 16); \ | ||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 64, 32, 64, 32, 32, 2, 16); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 64, 32, 32, 64, 32, 2, 16); \ | DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 64, 32, 32, 64, 32, 2, 16); \ | ||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 32, 32, 64, 32, 32, 2, 16); \ | DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 32, 32, 64, 32, 32, 2, 16); \ | ||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 32, 32, 32, 32, 32, 2, 16); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 128, 16, 16, 128, 16, 1, 8); \ | DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 128, 16, 16, 128, 16, 1, 8); \ | ||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 64, 8, 16, 64, 8, 2, 4); \ | DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 64, 8, 16, 64, 8, 2, 4); \ | ||||
| megdnn_assert(false, \ | megdnn_assert(false, \ | ||||
| @@ -1039,262 +801,4 @@ INST(true); | |||||
| INST(false); | INST(false); | ||||
| #undef INST | #undef INST | ||||
| /* ====== cutlass kernel wrapper for int4 x int4 nchw64 layout ====== */ | |||||
| #if MEGDNN_TEGRA_X1 | |||||
| template <bool NeedLoadFromConstMem> | |||||
| void megdnn::cuda::cutlass_wrapper:: | |||||
| do_conv_bias_int4_int4_implicit_gemm_imma_nhwc( | |||||
| const int8_t* /* d_src */, const int8_t* /* d_filter */, | |||||
| const int32_t* /* d_bias */, const int8_t* /* d_z */, | |||||
| int8_t* /* d_dst */, int* /* workspace */, | |||||
| const convolution::ConvParam& /* param */, | |||||
| uint32_t /* nonlinear_mode */, float /* alpha */, | |||||
| float /* beta */, float /* gamma */, float /* scale */, | |||||
| const GemmCoord& /* threadblock_shape */, | |||||
| const GemmCoord& /* warp_shape */, | |||||
| const int32_t /* access_size */, cudaStream_t /* stream */) {} | |||||
| #else | |||||
| template <bool NeedLoadFromConstMem> | |||||
| void megdnn::cuda::cutlass_wrapper:: | |||||
| do_conv_bias_int4_int4_implicit_gemm_imma_nhwc( | |||||
| const int8_t* d_src, const int8_t* d_filter, | |||||
| const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, | |||||
| int* workspace, const convolution::ConvParam& param, | |||||
| uint32_t nonlinear_mode, float alpha, float beta, float gamma, | |||||
| float scale, const GemmCoord& threadblock_shape, | |||||
| const GemmCoord& warp_shape, const int32_t access_size, | |||||
| cudaStream_t stream) { | |||||
| #define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \ | |||||
| threadblock_k_, warp_m_, warp_n_, \ | |||||
| warp_k_, access_size_) \ | |||||
| if (threadblock_shape.m() == threadblock_m_ && \ | |||||
| threadblock_shape.n() == threadblock_n_ && \ | |||||
| threadblock_shape.k() == threadblock_k_ && \ | |||||
| warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ | |||||
| warp_shape.k() == warp_k_ && access_size == access_size_) { \ | |||||
| using ThreadBlockShape = \ | |||||
| cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \ | |||||
| threadblock_k_>; \ | |||||
| using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \ | |||||
| using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; \ | |||||
| using Convolution = cutlass::conv::device::Convolution< \ | |||||
| cutlass::int4b_t, cutlass::layout::TensorNHWC, \ | |||||
| cutlass::int4b_t, cutlass::layout::TensorNCxHWx<access_size_>, \ | |||||
| ElementOutput, cutlass::layout::TensorNHWC, int32_t, \ | |||||
| cutlass::layout::TensorNHWC, int32_t, \ | |||||
| cutlass::conv::ConvType::kConvolution, \ | |||||
| cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \ | |||||
| ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ | |||||
| cutlass::conv::threadblock:: \ | |||||
| ConvolutionFpropNHWCThreadblockSwizzle, \ | |||||
| 2, access_size_, access_size_, NeedLoadFromConstMem, \ | |||||
| cutlass::arch::OpMultiplyAddSaturate, \ | |||||
| cutlass::conv::ImplicitGemmMode::GEMM_TN>; \ | |||||
| typename Convolution::ConvolutionParameter conv_param( \ | |||||
| param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ | |||||
| param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ | |||||
| param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \ | |||||
| return cutlass_convolution_wrapper<Convolution>( \ | |||||
| reinterpret_cast<const cutlass::int4b_t*>(d_src), \ | |||||
| reinterpret_cast<const cutlass::int4b_t*>(d_filter), d_bias, \ | |||||
| reinterpret_cast<const cutlass::int4b_t*>(d_z), \ | |||||
| reinterpret_cast<cutlass::int4b_t*>(d_dst), workspace, \ | |||||
| conv_param, epilogue, stream); \ | |||||
| } | |||||
| #define DISPATCH_KERNEL \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 32); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 16); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 8); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 32); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 16); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 8); \ | |||||
| megdnn_assert(false, \ | |||||
| "unsupported threadblock shape (%dx%dx%d) and warp shape " \ | |||||
| "(%dx%dx%d) and access_size (%d)", \ | |||||
| threadblock_shape.m(), threadblock_shape.n(), \ | |||||
| threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ | |||||
| warp_shape.k(), access_size); | |||||
| using ElementOutput = cutlass::int4b_t; | |||||
| using ElementAccumulator = int32_t; | |||||
| using ElementBias = int32_t; | |||||
| using ElementCompute = float; | |||||
| using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode; | |||||
| switch (nonlinear_mode) { | |||||
| case NonlineMode::IDENTITY: { | |||||
| using EpilogueOp = | |||||
| cutlass::epilogue::thread::BiasAddLinearCombinationClamp< | |||||
| ElementOutput, 8, ElementAccumulator, ElementBias, | |||||
| ElementCompute>; | |||||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma}; | |||||
| DISPATCH_KERNEL; | |||||
| } | |||||
| case NonlineMode::RELU: { | |||||
| using EpilogueOp = cutlass::epilogue::thread:: | |||||
| BiasAddLinearCombinationReluClamp< | |||||
| ElementOutput, 8, ElementAccumulator, ElementBias, | |||||
| ElementCompute>; | |||||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma, 0}; | |||||
| DISPATCH_KERNEL; | |||||
| } | |||||
| case NonlineMode::H_SWISH: { | |||||
| using EpilogueOp = cutlass::epilogue::thread:: | |||||
| BiasAddLinearCombinationHSwishClamp< | |||||
| ElementOutput, 8, ElementAccumulator, ElementBias, | |||||
| ElementCompute>; | |||||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma, scale}; | |||||
| DISPATCH_KERNEL; | |||||
| } | |||||
| default: | |||||
| megdnn_assert(false, | |||||
| "unsupported nonlinear mode for conv bias operator"); | |||||
| } | |||||
| #undef DISPATCH_KERNEL_WITH_TILE_SHAPE | |||||
| #undef DISPATCH_KERNEL | |||||
| } | |||||
| #endif | |||||
| #define INST(need_load_from_const_mem) \ | |||||
| template void megdnn::cuda::cutlass_wrapper:: \ | |||||
| do_conv_bias_int4_int4_implicit_gemm_imma_nhwc< \ | |||||
| need_load_from_const_mem>( \ | |||||
| const int8_t* d_src, const int8_t* d_filter, \ | |||||
| const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, \ | |||||
| int* workspace, const convolution::ConvParam& param, \ | |||||
| uint32_t nonlinear_mode, float alpha, float beta, \ | |||||
| float gamma, float scale, \ | |||||
| const GemmCoord& threadblock_shape, \ | |||||
| const GemmCoord& warp_shape, const int32_t access_size, \ | |||||
| cudaStream_t stream); | |||||
| INST(true); | |||||
| INST(false); | |||||
| #undef INST | |||||
| #if MEGDNN_TEGRA_X1 | |||||
| template <bool NeedLoadFromConstMem> | |||||
| void megdnn::cuda::cutlass_wrapper:: | |||||
| do_conv_bias_uint4_int4_implicit_gemm_imma_nhwc( | |||||
| const uint8_t* /* d_src */, const int8_t* /* d_filter */, | |||||
| const int32_t* /* d_bias */, const uint8_t* /* d_z */, | |||||
| uint8_t* /* d_dst */, int* /* workspace */, | |||||
| const convolution::ConvParam& /* param */, | |||||
| uint32_t /* nonlinear_mode */, float /* alpha */, | |||||
| float /* beta */, float /* gamma */, float /* delta */, | |||||
| float /* theta */, float /* scale */, | |||||
| uint8_t /* src_zero_point */, | |||||
| const GemmCoord& /* threadblock_shape */, | |||||
| const GemmCoord& /* warp_shape */, | |||||
| const int32_t /* access_size */, cudaStream_t /* stream */) {} | |||||
| #else | |||||
| template <bool NeedLoadFromConstMem> | |||||
| void megdnn::cuda::cutlass_wrapper:: | |||||
| do_conv_bias_uint4_int4_implicit_gemm_imma_nhwc( | |||||
| const uint8_t* d_src, const int8_t* d_filter, | |||||
| const int32_t* d_bias, const uint8_t* d_z, uint8_t* d_dst, | |||||
| int* workspace, const convolution::ConvParam& param, | |||||
| uint32_t nonlinear_mode, float alpha, float beta, float gamma, | |||||
| float delta, float theta, float /* scale */, | |||||
| uint8_t src_zero_point, const GemmCoord& threadblock_shape, | |||||
| const GemmCoord& warp_shape, const int32_t access_size, | |||||
| cudaStream_t stream) { | |||||
| #define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \ | |||||
| threadblock_k_, warp_m_, warp_n_, \ | |||||
| warp_k_, access_size_) \ | |||||
| if (threadblock_shape.m() == threadblock_m_ && \ | |||||
| threadblock_shape.n() == threadblock_n_ && \ | |||||
| threadblock_shape.k() == threadblock_k_ && \ | |||||
| warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ | |||||
| warp_shape.k() == warp_k_ && access_size == access_size_) { \ | |||||
| using ThreadBlockShape = \ | |||||
| cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \ | |||||
| threadblock_k_>; \ | |||||
| using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \ | |||||
| using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; \ | |||||
| using Convolution = cutlass::conv::device::Convolution< \ | |||||
| cutlass::uint4b_t, cutlass::layout::TensorNHWC, \ | |||||
| cutlass::int4b_t, cutlass::layout::TensorNCxHWx<access_size_>, \ | |||||
| ElementOutput, cutlass::layout::TensorNHWC, int32_t, \ | |||||
| cutlass::layout::TensorNHWC, int32_t, \ | |||||
| cutlass::conv::ConvType::kConvolution, \ | |||||
| cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \ | |||||
| ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ | |||||
| cutlass::conv::threadblock:: \ | |||||
| ConvolutionFpropNHWCThreadblockSwizzle, \ | |||||
| 2, access_size_, access_size_, NeedLoadFromConstMem, \ | |||||
| cutlass::arch::OpMultiplyAddSaturate, \ | |||||
| cutlass::conv::ImplicitGemmMode::GEMM_TN>; \ | |||||
| typename Convolution::ConvolutionParameter conv_param( \ | |||||
| param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ | |||||
| param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ | |||||
| param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \ | |||||
| return cutlass_convolution_wrapper<Convolution>( \ | |||||
| reinterpret_cast<const cutlass::uint4b_t*>(d_src), \ | |||||
| reinterpret_cast<const cutlass::int4b_t*>(d_filter), d_bias, \ | |||||
| reinterpret_cast<const cutlass::uint4b_t*>(d_z), \ | |||||
| reinterpret_cast<cutlass::uint4b_t*>(d_dst), workspace, \ | |||||
| conv_param, epilogue, stream, {src_zero_point}); \ | |||||
| } | |||||
| #define DISPATCH_KERNEL \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 32); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 16); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 8); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 32); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 16); \ | |||||
| DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 8); \ | |||||
| megdnn_assert(false, \ | |||||
| "unsupported threadblock shape (%dx%dx%d) and warp shape " \ | |||||
| "(%dx%dx%d) and access_size (%d)", \ | |||||
| threadblock_shape.m(), threadblock_shape.n(), \ | |||||
| threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ | |||||
| warp_shape.k(), access_size); | |||||
| using ElementOutput = cutlass::uint4b_t; | |||||
| using ElementAccumulator = int32_t; | |||||
| using ElementBias = int32_t; | |||||
| using ElementCompute = float; | |||||
| using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode; | |||||
| switch (nonlinear_mode) { | |||||
| case NonlineMode::IDENTITY: { | |||||
| using EpilogueOp = | |||||
| cutlass::epilogue::thread::BiasAddLinearCombinationClamp< | |||||
| ElementOutput, 8, ElementAccumulator, ElementBias, | |||||
| ElementCompute>; | |||||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma, | |||||
| delta + theta}; | |||||
| DISPATCH_KERNEL; | |||||
| } | |||||
| case NonlineMode::RELU: { | |||||
| using EpilogueOp = cutlass::epilogue::thread:: | |||||
| BiasAddLinearCombinationReluClamp< | |||||
| ElementOutput, 8, ElementAccumulator, ElementBias, | |||||
| ElementCompute>; | |||||
| typename EpilogueOp::Params epilogue{alpha, beta, gamma, | |||||
| 0, delta, theta}; | |||||
| DISPATCH_KERNEL; | |||||
| } | |||||
| default: | |||||
| megdnn_assert(false, | |||||
| "unsupported nonlinear mode for conv bias operator"); | |||||
| } | |||||
| #undef DISPATCH_KERNEL_WITH_TILE_SHAPE | |||||
| #undef DISPATCH_KERNEL | |||||
| } | |||||
| #endif | |||||
| #define INST(need_load_from_const_mem) \ | |||||
| template void megdnn::cuda::cutlass_wrapper:: \ | |||||
| do_conv_bias_uint4_int4_implicit_gemm_imma_nhwc< \ | |||||
| need_load_from_const_mem>( \ | |||||
| const uint8_t* d_src, const int8_t* d_filter, \ | |||||
| const int32_t* d_bias, const uint8_t* d_z, uint8_t* d_dst, \ | |||||
| int* workspace, const convolution::ConvParam& param, \ | |||||
| uint32_t nonlinear_mode, float alpha, float beta, \ | |||||
| float gamma, float delta, float theta, float scale, \ | |||||
| uint8_t src_zero_point, \ | |||||
| const GemmCoord& threadblock_shape, \ | |||||
| const GemmCoord& warp_shape, const int32_t access_size, \ | |||||
| cudaStream_t stream); | |||||
| INST(true); | |||||
| INST(false); | |||||
| #undef INST | |||||
| // vim: syntax=cuda.doxygen | // vim: syntax=cuda.doxygen | ||||
| @@ -0,0 +1,194 @@ | |||||
| /** | |||||
| * \file dnn/src/cuda/conv_bias/cutlass_reorder_filter.cu | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #include "src/cuda/conv_bias/cutlass_reorder_filter.cuh" | |||||
| #include "src/cuda/query_blocksize.cuh" | |||||
| #include "src/cuda/integer_subbyte_utils.cuh" | |||||
| using namespace megdnn; | |||||
| using namespace cuda; | |||||
| using namespace cutlass_wrapper; | |||||
| namespace { | |||||
| template <uint32_t size_bits, uint32_t interleaved> | |||||
| __device__ __forceinline__ void reorder_ncxhwx_imma_filter_func( | |||||
| int8_t* dst, const int8_t* src, uint32_t OC, uint32_t IC, uint32_t FH, | |||||
| uint32_t FW, uint32_t lane, bool trans_oc) { | |||||
| static constexpr uint32_t elements_per_lane = 128 / size_bits; | |||||
| static constexpr uint32_t threads_per_interleaved = | |||||
| interleaved / elements_per_lane; | |||||
| static constexpr uint32_t instruction_shape_col = 8; | |||||
| // 4 threads per Quad | |||||
| static constexpr uint32_t elements_per_thread = instruction_shape_col / 4; | |||||
| // 4 threads per Quad | |||||
| static constexpr uint32_t reordered_elements_per_thread = interleaved / 4; | |||||
| uint32_t id = lane / threads_per_interleaved; | |||||
| uint32_t residue = lane % threads_per_interleaved; | |||||
| uint32_t ICx = IC / interleaved; | |||||
| uint32_t row = id / (ICx * FH * FW); | |||||
| uint32_t col = id - row * ICx * FH * FW; | |||||
| // transpose ncxhwx to cxhwnx | |||||
| uint32_t src_offset = id * interleaved + residue * elements_per_lane; | |||||
| row = (trans_oc) ? (row / interleaved) * interleaved + | |||||
| ((row % reordered_elements_per_thread) / | |||||
| elements_per_thread) * | |||||
| instruction_shape_col + | |||||
| ((row % interleaved) / | |||||
| reordered_elements_per_thread) * | |||||
| elements_per_thread + | |||||
| (row % elements_per_thread) | |||||
| : row; | |||||
| uint32_t dst_offset = | |||||
| (col * OC + row) * interleaved + residue * elements_per_lane; | |||||
| *(reinterpret_cast<int4*>(dst + dst_offset * size_bits / 8)) = | |||||
| *(reinterpret_cast<const int4*>(src + src_offset * size_bits / 8)); | |||||
| } | |||||
| template <uint32_t size_bits, uint32_t interleaved> | |||||
| __global__ void reorder_ncxhwx_imma_filter_kernel( | |||||
| int8_t* __restrict__ dst_filter, const int8_t* __restrict__ src_filter, | |||||
| uint32_t OC, uint32_t IC, uint32_t FH, uint32_t FW, bool trans_oc) { | |||||
| static constexpr uint32_t elements_per_lane = 128 / size_bits; | |||||
| const uint32_t size = OC * IC * FH * FW / elements_per_lane; | |||||
| uint32_t lane = threadIdx.x + blockIdx.x * blockDim.x; | |||||
| if (lane < size) { | |||||
| reorder_ncxhwx_imma_filter_func<size_bits, interleaved>( | |||||
| dst_filter, src_filter, OC, IC, FH, FW, lane, trans_oc); | |||||
| } | |||||
| } | |||||
| template <uint32_t size_bits, uint32_t alignbits, uint32_t interleaved> | |||||
| __device__ __forceinline__ void reorder_nhwc_imma_filter_func( | |||||
| int8_t* dst, const int8_t* src, uint32_t OC, uint32_t IC, uint32_t FH, | |||||
| uint32_t FW, uint32_t lane, bool trans_oc) { | |||||
| static constexpr uint32_t elements_per_access = alignbits / size_bits; | |||||
| static constexpr uint32_t instruction_shape_col = 8; | |||||
| // 4 threads per Quad | |||||
| static constexpr uint32_t elements_per_thread = instruction_shape_col / 4; | |||||
| // 4 threads per Quad | |||||
| static constexpr uint32_t reordered_elements_per_thread = interleaved / 4; | |||||
| uint32_t ICx = IC / elements_per_access; | |||||
| uint32_t k = lane / (ICx * FH * FW); | |||||
| uint32_t cxrs = lane - k * ICx * FH * FW; | |||||
| uint32_t rs = cxrs / ICx; | |||||
| uint32_t cx = cxrs - rs * ICx; | |||||
| // transpose nhwc to ncxhwx | |||||
| uint32_t src_offset = lane * elements_per_access; | |||||
| // reorder k | |||||
| k = (trans_oc) | |||||
| ? (k / interleaved) * interleaved + | |||||
| ((k % reordered_elements_per_thread) / | |||||
| elements_per_thread) * | |||||
| instruction_shape_col + | |||||
| ((k % interleaved) / reordered_elements_per_thread) * | |||||
| elements_per_thread + | |||||
| (k % elements_per_thread) | |||||
| : k; | |||||
| uint32_t dst_offset = | |||||
| (k * ICx * FH * FW + cx * FH * FW + rs) * elements_per_access; | |||||
| if (alignbits == 32) { | |||||
| *(reinterpret_cast<int*>(dst + dst_offset * size_bits / 8)) = *( | |||||
| reinterpret_cast<const int*>(src + src_offset * size_bits / 8)); | |||||
| } else if (alignbits == 64) { | |||||
| *(reinterpret_cast<int2*>(dst + dst_offset * size_bits / 8)) = | |||||
| *(reinterpret_cast<const int2*>(src + | |||||
| src_offset * size_bits / 8)); | |||||
| } else { | |||||
| *(reinterpret_cast<int4*>(dst + dst_offset * size_bits / 8)) = | |||||
| *(reinterpret_cast<const int4*>(src + | |||||
| src_offset * size_bits / 8)); | |||||
| } | |||||
| } | |||||
| template <uint32_t size_bits, uint32_t alignbits, uint32_t interleaved> | |||||
| __global__ void reorder_nhwc_imma_filter_kernel( | |||||
| int8_t* __restrict__ dst_filter, const int8_t* __restrict__ src_filter, | |||||
| uint32_t OC, uint32_t IC, uint32_t FH, uint32_t FW, bool trans_oc) { | |||||
| static constexpr uint32_t elements_per_access = alignbits / size_bits; | |||||
| const uint32_t size = OC * IC * FH * FW / elements_per_access; | |||||
| uint32_t lane = threadIdx.x + blockIdx.x * blockDim.x; | |||||
| if (lane < size) { | |||||
| reorder_nhwc_imma_filter_func<size_bits, alignbits, interleaved>( | |||||
| dst_filter, src_filter, OC, IC, FH, FW, lane, trans_oc); | |||||
| } | |||||
| } | |||||
| } // namespace | |||||
| template <uint32_t size_bits, uint32_t interleaved> | |||||
| void megdnn::cuda::cutlass_wrapper::reorder_ncxhwx_imma_filter( | |||||
| int8_t* dst_filter, const int8_t* src_filter, uint32_t OC, uint32_t IC, | |||||
| uint32_t FH, uint32_t FW, bool trans_oc, cudaStream_t stream) { | |||||
| static constexpr uint32_t elements_per_lane = 128 / size_bits; | |||||
| uint32_t nr_threads = | |||||
| query_blocksize_for_kernel(reinterpret_cast<const void*>( | |||||
| reorder_ncxhwx_imma_filter_kernel<size_bits, interleaved>)); | |||||
| uint32_t vthreads = DIVUP(OC * IC * FH * FW, elements_per_lane); | |||||
| nr_threads = std::min(nr_threads, vthreads); | |||||
| uint32_t nr_blocks = DIVUP(vthreads, nr_threads); | |||||
| reorder_ncxhwx_imma_filter_kernel<size_bits, interleaved> | |||||
| <<<nr_blocks, nr_threads, 0, stream>>>(dst_filter, src_filter, OC, | |||||
| IC, FH, FW, trans_oc); | |||||
| after_kernel_launch(); | |||||
| } | |||||
| template <uint32_t size_bits, uint32_t alignbits> | |||||
| void megdnn::cuda::cutlass_wrapper::reorder_nhwc_imma_filter( | |||||
| int8_t* dst_filter, const int8_t* src_filter, uint32_t OC, uint32_t IC, | |||||
| uint32_t FH, uint32_t FW, bool trans_oc, uint32_t oc_interleaved, | |||||
| cudaStream_t stream) { | |||||
| static constexpr uint32_t elements_per_access = alignbits / size_bits; | |||||
| uint32_t nr_threads = | |||||
| query_blocksize_for_kernel(reinterpret_cast<const void*>( | |||||
| reorder_nhwc_imma_filter_kernel<size_bits, alignbits, 32>)); | |||||
| uint32_t vthreads = DIVUP(OC * IC * FH * FW, elements_per_access); | |||||
| nr_threads = std::min(nr_threads, vthreads); | |||||
| uint32_t nr_blocks = DIVUP(vthreads, nr_threads); | |||||
| if (oc_interleaved == 32) { | |||||
| reorder_nhwc_imma_filter_kernel<size_bits, alignbits, 32> | |||||
| <<<nr_blocks, nr_threads, 0, stream>>>( | |||||
| dst_filter, src_filter, OC, IC, FH, FW, trans_oc); | |||||
| } else { | |||||
| reorder_nhwc_imma_filter_kernel<size_bits, alignbits, 64> | |||||
| <<<nr_blocks, nr_threads, 0, stream>>>( | |||||
| dst_filter, src_filter, OC, IC, FH, FW, trans_oc); | |||||
| } | |||||
| after_kernel_launch(); | |||||
| } | |||||
| #define INST(_size_bits, _interleaved) \ | |||||
| template void megdnn::cuda::cutlass_wrapper::reorder_ncxhwx_imma_filter< \ | |||||
| _size_bits, _interleaved>(int8_t * dst_filter, \ | |||||
| const int8_t* src_filter, uint32_t OC, \ | |||||
| uint32_t IC, uint32_t FH, uint32_t FW, \ | |||||
| bool trans_oc, cudaStream_t stream); | |||||
| INST(8, 32) | |||||
| INST(4, 64) | |||||
| #undef INST | |||||
| #define INST(_size_bits, _alignbits) \ | |||||
| template void megdnn::cuda::cutlass_wrapper::reorder_nhwc_imma_filter< \ | |||||
| _size_bits, _alignbits>( \ | |||||
| int8_t * dst_filter, const int8_t* src_filter, uint32_t OC, \ | |||||
| uint32_t IC, uint32_t FH, uint32_t FW, bool trans_oc, \ | |||||
| uint32_t oc_interleaved, cudaStream_t stream); | |||||
| INST(4, 32) | |||||
| INST(4, 64) | |||||
| INST(4, 128) | |||||
| #undef INST | |||||
| // vim: syntax=cuda.doxygen | |||||
| @@ -0,0 +1,33 @@ | |||||
| /** | |||||
| * \file dnn/src/cuda/conv_bias/cutlass_reorder_filter.cuh | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #pragma once | |||||
| #include "src/cuda/utils.cuh" | |||||
| namespace megdnn { | |||||
| namespace cuda { | |||||
| namespace cutlass_wrapper { | |||||
| template <uint32_t size_bits, uint32_t interleaved> | |||||
| void reorder_ncxhwx_imma_filter(int8_t* dst_filter, const int8_t* src_filter, | |||||
| uint32_t OC, uint32_t IC, uint32_t FH, | |||||
| uint32_t FW, bool trans_oc, | |||||
| cudaStream_t stream); | |||||
| template <uint32_t size_bits, uint32_t alignbits> | |||||
| void reorder_nhwc_imma_filter(int8_t* dst_filter, const int8_t* src_filter, | |||||
| uint32_t OC, uint32_t IC, uint32_t FH, | |||||
| uint32_t FW, bool trans_oc, | |||||
| uint32_t oc_interleaved, cudaStream_t stream); | |||||
| } // namespace cutlass_wrapper | |||||
| } // namespace cuda | |||||
| } // namespace megdnn | |||||
| @@ -102,7 +102,7 @@ void ConvBiasForwardImpl::AlgoInt4Int4NCHW64IMMAImplicitGemm::do_exec( | |||||
| reinterpret_cast<int8_t*>(z_ptr), | reinterpret_cast<int8_t*>(z_ptr), | ||||
| reinterpret_cast<int8_t*>(args.dst_tensor->raw_ptr), nullptr, | reinterpret_cast<int8_t*>(args.dst_tensor->raw_ptr), nullptr, | ||||
| kern_param, nonlinear_mode, alpha, beta, gamma, dst_scale, | kern_param, nonlinear_mode, alpha, beta, gamma, dst_scale, | ||||
| threadblock_shape, warp_shape, stream); | |||||
| threadblock_shape, warp_shape, m_algo_param.stage, stream); | |||||
| } | } | ||||
| #endif | #endif | ||||
| @@ -104,7 +104,7 @@ void ConvBiasForwardImpl::AlgoInt4Int4NHWCIMMAImplicitGemm::do_exec( | |||||
| reinterpret_cast<int8_t*>(args.dst_tensor->raw_ptr), nullptr, | reinterpret_cast<int8_t*>(args.dst_tensor->raw_ptr), nullptr, | ||||
| kern_param, nonlinear_mode, alpha, beta, gamma, dst_scale, | kern_param, nonlinear_mode, alpha, beta, gamma, dst_scale, | ||||
| threadblock_shape, warp_shape, m_algo_param.access_size, | threadblock_shape, warp_shape, m_algo_param.access_size, | ||||
| stream); | |||||
| m_algo_param.stage, stream); | |||||
| } else { | } else { | ||||
| cutlass_wrapper::do_conv_bias_int4_int4_implicit_gemm_imma_nhwc<true>( | cutlass_wrapper::do_conv_bias_int4_int4_implicit_gemm_imma_nhwc<true>( | ||||
| reinterpret_cast<int8_t*>(args.src_tensor->raw_ptr), | reinterpret_cast<int8_t*>(args.src_tensor->raw_ptr), | ||||
| @@ -114,7 +114,7 @@ void ConvBiasForwardImpl::AlgoInt4Int4NHWCIMMAImplicitGemm::do_exec( | |||||
| reinterpret_cast<int8_t*>(args.dst_tensor->raw_ptr), nullptr, | reinterpret_cast<int8_t*>(args.dst_tensor->raw_ptr), nullptr, | ||||
| kern_param, nonlinear_mode, alpha, beta, gamma, dst_scale, | kern_param, nonlinear_mode, alpha, beta, gamma, dst_scale, | ||||
| threadblock_shape, warp_shape, m_algo_param.access_size, | threadblock_shape, warp_shape, m_algo_param.access_size, | ||||
| stream); | |||||
| m_algo_param.stage, stream); | |||||
| } | } | ||||
| } | } | ||||
| #endif | #endif | ||||
| @@ -12,6 +12,7 @@ | |||||
| #include "./algo.h" | #include "./algo.h" | ||||
| #include "src/common/conv_bias.h" | #include "src/common/conv_bias.h" | ||||
| #include "src/cuda/conv_bias/cutlass_reorder_filter.cuh" | |||||
| #include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh" | #include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh" | ||||
| #include "src/cuda/conv_bias/reduce_filter.cuh" | #include "src/cuda/conv_bias/reduce_filter.cuh" | ||||
| #include "src/cuda/convolution_helper/parameter.cuh" | #include "src/cuda/convolution_helper/parameter.cuh" | ||||
| @@ -121,41 +122,26 @@ void ConvBiasForwardImpl::AlgoInt4NCHW64IMMAImplicitGemmBase::exec( | |||||
| std::string ConvBiasForwardImpl::AlgoInt4NCHW64IMMAImplicitGemmBase::to_string( | std::string ConvBiasForwardImpl::AlgoInt4NCHW64IMMAImplicitGemmBase::to_string( | ||||
| AlgoParam algo_param) { | AlgoParam algo_param) { | ||||
| return ssprintf("%dX%dX%d_%dX%dX%d", algo_param.threadblock_m, | |||||
| return ssprintf("%dX%dX%d_%dX%dX%d_%d", algo_param.threadblock_m, | |||||
| algo_param.threadblock_n, algo_param.threadblock_k, | algo_param.threadblock_n, algo_param.threadblock_k, | ||||
| algo_param.warp_m, algo_param.warp_n, algo_param.warp_k); | |||||
| algo_param.warp_m, algo_param.warp_n, algo_param.warp_k, | |||||
| algo_param.stage); | |||||
| } | } | ||||
| void ConvBiasForwardImpl::AlgoInt4NCHW64IMMAImplicitGemmBase::reorder_filter( | void ConvBiasForwardImpl::AlgoInt4NCHW64IMMAImplicitGemmBase::reorder_filter( | ||||
| const ExecArgs& args, void* reordered_filter) const { | const ExecArgs& args, void* reordered_filter) const { | ||||
| auto&& param = args.opr->param(); | |||||
| size_t ci = args.src_layout->operator[](1) * 64; | |||||
| size_t co = args.dst_layout->operator[](1) * 64; | |||||
| auto&& fm = args.filter_meta; | auto&& fm = args.filter_meta; | ||||
| size_t n = args.src_layout->operator[](0), | |||||
| ci = args.src_layout->operator[](1) * 64, | |||||
| hi = args.src_layout->operator[](2), | |||||
| wi = args.src_layout->operator[](3); | |||||
| size_t co = args.dst_layout->operator[](1) * 64, | |||||
| ho = args.dst_layout->operator[](2), | |||||
| wo = args.dst_layout->operator[](3); | |||||
| UNPACK_CONV_PARAMETER(fm, param); | |||||
| MARK_USED_VAR; | |||||
| // filter: KCRS64 => CRSK64 | |||||
| TensorLayout src{{co, ci / 64, fh, fw, 64}, dtype::QuantizedS4()}; | |||||
| src.init_contiguous_stride(); | |||||
| TensorLayout dst = src; | |||||
| dst.stride[0] = 64; | |||||
| dst.stride[1] = co * fh * fw * 64; | |||||
| dst.stride[2] = co * fw * 64; | |||||
| dst.stride[3] = co * 64; | |||||
| dst.stride[4] = 1; | |||||
| TensorND ts_src, ts_dst; | |||||
| ts_src.raw_ptr = args.filter_tensor->raw_ptr; | |||||
| ts_src.layout = src; | |||||
| ts_dst.raw_ptr = reordered_filter; | |||||
| ts_dst.layout = dst; | |||||
| auto&& transpose = args.opr->handle()->create_operator<RelayoutForward>(); | |||||
| transpose->exec(ts_src, ts_dst); | |||||
| size_t fh = fm.spatial[0], fw = fm.spatial[1]; | |||||
| cudaStream_t stream = cuda_stream(args.opr->handle()); | |||||
| // filter: KCRS64 => CRSK64 and reorder oc | |||||
| cutlass_wrapper::reorder_ncxhwx_imma_filter<4, 64>( | |||||
| reinterpret_cast<int8_t*>(reordered_filter), | |||||
| reinterpret_cast<int8_t*>(args.filter_tensor->raw_ptr), co, ci, fh, | |||||
| fw, true, stream); | |||||
| } | } | ||||
| #endif | #endif | ||||
| @@ -12,6 +12,7 @@ | |||||
| #include "./algo.h" | #include "./algo.h" | ||||
| #include "src/common/conv_bias.h" | #include "src/common/conv_bias.h" | ||||
| #include "src/cuda/conv_bias/cutlass_reorder_filter.cuh" | |||||
| #include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh" | #include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh" | ||||
| #include "src/cuda/conv_bias/reduce_filter.cuh" | #include "src/cuda/conv_bias/reduce_filter.cuh" | ||||
| #include "src/cuda/convolution_helper/parameter.cuh" | #include "src/cuda/convolution_helper/parameter.cuh" | ||||
| @@ -128,10 +129,10 @@ void ConvBiasForwardImpl::AlgoInt4NHWCIMMAImplicitGemmBase::exec( | |||||
| std::string ConvBiasForwardImpl::AlgoInt4NHWCIMMAImplicitGemmBase::to_string( | std::string ConvBiasForwardImpl::AlgoInt4NHWCIMMAImplicitGemmBase::to_string( | ||||
| AlgoParam algo_param) { | AlgoParam algo_param) { | ||||
| return ssprintf("%dX%dX%d_%dX%dX%d_%d", algo_param.threadblock_m, | |||||
| return ssprintf("%dX%dX%d_%dX%dX%d_%d_%d", algo_param.threadblock_m, | |||||
| algo_param.threadblock_n, algo_param.threadblock_k, | algo_param.threadblock_n, algo_param.threadblock_k, | ||||
| algo_param.warp_m, algo_param.warp_n, algo_param.warp_k, | algo_param.warp_m, algo_param.warp_n, algo_param.warp_k, | ||||
| algo_param.access_size); | |||||
| algo_param.stage, algo_param.access_size); | |||||
| } | } | ||||
| void ConvBiasForwardImpl::AlgoInt4NHWCIMMAImplicitGemmBase::reorder_filter( | void ConvBiasForwardImpl::AlgoInt4NHWCIMMAImplicitGemmBase::reorder_filter( | ||||
| @@ -142,17 +143,32 @@ void ConvBiasForwardImpl::AlgoInt4NHWCIMMAImplicitGemmBase::reorder_filter( | |||||
| fh = args.filter_layout->operator[](1), | fh = args.filter_layout->operator[](1), | ||||
| fw = args.filter_layout->operator[](2); | fw = args.filter_layout->operator[](2); | ||||
| // reformat grad from nhwc to ncxhwx | |||||
| TensorLayout exec_src{{co, fh, fw, ci / iterleaved, (size_t)iterleaved / 2}, | |||||
| dtype::Int8()}; | |||||
| TensorLayout exec_dst{{co, ci / iterleaved, fh, fw, (size_t)iterleaved / 2}, | |||||
| dtype::Int8()}; | |||||
| exec_src = exec_src.dimshuffle({0, 3, 1, 2, 4}); | |||||
| cudaStream_t stream = cuda_stream(args.opr->handle()); | |||||
| auto&& relayout = args.opr->handle()->create_operator<RelayoutForward>(); | |||||
| relayout->exec({args.filter_tensor->raw_ptr, exec_src}, | |||||
| {reordered_filter, exec_dst}); | |||||
| // reformat filter from nhwc to ncxhwx and reorder oc | |||||
| // use trans_oc threadblock_n must be 32 or 64 | |||||
| bool trans_oc = ((co % m_algo_param.threadblock_n == 0) && | |||||
| (m_algo_param.threadblock_n == 32 || | |||||
| m_algo_param.threadblock_n == 64)); | |||||
| uint32_t oc_iterleave = (m_algo_param.threadblock_n == 64) ? 64 : 32; | |||||
| if (iterleaved == 8) { | |||||
| cutlass_wrapper::reorder_nhwc_imma_filter<4, 32>( | |||||
| reinterpret_cast<int8_t*>(reordered_filter), | |||||
| reinterpret_cast<int8_t*>(args.filter_tensor->raw_ptr), co, ci, | |||||
| fh, fw, trans_oc, oc_iterleave, stream); | |||||
| } else if (iterleaved == 16) { | |||||
| cutlass_wrapper::reorder_nhwc_imma_filter<4, 64>( | |||||
| reinterpret_cast<int8_t*>(reordered_filter), | |||||
| reinterpret_cast<int8_t*>(args.filter_tensor->raw_ptr), co, ci, | |||||
| fh, fw, trans_oc, oc_iterleave, stream); | |||||
| } else { | |||||
| megdnn_assert(iterleaved == 32); | |||||
| cutlass_wrapper::reorder_nhwc_imma_filter<4, 128>( | |||||
| reinterpret_cast<int8_t*>(reordered_filter), | |||||
| reinterpret_cast<int8_t*>(args.filter_tensor->raw_ptr), co, ci, | |||||
| fh, fw, trans_oc, oc_iterleave, stream); | |||||
| } | |||||
| } | } | ||||
| #endif | #endif | ||||
| @@ -11,6 +11,7 @@ | |||||
| */ | */ | ||||
| #include "./algo.h" | #include "./algo.h" | ||||
| #include "src/cuda/conv_bias/cutlass_reorder_filter.cuh" | |||||
| #include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh" | #include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh" | ||||
| #include "src/cuda/convolution_helper/parameter.cuh" | #include "src/cuda/convolution_helper/parameter.cuh" | ||||
| #include "src/cuda/utils.h" | #include "src/cuda/utils.h" | ||||
| @@ -110,11 +111,14 @@ void ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::exec( | |||||
| size_t ho = args.dst_layout->operator[](2), | size_t ho = args.dst_layout->operator[](2), | ||||
| wo = args.dst_layout->operator[](3); | wo = args.dst_layout->operator[](3); | ||||
| size_t co; | size_t co; | ||||
| bool trans_oc; | |||||
| if (param.format == Format::NCHW32) { | if (param.format == Format::NCHW32) { | ||||
| co = args.dst_layout->operator[](1) * 32; | co = args.dst_layout->operator[](1) * 32; | ||||
| trans_oc = true; | |||||
| } else { | } else { | ||||
| megdnn_assert(param.format == Format::NCHW32_NCHW4); | megdnn_assert(param.format == Format::NCHW32_NCHW4); | ||||
| co = args.dst_layout->operator[](1) * 4; | co = args.dst_layout->operator[](1) * 4; | ||||
| trans_oc = false; | |||||
| } | } | ||||
| UNPACK_CONV_PARAMETER(fm, param); | UNPACK_CONV_PARAMETER(fm, param); | ||||
| MARK_USED_VAR | MARK_USED_VAR | ||||
| @@ -123,23 +127,11 @@ void ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::exec( | |||||
| int8_t* filter_ptr = nullptr; | int8_t* filter_ptr = nullptr; | ||||
| if (args.preprocessed_filter == nullptr) { | if (args.preprocessed_filter == nullptr) { | ||||
| filter_ptr = reinterpret_cast<int8_t*>(args.workspace.raw_ptr); | filter_ptr = reinterpret_cast<int8_t*>(args.workspace.raw_ptr); | ||||
| // reformat filter from nchw32 to chwn32 | |||||
| TensorLayout src{{co, ci / 32, fh, fw, 32}, dtype::Int8()}; | |||||
| src.init_contiguous_stride(); | |||||
| TensorLayout dst = src; | |||||
| dst.stride[0] = 32; | |||||
| dst.stride[1] = co * fh * fw * 32; | |||||
| dst.stride[2] = co * fw * 32; | |||||
| dst.stride[3] = co * 32; | |||||
| dst.stride[4] = 1; | |||||
| TensorND ts_src, ts_dst; | |||||
| ts_src.raw_ptr = args.filter_tensor->raw_ptr; | |||||
| ts_src.layout = src; | |||||
| ts_dst.raw_ptr = args.workspace.raw_ptr; | |||||
| ts_dst.layout = dst; | |||||
| auto&& transpose = | |||||
| args.opr->handle()->create_operator<RelayoutForward>(); | |||||
| transpose->exec(ts_src, ts_dst); | |||||
| // filter: KCRS32 => CRSK32 and reorder oc | |||||
| cutlass_wrapper::reorder_ncxhwx_imma_filter<8, 32>( | |||||
| filter_ptr, | |||||
| reinterpret_cast<int8_t*>(args.filter_tensor->raw_ptr), co, ci, | |||||
| fh, fw, trans_oc, stream); | |||||
| } else { | } else { | ||||
| filter_ptr = reinterpret_cast<int8_t*>( | filter_ptr = reinterpret_cast<int8_t*>( | ||||
| args.preprocessed_filter->tensors[0].raw_ptr); | args.preprocessed_filter->tensors[0].raw_ptr); | ||||
| @@ -182,7 +174,7 @@ void ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::exec( | |||||
| cutlass_wrapper::GemmCoord{m_algo_param.warp_m, | cutlass_wrapper::GemmCoord{m_algo_param.warp_m, | ||||
| m_algo_param.warp_n, | m_algo_param.warp_n, | ||||
| m_algo_param.warp_k}, | m_algo_param.warp_k}, | ||||
| stream); | |||||
| m_algo_param.stage, stream); | |||||
| } else { | } else { | ||||
| megdnn_assert(param.format == Format::NCHW32_NCHW4); | megdnn_assert(param.format == Format::NCHW32_NCHW4); | ||||
| cutlass_wrapper:: | cutlass_wrapper:: | ||||
| @@ -202,7 +194,7 @@ void ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::exec( | |||||
| cutlass_wrapper::GemmCoord{m_algo_param.warp_m, | cutlass_wrapper::GemmCoord{m_algo_param.warp_m, | ||||
| m_algo_param.warp_n, | m_algo_param.warp_n, | ||||
| m_algo_param.warp_k}, | m_algo_param.warp_k}, | ||||
| stream); | |||||
| m_algo_param.stage, stream); | |||||
| } | } | ||||
| } else { | } else { | ||||
| if (param.format == Format::NCHW32) { | if (param.format == Format::NCHW32) { | ||||
| @@ -218,7 +210,7 @@ void ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::exec( | |||||
| cutlass_wrapper::GemmCoord{m_algo_param.warp_m, | cutlass_wrapper::GemmCoord{m_algo_param.warp_m, | ||||
| m_algo_param.warp_n, | m_algo_param.warp_n, | ||||
| m_algo_param.warp_k}, | m_algo_param.warp_k}, | ||||
| stream); | |||||
| m_algo_param.stage, stream); | |||||
| } else { | } else { | ||||
| megdnn_assert(param.format == Format::NCHW32_NCHW4); | megdnn_assert(param.format == Format::NCHW32_NCHW4); | ||||
| cutlass_wrapper:: | cutlass_wrapper:: | ||||
| @@ -238,7 +230,7 @@ void ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::exec( | |||||
| cutlass_wrapper::GemmCoord{m_algo_param.warp_m, | cutlass_wrapper::GemmCoord{m_algo_param.warp_m, | ||||
| m_algo_param.warp_n, | m_algo_param.warp_n, | ||||
| m_algo_param.warp_k}, | m_algo_param.warp_k}, | ||||
| stream); | |||||
| m_algo_param.stage, stream); | |||||
| } | } | ||||
| } | } | ||||
| after_kernel_launch(); | after_kernel_launch(); | ||||
| @@ -246,9 +238,10 @@ void ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::exec( | |||||
| std::string ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::to_string( | std::string ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::to_string( | ||||
| AlgoParam algo_param) { | AlgoParam algo_param) { | ||||
| return ssprintf("%uX%uX%u_%uX%uX%u", algo_param.threadblock_m, | |||||
| return ssprintf("%uX%uX%u_%uX%uX%u_%u", algo_param.threadblock_m, | |||||
| algo_param.threadblock_n, algo_param.threadblock_k, | algo_param.threadblock_n, algo_param.threadblock_k, | ||||
| algo_param.warp_m, algo_param.warp_n, algo_param.warp_k); | |||||
| algo_param.warp_m, algo_param.warp_n, algo_param.warp_k, | |||||
| algo_param.stage); | |||||
| } | } | ||||
| size_t ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm:: | size_t ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm:: | ||||
| @@ -267,36 +260,26 @@ void ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::exec_preprocess( | |||||
| using Format = Param::Format; | using Format = Param::Format; | ||||
| auto&& param = args.opr->param(); | auto&& param = args.opr->param(); | ||||
| auto&& fm = args.filter_meta; | auto&& fm = args.filter_meta; | ||||
| size_t n = args.src_layout->operator[](0), | |||||
| ci = args.src_layout->operator[](1) * 32, | |||||
| hi = args.src_layout->operator[](2), | |||||
| wi = args.src_layout->operator[](3); | |||||
| size_t ho = args.dst_layout->operator[](2), | |||||
| wo = args.dst_layout->operator[](3); | |||||
| size_t ci = args.src_layout->operator[](1) * 32; | |||||
| size_t co; | size_t co; | ||||
| bool trans_oc; | |||||
| if (param.format == Format::NCHW32) { | if (param.format == Format::NCHW32) { | ||||
| co = args.dst_layout->operator[](1) * 32; | co = args.dst_layout->operator[](1) * 32; | ||||
| trans_oc = true; | |||||
| } else { | } else { | ||||
| megdnn_assert(param.format == Format::NCHW32_NCHW4); | megdnn_assert(param.format == Format::NCHW32_NCHW4); | ||||
| co = args.dst_layout->operator[](1) * 4; | co = args.dst_layout->operator[](1) * 4; | ||||
| trans_oc = false; | |||||
| } | } | ||||
| UNPACK_CONV_PARAMETER(fm, param); | |||||
| MARK_USED_VAR | |||||
| TensorLayout src{{co, ci / 32, fh, fw, 32}, dtype::Int8()}; | |||||
| src.init_contiguous_stride(); | |||||
| TensorLayout dst = src; | |||||
| dst.stride[0] = 32; | |||||
| dst.stride[1] = co * fh * fw * 32; | |||||
| dst.stride[2] = co * fw * 32; | |||||
| dst.stride[3] = co * 32; | |||||
| dst.stride[4] = 1; | |||||
| TensorND ts_src, ts_dst; | |||||
| ts_src.raw_ptr = args.filter_tensor->raw_ptr; | |||||
| ts_src.layout = src; | |||||
| ts_dst.raw_ptr = args.preprocessed_filter->tensors[0].raw_ptr; | |||||
| ts_dst.layout = dst; | |||||
| auto&& transpose = args.opr->handle()->create_operator<RelayoutForward>(); | |||||
| transpose->exec(ts_src, ts_dst); | |||||
| size_t fh = fm.spatial[0], fw = fm.spatial[1]; | |||||
| cudaStream_t stream = cuda_stream(args.opr->handle()); | |||||
| // filter: KCRS32 => CRSK32 and reorder oc | |||||
| cutlass_wrapper::reorder_ncxhwx_imma_filter<8, 32>( | |||||
| reinterpret_cast<int8_t*>( | |||||
| args.preprocessed_filter->tensors[0].raw_ptr), | |||||
| reinterpret_cast<int8_t*>(args.filter_tensor->raw_ptr), co, ci, fh, | |||||
| fw, trans_oc, stream); | |||||
| } | } | ||||
| #endif | #endif | ||||
| @@ -144,7 +144,8 @@ void ConvBiasForwardImpl::AlgoUInt4Int4NCHW64IMMAImplicitGemm::do_exec( | |||||
| reinterpret_cast<uint8_t*>(z_ptr), | reinterpret_cast<uint8_t*>(z_ptr), | ||||
| reinterpret_cast<uint8_t*>(args.dst_tensor->raw_ptr), nullptr, | reinterpret_cast<uint8_t*>(args.dst_tensor->raw_ptr), nullptr, | ||||
| kern_param, nonlinear_mode, alpha, beta, gamma, delta, theta, | kern_param, nonlinear_mode, alpha, beta, gamma, delta, theta, | ||||
| dst_scale, src_zero, threadblock_shape, warp_shape, stream); | |||||
| dst_scale, src_zero, threadblock_shape, warp_shape, | |||||
| m_algo_param.stage, stream); | |||||
| } | } | ||||
| void ConvBiasForwardImpl::AlgoUInt4Int4NCHW64IMMAImplicitGemm::update_bias( | void ConvBiasForwardImpl::AlgoUInt4Int4NCHW64IMMAImplicitGemm::update_bias( | ||||
| @@ -147,7 +147,7 @@ void ConvBiasForwardImpl::AlgoUInt4Int4NHWCIMMAImplicitGemm::do_exec( | |||||
| reinterpret_cast<uint8_t*>(args.dst_tensor->raw_ptr), nullptr, | reinterpret_cast<uint8_t*>(args.dst_tensor->raw_ptr), nullptr, | ||||
| kern_param, nonlinear_mode, alpha, beta, gamma, delta, theta, | kern_param, nonlinear_mode, alpha, beta, gamma, delta, theta, | ||||
| dst_scale, src_zero, threadblock_shape, warp_shape, | dst_scale, src_zero, threadblock_shape, warp_shape, | ||||
| m_algo_param.access_size, stream); | |||||
| m_algo_param.access_size, m_algo_param.stage, stream); | |||||
| } else { | } else { | ||||
| cutlass_wrapper::do_conv_bias_uint4_int4_implicit_gemm_imma_nhwc<true>( | cutlass_wrapper::do_conv_bias_uint4_int4_implicit_gemm_imma_nhwc<true>( | ||||
| reinterpret_cast<uint8_t*>(args.src_tensor->raw_ptr), | reinterpret_cast<uint8_t*>(args.src_tensor->raw_ptr), | ||||
| @@ -157,7 +157,7 @@ void ConvBiasForwardImpl::AlgoUInt4Int4NHWCIMMAImplicitGemm::do_exec( | |||||
| reinterpret_cast<uint8_t*>(args.dst_tensor->raw_ptr), nullptr, | reinterpret_cast<uint8_t*>(args.dst_tensor->raw_ptr), nullptr, | ||||
| kern_param, nonlinear_mode, alpha, beta, gamma, delta, theta, | kern_param, nonlinear_mode, alpha, beta, gamma, delta, theta, | ||||
| dst_scale, src_zero, threadblock_shape, warp_shape, | dst_scale, src_zero, threadblock_shape, warp_shape, | ||||
| m_algo_param.access_size, stream); | |||||
| m_algo_param.access_size, m_algo_param.stage, stream); | |||||
| } | } | ||||
| } | } | ||||
| @@ -840,21 +840,21 @@ TEST_F(CUDA, CUTLASS_CONV_BIAS_INT8_NCHW32_IMMA) { | |||||
| param.pad_h = param.pad_w = 1; | param.pad_h = param.pad_w = 1; | ||||
| param.stride_h = param.stride_w = 1; | param.stride_h = param.stride_w = 1; | ||||
| param.format = param::ConvBias::Format::NCHW32; | param.format = param::ConvBias::Format::NCHW32; | ||||
| checker.set_param(param).execs({{16, 16, 7, 7, 32}, | |||||
| {512, 16, 3, 3, 32}, | |||||
| {1, 16, 1, 1, 32}, | |||||
| checker.set_param(param).execs({{16, 8, 7, 7, 32}, | |||||
| {256, 8, 3, 3, 32}, | |||||
| {1, 8, 1, 1, 32}, | |||||
| {}, | {}, | ||||
| {}}); | {}}); | ||||
| param.nonlineMode = param::ConvBias::NonlineMode::RELU; | param.nonlineMode = param::ConvBias::NonlineMode::RELU; | ||||
| checker.set_param(param).execs({{16, 16, 7, 7, 32}, | |||||
| {512, 16, 1, 1, 32}, | |||||
| {1, 16, 1, 1, 32}, | |||||
| checker.set_param(param).execs({{16, 8, 7, 7, 32}, | |||||
| {256, 8, 1, 1, 32}, | |||||
| {1, 8, 1, 1, 32}, | |||||
| {}, | {}, | ||||
| {}}); | {}}); | ||||
| param.nonlineMode = param::ConvBias::NonlineMode::H_SWISH; | param.nonlineMode = param::ConvBias::NonlineMode::H_SWISH; | ||||
| checker.set_param(param).execs({{16, 16, 7, 7, 32}, | |||||
| {512, 16, 3, 3, 32}, | |||||
| {1, 16, 1, 1, 32}, | |||||
| checker.set_param(param).execs({{16, 8, 7, 7, 32}, | |||||
| {256, 8, 3, 3, 32}, | |||||
| {1, 8, 1, 1, 32}, | |||||
| {}, | {}, | ||||
| {}}); | {}}); | ||||
| // use non integer scale | // use non integer scale | ||||
| @@ -867,18 +867,18 @@ TEST_F(CUDA, CUTLASS_CONV_BIAS_INT8_NCHW32_IMMA) { | |||||
| .set_epsilon(1 + 1e-3) | .set_epsilon(1 + 1e-3) | ||||
| .set_max_avg_error(1e-1) | .set_max_avg_error(1e-1) | ||||
| .set_max_avg_biased_error(1e-1) | .set_max_avg_biased_error(1e-1) | ||||
| .execs({{16, 16, 7, 7, 32}, | |||||
| {512, 16, 3, 3, 32}, | |||||
| {1, 16, 1, 1, 32}, | |||||
| {16, 16, 7, 7, 32}, | |||||
| .execs({{16, 8, 7, 7, 32}, | |||||
| {256, 8, 3, 3, 32}, | |||||
| {1, 8, 1, 1, 32}, | |||||
| {16, 8, 7, 7, 32}, | |||||
| {}}); | {}}); | ||||
| }; | }; | ||||
| std::string algo = ConvBias::algo_name<ConvBias::DirectParam>( | std::string algo = ConvBias::algo_name<ConvBias::DirectParam>( | ||||
| "INT8_NCHW32_IMMA_IMPLICIT_GEMM_256X128X64_64X64X64", | |||||
| "INT8_NCHW32_IMMA_IMPLICIT_GEMM_128X128X64_64X64X64_2", | |||||
| ConvBias::DirectParam{}); | ConvBias::DirectParam{}); | ||||
| check(algo); | check(algo); | ||||
| algo = ConvBias::algo_name<ConvBias::DirectParam>( | algo = ConvBias::algo_name<ConvBias::DirectParam>( | ||||
| "INT8_NCHW32_IMMA_IMPLICIT_GEMM_32X64X64_32X16X64", | |||||
| "INT8_NCHW32_IMMA_IMPLICIT_GEMM_128X32X32_64X32X32_1", | |||||
| ConvBias::DirectParam{}); | ConvBias::DirectParam{}); | ||||
| check(algo); | check(algo); | ||||
| } | } | ||||
| @@ -969,7 +969,7 @@ TEST_F(CUDA, CUTLASS_CONV_BIAS_INT8_NCHW32_NCHW4) { | |||||
| checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker< | checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker< | ||||
| ConvBiasForward>( | ConvBiasForward>( | ||||
| ConvBias::algo_name<ConvBias::DirectParam>( | ConvBias::algo_name<ConvBias::DirectParam>( | ||||
| "INT8_NCHW32_IMMA_IMPLICIT_GEMM_256X128X64_64X64X64", | |||||
| "INT8_NCHW32_IMMA_IMPLICIT_GEMM_128X128X64_64X64X64_2", | |||||
| ConvBias::DirectParam{}) | ConvBias::DirectParam{}) | ||||
| .c_str())); | .c_str())); | ||||
| checker.set_dtype(0, dtype::QuantizedS8(1.9980618f)) | checker.set_dtype(0, dtype::QuantizedS8(1.9980618f)) | ||||