GitOrigin-RevId: e690bc42b0
tags/v1.0.0-rc1
| @@ -209,7 +209,7 @@ def conv_transpose2d( | |||
| dilate_w=dilate_w, | |||
| strategy=get_conv_execution_strategy(), | |||
| ) | |||
| (output,) = apply(op, inp, weight) | |||
| (output,) = apply(op, weight, inp) | |||
| if bias is not None: | |||
| output += bias | |||
| return output | |||
| @@ -241,7 +241,7 @@ def local_conv2d( | |||
| pad_w=pad_w, | |||
| dilate_h=dilate_h, | |||
| dilate_w=dilate_w, | |||
| strategy=get_conv_execution_strategy(), | |||
| # strategy=get_conv_execution_strategy(), | |||
| ) | |||
| (output,) = apply(op, inp, weight) | |||
| if bias is not None: | |||
| @@ -724,7 +724,7 @@ def sync_batch_norm( | |||
| """ | |||
| assert eps_mode in {"MAX", "ADDITIVE"}, "unknown eps_mode: {}".format(eps_mode) | |||
| _channels = input.shape[1] | |||
| _ndim = len(input.shape) | |||
| _ndim = input.ndim | |||
| _param_shape = (1, _channels) + (1,) * (_ndim - 2) | |||
| if training: | |||
| @@ -12,7 +12,7 @@ import numpy as np | |||
| from ..distributed.group import WORLD, Group | |||
| from ..functional import batch_norm2d, sync_batch_norm | |||
| from ..tensor_nn import Buffer, Parameter | |||
| from ..tensor_nn import Buffer, Parameter, Tensor | |||
| from . import init | |||
| from .module import Module | |||
| @@ -74,12 +74,12 @@ class _BatchNorm(Module): | |||
| _ndims = len(inp.shape) | |||
| if _ndims != 4: | |||
| origin_shape = inp.shapeof() | |||
| origin_shape = inp.shape | |||
| if _ndims == 2: | |||
| n, c = inp.shapeof(0), inp.shapeof(1) | |||
| n, c = inp.shape[0], inp.shape[1] | |||
| new_shape = (n, c, 1, 1) | |||
| elif _ndims == 3: | |||
| n, c, h = inp.shapeof(0), inp.shapeof(1), inp.shapeof(2) | |||
| n, c, h = inp.shape[0], inp.shape[1], inp.shape[2] | |||
| new_shape = (n, c, h, 1) | |||
| inp = inp.reshape(new_shape) | |||
| @@ -127,7 +127,7 @@ class SyncBatchNorm(_BatchNorm): | |||
| affine=True, | |||
| track_running_stats=True, | |||
| freeze=False, | |||
| group: Optional[Group] = None, | |||
| group: Optional[Group] = WORLD, | |||
| ) -> None: | |||
| super().__init__( | |||
| num_features, eps, momentum, affine, track_running_stats, freeze | |||
| @@ -145,13 +145,16 @@ class SyncBatchNorm(_BatchNorm): | |||
| _ndims = len(inp.shape) | |||
| if _ndims != 4: | |||
| origin_shape = inp.shapeof() | |||
| new_shape = Tensor([1, 1, 1, 1], device=inp.device) | |||
| origin_shape = inp.shape | |||
| if _ndims == 2: | |||
| n, c = inp.shape[0], inp.shape[1] | |||
| new_shape = (n, c, 1, 1) | |||
| new_shape[:2] = origin_shape[:2] | |||
| elif _ndims == 3: | |||
| n, c, h = inp.shape[0], inp.shape[1], inp.shape[2] | |||
| new_shape = (n, c, h, 1) | |||
| new_shape[:3] = origin_shape[:3] | |||
| else: | |||
| raise ValueError( | |||
| "expected 2D, 3D or 4D input (got {}D input)".format(len(inp.shape)) | |||
| ) | |||
| inp = inp.reshape(new_shape) | |||
| @@ -376,7 +376,13 @@ class LocalConv2d(Conv2d): | |||
| def forward(self, inp): | |||
| return local_conv2d( | |||
| inp, self.weight, self.stride, self.padding, self.dilation, self.conv_mode | |||
| inp, | |||
| self.weight, | |||
| None, | |||
| self.stride, | |||
| self.padding, | |||
| self.dilation, | |||
| self.conv_mode, | |||
| ) | |||
| @@ -0,0 +1,24 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import numpy as np | |||
| import megengine as mge | |||
| from megengine.module import LeakyReLU | |||
| from megengine.test import assertTensorClose | |||
| def test_leaky_relu(): | |||
| data = np.array([-8, -12, 6, 10]).astype(np.float32) | |||
| negative_slope = 0.1 | |||
| leaky_relu = LeakyReLU(negative_slope) | |||
| output = leaky_relu(mge.tensor(data)) | |||
| np_output = np.maximum(0, data) + negative_slope * np.minimum(0, data) | |||
| assertTensorClose(output.numpy(), np_output, max_err=0) | |||
| @@ -0,0 +1,419 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import multiprocessing as mp | |||
| import platform | |||
| import numpy as np | |||
| import pytest | |||
| import megengine as mge | |||
| import megengine.distributed as dist | |||
| from megengine import tensor | |||
| from megengine.core._trace_option import use_tensor_shape | |||
| from megengine.module import BatchNorm1d, BatchNorm2d, SyncBatchNorm | |||
| from megengine.tensor import Tensor | |||
| from megengine.test import assertTensorClose | |||
| @pytest.mark.skipif( | |||
| platform.system() == "Darwin", reason="do not imp GPU mode at macos now" | |||
| ) | |||
| @pytest.mark.skipif( | |||
| platform.system() == "Windows", reason="do not imp GPU mode at Windows now" | |||
| ) | |||
| @pytest.mark.skipif(use_tensor_shape(), reason="syncbn doesnot support symbolic shape") | |||
| @pytest.mark.isolated_distributed | |||
| def test_syncbn(): | |||
| nr_chan = 8 | |||
| data_shape = (3, nr_chan, 4, 16) | |||
| momentum = 0.9 | |||
| eps = 1e-5 | |||
| running_mean = np.zeros((1, nr_chan, 1, 1), dtype=np.float32) | |||
| running_var = np.ones((1, nr_chan, 1, 1), dtype=np.float32) | |||
| steps = 4 | |||
| nr_ranks = 2 | |||
| server = dist.Server(0) | |||
| port = server.py_server_port | |||
| def worker(rank, data, yv_expect, running_mean, running_var): | |||
| if mge.get_device_count("gpu") < nr_ranks: | |||
| return | |||
| dist.init_process_group("localhost", port, nr_ranks, rank, rank) | |||
| bn = SyncBatchNorm(nr_chan, momentum=momentum, eps=eps) | |||
| data_tensor = tensor([]) | |||
| for i in range(steps): | |||
| data_tensor.set_value(data[i]) | |||
| yv = bn(data_tensor) | |||
| assertTensorClose(yv_expect, yv.numpy(), max_err=5e-6) | |||
| assertTensorClose(running_mean, bn.running_mean.numpy(), max_err=5e-6) | |||
| assertTensorClose(running_var, bn.running_var.numpy(), max_err=5e-6) | |||
| xv = [] | |||
| for i in range(steps): | |||
| xv.append(np.random.normal(loc=2.3, size=data_shape).astype(np.float32)) | |||
| xv_transposed = np.transpose(xv[i], [0, 2, 3, 1]).reshape( | |||
| (data_shape[0] * data_shape[2] * data_shape[3], nr_chan) | |||
| ) | |||
| mean = np.mean(xv_transposed, axis=0).reshape(1, nr_chan, 1, 1) | |||
| var_biased = np.var(xv_transposed, axis=0).reshape((1, nr_chan, 1, 1)) | |||
| sd = np.sqrt(var_biased + eps) | |||
| var_unbiased = np.var(xv_transposed, axis=0, ddof=1).reshape((1, nr_chan, 1, 1)) | |||
| running_mean = running_mean * momentum + mean * (1 - momentum) | |||
| running_var = running_var * momentum + var_unbiased * (1 - momentum) | |||
| yv_expect = (xv[i] - mean) / sd | |||
| data = [] | |||
| for i in range(nr_ranks): | |||
| data.append([]) | |||
| for j in range(steps): | |||
| data[i].append(xv[j][:, :, :, i * 8 : i * 8 + 8]) | |||
| procs = [] | |||
| for rank in range(nr_ranks): | |||
| p = mp.Process( | |||
| target=worker, | |||
| args=( | |||
| rank, | |||
| data[rank], | |||
| yv_expect[:, :, :, rank * 8 : rank * 8 + 8], | |||
| running_mean, | |||
| running_var, | |||
| ), | |||
| ) | |||
| p.start() | |||
| procs.append(p) | |||
| for p in procs: | |||
| p.join(10) | |||
| assert p.exitcode == 0 | |||
| def test_batchnorm(): | |||
| nr_chan = 8 | |||
| data_shape = (3, nr_chan, 4) | |||
| momentum = 0.9 | |||
| bn = BatchNorm1d(nr_chan, momentum=momentum) | |||
| running_mean = np.zeros((1, nr_chan, 1), dtype=np.float32) | |||
| running_var = np.ones((1, nr_chan, 1), dtype=np.float32) | |||
| data = tensor([]) | |||
| for i in range(3): | |||
| xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32) | |||
| mean = np.mean(np.mean(xv, axis=0, keepdims=True), axis=2, keepdims=True) | |||
| xv_transposed = np.transpose(xv, [0, 2, 1]).reshape( | |||
| (data_shape[0] * data_shape[2], nr_chan) | |||
| ) | |||
| var_biased = np.var(xv_transposed, axis=0).reshape((1, nr_chan, 1)) | |||
| sd = np.sqrt(var_biased + bn.eps) | |||
| var_unbiased = np.var(xv_transposed, axis=0, ddof=1).reshape((1, nr_chan, 1)) | |||
| running_mean = running_mean * momentum + mean * (1 - momentum) | |||
| running_var = running_var * momentum + var_unbiased * (1 - momentum) | |||
| data.set_value(xv) | |||
| yv = bn(data) | |||
| yv_expect = (xv - mean) / sd | |||
| assertTensorClose(yv_expect, yv.numpy(), max_err=5e-6) | |||
| assertTensorClose( | |||
| running_mean.reshape(-1), bn.running_mean.numpy().reshape(-1), max_err=5e-6 | |||
| ) | |||
| assertTensorClose( | |||
| running_var.reshape(-1), bn.running_var.numpy().reshape(-1), max_err=5e-6 | |||
| ) | |||
| # test set 'training' flag to False | |||
| mean_backup = bn.running_mean.numpy() | |||
| var_backup = bn.running_var.numpy() | |||
| bn.training = False | |||
| xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32) | |||
| data.set_value(xv) | |||
| yv1 = bn(data) | |||
| yv2 = bn(data) | |||
| assertTensorClose(yv1.numpy(), yv2.numpy(), max_err=0) | |||
| assertTensorClose(mean_backup, bn.running_mean.numpy(), max_err=0) | |||
| assertTensorClose(var_backup, bn.running_var.numpy(), max_err=0) | |||
| yv_expect = (xv - running_mean) / np.sqrt(running_var + bn.eps) | |||
| assertTensorClose(yv_expect, yv1.numpy(), max_err=5e-6) | |||
| @pytest.mark.skipif( | |||
| platform.system() == "Darwin", reason="do not imp GPU mode at macos now" | |||
| ) | |||
| @pytest.mark.skipif( | |||
| platform.system() == "Windows", reason="do not imp GPU mode at Windows now" | |||
| ) | |||
| @pytest.mark.skipif(use_tensor_shape(), reason="syncbn doesnot support symbolic shape") | |||
| @pytest.mark.isolated_distributed | |||
| def test_syncbn1d(): | |||
| nr_chan = 8 | |||
| data_shape = (3, nr_chan, 4) | |||
| momentum = 0.9 | |||
| bn = SyncBatchNorm(nr_chan, momentum=momentum) | |||
| running_mean = np.zeros((1, nr_chan, 1), dtype=np.float32) | |||
| running_var = np.ones((1, nr_chan, 1), dtype=np.float32) | |||
| data = tensor([]) | |||
| for i in range(3): | |||
| xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32) | |||
| mean = np.mean(np.mean(xv, axis=0, keepdims=True), axis=2, keepdims=True) | |||
| xv_transposed = np.transpose(xv, [0, 2, 1]).reshape( | |||
| (data_shape[0] * data_shape[2], nr_chan) | |||
| ) | |||
| var_biased = np.var(xv_transposed, axis=0).reshape((1, nr_chan, 1)) | |||
| sd = np.sqrt(var_biased + bn.eps) | |||
| var_unbiased = np.var(xv_transposed, axis=0, ddof=1).reshape((1, nr_chan, 1)) | |||
| running_mean = running_mean * momentum + mean * (1 - momentum) | |||
| running_var = running_var * momentum + var_unbiased * (1 - momentum) | |||
| data.set_value(xv) | |||
| yv = bn(data) | |||
| yv_expect = (xv - mean) / sd | |||
| assertTensorClose(yv_expect, yv.numpy(), max_err=5e-6) | |||
| assertTensorClose( | |||
| running_mean.reshape(-1), bn.running_mean.numpy().reshape(-1), max_err=5e-6 | |||
| ) | |||
| assertTensorClose( | |||
| running_var.reshape(-1), bn.running_var.numpy().reshape(-1), max_err=5e-6 | |||
| ) | |||
| # test set 'training' flag to False | |||
| mean_backup = bn.running_mean.numpy() | |||
| var_backup = bn.running_var.numpy() | |||
| bn.training = False | |||
| xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32) | |||
| data.set_value(xv) | |||
| yv1 = bn(data) | |||
| yv2 = bn(data) | |||
| assertTensorClose(yv1.numpy(), yv2.numpy(), max_err=0) | |||
| assertTensorClose(mean_backup, bn.running_mean.numpy(), max_err=0) | |||
| assertTensorClose(var_backup, bn.running_var.numpy(), max_err=0) | |||
| yv_expect = (xv - running_mean) / np.sqrt(running_var + bn.eps) | |||
| assertTensorClose(yv_expect, yv1.numpy(), max_err=5e-6) | |||
| def test_batchnorm2d(): | |||
| nr_chan = 8 | |||
| data_shape = (3, nr_chan, 16, 16) | |||
| momentum = 0.9 | |||
| bn = BatchNorm2d(nr_chan, momentum=momentum) | |||
| running_mean = np.zeros((1, nr_chan, 1, 1), dtype=np.float32) | |||
| running_var = np.ones((1, nr_chan, 1, 1), dtype=np.float32) | |||
| data = tensor([]) | |||
| for i in range(3): | |||
| xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32) | |||
| xv_transposed = np.transpose(xv, [0, 2, 3, 1]).reshape( | |||
| (data_shape[0] * data_shape[2] * data_shape[3], nr_chan) | |||
| ) | |||
| mean = np.mean(xv_transposed, axis=0).reshape(1, nr_chan, 1, 1) | |||
| var_biased = np.var(xv_transposed, axis=0).reshape((1, nr_chan, 1, 1)) | |||
| sd = np.sqrt(var_biased + bn.eps) | |||
| var_unbiased = np.var(xv_transposed, axis=0, ddof=1).reshape((1, nr_chan, 1, 1)) | |||
| running_mean = running_mean * momentum + mean * (1 - momentum) | |||
| running_var = running_var * momentum + var_unbiased * (1 - momentum) | |||
| data.set_value(xv) | |||
| yv = bn(data) | |||
| yv_expect = (xv - mean) / sd | |||
| assertTensorClose(yv_expect, yv.numpy(), max_err=5e-6) | |||
| assertTensorClose(running_mean, bn.running_mean.numpy(), max_err=5e-6) | |||
| assertTensorClose(running_var, bn.running_var.numpy(), max_err=5e-6) | |||
| # test set 'training' flag to False | |||
| mean_backup = bn.running_mean.numpy() | |||
| var_backup = bn.running_var.numpy() | |||
| bn.training = False | |||
| xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32) | |||
| data.set_value(xv) | |||
| yv1 = bn(data) | |||
| yv2 = bn(data) | |||
| assertTensorClose(yv1.numpy(), yv2.numpy(), max_err=0) | |||
| assertTensorClose(mean_backup, bn.running_mean.numpy(), max_err=0) | |||
| assertTensorClose(var_backup, bn.running_var.numpy(), max_err=0) | |||
| yv_expect = (xv - running_mean) / np.sqrt(running_var + bn.eps) | |||
| assertTensorClose(yv_expect, yv1.numpy(), max_err=5e-6) | |||
| @pytest.mark.skipif( | |||
| platform.system() == "Darwin", reason="do not imp GPU mode at macos now" | |||
| ) | |||
| @pytest.mark.skipif( | |||
| platform.system() == "Windows", reason="do not imp GPU mode at Windows now" | |||
| ) | |||
| @pytest.mark.skipif(use_tensor_shape(), reason="syncbn doesnot support symbolic shape") | |||
| @pytest.mark.isolated_distributed | |||
| def test_syncbn2d(): | |||
| nr_chan = 8 | |||
| data_shape = (3, nr_chan, 16, 16) | |||
| momentum = 0.9 | |||
| bn = SyncBatchNorm(nr_chan, momentum=momentum) | |||
| running_mean = np.zeros((1, nr_chan, 1, 1), dtype=np.float32) | |||
| running_var = np.ones((1, nr_chan, 1, 1), dtype=np.float32) | |||
| data = tensor([]) | |||
| for i in range(3): | |||
| xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32) | |||
| xv_transposed = np.transpose(xv, [0, 2, 3, 1]).reshape( | |||
| (data_shape[0] * data_shape[2] * data_shape[3], nr_chan) | |||
| ) | |||
| mean = np.mean(xv_transposed, axis=0).reshape(1, nr_chan, 1, 1) | |||
| var_biased = np.var(xv_transposed, axis=0).reshape((1, nr_chan, 1, 1)) | |||
| sd = np.sqrt(var_biased + bn.eps) | |||
| var_unbiased = np.var(xv_transposed, axis=0, ddof=1).reshape((1, nr_chan, 1, 1)) | |||
| running_mean = running_mean * momentum + mean * (1 - momentum) | |||
| running_var = running_var * momentum + var_unbiased * (1 - momentum) | |||
| data.set_value(xv) | |||
| yv = bn(data) | |||
| yv_expect = (xv - mean) / sd | |||
| assertTensorClose(yv_expect, yv.numpy(), max_err=5e-6) | |||
| assertTensorClose(running_mean, bn.running_mean.numpy(), max_err=5e-6) | |||
| assertTensorClose(running_var, bn.running_var.numpy(), max_err=5e-6) | |||
| # test set 'training' flag to False | |||
| mean_backup = bn.running_mean.numpy() | |||
| var_backup = bn.running_var.numpy() | |||
| bn.training = False | |||
| xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32) | |||
| data.set_value(xv) | |||
| yv1 = bn(data) | |||
| yv2 = bn(data) | |||
| assertTensorClose(yv1.numpy(), yv2.numpy(), max_err=0) | |||
| assertTensorClose(mean_backup, bn.running_mean.numpy(), max_err=0) | |||
| assertTensorClose(var_backup, bn.running_var.numpy(), max_err=0) | |||
| yv_expect = (xv - running_mean) / np.sqrt(running_var + bn.eps) | |||
| assertTensorClose(yv_expect, yv1.numpy(), max_err=5e-6) | |||
| def test_batchnorm_no_stats(): | |||
| nr_chan = 8 | |||
| data_shape = (3, nr_chan, 4) | |||
| bn = BatchNorm1d(8, track_running_stats=False) | |||
| data = tensor([]) | |||
| for i in range(4): | |||
| if i == 2: | |||
| bn.training = False | |||
| xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32) | |||
| mean = np.mean(np.mean(xv, axis=0, keepdims=True), axis=2, keepdims=True) | |||
| var = np.var( | |||
| np.transpose(xv, [0, 2, 1]).reshape( | |||
| (data_shape[0] * data_shape[2], nr_chan) | |||
| ), | |||
| axis=0, | |||
| ).reshape((1, nr_chan, 1)) | |||
| sd = np.sqrt(var + bn.eps) | |||
| data.set_value(xv) | |||
| yv = bn(data) | |||
| yv_expect = (xv - mean) / sd | |||
| assertTensorClose(yv_expect, yv.numpy(), max_err=5e-6) | |||
| @pytest.mark.skipif( | |||
| platform.system() == "Darwin", reason="do not imp GPU mode at macos now" | |||
| ) | |||
| @pytest.mark.skipif( | |||
| platform.system() == "Windows", reason="do not imp GPU mode at Windows now" | |||
| ) | |||
| @pytest.mark.skipif(use_tensor_shape(), reason="syncbn doesnot support symbolic shape") | |||
| @pytest.mark.isolated_distributed | |||
| def test_syncbn_no_stats(): | |||
| nr_chan = 8 | |||
| data_shape = (3, nr_chan, 4) | |||
| bn = SyncBatchNorm(8, track_running_stats=False) | |||
| data = tensor([]) | |||
| for i in range(4): | |||
| if i == 2: | |||
| bn.training = False | |||
| xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32) | |||
| mean = np.mean(np.mean(xv, axis=0, keepdims=True), axis=2, keepdims=True) | |||
| var = np.var( | |||
| np.transpose(xv, [0, 2, 1]).reshape( | |||
| (data_shape[0] * data_shape[2], nr_chan) | |||
| ), | |||
| axis=0, | |||
| ).reshape((1, nr_chan, 1)) | |||
| sd = np.sqrt(var + bn.eps) | |||
| data.set_value(xv) | |||
| yv = bn(data) | |||
| yv_expect = (xv - mean) / sd | |||
| assertTensorClose(yv_expect, yv.numpy(), max_err=5e-6) | |||
| def test_batchnorm2d_no_stats(): | |||
| nr_chan = 8 | |||
| data_shape = (3, nr_chan, 16, 16) | |||
| bn = BatchNorm2d(8, track_running_stats=False) | |||
| data = tensor([]) | |||
| for i in range(4): | |||
| if i == 2: | |||
| bn.training = False | |||
| xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32) | |||
| xv_transposed = np.transpose(xv, [0, 2, 3, 1]).reshape( | |||
| (data_shape[0] * data_shape[2] * data_shape[3], nr_chan) | |||
| ) | |||
| mean = np.mean(xv_transposed, axis=0).reshape(1, nr_chan, 1, 1) | |||
| var = np.var(xv_transposed, axis=0).reshape((1, nr_chan, 1, 1)) | |||
| sd = np.sqrt(var + bn.eps) | |||
| data.set_value(xv) | |||
| yv = bn(data) | |||
| yv_expect = (xv - mean) / sd | |||
| assertTensorClose(yv_expect, yv.numpy(), max_err=5e-6) | |||
| @pytest.mark.skipif( | |||
| platform.system() == "Darwin", reason="do not imp GPU mode at macos now" | |||
| ) | |||
| @pytest.mark.skipif( | |||
| platform.system() == "Windows", reason="do not imp GPU mode at Windows now" | |||
| ) | |||
| @pytest.mark.skipif(use_tensor_shape(), reason="syncbn doesnot support symbolic shape") | |||
| @pytest.mark.isolated_distributed | |||
| def test_syncbn2d_no_stats(): | |||
| nr_chan = 8 | |||
| data_shape = (3, nr_chan, 16, 16) | |||
| bn = SyncBatchNorm(8, track_running_stats=False) | |||
| data = tensor([]) | |||
| for i in range(4): | |||
| if i == 2: | |||
| bn.training = False | |||
| xv = np.random.normal(loc=2.3, size=data_shape).astype(np.float32) | |||
| xv_transposed = np.transpose(xv, [0, 2, 3, 1]).reshape( | |||
| (data_shape[0] * data_shape[2] * data_shape[3], nr_chan) | |||
| ) | |||
| mean = np.mean(xv_transposed, axis=0).reshape(1, nr_chan, 1, 1) | |||
| var = np.var(xv_transposed, axis=0).reshape((1, nr_chan, 1, 1)) | |||
| sd = np.sqrt(var + bn.eps) | |||
| data.set_value(xv) | |||
| yv = bn(data) | |||
| yv_expect = (xv - mean) / sd | |||
| assertTensorClose(yv_expect, yv.numpy(), max_err=5e-6) | |||
| @@ -0,0 +1,110 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import itertools | |||
| import numpy as np | |||
| from megengine import Parameter, tensor | |||
| from megengine.module import ConvTranspose2d, LocalConv2d | |||
| from megengine.test import assertTensorClose | |||
| def test_conv_transpose2d(): | |||
| SH, SW = 3, 1 | |||
| PH, PW = 2, 0 | |||
| N, IC, IH, IW = 4, 5, 8, 6 | |||
| KH, KW = 3, 4 | |||
| OC = 3 | |||
| BIAS = False | |||
| def getsize(inp, kern, stride): | |||
| return (inp - 1) * stride + kern | |||
| OH = getsize(IH, KH, SH) | |||
| OW = getsize(IW, KW, SW) | |||
| inp = np.random.normal(size=(N, IC, IH, IW)).astype(np.float32) | |||
| out = np.zeros((N, OC, OH, OW), dtype=np.float32) | |||
| weight = np.random.normal(size=(IC, OC, KH, KW)).astype(np.float32) | |||
| bias = np.random.normal(size=(1, OC, 1, 1)).astype(np.float32) | |||
| # naive calculation use numpy | |||
| for n, ic, ih, iw in itertools.product(*map(range, [N, IC, IH, IW])): | |||
| oh, ow = ih * SH, iw * SW | |||
| out[n, :, oh : oh + KH, ow : ow + KW] += inp[n, ic, ih, iw] * weight[ic] | |||
| out = out[:, :, PH : OH - PH, PW : OW - PW] | |||
| if BIAS: | |||
| out += bias | |||
| # megengine conv_transpose2d calculation | |||
| conv_transpose2d = ConvTranspose2d(IC, OC, (KH, KW), (SH, SW), (PH, PW), bias=BIAS) | |||
| conv_transpose2d.weight = Parameter(weight, dtype=np.float32) | |||
| if BIAS: | |||
| conv_transpose2d.bias = Parameter(bias, dtype=np.float32) | |||
| y = conv_transpose2d(tensor(inp)) | |||
| assertTensorClose(out, y.numpy(), max_err=2e-6) | |||
| def test_local_conv2d(): | |||
| batch_size = 10 | |||
| in_channels = 4 | |||
| out_channels = 8 | |||
| input_height = 8 | |||
| input_width = 8 | |||
| kernel_size = 3 | |||
| stride = 1 | |||
| padding = 1 | |||
| dilation = 1 | |||
| groups = 1 | |||
| local_conv2d = LocalConv2d( | |||
| in_channels=in_channels, | |||
| out_channels=out_channels, | |||
| input_height=input_height, | |||
| input_width=input_width, | |||
| kernel_size=kernel_size, | |||
| stride=stride, | |||
| padding=padding, | |||
| dilation=dilation, | |||
| groups=groups, | |||
| ) | |||
| inputs = np.random.normal( | |||
| size=(batch_size, in_channels, input_height, input_width) | |||
| ).astype(np.float32) | |||
| output_height = (input_height + padding * 2 - kernel_size) // stride + 1 | |||
| output_width = (input_width + padding * 2 - kernel_size) // stride + 1 | |||
| weights = np.random.normal( | |||
| size=( | |||
| groups, | |||
| output_height, | |||
| output_width, | |||
| in_channels // groups, | |||
| kernel_size, | |||
| kernel_size, | |||
| out_channels // groups, | |||
| ) | |||
| ).astype(np.float32) | |||
| local_conv2d.weight = Parameter(weights) | |||
| outputs = local_conv2d(tensor(inputs)) | |||
| # naive calculation use numpy | |||
| # only test output_height == input_height, output_width == input_width, group == 1 | |||
| inputs = np.pad(inputs, ((0, 0), (0, 0), (1, 1), (1, 1))) | |||
| expected = np.zeros( | |||
| (batch_size, out_channels, output_height, output_width), dtype=np.float32, | |||
| ) | |||
| for n, oc, oh, ow in itertools.product( | |||
| *map(range, [batch_size, out_channels, output_height, output_width]) | |||
| ): | |||
| ih, iw = oh * stride, ow * stride | |||
| expected[n, oc, ih, iw] = np.sum( | |||
| inputs[n, :, ih : ih + kernel_size, iw : iw + kernel_size] | |||
| * weights[0, oh, ow, :, :, :, oc] | |||
| ) | |||
| assertTensorClose(outputs.numpy(), expected, max_err=1e-5) | |||
| @@ -0,0 +1,46 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import os | |||
| import numpy as np | |||
| import pytest | |||
| import megengine as mge | |||
| from megengine import tensor | |||
| from megengine.module import Module | |||
| class MyModule(Module): | |||
| def __init__(self, data): | |||
| from megengine.module.external import CambriconSubgraph | |||
| super().__init__() | |||
| self.cambricon = CambriconSubgraph(data, "subnet0", True) | |||
| def forward(self, inputs): | |||
| out = self.cambricon(inputs) | |||
| return out | |||
| @pytest.mark.skip(reason="cambricon unimplemented") | |||
| def test_cambricon_module(): | |||
| model = "CambriconRuntimeOprTest.MutableBatchSize.mlu" | |||
| model = os.path.join(os.path.dirname(__file__), model) | |||
| with open(model, "rb") as f: | |||
| data = f.read() | |||
| m = MyModule(data) | |||
| inputs = [] | |||
| inputs.append(tensor(data=[], dtype=np.float16, device="cambricon0")) | |||
| inputs[0].set_value(np.random.normal(size=(1, 64, 32, 32)).astype(np.float16)) | |||
| def inference(inps): | |||
| pred = m(inps) | |||
| return pred | |||
| pred = inference(inputs) | |||
| @@ -0,0 +1,27 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import pytest | |||
| from megengine.module import Conv2d, Linear | |||
| from megengine.module.init import calculate_fan_in_and_fan_out | |||
| def test_calculate_fan_in_and_fan_out(): | |||
| l = Linear(in_features=3, out_features=8) | |||
| fanin, fanout = calculate_fan_in_and_fan_out(l.weight) | |||
| assert fanin == 3 | |||
| assert fanout == 8 | |||
| with pytest.raises(ValueError): | |||
| calculate_fan_in_and_fan_out(l.bias) | |||
| l = Conv2d(in_channels=2, out_channels=3, kernel_size=(5, 7)) | |||
| fanin, fanout = calculate_fan_in_and_fan_out(l.weight) | |||
| assert fanin == 2 * 5 * 7 | |||
| assert fanout == 3 * 5 * 7 | |||
| @@ -0,0 +1,614 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import os | |||
| import tempfile | |||
| from collections import OrderedDict | |||
| from io import BytesIO | |||
| import numpy as np | |||
| import pytest | |||
| import megengine as mge | |||
| import megengine.functional as F | |||
| from megengine import Buffer, Parameter, Tensor, tensor | |||
| from megengine.module import ( | |||
| BatchNorm1d, | |||
| BatchNorm2d, | |||
| Conv2d, | |||
| Linear, | |||
| Module, | |||
| Sequential, | |||
| ) | |||
| from megengine.quantization.quantize import quantize, quantize_qat | |||
| from megengine.test import assertTensorClose | |||
| class MLP(Module): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.dense0 = Linear(28, 50) | |||
| self.dense1 = Linear(50, 20) | |||
| def forward(self, x): | |||
| x = self.dense0(x) | |||
| x = F.relu(x) | |||
| x = self.dense1(x) | |||
| return x | |||
| def has_gpu(num=1): | |||
| try: | |||
| mgb.comp_node("gpu{}".format(num - 1)) | |||
| except mgb.MegBrainError: | |||
| return False | |||
| return True | |||
| def randomNp(*args): | |||
| for arg in args: | |||
| assert isinstance(arg, int) | |||
| return np.random.random(args) | |||
| def randomTorch(*args): | |||
| import torch # pylint: disable=import-outside-toplevel | |||
| for arg in args: | |||
| assert isinstance(arg, int) | |||
| return torch.tensor(randomNp(*args), dtype=torch.float32) | |||
| def graph_mode(*modes): | |||
| if not set(modes).issubset({"eager", "static"}): | |||
| raise ValueError("graph mode must be in (eager, static)") | |||
| def decorator(func): | |||
| def wrapper(*args, **kwargs): | |||
| if "eager" in set(modes): | |||
| func(*args, **kwargs) | |||
| if "static" in set(modes): | |||
| with Graph() as cg: | |||
| cg.set_option("eager_evaluation", False) | |||
| func(*args, **kwargs) | |||
| return wrapper | |||
| return decorator | |||
| def _default_compare_fn(x, y): | |||
| assertTensorClose(x.numpy(), y) | |||
| def opr_test( | |||
| cases, | |||
| func, | |||
| mode=("eager", "static", "dynamic_shape"), | |||
| compare_fn=_default_compare_fn, | |||
| ref_fn=None, | |||
| **kwargs | |||
| ): | |||
| """ | |||
| mode: the list of test mode which are eager, static and dynamic_shape | |||
| will test all the cases if None. | |||
| func: the function to run opr. | |||
| compare_fn: the function to compare the result and expected, use assertTensorClose if None. | |||
| ref_fn: the function to generate expected data, should assign output if None. | |||
| cases: the list which have dict element, the list length should be 2 for dynamic shape test. | |||
| and the dict should have input, | |||
| and should have output if ref_fn is None. | |||
| should use list for multiple inputs and outputs for each case. | |||
| kwargs: The additional kwargs for opr func. | |||
| simple examples: | |||
| dtype = np.float32 | |||
| cases = [{"input": [10, 20]}, {"input": [20, 30]}] | |||
| opr_test(cases, | |||
| F.eye, | |||
| ref_fn=lambda n, m: np.eye(n, m).astype(dtype), | |||
| dtype=dtype) | |||
| """ | |||
| def check_results(results, expected): | |||
| if not isinstance(results, Tuple): | |||
| results = (results,) | |||
| for r, e in zip(results, expected): | |||
| compare_fn(r, e) | |||
| def get_trace_fn(func, enabled, symbolic): | |||
| jit.trace.enabled = enabled | |||
| return jit.trace(func, symbolic=symbolic) | |||
| def get_param(cases, idx): | |||
| case = cases[idx] | |||
| inp = case.get("input", None) | |||
| outp = case.get("output", None) | |||
| if inp is None: | |||
| raise ValueError("the test case should have input") | |||
| if not isinstance(inp, List): | |||
| inp = (inp,) | |||
| else: | |||
| inp = tuple(inp) | |||
| if ref_fn is not None and callable(ref_fn): | |||
| outp = ref_fn(*inp) | |||
| if outp is None: | |||
| raise ValueError("the test case should have output or reference function") | |||
| if not isinstance(outp, List): | |||
| outp = (outp,) | |||
| else: | |||
| outp = tuple(outp) | |||
| return inp, outp | |||
| if not set(mode).issubset({"eager", "static", "dynamic_shape"}): | |||
| raise ValueError("opr test mode must be in (eager, static, dynamic_shape)") | |||
| if len(cases) == 0: | |||
| raise ValueError("should give one case at least") | |||
| if "dynamic_shape" in set(mode): | |||
| if len(cases) != 2: | |||
| raise ValueError("should give 2 cases for dynamic shape test") | |||
| if not callable(func): | |||
| raise ValueError("the input func should be callable") | |||
| inp, outp = get_param(cases, 0) | |||
| def run(*args, **kwargs): | |||
| return func(*args, **kwargs) | |||
| if "eager" in set(mode): | |||
| f = get_trace_fn(run, False, False) | |||
| results = f(*inp, **kwargs) | |||
| check_results(results, outp) | |||
| if "static" in set(mode) or "dynamic_shape" in set(mode): | |||
| f = get_trace_fn(run, True, True) | |||
| results = f(*inp, **kwargs) | |||
| check_results(results, outp) | |||
| if "dynamic_shape" in set(mode): | |||
| inp, outp = get_param(cases, 1) | |||
| results = f(*inp, **kwargs) | |||
| check_results(results, outp) | |||
| class MyModule(Module): | |||
| class InnerModule(Module): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.bn = BatchNorm2d(4) | |||
| def forward(self, x): | |||
| return self.bn(x) | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.i = self.InnerModule() | |||
| self.bn = BatchNorm2d(4) | |||
| self.param = Parameter(np.ones(1, dtype=np.float32)) | |||
| self.buff = Buffer(np.ones(1, dtype=np.float32)) | |||
| def forward(self, x): | |||
| x = self.i(x) | |||
| x = self.bn(x) | |||
| return x | |||
| def test_module_api(): | |||
| m = MyModule() | |||
| assert list(m.children()) == [m.bn, m.i] | |||
| assert list(m.named_children()) == [("bn", m.bn), ("i", m.i)] | |||
| assert list(m.modules()) == [m, m.bn, m.i, m.i.bn] | |||
| assert list(m.named_modules()) == [ | |||
| ("", m), | |||
| ("bn", m.bn), | |||
| ("i", m.i), | |||
| ("i.bn", m.i.bn), | |||
| ] | |||
| assert list(m.named_modules(prefix="x")) == [ | |||
| ("x", m), | |||
| ("x.bn", m.bn), | |||
| ("x.i", m.i), | |||
| ("x.i.bn", m.i.bn), | |||
| ] | |||
| assert list(m.buffers()) == [ | |||
| m.bn.running_mean, | |||
| m.bn.running_var, | |||
| m.buff, | |||
| m.i.bn.running_mean, | |||
| m.i.bn.running_var, | |||
| ] | |||
| assert list(m.buffers(recursive=False)) == [m.buff] | |||
| assert list(m.named_buffers()) == [ | |||
| ("bn.running_mean", m.bn.running_mean), | |||
| ("bn.running_var", m.bn.running_var), | |||
| ("buff", m.buff), | |||
| ("i.bn.running_mean", m.i.bn.running_mean), | |||
| ("i.bn.running_var", m.i.bn.running_var), | |||
| ] | |||
| assert list(m.parameters()) == [ | |||
| m.bn.bias, | |||
| m.bn.weight, | |||
| m.i.bn.bias, | |||
| m.i.bn.weight, | |||
| m.param, | |||
| ] | |||
| assert list(m.named_parameters()) == [ | |||
| ("bn.bias", m.bn.bias), | |||
| ("bn.weight", m.bn.weight), | |||
| ("i.bn.bias", m.i.bn.bias), | |||
| ("i.bn.weight", m.i.bn.weight), | |||
| ("param", m.param), | |||
| ] | |||
| m.eval() | |||
| assert ( | |||
| m.training == False | |||
| and m.bn.training == False | |||
| and m.i.training == False | |||
| and m.i.bn.training == False | |||
| ) | |||
| m.bn.train() | |||
| assert m.training == False and m.bn.training == True and m.i.bn.training == False | |||
| m.eval() | |||
| m.i.train() | |||
| assert ( | |||
| m.training == False | |||
| and m.bn.training == False | |||
| and m.i.training == True | |||
| and m.i.bn.training == True | |||
| ) | |||
| m.eval() | |||
| m.train() | |||
| assert m.training == True and m.bn.training == True and m.i.bn.training == True | |||
| def fn(m): | |||
| m.training = False | |||
| m.apply(fn) | |||
| assert m.bn.training == False and m.i.bn.training == False | |||
| def test_module_api_reuse_submodule(): | |||
| m = MyModule() | |||
| m.h = m.i # pylint: disable=attribute-defined-outside-init | |||
| assert list(m.modules()) == [m, m.bn, m.i, m.i.bn] | |||
| assert list(m.named_modules()) == [ | |||
| ("", m), | |||
| ("bn", m.bn), | |||
| ("h", m.i), | |||
| ("h.bn", m.i.bn), | |||
| ] | |||
| def test_module_api_iterable_stability(): | |||
| m = MyModule() | |||
| l = list(m.modules()) | |||
| for _ in range(100): | |||
| assert list(m.modules()) == l | |||
| def test_module_api_hooks(): | |||
| net = MyModule() | |||
| pre_hook_num = 0 | |||
| post_hook_num = 0 | |||
| hooks = [] | |||
| def pre_hook(module, inputs): | |||
| nonlocal pre_hook_num | |||
| pre_hook_num += 1 | |||
| modified_inputs = tuple(inp + 1 for inp in inputs) | |||
| return modified_inputs | |||
| def post_hook(module, inputs, outputs): | |||
| nonlocal post_hook_num | |||
| post_hook_num += 1 | |||
| outputs += 1 | |||
| return outputs | |||
| net.apply(lambda module: hooks.append(module.register_forward_pre_hook(pre_hook))) | |||
| net.apply(lambda module: hooks.append(module.register_forward_hook(post_hook))) | |||
| shape = (1, 4, 1, 1) | |||
| x = tensor(np.zeros(shape, dtype=np.float32)) | |||
| y = net(x) | |||
| assert pre_hook_num == 4 | |||
| assert post_hook_num == 4 | |||
| mean1 = Parameter(np.zeros(shape), dtype=np.float32) | |||
| bn1 = F.batch_norm2d( | |||
| x + 3, mean1, Parameter(np.ones(shape), dtype=np.float32), training=True | |||
| ) | |||
| assertTensorClose( | |||
| net.i.bn.running_mean.numpy(), mean1.numpy(), | |||
| ) | |||
| mean2 = Parameter(np.zeros(shape), dtype=np.float32) | |||
| bn2 = F.batch_norm2d( | |||
| bn1 + 3, mean2, Parameter(np.ones(shape), dtype=np.float32), training=True | |||
| ) | |||
| assertTensorClose( | |||
| net.bn.running_mean.numpy(), mean2.numpy(), | |||
| ) | |||
| assertTensorClose((bn2 + 2).numpy(), y.numpy()) | |||
| assert len(hooks) == 8 | |||
| for handler in hooks: | |||
| handler.remove() | |||
| y = net(x) | |||
| assert pre_hook_num == 4 | |||
| assert post_hook_num == 4 | |||
| class MyModule2(Module): | |||
| class InnerModule(Module): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.bn = BatchNorm2d(4) | |||
| self.test_bool_key = {True: 1, False: 0} | |||
| def forward(self, x): | |||
| x = self.bn(x) | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.bn = BatchNorm2d(4) | |||
| self.a = [ | |||
| BatchNorm2d(4), | |||
| {"x": BatchNorm2d(4), "y": [BatchNorm2d(4), self.InnerModule()], "z": 0}, | |||
| (self.InnerModule(),), | |||
| ] | |||
| def forward(self, x): | |||
| return x | |||
| def test_expand_structure(): | |||
| m = MyModule2() | |||
| assert list(m.named_modules()) == [ | |||
| ("", m), | |||
| ("a.0", m.a[0]), | |||
| ("a.1.x", m.a[1]["x"]), | |||
| ("a.1.y.0", m.a[1]["y"][0]), | |||
| ("a.1.y.1", m.a[1]["y"][1]), | |||
| ("a.1.y.1.bn", m.a[1]["y"][1].bn), | |||
| ("a.2.0", m.a[2][0]), | |||
| ("a.2.0.bn", m.a[2][0].bn), | |||
| ("bn", m.bn), | |||
| ] | |||
| def test_flatten_others(): | |||
| def be_others(obj): | |||
| return not isinstance(obj, (Tensor, Module)) | |||
| m = MyModule2() | |||
| assert len(list(m._flatten(with_key=True, predicate=be_others))) == 0 | |||
| def test_flatten_with_parent(): | |||
| m = MyModule2() | |||
| assert list(m.named_modules(with_parent=True)) == [ | |||
| ("", m, None), | |||
| ("a.0", m.a[0], m), | |||
| ("a.1.x", m.a[1]["x"], m), | |||
| ("a.1.y.0", m.a[1]["y"][0], m), | |||
| ("a.1.y.1", m.a[1]["y"][1], m), | |||
| ("a.1.y.1.bn", m.a[1]["y"][1].bn, m.a[1]["y"][1]), | |||
| ("a.2.0", m.a[2][0], m), | |||
| ("a.2.0.bn", m.a[2][0].bn, m.a[2][0]), | |||
| ("bn", m.bn, m), | |||
| ] | |||
| assert list(m.modules(with_parent=True)) == [ | |||
| (m, None), | |||
| (m.a[0], m), | |||
| (m.a[1]["x"], m), | |||
| (m.a[1]["y"][0], m), | |||
| (m.a[1]["y"][1], m), | |||
| (m.a[1]["y"][1].bn, m.a[1]["y"][1]), | |||
| (m.a[2][0], m), | |||
| (m.a[2][0].bn, m.a[2][0]), | |||
| (m.bn, m), | |||
| ] | |||
| class MyModule3(Module): | |||
| class InnerModule(Module): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.bn = BatchNorm2d(4) | |||
| def forward(self, x): | |||
| x = self.bn(x) | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.bn = BatchNorm2d(4) | |||
| self.seq = Sequential(BatchNorm2d(4), self.InnerModule(),) | |||
| def forward(self, x): | |||
| return x | |||
| def test_module_api_with_sequential(): | |||
| m = MyModule3() | |||
| assert list(m.named_modules()) == [ | |||
| ("", m), | |||
| ("bn", m.bn), | |||
| ("seq", m.seq), | |||
| ("seq.0", m.seq[0]), | |||
| ("seq.1", m.seq[1]), | |||
| ("seq.1.bn", m.seq[1].bn), | |||
| ] | |||
| def test_sequential_named_children(): | |||
| modules = OrderedDict() | |||
| modules["name0"] = Linear(20, 10) | |||
| modules["name1"] = Linear(10, 5) | |||
| modules["name2"] = Linear(5, 1) | |||
| m = Sequential(modules) | |||
| l = list(m.named_children()) | |||
| assert l[0][0] == "layer_values.0" | |||
| assert l[1][0] == "layer_values.1" | |||
| assert l[2][0] == "layer_values.2" | |||
| def test_state_dict(): | |||
| data_shape = (2, 28) | |||
| data = tensor([]) | |||
| data.set_value(np.random.random(data_shape)) | |||
| mlp = MLP() | |||
| pred0 = mlp(data) | |||
| with BytesIO() as fout: | |||
| mge.save(mlp.state_dict(), fout) | |||
| fout.seek(0) | |||
| state_dict = mge.load(fout) | |||
| state_dict["extra"] = None | |||
| mlp1 = MLP() | |||
| mlp1.load_state_dict(state_dict, strict=False) | |||
| pred1 = mlp1(data) | |||
| assertTensorClose(pred0.numpy(), pred1.numpy(), max_err=5e-6) | |||
| with pytest.raises(KeyError): | |||
| mlp1.load_state_dict(state_dict) | |||
| del state_dict["extra"] | |||
| del state_dict["dense0.bias"] | |||
| with pytest.raises(KeyError): | |||
| mlp1.load_state_dict(state_dict) | |||
| class AssertModule(Module): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.error_tensor_key = {True: tensor([]), False: 0} | |||
| def forward(self, x): | |||
| return x | |||
| def test_assert_message(): | |||
| m = AssertModule() | |||
| with pytest.raises( | |||
| AssertionError, match="keys for Tensor and Module must be str, error key: True" | |||
| ): | |||
| list(m._flatten()) | |||
| class Simple(Module): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.conv0 = Conv2d(1, 1, kernel_size=3, bias=False) | |||
| self.conv1 = Conv2d(1, 1, kernel_size=3, bias=False) | |||
| self.conv1.weight = self.conv0.weight | |||
| def forward(self, inputs): | |||
| pass | |||
| def test_shared_param(): | |||
| net = Simple() | |||
| assert net.conv0.weight is net.conv1.weight | |||
| data = tensor(np.random.random((1, 1, 8, 8)).astype(np.float32)) | |||
| assertTensorClose(net.conv0(data).numpy(), net.conv1(data).numpy()) | |||
| with BytesIO() as f: | |||
| mge.save(net, f) | |||
| f.seek(0) | |||
| net1 = mge.load(f) | |||
| assert net1.conv0.weight is net1.conv1.weight | |||
| assertTensorClose(net1.conv0(data).numpy(), net1.conv1(data).numpy()) | |||
| with BytesIO() as f: | |||
| mge.save(net.conv0, f) | |||
| f.seek(0) | |||
| conv0 = mge.load(f) | |||
| with BytesIO() as f: | |||
| mge.save(net.conv1, f) | |||
| f.seek(0) | |||
| conv1 = mge.load(f) | |||
| assert conv0.weight is not conv1.weight | |||
| assertTensorClose(conv0(data).numpy(), conv1(data).numpy()) | |||
| def test_pickle_module(): | |||
| data_shape = (2, 28) | |||
| data = tensor([]) | |||
| data.set_value(np.random.random(data_shape)) | |||
| mlp = MLP() | |||
| # pickle before forward | |||
| with BytesIO() as fout: | |||
| mge.save(mlp, fout) | |||
| fout.seek(0) | |||
| mlp1 = mge.load(fout) | |||
| pred0 = mlp1(data) | |||
| pred1 = mlp(data) | |||
| # pickle after forward | |||
| with BytesIO() as fout: | |||
| mge.save(mlp, fout) | |||
| fout.seek(0) | |||
| mlp1 = mge.load(fout) | |||
| pred2 = mlp1(data) | |||
| assertTensorClose(pred0.numpy(), pred1.numpy(), max_err=5e-6) | |||
| assertTensorClose(pred0.numpy(), pred2.numpy(), max_err=5e-6) | |||
| @pytest.mark.skip(reason="under development") | |||
| def test_dump_model(): | |||
| data_shape = (2, 28) | |||
| data = tensor([]) | |||
| data.set_value(np.random.random(data_shape)) | |||
| mlp = MLP() | |||
| pred = mlp(data) | |||
| f = tempfile.NamedTemporaryFile(delete=False) | |||
| f_name = f.name | |||
| try: | |||
| mge.dump(pred, f_name) | |||
| finally: | |||
| f.close() | |||
| os.unlink(f_name) | |||
| def test_load_quantized(): | |||
| from megengine.core.tensor import dtype | |||
| data_shape = (2, 28) | |||
| data = tensor(np.random.random(data_shape), dtype="float32") | |||
| data = data.astype(dtype.qint8(0.1)) | |||
| mlp = MLP() | |||
| quantize_qat(mlp) | |||
| quantize(mlp) | |||
| mlp.dense0.weight = Parameter(mlp.dense0.weight.astype(dtype.qint8(0.001)).numpy()) | |||
| mlp.dense1.weight = Parameter(mlp.dense1.weight.astype(dtype.qint8(0.0002)).numpy()) | |||
| mlp.eval() | |||
| pred0 = mlp(data) | |||
| with BytesIO() as fout: | |||
| mge.save(mlp.state_dict(), fout) | |||
| fout.seek(0) | |||
| checkpoint = mge.load(fout) | |||
| # change mlp weight. | |||
| mlp.dense0.weight = Parameter( | |||
| mlp.dense0.weight.astype(dtype.qint8(0.00001)).numpy() | |||
| ) | |||
| mlp.dense1.weight = Parameter( | |||
| mlp.dense1.weight.astype(dtype.qint8(0.2)).numpy() | |||
| ) | |||
| mlp.load_state_dict(checkpoint) | |||
| pred1 = mlp(data) | |||
| assertTensorClose( | |||
| pred0.astype("float32").numpy(), pred1.astype("float32").numpy(), max_err=5e-6 | |||
| ) | |||
| @@ -0,0 +1,91 @@ | |||
| from itertools import product | |||
| import numpy as np | |||
| from megengine import tensor | |||
| from megengine.module import ( | |||
| Conv2d, | |||
| ConvBn2d, | |||
| ConvRelu2d, | |||
| DequantStub, | |||
| Module, | |||
| QuantStub, | |||
| ) | |||
| from megengine.quantization.quantize import disable_fake_quant, quantize_qat | |||
| from megengine.test import assertTensorClose | |||
| def test_qat_convbn2d(): | |||
| in_channels = 32 | |||
| out_channels = 64 | |||
| kernel_size = 3 | |||
| for groups, bias in product([1, 4], [True, False]): | |||
| module = ConvBn2d( | |||
| in_channels, out_channels, kernel_size, groups=groups, bias=bias | |||
| ) | |||
| module.train() | |||
| qat_module = quantize_qat(module, inplace=False) | |||
| disable_fake_quant(qat_module) | |||
| inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32)) | |||
| normal_outputs = module(inputs) | |||
| # import pdb | |||
| # pdb.set_trace() | |||
| qat_outputs = qat_module(inputs) | |||
| assertTensorClose(normal_outputs.numpy(), qat_outputs.numpy(), max_err=5e-6) | |||
| assertTensorClose( | |||
| module.bn.running_mean.numpy(), | |||
| qat_module.bn.running_mean.numpy(), | |||
| max_err=5e-8, | |||
| ) | |||
| assertTensorClose( | |||
| module.bn.running_var.numpy(), | |||
| qat_module.bn.running_var.numpy(), | |||
| max_err=5e-7, | |||
| ) | |||
| module.eval() | |||
| normal_outputs = module(inputs) | |||
| qat_module.eval() | |||
| qat_outputs = qat_module(inputs) | |||
| assertTensorClose(normal_outputs.numpy(), qat_outputs.numpy(), max_err=5e-6) | |||
| def test_qat_conv(): | |||
| in_channels = 32 | |||
| out_channels = 64 | |||
| kernel_size = 3 | |||
| class TestNet(Module): | |||
| def __init__(self, groups, bias): | |||
| super().__init__() | |||
| self.quant = QuantStub() | |||
| self.dequant = DequantStub() | |||
| self.conv = Conv2d( | |||
| in_channels, out_channels, kernel_size, groups=groups, bias=bias | |||
| ) | |||
| self.conv_relu = ConvRelu2d( | |||
| out_channels, in_channels, kernel_size, groups=groups, bias=bias | |||
| ) | |||
| def forward(self, inp): | |||
| out = self.quant(inp) | |||
| out = self.conv(out) | |||
| out = self.conv_relu(out) | |||
| out = self.dequant(out) | |||
| return out | |||
| inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32)) | |||
| for groups, bias in product([1, 4], [True, False]): | |||
| net = TestNet(groups, bias) | |||
| net.train() | |||
| qat_net = quantize_qat(net, inplace=False) | |||
| disable_fake_quant(qat_net) | |||
| normal_outputs = net(inputs) | |||
| qat_outputs = qat_net(inputs) | |||
| assertTensorClose(normal_outputs.numpy(), qat_outputs.numpy()) | |||
| net.eval() | |||
| normal_outputs = net(inputs) | |||
| qat_net.eval() | |||
| qat_outputs = qat_net(inputs) | |||
| assertTensorClose(normal_outputs.numpy(), qat_outputs.numpy()) | |||
| @@ -0,0 +1,89 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import copy | |||
| import numpy as np | |||
| import pytest | |||
| import megengine as mge | |||
| import megengine.functional as F | |||
| from megengine import Buffer, Parameter | |||
| from megengine.module import Conv2d | |||
| from megengine.test import assertTensorClose | |||
| def test_set_value(): | |||
| v0 = np.random.random((2, 3)).astype(np.float32) | |||
| param = Parameter(v0) | |||
| v1 = np.random.random((2, 3)).astype(np.float32) | |||
| param.set_value(v1) | |||
| assertTensorClose(param.numpy(), v1, max_err=5e-6) | |||
| v2 = np.random.random((3, 3)).astype(np.float32) | |||
| # TODO: add this | |||
| # with pytest.raises(ValueError): | |||
| # param.set_value(v2) | |||
| assertTensorClose(param.numpy(), v1, max_err=5e-6) | |||
| @pytest.mark.skip(reason="fill unsupported") | |||
| def test_fill(): | |||
| a = Buffer(np.zeros((2, 3), dtype=np.float32)) | |||
| a.fill(3) | |||
| assertTensorClose(a.numpy(), np.full((2, 3), 3, dtype=np.float32)) | |||
| a.fill(124.568) | |||
| assertTensorClose(a.numpy(), np.full((2, 3), 124.568, dtype=np.float32)) | |||
| # TODO: remove or rewrite following test | |||
| # def test_attach(): | |||
| # p_ = np.random.random((2, 3)).astype(np.float32) | |||
| # with Graph() as g: | |||
| # g.set_option('eager_evaluation', False) | |||
| # p = Parameter(p_) | |||
| # v = p * 2 | |||
| # f = compile(v, None) | |||
| # out, = f() | |||
| # assertTensorClose(out, p_ * 2) | |||
| # F.add_update(p, p) | |||
| # out, = f() | |||
| # assertTensorClose(out, p_ * 4) | |||
| # TODO: remove or rewrite following test | |||
| # def test_module_attach(): | |||
| # v = np.random.random((1, 3, 64, 64)).astype(np.float32) | |||
| # net = Conv2d(3, 16, 3) | |||
| # with Graph() as g: | |||
| # g.set_option('eager_evaluation', False) | |||
| # data0 = Input("data") | |||
| # f = compile(net(data0), None) | |||
| # out0, = f(data=v) | |||
| # data1 = Input("data", value=v) | |||
| # out1 = net(data1) | |||
| # assertTensorClose(out0, out1.numpy()) | |||
| # def test_shape_warning(): | |||
| # with Graph() as cg: | |||
| # cg.set_option("eager_evaluation", False) | |||
| # b = Buffer(np.ones((2, 3)).astype(np.float32)) | |||
| # with pytest.warns(None) as record: | |||
| # print(b.shape) | |||
| # if len(record) != 0: | |||
| # raise ValueError( | |||
| # "Getting the shape of a constant Tensor should throw no Warning" | |||
| # ) | |||
| @@ -1,112 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import platform | |||
| import pytest | |||
| import megengine as mge | |||
| import megengine.distributed as dist | |||
| from megengine import tensor | |||
| from megengine.distributed.group import Group | |||
| from megengine.distributed.helper import get_device_count_by_fork | |||
| from megengine.module import SyncBatchNorm | |||
| from megengine.test import assertTensorClose | |||
| @pytest.mark.skipif( | |||
| platform.system() == "Darwin", reason="do not imp GPU mode at macos now" | |||
| ) | |||
| @pytest.mark.skipif( | |||
| platform.system() == "Windows", reason="do not imp GPU mode at Windows now" | |||
| ) | |||
| @pytest.mark.skipif(get_device_count_by_fork("gpu") < 4, reason="need more gpu device") | |||
| @pytest.mark.isolated_distributed | |||
| def test_syncbn(): | |||
| import numpy as np | |||
| import multiprocessing as mp | |||
| from megengine.distributed.group import Server | |||
| from megengine.core._trace_option import use_tensor_shape | |||
| if use_tensor_shape(): # XXX: fix sync bn if use_tensor_shape | |||
| return | |||
| nr_chan = 8 | |||
| nr_ranks = 4 | |||
| data_shape = (3, nr_chan, 4, nr_ranks * 8) | |||
| momentum = 0.9 | |||
| eps = 1e-5 | |||
| running_mean = np.zeros((1, nr_chan, 1, 1), dtype=np.float32) | |||
| running_var = np.ones((1, nr_chan, 1, 1), dtype=np.float32) | |||
| steps = 4 | |||
| server = Server(0) | |||
| port = server.py_server_port | |||
| def worker(rank, data, yv_expect, running_mean, running_var): | |||
| dist.init_process_group("localhost", port, nr_ranks, rank, rank) | |||
| group = Group([i for i in range(nr_ranks)]) | |||
| bn = SyncBatchNorm(nr_chan, eps=eps, momentum=momentum, group=group) | |||
| data_tensor = None | |||
| for i in range(steps): | |||
| if data_tensor is None: | |||
| data_tensor = tensor(data[i], device=f"gpu{rank}:0") | |||
| else: | |||
| data_tensor.set_value(data[i]) | |||
| yv = bn(data_tensor) | |||
| assertTensorClose(yv_expect, yv.numpy(), max_err=5e-6) | |||
| assertTensorClose(running_mean, bn.running_mean.numpy(), max_err=5e-6) | |||
| assertTensorClose(running_var, bn.running_var.numpy(), max_err=5e-6) | |||
| xv = [] | |||
| for i in range(steps): | |||
| xv.append(np.random.normal(loc=2.3, size=data_shape).astype(np.float32)) | |||
| xv_transposed = np.transpose(xv[i], [0, 2, 3, 1]).reshape( | |||
| (data_shape[0] * data_shape[2] * data_shape[3], nr_chan) | |||
| ) | |||
| mean = np.mean(xv_transposed, axis=0).reshape(1, nr_chan, 1, 1) | |||
| var_biased = np.var(xv_transposed, axis=0).reshape((1, nr_chan, 1, 1)) | |||
| sd = np.sqrt(var_biased + eps) | |||
| var_unbiased = np.var(xv_transposed, axis=0, ddof=1).reshape((1, nr_chan, 1, 1)) | |||
| running_mean = running_mean * momentum + mean * (1 - momentum) | |||
| running_var = running_var * momentum + var_unbiased * (1 - momentum) | |||
| yv_expect = (xv[i] - mean) / sd | |||
| data = [] | |||
| for i in range(nr_ranks): | |||
| data.append([]) | |||
| for j in range(steps): | |||
| data[i].append(xv[j][:, :, :, i * 8 : i * 8 + 8]) | |||
| procs = [] | |||
| for rank in range(nr_ranks): | |||
| p = mp.Process( | |||
| target=worker, | |||
| args=( | |||
| rank, | |||
| data[rank], | |||
| yv_expect[:, :, :, rank * 8 : rank * 8 + 8], | |||
| running_mean, | |||
| running_var, | |||
| ), | |||
| ) | |||
| p.start() | |||
| procs.append(p) | |||
| for p in procs: | |||
| p.join(10) | |||
| assert p.exitcode == 0 | |||
| def test_module_conv2d(): | |||
| from megengine.module.conv import Conv2d | |||
| conv = Conv2d(2, 3, 1) | |||