Browse Source

paddle fleet set_dist_repro_dataloader的测试例

tags/v1.0.0alpha
x54-729 3 years ago
parent
commit
de544707d9
2 changed files with 347 additions and 17 deletions
  1. +1
    -3
      fastNLP/core/drivers/paddle_driver/fleet.py
  2. +346
    -14
      tests/core/drivers/paddle_driver/test_fleet.py

+ 1
- 3
fastNLP/core/drivers/paddle_driver/fleet.py View File

@@ -1,6 +1,5 @@
import os import os
import shutil import shutil
from functools import partial
from typing import List, Union, Optional, Dict, Tuple, Callable from typing import List, Union, Optional, Dict, Tuple, Callable


from .paddle_driver import PaddleDriver from .paddle_driver import PaddleDriver
@@ -38,7 +37,6 @@ if _NEED_IMPORT_PADDLE:
from paddle import DataParallel from paddle import DataParallel
import paddle.distributed.fleet as fleet import paddle.distributed.fleet as fleet
import paddle.distributed as paddledist import paddle.distributed as paddledist
from paddle.io import BatchSampler
from paddle.optimizer import Optimizer from paddle.optimizer import Optimizer
from paddle.fluid.reader import _DatasetKind from paddle.fluid.reader import _DatasetKind
from paddle.fluid.dygraph import parallel_helper from paddle.fluid.dygraph import parallel_helper
@@ -305,7 +303,7 @@ class PaddleFleetDriver(PaddleDriver):
raise RuntimeError(f"There is no `{fn}` method in your model.") raise RuntimeError(f"There is no `{fn}` method in your model.")
else: else:
if hasattr(model, fn): if hasattr(model, fn):
logger.warning("Notice your model is a `DistributedDataParallel` model. And your model also implements "
logger.warning("Notice your model is a `DataParallel` model. And your model also implements "
f"the `{fn}` method, which we can not call actually, we will" f"the `{fn}` method, which we can not call actually, we will"
" call `forward` function instead of `train_step` and you should note that.") " call `forward` function instead of `train_step` and you should note that.")
elif fn not in {"train_step", "evaluate_step"}: elif fn not in {"train_step", "evaluate_step"}:


+ 346
- 14
tests/core/drivers/paddle_driver/test_fleet.py View File

@@ -12,28 +12,44 @@ from fastNLP.core.samplers import (
UnrepeatedSequentialSampler, UnrepeatedSequentialSampler,
) )
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1
from tests.helpers.datasets.paddle_data import PaddleNormalDataset
from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleRandomMaxDataset
from tests.helpers.utils import magic_argv_env_context from tests.helpers.utils import magic_argv_env_context
from fastNLP.core import synchronize_safe_rm


import paddle import paddle
import paddle.distributed as dist import paddle.distributed as dist
from paddle.io import DataLoader, BatchSampler from paddle.io import DataLoader, BatchSampler


def generate_driver(num_labels, feature_dimension):
def generate_driver(num_labels, feature_dimension, device=[0,1], fp16=False):
paddle_model = PaddleNormalModel_Classification_1(num_labels, feature_dimension) paddle_model = PaddleNormalModel_Classification_1(num_labels, feature_dimension)
paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01) paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01)
driver = PaddleFleetDriver( driver = PaddleFleetDriver(
model=paddle_model, model=paddle_model,
parallel_device=[0,1],
parallel_device=device,
fp16=fp16,
) )
driver.set_optimizers(paddle_opt) driver.set_optimizers(paddle_opt)
driver.setup() driver.setup()


return driver return driver


@magic_argv_env_context
def test_multi_drivers():
"""
测试使用了多个 PaddleFleetDriver 的情况。
"""
driver1 = generate_driver(10, 10)
driver2 = generate_driver(20, 10)

with pytest.raises(RuntimeError):
# 设备设置不同,应该报错
driver3 = generate_driver(20, 3, device=[0,2])

dist.barrier()

