GitOrigin-RevId: b092699dee
tags/v1.7.0
| @@ -74,7 +74,6 @@ option(MGE_ENABLE_EXCEPTIONS "Build with exceptions" ON) | |||
| option(MGE_WITH_TEST "Enable test for MegEngine." OFF) | |||
| option(MGE_WITH_DISTRIBUTED "Build with distributed support" ON) | |||
| option(MGE_BUILD_IMPERATIVE_RT "Build _imperative_rt Python Module " ON) | |||
| option(MGE_BUILD_SDK "Build load_and_run" ON) | |||
| option(MGE_INFERENCE_ONLY "Build inference only library." OFF) | |||
| option(MGE_WITH_MKLDNN "Enable Intel MKL_DNN support," ON) | |||
| option(MGE_WITH_ROCM "Enable ROCM support" OFF) | |||
| @@ -542,6 +541,8 @@ if(MGE_WITH_TEST) | |||
| include(cmake/gtest.cmake) | |||
| endif() | |||
| include(cmake/gflags.cmake) | |||
| if(MGE_BUILD_IMPERATIVE_RT) | |||
| set(CMAKE_CXX_STANDARD 17) | |||
| endif() | |||
| @@ -1147,10 +1148,6 @@ endif() | |||
| add_subdirectory(src) | |||
| if(MGE_BUILD_SDK) | |||
| add_subdirectory(sdk/load-and-run) | |||
| endif() | |||
| if(MGE_BUILD_IMPERATIVE_RT) | |||
| add_subdirectory(imperative) | |||
| message(STATUS "Enable imperative python wrapper runtime") | |||
| @@ -0,0 +1 @@ | |||
| add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/gflags ${CMAKE_CURRENT_BINARY_DIR}/gflags) | |||
| @@ -150,6 +150,9 @@ if(MGE_WITH_TEST) | |||
| add_subdirectory(test) | |||
| endif() | |||
| #load_and_run | |||
| add_subdirectory(load_and_run) | |||
| # tools and example | |||
| add_executable(rc4_encryptor tools/rc4_encrypt.cpp) | |||
| @@ -0,0 +1,38 @@ | |||
| load("//brain/megbrain/lite:flags.bzl","pthread_select") | |||
| cc_library( | |||
| name = "mgblar", | |||
| copts = ["-std=c++14"], | |||
| srcs = glob(["src/**/*.cpp"], exclude = ["src/main.cpp"]), | |||
| hdrs = glob(["src/**/*.h"]), | |||
| includes = ["src"], | |||
| features = if_opt([ | |||
| "no_exceptions", | |||
| "no_rtti", | |||
| ]), | |||
| defines = [ | |||
| "LITE_BUILD_WITH_MGE=1", | |||
| ], | |||
| deps = ["//brain/megbrain/lite:lite_static_test"]+ | |||
| pthread_select( | |||
| ["@com_github_gflags_gflags//:gflags_nothreads"], | |||
| ["//external:gflags"] | |||
| ), | |||
| alwayslink = 1, | |||
| visibility = ["//visibility:public"], | |||
| ) | |||
| cc_megvii_binary( | |||
| name = "load_and_run", | |||
| copts = ["-std=c++14"], | |||
| srcs = ["src/main.cpp"], | |||
| features = if_opt([ | |||
| "no_exceptions", | |||
| "no_rtti", | |||
| ]), | |||
| internal_deps = [":mgblar"], | |||
| visibility = ["//visibility:public"], | |||
| ) | |||
| @@ -0,0 +1,29 @@ | |||
| # BUILD the load and run for lite | |||
| include_directories(PUBLIC $<BUILD_INTERFACE:${PROJECT_SOURCE_DIR}/lite/load_and_run/src>) | |||
| file (GLOB_RECURSE SOURCES ./*.cpp) | |||
| add_executable (load_and_run ${SOURCES}) | |||
| target_link_libraries(load_and_run lite_static) | |||
| target_link_libraries(load_and_run megbrain) | |||
| target_link_libraries(load_and_run gflags) | |||
| if(LITE_BUILD_WITH_RKNPU) | |||
| #rknn sdk1.0.0 depend on libc++_shared, use gold to remove NEEDED so symbol check | |||
| target_link_options(load_and_run PRIVATE "-fuse-ld=gold") | |||
| endif() | |||
| if(MGE_WITH_ROCM) | |||
| # FIXME: hip obj can not find cpp obj only through lite_static | |||
| target_link_libraries(load_and_run megdnn) | |||
| endif() | |||
| if(UNIX) | |||
| if(APPLE OR ANDROID) | |||
| target_link_libraries(load_and_run dl) | |||
| else() | |||
| target_link_libraries(load_and_run dl rt) | |||
| endif() | |||
| endif() | |||
| install (TARGETS load_and_run EXPORT ${LITE_EXPORT_TARGETS} RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}) | |||
| @@ -0,0 +1,404 @@ | |||
| #!/usr/bin/env mdl | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from megskull.graph import NodeFilter, FpropEnv | |||
| from megskull.opr.all import AssertEqual, DataProvider, BatchNormalization | |||
| from megskull.utils.logconf import get_logger | |||
| from meghair.utils import io | |||
| import megbrain as mgb | |||
| import argparse | |||
| import struct | |||
| import re | |||
| import os | |||
| import numpy as np | |||
| import cv2 | |||
| logger = get_logger(__name__) | |||
| def auto_reformat_image(args, path, data, dst_shape): | |||
| """reformat image to target shape | |||
| :param data: image data as numpy array | |||
| :param dst_shape: target shape | |||
| """ | |||
| dim3_format = False # required input format does not contain batch | |||
| hwc_format = False # required input format is NHWC | |||
| if len(dst_shape) == 3: | |||
| dst_shape = (1, ) + dst_shape | |||
| dim3_format = True | |||
| assert len(dst_shape) == 4, 'bad dst_shape: {}'.format(dst_shape) | |||
| chl = dst_shape[1] | |||
| if chl in [1, 3]: | |||
| n, c, h, w = dst_shape | |||
| dst_shape = (n, h, w, c) | |||
| else: | |||
| chl = dst_shape[3] | |||
| assert chl in [1, 3], ( | |||
| 'can not infer input format from shape: {}'.format(dst_shape)) | |||
| hwc_format = True | |||
| # dst_shape has now been normalized to NHWC format | |||
| if args.resize_input: | |||
| h, w = dst_shape[1:3] | |||
| data = cv2.resize(data, (w, h)) | |||
| logger.info('input {} resized to {}'.format(path, data.shape)) | |||
| if chl == 1: | |||
| data = cv2.cvtColor(data, cv2.COLOR_BGR2GRAY) | |||
| data = data[:, :, np.newaxis] | |||
| assert data.ndim == 3 | |||
| data = data[np.newaxis] | |||
| # data normalized to NHWC format | |||
| if not hwc_format: | |||
| data = np.transpose(data, (0, 3, 1, 2)) | |||
| if dim3_format: | |||
| data = np.squeeze(data, 0) | |||
| return data | |||
| def read_input_data(args, dst_shape, dtype, path, repeat): | |||
| def check_shape_equal(dst_shape, data_shape): | |||
| assert len(data_shape) == len(dst_shape) , ( | |||
| 'input/data shapes mismatch: {} vs {}'.format( | |||
| dst_shape, data_shape)) | |||
| if data_shape[1:] != dst_shape[1:]: | |||
| logger.warning('dst_shape is {}; data_shape is {}'.format( | |||
| dst_shape, data_shape)) | |||
| if path.startswith('#'): | |||
| assert not args.resize_input | |||
| assert not args.input_transform | |||
| spec = path | |||
| m = re.match( | |||
| r'^#rand\(([-0-9.]*)\s*,\s*([-0-9.]*)\s*(,[^\)]+)?\)$', spec) | |||
| assert m, 'bad spec {}'.format(spec) | |||
| rng_min = float(m.group(1)) | |||
| rng_max = float(m.group(2)) | |||
| if m.group(3): | |||
| shape_str = m.group(3) | |||
| try: | |||
| shape = shape_str[1:].split(',') | |||
| if shape[-1].strip() == '...': | |||
| shape = shape[:-1] | |||
| shape.extend(list(dst_shape[len(shape):])) | |||
| data_shape = tuple(map(int, shape)) | |||
| except ValueError as e: | |||
| raise ValueError('bad spec {}: {}'.format(spec, e.args)) | |||
| else: | |||
| data_shape = dst_shape | |||
| check_shape_equal(dst_shape, data_shape) | |||
| return np.random.uniform(rng_min, rng_max, data_shape).astype(dtype) | |||
| # try to load image | |||
| data = cv2.imread(path, cv2.IMREAD_COLOR) | |||
| if data is None: | |||
| assert not args.resize_input | |||
| data = io.load(path) | |||
| assert isinstance(data, np.ndarray) | |||
| else: | |||
| # load image succeeds, so we expect input format is image format | |||
| data = auto_reformat_image(args, path, data, dst_shape) | |||
| data = np.repeat(data, repeat, axis=0) | |||
| if repeat > 1: | |||
| logger.info('repeat input for {} times, data shape is {}'.format( | |||
| repeat, data.shape)) | |||
| check_shape_equal(dst_shape, data.shape) | |||
| if args.input_transform: | |||
| data = eval(args.input_transform, {'data': data, 'np': np}) | |||
| return data | |||
| def gen_one_testcase(args, inputs, spec): | |||
| paths = spec.split(';') | |||
| if len(paths) != len(inputs): | |||
| if len(paths) == 1 and paths[0].startswith('#'): | |||
| paths = ['{}:{}'.format(name, paths[0]) for name in inputs.keys()] | |||
| assert len(paths) == len(inputs), ( | |||
| 'required inputs: {}; data paths: {}'.format(inputs.keys(), paths)) | |||
| if len(paths) == 1 and ':' not in paths[0]: | |||
| paths[0] = next(iter(inputs.keys())) + ':' + paths[0] | |||
| ret = {} | |||
| for path in paths: | |||
| var, path = path.split(':') | |||
| if args.repeat: | |||
| repeat = args.repeat | |||
| else: | |||
| repeat = 1 | |||
| ret[var] = read_input_data(args, inputs[var].imm_shape, | |||
| inputs[var].dtype, path, repeat) | |||
| return ret | |||
| def make_feeds(args): | |||
| outputs = io.load_network(args.input).outputs | |||
| if not args.no_assert: | |||
| env = FpropEnv(verbose_fprop=False) | |||
| # set flag so ExternCOprPlaceholder produce expected output | |||
| env.flags.user['extern_c_opr_eval'] = True | |||
| func = env.comp_graph.compile(None, [mgb.copy_output(env.get_mgbvar(i)) | |||
| for i in outputs]) | |||
| def expect_name(var): return 'expect:{}'.format(var.name) | |||
| nf = NodeFilter.make_all_deps(*outputs) | |||
| inputs = {i.name: i for i in nf.data_provider()} | |||
| if args.init_bn: | |||
| for i in nf: | |||
| if isinstance(i, BatchNormalization): | |||
| if i._iter.get_value() == 0: | |||
| i._iter.set_value(1) | |||
| i._variance.set_value(np.ones(i._variance.shape)) | |||
| testcases = [] | |||
| np.set_printoptions(precision=2, threshold=4, suppress=True) | |||
| data_list = [] | |||
| for item in args.data: | |||
| if item.startswith('@'): | |||
| with open(item[1:], 'r') as f: | |||
| data_list.extend([ line.rstrip() for line in f if line.rstrip() != '']) | |||
| else: | |||
| data_list.append(item) | |||
| for inp_spec in data_list: | |||
| cur_testcase = gen_one_testcase(args, inputs, inp_spec) | |||
| assert len(cur_testcase) == len(inputs), ( | |||
| 'required inputs: {}; given data: {}'.format( | |||
| inputs.keys(), cur_testcase.keys())) | |||
| if not args.no_assert: | |||
| outputs_get = func(**cur_testcase) | |||
| for var, val in zip(outputs, outputs_get): | |||
| cur_testcase[expect_name(var)] = val | |||
| logger.info( | |||
| 'generate test groundtruth: var={} shape={} range=({}, {})' | |||
| ' mean={} var={}'.format( | |||
| var, val.shape, val.min(), val.max(), | |||
| np.mean(val), np.var(val))) | |||
| testcases.append(cur_testcase) | |||
| logger.info('add testcase: \n {}'.format( | |||
| '\n '.join('{}: shape={} dtype={} range=({:.2f},{:.2f}) ' | |||
| 'mean={:.2f} sd={:.2f}'.format( | |||
| k, v.shape, v.dtype, v.min(), v.max(), np.mean(v), | |||
| np.std(v)) | |||
| for k, v in sorted(cur_testcase.items())))) | |||
| if not args.no_assert: | |||
| def expect_shp(var): | |||
| ret = var.partial_shape.determined_shape | |||
| if ret: | |||
| return ret | |||
| return testcases[0][expect_name(var)].shape | |||
| verbose = not args.silent | |||
| outputs = [AssertEqual(DataProvider(expect_name(i), expect_shp(i), | |||
| dtype=i.dtype, | |||
| comp_node=i.comp_node), | |||
| i, verbose=verbose, maxerr=args.maxerr) | |||
| for i in outputs] | |||
| return {'outputs': outputs, 'testcases': testcases} | |||
| def optimize_for_inference(args, outputs): | |||
| args_map = { | |||
| 'enable_io16xc32': 'f16_io_f32_comp', | |||
| 'enable_ioc16': 'f16_io_comp', | |||
| 'enable_hwcd4': 'use_nhwcd4', | |||
| 'enable_nchw4': 'use_nchw4', | |||
| 'enable_nchw88': 'use_nchw88', | |||
| 'enable_nchw44': 'use_nchw44', | |||
| 'enable_nchw44_dot': 'use_nchw44_dot', | |||
| 'enable_nchw32': 'use_nchw32', | |||
| 'enable_chwn4': 'use_chwn4', | |||
| 'enable_fuse_conv_bias_nonlinearity': 'fuse_conv_bias_nonlinearity', | |||
| 'enable_fuse_conv_bias_with_z': 'fuse_conv_bias_with_z', | |||
| 'enable_nchw64': 'use_nchw64', | |||
| 'enable_fuse_preprocess': 'fuse_preprocess', | |||
| } | |||
| kwargs = {} | |||
| for k, v in args_map.items(): | |||
| if getattr(args, k): | |||
| assert args.optimize_for_inference, ( | |||
| 'optimize_for_inference should be set when {} is given'.format( | |||
| k)) | |||
| kwargs[v] = True | |||
| if args.optimize_for_inference: | |||
| return mgb.optimize_for_inference(outputs, **kwargs) | |||
| return outputs | |||
| def main(): | |||
| parser = argparse.ArgumentParser( | |||
| description='Pack computing graph, input values and expected output ' | |||
| 'values into one file for checking correctness. README.md gives more ' | |||
| 'details on the usage', | |||
| formatter_class=argparse.ArgumentDefaultsHelpFormatter) | |||
| parser.add_argument('input', help='input file; see README for details') | |||
| parser.add_argument('-o', '--output', help='output file', required=True) | |||
| parser.add_argument('--init-bn', action='store_true', | |||
| help='initialize untrained batch-normalization, to ' | |||
| 'avoid NaN or Inf results') | |||
| parser.add_argument( | |||
| '-d', '--data', default=[], action='append', | |||
| help='Given input test data when input file is a network, ' | |||
| 'and current network output would be used as groundtruth. ' | |||
| 'The format is var0:file0;var1:file1... to specify data files for ' | |||
| 'input vars. It can also be #rand(min,max,shape...) for generating ' | |||
| 'random input data, for example, #rand(0,255), ' | |||
| '#rand(0,255,1,3,224,224) or #rand(0, 255, 1, ...) where `...` means ' | |||
| 'the remaining part of the original shape. ' | |||
| 'If the shape is not specified, the shape of ' | |||
| 'corresponding DataProvider in the network will be used. ' | |||
| 'If there is only one input var, its name can be omitted. ' | |||
| 'Each data file can either be an image which can be loaded by opencv, ' | |||
| 'or a pickled numpy.ndarray. ' | |||
| 'This option can be given multiple times to add multiple testcases. ' | |||
| ' *NOTE* ' | |||
| 'If you start the data with the letter @, the rest should be a ' | |||
| 'filename, and each line in the file should be a single datum in ' | |||
| 'the format described above. ' | |||
| ) | |||
| parser.add_argument( | |||
| '--repeat', type=int, default=1, | |||
| help='Specify how many times the input image is repeated. ' | |||
| 'Useful when running benchmark for batch size other than one. ' | |||
| 'Have no effect on randomly generated input data.') | |||
| parser.add_argument('--silent', action='store_true', | |||
| help='set verbose to False in AssertEqual opr') | |||
| parser.add_argument('--optimize-for-inference', action='store_true', | |||
| help='enbale optimization for inference') | |||
| parser.add_argument('--no-assert', action='store_true', | |||
| help='do not insert AssertEqual opr to check result; ' | |||
| 'this option is useful for benchmarking') | |||
| parser.add_argument('--maxerr', type=float, default=AssertEqual.maxerr, | |||
| help='max error for AssertEqual check during runtime') | |||
| parser.add_argument('--resize-input', action='store_true', | |||
| help='resize input image to fit input var shape') | |||
| parser.add_argument('--input-transform', | |||
| help='a python expression to transform the input data. ' | |||
| 'Example: data / np.std(data)') | |||
| parser.add_argument('--discard-var-name', action='store_true', | |||
| help='discard variable and param names in the ' | |||
| 'generated output') | |||
| parser.add_argument('--output-strip-info', action='store_true', | |||
| help='output code strip information') | |||
| parser.add_argument('--enable-io16xc32', action='store_true', | |||
| help='transform the mode to float16 io float32 compute') | |||
| parser.add_argument('--enable-ioc16', action='store_true', | |||
| help='transform the dtype of the model to float16 io ' | |||
| 'and compute') | |||
| parser.add_argument('--enable-fuse-conv-bias-nonlinearity', | |||
| action='store_true', | |||
| help='fuse convolution bias and nonlinearity opr to a ' | |||
| 'conv_bias opr and compute') | |||
| parser.add_argument('--enable-hwcd4', action='store_true', | |||
| help='transform the model format from NCHW to NHWCD4 ' | |||
| 'for inference; you may need to disable CUDA and set ' | |||
| 'MGB_USE_MEGDNN_DBG=2') | |||
| parser.add_argument('--enable-nchw4', action='store_true', | |||
| help='transform the model format from NCHW to NCHW4 ' | |||
| 'for inference') | |||
| parser.add_argument('--enable-nchw88', action='store_true', | |||
| help='transform the model format from NCHW to NCHW88 ' | |||
| 'for inference') | |||
| parser.add_argument('--enable-nchw44', action='store_true', | |||
| help='transform the model format from NCHW to NCHW44 ' | |||
| 'for inference') | |||
| parser.add_argument('--enable-nchw44-dot', action='store_true', | |||
| help='transform the model format from NCHW to NCHW44_DOT ' | |||
| 'for optimizing armv8.2 dot in inference') | |||
| parser.add_argument('--enable-chwn4', action='store_true', | |||
| help='transform the model format to CHWN4 ' | |||
| 'for inference, mainly used for nvidia tensorcore') | |||
| parser.add_argument('--enable-nchw32', action='store_true', | |||
| help='transform the model format from NCHW4 to NCHW32 ' | |||
| 'for inference on nvidia TensoCore') | |||
| parser.add_argument('--enable-nchw64', action='store_true', | |||
| help='transform the model format from NCHW to NCHW64 ' | |||
| 'for inference on Nvidia GPU') | |||
| parser.add_argument('--enable-fuse-conv-bias-with-z', action='store_true', | |||
| help='fuse conv_bias with z input for inference on ' | |||
| 'nvidia GPU (this optimization pass will result in mismatch ' | |||
| 'of the precision of output of training and inference)') | |||
| parser.add_argument('--enable-fuse-preprocess', action='store_true', | |||
| help='fuse astype\pad_channel\dimshuffle and etc opr ' | |||
| 'from h2d op') | |||
| args = parser.parse_args() | |||
| if args.data: | |||
| feeds = make_feeds(args) | |||
| else: | |||
| feeds = io.load(args.input) | |||
| assert isinstance(feeds, dict) and feeds['testcases'], ( | |||
| 'testcases can not be empty') | |||
| env = FpropEnv(verbose_fprop=False) | |||
| outputs = feeds['outputs'] | |||
| output_mgbvars = list(map(env.get_mgbvar, outputs)) | |||
| output_mgbvars = optimize_for_inference(args, output_mgbvars) | |||
| inputs = sorted(((i.name, i.dtype) for i in | |||
| NodeFilter.make_all_deps(*outputs).data_provider())) | |||
| if args.discard_var_name: | |||
| sereg_kwargs = dict(keep_var_name=0, keep_param_name=False) | |||
| else: | |||
| sereg_kwargs = dict(keep_var_name=2, keep_param_name=True) | |||
| with open(args.output, 'wb') as fout: | |||
| fout.write(b'mgbtest0') | |||
| fout.write(struct.pack('I', len(feeds['testcases']))) | |||
| stat = mgb.serialize_comp_graph_to_file( | |||
| args.output, output_mgbvars, append=True, | |||
| output_strip_info=args.output_strip_info, | |||
| **sereg_kwargs) | |||
| logger.info('graph dump sizes: tot_size={:.3f}KiB overhead={:.3f}KiB'. | |||
| format(stat.tot_bytes / 1024, | |||
| (stat.tot_bytes - stat.tensor_value_bytes) / 1024)) | |||
| for testcase in feeds['testcases']: | |||
| assert isinstance(testcase, dict) | |||
| cg = mgb.comp_graph() | |||
| cn = mgb.comp_node('cpux') | |||
| output_mgbvars = [] | |||
| for name, dtype in inputs: | |||
| output_mgbvars.append(cg.make_shared(cn, value=testcase.pop(name), | |||
| dtype=dtype)) | |||
| assert not testcase, 'extra inputs provided in testcase: {}'.format( | |||
| testcase.keys()) | |||
| mgb.serialize_comp_graph_to_file( | |||
| args.output, | |||
| output_mgbvars, | |||
| append=True, | |||
| output_strip_info=args.output_strip_info, | |||
| append_json=True) | |||
| if __name__ == '__main__': | |||
| main() | |||
| @@ -0,0 +1,535 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import argparse | |||
| import os | |||
| import re | |||
| import struct | |||
| import cv2 | |||
| import numpy as np | |||
| import megengine as mge | |||
| import megengine.core._imperative_rt as rt | |||
| import megengine.core.tensor.megbrain_graph as G | |||
| from megengine import tensor | |||
| from megengine.core._imperative_rt.core2 import apply | |||
| from megengine.core.ops import builtin | |||
| from megengine.utils import comp_graph_tools as cgtools | |||
| logger = mge.get_logger(__name__) | |||
| def auto_reformat_image(args, path, data, dst_shape): | |||
| """reformat image to target shape | |||
| :param data: image data as numpy array | |||
| :param dst_shape: target shape | |||
| """ | |||
| dim3_format = False # required input format does not contain batch | |||
| hwc_format = False # required input format is NHWC | |||
| if not dst_shape: # input tensor shape is not predefined | |||
| if len(data.shape) == 2: | |||
| chl = 1 | |||
| h = data.shape[0] | |||
| w = data.shape[1] | |||
| else: | |||
| assert len(data.shape) == 3, "Input image must be of dimension 2 or 3" | |||
| h, w, chl = data.shape | |||
| dst_shape = (1, chl, h, w) | |||
| if len(dst_shape) == 3: | |||
| dst_shape = (1,) + dst_shape | |||
| dim3_format = True | |||
| assert len(dst_shape) == 4, "bad dst_shape: {}".format(dst_shape) | |||
| chl = dst_shape[1] | |||
| if chl in [1, 3]: | |||
| n, c, h, w = dst_shape | |||
| dst_shape = (n, h, w, c) | |||
| else: | |||
| chl = dst_shape[3] | |||
| assert chl in [1, 3], "can not infer input format from shape: {}".format( | |||
| dst_shape | |||
| ) | |||
| hwc_format = True | |||
| # dst_shape has now been normalized to NHWC format | |||
| if args.resize_input: | |||
| h, w = dst_shape[1:3] | |||
| data = cv2.resize(data, (w, h)) | |||
| logger.info("input {} resized to {}".format(path, data.shape)) | |||
| if chl == 1: | |||
| data = cv2.cvtColor(data, cv2.COLOR_BGR2GRAY) | |||
| data = data[:, :, np.newaxis] | |||
| assert data.ndim == 3 | |||
| data = data[np.newaxis] | |||
| # data normalized to NHWC format | |||
| if not hwc_format: | |||
| data = np.transpose(data, (0, 3, 1, 2)) | |||
| if dim3_format: | |||
| data = np.squeeze(data, 0) | |||
| return data | |||
| def read_input_data(args, dst_shape, dtype, path, repeat): | |||
| def check_shape_equal(dst_shape, data_shape): | |||
| if len(dst_shape): | |||
| assert len(data_shape) == len( | |||
| dst_shape | |||
| ), "input/data shapes mismatch: {} vs {}".format(dst_shape, data_shape) | |||
| if data_shape[1:] != dst_shape[1:]: | |||
| logger.warning( | |||
| "dst_shape is {}; data_shape is {}".format(dst_shape, data_shape) | |||
| ) | |||
| if path.startswith("#"): | |||
| assert not args.resize_input | |||
| assert not args.input_transform | |||
| spec = path | |||
| m = re.match(r"^#rand\(([-0-9.]*)\s*,\s*([-0-9.]*)\s*(,[^\)]+)?\)$", spec) | |||
| assert m, "bad spec {}".format(spec) | |||
| rng_min = float(m.group(1)) | |||
| rng_max = float(m.group(2)) | |||
| if m.group(3): | |||
| shape_str = m.group(3) | |||
| try: | |||
| shape = shape_str[1:].split(",") | |||
| if shape[-1].strip() == "...": | |||
| shape = shape[:-1] | |||
| shape.extend(list(dst_shape[len(shape) :])) | |||
| data_shape = tuple(map(int, shape)) | |||
| except ValueError as e: | |||
| raise ValueError("bad spec {}: {}".format(spec, e.args)) | |||
| else: | |||
| data_shape = dst_shape | |||
| check_shape_equal(dst_shape, data_shape) | |||
| return np.random.uniform(rng_min, rng_max, data_shape).astype(dtype) | |||
| # try to load image | |||
| data = cv2.imread(path, cv2.IMREAD_COLOR) | |||
| if data is None: | |||
| assert not args.resize_input | |||
| data = np.load(path) | |||
| assert isinstance(data, np.ndarray) | |||
| else: | |||
| # load image succeeds, so we expect input format is image format | |||
| data = auto_reformat_image(args, path, data, dst_shape) | |||
| data = np.repeat(data, repeat, axis=0) | |||
| if repeat > 1: | |||
| logger.info( | |||
| "repeat input for {} times, data shape is {}".format(repeat, data.shape) | |||
| ) | |||
| check_shape_equal(dst_shape, data.shape) | |||
| if args.input_transform: | |||
| data = eval(args.input_transform, {"data": data, "np": np}) | |||
| return data | |||
| def gen_one_testcase(args, inputs, spec): | |||
| paths = spec.split(";") | |||
| if len(paths) != len(inputs): | |||
| if len(paths) == 1 and paths[0].startswith("#"): | |||
| paths = ["{}:{}".format(name, paths[0]) for name in inputs.keys()] | |||
| assert len(paths) == len(inputs), "required inputs: {}; data paths: {}".format( | |||
| inputs.keys(), paths | |||
| ) | |||
| if len(paths) == 1 and ":" not in paths[0]: | |||
| paths[0] = next(iter(inputs.keys())) + ":" + paths[0] | |||
| ret = {} | |||
| for path in paths: | |||
| var, path = path.split(":") | |||
| if args.repeat: | |||
| repeat = args.repeat | |||
| else: | |||
| repeat = 1 | |||
| ret[var] = read_input_data( | |||
| args, inputs[var].shape, inputs[var].dtype, path, repeat | |||
| ) | |||
| return ret | |||
| def make_feeds(args): | |||
| ret = G.load_graph(args.input) | |||
| cg_rt, outputs = ret.graph, ret.output_vars_list | |||
| inputs = cgtools.get_dep_vars(outputs, "Host2DeviceCopy") | |||
| inputs = {i.name: i for i in inputs} | |||
| if not args.no_assert: | |||
| replace_varmap = {} | |||
| inp_map = {} | |||
| # replace var use InputNode | |||
| for name, var in inputs.items(): | |||
| inp = G.InputNode( | |||
| device="xpux", dtype=var.dtype, shape=var.shape, graph=cg_rt | |||
| ) | |||
| replace_varmap[var] = inp.outputs[0] | |||
| inp_map[name] = inp | |||
| new = cgtools.replace_vars(outputs, replace_varmap) | |||
| if isinstance(new, rt.VarNode): | |||
| new = list(new) | |||
| output_nodes = [G.OutputNode(var) for var in new] | |||
| func = cg_rt.compile([node.outputs[0] for node in output_nodes]) | |||
| def make_dev_tensor(value, dtype=None, device=None): | |||
| return tensor(value, dtype=dtype, device=device)._dev_tensor() | |||
| def calculate(*args, **kwargs): | |||
| output_val = [] | |||
| # set inputs value | |||
| for name, var in inputs.items(): | |||
| val = kwargs.pop(name, None) | |||
| assert val is not None, "miss input name{}".format(name) | |||
| dev_tensor = make_dev_tensor(val, dtype=var.dtype, device="xpux") | |||
| inp_map[name].set_value(dev_tensor) | |||
| func.execute() | |||
| for res in output_nodes: | |||
| output_val.append(res.get_value().numpy()) | |||
| return output_val | |||
| def expect_name(var): | |||
| return "{}:expect".format(var.name) | |||
| testcases = [] | |||
| np.set_printoptions(precision=2, threshold=4, suppress=True) | |||
| data_list = [] | |||
| for item in args.data: | |||
| if item.startswith("@"): | |||
| with open(item[1:], "r") as f: | |||
| data_list.extend([line.rstrip() for line in f if line.rstrip() != ""]) | |||
| else: | |||
| data_list.append(item) | |||
| for inp_spec in data_list: | |||
| cur_testcase = gen_one_testcase(args, inputs, inp_spec) | |||
| assert len(cur_testcase) == len( | |||
| inputs | |||
| ), "required inputs: {}; given data: {}".format( | |||
| inputs.keys(), cur_testcase.keys() | |||
| ) | |||
| if not args.no_assert: | |||
| outputs_get = calculate(**cur_testcase) | |||
| for var, val in zip(outputs, outputs_get): | |||
| cur_testcase[expect_name(var)] = val | |||
| logger.info( | |||
| "generate test groundtruth: var={} shape={} range=({}, {})" | |||
| " mean={} var={}".format( | |||
| var, val.shape, val.min(), val.max(), np.mean(val), np.var(val) | |||
| ) | |||
| ) | |||
| testcases.append(cur_testcase) | |||
| logger.info( | |||
| "add testcase: \n {}".format( | |||
| "\n ".join( | |||
| "{}: shape={} dtype={} range=({:.2f},{:.2f}) " | |||
| "mean={:.2f} sd={:.2f}".format( | |||
| k, v.shape, v.dtype, v.min(), v.max(), np.mean(v), np.std(v) | |||
| ) | |||
| for k, v in sorted(cur_testcase.items()) | |||
| ) | |||
| ) | |||
| ) | |||
| if not args.no_assert: | |||
| def expect_shp(var): | |||
| ret = var.shape | |||
| if ret: | |||
| return ret | |||
| return testcases[0][expect_name(var)].shape | |||
| def assert_equal(expect, real, **kwargs): | |||
| op = builtin.AssertEqual(**kwargs) | |||
| (res,) = G.apply_normal_varnode(op, expect, real) | |||
| return res | |||
| verbose = not args.silent | |||
| outputs_new = [] | |||
| for i in outputs: | |||
| device = rt.CompNode("xpux") | |||
| dtype = i.dtype | |||
| name = expect_name(i) | |||
| shape = expect_shp(i) | |||
| # make expect output as one input of model. | |||
| expect_get = rt.make_h2d(cg_rt, device, dtype, shape, name) | |||
| # insert assert opr to check expect and real. | |||
| outputs_new.append( | |||
| assert_equal( | |||
| expect_get, | |||
| i, | |||
| verbose=verbose, | |||
| maxerr=args.maxerr, | |||
| ) | |||
| ) | |||
| inputs[expect_name(i)] = expect_get | |||
| outputs = outputs_new | |||
| return {"outputs": outputs, "testcases": testcases} | |||
| def optimize_for_inference(args, outputs): | |||
| args_list = [ | |||
| "enable_io16xc32", | |||
| "enable_ioc16", | |||
| "enable_hwcd4", | |||
| "enable_nchw4", | |||
| "enable_nchw88", | |||
| "enable_nchw44", | |||
| "enable_nchw44_dot", | |||
| "enable_nchw32", | |||
| "enable_chwn4", | |||
| "enable_fuse_conv_bias_nonlinearity", | |||
| "enable_fuse_conv_bias_with_z", | |||
| "enable_fuse_preprocess", | |||
| ] | |||
| kwargs = {} | |||
| for k in args_list: | |||
| if getattr(args, k): | |||
| kwargs[k] = True | |||
| if args.optimize_for_inference: | |||
| outputs = G.optimize_for_inference(outputs, **kwargs) | |||
| return outputs | |||
| def main(): | |||
| parser = argparse.ArgumentParser( | |||
| description="Pack computing graph, input values and expected output " | |||
| "values into one file for checking correctness. README.md gives more " | |||
| "details on the usage", | |||
| formatter_class=argparse.ArgumentDefaultsHelpFormatter, | |||
| ) | |||
| parser.add_argument("input", help="MegEngine dumped model file") | |||
| parser.add_argument("-o", "--output", help="output file", required=True) | |||
| parser.add_argument( | |||
| "-d", | |||
| "--data", | |||
| default=[], | |||
| action="append", | |||
| required=True, | |||
| help="Given input test data when input file is a network, " | |||
| "and current network output would be used as groundtruth. " | |||
| "The format is var0:file0;var1:file1... to specify data files for " | |||
| "input vars. It can also be #rand(min,max,shape...) for generating " | |||
| "random input data, for example, #rand(0,255), " | |||
| "#rand(0,255,1,3,224,224) or #rand(0, 255, 1, ...) where `...` means " | |||
| "the remaining part of the original shape. " | |||
| "If the shape is not specified, the shape of " | |||
| "corresponding input tensors in the network will be used. " | |||
| "If there is only one input var, its name can be omitted. " | |||
| "Each data file can either be an image which can be loaded by opencv, " | |||
| "or a pickled numpy.ndarray. " | |||
| "This option can be given multiple times to add multiple testcases. " | |||
| " *NOTE* " | |||
| "If you start the data with the letter @, the rest should be a " | |||
| "filename, and each line in the file should be a single datum in " | |||
| "the format described above. ", | |||
| ) | |||
| parser.add_argument( | |||
| "--repeat", | |||
| type=int, | |||
| default=1, | |||
| help="Specify how many times the input image is repeated. " | |||
| "Useful when running benchmark for batch size other than one. " | |||
| "Have no effect on randomly generated input data.", | |||
| ) | |||
| parser.add_argument( | |||
| "--silent", | |||
| action="store_true", | |||
| help="set verbose to False in asserti_equal opr", | |||
| ) | |||
| parser.add_argument( | |||
| "--optimize-for-inference", | |||
| action="store_true", | |||
| help="enable optimization for inference", | |||
| ) | |||
| parser.add_argument( | |||
| "--no-assert", | |||
| action="store_true", | |||
| help="do not insert assert_equal opr to check result; " | |||
| "this option is useful for benchmarking", | |||
| ) | |||
| parser.add_argument( | |||
| "--maxerr", | |||
| type=float, | |||
| default=1e-4, | |||
| help="max error for assert_equal check during runtime", | |||
| ) | |||
| parser.add_argument( | |||
| "--resize-input", | |||
| action="store_true", | |||
| help="resize input image to fit input var shape", | |||
| ) | |||
| parser.add_argument( | |||
| "--input-transform", | |||
| help="a python expression to transform the input data. " | |||
| "Example: data / np.std(data)", | |||
| ) | |||
| parser.add_argument( | |||
| "--discard-var-name", | |||
| action="store_true", | |||
| help="discard variable and param names in the " "generated output", | |||
| ) | |||
| parser.add_argument( | |||
| "--output-strip-info", action="store_true", help="output code strip information" | |||
| ) | |||
| parser.add_argument( | |||
| "--enable-io16xc32", | |||
| action="store_true", | |||
| help="transform the mode to float16 io float32 compute", | |||
| ) | |||
| parser.add_argument( | |||
| "--enable-ioc16", | |||
| action="store_true", | |||
| help="transform the dtype of the model to float16 io " "and compute", | |||
| ) | |||
| parser.add_argument( | |||
| "--enable-fuse-conv-bias-nonlinearity", | |||
| action="store_true", | |||
| help="fuse convolution bias and nonlinearity opr to a " | |||
| "conv_bias opr and compute", | |||
| ) | |||
| parser.add_argument( | |||
| "--enable-hwcd4", | |||
| action="store_true", | |||
| help="transform the model format from NCHW to NHWCD4 " | |||
| "for inference; you may need to disable CUDA and set " | |||
| "MGB_USE_MEGDNN_DBG=2", | |||
| ) | |||
| parser.add_argument( | |||
| "--enable-nchw4", | |||
| action="store_true", | |||
| help="transform the model format from NCHW to NCHW4 " "for inference", | |||
| ) | |||
| parser.add_argument( | |||
| "--enable-nchw88", | |||
| action="store_true", | |||
| help="transform the model format from NCHW to NCHW88 " "for inference", | |||
| ) | |||
| parser.add_argument( | |||
| "--enable-nchw44", | |||
| action="store_true", | |||
| help="transform the model format from NCHW to NCHW44 " "for inference", | |||
| ) | |||
| parser.add_argument( | |||
| "--enable-nchw44-dot", | |||
| action="store_true", | |||
| help="transform the model format from NCHW to NCHW44_DOT " | |||
| "for optimizing armv8.2 dot in inference", | |||
| ) | |||
| parser.add_argument( | |||
| "--enable-nchw32", | |||
| action="store_true", | |||
| help="transform the model format from NCHW4 to NCHW32 " | |||
| "for inference on nvidia TensoCore", | |||
| ) | |||
| parser.add_argument( | |||
| "--enable-chwn4", | |||
| action="store_true", | |||
| help="transform the model format to CHWN4 " | |||
| "for inference, mainly used for nvidia tensorcore", | |||
| ) | |||
| parser.add_argument( | |||
| "--enable-fuse-conv-bias-with-z", | |||
| action="store_true", | |||
| help="fuse conv_bias with z input for inference on " | |||
| "nvidia GPU (this optimization pass will result in mismatch " | |||
| "of the precision of output of training and inference)", | |||
| ) | |||
| parser.add_argument( | |||
| "--enable-fuse-preprocess", | |||
| action="store_true", | |||
| help="fuse astype\pad_channel\dimshuffle and etc opr " | |||
| "from h2d opr", | |||
| ) | |||
| args = parser.parse_args() | |||
| feeds = make_feeds(args) | |||
| assert isinstance(feeds, dict) and feeds["testcases"], "testcases can not be empty" | |||
| output_mgbvars = feeds["outputs"] | |||
| output_mgbvars = optimize_for_inference(args, output_mgbvars) | |||
| inputs = cgtools.get_dep_vars(output_mgbvars, "Host2DeviceCopy") | |||
| inputs = sorted((i.name, i.dtype) for i in inputs) | |||
| if args.discard_var_name: | |||
| sereg_kwargs = dict(keep_var_name=0, keep_param_name=False) | |||
| else: | |||
| sereg_kwargs = dict(keep_var_name=2, keep_param_name=True) | |||
| strip_info_file = args.output + ".json" if args.output_strip_info else None | |||
| with open(args.output, "wb") as fout: | |||
| fout.write(b"mgbtest0") | |||
| fout.write(struct.pack("I", len(feeds["testcases"]))) | |||
| dump_content, stat = G.dump_graph( | |||
| output_mgbvars, | |||
| append_json=True, | |||
| strip_info_file=strip_info_file, | |||
| **sereg_kwargs, | |||
| ) | |||
| fout.write(dump_content) | |||
| logger.info( | |||
| "graph dump sizes: tot_size={:.3f}KiB overhead={:.3f}KiB".format( | |||
| stat.tot_bytes / 1024, (stat.tot_bytes - stat.tensor_value_bytes) / 1024 | |||
| ) | |||
| ) | |||
| def make_dev_tensor(value, dtype=None, device=None): | |||
| return tensor(value, dtype=dtype, device=device)._dev_tensor() | |||
| for testcase in feeds["testcases"]: | |||
| assert isinstance(testcase, dict) | |||
| cg = G.Graph() | |||
| output_mgbvars = [] | |||
| for name, dtype in inputs: | |||
| output_mgbvars.append( | |||
| cg.make_const( | |||
| make_dev_tensor(testcase.pop(name), dtype=dtype, device="cpux") | |||
| ) | |||
| ) | |||
| assert not testcase, "extra inputs provided in testcase: {}".format( | |||
| testcase.keys() | |||
| ) | |||
| with open(args.output, "ab") as fout: | |||
| dump_content, _ = G.dump_graph( | |||
| output_mgbvars, strip_info_file=strip_info_file, append_json=True | |||
| ) | |||
| fout.write(dump_content) | |||
| if __name__ == "__main__": | |||
| main() | |||
| @@ -0,0 +1,74 @@ | |||
| /** | |||
| * \file lite/load_and_run/src/helpers/common.h | |||
| * | |||
| * This file is part of MegEngine, a deep learning framework developed by | |||
| * Megvii. | |||
| * | |||
| * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||
| */ | |||
| #pragma once | |||
| #include <gflags/gflags.h> | |||
| #include <memory> | |||
| DECLARE_int32(thread); | |||
| namespace lar { | |||
| /*! | |||
| * \brief: state of model running | |||
| */ | |||
| enum class RunStage { | |||
| BEFORE_MODEL_LOAD = 0, | |||
| AFTER_MODEL_LOAD = 1, | |||
| BEFORE_OUTSPEC_SET = 2, | |||
| //! using for dump static memory information svg file | |||
| AFTER_OUTSPEC_SET = 3, | |||
| //! using for external c opr library | |||
| MODEL_RUNNING = 4, | |||
| //! using for output dumper | |||
| AFTER_RUNNING_WAIT = 5, | |||
| //! using for external c opr library | |||
| AFTER_RUNNING_ITER = 6, | |||
| AFTER_MODEL_RUNNING = 7, | |||
| }; | |||
| /*! | |||
| * \brief: type of different model | |||
| */ | |||
| enum class ModelType { | |||
| LITE_MODEL = 0, | |||
| MEGDL_MODEL, | |||
| UNKNOWN, | |||
| }; | |||
| /*! | |||
| * \brief: param for running model | |||
| */ | |||
| struct RuntimeParam { | |||
| RunStage stage = RunStage::AFTER_MODEL_LOAD; | |||
| size_t warmup_iter; //! warm up number before running model | |||
| size_t run_iter; //! iteration number for running model | |||
| size_t threads = FLAGS_thread; //! thread number for running model (NOTE:it's | |||
| //! different from multithread device ) | |||
| size_t testcase_num = 1; //! testcase number for model with testcase | |||
| }; | |||
| /*! | |||
| * \brief:layout type for running model optimization | |||
| */ | |||
| enum class OptLayoutType { | |||
| NCHW4 = 1 << 0, | |||
| CHWN4 = 1 << 1, | |||
| NCHW44 = 1 << 2, | |||
| NCHW88 = 1 << 3, | |||
| NCHW32 = 1 << 4, | |||
| NCHW64 = 1 << 5, | |||
| NHWCD4 = 1 << 6, | |||
| NCHW44_DOT = 1 << 7 | |||
| }; | |||
| } // namespace lar | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,266 @@ | |||
| /** | |||
| * \file lite/load_and_run/src/helpers/data_parser.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "data_parser.h" | |||
| #include <sstream> | |||
| #include "json_loader.h" | |||
| #include "npy.h" | |||
| using namespace lar; | |||
| /*! | |||
| * \brief feed different data to diffferent parser | |||
| * \param path data file path or data string | |||
| */ | |||
| void DataParser::feed(const std::string& path) { | |||
| std::string blob_name = "data", blob_string = path; | |||
| size_t sep = path.find(":"); | |||
| if (sep != std::string::npos) { | |||
| blob_name = path.substr(0, sep); | |||
| blob_string = path.substr(sep + 1); | |||
| } | |||
| auto endWith = [blob_string](std::string suffix) -> bool { | |||
| return blob_string.rfind(suffix) == (blob_string.length() - suffix.length()); | |||
| }; | |||
| if (endWith(".ppm") || endWith(".pgm")) { | |||
| parse_image(blob_name, blob_string); | |||
| } else if (endWith(".json")) { | |||
| parse_json(blob_string); | |||
| } else if (endWith(".npy")) { | |||
| parse_npy(blob_name, blob_string); | |||
| } else { | |||
| parse_string(blob_name, blob_string); | |||
| } | |||
| } | |||
| void DataParser::parse_json(const std::string& path) { | |||
| mgb::JsonLoader json; | |||
| std::shared_ptr<mgb::JsonLoader::Value> root = json.load(path.c_str()); | |||
| mgb_assert(root != nullptr, "parse json %s fail", path.c_str()); | |||
| // parse json to data map | |||
| const std::string SHAPE = "shape", TYPE = "type", RAW = "raw"; | |||
| for (auto& item : root->objects()) { | |||
| auto&& value = *item.second; | |||
| auto&& shape = value[SHAPE]; | |||
| mgb_assert(shape->is_array()); | |||
| auto&& type = value[TYPE]; | |||
| mgb_assert(type->is_str()); | |||
| auto&& raw = value[RAW]; | |||
| mgb_assert(raw->is_array()); | |||
| megdnn::SmallVector<size_t> data_shape; | |||
| for (auto&& shape_ptr : shape->array()) { | |||
| data_shape.append({static_cast<size_t>(std::round(shape_ptr->number()))}); | |||
| } | |||
| // get type | |||
| const std::map<std::string, megdnn::DType> type_map = { | |||
| {"float32", mgb::dtype::Float32()}, {"float", mgb::dtype::Float32()}, | |||
| {"int32", mgb::dtype::Int32()}, {"int", mgb::dtype::Int32()}, | |||
| {"int8", mgb::dtype::Int8()}, {"uint8", mgb::dtype::Uint8()}}; | |||
| const std::string& type_str = type->str(); | |||
| mgb_assert( | |||
| type_map.find(type_str) != type_map.end(), | |||
| "unknown json data type for --input"); | |||
| mgb::DType datatype = type_map.at(type_str); | |||
| mgb::HostTensorND hv; | |||
| hv.comp_node(mgb::CompNode::default_cpu(), true) | |||
| .dtype(datatype) | |||
| .resize(data_shape); | |||
| mgb::dt_byte* raw_ptr = hv.raw_ptr(); | |||
| size_t elem_size = datatype.size(); | |||
| // get raw | |||
| const size_t array_size = raw->len(); | |||
| for (size_t idx = 0; idx < array_size; ++idx) { | |||
| double tmp = (*raw)[idx]->number(); | |||
| switch (datatype.enumv()) { | |||
| case megdnn::DTypeEnum::Int32: { | |||
| int32_t ival = std::round(tmp); | |||
| memcpy(((char*)raw_ptr) + idx * elem_size, &ival, elem_size); | |||
| } break; | |||
| case megdnn::DTypeEnum::Uint8: | |||
| case megdnn::DTypeEnum::Int8: { | |||
| int8_t cval = std::round(tmp); | |||
| memcpy(((char*)raw_ptr) + idx, &cval, sizeof(int8_t)); | |||
| } break; | |||
| case megdnn::DTypeEnum::Float32: { | |||
| float fval = tmp; | |||
| memcpy(((char*)raw_ptr) + idx * elem_size, &fval, elem_size); | |||
| } break; | |||
| default: | |||
| break; | |||
| } | |||
| } | |||
| inputs.insert(std::make_pair(item.first, std::move(hv))); | |||
| } | |||
| } | |||
| void DataParser::parse_image(const std::string& name, const std::string& path) { | |||
| // load binary ppm/pgm | |||
| std::ifstream fin; | |||
| fin.open(path, std::ifstream::binary | std::ifstream::in); | |||
| mgb_assert(fin.is_open(), "open file %s failed for --input", path.c_str()); | |||
| size_t w = 0, h = 0, channel = 0; | |||
| char buf[128] = {0}; | |||
| fin.getline(buf, 128); | |||
| if ('5' == buf[1]) { | |||
| channel = 1; | |||
| } else if ('6' == buf[1]) { | |||
| channel = 3; | |||
| } else { | |||
| mgb_assert(0, "not a formal ppm/pgm"); | |||
| } | |||
| while (fin.getline(buf, 128)) { | |||
| if (buf[0] == '#') { | |||
| continue; | |||
| } | |||
| break; | |||
| } | |||
| std::stringstream ss; | |||
| ss << std::string(buf); | |||
| ss >> w; | |||
| ss >> h; | |||
| mgb_assert(w > 0 and h > 0); | |||
| mgb::HostTensorND hv; | |||
| hv.comp_node(mgb::CompNode::default_cpu(), true) | |||
| .dtype(mgb::dtype::Uint8()) | |||
| .resize({1, h, w, channel}); | |||
| fin.read((char*)(hv.raw_ptr()), hv.layout().total_nr_elems()); | |||
| fin.close(); | |||
| inputs.insert(std::make_pair(name, std::move(hv))); | |||
| } | |||
| void DataParser::parse_npy(const std::string& name, const std::string& path) { | |||
| std::string type_str; | |||
| std::vector<npy::ndarray_len_t> stl_shape; | |||
| std::vector<int8_t> raw; | |||
| npy::LoadArrayFromNumpy(path, type_str, stl_shape, raw); | |||
| megdnn::SmallVector<size_t> shape; | |||
| for (auto val : stl_shape) { | |||
| shape.append({static_cast<size_t>(val)}); | |||
| } | |||
| const std::map<std::string, megdnn::DType> type_map = { | |||
| {"f4", mgb::dtype::Float32()}, {"i4", mgb::dtype::Int32()}, | |||
| {"i2", mgb::dtype::Int16()}, {"u2", mgb::dtype::Uint16()}, | |||
| {"i1", mgb::dtype::Int8()}, {"u1", mgb::dtype::Uint8()}}; | |||
| megdnn::DType hv_type; | |||
| for (auto& item : type_map) { | |||
| if (type_str.find(item.first) != std::string::npos) { | |||
| hv_type = item.second; | |||
| break; | |||
| } | |||
| } | |||
| mgb::HostTensorND hv; | |||
| hv.comp_node(mgb::CompNode::default_cpu(), true).dtype(hv_type).resize(shape); | |||
| mgb::dt_byte* raw_ptr = hv.raw_ptr(); | |||
| memcpy(raw_ptr, raw.data(), raw.size()); | |||
| inputs.insert(std::make_pair(name, std::move(hv))); | |||
| } | |||
| void DataParser::parse_string(const std::string name, const std::string& str) { | |||
| // data type | |||
| megdnn::DType data_type = mgb::dtype::Int32(); | |||
| if (str.find(".") != std::string::npos or str.find(".") != std::string::npos) { | |||
| data_type = mgb::dtype::Float32(); | |||
| } | |||
| // shape | |||
| size_t number_cnt = 0; | |||
| std::shared_ptr<Brace> brace_root = std::make_shared<Brace>(); | |||
| std::shared_ptr<Brace> cur = brace_root; | |||
| for (size_t i = 0; i < str.size(); ++i) { | |||
| char c = str[i]; | |||
| if (c == '[') { | |||
| std::shared_ptr<Brace> child = std::make_shared<Brace>(); | |||
| child->parent = cur; | |||
| cur->chidren.emplace_back(child); | |||
| cur = child; | |||
| } else if (c == ']') { | |||
| cur = cur->parent.lock(); | |||
| } else if (c == ',') { | |||
| number_cnt++; | |||
| } | |||
| continue; | |||
| } | |||
| ++number_cnt; | |||
| mgb_assert(cur == brace_root, "braces not closed for --input"); | |||
| megdnn::SmallVector<size_t> shape; | |||
| cur = brace_root; | |||
| while (not cur->chidren.empty()) { | |||
| shape.append({cur->chidren.size()}); | |||
| number_cnt /= cur->chidren.size(); | |||
| cur = cur->chidren[0]; | |||
| } | |||
| mgb_assert(number_cnt > 0); | |||
| shape.append({number_cnt}); | |||
| // data | |||
| std::string json_arr; | |||
| for (size_t i = 0; i < str.size(); ++i) { | |||
| char c = str[i]; | |||
| if (c != '[' and c != ']') { | |||
| json_arr += c; | |||
| } | |||
| } | |||
| json_arr = "[" + json_arr + "]"; | |||
| // reuse json parser to resolve raw data | |||
| mgb::JsonLoader json; | |||
| std::shared_ptr<mgb::JsonLoader::Value> json_root = | |||
| json.load(json_arr.data(), json_arr.size()); | |||
| mgb_assert(json_root != nullptr, "parse json fail in parse_string"); | |||
| mgb::HostTensorND hv; | |||
| hv.comp_node(mgb::CompNode::default_cpu(), true).dtype(data_type).resize(shape); | |||
| mgb::dt_byte* raw_ptr = hv.raw_ptr(); | |||
| const size_t array_len = json_root->len(); | |||
| const size_t elem_size = data_type.size(); | |||
| for (size_t idx = 0; idx < array_len; ++idx) { | |||
| double tmp = json_root->array()[idx]->number(); | |||
| switch (data_type.enumv()) { | |||
| case megdnn::DTypeEnum::Int32: { | |||
| int32_t ival = std::round(tmp); | |||
| memcpy(((char*)raw_ptr) + idx * elem_size, &ival, elem_size); | |||
| } break; | |||
| case megdnn::DTypeEnum::Float32: { | |||
| float fval = tmp; | |||
| memcpy(((char*)raw_ptr) + idx * elem_size, &fval, elem_size); | |||
| } break; | |||
| default: | |||
| break; | |||
| } | |||
| } | |||
| inputs.insert(std::make_pair(name, std::move(hv))); | |||
| } | |||
| @@ -0,0 +1,48 @@ | |||
| /** | |||
| * \file lite/load_and_run/src/helpers/data_parser.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #pragma once | |||
| #include <memory> | |||
| #include <unordered_map> | |||
| #include <vector> | |||
| #include "megbrain/opr/io.h" | |||
| namespace lar { | |||
| /*! | |||
| * \brief data parser for --input | |||
| * support .json|.ppm|.pgm|.npy data and user define data string | |||
| * data string format: [0,0,227,227] | |||
| */ | |||
| struct DataParser { | |||
| struct Brace { | |||
| std::weak_ptr<Brace> parent; | |||
| std::vector<std::shared_ptr<Brace>> chidren; | |||
| }; | |||
| void feed(const std::string& path); | |||
| std::unordered_map<std::string, mgb::HostTensorND> inputs; | |||
| private: | |||
| //! parser for json data | |||
| void parse_json(const std::string& path); | |||
| //! parser for .ppm .pgm image | |||
| void parse_image(const std::string& name, const std::string& path); | |||
| //! parser for .npy data | |||
| void parse_npy(const std::string& name, const std::string& path); | |||
| //! parser for user define string | |||
| void parse_string(const std::string name, const std::string& str); | |||
| }; | |||
| } // namespace lar | |||
| @@ -0,0 +1,297 @@ | |||
| /** | |||
| * \file lite/load_and_run/src/helpers/json_loader.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "json_loader.h" | |||
| using namespace mgb; | |||
| template <typename T> | |||
| T* JsonLoader::Value::safe_cast() { | |||
| T* ptr = (T*)(this); | |||
| if (nullptr == ptr) { | |||
| fprintf(stderr, "cast ptr is null\n"); | |||
| } | |||
| return ptr; | |||
| } | |||
| std::unique_ptr<JsonLoader::Value>& JsonLoader::Value::operator[]( | |||
| const std::string& key) { | |||
| mgb_assert(Type::OBJECT == m_type); | |||
| auto t = safe_cast<JsonLoader::ObjectValue>(); | |||
| return t->m_obj.at(key); | |||
| } | |||
| std::unique_ptr<JsonLoader::Value>& JsonLoader::Value::operator[](const size_t index) { | |||
| mgb_assert(Type::ARRAY == m_type); | |||
| auto t = safe_cast<JsonLoader::ArrayValue>(); | |||
| return t->m_obj[index]; | |||
| } | |||
| std::map<std::string, std::unique_ptr<JsonLoader::Value>>& JsonLoader::Value:: | |||
| objects() { | |||
| mgb_assert(Type::OBJECT == m_type); | |||
| auto t = safe_cast<JsonLoader::ObjectValue>(); | |||
| return t->m_obj; | |||
| } | |||
| size_t JsonLoader::Value::len() { | |||
| if (Type::ARRAY == m_type) { | |||
| auto t = safe_cast<JsonLoader::ArrayValue>(); | |||
| return t->m_obj.size(); | |||
| } else if (Type::OBJECT == m_type) { | |||
| auto t = safe_cast<JsonLoader::ObjectValue>(); | |||
| return t->m_obj.size(); | |||
| } | |||
| return 0; | |||
| } | |||
| megdnn::SmallVector<std::unique_ptr<JsonLoader::Value>>& JsonLoader::Value::array() { | |||
| mgb_assert(Type::ARRAY == m_type); | |||
| auto t = safe_cast<JsonLoader::ArrayValue>(); | |||
| return t->m_obj; | |||
| } | |||
| double JsonLoader::Value::number() { | |||
| mgb_assert(Type::NUMBER == m_type); | |||
| auto t = safe_cast<JsonLoader::NumberValue>(); | |||
| return t->value(); | |||
| } | |||
| std::string JsonLoader::Value::str() { | |||
| if (Type::STRING == m_type) { | |||
| auto t = safe_cast<StringValue>(); | |||
| return t->value(); | |||
| } | |||
| return std::string(); | |||
| } | |||
| void JsonLoader::expect(char c) { | |||
| mgb_assert(c == (*m_buf)); | |||
| m_buf++; | |||
| } | |||
| void JsonLoader::skip_whitespace() { | |||
| const char* p = m_buf; | |||
| while (*p == ' ' || *p == '\t' || *p == '\n' || *p == '\r') { | |||
| ++p; | |||
| } | |||
| m_buf = p; | |||
| } | |||
| std::unique_ptr<JsonLoader::Value> JsonLoader::parse_object() { | |||
| expect('{'); | |||
| skip_whitespace(); | |||
| std::unique_ptr<JsonLoader::Value> ret; | |||
| JsonLoader::ObjectValue* pObject = new JsonLoader::ObjectValue(); | |||
| if ('}' == *m_buf) { | |||
| m_buf = m_buf + 1; | |||
| ret.reset((JsonLoader::Value*)(pObject)); | |||
| return ret; | |||
| } | |||
| while (true) { | |||
| std::unique_ptr<JsonLoader::Value> key = parse_string(); | |||
| if (m_state != State::OK) { | |||
| return ret; | |||
| } | |||
| skip_whitespace(); | |||
| if (':' != (*m_buf)) { | |||
| m_state = State::MISS_COLON; | |||
| return ret; | |||
| } | |||
| m_buf++; | |||
| skip_whitespace(); | |||
| std::unique_ptr<JsonLoader::Value> pVal = parse_value(); | |||
| if (m_state != State::OK) { | |||
| return ret; | |||
| } | |||
| if (pObject->m_obj.find(pVal->str()) != pObject->m_obj.end()) { | |||
| m_state = State::KEY_NOT_UNIQUE; | |||
| return ret; | |||
| } | |||
| pObject->m_obj.insert(std::make_pair(key->str(), std::move(pVal))); | |||
| skip_whitespace(); | |||
| if (',' == (*m_buf)) { | |||
| m_buf++; | |||
| skip_whitespace(); | |||
| } else if ('}' == (*m_buf)) { | |||
| m_buf++; | |||
| break; | |||
| } else { | |||
| m_state = State::MISS_BRACE; | |||
| break; | |||
| } | |||
| } | |||
| ret.reset((JsonLoader::Value*)(pObject)); | |||
| return ret; | |||
| } | |||
| std::unique_ptr<JsonLoader::Value> JsonLoader::parse_array() { | |||
| expect('['); | |||
| skip_whitespace(); | |||
| std::unique_ptr<JsonLoader::Value> ret; | |||
| JsonLoader::ArrayValue* pArray = new JsonLoader::ArrayValue(); | |||
| if (']' == *m_buf) { | |||
| m_buf = m_buf + 1; | |||
| ret.reset((JsonLoader::Value*)(pArray)); | |||
| return ret; | |||
| } | |||
| while (true) { | |||
| std::unique_ptr<JsonLoader::Value> pVal = parse_value(); | |||
| if (m_state != State::OK) { | |||
| mgb_assert(0, "parse value failed during pase array"); | |||
| return ret; | |||
| } | |||
| pArray->m_obj.emplace_back(pVal.get()); | |||
| pVal.release(); | |||
| skip_whitespace(); | |||
| if (',' == *m_buf) { | |||
| m_buf++; | |||
| skip_whitespace(); | |||
| } else if (']' == *m_buf) { | |||
| m_buf++; | |||
| break; | |||
| } else { | |||
| m_state = State::BAD_ARRAY; | |||
| return ret; | |||
| } | |||
| } | |||
| ret.reset((JsonLoader::Value*)(pArray)); | |||
| return ret; | |||
| } | |||
| std::unique_ptr<JsonLoader::Value> JsonLoader::parse_string() { | |||
| expect('\"'); | |||
| std::unique_ptr<JsonLoader::Value> ret; | |||
| JsonLoader::StringValue* pStr = new JsonLoader::StringValue(); | |||
| const char* p = m_buf; | |||
| while (true) { | |||
| if (*p == '\"') { | |||
| p++; | |||
| break; | |||
| } else { | |||
| pStr->m_value += (*p); | |||
| p++; | |||
| } | |||
| } | |||
| m_buf = p; | |||
| ret.reset((JsonLoader::Value*)(pStr)); | |||
| return ret; | |||
| } | |||
| std::unique_ptr<JsonLoader::Value> JsonLoader::parse_number() { | |||
| const char* p = m_buf; | |||
| auto loop_digit = [this](const char*& p) { | |||
| if (not std::isdigit(*p)) { | |||
| m_state = State::BAD_DIGIT; | |||
| return; | |||
| } | |||
| while (std::isdigit(*p)) { | |||
| p++; | |||
| } | |||
| return; | |||
| }; | |||
| if (*p == '-') | |||
| p++; | |||
| if (*p == '0') | |||
| p++; | |||
| else { | |||
| loop_digit(std::ref(p)); | |||
| } | |||
| if (*p == '.') { | |||
| p++; | |||
| loop_digit(std::ref(p)); | |||
| } | |||
| if (*p == 'e' || *p == 'E') { | |||
| p++; | |||
| if (*p == '+' || *p == '-') | |||
| p++; | |||
| loop_digit(std::ref(p)); | |||
| } | |||
| JsonLoader::NumberValue* pNum = new JsonLoader::NumberValue(); | |||
| pNum->m_value = strtod(m_buf, nullptr); | |||
| m_buf = p; | |||
| std::unique_ptr<JsonLoader::Value> ret; | |||
| ret.reset((JsonLoader::Value*)(pNum)); | |||
| return ret; | |||
| } | |||
| std::unique_ptr<JsonLoader::Value> JsonLoader::parse_value() { | |||
| switch (*m_buf) { | |||
| case '[': | |||
| return parse_array(); | |||
| case '{': | |||
| return parse_object(); | |||
| case '\"': | |||
| return parse_string(); | |||
| case '\0': | |||
| m_state = State::BAD_TYPE; | |||
| break; | |||
| default: | |||
| return parse_number(); | |||
| } | |||
| return nullptr; | |||
| } | |||
| std::unique_ptr<JsonLoader::Value> JsonLoader::load( | |||
| const char* content, const size_t size) { | |||
| m_buf = content; | |||
| skip_whitespace(); | |||
| std::unique_ptr<JsonLoader::Value> value = parse_value(); | |||
| skip_whitespace(); | |||
| if (m_state != State::OK) { | |||
| return nullptr; | |||
| } | |||
| mgb_assert(size == static_cast<size_t>(m_buf - content)); | |||
| return value; | |||
| } | |||
| std::unique_ptr<JsonLoader::Value> JsonLoader::load(const char* path) { | |||
| std::unique_ptr<std::FILE, void (*)(std::FILE*)> fin( | |||
| std::fopen(path, "rb"), [](std::FILE* fp) { std::fclose(fp); }); | |||
| mgb_assert(fin.get(), "failed to open %s: %s", path, strerror(errno)); | |||
| std::fseek(fin.get(), 0, SEEK_END); | |||
| const size_t size = ftell(fin.get()); | |||
| std::fseek(fin.get(), 0, SEEK_SET); | |||
| std::unique_ptr<char> buf(static_cast<char*>(malloc(size))); | |||
| auto nr = std::fread(buf.get(), 1, size, fin.get()); | |||
| mgb_assert(nr == size); | |||
| return load(buf.get(), size); | |||
| } | |||
| @@ -0,0 +1,183 @@ | |||
| /** | |||
| * \file lite/load_and_run/src/helpers/json_loader.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #pragma once | |||
| #include <cctype> | |||
| #include <fstream> | |||
| #include <functional> | |||
| #include <iostream> | |||
| #include <map> | |||
| #include <memory> | |||
| #include "megbrain/common.h" | |||
| #include "megdnn/thin/small_vector.h" | |||
| namespace mgb { | |||
| /*! | |||
| * \brief JSON format data loader for --input | |||
| */ | |||
| class JsonLoader { | |||
| public: | |||
| // base class for different value format | |||
| class Value { | |||
| protected: | |||
| enum struct Type : uint8_t { UNKNOWN, NUMBER, STRING, OBJECT, ARRAY }; | |||
| Type m_type; | |||
| public: | |||
| template <typename T> | |||
| T* safe_cast(); | |||
| Value() { m_type = Type::UNKNOWN; } | |||
| Value(Type type) : m_type(type) {} | |||
| virtual ~Value() {} | |||
| bool is_array() { return Type::ARRAY == m_type; } | |||
| bool is_object() { return Type::OBJECT == m_type; } | |||
| bool is_number() { return Type::NUMBER == m_type; } | |||
| bool is_str() { return Type::STRING == m_type; } | |||
| std::unique_ptr<Value>& operator[](const std::string& key); | |||
| std::unique_ptr<Value>& operator[](const size_t index); | |||
| std::map<std::string, std::unique_ptr<Value>>& objects(); | |||
| size_t len(); | |||
| megdnn::SmallVector<std::unique_ptr<Value>>& array(); | |||
| double number(); | |||
| std::string str(); | |||
| }; | |||
| void expect(char c); | |||
| void skip_whitespace(); | |||
| std::unique_ptr<Value> parse_object(); | |||
| std::unique_ptr<Value> parse_array(); | |||
| std::unique_ptr<Value> parse_string(); | |||
| std::unique_ptr<Value> parse_number(); | |||
| std::unique_ptr<Value> parse_value(); | |||
| enum struct State : uint8_t { | |||
| OK = 0, | |||
| BAD_TYPE, | |||
| BAD_DIGIT, | |||
| BAD_ARRAY, | |||
| MISS_COLON, | |||
| MISS_BRACE, | |||
| KEY_NOT_UNIQUE | |||
| }; | |||
| JsonLoader() { m_state = State::OK; } | |||
| std::unique_ptr<Value> load(const char* content, const size_t size); | |||
| std::unique_ptr<Value> load(const char* path); | |||
| class NumberValue final : public Value { | |||
| friend std::unique_ptr<Value> JsonLoader::parse_number(); | |||
| double m_value; | |||
| public: | |||
| NumberValue() : Value(Type::NUMBER) {} | |||
| double value() { return m_value; } | |||
| }; | |||
| class StringValue final : public Value { | |||
| std::string m_value; | |||
| public: | |||
| StringValue() : Value(Type::STRING) {} | |||
| std::string value() { return m_value; } | |||
| friend std::unique_ptr<Value> JsonLoader::parse_string(); | |||
| }; | |||
| class ArrayValue final : public Value { | |||
| megdnn::SmallVector<std::unique_ptr<Value>> m_obj; | |||
| public: | |||
| ArrayValue() : Value(Type::ARRAY) {} | |||
| ArrayValue(ArrayValue& arr) : Value(arr) { | |||
| m_obj.clear(); | |||
| for (auto& item : arr.m_obj) { | |||
| m_obj.emplace_back(item.get()); | |||
| item.release(); | |||
| } | |||
| } | |||
| ArrayValue(ArrayValue&& arr) : Value(arr) { | |||
| m_obj.clear(); | |||
| for (auto& item : arr.m_obj) { | |||
| m_obj.emplace_back(item.get()); | |||
| item.release(); | |||
| } | |||
| } | |||
| friend std::unique_ptr<Value> JsonLoader::parse_array(); | |||
| friend std::unique_ptr<JsonLoader::Value>& JsonLoader::Value::operator[]( | |||
| const size_t index); | |||
| friend megdnn::SmallVector<std::unique_ptr<JsonLoader::Value>>& JsonLoader:: | |||
| Value::array(); | |||
| friend size_t JsonLoader::Value::len(); | |||
| }; | |||
| class ObjectValue final : public Value { | |||
| std::map<std::string, std::unique_ptr<Value>> m_obj; | |||
| public: | |||
| ObjectValue() : Value(Type::OBJECT) {} | |||
| ObjectValue(ObjectValue& arr) : Value(arr) { | |||
| m_obj.clear(); | |||
| for (auto itra = arr.m_obj.begin(); itra != arr.m_obj.end(); ++itra) { | |||
| m_obj.emplace(std::make_pair(itra->first, std::move(itra->second))); | |||
| } | |||
| } | |||
| ObjectValue(ObjectValue&& arr) : Value(arr) { | |||
| m_obj.clear(); | |||
| for (auto itra = arr.m_obj.begin(); itra != arr.m_obj.end(); ++itra) { | |||
| m_obj.emplace(std::make_pair(itra->first, std::move(itra->second))); | |||
| } | |||
| } | |||
| friend std::unique_ptr<Value> JsonLoader::parse_object(); | |||
| friend std::unique_ptr<JsonLoader::Value>& JsonLoader::Value::operator[]( | |||
| const std::string&); | |||
| friend std::map<std::string, std::unique_ptr<JsonLoader::Value>>& JsonLoader:: | |||
| Value::objects(); | |||
| friend size_t JsonLoader::Value::len(); | |||
| }; | |||
| private: | |||
| const char* m_buf; | |||
| State m_state; | |||
| }; | |||
| } // namespace mgb | |||
| @@ -0,0 +1,615 @@ | |||
| /* | |||
| Copyright 2017 Leon Merten Lohse | |||
| Permission is hereby granted, free of charge, to any person obtaining a copy | |||
| of this software and associated documentation files (the "Software"), to deal | |||
| in the Software without restriction, including without limitation the rights | |||
| to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | |||
| copies of the Software, and to permit persons to whom the Software is | |||
| furnished to do so, subject to the following conditions: | |||
| The above copyright notice and this permission notice shall be included in | |||
| all copies or substantial portions of the Software. | |||
| THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |||
| IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |||
| FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | |||
| AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |||
| LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | |||
| OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | |||
| SOFTWARE. | |||
| */ | |||
| #ifndef NPY_H | |||
| #define NPY_H | |||
| #include <algorithm> | |||
| #include <complex> | |||
| #include <cstdint> | |||
| #include <cstring> | |||
| #include <fstream> | |||
| #include <iostream> | |||
| #include <regex> | |||
| #include <sstream> | |||
| #include <stdexcept> | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include <vector> | |||
| namespace npy { | |||
| /* Compile-time test for byte order. | |||
| If your compiler does not define these per default, you may want to define | |||
| one of these constants manually. | |||
| Defaults to little endian order. */ | |||
| #if defined(__BYTE_ORDER) && __BYTE_ORDER == __BIG_ENDIAN || \ | |||
| defined(__BIG_ENDIAN__) || defined(__ARMEB__) || defined(__THUMBEB__) || \ | |||
| defined(__AARCH64EB__) || defined(_MIBSEB) || defined(__MIBSEB) || \ | |||
| defined(__MIBSEB__) | |||
| const bool big_endian = true; | |||
| #else | |||
| const bool big_endian = false; | |||
| #endif | |||
| const char magic_string[] = "\x93NUMPY"; | |||
| const size_t magic_string_length = 6; | |||
| const char little_endian_char = '<'; | |||
| const char big_endian_char = '>'; | |||
| const char no_endian_char = '|'; | |||
| constexpr char host_endian_char = (big_endian ? big_endian_char : little_endian_char); | |||
| /* npy array length */ | |||
| typedef unsigned long int ndarray_len_t; | |||
| inline void write_magic( | |||
| std::ostream& ostream, unsigned char v_major = 1, unsigned char v_minor = 0) { | |||
| ostream.write(magic_string, magic_string_length); | |||
| ostream.put(v_major); | |||
| ostream.put(v_minor); | |||
| } | |||
| inline void read_magic( | |||
| std::istream& istream, unsigned char& v_major, unsigned char& v_minor) { | |||
| char buf[magic_string_length + 2]; | |||
| istream.read(buf, magic_string_length + 2); | |||
| if (!istream) { | |||
| fprintf(stderr, "io error: failed reading file"); | |||
| } | |||
| if (0 != std::memcmp(buf, magic_string, magic_string_length)) { | |||
| fprintf(stderr, "this file does not have a valid npy format."); | |||
| } | |||
| v_major = buf[magic_string_length]; | |||
| v_minor = buf[magic_string_length + 1]; | |||
| } | |||
| // typestring magic | |||
| struct Typestring { | |||
| private: | |||
| char c_endian; | |||
| char c_type; | |||
| int len; | |||
| public: | |||
| inline std::string str() { | |||
| const size_t max_buflen = 16; | |||
| char buf[max_buflen]; | |||
| std::sprintf(buf, "%c%c%u", c_endian, c_type, len); | |||
| return std::string(buf); | |||
| } | |||
| Typestring(const std::vector<float>&) | |||
| : c_endian{host_endian_char}, c_type{'f'}, len{sizeof(float)} {} | |||
| Typestring(const std::vector<double>&) | |||
| : c_endian{host_endian_char}, c_type{'f'}, len{sizeof(double)} {} | |||
| Typestring(const std::vector<long double>&) | |||
| : c_endian{host_endian_char}, c_type{'f'}, len{sizeof(long double)} {} | |||
| Typestring(const std::vector<char>&) | |||
| : c_endian{no_endian_char}, c_type{'i'}, len{sizeof(char)} {} | |||
| Typestring(const std::vector<short>&) | |||
| : c_endian{host_endian_char}, c_type{'i'}, len{sizeof(short)} {} | |||
| Typestring(const std::vector<int>&) | |||
| : c_endian{host_endian_char}, c_type{'i'}, len{sizeof(int)} {} | |||
| Typestring(const std::vector<long>&) | |||
| : c_endian{host_endian_char}, c_type{'i'}, len{sizeof(long)} {} | |||
| Typestring(const std::vector<long long>&) | |||
| : c_endian{host_endian_char}, c_type{'i'}, len{sizeof(long long)} {} | |||
| Typestring(const std::vector<unsigned char>&) | |||
| : c_endian{no_endian_char}, c_type{'u'}, len{sizeof(unsigned char)} {} | |||
| Typestring(const std::vector<unsigned short>&) | |||
| : c_endian{host_endian_char}, c_type{'u'}, len{sizeof(unsigned short)} {} | |||
| Typestring(const std::vector<unsigned int>&) | |||
| : c_endian{host_endian_char}, c_type{'u'}, len{sizeof(unsigned int)} {} | |||
| Typestring(const std::vector<unsigned long>&) | |||
| : c_endian{host_endian_char}, c_type{'u'}, len{sizeof(unsigned long)} {} | |||
| Typestring(const std::vector<unsigned long long>&) | |||
| : c_endian{host_endian_char}, | |||
| c_type{'u'}, | |||
| len{sizeof(unsigned long long)} {} | |||
| Typestring(const std::vector<std::complex<float>>&) | |||
| : c_endian{host_endian_char}, | |||
| c_type{'c'}, | |||
| len{sizeof(std::complex<float>)} {} | |||
| Typestring(const std::vector<std::complex<double>>&) | |||
| : c_endian{host_endian_char}, | |||
| c_type{'c'}, | |||
| len{sizeof(std::complex<double>)} {} | |||
| Typestring(const std::vector<std::complex<long double>>&) | |||
| : c_endian{host_endian_char}, | |||
| c_type{'c'}, | |||
| len{sizeof(std::complex<long double>)} {} | |||
| }; | |||
| inline void parse_typestring(std::string typestring) { | |||
| std::regex re("'([<>|])([ifuc])(\\d+)'"); | |||
| std::smatch sm; | |||
| std::regex_match(typestring, sm, re); | |||
| if (sm.size() != 4) { | |||
| fprintf(stderr, "invalid typestring"); | |||
| } | |||
| } | |||
| namespace pyparse { | |||
| /** | |||
| Removes leading and trailing whitespaces | |||
| */ | |||
| inline std::string trim(const std::string& str) { | |||
| const std::string whitespace = " \t"; | |||
| auto begin = str.find_first_not_of(whitespace); | |||
| if (begin == std::string::npos) | |||
| return ""; | |||
| auto end = str.find_last_not_of(whitespace); | |||
| return str.substr(begin, end - begin + 1); | |||
| } | |||
| inline std::string get_value_from_map(const std::string& mapstr) { | |||
| size_t sep_pos = mapstr.find_first_of(":"); | |||
| if (sep_pos == std::string::npos) | |||
| return ""; | |||
| std::string tmp = mapstr.substr(sep_pos + 1); | |||
| return trim(tmp); | |||
| } | |||
| /** | |||
| Parses the string representation of a Python dict | |||
| The keys need to be known and may not appear anywhere else in the data. | |||
| */ | |||
| inline std::unordered_map<std::string, std::string> parse_dict( | |||
| std::string in, std::vector<std::string>& keys) { | |||
| std::unordered_map<std::string, std::string> map; | |||
| if (keys.size() == 0) | |||
| return map; | |||
| in = trim(in); | |||
| // unwrap dictionary | |||
| if ((in.front() == '{') && (in.back() == '}')) | |||
| in = in.substr(1, in.length() - 2); | |||
| else { | |||
| fprintf(stderr, "Not a Python dictionary."); | |||
| } | |||
| std::vector<std::pair<size_t, std::string>> positions; | |||
| for (auto const& value : keys) { | |||
| size_t pos = in.find("'" + value + "'"); | |||
| if (pos == std::string::npos) { | |||
| fprintf(stderr, "Missing %s key.", value.c_str()); | |||
| } | |||
| std::pair<size_t, std::string> position_pair{pos, value}; | |||
| positions.push_back(position_pair); | |||
| } | |||
| // sort by position in dict | |||
| std::sort(positions.begin(), positions.end()); | |||
| for (size_t i = 0; i < positions.size(); ++i) { | |||
| std::string raw_value; | |||
| size_t begin{positions[i].first}; | |||
| size_t end{std::string::npos}; | |||
| std::string key = positions[i].second; | |||
| if (i + 1 < positions.size()) | |||
| end = positions[i + 1].first; | |||
| raw_value = in.substr(begin, end - begin); | |||
| raw_value = trim(raw_value); | |||
| if (raw_value.back() == ',') | |||
| raw_value.pop_back(); | |||
| map[key] = get_value_from_map(raw_value); | |||
| } | |||
| return map; | |||
| } | |||
| /** | |||
| Parses the string representation of a Python boolean | |||
| */ | |||
| inline bool parse_bool(const std::string& in) { | |||
| if (in == "True") | |||
| return true; | |||
| if (in == "False") | |||
| return false; | |||
| fprintf(stderr, "Invalid python boolan."); | |||
| return false; | |||
| } | |||
| /** | |||
| Parses the string representation of a Python str | |||
| */ | |||
| inline std::string parse_str(const std::string& in) { | |||
| if ((in.front() == '\'') && (in.back() == '\'')) | |||
| return in.substr(1, in.length() - 2); | |||
| fprintf(stderr, "Invalid python string."); | |||
| return ""; | |||
| } | |||
| /** | |||
| Parses the string represenatation of a Python tuple into a vector of its items | |||
| */ | |||
| inline std::vector<std::string> parse_tuple(std::string in) { | |||
| std::vector<std::string> v; | |||
| const char seperator = ','; | |||
| in = trim(in); | |||
| if ((in.front() == '(') && (in.back() == ')')) | |||
| in = in.substr(1, in.length() - 2); | |||
| else { | |||
| fprintf(stderr, "Invalid Python tuple."); | |||
| } | |||
| std::istringstream iss(in); | |||
| for (std::string token; std::getline(iss, token, seperator);) { | |||
| v.push_back(token); | |||
| } | |||
| return v; | |||
| } | |||
| template <typename T> | |||
| inline std::string write_tuple(const std::vector<T>& v) { | |||
| if (v.size() == 0) | |||
| return ""; | |||
| std::ostringstream ss; | |||
| if (v.size() == 1) { | |||
| ss << "(" << v.front() << ",)"; | |||
| } else { | |||
| const std::string delimiter = ", "; | |||
| // v.size() > 1 | |||
| ss << "("; | |||
| std::copy( | |||
| v.begin(), v.end() - 1, | |||
| std::ostream_iterator<T>(ss, delimiter.c_str())); | |||
| ss << v.back(); | |||
| ss << ")"; | |||
| } | |||
| return ss.str(); | |||
| } | |||
| inline std::string write_boolean(bool b) { | |||
| if (b) | |||
| return "True"; | |||
| else | |||
| return "False"; | |||
| } | |||
| } // namespace pyparse | |||
| inline void parse_header(std::string header, std::string& descr) { | |||
| /* | |||
| The first 6 bytes are a magic string: exactly "x93NUMPY". | |||
| The next 1 byte is an unsigned byte: the major version number of the file | |||
| format, e.g. x01. The next 1 byte is an unsigned byte: the minor version | |||
| number of the file format, e.g. x00. Note: the version of the file format | |||
| is not tied to the version of the numpy package. The next 2 bytes form a | |||
| little-endian unsigned short int: the length of the header data | |||
| HEADER_LEN. The next HEADER_LEN bytes form the header data describing the | |||
| array's format. It is an ASCII string which contains a Python literal | |||
| expression of a dictionary. It is terminated by a newline ('n') and | |||
| padded with spaces | |||
| ('x20') to make the total length of the magic string + 4 + HEADER_LEN be | |||
| evenly divisible by 16 for alignment purposes. The dictionary contains | |||
| three keys: | |||
| "descr" : dtype.descr | |||
| An object that can be passed as an argument to the numpy.dtype() | |||
| constructor to create the array's dtype. For repeatability and | |||
| readability, this dictionary is formatted using pprint.pformat() so the | |||
| keys are in alphabetic order. | |||
| */ | |||
| // remove trailing newline | |||
| if (header.back() != '\n') | |||
| fprintf(stderr, "invalid header"); | |||
| header.pop_back(); | |||
| // parse the dictionary | |||
| std::vector<std::string> keys{"descr"}; | |||
| auto dict_map = npy::pyparse::parse_dict(header, keys); | |||
| if (dict_map.size() == 0) | |||
| fprintf(stderr, "invalid dictionary in header"); | |||
| std::string descr_s = dict_map["descr"]; | |||
| parse_typestring(descr_s); | |||
| // remove | |||
| descr = npy::pyparse::parse_str(descr_s); | |||
| return; | |||
| } | |||
| inline void parse_header( | |||
| std::string header, std::string& descr, bool& fortran_order, | |||
| std::vector<ndarray_len_t>& shape) { | |||
| /* | |||
| The first 6 bytes are a magic string: exactly "x93NUMPY". | |||
| The next 1 byte is an unsigned byte: the major version number of the file | |||
| format, e.g. x01. The next 1 byte is an unsigned byte: the minor version | |||
| number of the file format, e.g. x00. Note: the version of the file format | |||
| is not tied to the version of the numpy package. The next 2 bytes form a | |||
| little-endian unsigned short int: the length of the header data | |||
| HEADER_LEN. The next HEADER_LEN bytes form the header data describing the | |||
| array's format. It is an ASCII string which contains a Python literal | |||
| expression of a dictionary. It is terminated by a newline ('n') and | |||
| padded with spaces | |||
| ('x20') to make the total length of the magic string + 4 + HEADER_LEN be | |||
| evenly divisible by 16 for alignment purposes. The dictionary contains | |||
| three keys: | |||
| "descr" : dtype.descr | |||
| An object that can be passed as an argument to the numpy.dtype() | |||
| constructor to create the array's dtype. "fortran_order" : bool Whether | |||
| the array data is Fortran-contiguous or not. Since Fortran-contiguous | |||
| arrays are a common form of non-C-contiguity, we allow them to be written | |||
| directly to disk for efficiency. "shape" : tuple of int The shape of the | |||
| array. For repeatability and readability, this dictionary is formatted | |||
| using pprint.pformat() so the keys are in alphabetic order. | |||
| */ | |||
| // remove trailing newline | |||
| if (header.back() != '\n') | |||
| fprintf(stderr, "invalid header"); | |||
| header.pop_back(); | |||
| // parse the dictionary | |||
| std::vector<std::string> keys{"descr", "fortran_order", "shape"}; | |||
| auto dict_map = npy::pyparse::parse_dict(header, keys); | |||
| if (dict_map.size() == 0) | |||
| fprintf(stderr, "invalid dictionary in header"); | |||
| std::string descr_s = dict_map["descr"]; | |||
| std::string fortran_s = dict_map["fortran_order"]; | |||
| std::string shape_s = dict_map["shape"]; | |||
| // TODO: extract info from typestring | |||
| parse_typestring(descr_s); | |||
| // remove | |||
| descr = npy::pyparse::parse_str(descr_s); | |||
| // convert literal Python bool to C++ bool | |||
| fortran_order = npy::pyparse::parse_bool(fortran_s); | |||
| // parse the shape tuple | |||
| auto shape_v = npy::pyparse::parse_tuple(shape_s); | |||
| if (shape_v.size() == 0) | |||
| fprintf(stderr, "invalid shape tuple in header"); | |||
| for (auto item : shape_v) { | |||
| ndarray_len_t dim = static_cast<ndarray_len_t>(std::stoul(item)); | |||
| shape.push_back(dim); | |||
| } | |||
| } | |||
| inline std::string write_header_dict( | |||
| const std::string& descr, bool fortran_order, | |||
| const std::vector<ndarray_len_t>& shape) { | |||
| std::string s_fortran_order = npy::pyparse::write_boolean(fortran_order); | |||
| std::string shape_s = npy::pyparse::write_tuple(shape); | |||
| return "{'descr': '" + descr + "', 'fortran_order': " + s_fortran_order + | |||
| ", 'shape': " + shape_s + ", }"; | |||
| } | |||
| inline void write_header( | |||
| std::ostream& out, const std::string& descr, bool fortran_order, | |||
| const std::vector<ndarray_len_t>& shape_v) { | |||
| std::string header_dict = write_header_dict(descr, fortran_order, shape_v); | |||
| size_t length = magic_string_length + 2 + 2 + header_dict.length() + 1; | |||
| unsigned char version[2] = {1, 0}; | |||
| if (length >= 255 * 255) { | |||
| length = magic_string_length + 2 + 4 + header_dict.length() + 1; | |||
| version[0] = 2; | |||
| version[1] = 0; | |||
| } | |||
| size_t padding_len = 16 - length % 16; | |||
| std::string padding(padding_len, ' '); | |||
| // write magic | |||
| write_magic(out, version[0], version[1]); | |||
| // write header length | |||
| if (version[0] == 1 && version[1] == 0) { | |||
| char header_len_le16[2]; | |||
| uint16_t header_len = | |||
| static_cast<uint16_t>(header_dict.length() + padding.length() + 1); | |||
| header_len_le16[0] = (header_len >> 0) & 0xff; | |||
| header_len_le16[1] = (header_len >> 8) & 0xff; | |||
| out.write(reinterpret_cast<char*>(header_len_le16), 2); | |||
| } else { | |||
| char header_len_le32[4]; | |||
| uint32_t header_len = | |||
| static_cast<uint32_t>(header_dict.length() + padding.length() + 1); | |||
| header_len_le32[0] = (header_len >> 0) & 0xff; | |||
| header_len_le32[1] = (header_len >> 8) & 0xff; | |||
| header_len_le32[2] = (header_len >> 16) & 0xff; | |||
| header_len_le32[3] = (header_len >> 24) & 0xff; | |||
| out.write(reinterpret_cast<char*>(header_len_le32), 4); | |||
| } | |||
| out << header_dict << padding << '\n'; | |||
| } | |||
| inline std::string read_header(std::istream& istream) { | |||
| // check magic bytes an version number | |||
| unsigned char v_major, v_minor; | |||
| read_magic(istream, v_major, v_minor); | |||
| uint32_t header_length = 0; | |||
| if (v_major == 1 && v_minor == 0) { | |||
| char header_len_le16[2]; | |||
| istream.read(header_len_le16, 2); | |||
| header_length = (header_len_le16[0] << 0) | (header_len_le16[1] << 8); | |||
| if ((magic_string_length + 2 + 2 + header_length) % 16 != 0) { | |||
| // TODO: display warning | |||
| } | |||
| } else if (v_major == 2 && v_minor == 0) { | |||
| char header_len_le32[4]; | |||
| istream.read(header_len_le32, 4); | |||
| header_length = (header_len_le32[0] << 0) | (header_len_le32[1] << 8) | | |||
| (header_len_le32[2] << 16) | (header_len_le32[3] << 24); | |||
| if ((magic_string_length + 2 + 4 + header_length) % 16 != 0) { | |||
| // TODO: display warning | |||
| } | |||
| } else { | |||
| fprintf(stderr, "unsupported file format version"); | |||
| } | |||
| auto buf_v = std::vector<char>(); | |||
| buf_v.reserve(header_length); | |||
| istream.read(buf_v.data(), header_length); | |||
| std::string header(buf_v.data(), header_length); | |||
| return header; | |||
| } | |||
| inline ndarray_len_t comp_size(const std::vector<ndarray_len_t>& shape) { | |||
| ndarray_len_t size = 1; | |||
| for (ndarray_len_t i : shape) | |||
| size *= i; | |||
| return size; | |||
| } | |||
| template <typename Scalar> | |||
| inline void SaveArrayAsNumpy( | |||
| const std::string& filename, bool fortran_order, unsigned int n_dims, | |||
| const unsigned long shape[], const std::vector<Scalar>& data) { | |||
| Typestring typestring_o(data); | |||
| std::string typestring = typestring_o.str(); | |||
| std::ofstream stream(filename, std::ofstream::binary); | |||
| if (!stream) { | |||
| fprintf(stderr, "io error: failed to open a file."); | |||
| } | |||
| std::vector<ndarray_len_t> shape_v(shape, shape + n_dims); | |||
| write_header(stream, typestring, fortran_order, shape_v); | |||
| auto size = static_cast<size_t>(comp_size(shape_v)); | |||
| stream.write(reinterpret_cast<const char*>(data.data()), sizeof(Scalar) * size); | |||
| } | |||
| template <typename Scalar> | |||
| inline void LoadArrayFromNumpy( | |||
| const std::string& filename, std::vector<unsigned long>& shape, | |||
| std::vector<Scalar>& data) { | |||
| bool fortran_order; | |||
| LoadArrayFromNumpy<Scalar>(filename, shape, fortran_order, data); | |||
| } | |||
| template <typename Scalar> | |||
| inline void LoadArrayFromNumpy( | |||
| const std::string& filename, std::vector<unsigned long>& shape, | |||
| bool& fortran_order, std::vector<Scalar>& data) { | |||
| std::ifstream stream(filename, std::ifstream::binary); | |||
| if (!stream) { | |||
| fprintf(stderr, "io error: failed to open a file."); | |||
| } | |||
| std::string header = read_header(stream); | |||
| // parse header | |||
| std::string typestr; | |||
| parse_header(header, typestr, fortran_order, shape); | |||
| // check if the typestring matches the given one | |||
| Typestring typestring_o{data}; | |||
| std::string expect_typestr = typestring_o.str(); | |||
| if (typestr != expect_typestr) { | |||
| fprintf(stderr, "formatting error: typestrings not matching"); | |||
| } | |||
| // compute the data size based on the shape | |||
| auto size = static_cast<size_t>(comp_size(shape)); | |||
| data.resize(size); | |||
| // read the data | |||
| stream.read(reinterpret_cast<char*>(data.data()), sizeof(Scalar) * size); | |||
| } | |||
| inline void LoadArrayFromNumpy( | |||
| const std::string& filename, std::string& type_str, | |||
| std::vector<ndarray_len_t>& shape, std::vector<int8_t>& data) { | |||
| std::ifstream stream(filename, std::ifstream::binary); | |||
| if (!stream) { | |||
| fprintf(stderr, "io error: failed to open a file."); | |||
| } | |||
| std::string header = read_header(stream); | |||
| bool fortran_order; | |||
| // parse header | |||
| parse_header(header, type_str, fortran_order, shape); | |||
| // check if the typestring matches the given one | |||
| std::string size_str = type_str.substr(type_str.size() - 1); | |||
| size_t elem_size = atoi(size_str.c_str()); | |||
| // compute the data size based on the shape | |||
| auto byte_size = elem_size * static_cast<size_t>(comp_size(shape)); | |||
| data.resize(byte_size); | |||
| // read the data | |||
| stream.read(reinterpret_cast<char*>(data.data()), byte_size); | |||
| } | |||
| } // namespace npy | |||
| #endif // NPY_H | |||
| @@ -0,0 +1,48 @@ | |||
| /** | |||
| * \file lite/load_and_run/src/helpers/outdumper.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| */ | |||
| #include "outdumper.h" | |||
| #include "megbrain/utils/debug.h" | |||
| using namespace lar; | |||
| void OutputDumper::set(mgb::SymbolVarArray& symb_var) { | |||
| for (auto&& i : symb_var) { | |||
| auto&& var = i.node(); | |||
| DumpInfo info; | |||
| info.var_info = mgb::cg::dump_var_info({var}); | |||
| info.owner_inputs_info = mgb::cg::dump_var_info(var->owner_opr()->input()); | |||
| info.id = var->id(); | |||
| m_infos.push_back(info); | |||
| } | |||
| } | |||
| mgb::ComputingGraph::Callback OutputDumper::bind() { | |||
| auto& info = m_infos.at(m_bind_id++); | |||
| mgb::ComputingGraph::Callback cb = [&info](const mgb::DeviceTensorND& dv) { | |||
| info.hv.copy_from(dv); | |||
| }; | |||
| return cb; | |||
| } | |||
| void OutputDumper::write_to_file() { | |||
| if (!dump_file.empty()) { | |||
| for (auto&& info : m_infos) { | |||
| auto value = mgb::debug::dump_tensor( | |||
| info.hv, | |||
| mgb::ssprintf( | |||
| "var=%s owner_opr_inputs= %s", info.var_info.c_str(), | |||
| info.owner_inputs_info.c_str())); | |||
| mgb::debug::write_to_file( | |||
| mgb::ssprintf( | |||
| "%s/run%zu-var %zd", dump_file.c_str(), m_run_id, info.id) | |||
| .c_str(), | |||
| value); | |||
| } | |||
| } | |||
| m_run_id++; | |||
| } | |||
| @@ -0,0 +1,42 @@ | |||
| /** | |||
| * \file lite/load_and_run/src/helpers/outdumper.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| */ | |||
| #pragma once | |||
| #include "megbrain/serialization/serializer.h" | |||
| namespace lar { | |||
| /*! | |||
| * \brief dumper for only output used for --bin-out-dump | |||
| */ | |||
| class OutputDumper { | |||
| public: | |||
| struct DumpInfo { | |||
| mgb::HostTensorND hv = {}; | |||
| std::string var_info; | |||
| std::string owner_inputs_info; | |||
| size_t id; | |||
| }; | |||
| //! init the dump_file path | |||
| OutputDumper(const char* file) { dump_file = file; } | |||
| //! set the dump informations | |||
| void set(mgb::SymbolVarArray& symb_var); | |||
| //! callback function for specify output when compile computing graph | |||
| mgb::ComputingGraph::Callback bind(); | |||
| //! write dumped output into dump_file | |||
| void write_to_file(); | |||
| private: | |||
| mgb::SmallVector<DumpInfo> m_infos; | |||
| size_t m_run_id = 0; | |||
| size_t m_bind_id = 0; | |||
| std::string dump_file; | |||
| }; | |||
| } // namespace lar | |||
| @@ -0,0 +1,119 @@ | |||
| /** | |||
| * \file lite/load_and_run/src/helpers/text_table.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "text_table.h" | |||
| using namespace mgb; | |||
| namespace { | |||
| inline void mid(std::ostream& os, const std::string& str, size_t max_w) { | |||
| size_t l = (max_w - str.length()) / 2 + str.length(); | |||
| size_t r = max_w - l; | |||
| os << std::setw(l) << std::right << str; | |||
| if (r > 0) | |||
| os << std::setw(r) << ' '; | |||
| } | |||
| inline size_t char_length(char c) { | |||
| return c ? 1 : 0; | |||
| } | |||
| } // namespace | |||
| void TextTable::adjuster_last_row() { | |||
| if (m_rows.empty()) | |||
| return; | |||
| auto& row = m_rows.back(); | |||
| if (row.params.horizontal == 0 or row.params.vertical == 0) { | |||
| row.params.corner = 0; | |||
| } | |||
| if (row.params.horizontal != 0 && row.params.vertical != 0 && | |||
| row.params.corner == 0) { | |||
| row.params.corner = row.params.horizontal; | |||
| } | |||
| } | |||
| void TextTable::show(std::ostream& os) { | |||
| if (m_rows.empty()) | |||
| return; | |||
| auto& last_row = m_rows.front(); | |||
| bool first = true; | |||
| for (auto& row : m_rows) { | |||
| auto& lrow = | |||
| (last_row.values.size() * char_length(last_row.params.horizontal)) > | |||
| (row.values.size() * char_length(row.params.horizontal)) | |||
| ? last_row | |||
| : row; | |||
| // line before row | |||
| if (lrow.params.horizontal) { | |||
| if (not first) | |||
| os << std::endl; | |||
| os << m_prefix; | |||
| if (lrow.params.corner) | |||
| os << lrow.params.corner; | |||
| size_t skip_size = 0; | |||
| // table name | |||
| if (first) { | |||
| os << m_name; | |||
| skip_size = m_name.length(); | |||
| } | |||
| for (size_t i = 0; i < lrow.values.size(); ++i) { | |||
| auto max_w = m_cols_max_w.at(i) + m_padding * 2; | |||
| if (max_w + char_length(lrow.params.corner) <= skip_size) { | |||
| skip_size = skip_size - max_w - char_length(lrow.params.corner); | |||
| continue; | |||
| } | |||
| size_t rest = max_w + char_length(lrow.params.corner) - skip_size; | |||
| skip_size = 0; | |||
| if (rest > char_length(lrow.params.corner)) { | |||
| os << std::string( | |||
| rest - char_length(lrow.params.corner), | |||
| lrow.params.horizontal); | |||
| rest = char_length(lrow.params.corner); | |||
| } | |||
| if (rest > 0 && lrow.params.corner) | |||
| os << lrow.params.corner; | |||
| } | |||
| } else if (first) { | |||
| os << m_prefix << ' ' << m_name; | |||
| } | |||
| first = false; | |||
| os << std::endl << m_prefix; | |||
| if (row.params.vertical) | |||
| os << row.params.vertical; | |||
| // row | |||
| for (size_t i = 0; i < row.values.size(); ++i) { | |||
| auto& str = row.values.at(i); | |||
| auto max_w = m_cols_max_w.at(i) + 2 * m_padding; | |||
| if (row.params.align == Align::Mid) { | |||
| mid(os, str, max_w); | |||
| } else if (row.params.align == Align::Left) { | |||
| os << std::setw(max_w) << std::left << str; | |||
| } else { | |||
| os << std::setw(max_w) << std::right << str; | |||
| } | |||
| if (row.params.vertical) | |||
| os << row.params.vertical; | |||
| } | |||
| last_row = row; | |||
| } | |||
| if (last_row.params.horizontal) { | |||
| os << std::endl << m_prefix; | |||
| if (last_row.params.corner) | |||
| os << last_row.params.corner; | |||
| for (size_t i = 0; i < last_row.values.size(); ++i) { | |||
| auto max_w = m_cols_max_w.at(i); | |||
| std::string tmp(max_w + m_padding * 2, last_row.params.horizontal); | |||
| os << tmp; | |||
| if (last_row.params.corner) | |||
| os << last_row.params.corner; | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,133 @@ | |||
| /** | |||
| * \file lite/load_and_run/src/helpers/text_table.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #pragma once | |||
| #include <array> | |||
| #include <iomanip> | |||
| #include <ostream> | |||
| #include <sstream> | |||
| #include <string> | |||
| #include <tuple> | |||
| #include <type_traits> | |||
| #include <vector> | |||
| #include "megbrain/common.h" | |||
| namespace mgb { | |||
| class TextTable { | |||
| public: | |||
| enum Level { Summary, Detail }; | |||
| enum class Align : int { Left, Right, Mid }; | |||
| explicit TextTable(const std::string& table_name) : m_name(table_name) {} | |||
| TextTable& horizontal(char c) { | |||
| m_row.params.horizontal = c; | |||
| return *this; | |||
| } | |||
| TextTable& vertical(char c) { | |||
| m_row.params.vertical = c; | |||
| return *this; | |||
| } | |||
| TextTable& corner(char c) { | |||
| m_row.params.corner = c; | |||
| return *this; | |||
| } | |||
| TextTable& align(Align v) { | |||
| m_row.params.align = v; | |||
| return *this; | |||
| } | |||
| TextTable& padding(size_t w) { | |||
| m_padding = w; | |||
| return *this; | |||
| } | |||
| TextTable& prefix(const std::string& str) { | |||
| m_prefix = str; | |||
| return *this; | |||
| } | |||
| template <typename T> | |||
| TextTable& add(const T& value) { | |||
| m_row.values.emplace_back(value); | |||
| if (m_cols_max_w.size() < m_row.values.size()) { | |||
| m_cols_max_w.emplace_back(m_row.values.back().length()); | |||
| } else { | |||
| mgb_assert(m_row.values.size() >= 1); | |||
| size_t i = m_row.values.size() - 1; | |||
| m_cols_max_w[i] = std::max(m_cols_max_w[i], m_row.values.back().length()); | |||
| } | |||
| return *this; | |||
| } | |||
| template < | |||
| typename T, | |||
| typename std::enable_if<std::is_floating_point<T>::value, bool>::type = 0> | |||
| TextTable& add(const T& value) { | |||
| std::stringstream ss; | |||
| ss << std::setiosflags(std::ios::fixed) << std::setprecision(2); | |||
| ss << value; | |||
| m_row.values.emplace_back(ss.str()); | |||
| if (m_cols_max_w.size() < m_row.values.size()) { | |||
| m_cols_max_w.emplace_back(m_row.values.back().length()); | |||
| } else { | |||
| mgb_assert(m_row.values.size() >= 1); | |||
| size_t i = m_row.values.size() - 1; | |||
| m_cols_max_w[i] = std::max(m_cols_max_w[i], m_row.values.back().length()); | |||
| } | |||
| return *this; | |||
| } | |||
| template < | |||
| typename T, | |||
| typename std::enable_if<std::is_integral<T>::value, bool>::type = 0> | |||
| TextTable& add(const T& value) { | |||
| m_row.values.emplace_back(std::to_string(value)); | |||
| return *this; | |||
| } | |||
| void eor() { | |||
| m_rows.emplace_back(m_row); | |||
| adjuster_last_row(); | |||
| m_row.values.clear(); | |||
| } | |||
| void reset() { | |||
| m_row = {}; | |||
| m_cols_max_w.clear(); | |||
| m_padding = 0; | |||
| m_rows.clear(); | |||
| } | |||
| void show(std::ostream& os); | |||
| private: | |||
| void adjuster_last_row(); | |||
| std::string m_name; | |||
| std::vector<size_t> m_cols_max_w; | |||
| size_t m_padding = 0; | |||
| std::string m_prefix = ""; | |||
| struct Row { | |||
| std::vector<std::string> values; | |||
| struct Params { | |||
| Align align = Align::Left; | |||
| char horizontal = '-', vertical = '|', corner = '+'; | |||
| } params; | |||
| }; | |||
| std::vector<Row> m_rows; | |||
| Row m_row; | |||
| }; | |||
| inline std::ostream& operator<<(std::ostream& stream, TextTable& table) { | |||
| table.show(stream); | |||
| return stream; | |||
| } | |||
| } // namespace mgb | |||
| @@ -0,0 +1,31 @@ | |||
| /** | |||
| * \file lite/load_and_run/src/main.cpp | |||
| * | |||
| * This file is part of MegEngine, a deep learning framework developed by | |||
| * Megvii. | |||
| * | |||
| * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||
| */ | |||
| #include <gflags/gflags.h> | |||
| #include <string> | |||
| #include "strategys/strategy.h" | |||
| int main(int argc, char** argv) { | |||
| std::string usage = "load_and_run <model_path> [options...]"; | |||
| if (argc < 2) { | |||
| printf("usage: %s\n", usage.c_str()); | |||
| return -1; | |||
| } | |||
| gflags::SetUsageMessage(usage); | |||
| gflags::SetVersionString("1.0"); | |||
| gflags::ParseCommandLineFlags(&argc, &argv, true); | |||
| std::string model_path = argv[1]; | |||
| auto strategy = lar::StrategyBase::create_strategy(model_path); | |||
| strategy->run(); | |||
| gflags::ShutDownCommandLineFlags(); | |||
| return 0; | |||
| } | |||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||
| @@ -0,0 +1,60 @@ | |||
| /** | |||
| * \file lite/load_and_run/src/models/model.cpp | |||
| * | |||
| * This file is part of MegEngine, a deep learning framework developed by | |||
| * Megvii. | |||
| * | |||
| * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||
| */ | |||
| #include "model.h" | |||
| #include <iostream> | |||
| #include <memory> | |||
| #include "model_lite.h" | |||
| #include "model_mdl.h" | |||
| using namespace lar; | |||
| ModelType ModelBase::get_model_type(std::string model_path) { | |||
| //! read magic number of dump file | |||
| FILE* fin = fopen(model_path.c_str(), "rb"); | |||
| mgb_assert(fin, "failed to open %s: %s", model_path.c_str(), strerror(errno)); | |||
| char buf[16]; | |||
| mgb_assert(fread(buf, 1, 16, fin) == 16, "read model failed"); | |||
| fclose(fin); | |||
| // get model type | |||
| // uint32_t MGB_MAGIC = 0x5342474D | |||
| std::string tag(buf); | |||
| ModelType type; | |||
| if (tag.substr(0, 7) == std::string("mgb0001") || | |||
| tag.substr(0, 8) == std::string("mgb0000a") || | |||
| tag.substr(0, 4) == std::string("MGBS") || | |||
| tag.substr(0, 8) == std::string("mgbtest0")) { | |||
| type = ModelType::MEGDL_MODEL; | |||
| } else { | |||
| type = ModelType::LITE_MODEL; | |||
| } | |||
| return type; | |||
| } | |||
| std::shared_ptr<ModelBase> ModelBase::create_model(std::string model_path) { | |||
| mgb_log_debug("model path %s\n", model_path.c_str()); | |||
| auto model_type = get_model_type(model_path); | |||
| if (ModelType::LITE_MODEL == model_type) { | |||
| return std::make_shared<ModelLite>(model_path); | |||
| } else if (ModelType::MEGDL_MODEL == model_type) { | |||
| if (FLAGS_lite) | |||
| return std::make_shared<ModelLite>(model_path); | |||
| else | |||
| return std::make_shared<ModelMdl>(model_path); | |||
| } else { | |||
| return nullptr; | |||
| } | |||
| } | |||
| DEFINE_bool(lite, false, "using lite model to run mdl model"); | |||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||
| @@ -0,0 +1,49 @@ | |||
| /** | |||
| * \file lite/load_and_run/src/models/model.h | |||
| * | |||
| * This file is part of MegEngine, a deep learning framework developed by | |||
| * Megvii. | |||
| * | |||
| * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||
| */ | |||
| #pragma once | |||
| #include <gflags/gflags.h> | |||
| #include <string> | |||
| #include "helpers/common.h" | |||
| DECLARE_bool(lite); | |||
| namespace lar { | |||
| /*! | |||
| * \brief: base class of model | |||
| */ | |||
| class ModelBase { | |||
| public: | |||
| //! get model type by the magic number in dump file | |||
| static ModelType get_model_type(std::string model_path); | |||
| //! create model by different model type | |||
| static std::shared_ptr<ModelBase> create_model(std::string model_path); | |||
| //! type of the model | |||
| virtual ModelType type() = 0; | |||
| //! set model load state | |||
| virtual void set_shared_mem(bool state) = 0; | |||
| //! load model interface for load and run strategy | |||
| virtual void load_model() = 0; | |||
| //! run model interface for load and run strategy | |||
| virtual void run_model() = 0; | |||
| //! wait asynchronous function interface for load and run strategy | |||
| virtual void wait() = 0; | |||
| virtual ~ModelBase() = default; | |||
| }; | |||
| } // namespace lar | |||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||
| @@ -0,0 +1,50 @@ | |||
| /** | |||
| * \file lite/load_and_run/src/models/model_lite.cpp | |||
| * | |||
| * This file is part of MegEngine, a deep learning framework developed by | |||
| * Megvii. | |||
| * | |||
| * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||
| */ | |||
| #include "model_lite.h" | |||
| #include <gflags/gflags.h> | |||
| #include <cstring> | |||
| #include "misc.h" | |||
| DECLARE_bool(share_param_mem); | |||
| using namespace lar; | |||
| ModelLite::ModelLite(const std::string& path) : model_path(path) { | |||
| LITE_WARN("creat lite model use CPU as default comp node"); | |||
| }; | |||
| void ModelLite::load_model() { | |||
| m_network = std::make_shared<lite::Network>(config, IO); | |||
| if (share_model_mem) { | |||
| //! WARNNING:maybe not right to share param memmory for this | |||
| LITE_WARN("enable share model memory"); | |||
| FILE* fin = fopen(model_path.c_str(), "rb"); | |||
| LITE_ASSERT(fin, "failed to open %s: %s", model_path.c_str(), strerror(errno)); | |||
| fseek(fin, 0, SEEK_END); | |||
| size_t size = ftell(fin); | |||
| fseek(fin, 0, SEEK_SET); | |||
| void* ptr = malloc(size); | |||
| std::shared_ptr<void> buf{ptr, free}; | |||
| auto nr = fread(buf.get(), 1, size, fin); | |||
| LITE_ASSERT(nr == size, "read model file failed"); | |||
| fclose(fin); | |||
| m_network->load_model(buf.get(), size); | |||
| } else { | |||
| m_network->load_model(model_path); | |||
| } | |||
| } | |||
| void ModelLite::run_model() { | |||
| m_network->forward(); | |||
| } | |||
| void ModelLite::wait() { | |||
| m_network->wait(); | |||
| } | |||
| @@ -0,0 +1,73 @@ | |||
| /** | |||
| * \file lite/load_and_run/src/models/model_lite.h | |||
| * | |||
| * This file is part of MegEngine, a deep learning framework developed by | |||
| * Megvii. | |||
| * | |||
| * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||
| */ | |||
| #pragma once | |||
| #include <string> | |||
| #include "helpers/common.h" | |||
| #include "helpers/data_parser.h" | |||
| #include "lite/network.h" | |||
| #include "model.h" | |||
| namespace lar { | |||
| /*! | |||
| * \brief: megengine lite model | |||
| */ | |||
| class ModelLite : public ModelBase { | |||
| public: | |||
| using Strategy = LiteAlgoSelectStrategy; | |||
| ModelLite(const std::string& path); | |||
| //! model type | |||
| ModelType type() override { return ModelType::LITE_MODEL; } | |||
| //! set to load from shared memory | |||
| void set_shared_mem(bool state) override { share_model_mem = state; } | |||
| //! load model from dump file | |||
| void load_model() override; | |||
| //! run model with given runtime parameter | |||
| void run_model() override; | |||
| //! wait the end of asynchronous function execution | |||
| void wait() override; | |||
| //! get the network of lite model | |||
| std::shared_ptr<lite::Network> get_lite_network() { return m_network; } | |||
| //! get the config of lite model | |||
| lite::Config& get_config() { return config; } | |||
| //! get the networkIO of lite model | |||
| lite::NetworkIO& get_networkIO() { return IO; } | |||
| //! get the data parser | |||
| DataParser& get_input_parser() { return parser; } | |||
| //! set the strategy before load model | |||
| void set_lite_strategy(Strategy& u_strategy) { m_strategy = u_strategy; } | |||
| //! get algo strategy | |||
| Strategy& get_lite_strategy() { return m_strategy; } | |||
| private: | |||
| bool share_model_mem; | |||
| std::string model_path; | |||
| DataParser parser; | |||
| lite::Config config; | |||
| lite::NetworkIO IO; | |||
| std::shared_ptr<lite::Network> m_network; | |||
| Strategy m_strategy; | |||
| }; | |||
| } // namespace lar | |||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||
| @@ -0,0 +1,105 @@ | |||
| /** | |||
| * \file lite/load_and_run/src/models/model_mdl.cpp | |||
| * | |||
| * This file is part of MegEngine, a deep learning framework developed by | |||
| * Megvii. | |||
| * | |||
| * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||
| */ | |||
| #include "model_mdl.h" | |||
| #include <gflags/gflags.h> | |||
| #include <iostream> | |||
| DECLARE_bool(share_param_mem); | |||
| using namespace lar; | |||
| ModelMdl::ModelMdl(const std::string& path) : model_path(path) { | |||
| mgb_log_warn("creat mdl model use XPU as default comp node"); | |||
| m_load_config.comp_graph = mgb::ComputingGraph::make(); | |||
| m_load_config.comp_graph->options().graph_opt_level = 0; | |||
| testcase_num = 0; | |||
| } | |||
| void ModelMdl::load_model() { | |||
| //! read dump file | |||
| if (share_model_mem) { | |||
| mgb_log_warn("enable share model memory"); | |||
| FILE* fin = fopen(model_path.c_str(), "rb"); | |||
| mgb_assert(fin, "failed to open %s: %s", model_path.c_str(), strerror(errno)); | |||
| fseek(fin, 0, SEEK_END); | |||
| size_t size = ftell(fin); | |||
| fseek(fin, 0, SEEK_SET); | |||
| void* ptr = malloc(size); | |||
| std::shared_ptr<void> buf{ptr, free}; | |||
| auto nr = fread(buf.get(), 1, size, fin); | |||
| mgb_assert(nr == size, "read model file failed"); | |||
| fclose(fin); | |||
| m_model_file = mgb::serialization::InputFile::make_mem_proxy(buf, size); | |||
| } else { | |||
| m_model_file = mgb::serialization::InputFile::make_fs(model_path.c_str()); | |||
| } | |||
| //! get dump_with_testcase model testcase number | |||
| char magic[8]; | |||
| m_model_file->read(magic, sizeof(magic)); | |||
| if (strncmp(magic, "mgbtest0", 8)) { | |||
| m_model_file->rewind(); | |||
| } else { | |||
| m_model_file->read(&testcase_num, sizeof(testcase_num)); | |||
| } | |||
| auto format = | |||
| mgb::serialization::GraphLoader::identify_graph_dump_format(*m_model_file); | |||
| mgb_assert( | |||
| format.valid(), | |||
| "invalid format, please make sure model is dumped by GraphDumper"); | |||
| //! load computing graph of model | |||
| m_loader = mgb::serialization::GraphLoader::make( | |||
| std::move(m_model_file), format.val()); | |||
| m_load_result = m_loader->load(m_load_config, false); | |||
| m_load_config.comp_graph.reset(); | |||
| // get testcase input generated by dump_with_testcase.py | |||
| if (testcase_num) { | |||
| for (auto&& i : m_load_result.tensor_map) { | |||
| test_input_tensors.emplace_back(i.first, i.second.get()); | |||
| } | |||
| std::sort(test_input_tensors.begin(), test_input_tensors.end()); | |||
| } | |||
| // initialize output callback | |||
| for (size_t i = 0; i < m_load_result.output_var_list.size(); i++) { | |||
| mgb::ComputingGraph::Callback cb; | |||
| m_callbacks.push_back(cb); | |||
| } | |||
| } | |||
| void ModelMdl::make_output_spec() { | |||
| for (size_t i = 0; i < m_load_result.output_var_list.size(); i++) { | |||
| auto item = m_load_result.output_var_list[i]; | |||
| m_output_spec.emplace_back(item, std::move(m_callbacks[i])); | |||
| } | |||
| m_asyc_exec = m_load_result.graph_compile(m_output_spec); | |||
| } | |||
| std::shared_ptr<mgb::serialization::GraphLoader>& ModelMdl::reset_loader() { | |||
| m_loader = mgb::serialization::GraphLoader::make( | |||
| m_loader->reset_file(), m_loader->format()); | |||
| return m_loader; | |||
| } | |||
| void ModelMdl::run_model() { | |||
| mgb_assert( | |||
| m_asyc_exec != nullptr, | |||
| "empty asychronous function to execute after graph compiled"); | |||
| m_asyc_exec->execute(); | |||
| } | |||
| void ModelMdl::wait() { | |||
| m_asyc_exec->wait(); | |||
| } | |||
| @@ -0,0 +1,117 @@ | |||
| /** | |||
| * \file lite/load_and_run/src/models/model_mdl.h | |||
| * | |||
| * This file is part of MegEngine, a deep learning framework developed by | |||
| * Megvii. | |||
| * | |||
| * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||
| */ | |||
| #pragma once | |||
| #include <string> | |||
| #include "megbrain/opr/search_policy/algo_chooser_helper.h" | |||
| #include "megbrain/plugin/opr_io_dump.h" | |||
| #include "megbrain/serialization/extern_c_opr.h" | |||
| #include "megbrain/serialization/serializer.h" | |||
| #include "megbrain/utils/debug.h" | |||
| #include "megbrain/plugin/num_range_checker.h" | |||
| #include "megbrain/plugin/profiler.h" | |||
| #include "helpers/common.h" | |||
| #include "helpers/data_parser.h" | |||
| #include "model.h" | |||
| namespace lar { | |||
| class ModelMdl : public ModelBase { | |||
| public: | |||
| using Strategy = mgb::opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy; | |||
| //! interface implement of ModelBase | |||
| ModelMdl(const std::string& path); | |||
| ModelType type() override { return ModelType::MEGDL_MODEL; } | |||
| void set_shared_mem(bool state) override { share_model_mem = state; } | |||
| void load_model() override; | |||
| void make_output_spec(); | |||
| void run_model() override; | |||
| void wait() override; | |||
| //! get load result for megDL model | |||
| mgb::serialization::GraphLoader::LoadResult& get_mdl_load_result() { | |||
| return m_load_result; | |||
| } | |||
| //! get load config for megDL model | |||
| mgb::serialization::GraphLoadConfig& get_mdl_config() { return m_load_config; } | |||
| //! reset the graph loader for dump_with_testcase model | |||
| std::shared_ptr<mgb::serialization::GraphLoader>& reset_loader(); | |||
| //! algo strategy for runing model | |||
| void set_mdl_strategy(Strategy& u_strategy) { m_strategy = u_strategy; } | |||
| Strategy& get_mdl_strategy() { return m_strategy; } | |||
| //! get data parser | |||
| DataParser& get_input_parser() { return parser; } | |||
| uint32_t get_testcase_num() { return testcase_num; } | |||
| std::vector<std::pair<std::string, mgb::HostTensorND*>>& get_test_input() { | |||
| return test_input_tensors; | |||
| } | |||
| //! get output specified configuration | |||
| mgb::ComputingGraph::OutputSpec& get_output_spec() { return m_output_spec; } | |||
| std::unique_ptr<mgb::cg::AsyncExecutable>& get_async_func() { return m_asyc_exec; } | |||
| void set_output_callback(std::vector<mgb::ComputingGraph::Callback>& cb) { | |||
| mgb_assert( | |||
| m_callbacks.size() == cb.size(), | |||
| "invalid output callback list to set!!"); | |||
| for (size_t i = 0; i < cb.size(); i++) { | |||
| m_callbacks[i] = cb[i]; | |||
| } | |||
| } | |||
| #if MGB_ENABLE_JSON | |||
| std::unique_ptr<mgb::GraphProfiler>& get_profiler() { return m_profiler; } | |||
| void set_profiler() { | |||
| m_profiler = | |||
| std::make_unique<mgb::GraphProfiler>(m_load_config.comp_graph.get()); | |||
| } | |||
| #endif | |||
| void set_num_range_checker(float range) { | |||
| m_num_range_checker = std::make_unique<mgb::NumRangeChecker>( | |||
| m_load_config.comp_graph.get(), range); | |||
| } | |||
| private: | |||
| bool share_model_mem; | |||
| std::string model_path; | |||
| std::unique_ptr<mgb::serialization::InputFile> m_model_file; | |||
| mgb::serialization::GraphLoadConfig m_load_config; | |||
| mgb::serialization::GraphLoader::LoadResult m_load_result; | |||
| std::shared_ptr<mgb::serialization::GraphLoader> m_loader; | |||
| std::unique_ptr<mgb::cg::AsyncExecutable> m_asyc_exec; | |||
| uint32_t testcase_num; | |||
| std::vector<std::pair<std::string, mgb::HostTensorND*>> test_input_tensors; | |||
| DataParser parser; | |||
| Strategy m_strategy = Strategy::HEURISTIC; | |||
| std::vector<mgb::ComputingGraph::Callback> m_callbacks; | |||
| mgb::ComputingGraph::OutputSpec m_output_spec; | |||
| std::unique_ptr<mgb::NumRangeChecker> m_num_range_checker; | |||
| #if MGB_ENABLE_JSON | |||
| std::unique_ptr<mgb::GraphProfiler> m_profiler; | |||
| #endif | |||
| }; | |||
| } // namespace lar | |||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||
| @@ -0,0 +1,200 @@ | |||
| /** | |||
| * \file lite/load_and_run/src/options/device_options.cpp | |||
| * | |||
| * This file is part of MegEngine, a deep learning framework developed by | |||
| * Megvii. | |||
| * | |||
| * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||
| */ | |||
| #include <iostream> | |||
| #include <sstream> | |||
| #include "lite/global.h" | |||
| #include "megbrain/comp_node_env.h" | |||
| #include "misc.h" | |||
| #include "device_options.h" | |||
| #include "models/model_lite.h" | |||
| #include "models/model_mdl.h" | |||
| DECLARE_bool(weight_preprocess); | |||
| using namespace lar; | |||
| /////////////////// XPUDeviceOption ////////////////////// | |||
| namespace lar { | |||
| template <> | |||
| void XPUDeviceOption::config_model_internel<ModelLite>( | |||
| RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) { | |||
| if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | |||
| if ((enable_cpu) || (enable_cpu_default) || (enable_multithread) || | |||
| (enable_multithread_default)) { | |||
| LITE_WARN("using cpu device\n"); | |||
| model->get_config().device_type = LiteDeviceType::LITE_CPU; | |||
| } | |||
| #if MGE_WITH_CUDA | |||
| if (enable_cuda) { | |||
| model->get_config().device_type = LiteDeviceType::LITE_CUDA; | |||
| } | |||
| #endif | |||
| } else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) { | |||
| auto network = model->get_lite_network(); | |||
| if (enable_cpu_default) { | |||
| LITE_WARN("using cpu default device\n"); | |||
| lite::Runtime::set_cpu_inplace_mode(network); | |||
| } | |||
| if (enable_multithread) { | |||
| LITE_WARN("using multithread device\n"); | |||
| lite::Runtime::set_cpu_threads_number(network, thread_num); | |||
| } | |||
| if (enable_multithread_default) { | |||
| LITE_WARN("using multithread default device\n"); | |||
| lite::Runtime::set_cpu_inplace_mode(network); | |||
| lite::Runtime::set_cpu_threads_number(network, thread_num); | |||
| } | |||
| if (enable_set_core_ids) { | |||
| std::string core_str; | |||
| for (auto id : core_ids) { | |||
| core_str += std::to_string(id) + ","; | |||
| } | |||
| LITE_WARN("multi thread core ids: %s\n", core_str.c_str()); | |||
| lite::ThreadAffinityCallback affinity_callback = [&](size_t thread_id) { | |||
| mgb::sys::set_cpu_affinity({core_ids[thread_id]}); | |||
| }; | |||
| lite::Runtime::set_runtime_thread_affinity(network, affinity_callback); | |||
| } | |||
| } | |||
| } | |||
| template <> | |||
| void XPUDeviceOption::config_model_internel<ModelMdl>( | |||
| RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) { | |||
| if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | |||
| if (enable_cpu) { | |||
| mgb_log_warn("using cpu device\n"); | |||
| model->get_mdl_config().comp_node_mapper = [](mgb::CompNode::Locator& loc) { | |||
| loc.type = mgb::CompNode::DeviceType::CPU; | |||
| }; | |||
| } | |||
| #if MGE_WITH_CUDA | |||
| if (enable_cuda) { | |||
| mgb_log_warn("using cuda device\n"); | |||
| model->get_mdl_config().comp_node_mapper = [](mgb::CompNode::Locator& loc) { | |||
| loc.type = mgb::CompNode::DeviceType::CUDA; | |||
| }; | |||
| } | |||
| #endif | |||
| if (enable_cpu_default) { | |||
| mgb_log_warn("using cpu default device\n"); | |||
| model->get_mdl_config().comp_node_mapper = [](mgb::CompNode::Locator& loc) { | |||
| loc.type = mgb::CompNode::DeviceType::CPU; | |||
| loc.device = mgb::CompNode::Locator::DEVICE_CPU_DEFAULT; | |||
| }; | |||
| } | |||
| if (enable_multithread) { | |||
| mgb_log_warn("using multithread device\n"); | |||
| model->get_mdl_config().comp_node_mapper = | |||
| [&](mgb::CompNode::Locator& loc) { | |||
| loc.type = mgb::CompNode::DeviceType::MULTITHREAD; | |||
| loc.device = 0; | |||
| loc.stream = thread_num; | |||
| }; | |||
| } | |||
| if (enable_multithread_default) { | |||
| mgb_log_warn("using multithread default device\n"); | |||
| model->get_mdl_config().comp_node_mapper = | |||
| [&](mgb::CompNode::Locator& loc) { | |||
| loc.type = mgb::CompNode::DeviceType::MULTITHREAD; | |||
| loc.device = mgb::CompNode::Locator::DEVICE_MULTITHREAD_DEFAULT; | |||
| loc.stream = thread_num; | |||
| }; | |||
| } | |||
| if (enable_set_core_ids) { | |||
| std::string core_str; | |||
| for (auto id : core_ids) { | |||
| core_str += std::to_string(id) + ","; | |||
| } | |||
| mgb_log_warn("set multi thread core ids:%s\n", core_str.c_str()); | |||
| auto affinity_callback = [&](size_t thread_id) { | |||
| mgb::sys::set_cpu_affinity({core_ids[thread_id]}); | |||
| }; | |||
| mgb::CompNode::Locator loc; | |||
| model->get_mdl_config().comp_node_mapper(loc); | |||
| auto comp_node = mgb::CompNode::load(loc); | |||
| mgb::CompNodeEnv::from_comp_node(comp_node).cpu_env().set_affinity( | |||
| affinity_callback); | |||
| } | |||
| } | |||
| } | |||
| } // namespace lar | |||
| XPUDeviceOption::XPUDeviceOption() { | |||
| m_option_name = "xpu_device"; | |||
| enable_cpu = FLAGS_cpu; | |||
| #if MGE_WITH_CUDA | |||
| enable_cuda = FLAGS_cuda; | |||
| #endif | |||
| enable_cpu_default = FLAGS_cpu_default; | |||
| if (FLAGS_multithread >= 0) { | |||
| thread_num = FLAGS_multithread; | |||
| enable_multithread = true; | |||
| } | |||
| if (FLAGS_multithread_default >= 0) { | |||
| thread_num = FLAGS_multithread_default; | |||
| enable_multithread_default = true; | |||
| } | |||
| if (!FLAGS_multi_thread_core_ids.empty()) { | |||
| mgb_assert(enable_multithread, "core ids should be set after --multithread"); | |||
| std::stringstream id_stream(FLAGS_multi_thread_core_ids); | |||
| std::string id; | |||
| size_t thread_cnt = 0; | |||
| while (getline(id_stream, id, ',')) { | |||
| thread_cnt++; | |||
| core_ids.push_back(atoi(id.c_str())); | |||
| } | |||
| mgb_assert( | |||
| thread_cnt == thread_num, | |||
| "core ids number should be same with thread number set before"); | |||
| enable_set_core_ids = true; | |||
| } | |||
| } | |||
| bool XPUDeviceOption::is_valid() { | |||
| bool ret = FLAGS_cpu || FLAGS_cpu_default; | |||
| #if MGE_WITH_CUDA | |||
| ret = ret || FLAGS_cuda; | |||
| #endif | |||
| ret = ret || FLAGS_multithread >= 0; | |||
| ret = ret || FLAGS_multithread_default >= 0; | |||
| ret = ret || !FLAGS_multi_thread_core_ids.empty(); | |||
| return ret; | |||
| } | |||
| std::shared_ptr<OptionBase> XPUDeviceOption::create_option() { | |||
| static std::shared_ptr<lar::XPUDeviceOption> option(new XPUDeviceOption); | |||
| if (XPUDeviceOption::is_valid()) { | |||
| return std::static_pointer_cast<lar::OptionBase>(option); | |||
| } else { | |||
| return nullptr; | |||
| } | |||
| } | |||
| void XPUDeviceOption::config_model( | |||
| RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) { | |||
| CONFIG_MODEL_FUN; | |||
| } | |||
| ///////////////////////// xpu gflags //////////////////////////// | |||
| DEFINE_bool(cpu, false, "set CPU device as running device"); | |||
| #if MGE_WITH_CUDA | |||
| DEFINE_bool(cuda, false, "set CUDA device as running device "); | |||
| #endif | |||
| DEFINE_bool(cpu_default, false, "set running device as CPU device with inplace mode"); | |||
| DEFINE_int32(multithread, -1, "set multithread device as running device"); | |||
| DEFINE_int32( | |||
| multithread_default, -1, | |||
| "set multithread device as running device with inplace mode"); | |||
| DEFINE_string(multi_thread_core_ids, "", "set multithread core id"); | |||
| REGIST_OPTION_CREATOR(xpu_device, lar::XPUDeviceOption::create_option); | |||
| @@ -0,0 +1,49 @@ | |||
| /** | |||
| * \file lite/load_and_run/src/options/device_options.h | |||
| * | |||
| * This file is part of MegEngine, a deep learning framework developed by | |||
| * Megvii. | |||
| * | |||
| * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||
| */ | |||
| #pragma once | |||
| #include <gflags/gflags.h> | |||
| #include "models/model.h" | |||
| #include "option_base.h" | |||
| DECLARE_bool(cpu); | |||
| #if MGE_WITH_CUDA | |||
| DECLARE_bool(cuda); | |||
| #endif | |||
| DECLARE_bool(cpu_default); | |||
| DECLARE_int32(multithread); | |||
| DECLARE_int32(multithread_default); | |||
| DECLARE_string(multi_thread_core_ids); | |||
| namespace lar { | |||
| class XPUDeviceOption final : public OptionBase { | |||
| public: | |||
| static bool is_valid(); | |||
| static std::shared_ptr<OptionBase> create_option(); | |||
| void config_model( | |||
| RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) override; | |||
| std::string option_name() const override { return m_option_name; }; | |||
| private: | |||
| XPUDeviceOption(); | |||
| template <typename ModelImpl> | |||
| void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){}; | |||
| bool enable_cpu; | |||
| #if MGE_WITH_CUDA | |||
| bool enable_cuda; | |||
| #endif | |||
| bool enable_cpu_default; | |||
| bool enable_multithread; | |||
| bool enable_multithread_default; | |||
| bool enable_set_core_ids; | |||
| size_t thread_num; | |||
| std::vector<int> core_ids; | |||
| std::string m_option_name; | |||
| }; | |||
| } // namespace lar | |||
| @@ -0,0 +1,216 @@ | |||
| /** | |||
| * \file lite/load_and_run/src/options/extern_c_opr_options.cpp | |||
| * | |||
| * This file is part of MegEngine, a deep learning framework developed by | |||
| * Megvii. | |||
| * | |||
| * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||
| */ | |||
| #include "extern_c_opr_options.h" | |||
| #include "megbrain/utils/debug.h" | |||
| #include "misc.h" | |||
| #include "models/model_lite.h" | |||
| #include "models/model_mdl.h" | |||
| namespace lar { | |||
| template <> | |||
| void COprLibOption::config_model_internel( | |||
| RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) { | |||
| MGB_MARK_USED_VAR(model); | |||
| if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | |||
| if (!lib_path.empty()) { | |||
| lite::set_loader_lib_path(lib_path); | |||
| } | |||
| if (c_opr_args.is_run_c_opr_with_param) { | |||
| LITE_THROW( | |||
| "lite model dont't support run with external c opr " | |||
| "parmeter"); | |||
| } | |||
| } | |||
| } | |||
| template <> | |||
| void COprLibOption::config_model_internel( | |||
| RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) { | |||
| if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | |||
| if (!lib_path.empty()) { | |||
| load_lib(); | |||
| } | |||
| if (c_opr_args.is_run_c_opr_with_param) { | |||
| mgb_assert( | |||
| c_opr_args.is_run_c_opr && | |||
| c_opr_args.copr_param_device_ptr_malloc && | |||
| c_opr_args.copr_param_device_ptr_free && | |||
| c_opr_args.copr_param_device_ptr_h2d, | |||
| "--c-opr-lib-with-param need config with --c-opr-lib, also " | |||
| "extern c opr loader need implemente " | |||
| "copr_param_device_ptr_malloc, copr_param_device_ptr_free " | |||
| "and copr_param_device_ptr_h2d symbols"); | |||
| } | |||
| } else if (runtime_param.stage == RunStage::MODEL_RUNNING) { | |||
| if (model->get_testcase_num() && c_opr_args.is_run_c_opr_with_param) { | |||
| init_extern_param(model); | |||
| set_Copr_IO(model); | |||
| } | |||
| } else if (runtime_param.stage == RunStage::AFTER_RUNNING_ITER) { | |||
| if (model->get_testcase_num() && c_opr_args.is_run_c_opr_with_param) { | |||
| c_opr_args.copr_param_device_ptr_free(c_opr_param.get()); | |||
| free(c_opr_param->input); | |||
| } | |||
| } | |||
| } | |||
| } // namespace lar | |||
| using namespace lar; | |||
| MGBDType COprLibOption::dtype_cpp2c(megdnn::DType dtype) { | |||
| switch (dtype.enumv()) { | |||
| case megdnn::DTypeEnum::Float32: | |||
| return MGB_DTYPE_FLOAT32; | |||
| case megdnn::DTypeEnum::Int32: | |||
| return MGB_DTYPE_INT32; | |||
| case megdnn::DTypeEnum::Int16: | |||
| return MGB_DTYPE_INT16; | |||
| case megdnn::DTypeEnum::Uint8: | |||
| return MGB_DTYPE_UINT8; | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| case megdnn::DTypeEnum::Float16: | |||
| return MGB_DTYPE_FLOAT16; | |||
| #endif | |||
| default: | |||
| mgb_throw( | |||
| mgb::InternalError, "unsupported dtype for extern C API: %s", | |||
| dtype.name()); | |||
| } | |||
| } | |||
| void COprLibOption::tensor_shape_to_c( | |||
| const megdnn::TensorShape& shape, MGBTensorShape& mgb_shape) { | |||
| mgb_assert( | |||
| shape.ndim <= MGB_TENSOR_MAX_NDIM, "shape ndim too large: %zu", shape.ndim); | |||
| mgb_shape.ndim = shape.ndim; | |||
| for (size_t i = 0; i < shape.ndim; ++i) { | |||
| mgb_shape.shape[i] = shape[i]; | |||
| } | |||
| } | |||
| void COprLibOption::init_extern_param(std::shared_ptr<ModelBase> model_ptr) { | |||
| auto model = std::static_pointer_cast<ModelMdl>(model_ptr); | |||
| auto inp_tensors = model->get_test_input(); | |||
| c_opr_param = std::make_shared<ExternCOprParam>(); | |||
| memset(c_opr_param.get(), 0, sizeof(ExternCOprParam)); | |||
| //! we just test input on npu case, do not test output on | |||
| //! npu case, so we just init input shape and type | |||
| c_opr_param->nr_input = inp_tensors.size(); | |||
| c_opr_param->input = (ExternDeviceTensor*)malloc( | |||
| sizeof(ExternDeviceTensor) * inp_tensors.size()); | |||
| memset(c_opr_param->input, 0, sizeof(ExternDeviceTensor) * inp_tensors.size()); | |||
| //! init input ExternDeviceTensor shape and dtype | |||
| for (size_t input_idx = 0; input_idx < inp_tensors.size(); input_idx++) { | |||
| auto& mgb_tensor_layout = c_opr_param->input[input_idx].layout; | |||
| auto host_tensor_nd_p = inp_tensors[input_idx].second; | |||
| mgb_tensor_layout.dtype = dtype_cpp2c(host_tensor_nd_p->dtype()); | |||
| tensor_shape_to_c( | |||
| inp_tensors[input_idx].second->shape(), mgb_tensor_layout.shape); | |||
| } | |||
| c_opr_param->nr_output = 0; | |||
| //! now call copr_param_device_ptr_malloc to malloc | |||
| //! device_ptr | |||
| c_opr_args.copr_param_device_ptr_malloc(c_opr_param.get()); | |||
| } | |||
| void COprLibOption::load_lib() { | |||
| auto handle = dlopen(lib_path.c_str(), RTLD_LAZY); | |||
| mgb_assert(handle, "failed to open c opr lib %s: %s", lib_path.c_str(), dlerror()); | |||
| const char* entry = MGB_C_OPR_INIT_FUNC_STR; | |||
| auto func = dlsym(handle, entry); | |||
| mgb_assert(func, "can not resolve %s: %s", entry, dlerror()); | |||
| typedef void (*entry_f_t)(void*); | |||
| reinterpret_cast<entry_f_t>(func)( | |||
| reinterpret_cast<void*>(&mgb_get_extern_c_opr_api_versioned)); | |||
| printf("loaded C opr library: %s\n", lib_path.c_str()); | |||
| entry = "copr_param_device_ptr_malloc"; | |||
| func = dlsym(handle, entry); | |||
| if (func) { | |||
| printf("get %s from: %s\n", entry, lib_path.c_str()); | |||
| c_opr_args.copr_param_device_ptr_malloc = | |||
| reinterpret_cast<COprArgs::COPR_PARAM_DEVICE_PTR_MEM_T>(func); | |||
| } | |||
| entry = "copr_param_device_ptr_free"; | |||
| func = dlsym(handle, entry); | |||
| if (func) { | |||
| printf("get %s from: %s\n", entry, lib_path.c_str()); | |||
| c_opr_args.copr_param_device_ptr_free = | |||
| reinterpret_cast<COprArgs::COPR_PARAM_DEVICE_PTR_MEM_T>(func); | |||
| } | |||
| entry = "copr_param_device_ptr_h2d"; | |||
| func = dlsym(handle, entry); | |||
| if (func) { | |||
| printf("get %s from: %s\n", entry, lib_path.c_str()); | |||
| c_opr_args.copr_param_device_ptr_h2d = | |||
| reinterpret_cast<COprArgs::COPR_PARAM_DEVICE_PTR_H2D_T>(func); | |||
| } | |||
| } | |||
| void COprLibOption::set_Copr_IO(std::shared_ptr<ModelBase> model_ptr) { | |||
| auto model = std::static_pointer_cast<ModelMdl>(model_ptr); | |||
| auto inp_tensors = model->get_test_input(); | |||
| auto loader = model->reset_loader(); | |||
| auto testcase = loader->load(model->get_mdl_config(), false); | |||
| mgb_assert(testcase.output_var_list.size() == inp_tensors.size()); | |||
| for (size_t i = 0; i < inp_tensors.size(); ++i) { | |||
| auto&& opr = testcase.output_var_list[i] | |||
| .node() | |||
| ->owner_opr() | |||
| ->cast_final_safe<mgb::opr::SharedDeviceTensor>(); | |||
| c_opr_args.copr_param_device_ptr_h2d( | |||
| c_opr_param.get(), opr.dev_data()->raw_ptr(), i); | |||
| } | |||
| //! now config c opr dynamic param | |||
| config_extern_c_opr_dynamic_param(model->get_async_func(), c_opr_param); | |||
| } | |||
| COprLibOption::COprLibOption() { | |||
| m_option_name = "c_opr_lib"; | |||
| lib_path = FLAGS_c_opr_lib; | |||
| c_opr_args.is_run_c_opr = !lib_path.empty(); | |||
| c_opr_args.is_run_c_opr_with_param = FLAGS_c_opr_lib_with_param; | |||
| } | |||
| bool COprLibOption::is_valid() { | |||
| return !FLAGS_c_opr_lib.empty() || FLAGS_c_opr_lib_with_param; | |||
| } | |||
| std::shared_ptr<OptionBase> COprLibOption::create_option() { | |||
| static std::shared_ptr<COprLibOption> option(new COprLibOption); | |||
| if (COprLibOption::is_valid()) { | |||
| return std::static_pointer_cast<OptionBase>(option); | |||
| } else { | |||
| return nullptr; | |||
| } | |||
| } | |||
| void COprLibOption::config_model( | |||
| RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) { | |||
| CONFIG_MODEL_FUN; | |||
| } | |||
| DEFINE_string( | |||
| c_opr_lib, "", | |||
| "Load external operator library. It must implement " | |||
| "MGB_C_OPR_INIT_FUNC_STR as the entry point"); | |||
| DEFINE_bool( | |||
| c_opr_lib_with_param, false, | |||
| "Run c opr lib with param, use to benchmark speed and check result, " | |||
| "need c opr loader implemente `copr_param_device_ptr_malloc, " | |||
| "copr_param_device_ptr_free and copr_param_device_ptr_h2d' symbols"); | |||
| REGIST_OPTION_CREATOR(c_opr_lib, lar::COprLibOption::create_option); | |||
| @@ -0,0 +1,64 @@ | |||
| /** | |||
| * \file lite/load_and_run/src/options/extern_c_opr_options.h | |||
| * | |||
| * This file is part of MegEngine, a deep learning framework developed by | |||
| * Megvii. | |||
| * | |||
| * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||
| */ | |||
| #pragma once | |||
| #include <gflags/gflags.h> | |||
| #include "megbrain/graph/extern_copr_api.h" | |||
| #include "models/model.h" | |||
| #include "option_base.h" | |||
| DECLARE_bool(c_opr_lib_with_param); | |||
| DECLARE_string(c_opr_lib); | |||
| namespace lar { | |||
| struct COprArgs { | |||
| //! for run c opr | |||
| bool is_run_c_opr = false; | |||
| bool is_run_c_opr_with_param = false; | |||
| typedef void (*COPR_PARAM_DEVICE_PTR_MEM_T)(ExternCOprParam* param); | |||
| typedef void (*COPR_PARAM_DEVICE_PTR_H2D_T)( | |||
| ExternCOprParam* param, void* host_ptr, size_t extern_device_tensor_id); | |||
| COPR_PARAM_DEVICE_PTR_MEM_T copr_param_device_ptr_malloc = nullptr; | |||
| COPR_PARAM_DEVICE_PTR_MEM_T copr_param_device_ptr_free = nullptr; | |||
| COPR_PARAM_DEVICE_PTR_H2D_T copr_param_device_ptr_h2d = nullptr; | |||
| }; | |||
| class COprLibOption final : public OptionBase { | |||
| public: | |||
| static bool is_valid(); | |||
| static std::shared_ptr<OptionBase> create_option(); | |||
| void config_model( | |||
| RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) override; | |||
| std::string option_name() const override { return m_option_name; }; | |||
| private: | |||
| COprLibOption(); | |||
| template <typename ModelImpl> | |||
| void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){}; | |||
| void load_lib(); | |||
| MGBDType dtype_cpp2c(megdnn::DType dtype); | |||
| void tensor_shape_to_c(const megdnn::TensorShape& shape, MGBTensorShape& mgb_shape); | |||
| void init_extern_param(std::shared_ptr<ModelBase> model); | |||
| void set_Copr_IO(std::shared_ptr<ModelBase> model); | |||
| std::string m_option_name; | |||
| COprArgs c_opr_args; | |||
| std::string lib_path; | |||
| std::shared_ptr<ExternCOprParam> c_opr_param; | |||
| }; | |||
| } // namespace lar | |||
| @@ -0,0 +1,231 @@ | |||
| /** | |||
| * \file lite/load_and_run/src/options/fastrun_options.cpp | |||
| * | |||
| * This file is part of MegEngine, a deep learning framework developed by | |||
| * Megvii. | |||
| * | |||
| * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||
| */ | |||
| #include <gflags/gflags.h> | |||
| #if defined(_WIN32) | |||
| #include <io.h> | |||
| #define F_OK 0 | |||
| #define access(a, b) _access(a, b) | |||
| #elif __linux__ || __unix__ || __APPLE__ | |||
| #include <unistd.h> | |||
| #endif | |||
| #include "fastrun_options.h" | |||
| #include "megbrain/gopt/inference.h" | |||
| #include "megbrain/utils/infile_persistent_cache.h" | |||
| #include "misc.h" | |||
| #include "models/model_lite.h" | |||
| #include "models/model_mdl.h" | |||
| namespace lar { | |||
| template <> | |||
| void FastRunOption::config_model_internel<ModelLite>( | |||
| RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) { | |||
| if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | |||
| //! set the algo policy before model load | |||
| using Strategy = ModelLite::Strategy; | |||
| uint32_t strategy = 0; | |||
| #if MGB_ENABLE_FASTRUN | |||
| if (enable_full_run) { | |||
| LITE_WARN("enable full-run strategy for algo profile"); | |||
| strategy = static_cast<uint32_t>(Strategy::LITE_ALGO_PROFILE) | strategy; | |||
| } else if (enable_fast_run) { | |||
| LITE_WARN("enable fast-run strategy for algo profile"); | |||
| strategy = static_cast<uint32_t>(Strategy::LITE_ALGO_PROFILE) | | |||
| static_cast<uint32_t>(Strategy::LITE_ALGO_OPTIMIZED) | strategy; | |||
| } else { | |||
| strategy = static_cast<uint32_t>(Strategy::LITE_ALGO_HEURISTIC) | strategy; | |||
| } | |||
| #else | |||
| strategy = static_cast<uint32_t>(Strategy::LITE_ALGO_HEURISTIC) | strategy; | |||
| #endif | |||
| if (batch_binary_equal || enable_reproducible) { | |||
| LITE_WARN("enable reproducible strategy for algo profile"); | |||
| if (batch_binary_equal) | |||
| strategy = static_cast<uint32_t>(Strategy::LITE_ALGO_REPRODUCIBLE) | | |||
| strategy; | |||
| } | |||
| auto lite_strategy = static_cast<Strategy>(strategy); | |||
| model->set_lite_strategy(lite_strategy); | |||
| } else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) { | |||
| auto lite_network = model->get_lite_network(); | |||
| auto lite_strategy = model->get_lite_strategy(); | |||
| //! set algo policy for model | |||
| lite::Runtime::set_network_algo_policy( | |||
| lite_network, lite_strategy, share_batch_size, batch_binary_equal); | |||
| if (!m_fast_run_cache.empty()) { | |||
| if (!access(m_fast_run_cache.c_str(), F_OK)) { | |||
| lite::set_persistent_cache(m_fast_run_cache); | |||
| } else { | |||
| lite::set_persistent_cache(m_fast_run_cache, true); | |||
| } | |||
| //! TODO:this is from mdl model settings but not matched settings in | |||
| //! lite model | |||
| // if (!enable_full_run && !enable_fast_run) | |||
| // mgb::gopt::enable_opr_use_profiling_cache_inplace(vars); | |||
| } | |||
| } else if (runtime_param.stage == RunStage::AFTER_MODEL_RUNNING) { | |||
| #if MGB_ENABLE_FASTRUN | |||
| //! dump algo cache | |||
| if (!m_fast_run_cache.empty()) { | |||
| lite::dump_persistent_cache(m_fast_run_cache); | |||
| } | |||
| #endif | |||
| } | |||
| } | |||
| template <> | |||
| void FastRunOption::config_model_internel<ModelMdl>( | |||
| RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) { | |||
| if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | |||
| //! set the algo policy before model load | |||
| using Strategy = ModelMdl::Strategy; | |||
| auto strategy = static_cast<Strategy>(0); | |||
| #if MGB_ENABLE_FASTRUN | |||
| if (enable_full_run) { | |||
| mgb_log_warn("enable full-run strategy for algo profile"); | |||
| strategy = Strategy::PROFILE | strategy; | |||
| } else if (enable_fast_run) { | |||
| mgb_log_warn("enable fast-run strategy for algo profile"); | |||
| strategy = Strategy::PROFILE | Strategy::OPTIMIZED | strategy; | |||
| } else { | |||
| strategy = Strategy::HEURISTIC | strategy; | |||
| } | |||
| #else | |||
| strategy = Strategy::HEURISTIC | strategy; | |||
| #endif | |||
| if (batch_binary_equal || enable_reproducible) { | |||
| mgb_log_warn("enable reproducible strategy for algo profile"); | |||
| strategy = Strategy::REPRODUCIBLE | strategy; | |||
| } | |||
| model->set_mdl_strategy(strategy); | |||
| //! set binary_equal_between_batch and shared_batch_size | |||
| if (batch_binary_equal) { | |||
| mgb_log_warn("enable batch binary equal"); | |||
| model->get_mdl_config() | |||
| .comp_graph->options() | |||
| .fast_run_config.binary_equal_between_batch = true; | |||
| } | |||
| if (share_batch_size > 0) { | |||
| mgb_log_warn("set shared shared batch"); | |||
| model->get_mdl_config() | |||
| .comp_graph->options() | |||
| .fast_run_config.shared_batch_size = share_batch_size; | |||
| } | |||
| } else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) { | |||
| auto vars = model->get_mdl_load_result().output_var_list; | |||
| auto strategy = model->get_mdl_strategy(); | |||
| mgb::gopt::modify_opr_algo_strategy_inplace(vars, strategy); | |||
| // set algo cache path | |||
| if (!m_fast_run_cache.empty()) { | |||
| if (!access(m_fast_run_cache.c_str(), F_OK)) { | |||
| mgb::PersistentCache::set_impl( | |||
| std::make_shared<mgb::InFilePersistentCache>( | |||
| m_fast_run_cache.c_str())); | |||
| } else { | |||
| mgb::PersistentCache::set_impl( | |||
| std::make_shared<mgb::InFilePersistentCache>()); | |||
| } | |||
| #if MGB_ENABLE_FASTRUN | |||
| if (!enable_full_run && !enable_fast_run) | |||
| #endif | |||
| mgb::gopt::enable_opr_use_profiling_cache_inplace(vars); | |||
| } | |||
| } else if (runtime_param.stage == RunStage::AFTER_MODEL_RUNNING) { | |||
| #if MGB_ENABLE_FASTRUN | |||
| //! dump algo cache | |||
| if (!m_fast_run_cache.empty()) { | |||
| static_cast<mgb::InFilePersistentCache&>(mgb::PersistentCache::inst()) | |||
| .dump_cache(m_fast_run_cache.c_str()); | |||
| } | |||
| #endif | |||
| } | |||
| } | |||
| } // namespace lar | |||
| using namespace lar; | |||
| FastRunOption::FastRunOption() { | |||
| m_option_name = "fastrun"; | |||
| #if MGB_ENABLE_FASTRUN | |||
| enable_fast_run = FLAGS_fast_run; | |||
| enable_full_run = FLAGS_full_run; | |||
| #endif | |||
| batch_binary_equal = FLAGS_binary_equal_between_batch; | |||
| enable_reproducible = FLAGS_reproducible; | |||
| m_fast_run_cache = FLAGS_fast_run_algo_policy; | |||
| share_batch_size = FLAGS_fast_run_shared_batch_size; | |||
| #if MGB_ENABLE_FASTRUN | |||
| //! while fastrun cache file path is not empty and can't be accessed | |||
| if (!m_fast_run_cache.empty() && access(m_fast_run_cache.c_str(), F_OK)) { | |||
| mgb_assert( | |||
| enable_full_run || enable_fast_run, | |||
| "--fast-run or --full-run should be enabled"); | |||
| } | |||
| if (share_batch_size) { | |||
| mgb_assert( | |||
| enable_full_run || enable_fast_run || !m_fast_run_cache.empty(), | |||
| "--fast-run-shared-batch-size should be used with " | |||
| "--fast-run|--full-run|--fast-run-algo-policy"); | |||
| } | |||
| #endif | |||
| } | |||
| bool FastRunOption::is_valid() { | |||
| bool ret = false; | |||
| #if MGB_ENABLE_FASTRUN | |||
| ret = ret || FLAGS_fast_run; | |||
| ret = ret || FLAGS_full_run; | |||
| #endif | |||
| ret = ret || FLAGS_binary_equal_between_batch; | |||
| ret = ret || FLAGS_fast_run_shared_batch_size > 0; | |||
| ret = ret || FLAGS_reproducible; | |||
| ret = ret || FLAGS_fast_run_algo_policy.size() > 0; | |||
| return ret; | |||
| } | |||
| std::shared_ptr<OptionBase> FastRunOption::create_option() { | |||
| static std::shared_ptr<FastRunOption> option(new FastRunOption); | |||
| if (FastRunOption::is_valid()) { | |||
| return std::static_pointer_cast<OptionBase>(option); | |||
| } else { | |||
| return nullptr; | |||
| } | |||
| } | |||
| void FastRunOption::config_model( | |||
| RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) { | |||
| CONFIG_MODEL_FUN; | |||
| } | |||
| #if MGB_ENABLE_FASTRUN | |||
| DEFINE_bool(fast_run, false, "whether to use fast-run in model run"); | |||
| DEFINE_bool(full_run, false, "whether to use full-run in model run"); | |||
| #endif | |||
| DEFINE_bool( | |||
| binary_equal_between_batch, false, | |||
| "Each batch of output is promised binary equal if each batch of " | |||
| "input is binary equal\n Note that if this option is turned on, " | |||
| "`--reproducible` will also be turned on."); | |||
| DEFINE_bool( | |||
| reproducible, false, | |||
| "Enable choose algo which is reproducible. It mainly used for " | |||
| "cudnn algos.See " | |||
| "https://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/" | |||
| "index.html#reproducibility" | |||
| "for more details."); | |||
| DEFINE_uint32(fast_run_shared_batch_size, 0, "Set the batch size used during fastrun"); | |||
| DEFINE_string(fast_run_algo_policy, "", "fast-run cache path."); | |||
| REGIST_OPTION_CREATOR(fastrun, lar::FastRunOption::create_option); | |||
| @@ -0,0 +1,57 @@ | |||
| /** | |||
| * \file lite/load_and_run/src/options/fastrun_options.h | |||
| * | |||
| * This file is part of MegEngine, a deep learning framework developed by | |||
| * Megvii. | |||
| * | |||
| * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||
| */ | |||
| #pragma once | |||
| #include <gflags/gflags.h> | |||
| #include "models/model.h" | |||
| #include "option_base.h" | |||
| #if MGB_ENABLE_FASTRUN | |||
| DECLARE_bool(fast_run); | |||
| DECLARE_bool(full_run); | |||
| #endif | |||
| DECLARE_bool(reproducible); | |||
| DECLARE_bool(binary_equal_between_batch); | |||
| DECLARE_uint32(fast_run_shared_batch_size); | |||
| DECLARE_string(fast_run_algo_policy); | |||
| namespace lar { | |||
| class FastRunOption final : public OptionBase { | |||
| public: | |||
| //! get condition for construct FastRunOption | |||
| static bool is_valid(); | |||
| //! creat option using condition from cmdline args | |||
| static std::shared_ptr<OptionBase> create_option(); | |||
| //! configure model for different runtime_param | |||
| void config_model( | |||
| RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) override; | |||
| //! get options name for quickly search | |||
| std::string option_name() const override { return m_option_name; } | |||
| private: | |||
| FastRunOption(); | |||
| //! config template for different model | |||
| template <typename ModelImpl> | |||
| void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>) {} | |||
| #if MGB_ENABLE_FASTRUN | |||
| bool enable_fast_run; //! fast run strategy flag | |||
| bool enable_full_run; //! full run strategy flag | |||
| #endif | |||
| bool batch_binary_equal; //! fast run stratgey setting | |||
| bool enable_reproducible; //! enable reproducible strategy | |||
| size_t share_batch_size; //! fast run strategy share batch size setting | |||
| std::string m_fast_run_cache; //! fast run cache file path | |||
| std::string m_option_name; //! option name | |||
| }; | |||
| } // namespace lar | |||
| @@ -0,0 +1,295 @@ | |||
| /** | |||
| * \file lite/load_and_run/src/options/io_options.cpp | |||
| * | |||
| * This file is part of MegEngine, a deep learning framework developed by | |||
| * Megvii. | |||
| * | |||
| * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||
| */ | |||
| #include <map> | |||
| #include "helpers/data_parser.h" | |||
| #include "misc.h" | |||
| #include "models/model_lite.h" | |||
| #include "models/model_mdl.h" | |||
| #include "io_options.h" | |||
| namespace lar { | |||
| template <> | |||
| void InputOption::config_model_internel<ModelLite>( | |||
| RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) { | |||
| if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | |||
| auto parser = model->get_input_parser(); | |||
| auto io = model->get_networkIO(); | |||
| for (size_t idx = 0; idx < data_path.size(); ++idx) { | |||
| parser.feed(data_path[idx].c_str()); | |||
| } | |||
| auto inputs = parser.inputs; | |||
| bool is_host = true; | |||
| for (auto& i : inputs) { | |||
| io.inputs.push_back({i.first, is_host}); | |||
| } | |||
| } else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) { | |||
| auto config = model->get_config(); | |||
| auto parser = model->get_input_parser(); | |||
| auto network = model->get_lite_network(); | |||
| //! datd type map from mgb data type to lite data type | |||
| std::map<megdnn::DTypeEnum, LiteDataType> type_map = { | |||
| {megdnn::DTypeEnum::Float32, LiteDataType::LITE_FLOAT}, | |||
| {megdnn::DTypeEnum::Int32, LiteDataType::LITE_INT}, | |||
| {megdnn::DTypeEnum::Int8, LiteDataType::LITE_INT8}, | |||
| {megdnn::DTypeEnum::Uint8, LiteDataType::LITE_UINT8}}; | |||
| for (auto& i : parser.inputs) { | |||
| //! get tensor information from data parser | |||
| auto tensor = i.second; | |||
| auto data_type = tensor.dtype(); | |||
| auto tensor_shape = tensor.shape(); | |||
| mgb::dt_byte* src = tensor.raw_ptr(); | |||
| //! set lite layout | |||
| lite::Layout layout; | |||
| layout.ndim = tensor_shape.ndim; | |||
| for (size_t idx = 0; idx < tensor_shape.ndim; idx++) { | |||
| layout.shapes[idx] = tensor_shape[idx]; | |||
| } | |||
| layout.data_type = type_map[data_type.enumv()]; | |||
| //! set network input tensor | |||
| std::shared_ptr<lite::Tensor> input_tensor = | |||
| network->get_io_tensor(i.first); | |||
| input_tensor->reset(src, layout); | |||
| } | |||
| } | |||
| } | |||
| template <> | |||
| void InputOption::config_model_internel<ModelMdl>( | |||
| RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) { | |||
| if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | |||
| auto parser = model->get_input_parser(); | |||
| for (size_t idx = 0; idx < data_path.size(); ++idx) { | |||
| parser.feed(data_path[idx].c_str()); | |||
| } | |||
| } else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) { | |||
| auto parser = model->get_input_parser(); | |||
| auto network = model->get_mdl_load_result(); | |||
| auto tensormap = network.tensor_map; | |||
| for (auto& i : parser.inputs) { | |||
| mgb_assert( | |||
| tensormap.find(i.first) != tensormap.end(), | |||
| "can't find tesnor named %s", i.first.c_str()); | |||
| auto& in = tensormap.find(i.first)->second; | |||
| in->copy_from(i.second); | |||
| } | |||
| } | |||
| } | |||
| template <> | |||
| void IOdumpOption::config_model_internel<ModelLite>( | |||
| RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) { | |||
| if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) { | |||
| if (enable_io_dump) { | |||
| LITE_WARN("enable text io dump"); | |||
| lite::Runtime::enable_io_txt_dump(model->get_lite_network(), dump_path); | |||
| } | |||
| if (enable_bin_io_dump) { | |||
| LITE_WARN("enable binary io dump"); | |||
| lite::Runtime::enable_io_bin_dump(model->get_lite_network(), dump_path); | |||
| } | |||
| //! FIX:when add API in lite complate this | |||
| if (enable_io_dump_stdout || enable_io_dump_stderr) { | |||
| LITE_THROW("lite model don't support the stdout or stderr io dump"); | |||
| } | |||
| if (enable_bin_out_dump) { | |||
| LITE_THROW("lite model don't support the binary output dump"); | |||
| } | |||
| if (enable_copy_to_host) { | |||
| LITE_WARN("lite model set copy to host defaultly"); | |||
| } | |||
| } | |||
| } | |||
| template <> | |||
| void IOdumpOption::config_model_internel<ModelMdl>( | |||
| RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) { | |||
| if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | |||
| if (enable_io_dump) { | |||
| mgb_log_warn("enable text io dump"); | |||
| auto iodump = std::make_unique<mgb::TextOprIODump>( | |||
| model->get_mdl_config().comp_graph.get(), dump_path.c_str()); | |||
| iodump->print_addr(false); | |||
| io_dumper = std::move(iodump); | |||
| } | |||
| if (enable_io_dump_stdout) { | |||
| mgb_log_warn("enable text io dump to stdout"); | |||
| std::shared_ptr<FILE> std_out(stdout, [](FILE*) {}); | |||
| auto iodump = std::make_unique<mgb::TextOprIODump>( | |||
| model->get_mdl_config().comp_graph.get(), std_out); | |||
| iodump->print_addr(false); | |||
| io_dumper = std::move(iodump); | |||
| } | |||
| if (enable_io_dump_stderr) { | |||
| mgb_log_warn("enable text io dump to stderr"); | |||
| std::shared_ptr<FILE> std_err(stderr, [](FILE*) {}); | |||
| auto iodump = std::make_unique<mgb::TextOprIODump>( | |||
| model->get_mdl_config().comp_graph.get(), std_err); | |||
| iodump->print_addr(false); | |||
| io_dumper = std::move(iodump); | |||
| } | |||
| if (enable_bin_io_dump) { | |||
| mgb_log_warn("enable binary io dump"); | |||
| auto iodump = std::make_unique<mgb::BinaryOprIODump>( | |||
| model->get_mdl_config().comp_graph.get(), dump_path); | |||
| io_dumper = std::move(iodump); | |||
| } | |||
| if (enable_bin_out_dump) { | |||
| mgb_log_warn("enable binary output dump"); | |||
| out_dumper = std::make_unique<OutputDumper>(dump_path.c_str()); | |||
| } | |||
| } else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) { | |||
| if (enable_bin_out_dump) { | |||
| auto load_result = model->get_mdl_load_result(); | |||
| out_dumper->set(load_result.output_var_list); | |||
| std::vector<mgb::ComputingGraph::Callback> cb; | |||
| for (size_t i = 0; i < load_result.output_var_list.size(); i++) { | |||
| cb.push_back(out_dumper->bind()); | |||
| } | |||
| model->set_output_callback(cb); | |||
| } | |||
| if (enable_copy_to_host) { | |||
| auto load_result = model->get_mdl_load_result(); | |||
| std::vector<mgb::ComputingGraph::Callback> cb; | |||
| for (size_t i = 0; i < load_result.output_var_list.size(); i++) { | |||
| mgb::HostTensorND val; | |||
| auto callback = [val](const mgb::DeviceTensorND& dv) mutable { | |||
| val.copy_from(dv); | |||
| }; | |||
| cb.push_back(callback); | |||
| } | |||
| model->set_output_callback(cb); | |||
| } | |||
| } else if (runtime_param.stage == RunStage::AFTER_RUNNING_WAIT) { | |||
| if (enable_bin_out_dump) { | |||
| out_dumper->write_to_file(); | |||
| } | |||
| } | |||
| } | |||
| } // namespace lar | |||
| ////////////////////// Input options //////////////////////// | |||
| using namespace lar; | |||
| InputOption::InputOption() { | |||
| m_option_name = "input"; | |||
| size_t start = 0; | |||
| auto end = FLAGS_input.find(";", start); | |||
| while (end != std::string::npos) { | |||
| std::string path = FLAGS_input.substr(start, end - start); | |||
| data_path.emplace_back(path); | |||
| start = end + 1; | |||
| end = FLAGS_input.find(";", start); | |||
| } | |||
| data_path.emplace_back(FLAGS_input.substr(start)); | |||
| } | |||
| std::shared_ptr<lar::OptionBase> lar::InputOption::create_option() { | |||
| static std::shared_ptr<InputOption> m_option(new InputOption); | |||
| if (InputOption::is_valid()) { | |||
| return std::static_pointer_cast<OptionBase>(m_option); | |||
| } else { | |||
| return nullptr; | |||
| } | |||
| } | |||
| void InputOption::config_model( | |||
| RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) { | |||
| CONFIG_MODEL_FUN; | |||
| } | |||
| ////////////////////// OprIOdump options //////////////////////// | |||
| IOdumpOption::IOdumpOption() { | |||
| m_option_name = "iodump"; | |||
| size_t valid_flag = 0; | |||
| if (!FLAGS_io_dump.empty()) { | |||
| dump_path = FLAGS_io_dump; | |||
| enable_io_dump = true; | |||
| valid_flag = valid_flag | (1 << 0); | |||
| } | |||
| if (!FLAGS_bin_io_dump.empty()) { | |||
| dump_path = FLAGS_bin_io_dump; | |||
| enable_bin_io_dump = true; | |||
| valid_flag = valid_flag | (1 << 1); | |||
| } | |||
| if (!FLAGS_bin_out_dump.empty()) { | |||
| dump_path = FLAGS_bin_out_dump; | |||
| enable_bin_out_dump = true; | |||
| valid_flag = valid_flag | (1 << 2); | |||
| } | |||
| if (FLAGS_io_dump_stdout) { | |||
| enable_io_dump_stdout = FLAGS_io_dump_stdout; | |||
| valid_flag = valid_flag | (1 << 3); | |||
| } | |||
| if (FLAGS_io_dump_stderr) { | |||
| enable_io_dump_stderr = FLAGS_io_dump_stderr; | |||
| valid_flag = valid_flag | (1 << 4); | |||
| } | |||
| // not only one dump set valid | |||
| if (valid_flag && (valid_flag & (valid_flag - 1))) { | |||
| mgb_log_warn( | |||
| "ONLY the last io dump option is validate and others is " | |||
| "skipped!!!"); | |||
| } | |||
| enable_copy_to_host = FLAGS_copy_to_host; | |||
| } | |||
| bool IOdumpOption::is_valid() { | |||
| bool ret = !FLAGS_io_dump.empty(); | |||
| ret = ret || FLAGS_io_dump_stdout; | |||
| ret = ret || FLAGS_io_dump_stderr; | |||
| ret = ret || !FLAGS_bin_io_dump.empty(); | |||
| ret = ret || !FLAGS_bin_out_dump.empty(); | |||
| ret = ret || FLAGS_copy_to_host; | |||
| return ret; | |||
| } | |||
| std::shared_ptr<OptionBase> IOdumpOption::create_option() { | |||
| static std::shared_ptr<IOdumpOption> option(new IOdumpOption); | |||
| if (IOdumpOption::is_valid()) { | |||
| return std::static_pointer_cast<OptionBase>(option); | |||
| } else { | |||
| return nullptr; | |||
| } | |||
| } | |||
| void IOdumpOption::config_model( | |||
| RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) { | |||
| CONFIG_MODEL_FUN; | |||
| } | |||
| ////////////////////// Input gflags //////////////////////// | |||
| DEFINE_string( | |||
| input, "", "Set up inputs data for model --input [ file_path | data_string]"); | |||
| ////////////////////// OprIOdump gflags //////////////////////// | |||
| DEFINE_string(io_dump, "", "set the io dump file path in text format"); | |||
| DEFINE_bool(io_dump_stdout, false, "dump io opr to stdout in text format"); | |||
| DEFINE_bool(io_dump_stderr, false, "dump io opr to stderr in text format"); | |||
| DEFINE_string(bin_io_dump, "", "set the io dump file path in binary format"); | |||
| DEFINE_string(bin_out_dump, "", "set the out dump file path in binary format"); | |||
| DEFINE_bool(copy_to_host, false, "copy device data to host"); | |||
| REGIST_OPTION_CREATOR(input, lar::InputOption::create_option); | |||
| REGIST_OPTION_CREATOR(iodump, lar::IOdumpOption::create_option); | |||
| @@ -0,0 +1,78 @@ | |||
| /** | |||
| * \file lite/load_and_run/src/options/io_options.h | |||
| * | |||
| * This file is part of MegEngine, a deep learning framework developed by | |||
| * Megvii. | |||
| * | |||
| * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||
| */ | |||
| #pragma once | |||
| #include <gflags/gflags.h> | |||
| #include "helpers/outdumper.h" | |||
| #include "megbrain/plugin/opr_io_dump.h" | |||
| #include "models/model.h" | |||
| #include "option_base.h" | |||
| DECLARE_string(input); | |||
| DECLARE_string(io_dump); | |||
| DECLARE_bool(io_dump_stdout); | |||
| DECLARE_bool(io_dump_stderr); | |||
| DECLARE_string(bin_io_dump); | |||
| DECLARE_string(bin_out_dump); | |||
| DECLARE_bool(copy_to_host); | |||
| namespace lar { | |||
| /*! | |||
| * \brief: input option for --input set | |||
| */ | |||
| class InputOption final : public OptionBase { | |||
| public: | |||
| //! static function for registe options | |||
| static bool is_valid() { return !FLAGS_input.empty(); }; | |||
| static std::shared_ptr<OptionBase> create_option(); | |||
| void config_model( | |||
| RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) override; | |||
| //! interface implement from OptionBase | |||
| std::string option_name() const override { return m_option_name; }; | |||
| private: | |||
| InputOption(); | |||
| template <typename ModelImpl> | |||
| void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){}; | |||
| std::string m_option_name; | |||
| std::vector<std::string> data_path; // data string or data file path | |||
| }; | |||
| class IOdumpOption : public OptionBase { | |||
| public: | |||
| static bool is_valid(); | |||
| static std::shared_ptr<OptionBase> create_option(); | |||
| //! config the model, if different has different configure code, then | |||
| //! dispatch | |||
| void config_model( | |||
| RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) override; | |||
| std::string option_name() const override { return m_option_name; }; | |||
| private: | |||
| IOdumpOption(); | |||
| template <typename ModelImpl> | |||
| void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){}; | |||
| bool enable_io_dump; | |||
| bool enable_io_dump_stdout; | |||
| bool enable_io_dump_stderr; | |||
| bool enable_bin_io_dump; | |||
| bool enable_bin_out_dump; | |||
| bool enable_copy_to_host; | |||
| std::string m_option_name; | |||
| std::string dump_path; | |||
| std::unique_ptr<mgb::OprIODumpBase> io_dumper; | |||
| std::unique_ptr<OutputDumper> out_dumper; | |||
| }; | |||
| } // namespace lar | |||
| @@ -0,0 +1,171 @@ | |||
| /** | |||
| * \file lite/load_and_run/src/options/layout_options.cpp | |||
| * | |||
| * This file is part of MegEngine, a deep learning framework developed by | |||
| * Megvii. | |||
| * | |||
| * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||
| */ | |||
| #include <gflags/gflags.h> | |||
| #include "misc.h" | |||
| #include "models/model_lite.h" | |||
| #include "models/model_mdl.h" | |||
| #include "layout_options.h" | |||
| namespace lar { | |||
| template <> | |||
| void LayoutOption::config_model_internel<ModelLite>( | |||
| RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) { | |||
| if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | |||
| #define ENABLE_LAYOUT(layout) \ | |||
| LITE_WARN("enable " #layout " optimization"); \ | |||
| model->get_config().options.enable_##layout = true; \ | |||
| break; | |||
| switch (option_flag) { | |||
| case OptLayoutType::NCHW4: | |||
| ENABLE_LAYOUT(nchw4) | |||
| case OptLayoutType::CHWN4: | |||
| LITE_THROW("lite model unsupport chwn4 layout"); | |||
| break; | |||
| case OptLayoutType::NCHW44: | |||
| ENABLE_LAYOUT(nchw44) | |||
| case OptLayoutType::NCHW88: | |||
| ENABLE_LAYOUT(nchw88) | |||
| case OptLayoutType::NCHW32: | |||
| ENABLE_LAYOUT(nchw32) | |||
| case OptLayoutType::NCHW64: | |||
| ENABLE_LAYOUT(nchw64) | |||
| case OptLayoutType::NHWCD4: | |||
| ENABLE_LAYOUT(nhwcd4) | |||
| case OptLayoutType::NCHW44_DOT: | |||
| ENABLE_LAYOUT(nchw44_dot) | |||
| default: | |||
| break; | |||
| } | |||
| #undef ENABLE_LAYOUT | |||
| } | |||
| } | |||
| template <> | |||
| void lar::LayoutOption::config_model_internel<ModelMdl>( | |||
| RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) { | |||
| if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | |||
| mgb_log_debug("mdl layout config start"); | |||
| #define ENABLE_LAYOUT(layout) \ | |||
| mgb_log_warn("enable " #layout " optimization"); \ | |||
| model->get_mdl_config().comp_graph->options().graph_opt.enable_##layout(); \ | |||
| break; | |||
| switch (option_flag) { | |||
| case OptLayoutType::NCHW4: | |||
| ENABLE_LAYOUT(nchw4) | |||
| case OptLayoutType::CHWN4: | |||
| ENABLE_LAYOUT(chwn4) | |||
| case OptLayoutType::NCHW44: | |||
| ENABLE_LAYOUT(nchw44) | |||
| case OptLayoutType::NCHW88: | |||
| ENABLE_LAYOUT(nchw88) | |||
| case OptLayoutType::NCHW32: | |||
| ENABLE_LAYOUT(nchw32) | |||
| case OptLayoutType::NCHW64: | |||
| ENABLE_LAYOUT(nchw64) | |||
| case OptLayoutType::NHWCD4: | |||
| ENABLE_LAYOUT(nhwcd4) | |||
| case OptLayoutType::NCHW44_DOT: | |||
| ENABLE_LAYOUT(nchw44_dot) | |||
| default: | |||
| break; | |||
| } | |||
| mgb_log_debug("mdl layout config end"); | |||
| #undef ENABLE_LAYOUT | |||
| } | |||
| } | |||
| } // namespace lar | |||
| using namespace lar; | |||
| OptLayoutType LayoutOption::option_flag; | |||
| LayoutOption::LayoutOption() { | |||
| m_option_name = "layout"; | |||
| } | |||
| bool LayoutOption::is_valid() { | |||
| size_t valid_flag = 0; | |||
| if (FLAGS_enable_nchw4) { | |||
| valid_flag = valid_flag | (1 << 0); | |||
| } | |||
| if (FLAGS_enable_chwn4) { | |||
| valid_flag = valid_flag | (1 << 1); | |||
| } | |||
| if (FLAGS_enable_nchw44) { | |||
| valid_flag = valid_flag | (1 << 2); | |||
| } | |||
| if (FLAGS_enable_nchw88) { | |||
| valid_flag = valid_flag | (1 << 3); | |||
| } | |||
| if (FLAGS_enable_nchw32) { | |||
| valid_flag = valid_flag | (1 << 4); | |||
| } | |||
| if (FLAGS_enable_nchw64) { | |||
| valid_flag = valid_flag | (1 << 5); | |||
| } | |||
| if (FLAGS_enable_nhwcd4) { | |||
| valid_flag = valid_flag | (1 << 6); | |||
| } | |||
| if (FLAGS_enable_nchw44_dot) { | |||
| valid_flag = valid_flag | (1 << 7); | |||
| } | |||
| bool ret = valid_flag && !(valid_flag & (valid_flag - 1)); | |||
| if (ret) { | |||
| option_flag = static_cast<OptLayoutType>(valid_flag); | |||
| } else { | |||
| option_flag = static_cast<OptLayoutType>(0); | |||
| } | |||
| return ret; | |||
| }; | |||
| std::shared_ptr<OptionBase> LayoutOption::create_option() { | |||
| static std::shared_ptr<LayoutOption> option(new LayoutOption); | |||
| if (LayoutOption::is_valid()) { | |||
| return std::static_pointer_cast<OptionBase>(option); | |||
| } else { | |||
| return nullptr; | |||
| } | |||
| } | |||
| void LayoutOption::config_model( | |||
| RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) { | |||
| CONFIG_MODEL_FUN; | |||
| } | |||
| DEFINE_bool(enable_nchw4, false, "enable nchw4 layout optimization!!"); | |||
| DEFINE_bool(enable_chwn4, false, "enable chwn4 layout optimization!!"); | |||
| DEFINE_bool(enable_nchw44, false, "enable nchw44 layout optimization!!"); | |||
| DEFINE_bool(enable_nchw88, false, "enable nchw88 layout optimization!!"); | |||
| DEFINE_bool(enable_nchw32, false, "enable nchw32 layout optimization!!"); | |||
| DEFINE_bool(enable_nchw64, false, "enable nchw64 layout optimization!!"); | |||
| DEFINE_bool(enable_nhwcd4, false, "enable nhwcd4 layout optimization!!"); | |||
| DEFINE_bool(enable_nchw44_dot, false, "enable nchw444-dot layout optimization!!"); | |||
| REGIST_OPTION_CREATOR(layout, lar::LayoutOption::create_option); | |||
| @@ -0,0 +1,56 @@ | |||
| /** | |||
| * \file lite/load_and_run/src/options/layout_options.h | |||
| * | |||
| * This file is part of MegEngine, a deep learning framework developed by | |||
| * Megvii. | |||
| * | |||
| * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||
| */ | |||
| #pragma once | |||
| #include <gflags/gflags.h> | |||
| #include "helpers/common.h" | |||
| #include "models/model.h" | |||
| #include "option_base.h" | |||
| DECLARE_bool(enable_nchw4); | |||
| DECLARE_bool(enable_chwn4); | |||
| DECLARE_bool(enable_nchw44); | |||
| DECLARE_bool(enable_nchw88); | |||
| DECLARE_bool(enable_nchw32); | |||
| DECLARE_bool(enable_nchw64); | |||
| DECLARE_bool(enable_nhwcd4); | |||
| DECLARE_bool(enable_nchw44_dot); | |||
| namespace lar { | |||
| /*! | |||
| * \brief: layout option for optimization | |||
| */ | |||
| class LayoutOption final : public OptionBase { | |||
| public: | |||
| //! check the validation of option flag | |||
| static bool is_valid(); | |||
| //! creat options when option is used | |||
| static std::shared_ptr<OptionBase> create_option(); | |||
| //! config the model, dispatch configuration for different model implement | |||
| void config_model( | |||
| RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) override; | |||
| //! get option name | |||
| std::string option_name() const override { return m_option_name; }; | |||
| private: | |||
| //! Constructor | |||
| LayoutOption(); | |||
| //! configuration for different model implement | |||
| template <typename ModelImpl> | |||
| void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){}; | |||
| static OptLayoutType option_flag; | |||
| std::string m_option_name; | |||
| }; | |||
| } // namespace lar | |||
| @@ -0,0 +1,600 @@ | |||
| /** | |||
| * \file lite/load_and_run/src/options/optimize_options.cpp | |||
| * | |||
| * This file is part of MegEngine, a deep learning framework developed by | |||
| * Megvii. | |||
| * | |||
| * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||
| */ | |||
| #include "megbrain/gopt/inference.h" | |||
| #if MGB_ENABLE_TENSOR_RT | |||
| #include "megbrain/tensorrt/tensorrt_engine_cache.h" | |||
| #endif | |||
| #include "lite/global.h" | |||
| #include "misc.h" | |||
| #include "models/model_lite.h" | |||
| #include "models/model_mdl.h" | |||
| #include "optimize_options.h" | |||
| ///////////////////////// fuse and preprocess optimize options /////////////// | |||
| namespace lar { | |||
| template <> | |||
| void FusePreprocessOption::config_model_internel<ModelLite>( | |||
| RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) { | |||
| if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | |||
| if (enable_fuse_preprocess) { | |||
| LITE_WARN("enable fuse-preprocess optimization"); | |||
| model->get_config().options.fuse_preprocess = true; | |||
| } | |||
| } | |||
| } | |||
| template <> | |||
| void FusePreprocessOption::config_model_internel<ModelMdl>( | |||
| RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) { | |||
| if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | |||
| auto&& graph_option = model->get_mdl_config().comp_graph->options(); | |||
| if (enable_fuse_preprocess) { | |||
| mgb_log_warn("enable fuse-preprocess optimization"); | |||
| graph_option.graph_opt.enable_fuse_preprocess(); | |||
| } | |||
| } | |||
| } | |||
| } // namespace lar | |||
| using namespace lar; | |||
| FusePreprocessOption::FusePreprocessOption() { | |||
| m_option_name = "fuse_preprocess"; | |||
| enable_fuse_preprocess = FLAGS_enable_fuse_preprocess; | |||
| } | |||
| bool FusePreprocessOption::is_valid() { | |||
| bool ret = FLAGS_enable_fuse_preprocess; | |||
| return ret; | |||
| } | |||
| std::shared_ptr<OptionBase> FusePreprocessOption::create_option() { | |||
| static std::shared_ptr<FusePreprocessOption> option(new FusePreprocessOption); | |||
| if (FusePreprocessOption::is_valid()) { | |||
| return std::static_pointer_cast<OptionBase>(option); | |||
| } else { | |||
| return nullptr; | |||
| } | |||
| } | |||
| void FusePreprocessOption::config_model( | |||
| RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) { | |||
| CONFIG_MODEL_FUN; | |||
| } | |||
| ///////////////////////// weight preprocess optimize options /////////////// | |||
| namespace lar { | |||
| template <> | |||
| void WeightPreprocessOption::config_model_internel<ModelLite>( | |||
| RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) { | |||
| if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | |||
| if (weight_preprocess) { | |||
| LITE_WARN("enable weight-preprocess optimization"); | |||
| model->get_config().options.weight_preprocess = true; | |||
| //! FIXME: algo searcher enable weight preprocess for opencl( | |||
| //! implement below has some problem); | |||
| // #if MGB_OPENCL | |||
| // megdnn::opencl::algo_searcher::AlgoSearcherBase:: | |||
| // enable_weight_preprocess(); | |||
| // #endif | |||
| } | |||
| } | |||
| } | |||
| template <> | |||
| void WeightPreprocessOption::config_model_internel<ModelMdl>( | |||
| RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) { | |||
| if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | |||
| auto&& graph_option = model->get_mdl_config().comp_graph->options(); | |||
| if (weight_preprocess) { | |||
| mgb_log_warn("enable weight-preprocess optimization"); | |||
| graph_option.graph_opt.enable_weight_preprocess(); | |||
| //! FIXME: this implemment is not right | |||
| // #if MGB_OPENCL | |||
| // megdnn::opencl::algo_searcher::AlgoSearcherBase:: | |||
| // enable_weight_preprocess(); | |||
| // #endif | |||
| } | |||
| } | |||
| } | |||
| } // namespace lar | |||
| WeightPreprocessOption::WeightPreprocessOption() { | |||
| m_option_name = "weight_preprocess"; | |||
| weight_preprocess = FLAGS_weight_preprocess; | |||
| } | |||
| bool WeightPreprocessOption::is_valid() { | |||
| bool ret = FLAGS_weight_preprocess; | |||
| return ret; | |||
| } | |||
| std::shared_ptr<OptionBase> WeightPreprocessOption::create_option() { | |||
| static std::shared_ptr<WeightPreprocessOption> option(new WeightPreprocessOption); | |||
| if (WeightPreprocessOption::is_valid()) { | |||
| return std::static_pointer_cast<OptionBase>(option); | |||
| } else { | |||
| return nullptr; | |||
| } | |||
| } | |||
| void WeightPreprocessOption::config_model( | |||
| RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) { | |||
| CONFIG_MODEL_FUN; | |||
| } | |||
| ///// fuse conv bias and nonlinear activation opr optimize options //////// | |||
| namespace lar { | |||
| template <> | |||
| void FuseConvBiasNonlinearOption::config_model_internel<ModelLite>( | |||
| RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) { | |||
| LITE_MARK_USED_VAR(model); | |||
| if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | |||
| if (enable_fuse_conv_bias_nonlinearity) { | |||
| LITE_THROW("fuse conv+bias+nonlinearity not supported in lite model"); | |||
| } | |||
| } | |||
| } | |||
| template <> | |||
| void FuseConvBiasNonlinearOption::config_model_internel<ModelMdl>( | |||
| RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) { | |||
| if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | |||
| auto&& graph_option = model->get_mdl_config().comp_graph->options(); | |||
| if (enable_fuse_conv_bias_nonlinearity) { | |||
| mgb_log_warn("enable fuse conv+bias+nonlinearity optimization"); | |||
| graph_option.graph_opt.enable_fuse_conv_bias_nonlinearity(); | |||
| } | |||
| } | |||
| } | |||
| } // namespace lar | |||
| FuseConvBiasNonlinearOption::FuseConvBiasNonlinearOption() { | |||
| m_option_name = "fuse_conv_bias_nonlinear"; | |||
| enable_fuse_conv_bias_nonlinearity = FLAGS_enable_fuse_conv_bias_nonlinearity; | |||
| } | |||
| bool FuseConvBiasNonlinearOption::is_valid() { | |||
| bool ret = FLAGS_enable_fuse_conv_bias_nonlinearity; | |||
| return ret; | |||
| } | |||
| std::shared_ptr<OptionBase> FuseConvBiasNonlinearOption::create_option() { | |||
| static std::shared_ptr<FuseConvBiasNonlinearOption> option( | |||
| new FuseConvBiasNonlinearOption); | |||
| if (FuseConvBiasNonlinearOption::is_valid()) { | |||
| return std::static_pointer_cast<OptionBase>(option); | |||
| } else { | |||
| return nullptr; | |||
| } | |||
| } | |||
| void FuseConvBiasNonlinearOption::config_model( | |||
| RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) { | |||
| CONFIG_MODEL_FUN; | |||
| } | |||
| ///////////////////////// fuse and preprocess optimize options /////////////// | |||
| namespace lar { | |||
| template <> | |||
| void FuseConvBiasElemwiseAddOption::config_model_internel<ModelLite>( | |||
| RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) { | |||
| LITE_MARK_USED_VAR(model); | |||
| if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | |||
| if (enable_fuse_conv_bias_with_z) { | |||
| LITE_THROW( | |||
| "fuse conv+bias+z optimization not supported in lite " | |||
| "model"); | |||
| } | |||
| } | |||
| } | |||
| template <> | |||
| void FuseConvBiasElemwiseAddOption::config_model_internel<ModelMdl>( | |||
| RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) { | |||
| if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | |||
| auto&& graph_option = model->get_mdl_config().comp_graph->options(); | |||
| if (enable_fuse_conv_bias_with_z) { | |||
| mgb_log_warn("enable fuse conv+bias+z optimization"); | |||
| graph_option.graph_opt.enable_fuse_conv_bias_with_z(); | |||
| } | |||
| } | |||
| } | |||
| } // namespace lar | |||
| FuseConvBiasElemwiseAddOption::FuseConvBiasElemwiseAddOption() { | |||
| m_option_name = "fuse_conv_bias_z"; | |||
| enable_fuse_conv_bias_with_z = FLAGS_enable_fuse_conv_bias_with_z; | |||
| } | |||
| bool FuseConvBiasElemwiseAddOption::is_valid() { | |||
| bool ret = FLAGS_enable_fuse_conv_bias_with_z; | |||
| return ret; | |||
| } | |||
| std::shared_ptr<OptionBase> FuseConvBiasElemwiseAddOption::create_option() { | |||
| static std::shared_ptr<FuseConvBiasElemwiseAddOption> option( | |||
| new FuseConvBiasElemwiseAddOption); | |||
| if (FuseConvBiasElemwiseAddOption::is_valid()) { | |||
| return std::static_pointer_cast<OptionBase>(option); | |||
| } else { | |||
| return nullptr; | |||
| } | |||
| } | |||
| void FuseConvBiasElemwiseAddOption::config_model( | |||
| RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) { | |||
| CONFIG_MODEL_FUN; | |||
| } | |||
| ///////////////////////// graph retrict options ///////////////////////// | |||
| namespace lar { | |||
| template <> | |||
| void GraphRecordOption::config_model_internel<ModelLite>( | |||
| RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) { | |||
| if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | |||
| auto&& config_option = model->get_config().options; | |||
| if (const_shape) { | |||
| LITE_WARN("enable const var shape"); | |||
| config_option.const_shape = true; | |||
| } | |||
| if (fake_first) { | |||
| LITE_WARN("enable fake-first optimization"); | |||
| config_option.fake_next_exec = true; | |||
| } | |||
| if (no_sanity_check) { | |||
| LITE_WARN("disable var sanity check optimization"); | |||
| config_option.var_sanity_check_first_run = false; | |||
| } | |||
| if (m_record_comp_seq == 1) { | |||
| LITE_WARN("set record_comp_seq_level to 1"); | |||
| } | |||
| if (m_record_comp_seq == 2) { | |||
| mgb_assert( | |||
| no_sanity_check, | |||
| "--no-sanity-check should be set before " | |||
| "--record-comp-seq2"); | |||
| LITE_WARN("set record_comp_seq_level to 2"); | |||
| } | |||
| config_option.comp_node_seq_record_level = m_record_comp_seq; | |||
| } | |||
| } | |||
| template <> | |||
| void GraphRecordOption::config_model_internel<ModelMdl>( | |||
| RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) { | |||
| if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | |||
| auto&& graph_option = model->get_mdl_config().comp_graph->options(); | |||
| if (const_shape) { | |||
| mgb_log_warn("enable const var shape"); | |||
| model->get_mdl_config().const_var_shape = true; | |||
| } | |||
| if (fake_first) { | |||
| mgb_log_warn("enable fake-first optimization"); | |||
| graph_option.fake_next_exec = true; | |||
| } | |||
| if (no_sanity_check) { | |||
| mgb_log_warn("disable var sanity check optimization"); | |||
| graph_option.var_sanity_check_first_run = false; | |||
| } | |||
| if (m_record_comp_seq == 1) { | |||
| mgb_log_warn("set record_comp_seq_level to 1"); | |||
| } | |||
| if (m_record_comp_seq == 2) { | |||
| mgb_assert( | |||
| no_sanity_check && !fake_first, | |||
| "--no-sanity-check should be set before " | |||
| "--record-comp-seq2 and --fake-first should not be set"); | |||
| mgb_log_warn("set record_comp_seq_level to 2"); | |||
| } | |||
| graph_option.comp_node_seq_record_level = m_record_comp_seq; | |||
| } | |||
| } | |||
| } // namespace lar | |||
| GraphRecordOption::GraphRecordOption() { | |||
| m_option_name = "graph_record"; | |||
| m_record_comp_seq = 0; | |||
| const_shape = FLAGS_const_shape; | |||
| fake_first = FLAGS_fake_first; | |||
| no_sanity_check = FLAGS_no_sanity_check; | |||
| if (FLAGS_record_comp_seq) { | |||
| m_record_comp_seq = 1; | |||
| } | |||
| if (FLAGS_record_comp_seq2) { | |||
| m_record_comp_seq = 2; | |||
| } | |||
| } | |||
| bool GraphRecordOption::is_valid() { | |||
| bool ret = FLAGS_const_shape; | |||
| ret = ret || FLAGS_fake_first; | |||
| ret = ret || FLAGS_no_sanity_check; | |||
| ret = ret || FLAGS_record_comp_seq; | |||
| ret = ret || FLAGS_record_comp_seq2; | |||
| return ret; | |||
| } | |||
| std::shared_ptr<OptionBase> GraphRecordOption::create_option() { | |||
| static std::shared_ptr<GraphRecordOption> option(new GraphRecordOption); | |||
| if (GraphRecordOption::is_valid()) { | |||
| return std::static_pointer_cast<OptionBase>(option); | |||
| } else { | |||
| return nullptr; | |||
| } | |||
| } | |||
| void GraphRecordOption::config_model( | |||
| RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) { | |||
| CONFIG_MODEL_FUN; | |||
| } | |||
| ///////////////////////// graph retrict options ///////////////////////// | |||
| namespace lar { | |||
| template <> | |||
| void MemoryOptimizeOption::config_model_internel<ModelLite>( | |||
| RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) { | |||
| LITE_MARK_USED_VAR(model); | |||
| if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | |||
| if (disable_mem_opt) { | |||
| LITE_THROW("lite model don't support disable memory optimization"); | |||
| } | |||
| } else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) { | |||
| if (workspace_limit != SIZE_MAX) { | |||
| LITE_WARN("set workspace limit to %ld", workspace_limit); | |||
| lite::Runtime::set_network_algo_workspace_limit( | |||
| model->get_lite_network(), workspace_limit); | |||
| } | |||
| } | |||
| } | |||
| template <> | |||
| void MemoryOptimizeOption::config_model_internel<ModelMdl>( | |||
| RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) { | |||
| if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | |||
| auto&& graph_option = model->get_mdl_config().comp_graph->options(); | |||
| if (disable_mem_opt) { | |||
| mgb_log_warn("disable memory optimization"); | |||
| graph_option.seq_opt.enable_mem_plan_opt = false; | |||
| graph_option.seq_opt.enable_mem_reuse_alloc = false; | |||
| } | |||
| if (workspace_limit < SIZE_MAX) { | |||
| mgb_log_warn("set workspace limit to %ld", workspace_limit); | |||
| auto output_spec = model->get_output_spec(); | |||
| mgb::SymbolVarArray vars; | |||
| for (auto i : output_spec) { | |||
| vars.push_back(i.first); | |||
| } | |||
| mgb::gopt::set_opr_algo_workspace_limit_inplace(vars, workspace_limit); | |||
| } | |||
| } | |||
| } | |||
| } // namespace lar | |||
| MemoryOptimizeOption::MemoryOptimizeOption() { | |||
| m_option_name = "memory_optimize"; | |||
| disable_mem_opt = FLAGS_disable_mem_opt; | |||
| workspace_limit = FLAGS_workspace_limit; | |||
| } | |||
| bool MemoryOptimizeOption::is_valid() { | |||
| bool ret = FLAGS_disable_mem_opt; | |||
| ret = ret || FLAGS_workspace_limit < SIZE_MAX; | |||
| return ret; | |||
| } | |||
| std::shared_ptr<OptionBase> MemoryOptimizeOption::create_option() { | |||
| static std::shared_ptr<MemoryOptimizeOption> option(new MemoryOptimizeOption); | |||
| if (MemoryOptimizeOption::is_valid()) { | |||
| return std::static_pointer_cast<OptionBase>(option); | |||
| } else { | |||
| return nullptr; | |||
| } | |||
| } | |||
| void MemoryOptimizeOption::config_model( | |||
| RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) { | |||
| CONFIG_MODEL_FUN; | |||
| } | |||
| ///////////////////////// other options for optimization ///////////////// | |||
| namespace lar { | |||
| template <> | |||
| void JITOption::config_model_internel<ModelLite>( | |||
| RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) { | |||
| if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | |||
| auto&& config_option = model->get_config().options; | |||
| if (enable_jit) { | |||
| LITE_WARN("enable JIT (level 1)"); | |||
| config_option.jit_level = 1; | |||
| } | |||
| } | |||
| } | |||
| template <> | |||
| void JITOption::config_model_internel<ModelMdl>( | |||
| RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) { | |||
| if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | |||
| auto&& graph_option = model->get_mdl_config().comp_graph->options(); | |||
| if (enable_jit) { | |||
| mgb_log_warn("enable JIT (level 1)"); | |||
| graph_option.graph_opt.jit = 1; | |||
| } | |||
| } | |||
| } | |||
| } // namespace lar | |||
| JITOption::JITOption() { | |||
| m_option_name = "JIT"; | |||
| enable_jit = FLAGS_enable_jit; | |||
| } | |||
| bool JITOption::is_valid() { | |||
| bool ret = FLAGS_enable_jit; | |||
| return ret; | |||
| } | |||
| std::shared_ptr<OptionBase> JITOption::create_option() { | |||
| static std::shared_ptr<JITOption> option(new JITOption); | |||
| if (JITOption::is_valid()) { | |||
| return std::static_pointer_cast<OptionBase>(option); | |||
| } else { | |||
| return nullptr; | |||
| } | |||
| } | |||
| void JITOption::config_model( | |||
| RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) { | |||
| CONFIG_MODEL_FUN; | |||
| } | |||
| ///////////////////////// other options for optimization ///////////////// | |||
| #if MGB_ENABLE_TENSOR_RT | |||
| namespace lar { | |||
| template <> | |||
| void TensorRTOption::config_model_internel<ModelLite>( | |||
| RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) { | |||
| if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | |||
| if (!tensorrt_cache.empty()) { | |||
| LITE_WARN("set tensorrt cache as %s", tensorrt_cache.c_str()); | |||
| lite::set_tensor_rt_cache(tensorrt_cache); | |||
| } | |||
| } else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) { | |||
| if (enable_tensorrt) { | |||
| LITE_WARN("enable TensorRT"); | |||
| lite::Runtime::use_tensorrt(model->get_lite_network()); | |||
| } | |||
| } else if (runtime_param.stage == RunStage::AFTER_MODEL_RUNNING) { | |||
| if (!tensorrt_cache.empty()) { | |||
| lite::dump_tensor_rt_cache(); | |||
| } | |||
| } | |||
| } | |||
| template <> | |||
| void TensorRTOption::config_model_internel<ModelMdl>( | |||
| RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) { | |||
| if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | |||
| auto&& graph_option = model->get_mdl_config().comp_graph->options(); | |||
| if (enable_tensorrt) { | |||
| mgb_log_warn("using tensorRT"); | |||
| graph_option.graph_opt.tensorrt = true; | |||
| } | |||
| if (!tensorrt_cache.empty()) { | |||
| mgb_log_warn("use tensorrt cache: %s", tensorrt_cache.c_str()); | |||
| mgb::TensorRTEngineCache::enable_engine_cache(true); | |||
| mgb::TensorRTEngineCache::set_impl( | |||
| std::make_shared<mgb::TensorRTEngineCacheIO>( | |||
| tensorrt_cache.c_str())); | |||
| } | |||
| } else if (runtime_param.stage == RunStage::AFTER_MODEL_RUNNING) { | |||
| if (!tensorrt_cache.empty()) { | |||
| if (mgb::TensorRTEngineCache::enable_engine_cache()) { | |||
| mgb::TensorRTEngineCache::inst().dump_cache(); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } // namespace lar | |||
| TensorRTOption::TensorRTOption() { | |||
| m_option_name = "tensorRT"; | |||
| enable_tensorrt = FLAGS_tensorrt; | |||
| tensorrt_cache = FLAGS_tensorrt_cache; | |||
| } | |||
| bool TensorRTOption::is_valid() { | |||
| bool ret = FLAGS_tensorrt; | |||
| ret = ret || !FLAGS_tensorrt_cache.empty(); | |||
| return ret; | |||
| } | |||
| std::shared_ptr<OptionBase> TensorRTOption::create_option() { | |||
| static std::shared_ptr<TensorRTOption> option(new TensorRTOption); | |||
| if (TensorRTOption::is_valid()) { | |||
| return std::static_pointer_cast<OptionBase>(option); | |||
| } else { | |||
| return nullptr; | |||
| } | |||
| } | |||
| void TensorRTOption::config_model( | |||
| RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) { | |||
| CONFIG_MODEL_FUN; | |||
| } | |||
| #endif | |||
| ///////////////////////// fuse and preprocess optimize options /////////////// | |||
| DEFINE_bool( | |||
| enable_fuse_preprocess, false, | |||
| "Fusion astype | pad_channel | dimshuffle and etc opr from h2d opr"); | |||
| DEFINE_bool( | |||
| weight_preprocess, false, | |||
| "Execute operators with weight preprocess, which can optimize the " | |||
| "operator execution time with algo of winograd, im2col ,etc., but " | |||
| "it may consume more memory."); | |||
| DEFINE_bool( | |||
| enable_fuse_conv_bias_nonlinearity, false, | |||
| "whether to fuse conv+bias+nonlinearity"); | |||
| DEFINE_bool( | |||
| enable_fuse_conv_bias_with_z, false, | |||
| "fuse conv,bias (elemwise add),z(elemwise add) into one opr " | |||
| "(only support on GPU)"); | |||
| ///////////////////////// graph retrict options ///////////////////////// | |||
| DEFINE_bool( | |||
| const_shape, false, | |||
| "set const_var_shape to reduce memory usage, since some static " | |||
| "inference data structures can be omitted"); | |||
| DEFINE_bool( | |||
| fake_first, false, | |||
| "Enable fake exec for the first run. In fake exec mode, some " | |||
| "initialization job would be done, but no actual computing is " | |||
| "performed."); | |||
| DEFINE_bool(no_sanity_check, false, "Disable var sanity check on the first run"); | |||
| DEFINE_bool( | |||
| record_comp_seq, false, | |||
| "Record the computing sequence, in level 1 . It reduces overhead of API" | |||
| "calls of some asynchronous computing devices"); | |||
| DEFINE_bool( | |||
| record_comp_seq2, false, | |||
| "Record the computing sequence, in level 2, the computing graph can be" | |||
| "destructed to reduce memory usage"); | |||
| DEFINE_bool(disable_mem_opt, false, "disable memory optimization!!"); | |||
| DEFINE_uint64(workspace_limit, SIZE_MAX, "set workspace upbound limit"); | |||
| ///////////////////////// other options for optimization ///////////////// | |||
| DEFINE_bool( | |||
| enable_jit, false, | |||
| " Execute supported operators with JIT(now only support NVRTC). " | |||
| "Can only be used on Nvidia GPUs"); | |||
| #if MGB_ENABLE_ANDROID_NN | |||
| DEFINE_bool( | |||
| android_nn, false, | |||
| "Execute supported operators with Android NN. Can only be used " | |||
| "with --cpu."); | |||
| #endif | |||
| #if MGB_ENABLE_TENSOR_RT | |||
| DEFINE_bool( | |||
| tensorrt, false, | |||
| " Execute supported operators with TensorRT. Can only be used on " | |||
| "Nvidia GPUs,i.e. comp node is xpu or gpu."); | |||
| DEFINE_string( | |||
| tensorrt_cache, "", | |||
| "Set the TensorRT engine cache path for serialized prebuilt " | |||
| "ICudaEngine"); | |||
| #endif | |||
| REGIST_OPTION_CREATOR(fuse_preprocess, lar::FusePreprocessOption::create_option); | |||
| REGIST_OPTION_CREATOR(weight_preprocess, lar::WeightPreprocessOption::create_option); | |||
| REGIST_OPTION_CREATOR( | |||
| fuse_conv_bias_nonlinear, lar::FuseConvBiasNonlinearOption::create_option); | |||
| REGIST_OPTION_CREATOR( | |||
| fuse_conv_bias_z, lar::FuseConvBiasElemwiseAddOption::create_option); | |||
| REGIST_OPTION_CREATOR(graph_record, lar::GraphRecordOption::create_option); | |||
| REGIST_OPTION_CREATOR(memory_optimize, lar::MemoryOptimizeOption::create_option); | |||
| REGIST_OPTION_CREATOR(JIT, lar::JITOption::create_option); | |||
| #if MGB_ENABLE_TENSOR_RT | |||
| REGIST_OPTION_CREATOR(tensorRT, lar::TensorRTOption::create_option); | |||
| #endif | |||
| @@ -0,0 +1,207 @@ | |||
| /** | |||
| * \file lite/load_and_run/src/options/optimize_options.h | |||
| * | |||
| * This file is part of MegEngine, a deep learning framework developed by | |||
| * Megvii. | |||
| * | |||
| * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||
| */ | |||
| #pragma once | |||
| #include <gflags/gflags.h> | |||
| #include "helpers/common.h" | |||
| #include "models/model.h" | |||
| #include "option_base.h" | |||
| DECLARE_bool(enable_fuse_preprocess); | |||
| DECLARE_bool(weight_preprocess); | |||
| DECLARE_bool(enable_fuse_conv_bias_nonlinearity); | |||
| DECLARE_bool(enable_fuse_conv_bias_with_z); | |||
| DECLARE_bool(const_shape); | |||
| DECLARE_bool(fake_first); | |||
| DECLARE_bool(no_sanity_check); | |||
| DECLARE_bool(record_comp_seq); | |||
| DECLARE_bool(record_comp_seq2); | |||
| DECLARE_bool(disable_mem_opt); | |||
| DECLARE_uint64(workspace_limit); | |||
| DECLARE_bool(enable_jit); | |||
| #if MGB_ENABLE_TENSOR_RT | |||
| DECLARE_bool(tensorrt); | |||
| DECLARE_string(tensorrt_cache); | |||
| #endif | |||
| namespace lar { | |||
| ///////////////////////// fuse_preprocess optimize options ////////////// | |||
| class FusePreprocessOption final : public OptionBase { | |||
| public: | |||
| static bool is_valid(); | |||
| static std::shared_ptr<OptionBase> create_option(); | |||
| void config_model( | |||
| RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) override; | |||
| std::string option_name() const override { return m_option_name; }; | |||
| private: | |||
| FusePreprocessOption(); | |||
| template <typename ModelImpl> | |||
| void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){}; | |||
| std::string m_option_name; | |||
| bool enable_fuse_preprocess; | |||
| }; | |||
| ///////////////////////// weight preprocess optimize options ////////////// | |||
| class WeightPreprocessOption final : public OptionBase { | |||
| public: | |||
| static bool is_valid(); | |||
| static std::shared_ptr<OptionBase> create_option(); | |||
| void config_model( | |||
| RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) override; | |||
| std::string option_name() const override { return m_option_name; }; | |||
| private: | |||
| WeightPreprocessOption(); | |||
| template <typename ModelImpl> | |||
| void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){}; | |||
| std::string m_option_name; | |||
| bool weight_preprocess; | |||
| }; | |||
| /////////////// fuse_conv_bias_nonlinearity optimize options /////////////// | |||
| class FuseConvBiasNonlinearOption final : public OptionBase { | |||
| public: | |||
| static bool is_valid(); | |||
| static std::shared_ptr<OptionBase> create_option(); | |||
| void config_model( | |||
| RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) override; | |||
| std::string option_name() const override { return m_option_name; }; | |||
| private: | |||
| FuseConvBiasNonlinearOption(); | |||
| template <typename ModelImpl> | |||
| void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){}; | |||
| std::string m_option_name; | |||
| bool enable_fuse_conv_bias_nonlinearity; | |||
| }; | |||
| ///////////////////////// fuse_conv_bias_with_z optimize options ////////////// | |||
| class FuseConvBiasElemwiseAddOption final : public OptionBase { | |||
| public: | |||
| static bool is_valid(); | |||
| static std::shared_ptr<OptionBase> create_option(); | |||
| void config_model( | |||
| RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) override; | |||
| std::string option_name() const override { return m_option_name; }; | |||
| private: | |||
| FuseConvBiasElemwiseAddOption(); | |||
| template <typename ModelImpl> | |||
| void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){}; | |||
| std::string m_option_name; | |||
| bool enable_fuse_conv_bias_with_z; | |||
| }; | |||
| ///////////////////////// graph record options /////////////////////////// | |||
| class GraphRecordOption final : public OptionBase { | |||
| public: | |||
| static bool is_valid(); | |||
| static std::shared_ptr<OptionBase> create_option(); | |||
| void config_model( | |||
| RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) override; | |||
| std::string option_name() const override { return m_option_name; }; | |||
| private: | |||
| GraphRecordOption(); | |||
| template <typename ModelImpl> | |||
| void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){}; | |||
| std::string m_option_name; | |||
| size_t m_record_comp_seq; | |||
| bool const_shape; | |||
| bool fake_first; | |||
| bool no_sanity_check; | |||
| }; | |||
| ///////////////////////// memory optimize options ///////////////////////// | |||
| class MemoryOptimizeOption final : public OptionBase { | |||
| public: | |||
| static bool is_valid(); | |||
| static std::shared_ptr<OptionBase> create_option(); | |||
| void config_model( | |||
| RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) override; | |||
| std::string option_name() const override { return m_option_name; }; | |||
| private: | |||
| MemoryOptimizeOption(); | |||
| template <typename ModelImpl> | |||
| void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){}; | |||
| std::string m_option_name; | |||
| bool disable_mem_opt; | |||
| uint64_t workspace_limit; | |||
| }; | |||
| ///////////////////////// other options for optimization ///////////////// | |||
| class JITOption final : public OptionBase { | |||
| public: | |||
| static bool is_valid(); | |||
| static std::shared_ptr<OptionBase> create_option(); | |||
| void config_model( | |||
| RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) override; | |||
| std::string option_name() const override { return m_option_name; }; | |||
| private: | |||
| JITOption(); | |||
| template <typename ModelImpl> | |||
| void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){}; | |||
| std::string m_option_name; | |||
| bool enable_jit; | |||
| }; | |||
| ///////////////////////// TensorRT options for optimization ///////////////// | |||
| #if MGB_ENABLE_TENSOR_RT | |||
| class TensorRTOption final : public OptionBase { | |||
| public: | |||
| static bool is_valid(); | |||
| static std::shared_ptr<OptionBase> create_option(); | |||
| void config_model( | |||
| RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) override; | |||
| std::string option_name() const override { return m_option_name; }; | |||
| private: | |||
| TensorRTOption(); | |||
| template <typename ModelImpl> | |||
| void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){}; | |||
| std::string m_option_name; | |||
| bool enable_tensorrt; | |||
| std::string tensorrt_cache; | |||
| }; | |||
| #endif | |||
| } // namespace lar | |||
| @@ -0,0 +1,87 @@ | |||
| /** | |||
| * \file lite/load_and_run/src/options/option_base.h | |||
| * | |||
| * This file is part of MegEngine, a deep learning framework developed by | |||
| * Megvii. | |||
| * | |||
| * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||
| */ | |||
| #pragma once | |||
| #include <functional> | |||
| #include <iostream> | |||
| #include <memory> | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include <vector> | |||
| #include "megbrain/common.h" | |||
| #include "helpers/common.h" | |||
| #include "models/model.h" | |||
| namespace lar { | |||
| /*! | |||
| * \brief: base class of options | |||
| */ | |||
| class OptionBase { | |||
| public: | |||
| //! configure model in different runtime state | |||
| virtual void config_model( | |||
| RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) = 0; | |||
| //! get depend options | |||
| virtual std::vector<std::string> depend_option() const { return {}; }; | |||
| //! get option name | |||
| virtual std::string option_name() const = 0; | |||
| virtual ~OptionBase() = default; | |||
| }; | |||
| /*! | |||
| * \brief: Singleton option factory for register options before main function | |||
| */ | |||
| class OptionFactory { | |||
| public: | |||
| using OptionCreator = std::function<std::shared_ptr<OptionBase>()>; | |||
| using OptionMap = std::unordered_map<std::string, OptionCreator>; | |||
| //! get Singleton option factory | |||
| static OptionFactory& get_Instance() { | |||
| static OptionFactory instance; | |||
| return instance; | |||
| } | |||
| //! registe option creator into option map | |||
| void registe_options(std::string name, OptionCreator creator) { | |||
| if (option_creator_map.count(name) == 0) { | |||
| option_creator_map[name] = creator; | |||
| } | |||
| } | |||
| //! get creator map | |||
| OptionMap* get_option_creator_map() { return &option_creator_map; } | |||
| private: | |||
| OptionFactory(){}; | |||
| OptionMap option_creator_map; | |||
| }; | |||
| } // namespace lar | |||
| #define REGIST_OPTION_CREATOR(name_, creator_) \ | |||
| struct OptionRegister_##name_ { \ | |||
| OptionRegister_##name_() { \ | |||
| lar::OptionFactory::get_Instance().registe_options(#name_, creator_); \ | |||
| } \ | |||
| }; \ | |||
| OptionRegister_##name_ name_; | |||
| #define CONFIG_MODEL_FUN \ | |||
| if (model->type() == ModelType::LITE_MODEL) { \ | |||
| config_model_internel<ModelLite>( \ | |||
| runtime_param, std::static_pointer_cast<ModelLite>(model)); \ | |||
| } else if (model->type() == ModelType::MEGDL_MODEL) { \ | |||
| config_model_internel<ModelMdl>( \ | |||
| runtime_param, std::static_pointer_cast<ModelMdl>(model)); \ | |||
| } | |||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||
| @@ -0,0 +1,401 @@ | |||
| /** | |||
| * \file lite/load_and_run/src/options/plugin_options.cpp | |||
| * | |||
| * This file is part of MegEngine, a deep learning framework developed by | |||
| * Megvii. | |||
| * | |||
| * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||
| */ | |||
| #include "plugin_options.h" | |||
| #include "misc.h" | |||
| #include "models/model_lite.h" | |||
| #include "models/model_mdl.h" | |||
| ///////////////////// Plugin options/////////////////////////// | |||
| namespace lar { | |||
| template <> | |||
| void PluginOption::config_model_internel<ModelLite>( | |||
| RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) { | |||
| if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | |||
| LITE_ASSERT(range == 0, "lite model don't support NumRangeChecker plugin"); | |||
| LITE_ASSERT( | |||
| !enable_check_dispatch, | |||
| "lite model don't support CPUDispatchChecker plugin"); | |||
| LITE_ASSERT( | |||
| var_value_check_str.empty(), | |||
| "lite model don't support VarValueChecker plugin"); | |||
| } | |||
| #if MGB_ENABLE_JSON | |||
| else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) { | |||
| if (!profile_path.empty()) { | |||
| if (!enable_profile_host) { | |||
| LITE_WARN("enable profiling"); | |||
| model->get_lite_network()->enable_profile_performance(profile_path); | |||
| } else { | |||
| LITE_WARN("enable profiling for host"); | |||
| model->get_lite_network()->enable_profile_performance(profile_path); | |||
| } | |||
| } | |||
| } | |||
| #endif | |||
| } | |||
| template <> | |||
| void PluginOption::config_model_internel<ModelMdl>( | |||
| RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) { | |||
| if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | |||
| auto config = model->get_mdl_config(); | |||
| if (range > 0) { | |||
| mgb_log_warn("enable number range check"); | |||
| model->set_num_range_checker(float(range)); | |||
| } | |||
| if (enable_check_dispatch) { | |||
| mgb_log_warn("enable cpu dispatch check"); | |||
| cpu_dispatch_checker = | |||
| std::make_unique<mgb::CPUDispatchChecker>(config.comp_graph.get()); | |||
| } | |||
| if (!var_value_check_str.empty()) { | |||
| mgb_log_warn("enable variable value check"); | |||
| size_t init_idx = 0, switch_interval; | |||
| auto sep = var_value_check_str.find(':'); | |||
| if (sep != std::string::npos) { | |||
| switch_interval = std::stoul(var_value_check_str.substr(0, sep)); | |||
| init_idx = std::stoul(var_value_check_str.substr(sep + 1)); | |||
| } else { | |||
| switch_interval = std::stoul(var_value_check_str); | |||
| } | |||
| var_value_checker = std::make_unique<mgb::VarValueChecker>( | |||
| config.comp_graph.get(), switch_interval, init_idx); | |||
| } | |||
| #if MGB_ENABLE_JSON | |||
| if (!profile_path.empty()) { | |||
| if (!enable_profile_host) { | |||
| mgb_log_warn("enable profiling"); | |||
| } else { | |||
| mgb_log_warn("enable profiling for host"); | |||
| } | |||
| model->set_profiler(); | |||
| } | |||
| #endif | |||
| } | |||
| else if (runtime_param.stage == RunStage::AFTER_MODEL_RUNNING) { | |||
| #if MGB_ENABLE_JSON | |||
| if (!profile_path.empty()) { | |||
| mgb_log_warn("filename %s", profile_path.c_str()); | |||
| if (model->get_profiler()) { | |||
| model->get_profiler() | |||
| ->to_json_full(model->get_async_func().get()) | |||
| ->writeto_fpath(profile_path); | |||
| mgb_log_warn("profiling result written to %s", profile_path.c_str()); | |||
| } | |||
| } | |||
| #endif | |||
| } | |||
| } | |||
| } // namespace lar | |||
| using namespace lar; | |||
| PluginOption::PluginOption() { | |||
| m_option_name = "plugin"; | |||
| range = FLAGS_range; | |||
| enable_check_dispatch = FLAGS_check_dispatch; | |||
| var_value_check_str = FLAGS_check_var_value; | |||
| #if MGB_ENABLE_JSON | |||
| enable_profile_host = false; | |||
| if (!FLAGS_profile.empty()) { | |||
| profile_path = FLAGS_profile; | |||
| } | |||
| if (!FLAGS_profile_host.empty()) { | |||
| enable_profile_host = !FLAGS_profile_host.empty(); | |||
| profile_path = FLAGS_profile_host; | |||
| } | |||
| #endif | |||
| } | |||
| bool PluginOption::is_valid() { | |||
| bool ret = FLAGS_check_dispatch; | |||
| ret = ret || FLAGS_range > 0; | |||
| ret = ret || !FLAGS_check_var_value.empty(); | |||
| #if MGB_ENABLE_JSON | |||
| ret = ret || !FLAGS_profile.empty(); | |||
| ret = ret || !FLAGS_profile_host.empty(); | |||
| #endif | |||
| return ret; | |||
| } | |||
| std::shared_ptr<OptionBase> PluginOption::create_option() { | |||
| static std::shared_ptr<PluginOption> option(new PluginOption); | |||
| if (PluginOption::is_valid()) { | |||
| return std::static_pointer_cast<OptionBase>(option); | |||
| } else { | |||
| return nullptr; | |||
| } | |||
| } | |||
| void PluginOption::config_model( | |||
| RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) { | |||
| CONFIG_MODEL_FUN; | |||
| } | |||
| ///////////////////// Debug options/////////////////////////// | |||
| namespace lar { | |||
| template <> | |||
| void DebugOption::format_and_print( | |||
| const std::string& tablename, std::shared_ptr<ModelLite> model) { | |||
| auto table = mgb::TextTable(tablename); | |||
| auto network = model->get_lite_network(); | |||
| table.padding(1); | |||
| table.align(mgb::TextTable::Align::Mid).add("type").add("name").add("shape").eor(); | |||
| auto to_string = [&](lite::Layout& layout) { | |||
| std::string shape("{"); | |||
| for (size_t i = 0; i < layout.ndim; i++) { | |||
| if (i) | |||
| shape.append(","); | |||
| shape.append(std::to_string(layout.shapes[i])); | |||
| } | |||
| shape.append("}"); | |||
| return shape; | |||
| }; | |||
| auto input_name = network->get_all_input_name(); | |||
| for (auto& i : input_name) { | |||
| auto layout = network->get_io_tensor(i)->get_layout(); | |||
| table.align(mgb::TextTable::Align::Mid) | |||
| .add("INPUT") | |||
| .add(i) | |||
| .add(to_string(layout)) | |||
| .eor(); | |||
| } | |||
| auto output_name = network->get_all_output_name(); | |||
| for (auto& i : output_name) { | |||
| auto layout = network->get_io_tensor(i)->get_layout(); | |||
| table.align(mgb::TextTable::Align::Mid) | |||
| .add("OUTPUT") | |||
| .add(i) | |||
| .add(to_string(layout)) | |||
| .eor(); | |||
| } | |||
| std::stringstream ss; | |||
| ss << table; | |||
| printf("%s\n\n", ss.str().c_str()); | |||
| } | |||
| template <> | |||
| void DebugOption::format_and_print( | |||
| const std::string& tablename, std::shared_ptr<ModelMdl> model) { | |||
| auto table = mgb::TextTable(tablename); | |||
| table.padding(1); | |||
| table.align(mgb::TextTable::Align::Mid).add("type").add("name").add("shape").eor(); | |||
| for (auto&& i : model->get_mdl_load_result().tensor_map) { | |||
| table.align(mgb::TextTable::Align::Mid) | |||
| .add("INPUT") | |||
| .add(i.first) | |||
| .add(i.second->shape().to_string()) | |||
| .eor(); | |||
| } | |||
| for (auto&& i : model->get_mdl_load_result().output_var_list) { | |||
| table.align(mgb::TextTable::Align::Mid) | |||
| .add("OUTPUT") | |||
| .add(i.node()->name()) | |||
| .add(i.shape().to_string()) | |||
| .eor(); | |||
| } | |||
| std::stringstream ss; | |||
| ss << table; | |||
| printf("%s\n\n", ss.str().c_str()); | |||
| } | |||
| template <> | |||
| void DebugOption::config_model_internel<ModelLite>( | |||
| RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) { | |||
| if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | |||
| LITE_ASSERT( | |||
| !disable_assert_throw, "lite model don't support disable assert throw"); | |||
| #ifndef __IN_TEE_ENV__ | |||
| #if MGB_ENABLE_JSON | |||
| LITE_ASSERT( | |||
| static_mem_log_dir_path.empty(), | |||
| "lite model don't support static memory information export"); | |||
| #endif | |||
| #endif | |||
| if (enable_verbose) { | |||
| LITE_WARN("enable verbose"); | |||
| lite::set_log_level(LiteLogLevel::DEBUG); | |||
| } | |||
| #if __linux__ || __unix__ | |||
| if (enable_wait_gdb) { | |||
| printf("wait for gdb attach (pid=%d): ", getpid()); | |||
| getchar(); | |||
| } | |||
| #endif | |||
| } else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) { | |||
| if (enable_display_model_info) { | |||
| LITE_WARN("enable display model information"); | |||
| format_and_print<ModelLite>("Runtime Model Info", model); | |||
| } | |||
| } else if (runtime_param.stage == RunStage::AFTER_MODEL_RUNNING) { | |||
| if (enable_display_model_info) { | |||
| format_and_print<ModelLite>("Runtime Model Info", model); | |||
| } | |||
| } | |||
| } | |||
| template <> | |||
| void DebugOption::config_model_internel<ModelMdl>( | |||
| RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) { | |||
| if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | |||
| auto config = model->get_mdl_config(); | |||
| if (enable_verbose) { | |||
| mgb_log_warn("enable verbose"); | |||
| mgb::set_log_level(mgb::LogLevel::DEBUG); | |||
| } | |||
| #if __linux__ || __unix__ | |||
| if (enable_wait_gdb) { | |||
| printf("wait for gdb attach (pid=%d): ", getpid()); | |||
| getchar(); | |||
| } | |||
| #endif | |||
| } else if (runtime_param.stage == RunStage::AFTER_OUTSPEC_SET) { | |||
| if (enable_display_model_info) { | |||
| mgb_log_warn("enable display model information"); | |||
| format_and_print<ModelMdl>("Runtime Model Info", model); | |||
| } | |||
| if (disable_assert_throw) { | |||
| mgb_log_warn("disable assert throw"); | |||
| auto on_opr = [](mgb::cg::OperatorNodeBase* opr) { | |||
| if (opr->same_type<mgb::opr::AssertEqual>()) { | |||
| opr->cast_final<mgb::opr::AssertEqual>().disable_throw_on_error(); | |||
| } | |||
| }; | |||
| mgb::cg::DepOprIter iter{on_opr}; | |||
| for (auto&& i : model->get_output_spec()) { | |||
| iter.add(i.first.node()->owner_opr()); | |||
| } | |||
| } | |||
| } else if (runtime_param.stage == RunStage::AFTER_OUTSPEC_SET) { | |||
| //! FIX:it don't work for cpu build (nothing dumped) | |||
| //! megbrain/sdk origin code will assert(m_recorded) in | |||
| //! EventImplHelper::finished(); | |||
| #ifndef __IN_TEE_ENV__ | |||
| #if MGB_ENABLE_JSON | |||
| if (!static_mem_log_dir_path.empty()) { | |||
| mgb_log_warn("enable get static memeory information"); | |||
| model->get_async_func()->get_static_memory_alloc_info( | |||
| static_mem_log_dir_path); | |||
| } | |||
| #endif | |||
| #endif | |||
| } else if (runtime_param.stage == RunStage::AFTER_MODEL_RUNNING) { | |||
| if (enable_display_model_info) { | |||
| format_and_print<ModelMdl>("Runtime Model Info", model); | |||
| } | |||
| } | |||
| } | |||
| } // namespace lar | |||
| DebugOption::DebugOption() { | |||
| m_option_name = "debug"; | |||
| enable_display_model_info = FLAGS_model_info; | |||
| enable_verbose = FLAGS_verbose; | |||
| disable_assert_throw = FLAGS_disable_assert_throw; | |||
| #if __linux__ || __unix__ | |||
| enable_wait_gdb = FLAGS_wait_gdb; | |||
| #endif | |||
| #ifndef __IN_TEE_ENV__ | |||
| #if MGB_ENABLE_JSON | |||
| static_mem_log_dir_path = FLAGS_get_static_mem_info; | |||
| #endif | |||
| #endif | |||
| } | |||
| bool DebugOption::is_valid() { | |||
| bool ret = FLAGS_model_info; | |||
| ret = ret || FLAGS_verbose; | |||
| ret = ret || FLAGS_disable_assert_throw; | |||
| #if __linux__ || __unix__ | |||
| ret = ret || FLAGS_wait_gdb; | |||
| #endif | |||
| #ifndef __IN_TEE_ENV__ | |||
| #if MGB_ENABLE_JSON | |||
| ret = ret || !FLAGS_get_static_mem_info.empty(); | |||
| #endif | |||
| #endif | |||
| return ret; | |||
| } | |||
| std::shared_ptr<OptionBase> DebugOption::create_option() { | |||
| static std::shared_ptr<DebugOption> option(new DebugOption); | |||
| if (DebugOption::is_valid()) { | |||
| return std::static_pointer_cast<OptionBase>(option); | |||
| } else { | |||
| return nullptr; | |||
| } | |||
| } | |||
| void DebugOption::config_model( | |||
| RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) { | |||
| CONFIG_MODEL_FUN; | |||
| } | |||
| ///////////////////// Plugin gflags/////////////////////////// | |||
| DEFINE_double( | |||
| range, 0, | |||
| "check whether absolute value of all numbers in computing graph " | |||
| "is in the given range"); | |||
| DEFINE_bool( | |||
| check_dispatch, false, | |||
| "check whether an operator call dispatch on cpu comp nodes"); | |||
| DEFINE_string( | |||
| check_var_value, "", | |||
| "--check-var-value [interval]|[interval:init_idx], Enable " | |||
| "VarValueChecker plugin. Refer to its doc for more details"); | |||
| #if MGB_ENABLE_JSON | |||
| DEFINE_string( | |||
| profile, "", | |||
| "Write profiling result to given file. The output file is in " | |||
| "JSON format"); | |||
| DEFINE_string(profile_host, "", "focus on host time profiling For some backends"); | |||
| #endif | |||
| ///////////////////// Debug gflags/////////////////////////// | |||
| DEFINE_bool( | |||
| model_info, false, | |||
| " Format and display model input/output tensor inforamtion"); | |||
| DEFINE_bool(verbose, false, "get more inforamtion for debug"); | |||
| DEFINE_bool(disable_assert_throw, false, "disable assert throw on error check"); | |||
| #if __linux__ || __unix__ | |||
| DEFINE_bool(wait_gdb, false, "print current process PID and wait for gdb attach"); | |||
| #endif | |||
| #ifndef __IN_TEE_ENV__ | |||
| #if MGB_ENABLE_JSON | |||
| DEFINE_string( | |||
| get_static_mem_info, "", | |||
| "Record the static computing graph's static memory information"); | |||
| #endif | |||
| #endif | |||
| REGIST_OPTION_CREATOR(plugin, lar::PluginOption::create_option); | |||
| REGIST_OPTION_CREATOR(debug, lar::DebugOption::create_option); | |||
| @@ -0,0 +1,105 @@ | |||
| /** | |||
| * \file lite/load_and_run/src/options/plugin_options.h | |||
| * | |||
| * This file is part of MegEngine, a deep learning framework developed by | |||
| * Megvii. | |||
| * | |||
| * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||
| */ | |||
| #pragma once | |||
| #include <gflags/gflags.h> | |||
| #if __linux__ || __unix__ | |||
| #include <unistd.h> | |||
| #endif | |||
| #include "megbrain/plugin/cpu_dispatch_checker.h" | |||
| #include "megbrain/plugin/var_value_checker.h" | |||
| #include "helpers/common.h" | |||
| #include "helpers/text_table.h" | |||
| #include "models/model.h" | |||
| #include "option_base.h" | |||
| DECLARE_bool(check_dispatch); | |||
| DECLARE_double(range); | |||
| DECLARE_string(check_var_value); | |||
| #if MGB_ENABLE_JSON | |||
| DECLARE_string(profile); | |||
| DECLARE_string(profile_host); | |||
| #endif | |||
| DECLARE_bool(model_info); | |||
| DECLARE_bool(verbose); | |||
| DECLARE_bool(disable_assert_throw); | |||
| #if __linux__ || __unix__ | |||
| DECLARE_bool(wait_gdb); | |||
| #endif | |||
| #ifndef __IN_TEE_ENV__ | |||
| #if MGB_ENABLE_JSON | |||
| DECLARE_string(get_static_mem_info); | |||
| #endif | |||
| #endif | |||
| namespace lar { | |||
| class PluginOption final : public OptionBase { | |||
| public: | |||
| static bool is_valid(); | |||
| static std::shared_ptr<OptionBase> create_option(); | |||
| void config_model( | |||
| RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) override; | |||
| std::string option_name() const override { return m_option_name; }; | |||
| private: | |||
| PluginOption(); | |||
| template <typename ModelImpl> | |||
| void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){}; | |||
| double range; | |||
| bool enable_check_dispatch; | |||
| #if MGB_ENABLE_JSON | |||
| bool enable_profile_host; | |||
| std::string profile_path; | |||
| #endif | |||
| std::string var_value_check_str; | |||
| std::string m_option_name; | |||
| std::unique_ptr<mgb::VarValueChecker> var_value_checker; | |||
| std::unique_ptr<mgb::CPUDispatchChecker> cpu_dispatch_checker; | |||
| }; | |||
| class DebugOption final : public OptionBase { | |||
| public: | |||
| static bool is_valid(); | |||
| static std::shared_ptr<OptionBase> create_option(); | |||
| void config_model( | |||
| RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) override; | |||
| std::string option_name() const override { return m_option_name; }; | |||
| private: | |||
| DebugOption(); | |||
| template <typename ModelImpl> | |||
| void format_and_print(const std::string&, std::shared_ptr<ModelImpl>){}; | |||
| template <typename ModelImpl> | |||
| void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){}; | |||
| bool enable_display_model_info; | |||
| bool enable_verbose; | |||
| bool disable_assert_throw; | |||
| #if __linux__ || __unix__ | |||
| bool enable_wait_gdb; | |||
| #endif | |||
| #ifndef __IN_TEE_ENV__ | |||
| #if MGB_ENABLE_JSON | |||
| std::string static_mem_log_dir_path; | |||
| #endif | |||
| #endif | |||
| std::string m_option_name; | |||
| }; | |||
| } // namespace lar | |||
| @@ -0,0 +1,96 @@ | |||
| /** | |||
| * \file lite/load_and_run/src/options/strategy_options.cpp | |||
| * | |||
| * This file is part of MegEngine, a deep learning framework developed by | |||
| * Megvii. | |||
| * | |||
| * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||
| */ | |||
| #include "strategy_options.h" | |||
| #include "models/model_mdl.h" | |||
| using namespace lar; | |||
| DECLARE_bool(c_opr_lib_with_param); | |||
| StrategyOption::StrategyOption() { | |||
| m_option_name = "run_strategy"; | |||
| warmup_iter = FLAGS_warmup_iter; | |||
| run_iter = FLAGS_iter; | |||
| threads = FLAGS_thread; | |||
| } | |||
| std::shared_ptr<OptionBase> StrategyOption::create_option() { | |||
| static std::shared_ptr<StrategyOption> option(new StrategyOption); | |||
| return std::static_pointer_cast<OptionBase>(option); | |||
| } | |||
| void StrategyOption::config_model( | |||
| RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) { | |||
| if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | |||
| model->set_shared_mem(FLAGS_share_param_mem); | |||
| runtime_param.warmup_iter = warmup_iter; | |||
| runtime_param.run_iter = run_iter; | |||
| runtime_param.threads = threads; | |||
| runtime_param.testcase_num = 1; | |||
| } else if (runtime_param.stage == RunStage::BEFORE_OUTSPEC_SET) { | |||
| if (model->type() == ModelType::MEGDL_MODEL) { | |||
| auto model_ptr = std::static_pointer_cast<ModelMdl>(model); | |||
| auto num = model_ptr->get_testcase_num(); | |||
| if (num != 0) | |||
| runtime_param.testcase_num = num; | |||
| model_ptr->make_output_spec(); | |||
| } | |||
| } | |||
| } | |||
| TestcaseOption::TestcaseOption() { | |||
| m_option_name = "run_testcase"; | |||
| } | |||
| std::shared_ptr<OptionBase> TestcaseOption::create_option() { | |||
| static std::shared_ptr<TestcaseOption> option(new TestcaseOption); | |||
| return std::static_pointer_cast<OptionBase>(option); | |||
| } | |||
| void TestcaseOption::config_model( | |||
| RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) { | |||
| if (model->type() == ModelType::MEGDL_MODEL) { | |||
| auto model_ptr = std::static_pointer_cast<ModelMdl>(model); | |||
| if (model_ptr->get_testcase_num() && !FLAGS_c_opr_lib_with_param) { | |||
| if (runtime_param.stage == RunStage::MODEL_RUNNING) { | |||
| auto load_result = model_ptr->get_mdl_load_result(); | |||
| auto input_tensor = model_ptr->get_test_input(); | |||
| auto loader = model_ptr->reset_loader(); | |||
| auto testcase = loader->load(model_ptr->get_mdl_config(), false); | |||
| mgb_assert(testcase.output_var_list.size() == input_tensor.size()); | |||
| for (size_t i = 0; i < input_tensor.size(); ++i) { | |||
| auto&& opr = | |||
| testcase.output_var_list[i] | |||
| .node() | |||
| ->owner_opr() | |||
| ->cast_final_safe<mgb::opr::SharedDeviceTensor>(); | |||
| input_tensor[i].second->copy_from( | |||
| mgb::HostTensorND::make_proxy(*opr.dev_data())); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| DEFINE_int32(iter, 10, "iteration number for run model"); | |||
| DEFINE_int32(warmup_iter, 1, "iteration number for warm up model before run"); | |||
| DEFINE_int32( | |||
| thread, 1, | |||
| "thread number for run model while <thread> is supported( NOTE: " | |||
| "this is not a mapper device setting just for load and run)"); | |||
| DEFINE_bool(share_param_mem, false, "load model from shared memeory"); | |||
| REGIST_OPTION_CREATOR(run_strategy, lar::StrategyOption::create_option); | |||
| REGIST_OPTION_CREATOR(run_testcase, lar::TestcaseOption::create_option); | |||
| @@ -0,0 +1,68 @@ | |||
| /** | |||
| * \file lite/load_and_run/src/options/strategy_options.h | |||
| * | |||
| * This file is part of MegEngine, a deep learning framework developed by | |||
| * Megvii. | |||
| * | |||
| * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||
| */ | |||
| #include <gflags/gflags.h> | |||
| #include "models/model.h" | |||
| #include "option_base.h" | |||
| DECLARE_int32(iter); | |||
| DECLARE_int32(warmup_iter); | |||
| DECLARE_int32(thread); | |||
| DECLARE_bool(share_param_mem); | |||
| namespace lar { | |||
| /*! | |||
| * \brief: strategy option for running model | |||
| */ | |||
| class StrategyOption final : public OptionBase { | |||
| public: | |||
| //! creat options when option is used | |||
| static std::shared_ptr<OptionBase> create_option(); | |||
| //! config the model, dispatch configuration for different model implement | |||
| void config_model( | |||
| RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) override; | |||
| //! get option name | |||
| std::string option_name() const override { return m_option_name; }; | |||
| private: | |||
| //! Constructor | |||
| StrategyOption(); | |||
| //! configuration for different model implement | |||
| std::string m_option_name; | |||
| size_t warmup_iter; //! warm up number before running model | |||
| size_t run_iter; //! iteration number for running model | |||
| size_t threads; //! thread number for running model (NOTE:it's different | |||
| //! from multithread device ) | |||
| }; | |||
| class TestcaseOption final : public OptionBase { | |||
| public: | |||
| //! creat options when option is used | |||
| static std::shared_ptr<OptionBase> create_option(); | |||
| //! config the model, dispatch configuration for different model implement | |||
| void config_model( | |||
| RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) override; | |||
| //! get option name | |||
| std::string option_name() const override { return m_option_name; }; | |||
| private: | |||
| //! Constructor | |||
| TestcaseOption(); | |||
| //! configuration for different model implement | |||
| std::string m_option_name; | |||
| }; | |||
| } // namespace lar | |||
| @@ -0,0 +1,24 @@ | |||
| /** | |||
| * \file lite/load_and_run/src/strategys/strategy.cpp | |||
| * | |||
| * This file is part of MegEngine, a deep learning framework developed by | |||
| * Megvii. | |||
| * | |||
| * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||
| */ | |||
| #include "strategy.h" | |||
| #include <iostream> | |||
| using namespace lar; | |||
| std::shared_ptr<StrategyBase> StrategyBase::create_strategy(std::string model_path) { | |||
| if (FLAGS_fitting) { | |||
| return std::make_shared<FittingStrategy>(model_path); | |||
| } else { | |||
| return std::make_shared<NormalStrategy>(model_path); | |||
| } | |||
| } | |||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||
| @@ -0,0 +1,63 @@ | |||
| /** | |||
| * \file lite/load_and_run/src/strategys/strategy.h | |||
| * | |||
| * This file is part of MegEngine, a deep learning framework developed by | |||
| * Megvii. | |||
| * | |||
| * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||
| */ | |||
| #pragma once | |||
| #include <gflags/gflags.h> | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include "helpers/common.h" | |||
| #include "models/model.h" | |||
| #include "options/option_base.h" | |||
| DECLARE_bool(fitting); | |||
| namespace lar { | |||
| /*! | |||
| * \brief: load and run strategy base class | |||
| */ | |||
| class StrategyBase { | |||
| public: | |||
| static std::shared_ptr<StrategyBase> create_strategy(std::string model_path); | |||
| virtual void run() = 0; | |||
| virtual ~StrategyBase() = default; | |||
| RuntimeParam m_runtime_param; | |||
| std::unordered_map<std::string, std::shared_ptr<OptionBase>> m_options; | |||
| }; | |||
| /*! | |||
| * \brief: normal strategy for running | |||
| */ | |||
| class NormalStrategy : public StrategyBase { | |||
| public: | |||
| NormalStrategy(std::string model_path); | |||
| //! run model with runtime parameter | |||
| void run() override; | |||
| private: | |||
| //! run model subline for multiple thread | |||
| void run_subline(); | |||
| std::string m_model_path; | |||
| }; | |||
| /*! | |||
| * \brief: Fitting strategy for running | |||
| */ | |||
| class FittingStrategy : public StrategyBase { | |||
| public: | |||
| FittingStrategy(std::string model_path); | |||
| void run() override; | |||
| }; | |||
| } // namespace lar | |||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||
| @@ -0,0 +1,24 @@ | |||
| /** | |||
| * \file lite/load_and_run/src/strategys/strategy_fitting.cpp | |||
| * | |||
| * This file is part of MegEngine, a deep learning framework developed by | |||
| * Megvii. | |||
| * | |||
| * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||
| */ | |||
| #include "strategy.h" | |||
| using namespace lar; | |||
| FittingStrategy::FittingStrategy(std::string) { | |||
| mgb_assert("this version don't support Fitting Strategy"); | |||
| }; | |||
| void FittingStrategy::run() { | |||
| mgb_assert("this version don't support Fitting Strategy"); | |||
| }; | |||
| DEFINE_bool( | |||
| fitting, false, | |||
| "whether to use the fitting model, which will auto profile and get " | |||
| "the best option set!"); | |||
| @@ -0,0 +1,167 @@ | |||
| /** | |||
| * \file lite/load_and_run/src/strategys/strategy_normal.cpp | |||
| * | |||
| * This file is part of MegEngine, a deep learning framework developed by | |||
| * Megvii. | |||
| * | |||
| * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||
| */ | |||
| #include <iostream> | |||
| #include <thread> | |||
| #include "megbrain/common.h" | |||
| #include "megbrain/utils/timer.h" | |||
| #include "megbrain/version.h" | |||
| #include "megdnn/version.h" | |||
| #include "misc.h" | |||
| #include "strategy.h" | |||
| using namespace lar; | |||
| NormalStrategy::NormalStrategy(std::string model_path) { | |||
| mgb::set_log_level(mgb::LogLevel::WARN); | |||
| lite::set_log_level(LiteLogLevel::WARN); | |||
| m_model_path = model_path; | |||
| auto option_creator_map = OptionFactory::get_Instance().get_option_creator_map(); | |||
| mgb_log_debug("option map size: %lu", option_creator_map->size()); | |||
| auto construct_option = [&](std::string name) -> void { | |||
| auto& creator = (*option_creator_map)[name]; | |||
| auto option = creator(); | |||
| if (option) { | |||
| m_options.insert({name, option}); | |||
| } | |||
| }; | |||
| for (auto& creator : *option_creator_map) { | |||
| auto name = creator.first; | |||
| if (m_options.count(name) == 0) { | |||
| construct_option(name); | |||
| } | |||
| } | |||
| } | |||
| void NormalStrategy::run_subline() { | |||
| auto model = ModelBase::create_model(m_model_path); | |||
| mgb_assert(model != nullptr, "create model failed!!"); | |||
| auto stage_config_model = [&]() { | |||
| for (auto& option : m_options) { | |||
| option.second->config_model(m_runtime_param, model); | |||
| } | |||
| }; | |||
| //! execute before load config | |||
| m_runtime_param.stage = RunStage::BEFORE_MODEL_LOAD; | |||
| stage_config_model(); | |||
| mgb::RealTimer timer; | |||
| model->load_model(); | |||
| printf("load model: %.3fms\n", timer.get_msecs_reset()); | |||
| //! after load configure | |||
| m_runtime_param.stage = RunStage::AFTER_MODEL_LOAD; | |||
| stage_config_model(); | |||
| m_runtime_param.stage = RunStage::BEFORE_OUTSPEC_SET; | |||
| stage_config_model(); | |||
| // for get static memmory information options | |||
| m_runtime_param.stage = RunStage::AFTER_OUTSPEC_SET; | |||
| stage_config_model(); | |||
| auto warm_up = [&]() { | |||
| auto warmup_num = m_runtime_param.warmup_iter; | |||
| for (size_t i = 0; i < warmup_num; i++) { | |||
| printf("=== prepare: %.3fms; going to warmup\n\n", timer.get_msecs_reset()); | |||
| model->run_model(); | |||
| model->wait(); | |||
| printf("warm up %lu %.3fms\n", i, timer.get_msecs_reset()); | |||
| m_runtime_param.stage = RunStage::AFTER_RUNNING_WAIT; | |||
| stage_config_model(); | |||
| } | |||
| }; | |||
| auto run_iter = [&](int idx) { | |||
| double time_sqrsum = 0, time_sum = 0, | |||
| min_time = std::numeric_limits<double>::max(), max_time = 0; | |||
| auto run_num = m_runtime_param.run_iter; | |||
| for (size_t i = 0; i < run_num; i++) { | |||
| timer.reset(); | |||
| model->run_model(); | |||
| auto exec_time = timer.get_msecs(); | |||
| model->wait(); | |||
| m_runtime_param.stage = RunStage::AFTER_RUNNING_WAIT; | |||
| stage_config_model(); | |||
| auto cur = timer.get_msecs(); | |||
| printf("iter %lu/%lu: %.3fms (exec=%.3fms)\n", i, run_num, cur, exec_time); | |||
| time_sum += cur; | |||
| time_sqrsum += cur * cur; | |||
| fflush(stdout); | |||
| min_time = std::min(min_time, cur); | |||
| max_time = std::max(max_time, cur); | |||
| } | |||
| printf("\n=== finished test #%u: time=%.3fms avg_time=%.3fms " | |||
| "sexec=%.3fms min=%.3fms max=%.3fms\n\n", | |||
| idx, time_sum, time_sum / run_num, | |||
| std::sqrt( | |||
| (time_sqrsum * run_num - time_sum * time_sum) / | |||
| (run_num * (run_num - 1))), | |||
| min_time, max_time); | |||
| return time_sum; | |||
| }; | |||
| //! model with testcase | |||
| size_t iter_num = m_runtime_param.testcase_num; | |||
| double tot_time = 0; | |||
| for (size_t idx = 0; idx < iter_num; idx++) { | |||
| //! config when running model | |||
| mgb_log_warn("run testcase: %zu ", idx); | |||
| m_runtime_param.stage = RunStage::MODEL_RUNNING; | |||
| stage_config_model(); | |||
| if (!idx) { | |||
| warm_up(); | |||
| } | |||
| tot_time += run_iter(idx); | |||
| m_runtime_param.stage = RunStage::AFTER_RUNNING_ITER; | |||
| stage_config_model(); | |||
| } | |||
| printf("=== total time: %.3fms\n", tot_time); | |||
| //! execute after run | |||
| m_runtime_param.stage = RunStage::AFTER_MODEL_RUNNING; | |||
| stage_config_model(); | |||
| }; | |||
| void NormalStrategy::run() { | |||
| auto v0 = mgb::get_version(); | |||
| auto v1 = megdnn::get_version(); | |||
| printf("megbrain/lite/load_and_run:\nusing MegBrain " | |||
| "%d.%d.%d(%d) and MegDNN %d.%d.%d\n", | |||
| v0.major, v0.minor, v0.patch, v0.is_dev, v1.major, v1.minor, v1.patch); | |||
| size_t thread_num = m_runtime_param.threads; | |||
| auto run_sub = [&]() { run_subline(); }; | |||
| if (thread_num == 1) { | |||
| run_sub(); | |||
| } else if (thread_num > 1) { | |||
| #if MGB_HAVE_THREAD | |||
| std::vector<std::thread> threads; | |||
| for (size_t i = 0; i < thread_num; ++i) { | |||
| threads.emplace_back(run_sub); | |||
| } | |||
| for (auto&& i : threads) { | |||
| i.join(); | |||
| } | |||
| #else | |||
| mgb_log_error( | |||
| "%d threads requested, but load_and_run was compiled " | |||
| "without <thread> support.", | |||
| thread_num); | |||
| #endif | |||
| } else { | |||
| mgb_assert(false, "--thread must input a positive number!!"); | |||
| } | |||
| //! execute before run | |||
| } | |||