| @@ -1,3 +1,5 @@ | |||||
| import multiprocessing as mp | |||||
| import numpy as np | import numpy as np | ||||
| import pytest | import pytest | ||||
| @@ -88,8 +90,7 @@ class ResNet(M.Module): | |||||
| return out | return out | ||||
| @pytest.mark.require_ngpu(1) | |||||
| def test_dtr_resnet1202(): | |||||
| def run_dtr_resnet1202(): | |||||
| batch_size = 8 | batch_size = 8 | ||||
| resnet1202 = ResNet(BasicBlock, [200, 200, 200]) | resnet1202 = ResNet(BasicBlock, [200, 200, 200]) | ||||
| opt = optim.SGD(resnet1202.parameters(), lr=0.05, momentum=0.9, weight_decay=1e-4) | opt = optim.SGD(resnet1202.parameters(), lr=0.05, momentum=0.9, weight_decay=1e-4) | ||||
| @@ -124,3 +125,13 @@ def test_dtr_resnet1202(): | |||||
| t.numpy() | t.numpy() | ||||
| mge.dtr.disable() | mge.dtr.disable() | ||||
| mge._exit(0) | |||||
| @pytest.mark.require_ngpu(1) | |||||
| @pytest.mark.isolated_distributed | |||||
| def test_dtr_resnet1202(): | |||||
| p = mp.Process(target=run_dtr_resnet1202) | |||||
| p.start() | |||||
| p.join() | |||||
| assert p.exitcode == 0 | |||||