| @@ -10,6 +10,7 @@ from ..core._imperative_rt.core2 import set_cpp_apply_module_trace | |||
| from . import compat | |||
| from ._passes import optimize | |||
| from .pytree import register_supported_type | |||
| from .tm_config import disable_default_checker, enable_expr_checker | |||
| from .traced_module import ( | |||
| TracedModule, | |||
| _register_all_builtin_module, | |||
| @@ -29,4 +30,6 @@ __all__ = [ | |||
| "wrap", | |||
| "TracedModule", | |||
| "optimize", | |||
| "enable_expr_checker", | |||
| "disable_default_checker", | |||
| ] | |||
| @@ -0,0 +1,142 @@ | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import traceback | |||
| from typing import Sequence | |||
| import numpy as np | |||
| from ..core._imperative_rt.core2 import apply | |||
| from ..core._imperative_rt.ops import ROIAlign, ROIPooling | |||
| from ..core.ops.builtin import Copy | |||
| from ..core.tensor.utils import isscalar, setscalar | |||
| from ..tensor import Tensor | |||
| from .tm_config import _exclude_from_trace | |||
| class TracedModuleChecker: | |||
| def __init__(self, tracer): | |||
| self._active_node2values = [] | |||
| self.tracer = tracer | |||
| self.node_without_tensor_info = {} | |||
| def push_scope(self): | |||
| self._active_node2values.append({}) | |||
| def pop_scope(self): | |||
| self._active_node2values.pop() | |||
| def current_node2values(self): | |||
| return self._active_node2values[-1] | |||
| def reset_checker(self): | |||
| self._active_node2values = [] | |||
| def check_node_not_in_scope(self): | |||
| if self.node_without_tensor_info: | |||
| for node, info in self.node_without_tensor_info.items(): | |||
| for expr in info[0]._exprs: | |||
| if node in expr.inputs or node in expr.outputs: | |||
| traceback.print_list(info[1]) | |||
| raise ValueError( | |||
| "node({}) not in the graph:\n{}".format(node, info[0]) | |||
| ) | |||
| return True | |||
| else: | |||
| return False | |||
| def check_net_outputs(self, tm_res, gt_res): | |||
| if isinstance(tm_res, Tensor): | |||
| np.testing.assert_allclose(tm_res.numpy(), gt_res.numpy()) | |||
| elif isinstance(tm_res, Sequence): | |||
| for i, j in zip(tm_res, gt_res): | |||
| np.testing.assert_allclose(i.numpy(), j.numpy()) | |||
| else: | |||
| for k in tm_res.__dict__.keys(): | |||
| np.testing.assert_allclose( | |||
| getattr(tm_res, k).numpy(), getattr(gt_res, k).numpy() | |||
| ) | |||
| def record_nodemixin(self, node, value): | |||
| self.current_node2values()[node] = value | |||
| def record_node2value(self, node, value): | |||
| with _exclude_from_trace(): | |||
| self.current_node2values()[node] = apply( | |||
| Copy(comp_node=value.device), value | |||
| )[0] | |||
| if isscalar(value): | |||
| setscalar(self.current_node2values()[node]) | |||
| def check_apply_special_cases(self, opdef, num_outputs): | |||
| indexs = list(range(num_outputs)) | |||
| if isinstance(opdef, ROIAlign) and opdef.mode == ROIAlign.Mode.AVERAGE: | |||
| indexs.pop(-1) | |||
| if isinstance(opdef, ROIPooling) and opdef.mode == ROIPooling.Mode.AVERAGE: | |||
| indexs.pop(-1) | |||
| return indexs | |||
| def check_expr_results(self, expr_outputs, gt_outputs, indexs=None): | |||
| expr_outputs = ( | |||
| (expr_outputs,) if not isinstance(expr_outputs, Sequence) else expr_outputs | |||
| ) | |||
| gt_outputs = ( | |||
| (gt_outputs,) if not isinstance(gt_outputs, Sequence) else gt_outputs | |||
| ) | |||
| if indexs is not None: | |||
| for i in indexs: | |||
| np.testing.assert_allclose( | |||
| expr_outputs[i].numpy(), gt_outputs[i].numpy() | |||
| ) | |||
| else: | |||
| np.testing.assert_allclose(expr_outputs, gt_outputs) | |||
| def get_node2value(self, inputs, start_idx=0): | |||
| inp_values = [] | |||
| has_node_not_in_scope = False | |||
| for i in range(start_idx, len(inputs)): | |||
| try: | |||
| inp_values.append(self.current_node2values()[inputs[i]]) | |||
| except: | |||
| has_node_not_in_scope = True | |||
| self.node_without_tensor_info[inputs[i]] = [ | |||
| self.tracer.current_scope(), | |||
| traceback.extract_stack(), | |||
| ] | |||
| return inp_values, has_node_not_in_scope | |||
| def check_expr_interpret(self, expr, gt_outputs): | |||
| ori_in, has_node_not_in_scope = self.get_node2value(expr.inputs) | |||
| if not has_node_not_in_scope: | |||
| expr_res = expr.interpret(*ori_in) | |||
| try: | |||
| self.check_expr_results(expr_res, gt_outputs) | |||
| except: | |||
| raise ValueError("Error occurred when checking expr: {}".format(expr)) | |||
| def check_apply(self, expr, gt_outputs, opdef): | |||
| ori_in, has_node_not_in_scope = self.get_node2value(expr.inputs) | |||
| if not has_node_not_in_scope: | |||
| expr_res = expr.interpret(*ori_in) | |||
| indexs = self.check_apply_special_cases(opdef, len(gt_outputs)) | |||
| try: | |||
| self.check_expr_results(expr_res, gt_outputs, indexs=indexs) | |||
| except: | |||
| raise ValueError("Error occurred when checking expr: {}".format(expr)) | |||
| def check_builtin_module(self, module, expr, gt_outputs): | |||
| ori_in, has_node_not_in_scope = self.get_node2value(expr.inputs, start_idx=1) | |||
| if not has_node_not_in_scope: | |||
| ori_in.insert(0, module) | |||
| expr_res = expr.interpret(*ori_in) | |||
| try: | |||
| self.check_expr_results(expr_res, gt_outputs) | |||
| except: | |||
| raise ValueError( | |||
| "{}, Error occurred when checking expr: {}".format(expr) | |||
| ) | |||
| @@ -32,6 +32,7 @@ from .module_tracer import active_module_tracer, module_tracer | |||
| from .node import ModuleNode, Node, NodeMixin, TensorNode | |||
| from .pytree import ArgsIndex, TreeDef, _is_const_leaf, _is_leaf, tree_flatten | |||
| from .serialization import _ModuleState | |||
| from .tm_config import _exclude_from_trace, _get_expr_checker | |||
| from .utils import _check_builtin_module_attr, _check_obj_attr, _convert_kwargs_to_args | |||
| @@ -611,6 +612,8 @@ class Apply(Expr): | |||
| inp_nodes = [NodeMixin.get(inputs[0])] | |||
| for i in inputs[1:]: | |||
| node = Constant.make(i) | |||
| if _get_expr_checker(): | |||
| active_module_tracer().checker.record_node2value(node, Tensor(i)) | |||
| inp_nodes.append(node) | |||
| apply_node = cls.make(opdef) | |||
| for n in inp_nodes: | |||
| @@ -624,11 +627,17 @@ class Apply(Expr): | |||
| unset_module_tracing() | |||
| outputs = apply(opdef, *inputs) | |||
| outputs = list(map(Tensor, outputs)) | |||
| set_module_tracing() | |||
| apply_node.add_outputs(outputs) | |||
| for n, v in zip(apply_node.outputs, outputs): | |||
| NodeMixin.wrap_safe(v, n) | |||
| if _get_expr_checker(): | |||
| with _exclude_from_trace(): | |||
| active_module_tracer().checker.check_apply(apply_node, outputs, opdef) | |||
| return list(outputs) | |||
| @@ -12,6 +12,7 @@ from .. import functional as F | |||
| from ..core.tensor.array_method import ArrayMethodMixin | |||
| from ..module import Module | |||
| from ..module.qat import QATModule | |||
| from .checker import TracedModuleChecker | |||
| _active_module_tracer = None | |||
| @@ -128,6 +129,7 @@ class module_tracer: | |||
| def __init__(self, wrap_fn): | |||
| self._active_scopes = [] | |||
| self.checker = TracedModuleChecker(self) | |||
| self.patcher = Patcher(wrap_fn) | |||
| @classmethod | |||
| @@ -142,9 +144,11 @@ class module_tracer: | |||
| def push_scope(self, scope): | |||
| self._active_scopes.append(scope) | |||
| self.checker.push_scope() | |||
| def pop_scope(self): | |||
| self._active_scopes.pop() | |||
| self.checker.pop_scope() | |||
| def current_scope(self): | |||
| if self._active_scopes: | |||
| @@ -18,6 +18,8 @@ from ..core._imperative_rt.core2 import Tensor as RawTensor | |||
| from ..module import Module | |||
| from ..quantization.utils import QParams | |||
| from ..tensor import Tensor | |||
| from .module_tracer import active_module_tracer | |||
| from .tm_config import _get_expr_checker | |||
| from .utils import _check_obj_attr | |||
| logger = get_logger(__name__) | |||
| @@ -343,6 +345,11 @@ class NodeMixin(abc.ABC): | |||
| if isinstance(value, NodeMixin): | |||
| value._record_wrapped_nodes(node) | |||
| setattr(value, "_NodeMixin__node", node) | |||
| if _get_expr_checker(): | |||
| if isinstance(value, RawTensor): | |||
| active_module_tracer().checker.record_node2value(node, value) | |||
| if isinstance(value, NodeMixin): | |||
| active_module_tracer().checker.record_nodemixin(node, value) | |||
| else: | |||
| assert callable(node) | |||
| n = node() | |||
| @@ -352,6 +359,11 @@ class NodeMixin(abc.ABC): | |||
| if isinstance(value, NodeMixin): | |||
| value._record_wrapped_nodes(n) | |||
| setattr(value, "_NodeMixin__node", n) | |||
| if _get_expr_checker(): | |||
| if isinstance(value, RawTensor): | |||
| active_module_tracer().checker.record_node2value(n, value) | |||
| if isinstance(value, NodeMixin): | |||
| active_module_tracer().checker.record_nodemixin(n, value) | |||
| @classmethod | |||
| def wrap_safe(cls, value, node): | |||
| @@ -359,6 +371,11 @@ class NodeMixin(abc.ABC): | |||
| if isinstance(value, RawTensor): | |||
| cls._record_tensornode_property(node, value) | |||
| setattr(value, "_NodeMixin__node", node) | |||
| if _get_expr_checker(): | |||
| if isinstance(value, RawTensor): | |||
| active_module_tracer().checker.record_node2value(node, value) | |||
| if isinstance(value, NodeMixin): | |||
| active_module_tracer().checker.record_nodemixin(node, value) | |||
| if isinstance(value, NodeMixin): | |||
| value._record_wrapped_nodes(node) | |||
| @@ -212,7 +212,11 @@ def tree_flatten( | |||
| to reconstruct the pytree. | |||
| """ | |||
| if type(values) not in SUPPORTED_TYPE: | |||
| assert is_leaf(values), values | |||
| assert is_leaf( | |||
| values | |||
| ), 'doesn\'t support {} type, MUST use "register_supported_type" method to register self-defined type'.format( | |||
| values | |||
| ) | |||
| node = LeafDef(leaf_type(values)) | |||
| if is_const_leaf(values): | |||
| node.const_val = values | |||
| @@ -0,0 +1,55 @@ | |||
| import contextlib | |||
| from ..core._imperative_rt.core2 import ( | |||
| is_tracing_module, | |||
| set_module_tracing, | |||
| unset_module_tracing, | |||
| ) | |||
| _enable_expr_checker = False | |||
| _enable_default_checker = True | |||
| def _get_expr_checker(): | |||
| return _enable_expr_checker | |||
| def _get_default_checker(): | |||
| return _enable_default_checker | |||
| def enable_expr_checker(): | |||
| r"""Call this function to check the result of each expr during tracing.""" | |||
| global _enable_expr_checker | |||
| _enable_expr_checker = True | |||
| _enable_default_checker = False | |||
| def disable_default_checker(): | |||
| r"""Call this function to disable checking the final output of the model after tracing.""" | |||
| global _enable_default_checker | |||
| _enable_default_checker = False | |||
| _enable_graph_surgery_mode = False | |||
| def _graph_surgery_mode(): | |||
| return _enable_graph_surgery_mode | |||
| def _set_graph_surgery_mode(mode: bool): | |||
| global _enable_graph_surgery_mode | |||
| pre_mode = _enable_graph_surgery_mode | |||
| _enable_graph_surgery_mode = mode | |||
| return pre_mode | |||
| @contextlib.contextmanager | |||
| def _exclude_from_trace(): | |||
| is_tracing = is_tracing_module() | |||
| if is_tracing: | |||
| unset_module_tracing() | |||
| yield | |||
| if is_tracing: | |||
| set_module_tracing() | |||
| @@ -36,11 +36,14 @@ from .. import get_logger | |||
| from .. import module as M | |||
| from ..core._imperative_rt.core2 import Tensor as RawTensor | |||
| from ..core._imperative_rt.core2 import ( | |||
| apply, | |||
| is_tracing_module, | |||
| set_module_tracing, | |||
| unset_module_tracing, | |||
| ) | |||
| from ..core._trace_option import set_symbolic_shape | |||
| from ..core.ops.builtin import Copy | |||
| from ..core.tensor.utils import isscalar, setscalar | |||
| from ..module import Module | |||
| from ..module import external as MExternal | |||
| from ..module.qat import QATModule | |||
| @@ -98,6 +101,13 @@ from .serialization import ( | |||
| load_call_tensor_method_expr, | |||
| load_functional, | |||
| ) | |||
| from .tm_config import ( | |||
| _exclude_from_trace, | |||
| _get_default_checker, | |||
| _get_expr_checker, | |||
| _graph_surgery_mode, | |||
| _set_graph_surgery_mode, | |||
| ) | |||
| from .utils import ( | |||
| _check_builtin_module_attr, | |||
| _check_obj_attr, | |||
| @@ -117,26 +127,14 @@ def _is_builtin_name(name: str) -> bool: | |||
| def _is_leaf(node): | |||
| assert isinstance(node, RawTensor), "doesn't support {} in return values".format( | |||
| assert isinstance( | |||
| node, RawTensor | |||
| ), 'doesn\'t support {} in return values, MUST use Tensor or use "register_supported_type" method to register self-defined type'.format( | |||
| type(node) | |||
| ) | |||
| return isinstance(node, RawTensor) | |||
| _enable_graph_surgery_mode = False | |||
| def _graph_surgery_mode(): | |||
| return _enable_graph_surgery_mode | |||
| def _set_graph_surgery_mode(mode: bool): | |||
| global _enable_graph_surgery_mode | |||
| pre_mode = _enable_graph_surgery_mode | |||
| _enable_graph_surgery_mode = mode | |||
| return pre_mode | |||
| def _node_to_tensor(*args, **kwargs): | |||
| tensors = [] | |||
| nodes, tree_def = tree_flatten((args, kwargs)) | |||
| @@ -1295,7 +1293,12 @@ def _wrapped_function(orig_func): | |||
| return orig_func(*args, **kwargs) | |||
| if isinstance(args[1], RawTensor): | |||
| node = NodeMixin.get(inputs[1]) | |||
| inputs[1] = copy.copy(inputs[1]) | |||
| is_scalar = isscalar(inputs[1]) | |||
| inputs[1] = apply( | |||
| Copy(comp_node=inputs[1].device), Tensor(inputs[1]) | |||
| )[0] | |||
| if is_scalar: | |||
| setscalar(inputs[1]) | |||
| # copy inputs[1] to avoid tensor and Tensor(tensor) share same m_tensor, | |||
| # which will cause they have same _NodeMixin__node in tracing. | |||
| NodeMixin.wrap_safe(inputs[1], node) | |||
| @@ -1319,6 +1322,13 @@ def _wrapped_function(orig_func): | |||
| else: | |||
| outputs = None | |||
| call_node.add_outputs(outputs) | |||
| if _get_expr_checker(): | |||
| with _exclude_from_trace(): | |||
| active_module_tracer().checker.check_expr_interpret( | |||
| call_node, outputs | |||
| ) | |||
| set_module_tracing() | |||
| return rst | |||
| return orig_func(*args, **kwargs) | |||
| @@ -1500,6 +1510,12 @@ class TracedModuleBuilder(NodeMixin): | |||
| unset_module_tracing() | |||
| rst = self._mod(*args, **kwargs) | |||
| outputs, out_def = tree_flatten(rst, is_leaf=_is_leaf) | |||
| if _get_expr_checker(): | |||
| with _exclude_from_trace(): | |||
| tmp = self.build() | |||
| active_module_tracer().checker.check_builtin_module( | |||
| tmp, callnode, outputs | |||
| ) | |||
| set_module_tracing() | |||
| if self._is_builtin: | |||
| self._body = None | |||
| @@ -1674,7 +1690,9 @@ class TracedModuleBuilder(NodeMixin): | |||
| if not isinstance(mod_attr, (List, Dict, QATModule)): | |||
| assert mod_attr is wrapped._mod | |||
| else: | |||
| assert mod_attr is wrapped | |||
| assert ( | |||
| mod_attr is wrapped | |||
| ), "TracedModule do not support modify attributes, please check your code." | |||
| if isinstance(wrapped, (NodeMixin, RawTensor)): | |||
| NodeMixin.wrap( | |||
| @@ -2469,11 +2487,23 @@ def trace_module( | |||
| qualname="{}.[{}]".format(net_name, "arg_{}".format(_)), | |||
| ), | |||
| ) | |||
| builder(*args, **kwargs) | |||
| rst = builder(*copy.deepcopy(args), **copy.deepcopy(kwargs)) | |||
| active_module_tracer().pop_scope() | |||
| traced_mod = builder.build() | |||
| traced_mod.argspec = forward_argspec | |||
| traced_mod.graph._reset_ids() | |||
| has_expr_not_check = False | |||
| if _get_expr_checker(): | |||
| has_expr_not_check = ( | |||
| active_module_tracer().checker.check_node_not_in_scope() | |||
| ) | |||
| if _get_default_checker() or has_expr_not_check: | |||
| with _exclude_from_trace(): | |||
| tm_res = traced_mod(*args, **kwargs) | |||
| tm_res, _ = tree_flatten(tm_res, is_leaf=_is_leaf) | |||
| rst, _ = tree_flatten(rst, is_leaf=_is_leaf) | |||
| active_module_tracer().checker.check_net_outputs(tm_res, rst) | |||
| return traced_mod | |||
| finally: | |||
| set_symbolic_shape(use_sym_shape) | |||
| @@ -5,16 +5,15 @@ | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import collections | |||
| import copy | |||
| import inspect | |||
| from collections.abc import MutableMapping, MutableSequence | |||
| from inspect import FullArgSpec | |||
| from typing import Callable, Dict, Iterable, List, Optional, Sequence, Type, Union | |||
| from typing import Callable, Dict, Iterable, List, Optional, Sequence, Union | |||
| from .. import get_logger | |||
| from ..module import Module | |||
| from ..tensor import Parameter, Tensor | |||
| from ..tensor import Tensor | |||
| logger = get_logger(__name__) | |||
| @@ -109,6 +109,7 @@ def build_observered_net(net: M.Module, observer_cls): | |||
| ) | |||
| Q.enable_observer(qat_net) | |||
| inp = Tensor(np.random.random(size=(5, 3, 32, 32))) | |||
| qat_net.eval() | |||
| qat_net(inp) | |||
| Q.disable_observer(qat_net) | |||
| return qat_net | |||
| @@ -116,6 +117,7 @@ def build_observered_net(net: M.Module, observer_cls): | |||
| def build_fakequanted_net(net: QATModule, fakequant_cls): | |||
| qat_net = Q.reset_qconfig(net, get_lsq_config(fakequant_cls)) | |||
| qat_net.eval() | |||
| return qat_net | |||
| @@ -162,6 +164,7 @@ def test_load_param(): | |||
| def _check_module(build_func: Callable): | |||
| net = build_func() | |||
| net.eval() | |||
| buffer = io.BytesIO() | |||
| mge.save(net.state_dict(), buffer) | |||
| buffer.seek(0) | |||
| @@ -185,6 +188,7 @@ def test_load_param(): | |||
| def test_qualname(): | |||
| def _check_qualname(net): | |||
| inp = Tensor(np.random.random(size=(5, 3, 32, 32))) | |||
| net.eval() | |||
| traced_net = trace_module(net, inp) | |||
| base_qualname = traced_net.graph.qualname | |||
| for node in traced_net.graph.nodes(): | |||