GitOrigin-RevId: ad2cdc1b61
tags/v1.6.0
| @@ -9,12 +9,13 @@ | |||||
| import collections | import collections | ||||
| from typing import List | |||||
| from typing import Callable, List | |||||
| from ...core._imperative_rt import OpDef | from ...core._imperative_rt import OpDef | ||||
| from ...core._imperative_rt.core2 import Tensor as RawTensor | from ...core._imperative_rt.core2 import Tensor as RawTensor | ||||
| from ...core._imperative_rt.core2 import apply, set_module_tracing, unset_module_tracing | from ...core._imperative_rt.core2 import apply, set_module_tracing, unset_module_tracing | ||||
| from ...core.ops.special import Const | from ...core.ops.special import Const | ||||
| from ...module import Module | |||||
| from ...tensor import Tensor | from ...tensor import Tensor | ||||
| from .module_tracer import active_module_tracer | from .module_tracer import active_module_tracer | ||||
| from .node import ModuleNode, Node, NodeMixin, TensorNode | from .node import ModuleNode, Node, NodeMixin, TensorNode | ||||
| @@ -22,12 +23,66 @@ from .node import ModuleNode, Node, NodeMixin, TensorNode | |||||
| class Expr: | class Expr: | ||||
| """ | """ | ||||
| ``Expr`` represents the operations(i.e. Call, Apply, GetAttr, Input, Constant) on ``Node``. | |||||
| ``Expr`` represents the operations(i.e. CallMethod, CallFunction, Apply, GetAttr, Input, Constant) on ``Node``. | |||||
| """ | """ | ||||
| inputs = None # type: List[Node] | inputs = None # type: List[Node] | ||||
| outputs = None # type: List[Node] | outputs = None # type: List[Node] | ||||
| def add_input(self, node): | |||||
| self.inputs.append(node) | |||||
| def add_outputs(self, outputs): | |||||
| self.outputs = [] | |||||
| if not isinstance(outputs, collections.Sequence): | |||||
| outputs = (outputs,) | |||||
| for i in outputs: | |||||
| self.outputs.append(NodeMixin.get_wrapped_type(i)(self)) | |||||
| for i, node in zip(outputs, self.outputs,): | |||||
| NodeMixin.wrap_safe(i, node) | |||||
| @classmethod | |||||
| def get_args_node(cls, arg): | |||||
| """ | |||||
| Create nodes by ``arg``, which may be a container. | |||||
| Return the same structure with arg. | |||||
| If ``arg`` was not Tensor or Module, it will be stored as const. | |||||
| :param arg: tensor, module or const. | |||||
| """ | |||||
| if isinstance(arg, (RawTensor, Module)): | |||||
| if not NodeMixin.get(arg, None): | |||||
| NodeMixin.wrap_safe(arg, Constant.make(arg)) | |||||
| return NodeMixin.get(arg) | |||||
| elif isinstance(arg, collections.abc.Sequence): | |||||
| seq_cls = type(arg) | |||||
| return seq_cls([Expr.get_args_node(a) for a in arg]) | |||||
| else: | |||||
| # TODO: assert arg type | |||||
| return arg # as const | |||||
| @classmethod | |||||
| def get_arg_value(cls, inp_node, node2value): | |||||
| """ | |||||
| Get values from node2value by inp_node, which may be a container. | |||||
| Return the same structure with inp_node. | |||||
| If ``inp_node`` was not in node2value, it is a const. | |||||
| :param inp_node: nodes. | |||||
| :param node2value: dict from node to tensor and module. | |||||
| """ | |||||
| if inp_node in node2value: | |||||
| return node2value[inp_node] | |||||
| elif isinstance(inp_node, collections.abc.Sequence): | |||||
| seq_cls = type(inp_node) | |||||
| return seq_cls([Expr.get_arg_value(i, node2value) for i in inp_node]) | |||||
| else: | |||||
| return inp_node | |||||
| # expr: None (i.e. fake expression which is used to mark input) | # expr: None (i.e. fake expression which is used to mark input) | ||||
| class Input(Expr): | class Input(Expr): | ||||
| @@ -83,23 +138,22 @@ class GetAttr(Expr): | |||||
| # expr: outputs = inputs[0].__call__(*inputs[1:]) | # expr: outputs = inputs[0].__call__(*inputs[1:]) | ||||
| class Call(Expr): | |||||
| def __init__(self, module): | |||||
| assert isinstance(module, ModuleNode) | |||||
| class CallMethod(Expr): | |||||
| def __init__(self, module, method="__call__"): | |||||
| assert isinstance(module, (TensorNode, ModuleNode)) | |||||
| self.inputs = [ | self.inputs = [ | ||||
| module, | module, | ||||
| ] | ] | ||||
| self.method = method | |||||
| self.arg_names = [] | |||||
| self.kwargs = {} # const kwargs | |||||
| def add_input(self, node): | |||||
| def add_input(self, node, arg_name=None): | |||||
| if arg_name == "self": # FIXME: <XP> | |||||
| return | |||||
| self.inputs.append(node) | self.inputs.append(node) | ||||
| def add_outputs(self, references): | |||||
| self.outputs = [] | |||||
| if not isinstance(references, collections.Sequence): | |||||
| references = (references,) | |||||
| for i in references: | |||||
| self.outputs.append(NodeMixin.get_wrapped_type(i)(self)) | |||||
| if arg_name is not None: | |||||
| self.arg_names.append(arg_name) | |||||
| @classmethod | @classmethod | ||||
| def make(cls, *args, **kwargs): | def make(cls, *args, **kwargs): | ||||
| @@ -110,15 +164,16 @@ class Call(Expr): | |||||
| def interpret(self, *inputs): | def interpret(self, *inputs): | ||||
| mod = inputs[0] | mod = inputs[0] | ||||
| args = inputs[1:] | args = inputs[1:] | ||||
| outputs = mod(*args) | |||||
| outputs = getattr(mod, self.method)(*args, **self.kwargs) | |||||
| if isinstance(outputs, RawTensor): | if isinstance(outputs, RawTensor): | ||||
| outputs = (outputs,) | outputs = (outputs,) | ||||
| return outputs | return outputs | ||||
| def __repr__(self): | def __repr__(self): | ||||
| return "{} = Call({})({})".format( | |||||
| return "{} = CallMethod({}, {})({})".format( | |||||
| ", ".join(str(i) for i in self.outputs), | ", ".join(str(i) for i in self.outputs), | ||||
| self.inputs[0], | self.inputs[0], | ||||
| self.method, | |||||
| ", ".join(str(i) for i in self.inputs[1:]), | ", ".join(str(i) for i in self.inputs[1:]), | ||||
| ) | ) | ||||
| @@ -132,17 +187,6 @@ class Apply(Expr): | |||||
| self.opdef = opdef | self.opdef = opdef | ||||
| self.inputs = [] | self.inputs = [] | ||||
| def add_input(self, node): | |||||
| self.inputs.append(node) | |||||
| def add_outputs(self, references): | |||||
| self.outputs = [] | |||||
| if not isinstance(references, collections.Sequence): | |||||
| references = (references,) | |||||
| for i in references: | |||||
| self.outputs.append(NodeMixin.get_wrapped_type(i)(self)) | |||||
| @classmethod | @classmethod | ||||
| def make(cls, *args, **kwargs): | def make(cls, *args, **kwargs): | ||||
| expr = cls(*args, **kwargs) | expr = cls(*args, **kwargs) | ||||
| @@ -179,6 +223,40 @@ class Apply(Expr): | |||||
| return list(outputs) | return list(outputs) | ||||
| class CallFunction(Expr): | |||||
| def __init__(self, func): | |||||
| assert isinstance(func, Callable) | |||||
| self.func = func | |||||
| self.inputs = [] | |||||
| self.arg_names = [] | |||||
| self.kwargs = {} # const kwargs | |||||
| def add_input(self, node, arg_name): | |||||
| self.inputs.append(node) | |||||
| self.arg_names.append(arg_name) | |||||
| @classmethod | |||||
| def make(cls, *args, **kwargs): | |||||
| expr = cls(*args, **kwargs) | |||||
| active_module_tracer().current_scope().insert(expr) | |||||
| return expr | |||||
| def interpret(self, *inputs): | |||||
| inp_dict = dict([(name, node) for node, name in zip(inputs, self.arg_names)]) | |||||
| outputs = self.func(**inp_dict, **self.kwargs) | |||||
| outputs = ( | |||||
| outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,) | |||||
| ) | |||||
| return outputs | |||||
| def __repr__(self): | |||||
| return "{} = {}({})".format( | |||||
| ", ".join(str(i) for i in self.outputs), | |||||
| self.func.__module__ + "." + self.func.__name__, | |||||
| ", ".join(str(i) for i in self.inputs), | |||||
| ) | |||||
| # expr outputs = self.value | # expr outputs = self.value | ||||
| class Constant(Expr): | class Constant(Expr): | ||||
| value = None | value = None | ||||
| @@ -6,7 +6,11 @@ | |||||
| # Unless required by applicable law or agreed to in writing, | # Unless required by applicable law or agreed to in writing, | ||||
| # software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| import collections | |||||
| from ... import Tensor | |||||
| from ... import functional as F | |||||
| from ...core.tensor.array_method import ArrayMethodMixin | |||||
| from ...module import Module | from ...module import Module | ||||
| _active_module_tracer = None | _active_module_tracer = None | ||||
| @@ -23,12 +27,14 @@ def set_active_module_tracer(tracer): | |||||
| class module_tracer: | class module_tracer: | ||||
| # builtin types | |||||
| _opaque_types = set() | _opaque_types = set() | ||||
| _active_scopes = None | _active_scopes = None | ||||
| def __init__(self): | |||||
| def __init__(self, wrap_fn): | |||||
| self._active_scopes = [] | self._active_scopes = [] | ||||
| self.patcher = Patcher(wrap_fn) | |||||
| @classmethod | @classmethod | ||||
| def register_as_builtin(cls, mod): | def register_as_builtin(cls, mod): | ||||
| @@ -50,3 +56,105 @@ class module_tracer: | |||||
| if self._active_scopes: | if self._active_scopes: | ||||
| return self._active_scopes[-1] | return self._active_scopes[-1] | ||||
| return None | return None | ||||
| class PatchedFn: | |||||
| frame_dict = None | |||||
| name = None | |||||
| origin_fn = None | |||||
| def __init__(self, frame_dict, name): | |||||
| self.frame_dict = frame_dict | |||||
| self.name = name | |||||
| self.origin_fn = ( | |||||
| self.frame_dict[name] | |||||
| if isinstance(frame_dict, collections.abc.Mapping) | |||||
| else getattr(frame_dict, name) | |||||
| ) | |||||
| def set_func(self, func): | |||||
| if isinstance(self.frame_dict, collections.abc.Mapping): | |||||
| self.frame_dict[self.name] = func | |||||
| else: | |||||
| setattr(self.frame_dict, self.name, func) | |||||
| class Patcher: | |||||
| patched_fn_ids = set() | |||||
| _builtin_functions = [] | |||||
| _builtin_modules = [ | |||||
| F, | |||||
| F.distributed, | |||||
| F.elemwise, | |||||
| F.inplace, | |||||
| F.loss, | |||||
| F.math, | |||||
| F.metric, | |||||
| F.nn, | |||||
| F.quantized, | |||||
| F.tensor, | |||||
| F.utils, | |||||
| F.vision, | |||||
| ] | |||||
| _builtin_methods = [ | |||||
| Tensor, | |||||
| ArrayMethodMixin, | |||||
| ] | |||||
| def __init__(self, wrap_fn): | |||||
| self.patched_fn = [] | |||||
| self.visited_frames_ids = set() | |||||
| self.wrap_fn = wrap_fn | |||||
| for module in self._builtin_modules: | |||||
| self.patch_module(module) | |||||
| for cls in self._builtin_methods: | |||||
| self.patch_cls(cls) | |||||
| for i, j in self._builtin_functions: | |||||
| if id(i) not in self.visited_frames_ids: | |||||
| self.patch_function(i, j, self.wrap_fn) | |||||
| def patch_function(self, frame_dict, fn, wrap_fn): | |||||
| patched_fn = PatchedFn(frame_dict, fn) | |||||
| self.patched_fn_ids.add(id(patched_fn.origin_fn)) | |||||
| patched_fn.set_func(wrap_fn(patched_fn.origin_fn)) | |||||
| self.patched_fn.append(patched_fn) | |||||
| def patch_method(self, cls, name, wrap_fn): | |||||
| self.patch_function(cls, name, wrap_fn) | |||||
| def patch_cls(self, cls): | |||||
| import inspect | |||||
| if id(cls) not in self.visited_frames_ids: | |||||
| for k, v in cls.__dict__.items(): | |||||
| if inspect.isfunction(v) and not k.startswith("_"): | |||||
| self.patch_function(cls, k, self.wrap_fn) | |||||
| self.visited_frames_ids.add(id(cls)) | |||||
| def patch_module(self, module): | |||||
| import inspect | |||||
| if id(module.__dict__) not in self.visited_frames_ids: | |||||
| for k, v in module.__dict__.items(): | |||||
| if inspect.isfunction(v) and not k.startswith("_"): | |||||
| self.patch_function(module.__dict__, k, self.wrap_fn) | |||||
| self.visited_frames_ids.add(id(module.__dict__)) | |||||
| def auto_patch(self, frame_dict): | |||||
| if id(frame_dict) not in self.visited_frames_ids: | |||||
| for k, v in frame_dict.items(): | |||||
| if id(v) in self.patched_fn_ids: | |||||
| self.patch_function(frame_dict, k, self.wrap_fn) | |||||
| self.visited_frames_ids.add(id(frame_dict)) | |||||
| def __enter__(self): | |||||
| return self | |||||
| def __exit__(self, type, vlaue, trace): | |||||
| while self.patched_fn: | |||||
| pf = self.patched_fn.pop() | |||||
| pf.set_func(pf.origin_fn) | |||||
| self.visited_frames_ids.clear() | |||||
| @@ -34,6 +34,10 @@ class Node: | |||||
| Node.__total_id += 1 | Node.__total_id += 1 | ||||
| self._name = name | self._name = name | ||||
| def __setstate__(self, d): | |||||
| self.__dict__ = d | |||||
| Node.__total_id = max(Node.__total_id, self._id) + 1 | |||||
| def __repr__(self): | def __repr__(self): | ||||
| if self._name is None: | if self._name is None: | ||||
| return "%{}".format(self._id) | return "%{}".format(self._id) | ||||
| @@ -8,14 +8,25 @@ | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| import collections | import collections | ||||
| import copy | import copy | ||||
| import functools | |||||
| from typing import List, Type | from typing import List, Type | ||||
| from ... import module as M | from ... import module as M | ||||
| from ...core._imperative_rt.core2 import set_module_tracing, unset_module_tracing | |||||
| from ...core._imperative_rt.core2 import ( | |||||
| is_tracing_module, | |||||
| set_module_tracing, | |||||
| unset_module_tracing, | |||||
| ) | |||||
| from ...core.tensor.array_method import ArrayMethodMixin | |||||
| from ...module import Module | from ...module import Module | ||||
| from ...tensor import Tensor | from ...tensor import Tensor | ||||
| from .expr import Apply, Call, Constant, Expr, GetAttr, Input | |||||
| from .module_tracer import active_module_tracer, module_tracer, set_active_module_tracer | |||||
| from .expr import Apply, CallFunction, CallMethod, Constant, Expr, GetAttr, Input | |||||
| from .module_tracer import ( | |||||
| Patcher, | |||||
| active_module_tracer, | |||||
| module_tracer, | |||||
| set_active_module_tracer, | |||||
| ) | |||||
| from .node import ModuleNode, Node, NodeMixin, TensorNode | from .node import ModuleNode, Node, NodeMixin, TensorNode | ||||
| @@ -54,7 +65,9 @@ class InternalGraph: | |||||
| for n, v in zip(self._inputs, inputs): | for n, v in zip(self._inputs, inputs): | ||||
| node2value[n] = v | node2value[n] = v | ||||
| for expr in self._exprs: | for expr in self._exprs: | ||||
| values = expr.interpret(*list(node2value[i] for i in expr.inputs)) | |||||
| values = expr.interpret( | |||||
| *list(Expr.get_arg_value(i, node2value) for i in expr.inputs) | |||||
| ) | |||||
| for n, v in zip(expr.outputs, values): | for n, v in zip(expr.outputs, values): | ||||
| node2value[n] = v | node2value[n] = v | ||||
| return list(node2value[i] for i in self._outputs) | return list(node2value[i] for i in self._outputs) | ||||
| @@ -67,6 +80,41 @@ class InternalGraph: | |||||
| ) | ) | ||||
| def _wrapped_function(orig_func): | |||||
| @functools.wraps(orig_func) | |||||
| def wrapped_fn(*inputs, **kwargs): | |||||
| if is_tracing_module(): | |||||
| unset_module_tracing() | |||||
| const_kwargs = {} | |||||
| arg_names = orig_func.__code__.co_varnames | |||||
| if orig_func.__qualname__.split(".").__len__() > 1: | |||||
| # FIXME: a robust way to distinguish method and function. <XP> | |||||
| self = inputs[0] | |||||
| call_node = CallMethod.make(NodeMixin.get(self), orig_func.__name__) | |||||
| else: | |||||
| call_node = CallFunction.make(orig_func) | |||||
| def add_input(inp, varname=None): | |||||
| node = Expr.get_args_node(inp) | |||||
| if node is not None: | |||||
| call_node.add_input(node, varname) | |||||
| else: | |||||
| const_kwargs[varname] = inp | |||||
| for ind, inp in enumerate(inputs): | |||||
| add_input(inp, arg_names[ind]) | |||||
| for k, v in kwargs.items(): | |||||
| add_input(v, k) | |||||
| call_node.kwargs = const_kwargs | |||||
| outputs = orig_func(*inputs, **kwargs) | |||||
| call_node.add_outputs(outputs) | |||||
| set_module_tracing() | |||||
| return outputs | |||||
| return orig_func(*inputs, **kwargs) | |||||
| return wrapped_fn | |||||
| class TracedModuleBuilder(NodeMixin): | class TracedModuleBuilder(NodeMixin): | ||||
| _mod = None # type: Module | _mod = None # type: Module | ||||
| @@ -120,7 +168,7 @@ class TracedModuleBuilder(NodeMixin): | |||||
| mark_constant(i) | mark_constant(i) | ||||
| for k, v in kwargs.items(): | for k, v in kwargs.items(): | ||||
| mark_constant(v) | mark_constant(v) | ||||
| callnode = Call.make(NodeMixin.get(self)) | |||||
| callnode = CallMethod.make(NodeMixin.get(self)) | |||||
| def add_input(x): | def add_input(x): | ||||
| callnode.add_input(NodeMixin.get(x)) | callnode.add_input(NodeMixin.get(x)) | ||||
| @@ -145,7 +193,8 @@ class TracedModuleBuilder(NodeMixin): | |||||
| ) | ) | ||||
| # prepare args and kwargs for inner graph | # prepare args and kwargs for inner graph | ||||
| def wrap(x): | def wrap(x): | ||||
| wrapped = copy.copy(x) # FIXME | |||||
| # wrapped = copy.copy(x) # FIXME | |||||
| wrapped = x # FIXME: <XP> | |||||
| NodeMixin.wrap( | NodeMixin.wrap( | ||||
| wrapped, | wrapped, | ||||
| lambda: Input.make(type=NodeMixin.get_wrapped_type(wrapped)), | lambda: Input.make(type=NodeMixin.get_wrapped_type(wrapped)), | ||||
| @@ -157,7 +206,9 @@ class TracedModuleBuilder(NodeMixin): | |||||
| args.append(wrap(i)) | args.append(wrap(i)) | ||||
| for k, v in kwargs.items(): | for k, v in kwargs.items(): | ||||
| kwargs[k] = wrap(v) | kwargs[k] = wrap(v) | ||||
| active_module_tracer().patcher.auto_patch( | |||||
| getattr(getattr(self._mod, "forward", self._mod), "__globals__", {}) | |||||
| ) | |||||
| outputs = type(self._mod).forward(self, *args, **kwargs) | outputs = type(self._mod).forward(self, *args, **kwargs) | ||||
| for i in ( | for i in ( | ||||
| @@ -171,11 +222,6 @@ class TracedModuleBuilder(NodeMixin): | |||||
| # rebind output to outer graph | # rebind output to outer graph | ||||
| callnode.add_outputs(outputs) | callnode.add_outputs(outputs) | ||||
| for i, node in zip( | |||||
| outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,), | |||||
| callnode.outputs, | |||||
| ): | |||||
| NodeMixin.wrap_safe(i, node) | |||||
| return outputs | return outputs | ||||
| def __getattr__(self, name): | def __getattr__(self, name): | ||||
| @@ -229,6 +275,55 @@ class TracedModule(Module): | |||||
| rst = rst[0] | rst = rst[0] | ||||
| return rst | return rst | ||||
| @property | |||||
| def all_exprs(self): | |||||
| """ | |||||
| Visit all ``Expr``s in the graph recursively. | |||||
| :return: List[Expr] | |||||
| """ | |||||
| in_nodes = [i.expr for i in self.m_node.graph._inputs if not i is self] | |||||
| def _flatten_submodule(module, call=None): | |||||
| if not isinstance(module, TracedModule): | |||||
| call.inputs[0] = module | |||||
| return (call,) | |||||
| exprs = [] | |||||
| graph = module.m_node.graph | |||||
| for expr in graph._exprs: | |||||
| # replace inputs for submodule's expr | |||||
| for idx, inp in enumerate(expr.inputs): | |||||
| if call and inp in graph._inputs: | |||||
| expr.inputs[idx] = call.inputs[idx] | |||||
| # replace outputs for submodule's expr | |||||
| for idx, outp in enumerate(expr.outputs): | |||||
| if call and outp in graph._outputs: | |||||
| expr.outputs[idx] = call.outputs[idx] | |||||
| if isinstance(expr, GetAttr): | |||||
| # replace GetAttr with Constant | |||||
| if isinstance(expr.outputs[0], TensorNode): | |||||
| const = Constant(getattr(module, expr.name)) | |||||
| const.outputs = expr.outputs | |||||
| exprs.append(const) | |||||
| elif isinstance(expr, CallMethod): | |||||
| obj_node = expr.inputs[0] | |||||
| if isinstance(obj_node, ModuleNode): | |||||
| (obj,) = expr.inputs[0].expr.interpret(module) | |||||
| exprs.extend(_flatten_submodule(obj, expr)) | |||||
| else: | |||||
| exprs.append(expr) | |||||
| else: | |||||
| exprs.append(expr) | |||||
| return exprs | |||||
| return in_nodes + _flatten_submodule(self) | |||||
| def __getstate__(self): | def __getstate__(self): | ||||
| d = self.__dict__ | d = self.__dict__ | ||||
| for k in Module.__dict__: | for k in Module.__dict__: | ||||
| @@ -273,23 +368,23 @@ def trace_module(mod: Module, *inputs: Tensor, **kwargs: Tensor) -> TracedModule | |||||
| assert active_module_tracer() is None | assert active_module_tracer() is None | ||||
| try: | try: | ||||
| set_module_tracing() | set_module_tracing() | ||||
| set_active_module_tracer(module_tracer()) | |||||
| global_scope = InternalGraph() | |||||
| active_module_tracer().push_scope(global_scope) | |||||
| set_active_module_tracer(module_tracer(_wrapped_function)) | |||||
| with active_module_tracer().patcher: | |||||
| global_scope = InternalGraph() | |||||
| active_module_tracer().push_scope(global_scope) | |||||
| builder = TracedModuleBuilder(mod) | |||||
| NodeMixin.wrap_safe(builder, Input.make("TopModule", ModuleNode)) | |||||
| builder = TracedModuleBuilder(mod) | |||||
| NodeMixin.wrap_safe(builder, Input.make("TopModule", ModuleNode)) | |||||
| for _, i in enumerate(inputs): | |||||
| NodeMixin.wrap_safe(i, Input.make("arg_{}".format(_))) | |||||
| for k, v in kwargs.items(): | |||||
| NodeMixin.wrap_safe(v, Input.make("kwarg_{}".format(k))) | |||||
| for _, i in enumerate(inputs): | |||||
| NodeMixin.wrap_safe(i, Input.make("arg_{}".format(_))) | |||||
| for k, v in kwargs.items(): | |||||
| NodeMixin.wrap_safe(v, Input.make("kwarg_{}".format(k))) | |||||
| builder(*inputs, **kwargs) | |||||
| active_module_tracer().pop_scope() | |||||
| builder(*inputs, **kwargs) | |||||
| active_module_tracer().pop_scope() | |||||
| return builder.build() | |||||
| return builder.build() | |||||
| finally: | finally: | ||||
| set_active_module_tracer(None) | set_active_module_tracer(None) | ||||
| unset_module_tracing() | unset_module_tracing() | ||||