You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dataset.py 6.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Data operations, will be used in train.py and eval.py"""
  16. import math
  17. import os
  18. import numpy as np
  19. import mindspore.dataset.vision.py_transforms as py_vision
  20. import mindspore.dataset.transforms.py_transforms as py_transforms
  21. import mindspore.dataset.transforms.c_transforms as c_transforms
  22. import mindspore.common.dtype as mstype
  23. import mindspore.dataset as ds
  24. from mindspore.communication.management import get_rank, get_group_size
  25. from mindspore.dataset.vision import Inter
  26. # values that should remain constant
  27. DEFAULT_CROP_PCT = 0.875
  28. IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
  29. IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
  30. # data preprocess configs
  31. SCALE = (0.08, 1.0)
  32. RATIO = (3./4., 4./3.)
  33. ds.config.set_seed(1)
  34. def split_imgs_and_labels(imgs, labels, batchInfo):
  35. """split data into labels and images"""
  36. ret_imgs = []
  37. ret_labels = []
  38. for i, image in enumerate(imgs):
  39. ret_imgs.append(image)
  40. ret_labels.append(labels[i])
  41. return np.array(ret_imgs), np.array(ret_labels)
  42. def create_dataset(batch_size, train_data_url='', workers=8, distributed=False,
  43. input_size=224, color_jitter=0.4):
  44. """Creat ImageNet training dataset"""
  45. if not os.path.exists(train_data_url):
  46. raise ValueError('Path not exists')
  47. decode_op = py_vision.Decode()
  48. type_cast_op = c_transforms.TypeCast(mstype.int32)
  49. random_resize_crop_bicubic = py_vision.RandomResizedCrop(size=(input_size, input_size),
  50. scale=SCALE, ratio=RATIO,
  51. interpolation=Inter.BICUBIC)
  52. random_horizontal_flip_op = py_vision.RandomHorizontalFlip(0.5)
  53. adjust_range = (max(0, 1 - color_jitter), 1 + color_jitter)
  54. random_color_jitter_op = py_vision.RandomColorAdjust(brightness=adjust_range,
  55. contrast=adjust_range,
  56. saturation=adjust_range)
  57. to_tensor = py_vision.ToTensor()
  58. normalize_op = py_vision.Normalize(
  59. IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
  60. # assemble all the transforms
  61. image_ops = py_transforms.Compose([decode_op, random_resize_crop_bicubic,
  62. random_horizontal_flip_op, random_color_jitter_op, to_tensor, normalize_op])
  63. rank_id = get_rank() if distributed else 0
  64. rank_size = get_group_size() if distributed else 1
  65. dataset_train = ds.ImageFolderDataset(train_data_url,
  66. num_parallel_workers=workers,
  67. shuffle=True,
  68. num_shards=rank_size,
  69. shard_id=rank_id)
  70. dataset_train = dataset_train.map(input_columns=["image"],
  71. operations=image_ops,
  72. num_parallel_workers=workers)
  73. dataset_train = dataset_train.map(input_columns=["label"],
  74. operations=type_cast_op,
  75. num_parallel_workers=workers)
  76. # batch dealing
  77. ds_train = dataset_train.batch(batch_size,
  78. per_batch_map=split_imgs_and_labels,
  79. input_columns=["image", "label"],
  80. num_parallel_workers=2,
  81. drop_remainder=True)
  82. ds_train = ds_train.repeat(1)
  83. return ds_train
  84. def create_dataset_val(batch_size=128, val_data_url='', workers=8, distributed=False,
  85. input_size=224):
  86. """Creat ImageNet validation dataset"""
  87. if not os.path.exists(val_data_url):
  88. raise ValueError('Path not exists')
  89. rank_id = get_rank() if distributed else 0
  90. rank_size = get_group_size() if distributed else 1
  91. dataset = ds.ImageFolderDataset(val_data_url, num_parallel_workers=workers,
  92. num_shards=rank_size, shard_id=rank_id)
  93. scale_size = None
  94. if isinstance(input_size, tuple):
  95. assert len(input_size) == 2
  96. if input_size[-1] == input_size[-2]:
  97. scale_size = int(math.floor(input_size[0] / DEFAULT_CROP_PCT))
  98. else:
  99. scale_size = tuple([int(x / DEFAULT_CROP_PCT) for x in input_size])
  100. else:
  101. scale_size = int(math.floor(input_size / DEFAULT_CROP_PCT))
  102. type_cast_op = c_transforms.TypeCast(mstype.int32)
  103. decode_op = py_vision.Decode()
  104. resize_op = py_vision.Resize(size=scale_size, interpolation=Inter.BICUBIC)
  105. center_crop = py_vision.CenterCrop(size=input_size)
  106. to_tensor = py_vision.ToTensor()
  107. normalize_op = py_vision.Normalize(
  108. IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
  109. image_ops = py_transforms.Compose([decode_op, resize_op, center_crop,
  110. to_tensor, normalize_op])
  111. dataset = dataset.map(input_columns=["label"], operations=type_cast_op,
  112. num_parallel_workers=workers)
  113. dataset = dataset.map(input_columns=["image"], operations=image_ops,
  114. num_parallel_workers=workers)
  115. dataset = dataset.batch(batch_size, per_batch_map=split_imgs_and_labels,
  116. input_columns=["image", "label"],
  117. num_parallel_workers=2,
  118. drop_remainder=True)
  119. dataset = dataset.repeat(1)
  120. return dataset