Browse Source

paddle fleet set_dist_repro_dataloader的测试例

tags/v1.0.0alpha
x54-729 3 years ago
parent
commit
de544707d9
2 changed files with 347 additions and 17 deletions
  1. +1
    -3
      fastNLP/core/drivers/paddle_driver/fleet.py
  2. +346
    -14
      tests/core/drivers/paddle_driver/test_fleet.py

+ 1
- 3
fastNLP/core/drivers/paddle_driver/fleet.py View File

@@ -1,6 +1,5 @@
import os
import shutil
from functools import partial
from typing import List, Union, Optional, Dict, Tuple, Callable

from .paddle_driver import PaddleDriver
@@ -38,7 +37,6 @@ if _NEED_IMPORT_PADDLE:
from paddle import DataParallel
import paddle.distributed.fleet as fleet
import paddle.distributed as paddledist
from paddle.io import BatchSampler
from paddle.optimizer import Optimizer
from paddle.fluid.reader import _DatasetKind
from paddle.fluid.dygraph import parallel_helper
@@ -305,7 +303,7 @@ class PaddleFleetDriver(PaddleDriver):
raise RuntimeError(f"There is no `{fn}` method in your model.")
else:
if hasattr(model, fn):
logger.warning("Notice your model is a `DistributedDataParallel` model. And your model also implements "
logger.warning("Notice your model is a `DataParallel` model. And your model also implements "
f"the `{fn}` method, which we can not call actually, we will"
" call `forward` function instead of `train_step` and you should note that.")
elif fn not in {"train_step", "evaluate_step"}:


+ 346
- 14
tests/core/drivers/paddle_driver/test_fleet.py View File

@@ -12,28 +12,44 @@ from fastNLP.core.samplers import (
UnrepeatedSequentialSampler,
)
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1
from tests.helpers.datasets.paddle_data import PaddleNormalDataset
from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleRandomMaxDataset
from tests.helpers.utils import magic_argv_env_context
from fastNLP.core import synchronize_safe_rm

import paddle
import paddle.distributed as dist
from paddle.io import DataLoader, BatchSampler

def generate_driver(num_labels, feature_dimension):
def generate_driver(num_labels, feature_dimension, device=[0,1], fp16=False):
paddle_model = PaddleNormalModel_Classification_1(num_labels, feature_dimension)
paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01)
driver = PaddleFleetDriver(
model=paddle_model,
parallel_device=[0,1],
parallel_device=device,
fp16=fp16,
)
driver.set_optimizers(paddle_opt)
driver.setup()

return driver

@magic_argv_env_context
def test_multi_drivers():
"""
测试使用了多个 PaddleFleetDriver 的情况。
"""
driver1 = generate_driver(10, 10)
driver2 = generate_driver(20, 10)

with pytest.raises(RuntimeError):
# 设备设置不同,应该报错
driver3 = generate_driver(20, 3, device=[0,2])

dist.barrier()

############################################################################
#
# 测试PaddleFleetDriver的一些函数
# 测试 PaddleFleetDriver 的一些函数
#
############################################################################

@@ -106,10 +122,11 @@ class TestSetDistReproDataloader:

@classmethod
def setup_class(cls):
cls.driver = generate_driver(10, 10)
cls.device = [0, 1]
cls.driver = generate_driver(10, 10, device=cls.device)

def setup_method(self):
self.dataset = PaddleNormalDataset(20)
self.dataset = PaddleNormalDataset(40)

