| @@ -8,21 +8,22 @@ | |||||
| # software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| import sys | |||||
| import re | |||||
| if sys.version_info[0] != 3 or sys.version_info[1] < 5: | |||||
| print('This script requires Python version 3.5') | |||||
| sys.exit(1) | |||||
| import argparse | import argparse | ||||
| import json | import json | ||||
| import os | import os | ||||
| import re | |||||
| import subprocess | import subprocess | ||||
| import sys | |||||
| import tempfile | import tempfile | ||||
| from pathlib import Path | from pathlib import Path | ||||
| MIDOUT_TRACE_MAGIC = 'midout_trace v1\n' | |||||
| if sys.version_info[0] != 3 or sys.version_info[1] < 5: | |||||
| print("This script requires Python version 3.5") | |||||
| sys.exit(1) | |||||
| MIDOUT_TRACE_MAGIC = "midout_trace v1\n" | |||||
| class HeaderGen: | class HeaderGen: | ||||
| _dtypes = None | _dtypes = None | ||||
| @@ -42,20 +43,22 @@ class HeaderGen: | |||||
| self._midout_files = [] | self._midout_files = [] | ||||
| _megvii3_root_cache = None | _megvii3_root_cache = None | ||||
| @classmethod | @classmethod | ||||
| def get_megvii3_root(cls): | def get_megvii3_root(cls): | ||||
| if cls._megvii3_root_cache is not None: | if cls._megvii3_root_cache is not None: | ||||
| return cls._megvii3_root_cache | return cls._megvii3_root_cache | ||||
| wd = Path(__file__).resolve().parent | wd = Path(__file__).resolve().parent | ||||
| while wd.parent != wd: | while wd.parent != wd: | ||||
| workspace_file = wd / 'WORKSPACE' | |||||
| if workspace_file.is_file(): | |||||
| cls._megvii3_root_cache = str(wd) | |||||
| return cls._megvii3_root_cache | |||||
| wd = wd.parent | |||||
| workspace_file = wd / "WORKSPACE" | |||||
| if workspace_file.is_file(): | |||||
| cls._megvii3_root_cache = str(wd) | |||||
| return cls._megvii3_root_cache | |||||
| wd = wd.parent | |||||
| return None | return None | ||||
| _megengine_root_cache = None | _megengine_root_cache = None | ||||
| @classmethod | @classmethod | ||||
| def get_megengine_root(cls): | def get_megengine_root(cls): | ||||
| if cls._megengine_root_cache is not None: | if cls._megengine_root_cache is not None: | ||||
| @@ -66,15 +69,15 @@ class HeaderGen: | |||||
| def extend_netinfo(self, data): | def extend_netinfo(self, data): | ||||
| self._has_netinfo = True | self._has_netinfo = True | ||||
| if 'hash' not in data: | |||||
| if "hash" not in data: | |||||
| self._file_without_hash = True | self._file_without_hash = True | ||||
| else: | else: | ||||
| self._graph_hashes.add(str(data['hash'])) | |||||
| for i in data['dtypes']: | |||||
| self._graph_hashes.add(str(data["hash"])) | |||||
| for i in data["dtypes"]: | |||||
| self._dtypes.add(i) | self._dtypes.add(i) | ||||
| for i in data['opr_types']: | |||||
| for i in data["opr_types"]: | |||||
| self._oprs.add(i) | self._oprs.add(i) | ||||
| for i in data['elemwise_modes']: | |||||
| for i in data["elemwise_modes"]: | |||||
| self._elemwise_modes.add(i) | self._elemwise_modes.add(i) | ||||
| def extend_midout(self, fname): | def extend_midout(self, fname): | ||||
| @@ -82,7 +85,7 @@ class HeaderGen: | |||||
| def generate(self, fout): | def generate(self, fout): | ||||
| self._fout = fout | self._fout = fout | ||||
| self._write_def('MGB_BINREDUCE_VERSION', '20190219') | |||||
| self._write_def("MGB_BINREDUCE_VERSION", "20190219") | |||||
| if self._has_netinfo: | if self._has_netinfo: | ||||
| self._write_dtype() | self._write_dtype() | ||||
| self._write_elemwise_modes() | self._write_elemwise_modes() | ||||
| @@ -93,13 +96,13 @@ class HeaderGen: | |||||
| def strip_opr_name_with_version(self, name): | def strip_opr_name_with_version(self, name): | ||||
| pos = len(name) | pos = len(name) | ||||
| t = re.search(r'V\d+$', name) | |||||
| t = re.search(r"V\d+$", name) | |||||
| if t: | if t: | ||||
| pos = t.start() | pos = t.start() | ||||
| return name[:pos] | return name[:pos] | ||||
| def _write_oprs(self): | def _write_oprs(self): | ||||
| defs = ['}', 'namespace opr {'] | |||||
| defs = ["}", "namespace opr {"] | |||||
| already_declare = set() | already_declare = set() | ||||
| already_instance = set() | already_instance = set() | ||||
| for i in self._oprs: | for i in self._oprs: | ||||
| @@ -109,13 +112,15 @@ class HeaderGen: | |||||
| else: | else: | ||||
| already_declare.add(i) | already_declare.add(i) | ||||
| defs.append('class {};'.format(i)) | |||||
| defs.append('}') | |||||
| defs.append('namespace serialization {') | |||||
| defs.append(""" | |||||
| defs.append("class {};".format(i)) | |||||
| defs.append("}") | |||||
| defs.append("namespace serialization {") | |||||
| defs.append( | |||||
| """ | |||||
| template<class Opr, class Callee> | template<class Opr, class Callee> | ||||
| struct OprRegistryCaller { | struct OprRegistryCaller { | ||||
| }; """) | |||||
| }; """ | |||||
| ) | |||||
| for i in sorted(self._oprs): | for i in sorted(self._oprs): | ||||
| i = self.strip_opr_name_with_version(i) | i = self.strip_opr_name_with_version(i) | ||||
| if i in already_instance: | if i in already_instance: | ||||
| @@ -123,40 +128,53 @@ class HeaderGen: | |||||
| else: | else: | ||||
| already_instance.add(i) | already_instance.add(i) | ||||
| defs.append(""" | |||||
| defs.append( | |||||
| """ | |||||
| template<class Callee> | template<class Callee> | ||||
| struct OprRegistryCaller<opr::{}, Callee>: public | struct OprRegistryCaller<opr::{}, Callee>: public | ||||
| OprRegistryCallerDefaultImpl<Callee> {{ | OprRegistryCallerDefaultImpl<Callee> {{ | ||||
| }}; """.format(i)) | |||||
| self._write_def('MGB_OPR_REGISTRY_CALLER_SPECIALIZE', defs) | |||||
| }}; """.format( | |||||
| i | |||||
| ) | |||||
| ) | |||||
| self._write_def("MGB_OPR_REGISTRY_CALLER_SPECIALIZE", defs) | |||||
| def _write_elemwise_modes(self): | def _write_elemwise_modes(self): | ||||
| with tempfile.NamedTemporaryFile() as ftmp: | with tempfile.NamedTemporaryFile() as ftmp: | ||||
| fpath = os.path.realpath(ftmp.name) | fpath = os.path.realpath(ftmp.name) | ||||
| subprocess.check_call( | subprocess.check_call( | ||||
| ['./dnn/scripts/gen_param_defs.py', | |||||
| '--write-enum-items', 'Elemwise:Mode', | |||||
| './dnn/scripts/opr_param_defs.py', | |||||
| fpath], | |||||
| cwd=self.get_megengine_root() | |||||
| [ | |||||
| "./dnn/scripts/gen_param_defs.py", | |||||
| "--write-enum-items", | |||||
| "Elemwise:Mode", | |||||
| "./dnn/scripts/opr_param_defs.py", | |||||
| fpath, | |||||
| ], | |||||
| cwd=self.get_megengine_root(), | |||||
| ) | ) | ||||
| with open(fpath) as fin: | with open(fpath) as fin: | ||||
| mode_list = [i.strip() for i in fin] | mode_list = [i.strip() for i in fin] | ||||
| for i in mode_list: | for i in mode_list: | ||||
| i = i.split(' ')[0].split('=')[0] | |||||
| i = i.split(" ")[0].split("=")[0] | |||||
| if i in self._elemwise_modes: | if i in self._elemwise_modes: | ||||
| content = '_cb({})'.format(i) | |||||
| content = "_cb({})".format(i) | |||||
| else: | else: | ||||
| content = '' | |||||
| content = "" | |||||
| self._write_def( | self._write_def( | ||||
| '_MEGDNN_ELEMWISE_MODE_ENABLE_IMPL_{}(_cb)'.format(i.split(' ')[0].split('=')[0]), content) | |||||
| self._write_def('MEGDNN_ELEMWISE_MODE_ENABLE(_mode, _cb)', | |||||
| '_MEGDNN_ELEMWISE_MODE_ENABLE_IMPL_##_mode(_cb)') | |||||
| "_MEGDNN_ELEMWISE_MODE_ENABLE_IMPL_{}(_cb)".format( | |||||
| i.split(" ")[0].split("=")[0] | |||||
| ), | |||||
| content, | |||||
| ) | |||||
| self._write_def( | |||||
| "MEGDNN_ELEMWISE_MODE_ENABLE(_mode, _cb)", | |||||
| "_MEGDNN_ELEMWISE_MODE_ENABLE_IMPL_##_mode(_cb)", | |||||
| ) | |||||
| def _write_dtype(self): | def _write_dtype(self): | ||||
| if 'Float16' not in self._dtypes: | |||||
| if "Float16" not in self._dtypes: | |||||
| # MegBrain/MegDNN used MEGDNN_DISABLE_FLOT16 to turn off float16 | # MegBrain/MegDNN used MEGDNN_DISABLE_FLOT16 to turn off float16 | ||||
| # support in the past; however `FLOT16' is really a typo. We plan to | # support in the past; however `FLOT16' is really a typo. We plan to | ||||
| # change MEGDNN_DISABLE_FLOT16 to MEGDNN_DISABLE_FLOAT16 soon. | # change MEGDNN_DISABLE_FLOT16 to MEGDNN_DISABLE_FLOAT16 soon. | ||||
| @@ -166,74 +184,86 @@ class HeaderGen: | |||||
| # In the future when the situation is settled and no one would ever | # In the future when the situation is settled and no one would ever | ||||
| # use legacy MegBrain/MegDNN, the `FLOT16' macro definition can be | # use legacy MegBrain/MegDNN, the `FLOT16' macro definition can be | ||||
| # safely deleted. | # safely deleted. | ||||
| self._write_def('MEGDNN_DISABLE_FLOT16', 1) | |||||
| self._write_def('MEGDNN_DISABLE_FLOAT16', 1) | |||||
| self._write_def("MEGDNN_DISABLE_FLOT16", 1) | |||||
| self._write_def("MEGDNN_DISABLE_FLOAT16", 1) | |||||
| def _write_hash(self): | def _write_hash(self): | ||||
| if self._file_without_hash: | if self._file_without_hash: | ||||
| print('WARNING: network info has no graph hash. Using json file ' | |||||
| 'generated by MegBrain >= 7.28.0 is recommended') | |||||
| print( | |||||
| "WARNING: network info has no graph hash. Using json file " | |||||
| "generated by MegBrain >= 7.28.0 is recommended" | |||||
| ) | |||||
| else: | else: | ||||
| defs = 'ULL,'.join(self._graph_hashes) + 'ULL' | |||||
| self._write_def('MGB_BINREDUCE_GRAPH_HASHES', defs) | |||||
| defs = "ULL,".join(self._graph_hashes) + "ULL" | |||||
| self._write_def("MGB_BINREDUCE_GRAPH_HASHES", defs) | |||||
| def _write_def(self, name, val): | def _write_def(self, name, val): | ||||
| if isinstance(val, list): | if isinstance(val, list): | ||||
| val = '\n'.join(val) | |||||
| val = str(val).strip().replace('\n', ' \\\n') | |||||
| self._fout.write('#define {} {}\n'.format(name, val)) | |||||
| val = "\n".join(val) | |||||
| val = str(val).strip().replace("\n", " \\\n") | |||||
| self._fout.write("#define {} {}\n".format(name, val)) | |||||
| def _write_midout(self): | def _write_midout(self): | ||||
| if not self._midout_files: | if not self._midout_files: | ||||
| return | return | ||||
| gen = os.path.join(self.get_megengine_root(), 'third_party', 'midout', 'gen_header.py') | |||||
| gen = os.path.join( | |||||
| self.get_megengine_root(), "third_party", "midout", "gen_header.py" | |||||
| ) | |||||
| if self.get_megvii3_root(): | if self.get_megvii3_root(): | ||||
| gen = os.path.join(self.get_megvii3_root(), 'brain', 'midout', 'gen_header.py') | |||||
| print('use {} to gen bin_reduce header'.format(gen)) | |||||
| gen = os.path.join( | |||||
| self.get_megvii3_root(), "brain", "midout", "gen_header.py" | |||||
| ) | |||||
| print("use {} to gen bin_reduce header".format(gen)) | |||||
| cvt = subprocess.run( | cvt = subprocess.run( | ||||
| [gen] + self._midout_files, | |||||
| stdout=subprocess.PIPE, check=True, | |||||
| ).stdout.decode('utf-8') | |||||
| self._fout.write('// midout \n') | |||||
| [gen] + self._midout_files, stdout=subprocess.PIPE, check=True, | |||||
| ).stdout.decode("utf-8") | |||||
| self._fout.write("// midout \n") | |||||
| self._fout.write(cvt) | self._fout.write(cvt) | ||||
| if cvt.find(" half,") > 0: | if cvt.find(" half,") > 0: | ||||
| change = open(self._fout.name).read().replace(" half,", " __fp16,") | change = open(self._fout.name).read().replace(" half,", " __fp16,") | ||||
| with open("fix_fp16_bin_reduce.h", "w") as fix_fp16: | with open("fix_fp16_bin_reduce.h", "w") as fix_fp16: | ||||
| fix_fp16.write(change) | fix_fp16.write(change) | ||||
| msg = ( | msg = ( | ||||
| "WARNING:\n" | |||||
| "hit half in trace, try use fix_fp16_bin_reduce.h when build failed with bin_reduce.h\n" | |||||
| "which caused by LLVM mangle issue on __fp16 dtype, if you find msg 'error: use of undeclared identifier 'half'\n" | |||||
| "then try use fix_fp16_bin_reduce.h, if build failed again, submit a issue to Engine team!!!" | |||||
| ) | |||||
| "WARNING:\n" | |||||
| "hit half in trace, try use fix_fp16_bin_reduce.h when build failed with bin_reduce.h\n" | |||||
| "which caused by LLVM mangle issue on __fp16 dtype, if you find msg 'error: use of undeclared identifier 'half'\n" | |||||
| "then try use fix_fp16_bin_reduce.h, if build failed again, submit a issue to Engine team!!!" | |||||
| ) | |||||
| print(msg) | print(msg) | ||||
| def main(): | def main(): | ||||
| parser = argparse.ArgumentParser( | parser = argparse.ArgumentParser( | ||||
| description='generate header file for reducing binary size by ' | |||||
| 'stripping unused oprs in a particular network; output file would ' | |||||
| 'be written to bin_reduce.h', | |||||
| formatter_class=argparse.ArgumentDefaultsHelpFormatter) | |||||
| description="generate header file for reducing binary size by " | |||||
| "stripping unused oprs in a particular network; output file would " | |||||
| "be written to bin_reduce.h", | |||||
| formatter_class=argparse.ArgumentDefaultsHelpFormatter, | |||||
| ) | |||||
| parser.add_argument( | parser.add_argument( | ||||
| 'inputs', nargs='+', | |||||
| help='input files that describe specific traits of the network; ' | |||||
| 'can be one of the following:' | |||||
| ' 1. json files generated by ' | |||||
| 'megbrain.serialize_comp_graph_to_file() in python; ' | |||||
| ' 2. trace files generated by midout library') | |||||
| default_file=os.path.join(HeaderGen.get_megengine_root(), 'src', 'bin_reduce_cmake.h') | |||||
| "inputs", | |||||
| nargs="+", | |||||
| help="input files that describe specific traits of the network; " | |||||
| "can be one of the following:" | |||||
| " 1. json files generated by " | |||||
| "megbrain.serialize_comp_graph_to_file() in python; " | |||||
| " 2. trace files generated by midout library", | |||||
| ) | |||||
| default_file = os.path.join( | |||||
| HeaderGen.get_megengine_root(), "src", "bin_reduce_cmake.h" | |||||
| ) | |||||
| is_megvii3 = HeaderGen.get_megvii3_root() | is_megvii3 = HeaderGen.get_megvii3_root() | ||||
| if is_megvii3: | if is_megvii3: | ||||
| default_file=os.path.join(HeaderGen.get_megvii3_root(), 'utils', 'bin_reduce.h') | |||||
| parser.add_argument('-o', '--output', help='output file', default=default_file) | |||||
| default_file = os.path.join( | |||||
| HeaderGen.get_megvii3_root(), "utils", "bin_reduce.h" | |||||
| ) | |||||
| parser.add_argument("-o", "--output", help="output file", default=default_file) | |||||
| args = parser.parse_args() | args = parser.parse_args() | ||||
| print('config output file: {}'.format(args.output)) | |||||
| print("config output file: {}".format(args.output)) | |||||
| gen = HeaderGen() | gen = HeaderGen() | ||||
| for i in args.inputs: | for i in args.inputs: | ||||
| print('==== processing {}'.format(i)) | |||||
| print("==== processing {}".format(i)) | |||||
| with open(i) as fin: | with open(i) as fin: | ||||
| if fin.read(len(MIDOUT_TRACE_MAGIC)) == MIDOUT_TRACE_MAGIC: | if fin.read(len(MIDOUT_TRACE_MAGIC)) == MIDOUT_TRACE_MAGIC: | ||||
| gen.extend_midout(i) | gen.extend_midout(i) | ||||
| @@ -241,8 +271,9 @@ def main(): | |||||
| fin.seek(0) | fin.seek(0) | ||||
| gen.extend_netinfo(json.loads(fin.read())) | gen.extend_netinfo(json.loads(fin.read())) | ||||
| with open(args.output, 'w') as fout: | |||||
| with open(args.output, "w") as fout: | |||||
| gen.generate(fout) | gen.generate(fout) | ||||
| if __name__ == '__main__': | |||||
| if __name__ == "__main__": | |||||
| main() | main() | ||||