GitOrigin-RevId: 8e31a00c7e
tags/v1.7.0
| @@ -9,6 +9,7 @@ | |||
| import collections | |||
| from collections import OrderedDict, defaultdict | |||
| from functools import partial | |||
| from inspect import FullArgSpec | |||
| from typing import Callable, NamedTuple | |||
| import numpy as np | |||
| @@ -53,6 +54,7 @@ SUPPORTED_LEAF_TYPE = { | |||
| QuantMode, | |||
| ArgsIndex, | |||
| Group, | |||
| FullArgSpec, | |||
| } | |||
| USER_REGISTERED_LEAF_TYPE = [] | |||
| @@ -1928,8 +1928,11 @@ class TracedModule(Module): | |||
| self.watch_node_value = {} | |||
| self.end_points = [] | |||
| self.is_qat = is_qat | |||
| self.argspec = None | |||
| def forward(self, *args, **kwargs): | |||
| if hasattr(self, "argspec") and self.argspec is not None: | |||
| args, kwargs = _convert_kwargs_to_args(self.argspec, args, kwargs, True) | |||
| inputs, treedef = tree_flatten(((self, *args), kwargs)) | |||
| assert treedef in self.argdef_graph_map | |||
| inputs = filter( | |||
| @@ -2422,8 +2425,12 @@ def trace_module( | |||
| NodeMixin.wrap_safe( | |||
| builder, Input.make(name="top", type=ModuleNode, qualname=net_name) | |||
| ) | |||
| args, kwargs = _convert_kwargs_to_args(mod.forward, args, kwargs, True) | |||
| forward_argspec = ( | |||
| mod.argspec | |||
| if hasattr(mod, "argspec") | |||
| else inspect.getfullargspec(mod.forward) | |||
| ) | |||
| args, kwargs = _convert_kwargs_to_args(forward_argspec, args, kwargs, True) | |||
| inputs, _ = tree_flatten((args, kwargs)) | |||
| for _, i in enumerate(inputs): | |||
| # assert isinstance(i, Tensor), "not support " | |||
| @@ -2439,6 +2446,7 @@ def trace_module( | |||
| builder(*args, **kwargs) | |||
| active_module_tracer().pop_scope() | |||
| traced_mod = builder.build() | |||
| traced_mod.argspec = forward_argspec | |||
| traced_mod.graph._reset_ids() | |||
| return traced_mod | |||
| finally: | |||
| @@ -9,7 +9,8 @@ import collections | |||
| import copy | |||
| import inspect | |||
| from collections.abc import MutableMapping, MutableSequence | |||
| from typing import Dict, Iterable, List, Optional, Sequence, Type | |||
| from inspect import FullArgSpec | |||
| from typing import Callable, Dict, Iterable, List, Optional, Sequence, Type, Union | |||
| from .. import get_logger | |||
| from ..module import Module | |||
| @@ -57,9 +58,14 @@ def replace_container_with_module_container(container): | |||
| return has_module, module_container | |||
| def _convert_kwargs_to_args(func, args, kwargs, is_bounded=False): | |||
| def _convert_kwargs_to_args( | |||
| argspecs: Union[Callable, FullArgSpec], args, kwargs, is_bounded=False | |||
| ): | |||
| # is_bounded = True when func is a method and provided args don't include 'self' | |||
| arg_specs = inspect.getfullargspec(func) | |||
| arg_specs = ( | |||
| inspect.getfullargspec(argspecs) if isinstance(argspecs, Callable) else argspecs | |||
| ) | |||
| assert isinstance(arg_specs, FullArgSpec) | |||
| arg_specs_args = arg_specs.args | |||
| if is_bounded: | |||
| arg_specs_args = arg_specs.args[1:] | |||
| @@ -5,6 +5,7 @@ import numpy as np | |||
| import megengine.functional as F | |||
| import megengine.module as M | |||
| from megengine import Tensor | |||
| from megengine.module.module import Module | |||
| from megengine.traced_module import TracedModule, trace_module | |||
| from megengine.traced_module.expr import CallFunction | |||
| @@ -89,5 +90,46 @@ def test_trace_module(): | |||
| m4 = MyModule4() | |||
| tm4 = trace_module(m4, a, b) | |||
| np.testing.assert_equal(tm4(a, b).numpy(), 3) | |||
| np.testing.assert_equal(tm4(a, y=b).numpy(), 3) | |||
| np.testing.assert_equal(tm4(x=a, y=b).numpy(), 3) | |||
| tm4 = trace_module(m4, a, y=b) | |||
| np.testing.assert_equal(tm4(a, b).numpy(), 3) | |||
| np.testing.assert_equal(tm4(a, y=b).numpy(), 3) | |||
| np.testing.assert_equal(tm4(x=a, y=b).numpy(), 3) | |||
| tm4 = trace_module(m4, x=a, y=b) | |||
| np.testing.assert_equal(tm4(a, b).numpy(), 3) | |||
| np.testing.assert_equal(tm4(a, y=b).numpy(), 3) | |||
| np.testing.assert_equal(tm4(x=a, y=b).numpy(), 3) | |||
| tm5 = trace_module(tm4, a, b) | |||
| np.testing.assert_equal(tm5(a, b).numpy(), 3) | |||
| np.testing.assert_equal(tm5(a, y=b).numpy(), 3) | |||
| np.testing.assert_equal(tm5(x=a, y=b).numpy(), 3) | |||
| tm5 = trace_module(tm4, a, y=b) | |||
| np.testing.assert_equal(tm5(a, b).numpy(), 3) | |||
| np.testing.assert_equal(tm5(a, y=b).numpy(), 3) | |||
| np.testing.assert_equal(tm5(x=a, y=b).numpy(), 3) | |||
| tm5 = trace_module(tm4, x=a, y=b) | |||
| np.testing.assert_equal(tm5(a, b).numpy(), 3) | |||
| np.testing.assert_equal(tm5(a, y=b).numpy(), 3) | |||
| np.testing.assert_equal(tm5(x=a, y=b).numpy(), 3) | |||
| assert len(tm4.graph._exprs) == 1 | |||
| assert isinstance(tm4.graph._exprs[0], CallFunction) | |||
| class MyModule5(Module): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.m1 = tm4 | |||
| def forward(self, x, y): | |||
| return self.m1(x, y) | |||
| tm6 = trace_module(MyModule5(), a, b) | |||
| assert tm6.m1.argspec is None | |||
| assert tm6.m1._is_top is False | |||