"""
传入的 `dist` 参数为具体的 ReproducibleSampler 或 ReproducibleBatchSampler 的情况
@@ -121,6 +138,7 @@ class TestSetDistReproDataloader:
def test_set_dist_repro_dataloader_with_dist_batch_sampler(self, shuffle):
"""
测试 set_dist_repro_dataloader 中 dist 为 BucketedBatchSampler 时的表现
此时应该将 batch_sampler 替换为 dist 对应的 BucketedBatchSampler
"""
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=not shuffle)
batch_sampler = BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4, shuffle=shuffle)
@@ -130,6 +148,7 @@ class TestSetDistReproDataloader:
assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler)
assert replaced_loader.batch_sampler is batch_sampler
self.check_distributed_sampler(replaced_loader.batch_sampler)
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)
dist.barrier()

@@ -138,6 +157,7 @@ class TestSetDistReproDataloader:
def test_set_dist_repro_dataloader_with_dist_sampler(self, shuffle):
"""
测试 set_dist_repro_dataloader 中 dist 为 RandomSampler 时的表现
此时应该将 batch_sampler.sampler 替换为 dist 对应的 RandomSampler
"""
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=not shuffle)
sampler = RandomSampler(self.dataset, shuffle=shuffle)
@@ -150,6 +170,7 @@ class TestSetDistReproDataloader:
assert replaced_loader.batch_sampler.sampler is sampler
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)

dist.barrier()
@@ -164,6 +185,7 @@ class TestSetDistReproDataloader:
def test_set_dist_repro_dataloader_with_dist_none_reproducible_true(self):
"""
测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 True 时的表现
当用户在 driver 之外初始化了分布式环境时,fastnlp 不支持进行断点重训,此时应该报错
"""
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True)
with pytest.raises(RuntimeError):
@@ -178,6 +200,8 @@ class TestSetDistReproDataloader:
"""
测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 有 BucketedBatchSampler
时的表现
此时传入的 dataloader 的 batch_sampler 应该已经执行了 set_distributed,产生一个新的 dataloader,其 batch_sampler
和原 dataloader 相同
"""
dataloader = DataLoader(
self.dataset,
@@ -194,6 +218,7 @@ class TestSetDistReproDataloader:
assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler)
assert replaced_loader.batch_sampler.batch_size == 4
self.check_distributed_sampler(dataloader.batch_sampler)
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)

dist.barrier()

@@ -202,8 +227,10 @@ class TestSetDistReproDataloader:
def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_reproducible_smpler(self, shuffle):
"""
测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 有 RandomSampler 时的表现
此时传入的 dataloader 的 batch_sampler.sampler 应该已经执行了 set_distributed,产生一个新的 dataloader,其
batch_sampler.sampler 和原 dataloader 相同
"""
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2)
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=4)
batch_sampler.sampler = RandomSampler(self.dataset, shuffle)
batch_sampler.sampler.set_distributed(
num_replicas=self.driver.world_size,
@@ -220,9 +247,11 @@ class TestSetDistReproDataloader:
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler)
assert replaced_loader.batch_sampler.batch_size == 2
assert replaced_loader.batch_sampler.batch_size == 4
assert replaced_loader.batch_sampler.drop_last == False
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)
dist.barrier()

@magic_argv_env_context
@@ -230,6 +259,7 @@ class TestSetDistReproDataloader:
def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_normal(self, shuffle):
"""
测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 为一般情况时的表现
此时直接返回原来的 dataloader,不做任何处理。
"""
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, False)
@@ -248,6 +278,7 @@ class TestSetDistReproDataloader:
"""
测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader.batch_sampler 为 ReproducibleBatchSampler
的表现
此时应该返回一个新的 dataloader,其batch_sampler 和原 dataloader 相同,且应该正确地设置了分布式相关的属性
"""
dataloader = DataLoader(
dataset=self.dataset,
@@ -261,6 +292,7 @@ class TestSetDistReproDataloader:
assert replaced_loader.batch_sampler.batch_size == 4
assert replaced_loader.drop_last == dataloader.drop_last
self.check_distributed_sampler(replaced_loader.batch_sampler)
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)
dist.barrier()

@magic_argv_env_context
@@ -269,8 +301,10 @@ class TestSetDistReproDataloader:
"""
测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader.batch_sampler.sampler 为 ReproducibleSampler
的表现
此时应该返回一个新的 dataloader,其 batch_sampler.sampler 和原 dataloader 相同,且应该正确地设置了分布式相关
的属性
"""
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2, shuffle=shuffle)
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=4, shuffle=shuffle)
batch_sampler.sampler = RandomSampler(self.dataset, shuffle)
dataloader = DataLoader(
self.dataset,
@@ -282,9 +316,10 @@ class TestSetDistReproDataloader:
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler)
assert replaced_loader.batch_sampler.batch_size == 2
assert replaced_loader.batch_sampler.batch_size == 4
assert replaced_loader.batch_sampler.sampler.shuffle == shuffle
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)
dist.barrier()

@magic_argv_env_context
@@ -292,6 +327,8 @@ class TestSetDistReproDataloader:
def test_set_dist_repro_dataloader_with_dist_dist_dataloader_normal(self, shuffle):
"""
测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader 为一般情况的表现
此时应该返回一个新的 dataloader,并替换其 batch_sampler.sampler 为 RandomSampler,且应该正确设置了分布式相关
的属性
"""
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False)
@@ -302,6 +339,8 @@ class TestSetDistReproDataloader:
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size
assert replaced_loader.batch_sampler.sampler.shuffle == shuffle
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)
dist.barrier()

"""
@@ -315,8 +354,10 @@ class TestSetDistReproDataloader:
"""
测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader.batch_sampler.sampler 为 ReproducibleSampler
的表现
此时应该返回一个新的 dataloader,且将原来的 Sampler 替换为 UnrepeatedRandomSampler,且正确地设置了分布式相关
的属性
"""
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2)
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=4)
batch_sampler.sampler = RandomSampler(self.dataset, shuffle)
dataloader = DataLoader(
self.dataset,
@@ -328,9 +369,10 @@ class TestSetDistReproDataloader:
assert isinstance(replaced_loader.batch_sampler, BatchSampler)
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
assert isinstance(replaced_loader.batch_sampler.sampler, UnrepeatedRandomSampler)
assert replaced_loader.batch_sampler.batch_size == 2
assert replaced_loader.batch_sampler.batch_size == 4
assert replaced_loader.batch_sampler.sampler.shuffle == shuffle
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)
dist.barrier()

