| @@ -24,7 +24,7 @@ from ....core.serialization import load, save | |||
| from ....distributed.util import is_distributed | |||
| from ....logger import get_logger | |||
| from .folder import ImageFolder | |||
| from .utils import _default_dataset_root, untar, untargz | |||
| from .utils import _default_dataset_root, calculate_md5, untar, untargz | |||
| logger = get_logger(__name__) | |||
| @@ -33,40 +33,28 @@ class ImageNet(ImageFolder): | |||
| r""" | |||
| Load ImageNet from raw files or folder, expected folder looks like | |||
| raw files situation (optional): | |||
| root/ILSVRC2012_img_train.tar | |||
| root/ILSVRC2012_img_val.tar | |||
| root/ILSVRC2012_devkit_t12.tar.gz | |||
| image folder situation (required): | |||
| root/train/cls/xxx.${img_ext} | |||
| root/val/cls/xxx.${img_ext} | |||
| root/ILSVRC2012_devkit_t12/data/meta.mat | |||
| root/ILSVRC2012_devkit_t12/data/ILSVRC2012_validation_ground_truth.txt | |||
| If the required folders don't exist, raw files are required to get extracted and processed. | |||
| ${root}/ | |||
| | [REQUIRED TAR FILES] | |||
| |- ILSVRC2012_img_train.tar | |||
| |- ILSVRC2012_img_val.tar | |||
| |- ILSVRC2012_devkit_t12.tar.gz | |||
| | [OPTIONAL IMAGE FOLDERS] | |||
| |- train/cls/xxx.${img_ext} | |||
| |- val/cls/xxx.${img_ext} | |||
| |- ILSVRC2012_devkit_t12/data/meta.mat | |||
| |- ILSVRC2012_devkit_t12/data/ILSVRC2012_validation_ground_truth.txt | |||
| If the image folders don't exist, raw tar files are required to get extracted and processed. | |||
| """ | |||
| raw_file_meta = { | |||
| "train": ("ILSVRC2012_img_train.tar", "1d675b47d978889d74fa0da5fadfb00e"), | |||
| "val": ("ILSVRC2012_img_val.tar", "29b22e2961454d5413ddabcf34fc5622"), | |||
| "devkit": ("ILSVRC2012_devkit_t12.tar.gz", "fa75699e90414af021442c21a62c3abf"), | |||
| } | |||
| """ | |||
| raw files of ImageNet (train, val, devkit) | |||
| """ | |||
| } # ImageNet raw files | |||
| default_train_dir = "train" | |||
| """ | |||
| directory of train data | |||
| """ | |||
| default_val_dir = "val" | |||
| """ | |||
| directory of val data | |||
| """ | |||
| default_devkit_dir = "ILSVRC2012_devkit_t12" | |||
| """ | |||
| directory of devkit | |||
| """ | |||
| def __init__(self, root: str = None, train: bool = True, **kwargs): | |||
| r""" | |||
| @@ -97,13 +85,16 @@ class ImageNet(ImageFolder): | |||
| else: | |||
| self.root = root | |||
| self.devkit_dir = os.path.join(self.root, self.default_devkit_dir) | |||
| if not os.path.exists(self.root): | |||
| raise FileNotFoundError("dir %s does not exist" % self.root) | |||
| self.devkit_dir = os.path.join(self.root, self.default_devkit_dir) | |||
| if not os.path.exists(self.devkit_dir): | |||
| logger.warning("devkit directory %s does not exists" % self.devkit_dir) | |||
| self._prepare_devkit() | |||
| self.train = train | |||
| if train: | |||
| self.target_folder = os.path.join(self.root, self.default_train_dir) | |||
| @@ -125,7 +116,7 @@ class ImageNet(ImageFolder): | |||
| "extracting raw file shouldn't be done in distributed mode, use single process instead" | |||
| ) | |||
| else: | |||
| self.parse(train) | |||
| self._prepare_train() if train else self._prepare_val() | |||
| super().__init__(self.target_folder, **kwargs) | |||
| @@ -180,14 +171,13 @@ class ImageNet(ImageFolder): | |||
| ] | |||
| ) | |||
| def organize_val_data(self): | |||
| def _organize_val_data(self): | |||
| id2wnid = self.meta[0] | |||
| val_idcs = self.valid_ground_truth | |||
| val_wnids = [id2wnid[idx] for idx in val_idcs] | |||
| raw_val_dir = os.path.join(self.root, "ILSVRC2012_img_val") | |||
| val_images = sorted( | |||
| [os.path.join(raw_val_dir, image) for image in os.listdir(raw_val_dir)] | |||
| [os.path.join(self.target_folder, image) for image in os.listdir(self.target_folder)] | |||
| ) | |||
| logger.debug("mkdir for val set wnids") | |||
| @@ -203,24 +193,41 @@ class ImageNet(ImageFolder): | |||
| ), | |||
| ) | |||
| def parse(self, train): | |||
| if train: | |||
| logger.info("process train raw file.. this may take several hours") | |||
| untar( | |||
| os.path.join(self.root, self.raw_file_meta["train"][0]), | |||
| self.target_folder, | |||
| ) | |||
| paths = [ | |||
| os.path.join(self.target_folder, child_dir) | |||
| for child_dir in os.listdir(self.target_folder) | |||
| ] | |||
| for path in tqdm(paths): | |||
| untar(path, os.path.splitext(path)[0], remove=True) | |||
| else: | |||
| logger.info("process devkit file..") | |||
| untargz(os.path.join(self.root, self.raw_file_meta["devkit"][0])) | |||
| logger.info("process valid raw file.. this may take 10-20 minutes") | |||
| raw_val_dir = os.path.join(self.root, "ILSVRC2012_img_val") | |||
| os.makedirs(raw_val_dir, exist_ok=True) | |||
| untar(os.path.join(self.root, self.raw_file_meta["val"][0]), raw_val_dir) | |||
| self.organize_val_data() | |||
| def _prepare_val(self): | |||
| assert not self.train | |||
| raw_filename, checksum = self.raw_file_meta["val"] | |||
| raw_file = os.path.join(self.root, raw_filename) | |||
| logger.info("checksum valid tar file {} ..".format(raw_file)) | |||
| assert calculate_md5(raw_file) == checksum, \ | |||
| "checksum mismatch, {} may be damaged".format(raw_file) | |||
| logger.info("extract valid tar file.. this may take 10-20 minutes") | |||
| untar(os.path.join(self.root, raw_file), self.target_folder) | |||
| self._organize_val_data() | |||
| def _prepare_train(self): | |||
| assert self.train | |||
| raw_filename, checksum = self.raw_file_meta["train"] | |||
| raw_file = os.path.join(self.root, raw_filename) | |||
| logger.info("checksum train tar file {} ..".format(raw_file)) | |||
| assert calculate_md5(raw_file) == checksum, \ | |||
| "checksum mismatch, {} may be damaged".format(raw_file) | |||
| logger.info("extract train tar file.. this may take several hours") | |||
| untar( | |||
| os.path.join(self.root, raw_file), | |||
| self.target_folder, | |||
| ) | |||
| paths = [ | |||
| os.path.join(self.target_folder, child_dir) | |||
| for child_dir in os.listdir(self.target_folder) | |||
| ] | |||
| for path in tqdm(paths): | |||
| untar(path, os.path.splitext(path)[0], remove=True) | |||
| def _prepare_devkit(self): | |||
| raw_filename, checksum = self.raw_file_meta["val"] | |||
| raw_file = os.path.join(self.root, raw_filename) | |||
| logger.info("checksum devkit tar file {} ..".format(raw_file)) | |||
| assert calculate_md5(raw_file) == checksum, \ | |||
| "checksum mismatch, {} may be damaged".format(raw_file) | |||
| logger.info("extract devkit file..") | |||
| untargz(os.path.join(self.root, self.raw_file_meta["devkit"][0])) | |||