You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dataset.py 6.9 kB

5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. dataset processing.
  17. """
  18. import os
  19. from mindspore.common import dtype as mstype
  20. import mindspore.dataset as de
  21. import mindspore.dataset.transforms.c_transforms as C
  22. import mindspore.dataset.vision.c_transforms as V_C
  23. from PIL import Image, ImageFile
  24. from src.utils.sampler import DistributedSampler
  25. ImageFile.LOAD_TRUNCATED_IMAGES = True
  26. class TxtDataset():
  27. """
  28. create txt dataset.
  29. Args:
  30. Returns:
  31. de_dataset.
  32. """
  33. def __init__(self, root, txt_name):
  34. super(TxtDataset, self).__init__()
  35. self.imgs = []
  36. self.labels = []
  37. fin = open(txt_name, "r")
  38. for line in fin:
  39. img_name, label = line.strip().split(' ')
  40. self.imgs.append(os.path.join(root, img_name))
  41. self.labels.append(int(label))
  42. fin.close()
  43. def __getitem__(self, index):
  44. img = Image.open(self.imgs[index]).convert('RGB')
  45. return img, self.labels[index]
  46. def __len__(self):
  47. return len(self.imgs)
  48. def classification_dataset(data_dir, image_size, per_batch_size, max_epoch, rank, group_size,
  49. mode='train',
  50. input_mode='folder',
  51. root='',
  52. num_parallel_workers=None,
  53. shuffle=None,
  54. sampler=None,
  55. class_indexing=None,
  56. drop_remainder=True,
  57. transform=None,
  58. target_transform=None):
  59. """
  60. A function that returns a dataset for classification. The mode of input dataset could be "folder" or "txt".
  61. If it is "folder", all images within one folder have the same label. If it is "txt", all paths of images
  62. are written into a textfile.
  63. Args:
  64. data_dir (str): Path to the root directory that contains the dataset for "input_mode="folder"".
  65. Or path of the textfile that contains every image's path of the dataset.
  66. image_size (str): Size of the input images.
  67. per_batch_size (int): the batch size of evey step during training.
  68. max_epoch (int): the number of epochs.
  69. rank (int): The shard ID within num_shards (default=None).
  70. group_size (int): Number of shards that the dataset should be divided
  71. into (default=None).
  72. mode (str): "train" or others. Default: " train".
  73. input_mode (str): The form of the input dataset. "folder" or "txt". Default: "folder".
  74. root (str): the images path for "input_mode="txt"". Default: " ".
  75. num_parallel_workers (int): Number of workers to read the data. Default: None.
  76. shuffle (bool): Whether or not to perform shuffle on the dataset
  77. (default=None, performs shuffle).
  78. sampler (Sampler): Object used to choose samples from the dataset. Default: None.
  79. class_indexing (dict): A str-to-int mapping from folder name to index
  80. (default=None, the folder names will be sorted
  81. alphabetically and each class will be given a
  82. unique index starting from 0).
  83. Examples:
  84. >>> from mindvision.common.datasets.classification import classification_dataset
  85. >>> # path to imagefolder directory. This directory needs to contain sub-directories which contain the images
  86. >>> dataset_dir = "/path/to/imagefolder_directory"
  87. >>> de_dataset = classification_dataset(train_data_dir, image_size=[224, 244],
  88. >>> per_batch_size=64, max_epoch=100,
  89. >>> rank=0, group_size=4)
  90. >>> # Path of the textfile that contains every image's path of the dataset.
  91. >>> dataset_dir = "/path/to/dataset/images/train.txt"
  92. >>> images_dir = "/path/to/dataset/images"
  93. >>> de_dataset = classification_dataset(train_data_dir, image_size=[224, 244],
  94. >>> per_batch_size=64, max_epoch=100,
  95. >>> rank=0, group_size=4,
  96. >>> input_mode="txt", root=images_dir)
  97. """
  98. mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
  99. std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
  100. if transform is None:
  101. if mode == 'train':
  102. transform_img = [
  103. V_C.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
  104. V_C.RandomHorizontalFlip(prob=0.5),
  105. V_C.RandomColorAdjust(brightness=0.4, contrast=0.4, saturation=0.4),
  106. V_C.Normalize(mean=mean, std=std),
  107. V_C.HWC2CHW()
  108. ]
  109. else:
  110. transform_img = [
  111. V_C.Decode(),
  112. V_C.Resize((256, 256)),
  113. V_C.CenterCrop(image_size),
  114. V_C.Normalize(mean=mean, std=std),
  115. V_C.HWC2CHW()
  116. ]
  117. else:
  118. transform_img = transform
  119. if target_transform is None:
  120. transform_label = [C.TypeCast(mstype.int32)]
  121. else:
  122. transform_label = target_transform
  123. if input_mode == 'folder':
  124. de_dataset = de.ImageFolderDataset(data_dir, num_parallel_workers=num_parallel_workers,
  125. shuffle=shuffle, sampler=sampler, class_indexing=class_indexing,
  126. num_shards=group_size, shard_id=rank)
  127. else:
  128. dataset = TxtDataset(root, data_dir)
  129. sampler = DistributedSampler(dataset, rank, group_size, shuffle=shuffle)
  130. de_dataset = de.GeneratorDataset(dataset, ["image", "label"], sampler=sampler)
  131. de_dataset = de_dataset.map(input_columns="image", num_parallel_workers=num_parallel_workers,
  132. operations=transform_img)
  133. de_dataset = de_dataset.map(input_columns="label", num_parallel_workers=num_parallel_workers,
  134. operations=transform_label)
  135. columns_to_project = ["image", "label"]
  136. de_dataset = de_dataset.project(columns=columns_to_project)
  137. de_dataset = de_dataset.batch(per_batch_size, drop_remainder=drop_remainder)
  138. de_dataset = de_dataset.repeat(max_epoch)
  139. return de_dataset