| @@ -37,24 +37,31 @@ def create_dataset(dataset_path, do_train, config, platform, repeat_num=1, batch | |||
| if platform == "Ascend": | |||
| rank_size = int(os.getenv("RANK_SIZE")) | |||
| rank_id = int(os.getenv("RANK_ID")) | |||
| if rank_size == 1: | |||
| ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True) | |||
| if do_train: | |||
| if rank_size == 1: | |||
| ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True) | |||
| else: | |||
| ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True, | |||
| num_shards=rank_size, shard_id=rank_id) | |||
| else: | |||
| ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True, | |||
| num_shards=rank_size, shard_id=rank_id) | |||
| ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=False) | |||
| elif platform == "GPU": | |||
| if do_train: | |||
| from mindspore.communication.management import get_rank, get_group_size | |||
| ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True, | |||
| num_shards=get_group_size(), shard_id=get_rank()) | |||
| else: | |||
| ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True) | |||
| ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=False) | |||
| else: | |||
| raise ValueError("Unsupport platform.") | |||
| resize_height = config.image_height | |||
| resize_width = config.image_width | |||
| buffer_size = 1000 | |||
| if do_train: | |||
| buffer_size = 20480 | |||
| # apply shuffle operations | |||
| ds = ds.shuffle(buffer_size=buffer_size) | |||
| # define map operations | |||
| decode_op = C.Decode() | |||
| @@ -63,23 +70,23 @@ def create_dataset(dataset_path, do_train, config, platform, repeat_num=1, batch | |||
| resize_op = C.Resize((256, 256)) | |||
| center_crop = C.CenterCrop(resize_width) | |||
| rescale_op = C.RandomColorAdjust(brightness=0.4, contrast=0.4, saturation=0.4) | |||
| random_color_op = C.RandomColorAdjust(brightness=0.4, contrast=0.4, saturation=0.4) | |||
| normalize_op = C.Normalize(mean=[0.485*255, 0.456*255, 0.406*255], std=[0.229*255, 0.224*255, 0.225*255]) | |||
| change_swap_op = C.HWC2CHW() | |||
| transform_uniform = [horizontal_flip_op, random_color_op] | |||
| uni_aug = C.UniformAugment(operations=transform_uniform, num_ops=2) | |||
| if do_train: | |||
| trans = [resize_crop_op, horizontal_flip_op, rescale_op, normalize_op, change_swap_op] | |||
| trans = [resize_crop_op, uni_aug, normalize_op, change_swap_op] | |||
| else: | |||
| trans = [decode_op, resize_op, center_crop, normalize_op, change_swap_op] | |||
| type_cast_op = C2.TypeCast(mstype.int32) | |||
| ds = ds.map(input_columns="image", operations=trans, num_parallel_workers=8) | |||
| ds = ds.map(input_columns="image", operations=trans, num_parallel_workers=16) | |||
| ds = ds.map(input_columns="label", operations=type_cast_op, num_parallel_workers=8) | |||
| # apply shuffle operations | |||
| ds = ds.shuffle(buffer_size=buffer_size) | |||
| # apply batch operations | |||
| ds = ds.batch(batch_size, drop_remainder=True) | |||