From: @fuyu-wang Reviewed-by: Signed-off-by:pull/12807/MERGE
| @@ -0,0 +1,175 @@ | |||||
| # Contents | |||||
| - [SRCNN Description](#srcnn-description) | |||||
| - [Model Architecture](#model-architecture) | |||||
| - [Dataset](#dataset) | |||||
| - [Environment Requirements](#environment-requirements) | |||||
| - [Quick Start](#quick-start) | |||||
| - [Script Description](#script-description) | |||||
| - [Script and Sample Code](#script-and-sample-code) | |||||
| - [Script Parameters](#script-parameters) | |||||
| - [Training Process](#training-process) | |||||
| - [Evaluation Process](#evaluation-process) | |||||
| - [Model Description](#model-description) | |||||
| - [Performance](#performance) | |||||
| - [Training Performance](#evaluation-performance) | |||||
| - [Inference Performance](#evaluation-performance) | |||||
| - [ModelZoo Homepage](#modelzoo-homepage) | |||||
| # [NASNet Description](#contents) | |||||
| SRCNN learns an end-to-end mapping between low- and high-resolution images, with little extra pre/post-processing beyond the optimization. With a lightweight structure, the SRCNN has achieved superior performance than the state-of-the-art methods. | |||||
| [Paper](https://arxiv.org/pdf/1501.00092.pdf): Chao Dong, Chen Change Loy, Kaiming He, Xiaoou Tang. Image Super-Resolution Using Deep Convolutional Networks. 2014. | |||||
| # [Model architecture](#contents) | |||||
| The overall network architecture of SRCNN is show below: | |||||
| [Link](https://arxiv.org/pdf/1501.00092.pdf) | |||||
| # [Dataset](#contents) | |||||
| - Training Dataset | |||||
| - ILSVRC2013_DET_train: 395918 images, 200 classes | |||||
| - Evaluation Dataset | |||||
| - Set5: 5 images | |||||
| - Set14: 14 images | |||||
| - Set5 & Set14 download url: http://vllab.ucmerced.edu/wlai24/LapSRN/results/SR_testing_datasets.zip | |||||
| - BSDS200: 200 images | |||||
| - BSDS200 download url: http://vllab.ucmerced.edu/wlai24/LapSRN/results/SR_training_datasets.zip | |||||
| - Data format: RGB images. | |||||
| - Note: Data will be processed in src/dataset.py | |||||
| # [Environment Requirements](#contents) | |||||
| - Hardware GPU | |||||
| - Prepare hardware environment with GPU processor. | |||||
| - Framework | |||||
| - [MindSpore](https://www.mindspore.cn/install/en) | |||||
| - For more information, please check the resources below: | |||||
| - [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html) | |||||
| - [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html) | |||||
| # [Script description](#contents) | |||||
| ## [Script and sample code](#contents) | |||||
| ```python | |||||
| . | |||||
| └─srcnn | |||||
| ├─README.md | |||||
| ├─scripts | |||||
| ├─run_distribute_train_gpu.sh # launch distributed training with gpu platform | |||||
| └─run_eval_gpu.sh # launch evaluating with gpu platform | |||||
| ├─src | |||||
| ├─config.py # parameter configuration | |||||
| ├─dataset.py # data preprocessing | |||||
| ├─metric.py # accuracy metric | |||||
| ├─utils.py # some functions which is commonly used | |||||
| ├─srcnn.py # network definition | |||||
| ├─create_dataset.py # generating mindrecord training dataset | |||||
| ├─eval.py # eval net | |||||
| └─train.py # train net | |||||
| ``` | |||||
| ## [Script Parameters](#contents) | |||||
| Parameters for both training and evaluating can be set in config.py. | |||||
| ```python | |||||
| 'lr': 1e-4, # learning rate | |||||
| 'patch_size': 33, # patch_size | |||||
| 'stride': 99, # stride | |||||
| 'scale': 2, # image scale | |||||
| 'epoch_size': 20, # total epoch numbers | |||||
| 'batch_size': 16, # input batchsize | |||||
| 'save_checkpoint': True, # whether saving ckpt file | |||||
| 'keep_checkpoint_max': 10, # max numbers to keep checkpoints | |||||
| 'save_checkpoint_path': 'outputs/' # save checkpoint path | |||||
| ``` | |||||
| ## [Training Process](#contents) | |||||
| ### Dataset | |||||
| To create dataset, download the training dataset firstly and then convert them to mindrecord files. We can deal with it as follows. | |||||
| ```shell | |||||
| python create_dataset.py --src_folder=/dataset/ILSVRC2013_DET_train --output_folder=/dataset/mindrecord_dir | |||||
| ``` | |||||
| ### Usage | |||||
| ```bash | |||||
| GPU: | |||||
| sh run_distribute_train_gpu.sh DEVICE_NUM VISIABLE_DEVICES(0,1,2,3,4,5,6,7) DATASET_PATH | |||||
| ``` | |||||
| ### Launch | |||||
| ```bash | |||||
| # distributed training example(8p) for GPU | |||||
| sh run_distribute_train_gpu.sh 8 0,1,2,3,4,5,6,7 /dataset/train | |||||
| # standalone training example for GPU | |||||
| sh run_distribute_train_gpu.sh 1 0 /dataset/train | |||||
| ``` | |||||
| You can find checkpoint file together with result in log. | |||||
| ## [Evaluation Process](#contents) | |||||
| ### Usage | |||||
| ```bash | |||||
| # Evaluation | |||||
| sh run_eval_gpu.sh DEVICE_ID DATASET_PATH CHECKPOINT_PATH | |||||
| ``` | |||||
| ### Launch | |||||
| ```bash | |||||
| # Evaluation with checkpoint | |||||
| sh run_eval_gpu.sh 1 /dataset/val /ckpt_dir/srcnn-20_*.ckpt | |||||
| ``` | |||||
| ### Result | |||||
| Evaluation result will be stored in the scripts path. Under this, you can find result like the followings in log. | |||||
| result {'PSNR': 36.72421418219669} | |||||
| # [Model description](#contents) | |||||
| ## [Performance](#contents) | |||||
| ### Training Performance | |||||
| | Parameters | SRCNN | | |||||
| | -------------------------- | ------------------------- | | |||||
| | Resource | NV PCIE V100-32G | | |||||
| | uploaded Date | 03/02/2021 | | |||||
| | MindSpore Version | master | | |||||
| | Dataset | ImageNet2013 scale:2 | | |||||
| | Training Parameters | src/config.py | | |||||
| | Optimizer | Adam | | |||||
| | Loss Function | MSELoss | | |||||
| | Loss | 0.00179 | | |||||
| | Total time | 1 h 8ps | | |||||
| | Checkpoint for Fine tuning | 671 K(.ckpt file) | | |||||
| ### Inference Performance | |||||
| | Parameters | | | |||||
| | -------------------------- | -------------------------- | | |||||
| | Resource | NV PCIE V100-32G | | |||||
| | uploaded Date | 03/02/2021 | | |||||
| | MindSpore Version | master | | |||||
| | Dataset | Set5/Set14/BSDS200 scale:2 | | |||||
| | batch_size | 1 | | |||||
| | PSNR | 36.72/32.58/33.81 | | |||||
| # [ModelZoo Homepage](#contents) | |||||
| Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo). | |||||
| @@ -0,0 +1,82 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Create Dataset.""" | |||||
| import os | |||||
| import argparse | |||||
| import glob | |||||
| import numpy as np | |||||
| import PIL.Image as pil_image | |||||
| from PIL import ImageFile | |||||
| from mindspore.mindrecord import FileWriter | |||||
| from src.config import srcnn_cfg as config | |||||
| from src.utils import convert_rgb_to_y | |||||
| ImageFile.LOAD_TRUNCATED_IMAGES = True | |||||
| parser = argparse.ArgumentParser(description='Generate dataset file.') | |||||
| parser.add_argument("--src_folder", type=str, required=True, help="Raw data folder.") | |||||
| parser.add_argument("--output_folder", type=str, required=True, help="Dataset output path.") | |||||
| if __name__ == '__main__': | |||||
| args, _ = parser.parse_known_args() | |||||
| if not os.path.exists(args.output_folder): | |||||
| os.makedirs(args.output_folder) | |||||
| prefix = "srcnn.mindrecord" | |||||
| file_num = 32 | |||||
| patch_size = config.patch_size | |||||
| stride = config.stride | |||||
| scale = config.scale | |||||
| mindrecord_path = os.path.join(args.output_folder, prefix) | |||||
| writer = FileWriter(mindrecord_path, file_num) | |||||
| srcnn_json = { | |||||
| "lr": {"type": "float32", "shape": [1, patch_size, patch_size]}, | |||||
| "hr": {"type": "float32", "shape": [1, patch_size, patch_size]}, | |||||
| } | |||||
| writer.add_schema(srcnn_json, "srcnn_json") | |||||
| image_list = [] | |||||
| file_list = sorted(os.listdir(args.src_folder)) | |||||
| for file_name in file_list: | |||||
| path = os.path.join(args.src_folder, file_name) | |||||
| if os.path.isfile(path): | |||||
| image_list.append(path) | |||||
| else: | |||||
| for image_path in sorted(glob.glob('{}/*'.format(path))): | |||||
| image_list.append(image_path) | |||||
| print("image_list size ", len(image_list), flush=True) | |||||
| for path in image_list: | |||||
| hr = pil_image.open(path).convert('RGB') | |||||
| hr_width = (hr.width // scale) * scale | |||||
| hr_height = (hr.height // scale) * scale | |||||
| hr = hr.resize((hr_width, hr_height), resample=pil_image.BICUBIC) | |||||
| lr = hr.resize((hr_width // scale, hr_height // scale), resample=pil_image.BICUBIC) | |||||
| lr = lr.resize((lr.width * scale, lr.height * scale), resample=pil_image.BICUBIC) | |||||
| hr = np.array(hr).astype(np.float32) | |||||
| lr = np.array(lr).astype(np.float32) | |||||
| hr = convert_rgb_to_y(hr) | |||||
| lr = convert_rgb_to_y(lr) | |||||
| for i in range(0, lr.shape[0] - patch_size + 1, stride): | |||||
| for j in range(0, lr.shape[1] - patch_size + 1, stride): | |||||
| lr_res = np.expand_dims(lr[i:i + patch_size, j:j + patch_size] / 255., 0) | |||||
| hr_res = np.expand_dims(hr[i:i + patch_size, j:j + patch_size] / 255., 0) | |||||
| row = {"lr": lr_res, "hr": hr_res} | |||||
| writer.write_raw_data([row]) | |||||
| writer.commit() | |||||
| print("Finish!") | |||||
| @@ -0,0 +1,55 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """srcnn evaluation""" | |||||
| import argparse | |||||
| import mindspore as ms | |||||
| import mindspore.nn as nn | |||||
| from mindspore import context, Tensor | |||||
| from mindspore.train.model import Model | |||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||||
| from src.config import srcnn_cfg as config | |||||
| from src.dataset import create_eval_dataset | |||||
| from src.srcnn import SRCNN | |||||
| from src.metric import SRCNNpsnr | |||||
| if __name__ == '__main__': | |||||
| parser = argparse.ArgumentParser(description="srcnn eval") | |||||
| parser.add_argument('--dataset_path', type=str, required=True, help="Dataset, default is None.") | |||||
| parser.add_argument('--checkpoint_path', type=str, required=True, help="checkpoint file path") | |||||
| parser.add_argument('--device_target', type=str, default='GPU', choices=("GPU"), | |||||
| help="Device target, support GPU.") | |||||
| args, _ = parser.parse_known_args() | |||||
| if args.device_target == "GPU": | |||||
| context.set_context(mode=context.GRAPH_MODE, | |||||
| device_target=args.device_target, | |||||
| save_graphs=False) | |||||
| else: | |||||
| raise ValueError("Unsupported device target.") | |||||
| eval_ds = create_eval_dataset(args.dataset_path) | |||||
| net = SRCNN() | |||||
| lr = Tensor(config.lr, ms.float32) | |||||
| opt = nn.Adam(params=net.trainable_params(), learning_rate=lr, eps=1e-07) | |||||
| loss = nn.MSELoss(reduction='mean') | |||||
| param_dict = load_checkpoint(args.checkpoint_path) | |||||
| load_param_into_net(net, param_dict) | |||||
| net.set_train(False) | |||||
| model = Model(net, loss_fn=loss, optimizer=opt, metrics={'PSNR': SRCNNpsnr()}) | |||||
| res = model.eval(eval_ds, dataset_sink_mode=False) | |||||
| print("result ", res) | |||||
| @@ -0,0 +1 @@ | |||||
| Pillow | |||||
| @@ -0,0 +1,66 @@ | |||||
| #!/bin/bash | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| if [ $# -lt 3 ] | |||||
| then | |||||
| echo "Usage: sh run_distribute_train_gpu.sh [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] [PRE_TRAINED](optional)" | |||||
| exit 1 | |||||
| fi | |||||
| if [ $1 -lt 1 ] && [ $1 -gt 8 ] | |||||
| then | |||||
| echo "error: DEVICE_NUM=$1 is not in (1-8)" | |||||
| exit 1 | |||||
| fi | |||||
| export DEVICE_NUM=$1 | |||||
| export RANK_SIZE=$1 | |||||
| BASEPATH=$(cd "`dirname $0`" || exit; pwd) | |||||
| export PYTHONPATH=${BASEPATH}:$PYTHONPATH | |||||
| if [ -d "train_parallel" ]; | |||||
| then | |||||
| rm -rf train_parallel | |||||
| fi | |||||
| mkdir train_parallel | |||||
| cd train_parallel || exit | |||||
| export CUDA_VISIBLE_DEVICES="$2" | |||||
| if [ -f $4 ] # pre_trained ckpt | |||||
| then | |||||
| if [ $1 -gt 1 ] | |||||
| then | |||||
| mpirun -n $1 --allow-run-as-root python3 ${BASEPATH}/../train.py \ | |||||
| --dataset_path=$3 \ | |||||
| --run_distribute=True \ | |||||
| --pre_trained=$4 > log.txt 2>&1 & | |||||
| else | |||||
| python3 ${BASEPATH}/../train.py \ | |||||
| --dataset_path=$3 \ | |||||
| --pre_trained=$4 > log.txt 2>&1 & | |||||
| fi | |||||
| else | |||||
| if [ $1 -gt 1 ] | |||||
| then | |||||
| mpirun -n $1 --allow-run-as-root python3 ${BASEPATH}/../train.py \ | |||||
| --run_distribute=True \ | |||||
| --dataset_path=$3 > log.txt 2>&1 & | |||||
| else | |||||
| python3 ${BASEPATH}/../train.py \ | |||||
| --dataset_path=$3 > log.txt 2>&1 & | |||||
| fi | |||||
| fi | |||||
| @@ -0,0 +1,43 @@ | |||||
| #!/bin/bash | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| if [ $# -lt 3 ] | |||||
| then | |||||
| echo "Usage: sh run_eval_gpu.sh [DEVICE_ID] [DATASET_PATH] [CHECKPOINT_PATH]" | |||||
| exit 1 | |||||
| fi | |||||
| # check checkpoint file | |||||
| if [ ! -f $3 ] | |||||
| then | |||||
| echo "error: CHECKPOINT_PATH=$3 is not a file" | |||||
| exit 1 | |||||
| fi | |||||
| BASEPATH=$(cd "`dirname $0`" || exit; pwd) | |||||
| export PYTHONPATH=${BASEPATH}:$PYTHONPATH | |||||
| if [ -d "./eval" ]; | |||||
| then | |||||
| rm -rf ./eval | |||||
| fi | |||||
| mkdir ./eval | |||||
| export CUDA_VISIBLE_DEVICES="$1" | |||||
| python3 ${BASEPATH}/../eval.py \ | |||||
| --dataset_path=$2 \ | |||||
| --checkpoint_path=$3 > eval/eval.log 2>&1 & | |||||
| @@ -0,0 +1,29 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Network parameters.""" | |||||
| from easydict import EasyDict as edict | |||||
| srcnn_cfg = edict({ | |||||
| 'lr': 1e-4, | |||||
| 'patch_size': 33, | |||||
| 'stride': 99, | |||||
| 'scale': 2, | |||||
| 'epoch_size': 20, | |||||
| 'batch_size': 16, | |||||
| 'save_checkpoint': True, | |||||
| 'keep_checkpoint_max': 10, | |||||
| 'save_checkpoint_path': 'outputs/' | |||||
| }) | |||||
| @@ -0,0 +1,62 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import glob | |||||
| import numpy as np | |||||
| import PIL.Image as pil_image | |||||
| import mindspore.dataset as ds | |||||
| from src.config import srcnn_cfg as config | |||||
| from src.utils import convert_rgb_to_y | |||||
| class EvalDataset: | |||||
| def __init__(self, images_dir): | |||||
| self.images_dir = images_dir | |||||
| scale = config.scale | |||||
| self.lr_group = [] | |||||
| self.hr_group = [] | |||||
| for image_path in sorted(glob.glob('{}/*'.format(images_dir))): | |||||
| hr = pil_image.open(image_path).convert('RGB') | |||||
| hr_width = (hr.width // scale) * scale | |||||
| hr_height = (hr.height // scale) * scale | |||||
| hr = hr.resize((hr_width, hr_height), resample=pil_image.BICUBIC) | |||||
| lr = hr.resize((hr_width // scale, hr_height // scale), resample=pil_image.BICUBIC) | |||||
| lr = lr.resize((lr.width * scale, lr.height * scale), resample=pil_image.BICUBIC) | |||||
| hr = np.array(hr).astype(np.float32) | |||||
| lr = np.array(lr).astype(np.float32) | |||||
| hr = convert_rgb_to_y(hr) | |||||
| lr = convert_rgb_to_y(lr) | |||||
| self.lr_group.append(lr) | |||||
| self.hr_group.append(hr) | |||||
| def __len__(self): | |||||
| return len(self.lr_group) | |||||
| def __getitem__(self, idx): | |||||
| return np.expand_dims(self.lr_group[idx] / 255., 0), np.expand_dims(self.hr_group[idx] / 255., 0) | |||||
| def create_train_dataset(mindrecord_file, batch_size=1, shard_id=0, num_shard=1, num_parallel_workers=4): | |||||
| data_set = ds.MindDataset(mindrecord_file, columns_list=["lr", "hr"], num_shards=num_shard, | |||||
| shard_id=shard_id, num_parallel_workers=num_parallel_workers, shuffle=True) | |||||
| data_set = data_set.batch(batch_size, drop_remainder=True) | |||||
| return data_set | |||||
| def create_eval_dataset(images_dir, batch_size=1): | |||||
| dataset = EvalDataset(images_dir) | |||||
| data_set = ds.GeneratorDataset(dataset, ["lr", "hr"], shuffle=False) | |||||
| data_set = data_set.batch(batch_size, drop_remainder=True) | |||||
| return data_set | |||||
| @@ -0,0 +1,46 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Metric for accuracy evaluation.""" | |||||
| from mindspore import nn | |||||
| import numpy as np | |||||
| class SRCNNpsnr(nn.Metric): | |||||
| def __init__(self): | |||||
| super(SRCNNpsnr).__init__() | |||||
| self.clear() | |||||
| def clear(self): | |||||
| self.val = 0 | |||||
| self.sum = 0 | |||||
| self.count = 0 | |||||
| def update(self, *inputs): | |||||
| if len(inputs) != 2: | |||||
| raise ValueError('SRCNNpsnr need 2 inputs (y_pred, y), but got {}'.format(len(inputs))) | |||||
| y_pred = self._convert_data(inputs[0]) | |||||
| y = self._convert_data(inputs[1]) | |||||
| n = len(inputs) | |||||
| val = 10. * np.log10(1. / np.mean((y_pred - y) ** 2)) | |||||
| self.val = val | |||||
| self.sum += val * n | |||||
| self.count += n | |||||
| def eval(self): | |||||
| if self.count == 0: | |||||
| raise RuntimeError('PSNR can not be calculated, because the number of samples is 0.') | |||||
| return self.sum / self.count | |||||
| @@ -0,0 +1,30 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import mindspore.nn as nn | |||||
| class SRCNN(nn.Cell): | |||||
| def __init__(self, num_channels=1): | |||||
| super(SRCNN, self).__init__() | |||||
| self.conv1 = nn.Conv2d(num_channels, 64, kernel_size=9, padding=9 // 2, pad_mode='pad') | |||||
| self.conv2 = nn.Conv2d(64, 32, kernel_size=5, padding=5 // 2, pad_mode='pad') | |||||
| self.conv3 = nn.Conv2d(32, num_channels, kernel_size=5, padding=5 // 2, pad_mode='pad') | |||||
| self.relu = nn.ReLU() | |||||
| def construct(self, x): | |||||
| x = self.relu(self.conv1(x)) | |||||
| x = self.relu(self.conv2(x)) | |||||
| x = self.conv3(x) | |||||
| return x | |||||
| @@ -0,0 +1,37 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import numpy as np | |||||
| def convert_rgb_to_y(img): | |||||
| if isinstance(img, np.ndarray): | |||||
| return 16. + (64.738 * img[:, :, 0] + 129.057 * img[:, :, 1] + 25.064 * img[:, :, 2]) / 256. | |||||
| raise Exception('Unknown Type', type(img)) | |||||
| def convert_rgb_to_ycbcr(img): | |||||
| if isinstance(img, np.ndarray): | |||||
| y = 16. + (64.738 * img[:, :, 0] + 129.057 * img[:, :, 1] + 25.064 * img[:, :, 2]) / 256. | |||||
| cb = 128. + (-37.945 * img[:, :, 0] - 74.494 * img[:, :, 1] + 112.439 * img[:, :, 2]) / 256. | |||||
| cr = 128. + (112.439 * img[:, :, 0] - 94.154 * img[:, :, 1] - 18.285 * img[:, :, 2]) / 256. | |||||
| return np.array([y, cb, cr]).transpose([1, 2, 0]) | |||||
| raise Exception('Unknown Type', type(img)) | |||||
| def convert_ycbcr_to_rgb(img): | |||||
| if isinstance(img, np.ndarray): | |||||
| r = 298.082 * img[:, :, 0] / 256. + 408.583 * img[:, :, 2] / 256. - 222.921 | |||||
| g = 298.082 * img[:, :, 0] / 256. - 100.291 * img[:, :, 1] / 256. - 208.120 * img[:, :, 2] / 256. + 135.576 | |||||
| b = 298.082 * img[:, :, 0] / 256. + 516.412 * img[:, :, 1] / 256. - 276.836 | |||||
| return np.array([r, g, b]).transpose([1, 2, 0]) | |||||
| raise Exception('Unknown Type', type(img)) | |||||
| @@ -0,0 +1,105 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """srcnn training""" | |||||
| import os | |||||
| import argparse | |||||
| import ast | |||||
| import mindspore as ms | |||||
| import mindspore.nn as nn | |||||
| from mindspore import context, Tensor | |||||
| from mindspore.common import set_seed | |||||
| from mindspore.train.model import Model | |||||
| from mindspore.train.callback import TimeMonitor, LossMonitor, CheckpointConfig, ModelCheckpoint | |||||
| from mindspore.communication.management import init, get_rank, get_group_size | |||||
| from mindspore.train.model import ParallelMode | |||||
| from src.config import srcnn_cfg as config | |||||
| from src.dataset import create_train_dataset | |||||
| from src.srcnn import SRCNN | |||||
| set_seed(1) | |||||
| def filter_checkpoint_parameter_by_list(origin_dict, param_filter): | |||||
| """remove useless parameters according to filter_list""" | |||||
| for key in list(origin_dict.keys()): | |||||
| for name in param_filter: | |||||
| if name in key: | |||||
| print("Delete parameter from checkpoint: ", key) | |||||
| del origin_dict[key] | |||||
| break | |||||
| if __name__ == '__main__': | |||||
| parser = argparse.ArgumentParser(description="srcnn training") | |||||
| parser.add_argument('--dataset_path', type=str, default='', help='Dataset path') | |||||
| parser.add_argument('--device_num', type=int, default=1, help='Device num.') | |||||
| parser.add_argument('--device_target', type=str, default='GPU', choices=("GPU"), | |||||
| help="Device target, support GPU.") | |||||
| parser.add_argument('--pre_trained', type=str, default='', help='model_path, local pretrained model to load') | |||||
| parser.add_argument("--run_distribute", type=ast.literal_eval, default=False, | |||||
| help="Run distribute, default: false.") | |||||
| parser.add_argument("--filter_weight", type=ast.literal_eval, default=False, | |||||
| help="Filter head weight parameters, default is False.") | |||||
| args, _ = parser.parse_known_args() | |||||
| if args.device_target == "GPU": | |||||
| context.set_context(mode=context.GRAPH_MODE, | |||||
| device_target=args.device_target, | |||||
| save_graphs=False) | |||||
| else: | |||||
| raise ValueError("Unsupported device target.") | |||||
| rank = 0 | |||||
| device_num = 1 | |||||
| if args.run_distribute: | |||||
| init() | |||||
| rank = get_rank() | |||||
| device_num = get_group_size() | |||||
| context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL) | |||||
| train_dataset = create_train_dataset(args.dataset_path, batch_size=config.batch_size, | |||||
| shard_id=rank, num_shard=device_num) | |||||
| step_size = train_dataset.get_dataset_size() | |||||
| # define net | |||||
| net = SRCNN() | |||||
| # init weight | |||||
| if args.pre_trained: | |||||
| param_dict = load_checkpoint(args.pre_trained) | |||||
| if args.filter_weight: | |||||
| filter_list = [x.name for x in net.end_point.get_parameters()] | |||||
| filter_checkpoint_parameter_by_list(param_dict, filter_list) | |||||
| load_param_into_net(net, param_dict) | |||||
| lr = Tensor(config.lr, ms.float32) | |||||
| opt = nn.Adam(params=net.trainable_params(), learning_rate=lr, eps=1e-07) | |||||
| loss = nn.MSELoss(reduction='mean') | |||||
| model = Model(net, loss_fn=loss, optimizer=opt) | |||||
| # define callbacks | |||||
| callbacks = [LossMonitor(), TimeMonitor(data_size=step_size)] | |||||
| if config.save_checkpoint and rank == 0: | |||||
| config_ck = CheckpointConfig(save_checkpoint_steps=step_size, | |||||
| keep_checkpoint_max=config.keep_checkpoint_max) | |||||
| save_ckpt_path = os.path.join(config.save_checkpoint_path, 'ckpt_' + str(rank) + '/') | |||||
| ckpt_cb = ModelCheckpoint(prefix="srcnn", directory=save_ckpt_path, config=config_ck) | |||||
| callbacks.append(ckpt_cb) | |||||
| model.train(config.epoch_size, train_dataset, callbacks=callbacks) | |||||