From 6d1a4f20e7aebcd14332c63e42f0dede6b108035 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Sat, 21 Aug 2021 19:39:09 +0800 Subject: [PATCH] feat(traced_module): support tracing submodules in list/dict GitOrigin-RevId: 4076b47a89ff5fdbe7c94778a649f8a01d6cc0b6 --- .../traced_module/traced_module.py | 36 +++- .../experimental/traced_module/utils.py | 186 ++++++++++++++++++ .../unit/traced_module/test_trace_module.py | 35 +++- 3 files changed, 243 insertions(+), 14 deletions(-) create mode 100644 imperative/python/megengine/experimental/traced_module/utils.py diff --git a/imperative/python/megengine/experimental/traced_module/traced_module.py b/imperative/python/megengine/experimental/traced_module/traced_module.py index 66ee1fa3..15f93552 100644 --- a/imperative/python/megengine/experimental/traced_module/traced_module.py +++ b/imperative/python/megengine/experimental/traced_module/traced_module.py @@ -58,6 +58,7 @@ from .module_tracer import ( ) from .node import ModuleNode, Node, NodeMixin, TensorNode from .pytree import ArgsIndex, tree_flatten +from .utils import replace_container_with_module_container logger = get_logger(__name__) @@ -988,7 +989,9 @@ class TracedModuleBuilder(NodeMixin): if k not in TracedModuleBuilder.__builder_attributes__: if isinstance(v, TracedModuleBuilder): v = v.build() - setattr(traced_module, k, v) + setattr(traced_module, k, v) + elif isinstance(v, RawTensor): + setattr(traced_module, k, v) if isinstance(self._mod, QATModule): unset_module_tracing() @@ -1146,7 +1149,16 @@ class TracedModuleBuilder(NodeMixin): if id(attr) in active_module_tracer().id2name: full_name = active_module_tracer().id2name[id(attr)] - + if isinstance(attr, (List, Dict)): + unset_module_tracing() + has_module, m_container = replace_container_with_module_container(attr) + if m_container: + attr = m_container + if has_module and not m_container: + raise ValueError( + "Can not trace the module that uses the same container to store Module and Non-Module objects " + ) + set_module_tracing() if isinstance(attr, Module): attr = TracedModuleBuilder(attr) @@ -1178,17 +1190,22 @@ class TracedModuleBuilder(NodeMixin): return object.__getattribute__(self, name) else: wrapped = object.__getattribute__(self, name) + class_members = dict(inspect.getmembers(self.__class__)) if name in self._mod.__dict__: mod_attr = getattr(self._mod, name) - - if not isinstance(mod_attr, Module) and wrapped is not mod_attr: - wrapped = mod_attr - setattr(self, name, wrapped) - - if isinstance(mod_attr, Module): - assert mod_attr is wrapped._mod + if name in class_members: + if ( + not isinstance(wrapped, TracedModuleBuilder) + and wrapped is not mod_attr + ): + wrapped = self.__getattr__(name) + + if isinstance(wrapped, TracedModuleBuilder): + if not isinstance(mod_attr, (List, Dict)): + assert mod_attr is wrapped._mod else: assert mod_attr is wrapped + full_name = None if id(mod_attr) in active_module_tracer().id2name: full_name = active_module_tracer().id2name[id(mod_attr)] @@ -1679,7 +1696,6 @@ def _register_all_builtin_module(): isclass(m[1]) and issubclass(m[1], M.Module) and m[1] is not M.Sequential - and m[1] is not M.ModuleList ): module_tracer.register_as_builtin(m[1]) diff --git a/imperative/python/megengine/experimental/traced_module/utils.py b/imperative/python/megengine/experimental/traced_module/utils.py new file mode 100644 index 00000000..dffb7da2 --- /dev/null +++ b/imperative/python/megengine/experimental/traced_module/utils.py @@ -0,0 +1,186 @@ +# -*- coding: utf-8 -*- +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import copy +from collections.abc import MutableMapping, MutableSequence +from typing import Dict, Iterable, List, Optional, Sequence + +from ...module import Module + + +def replace_container_with_module_container(container): + has_module = False + module_container = None + if isinstance(container, Dict): + m_dic = copy.copy(container) + for key, value in container.items(): + if isinstance(value, Module): + has_module = True + elif isinstance(value, (List, Dict)): + ( + _has_module, + _module_container, + ) = replace_container_with_module_container(value) + m_dic[key] = _module_container + if _has_module: + has_module = True + if not all(isinstance(v, Module) for v in m_dic.values()): + return has_module, None + else: + return has_module, _ModuleDict(m_dic) + elif isinstance(container, List): + m_list = copy.copy(container) + for ind, value in enumerate(container): + if isinstance(value, Module): + has_module = True + elif isinstance(value, (List, Dict)): + ( + _has_module, + _module_container, + ) = replace_container_with_module_container(value) + m_list[ind] = _module_container + if _has_module: + has_module = True + if not all(isinstance(v, Module) for v in m_list): + return has_module, None + else: + return has_module, _ModuleList(m_list) + return has_module, module_container + + +class _ModuleList(Module, MutableSequence): + r""" + A List-like container. + + Using a ``ModuleList``, one can visit, add, delete and modify submodules + just like an ordinary python list. + """ + + def __init__(self, modules: Optional[Iterable[Module]] = None): + super().__init__() + self._size = 0 + if modules is None: + return + for mod in modules: + self.append(mod) + + @classmethod + def _ikey(cls, idx): + return "{}".format(idx) + + def _check_idx(self, idx): + L = len(self) + if idx < 0: + idx = L + idx + if idx < 0 or idx >= L: + raise IndexError("list index out of range") + return idx + + def __getitem__(self, idx: int): + if isinstance(idx, slice): + idx = range(self._size)[idx] + if not isinstance(idx, Sequence): + idx = [ + idx, + ] + rst = [] + for i in idx: + i = self._check_idx(i) + key = self._ikey(i) + try: + rst.append(getattr(self, key)) + except AttributeError: + raise IndexError("list index out of range") + return rst if len(rst) > 1 else rst[0] + + def __setitem__(self, idx: int, mod: Module): + if not isinstance(mod, Module): + raise ValueError("invalid sub-module") + idx = self._check_idx(idx) + setattr(self, self._ikey(idx), mod) + + def __delitem__(self, idx): + idx = self._check_idx(idx) + L = len(self) + for orig_idx in range(idx + 1, L): + new_idx = orig_idx - 1 + self[new_idx] = self[orig_idx] + delattr(self, self._ikey(L - 1)) + self._size -= 1 + + def __len__(self): + return self._size + + def insert(self, idx, mod: Module): + assert isinstance(mod, Module) + L = len(self) + if idx < 0: + idx = L - idx + # clip idx to (0, L) + if idx > L: + idx = L + elif idx < 0: + idx = 0 + + for new_idx in range(L, idx, -1): + orig_idx = new_idx - 1 + key = self._ikey(new_idx) + setattr(self, key, self[orig_idx]) + + key = self._ikey(idx) + setattr(self, key, mod) + self._size += 1 + + def forward(self): + raise RuntimeError("ModuleList is not callable") + + +class _ModuleDict(Module, MutableMapping): + r""" + A Dict-like container. + + Using a ``ModuleDict``, one can visit, add, delete and modify submodules + just like an ordinary python dict. + """ + + def __init__(self, modules: Optional[Dict[str, Module]] = None): + super().__init__() + self._size = 0 + if modules is not None: + self.update(modules) + + def __delitem__(self, key): + delattr(self, key) + self._size -= 1 + + def __getitem__(self, key): + return getattr(self, key) + + def __setitem__(self, key, value): + if not isinstance(value, Module): + raise ValueError("invalid sub-module") + setattr(self, key, value) + self._size += 1 + + def __iter__(self): + return iter(self.keys()) + + def __len__(self): + return self._size + + def items(self): + return dict(self.named_children()).items() + + def values(self): + return dict(self.named_children()).values() + + def keys(self): + return dict(self.named_children()).keys() + + def forward(self): + raise RuntimeError("ModuleList is not callable") diff --git a/imperative/python/test/unit/traced_module/test_trace_module.py b/imperative/python/test/unit/traced_module/test_trace_module.py index d4475ecd..5ccf935d 100644 --- a/imperative/python/test/unit/traced_module/test_trace_module.py +++ b/imperative/python/test/unit/traced_module/test_trace_module.py @@ -1,11 +1,11 @@ import numpy as np +import megengine.module as M from megengine import Tensor -from megengine.experimental.traced_module import trace_module -from megengine.module import Module as M +from megengine.experimental.traced_module import TracedModule, trace_module -class MyModule1(M): +class MyModule1(M.Module): def forward(self, x): y = Tensor(x) y += 1 @@ -13,7 +13,7 @@ class MyModule1(M): return x, y -class MyModule2(M): +class MyModule2(M.Module): def forward(self, x): y = Tensor([1, x, 1]) y += 1 @@ -21,6 +21,23 @@ class MyModule2(M): return x, y +class MyModule3(M.Module): + def __init__(self): + super().__init__() + self.modules = [ + M.Elemwise("ADD"), + M.Elemwise("ADD"), + {"a": M.Elemwise("ADD"), "b": M.Elemwise("ADD")}, + ] + + def forward(self, a, b): + x = self.modules[0](a, b) + y = self.modules[1](a, b) + y = self.modules[2]["a"](x, y) + y = self.modules[2]["b"](x, y) + return y + + def test_trace_module(): x = Tensor(1) @@ -40,3 +57,13 @@ def test_trace_module(): for a, b in zip(output1, gt1): np.testing.assert_equal(a.numpy(), b.numpy()) + + a, b = Tensor(1), Tensor(2) + m3 = MyModule3() + gt = m3(a, b) + tm3 = trace_module(m3, a, b) + out = tm3(a, b) + np.testing.assert_equal(out.numpy(), gt.numpy()) + assert isinstance(tm3.modules.__dict__["0"], M.Elemwise) + assert isinstance(tm3.modules.__dict__["2"], TracedModule) + assert isinstance(tm3.modules.__dict__["2"].a, M.Elemwise)