|
|
|
@@ -6,8 +6,6 @@ |
|
|
|
# Unless required by applicable law or agreed to in writing, |
|
|
|
# software distributed under the License is distributed on an |
|
|
|
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
|
import os |
|
|
|
import tempfile |
|
|
|
from collections import OrderedDict |
|
|
|
from io import BytesIO |
|
|
|
|
|
|
|
@@ -29,7 +27,9 @@ from megengine.module import ( |
|
|
|
Sequential, |
|
|
|
Softmax, |
|
|
|
) |
|
|
|
from megengine.module.module import _access_structure |
|
|
|
from megengine.quantization.quantize import quantize, quantize_qat |
|
|
|
from megengine.utils.module_utils import get_expand_structure, set_expand_structure |
|
|
|
|
|
|
|
|
|
|
|
class MLP(Module): |
|
|
|
@@ -45,146 +45,6 @@ class MLP(Module): |
|
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
def has_gpu(num=1): |
|
|
|
try: |
|
|
|
mgb.comp_node("gpu{}".format(num - 1)) |
|
|
|
except mgb.MegBrainError: |
|
|
|
return False |
|
|
|
|
|
|
|
return True |
|
|
|
|
|
|
|
|
|
|
|
def randomNp(*args): |
|
|
|
for arg in args: |
|
|
|
assert isinstance(arg, int) |
|
|
|
return np.random.random(args) |
|
|
|
|
|
|
|
|
|
|
|
def randomTorch(*args): |
|
|
|
import torch # pylint: disable=import-outside-toplevel |
|
|
|
|
|
|
|
for arg in args: |
|
|
|
assert isinstance(arg, int) |
|
|
|
return torch.tensor(randomNp(*args), dtype=torch.float32) |
|
|
|
|
|
|
|
|
|
|
|
def graph_mode(*modes): |
|
|
|
if not set(modes).issubset({"eager", "static"}): |
|
|
|
raise ValueError("graph mode must be in (eager, static)") |
|
|
|
|
|
|
|
def decorator(func): |
|
|
|
def wrapper(*args, **kwargs): |
|
|
|
if "eager" in set(modes): |
|
|
|
func(*args, **kwargs) |
|
|
|
if "static" in set(modes): |
|
|
|
with Graph() as cg: |
|
|
|
cg.set_option("eager_evaluation", False) |
|
|
|
func(*args, **kwargs) |
|
|
|
|
|
|
|
return wrapper |
|
|
|
|
|
|
|
return decorator |
|
|
|
|
|
|
|
|
|
|
|
def _default_compare_fn(x, y): |
|
|
|
np.testing.assert_allclose(x.numpy(), y, rtol=1e-6) |
|
|
|
|
|
|
|
|
|
|
|
def opr_test( |
|
|
|
cases, |
|
|
|
func, |
|
|
|
mode=("eager", "static", "dynamic_shape"), |
|
|
|
compare_fn=_default_compare_fn, |
|
|
|
ref_fn=None, |
|
|
|
**kwargs |
|
|
|
): |
|
|
|
""" |
|
|
|
mode: the list of test mode which are eager, static and dynamic_shape |
|
|
|
will test all the cases if None. |
|
|
|
func: the function to run opr. |
|
|
|
compare_fn: the function to compare the result and expected, use np.testing.assert_allclose if None. |
|
|
|
ref_fn: the function to generate expected data, should assign output if None. |
|
|
|
cases: the list which have dict element, the list length should be 2 for dynamic shape test. |
|
|
|
and the dict should have input, |
|
|
|
and should have output if ref_fn is None. |
|
|
|
should use list for multiple inputs and outputs for each case. |
|
|
|
kwargs: The additional kwargs for opr func. |
|
|
|
|
|
|
|
simple examples: |
|
|
|
|
|
|
|
dtype = np.float32 |
|
|
|
cases = [{"input": [10, 20]}, {"input": [20, 30]}] |
|
|
|
opr_test(cases, |
|
|
|
F.eye, |
|
|
|
ref_fn=lambda n, m: np.eye(n, m).astype(dtype), |
|
|
|
dtype=dtype) |
|
|
|
|
|
|
|
""" |
|
|
|
|
|
|
|
def check_results(results, expected): |
|
|
|
if not isinstance(results, Tuple): |
|
|
|
results = (results,) |
|
|
|
for r, e in zip(results, expected): |
|
|
|
compare_fn(r, e) |
|
|
|
|
|
|
|
def get_trace_fn(func, enabled, symbolic): |
|
|
|
jit.trace.enabled = enabled |
|
|
|
return jit.trace(func, symbolic=symbolic) |
|
|
|
|
|
|
|
def get_param(cases, idx): |
|
|
|
case = cases[idx] |
|
|
|
inp = case.get("input", None) |
|
|
|
outp = case.get("output", None) |
|
|
|
if inp is None: |
|
|
|
raise ValueError("the test case should have input") |
|
|
|
if not isinstance(inp, List): |
|
|
|
inp = (inp,) |
|
|
|
else: |
|
|
|
inp = tuple(inp) |
|
|
|
if ref_fn is not None and callable(ref_fn): |
|
|
|
outp = ref_fn(*inp) |
|
|
|
if outp is None: |
|
|
|
raise ValueError("the test case should have output or reference function") |
|
|
|
if not isinstance(outp, List): |
|
|
|
outp = (outp,) |
|
|
|
else: |
|
|
|
outp = tuple(outp) |
|
|
|
|
|
|
|
return inp, outp |
|
|
|
|
|
|
|
if not set(mode).issubset({"eager", "static", "dynamic_shape"}): |
|
|
|
raise ValueError("opr test mode must be in (eager, static, dynamic_shape)") |
|
|
|
|
|
|
|
if len(cases) == 0: |
|
|
|
raise ValueError("should give one case at least") |
|
|
|
|
|
|
|
if "dynamic_shape" in set(mode): |
|
|
|
if len(cases) != 2: |
|
|
|
raise ValueError("should give 2 cases for dynamic shape test") |
|
|
|
|
|
|
|
if not callable(func): |
|
|
|
raise ValueError("the input func should be callable") |
|
|
|
|
|
|
|
inp, outp = get_param(cases, 0) |
|
|
|
|
|
|
|
def run(*args, **kwargs): |
|
|
|
return func(*args, **kwargs) |
|
|
|
|
|
|
|
if "eager" in set(mode): |
|
|
|
f = get_trace_fn(run, False, False) |
|
|
|
results = f(*inp, **kwargs) |
|
|
|
check_results(results, outp) |
|
|
|
|
|
|
|
if "static" in set(mode) or "dynamic_shape" in set(mode): |
|
|
|
f = get_trace_fn(run, True, True) |
|
|
|
results = f(*inp, **kwargs) |
|
|
|
check_results(results, outp) |
|
|
|
if "dynamic_shape" in set(mode): |
|
|
|
inp, outp = get_param(cases, 1) |
|
|
|
results = f(*inp, **kwargs) |
|
|
|
check_results(results, outp) |
|
|
|
|
|
|
|
|
|
|
|
class MyModule(Module): |
|
|
|
class InnerModule(Module): |
|
|
|
def __init__(self): |
|
|
|
@@ -306,13 +166,13 @@ def test_module_api_hooks(): |
|
|
|
post_hook_num = 0 |
|
|
|
hooks = [] |
|
|
|
|
|
|
|
def pre_hook(module, inputs): |
|
|
|
def pre_hook(_, inputs): |
|
|
|
nonlocal pre_hook_num |
|
|
|
pre_hook_num += 1 |
|
|
|
modified_inputs = tuple(inp + 1 for inp in inputs) |
|
|
|
return modified_inputs |
|
|
|
|
|
|
|
def post_hook(module, inputs, outputs): |
|
|
|
def post_hook(_, __, outputs): |
|
|
|
nonlocal post_hook_num |
|
|
|
post_hook_num += 1 |
|
|
|
outputs += 1 |
|
|
|
@@ -376,7 +236,7 @@ class MyModule2(Module): |
|
|
|
|
|
|
|
def test_expand_structure(): |
|
|
|
m = MyModule2() |
|
|
|
assert list(m.named_modules()) == [ |
|
|
|
rst = [ |
|
|
|
("", m), |
|
|
|
("a.0", m.a[0]), |
|
|
|
("a.1.x", m.a[1]["x"]), |
|
|
|
@@ -387,6 +247,16 @@ def test_expand_structure(): |
|
|
|
("a.2.0.bn", m.a[2][0].bn), |
|
|
|
("bn", m.bn), |
|
|
|
] |
|
|
|
assert list(m.named_modules()) == rst |
|
|
|
|
|
|
|
for item in rst[1:]: |
|
|
|
assert get_expand_structure(m, item[0]) == item[1] |
|
|
|
|
|
|
|
for item in reversed(rst[1:]): |
|
|
|
if _access_structure(m, item[0], lambda p, k, o: isinstance(p, tuple)): |
|
|
|
continue |
|
|
|
set_expand_structure(m, item[0], "TEST_VALUE") |
|
|
|
assert get_expand_structure(m, item[0]) == "TEST_VALUE" |
|
|
|
|
|
|
|
|
|
|
|
def test_flatten_others(): |
|
|
|
@@ -603,21 +473,6 @@ def test_pickle_module(): |
|
|
|
np.testing.assert_allclose(pred0.numpy(), pred2.numpy(), atol=5e-6) |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.skip(reason="under development") |
|
|
|
def test_dump_model(): |
|
|
|
data_shape = (2, 28) |
|
|
|
data = Tensor(np.random.random(data_shape)) |
|
|
|
mlp = MLP() |
|
|
|
pred = mlp(data) |
|
|
|
f = tempfile.NamedTemporaryFile(delete=False) |
|
|
|
f_name = f.name |
|
|
|
try: |
|
|
|
mge.dump(pred, f_name) |
|
|
|
finally: |
|
|
|
f.close() |
|
|
|
os.unlink(f_name) |
|
|
|
|
|
|
|
|
|
|
|
def test_load_quantized(): |
|
|
|
from megengine.core.tensor import dtype |
|
|
|
|
|
|
|
|