GitOrigin-RevId: a1ab77c20a
tags/v1.3.0
| @@ -0,0 +1,8 @@ | |||||
| # MegEngine Tools | |||||
| This directory contains executable python files. | |||||
| Use these files in the following way (replace `xxx` to specific file name, like `network_visualize`): | |||||
| ``` | |||||
| python -m megengine.tools.xxx | |||||
| ``` | |||||
| @@ -1,3 +1,4 @@ | |||||
| #! /usr/bin/env python3 | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
| # | # | ||||
| # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
| @@ -7,12 +8,55 @@ | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| import argparse | import argparse | ||||
| import os | import os | ||||
| import struct | |||||
| import textwrap | import textwrap | ||||
| from pathlib import Path | from pathlib import Path | ||||
| import numpy as np | import numpy as np | ||||
| from megengine.utils import plugin | |||||
| def load_tensor_binary(fobj): | |||||
| """ | |||||
| Load a tensor dumped by the :class:`BinaryOprIODump` plugin; the actual | |||||
| tensor value dump is implemented by ``mgb::debug::dump_tensor``. | |||||
| :param fobj: file object, or a string that contains the file name. | |||||
| :return: tuple ``(tensor_value, tensor_name)``. | |||||
| """ | |||||
| if isinstance(fobj, str): | |||||
| with open(fobj, "rb") as fin: | |||||
| return load_tensor_binary(fin) | |||||
| DTYPE_LIST = { | |||||
| 0: np.float32, | |||||
| 1: np.uint8, | |||||
| 2: np.int8, | |||||
| 3: np.int16, | |||||
| 4: np.int32, | |||||
| # 5: _mgb.intb1, | |||||
| # 6: _mgb.intb2, | |||||
| # 7: _mgb.intb4, | |||||
| 8: None, | |||||
| 9: np.float16, | |||||
| # quantized dtype start from 100000 | |||||
| # see MEGDNN_PARAMETERIZED_DTYPE_ENUM_BASE in | |||||
| # dnn/include/megdnn/dtype.h | |||||
| 100000: np.uint8, | |||||
| 100001: np.int32, | |||||
| 100002: np.int8, | |||||
| } | |||||
| header_fmt = struct.Struct("III") | |||||
| name_len, dtype, max_ndim = header_fmt.unpack(fobj.read(header_fmt.size)) | |||||
| assert ( | |||||
| DTYPE_LIST[dtype] is not None | |||||
| ), "Cannot load this tensor: dtype Byte is unsupported." | |||||
| shape = list(struct.unpack("I" * max_ndim, fobj.read(max_ndim * 4))) | |||||
| while shape[-1] == 0: | |||||
| shape.pop(-1) | |||||
| name = fobj.read(name_len).decode("ascii") | |||||
| return np.fromfile(fobj, dtype=DTYPE_LIST[dtype]).reshape(shape), name | |||||
| def check(v0, v1, name, max_err): | def check(v0, v1, name, max_err): | ||||
| @@ -26,9 +70,9 @@ def check(v0, v1, name, max_err): | |||||
| ) | ) | ||||
| vdiv = np.max([np.abs(v0), np.abs(v1), np.ones_like(v0)], axis=0) | vdiv = np.max([np.abs(v0), np.abs(v1), np.ones_like(v0)], axis=0) | ||||
| err = np.abs(v0 - v1) / vdiv | err = np.abs(v0 - v1) / vdiv | ||||
| check = err > max_err | |||||
| if check.sum(): | |||||
| idx = tuple(i[0] for i in np.nonzero(check)) | |||||
| rst = err > max_err | |||||
| if rst.sum(): | |||||
| idx = tuple(i[0] for i in np.nonzero(rst)) | |||||
| raise AssertionError( | raise AssertionError( | ||||
| "{} not equal: " | "{} not equal: " | ||||
| "shape={} nonequal_idx={} v0={} v1={} err={}".format( | "shape={} nonequal_idx={} v0={} v1={} err={}".format( | ||||
| @@ -79,8 +123,8 @@ def main(): | |||||
| files1 = sorted(files1) | files1 = sorted(files1) | ||||
| for i, j in zip(files0, files1): | for i, j in zip(files0, files1): | ||||
| val0, name0 = plugin.load_tensor_binary(i) | |||||
| val1, name1 = plugin.load_tensor_binary(j) | |||||
| val0, name0 = load_tensor_binary(i) | |||||
| val1, name1 = load_tensor_binary(j) | |||||
| name = "{}: \n{}\n{}\n".format( | name = "{}: \n{}\n{}\n".format( | ||||
| i, "\n ".join(textwrap.wrap(name0)), "\n ".join(textwrap.wrap(name1)) | i, "\n ".join(textwrap.wrap(name0)), "\n ".join(textwrap.wrap(name1)) | ||||
| ) | ) | ||||
| @@ -0,0 +1,176 @@ | |||||
| #! /usr/bin/env python3 | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| import argparse | |||||
| import numpy as np | |||||
| from megengine.core.tensor.dtype import is_quantize | |||||
| from megengine.logger import get_logger | |||||
| from megengine.utils.module_stats import ( | |||||
| print_flops_stats, | |||||
| print_params_stats, | |||||
| sizeof_fmt, | |||||
| ) | |||||
| from megengine.utils.network import Network | |||||
| logger = get_logger(__name__) | |||||
| def visualize( | |||||
| model_path: str, | |||||
| log_path: str, | |||||
| bar_length_max: int = 20, | |||||
| log_params: bool = True, | |||||
| log_flops: bool = True, | |||||
| ): | |||||
| r""" | |||||
| Load megengine dumped model and visualize graph structure with tensorboard log files. | |||||
| Can also record and print model's statistics like :func:`~.net_stats` | |||||
| :param model_path: dir path for megengine dumped model. | |||||
| :param log_path: dir path for tensorboard graph log. | |||||
| :param bar_length_max: size of bar indicating max flops or parameter size in net stats. | |||||
| :param log_params: whether print and record params size. | |||||
| :param log_flops: whether print and record op flops. | |||||
| """ | |||||
| try: | |||||
| from tensorboard.compat.proto.attr_value_pb2 import AttrValue | |||||
| from tensorboard.compat.proto.config_pb2 import RunMetadata | |||||
| from tensorboard.compat.proto.graph_pb2 import GraphDef | |||||
| from tensorboard.compat.proto.node_def_pb2 import NodeDef | |||||
| from tensorboard.compat.proto.step_stats_pb2 import ( | |||||
| AllocatorMemoryUsed, | |||||
| DeviceStepStats, | |||||
| NodeExecStats, | |||||
| StepStats, | |||||
| ) | |||||
| from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto | |||||
| from tensorboard.compat.proto.versions_pb2 import VersionDef | |||||
| from tensorboardX import SummaryWriter | |||||
| except ImportError: | |||||
| logger.error( | |||||
| "TensorBoard and TensorboardX are required for visualize.", exc_info=True | |||||
| ) | |||||
| return | |||||
| graph = Network.load(model_path) | |||||
| writer = SummaryWriter(log_path) | |||||
| def process_name(name): | |||||
| return name.replace(".", "/").encode(encoding="utf-8") | |||||
| node_list = [] | |||||
| flops_list = [] | |||||
| params_list = [] | |||||
| for node in graph.all_oprs: | |||||
| if hasattr(node, "output_idx"): | |||||
| node_oup = node.outputs[node.output_idx] | |||||
| else: | |||||
| if len(node.outputs) != 1: | |||||
| logger.warning( | |||||
| "OpNode {} has more than one output and not has 'output_idx' attr.".format( | |||||
| node | |||||
| ) | |||||
| ) | |||||
| node_oup = node.outputs[0] | |||||
| inp_list = [process_name(var.owner.name) for var in node.inputs] | |||||
| attr = { | |||||
| "_output_shapes": AttrValue( | |||||
| list=AttrValue.ListValue( | |||||
| shape=[ | |||||
| TensorShapeProto( | |||||
| dim=[TensorShapeProto.Dim(size=d) for d in node_oup.shape] | |||||
| ) | |||||
| ] | |||||
| ) | |||||
| ), | |||||
| } | |||||
| if hasattr(node, "calc_flops"): | |||||
| flops_num = node.calc_flops() | |||||
| # add op flops attr | |||||
| attr["flops"] = AttrValue(s=sizeof_fmt(flops_num).encode(encoding="utf-8")) | |||||
| flops_list.append( | |||||
| dict( | |||||
| name=node.name, | |||||
| class_name=node.type, | |||||
| input_shapes=[i.shape for i in node.inputs], | |||||
| output_shapes=[o.shape for o in node.outputs], | |||||
| flops_num=flops_num, | |||||
| flops_cum=0, | |||||
| ) | |||||
| ) | |||||
| if node.type == "ImmutableTensor": | |||||
| param_dim = np.prod(node_oup.shape) | |||||
| # TODO: consider other quantize dtypes | |||||
| param_bytes = 1 if is_quantize(node_oup.dtype) else 4 | |||||
| # add tensor size attr | |||||
| attr["size"] = AttrValue( | |||||
| s=sizeof_fmt(param_dim * param_bytes).encode(encoding="utf-8") | |||||
| ) | |||||
| params_list.append( | |||||
| dict( | |||||
| name=node.name, | |||||
| shape=node_oup.shape, | |||||
| param_dim=param_dim, | |||||
| bits=param_bytes * 8, | |||||
| size=param_dim * param_bytes, | |||||
| size_cum=0, | |||||
| mean="{:.2g}".format(node.numpy().mean()), | |||||
| std="{:.2g}".format(node.numpy().std()), | |||||
| ) | |||||
| ) | |||||
| node_list.append( | |||||
| NodeDef( | |||||
| name=process_name(node.name), op=node.type, input=inp_list, attr=attr, | |||||
| ) | |||||
| ) | |||||
| total_flops, total_params = 0, 0 | |||||
| if log_params: | |||||
| total_params = print_params_stats(params_list, bar_length_max) | |||||
| if log_flops: | |||||
| total_flops = print_flops_stats(flops_list, bar_length_max) | |||||
| graph_def = GraphDef(node=node_list, versions=VersionDef(producer=22)) | |||||
| device = "/device:CPU:0" | |||||
| stepstats = RunMetadata( | |||||
| step_stats=StepStats(dev_stats=[DeviceStepStats(device=device)]) | |||||
| ) | |||||
| writer._get_file_writer().add_graph((graph_def, stepstats)) | |||||
| return total_params, total_flops | |||||
| def main(): | |||||
| parser = argparse.ArgumentParser( | |||||
| description="load a megengine dumped model and export log file for tensorboard visualization.", | |||||
| formatter_class=argparse.ArgumentDefaultsHelpFormatter, | |||||
| ) | |||||
| parser.add_argument("model_path", help="dumped model path.") | |||||
| parser.add_argument("log_path", help="tensorboard log path.") | |||||
| parser.add_argument( | |||||
| "--bar_length_max", | |||||
| type=int, | |||||
| default=20, | |||||
| help="size of bar indicating max flops or parameter size in net stats.", | |||||
| ) | |||||
| parser.add_argument( | |||||
| "--log_params", | |||||
| action="store_true", | |||||
| help="whether print and record params size.", | |||||
| ) | |||||
| parser.add_argument( | |||||
| "--log_flops", action="store_true", help="whether print and record op flops.", | |||||
| ) | |||||
| visualize(**vars(parser.parse_args())) | |||||
| if __name__ == "__main__": | |||||
| main() | |||||
| @@ -1,4 +1,4 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| #! /usr/bin/env python3 | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
| # | # | ||||
| # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
| @@ -84,26 +84,125 @@ hook_modules = ( | |||||
| ) | ) | ||||
| def net_stats(model, input_size, bar_length_max=20, log_params=True, log_flops=True): | |||||
| def dict2table(list_of_dict, header): | |||||
| table_data = [header] | |||||
| for d in list_of_dict: | |||||
| row = [] | |||||
| for h in header: | |||||
| v = "" | |||||
| if h in d: | |||||
| v = d[h] | |||||
| row.append(v) | |||||
| table_data.append(row) | |||||
| return table_data | |||||
| def sizeof_fmt(num, suffix="B"): | |||||
| for unit in ["", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"]: | |||||
| if abs(num) < 1024.0: | |||||
| return "{:3.3f} {}{}".format(num, unit, suffix) | |||||
| num /= 1024.0 | |||||
| sign_str = "-" if num < 0 else "" | |||||
| return "{}{:.1f} {}{}".format(sign_str, num, "Yi", suffix) | |||||
| def dict2table(list_of_dict, header): | |||||
| table_data = [header] | |||||
| for d in list_of_dict: | |||||
| row = [] | |||||
| for h in header: | |||||
| v = "" | |||||
| if h in d: | |||||
| v = d[h] | |||||
| row.append(v) | |||||
| table_data.append(row) | |||||
| return table_data | |||||
| def sizeof_fmt(num, suffix="B"): | |||||
| for unit in ["", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"]: | |||||
| if abs(num) < 1024.0: | |||||
| return "{:3.3f} {}{}".format(num, unit, suffix) | |||||
| num /= 1024.0 | |||||
| sign_str = "-" if num < 0 else "" | |||||
| return "{}{:.1f} {}{}".format(sign_str, num, "Yi", suffix) | |||||
| def print_flops_stats(flops, bar_length_max=20): | |||||
| flops_list = [i["flops_num"] for i in flops] | |||||
| max_flops_num = max(flops_list + [0]) | |||||
| # calc total flops and set flops_cum | |||||
| total_flops_num = 0 | |||||
| for d in flops: | |||||
| total_flops_num += int(d["flops_num"]) | |||||
| d["flops_cum"] = sizeof_fmt(total_flops_num, suffix="OPs") | |||||
| for i in flops: | |||||
| f = i["flops_num"] | |||||
| i["flops"] = sizeof_fmt(f, suffix="OPs") | |||||
| r = i["ratio"] = f / total_flops_num | |||||
| i["percentage"] = "{:.2f}%".format(r * 100) | |||||
| bar_length = int(f / max_flops_num * bar_length_max) | |||||
| i["bar"] = "#" * bar_length | |||||
| header = [ | |||||
| "name", | |||||
| "class_name", | |||||
| "input_shapes", | |||||
| "output_shapes", | |||||
| "flops", | |||||
| "flops_cum", | |||||
| "percentage", | |||||
| "bar", | |||||
| ] | |||||
| total_flops_str = sizeof_fmt(total_flops_num, suffix="OPs") | |||||
| total_var_size = sum(sum(s[1] for s in i["output_shapes"]) for i in flops) | |||||
| flops.append( | |||||
| dict(name="total", flops=total_flops_str, output_shapes=total_var_size) | |||||
| ) | |||||
| logger.info("flops stats: \n" + tabulate.tabulate(dict2table(flops, header=header))) | |||||
| return total_flops_num | |||||
| def print_params_stats(params, bar_length_max=20): | |||||
| total_param_dims, total_param_size = 0, 0 | |||||
| for d in params: | |||||
| total_param_dims += int(d["param_dim"]) | |||||
| total_param_size += int(d["size"]) | |||||
| d["size"] = sizeof_fmt(d["size"]) | |||||
| d["size_cum"] = sizeof_fmt(total_param_size) | |||||
| for d in params: | |||||
| ratio = d["param_dim"] / total_param_dims | |||||
| d["ratio"] = ratio | |||||
| d["percentage"] = "{:.2f}%".format(ratio * 100) | |||||
| # construct bar | |||||
| max_ratio = max([d["ratio"] for d in params]) | |||||
| for d in params: | |||||
| bar_length = int(d["ratio"] / max_ratio * bar_length_max) | |||||
| d["size_bar"] = "#" * bar_length | |||||
| param_size = sizeof_fmt(total_param_size) | |||||
| params.append(dict(name="total", param_dim=total_param_dims, size=param_size,)) | |||||
| header = [ | |||||
| "name", | |||||
| "shape", | |||||
| "mean", | |||||
| "std", | |||||
| "param_dim", | |||||
| "bits", | |||||
| "size", | |||||
| "size_cum", | |||||
| "percentage", | |||||
| "size_bar", | |||||
| ] | |||||
| logger.info( | |||||
| "param stats: \n" + tabulate.tabulate(dict2table(params, header=header)) | |||||
| ) | |||||
| return total_param_size | |||||
| def net_stats( | |||||
| model: m.Module, | |||||
| input_size: int, | |||||
| bar_length_max: int = 20, | |||||
| log_params: bool = True, | |||||
| log_flops: bool = True, | |||||
| ): | |||||
| r""" | |||||
| Calculate and print ``model``'s statistics by adding hook and record Module's inputs outputs size. | |||||
| :param model: model that need to get stats info. | |||||
| :param input_size: size of input for running model and calculating stats. | |||||
| :param bar_length_max: size of bar indicating max flops or parameter size in net stats. | |||||
| :param log_params: whether print and record params size. | |||||
| :param log_flops: whether print and record op flops. | |||||
| """ | |||||
| def get_byteswidth(tensor): | def get_byteswidth(tensor): | ||||
| if dtype.is_quantize(tensor.dtype): | if dtype.is_quantize(tensor.dtype): | ||||
| @@ -113,87 +212,6 @@ def net_stats(model, input_size, bar_length_max=20, log_params=True, log_flops=T | |||||
| else: | else: | ||||
| return 4 | return 4 | ||||
| def print_flops_stats(flops): | |||||
| flops_list = [i["flops_num"] for i in flops] | |||||
| max_flops_num = max(flops_list + [0]) | |||||
| # calc total flops and set flops_cum | |||||
| total_flops_num = 0 | |||||
| for d in flops: | |||||
| total_flops_num += int(d["flops_num"]) | |||||
| d["flops_cum"] = sizeof_fmt(total_flops_num, suffix="OPs") | |||||
| for i in flops: | |||||
| f = i["flops_num"] | |||||
| i["flops"] = sizeof_fmt(f, suffix="OPs") | |||||
| r = i["ratio"] = f / total_flops_num | |||||
| i["percentage"] = "{:.2f}%".format(r * 100) | |||||
| bar_length = int(f / max_flops_num * bar_length_max) | |||||
| i["bar"] = "#" * bar_length | |||||
| header = [ | |||||
| "name", | |||||
| "class_name", | |||||
| "input_shapes", | |||||
| "output_shapes", | |||||
| "flops", | |||||
| "flops_cum", | |||||
| "percentage", | |||||
| "bar", | |||||
| ] | |||||
| total_flops_str = sizeof_fmt(total_flops_num, suffix="OPs") | |||||
| total_var_size = sum(sum(s[1] for s in i["output_shapes"]) for i in flops) | |||||
| flops.append( | |||||
| dict(name="total", flops=total_flops_str, output_shapes=total_var_size) | |||||
| ) | |||||
| logger.info( | |||||
| "flops stats: \n" + tabulate.tabulate(dict2table(flops, header=header)) | |||||
| ) | |||||
| return total_flops_num | |||||
| def print_params_stats(params): | |||||
| total_param_dims, total_param_size = 0, 0 | |||||
| for d in params: | |||||
| total_param_dims += int(d["param_dim"]) | |||||
| total_param_size += int(d["size"]) | |||||
| d["size"] = sizeof_fmt(d["size"]) | |||||
| d["size_cum"] = sizeof_fmt(total_param_size) | |||||
| for d in params: | |||||
| ratio = d["param_dim"] / total_param_dims | |||||
| d["ratio"] = ratio | |||||
| d["percentage"] = "{:.2f}%".format(ratio * 100) | |||||
| # construct bar | |||||
| max_ratio = max([d["ratio"] for d in params]) | |||||
| for d in params: | |||||
| bar_length = int(d["ratio"] / max_ratio * bar_length_max) | |||||
| d["size_bar"] = "#" * bar_length | |||||
| param_size = sizeof_fmt(total_param_size) | |||||
| params.append(dict(name="total", param_dim=total_param_dims, size=param_size,)) | |||||
| header = [ | |||||
| "name", | |||||
| "shape", | |||||
| "mean", | |||||
| "std", | |||||
| "param_dim", | |||||
| "bits", | |||||
| "size", | |||||
| "size_cum", | |||||
| "percentage", | |||||
| "size_bar", | |||||
| ] | |||||
| logger.info( | |||||
| "param stats: \n" + tabulate.tabulate(dict2table(params, header=header)) | |||||
| ) | |||||
| return total_param_size | |||||
| def net_stats_hook(module, input, output, name=""): | def net_stats_hook(module, input, output, name=""): | ||||
| class_name = str(module.__class__).split(".")[-1].split("'")[0] | class_name = str(module.__class__).split(".")[-1].split("'")[0] | ||||
| @@ -273,8 +291,8 @@ def net_stats(model, input_size, bar_length_max=20, log_params=True, log_flops=T | |||||
| total_flops, total_params = 0, 0 | total_flops, total_params = 0, 0 | ||||
| if log_params: | if log_params: | ||||
| total_params = print_params_stats(params) | |||||
| total_params = print_params_stats(params, bar_length_max) | |||||
| if log_flops: | if log_flops: | ||||
| total_flops = print_flops_stats(flops) | |||||
| total_flops = print_flops_stats(flops, bar_length_max) | |||||
| return total_params, total_flops | return total_params, total_flops | ||||
| @@ -19,9 +19,9 @@ from ..core._imperative_rt import ComputingGraph | |||||
| from ..core.tensor import megbrain_graph as G | from ..core.tensor import megbrain_graph as G | ||||
| from .comp_graph_tools import get_dep_vars, get_opr_type, get_oprs_seq | from .comp_graph_tools import get_dep_vars, get_opr_type, get_oprs_seq | ||||
| from .network_node import ( | from .network_node import ( | ||||
| NetworkNode, | |||||
| Host2DeviceCopy, | Host2DeviceCopy, | ||||
| ImmutableTensor, | ImmutableTensor, | ||||
| NetworkNode, | |||||
| OpNode, | OpNode, | ||||
| VarNode, | VarNode, | ||||
| str_to_mge_class, | str_to_mge_class, | ||||
| @@ -606,9 +606,7 @@ class NodeFilterType(NodeFilter): | |||||
| _node_type = None | _node_type = None | ||||
| def __init__(self, node_iter, node_type): | def __init__(self, node_iter, node_type): | ||||
| assert issubclass(node_type, NetworkNode), "bad opr type: {}".format( | |||||
| node_type | |||||
| ) | |||||
| assert issubclass(node_type, NetworkNode), "bad opr type: {}".format(node_type) | |||||
| super().__init__(node_iter) | super().__init__(node_iter) | ||||
| self._node_type = node_type | self._node_type = node_type | ||||
| @@ -10,6 +10,8 @@ import json | |||||
| import sys | import sys | ||||
| from typing import Callable | from typing import Callable | ||||
| import numpy as np | |||||
| from ..core import _imperative_rt as rt | from ..core import _imperative_rt as rt | ||||
| from ..core._wrap import Device | from ..core._wrap import Device | ||||
| from ..core.ops import builtin | from ..core.ops import builtin | ||||
| @@ -52,7 +54,7 @@ class VarNode(NetworkNode): | |||||
| return self.var.dtype if self.var else None | return self.var.dtype if self.var else None | ||||
| def set_owner_opr(self, owner_opr): | def set_owner_opr(self, owner_opr): | ||||
| self.owner_opr = owner_opr | |||||
| self.owner = owner_opr | |||||
| class OpNode(NetworkNode): | class OpNode(NetworkNode): | ||||
| @@ -223,6 +225,9 @@ class Elemwise(OpNode): | |||||
| type = "Elemwise" | type = "Elemwise" | ||||
| opdef = builtin.Elemwise | opdef = builtin.Elemwise | ||||
| def calc_flops(self): | |||||
| return np.prod(self.outputs[0].shape) | |||||
| class Reduce(OpNode): | class Reduce(OpNode): | ||||
| type = "Reduce" | type = "Reduce" | ||||
| @@ -250,11 +255,21 @@ class MatrixMul(OpNode): | |||||
| type = "MatrixMul" | type = "MatrixMul" | ||||
| opdef = builtin.MatrixMul | opdef = builtin.MatrixMul | ||||
| def calc_flops(self): | |||||
| assert len(self.inputs[0].shape) == 2 and len(self.outputs[0].shape) == 2 | |||||
| mid_shape = self.inputs[0].shape[1] | |||||
| return np.prod(self.outputs[0].shape) * mid_shape | |||||
| class BatchedMatrixMul(OpNode): | class BatchedMatrixMul(OpNode): | ||||
| type = "BatchedMatmul" | type = "BatchedMatmul" | ||||
| opdef = builtin.BatchedMatrixMul | opdef = builtin.BatchedMatrixMul | ||||
| def calc_flops(self): | |||||
| assert len(self.inputs[0].shape) == 3 and len(self.outputs[0].shape) == 3 | |||||
| mid_shape = self.inputs[0].shape[2] | |||||
| return np.prod(self.outputs[0].shape) * mid_shape | |||||
| class Dot(OpNode): | class Dot(OpNode): | ||||
| type = "Dot" | type = "Dot" | ||||
| @@ -270,6 +285,18 @@ class ConvolutionForward(OpNode): | |||||
| type = "Convolution" | type = "Convolution" | ||||
| opdef = builtin.Convolution | opdef = builtin.Convolution | ||||
| def calc_flops(self): | |||||
| param_W_shape = self.inputs[1].shape | |||||
| kh = param_W_shape[-2] | |||||
| kw = param_W_shape[-1] | |||||
| if len(param_W_shape) == 5: | |||||
| num_input = param_W_shape[2] | |||||
| else: | |||||
| num_input = param_W_shape[1] | |||||
| NCHW = np.prod(self.outputs[0].shape) | |||||
| # N x Cout x H x W x (Cin x Kw x Kh) | |||||
| return NCHW * (num_input * kw * kh) | |||||
| class ConvolutionBackwardData(OpNode): | class ConvolutionBackwardData(OpNode): | ||||
| type = "ConvTranspose" | type = "ConvTranspose" | ||||
| @@ -316,6 +343,18 @@ class ConvBiasForward(OpNode): | |||||
| obj.params["dtype"] = opr.outputs[0].dtype | obj.params["dtype"] = opr.outputs[0].dtype | ||||
| return obj | return obj | ||||
| def calc_flops(self): | |||||
| param_W_shape = self.inputs[1].shape | |||||
| kh = param_W_shape[-2] | |||||
| kw = param_W_shape[-1] | |||||
| if len(param_W_shape) == 5: | |||||
| num_input = param_W_shape[2] | |||||
| else: | |||||
| num_input = param_W_shape[1] | |||||
| NCHW = np.prod(self.outputs[0].shape) | |||||
| # N x Cout x H x W x (Cin x Kw x Kh + bias) | |||||
| return NCHW * (num_input * kw * kh + 1) | |||||
| class BatchConvBiasForward(OpNode): | class BatchConvBiasForward(OpNode): | ||||
| type = "BatchConvBias" | type = "BatchConvBias" | ||||
| @@ -331,6 +370,7 @@ class BatchConvBiasForward(OpNode): | |||||
| class BatchNormForward(OpNode): | class BatchNormForward(OpNode): | ||||
| type = "BatchNorm" | type = "BatchNorm" | ||||
| opdef = builtin.BatchNorm | opdef = builtin.BatchNorm | ||||
| output_idx = -1 | |||||
| class ROIAlignForward(OpNode): | class ROIAlignForward(OpNode): | ||||
| @@ -622,6 +662,9 @@ class ElemwiseMultiType(OpNode): | |||||
| obj.params["dtype"] = opr.outputs[0].dtype | obj.params["dtype"] = opr.outputs[0].dtype | ||||
| return obj | return obj | ||||
| def calc_flops(self): | |||||
| return np.prod(self.outputs[0].shape) | |||||
| class CvtColorForward(OpNode): | class CvtColorForward(OpNode): | ||||
| type = "CvtColor" | type = "CvtColor" | ||||
| @@ -1,57 +0,0 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| import struct | |||||
| import numpy as np | |||||
| def load_tensor_binary(fobj): | |||||
| """ | |||||
| Load a tensor dumped by the :class:`BinaryOprIODump` plugin; the actual | |||||
| tensor value dump is implemented by ``mgb::debug::dump_tensor``. | |||||
| Multiple values can be compared by ``tools/compare_binary_iodump.py``. | |||||
| :param fobj: file object, or a string that contains the file name. | |||||
| :return: tuple ``(tensor_value, tensor_name)``. | |||||
| """ | |||||
| if isinstance(fobj, str): | |||||
| with open(fobj, "rb") as fin: | |||||
| return load_tensor_binary(fin) | |||||
| DTYPE_LIST = { | |||||
| 0: np.float32, | |||||
| 1: np.uint8, | |||||
| 2: np.int8, | |||||
| 3: np.int16, | |||||
| 4: np.int32, | |||||
| # 5: _mgb.intb1, | |||||
| # 6: _mgb.intb2, | |||||
| # 7: _mgb.intb4, | |||||
| 8: None, | |||||
| 9: np.float16, | |||||
| # quantized dtype start from 100000 | |||||
| # see MEGDNN_PARAMETERIZED_DTYPE_ENUM_BASE in | |||||
| # dnn/include/megdnn/dtype.h | |||||
| 100000: np.uint8, | |||||
| 100001: np.int32, | |||||
| 100002: np.int8, | |||||
| } | |||||
| header_fmt = struct.Struct("III") | |||||
| name_len, dtype, max_ndim = header_fmt.unpack(fobj.read(header_fmt.size)) | |||||
| assert ( | |||||
| DTYPE_LIST[dtype] is not None | |||||
| ), "Cannot load this tensor: dtype Byte is unsupported." | |||||
| shape = list(struct.unpack("I" * max_ndim, fobj.read(max_ndim * 4))) | |||||
| while shape[-1] == 0: | |||||
| shape.pop(-1) | |||||
| name = fobj.read(name_len).decode("ascii") | |||||
| return np.fromfile(fobj, dtype=DTYPE_LIST[dtype]).reshape(shape), name | |||||