BREAKING CHANGE:
GitOrigin-RevId: cd2a1acd11
tags/v1.5.0
| @@ -12,6 +12,7 @@ import re | |||||
| from collections import namedtuple | from collections import namedtuple | ||||
| import numpy as np | import numpy as np | ||||
| from tqdm import tqdm | |||||
| from megengine.core.tensor.dtype import is_quantize | from megengine.core.tensor.dtype import is_quantize | ||||
| from megengine.logger import _imperative_rt_logger, get_logger, set_mgb_log_level | from megengine.logger import _imperative_rt_logger, get_logger, set_mgb_log_level | ||||
| @@ -37,10 +38,13 @@ logger = get_logger(__name__) | |||||
| def visualize( | def visualize( | ||||
| model_path: str, | model_path: str, | ||||
| log_path: str, | log_path: str, | ||||
| input: np.ndarray = None, | |||||
| inp_dict: dict = None, | |||||
| cal_params: bool = True, | |||||
| cal_flops: bool = True, | |||||
| cal_activations: bool = True, | |||||
| logging_to_stdout: bool = True, | |||||
| bar_length_max: int = 20, | bar_length_max: int = 20, | ||||
| log_params: bool = True, | |||||
| log_flops: bool = True, | |||||
| log_activations: bool = True, | |||||
| ): | ): | ||||
| r""" | r""" | ||||
| Load megengine dumped model and visualize graph structure with tensorboard log files. | Load megengine dumped model and visualize graph structure with tensorboard log files. | ||||
| @@ -48,10 +52,14 @@ def visualize( | |||||
| :param model_path: dir path for megengine dumped model. | :param model_path: dir path for megengine dumped model. | ||||
| :param log_path: dir path for tensorboard graph log. | :param log_path: dir path for tensorboard graph log. | ||||
| :param input: user defined input data for running model and calculating stats, alternative with inp_dict, used when the model has only one input. | |||||
| :param inp_dict: input dict for running model and calculating stats, alternative with input, used when the model has more than one input. When both input and inp_dict are None, a random input will be used. | |||||
| :param cal_params: whether calculate and record params size. | |||||
| :param cal_flops: whether calculate and record op flops. | |||||
| :param cal_activations: whether calculate and record op activations. | |||||
| :param logging_to_stdout: whether print all calculated statistic details. | |||||
| :param bar_length_max: size of bar indicating max flops or parameter size in net stats. | :param bar_length_max: size of bar indicating max flops or parameter size in net stats. | ||||
| :param log_params: whether print and record params size. | |||||
| :param log_flops: whether print and record op flops. | |||||
| :param log_activations: whether print and record op activations. | |||||
| """ | """ | ||||
| if log_path: | if log_path: | ||||
| try: | try: | ||||
| @@ -78,6 +86,27 @@ def visualize( | |||||
| enable_receptive_field() | enable_receptive_field() | ||||
| graph = Network.load(model_path) | graph = Network.load(model_path) | ||||
| graph.reset_batch_size(1) | |||||
| has_input = False | |||||
| if input is not None or inp_dict is not None: | |||||
| has_input = True | |||||
| repl_dict = {} | |||||
| inp_vars = graph.input_vars | |||||
| if inp_dict is not None: | |||||
| assert len(inp_dict) == len( | |||||
| inp_vars | |||||
| ), "Inputs are not sufficient for calculation." | |||||
| for v in inp_vars: | |||||
| new_input = graph.make_const(inp_dict[v.name], name=v.name) | |||||
| repl_dict[v] = new_input | |||||
| else: | |||||
| assert len(inp_vars) == 1, "The graph needs more than one input." | |||||
| inp_var = inp_vars[0] | |||||
| repl_dict[inp_var] = graph.make_const(input, name=inp_var.name) | |||||
| graph.replace_vars(repl_dict=repl_dict) | |||||
| graph._compile() | |||||
| def process_name(name): | def process_name(name): | ||||
| # nodes that start with point or contain float const will lead to display bug | # nodes that start with point or contain float const will lead to display bug | ||||
| @@ -93,7 +122,7 @@ def visualize( | |||||
| total_stats = namedtuple("total_stats", ["param_size", "flops", "act_size"]) | total_stats = namedtuple("total_stats", ["param_size", "flops", "act_size"]) | ||||
| stats_details = namedtuple("module_stats", ["params", "flops", "activations"]) | stats_details = namedtuple("module_stats", ["params", "flops", "activations"]) | ||||
| for node in graph.all_oprs: | |||||
| for node in tqdm(graph.all_oprs): | |||||
| if hasattr(node, "output_idx"): | if hasattr(node, "output_idx"): | ||||
| node_oup = node.outputs[node.output_idx] | node_oup = node.outputs[node.output_idx] | ||||
| else: | else: | ||||
| @@ -123,31 +152,35 @@ def visualize( | |||||
| "params": AttrValue(s=str(node.params).encode(encoding="utf-8")), | "params": AttrValue(s=str(node.params).encode(encoding="utf-8")), | ||||
| "dtype": AttrValue(s=str(node_oup.dtype).encode(encoding="utf-8")), | "dtype": AttrValue(s=str(node_oup.dtype).encode(encoding="utf-8")), | ||||
| } | } | ||||
| flops_stats = get_op_stats(node, node.inputs, node.outputs) | |||||
| if flops_stats is not None: | |||||
| # add op flops attr | |||||
| if log_path and hasattr(flops_stats, "flops_num"): | |||||
| attr["flops"] = AttrValue( | |||||
| s=sizeof_fmt(flops_stats["flops"]).encode(encoding="utf-8") | |||||
| ) | |||||
| flops_stats["name"] = node.name | |||||
| flops_stats["class_name"] = node.type | |||||
| flops_list.append(flops_stats) | |||||
| acts = get_activation_stats(node_oup) | |||||
| if cal_flops: | |||||
| flops_stats = get_op_stats(node, node.inputs, node.outputs) | |||||
| if flops_stats is not None: | |||||
| # add op flops attr | |||||
| if log_path and hasattr(flops_stats, "flops_num"): | |||||
| attr["flops"] = AttrValue( | |||||
| s=sizeof_fmt(flops_stats["flops"]).encode(encoding="utf-8") | |||||
| ) | |||||
| flops_stats["name"] = node.name | |||||
| flops_stats["class_name"] = node.type | |||||
| flops_list.append(flops_stats) | |||||
| if cal_activations: | |||||
| acts = get_activation_stats(node_oup.numpy(), has_input=has_input) | |||||
| acts["name"] = node.name | acts["name"] = node.name | ||||
| acts["class_name"] = node.type | acts["class_name"] = node.type | ||||
| activations_list.append(acts) | activations_list.append(acts) | ||||
| if node.type == "ImmutableTensor": | |||||
| param_stats = get_param_stats(node_oup) | |||||
| # add tensor size attr | |||||
| if log_path: | |||||
| attr["size"] = AttrValue( | |||||
| s=sizeof_fmt(param_stats["size"]).encode(encoding="utf-8") | |||||
| ) | |||||
| param_stats["name"] = node.name | |||||
| params_list.append(param_stats) | |||||
| if cal_params: | |||||
| if node.type == "ImmutableTensor": | |||||
| param_stats = get_param_stats(node.numpy()) | |||||
| # add tensor size attr | |||||
| if log_path: | |||||
| attr["size"] = AttrValue( | |||||
| s=sizeof_fmt(param_stats["size"]).encode(encoding="utf-8") | |||||
| ) | |||||
| param_stats["name"] = node.name | |||||
| params_list.append(param_stats) | |||||
| if log_path: | if log_path: | ||||
| node_list.append( | node_list.append( | ||||
| @@ -169,31 +202,37 @@ def visualize( | |||||
| total_param_dims, | total_param_dims, | ||||
| total_param_size, | total_param_size, | ||||
| total_act_dims, | total_act_dims, | ||||
| total_param_size, | |||||
| total_act_size, | |||||
| ) = (0, 0, 0, 0, 0) | ) = (0, 0, 0, 0, 0) | ||||
| total_param_dims, total_param_size, params = sum_param_stats( | |||||
| params_list, bar_length_max | |||||
| ) | |||||
| extra_info["total_param_dims"] = sizeof_fmt(total_param_dims, suffix="") | |||||
| extra_info["total_param_size"] = sizeof_fmt(total_param_size) | |||||
| if log_params: | |||||
| print_param_stats(params) | |||||
| total_flops, flops = sum_op_stats(flops_list, bar_length_max) | |||||
| extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs") | |||||
| if log_flops: | |||||
| print_op_stats(flops) | |||||
| total_act_dims, total_act_size, activations = sum_activations_stats( | |||||
| activations_list, bar_length_max | |||||
| ) | |||||
| extra_info["total_act_dims"] = sizeof_fmt(total_act_dims, suffix="") | |||||
| extra_info["total_act_size"] = sizeof_fmt(total_act_size) | |||||
| if log_activations: | |||||
| print_activations_stats(activations) | |||||
| if cal_params: | |||||
| total_param_dims, total_param_size, params_list = sum_param_stats( | |||||
| params_list, bar_length_max | |||||
| ) | |||||
| extra_info["total_param_dims"] = sizeof_fmt(total_param_dims, suffix="") | |||||
| extra_info["total_param_size"] = sizeof_fmt(total_param_size) | |||||
| if logging_to_stdout: | |||||
| print_param_stats(params_list) | |||||
| extra_info["flops/param_size"] = "{:3.3f}".format(total_flops / total_param_size) | |||||
| if cal_flops: | |||||
| total_flops, flops_list = sum_op_stats(flops_list, bar_length_max) | |||||
| extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs") | |||||
| if logging_to_stdout: | |||||
| print_op_stats(flops_list) | |||||
| if cal_activations: | |||||
| total_act_dims, total_act_size, activations_list = sum_activations_stats( | |||||
| activations_list, bar_length_max | |||||
| ) | |||||
| extra_info["total_act_dims"] = sizeof_fmt(total_act_dims, suffix="") | |||||
| extra_info["total_act_size"] = sizeof_fmt(total_act_size) | |||||
| if logging_to_stdout: | |||||
| print_activations_stats(activations_list, has_input=has_input) | |||||
| if cal_flops and cal_params: | |||||
| extra_info["flops/param_size"] = "{:3.3f}".format( | |||||
| total_flops / total_param_size | |||||
| ) | |||||
| if log_path: | if log_path: | ||||
| graph_def = GraphDef(node=node_list, versions=VersionDef(producer=22)) | graph_def = GraphDef(node=node_list, versions=VersionDef(producer=22)) | ||||
| @@ -211,7 +250,9 @@ def visualize( | |||||
| total_stats( | total_stats( | ||||
| param_size=total_param_size, flops=total_flops, act_size=total_act_size, | param_size=total_param_size, flops=total_flops, act_size=total_act_size, | ||||
| ), | ), | ||||
| stats_details(params=params, flops=flops, activations=activations), | |||||
| stats_details( | |||||
| params=params_list, flops=flops_list, activations=activations_list | |||||
| ), | |||||
| ) | ) | ||||
| @@ -229,12 +270,24 @@ def main(): | |||||
| help="size of bar indicating max flops or parameter size in net stats.", | help="size of bar indicating max flops or parameter size in net stats.", | ||||
| ) | ) | ||||
| parser.add_argument( | parser.add_argument( | ||||
| "--log_params", | |||||
| "--cal_params", | |||||
| action="store_true", | |||||
| help="whether calculate and record params size.", | |||||
| ) | |||||
| parser.add_argument( | |||||
| "--cal_flops", | |||||
| action="store_true", | |||||
| help="whether calculate and record op flops.", | |||||
| ) | |||||
| parser.add_argument( | |||||
| "--cal_activations", | |||||
| action="store_true", | action="store_true", | ||||
| help="whether print and record params size.", | |||||
| help="whether calculate and record op activations.", | |||||
| ) | ) | ||||
| parser.add_argument( | parser.add_argument( | ||||
| "--log_flops", action="store_true", help="whether print and record op flops.", | |||||
| "--logging_to_stdout", | |||||
| action="store_true", | |||||
| help="whether print all calculated statistic details.", | |||||
| ) | ) | ||||
| parser.add_argument( | parser.add_argument( | ||||
| "--all", | "--all", | ||||
| @@ -243,8 +296,10 @@ def main(): | |||||
| ) | ) | ||||
| args = parser.parse_args() | args = parser.parse_args() | ||||
| if args.all: | if args.all: | ||||
| args.log_params = True | |||||
| args.log_flops = True | |||||
| args.cal_params = True | |||||
| args.cal_flops = True | |||||
| args.cal_activations = True | |||||
| args.logging_to_stdout = True | |||||
| if not args.log_path: | if not args.log_path: | ||||
| args.log_path = "./log" | args.log_path = "./log" | ||||
| kwargs = vars(args) | kwargs = vars(args) | ||||
| @@ -5,8 +5,9 @@ | |||||
| # Unless required by applicable law or agreed to in writing, | # Unless required by applicable law or agreed to in writing, | ||||
| # software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| from collections import namedtuple | |||||
| from collections import Iterable, namedtuple | |||||
| from functools import partial | from functools import partial | ||||
| from typing import Iterable | |||||
| import numpy as np | import numpy as np | ||||
| import tabulate | import tabulate | ||||
| @@ -19,6 +20,7 @@ from megengine import Tensor | |||||
| from megengine import functional as F | from megengine import functional as F | ||||
| from megengine.core.tensor.dtype import get_dtype_bit | from megengine.core.tensor.dtype import get_dtype_bit | ||||
| from megengine.functional.tensor import zeros | from megengine.functional.tensor import zeros | ||||
| from megengine.tensor import Tensor | |||||
| from .module_utils import set_module_mode_safe | from .module_utils import set_module_mode_safe | ||||
| @@ -335,21 +337,23 @@ def print_param_stats(params): | |||||
| ) | ) | ||||
| def get_activation_stats(output: Tensor): | |||||
| def get_activation_stats(output: np.ndarray, has_input=False): | |||||
| out_shape = output.shape | out_shape = output.shape | ||||
| activations_dtype = np.dtype(output.dtype) | activations_dtype = np.dtype(output.dtype) | ||||
| nbits = get_dtype_bit(activations_dtype.name) | nbits = get_dtype_bit(activations_dtype.name) | ||||
| act_dim = np.prod(out_shape) | act_dim = np.prod(out_shape) | ||||
| act_size = act_dim * nbits // 8 | act_size = act_dim * nbits // 8 | ||||
| return { | |||||
| activation_stats = { | |||||
| "dtype": activations_dtype, | "dtype": activations_dtype, | ||||
| "shape": out_shape, | "shape": out_shape, | ||||
| "act_dim": act_dim, | "act_dim": act_dim, | ||||
| "mean": "{:.3g}".format(_mean(output)), | |||||
| "std": "{:.3g}".format(_std(output)), | |||||
| "nbits": nbits, | "nbits": nbits, | ||||
| "size": act_size, | "size": act_size, | ||||
| } | } | ||||
| if has_input: | |||||
| activation_stats["mean"] = "{:.3g}".format(output.mean()) | |||||
| activation_stats["std"] = "{:.3g}".format(output.std()) | |||||
| return activation_stats | |||||
| def sum_activations_stats(activations, bar_length_max=20): | def sum_activations_stats(activations, bar_length_max=20): | ||||
| @@ -373,14 +377,12 @@ def sum_activations_stats(activations, bar_length_max=20): | |||||
| return total_act_dims, total_act_size, activations | return total_act_dims, total_act_size, activations | ||||
| def print_activations_stats(activations): | |||||
| def print_activations_stats(activations, has_input=False): | |||||
| header = [ | header = [ | ||||
| "name", | "name", | ||||
| "class_name", | "class_name", | ||||
| "dtype", | "dtype", | ||||
| "shape", | "shape", | ||||
| "mean", | |||||
| "std", | |||||
| "nbits", | "nbits", | ||||
| "act_dim", | "act_dim", | ||||
| "size", | "size", | ||||
| @@ -388,6 +390,9 @@ def print_activations_stats(activations): | |||||
| "percentage", | "percentage", | ||||
| "size_bar", | "size_bar", | ||||
| ] | ] | ||||
| if has_input: | |||||
| header.insert(4, "mean") | |||||
| header.insert(5, "std") | |||||
| logger.info( | logger.info( | ||||
| "activations stats: \n" | "activations stats: \n" | ||||
| + tabulate.tabulate(dict2table(activations, header=header)) | + tabulate.tabulate(dict2table(activations, header=header)) | ||||
| @@ -402,56 +407,80 @@ def print_summary(**kwargs): | |||||
| def module_stats( | def module_stats( | ||||
| model: m.Module, | model: m.Module, | ||||
| input_shapes: list, | |||||
| inputs: Iterable[np.ndarray] = None, | |||||
| input_shapes: list = None, | |||||
| cal_params: bool = True, | |||||
| cal_flops: bool = True, | |||||
| cal_activations: bool = True, | |||||
| logging_to_stdout: bool = True, | |||||
| bar_length_max: int = 20, | bar_length_max: int = 20, | ||||
| log_params: bool = True, | |||||
| log_flops: bool = True, | |||||
| log_activations: bool = True, | |||||
| ): | ): | ||||
| r""" | r""" | ||||
| Calculate and print ``model``'s statistics by adding hook and record Module's inputs outputs size. | Calculate and print ``model``'s statistics by adding hook and record Module's inputs outputs size. | ||||
| :param model: model that need to get stats info. | :param model: model that need to get stats info. | ||||
| :param input_shapes: shapes of inputs for running model and calculating stats. | |||||
| :param inputs: user defined input data for running model and calculating stats, alternative with input_shapes. | |||||
| :param input_shapes: shapes to generate random inputs for running model and calculating stats, alternative with inputs. | |||||
| :param cal_params: whether calculate and record params size. | |||||
| :param cal_flops: whether calculate and record op flops. | |||||
| :param cal_activations: whether calculate and record op activations. | |||||
| :param logging_to_stdout: whether print all calculated statistic details. | |||||
| :param bar_length_max: size of bar indicating max flops or parameter size in net stats. | :param bar_length_max: size of bar indicating max flops or parameter size in net stats. | ||||
| :param log_params: whether print and record params size. | |||||
| :param log_flops: whether print and record op flops. | |||||
| :param log_activations: whether print and record op activations. | |||||
| """ | """ | ||||
| has_inputs = False | |||||
| if inputs is not None: | |||||
| has_inputs = True | |||||
| if not isinstance(inputs, (tuple, list)): | |||||
| inputs = [inputs] | |||||
| inputs = [Tensor(input, dtype=np.float32) for input in inputs] | |||||
| else: | |||||
| if input_shapes: | |||||
| if not isinstance(input_shapes[0], tuple): | |||||
| input_shapes = [input_shapes] | |||||
| inputs = [zeros(in_size, dtype=np.float32) for in_size in input_shapes] | |||||
| else: | |||||
| logger.error( | |||||
| "Inputs or input_shapes is required for running model and calculating stats.", | |||||
| exc_info=True, | |||||
| ) | |||||
| return | |||||
| if not cal_activations: | |||||
| log_activations = False | |||||
| disable_receptive_field() | disable_receptive_field() | ||||
| def module_stats_hook(module, inputs, outputs, name=""): | def module_stats_hook(module, inputs, outputs, name=""): | ||||
| class_name = str(module.__class__).split(".")[-1].split("'")[0] | class_name = str(module.__class__).split(".")[-1].split("'")[0] | ||||
| flops_stats = get_op_stats(module, inputs, outputs) | |||||
| if flops_stats is not None: | |||||
| flops_stats["name"] = name | |||||
| flops_stats["class_name"] = class_name | |||||
| flops.append(flops_stats) | |||||
| if hasattr(module, "weight") and module.weight is not None: | |||||
| w = module.weight | |||||
| param_stats = get_param_stats(w) | |||||
| param_stats["name"] = name + "-w" | |||||
| params.append(param_stats) | |||||
| if hasattr(module, "bias") and module.bias is not None: | |||||
| b = module.bias | |||||
| param_stats = get_param_stats(b) | |||||
| param_stats["name"] = name + "-b" | |||||
| params.append(param_stats) | |||||
| if not isinstance(outputs, tuple) or not isinstance(outputs, list): | |||||
| output = outputs | |||||
| else: | |||||
| output = outputs[0] | |||||
| activation_stats = get_activation_stats(output) | |||||
| activation_stats["name"] = name | |||||
| activation_stats["class_name"] = class_name | |||||
| activations.append(activation_stats) | |||||
| # multiple inputs to the network | |||||
| if not isinstance(input_shapes[0], tuple): | |||||
| input_shapes = [input_shapes] | |||||
| if cal_flops: | |||||
| flops_stats = get_op_stats(module, inputs, outputs) | |||||
| if flops_stats is not None: | |||||
| flops_stats["name"] = name | |||||
| flops_stats["class_name"] = class_name | |||||
| flops.append(flops_stats) | |||||
| if cal_params: | |||||
| if hasattr(module, "weight") and module.weight is not None: | |||||
| w = module.weight | |||||
| param_stats = get_param_stats(w.numpy()) | |||||
| param_stats["name"] = name + "-w" | |||||
| params.append(param_stats) | |||||
| if hasattr(module, "bias") and module.bias is not None: | |||||
| b = module.bias | |||||
| param_stats = get_param_stats(b.numpy()) | |||||
| param_stats["name"] = name + "-b" | |||||
| params.append(param_stats) | |||||
| if cal_activations: | |||||
| if not isinstance(outputs, (tuple, list)): | |||||
| output = outputs.numpy() | |||||
| else: | |||||
| output = outputs[0].numpy() | |||||
| activation_stats = get_activation_stats(output, has_inputs) | |||||
| activation_stats["name"] = name | |||||
| activation_stats["class_name"] = class_name | |||||
| activations.append(activation_stats) | |||||
| params = [] | params = [] | ||||
| flops = [] | flops = [] | ||||
| @@ -466,7 +495,6 @@ def module_stats( | |||||
| module.register_forward_hook(partial(module_stats_hook, name=name)) | module.register_forward_hook(partial(module_stats_hook, name=name)) | ||||
| ) | ) | ||||
| inputs = [zeros(in_size, dtype=np.float32) for in_size in input_shapes] | |||||
| with set_module_mode_safe(model, training=False) as model: | with set_module_mode_safe(model, training=False) as model: | ||||
| model(*inputs) | model(*inputs) | ||||
| @@ -481,29 +509,37 @@ def module_stats( | |||||
| total_param_dims, | total_param_dims, | ||||
| total_param_size, | total_param_size, | ||||
| total_act_dims, | total_act_dims, | ||||
| total_param_size, | |||||
| total_act_size, | |||||
| ) = (0, 0, 0, 0, 0) | ) = (0, 0, 0, 0, 0) | ||||
| total_param_dims, total_param_size, params = sum_param_stats(params, bar_length_max) | |||||
| extra_info["total_param_dims"] = sizeof_fmt(total_param_dims, suffix="") | |||||
| extra_info["total_param_size"] = sizeof_fmt(total_param_size) | |||||
| if log_params: | |||||
| print_param_stats(params) | |||||
| total_flops, flops = sum_op_stats(flops, bar_length_max) | |||||
| extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs") | |||||
| if log_flops: | |||||
| print_op_stats(flops) | |||||
| total_act_dims, total_act_size, activations = sum_activations_stats( | |||||
| activations, bar_length_max | |||||
| ) | |||||
| extra_info["total_act_dims"] = sizeof_fmt(total_act_dims, suffix="") | |||||
| extra_info["total_act_size"] = sizeof_fmt(total_act_size) | |||||
| if log_activations: | |||||
| print_activations_stats(activations) | |||||
| extra_info["flops/param_size"] = "{:3.3f}".format(total_flops / total_param_size) | |||||
| if cal_params: | |||||
| total_param_dims, total_param_size, params = sum_param_stats( | |||||
| params, bar_length_max | |||||
| ) | |||||
| extra_info["total_param_dims"] = sizeof_fmt(total_param_dims, suffix="") | |||||
| extra_info["total_param_size"] = sizeof_fmt(total_param_size) | |||||
| if logging_to_stdout: | |||||
| print_param_stats(params) | |||||
| if cal_flops: | |||||
| total_flops, flops = sum_op_stats(flops, bar_length_max) | |||||
| extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs") | |||||
| if logging_to_stdout: | |||||
| print_op_stats(flops) | |||||
| if cal_activations: | |||||
| total_act_dims, total_act_size, activations = sum_activations_stats( | |||||
| activations, bar_length_max | |||||
| ) | |||||
| extra_info["total_act_dims"] = sizeof_fmt(total_act_dims, suffix="") | |||||
| extra_info["total_act_size"] = sizeof_fmt(total_act_size) | |||||
| if logging_to_stdout: | |||||
| print_activations_stats(activations, has_inputs) | |||||
| if cal_flops and cal_params: | |||||
| extra_info["flops/param_size"] = "{:3.3f}".format( | |||||
| total_flops / total_param_size | |||||
| ) | |||||
| print_summary(**extra_info) | print_summary(**extra_info) | ||||
| @@ -18,11 +18,15 @@ from megengine.utils.module_stats import module_stats | |||||
| def test_module_stats(): | def test_module_stats(): | ||||
| net = ResNet(BasicBlock, [2, 2, 2, 2]) | net = ResNet(BasicBlock, [2, 2, 2, 2]) | ||||
| input_shape = (1, 3, 224, 224) | input_shape = (1, 3, 224, 224) | ||||
| total_stats, stats_details = module_stats(net, input_shape) | |||||
| x1 = mge.tensor(np.zeros((1, 3, 224, 224))) | |||||
| gt_flops, gt_acts = net.get_stats(x1) | |||||
| total_stats, stats_details = module_stats(net, input_shapes=input_shape) | |||||
| x1 = np.random.random((1, 3, 224, 224)).astype("float32") | |||||
| gt_flops, gt_acts = net.get_stats(mge.tensor(x1)) | |||||
| assert (total_stats.flops, stats_details.activations[-1]["act_dim"]) == ( | |||||
| gt_flops, | |||||
| gt_acts, | |||||
| ) | |||||
| total_stats, stats_details = module_stats(net, inputs=x1) | |||||
| assert (total_stats.flops, stats_details.activations[-1]["act_dim"]) == ( | assert (total_stats.flops, stats_details.activations[-1]["act_dim"]) == ( | ||||
| gt_flops, | gt_flops, | ||||
| gt_acts, | gt_acts, | ||||