You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dataset.py 2.7 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. Data operations, will be used in train.py and eval.py
  17. """
  18. import mindspore.common.dtype as mstype
  19. import mindspore.dataset.engine as de
  20. import mindspore.dataset.transforms.c_transforms as C2
  21. import mindspore.dataset.vision.c_transforms as C
  22. def create_dataset(dataset_path, config, do_train, repeat_num=1):
  23. """
  24. create a train or eval dataset
  25. Args:
  26. dataset_path(string): the path of dataset.
  27. config(dict): config of dataset.
  28. do_train(bool): whether dataset is used for train or eval.
  29. repeat_num(int): the repeat times of dataset. Default: 1.
  30. Returns:
  31. dataset
  32. """
  33. rank = config.rank
  34. group_size = config.group_size
  35. if group_size == 1:
  36. ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=config.work_nums, shuffle=True)
  37. else:
  38. ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=config.work_nums, shuffle=True,
  39. num_shards=group_size, shard_id=rank)
  40. # define map operations
  41. if do_train:
  42. trans = [
  43. C.RandomCropDecodeResize(config.image_size),
  44. C.RandomHorizontalFlip(prob=0.5),
  45. C.RandomColorAdjust(brightness=0.4, saturation=0.5) # fast mode
  46. # C.RandomColorAdjust(brightness=0.4, contrast=0.5, saturation=0.5, hue=0.2)
  47. ]
  48. else:
  49. trans = [
  50. C.Decode(),
  51. C.Resize(int(config.image_size / 0.875)),
  52. C.CenterCrop(config.image_size)
  53. ]
  54. trans += [
  55. C.Rescale(1.0 / 255.0, 0.0),
  56. C.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
  57. C.HWC2CHW()
  58. ]
  59. type_cast_op = C2.TypeCast(mstype.int32)
  60. ds = ds.map(input_columns="image", operations=trans, num_parallel_workers=config.work_nums)
  61. ds = ds.map(input_columns="label", operations=type_cast_op, num_parallel_workers=config.work_nums)
  62. # apply batch operations
  63. ds = ds.batch(config.batch_size, drop_remainder=True)
  64. # apply dataset repeat operation
  65. ds = ds.repeat(repeat_num)
  66. return ds