@magic_argv_env_context
@@ -339,8 +381,9 @@ class TestSetDistReproDataloader:
"""
测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader.batch_sampler.sampler 为 UnrepeatedSampler
的表现
此时应该返回一个新的 dataloader,且重新实例化了原来的 Sampler
"""
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2)
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=4)
batch_sampler.sampler = UnrepeatedRandomSampler(self.dataset, shuffle)
dataloader = DataLoader(
self.dataset,
@@ -353,9 +396,10 @@ class TestSetDistReproDataloader:
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
assert isinstance(replaced_loader.batch_sampler.sampler, UnrepeatedRandomSampler)
assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler)
assert replaced_loader.batch_sampler.batch_size == 2
assert replaced_loader.batch_sampler.batch_size == 4
assert replaced_loader.drop_last == dataloader.drop_last
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)
dist.barrier()

@magic_argv_env_context
@@ -363,6 +407,8 @@ class TestSetDistReproDataloader:
def test_set_dist_repro_dataloader_with_dist_unrepeat_dataloader_normal(self, shuffle):
"""
测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader 为一般情况的表现
此时应该返回一个新的 dataloader,且将 sampler 替换为 UnrepeatedSequentialSampler,并正确地设置了分布式相关
的属性
"""
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False)
@@ -374,6 +420,7 @@ class TestSetDistReproDataloader:
assert replaced_loader.batch_sampler.batch_size == 4
assert replaced_loader.drop_last == dataloader.drop_last
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)
dist.barrier()

def check_distributed_sampler(self, sampler):
@@ -385,3 +432,288 @@ class TestSetDistReproDataloader:
if not isinstance(sampler, UnrepeatedSampler):
assert sampler.pad == True

def check_set_dist_repro_dataloader(self, dataloader, replaced_loader, shuffle):
"""
测试多卡下 set_dist_repro_dataloader 函数的执行结果是否正确
"""
# 迭代两个 batch
num_consumed_batches = 2
already_seen_idx = set()
for idx, batch in enumerate(replaced_loader):
if idx >= num_consumed_batches:
break
already_seen_idx.update(batch)
if isinstance(replaced_loader.batch_sampler, BucketedBatchSampler):
sampler_states = replaced_loader.batch_sampler.state_dict()
else:
sampler_states = replaced_loader.batch_sampler.sampler.state_dict()

# 加载 num_consumed_samples_array,设置正确取出的 batch 数目
num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None)

