Browse Source

feat(mge/utils): add get/set_expand_structure to deal with complex key

GitOrigin-RevId: 4d1b952068
tags/v1.3.0
Megvii Engine Team 4 years ago
parent
commit
5a38ad3920
5 changed files with 87 additions and 176 deletions
  1. +23
    -3
      imperative/python/megengine/module/module.py
  2. +2
    -2
      imperative/python/megengine/module/sequential.py
  3. +4
    -11
      imperative/python/megengine/quantization/quantize.py
  4. +43
    -0
      imperative/python/megengine/utils/module_utils.py
  5. +15
    -160
      imperative/python/test/unit/module/test_module.py

+ 23
- 3
imperative/python/megengine/module/module.py View File

@@ -21,9 +21,9 @@ from ..utils.naming import auto_naming
logger = get_logger(__name__)


def _expand_structure(key, obj):
def _expand_structure(prefix, obj):
if isinstance(obj, (Tensor, Module)):
return [(key, obj)]
return [(prefix, obj)]
elif isinstance(obj, (list, tuple, dict)):
ret = []
if isinstance(obj, dict):
@@ -37,12 +37,32 @@ def _expand_structure(key, obj):
"keys for Tensor and Module must be str, error key: {}".format(k)
)
for kt, vt in sub_ret:
ret.extend([(key + "." + kt, vt)])
ret.extend([(prefix + "." + kt, vt)])
return ret
else:
return []


def _access_structure(obj, key, callback=None):
key_list = key.split(".")
cur = obj
parent = None
for k in key_list:
parent = cur
if isinstance(cur, (Tensor, Module)):
cur = getattr(cur, k)
elif isinstance(cur, (list, tuple)):
k = int(k)
cur = cur[k]
elif isinstance(cur, dict):
cur = cur[k]
else:
raise ValueError(
"Unsupport value type {} to access attribute".format(type(cur))
)
return callback(parent, k, cur)


def _is_parameter(obj):
return isinstance(obj, Parameter)



+ 2
- 2
imperative/python/megengine/module/sequential.py View File

@@ -18,9 +18,9 @@ class Sequential(Module):
Alternatively, an ordered dict of modules can also be passed in.

To make it easier to understand, here is a small example:
Examples:
.. testcode::

import numpy as np


+ 4
- 11
imperative/python/megengine/quantization/quantize.py View File

@@ -7,7 +7,7 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from copy import copy, deepcopy
from functools import partial
from typing import Callable, Dict, Tuple
from typing import Callable

import numpy as np

@@ -19,6 +19,7 @@ from ..module import quantized as Quantized
from ..module.qat import QATModule
from ..module.quantized import QuantizedModule
from ..tensor import Tensor
from ..utils.module_utils import set_expand_structure
from .qconfig import QConfig, ema_fakequant_qconfig


@@ -79,11 +80,7 @@ def quantize(module: Module, inplace: bool = True, mapping: dict = None):
module._flatten(with_key=True, with_parent=True, predicate=is_qat)
):
new_mod = convert_dict[type(submodule)].from_qat_module(submodule)
if isinstance(parent, Float.Sequential):
# cannnot use setattr to be compatible with Sequential's ``__setitem__``
parent[int(key.split(".")[-1])] = new_mod
else:
setattr(parent, key.split(".")[-1], new_mod)
set_expand_structure(parent, key, new_mod)

return module

