You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

create_dataset.py 3.3 kB

4 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Create Dataset."""
  16. import os
  17. import argparse
  18. import glob
  19. import numpy as np
  20. import PIL.Image as pil_image
  21. from PIL import ImageFile
  22. from mindspore.mindrecord import FileWriter
  23. from src.config import srcnn_cfg as config
  24. from src.utils import convert_rgb_to_y
  25. ImageFile.LOAD_TRUNCATED_IMAGES = True
  26. parser = argparse.ArgumentParser(description='Generate dataset file.')
  27. parser.add_argument("--src_folder", type=str, required=True, help="Raw data folder.")
  28. parser.add_argument("--output_folder", type=str, required=True, help="Dataset output path.")
  29. if __name__ == '__main__':
  30. args, _ = parser.parse_known_args()
  31. if not os.path.exists(args.output_folder):
  32. os.makedirs(args.output_folder)
  33. prefix = "srcnn.mindrecord"
  34. file_num = 32
  35. patch_size = config.patch_size
  36. stride = config.stride
  37. scale = config.scale
  38. mindrecord_path = os.path.join(args.output_folder, prefix)
  39. writer = FileWriter(mindrecord_path, file_num)
  40. srcnn_json = {
  41. "lr": {"type": "float32", "shape": [1, patch_size, patch_size]},
  42. "hr": {"type": "float32", "shape": [1, patch_size, patch_size]},
  43. }
  44. writer.add_schema(srcnn_json, "srcnn_json")
  45. image_list = []
  46. file_list = sorted(os.listdir(args.src_folder))
  47. for file_name in file_list:
  48. path = os.path.join(args.src_folder, file_name)
  49. if os.path.isfile(path):
  50. image_list.append(path)
  51. else:
  52. for image_path in sorted(glob.glob('{}/*'.format(path))):
  53. image_list.append(image_path)
  54. print("image_list size ", len(image_list), flush=True)
  55. for path in image_list:
  56. hr = pil_image.open(path).convert('RGB')
  57. hr_width = (hr.width // scale) * scale
  58. hr_height = (hr.height // scale) * scale
  59. hr = hr.resize((hr_width, hr_height), resample=pil_image.BICUBIC)
  60. lr = hr.resize((hr_width // scale, hr_height // scale), resample=pil_image.BICUBIC)
  61. lr = lr.resize((lr.width * scale, lr.height * scale), resample=pil_image.BICUBIC)
  62. hr = np.array(hr).astype(np.float32)
  63. lr = np.array(lr).astype(np.float32)
  64. hr = convert_rgb_to_y(hr)
  65. lr = convert_rgb_to_y(lr)
  66. for i in range(0, lr.shape[0] - patch_size + 1, stride):
  67. for j in range(0, lr.shape[1] - patch_size + 1, stride):
  68. lr_res = np.expand_dims(lr[i:i + patch_size, j:j + patch_size] / 255., 0)
  69. hr_res = np.expand_dims(hr[i:i + patch_size, j:j + patch_size] / 255., 0)
  70. row = {"lr": lr_res, "hr": hr_res}
  71. writer.write_raw_data([row])
  72. writer.commit()
  73. print("Finish!")