|
|
|
@@ -21,29 +21,38 @@ import mindspore.dataset.vision.c_transforms as C |
|
|
|
from .distributed_sampler import DistributedSampler |
|
|
|
from .datasets import UnalignedDataset, ImageFolderDataset |
|
|
|
|
|
|
|
def create_dataset(args, shuffle=True, max_dataset_size=float("inf")): |
|
|
|
def create_dataset(args): |
|
|
|
"""Create dataset""" |
|
|
|
dataroot = args.dataroot |
|
|
|
phase = args.phase |
|
|
|
batch_size = args.batch_size |
|
|
|
device_num = args.device_num |
|
|
|
rank = args.rank |
|
|
|
shuffle = args.use_random |
|
|
|
max_dataset_size = args.max_dataset_size |
|
|
|
cores = multiprocessing.cpu_count() |
|
|
|
num_parallel_workers = min(8, int(cores / device_num)) |
|
|
|
image_size = args.image_size |
|
|
|
mean = [0.5 * 255] * 3 |
|
|
|
std = [0.5 * 255] * 3 |
|
|
|
if phase == "train": |
|
|
|
dataset = UnalignedDataset(dataroot, phase, max_dataset_size=max_dataset_size) |
|
|
|
dataset = UnalignedDataset(dataroot, phase, max_dataset_size=max_dataset_size, use_random=args.use_random) |
|
|
|
distributed_sampler = DistributedSampler(len(dataset), device_num, rank, shuffle=shuffle) |
|
|
|
ds = de.GeneratorDataset(dataset, column_names=["image_A", "image_B"], |
|
|
|
sampler=distributed_sampler, num_parallel_workers=num_parallel_workers) |
|
|
|
trans = [ |
|
|
|
C.RandomResizedCrop(image_size, scale=(0.5, 1.0), ratio=(0.75, 1.333)), |
|
|
|
C.RandomHorizontalFlip(prob=0.5), |
|
|
|
C.Normalize(mean=mean, std=std), |
|
|
|
C.HWC2CHW() |
|
|
|
] |
|
|
|
if args.use_random: |
|
|
|
trans = [ |
|
|
|
C.RandomResizedCrop(image_size, scale=(0.5, 1.0), ratio=(0.75, 1.333)), |
|
|
|
C.RandomHorizontalFlip(prob=0.5), |
|
|
|
C.Normalize(mean=mean, std=std), |
|
|
|
C.HWC2CHW() |
|
|
|
] |
|
|
|
else: |
|
|
|
trans = [ |
|
|
|
C.Resize((image_size, image_size)), |
|
|
|
C.Normalize(mean=mean, std=std), |
|
|
|
C.HWC2CHW() |
|
|
|
] |
|
|
|
ds = ds.map(operations=trans, input_columns=["image_A"], num_parallel_workers=num_parallel_workers) |
|
|
|
ds = ds.map(operations=trans, input_columns=["image_B"], num_parallel_workers=num_parallel_workers) |
|
|
|
ds = ds.batch(batch_size, drop_remainder=True) |
|
|
|
|