GitOrigin-RevId: 8e31a00c7e
tags/v1.7.0
| @@ -9,6 +9,7 @@ | |||||
| import collections | import collections | ||||
| from collections import OrderedDict, defaultdict | from collections import OrderedDict, defaultdict | ||||
| from functools import partial | from functools import partial | ||||
| from inspect import FullArgSpec | |||||
| from typing import Callable, NamedTuple | from typing import Callable, NamedTuple | ||||
| import numpy as np | import numpy as np | ||||
| @@ -53,6 +54,7 @@ SUPPORTED_LEAF_TYPE = { | |||||
| QuantMode, | QuantMode, | ||||
| ArgsIndex, | ArgsIndex, | ||||
| Group, | Group, | ||||
| FullArgSpec, | |||||
| } | } | ||||
| USER_REGISTERED_LEAF_TYPE = [] | USER_REGISTERED_LEAF_TYPE = [] | ||||
| @@ -1928,8 +1928,11 @@ class TracedModule(Module): | |||||
| self.watch_node_value = {} | self.watch_node_value = {} | ||||
| self.end_points = [] | self.end_points = [] | ||||
| self.is_qat = is_qat | self.is_qat = is_qat | ||||
| self.argspec = None | |||||
| def forward(self, *args, **kwargs): | def forward(self, *args, **kwargs): | ||||
| if hasattr(self, "argspec") and self.argspec is not None: | |||||
| args, kwargs = _convert_kwargs_to_args(self.argspec, args, kwargs, True) | |||||
| inputs, treedef = tree_flatten(((self, *args), kwargs)) | inputs, treedef = tree_flatten(((self, *args), kwargs)) | ||||
| assert treedef in self.argdef_graph_map | assert treedef in self.argdef_graph_map | ||||
| inputs = filter( | inputs = filter( | ||||
| @@ -2422,8 +2425,12 @@ def trace_module( | |||||
| NodeMixin.wrap_safe( | NodeMixin.wrap_safe( | ||||
| builder, Input.make(name="top", type=ModuleNode, qualname=net_name) | builder, Input.make(name="top", type=ModuleNode, qualname=net_name) | ||||
| ) | ) | ||||
| args, kwargs = _convert_kwargs_to_args(mod.forward, args, kwargs, True) | |||||
| forward_argspec = ( | |||||
| mod.argspec | |||||
| if hasattr(mod, "argspec") | |||||
| else inspect.getfullargspec(mod.forward) | |||||
| ) | |||||
| args, kwargs = _convert_kwargs_to_args(forward_argspec, args, kwargs, True) | |||||
| inputs, _ = tree_flatten((args, kwargs)) | inputs, _ = tree_flatten((args, kwargs)) | ||||
| for _, i in enumerate(inputs): | for _, i in enumerate(inputs): | ||||
| # assert isinstance(i, Tensor), "not support " | # assert isinstance(i, Tensor), "not support " | ||||
| @@ -2439,6 +2446,7 @@ def trace_module( | |||||
| builder(*args, **kwargs) | builder(*args, **kwargs) | ||||
| active_module_tracer().pop_scope() | active_module_tracer().pop_scope() | ||||
| traced_mod = builder.build() | traced_mod = builder.build() | ||||
| traced_mod.argspec = forward_argspec | |||||
| traced_mod.graph._reset_ids() | traced_mod.graph._reset_ids() | ||||
| return traced_mod | return traced_mod | ||||
| finally: | finally: | ||||
| @@ -9,7 +9,8 @@ import collections | |||||
| import copy | import copy | ||||
| import inspect | import inspect | ||||
| from collections.abc import MutableMapping, MutableSequence | from collections.abc import MutableMapping, MutableSequence | ||||
| from typing import Dict, Iterable, List, Optional, Sequence, Type | |||||
| from inspect import FullArgSpec | |||||
| from typing import Callable, Dict, Iterable, List, Optional, Sequence, Type, Union | |||||
| from .. import get_logger | from .. import get_logger | ||||
| from ..module import Module | from ..module import Module | ||||
| @@ -57,9 +58,14 @@ def replace_container_with_module_container(container): | |||||
| return has_module, module_container | return has_module, module_container | ||||
| def _convert_kwargs_to_args(func, args, kwargs, is_bounded=False): | |||||
| def _convert_kwargs_to_args( | |||||
| argspecs: Union[Callable, FullArgSpec], args, kwargs, is_bounded=False | |||||
| ): | |||||
| # is_bounded = True when func is a method and provided args don't include 'self' | # is_bounded = True when func is a method and provided args don't include 'self' | ||||
| arg_specs = inspect.getfullargspec(func) | |||||
| arg_specs = ( | |||||
| inspect.getfullargspec(argspecs) if isinstance(argspecs, Callable) else argspecs | |||||
| ) | |||||
| assert isinstance(arg_specs, FullArgSpec) | |||||
| arg_specs_args = arg_specs.args | arg_specs_args = arg_specs.args | ||||
| if is_bounded: | if is_bounded: | ||||
| arg_specs_args = arg_specs.args[1:] | arg_specs_args = arg_specs.args[1:] | ||||
| @@ -5,6 +5,7 @@ import numpy as np | |||||
| import megengine.functional as F | import megengine.functional as F | ||||
| import megengine.module as M | import megengine.module as M | ||||
| from megengine import Tensor | from megengine import Tensor | ||||
| from megengine.module.module import Module | |||||
| from megengine.traced_module import TracedModule, trace_module | from megengine.traced_module import TracedModule, trace_module | ||||
| from megengine.traced_module.expr import CallFunction | from megengine.traced_module.expr import CallFunction | ||||
| @@ -89,5 +90,46 @@ def test_trace_module(): | |||||
| m4 = MyModule4() | m4 = MyModule4() | ||||
| tm4 = trace_module(m4, a, b) | tm4 = trace_module(m4, a, b) | ||||
| np.testing.assert_equal(tm4(a, b).numpy(), 3) | |||||
| np.testing.assert_equal(tm4(a, y=b).numpy(), 3) | |||||
| np.testing.assert_equal(tm4(x=a, y=b).numpy(), 3) | |||||
| tm4 = trace_module(m4, a, y=b) | |||||
| np.testing.assert_equal(tm4(a, b).numpy(), 3) | |||||
| np.testing.assert_equal(tm4(a, y=b).numpy(), 3) | |||||
| np.testing.assert_equal(tm4(x=a, y=b).numpy(), 3) | |||||
| tm4 = trace_module(m4, x=a, y=b) | |||||
| np.testing.assert_equal(tm4(a, b).numpy(), 3) | |||||
| np.testing.assert_equal(tm4(a, y=b).numpy(), 3) | |||||
| np.testing.assert_equal(tm4(x=a, y=b).numpy(), 3) | |||||
| tm5 = trace_module(tm4, a, b) | |||||
| np.testing.assert_equal(tm5(a, b).numpy(), 3) | |||||
| np.testing.assert_equal(tm5(a, y=b).numpy(), 3) | |||||
| np.testing.assert_equal(tm5(x=a, y=b).numpy(), 3) | |||||
| tm5 = trace_module(tm4, a, y=b) | |||||
| np.testing.assert_equal(tm5(a, b).numpy(), 3) | |||||
| np.testing.assert_equal(tm5(a, y=b).numpy(), 3) | |||||
| np.testing.assert_equal(tm5(x=a, y=b).numpy(), 3) | |||||
| tm5 = trace_module(tm4, x=a, y=b) | |||||
| np.testing.assert_equal(tm5(a, b).numpy(), 3) | |||||
| np.testing.assert_equal(tm5(a, y=b).numpy(), 3) | |||||
| np.testing.assert_equal(tm5(x=a, y=b).numpy(), 3) | |||||
| assert len(tm4.graph._exprs) == 1 | assert len(tm4.graph._exprs) == 1 | ||||
| assert isinstance(tm4.graph._exprs[0], CallFunction) | assert isinstance(tm4.graph._exprs[0], CallFunction) | ||||
| class MyModule5(Module): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.m1 = tm4 | |||||
| def forward(self, x, y): | |||||
| return self.m1(x, y) | |||||
| tm6 = trace_module(MyModule5(), a, b) | |||||
| assert tm6.m1.argspec is None | |||||
| assert tm6.m1._is_top is False | |||||