|
|
|
@@ -26,61 +26,95 @@ logger = mge.get_logger(__name__) |
|
|
|
logger.setLevel("INFO") |
|
|
|
|
|
|
|
|
|
|
|
CALC_FLOPS = {} |
|
|
|
|
|
|
|
|
|
|
|
def _register_modules(*modules): |
|
|
|
_calc_flops_dict = {} |
|
|
|
_calc_receptive_field_dict = {} |
|
|
|
|
|
|
|
|
|
|
|
def _receptive_field_fallback(module, inputs, outputs): |
|
|
|
assert not hasattr(module, "_rf") |
|
|
|
assert not hasattr(module, "_stride") |
|
|
|
if len(inputs) == 0: |
|
|
|
# TODO: support other dimension |
|
|
|
module._rf = (1, 1) |
|
|
|
module._stride = (1, 1) |
|
|
|
return module._rf, module._stride |
|
|
|
rf, stride = preprocess_receptive_field(module, inputs, outputs) |
|
|
|
module._rf = rf |
|
|
|
module._stride = stride |
|
|
|
return rf, stride |
|
|
|
|
|
|
|
|
|
|
|
# key tuple, impl_dict, fallback |
|
|
|
_iter_list = [ |
|
|
|
("flops_num", _calc_flops_dict, None), |
|
|
|
( |
|
|
|
("receptive_field", "stride"), |
|
|
|
_calc_receptive_field_dict, |
|
|
|
_receptive_field_fallback, |
|
|
|
), |
|
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
def _register_dict(*modules, dict=None): |
|
|
|
def callback(impl): |
|
|
|
for module in modules: |
|
|
|
CALC_FLOPS[module] = impl |
|
|
|
dict[module] = impl |
|
|
|
return impl |
|
|
|
|
|
|
|
return callback |
|
|
|
|
|
|
|
|
|
|
|
@_register_modules( |
|
|
|
m.Conv2d, |
|
|
|
m.ConvTranspose2d, |
|
|
|
m.LocalConv2d, |
|
|
|
qm.Conv2d, |
|
|
|
qm.ConvRelu2d, |
|
|
|
qm.ConvBn2d, |
|
|
|
qm.ConvBnRelu2d, |
|
|
|
qatm.Conv2d, |
|
|
|
qatm.ConvRelu2d, |
|
|
|
qatm.ConvBn2d, |
|
|
|
qatm.ConvBnRelu2d, |
|
|
|
def register_flops(*modules): |
|
|
|
return _register_dict(*modules, dict=_calc_flops_dict) |
|
|
|
|
|
|
|
|
|
|
|
def register_receptive_field(*modules): |
|
|
|
return _register_dict(*modules, dict=_calc_receptive_field_dict) |
|
|
|
|
|
|
|
|
|
|
|
@register_flops( |
|
|
|
m.Conv1d, m.Conv2d, m.Conv3d, |
|
|
|
) |
|
|
|
def count_convNd(module, input, output): |
|
|
|
def flops_convNd(module: m.Conv2d, inputs, outputs): |
|
|
|
bias = 1 if module.bias is not None else 0 |
|
|
|
group = module.groups |
|
|
|
ic = input[0].shape[1] |
|
|
|
oc = output[0].shape[1] |
|
|
|
ic = inputs[0].shape[1] |
|
|
|
oc = outputs[0].shape[1] |
|
|
|
goc = oc // group |
|
|
|
gic = ic // group |
|
|
|
N = output[0].shape[0] |
|
|
|
HW = np.prod(output[0].shape[2:]) |
|
|
|
N = outputs[0].shape[0] |
|
|
|
HW = np.prod(outputs[0].shape[2:]) |
|
|
|
# N x Cout x H x W x (Cin x Kw x Kh + bias) |
|
|
|
return N * HW * goc * (gic * np.prod(module.kernel_size) + bias) |
|
|
|
|
|
|
|
|
|
|
|
@_register_modules(m.ConvTranspose2d) |
|
|
|
def count_deconvNd(module, input, output): |
|
|
|
return np.prod(input[0].shape) * output[0].shape[1] * np.prod(module.kernel_size) |
|
|
|
@register_flops(m.ConvTranspose2d) |
|
|
|
def flops_deconvNd(module: m.ConvTranspose2d, inputs, outputs): |
|
|
|
return np.prod(inputs[0].shape) * outputs[0].shape[1] * np.prod(module.kernel_size) |
|
|
|
|
|
|
|
|
|
|
|
@register_flops(m.Linear) |
|
|
|
def flops_linear(module: m.Linear, inputs, outputs): |
|
|
|
bias = 1 if module.bias is not None else 0 |
|
|
|
return np.prod(outputs[0].shape) * module.in_features |
|
|
|
|
|
|
|
@_register_modules(m.Linear, qatm.Linear, qm.Linear) |
|
|
|
def count_linear(module, input, output): |
|
|
|
return np.prod(output[0].shape) * module.in_features |
|
|
|
|
|
|
|
@register_flops(m.BatchMatMulActivation) |
|
|
|
def flops_batchmatmul(module: m.BatchMatMulActivation, inputs, outputs): |
|
|
|
bias = 1 if module.bias is not None else 0 |
|
|
|
x = inputs[0] |
|
|
|
w = module.weight |
|
|
|
batch_size = x.shape[0] |
|
|
|
n, p = x.shape[1:] |
|
|
|
_, m = w.shape[1:] |
|
|
|
return n * (p + bias) * m * batch_size |
|
|
|
|
|
|
|
|
|
|
|
# does not need import qat and quantized module since they inherit from float module. |
|
|
|
hook_modules = ( |
|
|
|
m.Conv2d, |
|
|
|
m.ConvTranspose2d, |
|
|
|
m.LocalConv2d, |
|
|
|
m.BatchNorm2d, |
|
|
|
m.conv._ConvNd, |
|
|
|
m.Linear, |
|
|
|
m.BatchMatMulActivation, |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
@@ -106,28 +140,71 @@ def sizeof_fmt(num, suffix="B"): |
|
|
|
return "{}{:.1f} {}{}".format(sign_str, num, "Yi", suffix) |
|
|
|
|
|
|
|
|
|
|
|
def preprocess_receptive_field(module, inputs, outputs): |
|
|
|
# TODO: support other dimensions |
|
|
|
pre_rf = ( |
|
|
|
max(getattr(i.owner, "_rf", (1, 1))[0] for i in inputs), |
|
|
|
max(i.owner._rf[1] for i in inputs), |
|
|
|
) |
|
|
|
pre_stride = ( |
|
|
|
max(getattr(i.owner, "_stride", (1, 1))[0] for i in inputs), |
|
|
|
max(i.owner._stride[1] for i in inputs), |
|
|
|
) |
|
|
|
return pre_rf, pre_stride |
|
|
|
|
|
|
|
|
|
|
|
def get_flops_stats(module, inputs, outputs): |
|
|
|
rst = { |
|
|
|
"input_shapes": [i.shape for i in inputs], |
|
|
|
"output_shapes": [o.shape for o in outputs], |
|
|
|
} |
|
|
|
valid_flag = False |
|
|
|
for key, _dict, fallback in _iter_list: |
|
|
|
for _type in _dict: |
|
|
|
if isinstance(module, _type): |
|
|
|
value = _dict[_type](module, inputs, outputs) |
|
|
|
valid_flag = True |
|
|
|
break |
|
|
|
else: |
|
|
|
if fallback is not None: |
|
|
|
value = fallback(module, inputs, outputs) |
|
|
|
continue |
|
|
|
|
|
|
|
if isinstance(key, tuple): |
|
|
|
assert isinstance(value, tuple) |
|
|
|
for k, v in zip(key, value): |
|
|
|
rst[k] = v |
|
|
|
else: |
|
|
|
rst[key] = value |
|
|
|
|
|
|
|
if valid_flag: |
|
|
|
return rst |
|
|
|
else: |
|
|
|
return None |
|
|
|
return |
|
|
|
|
|
|
|
|
|
|
|
def print_flops_stats(flops, bar_length_max=20): |
|
|
|
flops_list = [i["flops_num"] for i in flops] |
|
|
|
max_flops_num = max(flops_list + [0]) |
|
|
|
# calc total flops and set flops_cum |
|
|
|
max_flops_num = max([i["flops_num"] for i in flops] + [0]) |
|
|
|
total_flops_num = 0 |
|
|
|
for d in flops: |
|
|
|
total_flops_num += int(d["flops_num"]) |
|
|
|
d["flops_cum"] = sizeof_fmt(total_flops_num, suffix="OPs") |
|
|
|
|
|
|
|
for d in flops: |
|
|
|
f = d["flops_num"] |
|
|
|
d["flops"] = sizeof_fmt(f, suffix="OPs") |
|
|
|
r = d["ratio"] = f / total_flops_num |
|
|
|
d["percentage"] = "{:.2f}%".format(r * 100) |
|
|
|
bar_length = int(f / max_flops_num * bar_length_max) |
|
|
|
ratio = d["ratio"] = d["flops_num"] / total_flops_num |
|
|
|
d["percentage"] = "{:.2f}%".format(ratio * 100) |
|
|
|
bar_length = int(d["flops_num"] / max_flops_num * bar_length_max) |
|
|
|
d["bar"] = "#" * bar_length |
|
|
|
d["flops"] = sizeof_fmt(d["flops_num"], suffix="OPs") |
|
|
|
|
|
|
|
header = [ |
|
|
|
"name", |
|
|
|
"class_name", |
|
|
|
"input_shapes", |
|
|
|
"output_shapes", |
|
|
|
"receptive_field", |
|
|
|
"stride", |
|
|
|
"flops", |
|
|
|
"flops_cum", |
|
|
|
"percentage", |
|
|
|
@@ -154,8 +231,8 @@ def get_param_stats(param: np.ndarray): |
|
|
|
param_size = param_dim * nbits // 8 |
|
|
|
return { |
|
|
|
"shape": shape, |
|
|
|
"mean": param.mean(), |
|
|
|
"std": param.std(), |
|
|
|
"mean": "{:.3g}".format(param.mean()), |
|
|
|
"std": "{:.3g}".format(param.std()), |
|
|
|
"param_dim": param_dim, |
|
|
|
"nbits": nbits, |
|
|
|
"size": param_size, |
|
|
|
@@ -163,21 +240,20 @@ def get_param_stats(param: np.ndarray): |
|
|
|
|
|
|
|
|
|
|
|
def print_params_stats(params, bar_length_max=20): |
|
|
|
max_size = max([d["size"] for d in params] + [0]) |
|
|
|
total_param_dims, total_param_size = 0, 0 |
|
|
|
for d in params: |
|
|
|
total_param_dims += int(d["param_dim"]) |
|
|
|
total_param_size += int(d["size"]) |
|
|
|
ratio = d["size"] / total_param_size |
|
|
|
d["size"] = sizeof_fmt(d["size"]) |
|
|
|
d["size_cum"] = sizeof_fmt(total_param_size) |
|
|
|
d["ratio"] = ratio |
|
|
|
d["percentage"] = "{:.2f}%".format(ratio * 100) |
|
|
|
|
|
|
|
# construct bar |
|
|
|
max_ratio = max([d["ratio"] for d in params]) |
|
|
|
for d in params: |
|
|
|
bar_length = int(d["ratio"] / max_ratio * bar_length_max) |
|
|
|
ratio = d["size"] / total_param_size |
|
|
|
d["ratio"] = ratio |
|
|
|
d["percentage"] = "{:.2f}%".format(ratio * 100) |
|
|
|
bar_length = int(d["size"] / max_size * bar_length_max) |
|
|
|
d["size_bar"] = "#" * bar_length |
|
|
|
d["size"] = sizeof_fmt(d["size"]) |
|
|
|
|
|
|
|
param_size = sizeof_fmt(total_param_size) |
|
|
|
params.append(dict(name="total", param_dim=total_param_dims, size=param_size,)) |
|
|
|
@@ -225,26 +301,14 @@ def module_stats( |
|
|
|
:param log_flops: whether print and record op flops. |
|
|
|
""" |
|
|
|
|
|
|
|
def module_stats_hook(module, input, output, name=""): |
|
|
|
def module_stats_hook(module, inputs, outputs, name=""): |
|
|
|
class_name = str(module.__class__).split(".")[-1].split("'")[0] |
|
|
|
|
|
|
|
flops_fun = CALC_FLOPS.get(type(module)) |
|
|
|
if callable(flops_fun): |
|
|
|
flops_num = flops_fun(module, input, output) |
|
|
|
|
|
|
|
if not isinstance(output, (list, tuple)): |
|
|
|
output = [output] |
|
|
|
|
|
|
|
flops.append( |
|
|
|
dict( |
|
|
|
name=name, |
|
|
|
class_name=class_name, |
|
|
|
input_shapes=[i.shape for i in input], |
|
|
|
output_shapes=[o.shape for o in output], |
|
|
|
flops_num=flops_num, |
|
|
|
flops_cum=0, |
|
|
|
) |
|
|
|
) |
|
|
|
flops_stats = get_flops_stats(module, inputs, outputs) |
|
|
|
if flops_stats is not None: |
|
|
|
flops_stats["name"] = name |
|
|
|
flops_stats["class_name"] = class_name |
|
|
|
flops.append(flops_stats) |
|
|
|
|
|
|
|
if hasattr(module, "weight") and module.weight is not None: |
|
|
|
w = module.weight |
|
|
|
@@ -278,19 +342,22 @@ def module_stats( |
|
|
|
for h in hooks: |
|
|
|
h.remove() |
|
|
|
|
|
|
|
total_flops, total_params = 0, 0 |
|
|
|
extra_info = { |
|
|
|
"#params": len(params), |
|
|
|
} |
|
|
|
total_flops, total_param_dims, total_param_size = 0, 0, 0 |
|
|
|
if log_params: |
|
|
|
total_param_dims, total_param_size = print_params_stats(params, bar_length_max) |
|
|
|
extra_info["total_param_dims"] = sizeof_fmt(total_param_dims) |
|
|
|
extra_info["total_param_size"] = sizeof_fmt(total_param_size) |
|
|
|
if log_flops: |
|
|
|
total_flops = print_flops_stats(flops, bar_length_max) |
|
|
|
extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs") |
|
|
|
if log_params and log_flops: |
|
|
|
extra_info["flops/param_size"] = "{:3.3f}".format( |
|
|
|
total_flops / total_param_size |
|
|
|
) |
|
|
|
|
|
|
|
extra_info = { |
|
|
|
"#params": len(params), |
|
|
|
"total_param_dims": sizeof_fmt(total_param_dims), |
|
|
|
"total_param_size": sizeof_fmt(total_param_size), |
|
|
|
"total_flops": sizeof_fmt(total_flops, suffix="OPs"), |
|
|
|
"flops/param_size": "{:3.3f}".format(total_flops / total_param_size), |
|
|
|
} |
|
|
|
print_summary(**extra_info) |
|
|
|
|
|
|
|
return total_params, total_flops |
|
|
|
return total_param_size, total_flops |