############################################################################ ############################################################################
# #
# 测试PaddleFleetDriver的一些函数
# 测试 PaddleFleetDriver 的一些函数
# #
############################################################################ ############################################################################


@@ -106,10 +122,11 @@ class TestSetDistReproDataloader:


@classmethod @classmethod
def setup_class(cls): def setup_class(cls):
cls.driver = generate_driver(10, 10)
cls.device = [0, 1]
cls.driver = generate_driver(10, 10, device=cls.device)


def setup_method(self): def setup_method(self):
self.dataset = PaddleNormalDataset(20)
self.dataset = PaddleNormalDataset(40)


""" """
传入的 `dist` 参数为具体的 ReproducibleSampler 或 ReproducibleBatchSampler 的情况 传入的 `dist` 参数为具体的 ReproducibleSampler 或 ReproducibleBatchSampler 的情况
@@ -121,6 +138,7 @@ class TestSetDistReproDataloader:
def test_set_dist_repro_dataloader_with_dist_batch_sampler(self, shuffle): def test_set_dist_repro_dataloader_with_dist_batch_sampler(self, shuffle):
""" """
测试 set_dist_repro_dataloader 中 dist 为 BucketedBatchSampler 时的表现 测试 set_dist_repro_dataloader 中 dist 为 BucketedBatchSampler 时的表现
此时应该将 batch_sampler 替换为 dist 对应的 BucketedBatchSampler
""" """
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=not shuffle) dataloader = DataLoader(self.dataset, batch_size=4, shuffle=not shuffle)
batch_sampler = BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4, shuffle=shuffle) batch_sampler = BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4, shuffle=shuffle)
@@ -130,6 +148,7 @@ class TestSetDistReproDataloader:
assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler) assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler)
assert replaced_loader.batch_sampler is batch_sampler assert replaced_loader.batch_sampler is batch_sampler
self.check_distributed_sampler(replaced_loader.batch_sampler) self.check_distributed_sampler(replaced_loader.batch_sampler)
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)
dist.barrier() dist.barrier()


@@ -138,6 +157,7 @@ class TestSetDistReproDataloader:
def test_set_dist_repro_dataloader_with_dist_sampler(self, shuffle): def test_set_dist_repro_dataloader_with_dist_sampler(self, shuffle):
""" """
测试 set_dist_repro_dataloader 中 dist 为 RandomSampler 时的表现 测试 set_dist_repro_dataloader 中 dist 为 RandomSampler 时的表现
此时应该将 batch_sampler.sampler 替换为 dist 对应的 RandomSampler
""" """
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=not shuffle) dataloader = DataLoader(self.dataset, batch_size=4, shuffle=not shuffle)
sampler = RandomSampler(self.dataset, shuffle=shuffle) sampler = RandomSampler(self.dataset, shuffle=shuffle)
@@ -150,6 +170,7 @@ class TestSetDistReproDataloader:
assert replaced_loader.batch_sampler.sampler is sampler assert replaced_loader.batch_sampler.sampler is sampler
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)


dist.barrier() dist.barrier()
@@ -164,6 +185,7 @@ class TestSetDistReproDataloader:
def test_set_dist_repro_dataloader_with_dist_none_reproducible_true(self): def test_set_dist_repro_dataloader_with_dist_none_reproducible_true(self):
""" """
测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 True 时的表现 测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 True 时的表现
当用户在 driver 之外初始化了分布式环境时,fastnlp 不支持进行断点重训,此时应该报错
""" """
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True) dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True)
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
@@ -178,6 +200,8 @@ class TestSetDistReproDataloader:
""" """
测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 有 BucketedBatchSampler 测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 有 BucketedBatchSampler
时的表现 时的表现
此时传入的 dataloader 的 batch_sampler 应该已经执行了 set_distributed,产生一个新的 dataloader,其 batch_sampler
和原 dataloader 相同
""" """
dataloader = DataLoader( dataloader = DataLoader(
self.dataset, self.dataset,
@@ -194,6 +218,7 @@ class TestSetDistReproDataloader:
assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler) assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler)
assert replaced_loader.batch_sampler.batch_size == 4 assert replaced_loader.batch_sampler.batch_size == 4
self.check_distributed_sampler(dataloader.batch_sampler) self.check_distributed_sampler(dataloader.batch_sampler)
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)


