GitOrigin-RevId: 5dce356452
tags/v1.7.0
| @@ -13,10 +13,17 @@ import itertools | |||||
| import json | import json | ||||
| import os | import os | ||||
| import pickle | import pickle | ||||
| import re | |||||
| import struct | |||||
| from typing import Any | from typing import Any | ||||
| import cv2 | |||||
| import numpy as np | import numpy as np | ||||
| from megengine.logger import get_logger | |||||
| from .. import tensor | |||||
| from ..core import _imperative_rt as rt | |||||
| from ..core._imperative_rt import GraphProfiler, GraphProfiler2, SerializationMetadata | from ..core._imperative_rt import GraphProfiler, GraphProfiler2, SerializationMetadata | ||||
| from ..core._imperative_rt.core2 import Tensor as RawTensor | from ..core._imperative_rt.core2 import Tensor as RawTensor | ||||
| from ..core._imperative_rt.core2 import ( | from ..core._imperative_rt.core2 import ( | ||||
| @@ -38,12 +45,15 @@ from ..core._wrap import as_device | |||||
| from ..core.ops.builtin import BatchNorm, OpDef | from ..core.ops.builtin import BatchNorm, OpDef | ||||
| from ..core.tensor import megbrain_graph as G | from ..core.tensor import megbrain_graph as G | ||||
| from ..core.tensor.utils import setscalar | from ..core.tensor.utils import setscalar | ||||
| from ..utils import comp_graph_tools as cgtools | |||||
| from ..utils.naming import AutoNaming | from ..utils.naming import AutoNaming | ||||
| from ..utils.profiler import is_profiling | from ..utils.profiler import is_profiling | ||||
| from .dtr_config import DTRConfig | from .dtr_config import DTRConfig | ||||
| from .graph_opt_config import GraphOptimizationConfig | from .graph_opt_config import GraphOptimizationConfig | ||||
| from .sublinear_memory_config import SublinearMemoryConfig | from .sublinear_memory_config import SublinearMemoryConfig | ||||
| logger = get_logger(__name__) | |||||
| def _input_node_use_static_shape(): | def _input_node_use_static_shape(): | ||||
| return os.environ.get("MEGENGINE_INPUT_NODE_USE_STATIC_SHAPE") is not None | return os.environ.get("MEGENGINE_INPUT_NODE_USE_STATIC_SHAPE") is not None | ||||
| @@ -692,6 +702,289 @@ class trace: | |||||
| self._process_outputs(outputs) | self._process_outputs(outputs) | ||||
| return outputs | return outputs | ||||
| def _make_feed( | |||||
| self, | |||||
| graph, | |||||
| outputs, | |||||
| input_data, | |||||
| repeat, | |||||
| silent, | |||||
| no_assert, | |||||
| maxerr, | |||||
| resize_input, | |||||
| input_transform, | |||||
| ): | |||||
| def auto_reformat_image(path, data, dst_shape): | |||||
| """reformat image to target shape | |||||
| :param data: image data as numpy array | |||||
| :param dst_shape: target shape | |||||
| """ | |||||
| dim3_format = False # required input format does not contain batch | |||||
| hwc_format = False # required input format is NHWC | |||||
| if not dst_shape: # input tensor shape is not predefined | |||||
| if len(data.shape) == 2: | |||||
| chl = 1 | |||||
| h = data.shape[0] | |||||
| w = data.shape[1] | |||||
| else: | |||||
| assert ( | |||||
| len(data.shape) == 3 | |||||
| ), "Input image must be of dimension 2 or 3" | |||||
| h, w, chl = data.shape | |||||
| dst_shape = (1, chl, h, w) | |||||
| if len(dst_shape) == 3: | |||||
| dst_shape = (1,) + dst_shape | |||||
| dim3_format = True | |||||
| assert len(dst_shape) == 4, "bad dst_shape: {}".format(dst_shape) | |||||
| chl = dst_shape[1] | |||||
| if chl in [1, 3]: | |||||
| n, c, h, w = dst_shape | |||||
| dst_shape = (n, h, w, c) | |||||
| else: | |||||
| chl = dst_shape[3] | |||||
| assert chl in [ | |||||
| 1, | |||||
| 3, | |||||
| ], "can not infer input format from shape: {}".format(dst_shape) | |||||
| hwc_format = True | |||||
| # dst_shape has now been normalized to NHWC format | |||||
| if resize_input: | |||||
| h, w = dst_shape[1:3] | |||||
| data = cv2.resize(data, (w, h)) | |||||
| logger.info("input {} resized to {}".format(path, data.shape)) | |||||
| if chl == 1: | |||||
| data = cv2.cvtColor(data, cv2.COLOR_BGR2GRAY) | |||||
| data = data[:, :, np.newaxis] | |||||
| assert data.ndim == 3 | |||||
| data = data[np.newaxis] | |||||
| # data normalized to NHWC format | |||||
| if not hwc_format: | |||||
| data = np.transpose(data, (0, 3, 1, 2)) | |||||
| if dim3_format: | |||||
| data = np.squeeze(data, 0) | |||||
| return data | |||||
| def read_input_data(dst_shape, dtype, path): | |||||
| def check_shape_equal(dst_shape, data_shape): | |||||
| if len(dst_shape): | |||||
| assert len(data_shape) == len( | |||||
| dst_shape | |||||
| ), "input/data shapes mismatch: {} vs {}".format( | |||||
| dst_shape, data_shape | |||||
| ) | |||||
| if data_shape[1:] != dst_shape[1:]: | |||||
| logger.warning( | |||||
| "dst_shape is {}; data_shape is {}".format( | |||||
| dst_shape, data_shape | |||||
| ) | |||||
| ) | |||||
| if path.startswith("#"): | |||||
| assert not resize_input | |||||
| assert not input_transform | |||||
| spec = path | |||||
| m = re.match( | |||||
| r"^#rand\(([-0-9.]*)\s*,\s*([-0-9.]*)\s*(,[^\)]+)?\)$", spec | |||||
| ) | |||||
| assert m, "bad spec {}".format(spec) | |||||
| rng_min = float(m.group(1)) | |||||
| rng_max = float(m.group(2)) | |||||
| if m.group(3): | |||||
| shape_str = m.group(3) | |||||
| try: | |||||
| shape = shape_str[1:].split(",") | |||||
| if shape[-1].strip() == "...": | |||||
| shape = shape[:-1] | |||||
| shape.extend(list(dst_shape[len(shape) :])) | |||||
| data_shape = tuple(map(int, shape)) | |||||
| except ValueError as e: | |||||
| raise ValueError("bad spec {}: {}".format(spec, e.args)) | |||||
| else: | |||||
| data_shape = dst_shape | |||||
| check_shape_equal(dst_shape, data_shape) | |||||
| return np.random.uniform(rng_min, rng_max, data_shape).astype(dtype) | |||||
| # try to load image | |||||
| data = cv2.imread(path, cv2.IMREAD_COLOR) | |||||
| if data is None: | |||||
| assert not resize_input | |||||
| data = np.load(path) | |||||
| assert isinstance(data, np.ndarray) | |||||
| else: | |||||
| # load image succeeds, so we expect input format is image format | |||||
| data = auto_reformat_image(path, data, dst_shape) | |||||
| data = np.repeat(data, repeat, axis=0) | |||||
| if repeat > 1: | |||||
| logger.info( | |||||
| "repeat input for {} times, data shape is {}".format( | |||||
| repeat, data.shape | |||||
| ) | |||||
| ) | |||||
| check_shape_equal(dst_shape, data.shape) | |||||
| if input_transform: | |||||
| data = eval(input_transform, {"data": data, "np": np}) | |||||
| return data | |||||
| def gen_one_testcase(inputs, spec): | |||||
| paths = spec.split(";") | |||||
| if len(paths) != len(inputs): | |||||
| if len(paths) == 1 and paths[0].startswith("#"): | |||||
| paths = ["{}:{}".format(name, paths[0]) for name in inputs.keys()] | |||||
| assert len(paths) == len( | |||||
| inputs | |||||
| ), "required inputs: {}; data paths: {}".format(inputs.keys(), paths) | |||||
| if len(paths) == 1 and ":" not in paths[0]: | |||||
| paths[0] = next(iter(inputs.keys())) + ":" + paths[0] | |||||
| ret = {} | |||||
| for path in paths: | |||||
| var, path = path.split(":") | |||||
| ret[var] = read_input_data(inputs[var].shape, inputs[var].dtype, path) | |||||
| return ret | |||||
| inputs = cgtools.get_dep_vars(outputs, "Host2DeviceCopy") | |||||
| inputs = {i.name: i for i in inputs} | |||||
| if not no_assert: | |||||
| replace_varmap = {} | |||||
| inp_map = {} | |||||
| # replace var use InputNode | |||||
| for name, var in inputs.items(): | |||||
| inp = G.InputNode( | |||||
| device="xpux", dtype=var.dtype, shape=var.shape, graph=graph | |||||
| ) | |||||
| replace_varmap[var] = inp.outputs[0]._node | |||||
| inp_map[name] = inp | |||||
| new = cgtools.replace_vars(outputs, replace_varmap) | |||||
| if isinstance(new, rt.VarNode): | |||||
| new = list(new) | |||||
| output_nodes = [G.OutputNode(var) for var in new] | |||||
| func = graph.compile(*[node.outputs[0]._node for node in output_nodes]) | |||||
| def make_dev_tensor(value, dtype=None, device=None): | |||||
| return tensor(value, dtype=dtype, device=device)._dev_tensor() | |||||
| def calculate(*args, **kwargs): | |||||
| output_val = [] | |||||
| # set inputs value | |||||
| for name, var in inputs.items(): | |||||
| val = kwargs.pop(name, None) | |||||
| assert val is not None, "miss input name{}".format(name) | |||||
| dev_tensor = make_dev_tensor(val, dtype=var.dtype, device="xpux") | |||||
| inp_map[name].set_value(dev_tensor) | |||||
| func.execute() | |||||
| for res in output_nodes: | |||||
| output_val.append(res.get_value().numpy()) | |||||
| return output_val | |||||
| def expect_name(var): | |||||
| return "{}:expect".format(var.name) | |||||
| testcases = [] | |||||
| np.set_printoptions(precision=2, threshold=4, suppress=True) | |||||
| data_list = [] | |||||
| for item in input_data: | |||||
| if item.startswith("@"): | |||||
| with open(item[1:], "r") as f: | |||||
| data_list.extend( | |||||
| [line.rstrip() for line in f if line.rstrip() != ""] | |||||
| ) | |||||
| else: | |||||
| data_list.append(item) | |||||
| for inp_spec in data_list: | |||||
| cur_testcase = gen_one_testcase(inputs, inp_spec) | |||||
| assert len(cur_testcase) == len( | |||||
| inputs | |||||
| ), "required inputs: {}; given data: {}".format( | |||||
| inputs.keys(), cur_testcase.keys() | |||||
| ) | |||||
| if not no_assert: | |||||
| outputs_get = calculate(**cur_testcase) | |||||
| for var, val in zip(outputs, outputs_get): | |||||
| cur_testcase[expect_name(var)] = val | |||||
| logger.info( | |||||
| "generate test groundtruth: var={} shape={} range=({}, {})" | |||||
| " mean={} var={}".format( | |||||
| var, | |||||
| val.shape, | |||||
| val.min(), | |||||
| val.max(), | |||||
| np.mean(val), | |||||
| np.var(val), | |||||
| ) | |||||
| ) | |||||
| testcases.append(cur_testcase) | |||||
| logger.info( | |||||
| "add testcase: \n {}".format( | |||||
| "\n ".join( | |||||
| "{}: shape={} dtype={} range=({:.2f},{:.2f}) " | |||||
| "mean={:.2f} sd={:.2f}".format( | |||||
| k, v.shape, v.dtype, v.min(), v.max(), np.mean(v), np.std(v) | |||||
| ) | |||||
| for k, v in sorted(cur_testcase.items()) | |||||
| ) | |||||
| ) | |||||
| ) | |||||
| if not no_assert: | |||||
| def expect_shp(var): | |||||
| ret = var.shape | |||||
| if ret: | |||||
| return ret | |||||
| return testcases[0][expect_name(var)].shape | |||||
| def assert_equal(expect, real, **kwargs): | |||||
| op = AssertEqual(**kwargs) | |||||
| (res,) = G.apply_normal_varnode(op, expect, real) | |||||
| return res._node | |||||
| verbose = not silent | |||||
| outputs_new = [] | |||||
| for i in outputs: | |||||
| device = rt.CompNode("xpux") | |||||
| dtype = i.dtype | |||||
| name = expect_name(i) | |||||
| shape = expect_shp(i) | |||||
| # make expect output as one input of model. | |||||
| expect_get = rt.make_h2d(graph, device, dtype, shape, name) | |||||
| # insert assert opr to check expect and real. | |||||
| outputs_new.append( | |||||
| assert_equal(expect_get, i, verbose=verbose, maxerr=maxerr,) | |||||
| ) | |||||
| inputs[expect_name(i)] = expect_get | |||||
| outputs = outputs_new | |||||
| return {"outputs": outputs, "testcases": testcases} | |||||
| def dump( | def dump( | ||||
| self, | self, | ||||
| file, | file, | ||||
| @@ -708,6 +1001,13 @@ class trace: | |||||
| optimize_for_inference=True, | optimize_for_inference=True, | ||||
| user_info: Any = None, | user_info: Any = None, | ||||
| enable_metadata: bool = True, | enable_metadata: bool = True, | ||||
| input_data=None, | |||||
| repeat=1, | |||||
| silent=False, | |||||
| no_assert=False, | |||||
| maxerr=1e-4, | |||||
| resize_input=False, | |||||
| input_transform=None, | |||||
| **kwargs | **kwargs | ||||
| ): | ): | ||||
| r"""Serializes trace to file system. | r"""Serializes trace to file system. | ||||
| @@ -738,6 +1038,27 @@ class trace: | |||||
| will skip all optimize options if this is False. Default: True | will skip all optimize options if this is False. Default: True | ||||
| user_info: any type object, which will be pickled to bytes. | user_info: any type object, which will be pickled to bytes. | ||||
| enable_metadata: whether to save metadata into output file. | enable_metadata: whether to save metadata into output file. | ||||
| input_data: input test data and current network output would be used as groundtruth. | |||||
| The format is "var0:file0;var1:file1..." to specify data files for input vars. | |||||
| It can also be "#rand(min,max,shape...)" for generating random input data, for | |||||
| example, "#rand(0,255)", "#rand(0,255,1,3,224,224)" or "#rand(0, 255, 1, ...)" | |||||
| where `...` means the remaining part of the original shape. If the shape is not | |||||
| specified, the shape of corresponding input tensors in the network will be used. | |||||
| If there is only one input var, its name can be omitted. Each data file can either | |||||
| be an image which can be loaded by opencv, or a pickled numpy.ndarray. This option | |||||
| can be given multiple times to add multiple testcases. If you start the data | |||||
| with the letter @, the rest should be a filename, and each line in the file should | |||||
| be a single datum in the format described above. *NOTE* If `input_data` is not None, | |||||
| you can only use load-and-run to run the output file. | |||||
| repeat: how many times the input image is repeated. Useful when running benchmark for | |||||
| batch size other than one. Have no effect on randomly generated input data. | |||||
| silent: whether set verbose to False in assert_equal opr. | |||||
| no_assert: whether insert assert_equal opr to check result; this option is useful for | |||||
| benchmarking. | |||||
| maxerr: max error for assert_equal check during runtime. | |||||
| resize_input: whether resize input image to fit input var shape. | |||||
| input_transform: a python expression to transform the input data. | |||||
| Example: data / np.std(data) | |||||
| Keyword Arguments: | Keyword Arguments: | ||||
| @@ -778,6 +1099,8 @@ class trace: | |||||
| input for inference on nvidia backend(this optimization pass will | input for inference on nvidia backend(this optimization pass will | ||||
| result in mismatch of the precision of output of training and | result in mismatch of the precision of output of training and | ||||
| inference) | inference) | ||||
| * enable_fuse_preprocess: whether to fuse astype\pad_channel\dimshuffle and | |||||
| etc opr | |||||
| """ | """ | ||||
| if not self._capture_as_const: | if not self._capture_as_const: | ||||
| raise ValueError( | raise ValueError( | ||||
| @@ -892,8 +1215,28 @@ class trace: | |||||
| v.name = output_names[i] | v.name = output_names[i] | ||||
| dest_vars.append(v) | dest_vars.append(v) | ||||
| dest_vars = [i._node for i in dest_vars] | |||||
| if input_data is not None: | |||||
| feeds = self._make_feed( | |||||
| graph, | |||||
| dest_vars, | |||||
| input_data, | |||||
| repeat, | |||||
| silent, | |||||
| no_assert, | |||||
| maxerr, | |||||
| resize_input, | |||||
| input_transform, | |||||
| ) | |||||
| assert ( | |||||
| isinstance(feeds, dict) and feeds["testcases"] | |||||
| ), "testcases can not be empty" | |||||
| dest_vars = feeds["outputs"] | |||||
| if optimize_for_inference: | if optimize_for_inference: | ||||
| dest_vars, optimize_options = G.optimize_for_inference(dest_vars, **kwargs) | dest_vars, optimize_options = G.optimize_for_inference(dest_vars, **kwargs) | ||||
| dest_vars = [i._node for i in dest_vars] | |||||
| metadata = SerializationMetadata() | metadata = SerializationMetadata() | ||||
| if enable_metadata: | if enable_metadata: | ||||
| @@ -910,6 +1253,9 @@ class trace: | |||||
| if keep_opr_priority: | if keep_opr_priority: | ||||
| graph._set_priority_to_id(dest_vars) | graph._set_priority_to_id(dest_vars) | ||||
| if input_data is not None: | |||||
| file.write(b"mgbtest0") | |||||
| file.write(struct.pack("I", len(feeds["testcases"]))) | |||||
| dump_content, dump_info = G.dump_graph( | dump_content, dump_info = G.dump_graph( | ||||
| dest_vars, | dest_vars, | ||||
| keep_var_name=keep_var_name, | keep_var_name=keep_var_name, | ||||
| @@ -921,6 +1267,34 @@ class trace: | |||||
| metadata=metadata, | metadata=metadata, | ||||
| ) | ) | ||||
| file.write(dump_content) | file.write(dump_content) | ||||
| if input_data is not None: | |||||
| inputs = cgtools.get_dep_vars(dest_vars, "Host2DeviceCopy") | |||||
| inputs = sorted((i.name, i.dtype) for i in inputs) | |||||
| def make_dev_tensor(value, dtype=None, device=None): | |||||
| return tensor(value, dtype=dtype, device=device)._dev_tensor() | |||||
| for testcase in feeds["testcases"]: | |||||
| assert isinstance(testcase, dict) | |||||
| cg = G.Graph() | |||||
| output_mgbvars = [] | |||||
| for name, dtype in inputs: | |||||
| output_mgbvars.append( | |||||
| cg.make_const( | |||||
| make_dev_tensor( | |||||
| testcase.pop(name), dtype=dtype, device="cpux" | |||||
| ) | |||||
| ) | |||||
| ) | |||||
| assert not testcase, "extra inputs provided in testcase: {}".format( | |||||
| testcase.keys() | |||||
| ) | |||||
| dump_content, _ = G.dump_graph( | |||||
| output_mgbvars, strip_info_file=strip_info_file, append_json=True, | |||||
| ) | |||||
| file.write(dump_content) | |||||
| return dump_info | return dump_info | ||||
| def _process_inputs(self, *args, **kwargs): | def _process_inputs(self, *args, **kwargs): | ||||
| @@ -287,6 +287,16 @@ def test_dump_backward_graph(): | |||||
| np.testing.assert_equal(results[1], dx0) | np.testing.assert_equal(results[1], dx0) | ||||
| def test_dump_with_testcase(): | |||||
| @trace(symbolic=True, capture_as_const=True) | |||||
| def f(x): | |||||
| return exp(x) | |||||
| f(tensor(1.0)) | |||||
| file = io.BytesIO() | |||||
| f.dump(file, input_data=["#rand(0, 255, 1)"]) | |||||
| @pytest.mark.parametrize("trace_mode", [False, True]) | @pytest.mark.parametrize("trace_mode", [False, True]) | ||||
| def test_trace_profiler(trace_mode): | def test_trace_profiler(trace_mode): | ||||
| @trace(symbolic=trace_mode, profiling=True) | @trace(symbolic=trace_mode, profiling=True) | ||||
| @@ -1,535 +0,0 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| import argparse | |||||
| import os | |||||
| import re | |||||
| import struct | |||||
| import cv2 | |||||
| import numpy as np | |||||
| import megengine as mge | |||||
| import megengine.core._imperative_rt as rt | |||||
| import megengine.core.tensor.megbrain_graph as G | |||||
| from megengine import tensor | |||||
| from megengine.core._imperative_rt.core2 import apply | |||||
| from megengine.core.ops import builtin | |||||
| from megengine.utils import comp_graph_tools as cgtools | |||||
| logger = mge.get_logger(__name__) | |||||
| def auto_reformat_image(args, path, data, dst_shape): | |||||
| """reformat image to target shape | |||||
| :param data: image data as numpy array | |||||
| :param dst_shape: target shape | |||||
| """ | |||||
| dim3_format = False # required input format does not contain batch | |||||
| hwc_format = False # required input format is NHWC | |||||
| if not dst_shape: # input tensor shape is not predefined | |||||
| if len(data.shape) == 2: | |||||
| chl = 1 | |||||
| h = data.shape[0] | |||||
| w = data.shape[1] | |||||
| else: | |||||
| assert len(data.shape) == 3, "Input image must be of dimension 2 or 3" | |||||
| h, w, chl = data.shape | |||||
| dst_shape = (1, chl, h, w) | |||||
| if len(dst_shape) == 3: | |||||
| dst_shape = (1,) + dst_shape | |||||
| dim3_format = True | |||||
| assert len(dst_shape) == 4, "bad dst_shape: {}".format(dst_shape) | |||||
| chl = dst_shape[1] | |||||
| if chl in [1, 3]: | |||||
| n, c, h, w = dst_shape | |||||
| dst_shape = (n, h, w, c) | |||||
| else: | |||||
| chl = dst_shape[3] | |||||
| assert chl in [1, 3], "can not infer input format from shape: {}".format( | |||||
| dst_shape | |||||
| ) | |||||
| hwc_format = True | |||||
| # dst_shape has now been normalized to NHWC format | |||||
| if args.resize_input: | |||||
| h, w = dst_shape[1:3] | |||||
| data = cv2.resize(data, (w, h)) | |||||
| logger.info("input {} resized to {}".format(path, data.shape)) | |||||
| if chl == 1: | |||||
| data = cv2.cvtColor(data, cv2.COLOR_BGR2GRAY) | |||||
| data = data[:, :, np.newaxis] | |||||
| assert data.ndim == 3 | |||||
| data = data[np.newaxis] | |||||
| # data normalized to NHWC format | |||||
| if not hwc_format: | |||||
| data = np.transpose(data, (0, 3, 1, 2)) | |||||
| if dim3_format: | |||||
| data = np.squeeze(data, 0) | |||||
| return data | |||||
| def read_input_data(args, dst_shape, dtype, path, repeat): | |||||
| def check_shape_equal(dst_shape, data_shape): | |||||
| if len(dst_shape): | |||||
| assert len(data_shape) == len( | |||||
| dst_shape | |||||
| ), "input/data shapes mismatch: {} vs {}".format(dst_shape, data_shape) | |||||
| if data_shape[1:] != dst_shape[1:]: | |||||
| logger.warning( | |||||
| "dst_shape is {}; data_shape is {}".format(dst_shape, data_shape) | |||||
| ) | |||||
| if path.startswith("#"): | |||||
| assert not args.resize_input | |||||
| assert not args.input_transform | |||||
| spec = path | |||||
| m = re.match(r"^#rand\(([-0-9.]*)\s*,\s*([-0-9.]*)\s*(,[^\)]+)?\)$", spec) | |||||
| assert m, "bad spec {}".format(spec) | |||||
| rng_min = float(m.group(1)) | |||||
| rng_max = float(m.group(2)) | |||||
| if m.group(3): | |||||
| shape_str = m.group(3) | |||||
| try: | |||||
| shape = shape_str[1:].split(",") | |||||
| if shape[-1].strip() == "...": | |||||
| shape = shape[:-1] | |||||
| shape.extend(list(dst_shape[len(shape) :])) | |||||
| data_shape = tuple(map(int, shape)) | |||||
| except ValueError as e: | |||||
| raise ValueError("bad spec {}: {}".format(spec, e.args)) | |||||
| else: | |||||
| data_shape = dst_shape | |||||
| check_shape_equal(dst_shape, data_shape) | |||||
| return np.random.uniform(rng_min, rng_max, data_shape).astype(dtype) | |||||
| # try to load image | |||||
| data = cv2.imread(path, cv2.IMREAD_COLOR) | |||||
| if data is None: | |||||
| assert not args.resize_input | |||||
| data = np.load(path) | |||||
| assert isinstance(data, np.ndarray) | |||||
| else: | |||||
| # load image succeeds, so we expect input format is image format | |||||
| data = auto_reformat_image(args, path, data, dst_shape) | |||||
| data = np.repeat(data, repeat, axis=0) | |||||
| if repeat > 1: | |||||
| logger.info( | |||||
| "repeat input for {} times, data shape is {}".format(repeat, data.shape) | |||||
| ) | |||||
| check_shape_equal(dst_shape, data.shape) | |||||
| if args.input_transform: | |||||
| data = eval(args.input_transform, {"data": data, "np": np}) | |||||
| return data | |||||
| def gen_one_testcase(args, inputs, spec): | |||||
| paths = spec.split(";") | |||||
| if len(paths) != len(inputs): | |||||
| if len(paths) == 1 and paths[0].startswith("#"): | |||||
| paths = ["{}:{}".format(name, paths[0]) for name in inputs.keys()] | |||||
| assert len(paths) == len(inputs), "required inputs: {}; data paths: {}".format( | |||||
| inputs.keys(), paths | |||||
| ) | |||||
| if len(paths) == 1 and ":" not in paths[0]: | |||||
| paths[0] = next(iter(inputs.keys())) + ":" + paths[0] | |||||
| ret = {} | |||||
| for path in paths: | |||||
| var, path = path.split(":") | |||||
| if args.repeat: | |||||
| repeat = args.repeat | |||||
| else: | |||||
| repeat = 1 | |||||
| ret[var] = read_input_data( | |||||
| args, inputs[var].shape, inputs[var].dtype, path, repeat | |||||
| ) | |||||
| return ret | |||||
| def make_feeds(args): | |||||
| ret = G.load_graph(args.input) | |||||
| cg_rt, outputs = ret.graph, ret.output_vars_list | |||||
| inputs = cgtools.get_dep_vars(outputs, "Host2DeviceCopy") | |||||
| inputs = {i.name: i for i in inputs} | |||||
| if not args.no_assert: | |||||
| replace_varmap = {} | |||||
| inp_map = {} | |||||
| # replace var use InputNode | |||||
| for name, var in inputs.items(): | |||||
| inp = G.InputNode( | |||||
| device="xpux", dtype=var.dtype, shape=var.shape, graph=cg_rt | |||||
| ) | |||||
| replace_varmap[var] = inp.outputs[0] | |||||
| inp_map[name] = inp | |||||
| new = cgtools.replace_vars(outputs, replace_varmap) | |||||
| if isinstance(new, rt.VarNode): | |||||
| new = list(new) | |||||
| output_nodes = [G.OutputNode(var) for var in new] | |||||
| func = cg_rt.compile([node.outputs[0] for node in output_nodes]) | |||||
| def make_dev_tensor(value, dtype=None, device=None): | |||||
| return tensor(value, dtype=dtype, device=device)._dev_tensor() | |||||
| def calculate(*args, **kwargs): | |||||
| output_val = [] | |||||
| # set inputs value | |||||
| for name, var in inputs.items(): | |||||
| val = kwargs.pop(name, None) | |||||
| assert val is not None, "miss input name{}".format(name) | |||||
| dev_tensor = make_dev_tensor(val, dtype=var.dtype, device="xpux") | |||||
| inp_map[name].set_value(dev_tensor) | |||||
| func.execute() | |||||
| for res in output_nodes: | |||||
| output_val.append(res.get_value().numpy()) | |||||
| return output_val | |||||
| def expect_name(var): | |||||
| return "{}:expect".format(var.name) | |||||
| testcases = [] | |||||
| np.set_printoptions(precision=2, threshold=4, suppress=True) | |||||
| data_list = [] | |||||
| for item in args.data: | |||||
| if item.startswith("@"): | |||||
| with open(item[1:], "r") as f: | |||||
| data_list.extend([line.rstrip() for line in f if line.rstrip() != ""]) | |||||
| else: | |||||
| data_list.append(item) | |||||
| for inp_spec in data_list: | |||||
| cur_testcase = gen_one_testcase(args, inputs, inp_spec) | |||||
| assert len(cur_testcase) == len( | |||||
| inputs | |||||
| ), "required inputs: {}; given data: {}".format( | |||||
| inputs.keys(), cur_testcase.keys() | |||||
| ) | |||||
| if not args.no_assert: | |||||
| outputs_get = calculate(**cur_testcase) | |||||
| for var, val in zip(outputs, outputs_get): | |||||
| cur_testcase[expect_name(var)] = val | |||||
| logger.info( | |||||
| "generate test groundtruth: var={} shape={} range=({}, {})" | |||||
| " mean={} var={}".format( | |||||
| var, val.shape, val.min(), val.max(), np.mean(val), np.var(val) | |||||
| ) | |||||
| ) | |||||
| testcases.append(cur_testcase) | |||||
| logger.info( | |||||
| "add testcase: \n {}".format( | |||||
| "\n ".join( | |||||
| "{}: shape={} dtype={} range=({:.2f},{:.2f}) " | |||||
| "mean={:.2f} sd={:.2f}".format( | |||||
| k, v.shape, v.dtype, v.min(), v.max(), np.mean(v), np.std(v) | |||||
| ) | |||||
| for k, v in sorted(cur_testcase.items()) | |||||
| ) | |||||
| ) | |||||
| ) | |||||
| if not args.no_assert: | |||||
| def expect_shp(var): | |||||
| ret = var.shape | |||||
| if ret: | |||||
| return ret | |||||
| return testcases[0][expect_name(var)].shape | |||||
| def assert_equal(expect, real, **kwargs): | |||||
| op = builtin.AssertEqual(**kwargs) | |||||
| (res,) = apply(op, expect, real) | |||||
| return res | |||||
| verbose = not args.silent | |||||
| outputs_new = [] | |||||
| for i in outputs: | |||||
| device = rt.CompNode("xpux") | |||||
| dtype = i.dtype | |||||
| name = expect_name(i) | |||||
| shape = expect_shp(i) | |||||
| # make expect output as one input of model. | |||||
| expect_get = rt.make_h2d(cg_rt, device, dtype, shape, name) | |||||
| # insert assert opr to check expect and real. | |||||
| outputs_new.append( | |||||
| assert_equal( | |||||
| expect_get, | |||||
| i, | |||||
| verbose=verbose, | |||||
| maxerr=args.maxerr, | |||||
| ) | |||||
| ) | |||||
| inputs[expect_name(i)] = expect_get | |||||
| outputs = outputs_new | |||||
| return {"outputs": outputs, "testcases": testcases} | |||||
| def optimize_for_inference(args, outputs): | |||||
| args_list = [ | |||||
| "enable_io16xc32", | |||||
| "enable_ioc16", | |||||
| "enable_hwcd4", | |||||
| "enable_nchw4", | |||||
| "enable_nchw88", | |||||
| "enable_nchw44", | |||||
| "enable_nchw44_dot", | |||||
| "enable_nchw32", | |||||
| "enable_chwn4", | |||||
| "enable_fuse_conv_bias_nonlinearity", | |||||
| "enable_fuse_conv_bias_with_z", | |||||
| "enable_fuse_preprocess", | |||||
| ] | |||||
| kwargs = {} | |||||
| for k in args_list: | |||||
| if getattr(args, k): | |||||
| kwargs[k] = True | |||||
| if args.optimize_for_inference: | |||||
| outputs = G.optimize_for_inference(outputs, **kwargs) | |||||
| return outputs | |||||
| def main(): | |||||
| parser = argparse.ArgumentParser( | |||||
| description="Pack computing graph, input values and expected output " | |||||
| "values into one file for checking correctness. README.md gives more " | |||||
| "details on the usage", | |||||
| formatter_class=argparse.ArgumentDefaultsHelpFormatter, | |||||
| ) | |||||
| parser.add_argument("input", help="MegEngine dumped model file") | |||||
| parser.add_argument("-o", "--output", help="output file", required=True) | |||||
| parser.add_argument( | |||||
| "-d", | |||||
| "--data", | |||||
| default=[], | |||||
| action="append", | |||||
| required=True, | |||||
| help="Given input test data when input file is a network, " | |||||
| "and current network output would be used as groundtruth. " | |||||
| "The format is var0:file0;var1:file1... to specify data files for " | |||||
| "input vars. It can also be #rand(min,max,shape...) for generating " | |||||
| "random input data, for example, #rand(0,255), " | |||||
| "#rand(0,255,1,3,224,224) or #rand(0, 255, 1, ...) where `...` means " | |||||
| "the remaining part of the original shape. " | |||||
| "If the shape is not specified, the shape of " | |||||
| "corresponding input tensors in the network will be used. " | |||||
| "If there is only one input var, its name can be omitted. " | |||||
| "Each data file can either be an image which can be loaded by opencv, " | |||||
| "or a pickled numpy.ndarray. " | |||||
| "This option can be given multiple times to add multiple testcases. " | |||||
| " *NOTE* " | |||||
| "If you start the data with the letter @, the rest should be a " | |||||
| "filename, and each line in the file should be a single datum in " | |||||
| "the format described above. ", | |||||
| ) | |||||
| parser.add_argument( | |||||
| "--repeat", | |||||
| type=int, | |||||
| default=1, | |||||
| help="Specify how many times the input image is repeated. " | |||||
| "Useful when running benchmark for batch size other than one. " | |||||
| "Have no effect on randomly generated input data.", | |||||
| ) | |||||
| parser.add_argument( | |||||
| "--silent", | |||||
| action="store_true", | |||||
| help="set verbose to False in asserti_equal opr", | |||||
| ) | |||||
| parser.add_argument( | |||||
| "--optimize-for-inference", | |||||
| action="store_true", | |||||
| help="enable optimization for inference", | |||||
| ) | |||||
| parser.add_argument( | |||||
| "--no-assert", | |||||
| action="store_true", | |||||
| help="do not insert assert_equal opr to check result; " | |||||
| "this option is useful for benchmarking", | |||||
| ) | |||||
| parser.add_argument( | |||||
| "--maxerr", | |||||
| type=float, | |||||
| default=1e-4, | |||||
| help="max error for assert_equal check during runtime", | |||||
| ) | |||||
| parser.add_argument( | |||||
| "--resize-input", | |||||
| action="store_true", | |||||
| help="resize input image to fit input var shape", | |||||
| ) | |||||
| parser.add_argument( | |||||
| "--input-transform", | |||||
| help="a python expression to transform the input data. " | |||||
| "Example: data / np.std(data)", | |||||
| ) | |||||
| parser.add_argument( | |||||
| "--discard-var-name", | |||||
| action="store_true", | |||||
| help="discard variable and param names in the " "generated output", | |||||
| ) | |||||
| parser.add_argument( | |||||
| "--output-strip-info", action="store_true", help="output code strip information" | |||||
| ) | |||||
| parser.add_argument( | |||||
| "--enable-io16xc32", | |||||
| action="store_true", | |||||
| help="transform the mode to float16 io float32 compute", | |||||
| ) | |||||
| parser.add_argument( | |||||
| "--enable-ioc16", | |||||
| action="store_true", | |||||
| help="transform the dtype of the model to float16 io " "and compute", | |||||
| ) | |||||
| parser.add_argument( | |||||
| "--enable-fuse-conv-bias-nonlinearity", | |||||
| action="store_true", | |||||
| help="fuse convolution bias and nonlinearity opr to a " | |||||
| "conv_bias opr and compute", | |||||
| ) | |||||
| parser.add_argument( | |||||
| "--enable-hwcd4", | |||||
| action="store_true", | |||||
| help="transform the model format from NCHW to NHWCD4 " | |||||
| "for inference; you may need to disable CUDA and set " | |||||
| "MGB_USE_MEGDNN_DBG=2", | |||||
| ) | |||||
| parser.add_argument( | |||||
| "--enable-nchw4", | |||||
| action="store_true", | |||||
| help="transform the model format from NCHW to NCHW4 " "for inference", | |||||
| ) | |||||
| parser.add_argument( | |||||
| "--enable-nchw88", | |||||
| action="store_true", | |||||
| help="transform the model format from NCHW to NCHW88 " "for inference", | |||||
| ) | |||||
| parser.add_argument( | |||||
| "--enable-nchw44", | |||||
| action="store_true", | |||||
| help="transform the model format from NCHW to NCHW44 " "for inference", | |||||
| ) | |||||
| parser.add_argument( | |||||
| "--enable-nchw44-dot", | |||||
| action="store_true", | |||||
| help="transform the model format from NCHW to NCHW44_DOT " | |||||
| "for optimizing armv8.2 dot in inference", | |||||
| ) | |||||
| parser.add_argument( | |||||
| "--enable-nchw32", | |||||
| action="store_true", | |||||
| help="transform the model format from NCHW4 to NCHW32 " | |||||
| "for inference on nvidia TensoCore", | |||||
| ) | |||||
| parser.add_argument( | |||||
| "--enable-chwn4", | |||||
| action="store_true", | |||||
| help="transform the model format to CHWN4 " | |||||
| "for inference, mainly used for nvidia tensorcore", | |||||
| ) | |||||
| parser.add_argument( | |||||
| "--enable-fuse-conv-bias-with-z", | |||||
| action="store_true", | |||||
| help="fuse conv_bias with z input for inference on " | |||||
| "nvidia GPU (this optimization pass will result in mismatch " | |||||
| "of the precision of output of training and inference)", | |||||
| ) | |||||
| parser.add_argument( | |||||
| "--enable-fuse-preprocess", | |||||
| action="store_true", | |||||
| help="fuse astype\pad_channel\dimshuffle and etc opr " | |||||
| "from h2d opr", | |||||
| ) | |||||
| args = parser.parse_args() | |||||
| feeds = make_feeds(args) | |||||
| assert isinstance(feeds, dict) and feeds["testcases"], "testcases can not be empty" | |||||
| output_mgbvars = feeds["outputs"] | |||||
| output_mgbvars = optimize_for_inference(args, output_mgbvars) | |||||
| inputs = cgtools.get_dep_vars(output_mgbvars, "Host2DeviceCopy") | |||||
| inputs = sorted((i.name, i.dtype) for i in inputs) | |||||
| if args.discard_var_name: | |||||
| sereg_kwargs = dict(keep_var_name=0, keep_param_name=False) | |||||
| else: | |||||
| sereg_kwargs = dict(keep_var_name=2, keep_param_name=True) | |||||
| strip_info_file = args.output + ".json" if args.output_strip_info else None | |||||
| with open(args.output, "wb") as fout: | |||||
| fout.write(b"mgbtest0") | |||||
| fout.write(struct.pack("I", len(feeds["testcases"]))) | |||||
| dump_content, stat = G.dump_graph( | |||||
| output_mgbvars, | |||||
| append_json=True, | |||||
| strip_info_file=strip_info_file, | |||||
| **sereg_kwargs, | |||||
| ) | |||||
| fout.write(dump_content) | |||||
| logger.info( | |||||
| "graph dump sizes: tot_size={:.3f}KiB overhead={:.3f}KiB".format( | |||||
| stat.tot_bytes / 1024, (stat.tot_bytes - stat.tensor_value_bytes) / 1024 | |||||
| ) | |||||
| ) | |||||
| def make_dev_tensor(value, dtype=None, device=None): | |||||
| return tensor(value, dtype=dtype, device=device)._dev_tensor() | |||||
| for testcase in feeds["testcases"]: | |||||
| assert isinstance(testcase, dict) | |||||
| cg = G.Graph() | |||||
| output_mgbvars = [] | |||||
| for name, dtype in inputs: | |||||
| output_mgbvars.append( | |||||
| cg.make_const( | |||||
| make_dev_tensor(testcase.pop(name), dtype=dtype, device="cpux") | |||||
| ) | |||||
| ) | |||||
| assert not testcase, "extra inputs provided in testcase: {}".format( | |||||
| testcase.keys() | |||||
| ) | |||||
| with open(args.output, "ab") as fout: | |||||
| dump_content, _ = G.dump_graph( | |||||
| output_mgbvars, strip_info_file=strip_info_file, append_json=True | |||||
| ) | |||||
| fout.write(dump_content) | |||||
| if __name__ == "__main__": | |||||
| main() | |||||