| @@ -7,7 +7,7 @@ | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import builtins | |||
| import collections | |||
| from typing import Callable, List | |||
| @@ -19,6 +19,7 @@ from ...module import Module | |||
| from ...tensor import Tensor | |||
| from .module_tracer import active_module_tracer | |||
| from .node import ModuleNode, Node, NodeMixin, TensorNode | |||
| from .pytree import TreeDef | |||
| class Expr: | |||
| @@ -28,9 +29,22 @@ class Expr: | |||
| inputs = None # type: List[Node] | |||
| outputs = None # type: List[Node] | |||
| def add_input(self, node): | |||
| self.inputs.append(node) | |||
| const_val = None # type: List[Any] | |||
| arg_def = None # type: TreeDef | |||
| def add_inputs(self, vals): | |||
| if not isinstance(vals, collections.abc.Sequence): | |||
| vals = (vals,) | |||
| for val in vals: | |||
| node = NodeMixin.get(val, None) | |||
| if isinstance(node, (TensorNode, ModuleNode)): | |||
| if node not in self.inputs: | |||
| self.inputs.append(node) | |||
| else: | |||
| assert node is None | |||
| assert type(val) in builtins.__dict__.values() | |||
| idx = len(self.inputs) + len(self.const_val) | |||
| self.const_val.append((idx, val)) | |||
| def add_outputs(self, outputs): | |||
| self.outputs = [] | |||
| @@ -38,50 +52,31 @@ class Expr: | |||
| outputs = (outputs,) | |||
| for i in outputs: | |||
| assert isinstance(i, RawTensor) | |||
| self.outputs.append(NodeMixin.get_wrapped_type(i)(self)) | |||
| for i, node in zip(outputs, self.outputs,): | |||
| NodeMixin.wrap_safe(i, node) | |||
| @classmethod | |||
| def get_args_node(cls, arg): | |||
| """ | |||
| Create nodes by ``arg``, which may be a container. | |||
| Return the same structure with arg. | |||
| If ``arg`` was not Tensor or Module, it will be stored as const. | |||
| :param arg: tensor, module or const. | |||
| """ | |||
| if isinstance(arg, (RawTensor, Module)): | |||
| if not NodeMixin.get(arg, None): | |||
| NodeMixin.wrap_safe(arg, Constant.make(arg)) | |||
| return NodeMixin.get(arg) | |||
| elif isinstance(arg, collections.abc.Sequence): | |||
| seq_cls = type(arg) | |||
| return seq_cls([Expr.get_args_node(a) for a in arg]) | |||
| def unflatten_args(self, inputs): | |||
| if self.arg_def is not None: | |||
| inputs = list(inputs) | |||
| for idx, val in self.const_val: | |||
| inputs.insert(idx, val) | |||
| args, kwargs = self.arg_def.unflatten(inputs) | |||
| return args, kwargs | |||
| else: | |||
| # TODO: assert arg type | |||
| return arg # as const | |||
| return inputs, {} | |||
| @classmethod | |||
| def get_arg_value(cls, inp_node, node2value): | |||
| """ | |||
| Get values from node2value by inp_node, which may be a container. | |||
| Return the same structure with inp_node. | |||
| If ``inp_node`` was not in node2value, it is a const. | |||
| :param inp_node: nodes. | |||
| :param node2value: dict from node to tensor and module. | |||
| """ | |||
| if inp_node in node2value: | |||
| return node2value[inp_node] | |||
| elif isinstance(inp_node, collections.abc.Sequence): | |||
| seq_cls = type(inp_node) | |||
| return seq_cls([Expr.get_arg_value(i, node2value) for i in inp_node]) | |||
| else: | |||
| return inp_node | |||
| @property | |||
| def kwargs(self): | |||
| _, kwargs = self.unflatten_args(self.inputs) | |||
| return kwargs | |||
| @property | |||
| def args(self): | |||
| args, _ = self.unflatten_args(self.inputs) | |||
| return args | |||
| # expr: None (i.e. fake expression which is used to mark input) | |||
| @@ -144,16 +139,8 @@ class CallMethod(Expr): | |||
| self.inputs = [ | |||
| module, | |||
| ] | |||
| self.const_val = [] | |||
| self.method = method | |||
| self.arg_names = [] | |||
| self.kwargs = {} # const kwargs | |||
| def add_input(self, node, arg_name=None): | |||
| if arg_name == "self": # FIXME: <XP> | |||
| return | |||
| self.inputs.append(node) | |||
| if arg_name is not None: | |||
| self.arg_names.append(arg_name) | |||
| @classmethod | |||
| def make(cls, *args, **kwargs): | |||
| @@ -162,19 +149,22 @@ class CallMethod(Expr): | |||
| return expr | |||
| def interpret(self, *inputs): | |||
| mod = inputs[0] | |||
| args = inputs[1:] | |||
| outputs = getattr(mod, self.method)(*args, **self.kwargs) | |||
| args, kwargs = self.unflatten_args(inputs) | |||
| obj = args[0] | |||
| args = args[1:] | |||
| outputs = getattr(obj, self.method)(*args, **kwargs) | |||
| if isinstance(outputs, RawTensor): | |||
| outputs = (outputs,) | |||
| return outputs | |||
| def __repr__(self): | |||
| return "{} = CallMethod({}, {})({})".format( | |||
| args = ", ".join(str(i) for i in self.args[1:]) | |||
| kwargs = ", ".join("{}={}".format(k, v) for k, v in self.kwargs.items()) | |||
| return "{} = {}.{}({})".format( | |||
| ", ".join(str(i) for i in self.outputs), | |||
| self.inputs[0], | |||
| self.method, | |||
| ", ".join(str(i) for i in self.inputs[1:]), | |||
| ", ".join([args, kwargs]), | |||
| ) | |||
| @@ -227,13 +217,8 @@ class CallFunction(Expr): | |||
| def __init__(self, func): | |||
| assert isinstance(func, Callable) | |||
| self.func = func | |||
| self.const_val = [] | |||
| self.inputs = [] | |||
| self.arg_names = [] | |||
| self.kwargs = {} # const kwargs | |||
| def add_input(self, node, arg_name): | |||
| self.inputs.append(node) | |||
| self.arg_names.append(arg_name) | |||
| @classmethod | |||
| def make(cls, *args, **kwargs): | |||
| @@ -242,18 +227,20 @@ class CallFunction(Expr): | |||
| return expr | |||
| def interpret(self, *inputs): | |||
| inp_dict = dict([(name, node) for node, name in zip(inputs, self.arg_names)]) | |||
| outputs = self.func(**inp_dict, **self.kwargs) | |||
| args, kwargs = self.unflatten_args(inputs) | |||
| outputs = self.func(*args, **kwargs) | |||
| outputs = ( | |||
| outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,) | |||
| ) | |||
| return outputs | |||
| def __repr__(self): | |||
| args = ", ".join(str(i) for i in self.args) | |||
| kwargs = ", ".join("{}={}".format(k, v) for k, v in self.kwargs.items()) | |||
| return "{} = {}({})".format( | |||
| ", ".join(str(i) for i in self.outputs), | |||
| self.func.__module__ + "." + self.func.__name__, | |||
| ", ".join(str(i) for i in self.inputs), | |||
| ", ".join([args, kwargs]), | |||
| ) | |||
| @@ -15,6 +15,72 @@ from ...module import Module | |||
| _active_module_tracer = None | |||
| BUILTIN_ARRAY_METHOD = [ | |||
| "__lt__", | |||
| "__le__", | |||
| "__gt__", | |||
| "__ge__", | |||
| "__eq__", | |||
| "__ne__", | |||
| "__neg__", | |||
| "__pos__", | |||
| "__abs__", | |||
| "__invert__", | |||
| "__round__", | |||
| "__floor__", | |||
| "__ceil__", | |||
| "__add__", | |||
| "__sub__", | |||
| "__mul__", | |||
| "__matmul__", | |||
| "__truediv__", | |||
| "__floordiv__", | |||
| "__mod__", | |||
| "__pow__", | |||
| "__lshift__", | |||
| "__rshift__", | |||
| "__and__", | |||
| "__or__", | |||
| "__xor__", | |||
| "__radd__", | |||
| "__rsub__", | |||
| "__rmul__", | |||
| "__rmatmul__", | |||
| "__rtruediv__", | |||
| "__rfloordiv__", | |||
| "__rmod__", | |||
| "__rpow__", | |||
| "__rlshift__", | |||
| "__rrshift__", | |||
| "__rand__", | |||
| "__ror__", | |||
| "__rxor__", | |||
| "__iadd__", | |||
| "__isub__", | |||
| "__imul__", | |||
| "__imatmul__", | |||
| "__itruediv__", | |||
| "__ifloordiv__", | |||
| "__imod__", | |||
| "__ipow__", | |||
| "__ilshift__", | |||
| "__irshift__", | |||
| "__iand__", | |||
| "__ior__", | |||
| "__ixor__", | |||
| "T", | |||
| "astype", | |||
| "reshape", | |||
| "_broadcast", | |||
| "transpose", | |||
| "flatten", | |||
| "sum", | |||
| "prod", | |||
| "min", | |||
| "max", | |||
| "mean", | |||
| ] | |||
| def active_module_tracer(): | |||
| return _active_module_tracer | |||
| @@ -108,9 +174,8 @@ class Patcher: | |||
| self.wrap_fn = wrap_fn | |||
| for module in self._builtin_modules: | |||
| self.patch_module(module) | |||
| for cls in self._builtin_methods: | |||
| self.patch_cls(cls) | |||
| for meth in BUILTIN_ARRAY_METHOD: | |||
| self.patch_method(ArrayMethodMixin, meth, self.wrap_fn) | |||
| for i, j in self._builtin_functions: | |||
| if id(i) not in self.visited_frames_ids: | |||
| @@ -13,6 +13,7 @@ import numpy | |||
| from ...core._imperative_rt.core2 import Tensor as RawTensor | |||
| from ...module import Module | |||
| from ...tensor import Tensor | |||
| from .pytree import TreeDef | |||
| class Node: | |||
| @@ -58,6 +59,7 @@ class ModuleNode(Node): | |||
| module_type = Module # type: Type[Module] | |||
| graph = None | |||
| attr_type_map = None # type: Dict[str, Type[Any]] | |||
| arg_def = None # type: TreeDef | |||
| def __repr__(self): | |||
| if self._name is None: | |||
| @@ -0,0 +1,80 @@ | |||
| from typing import Callable, NamedTuple | |||
| SUPPORTED_TYPE = {} | |||
| NodeType = NamedTuple("NodeType", [("flatten", Callable), ("unflatten", Callable)]) | |||
| def register_supported_type(type, flatten, unflatten): | |||
| SUPPORTED_TYPE[type] = NodeType(flatten, unflatten) | |||
| register_supported_type(list, lambda x: (x, None), lambda x, aux_data: list(x)) | |||
| register_supported_type(tuple, lambda x: (x, None), lambda x, aux_data: list(x)) | |||
| register_supported_type( | |||
| dict, lambda x: (list(x.values()), list(x.keys())), lambda x, y: dict(zip(y, x)) | |||
| ) | |||
| register_supported_type( | |||
| slice, | |||
| lambda x: ([x.start, x.stop, x.step], None), | |||
| lambda x, aux_data: slice(x[0], x[1], x[2]), | |||
| ) | |||
| def tree_flatten( | |||
| values, leaf_type: Callable = lambda x: type(x), is_leaf: Callable = lambda x: True | |||
| ): | |||
| if type(values) not in SUPPORTED_TYPE: | |||
| assert is_leaf(values) | |||
| return [values,], LeafDef(leaf_type(values)) | |||
| rst = [] | |||
| children_defs = [] | |||
| children_values, aux_data = SUPPORTED_TYPE[type(values)].flatten(values) | |||
| for v in children_values: | |||
| v_list, treedef = tree_flatten(v, leaf_type) | |||
| rst.extend(v_list) | |||
| children_defs.append(treedef) | |||
| return rst, TreeDef(type(values), aux_data, children_defs) | |||
| class TreeDef: | |||
| def __init__(self, type, aux_data, children_defs): | |||
| self.type = type | |||
| self.aux_data = aux_data | |||
| self.children_defs = children_defs | |||
| self.num_leaves = sum(ch.num_leaves for ch in children_defs) | |||
| def unflatten(self, leaves): | |||
| assert len(leaves) == self.num_leaves | |||
| start = 0 | |||
| children = [] | |||
| for ch in self.children_defs: | |||
| children.append(ch.unflatten(leaves[start : start + ch.num_leaves])) | |||
| start += ch.num_leaves | |||
| return SUPPORTED_TYPE[self.type].unflatten(children, self.aux_data) | |||
| def __eq__(self, other): | |||
| return ( | |||
| self.type == other.type | |||
| and self.aux_data == other.aux_data | |||
| and self.num_leaves == other.num_leaves | |||
| and self.children_defs == other.children_defs | |||
| ) | |||
| def __repr__(self): | |||
| return "{}[{}]".format(self.type.__name__, self.children_defs) | |||
| class LeafDef(TreeDef): | |||
| def __init__(self, type): | |||
| super().__init__(type, None, []) | |||
| self.num_leaves = 1 | |||
| def unflatten(self, leaves): | |||
| assert len(leaves) == 1 | |||
| assert isinstance(leaves[0], self.type), self.type | |||
| return leaves[0] | |||
| def __repr__(self): | |||
| return "Leaf({})".format(self.type.__name__) | |||
| @@ -9,9 +9,11 @@ | |||
| import collections | |||
| import copy | |||
| import functools | |||
| from inspect import getmembers, isclass, ismethod | |||
| from typing import List, Type | |||
| from ... import module as M | |||
| from ...core._imperative_rt.core2 import Tensor as RawTensor | |||
| from ...core._imperative_rt.core2 import ( | |||
| is_tracing_module, | |||
| set_module_tracing, | |||
| @@ -28,6 +30,16 @@ from .module_tracer import ( | |||
| set_active_module_tracer, | |||
| ) | |||
| from .node import ModuleNode, Node, NodeMixin, TensorNode | |||
| from .pytree import tree_flatten | |||
| def _leaf_type(node): | |||
| if isinstance(node, RawTensor): | |||
| return (Tensor, TensorNode) | |||
| elif isinstance(node, (NodeMixin, Module)): | |||
| return (Module, ModuleNode, NodeMixin) | |||
| else: | |||
| return type(node) | |||
| class InternalGraph: | |||
| @@ -65,9 +77,7 @@ class InternalGraph: | |||
| for n, v in zip(self._inputs, inputs): | |||
| node2value[n] = v | |||
| for expr in self._exprs: | |||
| values = expr.interpret( | |||
| *list(Expr.get_arg_value(i, node2value) for i in expr.inputs) | |||
| ) | |||
| values = expr.interpret(*list(node2value[i] for i in expr.inputs)) | |||
| for n, v in zip(expr.outputs, values): | |||
| node2value[n] = v | |||
| return list(node2value[i] for i in self._outputs) | |||
| @@ -80,37 +90,39 @@ class InternalGraph: | |||
| ) | |||
| def _get_meth_name(obj, func): | |||
| for cls in type(obj).mro(): | |||
| for k, v in cls.__dict__.items(): | |||
| if v == func: | |||
| return k | |||
| return None | |||
| def _wrapped_function(orig_func): | |||
| @functools.wraps(orig_func) | |||
| def wrapped_fn(*inputs, **kwargs): | |||
| def wrapped_fn(*args, **kwargs): | |||
| if is_tracing_module(): | |||
| unset_module_tracing() | |||
| const_kwargs = {} | |||
| arg_names = orig_func.__code__.co_varnames | |||
| if orig_func.__qualname__.split(".").__len__() > 1: | |||
| # FIXME: a robust way to distinguish method and function. <XP> | |||
| inputs, tree_def = tree_flatten((args, kwargs), leaf_type=_leaf_type) | |||
| for i in inputs: | |||
| if not NodeMixin.get(i, None): | |||
| if isinstance(i, (RawTensor, NodeMixin)): | |||
| NodeMixin.wrap_safe(i, Constant.make(i)) | |||
| meth_name = _get_meth_name(args[0], wrapped_fn) | |||
| if meth_name: | |||
| self = inputs[0] | |||
| call_node = CallMethod.make(NodeMixin.get(self), orig_func.__name__) | |||
| call_node = CallMethod.make(NodeMixin.get(self), meth_name) | |||
| else: | |||
| call_node = CallFunction.make(orig_func) | |||
| def add_input(inp, varname=None): | |||
| node = Expr.get_args_node(inp) | |||
| if node is not None: | |||
| call_node.add_input(node, varname) | |||
| else: | |||
| const_kwargs[varname] = inp | |||
| for ind, inp in enumerate(inputs): | |||
| add_input(inp, arg_names[ind]) | |||
| for k, v in kwargs.items(): | |||
| add_input(v, k) | |||
| call_node.kwargs = const_kwargs | |||
| outputs = orig_func(*inputs, **kwargs) | |||
| call_node.add_inputs(inputs) | |||
| call_node.arg_def = tree_def | |||
| outputs = orig_func(*args, **kwargs) | |||
| call_node.add_outputs(outputs) | |||
| set_module_tracing() | |||
| return outputs | |||
| return orig_func(*inputs, **kwargs) | |||
| return orig_func(*args, **kwargs) | |||
| return wrapped_fn | |||
| @@ -120,14 +132,14 @@ class TracedModuleBuilder(NodeMixin): | |||
| _mod = None # type: Module | |||
| _body = None # type: InternalGraph | |||
| _is_builtin = None # type: bool | |||
| _arg_def = None # type: TreeDef | |||
| __builder_attributes__ = [ | |||
| "_mod", | |||
| "_body", | |||
| "_NodeMixin__node", | |||
| "_is_builtin", | |||
| "_is_traced", | |||
| "build", | |||
| "_arg_def" "build", | |||
| ] | |||
| def __init__(self, mod): | |||
| @@ -146,6 +158,7 @@ class TracedModuleBuilder(NodeMixin): | |||
| node = NodeMixin.get(self) | |||
| node.graph = self._body | |||
| node.attr_type_map = {} | |||
| node.arg_def = self._arg_def | |||
| traced_module = TracedModule(node) | |||
| for k, v in self.__dict__.items(): | |||
| if k not in TracedModuleBuilder.__builder_attributes__: | |||
| @@ -155,32 +168,34 @@ class TracedModuleBuilder(NodeMixin): | |||
| traced_module.m_node.attr_type_map[k] = type(v) | |||
| return traced_module | |||
| def __call__(self, *inputs, **kwargs): | |||
| def __call__(self, *args, **kwargs): | |||
| assert isinstance(self._mod, Module) | |||
| for arg in args: | |||
| assert isinstance(arg, RawTensor) | |||
| for k, v in kwargs.items(): | |||
| assert isinstance(v, RawTensor) | |||
| # prepare args and kwargs for inner graph | |||
| def mark_constant(x): | |||
| node = NodeMixin.get(x, None) | |||
| if node is None: # capture as constant | |||
| NodeMixin.wrap(x, lambda: Constant.make(x)) | |||
| inputs, tree_def = tree_flatten(((self, *args), kwargs), leaf_type=_leaf_type) | |||
| if self._arg_def is None: | |||
| self._arg_def = tree_def | |||
| assert self._arg_def == tree_def | |||
| for i in inputs: | |||
| mark_constant(i) | |||
| for k, v in kwargs.items(): | |||
| mark_constant(v) | |||
| callnode = CallMethod.make(NodeMixin.get(self)) | |||
| def add_input(x): | |||
| callnode.add_input(NodeMixin.get(x)) | |||
| callnode.add_inputs(inputs) | |||
| for i in inputs: | |||
| add_input(i) | |||
| for k, v in kwargs.items(): | |||
| add_input(v) | |||
| callnode.arg_def = tree_def | |||
| if self._is_builtin or self._is_traced: | |||
| unset_module_tracing() | |||
| outputs = self._mod(*inputs, **kwargs) | |||
| outputs = self._mod(*args, **kwargs) | |||
| set_module_tracing() | |||
| if self._is_builtin: | |||
| self._body = None | |||
| @@ -193,23 +208,21 @@ class TracedModuleBuilder(NodeMixin): | |||
| ) | |||
| # prepare args and kwargs for inner graph | |||
| def wrap(x): | |||
| # wrapped = copy.copy(x) # FIXME | |||
| wrapped = x # FIXME: <XP> | |||
| wrapped = copy.copy(x) # FIXME | |||
| NodeMixin.wrap( | |||
| wrapped, | |||
| lambda: Input.make(type=NodeMixin.get_wrapped_type(wrapped)), | |||
| ) | |||
| return wrapped | |||
| args = [] | |||
| for i in inputs: | |||
| args = [self] | |||
| for i in inputs[1:]: | |||
| args.append(wrap(i)) | |||
| for k, v in kwargs.items(): | |||
| kwargs[k] = wrap(v) | |||
| args, kwargs = tree_def.unflatten(args) | |||
| active_module_tracer().patcher.auto_patch( | |||
| getattr(getattr(self._mod, "forward", self._mod), "__globals__", {}) | |||
| ) | |||
| outputs = type(self._mod).forward(self, *args, **kwargs) | |||
| outputs = type(self._mod).forward(*args, **kwargs) | |||
| for i in ( | |||
| outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,) | |||
| @@ -269,8 +282,10 @@ class TracedModule(Module): | |||
| super(TracedModule, self).__init__() | |||
| self.m_node = node | |||
| def forward(self, *inputs): | |||
| rst = self.m_node.graph.interpret(self, *inputs) | |||
| def forward(self, *args, **kwargs): | |||
| inputs, treedef = tree_flatten(((self, *args), kwargs), leaf_type=_leaf_type) | |||
| assert treedef == self.m_node.arg_def | |||
| rst = self.m_node.graph.interpret(*inputs) | |||
| if len(rst) == 1: | |||
| rst = rst[0] | |||
| return rst | |||
| @@ -345,7 +360,6 @@ def register_as_builtin(mod_cls: Type[Module]) -> None: | |||
| def _register_all_builtin_module(): | |||
| from inspect import getmembers, isclass | |||
| for sub_mod in [M, M.qat, M.quantized]: | |||
| for m in getmembers(sub_mod): | |||
| @@ -357,7 +371,7 @@ def _register_all_builtin_module(): | |||
| module_tracer.register_as_builtin(m[1]) | |||
| def trace_module(mod: Module, *inputs: Tensor, **kwargs: Tensor) -> TracedModule: | |||
| def trace_module(mod: Module, *args: Tensor, **kwargs: Tensor) -> TracedModule: | |||
| """ | |||
| Traces module ``mod`` and returns corresponding TracedModule. | |||
| @@ -375,15 +389,13 @@ def trace_module(mod: Module, *inputs: Tensor, **kwargs: Tensor) -> TracedModule | |||
| builder = TracedModuleBuilder(mod) | |||
| NodeMixin.wrap_safe(builder, Input.make("TopModule", ModuleNode)) | |||
| inputs, _ = tree_flatten((args, kwargs)) | |||
| for _, i in enumerate(inputs): | |||
| NodeMixin.wrap_safe(i, Input.make("arg_{}".format(_))) | |||
| for k, v in kwargs.items(): | |||
| NodeMixin.wrap_safe(v, Input.make("kwarg_{}".format(k))) | |||
| builder(*inputs, **kwargs) | |||
| NodeMixin.wrap_safe( | |||
| i, Input.make("arg_{}".format(_), NodeMixin.get_wrapped_type(i)) | |||
| ) | |||
| builder(*args, **kwargs) | |||
| active_module_tracer().pop_scope() | |||
| return builder.build() | |||
| finally: | |||
| set_active_module_tracer(None) | |||