# 重新加载,应该可以输出剩下的内容,且对于 PaddleNormalDataset 来说,排序后应该是一个 range
left_idxes = set()
if isinstance(replaced_loader.batch_sampler, BucketedBatchSampler):
batch_size = replaced_loader.batch_sampler.batch_size
if num_consumed_samples_array is not None:
sampler_states["num_consumed_samples"] = num_consumed_samples_array[num_consumed_batches]
else:
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size
# 重新改造 dataloader
new_loader = DataLoader(
dataset=replaced_loader.dataset,
batch_sampler=BucketedBatchSampler(
replaced_loader.dataset,
length=replaced_loader.dataset._data,
batch_size=batch_size,
shuffle=shuffle,
)
)
new_loader.batch_sampler.set_distributed(
num_replicas=self.driver.world_size,
rank=self.driver.global_rank,
pad=True
)
new_loader.batch_sampler.load_state_dict(sampler_states)
else:
batch_size = replaced_loader.batch_sampler.batch_size
num_consumed_samples = num_consumed_batches * batch_size
if num_consumed_samples_array is not None:
sampler_states["num_consumed_samples"] = num_consumed_samples_array[num_consumed_samples]
else:
sampler_states["num_consumed_samples"] = num_consumed_samples
# 重新构造 dataloader
batch_sampler = BatchSampler(replaced_loader.dataset, shuffle=shuffle, batch_size=batch_size)
batch_sampler.sampler = RandomSampler(replaced_loader.dataset, shuffle=shuffle)
batch_sampler.sampler.set_distributed(
num_replicas=self.driver.world_size,
rank=self.driver.global_rank
)
new_loader = DataLoader(replaced_loader.dataset, batch_sampler=batch_sampler)
new_loader.batch_sampler.sampler.load_state_dict(sampler_states)
for idx, batch in enumerate(new_loader):
left_idxes.update(batch)

num_replicas = len(self.device)
assert len(left_idxes) + len(already_seen_idx) == len(self.dataset) / num_replicas
assert len(left_idxes | already_seen_idx) == len(self.dataset) / num_replicas
assert False


############################################################################
#
# 测试 save 和 load 相关的功能
#
############################################################################
class TestSaveLoad:
"""
测试多卡情况下 save 和 load 相关函数的表现
"""
def setup_method(self):
self.dataset = PaddleRandomMaxDataset(20, 10)
self.driver1, self.driver2 = generate_driver(10, 10), generate_driver(10, 10)

@magic_argv_env_context
@pytest.mark.parametrize("only_state_dict", ([True, False]))
def test_save_and_load_model(self, only_state_dict):
"""
测试 save_model 和 load_model 函数
"""
try:
path = "model"
dataloader = DataLoader(self.dataset, batch_size=2)

if only_state_dict:
self.driver1.save_model(path, only_state_dict)
else:
self.driver1.save_model(path, only_state_dict, input_spec=[paddle.ones((4, 10))])

# 同步
dist.barrier()
self.driver2.load_model(path, only_state_dict)

for idx, batch in enumerate(dataloader):
batch = self.driver1.move_data_to_device(batch)
res1 = self.driver1.model(
batch,
fastnlp_fn=self.driver1.model._layers.model.evaluate_step,
# Driver.model -> DataParallel._layers -> _FleetWrappingModel.model
fastnlp_signature_fn=None,
wo_auto_param_call=False,
)
res2 = self.driver2.model(
batch,
fastnlp_fn=self.driver2.model._layers.model.evaluate_step,
fastnlp_signature_fn=None,
wo_auto_param_call=False,
)

assert paddle.equal_all(res1["pred"], res2["pred"])
finally:
if only_state_dict:
synchronize_safe_rm(path)
else:
synchronize_safe_rm(path + ".pdiparams")
synchronize_safe_rm(path + ".pdiparams.info")
synchronize_safe_rm(path + ".pdmodel")

@magic_argv_env_context
@pytest.mark.parametrize("only_state_dict", ([True, False]))
@pytest.mark.parametrize("fp16", ([True, False]))
def test_save_and_load_with_randombatchsampler(self, only_state_dict, fp16):
return
"""
测试save和load函数,主要测试 dataloader 被替换了 sampler 之后的情况
"""

try:
path = "model.ckp"

driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10)
dataset = PaddleRandomMaxDataset(40, 10)
dataloader = DataLoader(
dataset=dataset,
batch_sampler=RandomBatchSampler(BatchSampler(dataset, batch_size=4), 4, False)
)
num_consumed_batches = 2

already_seen_x_set = set()
already_seen_y_set = set()
for idx, batch in enumerate(dataloader):
if idx >= num_consumed_batches:
break
already_seen_x_set.update(batch["x"])
already_seen_y_set.update(batch["y"])

sampler_states = dataloader.batch_sampler.state_dict()
save_states = {"num_consumed_batches": num_consumed_batches}
if only_state_dict:
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True)
else:
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))])
# 加载
# 更改 batch_size
dataloader = DataLoader(
dataset=dataset,
batch_sampler=RandomBatchSampler(BatchSampler(dataset, batch_size=2, shuffle=True), 2, False)
)
load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True)
replaced_loader = load_states.pop("dataloader")
# 1. 检查 optimizer 的状态
# TODO optimizer 的 state_dict 总是为空

