GitOrigin-RevId: 0da9b3bca8
tags/v1.11.1
| @@ -19,6 +19,7 @@ genrule( | |||||
| CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations dwconv2d_dgrad --type tensorop884 $(@D) | CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations dwconv2d_dgrad --type tensorop884 $(@D) | ||||
| CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations dwconv2d_wgrad --type simt $(@D) | CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations dwconv2d_wgrad --type simt $(@D) | ||||
| CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations dwconv2d_wgrad --type tensorop884 $(@D) | CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations dwconv2d_wgrad --type tensorop884 $(@D) | ||||
| CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations rrconv2d_wgrad --type simt $(@D) | |||||
| """, | """, | ||||
| tools = ["//brain/megbrain/dnn/scripts/cutlass_generator:generator.py"], | tools = ["//brain/megbrain/dnn/scripts/cutlass_generator:generator.py"], | ||||
| visibility = ["//visibility:public"], | visibility = ["//visibility:public"], | ||||
| @@ -35,6 +35,8 @@ class Conv2dOperation: | |||||
| without_shared_load=False, | without_shared_load=False, | ||||
| required_cuda_ver_major=9, | required_cuda_ver_major=9, | ||||
| required_cuda_ver_minor=2, | required_cuda_ver_minor=2, | ||||
| rin=None, | |||||
| rout=None, | |||||
| ): | ): | ||||
| self.operation_kind = OperationKind.Conv2d | self.operation_kind = OperationKind.Conv2d | ||||
| @@ -54,6 +56,8 @@ class Conv2dOperation: | |||||
| self.without_shared_load = without_shared_load | self.without_shared_load = without_shared_load | ||||
| self.required_cuda_ver_major = required_cuda_ver_major | self.required_cuda_ver_major = required_cuda_ver_major | ||||
| self.required_cuda_ver_minor = required_cuda_ver_minor | self.required_cuda_ver_minor = required_cuda_ver_minor | ||||
| self.rin = rin | |||||
| self.rout = rout | |||||
| # | # | ||||
| def accumulator_type(self): | def accumulator_type(self): | ||||
| @@ -95,6 +99,8 @@ class Conv2dOperation: | |||||
| conv_type_name = "" | conv_type_name = "" | ||||
| if self.conv_type == ConvType.DepthwiseConvolution: | if self.conv_type == ConvType.DepthwiseConvolution: | ||||
| conv_type_name = "dw" | conv_type_name = "dw" | ||||
| elif self.conv_type == ConvType.RegionRestrictedConvolution: | |||||
| conv_type_name = "rr" | |||||
| return "%s%s%s%s%s%s%s_%s" % ( | return "%s%s%s%s%s%s%s_%s" % ( | ||||
| ShortDataTypeNames[self.accumulator_type()], | ShortDataTypeNames[self.accumulator_type()], | ||||
| @@ -125,6 +131,9 @@ class Conv2dOperation: | |||||
| elif self.src.element == self.flt.element: | elif self.src.element == self.flt.element: | ||||
| extended_name = "${core_name}_${element_src}" | extended_name = "${core_name}_${element_src}" | ||||
| if self.rin != None: | |||||
| extended_name += "_${element_rin}" | |||||
| extended_name = SubstituteTemplate( | extended_name = SubstituteTemplate( | ||||
| extended_name, | extended_name, | ||||
| { | { | ||||
| @@ -132,6 +141,7 @@ class Conv2dOperation: | |||||
| "element_flt": DataTypeNames[self.flt.element], | "element_flt": DataTypeNames[self.flt.element], | ||||
| "element_dst": DataTypeNames[self.dst.element], | "element_dst": DataTypeNames[self.dst.element], | ||||
| "core_name": self.core_name(), | "core_name": self.core_name(), | ||||
| "element_rin": DataTypeNames[self.rin.element], | |||||
| }, | }, | ||||
| ) | ) | ||||
| @@ -512,6 +522,115 @@ using Convolution_${operation_name} = | |||||
| return SubstituteTemplate(self.template, values) | return SubstituteTemplate(self.template, values) | ||||
| class EmitRegionRestrictedConvolutionBackwardFilterInstance: | |||||
| def __init__(self): | |||||
| self.template = """ | |||||
| // kernel instance "${operation_name}" generated by cutlass generator | |||||
| using Convolution_${operation_name} = | |||||
| typename cutlass::conv::device::RegionRestrictedConvolutionBackwardFilter< | |||||
| ${element_src}, | |||||
| ${layout_src}, | |||||
| ${element_diff}, | |||||
| ${layout_diff}, | |||||
| ${element_src_mask}, | |||||
| ${layout_src_mask}, | |||||
| ${element_output_mask}, | |||||
| ${layout_output_mask}, | |||||
| ${element_grad}, | |||||
| ${layout_grad}, | |||||
| ${element_accumulator}, | |||||
| ${conv_type}, | |||||
| ${opcode_class}, | |||||
| ${arch}, | |||||
| cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, | |||||
| cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, | |||||
| cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, | |||||
| ${epilogue_functor}< | |||||
| ${element_grad}, | |||||
| ${epilogue_vector_length}, | |||||
| ${element_accumulator}, | |||||
| ${element_epilogue} | |||||
| >, | |||||
| ${swizzling_functor}, | |||||
| ${stages}, | |||||
| ${alignment_src}, | |||||
| ${alignment_diff}, | |||||
| ${alignment_src_mask}, | |||||
| ${alignment_output_mask}, | |||||
| ${special_optimization}, | |||||
| ${math_operator}, | |||||
| ${implicit_gemm_mode}>; | |||||
| """ | |||||
| def emit(self, operation): | |||||
| warp_shape = [ | |||||
| int( | |||||
| operation.tile_description.threadblock_shape[idx] | |||||
| / operation.tile_description.warp_count[idx] | |||||
| ) | |||||
| for idx in range(3) | |||||
| ] | |||||
| epilogue_vector_length = int( | |||||
| min(operation.dst.alignment * DataTypeSize[operation.dst.element], 128) | |||||
| / DataTypeSize[operation.dst.element] | |||||
| ) | |||||
| values = { | |||||
| "operation_name": operation.procedural_name(), | |||||
| "conv_type": ConvTypeTag[operation.conv_type], | |||||
| "element_src": DataTypeTag[operation.src.element], | |||||
| "layout_src": LayoutTag[operation.src.layout], | |||||
| "element_diff": DataTypeTag[operation.flt.element], | |||||
| "layout_diff": LayoutTag[operation.flt.layout], | |||||
| "element_src_mask": DataTypeTag[operation.rin.element], | |||||
| "layout_src_mask": LayoutTag[operation.rin.layout], | |||||
| "element_output_mask": DataTypeTag[operation.rout.element], | |||||
| "layout_output_mask": LayoutTag[operation.rout.layout], | |||||
| "element_grad": DataTypeTag[operation.dst.element], | |||||
| "layout_grad": LayoutTag[operation.dst.layout], | |||||
| "element_accumulator": DataTypeTag[operation.accumulator_type()], | |||||
| "opcode_class": OpcodeClassTag[ | |||||
| operation.tile_description.math_instruction.opcode_class | |||||
| ], | |||||
| "arch": "cutlass::arch::Sm%d" % operation.arch, | |||||
| "threadblock_shape_m": str(operation.tile_description.threadblock_shape[0]), | |||||
| "threadblock_shape_n": str(operation.tile_description.threadblock_shape[1]), | |||||
| "threadblock_shape_k": str(operation.tile_description.threadblock_shape[2]), | |||||
| "warp_shape_m": str(warp_shape[0]), | |||||
| "warp_shape_n": str(warp_shape[1]), | |||||
| "warp_shape_k": str(warp_shape[2]), | |||||
| "instruction_shape_m": str( | |||||
| operation.tile_description.math_instruction.instruction_shape[0] | |||||
| ), | |||||
| "instruction_shape_n": str( | |||||
| operation.tile_description.math_instruction.instruction_shape[1] | |||||
| ), | |||||
| "instruction_shape_k": str( | |||||
| operation.tile_description.math_instruction.instruction_shape[2] | |||||
| ), | |||||
| "epilogue_vector_length": str(epilogue_vector_length), | |||||
| "epilogue_functor": EpilogueFunctorTag[operation.epilogue_functor], | |||||
| "element_epilogue": str(DataTypeTag[operation.element_epilogue]), | |||||
| "swizzling_functor": SwizzlingFunctorTag[operation.swizzling_functor], | |||||
| "stages": str(operation.tile_description.stages), | |||||
| "alignment_src": str(operation.src.alignment), | |||||
| "alignment_diff": str(operation.flt.alignment), | |||||
| "alignment_src_mask": str(operation.rin.alignment), | |||||
| "alignment_output_mask": str(operation.rout.alignment), | |||||
| "special_optimization": SpecialOptimizeDescTag[ | |||||
| operation.special_optimization | |||||
| ], | |||||
| "math_operator": MathOperationTag[ | |||||
| operation.tile_description.math_instruction.math_operation | |||||
| ], | |||||
| "implicit_gemm_mode": ImplicitGemmModeTag[operation.implicit_gemm_mode], | |||||
| } | |||||
| return SubstituteTemplate(self.template, values) | |||||
| ################################################################################################### | ################################################################################################### | ||||
| # | # | ||||
| # Generator functions for all layouts | # Generator functions for all layouts | ||||
| @@ -540,7 +659,10 @@ def GenerateConv2d( | |||||
| operations = [] | operations = [] | ||||
| element_epilogue = DataType.f32 | element_epilogue = DataType.f32 | ||||
| if conv_type == ConvType.DepthwiseConvolution: | |||||
| if ( | |||||
| conv_type == ConvType.DepthwiseConvolution | |||||
| or conv_type == ConvType.RegionRestrictedConvolution | |||||
| ): | |||||
| if conv_kind == ConvKind.Fprop: | if conv_kind == ConvKind.Fprop: | ||||
| swizzling_functor = SwizzlingFunctor.DepthwiseConvolutionFprop | swizzling_functor = SwizzlingFunctor.DepthwiseConvolutionFprop | ||||
| elif conv_kind == ConvKind.Dgrad: | elif conv_kind == ConvKind.Dgrad: | ||||
| @@ -680,6 +802,16 @@ def GenerateConv2d( | |||||
| flt_layout, | flt_layout, | ||||
| int(flt_align / DataTypeSize[tile.math_instruction.element_a]), | int(flt_align / DataTypeSize[tile.math_instruction.element_a]), | ||||
| ) | ) | ||||
| rin = TensorDescription( | |||||
| tile.math_instruction.element_rin, | |||||
| src_layout, | |||||
| int(src_align / DataTypeSize[tile.math_instruction.element_rin]), | |||||
| ) | |||||
| rout = TensorDescription( | |||||
| tile.math_instruction.element_rout, | |||||
| dst_layout, | |||||
| int(dst_align / DataTypeSize[tile.math_instruction.element_rout]), | |||||
| ) | |||||
| bias = TensorDescription( | bias = TensorDescription( | ||||
| bias_type, dst_layout, max(1, int(32 / DataTypeSize[bias_type])) | bias_type, dst_layout, max(1, int(32 / DataTypeSize[bias_type])) | ||||
| ) | ) | ||||
| @@ -704,6 +836,8 @@ def GenerateConv2d( | |||||
| without_shared_load, | without_shared_load, | ||||
| required_cuda_ver_major, | required_cuda_ver_major, | ||||
| required_cuda_ver_minor, | required_cuda_ver_minor, | ||||
| rin, | |||||
| rout, | |||||
| ) | ) | ||||
| operations.append(new_operation) | operations.append(new_operation) | ||||
| if use_special_optimization != SpecialOptimizeDesc.NoneSpecialOpt: | if use_special_optimization != SpecialOptimizeDesc.NoneSpecialOpt: | ||||
| @@ -724,6 +858,8 @@ def GenerateConv2d( | |||||
| without_shared_load, | without_shared_load, | ||||
| required_cuda_ver_major, | required_cuda_ver_major, | ||||
| required_cuda_ver_minor, | required_cuda_ver_minor, | ||||
| rin, | |||||
| rout, | |||||
| ) | ) | ||||
| operations.append(new_operation) | operations.append(new_operation) | ||||
| return operations | return operations | ||||
| @@ -955,5 +1091,89 @@ void initialize_${operation_name}(Manifest &manifest) { | |||||
| self.kernel_file.close() | self.kernel_file.close() | ||||
| class EmitRegionRestrictedConvSingleKernelWrapper: | |||||
| def __init__(self, kernel_path, operation, short_path=False): | |||||
| self.kernel_path = kernel_path | |||||
| self.operation = operation | |||||
| self.short_path = short_path | |||||
| # Now only support wgrad | |||||
| assert self.operation.conv_kind == ConvKind.Wgrad | |||||
| self.instance_emitter = EmitRegionRestrictedConvolutionBackwardFilterInstance() | |||||
| self.convolution_name = "RegionRestrictedConvolutionBackwardFilterOperation" | |||||
| self.header_template = """ | |||||
| #if __CUDACC_VER_MAJOR__ > ${required_cuda_ver_major} || (__CUDACC_VER_MAJOR__ == ${required_cuda_ver_major} && __CUDACC_VER_MINOR__ >= ${required_cuda_ver_minor}) | |||||
| // ignore warning of cutlass | |||||
| #pragma GCC diagnostic push | |||||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||||
| #pragma GCC diagnostic ignored "-Wuninitialized" | |||||
| #pragma GCC diagnostic ignored "-Wmaybe-uninitialized" | |||||
| #include "cutlass/convolution/device/convolution.h" | |||||
| #include "src/cuda/cutlass/manifest.h" | |||||
| #include "src/cuda/cutlass/convolution_operation.h" | |||||
| """ | |||||
| self.instance_template = """ | |||||
| ${operation_instance} | |||||
| """ | |||||
| self.manifest_template = """ | |||||
| namespace cutlass { | |||||
| namespace library { | |||||
| void initialize_${operation_name}(Manifest &manifest) { | |||||
| manifest.append(new ${convolution_name}<Convolution_${operation_name}>( | |||||
| "${operation_name}" | |||||
| )); | |||||
| } | |||||
| } // namespace library | |||||
| } // namespace cutlass | |||||
| """ | |||||
| self.epilogue_template = """ | |||||
| #pragma GCC diagnostic pop | |||||
| #endif | |||||
| """ | |||||
| # | |||||
| def __enter__(self): | |||||
| if self.short_path: | |||||
| self.kernel_path = os.path.join(self.kernel_path, "%s.cu" % GlobalCnt.cnt) | |||||
| GlobalCnt.cnt += 1 | |||||
| else: | |||||
| self.kernel_path = os.path.join( | |||||
| self.kernel_path, "%s.cu" % self.operation.procedural_name() | |||||
| ) | |||||
| self.kernel_file = open(self.kernel_path, "w") | |||||
| return self | |||||
| # | |||||
| def emit(self): | |||||
| self.kernel_file.write( | |||||
| SubstituteTemplate( | |||||
| self.instance_template, | |||||
| {"operation_instance": self.instance_emitter.emit(self.operation)}, | |||||
| ) | |||||
| ) | |||||
| # emit manifest helper | |||||
| manifest = SubstituteTemplate( | |||||
| self.manifest_template, | |||||
| { | |||||
| "operation_name": self.operation.procedural_name(), | |||||
| "convolution_name": self.convolution_name, | |||||
| }, | |||||
| ) | |||||
| self.kernel_file.write(manifest) | |||||
| # | |||||
| def __exit__(self, exception_type, exception_value, traceback): | |||||
| self.kernel_file.close() | |||||
| ################################################################################################### | ################################################################################################### | ||||
| ################################################################################################### | ################################################################################################### | ||||
| @@ -64,4 +64,5 @@ if __name__ == "__main__": | |||||
| write_merge_file_name(f, "dwconv2d_dgrad", "tensorop884", 4) | write_merge_file_name(f, "dwconv2d_dgrad", "tensorop884", 4) | ||||
| write_merge_file_name(f, "dwconv2d_wgrad", "simt", 2) | write_merge_file_name(f, "dwconv2d_wgrad", "simt", 2) | ||||
| write_merge_file_name(f, "dwconv2d_wgrad", "tensorop884", 4) | write_merge_file_name(f, "dwconv2d_wgrad", "tensorop884", 4) | ||||
| write_merge_file_name(f, "rrconv2d_wgrad", "simt", 2) | |||||
| f.write("]") | f.write("]") | ||||
| @@ -1260,6 +1260,218 @@ def GenerateDwconv2d_Simt(args, conv_kind): | |||||
| return operations | return operations | ||||
| def GenerateRegionRestrictedconv2d_Simt(args, conv_kind): | |||||
| ################################################################################ | |||||
| # warps per threadblock | |||||
| ################################################################################ | |||||
| warpsPerThreadblocks = [] | |||||
| for warpsPerThreadblock0 in warpsPerThreadblockEdge: | |||||
| for warpsPerThreadblock1 in warpsPerThreadblockEdge: | |||||
| if ( | |||||
| warpsPerThreadblock0 / warpsPerThreadblock1 <= warpsPerThreadblockRatio | |||||
| and warpsPerThreadblock1 / warpsPerThreadblock0 | |||||
| <= warpsPerThreadblockRatio | |||||
| and warpsPerThreadblock0 * warpsPerThreadblock1 | |||||
| <= warpsPerThreadblockMax | |||||
| ): | |||||
| warpsPerThreadblocks.append( | |||||
| [warpsPerThreadblock0, warpsPerThreadblock1] | |||||
| ) | |||||
| ################################################################################ | |||||
| # warp shapes | |||||
| ################################################################################ | |||||
| warpNumThreads = 32 | |||||
| warpShapes = [] | |||||
| for warp0 in warpShapeEdges: | |||||
| for warp1 in warpShapeEdges: | |||||
| if ( | |||||
| warp0 / warp1 <= warpShapeRatio | |||||
| and warp1 / warp0 <= warpShapeRatio | |||||
| and warp0 * warp1 <= warpShapeMax | |||||
| and warp0 * warp1 > warpShapeMin | |||||
| ): | |||||
| warpShapes.append([warp0, warp1]) | |||||
| # sgemm | |||||
| ( | |||||
| precisionType, | |||||
| precisionBits, | |||||
| threadblockMaxElements, | |||||
| threadblockTilesL0, | |||||
| ) = precisions["s"] | |||||
| layouts = [(LayoutType.TensorNCHW, LayoutType.TensorNCHW)] | |||||
| math_instructions = [ | |||||
| MathInstruction( | |||||
| [1, 1, 1], | |||||
| DataType.f32, | |||||
| DataType.f32, | |||||
| DataType.f32, | |||||
| OpcodeClass.Simt, | |||||
| MathOperation.multiply_add, | |||||
| DataType.s32, | |||||
| DataType.s32, | |||||
| ), | |||||
| MathInstruction( | |||||
| [1, 1, 1], | |||||
| DataType.f32, | |||||
| DataType.f32, | |||||
| DataType.f32, | |||||
| OpcodeClass.Simt, | |||||
| MathOperation.multiply_add, | |||||
| DataType.s8, | |||||
| DataType.s8, | |||||
| ), | |||||
| ] | |||||
| min_cc = 50 | |||||
| max_cc = 1024 | |||||
| dst_layouts = [LayoutType.TensorNCHW] | |||||
| dst_types = [DataType.f32] | |||||
| if conv_kind == ConvKind.Wgrad: | |||||
| alignment_constraints = [32] | |||||
| else: | |||||
| alignment_constraints = [128, 32] | |||||
| operations = [] | |||||
| for math_inst in math_instructions: | |||||
| tile_descriptions = [ | |||||
| TileDescription([128, 128, 8], 1, [4, 2, 1], math_inst, min_cc, max_cc), | |||||
| TileDescription([128, 64, 8], 1, [2, 2, 1], math_inst, min_cc, max_cc), | |||||
| TileDescription([64, 128, 8], 1, [2, 2, 1], math_inst, min_cc, max_cc), | |||||
| TileDescription([128, 32, 8], 1, [2, 1, 1], math_inst, min_cc, max_cc), | |||||
| TileDescription([32, 128, 8], 1, [1, 2, 1], math_inst, min_cc, max_cc), | |||||
| TileDescription([64, 64, 8], 1, [2, 1, 1], math_inst, min_cc, max_cc), | |||||
| TileDescription([32, 64, 8], 1, [1, 1, 1], math_inst, min_cc, max_cc), | |||||
| TileDescription([64, 32, 8], 1, [1, 1, 1], math_inst, min_cc, max_cc), | |||||
| TileDescription([32, 32, 8], 1, [1, 1, 1], math_inst, min_cc, max_cc), | |||||
| ] | |||||
| for warpsPerThreadblock in warpsPerThreadblocks: | |||||
| for warpShape in warpShapes: | |||||
| warpThreadsM = 0 | |||||
| if warpShape[0] > warpShape[1]: | |||||
| warpThreadsM = 8 | |||||
| else: | |||||
| warpThreadsM = 4 | |||||
| warpThreadsN = warpNumThreads / warpThreadsM | |||||
| # skip shapes with conflicting rectangularity | |||||
| # they are unlikely to be fastest | |||||
| blockG = warpsPerThreadblock[0] > warpsPerThreadblock[1] | |||||
| blockL = warpsPerThreadblock[0] < warpsPerThreadblock[1] | |||||
| warpG = warpShape[0] > warpShape[1] | |||||
| warpL = warpShape[0] < warpShape[1] | |||||
| blockG2 = warpsPerThreadblock[0] > warpsPerThreadblock[1] * 2 | |||||
| blockL2 = warpsPerThreadblock[0] * 2 < warpsPerThreadblock[1] | |||||
| warpG2 = warpShape[0] > warpShape[1] * 2 | |||||
| warpL2 = warpShape[0] * 2 < warpShape[1] | |||||
| if blockG2 and warpL: | |||||
| continue | |||||
| if blockL2 and warpG: | |||||
| continue | |||||
| if warpG2 and blockL: | |||||
| continue | |||||
| if warpL2 and blockG: | |||||
| continue | |||||
| # check threadblock ratios and max | |||||
| threadblockTile = [ | |||||
| warpShape[0] * warpsPerThreadblock[0], | |||||
| warpShape[1] * warpsPerThreadblock[1], | |||||
| ] | |||||
| if threadblockTile[0] * threadblockTile[1] > threadblockMaxElements: | |||||
| continue | |||||
| if threadblockTile[0] > threadblockEdgeMax: | |||||
| continue | |||||
| if threadblockTile[1] > threadblockEdgeMax: | |||||
| continue | |||||
| totalThreads = ( | |||||
| warpNumThreads * warpsPerThreadblock[0] * warpsPerThreadblock[1] | |||||
| ) | |||||
| # calculate unroll | |||||
| # ensure that every iteration at least a full load of A,B are done | |||||
| unrollMin = 8 | |||||
| unrollMin0 = totalThreads // threadblockTile[0] | |||||
| unrollMin1 = totalThreads // threadblockTile[1] | |||||
| unroll = max(unrollMin, unrollMin0, unrollMin1) | |||||
| threadTileM = warpShape[0] // warpThreadsM | |||||
| threadTileN = warpShape[1] // warpThreadsN | |||||
| if threadTileM < 2 or threadTileN < 2: | |||||
| continue | |||||
| if threadTileM * threadTileN * precisionBits > 8 * 8 * 32: | |||||
| continue | |||||
| # epilogue currently only supports N < WarpNumThreads | |||||
| if threadblockTile[1] < warpNumThreads: | |||||
| continue | |||||
| # limit smem | |||||
| smemBitsA = threadblockTile[0] * unroll * 2 * precisionBits | |||||
| smemBitsB = threadblockTile[1] * unroll * 2 * precisionBits | |||||
| smemKBytes = (smemBitsA + smemBitsB) / 8 / 1024 | |||||
| if smemKBytes > 48: | |||||
| continue | |||||
| tile = TileDescription( | |||||
| [threadblockTile[0], threadblockTile[1], unroll], | |||||
| 1, | |||||
| [ | |||||
| threadblockTile[0] // warpShape[0], | |||||
| threadblockTile[1] // warpShape[1], | |||||
| 1, | |||||
| ], | |||||
| math_inst, | |||||
| min_cc, | |||||
| max_cc, | |||||
| ) | |||||
| def filter(t: TileDescription) -> bool: | |||||
| nonlocal tile | |||||
| return ( | |||||
| t.threadblock_shape[0] == tile.threadblock_shape[0] | |||||
| and t.threadblock_shape[1] == tile.threadblock_shape[1] | |||||
| and t.threadblock_shape[2] == tile.threadblock_shape[2] | |||||
| and t.warp_count[0] == tile.warp_count[0] | |||||
| and t.warp_count[1] == tile.warp_count[1] | |||||
| and t.warp_count[2] == tile.warp_count[2] | |||||
| and t.stages == tile.stages | |||||
| ) | |||||
| if not any(t for t in tile_descriptions if filter(t)): | |||||
| continue | |||||
| for layout in layouts: | |||||
| for dst_type, dst_layout in zip(dst_types, dst_layouts): | |||||
| for alignment_src in alignment_constraints: | |||||
| operations += GenerateConv2d( | |||||
| ConvType.RegionRestrictedConvolution, | |||||
| conv_kind, | |||||
| [tile], | |||||
| layout[0], | |||||
| layout[1], | |||||
| dst_layout, | |||||
| dst_type, | |||||
| min_cc, | |||||
| alignment_src, | |||||
| 32, | |||||
| 32, | |||||
| SpecialOptimizeDesc.NoneSpecialOpt, | |||||
| ImplicitGemmMode.GemmNT | |||||
| if conv_kind == ConvKind.Wgrad | |||||
| else ImplicitGemmMode.GemmTN, | |||||
| ) | |||||
| return operations | |||||
| # | # | ||||
| def GenerateDwconv2d_TensorOp_884(args, conv_kind): | def GenerateDwconv2d_TensorOp_884(args, conv_kind): | ||||
| layouts = [(LayoutType.TensorNCHW, LayoutType.TensorNCHW)] | layouts = [(LayoutType.TensorNCHW, LayoutType.TensorNCHW)] | ||||
| @@ -1644,6 +1856,14 @@ def GenerateDwconv2dWgradOperations(args): | |||||
| return GenerateDwconv2d_TensorOp_884(args, ConvKind.Wgrad) | return GenerateDwconv2d_TensorOp_884(args, ConvKind.Wgrad) | ||||
| def GenerateRegionRestrictedconv2dWgradOperations(args): | |||||
| assert args.type == "simt", ( | |||||
| "operation RegionRestrictedconv2d wgrad only support" | |||||
| "simt. (got:{})".format(args.type) | |||||
| ) | |||||
| return GenerateRegionRestrictedconv2d_Simt(args, ConvKind.Wgrad) | |||||
| def GenerateGemmOperations(args): | def GenerateGemmOperations(args): | ||||
| if args.type == "tensorop884": | if args.type == "tensorop884": | ||||
| return GeneratesGemm_TensorOp_884(args) | return GeneratesGemm_TensorOp_884(args) | ||||
| @@ -1698,6 +1918,8 @@ def ConcatFile( | |||||
| sub_string_1 = sub_string_2 = "simt" | sub_string_1 = sub_string_2 = "simt" | ||||
| if "dwconv2d_" in operations: | if "dwconv2d_" in operations: | ||||
| filtered_operations = operations[:2] + operations[9:] | filtered_operations = operations[:2] + operations[9:] | ||||
| if "rrconv2d_" in operations: | |||||
| filtered_operations = operations[:2] + operations[9:] | |||||
| elif ("conv2d" in operations) or ("deconv" in operations): | elif ("conv2d" in operations) or ("deconv" in operations): | ||||
| filtered_operations = "cutlass" | filtered_operations = "cutlass" | ||||
| else: | else: | ||||
| @@ -1893,6 +2115,7 @@ if __name__ == "__main__": | |||||
| "dwconv2d_fprop", | "dwconv2d_fprop", | ||||
| "dwconv2d_dgrad", | "dwconv2d_dgrad", | ||||
| "dwconv2d_wgrad", | "dwconv2d_wgrad", | ||||
| "rrconv2d_wgrad", | |||||
| ], | ], | ||||
| required=True, | required=True, | ||||
| help="Specifies the operation to generate (gemm, gemv, conv2d, deconv, dwconv2d_fprop, dwconv2d_dgrad, dwconv2d_wgrad)", | help="Specifies the operation to generate (gemm, gemv, conv2d, deconv, dwconv2d_fprop, dwconv2d_dgrad, dwconv2d_wgrad)", | ||||
| @@ -1928,9 +2151,11 @@ if __name__ == "__main__": | |||||
| operations = GenerateDwconv2dFpropOperations(args) | operations = GenerateDwconv2dFpropOperations(args) | ||||
| elif args.operations == "dwconv2d_dgrad": | elif args.operations == "dwconv2d_dgrad": | ||||
| operations = GenerateDwconv2dDgradOperations(args) | operations = GenerateDwconv2dDgradOperations(args) | ||||
| else: | |||||
| assert args.operations == "dwconv2d_wgrad", "invalid operation" | |||||
| elif args.operations == "dwconv2d_wgrad": | |||||
| operations = GenerateDwconv2dWgradOperations(args) | operations = GenerateDwconv2dWgradOperations(args) | ||||
| else: | |||||
| assert args.operations == "rrconv2d_wgrad", "invalid operation" | |||||
| operations = GenerateRegionRestrictedconv2dWgradOperations(args) | |||||
| if ( | if ( | ||||
| args.operations == "conv2d" | args.operations == "conv2d" | ||||
| @@ -1974,6 +2199,42 @@ if __name__ == "__main__": | |||||
| required_cuda_ver_minor, | required_cuda_ver_minor, | ||||
| epilogue, | epilogue, | ||||
| ) | ) | ||||
| elif args.operations == "rrconv2d_wgrad": | |||||
| for operation in operations: | |||||
| with EmitRegionRestrictedConvSingleKernelWrapper( | |||||
| args.output, operation, short_path | |||||
| ) as emitter: | |||||
| emitter.emit() | |||||
| head = EmitRegionRestrictedConvSingleKernelWrapper( | |||||
| args.output, operations[0], short_path | |||||
| ).header_template | |||||
| required_cuda_ver_major = operations[0].required_cuda_ver_major | |||||
| required_cuda_ver_minor = operations[0].required_cuda_ver_minor | |||||
| epilogue = EmitRegionRestrictedConvSingleKernelWrapper( | |||||
| args.output, operations[0], short_path | |||||
| ).epilogue_template | |||||
| if "tensorop" in args.type: | |||||
| ConcatFile( | |||||
| 4, | |||||
| args.output, | |||||
| args.operations, | |||||
| args.type, | |||||
| head, | |||||
| required_cuda_ver_major, | |||||
| required_cuda_ver_minor, | |||||
| epilogue, | |||||
| ) | |||||
| else: | |||||
| ConcatFile( | |||||
| 2, | |||||
| args.output, | |||||
| args.operations, | |||||
| args.type, | |||||
| head, | |||||
| required_cuda_ver_major, | |||||
| required_cuda_ver_minor, | |||||
| epilogue, | |||||
| ) | |||||
| elif args.operations == "gemm": | elif args.operations == "gemm": | ||||
| for operation in operations: | for operation in operations: | ||||
| with EmitGemmSingleKernelWrapper( | with EmitGemmSingleKernelWrapper( | ||||
| @@ -532,6 +532,7 @@ class ConvType(enum.Enum): | |||||
| Local = enum_auto() | Local = enum_auto() | ||||
| LocalShare = enum_auto() | LocalShare = enum_auto() | ||||
| DepthwiseConvolution = enum_auto() | DepthwiseConvolution = enum_auto() | ||||
| RegionRestrictedConvolution = enum_auto() | |||||
| ConvTypeTag = { | ConvTypeTag = { | ||||
| @@ -540,6 +541,8 @@ ConvTypeTag = { | |||||
| ConvType.Local: "cutlass::conv::ConvType::kLocal", | ConvType.Local: "cutlass::conv::ConvType::kLocal", | ||||
| ConvType.LocalShare: "cutlass::conv::ConvType::kLocalShare", | ConvType.LocalShare: "cutlass::conv::ConvType::kLocalShare", | ||||
| ConvType.DepthwiseConvolution: "cutlass::conv::ConvType::kDepthwiseConvolution", | ConvType.DepthwiseConvolution: "cutlass::conv::ConvType::kDepthwiseConvolution", | ||||
| # RegionRestrictedConvolution using the same conv type with Depthwise | |||||
| ConvType.RegionRestrictedConvolution: "cutlass::conv::ConvType::kDepthwiseConvolution", | |||||
| } | } | ||||
| # | # | ||||
| @@ -640,6 +643,8 @@ class MathInstruction: | |||||
| element_accumulator, | element_accumulator, | ||||
| opcode_class, | opcode_class, | ||||
| math_operation=MathOperation.multiply_add, | math_operation=MathOperation.multiply_add, | ||||
| element_rin=DataType.s32, | |||||
| element_rout=DataType.s32, | |||||
| ): | ): | ||||
| self.instruction_shape = instruction_shape | self.instruction_shape = instruction_shape | ||||
| self.element_a = element_a | self.element_a = element_a | ||||
| @@ -647,6 +652,8 @@ class MathInstruction: | |||||
| self.element_accumulator = element_accumulator | self.element_accumulator = element_accumulator | ||||
| self.opcode_class = opcode_class | self.opcode_class = opcode_class | ||||
| self.math_operation = math_operation | self.math_operation = math_operation | ||||
| self.element_rin = element_rin | |||||
| self.element_rout = element_rout | |||||
| # | # | ||||
| @@ -85,4 +85,7 @@ cutlass_gen_list = [ | |||||
| "dwconv2d_wgrad_tensorop884_2.cu", | "dwconv2d_wgrad_tensorop884_2.cu", | ||||
| "dwconv2d_wgrad_tensorop884_3.cu", | "dwconv2d_wgrad_tensorop884_3.cu", | ||||
| "all_dwconv2d_wgrad_tensorop884_operations.cu", | "all_dwconv2d_wgrad_tensorop884_operations.cu", | ||||
| "rrconv2d_wgrad_simt_0.cu", | |||||
| "rrconv2d_wgrad_simt_1.cu", | |||||
| "all_rrconv2d_wgrad_simt_operations.cu", | |||||
| ] | ] | ||||
| @@ -188,6 +188,7 @@ if(MGE_WITH_CUDA) | |||||
| gen_cutlass_kimpl(dwconv2d_dgrad tensorop884 CUTLASS_SOURCES) | gen_cutlass_kimpl(dwconv2d_dgrad tensorop884 CUTLASS_SOURCES) | ||||
| gen_cutlass_kimpl(dwconv2d_wgrad simt CUTLASS_SOURCES) | gen_cutlass_kimpl(dwconv2d_wgrad simt CUTLASS_SOURCES) | ||||
| gen_cutlass_kimpl(dwconv2d_wgrad tensorop884 CUTLASS_SOURCES) | gen_cutlass_kimpl(dwconv2d_wgrad tensorop884 CUTLASS_SOURCES) | ||||
| gen_cutlass_kimpl(rrconv2d_wgrad simt CUTLASS_SOURCES) | |||||
| list(PREPEND CUSOURCES ${CUTLASS_SOURCES}) | list(PREPEND CUSOURCES ${CUTLASS_SOURCES}) | ||||
| # Compile the following file first, the priority_compile_opr.txt is generated by | # Compile the following file first, the priority_compile_opr.txt is generated by | ||||
| @@ -452,6 +452,86 @@ public: | |||||
| /////////////////////////////////////////////////////////////////////////////////////////////////// | /////////////////////////////////////////////////////////////////////////////////////////////////// | ||||
| template <typename Operator_> | |||||
| class RegionRestrictedConvolutionBackwardFilterOperation | |||||
| : public ConvolutionBackwardFilterOperationBase<Operator_> { | |||||
| public: | |||||
| using Operator = Operator_; | |||||
| using ElementSrc = typename Operator::ElementSrc; | |||||
| using LayoutSrc = typename Operator::LayoutSrc; | |||||
| using ElementDiff = typename Operator::ElementDiff; | |||||
| using LayoutDiff = typename Operator::LayoutDiff; | |||||
| using ElementGrad = typename Operator::ElementGrad; | |||||
| using LayoutGrad = typename Operator::LayoutGrad; | |||||
| using ElementAccumulator = typename Operator::ElementAccumulator; | |||||
| using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; | |||||
| using OperatorArguments = typename Operator::Arguments; | |||||
| using ElementRin = typename Operator::ElementMaskInput; | |||||
| using LayoutRin = typename Operator::LayoutMaskInput; | |||||
| using ElementRout = typename Operator::ElementMaskOutput; | |||||
| using LayoutRout = typename Operator::LayoutMaskOutput; | |||||
| RegionRestrictedConvolutionBackwardFilterOperation( | |||||
| char const* name = "unknown_gemm") | |||||
| : ConvolutionBackwardFilterOperationBase<Operator_>(name) { | |||||
| /// rin in description -> rin in C++ template | |||||
| this->m_description.rin = make_TensorDescription<ElementRin, LayoutRin>( | |||||
| Operator::kAlignmentMaskInput); | |||||
| /// rout in description -> rout in C++ template | |||||
| this->m_description.rout = make_TensorDescription<ElementRout, LayoutRout>( | |||||
| Operator::kAlignmentMaskOutput); | |||||
| this->m_description.without_shared_load = false; | |||||
| } | |||||
| virtual Status run( | |||||
| void const* arguments_ptr, void* device_workspace = nullptr, | |||||
| cudaStream_t stream = nullptr) const { | |||||
| cutlass::conv::Operator conv_op = this->m_description.conv_op; | |||||
| ConvolutionArguments const* conv_args = | |||||
| reinterpret_cast<ConvolutionArguments const*>(arguments_ptr); | |||||
| const auto& ps = conv_args->problem_size; | |||||
| OperatorArguments args; | |||||
| args.problem_size = ps; | |||||
| /// src in convolution arguments -> ref_src | |||||
| args.ref_src = { | |||||
| static_cast<ElementSrc*>(const_cast<void*>(conv_args->src)), | |||||
| LayoutSrc::packed(implicit_gemm_tensor_b_extent(conv_op, ps))}; | |||||
| /// filter in convolution arguments -> ref_diff | |||||
| args.ref_diff = { | |||||
| static_cast<ElementDiff*>(const_cast<void*>(conv_args->filter)), | |||||
| LayoutDiff::packed(implicit_gemm_tensor_a_extent(conv_op, ps))}; | |||||
| /// dst in convolution arguments -> ref_grad | |||||
| args.ref_grad = { | |||||
| static_cast<ElementGrad*>(conv_args->dst), | |||||
| LayoutGrad::packed(implicit_gemm_tensor_c_extent(conv_op, ps))}; | |||||
| /// rin in convolution arguments -> ref_mask_input | |||||
| args.ref_mask_input = { | |||||
| static_cast<ElementRin*>(const_cast<void*>(conv_args->rin)), | |||||
| LayoutRin::packed(implicit_gemm_tensor_rin_extent(conv_op, ps))}; | |||||
| /// rout in convolution arguments -> ref_mask_output | |||||
| args.ref_mask_output = { | |||||
| static_cast<ElementRout*>(const_cast<void*>(conv_args->rout)), | |||||
| LayoutRout::packed(implicit_gemm_tensor_rout_extent(conv_op, ps))}; | |||||
| args.output_op = init_epilogue_param<typename Operator::EpilogueOutputOp>().get( | |||||
| conv_args); | |||||
| Operator op; | |||||
| Status status = op.initialize(args, device_workspace); | |||||
| if (status != Status::kSuccess) { | |||||
| return status; | |||||
| } | |||||
| return op.run(stream); | |||||
| } | |||||
| }; | |||||
| /////////////////////////////////////////////////////////////////////////////////////////////////// | |||||
| } // namespace library | } // namespace library | ||||
| } // namespace cutlass | } // namespace cutlass | ||||
| @@ -50,6 +50,7 @@ void initialize_all_deconv_simt_operations(Manifest& manifest); | |||||
| void initialize_all_dwconv2d_fprop_simt_operations(Manifest& manifest); | void initialize_all_dwconv2d_fprop_simt_operations(Manifest& manifest); | ||||
| void initialize_all_dwconv2d_dgrad_simt_operations(Manifest& manifest); | void initialize_all_dwconv2d_dgrad_simt_operations(Manifest& manifest); | ||||
| void initialize_all_dwconv2d_wgrad_simt_operations(Manifest& manifest); | void initialize_all_dwconv2d_wgrad_simt_operations(Manifest& manifest); | ||||
| void initialize_all_rrconv2d_wgrad_simt_operations(Manifest& manifest); | |||||
| #if defined(CUTLASS_ARCH_MMA_SM70_SUPPORTED) && CUTLASS_ARCH_MMA_SM70_SUPPORTED | #if defined(CUTLASS_ARCH_MMA_SM70_SUPPORTED) && CUTLASS_ARCH_MMA_SM70_SUPPORTED | ||||
| void initialize_all_gemm_tensorop884_operations(Manifest& manifest); | void initialize_all_gemm_tensorop884_operations(Manifest& manifest); | ||||
| void initialize_all_dwconv2d_fprop_tensorop884_operations(Manifest& manifest); | void initialize_all_dwconv2d_fprop_tensorop884_operations(Manifest& manifest); | ||||
| @@ -70,6 +71,7 @@ void initialize_all(Manifest& manifest) { | |||||
| initialize_all_dwconv2d_fprop_simt_operations(manifest); | initialize_all_dwconv2d_fprop_simt_operations(manifest); | ||||
| initialize_all_dwconv2d_dgrad_simt_operations(manifest); | initialize_all_dwconv2d_dgrad_simt_operations(manifest); | ||||
| initialize_all_dwconv2d_wgrad_simt_operations(manifest); | initialize_all_dwconv2d_wgrad_simt_operations(manifest); | ||||
| initialize_all_rrconv2d_wgrad_simt_operations(manifest); | |||||
| #if defined(CUTLASS_ARCH_MMA_SM70_SUPPORTED) && CUTLASS_ARCH_MMA_SM70_SUPPORTED | #if defined(CUTLASS_ARCH_MMA_SM70_SUPPORTED) && CUTLASS_ARCH_MMA_SM70_SUPPORTED | ||||
| initialize_all_gemm_tensorop884_operations(manifest); | initialize_all_gemm_tensorop884_operations(manifest); | ||||
| initialize_all_dwconv2d_fprop_tensorop884_operations(manifest); | initialize_all_dwconv2d_fprop_tensorop884_operations(manifest); | ||||
| @@ -471,6 +471,10 @@ struct ConvolutionDescription : public OperationDescription { | |||||
| conv::SpecialOptimizeDesc special_optimization; | conv::SpecialOptimizeDesc special_optimization; | ||||
| conv::ImplicitGemmMode gemm_mode; | conv::ImplicitGemmMode gemm_mode; | ||||
| bool without_shared_load; | bool without_shared_load; | ||||
| // only used by rrconv | |||||
| TensorDescription rin; | |||||
| TensorDescription rout; | |||||
| }; | }; | ||||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | ///////////////////////////////////////////////////////////////////////////////////////////////// | ||||
| @@ -499,6 +503,10 @@ struct ConvolutionArguments { | |||||
| /// Host pointer to extra param struct | /// Host pointer to extra param struct | ||||
| void const* extra_param; | void const* extra_param; | ||||
| // only used by rrconv, default: nullptr | |||||
| void const* rin = nullptr; | |||||
| void const* rout = nullptr; | |||||
| }; | }; | ||||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | ///////////////////////////////////////////////////////////////////////////////////////////////// | ||||
| @@ -118,6 +118,11 @@ ConvolutionKey get_convolution_key_from_desc(const ConvolutionDescription& desc) | |||||
| key.alignment_filter = desc.filter.alignment; | key.alignment_filter = desc.filter.alignment; | ||||
| key.without_shared_load = desc.without_shared_load; | key.without_shared_load = desc.without_shared_load; | ||||
| key.element_rin = desc.rin.element; | |||||
| key.layout_rin = desc.rin.layout; | |||||
| key.element_rout = desc.rout.element; | |||||
| key.layout_rout = desc.rout.layout; | |||||
| return key; | return key; | ||||
| } | } | ||||
| @@ -201,6 +201,12 @@ struct ConvolutionKey { | |||||
| bool without_shared_load; | bool without_shared_load; | ||||
| // only used by rrconv | |||||
| library::NumericTypeID element_rin = library::NumericTypeID::kInvalid; | |||||
| library::LayoutTypeID layout_rin = library::LayoutTypeID::kInvalid; | |||||
| library::NumericTypeID element_rout = library::NumericTypeID::kInvalid; | |||||
| library::LayoutTypeID layout_rout = library::LayoutTypeID::kInvalid; | |||||
| inline bool operator==(ConvolutionKey const& rhs) const { | inline bool operator==(ConvolutionKey const& rhs) const { | ||||
| return (conv_op == rhs.conv_op) && (element_src == rhs.element_src) && | return (conv_op == rhs.conv_op) && (element_src == rhs.element_src) && | ||||
| (layout_src == rhs.layout_src) && | (layout_src == rhs.layout_src) && | ||||
| @@ -223,7 +229,9 @@ struct ConvolutionKey { | |||||
| (special_optimization == rhs.special_optimization) && | (special_optimization == rhs.special_optimization) && | ||||
| (alignment_src == rhs.alignment_src) && | (alignment_src == rhs.alignment_src) && | ||||
| (alignment_filter == rhs.alignment_filter) && | (alignment_filter == rhs.alignment_filter) && | ||||
| (without_shared_load == rhs.without_shared_load); | |||||
| (without_shared_load == rhs.without_shared_load) && | |||||
| (element_rin == rhs.element_rin) && (layout_rin == rhs.layout_rin) && | |||||
| (element_rout == rhs.element_rout) && (layout_rout == rhs.layout_rout); | |||||
| } | } | ||||
| inline bool operator!=(ConvolutionKey const& rhs) const { return !(*this == rhs); } | inline bool operator!=(ConvolutionKey const& rhs) const { return !(*this == rhs); } | ||||
| @@ -260,7 +268,11 @@ struct ConvolutionKey { | |||||
| "\n special_optimization: " + to_string(special_optimization) + | "\n special_optimization: " + to_string(special_optimization) + | ||||
| "\n alignment_src: " + std::to_string(alignment_src) + | "\n alignment_src: " + std::to_string(alignment_src) + | ||||
| "\n alignment_filter: " + std::to_string(alignment_filter) + | "\n alignment_filter: " + std::to_string(alignment_filter) + | ||||
| "\n without_shared_load: " + to_string(without_shared_load) + "\n}"; | |||||
| "\n without_shared_load: " + to_string(without_shared_load) + | |||||
| "\n element_rin: " + to_string(element_rin) + | |||||
| "\n layout_rin: " + to_string(layout_rin) + | |||||
| "\n element_rout: " + to_string(element_rout) + | |||||
| "\n layout_rout: " + to_string(layout_rout) + "\n}"; | |||||
| } | } | ||||
| }; | }; | ||||
| @@ -293,6 +305,10 @@ struct ConvolutionKeyHasher { | |||||
| .update(&key.alignment_src, sizeof(key.alignment_src)) | .update(&key.alignment_src, sizeof(key.alignment_src)) | ||||
| .update(&key.alignment_filter, sizeof(key.alignment_filter)) | .update(&key.alignment_filter, sizeof(key.alignment_filter)) | ||||
| .update(&key.without_shared_load, sizeof(key.without_shared_load)) | .update(&key.without_shared_load, sizeof(key.without_shared_load)) | ||||
| .update(&key.element_rin, sizeof(key.element_rin)) | |||||
| .update(&key.layout_rin, sizeof(key.layout_rin)) | |||||
| .update(&key.element_rout, sizeof(key.element_rout)) | |||||
| .update(&key.layout_rout, sizeof(key.layout_rout)) | |||||
| .digest(); | .digest(); | ||||
| } | } | ||||
| }; | }; | ||||
| @@ -1,4 +1,5 @@ | |||||
| #include "src/cuda/region_restricted_convolution/opr_impl.h" | #include "src/cuda/region_restricted_convolution/opr_impl.h" | ||||
| #include "src/cuda/cutlass/singleton.h" | |||||
| #include "src/cuda/region_restricted_convolution/chanwise/depthwise_large_filter.cuh" | #include "src/cuda/region_restricted_convolution/chanwise/depthwise_large_filter.cuh" | ||||
| #include "src/cuda/region_restricted_convolution/chanwise/kern.cuh" | #include "src/cuda/region_restricted_convolution/chanwise/kern.cuh" | ||||
| #include "src/cuda/utils.h" | #include "src/cuda/utils.h" | ||||
| @@ -6,6 +7,7 @@ | |||||
| using namespace megdnn; | using namespace megdnn; | ||||
| using namespace cuda; | using namespace cuda; | ||||
| using namespace region_restricted_convolution; | using namespace region_restricted_convolution; | ||||
| using namespace cutlass::library; | |||||
| /* ============== RegionRestrictedConvolutionForwardImpl ============== */ | /* ============== RegionRestrictedConvolutionForwardImpl ============== */ | ||||
| void RegionRestrictedConvolutionForwardImpl::exec( | void RegionRestrictedConvolutionForwardImpl::exec( | ||||
| @@ -113,7 +115,137 @@ size_t RegionRestrictedConvolutionBackwardFilterImpl::get_workspace_in_bytes( | |||||
| void RegionRestrictedConvolutionBackwardFilterImpl::exec( | void RegionRestrictedConvolutionBackwardFilterImpl::exec( | ||||
| _megdnn_tensor_in src, _megdnn_tensor_in diff, _megdnn_tensor_in rin, | _megdnn_tensor_in src, _megdnn_tensor_in diff, _megdnn_tensor_in rin, | ||||
| _megdnn_tensor_in rout, _megdnn_tensor_out grad, _megdnn_workspace workspace) { | _megdnn_tensor_in rout, _megdnn_tensor_out grad, _megdnn_workspace workspace) { | ||||
| megdnn_throw("Region Restricted Conv BackwardFilter unimplemented"); | |||||
| auto fm = check_exec( | |||||
| src.layout, diff.layout, rin.layout, rout.layout, grad.layout, | |||||
| workspace.size); | |||||
| megdnn_assert( | |||||
| fm.group > 1 && src.layout.dtype.category() == DTypeCategory::FLOAT && | |||||
| param().compute_mode == Param::ComputeMode::DEFAULT && | |||||
| fm.spatial_ndim == 2 && fm.icpg == 1 && fm.ocpg == 1 && | |||||
| fm.dilation[0] == 1 && fm.dilation[1] == 1 && !fm.should_flip && | |||||
| param().stride_h == 1 && param().stride_w == 1); | |||||
| int hi = src.layout.operator[](2), wi = src.layout.operator[](3); | |||||
| int n = diff.layout.operator[](0), ho = diff.layout.operator[](2), | |||||
| wo = diff.layout.operator[](3); | |||||
| int co = fm.group, ci = co, groups = co; | |||||
| int fh = fm.spatial[0], fw = fm.spatial[1]; | |||||
| int sh = fm.stride[0], sw = fm.stride[1]; | |||||
| int ph = fm.padding[0], pw = fm.padding[1]; | |||||
| int dh = 0, dw = 0; | |||||
| // check if channelwise convolution | |||||
| megdnn_assert(fm.icpg == 1 && fm.ocpg == 1); | |||||
| auto stream = cuda_stream(handle()); | |||||
| float alpha = 1.f; | |||||
| float beta = 0.f; | |||||
| ConvolutionKey key; | |||||
| int threadblock_shape_n = 128; | |||||
| int warp_shape_m = 32; | |||||
| int warp_shape_n = 64; | |||||
| if (grad.layout.operator[](3) % 8 < 4) { | |||||
| threadblock_shape_n = 64; | |||||
| warp_shape_m = 64; | |||||
| warp_shape_n = 32; | |||||
| } | |||||
| if (rin.layout.dtype == dtype::Int32() && rout.layout.dtype == dtype::Int32()) { | |||||
| key = { | |||||
| cutlass::conv::Operator::kWgrad, | |||||
| NumericTypeID::kF32, | |||||
| LayoutTypeID::kTensorNCHW, | |||||
| NumericTypeID::kF32, | |||||
| LayoutTypeID::kTensorNCHW, | |||||
| NumericTypeID::kF32, | |||||
| LayoutTypeID::kTensorNCHW, | |||||
| NumericTypeID::kF32, | |||||
| LayoutTypeID::kTensorNCHW, | |||||
| NumericTypeID::kF32, | |||||
| cutlass::conv::ConvType::kDepthwiseConvolution, | |||||
| 128, | |||||
| threadblock_shape_n, | |||||
| 8, | |||||
| warp_shape_m, | |||||
| warp_shape_n, | |||||
| 8, | |||||
| 1, | |||||
| 1, | |||||
| 1, | |||||
| cutlass::epilogue::EpilogueType::kLinearCombination, | |||||
| 1, | |||||
| cutlass::conv::SpecialOptimizeDesc::NONE, | |||||
| 1, | |||||
| 1, | |||||
| false, | |||||
| NumericTypeID::kS32, | |||||
| LayoutTypeID::kTensorNCHW, | |||||
| NumericTypeID::kS32, | |||||
| LayoutTypeID::kTensorNCHW, | |||||
| }; | |||||
| } else if ( | |||||
| rin.layout.dtype == dtype::Uint8() && rout.layout.dtype == dtype::Uint8()) { | |||||
| key = { | |||||
| cutlass::conv::Operator::kWgrad, | |||||
| NumericTypeID::kF32, | |||||
| LayoutTypeID::kTensorNCHW, | |||||
| NumericTypeID::kF32, | |||||
| LayoutTypeID::kTensorNCHW, | |||||
| NumericTypeID::kF32, | |||||
| LayoutTypeID::kTensorNCHW, | |||||
| NumericTypeID::kF32, | |||||
| LayoutTypeID::kTensorNCHW, | |||||
| NumericTypeID::kF32, | |||||
| cutlass::conv::ConvType::kDepthwiseConvolution, | |||||
| 128, | |||||
| threadblock_shape_n, | |||||
| 8, | |||||
| warp_shape_m, | |||||
| warp_shape_n, | |||||
| 8, | |||||
| 1, | |||||
| 1, | |||||
| 1, | |||||
| cutlass::epilogue::EpilogueType::kLinearCombination, | |||||
| 1, | |||||
| cutlass::conv::SpecialOptimizeDesc::NONE, | |||||
| 1, | |||||
| 1, | |||||
| false, | |||||
| NumericTypeID::kS8, | |||||
| LayoutTypeID::kTensorNCHW, | |||||
| NumericTypeID::kS8, | |||||
| LayoutTypeID::kTensorNCHW, | |||||
| }; | |||||
| } else { | |||||
| megdnn_throw(ssprintf( | |||||
| "don't support region restricted type rin: %s, rout: %s", | |||||
| rin.layout.dtype.name(), rout.layout.dtype.name()) | |||||
| .c_str()); | |||||
| } | |||||
| const Operation* op = | |||||
| (const Operation*)Singleton::get().operation_table.find_op(key); | |||||
| cutlass::conv::Conv2dProblemSize problem_size{ | |||||
| n, hi, wi, ci, co, fh, fw, ho, | |||||
| wo, ph, pw, sh, sw, dh, dw, cutlass::conv::Mode::kCrossCorrelation, | |||||
| 1, // split k slices, always 1 | |||||
| groups, // groups | |||||
| }; | |||||
| cutlass::library::ConvolutionArguments conv_args{ | |||||
| problem_size, src.raw_ptr(), diff.raw_ptr(), nullptr, | |||||
| nullptr, grad.raw_ptr(), &alpha, &beta, | |||||
| nullptr, nullptr, nullptr, nullptr, | |||||
| nullptr, nullptr, rin.raw_ptr(), rout.raw_ptr()}; | |||||
| cutlass_check(op->run(&conv_args, nullptr, stream)); | |||||
| after_kernel_launch(); | |||||
| } | } | ||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
| @@ -465,6 +465,206 @@ TEST_F(CUDA, BENCHMARK_REGION_RESTRICTED_CONV_FORWARD_LARGE_FILTER_UINT8) { | |||||
| run_bench(64, 384, 32, 32, 31, 31, 1, 1, 10); | run_bench(64, 384, 32, 32, 31, 31, 1, 1, 10); | ||||
| } | } | ||||
| TEST_F(CUDA, BENCHMARK_REGION_RESTRICTED_CONV_BACKWARD_FILTER_FP32) { | |||||
| require_compute_capability(7, 5); | |||||
| Benchmarker<ConvolutionBackwardFilter> bencher(handle_cuda()); | |||||
| bencher.set_display(false); | |||||
| bencher.set_before_exec_callback(AlgoChecker<ConvolutionBackwardFilter>( | |||||
| "FLOAT32_NCHW_FMA_IMPLICIT_BATCHED_GEMM_128X128X8_32X64X8_2stage")); | |||||
| Benchmarker<RegionRestrictedConvolutionBackwardFilter> rr_bencher(handle_cuda()); | |||||
| rr_bencher.set_display(false); | |||||
| ConvolutionBackwardFilter::Param param; | |||||
| param.format = ConvolutionBackwardFilter::Param::Format::NCHW; | |||||
| param.sparse = ConvolutionBackwardFilter::Param::Sparse::GROUP; | |||||
| RegionRestrictedConvolutionBackwardFilter::Param rr_param; | |||||
| rr_param.format = RegionRestrictedConvolutionBackwardFilter::Param::Format::NCHW; | |||||
| rr_param.sparse = RegionRestrictedConvolutionBackwardFilter::Param::Sparse::GROUP; | |||||
| UniformIntRNG r_rng{1, 3}; | |||||
| auto run_bench = [&](size_t batch, size_t g, size_t hi, size_t wi, size_t fh, | |||||
| size_t fw, size_t sh, size_t sw, size_t nr_times) { | |||||
| param.pad_h = fh / 2; | |||||
| param.pad_w = fw / 2; | |||||
| param.stride_h = sh; | |||||
| param.stride_w = sw; | |||||
| rr_param.pad_h = fh / 2; | |||||
| rr_param.pad_w = fw / 2; | |||||
| rr_param.stride_h = sh; | |||||
| rr_param.stride_w = sw; | |||||
| bencher.set_param(param) | |||||
| .set_dtype(0, dtype::Float32()) | |||||
| .set_dtype(1, dtype::Float32()) | |||||
| .set_dtype(2, dtype::Float32()) | |||||
| .set_dtype(4, dtype::Float32()); | |||||
| bencher.proxy()->target_execution_policy = {}; | |||||
| bencher.set_times(nr_times); | |||||
| rr_bencher.set_param(rr_param) | |||||
| .set_dtype(0, dtype::Float32()) | |||||
| .set_dtype(1, dtype::Float32()) | |||||
| .set_dtype(2, dtype::Int32()) | |||||
| .set_dtype(3, dtype::Int32()); | |||||
| rr_bencher.set_rng(2, &r_rng).set_rng(3, &r_rng); | |||||
| rr_bencher.set_times(nr_times); | |||||
| size_t ho = infer_conv_shape(hi, fh, sh, param.pad_h); | |||||
| size_t wo = infer_conv_shape(wi, fw, sw, param.pad_w); | |||||
| TensorShape src{batch, g, hi, wi}, diff{batch, g, ho, wo}, rin{batch, hi, wi}, | |||||
| rout{batch, ho, wo}, grad{g, 1, 1, fh, fw}; | |||||
| float bandwith = static_cast<float>( | |||||
| src.total_nr_elems() + diff.total_nr_elems() + | |||||
| grad.total_nr_elems()) / | |||||
| (1024 * 1024 * 1024) * 1e3; | |||||
| float rr_bandwith = static_cast<float>( | |||||
| src.total_nr_elems() + diff.total_nr_elems() + | |||||
| rin.total_nr_elems() + rout.total_nr_elems() + | |||||
| grad.total_nr_elems()) / | |||||
| (1024 * 1024 * 1024) * 1e3; | |||||
| auto time_in_ms = bencher.execs({src, diff, grad}) / nr_times; | |||||
| auto ops = 2.0 * batch * g * hi * wi * fh * fw / (time_in_ms * 1e-3) * 1e-12; | |||||
| auto rr_time_in_ms = rr_bencher.execs({src, diff, rin, rout, grad}) / nr_times; | |||||
| auto rr_ops = | |||||
| 2.0 * batch * g * hi * wi * fh * fw / (rr_time_in_ms * 1e-3) * 1e-12; | |||||
| printf("[DGRAD]RegionRestrictedDepthwiseLargeFilter vs DepthwiseLargeFilter: " | |||||
| "src=%s, " | |||||
| "diff=%s, grad=%s\n" | |||||
| "time: %.2f ms, time(rr): %.2f ms, perf: %.2fTops, perf(rr): %.2f Tops\n" | |||||
| "bandwidth: %.2fGB/s, bandwidth(rr): %.2fGB/s, speedup: %.2f.\n", | |||||
| src.to_string().c_str(), diff.to_string().c_str(), | |||||
| grad.to_string().c_str(), time_in_ms, rr_time_in_ms, ops, rr_ops, | |||||
| bandwith * 4 / time_in_ms, rr_bandwith * 4 / rr_time_in_ms, | |||||
| time_in_ms / rr_time_in_ms); | |||||
| }; | |||||
| run_bench(64, 384, 32, 32, 3, 3, 1, 1, 1000); | |||||
| run_bench(64, 384, 32, 32, 5, 5, 1, 1, 1000); | |||||
| run_bench(64, 384, 32, 32, 7, 7, 1, 1, 1000); | |||||
| run_bench(64, 384, 32, 32, 9, 9, 1, 1, 1000); | |||||
| run_bench(64, 384, 32, 32, 11, 11, 1, 1, 1000); | |||||
| run_bench(64, 384, 32, 32, 13, 13, 1, 1, 1000); | |||||
| run_bench(64, 384, 32, 32, 15, 15, 1, 1, 1000); | |||||
| run_bench(64, 384, 32, 32, 17, 17, 1, 1, 1000); | |||||
| run_bench(64, 384, 32, 32, 19, 19, 1, 1, 1000); | |||||
| run_bench(64, 384, 32, 32, 21, 21, 1, 1, 1000); | |||||
| run_bench(64, 384, 32, 32, 23, 23, 1, 1, 1000); | |||||
| run_bench(64, 384, 32, 32, 25, 25, 1, 1, 1000); | |||||
| run_bench(64, 384, 32, 32, 27, 27, 1, 1, 1000); | |||||
| run_bench(64, 384, 32, 32, 29, 29, 1, 1, 1000); | |||||
| run_bench(64, 384, 32, 32, 31, 31, 1, 1, 1000); | |||||
| } | |||||
| TEST_F(CUDA, BENCHMARK_REGION_RESTRICTED_CONV_BACKWARD_FILTER_FP32_RINT8) { | |||||
| require_compute_capability(7, 5); | |||||
| Benchmarker<ConvolutionBackwardFilter> bencher(handle_cuda()); | |||||
| bencher.set_display(false); | |||||
| bencher.set_before_exec_callback(AlgoChecker<ConvolutionBackwardFilter>( | |||||
| "FLOAT32_NCHW_FMA_IMPLICIT_BATCHED_GEMM_128X128X8_32X64X8_2stage")); | |||||
| Benchmarker<RegionRestrictedConvolutionBackwardFilter> rr_bencher(handle_cuda()); | |||||
| rr_bencher.set_display(false); | |||||
| ConvolutionBackwardFilter::Param param; | |||||
| param.format = ConvolutionBackwardFilter::Param::Format::NCHW; | |||||
| param.sparse = ConvolutionBackwardFilter::Param::Sparse::GROUP; | |||||
| RegionRestrictedConvolutionBackwardFilter::Param rr_param; | |||||
| rr_param.format = RegionRestrictedConvolutionBackwardFilter::Param::Format::NCHW; | |||||
| rr_param.sparse = RegionRestrictedConvolutionBackwardFilter::Param::Sparse::GROUP; | |||||
| UniformIntRNG r_rng{1, 3}; | |||||
| auto run_bench = [&](size_t batch, size_t g, size_t hi, size_t wi, size_t fh, | |||||
| size_t fw, size_t sh, size_t sw, size_t nr_times) { | |||||
| param.pad_h = fh / 2; | |||||
| param.pad_w = fw / 2; | |||||
| param.stride_h = sh; | |||||
| param.stride_w = sw; | |||||
| rr_param.pad_h = fh / 2; | |||||
| rr_param.pad_w = fw / 2; | |||||
| rr_param.stride_h = sh; | |||||
| rr_param.stride_w = sw; | |||||
| bencher.set_param(param) | |||||
| .set_dtype(0, dtype::Float32()) | |||||
| .set_dtype(1, dtype::Float32()) | |||||
| .set_dtype(2, dtype::Float32()) | |||||
| .set_dtype(4, dtype::Float32()); | |||||
| bencher.proxy()->target_execution_policy = {}; | |||||
| bencher.set_times(nr_times); | |||||
| rr_bencher.set_param(rr_param) | |||||
| .set_dtype(0, dtype::Float32()) | |||||
| .set_dtype(1, dtype::Float32()) | |||||
| .set_dtype(2, dtype::Uint8()) | |||||
| .set_dtype(3, dtype::Uint8()); | |||||
| rr_bencher.set_rng(2, &r_rng).set_rng(3, &r_rng); | |||||
| rr_bencher.set_times(nr_times); | |||||
| size_t ho = infer_conv_shape(hi, fh, sh, param.pad_h); | |||||
| size_t wo = infer_conv_shape(wi, fw, sw, param.pad_w); | |||||
| TensorShape src{batch, g, hi, wi}, diff{batch, g, ho, wo}, rin{batch, hi, wi}, | |||||
| rout{batch, ho, wo}, grad{g, 1, 1, fh, fw}; | |||||
| float bandwith = static_cast<float>( | |||||
| src.total_nr_elems() + diff.total_nr_elems() + | |||||
| grad.total_nr_elems()) / | |||||
| (1024 * 1024 * 1024) * 1e3; | |||||
| float rr_bandwith = static_cast<float>( | |||||
| src.total_nr_elems() + diff.total_nr_elems() + | |||||
| rin.total_nr_elems() + rout.total_nr_elems() + | |||||
| grad.total_nr_elems()) / | |||||
| (1024 * 1024 * 1024) * 1e3; | |||||
| auto time_in_ms = bencher.execs({src, diff, grad}) / nr_times; | |||||
| auto ops = 2.0 * batch * g * hi * wi * fh * fw / (time_in_ms * 1e-3) * 1e-12; | |||||
| auto rr_time_in_ms = rr_bencher.execs({src, diff, rin, rout, grad}) / nr_times; | |||||
| auto rr_ops = | |||||
| 2.0 * batch * g * hi * wi * fh * fw / (rr_time_in_ms * 1e-3) * 1e-12; | |||||
| printf("[DGRAD]RegionRestrictedDepthwiseLargeFilter vs DepthwiseLargeFilter: " | |||||
| "src=%s, " | |||||
| "diff=%s, grad=%s\n" | |||||
| "time: %.2f ms, time(rr): %.2f ms, perf: %.2fTops, perf(rr): %.2f Tops\n" | |||||
| "bandwidth: %.2fGB/s, bandwidth(rr): %.2fGB/s, speedup: %.2f.\n", | |||||
| src.to_string().c_str(), diff.to_string().c_str(), | |||||
| grad.to_string().c_str(), time_in_ms, rr_time_in_ms, ops, rr_ops, | |||||
| bandwith * 4 / time_in_ms, rr_bandwith * 4 / rr_time_in_ms, | |||||
| time_in_ms / rr_time_in_ms); | |||||
| }; | |||||
| run_bench(64, 384, 32, 32, 3, 3, 1, 1, 1000); | |||||
| run_bench(64, 384, 32, 32, 5, 5, 1, 1, 1000); | |||||
| run_bench(64, 384, 32, 32, 7, 7, 1, 1, 1000); | |||||
| run_bench(64, 384, 32, 32, 9, 9, 1, 1, 1000); | |||||
| run_bench(64, 384, 32, 32, 11, 11, 1, 1, 1000); | |||||
| run_bench(64, 384, 32, 32, 13, 13, 1, 1, 1000); | |||||
| run_bench(64, 384, 32, 32, 15, 15, 1, 1, 1000); | |||||
| run_bench(64, 384, 32, 32, 17, 17, 1, 1, 1000); | |||||
| run_bench(64, 384, 32, 32, 19, 19, 1, 1, 1000); | |||||
| run_bench(64, 384, 32, 32, 21, 21, 1, 1, 1000); | |||||
| run_bench(64, 384, 32, 32, 23, 23, 1, 1, 1000); | |||||
| run_bench(64, 384, 32, 32, 25, 25, 1, 1, 1000); | |||||
| run_bench(64, 384, 32, 32, 27, 27, 1, 1, 1000); | |||||
| run_bench(64, 384, 32, 32, 29, 29, 1, 1, 1000); | |||||
| run_bench(64, 384, 32, 32, 31, 31, 1, 1, 1000); | |||||
| } | |||||
| #endif | #endif | ||||
| TEST_F(CUDA, REGION_RESTRICTED_CONV_BWD_DATA_FP32) { | TEST_F(CUDA, REGION_RESTRICTED_CONV_BWD_DATA_FP32) { | ||||
| @@ -585,6 +785,125 @@ TEST_F(CUDA, REGION_RESTRICTED_CONV_BWD_DATA_FP32_RIN_EQ_ROUT) { | |||||
| } | } | ||||
| } | } | ||||
| TEST_F(CUDA, REGION_RESTRICTED_CONV_BWD_FILTER_FP32) { | |||||
| Checker<RegionRestrictedConvolutionBackwardFilter> checker(handle_cuda()); | |||||
| for (auto dt : std::vector<DType>{dtype::Int32(), dtype::Uint8()}) { | |||||
| auto run = [&checker, &dt]( | |||||
| size_t n, size_t g, size_t ih, size_t fh, size_t padding, | |||||
| size_t stride) { | |||||
| RegionRestrictedConvolutionBackwardFilter::Param cur_param; | |||||
| cur_param.mode = RegionRestrictedConvolutionBackwardFilter::Param::Mode:: | |||||
| CROSS_CORRELATION; | |||||
| cur_param.compute_mode = RegionRestrictedConvolutionBackwardFilter::Param:: | |||||
| ComputeMode::DEFAULT; | |||||
| cur_param.sparse = | |||||
| RegionRestrictedConvolutionBackwardFilter::Param::Sparse::GROUP; | |||||
| checker.set_dtype(0, dtype::Float32()) | |||||
| .set_dtype(1, dtype::Float32()) | |||||
| .set_dtype(2, dt) | |||||
| .set_dtype(3, dt); | |||||
| float scale = 64.f / sqrt(fh * fh); | |||||
| UniformFloatRNG rng(scale, 2 * scale); | |||||
| UniformIntRNG r_rng{1, 2}; | |||||
| checker.set_rng(0, &rng).set_rng(1, &rng).set_rng(2, &r_rng).set_rng( | |||||
| 3, &r_rng); | |||||
| cur_param.pad_h = cur_param.pad_w = padding; | |||||
| cur_param.stride_h = cur_param.stride_w = stride; | |||||
| size_t oh = (ih + 2 * padding - fh + 1) / stride; | |||||
| checker.set_param(cur_param).execs({ | |||||
| {n, g * 1, ih, ih}, // src | |||||
| {n, g * 1, oh, oh}, // diff | |||||
| {n, ih, ih}, // rin | |||||
| {n, oh, oh}, // rout | |||||
| {g, 1, 1, fh, fh} // grad | |||||
| }); | |||||
| }; | |||||
| if (dt == dtype::Int32()) { | |||||
| run(4, 8, 32, 5, 5 / 2, 1); | |||||
| run(1, 2, 2, 2, 0, 1); | |||||
| run(1, 2, 3, 3, 0, 1); | |||||
| run(1, 2, 4, 4, 0, 1); | |||||
| run(1, 2, 5, 5, 0, 1); | |||||
| run(1, 2, 6, 6, 0, 1); | |||||
| run(1, 2, 7, 7, 0, 1); | |||||
| } | |||||
| run(4, 8, 32, 7, 7 / 2, 1); | |||||
| run(4, 8, 32, 9, 9 / 2, 1); | |||||
| run(4, 8, 32, 11, 11 / 2, 1); | |||||
| run(4, 8, 32, 13, 13 / 2, 1); | |||||
| run(4, 8, 32, 15, 15 / 2, 1); | |||||
| run(4, 8, 32, 17, 17 / 2, 1); | |||||
| run(4, 8, 32, 19, 19 / 2, 1); | |||||
| run(4, 8, 32, 21, 21 / 2, 1); | |||||
| run(4, 8, 32, 23, 23 / 2, 1); | |||||
| run(4, 8, 32, 25, 25 / 2, 1); | |||||
| run(4, 8, 32, 27, 27 / 2, 1); | |||||
| run(4, 8, 32, 29, 29 / 2, 1); | |||||
| run(4, 8, 32, 31, 31 / 2, 1); | |||||
| } | |||||
| } | |||||
| TEST_F(CUDA, REGION_RESTRICTED_CONV_BWD_FILTER_FP32_RIN_EQ_ROUT) { | |||||
| Checker<RegionRestrictedConvolutionBackwardFilter> checker(handle_cuda()); | |||||
| for (auto dt : std::vector<DType>{dtype::Int32(), dtype::Uint8()}) { | |||||
| auto run = [&checker, &dt]( | |||||
| size_t n, size_t g, size_t ih, size_t fh, size_t padding, | |||||
| size_t stride) { | |||||
| RegionRestrictedConvolutionBackwardFilter::Param cur_param; | |||||
| cur_param.mode = RegionRestrictedConvolutionBackwardFilter::Param::Mode:: | |||||
| CROSS_CORRELATION; | |||||
| cur_param.compute_mode = RegionRestrictedConvolutionBackwardFilter::Param:: | |||||
| ComputeMode::DEFAULT; | |||||
| cur_param.sparse = | |||||
| RegionRestrictedConvolutionBackwardFilter::Param::Sparse::GROUP; | |||||
| checker.set_dtype(0, dtype::Float32()) | |||||
| .set_dtype(1, dtype::Float32()) | |||||
| .set_dtype(2, dt) | |||||
| .set_dtype(3, dt); | |||||
| float scale = 64.f / sqrt(fh * fh); | |||||
| UniformFloatRNG rng(scale, 2 * scale); | |||||
| UniformIntRNG r_rng{1, 1}; | |||||
| checker.set_rng(0, &rng).set_rng(1, &rng).set_rng(2, &r_rng).set_rng( | |||||
| 3, &r_rng); | |||||
| cur_param.pad_h = cur_param.pad_w = padding; | |||||
| cur_param.stride_h = cur_param.stride_w = stride; | |||||
| size_t oh = (ih + 2 * padding - fh + 1) / stride; | |||||
| checker.set_param(cur_param).execs({ | |||||
| {n, g * 1, ih, ih}, // src | |||||
| {n, g * 1, oh, oh}, // diff | |||||
| {n, ih, ih}, // rin | |||||
| {n, oh, oh}, // rout | |||||
| {g, 1, 1, fh, fh} // grad | |||||
| }); | |||||
| }; | |||||
| if (dt == dtype::Int32()) { | |||||
| run(4, 8, 32, 5, 5 / 2, 1); | |||||
| run(1, 2, 2, 2, 0, 1); | |||||
| run(1, 2, 3, 3, 0, 1); | |||||
| run(1, 2, 4, 4, 0, 1); | |||||
| run(1, 2, 5, 5, 0, 1); | |||||
| run(1, 2, 6, 6, 0, 1); | |||||
| run(1, 2, 7, 7, 0, 1); | |||||
| } | |||||
| run(4, 8, 32, 7, 7 / 2, 1); | |||||
| run(4, 8, 32, 9, 9 / 2, 1); | |||||
| run(4, 8, 32, 11, 11 / 2, 1); | |||||
| run(4, 8, 32, 13, 13 / 2, 1); | |||||
| run(4, 8, 32, 15, 15 / 2, 1); | |||||
| run(4, 8, 32, 17, 17 / 2, 1); | |||||
| run(4, 8, 32, 19, 19 / 2, 1); | |||||
| run(4, 8, 32, 21, 21 / 2, 1); | |||||
| run(4, 8, 32, 23, 23 / 2, 1); | |||||
| run(4, 8, 32, 25, 25 / 2, 1); | |||||
| run(4, 8, 32, 27, 27 / 2, 1); | |||||
| run(4, 8, 32, 29, 29 / 2, 1); | |||||
| run(4, 8, 32, 31, 31 / 2, 1); | |||||
| } | |||||
| } | |||||
| } // namespace test | } // namespace test | ||||
| } // namespace megdnn | } // namespace megdnn | ||||