You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

image_folder.py 3.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. """
  2. Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
  3. Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
  4. """
  5. ###############################################################################
  6. # Code from
  7. # https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py
  8. # Modified the original code so that it also loads images from the current
  9. # directory as well as the subdirectories
  10. ###############################################################################
  11. from jittor.dataset.dataset import Dataset
  12. from PIL import Image
  13. import os
  14. IMG_EXTENSIONS = [
  15. '.jpg', '.JPG', '.jpeg', '.JPEG',
  16. '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tiff', '.webp'
  17. ]
  18. def is_image_file(filename):
  19. return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
  20. def make_dataset_rec(dir, images):
  21. assert os.path.isdir(dir), '%s is not a valid directory' % dir
  22. for root, dnames, fnames in sorted(os.walk(dir, followlinks=True)):
  23. for fname in fnames:
  24. if is_image_file(fname):
  25. path = os.path.join(root, fname)
  26. images.append(path)
  27. def make_dataset(dir, recursive=False, read_cache=False, write_cache=False):
  28. images = []
  29. if read_cache:
  30. possible_filelist = os.path.join(dir, 'files.list')
  31. if os.path.isfile(possible_filelist):
  32. with open(possible_filelist, 'r') as f:
  33. images = f.read().splitlines()
  34. return images
  35. if recursive:
  36. make_dataset_rec(dir, images)
  37. else:
  38. assert os.path.isdir(dir) or os.path.islink(
  39. dir), '%s is not a valid directory' % dir
  40. for root, dnames, fnames in sorted(os.walk(dir)):
  41. for fname in fnames:
  42. if is_image_file(fname):
  43. path = os.path.join(root, fname)
  44. images.append(path)
  45. if write_cache:
  46. filelist_cache = os.path.join(dir, 'files.list')
  47. with open(filelist_cache, 'w') as f:
  48. for path in images:
  49. f.write("%s\n" % path)
  50. print('wrote filelist cache at %s' % filelist_cache)
  51. return images
  52. def default_loader(path):
  53. return Image.open(path).convert('RGB')
  54. class ImageFolder(Dataset):
  55. def __init__(self, root, transform=None, return_paths=False,
  56. loader=default_loader):
  57. imgs = make_dataset(root)
  58. if len(imgs) == 0:
  59. raise(RuntimeError("Found 0 images in: " + root + "\n"
  60. "Supported image extensions are: " +
  61. ",".join(IMG_EXTENSIONS)))
  62. self.root = root
  63. self.imgs = imgs
  64. self.transform = transform
  65. self.return_paths = return_paths
  66. self.loader = loader
  67. def __getitem__(self, index):
  68. path = self.imgs[index]
  69. img = self.loader(path)
  70. if self.transform is not None:
  71. img = self.transform(img)
  72. if self.return_paths:
  73. return img, path
  74. else:
  75. return img
  76. def __len__(self):
  77. return len(self.imgs)

第三届计图人工智能挑战赛——风格及语义引导的风景图片生成赛道项目,由jittor计图框架实现