GitOrigin-RevId: 5684e5ea43
tags/v1.11.1
| @@ -8,7 +8,7 @@ | |||
| import enum | |||
| import os.path | |||
| import shutil | |||
| from typing import Tuple, List | |||
| from typing import List, Tuple | |||
| from library import * | |||
| @@ -5,14 +5,13 @@ | |||
| # | |||
| import enum | |||
| import os.path | |||
| import shutil | |||
| import functools | |||
| import operator | |||
| import os.path | |||
| import shutil | |||
| from library import * | |||
| ################################################################################################### | |||
| # | |||
| # Data structure modeling a GEMM operation | |||
| @@ -1,11 +1,11 @@ | |||
| from generator import ( | |||
| GenerateGemmOperations, | |||
| GenerateGemvOperations, | |||
| from generator import ( # isort: skip; isort: skip | |||
| GenerateConv2dOperations, | |||
| GenerateDeconvOperations, | |||
| GenerateDwconv2dFpropOperations, | |||
| GenerateDwconv2dDgradOperations, | |||
| GenerateDwconv2dFpropOperations, | |||
| GenerateDwconv2dWgradOperations, | |||
| GenerateGemmOperations, | |||
| GenerateGemvOperations, | |||
| ) | |||
| @@ -35,12 +35,14 @@ def write_op_list(f, gen_op, gen_type): | |||
| if gen_op != "gemv": | |||
| f.write(' "all_%s_%s_operations.cu",\n' % (gen_op, gen_type)) | |||
| # Write down a list of merged filenames | |||
| def write_merge_file_name(f, gen_op, gen_type, split_number): | |||
| for i in range(0, split_number): | |||
| f.write(' "{}_{}_{}.cu",\n'.format(gen_op,gen_type,i)) | |||
| f.write(' "{}_{}_{}.cu",\n'.format(gen_op, gen_type, i)) | |||
| if gen_op != "gemv": | |||
| f.write(' "all_{}_{}_operations.cu",\n'.format(gen_op,gen_type)) | |||
| f.write(' "all_{}_{}_operations.cu",\n'.format(gen_op, gen_type)) | |||
| if __name__ == "__main__": | |||
| with open("list.bzl", "w") as f: | |||
| @@ -4,12 +4,12 @@ | |||
| # \brief Generates the CUTLASS Library's instances | |||
| # | |||
| import argparse | |||
| import enum | |||
| import os.path | |||
| import shutil | |||
| import argparse | |||
| import platform | |||
| import string | |||
| from library import * | |||
| from manifest import * | |||
| @@ -899,9 +899,12 @@ def GenerateGemm_Simt(args): | |||
| warpShapes.append([warp0, warp1]) | |||
| # sgemm | |||
| precisionType, precisionBits, threadblockMaxElements, threadblockTilesL0 = precisions[ | |||
| "s" | |||
| ] | |||
| ( | |||
| precisionType, | |||
| precisionBits, | |||
| threadblockMaxElements, | |||
| threadblockTilesL0, | |||
| ) = precisions["s"] | |||
| layouts = [ | |||
| (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), # nn | |||
| @@ -1091,9 +1094,12 @@ def GenerateDwconv2d_Simt(args, conv_kind): | |||
| warpShapes.append([warp0, warp1]) | |||
| # sgemm | |||
| precisionType, precisionBits, threadblockMaxElements, threadblockTilesL0 = precisions[ | |||
| "s" | |||
| ] | |||
| ( | |||
| precisionType, | |||
| precisionBits, | |||
| threadblockMaxElements, | |||
| threadblockTilesL0, | |||
| ) = precisions["s"] | |||
| layouts = [(LayoutType.TensorNCHW, LayoutType.TensorNCHW)] | |||
| @@ -1304,7 +1310,7 @@ def GenerateDwconv2d_TensorOp_884(args, conv_kind): | |||
| for dst_type, dst_layout in zip(dst_types, dst_layouts): | |||
| for alignment_src in alignment_constraints: | |||
| if conv_kind == ConvKind.Wgrad: | |||
| # skip io16xc16 | |||
| # skip io16xc16 | |||
| if math_inst.element_accumulator == DataType.f16: | |||
| continue | |||
| for alignment_diff in alignment_constraints: | |||
| @@ -1319,7 +1325,7 @@ def GenerateDwconv2d_TensorOp_884(args, conv_kind): | |||
| min_cc, | |||
| alignment_src, | |||
| alignment_diff, | |||
| 32, # always f32 output | |||
| 32, # always f32 output | |||
| SpecialOptimizeDesc.NoneSpecialOpt, | |||
| ImplicitGemmMode.GemmNT, | |||
| False, | |||
| @@ -1656,6 +1662,7 @@ def GenerateGemvOperations(args): | |||
| ) | |||
| return GenerateGemv_Simt(args) | |||
| ################################################################################ | |||
| # parameters | |||
| # split_number - the concated file will be divided into split_number parts | |||
| @@ -1668,10 +1675,21 @@ def GenerateGemvOperations(args): | |||
| # epilogue - the epilogue in the file | |||
| # wrapper_path - wrapper path | |||
| ################################################################################ | |||
| def ConcatFile(split_number:int, file_path:str,operations:str,type:str,head:str,required_cuda_ver_major:str, required_cuda_ver_minor:str, epilogue:str, wrapper_path = None): | |||
| def ConcatFile( | |||
| split_number: int, | |||
| file_path: str, | |||
| operations: str, | |||
| type: str, | |||
| head: str, | |||
| required_cuda_ver_major: str, | |||
| required_cuda_ver_minor: str, | |||
| epilogue: str, | |||
| wrapper_path=None, | |||
| ): | |||
| import os | |||
| meragefiledir = file_path | |||
| filenames=os.listdir(meragefiledir) | |||
| filenames = os.listdir(meragefiledir) | |||
| # filter file | |||
| if "tensorop" in type: | |||
| sub_string_1 = "tensorop" | |||
| @@ -1679,197 +1697,183 @@ def ConcatFile(split_number:int, file_path:str,operations:str,type:str,head:str, | |||
| else: | |||
| sub_string_1 = sub_string_2 = "simt" | |||
| if "dwconv2d_" in operations: | |||
| filtered_operations = operations[:2]+operations[9:] | |||
| filtered_operations = operations[:2] + operations[9:] | |||
| elif ("conv2d" in operations) or ("deconv" in operations): | |||
| filtered_operations = "cutlass" | |||
| else: | |||
| filtered_operations = operations | |||
| #get the file list number | |||
| # get the file list number | |||
| file_list = {} | |||
| file_list[operations + type] = 0 | |||
| for filename in filenames: | |||
| if (filtered_operations in filename) and (sub_string_1 in filename) and (sub_string_2 in filename) and ("all_" not in filename): | |||
| if ( | |||
| (filtered_operations in filename) | |||
| and (sub_string_1 in filename) | |||
| and (sub_string_2 in filename) | |||
| and ("all_" not in filename) | |||
| ): | |||
| file_list[operations + type] += 1 | |||
| #concat file for linux | |||
| # concat file for linux | |||
| flag_1 = 0 | |||
| flag_2 = 0 | |||
| for filename in filenames: | |||
| if (filtered_operations in filename) and (sub_string_1 in filename) and (sub_string_2 in filename) and ("all_" not in filename): | |||
| if ( | |||
| (filtered_operations in filename) | |||
| and (sub_string_1 in filename) | |||
| and (sub_string_2 in filename) | |||
| and ("all_" not in filename) | |||
| ): | |||
| flag_1 += 1 | |||
| filepath=meragefiledir+'/'+filename | |||
| if (flag_1 >= flag_2 * (file_list[operations + type]/split_number)) and (flag_1 <= (flag_2 + 1) * (file_list[operations + type]/split_number)): | |||
| file =open(file_path + '/{}_{}_{}.cu'.format(operations,type, flag_2),'a') | |||
| #write Template at the head | |||
| filepath = meragefiledir + "/" + filename | |||
| if (flag_1 >= flag_2 * (file_list[operations + type] / split_number)) and ( | |||
| flag_1 <= (flag_2 + 1) * (file_list[operations + type] / split_number) | |||
| ): | |||
| file = open( | |||
| file_path + "/{}_{}_{}.cu".format(operations, type, flag_2), "a" | |||
| ) | |||
| # write Template at the head | |||
| if wrapper_path is None: | |||
| file.write( | |||
| SubstituteTemplate( | |||
| head, | |||
| { | |||
| "required_cuda_ver_major": str( | |||
| required_cuda_ver_major | |||
| ), | |||
| "required_cuda_ver_minor": str( | |||
| required_cuda_ver_minor | |||
| ), | |||
| "required_cuda_ver_major": str(required_cuda_ver_major), | |||
| "required_cuda_ver_minor": str(required_cuda_ver_minor), | |||
| }, | |||
| ) | |||
| ) | |||
| else: | |||
| file.write( | |||
| SubstituteTemplate( | |||
| head, | |||
| { | |||
| "wrapper_path": wrapper_path, | |||
| "required_cuda_ver_major": str( | |||
| required_cuda_ver_major | |||
| ), | |||
| "required_cuda_ver_minor": str( | |||
| required_cuda_ver_minor | |||
| ), | |||
| }, | |||
| ) | |||
| SubstituteTemplate( | |||
| head, | |||
| { | |||
| "wrapper_path": wrapper_path, | |||
| "required_cuda_ver_major": str(required_cuda_ver_major), | |||
| "required_cuda_ver_minor": str(required_cuda_ver_minor), | |||
| }, | |||
| ) | |||
| ) | |||
| # concat all the remaining files | |||
| if flag_2 == (split_number - 1): | |||
| for line in open(filepath): | |||
| file.writelines(line) | |||
| os.remove(filepath) | |||
| file.write('\n') | |||
| file.write("\n") | |||
| file.write(epilogue) | |||
| continue | |||
| for line in open(filepath): | |||
| file.writelines(line) | |||
| os.remove(filepath) | |||
| file.write('\n') | |||
| file.write("\n") | |||
| file.write(epilogue) | |||
| else: | |||
| #write Template at the head | |||
| # write Template at the head | |||
| if wrapper_path is None: | |||
| file.write( | |||
| SubstituteTemplate( | |||
| head, | |||
| { | |||
| "required_cuda_ver_major": str( | |||
| required_cuda_ver_major | |||
| ), | |||
| "required_cuda_ver_minor": str( | |||
| required_cuda_ver_minor | |||
| ), | |||
| "required_cuda_ver_major": str(required_cuda_ver_major), | |||
| "required_cuda_ver_minor": str(required_cuda_ver_minor), | |||
| }, | |||
| ) | |||
| ) | |||
| else: | |||
| file.write( | |||
| SubstituteTemplate( | |||
| head, | |||
| { | |||
| "wrapper_path": wrapper_path, | |||
| "required_cuda_ver_major": str( | |||
| required_cuda_ver_major | |||
| ), | |||
| "required_cuda_ver_minor": str( | |||
| required_cuda_ver_minor | |||
| ), | |||
| }, | |||
| ) | |||
| SubstituteTemplate( | |||
| head, | |||
| { | |||
| "wrapper_path": wrapper_path, | |||
| "required_cuda_ver_major": str(required_cuda_ver_major), | |||
| "required_cuda_ver_minor": str(required_cuda_ver_minor), | |||
| }, | |||
| ) | |||
| ) | |||
| for line in open(filepath): | |||
| file.writelines(line) | |||
| os.remove(filepath) | |||
| file.write('\n') | |||
| file.write("\n") | |||
| file.write(epilogue) | |||
| file.close() | |||
| flag_2 += 1 | |||
| #concat file for windows | |||
| # concat file for windows | |||
| elif filename[0].isdigit() and ("all_" not in filename): | |||
| flag_1 += 1 | |||
| filepath=meragefiledir+'/'+filename | |||
| if (flag_1 >= flag_2 * (len(filenames)/split_number)) and (flag_1 <= (flag_2 + 1) * (len(filenames)/split_number)): | |||
| file =open(file_path + '/{}_{}_{}.cu'.format(operations,type, flag_2),'a') | |||
| #write Template at the head | |||
| filepath = meragefiledir + "/" + filename | |||
| if (flag_1 >= flag_2 * (len(filenames) / split_number)) and ( | |||
| flag_1 <= (flag_2 + 1) * (len(filenames) / split_number) | |||
| ): | |||
| file = open( | |||
| file_path + "/{}_{}_{}.cu".format(operations, type, flag_2), "a" | |||
| ) | |||
| # write Template at the head | |||
| if wrapper_path is None: | |||
| file.write( | |||
| SubstituteTemplate( | |||
| head, | |||
| { | |||
| "required_cuda_ver_major": str( | |||
| required_cuda_ver_major | |||
| ), | |||
| "required_cuda_ver_minor": str( | |||
| required_cuda_ver_minor | |||
| ), | |||
| "required_cuda_ver_major": str(required_cuda_ver_major), | |||
| "required_cuda_ver_minor": str(required_cuda_ver_minor), | |||
| }, | |||
| ) | |||
| ) | |||
| else: | |||
| file.write( | |||
| SubstituteTemplate( | |||
| head, | |||
| { | |||
| "wrapper_path": wrapper_path, | |||
| "required_cuda_ver_major": str( | |||
| required_cuda_ver_major | |||
| ), | |||
| "required_cuda_ver_minor": str( | |||
| required_cuda_ver_minor | |||
| ), | |||
| }, | |||
| ) | |||
| SubstituteTemplate( | |||
| head, | |||
| { | |||
| "wrapper_path": wrapper_path, | |||
| "required_cuda_ver_major": str(required_cuda_ver_major), | |||
| "required_cuda_ver_minor": str(required_cuda_ver_minor), | |||
| }, | |||
| ) | |||
| ) | |||
| # concat all the remaining files | |||
| if flag_2 == (split_number - 1): | |||
| for line in open(filepath): | |||
| file.writelines(line) | |||
| os.remove(filepath) | |||
| file.write('\n') | |||
| file.write("\n") | |||
| file.write(epilogue) | |||
| continue | |||
| for line in open(filepath): | |||
| file.writelines(line) | |||
| os.remove(filepath) | |||
| file.write('\n') | |||
| file.write("\n") | |||
| file.write(epilogue) | |||
| else: | |||
| #write Template at the head | |||
| # write Template at the head | |||
| if wrapper_path is None: | |||
| file.write( | |||
| SubstituteTemplate( | |||
| head, | |||
| { | |||
| "required_cuda_ver_major": str( | |||
| required_cuda_ver_major | |||
| ), | |||
| "required_cuda_ver_minor": str( | |||
| required_cuda_ver_minor | |||
| ), | |||
| "required_cuda_ver_major": str(required_cuda_ver_major), | |||
| "required_cuda_ver_minor": str(required_cuda_ver_minor), | |||
| }, | |||
| ) | |||
| ) | |||
| else: | |||
| file.write( | |||
| SubstituteTemplate( | |||
| head, | |||
| { | |||
| "wrapper_path": wrapper_path, | |||
| "required_cuda_ver_major": str( | |||
| required_cuda_ver_major | |||
| ), | |||
| "required_cuda_ver_minor": str( | |||
| required_cuda_ver_minor | |||
| ), | |||
| }, | |||
| ) | |||
| SubstituteTemplate( | |||
| head, | |||
| { | |||
| "wrapper_path": wrapper_path, | |||
| "required_cuda_ver_major": str(required_cuda_ver_major), | |||
| "required_cuda_ver_minor": str(required_cuda_ver_minor), | |||
| }, | |||
| ) | |||
| ) | |||
| for line in open(filepath): | |||
| file.writelines(line) | |||
| os.remove(filepath) | |||
| file.write('\n') | |||
| file.write("\n") | |||
| file.write(epilogue) | |||
| file.close() | |||
| flag_2 += 1 | |||
| ################################################################################################### | |||
| ################################################################################################### | |||
| @@ -1940,39 +1944,97 @@ if __name__ == "__main__": | |||
| args.output, operation, short_path | |||
| ) as emitter: | |||
| emitter.emit() | |||
| head = EmitConvSingleKernelWrapper(args.output, operations[0], short_path).header_template | |||
| head = EmitConvSingleKernelWrapper( | |||
| args.output, operations[0], short_path | |||
| ).header_template | |||
| required_cuda_ver_major = operations[0].required_cuda_ver_major | |||
| required_cuda_ver_minor = operations[0].required_cuda_ver_minor | |||
| epilogue = EmitConvSingleKernelWrapper(args.output, operations[0], short_path).epilogue_template | |||
| epilogue = EmitConvSingleKernelWrapper( | |||
| args.output, operations[0], short_path | |||
| ).epilogue_template | |||
| if "tensorop" in args.type: | |||
| ConcatFile(4, args.output,args.operations, args.type, head,required_cuda_ver_major, required_cuda_ver_minor, epilogue) | |||
| ConcatFile( | |||
| 4, | |||
| args.output, | |||
| args.operations, | |||
| args.type, | |||
| head, | |||
| required_cuda_ver_major, | |||
| required_cuda_ver_minor, | |||
| epilogue, | |||
| ) | |||
| else: | |||
| ConcatFile(2, args.output,args.operations, args.type, head,required_cuda_ver_major, required_cuda_ver_minor, epilogue) | |||
| ConcatFile( | |||
| 2, | |||
| args.output, | |||
| args.operations, | |||
| args.type, | |||
| head, | |||
| required_cuda_ver_major, | |||
| required_cuda_ver_minor, | |||
| epilogue, | |||
| ) | |||
| elif args.operations == "gemm": | |||
| for operation in operations: | |||
| with EmitGemmSingleKernelWrapper( | |||
| args.output, operation, short_path | |||
| ) as emitter: | |||
| emitter.emit() | |||
| head = EmitGemmSingleKernelWrapper(args.output, operations[0], short_path).header_template | |||
| head = EmitGemmSingleKernelWrapper( | |||
| args.output, operations[0], short_path | |||
| ).header_template | |||
| required_cuda_ver_major = operations[0].required_cuda_ver_major | |||
| required_cuda_ver_minor = operations[0].required_cuda_ver_minor | |||
| epilogue = EmitGemmSingleKernelWrapper(args.output, operations[0], short_path).epilogue_template | |||
| epilogue = EmitGemmSingleKernelWrapper( | |||
| args.output, operations[0], short_path | |||
| ).epilogue_template | |||
| if args.type == "tensorop884": | |||
| ConcatFile(30, args.output, args.operations, args.type, head,required_cuda_ver_major, required_cuda_ver_minor, epilogue) | |||
| ConcatFile( | |||
| 30, | |||
| args.output, | |||
| args.operations, | |||
| args.type, | |||
| head, | |||
| required_cuda_ver_major, | |||
| required_cuda_ver_minor, | |||
| epilogue, | |||
| ) | |||
| else: | |||
| ConcatFile(2, args.output, args.operations, args.type, head,required_cuda_ver_major, required_cuda_ver_minor, epilogue) | |||
| ConcatFile( | |||
| 2, | |||
| args.output, | |||
| args.operations, | |||
| args.type, | |||
| head, | |||
| required_cuda_ver_major, | |||
| required_cuda_ver_minor, | |||
| epilogue, | |||
| ) | |||
| elif args.operations == "gemv": | |||
| for operation in operations: | |||
| with EmitGemvSingleKernelWrapper( | |||
| args.output, operation, gemv_wrapper_path, short_path | |||
| ) as emitter: | |||
| emitter.emit() | |||
| head = EmitGemvSingleKernelWrapper(args.output, operations[0], gemv_wrapper_path, short_path).header_template | |||
| head = EmitGemvSingleKernelWrapper( | |||
| args.output, operations[0], gemv_wrapper_path, short_path | |||
| ).header_template | |||
| required_cuda_ver_major = operations[0].required_cuda_ver_major | |||
| required_cuda_ver_minor = operations[0].required_cuda_ver_minor | |||
| epilogue = EmitGemvSingleKernelWrapper(args.output, operations[0], gemv_wrapper_path, short_path).epilogue_template | |||
| ConcatFile(2, args.output,args.operations, args.type, head,required_cuda_ver_major, required_cuda_ver_minor, epilogue, wrapper_path = gemv_wrapper_path) | |||
| epilogue = EmitGemvSingleKernelWrapper( | |||
| args.output, operations[0], gemv_wrapper_path, short_path | |||
| ).epilogue_template | |||
| ConcatFile( | |||
| 2, | |||
| args.output, | |||
| args.operations, | |||
| args.type, | |||
| head, | |||
| required_cuda_ver_major, | |||
| required_cuda_ver_minor, | |||
| epilogue, | |||
| wrapper_path=gemv_wrapper_path, | |||
| ) | |||
| if args.operations != "gemv": | |||
| GenerateManifest(args, operations, args.output) | |||
| @@ -4,11 +4,11 @@ | |||
| # \brief Generates the CUTLASS Library's instances | |||
| # | |||
| import enum | |||
| import re | |||
| ################################################################################################### | |||
| import enum | |||
| # The following block implements enum.auto() for Python 3.5 variants that don't include it such | |||
| # as the default 3.5.2 on Ubuntu 16.04. | |||
| @@ -8,9 +8,9 @@ import enum | |||
| import os.path | |||
| import shutil | |||
| from library import * | |||
| from gemm_operation import * | |||
| from conv2d_operation import * | |||
| from gemm_operation import * | |||
| from library import * | |||
| ################################################################################################### | |||
| @@ -1,59 +1,67 @@ | |||
| #!/usr/bin/env python3 | |||
| # -*- coding: utf-8 -*- | |||
| import os | |||
| import argparse | |||
| import os | |||
| from gen_elemwise_utils import DTYPES | |||
| def main(): | |||
| parser = argparse.ArgumentParser( | |||
| description='generate elemwise impl files', | |||
| formatter_class=argparse.ArgumentDefaultsHelpFormatter) | |||
| parser.add_argument('--type', type=str, choices=['cuda'], | |||
| default='cuda', | |||
| help='generate cuda cond take kernel file') | |||
| parser.add_argument('output', help='output directory') | |||
| description="generate elemwise impl files", | |||
| formatter_class=argparse.ArgumentDefaultsHelpFormatter, | |||
| ) | |||
| parser.add_argument( | |||
| "--type", | |||
| type=str, | |||
| choices=["cuda"], | |||
| default="cuda", | |||
| help="generate cuda cond take kernel file", | |||
| ) | |||
| parser.add_argument("output", help="output directory") | |||
| args = parser.parse_args() | |||
| if not os.path.isdir(args.output): | |||
| os.makedirs(args.output) | |||
| assert args.type =='cuda' | |||
| cpp_ext = 'cu' | |||
| assert args.type == "cuda" | |||
| cpp_ext = "cu" | |||
| for dtype in DTYPES.keys(): | |||
| fname = '{}.{}'.format(dtype, cpp_ext) | |||
| fname = "{}.{}".format(dtype, cpp_ext) | |||
| fname = os.path.join(args.output, fname) | |||
| with open(fname, 'w') as fout: | |||
| with open(fname, "w") as fout: | |||
| w = lambda s: print(s, file=fout) | |||
| w('// generated by gen_cond_take_kern_impls.py') | |||
| w("// generated by gen_cond_take_kern_impls.py") | |||
| w('#include "../kern.inl"') | |||
| w('') | |||
| if dtype == 'dt_float16' or dtype == 'dt_bfloat16': | |||
| w('#if !MEGDNN_DISABLE_FLOAT16') | |||
| w('namespace megdnn {') | |||
| w('namespace cuda {') | |||
| w('namespace cond_take {') | |||
| w('') | |||
| w('inst_genidx(::megdnn::dtype::{})'.format(DTYPES[dtype][0])) | |||
| w('#undef inst_genidx') | |||
| w('') | |||
| w('inst_copy(::megdnn::dtype::{})'.format(DTYPES[dtype][0])) | |||
| w('#undef inst_copy') | |||
| w('#undef inst_copy_') | |||
| w('') | |||
| w('} // cond_take') | |||
| w('} // cuda') | |||
| w('} // megdnn') | |||
| if dtype == 'dt_float16' or dtype == 'dt_bfloat16': | |||
| w('#endif') | |||
| print('generated {}'.format(fname)) | |||
| w("") | |||
| if dtype == "dt_float16" or dtype == "dt_bfloat16": | |||
| w("#if !MEGDNN_DISABLE_FLOAT16") | |||
| w("namespace megdnn {") | |||
| w("namespace cuda {") | |||
| w("namespace cond_take {") | |||
| w("") | |||
| w("inst_genidx(::megdnn::dtype::{})".format(DTYPES[dtype][0])) | |||
| w("#undef inst_genidx") | |||
| w("") | |||
| w("inst_copy(::megdnn::dtype::{})".format(DTYPES[dtype][0])) | |||
| w("#undef inst_copy") | |||
| w("#undef inst_copy_") | |||
| w("") | |||
| w("} // cond_take") | |||
| w("} // cuda") | |||
| w("} // megdnn") | |||
| if dtype == "dt_float16" or dtype == "dt_bfloat16": | |||
| w("#endif") | |||
| print("generated {}".format(fname)) | |||
| os.utime(args.output) | |||
| if __name__ == '__main__': | |||
| if __name__ == "__main__": | |||
| main() | |||
| @@ -1,37 +1,47 @@ | |||
| #!/usr/bin/env python3 | |||
| # -*- coding: utf-8 -*- | |||
| import os | |||
| import argparse | |||
| import itertools | |||
| import os | |||
| PREFIXES = { | |||
| "dp4a": [ | |||
| ("batch_conv_bias_int8_implicit_gemm_precomp_ncdiv4hw4", True), | |||
| ("batch_conv_bias_int8_gemm_ncdiv4hw4", False), | |||
| ("batch_conv_bias_int8_gemm_ncdiv4hw4_ldg_128", False), | |||
| ] | |||
| } | |||
| PREFIXES = {"dp4a": [("batch_conv_bias_int8_implicit_gemm_precomp_ncdiv4hw4", True), ("batch_conv_bias_int8_gemm_ncdiv4hw4", False), ("batch_conv_bias_int8_gemm_ncdiv4hw4_ldg_128", False)]} | |||
| ACTIVATIONS = {1: ("IDENTITY", "_id"), 2: ("RELU", "_relu"), 3: ("H_SWISH", "_hswish")} | |||
| ACTIVATIONS = {1: ("IDENTITY", "_id"), | |||
| 2: ("RELU", "_relu"), | |||
| 3: ("H_SWISH", "_hswish")} | |||
| BIASES = { | |||
| 1: ("PerElementBiasVisitor", "_per_elem"), | |||
| 2: ("PerChannelBiasVisitor", "_per_chan"), | |||
| } | |||
| BIASES = {1: ("PerElementBiasVisitor", "_per_elem"), | |||
| 2: ("PerChannelBiasVisitor", "_per_chan")} | |||
| SUFFIXES = {"dp4a": [""], "imma": [""]} | |||
| SUFFIXES = {"dp4a": [""], | |||
| "imma": [""]} | |||
| def main(): | |||
| parser = argparse.ArgumentParser( | |||
| description='generate cuda batch conv bias (dp4a/imma) kern impl files', | |||
| formatter_class=argparse.ArgumentDefaultsHelpFormatter) | |||
| parser.add_argument('--type', type=str, choices=['dp4a', | |||
| 'imma'], | |||
| default='dp4a', help='generate cuda conv bias kernel file') | |||
| parser.add_argument('output', help='output directory') | |||
| description="generate cuda batch conv bias (dp4a/imma) kern impl files", | |||
| formatter_class=argparse.ArgumentDefaultsHelpFormatter, | |||
| ) | |||
| parser.add_argument( | |||
| "--type", | |||
| type=str, | |||
| choices=["dp4a", "imma"], | |||
| default="dp4a", | |||
| help="generate cuda conv bias kernel file", | |||
| ) | |||
| parser.add_argument("output", help="output directory") | |||
| args = parser.parse_args() | |||
| if not os.path.isdir(args.output): | |||
| os.makedirs(args.output) | |||
| inst = ''' | |||
| inst = """ | |||
| template void megdnn::cuda::batch_conv_bias::do_PREFIXSUFFIX<BIAS, | |||
| IConvEpilogue<Activation<megdnn::param_enumv::BatchConvBias::NonlineMode::ACTIVATION>>>( | |||
| const int8_t* d_src, | |||
| @@ -41,7 +51,7 @@ template void megdnn::cuda::batch_conv_bias::do_PREFIXSUFFIX<BIAS, | |||
| const ConvParam& param, | |||
| float alpha, | |||
| float beta, | |||
| cudaStream_t stream);''' | |||
| cudaStream_t stream);""" | |||
| for prefix in PREFIXES[args.type]: | |||
| for suffix in SUFFIXES[args.type]: | |||
| @@ -52,17 +62,23 @@ template void megdnn::cuda::batch_conv_bias::do_PREFIXSUFFIX<BIAS, | |||
| fname = os.path.join(args.output, fname) | |||
| with open(fname, "w") as fout: | |||
| w = lambda s: print(s, file=fout) | |||
| w('// generated by gen_batch_cuda_conv_bias_kern_impls.py') | |||
| cur_inst = inst.replace("PREFIX", prefix[0]).replace("SUFFIX", suffix).replace("BIAS", bias[0]).replace("ACTIVATION", act[0]) | |||
| w("// generated by gen_batch_cuda_conv_bias_kern_impls.py") | |||
| cur_inst = ( | |||
| inst.replace("PREFIX", prefix[0]) | |||
| .replace("SUFFIX", suffix) | |||
| .replace("BIAS", bias[0]) | |||
| .replace("ACTIVATION", act[0]) | |||
| ) | |||
| if has_workspace: | |||
| cur_inst = cur_inst.replace("WORKSPACE", "\nint* d_workspace, ") | |||
| else: | |||
| cur_inst = cur_inst.replace("WORKSPACE", "") | |||
| cur_inst = cur_inst.replace("WORKSPACE", "") | |||
| w('#include "../{}{}.cuinl"'.format(prefix[0], suffix)) | |||
| w(cur_inst) | |||
| print('generated {}'.format(fname)) | |||
| print("generated {}".format(fname)) | |||
| os.utime(args.output) | |||
| if __name__ == '__main__': | |||
| if __name__ == "__main__": | |||
| main() | |||
| @@ -1,39 +1,57 @@ | |||
| #!/usr/bin/env python3 | |||
| # -*- coding: utf-8 -*- | |||
| import os | |||
| import argparse | |||
| import itertools | |||
| import os | |||
| PREFIXES = { | |||
| "dp4a": "conv_bias_int8_implicit_gemm_cdiv4hwn4", | |||
| "imma": "conv_bias_int8_implicit_gemm", | |||
| } | |||
| PREFIXES = {"dp4a": "conv_bias_int8_implicit_gemm_cdiv4hwn4", "imma": "conv_bias_int8_implicit_gemm"} | |||
| ACTIVATIONS = {1: ("IDENTITY", "_id"), 2: ("RELU", "_relu"), 3: ("H_SWISH", "_hswish")} | |||
| ACTIVATIONS = {1: ("IDENTITY", "_id"), | |||
| 2: ("RELU", "_relu"), | |||
| 3: ("H_SWISH", "_hswish")} | |||
| BIASES = { | |||
| 1: ("PerElementBiasVisitor", "_per_elem"), | |||
| 2: ("PerChannelBiasVisitor", "_per_chan"), | |||
| } | |||
| BIASES = {1: ("PerElementBiasVisitor", "_per_elem"), | |||
| 2: ("PerChannelBiasVisitor", "_per_chan")} | |||
| SUFFIXES = { | |||
| "dp4a": ["", "_ld_64bit", "_ld_64bit_unroll_width", "_unroll_width"], | |||
| "imma": [ | |||
| "_imma16x16x16_cdiv4hwn4", | |||
| "_imma8x32x16_cdiv4hwn4", | |||
| "_imma32x8x16_cdiv4hwn4", | |||
| "_imma16x16x16_cdiv4hwn4_reorder_filter", | |||
| "_imma8x32x16_cdiv4hwn4_reorder_filter", | |||
| "_imma32x8x16_cdiv4hwn4_reorder_filter", | |||
| "_imma16x16x16_cdiv4hwn4_unroll_width", | |||
| "_imma8x32x16_cdiv4hwn4_unroll_width", | |||
| "_imma32x8x16_cdiv4hwn4_unroll_width", | |||
| ], | |||
| } | |||
| SUFFIXES = {"dp4a": ["", "_ld_64bit", "_ld_64bit_unroll_width", "_unroll_width"], | |||
| "imma": ["_imma16x16x16_cdiv4hwn4", "_imma8x32x16_cdiv4hwn4", "_imma32x8x16_cdiv4hwn4", | |||
| "_imma16x16x16_cdiv4hwn4_reorder_filter", "_imma8x32x16_cdiv4hwn4_reorder_filter", "_imma32x8x16_cdiv4hwn4_reorder_filter", | |||
| "_imma16x16x16_cdiv4hwn4_unroll_width", "_imma8x32x16_cdiv4hwn4_unroll_width", "_imma32x8x16_cdiv4hwn4_unroll_width"]} | |||
| def main(): | |||
| parser = argparse.ArgumentParser( | |||
| description='generate cuda conv bias (dp4a/imma) kern impl files', | |||
| formatter_class=argparse.ArgumentDefaultsHelpFormatter) | |||
| parser.add_argument('--type', type=str, choices=['dp4a', | |||
| 'imma'], | |||
| default='dp4a', help='generate cuda conv bias kernel file') | |||
| parser.add_argument('output', help='output directory') | |||
| description="generate cuda conv bias (dp4a/imma) kern impl files", | |||
| formatter_class=argparse.ArgumentDefaultsHelpFormatter, | |||
| ) | |||
| parser.add_argument( | |||
| "--type", | |||
| type=str, | |||
| choices=["dp4a", "imma"], | |||
| default="dp4a", | |||
| help="generate cuda conv bias kernel file", | |||
| ) | |||
| parser.add_argument("output", help="output directory") | |||
| args = parser.parse_args() | |||
| if not os.path.isdir(args.output): | |||
| os.makedirs(args.output) | |||
| inst = ''' | |||
| inst = """ | |||
| template void megdnn::cuda::conv_bias_int8::do_PREFIXSUFFIX<BIAS, | |||
| IConvEpilogue<Activation<megdnn::param_enumv::ConvBias::NonlineMode::ACTIVATION>>>( | |||
| const int8_t* d_src, | |||
| @@ -43,7 +61,7 @@ template void megdnn::cuda::conv_bias_int8::do_PREFIXSUFFIX<BIAS, | |||
| const ConvParam& param, | |||
| float alpha, | |||
| float beta, | |||
| cudaStream_t stream);''' | |||
| cudaStream_t stream);""" | |||
| for suffix in SUFFIXES[args.type]: | |||
| for _, act in ACTIVATIONS.items(): | |||
| @@ -53,13 +71,19 @@ template void megdnn::cuda::conv_bias_int8::do_PREFIXSUFFIX<BIAS, | |||
| fname = os.path.join(args.output, fname) | |||
| with open(fname, "w") as fout: | |||
| w = lambda s: print(s, file=fout) | |||
| w('// generated by gen_cuda_conv_bias_kern_impls.py') | |||
| cur_inst = inst.replace("PREFIX", prefix).replace("SUFFIX", suffix).replace("BIAS", bias[0]).replace("ACTIVATION", act[0]) | |||
| w("// generated by gen_cuda_conv_bias_kern_impls.py") | |||
| cur_inst = ( | |||
| inst.replace("PREFIX", prefix) | |||
| .replace("SUFFIX", suffix) | |||
| .replace("BIAS", bias[0]) | |||
| .replace("ACTIVATION", act[0]) | |||
| ) | |||
| w('#include "../{}{}.cuinl"'.format(prefix, suffix)) | |||
| w(cur_inst) | |||
| print('generated {}'.format(fname)) | |||
| print("generated {}".format(fname)) | |||
| os.utime(args.output) | |||
| if __name__ == '__main__': | |||
| if __name__ == "__main__": | |||
| main() | |||
| @@ -1,34 +1,39 @@ | |||
| #!/usr/bin/env python3 | |||
| # -*- coding: utf-8 -*- | |||
| import os | |||
| import argparse | |||
| import os | |||
| from gen_elemwise_utils import ARITIES, MODES | |||
| def main(): | |||
| parser = argparse.ArgumentParser( | |||
| description='generate elemwise each mode', | |||
| formatter_class=argparse.ArgumentDefaultsHelpFormatter) | |||
| description="generate elemwise each mode", | |||
| formatter_class=argparse.ArgumentDefaultsHelpFormatter, | |||
| ) | |||
| parser.add_argument('output', help='output directory') | |||
| parser.add_argument("output", help="output directory") | |||
| args = parser.parse_args() | |||
| with open(args.output, 'w') as fout: | |||
| with open(args.output, "w") as fout: | |||
| w = lambda s: print(s, file=fout) | |||
| w('// generated by gen_elemwise_each_mode.py') | |||
| w("// generated by gen_elemwise_each_mode.py") | |||
| keys = list(MODES.keys()) | |||
| keys.sort() | |||
| for (anum, ctype) in keys: | |||
| w('#define MEGDNN_FOREACH_ELEMWISE_MODE_{}_{}(cb) \\'.format( | |||
| ARITIES[anum], ctype)) | |||
| w( | |||
| "#define MEGDNN_FOREACH_ELEMWISE_MODE_{}_{}(cb) \\".format( | |||
| ARITIES[anum], ctype | |||
| ) | |||
| ) | |||
| for mode in MODES[(anum, ctype)]: | |||
| w(' MEGDNN_ELEMWISE_MODE_ENABLE({}, cb) \\'.format(mode)) | |||
| w('') | |||
| w(" MEGDNN_ELEMWISE_MODE_ENABLE({}, cb) \\".format(mode)) | |||
| w("") | |||
| print('generated each_mode.inl') | |||
| print("generated each_mode.inl") | |||
| os.utime(args.output) | |||
| if __name__ == '__main__': | |||
| if __name__ == "__main__": | |||
| main() | |||
| @@ -1,56 +1,63 @@ | |||
| #!/usr/bin/env python3 | |||
| # -*- coding: utf-8 -*- | |||
| import os | |||
| import argparse | |||
| import itertools | |||
| import os | |||
| from gen_elemwise_utils import ARITIES, DTYPES, MODES | |||
| def main(): | |||
| parser = argparse.ArgumentParser( | |||
| description='generate elemwise impl files', | |||
| formatter_class=argparse.ArgumentDefaultsHelpFormatter) | |||
| parser.add_argument('--type', type=str, choices=['cuda', | |||
| 'hip', | |||
| 'cpp'], | |||
| default='cpp', help='generate cuda/hip kernel file') | |||
| parser.add_argument('output', help='output directory') | |||
| description="generate elemwise impl files", | |||
| formatter_class=argparse.ArgumentDefaultsHelpFormatter, | |||
| ) | |||
| parser.add_argument( | |||
| "--type", | |||
| type=str, | |||
| choices=["cuda", "hip", "cpp"], | |||
| default="cpp", | |||
| help="generate cuda/hip kernel file", | |||
| ) | |||
| parser.add_argument("output", help="output directory") | |||
| args = parser.parse_args() | |||
| if not os.path.isdir(args.output): | |||
| os.makedirs(args.output) | |||
| if args.type == 'cuda': | |||
| cpp_ext = 'cu' | |||
| elif args.type == 'hip': | |||
| cpp_ext = 'cpp.hip' | |||
| if args.type == "cuda": | |||
| cpp_ext = "cu" | |||
| elif args.type == "hip": | |||
| cpp_ext = "cpp.hip" | |||
| else: | |||
| assert args.type == 'cpp' | |||
| cpp_ext = 'cpp' | |||
| assert args.type == "cpp" | |||
| cpp_ext = "cpp" | |||
| for anum, ctype in itertools.product(ARITIES.keys(), DTYPES.keys()): | |||
| for mode in MODES[(anum, DTYPES[ctype][1])]: | |||
| formode = 'MEGDNN_ELEMWISE_MODE_ENABLE({}, cb)'.format(mode) | |||
| fname = '{}_{}.{}'.format(mode, ctype, cpp_ext) | |||
| formode = "MEGDNN_ELEMWISE_MODE_ENABLE({}, cb)".format(mode) | |||
| fname = "{}_{}.{}".format(mode, ctype, cpp_ext) | |||
| fname = os.path.join(args.output, fname) | |||
| with open(fname, 'w') as fout: | |||
| with open(fname, "w") as fout: | |||
| w = lambda s: print(s, file=fout) | |||
| w('// generated by gen_elemwise_kern_impls.py') | |||
| w("// generated by gen_elemwise_kern_impls.py") | |||
| if ctype == 'dt_float16' or ctype == 'dt_bfloat16': | |||
| w('#if !MEGDNN_DISABLE_FLOAT16') | |||
| if ctype == "dt_float16" or ctype == "dt_bfloat16": | |||
| w("#if !MEGDNN_DISABLE_FLOAT16") | |||
| w('#define KERN_IMPL_MODE(cb) {}'.format(formode)) | |||
| w('#define KERN_IMPL_ARITY {}'.format(anum)) | |||
| w('#define KERN_IMPL_CTYPE {}'.format(ctype)) | |||
| w("#define KERN_IMPL_MODE(cb) {}".format(formode)) | |||
| w("#define KERN_IMPL_ARITY {}".format(anum)) | |||
| w("#define KERN_IMPL_CTYPE {}".format(ctype)) | |||
| w('#include "../kern_impl.inl"') | |||
| if ctype == 'dt_float16' or ctype == 'dt_bfloat16': | |||
| w('#endif') | |||
| if ctype == "dt_float16" or ctype == "dt_bfloat16": | |||
| w("#endif") | |||
| print('generated {}'.format(fname)) | |||
| print("generated {}".format(fname)) | |||
| os.utime(args.output) | |||
| if __name__ == '__main__': | |||
| if __name__ == "__main__": | |||
| main() | |||
| @@ -1,52 +1,66 @@ | |||
| #!/usr/bin/env python3 | |||
| # -*- coding: utf-8 -*- | |||
| import os | |||
| import argparse | |||
| import itertools | |||
| from gen_elemwise_multi_type_utils import SUPPORT_DTYPES, MODES, SUPPORT_QINT32_DTYPES, QINT32_MODES | |||
| import os | |||
| from gen_elemwise_multi_type_utils import ( # isort: skip; isort: skip | |||
| MODES, | |||
| QINT32_MODES, | |||
| SUPPORT_DTYPES, | |||
| SUPPORT_QINT32_DTYPES, | |||
| ) | |||
| def generate(modes, support_dtypes, output, cpp_ext): | |||
| for anum, ctype in itertools.product(modes.keys(), support_dtypes): | |||
| print('{} : {}'.format(anum, ctype)) | |||
| print("{} : {}".format(anum, ctype)) | |||
| src_ctype = ctype[0] | |||
| dst_ctype = ctype[1] | |||
| for mode in modes[anum]: | |||
| formode = 'MEGDNN_ELEMWISE_MODE_ENABLE({}, cb)'.format(mode) | |||
| fname = '{}_{}_{}.{}'.format(mode, src_ctype, dst_ctype, cpp_ext) | |||
| formode = "MEGDNN_ELEMWISE_MODE_ENABLE({}, cb)".format(mode) | |||
| fname = "{}_{}_{}.{}".format(mode, src_ctype, dst_ctype, cpp_ext) | |||
| fname = os.path.join(output, fname) | |||
| with open(fname, 'w') as fout: | |||
| with open(fname, "w") as fout: | |||
| w = lambda s: print(s, file=fout) | |||
| w('// generated by gen_elemwise_multi_type_kern_impls.py') | |||
| w("// generated by gen_elemwise_multi_type_kern_impls.py") | |||
| w('#define KERN_IMPL_MODE(cb) {}'.format(formode)) | |||
| w('#define KERN_IMPL_ARITY {}'.format(anum)) | |||
| w('#define KERN_IMPL_STYPE {}'.format(src_ctype)) | |||
| w('#define KERN_IMPL_DTYPE {}'.format(dst_ctype)) | |||
| w("#define KERN_IMPL_MODE(cb) {}".format(formode)) | |||
| w("#define KERN_IMPL_ARITY {}".format(anum)) | |||
| w("#define KERN_IMPL_STYPE {}".format(src_ctype)) | |||
| w("#define KERN_IMPL_DTYPE {}".format(dst_ctype)) | |||
| w('#include "../kern_impl.inl"') | |||
| print('generated {}'.format(fname)) | |||
| print("generated {}".format(fname)) | |||
| def main(): | |||
| parser = argparse.ArgumentParser( | |||
| description='generate elemwise impl files', | |||
| formatter_class=argparse.ArgumentDefaultsHelpFormatter) | |||
| parser.add_argument('--type', type=str, choices=['cuda'], | |||
| default='cuda', help='generate cuda kernel file') | |||
| parser.add_argument('output', help='output directory') | |||
| description="generate elemwise impl files", | |||
| formatter_class=argparse.ArgumentDefaultsHelpFormatter, | |||
| ) | |||
| parser.add_argument( | |||
| "--type", | |||
| type=str, | |||
| choices=["cuda"], | |||
| default="cuda", | |||
| help="generate cuda kernel file", | |||
| ) | |||
| parser.add_argument("output", help="output directory") | |||
| args = parser.parse_args() | |||
| if not os.path.isdir(args.output): | |||
| os.makedirs(args.output) | |||
| assert args.type == 'cuda' | |||
| if args.type == 'cuda': | |||
| cpp_ext = 'cu' | |||
| assert args.type == "cuda" | |||
| if args.type == "cuda": | |||
| cpp_ext = "cu" | |||
| generate(MODES, SUPPORT_DTYPES, args.output, cpp_ext) | |||
| generate(QINT32_MODES, SUPPORT_QINT32_DTYPES, args.output, cpp_ext) | |||
| os.utime(args.output) | |||
| if __name__ == '__main__': | |||
| if __name__ == "__main__": | |||
| main() | |||
| @@ -1,48 +1,131 @@ | |||
| # As cuda currently do not support quint8, so we just ignore it. | |||
| SUPPORT_DTYPES = [('dt_qint8', 'dt_qint8')] | |||
| SUPPORT_QINT32_DTYPES = [('dt_qint32', 'dt_qint8'), ('dt_qint8', 'dt_qint32'), | |||
| ('dt_qint4', 'dt_qint32'), ('dt_quint4', 'dt_qint32')] | |||
| SUPPORT_DTYPES = [("dt_qint8", "dt_qint8")] | |||
| SUPPORT_QINT32_DTYPES = [ | |||
| ("dt_qint32", "dt_qint8"), | |||
| ("dt_qint8", "dt_qint32"), | |||
| ("dt_qint4", "dt_qint32"), | |||
| ("dt_quint4", "dt_qint32"), | |||
| ] | |||
| SUPPORT_DTYPES_Q4 = [('dt_qint4', 'dt_qint4'), ('dt_quint4', 'dt_quint4')] | |||
| SUPPORT_QINT32_DTYPES_Q4 = [('dt_qint32', 'dt_qint4'), ('dt_qint32', 'dt_quint4')] | |||
| SUPPORT_DTYPES_Q4 = [("dt_qint4", "dt_qint4"), ("dt_quint4", "dt_quint4")] | |||
| SUPPORT_QINT32_DTYPES_Q4 = [("dt_qint32", "dt_qint4"), ("dt_qint32", "dt_quint4")] | |||
| SUPPORT_ARRITY2_DTYPES = ['dt_int32', 'dt_uint8', 'dt_int8', 'dt_int16', 'dt_bool', 'dt_float32', | |||
| 'dt_float16', 'dt_bfloat16'] | |||
| SUPPORT_ARRITY1_DTYPES = ['dt_float32','dt_float16', 'dt_bfloat16'] | |||
| SUPPORT_ARRITY2_DTYPES = [ | |||
| "dt_int32", | |||
| "dt_uint8", | |||
| "dt_int8", | |||
| "dt_int16", | |||
| "dt_bool", | |||
| "dt_float32", | |||
| "dt_float16", | |||
| "dt_bfloat16", | |||
| ] | |||
| SUPPORT_ARRITY1_DTYPES = ["dt_float32", "dt_float16", "dt_bfloat16"] | |||
| MODES = { | |||
| 1: ['RELU', 'ABS', 'NEGATE', 'ACOS', 'ASIN', 'CEIL', 'COS', | |||
| 'EXP', 'EXPM1', 'FLOOR', 'LOG', 'LOG1P', 'SIGMOID', 'SIN', | |||
| 'TANH', 'FAST_TANH', 'ROUND', 'ERF', 'ERFINV', 'ERFC', | |||
| 'ERFCINV', 'H_SWISH', 'SILU', 'GELU'], | |||
| 2: ['ABS_GRAD', 'ADD', 'FLOOR_DIV', 'MAX', 'MIN', 'MOD', 'MUL', | |||
| 'SIGMOID_GRAD', 'SUB', 'SWITCH_GT0', 'TANH_GRAD', 'LT', | |||
| 'LEQ', 'EQ', 'FUSE_ADD_RELU', 'TRUE_DIV', 'POW', | |||
| 'LOG_SUM_EXP', 'FUSE_ADD_TANH', 'FAST_TANH_GRAD', | |||
| 'FUSE_ADD_SIGMOID', 'ATAN2', 'H_SWISH_GRAD', | |||
| 'FUSE_ADD_H_SWISH', 'SILU_GRAD', 'GELU_GRAD'], | |||
| 3: ['COND_LEQ_MOV', 'COND_LT_MOV', 'FUSE_MUL_ADD3'], | |||
| 1: [ | |||
| "RELU", | |||
| "ABS", | |||
| "NEGATE", | |||
| "ACOS", | |||
| "ASIN", | |||
| "CEIL", | |||
| "COS", | |||
| "EXP", | |||
| "EXPM1", | |||
| "FLOOR", | |||
| "LOG", | |||
| "LOG1P", | |||
| "SIGMOID", | |||
| "SIN", | |||
| "TANH", | |||
| "FAST_TANH", | |||
| "ROUND", | |||
| "ERF", | |||
| "ERFINV", | |||
| "ERFC", | |||
| "ERFCINV", | |||
| "H_SWISH", | |||
| "SILU", | |||
| "GELU", | |||
| ], | |||
| 2: [ | |||
| "ABS_GRAD", | |||
| "ADD", | |||
| "FLOOR_DIV", | |||
| "MAX", | |||
| "MIN", | |||
| "MOD", | |||
| "MUL", | |||
| "SIGMOID_GRAD", | |||
| "SUB", | |||
| "SWITCH_GT0", | |||
| "TANH_GRAD", | |||
| "LT", | |||
| "LEQ", | |||
| "EQ", | |||
| "FUSE_ADD_RELU", | |||
| "TRUE_DIV", | |||
| "POW", | |||
| "LOG_SUM_EXP", | |||
| "FUSE_ADD_TANH", | |||
| "FAST_TANH_GRAD", | |||
| "FUSE_ADD_SIGMOID", | |||
| "ATAN2", | |||
| "H_SWISH_GRAD", | |||
| "FUSE_ADD_H_SWISH", | |||
| "SILU_GRAD", | |||
| "GELU_GRAD", | |||
| ], | |||
| 3: ["COND_LEQ_MOV", "COND_LT_MOV", "FUSE_MUL_ADD3"], | |||
| } | |||
| QINT4_MODES = { | |||
| 1: ['RELU', 'ABS', 'NEGATE', 'CEIL', 'FLOOR', 'SIGMOID', | |||
| 'TANH', 'FAST_TANH', 'ROUND', 'H_SWISH'], | |||
| 2: ['ADD', 'MAX', 'MIN', 'MUL', 'SUB', 'SWITCH_GT0', | |||
| 'LT', 'LEQ', 'EQ', 'FUSE_ADD_RELU', 'FUSE_ADD_TANH', | |||
| 'FUSE_ADD_SIGMOID', 'FUSE_ADD_H_SWISH'], | |||
| 3: ['COND_LEQ_MOV', 'COND_LT_MOV', 'FUSE_MUL_ADD3'], | |||
| 1: [ | |||
| "RELU", | |||
| "ABS", | |||
| "NEGATE", | |||
| "CEIL", | |||
| "FLOOR", | |||
| "SIGMOID", | |||
| "TANH", | |||
| "FAST_TANH", | |||
| "ROUND", | |||
| "H_SWISH", | |||
| ], | |||
| 2: [ | |||
| "ADD", | |||
| "MAX", | |||
| "MIN", | |||
| "MUL", | |||
| "SUB", | |||
| "SWITCH_GT0", | |||
| "LT", | |||
| "LEQ", | |||
| "EQ", | |||
| "FUSE_ADD_RELU", | |||
| "FUSE_ADD_TANH", | |||
| "FUSE_ADD_SIGMOID", | |||
| "FUSE_ADD_H_SWISH", | |||
| ], | |||
| 3: ["COND_LEQ_MOV", "COND_LT_MOV", "FUSE_MUL_ADD3"], | |||
| } | |||
| QINT32_MODES = { | |||
| 1: ['RELU', 'SIGMOID', 'TANH', 'FAST_TANH', 'H_SWISH'], | |||
| 2: ['ADD', 'FUSE_ADD_RELU', 'FUSE_ADD_SIGMOID', | |||
| 'FUSE_ADD_TANH', 'FUSE_ADD_H_SWISH'] | |||
| 1: ["RELU", "SIGMOID", "TANH", "FAST_TANH", "H_SWISH"], | |||
| 2: [ | |||
| "ADD", | |||
| "FUSE_ADD_RELU", | |||
| "FUSE_ADD_SIGMOID", | |||
| "FUSE_ADD_TANH", | |||
| "FUSE_ADD_H_SWISH", | |||
| ], | |||
| } | |||
| ARRITY1_BOOL_MODES = { | |||
| 1: ['ISINF','ISNAN'], | |||
| 1: ["ISINF", "ISNAN"], | |||
| } | |||
| ARRITY2_BOOL_MODES = { | |||
| 2: ['EQ','LEQ','NEQ','LT'], | |||
| 2: ["EQ", "LEQ", "NEQ", "LT"], | |||
| } | |||
| @@ -1,52 +1,57 @@ | |||
| #!/usr/bin/env python3 | |||
| # -*- coding: utf-8 -*- | |||
| import os | |||
| import argparse | |||
| import os | |||
| from gen_elemwise_utils import DTYPES | |||
| def main(): | |||
| parser = argparse.ArgumentParser( | |||
| description='generate elemwise impl files', | |||
| formatter_class=argparse.ArgumentDefaultsHelpFormatter) | |||
| parser.add_argument('--type', type=str, choices=[ | |||
| 'cuda', | |||
| 'hip' | |||
| ], | |||
| default='cuda', | |||
| help='generate cuda/hip elemwise special kernel file') | |||
| parser.add_argument('output', help='output directory') | |||
| description="generate elemwise impl files", | |||
| formatter_class=argparse.ArgumentDefaultsHelpFormatter, | |||
| ) | |||
| parser.add_argument( | |||
| "--type", | |||
| type=str, | |||
| choices=["cuda", "hip"], | |||
| default="cuda", | |||
| help="generate cuda/hip elemwise special kernel file", | |||
| ) | |||
| parser.add_argument("output", help="output directory") | |||
| args = parser.parse_args() | |||
| if not os.path.isdir(args.output): | |||
| os.makedirs(args.output) | |||
| if args.type == 'cuda': | |||
| cpp_ext = 'cu' | |||
| if args.type == "cuda": | |||
| cpp_ext = "cu" | |||
| else: | |||
| assert args.type =='hip' | |||
| cpp_ext = 'cpp.hip' | |||
| assert args.type == "hip" | |||
| cpp_ext = "cpp.hip" | |||
| for dtype in DTYPES.keys(): | |||
| fname = 'special_{}.{}'.format(dtype, cpp_ext) | |||
| fname = "special_{}.{}".format(dtype, cpp_ext) | |||
| fname = os.path.join(args.output, fname) | |||
| with open(fname, 'w') as fout: | |||
| with open(fname, "w") as fout: | |||
| w = lambda s: print(s, file=fout) | |||
| w('// generated by gen_elemwise_special_kern_impls.py') | |||
| if dtype == 'dt_float16' or dtype == 'dt_bfloat16': | |||
| w('#if !MEGDNN_DISABLE_FLOAT16') | |||
| w("// generated by gen_elemwise_special_kern_impls.py") | |||
| if dtype == "dt_float16" or dtype == "dt_bfloat16": | |||
| w("#if !MEGDNN_DISABLE_FLOAT16") | |||
| w('#include "../special_kerns.inl"') | |||
| w('INST(::megdnn::dtype::{})'.format(DTYPES[dtype][0])) | |||
| w('#undef INST') | |||
| w('}') | |||
| w('}') | |||
| if dtype == 'dt_float16' or dtype == 'dt_bfloat16': | |||
| w('#endif') | |||
| w("INST(::megdnn::dtype::{})".format(DTYPES[dtype][0])) | |||
| w("#undef INST") | |||
| w("}") | |||
| w("}") | |||
| if dtype == "dt_float16" or dtype == "dt_bfloat16": | |||
| w("#endif") | |||
| print('generated {}'.format(fname)) | |||
| print("generated {}".format(fname)) | |||
| os.utime(args.output) | |||
| if __name__ == '__main__': | |||
| if __name__ == "__main__": | |||
| main() | |||
| @@ -1,35 +1,95 @@ | |||
| ARITIES = {1: "UNARY", 2: "BINARY", 3: "TERNARY"} | |||
| ARITIES = {1: 'UNARY', 2: 'BINARY', 3: 'TERNARY'} | |||
| DTYPES = {'dt_int32': ('Int32', 'INT'), | |||
| 'dt_uint8': ('Uint8', 'INT'), | |||
| 'dt_int8': ('Int8', 'INT'), | |||
| 'dt_int16': ('Int16', 'INT'), | |||
| 'dt_bool': ('Bool', 'BOOL'), | |||
| 'dt_float32': ('Float32', 'FLOAT'), | |||
| 'dt_float16': ('Float16', 'FLOAT'), | |||
| 'dt_bfloat16': ('BFloat16', 'FLOAT') | |||
| } | |||
| DTYPES = { | |||
| "dt_int32": ("Int32", "INT"), | |||
| "dt_uint8": ("Uint8", "INT"), | |||
| "dt_int8": ("Int8", "INT"), | |||
| "dt_int16": ("Int16", "INT"), | |||
| "dt_bool": ("Bool", "BOOL"), | |||
| "dt_float32": ("Float32", "FLOAT"), | |||
| "dt_float16": ("Float16", "FLOAT"), | |||
| "dt_bfloat16": ("BFloat16", "FLOAT"), | |||
| } | |||
| MODES = { | |||
| (1, 'INT'): ['RELU', 'ABS', 'NEGATE'], | |||
| (2, 'INT'): ['ABS_GRAD', 'ADD', 'FLOOR_DIV', 'MAX', 'MIN', 'MOD', 'MUL', | |||
| 'SIGMOID_GRAD', 'SUB', 'SWITCH_GT0', 'TANH_GRAD', 'LT', 'LEQ', | |||
| 'EQ', 'FUSE_ADD_RELU', 'SHL', 'SHR', 'RMULH'], | |||
| (3, 'INT'): ['COND_LEQ_MOV', 'COND_LT_MOV'], | |||
| (1, 'FLOAT'): ['RELU', 'ABS', 'NEGATE', 'ACOS', 'ASIN', 'CEIL', 'COS', | |||
| 'EXP', 'EXPM1', 'FLOOR', 'LOG', 'LOG1P', 'SIGMOID', 'SIN', | |||
| 'TANH', 'FAST_TANH', 'ROUND', 'ERF', 'ERFINV', 'ERFC', | |||
| 'ERFCINV', 'H_SWISH', 'SILU', 'GELU'], | |||
| (2, 'FLOAT'): ['ABS_GRAD', 'ADD', 'FLOOR_DIV', 'MAX', 'MIN', 'MOD', 'MUL', | |||
| 'SIGMOID_GRAD', 'SUB', 'SWITCH_GT0', 'TANH_GRAD', 'LT', | |||
| 'LEQ', 'EQ', 'FUSE_ADD_RELU', 'TRUE_DIV', 'POW', | |||
| 'LOG_SUM_EXP', 'FUSE_ADD_TANH', 'FAST_TANH_GRAD', | |||
| 'FUSE_ADD_SIGMOID', 'ATAN2', 'H_SWISH_GRAD', | |||
| 'FUSE_ADD_H_SWISH', 'SILU_GRAD', 'GELU_GRAD'], | |||
| (3, 'FLOAT'): ['COND_LEQ_MOV', 'COND_LT_MOV', 'FUSE_MUL_ADD3'], | |||
| (1, 'BOOL'): ['NOT'], | |||
| (2, 'BOOL'): ['AND', 'OR', 'XOR', 'LT', 'LEQ', 'EQ'], | |||
| (3, 'BOOL'): [] | |||
| (1, "INT"): ["RELU", "ABS", "NEGATE"], | |||
| (2, "INT"): [ | |||
| "ABS_GRAD", | |||
| "ADD", | |||
| "FLOOR_DIV", | |||
| "MAX", | |||
| "MIN", | |||
| "MOD", | |||
| "MUL", | |||
| "SIGMOID_GRAD", | |||
| "SUB", | |||
| "SWITCH_GT0", | |||
| "TANH_GRAD", | |||
| "LT", | |||
| "LEQ", | |||
| "EQ", | |||
| "FUSE_ADD_RELU", | |||
| "SHL", | |||
| "SHR", | |||
| "RMULH", | |||
| ], | |||
| (3, "INT"): ["COND_LEQ_MOV", "COND_LT_MOV"], | |||
| (1, "FLOAT"): [ | |||
| "RELU", | |||
| "ABS", | |||
| "NEGATE", | |||
| "ACOS", | |||
| "ASIN", | |||
| "CEIL", | |||
| "COS", | |||
| "EXP", | |||
| "EXPM1", | |||
| "FLOOR", | |||
| "LOG", | |||
| "LOG1P", | |||
| "SIGMOID", | |||
| "SIN", | |||
| "TANH", | |||
| "FAST_TANH", | |||
| "ROUND", | |||
| "ERF", | |||
| "ERFINV", | |||
| "ERFC", | |||
| "ERFCINV", | |||
| "H_SWISH", | |||
| "SILU", | |||
| "GELU", | |||
| ], | |||
| (2, "FLOAT"): [ | |||
| "ABS_GRAD", | |||
| "ADD", | |||
| "FLOOR_DIV", | |||
| "MAX", | |||
| "MIN", | |||
| "MOD", | |||
| "MUL", | |||
| "SIGMOID_GRAD", | |||
| "SUB", | |||
| "SWITCH_GT0", | |||
| "TANH_GRAD", | |||
| "LT", | |||
| "LEQ", | |||
| "EQ", | |||
| "FUSE_ADD_RELU", | |||
| "TRUE_DIV", | |||
| "POW", | |||
| "LOG_SUM_EXP", | |||
| "FUSE_ADD_TANH", | |||
| "FAST_TANH_GRAD", | |||
| "FUSE_ADD_SIGMOID", | |||
| "ATAN2", | |||
| "H_SWISH_GRAD", | |||
| "FUSE_ADD_H_SWISH", | |||
| "SILU_GRAD", | |||
| "GELU_GRAD", | |||
| ], | |||
| (3, "FLOAT"): ["COND_LEQ_MOV", "COND_LT_MOV", "FUSE_MUL_ADD3"], | |||
| (1, "BOOL"): ["NOT"], | |||
| (2, "BOOL"): ["AND", "OR", "XOR", "LT", "LEQ", "EQ"], | |||
| (3, "BOOL"): [], | |||
| } | |||
| @@ -3,13 +3,14 @@ | |||
| import argparse | |||
| import collections | |||
| import textwrap | |||
| import os | |||
| import hashlib | |||
| import struct | |||
| import io | |||
| import os | |||
| import struct | |||
| import textwrap | |||
| from gen_param_defs import IndentWriterBase, ParamDef, member_defs | |||
| from gen_param_defs import member_defs, ParamDef, IndentWriterBase | |||
| class ConverterWriter(IndentWriterBase): | |||
| _skip_current_param = False | |||
| @@ -20,7 +21,7 @@ class ConverterWriter(IndentWriterBase): | |||
| def __call__(self, fout, defs): | |||
| super().__call__(fout) | |||
| self._write("// %s", self._get_header()) | |||
| self._write('#include <flatbuffers/flatbuffers.h>') | |||
| self._write("#include <flatbuffers/flatbuffers.h>") | |||
| self._write("namespace mgb {") | |||
| self._write("namespace serialization {") | |||
| self._write("namespace fbs {") | |||
| @@ -33,8 +34,9 @@ class ConverterWriter(IndentWriterBase): | |||
| self._last_param = p | |||
| self._param_fields = [] | |||
| self._fb_fields = ["builder"] | |||
| self._write("template<>\nstruct ParamConverter<megdnn::param::%s> {", | |||
| p.name, indent=1) | |||
| self._write( | |||
| "template<>\nstruct ParamConverter<megdnn::param::%s> {", p.name, indent=1 | |||
| ) | |||
| self._write("using MegDNNType = megdnn::param::%s;", p.name) | |||
| self._write("using FlatBufferType = fbs::param::%s;\n", p.name) | |||
| @@ -42,22 +44,22 @@ class ConverterWriter(IndentWriterBase): | |||
| if self._skip_current_param: | |||
| self._skip_current_param = False | |||
| return | |||
| self._write("static MegDNNType to_param(const FlatBufferType* fb) {", | |||
| indent=1) | |||
| line = 'return {' | |||
| line += ', '.join(self._param_fields) | |||
| line += '};' | |||
| self._write("static MegDNNType to_param(const FlatBufferType* fb) {", indent=1) | |||
| line = "return {" | |||
| line += ", ".join(self._param_fields) | |||
| line += "};" | |||
| self._write(line) | |||
| self._write("}\n", indent=-1) | |||
| self._write( | |||
| "static flatbuffers::Offset<FlatBufferType> to_flatbuffer(flatbuffers::FlatBufferBuilder& builder, const MegDNNType& param) {", | |||
| indent=1) | |||
| line = 'return fbs::param::Create{}('.format(str(p.name)) | |||
| line += ', '.join(self._fb_fields) | |||
| line += ');' | |||
| indent=1, | |||
| ) | |||
| line = "return fbs::param::Create{}(".format(str(p.name)) | |||
| line += ", ".join(self._fb_fields) | |||
| line += ");" | |||
| self._write(line) | |||
| self._write('}', indent=-1) | |||
| self._write("}", indent=-1) | |||
| self._write("};\n", indent=-1) | |||
| @@ -68,18 +70,23 @@ class ConverterWriter(IndentWriterBase): | |||
| return | |||
| self._param_fields.append( | |||
| "static_cast<megdnn::param::{}::{}>(fb->{}())".format( | |||
| str(p.name), str(e.name), e.name_field)) | |||
| self._fb_fields.append("static_cast<fbs::param::{}>(param.{})".format( | |||
| key, e.name_field)) | |||
| str(p.name), str(e.name), e.name_field | |||
| ) | |||
| ) | |||
| self._fb_fields.append( | |||
| "static_cast<fbs::param::{}>(param.{})".format(key, e.name_field) | |||
| ) | |||
| def _on_member_field(self, f): | |||
| if self._skip_current_param: | |||
| return | |||
| if f.dtype.cname == 'DTypeEnum': | |||
| if f.dtype.cname == "DTypeEnum": | |||
| self._param_fields.append( | |||
| "intl::convert_dtype_to_megdnn(fb->{}())".format(f.name)) | |||
| "intl::convert_dtype_to_megdnn(fb->{}())".format(f.name) | |||
| ) | |||
| self._fb_fields.append( | |||
| "intl::convert_dtype_to_fbs(param.{})".format(f.name)) | |||
| "intl::convert_dtype_to_fbs(param.{})".format(f.name) | |||
| ) | |||
| else: | |||
| self._param_fields.append("fb->{}()".format(f.name)) | |||
| self._fb_fields.append("param.{}".format(f.name)) | |||
| @@ -93,28 +100,33 @@ class ConverterWriter(IndentWriterBase): | |||
| enum_name = e.src_class + e.src_name | |||
| self._param_fields.append( | |||
| "static_cast<megdnn::param::{}::{}>(fb->{}())".format( | |||
| e.src_class, e.src_name, e.name_field)) | |||
| self._fb_fields.append("static_cast<fbs::param::{}>(param.{})".format( | |||
| enum_name, e.name_field)) | |||
| e.src_class, e.src_name, e.name_field | |||
| ) | |||
| ) | |||
| self._fb_fields.append( | |||
| "static_cast<fbs::param::{}>(param.{})".format(enum_name, e.name_field) | |||
| ) | |||
| def main(): | |||
| parser = argparse.ArgumentParser( | |||
| 'generate convert functions between FlatBuffers type and MegBrain type') | |||
| parser.add_argument('input') | |||
| parser.add_argument('output') | |||
| "generate convert functions between FlatBuffers type and MegBrain type" | |||
| ) | |||
| parser.add_argument("input") | |||
| parser.add_argument("output") | |||
| args = parser.parse_args() | |||
| with open(args.input) as fin: | |||
| inputs = fin.read() | |||
| exec(inputs, {'pdef': ParamDef, 'Doc': member_defs.Doc}) | |||
| exec(inputs, {"pdef": ParamDef, "Doc": member_defs.Doc}) | |||
| input_hash = hashlib.sha256() | |||
| input_hash.update(inputs.encode(encoding='UTF-8')) | |||
| input_hash.update(inputs.encode(encoding="UTF-8")) | |||
| input_hash = input_hash.hexdigest() | |||
| writer = ConverterWriter() | |||
| with open(args.output, 'w') as fout: | |||
| with open(args.output, "w") as fout: | |||
| writer.set_input_hash(input_hash)(fout, ParamDef.all_param_defs) | |||
| if __name__ == "__main__": | |||
| main() | |||
| @@ -3,13 +3,14 @@ | |||
| import argparse | |||
| import collections | |||
| import textwrap | |||
| import os | |||
| import hashlib | |||
| import struct | |||
| import io | |||
| import os | |||
| import struct | |||
| import textwrap | |||
| from gen_param_defs import IndentWriterBase, ParamDef, member_defs | |||
| from gen_param_defs import member_defs, ParamDef, IndentWriterBase | |||
| def _cname_to_fbname(cname): | |||
| return { | |||
| @@ -22,17 +23,19 @@ def _cname_to_fbname(cname): | |||
| "bool": "bool", | |||
| }[cname] | |||
| def scramble_enum_member_name(name): | |||
| s = name.find('<<') | |||
| s = name.find("<<") | |||
| if s != -1: | |||
| name = name[0:name.find('=') + 1] + ' ' + name[s+2:] | |||
| name = name[0 : name.find("=") + 1] + " " + name[s + 2 :] | |||
| if name in ("MIN", "MAX"): | |||
| return name + "_" | |||
| o_name = name.split(' ')[0].split('=')[0] | |||
| o_name = name.split(" ")[0].split("=")[0] | |||
| if o_name in ("MIN", "MAX"): | |||
| return name.replace(o_name, o_name + "_") | |||
| return name | |||
| class FlatBuffersWriter(IndentWriterBase): | |||
| _skip_current_param = False | |||
| _last_param = None | |||
| @@ -66,12 +69,13 @@ class FlatBuffersWriter(IndentWriterBase): | |||
| self._write("}\n", indent=-1) | |||
| def _write_doc(self, doc): | |||
| if not isinstance(doc, member_defs.Doc) or not doc.doc: return | |||
| if not isinstance(doc, member_defs.Doc) or not doc.doc: | |||
| return | |||
| doc_lines = [] | |||
| if doc.no_reformat: | |||
| doc_lines = doc.raw_lines | |||
| else: | |||
| doc = doc.doc.replace('\n', ' ') | |||
| doc = doc.doc.replace("\n", " ") | |||
| text_width = 80 - len(self._cur_indent) - 4 | |||
| doc_lines = textwrap.wrap(doc, text_width) | |||
| for line in doc_lines: | |||
| @@ -101,7 +105,8 @@ class FlatBuffersWriter(IndentWriterBase): | |||
| default = e.compose_combined_enum(e.default) | |||
| else: | |||
| default = scramble_enum_member_name( | |||
| str(e.members[e.default]).split(' ')[0].split('=')[0]) | |||
| str(e.members[e.default]).split(" ")[0].split("=")[0] | |||
| ) | |||
| self._write("%s:%s%s = %s;", e.name_field, p.name, e.name, default) | |||
| def _resolve_const(self, v): | |||
| @@ -113,8 +118,12 @@ class FlatBuffersWriter(IndentWriterBase): | |||
| if self._skip_current_param: | |||
| return | |||
| self._write_doc(f.name) | |||
| self._write("%s:%s = %s;", f.name, _cname_to_fbname(f.dtype.cname), | |||
| self._get_fb_default(self._resolve_const(f.default))) | |||
| self._write( | |||
| "%s:%s = %s;", | |||
| f.name, | |||
| _cname_to_fbname(f.dtype.cname), | |||
| self._get_fb_default(self._resolve_const(f.default)), | |||
| ) | |||
| def _on_const_field(self, f): | |||
| self._cur_const_val[str(f.name)] = str(f.default) | |||
| @@ -129,7 +138,8 @@ class FlatBuffersWriter(IndentWriterBase): | |||
| default = s.compose_combined_enum(e.get_default()) | |||
| else: | |||
| default = scramble_enum_member_name( | |||
| str(s.members[e.get_default()]).split(' ')[0].split('=')[0]) | |||
| str(s.members[e.get_default()]).split(" ")[0].split("=")[0] | |||
| ) | |||
| self._write("%s:%s = %s;", e.name_field, enum_name, default) | |||
| def _get_fb_default(self, cppdefault): | |||
| @@ -137,9 +147,9 @@ class FlatBuffersWriter(IndentWriterBase): | |||
| return cppdefault | |||
| d = cppdefault | |||
| if d.endswith('f'): # 1.f | |||
| if d.endswith("f"): # 1.f | |||
| return d[:-1] | |||
| if d.endswith('ull'): | |||
| if d.endswith("ull"): | |||
| return d[:-3] | |||
| if d.startswith("DTypeEnum::"): | |||
| return d[11:] | |||
| @@ -148,21 +158,23 @@ class FlatBuffersWriter(IndentWriterBase): | |||
| def main(): | |||
| parser = argparse.ArgumentParser( | |||
| 'generate FlatBuffers schema of operator param from description file') | |||
| parser.add_argument('input') | |||
| parser.add_argument('output') | |||
| "generate FlatBuffers schema of operator param from description file" | |||
| ) | |||
| parser.add_argument("input") | |||
| parser.add_argument("output") | |||
| args = parser.parse_args() | |||
| with open(args.input) as fin: | |||
| inputs = fin.read() | |||
| exec(inputs, {'pdef': ParamDef, 'Doc': member_defs.Doc}) | |||
| exec(inputs, {"pdef": ParamDef, "Doc": member_defs.Doc}) | |||
| input_hash = hashlib.sha256() | |||
| input_hash.update(inputs.encode(encoding='UTF-8')) | |||
| input_hash.update(inputs.encode(encoding="UTF-8")) | |||
| input_hash = input_hash.hexdigest() | |||
| writer = FlatBuffersWriter() | |||
| with open(args.output, 'w') as fout: | |||
| with open(args.output, "w") as fout: | |||
| writer.set_input_hash(input_hash)(fout, ParamDef.all_param_defs) | |||
| if __name__ == "__main__": | |||
| main() | |||
| @@ -1,14 +1,16 @@ | |||
| #! /usr/local/env python3 | |||
| import pickle | |||
| import numpy as np | |||
| import os | |||
| import argparse | |||
| import re | |||
| import collections | |||
| import os | |||
| import pickle | |||
| import re | |||
| import numpy as np | |||
| def define_template(**kwargs): | |||
| template = ''' | |||
| template = """ | |||
| float cuda{cuda_arch}_{conv_type}_time_pred[{out_dim}] = {{0.0f}}; | |||
| float cuda{cuda_arch}_{conv_type}_mask[{out_dim}] = {{0.0f}}; | |||
| float cuda{cuda_arch}_{conv_type}_hidden_units[{hidden_num}] = {{0.0f}}; | |||
| @@ -17,21 +19,23 @@ def define_template(**kwargs): | |||
| const static float cuda{cuda_arch}_{conv_type}_biases[{biases_dim}] = {{{biases}}}; | |||
| const static float cuda{cuda_arch}_{conv_type}_alpha[{out_dim}] = {{{alpha}}}; | |||
| const static float cuda{cuda_arch}_{conv_type}_beta[{out_dim}] = {{{beta}}}; | |||
| ''' | |||
| """ | |||
| return template.format(**kwargs) | |||
| def cudnn_slt_template(**kwargs): | |||
| template = ("#if CUDNN_MAJOR == {cudnn_major} && CUDNN_MINOR == {cudnn_minor}\n" + | |||
| " {define_cmd}\n" + | |||
| " {select_cmd}\n" + | |||
| " return true;\n" + | |||
| "#endif\n" | |||
| ) | |||
| template = ( | |||
| "#if CUDNN_MAJOR == {cudnn_major} && CUDNN_MINOR == {cudnn_minor}\n" | |||
| + " {define_cmd}\n" | |||
| + " {select_cmd}\n" | |||
| + " return true;\n" | |||
| + "#endif\n" | |||
| ) | |||
| return template.format(**kwargs) | |||
| def select_template(**kwargs): | |||
| template = \ | |||
| '''if (conv_type == ConvolutionType::{conv_type} && cuda_major == {cuda_major} && | |||
| template = """if (conv_type == ConvolutionType::{conv_type} && cuda_major == {cuda_major} && | |||
| cuda_minor == {cuda_minor}) {{ | |||
| *layer_num_p = {layer_num}; | |||
| *hidden_units_p = cuda{cuda_arch}_{conv_type}_hidden_units; | |||
| @@ -42,7 +46,7 @@ def select_template(**kwargs): | |||
| *beta_p = cuda{cuda_arch}_{conv_type}_beta; | |||
| *time_pred_p = cuda{cuda_arch}_{conv_type}_time_pred; | |||
| *mask_p = cuda{cuda_arch}_{conv_type}_mask; | |||
| }} else ''' | |||
| }} else """ | |||
| return template.format(**kwargs) | |||
| @@ -58,48 +62,48 @@ def fill_src(): | |||
| if len(matrix_files) == 0: | |||
| print("Warning: no param files detected.") | |||
| for fpath in matrix_files: | |||
| cudnn_version = re.findall('cudnn([\d.]+)',fpath)[0] | |||
| cudnn_version = re.findall("cudnn([\d.]+)", fpath)[0] | |||
| gen_list[cudnn_version].append(fpath) | |||
| for cudnn in gen_list: | |||
| select_cmd = ("{\n" + | |||
| " " * 8 + "return false;\n" + | |||
| " " * 4 + "}") | |||
| select_cmd = "{\n" + " " * 8 + "return false;\n" + " " * 4 + "}" | |||
| define_cmd = "" | |||
| cudnn_major, cudnn_minor = cudnn.split('.') | |||
| cudnn_major, cudnn_minor = cudnn.split(".") | |||
| for fpath in gen_list[cudnn]: | |||
| cuda_arch = fpath.split("-")[1].replace(".", "_") | |||
| print('cudnn_version: {}, cuda_arch: {}'.format(cudnn,cuda_arch)) | |||
| print("cudnn_version: {}, cuda_arch: {}".format(cudnn, cuda_arch)) | |||
| conv_type = fpath.split("-")[2].split(".")[0] | |||
| with open(os.path.join(home, "params/{}".format(fpath)), "rb") as pobj: | |||
| params = pickle.load(pobj) | |||
| crt_define_cmd, crt_select_cmd = gen_cmds( | |||
| cuda_arch, conv_type, params) | |||
| crt_define_cmd, crt_select_cmd = gen_cmds(cuda_arch, conv_type, params) | |||
| select_cmd = crt_select_cmd + select_cmd | |||
| define_cmd = crt_define_cmd + define_cmd | |||
| cudnn_slt_cmd += cudnn_slt_template(cudnn_major=cudnn_major, | |||
| cudnn_minor=cudnn_minor, | |||
| select_cmd=select_cmd, | |||
| define_cmd=define_cmd) | |||
| cudnn_slt_cmd += cudnn_slt_template( | |||
| cudnn_major=cudnn_major, | |||
| cudnn_minor=cudnn_minor, | |||
| select_cmd=select_cmd, | |||
| define_cmd=define_cmd, | |||
| ) | |||
| #select_cmd = select_cmd | |||
| # select_cmd = select_cmd | |||
| with open(os.path.join(home, "get_params.template"), "r") as srcf: | |||
| src = srcf.read() | |||
| dst = src.replace("{cudnn_select}", cudnn_slt_cmd) | |||
| MegDNN_path = os.path.join(home, "../..") | |||
| with open(os.path.join(MegDNN_path, | |||
| "src/cuda/convolution/get_params.cpp"), "w") as dstf: | |||
| with open( | |||
| os.path.join(MegDNN_path, "src/cuda/convolution/get_params.cpp"), "w" | |||
| ) as dstf: | |||
| dstf.write(dst) | |||
| def gen_cmds(cuda_arch, conv_type, params): | |||
| cuda_major, cuda_minor = cuda_arch.split("_") | |||
| alphastr = format_array(params['alpha']).rstrip()[:-1] | |||
| betastr = format_array(params['beta']).rstrip()[:-1] | |||
| W_list = params['W'] | |||
| b_list = params['b'] | |||
| Wstr = '' | |||
| bstr = '' | |||
| alphastr = format_array(params["alpha"]).rstrip()[:-1] | |||
| betastr = format_array(params["beta"]).rstrip()[:-1] | |||
| W_list = params["W"] | |||
| b_list = params["b"] | |||
| Wstr = "" | |||
| bstr = "" | |||
| layer_num = str(len(b_list) + 1) | |||
| layers_dim = [W_list[0].shape[1]] | |||
| matrices_dim = 0 | |||
| @@ -118,16 +122,27 @@ def gen_cmds(cuda_arch, conv_type, params): | |||
| out_dim = layers_dim[-1] | |||
| layers_dim_str = format_array(np.array(layers_dim)).rstrip()[:-1] | |||
| select_cmd = select_template(conv_type=conv_type.upper(), cuda_major=cuda_major, | |||
| cuda_minor=cuda_minor, layer_num=layer_num, | |||
| cuda_arch=cuda_arch) | |||
| define_cmd = define_template(cuda_arch=cuda_arch, conv_type=conv_type.upper(), | |||
| hidden_num=hidden_num, | |||
| layer_num=layer_num, out_dim=out_dim, | |||
| layers_dim=layers_dim_str, | |||
| matrices_dim=matrices_dim, matrices=Wstr, | |||
| biases_dim=biases_dim, biases=bstr, | |||
| alpha=alphastr, beta=betastr) | |||
| select_cmd = select_template( | |||
| conv_type=conv_type.upper(), | |||
| cuda_major=cuda_major, | |||
| cuda_minor=cuda_minor, | |||
| layer_num=layer_num, | |||
| cuda_arch=cuda_arch, | |||
| ) | |||
| define_cmd = define_template( | |||
| cuda_arch=cuda_arch, | |||
| conv_type=conv_type.upper(), | |||
| hidden_num=hidden_num, | |||
| layer_num=layer_num, | |||
| out_dim=out_dim, | |||
| layers_dim=layers_dim_str, | |||
| matrices_dim=matrices_dim, | |||
| matrices=Wstr, | |||
| biases_dim=biases_dim, | |||
| biases=bstr, | |||
| alpha=alphastr, | |||
| beta=betastr, | |||
| ) | |||
| return (define_cmd, select_cmd) | |||
| @@ -153,8 +168,9 @@ def format_array(array): | |||
| if __name__ == "__main__": | |||
| parser = argparse.ArgumentParser( | |||
| description="Generate cuDNN heuristic code by neural network into" | |||
| " {MEGDNN_ROOT}/src/cuda/convolution/get_params.cpp," | |||
| " using parameter value from pickle files in" | |||
| " {MEGDNN_ROOT}/scripts/gen_heuristic/params/") | |||
| " {MEGDNN_ROOT}/src/cuda/convolution/get_params.cpp," | |||
| " using parameter value from pickle files in" | |||
| " {MEGDNN_ROOT}/scripts/gen_heuristic/params/" | |||
| ) | |||
| args = parser.parse_args() | |||
| main() | |||
| @@ -3,19 +3,17 @@ | |||
| import argparse | |||
| import collections | |||
| import textwrap | |||
| import os | |||
| import hashlib | |||
| import struct | |||
| import io | |||
| import os | |||
| import struct | |||
| import textwrap | |||
| from gen_param_defs import member_defs, ParamDef, IndentWriterBase | |||
| from gen_param_defs import IndentWriterBase, ParamDef, member_defs | |||
| # FIXME: move supportToString flag definition into the param def source file | |||
| ENUM_TO_STRING_SPECIAL_RULES = [ | |||
| ("Elemwise", "Mode"), | |||
| ("ElemwiseMultiType", "Mode") | |||
| ] | |||
| ENUM_TO_STRING_SPECIAL_RULES = [("Elemwise", "Mode"), ("ElemwiseMultiType", "Mode")] | |||
| class ConverterWriter(IndentWriterBase): | |||
| _skip_current_param = False | |||
| @@ -33,21 +31,21 @@ class ConverterWriter(IndentWriterBase): | |||
| self._write("#endif // MGB_PARAM") | |||
| def _ctype2attr(self, ctype, value): | |||
| if ctype == 'uint32_t': | |||
| return 'MgbUI32Attr', value | |||
| if ctype == 'uint64_t': | |||
| return 'MgbUI64Attr', value | |||
| if ctype == 'int32_t': | |||
| return 'MgbI32Attr', value | |||
| if ctype == 'float': | |||
| return 'MgbF32Attr', value | |||
| if ctype == 'double': | |||
| return 'MgbF64Attr', value | |||
| if ctype == 'bool': | |||
| return 'MgbBoolAttr', value | |||
| if ctype == 'DTypeEnum': | |||
| if ctype == "uint32_t": | |||
| return "MgbUI32Attr", value | |||
| if ctype == "uint64_t": | |||
| return "MgbUI64Attr", value | |||
| if ctype == "int32_t": | |||
| return "MgbI32Attr", value | |||
| if ctype == "float": | |||
| return "MgbF32Attr", value | |||
| if ctype == "double": | |||
| return "MgbF64Attr", value | |||
| if ctype == "bool": | |||
| return "MgbBoolAttr", value | |||
| if ctype == "DTypeEnum": | |||
| self._packed = False | |||
| return 'MgbDTypeAttr', 'megdnn::DType::from_enum(megdnn::{})'.format(value) | |||
| return "MgbDTypeAttr", "megdnn::DType::from_enum(megdnn::{})".format(value) | |||
| raise RuntimeError("unknown ctype") | |||
| def _on_param_begin(self, p): | |||
| @@ -61,21 +59,26 @@ class ConverterWriter(IndentWriterBase): | |||
| self._skip_current_param = False | |||
| return | |||
| if self._packed: | |||
| self._write("class {0}ParamBase<string accessor> : MgbPackedParamBase<\"{0}\", accessor> {{".format(p.name), indent=1) | |||
| self._write( | |||
| 'class {0}ParamBase<string accessor> : MgbPackedParamBase<"{0}", accessor> {{'.format( | |||
| p.name | |||
| ), | |||
| indent=1, | |||
| ) | |||
| else: | |||
| self._write("def {0}Param: MgbParamBase<\"{0}\"> {{".format(p.name), indent=1) | |||
| self._write('def {0}Param: MgbParamBase<"{0}"> {{'.format(p.name), indent=1) | |||
| self._write("let fields = (ins", indent=1) | |||
| self._write(",\n{}".format(self._cur_indent).join(self._current_tparams)) | |||
| self._write(");", indent=-1) | |||
| self._write("}\n", indent=-1) | |||
| if self._packed: | |||
| self._write("def {0}Param : {0}ParamBase<\"param\">;\n".format(p.name)) | |||
| self._write('def {0}Param : {0}ParamBase<"param">;\n'.format(p.name)) | |||
| self._current_tparams = None | |||
| self._packed = None | |||
| self._const = None | |||
| def _wrapped_with_default_value(self, attr, default): | |||
| return 'MgbDefaultValuedAttr<{}, \"{}\">'.format(attr, default) | |||
| return 'MgbDefaultValuedAttr<{}, "{}">'.format(attr, default) | |||
| def _on_member_enum(self, e): | |||
| p = self._last_param | |||
| @@ -84,10 +87,12 @@ class ConverterWriter(IndentWriterBase): | |||
| # directly used by any operator, or other enum couldn't alias to this enum | |||
| td_class = "{}{}".format(p.name, e.name) | |||
| fullname = "::megdnn::param::{}".format(p.name) | |||
| enum_def = "MgbEnumAttr<\"{}\", \"{}\", [".format(fullname, e.name) | |||
| enum_def = 'MgbEnumAttr<"{}", "{}", ['.format(fullname, e.name) | |||
| def format(v): | |||
| return '\"{}\"'.format(str(v).split(' ')[0].split('=')[0]) | |||
| enum_def += ','.join(format(i) for i in e.members) | |||
| return '"{}"'.format(str(v).split(" ")[0].split("=")[0]) | |||
| enum_def += ",".join(format(i) for i in e.members) | |||
| if e.combined: | |||
| enum_def += "], 1" | |||
| @@ -95,7 +100,7 @@ class ConverterWriter(IndentWriterBase): | |||
| enum_def += "], 0" | |||
| if ENUM_TO_STRING_SPECIAL_RULES.count((p.name, e.name)): | |||
| enum_def += ", 1" # whether generate ToStringTrait | |||
| enum_def += ", 1" # whether generate ToStringTrait | |||
| enum_def += ">" | |||
| self._write("def {} : {};".format(td_class, enum_def)) | |||
| @@ -105,10 +110,12 @@ class ConverterWriter(IndentWriterBase): | |||
| # wrapped with default value | |||
| if e.combined: | |||
| default_val = "static_cast<{}::{}>({})".format( | |||
| fullname, e.name, e.compose_combined_enum(e.default)) | |||
| fullname, e.name, e.compose_combined_enum(e.default) | |||
| ) | |||
| else: | |||
| default_val = "{}::{}::{}".format( | |||
| fullname, e.name, str(e.members[e.default]).split(' ')[0].split('=')[0]) | |||
| fullname, e.name, str(e.members[e.default]).split(" ")[0].split("=")[0] | |||
| ) | |||
| wrapped = self._wrapped_with_default_value(td_class, default_val) | |||
| @@ -123,51 +130,58 @@ class ConverterWriter(IndentWriterBase): | |||
| td_class = "{}{}".format(p.name, e.name) | |||
| fullname = "::megdnn::param::{}".format(p.name) | |||
| base_td_class = "{}{}".format(e.src_class, e.src_name) | |||
| enum_def = "MgbEnumAliasAttr<\"{}\", \"{}\", {}>".format(fullname, e.name, base_td_class) | |||
| enum_def = 'MgbEnumAliasAttr<"{}", "{}", {}>'.format( | |||
| fullname, e.name, base_td_class | |||
| ) | |||
| self._write("def {} : {};".format(td_class, enum_def)) | |||
| # wrapped with default value | |||
| s = e.src_enum | |||
| if s.combined: | |||
| default_val = "static_cast<{}::{}>({})".format( | |||
| fullname, e.name, s.compose_combined_enum(e.get_default())) | |||
| fullname, e.name, s.compose_combined_enum(e.get_default()) | |||
| ) | |||
| else: | |||
| default_val = "{}::{}::{}".format(fullname, e.name, str( | |||
| s.members[e.get_default()]).split(' ')[0].split('=')[0]) | |||
| default_val = "{}::{}::{}".format( | |||
| fullname, | |||
| e.name, | |||
| str(s.members[e.get_default()]).split(" ")[0].split("=")[0], | |||
| ) | |||
| wrapped = self._wrapped_with_default_value(td_class, default_val) | |||
| self._current_tparams.append("{}:${}".format(wrapped, e.name_field)) | |||
| def _on_member_field(self, f): | |||
| if self._skip_current_param: | |||
| return | |||
| attr, value = self._ctype2attr(f.dtype.cname, str(f.default)) | |||
| if str(value) in self._const: | |||
| value = '::megdnn::param::{}::{}'.format(self._last_param.name, value) | |||
| value = "::megdnn::param::{}::{}".format(self._last_param.name, value) | |||
| wrapped = self._wrapped_with_default_value(attr, value) | |||
| self._current_tparams.append("{}:${}".format(wrapped, f.name)) | |||
| def _on_const_field(self, f): | |||
| self._const.add(str(f.name)) | |||
| def main(): | |||
| parser = argparse.ArgumentParser('generate op param tablegen file') | |||
| parser.add_argument('input') | |||
| parser.add_argument('output') | |||
| parser = argparse.ArgumentParser("generate op param tablegen file") | |||
| parser.add_argument("input") | |||
| parser.add_argument("output") | |||
| args = parser.parse_args() | |||
| with open(args.input) as fin: | |||
| inputs = fin.read() | |||
| exec(inputs, {'pdef': ParamDef, 'Doc': member_defs.Doc}) | |||
| exec(inputs, {"pdef": ParamDef, "Doc": member_defs.Doc}) | |||
| input_hash = hashlib.sha256() | |||
| input_hash.update(inputs.encode(encoding='UTF-8')) | |||
| input_hash.update(inputs.encode(encoding="UTF-8")) | |||
| input_hash = input_hash.hexdigest() | |||
| writer = ConverterWriter() | |||
| with open(args.output, 'w') as fout: | |||
| with open(args.output, "w") as fout: | |||
| writer.set_input_hash(input_hash)(fout, ParamDef.all_param_defs) | |||
| if __name__ == "__main__": | |||
| main() | |||
| @@ -19,6 +19,7 @@ device = { | |||
| "thread_number": 3, | |||
| } | |||
| class SshConnector: | |||
| """imp ssh control master connector""" | |||
| @@ -83,17 +84,17 @@ def main(): | |||
| model_file = args.model_file | |||
| # copy model file | |||
| ssh.copy([args.model_file], workspace) | |||
| m = model_file.split('\\')[-1] | |||
| m = model_file.split("\\")[-1] | |||
| # run single thread | |||
| result = [] | |||
| thread_number = [1, 2, 4] | |||
| for b in thread_number : | |||
| for b in thread_number: | |||
| cmd = [] | |||
| cmd1 = "cd {} && ./load_and_run {} -multithread {} --fast-run --fast_run_algo_policy fastrun.cache --iter 1 --warmup-iter 1 --no-sanity-check --weight-preprocess".format( | |||
| workspace, m, b | |||
| workspace, m, b | |||
| ) | |||
| cmd2 = "cd {} && ./load_and_run {} -multithread {} --fast_run_algo_policy fastrun.cache --iter 20 --warmup-iter 5 --no-sanity-check --weight-preprocess ".format( | |||
| workspace, m, b | |||
| workspace, m, b | |||
| ) | |||
| cmd.append(cmd1) | |||
| cmd.append(cmd2) | |||
| @@ -103,12 +104,20 @@ def main(): | |||
| logging.debug("model: {} with backend: {} result is: {}".format(m, b, ret)) | |||
| result.append(ret) | |||
| thread_2 = result[0]/result[1] | |||
| thread_4 = result[0]/result[2] | |||
| thread_2 = result[0] / result[1] | |||
| thread_4 = result[0] / result[2] | |||
| if thread_2 > 1.6 or thread_4 > 3.0: | |||
| print("model: {} can has good parallelism. 2 thread is {}, 4 thread is {}".format(m, thread_2, thread_4)) | |||
| print( | |||
| "model: {} can has good parallelism. 2 thread is {}, 4 thread is {}".format( | |||
| m, thread_2, thread_4 | |||
| ) | |||
| ) | |||
| else: | |||
| print("model: {} can has bad parallelism. 2 thread is {}, 4 thread is {}".format(m, thread_2, thread_4)) | |||
| print( | |||
| "model: {} can has bad parallelism. 2 thread is {}, 4 thread is {}".format( | |||
| m, thread_2, thread_4 | |||
| ) | |||
| ) | |||
| if __name__ == "__main__": | |||
| @@ -20,8 +20,12 @@ failed_files = Manager().list() | |||
| def process_file(file, clang_format, write): | |||
| original_source = open(file, "r").read() | |||
| source = original_source | |||
| source = re.sub(r"MGB_DEFINE(?P<r>([^\\]|\n)*?)// *{", r"class MGB_DEFINE\g<r>{", source) | |||
| source, count = re.subn(r"(?<!#define )MGB_DEFINE(.*) +\\", r"class MGB_DEFINE\1{\\", source) | |||
| source = re.sub( | |||
| r"MGB_DEFINE(?P<r>([^\\]|\n)*?)// *{", r"class MGB_DEFINE\g<r>{", source | |||
| ) | |||
| source, count = re.subn( | |||
| r"(?<!#define )MGB_DEFINE(.*) +\\", r"class MGB_DEFINE\1{\\", source | |||
| ) | |||
| result = subprocess.check_output( | |||
| [ | |||
| @@ -36,7 +40,9 @@ def process_file(file, clang_format, write): | |||
| result = result.decode("utf-8") | |||
| if count: | |||
| result = re.sub(r"class MGB_DEFINE(.*){( *)\\", r"MGB_DEFINE\1\2 \\", result) | |||
| result = re.sub( | |||
| r"class MGB_DEFINE(.*){( *)\\", r"MGB_DEFINE\1\2 \\", result | |||
| ) | |||
| result = re.sub(r"class MGB_DEFINE((.|\n)*?){", r"MGB_DEFINE\1// {", result) | |||
| if write and original_source != result: | |||
| @@ -109,19 +115,17 @@ def main(): | |||
| raise ValueError("Invalid path {}".format(path)) | |||
| # check version, we only support 12.0.1 now | |||
| version = subprocess.check_output( | |||
| [ | |||
| args.clang_format, | |||
| "--version", | |||
| ], | |||
| ) | |||
| version = subprocess.check_output([args.clang_format, "--version",],) | |||
| version = version.decode("utf-8") | |||
| need_version = '12.0.1' | |||
| need_version = "12.0.1" | |||
| if version.find(need_version) < 0: | |||
| print('We only support {} now, please install {} version, find version: {}' | |||
| .format(need_version, need_version, version)) | |||
| raise RuntimeError('clang-format version not equal {}'.format(need_version)) | |||
| print( | |||
| "We only support {} now, please install {} version, find version: {}".format( | |||
| need_version, need_version, version | |||
| ) | |||
| ) | |||
| raise RuntimeError("clang-format version not equal {}".format(need_version)) | |||
| process_map( | |||
| partial(process_file, clang_format=args.clang_format, write=args.write,), | |||
| @@ -20,6 +20,7 @@ device = { | |||
| "thread_number": 3, | |||
| } | |||
| class SshConnector: | |||
| """imp ssh control master connector""" | |||
| @@ -54,6 +55,7 @@ class SshConnector: | |||
| except: | |||
| raise | |||
| def main(): | |||
| parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter) | |||
| parser.add_argument("--model_file", help="megengine model", required=True) | |||
| @@ -78,10 +80,10 @@ def main(): | |||
| model_file = args.model_file | |||
| # copy model file | |||
| ssh.copy([model_file], workspace) | |||
| m = model_file.split('\\')[-1] | |||
| m = model_file.split("\\")[-1] | |||
| # run single thread | |||
| cmd = "cd {} && ./load_and_run {} --fast-run --record-comp-seq --iter 1 --warmup-iter 1".format( | |||
| workspace, m | |||
| workspace, m | |||
| ) | |||
| try: | |||
| raw_log = ssh.cmd([cmd]) | |||
| @@ -91,6 +93,7 @@ def main(): | |||
| print("model: {} is static model.".format(m)) | |||
| if __name__ == "__main__": | |||
| LOG_FORMAT = "%(asctime)s - %(levelname)s - %(message)s" | |||
| DATE_FORMAT = "%Y/%m/%d %H:%M:%S" | |||