GitOrigin-RevId: ad2cdc1b61
tags/v1.6.0
| @@ -9,12 +9,13 @@ | |||
| import collections | |||
| from typing import List | |||
| from typing import Callable, List | |||
| from ...core._imperative_rt import OpDef | |||
| from ...core._imperative_rt.core2 import Tensor as RawTensor | |||
| from ...core._imperative_rt.core2 import apply, set_module_tracing, unset_module_tracing | |||
| from ...core.ops.special import Const | |||
| from ...module import Module | |||
| from ...tensor import Tensor | |||
| from .module_tracer import active_module_tracer | |||
| from .node import ModuleNode, Node, NodeMixin, TensorNode | |||
| @@ -22,12 +23,66 @@ from .node import ModuleNode, Node, NodeMixin, TensorNode | |||
| class Expr: | |||
| """ | |||
| ``Expr`` represents the operations(i.e. Call, Apply, GetAttr, Input, Constant) on ``Node``. | |||
| ``Expr`` represents the operations(i.e. CallMethod, CallFunction, Apply, GetAttr, Input, Constant) on ``Node``. | |||
| """ | |||
| inputs = None # type: List[Node] | |||
| outputs = None # type: List[Node] | |||
| def add_input(self, node): | |||
| self.inputs.append(node) | |||
| def add_outputs(self, outputs): | |||
| self.outputs = [] | |||
| if not isinstance(outputs, collections.Sequence): | |||
| outputs = (outputs,) | |||
| for i in outputs: | |||
| self.outputs.append(NodeMixin.get_wrapped_type(i)(self)) | |||
| for i, node in zip(outputs, self.outputs,): | |||
| NodeMixin.wrap_safe(i, node) | |||
| @classmethod | |||
| def get_args_node(cls, arg): | |||
| """ | |||
| Create nodes by ``arg``, which may be a container. | |||
| Return the same structure with arg. | |||
| If ``arg`` was not Tensor or Module, it will be stored as const. | |||
| :param arg: tensor, module or const. | |||
| """ | |||
| if isinstance(arg, (RawTensor, Module)): | |||
| if not NodeMixin.get(arg, None): | |||
| NodeMixin.wrap_safe(arg, Constant.make(arg)) | |||
| return NodeMixin.get(arg) | |||
| elif isinstance(arg, collections.abc.Sequence): | |||
| seq_cls = type(arg) | |||
| return seq_cls([Expr.get_args_node(a) for a in arg]) | |||
| else: | |||
| # TODO: assert arg type | |||
| return arg # as const | |||
| @classmethod | |||
| def get_arg_value(cls, inp_node, node2value): | |||
| """ | |||
| Get values from node2value by inp_node, which may be a container. | |||
| Return the same structure with inp_node. | |||
| If ``inp_node`` was not in node2value, it is a const. | |||
| :param inp_node: nodes. | |||
| :param node2value: dict from node to tensor and module. | |||
| """ | |||
| if inp_node in node2value: | |||
| return node2value[inp_node] | |||
| elif isinstance(inp_node, collections.abc.Sequence): | |||
| seq_cls = type(inp_node) | |||
| return seq_cls([Expr.get_arg_value(i, node2value) for i in inp_node]) | |||
| else: | |||
| return inp_node | |||
| # expr: None (i.e. fake expression which is used to mark input) | |||
| class Input(Expr): | |||
| @@ -83,23 +138,22 @@ class GetAttr(Expr): | |||
| # expr: outputs = inputs[0].__call__(*inputs[1:]) | |||
| class Call(Expr): | |||
| def __init__(self, module): | |||
| assert isinstance(module, ModuleNode) | |||
| class CallMethod(Expr): | |||
| def __init__(self, module, method="__call__"): | |||
| assert isinstance(module, (TensorNode, ModuleNode)) | |||
| self.inputs = [ | |||
| module, | |||
| ] | |||
| self.method = method | |||
| self.arg_names = [] | |||
| self.kwargs = {} # const kwargs | |||
| def add_input(self, node): | |||
| def add_input(self, node, arg_name=None): | |||
| if arg_name == "self": # FIXME: <XP> | |||
| return | |||
| self.inputs.append(node) | |||
| def add_outputs(self, references): | |||
| self.outputs = [] | |||
| if not isinstance(references, collections.Sequence): | |||
| references = (references,) | |||
| for i in references: | |||
| self.outputs.append(NodeMixin.get_wrapped_type(i)(self)) | |||
| if arg_name is not None: | |||
| self.arg_names.append(arg_name) | |||
| @classmethod | |||
| def make(cls, *args, **kwargs): | |||
| @@ -110,15 +164,16 @@ class Call(Expr): | |||
| def interpret(self, *inputs): | |||
| mod = inputs[0] | |||
| args = inputs[1:] | |||
| outputs = mod(*args) | |||
| outputs = getattr(mod, self.method)(*args, **self.kwargs) | |||
| if isinstance(outputs, RawTensor): | |||
| outputs = (outputs,) | |||
| return outputs | |||
| def __repr__(self): | |||
| return "{} = Call({})({})".format( | |||
| return "{} = CallMethod({}, {})({})".format( | |||
| ", ".join(str(i) for i in self.outputs), | |||
| self.inputs[0], | |||
| self.method, | |||
| ", ".join(str(i) for i in self.inputs[1:]), | |||
| ) | |||
| @@ -132,17 +187,6 @@ class Apply(Expr): | |||
| self.opdef = opdef | |||
| self.inputs = [] | |||
| def add_input(self, node): | |||
| self.inputs.append(node) | |||
| def add_outputs(self, references): | |||
| self.outputs = [] | |||
| if not isinstance(references, collections.Sequence): | |||
| references = (references,) | |||
| for i in references: | |||
| self.outputs.append(NodeMixin.get_wrapped_type(i)(self)) | |||
| @classmethod | |||
| def make(cls, *args, **kwargs): | |||
| expr = cls(*args, **kwargs) | |||
| @@ -179,6 +223,40 @@ class Apply(Expr): | |||
| return list(outputs) | |||
| class CallFunction(Expr): | |||
| def __init__(self, func): | |||
| assert isinstance(func, Callable) | |||
| self.func = func | |||
| self.inputs = [] | |||
| self.arg_names = [] | |||
| self.kwargs = {} # const kwargs | |||
| def add_input(self, node, arg_name): | |||
| self.inputs.append(node) | |||
| self.arg_names.append(arg_name) | |||
| @classmethod | |||
| def make(cls, *args, **kwargs): | |||
| expr = cls(*args, **kwargs) | |||
| active_module_tracer().current_scope().insert(expr) | |||
| return expr | |||
| def interpret(self, *inputs): | |||
| inp_dict = dict([(name, node) for node, name in zip(inputs, self.arg_names)]) | |||
| outputs = self.func(**inp_dict, **self.kwargs) | |||
| outputs = ( | |||
| outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,) | |||
| ) | |||
| return outputs | |||
| def __repr__(self): | |||
| return "{} = {}({})".format( | |||
| ", ".join(str(i) for i in self.outputs), | |||
| self.func.__module__ + "." + self.func.__name__, | |||
| ", ".join(str(i) for i in self.inputs), | |||
| ) | |||
| # expr outputs = self.value | |||
| class Constant(Expr): | |||
| value = None | |||
| @@ -6,7 +6,11 @@ | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import collections | |||
| from ... import Tensor | |||
| from ... import functional as F | |||
| from ...core.tensor.array_method import ArrayMethodMixin | |||
| from ...module import Module | |||
| _active_module_tracer = None | |||
| @@ -23,12 +27,14 @@ def set_active_module_tracer(tracer): | |||
| class module_tracer: | |||
| # builtin types | |||
| _opaque_types = set() | |||
| _active_scopes = None | |||
| def __init__(self): | |||
| def __init__(self, wrap_fn): | |||
| self._active_scopes = [] | |||
| self.patcher = Patcher(wrap_fn) | |||
| @classmethod | |||
| def register_as_builtin(cls, mod): | |||
| @@ -50,3 +56,105 @@ class module_tracer: | |||
| if self._active_scopes: | |||
| return self._active_scopes[-1] | |||
| return None | |||
| class PatchedFn: | |||
| frame_dict = None | |||
| name = None | |||
| origin_fn = None | |||
| def __init__(self, frame_dict, name): | |||
| self.frame_dict = frame_dict | |||
| self.name = name | |||
| self.origin_fn = ( | |||
| self.frame_dict[name] | |||
| if isinstance(frame_dict, collections.abc.Mapping) | |||
| else getattr(frame_dict, name) | |||
| ) | |||
| def set_func(self, func): | |||
| if isinstance(self.frame_dict, collections.abc.Mapping): | |||
| self.frame_dict[self.name] = func | |||
| else: | |||
| setattr(self.frame_dict, self.name, func) | |||
| class Patcher: | |||
| patched_fn_ids = set() | |||
| _builtin_functions = [] | |||
| _builtin_modules = [ | |||
| F, | |||
| F.distributed, | |||
| F.elemwise, | |||
| F.inplace, | |||
| F.loss, | |||
| F.math, | |||
| F.metric, | |||
| F.nn, | |||
| F.quantized, | |||
| F.tensor, | |||
| F.utils, | |||
| F.vision, | |||
| ] | |||
| _builtin_methods = [ | |||
| Tensor, | |||
| ArrayMethodMixin, | |||
| ] | |||
| def __init__(self, wrap_fn): | |||
| self.patched_fn = [] | |||
| self.visited_frames_ids = set() | |||
| self.wrap_fn = wrap_fn | |||
| for module in self._builtin_modules: | |||
| self.patch_module(module) | |||
| for cls in self._builtin_methods: | |||
| self.patch_cls(cls) | |||
| for i, j in self._builtin_functions: | |||
| if id(i) not in self.visited_frames_ids: | |||
| self.patch_function(i, j, self.wrap_fn) | |||
| def patch_function(self, frame_dict, fn, wrap_fn): | |||
| patched_fn = PatchedFn(frame_dict, fn) | |||
| self.patched_fn_ids.add(id(patched_fn.origin_fn)) | |||
| patched_fn.set_func(wrap_fn(patched_fn.origin_fn)) | |||
| self.patched_fn.append(patched_fn) | |||
| def patch_method(self, cls, name, wrap_fn): | |||
| self.patch_function(cls, name, wrap_fn) | |||
| def patch_cls(self, cls): | |||
| import inspect | |||
| if id(cls) not in self.visited_frames_ids: | |||
| for k, v in cls.__dict__.items(): | |||
| if inspect.isfunction(v) and not k.startswith("_"): | |||
| self.patch_function(cls, k, self.wrap_fn) | |||
| self.visited_frames_ids.add(id(cls)) | |||
| def patch_module(self, module): | |||
| import inspect | |||
| if id(module.__dict__) not in self.visited_frames_ids: | |||
| for k, v in module.__dict__.items(): | |||
| if inspect.isfunction(v) and not k.startswith("_"): | |||
| self.patch_function(module.__dict__, k, self.wrap_fn) | |||
| self.visited_frames_ids.add(id(module.__dict__)) | |||
| def auto_patch(self, frame_dict): | |||
| if id(frame_dict) not in self.visited_frames_ids: | |||
| for k, v in frame_dict.items(): | |||
| if id(v) in self.patched_fn_ids: | |||
| self.patch_function(frame_dict, k, self.wrap_fn) | |||
| self.visited_frames_ids.add(id(frame_dict)) | |||
| def __enter__(self): | |||
| return self | |||
| def __exit__(self, type, vlaue, trace): | |||
| while self.patched_fn: | |||
| pf = self.patched_fn.pop() | |||
| pf.set_func(pf.origin_fn) | |||
| self.visited_frames_ids.clear() | |||
| @@ -34,6 +34,10 @@ class Node: | |||
| Node.__total_id += 1 | |||
| self._name = name | |||
| def __setstate__(self, d): | |||
| self.__dict__ = d | |||
| Node.__total_id = max(Node.__total_id, self._id) + 1 | |||
| def __repr__(self): | |||
| if self._name is None: | |||
| return "%{}".format(self._id) | |||
| @@ -8,14 +8,25 @@ | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import collections | |||
| import copy | |||
| import functools | |||
| from typing import List, Type | |||
| from ... import module as M | |||
| from ...core._imperative_rt.core2 import set_module_tracing, unset_module_tracing | |||
| from ...core._imperative_rt.core2 import ( | |||
| is_tracing_module, | |||
| set_module_tracing, | |||
| unset_module_tracing, | |||
| ) | |||
| from ...core.tensor.array_method import ArrayMethodMixin | |||
| from ...module import Module | |||
| from ...tensor import Tensor | |||
| from .expr import Apply, Call, Constant, Expr, GetAttr, Input | |||
| from .module_tracer import active_module_tracer, module_tracer, set_active_module_tracer | |||
| from .expr import Apply, CallFunction, CallMethod, Constant, Expr, GetAttr, Input | |||
| from .module_tracer import ( | |||
| Patcher, | |||
| active_module_tracer, | |||
| module_tracer, | |||
| set_active_module_tracer, | |||
| ) | |||
| from .node import ModuleNode, Node, NodeMixin, TensorNode | |||
| @@ -54,7 +65,9 @@ class InternalGraph: | |||
| for n, v in zip(self._inputs, inputs): | |||
| node2value[n] = v | |||
| for expr in self._exprs: | |||
| values = expr.interpret(*list(node2value[i] for i in expr.inputs)) | |||
| values = expr.interpret( | |||
| *list(Expr.get_arg_value(i, node2value) for i in expr.inputs) | |||
| ) | |||
| for n, v in zip(expr.outputs, values): | |||
| node2value[n] = v | |||
| return list(node2value[i] for i in self._outputs) | |||
| @@ -67,6 +80,41 @@ class InternalGraph: | |||
| ) | |||
| def _wrapped_function(orig_func): | |||
| @functools.wraps(orig_func) | |||
| def wrapped_fn(*inputs, **kwargs): | |||
| if is_tracing_module(): | |||
| unset_module_tracing() | |||
| const_kwargs = {} | |||
| arg_names = orig_func.__code__.co_varnames | |||
| if orig_func.__qualname__.split(".").__len__() > 1: | |||
| # FIXME: a robust way to distinguish method and function. <XP> | |||
| self = inputs[0] | |||
| call_node = CallMethod.make(NodeMixin.get(self), orig_func.__name__) | |||
| else: | |||
| call_node = CallFunction.make(orig_func) | |||
| def add_input(inp, varname=None): | |||
| node = Expr.get_args_node(inp) | |||
| if node is not None: | |||
| call_node.add_input(node, varname) | |||
| else: | |||
| const_kwargs[varname] = inp | |||
| for ind, inp in enumerate(inputs): | |||
| add_input(inp, arg_names[ind]) | |||
| for k, v in kwargs.items(): | |||
| add_input(v, k) | |||
| call_node.kwargs = const_kwargs | |||
| outputs = orig_func(*inputs, **kwargs) | |||
| call_node.add_outputs(outputs) | |||
| set_module_tracing() | |||
| return outputs | |||
| return orig_func(*inputs, **kwargs) | |||
| return wrapped_fn | |||
| class TracedModuleBuilder(NodeMixin): | |||
| _mod = None # type: Module | |||
| @@ -120,7 +168,7 @@ class TracedModuleBuilder(NodeMixin): | |||
| mark_constant(i) | |||
| for k, v in kwargs.items(): | |||
| mark_constant(v) | |||
| callnode = Call.make(NodeMixin.get(self)) | |||
| callnode = CallMethod.make(NodeMixin.get(self)) | |||
| def add_input(x): | |||
| callnode.add_input(NodeMixin.get(x)) | |||
| @@ -145,7 +193,8 @@ class TracedModuleBuilder(NodeMixin): | |||
| ) | |||
| # prepare args and kwargs for inner graph | |||
| def wrap(x): | |||
| wrapped = copy.copy(x) # FIXME | |||
| # wrapped = copy.copy(x) # FIXME | |||
| wrapped = x # FIXME: <XP> | |||
| NodeMixin.wrap( | |||
| wrapped, | |||
| lambda: Input.make(type=NodeMixin.get_wrapped_type(wrapped)), | |||
| @@ -157,7 +206,9 @@ class TracedModuleBuilder(NodeMixin): | |||
| args.append(wrap(i)) | |||
| for k, v in kwargs.items(): | |||
| kwargs[k] = wrap(v) | |||
| active_module_tracer().patcher.auto_patch( | |||
| getattr(getattr(self._mod, "forward", self._mod), "__globals__", {}) | |||
| ) | |||
| outputs = type(self._mod).forward(self, *args, **kwargs) | |||
| for i in ( | |||
| @@ -171,11 +222,6 @@ class TracedModuleBuilder(NodeMixin): | |||
| # rebind output to outer graph | |||
| callnode.add_outputs(outputs) | |||
| for i, node in zip( | |||
| outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,), | |||
| callnode.outputs, | |||
| ): | |||
| NodeMixin.wrap_safe(i, node) | |||
| return outputs | |||
| def __getattr__(self, name): | |||
| @@ -229,6 +275,55 @@ class TracedModule(Module): | |||
| rst = rst[0] | |||
| return rst | |||
| @property | |||
| def all_exprs(self): | |||
| """ | |||
| Visit all ``Expr``s in the graph recursively. | |||
| :return: List[Expr] | |||
| """ | |||
| in_nodes = [i.expr for i in self.m_node.graph._inputs if not i is self] | |||
| def _flatten_submodule(module, call=None): | |||
| if not isinstance(module, TracedModule): | |||
| call.inputs[0] = module | |||
| return (call,) | |||
| exprs = [] | |||
| graph = module.m_node.graph | |||
| for expr in graph._exprs: | |||
| # replace inputs for submodule's expr | |||
| for idx, inp in enumerate(expr.inputs): | |||
| if call and inp in graph._inputs: | |||
| expr.inputs[idx] = call.inputs[idx] | |||
| # replace outputs for submodule's expr | |||
| for idx, outp in enumerate(expr.outputs): | |||
| if call and outp in graph._outputs: | |||
| expr.outputs[idx] = call.outputs[idx] | |||
| if isinstance(expr, GetAttr): | |||
| # replace GetAttr with Constant | |||
| if isinstance(expr.outputs[0], TensorNode): | |||
| const = Constant(getattr(module, expr.name)) | |||
| const.outputs = expr.outputs | |||
| exprs.append(const) | |||
| elif isinstance(expr, CallMethod): | |||
| obj_node = expr.inputs[0] | |||
| if isinstance(obj_node, ModuleNode): | |||
| (obj,) = expr.inputs[0].expr.interpret(module) | |||
| exprs.extend(_flatten_submodule(obj, expr)) | |||
| else: | |||
| exprs.append(expr) | |||
| else: | |||
| exprs.append(expr) | |||
| return exprs | |||
| return in_nodes + _flatten_submodule(self) | |||
| def __getstate__(self): | |||
| d = self.__dict__ | |||
| for k in Module.__dict__: | |||
| @@ -273,23 +368,23 @@ def trace_module(mod: Module, *inputs: Tensor, **kwargs: Tensor) -> TracedModule | |||
| assert active_module_tracer() is None | |||
| try: | |||
| set_module_tracing() | |||
| set_active_module_tracer(module_tracer()) | |||
| global_scope = InternalGraph() | |||
| active_module_tracer().push_scope(global_scope) | |||
| set_active_module_tracer(module_tracer(_wrapped_function)) | |||
| with active_module_tracer().patcher: | |||
| global_scope = InternalGraph() | |||
| active_module_tracer().push_scope(global_scope) | |||
| builder = TracedModuleBuilder(mod) | |||
| NodeMixin.wrap_safe(builder, Input.make("TopModule", ModuleNode)) | |||
| builder = TracedModuleBuilder(mod) | |||
| NodeMixin.wrap_safe(builder, Input.make("TopModule", ModuleNode)) | |||
| for _, i in enumerate(inputs): | |||
| NodeMixin.wrap_safe(i, Input.make("arg_{}".format(_))) | |||
| for k, v in kwargs.items(): | |||
| NodeMixin.wrap_safe(v, Input.make("kwarg_{}".format(k))) | |||
| for _, i in enumerate(inputs): | |||
| NodeMixin.wrap_safe(i, Input.make("arg_{}".format(_))) | |||
| for k, v in kwargs.items(): | |||
| NodeMixin.wrap_safe(v, Input.make("kwarg_{}".format(k))) | |||
| builder(*inputs, **kwargs) | |||
| active_module_tracer().pop_scope() | |||
| builder(*inputs, **kwargs) | |||
| active_module_tracer().pop_scope() | |||
| return builder.build() | |||
| return builder.build() | |||
| finally: | |||
| set_active_module_tracer(None) | |||
| unset_module_tracing() | |||