|
|
|
@@ -38,10 +38,14 @@ def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32): |
|
|
|
|
|
|
|
device_num = int(os.getenv("RANK_SIZE")) |
|
|
|
rank_id = int(os.getenv("RANK_ID")) |
|
|
|
if device_num == 1: |
|
|
|
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True) |
|
|
|
if do_train: |
|
|
|
if device_num == 1: |
|
|
|
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True) |
|
|
|
else: |
|
|
|
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True, |
|
|
|
num_shards=device_num, shard_id=rank_id) |
|
|
|
else: |
|
|
|
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True, |
|
|
|
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=False, |
|
|
|
num_shards=device_num, shard_id=rank_id) |
|
|
|
|
|
|
|
image_size = 224 |
|
|
|
|