GitOrigin-RevId: a43ad1273c
tags/v1.7.0
| @@ -11,7 +11,7 @@ import collections | |||
| import copy | |||
| import inspect | |||
| import re | |||
| from typing import Callable, Dict, List | |||
| from typing import Callable, Dict, List, Optional, Union | |||
| from ..core._imperative_rt import OpDef | |||
| from ..core._imperative_rt.core2 import Tensor as RawTensor | |||
| @@ -32,6 +32,43 @@ def rstrip(s: str, __chars: str): | |||
| return s | |||
| def get_suffix_name(prefix: str, name: str): | |||
| if prefix == name: | |||
| return "" | |||
| matchd = re.compile("^%s\.(.*)" % prefix).match(name) | |||
| if matchd is None: | |||
| return None | |||
| return matchd.group(1) | |||
| def is_call_module(expr): | |||
| return ( | |||
| isinstance(expr, CallMethod) | |||
| and isinstance(expr.inputs[0], ModuleNode) | |||
| and expr.method == "__call__" | |||
| ) | |||
| def is_call_tensor_method(expr): | |||
| return isinstance(expr, CallMethod) and not is_call_module(expr) | |||
| def is_call_function(expr): | |||
| return isinstance(expr, CallFunction) | |||
| def is_constant(expr): | |||
| return isinstance(expr, Constant) | |||
| def is_getattr(expr): | |||
| return isinstance(expr, GetAttr) | |||
| def is_apply_def(expr): | |||
| return isinstance(expr, Apply) | |||
| class Expr: | |||
| r"""``Expr`` represents the operations (i.e. ``CallMethod``, ``CallFunction``, ``Apply``, | |||
| ``GetAttr``, ``Input``, ``Constant``) on ``Node``. | |||
| @@ -76,50 +113,19 @@ class Expr: | |||
| self.const_val.append((idx, val)) | |||
| def add_outputs(self, outputs): | |||
| assert active_module_tracer() is not None | |||
| self.outputs = [] | |||
| if outputs is not None: | |||
| if not isinstance(outputs, collections.Sequence): | |||
| outputs = (outputs,) | |||
| name = None | |||
| orig_name = None | |||
| if isinstance(self, CallMethod): | |||
| name = self.inputs[0]._name | |||
| orig_name = self.inputs[0]._orig_name | |||
| assert isinstance(name, str), "The name of ({}) must be a str".format( | |||
| self.inputs[0] | |||
| ) | |||
| assert isinstance( | |||
| orig_name, str | |||
| ), "The orig_name of ({}) must be a str".format(self.inputs[0]) | |||
| name = rstrip(name, "_out") | |||
| if self.method == "__call__": | |||
| name += "_out" | |||
| orig_name += "_out" | |||
| else: | |||
| strip_method = self.method.strip("_") | |||
| name = "%s_out" % strip_method | |||
| orig_name = name | |||
| elif isinstance(self, CallFunction): | |||
| name = self.func.__name__ + "_out" | |||
| elif isinstance(self, Apply): | |||
| name = str(self.opdef).lower() + "_out" | |||
| for i in outputs: | |||
| assert isinstance(i, RawTensor), "The output must be a Tensor" | |||
| o_name = ( | |||
| active_module_tracer().current_scope()._create_unique_name(name) | |||
| ) | |||
| self.outputs.append( | |||
| NodeMixin.get_wrapped_type(i)( | |||
| expr=self, | |||
| name=o_name, | |||
| orig_name=orig_name if orig_name else o_name, | |||
| ) | |||
| ) | |||
| for i, node in zip(outputs, self.outputs,): | |||
| NodeMixin.wrap_safe(i, node) | |||
| if outputs is None: | |||
| return | |||
| current_graph = active_module_tracer().current_scope() | |||
| if not isinstance(outputs, collections.Sequence): | |||
| outputs = (outputs,) | |||
| for i in outputs: | |||
| assert isinstance(i, RawTensor), "The output must be a Tensor" | |||
| node = NodeMixin.get_wrapped_type(i)(expr=self, name="", qualname="",) | |||
| NodeMixin.wrap_safe(i, node) | |||
| self.outputs.append(node) | |||
| current_graph._namespace.auto_naming_for_outputs(self) | |||
| def unflatten_args(self, inputs): | |||
| if self.arg_def is not None: | |||
| @@ -152,9 +158,7 @@ class Expr: | |||
| ), "({}) must be generated before ({})".format(repl_node, self) | |||
| idx = self.inputs.index(node) | |||
| self.inputs[idx] = repl_node | |||
| user_idx = node.users.index(self) | |||
| assert user_idx >= 0 | |||
| node.users.pop(user_idx) | |||
| node.users.remove(self) | |||
| repl_node.users.append(self) | |||
| @property | |||
| @@ -197,26 +201,23 @@ class Input(Expr): | |||
| r"""A fake Expr which is used to mark the input of graph.""" | |||
| name = None | |||
| def __init__(self, name=None, type=None, orig_name=None): | |||
| def __init__(self, type: List[Node], name: str = "args", qualname: str = ""): | |||
| super().__init__() | |||
| assert type in [ModuleNode, TensorNode] | |||
| assert name and qualname | |||
| self.inputs = [] | |||
| node_cls = type if type else Node | |||
| if orig_name is None: | |||
| orig_name = name | |||
| self.outputs = [ | |||
| node_cls(self, name=name, orig_name=orig_name), | |||
| node_cls(self, name=name, qualname=qualname), | |||
| ] | |||
| self.name = name | |||
| @classmethod | |||
| def make(cls, *args, **kwargs): | |||
| assert active_module_tracer() is not None | |||
| expr = cls(*args, **kwargs) | |||
| oup_node = expr.outputs[0] | |||
| name = ( | |||
| active_module_tracer().current_scope()._create_unique_name(oup_node._name) | |||
| ) | |||
| oup_node._name = name | |||
| active_module_tracer().current_scope()._add_input(oup_node) | |||
| out_node = expr.outputs[0] | |||
| active_module_tracer().current_scope()._add_input(out_node) | |||
| return expr.outputs[0] | |||
| def __repr__(self): | |||
| @@ -230,34 +231,41 @@ class GetAttr(Expr): | |||
| name = None | |||
| r"""name: the qualified name of the attribute to be retrieved.""" | |||
| def __init__(self, module, name, type=None, orig_name=None): | |||
| def __init__( | |||
| self, module: ModuleNode, type: Union[Node], attr_name: str, name: str = "", | |||
| ): | |||
| super().__init__() | |||
| assert isinstance(module, ModuleNode) | |||
| assert type in [TensorNode, ModuleNode] | |||
| self.inputs = [ | |||
| module, | |||
| ] | |||
| module.users.append(self) | |||
| self.name = name | |||
| node_cls = type if type else Node | |||
| self.name = attr_name | |||
| self.outputs = [ | |||
| node_cls(self, name=name, orig_name=orig_name), | |||
| type(self, name=name, qualname="{}.{}".format(module.qualname, attr_name)), | |||
| ] | |||
| @classmethod | |||
| def make(cls, *args, **kwargs): | |||
| assert active_module_tracer() is not None | |||
| current_graph = active_module_tracer().current_scope() | |||
| expr = cls(*args, **kwargs) | |||
| module = expr.inputs[0] | |||
| oup_name = expr.name | |||
| while module._name != "self": | |||
| oup_name = module._name + "_" + oup_name | |||
| module = module.expr.inputs[0] | |||
| oup_name = active_module_tracer().current_scope()._create_unique_name(oup_name) | |||
| expr.outputs[0]._name = oup_name | |||
| active_module_tracer().current_scope()._insert(expr) | |||
| current_graph._namespace.auto_naming_for_outputs(expr) | |||
| current_graph._insert(expr) | |||
| return expr.outputs[0] | |||
| def interpret(self, *inputs): | |||
| return (getattr(inputs[0], self.name),) | |||
| mod = inputs[0] | |||
| module_path, _, name = self.name.rpartition(".") | |||
| if module_path == "": | |||
| return (getattr(mod, name),) | |||
| module_names = module_path.split(".") | |||
| for item in module_names: | |||
| mod = getattr(mod, item) | |||
| if not isinstance(mod, Module): | |||
| raise AttributeError("`{}` is not an Module".format(item)) | |||
| return (getattr(mod, name),) | |||
| def __repr__(self): | |||
| out_type = "Tensor" | |||
| @@ -297,6 +305,7 @@ class CallMethod(Expr): | |||
| @classmethod | |||
| def make(cls, *args, **kwargs): | |||
| assert active_module_tracer() is not None | |||
| expr = cls(*args, **kwargs) | |||
| active_module_tracer().current_scope()._insert(expr) | |||
| return expr | |||
| @@ -362,6 +371,7 @@ class Apply(Expr): | |||
| @classmethod | |||
| def make(cls, *args, **kwargs): | |||
| assert active_module_tracer() is not None | |||
| expr = cls(*args, **kwargs) | |||
| active_module_tracer().current_scope()._insert(expr) | |||
| return expr | |||
| @@ -435,6 +445,7 @@ class CallFunction(Expr): | |||
| @classmethod | |||
| def make(cls, *args, **kwargs): | |||
| assert active_module_tracer() is not None | |||
| expr = cls(*args, **kwargs) | |||
| active_module_tracer().current_scope()._insert(expr) | |||
| return expr | |||
| @@ -474,7 +485,7 @@ class Constant(Expr): | |||
| # TODO: constant cache to reduce the size of dumped model | |||
| _constant_cache = {} | |||
| def __init__(self, c, name=None): | |||
| def __init__(self, c, name: str = "", qualname: str = ""): | |||
| super().__init__() | |||
| assert isinstance(c, (RawTensor, Module)) | |||
| if isinstance(c, Module): | |||
| @@ -484,31 +495,16 @@ class Constant(Expr): | |||
| self.inputs = [] | |||
| node_cls = NodeMixin.get_wrapped_type(c) | |||
| self.outputs = [ | |||
| node_cls(self, name=name, orig_name=name), | |||
| node_cls(self, name=name, qualname=qualname), | |||
| ] | |||
| self.outputs[0]._name = name if name else "const_" + str(self._id) | |||
| @classmethod | |||
| def make(cls, *args, **kwargs): | |||
| assert active_module_tracer() is not None | |||
| expr = cls(*args, **kwargs) | |||
| name = "const_module" if isinstance(expr.value, Module) else "const_tensor" | |||
| full_name = name | |||
| if ( | |||
| isinstance(expr.value, RawTensor) | |||
| and id(expr.value) in active_module_tracer().id2name | |||
| ): | |||
| full_name = active_module_tracer().id2name[id(expr.value)] | |||
| scope_name = active_module_tracer().current_scope()._module_name | |||
| if full_name and scope_name: | |||
| full_name = ("self." + full_name)[len(scope_name) + 1 :] | |||
| else: | |||
| full_name = name | |||
| else: | |||
| full_name = name | |||
| name = active_module_tracer().current_scope()._create_unique_name(full_name) | |||
| expr.outputs[0]._name = name | |||
| expr.outputs[0]._orig_name = full_name | |||
| active_module_tracer().current_scope()._insert(expr) | |||
| current_graph = active_module_tracer().current_scope() | |||
| current_graph._namespace.auto_naming_for_outputs(expr) | |||
| current_graph._insert(expr) | |||
| return expr.outputs[0] | |||
| def interpret(self, *inputs): | |||
| @@ -128,10 +128,9 @@ class module_tracer: | |||
| _active_scopes = None | |||
| def __init__(self, wrap_fn, id2name): | |||
| def __init__(self, wrap_fn): | |||
| self._active_scopes = [] | |||
| self.patcher = Patcher(wrap_fn) | |||
| self.id2name = id2name | |||
| @classmethod | |||
| def register_as_builtin(cls, mod): | |||
| @@ -29,17 +29,15 @@ class Node: | |||
| __total_id = 0 # type: int | |||
| _id = None # type: int | |||
| _top_graph = None # type: weakref.ReferenceType | |||
| _name = None # type: str | |||
| _orig_name = None # type: str | |||
| _format_spec = "" # type: str | |||
| def __init__(self, expr, name: str, orig_name: str): | |||
| def __init__(self, expr, name: str, qualname: str): | |||
| self.expr = expr | |||
| self.users = [] # List[Expr] | |||
| self._id = Node.__total_id | |||
| Node.__total_id += 1 | |||
| self._name = name | |||
| self._orig_name = orig_name | |||
| self._qualname = qualname | |||
| self.actual_node = [] # type: List[Node] | |||
| def __repr__(self): | |||
| @@ -54,21 +52,10 @@ class Node: | |||
| name = "" | |||
| if format_spec in ["i", "p", "ip", "pi"]: | |||
| if "p" in format_spec: | |||
| graph = self.top_graph | |||
| prefix_name = "" | |||
| if graph is not None: | |||
| prefix_name = graph._name | |||
| if graph._prefix_name: | |||
| prefix_name = "{}_{}".format( | |||
| graph._prefix_name, prefix_name.lstrip("_") | |||
| ) | |||
| if name: | |||
| name = "_" + name.lstrip("_") | |||
| name = "{}{}".format(prefix_name, name) | |||
| prefix_name = self.top_graph._name | |||
| name = "{}_{}".format(prefix_name, name) | |||
| if "i" in format_spec: | |||
| if name: | |||
| name = "_" + name.lstrip("_") | |||
| name = "%{}{}".format(self._id, name) | |||
| name = "%{}_{}".format(self._id, name) | |||
| return name | |||
| else: | |||
| return name if name else ("%d" % self._id) | |||
| @@ -80,15 +67,62 @@ class Node: | |||
| @name.setter | |||
| def name(self, new_name: str): | |||
| r"""Set a new name to this Node.""" | |||
| graph = self.top_graph | |||
| assert graph is not None, "The parent graph of this Node cannot be None." | |||
| assert new_name not in graph._used_names, ( | |||
| assert new_name not in graph._namespace.used_names, ( | |||
| "The name(%s) is already in use. Please try a different one again." | |||
| % (new_name) | |||
| ) | |||
| new_name = graph._create_unique_name(new_name) | |||
| new_name = graph._namespace.create_unique_name(new_name) | |||
| self._name = new_name | |||
| self._orig_name = new_name | |||
| @property | |||
| def qualname(self): | |||
| r"""Get the `qualname` of this Node. The `qualname` can be used to get the | |||
| submodule from the traced Module or Module. | |||
| Example: | |||
| .. code-block:: | |||
| import megengine.module as M | |||
| import megengine.functional as F | |||
| import megengine.traced_module as tm | |||
| import megengine as mge | |||
| class block(M.Module): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.param = mge.Tensor([1.]) | |||
| self.relu = M.ReLU() | |||
| def forward(self, x): | |||
| x = x + self.param | |||
| return self.relu(F.relu(x)) | |||
| class module(M.Module): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.block = block() | |||
| def forward(self, x): | |||
| x = self.block(x) | |||
| return x | |||
| net = module() | |||
| traced_net = tm.trace_module(net, mge.Tensor([0.])) | |||
| traced_net = traced_net.flatten() | |||
| out_node = traced_net.graph.outputs[0] | |||
| # qualname : "module.block.relu.[out]" | |||
| qualname = out_node.qualname | |||
| # qualname : "block.relu" | |||
| qualname = qualname.split(".", 1)[-1].rsplit(".", 1)[0] | |||
| assert qualname in list(map(lambda x: x[0], net.named_modules())) | |||
| assert qualname in list(map(lambda x: x[0], traced_net.named_modules())) | |||
| """ | |||
| return self._qualname | |||
| @property | |||
| def top_graph(self): | |||
| @@ -120,8 +154,8 @@ class ModuleNode(Node): | |||
| r"""The type of the Module correspending to the ModuleNode.""" | |||
| _owner = None # type: weakref.ReferenceType | |||
| def __init__(self, expr, name: str = None, orig_name: str = None): | |||
| super().__init__(expr, name, orig_name) | |||
| def __init__(self, expr, name: str = None, qualname: str = None): | |||
| super().__init__(expr, name, qualname) | |||
| def __getstate__(self): | |||
| return { | |||
| @@ -129,10 +163,15 @@ class ModuleNode(Node): | |||
| "users": self.users, | |||
| "_id": self._id, | |||
| "_name": self._name, | |||
| "_orig_name": self._orig_name, | |||
| "_qualname": self._qualname, | |||
| "module_type": self.module_type, | |||
| } | |||
| def __setstate__(self, state): | |||
| if "_orig_name" in state: | |||
| state["_qualname"] = state.pop("_orig_name") | |||
| self.__dict__.update(state) | |||
| @property | |||
| def owner(self): | |||
| r"""Get the ``Module`` corresponding to this ``ModuleNode``. | |||
| @@ -161,9 +200,21 @@ class TensorNode(Node): | |||
| "_dtype": self._dtype, | |||
| "_device": self._device, | |||
| "_name": self._name, | |||
| "_orig_name": self._orig_name, | |||
| "_qualname": self._qualname, | |||
| } | |||
| def __setstate__(self, state): | |||
| if "_orig_name" in state: | |||
| qualname = state.pop("_orig_name") | |||
| modulepath, comma, qualname = qualname.rpartition(".") | |||
| expr_name = state["expr"].__class__.__name__ | |||
| if expr_name not in ["GetAttr"]: | |||
| qualname = "[{}]".format(qualname) | |||
| if comma: | |||
| qualname = "{}.{}".format(modulepath, qualname) | |||
| state["_qualname"] = qualname | |||
| self.__dict__.update(state) | |||
| @property | |||
| def shape(self): | |||
| r"""Get the shape of this Node.""" | |||
| @@ -6,6 +6,7 @@ | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import pickle | |||
| from itertools import chain | |||
| import numpy as np | |||
| @@ -13,8 +14,8 @@ import megengine.functional as F | |||
| import megengine.module as M | |||
| from megengine.module.identity import Identity | |||
| from megengine.traced_module import trace_module | |||
| from megengine.traced_module.expr import CallFunction, Expr, GetAttr | |||
| from megengine.traced_module.node import Node | |||
| from megengine.traced_module.expr import CallFunction, CallMethod, Expr, GetAttr, Input | |||
| from megengine.traced_module.node import ModuleNode, Node | |||
| class IdentityMod(M.Module): | |||
| @@ -85,6 +86,34 @@ def test_search(): | |||
| relu_expr = graph.get_function_by_type(F.relu).as_unique() | |||
| assert isinstance(relu_expr, CallFunction) and relu_expr.func == F.relu | |||
| conv_node = graph.get_module_by_type(M.Conv2d).as_unique() | |||
| assert isinstance(conv_node, ModuleNode) and conv_node.module_type == M.Conv2d | |||
| add_expr = graph.get_method_by_type("__add__").as_unique() | |||
| assert isinstance(add_expr, CallMethod) and add_expr.method == "__add__" | |||
| conv_node = graph.get_node_by_name("MyBlock_conv1").as_unique() | |||
| assert isinstance(conv_node, ModuleNode) and conv_node.module_type == M.Conv2d | |||
| def test_producer_and_users(): | |||
| traced_module, *_ = _init_module() | |||
| def _check(exprs): | |||
| for expr in exprs: | |||
| for n in chain(expr.inputs, expr.outputs): | |||
| if not isinstance(n.expr, Input): | |||
| assert n.expr in exprs | |||
| for e in n.users: | |||
| assert e in exprs | |||
| assert n in e.inputs | |||
| for mod in traced_module.modules(): | |||
| if not hasattr(mod, "argdef_graph_map"): | |||
| continue | |||
| for g in mod.argdef_graph_map.values(): | |||
| _check(g._exprs) | |||
| def test_insert(): | |||
| traced_module, x, expect = _init_block() | |||
| @@ -97,6 +126,54 @@ def test_insert(): | |||
| np.testing.assert_allclose(expect - 1, 1 - traced_module(x), atol=1e-6) | |||
| def test_insert_module(): | |||
| class Neg(M.Module): | |||
| def forward(self, x): | |||
| return F.neg(x) | |||
| traced_module, x, expect = _init_block() | |||
| graph = traced_module.graph | |||
| relu_out = graph.get_function_by_type(F.relu).as_unique().outputs[0] | |||
| self = graph.inputs[0] | |||
| setattr(traced_module, "neg", Neg()) | |||
| with graph.insert_exprs(): | |||
| neg_out = self.neg(relu_out) | |||
| graph.replace_node({relu_out: neg_out}) | |||
| graph.compile() | |||
| np.testing.assert_allclose(expect - 1, 1 - traced_module(x), atol=1e-6) | |||
| assert traced_module.neg.graph is not None | |||
| assert len(traced_module.neg.graph._exprs) == 1 | |||
| def test_add_input_and_output(): | |||
| traced_module, x, y = _init_module() | |||
| data_node = traced_module.graph.add_input_node(shape=(1, 3, 224, 224), name="data") | |||
| traced_module.graph.add_output_node(data_node) | |||
| assert data_node.name == "data" | |||
| assert traced_module.graph.inputs[-1] == data_node | |||
| assert len(traced_module.graph.inputs) == 3 | |||
| assert len(traced_module.graph.outputs) == 2 | |||
| y1, y2 = traced_module(x, x) | |||
| np.testing.assert_equal(y1.numpy(), y.numpy()) | |||
| np.testing.assert_equal(y2.numpy(), x.numpy()) | |||
| y1, y2 = traced_module(x, y) | |||
| np.testing.assert_equal(y2.numpy(), y.numpy()) | |||
| traced_module.graph.reset_outputs( | |||
| ({"orig_out": traced_module.graph.outputs[0]}, traced_module.graph.outputs[1]) | |||
| ) | |||
| out = traced_module(x, x) | |||
| assert isinstance(out, tuple) | |||
| assert isinstance(out[0], dict) | |||
| np.testing.assert_equal(out[0]["orig_out"].numpy(), y.numpy()) | |||
| np.testing.assert_equal(out[1].numpy(), x.numpy()) | |||
| def test_delete(): | |||
| traced_module, x, expect = _init_block() | |||
| graph = traced_module.graph | |||
| @@ -117,8 +194,10 @@ def test_delete(): | |||
| def test_flatten(): | |||
| traced_module, x, expect = _init_module() | |||
| traced_module = traced_module.flatten() | |||
| traced_module.graph.compile() | |||
| assert all(not isinstance(i, GetAttr) for i in traced_module.graph._exprs) | |||
| assert len(traced_module.graph._exprs) == 12 | |||
| np.testing.assert_equal(expect.numpy(), traced_module(x).numpy()) | |||
| traced_module = traced_module.flatten() | |||
| assert len(traced_module.graph._exprs) == 12 | |||
| np.testing.assert_equal(expect.numpy(), traced_module(x).numpy()) | |||
| @@ -128,7 +207,7 @@ def test_id_and_name(): | |||
| _total_ids = traced_module.graph._total_ids | |||
| node_ids = [n._id for n in traced_module.graph.nodes().as_list()] | |||
| assert len(set(node_ids)) == len(node_ids) | |||
| assert max(node_ids) + 1 == len(node_ids) | |||
| assert max(node_ids) + 1 == _total_ids[0] | |||
| expr_ids = [n._id for n in traced_module.graph.exprs().as_list()] | |||
| assert len(set(expr_ids)) == len(expr_ids) | |||
| @@ -177,7 +256,7 @@ def test_id_and_name(): | |||
| _check_name(flattened_module) | |||
| def test_set_name(): | |||
| def test_set_node_name(): | |||
| traced_module, x, expect = _init_module() | |||
| graph = traced_module.graph | |||
| output_node = graph.outputs[0] | |||
| @@ -190,6 +269,18 @@ def test_set_name(): | |||
| np.testing.assert_equal(str(graph.outputs[0]), "output") | |||
| def test_set_graph_name(): | |||
| traced_module, x, expect = _init_module() | |||
| graph = traced_module.graph | |||
| output_node = graph.outputs[0] | |||
| node_name = output_node.name | |||
| graph.name = "Top" | |||
| node = graph.get_node_by_name("{}_{}".format("Top", node_name)).as_unique() | |||
| assert node is output_node | |||
| def test_extra_block(): | |||
| class PostProcess(M.Module): | |||
| def forward(self, x): | |||
| @@ -0,0 +1,195 @@ | |||
| import io | |||
| from functools import partial | |||
| from itertools import chain | |||
| from typing import Callable | |||
| import numpy as np | |||
| import megengine as mge | |||
| import megengine.functional as F | |||
| import megengine.module as M | |||
| import megengine.quantization as Q | |||
| from megengine import Tensor | |||
| from megengine.module.qat.module import QATModule | |||
| from megengine.traced_module import TracedModule, trace_module | |||
| def get_subattr(self: M.Module, name: str): | |||
| if name == "": | |||
| return self | |||
| module_path, _, name = name.rpartition(".") | |||
| if module_path == "": | |||
| return getattr(self, name) | |||
| module_names = module_path.split(".") | |||
| for item in module_names: | |||
| self = getattr(self, item) | |||
| if not isinstance(self, M.Module): | |||
| raise AttributeError("`{}` is not an Module".format(item)) | |||
| return getattr(self, name) | |||
| class Myblcok(M.Module): | |||
| def __init__(self,): | |||
| super().__init__() | |||
| self.conv0 = M.ConvBnRelu2d(3, 3, 3, 1, 1) | |||
| self.conv1 = M.ConvBn2d(3, 3, 1, 1, 0) | |||
| self.conv2 = M.ConvBn2d(3, 3, 1, 1, 0) | |||
| self.add = M.Elemwise("FUSE_ADD_RELU") | |||
| def forward(self, x): | |||
| x = self.conv0(x) | |||
| x0 = self.conv1(x) | |||
| x1 = self.conv2(x) | |||
| o = self.add(x0, x1) | |||
| return o | |||
| class MyModule(M.Module): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.block0 = Myblcok() | |||
| self.block1 = Myblcok() | |||
| def forward(self, x): | |||
| x = self.block0(x) | |||
| x = self.block1(x) | |||
| return x | |||
| class MyMinMaxObserver(Q.MinMaxObserver): | |||
| pass | |||
| class MyTQT(Q.TQT): | |||
| pass | |||
| def get_lsq_config(lsq_cls): | |||
| return Q.QConfig( | |||
| weight_observer=None, | |||
| act_observer=None, | |||
| weight_fake_quant=partial(lsq_cls, dtype="qint8_narrow"), | |||
| act_fake_quant=partial(lsq_cls, dtype="qint8"), | |||
| ) | |||
| def get_observer_config(observer_cls): | |||
| return Q.QConfig( | |||
| weight_observer=partial(observer_cls, dtype="qint8_narrow"), | |||
| act_observer=partial(observer_cls, dtype="qint8"), | |||
| weight_fake_quant=None, | |||
| act_fake_quant=None, | |||
| ) | |||
| def get_qparams(mod: QATModule): | |||
| weight_qparams, act_qparams = None, None | |||
| if mod.act_observer is not None: | |||
| act_qparams = mod.act_observer.get_qparams() | |||
| if mod.act_fake_quant: | |||
| act_qparams = mod.act_fake_quant.get_qparams() | |||
| if mod.weight_observer is not None: | |||
| weight_qparams = mod.weight_observer.get_qparams() | |||
| if mod.weight_fake_quant: | |||
| weight_qparams = mod.weight_fake_quant.get_qparams() | |||
| return weight_qparams, act_qparams | |||
| def check_qparams(qparmsa: Q.QParams, qparmsb: Q.QParams): | |||
| assert qparmsa.dtype_meta == qparmsb.dtype_meta | |||
| assert qparmsa.mode == qparmsb.mode | |||
| np.testing.assert_equal(qparmsa.scale.numpy(), qparmsb.scale.numpy()) | |||
| if qparmsa.zero_point is not None: | |||
| np.testing.assert_equal(qparmsa.zero_point.numpy(), qparmsb.zero_point.numpy()) | |||
| def build_observered_net(net: M.Module, observer_cls): | |||
| qat_net = Q.quantize_qat(net, qconfig=get_observer_config(observer_cls)) | |||
| Q.enable_observer(qat_net) | |||
| for _ in range(5): | |||
| inp = Tensor(np.random.random(size=(5, 3, 32, 32))) | |||
| qat_net(inp) | |||
| Q.disable_observer(qat_net) | |||
| return qat_net | |||
| def build_fakequanted_net(net: QATModule, fakequant_cls): | |||
| qat_net = Q.reset_qconfig(net, get_lsq_config(fakequant_cls)) | |||
| return qat_net | |||
| def test_trace_qat(): | |||
| def _check_qat_module(qat_net: QATModule): | |||
| inp = Tensor(np.random.random(size=(5, 3, 32, 32))) | |||
| traced_net = trace_module(qat_net, inp) | |||
| for name, qat_module in qat_net.named_modules(): | |||
| if not isinstance(qat_module, QATModule): | |||
| continue | |||
| traced_qat_module = get_subattr(traced_net, name) | |||
| weight_qparams, act_qparams = get_qparams(qat_module) | |||
| traced_weight_qparams, traced_act_qparams = get_qparams(traced_qat_module) | |||
| if weight_qparams: | |||
| check_qparams(weight_qparams, traced_weight_qparams) | |||
| if act_qparams: | |||
| check_qparams(act_qparams, traced_act_qparams) | |||
| _check_qat_module(build_observered_net(MyModule(), Q.MinMaxObserver)) | |||
| _check_qat_module(build_observered_net(MyModule(), MyMinMaxObserver)) | |||
| _check_qat_module( | |||
| build_fakequanted_net(build_observered_net(MyModule(), Q.MinMaxObserver), Q.TQT) | |||
| ) | |||
| _check_qat_module( | |||
| build_fakequanted_net(build_observered_net(MyModule(), Q.MinMaxObserver), MyTQT) | |||
| ) | |||
| def test_load_param(): | |||
| def _check_param(moda: M.Module, modb: M.Module): | |||
| for name, attr in chain(moda.named_parameters(), moda.named_buffers()): | |||
| traced_attr = get_subattr(modb, name) | |||
| np.testing.assert_equal(attr.numpy(), traced_attr.numpy()) | |||
| def _check_module(build_func: Callable): | |||
| net = build_func() | |||
| buffer = io.BytesIO() | |||
| mge.save(net.state_dict(), buffer) | |||
| buffer.seek(0) | |||
| inp = Tensor(np.random.random(size=(5, 3, 32, 32))) | |||
| traced_net = trace_module(build_func(), inp) | |||
| traced_net.load_state_dict(mge.load(buffer)) | |||
| _check_param(net, traced_net) | |||
| buffer.seek(0) | |||
| traced_net = trace_module(build_func(), inp).flatten() | |||
| traced_net.load_state_dict(mge.load(buffer)) | |||
| _check_param(net, traced_net) | |||
| _check_module(lambda: MyModule()) | |||
| _check_module(lambda: build_observered_net(MyModule(), Q.MinMaxObserver)) | |||
| def test_qualname(): | |||
| def _check_qualname(net): | |||
| inp = Tensor(np.random.random(size=(5, 3, 32, 32))) | |||
| traced_net = trace_module(net, inp) | |||
| base_qualname = traced_net.graph.qualname | |||
| for node in traced_net.graph.nodes(): | |||
| qualname = node.qualname | |||
| qualname = qualname[len(base_qualname) + 1 :] | |||
| if qualname.endswith("]"): | |||
| qualname = qualname.rsplit(".", 1)[0] | |||
| if qualname.startswith("["): | |||
| qualname = "" | |||
| traced_attr = get_subattr(traced_net, qualname) | |||
| orig_attr = get_subattr(net, qualname) | |||
| assert traced_attr is not None | |||
| assert orig_attr is not None | |||
| _check_qualname(MyModule()) | |||
| _check_qualname(build_observered_net(MyModule(), Q.MinMaxObserver)) | |||