You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

caltech.py 9.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226
  1. # Modified from https://github.com/pytorch/vision/blob/master/torchvision/datasets/caltech.py
  2. from __future__ import print_function
  3. from PIL import Image
  4. import os
  5. import os.path
  6. from torchvision.datasets.vision import VisionDataset
  7. from torchvision.datasets.utils import download_url
  8. class Caltech101(VisionDataset):
  9. """`Caltech 101 <http://www.vision.caltech.edu/Image_Datasets/Caltech101/>`_ Dataset.
  10. Args:
  11. root (string): Root directory of dataset where directory
  12. ``caltech101`` exists or will be saved to if download is set to True.
  13. target_type (string or list, optional): Type of target to use, ``category`` or
  14. ``annotation``. Can also be a list to output a tuple with all specified target types.
  15. ``category`` represents the target class, and ``annotation`` is a list of points
  16. from a hand-generated outline. Defaults to ``category``.
  17. transform (callable, optional): A function/transform that takes in an PIL image
  18. and returns a transformed version. E.g, ``transforms.RandomCrop``
  19. target_transform (callable, optional): A function/transform that takes in the
  20. target and transforms it.
  21. download (bool, optional): If true, downloads the dataset from the internet and
  22. puts it in root directory. If dataset is already downloaded, it is not
  23. downloaded again.
  24. """
  25. def __init__(self, root, target_type="category", train=True,
  26. transform=None, target_transform=None,
  27. download=False):
  28. super(Caltech101, self).__init__(os.path.join(root, 'caltech101'))
  29. self.train = train
  30. self.dir_name = '101_ObjectCategories_split/train' if self.train else '101_ObjectCategories_split/test'
  31. os.makdirs(self.root, exist_ok=True)
  32. if isinstance(target_type, list):
  33. self.target_type = target_type
  34. else:
  35. self.target_type = [target_type]
  36. self.transform = transform
  37. self.target_transform = target_transform
  38. if download:
  39. self.download()
  40. if not self._check_integrity():
  41. raise RuntimeError('Dataset not found or corrupted.' +
  42. ' You can use download=True to download it')
  43. self.categories = sorted(os.listdir(os.path.join(self.root, "101_ObjectCategories")))
  44. self.categories.remove("BACKGROUND_Google") # this is not a real class
  45. # For some reason, the category names in "101_ObjectCategories" and
  46. # "Annotations" do not always match. This is a manual map between the
  47. # two. Defaults to using same name, since most names are fine.
  48. name_map = {"Faces": "Faces_2",
  49. "Faces_easy": "Faces_3",
  50. "Motorbikes": "Motorbikes_16",
  51. "airplanes": "Airplanes_Side_2"}
  52. self.annotation_categories = list(map(lambda x: name_map[x] if x in name_map else x, self.categories))
  53. self.index = []
  54. self.y = []
  55. for (i, c) in enumerate(self.categories):
  56. file_names = os.listdir(os.path.join(self.root, self.dir_name, c))
  57. n = len(file_names)
  58. self.index.extend( file_names )
  59. self.y.extend(n * [i])
  60. print(self.train, len(self.index))
  61. def __getitem__(self, index):
  62. """
  63. Args:
  64. index (int): Index
  65. Returns:
  66. tuple: (image, target) where the type of target specified by target_type.
  67. """
  68. import scipy.io
  69. img = Image.open(os.path.join(self.root,
  70. self.dir_name,
  71. self.categories[self.y[index]],
  72. self.index[index])).convert("RGB")
  73. target = []
  74. for t in self.target_type:
  75. if t == "category":
  76. target.append(self.y[index])
  77. elif t == "annotation":
  78. data = scipy.io.loadmat(os.path.join(self.root,
  79. "Annotations",
  80. self.annotation_categories[self.y[index]],
  81. "annotation_{:04d}.mat".format(self.index[index])))
  82. target.append(data["obj_contour"])
  83. else:
  84. raise ValueError("Target type \"{}\" is not recognized.".format(t))
  85. target = tuple(target) if len(target) > 1 else target[0]
  86. if self.transform is not None:
  87. img = self.transform(img)
  88. if self.target_transform is not None:
  89. target = self.target_transform(target)
  90. return img, target
  91. def _check_integrity(self):
  92. # can be more robust and check hash of files
  93. return os.path.exists(os.path.join(self.root, "101_ObjectCategories"))
  94. def __len__(self):
  95. return len(self.index)
  96. def download(self):
  97. import tarfile
  98. if self._check_integrity():
  99. print('Files already downloaded and verified')
  100. return
  101. download_url("http://www.vision.caltech.edu/Image_Datasets/Caltech101/101_ObjectCategories.tar.gz",
  102. self.root,
  103. "101_ObjectCategories.tar.gz",
  104. "b224c7392d521a49829488ab0f1120d9")
  105. download_url("http://www.vision.caltech.edu/Image_Datasets/Caltech101/Annotations.tar",
  106. self.root,
  107. "101_Annotations.tar",
  108. "6f83eeb1f24d99cab4eb377263132c91")
  109. # extract file
  110. with tarfile.open(os.path.join(self.root, "101_ObjectCategories.tar.gz"), "r:gz") as tar:
  111. tar.extractall(path=self.root)
  112. with tarfile.open(os.path.join(self.root, "101_Annotations.tar"), "r:") as tar:
  113. tar.extractall(path=self.root)
  114. def extra_repr(self):
  115. return "Target type: {target_type}".format(**self.__dict__)
  116. class Caltech256(VisionDataset):
  117. """`Caltech 256 <http://www.vision.caltech.edu/Image_Datasets/Caltech256/>`_ Dataset.
  118. Args:
  119. root (string): Root directory of dataset where directory
  120. ``caltech256`` exists or will be saved to if download is set to True.
  121. transform (callable, optional): A function/transform that takes in an PIL image
  122. and returns a transformed version. E.g, ``transforms.RandomCrop``
  123. target_transform (callable, optional): A function/transform that takes in the
  124. target and transforms it.
  125. download (bool, optional): If true, downloads the dataset from the internet and
  126. puts it in root directory. If dataset is already downloaded, it is not
  127. downloaded again.
  128. """
  129. def __init__(self, root,
  130. transform=None, target_transform=None,
  131. download=False):
  132. super(Caltech256, self).__init__(os.path.join(root, 'caltech256'))
  133. os.makedirs(self.root, exist_ok=True)
  134. self.transform = transform
  135. self.target_transform = target_transform
  136. if download:
  137. self.download()
  138. if not self._check_integrity():
  139. raise RuntimeError('Dataset not found or corrupted.' +
  140. ' You can use download=True to download it')
  141. self.categories = sorted(os.listdir(os.path.join(self.root, "256_ObjectCategories")))
  142. self.index = []
  143. self.y = []
  144. for (i, c) in enumerate(self.categories):
  145. n = len(os.listdir(os.path.join(self.root, "256_ObjectCategories", c)))
  146. self.index.extend(range(1, n + 1))
  147. self.y.extend(n * [i])
  148. def __getitem__(self, index):
  149. """
  150. Args:
  151. index (int): Index
  152. Returns:
  153. tuple: (image, target) where target is index of the target class.
  154. """
  155. img = Image.open(os.path.join(self.root,
  156. "256_ObjectCategories",
  157. self.categories[self.y[index]],
  158. "{:03d}_{:04d}.jpg".format(self.y[index] + 1, self.index[index])))
  159. target = self.y[index]
  160. if self.transform is not None:
  161. img = self.transform(img)
  162. if self.target_transform is not None:
  163. target = self.target_transform(target)
  164. return img, target
  165. def _check_integrity(self):
  166. # can be more robust and check hash of files
  167. return os.path.exists(os.path.join(self.root, "256_ObjectCategories"))
  168. def __len__(self):
  169. return len(self.index)
  170. def download(self):
  171. import tarfile
  172. if self._check_integrity():
  173. print('Files already downloaded and verified')
  174. return
  175. download_url("http://www.vision.caltech.edu/Image_Datasets/Caltech256/256_ObjectCategories.tar",
  176. self.root,
  177. "256_ObjectCategories.tar",
  178. "67b4f42ca05d46448c6bb8ecd2220f6d")
  179. # extract file
  180. with tarfile.open(os.path.join(self.root, "256_ObjectCategories.tar"), "r:") as tar:
  181. tar.extractall(path=self.root)

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能

Contributors (1)