dist.barrier() dist.barrier()


@@ -202,8 +227,10 @@ class TestSetDistReproDataloader:
def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_reproducible_smpler(self, shuffle): def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_reproducible_smpler(self, shuffle):
""" """
测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 有 RandomSampler 时的表现 测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 有 RandomSampler 时的表现
此时传入的 dataloader 的 batch_sampler.sampler 应该已经执行了 set_distributed,产生一个新的 dataloader,其
batch_sampler.sampler 和原 dataloader 相同
""" """
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2)
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=4)
batch_sampler.sampler = RandomSampler(self.dataset, shuffle) batch_sampler.sampler = RandomSampler(self.dataset, shuffle)
batch_sampler.sampler.set_distributed( batch_sampler.sampler.set_distributed(
num_replicas=self.driver.world_size, num_replicas=self.driver.world_size,
@@ -220,9 +247,11 @@ class TestSetDistReproDataloader:
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler) assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler)
assert replaced_loader.batch_sampler.batch_size == 2
assert replaced_loader.batch_sampler.batch_size == 4
assert replaced_loader.batch_sampler.drop_last == False assert replaced_loader.batch_sampler.drop_last == False
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)
dist.barrier() dist.barrier()


@magic_argv_env_context @magic_argv_env_context
@@ -230,6 +259,7 @@ class TestSetDistReproDataloader:
def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_normal(self, shuffle): def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_normal(self, shuffle):
""" """
测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 为一般情况时的表现 测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 为一般情况时的表现
此时直接返回原来的 dataloader,不做任何处理。
""" """
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle) dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, False) replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, False)
@@ -248,6 +278,7 @@ class TestSetDistReproDataloader:
""" """
测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader.batch_sampler 为 ReproducibleBatchSampler 测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader.batch_sampler 为 ReproducibleBatchSampler
的表现 的表现
此时应该返回一个新的 dataloader,其batch_sampler 和原 dataloader 相同,且应该正确地设置了分布式相关的属性
""" """
dataloader = DataLoader( dataloader = DataLoader(
dataset=self.dataset, dataset=self.dataset,
@@ -261,6 +292,7 @@ class TestSetDistReproDataloader:
assert replaced_loader.batch_sampler.batch_size == 4 assert replaced_loader.batch_sampler.batch_size == 4
assert replaced_loader.drop_last == dataloader.drop_last assert replaced_loader.drop_last == dataloader.drop_last
self.check_distributed_sampler(replaced_loader.batch_sampler) self.check_distributed_sampler(replaced_loader.batch_sampler)
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)
dist.barrier() dist.barrier()


@magic_argv_env_context @magic_argv_env_context
@@ -269,8 +301,10 @@ class TestSetDistReproDataloader:
""" """
测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader.batch_sampler.sampler 为 ReproducibleSampler 测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader.batch_sampler.sampler 为 ReproducibleSampler
的表现 的表现
此时应该返回一个新的 dataloader,其 batch_sampler.sampler 和原 dataloader 相同,且应该正确地设置了分布式相关
的属性
""" """
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2, shuffle=shuffle)
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=4, shuffle=shuffle)
batch_sampler.sampler = RandomSampler(self.dataset, shuffle) batch_sampler.sampler = RandomSampler(self.dataset, shuffle)
dataloader = DataLoader( dataloader = DataLoader(
self.dataset, self.dataset,
@@ -282,9 +316,10 @@ class TestSetDistReproDataloader:
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler) assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler)
assert replaced_loader.batch_sampler.batch_size == 2
assert replaced_loader.batch_sampler.batch_size == 4
assert replaced_loader.batch_sampler.sampler.shuffle == shuffle assert replaced_loader.batch_sampler.sampler.shuffle == shuffle
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)
dist.barrier() dist.barrier()


