From a5c16ba5c4e153dc5f6edd611f0cb0f2f4507795 Mon Sep 17 00:00:00 2001 From: wangmin0104 Date: Sun, 27 Dec 2020 20:06:44 +0800 Subject: [PATCH] update tests/st/networks/models/resnet50/src/dataset.py. --- tests/st/networks/models/resnet50/src/dataset.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/tests/st/networks/models/resnet50/src/dataset.py b/tests/st/networks/models/resnet50/src/dataset.py index 799b1fed74..0d019c0279 100755 --- a/tests/st/networks/models/resnet50/src/dataset.py +++ b/tests/st/networks/models/resnet50/src/dataset.py @@ -38,10 +38,14 @@ def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32): device_num = int(os.getenv("RANK_SIZE")) rank_id = int(os.getenv("RANK_ID")) - if device_num == 1: - data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True) + if do_train: + if device_num == 1: + data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True) + else: + data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True, + num_shards=device_num, shard_id=rank_id) else: - data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True, + data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=False, num_shards=device_num, shard_id=rank_id) image_size = 224