BREAKING CHANGE:
GitOrigin-RevId: cd2a1acd11
tags/v1.5.0
| @@ -12,6 +12,7 @@ import re | |||
| from collections import namedtuple | |||
| import numpy as np | |||
| from tqdm import tqdm | |||
| from megengine.core.tensor.dtype import is_quantize | |||
| from megengine.logger import _imperative_rt_logger, get_logger, set_mgb_log_level | |||
| @@ -37,10 +38,13 @@ logger = get_logger(__name__) | |||
| def visualize( | |||
| model_path: str, | |||
| log_path: str, | |||
| input: np.ndarray = None, | |||
| inp_dict: dict = None, | |||
| cal_params: bool = True, | |||
| cal_flops: bool = True, | |||
| cal_activations: bool = True, | |||
| logging_to_stdout: bool = True, | |||
| bar_length_max: int = 20, | |||
| log_params: bool = True, | |||
| log_flops: bool = True, | |||
| log_activations: bool = True, | |||
| ): | |||
| r""" | |||
| Load megengine dumped model and visualize graph structure with tensorboard log files. | |||
| @@ -48,10 +52,14 @@ def visualize( | |||
| :param model_path: dir path for megengine dumped model. | |||
| :param log_path: dir path for tensorboard graph log. | |||
| :param input: user defined input data for running model and calculating stats, alternative with inp_dict, used when the model has only one input. | |||
| :param inp_dict: input dict for running model and calculating stats, alternative with input, used when the model has more than one input. When both input and inp_dict are None, a random input will be used. | |||
| :param cal_params: whether calculate and record params size. | |||
| :param cal_flops: whether calculate and record op flops. | |||
| :param cal_activations: whether calculate and record op activations. | |||
| :param logging_to_stdout: whether print all calculated statistic details. | |||
| :param bar_length_max: size of bar indicating max flops or parameter size in net stats. | |||
| :param log_params: whether print and record params size. | |||
| :param log_flops: whether print and record op flops. | |||
| :param log_activations: whether print and record op activations. | |||
| """ | |||
| if log_path: | |||
| try: | |||
| @@ -78,6 +86,27 @@ def visualize( | |||
| enable_receptive_field() | |||
| graph = Network.load(model_path) | |||
| graph.reset_batch_size(1) | |||
| has_input = False | |||
| if input is not None or inp_dict is not None: | |||
| has_input = True | |||
| repl_dict = {} | |||
| inp_vars = graph.input_vars | |||
| if inp_dict is not None: | |||
| assert len(inp_dict) == len( | |||
| inp_vars | |||
| ), "Inputs are not sufficient for calculation." | |||
| for v in inp_vars: | |||
| new_input = graph.make_const(inp_dict[v.name], name=v.name) | |||
| repl_dict[v] = new_input | |||
| else: | |||
| assert len(inp_vars) == 1, "The graph needs more than one input." | |||
| inp_var = inp_vars[0] | |||
| repl_dict[inp_var] = graph.make_const(input, name=inp_var.name) | |||
| graph.replace_vars(repl_dict=repl_dict) | |||
| graph._compile() | |||
| def process_name(name): | |||
| # nodes that start with point or contain float const will lead to display bug | |||
| @@ -93,7 +122,7 @@ def visualize( | |||
| total_stats = namedtuple("total_stats", ["param_size", "flops", "act_size"]) | |||
| stats_details = namedtuple("module_stats", ["params", "flops", "activations"]) | |||
| for node in graph.all_oprs: | |||
| for node in tqdm(graph.all_oprs): | |||
| if hasattr(node, "output_idx"): | |||
| node_oup = node.outputs[node.output_idx] | |||
| else: | |||
| @@ -123,31 +152,35 @@ def visualize( | |||
| "params": AttrValue(s=str(node.params).encode(encoding="utf-8")), | |||
| "dtype": AttrValue(s=str(node_oup.dtype).encode(encoding="utf-8")), | |||
| } | |||
| flops_stats = get_op_stats(node, node.inputs, node.outputs) | |||
| if flops_stats is not None: | |||
| # add op flops attr | |||
| if log_path and hasattr(flops_stats, "flops_num"): | |||
| attr["flops"] = AttrValue( | |||
| s=sizeof_fmt(flops_stats["flops"]).encode(encoding="utf-8") | |||
| ) | |||
| flops_stats["name"] = node.name | |||
| flops_stats["class_name"] = node.type | |||
| flops_list.append(flops_stats) | |||
| acts = get_activation_stats(node_oup) | |||
| if cal_flops: | |||
| flops_stats = get_op_stats(node, node.inputs, node.outputs) | |||
| if flops_stats is not None: | |||
| # add op flops attr | |||
| if log_path and hasattr(flops_stats, "flops_num"): | |||
| attr["flops"] = AttrValue( | |||
| s=sizeof_fmt(flops_stats["flops"]).encode(encoding="utf-8") | |||
| ) | |||
| flops_stats["name"] = node.name | |||
| flops_stats["class_name"] = node.type | |||
| flops_list.append(flops_stats) | |||
| if cal_activations: | |||
| acts = get_activation_stats(node_oup.numpy(), has_input=has_input) | |||
| acts["name"] = node.name | |||
| acts["class_name"] = node.type | |||
| activations_list.append(acts) | |||
| if node.type == "ImmutableTensor": | |||
| param_stats = get_param_stats(node_oup) | |||
| # add tensor size attr | |||
| if log_path: | |||
| attr["size"] = AttrValue( | |||
| s=sizeof_fmt(param_stats["size"]).encode(encoding="utf-8") | |||
| ) | |||
| param_stats["name"] = node.name | |||
| params_list.append(param_stats) | |||
| if cal_params: | |||
| if node.type == "ImmutableTensor": | |||
| param_stats = get_param_stats(node.numpy()) | |||
| # add tensor size attr | |||
| if log_path: | |||
| attr["size"] = AttrValue( | |||
| s=sizeof_fmt(param_stats["size"]).encode(encoding="utf-8") | |||
| ) | |||
| param_stats["name"] = node.name | |||
| params_list.append(param_stats) | |||
| if log_path: | |||
| node_list.append( | |||
| @@ -169,31 +202,37 @@ def visualize( | |||
| total_param_dims, | |||
| total_param_size, | |||
| total_act_dims, | |||
| total_param_size, | |||
| total_act_size, | |||
| ) = (0, 0, 0, 0, 0) | |||
| total_param_dims, total_param_size, params = sum_param_stats( | |||
| params_list, bar_length_max | |||
| ) | |||
| extra_info["total_param_dims"] = sizeof_fmt(total_param_dims, suffix="") | |||
| extra_info["total_param_size"] = sizeof_fmt(total_param_size) | |||
| if log_params: | |||
| print_param_stats(params) | |||
| total_flops, flops = sum_op_stats(flops_list, bar_length_max) | |||
| extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs") | |||
| if log_flops: | |||
| print_op_stats(flops) | |||
| total_act_dims, total_act_size, activations = sum_activations_stats( | |||
| activations_list, bar_length_max | |||
| ) | |||
| extra_info["total_act_dims"] = sizeof_fmt(total_act_dims, suffix="") | |||
| extra_info["total_act_size"] = sizeof_fmt(total_act_size) | |||
| if log_activations: | |||
| print_activations_stats(activations) | |||
| if cal_params: | |||
| total_param_dims, total_param_size, params_list = sum_param_stats( | |||
| params_list, bar_length_max | |||
| ) | |||
| extra_info["total_param_dims"] = sizeof_fmt(total_param_dims, suffix="") | |||
| extra_info["total_param_size"] = sizeof_fmt(total_param_size) | |||
| if logging_to_stdout: | |||
| print_param_stats(params_list) | |||
| extra_info["flops/param_size"] = "{:3.3f}".format(total_flops / total_param_size) | |||
| if cal_flops: | |||
| total_flops, flops_list = sum_op_stats(flops_list, bar_length_max) | |||
| extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs") | |||
| if logging_to_stdout: | |||
| print_op_stats(flops_list) | |||
| if cal_activations: | |||
| total_act_dims, total_act_size, activations_list = sum_activations_stats( | |||
| activations_list, bar_length_max | |||
| ) | |||
| extra_info["total_act_dims"] = sizeof_fmt(total_act_dims, suffix="") | |||
| extra_info["total_act_size"] = sizeof_fmt(total_act_size) | |||
| if logging_to_stdout: | |||
| print_activations_stats(activations_list, has_input=has_input) | |||
| if cal_flops and cal_params: | |||
| extra_info["flops/param_size"] = "{:3.3f}".format( | |||
| total_flops / total_param_size | |||
| ) | |||
| if log_path: | |||
| graph_def = GraphDef(node=node_list, versions=VersionDef(producer=22)) | |||
| @@ -211,7 +250,9 @@ def visualize( | |||
| total_stats( | |||
| param_size=total_param_size, flops=total_flops, act_size=total_act_size, | |||
| ), | |||
| stats_details(params=params, flops=flops, activations=activations), | |||
| stats_details( | |||
| params=params_list, flops=flops_list, activations=activations_list | |||
| ), | |||
| ) | |||
| @@ -229,12 +270,24 @@ def main(): | |||
| help="size of bar indicating max flops or parameter size in net stats.", | |||
| ) | |||
| parser.add_argument( | |||
| "--log_params", | |||
| "--cal_params", | |||
| action="store_true", | |||
| help="whether calculate and record params size.", | |||
| ) | |||
| parser.add_argument( | |||
| "--cal_flops", | |||
| action="store_true", | |||
| help="whether calculate and record op flops.", | |||
| ) | |||
| parser.add_argument( | |||
| "--cal_activations", | |||
| action="store_true", | |||
| help="whether print and record params size.", | |||
| help="whether calculate and record op activations.", | |||
| ) | |||
| parser.add_argument( | |||
| "--log_flops", action="store_true", help="whether print and record op flops.", | |||
| "--logging_to_stdout", | |||
| action="store_true", | |||
| help="whether print all calculated statistic details.", | |||
| ) | |||
| parser.add_argument( | |||
| "--all", | |||
| @@ -243,8 +296,10 @@ def main(): | |||
| ) | |||
| args = parser.parse_args() | |||
| if args.all: | |||
| args.log_params = True | |||
| args.log_flops = True | |||
| args.cal_params = True | |||
| args.cal_flops = True | |||
| args.cal_activations = True | |||
| args.logging_to_stdout = True | |||
| if not args.log_path: | |||
| args.log_path = "./log" | |||
| kwargs = vars(args) | |||
| @@ -5,8 +5,9 @@ | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from collections import namedtuple | |||
| from collections import Iterable, namedtuple | |||
| from functools import partial | |||
| from typing import Iterable | |||
| import numpy as np | |||
| import tabulate | |||
| @@ -19,6 +20,7 @@ from megengine import Tensor | |||
| from megengine import functional as F | |||
| from megengine.core.tensor.dtype import get_dtype_bit | |||
| from megengine.functional.tensor import zeros | |||
| from megengine.tensor import Tensor | |||
| from .module_utils import set_module_mode_safe | |||
| @@ -335,21 +337,23 @@ def print_param_stats(params): | |||
| ) | |||
| def get_activation_stats(output: Tensor): | |||
| def get_activation_stats(output: np.ndarray, has_input=False): | |||
| out_shape = output.shape | |||
| activations_dtype = np.dtype(output.dtype) | |||
| nbits = get_dtype_bit(activations_dtype.name) | |||
| act_dim = np.prod(out_shape) | |||
| act_size = act_dim * nbits // 8 | |||
| return { | |||
| activation_stats = { | |||
| "dtype": activations_dtype, | |||
| "shape": out_shape, | |||
| "act_dim": act_dim, | |||
| "mean": "{:.3g}".format(_mean(output)), | |||
| "std": "{:.3g}".format(_std(output)), | |||
| "nbits": nbits, | |||
| "size": act_size, | |||
| } | |||
| if has_input: | |||
| activation_stats["mean"] = "{:.3g}".format(output.mean()) | |||
| activation_stats["std"] = "{:.3g}".format(output.std()) | |||
| return activation_stats | |||
| def sum_activations_stats(activations, bar_length_max=20): | |||
| @@ -373,14 +377,12 @@ def sum_activations_stats(activations, bar_length_max=20): | |||
| return total_act_dims, total_act_size, activations | |||
| def print_activations_stats(activations): | |||
| def print_activations_stats(activations, has_input=False): | |||
| header = [ | |||
| "name", | |||
| "class_name", | |||
| "dtype", | |||
| "shape", | |||
| "mean", | |||
| "std", | |||
| "nbits", | |||
| "act_dim", | |||
| "size", | |||
| @@ -388,6 +390,9 @@ def print_activations_stats(activations): | |||
| "percentage", | |||
| "size_bar", | |||
| ] | |||
| if has_input: | |||
| header.insert(4, "mean") | |||
| header.insert(5, "std") | |||
| logger.info( | |||
| "activations stats: \n" | |||
| + tabulate.tabulate(dict2table(activations, header=header)) | |||
| @@ -402,56 +407,80 @@ def print_summary(**kwargs): | |||
| def module_stats( | |||
| model: m.Module, | |||
| input_shapes: list, | |||
| inputs: Iterable[np.ndarray] = None, | |||
| input_shapes: list = None, | |||
| cal_params: bool = True, | |||
| cal_flops: bool = True, | |||
| cal_activations: bool = True, | |||
| logging_to_stdout: bool = True, | |||
| bar_length_max: int = 20, | |||
| log_params: bool = True, | |||
| log_flops: bool = True, | |||
| log_activations: bool = True, | |||
| ): | |||
| r""" | |||
| Calculate and print ``model``'s statistics by adding hook and record Module's inputs outputs size. | |||
| :param model: model that need to get stats info. | |||
| :param input_shapes: shapes of inputs for running model and calculating stats. | |||
| :param inputs: user defined input data for running model and calculating stats, alternative with input_shapes. | |||
| :param input_shapes: shapes to generate random inputs for running model and calculating stats, alternative with inputs. | |||
| :param cal_params: whether calculate and record params size. | |||
| :param cal_flops: whether calculate and record op flops. | |||
| :param cal_activations: whether calculate and record op activations. | |||
| :param logging_to_stdout: whether print all calculated statistic details. | |||
| :param bar_length_max: size of bar indicating max flops or parameter size in net stats. | |||
| :param log_params: whether print and record params size. | |||
| :param log_flops: whether print and record op flops. | |||
| :param log_activations: whether print and record op activations. | |||
| """ | |||
| has_inputs = False | |||
| if inputs is not None: | |||
| has_inputs = True | |||
| if not isinstance(inputs, (tuple, list)): | |||
| inputs = [inputs] | |||
| inputs = [Tensor(input, dtype=np.float32) for input in inputs] | |||
| else: | |||
| if input_shapes: | |||
| if not isinstance(input_shapes[0], tuple): | |||
| input_shapes = [input_shapes] | |||
| inputs = [zeros(in_size, dtype=np.float32) for in_size in input_shapes] | |||
| else: | |||
| logger.error( | |||
| "Inputs or input_shapes is required for running model and calculating stats.", | |||
| exc_info=True, | |||
| ) | |||
| return | |||
| if not cal_activations: | |||
| log_activations = False | |||
| disable_receptive_field() | |||
| def module_stats_hook(module, inputs, outputs, name=""): | |||
| class_name = str(module.__class__).split(".")[-1].split("'")[0] | |||
| flops_stats = get_op_stats(module, inputs, outputs) | |||
| if flops_stats is not None: | |||
| flops_stats["name"] = name | |||
| flops_stats["class_name"] = class_name | |||
| flops.append(flops_stats) | |||
| if hasattr(module, "weight") and module.weight is not None: | |||
| w = module.weight | |||
| param_stats = get_param_stats(w) | |||
| param_stats["name"] = name + "-w" | |||
| params.append(param_stats) | |||
| if hasattr(module, "bias") and module.bias is not None: | |||
| b = module.bias | |||
| param_stats = get_param_stats(b) | |||
| param_stats["name"] = name + "-b" | |||
| params.append(param_stats) | |||
| if not isinstance(outputs, tuple) or not isinstance(outputs, list): | |||
| output = outputs | |||
| else: | |||
| output = outputs[0] | |||
| activation_stats = get_activation_stats(output) | |||
| activation_stats["name"] = name | |||
| activation_stats["class_name"] = class_name | |||
| activations.append(activation_stats) | |||
| # multiple inputs to the network | |||
| if not isinstance(input_shapes[0], tuple): | |||
| input_shapes = [input_shapes] | |||
| if cal_flops: | |||
| flops_stats = get_op_stats(module, inputs, outputs) | |||
| if flops_stats is not None: | |||
| flops_stats["name"] = name | |||
| flops_stats["class_name"] = class_name | |||
| flops.append(flops_stats) | |||
| if cal_params: | |||
| if hasattr(module, "weight") and module.weight is not None: | |||
| w = module.weight | |||
| param_stats = get_param_stats(w.numpy()) | |||
| param_stats["name"] = name + "-w" | |||
| params.append(param_stats) | |||
| if hasattr(module, "bias") and module.bias is not None: | |||
| b = module.bias | |||
| param_stats = get_param_stats(b.numpy()) | |||
| param_stats["name"] = name + "-b" | |||
| params.append(param_stats) | |||
| if cal_activations: | |||
| if not isinstance(outputs, (tuple, list)): | |||
| output = outputs.numpy() | |||
| else: | |||
| output = outputs[0].numpy() | |||
| activation_stats = get_activation_stats(output, has_inputs) | |||
| activation_stats["name"] = name | |||
| activation_stats["class_name"] = class_name | |||
| activations.append(activation_stats) | |||
| params = [] | |||
| flops = [] | |||
| @@ -466,7 +495,6 @@ def module_stats( | |||
| module.register_forward_hook(partial(module_stats_hook, name=name)) | |||
| ) | |||
| inputs = [zeros(in_size, dtype=np.float32) for in_size in input_shapes] | |||
| with set_module_mode_safe(model, training=False) as model: | |||
| model(*inputs) | |||
| @@ -481,29 +509,37 @@ def module_stats( | |||
| total_param_dims, | |||
| total_param_size, | |||
| total_act_dims, | |||
| total_param_size, | |||
| total_act_size, | |||
| ) = (0, 0, 0, 0, 0) | |||
| total_param_dims, total_param_size, params = sum_param_stats(params, bar_length_max) | |||
| extra_info["total_param_dims"] = sizeof_fmt(total_param_dims, suffix="") | |||
| extra_info["total_param_size"] = sizeof_fmt(total_param_size) | |||
| if log_params: | |||
| print_param_stats(params) | |||
| total_flops, flops = sum_op_stats(flops, bar_length_max) | |||
| extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs") | |||
| if log_flops: | |||
| print_op_stats(flops) | |||
| total_act_dims, total_act_size, activations = sum_activations_stats( | |||
| activations, bar_length_max | |||
| ) | |||
| extra_info["total_act_dims"] = sizeof_fmt(total_act_dims, suffix="") | |||
| extra_info["total_act_size"] = sizeof_fmt(total_act_size) | |||
| if log_activations: | |||
| print_activations_stats(activations) | |||
| extra_info["flops/param_size"] = "{:3.3f}".format(total_flops / total_param_size) | |||
| if cal_params: | |||
| total_param_dims, total_param_size, params = sum_param_stats( | |||
| params, bar_length_max | |||
| ) | |||
| extra_info["total_param_dims"] = sizeof_fmt(total_param_dims, suffix="") | |||
| extra_info["total_param_size"] = sizeof_fmt(total_param_size) | |||
| if logging_to_stdout: | |||
| print_param_stats(params) | |||
| if cal_flops: | |||
| total_flops, flops = sum_op_stats(flops, bar_length_max) | |||
| extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs") | |||
| if logging_to_stdout: | |||
| print_op_stats(flops) | |||
| if cal_activations: | |||
| total_act_dims, total_act_size, activations = sum_activations_stats( | |||
| activations, bar_length_max | |||
| ) | |||
| extra_info["total_act_dims"] = sizeof_fmt(total_act_dims, suffix="") | |||
| extra_info["total_act_size"] = sizeof_fmt(total_act_size) | |||
| if logging_to_stdout: | |||
| print_activations_stats(activations, has_inputs) | |||
| if cal_flops and cal_params: | |||
| extra_info["flops/param_size"] = "{:3.3f}".format( | |||
| total_flops / total_param_size | |||
| ) | |||
| print_summary(**extra_info) | |||
| @@ -18,11 +18,15 @@ from megengine.utils.module_stats import module_stats | |||
| def test_module_stats(): | |||
| net = ResNet(BasicBlock, [2, 2, 2, 2]) | |||
| input_shape = (1, 3, 224, 224) | |||
| total_stats, stats_details = module_stats(net, input_shape) | |||
| x1 = mge.tensor(np.zeros((1, 3, 224, 224))) | |||
| gt_flops, gt_acts = net.get_stats(x1) | |||
| total_stats, stats_details = module_stats(net, input_shapes=input_shape) | |||
| x1 = np.random.random((1, 3, 224, 224)).astype("float32") | |||
| gt_flops, gt_acts = net.get_stats(mge.tensor(x1)) | |||
| assert (total_stats.flops, stats_details.activations[-1]["act_dim"]) == ( | |||
| gt_flops, | |||
| gt_acts, | |||
| ) | |||
| total_stats, stats_details = module_stats(net, inputs=x1) | |||
| assert (total_stats.flops, stats_details.activations[-1]["act_dim"]) == ( | |||
| gt_flops, | |||
| gt_acts, | |||