| @@ -130,3 +130,4 @@ import megengine.optimizer | |||||
| import megengine.quantization | import megengine.quantization | ||||
| import megengine.random | import megengine.random | ||||
| import megengine.utils | import megengine.utils | ||||
| import megengine.experimental | |||||
| @@ -6,4 +6,5 @@ | |||||
| # Unless required by applicable law or agreed to in writing, | # Unless required by applicable law or agreed to in writing, | ||||
| # software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| from . import traced_module | |||||
| from .weight_scaler import get_scaled_model | from .weight_scaler import get_scaled_model | ||||
| @@ -5,3 +5,15 @@ | |||||
| # Unless required by applicable law or agreed to in writing, | # Unless required by applicable law or agreed to in writing, | ||||
| # software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| from ...core._imperative_rt.core2 import set_cpp_apply_module_trace | |||||
| from .traced_module import ( | |||||
| TracedModule, | |||||
| _register_all_builtin_module, | |||||
| cpp_apply_module_trace, | |||||
| register_as_builtin, | |||||
| trace_module, | |||||
| ) | |||||
| _register_all_builtin_module() | |||||
| set_cpp_apply_module_trace(cpp_apply_module_trace) | |||||
| @@ -0,0 +1,215 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| import collections | |||||
| from typing import List | |||||
| from ...core._imperative_rt import OpDef | |||||
| from ...core._imperative_rt.core2 import Tensor as RawTensor | |||||
| from ...core._imperative_rt.core2 import apply, set_module_tracing, unset_module_tracing | |||||
| from ...core.ops.special import Const | |||||
| from ...tensor import Tensor | |||||
| from .module_tracer import active_module_tracer | |||||
| from .node import ModuleNode, Node, NodeMixin, TensorNode | |||||
| class Expr: | |||||
| """ | |||||
| ``Expr`` represents the operations(i.e. Call, Apply, GetAttr, Input, Constant) on ``Node``. | |||||
| """ | |||||
| inputs = None # type: List[Node] | |||||
| outputs = None # type: List[Node] | |||||
| # expr: None (i.e. fake expression which is used to mark input) | |||||
| class Input(Expr): | |||||
| name = None | |||||
| def __init__(self, name=None, type=None): | |||||
| self.inputs = [] | |||||
| node_cls = type if type else Node | |||||
| self.outputs = [ | |||||
| node_cls(self, name=name), | |||||
| ] | |||||
| self.name = name | |||||
| @classmethod | |||||
| def make(cls, *args, **kwargs): | |||||
| expr = cls(*args, **kwargs) | |||||
| active_module_tracer().current_scope().add_input(expr.outputs[0]) | |||||
| return expr.outputs[0] | |||||
| def __repr__(self): | |||||
| return "{} = Input({})".format(self.outputs[0], self.name) | |||||
| # expr: outputs = getattr(inputs[0], self.name) | |||||
| class GetAttr(Expr): | |||||
| name = None | |||||
| def __init__(self, module, name, type=None): | |||||
| assert isinstance(module, ModuleNode) | |||||
| self.inputs = [ | |||||
| module, | |||||
| ] | |||||
| self.name = name | |||||
| node_cls = type if type else Node | |||||
| self.outputs = [ | |||||
| node_cls(self), | |||||
| ] | |||||
| @classmethod | |||||
| def make(cls, *args, **kwargs): | |||||
| expr = cls(*args, **kwargs) | |||||
| active_module_tracer().current_scope().insert(expr) | |||||
| expr.outputs[0]._name = expr.name | |||||
| return expr.outputs[0] | |||||
| def interpret(self, *inputs): | |||||
| return (getattr(inputs[0], self.name),) | |||||
| def __repr__(self): | |||||
| return '{} = GetAttr({}, "{}")'.format( | |||||
| self.outputs[0], self.inputs[0], self.name | |||||
| ) | |||||
| # expr: outputs = inputs[0].__call__(*inputs[1:]) | |||||
| class Call(Expr): | |||||
| def __init__(self, module): | |||||
| assert isinstance(module, ModuleNode) | |||||
| self.inputs = [ | |||||
| module, | |||||
| ] | |||||
| def add_input(self, node): | |||||
| self.inputs.append(node) | |||||
| def add_outputs(self, references): | |||||
| self.outputs = [] | |||||
| if not isinstance(references, collections.Sequence): | |||||
| references = (references,) | |||||
| for i in references: | |||||
| self.outputs.append(NodeMixin.get_wrapped_type(i)(self)) | |||||
| @classmethod | |||||
| def make(cls, *args, **kwargs): | |||||
| expr = cls(*args, **kwargs) | |||||
| active_module_tracer().current_scope().insert(expr) | |||||
| return expr | |||||
| def interpret(self, *inputs): | |||||
| mod = inputs[0] | |||||
| args = inputs[1:] | |||||
| outputs = mod(*args) | |||||
| if isinstance(outputs, RawTensor): | |||||
| outputs = (outputs,) | |||||
| return outputs | |||||
| def __repr__(self): | |||||
| return "{} = Call({})({})".format( | |||||
| ", ".join(str(i) for i in self.outputs), | |||||
| self.inputs[0], | |||||
| ", ".join(str(i) for i in self.inputs[1:]), | |||||
| ) | |||||
| # expr: outputs = apply(self.opdef, *inputs) | |||||
| class Apply(Expr): | |||||
| opdef = None | |||||
| def __init__(self, opdef): | |||||
| assert isinstance(opdef, OpDef) | |||||
| self.opdef = opdef | |||||
| self.inputs = [] | |||||
| def add_input(self, node): | |||||
| self.inputs.append(node) | |||||
| def add_outputs(self, references): | |||||
| self.outputs = [] | |||||
| if not isinstance(references, collections.Sequence): | |||||
| references = (references,) | |||||
| for i in references: | |||||
| self.outputs.append(NodeMixin.get_wrapped_type(i)(self)) | |||||
| @classmethod | |||||
| def make(cls, *args, **kwargs): | |||||
| expr = cls(*args, **kwargs) | |||||
| active_module_tracer().current_scope().insert(expr) | |||||
| return expr | |||||
| def interpret(self, *inputs): | |||||
| return apply(self.opdef, *inputs) | |||||
| def __repr__(self): | |||||
| return "{} = {}({})".format( | |||||
| ", ".join(str(i) for i in self.outputs), | |||||
| self.opdef, | |||||
| ", ".join(str(i) for i in self.inputs), | |||||
| ) | |||||
| @classmethod | |||||
| def apply_module_trace_hook(cls, opdef, *inputs): | |||||
| for i in inputs: | |||||
| node = NodeMixin.get(i, None) | |||||
| if node is None: # capture as constant | |||||
| NodeMixin.wrap_safe(i, Constant.make(i)) | |||||
| apply_node = cls.make(opdef) | |||||
| for i in inputs: | |||||
| apply_node.add_input(NodeMixin.get(i)) | |||||
| unset_module_tracing() | |||||
| outputs = apply(opdef, *inputs) | |||||
| set_module_tracing() | |||||
| apply_node.add_outputs(outputs) | |||||
| for n, v in zip(apply_node.outputs, outputs): | |||||
| NodeMixin.wrap_safe(v, n) | |||||
| return list(outputs) | |||||
| # expr outputs = self.value | |||||
| class Constant(Expr): | |||||
| value = None | |||||
| # TODO: constant cache to reduce the size of dumped model | |||||
| _constant_cache = {} | |||||
| def __init__(self, c): | |||||
| # TODO: type check, since not all types should be captured as constant | |||||
| self.value = c | |||||
| self.inputs = [] | |||||
| node_cls = NodeMixin.get_wrapped_type(c) | |||||
| self.outputs = [ | |||||
| node_cls(self), | |||||
| ] | |||||
| @classmethod | |||||
| def make(cls, *args, **kwargs): | |||||
| expr = cls(*args, **kwargs) | |||||
| active_module_tracer().current_scope().insert(expr) | |||||
| return expr.outputs[0] | |||||
| def interpret(self, *inputs): | |||||
| if isinstance(self.value, RawTensor): | |||||
| return Const(self.value.numpy())() | |||||
| return (self.value,) | |||||
| def __repr__(self): | |||||
| return "{} = Constant({})".format(self.outputs[0], self.value) | |||||
| def __getstate__(self): | |||||
| state = self.__dict__.copy() | |||||
| if isinstance(self.value, RawTensor): | |||||
| state["value"] = Tensor(self.value) | |||||
| return state | |||||
| @@ -0,0 +1,52 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| from ...module import Module | |||||
| _active_module_tracer = None | |||||
| def active_module_tracer(): | |||||
| return _active_module_tracer | |||||
| def set_active_module_tracer(tracer): | |||||
| global _active_module_tracer | |||||
| _active_module_tracer = tracer | |||||
| class module_tracer: | |||||
| _opaque_types = set() | |||||
| _active_scopes = None | |||||
| def __init__(self): | |||||
| self._active_scopes = [] | |||||
| @classmethod | |||||
| def register_as_builtin(cls, mod): | |||||
| assert issubclass(mod, Module) | |||||
| cls._opaque_types.add(mod) | |||||
| return mod | |||||
| @classmethod | |||||
| def is_builtin(cls, mod): | |||||
| return type(mod) in cls._opaque_types | |||||
| def push_scope(self, scope): | |||||
| self._active_scopes.append(scope) | |||||
| def pop_scope(self): | |||||
| self._active_scopes.pop() | |||||
| def current_scope(self): | |||||
| if self._active_scopes: | |||||
| return self._active_scopes[-1] | |||||
| return None | |||||
| @@ -0,0 +1,123 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| from typing import Any, Dict, Tuple, Type | |||||
| import numpy | |||||
| from ...core._imperative_rt.core2 import Tensor as RawTensor | |||||
| from ...module import Module | |||||
| from ...tensor import Tensor | |||||
| class Node: | |||||
| """ | |||||
| ``Node`` represents the variables (Tensor/Module/other python object) used in Module's forward method. They are inputs/outputs of Expr(the operations on variables). | |||||
| param expr: the Expr which produces the node | |||||
| param name: the name of the node | |||||
| """ | |||||
| expr = None | |||||
| __total_id = 0 | |||||
| _id = None | |||||
| _name = None | |||||
| def __init__(self, expr: "Expr", name: str = None): | |||||
| self.expr = expr | |||||
| self._id = Node.__total_id | |||||
| Node.__total_id += 1 | |||||
| self._name = name | |||||
| def __repr__(self): | |||||
| if self._name is None: | |||||
| return "%{}".format(self._id) | |||||
| else: | |||||
| return "%{}".format(self._name) | |||||
| class ModuleNode(Node): | |||||
| """ | |||||
| ``ModuleNode`` represents the Module objects. | |||||
| Attributes: | |||||
| module_type: type of the Module correspending to the ModuleNode | |||||
| graph: the InternalGraph which will be interpreted when call Module's forward method | |||||
| attr_type_map: record the type of Module's attributes | |||||
| """ | |||||
| module_type = Module # type: Type[Module] | |||||
| graph = None | |||||
| attr_type_map = None # type: Dict[str, Type[Any]] | |||||
| def __repr__(self): | |||||
| if self._name is None: | |||||
| return "%{}({})".format(self._id, self.module_type.__name__) | |||||
| else: | |||||
| return "%{}({})".format(self._name, self.module_type.__name__) | |||||
| class TensorNode(Node): | |||||
| """ | |||||
| ``TensorNode`` represents the Tensor objects. | |||||
| """ | |||||
| shape = None # type: Tuple[int] | |||||
| dtype = None # type: numpy.dtype | |||||
| def __repr__(self): | |||||
| if self._name is None: | |||||
| return "%{}(Tensor)".format(self._id) | |||||
| else: | |||||
| return "%{}(Tensor)".format(self._name) | |||||
| class NodeMixin: | |||||
| __node = None | |||||
| @classmethod | |||||
| def wrap(cls, value, node): | |||||
| if isinstance(value, (NodeMixin, RawTensor)): | |||||
| if isinstance(node, Node): | |||||
| if isinstance(value, RawTensor): | |||||
| node.dtype = value.dtype | |||||
| node.shape = ( | |||||
| value._tuple_shape if isinstance(value, Tensor) else value.shape | |||||
| ) | |||||
| setattr(value, "_NodeMixin__node", node) | |||||
| else: | |||||
| assert callable(node) | |||||
| n = node() | |||||
| if isinstance(value, RawTensor): | |||||
| n.dtype = value.dtype | |||||
| n.shape = ( | |||||
| value._tuple_shape if isinstance(value, Tensor) else value.shape | |||||
| ) | |||||
| setattr(value, "_NodeMixin__node", n) | |||||
| @classmethod | |||||
| def wrap_safe(cls, value, node): | |||||
| assert isinstance(value, (NodeMixin, RawTensor)) | |||||
| if isinstance(value, RawTensor): | |||||
| node.dtype = value.dtype | |||||
| node.shape = ( | |||||
| value._tuple_shape if isinstance(value, Tensor) else value.shape | |||||
| ) | |||||
| setattr(value, "_NodeMixin__node", node) | |||||
| @classmethod | |||||
| def get(cls, value, *default): | |||||
| return getattr(value, "_NodeMixin__node", *default) | |||||
| @classmethod | |||||
| def get_wrapped_type(cls, value): | |||||
| if isinstance(value, RawTensor): | |||||
| return TensorNode | |||||
| if isinstance(value, (Module, NodeMixin)): | |||||
| return ModuleNode | |||||
| return Node | |||||
| @@ -0,0 +1,295 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| import collections | |||||
| import copy | |||||
| from typing import List, Type | |||||
| from ... import module as M | |||||
| from ...core._imperative_rt.core2 import set_module_tracing, unset_module_tracing | |||||
| from ...module import Module | |||||
| from ...tensor import Tensor | |||||
| from .expr import Apply, Call, Constant, Expr, GetAttr, Input | |||||
| from .module_tracer import active_module_tracer, module_tracer, set_active_module_tracer | |||||
| from .node import ModuleNode, Node, NodeMixin, TensorNode | |||||
| class InternalGraph: | |||||
| """ | |||||
| ``InternalGraph`` is a graph consist of ``Node`` and ``Expr``, it is used to represent the execution procedure of Module's forward method. | |||||
| Attributes: | |||||
| _exprs: List of Exprs in order of execution | |||||
| _inputs: Input Nodes of InternalGraph | |||||
| _outputs: Output Nodes of InternalGraph | |||||
| """ | |||||
| _exprs = None # type: List[Expr] | |||||
| _inputs = None # type: List[Node] | |||||
| _outputs = None # type: List[Node] | |||||
| def __init__(self): | |||||
| self._exprs = [] | |||||
| self._inputs = [] | |||||
| self._outputs = [] | |||||
| def insert(self, expr): | |||||
| self._exprs.append(expr) | |||||
| def add_input(self, i): | |||||
| self._inputs.append(i) | |||||
| def add_output(self, o): | |||||
| self._outputs.append(o) | |||||
| def interpret(self, *inputs): | |||||
| # TODO: support kwargs ? | |||||
| # TODO: skip expressions which are independent and have no side effect | |||||
| node2value = {} | |||||
| for n, v in zip(self._inputs, inputs): | |||||
| node2value[n] = v | |||||
| for expr in self._exprs: | |||||
| values = expr.interpret(*list(node2value[i] for i in expr.inputs)) | |||||
| for n, v in zip(expr.outputs, values): | |||||
| node2value[n] = v | |||||
| return list(node2value[i] for i in self._outputs) | |||||
| def __repr__(self): | |||||
| return "InternalGraph ({}) {{\n\t{}\n\treturn {}\n}}".format( | |||||
| ", ".join(str(i) for i in self._inputs), | |||||
| "\n\t".join(str(i) for i in self._exprs), | |||||
| ", ".join(str(i) for i in self._outputs), | |||||
| ) | |||||
| class TracedModuleBuilder(NodeMixin): | |||||
| _mod = None # type: Module | |||||
| _body = None # type: InternalGraph | |||||
| _is_builtin = None # type: bool | |||||
| __builder_attributes__ = [ | |||||
| "_mod", | |||||
| "_body", | |||||
| "_NodeMixin__node", | |||||
| "_is_builtin", | |||||
| "_is_traced", | |||||
| "build", | |||||
| ] | |||||
| def __init__(self, mod): | |||||
| super(TracedModuleBuilder, self).__init__() | |||||
| self._mod = mod | |||||
| self._body = InternalGraph() | |||||
| self._is_traced = False | |||||
| self._is_builtin = module_tracer.is_builtin(mod) | |||||
| def build(self): | |||||
| if self._is_builtin: | |||||
| node = NodeMixin.get(self) | |||||
| node.module_type = type(self._mod) | |||||
| return self._mod | |||||
| else: | |||||
| node = NodeMixin.get(self) | |||||
| node.graph = self._body | |||||
| node.attr_type_map = {} | |||||
| traced_module = TracedModule(node) | |||||
| for k, v in self.__dict__.items(): | |||||
| if k not in TracedModuleBuilder.__builder_attributes__: | |||||
| if isinstance(v, TracedModuleBuilder): | |||||
| v = v.build() | |||||
| setattr(traced_module, k, v) | |||||
| traced_module.m_node.attr_type_map[k] = type(v) | |||||
| return traced_module | |||||
| def __call__(self, *inputs, **kwargs): | |||||
| assert isinstance(self._mod, Module) | |||||
| # prepare args and kwargs for inner graph | |||||
| def mark_constant(x): | |||||
| node = NodeMixin.get(x, None) | |||||
| if node is None: # capture as constant | |||||
| NodeMixin.wrap(x, lambda: Constant.make(x)) | |||||
| for i in inputs: | |||||
| mark_constant(i) | |||||
| for k, v in kwargs.items(): | |||||
| mark_constant(v) | |||||
| callnode = Call.make(NodeMixin.get(self)) | |||||
| def add_input(x): | |||||
| callnode.add_input(NodeMixin.get(x)) | |||||
| for i in inputs: | |||||
| add_input(i) | |||||
| for k, v in kwargs.items(): | |||||
| add_input(v) | |||||
| if self._is_builtin or self._is_traced: | |||||
| unset_module_tracing() | |||||
| outputs = self._mod(*inputs, **kwargs) | |||||
| set_module_tracing() | |||||
| if self._is_builtin: | |||||
| self._body = None | |||||
| else: | |||||
| active_module_tracer().push_scope(self._body) | |||||
| # rebind self to new input node | |||||
| orig_self = NodeMixin.get(self) | |||||
| NodeMixin.wrap_safe( | |||||
| self, Input.make("self", NodeMixin.get_wrapped_type(self)) | |||||
| ) | |||||
| # prepare args and kwargs for inner graph | |||||
| def wrap(x): | |||||
| wrapped = copy.copy(x) # FIXME | |||||
| NodeMixin.wrap( | |||||
| wrapped, | |||||
| lambda: Input.make(type=NodeMixin.get_wrapped_type(wrapped)), | |||||
| ) | |||||
| return wrapped | |||||
| args = [] | |||||
| for i in inputs: | |||||
| args.append(wrap(i)) | |||||
| for k, v in kwargs.items(): | |||||
| kwargs[k] = wrap(v) | |||||
| outputs = type(self._mod).forward(self, *args, **kwargs) | |||||
| for i in ( | |||||
| outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,) | |||||
| ): | |||||
| active_module_tracer().current_scope().add_output(NodeMixin.get(i)) | |||||
| NodeMixin.wrap_safe(self, orig_self) | |||||
| self._is_traced = True | |||||
| active_module_tracer().pop_scope() | |||||
| # rebind output to outer graph | |||||
| callnode.add_outputs(outputs) | |||||
| for i, node in zip( | |||||
| outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,), | |||||
| callnode.outputs, | |||||
| ): | |||||
| NodeMixin.wrap_safe(i, node) | |||||
| return outputs | |||||
| def __getattr__(self, name): | |||||
| if name not in self._mod.__dict__: | |||||
| attr = getattr(type(self._mod), name).__get__(self, type(self)) | |||||
| else: | |||||
| attr = getattr(self._mod, name) | |||||
| if isinstance(attr, Module): | |||||
| attr = TracedModuleBuilder(attr) | |||||
| setattr(self, name, attr) | |||||
| NodeMixin.wrap( | |||||
| attr, | |||||
| lambda: GetAttr.make( | |||||
| NodeMixin.get(self), name, type=NodeMixin.get_wrapped_type(attr) | |||||
| ), | |||||
| ) | |||||
| return attr | |||||
| def __getattribute__(self, name): | |||||
| if name in TracedModuleBuilder.__builder_attributes__: | |||||
| return super().__getattribute__(name) | |||||
| else: | |||||
| wrapped = super().__getattribute__(name) | |||||
| if name in self._mod.__dict__ and not NodeMixin.get(wrapped, None): | |||||
| assert not self._is_builtin | |||||
| NodeMixin.wrap( | |||||
| wrapped, | |||||
| lambda: GetAttr.make( | |||||
| NodeMixin.get(self), | |||||
| name, | |||||
| type=NodeMixin.get_wrapped_type(wrapped), | |||||
| ), | |||||
| ) | |||||
| return wrapped | |||||
| class TracedModule(Module): | |||||
| """ | |||||
| `TracedModule` is the Module created by tracing normal module. It owns a ModuleNode(m_node), and will interpret the m_node.graph when it is called. | |||||
| """ | |||||
| m_node = None # type: ModuleNode | |||||
| def __init__(self, node): | |||||
| super(TracedModule, self).__init__() | |||||
| self.m_node = node | |||||
| def forward(self, *inputs): | |||||
| rst = self.m_node.graph.interpret(self, *inputs) | |||||
| if len(rst) == 1: | |||||
| rst = rst[0] | |||||
| return rst | |||||
| def __getstate__(self): | |||||
| d = self.__dict__ | |||||
| for k in Module.__dict__: | |||||
| d.pop(k, None) | |||||
| return d | |||||
| def cpp_apply_module_trace(opdef, *args): | |||||
| return Apply.apply_module_trace_hook(opdef, *args) | |||||
| def register_as_builtin(mod_cls: Type[Module]) -> None: | |||||
| """ | |||||
| Registers class ``mod_cls`` (subclass of megengine.module.Module) as builtin module. | |||||
| param mod_cls: the Module class which will be threated as builtin module in tracing | |||||
| """ | |||||
| module_tracer.register_as_builtin(mod_cls) | |||||
| def _register_all_builtin_module(): | |||||
| from inspect import getmembers, isclass | |||||
| for sub_mod in [M, M.qat, M.quantized]: | |||||
| for m in getmembers(sub_mod): | |||||
| if ( | |||||
| isclass(m[1]) | |||||
| and issubclass(m[1], M.Module) | |||||
| and m[1] is not M.Sequential | |||||
| ): | |||||
| module_tracer.register_as_builtin(m[1]) | |||||
| def trace_module(mod: Module, *inputs: Tensor, **kwargs: Tensor) -> TracedModule: | |||||
| """ | |||||
| Traces module ``mod`` and returns corresponding TracedModule. | |||||
| param mod: the module will be converted to TracedModule | |||||
| param input: the positional arguments passed to forward method of ``mod`` | |||||
| param kwargs: the keyword arguments passed to forward method of ``mod`` | |||||
| """ | |||||
| assert active_module_tracer() is None | |||||
| try: | |||||
| set_module_tracing() | |||||
| set_active_module_tracer(module_tracer()) | |||||
| global_scope = InternalGraph() | |||||
| active_module_tracer().push_scope(global_scope) | |||||
| builder = TracedModuleBuilder(mod) | |||||
| NodeMixin.wrap_safe(builder, Input.make("TopModule", ModuleNode)) | |||||
| for _, i in enumerate(inputs): | |||||
| NodeMixin.wrap_safe(i, Input.make("arg_{}".format(_))) | |||||
| for k, v in kwargs.items(): | |||||
| NodeMixin.wrap_safe(v, Input.make("kwarg_{}".format(k))) | |||||
| builder(*inputs, **kwargs) | |||||
| active_module_tracer().pop_scope() | |||||
| return builder.build() | |||||
| finally: | |||||
| set_active_module_tracer(None) | |||||
| unset_module_tracing() | |||||