GitOrigin-RevId: a43ad1273c
tags/v1.7.0
| @@ -11,7 +11,7 @@ import collections | |||||
| import copy | import copy | ||||
| import inspect | import inspect | ||||
| import re | import re | ||||
| from typing import Callable, Dict, List | |||||
| from typing import Callable, Dict, List, Optional, Union | |||||
| from ..core._imperative_rt import OpDef | from ..core._imperative_rt import OpDef | ||||
| from ..core._imperative_rt.core2 import Tensor as RawTensor | from ..core._imperative_rt.core2 import Tensor as RawTensor | ||||
| @@ -32,6 +32,43 @@ def rstrip(s: str, __chars: str): | |||||
| return s | return s | ||||
| def get_suffix_name(prefix: str, name: str): | |||||
| if prefix == name: | |||||
| return "" | |||||
| matchd = re.compile("^%s\.(.*)" % prefix).match(name) | |||||
| if matchd is None: | |||||
| return None | |||||
| return matchd.group(1) | |||||
| def is_call_module(expr): | |||||
| return ( | |||||
| isinstance(expr, CallMethod) | |||||
| and isinstance(expr.inputs[0], ModuleNode) | |||||
| and expr.method == "__call__" | |||||
| ) | |||||
| def is_call_tensor_method(expr): | |||||
| return isinstance(expr, CallMethod) and not is_call_module(expr) | |||||
| def is_call_function(expr): | |||||
| return isinstance(expr, CallFunction) | |||||
| def is_constant(expr): | |||||
| return isinstance(expr, Constant) | |||||
| def is_getattr(expr): | |||||
| return isinstance(expr, GetAttr) | |||||
| def is_apply_def(expr): | |||||
| return isinstance(expr, Apply) | |||||
| class Expr: | class Expr: | ||||
| r"""``Expr`` represents the operations (i.e. ``CallMethod``, ``CallFunction``, ``Apply``, | r"""``Expr`` represents the operations (i.e. ``CallMethod``, ``CallFunction``, ``Apply``, | ||||
| ``GetAttr``, ``Input``, ``Constant``) on ``Node``. | ``GetAttr``, ``Input``, ``Constant``) on ``Node``. | ||||
| @@ -76,50 +113,19 @@ class Expr: | |||||
| self.const_val.append((idx, val)) | self.const_val.append((idx, val)) | ||||
| def add_outputs(self, outputs): | def add_outputs(self, outputs): | ||||
| assert active_module_tracer() is not None | |||||
| self.outputs = [] | self.outputs = [] | ||||
| if outputs is not None: | |||||
| if not isinstance(outputs, collections.Sequence): | |||||
| outputs = (outputs,) | |||||
| name = None | |||||
| orig_name = None | |||||
| if isinstance(self, CallMethod): | |||||
| name = self.inputs[0]._name | |||||
| orig_name = self.inputs[0]._orig_name | |||||
| assert isinstance(name, str), "The name of ({}) must be a str".format( | |||||
| self.inputs[0] | |||||
| ) | |||||
| assert isinstance( | |||||
| orig_name, str | |||||
| ), "The orig_name of ({}) must be a str".format(self.inputs[0]) | |||||
| name = rstrip(name, "_out") | |||||
| if self.method == "__call__": | |||||
| name += "_out" | |||||
| orig_name += "_out" | |||||
| else: | |||||
| strip_method = self.method.strip("_") | |||||
| name = "%s_out" % strip_method | |||||
| orig_name = name | |||||
| elif isinstance(self, CallFunction): | |||||
| name = self.func.__name__ + "_out" | |||||
| elif isinstance(self, Apply): | |||||
| name = str(self.opdef).lower() + "_out" | |||||
| for i in outputs: | |||||
| assert isinstance(i, RawTensor), "The output must be a Tensor" | |||||
| o_name = ( | |||||
| active_module_tracer().current_scope()._create_unique_name(name) | |||||
| ) | |||||
| self.outputs.append( | |||||
| NodeMixin.get_wrapped_type(i)( | |||||
| expr=self, | |||||
| name=o_name, | |||||
| orig_name=orig_name if orig_name else o_name, | |||||
| ) | |||||
| ) | |||||
| for i, node in zip(outputs, self.outputs,): | |||||
| NodeMixin.wrap_safe(i, node) | |||||
| if outputs is None: | |||||
| return | |||||
| current_graph = active_module_tracer().current_scope() | |||||
| if not isinstance(outputs, collections.Sequence): | |||||
| outputs = (outputs,) | |||||
| for i in outputs: | |||||
| assert isinstance(i, RawTensor), "The output must be a Tensor" | |||||
| node = NodeMixin.get_wrapped_type(i)(expr=self, name="", qualname="",) | |||||
| NodeMixin.wrap_safe(i, node) | |||||
| self.outputs.append(node) | |||||
| current_graph._namespace.auto_naming_for_outputs(self) | |||||
| def unflatten_args(self, inputs): | def unflatten_args(self, inputs): | ||||
| if self.arg_def is not None: | if self.arg_def is not None: | ||||
| @@ -152,9 +158,7 @@ class Expr: | |||||
| ), "({}) must be generated before ({})".format(repl_node, self) | ), "({}) must be generated before ({})".format(repl_node, self) | ||||
| idx = self.inputs.index(node) | idx = self.inputs.index(node) | ||||
| self.inputs[idx] = repl_node | self.inputs[idx] = repl_node | ||||
| user_idx = node.users.index(self) | |||||
| assert user_idx >= 0 | |||||
| node.users.pop(user_idx) | |||||
| node.users.remove(self) | |||||
| repl_node.users.append(self) | repl_node.users.append(self) | ||||
| @property | @property | ||||
| @@ -197,26 +201,23 @@ class Input(Expr): | |||||
| r"""A fake Expr which is used to mark the input of graph.""" | r"""A fake Expr which is used to mark the input of graph.""" | ||||
| name = None | name = None | ||||
| def __init__(self, name=None, type=None, orig_name=None): | |||||
| def __init__(self, type: List[Node], name: str = "args", qualname: str = ""): | |||||
| super().__init__() | super().__init__() | ||||
| assert type in [ModuleNode, TensorNode] | |||||
| assert name and qualname | |||||
| self.inputs = [] | self.inputs = [] | ||||
| node_cls = type if type else Node | node_cls = type if type else Node | ||||
| if orig_name is None: | |||||
| orig_name = name | |||||
| self.outputs = [ | self.outputs = [ | ||||
| node_cls(self, name=name, orig_name=orig_name), | |||||
| node_cls(self, name=name, qualname=qualname), | |||||
| ] | ] | ||||
| self.name = name | self.name = name | ||||
| @classmethod | @classmethod | ||||
| def make(cls, *args, **kwargs): | def make(cls, *args, **kwargs): | ||||
| assert active_module_tracer() is not None | |||||
| expr = cls(*args, **kwargs) | expr = cls(*args, **kwargs) | ||||
| oup_node = expr.outputs[0] | |||||
| name = ( | |||||
| active_module_tracer().current_scope()._create_unique_name(oup_node._name) | |||||
| ) | |||||
| oup_node._name = name | |||||
| active_module_tracer().current_scope()._add_input(oup_node) | |||||
| out_node = expr.outputs[0] | |||||
| active_module_tracer().current_scope()._add_input(out_node) | |||||
| return expr.outputs[0] | return expr.outputs[0] | ||||
| def __repr__(self): | def __repr__(self): | ||||
| @@ -230,34 +231,41 @@ class GetAttr(Expr): | |||||
| name = None | name = None | ||||
| r"""name: the qualified name of the attribute to be retrieved.""" | r"""name: the qualified name of the attribute to be retrieved.""" | ||||
| def __init__(self, module, name, type=None, orig_name=None): | |||||
| def __init__( | |||||
| self, module: ModuleNode, type: Union[Node], attr_name: str, name: str = "", | |||||
| ): | |||||
| super().__init__() | super().__init__() | ||||
| assert isinstance(module, ModuleNode) | assert isinstance(module, ModuleNode) | ||||
| assert type in [TensorNode, ModuleNode] | |||||
| self.inputs = [ | self.inputs = [ | ||||
| module, | module, | ||||
| ] | ] | ||||
| module.users.append(self) | module.users.append(self) | ||||
| self.name = name | |||||
| node_cls = type if type else Node | |||||
| self.name = attr_name | |||||
| self.outputs = [ | self.outputs = [ | ||||
| node_cls(self, name=name, orig_name=orig_name), | |||||
| type(self, name=name, qualname="{}.{}".format(module.qualname, attr_name)), | |||||
| ] | ] | ||||
| @classmethod | @classmethod | ||||
| def make(cls, *args, **kwargs): | def make(cls, *args, **kwargs): | ||||
| assert active_module_tracer() is not None | |||||
| current_graph = active_module_tracer().current_scope() | |||||
| expr = cls(*args, **kwargs) | expr = cls(*args, **kwargs) | ||||
| module = expr.inputs[0] | |||||
| oup_name = expr.name | |||||
| while module._name != "self": | |||||
| oup_name = module._name + "_" + oup_name | |||||
| module = module.expr.inputs[0] | |||||
| oup_name = active_module_tracer().current_scope()._create_unique_name(oup_name) | |||||
| expr.outputs[0]._name = oup_name | |||||
| active_module_tracer().current_scope()._insert(expr) | |||||
| current_graph._namespace.auto_naming_for_outputs(expr) | |||||
| current_graph._insert(expr) | |||||
| return expr.outputs[0] | return expr.outputs[0] | ||||
| def interpret(self, *inputs): | def interpret(self, *inputs): | ||||
| return (getattr(inputs[0], self.name),) | |||||
| mod = inputs[0] | |||||
| module_path, _, name = self.name.rpartition(".") | |||||
| if module_path == "": | |||||
| return (getattr(mod, name),) | |||||
| module_names = module_path.split(".") | |||||
| for item in module_names: | |||||
| mod = getattr(mod, item) | |||||
| if not isinstance(mod, Module): | |||||
| raise AttributeError("`{}` is not an Module".format(item)) | |||||
| return (getattr(mod, name),) | |||||
| def __repr__(self): | def __repr__(self): | ||||
| out_type = "Tensor" | out_type = "Tensor" | ||||
| @@ -297,6 +305,7 @@ class CallMethod(Expr): | |||||
| @classmethod | @classmethod | ||||
| def make(cls, *args, **kwargs): | def make(cls, *args, **kwargs): | ||||
| assert active_module_tracer() is not None | |||||
| expr = cls(*args, **kwargs) | expr = cls(*args, **kwargs) | ||||
| active_module_tracer().current_scope()._insert(expr) | active_module_tracer().current_scope()._insert(expr) | ||||
| return expr | return expr | ||||
| @@ -362,6 +371,7 @@ class Apply(Expr): | |||||
| @classmethod | @classmethod | ||||
| def make(cls, *args, **kwargs): | def make(cls, *args, **kwargs): | ||||
| assert active_module_tracer() is not None | |||||
| expr = cls(*args, **kwargs) | expr = cls(*args, **kwargs) | ||||
| active_module_tracer().current_scope()._insert(expr) | active_module_tracer().current_scope()._insert(expr) | ||||
| return expr | return expr | ||||
| @@ -435,6 +445,7 @@ class CallFunction(Expr): | |||||
| @classmethod | @classmethod | ||||
| def make(cls, *args, **kwargs): | def make(cls, *args, **kwargs): | ||||
| assert active_module_tracer() is not None | |||||
| expr = cls(*args, **kwargs) | expr = cls(*args, **kwargs) | ||||
| active_module_tracer().current_scope()._insert(expr) | active_module_tracer().current_scope()._insert(expr) | ||||
| return expr | return expr | ||||
| @@ -474,7 +485,7 @@ class Constant(Expr): | |||||
| # TODO: constant cache to reduce the size of dumped model | # TODO: constant cache to reduce the size of dumped model | ||||
| _constant_cache = {} | _constant_cache = {} | ||||
| def __init__(self, c, name=None): | |||||
| def __init__(self, c, name: str = "", qualname: str = ""): | |||||
| super().__init__() | super().__init__() | ||||
| assert isinstance(c, (RawTensor, Module)) | assert isinstance(c, (RawTensor, Module)) | ||||
| if isinstance(c, Module): | if isinstance(c, Module): | ||||
| @@ -484,31 +495,16 @@ class Constant(Expr): | |||||
| self.inputs = [] | self.inputs = [] | ||||
| node_cls = NodeMixin.get_wrapped_type(c) | node_cls = NodeMixin.get_wrapped_type(c) | ||||
| self.outputs = [ | self.outputs = [ | ||||
| node_cls(self, name=name, orig_name=name), | |||||
| node_cls(self, name=name, qualname=qualname), | |||||
| ] | ] | ||||
| self.outputs[0]._name = name if name else "const_" + str(self._id) | |||||
| @classmethod | @classmethod | ||||
| def make(cls, *args, **kwargs): | def make(cls, *args, **kwargs): | ||||
| assert active_module_tracer() is not None | |||||
| expr = cls(*args, **kwargs) | expr = cls(*args, **kwargs) | ||||
| name = "const_module" if isinstance(expr.value, Module) else "const_tensor" | |||||
| full_name = name | |||||
| if ( | |||||
| isinstance(expr.value, RawTensor) | |||||
| and id(expr.value) in active_module_tracer().id2name | |||||
| ): | |||||
| full_name = active_module_tracer().id2name[id(expr.value)] | |||||
| scope_name = active_module_tracer().current_scope()._module_name | |||||
| if full_name and scope_name: | |||||
| full_name = ("self." + full_name)[len(scope_name) + 1 :] | |||||
| else: | |||||
| full_name = name | |||||
| else: | |||||
| full_name = name | |||||
| name = active_module_tracer().current_scope()._create_unique_name(full_name) | |||||
| expr.outputs[0]._name = name | |||||
| expr.outputs[0]._orig_name = full_name | |||||
| active_module_tracer().current_scope()._insert(expr) | |||||
| current_graph = active_module_tracer().current_scope() | |||||
| current_graph._namespace.auto_naming_for_outputs(expr) | |||||
| current_graph._insert(expr) | |||||
| return expr.outputs[0] | return expr.outputs[0] | ||||
| def interpret(self, *inputs): | def interpret(self, *inputs): | ||||
| @@ -128,10 +128,9 @@ class module_tracer: | |||||
| _active_scopes = None | _active_scopes = None | ||||
| def __init__(self, wrap_fn, id2name): | |||||
| def __init__(self, wrap_fn): | |||||
| self._active_scopes = [] | self._active_scopes = [] | ||||
| self.patcher = Patcher(wrap_fn) | self.patcher = Patcher(wrap_fn) | ||||
| self.id2name = id2name | |||||
| @classmethod | @classmethod | ||||
| def register_as_builtin(cls, mod): | def register_as_builtin(cls, mod): | ||||
| @@ -29,17 +29,15 @@ class Node: | |||||
| __total_id = 0 # type: int | __total_id = 0 # type: int | ||||
| _id = None # type: int | _id = None # type: int | ||||
| _top_graph = None # type: weakref.ReferenceType | _top_graph = None # type: weakref.ReferenceType | ||||
| _name = None # type: str | |||||
| _orig_name = None # type: str | |||||
| _format_spec = "" # type: str | _format_spec = "" # type: str | ||||
| def __init__(self, expr, name: str, orig_name: str): | |||||
| def __init__(self, expr, name: str, qualname: str): | |||||
| self.expr = expr | self.expr = expr | ||||
| self.users = [] # List[Expr] | self.users = [] # List[Expr] | ||||
| self._id = Node.__total_id | self._id = Node.__total_id | ||||
| Node.__total_id += 1 | Node.__total_id += 1 | ||||
| self._name = name | self._name = name | ||||
| self._orig_name = orig_name | |||||
| self._qualname = qualname | |||||
| self.actual_node = [] # type: List[Node] | self.actual_node = [] # type: List[Node] | ||||
| def __repr__(self): | def __repr__(self): | ||||
| @@ -54,21 +52,10 @@ class Node: | |||||
| name = "" | name = "" | ||||
| if format_spec in ["i", "p", "ip", "pi"]: | if format_spec in ["i", "p", "ip", "pi"]: | ||||
| if "p" in format_spec: | if "p" in format_spec: | ||||
| graph = self.top_graph | |||||
| prefix_name = "" | |||||
| if graph is not None: | |||||
| prefix_name = graph._name | |||||
| if graph._prefix_name: | |||||
| prefix_name = "{}_{}".format( | |||||
| graph._prefix_name, prefix_name.lstrip("_") | |||||
| ) | |||||
| if name: | |||||
| name = "_" + name.lstrip("_") | |||||
| name = "{}{}".format(prefix_name, name) | |||||
| prefix_name = self.top_graph._name | |||||
| name = "{}_{}".format(prefix_name, name) | |||||
| if "i" in format_spec: | if "i" in format_spec: | ||||
| if name: | |||||
| name = "_" + name.lstrip("_") | |||||
| name = "%{}{}".format(self._id, name) | |||||
| name = "%{}_{}".format(self._id, name) | |||||
| return name | return name | ||||
| else: | else: | ||||
| return name if name else ("%d" % self._id) | return name if name else ("%d" % self._id) | ||||
| @@ -80,15 +67,62 @@ class Node: | |||||
| @name.setter | @name.setter | ||||
| def name(self, new_name: str): | def name(self, new_name: str): | ||||
| r"""Set a new name to this Node.""" | |||||
| graph = self.top_graph | graph = self.top_graph | ||||
| assert graph is not None, "The parent graph of this Node cannot be None." | assert graph is not None, "The parent graph of this Node cannot be None." | ||||
| assert new_name not in graph._used_names, ( | |||||
| assert new_name not in graph._namespace.used_names, ( | |||||
| "The name(%s) is already in use. Please try a different one again." | "The name(%s) is already in use. Please try a different one again." | ||||
| % (new_name) | % (new_name) | ||||
| ) | ) | ||||
| new_name = graph._create_unique_name(new_name) | |||||
| new_name = graph._namespace.create_unique_name(new_name) | |||||
| self._name = new_name | self._name = new_name | ||||
| self._orig_name = new_name | |||||
| @property | |||||
| def qualname(self): | |||||
| r"""Get the `qualname` of this Node. The `qualname` can be used to get the | |||||
| submodule from the traced Module or Module. | |||||
| Example: | |||||
| .. code-block:: | |||||
| import megengine.module as M | |||||
| import megengine.functional as F | |||||
| import megengine.traced_module as tm | |||||
| import megengine as mge | |||||
| class block(M.Module): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.param = mge.Tensor([1.]) | |||||
| self.relu = M.ReLU() | |||||
| def forward(self, x): | |||||
| x = x + self.param | |||||
| return self.relu(F.relu(x)) | |||||
| class module(M.Module): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.block = block() | |||||
| def forward(self, x): | |||||
| x = self.block(x) | |||||
| return x | |||||
| net = module() | |||||
| traced_net = tm.trace_module(net, mge.Tensor([0.])) | |||||
| traced_net = traced_net.flatten() | |||||
| out_node = traced_net.graph.outputs[0] | |||||
| # qualname : "module.block.relu.[out]" | |||||
| qualname = out_node.qualname | |||||
| # qualname : "block.relu" | |||||
| qualname = qualname.split(".", 1)[-1].rsplit(".", 1)[0] | |||||
| assert qualname in list(map(lambda x: x[0], net.named_modules())) | |||||
| assert qualname in list(map(lambda x: x[0], traced_net.named_modules())) | |||||
| """ | |||||
| return self._qualname | |||||
| @property | @property | ||||
| def top_graph(self): | def top_graph(self): | ||||
| @@ -120,8 +154,8 @@ class ModuleNode(Node): | |||||
| r"""The type of the Module correspending to the ModuleNode.""" | r"""The type of the Module correspending to the ModuleNode.""" | ||||
| _owner = None # type: weakref.ReferenceType | _owner = None # type: weakref.ReferenceType | ||||
| def __init__(self, expr, name: str = None, orig_name: str = None): | |||||
| super().__init__(expr, name, orig_name) | |||||
| def __init__(self, expr, name: str = None, qualname: str = None): | |||||
| super().__init__(expr, name, qualname) | |||||
| def __getstate__(self): | def __getstate__(self): | ||||
| return { | return { | ||||
| @@ -129,10 +163,15 @@ class ModuleNode(Node): | |||||
| "users": self.users, | "users": self.users, | ||||
| "_id": self._id, | "_id": self._id, | ||||
| "_name": self._name, | "_name": self._name, | ||||
| "_orig_name": self._orig_name, | |||||
| "_qualname": self._qualname, | |||||
| "module_type": self.module_type, | "module_type": self.module_type, | ||||
| } | } | ||||
| def __setstate__(self, state): | |||||
| if "_orig_name" in state: | |||||
| state["_qualname"] = state.pop("_orig_name") | |||||
| self.__dict__.update(state) | |||||
| @property | @property | ||||
| def owner(self): | def owner(self): | ||||
| r"""Get the ``Module`` corresponding to this ``ModuleNode``. | r"""Get the ``Module`` corresponding to this ``ModuleNode``. | ||||
| @@ -161,9 +200,21 @@ class TensorNode(Node): | |||||
| "_dtype": self._dtype, | "_dtype": self._dtype, | ||||
| "_device": self._device, | "_device": self._device, | ||||
| "_name": self._name, | "_name": self._name, | ||||
| "_orig_name": self._orig_name, | |||||
| "_qualname": self._qualname, | |||||
| } | } | ||||
| def __setstate__(self, state): | |||||
| if "_orig_name" in state: | |||||
| qualname = state.pop("_orig_name") | |||||
| modulepath, comma, qualname = qualname.rpartition(".") | |||||
| expr_name = state["expr"].__class__.__name__ | |||||
| if expr_name not in ["GetAttr"]: | |||||
| qualname = "[{}]".format(qualname) | |||||
| if comma: | |||||
| qualname = "{}.{}".format(modulepath, qualname) | |||||
| state["_qualname"] = qualname | |||||
| self.__dict__.update(state) | |||||
| @property | @property | ||||
| def shape(self): | def shape(self): | ||||
| r"""Get the shape of this Node.""" | r"""Get the shape of this Node.""" | ||||
| @@ -6,6 +6,7 @@ | |||||
| # software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| import pickle | import pickle | ||||
| from itertools import chain | |||||
| import numpy as np | import numpy as np | ||||
| @@ -13,8 +14,8 @@ import megengine.functional as F | |||||
| import megengine.module as M | import megengine.module as M | ||||
| from megengine.module.identity import Identity | from megengine.module.identity import Identity | ||||
| from megengine.traced_module import trace_module | from megengine.traced_module import trace_module | ||||
| from megengine.traced_module.expr import CallFunction, Expr, GetAttr | |||||
| from megengine.traced_module.node import Node | |||||
| from megengine.traced_module.expr import CallFunction, CallMethod, Expr, GetAttr, Input | |||||
| from megengine.traced_module.node import ModuleNode, Node | |||||
| class IdentityMod(M.Module): | class IdentityMod(M.Module): | ||||
| @@ -85,6 +86,34 @@ def test_search(): | |||||
| relu_expr = graph.get_function_by_type(F.relu).as_unique() | relu_expr = graph.get_function_by_type(F.relu).as_unique() | ||||
| assert isinstance(relu_expr, CallFunction) and relu_expr.func == F.relu | assert isinstance(relu_expr, CallFunction) and relu_expr.func == F.relu | ||||
| conv_node = graph.get_module_by_type(M.Conv2d).as_unique() | |||||
| assert isinstance(conv_node, ModuleNode) and conv_node.module_type == M.Conv2d | |||||
| add_expr = graph.get_method_by_type("__add__").as_unique() | |||||
| assert isinstance(add_expr, CallMethod) and add_expr.method == "__add__" | |||||
| conv_node = graph.get_node_by_name("MyBlock_conv1").as_unique() | |||||
| assert isinstance(conv_node, ModuleNode) and conv_node.module_type == M.Conv2d | |||||
| def test_producer_and_users(): | |||||
| traced_module, *_ = _init_module() | |||||
| def _check(exprs): | |||||
| for expr in exprs: | |||||
| for n in chain(expr.inputs, expr.outputs): | |||||
| if not isinstance(n.expr, Input): | |||||
| assert n.expr in exprs | |||||
| for e in n.users: | |||||
| assert e in exprs | |||||
| assert n in e.inputs | |||||
| for mod in traced_module.modules(): | |||||
| if not hasattr(mod, "argdef_graph_map"): | |||||
| continue | |||||
| for g in mod.argdef_graph_map.values(): | |||||
| _check(g._exprs) | |||||
| def test_insert(): | def test_insert(): | ||||
| traced_module, x, expect = _init_block() | traced_module, x, expect = _init_block() | ||||
| @@ -97,6 +126,54 @@ def test_insert(): | |||||
| np.testing.assert_allclose(expect - 1, 1 - traced_module(x), atol=1e-6) | np.testing.assert_allclose(expect - 1, 1 - traced_module(x), atol=1e-6) | ||||
| def test_insert_module(): | |||||
| class Neg(M.Module): | |||||
| def forward(self, x): | |||||
| return F.neg(x) | |||||
| traced_module, x, expect = _init_block() | |||||
| graph = traced_module.graph | |||||
| relu_out = graph.get_function_by_type(F.relu).as_unique().outputs[0] | |||||
| self = graph.inputs[0] | |||||
| setattr(traced_module, "neg", Neg()) | |||||
| with graph.insert_exprs(): | |||||
| neg_out = self.neg(relu_out) | |||||
| graph.replace_node({relu_out: neg_out}) | |||||
| graph.compile() | |||||
| np.testing.assert_allclose(expect - 1, 1 - traced_module(x), atol=1e-6) | |||||
| assert traced_module.neg.graph is not None | |||||
| assert len(traced_module.neg.graph._exprs) == 1 | |||||
| def test_add_input_and_output(): | |||||
| traced_module, x, y = _init_module() | |||||
| data_node = traced_module.graph.add_input_node(shape=(1, 3, 224, 224), name="data") | |||||
| traced_module.graph.add_output_node(data_node) | |||||
| assert data_node.name == "data" | |||||
| assert traced_module.graph.inputs[-1] == data_node | |||||
| assert len(traced_module.graph.inputs) == 3 | |||||
| assert len(traced_module.graph.outputs) == 2 | |||||
| y1, y2 = traced_module(x, x) | |||||
| np.testing.assert_equal(y1.numpy(), y.numpy()) | |||||
| np.testing.assert_equal(y2.numpy(), x.numpy()) | |||||
| y1, y2 = traced_module(x, y) | |||||
| np.testing.assert_equal(y2.numpy(), y.numpy()) | |||||
| traced_module.graph.reset_outputs( | |||||
| ({"orig_out": traced_module.graph.outputs[0]}, traced_module.graph.outputs[1]) | |||||
| ) | |||||
| out = traced_module(x, x) | |||||
| assert isinstance(out, tuple) | |||||
| assert isinstance(out[0], dict) | |||||
| np.testing.assert_equal(out[0]["orig_out"].numpy(), y.numpy()) | |||||
| np.testing.assert_equal(out[1].numpy(), x.numpy()) | |||||
| def test_delete(): | def test_delete(): | ||||
| traced_module, x, expect = _init_block() | traced_module, x, expect = _init_block() | ||||
| graph = traced_module.graph | graph = traced_module.graph | ||||
| @@ -117,8 +194,10 @@ def test_delete(): | |||||
| def test_flatten(): | def test_flatten(): | ||||
| traced_module, x, expect = _init_module() | traced_module, x, expect = _init_module() | ||||
| traced_module = traced_module.flatten() | traced_module = traced_module.flatten() | ||||
| traced_module.graph.compile() | |||||
| assert all(not isinstance(i, GetAttr) for i in traced_module.graph._exprs) | |||||
| assert len(traced_module.graph._exprs) == 12 | |||||
| np.testing.assert_equal(expect.numpy(), traced_module(x).numpy()) | |||||
| traced_module = traced_module.flatten() | |||||
| assert len(traced_module.graph._exprs) == 12 | assert len(traced_module.graph._exprs) == 12 | ||||
| np.testing.assert_equal(expect.numpy(), traced_module(x).numpy()) | np.testing.assert_equal(expect.numpy(), traced_module(x).numpy()) | ||||
| @@ -128,7 +207,7 @@ def test_id_and_name(): | |||||
| _total_ids = traced_module.graph._total_ids | _total_ids = traced_module.graph._total_ids | ||||
| node_ids = [n._id for n in traced_module.graph.nodes().as_list()] | node_ids = [n._id for n in traced_module.graph.nodes().as_list()] | ||||
| assert len(set(node_ids)) == len(node_ids) | assert len(set(node_ids)) == len(node_ids) | ||||
| assert max(node_ids) + 1 == len(node_ids) | |||||
| assert max(node_ids) + 1 == _total_ids[0] | |||||
| expr_ids = [n._id for n in traced_module.graph.exprs().as_list()] | expr_ids = [n._id for n in traced_module.graph.exprs().as_list()] | ||||
| assert len(set(expr_ids)) == len(expr_ids) | assert len(set(expr_ids)) == len(expr_ids) | ||||
| @@ -177,7 +256,7 @@ def test_id_and_name(): | |||||
| _check_name(flattened_module) | _check_name(flattened_module) | ||||
| def test_set_name(): | |||||
| def test_set_node_name(): | |||||
| traced_module, x, expect = _init_module() | traced_module, x, expect = _init_module() | ||||
| graph = traced_module.graph | graph = traced_module.graph | ||||
| output_node = graph.outputs[0] | output_node = graph.outputs[0] | ||||
| @@ -190,6 +269,18 @@ def test_set_name(): | |||||
| np.testing.assert_equal(str(graph.outputs[0]), "output") | np.testing.assert_equal(str(graph.outputs[0]), "output") | ||||
| def test_set_graph_name(): | |||||
| traced_module, x, expect = _init_module() | |||||
| graph = traced_module.graph | |||||
| output_node = graph.outputs[0] | |||||
| node_name = output_node.name | |||||
| graph.name = "Top" | |||||
| node = graph.get_node_by_name("{}_{}".format("Top", node_name)).as_unique() | |||||
| assert node is output_node | |||||
| def test_extra_block(): | def test_extra_block(): | ||||
| class PostProcess(M.Module): | class PostProcess(M.Module): | ||||
| def forward(self, x): | def forward(self, x): | ||||
| @@ -0,0 +1,195 @@ | |||||
| import io | |||||
| from functools import partial | |||||
| from itertools import chain | |||||
| from typing import Callable | |||||
| import numpy as np | |||||
| import megengine as mge | |||||
| import megengine.functional as F | |||||
| import megengine.module as M | |||||
| import megengine.quantization as Q | |||||
| from megengine import Tensor | |||||
| from megengine.module.qat.module import QATModule | |||||
| from megengine.traced_module import TracedModule, trace_module | |||||
| def get_subattr(self: M.Module, name: str): | |||||
| if name == "": | |||||
| return self | |||||
| module_path, _, name = name.rpartition(".") | |||||
| if module_path == "": | |||||
| return getattr(self, name) | |||||
| module_names = module_path.split(".") | |||||
| for item in module_names: | |||||
| self = getattr(self, item) | |||||
| if not isinstance(self, M.Module): | |||||
| raise AttributeError("`{}` is not an Module".format(item)) | |||||
| return getattr(self, name) | |||||
| class Myblcok(M.Module): | |||||
| def __init__(self,): | |||||
| super().__init__() | |||||
| self.conv0 = M.ConvBnRelu2d(3, 3, 3, 1, 1) | |||||
| self.conv1 = M.ConvBn2d(3, 3, 1, 1, 0) | |||||
| self.conv2 = M.ConvBn2d(3, 3, 1, 1, 0) | |||||
| self.add = M.Elemwise("FUSE_ADD_RELU") | |||||
| def forward(self, x): | |||||
| x = self.conv0(x) | |||||
| x0 = self.conv1(x) | |||||
| x1 = self.conv2(x) | |||||
| o = self.add(x0, x1) | |||||
| return o | |||||
| class MyModule(M.Module): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.block0 = Myblcok() | |||||
| self.block1 = Myblcok() | |||||
| def forward(self, x): | |||||
| x = self.block0(x) | |||||
| x = self.block1(x) | |||||
| return x | |||||
| class MyMinMaxObserver(Q.MinMaxObserver): | |||||
| pass | |||||
| class MyTQT(Q.TQT): | |||||
| pass | |||||
| def get_lsq_config(lsq_cls): | |||||
| return Q.QConfig( | |||||
| weight_observer=None, | |||||
| act_observer=None, | |||||
| weight_fake_quant=partial(lsq_cls, dtype="qint8_narrow"), | |||||
| act_fake_quant=partial(lsq_cls, dtype="qint8"), | |||||
| ) | |||||
| def get_observer_config(observer_cls): | |||||
| return Q.QConfig( | |||||
| weight_observer=partial(observer_cls, dtype="qint8_narrow"), | |||||
| act_observer=partial(observer_cls, dtype="qint8"), | |||||
| weight_fake_quant=None, | |||||
| act_fake_quant=None, | |||||
| ) | |||||
| def get_qparams(mod: QATModule): | |||||
| weight_qparams, act_qparams = None, None | |||||
| if mod.act_observer is not None: | |||||
| act_qparams = mod.act_observer.get_qparams() | |||||
| if mod.act_fake_quant: | |||||
| act_qparams = mod.act_fake_quant.get_qparams() | |||||
| if mod.weight_observer is not None: | |||||
| weight_qparams = mod.weight_observer.get_qparams() | |||||
| if mod.weight_fake_quant: | |||||
| weight_qparams = mod.weight_fake_quant.get_qparams() | |||||
| return weight_qparams, act_qparams | |||||
| def check_qparams(qparmsa: Q.QParams, qparmsb: Q.QParams): | |||||
| assert qparmsa.dtype_meta == qparmsb.dtype_meta | |||||
| assert qparmsa.mode == qparmsb.mode | |||||
| np.testing.assert_equal(qparmsa.scale.numpy(), qparmsb.scale.numpy()) | |||||
| if qparmsa.zero_point is not None: | |||||
| np.testing.assert_equal(qparmsa.zero_point.numpy(), qparmsb.zero_point.numpy()) | |||||
| def build_observered_net(net: M.Module, observer_cls): | |||||
| qat_net = Q.quantize_qat(net, qconfig=get_observer_config(observer_cls)) | |||||
| Q.enable_observer(qat_net) | |||||
| for _ in range(5): | |||||
| inp = Tensor(np.random.random(size=(5, 3, 32, 32))) | |||||
| qat_net(inp) | |||||
| Q.disable_observer(qat_net) | |||||
| return qat_net | |||||
| def build_fakequanted_net(net: QATModule, fakequant_cls): | |||||
| qat_net = Q.reset_qconfig(net, get_lsq_config(fakequant_cls)) | |||||
| return qat_net | |||||
| def test_trace_qat(): | |||||
| def _check_qat_module(qat_net: QATModule): | |||||
| inp = Tensor(np.random.random(size=(5, 3, 32, 32))) | |||||
| traced_net = trace_module(qat_net, inp) | |||||
| for name, qat_module in qat_net.named_modules(): | |||||
| if not isinstance(qat_module, QATModule): | |||||
| continue | |||||
| traced_qat_module = get_subattr(traced_net, name) | |||||
| weight_qparams, act_qparams = get_qparams(qat_module) | |||||
| traced_weight_qparams, traced_act_qparams = get_qparams(traced_qat_module) | |||||
| if weight_qparams: | |||||
| check_qparams(weight_qparams, traced_weight_qparams) | |||||
| if act_qparams: | |||||
| check_qparams(act_qparams, traced_act_qparams) | |||||
| _check_qat_module(build_observered_net(MyModule(), Q.MinMaxObserver)) | |||||
| _check_qat_module(build_observered_net(MyModule(), MyMinMaxObserver)) | |||||
| _check_qat_module( | |||||
| build_fakequanted_net(build_observered_net(MyModule(), Q.MinMaxObserver), Q.TQT) | |||||
| ) | |||||
| _check_qat_module( | |||||
| build_fakequanted_net(build_observered_net(MyModule(), Q.MinMaxObserver), MyTQT) | |||||
| ) | |||||
| def test_load_param(): | |||||
| def _check_param(moda: M.Module, modb: M.Module): | |||||
| for name, attr in chain(moda.named_parameters(), moda.named_buffers()): | |||||
| traced_attr = get_subattr(modb, name) | |||||
| np.testing.assert_equal(attr.numpy(), traced_attr.numpy()) | |||||
| def _check_module(build_func: Callable): | |||||
| net = build_func() | |||||
| buffer = io.BytesIO() | |||||
| mge.save(net.state_dict(), buffer) | |||||
| buffer.seek(0) | |||||
| inp = Tensor(np.random.random(size=(5, 3, 32, 32))) | |||||
| traced_net = trace_module(build_func(), inp) | |||||
| traced_net.load_state_dict(mge.load(buffer)) | |||||
| _check_param(net, traced_net) | |||||
| buffer.seek(0) | |||||
| traced_net = trace_module(build_func(), inp).flatten() | |||||
| traced_net.load_state_dict(mge.load(buffer)) | |||||
| _check_param(net, traced_net) | |||||
| _check_module(lambda: MyModule()) | |||||
| _check_module(lambda: build_observered_net(MyModule(), Q.MinMaxObserver)) | |||||
| def test_qualname(): | |||||
| def _check_qualname(net): | |||||
| inp = Tensor(np.random.random(size=(5, 3, 32, 32))) | |||||
| traced_net = trace_module(net, inp) | |||||
| base_qualname = traced_net.graph.qualname | |||||
| for node in traced_net.graph.nodes(): | |||||
| qualname = node.qualname | |||||
| qualname = qualname[len(base_qualname) + 1 :] | |||||
| if qualname.endswith("]"): | |||||
| qualname = qualname.rsplit(".", 1)[0] | |||||
| if qualname.startswith("["): | |||||
| qualname = "" | |||||
| traced_attr = get_subattr(traced_net, qualname) | |||||
| orig_attr = get_subattr(net, qualname) | |||||
| assert traced_attr is not None | |||||
| assert orig_attr is not None | |||||
| _check_qualname(MyModule()) | |||||
| _check_qualname(build_observered_net(MyModule(), Q.MinMaxObserver)) | |||||