|
- # Copyright 2021 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """Create Dataset."""
- import os
- import argparse
- import glob
- import numpy as np
- import PIL.Image as pil_image
- from PIL import ImageFile
-
- from mindspore.mindrecord import FileWriter
-
- from src.config import srcnn_cfg as config
- from src.utils import convert_rgb_to_y
- ImageFile.LOAD_TRUNCATED_IMAGES = True
-
- parser = argparse.ArgumentParser(description='Generate dataset file.')
- parser.add_argument("--src_folder", type=str, required=True, help="Raw data folder.")
- parser.add_argument("--output_folder", type=str, required=True, help="Dataset output path.")
-
- if __name__ == '__main__':
- args, _ = parser.parse_known_args()
- if not os.path.exists(args.output_folder):
- os.makedirs(args.output_folder)
- prefix = "srcnn.mindrecord"
- file_num = 32
- patch_size = config.patch_size
- stride = config.stride
- scale = config.scale
- mindrecord_path = os.path.join(args.output_folder, prefix)
- writer = FileWriter(mindrecord_path, file_num)
-
- srcnn_json = {
- "lr": {"type": "float32", "shape": [1, patch_size, patch_size]},
- "hr": {"type": "float32", "shape": [1, patch_size, patch_size]},
- }
- writer.add_schema(srcnn_json, "srcnn_json")
- image_list = []
- file_list = sorted(os.listdir(args.src_folder))
- for file_name in file_list:
- path = os.path.join(args.src_folder, file_name)
- if os.path.isfile(path):
- image_list.append(path)
- else:
- for image_path in sorted(glob.glob('{}/*'.format(path))):
- image_list.append(image_path)
-
- print("image_list size ", len(image_list), flush=True)
-
- for path in image_list:
- hr = pil_image.open(path).convert('RGB')
- hr_width = (hr.width // scale) * scale
- hr_height = (hr.height // scale) * scale
- hr = hr.resize((hr_width, hr_height), resample=pil_image.BICUBIC)
- lr = hr.resize((hr_width // scale, hr_height // scale), resample=pil_image.BICUBIC)
- lr = lr.resize((lr.width * scale, lr.height * scale), resample=pil_image.BICUBIC)
- hr = np.array(hr).astype(np.float32)
- lr = np.array(lr).astype(np.float32)
- hr = convert_rgb_to_y(hr)
- lr = convert_rgb_to_y(lr)
-
- for i in range(0, lr.shape[0] - patch_size + 1, stride):
- for j in range(0, lr.shape[1] - patch_size + 1, stride):
- lr_res = np.expand_dims(lr[i:i + patch_size, j:j + patch_size] / 255., 0)
- hr_res = np.expand_dims(hr[i:i + patch_size, j:j + patch_size] / 255., 0)
- row = {"lr": lr_res, "hr": hr_res}
- writer.write_raw_data([row])
-
- writer.commit()
- print("Finish!")
|