You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

MGDataset.py 7.0 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """MGDataset"""
  16. import math
  17. import sys
  18. import os
  19. import os.path as osp
  20. from collections import defaultdict
  21. import random
  22. import numpy as np
  23. from PIL import ImageFile
  24. import cv2
  25. ImageFile.LOAD_TRUNCATED_IMAGES = True
  26. __all__ = ['DistributedPKSampler', 'MGDataset']
  27. IMG_EXTENSIONS = ('.jpg', 'jpeg', '.png', '.ppm', '.bmp', 'pgm', '.tif', '.tiff', 'webp')
  28. class DistributedPKSampler:
  29. '''DistributedPKSampler'''
  30. def __init__(self, dataset, shuffle=True, p=5, k=2):
  31. assert isinstance(dataset, MGDataset), 'PK Sampler Only Supports PK Dataset or MG Dataset!'
  32. self.p = p
  33. self.k = k
  34. self.dataset = dataset
  35. self.epoch = 0
  36. self.step_nums = int(math.ceil(len(self.dataset.classes)*1.0/p))
  37. self.total_ids = self.step_nums*p
  38. self.batch_size = p*k
  39. self.num_samples = self.total_ids * self.k
  40. self.shuffle = shuffle
  41. self.epoch_gen = 1
  42. def _sample_pk(self, indices):
  43. '''sample pk'''
  44. sampled_pk = []
  45. for indice in indices:
  46. sampled_id = indice
  47. replacement = False
  48. if len(self.dataset.id2range[sampled_id]) < self.k:
  49. replacement = True
  50. index_list = np.random.choice(self.dataset.id2range[sampled_id][0:], self.k, replace=replacement)
  51. sampled_pk.extend(index_list.tolist())
  52. return sampled_pk
  53. def __iter__(self):
  54. if self.shuffle:
  55. self.epoch_gen = (self.epoch_gen + 1) & 0xffffffff
  56. np.random.seed(self.epoch_gen)
  57. indices = np.random.permutation(len(self.dataset.classes))
  58. indices = indices.tolist()
  59. else:
  60. indices = list(range(len(self.dataset.classes)))
  61. indices += indices[:(self.total_ids - len(indices))]
  62. assert len(indices) == self.total_ids
  63. sampled_idxs = self._sample_pk(indices)
  64. return iter(sampled_idxs)
  65. def __len__(self):
  66. return self.num_samples
  67. def set_epoch(self, epoch):
  68. self.epoch = epoch
  69. def has_file_allowed_extension(filename, extensions):
  70. """ check if a file has an allowed extensio n.
  71. Args:
  72. filename (string): path to a file
  73. extensions (tuple of strings): extensions allowed (lowercase)
  74. Returns:
  75. bool: True if the file ends with one of the given extensions
  76. """
  77. return filename.lower().endswith(extensions)
  78. def make_dataset(dir_name, class_to_idx, extensions=None, is_valid_file=None):
  79. '''make dataset'''
  80. images = []
  81. masked_datasets = ["n95", "3m", "new", "mask_1", "mask_2", "mask_3", "mask_4", "mask_5"]
  82. dir_name = os.path.expanduser(dir_name)
  83. if not (extensions is None) ^ (is_valid_file is None):
  84. raise ValueError("Extensions and is_valid_file should not be the same")
  85. def is_valid(x):
  86. if extensions is not None:
  87. return has_file_allowed_extension(x, extensions)
  88. return is_valid_file(x)
  89. for target in sorted(class_to_idx.keys()):
  90. d = os.path.join(dir_name, target)
  91. if not os.path.isdir(d):
  92. continue
  93. for root, _, fnames in sorted(os.walk(d)):
  94. for fname in sorted(fnames):
  95. path = os.path.join(root, fname)
  96. if is_valid(path):
  97. scale = float(osp.splitext(fname)[0].split('_')[1])
  98. item = (path, class_to_idx[target], scale)
  99. images.append(item)
  100. mask_root_path = root.replace("faces_webface_112x112_raw_image", random.choice(masked_datasets))
  101. mask_name = fname.split('_')[0]+".jpg"
  102. mask_path = osp.join(mask_root_path, mask_name)
  103. if os.path.isfile(mask_path) and is_valid(mask_path):
  104. item = (mask_path, class_to_idx[target], scale)
  105. images.append(item)
  106. return images
  107. class ImageFolderPKDataset:
  108. '''Image Folder PKDataset'''
  109. def __init__(self, root):
  110. self.classes, self.classes_to_idx = self._find_classes(root)
  111. self.samples = make_dataset(root, self.classes_to_idx, IMG_EXTENSIONS, None)
  112. self.id2range = self._build_id2range()
  113. self.all_image_idxs = range(len(self.samples))
  114. self.classes = list(self.id2range.keys())
  115. def _find_classes(self, dir_name):
  116. """
  117. Finds the class folders in a dataset
  118. Args:
  119. dir (string): root directory path
  120. Returns:
  121. tuple (class, class_to_idx): where classes are relative to dir, and class_to_idx is a directionaty
  122. Ensures:
  123. No class is a subdirectory of others
  124. """
  125. if sys.version_info >= (3, 5):
  126. # Faster and available in Python 3.5 and above
  127. classes = [d.name for d in os.scandir(dir_name) if d.is_dir()]
  128. else:
  129. classes = [d for d in os.listdir(dir_name) if os.path.isdir(os.path.join(dir_name, d))]
  130. classes.sort()
  131. class_to_idx = {classes[i]: i for i in range(len(classes))}
  132. return classes, class_to_idx
  133. def _build_id2range(self):
  134. '''id to range'''
  135. id2range = defaultdict(list)
  136. ret_range = defaultdict(list)
  137. for idx, sample in enumerate(self.samples):
  138. label = sample[1]
  139. id2range[label].append((sample, idx))
  140. for key in id2range:
  141. id2range[key].sort(key=lambda x: int(os.path.basename(x[0][0]).split(".")[0]))
  142. for item in id2range[key]:
  143. ret_range[key].append(item[1])
  144. return ret_range
  145. def __getitem__(self, index):
  146. return self.samples[index]
  147. def __len__(self):
  148. return len(self.samples)
  149. def pil_loader(path):
  150. '''load pil'''
  151. img = cv2.imread(path)
  152. img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  153. return img
  154. class MGDataset:
  155. '''MGDataset'''
  156. def __init__(self, root, loader=pil_loader):
  157. self.dataset = ImageFolderPKDataset(root)
  158. print('MGDataset len(dataset):{}'.format(len(self.dataset)))
  159. self.loader = loader
  160. self.classes = self.dataset.classes
  161. self.id2range = self.dataset.id2range
  162. def __getitem__(self, index):
  163. path, target1, target2 = self.dataset[index]
  164. sample = self.loader(path)
  165. return sample, target1, target2
  166. def __len__(self):
  167. return len(self.dataset)