| @@ -1,21 +1,35 @@ | |||||
| from dataclasses import replace | |||||
| import pytest | import pytest | ||||
| import os | import os | ||||
| import numpy as np | |||||
| from fastNLP.envs.set_env_on_import import set_env_on_import_paddle | |||||
| set_env_on_import_paddle() | |||||
| os.environ["FASTNLP_BACKEND"] = "paddle" | |||||
| from fastNLP.core.drivers.paddle_driver.fleet import PaddleFleetDriver | |||||
| from fastNLP.core.samplers import ( | |||||
| RandomSampler, | |||||
| UnrepeatedSampler, | |||||
| BucketedBatchSampler, | |||||
| UnrepeatedRandomSampler, | |||||
| UnrepeatedSequentialSampler, | |||||
| ) | |||||
| from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 | |||||
| from tests.helpers.datasets.paddle_data import PaddleNormalDataset | |||||
| from tests.helpers.utils import magic_argv_env_context | |||||
| import paddle | import paddle | ||||
| import paddle.distributed as dist | import paddle.distributed as dist | ||||
| from paddle.io import DataLoader | |||||
| from paddle.io import DataLoader, BatchSampler | |||||
| from fastNLP.core.drivers.paddle_driver.fleet import PaddleFleetDriver | |||||
| from fastNLP.core.samplers.reproducible_sampler import RandomSampler | |||||
| from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK | |||||
| from tests.helpers.models.paddle_model import PaddleNormalModel_Classification | |||||
| from tests.helpers.datasets.paddle_data import PaddleDataset_MNIST | |||||
| from tests.helpers.utils import magic_argv_env_context | |||||
| from fastNLP.core import synchronize_safe_rm | |||||
| def generate_driver(num_labels, feature_dimension): | |||||
| paddle_model = PaddleNormalModel_Classification_1(num_labels, feature_dimension) | |||||
| paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01) | |||||
| driver = PaddleFleetDriver( | |||||
| model=paddle_model, | |||||
| parallel_device=[0,1], | |||||
| ) | |||||
| driver.set_optimizers(paddle_opt) | |||||
| driver.setup() | |||||
| return driver | |||||
| ############################################################################ | ############################################################################ | ||||
| # | # | ||||
| @@ -23,269 +37,340 @@ from fastNLP.core import synchronize_safe_rm | |||||
| # | # | ||||
| ############################################################################ | ############################################################################ | ||||
| @magic_argv_env_context | |||||
| def test_move_data_to_device(): | |||||
| """ | |||||
| 这个函数仅调用了paddle_move_data_to_device,测试例在tests/core/utils/test_paddle_utils.py中 | |||||
| 就不重复测试了 | |||||
| """ | |||||
| try: | |||||
| paddle_model = PaddleNormalModel_Classification(10, 784) | |||||
| paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01) | |||||
| driver = PaddleFleetDriver( | |||||
| model=paddle_model, | |||||
| parallel_device=[0,1], | |||||
| ) | |||||
| driver.set_optimizers(paddle_opt) | |||||
| # 区分launch和子进程setup的时候 | |||||
| if FASTNLP_DISTRIBUTED_CHECK not in os.environ: | |||||
| with pytest.raises(SystemExit) as e: | |||||
| driver.setup() | |||||
| assert e.value.code == 0 | |||||
| return | |||||
| else: | |||||
| driver.setup() | |||||
| driver.move_data_to_device(paddle.rand((32, 64))) | |||||
| finally: | |||||
| synchronize_safe_rm("log") | |||||
| dist.barrier() | |||||
| @magic_argv_env_context | |||||
| def test_is_distributed(): | |||||
| print(os.getenv("CUDA_VISIBLE_DEVICES")) | |||||
| print(paddle.device.get_device()) | |||||
| try: | |||||
| paddle_model = PaddleNormalModel_Classification(10, 784) | |||||
| paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01) | |||||
| driver = PaddleFleetDriver( | |||||
| model=paddle_model, | |||||
| parallel_device=[0,1], | |||||
| output_from_new_proc='all' | |||||
| ) | |||||
| driver.set_optimizers(paddle_opt) | |||||
| # 区分launch和子进程setup的时候 | |||||
| if FASTNLP_DISTRIBUTED_CHECK not in os.environ: | |||||
| with pytest.raises(SystemExit) as e: | |||||
| driver.setup() | |||||
| assert e.value.code == 0 | |||||
| return | |||||
| else: | |||||
| driver.setup() | |||||
| assert driver.is_distributed() == True | |||||
| finally: | |||||
| synchronize_safe_rm("log") | |||||
| dist.barrier() | |||||
| @magic_argv_env_context | |||||
| def test_get_no_sync_context(): | |||||
| class TestFleetDriverFunction: | |||||
| """ | """ | ||||
| 测试能否运行 | |||||
| 测试 PaddleFleetDriver 一些简单函数的测试类,基本都是测试能否运行、是否存在 import 错误等问题 | |||||
| """ | """ | ||||
| try: | |||||
| paddle_model = PaddleNormalModel_Classification(10, 784) | |||||
| paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01) | |||||
| driver = PaddleFleetDriver( | |||||
| model=paddle_model, | |||||
| parallel_device=[0,1], | |||||
| ) | |||||
| driver.set_optimizers(paddle_opt) | |||||
| # 区分launch和子进程setup的时候 | |||||
| if FASTNLP_DISTRIBUTED_CHECK not in os.environ: | |||||
| with pytest.raises(SystemExit) as e: | |||||
| driver.setup() | |||||
| assert e.value.code == 0 | |||||
| return | |||||
| else: | |||||
| driver.setup() | |||||
| res = driver.get_no_sync_context() | |||||
| finally: | |||||
| synchronize_safe_rm("log") | |||||
| dist.barrier() | |||||
| @magic_argv_env_context | |||||
| def test_is_global_zero(): | |||||
| try: | |||||
| paddle_model = PaddleNormalModel_Classification(10, 784) | |||||
| paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01) | |||||
| driver = PaddleFleetDriver( | |||||
| model=paddle_model, | |||||
| parallel_device=[0,1], | |||||
| ) | |||||
| driver.set_optimizers(paddle_opt) | |||||
| # 区分launch和子进程setup的时候 | |||||
| if FASTNLP_DISTRIBUTED_CHECK not in os.environ: | |||||
| with pytest.raises(SystemExit) as e: | |||||
| driver.setup() | |||||
| assert e.value.code == 0 | |||||
| return | |||||
| else: | |||||
| driver.setup() | |||||
| driver.is_global_zero() | |||||
| finally: | |||||
| synchronize_safe_rm("log") | |||||
| dist.barrier() | |||||
| @magic_argv_env_context | |||||
| def test_unwrap_model(): | |||||
| try: | |||||
| paddle_model = PaddleNormalModel_Classification(10, 784) | |||||
| paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01) | |||||
| driver = PaddleFleetDriver( | |||||
| model=paddle_model, | |||||
| parallel_device=[0,1], | |||||
| ) | |||||
| driver.set_optimizers(paddle_opt) | |||||
| # 区分launch和子进程setup的时候 | |||||
| if FASTNLP_DISTRIBUTED_CHECK not in os.environ: | |||||
| with pytest.raises(SystemExit) as e: | |||||
| driver.setup() | |||||
| assert e.value.code == 0 | |||||
| return | |||||
| else: | |||||
| driver.setup() | |||||
| driver.unwrap_model() | |||||
| finally: | |||||
| synchronize_safe_rm("log") | |||||
| dist.barrier() | |||||
| @magic_argv_env_context | |||||
| def test_get_local_rank(): | |||||
| try: | |||||
| paddle_model = PaddleNormalModel_Classification(10, 784) | |||||
| paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01) | |||||
| driver = PaddleFleetDriver( | |||||
| model=paddle_model, | |||||
| parallel_device=[0,1], | |||||
| ) | |||||
| driver.set_optimizers(paddle_opt) | |||||
| # 区分launch和子进程setup的时候 | |||||
| if FASTNLP_DISTRIBUTED_CHECK not in os.environ: | |||||
| with pytest.raises(SystemExit) as e: | |||||
| driver.setup() | |||||
| assert e.value.code == 0 | |||||
| return | |||||
| else: | |||||
| driver.setup() | |||||
| driver.get_local_rank() | |||||
| finally: | |||||
| synchronize_safe_rm("log") | |||||
| dist.barrier() | |||||
| @magic_argv_env_context | |||||
| @pytest.mark.parametrize( | |||||
| "dist_sampler", | |||||
| ["dist", "unrepeatdist", RandomSampler(PaddleDataset_MNIST("train"))] | |||||
| ) | |||||
| @pytest.mark.parametrize( | |||||
| "reproducible", | |||||
| [True, False] | |||||
| ) | |||||
| def test_replace_sampler(dist_sampler, reproducible): | |||||
| """ | |||||
| 测试replace_sampler | |||||
| """ | |||||
| try: | |||||
| paddle_model = PaddleNormalModel_Classification(10, 784) | |||||
| paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01) | |||||
| driver = PaddleFleetDriver( | |||||
| model=paddle_model, | |||||
| parallel_device=[0,1], | |||||
| ) | |||||
| driver.set_optimizers(paddle_opt) | |||||
| # 区分launch和子进程setup的时候 | |||||
| if FASTNLP_DISTRIBUTED_CHECK not in os.environ: | |||||
| with pytest.raises(SystemExit) as e: | |||||
| driver.setup() | |||||
| assert e.value.code == 0 | |||||
| return | |||||
| else: | |||||
| driver.setup() | |||||
| dataloader = DataLoader(PaddleDataset_MNIST("train"), batch_size=100, shuffle=True) | |||||
| driver.set_dist_repro_dataloader(dataloader, dist_sampler, reproducible) | |||||
| finally: | |||||
| synchronize_safe_rm("log") | |||||
| dist.barrier() | |||||
| @classmethod | |||||
| def setup_class(cls): | |||||
| cls.driver = generate_driver(10, 10) | |||||
| @magic_argv_env_context | |||||
| def test_move_data_to_device(self): | |||||
| """ | |||||
| 这个函数仅调用了paddle_move_data_to_device,测试例在tests/core/utils/test_paddle_utils.py中 | |||||
| 就不重复测试了 | |||||
| """ | |||||
| self.driver.move_data_to_device(paddle.rand((32, 64))) | |||||
| dist.barrier() | |||||
| @magic_argv_env_context | |||||
| def test_is_distributed(self): | |||||
| """ | |||||
| 测试 is_distributed 函数 | |||||
| """ | |||||
| assert self.driver.is_distributed() == True | |||||
| dist.barrier() | |||||
| @magic_argv_env_context | |||||
| def test_get_no_sync_context(self): | |||||
| """ | |||||
| 测试 get_no_sync_context 函数 | |||||
| """ | |||||
| res = self.driver.get_no_sync_context() | |||||
| dist.barrier() | |||||
| @magic_argv_env_context | |||||
| def test_is_global_zero(self): | |||||
| """ | |||||
| 测试 is_global_zero 函数 | |||||
| """ | |||||
| self.driver.is_global_zero() | |||||
| dist.barrier() | |||||
| @magic_argv_env_context | |||||
| def test_unwrap_model(self): | |||||
| """ | |||||
| 测试 unwrap_model 函数 | |||||
| """ | |||||
| self.driver.unwrap_model() | |||||
| dist.barrier() | |||||
| @magic_argv_env_context | |||||
| def test_get_local_rank(self): | |||||
| """ | |||||
| 测试 get_local_rank 函数 | |||||
| """ | |||||
| self.driver.get_local_rank() | |||||
| dist.barrier() | |||||
| ############################################################################ | ############################################################################ | ||||
| # | # | ||||
| # 测试单机多卡的训练情况 | |||||
| # 测试 set_dist_repro_dataloader 函数 | |||||
| # | # | ||||
| ############################################################################ | ############################################################################ | ||||
| @magic_argv_env_context | |||||
| class SingleMachineMultiGPUTrainingTestCase: | |||||
| class TestSetDistReproDataloader: | |||||
| @classmethod | |||||
| def setup_class(cls): | |||||
| cls.driver = generate_driver(10, 10) | |||||
| def setup_method(self): | |||||
| self.dataset = PaddleNormalDataset(20) | |||||
| """ | """ | ||||
| 测试在单机多卡上使用PaddleFleetDriver进行训练。 | |||||
| 分布式训练用pytest会有些混乱 | |||||
| 传入的 `dist` 参数为具体的 ReproducibleSampler 或 ReproducibleBatchSampler 的情况 | |||||
| 此时对应 driver.load 中的情况 | |||||
| """ | """ | ||||
| def test_case1(self): | |||||
| gpus = [0, 1] | |||||
| lr = 0.0003 | |||||
| epochs = 20 | |||||
| @magic_argv_env_context | |||||
| def test_set_dist_repro_dataloader_with_dist_batch_sampler(self): | |||||
| """ | |||||
| 测试 set_dist_repro_dataloader 中 dist 为 BucketedBatchSampler 时的表现 | |||||
| """ | |||||
| dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True) | |||||
| batch_sampler = BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4) | |||||
| replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, batch_sampler, False) | |||||
| assert not (replaced_loader is dataloader) | |||||
| assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler) | |||||
| assert replaced_loader.batch_sampler is batch_sampler | |||||
| self.check_distributed_sampler(replaced_loader.batch_sampler) | |||||
| dist.barrier() | |||||
| @magic_argv_env_context | |||||
| def test_set_dist_repro_dataloader_with_dist_sampler(self): | |||||
| """ | |||||
| 测试 set_dist_repro_dataloader 中 dist 为 RandomSampler 时的表现 | |||||
| """ | |||||
| dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True) | |||||
| sampler = RandomSampler(self.dataset, shuffle=True) | |||||
| replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, sampler, False) | |||||
| assert not (replaced_loader is dataloader) | |||||
| assert isinstance(replaced_loader.batch_sampler, BatchSampler) | |||||
| assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) | |||||
| assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | |||||
| assert replaced_loader.batch_sampler.sampler is sampler | |||||
| assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size | |||||
| self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) | |||||
| dist.barrier() | |||||
| """ | |||||
| 传入的参数 `dist` 为 None 的情况,这种情况出现在 trainer 和 evaluator 的初始化过程中,用户指定了 `use_dist_sampler` | |||||
| 参数为 False。此时函数会根据 `reproducible` 的设置进行不同的处理。 | |||||
| 当 `reproducible` 为 False 时,需要根据 dataloader 的 batch_sampler 或 sampler 是否为 Reproducible 来决定 | |||||
| 是否重新实例化 dataloader | |||||
| """ | |||||
| paddle_model = PaddleNormalModel_Classification() | |||||
| @magic_argv_env_context | |||||
| def test_set_dist_repro_dataloader_with_dist_none_reproducible_true(self): | |||||
| """ | |||||
| 测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 True 时的表现 | |||||
| """ | |||||
| dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True) | |||||
| with pytest.raises(RuntimeError): | |||||
| # 应当抛出 RuntimeError | |||||
| replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, True) | |||||
| dist.barrier() | |||||
| @magic_argv_env_context | |||||
| def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_reproducible_batch_sampler(self): | |||||
| """ | |||||
| 测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 有 BucketedBatchSampler | |||||
| 时的表现 | |||||
| """ | |||||
| dataloader = DataLoader( | |||||
| self.dataset, | |||||
| batch_sampler = BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4), | |||||
| ) | |||||
| dataloader.batch_sampler.set_distributed( | |||||
| num_replicas=self.driver.world_size, | |||||
| rank=self.driver.global_rank, | |||||
| pad=True | |||||
| ) | |||||
| replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, False) | |||||
| assert not (replaced_loader is dataloader) | |||||
| assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler) | |||||
| assert replaced_loader.batch_sampler.batch_size == 4 | |||||
| self.check_distributed_sampler(dataloader.batch_sampler) | |||||
| dist.barrier() | |||||
| @magic_argv_env_context | |||||
| def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_reproducible_smpler(self): | |||||
| """ | |||||
| 测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 有 RandomSampler 时的表现 | |||||
| """ | |||||
| batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2) | |||||
| batch_sampler.sampler = RandomSampler(self.dataset, True) | |||||
| batch_sampler.sampler.set_distributed( | |||||
| num_replicas=self.driver.world_size, | |||||
| rank=self.driver.global_rank | |||||
| ) | |||||
| dataloader = DataLoader( | |||||
| self.dataset, | |||||
| batch_sampler=batch_sampler | |||||
| ) | |||||
| replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, False) | |||||
| assert not (replaced_loader is dataloader) | |||||
| assert isinstance(replaced_loader.batch_sampler, BatchSampler) | |||||
| assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | |||||
| assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) | |||||
| assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler) | |||||
| assert replaced_loader.batch_sampler.batch_size == 2 | |||||
| assert replaced_loader.batch_sampler.drop_last == False | |||||
| self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) | |||||
| dist.barrier() | |||||
| @magic_argv_env_context | |||||
| def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_normal(self): | |||||
| """ | |||||
| 测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 为一般情况时的表现 | |||||
| """ | |||||
| dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True) | |||||
| replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, False) | |||||
| assert replaced_loader is dataloader | |||||
| dist.barrier() | |||||
| paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=lr) | |||||
| """ | |||||
| 传入的参数 `dist` 为 'dist' 的情况,这种情况出现在 trainer 的初始化过程中,用户指定了 `use_dist_sampler` 参数 | |||||
| 为 True。此时函数会根据 dataloader 的 batch_sampler 或 sampler 是否为 Reproducible 来决定如何重新实例化 dataloader | |||||
| """ | |||||
| train_dataset = PaddleDataset_MNIST("train") | |||||
| test_dataset = PaddleDataset_MNIST("test") | |||||
| loss_func = paddle.nn.CrossEntropyLoss() | |||||
| @magic_argv_env_context | |||||
| def test_set_dist_repro_dataloader_with_dist_dist_dataloader_reproducible_batch_sampler(self): | |||||
| """ | |||||
| 测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader.batch_sampler 为 ReproducibleBatchSampler | |||||
| 的表现 | |||||
| """ | |||||
| dataloader = DataLoader( | |||||
| dataset=self.dataset, | |||||
| batch_sampler=BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4) | |||||
| ) | |||||
| replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False) | |||||
| assert not (replaced_loader is dataloader) | |||||
| assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler) | |||||
| assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | |||||
| assert replaced_loader.batch_sampler.batch_size == 4 | |||||
| assert replaced_loader.drop_last == dataloader.drop_last | |||||
| self.check_distributed_sampler(replaced_loader.batch_sampler) | |||||
| dist.barrier() | |||||
| @magic_argv_env_context | |||||
| def test_set_dist_repro_dataloader_with_dist_dist_dataloader_reproducible_sampler(self): | |||||
| """ | |||||
| 测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader.batch_sampler.sampler 为 ReproducibleSampler | |||||
| 的表现 | |||||
| """ | |||||
| batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2) | |||||
| batch_sampler.sampler = RandomSampler(self.dataset, True) | |||||
| dataloader = DataLoader( | |||||
| self.dataset, | |||||
| batch_sampler=batch_sampler | |||||
| ) | |||||
| replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False) | |||||
| assert not (replaced_loader is dataloader) | |||||
| assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | |||||
| assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) | |||||
| assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler) | |||||
| assert replaced_loader.batch_sampler.batch_size == 2 | |||||
| assert replaced_loader.batch_sampler.sampler.shuffle == True | |||||
| self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) | |||||
| dist.barrier() | |||||
| @magic_argv_env_context | |||||
| def test_set_dist_repro_dataloader_with_dist_dist_dataloader_normal(self): | |||||
| """ | |||||
| 测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader 为一般情况的表现 | |||||
| """ | |||||
| dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True) | |||||
| replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False) | |||||
| assert not (replaced_loader is dataloader) | |||||
| assert isinstance(replaced_loader.batch_sampler, BatchSampler) | |||||
| assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | |||||
| assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) | |||||
| assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size | |||||
| assert replaced_loader.batch_sampler.sampler.shuffle == True | |||||
| dist.barrier() | |||||
| dataloader = DataLoader(train_dataset, batch_size=100, shuffle=True) | |||||
| """ | |||||
| 传入的参数 `dist` 为 'unrepeatdist' 的情况,这种情况出现在 evaluator 的初始化过程中,用户指定了 `use_dist_sampler` 参数 | |||||
| 为 True。此时函数会根据 dataloader 的 sampler 是否为 Unrepeated 和 Reproducible 来决定如何重新实例化 dataloader | |||||
| """ | |||||
| driver = PaddleFleetDriver( | |||||
| model=paddle_model, | |||||
| parallel_device=gpus, | |||||
| @magic_argv_env_context | |||||
| def test_set_dist_repro_dataloader_with_dist_unrepeat_dataloader_reproducible_sampler(self): | |||||
| """ | |||||
| 测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader.batch_sampler.sampler 为 ReproducibleSampler | |||||
| 的表现 | |||||
| """ | |||||
| batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2) | |||||
| batch_sampler.sampler = RandomSampler(self.dataset, True) | |||||
| dataloader = DataLoader( | |||||
| self.dataset, | |||||
| batch_sampler=batch_sampler | |||||
| ) | ) | ||||
| driver.set_optimizers(paddle_opt) | |||||
| dataloader = driver.set_dist_repro_dataloader(dataloader, ) | |||||
| driver.setup() | |||||
| # 检查model_device | |||||
| self.assertEqual(driver.model_device, f"gpu:{os.environ['PADDLE_LOCAL_DEVICE_IDS']}") | |||||
| driver.barrier() | |||||
| driver.zero_grad() | |||||
| current_epoch_idx = 0 | |||||
| while current_epoch_idx < epochs: | |||||
| epoch_loss, batch = 0, 0 | |||||
| driver.set_model_mode("train") | |||||
| driver.set_sampler_epoch(dataloader, current_epoch_idx) | |||||
| for batch, (img, label) in enumerate(dataloader): | |||||
| img = paddle.to_tensor(img) | |||||
| out = driver.train_step(img) | |||||
| label + 1 | |||||
| loss = loss_func(out, label) | |||||
| epoch_loss += loss.item() | |||||
| if batch % 50 == 0: | |||||
| print("epoch:{}, batch:{}, loss: {}, rank:{}".format(current_epoch_idx, batch, loss.item(), driver.local_rank)) | |||||
| driver.backward(loss) | |||||
| driver.step() | |||||
| driver.zero_grad() | |||||
| driver.barrier() | |||||
| current_epoch_idx += 1 | |||||
| # test | |||||
| correct = 0 | |||||
| driver.set_model_mode("eval") | |||||
| for img, label in test_dataset: | |||||
| img = paddle.to_tensor(np.array(img).astype('float32').reshape(1, -1)) | |||||
| out = driver.test_step(img) | |||||
| res = paddle.nn.functional.softmax(out).argmax().item() | |||||
| label = label.item() | |||||
| if res == label: | |||||
| correct += 1 | |||||
| print("{} / {}, acc: {}".format(correct, len(test_dataset), correct / len(test_dataset))) | |||||
| replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False) | |||||
| assert not (replaced_loader is dataloader) | |||||
| assert isinstance(replaced_loader.batch_sampler, BatchSampler) | |||||
| assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | |||||
| assert isinstance(replaced_loader.batch_sampler.sampler, UnrepeatedRandomSampler) | |||||
| assert replaced_loader.batch_sampler.batch_size == 2 | |||||
| assert replaced_loader.batch_sampler.sampler.shuffle == True | |||||
| self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) | |||||
| dist.barrier() | |||||
| @magic_argv_env_context | |||||
| def test_set_dist_repro_dataloader_with_dist_unrepeat_dataloader_unrepreated_sampler(self): | |||||
| """ | |||||
| 测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader.batch_sampler.sampler 为 UnrepeatedSampler | |||||
| 的表现 | |||||
| """ | |||||
| batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2) | |||||
| batch_sampler.sampler = UnrepeatedRandomSampler(self.dataset, True) | |||||
| dataloader = DataLoader( | |||||
| self.dataset, | |||||
| batch_sampler=batch_sampler | |||||
| ) | |||||
| replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False) | |||||
| assert not (replaced_loader is dataloader) | |||||
| assert isinstance(replaced_loader.batch_sampler, BatchSampler) | |||||
| assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | |||||
| assert isinstance(replaced_loader.batch_sampler.sampler, UnrepeatedRandomSampler) | |||||
| assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler) | |||||
| assert replaced_loader.batch_sampler.batch_size == 2 | |||||
| assert replaced_loader.drop_last == dataloader.drop_last | |||||
| self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) | |||||
| dist.barrier() | |||||
| @magic_argv_env_context | |||||
| def test_set_dist_repro_dataloader_with_dist_unrepeat_dataloader_normal(self): | |||||
| """ | |||||
| 测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader 为一般情况的表现 | |||||
| """ | |||||
| dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True) | |||||
| replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False) | |||||
| assert not (replaced_loader is dataloader) | |||||
| assert isinstance(replaced_loader.batch_sampler, BatchSampler) | |||||
| assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | |||||
| assert isinstance(replaced_loader.batch_sampler.sampler, UnrepeatedSequentialSampler) | |||||
| assert replaced_loader.batch_sampler.batch_size == 4 | |||||
| assert replaced_loader.drop_last == dataloader.drop_last | |||||
| self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) | |||||
| dist.barrier() | |||||
| def check_distributed_sampler(self, sampler): | |||||
| """ | |||||
| 测试替换得到的 sampler 或 batch_sampler 的分布式设置是否正确 | |||||
| """ | |||||
| assert sampler.num_replicas == dist.get_world_size() | |||||
| assert sampler.rank == dist.get_rank() | |||||
| if not isinstance(sampler, UnrepeatedSampler): | |||||
| assert sampler.pad == True | |||||