GitOrigin-RevId: ac0603057a
tags/v1.6.0
| @@ -9,6 +9,7 @@ | |||
| import builtins | |||
| import collections | |||
| import inspect | |||
| from typing import Callable, List | |||
| from ...core._imperative_rt import OpDef | |||
| @@ -16,10 +17,10 @@ from ...core._imperative_rt.core2 import Tensor as RawTensor | |||
| from ...core._imperative_rt.core2 import apply, set_module_tracing, unset_module_tracing | |||
| from ...core.ops.special import Const | |||
| from ...module import Module | |||
| from ...tensor import Tensor | |||
| from ...tensor import Parameter, Tensor | |||
| from .module_tracer import active_module_tracer, module_tracer | |||
| from .node import ModuleNode, Node, NodeMixin, TensorNode | |||
| from .pytree import TreeDef | |||
| from .pytree import TreeDef, tree_flatten | |||
| class Expr: | |||
| @@ -38,25 +39,28 @@ class Expr: | |||
| for val in vals: | |||
| node = NodeMixin.get(val, None) | |||
| if isinstance(node, (TensorNode, ModuleNode)): | |||
| if node not in self.inputs: | |||
| self.inputs.append(node) | |||
| self.inputs.append(node) | |||
| node.users.append(self) | |||
| else: | |||
| assert node is None | |||
| assert type(val) in builtins.__dict__.values() | |||
| idx = len(self.inputs) + len(self.const_val) | |||
| self.const_val.append((idx, val)) | |||
| def add_outputs(self, outputs): | |||
| def add_outputs(self, outputs, check_inplace=True): | |||
| self.outputs = [] | |||
| if not isinstance(outputs, collections.Sequence): | |||
| outputs = (outputs,) | |||
| if outputs is not None: | |||
| if not isinstance(outputs, collections.Sequence): | |||
| outputs = (outputs,) | |||
| for i in outputs: | |||
| assert isinstance(i, RawTensor) | |||
| self.outputs.append(NodeMixin.get_wrapped_type(i)(self)) | |||
| for i in outputs: | |||
| assert isinstance(i, RawTensor) | |||
| node = NodeMixin.get(i, None) if check_inplace else None | |||
| self.outputs.append( | |||
| node if node else NodeMixin.get_wrapped_type(i)(self) | |||
| ) | |||
| for i, node in zip(outputs, self.outputs,): | |||
| NodeMixin.wrap_safe(i, node) | |||
| for i, node in zip(outputs, self.outputs,): | |||
| NodeMixin.wrap_safe(i, node) | |||
| def unflatten_args(self, inputs): | |||
| if self.arg_def is not None: | |||
| @@ -110,6 +114,7 @@ class GetAttr(Expr): | |||
| self.inputs = [ | |||
| module, | |||
| ] | |||
| module.users.append(self) | |||
| self.name = name | |||
| node_cls = type if type else Node | |||
| self.outputs = [ | |||
| @@ -134,12 +139,20 @@ class GetAttr(Expr): | |||
| # expr: outputs = inputs[0].__call__(*inputs[1:]) | |||
| class CallMethod(Expr): | |||
| def __init__(self, module, method="__call__"): | |||
| assert isinstance(module, (TensorNode, ModuleNode)) | |||
| self.inputs = [ | |||
| module, | |||
| ] | |||
| self.const_val = [] | |||
| def __init__(self, node, method="__call__"): | |||
| if isinstance(node, type): | |||
| assert issubclass(node, Tensor) | |||
| cls = Parameter if issubclass(node, Parameter) else Tensor | |||
| self.inputs = [] | |||
| self.const_val = [(0, cls)] | |||
| else: | |||
| assert isinstance(node, (TensorNode, ModuleNode)) | |||
| node.users.append(self) | |||
| self.inputs = [ | |||
| node, | |||
| ] | |||
| self.const_val = [] | |||
| self.method = method | |||
| @classmethod | |||
| @@ -160,10 +173,13 @@ class CallMethod(Expr): | |||
| def interpret(self, *inputs): | |||
| args, kwargs = self.unflatten_args(inputs) | |||
| obj = args[0] | |||
| args = args[1:] | |||
| meth = getattr(obj, self.method) | |||
| if inspect.ismethod(meth): | |||
| args = args[1:] | |||
| outputs = getattr(obj, self.method)(*args, **kwargs) | |||
| if isinstance(outputs, RawTensor): | |||
| outputs = (outputs,) | |||
| if outputs is None: | |||
| return outputs | |||
| outputs, _ = tree_flatten(outputs, is_leaf=lambda x: isinstance(x, RawTensor)) | |||
| return outputs | |||
| def __repr__(self): | |||
| @@ -171,7 +187,7 @@ class CallMethod(Expr): | |||
| kwargs = ", ".join("{}={}".format(k, v) for k, v in self.kwargs.items()) | |||
| return "{} = {}.{}({})".format( | |||
| ", ".join(str(i) for i in self.outputs), | |||
| self.inputs[0], | |||
| self.args[0], | |||
| self.method, | |||
| ", ".join([args, kwargs]), | |||
| ) | |||
| @@ -209,9 +225,8 @@ class Apply(Expr): | |||
| if node is None: # capture as constant | |||
| NodeMixin.wrap_safe(i, Constant.make(i)) | |||
| apply_node = cls.make(opdef) | |||
| for i in inputs: | |||
| assert isinstance(i, RawTensor) | |||
| apply_node.inputs.append(NodeMixin.get(i)) | |||
| apply_node.add_inputs(inputs) | |||
| assert not apply_node.const_val | |||
| unset_module_tracing() | |||
| outputs = apply(opdef, *inputs) | |||
| @@ -283,7 +298,7 @@ class Constant(Expr): | |||
| return (self.value,) | |||
| def __repr__(self): | |||
| return "{} = Constant({})".format(self.outputs[0], self.value) | |||
| return "{} = Constant({})".format(self.outputs[0], type(self.value)) | |||
| def __getstate__(self): | |||
| state = self.__dict__.copy() | |||
| @@ -79,6 +79,8 @@ BUILTIN_ARRAY_METHOD = [ | |||
| "min", | |||
| "max", | |||
| "mean", | |||
| "__getitem__", | |||
| "__setitem__", | |||
| ] | |||
| @@ -176,7 +178,8 @@ class Patcher: | |||
| self.patch_module(module) | |||
| for meth in BUILTIN_ARRAY_METHOD: | |||
| self.patch_method(ArrayMethodMixin, meth, self.wrap_fn) | |||
| self.patch_method(Tensor, "detach", self.wrap_fn) | |||
| self.patch_method(Tensor, "__new__", self.wrap_fn) | |||
| for i, j in self._builtin_functions: | |||
| if id(i) not in self.visited_frames_ids: | |||
| self.patch_function(i, j, self.wrap_fn) | |||
| @@ -203,7 +206,13 @@ class Patcher: | |||
| import inspect | |||
| if id(module.__dict__) not in self.visited_frames_ids: | |||
| for k, v in module.__dict__.items(): | |||
| keys = ( | |||
| getattr(module, "__all__") | |||
| if hasattr(module, "__all__") | |||
| else module.__dict__.keys() | |||
| ) | |||
| for k in keys: | |||
| v = getattr(module, k) | |||
| if inspect.isfunction(v) and not k.startswith("_"): | |||
| self.patch_function(module.__dict__, k, self.wrap_fn) | |||
| self.visited_frames_ids.add(id(module.__dict__)) | |||
| @@ -6,7 +6,7 @@ | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from typing import Any, Dict, Tuple, Type | |||
| from typing import Any, Dict, List, Tuple, Type | |||
| import numpy | |||
| @@ -31,6 +31,7 @@ class Node: | |||
| def __init__(self, expr: "Expr", name: str = None): | |||
| self.expr = expr | |||
| self.users = [] # List[Expr] | |||
| self._id = Node.__total_id | |||
| Node.__total_id += 1 | |||
| self._name = name | |||
| @@ -59,11 +60,13 @@ class ModuleNode(Node): | |||
| module_type = Module # type: Type[Module] | |||
| attr_type_map = None # type: Dict[str, Type[Any]] | |||
| argdef_graph_map = None # type: Dict[Treedef, "InternalGraph"] | |||
| argdef_outdef_map = None # type: Dict[Treedef, Treedef] | |||
| def __init__(self, expr: "Expr", name: str = None): | |||
| super().__init__(expr, name) | |||
| self.attr_type_map = {} | |||
| self.argdef_graph_map = {} | |||
| self.argdef_outdef_map = {} | |||
| def __repr__(self): | |||
| if self._name is None: | |||
| @@ -10,6 +10,8 @@ | |||
| import collections | |||
| from typing import Callable, NamedTuple | |||
| import numpy as np | |||
| SUPPORTED_TYPE = {} | |||
| NodeType = NamedTuple("NodeType", [("flatten", Callable), ("unflatten", Callable)]) | |||
| @@ -33,7 +35,7 @@ def _dict_unflatten(inps, aux_data): | |||
| register_supported_type(list, lambda x: (x, None), lambda x, aux_data: list(x)) | |||
| register_supported_type(tuple, lambda x: (x, None), lambda x, aux_data: list(x)) | |||
| register_supported_type(tuple, lambda x: (x, None), lambda x, aux_data: tuple(x)) | |||
| register_supported_type(dict, _dict_flatten, _dict_unflatten) | |||
| register_supported_type( | |||
| slice, | |||
| @@ -52,7 +54,10 @@ def tree_flatten( | |||
| assert is_leaf(values) | |||
| node = LeafDef(leaf_type(values)) | |||
| if is_const_leaf(values): | |||
| node.const_val = values | |||
| if isinstance(values, np.ndarray): | |||
| node.const_val = str(values) | |||
| else: | |||
| node.const_val = values | |||
| return [values,], node | |||
| rst = [] | |||
| @@ -10,8 +10,13 @@ import collections | |||
| import copy | |||
| import functools | |||
| from inspect import getmembers, isclass, ismethod | |||
| from typing import Dict, List, Type | |||
| from typing import Callable, Dict, Iterable, List, Sequence, Type | |||
| import numpy as np | |||
| from numpy.lib.arraysetops import isin | |||
| from ... import functional as F | |||
| from ... import get_logger | |||
| from ... import module as M | |||
| from ...core._imperative_rt.core2 import Tensor as RawTensor | |||
| from ...core._imperative_rt.core2 import ( | |||
| @@ -19,6 +24,7 @@ from ...core._imperative_rt.core2 import ( | |||
| set_module_tracing, | |||
| unset_module_tracing, | |||
| ) | |||
| from ...core._trace_option import set_symbolic_shape | |||
| from ...core.tensor.array_method import ArrayMethodMixin | |||
| from ...module import Module | |||
| from ...tensor import Tensor | |||
| @@ -32,6 +38,8 @@ from .module_tracer import ( | |||
| from .node import ModuleNode, Node, NodeMixin, TensorNode | |||
| from .pytree import tree_flatten | |||
| logger = get_logger(__name__) | |||
| def _leaf_type(node): | |||
| if isinstance(node, RawTensor): | |||
| @@ -42,6 +50,11 @@ def _leaf_type(node): | |||
| return type(node) | |||
| def _is_leaf(node): | |||
| assert isinstance(node, RawTensor), type(node) | |||
| return isinstance(node, RawTensor) | |||
| def _is_const_leaf(node): | |||
| if isinstance(node, (RawTensor, NodeMixin, Module)): | |||
| return False | |||
| @@ -80,7 +93,13 @@ class InternalGraph: | |||
| @property | |||
| def exprs(self): | |||
| return _expr_list(self) | |||
| return ExprFilter(_expr_iter(self)) | |||
| def get_call_function(self, func: Callable = None): | |||
| return self.exprs.call_function(func) | |||
| def get_call_method(self, method: str = None): | |||
| return self.exprs.call_method(method) | |||
| def add_input(self, i): | |||
| self._inputs.append(i) | |||
| @@ -88,16 +107,131 @@ class InternalGraph: | |||
| def add_output(self, o): | |||
| self._outputs.append(o) | |||
| def get_dep_exprs(self, nodes: Sequence[Node]) -> List[Expr]: | |||
| if not isinstance(nodes, Sequence): | |||
| nodes = (nodes,) | |||
| ret = list() | |||
| queue = list(nodes) | |||
| while queue: | |||
| node = queue.pop() | |||
| expr = node.expr | |||
| if expr not in ret: | |||
| ret.append(expr) | |||
| for i in expr.inputs: | |||
| if i not in queue: | |||
| queue.append(i) | |||
| return ret | |||
| def insert_call_function(self, func: Callable, nodes: Sequence[Node]): | |||
| if not isinstance(nodes, Sequence): | |||
| nodes = [nodes] | |||
| assert isinstance(func, Callable) | |||
| for i in nodes: | |||
| assert isinstance( | |||
| i, TensorNode | |||
| ), "CallFunction only accept TensorNode as inputs" | |||
| expr = CallFunction(func) | |||
| expr.inputs = nodes | |||
| for i in nodes: | |||
| i.users.append(expr) | |||
| idx = max(self._exprs.index(i.expr) for i in nodes) + 1 | |||
| self._exprs.insert(idx, expr) | |||
| fake_inp_val = tuple(F.zeros(shape=i.shape, dtype=i.dtype) for i in nodes) | |||
| fake_out_val = func(*fake_inp_val) | |||
| def create_node(val: Tensor): | |||
| node = TensorNode(expr) | |||
| node.shape = val.shape | |||
| node.dtype = val.dtype | |||
| return node | |||
| out_nodes = list(create_node(i) for i in fake_out_val) | |||
| expr.outputs = out_nodes | |||
| return out_nodes | |||
| def insert_call_method(self, target, method, args): | |||
| if not isinstance(args, Sequence): | |||
| args = [args] | |||
| assert isinstance(target, (TensorNode, ModuleNode)) | |||
| assert isinstance(method, str) | |||
| for i in args: | |||
| assert isinstance(i, TensorNode) | |||
| expr = CallMethod(method) | |||
| expr.inputs = [target, *args] | |||
| if isinstance(target, TensorNode): | |||
| fake_target_val = F.zeros(shape=target.shape, dtype=target.dtype) | |||
| fake_inp_val = tuple(F.zeros(shape=i.shape, dtype=i.dtype) for i in args) | |||
| fake_out_val = getattr(fake_target_val, method)(fake_inp_val) | |||
| def create_node(val: Tensor): | |||
| node = TensorNode(expr) | |||
| node.shape = val.shape | |||
| node.dtype = val.dtype | |||
| return node | |||
| out_nodes = list(create_node(i) for i in fake_out_val) | |||
| expr.outputs = out_nodes | |||
| else: | |||
| raise NotImplementedError() | |||
| return out_nodes | |||
| def replace_node(self, repl_dict: Dict[Node, Node]): | |||
| while repl_dict: | |||
| node, repl_node = repl_dict.popitem() | |||
| # check graph inputs and outputs | |||
| assert node not in self.inputs, "Cannot replace inputs" | |||
| for i, n in enumerate(self.outputs): | |||
| if n is node: | |||
| self.outputs[i] = repl_node | |||
| # update users of node and repl_node | |||
| # update inputs of expr in node.users | |||
| dep_exprs = self.get_dep_exprs(repl_node) | |||
| i = 0 | |||
| while i < len(node.users): | |||
| n = node.users[i] | |||
| if n in dep_exprs: | |||
| logger.info("Find a loop: ignore this replacement once") | |||
| logger.info("node: %s" % node.__repr__()) | |||
| logger.info("repl_node: %s" % repl_node.__repr__()) | |||
| i += 1 | |||
| continue | |||
| repl_node.users.append(n) | |||
| node.users.pop(i) | |||
| idx = n.inputs.index(node) | |||
| n.inputs[idx] = repl_node | |||
| def compile(self): | |||
| """ | |||
| Delete unused expr. | |||
| """ | |||
| dep_exprs = self.get_dep_exprs(self.outputs) | |||
| i = 0 | |||
| while i < len(self._exprs): | |||
| expr = self._exprs[i] | |||
| if expr in dep_exprs: | |||
| i += 1 | |||
| continue | |||
| for n in expr.inputs: | |||
| n.users.remove(expr) | |||
| self._exprs.remove(expr) | |||
| def interpret(self, *inputs): | |||
| # TODO: support kwargs ? | |||
| # TODO: skip expressions which are independent and have no side effect | |||
| node2value = {} | |||
| for n, v in zip(self._inputs, inputs): | |||
| node2value[n] = v | |||
| for expr in self._exprs: | |||
| values = expr.interpret(*list(node2value[i] for i in expr.inputs)) | |||
| for n, v in zip(expr.outputs, values): | |||
| node2value[n] = v | |||
| if values is not None: | |||
| for n, v in zip(expr.outputs, values): | |||
| node2value[n] = v | |||
| return list(node2value[i] for i in self._outputs) | |||
| def __repr__(self): | |||
| @@ -109,7 +243,8 @@ class InternalGraph: | |||
| def _get_meth_name(obj, func): | |||
| for cls in type(obj).mro(): | |||
| tp = obj if isinstance(obj, type) else type(obj) | |||
| for cls in tp.mro(): | |||
| for k, v in cls.__dict__.items(): | |||
| if v == func: | |||
| return k | |||
| @@ -131,15 +266,31 @@ def _wrapped_function(orig_func): | |||
| meth_name = _get_meth_name(args[0], wrapped_fn) | |||
| if meth_name: | |||
| self = inputs[0] | |||
| call_node = CallMethod.make(NodeMixin.get(self), meth_name) | |||
| if meth_name == "__new__": | |||
| if all([not isinstance(i, RawTensor) for i in inputs]): | |||
| # only trace Tensor.__new__() when there are tensors in args | |||
| set_module_tracing() | |||
| return orig_func(*args, **kwargs) | |||
| if isinstance(args[1], RawTensor): | |||
| node = NodeMixin.get(inputs[1]) | |||
| inputs[1] = copy.copy(inputs[1]) | |||
| # copy inputs[1] to avoid tensor and Tensor(tensor) share same m_tensor, which will cause they have same _NodeMixin__node in tracing. | |||
| NodeMixin.wrap_safe(inputs[1], node) | |||
| args, kwargs = tree_def.unflatten(inputs) | |||
| call_node = CallMethod.make(self, meth_name) | |||
| else: | |||
| call_node = CallMethod.make(NodeMixin.get(self), meth_name) | |||
| call_node.add_inputs(inputs[1:]) | |||
| else: | |||
| call_node = CallFunction.make(orig_func) | |||
| call_node.add_inputs(inputs) | |||
| call_node.add_inputs(inputs) | |||
| call_node.arg_def = tree_def | |||
| outputs = orig_func(*args, **kwargs) | |||
| call_node.add_outputs(outputs) | |||
| if meth_name == "__new__": | |||
| call_node.add_outputs(outputs, False) | |||
| else: | |||
| call_node.add_outputs(outputs) | |||
| set_module_tracing() | |||
| return outputs | |||
| return orig_func(*args, **kwargs) | |||
| @@ -197,13 +348,14 @@ class TracedModuleBuilder(NodeMixin): | |||
| mark_constant(i) | |||
| callnode = CallMethod.make(NodeMixin.get(self)) | |||
| callnode.add_inputs(inputs) | |||
| callnode.add_inputs(inputs[1:]) | |||
| callnode.arg_def = tree_def | |||
| if self._is_builtin: | |||
| unset_module_tracing() | |||
| outputs = self._mod(*args, **kwargs) | |||
| rst = self._mod(*args, **kwargs) | |||
| outputs, out_def = tree_flatten(rst, leaf_type=_leaf_type, is_leaf=_is_leaf) | |||
| set_module_tracing() | |||
| if self._is_builtin: | |||
| self._body = None | |||
| @@ -215,14 +367,13 @@ class TracedModuleBuilder(NodeMixin): | |||
| NodeMixin.wrap_safe( | |||
| self, Input.make("self", NodeMixin.get_wrapped_type(self)) | |||
| ) | |||
| origin_inp_node = [NodeMixin.get(i, None) for i in inputs[1:]] | |||
| # prepare args and kwargs for inner graph | |||
| def wrap(x): | |||
| wrapped = copy.copy(x) # FIXME | |||
| NodeMixin.wrap( | |||
| wrapped, | |||
| lambda: Input.make(type=NodeMixin.get_wrapped_type(wrapped)), | |||
| x, lambda: Input.make(type=NodeMixin.get_wrapped_type(x)), | |||
| ) | |||
| return wrapped | |||
| return x | |||
| args = [self] | |||
| for i in inputs[1:]: | |||
| @@ -231,21 +382,25 @@ class TracedModuleBuilder(NodeMixin): | |||
| active_module_tracer().patcher.auto_patch( | |||
| getattr(getattr(self._mod, "forward", self._mod), "__globals__", {}) | |||
| ) | |||
| outputs = type(self._mod).forward(*args, **kwargs) | |||
| rst = type(self._mod).forward(*args, **kwargs) | |||
| outputs, out_def = tree_flatten(rst, leaf_type=_leaf_type, is_leaf=_is_leaf) | |||
| for i in ( | |||
| outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,) | |||
| ): | |||
| active_module_tracer().current_scope().add_output(NodeMixin.get(i)) | |||
| NodeMixin.wrap_safe(self, orig_self) | |||
| for arg, node in zip(inputs[1:], origin_inp_node): | |||
| if node: | |||
| NodeMixin.wrap_safe(arg, node) | |||
| active_module_tracer().pop_scope() | |||
| # rebind output to outer graph | |||
| callnode.add_outputs(outputs) | |||
| self_node = NodeMixin.get(self) | |||
| self_node.argdef_graph_map[callnode.arg_def] = self._body | |||
| return outputs | |||
| self_node.argdef_outdef_map[callnode.arg_def] = out_def | |||
| return rst | |||
| def __getattr__(self, name): | |||
| if name not in self._mod.__dict__: | |||
| @@ -268,20 +423,29 @@ class TracedModuleBuilder(NodeMixin): | |||
| return super().__getattribute__(name) | |||
| else: | |||
| wrapped = super().__getattribute__(name) | |||
| if name in self._mod.__dict__ and not NodeMixin.get(wrapped, None): | |||
| assert not self._is_builtin | |||
| NodeMixin.wrap( | |||
| wrapped, | |||
| lambda: GetAttr.make( | |||
| if name in self._mod.__dict__: | |||
| if not NodeMixin.get(wrapped, None): | |||
| assert not self._is_builtin | |||
| NodeMixin.wrap( | |||
| wrapped, | |||
| lambda: GetAttr.make( | |||
| NodeMixin.get(self), | |||
| name, | |||
| type=NodeMixin.get_wrapped_type(wrapped), | |||
| ), | |||
| ) | |||
| else: | |||
| node = NodeMixin.get(wrapped) | |||
| expr = GetAttr.make( | |||
| NodeMixin.get(self), | |||
| name, | |||
| type=NodeMixin.get_wrapped_type(wrapped), | |||
| ), | |||
| ) | |||
| ).expr | |||
| expr.outputs[0] = node | |||
| return wrapped | |||
| class _expr_list: | |||
| class _expr_iter: | |||
| def __init__(self, graph: InternalGraph): | |||
| self.graph = graph | |||
| @@ -295,6 +459,59 @@ class _expr_list: | |||
| yield expr | |||
| class ExprFilter: | |||
| def __init__(self, expr_iter: Iterable): | |||
| self._iter = expr_iter | |||
| def __iter__(self): | |||
| return iter(self._iter) | |||
| def call_function(self, func): | |||
| return ExprFilterCallFunction(self, func) | |||
| def call_method(self, method): | |||
| return ExprFilterCallMethod(self, method) | |||
| def as_list(self): | |||
| return list(self) | |||
| def as_dict(self): | |||
| raise NotImplementedError("need key") | |||
| def as_unique(self): | |||
| (expr,) = self | |||
| return expr | |||
| def as_count(self): | |||
| return sum(1 for _ in self) | |||
| class ExprFilterCallFunction(ExprFilter): | |||
| def __init__(self, expr_iter, func: Callable = None): | |||
| super().__init__(expr_iter) | |||
| self.func = func | |||
| def __iter__(self): | |||
| for i in self._iter: | |||
| if not isinstance(i, CallFunction): | |||
| continue | |||
| if self.func is None or i.func == self.func: | |||
| yield i | |||
| class ExprFilterCallMethod(ExprFilter): | |||
| def __init__(self, expr_iter, method: str = None): | |||
| super().__init__(expr_iter) | |||
| self.method = method | |||
| def __iter__(self): | |||
| for i in self._iter: | |||
| if not isinstance(i, CallMethod): | |||
| continue | |||
| if self.method is None or i.method == self.method: | |||
| yield i | |||
| class TracedModule(Module): | |||
| """ | |||
| `TracedModule` is the Module created by tracing normal module. It owns a ModuleNode(m_node). `TracedModule` can not be called directly. It can be | |||
| @@ -312,10 +529,12 @@ class TracedModule(Module): | |||
| ((self, *args), kwargs), _leaf_type, is_const_leaf=_is_const_leaf | |||
| ) | |||
| assert treedef in self.m_node.argdef_graph_map | |||
| inputs = [i for i in inputs if isinstance(i, (Module, RawTensor))] | |||
| inputs = filter( | |||
| lambda i: isinstance(i, (Module, TracedModuleBuilder, RawTensor)), inputs | |||
| ) # allow TracedModuleBuilder for retrace. | |||
| outputs = self.m_node.argdef_graph_map[treedef].interpret(*inputs) | |||
| if len(outputs) == 1: | |||
| return outputs[0] | |||
| out_def = self.m_node.argdef_outdef_map[treedef] | |||
| outputs = out_def.unflatten(outputs) | |||
| return outputs | |||
| @property | |||
| @@ -339,9 +558,8 @@ class TracedModule(Module): | |||
| if graph is None: | |||
| assert not isinstance(module, TracedModule) | |||
| const = Constant(module) | |||
| modulenode = const.outputs[0] | |||
| modulenode.module_type = type(module) | |||
| call.inputs[0] = modulenode | |||
| const.outputs[0] = call.inputs[0] | |||
| const.outputs[0].expr = const | |||
| return [const, call] | |||
| exprs = [] | |||
| for expr in graph._exprs: | |||
| @@ -350,30 +568,41 @@ class TracedModule(Module): | |||
| if call and inp in graph._inputs: | |||
| inp_idx = graph._inputs.index(inp) | |||
| expr.inputs[idx] = call.inputs[inp_idx] | |||
| call.inputs[inp_idx].users.append(expr) | |||
| # replace outputs for submodule's expr | |||
| for idx, outp in enumerate(expr.outputs): | |||
| if call and outp in graph._outputs: | |||
| oup_idx = graph._outputs.index(outp) | |||
| expr.outputs[idx] = call.outputs[oup_idx] | |||
| call.outputs[oup_idx].expr = expr | |||
| if isinstance(expr, GetAttr): | |||
| # replace GetAttr with Constant | |||
| if isinstance(expr.outputs[0], TensorNode): | |||
| const = Constant(getattr(module, expr.name)) | |||
| const.outputs = expr.outputs | |||
| const.outputs[0].expr = const | |||
| exprs.append(const) | |||
| elif isinstance(expr, CallMethod): | |||
| obj_node = expr.inputs[0] | |||
| if isinstance(obj_node, ModuleNode): | |||
| assert isinstance(expr.inputs[0].expr, GetAttr) | |||
| (obj,) = expr.inputs[0].expr.interpret(module) | |||
| exprs.extend(_flatten_subgraph(expr.graph, obj, expr)) | |||
| pre_expr = expr.inputs[0].expr | |||
| if isinstance(pre_expr, GetAttr): | |||
| (obj,) = expr.inputs[0].expr.interpret(module) | |||
| exprs.extend(_flatten_subgraph(expr.graph, obj, expr)) | |||
| else: | |||
| # module has been replaced. | |||
| assert isinstance(pre_expr, Constant) | |||
| else: | |||
| exprs.append(expr) | |||
| else: | |||
| exprs.append(expr) | |||
| if call is not None: | |||
| for i in call.inputs: | |||
| i.users.remove(call) | |||
| return exprs | |||
| new_module.graph._exprs = _flatten_subgraph(new_module.graph, new_module) | |||
| @@ -422,22 +651,26 @@ def trace_module(mod: Module, *args: Tensor, **kwargs: Tensor) -> TracedModule: | |||
| """ | |||
| assert active_module_tracer() is None | |||
| try: | |||
| use_sym_shape = set_symbolic_shape(True) | |||
| set_module_tracing() | |||
| set_active_module_tracer(module_tracer(_wrapped_function)) | |||
| with active_module_tracer().patcher: | |||
| global_scope = InternalGraph() | |||
| active_module_tracer().push_scope(global_scope) | |||
| builder = TracedModuleBuilder(mod, True) | |||
| NodeMixin.wrap_safe(builder, Input.make("TopModule", ModuleNode)) | |||
| inputs, _ = tree_flatten((args, kwargs)) | |||
| inputs, _ = tree_flatten((args, kwargs), is_const_leaf=_is_const_leaf) | |||
| for _, i in enumerate(inputs): | |||
| NodeMixin.wrap_safe( | |||
| i, Input.make("arg_{}".format(_), NodeMixin.get_wrapped_type(i)) | |||
| ) | |||
| if isinstance(i, RawTensor): | |||
| NodeMixin.wrap_safe( | |||
| i, Input.make("arg_{}".format(_), NodeMixin.get_wrapped_type(i)) | |||
| ) | |||
| builder(*args, **kwargs) | |||
| active_module_tracer().pop_scope() | |||
| return builder.build() | |||
| finally: | |||
| set_symbolic_shape(use_sym_shape) | |||
| set_active_module_tracer(None) | |||
| unset_module_tracing() | |||
| @@ -0,0 +1,90 @@ | |||
| import io | |||
| import pickle | |||
| import numpy as np | |||
| import megengine.functional as F | |||
| import megengine.module as M | |||
| import megengine.utils.comp_graph_tools as cgtools | |||
| from megengine.core._trace_option import set_symbolic_shape | |||
| from megengine.experimental.traced_module import trace_module | |||
| from megengine.jit import trace | |||
| set_symbolic_shape(True) | |||
| class Main(M.Module): | |||
| def forward(self, x): | |||
| return x | |||
| class PreProcess(M.Module): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.I = F.ones((1,)) | |||
| self.M = F.zeros((1,)) | |||
| def forward(self, data, idx, roi): | |||
| N, H, W, C = data.shape | |||
| xmax = roi[:, 1, 0] | |||
| xmin = roi[:, 0, 0] | |||
| ymax = roi[:, 1, 1] | |||
| ymin = roi[:, 0, 1] | |||
| scale = F.maximum((xmax - xmin) / W, (ymax - ymin) / H) | |||
| I = F.broadcast_to(self.I, (N,)) | |||
| M = F.broadcast_to(self.M, (N, 3, 3)) | |||
| M[:, 0, 0] = scale | |||
| M[:, 0, 2] = xmin | |||
| M[:, 1, 1] = scale | |||
| M[:, 1, 2] = ymin | |||
| M[:, 2, 2] = I | |||
| resized = ( | |||
| F.warp_perspective( | |||
| data, M, (H, W), mat_idx=idx, border_mode="CONSTANT", format="NHWC" | |||
| ) | |||
| .transpose(0, 3, 1, 2) | |||
| .astype(np.float32) | |||
| ) | |||
| return resized | |||
| class Net(M.Module): | |||
| def __init__(self, traced_module): | |||
| super().__init__() | |||
| self.pre_process = PreProcess() | |||
| self.traced_module = traced_module | |||
| def forward(self, data, idx, roi): | |||
| x = self.pre_process(data, idx, roi) | |||
| x = self.traced_module(x) | |||
| return x | |||
| def test_preprocess(): | |||
| module = Main() | |||
| data = F.ones((1, 14, 8, 8), dtype=np.uint8) | |||
| traced_module = trace_module(module, data) | |||
| obj = pickle.dumps(traced_module) | |||
| traced_module = pickle.loads(obj) | |||
| module = Net(traced_module) | |||
| module.eval() | |||
| idx = F.zeros((1,), dtype=np.int32) | |||
| roi = F.ones((1, 2, 2), dtype=np.float32) | |||
| y = module(data, idx, roi) | |||
| traced_module = trace_module(module, data, idx, roi) | |||
| np.testing.assert_array_equal(traced_module(data, idx, roi), y) | |||
| func = trace(traced_module, capture_as_const=True) | |||
| np.testing.assert_array_equal(func(data, idx, roi), y) | |||
| model = io.BytesIO() | |||
| func.dump(model, arg_names=("data", "idx", "roi")) | |||
| model.seek(0) | |||
| infer_cg = cgtools.GraphInference(model) | |||
| np.testing.assert_allclose( | |||
| list( | |||
| infer_cg.run( | |||
| inp_dict={"data": data.numpy(), "idx": idx.numpy(), "roi": roi.numpy()} | |||
| ).values() | |||
| )[0], | |||
| y, | |||
| atol=1e-6, | |||
| ) | |||
| @@ -0,0 +1,113 @@ | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import numpy as np | |||
| import megengine.functional as F | |||
| import megengine.module as M | |||
| from megengine.experimental.traced_module import trace_module | |||
| from megengine.experimental.traced_module.expr import CallFunction, GetAttr | |||
| class MyBlock(M.Module): | |||
| def __init__(self, in_channels=3, channels=3): | |||
| super(MyBlock, self).__init__() | |||
| self.conv1 = M.Conv2d(in_channels, channels, 3, 1, padding=1, bias=False) | |||
| self.bn1 = M.BatchNorm2d(channels) | |||
| def forward(self, x): | |||
| x = self.conv1(x) | |||
| x = self.bn1(x) | |||
| x = F.relu(x) + 1 | |||
| return x | |||
| class MyModule(M.Module): | |||
| def __init__(self): | |||
| super(MyModule, self).__init__() | |||
| self.block0 = MyBlock() | |||
| self.block1 = MyBlock() | |||
| def forward(self, x): | |||
| x = self.block0(x) | |||
| x = self.block1(x) | |||
| return x | |||
| def _init_cls(cls): | |||
| module = cls() | |||
| x = F.ones((1, 3, 3, 3)) | |||
| y = module(x) | |||
| traced_module = trace_module(module, x) | |||
| return traced_module, x, y | |||
| def _init_block(): | |||
| return _init_cls(MyBlock) | |||
| def _init_module(): | |||
| return _init_cls(MyModule) | |||
| def test_search(): | |||
| traced_module, *_ = _init_block() | |||
| graph = traced_module.graph | |||
| relu_expr = graph.get_call_function(F.relu).as_unique() | |||
| assert isinstance(relu_expr, CallFunction) and relu_expr.func == F.relu | |||
| def test_insert(): | |||
| traced_module, x, expect = _init_block() | |||
| graph = traced_module.graph | |||
| relu_node = graph.get_call_function(F.relu).as_unique().outputs | |||
| neg_node = graph.insert_call_function(F.neg, relu_node) | |||
| graph.replace_node({relu_node[0]: neg_node[0]}) | |||
| graph.compile() | |||
| np.testing.assert_allclose(expect - 1, 1 - traced_module(x), atol=1e-6) | |||
| def test_delete(): | |||
| traced_module, x, expect = _init_block() | |||
| graph = traced_module.graph | |||
| relu_expr = graph.get_call_function(F.relu).as_unique() | |||
| node = relu_expr.outputs | |||
| repl_node = relu_expr.inputs | |||
| graph.replace_node({node[0]: repl_node[0]}) | |||
| graph.compile() | |||
| np.testing.assert_allclose(expect - 1, F.relu(traced_module(x) - 1), atol=1e-6) | |||
| def test_flatten(): | |||
| traced_module, x, expect = _init_module() | |||
| traced_module = traced_module.flatten() | |||
| traced_module.graph.compile() | |||
| assert all(not isinstance(i, GetAttr) for i in traced_module.graph._exprs) | |||
| assert len(traced_module.graph._exprs) == 12 | |||
| def test_extra_block(): | |||
| class PostProcess(M.Module): | |||
| def forward(self, x): | |||
| return x * 2 | |||
| class Net(M.Module): | |||
| def __init__(self, traced_module): | |||
| super().__init__() | |||
| self.post_process = PostProcess() | |||
| self.traced_module = traced_module | |||
| def forward(self, x): | |||
| x = self.traced_module(x) | |||
| x = self.post_process(x) | |||
| return x | |||
| traced_module, x, expect = _init_block() | |||
| module = Net(traced_module) | |||
| np.testing.assert_allclose(2 * expect, module(x), atol=1e-6) | |||
| traced_module = trace_module(module, x) | |||
| np.testing.assert_allclose(2 * expect, traced_module(x), atol=1e-6) | |||
| @@ -0,0 +1,52 @@ | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import pickle | |||
| import numpy as np | |||
| import megengine.functional as F | |||
| import megengine.module as M | |||
| from megengine import Tensor | |||
| from megengine.experimental.traced_module import trace_module | |||
| from megengine.module import Module | |||
| class MyBlock(Module): | |||
| def __init__(self, in_channels, channels): | |||
| super(MyBlock, self).__init__() | |||
| self.conv1 = M.Conv2d(in_channels, channels, 3, 1, padding=1, bias=False) | |||
| self.bn1 = M.BatchNorm2d(channels) | |||
| def forward(self, x): | |||
| x = self.conv1(x) | |||
| x = self.bn1(x) | |||
| x = F.relu(x) + 1 | |||
| return x | |||
| class MyModule(Module): | |||
| def __init__(self): | |||
| super(MyModule, self).__init__() | |||
| self.block0 = MyBlock(8, 4) | |||
| self.block1 = MyBlock(4, 2) | |||
| def forward(self, x): | |||
| x = self.block0(x) | |||
| x = self.block1(x) | |||
| return x | |||
| def test_dump_and_load(): | |||
| module = MyModule() | |||
| x = Tensor(np.ones((1, 8, 14, 14))) | |||
| expect = module(x) | |||
| traced_module = trace_module(module, x) | |||
| np.testing.assert_array_equal(expect, traced_module(x)) | |||
| obj = pickle.dumps(traced_module) | |||
| pickle.loads(obj) | |||
| np.testing.assert_array_equal(expect, traced_module(x)) | |||
| @@ -0,0 +1,42 @@ | |||
| import numpy as np | |||
| from megengine import Tensor | |||
| from megengine.experimental.traced_module import trace_module | |||
| from megengine.module import Module as M | |||
| class MyModule1(M): | |||
| def forward(self, x): | |||
| y = Tensor(x) | |||
| y += 1 | |||
| x = x + 2 | |||
| return x, y | |||
| class MyModule2(M): | |||
| def forward(self, x): | |||
| y = Tensor([1, x, 1]) | |||
| y += 1 | |||
| x = x + 2 | |||
| return x, y | |||
| def test_trace_module(): | |||
| x = Tensor(1) | |||
| m1 = MyModule1() | |||
| tm1 = trace_module(m1, x) | |||
| m2 = MyModule2() | |||
| tm2 = trace_module(m2, x) | |||
| inp = Tensor(2) | |||
| gt = m1(inp) | |||
| output = tm1(inp) | |||
| for a, b in zip(output, gt): | |||
| np.testing.assert_equal(a.numpy(), b.numpy()) | |||
| gt1 = m2(inp) | |||
| output1 = tm2(inp) | |||
| for a, b in zip(output1, gt1): | |||
| np.testing.assert_equal(a.numpy(), b.numpy()) | |||
| @@ -0,0 +1,94 @@ | |||
| import io | |||
| import pickle | |||
| import numpy as np | |||
| import megengine as mge | |||
| import megengine.functional as F | |||
| import megengine.module as M | |||
| import megengine.utils.comp_graph_tools as cgtools | |||
| from megengine.core._trace_option import set_symbolic_shape | |||
| from megengine.experimental.traced_module import trace_module | |||
| from megengine.jit import trace | |||
| set_symbolic_shape(True) | |||
| class Main(M.Module): | |||
| def forward(self, x): | |||
| return x["data"] | |||
| class PreProcess(M.Module): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.A = F.zeros((1,)) | |||
| self.I = F.ones((1,)) | |||
| self.bb_out = mge.tensor( | |||
| np.array([[[0, 0], [160, 0], [160, 48], [0, 48]]], dtype="float32") | |||
| ) | |||
| def forward(self, data, quad): | |||
| """ | |||
| data: (1, 3, 48, 160) | |||
| quad: (1, 4, 2) | |||
| """ | |||
| N = quad.shape[0] | |||
| dst = F.repeat(self.bb_out, N, axis=0).reshape(-1, 4, 2) | |||
| I = F.broadcast_to(self.I, quad.shape) | |||
| A = F.broadcast_to(self.A, (N, 8, 8)) | |||
| A[:, 0:4, 0:2] = quad | |||
| A[:, 4:8, 5:6] = I[:, :, 0:1] | |||
| A[:, 0:4, 6:8] = -quad * dst[:, :, 0:1] | |||
| A[:, 4:8, 3:5] = quad | |||
| A[:, 0:4, 2:3] = I[:, :, 0:1] | |||
| A[:, 4:8, 6:8] = -quad * dst[:, :, 1:2] | |||
| B = dst.transpose(0, 2, 1).reshape(-1, 8, 1) | |||
| M = F.concat([F.matmul(F.matinv(A), B)[:, :, 0], I[:, 0:1, 0]], axis=1).reshape( | |||
| -1, 3, 3 | |||
| ) | |||
| new_data = F.warp_perspective(data, M, (48, 160)) # (N, 3, 48, 160) | |||
| return {"data": new_data} | |||
| class Net(M.Module): | |||
| def __init__(self, traced_module): | |||
| super().__init__() | |||
| self.pre_process = PreProcess() | |||
| self.traced_module = traced_module | |||
| def forward(self, data, quad): | |||
| x = self.pre_process(data, quad) | |||
| x = self.traced_module(x) | |||
| return x | |||
| def test_preprocess(): | |||
| batch_size = 2 | |||
| module = Main() | |||
| data = mge.tensor( | |||
| np.random.randint(0, 256, size=(batch_size, 3, 48, 160)), dtype=np.float32 | |||
| ) | |||
| traced_module = trace_module(module, {"data": data}) | |||
| obj = pickle.dumps(traced_module) | |||
| traced_module = pickle.loads(obj) | |||
| module = Net(traced_module) | |||
| module.eval() | |||
| quad = mge.tensor(np.random.normal(size=(batch_size, 4, 2)), dtype=np.float32) | |||
| expect = module(data, quad) | |||
| traced_module = trace_module(module, data, quad) | |||
| actual = traced_module(data, quad) | |||
| for i, j in zip(expect, actual): | |||
| np.testing.assert_array_equal(i, j) | |||
| func = trace(traced_module, capture_as_const=True) | |||
| actual = func(data, quad) | |||
| for i, j in zip(expect, actual): | |||
| np.testing.assert_array_equal(i, j) | |||
| model = io.BytesIO() | |||
| func.dump(model, arg_names=("data", "quad")) | |||
| model.seek(0) | |||
| infer_cg = cgtools.GraphInference(model) | |||
| actual = list( | |||
| infer_cg.run(inp_dict={"data": data.numpy(), "quad": quad.numpy()}).values() | |||
| )[0] | |||
| np.testing.assert_allclose(expect, actual) | |||