Browse Source

fix(mge/tools): improve `module_visualize` result's robustness and beauty

GitOrigin-RevId: ef7b573776
tags/v1.3.0
Megvii Engine Team 4 years ago
parent
commit
c45f1eb298
2 changed files with 14 additions and 2 deletions
  1. +11
    -1
      imperative/python/megengine/tools/network_visualize.py
  2. +3
    -1
      imperative/python/megengine/utils/module_stats.py

+ 11
- 1
imperative/python/megengine/tools/network_visualize.py View File

@@ -7,11 +7,12 @@
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import argparse import argparse
import logging


import numpy as np import numpy as np


from megengine.core.tensor.dtype import is_quantize from megengine.core.tensor.dtype import is_quantize
from megengine.logger import get_logger
from megengine.logger import _imperative_rt_logger, get_logger, set_mgb_log_level
from megengine.utils.module_stats import ( from megengine.utils.module_stats import (
print_flops_stats, print_flops_stats,
print_params_stats, print_params_stats,
@@ -58,6 +59,8 @@ def visualize(
"TensorBoard and TensorboardX are required for visualize.", exc_info=True "TensorBoard and TensorboardX are required for visualize.", exc_info=True
) )
return return
# FIXME: remove this after resolving "span dist too large" warning
old_level = set_mgb_log_level(logging.ERROR)


graph = Network.load(model_path) graph = Network.load(model_path)
writer = SummaryWriter(log_path) writer = SummaryWriter(log_path)
@@ -126,6 +129,9 @@ def visualize(
std="{:.2g}".format(node.numpy().std()), std="{:.2g}".format(node.numpy().std()),
) )
) )
# FIXME(MGE-2165): nodes outside network module may lead to unknown display bug
if not len(node.name.split(".")) > 2 and not node in graph.input_vars:
continue
node_list.append( node_list.append(
NodeDef( NodeDef(
name=process_name(node.name), op=node.type, input=inp_list, attr=attr, name=process_name(node.name), op=node.type, input=inp_list, attr=attr,
@@ -145,6 +151,10 @@ def visualize(
step_stats=StepStats(dev_stats=[DeviceStepStats(device=device)]) step_stats=StepStats(dev_stats=[DeviceStepStats(device=device)])
) )
writer._get_file_writer().add_graph((graph_def, stepstats)) writer._get_file_writer().add_graph((graph_def, stepstats))

# FIXME: remove this after resolving "span dist too large" warning
_imperative_rt_logger.set_log_level(old_level)

return total_params, total_flops return total_params, total_flops






+ 3
- 1
imperative/python/megengine/utils/module_stats.py View File

@@ -135,7 +135,9 @@ def print_flops_stats(flops, bar_length_max=20):
] ]


total_flops_str = sizeof_fmt(total_flops_num, suffix="OPs") total_flops_str = sizeof_fmt(total_flops_num, suffix="OPs")
total_var_size = sum(sum(s[1] for s in i["output_shapes"]) for i in flops)
total_var_size = sum(
sum(s[1] if len(s) > 1 else 0 for s in i["output_shapes"]) for i in flops
)
flops.append( flops.append(
dict(name="total", flops=total_flops_str, output_shapes=total_var_size) dict(name="total", flops=total_flops_str, output_shapes=total_var_size)
) )


Loading…
Cancel
Save