@magic_argv_env_context @magic_argv_env_context
@@ -292,6 +327,8 @@ class TestSetDistReproDataloader:
def test_set_dist_repro_dataloader_with_dist_dist_dataloader_normal(self, shuffle): def test_set_dist_repro_dataloader_with_dist_dist_dataloader_normal(self, shuffle):
""" """
测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader 为一般情况的表现 测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader 为一般情况的表现
此时应该返回一个新的 dataloader,并替换其 batch_sampler.sampler 为 RandomSampler,且应该正确设置了分布式相关
的属性
""" """
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle) dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False) replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False)
@@ -302,6 +339,8 @@ class TestSetDistReproDataloader:
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size
assert replaced_loader.batch_sampler.sampler.shuffle == shuffle assert replaced_loader.batch_sampler.sampler.shuffle == shuffle
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)
dist.barrier() dist.barrier()


""" """
@@ -315,8 +354,10 @@ class TestSetDistReproDataloader:
""" """
测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader.batch_sampler.sampler 为 ReproducibleSampler 测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader.batch_sampler.sampler 为 ReproducibleSampler
的表现 的表现
此时应该返回一个新的 dataloader,且将原来的 Sampler 替换为 UnrepeatedRandomSampler,且正确地设置了分布式相关
的属性
""" """
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2)
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=4)
batch_sampler.sampler = RandomSampler(self.dataset, shuffle) batch_sampler.sampler = RandomSampler(self.dataset, shuffle)
dataloader = DataLoader( dataloader = DataLoader(
self.dataset, self.dataset,
@@ -328,9 +369,10 @@ class TestSetDistReproDataloader:
assert isinstance(replaced_loader.batch_sampler, BatchSampler) assert isinstance(replaced_loader.batch_sampler, BatchSampler)
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
assert isinstance(replaced_loader.batch_sampler.sampler, UnrepeatedRandomSampler) assert isinstance(replaced_loader.batch_sampler.sampler, UnrepeatedRandomSampler)
assert replaced_loader.batch_sampler.batch_size == 2
assert replaced_loader.batch_sampler.batch_size == 4
assert replaced_loader.batch_sampler.sampler.shuffle == shuffle assert replaced_loader.batch_sampler.sampler.shuffle == shuffle
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)
dist.barrier() dist.barrier()


@magic_argv_env_context @magic_argv_env_context
@@ -339,8 +381,9 @@ class TestSetDistReproDataloader:
""" """
测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader.batch_sampler.sampler 为 UnrepeatedSampler 测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader.batch_sampler.sampler 为 UnrepeatedSampler
的表现 的表现
此时应该返回一个新的 dataloader,且重新实例化了原来的 Sampler
""" """
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2)
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=4)
batch_sampler.sampler = UnrepeatedRandomSampler(self.dataset, shuffle) batch_sampler.sampler = UnrepeatedRandomSampler(self.dataset, shuffle)
dataloader = DataLoader( dataloader = DataLoader(
self.dataset, self.dataset,
@@ -353,9 +396,10 @@ class TestSetDistReproDataloader:
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
assert isinstance(replaced_loader.batch_sampler.sampler, UnrepeatedRandomSampler) assert isinstance(replaced_loader.batch_sampler.sampler, UnrepeatedRandomSampler)
assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler) assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler)
assert replaced_loader.batch_sampler.batch_size == 2
assert replaced_loader.batch_sampler.batch_size == 4
assert replaced_loader.drop_last == dataloader.drop_last assert replaced_loader.drop_last == dataloader.drop_last
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)
dist.barrier() dist.barrier()


