|
|
|
@@ -125,11 +125,9 @@ def create_dataset_imagenet(dataset_path, |
|
|
|
|
|
|
|
if device_num == 1: |
|
|
|
data_set = ds.ImageFolderDataset(dataset_path, |
|
|
|
num_parallel_workers=8, |
|
|
|
shuffle=True) |
|
|
|
else: |
|
|
|
data_set = ds.ImageFolderDataset(dataset_path, |
|
|
|
num_parallel_workers=8, |
|
|
|
shuffle=True, |
|
|
|
num_shards=device_num, |
|
|
|
shard_id=rank_id) |
|
|
|
@@ -162,11 +160,10 @@ def create_dataset_imagenet(dataset_path, |
|
|
|
type_cast_op = C2.TypeCast(mstype.int32) |
|
|
|
|
|
|
|
data_set = data_set.map(operations=type_cast_op, |
|
|
|
input_columns="label", |
|
|
|
num_parallel_workers=8) |
|
|
|
input_columns="label") |
|
|
|
data_set = data_set.map(operations=trans, |
|
|
|
input_columns="image", |
|
|
|
num_parallel_workers=8) |
|
|
|
num_parallel_workers=10) |
|
|
|
|
|
|
|
# apply batch operations |
|
|
|
data_set = data_set.batch(batch_size, drop_remainder=True) |
|
|
|
|