|
|
|
@@ -16,18 +16,15 @@ |
|
|
|
"""create train or eval dataset.""" |
|
|
|
|
|
|
|
import os |
|
|
|
|
|
|
|
import mindspore.common.dtype as mstype |
|
|
|
import mindspore.dataset as ds |
|
|
|
import mindspore.dataset.transforms.c_transforms as C2 |
|
|
|
import mindspore.dataset.vision.c_transforms as C |
|
|
|
|
|
|
|
ds.config.set_seed(1) |
|
|
|
import mindspore.dataset.transforms.c_transforms as C2 |
|
|
|
|
|
|
|
|
|
|
|
def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32): |
|
|
|
""" |
|
|
|
Create a train or eval dataset. |
|
|
|
create a train or eval dataset. |
|
|
|
|
|
|
|
Args: |
|
|
|
dataset_path(string): the path of dataset. |
|
|
|
@@ -41,10 +38,14 @@ def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32): |
|
|
|
|
|
|
|
device_num = int(os.getenv("RANK_SIZE")) |
|
|
|
rank_id = int(os.getenv("RANK_ID")) |
|
|
|
if device_num == 1: |
|
|
|
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True) |
|
|
|
if do_train: |
|
|
|
if device_num == 1: |
|
|
|
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True) |
|
|
|
else: |
|
|
|
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True, |
|
|
|
num_shards=device_num, shard_id=rank_id) |
|
|
|
else: |
|
|
|
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True, |
|
|
|
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=False, |
|
|
|
num_shards=device_num, shard_id=rank_id) |
|
|
|
|
|
|
|
image_size = 224 |
|
|
|
@@ -54,8 +55,9 @@ def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32): |
|
|
|
# define map operations |
|
|
|
if do_train: |
|
|
|
trans = [ |
|
|
|
C.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)), |
|
|
|
C.RandomHorizontalFlip(prob=0.5), |
|
|
|
C.Decode(), |
|
|
|
C.Resize((256, 256)), |
|
|
|
C.CenterCrop(image_size), |
|
|
|
C.Normalize(mean=mean, std=std), |
|
|
|
C.HWC2CHW() |
|
|
|
] |
|
|
|
|