| @@ -1,75 +1,28 @@ | |||||
| import unittest | |||||
| import torch | |||||
| import os | |||||
| import pytest | |||||
| os.environ["FASTNLP_BACKEND"] = "paddle" | |||||
| from fastNLP.core.drivers.paddle_driver.paddle_driver import PaddleDriver | from fastNLP.core.drivers.paddle_driver.paddle_driver import PaddleDriver | ||||
| import paddle | |||||
| from paddle.io import Dataset, DataLoader | |||||
| class Net(paddle.nn.Layer): | |||||
| def __init__(self): | |||||
| super(Net, self).__init__() | |||||
| self.fc1 = paddle.nn.Linear(784, 64) | |||||
| self.fc2 = paddle.nn.Linear(64, 32) | |||||
| self.fc3 = paddle.nn.Linear(32, 10) | |||||
| self.fc4 = paddle.nn.Linear(10, 10) | |||||
| def forward(self, x): | |||||
| x = self.fc1(x) | |||||
| x = self.fc2(x) | |||||
| x = self.fc3(x) | |||||
| x = self.fc4(x) | |||||
| return x | |||||
| class PaddleDataset(Dataset): | |||||
| def __init__(self): | |||||
| super(PaddleDataset, self).__init__() | |||||
| self.items = [paddle.rand((3, 4)) for i in range(320)] | |||||
| def __len__(self): | |||||
| return len(self.items) | |||||
| def __getitem__(self, idx): | |||||
| return self.items[idx] | |||||
| class TorchNet(torch.nn.Module): | |||||
| def __init__(self): | |||||
| super(TorchNet, self).__init__() | |||||
| self.torch_fc1 = torch.nn.Linear(10, 10) | |||||
| self.torch_softmax = torch.nn.Softmax(0) | |||||
| self.torch_conv2d1 = torch.nn.Conv2d(10, 10, 3) | |||||
| self.torch_tensor = torch.ones(3, 3) | |||||
| self.torch_param = torch.nn.Parameter(torch.ones(4, 4)) | |||||
| class TorchDataset(torch.utils.data.Dataset): | |||||
| def __init__(self): | |||||
| super(TorchDataset, self).__init__() | |||||
| self.items = [torch.ones(3, 4) for i in range(320)] | |||||
| def __len__(self): | |||||
| return len(self.items) | |||||
| def __getitem__(self, idx): | |||||
| return self.items[idx] | |||||
| from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 | |||||
| from tests.helpers.datasets.paddle_data import PaddleNormalDataset | |||||
| from tests.helpers.datasets.torch_data import TorchNormalDataset | |||||
| from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||||
| import torch | |||||
| import paddle | |||||
| from paddle.io import DataLoader | |||||
| class PaddleDriverTestCase(unittest.TestCase): | |||||
| class TestPaddleDriverFunctions: | |||||
| """ | """ | ||||
| PaddleDriver的测试类,由于类的特殊性仅测试部分函数,其它的由PaddleSingleDriver和PaddleFleetDriver完成测试 | |||||
| PaddleDriver的测试类,使用仅测试部分函数,其它的由PaddleSingleDriver和PaddleFleetDriver完成测试 | |||||
| """ | """ | ||||
| def setUp(self): | |||||
| model = Net() | |||||
| @classmethod | |||||
| def setup_class(self): | |||||
| model = PaddleNormalModel_Classification_1(10, 32) | |||||
| self.driver = PaddleDriver(model) | self.driver = PaddleDriver(model) | ||||
| def test_check_single_optimizer_legacy(self): | |||||
| def test_check_single_optimizer_legality(self): | |||||
| """ | """ | ||||
| 测试传入单个optimizer时的表现 | 测试传入单个optimizer时的表现 | ||||
| """ | """ | ||||
| @@ -80,12 +33,12 @@ class PaddleDriverTestCase(unittest.TestCase): | |||||
| self.driver.set_optimizers(optimizer) | self.driver.set_optimizers(optimizer) | ||||
| optimizer = torch.optim.Adam(TorchNet().parameters(), 0.01) | |||||
| optimizer = torch.optim.Adam(TorchNormalModel_Classification_1(10, 32).parameters(), 0.01) | |||||
| # 传入torch的optimizer时,应该报错ValueError | # 传入torch的optimizer时,应该报错ValueError | ||||
| with self.assertRaises(ValueError) as cm: | with self.assertRaises(ValueError) as cm: | ||||
| self.driver.set_optimizers(optimizer) | self.driver.set_optimizers(optimizer) | ||||
| def test_check_optimizers_legacy(self): | |||||
| def test_check_optimizers_legality(self): | |||||
| """ | """ | ||||
| 测试传入optimizer list的表现 | 测试传入optimizer list的表现 | ||||
| """ | """ | ||||
| @@ -99,22 +52,27 @@ class PaddleDriverTestCase(unittest.TestCase): | |||||
| self.driver.set_optimizers(optimizers) | self.driver.set_optimizers(optimizers) | ||||
| optimizers += [ | optimizers += [ | ||||
| torch.optim.Adam(TorchNet().parameters(), 0.01) | |||||
| torch.optim.Adam(TorchNormalModel_Classification_1(10, 32).parameters(), 0.01) | |||||
| ] | ] | ||||
| with self.assertRaises(ValueError) as cm: | with self.assertRaises(ValueError) as cm: | ||||
| self.driver.set_optimizers(optimizers) | self.driver.set_optimizers(optimizers) | ||||
| def test_check_dataloader_legacy_in_train(self): | |||||
| def test_check_dataloader_legality_in_train(self): | |||||
| """ | """ | ||||
| 测试is_train参数为True时,_check_dataloader_legality函数的表现 | 测试is_train参数为True时,_check_dataloader_legality函数的表现 | ||||
| """ | """ | ||||
| dataloader = paddle.io.DataLoader(PaddleDataset()) | |||||
| dataloader = paddle.io.DataLoader(PaddleNormalDataset()) | |||||
| PaddleDriver._check_dataloader_legality(dataloader, "dataloader", True) | PaddleDriver._check_dataloader_legality(dataloader, "dataloader", True) | ||||
| # batch_size 和 batch_sampler 均为 None 的情形 | |||||
| dataloader = paddle.io.DataLoader(PaddleNormalDataset(), batch_size=None) | |||||
| with self.assertRaises(ValueError) as cm: | |||||
| PaddleDriver._check_dataloader_legality(dataloader, "dataloader", True) | |||||
| # 创建torch的dataloader | # 创建torch的dataloader | ||||
| dataloader = torch.utils.data.DataLoader( | dataloader = torch.utils.data.DataLoader( | ||||
| TorchDataset(), | |||||
| TorchNormalDataset(), | |||||
| batch_size=32, shuffle=True | batch_size=32, shuffle=True | ||||
| ) | ) | ||||
| with self.assertRaises(ValueError) as cm: | with self.assertRaises(ValueError) as cm: | ||||
| @@ -125,21 +83,31 @@ class PaddleDriverTestCase(unittest.TestCase): | |||||
| 测试is_train参数为False时,_check_dataloader_legality函数的表现 | 测试is_train参数为False时,_check_dataloader_legality函数的表现 | ||||
| """ | """ | ||||
| # 此时传入的应该是dict | # 此时传入的应该是dict | ||||
| dataloader = {"train": paddle.io.DataLoader(PaddleDataset()), "test":paddle.io.DataLoader(PaddleDataset())} | |||||
| dataloader = { | |||||
| "train": paddle.io.DataLoader(PaddleNormalDataset()), | |||||
| "test":paddle.io.DataLoader(PaddleNormalDataset()) | |||||
| } | |||||
| PaddleDriver._check_dataloader_legality(dataloader, "dataloader", False) | |||||
| # batch_size 和 batch_sampler 均为 None 的情形 | |||||
| dataloader = { | |||||
| "train": paddle.io.DataLoader(PaddleNormalDataset()), | |||||
| "test":paddle.io.DataLoader(PaddleNormalDataset(), batch_size=None) | |||||
| } | |||||
| PaddleDriver._check_dataloader_legality(dataloader, "dataloader", False) | PaddleDriver._check_dataloader_legality(dataloader, "dataloader", False) | ||||
| # 传入的不是dict,应该报错 | # 传入的不是dict,应该报错 | ||||
| dataloader = paddle.io.DataLoader(PaddleDataset()) | |||||
| dataloader = paddle.io.DataLoader(PaddleNormalDataset()) | |||||
| with self.assertRaises(ValueError) as cm: | with self.assertRaises(ValueError) as cm: | ||||
| PaddleDriver._check_dataloader_legality(dataloader, "dataloader", False) | PaddleDriver._check_dataloader_legality(dataloader, "dataloader", False) | ||||
| # 创建torch的dataloader | # 创建torch的dataloader | ||||
| train_loader = torch.utils.data.DataLoader( | train_loader = torch.utils.data.DataLoader( | ||||
| TorchDataset(), | |||||
| TorchNormalDataset(), | |||||
| batch_size=32, shuffle=True | batch_size=32, shuffle=True | ||||
| ) | ) | ||||
| test_loader = torch.utils.data.DataLoader( | test_loader = torch.utils.data.DataLoader( | ||||
| TorchDataset(), | |||||
| TorchNormalDataset(), | |||||
| batch_size=32, shuffle=True | batch_size=32, shuffle=True | ||||
| ) | ) | ||||
| dataloader = {"train": train_loader, "test": test_loader} | dataloader = {"train": train_loader, "test": test_loader} | ||||
| @@ -240,7 +208,7 @@ class PaddleDriverTestCase(unittest.TestCase): | |||||
| """ | """ | ||||
| # 先确保不影响运行 | # 先确保不影响运行 | ||||
| # TODO:正确性 | # TODO:正确性 | ||||
| dataloader = DataLoader(PaddleDataset()) | |||||
| dataloader = DataLoader(PaddleNormalDataset()) | |||||
| self.driver.set_deterministic_dataloader(dataloader) | self.driver.set_deterministic_dataloader(dataloader) | ||||
| def test_set_sampler_epoch(self): | def test_set_sampler_epoch(self): | ||||
| @@ -249,7 +217,7 @@ class PaddleDriverTestCase(unittest.TestCase): | |||||
| """ | """ | ||||
| # 先确保不影响运行 | # 先确保不影响运行 | ||||
| # TODO:正确性 | # TODO:正确性 | ||||
| dataloader = DataLoader(PaddleDataset()) | |||||
| dataloader = DataLoader(PaddleNormalDataset()) | |||||
| self.driver.set_sampler_epoch(dataloader, 0) | self.driver.set_sampler_epoch(dataloader, 0) | ||||
| def test_get_dataloader_args(self): | def test_get_dataloader_args(self): | ||||
| @@ -258,5 +226,5 @@ class PaddleDriverTestCase(unittest.TestCase): | |||||
| """ | """ | ||||
| # 先确保不影响运行 | # 先确保不影响运行 | ||||
| # TODO:正确性 | # TODO:正确性 | ||||
| dataloader = DataLoader(PaddleDataset()) | |||||
| dataloader = DataLoader(PaddleNormalDataset()) | |||||
| res = PaddleDriver.get_dataloader_args(dataloader) | res = PaddleDriver.get_dataloader_args(dataloader) | ||||