From: @zhao_ting_v Reviewed-by: @guoqi1024,@wuxuejian Signed-off-by: @guoqi1024tags/v1.1.0
| @@ -228,7 +228,7 @@ python export.py --platform [PLATFORM] --G_A_ckpt [G_A_CKPT] --G_B_ckpt [G_B_CKP | |||||
| # [Description of Random Situation](#contents) | # [Description of Random Situation](#contents) | ||||
| In dataset.py, we set the seed inside “create_dataset" function. We also use random seed in train.py. | |||||
| If you set --use_random=False, there are no random when training. | |||||
| # [ModelZoo Homepage](#contents) | # [ModelZoo Homepage](#contents) | ||||
| @@ -21,29 +21,38 @@ import mindspore.dataset.vision.c_transforms as C | |||||
| from .distributed_sampler import DistributedSampler | from .distributed_sampler import DistributedSampler | ||||
| from .datasets import UnalignedDataset, ImageFolderDataset | from .datasets import UnalignedDataset, ImageFolderDataset | ||||
| def create_dataset(args, shuffle=True, max_dataset_size=float("inf")): | |||||
| def create_dataset(args): | |||||
| """Create dataset""" | """Create dataset""" | ||||
| dataroot = args.dataroot | dataroot = args.dataroot | ||||
| phase = args.phase | phase = args.phase | ||||
| batch_size = args.batch_size | batch_size = args.batch_size | ||||
| device_num = args.device_num | device_num = args.device_num | ||||
| rank = args.rank | rank = args.rank | ||||
| shuffle = args.use_random | |||||
| max_dataset_size = args.max_dataset_size | |||||
| cores = multiprocessing.cpu_count() | cores = multiprocessing.cpu_count() | ||||
| num_parallel_workers = min(8, int(cores / device_num)) | num_parallel_workers = min(8, int(cores / device_num)) | ||||
| image_size = args.image_size | image_size = args.image_size | ||||
| mean = [0.5 * 255] * 3 | mean = [0.5 * 255] * 3 | ||||
| std = [0.5 * 255] * 3 | std = [0.5 * 255] * 3 | ||||
| if phase == "train": | if phase == "train": | ||||
| dataset = UnalignedDataset(dataroot, phase, max_dataset_size=max_dataset_size) | |||||
| dataset = UnalignedDataset(dataroot, phase, max_dataset_size=max_dataset_size, use_random=args.use_random) | |||||
| distributed_sampler = DistributedSampler(len(dataset), device_num, rank, shuffle=shuffle) | distributed_sampler = DistributedSampler(len(dataset), device_num, rank, shuffle=shuffle) | ||||
| ds = de.GeneratorDataset(dataset, column_names=["image_A", "image_B"], | ds = de.GeneratorDataset(dataset, column_names=["image_A", "image_B"], | ||||
| sampler=distributed_sampler, num_parallel_workers=num_parallel_workers) | sampler=distributed_sampler, num_parallel_workers=num_parallel_workers) | ||||
| trans = [ | |||||
| C.RandomResizedCrop(image_size, scale=(0.5, 1.0), ratio=(0.75, 1.333)), | |||||
| C.RandomHorizontalFlip(prob=0.5), | |||||
| C.Normalize(mean=mean, std=std), | |||||
| C.HWC2CHW() | |||||
| ] | |||||
| if args.use_random: | |||||
| trans = [ | |||||
| C.RandomResizedCrop(image_size, scale=(0.5, 1.0), ratio=(0.75, 1.333)), | |||||
| C.RandomHorizontalFlip(prob=0.5), | |||||
| C.Normalize(mean=mean, std=std), | |||||
| C.HWC2CHW() | |||||
| ] | |||||
| else: | |||||
| trans = [ | |||||
| C.Resize((image_size, image_size)), | |||||
| C.Normalize(mean=mean, std=std), | |||||
| C.HWC2CHW() | |||||
| ] | |||||
| ds = ds.map(operations=trans, input_columns=["image_A"], num_parallel_workers=num_parallel_workers) | ds = ds.map(operations=trans, input_columns=["image_A"], num_parallel_workers=num_parallel_workers) | ||||
| ds = ds.map(operations=trans, input_columns=["image_B"], num_parallel_workers=num_parallel_workers) | ds = ds.map(operations=trans, input_columns=["image_B"], num_parallel_workers=num_parallel_workers) | ||||
| ds = ds.batch(batch_size, drop_remainder=True) | ds = ds.batch(batch_size, drop_remainder=True) | ||||
| @@ -53,7 +53,7 @@ class UnalignedDataset: | |||||
| Two domain image path list. | Two domain image path list. | ||||
| """ | """ | ||||
| def __init__(self, dataroot, phase, max_dataset_size=float("inf")): | |||||
| def __init__(self, dataroot, phase, max_dataset_size=float("inf"), use_random=True): | |||||
| self.dir_A = os.path.join(dataroot, phase + 'A') | self.dir_A = os.path.join(dataroot, phase + 'A') | ||||
| self.dir_B = os.path.join(dataroot, phase + 'B') | self.dir_B = os.path.join(dataroot, phase + 'B') | ||||
| @@ -61,12 +61,14 @@ class UnalignedDataset: | |||||
| self.B_paths = sorted(make_dataset(self.dir_B, max_dataset_size)) # load images from '/path/to/data/trainB' | self.B_paths = sorted(make_dataset(self.dir_B, max_dataset_size)) # load images from '/path/to/data/trainB' | ||||
| self.A_size = len(self.A_paths) # get the size of dataset A | self.A_size = len(self.A_paths) # get the size of dataset A | ||||
| self.B_size = len(self.B_paths) # get the size of dataset B | self.B_size = len(self.B_paths) # get the size of dataset B | ||||
| self.use_random = use_random | |||||
| def __getitem__(self, index): | def __getitem__(self, index): | ||||
| if index % max(self.A_size, self.B_size) == 0: | |||||
| index_B = index % self.B_size | |||||
| if index % max(self.A_size, self.B_size) == 0 and self.use_random: | |||||
| random.shuffle(self.A_paths) | random.shuffle(self.A_paths) | ||||
| index_B = random.randint(0, self.B_size - 1) | |||||
| A_path = self.A_paths[index % self.A_size] | A_path = self.A_paths[index % self.A_size] | ||||
| index_B = random.randint(0, self.B_size - 1) | |||||
| B_path = self.B_paths[index_B] | B_path = self.B_paths[index_B] | ||||
| A_img = np.array(Image.open(A_path).convert('RGB')) | A_img = np.array(Image.open(A_path).convert('RGB')) | ||||
| B_img = np.array(Image.open(B_path).convert('RGB')) | B_img = np.array(Image.open(B_path).convert('RGB')) | ||||
| @@ -15,7 +15,7 @@ | |||||
| """Cycle GAN network.""" | """Cycle GAN network.""" | ||||
| import mindspore.nn as nn | import mindspore.nn as nn | ||||
| from mindspore.common import initializer as init | |||||
| def init_weights(net, init_type='normal', init_gain=0.02): | def init_weights(net, init_type='normal', init_gain=0.02): | ||||
| """ | """ | ||||
| @@ -27,12 +27,14 @@ def init_weights(net, init_type='normal', init_gain=0.02): | |||||
| init_gain (float): Gain factor for normal and xavier. | init_gain (float): Gain factor for normal and xavier. | ||||
| """ | """ | ||||
| for cell in net.cells_and_names(): | |||||
| if isinstance(cell, nn.Conv2d): | |||||
| for _, cell in net.cells_and_names(): | |||||
| if isinstance(cell, (nn.Conv2d, nn.Conv2dTranspose)): | |||||
| if init_type == 'normal': | if init_type == 'normal': | ||||
| cell.weight.set_data(init.initializer(init.Normal(init_gain))) | |||||
| cell.weight.set_data(init.initializer(init.Normal(init_gain), cell.weight.shape)) | |||||
| elif init_type == 'xavier': | elif init_type == 'xavier': | ||||
| cell.weight.set_data(init.initializer(init.XavierUniform(init_gain))) | |||||
| cell.weight.set_data(init.initializer(init.XavierUniform(init_gain), cell.weight.shape)) | |||||
| elif init_type == 'constant': | |||||
| cell.weight.set_data(init.initializer(0.001, cell.weight.shape)) | |||||
| else: | else: | ||||
| raise NotImplementedError('initialization method [%s] is not implemented' % init_type) | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) | ||||
| elif isinstance(cell, nn.BatchNorm2d): | elif isinstance(cell, nn.BatchNorm2d): | ||||
| @@ -105,6 +105,9 @@ def get_args(phase): | |||||
| parser.add_argument('--save_imgs', type=ast.literal_eval, default=True, \ | parser.add_argument('--save_imgs', type=ast.literal_eval, default=True, \ | ||||
| help='whether save imgs when epoch end, if True result images will generate in ' | help='whether save imgs when epoch end, if True result images will generate in ' | ||||
| '`outputs_dir/imgs`, default is True.') | '`outputs_dir/imgs`, default is True.') | ||||
| parser.add_argument('--use_random', type=ast.literal_eval, default=True, \ | |||||
| help='whether use random when training, default is True.') | |||||
| parser.add_argument('--max_dataset_size', type=int, default=None, help='max images pre epoch, default is None.') | |||||
| if phase == "export": | if phase == "export": | ||||
| parser.add_argument("--file_name", type=str, default="cyclegan", help="output file name prefix.") | parser.add_argument("--file_name", type=str, default="cyclegan", help="output file name prefix.") | ||||
| parser.add_argument('--file_format', type=str, choices=["AIR", "ONNX", "MINDIR"], default='AIR', \ | parser.add_argument('--file_format', type=str, choices=["AIR", "ONNX", "MINDIR"], default='AIR', \ | ||||
| @@ -140,6 +143,14 @@ def get_args(phase): | |||||
| if args.dataroot is None and (phase in ["train", "predict"]): | if args.dataroot is None and (phase in ["train", "predict"]): | ||||
| raise ValueError('Must set dataroot!') | raise ValueError('Must set dataroot!') | ||||
| if not args.use_random: | |||||
| args.need_dropout = False | |||||
| args.init_type = "constant" | |||||
| if args.max_dataset_size is None: | |||||
| args.max_dataset_size = float("inf") | |||||
| args.n_epochs = min(args.max_epoch, args.n_epochs) | |||||
| args.n_epochs_decay = args.max_epoch - args.n_epochs | args.n_epochs_decay = args.max_epoch - args.n_epochs | ||||
| args.phase = phase | args.phase = phase | ||||
| return args | return args | ||||