| @@ -125,11 +125,9 @@ def create_dataset_imagenet(dataset_path, | |||||
| if device_num == 1: | if device_num == 1: | ||||
| data_set = ds.ImageFolderDataset(dataset_path, | data_set = ds.ImageFolderDataset(dataset_path, | ||||
| num_parallel_workers=8, | |||||
| shuffle=True) | shuffle=True) | ||||
| else: | else: | ||||
| data_set = ds.ImageFolderDataset(dataset_path, | data_set = ds.ImageFolderDataset(dataset_path, | ||||
| num_parallel_workers=8, | |||||
| shuffle=True, | shuffle=True, | ||||
| num_shards=device_num, | num_shards=device_num, | ||||
| shard_id=rank_id) | shard_id=rank_id) | ||||
| @@ -162,11 +160,10 @@ def create_dataset_imagenet(dataset_path, | |||||
| type_cast_op = C2.TypeCast(mstype.int32) | type_cast_op = C2.TypeCast(mstype.int32) | ||||
| data_set = data_set.map(operations=type_cast_op, | data_set = data_set.map(operations=type_cast_op, | ||||
| input_columns="label", | |||||
| num_parallel_workers=8) | |||||
| input_columns="label") | |||||
| data_set = data_set.map(operations=trans, | data_set = data_set.map(operations=trans, | ||||
| input_columns="image", | input_columns="image", | ||||
| num_parallel_workers=8) | |||||
| num_parallel_workers=10) | |||||
| # apply batch operations | # apply batch operations | ||||
| data_set = data_set.batch(batch_size, drop_remainder=True) | data_set = data_set.batch(batch_size, drop_remainder=True) | ||||