You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dataset.py 8.1 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. dataset processing.
  17. """
  18. import os
  19. from PIL import Image, ImageFile
  20. from mindspore.common import dtype as mstype
  21. import mindspore.dataset as de
  22. import mindspore.dataset.transforms.c_transforms as C
  23. import mindspore.dataset.vision.c_transforms as vision
  24. from src.utils.sampler import DistributedSampler
  25. ImageFile.LOAD_TRUNCATED_IMAGES = True
  26. def vgg_create_dataset(data_home, image_size, batch_size, rank_id=0, rank_size=1, repeat_num=1, training=True):
  27. """Data operations."""
  28. data_dir = os.path.join(data_home, "cifar-10-batches-bin")
  29. if not training:
  30. data_dir = os.path.join(data_home, "cifar-10-verify-bin")
  31. data_set = de.Cifar10Dataset(data_dir, num_shards=rank_size, shard_id=rank_id)
  32. rescale = 1.0 / 255.0
  33. shift = 0.0
  34. # define map operations
  35. random_crop_op = vision.RandomCrop((32, 32), (4, 4, 4, 4)) # padding_mode default CONSTANT
  36. random_horizontal_op = vision.RandomHorizontalFlip()
  37. resize_op = vision.Resize(image_size) # interpolation default BILINEAR
  38. rescale_op = vision.Rescale(rescale, shift)
  39. normalize_op = vision.Normalize((0.4465, 0.4822, 0.4914), (0.2010, 0.1994, 0.2023))
  40. changeswap_op = vision.HWC2CHW()
  41. type_cast_op = C.TypeCast(mstype.int32)
  42. c_trans = []
  43. if training:
  44. c_trans = [random_crop_op, random_horizontal_op]
  45. c_trans += [resize_op, rescale_op, normalize_op,
  46. changeswap_op]
  47. # apply map operations on images
  48. data_set = data_set.map(operations=type_cast_op, input_columns="label")
  49. data_set = data_set.map(operations=c_trans, input_columns="image")
  50. # apply repeat operations
  51. data_set = data_set.repeat(repeat_num)
  52. # apply shuffle operations
  53. data_set = data_set.shuffle(buffer_size=10)
  54. # apply batch operations
  55. data_set = data_set.batch(batch_size=batch_size, drop_remainder=True)
  56. return data_set
  57. def classification_dataset(data_dir, image_size, per_batch_size, rank=0, group_size=1,
  58. mode='train',
  59. input_mode='folder',
  60. root='',
  61. num_parallel_workers=None,
  62. shuffle=None,
  63. sampler=None,
  64. repeat_num=1,
  65. class_indexing=None,
  66. drop_remainder=True,
  67. transform=None,
  68. target_transform=None):
  69. """
  70. A function that returns a dataset for classification. The mode of input dataset could be "folder" or "txt".
  71. If it is "folder", all images within one folder have the same label. If it is "txt", all paths of images
  72. are written into a textfile.
  73. Args:
  74. data_dir (str): Path to the root directory that contains the dataset for "input_mode="folder"".
  75. Or path of the textfile that contains every image's path of the dataset.
  76. image_size (Union(int, sequence)): Size of the input images.
  77. per_batch_size (int): the batch size of evey step during training.
  78. rank (int): The shard ID within num_shards (default=None).
  79. group_size (int): Number of shards that the dataset should be divided
  80. into (default=None).
  81. mode (str): "train" or others. Default: " train".
  82. input_mode (str): The form of the input dataset. "folder" or "txt". Default: "folder".
  83. root (str): the images path for "input_mode="txt"". Default: " ".
  84. num_parallel_workers (int): Number of workers to read the data. Default: None.
  85. shuffle (bool): Whether or not to perform shuffle on the dataset
  86. (default=None, performs shuffle).
  87. sampler (Sampler): Object used to choose samples from the dataset. Default: None.
  88. repeat_num (int): the num of repeat dataset.
  89. class_indexing (dict): A str-to-int mapping from folder name to index
  90. (default=None, the folder names will be sorted
  91. alphabetically and each class will be given a
  92. unique index starting from 0).
  93. Examples:
  94. >>> from src.dataset import classification_dataset
  95. >>> # path to imagefolder directory. This directory needs to contain sub-directories which contain the images
  96. >>> data_dir = "/path/to/imagefolder_directory"
  97. >>> de_dataset = classification_dataset(data_dir, image_size=[224, 244],
  98. >>> per_batch_size=64, rank=0, group_size=4)
  99. >>> # Path of the textfile that contains every image's path of the dataset.
  100. >>> data_dir = "/path/to/dataset/images/train.txt"
  101. >>> images_dir = "/path/to/dataset/images"
  102. >>> de_dataset = classification_dataset(data_dir, image_size=[224, 244],
  103. >>> per_batch_size=64, rank=0, group_size=4,
  104. >>> input_mode="txt", root=images_dir)
  105. """
  106. mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
  107. std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
  108. if transform is None:
  109. if mode == 'train':
  110. transform_img = [
  111. vision.RandomCropDecodeResize(image_size, scale=(0.08, 1.0)),
  112. vision.RandomHorizontalFlip(prob=0.5),
  113. vision.Normalize(mean=mean, std=std),
  114. vision.HWC2CHW()
  115. ]
  116. else:
  117. transform_img = [
  118. vision.Decode(),
  119. vision.Resize((256, 256)),
  120. vision.CenterCrop(image_size),
  121. vision.Normalize(mean=mean, std=std),
  122. vision.HWC2CHW()
  123. ]
  124. else:
  125. transform_img = transform
  126. if target_transform is None:
  127. transform_label = [C.TypeCast(mstype.int32)]
  128. else:
  129. transform_label = target_transform
  130. if input_mode == 'folder':
  131. de_dataset = de.ImageFolderDataset(data_dir, num_parallel_workers=num_parallel_workers,
  132. shuffle=shuffle, sampler=sampler, class_indexing=class_indexing,
  133. num_shards=group_size, shard_id=rank)
  134. else:
  135. dataset = TxtDataset(root, data_dir)
  136. sampler = DistributedSampler(dataset, rank, group_size, shuffle=shuffle)
  137. de_dataset = de.GeneratorDataset(dataset, ["image", "label"], sampler=sampler)
  138. de_dataset = de_dataset.map(operations=transform_img, input_columns="image", num_parallel_workers=8)
  139. de_dataset = de_dataset.map(operations=transform_label, input_columns="label", num_parallel_workers=8)
  140. columns_to_project = ["image", "label"]
  141. de_dataset = de_dataset.project(columns=columns_to_project)
  142. de_dataset = de_dataset.batch(per_batch_size, drop_remainder=drop_remainder)
  143. de_dataset = de_dataset.repeat(repeat_num)
  144. return de_dataset
  145. class TxtDataset:
  146. """
  147. create txt dataset.
  148. Args:
  149. Returns:
  150. de_dataset.
  151. """
  152. def __init__(self, root, txt_name):
  153. super(TxtDataset, self).__init__()
  154. self.imgs = []
  155. self.labels = []
  156. fin = open(txt_name, "r")
  157. for line in fin:
  158. img_name, label = line.strip().split(' ')
  159. self.imgs.append(os.path.join(root, img_name))
  160. self.labels.append(int(label))
  161. fin.close()
  162. def __getitem__(self, index):
  163. img = Image.open(self.imgs[index]).convert('RGB')
  164. return img, self.labels[index]
  165. def __len__(self):
  166. return len(self.imgs)