From be3acce3cb5460c006fb13ad910cb39db062546d Mon Sep 17 00:00:00 2001 From: Fuyu_Wang Date: Tue, 2 Mar 2021 23:00:35 +0800 Subject: [PATCH] srcnn gpu add --- model_zoo/official/cv/srcnn/README.md | 175 ++++++++++++++++++ model_zoo/official/cv/srcnn/create_dataset.py | 82 ++++++++ model_zoo/official/cv/srcnn/eval.py | 55 ++++++ model_zoo/official/cv/srcnn/requirements.txt | 1 + .../srcnn/scripts/run_distribute_train_gpu.sh | 66 +++++++ .../official/cv/srcnn/scripts/run_eval_gpu.sh | 43 +++++ model_zoo/official/cv/srcnn/src/config.py | 29 +++ model_zoo/official/cv/srcnn/src/dataset.py | 62 +++++++ model_zoo/official/cv/srcnn/src/metric.py | 46 +++++ model_zoo/official/cv/srcnn/src/srcnn.py | 30 +++ model_zoo/official/cv/srcnn/src/utils.py | 37 ++++ model_zoo/official/cv/srcnn/train.py | 105 +++++++++++ 12 files changed, 731 insertions(+) create mode 100644 model_zoo/official/cv/srcnn/README.md create mode 100644 model_zoo/official/cv/srcnn/create_dataset.py create mode 100644 model_zoo/official/cv/srcnn/eval.py create mode 100644 model_zoo/official/cv/srcnn/requirements.txt create mode 100644 model_zoo/official/cv/srcnn/scripts/run_distribute_train_gpu.sh create mode 100644 model_zoo/official/cv/srcnn/scripts/run_eval_gpu.sh create mode 100644 model_zoo/official/cv/srcnn/src/config.py create mode 100644 model_zoo/official/cv/srcnn/src/dataset.py create mode 100644 model_zoo/official/cv/srcnn/src/metric.py create mode 100644 model_zoo/official/cv/srcnn/src/srcnn.py create mode 100644 model_zoo/official/cv/srcnn/src/utils.py create mode 100644 model_zoo/official/cv/srcnn/train.py diff --git a/model_zoo/official/cv/srcnn/README.md b/model_zoo/official/cv/srcnn/README.md new file mode 100644 index 0000000000..c7a2af3a0d --- /dev/null +++ b/model_zoo/official/cv/srcnn/README.md @@ -0,0 +1,175 @@ +# Contents + +- [SRCNN Description](#srcnn-description) +- [Model Architecture](#model-architecture) +- [Dataset](#dataset) +- [Environment Requirements](#environment-requirements) +- [Quick Start](#quick-start) +- [Script Description](#script-description) + - [Script and Sample Code](#script-and-sample-code) + - [Script Parameters](#script-parameters) + - [Training Process](#training-process) + - [Evaluation Process](#evaluation-process) +- [Model Description](#model-description) + - [Performance](#performance) + - [Training Performance](#evaluation-performance) + - [Inference Performance](#evaluation-performance) +- [ModelZoo Homepage](#modelzoo-homepage) + +# [NASNet Description](#contents) + +SRCNN learns an end-to-end mapping between low- and high-resolution images, with little extra pre/post-processing beyond the optimization. With a lightweight structure, the SRCNN has achieved superior performance than the state-of-the-art methods. + +[Paper](https://arxiv.org/pdf/1501.00092.pdf): Chao Dong, Chen Change Loy, Kaiming He, Xiaoou Tang. Image Super-Resolution Using Deep Convolutional Networks. 2014. + +# [Model architecture](#contents) + +The overall network architecture of SRCNN is show below: + +[Link](https://arxiv.org/pdf/1501.00092.pdf) + +# [Dataset](#contents) + +- Training Dataset + - ILSVRC2013_DET_train: 395918 images, 200 classes +- Evaluation Dataset + - Set5: 5 images + - Set14: 14 images + - Set5 & Set14 download url: http://vllab.ucmerced.edu/wlai24/LapSRN/results/SR_testing_datasets.zip + - BSDS200: 200 images + - BSDS200 download url: http://vllab.ucmerced.edu/wlai24/LapSRN/results/SR_training_datasets.zip +- Data format: RGB images. + - Note: Data will be processed in src/dataset.py + +# [Environment Requirements](#contents) + +- Hardware GPU + - Prepare hardware environment with GPU processor. +- Framework + - [MindSpore](https://www.mindspore.cn/install/en) +- For more information, please check the resources below: + - [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html) + - [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html) + +# [Script description](#contents) + +## [Script and sample code](#contents) + +```python +. +└─srcnn + ├─README.md + ├─scripts + ├─run_distribute_train_gpu.sh # launch distributed training with gpu platform + └─run_eval_gpu.sh # launch evaluating with gpu platform + ├─src + ├─config.py # parameter configuration + ├─dataset.py # data preprocessing + ├─metric.py # accuracy metric + ├─utils.py # some functions which is commonly used + ├─srcnn.py # network definition +├─create_dataset.py # generating mindrecord training dataset +├─eval.py # eval net +└─train.py # train net + +``` + +## [Script Parameters](#contents) + +Parameters for both training and evaluating can be set in config.py. + +```python +'lr': 1e-4, # learning rate +'patch_size': 33, # patch_size +'stride': 99, # stride +'scale': 2, # image scale +'epoch_size': 20, # total epoch numbers +'batch_size': 16, # input batchsize +'save_checkpoint': True, # whether saving ckpt file +'keep_checkpoint_max': 10, # max numbers to keep checkpoints +'save_checkpoint_path': 'outputs/' # save checkpoint path +``` + +## [Training Process](#contents) + +### Dataset + +To create dataset, download the training dataset firstly and then convert them to mindrecord files. We can deal with it as follows. + +```shell + python create_dataset.py --src_folder=/dataset/ILSVRC2013_DET_train --output_folder=/dataset/mindrecord_dir +``` + +### Usage + +```bash +GPU: + sh run_distribute_train_gpu.sh DEVICE_NUM VISIABLE_DEVICES(0,1,2,3,4,5,6,7) DATASET_PATH +``` + +### Launch + +```bash +# distributed training example(8p) for GPU +sh run_distribute_train_gpu.sh 8 0,1,2,3,4,5,6,7 /dataset/train +# standalone training example for GPU +sh run_distribute_train_gpu.sh 1 0 /dataset/train +``` + +You can find checkpoint file together with result in log. + +## [Evaluation Process](#contents) + +### Usage + +```bash +# Evaluation +sh run_eval_gpu.sh DEVICE_ID DATASET_PATH CHECKPOINT_PATH +``` + +### Launch + +```bash +# Evaluation with checkpoint +sh run_eval_gpu.sh 1 /dataset/val /ckpt_dir/srcnn-20_*.ckpt +``` + +### Result + +Evaluation result will be stored in the scripts path. Under this, you can find result like the followings in log. + +result {'PSNR': 36.72421418219669} + +# [Model description](#contents) + +## [Performance](#contents) + +### Training Performance + +| Parameters | SRCNN | +| -------------------------- | ------------------------- | +| Resource | NV PCIE V100-32G | +| uploaded Date | 03/02/2021 | +| MindSpore Version | master | +| Dataset | ImageNet2013 scale:2 | +| Training Parameters | src/config.py | +| Optimizer | Adam | +| Loss Function | MSELoss | +| Loss | 0.00179 | +| Total time | 1 h 8ps | +| Checkpoint for Fine tuning | 671 K(.ckpt file) | + +### Inference Performance + +| Parameters | | +| -------------------------- | -------------------------- | +| Resource | NV PCIE V100-32G | +| uploaded Date | 03/02/2021 | +| MindSpore Version | master | +| Dataset | Set5/Set14/BSDS200 scale:2 | +| batch_size | 1 | +| PSNR | 36.72/32.58/33.81 | + +# [ModelZoo Homepage](#contents) + +Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo). diff --git a/model_zoo/official/cv/srcnn/create_dataset.py b/model_zoo/official/cv/srcnn/create_dataset.py new file mode 100644 index 0000000000..128107ace1 --- /dev/null +++ b/model_zoo/official/cv/srcnn/create_dataset.py @@ -0,0 +1,82 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Create Dataset.""" +import os +import argparse +import glob +import numpy as np +import PIL.Image as pil_image +from PIL import ImageFile + +from mindspore.mindrecord import FileWriter + +from src.config import srcnn_cfg as config +from src.utils import convert_rgb_to_y +ImageFile.LOAD_TRUNCATED_IMAGES = True + +parser = argparse.ArgumentParser(description='Generate dataset file.') +parser.add_argument("--src_folder", type=str, required=True, help="Raw data folder.") +parser.add_argument("--output_folder", type=str, required=True, help="Dataset output path.") + +if __name__ == '__main__': + args, _ = parser.parse_known_args() + if not os.path.exists(args.output_folder): + os.makedirs(args.output_folder) + prefix = "srcnn.mindrecord" + file_num = 32 + patch_size = config.patch_size + stride = config.stride + scale = config.scale + mindrecord_path = os.path.join(args.output_folder, prefix) + writer = FileWriter(mindrecord_path, file_num) + + srcnn_json = { + "lr": {"type": "float32", "shape": [1, patch_size, patch_size]}, + "hr": {"type": "float32", "shape": [1, patch_size, patch_size]}, + } + writer.add_schema(srcnn_json, "srcnn_json") + image_list = [] + file_list = sorted(os.listdir(args.src_folder)) + for file_name in file_list: + path = os.path.join(args.src_folder, file_name) + if os.path.isfile(path): + image_list.append(path) + else: + for image_path in sorted(glob.glob('{}/*'.format(path))): + image_list.append(image_path) + + print("image_list size ", len(image_list), flush=True) + + for path in image_list: + hr = pil_image.open(path).convert('RGB') + hr_width = (hr.width // scale) * scale + hr_height = (hr.height // scale) * scale + hr = hr.resize((hr_width, hr_height), resample=pil_image.BICUBIC) + lr = hr.resize((hr_width // scale, hr_height // scale), resample=pil_image.BICUBIC) + lr = lr.resize((lr.width * scale, lr.height * scale), resample=pil_image.BICUBIC) + hr = np.array(hr).astype(np.float32) + lr = np.array(lr).astype(np.float32) + hr = convert_rgb_to_y(hr) + lr = convert_rgb_to_y(lr) + + for i in range(0, lr.shape[0] - patch_size + 1, stride): + for j in range(0, lr.shape[1] - patch_size + 1, stride): + lr_res = np.expand_dims(lr[i:i + patch_size, j:j + patch_size] / 255., 0) + hr_res = np.expand_dims(hr[i:i + patch_size, j:j + patch_size] / 255., 0) + row = {"lr": lr_res, "hr": hr_res} + writer.write_raw_data([row]) + + writer.commit() + print("Finish!") diff --git a/model_zoo/official/cv/srcnn/eval.py b/model_zoo/official/cv/srcnn/eval.py new file mode 100644 index 0000000000..9322642136 --- /dev/null +++ b/model_zoo/official/cv/srcnn/eval.py @@ -0,0 +1,55 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""srcnn evaluation""" +import argparse +import mindspore as ms +import mindspore.nn as nn +from mindspore import context, Tensor +from mindspore.train.model import Model +from mindspore.train.serialization import load_checkpoint, load_param_into_net + +from src.config import srcnn_cfg as config +from src.dataset import create_eval_dataset +from src.srcnn import SRCNN +from src.metric import SRCNNpsnr + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description="srcnn eval") + parser.add_argument('--dataset_path', type=str, required=True, help="Dataset, default is None.") + parser.add_argument('--checkpoint_path', type=str, required=True, help="checkpoint file path") + parser.add_argument('--device_target', type=str, default='GPU', choices=("GPU"), + help="Device target, support GPU.") + args, _ = parser.parse_known_args() + + if args.device_target == "GPU": + context.set_context(mode=context.GRAPH_MODE, + device_target=args.device_target, + save_graphs=False) + else: + raise ValueError("Unsupported device target.") + + eval_ds = create_eval_dataset(args.dataset_path) + + net = SRCNN() + lr = Tensor(config.lr, ms.float32) + opt = nn.Adam(params=net.trainable_params(), learning_rate=lr, eps=1e-07) + loss = nn.MSELoss(reduction='mean') + param_dict = load_checkpoint(args.checkpoint_path) + load_param_into_net(net, param_dict) + net.set_train(False) + model = Model(net, loss_fn=loss, optimizer=opt, metrics={'PSNR': SRCNNpsnr()}) + + res = model.eval(eval_ds, dataset_sink_mode=False) + print("result ", res) diff --git a/model_zoo/official/cv/srcnn/requirements.txt b/model_zoo/official/cv/srcnn/requirements.txt new file mode 100644 index 0000000000..7e2fba5e6c --- /dev/null +++ b/model_zoo/official/cv/srcnn/requirements.txt @@ -0,0 +1 @@ +Pillow diff --git a/model_zoo/official/cv/srcnn/scripts/run_distribute_train_gpu.sh b/model_zoo/official/cv/srcnn/scripts/run_distribute_train_gpu.sh new file mode 100644 index 0000000000..79e82a9d66 --- /dev/null +++ b/model_zoo/official/cv/srcnn/scripts/run_distribute_train_gpu.sh @@ -0,0 +1,66 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [ $# -lt 3 ] +then + echo "Usage: sh run_distribute_train_gpu.sh [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] [PRE_TRAINED](optional)" +exit 1 +fi + +if [ $1 -lt 1 ] && [ $1 -gt 8 ] +then + echo "error: DEVICE_NUM=$1 is not in (1-8)" +exit 1 +fi + +export DEVICE_NUM=$1 +export RANK_SIZE=$1 + +BASEPATH=$(cd "`dirname $0`" || exit; pwd) +export PYTHONPATH=${BASEPATH}:$PYTHONPATH +if [ -d "train_parallel" ]; +then + rm -rf train_parallel +fi +mkdir train_parallel +cd train_parallel || exit + +export CUDA_VISIBLE_DEVICES="$2" + +if [ -f $4 ] # pre_trained ckpt +then + if [ $1 -gt 1 ] + then + mpirun -n $1 --allow-run-as-root python3 ${BASEPATH}/../train.py \ + --dataset_path=$3 \ + --run_distribute=True \ + --pre_trained=$4 > log.txt 2>&1 & + else + python3 ${BASEPATH}/../train.py \ + --dataset_path=$3 \ + --pre_trained=$4 > log.txt 2>&1 & + fi +else + if [ $1 -gt 1 ] + then + mpirun -n $1 --allow-run-as-root python3 ${BASEPATH}/../train.py \ + --run_distribute=True \ + --dataset_path=$3 > log.txt 2>&1 & + else + python3 ${BASEPATH}/../train.py \ + --dataset_path=$3 > log.txt 2>&1 & + fi +fi diff --git a/model_zoo/official/cv/srcnn/scripts/run_eval_gpu.sh b/model_zoo/official/cv/srcnn/scripts/run_eval_gpu.sh new file mode 100644 index 0000000000..268eece84a --- /dev/null +++ b/model_zoo/official/cv/srcnn/scripts/run_eval_gpu.sh @@ -0,0 +1,43 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [ $# -lt 3 ] +then + echo "Usage: sh run_eval_gpu.sh [DEVICE_ID] [DATASET_PATH] [CHECKPOINT_PATH]" +exit 1 +fi + +# check checkpoint file +if [ ! -f $3 ] +then + echo "error: CHECKPOINT_PATH=$3 is not a file" +exit 1 +fi + +BASEPATH=$(cd "`dirname $0`" || exit; pwd) +export PYTHONPATH=${BASEPATH}:$PYTHONPATH + +if [ -d "./eval" ]; +then + rm -rf ./eval +fi +mkdir ./eval + +export CUDA_VISIBLE_DEVICES="$1" + +python3 ${BASEPATH}/../eval.py \ + --dataset_path=$2 \ + --checkpoint_path=$3 > eval/eval.log 2>&1 & diff --git a/model_zoo/official/cv/srcnn/src/config.py b/model_zoo/official/cv/srcnn/src/config.py new file mode 100644 index 0000000000..24e33c3211 --- /dev/null +++ b/model_zoo/official/cv/srcnn/src/config.py @@ -0,0 +1,29 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Network parameters.""" + +from easydict import EasyDict as edict + +srcnn_cfg = edict({ + 'lr': 1e-4, + 'patch_size': 33, + 'stride': 99, + 'scale': 2, + 'epoch_size': 20, + 'batch_size': 16, + 'save_checkpoint': True, + 'keep_checkpoint_max': 10, + 'save_checkpoint_path': 'outputs/' +}) diff --git a/model_zoo/official/cv/srcnn/src/dataset.py b/model_zoo/official/cv/srcnn/src/dataset.py new file mode 100644 index 0000000000..6df7257d53 --- /dev/null +++ b/model_zoo/official/cv/srcnn/src/dataset.py @@ -0,0 +1,62 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import glob +import numpy as np +import PIL.Image as pil_image + +import mindspore.dataset as ds + +from src.config import srcnn_cfg as config +from src.utils import convert_rgb_to_y + +class EvalDataset: + def __init__(self, images_dir): + self.images_dir = images_dir + scale = config.scale + self.lr_group = [] + self.hr_group = [] + for image_path in sorted(glob.glob('{}/*'.format(images_dir))): + hr = pil_image.open(image_path).convert('RGB') + hr_width = (hr.width // scale) * scale + hr_height = (hr.height // scale) * scale + hr = hr.resize((hr_width, hr_height), resample=pil_image.BICUBIC) + lr = hr.resize((hr_width // scale, hr_height // scale), resample=pil_image.BICUBIC) + lr = lr.resize((lr.width * scale, lr.height * scale), resample=pil_image.BICUBIC) + hr = np.array(hr).astype(np.float32) + lr = np.array(lr).astype(np.float32) + hr = convert_rgb_to_y(hr) + lr = convert_rgb_to_y(lr) + + self.lr_group.append(lr) + self.hr_group.append(hr) + + def __len__(self): + return len(self.lr_group) + + def __getitem__(self, idx): + return np.expand_dims(self.lr_group[idx] / 255., 0), np.expand_dims(self.hr_group[idx] / 255., 0) + +def create_train_dataset(mindrecord_file, batch_size=1, shard_id=0, num_shard=1, num_parallel_workers=4): + data_set = ds.MindDataset(mindrecord_file, columns_list=["lr", "hr"], num_shards=num_shard, + shard_id=shard_id, num_parallel_workers=num_parallel_workers, shuffle=True) + data_set = data_set.batch(batch_size, drop_remainder=True) + return data_set + +def create_eval_dataset(images_dir, batch_size=1): + dataset = EvalDataset(images_dir) + data_set = ds.GeneratorDataset(dataset, ["lr", "hr"], shuffle=False) + data_set = data_set.batch(batch_size, drop_remainder=True) + return data_set diff --git a/model_zoo/official/cv/srcnn/src/metric.py b/model_zoo/official/cv/srcnn/src/metric.py new file mode 100644 index 0000000000..1cd75273ac --- /dev/null +++ b/model_zoo/official/cv/srcnn/src/metric.py @@ -0,0 +1,46 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Metric for accuracy evaluation.""" +from mindspore import nn +import numpy as np + +class SRCNNpsnr(nn.Metric): + def __init__(self): + super(SRCNNpsnr).__init__() + self.clear() + + def clear(self): + self.val = 0 + self.sum = 0 + self.count = 0 + + def update(self, *inputs): + if len(inputs) != 2: + raise ValueError('SRCNNpsnr need 2 inputs (y_pred, y), but got {}'.format(len(inputs))) + + y_pred = self._convert_data(inputs[0]) + y = self._convert_data(inputs[1]) + + n = len(inputs) + val = 10. * np.log10(1. / np.mean((y_pred - y) ** 2)) + + self.val = val + self.sum += val * n + self.count += n + + def eval(self): + if self.count == 0: + raise RuntimeError('PSNR can not be calculated, because the number of samples is 0.') + return self.sum / self.count diff --git a/model_zoo/official/cv/srcnn/src/srcnn.py b/model_zoo/official/cv/srcnn/src/srcnn.py new file mode 100644 index 0000000000..3b682ece46 --- /dev/null +++ b/model_zoo/official/cv/srcnn/src/srcnn.py @@ -0,0 +1,30 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import mindspore.nn as nn + +class SRCNN(nn.Cell): + def __init__(self, num_channels=1): + super(SRCNN, self).__init__() + self.conv1 = nn.Conv2d(num_channels, 64, kernel_size=9, padding=9 // 2, pad_mode='pad') + self.conv2 = nn.Conv2d(64, 32, kernel_size=5, padding=5 // 2, pad_mode='pad') + self.conv3 = nn.Conv2d(32, num_channels, kernel_size=5, padding=5 // 2, pad_mode='pad') + self.relu = nn.ReLU() + + def construct(self, x): + x = self.relu(self.conv1(x)) + x = self.relu(self.conv2(x)) + x = self.conv3(x) + return x diff --git a/model_zoo/official/cv/srcnn/src/utils.py b/model_zoo/official/cv/srcnn/src/utils.py new file mode 100644 index 0000000000..57c8ca1c93 --- /dev/null +++ b/model_zoo/official/cv/srcnn/src/utils.py @@ -0,0 +1,37 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np + +def convert_rgb_to_y(img): + if isinstance(img, np.ndarray): + return 16. + (64.738 * img[:, :, 0] + 129.057 * img[:, :, 1] + 25.064 * img[:, :, 2]) / 256. + raise Exception('Unknown Type', type(img)) + +def convert_rgb_to_ycbcr(img): + if isinstance(img, np.ndarray): + y = 16. + (64.738 * img[:, :, 0] + 129.057 * img[:, :, 1] + 25.064 * img[:, :, 2]) / 256. + cb = 128. + (-37.945 * img[:, :, 0] - 74.494 * img[:, :, 1] + 112.439 * img[:, :, 2]) / 256. + cr = 128. + (112.439 * img[:, :, 0] - 94.154 * img[:, :, 1] - 18.285 * img[:, :, 2]) / 256. + return np.array([y, cb, cr]).transpose([1, 2, 0]) + raise Exception('Unknown Type', type(img)) + +def convert_ycbcr_to_rgb(img): + if isinstance(img, np.ndarray): + r = 298.082 * img[:, :, 0] / 256. + 408.583 * img[:, :, 2] / 256. - 222.921 + g = 298.082 * img[:, :, 0] / 256. - 100.291 * img[:, :, 1] / 256. - 208.120 * img[:, :, 2] / 256. + 135.576 + b = 298.082 * img[:, :, 0] / 256. + 516.412 * img[:, :, 1] / 256. - 276.836 + return np.array([r, g, b]).transpose([1, 2, 0]) + raise Exception('Unknown Type', type(img)) diff --git a/model_zoo/official/cv/srcnn/train.py b/model_zoo/official/cv/srcnn/train.py new file mode 100644 index 0000000000..5a7bb59db5 --- /dev/null +++ b/model_zoo/official/cv/srcnn/train.py @@ -0,0 +1,105 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""srcnn training""" + +import os +import argparse +import ast + +import mindspore as ms +import mindspore.nn as nn +from mindspore import context, Tensor +from mindspore.common import set_seed +from mindspore.train.model import Model +from mindspore.train.callback import TimeMonitor, LossMonitor, CheckpointConfig, ModelCheckpoint +from mindspore.communication.management import init, get_rank, get_group_size +from mindspore.train.model import ParallelMode + +from src.config import srcnn_cfg as config +from src.dataset import create_train_dataset +from src.srcnn import SRCNN + +set_seed(1) + +def filter_checkpoint_parameter_by_list(origin_dict, param_filter): + """remove useless parameters according to filter_list""" + for key in list(origin_dict.keys()): + for name in param_filter: + if name in key: + print("Delete parameter from checkpoint: ", key) + del origin_dict[key] + break + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description="srcnn training") + parser.add_argument('--dataset_path', type=str, default='', help='Dataset path') + parser.add_argument('--device_num', type=int, default=1, help='Device num.') + parser.add_argument('--device_target', type=str, default='GPU', choices=("GPU"), + help="Device target, support GPU.") + parser.add_argument('--pre_trained', type=str, default='', help='model_path, local pretrained model to load') + parser.add_argument("--run_distribute", type=ast.literal_eval, default=False, + help="Run distribute, default: false.") + parser.add_argument("--filter_weight", type=ast.literal_eval, default=False, + help="Filter head weight parameters, default is False.") + args, _ = parser.parse_known_args() + + + if args.device_target == "GPU": + context.set_context(mode=context.GRAPH_MODE, + device_target=args.device_target, + save_graphs=False) + else: + raise ValueError("Unsupported device target.") + + rank = 0 + device_num = 1 + if args.run_distribute: + init() + rank = get_rank() + device_num = get_group_size() + context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL) + + train_dataset = create_train_dataset(args.dataset_path, batch_size=config.batch_size, + shard_id=rank, num_shard=device_num) + + step_size = train_dataset.get_dataset_size() + + # define net + net = SRCNN() + + # init weight + if args.pre_trained: + param_dict = load_checkpoint(args.pre_trained) + if args.filter_weight: + filter_list = [x.name for x in net.end_point.get_parameters()] + filter_checkpoint_parameter_by_list(param_dict, filter_list) + load_param_into_net(net, param_dict) + + lr = Tensor(config.lr, ms.float32) + + opt = nn.Adam(params=net.trainable_params(), learning_rate=lr, eps=1e-07) + loss = nn.MSELoss(reduction='mean') + model = Model(net, loss_fn=loss, optimizer=opt) + + # define callbacks + callbacks = [LossMonitor(), TimeMonitor(data_size=step_size)] + if config.save_checkpoint and rank == 0: + config_ck = CheckpointConfig(save_checkpoint_steps=step_size, + keep_checkpoint_max=config.keep_checkpoint_max) + save_ckpt_path = os.path.join(config.save_checkpoint_path, 'ckpt_' + str(rank) + '/') + ckpt_cb = ModelCheckpoint(prefix="srcnn", directory=save_ckpt_path, config=config_ck) + callbacks.append(ckpt_cb) + + model.train(config.epoch_size, train_dataset, callbacks=callbacks)