# 2. 检查 batch_sampler 是否被正确地加载和替换
assert not (replaced_loader is dataloader)
assert replaced_loader.batch_sampler is dataloader.batch_sampler
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert replaced_loader.batch_sampler.index_list == sampler_states["index_list"]
assert replaced_loader.batch_sampler.num_consumed_samples == num_consumed_batches * 4

# 3. 检查 fp16 是否被加载
if fp16:
assert isinstance(driver2.grad_scaler, paddle.amp.GradScaler)

# 4. 检查 model 的参数是否正确
# 5. 检查 batch_idx
start_batch = load_states.pop('batch_idx_in_epoch')
assert start_batch == 2 * num_consumed_batches
left_x_batches = set()
left_y_batches = set()
for idx, batch in enumerate(replaced_loader):

left_x_batches.update(batch["x"])
left_y_batches.update(batch["y"])
res1 = driver1.model.evaluate_step(**batch)
res2 = driver2.model.evaluate_step(**batch)
assert paddle.equal_all(res1["pred"], res2["pred"])

assert len(left_x_batches) + len(already_seen_x_set) == len(dataset)
assert len(left_x_batches | already_seen_x_set) == len(dataset)
assert len(left_y_batches) + len(already_seen_y_set) == len(dataset)
assert len(left_y_batches | already_seen_y_set) == len(dataset)
finally:
synchronize_safe_rm(path)

@magic_argv_env_context
@pytest.mark.parametrize("only_state_dict", ([True, False]))
@pytest.mark.parametrize("fp16", ([True, False]))
def test_save_and_load_with_randomsampler(self, only_state_dict, fp16):
return
"""
测试save和load函数,主要测试 dataloader 被替换了 batch_sampler 的情况
"""

try:
path = "model.ckp"

driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10)
dataset = PaddleRandomMaxDataset(40, 10)
batch_sampler = BatchSampler(dataset=dataset, batch_size=4)
batch_sampler.sampler = RandomSampler(dataset, True)
dataloader = DataLoader(
dataset,
batch_sampler=batch_sampler
)
num_consumed_batches = 2

already_seen_x_set = set()
already_seen_y_set = set()
for idx, batch in enumerate(dataloader):
if idx >= num_consumed_batches:
break
already_seen_x_set.update(batch["x"])
already_seen_y_set.update(batch["y"])

sampler_states = dataloader.batch_sampler.sampler.state_dict()
save_states = {"num_consumed_batches": num_consumed_batches}
if only_state_dict:
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True)
else:
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))])
# 加载
# 更改 batch_size
batch_sampler = BatchSampler(dataset=dataset, batch_size=2)
batch_sampler.sampler = RandomSampler(dataset, True)
dataloader = DataLoader(
dataset,
batch_sampler=batch_sampler
)
load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True)
replaced_loader = load_states.pop("dataloader")

# 1. 检查 optimizer 的状态
# TODO optimizer 的 state_dict 总是为空

# 2. 检查 sampler 是否被正确地加载和替换
assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
assert replaced_loader.batch_sampler.sampler.seed == sampler_states["seed"]
assert replaced_loader.batch_sampler.sampler.epoch == sampler_states["epoch"]
assert replaced_loader.batch_sampler.sampler.num_consumed_samples == 4 * num_consumed_batches
assert len(replaced_loader.batch_sampler.sampler.dataset) == sampler_states["length"]
assert replaced_loader.batch_sampler.sampler.shuffle == sampler_states["shuffle"]
# 3. 检查 fp16 是否被加载
if fp16:
assert isinstance(driver2.grad_scaler, paddle.amp.GradScaler)

# 4. 检查 model 的参数是否正确
# 5. 检查 batch_idx
start_batch = load_states.pop('batch_idx_in_epoch')
assert start_batch == 2 * num_consumed_batches
left_x_batches = set()
left_y_batches = set()
for idx, batch in enumerate(replaced_loader):

left_x_batches.update(batch["x"])
left_y_batches.update(batch["y"])
res1 = driver1.model.evaluate_step(**batch)
res2 = driver2.model.evaluate_step(**batch)
assert paddle.equal_all(res1["pred"], res2["pred"])

assert len(left_x_batches) + len(already_seen_x_set) == len(dataset)
assert len(left_x_batches | already_seen_x_set) == len(dataset)
assert len(left_y_batches) + len(already_seen_y_set) == len(dataset)
assert len(left_y_batches | already_seen_y_set) == len(dataset)
finally:
synchronize_safe_rm(path)

Loading…
Cancel
Save