""" Copyright (C) 2019 NVIDIA Corporation. All rights reserved. Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). """ from data.pix2pix_dataset import Pix2pixDataset from data.image_folder import make_dataset import json class CustomDataset(Pix2pixDataset): """ Dataset that loads images from directories Use option --label_dir, --image_dir, --instance_dir to specify the directories. The images in the directories are sorted in alphabetical order and paired in order. """ @staticmethod def modify_commandline_options(parser, is_train): parser = Pix2pixDataset.modify_commandline_options(parser, is_train) # parser.set_defaults(preprocess_mode='resize_and_crop') # load_size = 286 if is_train else 256 # parser.set_defaults(load_size=load_size) # parser.set_defaults(crop_size=256) # parser.set_defaults(display_winsize=256) parser.set_defaults(label_nc=29) parser.set_defaults(contain_dontcare_label=False) parser.add_argument('--label_dir', type=str, required=True, help='path to the directory that contains label images') parser.add_argument('--image_dir', type=str, default='', help='path to the directory that contains photo images') parser.add_argument('--instance_dir', type=str, default='', help='path to the directory that contains instance maps. Leave black if not exists') return parser def get_paths(self, opt): label_dir = opt.label_dir label_paths = make_dataset(label_dir, recursive=False, read_cache=True) print(f"label_dir: {label_dir}, image_dir: {opt.image_dir}, instance_dir: {opt.instance_dir}") # exit(0) if len(opt.image_dir) > 0: self.image_dir = opt.image_dir image_paths = make_dataset( self.image_dir, recursive=False, read_cache=True) else: image_paths = [] if len(opt.instance_dir) > 0: instance_dir = opt.instance_dir instance_paths = make_dataset( instance_dir, recursive=False, read_cache=True) else: instance_paths = [] if opt.isTrain: self.isTrain = True assert len(label_paths) == len( image_paths), "The #images in %s and %s do not match. Is there something wrong?" else: self.isTrain = False self.ref_dict = {} with open('./label_to_img.json', 'r') as f: self.ref_dict = json.load(f) # for line in f: # a = line.strip().split(',') # self.ref_dict[a[0]] = a[1] self.name = "CustomDataset" return label_paths, image_paths, instance_paths