GitOrigin-RevId: 0d6bb20b2b
tags/v1.6.0
| @@ -201,7 +201,8 @@ class Apply(Expr): | |||||
| NodeMixin.wrap_safe(i, Constant.make(i)) | NodeMixin.wrap_safe(i, Constant.make(i)) | ||||
| apply_node = cls.make(opdef) | apply_node = cls.make(opdef) | ||||
| for i in inputs: | for i in inputs: | ||||
| apply_node.add_input(NodeMixin.get(i)) | |||||
| assert isinstance(i, RawTensor) | |||||
| apply_node.inputs.append(NodeMixin.get(i)) | |||||
| unset_module_tracing() | unset_module_tracing() | ||||
| outputs = apply(opdef, *inputs) | outputs = apply(opdef, *inputs) | ||||
| @@ -1,3 +1,13 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| import collections | |||||
| from typing import Callable, NamedTuple | from typing import Callable, NamedTuple | ||||
| SUPPORTED_TYPE = {} | SUPPORTED_TYPE = {} | ||||
| @@ -9,11 +19,22 @@ def register_supported_type(type, flatten, unflatten): | |||||
| SUPPORTED_TYPE[type] = NodeType(flatten, unflatten) | SUPPORTED_TYPE[type] = NodeType(flatten, unflatten) | ||||
| def _dict_flatten(inp): | |||||
| aux_data = [] | |||||
| results = [] | |||||
| for key, value in sorted(inp.items()): | |||||
| results.append(value) | |||||
| aux_data.append(key) | |||||
| return results, aux_data | |||||
| def _dict_unflatten(inps, aux_data): | |||||
| return dict(zip(aux_data, inps)) | |||||
| register_supported_type(list, lambda x: (x, None), lambda x, aux_data: list(x)) | register_supported_type(list, lambda x: (x, None), lambda x, aux_data: list(x)) | ||||
| register_supported_type(tuple, lambda x: (x, None), lambda x, aux_data: list(x)) | register_supported_type(tuple, lambda x: (x, None), lambda x, aux_data: list(x)) | ||||
| register_supported_type( | |||||
| dict, lambda x: (list(x.values()), list(x.keys())), lambda x, y: dict(zip(y, x)) | |||||
| ) | |||||
| register_supported_type(dict, _dict_flatten, _dict_unflatten) | |||||
| register_supported_type( | register_supported_type( | ||||
| slice, | slice, | ||||
| lambda x: ([x.start, x.stop, x.step], None), | lambda x: ([x.start, x.stop, x.step], None), | ||||
| @@ -68,6 +89,8 @@ class TreeDef: | |||||
| class LeafDef(TreeDef): | class LeafDef(TreeDef): | ||||
| def __init__(self, type): | def __init__(self, type): | ||||
| if not isinstance(type, collections.abc.Sequence): | |||||
| type = (type,) | |||||
| super().__init__(type, None, []) | super().__init__(type, None, []) | ||||
| self.num_leaves = 1 | self.num_leaves = 1 | ||||
| @@ -77,4 +100,4 @@ class LeafDef(TreeDef): | |||||
| return leaves[0] | return leaves[0] | ||||
| def __repr__(self): | def __repr__(self): | ||||
| return "Leaf({})".format(self.type.__name__) | |||||
| return "Leaf({})".format(", ".join(t.__name__ for t in self.type)) | |||||
| @@ -14,6 +14,7 @@ import megengine as mge | |||||
| import megengine.autodiff as ad | import megengine.autodiff as ad | ||||
| import megengine.functional as F | import megengine.functional as F | ||||
| from megengine import Tensor | from megengine import Tensor | ||||
| from megengine.experimental.traced_module import trace_module | |||||
| from megengine.module import Linear, Module | from megengine.module import Linear, Module | ||||
| from megengine.optimizer import SGD | from megengine.optimizer import SGD | ||||
| @@ -71,8 +72,13 @@ class XORNet(Module): | |||||
| return x | return x | ||||
| def test_training_converge(): | |||||
| @pytest.mark.parametrize("test_traced_module", [True, False]) | |||||
| def test_training_converge(test_traced_module): | |||||
| net = XORNet() | net = XORNet() | ||||
| if test_training_converge: | |||||
| inp = Tensor(np.random.random((14, 2))) | |||||
| net = trace_module(net, inp) | |||||
| opt = SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4) | opt = SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4) | ||||
| gm = ad.GradManager().attach(net.parameters()) | gm = ad.GradManager().attach(net.parameters()) | ||||
| @@ -105,9 +111,8 @@ def test_training_converge(): | |||||
| xx = xx.reshape((ngrid * ngrid, 1)) | xx = xx.reshape((ngrid * ngrid, 1)) | ||||
| yy = yy.reshape((ngrid * ngrid, 1)) | yy = yy.reshape((ngrid * ngrid, 1)) | ||||
| data = mge.tensor(np.concatenate((xx, yy), axis=1).astype(np.float32)) | data = mge.tensor(np.concatenate((xx, yy), axis=1).astype(np.float32)) | ||||
| pred = infer(data).numpy() | |||||
| precision = calculate_precision(data.numpy(), pred) | |||||
| pred = infer(data) | |||||
| precision = calculate_precision(data.numpy(), pred.numpy()) | |||||
| assert precision == 1.0, "Test precision must be high enough, get {}".format( | assert precision == 1.0, "Test precision must be high enough, get {}".format( | ||||
| precision | precision | ||||
| ) | ) | ||||
| @@ -15,6 +15,7 @@ import megengine.autodiff as ad | |||||
| import megengine.functional as F | import megengine.functional as F | ||||
| import megengine.optimizer as optim | import megengine.optimizer as optim | ||||
| from megengine import Tensor | from megengine import Tensor | ||||
| from megengine.experimental.traced_module import trace_module | |||||
| from megengine.jit import trace | from megengine.jit import trace | ||||
| from megengine.module import Linear, Module | from megengine.module import Linear, Module | ||||
| from megengine.optimizer import SGD | from megengine.optimizer import SGD | ||||
| @@ -73,8 +74,12 @@ class XORNet(Module): | |||||
| return x | return x | ||||
| def test_training_converge(): | |||||
| @pytest.mark.parametrize("test_traced_module", [True, False]) | |||||
| def test_training_converge(test_traced_module): | |||||
| net = XORNet() | net = XORNet() | ||||
| if test_traced_module: | |||||
| inp = Tensor(np.random.random((14, 2))) | |||||
| net = trace_module(net, inp) | |||||
| opt = SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4) | opt = SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4) | ||||
| gm = ad.GradManager().attach(net.parameters()) | gm = ad.GradManager().attach(net.parameters()) | ||||
| @@ -110,9 +115,8 @@ def test_training_converge(): | |||||
| xx = xx.reshape((ngrid * ngrid, 1)) | xx = xx.reshape((ngrid * ngrid, 1)) | ||||
| yy = yy.reshape((ngrid * ngrid, 1)) | yy = yy.reshape((ngrid * ngrid, 1)) | ||||
| data = mge.tensor(np.concatenate((xx, yy), axis=1).astype(np.float32)) | data = mge.tensor(np.concatenate((xx, yy), axis=1).astype(np.float32)) | ||||
| pred = infer(data).numpy() | |||||
| precision = calculate_precision(data.numpy(), pred) | |||||
| pred = infer(data) | |||||
| precision = calculate_precision(data.numpy(), pred.numpy()) | |||||
| print("precision=", precision) | print("precision=", precision) | ||||
| assert precision == 1.0, "Test precision must be high enough, get {}".format( | assert precision == 1.0, "Test precision must be high enough, get {}".format( | ||||
| precision | precision | ||||
| @@ -19,6 +19,7 @@ import megengine.module as M | |||||
| import megengine.optimizer as optim | import megengine.optimizer as optim | ||||
| from megengine import tensor | from megengine import tensor | ||||
| from megengine.autodiff import GradManager | from megengine.autodiff import GradManager | ||||
| from megengine.experimental.traced_module import trace_module | |||||
| from megengine.jit import trace | from megengine.jit import trace | ||||
| @@ -15,6 +15,7 @@ import pytest | |||||
| import megengine as mge | import megengine as mge | ||||
| import megengine.functional as F | import megengine.functional as F | ||||
| from megengine import Parameter, Tensor, tensor | from megengine import Parameter, Tensor, tensor | ||||
| from megengine.experimental.traced_module import TracedModule, trace_module | |||||
| from megengine.module import ( | from megengine.module import ( | ||||
| BatchNorm1d, | BatchNorm1d, | ||||
| BatchNorm2d, | BatchNorm2d, | ||||
| @@ -67,8 +68,18 @@ class MyModule(Module): | |||||
| return x | return x | ||||
| def test_module_api(): | |||||
| @pytest.mark.parametrize("test_traced_module", [True, False]) | |||||
| def test_module_api(test_traced_module): | |||||
| m = MyModule() | m = MyModule() | ||||
| if test_traced_module: | |||||
| buff = m.buff | |||||
| param = m.param | |||||
| m = trace_module(m, Tensor(np.random.random((1, 4, 16, 16)))) | |||||
| assert "buff" not in m.__dict__ | |||||
| assert "param" not in m.__dict__ | |||||
| m.buff = buff | |||||
| m.param = param | |||||
| assert list(m.children()) == [m.bn, m.i] | assert list(m.children()) == [m.bn, m.i] | ||||
| assert list(m.named_children()) == [("bn", m.bn), ("i", m.i)] | assert list(m.named_children()) == [("bn", m.bn), ("i", m.i)] | ||||
| assert list(m.modules()) == [m, m.bn, m.i, m.i.bn] | assert list(m.modules()) == [m, m.bn, m.i, m.i.bn] | ||||
| @@ -141,8 +152,11 @@ def test_module_api(): | |||||
| assert m.bn.training == False and m.i.bn.training == False | assert m.bn.training == False and m.i.bn.training == False | ||||
| def test_module_api_reuse_submodule(): | |||||
| @pytest.mark.parametrize("test_traced_module", [True, False]) | |||||
| def test_module_api_reuse_submodule(test_traced_module): | |||||
| m = MyModule() | m = MyModule() | ||||
| if test_traced_module: | |||||
| m = trace_module(m, Tensor(np.random.random((1, 4, 16, 16)))) | |||||
| m.h = m.i # pylint: disable=attribute-defined-outside-init | m.h = m.i # pylint: disable=attribute-defined-outside-init | ||||
| assert list(m.modules()) == [m, m.bn, m.i, m.i.bn] | assert list(m.modules()) == [m, m.bn, m.i, m.i.bn] | ||||
| assert list(m.named_modules()) == [ | assert list(m.named_modules()) == [ | ||||
| @@ -153,15 +167,21 @@ def test_module_api_reuse_submodule(): | |||||
| ] | ] | ||||
| def test_module_api_iterable_stability(): | |||||
| @pytest.mark.parametrize("test_traced_module", [True, False]) | |||||
| def test_module_api_iterable_stability(test_traced_module): | |||||
| m = MyModule() | m = MyModule() | ||||
| if test_traced_module: | |||||
| m = trace_module(m, Tensor(np.random.random((1, 4, 16, 16)))) | |||||
| l = list(m.modules()) | l = list(m.modules()) | ||||
| for _ in range(100): | for _ in range(100): | ||||
| assert list(m.modules()) == l | assert list(m.modules()) == l | ||||
| def test_module_api_hooks(): | |||||
| @pytest.mark.parametrize("test_traced_module", [True, False]) | |||||
| def test_module_api_hooks(test_traced_module): | |||||
| net = MyModule() | net = MyModule() | ||||
| if test_traced_module: | |||||
| net = trace_module(net, Tensor(np.zeros((1, 4, 1, 1)))) | |||||
| pre_hook_num = 0 | pre_hook_num = 0 | ||||
| post_hook_num = 0 | post_hook_num = 0 | ||||
| hooks = [] | hooks = [] | ||||
| @@ -383,11 +403,16 @@ class Simple(Module): | |||||
| self.conv1.weight = self.conv0.weight | self.conv1.weight = self.conv0.weight | ||||
| def forward(self, inputs): | def forward(self, inputs): | ||||
| pass | |||||
| x = self.conv0(inputs) | |||||
| y = self.conv1(inputs) | |||||
| return x + y | |||||
| def test_shared_param(): | |||||
| @pytest.mark.parametrize("test_traced_module", [True, False]) | |||||
| def test_shared_param(test_traced_module): | |||||
| net = Simple() | net = Simple() | ||||
| if test_traced_module: | |||||
| net = trace_module(net, tensor(np.random.random((1, 1, 8, 8)))) | |||||
| assert net.conv0.weight is net.conv1.weight | assert net.conv0.weight is net.conv1.weight | ||||
| data = tensor(np.random.random((1, 1, 8, 8)).astype(np.float32)) | data = tensor(np.random.random((1, 1, 8, 8)).astype(np.float32)) | ||||
| np.testing.assert_allclose(net.conv0(data).numpy(), net.conv1(data).numpy()) | np.testing.assert_allclose(net.conv0(data).numpy(), net.conv1(data).numpy()) | ||||
| @@ -449,15 +474,21 @@ def test_shared_param_1d(): | |||||
| np.testing.assert_allclose(conv0(data).numpy(), conv1(data).numpy()) | np.testing.assert_allclose(conv0(data).numpy(), conv1(data).numpy()) | ||||
| def test_pickle_module(): | |||||
| @pytest.mark.parametrize("test_traced_module", [True, False]) | |||||
| def test_pickle_module(test_traced_module): | |||||
| data_shape = (2, 28) | data_shape = (2, 28) | ||||
| data = tensor(np.random.random(data_shape)) | data = tensor(np.random.random(data_shape)) | ||||
| mlp = MLP() | mlp = MLP() | ||||
| pred_gt = mlp(data) | |||||
| if test_traced_module: | |||||
| mlp = trace_module(mlp, data) | |||||
| # pickle before forward | # pickle before forward | ||||
| with BytesIO() as fout: | with BytesIO() as fout: | ||||
| mge.save(mlp, fout) | mge.save(mlp, fout) | ||||
| fout.seek(0) | fout.seek(0) | ||||
| mlp1 = mge.load(fout) | mlp1 = mge.load(fout) | ||||
| if test_traced_module: | |||||
| assert type(mlp1) == TracedModule | |||||
| pred0 = mlp1(data) | pred0 = mlp1(data) | ||||
| pred1 = mlp(data) | pred1 = mlp(data) | ||||
| @@ -467,8 +498,11 @@ def test_pickle_module(): | |||||
| mge.save(mlp, fout) | mge.save(mlp, fout) | ||||
| fout.seek(0) | fout.seek(0) | ||||
| mlp1 = mge.load(fout) | mlp1 = mge.load(fout) | ||||
| if test_traced_module: | |||||
| assert type(mlp1) == TracedModule | |||||
| pred2 = mlp1(data) | pred2 = mlp1(data) | ||||
| np.testing.assert_allclose(pred_gt.numpy(), pred1.numpy(), atol=5e-6) | |||||
| np.testing.assert_allclose(pred0.numpy(), pred1.numpy(), atol=5e-6) | np.testing.assert_allclose(pred0.numpy(), pred1.numpy(), atol=5e-6) | ||||
| np.testing.assert_allclose(pred0.numpy(), pred2.numpy(), atol=5e-6) | np.testing.assert_allclose(pred0.numpy(), pred2.numpy(), atol=5e-6) | ||||
| @@ -0,0 +1,59 @@ | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| import io | |||||
| import numpy as np | |||||
| import megengine.functional as F | |||||
| import megengine.module as M | |||||
| import megengine.utils.comp_graph_tools as cgtools | |||||
| from megengine.experimental.traced_module import trace_module | |||||
| from megengine.jit import trace | |||||
| from megengine.module import Module | |||||
| class MyBlock(Module): | |||||
| def __init__(self, in_channels, channels): | |||||
| super(MyBlock, self).__init__() | |||||
| self.conv1 = M.Conv2d(in_channels, channels, 3, 1, padding=1, bias=False) | |||||
| self.bn1 = M.BatchNorm2d(channels) | |||||
| def forward(self, x): | |||||
| x = self.conv1(x) | |||||
| x = self.bn1(x) | |||||
| x = F.relu(x) + 1 | |||||
| return x | |||||
| class MyModule(Module): | |||||
| def __init__(self): | |||||
| super(MyModule, self).__init__() | |||||
| self.block0 = MyBlock(8, 4) | |||||
| self.block1 = MyBlock(4, 2) | |||||
| def forward(self, x): | |||||
| x = self.block0(x) | |||||
| x = self.block1(x) | |||||
| return x | |||||
| def test_jit_trace(): | |||||
| module = MyModule() | |||||
| module.eval() | |||||
| x = F.ones((1, 8, 14, 14)) | |||||
| expect = module(x) | |||||
| traced_module = trace_module(module, x) | |||||
| func = trace(traced_module, capture_as_const=True) | |||||
| np.testing.assert_array_equal(func(x), expect) | |||||
| model = io.BytesIO() | |||||
| func.dump(model) | |||||
| model.seek(0) | |||||
| infer_cg = cgtools.GraphInference(model) | |||||
| np.testing.assert_allclose( | |||||
| list(infer_cg.run(x.numpy()).values())[0], expect, atol=1e-6 | |||||
| ) | |||||