You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Dataset.py 6.4 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """data process"""
  16. import math
  17. import sys
  18. import os
  19. from collections import defaultdict
  20. import numpy as np
  21. from PIL import ImageFile
  22. import cv2
  23. ImageFile.LOAD_TRUNCATED_IMAGES = True
  24. __all__ = ['DistributedPKSampler', 'Dataset']
  25. IMG_EXTENSIONS = ('.jpg', 'jpeg', '.png', '.ppm', '.bmp', 'pgm', '.tif', '.tiff', 'webp')
  26. class DistributedPKSampler:
  27. '''DistributedPKSampler'''
  28. def __init__(self, dataset, shuffle=True, p=5, k=2):
  29. assert isinstance(dataset, PKDataset), 'PK Sampler Only Supports PK Dataset!'
  30. self.p = p
  31. self.k = k
  32. self.dataset = dataset
  33. self.epoch = 0
  34. self.step_nums = int(math.ceil(len(self.dataset.classes)*1.0/p))
  35. self.total_ids = self.step_nums*p
  36. self.batch_size = p*k
  37. self.num_samples = self.total_ids * self.k
  38. self.shuffle = shuffle
  39. self.epoch_gen = 1
  40. def _sample_pk(self, indices):
  41. '''sample pk'''
  42. sampled_pk = []
  43. for indice in indices:
  44. sampled_id = indice
  45. replacement = False
  46. if len(self.dataset.id2range[sampled_id]) < self.k:
  47. replacement = True
  48. index_list = np.random.choice(self.dataset.id2range[sampled_id][0:], self.k, replace=replacement)
  49. sampled_pk.extend(index_list.tolist())
  50. return sampled_pk
  51. def __iter__(self):
  52. if self.shuffle:
  53. self.epoch_gen = (self.epoch_gen + 1) & 0xffffffff
  54. np.random.seed(self.epoch_gen)
  55. indices = np.random.permutation(len(self.dataset.classes))
  56. indices = indices.tolist()
  57. else:
  58. indices = list(range(len(self.dataset.classes)))
  59. indices += indices[:(self.total_ids - len(indices))]
  60. assert len(indices) == self.total_ids
  61. sampled_idxs = self._sample_pk(indices)
  62. return iter(sampled_idxs)
  63. def __len__(self):
  64. return self.num_samples
  65. def set_epoch(self, epoch):
  66. self.epoch = epoch
  67. def has_file_allowed_extension(filename, extensions):
  68. """ check if a file has an allowed extensio n.
  69. Args:
  70. filename (string): path to a file
  71. extensions (tuple of strings): extensions allowed (lowercase)
  72. Returns:
  73. bool: True if the file ends with one of the given extensions
  74. """
  75. return filename.lower().endswith(extensions)
  76. def make_dataset(dir_name, class_to_idx, extensions=None, is_valid_file=None):
  77. '''make dataset'''
  78. images = []
  79. dir_name = os.path.expanduser(dir_name)
  80. if not (extensions is None) ^ (is_valid_file is None):
  81. raise ValueError("Extensions and is_valid_file should not be the same.")
  82. def is_valid(x):
  83. if extensions is not None:
  84. return has_file_allowed_extension(x, extensions)
  85. return is_valid_file(x)
  86. for target in sorted(class_to_idx.keys()):
  87. d = os.path.join(dir_name, target)
  88. if not os.path.isdir(d):
  89. continue
  90. for root, _, fnames in sorted(os.walk(d)):
  91. for fname in sorted(fnames):
  92. path = os.path.join(root, fname)
  93. if is_valid(path):
  94. item = (path, class_to_idx[target], 0.6)
  95. images.append(item)
  96. return images
  97. class ImageFolderPKDataset:
  98. '''ImageFolderPKDataset'''
  99. def __init__(self, root):
  100. self.classes, self.classes_to_idx = self._find_classes(root)
  101. self.samples = make_dataset(root, self.classes_to_idx, IMG_EXTENSIONS, None)
  102. self.id2range = self._build_id2range()
  103. self.all_image_idxs = range(len(self.samples))
  104. self.classes = list(self.id2range.keys())
  105. def _find_classes(self, dir_name):
  106. """
  107. Finds the class folders in a dataset
  108. Args:
  109. dir (string): root directory path
  110. Returns:
  111. tuple (class, class_to_idx): where classes are relative to dir, and class_to_idx is a directionaty
  112. Ensures:
  113. No class is a subdirectory of others
  114. """
  115. if sys.version_info >= (3, 5):
  116. # Faster and available in Python 3.5 and above
  117. classes = [d.name for d in os.scandir(dir_name) if d.is_dir()]
  118. else:
  119. classes = [d for d in os.listdir(dir_name) if os.path.isdir(os.path.join(dir_name, d))]
  120. classes.sort()
  121. class_to_idx = {classes[i]: i for i in range(len(classes))}
  122. return classes, class_to_idx
  123. def _build_id2range(self):
  124. '''map id to range'''
  125. id2range = defaultdict(list)
  126. ret_range = defaultdict(list)
  127. for idx, sample in enumerate(self.samples):
  128. label = sample[1]
  129. id2range[label].append((sample, idx))
  130. for key in id2range:
  131. id2range[key].sort(key=lambda x: int(os.path.basename(x[0][0]).split(".")[0]))
  132. for item in id2range[key]:
  133. ret_range[key].append(item[1])
  134. return ret_range
  135. def __getitem__(self, index):
  136. return self.samples[index]
  137. def __len__(self):
  138. return len(self.samples)
  139. def pil_loader(path):
  140. '''pil loader'''
  141. img = cv2.imread(path)
  142. img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  143. return img
  144. class Dataset:
  145. '''Dataset'''
  146. def __init__(self, root, loader=pil_loader):
  147. self.dataset = ImageFolderPKDataset(root)
  148. print('Dataset len(dataset):{}'.format(len(self.dataset)))
  149. self.loader = loader
  150. self.classes = self.dataset.classes
  151. self.id2range = self.dataset.id2range
  152. def __getitem__(self, index):
  153. path, target1, target2 = self.dataset[index]
  154. sample = self.loader(path)
  155. return sample, target1, target2
  156. def __len__(self):
  157. return len(self.dataset)