GitOrigin-RevId: aaa9e51c74
tags/v1.7.0
| @@ -7,6 +7,7 @@ | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from ..core._imperative_rt.core2 import set_cpp_apply_module_trace | |||
| from . import compat | |||
| from .traced_module import ( | |||
| TracedModule, | |||
| _register_all_builtin_module, | |||
| @@ -0,0 +1,136 @@ | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import numpy as np | |||
| from .. import tensor | |||
| from ..core.ops.builtin import BatchNorm | |||
| from .expr import CallMethod, Constant | |||
| from .node import TensorNode | |||
| from .serialization import ( | |||
| register_functional_loader, | |||
| register_module_loader, | |||
| register_opdef_loader, | |||
| register_tensor_method_loader, | |||
| ) | |||
| """ | |||
| # Expr loaders examples | |||
| from ..core.ops.builtin import Elemwise | |||
| @register_opdef_loader(Elemwise) | |||
| def add_opdef_loader(expr): | |||
| if expr.opdef_state["mode"] == "ADD": | |||
| expr.opdef_state["mode"] == "MUL" | |||
| node = expr.inputs[1] | |||
| astype_expr = CallMethod(node, "astype") | |||
| oup = TensorNode( | |||
| astype_expr, | |||
| shape=node.shape, | |||
| dtype=expr.inputs[0].dtype, | |||
| qparams=node.qparams, | |||
| ) | |||
| astype_expr.set_args_kwargs(node, expr.inputs[0].dtype) | |||
| astype_expr.return_val = (oup,) | |||
| expr.inputs[1] = oup | |||
| @register_functional_loader(("megengine.functional.nn", "conv2d")) | |||
| def conv2df_loader(expr): | |||
| # expr.func = ("megengine.functional.nn","conv2d") | |||
| kwargs = expr.kwargs | |||
| orig_weight = expr.named_args["weight"] | |||
| astype_expr = CallMethod(orig_weight, "astype") | |||
| oup = TensorNode( | |||
| astype_expr, | |||
| shape=orig_weight.shape, | |||
| dtype=orig_weight.dtype, | |||
| qparams=orig_weight.qparams, | |||
| ) | |||
| astype_expr.set_args_kwargs(orig_weight, expr.named_args["inp"].dtype) | |||
| astype_expr.return_val = (oup,) | |||
| expr.set_arg("weight", oup) | |||
| @register_module_loader(("megengine.module.conv", "Conv2d")) | |||
| def conv2dm_loader(expr): | |||
| module = expr.inputs[0].owner | |||
| args = list(expr.args) | |||
| orig_inp = args[1] | |||
| astype_expr = CallMethod(orig_inp, "astype") | |||
| oup = TensorNode( | |||
| astype_expr, | |||
| shape=orig_inp.shape, | |||
| dtype=orig_inp.dtype, | |||
| qparams=orig_inp.qparams, | |||
| ) | |||
| astype_expr.set_args_kwargs(orig_inp, module.weight.dtype) | |||
| astype_expr.return_val = (oup,) | |||
| args[1] = oup | |||
| expr.set_args_kwargs(*args) | |||
| @register_tensor_method_loader("__add__") | |||
| def add_loader(expr): | |||
| args = list(expr.args) | |||
| if not isinstance(args[1], TensorNode): | |||
| args[1] = tensor(args[1]) | |||
| node = Constant(args[1], "const").outputs[0] | |||
| astype_expr = CallMethod(node, "astype") | |||
| oup = TensorNode( | |||
| astype_expr, shape=node.shape, dtype=node.dtype, qparams=node.qparams, | |||
| ) | |||
| astype_expr.set_args_kwargs(node, expr.inputs[0].dtype) | |||
| astype_expr.return_val = (oup,) | |||
| args[1] = oup | |||
| expr.set_args_kwargs(*args) | |||
| """ | |||
| @register_module_loader( | |||
| ("megengine.module.batchnorm", "BatchNorm1d"), | |||
| ("megengine.module.batchnorm", "BatchNorm2d"), | |||
| ("megengine.module.batchnorm", "SyncBatchNorm"), | |||
| ) | |||
| def bn2d_module_loader(expr): | |||
| # mge 1.6 | |||
| if not hasattr(expr, "version"): | |||
| module = expr.inputs[0].owner | |||
| if not hasattr(module, "param_dim"): | |||
| module.param_dim = "dim_1c11" | |||
| @register_module_loader( | |||
| ("megengine.module.conv_bn", "ConvBn2d"), | |||
| ("megengine.module.conv_bn", "ConvBnRelu2d"), | |||
| ("megengine.module.qat.conv_bn", "ConvBn2d"), | |||
| ("megengine.module.qat.conv_bn", "ConvBnRelu2d"), | |||
| ) | |||
| def convbn2d_module_loader(expr): | |||
| # mge 1.6 | |||
| if not hasattr(expr, "version"): | |||
| module = expr.inputs[0].owner | |||
| if not hasattr(module.bn, "param_dim"): | |||
| module.bn.param_dim = "dim_1c11" | |||
| @register_opdef_loader(BatchNorm) | |||
| def bn_opdef_loader(expr): | |||
| # mge 1.6 | |||
| if not hasattr(expr, "version"): | |||
| output = expr.outputs[-1] | |||
| oup = TensorNode(expr, shape=(0,), dtype=None, qparams=output._qparams,) | |||
| expr.outputs.insert(4, oup) | |||
| @@ -11,19 +11,28 @@ import collections | |||
| import copy | |||
| import inspect | |||
| import re | |||
| from typing import Callable, Dict, List, Optional, Union | |||
| import weakref | |||
| from importlib import import_module | |||
| from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union | |||
| from ..core._imperative_rt import OpDef | |||
| from ..core._imperative_rt.core2 import Tensor as RawTensor | |||
| from ..core._imperative_rt.core2 import apply, set_module_tracing, unset_module_tracing | |||
| from ..core._imperative_rt.core2 import ( | |||
| apply, | |||
| is_tracing_module, | |||
| set_module_tracing, | |||
| unset_module_tracing, | |||
| ) | |||
| from ..core.ops.builtin import FakeQuant | |||
| from ..core.ops.special import Const | |||
| from ..module import Module | |||
| from ..tensor import Parameter, Tensor | |||
| from ..version import __version__ | |||
| from .module_tracer import active_module_tracer, module_tracer | |||
| from .node import ModuleNode, Node, NodeMixin, TensorNode | |||
| from .pytree import ArgsIndex, TreeDef, _is_const_leaf, _is_leaf, tree_flatten | |||
| from .serialization import get_opdef_state, load_opdef_from_state | |||
| from .serialization import _ModuleState | |||
| from .utils import _check_builtin_module_attr, _check_obj_attr, _convert_kwargs_to_args | |||
| def rstrip(s: str, __chars: str): | |||
| @@ -112,6 +121,7 @@ class Expr: | |||
| node.users.append(self) | |||
| else: | |||
| assert node is None | |||
| assert not isinstance(val, (Module, RawTensor)) | |||
| assert _is_leaf(val) and _is_const_leaf(val) | |||
| idx = len(self.inputs) + len(self.const_val) | |||
| self.const_val.append((idx, val)) | |||
| @@ -132,14 +142,14 @@ class Expr: | |||
| current_graph._namespace.auto_naming_for_outputs(self) | |||
| def unflatten_args(self, inputs): | |||
| if self.arg_def is not None: | |||
| inputs = list(inputs) | |||
| for idx, val in self.const_val: | |||
| inputs.insert(idx, val) | |||
| args, kwargs = self.arg_def.unflatten(inputs) | |||
| return args, kwargs | |||
| else: | |||
| return inputs, {} | |||
| assert self.arg_def is not None, "{} expr doesn't have args/kwargs".format( | |||
| type(self).__name__ | |||
| ) | |||
| inputs = list(inputs) | |||
| for idx, val in self.const_val: | |||
| inputs.insert(idx, val) | |||
| args, kwargs = self.arg_def.unflatten(inputs) | |||
| return args, kwargs | |||
| def replace_inputs(self, repl_dict: Dict[Node, Node]): | |||
| r"""Replace the input Nodes of this Expr. | |||
| @@ -165,6 +175,39 @@ class Expr: | |||
| node.users.remove(self) | |||
| repl_node.users.append(self) | |||
| @property | |||
| def _support_set_args_kwargs(self): | |||
| return False | |||
| def set_args_kwargs(self, *args, **kwargs): | |||
| r""" Set args and kwargs for Expr. | |||
| """ | |||
| assert ( | |||
| self._support_set_args_kwargs | |||
| ), "Doesn't support set args/kwargs for {} expr".format(type(self).__name__) | |||
| args, kwargs = _convert_kwargs_to_args(self._get_func(), args, kwargs) | |||
| inputs, arg_def = tree_flatten((args, kwargs)) | |||
| orig_inputs = self.inputs | |||
| self.inputs = [] | |||
| self.const_val = [] | |||
| for val in inputs: | |||
| if isinstance(val, (TensorNode, ModuleNode)): | |||
| self.inputs.append(val) | |||
| else: | |||
| assert _is_leaf(val) and _is_const_leaf(val) | |||
| idx = len(self.inputs) + len(self.const_val) | |||
| self.const_val.append((idx, val)) | |||
| for n in orig_inputs: | |||
| if n not in self.inputs: | |||
| n.users.remove(self) | |||
| for n in self.inputs: | |||
| if n not in orig_inputs: | |||
| n.users.append(self) | |||
| self.arg_def = arg_def | |||
| @property | |||
| def kwargs(self): | |||
| r"""Get the keyword arguments of the operation corresponding to this Expr.""" | |||
| @@ -177,6 +220,61 @@ class Expr: | |||
| args, _ = self.unflatten_args(self.inputs) | |||
| return args | |||
| def _get_func(self): | |||
| # get called function when the expr is interpreted | |||
| raise NotImplementedError | |||
| @property | |||
| def named_args(self): | |||
| func = self._get_func() | |||
| return inspect.getcallargs(func, *self.args, **self.kwargs) | |||
| def set_arg(self, name, val): | |||
| func = self._get_func() | |||
| if name in self.kwargs: | |||
| new_kwargs = self.kwargs | |||
| new_kwargs[name] = val | |||
| self.set_args_kwargs(*self.args, **new_kwargs) | |||
| else: | |||
| arg_spec = inspect.getfullargspec(func) | |||
| if name in arg_spec.args: | |||
| ind = arg_spec.args.index(name) | |||
| new_args = list(self.args) | |||
| new_args[ind] = val | |||
| self.set_args_kwargs(*new_args) | |||
| elif name == arg_spec.varargs: | |||
| assert arg_spec.varargs is not None | |||
| assert len(self.args) >= len(arg_spec.args) | |||
| val = (val,) if not isinstance(val, Sequence) else val | |||
| self.set_args_kwargs(*self.args[0 : len(arg_spec.args)], *val) | |||
| else: | |||
| assert ( | |||
| arg_spec.varkw is not None | |||
| ), "func {} does't have argument named {}".format(func, name) | |||
| new_kwargs = self.kwargs | |||
| new_kwargs[name] = val | |||
| self.set_args_kwargs(*self.args, **new_kwargs) | |||
| @property | |||
| def return_val(self): | |||
| return self.out_def.unflatten(self.outputs) | |||
| @return_val.setter | |||
| def return_val(self, new_outputs): | |||
| outputs, out_def = tree_flatten( | |||
| new_outputs, is_leaf=lambda x: isinstance(x, Node) | |||
| ) | |||
| assert all( | |||
| isinstance(o, Node) for o in outputs | |||
| ), "Return values of expr must be ModuleNode or TensorNode or Container with them" | |||
| assert all( | |||
| o.expr in (None, self) for o in outputs | |||
| ), "Some nodes are produced by other expr, can not be output of expr {}".format( | |||
| self | |||
| ) | |||
| self.outputs = outputs | |||
| self.out_def = out_def | |||
| @property | |||
| def top_graph(self): | |||
| r"""Get the parent graph of this Expr.""" | |||
| @@ -184,12 +282,6 @@ class Expr: | |||
| return self._top_graph() | |||
| return None | |||
| def __getstate__(self): | |||
| state = self.__dict__.copy() | |||
| if "_top_graph" in state: | |||
| state.pop("_top_graph") | |||
| return state | |||
| @classmethod | |||
| def _get_next_id(cls): | |||
| return cls.__total_id | |||
| @@ -199,6 +291,23 @@ class Expr: | |||
| assert isinstance(id, int) | |||
| cls.__total_id = id | |||
| def __copy__(self): | |||
| cls = self.__class__ | |||
| result = cls.__new__(cls) | |||
| result.__dict__.update(self.__dict__) | |||
| return result | |||
| def __deepcopy__(self, memo): | |||
| cls = self.__class__ | |||
| result = cls.__new__(cls) | |||
| state = {} | |||
| memo[id(self)] = result | |||
| for k, v in self.__dict__.items(): | |||
| if not isinstance(v, weakref.ReferenceType): | |||
| state[k] = copy.deepcopy(v, memo) | |||
| result.__dict__.update(state) | |||
| return result | |||
| # expr: None (i.e. fake expression which is used to mark input) | |||
| class Input(Expr): | |||
| @@ -229,6 +338,17 @@ class Input(Expr): | |||
| def __repr__(self): | |||
| return "%{}:\t{} = Input()".format(self._id, self.outputs[0]) | |||
| def __getstate__(self): | |||
| state = { | |||
| "_id": self._id, | |||
| "_disable_remove": self._disable_remove, | |||
| "inputs": self.inputs, | |||
| "outputs": self.outputs, | |||
| "name": self.name, | |||
| } | |||
| _check_obj_attr(state) | |||
| return state | |||
| # expr: outputs = getattr(inputs[0], self.name) | |||
| class GetAttr(Expr): | |||
| @@ -276,11 +396,23 @@ class GetAttr(Expr): | |||
| def __repr__(self): | |||
| out_type = "Tensor" | |||
| if isinstance(self.outputs[0], ModuleNode): | |||
| out_type = self.outputs[0].module_type.__name__ | |||
| m_type = self.outputs[0].module_type | |||
| out_type = m_type.__name__ if isinstance(m_type, type) else m_type[1] | |||
| return '%{}:\t{} = getattr({}, "{}") -> ({})'.format( | |||
| self._id, self.outputs[0], self.inputs[0], self.name, out_type | |||
| ) | |||
| def __getstate__(self): | |||
| state = { | |||
| "_id": self._id, | |||
| "_disable_remove": self._disable_remove, | |||
| "inputs": self.inputs, | |||
| "outputs": self.outputs, | |||
| "name": self.name, | |||
| } | |||
| _check_obj_attr(state) | |||
| return state | |||
| # expr: outputs = inputs[0].__call__(*inputs[1:]) | |||
| class CallMethod(Expr): | |||
| @@ -307,6 +439,7 @@ class CallMethod(Expr): | |||
| node, | |||
| ] | |||
| self.const_val = [] | |||
| self.arg_def = tree_flatten(((node,), {}))[1] | |||
| self.method = method | |||
| @classmethod | |||
| @@ -342,6 +475,27 @@ class CallMethod(Expr): | |||
| outputs, _ = tree_flatten(outputs, is_leaf=lambda x: isinstance(x, RawTensor)) | |||
| return outputs | |||
| def _get_func(self): | |||
| if isinstance(self.args[0], type): | |||
| obj_type = self.args[0] | |||
| elif isinstance(self.args[0], ModuleNode): | |||
| obj_type = self.args[0].module_type | |||
| else: | |||
| assert isinstance(self.args[0], TensorNode) | |||
| obj_type = Tensor | |||
| meth = getattr( | |||
| obj_type, "forward" if issubclass(obj_type, Module) else self.method | |||
| ) | |||
| return meth | |||
| @property | |||
| def _support_set_args_kwargs(self): | |||
| # only expr call tensor method or builtin module support modify args/kwargs | |||
| return ( | |||
| isinstance(self.args[0], (TensorNode, type)) | |||
| or self.args[0].module_type is not Module | |||
| ) | |||
| def __repr__(self): | |||
| args = ", ".join(str(i) for i in self.args[1:]) | |||
| kwargs = ", ".join("{}={}".format(k, v) for k, v in self.kwargs.items()) | |||
| @@ -359,6 +513,21 @@ class CallMethod(Expr): | |||
| ", ".join([args, kwargs]), | |||
| ) | |||
| def __getstate__(self): | |||
| state = { | |||
| "_id": self._id, | |||
| "_disable_remove": self._disable_remove, | |||
| "inputs": self.inputs, | |||
| "const_val": self.const_val, | |||
| "method": self.method, | |||
| "arg_def": self.arg_def, | |||
| "out_def": self.out_def, | |||
| "outputs": self.outputs, | |||
| "version": __version__, | |||
| } | |||
| _check_obj_attr(state) | |||
| return state | |||
| # expr: outputs = apply(self.opdef, *inputs) | |||
| class Apply(Expr): | |||
| @@ -394,14 +563,32 @@ class Apply(Expr): | |||
| ) | |||
| def __getstate__(self): | |||
| state = super().__getstate__() | |||
| state["opdef"] = get_opdef_state(state["opdef"]) | |||
| opdef_state = self.opdef.__getstate__() | |||
| opdef_state["opdef_type"] = type(self.opdef) | |||
| state = { | |||
| "_id": self._id, | |||
| "_disable_remove": self._disable_remove, | |||
| "opdef_state": opdef_state, | |||
| "inputs": self.inputs, | |||
| "outputs": self.outputs, | |||
| "version": __version__, | |||
| } | |||
| _check_obj_attr(state) | |||
| return state | |||
| def __setstate__(self, state): | |||
| state["opdef"] = load_opdef_from_state(state["opdef"]) | |||
| for k, v in state.items(): | |||
| setattr(self, k, v) | |||
| # compat with mge 1.6 | |||
| if "opdef" in state and "opdef_state" not in state: | |||
| opdef_state = state.pop("opdef") | |||
| opdef_state["opdef_type"] = opdef_state.pop("type") | |||
| state["opdef_state"] = opdef_state | |||
| self.__dict__.update(state) | |||
| assert isinstance(state["opdef_state"], dict) | |||
| opdef_state = state["opdef_state"].copy() | |||
| opdef_type = opdef_state.pop("opdef_type") | |||
| opdef_obj = opdef_type() | |||
| opdef_obj.__setstate__(opdef_state) | |||
| setattr(self, "opdef", opdef_obj) | |||
| @classmethod | |||
| def apply_module_trace_hook(cls, opdef, *inputs): | |||
| @@ -458,12 +645,24 @@ class CallFunction(Expr): | |||
| def interpret(self, *inputs): | |||
| args, kwargs = self.unflatten_args(inputs) | |||
| outputs = self.func(*args, **kwargs) | |||
| func = ( | |||
| self.func | |||
| if not is_tracing_module() | |||
| else active_module_tracer().patcher.wrap_fn(self.func) | |||
| ) | |||
| outputs = func(*args, **kwargs) | |||
| if outputs is None: | |||
| return outputs | |||
| outputs, _ = tree_flatten(outputs, is_leaf=lambda x: isinstance(x, RawTensor)) | |||
| return outputs | |||
| def _get_func(self): | |||
| return self.func | |||
| @property | |||
| def _support_set_args_kwargs(self): | |||
| return True | |||
| def __repr__(self): | |||
| args = ", ".join(str(i) for i in self.args) | |||
| kwargs = ", ".join("{}={}".format(k, v) for k, v in self.kwargs.items()) | |||
| @@ -477,6 +676,33 @@ class CallFunction(Expr): | |||
| ", ".join([args, kwargs]), | |||
| ) | |||
| def __getstate__(self): | |||
| state = { | |||
| "_id": self._id, | |||
| "_disable_remove": self._disable_remove, | |||
| "func": (self.func.__module__, self.func.__qualname__), | |||
| "const_val": self.const_val, | |||
| "inputs": self.inputs, | |||
| "arg_def": self.arg_def, | |||
| "out_def": self.out_def, | |||
| "outputs": self.outputs, | |||
| "version": __version__, | |||
| } | |||
| _check_obj_attr(state) | |||
| return state | |||
| def __setstate__(self, state): | |||
| self.__dict__.update(state) | |||
| try: | |||
| if isinstance(self.func, tuple): | |||
| mname, fname = self.func | |||
| f = import_module(mname) | |||
| for i in fname.split("."): | |||
| f = getattr(f, i) | |||
| self.func = f | |||
| except Exception: | |||
| pass | |||
| # expr outputs = self.value | |||
| class Constant(Expr): | |||
| @@ -496,6 +722,13 @@ class Constant(Expr): | |||
| assert isinstance(c, (RawTensor, Module)) | |||
| if isinstance(c, Module): | |||
| assert module_tracer.is_builtin(c) or c.is_qat | |||
| if isinstance(c, RawTensor): | |||
| if is_tracing_module(): | |||
| unset_module_tracing() | |||
| c = Tensor(c) | |||
| set_module_tracing() | |||
| else: | |||
| c = Tensor(c) | |||
| self.value = c | |||
| self.name = name | |||
| self.inputs = [] | |||
| @@ -530,9 +763,25 @@ class Constant(Expr): | |||
| ) | |||
| def __getstate__(self): | |||
| state = self.__dict__.copy() | |||
| if "_top_graph" in state: | |||
| state.pop("_top_graph") | |||
| state = { | |||
| "_id": self._id, | |||
| "_disable_remove": self._disable_remove, | |||
| "value": self.value, | |||
| "name": self.name, | |||
| "inputs": self.inputs, | |||
| "outputs": self.outputs, | |||
| } | |||
| _check_obj_attr(state) | |||
| if isinstance(self.value, RawTensor): | |||
| state["value"] = Tensor(self.value) | |||
| if isinstance(self.value, Module) and module_tracer.is_builtin(self.value): | |||
| _check_builtin_module_attr(self.value) | |||
| state["value"] = _ModuleState.get_module_state(self.value) | |||
| return state | |||
| def __setstate__(self, state): | |||
| for k, v in state.items(): | |||
| if isinstance(v, _ModuleState): | |||
| state[k] = v.to_module() | |||
| self.__dict__.update(state) | |||
| @@ -72,7 +72,6 @@ BUILTIN_ARRAY_METHOD = [ | |||
| "astype", | |||
| "reshape", | |||
| "_broadcast", | |||
| "transpose", | |||
| "flatten", | |||
| "sum", | |||
| "prod", | |||
| @@ -6,7 +6,9 @@ | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import abc | |||
| import copy | |||
| import weakref | |||
| from importlib import import_module | |||
| from typing import Any, Dict, List, Tuple, Type | |||
| import numpy | |||
| @@ -14,7 +16,9 @@ import numpy | |||
| from .. import get_logger | |||
| from ..core._imperative_rt.core2 import Tensor as RawTensor | |||
| from ..module import Module | |||
| from ..quantization.utils import QParams | |||
| from ..tensor import Tensor | |||
| from .utils import _check_obj_attr | |||
| logger = get_logger(__name__) | |||
| @@ -145,6 +149,23 @@ class Node: | |||
| assert isinstance(id, int) | |||
| cls.__total_id = id | |||
| def __copy__(self): | |||
| cls = self.__class__ | |||
| result = cls.__new__(cls) | |||
| result.__dict__.update(self.__dict__) | |||
| return result | |||
| def __deepcopy__(self, memo): | |||
| cls = self.__class__ | |||
| result = cls.__new__(cls) | |||
| state = {} | |||
| memo[id(self)] = result | |||
| for k, v in self.__dict__.items(): | |||
| if not isinstance(v, weakref.ReferenceType) and k != "actual_node": | |||
| state[k] = copy.deepcopy(v, memo) | |||
| result.__dict__.update(state) | |||
| return result | |||
| class ModuleNode(Node): | |||
| r"""``ModuleNode`` represents the Module objects.""" | |||
| @@ -157,19 +178,28 @@ class ModuleNode(Node): | |||
| super().__init__(expr, name, qualname) | |||
| def __getstate__(self): | |||
| return { | |||
| state = { | |||
| "expr": self.expr, | |||
| "users": self.users, | |||
| "_id": self._id, | |||
| "_name": self._name, | |||
| "_qualname": self._qualname, | |||
| "module_type": self.module_type, | |||
| "module_type": (self.module_type.__module__, self.module_type.__qualname__), | |||
| } | |||
| _check_obj_attr(state) | |||
| return state | |||
| def __setstate__(self, state): | |||
| if "_orig_name" in state: | |||
| state["_qualname"] = state.pop("_orig_name") | |||
| self.__dict__.update(state) | |||
| try: | |||
| if isinstance(self.module_type, tuple): | |||
| mname, classname = self.module_type | |||
| mtype = getattr(import_module(mname), classname) | |||
| self.module_type = mtype | |||
| except Exception: | |||
| pass | |||
| @property | |||
| def owner(self): | |||
| @@ -185,12 +215,26 @@ class TensorNode(Node): | |||
| _shape = None # type: Tuple[int] | |||
| _dtype = None # type: numpy.dtype | |||
| _qparams = None | |||
| _qparams = None # type: QParams | |||
| _device = None | |||
| _value = None # type: Tensor | |||
| def __init__( | |||
| self, | |||
| expr: "Expr", | |||
| name: str = None, | |||
| qualname: str = None, | |||
| shape: Tuple[int] = None, | |||
| dtype: numpy.dtype = None, | |||
| qparams: QParams = None, | |||
| ): | |||
| super().__init__(expr, name, qualname) | |||
| self._shape = shape | |||
| self._dtype = shape | |||
| self._qparams = qparams | |||
| def __getstate__(self): | |||
| return { | |||
| state = { | |||
| "expr": self.expr, | |||
| "users": self.users, | |||
| "_id": self._id, | |||
| @@ -201,6 +245,8 @@ class TensorNode(Node): | |||
| "_name": self._name, | |||
| "_qualname": self._qualname, | |||
| } | |||
| _check_obj_attr(state) | |||
| return state | |||
| def __setstate__(self, state): | |||
| if "_orig_name" in state: | |||
| @@ -276,7 +322,10 @@ class NodeMixin(abc.ABC): | |||
| assert isinstance(node, TensorNode) | |||
| assert isinstance(value, RawTensor) | |||
| if isinstance(value, RawTensor): | |||
| node._dtype = value.dtype | |||
| try: | |||
| node._dtype = value.dtype | |||
| except RuntimeError: | |||
| node._dtype = None | |||
| node._shape = ( | |||
| value._tuple_shape if isinstance(value, Tensor) else value.shape | |||
| ) | |||
| @@ -7,15 +7,18 @@ | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import collections | |||
| from collections import OrderedDict | |||
| from collections import OrderedDict, defaultdict | |||
| from functools import partial | |||
| from typing import Callable, NamedTuple | |||
| import numpy as np | |||
| from ..core._imperative_rt import OpDef | |||
| from ..core._imperative_rt.common import CompNode | |||
| from ..core._imperative_rt.core2 import Tensor as RawTensor | |||
| from ..core._wrap import Device | |||
| from ..core.tensor.dtype import QuantDtypeMeta | |||
| from ..distributed import Group | |||
| from ..module import Module | |||
| from ..quantization.utils import LSQParams, QParams, QuantMode | |||
| from ..tensor import Parameter, Tensor | |||
| @@ -49,45 +52,54 @@ SUPPORTED_LEAF_TYPE = { | |||
| type(Ellipsis), | |||
| QuantMode, | |||
| ArgsIndex, | |||
| Group, | |||
| } | |||
| USER_REGISTERED_LEAF_TYPE = [] | |||
| USER_REGISTERED_CONTAINER_TYPE = [] | |||
| # if isinstance(object, SUPPORTED_LEAF_CLS) or issubclass(obj, SUPPORTED_LEAF_CLS) is True, the object could be threated as leaf node of pytree | |||
| SUPPORTED_LEAF_CLS = [Module, Node, NodeMixin, np.dtype, np.ndarray, np.number] | |||
| SUPPORTED_LEAF_CLS = [ | |||
| Module, | |||
| Node, | |||
| NodeMixin, | |||
| np.dtype, | |||
| np.ndarray, | |||
| np.number, | |||
| np.bool_, | |||
| OpDef, | |||
| ] | |||
| NodeType = NamedTuple("NodeType", [("flatten", Callable), ("unflatten", Callable)]) | |||
| def register_supported_type(type, flatten=None, unflatten=None): | |||
| tp_info = (type.__module__, type.__qualname__) | |||
| if flatten and unflatten: | |||
| SUPPORTED_TYPE[type] = NodeType(flatten, unflatten) | |||
| USER_REGISTERED_CONTAINER_TYPE.append(tp_info) | |||
| else: | |||
| SUPPORTED_LEAF_CLS.append(type) | |||
| def _dict_flatten(inp): | |||
| aux_data = [] | |||
| results = [] | |||
| for key, value in sorted(inp.items()): | |||
| results.append(value) | |||
| aux_data.append(key) | |||
| return results, tuple(aux_data) | |||
| USER_REGISTERED_LEAF_TYPE.append(tp_info) | |||
| _register_supported_type(type, flatten, unflatten) | |||
| def _dict_unflatten(inps, aux_data): | |||
| return dict(zip(aux_data, inps)) | |||
| def _register_supported_type(type, flatten=None, unflatten=None): | |||
| if flatten and unflatten: | |||
| SUPPORTED_TYPE[type] = NodeType(flatten, unflatten) | |||
| else: | |||
| SUPPORTED_LEAF_CLS.append(type) | |||
| def _ordereddict_flatten(inp): | |||
| def _dict_flatten(ordered, inp): | |||
| aux_data = [] | |||
| results = [] | |||
| for key, value in inp.items(): | |||
| dict_items = inp.items() if ordered else sorted(inp.items()) | |||
| for key, value in dict_items: | |||
| results.append(value) | |||
| aux_data.append(key) | |||
| return results, tuple(aux_data) | |||
| def _ordereddict_unflatten(inps, aux_data): | |||
| return OrderedDict(zip(aux_data, inps)) | |||
| def _dict_unflatten(dict_type, inps, aux_data): | |||
| return dict_type(zip(aux_data, inps)) | |||
| def qparams_flatten(inp): | |||
| @@ -99,33 +111,41 @@ def qparams_flatten(inp): | |||
| return results, tuple(aux_data) | |||
| def qparams_unflatten(inp, aux_data): | |||
| obj = QParams.__new__(QParams) | |||
| def qparams_unflatten(qparam_type, inp, aux_data): | |||
| obj = qparam_type.__new__(qparam_type) | |||
| for k, v in zip(aux_data, inp): | |||
| setattr(obj, k, v) | |||
| return obj | |||
| register_supported_type(list, lambda x: (x, None), lambda x, aux_data: list(x)) | |||
| register_supported_type(tuple, lambda x: (x, None), lambda x, aux_data: tuple(x)) | |||
| register_supported_type(dict, _dict_flatten, _dict_unflatten) | |||
| register_supported_type( | |||
| collections.OrderedDict, _ordereddict_flatten, _ordereddict_unflatten | |||
| _register_supported_type(list, lambda x: (x, None), lambda x, aux_data: list(x)) | |||
| _register_supported_type(tuple, lambda x: (x, None), lambda x, aux_data: tuple(x)) | |||
| _register_supported_type( | |||
| dict, partial(_dict_flatten, False), partial(_dict_unflatten, dict) | |||
| ) | |||
| _register_supported_type( | |||
| defaultdict, partial(_dict_flatten, False), partial(_dict_unflatten, defaultdict) | |||
| ) | |||
| register_supported_type( | |||
| _register_supported_type( | |||
| OrderedDict, partial(_dict_flatten, True), partial(_dict_unflatten, OrderedDict) | |||
| ) | |||
| _register_supported_type( | |||
| slice, | |||
| lambda x: ([x.start, x.stop, x.step], None), | |||
| lambda x, aux_data: slice(x[0], x[1], x[2]), | |||
| ) | |||
| register_supported_type(QParams, qparams_flatten, qparams_unflatten) | |||
| _register_supported_type(QParams, qparams_flatten, partial(qparams_unflatten, QParams)) | |||
| _register_supported_type( | |||
| LSQParams, qparams_flatten, partial(qparams_unflatten, LSQParams) | |||
| ) | |||
| def _is_leaf(obj): | |||
| if isinstance(obj, type): | |||
| return issubclass(obj, tuple(SUPPORTED_LEAF_CLS)) or obj in SUPPORTED_LEAF_TYPE | |||
| obj_type = obj if isinstance(obj, type) else type(obj) | |||
| return ( | |||
| isinstance(obj, tuple(SUPPORTED_LEAF_CLS)) or type(obj) in SUPPORTED_LEAF_TYPE | |||
| issubclass(obj_type, tuple(SUPPORTED_LEAF_CLS)) | |||
| or obj_type in SUPPORTED_LEAF_TYPE | |||
| ) | |||
| @@ -5,30 +5,158 @@ | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from typing import Dict | |||
| from importlib import import_module | |||
| from typing import Dict, Tuple | |||
| from ..core._imperative_rt import OpDef | |||
| from ..core.ops import builtin | |||
| from ..tensor import Tensor | |||
| from ..version import __version__ | |||
| from .utils import _convert_kwargs_to_args | |||
| OPDEF_PARAM_LOADER = {} | |||
| OPDEF_LOADER = {} | |||
| FUNCTIONAL_LOADER = {} | |||
| TENSORMETHOD_LOADER = {} | |||
| MODULE_LOADER = {} | |||
| def get_opdef_state(obj: OpDef) -> Dict: | |||
| state = obj.__getstate__() | |||
| state["type"] = type(obj) | |||
| state["version"] = __version__ | |||
| return state | |||
| class _ModuleState: | |||
| obj = None | |||
| def __init__(self, module: Tuple, state: Dict, version: str): | |||
| self.module = module | |||
| self.state = state | |||
| self.version = version | |||
| def load_opdef_from_state(state: Dict) -> OpDef: | |||
| assert "type" in state and issubclass(state["type"], OpDef) | |||
| assert "version" in state | |||
| opdef_type = state.pop("type") | |||
| if opdef_type in OPDEF_PARAM_LOADER: | |||
| loader = OPDEF_PARAM_LOADER[opdef_type] | |||
| state = loader(state) | |||
| state.pop("version") | |||
| opdef_obj = opdef_type() | |||
| opdef_obj.__setstate__(state) | |||
| return opdef_obj | |||
| @classmethod | |||
| def get_module_state(cls, module): | |||
| typem = (type(module).__module__, type(module).__qualname__) | |||
| state = module.__dict__.copy() | |||
| state.pop("_m_dump_modulestate", None) | |||
| if hasattr(module, "_m_dump_modulestate"): | |||
| assert isinstance(module._m_dump_modulestate, cls) | |||
| module._m_dump_modulestate.__init__(typem, state, __version__) | |||
| else: | |||
| module.__dict__["_m_dump_modulestate"] = _ModuleState( | |||
| typem, state, __version__ | |||
| ) | |||
| return module._m_dump_modulestate | |||
| def __getstate__(self): | |||
| return {"module": self.module, "state": self.state, "version": self.version} | |||
| def to_module(self): | |||
| if self.obj is None: | |||
| typem = getattr(import_module(self.module[0]), self.module[1]) | |||
| m_obj = typem.__new__(typem) | |||
| m_obj.__dict__.update(self.state) | |||
| self.obj = m_obj | |||
| return self.obj | |||
| def register_opdef_loader(*opdefs): | |||
| def callback(loader): | |||
| for opdef in opdefs: | |||
| assert opdef not in OPDEF_LOADER | |||
| OPDEF_LOADER[opdef] = loader | |||
| return loader | |||
| return callback | |||
| def register_functional_loader(*funcs): | |||
| def callback(loader): | |||
| for func in funcs: | |||
| assert func not in FUNCTIONAL_LOADER | |||
| FUNCTIONAL_LOADER[func] = loader | |||
| return loader | |||
| return callback | |||
| def register_module_loader(*module_types): | |||
| def callback(loader): | |||
| for module_type in module_types: | |||
| assert module_type not in MODULE_LOADER | |||
| MODULE_LOADER[module_type] = loader | |||
| return loader | |||
| return callback | |||
| def register_tensor_method_loader(*methods): | |||
| def callback(loader): | |||
| for method in methods: | |||
| assert method not in TENSORMETHOD_LOADER | |||
| TENSORMETHOD_LOADER[method] = loader | |||
| return loader | |||
| return callback | |||
| def _replace_args_kwargs(expr, new_args, new_kwargs): | |||
| if len(new_args) != len(expr.args) or set(new_kwargs.keys()) != set( | |||
| expr.kwargs.keys() | |||
| ): | |||
| expr.set_args_kwargs(*new_args, **new_kwargs) | |||
| def load_functional(expr): | |||
| func = ( | |||
| (expr.func.__module__, expr.func.__qualname__) | |||
| if callable(expr.func) | |||
| else expr.func | |||
| ) | |||
| assert isinstance(func, tuple) | |||
| if func in FUNCTIONAL_LOADER: | |||
| loader = FUNCTIONAL_LOADER[func] | |||
| loader(expr) | |||
| mname, fname = func | |||
| f = import_module(mname) | |||
| for i in fname.split("."): | |||
| f = getattr(f, i) | |||
| expr.func = f | |||
| assert callable(expr.func) | |||
| if not hasattr(expr, "version") or expr.version != __version__: | |||
| args, kwargs = _convert_kwargs_to_args(expr.func, expr.args, expr.kwargs) | |||
| _replace_args_kwargs(expr, args, kwargs) | |||
| def load_call_module_expr(expr): | |||
| m_type = expr.inputs[0].module_type | |||
| if isinstance(m_type, type): | |||
| m_type = (m_type.__module__, m_type.__qualname__) | |||
| if m_type in MODULE_LOADER: | |||
| MODULE_LOADER[m_type](expr) | |||
| if isinstance(expr.inputs[0].module_type, tuple): | |||
| mname, classname = expr.inputs[0].module_type | |||
| expr.inputs[0].module_type = getattr(import_module(mname), classname) | |||
| if not hasattr(expr, "version") or expr.version != __version__: | |||
| fwd_func = getattr(expr.inputs[0].module_type, "forward") | |||
| args, kwargs = _convert_kwargs_to_args(fwd_func, expr.args, expr.kwargs) | |||
| _replace_args_kwargs(expr, args, kwargs) | |||
| def load_call_tensor_method_expr(expr): | |||
| if expr.method in TENSORMETHOD_LOADER: | |||
| loader = TENSORMETHOD_LOADER[expr.method] | |||
| loader(expr) | |||
| if not hasattr(expr, "version") or expr.version != __version__: | |||
| tmethod = ( | |||
| getattr(expr.args[0], expr.method) | |||
| if isinstance(expr.args[0], type) | |||
| else getattr(Tensor, expr.method) | |||
| ) | |||
| args, kwargs = _convert_kwargs_to_args(tmethod, expr.args, expr.kwargs) | |||
| _replace_args_kwargs(expr, args, kwargs) | |||
| def load_apply_expr(expr): | |||
| opdef_type = type(expr.opdef) | |||
| if opdef_type in OPDEF_LOADER: | |||
| OPDEF_LOADER[opdef_type](expr) | |||
| opdef_state = expr.opdef_state | |||
| opdef_obj = opdef_state.pop("opdef_type")() | |||
| opdef_obj.__setstate__(opdef_state) | |||
| expr.opdef = opdef_obj | |||
| @@ -14,6 +14,7 @@ import inspect | |||
| import keyword | |||
| import re | |||
| import weakref | |||
| from importlib import import_module | |||
| from inspect import getcallargs, getmembers, isclass, ismethod | |||
| from itertools import chain | |||
| from types import FunctionType | |||
| @@ -53,6 +54,7 @@ from ..quantization.observer import ( | |||
| SyncMinMaxObserver, | |||
| ) | |||
| from ..tensor import Tensor | |||
| from ..version import __version__ | |||
| from .expr import ( | |||
| Apply, | |||
| CallFunction, | |||
| @@ -80,8 +82,27 @@ from .module_tracer import ( | |||
| set_active_module_tracer, | |||
| ) | |||
| from .node import ModuleNode, Node, NodeMixin, TensorNode | |||
| from .pytree import ArgsIndex, tree_flatten | |||
| from .utils import replace_container_with_module_container | |||
| from .pytree import ( | |||
| USER_REGISTERED_CONTAINER_TYPE, | |||
| USER_REGISTERED_LEAF_TYPE, | |||
| ArgsIndex, | |||
| TreeDef, | |||
| _register_supported_type, | |||
| tree_flatten, | |||
| ) | |||
| from .serialization import ( | |||
| _ModuleState, | |||
| load_apply_expr, | |||
| load_call_module_expr, | |||
| load_call_tensor_method_expr, | |||
| load_functional, | |||
| ) | |||
| from .utils import ( | |||
| _check_builtin_module_attr, | |||
| _check_obj_attr, | |||
| _convert_kwargs_to_args, | |||
| replace_container_with_module_container, | |||
| ) | |||
| logger = get_logger(__name__) | |||
| @@ -341,7 +362,7 @@ class NameSpace: | |||
| def create_unique_name(self, name: str, node: Any = None) -> str: | |||
| assert isinstance(name, str), "The name must be a string" | |||
| if name in self._used_names and self._used_names[name] is node: | |||
| if name in self._used_names and (self._used_names[name] is node): | |||
| return name | |||
| name = re.sub("[^0-9a-zA-Z_]+", "_", name) | |||
| @@ -1067,6 +1088,7 @@ class InternalGraph: | |||
| if node2value[n][1] == 0: | |||
| node2value.pop(n) | |||
| if values is not None: | |||
| assert len(values) == len(expr.outputs) | |||
| for n, v in zip(expr.outputs, values): | |||
| if ref_count(n) > 0: | |||
| node2value[n] = [v, ref_count(n)] | |||
| @@ -1105,13 +1127,27 @@ class InternalGraph: | |||
| return res | |||
| def __getstate__(self): | |||
| state = self.__dict__.copy() | |||
| if "_top_graph" in state: | |||
| state.pop("_top_graph") | |||
| state = { | |||
| "_exprs": self._exprs, | |||
| "_inputs": self._inputs, | |||
| "_outputs": self._outputs, | |||
| "_watch_point": [], | |||
| "_end_point": [], | |||
| "_namespace": self._namespace, | |||
| "_rst": collections.defaultdict(list), | |||
| "_name": self._name, | |||
| "_qualname": self._qualname, | |||
| } | |||
| if self._total_ids: | |||
| state["_total_ids"] = self._total_ids | |||
| _check_obj_attr(state) | |||
| return state | |||
| def __setstate__(self, state): | |||
| old_version = False | |||
| if "_module_name" in state: | |||
| old_version = True | |||
| state["_qualname"] = state.pop("_module_name") | |||
| @@ -1144,6 +1180,25 @@ class InternalGraph: | |||
| self._namespace = NameSpace(self._name, self._qualname) | |||
| self._re_associate_name() | |||
| def __copy__(self): | |||
| cls = self.__class__ | |||
| result = cls.__new__(cls) | |||
| result.__dict__.update(self.__dict__) | |||
| return result | |||
| def __deepcopy__(self, memo): | |||
| if id(self) in memo: | |||
| return memo[id(self)] | |||
| cls = self.__class__ | |||
| result = cls.__new__(cls) | |||
| state = {} | |||
| memo[id(self)] = result | |||
| for k, v in self.__dict__.items(): | |||
| if not isinstance(v, weakref.ReferenceType): | |||
| state[k] = copy.deepcopy(v, memo) | |||
| result.__dict__.update(state) | |||
| return result | |||
| def _get_meth_name(obj, func): | |||
| tp = obj if isinstance(obj, type) else type(obj) | |||
| @@ -1157,9 +1212,7 @@ def _get_meth_name(obj, func): | |||
| def _wrapped_function(orig_func): | |||
| @functools.wraps(orig_func) | |||
| def wrapped_fn(*args, **kwargs): | |||
| method_func = wrapped_fn | |||
| if "method_func" in kwargs: | |||
| method_func = kwargs.pop("method_func") | |||
| method_func = kwargs.pop("method_func", wrapped_fn) | |||
| if is_tracing_module(): | |||
| unset_module_tracing() | |||
| inputs, tree_def = tree_flatten((args, kwargs)) | |||
| @@ -1167,11 +1220,11 @@ def _wrapped_function(orig_func): | |||
| if not NodeMixin.get(i, None): | |||
| if isinstance(i, (RawTensor, NodeMixin)): | |||
| NodeMixin.wrap_safe(i, Constant.make(i)) | |||
| meth_name, arg_type = None, None | |||
| if args: | |||
| meth_name = _get_meth_name(args[0], method_func) | |||
| arg_type = args[0] if isinstance(args[0], type) else type(args[0]) | |||
| args, kwargs = _convert_kwargs_to_args(orig_func, args, kwargs) | |||
| meth_name = _get_meth_name(args[0], method_func) | |||
| arg_type = args[0] if isinstance(args[0], type) else type(args[0]) | |||
| if meth_name and arg_type and issubclass(arg_type, RawTensor): | |||
| inputs, tree_def = tree_flatten((args, kwargs)) | |||
| self = inputs[0] | |||
| if meth_name == "__new__": | |||
| if all([not isinstance(i, RawTensor) for i in inputs]): | |||
| @@ -1190,6 +1243,7 @@ def _wrapped_function(orig_func): | |||
| call_node = CallMethod.make(NodeMixin.get(self), meth_name) | |||
| call_node.add_inputs(inputs[1:]) | |||
| else: | |||
| inputs, tree_def = tree_flatten((args, kwargs)) | |||
| call_node = CallFunction.make(orig_func) | |||
| call_node.add_inputs(inputs) | |||
| @@ -1228,9 +1282,11 @@ class TracedModuleBuilder(NodeMixin): | |||
| "_record_wrapped_nodes", | |||
| "_argdef_graph_map", | |||
| "_argdef_outdef_map", | |||
| "_check_qat_module", | |||
| "nodes", | |||
| "__class__", | |||
| "__dict__", | |||
| "_is_top", | |||
| ] | |||
| def __init__(self, mod, is_top_module=False): | |||
| @@ -1301,22 +1357,18 @@ class TracedModuleBuilder(NodeMixin): | |||
| qat_module.weight_fake_quant.set_qparams(qparams) | |||
| def build(self): | |||
| if self._is_builtin or isinstance(self._mod, TracedModule): | |||
| if module_tracer.is_builtin(self._mod) or isinstance( | |||
| self._mod, TracedModule | |||
| ): | |||
| mod_type = type(self._mod) | |||
| else: | |||
| assert isinstance(self._mod, (Observer, _FakeQuantize)) | |||
| mod_type = ( | |||
| Observer if isinstance(self._mod, Observer) else _FakeQuantize | |||
| ) | |||
| if self._is_builtin: | |||
| assert module_tracer.is_builtin(self._mod) | |||
| mod_type = type(self._mod) | |||
| for node in self.nodes: | |||
| node.module_type = mod_type | |||
| return self._mod | |||
| else: | |||
| is_qat = isinstance(self._mod, QATModule) | |||
| is_qat = isinstance(self._mod, QATModule) or ( | |||
| isinstance(self._mod, TracedModule) and self._mod.is_qat | |||
| ) | |||
| traced_module = TracedModule( | |||
| self._is_top, self._argdef_graph_map, self._argdef_outdef_map, is_qat | |||
| ) | |||
| @@ -1338,15 +1390,18 @@ class TracedModuleBuilder(NodeMixin): | |||
| traced_module.with_act = self._mod.with_act | |||
| traced_module.with_weight = self._mod.with_weight | |||
| if not hasattr(traced_module, "act_fake_quant"): | |||
| traced_module.act_fakequant = None | |||
| traced_module.act_fake_quant = None | |||
| if not hasattr(traced_module, "act_observer"): | |||
| traced_module.act_observer = None | |||
| if not hasattr(traced_module, "weight_fake_quant"): | |||
| traced_module.weight_fakequant = None | |||
| traced_module.weight_fake_quant = None | |||
| if not hasattr(traced_module, "weight_observer"): | |||
| traced_module.weight_observer = None | |||
| set_module_tracing() | |||
| if self._is_top: | |||
| traced_module._update_ref() | |||
| return traced_module | |||
| def _record_wrapped_nodes(self, node): | |||
| @@ -1357,6 +1412,7 @@ class TracedModuleBuilder(NodeMixin): | |||
| # prepare args and kwargs for inner graph | |||
| if "method_func" in kwargs: | |||
| kwargs.pop("method_func") | |||
| args, kwargs = _convert_kwargs_to_args(self._mod.forward, args, kwargs, True) | |||
| def mark_constant(x): | |||
| node = NodeMixin.get(x, None) | |||
| @@ -1372,11 +1428,7 @@ class TracedModuleBuilder(NodeMixin): | |||
| callnode.arg_def = tree_def | |||
| if ( | |||
| self._is_builtin | |||
| or tree_def in self._argdef_graph_map | |||
| or isinstance(self._mod, TracedModule) | |||
| ): | |||
| if self._is_builtin or tree_def in self._argdef_graph_map: | |||
| unset_module_tracing() | |||
| rst = self._mod(*args, **kwargs) | |||
| outputs, out_def = tree_flatten(rst, is_leaf=_is_leaf) | |||
| @@ -1385,33 +1437,7 @@ class TracedModuleBuilder(NodeMixin): | |||
| self._body = None | |||
| elif tree_def in self._argdef_graph_map: | |||
| self._body = self._argdef_graph_map[tree_def] | |||
| else: | |||
| self._mod._is_top = False | |||
| self._body = self._mod.argdef_graph_map[tree_def] | |||
| module_qualname = NodeMixin.get(self).qualname | |||
| if module_qualname != self._body.qualname: | |||
| src_name, dst_name = self._body.qualname, module_qualname | |||
| def replace_qualname(g): | |||
| attr_name = get_suffix_name(src_name, g.qualname) | |||
| if attr_name is not None: | |||
| g._qualname = ( | |||
| ("%s.%s" % (dst_name, attr_name)) | |||
| if attr_name | |||
| else dst_name | |||
| ) | |||
| assert get_suffix_name(dst_name, g.qualname) is not None | |||
| for mod in self._mod.modules(): | |||
| if not hasattr(mod, "argdef_graph_map"): | |||
| continue | |||
| for g in mod.argdef_graph_map.values(): | |||
| replace_qualname(g) | |||
| g._namespace.qualname = g.qualname | |||
| for n in g.nodes(False): | |||
| replace_qualname(n) | |||
| else: | |||
| self_node = None | |||
| orig_self = NodeMixin.get(self) | |||
| parent_graph = active_module_tracer().current_scope() | |||
| module_qualname = orig_self._qualname | |||
| @@ -1423,20 +1449,14 @@ class TracedModuleBuilder(NodeMixin): | |||
| active_module_tracer().push_scope(self._body) | |||
| # rebind self to new input node | |||
| if self_node: | |||
| NodeMixin.wrap_safe(self, self_node) | |||
| active_module_tracer().current_scope()._add_input(self_node) | |||
| else: | |||
| NodeMixin.wrap_safe( | |||
| self, | |||
| self_node | |||
| if self_node | |||
| else Input.make( | |||
| name="self", | |||
| qualname=module_qualname, | |||
| type=NodeMixin.get_wrapped_type(self), | |||
| ), | |||
| ) | |||
| NodeMixin.wrap_safe( | |||
| self, | |||
| Input.make( | |||
| name="self", | |||
| qualname=module_qualname, | |||
| type=NodeMixin.get_wrapped_type(self), | |||
| ), | |||
| ) | |||
| origin_inp_node = [NodeMixin.get(i, None) for i in inputs[1:]] | |||
| # prepare args and kwargs for inner graph | |||
| @@ -1470,8 +1490,23 @@ class TracedModuleBuilder(NodeMixin): | |||
| return x | |||
| args = [self] | |||
| for i, v in enumerate(inputs[1:]): | |||
| args.append(wrap(v, idx2key[i + 1])) | |||
| orig_traced_inputs = ( | |||
| None | |||
| if not isinstance(self._mod, TracedModule) | |||
| else self._mod.argdef_graph_map[tree_def].inputs | |||
| ) | |||
| ind = 1 | |||
| for v in inputs[1:]: | |||
| if isinstance(v, (RawTensor, NodeMixin)): | |||
| args_name = ( | |||
| orig_traced_inputs[ind]._name | |||
| if orig_traced_inputs | |||
| else idx2key[ind] | |||
| ) | |||
| ind += 1 | |||
| args.append(wrap(v, args_name)) | |||
| else: | |||
| args.append(v) | |||
| args, kwargs = tree_def.unflatten(args) | |||
| active_module_tracer().patcher.auto_patch( | |||
| @@ -1514,7 +1549,6 @@ class TracedModuleBuilder(NodeMixin): | |||
| attr = getattr(type(self._mod), name).__get__(self, type(self)) | |||
| else: | |||
| attr = getattr(self._mod, name) | |||
| if ( | |||
| isinstance(attr, FunctionType) | |||
| and id(attr) in active_module_tracer().patcher.patched_fn_ids | |||
| @@ -1568,7 +1602,7 @@ class TracedModuleBuilder(NodeMixin): | |||
| wrapped = self.__getattr__(name) | |||
| if isinstance(wrapped, TracedModuleBuilder): | |||
| if not isinstance(mod_attr, (List, Dict)): | |||
| if not isinstance(mod_attr, (List, Dict, QATModule)): | |||
| assert mod_attr is wrapped._mod | |||
| else: | |||
| assert mod_attr is wrapped | |||
| @@ -1977,8 +2011,6 @@ class TracedModule(Module): | |||
| def graph(self) -> InternalGraph: | |||
| """Return the ``InternalGraph`` of this ``TracedModule``. | |||
| """ | |||
| if self._is_top: | |||
| self._update_ref() | |||
| assert len(self.argdef_graph_map) == 1 | |||
| return list(self.argdef_graph_map.values())[0] | |||
| @@ -2112,7 +2144,7 @@ class TracedModule(Module): | |||
| if hasattr(obj, "argdef_graph_map") | |||
| else None | |||
| ) | |||
| if expr_graph is not None: | |||
| if expr_graph is not None and not obj.is_qat: | |||
| exprs = _flatten_subgraph(graph, expr_graph, expr, obj) | |||
| if parent_graph is not None: | |||
| @@ -2137,26 +2169,119 @@ class TracedModule(Module): | |||
| ) | |||
| new_module.graph._re_associate_name() | |||
| new_module.graph.compile() | |||
| new_module._update_ref() | |||
| new_module.graph._reset_ids() | |||
| return new_module | |||
| def __getstate__(self): | |||
| d = self.__dict__ | |||
| d = self.__dict__.copy() | |||
| for k in Module.__dict__: | |||
| d.pop(k, None) | |||
| _check_obj_attr(d) | |||
| for k in d: | |||
| if module_tracer.is_builtin(d[k]): | |||
| assert _check_builtin_module_attr( | |||
| d[k] | |||
| ), "Module {} can not be serialized. ".format(type(d[k])) | |||
| d[k] = _ModuleState.get_module_state(d[k]) | |||
| dump_info = { | |||
| "version": __version__, | |||
| "register_type": USER_REGISTERED_LEAF_TYPE, | |||
| "register_container_type": USER_REGISTERED_CONTAINER_TYPE, | |||
| "register_mdule": USER_REGISTERED_MODULE, | |||
| "register_function": USER_REGISTERED_FUNCTION, | |||
| } | |||
| d["dump_info"] = dump_info | |||
| return d | |||
| def __setstate__(self, state): | |||
| for k, v in state.items(): | |||
| if isinstance(v, _ModuleState): | |||
| state[k] = v.to_module() | |||
| self.__dict__.update(state) | |||
| self._update_ref() | |||
| for _, graph in self.argdef_graph_map.items(): | |||
| for expr in graph._exprs: | |||
| if isinstance(expr, CallFunction): | |||
| load_functional(expr) | |||
| if isinstance(expr, CallMethod): | |||
| if expr.method == "__call__": | |||
| load_call_module_expr(expr) | |||
| else: | |||
| load_call_tensor_method_expr(expr) | |||
| if isinstance(expr, Apply): | |||
| load_apply_expr(expr) | |||
| for _, graph in self.argdef_graph_map.items(): | |||
| ind = 0 | |||
| while ind < len(graph._exprs): | |||
| cur_expr = graph._exprs[ind] | |||
| has_new_expr = False | |||
| for i in cur_expr.inputs: | |||
| if i.expr not in graph._exprs and not isinstance(i.expr, Input): | |||
| graph._exprs.insert(ind, i.expr) | |||
| has_new_expr = True | |||
| if not has_new_expr: | |||
| ind += 1 | |||
| for expr in graph._exprs: | |||
| for i in expr.inputs: | |||
| if expr.inputs.count(i) != i.users.count(expr): | |||
| add_or_del_count = expr.inputs.count(i) - i.users.count(expr) | |||
| if add_or_del_count > 0: | |||
| i.users.extend([expr] * add_or_del_count) | |||
| else: | |||
| [i.users.remove(expr) for i in range(-add_or_del_count)] | |||
| for o in expr.outputs: | |||
| if o.expr is not expr: | |||
| assert o not in o.expr.outputs | |||
| o.expr = expr | |||
| for node in graph.nodes(False): | |||
| # remove users of node which doesn't use node as input | |||
| node.users = [e for e in node.users if node in e.inputs] | |||
| for expr in graph._exprs: | |||
| graph._namespace.auto_naming_for_outputs(expr) | |||
| self._update_ref() | |||
| for _, graph in self.argdef_graph_map.items(): | |||
| graph._reset_ids() | |||
| def __copy__(self): | |||
| cls = self.__class__ | |||
| result = cls.__new__(cls) | |||
| result.__dict__.update(self.__dict__) | |||
| return result | |||
| def __deepcopy__(self, memo): | |||
| cls = self.__class__ | |||
| result = cls.__new__(cls) | |||
| state = {} | |||
| memo[id(self)] = result | |||
| for k, v in self.__dict__.items(): | |||
| if not isinstance(v, weakref.ReferenceType): | |||
| state[k] = copy.deepcopy(v, memo) | |||
| result.__dict__.update(state) | |||
| result._update_ref() | |||
| return result | |||
| def cpp_apply_module_trace(opdef, *args): | |||
| return Apply.apply_module_trace_hook(opdef, *args) | |||
| USER_REGISTERED_MODULE = [] | |||
| USER_REGISTERED_FUNCTION = [] | |||
| def register_as_builtin(mod_cls: Type[Module]) -> None: | |||
| r"""Registers class ``mod_cls`` (subclass of :class:`~.Module`) as builtin module. | |||
| Args: | |||
| mod_cls: the module class which will be treated as builtin module in tracing. | |||
| """ | |||
| USER_REGISTERED_MODULE.append((mod_cls.__module__, mod_cls.__qualname__)) | |||
| module_tracer.register_as_builtin(mod_cls) | |||
| @@ -2181,6 +2306,7 @@ def wrap(func: Callable): | |||
| Args: | |||
| func: the function of the global function to insert into the graph when it's called. | |||
| """ | |||
| USER_REGISTERED_FUNCTION.append((func.__module__, func.__qualname__)) | |||
| assert callable(func), "func must be a callable" | |||
| assert hasattr(func, "__code__") | |||
| fn_name = func.__code__.co_name | |||
| @@ -2247,6 +2373,8 @@ def trace_module( | |||
| NodeMixin.wrap_safe( | |||
| builder, Input.make(name="top", type=ModuleNode, qualname=net_name) | |||
| ) | |||
| args, kwargs = _convert_kwargs_to_args(mod.forward, args, kwargs, True) | |||
| inputs, _ = tree_flatten((args, kwargs)) | |||
| for _, i in enumerate(inputs): | |||
| # assert isinstance(i, Tensor), "not support " | |||
| @@ -5,12 +5,17 @@ | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import collections | |||
| import copy | |||
| import inspect | |||
| from collections.abc import MutableMapping, MutableSequence | |||
| from typing import Dict, Iterable, List, Optional, Sequence | |||
| from typing import Dict, Iterable, List, Optional, Sequence, Type | |||
| from .. import get_logger | |||
| from ..module import Module | |||
| logger = get_logger(__name__) | |||
| def replace_container_with_module_container(container): | |||
| has_module = False | |||
| @@ -52,6 +57,101 @@ def replace_container_with_module_container(container): | |||
| return has_module, module_container | |||
| def _convert_kwargs_to_args(func, args, kwargs, is_bounded=False): | |||
| # is_bounded = True when func is a method and provided args don't include 'self' | |||
| arg_specs = inspect.getfullargspec(func) | |||
| arg_specs_args = arg_specs.args | |||
| if is_bounded: | |||
| arg_specs_args = arg_specs.args[1:] | |||
| new_args = [] | |||
| new_kwargs = {} | |||
| new_args.extend(args) | |||
| if set(arg_specs_args[0 : len(new_args)]) & set(kwargs.keys()): | |||
| repeated_arg_name = set(arg_specs_args[0 : len(new_args)]) & set(kwargs.keys()) | |||
| raise TypeError( | |||
| "{} got multiple values for argument {}".format( | |||
| func.__qualname__, ", ".join(repeated_arg_name) | |||
| ) | |||
| ) | |||
| if len(new_args) < len(arg_specs.args): | |||
| for ind in range(len(new_args), len(arg_specs_args)): | |||
| arg_name = arg_specs_args[ind] | |||
| if arg_name in kwargs: | |||
| new_args.append(kwargs[arg_name]) | |||
| else: | |||
| index = ind - len(arg_specs_args) + len(arg_specs.defaults) | |||
| assert index < len(arg_specs.defaults) and index >= 0 | |||
| new_args.append(arg_specs.defaults[index]) | |||
| for kwarg_name in arg_specs.kwonlyargs: | |||
| if kwarg_name in kwargs: | |||
| new_kwargs[kwarg_name] = kwargs[kwarg_name] | |||
| else: | |||
| assert kwarg_name in arg_specs.kwonlydefaults | |||
| new_kwargs[kwarg_name] = arg_specs.kwonlydefaults[kwarg_name] | |||
| for k, v in kwargs.items(): | |||
| if k not in arg_specs.args and k not in arg_specs.kwonlyargs: | |||
| if arg_specs.varkw is None: | |||
| raise TypeError( | |||
| "{} got an unexpected keyword argument {}".format( | |||
| func.__qualname__, k | |||
| ) | |||
| ) | |||
| new_kwargs[k] = v | |||
| return tuple(new_args), new_kwargs | |||
| def _check_obj_attr(obj): | |||
| # check if all the attributes of a obj is serializable | |||
| from .pytree import tree_flatten | |||
| from .pytree import SUPPORTED_LEAF_CLS, SUPPORTED_LEAF_TYPE, TreeDef | |||
| from .expr import Expr | |||
| from .traced_module import TracedModule, InternalGraph, NameSpace | |||
| def _check_leaf_type(leaf): | |||
| leaf_type = leaf if isinstance(leaf, type) else type(leaf) | |||
| traced_module_types = [Expr, TreeDef, TracedModule, InternalGraph, NameSpace] | |||
| return ( | |||
| issubclass(leaf_type, tuple(SUPPORTED_LEAF_CLS + traced_module_types)) | |||
| or leaf_type in SUPPORTED_LEAF_TYPE | |||
| ) | |||
| for _, v in obj.items(): | |||
| leafs, _ = tree_flatten(v, is_leaf=lambda _: True) | |||
| for leaf in leafs: | |||
| assert _check_leaf_type( | |||
| leaf | |||
| ), "Type {} is not supported by traced module".format( | |||
| leaf if isinstance(leaf, type) else type(leaf) | |||
| ) | |||
| def _check_builtin_module_attr(mod): | |||
| from .pytree import _is_leaf as _check_leaf_type | |||
| from .pytree import tree_flatten | |||
| # check if all the attributes of a builtin module is serializable | |||
| is_non_serializable_module = lambda m: isinstance( | |||
| m, Module | |||
| ) and not _check_builtin_module_attr(m) | |||
| for k, v in mod.__dict__.items(): | |||
| if k == "_m_dump_modulestate": | |||
| continue | |||
| if is_non_serializable_module(v): | |||
| return False | |||
| elif not isinstance(v, Module): | |||
| leafs, _ = tree_flatten(v, is_leaf=lambda _: True) | |||
| for leaf in leafs: | |||
| if not _check_leaf_type(leaf) or is_non_serializable_module(leaf): | |||
| logger.warn( | |||
| "Type {} is not supported by traced module".format( | |||
| leaf if isinstance(leaf, type) else type(leaf) | |||
| ) | |||
| ) | |||
| return False | |||
| return True | |||
| class _ModuleList(Module, MutableSequence): | |||
| r"""A List-like container. | |||
| @@ -15,7 +15,6 @@ import numpy as np | |||
| import megengine as mge | |||
| from megengine import Parameter, Tensor | |||
| from megengine.core.ops import builtin | |||
| from megengine.traced_module.serialization import get_opdef_state, load_opdef_from_state | |||
| def test_tensor_serialization(): | |||
| @@ -88,25 +87,3 @@ def test_compatibility(): | |||
| test_old_tensor("tensor_v1_1.mge") | |||
| test_old_tensor("tensor_v1_2.mge") | |||
| def test_opdef_serialization(): | |||
| with TemporaryFile() as f: | |||
| x = builtin.Elemwise(mode="Add") | |||
| pickle.dump(get_opdef_state(x), f) | |||
| f.seek(0) | |||
| load_x = load_opdef_from_state(pickle.load(f)) | |||
| assert x == load_x | |||
| with TemporaryFile() as f: | |||
| x = builtin.Convolution(stride_h=9, compute_mode="float32") | |||
| x.strategy = ( | |||
| builtin.Convolution.Strategy.PROFILE | |||
| | builtin.Convolution.Strategy.HEURISTIC | |||
| | builtin.Convolution.Strategy.REPRODUCIBLE | |||
| ) | |||
| pickle.dump(get_opdef_state(x), f) | |||
| f.seek(0) | |||
| load_x = load_opdef_from_state(pickle.load(f)) | |||
| assert x.strategy == load_x.strategy | |||
| assert x == load_x | |||
| @@ -85,12 +85,12 @@ class NewModule(M.Module): | |||
| return x | |||
| def _check_expr_users(traced_module): | |||
| def _check_expr_users(flattened_module): | |||
| node_user = defaultdict(list) | |||
| for expr in traced_module.graph._exprs: | |||
| for expr in flattened_module.graph._exprs: | |||
| for node in expr.inputs: | |||
| node_user[node].append(expr) | |||
| for node in traced_module.graph.nodes(): | |||
| for node in flattened_module.graph.nodes(): | |||
| node.users.sort(key=lambda m: m._id) | |||
| node_user[node].sort(key=lambda m: m._id) | |||
| assert node.users == node_user[node] | |||
| @@ -8,6 +8,7 @@ import numpy as np | |||
| import megengine as mge | |||
| import megengine.functional as F | |||
| import megengine.module as M | |||
| import megengine.module.qat as QM | |||
| import megengine.quantization as Q | |||
| from megengine import Tensor | |||
| from megengine.module.qat.module import QATModule | |||
| @@ -28,10 +29,18 @@ def get_subattr(self: M.Module, name: str): | |||
| return getattr(self, name) | |||
| class MyConvBnRelu2d(M.ConvBnRelu2d): | |||
| pass | |||
| class MyQATConvBnRelu2d(QM.ConvBnRelu2d): | |||
| pass | |||
| class Myblcok(M.Module): | |||
| def __init__(self,): | |||
| super().__init__() | |||
| self.conv0 = M.ConvBnRelu2d(3, 3, 3, 1, 1) | |||
| self.conv0 = MyConvBnRelu2d(3, 3, 3, 1, 1) | |||
| self.conv1 = M.ConvBn2d(3, 3, 1, 1, 0) | |||
| self.conv2 = M.ConvBn2d(3, 3, 1, 1, 0) | |||
| self.add = M.Elemwise("FUSE_ADD_RELU") | |||
| @@ -106,7 +115,11 @@ def check_qparams(qparmsa: Q.QParams, qparmsb: Q.QParams): | |||
| def build_observered_net(net: M.Module, observer_cls): | |||
| qat_net = Q.quantize_qat(net, qconfig=get_observer_config(observer_cls)) | |||
| qat_net = Q.quantize_qat( | |||
| net, | |||
| qconfig=get_observer_config(observer_cls), | |||
| mapping={MyConvBnRelu2d: MyQATConvBnRelu2d}, | |||
| ) | |||
| Q.enable_observer(qat_net) | |||
| inp = Tensor(np.random.random(size=(5, 3, 32, 32))) | |||
| qat_net(inp) | |||
| @@ -134,6 +147,15 @@ def test_trace_qat(): | |||
| check_qparams(weight_qparams, traced_weight_qparams) | |||
| if act_qparams: | |||
| check_qparams(act_qparams, traced_act_qparams) | |||
| flatten_traced_net = traced_net.flatten() | |||
| conv0_node = flatten_traced_net.graph.get_node_by_name( | |||
| "MyModule_block0_conv0" | |||
| ).as_unique() | |||
| conv0_out_node = flatten_traced_net.graph.get_node_by_name( | |||
| "MyModule_block0_conv0_out" | |||
| ).as_unique() | |||
| assert isinstance(conv0_node.owner, TracedModule) | |||
| assert conv0_out_node.expr.inputs[0] is conv0_node | |||
| _check_qat_module(build_observered_net(MyModule(), Q.MinMaxObserver)) | |||
| _check_qat_module(build_observered_net(MyModule(), MyMinMaxObserver)) | |||
| @@ -6,14 +6,59 @@ | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import pickle | |||
| from collections import defaultdict | |||
| from tempfile import TemporaryFile | |||
| import numpy as np | |||
| import megengine.functional as F | |||
| import megengine.module as M | |||
| import megengine.traced_module.serialization as S | |||
| from megengine import Tensor | |||
| from megengine.core._imperative_rt.core2 import apply | |||
| from megengine.core.ops import builtin | |||
| from megengine.core.ops.builtin import Elemwise | |||
| from megengine.module import Module | |||
| from megengine.traced_module import trace_module | |||
| from megengine.traced_module.expr import CallMethod, Constant | |||
| from megengine.traced_module.node import TensorNode | |||
| from megengine.traced_module.serialization import ( | |||
| register_functional_loader, | |||
| register_module_loader, | |||
| register_opdef_loader, | |||
| register_tensor_method_loader, | |||
| ) | |||
| from megengine.traced_module.utils import _convert_kwargs_to_args | |||
| def _check_id(traced_module): | |||
| _total_ids = traced_module.graph._total_ids | |||
| node_ids = [n._id for n in traced_module.graph.nodes().as_list()] | |||
| assert len(set(node_ids)) == len(node_ids) | |||
| assert max(node_ids) + 1 == _total_ids[0] | |||
| expr_ids = [n._id for n in traced_module.graph.exprs().as_list()] | |||
| assert len(set(expr_ids)) == len(expr_ids) | |||
| assert max(expr_ids) + 1 == _total_ids[1] | |||
| def _check_name(flatened_module): | |||
| node_names = [n._name for n in flatened_module.graph.nodes().as_list()] | |||
| assert len(set(node_names)) == len(node_names) | |||
| def _check_expr_users(traced_module): | |||
| node_user = defaultdict(list) | |||
| for expr in traced_module.graph._exprs: | |||
| for node in expr.inputs: | |||
| node_user[node].append(expr) | |||
| if isinstance(expr, CallMethod) and expr.graph: | |||
| _check_expr_users(expr.inputs[0].owner) | |||
| for node in traced_module.graph.nodes(False): | |||
| node.users.sort(key=lambda m: m._id) | |||
| node_user[node].sort(key=lambda m: m._id) | |||
| assert node.users == node_user[node] | |||
| class MyBlock(Module): | |||
| @@ -48,5 +93,274 @@ def test_dump_and_load(): | |||
| traced_module = trace_module(module, x) | |||
| np.testing.assert_array_equal(expect, traced_module(x)) | |||
| obj = pickle.dumps(traced_module) | |||
| pickle.loads(obj) | |||
| new_tm = pickle.loads(obj) | |||
| _check_id(new_tm) | |||
| _check_expr_users(new_tm) | |||
| traced_module.graph._reset_ids() | |||
| old_nodes = traced_module.graph.nodes().as_list() | |||
| new_nodes = new_tm.graph.nodes().as_list() | |||
| old_exprs = traced_module.graph.exprs().as_list() | |||
| new_exprs = new_tm.graph.exprs().as_list() | |||
| assert len(old_nodes) == len(new_nodes) | |||
| for i, j in zip(old_nodes, new_nodes): | |||
| assert i._name == j._name | |||
| assert i._qualname == j._qualname | |||
| assert i._id == j._id | |||
| assert len(old_exprs) == len(new_exprs) | |||
| for i, j in zip(old_exprs, new_exprs): | |||
| assert i._id == j._id | |||
| np.testing.assert_array_equal(expect, traced_module(x)) | |||
| def test_opdef_loader(): | |||
| class MyModule1(Module): | |||
| def forward(self, x, y): | |||
| op = Elemwise("ADD") | |||
| return apply(op, x, y)[0] | |||
| m = MyModule1() | |||
| x = Tensor(np.ones((20))) | |||
| y = Tensor(np.ones((20))) | |||
| traced_module = trace_module(m, x, y) | |||
| orig_loader_dict = S.OPDEF_LOADER | |||
| S.OPDEF_LOADER = {} | |||
| @register_opdef_loader(Elemwise) | |||
| def add_opdef_loader(expr): | |||
| if expr.opdef_state["mode"] == "ADD": | |||
| expr.opdef_state["mode"] = "MUL" | |||
| node = expr.inputs[1] | |||
| astype_expr = CallMethod(node, "astype") | |||
| oup = TensorNode( | |||
| astype_expr, | |||
| shape=node.shape, | |||
| dtype=expr.inputs[0].dtype, | |||
| qparams=node.qparams, | |||
| ) | |||
| astype_expr.set_args_kwargs(node, expr.inputs[0].dtype) | |||
| astype_expr.return_val = (oup,) | |||
| expr.inputs[1] = oup | |||
| obj = pickle.dumps(traced_module) | |||
| new_module = pickle.loads(obj) | |||
| _check_id(new_module) | |||
| _check_expr_users(new_module) | |||
| _check_name(new_module.flatten()) | |||
| assert ( | |||
| isinstance(new_module.graph._exprs[0], CallMethod) | |||
| and new_module.graph._exprs[1].opdef.mode == "MUL" | |||
| and len(new_module.graph._exprs) == 2 | |||
| ) | |||
| result = new_module(x, y) | |||
| np.testing.assert_equal(result.numpy(), x.numpy()) | |||
| S.OPDEF_LOADER = orig_loader_dict | |||
| def test_functional_loader(): | |||
| class MyModule2(Module): | |||
| def forward(self, x, y): | |||
| return F.conv2d(x, y) | |||
| m = MyModule2() | |||
| x = Tensor(np.random.random((1, 3, 32, 32))) | |||
| y = Tensor(np.random.random((3, 3, 3, 3))) | |||
| traced_module = trace_module(m, x, y) | |||
| orig_loader_dict = S.FUNCTIONAL_LOADER | |||
| S.FUNCTIONAL_LOADER = {} | |||
| @register_functional_loader(("megengine.functional.nn", "conv2d")) | |||
| def conv2df_loader(expr): | |||
| # expr.func = ("megengine.functional.nn","conv2d") | |||
| kwargs = expr.kwargs | |||
| orig_weight = expr.named_args["weight"] | |||
| astype_expr = CallMethod(orig_weight, "astype") | |||
| oup = TensorNode( | |||
| astype_expr, | |||
| shape=orig_weight.shape, | |||
| dtype=orig_weight.dtype, | |||
| qparams=orig_weight.qparams, | |||
| ) | |||
| astype_expr.set_args_kwargs(orig_weight, expr.named_args["inp"].dtype) | |||
| astype_expr.return_val = (oup,) | |||
| expr.set_arg("weight", oup) | |||
| obj = pickle.dumps(traced_module) | |||
| new_module = pickle.loads(obj) | |||
| _check_expr_users(new_module) | |||
| _check_id(new_module) | |||
| result = new_module(x, y) | |||
| gt = m(x, y) | |||
| assert ( | |||
| isinstance(new_module.graph._exprs[0], CallMethod) | |||
| and len(new_module.graph._exprs) == 2 | |||
| ) | |||
| np.testing.assert_equal(result.numpy(), gt.numpy()) | |||
| S.FUNCTIONAL_LOADER = orig_loader_dict | |||
| def test_tensor_method_loader(): | |||
| class MyModule3(Module): | |||
| def forward(self, x): | |||
| return x + 1 | |||
| m = MyModule3() | |||
| x = Tensor(np.ones((20))) | |||
| traced_module = trace_module(m, x) | |||
| orig_loader_dict = S.TENSORMETHOD_LOADER | |||
| S.TENSORMETHOD_LOADER = {} | |||
| @register_tensor_method_loader("__add__") | |||
| def add_loader(expr): | |||
| args = list(expr.args) | |||
| if not isinstance(args[1], TensorNode): | |||
| args[1] = Tensor(args[1]) | |||
| node = Constant(args[1], "const").outputs[0] | |||
| astype_expr = CallMethod(node, "astype") | |||
| oup = TensorNode( | |||
| astype_expr, shape=node.shape, dtype=node.dtype, qparams=node.qparams, | |||
| ) | |||
| astype_expr.set_args_kwargs(node, expr.inputs[0].dtype) | |||
| astype_expr.return_val = (oup,) | |||
| add_expr = CallMethod(oup, "__add__") | |||
| add_expr.set_args_kwargs(oup, oup) | |||
| oup1 = TensorNode( | |||
| add_expr, shape=oup.shape, dtype=oup.dtype, qparams=node.qparams, | |||
| ) | |||
| add_expr.return_val = oup1 | |||
| args[1] = oup1 | |||
| expr.set_args_kwargs(*args) | |||
| obj = pickle.dumps(traced_module) | |||
| new_module = pickle.loads(obj) | |||
| _check_expr_users(new_module) | |||
| _check_id(new_module) | |||
| result = new_module(x) | |||
| gt = m(x) | |||
| assert ( | |||
| isinstance(new_module.graph._exprs[0], Constant) | |||
| and len(new_module.graph._exprs) == 4 | |||
| ) | |||
| np.testing.assert_equal(result.numpy(), (x + 2).numpy()) | |||
| S.TENSORMETHOD_LOADER = orig_loader_dict | |||
| def test_module_loader(): | |||
| class MyModule4(Module): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.conv = M.Conv2d(3, 3, 3) | |||
| def forward(self, x): | |||
| return self.conv(x) | |||
| m = MyModule4() | |||
| x = Tensor(np.random.random((1, 3, 32, 32))) | |||
| traced_module = trace_module(m, x) | |||
| orig_loader_dict = S.MODULE_LOADER | |||
| S.MODULE_LOADER = {} | |||
| @register_module_loader(("megengine.module.conv", "Conv2d")) | |||
| def conv2dm_loader(expr): | |||
| module = expr.inputs[0].owner | |||
| args = list(expr.args) | |||
| orig_inp = args[1] | |||
| astype_expr = CallMethod(orig_inp, "astype") | |||
| oup = TensorNode( | |||
| astype_expr, | |||
| shape=orig_inp.shape, | |||
| dtype=orig_inp.dtype, | |||
| qparams=orig_inp.qparams, | |||
| ) | |||
| astype_expr.set_args_kwargs(orig_inp, module.weight.dtype) | |||
| astype_expr.return_val = (oup,) | |||
| args[1] = oup | |||
| expr.set_args_kwargs(*args) | |||
| obj = pickle.dumps(traced_module) | |||
| new_module = pickle.loads(obj) | |||
| result = new_module(x) | |||
| gt = m(x) | |||
| assert ( | |||
| isinstance(new_module.graph._exprs[1], CallMethod) | |||
| and len(new_module.graph._exprs) == 3 | |||
| ) | |||
| np.testing.assert_equal(result.numpy(), gt.numpy()) | |||
| S.MODULE_LOADER = orig_loader_dict | |||
| def test_shared_module(): | |||
| class MyModule(M.Module): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.a = M.Elemwise("ADD") | |||
| self.b = self.a | |||
| def forward(self, x, y): | |||
| z = self.a(x, y) | |||
| z = self.b(z, y) | |||
| return z | |||
| x = Tensor(1) | |||
| y = Tensor(2) | |||
| m = MyModule() | |||
| tm = trace_module(m, x, y) | |||
| obj = pickle.dumps(tm) | |||
| load_tm = pickle.loads(obj) | |||
| _check_expr_users(load_tm) | |||
| _check_name(load_tm.flatten()) | |||
| _check_id(load_tm) | |||
| assert load_tm.a is load_tm.b | |||
| def test_convert_kwargs_to_args(): | |||
| def func(a, b, c=4, *, d, e=3, f=4): | |||
| pass | |||
| args = (1,) | |||
| kwargs = {"b": 1, "d": 6} | |||
| new_args, new_kwargs = _convert_kwargs_to_args(func, args, kwargs) | |||
| assert new_args == (1, 1, 4) | |||
| assert new_kwargs == {"d": 6, "e": 3, "f": 4} | |||
| args = (1,) | |||
| kwargs = {"d": 6} | |||
| new_args, new_kwargs = _convert_kwargs_to_args(func, args, kwargs, is_bounded=True) | |||
| assert new_args == (1, 4) | |||
| assert new_kwargs == {"d": 6, "e": 3, "f": 4} | |||
| def func1(a, b, c, d, e, *, f): | |||
| pass | |||
| args = () | |||
| kwargs = {"a": 1, "b": 2, "c": 3, "d": 4, "e": 5, "f": 6} | |||
| new_args, new_kwargs = _convert_kwargs_to_args(func1, args, kwargs) | |||
| assert new_args == (1, 2, 3, 4, 5) | |||
| assert new_kwargs == {"f": 6} | |||
| def test_opdef_serialization(): | |||
| with TemporaryFile() as f: | |||
| x = builtin.Elemwise(mode="Add") | |||
| pickle.dump(x, f) | |||
| f.seek(0) | |||
| load_x = pickle.load(f) | |||
| assert x == load_x | |||
| with TemporaryFile() as f: | |||
| x = builtin.Convolution(stride_h=9, compute_mode="float32") | |||
| x.strategy = ( | |||
| builtin.Convolution.Strategy.PROFILE | |||
| | builtin.Convolution.Strategy.HEURISTIC | |||
| | builtin.Convolution.Strategy.REPRODUCIBLE | |||
| ) | |||
| pickle.dump(x, f) | |||
| f.seek(0) | |||
| load_x = pickle.load(f) | |||
| assert x.strategy == load_x.strategy | |||
| assert x == load_x | |||