| @@ -1,21 +1,35 @@ | |||
| from dataclasses import replace | |||
| import pytest | |||
| import os | |||
| import numpy as np | |||
| from fastNLP.envs.set_env_on_import import set_env_on_import_paddle | |||
| set_env_on_import_paddle() | |||
| os.environ["FASTNLP_BACKEND"] = "paddle" | |||
| from fastNLP.core.drivers.paddle_driver.fleet import PaddleFleetDriver | |||
| from fastNLP.core.samplers import ( | |||
| RandomSampler, | |||
| UnrepeatedSampler, | |||
| BucketedBatchSampler, | |||
| UnrepeatedRandomSampler, | |||
| UnrepeatedSequentialSampler, | |||
| ) | |||
| from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 | |||
| from tests.helpers.datasets.paddle_data import PaddleNormalDataset | |||
| from tests.helpers.utils import magic_argv_env_context | |||
| import paddle | |||
| import paddle.distributed as dist | |||
| from paddle.io import DataLoader | |||
| from paddle.io import DataLoader, BatchSampler | |||
| from fastNLP.core.drivers.paddle_driver.fleet import PaddleFleetDriver | |||
| from fastNLP.core.samplers.reproducible_sampler import RandomSampler | |||
| from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK | |||
| from tests.helpers.models.paddle_model import PaddleNormalModel_Classification | |||
| from tests.helpers.datasets.paddle_data import PaddleDataset_MNIST | |||
| from tests.helpers.utils import magic_argv_env_context | |||
| from fastNLP.core import synchronize_safe_rm | |||
| def generate_driver(num_labels, feature_dimension): | |||
| paddle_model = PaddleNormalModel_Classification_1(num_labels, feature_dimension) | |||
| paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01) | |||
| driver = PaddleFleetDriver( | |||
| model=paddle_model, | |||
| parallel_device=[0,1], | |||
| ) | |||
| driver.set_optimizers(paddle_opt) | |||
| driver.setup() | |||
| return driver | |||
| ############################################################################ | |||
| # | |||
| @@ -23,269 +37,340 @@ from fastNLP.core import synchronize_safe_rm | |||
| # | |||
| ############################################################################ | |||
| @magic_argv_env_context | |||
| def test_move_data_to_device(): | |||
| """ | |||
| 这个函数仅调用了paddle_move_data_to_device,测试例在tests/core/utils/test_paddle_utils.py中 | |||
| 就不重复测试了 | |||
| """ | |||
| try: | |||
| paddle_model = PaddleNormalModel_Classification(10, 784) | |||
| paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01) | |||
| driver = PaddleFleetDriver( | |||
| model=paddle_model, | |||
| parallel_device=[0,1], | |||
| ) | |||
| driver.set_optimizers(paddle_opt) | |||
| # 区分launch和子进程setup的时候 | |||
| if FASTNLP_DISTRIBUTED_CHECK not in os.environ: | |||
| with pytest.raises(SystemExit) as e: | |||
| driver.setup() | |||
| assert e.value.code == 0 | |||
| return | |||
| else: | |||
| driver.setup() | |||
| driver.move_data_to_device(paddle.rand((32, 64))) | |||
| finally: | |||
| synchronize_safe_rm("log") | |||
| dist.barrier() | |||
| @magic_argv_env_context | |||
| def test_is_distributed(): | |||
| print(os.getenv("CUDA_VISIBLE_DEVICES")) | |||
| print(paddle.device.get_device()) | |||
| try: | |||
| paddle_model = PaddleNormalModel_Classification(10, 784) | |||
| paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01) | |||
| driver = PaddleFleetDriver( | |||
| model=paddle_model, | |||
| parallel_device=[0,1], | |||
| output_from_new_proc='all' | |||
| ) | |||
| driver.set_optimizers(paddle_opt) | |||
| # 区分launch和子进程setup的时候 | |||
| if FASTNLP_DISTRIBUTED_CHECK not in os.environ: | |||
| with pytest.raises(SystemExit) as e: | |||
| driver.setup() | |||
| assert e.value.code == 0 | |||
| return | |||
| else: | |||
| driver.setup() | |||
| assert driver.is_distributed() == True | |||
| finally: | |||
| synchronize_safe_rm("log") | |||
| dist.barrier() | |||
| @magic_argv_env_context | |||
| def test_get_no_sync_context(): | |||
| class TestFleetDriverFunction: | |||
| """ | |||
| 测试能否运行 | |||
| 测试 PaddleFleetDriver 一些简单函数的测试类,基本都是测试能否运行、是否存在 import 错误等问题 | |||
| """ | |||
| try: | |||
| paddle_model = PaddleNormalModel_Classification(10, 784) | |||
| paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01) | |||
| driver = PaddleFleetDriver( | |||
| model=paddle_model, | |||
| parallel_device=[0,1], | |||
| ) | |||
| driver.set_optimizers(paddle_opt) | |||
| # 区分launch和子进程setup的时候 | |||
| if FASTNLP_DISTRIBUTED_CHECK not in os.environ: | |||
| with pytest.raises(SystemExit) as e: | |||
| driver.setup() | |||
| assert e.value.code == 0 | |||
| return | |||
| else: | |||
| driver.setup() | |||
| res = driver.get_no_sync_context() | |||
| finally: | |||
| synchronize_safe_rm("log") | |||
| dist.barrier() | |||
| @magic_argv_env_context | |||
| def test_is_global_zero(): | |||
| try: | |||
| paddle_model = PaddleNormalModel_Classification(10, 784) | |||
| paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01) | |||
| driver = PaddleFleetDriver( | |||
| model=paddle_model, | |||
| parallel_device=[0,1], | |||
| ) | |||
| driver.set_optimizers(paddle_opt) | |||
| # 区分launch和子进程setup的时候 | |||
| if FASTNLP_DISTRIBUTED_CHECK not in os.environ: | |||
| with pytest.raises(SystemExit) as e: | |||
| driver.setup() | |||
| assert e.value.code == 0 | |||
| return | |||
| else: | |||
| driver.setup() | |||
| driver.is_global_zero() | |||
| finally: | |||
| synchronize_safe_rm("log") | |||
| dist.barrier() | |||
| @magic_argv_env_context | |||
| def test_unwrap_model(): | |||
| try: | |||
| paddle_model = PaddleNormalModel_Classification(10, 784) | |||
| paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01) | |||
| driver = PaddleFleetDriver( | |||
| model=paddle_model, | |||
| parallel_device=[0,1], | |||
| ) | |||
| driver.set_optimizers(paddle_opt) | |||
| # 区分launch和子进程setup的时候 | |||
| if FASTNLP_DISTRIBUTED_CHECK not in os.environ: | |||
| with pytest.raises(SystemExit) as e: | |||
| driver.setup() | |||
| assert e.value.code == 0 | |||
| return | |||
| else: | |||
| driver.setup() | |||
| driver.unwrap_model() | |||
| finally: | |||
| synchronize_safe_rm("log") | |||
| dist.barrier() | |||
| @magic_argv_env_context | |||
| def test_get_local_rank(): | |||
| try: | |||
| paddle_model = PaddleNormalModel_Classification(10, 784) | |||
| paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01) | |||
| driver = PaddleFleetDriver( | |||
| model=paddle_model, | |||
| parallel_device=[0,1], | |||
| ) | |||
| driver.set_optimizers(paddle_opt) | |||
| # 区分launch和子进程setup的时候 | |||
| if FASTNLP_DISTRIBUTED_CHECK not in os.environ: | |||
| with pytest.raises(SystemExit) as e: | |||
| driver.setup() | |||
| assert e.value.code == 0 | |||
| return | |||
| else: | |||
| driver.setup() | |||
| driver.get_local_rank() | |||
| finally: | |||
| synchronize_safe_rm("log") | |||
| dist.barrier() | |||
| @magic_argv_env_context | |||
| @pytest.mark.parametrize( | |||
| "dist_sampler", | |||
| ["dist", "unrepeatdist", RandomSampler(PaddleDataset_MNIST("train"))] | |||
| ) | |||
| @pytest.mark.parametrize( | |||
| "reproducible", | |||
| [True, False] | |||
| ) | |||
| def test_replace_sampler(dist_sampler, reproducible): | |||
| """ | |||
| 测试replace_sampler | |||
| """ | |||
| try: | |||
| paddle_model = PaddleNormalModel_Classification(10, 784) | |||
| paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01) | |||
| driver = PaddleFleetDriver( | |||
| model=paddle_model, | |||
| parallel_device=[0,1], | |||
| ) | |||
| driver.set_optimizers(paddle_opt) | |||
| # 区分launch和子进程setup的时候 | |||
| if FASTNLP_DISTRIBUTED_CHECK not in os.environ: | |||
| with pytest.raises(SystemExit) as e: | |||
| driver.setup() | |||
| assert e.value.code == 0 | |||
| return | |||
| else: | |||
| driver.setup() | |||
| dataloader = DataLoader(PaddleDataset_MNIST("train"), batch_size=100, shuffle=True) | |||
| driver.set_dist_repro_dataloader(dataloader, dist_sampler, reproducible) | |||
| finally: | |||
| synchronize_safe_rm("log") | |||
| dist.barrier() | |||
| @classmethod | |||
| def setup_class(cls): | |||
| cls.driver = generate_driver(10, 10) | |||
| @magic_argv_env_context | |||
| def test_move_data_to_device(self): | |||
| """ | |||
| 这个函数仅调用了paddle_move_data_to_device,测试例在tests/core/utils/test_paddle_utils.py中 | |||
| 就不重复测试了 | |||
| """ | |||
| self.driver.move_data_to_device(paddle.rand((32, 64))) | |||
| dist.barrier() | |||
| @magic_argv_env_context | |||
| def test_is_distributed(self): | |||
| """ | |||
| 测试 is_distributed 函数 | |||
| """ | |||
| assert self.driver.is_distributed() == True | |||
| dist.barrier() | |||
| @magic_argv_env_context | |||
| def test_get_no_sync_context(self): | |||
| """ | |||
| 测试 get_no_sync_context 函数 | |||
| """ | |||
| res = self.driver.get_no_sync_context() | |||
| dist.barrier() | |||
| @magic_argv_env_context | |||
| def test_is_global_zero(self): | |||
| """ | |||
| 测试 is_global_zero 函数 | |||
| """ | |||
| self.driver.is_global_zero() | |||
| dist.barrier() | |||
| @magic_argv_env_context | |||
| def test_unwrap_model(self): | |||
| """ | |||
| 测试 unwrap_model 函数 | |||
| """ | |||
| self.driver.unwrap_model() | |||
| dist.barrier() | |||
| @magic_argv_env_context | |||
| def test_get_local_rank(self): | |||
| """ | |||
| 测试 get_local_rank 函数 | |||
| """ | |||
| self.driver.get_local_rank() | |||
| dist.barrier() | |||
| ############################################################################ | |||
| # | |||
| # 测试单机多卡的训练情况 | |||
| # 测试 set_dist_repro_dataloader 函数 | |||
| # | |||
| ############################################################################ | |||
| @magic_argv_env_context | |||
| class SingleMachineMultiGPUTrainingTestCase: | |||
| class TestSetDistReproDataloader: | |||
| @classmethod | |||
| def setup_class(cls): | |||
| cls.driver = generate_driver(10, 10) | |||
| def setup_method(self): | |||
| self.dataset = PaddleNormalDataset(20) | |||
| """ | |||
| 测试在单机多卡上使用PaddleFleetDriver进行训练。 | |||
| 分布式训练用pytest会有些混乱 | |||
| 传入的 `dist` 参数为具体的 ReproducibleSampler 或 ReproducibleBatchSampler 的情况 | |||
| 此时对应 driver.load 中的情况 | |||
| """ | |||
| def test_case1(self): | |||
| gpus = [0, 1] | |||
| lr = 0.0003 | |||
| epochs = 20 | |||
| @magic_argv_env_context | |||
| def test_set_dist_repro_dataloader_with_dist_batch_sampler(self): | |||
| """ | |||
| 测试 set_dist_repro_dataloader 中 dist 为 BucketedBatchSampler 时的表现 | |||
| """ | |||
| dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True) | |||
| batch_sampler = BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4) | |||
| replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, batch_sampler, False) | |||
| assert not (replaced_loader is dataloader) | |||
| assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler) | |||
| assert replaced_loader.batch_sampler is batch_sampler | |||
| self.check_distributed_sampler(replaced_loader.batch_sampler) | |||
| dist.barrier() | |||
| @magic_argv_env_context | |||
| def test_set_dist_repro_dataloader_with_dist_sampler(self): | |||
| """ | |||
| 测试 set_dist_repro_dataloader 中 dist 为 RandomSampler 时的表现 | |||
| """ | |||
| dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True) | |||
| sampler = RandomSampler(self.dataset, shuffle=True) | |||
| replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, sampler, False) | |||
| assert not (replaced_loader is dataloader) | |||
| assert isinstance(replaced_loader.batch_sampler, BatchSampler) | |||
| assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) | |||
| assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | |||
| assert replaced_loader.batch_sampler.sampler is sampler | |||
| assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size | |||
| self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) | |||
| dist.barrier() | |||
| """ | |||
| 传入的参数 `dist` 为 None 的情况,这种情况出现在 trainer 和 evaluator 的初始化过程中,用户指定了 `use_dist_sampler` | |||
| 参数为 False。此时函数会根据 `reproducible` 的设置进行不同的处理。 | |||
| 当 `reproducible` 为 False 时,需要根据 dataloader 的 batch_sampler 或 sampler 是否为 Reproducible 来决定 | |||
| 是否重新实例化 dataloader | |||
| """ | |||
| paddle_model = PaddleNormalModel_Classification() | |||
| @magic_argv_env_context | |||
| def test_set_dist_repro_dataloader_with_dist_none_reproducible_true(self): | |||
| """ | |||
| 测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 True 时的表现 | |||
| """ | |||
| dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True) | |||
| with pytest.raises(RuntimeError): | |||
| # 应当抛出 RuntimeError | |||
| replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, True) | |||
| dist.barrier() | |||
| @magic_argv_env_context | |||
| def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_reproducible_batch_sampler(self): | |||
| """ | |||
| 测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 有 BucketedBatchSampler | |||
| 时的表现 | |||
| """ | |||
| dataloader = DataLoader( | |||
| self.dataset, | |||
| batch_sampler = BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4), | |||
| ) | |||
| dataloader.batch_sampler.set_distributed( | |||
| num_replicas=self.driver.world_size, | |||
| rank=self.driver.global_rank, | |||
| pad=True | |||
| ) | |||
| replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, False) | |||
| assert not (replaced_loader is dataloader) | |||
| assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler) | |||
| assert replaced_loader.batch_sampler.batch_size == 4 | |||
| self.check_distributed_sampler(dataloader.batch_sampler) | |||
| dist.barrier() | |||
| @magic_argv_env_context | |||
| def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_reproducible_smpler(self): | |||
| """ | |||
| 测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 有 RandomSampler 时的表现 | |||
| """ | |||
| batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2) | |||
| batch_sampler.sampler = RandomSampler(self.dataset, True) | |||
| batch_sampler.sampler.set_distributed( | |||
| num_replicas=self.driver.world_size, | |||
| rank=self.driver.global_rank | |||
| ) | |||
| dataloader = DataLoader( | |||
| self.dataset, | |||
| batch_sampler=batch_sampler | |||
| ) | |||
| replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, False) | |||
| assert not (replaced_loader is dataloader) | |||
| assert isinstance(replaced_loader.batch_sampler, BatchSampler) | |||
| assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | |||
| assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) | |||
| assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler) | |||
| assert replaced_loader.batch_sampler.batch_size == 2 | |||
| assert replaced_loader.batch_sampler.drop_last == False | |||
| self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) | |||
| dist.barrier() | |||
| @magic_argv_env_context | |||
| def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_normal(self): | |||
| """ | |||
| 测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 为一般情况时的表现 | |||
| """ | |||
| dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True) | |||
| replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, False) | |||
| assert replaced_loader is dataloader | |||
| dist.barrier() | |||
| paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=lr) | |||
| """ | |||
| 传入的参数 `dist` 为 'dist' 的情况,这种情况出现在 trainer 的初始化过程中,用户指定了 `use_dist_sampler` 参数 | |||
| 为 True。此时函数会根据 dataloader 的 batch_sampler 或 sampler 是否为 Reproducible 来决定如何重新实例化 dataloader | |||
| """ | |||
| train_dataset = PaddleDataset_MNIST("train") | |||
| test_dataset = PaddleDataset_MNIST("test") | |||
| loss_func = paddle.nn.CrossEntropyLoss() | |||
| @magic_argv_env_context | |||
| def test_set_dist_repro_dataloader_with_dist_dist_dataloader_reproducible_batch_sampler(self): | |||
| """ | |||
| 测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader.batch_sampler 为 ReproducibleBatchSampler | |||
| 的表现 | |||
| """ | |||
| dataloader = DataLoader( | |||
| dataset=self.dataset, | |||
| batch_sampler=BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4) | |||
| ) | |||
| replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False) | |||
| assert not (replaced_loader is dataloader) | |||
| assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler) | |||
| assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | |||
| assert replaced_loader.batch_sampler.batch_size == 4 | |||
| assert replaced_loader.drop_last == dataloader.drop_last | |||
| self.check_distributed_sampler(replaced_loader.batch_sampler) | |||
| dist.barrier() | |||
| @magic_argv_env_context | |||
| def test_set_dist_repro_dataloader_with_dist_dist_dataloader_reproducible_sampler(self): | |||
| """ | |||
| 测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader.batch_sampler.sampler 为 ReproducibleSampler | |||
| 的表现 | |||
| """ | |||
| batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2) | |||
| batch_sampler.sampler = RandomSampler(self.dataset, True) | |||
| dataloader = DataLoader( | |||
| self.dataset, | |||
| batch_sampler=batch_sampler | |||
| ) | |||
| replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False) | |||
| assert not (replaced_loader is dataloader) | |||
| assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | |||
| assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) | |||
| assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler) | |||
| assert replaced_loader.batch_sampler.batch_size == 2 | |||
| assert replaced_loader.batch_sampler.sampler.shuffle == True | |||
| self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) | |||
| dist.barrier() | |||
| @magic_argv_env_context | |||
| def test_set_dist_repro_dataloader_with_dist_dist_dataloader_normal(self): | |||
| """ | |||
| 测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader 为一般情况的表现 | |||
| """ | |||
| dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True) | |||
| replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False) | |||
| assert not (replaced_loader is dataloader) | |||
| assert isinstance(replaced_loader.batch_sampler, BatchSampler) | |||
| assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | |||
| assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) | |||
| assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size | |||
| assert replaced_loader.batch_sampler.sampler.shuffle == True | |||
| dist.barrier() | |||
| dataloader = DataLoader(train_dataset, batch_size=100, shuffle=True) | |||
| """ | |||
| 传入的参数 `dist` 为 'unrepeatdist' 的情况,这种情况出现在 evaluator 的初始化过程中,用户指定了 `use_dist_sampler` 参数 | |||
| 为 True。此时函数会根据 dataloader 的 sampler 是否为 Unrepeated 和 Reproducible 来决定如何重新实例化 dataloader | |||
| """ | |||
| driver = PaddleFleetDriver( | |||
| model=paddle_model, | |||
| parallel_device=gpus, | |||
| @magic_argv_env_context | |||
| def test_set_dist_repro_dataloader_with_dist_unrepeat_dataloader_reproducible_sampler(self): | |||
| """ | |||
| 测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader.batch_sampler.sampler 为 ReproducibleSampler | |||
| 的表现 | |||
| """ | |||
| batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2) | |||
| batch_sampler.sampler = RandomSampler(self.dataset, True) | |||
| dataloader = DataLoader( | |||
| self.dataset, | |||
| batch_sampler=batch_sampler | |||
| ) | |||
| driver.set_optimizers(paddle_opt) | |||
| dataloader = driver.set_dist_repro_dataloader(dataloader, ) | |||
| driver.setup() | |||
| # 检查model_device | |||
| self.assertEqual(driver.model_device, f"gpu:{os.environ['PADDLE_LOCAL_DEVICE_IDS']}") | |||
| driver.barrier() | |||
| driver.zero_grad() | |||
| current_epoch_idx = 0 | |||
| while current_epoch_idx < epochs: | |||
| epoch_loss, batch = 0, 0 | |||
| driver.set_model_mode("train") | |||
| driver.set_sampler_epoch(dataloader, current_epoch_idx) | |||
| for batch, (img, label) in enumerate(dataloader): | |||
| img = paddle.to_tensor(img) | |||
| out = driver.train_step(img) | |||
| label + 1 | |||
| loss = loss_func(out, label) | |||
| epoch_loss += loss.item() | |||
| if batch % 50 == 0: | |||
| print("epoch:{}, batch:{}, loss: {}, rank:{}".format(current_epoch_idx, batch, loss.item(), driver.local_rank)) | |||
| driver.backward(loss) | |||
| driver.step() | |||
| driver.zero_grad() | |||
| driver.barrier() | |||
| current_epoch_idx += 1 | |||
| # test | |||
| correct = 0 | |||
| driver.set_model_mode("eval") | |||
| for img, label in test_dataset: | |||
| img = paddle.to_tensor(np.array(img).astype('float32').reshape(1, -1)) | |||
| out = driver.test_step(img) | |||
| res = paddle.nn.functional.softmax(out).argmax().item() | |||
| label = label.item() | |||
| if res == label: | |||
| correct += 1 | |||
| print("{} / {}, acc: {}".format(correct, len(test_dataset), correct / len(test_dataset))) | |||
| replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False) | |||
| assert not (replaced_loader is dataloader) | |||
| assert isinstance(replaced_loader.batch_sampler, BatchSampler) | |||
| assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | |||
| assert isinstance(replaced_loader.batch_sampler.sampler, UnrepeatedRandomSampler) | |||
| assert replaced_loader.batch_sampler.batch_size == 2 | |||
| assert replaced_loader.batch_sampler.sampler.shuffle == True | |||
| self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) | |||
| dist.barrier() | |||
| @magic_argv_env_context | |||
| def test_set_dist_repro_dataloader_with_dist_unrepeat_dataloader_unrepreated_sampler(self): | |||
| """ | |||
| 测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader.batch_sampler.sampler 为 UnrepeatedSampler | |||
| 的表现 | |||
| """ | |||
| batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2) | |||
| batch_sampler.sampler = UnrepeatedRandomSampler(self.dataset, True) | |||
| dataloader = DataLoader( | |||
| self.dataset, | |||
| batch_sampler=batch_sampler | |||
| ) | |||
| replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False) | |||
| assert not (replaced_loader is dataloader) | |||
| assert isinstance(replaced_loader.batch_sampler, BatchSampler) | |||
| assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | |||
| assert isinstance(replaced_loader.batch_sampler.sampler, UnrepeatedRandomSampler) | |||
| assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler) | |||
| assert replaced_loader.batch_sampler.batch_size == 2 | |||
| assert replaced_loader.drop_last == dataloader.drop_last | |||
| self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) | |||
| dist.barrier() | |||
| @magic_argv_env_context | |||
| def test_set_dist_repro_dataloader_with_dist_unrepeat_dataloader_normal(self): | |||
| """ | |||
| 测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader 为一般情况的表现 | |||
| """ | |||
| dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True) | |||
| replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False) | |||
| assert not (replaced_loader is dataloader) | |||
| assert isinstance(replaced_loader.batch_sampler, BatchSampler) | |||
| assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | |||
| assert isinstance(replaced_loader.batch_sampler.sampler, UnrepeatedSequentialSampler) | |||
| assert replaced_loader.batch_sampler.batch_size == 4 | |||
| assert replaced_loader.drop_last == dataloader.drop_last | |||
| self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) | |||
| dist.barrier() | |||
| def check_distributed_sampler(self, sampler): | |||
| """ | |||
| 测试替换得到的 sampler 或 batch_sampler 的分布式设置是否正确 | |||
| """ | |||
| assert sampler.num_replicas == dist.get_world_size() | |||
| assert sampler.rank == dist.get_rank() | |||
| if not isinstance(sampler, UnrepeatedSampler): | |||
| assert sampler.pad == True | |||