|
- """
- Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
- Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
- """
-
- ###############################################################################
- # Code from
- # https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py
- # Modified the original code so that it also loads images from the current
- # directory as well as the subdirectories
- ###############################################################################
- from jittor.dataset.dataset import Dataset
- from PIL import Image
- import os
-
- IMG_EXTENSIONS = [
- '.jpg', '.JPG', '.jpeg', '.JPEG',
- '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tiff', '.webp'
- ]
-
-
- def is_image_file(filename):
- return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
-
-
- def make_dataset_rec(dir, images):
- assert os.path.isdir(dir), '%s is not a valid directory' % dir
-
- for root, dnames, fnames in sorted(os.walk(dir, followlinks=True)):
- for fname in fnames:
- if is_image_file(fname):
- path = os.path.join(root, fname)
- images.append(path)
-
-
- def make_dataset(dir, recursive=False, read_cache=False, write_cache=False):
- images = []
-
- if read_cache:
- possible_filelist = os.path.join(dir, 'files.list')
- if os.path.isfile(possible_filelist):
- with open(possible_filelist, 'r') as f:
- images = f.read().splitlines()
- return images
-
- if recursive:
- make_dataset_rec(dir, images)
- else:
- assert os.path.isdir(dir) or os.path.islink(
- dir), '%s is not a valid directory' % dir
-
- for root, dnames, fnames in sorted(os.walk(dir)):
- for fname in fnames:
- if is_image_file(fname):
- path = os.path.join(root, fname)
- images.append(path)
-
- if write_cache:
- filelist_cache = os.path.join(dir, 'files.list')
- with open(filelist_cache, 'w') as f:
- for path in images:
- f.write("%s\n" % path)
- print('wrote filelist cache at %s' % filelist_cache)
-
- return images
-
-
- def default_loader(path):
- return Image.open(path).convert('RGB')
-
-
- class ImageFolder(Dataset):
-
- def __init__(self, root, transform=None, return_paths=False,
- loader=default_loader):
- imgs = make_dataset(root)
- if len(imgs) == 0:
- raise(RuntimeError("Found 0 images in: " + root + "\n"
- "Supported image extensions are: " +
- ",".join(IMG_EXTENSIONS)))
-
- self.root = root
- self.imgs = imgs
- self.transform = transform
- self.return_paths = return_paths
- self.loader = loader
-
- def __getitem__(self, index):
- path = self.imgs[index]
- img = self.loader(path)
- if self.transform is not None:
- img = self.transform(img)
- if self.return_paths:
- return img, path
- else:
- return img
-
- def __len__(self):
- return len(self.imgs)
|