GitOrigin-RevId: a1ab77c20a
tags/v1.3.0
| @@ -0,0 +1,8 @@ | |||
| # MegEngine Tools | |||
| This directory contains executable python files. | |||
| Use these files in the following way (replace `xxx` to specific file name, like `network_visualize`): | |||
| ``` | |||
| python -m megengine.tools.xxx | |||
| ``` | |||
| @@ -1,3 +1,4 @@ | |||
| #! /usr/bin/env python3 | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| @@ -7,12 +8,55 @@ | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import argparse | |||
| import os | |||
| import struct | |||
| import textwrap | |||
| from pathlib import Path | |||
| import numpy as np | |||
| from megengine.utils import plugin | |||
| def load_tensor_binary(fobj): | |||
| """ | |||
| Load a tensor dumped by the :class:`BinaryOprIODump` plugin; the actual | |||
| tensor value dump is implemented by ``mgb::debug::dump_tensor``. | |||
| :param fobj: file object, or a string that contains the file name. | |||
| :return: tuple ``(tensor_value, tensor_name)``. | |||
| """ | |||
| if isinstance(fobj, str): | |||
| with open(fobj, "rb") as fin: | |||
| return load_tensor_binary(fin) | |||
| DTYPE_LIST = { | |||
| 0: np.float32, | |||
| 1: np.uint8, | |||
| 2: np.int8, | |||
| 3: np.int16, | |||
| 4: np.int32, | |||
| # 5: _mgb.intb1, | |||
| # 6: _mgb.intb2, | |||
| # 7: _mgb.intb4, | |||
| 8: None, | |||
| 9: np.float16, | |||
| # quantized dtype start from 100000 | |||
| # see MEGDNN_PARAMETERIZED_DTYPE_ENUM_BASE in | |||
| # dnn/include/megdnn/dtype.h | |||
| 100000: np.uint8, | |||
| 100001: np.int32, | |||
| 100002: np.int8, | |||
| } | |||
| header_fmt = struct.Struct("III") | |||
| name_len, dtype, max_ndim = header_fmt.unpack(fobj.read(header_fmt.size)) | |||
| assert ( | |||
| DTYPE_LIST[dtype] is not None | |||
| ), "Cannot load this tensor: dtype Byte is unsupported." | |||
| shape = list(struct.unpack("I" * max_ndim, fobj.read(max_ndim * 4))) | |||
| while shape[-1] == 0: | |||
| shape.pop(-1) | |||
| name = fobj.read(name_len).decode("ascii") | |||
| return np.fromfile(fobj, dtype=DTYPE_LIST[dtype]).reshape(shape), name | |||
| def check(v0, v1, name, max_err): | |||
| @@ -26,9 +70,9 @@ def check(v0, v1, name, max_err): | |||
| ) | |||
| vdiv = np.max([np.abs(v0), np.abs(v1), np.ones_like(v0)], axis=0) | |||
| err = np.abs(v0 - v1) / vdiv | |||
| check = err > max_err | |||
| if check.sum(): | |||
| idx = tuple(i[0] for i in np.nonzero(check)) | |||
| rst = err > max_err | |||
| if rst.sum(): | |||
| idx = tuple(i[0] for i in np.nonzero(rst)) | |||
| raise AssertionError( | |||
| "{} not equal: " | |||
| "shape={} nonequal_idx={} v0={} v1={} err={}".format( | |||
| @@ -79,8 +123,8 @@ def main(): | |||
| files1 = sorted(files1) | |||
| for i, j in zip(files0, files1): | |||
| val0, name0 = plugin.load_tensor_binary(i) | |||
| val1, name1 = plugin.load_tensor_binary(j) | |||
| val0, name0 = load_tensor_binary(i) | |||
| val1, name1 = load_tensor_binary(j) | |||
| name = "{}: \n{}\n{}\n".format( | |||
| i, "\n ".join(textwrap.wrap(name0)), "\n ".join(textwrap.wrap(name1)) | |||
| ) | |||
| @@ -0,0 +1,176 @@ | |||
| #! /usr/bin/env python3 | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import argparse | |||
| import numpy as np | |||
| from megengine.core.tensor.dtype import is_quantize | |||
| from megengine.logger import get_logger | |||
| from megengine.utils.module_stats import ( | |||
| print_flops_stats, | |||
| print_params_stats, | |||
| sizeof_fmt, | |||
| ) | |||
| from megengine.utils.network import Network | |||
| logger = get_logger(__name__) | |||
| def visualize( | |||
| model_path: str, | |||
| log_path: str, | |||
| bar_length_max: int = 20, | |||
| log_params: bool = True, | |||
| log_flops: bool = True, | |||
| ): | |||
| r""" | |||
| Load megengine dumped model and visualize graph structure with tensorboard log files. | |||
| Can also record and print model's statistics like :func:`~.net_stats` | |||
| :param model_path: dir path for megengine dumped model. | |||
| :param log_path: dir path for tensorboard graph log. | |||
| :param bar_length_max: size of bar indicating max flops or parameter size in net stats. | |||
| :param log_params: whether print and record params size. | |||
| :param log_flops: whether print and record op flops. | |||
| """ | |||
| try: | |||
| from tensorboard.compat.proto.attr_value_pb2 import AttrValue | |||
| from tensorboard.compat.proto.config_pb2 import RunMetadata | |||
| from tensorboard.compat.proto.graph_pb2 import GraphDef | |||
| from tensorboard.compat.proto.node_def_pb2 import NodeDef | |||
| from tensorboard.compat.proto.step_stats_pb2 import ( | |||
| AllocatorMemoryUsed, | |||
| DeviceStepStats, | |||
| NodeExecStats, | |||
| StepStats, | |||
| ) | |||
| from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto | |||
| from tensorboard.compat.proto.versions_pb2 import VersionDef | |||
| from tensorboardX import SummaryWriter | |||
| except ImportError: | |||
| logger.error( | |||
| "TensorBoard and TensorboardX are required for visualize.", exc_info=True | |||
| ) | |||
| return | |||
| graph = Network.load(model_path) | |||
| writer = SummaryWriter(log_path) | |||
| def process_name(name): | |||
| return name.replace(".", "/").encode(encoding="utf-8") | |||
| node_list = [] | |||
| flops_list = [] | |||
| params_list = [] | |||
| for node in graph.all_oprs: | |||
| if hasattr(node, "output_idx"): | |||
| node_oup = node.outputs[node.output_idx] | |||
| else: | |||
| if len(node.outputs) != 1: | |||
| logger.warning( | |||
| "OpNode {} has more than one output and not has 'output_idx' attr.".format( | |||
| node | |||
| ) | |||
| ) | |||
| node_oup = node.outputs[0] | |||
| inp_list = [process_name(var.owner.name) for var in node.inputs] | |||
| attr = { | |||
| "_output_shapes": AttrValue( | |||
| list=AttrValue.ListValue( | |||
| shape=[ | |||
| TensorShapeProto( | |||
| dim=[TensorShapeProto.Dim(size=d) for d in node_oup.shape] | |||
| ) | |||
| ] | |||
| ) | |||
| ), | |||
| } | |||
| if hasattr(node, "calc_flops"): | |||
| flops_num = node.calc_flops() | |||
| # add op flops attr | |||
| attr["flops"] = AttrValue(s=sizeof_fmt(flops_num).encode(encoding="utf-8")) | |||
| flops_list.append( | |||
| dict( | |||
| name=node.name, | |||
| class_name=node.type, | |||
| input_shapes=[i.shape for i in node.inputs], | |||
| output_shapes=[o.shape for o in node.outputs], | |||
| flops_num=flops_num, | |||
| flops_cum=0, | |||
| ) | |||
| ) | |||
| if node.type == "ImmutableTensor": | |||
| param_dim = np.prod(node_oup.shape) | |||
| # TODO: consider other quantize dtypes | |||
| param_bytes = 1 if is_quantize(node_oup.dtype) else 4 | |||
| # add tensor size attr | |||
| attr["size"] = AttrValue( | |||
| s=sizeof_fmt(param_dim * param_bytes).encode(encoding="utf-8") | |||
| ) | |||
| params_list.append( | |||
| dict( | |||
| name=node.name, | |||
| shape=node_oup.shape, | |||
| param_dim=param_dim, | |||
| bits=param_bytes * 8, | |||
| size=param_dim * param_bytes, | |||
| size_cum=0, | |||
| mean="{:.2g}".format(node.numpy().mean()), | |||
| std="{:.2g}".format(node.numpy().std()), | |||
| ) | |||
| ) | |||
| node_list.append( | |||
| NodeDef( | |||
| name=process_name(node.name), op=node.type, input=inp_list, attr=attr, | |||
| ) | |||
| ) | |||
| total_flops, total_params = 0, 0 | |||
| if log_params: | |||
| total_params = print_params_stats(params_list, bar_length_max) | |||
| if log_flops: | |||
| total_flops = print_flops_stats(flops_list, bar_length_max) | |||
| graph_def = GraphDef(node=node_list, versions=VersionDef(producer=22)) | |||
| device = "/device:CPU:0" | |||
| stepstats = RunMetadata( | |||
| step_stats=StepStats(dev_stats=[DeviceStepStats(device=device)]) | |||
| ) | |||
| writer._get_file_writer().add_graph((graph_def, stepstats)) | |||
| return total_params, total_flops | |||
| def main(): | |||
| parser = argparse.ArgumentParser( | |||
| description="load a megengine dumped model and export log file for tensorboard visualization.", | |||
| formatter_class=argparse.ArgumentDefaultsHelpFormatter, | |||
| ) | |||
| parser.add_argument("model_path", help="dumped model path.") | |||
| parser.add_argument("log_path", help="tensorboard log path.") | |||
| parser.add_argument( | |||
| "--bar_length_max", | |||
| type=int, | |||
| default=20, | |||
| help="size of bar indicating max flops or parameter size in net stats.", | |||
| ) | |||
| parser.add_argument( | |||
| "--log_params", | |||
| action="store_true", | |||
| help="whether print and record params size.", | |||
| ) | |||
| parser.add_argument( | |||
| "--log_flops", action="store_true", help="whether print and record op flops.", | |||
| ) | |||
| visualize(**vars(parser.parse_args())) | |||
| if __name__ == "__main__": | |||
| main() | |||
| @@ -1,4 +1,4 @@ | |||
| # -*- coding: utf-8 -*- | |||
| #! /usr/bin/env python3 | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| @@ -84,26 +84,125 @@ hook_modules = ( | |||
| ) | |||
| def net_stats(model, input_size, bar_length_max=20, log_params=True, log_flops=True): | |||
| def dict2table(list_of_dict, header): | |||
| table_data = [header] | |||
| for d in list_of_dict: | |||
| row = [] | |||
| for h in header: | |||
| v = "" | |||
| if h in d: | |||
| v = d[h] | |||
| row.append(v) | |||
| table_data.append(row) | |||
| return table_data | |||
| def sizeof_fmt(num, suffix="B"): | |||
| for unit in ["", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"]: | |||
| if abs(num) < 1024.0: | |||
| return "{:3.3f} {}{}".format(num, unit, suffix) | |||
| num /= 1024.0 | |||
| sign_str = "-" if num < 0 else "" | |||
| return "{}{:.1f} {}{}".format(sign_str, num, "Yi", suffix) | |||
| def dict2table(list_of_dict, header): | |||
| table_data = [header] | |||
| for d in list_of_dict: | |||
| row = [] | |||
| for h in header: | |||
| v = "" | |||
| if h in d: | |||
| v = d[h] | |||
| row.append(v) | |||
| table_data.append(row) | |||
| return table_data | |||
| def sizeof_fmt(num, suffix="B"): | |||
| for unit in ["", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"]: | |||
| if abs(num) < 1024.0: | |||
| return "{:3.3f} {}{}".format(num, unit, suffix) | |||
| num /= 1024.0 | |||
| sign_str = "-" if num < 0 else "" | |||
| return "{}{:.1f} {}{}".format(sign_str, num, "Yi", suffix) | |||
| def print_flops_stats(flops, bar_length_max=20): | |||
| flops_list = [i["flops_num"] for i in flops] | |||
| max_flops_num = max(flops_list + [0]) | |||
| # calc total flops and set flops_cum | |||
| total_flops_num = 0 | |||
| for d in flops: | |||
| total_flops_num += int(d["flops_num"]) | |||
| d["flops_cum"] = sizeof_fmt(total_flops_num, suffix="OPs") | |||
| for i in flops: | |||
| f = i["flops_num"] | |||
| i["flops"] = sizeof_fmt(f, suffix="OPs") | |||
| r = i["ratio"] = f / total_flops_num | |||
| i["percentage"] = "{:.2f}%".format(r * 100) | |||
| bar_length = int(f / max_flops_num * bar_length_max) | |||
| i["bar"] = "#" * bar_length | |||
| header = [ | |||
| "name", | |||
| "class_name", | |||
| "input_shapes", | |||
| "output_shapes", | |||
| "flops", | |||
| "flops_cum", | |||
| "percentage", | |||
| "bar", | |||
| ] | |||
| total_flops_str = sizeof_fmt(total_flops_num, suffix="OPs") | |||
| total_var_size = sum(sum(s[1] for s in i["output_shapes"]) for i in flops) | |||
| flops.append( | |||
| dict(name="total", flops=total_flops_str, output_shapes=total_var_size) | |||
| ) | |||
| logger.info("flops stats: \n" + tabulate.tabulate(dict2table(flops, header=header))) | |||
| return total_flops_num | |||
| def print_params_stats(params, bar_length_max=20): | |||
| total_param_dims, total_param_size = 0, 0 | |||
| for d in params: | |||
| total_param_dims += int(d["param_dim"]) | |||
| total_param_size += int(d["size"]) | |||
| d["size"] = sizeof_fmt(d["size"]) | |||
| d["size_cum"] = sizeof_fmt(total_param_size) | |||
| for d in params: | |||
| ratio = d["param_dim"] / total_param_dims | |||
| d["ratio"] = ratio | |||
| d["percentage"] = "{:.2f}%".format(ratio * 100) | |||
| # construct bar | |||
| max_ratio = max([d["ratio"] for d in params]) | |||
| for d in params: | |||
| bar_length = int(d["ratio"] / max_ratio * bar_length_max) | |||
| d["size_bar"] = "#" * bar_length | |||
| param_size = sizeof_fmt(total_param_size) | |||
| params.append(dict(name="total", param_dim=total_param_dims, size=param_size,)) | |||
| header = [ | |||
| "name", | |||
| "shape", | |||
| "mean", | |||
| "std", | |||
| "param_dim", | |||
| "bits", | |||
| "size", | |||
| "size_cum", | |||
| "percentage", | |||
| "size_bar", | |||
| ] | |||
| logger.info( | |||
| "param stats: \n" + tabulate.tabulate(dict2table(params, header=header)) | |||
| ) | |||
| return total_param_size | |||
| def net_stats( | |||
| model: m.Module, | |||
| input_size: int, | |||
| bar_length_max: int = 20, | |||
| log_params: bool = True, | |||
| log_flops: bool = True, | |||
| ): | |||
| r""" | |||
| Calculate and print ``model``'s statistics by adding hook and record Module's inputs outputs size. | |||
| :param model: model that need to get stats info. | |||
| :param input_size: size of input for running model and calculating stats. | |||
| :param bar_length_max: size of bar indicating max flops or parameter size in net stats. | |||
| :param log_params: whether print and record params size. | |||
| :param log_flops: whether print and record op flops. | |||
| """ | |||
| def get_byteswidth(tensor): | |||
| if dtype.is_quantize(tensor.dtype): | |||
| @@ -113,87 +212,6 @@ def net_stats(model, input_size, bar_length_max=20, log_params=True, log_flops=T | |||
| else: | |||
| return 4 | |||
| def print_flops_stats(flops): | |||
| flops_list = [i["flops_num"] for i in flops] | |||
| max_flops_num = max(flops_list + [0]) | |||
| # calc total flops and set flops_cum | |||
| total_flops_num = 0 | |||
| for d in flops: | |||
| total_flops_num += int(d["flops_num"]) | |||
| d["flops_cum"] = sizeof_fmt(total_flops_num, suffix="OPs") | |||
| for i in flops: | |||
| f = i["flops_num"] | |||
| i["flops"] = sizeof_fmt(f, suffix="OPs") | |||
| r = i["ratio"] = f / total_flops_num | |||
| i["percentage"] = "{:.2f}%".format(r * 100) | |||
| bar_length = int(f / max_flops_num * bar_length_max) | |||
| i["bar"] = "#" * bar_length | |||
| header = [ | |||
| "name", | |||
| "class_name", | |||
| "input_shapes", | |||
| "output_shapes", | |||
| "flops", | |||
| "flops_cum", | |||
| "percentage", | |||
| "bar", | |||
| ] | |||
| total_flops_str = sizeof_fmt(total_flops_num, suffix="OPs") | |||
| total_var_size = sum(sum(s[1] for s in i["output_shapes"]) for i in flops) | |||
| flops.append( | |||
| dict(name="total", flops=total_flops_str, output_shapes=total_var_size) | |||
| ) | |||
| logger.info( | |||
| "flops stats: \n" + tabulate.tabulate(dict2table(flops, header=header)) | |||
| ) | |||
| return total_flops_num | |||
| def print_params_stats(params): | |||
| total_param_dims, total_param_size = 0, 0 | |||
| for d in params: | |||
| total_param_dims += int(d["param_dim"]) | |||
| total_param_size += int(d["size"]) | |||
| d["size"] = sizeof_fmt(d["size"]) | |||
| d["size_cum"] = sizeof_fmt(total_param_size) | |||
| for d in params: | |||
| ratio = d["param_dim"] / total_param_dims | |||
| d["ratio"] = ratio | |||
| d["percentage"] = "{:.2f}%".format(ratio * 100) | |||
| # construct bar | |||
| max_ratio = max([d["ratio"] for d in params]) | |||
| for d in params: | |||
| bar_length = int(d["ratio"] / max_ratio * bar_length_max) | |||
| d["size_bar"] = "#" * bar_length | |||
| param_size = sizeof_fmt(total_param_size) | |||
| params.append(dict(name="total", param_dim=total_param_dims, size=param_size,)) | |||
| header = [ | |||
| "name", | |||
| "shape", | |||
| "mean", | |||
| "std", | |||
| "param_dim", | |||
| "bits", | |||
| "size", | |||
| "size_cum", | |||
| "percentage", | |||
| "size_bar", | |||
| ] | |||
| logger.info( | |||
| "param stats: \n" + tabulate.tabulate(dict2table(params, header=header)) | |||
| ) | |||
| return total_param_size | |||
| def net_stats_hook(module, input, output, name=""): | |||
| class_name = str(module.__class__).split(".")[-1].split("'")[0] | |||
| @@ -273,8 +291,8 @@ def net_stats(model, input_size, bar_length_max=20, log_params=True, log_flops=T | |||
| total_flops, total_params = 0, 0 | |||
| if log_params: | |||
| total_params = print_params_stats(params) | |||
| total_params = print_params_stats(params, bar_length_max) | |||
| if log_flops: | |||
| total_flops = print_flops_stats(flops) | |||
| total_flops = print_flops_stats(flops, bar_length_max) | |||
| return total_params, total_flops | |||
| @@ -19,9 +19,9 @@ from ..core._imperative_rt import ComputingGraph | |||
| from ..core.tensor import megbrain_graph as G | |||
| from .comp_graph_tools import get_dep_vars, get_opr_type, get_oprs_seq | |||
| from .network_node import ( | |||
| NetworkNode, | |||
| Host2DeviceCopy, | |||
| ImmutableTensor, | |||
| NetworkNode, | |||
| OpNode, | |||
| VarNode, | |||
| str_to_mge_class, | |||
| @@ -606,9 +606,7 @@ class NodeFilterType(NodeFilter): | |||
| _node_type = None | |||
| def __init__(self, node_iter, node_type): | |||
| assert issubclass(node_type, NetworkNode), "bad opr type: {}".format( | |||
| node_type | |||
| ) | |||
| assert issubclass(node_type, NetworkNode), "bad opr type: {}".format(node_type) | |||
| super().__init__(node_iter) | |||
| self._node_type = node_type | |||
| @@ -10,6 +10,8 @@ import json | |||
| import sys | |||
| from typing import Callable | |||
| import numpy as np | |||
| from ..core import _imperative_rt as rt | |||
| from ..core._wrap import Device | |||
| from ..core.ops import builtin | |||
| @@ -52,7 +54,7 @@ class VarNode(NetworkNode): | |||
| return self.var.dtype if self.var else None | |||
| def set_owner_opr(self, owner_opr): | |||
| self.owner_opr = owner_opr | |||
| self.owner = owner_opr | |||
| class OpNode(NetworkNode): | |||
| @@ -223,6 +225,9 @@ class Elemwise(OpNode): | |||
| type = "Elemwise" | |||
| opdef = builtin.Elemwise | |||
| def calc_flops(self): | |||
| return np.prod(self.outputs[0].shape) | |||
| class Reduce(OpNode): | |||
| type = "Reduce" | |||
| @@ -250,11 +255,21 @@ class MatrixMul(OpNode): | |||
| type = "MatrixMul" | |||
| opdef = builtin.MatrixMul | |||
| def calc_flops(self): | |||
| assert len(self.inputs[0].shape) == 2 and len(self.outputs[0].shape) == 2 | |||
| mid_shape = self.inputs[0].shape[1] | |||
| return np.prod(self.outputs[0].shape) * mid_shape | |||
| class BatchedMatrixMul(OpNode): | |||
| type = "BatchedMatmul" | |||
| opdef = builtin.BatchedMatrixMul | |||
| def calc_flops(self): | |||
| assert len(self.inputs[0].shape) == 3 and len(self.outputs[0].shape) == 3 | |||
| mid_shape = self.inputs[0].shape[2] | |||
| return np.prod(self.outputs[0].shape) * mid_shape | |||
| class Dot(OpNode): | |||
| type = "Dot" | |||
| @@ -270,6 +285,18 @@ class ConvolutionForward(OpNode): | |||
| type = "Convolution" | |||
| opdef = builtin.Convolution | |||
| def calc_flops(self): | |||
| param_W_shape = self.inputs[1].shape | |||
| kh = param_W_shape[-2] | |||
| kw = param_W_shape[-1] | |||
| if len(param_W_shape) == 5: | |||
| num_input = param_W_shape[2] | |||
| else: | |||
| num_input = param_W_shape[1] | |||
| NCHW = np.prod(self.outputs[0].shape) | |||
| # N x Cout x H x W x (Cin x Kw x Kh) | |||
| return NCHW * (num_input * kw * kh) | |||
| class ConvolutionBackwardData(OpNode): | |||
| type = "ConvTranspose" | |||
| @@ -316,6 +343,18 @@ class ConvBiasForward(OpNode): | |||
| obj.params["dtype"] = opr.outputs[0].dtype | |||
| return obj | |||
| def calc_flops(self): | |||
| param_W_shape = self.inputs[1].shape | |||
| kh = param_W_shape[-2] | |||
| kw = param_W_shape[-1] | |||
| if len(param_W_shape) == 5: | |||
| num_input = param_W_shape[2] | |||
| else: | |||
| num_input = param_W_shape[1] | |||
| NCHW = np.prod(self.outputs[0].shape) | |||
| # N x Cout x H x W x (Cin x Kw x Kh + bias) | |||
| return NCHW * (num_input * kw * kh + 1) | |||
| class BatchConvBiasForward(OpNode): | |||
| type = "BatchConvBias" | |||
| @@ -331,6 +370,7 @@ class BatchConvBiasForward(OpNode): | |||
| class BatchNormForward(OpNode): | |||
| type = "BatchNorm" | |||
| opdef = builtin.BatchNorm | |||
| output_idx = -1 | |||
| class ROIAlignForward(OpNode): | |||
| @@ -622,6 +662,9 @@ class ElemwiseMultiType(OpNode): | |||
| obj.params["dtype"] = opr.outputs[0].dtype | |||
| return obj | |||
| def calc_flops(self): | |||
| return np.prod(self.outputs[0].shape) | |||
| class CvtColorForward(OpNode): | |||
| type = "CvtColor" | |||
| @@ -1,57 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import struct | |||
| import numpy as np | |||
| def load_tensor_binary(fobj): | |||
| """ | |||
| Load a tensor dumped by the :class:`BinaryOprIODump` plugin; the actual | |||
| tensor value dump is implemented by ``mgb::debug::dump_tensor``. | |||
| Multiple values can be compared by ``tools/compare_binary_iodump.py``. | |||
| :param fobj: file object, or a string that contains the file name. | |||
| :return: tuple ``(tensor_value, tensor_name)``. | |||
| """ | |||
| if isinstance(fobj, str): | |||
| with open(fobj, "rb") as fin: | |||
| return load_tensor_binary(fin) | |||
| DTYPE_LIST = { | |||
| 0: np.float32, | |||
| 1: np.uint8, | |||
| 2: np.int8, | |||
| 3: np.int16, | |||
| 4: np.int32, | |||
| # 5: _mgb.intb1, | |||
| # 6: _mgb.intb2, | |||
| # 7: _mgb.intb4, | |||
| 8: None, | |||
| 9: np.float16, | |||
| # quantized dtype start from 100000 | |||
| # see MEGDNN_PARAMETERIZED_DTYPE_ENUM_BASE in | |||
| # dnn/include/megdnn/dtype.h | |||
| 100000: np.uint8, | |||
| 100001: np.int32, | |||
| 100002: np.int8, | |||
| } | |||
| header_fmt = struct.Struct("III") | |||
| name_len, dtype, max_ndim = header_fmt.unpack(fobj.read(header_fmt.size)) | |||
| assert ( | |||
| DTYPE_LIST[dtype] is not None | |||
| ), "Cannot load this tensor: dtype Byte is unsupported." | |||
| shape = list(struct.unpack("I" * max_ndim, fobj.read(max_ndim * 4))) | |||
| while shape[-1] == 0: | |||
| shape.pop(-1) | |||
| name = fobj.read(name_len).decode("ascii") | |||
| return np.fromfile(fobj, dtype=DTYPE_LIST[dtype]).reshape(shape), name | |||