GitOrigin-RevId: 0d6bb20b2b
tags/v1.6.0
| @@ -201,7 +201,8 @@ class Apply(Expr): | |||
| NodeMixin.wrap_safe(i, Constant.make(i)) | |||
| apply_node = cls.make(opdef) | |||
| for i in inputs: | |||
| apply_node.add_input(NodeMixin.get(i)) | |||
| assert isinstance(i, RawTensor) | |||
| apply_node.inputs.append(NodeMixin.get(i)) | |||
| unset_module_tracing() | |||
| outputs = apply(opdef, *inputs) | |||
| @@ -1,3 +1,13 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import collections | |||
| from typing import Callable, NamedTuple | |||
| SUPPORTED_TYPE = {} | |||
| @@ -9,11 +19,22 @@ def register_supported_type(type, flatten, unflatten): | |||
| SUPPORTED_TYPE[type] = NodeType(flatten, unflatten) | |||
| def _dict_flatten(inp): | |||
| aux_data = [] | |||
| results = [] | |||
| for key, value in sorted(inp.items()): | |||
| results.append(value) | |||
| aux_data.append(key) | |||
| return results, aux_data | |||
| def _dict_unflatten(inps, aux_data): | |||
| return dict(zip(aux_data, inps)) | |||
| register_supported_type(list, lambda x: (x, None), lambda x, aux_data: list(x)) | |||
| register_supported_type(tuple, lambda x: (x, None), lambda x, aux_data: list(x)) | |||
| register_supported_type( | |||
| dict, lambda x: (list(x.values()), list(x.keys())), lambda x, y: dict(zip(y, x)) | |||
| ) | |||
| register_supported_type(dict, _dict_flatten, _dict_unflatten) | |||
| register_supported_type( | |||
| slice, | |||
| lambda x: ([x.start, x.stop, x.step], None), | |||
| @@ -68,6 +89,8 @@ class TreeDef: | |||
| class LeafDef(TreeDef): | |||
| def __init__(self, type): | |||
| if not isinstance(type, collections.abc.Sequence): | |||
| type = (type,) | |||
| super().__init__(type, None, []) | |||
| self.num_leaves = 1 | |||
| @@ -77,4 +100,4 @@ class LeafDef(TreeDef): | |||
| return leaves[0] | |||
| def __repr__(self): | |||
| return "Leaf({})".format(self.type.__name__) | |||
| return "Leaf({})".format(", ".join(t.__name__ for t in self.type)) | |||
| @@ -14,6 +14,7 @@ import megengine as mge | |||
| import megengine.autodiff as ad | |||
| import megengine.functional as F | |||
| from megengine import Tensor | |||
| from megengine.experimental.traced_module import trace_module | |||
| from megengine.module import Linear, Module | |||
| from megengine.optimizer import SGD | |||
| @@ -71,8 +72,13 @@ class XORNet(Module): | |||
| return x | |||
| def test_training_converge(): | |||
| @pytest.mark.parametrize("test_traced_module", [True, False]) | |||
| def test_training_converge(test_traced_module): | |||
| net = XORNet() | |||
| if test_training_converge: | |||
| inp = Tensor(np.random.random((14, 2))) | |||
| net = trace_module(net, inp) | |||
| opt = SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4) | |||
| gm = ad.GradManager().attach(net.parameters()) | |||
| @@ -105,9 +111,8 @@ def test_training_converge(): | |||
| xx = xx.reshape((ngrid * ngrid, 1)) | |||
| yy = yy.reshape((ngrid * ngrid, 1)) | |||
| data = mge.tensor(np.concatenate((xx, yy), axis=1).astype(np.float32)) | |||
| pred = infer(data).numpy() | |||
| precision = calculate_precision(data.numpy(), pred) | |||
| pred = infer(data) | |||
| precision = calculate_precision(data.numpy(), pred.numpy()) | |||
| assert precision == 1.0, "Test precision must be high enough, get {}".format( | |||
| precision | |||
| ) | |||
| @@ -15,6 +15,7 @@ import megengine.autodiff as ad | |||
| import megengine.functional as F | |||
| import megengine.optimizer as optim | |||
| from megengine import Tensor | |||
| from megengine.experimental.traced_module import trace_module | |||
| from megengine.jit import trace | |||
| from megengine.module import Linear, Module | |||
| from megengine.optimizer import SGD | |||
| @@ -73,8 +74,12 @@ class XORNet(Module): | |||
| return x | |||
| def test_training_converge(): | |||
| @pytest.mark.parametrize("test_traced_module", [True, False]) | |||
| def test_training_converge(test_traced_module): | |||
| net = XORNet() | |||
| if test_traced_module: | |||
| inp = Tensor(np.random.random((14, 2))) | |||
| net = trace_module(net, inp) | |||
| opt = SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4) | |||
| gm = ad.GradManager().attach(net.parameters()) | |||
| @@ -110,9 +115,8 @@ def test_training_converge(): | |||
| xx = xx.reshape((ngrid * ngrid, 1)) | |||
| yy = yy.reshape((ngrid * ngrid, 1)) | |||
| data = mge.tensor(np.concatenate((xx, yy), axis=1).astype(np.float32)) | |||
| pred = infer(data).numpy() | |||
| precision = calculate_precision(data.numpy(), pred) | |||
| pred = infer(data) | |||
| precision = calculate_precision(data.numpy(), pred.numpy()) | |||
| print("precision=", precision) | |||
| assert precision == 1.0, "Test precision must be high enough, get {}".format( | |||
| precision | |||
| @@ -19,6 +19,7 @@ import megengine.module as M | |||
| import megengine.optimizer as optim | |||
| from megengine import tensor | |||
| from megengine.autodiff import GradManager | |||
| from megengine.experimental.traced_module import trace_module | |||
| from megengine.jit import trace | |||
| @@ -15,6 +15,7 @@ import pytest | |||
| import megengine as mge | |||
| import megengine.functional as F | |||
| from megengine import Parameter, Tensor, tensor | |||
| from megengine.experimental.traced_module import TracedModule, trace_module | |||
| from megengine.module import ( | |||
| BatchNorm1d, | |||
| BatchNorm2d, | |||
| @@ -67,8 +68,18 @@ class MyModule(Module): | |||
| return x | |||
| def test_module_api(): | |||
| @pytest.mark.parametrize("test_traced_module", [True, False]) | |||
| def test_module_api(test_traced_module): | |||
| m = MyModule() | |||
| if test_traced_module: | |||
| buff = m.buff | |||
| param = m.param | |||
| m = trace_module(m, Tensor(np.random.random((1, 4, 16, 16)))) | |||
| assert "buff" not in m.__dict__ | |||
| assert "param" not in m.__dict__ | |||
| m.buff = buff | |||
| m.param = param | |||
| assert list(m.children()) == [m.bn, m.i] | |||
| assert list(m.named_children()) == [("bn", m.bn), ("i", m.i)] | |||
| assert list(m.modules()) == [m, m.bn, m.i, m.i.bn] | |||
| @@ -141,8 +152,11 @@ def test_module_api(): | |||
| assert m.bn.training == False and m.i.bn.training == False | |||
| def test_module_api_reuse_submodule(): | |||
| @pytest.mark.parametrize("test_traced_module", [True, False]) | |||
| def test_module_api_reuse_submodule(test_traced_module): | |||
| m = MyModule() | |||
| if test_traced_module: | |||
| m = trace_module(m, Tensor(np.random.random((1, 4, 16, 16)))) | |||
| m.h = m.i # pylint: disable=attribute-defined-outside-init | |||
| assert list(m.modules()) == [m, m.bn, m.i, m.i.bn] | |||
| assert list(m.named_modules()) == [ | |||
| @@ -153,15 +167,21 @@ def test_module_api_reuse_submodule(): | |||
| ] | |||
| def test_module_api_iterable_stability(): | |||
| @pytest.mark.parametrize("test_traced_module", [True, False]) | |||
| def test_module_api_iterable_stability(test_traced_module): | |||
| m = MyModule() | |||
| if test_traced_module: | |||
| m = trace_module(m, Tensor(np.random.random((1, 4, 16, 16)))) | |||
| l = list(m.modules()) | |||
| for _ in range(100): | |||
| assert list(m.modules()) == l | |||
| def test_module_api_hooks(): | |||
| @pytest.mark.parametrize("test_traced_module", [True, False]) | |||
| def test_module_api_hooks(test_traced_module): | |||
| net = MyModule() | |||
| if test_traced_module: | |||
| net = trace_module(net, Tensor(np.zeros((1, 4, 1, 1)))) | |||
| pre_hook_num = 0 | |||
| post_hook_num = 0 | |||
| hooks = [] | |||
| @@ -383,11 +403,16 @@ class Simple(Module): | |||
| self.conv1.weight = self.conv0.weight | |||
| def forward(self, inputs): | |||
| pass | |||
| x = self.conv0(inputs) | |||
| y = self.conv1(inputs) | |||
| return x + y | |||
| def test_shared_param(): | |||
| @pytest.mark.parametrize("test_traced_module", [True, False]) | |||
| def test_shared_param(test_traced_module): | |||
| net = Simple() | |||
| if test_traced_module: | |||
| net = trace_module(net, tensor(np.random.random((1, 1, 8, 8)))) | |||
| assert net.conv0.weight is net.conv1.weight | |||
| data = tensor(np.random.random((1, 1, 8, 8)).astype(np.float32)) | |||
| np.testing.assert_allclose(net.conv0(data).numpy(), net.conv1(data).numpy()) | |||
| @@ -449,15 +474,21 @@ def test_shared_param_1d(): | |||
| np.testing.assert_allclose(conv0(data).numpy(), conv1(data).numpy()) | |||
| def test_pickle_module(): | |||
| @pytest.mark.parametrize("test_traced_module", [True, False]) | |||
| def test_pickle_module(test_traced_module): | |||
| data_shape = (2, 28) | |||
| data = tensor(np.random.random(data_shape)) | |||
| mlp = MLP() | |||
| pred_gt = mlp(data) | |||
| if test_traced_module: | |||
| mlp = trace_module(mlp, data) | |||
| # pickle before forward | |||
| with BytesIO() as fout: | |||
| mge.save(mlp, fout) | |||
| fout.seek(0) | |||
| mlp1 = mge.load(fout) | |||
| if test_traced_module: | |||
| assert type(mlp1) == TracedModule | |||
| pred0 = mlp1(data) | |||
| pred1 = mlp(data) | |||
| @@ -467,8 +498,11 @@ def test_pickle_module(): | |||
| mge.save(mlp, fout) | |||
| fout.seek(0) | |||
| mlp1 = mge.load(fout) | |||
| if test_traced_module: | |||
| assert type(mlp1) == TracedModule | |||
| pred2 = mlp1(data) | |||
| np.testing.assert_allclose(pred_gt.numpy(), pred1.numpy(), atol=5e-6) | |||
| np.testing.assert_allclose(pred0.numpy(), pred1.numpy(), atol=5e-6) | |||
| np.testing.assert_allclose(pred0.numpy(), pred2.numpy(), atol=5e-6) | |||
| @@ -0,0 +1,59 @@ | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import io | |||
| import numpy as np | |||
| import megengine.functional as F | |||
| import megengine.module as M | |||
| import megengine.utils.comp_graph_tools as cgtools | |||
| from megengine.experimental.traced_module import trace_module | |||
| from megengine.jit import trace | |||
| from megengine.module import Module | |||
| class MyBlock(Module): | |||
| def __init__(self, in_channels, channels): | |||
| super(MyBlock, self).__init__() | |||
| self.conv1 = M.Conv2d(in_channels, channels, 3, 1, padding=1, bias=False) | |||
| self.bn1 = M.BatchNorm2d(channels) | |||
| def forward(self, x): | |||
| x = self.conv1(x) | |||
| x = self.bn1(x) | |||
| x = F.relu(x) + 1 | |||
| return x | |||
| class MyModule(Module): | |||
| def __init__(self): | |||
| super(MyModule, self).__init__() | |||
| self.block0 = MyBlock(8, 4) | |||
| self.block1 = MyBlock(4, 2) | |||
| def forward(self, x): | |||
| x = self.block0(x) | |||
| x = self.block1(x) | |||
| return x | |||
| def test_jit_trace(): | |||
| module = MyModule() | |||
| module.eval() | |||
| x = F.ones((1, 8, 14, 14)) | |||
| expect = module(x) | |||
| traced_module = trace_module(module, x) | |||
| func = trace(traced_module, capture_as_const=True) | |||
| np.testing.assert_array_equal(func(x), expect) | |||
| model = io.BytesIO() | |||
| func.dump(model) | |||
| model.seek(0) | |||
| infer_cg = cgtools.GraphInference(model) | |||
| np.testing.assert_allclose( | |||
| list(infer_cg.run(x.numpy()).values())[0], expect, atol=1e-6 | |||
| ) | |||