BREAKING CHANGE:
GitOrigin-RevId: ced3da3a12
tags/v1.4.0
| @@ -9,6 +9,7 @@ | |||
| import argparse | |||
| import logging | |||
| import re | |||
| from collections import namedtuple | |||
| import numpy as np | |||
| @@ -16,12 +17,17 @@ from megengine.core.tensor.dtype import is_quantize | |||
| from megengine.logger import _imperative_rt_logger, get_logger, set_mgb_log_level | |||
| from megengine.utils.module_stats import ( | |||
| enable_receptive_field, | |||
| get_activation_stats, | |||
| get_op_stats, | |||
| get_param_stats, | |||
| print_activations_stats, | |||
| print_op_stats, | |||
| print_param_stats, | |||
| print_summary, | |||
| sizeof_fmt, | |||
| sum_activations_stats, | |||
| sum_op_stats, | |||
| sum_param_stats, | |||
| ) | |||
| from megengine.utils.network import Network | |||
| @@ -34,6 +40,7 @@ def visualize( | |||
| bar_length_max: int = 20, | |||
| log_params: bool = True, | |||
| log_flops: bool = True, | |||
| log_activations: bool = True, | |||
| ): | |||
| r""" | |||
| Load megengine dumped model and visualize graph structure with tensorboard log files. | |||
| @@ -44,6 +51,7 @@ def visualize( | |||
| :param bar_length_max: size of bar indicating max flops or parameter size in net stats. | |||
| :param log_params: whether print and record params size. | |||
| :param log_flops: whether print and record op flops. | |||
| :param log_activations: whether print and record op activations. | |||
| """ | |||
| if log_path: | |||
| try: | |||
| @@ -83,6 +91,10 @@ def visualize( | |||
| node_list = [] | |||
| flops_list = [] | |||
| params_list = [] | |||
| activations_list = [] | |||
| total_stats = namedtuple("total_stats", ["param_size", "flops", "act_size"]) | |||
| stats_details = namedtuple("module_stats", ["params", "flops", "activations"]) | |||
| for node in graph.all_oprs: | |||
| if hasattr(node, "output_idx"): | |||
| node_oup = node.outputs[node.output_idx] | |||
| @@ -124,6 +136,11 @@ def visualize( | |||
| flops_stats["class_name"] = node.type | |||
| flops_list.append(flops_stats) | |||
| acts = get_activation_stats(node_oup.numpy()) | |||
| acts["name"] = node.name | |||
| acts["class_name"] = node.type | |||
| activations_list.append(acts) | |||
| if node.type == "ImmutableTensor": | |||
| param_stats = get_param_stats(node.numpy()) | |||
| # add tensor size attr | |||
| @@ -149,20 +166,36 @@ def visualize( | |||
| "#params": len(params_list), | |||
| } | |||
| total_flops, total_param_dims, total_param_size = 0, 0, 0 | |||
| ( | |||
| total_flops, | |||
| total_param_dims, | |||
| total_param_size, | |||
| total_act_dims, | |||
| total_param_size, | |||
| ) = (0, 0, 0, 0, 0) | |||
| total_param_dims, total_param_size, params = sum_param_stats( | |||
| params_list, bar_length_max | |||
| ) | |||
| extra_info["total_param_dims"] = sizeof_fmt(total_param_dims, suffix="") | |||
| extra_info["total_param_size"] = sizeof_fmt(total_param_size) | |||
| if log_params: | |||
| total_param_dims, total_param_size = print_param_stats( | |||
| params_list, bar_length_max | |||
| ) | |||
| extra_info["total_param_dims"] = sizeof_fmt(total_param_dims) | |||
| extra_info["total_param_size"] = sizeof_fmt(total_param_size) | |||
| print_param_stats(params) | |||
| total_flops, flops = sum_op_stats(flops_list, bar_length_max) | |||
| extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs") | |||
| if log_flops: | |||
| total_flops = print_op_stats(flops_list, bar_length_max) | |||
| extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs") | |||
| if log_params and log_flops: | |||
| extra_info["flops/param_size"] = "{:3.3f}".format( | |||
| total_flops / total_param_size | |||
| ) | |||
| print_op_stats(flops) | |||
| total_act_dims, total_act_size, activations = sum_activations_stats( | |||
| activations_list, bar_length_max | |||
| ) | |||
| extra_info["total_act_dims"] = sizeof_fmt(total_act_dims, suffix="") | |||
| extra_info["total_act_size"] = sizeof_fmt(total_act_size) | |||
| if log_activations: | |||
| print_activations_stats(activations) | |||
| extra_info["flops/param_size"] = "{:3.3f}".format(total_flops / total_param_size) | |||
| if log_path: | |||
| graph_def = GraphDef(node=node_list, versions=VersionDef(producer=22)) | |||
| @@ -179,7 +212,12 @@ def visualize( | |||
| # FIXME: remove this after resolving "span dist too large" warning | |||
| _imperative_rt_logger.set_log_level(old_level) | |||
| return total_param_size, total_flops | |||
| return ( | |||
| total_stats( | |||
| param_size=total_param_size, flops=total_flops, act_size=total_act_size, | |||
| ), | |||
| stats_details(params=params, flops=flops, activations=activations), | |||
| ) | |||
| def main(): | |||
| @@ -5,7 +5,7 @@ | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import contextlib | |||
| from collections import namedtuple | |||
| from functools import partial | |||
| import numpy as np | |||
| @@ -18,6 +18,8 @@ import megengine.module.quantized as qm | |||
| from megengine.core.tensor.dtype import get_dtype_bit | |||
| from megengine.functional.tensor import zeros | |||
| from .module_utils import set_module_mode_safe | |||
| try: | |||
| mge.logger.MegEngineLogFormatter.max_lines = float("inf") | |||
| except AttributeError as e: | |||
| @@ -98,6 +100,27 @@ def flops_convNd(module: m.Conv2d, inputs, outputs): | |||
| ) | |||
| @register_flops( | |||
| m.batchnorm._BatchNorm, m.SyncBatchNorm, m.GroupNorm, m.LayerNorm, m.InstanceNorm, | |||
| ) | |||
| def flops_norm(module: m.Linear, inputs, outputs): | |||
| return np.prod(inputs[0].shape) * 7 | |||
| @register_flops(m.AvgPool2d, m.MaxPool2d) | |||
| def flops_pool(module: m.AvgPool2d, inputs, outputs): | |||
| return np.prod(outputs[0].shape) * (module.kernel_size ** 2) | |||
| @register_flops(m.AdaptiveAvgPool2d, m.AdaptiveMaxPool2d) | |||
| def flops_adaptivePool(module: m.AdaptiveAvgPool2d, inputs, outputs): | |||
| stride_h = np.floor(inputs[0].shape[2] / (inputs[0].shape[2] - 1)) | |||
| kernel_h = inputs[0].shape[2] - (inputs[0].shape[2] - 1) * stride_h | |||
| stride_w = np.floor(inputs[0].shape[3] / (inputs[0].shape[3] - 1)) | |||
| kernel_w = inputs[0].shape[3] - (inputs[0].shape[3] - 1) * stride_w | |||
| return np.prod(outputs[0].shape) * kernel_h * kernel_w | |||
| @register_flops(m.Linear) | |||
| def flops_linear(module: m.Linear, inputs, outputs): | |||
| bias = module.out_features if module.bias is not None else 0 | |||
| @@ -120,6 +143,12 @@ hook_modules = ( | |||
| m.conv._ConvNd, | |||
| m.Linear, | |||
| m.BatchMatMulActivation, | |||
| m.batchnorm._BatchNorm, | |||
| m.LayerNorm, | |||
| m.GroupNorm, | |||
| m.InstanceNorm, | |||
| m.pooling._PoolNd, | |||
| m.adaptive_pooling._AdaptivePoolNd, | |||
| ) | |||
| @@ -137,12 +166,16 @@ def dict2table(list_of_dict, header): | |||
| def sizeof_fmt(num, suffix="B"): | |||
| for unit in ["", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"]: | |||
| if abs(num) < 1024.0: | |||
| if suffix == "B": | |||
| scale = 1024.0 | |||
| units = ["", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi", "Yi"] | |||
| else: | |||
| scale = 1000.0 | |||
| units = ["", "K", "M", "G", "T", "P", "E", "Z", "Y"] | |||
| for unit in units: | |||
| if abs(num) < scale or unit == units[-1]: | |||
| return "{:3.3f} {}{}".format(num, unit, suffix) | |||
| num /= 1024.0 | |||
| sign_str = "-" if num < 0 else "" | |||
| return "{}{:.1f} {}{}".format(sign_str, num, "Yi", suffix) | |||
| num /= scale | |||
| def preprocess_receptive_field(module, inputs, outputs): | |||
| @@ -159,6 +192,8 @@ def preprocess_receptive_field(module, inputs, outputs): | |||
| def get_op_stats(module, inputs, outputs): | |||
| if not isinstance(outputs, tuple) and not isinstance(outputs, list): | |||
| outputs = (outputs,) | |||
| rst = { | |||
| "input_shapes": [i.shape for i in inputs], | |||
| "output_shapes": [o.shape for o in outputs], | |||
| @@ -189,7 +224,7 @@ def get_op_stats(module, inputs, outputs): | |||
| return | |||
| def print_op_stats(flops, bar_length_max=20): | |||
| def sum_op_stats(flops, bar_length_max=20): | |||
| max_flops_num = max([i["flops_num"] for i in flops] + [0]) | |||
| total_flops_num = 0 | |||
| for d in flops: | |||
| @@ -203,6 +238,18 @@ def print_op_stats(flops, bar_length_max=20): | |||
| d["bar"] = "#" * bar_length | |||
| d["flops"] = sizeof_fmt(d["flops_num"], suffix="OPs") | |||
| total_flops_str = sizeof_fmt(total_flops_num, suffix="OPs") | |||
| total_var_size = sum( | |||
| sum(s[1] if len(s) > 1 else 0 for s in d["output_shapes"]) for d in flops | |||
| ) | |||
| flops.append( | |||
| dict(name="total", flops=total_flops_str, output_shapes=total_var_size) | |||
| ) | |||
| return total_flops_num, flops | |||
| def print_op_stats(flops): | |||
| header = [ | |||
| "name", | |||
| "class_name", | |||
| @@ -216,19 +263,8 @@ def print_op_stats(flops, bar_length_max=20): | |||
| if _receptive_field_enabled: | |||
| header.insert(4, "receptive_field") | |||
| header.insert(5, "stride") | |||
| total_flops_str = sizeof_fmt(total_flops_num, suffix="OPs") | |||
| total_var_size = sum( | |||
| sum(s[1] if len(s) > 1 else 0 for s in d["output_shapes"]) for d in flops | |||
| ) | |||
| flops.append( | |||
| dict(name="total", flops=total_flops_str, output_shapes=total_var_size) | |||
| ) | |||
| logger.info("flops stats: \n" + tabulate.tabulate(dict2table(flops, header=header))) | |||
| return total_flops_num | |||
| def get_param_stats(param: np.ndarray): | |||
| nbits = get_dtype_bit(param.dtype.name) | |||
| @@ -246,7 +282,7 @@ def get_param_stats(param: np.ndarray): | |||
| } | |||
| def print_param_stats(params, bar_length_max=20): | |||
| def sum_param_stats(params, bar_length_max=20): | |||
| max_size = max([d["size"] for d in params] + [0]) | |||
| total_param_dims, total_param_size = 0, 0 | |||
| for d in params: | |||
| @@ -265,6 +301,10 @@ def print_param_stats(params, bar_length_max=20): | |||
| param_size = sizeof_fmt(total_param_size) | |||
| params.append(dict(name="total", param_dim=total_param_dims, size=param_size,)) | |||
| return total_param_dims, total_param_size, params | |||
| def print_param_stats(params): | |||
| header = [ | |||
| "name", | |||
| "dtype", | |||
| @@ -272,18 +312,74 @@ def print_param_stats(params, bar_length_max=20): | |||
| "mean", | |||
| "std", | |||
| "param_dim", | |||
| "bits", | |||
| "nbits", | |||
| "size", | |||
| "size_cum", | |||
| "percentage", | |||
| "size_bar", | |||
| ] | |||
| logger.info( | |||
| "param stats: \n" + tabulate.tabulate(dict2table(params, header=header)) | |||
| ) | |||
| return total_param_dims, total_param_size | |||
| def get_activation_stats(output: np.ndarray): | |||
| out_shape = output.shape | |||
| activations_dtype = output.dtype | |||
| nbits = get_dtype_bit(activations_dtype.name) | |||
| act_dim = np.prod(out_shape) | |||
| act_size = act_dim * nbits // 8 | |||
| return { | |||
| "dtype": activations_dtype, | |||
| "shape": out_shape, | |||
| "act_dim": act_dim, | |||
| "mean": "{:.3g}".format(output.mean()), | |||
| "std": "{:.3g}".format(output.std()), | |||
| "nbits": nbits, | |||
| "size": act_size, | |||
| } | |||
| def sum_activations_stats(activations, bar_length_max=20): | |||
| max_act_size = max([i["size"] for i in activations] + [0]) | |||
| total_act_dims, total_act_size = 0, 0 | |||
| for d in activations: | |||
| total_act_size += int(d["size"]) | |||
| total_act_dims += int(d["act_dim"]) | |||
| d["size_cum"] = sizeof_fmt(total_act_size) | |||
| for d in activations: | |||
| ratio = d["ratio"] = d["size"] / total_act_size | |||
| d["percentage"] = "{:.2f}%".format(ratio * 100) | |||
| bar_length = int(d["size"] / max_act_size * bar_length_max) | |||
| d["size_bar"] = "#" * bar_length | |||
| d["size"] = sizeof_fmt(d["size"]) | |||
| act_size = sizeof_fmt(total_act_size) | |||
| activations.append(dict(name="total", act_dim=total_act_dims, size=act_size,)) | |||
| return total_act_dims, total_act_size, activations | |||
| def print_activations_stats(activations): | |||
| header = [ | |||
| "name", | |||
| "class_name", | |||
| "dtype", | |||
| "shape", | |||
| "mean", | |||
| "std", | |||
| "nbits", | |||
| "act_dim", | |||
| "size", | |||
| "size_cum", | |||
| "percentage", | |||
| "size_bar", | |||
| ] | |||
| logger.info( | |||
| "activations stats: \n" | |||
| + tabulate.tabulate(dict2table(activations, header=header)) | |||
| ) | |||
| def print_summary(**kwargs): | |||
| @@ -294,25 +390,26 @@ def print_summary(**kwargs): | |||
| def module_stats( | |||
| model: m.Module, | |||
| input_size: int, | |||
| input_shapes: list, | |||
| bar_length_max: int = 20, | |||
| log_params: bool = True, | |||
| log_flops: bool = True, | |||
| log_activations: bool = True, | |||
| ): | |||
| r""" | |||
| Calculate and print ``model``'s statistics by adding hook and record Module's inputs outputs size. | |||
| :param model: model that need to get stats info. | |||
| :param input_size: size of input for running model and calculating stats. | |||
| :param input_shapes: shapes of inputs for running model and calculating stats. | |||
| :param bar_length_max: size of bar indicating max flops or parameter size in net stats. | |||
| :param log_params: whether print and record params size. | |||
| :param log_flops: whether print and record op flops. | |||
| :param log_activations: whether print and record op activations. | |||
| """ | |||
| disable_receptive_field() | |||
| def module_stats_hook(module, inputs, outputs, name=""): | |||
| class_name = str(module.__class__).split(".")[-1].split("'")[0] | |||
| flops_stats = get_op_stats(module, inputs, outputs) | |||
| if flops_stats is not None: | |||
| flops_stats["name"] = name | |||
| @@ -331,38 +428,25 @@ def module_stats( | |||
| param_stats["name"] = name + "-b" | |||
| params.append(param_stats) | |||
| @contextlib.contextmanager | |||
| def adjust_stats(module, training=False): | |||
| """Adjust module to training/eval mode temporarily. | |||
| Args: | |||
| module (M.Module): used module. | |||
| training (bool): training mode. True for train mode, False fro eval mode. | |||
| """ | |||
| def recursive_backup_stats(module, mode): | |||
| for m in module.modules(): | |||
| # save prev status to _prev_training | |||
| m._prev_training = m.training | |||
| m.train(mode, recursive=False) | |||
| def recursive_recover_stats(module): | |||
| for m in module.modules(): | |||
| # recover prev status and delete attribute | |||
| m.training = m._prev_training | |||
| delattr(m, "_prev_training") | |||
| recursive_backup_stats(module, mode=training) | |||
| yield module | |||
| recursive_recover_stats(module) | |||
| if not isinstance(outputs, tuple) or not isinstance(outputs, list): | |||
| output = outputs.numpy() | |||
| else: | |||
| output = outputs[0].numpy() | |||
| activation_stats = get_activation_stats(output) | |||
| activation_stats["name"] = name | |||
| activation_stats["class_name"] = class_name | |||
| activations.append(activation_stats) | |||
| # multiple inputs to the network | |||
| if not isinstance(input_size[0], tuple): | |||
| input_size = [input_size] | |||
| if not isinstance(input_shapes[0], tuple): | |||
| input_shapes = [input_shapes] | |||
| params = [] | |||
| flops = [] | |||
| hooks = [] | |||
| activations = [] | |||
| total_stats = namedtuple("total_stats", ["param_size", "flops", "act_size"]) | |||
| stats_details = namedtuple("module_stats", ["params", "flops", "activations"]) | |||
| for (name, module) in model.named_modules(): | |||
| if isinstance(module, hook_modules): | |||
| @@ -370,8 +454,8 @@ def module_stats( | |||
| module.register_forward_hook(partial(module_stats_hook, name=name)) | |||
| ) | |||
| inputs = [zeros(in_size, dtype=np.float32) for in_size in input_size] | |||
| with adjust_stats(model, training=False) as model: | |||
| inputs = [zeros(in_size, dtype=np.float32) for in_size in input_shapes] | |||
| with set_module_mode_safe(model, training=False) as model: | |||
| model(*inputs) | |||
| for h in hooks: | |||
| @@ -380,19 +464,40 @@ def module_stats( | |||
| extra_info = { | |||
| "#params": len(params), | |||
| } | |||
| total_flops, total_param_dims, total_param_size = 0, 0, 0 | |||
| ( | |||
| total_flops, | |||
| total_param_dims, | |||
| total_param_size, | |||
| total_act_dims, | |||
| total_param_size, | |||
| ) = (0, 0, 0, 0, 0) | |||
| total_param_dims, total_param_size, params = sum_param_stats(params, bar_length_max) | |||
| extra_info["total_param_dims"] = sizeof_fmt(total_param_dims, suffix="") | |||
| extra_info["total_param_size"] = sizeof_fmt(total_param_size) | |||
| if log_params: | |||
| total_param_dims, total_param_size = print_param_stats(params, bar_length_max) | |||
| extra_info["total_param_dims"] = sizeof_fmt(total_param_dims) | |||
| extra_info["total_param_size"] = sizeof_fmt(total_param_size) | |||
| print_param_stats(params) | |||
| total_flops, flops = sum_op_stats(flops, bar_length_max) | |||
| extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs") | |||
| if log_flops: | |||
| total_flops = print_op_stats(flops, bar_length_max) | |||
| extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs") | |||
| if log_params and log_flops: | |||
| extra_info["flops/param_size"] = "{:3.3f}".format( | |||
| total_flops / total_param_size | |||
| ) | |||
| print_op_stats(flops) | |||
| total_act_dims, total_act_size, activations = sum_activations_stats( | |||
| activations, bar_length_max | |||
| ) | |||
| extra_info["total_act_dims"] = sizeof_fmt(total_act_dims, suffix="") | |||
| extra_info["total_act_size"] = sizeof_fmt(total_act_size) | |||
| if log_activations: | |||
| print_activations_stats(activations) | |||
| extra_info["flops/param_size"] = "{:3.3f}".format(total_flops / total_param_size) | |||
| print_summary(**extra_info) | |||
| return total_param_size, total_flops | |||
| return ( | |||
| total_stats( | |||
| param_size=total_param_size, flops=total_flops, act_size=total_act_size, | |||
| ), | |||
| stats_details(params=params, flops=flops, activations=activations), | |||
| ) | |||
| @@ -5,6 +5,7 @@ | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import contextlib | |||
| from collections import Iterable | |||
| from ..module import Sequential | |||
| @@ -41,3 +42,28 @@ def set_expand_structure(obj: Module, key: str, value): | |||
| parent[key] = value | |||
| _access_structure(obj, key, callback=f) | |||
| @contextlib.contextmanager | |||
| def set_module_mode_safe( | |||
| module: Module, training: bool = False, | |||
| ): | |||
| """Adjust module to training/eval mode temporarily. | |||
| :param module: used module. | |||
| :param training: training (bool): training mode. True for train mode, False fro eval mode. | |||
| """ | |||
| backup_stats = {} | |||
| def recursive_backup_stats(module, mode): | |||
| for m in module.modules(): | |||
| backup_stats[m] = m.training | |||
| m.train(mode, recursive=False) | |||
| def recursive_recover_stats(module): | |||
| for m in module.modules(): | |||
| m.training = backup_stats.pop(m) | |||
| recursive_backup_stats(module, mode=training) | |||
| yield module | |||
| recursive_recover_stats(module) | |||
| @@ -0,0 +1,377 @@ | |||
| import math | |||
| from copy import deepcopy | |||
| import numpy as np | |||
| import pytest | |||
| import megengine as mge | |||
| import megengine.functional as F | |||
| import megengine.hub as hub | |||
| import megengine.module as M | |||
| from megengine.core._trace_option import use_symbolic_shape | |||
| from megengine.utils.module_stats import module_stats | |||
| @pytest.mark.skipif( | |||
| use_symbolic_shape(), reason="This test do not support symbolic shape.", | |||
| ) | |||
| def test_module_stats(): | |||
| net = ResNet(BasicBlock, [2, 2, 2, 2]) | |||
| input_shape = (1, 3, 224, 224) | |||
| total_stats, stats_details = module_stats(net, input_shape) | |||
| x1 = mge.tensor(np.zeros((1, 3, 224, 224))) | |||
| gt_flops, gt_acts = net.get_stats(x1) | |||
| assert (total_stats.flops, stats_details.activations[-1]["act_dim"]) == ( | |||
| gt_flops, | |||
| gt_acts, | |||
| ) | |||
| class BasicBlock(M.Module): | |||
| expansion = 1 | |||
| def __init__( | |||
| self, | |||
| in_channels, | |||
| channels, | |||
| stride=1, | |||
| groups=1, | |||
| base_width=64, | |||
| dilation=1, | |||
| norm=M.BatchNorm2d, | |||
| ): | |||
| super().__init__() | |||
| self.tmp_in_channels = in_channels | |||
| self.tmp_channels = channels | |||
| self.stride = stride | |||
| if groups != 1 or base_width != 64: | |||
| raise ValueError("BasicBlock only supports groups=1 and base_width=64") | |||
| if dilation > 1: | |||
| raise NotImplementedError("Dilation > 1 not supported in BasicBlock") | |||
| self.conv1 = M.Conv2d( | |||
| in_channels, channels, 3, stride, padding=dilation, bias=False | |||
| ) | |||
| self.bn1 = norm(channels) | |||
| self.conv2 = M.Conv2d(channels, channels, 3, 1, padding=1, bias=False) | |||
| self.bn2 = norm(channels) | |||
| self.downsample_id = M.Identity() | |||
| self.downsample_conv = M.Conv2d(in_channels, channels, 1, stride, bias=False) | |||
| self.downsample_norm = norm(channels) | |||
| def forward(self, x): | |||
| identity = x | |||
| x = self.conv1(x) | |||
| x = self.bn1(x) | |||
| x = F.relu(x) | |||
| x = self.conv2(x) | |||
| x = self.bn2(x) | |||
| if self.tmp_in_channels == self.tmp_channels and self.stride == 1: | |||
| identity = self.downsample_id(identity) | |||
| else: | |||
| identity = self.downsample_conv(identity) | |||
| identity = self.downsample_norm(identity) | |||
| x += identity | |||
| x = F.relu(x) | |||
| return x | |||
| def get_stats(self, x): | |||
| activations, flops = 0, 0 | |||
| identity = x | |||
| in_x = deepcopy(x) | |||
| x = self.conv1(x) | |||
| tmp_flops, tmp_acts = cal_conv_stats(self.conv1, in_x, x) | |||
| activations += tmp_acts | |||
| flops += tmp_flops | |||
| in_x = deepcopy(x) | |||
| x = self.bn1(x) | |||
| tmp_flops, tmp_acts = cal_norm_stats(self.bn1, in_x, x) | |||
| activations += tmp_acts | |||
| flops += tmp_flops | |||
| x = F.relu(x) | |||
| in_x = deepcopy(x) | |||
| x = self.conv2(x) | |||
| tmp_flops, tmp_acts = cal_conv_stats(self.conv2, in_x, x) | |||
| activations += tmp_acts | |||
| flops += tmp_flops | |||
| in_x = deepcopy(x) | |||
| x = self.bn2(x) | |||
| tmp_flops, tmp_acts = cal_norm_stats(self.bn2, in_x, x) | |||
| activations += tmp_acts | |||
| flops += tmp_flops | |||
| if self.tmp_in_channels == self.tmp_channels and self.stride == 1: | |||
| identity = self.downsample_id(identity) | |||
| else: | |||
| in_x = deepcopy(identity) | |||
| identity = self.downsample_conv(identity) | |||
| tmp_flops, tmp_acts = cal_conv_stats(self.downsample_conv, in_x, identity) | |||
| activations += tmp_acts | |||
| flops += tmp_flops | |||
| in_x = deepcopy(identity) | |||
| identity = self.downsample_norm(identity) | |||
| tmp_flops, tmp_acts = cal_norm_stats(self.downsample_norm, in_x, identity) | |||
| activations += tmp_acts | |||
| flops += tmp_flops | |||
| x += identity | |||
| x = F.relu(x) | |||
| return x, flops, activations | |||
| class ResNet(M.Module): | |||
| def __init__( | |||
| self, | |||
| block, | |||
| layers=[2, 2, 2, 2], | |||
| num_classes=1000, | |||
| zero_init_residual=False, | |||
| groups=1, | |||
| width_per_group=64, | |||
| replace_stride_with_dilation=None, | |||
| norm=M.BatchNorm2d, | |||
| ): | |||
| super().__init__() | |||
| self.in_channels = 64 | |||
| self.dilation = 1 | |||
| if replace_stride_with_dilation is None: | |||
| # each element in the tuple indicates if we should replace | |||
| # the 2x2 stride with a dilated convolution instead | |||
| replace_stride_with_dilation = [False, False, False] | |||
| if len(replace_stride_with_dilation) != 3: | |||
| raise ValueError( | |||
| "replace_stride_with_dilation should be None " | |||
| "or a 3-element tuple, got {}".format(replace_stride_with_dilation) | |||
| ) | |||
| self.groups = groups | |||
| self.base_width = width_per_group | |||
| self.conv1 = M.Conv2d( | |||
| 3, self.in_channels, kernel_size=7, stride=2, padding=3, bias=False | |||
| ) | |||
| self.bn1 = norm(self.in_channels) | |||
| self.maxpool = M.MaxPool2d(kernel_size=3, stride=2, padding=1) | |||
| self.layer1_0 = BasicBlock( | |||
| self.in_channels, | |||
| 64, | |||
| stride=1, | |||
| groups=self.groups, | |||
| base_width=self.base_width, | |||
| dilation=self.dilation, | |||
| norm=M.BatchNorm2d, | |||
| ) | |||
| self.layer1_1 = BasicBlock( | |||
| self.in_channels, | |||
| 64, | |||
| stride=1, | |||
| groups=self.groups, | |||
| base_width=self.base_width, | |||
| dilation=self.dilation, | |||
| norm=M.BatchNorm2d, | |||
| ) | |||
| self.layer2_0 = BasicBlock(64, 128, stride=2) | |||
| self.layer2_1 = BasicBlock(128, 128) | |||
| self.layer3_0 = BasicBlock(128, 256, stride=2) | |||
| self.layer3_1 = BasicBlock(256, 256) | |||
| self.layer4_0 = BasicBlock(256, 512, stride=2) | |||
| self.layer4_1 = BasicBlock(512, 512) | |||
| self.layer1 = self._make_layer(block, 64, layers[0], norm=norm) | |||
| self.layer2 = self._make_layer( | |||
| block, 128, 2, stride=2, dilate=replace_stride_with_dilation[0], norm=norm | |||
| ) | |||
| self.layer3 = self._make_layer( | |||
| block, 256, 2, stride=2, dilate=replace_stride_with_dilation[1], norm=norm | |||
| ) | |||
| self.layer4 = self._make_layer( | |||
| block, 512, 2, stride=2, dilate=replace_stride_with_dilation[2], norm=norm | |||
| ) | |||
| self.fc = M.Linear(512, num_classes) | |||
| for m in self.modules(): | |||
| if isinstance(m, M.Conv2d): | |||
| M.init.msra_normal_(m.weight, mode="fan_out", nonlinearity="relu") | |||
| if m.bias is not None: | |||
| fan_in, _ = M.init.calculate_fan_in_and_fan_out(m.weight) | |||
| bound = 1 / math.sqrt(fan_in) | |||
| M.init.uniform_(m.bias, -bound, bound) | |||
| elif isinstance(m, M.BatchNorm2d): | |||
| M.init.ones_(m.weight) | |||
| M.init.zeros_(m.bias) | |||
| elif isinstance(m, M.Linear): | |||
| M.init.msra_uniform_(m.weight, a=math.sqrt(5)) | |||
| if m.bias is not None: | |||
| fan_in, _ = M.init.calculate_fan_in_and_fan_out(m.weight) | |||
| bound = 1 / math.sqrt(fan_in) | |||
| M.init.uniform_(m.bias, -bound, bound) | |||
| if zero_init_residual: | |||
| for m in self.modules(): | |||
| M.init.zeros_(m.bn2.weight) | |||
| def _make_layer( | |||
| self, block, channels, blocks, stride=1, dilate=False, norm=M.BatchNorm2d | |||
| ): | |||
| previous_dilation = self.dilation | |||
| if dilate: | |||
| self.dilation *= stride | |||
| stride = 1 | |||
| layers = [] | |||
| layers.append( | |||
| block( | |||
| self.in_channels, | |||
| channels, | |||
| stride, | |||
| groups=self.groups, | |||
| base_width=self.base_width, | |||
| dilation=previous_dilation, | |||
| norm=norm, | |||
| ) | |||
| ) | |||
| self.in_channels = channels * block.expansion | |||
| for _ in range(1, blocks): | |||
| layers.append( | |||
| block( | |||
| self.in_channels, | |||
| channels, | |||
| groups=self.groups, | |||
| base_width=self.base_width, | |||
| dilation=self.dilation, | |||
| norm=norm, | |||
| ) | |||
| ) | |||
| return M.Sequential(*layers) | |||
| def extract_features(self, x): | |||
| outputs = {} | |||
| x = self.conv1(x) | |||
| x = self.bn1(x) | |||
| x = F.relu(x) | |||
| x = self.maxpool(x) | |||
| outputs["stem"] = x | |||
| x = self.layer1(x) | |||
| outputs["res2"] = x | |||
| x = self.layer2(x) | |||
| outputs["res3"] = x | |||
| x = self.layer3(x) | |||
| outputs["res4"] = x | |||
| x = self.layer4(x) | |||
| outputs["res5"] = x | |||
| return outputs | |||
| def forward(self, x): | |||
| x = self.extract_features(x)["res5"] | |||
| x = F.avg_pool2d(x, 7) | |||
| x = F.flatten(x, 1) | |||
| x = self.fc(x) | |||
| return x | |||
| def get_stats(self, x): | |||
| flops, activations = 0, 0 | |||
| in_x = deepcopy(x) | |||
| x = self.conv1(x) | |||
| tmp_flops, tmp_acts = cal_conv_stats(self.conv1, in_x, x) | |||
| activations += tmp_acts | |||
| flops += tmp_flops | |||
| in_x = deepcopy(x) | |||
| x = self.bn1(x) | |||
| tmp_flops, tmp_acts = cal_norm_stats(self.bn1, in_x, x) | |||
| activations += tmp_acts | |||
| flops += tmp_flops | |||
| x = F.relu(x) | |||
| in_x = deepcopy(x) | |||
| x = self.maxpool(x) | |||
| tmp_flops, tmp_acts = cal_pool_stats(self.maxpool, in_x, x) | |||
| activations += tmp_acts | |||
| flops += tmp_flops | |||
| x, tmp_flops, tmp_acts = self.layer1_0.get_stats(x) | |||
| activations += tmp_acts | |||
| flops += tmp_flops | |||
| x, tmp_flops, tmp_acts = self.layer1_1.get_stats(x) | |||
| activations += tmp_acts | |||
| flops += tmp_flops | |||
| x, tmp_flops, tmp_acts = self.layer2_0.get_stats(x) | |||
| activations += tmp_acts | |||
| flops += tmp_flops | |||
| x, tmp_flops, tmp_acts = self.layer2_1.get_stats(x) | |||
| activations += tmp_acts | |||
| flops += tmp_flops | |||
| x, tmp_flops, tmp_acts = self.layer3_0.get_stats(x) | |||
| activations += tmp_acts | |||
| flops += tmp_flops | |||
| x, tmp_flops, tmp_acts = self.layer3_1.get_stats(x) | |||
| activations += tmp_acts | |||
| flops += tmp_flops | |||
| x, tmp_flops, tmp_acts = self.layer4_0.get_stats(x) | |||
| activations += tmp_acts | |||
| flops += tmp_flops | |||
| x, tmp_flops, tmp_acts = self.layer4_1.get_stats(x) | |||
| activations += tmp_acts | |||
| flops += tmp_flops | |||
| x = F.avg_pool2d(x, 7) | |||
| x = F.flatten(x, 1) | |||
| in_x = deepcopy(x) | |||
| x = self.fc(x) | |||
| tmp_flops, tmp_acts = cal_linear_stats(self.fc, in_x, x) | |||
| activations += tmp_acts | |||
| flops += tmp_flops | |||
| return flops, activations | |||
| def cal_conv_stats(module, input, output): | |||
| bias = 1 if module.bias is not None else 0 | |||
| flops = np.prod(output[0].shape) * ( | |||
| module.in_channels // module.groups * np.prod(module.kernel_size) + bias | |||
| ) | |||
| acts = np.prod(output[0].shape) | |||
| return flops, acts | |||
| def cal_norm_stats(module, input, output): | |||
| return np.prod(input[0].shape) * 7, np.prod(output[0].shape) | |||
| def cal_linear_stats(module, inputs, outputs): | |||
| bias = module.out_features if module.bias is not None else 0 | |||
| return ( | |||
| np.prod(outputs[0].shape) * module.in_features + bias, | |||
| np.prod(outputs[0].shape), | |||
| ) | |||
| def cal_pool_stats(module, inputs, outputs): | |||
| return ( | |||
| np.prod(outputs[0].shape) * (module.kernel_size ** 2), | |||
| np.prod(outputs[0].shape), | |||
| ) | |||