@magic_argv_env_context @magic_argv_env_context
@@ -363,6 +407,8 @@ class TestSetDistReproDataloader:
def test_set_dist_repro_dataloader_with_dist_unrepeat_dataloader_normal(self, shuffle): def test_set_dist_repro_dataloader_with_dist_unrepeat_dataloader_normal(self, shuffle):
""" """
测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader 为一般情况的表现 测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader 为一般情况的表现
此时应该返回一个新的 dataloader,且将 sampler 替换为 UnrepeatedSequentialSampler,并正确地设置了分布式相关
的属性
""" """
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle) dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False) replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False)
@@ -374,6 +420,7 @@ class TestSetDistReproDataloader:
assert replaced_loader.batch_sampler.batch_size == 4 assert replaced_loader.batch_sampler.batch_size == 4
assert replaced_loader.drop_last == dataloader.drop_last assert replaced_loader.drop_last == dataloader.drop_last
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)
dist.barrier() dist.barrier()


def check_distributed_sampler(self, sampler): def check_distributed_sampler(self, sampler):
@@ -385,3 +432,288 @@ class TestSetDistReproDataloader:
if not isinstance(sampler, UnrepeatedSampler): if not isinstance(sampler, UnrepeatedSampler):
assert sampler.pad == True assert sampler.pad == True


def check_set_dist_repro_dataloader(self, dataloader, replaced_loader, shuffle):
"""
测试多卡下 set_dist_repro_dataloader 函数的执行结果是否正确
"""
# 迭代两个 batch
num_consumed_batches = 2
already_seen_idx = set()
for idx, batch in enumerate(replaced_loader):
if idx >= num_consumed_batches:
break
already_seen_idx.update(batch)
if isinstance(replaced_loader.batch_sampler, BucketedBatchSampler):
sampler_states = replaced_loader.batch_sampler.state_dict()
else:
sampler_states = replaced_loader.batch_sampler.sampler.state_dict()

# 加载 num_consumed_samples_array,设置正确取出的 batch 数目
num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None)

# 重新加载,应该可以输出剩下的内容,且对于 PaddleNormalDataset 来说,排序后应该是一个 range
left_idxes = set()
if isinstance(replaced_loader.batch_sampler, BucketedBatchSampler):
batch_size = replaced_loader.batch_sampler.batch_size
if num_consumed_samples_array is not None:
sampler_states["num_consumed_samples"] = num_consumed_samples_array[num_consumed_batches]
else:
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size
# 重新改造 dataloader
new_loader = DataLoader(
dataset=replaced_loader.dataset,
batch_sampler=BucketedBatchSampler(
replaced_loader.dataset,
length=replaced_loader.dataset._data,
batch_size=batch_size,
shuffle=shuffle,
)
)
new_loader.batch_sampler.set_distributed(
num_replicas=self.driver.world_size,
rank=self.driver.global_rank,
pad=True
)
new_loader.batch_sampler.load_state_dict(sampler_states)
else:
batch_size = replaced_loader.batch_sampler.batch_size
num_consumed_samples = num_consumed_batches * batch_size
if num_consumed_samples_array is not None:
sampler_states["num_consumed_samples"] = num_consumed_samples_array[num_consumed_samples]
else:
sampler_states["num_consumed_samples"] = num_consumed_samples
# 重新构造 dataloader
batch_sampler = BatchSampler(replaced_loader.dataset, shuffle=shuffle, batch_size=batch_size)
batch_sampler.sampler = RandomSampler(replaced_loader.dataset, shuffle=shuffle)
batch_sampler.sampler.set_distributed(
num_replicas=self.driver.world_size,
rank=self.driver.global_rank
)
new_loader = DataLoader(replaced_loader.dataset, batch_sampler=batch_sampler)
new_loader.batch_sampler.sampler.load_state_dict(sampler_states)
for idx, batch in enumerate(new_loader):
left_idxes.update(batch)

num_replicas = len(self.device)
assert len(left_idxes) + len(already_seen_idx) == len(self.dataset) / num_replicas
assert len(left_idxes | already_seen_idx) == len(self.dataset) / num_replicas
assert False


