You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

md_dataset.py 3.7 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the License);
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # httpwww.apache.orglicensesLICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an AS IS BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Dataset module."""
  16. from PIL import Image
  17. import mindspore.dataset as de
  18. import mindspore.dataset.transforms.vision.c_transforms as C
  19. import numpy as np
  20. from .ei_dataset import HwVocRawDataset
  21. from .utils import custom_transforms as tr
  22. class DataTransform:
  23. """Transform dataset for DeepLabV3."""
  24. def __init__(self, args, usage):
  25. self.args = args
  26. self.usage = usage
  27. def __call__(self, image, label):
  28. if self.usage == "train":
  29. return self._train(image, label)
  30. if self.usage == "eval":
  31. return self._eval(image, label)
  32. return None
  33. def _train(self, image, label):
  34. """
  35. Process training data.
  36. Args:
  37. image (list): Image data.
  38. label (list): Dataset label.
  39. """
  40. image = Image.fromarray(image)
  41. label = Image.fromarray(label)
  42. rsc_tr = tr.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size)
  43. image, label = rsc_tr(image, label)
  44. rhf_tr = tr.RandomHorizontalFlip()
  45. image, label = rhf_tr(image, label)
  46. image = np.array(image).astype(np.float32)
  47. label = np.array(label).astype(np.float32)
  48. return image, label
  49. def _eval(self, image, label):
  50. """
  51. Process eval data.
  52. Args:
  53. image (list): Image data.
  54. label (list): Dataset label.
  55. """
  56. image = Image.fromarray(image)
  57. label = Image.fromarray(label)
  58. fsc_tr = tr.FixScaleCrop(crop_size=self.args.crop_size)
  59. image, label = fsc_tr(image, label)
  60. image = np.array(image).astype(np.float32)
  61. label = np.array(label).astype(np.float32)
  62. return image, label
  63. def create_dataset(args, data_url, epoch_num=1, batch_size=1, usage="train", shuffle=True):
  64. """
  65. Create Dataset for DeepLabV3.
  66. Args:
  67. args (dict): Train parameters.
  68. data_url (str): Dataset path.
  69. epoch_num (int): Epoch of dataset (default=1).
  70. batch_size (int): Batch size of dataset (default=1).
  71. usage (str): Whether is use to train or eval (default='train').
  72. Returns:
  73. Dataset.
  74. """
  75. # create iter dataset
  76. dataset = HwVocRawDataset(data_url, usage=usage)
  77. dataset_len = len(dataset)
  78. # wrapped with GeneratorDataset
  79. dataset = de.GeneratorDataset(dataset, ["image", "label"], sampler=None)
  80. dataset.set_dataset_size(dataset_len)
  81. dataset = dataset.map(input_columns=["image", "label"], operations=DataTransform(args, usage=usage))
  82. channelswap_op = C.HWC2CHW()
  83. dataset = dataset.map(input_columns="image", operations=channelswap_op)
  84. # 1464 samples / batch_size 8 = 183 batches
  85. # epoch_num is num of steps
  86. # 3658 steps / 183 = 20 epochs
  87. if usage == "train" and shuffle:
  88. dataset = dataset.shuffle(1464)
  89. dataset = dataset.batch(batch_size, drop_remainder=(usage == "train"))
  90. dataset = dataset.repeat(count=epoch_num)
  91. dataset.map_model = 4
  92. return dataset