GitOrigin-RevId: e690bc42b0
tags/v1.0.0-rc1
| @@ -209,7 +209,7 @@ def conv_transpose2d( | |||||
| dilate_w=dilate_w, | dilate_w=dilate_w, | ||||
| strategy=get_conv_execution_strategy(), | strategy=get_conv_execution_strategy(), | ||||
| ) | ) | ||||
| (output,) = apply(op, inp, weight) | |||||
| (output,) = apply(op, weight, inp) | |||||
| if bias is not None: | if bias is not None: | ||||
| output += bias | output += bias | ||||
| return output | return output | ||||
| @@ -241,7 +241,7 @@ def local_conv2d( | |||||
| pad_w=pad_w, | pad_w=pad_w, | ||||
| dilate_h=dilate_h, | dilate_h=dilate_h, | ||||
| dilate_w=dilate_w, | dilate_w=dilate_w, | ||||
| strategy=get_conv_execution_strategy(), | |||||
| # strategy=get_conv_execution_strategy(), | |||||
| ) | ) | ||||
| (output,) = apply(op, inp, weight) | (output,) = apply(op, inp, weight) | ||||
| if bias is not None: | if bias is not None: | ||||
| @@ -724,7 +724,7 @@ def sync_batch_norm( | |||||
| """ | """ | ||||
| assert eps_mode in {"MAX", "ADDITIVE"}, "unknown eps_mode: {}".format(eps_mode) | assert eps_mode in {"MAX", "ADDITIVE"}, "unknown eps_mode: {}".format(eps_mode) | ||||
| _channels = input.shape[1] | _channels = input.shape[1] | ||||
| _ndim = len(input.shape) | |||||
| _ndim = input.ndim | |||||
| _param_shape = (1, _channels) + (1,) * (_ndim - 2) | _param_shape = (1, _channels) + (1,) * (_ndim - 2) | ||||
| if training: | if training: | ||||
| @@ -12,7 +12,7 @@ import numpy as np | |||||
| from ..distributed.group import WORLD, Group | from ..distributed.group import WORLD, Group | ||||
| from ..functional import batch_norm2d, sync_batch_norm | from ..functional import batch_norm2d, sync_batch_norm | ||||
| from ..tensor_nn import Buffer, Parameter | |||||
| from ..tensor_nn import Buffer, Parameter, Tensor | |||||
| from . import init | from . import init | ||||
| from .module import Module | from .module import Module | ||||
| @@ -74,12 +74,12 @@ class _BatchNorm(Module): | |||||
| _ndims = len(inp.shape) | _ndims = len(inp.shape) | ||||
| if _ndims != 4: | if _ndims != 4: | ||||
| origin_shape = inp.shapeof() | |||||
| origin_shape = inp.shape | |||||
| if _ndims == 2: | if _ndims == 2: | ||||
| n, c = inp.shapeof(0), inp.shapeof(1) | |||||
| n, c = inp.shape[0], inp.shape[1] | |||||
| new_shape = (n, c, 1, 1) | new_shape = (n, c, 1, 1) | ||||
| elif _ndims == 3: | elif _ndims == 3: | ||||
| n, c, h = inp.shapeof(0), inp.shapeof(1), inp.shapeof(2) | |||||
| n, c, h = inp.shape[0], inp.shape[1], inp.shape[2] | |||||
| new_shape = (n, c, h, 1) | new_shape = (n, c, h, 1) | ||||
| inp = inp.reshape(new_shape) | inp = inp.reshape(new_shape) | ||||
| @@ -127,7 +127,7 @@ class SyncBatchNorm(_BatchNorm): | |||||
| affine=True, | affine=True, | ||||
| track_running_stats=True, | track_running_stats=True, | ||||
| freeze=False, | freeze=False, | ||||
| group: Optional[Group] = None, | |||||
| group: Optional[Group] = WORLD, | |||||
| ) -> None: | ) -> None: | ||||
| super().__init__( | super().__init__( | ||||
| num_features, eps, momentum, affine, track_running_stats, freeze | num_features, eps, momentum, affine, track_running_stats, freeze | ||||
| @@ -145,13 +145,16 @@ class SyncBatchNorm(_BatchNorm): | |||||
| _ndims = len(inp.shape) | _ndims = len(inp.shape) | ||||
| if _ndims != 4: | if _ndims != 4: | ||||
| origin_shape = inp.shapeof() | |||||
| new_shape = Tensor([1, 1, 1, 1], device=inp.device) | |||||
| origin_shape = inp.shape | |||||
| if _ndims == 2: | if _ndims == 2: | ||||
| n, c = inp.shape[0], inp.shape[1] | |||||
| new_shape = (n, c, 1, 1) | |||||
| new_shape[:2] = origin_shape[:2] | |||||
| elif _ndims == 3: | elif _ndims == 3: | ||||
| n, c, h = inp.shape[0], inp.shape[1], inp.shape[2] | |||||
| new_shape = (n, c, h, 1) | |||||
| new_shape[:3] = origin_shape[:3] | |||||
| else: | |||||
| raise ValueError( | |||||
| "expected 2D, 3D or 4D input (got {}D input)".format(len(inp.shape)) | |||||
| ) | |||||
| inp = inp.reshape(new_shape) | inp = inp.reshape(new_shape) | ||||
| @@ -376,7 +376,13 @@ class LocalConv2d(Conv2d): | |||||
| def forward(self, inp): | def forward(self, inp): | ||||
| return local_conv2d( | return local_conv2d( | ||||
| inp, self.weight, self.stride, self.padding, self.dilation, self.conv_mode | |||||
| inp, | |||||
| self.weight, | |||||
| None, | |||||
| self.stride, | |||||
| self.padding, | |||||
| self.dilation, | |||||
| self.conv_mode, | |||||
| ) | ) | ||||
| @@ -0,0 +1,24 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| import numpy as np | |||||
| import megengine as mge | |||||
| from megengine.module import LeakyReLU | |||||
| from megengine.test import assertTensorClose | |||||
| def test_leaky_relu(): | |||||
| data = np.array([-8, -12, 6, 10]).astype(np.float32) | |||||
| negative_slope = 0.1 | |||||
| leaky_relu = LeakyReLU(negative_slope) | |||||
| output = leaky_relu(mge.tensor(data)) | |||||
| np_output = np.maximum(0, data) + negative_slope * np.minimum(0, data) | |||||
| assertTensorClose(output.numpy(), np_output, max_err=0) | |||||
| @@ -0,0 +1,419 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| import multiprocessing as mp | |||||
| import platform | |||||
| import numpy as np | |||||
| import pytest | |||||
| import megengine as mge | |||||
| import megengine.distributed as dist | |||||
| from megengine import tensor | |||||
| from megengine.core._trace_option import use_tensor_shape | |||||
| from megengine.module import BatchNorm1d, BatchNorm2d, SyncBatchNorm | |||||
| from megengine.tensor import Tensor | |||||
| from megengine.test import assertTensorClose | |||||
| @pytest.mark.skipif( | |||||
| platform.system() == "Darwin", reason="do not imp GPU mode at macos now" | |||||
| ) | |||||
| @pytest.mark.skipif( | |||||
| platform.system() == "Windows", reason="do not imp GPU mode at Windows now" | |||||
| ) | |||||
| @pytest.mark.skipif(use_tensor_shape(), reason="syncbn doesnot support symbolic shape") | |||||
| @pytest.mark.isolated_distributed | |||||
| def test_syncbn(): | |||||
| nr_chan = 8 | |||||
| data_shape = (3, nr_chan, 4, 16) | |||||
| momentum = 0.9 | |||||
| eps = 1e-5 | |||||
| running_mean = np.zeros((1, nr_chan, 1, 1), dtype=np.float32) | |||||
| running_var = np.ones((1, nr_chan, 1, 1), dtype=np.float32) | |||||
| steps = 4 | |||||
| nr_ranks = 2 | |||||
| server = dist.Server(0) | |||||
| port = server.py_server_port | |||||
| def worker(rank, data, yv_expect, running_mean, running_var): | |||||
| if mge.get_device_count("gpu") < nr_ranks: | |||||
| return | |||||
| dist.init_process_group("localhost", port, nr_ranks, rank, rank) | |||||
| bn = SyncBatchNorm(nr_chan, momentum=momentum, eps=eps) | |||||
| data_tensor = tensor([]) | |||||
| for i in range(steps): | |||||
| data_tensor.set_value(data[i]) | |||||
| yv = bn(data_tensor) | |||||
| assertTensorClose(yv_expect, yv.numpy(), max_err=5e-6) | |||||
| assertTensorClose(running_mean, bn.running_mean.numpy(), max_err=5e-6) | |||||
| assertTensorClose(running_var, bn.running_var.numpy(), max_err=5e-6) | |||||
| xv = [] | |||||
| for i in range(steps): | |||||
| xv.append(np.random.normal(loc=2.3, size=data_shape).astype(np.float32)) | |||||
| xv_transposed = np.transpose(xv[i], [0, 2, 3, 1]).reshape( | |||||
| (data_shape[0] * data_shape[2] * data_shape[3], nr_chan) | |||||
| ) | |||||
| mean = np.mean(xv_transposed, axis=0).reshape(1, nr_chan, 1, 1) | |||||
| var_biased = np.var(xv_transposed, axis=0).reshape((1, nr_chan, 1, 1)) | |||||
| sd = np.sqrt(var_biased + eps) | |||||
| var_unbiased = np.var(xv_transposed, axis=0, ddof=1).reshape((1, nr_chan, 1, 1)) | |||||
| running_mean = running_mean * momentum + mean * (1 - momentum) | |||||
| running_var = running_var * momentum + var_unbiased * (1 - momentum) | |||||
| yv_expect = (xv[i] - mean) / sd | |||||
| data = [] | |||||
| for i in range(nr_ranks): | |||||
| data.append([]) | |||||
| for j in range(steps): | |||||
| data[i].append(xv[j][:, :, :, i * 8 : i * 8 + 8]) | |||||
| procs = [] | |||||
| for rank in range(nr_ranks): | |||||
| p = mp.Process( | |||||
| target=worker, | |||||
| args=( | |||||
| rank, | |||||
| data[rank], | |||||
| yv_expect[:, :, :, rank * 8 : rank * 8 + 8], | |||||
| running_mean, | |||||
| running_var, | |||||
| ), | |||||
| ) | |||||
| p.start() | |||||
| procs.append(p) | |||||
| for p in procs: | |||||
| p.join(10) | |||||
| assert p.exitcode == 0 | |||||
| def test_batchnorm(): | |||||
| nr_chan = 8 | |||||
| data_shape = (3, nr_chan, 4) | |||||
| momentum = 0.9 | |||||
| bn = BatchNorm1d(nr_chan, momentum=momentum) | |||||
| running_mean = np.zeros((1, nr_chan, 1), dtype=np.float32) | |||||
| running_var = np.ones((1, nr_chan, 1), dtype=np.float32) | |||||
| data = tensor([]) | |||||
| for i in range(3): | |||||
| xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32) | |||||
| mean = np.mean(np.mean(xv, axis=0, keepdims=True), axis=2, keepdims=True) | |||||
| xv_transposed = np.transpose(xv, [0, 2, 1]).reshape( | |||||
| (data_shape[0] * data_shape[2], nr_chan) | |||||
| ) | |||||
| var_biased = np.var(xv_transposed, axis=0).reshape((1, nr_chan, 1)) | |||||
| sd = np.sqrt(var_biased + bn.eps) | |||||
| var_unbiased = np.var(xv_transposed, axis=0, ddof=1).reshape((1, nr_chan, 1)) | |||||
| running_mean = running_mean * momentum + mean * (1 - momentum) | |||||
| running_var = running_var * momentum + var_unbiased * (1 - momentum) | |||||
| data.set_value(xv) | |||||
| yv = bn(data) | |||||
| yv_expect = (xv - mean) / sd | |||||
| assertTensorClose(yv_expect, yv.numpy(), max_err=5e-6) | |||||
| assertTensorClose( | |||||
| running_mean.reshape(-1), bn.running_mean.numpy().reshape(-1), max_err=5e-6 | |||||
| ) | |||||
| assertTensorClose( | |||||
| running_var.reshape(-1), bn.running_var.numpy().reshape(-1), max_err=5e-6 | |||||
| ) | |||||
| # test set 'training' flag to False | |||||
| mean_backup = bn.running_mean.numpy() | |||||
| var_backup = bn.running_var.numpy() | |||||
| bn.training = False | |||||
| xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32) | |||||
| data.set_value(xv) | |||||
| yv1 = bn(data) | |||||
| yv2 = bn(data) | |||||
| assertTensorClose(yv1.numpy(), yv2.numpy(), max_err=0) | |||||
| assertTensorClose(mean_backup, bn.running_mean.numpy(), max_err=0) | |||||
| assertTensorClose(var_backup, bn.running_var.numpy(), max_err=0) | |||||
| yv_expect = (xv - running_mean) / np.sqrt(running_var + bn.eps) | |||||
| assertTensorClose(yv_expect, yv1.numpy(), max_err=5e-6) | |||||
| @pytest.mark.skipif( | |||||
| platform.system() == "Darwin", reason="do not imp GPU mode at macos now" | |||||
| ) | |||||
| @pytest.mark.skipif( | |||||
| platform.system() == "Windows", reason="do not imp GPU mode at Windows now" | |||||
| ) | |||||
| @pytest.mark.skipif(use_tensor_shape(), reason="syncbn doesnot support symbolic shape") | |||||
| @pytest.mark.isolated_distributed | |||||
| def test_syncbn1d(): | |||||
| nr_chan = 8 | |||||
| data_shape = (3, nr_chan, 4) | |||||
| momentum = 0.9 | |||||
| bn = SyncBatchNorm(nr_chan, momentum=momentum) | |||||
| running_mean = np.zeros((1, nr_chan, 1), dtype=np.float32) | |||||
| running_var = np.ones((1, nr_chan, 1), dtype=np.float32) | |||||
| data = tensor([]) | |||||
| for i in range(3): | |||||
| xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32) | |||||
| mean = np.mean(np.mean(xv, axis=0, keepdims=True), axis=2, keepdims=True) | |||||
| xv_transposed = np.transpose(xv, [0, 2, 1]).reshape( | |||||
| (data_shape[0] * data_shape[2], nr_chan) | |||||
| ) | |||||
| var_biased = np.var(xv_transposed, axis=0).reshape((1, nr_chan, 1)) | |||||
| sd = np.sqrt(var_biased + bn.eps) | |||||
| var_unbiased = np.var(xv_transposed, axis=0, ddof=1).reshape((1, nr_chan, 1)) | |||||
| running_mean = running_mean * momentum + mean * (1 - momentum) | |||||
| running_var = running_var * momentum + var_unbiased * (1 - momentum) | |||||
| data.set_value(xv) | |||||
| yv = bn(data) | |||||
| yv_expect = (xv - mean) / sd | |||||
| assertTensorClose(yv_expect, yv.numpy(), max_err=5e-6) | |||||
| assertTensorClose( | |||||
| running_mean.reshape(-1), bn.running_mean.numpy().reshape(-1), max_err=5e-6 | |||||
| ) | |||||
| assertTensorClose( | |||||
| running_var.reshape(-1), bn.running_var.numpy().reshape(-1), max_err=5e-6 | |||||
| ) | |||||
| # test set 'training' flag to False | |||||
| mean_backup = bn.running_mean.numpy() | |||||
| var_backup = bn.running_var.numpy() | |||||
| bn.training = False | |||||
| xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32) | |||||
| data.set_value(xv) | |||||
| yv1 = bn(data) | |||||
| yv2 = bn(data) | |||||
| assertTensorClose(yv1.numpy(), yv2.numpy(), max_err=0) | |||||
| assertTensorClose(mean_backup, bn.running_mean.numpy(), max_err=0) | |||||
| assertTensorClose(var_backup, bn.running_var.numpy(), max_err=0) | |||||
| yv_expect = (xv - running_mean) / np.sqrt(running_var + bn.eps) | |||||
| assertTensorClose(yv_expect, yv1.numpy(), max_err=5e-6) | |||||
| def test_batchnorm2d(): | |||||
| nr_chan = 8 | |||||
| data_shape = (3, nr_chan, 16, 16) | |||||
| momentum = 0.9 | |||||
| bn = BatchNorm2d(nr_chan, momentum=momentum) | |||||
| running_mean = np.zeros((1, nr_chan, 1, 1), dtype=np.float32) | |||||
| running_var = np.ones((1, nr_chan, 1, 1), dtype=np.float32) | |||||
| data = tensor([]) | |||||
| for i in range(3): | |||||
| xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32) | |||||
| xv_transposed = np.transpose(xv, [0, 2, 3, 1]).reshape( | |||||
| (data_shape[0] * data_shape[2] * data_shape[3], nr_chan) | |||||
| ) | |||||
| mean = np.mean(xv_transposed, axis=0).reshape(1, nr_chan, 1, 1) | |||||
| var_biased = np.var(xv_transposed, axis=0).reshape((1, nr_chan, 1, 1)) | |||||
| sd = np.sqrt(var_biased + bn.eps) | |||||
| var_unbiased = np.var(xv_transposed, axis=0, ddof=1).reshape((1, nr_chan, 1, 1)) | |||||
| running_mean = running_mean * momentum + mean * (1 - momentum) | |||||
| running_var = running_var * momentum + var_unbiased * (1 - momentum) | |||||
| data.set_value(xv) | |||||
| yv = bn(data) | |||||
| yv_expect = (xv - mean) / sd | |||||
| assertTensorClose(yv_expect, yv.numpy(), max_err=5e-6) | |||||
| assertTensorClose(running_mean, bn.running_mean.numpy(), max_err=5e-6) | |||||
| assertTensorClose(running_var, bn.running_var.numpy(), max_err=5e-6) | |||||
| # test set 'training' flag to False | |||||
| mean_backup = bn.running_mean.numpy() | |||||
| var_backup = bn.running_var.numpy() | |||||
| bn.training = False | |||||
| xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32) | |||||
| data.set_value(xv) | |||||
| yv1 = bn(data) | |||||
| yv2 = bn(data) | |||||
| assertTensorClose(yv1.numpy(), yv2.numpy(), max_err=0) | |||||
| assertTensorClose(mean_backup, bn.running_mean.numpy(), max_err=0) | |||||
| assertTensorClose(var_backup, bn.running_var.numpy(), max_err=0) | |||||
| yv_expect = (xv - running_mean) / np.sqrt(running_var + bn.eps) | |||||
| assertTensorClose(yv_expect, yv1.numpy(), max_err=5e-6) | |||||
| @pytest.mark.skipif( | |||||
| platform.system() == "Darwin", reason="do not imp GPU mode at macos now" | |||||
| ) | |||||
| @pytest.mark.skipif( | |||||
| platform.system() == "Windows", reason="do not imp GPU mode at Windows now" | |||||
| ) | |||||
| @pytest.mark.skipif(use_tensor_shape(), reason="syncbn doesnot support symbolic shape") | |||||
| @pytest.mark.isolated_distributed | |||||
| def test_syncbn2d(): | |||||
| nr_chan = 8 | |||||
| data_shape = (3, nr_chan, 16, 16) | |||||
| momentum = 0.9 | |||||
| bn = SyncBatchNorm(nr_chan, momentum=momentum) | |||||
| running_mean = np.zeros((1, nr_chan, 1, 1), dtype=np.float32) | |||||
| running_var = np.ones((1, nr_chan, 1, 1), dtype=np.float32) | |||||
| data = tensor([]) | |||||
| for i in range(3): | |||||
| xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32) | |||||
| xv_transposed = np.transpose(xv, [0, 2, 3, 1]).reshape( | |||||
| (data_shape[0] * data_shape[2] * data_shape[3], nr_chan) | |||||
| ) | |||||
| mean = np.mean(xv_transposed, axis=0).reshape(1, nr_chan, 1, 1) | |||||
| var_biased = np.var(xv_transposed, axis=0).reshape((1, nr_chan, 1, 1)) | |||||
| sd = np.sqrt(var_biased + bn.eps) | |||||
| var_unbiased = np.var(xv_transposed, axis=0, ddof=1).reshape((1, nr_chan, 1, 1)) | |||||
| running_mean = running_mean * momentum + mean * (1 - momentum) | |||||
| running_var = running_var * momentum + var_unbiased * (1 - momentum) | |||||
| data.set_value(xv) | |||||
| yv = bn(data) | |||||
| yv_expect = (xv - mean) / sd | |||||
| assertTensorClose(yv_expect, yv.numpy(), max_err=5e-6) | |||||
| assertTensorClose(running_mean, bn.running_mean.numpy(), max_err=5e-6) | |||||
| assertTensorClose(running_var, bn.running_var.numpy(), max_err=5e-6) | |||||
| # test set 'training' flag to False | |||||
| mean_backup = bn.running_mean.numpy() | |||||
| var_backup = bn.running_var.numpy() | |||||
| bn.training = False | |||||
| xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32) | |||||
| data.set_value(xv) | |||||
| yv1 = bn(data) | |||||
| yv2 = bn(data) | |||||
| assertTensorClose(yv1.numpy(), yv2.numpy(), max_err=0) | |||||
| assertTensorClose(mean_backup, bn.running_mean.numpy(), max_err=0) | |||||
| assertTensorClose(var_backup, bn.running_var.numpy(), max_err=0) | |||||
| yv_expect = (xv - running_mean) / np.sqrt(running_var + bn.eps) | |||||
| assertTensorClose(yv_expect, yv1.numpy(), max_err=5e-6) | |||||
| def test_batchnorm_no_stats(): | |||||
| nr_chan = 8 | |||||
| data_shape = (3, nr_chan, 4) | |||||
| bn = BatchNorm1d(8, track_running_stats=False) | |||||
| data = tensor([]) | |||||
| for i in range(4): | |||||
| if i == 2: | |||||
| bn.training = False | |||||
| xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32) | |||||
| mean = np.mean(np.mean(xv, axis=0, keepdims=True), axis=2, keepdims=True) | |||||
| var = np.var( | |||||
| np.transpose(xv, [0, 2, 1]).reshape( | |||||
| (data_shape[0] * data_shape[2], nr_chan) | |||||
| ), | |||||
| axis=0, | |||||
| ).reshape((1, nr_chan, 1)) | |||||
| sd = np.sqrt(var + bn.eps) | |||||
| data.set_value(xv) | |||||
| yv = bn(data) | |||||
| yv_expect = (xv - mean) / sd | |||||
| assertTensorClose(yv_expect, yv.numpy(), max_err=5e-6) | |||||
| @pytest.mark.skipif( | |||||
| platform.system() == "Darwin", reason="do not imp GPU mode at macos now" | |||||
| ) | |||||
| @pytest.mark.skipif( | |||||
| platform.system() == "Windows", reason="do not imp GPU mode at Windows now" | |||||
| ) | |||||
| @pytest.mark.skipif(use_tensor_shape(), reason="syncbn doesnot support symbolic shape") | |||||
| @pytest.mark.isolated_distributed | |||||
| def test_syncbn_no_stats(): | |||||
| nr_chan = 8 | |||||
| data_shape = (3, nr_chan, 4) | |||||
| bn = SyncBatchNorm(8, track_running_stats=False) | |||||
| data = tensor([]) | |||||
| for i in range(4): | |||||
| if i == 2: | |||||
| bn.training = False | |||||
| xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32) | |||||
| mean = np.mean(np.mean(xv, axis=0, keepdims=True), axis=2, keepdims=True) | |||||
| var = np.var( | |||||
| np.transpose(xv, [0, 2, 1]).reshape( | |||||
| (data_shape[0] * data_shape[2], nr_chan) | |||||
| ), | |||||
| axis=0, | |||||
| ).reshape((1, nr_chan, 1)) | |||||
| sd = np.sqrt(var + bn.eps) | |||||
| data.set_value(xv) | |||||
| yv = bn(data) | |||||
| yv_expect = (xv - mean) / sd | |||||
| assertTensorClose(yv_expect, yv.numpy(), max_err=5e-6) | |||||
| def test_batchnorm2d_no_stats(): | |||||
| nr_chan = 8 | |||||
| data_shape = (3, nr_chan, 16, 16) | |||||
| bn = BatchNorm2d(8, track_running_stats=False) | |||||
| data = tensor([]) | |||||
| for i in range(4): | |||||
| if i == 2: | |||||
| bn.training = False | |||||
| xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32) | |||||
| xv_transposed = np.transpose(xv, [0, 2, 3, 1]).reshape( | |||||
| (data_shape[0] * data_shape[2] * data_shape[3], nr_chan) | |||||
| ) | |||||
| mean = np.mean(xv_transposed, axis=0).reshape(1, nr_chan, 1, 1) | |||||
| var = np.var(xv_transposed, axis=0).reshape((1, nr_chan, 1, 1)) | |||||
| sd = np.sqrt(var + bn.eps) | |||||
| data.set_value(xv) | |||||
| yv = bn(data) | |||||
| yv_expect = (xv - mean) / sd | |||||
| assertTensorClose(yv_expect, yv.numpy(), max_err=5e-6) | |||||
| @pytest.mark.skipif( | |||||
| platform.system() == "Darwin", reason="do not imp GPU mode at macos now" | |||||
| ) | |||||
| @pytest.mark.skipif( | |||||
| platform.system() == "Windows", reason="do not imp GPU mode at Windows now" | |||||
| ) | |||||
| @pytest.mark.skipif(use_tensor_shape(), reason="syncbn doesnot support symbolic shape") | |||||
| @pytest.mark.isolated_distributed | |||||
| def test_syncbn2d_no_stats(): | |||||
| nr_chan = 8 | |||||
| data_shape = (3, nr_chan, 16, 16) | |||||
| bn = SyncBatchNorm(8, track_running_stats=False) | |||||
| data = tensor([]) | |||||
| for i in range(4): | |||||
| if i == 2: | |||||
| bn.training = False | |||||
| xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32) | |||||
| xv_transposed = np.transpose(xv, [0, 2, 3, 1]).reshape( | |||||
| (data_shape[0] * data_shape[2] * data_shape[3], nr_chan) | |||||
| ) | |||||
| mean = np.mean(xv_transposed, axis=0).reshape(1, nr_chan, 1, 1) | |||||
| var = np.var(xv_transposed, axis=0).reshape((1, nr_chan, 1, 1)) | |||||
| sd = np.sqrt(var + bn.eps) | |||||
| data.set_value(xv) | |||||
| yv = bn(data) | |||||
| yv_expect = (xv - mean) / sd | |||||
| assertTensorClose(yv_expect, yv.numpy(), max_err=5e-6) | |||||
| @@ -0,0 +1,110 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| import itertools | |||||
| import numpy as np | |||||
| from megengine import Parameter, tensor | |||||
| from megengine.module import ConvTranspose2d, LocalConv2d | |||||
| from megengine.test import assertTensorClose | |||||
| def test_conv_transpose2d(): | |||||
| SH, SW = 3, 1 | |||||
| PH, PW = 2, 0 | |||||
| N, IC, IH, IW = 4, 5, 8, 6 | |||||
| KH, KW = 3, 4 | |||||
| OC = 3 | |||||
| BIAS = False | |||||
| def getsize(inp, kern, stride): | |||||
| return (inp - 1) * stride + kern | |||||
| OH = getsize(IH, KH, SH) | |||||
| OW = getsize(IW, KW, SW) | |||||
| inp = np.random.normal(size=(N, IC, IH, IW)).astype(np.float32) | |||||
| out = np.zeros((N, OC, OH, OW), dtype=np.float32) | |||||
| weight = np.random.normal(size=(IC, OC, KH, KW)).astype(np.float32) | |||||
| bias = np.random.normal(size=(1, OC, 1, 1)).astype(np.float32) | |||||
| # naive calculation use numpy | |||||
| for n, ic, ih, iw in itertools.product(*map(range, [N, IC, IH, IW])): | |||||
| oh, ow = ih * SH, iw * SW | |||||
| out[n, :, oh : oh + KH, ow : ow + KW] += inp[n, ic, ih, iw] * weight[ic] | |||||
| out = out[:, :, PH : OH - PH, PW : OW - PW] | |||||
| if BIAS: | |||||
| out += bias | |||||
| # megengine conv_transpose2d calculation | |||||
| conv_transpose2d = ConvTranspose2d(IC, OC, (KH, KW), (SH, SW), (PH, PW), bias=BIAS) | |||||
| conv_transpose2d.weight = Parameter(weight, dtype=np.float32) | |||||
| if BIAS: | |||||
| conv_transpose2d.bias = Parameter(bias, dtype=np.float32) | |||||
| y = conv_transpose2d(tensor(inp)) | |||||
| assertTensorClose(out, y.numpy(), max_err=2e-6) | |||||
| def test_local_conv2d(): | |||||
| batch_size = 10 | |||||
| in_channels = 4 | |||||
| out_channels = 8 | |||||
| input_height = 8 | |||||
| input_width = 8 | |||||
| kernel_size = 3 | |||||
| stride = 1 | |||||
| padding = 1 | |||||
| dilation = 1 | |||||
| groups = 1 | |||||
| local_conv2d = LocalConv2d( | |||||
| in_channels=in_channels, | |||||
| out_channels=out_channels, | |||||
| input_height=input_height, | |||||
| input_width=input_width, | |||||
| kernel_size=kernel_size, | |||||
| stride=stride, | |||||
| padding=padding, | |||||
| dilation=dilation, | |||||
| groups=groups, | |||||
| ) | |||||
| inputs = np.random.normal( | |||||
| size=(batch_size, in_channels, input_height, input_width) | |||||
| ).astype(np.float32) | |||||
| output_height = (input_height + padding * 2 - kernel_size) // stride + 1 | |||||
| output_width = (input_width + padding * 2 - kernel_size) // stride + 1 | |||||
| weights = np.random.normal( | |||||
| size=( | |||||
| groups, | |||||
| output_height, | |||||
| output_width, | |||||
| in_channels // groups, | |||||
| kernel_size, | |||||
| kernel_size, | |||||
| out_channels // groups, | |||||
| ) | |||||
| ).astype(np.float32) | |||||
| local_conv2d.weight = Parameter(weights) | |||||
| outputs = local_conv2d(tensor(inputs)) | |||||
| # naive calculation use numpy | |||||
| # only test output_height == input_height, output_width == input_width, group == 1 | |||||
| inputs = np.pad(inputs, ((0, 0), (0, 0), (1, 1), (1, 1))) | |||||
| expected = np.zeros( | |||||
| (batch_size, out_channels, output_height, output_width), dtype=np.float32, | |||||
| ) | |||||
| for n, oc, oh, ow in itertools.product( | |||||
| *map(range, [batch_size, out_channels, output_height, output_width]) | |||||
| ): | |||||
| ih, iw = oh * stride, ow * stride | |||||
| expected[n, oc, ih, iw] = np.sum( | |||||
| inputs[n, :, ih : ih + kernel_size, iw : iw + kernel_size] | |||||
| * weights[0, oh, ow, :, :, :, oc] | |||||
| ) | |||||
| assertTensorClose(outputs.numpy(), expected, max_err=1e-5) | |||||
| @@ -0,0 +1,46 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| import os | |||||
| import numpy as np | |||||
| import pytest | |||||
| import megengine as mge | |||||
| from megengine import tensor | |||||
| from megengine.module import Module | |||||
| class MyModule(Module): | |||||
| def __init__(self, data): | |||||
| from megengine.module.external import CambriconSubgraph | |||||
| super().__init__() | |||||
| self.cambricon = CambriconSubgraph(data, "subnet0", True) | |||||
| def forward(self, inputs): | |||||
| out = self.cambricon(inputs) | |||||
| return out | |||||
| @pytest.mark.skip(reason="cambricon unimplemented") | |||||
| def test_cambricon_module(): | |||||
| model = "CambriconRuntimeOprTest.MutableBatchSize.mlu" | |||||
| model = os.path.join(os.path.dirname(__file__), model) | |||||
| with open(model, "rb") as f: | |||||
| data = f.read() | |||||
| m = MyModule(data) | |||||
| inputs = [] | |||||
| inputs.append(tensor(data=[], dtype=np.float16, device="cambricon0")) | |||||
| inputs[0].set_value(np.random.normal(size=(1, 64, 32, 32)).astype(np.float16)) | |||||
| def inference(inps): | |||||
| pred = m(inps) | |||||
| return pred | |||||
| pred = inference(inputs) | |||||
| @@ -0,0 +1,27 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| import pytest | |||||
| from megengine.module import Conv2d, Linear | |||||
| from megengine.module.init import calculate_fan_in_and_fan_out | |||||
| def test_calculate_fan_in_and_fan_out(): | |||||
| l = Linear(in_features=3, out_features=8) | |||||
| fanin, fanout = calculate_fan_in_and_fan_out(l.weight) | |||||
| assert fanin == 3 | |||||
| assert fanout == 8 | |||||
| with pytest.raises(ValueError): | |||||
| calculate_fan_in_and_fan_out(l.bias) | |||||
| l = Conv2d(in_channels=2, out_channels=3, kernel_size=(5, 7)) | |||||
| fanin, fanout = calculate_fan_in_and_fan_out(l.weight) | |||||
| assert fanin == 2 * 5 * 7 | |||||
| assert fanout == 3 * 5 * 7 | |||||
| @@ -0,0 +1,614 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| import os | |||||
| import tempfile | |||||
| from collections import OrderedDict | |||||
| from io import BytesIO | |||||
| import numpy as np | |||||
| import pytest | |||||
| import megengine as mge | |||||
| import megengine.functional as F | |||||
| from megengine import Buffer, Parameter, Tensor, tensor | |||||
| from megengine.module import ( | |||||
| BatchNorm1d, | |||||
| BatchNorm2d, | |||||
| Conv2d, | |||||
| Linear, | |||||
| Module, | |||||
| Sequential, | |||||
| ) | |||||
| from megengine.quantization.quantize import quantize, quantize_qat | |||||
| from megengine.test import assertTensorClose | |||||
| class MLP(Module): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.dense0 = Linear(28, 50) | |||||
| self.dense1 = Linear(50, 20) | |||||
| def forward(self, x): | |||||
| x = self.dense0(x) | |||||
| x = F.relu(x) | |||||
| x = self.dense1(x) | |||||
| return x | |||||
| def has_gpu(num=1): | |||||
| try: | |||||
| mgb.comp_node("gpu{}".format(num - 1)) | |||||
| except mgb.MegBrainError: | |||||
| return False | |||||
| return True | |||||
| def randomNp(*args): | |||||
| for arg in args: | |||||
| assert isinstance(arg, int) | |||||
| return np.random.random(args) | |||||
| def randomTorch(*args): | |||||
| import torch # pylint: disable=import-outside-toplevel | |||||
| for arg in args: | |||||
| assert isinstance(arg, int) | |||||
| return torch.tensor(randomNp(*args), dtype=torch.float32) | |||||
| def graph_mode(*modes): | |||||
| if not set(modes).issubset({"eager", "static"}): | |||||
| raise ValueError("graph mode must be in (eager, static)") | |||||
| def decorator(func): | |||||
| def wrapper(*args, **kwargs): | |||||
| if "eager" in set(modes): | |||||
| func(*args, **kwargs) | |||||
| if "static" in set(modes): | |||||
| with Graph() as cg: | |||||
| cg.set_option("eager_evaluation", False) | |||||
| func(*args, **kwargs) | |||||
| return wrapper | |||||
| return decorator | |||||
| def _default_compare_fn(x, y): | |||||
| assertTensorClose(x.numpy(), y) | |||||
| def opr_test( | |||||
| cases, | |||||
| func, | |||||
| mode=("eager", "static", "dynamic_shape"), | |||||
| compare_fn=_default_compare_fn, | |||||
| ref_fn=None, | |||||
| **kwargs | |||||
| ): | |||||
| """ | |||||
| mode: the list of test mode which are eager, static and dynamic_shape | |||||
| will test all the cases if None. | |||||
| func: the function to run opr. | |||||
| compare_fn: the function to compare the result and expected, use assertTensorClose if None. | |||||
| ref_fn: the function to generate expected data, should assign output if None. | |||||
| cases: the list which have dict element, the list length should be 2 for dynamic shape test. | |||||
| and the dict should have input, | |||||
| and should have output if ref_fn is None. | |||||
| should use list for multiple inputs and outputs for each case. | |||||
| kwargs: The additional kwargs for opr func. | |||||
| simple examples: | |||||
| dtype = np.float32 | |||||
| cases = [{"input": [10, 20]}, {"input": [20, 30]}] | |||||
| opr_test(cases, | |||||
| F.eye, | |||||
| ref_fn=lambda n, m: np.eye(n, m).astype(dtype), | |||||
| dtype=dtype) | |||||
| """ | |||||
| def check_results(results, expected): | |||||
| if not isinstance(results, Tuple): | |||||
| results = (results,) | |||||
| for r, e in zip(results, expected): | |||||
| compare_fn(r, e) | |||||
| def get_trace_fn(func, enabled, symbolic): | |||||
| jit.trace.enabled = enabled | |||||
| return jit.trace(func, symbolic=symbolic) | |||||
| def get_param(cases, idx): | |||||
| case = cases[idx] | |||||
| inp = case.get("input", None) | |||||
| outp = case.get("output", None) | |||||
| if inp is None: | |||||
| raise ValueError("the test case should have input") | |||||
| if not isinstance(inp, List): | |||||
| inp = (inp,) | |||||
| else: | |||||
| inp = tuple(inp) | |||||
| if ref_fn is not None and callable(ref_fn): | |||||
| outp = ref_fn(*inp) | |||||
| if outp is None: | |||||
| raise ValueError("the test case should have output or reference function") | |||||
| if not isinstance(outp, List): | |||||
| outp = (outp,) | |||||
| else: | |||||
| outp = tuple(outp) | |||||
| return inp, outp | |||||
| if not set(mode).issubset({"eager", "static", "dynamic_shape"}): | |||||
| raise ValueError("opr test mode must be in (eager, static, dynamic_shape)") | |||||
| if len(cases) == 0: | |||||
| raise ValueError("should give one case at least") | |||||
| if "dynamic_shape" in set(mode): | |||||
| if len(cases) != 2: | |||||
| raise ValueError("should give 2 cases for dynamic shape test") | |||||
| if not callable(func): | |||||
| raise ValueError("the input func should be callable") | |||||
| inp, outp = get_param(cases, 0) | |||||
| def run(*args, **kwargs): | |||||
| return func(*args, **kwargs) | |||||
| if "eager" in set(mode): | |||||
| f = get_trace_fn(run, False, False) | |||||
| results = f(*inp, **kwargs) | |||||
| check_results(results, outp) | |||||
| if "static" in set(mode) or "dynamic_shape" in set(mode): | |||||
| f = get_trace_fn(run, True, True) | |||||
| results = f(*inp, **kwargs) | |||||
| check_results(results, outp) | |||||
| if "dynamic_shape" in set(mode): | |||||
| inp, outp = get_param(cases, 1) | |||||
| results = f(*inp, **kwargs) | |||||
| check_results(results, outp) | |||||
| class MyModule(Module): | |||||
| class InnerModule(Module): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.bn = BatchNorm2d(4) | |||||
| def forward(self, x): | |||||
| return self.bn(x) | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.i = self.InnerModule() | |||||
| self.bn = BatchNorm2d(4) | |||||
| self.param = Parameter(np.ones(1, dtype=np.float32)) | |||||
| self.buff = Buffer(np.ones(1, dtype=np.float32)) | |||||
| def forward(self, x): | |||||
| x = self.i(x) | |||||
| x = self.bn(x) | |||||
| return x | |||||
| def test_module_api(): | |||||
| m = MyModule() | |||||
| assert list(m.children()) == [m.bn, m.i] | |||||
| assert list(m.named_children()) == [("bn", m.bn), ("i", m.i)] | |||||
| assert list(m.modules()) == [m, m.bn, m.i, m.i.bn] | |||||
| assert list(m.named_modules()) == [ | |||||
| ("", m), | |||||
| ("bn", m.bn), | |||||
| ("i", m.i), | |||||
| ("i.bn", m.i.bn), | |||||
| ] | |||||
| assert list(m.named_modules(prefix="x")) == [ | |||||
| ("x", m), | |||||
| ("x.bn", m.bn), | |||||
| ("x.i", m.i), | |||||
| ("x.i.bn", m.i.bn), | |||||
| ] | |||||
| assert list(m.buffers()) == [ | |||||
| m.bn.running_mean, | |||||
| m.bn.running_var, | |||||
| m.buff, | |||||
| m.i.bn.running_mean, | |||||
| m.i.bn.running_var, | |||||
| ] | |||||
| assert list(m.buffers(recursive=False)) == [m.buff] | |||||
| assert list(m.named_buffers()) == [ | |||||
| ("bn.running_mean", m.bn.running_mean), | |||||
| ("bn.running_var", m.bn.running_var), | |||||
| ("buff", m.buff), | |||||
| ("i.bn.running_mean", m.i.bn.running_mean), | |||||
| ("i.bn.running_var", m.i.bn.running_var), | |||||
| ] | |||||
| assert list(m.parameters()) == [ | |||||
| m.bn.bias, | |||||
| m.bn.weight, | |||||
| m.i.bn.bias, | |||||
| m.i.bn.weight, | |||||
| m.param, | |||||
| ] | |||||
| assert list(m.named_parameters()) == [ | |||||
| ("bn.bias", m.bn.bias), | |||||
| ("bn.weight", m.bn.weight), | |||||
| ("i.bn.bias", m.i.bn.bias), | |||||
| ("i.bn.weight", m.i.bn.weight), | |||||
| ("param", m.param), | |||||
| ] | |||||
| m.eval() | |||||
| assert ( | |||||
| m.training == False | |||||
| and m.bn.training == False | |||||
| and m.i.training == False | |||||
| and m.i.bn.training == False | |||||
| ) | |||||
| m.bn.train() | |||||
| assert m.training == False and m.bn.training == True and m.i.bn.training == False | |||||
| m.eval() | |||||
| m.i.train() | |||||
| assert ( | |||||
| m.training == False | |||||
| and m.bn.training == False | |||||
| and m.i.training == True | |||||
| and m.i.bn.training == True | |||||
| ) | |||||
| m.eval() | |||||
| m.train() | |||||
| assert m.training == True and m.bn.training == True and m.i.bn.training == True | |||||
| def fn(m): | |||||
| m.training = False | |||||
| m.apply(fn) | |||||
| assert m.bn.training == False and m.i.bn.training == False | |||||
| def test_module_api_reuse_submodule(): | |||||
| m = MyModule() | |||||
| m.h = m.i # pylint: disable=attribute-defined-outside-init | |||||
| assert list(m.modules()) == [m, m.bn, m.i, m.i.bn] | |||||
| assert list(m.named_modules()) == [ | |||||
| ("", m), | |||||
| ("bn", m.bn), | |||||
| ("h", m.i), | |||||
| ("h.bn", m.i.bn), | |||||
| ] | |||||
| def test_module_api_iterable_stability(): | |||||
| m = MyModule() | |||||
| l = list(m.modules()) | |||||
| for _ in range(100): | |||||
| assert list(m.modules()) == l | |||||
| def test_module_api_hooks(): | |||||
| net = MyModule() | |||||
| pre_hook_num = 0 | |||||
| post_hook_num = 0 | |||||
| hooks = [] | |||||
| def pre_hook(module, inputs): | |||||
| nonlocal pre_hook_num | |||||
| pre_hook_num += 1 | |||||
| modified_inputs = tuple(inp + 1 for inp in inputs) | |||||
| return modified_inputs | |||||
| def post_hook(module, inputs, outputs): | |||||
| nonlocal post_hook_num | |||||
| post_hook_num += 1 | |||||
| outputs += 1 | |||||
| return outputs | |||||
| net.apply(lambda module: hooks.append(module.register_forward_pre_hook(pre_hook))) | |||||
| net.apply(lambda module: hooks.append(module.register_forward_hook(post_hook))) | |||||
| shape = (1, 4, 1, 1) | |||||
| x = tensor(np.zeros(shape, dtype=np.float32)) | |||||
| y = net(x) | |||||
| assert pre_hook_num == 4 | |||||
| assert post_hook_num == 4 | |||||
| mean1 = Parameter(np.zeros(shape), dtype=np.float32) | |||||
| bn1 = F.batch_norm2d( | |||||
| x + 3, mean1, Parameter(np.ones(shape), dtype=np.float32), training=True | |||||
| ) | |||||
| assertTensorClose( | |||||
| net.i.bn.running_mean.numpy(), mean1.numpy(), | |||||
| ) | |||||
| mean2 = Parameter(np.zeros(shape), dtype=np.float32) | |||||
| bn2 = F.batch_norm2d( | |||||
| bn1 + 3, mean2, Parameter(np.ones(shape), dtype=np.float32), training=True | |||||
| ) | |||||
| assertTensorClose( | |||||
| net.bn.running_mean.numpy(), mean2.numpy(), | |||||
| ) | |||||
| assertTensorClose((bn2 + 2).numpy(), y.numpy()) | |||||
| assert len(hooks) == 8 | |||||
| for handler in hooks: | |||||
| handler.remove() | |||||
| y = net(x) | |||||
| assert pre_hook_num == 4 | |||||
| assert post_hook_num == 4 | |||||
| class MyModule2(Module): | |||||
| class InnerModule(Module): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.bn = BatchNorm2d(4) | |||||
| self.test_bool_key = {True: 1, False: 0} | |||||
| def forward(self, x): | |||||
| x = self.bn(x) | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.bn = BatchNorm2d(4) | |||||
| self.a = [ | |||||
| BatchNorm2d(4), | |||||
| {"x": BatchNorm2d(4), "y": [BatchNorm2d(4), self.InnerModule()], "z": 0}, | |||||
| (self.InnerModule(),), | |||||
| ] | |||||
| def forward(self, x): | |||||
| return x | |||||
| def test_expand_structure(): | |||||
| m = MyModule2() | |||||
| assert list(m.named_modules()) == [ | |||||
| ("", m), | |||||
| ("a.0", m.a[0]), | |||||
| ("a.1.x", m.a[1]["x"]), | |||||
| ("a.1.y.0", m.a[1]["y"][0]), | |||||
| ("a.1.y.1", m.a[1]["y"][1]), | |||||
| ("a.1.y.1.bn", m.a[1]["y"][1].bn), | |||||
| ("a.2.0", m.a[2][0]), | |||||
| ("a.2.0.bn", m.a[2][0].bn), | |||||
| ("bn", m.bn), | |||||
| ] | |||||
| def test_flatten_others(): | |||||
| def be_others(obj): | |||||
| return not isinstance(obj, (Tensor, Module)) | |||||
| m = MyModule2() | |||||
| assert len(list(m._flatten(with_key=True, predicate=be_others))) == 0 | |||||
| def test_flatten_with_parent(): | |||||
| m = MyModule2() | |||||
| assert list(m.named_modules(with_parent=True)) == [ | |||||
| ("", m, None), | |||||
| ("a.0", m.a[0], m), | |||||
| ("a.1.x", m.a[1]["x"], m), | |||||
| ("a.1.y.0", m.a[1]["y"][0], m), | |||||
| ("a.1.y.1", m.a[1]["y"][1], m), | |||||
| ("a.1.y.1.bn", m.a[1]["y"][1].bn, m.a[1]["y"][1]), | |||||
| ("a.2.0", m.a[2][0], m), | |||||
| ("a.2.0.bn", m.a[2][0].bn, m.a[2][0]), | |||||
| ("bn", m.bn, m), | |||||
| ] | |||||
| assert list(m.modules(with_parent=True)) == [ | |||||
| (m, None), | |||||
| (m.a[0], m), | |||||
| (m.a[1]["x"], m), | |||||
| (m.a[1]["y"][0], m), | |||||
| (m.a[1]["y"][1], m), | |||||
| (m.a[1]["y"][1].bn, m.a[1]["y"][1]), | |||||
| (m.a[2][0], m), | |||||
| (m.a[2][0].bn, m.a[2][0]), | |||||
| (m.bn, m), | |||||
| ] | |||||
| class MyModule3(Module): | |||||
| class InnerModule(Module): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.bn = BatchNorm2d(4) | |||||
| def forward(self, x): | |||||
| x = self.bn(x) | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.bn = BatchNorm2d(4) | |||||
| self.seq = Sequential(BatchNorm2d(4), self.InnerModule(),) | |||||
| def forward(self, x): | |||||
| return x | |||||
| def test_module_api_with_sequential(): | |||||
| m = MyModule3() | |||||
| assert list(m.named_modules()) == [ | |||||
| ("", m), | |||||
| ("bn", m.bn), | |||||
| ("seq", m.seq), | |||||
| ("seq.0", m.seq[0]), | |||||
| ("seq.1", m.seq[1]), | |||||
| ("seq.1.bn", m.seq[1].bn), | |||||
| ] | |||||
| def test_sequential_named_children(): | |||||
| modules = OrderedDict() | |||||
| modules["name0"] = Linear(20, 10) | |||||
| modules["name1"] = Linear(10, 5) | |||||
| modules["name2"] = Linear(5, 1) | |||||
| m = Sequential(modules) | |||||
| l = list(m.named_children()) | |||||
| assert l[0][0] == "layer_values.0" | |||||
| assert l[1][0] == "layer_values.1" | |||||
| assert l[2][0] == "layer_values.2" | |||||
| def test_state_dict(): | |||||
| data_shape = (2, 28) | |||||
| data = tensor([]) | |||||
| data.set_value(np.random.random(data_shape)) | |||||
| mlp = MLP() | |||||
| pred0 = mlp(data) | |||||
| with BytesIO() as fout: | |||||
| mge.save(mlp.state_dict(), fout) | |||||
| fout.seek(0) | |||||
| state_dict = mge.load(fout) | |||||
| state_dict["extra"] = None | |||||
| mlp1 = MLP() | |||||
| mlp1.load_state_dict(state_dict, strict=False) | |||||
| pred1 = mlp1(data) | |||||
| assertTensorClose(pred0.numpy(), pred1.numpy(), max_err=5e-6) | |||||
| with pytest.raises(KeyError): | |||||
| mlp1.load_state_dict(state_dict) | |||||
| del state_dict["extra"] | |||||
| del state_dict["dense0.bias"] | |||||
| with pytest.raises(KeyError): | |||||
| mlp1.load_state_dict(state_dict) | |||||
| class AssertModule(Module): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.error_tensor_key = {True: tensor([]), False: 0} | |||||
| def forward(self, x): | |||||
| return x | |||||
| def test_assert_message(): | |||||
| m = AssertModule() | |||||
| with pytest.raises( | |||||
| AssertionError, match="keys for Tensor and Module must be str, error key: True" | |||||
| ): | |||||
| list(m._flatten()) | |||||
| class Simple(Module): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.conv0 = Conv2d(1, 1, kernel_size=3, bias=False) | |||||
| self.conv1 = Conv2d(1, 1, kernel_size=3, bias=False) | |||||
| self.conv1.weight = self.conv0.weight | |||||
| def forward(self, inputs): | |||||
| pass | |||||
| def test_shared_param(): | |||||
| net = Simple() | |||||
| assert net.conv0.weight is net.conv1.weight | |||||
| data = tensor(np.random.random((1, 1, 8, 8)).astype(np.float32)) | |||||
| assertTensorClose(net.conv0(data).numpy(), net.conv1(data).numpy()) | |||||
| with BytesIO() as f: | |||||
| mge.save(net, f) | |||||
| f.seek(0) | |||||
| net1 = mge.load(f) | |||||
| assert net1.conv0.weight is net1.conv1.weight | |||||
| assertTensorClose(net1.conv0(data).numpy(), net1.conv1(data).numpy()) | |||||
| with BytesIO() as f: | |||||
| mge.save(net.conv0, f) | |||||
| f.seek(0) | |||||
| conv0 = mge.load(f) | |||||
| with BytesIO() as f: | |||||
| mge.save(net.conv1, f) | |||||
| f.seek(0) | |||||
| conv1 = mge.load(f) | |||||
| assert conv0.weight is not conv1.weight | |||||
| assertTensorClose(conv0(data).numpy(), conv1(data).numpy()) | |||||
| def test_pickle_module(): | |||||
| data_shape = (2, 28) | |||||
| data = tensor([]) | |||||
| data.set_value(np.random.random(data_shape)) | |||||
| mlp = MLP() | |||||
| # pickle before forward | |||||
| with BytesIO() as fout: | |||||
| mge.save(mlp, fout) | |||||
| fout.seek(0) | |||||
| mlp1 = mge.load(fout) | |||||
| pred0 = mlp1(data) | |||||
| pred1 = mlp(data) | |||||
| # pickle after forward | |||||
| with BytesIO() as fout: | |||||
| mge.save(mlp, fout) | |||||
| fout.seek(0) | |||||
| mlp1 = mge.load(fout) | |||||
| pred2 = mlp1(data) | |||||
| assertTensorClose(pred0.numpy(), pred1.numpy(), max_err=5e-6) | |||||
| assertTensorClose(pred0.numpy(), pred2.numpy(), max_err=5e-6) | |||||
| @pytest.mark.skip(reason="under development") | |||||
| def test_dump_model(): | |||||
| data_shape = (2, 28) | |||||
| data = tensor([]) | |||||
| data.set_value(np.random.random(data_shape)) | |||||
| mlp = MLP() | |||||
| pred = mlp(data) | |||||
| f = tempfile.NamedTemporaryFile(delete=False) | |||||
| f_name = f.name | |||||
| try: | |||||
| mge.dump(pred, f_name) | |||||
| finally: | |||||
| f.close() | |||||
| os.unlink(f_name) | |||||
| def test_load_quantized(): | |||||
| from megengine.core.tensor import dtype | |||||
| data_shape = (2, 28) | |||||
| data = tensor(np.random.random(data_shape), dtype="float32") | |||||
| data = data.astype(dtype.qint8(0.1)) | |||||
| mlp = MLP() | |||||
| quantize_qat(mlp) | |||||
| quantize(mlp) | |||||
| mlp.dense0.weight = Parameter(mlp.dense0.weight.astype(dtype.qint8(0.001)).numpy()) | |||||
| mlp.dense1.weight = Parameter(mlp.dense1.weight.astype(dtype.qint8(0.0002)).numpy()) | |||||
| mlp.eval() | |||||
| pred0 = mlp(data) | |||||
| with BytesIO() as fout: | |||||
| mge.save(mlp.state_dict(), fout) | |||||
| fout.seek(0) | |||||
| checkpoint = mge.load(fout) | |||||
| # change mlp weight. | |||||
| mlp.dense0.weight = Parameter( | |||||
| mlp.dense0.weight.astype(dtype.qint8(0.00001)).numpy() | |||||
| ) | |||||
| mlp.dense1.weight = Parameter( | |||||
| mlp.dense1.weight.astype(dtype.qint8(0.2)).numpy() | |||||
| ) | |||||
| mlp.load_state_dict(checkpoint) | |||||
| pred1 = mlp(data) | |||||
| assertTensorClose( | |||||
| pred0.astype("float32").numpy(), pred1.astype("float32").numpy(), max_err=5e-6 | |||||
| ) | |||||
| @@ -0,0 +1,91 @@ | |||||
| from itertools import product | |||||
| import numpy as np | |||||
| from megengine import tensor | |||||
| from megengine.module import ( | |||||
| Conv2d, | |||||
| ConvBn2d, | |||||
| ConvRelu2d, | |||||
| DequantStub, | |||||
| Module, | |||||
| QuantStub, | |||||
| ) | |||||
| from megengine.quantization.quantize import disable_fake_quant, quantize_qat | |||||
| from megengine.test import assertTensorClose | |||||
| def test_qat_convbn2d(): | |||||
| in_channels = 32 | |||||
| out_channels = 64 | |||||
| kernel_size = 3 | |||||
| for groups, bias in product([1, 4], [True, False]): | |||||
| module = ConvBn2d( | |||||
| in_channels, out_channels, kernel_size, groups=groups, bias=bias | |||||
| ) | |||||
| module.train() | |||||
| qat_module = quantize_qat(module, inplace=False) | |||||
| disable_fake_quant(qat_module) | |||||
| inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32)) | |||||
| normal_outputs = module(inputs) | |||||
| # import pdb | |||||
| # pdb.set_trace() | |||||
| qat_outputs = qat_module(inputs) | |||||
| assertTensorClose(normal_outputs.numpy(), qat_outputs.numpy(), max_err=5e-6) | |||||
| assertTensorClose( | |||||
| module.bn.running_mean.numpy(), | |||||
| qat_module.bn.running_mean.numpy(), | |||||
| max_err=5e-8, | |||||
| ) | |||||
| assertTensorClose( | |||||
| module.bn.running_var.numpy(), | |||||
| qat_module.bn.running_var.numpy(), | |||||
| max_err=5e-7, | |||||
| ) | |||||
| module.eval() | |||||
| normal_outputs = module(inputs) | |||||
| qat_module.eval() | |||||
| qat_outputs = qat_module(inputs) | |||||
| assertTensorClose(normal_outputs.numpy(), qat_outputs.numpy(), max_err=5e-6) | |||||
| def test_qat_conv(): | |||||
| in_channels = 32 | |||||
| out_channels = 64 | |||||
| kernel_size = 3 | |||||
| class TestNet(Module): | |||||
| def __init__(self, groups, bias): | |||||
| super().__init__() | |||||
| self.quant = QuantStub() | |||||
| self.dequant = DequantStub() | |||||
| self.conv = Conv2d( | |||||
| in_channels, out_channels, kernel_size, groups=groups, bias=bias | |||||
| ) | |||||
| self.conv_relu = ConvRelu2d( | |||||
| out_channels, in_channels, kernel_size, groups=groups, bias=bias | |||||
| ) | |||||
| def forward(self, inp): | |||||
| out = self.quant(inp) | |||||
| out = self.conv(out) | |||||
| out = self.conv_relu(out) | |||||
| out = self.dequant(out) | |||||
| return out | |||||
| inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32)) | |||||
| for groups, bias in product([1, 4], [True, False]): | |||||
| net = TestNet(groups, bias) | |||||
| net.train() | |||||
| qat_net = quantize_qat(net, inplace=False) | |||||
| disable_fake_quant(qat_net) | |||||
| normal_outputs = net(inputs) | |||||
| qat_outputs = qat_net(inputs) | |||||
| assertTensorClose(normal_outputs.numpy(), qat_outputs.numpy()) | |||||
| net.eval() | |||||
| normal_outputs = net(inputs) | |||||
| qat_net.eval() | |||||
| qat_outputs = qat_net(inputs) | |||||
| assertTensorClose(normal_outputs.numpy(), qat_outputs.numpy()) | |||||
| @@ -0,0 +1,89 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| import copy | |||||
| import numpy as np | |||||
| import pytest | |||||
| import megengine as mge | |||||
| import megengine.functional as F | |||||
| from megengine import Buffer, Parameter | |||||
| from megengine.module import Conv2d | |||||
| from megengine.test import assertTensorClose | |||||
| def test_set_value(): | |||||
| v0 = np.random.random((2, 3)).astype(np.float32) | |||||
| param = Parameter(v0) | |||||
| v1 = np.random.random((2, 3)).astype(np.float32) | |||||
| param.set_value(v1) | |||||
| assertTensorClose(param.numpy(), v1, max_err=5e-6) | |||||
| v2 = np.random.random((3, 3)).astype(np.float32) | |||||
| # TODO: add this | |||||
| # with pytest.raises(ValueError): | |||||
| # param.set_value(v2) | |||||
| assertTensorClose(param.numpy(), v1, max_err=5e-6) | |||||
| @pytest.mark.skip(reason="fill unsupported") | |||||
| def test_fill(): | |||||
| a = Buffer(np.zeros((2, 3), dtype=np.float32)) | |||||
| a.fill(3) | |||||
| assertTensorClose(a.numpy(), np.full((2, 3), 3, dtype=np.float32)) | |||||
| a.fill(124.568) | |||||
| assertTensorClose(a.numpy(), np.full((2, 3), 124.568, dtype=np.float32)) | |||||
| # TODO: remove or rewrite following test | |||||
| # def test_attach(): | |||||
| # p_ = np.random.random((2, 3)).astype(np.float32) | |||||
| # with Graph() as g: | |||||
| # g.set_option('eager_evaluation', False) | |||||
| # p = Parameter(p_) | |||||
| # v = p * 2 | |||||
| # f = compile(v, None) | |||||
| # out, = f() | |||||
| # assertTensorClose(out, p_ * 2) | |||||
| # F.add_update(p, p) | |||||
| # out, = f() | |||||
| # assertTensorClose(out, p_ * 4) | |||||
| # TODO: remove or rewrite following test | |||||
| # def test_module_attach(): | |||||
| # v = np.random.random((1, 3, 64, 64)).astype(np.float32) | |||||
| # net = Conv2d(3, 16, 3) | |||||
| # with Graph() as g: | |||||
| # g.set_option('eager_evaluation', False) | |||||
| # data0 = Input("data") | |||||
| # f = compile(net(data0), None) | |||||
| # out0, = f(data=v) | |||||
| # data1 = Input("data", value=v) | |||||
| # out1 = net(data1) | |||||
| # assertTensorClose(out0, out1.numpy()) | |||||
| # def test_shape_warning(): | |||||
| # with Graph() as cg: | |||||
| # cg.set_option("eager_evaluation", False) | |||||
| # b = Buffer(np.ones((2, 3)).astype(np.float32)) | |||||
| # with pytest.warns(None) as record: | |||||
| # print(b.shape) | |||||
| # if len(record) != 0: | |||||
| # raise ValueError( | |||||
| # "Getting the shape of a constant Tensor should throw no Warning" | |||||
| # ) | |||||
| @@ -1,112 +0,0 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| import platform | |||||
| import pytest | |||||
| import megengine as mge | |||||
| import megengine.distributed as dist | |||||
| from megengine import tensor | |||||
| from megengine.distributed.group import Group | |||||
| from megengine.distributed.helper import get_device_count_by_fork | |||||
| from megengine.module import SyncBatchNorm | |||||
| from megengine.test import assertTensorClose | |||||
| @pytest.mark.skipif( | |||||
| platform.system() == "Darwin", reason="do not imp GPU mode at macos now" | |||||
| ) | |||||
| @pytest.mark.skipif( | |||||
| platform.system() == "Windows", reason="do not imp GPU mode at Windows now" | |||||
| ) | |||||
| @pytest.mark.skipif(get_device_count_by_fork("gpu") < 4, reason="need more gpu device") | |||||
| @pytest.mark.isolated_distributed | |||||
| def test_syncbn(): | |||||
| import numpy as np | |||||
| import multiprocessing as mp | |||||
| from megengine.distributed.group import Server | |||||
| from megengine.core._trace_option import use_tensor_shape | |||||
| if use_tensor_shape(): # XXX: fix sync bn if use_tensor_shape | |||||
| return | |||||
| nr_chan = 8 | |||||
| nr_ranks = 4 | |||||
| data_shape = (3, nr_chan, 4, nr_ranks * 8) | |||||
| momentum = 0.9 | |||||
| eps = 1e-5 | |||||
| running_mean = np.zeros((1, nr_chan, 1, 1), dtype=np.float32) | |||||
| running_var = np.ones((1, nr_chan, 1, 1), dtype=np.float32) | |||||
| steps = 4 | |||||
| server = Server(0) | |||||
| port = server.py_server_port | |||||
| def worker(rank, data, yv_expect, running_mean, running_var): | |||||
| dist.init_process_group("localhost", port, nr_ranks, rank, rank) | |||||
| group = Group([i for i in range(nr_ranks)]) | |||||
| bn = SyncBatchNorm(nr_chan, eps=eps, momentum=momentum, group=group) | |||||
| data_tensor = None | |||||
| for i in range(steps): | |||||
| if data_tensor is None: | |||||
| data_tensor = tensor(data[i], device=f"gpu{rank}:0") | |||||
| else: | |||||
| data_tensor.set_value(data[i]) | |||||
| yv = bn(data_tensor) | |||||
| assertTensorClose(yv_expect, yv.numpy(), max_err=5e-6) | |||||
| assertTensorClose(running_mean, bn.running_mean.numpy(), max_err=5e-6) | |||||
| assertTensorClose(running_var, bn.running_var.numpy(), max_err=5e-6) | |||||
| xv = [] | |||||
| for i in range(steps): | |||||
| xv.append(np.random.normal(loc=2.3, size=data_shape).astype(np.float32)) | |||||
| xv_transposed = np.transpose(xv[i], [0, 2, 3, 1]).reshape( | |||||
| (data_shape[0] * data_shape[2] * data_shape[3], nr_chan) | |||||
| ) | |||||
| mean = np.mean(xv_transposed, axis=0).reshape(1, nr_chan, 1, 1) | |||||
| var_biased = np.var(xv_transposed, axis=0).reshape((1, nr_chan, 1, 1)) | |||||
| sd = np.sqrt(var_biased + eps) | |||||
| var_unbiased = np.var(xv_transposed, axis=0, ddof=1).reshape((1, nr_chan, 1, 1)) | |||||
| running_mean = running_mean * momentum + mean * (1 - momentum) | |||||
| running_var = running_var * momentum + var_unbiased * (1 - momentum) | |||||
| yv_expect = (xv[i] - mean) / sd | |||||
| data = [] | |||||
| for i in range(nr_ranks): | |||||
| data.append([]) | |||||
| for j in range(steps): | |||||
| data[i].append(xv[j][:, :, :, i * 8 : i * 8 + 8]) | |||||
| procs = [] | |||||
| for rank in range(nr_ranks): | |||||
| p = mp.Process( | |||||
| target=worker, | |||||
| args=( | |||||
| rank, | |||||
| data[rank], | |||||
| yv_expect[:, :, :, rank * 8 : rank * 8 + 8], | |||||
| running_mean, | |||||
| running_var, | |||||
| ), | |||||
| ) | |||||
| p.start() | |||||
| procs.append(p) | |||||
| for p in procs: | |||||
| p.join(10) | |||||
| assert p.exitcode == 0 | |||||
| def test_module_conv2d(): | |||||
| from megengine.module.conv import Conv2d | |||||
| conv = Conv2d(2, 3, 1) | |||||