| @@ -9,6 +9,7 @@ | |||
| from ..core._imperative_rt.core2 import set_cpp_apply_module_trace | |||
| from . import compat | |||
| from ._passes import optimize | |||
| from .pytree import register_supported_type | |||
| from .traced_module import ( | |||
| TracedModule, | |||
| _register_all_builtin_module, | |||
| @@ -23,6 +24,7 @@ set_cpp_apply_module_trace(cpp_apply_module_trace) | |||
| __all__ = [ | |||
| "register_as_builtin", | |||
| "register_supported_type", | |||
| "trace_module", | |||
| "wrap", | |||
| "TracedModule", | |||
| @@ -12,7 +12,7 @@ from ...core.ops.builtin import GetVarShape | |||
| from ...logger import get_logger | |||
| from ...tensor import Tensor | |||
| from ..expr import Constant, Expr, is_apply_def, is_constant, is_getattr | |||
| from ..node import Node, TensorNode | |||
| from ..node import Node, NodeMixin, TensorNode | |||
| from .matcher import PatternMatcher | |||
| from .pass_base import BackwardPass, ForwardPass, register_pass | |||
| from .pattern import is_op | |||
| @@ -21,6 +21,12 @@ from .utils import get_const_value | |||
| logger = get_logger(__name__) | |||
| def _as_const_node(x): | |||
| node = Constant.make(x) | |||
| NodeMixin.wrap(x, node) | |||
| return node | |||
| @register_pass("AttrToConstant") | |||
| class AttrToConstant(BackwardPass): | |||
| r"""Convert :class:`~.GetAttr` to :class:`~.Constant` expr.""" | |||
| @@ -35,10 +41,10 @@ class AttrToConstant(BackwardPass): | |||
| orig_node = expr.outputs[0] | |||
| name = orig_node.name | |||
| with graph.insert_exprs(expr): | |||
| const_node = Constant.make(value, name=name) | |||
| const_node = _as_const_node(value) | |||
| graph.replace_node({orig_node: const_node}) | |||
| graph.compile() | |||
| name = orig_node.name | |||
| const_node.name = name | |||
| return const_node.expr | |||
| @@ -53,7 +59,7 @@ class FixInputShape(BackwardPass): | |||
| shape = Tensor(expr.inputs[0].shape, dtype="int32") | |||
| graph = expr.top_graph | |||
| with graph.insert_exprs(expr): | |||
| const_shape = Constant.make(shape) | |||
| const_shape = _as_const_node(shape) | |||
| graph.replace_node({expr.outputs[0]: const_shape}) | |||
| graph.compile() | |||
| const_shape.name = expr.outputs[0].name | |||
| @@ -73,7 +79,7 @@ class FlodConstant(ForwardPass): | |||
| const_var = expr.interpret(*[get_const_value(n.expr) for n in expr.inputs])[0] | |||
| graph = expr.top_graph | |||
| with graph.insert_exprs(expr): | |||
| const_node = Constant.make(const_var) | |||
| const_node = _as_const_node(const_var) | |||
| graph.replace_node({expr.outputs[0]: const_node}) | |||
| graph.compile() | |||
| const_node.name = expr.outputs[0].name | |||
| @@ -10,7 +10,7 @@ import collections | |||
| from collections import OrderedDict, defaultdict | |||
| from functools import partial | |||
| from inspect import FullArgSpec | |||
| from typing import Callable, NamedTuple | |||
| from typing import Any, Callable, List, NamedTuple, Tuple | |||
| import numpy as np | |||
| @@ -46,6 +46,8 @@ SUPPORTED_LEAF_TYPE = { | |||
| int, | |||
| float, | |||
| bool, | |||
| bytes, | |||
| bytearray, | |||
| QuantDtypeMeta, | |||
| CompNode, | |||
| Device, | |||
| @@ -74,18 +76,51 @@ SUPPORTED_LEAF_CLS = [ | |||
| NodeType = NamedTuple("NodeType", [("flatten", Callable), ("unflatten", Callable)]) | |||
| def register_supported_type(type, flatten=None, unflatten=None): | |||
| def register_supported_type( | |||
| type, | |||
| flatten_fn: Callable[[Any], Tuple[List, Any]] = None, | |||
| unflatten_fn: Callable[[List, Any], Any] = None, | |||
| ): | |||
| r"""Call this function to register the ``type`` as a built-in type. The registered ``type`` | |||
| can be used and serialized correctly in :py:class:`TracedModule`. | |||
| Examples: | |||
| .. code-block:: | |||
| def dict_flatten(obj: Dict): | |||
| context, values = [], [] | |||
| # obj.keys() needs to be sortable | |||
| keys = sorted(obj.keys()) | |||
| for key in keys: | |||
| values.append(obj[key]) | |||
| context.append(key) | |||
| return values, tuple(context) | |||
| def dict_unflatten(values: List, context: Any): | |||
| return dict(zip(context, values)) | |||
| register_supported_type(dict, dict_flatten, dict_unflatten) | |||
| Args: | |||
| type: the type that needs to be registered. | |||
| flatten_fn: a function that should take an object created from ``type`` and return a | |||
| flat list of values. It can also return some context that is used in reconstructing | |||
| the object. Default: None | |||
| unflatten_fn: a function that should take a flat list of values and some context | |||
| (returned by flatten_fn). It returns the object by reconstructing | |||
| it from the list and the context. Default: None | |||
| """ | |||
| tp_info = (type.__module__, type.__qualname__) | |||
| if flatten and unflatten: | |||
| if flatten_fn and unflatten_fn: | |||
| USER_REGISTERED_CONTAINER_TYPE.append(tp_info) | |||
| else: | |||
| USER_REGISTERED_LEAF_TYPE.append(tp_info) | |||
| _register_supported_type(type, flatten, unflatten) | |||
| _register_supported_type(type, flatten_fn, unflatten_fn) | |||
| def _register_supported_type(type, flatten=None, unflatten=None): | |||
| if flatten and unflatten: | |||
| SUPPORTED_TYPE[type] = NodeType(flatten, unflatten) | |||
| def _register_supported_type(type, flatten_fn=None, unflatten_fn=None): | |||
| if flatten_fn and unflatten_fn: | |||
| SUPPORTED_TYPE[type] = NodeType(flatten_fn, unflatten_fn) | |||
| else: | |||
| SUPPORTED_LEAF_CLS.append(type) | |||
| @@ -131,6 +166,7 @@ _register_supported_type( | |||
| _register_supported_type( | |||
| OrderedDict, partial(_dict_flatten, True), partial(_dict_unflatten, OrderedDict) | |||
| ) | |||
| _register_supported_type( | |||
| slice, | |||
| lambda x: ([x.start, x.stop, x.step], None), | |||
| @@ -42,6 +42,7 @@ from ..core._imperative_rt.core2 import ( | |||
| ) | |||
| from ..core._trace_option import set_symbolic_shape | |||
| from ..module import Module | |||
| from ..module import external as MExternal | |||
| from ..module.qat import QATModule | |||
| from ..quantization.fake_quant import LSQ, TQT, FakeQuantize, _FakeQuantize | |||
| from ..quantization.observer import ( | |||
| @@ -207,6 +208,7 @@ def _wrap_method_to_tensor_node(): | |||
| for method in get_tensor_wrapable_method(): | |||
| patch = PatchedFn(TensorNode, method) | |||
| if type(getattr(Tensor, method)) == property: | |||
| # Only support property.getter | |||
| patch.set_func(property(_any_method(method, patch.origin_fn))) | |||
| else: | |||
| patch.set_func(_any_method(method, patch.origin_fn)) | |||
| @@ -351,14 +353,14 @@ class _InsertExprs: | |||
| assert ( | |||
| node.top_graph == self.graph | |||
| ), "The input node ({}) is not in the graph ({})".format(node, self.graph) | |||
| if isinstance(node, TensorNode) and node.expr in self.graph._exprs: | |||
| if node.expr in self.graph._exprs: | |||
| max_inp_expr_idx = max( | |||
| max_inp_expr_idx, self.graph._exprs.index(node.expr) | |||
| ) | |||
| max_inp_expr_idx += 1 | |||
| insert_index = -1 | |||
| if self.expr is not None: | |||
| if self.expr in self.graph._exprs: | |||
| insert_index = self.graph._exprs.index(self.expr) | |||
| insert_index += 1 | |||
| @@ -2070,7 +2072,8 @@ class TracedModule(Module): | |||
| for inp_def, graph in self.argdef_graph_map.items(): | |||
| if top_graph is not None: | |||
| graph._top_graph = weakref.ref(top_graph) | |||
| for n in graph._inputs + graph.outputs: | |||
| for n in graph._inputs + graph._outputs: | |||
| n.expr._top_graph = weakref.ref(graph) | |||
| n._top_graph = weakref.ref(graph) | |||
| graph._inputs[0]._owner = weakref.ref(self) | |||
| for i, n in enumerate(graph._inputs): | |||
| @@ -2375,7 +2378,7 @@ def wrap(func: Callable): | |||
| def _register_all_builtin_module(): | |||
| for sub_mod in [M, M.qat, M.quantized]: | |||
| for sub_mod in [M, M.qat, M.quantized, MExternal]: | |||
| for m in getmembers(sub_mod): | |||
| if ( | |||
| isclass(m[1]) | |||
| @@ -126,10 +126,12 @@ def _check_obj_attr(obj): | |||
| for _, v in obj.items(): | |||
| leafs, _ = tree_flatten(v, is_leaf=lambda _: True) | |||
| for leaf in leafs: | |||
| assert _check_leaf_type( | |||
| leaf | |||
| ), "Type {} is not supported by traced module".format( | |||
| leaf if isinstance(leaf, type) else type(leaf) | |||
| assert _check_leaf_type(leaf), ( | |||
| "Type {} is not supported in TracedModule serialization by default. " | |||
| "If you want to save this object to file, please call tm.register_supported_type({}) " | |||
| "before saving.".format( | |||
| leaf if isinstance(leaf, type) else type(leaf), type(leaf).__name__ | |||
| ) | |||
| ) | |||