############################################################################
#
# 测试 save 和 load 相关的功能
#
############################################################################
class TestSaveLoad:
"""
测试多卡情况下 save 和 load 相关函数的表现
"""
def setup_method(self):
self.dataset = PaddleRandomMaxDataset(20, 10)
self.driver1, self.driver2 = generate_driver(10, 10), generate_driver(10, 10)

@magic_argv_env_context
@pytest.mark.parametrize("only_state_dict", ([True, False]))
def test_save_and_load_model(self, only_state_dict):
"""
测试 save_model 和 load_model 函数
"""
try:
path = "model"
dataloader = DataLoader(self.dataset, batch_size=2)

if only_state_dict:
self.driver1.save_model(path, only_state_dict)
else:
self.driver1.save_model(path, only_state_dict, input_spec=[paddle.ones((4, 10))])

# 同步
dist.barrier()
self.driver2.load_model(path, only_state_dict)

for idx, batch in enumerate(dataloader):
batch = self.driver1.move_data_to_device(batch)
res1 = self.driver1.model(
batch,
fastnlp_fn=self.driver1.model._layers.model.evaluate_step,
# Driver.model -> DataParallel._layers -> _FleetWrappingModel.model
fastnlp_signature_fn=None,
wo_auto_param_call=False,
)
res2 = self.driver2.model(
batch,
fastnlp_fn=self.driver2.model._layers.model.evaluate_step,
fastnlp_signature_fn=None,
wo_auto_param_call=False,
)

assert paddle.equal_all(res1["pred"], res2["pred"])
finally:
if only_state_dict:
synchronize_safe_rm(path)
else:
synchronize_safe_rm(path + ".pdiparams")
synchronize_safe_rm(path + ".pdiparams.info")
synchronize_safe_rm(path + ".pdmodel")

@magic_argv_env_context
@pytest.mark.parametrize("only_state_dict", ([True, False]))
@pytest.mark.parametrize("fp16", ([True, False]))
def test_save_and_load_with_randombatchsampler(self, only_state_dict, fp16):
return
"""
测试save和load函数,主要测试 dataloader 被替换了 sampler 之后的情况
"""

try:
path = "model.ckp"

driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10)
dataset = PaddleRandomMaxDataset(40, 10)
dataloader = DataLoader(
dataset=dataset,
batch_sampler=RandomBatchSampler(BatchSampler(dataset, batch_size=4), 4, False)
)
num_consumed_batches = 2

already_seen_x_set = set()
already_seen_y_set = set()
for idx, batch in enumerate(dataloader):
if idx >= num_consumed_batches:
break
already_seen_x_set.update(batch["x"])
already_seen_y_set.update(batch["y"])

sampler_states = dataloader.batch_sampler.state_dict()
save_states = {"num_consumed_batches": num_consumed_batches}
if only_state_dict:
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True)
else:
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))])
# 加载
# 更改 batch_size
dataloader = DataLoader(
dataset=dataset,
batch_sampler=RandomBatchSampler(BatchSampler(dataset, batch_size=2, shuffle=True), 2, False)
)
load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True)
replaced_loader = load_states.pop("dataloader")
# 1. 检查 optimizer 的状态
# TODO optimizer 的 state_dict 总是为空

# 2. 检查 batch_sampler 是否被正确地加载和替换
assert not (replaced_loader is dataloader)
assert replaced_loader.batch_sampler is dataloader.batch_sampler
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert replaced_loader.batch_sampler.index_list == sampler_states["index_list"]
assert replaced_loader.batch_sampler.num_consumed_samples == num_consumed_batches * 4

# 3. 检查 fp16 是否被加载
if fp16:
assert isinstance(driver2.grad_scaler, paddle.amp.GradScaler)

