|
|
@@ -169,6 +169,7 @@ def create_dataset3(dataset_path, do_train, repeat_num=1, batch_size=32, target= |
|
|
device_num = get_group_size() |
|
|
device_num = get_group_size() |
|
|
else: |
|
|
else: |
|
|
device_num = 1 |
|
|
device_num = 1 |
|
|
|
|
|
rank_id = 1 |
|
|
if device_num == 1: |
|
|
if device_num == 1: |
|
|
ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True) |
|
|
ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True) |
|
|
else: |
|
|
else: |
|
|
|