@@ -126,11 +123,7 @@ def quantize_qat(
continue

new_mod = convert_dict[type(submodule)].from_float_module(submodule)
if isinstance(parent, Float.Sequential):
# cannnot use setattr to be compatible with Sequential's ``__setitem__``
parent[int(key.split(".")[-1])] = new_mod
else:
setattr(parent, key.split(".")[-1], new_mod)
set_expand_structure(parent, key, new_mod)

propagate_qconfig(module, qconfig)
return module


+ 43
- 0
imperative/python/megengine/utils/module_utils.py View File

@@ -0,0 +1,43 @@
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from collections import Iterable

from ..module import Sequential
from ..module.module import Module, _access_structure
from ..tensor import Tensor


def get_expand_structure(obj: Module, key: str):
"""
Gets Module's attribute compatible with complex key from Module's :meth:`~.named_children`.
Supports handling structure containing list or dict.
"""

def f(_, __, cur):
return cur

return _access_structure(obj, key, callback=f)


def set_expand_structure(obj: Module, key: str, value):
"""
Sets Module's attribute compatible with complex key from Module's :meth:`~.named_children`.
Supports handling structure containing list or dict.
"""

def f(parent, key, cur):
if isinstance(parent, (Tensor, Module)):
# cannnot use setattr to be compatible with Sequential's ``__setitem__``
if isinstance(cur, Sequential):
parent[int(key)] = value
else:
setattr(parent, key, value)
else:
parent[key] = value

_access_structure(obj, key, callback=f)

+ 15
- 160
imperative/python/test/unit/module/test_module.py View File

@@ -6,8 +6,6 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import os
import tempfile
from collections import OrderedDict
from io import BytesIO

@@ -29,7 +27,9 @@ from megengine.module import (
Sequential,
Softmax,
)
from megengine.module.module import _access_structure
from megengine.quantization.quantize import quantize, quantize_qat
from megengine.utils.module_utils import get_expand_structure, set_expand_structure


class MLP(Module):
@@ -45,146 +45,6 @@ class MLP(Module):
return x


def has_gpu(num=1):
try:
mgb.comp_node("gpu{}".format(num - 1))
except mgb.MegBrainError:
return False

return True


def randomNp(*args):
for arg in args:
assert isinstance(arg, int)
return np.random.random(args)


def randomTorch(*args):
import torch # pylint: disable=import-outside-toplevel

for arg in args:
assert isinstance(arg, int)
return torch.tensor(randomNp(*args), dtype=torch.float32)


def graph_mode(*modes):
if not set(modes).issubset({"eager", "static"}):
raise ValueError("graph mode must be in (eager, static)")

def decorator(func):
def wrapper(*args, **kwargs):
if "eager" in set(modes):
func(*args, **kwargs)
if "static" in set(modes):
with Graph() as cg:
cg.set_option("eager_evaluation", False)
func(*args, **kwargs)

return wrapper

return decorator


def _default_compare_fn(x, y):
np.testing.assert_allclose(x.numpy(), y, rtol=1e-6)


def opr_test(
cases,
func,
mode=("eager", "static", "dynamic_shape"),
compare_fn=_default_compare_fn,
ref_fn=None,
**kwargs
):
"""
mode: the list of test mode which are eager, static and dynamic_shape
will test all the cases if None.
func: the function to run opr.
compare_fn: the function to compare the result and expected, use np.testing.assert_allclose if None.
ref_fn: the function to generate expected data, should assign output if None.
cases: the list which have dict element, the list length should be 2 for dynamic shape test.
and the dict should have input,
and should have output if ref_fn is None.
should use list for multiple inputs and outputs for each case.
kwargs: The additional kwargs for opr func.

simple examples:

dtype = np.float32
cases = [{"input": [10, 20]}, {"input": [20, 30]}]
opr_test(cases,
F.eye,
ref_fn=lambda n, m: np.eye(n, m).astype(dtype),
dtype=dtype)

"""

def check_results(results, expected):
if not isinstance(results, Tuple):
results = (results,)
for r, e in zip(results, expected):
compare_fn(r, e)

def get_trace_fn(func, enabled, symbolic):
jit.trace.enabled = enabled
return jit.trace(func, symbolic=symbolic)

def get_param(cases, idx):
case = cases[idx]
inp = case.get("input", None)
outp = case.get("output", None)
if inp is None:
raise ValueError("the test case should have input")
if not isinstance(inp, List):
inp = (inp,)
else:
inp = tuple(inp)
if ref_fn is not None and callable(ref_fn):
outp = ref_fn(*inp)
if outp is None:
raise ValueError("the test case should have output or reference function")
if not isinstance(outp, List):
outp = (outp,)
else:
outp = tuple(outp)

return inp, outp

if not set(mode).issubset({"eager", "static", "dynamic_shape"}):
raise ValueError("opr test mode must be in (eager, static, dynamic_shape)")

if len(cases) == 0:
raise ValueError("should give one case at least")

if "dynamic_shape" in set(mode):
if len(cases) != 2:
raise ValueError("should give 2 cases for dynamic shape test")

if not callable(func):
raise ValueError("the input func should be callable")

inp, outp = get_param(cases, 0)

def run(*args, **kwargs):
return func(*args, **kwargs)

if "eager" in set(mode):
f = get_trace_fn(run, False, False)
results = f(*inp, **kwargs)
check_results(results, outp)

if "static" in set(mode) or "dynamic_shape" in set(mode):
f = get_trace_fn(run, True, True)
results = f(*inp, **kwargs)
check_results(results, outp)
if "dynamic_shape" in set(mode):
inp, outp = get_param(cases, 1)
results = f(*inp, **kwargs)
check_results(results, outp)


class MyModule(Module):
class InnerModule(Module):
def __init__(self):
@@ -306,13 +166,13 @@ def test_module_api_hooks():
post_hook_num = 0
hooks = []

def pre_hook(module, inputs):
def pre_hook(_, inputs):
nonlocal pre_hook_num
pre_hook_num += 1
modified_inputs = tuple(inp + 1 for inp in inputs)
return modified_inputs

def post_hook(module, inputs, outputs):
def post_hook(_, __, outputs):
nonlocal post_hook_num
post_hook_num += 1
outputs += 1
@@ -376,7 +236,7 @@ class MyModule2(Module):

def test_expand_structure():
m = MyModule2()
assert list(m.named_modules()) == [
rst = [
("", m),
("a.0", m.a[0]),
("a.1.x", m.a[1]["x"]),
@@ -387,6 +247,16 @@ def test_expand_structure():
("a.2.0.bn", m.a[2][0].bn),
("bn", m.bn),
]
assert list(m.named_modules()) == rst

for item in rst[1:]:
assert get_expand_structure(m, item[0]) == item[1]

for item in reversed(rst[1:]):
if _access_structure(m, item[0], lambda p, k, o: isinstance(p, tuple)):
continue
set_expand_structure(m, item[0], "TEST_VALUE")
assert get_expand_structure(m, item[0]) == "TEST_VALUE"


def test_flatten_others():
@@ -603,21 +473,6 @@ def test_pickle_module():
np.testing.assert_allclose(pred0.numpy(), pred2.numpy(), atol=5e-6)


@pytest.mark.skip(reason="under development")
def test_dump_model():
data_shape = (2, 28)
data = Tensor(np.random.random(data_shape))
mlp = MLP()
pred = mlp(data)
f = tempfile.NamedTemporaryFile(delete=False)
f_name = f.name
try:
mge.dump(pred, f_name)
finally:
f.close()
os.unlink(f_name)


def test_load_quantized():
from megengine.core.tensor import dtype



Loading…
Cancel
Save