| @@ -535,7 +535,7 @@ class TestSetDistReproDataloder: | |||||
| # | # | ||||
| ############################################################################ | ############################################################################ | ||||
| def generate_random_driver(features, labels, fp16, device="cpu"): | |||||
| def generate_random_driver(features, labels, fp16=False, device="cpu"): | |||||
| """ | """ | ||||
| 生成driver | 生成driver | ||||
| """ | """ | ||||
| @@ -549,8 +549,8 @@ def generate_random_driver(features, labels, fp16, device="cpu"): | |||||
| @pytest.fixture | @pytest.fixture | ||||
| def prepare_test_save_load(): | def prepare_test_save_load(): | ||||
| dataset = PaddleRandomMaxDataset(320, 10) | |||||
| dataloader = DataLoader(dataset, batch_size=32) | |||||
| dataset = PaddleRandomMaxDataset(40, 10) | |||||
| dataloader = DataLoader(dataset, batch_size=4) | |||||
| driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10) | driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10) | ||||
| return driver1, driver2, dataloader | return driver1, driver2, dataloader | ||||