You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dataset.py 5.1 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. create train or eval dataset.
  17. """
  18. import os
  19. import numpy as np
  20. from mindspore import Tensor
  21. from mindspore.train.model import Model
  22. import mindspore.common.dtype as mstype
  23. import mindspore.dataset.engine as de
  24. import mindspore.dataset.vision.c_transforms as C
  25. import mindspore.dataset.transforms.c_transforms as C2
  26. def create_dataset(dataset_path, do_train, config, repeat_num=1):
  27. """
  28. create a train or eval dataset
  29. Args:
  30. dataset_path(string): the path of dataset.
  31. do_train(bool): whether dataset is used for train or eval.
  32. config(struct): the config of train and eval in diffirent platform.
  33. repeat_num(int): the repeat times of dataset. Default: 1.
  34. Returns:
  35. dataset
  36. """
  37. if config.platform == "Ascend":
  38. rank_size = int(os.getenv("RANK_SIZE", '1'))
  39. rank_id = int(os.getenv("RANK_ID", '0'))
  40. if rank_size == 1:
  41. ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True)
  42. else:
  43. ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True,
  44. num_shards=rank_size, shard_id=rank_id)
  45. elif config.platform == "GPU":
  46. if do_train:
  47. if config.run_distribute:
  48. from mindspore.communication.management import get_rank, get_group_size
  49. ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True,
  50. num_shards=get_group_size(), shard_id=get_rank())
  51. else:
  52. ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True)
  53. else:
  54. ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True)
  55. elif config.platform == "CPU":
  56. ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True)
  57. resize_height = config.image_height
  58. resize_width = config.image_width
  59. buffer_size = 1000
  60. # define map operations
  61. decode_op = C.Decode()
  62. resize_crop_op = C.RandomCropDecodeResize(resize_height, scale=(0.08, 1.0), ratio=(0.75, 1.333))
  63. horizontal_flip_op = C.RandomHorizontalFlip(prob=0.5)
  64. resize_op = C.Resize((256, 256))
  65. center_crop = C.CenterCrop(resize_width)
  66. rescale_op = C.RandomColorAdjust(brightness=0.4, contrast=0.4, saturation=0.4)
  67. normalize_op = C.Normalize(mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
  68. std=[0.229 * 255, 0.224 * 255, 0.225 * 255])
  69. change_swap_op = C.HWC2CHW()
  70. if do_train:
  71. trans = [resize_crop_op, horizontal_flip_op, rescale_op, normalize_op, change_swap_op]
  72. else:
  73. trans = [decode_op, resize_op, center_crop, normalize_op, change_swap_op]
  74. type_cast_op = C2.TypeCast(mstype.int32)
  75. ds = ds.map(operations=trans, input_columns="image", num_parallel_workers=8)
  76. ds = ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=8)
  77. # apply shuffle operations
  78. ds = ds.shuffle(buffer_size=buffer_size)
  79. # apply batch operations
  80. ds = ds.batch(config.batch_size, drop_remainder=True)
  81. # apply dataset repeat operation
  82. ds = ds.repeat(repeat_num)
  83. step_size = ds.get_dataset_size()
  84. if step_size == 0:
  85. raise ValueError("The step_size of dataset is zero. Check if the images of train dataset is more than batch_\
  86. size in config.py")
  87. return ds, step_size
  88. def extract_features(net, dataset_path, config):
  89. features_folder = dataset_path + '_features'
  90. if not os.path.exists(features_folder):
  91. os.makedirs(features_folder)
  92. dataset = create_dataset(dataset_path=dataset_path,
  93. do_train=False,
  94. config=config,
  95. repeat_num=1)
  96. step_size = dataset.get_dataset_size()
  97. model = Model(net)
  98. for i, data in enumerate(dataset.create_dict_iterator(output_numpy=True)):
  99. features_path = os.path.join(features_folder, f"feature_{i}.npy")
  100. label_path = os.path.join(features_folder, f"label_{i}.npy")
  101. if not os.path.exists(features_path or not os.path.exists(label_path)):
  102. image = data["image"]
  103. label = data["label"]
  104. features = model.predict(Tensor(image))
  105. np.save(features_path, features.asnumpy())
  106. np.save(label_path, label)
  107. print(f"Complete the batch {i}/{step_size}")
  108. return step_size