# 4. 检查 model 的参数是否正确
# 5. 检查 batch_idx
start_batch = load_states.pop('batch_idx_in_epoch')
assert start_batch == 2 * num_consumed_batches
left_x_batches = set()
left_y_batches = set()
for idx, batch in enumerate(replaced_loader):

left_x_batches.update(batch["x"])
left_y_batches.update(batch["y"])
res1 = driver1.model.evaluate_step(**batch)
res2 = driver2.model.evaluate_step(**batch)
assert paddle.equal_all(res1["pred"], res2["pred"])

assert len(left_x_batches) + len(already_seen_x_set) == len(dataset)
assert len(left_x_batches | already_seen_x_set) == len(dataset)
assert len(left_y_batches) + len(already_seen_y_set) == len(dataset)
assert len(left_y_batches | already_seen_y_set) == len(dataset)
finally:
synchronize_safe_rm(path)

@magic_argv_env_context
@pytest.mark.parametrize("only_state_dict", ([True, False]))
@pytest.mark.parametrize("fp16", ([True, False]))
def test_save_and_load_with_randomsampler(self, only_state_dict, fp16):
return
"""
测试save和load函数,主要测试 dataloader 被替换了 batch_sampler 的情况
"""

try:
path = "model.ckp"

driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10)
dataset = PaddleRandomMaxDataset(40, 10)
batch_sampler = BatchSampler(dataset=dataset, batch_size=4)
batch_sampler.sampler = RandomSampler(dataset, True)
dataloader = DataLoader(
dataset,
batch_sampler=batch_sampler
)
num_consumed_batches = 2

already_seen_x_set = set()
already_seen_y_set = set()
for idx, batch in enumerate(dataloader):
if idx >= num_consumed_batches:
break
already_seen_x_set.update(batch["x"])
already_seen_y_set.update(batch["y"])

sampler_states = dataloader.batch_sampler.sampler.state_dict()
save_states = {"num_consumed_batches": num_consumed_batches}
if only_state_dict:
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True)
else:
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))])
# 加载
# 更改 batch_size
batch_sampler = BatchSampler(dataset=dataset, batch_size=2)
batch_sampler.sampler = RandomSampler(dataset, True)
dataloader = DataLoader(
dataset,
batch_sampler=batch_sampler
)
load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True)
replaced_loader = load_states.pop("dataloader")

# 1. 检查 optimizer 的状态
# TODO optimizer 的 state_dict 总是为空

# 2. 检查 sampler 是否被正确地加载和替换
assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
assert replaced_loader.batch_sampler.sampler.seed == sampler_states["seed"]
assert replaced_loader.batch_sampler.sampler.epoch == sampler_states["epoch"]
assert replaced_loader.batch_sampler.sampler.num_consumed_samples == 4 * num_consumed_batches
assert len(replaced_loader.batch_sampler.sampler.dataset) == sampler_states["length"]
assert replaced_loader.batch_sampler.sampler.shuffle == sampler_states["shuffle"]
# 3. 检查 fp16 是否被加载
if fp16:
assert isinstance(driver2.grad_scaler, paddle.amp.GradScaler)

# 4. 检查 model 的参数是否正确
# 5. 检查 batch_idx
start_batch = load_states.pop('batch_idx_in_epoch')
assert start_batch == 2 * num_consumed_batches
left_x_batches = set()
left_y_batches = set()
for idx, batch in enumerate(replaced_loader):

left_x_batches.update(batch["x"])
left_y_batches.update(batch["y"])
res1 = driver1.model.evaluate_step(**batch)
res2 = driver2.model.evaluate_step(**batch)
assert paddle.equal_all(res1["pred"], res2["pred"])

assert len(left_x_batches) + len(already_seen_x_set) == len(dataset)
assert len(left_x_batches | already_seen_x_set) == len(dataset)
assert len(left_y_batches) + len(already_seen_y_set) == len(dataset)
assert len(left_y_batches | already_seen_y_set) == len(dataset)
finally:
synchronize_safe_rm(path)

Loading…
Cancel
Save