GitOrigin-RevId: 5684e5ea43
tags/v1.11.1
| @@ -8,7 +8,7 @@ | |||||
| import enum | import enum | ||||
| import os.path | import os.path | ||||
| import shutil | import shutil | ||||
| from typing import Tuple, List | |||||
| from typing import List, Tuple | |||||
| from library import * | from library import * | ||||
| @@ -5,14 +5,13 @@ | |||||
| # | # | ||||
| import enum | import enum | ||||
| import os.path | |||||
| import shutil | |||||
| import functools | import functools | ||||
| import operator | import operator | ||||
| import os.path | |||||
| import shutil | |||||
| from library import * | from library import * | ||||
| ################################################################################################### | ################################################################################################### | ||||
| # | # | ||||
| # Data structure modeling a GEMM operation | # Data structure modeling a GEMM operation | ||||
| @@ -1,11 +1,11 @@ | |||||
| from generator import ( | |||||
| GenerateGemmOperations, | |||||
| GenerateGemvOperations, | |||||
| from generator import ( # isort: skip; isort: skip | |||||
| GenerateConv2dOperations, | GenerateConv2dOperations, | ||||
| GenerateDeconvOperations, | GenerateDeconvOperations, | ||||
| GenerateDwconv2dFpropOperations, | |||||
| GenerateDwconv2dDgradOperations, | GenerateDwconv2dDgradOperations, | ||||
| GenerateDwconv2dFpropOperations, | |||||
| GenerateDwconv2dWgradOperations, | GenerateDwconv2dWgradOperations, | ||||
| GenerateGemmOperations, | |||||
| GenerateGemvOperations, | |||||
| ) | ) | ||||
| @@ -35,12 +35,14 @@ def write_op_list(f, gen_op, gen_type): | |||||
| if gen_op != "gemv": | if gen_op != "gemv": | ||||
| f.write(' "all_%s_%s_operations.cu",\n' % (gen_op, gen_type)) | f.write(' "all_%s_%s_operations.cu",\n' % (gen_op, gen_type)) | ||||
| # Write down a list of merged filenames | # Write down a list of merged filenames | ||||
| def write_merge_file_name(f, gen_op, gen_type, split_number): | def write_merge_file_name(f, gen_op, gen_type, split_number): | ||||
| for i in range(0, split_number): | for i in range(0, split_number): | ||||
| f.write(' "{}_{}_{}.cu",\n'.format(gen_op,gen_type,i)) | |||||
| f.write(' "{}_{}_{}.cu",\n'.format(gen_op, gen_type, i)) | |||||
| if gen_op != "gemv": | if gen_op != "gemv": | ||||
| f.write(' "all_{}_{}_operations.cu",\n'.format(gen_op,gen_type)) | |||||
| f.write(' "all_{}_{}_operations.cu",\n'.format(gen_op, gen_type)) | |||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||
| with open("list.bzl", "w") as f: | with open("list.bzl", "w") as f: | ||||
| @@ -4,12 +4,12 @@ | |||||
| # \brief Generates the CUTLASS Library's instances | # \brief Generates the CUTLASS Library's instances | ||||
| # | # | ||||
| import argparse | |||||
| import enum | import enum | ||||
| import os.path | import os.path | ||||
| import shutil | |||||
| import argparse | |||||
| import platform | import platform | ||||
| import string | import string | ||||
| from library import * | from library import * | ||||
| from manifest import * | from manifest import * | ||||
| @@ -899,9 +899,12 @@ def GenerateGemm_Simt(args): | |||||
| warpShapes.append([warp0, warp1]) | warpShapes.append([warp0, warp1]) | ||||
| # sgemm | # sgemm | ||||
| precisionType, precisionBits, threadblockMaxElements, threadblockTilesL0 = precisions[ | |||||
| "s" | |||||
| ] | |||||
| ( | |||||
| precisionType, | |||||
| precisionBits, | |||||
| threadblockMaxElements, | |||||
| threadblockTilesL0, | |||||
| ) = precisions["s"] | |||||
| layouts = [ | layouts = [ | ||||
| (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), # nn | (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), # nn | ||||
| @@ -1091,9 +1094,12 @@ def GenerateDwconv2d_Simt(args, conv_kind): | |||||
| warpShapes.append([warp0, warp1]) | warpShapes.append([warp0, warp1]) | ||||
| # sgemm | # sgemm | ||||
| precisionType, precisionBits, threadblockMaxElements, threadblockTilesL0 = precisions[ | |||||
| "s" | |||||
| ] | |||||
| ( | |||||
| precisionType, | |||||
| precisionBits, | |||||
| threadblockMaxElements, | |||||
| threadblockTilesL0, | |||||
| ) = precisions["s"] | |||||
| layouts = [(LayoutType.TensorNCHW, LayoutType.TensorNCHW)] | layouts = [(LayoutType.TensorNCHW, LayoutType.TensorNCHW)] | ||||
| @@ -1304,7 +1310,7 @@ def GenerateDwconv2d_TensorOp_884(args, conv_kind): | |||||
| for dst_type, dst_layout in zip(dst_types, dst_layouts): | for dst_type, dst_layout in zip(dst_types, dst_layouts): | ||||
| for alignment_src in alignment_constraints: | for alignment_src in alignment_constraints: | ||||
| if conv_kind == ConvKind.Wgrad: | if conv_kind == ConvKind.Wgrad: | ||||
| # skip io16xc16 | |||||
| # skip io16xc16 | |||||
| if math_inst.element_accumulator == DataType.f16: | if math_inst.element_accumulator == DataType.f16: | ||||
| continue | continue | ||||
| for alignment_diff in alignment_constraints: | for alignment_diff in alignment_constraints: | ||||
| @@ -1319,7 +1325,7 @@ def GenerateDwconv2d_TensorOp_884(args, conv_kind): | |||||
| min_cc, | min_cc, | ||||
| alignment_src, | alignment_src, | ||||
| alignment_diff, | alignment_diff, | ||||
| 32, # always f32 output | |||||
| 32, # always f32 output | |||||
| SpecialOptimizeDesc.NoneSpecialOpt, | SpecialOptimizeDesc.NoneSpecialOpt, | ||||
| ImplicitGemmMode.GemmNT, | ImplicitGemmMode.GemmNT, | ||||
| False, | False, | ||||
| @@ -1656,6 +1662,7 @@ def GenerateGemvOperations(args): | |||||
| ) | ) | ||||
| return GenerateGemv_Simt(args) | return GenerateGemv_Simt(args) | ||||
| ################################################################################ | ################################################################################ | ||||
| # parameters | # parameters | ||||
| # split_number - the concated file will be divided into split_number parts | # split_number - the concated file will be divided into split_number parts | ||||
| @@ -1668,10 +1675,21 @@ def GenerateGemvOperations(args): | |||||
| # epilogue - the epilogue in the file | # epilogue - the epilogue in the file | ||||
| # wrapper_path - wrapper path | # wrapper_path - wrapper path | ||||
| ################################################################################ | ################################################################################ | ||||
| def ConcatFile(split_number:int, file_path:str,operations:str,type:str,head:str,required_cuda_ver_major:str, required_cuda_ver_minor:str, epilogue:str, wrapper_path = None): | |||||
| def ConcatFile( | |||||
| split_number: int, | |||||
| file_path: str, | |||||
| operations: str, | |||||
| type: str, | |||||
| head: str, | |||||
| required_cuda_ver_major: str, | |||||
| required_cuda_ver_minor: str, | |||||
| epilogue: str, | |||||
| wrapper_path=None, | |||||
| ): | |||||
| import os | import os | ||||
| meragefiledir = file_path | meragefiledir = file_path | ||||
| filenames=os.listdir(meragefiledir) | |||||
| filenames = os.listdir(meragefiledir) | |||||
| # filter file | # filter file | ||||
| if "tensorop" in type: | if "tensorop" in type: | ||||
| sub_string_1 = "tensorop" | sub_string_1 = "tensorop" | ||||
| @@ -1679,197 +1697,183 @@ def ConcatFile(split_number:int, file_path:str,operations:str,type:str,head:str, | |||||
| else: | else: | ||||
| sub_string_1 = sub_string_2 = "simt" | sub_string_1 = sub_string_2 = "simt" | ||||
| if "dwconv2d_" in operations: | if "dwconv2d_" in operations: | ||||
| filtered_operations = operations[:2]+operations[9:] | |||||
| filtered_operations = operations[:2] + operations[9:] | |||||
| elif ("conv2d" in operations) or ("deconv" in operations): | elif ("conv2d" in operations) or ("deconv" in operations): | ||||
| filtered_operations = "cutlass" | filtered_operations = "cutlass" | ||||
| else: | else: | ||||
| filtered_operations = operations | filtered_operations = operations | ||||
| #get the file list number | |||||
| # get the file list number | |||||
| file_list = {} | file_list = {} | ||||
| file_list[operations + type] = 0 | file_list[operations + type] = 0 | ||||
| for filename in filenames: | for filename in filenames: | ||||
| if (filtered_operations in filename) and (sub_string_1 in filename) and (sub_string_2 in filename) and ("all_" not in filename): | |||||
| if ( | |||||
| (filtered_operations in filename) | |||||
| and (sub_string_1 in filename) | |||||
| and (sub_string_2 in filename) | |||||
| and ("all_" not in filename) | |||||
| ): | |||||
| file_list[operations + type] += 1 | file_list[operations + type] += 1 | ||||
| #concat file for linux | |||||
| # concat file for linux | |||||
| flag_1 = 0 | flag_1 = 0 | ||||
| flag_2 = 0 | flag_2 = 0 | ||||
| for filename in filenames: | for filename in filenames: | ||||
| if (filtered_operations in filename) and (sub_string_1 in filename) and (sub_string_2 in filename) and ("all_" not in filename): | |||||
| if ( | |||||
| (filtered_operations in filename) | |||||
| and (sub_string_1 in filename) | |||||
| and (sub_string_2 in filename) | |||||
| and ("all_" not in filename) | |||||
| ): | |||||
| flag_1 += 1 | flag_1 += 1 | ||||
| filepath=meragefiledir+'/'+filename | |||||
| if (flag_1 >= flag_2 * (file_list[operations + type]/split_number)) and (flag_1 <= (flag_2 + 1) * (file_list[operations + type]/split_number)): | |||||
| file =open(file_path + '/{}_{}_{}.cu'.format(operations,type, flag_2),'a') | |||||
| #write Template at the head | |||||
| filepath = meragefiledir + "/" + filename | |||||
| if (flag_1 >= flag_2 * (file_list[operations + type] / split_number)) and ( | |||||
| flag_1 <= (flag_2 + 1) * (file_list[operations + type] / split_number) | |||||
| ): | |||||
| file = open( | |||||
| file_path + "/{}_{}_{}.cu".format(operations, type, flag_2), "a" | |||||
| ) | |||||
| # write Template at the head | |||||
| if wrapper_path is None: | if wrapper_path is None: | ||||
| file.write( | file.write( | ||||
| SubstituteTemplate( | SubstituteTemplate( | ||||
| head, | head, | ||||
| { | { | ||||
| "required_cuda_ver_major": str( | |||||
| required_cuda_ver_major | |||||
| ), | |||||
| "required_cuda_ver_minor": str( | |||||
| required_cuda_ver_minor | |||||
| ), | |||||
| "required_cuda_ver_major": str(required_cuda_ver_major), | |||||
| "required_cuda_ver_minor": str(required_cuda_ver_minor), | |||||
| }, | }, | ||||
| ) | ) | ||||
| ) | ) | ||||
| else: | else: | ||||
| file.write( | file.write( | ||||
| SubstituteTemplate( | |||||
| head, | |||||
| { | |||||
| "wrapper_path": wrapper_path, | |||||
| "required_cuda_ver_major": str( | |||||
| required_cuda_ver_major | |||||
| ), | |||||
| "required_cuda_ver_minor": str( | |||||
| required_cuda_ver_minor | |||||
| ), | |||||
| }, | |||||
| ) | |||||
| SubstituteTemplate( | |||||
| head, | |||||
| { | |||||
| "wrapper_path": wrapper_path, | |||||
| "required_cuda_ver_major": str(required_cuda_ver_major), | |||||
| "required_cuda_ver_minor": str(required_cuda_ver_minor), | |||||
| }, | |||||
| ) | ) | ||||
| ) | |||||
| # concat all the remaining files | # concat all the remaining files | ||||
| if flag_2 == (split_number - 1): | if flag_2 == (split_number - 1): | ||||
| for line in open(filepath): | for line in open(filepath): | ||||
| file.writelines(line) | file.writelines(line) | ||||
| os.remove(filepath) | os.remove(filepath) | ||||
| file.write('\n') | |||||
| file.write("\n") | |||||
| file.write(epilogue) | file.write(epilogue) | ||||
| continue | continue | ||||
| for line in open(filepath): | for line in open(filepath): | ||||
| file.writelines(line) | file.writelines(line) | ||||
| os.remove(filepath) | os.remove(filepath) | ||||
| file.write('\n') | |||||
| file.write("\n") | |||||
| file.write(epilogue) | file.write(epilogue) | ||||
| else: | else: | ||||
| #write Template at the head | |||||
| # write Template at the head | |||||
| if wrapper_path is None: | if wrapper_path is None: | ||||
| file.write( | file.write( | ||||
| SubstituteTemplate( | SubstituteTemplate( | ||||
| head, | head, | ||||
| { | { | ||||
| "required_cuda_ver_major": str( | |||||
| required_cuda_ver_major | |||||
| ), | |||||
| "required_cuda_ver_minor": str( | |||||
| required_cuda_ver_minor | |||||
| ), | |||||
| "required_cuda_ver_major": str(required_cuda_ver_major), | |||||
| "required_cuda_ver_minor": str(required_cuda_ver_minor), | |||||
| }, | }, | ||||
| ) | ) | ||||
| ) | ) | ||||
| else: | else: | ||||
| file.write( | file.write( | ||||
| SubstituteTemplate( | |||||
| head, | |||||
| { | |||||
| "wrapper_path": wrapper_path, | |||||
| "required_cuda_ver_major": str( | |||||
| required_cuda_ver_major | |||||
| ), | |||||
| "required_cuda_ver_minor": str( | |||||
| required_cuda_ver_minor | |||||
| ), | |||||
| }, | |||||
| ) | |||||
| SubstituteTemplate( | |||||
| head, | |||||
| { | |||||
| "wrapper_path": wrapper_path, | |||||
| "required_cuda_ver_major": str(required_cuda_ver_major), | |||||
| "required_cuda_ver_minor": str(required_cuda_ver_minor), | |||||
| }, | |||||
| ) | ) | ||||
| ) | |||||
| for line in open(filepath): | for line in open(filepath): | ||||
| file.writelines(line) | file.writelines(line) | ||||
| os.remove(filepath) | os.remove(filepath) | ||||
| file.write('\n') | |||||
| file.write("\n") | |||||
| file.write(epilogue) | file.write(epilogue) | ||||
| file.close() | file.close() | ||||
| flag_2 += 1 | flag_2 += 1 | ||||
| #concat file for windows | |||||
| # concat file for windows | |||||
| elif filename[0].isdigit() and ("all_" not in filename): | elif filename[0].isdigit() and ("all_" not in filename): | ||||
| flag_1 += 1 | flag_1 += 1 | ||||
| filepath=meragefiledir+'/'+filename | |||||
| if (flag_1 >= flag_2 * (len(filenames)/split_number)) and (flag_1 <= (flag_2 + 1) * (len(filenames)/split_number)): | |||||
| file =open(file_path + '/{}_{}_{}.cu'.format(operations,type, flag_2),'a') | |||||
| #write Template at the head | |||||
| filepath = meragefiledir + "/" + filename | |||||
| if (flag_1 >= flag_2 * (len(filenames) / split_number)) and ( | |||||
| flag_1 <= (flag_2 + 1) * (len(filenames) / split_number) | |||||
| ): | |||||
| file = open( | |||||
| file_path + "/{}_{}_{}.cu".format(operations, type, flag_2), "a" | |||||
| ) | |||||
| # write Template at the head | |||||
| if wrapper_path is None: | if wrapper_path is None: | ||||
| file.write( | file.write( | ||||
| SubstituteTemplate( | SubstituteTemplate( | ||||
| head, | head, | ||||
| { | { | ||||
| "required_cuda_ver_major": str( | |||||
| required_cuda_ver_major | |||||
| ), | |||||
| "required_cuda_ver_minor": str( | |||||
| required_cuda_ver_minor | |||||
| ), | |||||
| "required_cuda_ver_major": str(required_cuda_ver_major), | |||||
| "required_cuda_ver_minor": str(required_cuda_ver_minor), | |||||
| }, | }, | ||||
| ) | ) | ||||
| ) | ) | ||||
| else: | else: | ||||
| file.write( | file.write( | ||||
| SubstituteTemplate( | |||||
| head, | |||||
| { | |||||
| "wrapper_path": wrapper_path, | |||||
| "required_cuda_ver_major": str( | |||||
| required_cuda_ver_major | |||||
| ), | |||||
| "required_cuda_ver_minor": str( | |||||
| required_cuda_ver_minor | |||||
| ), | |||||
| }, | |||||
| ) | |||||
| SubstituteTemplate( | |||||
| head, | |||||
| { | |||||
| "wrapper_path": wrapper_path, | |||||
| "required_cuda_ver_major": str(required_cuda_ver_major), | |||||
| "required_cuda_ver_minor": str(required_cuda_ver_minor), | |||||
| }, | |||||
| ) | ) | ||||
| ) | |||||
| # concat all the remaining files | # concat all the remaining files | ||||
| if flag_2 == (split_number - 1): | if flag_2 == (split_number - 1): | ||||
| for line in open(filepath): | for line in open(filepath): | ||||
| file.writelines(line) | file.writelines(line) | ||||
| os.remove(filepath) | os.remove(filepath) | ||||
| file.write('\n') | |||||
| file.write("\n") | |||||
| file.write(epilogue) | file.write(epilogue) | ||||
| continue | continue | ||||
| for line in open(filepath): | for line in open(filepath): | ||||
| file.writelines(line) | file.writelines(line) | ||||
| os.remove(filepath) | os.remove(filepath) | ||||
| file.write('\n') | |||||
| file.write("\n") | |||||
| file.write(epilogue) | file.write(epilogue) | ||||
| else: | else: | ||||
| #write Template at the head | |||||
| # write Template at the head | |||||
| if wrapper_path is None: | if wrapper_path is None: | ||||
| file.write( | file.write( | ||||
| SubstituteTemplate( | SubstituteTemplate( | ||||
| head, | head, | ||||
| { | { | ||||
| "required_cuda_ver_major": str( | |||||
| required_cuda_ver_major | |||||
| ), | |||||
| "required_cuda_ver_minor": str( | |||||
| required_cuda_ver_minor | |||||
| ), | |||||
| "required_cuda_ver_major": str(required_cuda_ver_major), | |||||
| "required_cuda_ver_minor": str(required_cuda_ver_minor), | |||||
| }, | }, | ||||
| ) | ) | ||||
| ) | ) | ||||
| else: | else: | ||||
| file.write( | file.write( | ||||
| SubstituteTemplate( | |||||
| head, | |||||
| { | |||||
| "wrapper_path": wrapper_path, | |||||
| "required_cuda_ver_major": str( | |||||
| required_cuda_ver_major | |||||
| ), | |||||
| "required_cuda_ver_minor": str( | |||||
| required_cuda_ver_minor | |||||
| ), | |||||
| }, | |||||
| ) | |||||
| SubstituteTemplate( | |||||
| head, | |||||
| { | |||||
| "wrapper_path": wrapper_path, | |||||
| "required_cuda_ver_major": str(required_cuda_ver_major), | |||||
| "required_cuda_ver_minor": str(required_cuda_ver_minor), | |||||
| }, | |||||
| ) | ) | ||||
| ) | |||||
| for line in open(filepath): | for line in open(filepath): | ||||
| file.writelines(line) | file.writelines(line) | ||||
| os.remove(filepath) | os.remove(filepath) | ||||
| file.write('\n') | |||||
| file.write("\n") | |||||
| file.write(epilogue) | file.write(epilogue) | ||||
| file.close() | file.close() | ||||
| flag_2 += 1 | flag_2 += 1 | ||||
| ################################################################################################### | ################################################################################################### | ||||
| ################################################################################################### | ################################################################################################### | ||||
| @@ -1940,39 +1944,97 @@ if __name__ == "__main__": | |||||
| args.output, operation, short_path | args.output, operation, short_path | ||||
| ) as emitter: | ) as emitter: | ||||
| emitter.emit() | emitter.emit() | ||||
| head = EmitConvSingleKernelWrapper(args.output, operations[0], short_path).header_template | |||||
| head = EmitConvSingleKernelWrapper( | |||||
| args.output, operations[0], short_path | |||||
| ).header_template | |||||
| required_cuda_ver_major = operations[0].required_cuda_ver_major | required_cuda_ver_major = operations[0].required_cuda_ver_major | ||||
| required_cuda_ver_minor = operations[0].required_cuda_ver_minor | required_cuda_ver_minor = operations[0].required_cuda_ver_minor | ||||
| epilogue = EmitConvSingleKernelWrapper(args.output, operations[0], short_path).epilogue_template | |||||
| epilogue = EmitConvSingleKernelWrapper( | |||||
| args.output, operations[0], short_path | |||||
| ).epilogue_template | |||||
| if "tensorop" in args.type: | if "tensorop" in args.type: | ||||
| ConcatFile(4, args.output,args.operations, args.type, head,required_cuda_ver_major, required_cuda_ver_minor, epilogue) | |||||
| ConcatFile( | |||||
| 4, | |||||
| args.output, | |||||
| args.operations, | |||||
| args.type, | |||||
| head, | |||||
| required_cuda_ver_major, | |||||
| required_cuda_ver_minor, | |||||
| epilogue, | |||||
| ) | |||||
| else: | else: | ||||
| ConcatFile(2, args.output,args.operations, args.type, head,required_cuda_ver_major, required_cuda_ver_minor, epilogue) | |||||
| ConcatFile( | |||||
| 2, | |||||
| args.output, | |||||
| args.operations, | |||||
| args.type, | |||||
| head, | |||||
| required_cuda_ver_major, | |||||
| required_cuda_ver_minor, | |||||
| epilogue, | |||||
| ) | |||||
| elif args.operations == "gemm": | elif args.operations == "gemm": | ||||
| for operation in operations: | for operation in operations: | ||||
| with EmitGemmSingleKernelWrapper( | with EmitGemmSingleKernelWrapper( | ||||
| args.output, operation, short_path | args.output, operation, short_path | ||||
| ) as emitter: | ) as emitter: | ||||
| emitter.emit() | emitter.emit() | ||||
| head = EmitGemmSingleKernelWrapper(args.output, operations[0], short_path).header_template | |||||
| head = EmitGemmSingleKernelWrapper( | |||||
| args.output, operations[0], short_path | |||||
| ).header_template | |||||
| required_cuda_ver_major = operations[0].required_cuda_ver_major | required_cuda_ver_major = operations[0].required_cuda_ver_major | ||||
| required_cuda_ver_minor = operations[0].required_cuda_ver_minor | required_cuda_ver_minor = operations[0].required_cuda_ver_minor | ||||
| epilogue = EmitGemmSingleKernelWrapper(args.output, operations[0], short_path).epilogue_template | |||||
| epilogue = EmitGemmSingleKernelWrapper( | |||||
| args.output, operations[0], short_path | |||||
| ).epilogue_template | |||||
| if args.type == "tensorop884": | if args.type == "tensorop884": | ||||
| ConcatFile(30, args.output, args.operations, args.type, head,required_cuda_ver_major, required_cuda_ver_minor, epilogue) | |||||
| ConcatFile( | |||||
| 30, | |||||
| args.output, | |||||
| args.operations, | |||||
| args.type, | |||||
| head, | |||||
| required_cuda_ver_major, | |||||
| required_cuda_ver_minor, | |||||
| epilogue, | |||||
| ) | |||||
| else: | else: | ||||
| ConcatFile(2, args.output, args.operations, args.type, head,required_cuda_ver_major, required_cuda_ver_minor, epilogue) | |||||
| ConcatFile( | |||||
| 2, | |||||
| args.output, | |||||
| args.operations, | |||||
| args.type, | |||||
| head, | |||||
| required_cuda_ver_major, | |||||
| required_cuda_ver_minor, | |||||
| epilogue, | |||||
| ) | |||||
| elif args.operations == "gemv": | elif args.operations == "gemv": | ||||
| for operation in operations: | for operation in operations: | ||||
| with EmitGemvSingleKernelWrapper( | with EmitGemvSingleKernelWrapper( | ||||
| args.output, operation, gemv_wrapper_path, short_path | args.output, operation, gemv_wrapper_path, short_path | ||||
| ) as emitter: | ) as emitter: | ||||
| emitter.emit() | emitter.emit() | ||||
| head = EmitGemvSingleKernelWrapper(args.output, operations[0], gemv_wrapper_path, short_path).header_template | |||||
| head = EmitGemvSingleKernelWrapper( | |||||
| args.output, operations[0], gemv_wrapper_path, short_path | |||||
| ).header_template | |||||
| required_cuda_ver_major = operations[0].required_cuda_ver_major | required_cuda_ver_major = operations[0].required_cuda_ver_major | ||||
| required_cuda_ver_minor = operations[0].required_cuda_ver_minor | required_cuda_ver_minor = operations[0].required_cuda_ver_minor | ||||
| epilogue = EmitGemvSingleKernelWrapper(args.output, operations[0], gemv_wrapper_path, short_path).epilogue_template | |||||
| ConcatFile(2, args.output,args.operations, args.type, head,required_cuda_ver_major, required_cuda_ver_minor, epilogue, wrapper_path = gemv_wrapper_path) | |||||
| epilogue = EmitGemvSingleKernelWrapper( | |||||
| args.output, operations[0], gemv_wrapper_path, short_path | |||||
| ).epilogue_template | |||||
| ConcatFile( | |||||
| 2, | |||||
| args.output, | |||||
| args.operations, | |||||
| args.type, | |||||
| head, | |||||
| required_cuda_ver_major, | |||||
| required_cuda_ver_minor, | |||||
| epilogue, | |||||
| wrapper_path=gemv_wrapper_path, | |||||
| ) | |||||
| if args.operations != "gemv": | if args.operations != "gemv": | ||||
| GenerateManifest(args, operations, args.output) | GenerateManifest(args, operations, args.output) | ||||
| @@ -4,11 +4,11 @@ | |||||
| # \brief Generates the CUTLASS Library's instances | # \brief Generates the CUTLASS Library's instances | ||||
| # | # | ||||
| import enum | |||||
| import re | import re | ||||
| ################################################################################################### | ################################################################################################### | ||||
| import enum | |||||
| # The following block implements enum.auto() for Python 3.5 variants that don't include it such | # The following block implements enum.auto() for Python 3.5 variants that don't include it such | ||||
| # as the default 3.5.2 on Ubuntu 16.04. | # as the default 3.5.2 on Ubuntu 16.04. | ||||
| @@ -8,9 +8,9 @@ import enum | |||||
| import os.path | import os.path | ||||
| import shutil | import shutil | ||||
| from library import * | |||||
| from gemm_operation import * | |||||
| from conv2d_operation import * | from conv2d_operation import * | ||||
| from gemm_operation import * | |||||
| from library import * | |||||
| ################################################################################################### | ################################################################################################### | ||||
| @@ -1,59 +1,67 @@ | |||||
| #!/usr/bin/env python3 | #!/usr/bin/env python3 | ||||
| # -*- coding: utf-8 -*- | # -*- coding: utf-8 -*- | ||||
| import os | |||||
| import argparse | import argparse | ||||
| import os | |||||
| from gen_elemwise_utils import DTYPES | from gen_elemwise_utils import DTYPES | ||||
| def main(): | def main(): | ||||
| parser = argparse.ArgumentParser( | parser = argparse.ArgumentParser( | ||||
| description='generate elemwise impl files', | |||||
| formatter_class=argparse.ArgumentDefaultsHelpFormatter) | |||||
| parser.add_argument('--type', type=str, choices=['cuda'], | |||||
| default='cuda', | |||||
| help='generate cuda cond take kernel file') | |||||
| parser.add_argument('output', help='output directory') | |||||
| description="generate elemwise impl files", | |||||
| formatter_class=argparse.ArgumentDefaultsHelpFormatter, | |||||
| ) | |||||
| parser.add_argument( | |||||
| "--type", | |||||
| type=str, | |||||
| choices=["cuda"], | |||||
| default="cuda", | |||||
| help="generate cuda cond take kernel file", | |||||
| ) | |||||
| parser.add_argument("output", help="output directory") | |||||
| args = parser.parse_args() | args = parser.parse_args() | ||||
| if not os.path.isdir(args.output): | if not os.path.isdir(args.output): | ||||
| os.makedirs(args.output) | os.makedirs(args.output) | ||||
| assert args.type =='cuda' | |||||
| cpp_ext = 'cu' | |||||
| assert args.type == "cuda" | |||||
| cpp_ext = "cu" | |||||
| for dtype in DTYPES.keys(): | for dtype in DTYPES.keys(): | ||||
| fname = '{}.{}'.format(dtype, cpp_ext) | |||||
| fname = "{}.{}".format(dtype, cpp_ext) | |||||
| fname = os.path.join(args.output, fname) | fname = os.path.join(args.output, fname) | ||||
| with open(fname, 'w') as fout: | |||||
| with open(fname, "w") as fout: | |||||
| w = lambda s: print(s, file=fout) | w = lambda s: print(s, file=fout) | ||||
| w('// generated by gen_cond_take_kern_impls.py') | |||||
| w("// generated by gen_cond_take_kern_impls.py") | |||||
| w('#include "../kern.inl"') | w('#include "../kern.inl"') | ||||
| w('') | |||||
| if dtype == 'dt_float16' or dtype == 'dt_bfloat16': | |||||
| w('#if !MEGDNN_DISABLE_FLOAT16') | |||||
| w('namespace megdnn {') | |||||
| w('namespace cuda {') | |||||
| w('namespace cond_take {') | |||||
| w('') | |||||
| w('inst_genidx(::megdnn::dtype::{})'.format(DTYPES[dtype][0])) | |||||
| w('#undef inst_genidx') | |||||
| w('') | |||||
| w('inst_copy(::megdnn::dtype::{})'.format(DTYPES[dtype][0])) | |||||
| w('#undef inst_copy') | |||||
| w('#undef inst_copy_') | |||||
| w('') | |||||
| w('} // cond_take') | |||||
| w('} // cuda') | |||||
| w('} // megdnn') | |||||
| if dtype == 'dt_float16' or dtype == 'dt_bfloat16': | |||||
| w('#endif') | |||||
| print('generated {}'.format(fname)) | |||||
| w("") | |||||
| if dtype == "dt_float16" or dtype == "dt_bfloat16": | |||||
| w("#if !MEGDNN_DISABLE_FLOAT16") | |||||
| w("namespace megdnn {") | |||||
| w("namespace cuda {") | |||||
| w("namespace cond_take {") | |||||
| w("") | |||||
| w("inst_genidx(::megdnn::dtype::{})".format(DTYPES[dtype][0])) | |||||
| w("#undef inst_genidx") | |||||
| w("") | |||||
| w("inst_copy(::megdnn::dtype::{})".format(DTYPES[dtype][0])) | |||||
| w("#undef inst_copy") | |||||
| w("#undef inst_copy_") | |||||
| w("") | |||||
| w("} // cond_take") | |||||
| w("} // cuda") | |||||
| w("} // megdnn") | |||||
| if dtype == "dt_float16" or dtype == "dt_bfloat16": | |||||
| w("#endif") | |||||
| print("generated {}".format(fname)) | |||||
| os.utime(args.output) | os.utime(args.output) | ||||
| if __name__ == '__main__': | |||||
| if __name__ == "__main__": | |||||
| main() | main() | ||||
| @@ -1,37 +1,47 @@ | |||||
| #!/usr/bin/env python3 | #!/usr/bin/env python3 | ||||
| # -*- coding: utf-8 -*- | # -*- coding: utf-8 -*- | ||||
| import os | |||||
| import argparse | import argparse | ||||
| import itertools | import itertools | ||||
| import os | |||||
| PREFIXES = { | |||||
| "dp4a": [ | |||||
| ("batch_conv_bias_int8_implicit_gemm_precomp_ncdiv4hw4", True), | |||||
| ("batch_conv_bias_int8_gemm_ncdiv4hw4", False), | |||||
| ("batch_conv_bias_int8_gemm_ncdiv4hw4_ldg_128", False), | |||||
| ] | |||||
| } | |||||
| PREFIXES = {"dp4a": [("batch_conv_bias_int8_implicit_gemm_precomp_ncdiv4hw4", True), ("batch_conv_bias_int8_gemm_ncdiv4hw4", False), ("batch_conv_bias_int8_gemm_ncdiv4hw4_ldg_128", False)]} | |||||
| ACTIVATIONS = {1: ("IDENTITY", "_id"), 2: ("RELU", "_relu"), 3: ("H_SWISH", "_hswish")} | |||||
| ACTIVATIONS = {1: ("IDENTITY", "_id"), | |||||
| 2: ("RELU", "_relu"), | |||||
| 3: ("H_SWISH", "_hswish")} | |||||
| BIASES = { | |||||
| 1: ("PerElementBiasVisitor", "_per_elem"), | |||||
| 2: ("PerChannelBiasVisitor", "_per_chan"), | |||||
| } | |||||
| BIASES = {1: ("PerElementBiasVisitor", "_per_elem"), | |||||
| 2: ("PerChannelBiasVisitor", "_per_chan")} | |||||
| SUFFIXES = {"dp4a": [""], "imma": [""]} | |||||
| SUFFIXES = {"dp4a": [""], | |||||
| "imma": [""]} | |||||
| def main(): | def main(): | ||||
| parser = argparse.ArgumentParser( | parser = argparse.ArgumentParser( | ||||
| description='generate cuda batch conv bias (dp4a/imma) kern impl files', | |||||
| formatter_class=argparse.ArgumentDefaultsHelpFormatter) | |||||
| parser.add_argument('--type', type=str, choices=['dp4a', | |||||
| 'imma'], | |||||
| default='dp4a', help='generate cuda conv bias kernel file') | |||||
| parser.add_argument('output', help='output directory') | |||||
| description="generate cuda batch conv bias (dp4a/imma) kern impl files", | |||||
| formatter_class=argparse.ArgumentDefaultsHelpFormatter, | |||||
| ) | |||||
| parser.add_argument( | |||||
| "--type", | |||||
| type=str, | |||||
| choices=["dp4a", "imma"], | |||||
| default="dp4a", | |||||
| help="generate cuda conv bias kernel file", | |||||
| ) | |||||
| parser.add_argument("output", help="output directory") | |||||
| args = parser.parse_args() | args = parser.parse_args() | ||||
| if not os.path.isdir(args.output): | if not os.path.isdir(args.output): | ||||
| os.makedirs(args.output) | os.makedirs(args.output) | ||||
| inst = ''' | |||||
| inst = """ | |||||
| template void megdnn::cuda::batch_conv_bias::do_PREFIXSUFFIX<BIAS, | template void megdnn::cuda::batch_conv_bias::do_PREFIXSUFFIX<BIAS, | ||||
| IConvEpilogue<Activation<megdnn::param_enumv::BatchConvBias::NonlineMode::ACTIVATION>>>( | IConvEpilogue<Activation<megdnn::param_enumv::BatchConvBias::NonlineMode::ACTIVATION>>>( | ||||
| const int8_t* d_src, | const int8_t* d_src, | ||||
| @@ -41,7 +51,7 @@ template void megdnn::cuda::batch_conv_bias::do_PREFIXSUFFIX<BIAS, | |||||
| const ConvParam& param, | const ConvParam& param, | ||||
| float alpha, | float alpha, | ||||
| float beta, | float beta, | ||||
| cudaStream_t stream);''' | |||||
| cudaStream_t stream);""" | |||||
| for prefix in PREFIXES[args.type]: | for prefix in PREFIXES[args.type]: | ||||
| for suffix in SUFFIXES[args.type]: | for suffix in SUFFIXES[args.type]: | ||||
| @@ -52,17 +62,23 @@ template void megdnn::cuda::batch_conv_bias::do_PREFIXSUFFIX<BIAS, | |||||
| fname = os.path.join(args.output, fname) | fname = os.path.join(args.output, fname) | ||||
| with open(fname, "w") as fout: | with open(fname, "w") as fout: | ||||
| w = lambda s: print(s, file=fout) | w = lambda s: print(s, file=fout) | ||||
| w('// generated by gen_batch_cuda_conv_bias_kern_impls.py') | |||||
| cur_inst = inst.replace("PREFIX", prefix[0]).replace("SUFFIX", suffix).replace("BIAS", bias[0]).replace("ACTIVATION", act[0]) | |||||
| w("// generated by gen_batch_cuda_conv_bias_kern_impls.py") | |||||
| cur_inst = ( | |||||
| inst.replace("PREFIX", prefix[0]) | |||||
| .replace("SUFFIX", suffix) | |||||
| .replace("BIAS", bias[0]) | |||||
| .replace("ACTIVATION", act[0]) | |||||
| ) | |||||
| if has_workspace: | if has_workspace: | ||||
| cur_inst = cur_inst.replace("WORKSPACE", "\nint* d_workspace, ") | cur_inst = cur_inst.replace("WORKSPACE", "\nint* d_workspace, ") | ||||
| else: | else: | ||||
| cur_inst = cur_inst.replace("WORKSPACE", "") | |||||
| cur_inst = cur_inst.replace("WORKSPACE", "") | |||||
| w('#include "../{}{}.cuinl"'.format(prefix[0], suffix)) | w('#include "../{}{}.cuinl"'.format(prefix[0], suffix)) | ||||
| w(cur_inst) | w(cur_inst) | ||||
| print('generated {}'.format(fname)) | |||||
| print("generated {}".format(fname)) | |||||
| os.utime(args.output) | os.utime(args.output) | ||||
| if __name__ == '__main__': | |||||
| if __name__ == "__main__": | |||||
| main() | main() | ||||
| @@ -1,39 +1,57 @@ | |||||
| #!/usr/bin/env python3 | #!/usr/bin/env python3 | ||||
| # -*- coding: utf-8 -*- | # -*- coding: utf-8 -*- | ||||
| import os | |||||
| import argparse | import argparse | ||||
| import itertools | import itertools | ||||
| import os | |||||
| PREFIXES = { | |||||
| "dp4a": "conv_bias_int8_implicit_gemm_cdiv4hwn4", | |||||
| "imma": "conv_bias_int8_implicit_gemm", | |||||
| } | |||||
| PREFIXES = {"dp4a": "conv_bias_int8_implicit_gemm_cdiv4hwn4", "imma": "conv_bias_int8_implicit_gemm"} | |||||
| ACTIVATIONS = {1: ("IDENTITY", "_id"), 2: ("RELU", "_relu"), 3: ("H_SWISH", "_hswish")} | |||||
| ACTIVATIONS = {1: ("IDENTITY", "_id"), | |||||
| 2: ("RELU", "_relu"), | |||||
| 3: ("H_SWISH", "_hswish")} | |||||
| BIASES = { | |||||
| 1: ("PerElementBiasVisitor", "_per_elem"), | |||||
| 2: ("PerChannelBiasVisitor", "_per_chan"), | |||||
| } | |||||
| BIASES = {1: ("PerElementBiasVisitor", "_per_elem"), | |||||
| 2: ("PerChannelBiasVisitor", "_per_chan")} | |||||
| SUFFIXES = { | |||||
| "dp4a": ["", "_ld_64bit", "_ld_64bit_unroll_width", "_unroll_width"], | |||||
| "imma": [ | |||||
| "_imma16x16x16_cdiv4hwn4", | |||||
| "_imma8x32x16_cdiv4hwn4", | |||||
| "_imma32x8x16_cdiv4hwn4", | |||||
| "_imma16x16x16_cdiv4hwn4_reorder_filter", | |||||
| "_imma8x32x16_cdiv4hwn4_reorder_filter", | |||||
| "_imma32x8x16_cdiv4hwn4_reorder_filter", | |||||
| "_imma16x16x16_cdiv4hwn4_unroll_width", | |||||
| "_imma8x32x16_cdiv4hwn4_unroll_width", | |||||
| "_imma32x8x16_cdiv4hwn4_unroll_width", | |||||
| ], | |||||
| } | |||||
| SUFFIXES = {"dp4a": ["", "_ld_64bit", "_ld_64bit_unroll_width", "_unroll_width"], | |||||
| "imma": ["_imma16x16x16_cdiv4hwn4", "_imma8x32x16_cdiv4hwn4", "_imma32x8x16_cdiv4hwn4", | |||||
| "_imma16x16x16_cdiv4hwn4_reorder_filter", "_imma8x32x16_cdiv4hwn4_reorder_filter", "_imma32x8x16_cdiv4hwn4_reorder_filter", | |||||
| "_imma16x16x16_cdiv4hwn4_unroll_width", "_imma8x32x16_cdiv4hwn4_unroll_width", "_imma32x8x16_cdiv4hwn4_unroll_width"]} | |||||
| def main(): | def main(): | ||||
| parser = argparse.ArgumentParser( | parser = argparse.ArgumentParser( | ||||
| description='generate cuda conv bias (dp4a/imma) kern impl files', | |||||
| formatter_class=argparse.ArgumentDefaultsHelpFormatter) | |||||
| parser.add_argument('--type', type=str, choices=['dp4a', | |||||
| 'imma'], | |||||
| default='dp4a', help='generate cuda conv bias kernel file') | |||||
| parser.add_argument('output', help='output directory') | |||||
| description="generate cuda conv bias (dp4a/imma) kern impl files", | |||||
| formatter_class=argparse.ArgumentDefaultsHelpFormatter, | |||||
| ) | |||||
| parser.add_argument( | |||||
| "--type", | |||||
| type=str, | |||||
| choices=["dp4a", "imma"], | |||||
| default="dp4a", | |||||
| help="generate cuda conv bias kernel file", | |||||
| ) | |||||
| parser.add_argument("output", help="output directory") | |||||
| args = parser.parse_args() | args = parser.parse_args() | ||||
| if not os.path.isdir(args.output): | if not os.path.isdir(args.output): | ||||
| os.makedirs(args.output) | os.makedirs(args.output) | ||||
| inst = ''' | |||||
| inst = """ | |||||
| template void megdnn::cuda::conv_bias_int8::do_PREFIXSUFFIX<BIAS, | template void megdnn::cuda::conv_bias_int8::do_PREFIXSUFFIX<BIAS, | ||||
| IConvEpilogue<Activation<megdnn::param_enumv::ConvBias::NonlineMode::ACTIVATION>>>( | IConvEpilogue<Activation<megdnn::param_enumv::ConvBias::NonlineMode::ACTIVATION>>>( | ||||
| const int8_t* d_src, | const int8_t* d_src, | ||||
| @@ -43,7 +61,7 @@ template void megdnn::cuda::conv_bias_int8::do_PREFIXSUFFIX<BIAS, | |||||
| const ConvParam& param, | const ConvParam& param, | ||||
| float alpha, | float alpha, | ||||
| float beta, | float beta, | ||||
| cudaStream_t stream);''' | |||||
| cudaStream_t stream);""" | |||||
| for suffix in SUFFIXES[args.type]: | for suffix in SUFFIXES[args.type]: | ||||
| for _, act in ACTIVATIONS.items(): | for _, act in ACTIVATIONS.items(): | ||||
| @@ -53,13 +71,19 @@ template void megdnn::cuda::conv_bias_int8::do_PREFIXSUFFIX<BIAS, | |||||
| fname = os.path.join(args.output, fname) | fname = os.path.join(args.output, fname) | ||||
| with open(fname, "w") as fout: | with open(fname, "w") as fout: | ||||
| w = lambda s: print(s, file=fout) | w = lambda s: print(s, file=fout) | ||||
| w('// generated by gen_cuda_conv_bias_kern_impls.py') | |||||
| cur_inst = inst.replace("PREFIX", prefix).replace("SUFFIX", suffix).replace("BIAS", bias[0]).replace("ACTIVATION", act[0]) | |||||
| w("// generated by gen_cuda_conv_bias_kern_impls.py") | |||||
| cur_inst = ( | |||||
| inst.replace("PREFIX", prefix) | |||||
| .replace("SUFFIX", suffix) | |||||
| .replace("BIAS", bias[0]) | |||||
| .replace("ACTIVATION", act[0]) | |||||
| ) | |||||
| w('#include "../{}{}.cuinl"'.format(prefix, suffix)) | w('#include "../{}{}.cuinl"'.format(prefix, suffix)) | ||||
| w(cur_inst) | w(cur_inst) | ||||
| print('generated {}'.format(fname)) | |||||
| print("generated {}".format(fname)) | |||||
| os.utime(args.output) | os.utime(args.output) | ||||
| if __name__ == '__main__': | |||||
| if __name__ == "__main__": | |||||
| main() | main() | ||||
| @@ -1,34 +1,39 @@ | |||||
| #!/usr/bin/env python3 | #!/usr/bin/env python3 | ||||
| # -*- coding: utf-8 -*- | # -*- coding: utf-8 -*- | ||||
| import os | |||||
| import argparse | import argparse | ||||
| import os | |||||
| from gen_elemwise_utils import ARITIES, MODES | from gen_elemwise_utils import ARITIES, MODES | ||||
| def main(): | def main(): | ||||
| parser = argparse.ArgumentParser( | parser = argparse.ArgumentParser( | ||||
| description='generate elemwise each mode', | |||||
| formatter_class=argparse.ArgumentDefaultsHelpFormatter) | |||||
| description="generate elemwise each mode", | |||||
| formatter_class=argparse.ArgumentDefaultsHelpFormatter, | |||||
| ) | |||||
| parser.add_argument('output', help='output directory') | |||||
| parser.add_argument("output", help="output directory") | |||||
| args = parser.parse_args() | args = parser.parse_args() | ||||
| with open(args.output, 'w') as fout: | |||||
| with open(args.output, "w") as fout: | |||||
| w = lambda s: print(s, file=fout) | w = lambda s: print(s, file=fout) | ||||
| w('// generated by gen_elemwise_each_mode.py') | |||||
| w("// generated by gen_elemwise_each_mode.py") | |||||
| keys = list(MODES.keys()) | keys = list(MODES.keys()) | ||||
| keys.sort() | keys.sort() | ||||
| for (anum, ctype) in keys: | for (anum, ctype) in keys: | ||||
| w('#define MEGDNN_FOREACH_ELEMWISE_MODE_{}_{}(cb) \\'.format( | |||||
| ARITIES[anum], ctype)) | |||||
| w( | |||||
| "#define MEGDNN_FOREACH_ELEMWISE_MODE_{}_{}(cb) \\".format( | |||||
| ARITIES[anum], ctype | |||||
| ) | |||||
| ) | |||||
| for mode in MODES[(anum, ctype)]: | for mode in MODES[(anum, ctype)]: | ||||
| w(' MEGDNN_ELEMWISE_MODE_ENABLE({}, cb) \\'.format(mode)) | |||||
| w('') | |||||
| w(" MEGDNN_ELEMWISE_MODE_ENABLE({}, cb) \\".format(mode)) | |||||
| w("") | |||||
| print('generated each_mode.inl') | |||||
| print("generated each_mode.inl") | |||||
| os.utime(args.output) | os.utime(args.output) | ||||
| if __name__ == '__main__': | |||||
| if __name__ == "__main__": | |||||
| main() | main() | ||||
| @@ -1,56 +1,63 @@ | |||||
| #!/usr/bin/env python3 | #!/usr/bin/env python3 | ||||
| # -*- coding: utf-8 -*- | # -*- coding: utf-8 -*- | ||||
| import os | |||||
| import argparse | import argparse | ||||
| import itertools | import itertools | ||||
| import os | |||||
| from gen_elemwise_utils import ARITIES, DTYPES, MODES | from gen_elemwise_utils import ARITIES, DTYPES, MODES | ||||
| def main(): | def main(): | ||||
| parser = argparse.ArgumentParser( | parser = argparse.ArgumentParser( | ||||
| description='generate elemwise impl files', | |||||
| formatter_class=argparse.ArgumentDefaultsHelpFormatter) | |||||
| parser.add_argument('--type', type=str, choices=['cuda', | |||||
| 'hip', | |||||
| 'cpp'], | |||||
| default='cpp', help='generate cuda/hip kernel file') | |||||
| parser.add_argument('output', help='output directory') | |||||
| description="generate elemwise impl files", | |||||
| formatter_class=argparse.ArgumentDefaultsHelpFormatter, | |||||
| ) | |||||
| parser.add_argument( | |||||
| "--type", | |||||
| type=str, | |||||
| choices=["cuda", "hip", "cpp"], | |||||
| default="cpp", | |||||
| help="generate cuda/hip kernel file", | |||||
| ) | |||||
| parser.add_argument("output", help="output directory") | |||||
| args = parser.parse_args() | args = parser.parse_args() | ||||
| if not os.path.isdir(args.output): | if not os.path.isdir(args.output): | ||||
| os.makedirs(args.output) | os.makedirs(args.output) | ||||
| if args.type == 'cuda': | |||||
| cpp_ext = 'cu' | |||||
| elif args.type == 'hip': | |||||
| cpp_ext = 'cpp.hip' | |||||
| if args.type == "cuda": | |||||
| cpp_ext = "cu" | |||||
| elif args.type == "hip": | |||||
| cpp_ext = "cpp.hip" | |||||
| else: | else: | ||||
| assert args.type == 'cpp' | |||||
| cpp_ext = 'cpp' | |||||
| assert args.type == "cpp" | |||||
| cpp_ext = "cpp" | |||||
| for anum, ctype in itertools.product(ARITIES.keys(), DTYPES.keys()): | for anum, ctype in itertools.product(ARITIES.keys(), DTYPES.keys()): | ||||
| for mode in MODES[(anum, DTYPES[ctype][1])]: | for mode in MODES[(anum, DTYPES[ctype][1])]: | ||||
| formode = 'MEGDNN_ELEMWISE_MODE_ENABLE({}, cb)'.format(mode) | |||||
| fname = '{}_{}.{}'.format(mode, ctype, cpp_ext) | |||||
| formode = "MEGDNN_ELEMWISE_MODE_ENABLE({}, cb)".format(mode) | |||||
| fname = "{}_{}.{}".format(mode, ctype, cpp_ext) | |||||
| fname = os.path.join(args.output, fname) | fname = os.path.join(args.output, fname) | ||||
| with open(fname, 'w') as fout: | |||||
| with open(fname, "w") as fout: | |||||
| w = lambda s: print(s, file=fout) | w = lambda s: print(s, file=fout) | ||||
| w('// generated by gen_elemwise_kern_impls.py') | |||||
| w("// generated by gen_elemwise_kern_impls.py") | |||||
| if ctype == 'dt_float16' or ctype == 'dt_bfloat16': | |||||
| w('#if !MEGDNN_DISABLE_FLOAT16') | |||||
| if ctype == "dt_float16" or ctype == "dt_bfloat16": | |||||
| w("#if !MEGDNN_DISABLE_FLOAT16") | |||||
| w('#define KERN_IMPL_MODE(cb) {}'.format(formode)) | |||||
| w('#define KERN_IMPL_ARITY {}'.format(anum)) | |||||
| w('#define KERN_IMPL_CTYPE {}'.format(ctype)) | |||||
| w("#define KERN_IMPL_MODE(cb) {}".format(formode)) | |||||
| w("#define KERN_IMPL_ARITY {}".format(anum)) | |||||
| w("#define KERN_IMPL_CTYPE {}".format(ctype)) | |||||
| w('#include "../kern_impl.inl"') | w('#include "../kern_impl.inl"') | ||||
| if ctype == 'dt_float16' or ctype == 'dt_bfloat16': | |||||
| w('#endif') | |||||
| if ctype == "dt_float16" or ctype == "dt_bfloat16": | |||||
| w("#endif") | |||||
| print('generated {}'.format(fname)) | |||||
| print("generated {}".format(fname)) | |||||
| os.utime(args.output) | os.utime(args.output) | ||||
| if __name__ == '__main__': | |||||
| if __name__ == "__main__": | |||||
| main() | main() | ||||
| @@ -1,52 +1,66 @@ | |||||
| #!/usr/bin/env python3 | #!/usr/bin/env python3 | ||||
| # -*- coding: utf-8 -*- | # -*- coding: utf-8 -*- | ||||
| import os | |||||
| import argparse | import argparse | ||||
| import itertools | import itertools | ||||
| from gen_elemwise_multi_type_utils import SUPPORT_DTYPES, MODES, SUPPORT_QINT32_DTYPES, QINT32_MODES | |||||
| import os | |||||
| from gen_elemwise_multi_type_utils import ( # isort: skip; isort: skip | |||||
| MODES, | |||||
| QINT32_MODES, | |||||
| SUPPORT_DTYPES, | |||||
| SUPPORT_QINT32_DTYPES, | |||||
| ) | |||||
| def generate(modes, support_dtypes, output, cpp_ext): | def generate(modes, support_dtypes, output, cpp_ext): | ||||
| for anum, ctype in itertools.product(modes.keys(), support_dtypes): | for anum, ctype in itertools.product(modes.keys(), support_dtypes): | ||||
| print('{} : {}'.format(anum, ctype)) | |||||
| print("{} : {}".format(anum, ctype)) | |||||
| src_ctype = ctype[0] | src_ctype = ctype[0] | ||||
| dst_ctype = ctype[1] | dst_ctype = ctype[1] | ||||
| for mode in modes[anum]: | for mode in modes[anum]: | ||||
| formode = 'MEGDNN_ELEMWISE_MODE_ENABLE({}, cb)'.format(mode) | |||||
| fname = '{}_{}_{}.{}'.format(mode, src_ctype, dst_ctype, cpp_ext) | |||||
| formode = "MEGDNN_ELEMWISE_MODE_ENABLE({}, cb)".format(mode) | |||||
| fname = "{}_{}_{}.{}".format(mode, src_ctype, dst_ctype, cpp_ext) | |||||
| fname = os.path.join(output, fname) | fname = os.path.join(output, fname) | ||||
| with open(fname, 'w') as fout: | |||||
| with open(fname, "w") as fout: | |||||
| w = lambda s: print(s, file=fout) | w = lambda s: print(s, file=fout) | ||||
| w('// generated by gen_elemwise_multi_type_kern_impls.py') | |||||
| w("// generated by gen_elemwise_multi_type_kern_impls.py") | |||||
| w('#define KERN_IMPL_MODE(cb) {}'.format(formode)) | |||||
| w('#define KERN_IMPL_ARITY {}'.format(anum)) | |||||
| w('#define KERN_IMPL_STYPE {}'.format(src_ctype)) | |||||
| w('#define KERN_IMPL_DTYPE {}'.format(dst_ctype)) | |||||
| w("#define KERN_IMPL_MODE(cb) {}".format(formode)) | |||||
| w("#define KERN_IMPL_ARITY {}".format(anum)) | |||||
| w("#define KERN_IMPL_STYPE {}".format(src_ctype)) | |||||
| w("#define KERN_IMPL_DTYPE {}".format(dst_ctype)) | |||||
| w('#include "../kern_impl.inl"') | w('#include "../kern_impl.inl"') | ||||
| print('generated {}'.format(fname)) | |||||
| print("generated {}".format(fname)) | |||||
| def main(): | def main(): | ||||
| parser = argparse.ArgumentParser( | parser = argparse.ArgumentParser( | ||||
| description='generate elemwise impl files', | |||||
| formatter_class=argparse.ArgumentDefaultsHelpFormatter) | |||||
| parser.add_argument('--type', type=str, choices=['cuda'], | |||||
| default='cuda', help='generate cuda kernel file') | |||||
| parser.add_argument('output', help='output directory') | |||||
| description="generate elemwise impl files", | |||||
| formatter_class=argparse.ArgumentDefaultsHelpFormatter, | |||||
| ) | |||||
| parser.add_argument( | |||||
| "--type", | |||||
| type=str, | |||||
| choices=["cuda"], | |||||
| default="cuda", | |||||
| help="generate cuda kernel file", | |||||
| ) | |||||
| parser.add_argument("output", help="output directory") | |||||
| args = parser.parse_args() | args = parser.parse_args() | ||||
| if not os.path.isdir(args.output): | if not os.path.isdir(args.output): | ||||
| os.makedirs(args.output) | os.makedirs(args.output) | ||||
| assert args.type == 'cuda' | |||||
| if args.type == 'cuda': | |||||
| cpp_ext = 'cu' | |||||
| assert args.type == "cuda" | |||||
| if args.type == "cuda": | |||||
| cpp_ext = "cu" | |||||
| generate(MODES, SUPPORT_DTYPES, args.output, cpp_ext) | generate(MODES, SUPPORT_DTYPES, args.output, cpp_ext) | ||||
| generate(QINT32_MODES, SUPPORT_QINT32_DTYPES, args.output, cpp_ext) | generate(QINT32_MODES, SUPPORT_QINT32_DTYPES, args.output, cpp_ext) | ||||
| os.utime(args.output) | os.utime(args.output) | ||||
| if __name__ == '__main__': | |||||
| if __name__ == "__main__": | |||||
| main() | main() | ||||
| @@ -1,48 +1,131 @@ | |||||
| # As cuda currently do not support quint8, so we just ignore it. | # As cuda currently do not support quint8, so we just ignore it. | ||||
| SUPPORT_DTYPES = [('dt_qint8', 'dt_qint8')] | |||||
| SUPPORT_QINT32_DTYPES = [('dt_qint32', 'dt_qint8'), ('dt_qint8', 'dt_qint32'), | |||||
| ('dt_qint4', 'dt_qint32'), ('dt_quint4', 'dt_qint32')] | |||||
| SUPPORT_DTYPES = [("dt_qint8", "dt_qint8")] | |||||
| SUPPORT_QINT32_DTYPES = [ | |||||
| ("dt_qint32", "dt_qint8"), | |||||
| ("dt_qint8", "dt_qint32"), | |||||
| ("dt_qint4", "dt_qint32"), | |||||
| ("dt_quint4", "dt_qint32"), | |||||
| ] | |||||
| SUPPORT_DTYPES_Q4 = [('dt_qint4', 'dt_qint4'), ('dt_quint4', 'dt_quint4')] | |||||
| SUPPORT_QINT32_DTYPES_Q4 = [('dt_qint32', 'dt_qint4'), ('dt_qint32', 'dt_quint4')] | |||||
| SUPPORT_DTYPES_Q4 = [("dt_qint4", "dt_qint4"), ("dt_quint4", "dt_quint4")] | |||||
| SUPPORT_QINT32_DTYPES_Q4 = [("dt_qint32", "dt_qint4"), ("dt_qint32", "dt_quint4")] | |||||
| SUPPORT_ARRITY2_DTYPES = ['dt_int32', 'dt_uint8', 'dt_int8', 'dt_int16', 'dt_bool', 'dt_float32', | |||||
| 'dt_float16', 'dt_bfloat16'] | |||||
| SUPPORT_ARRITY1_DTYPES = ['dt_float32','dt_float16', 'dt_bfloat16'] | |||||
| SUPPORT_ARRITY2_DTYPES = [ | |||||
| "dt_int32", | |||||
| "dt_uint8", | |||||
| "dt_int8", | |||||
| "dt_int16", | |||||
| "dt_bool", | |||||
| "dt_float32", | |||||
| "dt_float16", | |||||
| "dt_bfloat16", | |||||
| ] | |||||
| SUPPORT_ARRITY1_DTYPES = ["dt_float32", "dt_float16", "dt_bfloat16"] | |||||
| MODES = { | MODES = { | ||||
| 1: ['RELU', 'ABS', 'NEGATE', 'ACOS', 'ASIN', 'CEIL', 'COS', | |||||
| 'EXP', 'EXPM1', 'FLOOR', 'LOG', 'LOG1P', 'SIGMOID', 'SIN', | |||||
| 'TANH', 'FAST_TANH', 'ROUND', 'ERF', 'ERFINV', 'ERFC', | |||||
| 'ERFCINV', 'H_SWISH', 'SILU', 'GELU'], | |||||
| 2: ['ABS_GRAD', 'ADD', 'FLOOR_DIV', 'MAX', 'MIN', 'MOD', 'MUL', | |||||
| 'SIGMOID_GRAD', 'SUB', 'SWITCH_GT0', 'TANH_GRAD', 'LT', | |||||
| 'LEQ', 'EQ', 'FUSE_ADD_RELU', 'TRUE_DIV', 'POW', | |||||
| 'LOG_SUM_EXP', 'FUSE_ADD_TANH', 'FAST_TANH_GRAD', | |||||
| 'FUSE_ADD_SIGMOID', 'ATAN2', 'H_SWISH_GRAD', | |||||
| 'FUSE_ADD_H_SWISH', 'SILU_GRAD', 'GELU_GRAD'], | |||||
| 3: ['COND_LEQ_MOV', 'COND_LT_MOV', 'FUSE_MUL_ADD3'], | |||||
| 1: [ | |||||
| "RELU", | |||||
| "ABS", | |||||
| "NEGATE", | |||||
| "ACOS", | |||||
| "ASIN", | |||||
| "CEIL", | |||||
| "COS", | |||||
| "EXP", | |||||
| "EXPM1", | |||||
| "FLOOR", | |||||
| "LOG", | |||||
| "LOG1P", | |||||
| "SIGMOID", | |||||
| "SIN", | |||||
| "TANH", | |||||
| "FAST_TANH", | |||||
| "ROUND", | |||||
| "ERF", | |||||
| "ERFINV", | |||||
| "ERFC", | |||||
| "ERFCINV", | |||||
| "H_SWISH", | |||||
| "SILU", | |||||
| "GELU", | |||||
| ], | |||||
| 2: [ | |||||
| "ABS_GRAD", | |||||
| "ADD", | |||||
| "FLOOR_DIV", | |||||
| "MAX", | |||||
| "MIN", | |||||
| "MOD", | |||||
| "MUL", | |||||
| "SIGMOID_GRAD", | |||||
| "SUB", | |||||
| "SWITCH_GT0", | |||||
| "TANH_GRAD", | |||||
| "LT", | |||||
| "LEQ", | |||||
| "EQ", | |||||
| "FUSE_ADD_RELU", | |||||
| "TRUE_DIV", | |||||
| "POW", | |||||
| "LOG_SUM_EXP", | |||||
| "FUSE_ADD_TANH", | |||||
| "FAST_TANH_GRAD", | |||||
| "FUSE_ADD_SIGMOID", | |||||
| "ATAN2", | |||||
| "H_SWISH_GRAD", | |||||
| "FUSE_ADD_H_SWISH", | |||||
| "SILU_GRAD", | |||||
| "GELU_GRAD", | |||||
| ], | |||||
| 3: ["COND_LEQ_MOV", "COND_LT_MOV", "FUSE_MUL_ADD3"], | |||||
| } | } | ||||
| QINT4_MODES = { | QINT4_MODES = { | ||||
| 1: ['RELU', 'ABS', 'NEGATE', 'CEIL', 'FLOOR', 'SIGMOID', | |||||
| 'TANH', 'FAST_TANH', 'ROUND', 'H_SWISH'], | |||||
| 2: ['ADD', 'MAX', 'MIN', 'MUL', 'SUB', 'SWITCH_GT0', | |||||
| 'LT', 'LEQ', 'EQ', 'FUSE_ADD_RELU', 'FUSE_ADD_TANH', | |||||
| 'FUSE_ADD_SIGMOID', 'FUSE_ADD_H_SWISH'], | |||||
| 3: ['COND_LEQ_MOV', 'COND_LT_MOV', 'FUSE_MUL_ADD3'], | |||||
| 1: [ | |||||
| "RELU", | |||||
| "ABS", | |||||
| "NEGATE", | |||||
| "CEIL", | |||||
| "FLOOR", | |||||
| "SIGMOID", | |||||
| "TANH", | |||||
| "FAST_TANH", | |||||
| "ROUND", | |||||
| "H_SWISH", | |||||
| ], | |||||
| 2: [ | |||||
| "ADD", | |||||
| "MAX", | |||||
| "MIN", | |||||
| "MUL", | |||||
| "SUB", | |||||
| "SWITCH_GT0", | |||||
| "LT", | |||||
| "LEQ", | |||||
| "EQ", | |||||
| "FUSE_ADD_RELU", | |||||
| "FUSE_ADD_TANH", | |||||
| "FUSE_ADD_SIGMOID", | |||||
| "FUSE_ADD_H_SWISH", | |||||
| ], | |||||
| 3: ["COND_LEQ_MOV", "COND_LT_MOV", "FUSE_MUL_ADD3"], | |||||
| } | } | ||||
| QINT32_MODES = { | QINT32_MODES = { | ||||
| 1: ['RELU', 'SIGMOID', 'TANH', 'FAST_TANH', 'H_SWISH'], | |||||
| 2: ['ADD', 'FUSE_ADD_RELU', 'FUSE_ADD_SIGMOID', | |||||
| 'FUSE_ADD_TANH', 'FUSE_ADD_H_SWISH'] | |||||
| 1: ["RELU", "SIGMOID", "TANH", "FAST_TANH", "H_SWISH"], | |||||
| 2: [ | |||||
| "ADD", | |||||
| "FUSE_ADD_RELU", | |||||
| "FUSE_ADD_SIGMOID", | |||||
| "FUSE_ADD_TANH", | |||||
| "FUSE_ADD_H_SWISH", | |||||
| ], | |||||
| } | } | ||||
| ARRITY1_BOOL_MODES = { | ARRITY1_BOOL_MODES = { | ||||
| 1: ['ISINF','ISNAN'], | |||||
| 1: ["ISINF", "ISNAN"], | |||||
| } | } | ||||
| ARRITY2_BOOL_MODES = { | ARRITY2_BOOL_MODES = { | ||||
| 2: ['EQ','LEQ','NEQ','LT'], | |||||
| 2: ["EQ", "LEQ", "NEQ", "LT"], | |||||
| } | } | ||||
| @@ -1,52 +1,57 @@ | |||||
| #!/usr/bin/env python3 | #!/usr/bin/env python3 | ||||
| # -*- coding: utf-8 -*- | # -*- coding: utf-8 -*- | ||||
| import os | |||||
| import argparse | import argparse | ||||
| import os | |||||
| from gen_elemwise_utils import DTYPES | from gen_elemwise_utils import DTYPES | ||||
| def main(): | def main(): | ||||
| parser = argparse.ArgumentParser( | parser = argparse.ArgumentParser( | ||||
| description='generate elemwise impl files', | |||||
| formatter_class=argparse.ArgumentDefaultsHelpFormatter) | |||||
| parser.add_argument('--type', type=str, choices=[ | |||||
| 'cuda', | |||||
| 'hip' | |||||
| ], | |||||
| default='cuda', | |||||
| help='generate cuda/hip elemwise special kernel file') | |||||
| parser.add_argument('output', help='output directory') | |||||
| description="generate elemwise impl files", | |||||
| formatter_class=argparse.ArgumentDefaultsHelpFormatter, | |||||
| ) | |||||
| parser.add_argument( | |||||
| "--type", | |||||
| type=str, | |||||
| choices=["cuda", "hip"], | |||||
| default="cuda", | |||||
| help="generate cuda/hip elemwise special kernel file", | |||||
| ) | |||||
| parser.add_argument("output", help="output directory") | |||||
| args = parser.parse_args() | args = parser.parse_args() | ||||
| if not os.path.isdir(args.output): | if not os.path.isdir(args.output): | ||||
| os.makedirs(args.output) | os.makedirs(args.output) | ||||
| if args.type == 'cuda': | |||||
| cpp_ext = 'cu' | |||||
| if args.type == "cuda": | |||||
| cpp_ext = "cu" | |||||
| else: | else: | ||||
| assert args.type =='hip' | |||||
| cpp_ext = 'cpp.hip' | |||||
| assert args.type == "hip" | |||||
| cpp_ext = "cpp.hip" | |||||
| for dtype in DTYPES.keys(): | for dtype in DTYPES.keys(): | ||||
| fname = 'special_{}.{}'.format(dtype, cpp_ext) | |||||
| fname = "special_{}.{}".format(dtype, cpp_ext) | |||||
| fname = os.path.join(args.output, fname) | fname = os.path.join(args.output, fname) | ||||
| with open(fname, 'w') as fout: | |||||
| with open(fname, "w") as fout: | |||||
| w = lambda s: print(s, file=fout) | w = lambda s: print(s, file=fout) | ||||
| w('// generated by gen_elemwise_special_kern_impls.py') | |||||
| if dtype == 'dt_float16' or dtype == 'dt_bfloat16': | |||||
| w('#if !MEGDNN_DISABLE_FLOAT16') | |||||
| w("// generated by gen_elemwise_special_kern_impls.py") | |||||
| if dtype == "dt_float16" or dtype == "dt_bfloat16": | |||||
| w("#if !MEGDNN_DISABLE_FLOAT16") | |||||
| w('#include "../special_kerns.inl"') | w('#include "../special_kerns.inl"') | ||||
| w('INST(::megdnn::dtype::{})'.format(DTYPES[dtype][0])) | |||||
| w('#undef INST') | |||||
| w('}') | |||||
| w('}') | |||||
| if dtype == 'dt_float16' or dtype == 'dt_bfloat16': | |||||
| w('#endif') | |||||
| w("INST(::megdnn::dtype::{})".format(DTYPES[dtype][0])) | |||||
| w("#undef INST") | |||||
| w("}") | |||||
| w("}") | |||||
| if dtype == "dt_float16" or dtype == "dt_bfloat16": | |||||
| w("#endif") | |||||
| print('generated {}'.format(fname)) | |||||
| print("generated {}".format(fname)) | |||||
| os.utime(args.output) | os.utime(args.output) | ||||
| if __name__ == '__main__': | |||||
| if __name__ == "__main__": | |||||
| main() | main() | ||||
| @@ -1,35 +1,95 @@ | |||||
| ARITIES = {1: "UNARY", 2: "BINARY", 3: "TERNARY"} | |||||
| ARITIES = {1: 'UNARY', 2: 'BINARY', 3: 'TERNARY'} | |||||
| DTYPES = {'dt_int32': ('Int32', 'INT'), | |||||
| 'dt_uint8': ('Uint8', 'INT'), | |||||
| 'dt_int8': ('Int8', 'INT'), | |||||
| 'dt_int16': ('Int16', 'INT'), | |||||
| 'dt_bool': ('Bool', 'BOOL'), | |||||
| 'dt_float32': ('Float32', 'FLOAT'), | |||||
| 'dt_float16': ('Float16', 'FLOAT'), | |||||
| 'dt_bfloat16': ('BFloat16', 'FLOAT') | |||||
| } | |||||
| DTYPES = { | |||||
| "dt_int32": ("Int32", "INT"), | |||||
| "dt_uint8": ("Uint8", "INT"), | |||||
| "dt_int8": ("Int8", "INT"), | |||||
| "dt_int16": ("Int16", "INT"), | |||||
| "dt_bool": ("Bool", "BOOL"), | |||||
| "dt_float32": ("Float32", "FLOAT"), | |||||
| "dt_float16": ("Float16", "FLOAT"), | |||||
| "dt_bfloat16": ("BFloat16", "FLOAT"), | |||||
| } | |||||
| MODES = { | MODES = { | ||||
| (1, 'INT'): ['RELU', 'ABS', 'NEGATE'], | |||||
| (2, 'INT'): ['ABS_GRAD', 'ADD', 'FLOOR_DIV', 'MAX', 'MIN', 'MOD', 'MUL', | |||||
| 'SIGMOID_GRAD', 'SUB', 'SWITCH_GT0', 'TANH_GRAD', 'LT', 'LEQ', | |||||
| 'EQ', 'FUSE_ADD_RELU', 'SHL', 'SHR', 'RMULH'], | |||||
| (3, 'INT'): ['COND_LEQ_MOV', 'COND_LT_MOV'], | |||||
| (1, 'FLOAT'): ['RELU', 'ABS', 'NEGATE', 'ACOS', 'ASIN', 'CEIL', 'COS', | |||||
| 'EXP', 'EXPM1', 'FLOOR', 'LOG', 'LOG1P', 'SIGMOID', 'SIN', | |||||
| 'TANH', 'FAST_TANH', 'ROUND', 'ERF', 'ERFINV', 'ERFC', | |||||
| 'ERFCINV', 'H_SWISH', 'SILU', 'GELU'], | |||||
| (2, 'FLOAT'): ['ABS_GRAD', 'ADD', 'FLOOR_DIV', 'MAX', 'MIN', 'MOD', 'MUL', | |||||
| 'SIGMOID_GRAD', 'SUB', 'SWITCH_GT0', 'TANH_GRAD', 'LT', | |||||
| 'LEQ', 'EQ', 'FUSE_ADD_RELU', 'TRUE_DIV', 'POW', | |||||
| 'LOG_SUM_EXP', 'FUSE_ADD_TANH', 'FAST_TANH_GRAD', | |||||
| 'FUSE_ADD_SIGMOID', 'ATAN2', 'H_SWISH_GRAD', | |||||
| 'FUSE_ADD_H_SWISH', 'SILU_GRAD', 'GELU_GRAD'], | |||||
| (3, 'FLOAT'): ['COND_LEQ_MOV', 'COND_LT_MOV', 'FUSE_MUL_ADD3'], | |||||
| (1, 'BOOL'): ['NOT'], | |||||
| (2, 'BOOL'): ['AND', 'OR', 'XOR', 'LT', 'LEQ', 'EQ'], | |||||
| (3, 'BOOL'): [] | |||||
| (1, "INT"): ["RELU", "ABS", "NEGATE"], | |||||
| (2, "INT"): [ | |||||
| "ABS_GRAD", | |||||
| "ADD", | |||||
| "FLOOR_DIV", | |||||
| "MAX", | |||||
| "MIN", | |||||
| "MOD", | |||||
| "MUL", | |||||
| "SIGMOID_GRAD", | |||||
| "SUB", | |||||
| "SWITCH_GT0", | |||||
| "TANH_GRAD", | |||||
| "LT", | |||||
| "LEQ", | |||||
| "EQ", | |||||
| "FUSE_ADD_RELU", | |||||
| "SHL", | |||||
| "SHR", | |||||
| "RMULH", | |||||
| ], | |||||
| (3, "INT"): ["COND_LEQ_MOV", "COND_LT_MOV"], | |||||
| (1, "FLOAT"): [ | |||||
| "RELU", | |||||
| "ABS", | |||||
| "NEGATE", | |||||
| "ACOS", | |||||
| "ASIN", | |||||
| "CEIL", | |||||
| "COS", | |||||
| "EXP", | |||||
| "EXPM1", | |||||
| "FLOOR", | |||||
| "LOG", | |||||
| "LOG1P", | |||||
| "SIGMOID", | |||||
| "SIN", | |||||
| "TANH", | |||||
| "FAST_TANH", | |||||
| "ROUND", | |||||
| "ERF", | |||||
| "ERFINV", | |||||
| "ERFC", | |||||
| "ERFCINV", | |||||
| "H_SWISH", | |||||
| "SILU", | |||||
| "GELU", | |||||
| ], | |||||
| (2, "FLOAT"): [ | |||||
| "ABS_GRAD", | |||||
| "ADD", | |||||
| "FLOOR_DIV", | |||||
| "MAX", | |||||
| "MIN", | |||||
| "MOD", | |||||
| "MUL", | |||||
| "SIGMOID_GRAD", | |||||
| "SUB", | |||||
| "SWITCH_GT0", | |||||
| "TANH_GRAD", | |||||
| "LT", | |||||
| "LEQ", | |||||
| "EQ", | |||||
| "FUSE_ADD_RELU", | |||||
| "TRUE_DIV", | |||||
| "POW", | |||||
| "LOG_SUM_EXP", | |||||
| "FUSE_ADD_TANH", | |||||
| "FAST_TANH_GRAD", | |||||
| "FUSE_ADD_SIGMOID", | |||||
| "ATAN2", | |||||
| "H_SWISH_GRAD", | |||||
| "FUSE_ADD_H_SWISH", | |||||
| "SILU_GRAD", | |||||
| "GELU_GRAD", | |||||
| ], | |||||
| (3, "FLOAT"): ["COND_LEQ_MOV", "COND_LT_MOV", "FUSE_MUL_ADD3"], | |||||
| (1, "BOOL"): ["NOT"], | |||||
| (2, "BOOL"): ["AND", "OR", "XOR", "LT", "LEQ", "EQ"], | |||||
| (3, "BOOL"): [], | |||||
| } | } | ||||
| @@ -3,13 +3,14 @@ | |||||
| import argparse | import argparse | ||||
| import collections | import collections | ||||
| import textwrap | |||||
| import os | |||||
| import hashlib | import hashlib | ||||
| import struct | |||||
| import io | import io | ||||
| import os | |||||
| import struct | |||||
| import textwrap | |||||
| from gen_param_defs import IndentWriterBase, ParamDef, member_defs | |||||
| from gen_param_defs import member_defs, ParamDef, IndentWriterBase | |||||
| class ConverterWriter(IndentWriterBase): | class ConverterWriter(IndentWriterBase): | ||||
| _skip_current_param = False | _skip_current_param = False | ||||
| @@ -20,7 +21,7 @@ class ConverterWriter(IndentWriterBase): | |||||
| def __call__(self, fout, defs): | def __call__(self, fout, defs): | ||||
| super().__call__(fout) | super().__call__(fout) | ||||
| self._write("// %s", self._get_header()) | self._write("// %s", self._get_header()) | ||||
| self._write('#include <flatbuffers/flatbuffers.h>') | |||||
| self._write("#include <flatbuffers/flatbuffers.h>") | |||||
| self._write("namespace mgb {") | self._write("namespace mgb {") | ||||
| self._write("namespace serialization {") | self._write("namespace serialization {") | ||||
| self._write("namespace fbs {") | self._write("namespace fbs {") | ||||
| @@ -33,8 +34,9 @@ class ConverterWriter(IndentWriterBase): | |||||
| self._last_param = p | self._last_param = p | ||||
| self._param_fields = [] | self._param_fields = [] | ||||
| self._fb_fields = ["builder"] | self._fb_fields = ["builder"] | ||||
| self._write("template<>\nstruct ParamConverter<megdnn::param::%s> {", | |||||
| p.name, indent=1) | |||||
| self._write( | |||||
| "template<>\nstruct ParamConverter<megdnn::param::%s> {", p.name, indent=1 | |||||
| ) | |||||
| self._write("using MegDNNType = megdnn::param::%s;", p.name) | self._write("using MegDNNType = megdnn::param::%s;", p.name) | ||||
| self._write("using FlatBufferType = fbs::param::%s;\n", p.name) | self._write("using FlatBufferType = fbs::param::%s;\n", p.name) | ||||
| @@ -42,22 +44,22 @@ class ConverterWriter(IndentWriterBase): | |||||
| if self._skip_current_param: | if self._skip_current_param: | ||||
| self._skip_current_param = False | self._skip_current_param = False | ||||
| return | return | ||||
| self._write("static MegDNNType to_param(const FlatBufferType* fb) {", | |||||
| indent=1) | |||||
| line = 'return {' | |||||
| line += ', '.join(self._param_fields) | |||||
| line += '};' | |||||
| self._write("static MegDNNType to_param(const FlatBufferType* fb) {", indent=1) | |||||
| line = "return {" | |||||
| line += ", ".join(self._param_fields) | |||||
| line += "};" | |||||
| self._write(line) | self._write(line) | ||||
| self._write("}\n", indent=-1) | self._write("}\n", indent=-1) | ||||
| self._write( | self._write( | ||||
| "static flatbuffers::Offset<FlatBufferType> to_flatbuffer(flatbuffers::FlatBufferBuilder& builder, const MegDNNType& param) {", | "static flatbuffers::Offset<FlatBufferType> to_flatbuffer(flatbuffers::FlatBufferBuilder& builder, const MegDNNType& param) {", | ||||
| indent=1) | |||||
| line = 'return fbs::param::Create{}('.format(str(p.name)) | |||||
| line += ', '.join(self._fb_fields) | |||||
| line += ');' | |||||
| indent=1, | |||||
| ) | |||||
| line = "return fbs::param::Create{}(".format(str(p.name)) | |||||
| line += ", ".join(self._fb_fields) | |||||
| line += ");" | |||||
| self._write(line) | self._write(line) | ||||
| self._write('}', indent=-1) | |||||
| self._write("}", indent=-1) | |||||
| self._write("};\n", indent=-1) | self._write("};\n", indent=-1) | ||||
| @@ -68,18 +70,23 @@ class ConverterWriter(IndentWriterBase): | |||||
| return | return | ||||
| self._param_fields.append( | self._param_fields.append( | ||||
| "static_cast<megdnn::param::{}::{}>(fb->{}())".format( | "static_cast<megdnn::param::{}::{}>(fb->{}())".format( | ||||
| str(p.name), str(e.name), e.name_field)) | |||||
| self._fb_fields.append("static_cast<fbs::param::{}>(param.{})".format( | |||||
| key, e.name_field)) | |||||
| str(p.name), str(e.name), e.name_field | |||||
| ) | |||||
| ) | |||||
| self._fb_fields.append( | |||||
| "static_cast<fbs::param::{}>(param.{})".format(key, e.name_field) | |||||
| ) | |||||
| def _on_member_field(self, f): | def _on_member_field(self, f): | ||||
| if self._skip_current_param: | if self._skip_current_param: | ||||
| return | return | ||||
| if f.dtype.cname == 'DTypeEnum': | |||||
| if f.dtype.cname == "DTypeEnum": | |||||
| self._param_fields.append( | self._param_fields.append( | ||||
| "intl::convert_dtype_to_megdnn(fb->{}())".format(f.name)) | |||||
| "intl::convert_dtype_to_megdnn(fb->{}())".format(f.name) | |||||
| ) | |||||
| self._fb_fields.append( | self._fb_fields.append( | ||||
| "intl::convert_dtype_to_fbs(param.{})".format(f.name)) | |||||
| "intl::convert_dtype_to_fbs(param.{})".format(f.name) | |||||
| ) | |||||
| else: | else: | ||||
| self._param_fields.append("fb->{}()".format(f.name)) | self._param_fields.append("fb->{}()".format(f.name)) | ||||
| self._fb_fields.append("param.{}".format(f.name)) | self._fb_fields.append("param.{}".format(f.name)) | ||||
| @@ -93,28 +100,33 @@ class ConverterWriter(IndentWriterBase): | |||||
| enum_name = e.src_class + e.src_name | enum_name = e.src_class + e.src_name | ||||
| self._param_fields.append( | self._param_fields.append( | ||||
| "static_cast<megdnn::param::{}::{}>(fb->{}())".format( | "static_cast<megdnn::param::{}::{}>(fb->{}())".format( | ||||
| e.src_class, e.src_name, e.name_field)) | |||||
| self._fb_fields.append("static_cast<fbs::param::{}>(param.{})".format( | |||||
| enum_name, e.name_field)) | |||||
| e.src_class, e.src_name, e.name_field | |||||
| ) | |||||
| ) | |||||
| self._fb_fields.append( | |||||
| "static_cast<fbs::param::{}>(param.{})".format(enum_name, e.name_field) | |||||
| ) | |||||
| def main(): | def main(): | ||||
| parser = argparse.ArgumentParser( | parser = argparse.ArgumentParser( | ||||
| 'generate convert functions between FlatBuffers type and MegBrain type') | |||||
| parser.add_argument('input') | |||||
| parser.add_argument('output') | |||||
| "generate convert functions between FlatBuffers type and MegBrain type" | |||||
| ) | |||||
| parser.add_argument("input") | |||||
| parser.add_argument("output") | |||||
| args = parser.parse_args() | args = parser.parse_args() | ||||
| with open(args.input) as fin: | with open(args.input) as fin: | ||||
| inputs = fin.read() | inputs = fin.read() | ||||
| exec(inputs, {'pdef': ParamDef, 'Doc': member_defs.Doc}) | |||||
| exec(inputs, {"pdef": ParamDef, "Doc": member_defs.Doc}) | |||||
| input_hash = hashlib.sha256() | input_hash = hashlib.sha256() | ||||
| input_hash.update(inputs.encode(encoding='UTF-8')) | |||||
| input_hash.update(inputs.encode(encoding="UTF-8")) | |||||
| input_hash = input_hash.hexdigest() | input_hash = input_hash.hexdigest() | ||||
| writer = ConverterWriter() | writer = ConverterWriter() | ||||
| with open(args.output, 'w') as fout: | |||||
| with open(args.output, "w") as fout: | |||||
| writer.set_input_hash(input_hash)(fout, ParamDef.all_param_defs) | writer.set_input_hash(input_hash)(fout, ParamDef.all_param_defs) | ||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||
| main() | main() | ||||
| @@ -3,13 +3,14 @@ | |||||
| import argparse | import argparse | ||||
| import collections | import collections | ||||
| import textwrap | |||||
| import os | |||||
| import hashlib | import hashlib | ||||
| import struct | |||||
| import io | import io | ||||
| import os | |||||
| import struct | |||||
| import textwrap | |||||
| from gen_param_defs import IndentWriterBase, ParamDef, member_defs | |||||
| from gen_param_defs import member_defs, ParamDef, IndentWriterBase | |||||
| def _cname_to_fbname(cname): | def _cname_to_fbname(cname): | ||||
| return { | return { | ||||
| @@ -22,17 +23,19 @@ def _cname_to_fbname(cname): | |||||
| "bool": "bool", | "bool": "bool", | ||||
| }[cname] | }[cname] | ||||
| def scramble_enum_member_name(name): | def scramble_enum_member_name(name): | ||||
| s = name.find('<<') | |||||
| s = name.find("<<") | |||||
| if s != -1: | if s != -1: | ||||
| name = name[0:name.find('=') + 1] + ' ' + name[s+2:] | |||||
| name = name[0 : name.find("=") + 1] + " " + name[s + 2 :] | |||||
| if name in ("MIN", "MAX"): | if name in ("MIN", "MAX"): | ||||
| return name + "_" | return name + "_" | ||||
| o_name = name.split(' ')[0].split('=')[0] | |||||
| o_name = name.split(" ")[0].split("=")[0] | |||||
| if o_name in ("MIN", "MAX"): | if o_name in ("MIN", "MAX"): | ||||
| return name.replace(o_name, o_name + "_") | return name.replace(o_name, o_name + "_") | ||||
| return name | return name | ||||
| class FlatBuffersWriter(IndentWriterBase): | class FlatBuffersWriter(IndentWriterBase): | ||||
| _skip_current_param = False | _skip_current_param = False | ||||
| _last_param = None | _last_param = None | ||||
| @@ -66,12 +69,13 @@ class FlatBuffersWriter(IndentWriterBase): | |||||
| self._write("}\n", indent=-1) | self._write("}\n", indent=-1) | ||||
| def _write_doc(self, doc): | def _write_doc(self, doc): | ||||
| if not isinstance(doc, member_defs.Doc) or not doc.doc: return | |||||
| if not isinstance(doc, member_defs.Doc) or not doc.doc: | |||||
| return | |||||
| doc_lines = [] | doc_lines = [] | ||||
| if doc.no_reformat: | if doc.no_reformat: | ||||
| doc_lines = doc.raw_lines | doc_lines = doc.raw_lines | ||||
| else: | else: | ||||
| doc = doc.doc.replace('\n', ' ') | |||||
| doc = doc.doc.replace("\n", " ") | |||||
| text_width = 80 - len(self._cur_indent) - 4 | text_width = 80 - len(self._cur_indent) - 4 | ||||
| doc_lines = textwrap.wrap(doc, text_width) | doc_lines = textwrap.wrap(doc, text_width) | ||||
| for line in doc_lines: | for line in doc_lines: | ||||
| @@ -101,7 +105,8 @@ class FlatBuffersWriter(IndentWriterBase): | |||||
| default = e.compose_combined_enum(e.default) | default = e.compose_combined_enum(e.default) | ||||
| else: | else: | ||||
| default = scramble_enum_member_name( | default = scramble_enum_member_name( | ||||
| str(e.members[e.default]).split(' ')[0].split('=')[0]) | |||||
| str(e.members[e.default]).split(" ")[0].split("=")[0] | |||||
| ) | |||||
| self._write("%s:%s%s = %s;", e.name_field, p.name, e.name, default) | self._write("%s:%s%s = %s;", e.name_field, p.name, e.name, default) | ||||
| def _resolve_const(self, v): | def _resolve_const(self, v): | ||||
| @@ -113,8 +118,12 @@ class FlatBuffersWriter(IndentWriterBase): | |||||
| if self._skip_current_param: | if self._skip_current_param: | ||||
| return | return | ||||
| self._write_doc(f.name) | self._write_doc(f.name) | ||||
| self._write("%s:%s = %s;", f.name, _cname_to_fbname(f.dtype.cname), | |||||
| self._get_fb_default(self._resolve_const(f.default))) | |||||
| self._write( | |||||
| "%s:%s = %s;", | |||||
| f.name, | |||||
| _cname_to_fbname(f.dtype.cname), | |||||
| self._get_fb_default(self._resolve_const(f.default)), | |||||
| ) | |||||
| def _on_const_field(self, f): | def _on_const_field(self, f): | ||||
| self._cur_const_val[str(f.name)] = str(f.default) | self._cur_const_val[str(f.name)] = str(f.default) | ||||
| @@ -129,7 +138,8 @@ class FlatBuffersWriter(IndentWriterBase): | |||||
| default = s.compose_combined_enum(e.get_default()) | default = s.compose_combined_enum(e.get_default()) | ||||
| else: | else: | ||||
| default = scramble_enum_member_name( | default = scramble_enum_member_name( | ||||
| str(s.members[e.get_default()]).split(' ')[0].split('=')[0]) | |||||
| str(s.members[e.get_default()]).split(" ")[0].split("=")[0] | |||||
| ) | |||||
| self._write("%s:%s = %s;", e.name_field, enum_name, default) | self._write("%s:%s = %s;", e.name_field, enum_name, default) | ||||
| def _get_fb_default(self, cppdefault): | def _get_fb_default(self, cppdefault): | ||||
| @@ -137,9 +147,9 @@ class FlatBuffersWriter(IndentWriterBase): | |||||
| return cppdefault | return cppdefault | ||||
| d = cppdefault | d = cppdefault | ||||
| if d.endswith('f'): # 1.f | |||||
| if d.endswith("f"): # 1.f | |||||
| return d[:-1] | return d[:-1] | ||||
| if d.endswith('ull'): | |||||
| if d.endswith("ull"): | |||||
| return d[:-3] | return d[:-3] | ||||
| if d.startswith("DTypeEnum::"): | if d.startswith("DTypeEnum::"): | ||||
| return d[11:] | return d[11:] | ||||
| @@ -148,21 +158,23 @@ class FlatBuffersWriter(IndentWriterBase): | |||||
| def main(): | def main(): | ||||
| parser = argparse.ArgumentParser( | parser = argparse.ArgumentParser( | ||||
| 'generate FlatBuffers schema of operator param from description file') | |||||
| parser.add_argument('input') | |||||
| parser.add_argument('output') | |||||
| "generate FlatBuffers schema of operator param from description file" | |||||
| ) | |||||
| parser.add_argument("input") | |||||
| parser.add_argument("output") | |||||
| args = parser.parse_args() | args = parser.parse_args() | ||||
| with open(args.input) as fin: | with open(args.input) as fin: | ||||
| inputs = fin.read() | inputs = fin.read() | ||||
| exec(inputs, {'pdef': ParamDef, 'Doc': member_defs.Doc}) | |||||
| exec(inputs, {"pdef": ParamDef, "Doc": member_defs.Doc}) | |||||
| input_hash = hashlib.sha256() | input_hash = hashlib.sha256() | ||||
| input_hash.update(inputs.encode(encoding='UTF-8')) | |||||
| input_hash.update(inputs.encode(encoding="UTF-8")) | |||||
| input_hash = input_hash.hexdigest() | input_hash = input_hash.hexdigest() | ||||
| writer = FlatBuffersWriter() | writer = FlatBuffersWriter() | ||||
| with open(args.output, 'w') as fout: | |||||
| with open(args.output, "w") as fout: | |||||
| writer.set_input_hash(input_hash)(fout, ParamDef.all_param_defs) | writer.set_input_hash(input_hash)(fout, ParamDef.all_param_defs) | ||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||
| main() | main() | ||||
| @@ -1,14 +1,16 @@ | |||||
| #! /usr/local/env python3 | #! /usr/local/env python3 | ||||
| import pickle | |||||
| import numpy as np | |||||
| import os | |||||
| import argparse | import argparse | ||||
| import re | |||||
| import collections | import collections | ||||
| import os | |||||
| import pickle | |||||
| import re | |||||
| import numpy as np | |||||
| def define_template(**kwargs): | def define_template(**kwargs): | ||||
| template = ''' | |||||
| template = """ | |||||
| float cuda{cuda_arch}_{conv_type}_time_pred[{out_dim}] = {{0.0f}}; | float cuda{cuda_arch}_{conv_type}_time_pred[{out_dim}] = {{0.0f}}; | ||||
| float cuda{cuda_arch}_{conv_type}_mask[{out_dim}] = {{0.0f}}; | float cuda{cuda_arch}_{conv_type}_mask[{out_dim}] = {{0.0f}}; | ||||
| float cuda{cuda_arch}_{conv_type}_hidden_units[{hidden_num}] = {{0.0f}}; | float cuda{cuda_arch}_{conv_type}_hidden_units[{hidden_num}] = {{0.0f}}; | ||||
| @@ -17,21 +19,23 @@ def define_template(**kwargs): | |||||
| const static float cuda{cuda_arch}_{conv_type}_biases[{biases_dim}] = {{{biases}}}; | const static float cuda{cuda_arch}_{conv_type}_biases[{biases_dim}] = {{{biases}}}; | ||||
| const static float cuda{cuda_arch}_{conv_type}_alpha[{out_dim}] = {{{alpha}}}; | const static float cuda{cuda_arch}_{conv_type}_alpha[{out_dim}] = {{{alpha}}}; | ||||
| const static float cuda{cuda_arch}_{conv_type}_beta[{out_dim}] = {{{beta}}}; | const static float cuda{cuda_arch}_{conv_type}_beta[{out_dim}] = {{{beta}}}; | ||||
| ''' | |||||
| """ | |||||
| return template.format(**kwargs) | return template.format(**kwargs) | ||||
| def cudnn_slt_template(**kwargs): | def cudnn_slt_template(**kwargs): | ||||
| template = ("#if CUDNN_MAJOR == {cudnn_major} && CUDNN_MINOR == {cudnn_minor}\n" + | |||||
| " {define_cmd}\n" + | |||||
| " {select_cmd}\n" + | |||||
| " return true;\n" + | |||||
| "#endif\n" | |||||
| ) | |||||
| template = ( | |||||
| "#if CUDNN_MAJOR == {cudnn_major} && CUDNN_MINOR == {cudnn_minor}\n" | |||||
| + " {define_cmd}\n" | |||||
| + " {select_cmd}\n" | |||||
| + " return true;\n" | |||||
| + "#endif\n" | |||||
| ) | |||||
| return template.format(**kwargs) | return template.format(**kwargs) | ||||
| def select_template(**kwargs): | def select_template(**kwargs): | ||||
| template = \ | |||||
| '''if (conv_type == ConvolutionType::{conv_type} && cuda_major == {cuda_major} && | |||||
| template = """if (conv_type == ConvolutionType::{conv_type} && cuda_major == {cuda_major} && | |||||
| cuda_minor == {cuda_minor}) {{ | cuda_minor == {cuda_minor}) {{ | ||||
| *layer_num_p = {layer_num}; | *layer_num_p = {layer_num}; | ||||
| *hidden_units_p = cuda{cuda_arch}_{conv_type}_hidden_units; | *hidden_units_p = cuda{cuda_arch}_{conv_type}_hidden_units; | ||||
| @@ -42,7 +46,7 @@ def select_template(**kwargs): | |||||
| *beta_p = cuda{cuda_arch}_{conv_type}_beta; | *beta_p = cuda{cuda_arch}_{conv_type}_beta; | ||||
| *time_pred_p = cuda{cuda_arch}_{conv_type}_time_pred; | *time_pred_p = cuda{cuda_arch}_{conv_type}_time_pred; | ||||
| *mask_p = cuda{cuda_arch}_{conv_type}_mask; | *mask_p = cuda{cuda_arch}_{conv_type}_mask; | ||||
| }} else ''' | |||||
| }} else """ | |||||
| return template.format(**kwargs) | return template.format(**kwargs) | ||||
| @@ -58,48 +62,48 @@ def fill_src(): | |||||
| if len(matrix_files) == 0: | if len(matrix_files) == 0: | ||||
| print("Warning: no param files detected.") | print("Warning: no param files detected.") | ||||
| for fpath in matrix_files: | for fpath in matrix_files: | ||||
| cudnn_version = re.findall('cudnn([\d.]+)',fpath)[0] | |||||
| cudnn_version = re.findall("cudnn([\d.]+)", fpath)[0] | |||||
| gen_list[cudnn_version].append(fpath) | gen_list[cudnn_version].append(fpath) | ||||
| for cudnn in gen_list: | for cudnn in gen_list: | ||||
| select_cmd = ("{\n" + | |||||
| " " * 8 + "return false;\n" + | |||||
| " " * 4 + "}") | |||||
| select_cmd = "{\n" + " " * 8 + "return false;\n" + " " * 4 + "}" | |||||
| define_cmd = "" | define_cmd = "" | ||||
| cudnn_major, cudnn_minor = cudnn.split('.') | |||||
| cudnn_major, cudnn_minor = cudnn.split(".") | |||||
| for fpath in gen_list[cudnn]: | for fpath in gen_list[cudnn]: | ||||
| cuda_arch = fpath.split("-")[1].replace(".", "_") | cuda_arch = fpath.split("-")[1].replace(".", "_") | ||||
| print('cudnn_version: {}, cuda_arch: {}'.format(cudnn,cuda_arch)) | |||||
| print("cudnn_version: {}, cuda_arch: {}".format(cudnn, cuda_arch)) | |||||
| conv_type = fpath.split("-")[2].split(".")[0] | conv_type = fpath.split("-")[2].split(".")[0] | ||||
| with open(os.path.join(home, "params/{}".format(fpath)), "rb") as pobj: | with open(os.path.join(home, "params/{}".format(fpath)), "rb") as pobj: | ||||
| params = pickle.load(pobj) | params = pickle.load(pobj) | ||||
| crt_define_cmd, crt_select_cmd = gen_cmds( | |||||
| cuda_arch, conv_type, params) | |||||
| crt_define_cmd, crt_select_cmd = gen_cmds(cuda_arch, conv_type, params) | |||||
| select_cmd = crt_select_cmd + select_cmd | select_cmd = crt_select_cmd + select_cmd | ||||
| define_cmd = crt_define_cmd + define_cmd | define_cmd = crt_define_cmd + define_cmd | ||||
| cudnn_slt_cmd += cudnn_slt_template(cudnn_major=cudnn_major, | |||||
| cudnn_minor=cudnn_minor, | |||||
| select_cmd=select_cmd, | |||||
| define_cmd=define_cmd) | |||||
| cudnn_slt_cmd += cudnn_slt_template( | |||||
| cudnn_major=cudnn_major, | |||||
| cudnn_minor=cudnn_minor, | |||||
| select_cmd=select_cmd, | |||||
| define_cmd=define_cmd, | |||||
| ) | |||||
| #select_cmd = select_cmd | |||||
| # select_cmd = select_cmd | |||||
| with open(os.path.join(home, "get_params.template"), "r") as srcf: | with open(os.path.join(home, "get_params.template"), "r") as srcf: | ||||
| src = srcf.read() | src = srcf.read() | ||||
| dst = src.replace("{cudnn_select}", cudnn_slt_cmd) | dst = src.replace("{cudnn_select}", cudnn_slt_cmd) | ||||
| MegDNN_path = os.path.join(home, "../..") | MegDNN_path = os.path.join(home, "../..") | ||||
| with open(os.path.join(MegDNN_path, | |||||
| "src/cuda/convolution/get_params.cpp"), "w") as dstf: | |||||
| with open( | |||||
| os.path.join(MegDNN_path, "src/cuda/convolution/get_params.cpp"), "w" | |||||
| ) as dstf: | |||||
| dstf.write(dst) | dstf.write(dst) | ||||
| def gen_cmds(cuda_arch, conv_type, params): | def gen_cmds(cuda_arch, conv_type, params): | ||||
| cuda_major, cuda_minor = cuda_arch.split("_") | cuda_major, cuda_minor = cuda_arch.split("_") | ||||
| alphastr = format_array(params['alpha']).rstrip()[:-1] | |||||
| betastr = format_array(params['beta']).rstrip()[:-1] | |||||
| W_list = params['W'] | |||||
| b_list = params['b'] | |||||
| Wstr = '' | |||||
| bstr = '' | |||||
| alphastr = format_array(params["alpha"]).rstrip()[:-1] | |||||
| betastr = format_array(params["beta"]).rstrip()[:-1] | |||||
| W_list = params["W"] | |||||
| b_list = params["b"] | |||||
| Wstr = "" | |||||
| bstr = "" | |||||
| layer_num = str(len(b_list) + 1) | layer_num = str(len(b_list) + 1) | ||||
| layers_dim = [W_list[0].shape[1]] | layers_dim = [W_list[0].shape[1]] | ||||
| matrices_dim = 0 | matrices_dim = 0 | ||||
| @@ -118,16 +122,27 @@ def gen_cmds(cuda_arch, conv_type, params): | |||||
| out_dim = layers_dim[-1] | out_dim = layers_dim[-1] | ||||
| layers_dim_str = format_array(np.array(layers_dim)).rstrip()[:-1] | layers_dim_str = format_array(np.array(layers_dim)).rstrip()[:-1] | ||||
| select_cmd = select_template(conv_type=conv_type.upper(), cuda_major=cuda_major, | |||||
| cuda_minor=cuda_minor, layer_num=layer_num, | |||||
| cuda_arch=cuda_arch) | |||||
| define_cmd = define_template(cuda_arch=cuda_arch, conv_type=conv_type.upper(), | |||||
| hidden_num=hidden_num, | |||||
| layer_num=layer_num, out_dim=out_dim, | |||||
| layers_dim=layers_dim_str, | |||||
| matrices_dim=matrices_dim, matrices=Wstr, | |||||
| biases_dim=biases_dim, biases=bstr, | |||||
| alpha=alphastr, beta=betastr) | |||||
| select_cmd = select_template( | |||||
| conv_type=conv_type.upper(), | |||||
| cuda_major=cuda_major, | |||||
| cuda_minor=cuda_minor, | |||||
| layer_num=layer_num, | |||||
| cuda_arch=cuda_arch, | |||||
| ) | |||||
| define_cmd = define_template( | |||||
| cuda_arch=cuda_arch, | |||||
| conv_type=conv_type.upper(), | |||||
| hidden_num=hidden_num, | |||||
| layer_num=layer_num, | |||||
| out_dim=out_dim, | |||||
| layers_dim=layers_dim_str, | |||||
| matrices_dim=matrices_dim, | |||||
| matrices=Wstr, | |||||
| biases_dim=biases_dim, | |||||
| biases=bstr, | |||||
| alpha=alphastr, | |||||
| beta=betastr, | |||||
| ) | |||||
| return (define_cmd, select_cmd) | return (define_cmd, select_cmd) | ||||
| @@ -153,8 +168,9 @@ def format_array(array): | |||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||
| parser = argparse.ArgumentParser( | parser = argparse.ArgumentParser( | ||||
| description="Generate cuDNN heuristic code by neural network into" | description="Generate cuDNN heuristic code by neural network into" | ||||
| " {MEGDNN_ROOT}/src/cuda/convolution/get_params.cpp," | |||||
| " using parameter value from pickle files in" | |||||
| " {MEGDNN_ROOT}/scripts/gen_heuristic/params/") | |||||
| " {MEGDNN_ROOT}/src/cuda/convolution/get_params.cpp," | |||||
| " using parameter value from pickle files in" | |||||
| " {MEGDNN_ROOT}/scripts/gen_heuristic/params/" | |||||
| ) | |||||
| args = parser.parse_args() | args = parser.parse_args() | ||||
| main() | main() | ||||
| @@ -3,19 +3,17 @@ | |||||
| import argparse | import argparse | ||||
| import collections | import collections | ||||
| import textwrap | |||||
| import os | |||||
| import hashlib | import hashlib | ||||
| import struct | |||||
| import io | import io | ||||
| import os | |||||
| import struct | |||||
| import textwrap | |||||
| from gen_param_defs import member_defs, ParamDef, IndentWriterBase | |||||
| from gen_param_defs import IndentWriterBase, ParamDef, member_defs | |||||
| # FIXME: move supportToString flag definition into the param def source file | # FIXME: move supportToString flag definition into the param def source file | ||||
| ENUM_TO_STRING_SPECIAL_RULES = [ | |||||
| ("Elemwise", "Mode"), | |||||
| ("ElemwiseMultiType", "Mode") | |||||
| ] | |||||
| ENUM_TO_STRING_SPECIAL_RULES = [("Elemwise", "Mode"), ("ElemwiseMultiType", "Mode")] | |||||
| class ConverterWriter(IndentWriterBase): | class ConverterWriter(IndentWriterBase): | ||||
| _skip_current_param = False | _skip_current_param = False | ||||
| @@ -33,21 +31,21 @@ class ConverterWriter(IndentWriterBase): | |||||
| self._write("#endif // MGB_PARAM") | self._write("#endif // MGB_PARAM") | ||||
| def _ctype2attr(self, ctype, value): | def _ctype2attr(self, ctype, value): | ||||
| if ctype == 'uint32_t': | |||||
| return 'MgbUI32Attr', value | |||||
| if ctype == 'uint64_t': | |||||
| return 'MgbUI64Attr', value | |||||
| if ctype == 'int32_t': | |||||
| return 'MgbI32Attr', value | |||||
| if ctype == 'float': | |||||
| return 'MgbF32Attr', value | |||||
| if ctype == 'double': | |||||
| return 'MgbF64Attr', value | |||||
| if ctype == 'bool': | |||||
| return 'MgbBoolAttr', value | |||||
| if ctype == 'DTypeEnum': | |||||
| if ctype == "uint32_t": | |||||
| return "MgbUI32Attr", value | |||||
| if ctype == "uint64_t": | |||||
| return "MgbUI64Attr", value | |||||
| if ctype == "int32_t": | |||||
| return "MgbI32Attr", value | |||||
| if ctype == "float": | |||||
| return "MgbF32Attr", value | |||||
| if ctype == "double": | |||||
| return "MgbF64Attr", value | |||||
| if ctype == "bool": | |||||
| return "MgbBoolAttr", value | |||||
| if ctype == "DTypeEnum": | |||||
| self._packed = False | self._packed = False | ||||
| return 'MgbDTypeAttr', 'megdnn::DType::from_enum(megdnn::{})'.format(value) | |||||
| return "MgbDTypeAttr", "megdnn::DType::from_enum(megdnn::{})".format(value) | |||||
| raise RuntimeError("unknown ctype") | raise RuntimeError("unknown ctype") | ||||
| def _on_param_begin(self, p): | def _on_param_begin(self, p): | ||||
| @@ -61,21 +59,26 @@ class ConverterWriter(IndentWriterBase): | |||||
| self._skip_current_param = False | self._skip_current_param = False | ||||
| return | return | ||||
| if self._packed: | if self._packed: | ||||
| self._write("class {0}ParamBase<string accessor> : MgbPackedParamBase<\"{0}\", accessor> {{".format(p.name), indent=1) | |||||
| self._write( | |||||
| 'class {0}ParamBase<string accessor> : MgbPackedParamBase<"{0}", accessor> {{'.format( | |||||
| p.name | |||||
| ), | |||||
| indent=1, | |||||
| ) | |||||
| else: | else: | ||||
| self._write("def {0}Param: MgbParamBase<\"{0}\"> {{".format(p.name), indent=1) | |||||
| self._write('def {0}Param: MgbParamBase<"{0}"> {{'.format(p.name), indent=1) | |||||
| self._write("let fields = (ins", indent=1) | self._write("let fields = (ins", indent=1) | ||||
| self._write(",\n{}".format(self._cur_indent).join(self._current_tparams)) | self._write(",\n{}".format(self._cur_indent).join(self._current_tparams)) | ||||
| self._write(");", indent=-1) | self._write(");", indent=-1) | ||||
| self._write("}\n", indent=-1) | self._write("}\n", indent=-1) | ||||
| if self._packed: | if self._packed: | ||||
| self._write("def {0}Param : {0}ParamBase<\"param\">;\n".format(p.name)) | |||||
| self._write('def {0}Param : {0}ParamBase<"param">;\n'.format(p.name)) | |||||
| self._current_tparams = None | self._current_tparams = None | ||||
| self._packed = None | self._packed = None | ||||
| self._const = None | self._const = None | ||||
| def _wrapped_with_default_value(self, attr, default): | def _wrapped_with_default_value(self, attr, default): | ||||
| return 'MgbDefaultValuedAttr<{}, \"{}\">'.format(attr, default) | |||||
| return 'MgbDefaultValuedAttr<{}, "{}">'.format(attr, default) | |||||
| def _on_member_enum(self, e): | def _on_member_enum(self, e): | ||||
| p = self._last_param | p = self._last_param | ||||
| @@ -84,10 +87,12 @@ class ConverterWriter(IndentWriterBase): | |||||
| # directly used by any operator, or other enum couldn't alias to this enum | # directly used by any operator, or other enum couldn't alias to this enum | ||||
| td_class = "{}{}".format(p.name, e.name) | td_class = "{}{}".format(p.name, e.name) | ||||
| fullname = "::megdnn::param::{}".format(p.name) | fullname = "::megdnn::param::{}".format(p.name) | ||||
| enum_def = "MgbEnumAttr<\"{}\", \"{}\", [".format(fullname, e.name) | |||||
| enum_def = 'MgbEnumAttr<"{}", "{}", ['.format(fullname, e.name) | |||||
| def format(v): | def format(v): | ||||
| return '\"{}\"'.format(str(v).split(' ')[0].split('=')[0]) | |||||
| enum_def += ','.join(format(i) for i in e.members) | |||||
| return '"{}"'.format(str(v).split(" ")[0].split("=")[0]) | |||||
| enum_def += ",".join(format(i) for i in e.members) | |||||
| if e.combined: | if e.combined: | ||||
| enum_def += "], 1" | enum_def += "], 1" | ||||
| @@ -95,7 +100,7 @@ class ConverterWriter(IndentWriterBase): | |||||
| enum_def += "], 0" | enum_def += "], 0" | ||||
| if ENUM_TO_STRING_SPECIAL_RULES.count((p.name, e.name)): | if ENUM_TO_STRING_SPECIAL_RULES.count((p.name, e.name)): | ||||
| enum_def += ", 1" # whether generate ToStringTrait | |||||
| enum_def += ", 1" # whether generate ToStringTrait | |||||
| enum_def += ">" | enum_def += ">" | ||||
| self._write("def {} : {};".format(td_class, enum_def)) | self._write("def {} : {};".format(td_class, enum_def)) | ||||
| @@ -105,10 +110,12 @@ class ConverterWriter(IndentWriterBase): | |||||
| # wrapped with default value | # wrapped with default value | ||||
| if e.combined: | if e.combined: | ||||
| default_val = "static_cast<{}::{}>({})".format( | default_val = "static_cast<{}::{}>({})".format( | ||||
| fullname, e.name, e.compose_combined_enum(e.default)) | |||||
| fullname, e.name, e.compose_combined_enum(e.default) | |||||
| ) | |||||
| else: | else: | ||||
| default_val = "{}::{}::{}".format( | default_val = "{}::{}::{}".format( | ||||
| fullname, e.name, str(e.members[e.default]).split(' ')[0].split('=')[0]) | |||||
| fullname, e.name, str(e.members[e.default]).split(" ")[0].split("=")[0] | |||||
| ) | |||||
| wrapped = self._wrapped_with_default_value(td_class, default_val) | wrapped = self._wrapped_with_default_value(td_class, default_val) | ||||
| @@ -123,51 +130,58 @@ class ConverterWriter(IndentWriterBase): | |||||
| td_class = "{}{}".format(p.name, e.name) | td_class = "{}{}".format(p.name, e.name) | ||||
| fullname = "::megdnn::param::{}".format(p.name) | fullname = "::megdnn::param::{}".format(p.name) | ||||
| base_td_class = "{}{}".format(e.src_class, e.src_name) | base_td_class = "{}{}".format(e.src_class, e.src_name) | ||||
| enum_def = "MgbEnumAliasAttr<\"{}\", \"{}\", {}>".format(fullname, e.name, base_td_class) | |||||
| enum_def = 'MgbEnumAliasAttr<"{}", "{}", {}>'.format( | |||||
| fullname, e.name, base_td_class | |||||
| ) | |||||
| self._write("def {} : {};".format(td_class, enum_def)) | self._write("def {} : {};".format(td_class, enum_def)) | ||||
| # wrapped with default value | # wrapped with default value | ||||
| s = e.src_enum | s = e.src_enum | ||||
| if s.combined: | if s.combined: | ||||
| default_val = "static_cast<{}::{}>({})".format( | default_val = "static_cast<{}::{}>({})".format( | ||||
| fullname, e.name, s.compose_combined_enum(e.get_default())) | |||||
| fullname, e.name, s.compose_combined_enum(e.get_default()) | |||||
| ) | |||||
| else: | else: | ||||
| default_val = "{}::{}::{}".format(fullname, e.name, str( | |||||
| s.members[e.get_default()]).split(' ')[0].split('=')[0]) | |||||
| default_val = "{}::{}::{}".format( | |||||
| fullname, | |||||
| e.name, | |||||
| str(s.members[e.get_default()]).split(" ")[0].split("=")[0], | |||||
| ) | |||||
| wrapped = self._wrapped_with_default_value(td_class, default_val) | wrapped = self._wrapped_with_default_value(td_class, default_val) | ||||
| self._current_tparams.append("{}:${}".format(wrapped, e.name_field)) | self._current_tparams.append("{}:${}".format(wrapped, e.name_field)) | ||||
| def _on_member_field(self, f): | def _on_member_field(self, f): | ||||
| if self._skip_current_param: | if self._skip_current_param: | ||||
| return | return | ||||
| attr, value = self._ctype2attr(f.dtype.cname, str(f.default)) | attr, value = self._ctype2attr(f.dtype.cname, str(f.default)) | ||||
| if str(value) in self._const: | if str(value) in self._const: | ||||
| value = '::megdnn::param::{}::{}'.format(self._last_param.name, value) | |||||
| value = "::megdnn::param::{}::{}".format(self._last_param.name, value) | |||||
| wrapped = self._wrapped_with_default_value(attr, value) | wrapped = self._wrapped_with_default_value(attr, value) | ||||
| self._current_tparams.append("{}:${}".format(wrapped, f.name)) | self._current_tparams.append("{}:${}".format(wrapped, f.name)) | ||||
| def _on_const_field(self, f): | def _on_const_field(self, f): | ||||
| self._const.add(str(f.name)) | self._const.add(str(f.name)) | ||||
| def main(): | def main(): | ||||
| parser = argparse.ArgumentParser('generate op param tablegen file') | |||||
| parser.add_argument('input') | |||||
| parser.add_argument('output') | |||||
| parser = argparse.ArgumentParser("generate op param tablegen file") | |||||
| parser.add_argument("input") | |||||
| parser.add_argument("output") | |||||
| args = parser.parse_args() | args = parser.parse_args() | ||||
| with open(args.input) as fin: | with open(args.input) as fin: | ||||
| inputs = fin.read() | inputs = fin.read() | ||||
| exec(inputs, {'pdef': ParamDef, 'Doc': member_defs.Doc}) | |||||
| exec(inputs, {"pdef": ParamDef, "Doc": member_defs.Doc}) | |||||
| input_hash = hashlib.sha256() | input_hash = hashlib.sha256() | ||||
| input_hash.update(inputs.encode(encoding='UTF-8')) | |||||
| input_hash.update(inputs.encode(encoding="UTF-8")) | |||||
| input_hash = input_hash.hexdigest() | input_hash = input_hash.hexdigest() | ||||
| writer = ConverterWriter() | writer = ConverterWriter() | ||||
| with open(args.output, 'w') as fout: | |||||
| with open(args.output, "w") as fout: | |||||
| writer.set_input_hash(input_hash)(fout, ParamDef.all_param_defs) | writer.set_input_hash(input_hash)(fout, ParamDef.all_param_defs) | ||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||
| main() | main() | ||||
| @@ -19,6 +19,7 @@ device = { | |||||
| "thread_number": 3, | "thread_number": 3, | ||||
| } | } | ||||
| class SshConnector: | class SshConnector: | ||||
| """imp ssh control master connector""" | """imp ssh control master connector""" | ||||
| @@ -83,17 +84,17 @@ def main(): | |||||
| model_file = args.model_file | model_file = args.model_file | ||||
| # copy model file | # copy model file | ||||
| ssh.copy([args.model_file], workspace) | ssh.copy([args.model_file], workspace) | ||||
| m = model_file.split('\\')[-1] | |||||
| m = model_file.split("\\")[-1] | |||||
| # run single thread | # run single thread | ||||
| result = [] | result = [] | ||||
| thread_number = [1, 2, 4] | thread_number = [1, 2, 4] | ||||
| for b in thread_number : | |||||
| for b in thread_number: | |||||
| cmd = [] | cmd = [] | ||||
| cmd1 = "cd {} && ./load_and_run {} -multithread {} --fast-run --fast_run_algo_policy fastrun.cache --iter 1 --warmup-iter 1 --no-sanity-check --weight-preprocess".format( | cmd1 = "cd {} && ./load_and_run {} -multithread {} --fast-run --fast_run_algo_policy fastrun.cache --iter 1 --warmup-iter 1 --no-sanity-check --weight-preprocess".format( | ||||
| workspace, m, b | |||||
| workspace, m, b | |||||
| ) | ) | ||||
| cmd2 = "cd {} && ./load_and_run {} -multithread {} --fast_run_algo_policy fastrun.cache --iter 20 --warmup-iter 5 --no-sanity-check --weight-preprocess ".format( | cmd2 = "cd {} && ./load_and_run {} -multithread {} --fast_run_algo_policy fastrun.cache --iter 20 --warmup-iter 5 --no-sanity-check --weight-preprocess ".format( | ||||
| workspace, m, b | |||||
| workspace, m, b | |||||
| ) | ) | ||||
| cmd.append(cmd1) | cmd.append(cmd1) | ||||
| cmd.append(cmd2) | cmd.append(cmd2) | ||||
| @@ -103,12 +104,20 @@ def main(): | |||||
| logging.debug("model: {} with backend: {} result is: {}".format(m, b, ret)) | logging.debug("model: {} with backend: {} result is: {}".format(m, b, ret)) | ||||
| result.append(ret) | result.append(ret) | ||||
| thread_2 = result[0]/result[1] | |||||
| thread_4 = result[0]/result[2] | |||||
| thread_2 = result[0] / result[1] | |||||
| thread_4 = result[0] / result[2] | |||||
| if thread_2 > 1.6 or thread_4 > 3.0: | if thread_2 > 1.6 or thread_4 > 3.0: | ||||
| print("model: {} can has good parallelism. 2 thread is {}, 4 thread is {}".format(m, thread_2, thread_4)) | |||||
| print( | |||||
| "model: {} can has good parallelism. 2 thread is {}, 4 thread is {}".format( | |||||
| m, thread_2, thread_4 | |||||
| ) | |||||
| ) | |||||
| else: | else: | ||||
| print("model: {} can has bad parallelism. 2 thread is {}, 4 thread is {}".format(m, thread_2, thread_4)) | |||||
| print( | |||||
| "model: {} can has bad parallelism. 2 thread is {}, 4 thread is {}".format( | |||||
| m, thread_2, thread_4 | |||||
| ) | |||||
| ) | |||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||
| @@ -20,8 +20,12 @@ failed_files = Manager().list() | |||||
| def process_file(file, clang_format, write): | def process_file(file, clang_format, write): | ||||
| original_source = open(file, "r").read() | original_source = open(file, "r").read() | ||||
| source = original_source | source = original_source | ||||
| source = re.sub(r"MGB_DEFINE(?P<r>([^\\]|\n)*?)// *{", r"class MGB_DEFINE\g<r>{", source) | |||||
| source, count = re.subn(r"(?<!#define )MGB_DEFINE(.*) +\\", r"class MGB_DEFINE\1{\\", source) | |||||
| source = re.sub( | |||||
| r"MGB_DEFINE(?P<r>([^\\]|\n)*?)// *{", r"class MGB_DEFINE\g<r>{", source | |||||
| ) | |||||
| source, count = re.subn( | |||||
| r"(?<!#define )MGB_DEFINE(.*) +\\", r"class MGB_DEFINE\1{\\", source | |||||
| ) | |||||
| result = subprocess.check_output( | result = subprocess.check_output( | ||||
| [ | [ | ||||
| @@ -36,7 +40,9 @@ def process_file(file, clang_format, write): | |||||
| result = result.decode("utf-8") | result = result.decode("utf-8") | ||||
| if count: | if count: | ||||
| result = re.sub(r"class MGB_DEFINE(.*){( *)\\", r"MGB_DEFINE\1\2 \\", result) | |||||
| result = re.sub( | |||||
| r"class MGB_DEFINE(.*){( *)\\", r"MGB_DEFINE\1\2 \\", result | |||||
| ) | |||||
| result = re.sub(r"class MGB_DEFINE((.|\n)*?){", r"MGB_DEFINE\1// {", result) | result = re.sub(r"class MGB_DEFINE((.|\n)*?){", r"MGB_DEFINE\1// {", result) | ||||
| if write and original_source != result: | if write and original_source != result: | ||||
| @@ -109,19 +115,17 @@ def main(): | |||||
| raise ValueError("Invalid path {}".format(path)) | raise ValueError("Invalid path {}".format(path)) | ||||
| # check version, we only support 12.0.1 now | # check version, we only support 12.0.1 now | ||||
| version = subprocess.check_output( | |||||
| [ | |||||
| args.clang_format, | |||||
| "--version", | |||||
| ], | |||||
| ) | |||||
| version = subprocess.check_output([args.clang_format, "--version",],) | |||||
| version = version.decode("utf-8") | version = version.decode("utf-8") | ||||
| need_version = '12.0.1' | |||||
| need_version = "12.0.1" | |||||
| if version.find(need_version) < 0: | if version.find(need_version) < 0: | ||||
| print('We only support {} now, please install {} version, find version: {}' | |||||
| .format(need_version, need_version, version)) | |||||
| raise RuntimeError('clang-format version not equal {}'.format(need_version)) | |||||
| print( | |||||
| "We only support {} now, please install {} version, find version: {}".format( | |||||
| need_version, need_version, version | |||||
| ) | |||||
| ) | |||||
| raise RuntimeError("clang-format version not equal {}".format(need_version)) | |||||
| process_map( | process_map( | ||||
| partial(process_file, clang_format=args.clang_format, write=args.write,), | partial(process_file, clang_format=args.clang_format, write=args.write,), | ||||
| @@ -20,6 +20,7 @@ device = { | |||||
| "thread_number": 3, | "thread_number": 3, | ||||
| } | } | ||||
| class SshConnector: | class SshConnector: | ||||
| """imp ssh control master connector""" | """imp ssh control master connector""" | ||||
| @@ -54,6 +55,7 @@ class SshConnector: | |||||
| except: | except: | ||||
| raise | raise | ||||
| def main(): | def main(): | ||||
| parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter) | parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter) | ||||
| parser.add_argument("--model_file", help="megengine model", required=True) | parser.add_argument("--model_file", help="megengine model", required=True) | ||||
| @@ -78,10 +80,10 @@ def main(): | |||||
| model_file = args.model_file | model_file = args.model_file | ||||
| # copy model file | # copy model file | ||||
| ssh.copy([model_file], workspace) | ssh.copy([model_file], workspace) | ||||
| m = model_file.split('\\')[-1] | |||||
| m = model_file.split("\\")[-1] | |||||
| # run single thread | # run single thread | ||||
| cmd = "cd {} && ./load_and_run {} --fast-run --record-comp-seq --iter 1 --warmup-iter 1".format( | cmd = "cd {} && ./load_and_run {} --fast-run --record-comp-seq --iter 1 --warmup-iter 1".format( | ||||
| workspace, m | |||||
| workspace, m | |||||
| ) | ) | ||||
| try: | try: | ||||
| raw_log = ssh.cmd([cmd]) | raw_log = ssh.cmd([cmd]) | ||||
| @@ -91,6 +93,7 @@ def main(): | |||||
| print("model: {} is static model.".format(m)) | print("model: {} is static model.".format(m)) | ||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||
| LOG_FORMAT = "%(asctime)s - %(levelname)s - %(message)s" | LOG_FORMAT = "%(asctime)s - %(levelname)s - %(message)s" | ||||
| DATE_FORMAT = "%Y/%m/%d %H:%M:%S" | DATE_FORMAT = "%Y/%m/%d %H:%M:%S" | ||||