| @@ -8,21 +8,22 @@ | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import sys | |||
| import re | |||
| if sys.version_info[0] != 3 or sys.version_info[1] < 5: | |||
| print('This script requires Python version 3.5') | |||
| sys.exit(1) | |||
| import argparse | |||
| import json | |||
| import os | |||
| import re | |||
| import subprocess | |||
| import sys | |||
| import tempfile | |||
| from pathlib import Path | |||
| MIDOUT_TRACE_MAGIC = 'midout_trace v1\n' | |||
| if sys.version_info[0] != 3 or sys.version_info[1] < 5: | |||
| print("This script requires Python version 3.5") | |||
| sys.exit(1) | |||
| MIDOUT_TRACE_MAGIC = "midout_trace v1\n" | |||
| class HeaderGen: | |||
| _dtypes = None | |||
| @@ -42,20 +43,22 @@ class HeaderGen: | |||
| self._midout_files = [] | |||
| _megvii3_root_cache = None | |||
| @classmethod | |||
| def get_megvii3_root(cls): | |||
| if cls._megvii3_root_cache is not None: | |||
| return cls._megvii3_root_cache | |||
| wd = Path(__file__).resolve().parent | |||
| while wd.parent != wd: | |||
| workspace_file = wd / 'WORKSPACE' | |||
| if workspace_file.is_file(): | |||
| cls._megvii3_root_cache = str(wd) | |||
| return cls._megvii3_root_cache | |||
| wd = wd.parent | |||
| workspace_file = wd / "WORKSPACE" | |||
| if workspace_file.is_file(): | |||
| cls._megvii3_root_cache = str(wd) | |||
| return cls._megvii3_root_cache | |||
| wd = wd.parent | |||
| return None | |||
| _megengine_root_cache = None | |||
| @classmethod | |||
| def get_megengine_root(cls): | |||
| if cls._megengine_root_cache is not None: | |||
| @@ -66,15 +69,15 @@ class HeaderGen: | |||
| def extend_netinfo(self, data): | |||
| self._has_netinfo = True | |||
| if 'hash' not in data: | |||
| if "hash" not in data: | |||
| self._file_without_hash = True | |||
| else: | |||
| self._graph_hashes.add(str(data['hash'])) | |||
| for i in data['dtypes']: | |||
| self._graph_hashes.add(str(data["hash"])) | |||
| for i in data["dtypes"]: | |||
| self._dtypes.add(i) | |||
| for i in data['opr_types']: | |||
| for i in data["opr_types"]: | |||
| self._oprs.add(i) | |||
| for i in data['elemwise_modes']: | |||
| for i in data["elemwise_modes"]: | |||
| self._elemwise_modes.add(i) | |||
| def extend_midout(self, fname): | |||
| @@ -82,7 +85,7 @@ class HeaderGen: | |||
| def generate(self, fout): | |||
| self._fout = fout | |||
| self._write_def('MGB_BINREDUCE_VERSION', '20190219') | |||
| self._write_def("MGB_BINREDUCE_VERSION", "20190219") | |||
| if self._has_netinfo: | |||
| self._write_dtype() | |||
| self._write_elemwise_modes() | |||
| @@ -93,13 +96,13 @@ class HeaderGen: | |||
| def strip_opr_name_with_version(self, name): | |||
| pos = len(name) | |||
| t = re.search(r'V\d+$', name) | |||
| t = re.search(r"V\d+$", name) | |||
| if t: | |||
| pos = t.start() | |||
| return name[:pos] | |||
| def _write_oprs(self): | |||
| defs = ['}', 'namespace opr {'] | |||
| defs = ["}", "namespace opr {"] | |||
| already_declare = set() | |||
| already_instance = set() | |||
| for i in self._oprs: | |||
| @@ -109,13 +112,15 @@ class HeaderGen: | |||
| else: | |||
| already_declare.add(i) | |||
| defs.append('class {};'.format(i)) | |||
| defs.append('}') | |||
| defs.append('namespace serialization {') | |||
| defs.append(""" | |||
| defs.append("class {};".format(i)) | |||
| defs.append("}") | |||
| defs.append("namespace serialization {") | |||
| defs.append( | |||
| """ | |||
| template<class Opr, class Callee> | |||
| struct OprRegistryCaller { | |||
| }; """) | |||
| }; """ | |||
| ) | |||
| for i in sorted(self._oprs): | |||
| i = self.strip_opr_name_with_version(i) | |||
| if i in already_instance: | |||
| @@ -123,40 +128,53 @@ class HeaderGen: | |||
| else: | |||
| already_instance.add(i) | |||
| defs.append(""" | |||
| defs.append( | |||
| """ | |||
| template<class Callee> | |||
| struct OprRegistryCaller<opr::{}, Callee>: public | |||
| OprRegistryCallerDefaultImpl<Callee> {{ | |||
| }}; """.format(i)) | |||
| self._write_def('MGB_OPR_REGISTRY_CALLER_SPECIALIZE', defs) | |||
| }}; """.format( | |||
| i | |||
| ) | |||
| ) | |||
| self._write_def("MGB_OPR_REGISTRY_CALLER_SPECIALIZE", defs) | |||
| def _write_elemwise_modes(self): | |||
| with tempfile.NamedTemporaryFile() as ftmp: | |||
| fpath = os.path.realpath(ftmp.name) | |||
| subprocess.check_call( | |||
| ['./dnn/scripts/gen_param_defs.py', | |||
| '--write-enum-items', 'Elemwise:Mode', | |||
| './dnn/scripts/opr_param_defs.py', | |||
| fpath], | |||
| cwd=self.get_megengine_root() | |||
| [ | |||
| "./dnn/scripts/gen_param_defs.py", | |||
| "--write-enum-items", | |||
| "Elemwise:Mode", | |||
| "./dnn/scripts/opr_param_defs.py", | |||
| fpath, | |||
| ], | |||
| cwd=self.get_megengine_root(), | |||
| ) | |||
| with open(fpath) as fin: | |||
| mode_list = [i.strip() for i in fin] | |||
| for i in mode_list: | |||
| i = i.split(' ')[0].split('=')[0] | |||
| i = i.split(" ")[0].split("=")[0] | |||
| if i in self._elemwise_modes: | |||
| content = '_cb({})'.format(i) | |||
| content = "_cb({})".format(i) | |||
| else: | |||
| content = '' | |||
| content = "" | |||
| self._write_def( | |||
| '_MEGDNN_ELEMWISE_MODE_ENABLE_IMPL_{}(_cb)'.format(i.split(' ')[0].split('=')[0]), content) | |||
| self._write_def('MEGDNN_ELEMWISE_MODE_ENABLE(_mode, _cb)', | |||
| '_MEGDNN_ELEMWISE_MODE_ENABLE_IMPL_##_mode(_cb)') | |||
| "_MEGDNN_ELEMWISE_MODE_ENABLE_IMPL_{}(_cb)".format( | |||
| i.split(" ")[0].split("=")[0] | |||
| ), | |||
| content, | |||
| ) | |||
| self._write_def( | |||
| "MEGDNN_ELEMWISE_MODE_ENABLE(_mode, _cb)", | |||
| "_MEGDNN_ELEMWISE_MODE_ENABLE_IMPL_##_mode(_cb)", | |||
| ) | |||
| def _write_dtype(self): | |||
| if 'Float16' not in self._dtypes: | |||
| if "Float16" not in self._dtypes: | |||
| # MegBrain/MegDNN used MEGDNN_DISABLE_FLOT16 to turn off float16 | |||
| # support in the past; however `FLOT16' is really a typo. We plan to | |||
| # change MEGDNN_DISABLE_FLOT16 to MEGDNN_DISABLE_FLOAT16 soon. | |||
| @@ -166,74 +184,86 @@ class HeaderGen: | |||
| # In the future when the situation is settled and no one would ever | |||
| # use legacy MegBrain/MegDNN, the `FLOT16' macro definition can be | |||
| # safely deleted. | |||
| self._write_def('MEGDNN_DISABLE_FLOT16', 1) | |||
| self._write_def('MEGDNN_DISABLE_FLOAT16', 1) | |||
| self._write_def("MEGDNN_DISABLE_FLOT16", 1) | |||
| self._write_def("MEGDNN_DISABLE_FLOAT16", 1) | |||
| def _write_hash(self): | |||
| if self._file_without_hash: | |||
| print('WARNING: network info has no graph hash. Using json file ' | |||
| 'generated by MegBrain >= 7.28.0 is recommended') | |||
| print( | |||
| "WARNING: network info has no graph hash. Using json file " | |||
| "generated by MegBrain >= 7.28.0 is recommended" | |||
| ) | |||
| else: | |||
| defs = 'ULL,'.join(self._graph_hashes) + 'ULL' | |||
| self._write_def('MGB_BINREDUCE_GRAPH_HASHES', defs) | |||
| defs = "ULL,".join(self._graph_hashes) + "ULL" | |||
| self._write_def("MGB_BINREDUCE_GRAPH_HASHES", defs) | |||
| def _write_def(self, name, val): | |||
| if isinstance(val, list): | |||
| val = '\n'.join(val) | |||
| val = str(val).strip().replace('\n', ' \\\n') | |||
| self._fout.write('#define {} {}\n'.format(name, val)) | |||
| val = "\n".join(val) | |||
| val = str(val).strip().replace("\n", " \\\n") | |||
| self._fout.write("#define {} {}\n".format(name, val)) | |||
| def _write_midout(self): | |||
| if not self._midout_files: | |||
| return | |||
| gen = os.path.join(self.get_megengine_root(), 'third_party', 'midout', 'gen_header.py') | |||
| gen = os.path.join( | |||
| self.get_megengine_root(), "third_party", "midout", "gen_header.py" | |||
| ) | |||
| if self.get_megvii3_root(): | |||
| gen = os.path.join(self.get_megvii3_root(), 'brain', 'midout', 'gen_header.py') | |||
| print('use {} to gen bin_reduce header'.format(gen)) | |||
| gen = os.path.join( | |||
| self.get_megvii3_root(), "brain", "midout", "gen_header.py" | |||
| ) | |||
| print("use {} to gen bin_reduce header".format(gen)) | |||
| cvt = subprocess.run( | |||
| [gen] + self._midout_files, | |||
| stdout=subprocess.PIPE, check=True, | |||
| ).stdout.decode('utf-8') | |||
| self._fout.write('// midout \n') | |||
| [gen] + self._midout_files, stdout=subprocess.PIPE, check=True, | |||
| ).stdout.decode("utf-8") | |||
| self._fout.write("// midout \n") | |||
| self._fout.write(cvt) | |||
| if cvt.find(" half,") > 0: | |||
| change = open(self._fout.name).read().replace(" half,", " __fp16,") | |||
| with open("fix_fp16_bin_reduce.h", "w") as fix_fp16: | |||
| fix_fp16.write(change) | |||
| msg = ( | |||
| "WARNING:\n" | |||
| "hit half in trace, try use fix_fp16_bin_reduce.h when build failed with bin_reduce.h\n" | |||
| "which caused by LLVM mangle issue on __fp16 dtype, if you find msg 'error: use of undeclared identifier 'half'\n" | |||
| "then try use fix_fp16_bin_reduce.h, if build failed again, submit a issue to Engine team!!!" | |||
| ) | |||
| "WARNING:\n" | |||
| "hit half in trace, try use fix_fp16_bin_reduce.h when build failed with bin_reduce.h\n" | |||
| "which caused by LLVM mangle issue on __fp16 dtype, if you find msg 'error: use of undeclared identifier 'half'\n" | |||
| "then try use fix_fp16_bin_reduce.h, if build failed again, submit a issue to Engine team!!!" | |||
| ) | |||
| print(msg) | |||
| def main(): | |||
| parser = argparse.ArgumentParser( | |||
| description='generate header file for reducing binary size by ' | |||
| 'stripping unused oprs in a particular network; output file would ' | |||
| 'be written to bin_reduce.h', | |||
| formatter_class=argparse.ArgumentDefaultsHelpFormatter) | |||
| description="generate header file for reducing binary size by " | |||
| "stripping unused oprs in a particular network; output file would " | |||
| "be written to bin_reduce.h", | |||
| formatter_class=argparse.ArgumentDefaultsHelpFormatter, | |||
| ) | |||
| parser.add_argument( | |||
| 'inputs', nargs='+', | |||
| help='input files that describe specific traits of the network; ' | |||
| 'can be one of the following:' | |||
| ' 1. json files generated by ' | |||
| 'megbrain.serialize_comp_graph_to_file() in python; ' | |||
| ' 2. trace files generated by midout library') | |||
| default_file=os.path.join(HeaderGen.get_megengine_root(), 'src', 'bin_reduce_cmake.h') | |||
| "inputs", | |||
| nargs="+", | |||
| help="input files that describe specific traits of the network; " | |||
| "can be one of the following:" | |||
| " 1. json files generated by " | |||
| "megbrain.serialize_comp_graph_to_file() in python; " | |||
| " 2. trace files generated by midout library", | |||
| ) | |||
| default_file = os.path.join( | |||
| HeaderGen.get_megengine_root(), "src", "bin_reduce_cmake.h" | |||
| ) | |||
| is_megvii3 = HeaderGen.get_megvii3_root() | |||
| if is_megvii3: | |||
| default_file=os.path.join(HeaderGen.get_megvii3_root(), 'utils', 'bin_reduce.h') | |||
| parser.add_argument('-o', '--output', help='output file', default=default_file) | |||
| default_file = os.path.join( | |||
| HeaderGen.get_megvii3_root(), "utils", "bin_reduce.h" | |||
| ) | |||
| parser.add_argument("-o", "--output", help="output file", default=default_file) | |||
| args = parser.parse_args() | |||
| print('config output file: {}'.format(args.output)) | |||
| print("config output file: {}".format(args.output)) | |||
| gen = HeaderGen() | |||
| for i in args.inputs: | |||
| print('==== processing {}'.format(i)) | |||
| print("==== processing {}".format(i)) | |||
| with open(i) as fin: | |||
| if fin.read(len(MIDOUT_TRACE_MAGIC)) == MIDOUT_TRACE_MAGIC: | |||
| gen.extend_midout(i) | |||
| @@ -241,8 +271,9 @@ def main(): | |||
| fin.seek(0) | |||
| gen.extend_netinfo(json.loads(fin.read())) | |||
| with open(args.output, 'w') as fout: | |||
| with open(args.output, "w") as fout: | |||
| gen.generate(fout) | |||
| if __name__ == '__main__': | |||
| if __name__ == "__main__": | |||
| main() | |||