| @@ -0,0 +1,91 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| import multiprocessing as mp | |||||
| import platform | |||||
| import numpy as np | |||||
| import pytest | |||||
| import megengine | |||||
| import megengine.autodiff as ad | |||||
| import megengine.distributed as dist | |||||
| import megengine.optimizer as optimizer | |||||
| from megengine import Parameter, tensor | |||||
| from megengine.distributed.helper import get_device_count_by_fork | |||||
| from megengine.module import Module | |||||
| from megengine.optimizer import SGD | |||||
| class Simple(Module): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.params = [Parameter(1.0, dtype=np.float32) for i in range(10)] | |||||
| def forward(self, x): | |||||
| for p in self.params: | |||||
| x = x * p | |||||
| return x | |||||
| @pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device") | |||||
| @pytest.mark.isolated_distributed | |||||
| @pytest.mark.skipif( | |||||
| platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM" | |||||
| ) | |||||
| def test_param_pack(): | |||||
| data = np.ones([1], dtype="float32") | |||||
| @dist.launcher | |||||
| def worker(): | |||||
| net = Simple() | |||||
| opt = SGD(net.parameters(), lr=0.1) | |||||
| gm = ad.GradManager().attach( | |||||
| net.parameters(), callbacks=[dist.make_allreduce_cb("MEAN", dist.WORLD)] | |||||
| ) | |||||
| opt.clear_grad() | |||||
| with gm: | |||||
| x = tensor(data) | |||||
| loss = net(x) | |||||
| loss = loss.sum() | |||||
| gm.backward(loss) | |||||
| for p in net.params: | |||||
| np.testing.assert_equal(p.grad.numpy(), 1) | |||||
| worker() | |||||
| @pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device") | |||||
| @pytest.mark.isolated_distributed | |||||
| @pytest.mark.skipif( | |||||
| platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM" | |||||
| ) | |||||
| def test_param_pack_with_no_param(): | |||||
| data = np.ones([1], dtype="float32") | |||||
| @dist.launcher | |||||
| def worker(): | |||||
| net = Simple() | |||||
| opt = SGD(net.parameters(), lr=0.1) | |||||
| allreduce_cb = dist.make_allreduce_cb("MEAN", dist.WORLD) | |||||
| allreduce_cb._param_pack_thd = 0 | |||||
| gm = ad.GradManager().attach(net.parameters(), callbacks=[allreduce_cb]) | |||||
| opt.clear_grad() | |||||
| with gm: | |||||
| x = tensor(data) | |||||
| loss = net(x) | |||||
| loss = loss.sum() | |||||
| gm.backward(loss) | |||||
| for p in net.params: | |||||
| np.testing.assert_equal(p.grad.numpy(), 1) | |||||
| worker() | |||||