|
|
@@ -115,8 +115,8 @@ def create_dataset_imagenet(dataset_path, repeat_num=1, training=True, |
|
|
|
|
|
|
|
|
transform_label = [C.TypeCast(mstype.int32)] |
|
|
transform_label = [C.TypeCast(mstype.int32)] |
|
|
|
|
|
|
|
|
data_set = data_set.map(input_columns="image", num_parallel_workers=8, operations=transform_img) |
|
|
|
|
|
data_set = data_set.map(input_columns="label", num_parallel_workers=8, operations=transform_label) |
|
|
|
|
|
|
|
|
data_set = data_set.map(input_columns="image", num_parallel_workers=12, operations=transform_img) |
|
|
|
|
|
data_set = data_set.map(input_columns="label", num_parallel_workers=4, operations=transform_label) |
|
|
|
|
|
|
|
|
# apply batch operations |
|
|
# apply batch operations |
|
|
data_set = data_set.batch(imagenet_cfg.batch_size, drop_remainder=True) |
|
|
data_set = data_set.batch(imagenet_cfg.batch_size, drop_remainder=True) |
|
|
|