feat(imperative): add more tools for megenginetags/v1.7.2.m1
| @@ -1,8 +1,156 @@ | |||
| # MegEngine Tools | |||
| This directory contains executable python files. | |||
| Use these files in the following way (replace `xxx` to specific file name, like `network_visualize`): | |||
| MegEngine 相关的工具汇总。使用方法如下(可将 `xxx` 替换成任一脚本文件,如 `network_visualize`): | |||
| ``` | |||
| ```bash | |||
| python -m megengine.tools.xxx | |||
| ``` | |||
| ``` | |||
| 工具列表: | |||
| ### accuracy_shake_var_tree | |||
| 将精度抖动分析结果构造成树结构,方便锁定引起抖动的根节点,以及查找依赖关系。 | |||
| 输入: compare_binary_iodump 的输出存入到的一个文件 | |||
| 输出: 第一个出现结果不一致的输出结点 | |||
| 执行命令: accuracy_shake_var_tree 中定义了一些函数组件,可按需集成到实际代码中。下面有一个测试代码: | |||
| ```python | |||
| import megengine.tools.accuracy_shake_var_tree as st | |||
| r = st.parse('diff.txt') | |||
| for key, value in r.items(): | |||
| n = st.varNode.get_varNode(key) | |||
| n.show_src_info() | |||
| print("reference nodes:") | |||
| for i in n.get_reference_list(): | |||
| print(i.id) | |||
| ``` | |||
| ### benchmark_op | |||
| 逐个运行 functional op(并不是所有的 functional op),对比 MegEngine 与 PyTorch 的性能,通过量化结果来指导如何进行下一步的优化。 | |||
| 输入: 无 | |||
| 输出: 打印一个列表,对比在小输入和大输入的情况下 MegEngine 和 Pytorch 执行一些 functional op 的速度对比 | |||
| 执行命令: `python3 -m megengine.tools.benchmark_op` | |||
| ### compare_binary_iodump | |||
| 分析同一模型在不同平台下给定相同输入之后的输出是否完全一致。 | |||
| 输入: 两个目录(假设分别为 expect/ 和 actual/),分别存有不同平台下运行的 tensor 结果 | |||
| 输出: 打印所有的输出 tensor 信息,如果某个 tensor 在两个平台上的值不一致,那么会打印出第一个不一致的值 | |||
| 执行命令: `python3 -m megengine.tools.compare_binary_iodump expect/ actual/` | |||
| ### draw_graph | |||
| 用来查看静态图的 op 序列,有助于理解 MegEngine 的静态图在动态图的基础上做了哪些优化。 | |||
| 输入: `megengine.core.tensor.megbrain_graph.Graph._to_json` 得出的静态图描述文件,为 json 格式 | |||
| 输出: 一个 dot 文件,可通过 dot 命令绘制出图片 | |||
| 执行命令: | |||
| ```bash | |||
| python3 -m megengine.tools.draw_graph -i dump.json -o dump.dot | |||
| dot -Tpng dump.dot -o dump.png | |||
| ``` | |||
| ### dump_with_testcase_mge | |||
| 将待测数据提前注入模型文件,并在本地运行得到期望结果,可与实际运行的结果进行比对以检查是否出错。 | |||
| 输入: 一个 MegEngine 模型文件,可选一些 npy 文件作为模型输入(也可以随机生成输入,如下面的命令示例) | |||
| 输出: 一个带输入的 MegEngine 模型文件 | |||
| 执行命令: `python3 -m megengine.tools.dump_with_testcase_mge model.mge -d "#rand(0,255,14,2)"` | |||
| ### graph_info_analyze | |||
| 将图和内存信息的 json 文件的文件夹 logs 转换为 TensorBoard 的输入文件夹 logs_p。以便 TensorBoard 对图结构以及内存信息进行可视化。 | |||
| 输入: 图和内存信息的 json 文件的文件夹 | |||
| 输出: TensorBoard 的输入文件夹 | |||
| 执行命令: `python3 -m megengine.tools.graph_info_analyze -i logs -o logs_p` | |||
| ### load_network_and_run | |||
| python 版本的 load_and_run。 | |||
| 输入: MegEngine 的模型文件,可选一些 npy 文件作为模型输入 | |||
| 输出: 模型执行并打印一些测速信息 | |||
| 执行命令: `python3 -m megengine.tools.load_network_and_run model.mge --iter 10` | |||
| ### network_visualize | |||
| 1. 分析给定的 MegEngine 模型中参数量信息,包括 shape、dtype、mean、std 以及 size 占比等。 | |||
| 2. 分析给定的 MegEngine 模型中算子 FLOPs 计算量以及占比,还有算子的 inputs\outputs shape、感受野、stride 等。 | |||
| 输入: MegEngine 的模型文件 | |||
| 输出: 模型中的参数量信息或计算量信息 | |||
| 执行命令: | |||
| ```bash | |||
| # 分析参数量 | |||
| python3 -m megengine.tools.network_visualize model.mge --cal_params --logging_to_stdout | |||
| # 分析计算量 | |||
| python3 -m megengine.tools.network_visualize model.mge --cal_flops --logging_to_stdout | |||
| ``` | |||
| ### profile_analyze | |||
| 对于 load_and_run --profile 运行模型生成的 profile.json 文件或者 trace 模式下开启 profiling 功能并通过 trace.get_profile() 得到的 json 文件进行分析,得到静态图中算子的时间和显存占比等信息,以表格形式呈现。 | |||
| 输入: load_and_run 生成的 profile 文件 | |||
| 输出: 一个按照参数在输入文件中筛选得出的数据表格 | |||
| 执行命令: | |||
| ```bash | |||
| # 生成供分析的 json 文件 | |||
| python3 -m megengine.tools.load_network_and_run model.mge --warm-up --iter 10 --profile profile.json | |||
| #分析耗时前 3 的单个算子 | |||
| python3 -m megengine.tools.profile_analyze profile.json -t 3 | |||
| #筛选用时超过 10us 的 conv 按 flops 排序 | |||
| python3 -m megengine.tools.profile_analyze profile.json -t 3 --order-by +flops --min-time 1e-5 --type ConvolutionForward | |||
| ``` | |||
| ### profiler | |||
| 对给定的训练程序,记录训练过程并以通用格式存储,可在浏览器上可视化。 | |||
| 输入: 需要一个 MegEngine 的训练程序(称之为 train.py,其中包含一个典型的 MegEngine 训练过程) | |||
| 输出: 一些记录 profile 过程的 json 文件,默认在 profile 子目录下,可用 https://ui.perfetto.dev/ 进行加载并且可视化 | |||
| 执行命令: `python3 -m megengine.tools.profiler train.py` | |||
| ### svg_viewer | |||
| 查看 MegEngine 生成的显存占用图,可以帮助用户了解显存使用情况. | |||
| 输入: 显存占用的 svg 图片 | |||
| 输出: 网页展示的可视化 | |||
| 执行命令: `python3 -m megengine.tools.svg_viewer` | |||
| @@ -0,0 +1,151 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import time | |||
| import numpy as np | |||
| import megengine as mge | |||
| import megengine.module as MM | |||
| import megengine.functional as MF | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as TF | |||
| from tabulate import tabulate | |||
| module_cache = { | |||
| "conv2d": (MM.Conv2d(32, 32, 3, 1, 0), nn.Conv2d(32, 32, 3, 1, 0).cuda()), | |||
| "dw_conv2d": (MM.Conv2d(32, 32, 3, 1, 0, groups=32), nn.Conv2d(32, 32, 3, 1, 0, groups=32).cuda()), | |||
| "conv3d": (MM.Conv3d(32, 32, 3, 1, 0), nn.Conv3d(32, 32, 3, 1, 0).cuda()), | |||
| "ConvTranspose2d": (MM.ConvTranspose2d(32, 32, 3, 1, 0), nn.ConvTranspose2d(32, 32, 3, 1, 0).cuda()), | |||
| "BatchNorm2d": (MM.BatchNorm2d(64), nn.BatchNorm2d(64).cuda()), | |||
| "Linear": (MM.Linear(1000, 1000), nn.Linear(1000, 1000).cuda()), | |||
| } | |||
| test_cases = [ | |||
| # (mge op, torch op, small inps, large inps, unpack_inps, rep) | |||
| ("adaptive_avg_pool2d", lambda x: MF.adaptive_avg_pool2d(x, (7,7)), lambda x: TF.adaptive_avg_pool2d(x, (7,7)), [(2, 32, 16, 16)], [(64, 512, 16, 16)], True, 1000), | |||
| ("adaptive_max_pool2d", lambda x: MF.adaptive_max_pool2d(x, (7,7)), lambda x: TF.adaptive_max_pool2d(x, (7,7)), [(2, 32, 16, 16)], [(64, 512, 16, 16)], True, 1000), | |||
| ("argsort", MF.argsort, torch.argsort, [(1000,)], [(1000, 1000),], True, 1000), | |||
| ("avg_pool2d", lambda x: MF.avg_pool2d(x, 2), lambda x: TF.avg_pool2d(x, 2), [(2, 32, 16, 16)], [(64, 512, 16, 16)], True, 1000), | |||
| ("broadcast", lambda x: MF.broadcast_to(x, (5,) + x.shape), lambda x: torch.broadcast_to(x, (5,)+x.shape), [(100,100)], [(64, 512, 16, 16)], True, 1000), | |||
| ("batchedmatmul", MF.matmul, torch.matmul, [(8, 64, 32), (8, 32, 64)], [(8, 2048, 512), (8, 512, 2048)], True, 1000), | |||
| ("batchnrom2d", lambda x: module_cache["BatchNorm2d"][0](x), lambda x: module_cache["BatchNorm2d"][1](x), [(2, 64, 16, 16)], [(64, 64, 128, 128)], True, 1000), | |||
| ("concat", MF.concat, torch.cat, [(20, 100), (50, 100), (30, 100)], [(64, 512, 16, 16), (64, 512, 16, 16), (64, 512, 16, 16)], False, 1000), | |||
| ("conv2d", lambda x: module_cache["conv2d"][0](x), lambda x: module_cache["conv2d"][1](x), [(2, 32, 16, 16)], [(32, 32, 128, 128)], True, 1000), | |||
| ("conv3d", lambda x: module_cache["conv3d"][0](x), lambda x: module_cache["conv3d"][1](x), [(2, 32, 8, 8, 8)], [(32, 32, 16, 16, 16)], True, 1000), | |||
| ("convTranspose2d", lambda x: module_cache["ConvTranspose2d"][0](x), lambda x: module_cache["ConvTranspose2d"][1](x), [(2, 32, 16, 16)], [(32, 32, 128, 128)], True, 1000), | |||
| ("dropout", lambda x: MF.dropout(x, 0.5), TF.dropout, [(100,100)], [(64, 512, 16, 16)], True, 1000), | |||
| ("dw_conv2d", lambda x: module_cache["dw_conv2d"][0](x), lambda x: module_cache["dw_conv2d"][1](x), [(2, 32, 16, 16)], [(32, 32, 128, 128)], True, 1000), | |||
| ("elemwise.unary", MF.log, torch.log, [(100,100)], [(64, 512, 16, 16)], True, 1000), | |||
| ("elemwise.binary", MF.add, torch.add, [(100,100), (100,100)], [(64, 512, 16, 16), (64, 512, 16, 16)], True, 1000), | |||
| ("expand_dims", lambda x: MF.expand_dims(x, 0), lambda x: torch.unsqueeze(x, 0), [(100,100)], [(64, 512, 16, 16)], True, 1000), | |||
| ("gelu", MF.gelu, TF.gelu, [(100,100)], [(64, 512, 16, 16)], True, 1000), | |||
| ("hswish", MF.hswish, TF.hardswish, [(100,100)], [(64, 512, 16, 16)], True, 1000), | |||
| ("hsigmoid", MF.hsigmoid, TF.hardsigmoid, [(100,100)], [(64, 512, 16, 16)], True, 1000), | |||
| ("isinf", MF.isinf, torch.isinf, [(100,100)], [(64, 512, 16, 16)], True, 1000), | |||
| ("indeixngMultiAxisVec", lambda x: x[[1,3,5], [1,3,5], [1,3,5], [1,3,5]], lambda x: x[[1,3,5], [1,3,5], [1,3,5], [1,3,5]], [(10,10,10,10)], [(64, 512, 16, 16)], True, 1000), | |||
| ("logsigmoid", MF.logsigmoid, TF.logsigmoid, [(100,100)], [(64, 512, 16, 16)], True, 1000), | |||
| ("leaky_relu", lambda x: MF.leaky_relu(x, 0.5), lambda x: TF.leaky_relu(x, 0.5), [(100,100)], [(64, 512, 16, 16)], True, 1000), | |||
| ("linear", lambda x: module_cache["Linear"][0](x), lambda x: module_cache["Linear"][1](x), [(10, 1000)], [(64, 128, 1000)], True, 1000), | |||
| ("matinv", MF.matinv, torch.inverse, [(10,10)], [(30, 30)], True, 1000), | |||
| ("matmul", MF.matmul, torch.matmul, [(64,32), (32, 64)], [(2048, 1024), (1024, 2048)], True, 1000), | |||
| ("max_pool2d", lambda x: MF.max_pool2d(x, 2), lambda x: TF.max_pool2d(x, 2), [(2, 32, 16, 16)], [(64, 512, 16, 16)], True, 1000), | |||
| ("normal", lambda x: mge.random.normal(0,1, x.shape), lambda x: torch.randn(x.shape, device="cuda"), [(100,100)], [(64, 512, 16, 16)], True, 1000), | |||
| ("prelu", MF.prelu, TF.prelu, [(100,100), (1,)], [(64, 512, 16, 16), (1,)], True, 1000), | |||
| ("reduce.max", lambda x: MF.max(x, 0), lambda x: torch.max(x, 0), [(100,100)], [(64, 512, 16, 16)], True, 1000), | |||
| ("reduce.mean", lambda x: MF.mean(x, 0), lambda x: torch.mean(x, 0), [(100,100)], [(64, 512, 16, 16)], True, 1000), | |||
| ("reduce.mean", lambda x: MF.mean(x, 0), lambda x: torch.mean(x, 0), [(100,100)], [(64, 512, 16, 16)], True, 1000), | |||
| ("relu", MF.relu, TF.relu, [(100,100)], [(64, 512, 16, 16)], True, 1000), | |||
| ("relu6", MF.relu6, TF.relu6, [(100,100)], [(64, 512, 16, 16)], True, 1000), | |||
| ("repeat", lambda x: MF.repeat(x, 5), lambda x: torch.repeat_interleave(x, 5), [(100,100)], [(64, 512, 16, 16)], True, 1000), | |||
| ("silu", MF.silu, TF.silu, [(100,100)], [(64, 512, 16, 16)], True, 1000), | |||
| ("split", lambda x: MF.split(x, 5), lambda x: torch.split(x, 5), [(100,100)], [(64, 512, 16, 16)], True, 1000), | |||
| ("sigmoid", MF.sigmoid, TF.sigmoid, [(100,100)], [(64, 512, 16, 16)], True, 1000), | |||
| ("softmax", lambda x: MF.softmax(x, axis=1), lambda x: TF.softmax(x, dim=1), [(100,100)], [(64, 512, 16, 16)], True, 1000), | |||
| ("softplus", MF.softplus, TF.softplus, [(100,100)], [(64, 512, 16, 16)], True, 1000), | |||
| ("squeeze", lambda x: MF.squeeze(x, 0), lambda x: torch.squeeze(x, 0), [(1, 100,100)], [(1, 64, 512, 16, 16)], True, 1000), | |||
| ("stack", MF.stack, torch.stack, [(100,100), (100,100)], [(64, 512, 16, 16), (64, 512, 16, 16)], False, 10000), | |||
| ("subtensor", lambda x: x[0:20, 10:60], lambda x: x[0:20, 10:60], [(100,100)], [(64, 512, 16, 16)], True, 1000), | |||
| ("topk", lambda x: MF.topk(x, 10), lambda x: torch.topk(x, 10), [(100,100)], [(1000, 1000)], True, 1000), | |||
| ("tile", lambda x: MF.tile(x, (2,)*len(x.shape)), lambda x: torch.tile(x, (2,)*len(x.shape)), [(100,100)], [(64, 512, 16, 16)], True, 1000), | |||
| ("transpose", lambda x: MF.transpose(x, list(range(len(x.shape)))[::-1]), lambda x: torch.permute(x, list(range(len(x.shape)))[::-1]), [(100,100)], [(64, 512, 16, 16)], True, 1000), | |||
| ("where", lambda x: MF.where(x > 0.5, x, x), lambda x: torch.where(x > 0.5, x, x), [(100,100)], [(64, 512, 16, 16)], True, 1000), | |||
| ("uniform", lambda x: mge.random.uniform(0,1, x.shape), lambda x: torch.rand(x.shape, device="cuda"), [(100,100)], [(64, 512, 16, 16)], True, 1000), | |||
| ] | |||
| def perf_func(func, inps, reps, unpack_inps, is_mge): | |||
| if is_mge: | |||
| mge._full_sync() | |||
| tik = time.time() | |||
| for _ in range(reps): | |||
| if unpack_inps: | |||
| out = func(*inps) | |||
| else: | |||
| out = func(inps) | |||
| mge._full_sync() | |||
| else: | |||
| torch.cuda.synchronize() | |||
| with torch.no_grad(): | |||
| tik = time.time() | |||
| for _ in range(reps): | |||
| if unpack_inps: | |||
| out = func(*inps) | |||
| else: | |||
| out = func(inps) | |||
| torch.cuda.synchronize() | |||
| return time.time() - tik | |||
| def get_avg_time(func, inps, reps, unpack_inps, is_mge): | |||
| # warm up | |||
| for _ in range(2): | |||
| t = perf_func(func, inps, reps, unpack_inps, is_mge) | |||
| times = [] | |||
| for _ in range(5): | |||
| t = perf_func(func, inps, reps, unpack_inps, is_mge) | |||
| times.append(t) | |||
| return np.mean(times) | |||
| def get_perf_results(mge_func, torch_func, shapes, unpack_inps, reps): | |||
| inps = [ | |||
| np.random.randn(*shape) for shape in shapes | |||
| ] | |||
| inps_mge = [mge.tensor(inp, dtype="float32") for inp in inps] | |||
| avg_time_mge = get_avg_time(mge_func, inps_mge, reps, unpack_inps, True) | |||
| inps_torch = [torch.Tensor(inp).type(torch.float).cuda() for inp in inps] | |||
| avg_time_torch = get_avg_time(torch_func, inps_torch, reps, unpack_inps, False) | |||
| return avg_time_mge, avg_time_torch | |||
| if __name__ == "__main__": | |||
| header = ["opr_name", "time(mge/pytorch; small input)", "time(mge/pytorch; large input)"] | |||
| table = [] | |||
| for case in test_cases: | |||
| assert len(case) == 7 | |||
| name, mge_func, torch_func, small_shapes, large_shapes, unpack_inps, reps = case | |||
| data = [] | |||
| data.append(name) | |||
| print("========== op: {}".format(name)) | |||
| avg_time_mge, avg_time_torch = get_perf_results(mge_func, torch_func, small_shapes, unpack_inps, reps) | |||
| print("mge time: {}".format(avg_time_mge)) | |||
| print("torch time: {}".format(avg_time_torch)) | |||
| data.append("{:.2f}".format(avg_time_mge / avg_time_torch)) | |||
| avg_time_mge, avg_time_torch = get_perf_results(mge_func, torch_func, large_shapes, unpack_inps, reps) | |||
| print("mge time: {}".format(avg_time_mge)) | |||
| print("torch time: {}".format(avg_time_torch)) | |||
| data.append("{:.2f}".format(avg_time_mge / avg_time_torch)) | |||
| table.append(data) | |||
| print(tabulate(table, header, tablefmt="github")) | |||
| @@ -0,0 +1,535 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import argparse | |||
| import os | |||
| import re | |||
| import struct | |||
| import cv2 | |||
| import numpy as np | |||
| import megengine as mge | |||
| import megengine.core._imperative_rt as rt | |||
| import megengine.core.tensor.megbrain_graph as G | |||
| from megengine import tensor | |||
| from megengine.core.ops import builtin | |||
| from megengine.utils import comp_graph_tools as cgtools | |||
| logger = mge.get_logger(__name__) | |||
| def auto_reformat_image(args, path, data, dst_shape): | |||
| """reformat image to target shape | |||
| :param data: image data as numpy array | |||
| :param dst_shape: target shape | |||
| """ | |||
| dim3_format = False # required input format does not contain batch | |||
| hwc_format = False # required input format is NHWC | |||
| if not dst_shape: # input tensor shape is not predefined | |||
| if len(data.shape) == 2: | |||
| chl = 1 | |||
| h = data.shape[0] | |||
| w = data.shape[1] | |||
| else: | |||
| assert len(data.shape) == 3, "Input image must be of dimension 2 or 3" | |||
| h, w, chl = data.shape | |||
| dst_shape = (1, chl, h, w) | |||
| if len(dst_shape) == 3: | |||
| dst_shape = (1,) + dst_shape | |||
| dim3_format = True | |||
| assert len(dst_shape) == 4, "bad dst_shape: {}".format(dst_shape) | |||
| chl = dst_shape[1] | |||
| if chl in [1, 3]: | |||
| n, c, h, w = dst_shape | |||
| dst_shape = (n, h, w, c) | |||
| else: | |||
| chl = dst_shape[3] | |||
| assert chl in [1, 3], "can not infer input format from shape: {}".format( | |||
| dst_shape | |||
| ) | |||
| hwc_format = True | |||
| # dst_shape has now been normalized to NHWC format | |||
| if args.resize_input: | |||
| h, w = dst_shape[1:3] | |||
| data = cv2.resize(data, (w, h)) | |||
| logger.info("input {} resized to {}".format(path, data.shape)) | |||
| if chl == 1: | |||
| data = cv2.cvtColor(data, cv2.COLOR_BGR2GRAY) | |||
| data = data[:, :, np.newaxis] | |||
| assert data.ndim == 3 | |||
| data = data[np.newaxis] | |||
| # data normalized to NHWC format | |||
| if not hwc_format: | |||
| data = np.transpose(data, (0, 3, 1, 2)) | |||
| if dim3_format: | |||
| data = np.squeeze(data, 0) | |||
| return data | |||
| def read_input_data(args, dst_shape, dtype, path, repeat): | |||
| def check_shape_equal(dst_shape, data_shape): | |||
| if len(dst_shape): | |||
| assert len(data_shape) == len( | |||
| dst_shape | |||
| ), "input/data shapes mismatch: {} vs {}".format(dst_shape, data_shape) | |||
| if data_shape[1:] != dst_shape[1:]: | |||
| logger.warning( | |||
| "dst_shape is {}; data_shape is {}".format(dst_shape, data_shape) | |||
| ) | |||
| if path.startswith("#"): | |||
| assert not args.resize_input | |||
| assert not args.input_transform | |||
| spec = path | |||
| m = re.match(r"^#rand\(([-0-9.]*)\s*,\s*([-0-9.]*)\s*(,[^\)]+)?\)$", spec) | |||
| assert m, "bad spec {}".format(spec) | |||
| rng_min = float(m.group(1)) | |||
| rng_max = float(m.group(2)) | |||
| if m.group(3): | |||
| shape_str = m.group(3) | |||
| try: | |||
| shape = shape_str[1:].split(",") | |||
| if shape[-1].strip() == "...": | |||
| shape = shape[:-1] | |||
| shape.extend(list(dst_shape[len(shape) :])) | |||
| data_shape = tuple(map(int, shape)) | |||
| except ValueError as e: | |||
| raise ValueError("bad spec {}: {}".format(spec, e.args)) | |||
| else: | |||
| data_shape = dst_shape | |||
| check_shape_equal(dst_shape, data_shape) | |||
| return np.random.uniform(rng_min, rng_max, data_shape).astype(dtype) | |||
| # try to load image | |||
| data = cv2.imread(path, cv2.IMREAD_COLOR) | |||
| if data is None: | |||
| assert not args.resize_input | |||
| data = np.load(path) | |||
| assert isinstance(data, np.ndarray) | |||
| else: | |||
| # load image succeeds, so we expect input format is image format | |||
| data = auto_reformat_image(args, path, data, dst_shape) | |||
| data = np.repeat(data, repeat, axis=0) | |||
| if repeat > 1: | |||
| logger.info( | |||
| "repeat input for {} times, data shape is {}".format(repeat, data.shape) | |||
| ) | |||
| check_shape_equal(dst_shape, data.shape) | |||
| if args.input_transform: | |||
| data = eval(args.input_transform, {"data": data, "np": np}) | |||
| return data | |||
| def gen_one_testcase(args, inputs, spec): | |||
| paths = spec.split(";") | |||
| if len(paths) != len(inputs): | |||
| if len(paths) == 1 and paths[0].startswith("#"): | |||
| paths = ["{}:{}".format(name, paths[0]) for name in inputs.keys()] | |||
| assert len(paths) == len(inputs), "required inputs: {}; data paths: {}".format( | |||
| inputs.keys(), paths | |||
| ) | |||
| if len(paths) == 1 and ":" not in paths[0]: | |||
| paths[0] = next(iter(inputs.keys())) + ":" + paths[0] | |||
| ret = {} | |||
| for path in paths: | |||
| var, path = path.split(":") | |||
| if args.repeat: | |||
| repeat = args.repeat | |||
| else: | |||
| repeat = 1 | |||
| ret[var] = read_input_data( | |||
| args, inputs[var].shape, inputs[var].dtype, path, repeat | |||
| ) | |||
| return ret | |||
| def make_feeds(args): | |||
| ret = G.load_graph(args.input) | |||
| cg_rt, outputs = ret.graph, ret.output_vars_list | |||
| inputs = cgtools.get_dep_vars(outputs, "Host2DeviceCopy") | |||
| inputs = {i.name: i for i in inputs} | |||
| if not args.no_assert: | |||
| replace_varmap = {} | |||
| inp_map = {} | |||
| # replace var use InputNode | |||
| for name, var in inputs.items(): | |||
| inp = G.InputNode( | |||
| device="xpux", dtype=var.dtype, shape=var.shape, graph=cg_rt | |||
| ) | |||
| replace_varmap[var] = inp.outputs[0] | |||
| inp_map[name] = inp | |||
| new = cgtools.replace_vars(outputs, replace_varmap) | |||
| if isinstance(new, rt.VarNode): | |||
| new = list(new) | |||
| output_nodes = [G.OutputNode(var) for var in new] | |||
| func = cg_rt.compile([node.outputs[0] for node in output_nodes]) | |||
| def make_dev_tensor(value, dtype=None, device=None): | |||
| return tensor(value, dtype=dtype, device=device)._dev_tensor() | |||
| def calculate(*args, **kwargs): | |||
| output_val = [] | |||
| # set inputs value | |||
| for name, var in inputs.items(): | |||
| val = kwargs.pop(name, None) | |||
| assert val is not None, "miss input name{}".format(name) | |||
| dev_tensor = make_dev_tensor(val, dtype=var.dtype, device="xpux") | |||
| inp_map[name].set_value(dev_tensor) | |||
| func.execute() | |||
| for res in output_nodes: | |||
| output_val.append(res.get_value().numpy()) | |||
| return output_val | |||
| def expect_name(var): | |||
| return "{}:expect".format(var.name) | |||
| testcases = [] | |||
| np.set_printoptions(precision=2, threshold=4, suppress=True) | |||
| data_list = [] | |||
| for item in args.data: | |||
| if item.startswith("@"): | |||
| with open(item[1:], "r") as f: | |||
| data_list.extend([line.rstrip() for line in f if line.rstrip() != ""]) | |||
| else: | |||
| data_list.append(item) | |||
| for inp_spec in data_list: | |||
| cur_testcase = gen_one_testcase(args, inputs, inp_spec) | |||
| assert len(cur_testcase) == len( | |||
| inputs | |||
| ), "required inputs: {}; given data: {}".format( | |||
| inputs.keys(), cur_testcase.keys() | |||
| ) | |||
| if not args.no_assert: | |||
| outputs_get = calculate(**cur_testcase) | |||
| for var, val in zip(outputs, outputs_get): | |||
| cur_testcase[expect_name(var)] = val | |||
| logger.info( | |||
| "generate test groundtruth: var={} shape={} range=({}, {})" | |||
| " mean={} var={}".format( | |||
| var, val.shape, val.min(), val.max(), np.mean(val), np.var(val) | |||
| ) | |||
| ) | |||
| testcases.append(cur_testcase) | |||
| logger.info( | |||
| "add testcase: \n {}".format( | |||
| "\n ".join( | |||
| "{}: shape={} dtype={} range=({:.2f},{:.2f}) " | |||
| "mean={:.2f} sd={:.2f}".format( | |||
| k, v.shape, v.dtype, v.min(), v.max(), np.mean(v), np.std(v) | |||
| ) | |||
| for k, v in sorted(cur_testcase.items()) | |||
| ) | |||
| ) | |||
| ) | |||
| if not args.no_assert: | |||
| def expect_shp(var): | |||
| ret = var.shape | |||
| if ret: | |||
| return ret | |||
| return testcases[0][expect_name(var)].shape | |||
| def assert_equal(expect, real, **kwargs): | |||
| op = builtin.AssertEqual(**kwargs) | |||
| (res,) = G.apply_normal_varnode(op, expect, real) | |||
| return res | |||
| verbose = not args.silent | |||
| outputs_new = [] | |||
| for i in outputs: | |||
| device = rt.CompNode("xpux") | |||
| dtype = i.dtype | |||
| name = expect_name(i) | |||
| shape = expect_shp(i) | |||
| # make expect output as one input of model. | |||
| expect_get = rt.make_h2d(cg_rt, device, dtype, shape, name) | |||
| # insert assert opr to check expect and real. | |||
| outputs_new.append( | |||
| assert_equal( | |||
| expect_get, | |||
| i, | |||
| verbose=verbose, | |||
| maxerr=args.maxerr, | |||
| ) | |||
| ) | |||
| inputs[expect_name(i)] = expect_get | |||
| outputs = outputs_new | |||
| return {"outputs": outputs, "testcases": testcases} | |||
| def optimize_for_inference(args, outputs): | |||
| args_list = [ | |||
| "enable_io16xc32", | |||
| "enable_ioc16", | |||
| "enable_hwcd4", | |||
| "enable_nchw4", | |||
| "enable_nchw88", | |||
| "enable_nchw44", | |||
| "enable_nchw44_dot", | |||
| "enable_nchw32", | |||
| "enable_chwn4", | |||
| "enable_fuse_conv_bias_nonlinearity", | |||
| "enable_fuse_conv_bias_with_z", | |||
| "enable_fuse_preprocess", | |||
| ] | |||
| kwargs = {} | |||
| for k in args_list: | |||
| if getattr(args, k): | |||
| kwargs[k] = True | |||
| if args.optimize_for_inference: | |||
| outputs = G.optimize_for_inference(outputs, **kwargs) | |||
| return outputs | |||
| def main(): | |||
| parser = argparse.ArgumentParser( | |||
| description="Pack computing graph, input values and expected output " | |||
| "values into one file for checking correctness. README.md gives more " | |||
| "details on the usage", | |||
| formatter_class=argparse.ArgumentDefaultsHelpFormatter, | |||
| ) | |||
| parser.add_argument("input", help="MegEngine dumped model file") | |||
| parser.add_argument("-o", "--output", help="output file", required=True) | |||
| parser.add_argument( | |||
| "-d", | |||
| "--data", | |||
| default=[], | |||
| action="append", | |||
| required=True, | |||
| help="Given input test data when input file is a network, " | |||
| "and current network output would be used as groundtruth. " | |||
| "The format is var0:file0;var1:file1... to specify data files for " | |||
| "input vars. It can also be #rand(min,max,shape...) for generating " | |||
| "random input data, for example, #rand(0,255), " | |||
| "#rand(0,255,1,3,224,224) or #rand(0, 255, 1, ...) where `...` means " | |||
| "the remaining part of the original shape. " | |||
| "If the shape is not specified, the shape of " | |||
| "corresponding input tensors in the network will be used. " | |||
| "If there is only one input var, its name can be omitted. " | |||
| "Each data file can either be an image which can be loaded by opencv, " | |||
| "or a pickled numpy.ndarray. " | |||
| "This option can be given multiple times to add multiple testcases. " | |||
| " *NOTE* " | |||
| "If you start the data with the letter @, the rest should be a " | |||
| "filename, and each line in the file should be a single datum in " | |||
| "the format described above. ", | |||
| ) | |||
| parser.add_argument( | |||
| "--repeat", | |||
| type=int, | |||
| default=1, | |||
| help="Specify how many times the input image is repeated. " | |||
| "Useful when running benchmark for batch size other than one. " | |||
| "Have no effect on randomly generated input data.", | |||
| ) | |||
| parser.add_argument( | |||
| "--silent", | |||
| action="store_true", | |||
| help="set verbose to False in asserti_equal opr", | |||
| ) | |||
| parser.add_argument( | |||
| "--optimize-for-inference", | |||
| action="store_true", | |||
| help="enable optimization for inference", | |||
| ) | |||
| parser.add_argument( | |||
| "--no-assert", | |||
| action="store_true", | |||
| help="do not insert assert_equal opr to check result; " | |||
| "this option is useful for benchmarking", | |||
| ) | |||
| parser.add_argument( | |||
| "--maxerr", | |||
| type=float, | |||
| default=1e-4, | |||
| help="max error for assert_equal check during runtime", | |||
| ) | |||
| parser.add_argument( | |||
| "--resize-input", | |||
| action="store_true", | |||
| help="resize input image to fit input var shape", | |||
| ) | |||
| parser.add_argument( | |||
| "--input-transform", | |||
| help="a python expression to transform the input data. " | |||
| "Example: data / np.std(data)", | |||
| ) | |||
| parser.add_argument( | |||
| "--discard-var-name", | |||
| action="store_true", | |||
| help="discard variable and param names in the " "generated output", | |||
| ) | |||
| parser.add_argument( | |||
| "--output-strip-info", action="store_true", help="output code strip information" | |||
| ) | |||
| parser.add_argument( | |||
| "--enable-io16xc32", | |||
| action="store_true", | |||
| help="transform the mode to float16 io float32 compute", | |||
| ) | |||
| parser.add_argument( | |||
| "--enable-ioc16", | |||
| action="store_true", | |||
| help="transform the dtype of the model to float16 io " "and compute", | |||
| ) | |||
| parser.add_argument( | |||
| "--enable-fuse-conv-bias-nonlinearity", | |||
| action="store_true", | |||
| help="fuse convolution bias and nonlinearity opr to a " | |||
| "conv_bias opr and compute", | |||
| ) | |||
| parser.add_argument( | |||
| "--enable-hwcd4", | |||
| action="store_true", | |||
| help="transform the model format from NCHW to NHWCD4 " | |||
| "for inference; you may need to disable CUDA and set " | |||
| "MGB_USE_MEGDNN_DBG=2", | |||
| ) | |||
| parser.add_argument( | |||
| "--enable-nchw4", | |||
| action="store_true", | |||
| help="transform the model format from NCHW to NCHW4 " "for inference", | |||
| ) | |||
| parser.add_argument( | |||
| "--enable-nchw88", | |||
| action="store_true", | |||
| help="transform the model format from NCHW to NCHW88 " "for inference", | |||
| ) | |||
| parser.add_argument( | |||
| "--enable-nchw44", | |||
| action="store_true", | |||
| help="transform the model format from NCHW to NCHW44 " "for inference", | |||
| ) | |||
| parser.add_argument( | |||
| "--enable-nchw44-dot", | |||
| action="store_true", | |||
| help="transform the model format from NCHW to NCHW44_DOT " | |||
| "for optimizing armv8.2 dot in inference", | |||
| ) | |||
| parser.add_argument( | |||
| "--enable-nchw32", | |||
| action="store_true", | |||
| help="transform the model format from NCHW4 to NCHW32 " | |||
| "for inference on nvidia TensoCore", | |||
| ) | |||
| parser.add_argument( | |||
| "--enable-chwn4", | |||
| action="store_true", | |||
| help="transform the model format to CHWN4 " | |||
| "for inference, mainly used for nvidia tensorcore", | |||
| ) | |||
| parser.add_argument( | |||
| "--enable-fuse-conv-bias-with-z", | |||
| action="store_true", | |||
| help="fuse conv_bias with z input for inference on " | |||
| "nvidia GPU (this optimization pass will result in mismatch " | |||
| "of the precision of output of training and inference)", | |||
| ) | |||
| parser.add_argument( | |||
| "--enable-fuse-preprocess", | |||
| action="store_true", | |||
| help="fuse astype\pad_channel\dimshuffle and etc opr " | |||
| "from h2d opr", | |||
| ) | |||
| args = parser.parse_args() | |||
| feeds = make_feeds(args) | |||
| assert isinstance(feeds, dict) and feeds["testcases"], "testcases can not be empty" | |||
| output_mgbvars = feeds["outputs"] | |||
| output_mgbvars = optimize_for_inference(args, output_mgbvars) | |||
| inputs = cgtools.get_dep_vars(output_mgbvars, "Host2DeviceCopy") | |||
| inputs = sorted((i.name, i.dtype) for i in inputs) | |||
| if args.discard_var_name: | |||
| sereg_kwargs = dict(keep_var_name=0, keep_param_name=False) | |||
| else: | |||
| sereg_kwargs = dict(keep_var_name=2, keep_param_name=True) | |||
| strip_info_file = args.output + ".json" if args.output_strip_info else None | |||
| with open(args.output, "wb") as fout: | |||
| fout.write(b"mgbtest0") | |||
| fout.write(struct.pack("I", len(feeds["testcases"]))) | |||
| dump_content, stat = G.dump_graph( | |||
| output_mgbvars, | |||
| append_json=True, | |||
| strip_info_file=strip_info_file, | |||
| **sereg_kwargs, | |||
| ) | |||
| fout.write(dump_content) | |||
| logger.info( | |||
| "graph dump sizes: tot_size={:.3f}KiB overhead={:.3f}KiB".format( | |||
| stat.tot_bytes / 1024, (stat.tot_bytes - stat.tensor_value_bytes) / 1024 | |||
| ) | |||
| ) | |||
| def make_dev_tensor(value, dtype=None, device=None): | |||
| return tensor(value, dtype=dtype, device=device)._dev_tensor() | |||
| for testcase in feeds["testcases"]: | |||
| assert isinstance(testcase, dict) | |||
| cg = G.Graph() | |||
| output_mgbvars = [] | |||
| for name, dtype in inputs: | |||
| output_mgbvars.append( | |||
| cg.make_const( | |||
| make_dev_tensor(testcase.pop(name), dtype=dtype, device="cpux") | |||
| ) | |||
| ) | |||
| assert not testcase, "extra inputs provided in testcase: {}".format( | |||
| testcase.keys() | |||
| ) | |||
| with open(args.output, "ab") as fout: | |||
| dump_content, _ = G.dump_graph( | |||
| output_mgbvars, strip_info_file=strip_info_file, append_json=True | |||
| ) | |||
| fout.write(dump_content) | |||
| if __name__ == "__main__": | |||
| main() | |||