| @@ -0,0 +1,68 @@ | |||||
| """eval script""" | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import numpy as np | |||||
| from src import ipt | |||||
| from src.args import args | |||||
| from src.data.srdata import SRData | |||||
| from src.metrics import calc_psnr, quantize | |||||
| from mindspore import context | |||||
| import mindspore.dataset as de | |||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU", device_id=0) | |||||
| def main(): | |||||
| """eval""" | |||||
| for arg in vars(args): | |||||
| if vars(args)[arg] == 'True': | |||||
| vars(args)[arg] = True | |||||
| elif vars(args)[arg] == 'False': | |||||
| vars(args)[arg] = False | |||||
| train_dataset = SRData(args, name=args.data_test, train=False, benchmark=False) | |||||
| train_de_dataset = de.GeneratorDataset(train_dataset, ['LR', "HR"], shuffle=False) | |||||
| train_de_dataset = train_de_dataset.batch(1, drop_remainder=True) | |||||
| train_loader = train_de_dataset.create_dict_iterator() | |||||
| net_m = ipt.IPT(args) | |||||
| print('load mindspore net successfully.') | |||||
| if args.pth_path: | |||||
| param_dict = load_checkpoint(args.pth_path) | |||||
| load_param_into_net(net_m, param_dict) | |||||
| net_m.set_train(False) | |||||
| num_imgs = train_de_dataset.get_dataset_size() | |||||
| psnrs = np.zeros((num_imgs, 1)) | |||||
| for batch_idx, imgs in enumerate(train_loader): | |||||
| lr = imgs['LR'] | |||||
| hr = imgs['HR'] | |||||
| hr_np = np.float32(hr.asnumpy()) | |||||
| pred = net_m.infrc(lr) | |||||
| pred_np = np.float32(pred.asnumpy()) | |||||
| pred_np = quantize(pred_np, 255) | |||||
| psnr = calc_psnr(pred_np, hr_np, 4, 255.0, y_only=True) | |||||
| psnrs[batch_idx, 0] = psnr | |||||
| if args.denoise: | |||||
| print('Mean psnr of %s DN_%s is %.4f' % (args.data_test[0], args.sigma, psnrs.mean(axis=0)[0])) | |||||
| elif args.derain: | |||||
| print('Mean psnr of Derain is %.4f' % (psnrs.mean(axis=0))) | |||||
| else: | |||||
| print('Mean psnr of %s x%s is %.4f' % (args.data_test[0], args.scale[0], psnrs.mean(axis=0)[0])) | |||||
| if __name__ == '__main__': | |||||
| print("Start main function!") | |||||
| main() | |||||
| @@ -0,0 +1,26 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """hub config.""" | |||||
| from src.vitm import ViT | |||||
| def IPT(*args, **kwargs): | |||||
| return ViT(*args, **kwargs) | |||||
| def create_network(name, *args, **kwargs): | |||||
| if name == 'IPT': | |||||
| return IPT(*args, **kwargs) | |||||
| raise NotImplementedError(f"{name} is not implemented in the repo") | |||||
| @@ -0,0 +1,147 @@ | |||||
| <TOC> | |||||
| # Pre-Trained Image Processing Transformer (IPT) | |||||
| This repository is an official implementation of the paper "Pre-Trained Image Processing Transformer" from CVPR 2021. | |||||
| We study the low-level computer vision task (e.g., denoising, super-resolution and deraining) and develop a new pre-trained model, namely, image processing transformer (IPT). To maximally excavate the capability of transformer, we present to utilize the well-known ImageNet benchmark for generating a large amount of corrupted image pairs. The IPT model is trained on these images with multi-heads and multi-tails. In addition, the contrastive learning is introduced for well adapting to different image processing tasks. The pre-trained model can therefore efficiently employed on desired task after fine-tuning. With only one pre-trained model, IPT outperforms the current state-of-the-art methods on various low-level benchmarks. | |||||
| If you find our work useful in your research or publication, please cite our work: | |||||
| [1] Hanting Chen, Yunhe Wang, Tianyu Guo, Chang Xu, Yiping Deng, Zhenhua Liu, Siwei Ma, Chunjing Xu, Chao Xu, and Wen Gao. **"Pre-trained image processing transformer"**. <i>**CVPR 2021**.</i> [[arXiv](https://arxiv.org/abs/2012.00364)] | |||||
| @inproceedings{chen2020pre, | |||||
| title={Pre-trained image processing transformer}, | |||||
| author={Chen, Hanting and Wang, Yunhe and Guo, Tianyu and Xu, Chang and Deng, Yiping and Liu, Zhenhua and Ma, Siwei and Xu, Chunjing and Xu, Chao and Gao, Wen}, | |||||
| booktitle={CVPR}, | |||||
| year={2021} | |||||
| } | |||||
| ## Model architecture | |||||
| ### The overall network architecture of IPT is shown as below: | |||||
|  | |||||
| ## Dataset | |||||
| The benchmark datasets can be downloaded as follows: | |||||
| For super-resolution: | |||||
| Set5, | |||||
| [Set14](https://sites.google.com/site/romanzeyde/research-interests), | |||||
| [B100](https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/bsds/), | |||||
| [Urban100](https://sites.google.com/site/jbhuang0604/publications/struct_sr). | |||||
| For denoising: | |||||
| [CBSD68](https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/bsds/). | |||||
| For deraining: | |||||
| [Rain100L](https://www.icst.pku.edu.cn/struct/Projects/joint_rain_removal.html) | |||||
| The result images are converted into YCbCr color space. The PSNR is evaluated on the Y channel only. | |||||
| ## Requirements | |||||
| ### Hardware (GPU) | |||||
| > Prepare hardware environment with GPU. | |||||
| ### Framework | |||||
| > [MindSpore](https://www.mindspore.cn/install/en) | |||||
| ### For more information, please check the resources below: | |||||
| [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html) | |||||
| [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html) | |||||
| ## Script Description | |||||
| > This is the inference script of IPT, you can following steps to finish the test of image processing tasks, like SR, denoise and derain, via the corresponding pretrained models. | |||||
| ### Scripts and Sample Code | |||||
| ``` | |||||
| IPT | |||||
| ├── eval.py # inference entry | |||||
| ├── image | |||||
| │ └── ipt.png # the illustration of IPT network | |||||
| ├── model | |||||
| │ ├── IPT_denoise30.ckpt # denoise model weights for noise level 30 | |||||
| │ ├── IPT_denoise50.ckpt # denoise model weights for noise level 50 | |||||
| │ ├── IPT_derain.ckpt # derain model weights | |||||
| │ ├── IPT_sr2.ckpt # X2 super-resolution model weights | |||||
| │ ├── IPT_sr3.ckpt # X3 super-resolution model weights | |||||
| │ └── IPT_sr4.ckpt # X4 super-resolution model weights | |||||
| ├── readme.md # Readme | |||||
| ├── scripts | |||||
| │ └── run_eval.sh # inference script for all tasks | |||||
| └── src | |||||
| ├── args.py # options/hyper-parameters of IPT | |||||
| ├── data | |||||
| │ ├── common.py # common dataset | |||||
| │ ├── __init__.py # Class data init function | |||||
| │ └── srdata.py # flow of loading sr data | |||||
| ├── foldunfold_stride.py # function of fold and unfold operations for images | |||||
| ├── metrics.py # PSNR calculator | |||||
| ├── template.py # setting of model selection | |||||
| └── vitm.py # IPT network | |||||
| ``` | |||||
| ### Script Parameter | |||||
| > For details about hyperparameters, see src/args.py. | |||||
| ## Evaluation | |||||
| ### Evaluation Process | |||||
| > Inference example: | |||||
| > For SR x4: | |||||
| ``` | |||||
| python eval.py --dir_data ../../data/ --data_test Set14 --nochange --test_only --ext img --chop_new --scale 4 --pth_path ./model/IPT_sr4.ckpt | |||||
| ``` | |||||
| > Or one can run following script for all tasks. | |||||
| ``` | |||||
| sh scripts/run_eval.sh | |||||
| ``` | |||||
| ### Evaluation Result | |||||
| The result are evaluated by the value of PSNR (Peak Signal-to-Noise Ratio), and the format is as following. | |||||
| ``` | |||||
| result: {"Mean psnr of Se5 x4 is 32.68"} | |||||
| ``` | |||||
| ## Performance | |||||
| ### Inference Performance | |||||
| The Results on all tasks are listed as below. | |||||
| Super-resolution results: | |||||
| | Scale | Set5 | Set14 | B100 | Urban100 | | |||||
| | ----- | ----- | ----- | ----- | ----- | | |||||
| | ×2 | 38.36 | 34.54 | 32.50 | 33.88 | | |||||
| | ×3 | 34.83 | 30.96 | 29.39 | 29.59 | | |||||
| | ×4 | 32.68 | 29.01 | 27.81 | 27.24 | | |||||
| Denoising results: | |||||
| | noisy level | CBSD68 | Urban100 | | |||||
| | ----- | ----- | ----- | | |||||
| | 30 | 32.37 | 33.82 | | |||||
| | 50 | 29.94 | 31.56 | | |||||
| Derain results: | |||||
| | Task | Rain100L | | |||||
| | ----- | ----- | | |||||
| | Derain | 41.98 | | |||||
| ## ModeZoo Homepage | |||||
| Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo). | |||||
| @@ -0,0 +1,31 @@ | |||||
| #!/bin/bash | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| export DEVICE_ID=$1 | |||||
| DATA_DIR=$2 | |||||
| DATA_SET=$3 | |||||
| PATH_CHECKPOINT=$4 | |||||
| python ../eval.py --dir_data=$DATA_DIR --data_test=$DATA_SET --nochange --test_only --ext img --chop_new --scale 4 --pth_path=$PATH_CHECKPOINT > eval.log 2>&1 & | |||||
| python ../eval.py --dir_data=$DATA_DIR --data_test=$DATA_SET --nochange --test_only --ext img --chop_new --scale 3 --pth_path=$PATH_CHECKPOINT > eval.log 2>&1 & | |||||
| python ../eval.py --dir_data=$DATA_DIR --data_test=$DATA_SET --nochange --test_only --ext img --chop_new --scale 2 --pth_path=$PATH_CHECKPOINT > eval.log 2>&1 & | |||||
| ##denoise | |||||
| python ../eval.py --dir_data=$DATA_DIR --data_test=$DATA_SET --nochange --test_only --ext img --chop_new --scale 1 --denoise --sigma 30 --pth_path=$PATH_CHECKPOINT > eval.log 2>&1 & | |||||
| python ../eval.py --dir_data=$DATA_DIR --data_test=$DATA_SET --nochange --test_only --ext img --chop_new --scale 1 --denoise --sigma 50 --pth_path=$PATH_CHECKPOINT > eval.log 2>&1 & | |||||
| ##derain | |||||
| python ../eval.py --dir_data=$DATA_DIR --data_test=$DATA_SET --nochange --test_only --ext img --chop_new --scale 1 --derain --derain_test 1 --pth_path=$PATH_CHECKPOINT > eval.log 2>&1 & | |||||
| @@ -0,0 +1,239 @@ | |||||
| '''args''' | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import argparse | |||||
| from src import template | |||||
| parser = argparse.ArgumentParser(description='EDSR and MDSR') | |||||
| parser.add_argument('--debug', action='store_true', | |||||
| help='Enables debug mode') | |||||
| parser.add_argument('--template', default='.', | |||||
| help='You can set various templates in option.py') | |||||
| # Hardware specifications | |||||
| parser.add_argument('--n_threads', type=int, default=6, | |||||
| help='number of threads for data loading') | |||||
| parser.add_argument('--cpu', action='store_true', | |||||
| help='use cpu only') | |||||
| parser.add_argument('--n_GPUs', type=int, default=1, | |||||
| help='number of GPUs') | |||||
| parser.add_argument('--seed', type=int, default=1, | |||||
| help='random seed') | |||||
| # Data specifications | |||||
| parser.add_argument('--dir_data', type=str, default='/cache/data/', | |||||
| help='dataset directory') | |||||
| parser.add_argument('--dir_demo', type=str, default='../test', | |||||
| help='demo image directory') | |||||
| parser.add_argument('--data_train', type=str, default='DIV2K', | |||||
| help='train dataset name') | |||||
| parser.add_argument('--data_test', type=str, default='DIV2K', | |||||
| help='test dataset name') | |||||
| parser.add_argument('--data_range', type=str, default='1-800/801-810', | |||||
| help='train/test data range') | |||||
| parser.add_argument('--ext', type=str, default='sep', | |||||
| help='dataset file extension') | |||||
| parser.add_argument('--scale', type=str, default='4', | |||||
| help='super resolution scale') | |||||
| parser.add_argument('--patch_size', type=int, default=48, | |||||
| help='output patch size') | |||||
| parser.add_argument('--rgb_range', type=int, default=255, | |||||
| help='maximum value of RGB') | |||||
| parser.add_argument('--n_colors', type=int, default=3, | |||||
| help='number of color channels to use') | |||||
| parser.add_argument('--chop', action='store_true', | |||||
| help='enable memory-efficient forward') | |||||
| parser.add_argument('--no_augment', action='store_true', | |||||
| help='do not use data augmentation') | |||||
| # Model specifications | |||||
| parser.add_argument('--model', default='vtip', | |||||
| help='model name') | |||||
| parser.add_argument('--act', type=str, default='relu', | |||||
| help='activation function') | |||||
| parser.add_argument('--pre_train', type=str, default='', | |||||
| help='pre-trained model directory') | |||||
| parser.add_argument('--extend', type=str, default='.', | |||||
| help='pre-trained model directory') | |||||
| parser.add_argument('--n_resblocks', type=int, default=16, | |||||
| help='number of residual blocks') | |||||
| parser.add_argument('--n_feats', type=int, default=64, | |||||
| help='number of feature maps') | |||||
| parser.add_argument('--res_scale', type=float, default=1, | |||||
| help='residual scaling') | |||||
| parser.add_argument('--shift_mean', default=True, | |||||
| help='subtract pixel mean from the input') | |||||
| parser.add_argument('--dilation', action='store_true', | |||||
| help='use dilated convolution') | |||||
| parser.add_argument('--precision', type=str, default='single', | |||||
| choices=('single', 'half'), | |||||
| help='FP precision for test (single | half)') | |||||
| # Option for Residual dense network (RDN) | |||||
| parser.add_argument('--G0', type=int, default=64, | |||||
| help='default number of filters. (Use in RDN)') | |||||
| parser.add_argument('--RDNkSize', type=int, default=3, | |||||
| help='default kernel size. (Use in RDN)') | |||||
| parser.add_argument('--RDNconfig', type=str, default='B', | |||||
| help='parameters config of RDN. (Use in RDN)') | |||||
| # Option for Residual channel attention network (RCAN) | |||||
| parser.add_argument('--n_resgroups', type=int, default=10, | |||||
| help='number of residual groups') | |||||
| parser.add_argument('--reduction', type=int, default=16, | |||||
| help='number of feature maps reduction') | |||||
| # Training specifications | |||||
| parser.add_argument('--reset', action='store_true', | |||||
| help='reset the training') | |||||
| parser.add_argument('--test_every', type=int, default=1000, | |||||
| help='do test per every N batches') | |||||
| parser.add_argument('--epochs', type=int, default=300, | |||||
| help='number of epochs to train') | |||||
| parser.add_argument('--batch_size', type=int, default=16, | |||||
| help='input batch size for training') | |||||
| parser.add_argument('--test_batch_size', type=int, default=1, | |||||
| help='input batch size for training') | |||||
| parser.add_argument('--split_batch', type=int, default=1, | |||||
| help='split the batch into smaller chunks') | |||||
| parser.add_argument('--self_ensemble', action='store_true', | |||||
| help='use self-ensemble method for test') | |||||
| parser.add_argument('--test_only', action='store_true', | |||||
| help='set this option to test the model') | |||||
| parser.add_argument('--gan_k', type=int, default=1, | |||||
| help='k value for adversarial loss') | |||||
| # Optimization specifications | |||||
| parser.add_argument('--lr', type=float, default=1e-4, | |||||
| help='learning rate') | |||||
| parser.add_argument('--decay', type=str, default='200', | |||||
| help='learning rate decay type') | |||||
| parser.add_argument('--gamma', type=float, default=0.5, | |||||
| help='learning rate decay factor for step decay') | |||||
| parser.add_argument('--optimizer', default='ADAM', | |||||
| choices=('SGD', 'ADAM', 'RMSprop'), | |||||
| help='optimizer to use (SGD | ADAM | RMSprop)') | |||||
| parser.add_argument('--momentum', type=float, default=0.9, | |||||
| help='SGD momentum') | |||||
| parser.add_argument('--betas', type=tuple, default=(0.9, 0.999), | |||||
| help='ADAM beta') | |||||
| parser.add_argument('--epsilon', type=float, default=1e-8, | |||||
| help='ADAM epsilon for numerical stability') | |||||
| parser.add_argument('--weight_decay', type=float, default=0, | |||||
| help='weight decay') | |||||
| parser.add_argument('--gclip', type=float, default=0, | |||||
| help='gradient clipping threshold (0 = no clipping)') | |||||
| # Loss specifications | |||||
| parser.add_argument('--loss', type=str, default='1*L1', | |||||
| help='loss function configuration') | |||||
| parser.add_argument('--skip_threshold', type=float, default='1e8', | |||||
| help='skipping batch that has large error') | |||||
| # Log specifications | |||||
| parser.add_argument('--save', type=str, default='/cache/results/edsr_baseline_x2/', | |||||
| help='file name to save') | |||||
| parser.add_argument('--load', type=str, default='', | |||||
| help='file name to load') | |||||
| parser.add_argument('--resume', type=int, default=0, | |||||
| help='resume from specific checkpoint') | |||||
| parser.add_argument('--save_models', action='store_true', | |||||
| help='save all intermediate models') | |||||
| parser.add_argument('--print_every', type=int, default=100, | |||||
| help='how many batches to wait before logging training status') | |||||
| parser.add_argument('--save_results', action='store_true', | |||||
| help='save output results') | |||||
| parser.add_argument('--save_gt', action='store_true', | |||||
| help='save low-resolution and high-resolution images together') | |||||
| parser.add_argument('--scalelr', type=int, default=0) | |||||
| # cloud | |||||
| parser.add_argument('--moxfile', type=int, default=1) | |||||
| parser.add_argument('--imagenet', type=int, default=0) | |||||
| parser.add_argument('--data_url', type=str, help='path to dataset') | |||||
| parser.add_argument('--train_url', type=str, help='train_dir') | |||||
| parser.add_argument('--pretrain', type=str, default='') | |||||
| parser.add_argument('--pth_path', type=str, default='') | |||||
| parser.add_argument('--load_query', type=int, default=0) | |||||
| # transformer | |||||
| parser.add_argument('--patch_dim', type=int, default=3) | |||||
| parser.add_argument('--num_heads', type=int, default=12) | |||||
| parser.add_argument('--num_layers', type=int, default=12) | |||||
| parser.add_argument('--dropout_rate', type=float, default=0) | |||||
| parser.add_argument('--no_norm', action='store_true') | |||||
| parser.add_argument('--post_norm', action='store_true') | |||||
| parser.add_argument('--no_mlp', action='store_true') | |||||
| parser.add_argument('--test', action='store_true') | |||||
| parser.add_argument('--chop_new', action='store_true') | |||||
| parser.add_argument('--pos_every', action='store_true') | |||||
| parser.add_argument('--no_pos', action='store_true') | |||||
| parser.add_argument('--num_queries', type=int, default=6) | |||||
| parser.add_argument('--reweight', action='store_true') | |||||
| # denoise | |||||
| parser.add_argument('--denoise', action='store_true') | |||||
| parser.add_argument('--sigma', type=float, default=25) | |||||
| # derain | |||||
| parser.add_argument('--derain', action='store_true') | |||||
| parser.add_argument('--finetune', action='store_true') | |||||
| parser.add_argument('--derain_test', type=int, default=10) | |||||
| # alltask | |||||
| parser.add_argument('--alltask', action='store_true') | |||||
| # dehaze | |||||
| parser.add_argument('--dehaze', action='store_true') | |||||
| parser.add_argument('--dehaze_test', type=int, default=100) | |||||
| parser.add_argument('--indoor', action='store_true') | |||||
| parser.add_argument('--outdoor', action='store_true') | |||||
| parser.add_argument('--nochange', action='store_true') | |||||
| # deblur | |||||
| parser.add_argument('--deblur', action='store_true') | |||||
| parser.add_argument('--deblur_test', type=int, default=1000) | |||||
| # distribute | |||||
| parser.add_argument('--init_method', type=str, | |||||
| default=None, help='master address') | |||||
| parser.add_argument('--rank', type=int, default=0, | |||||
| help='Index of current task') | |||||
| parser.add_argument('--world_size', type=int, default=1, | |||||
| help='Total number of tasks') | |||||
| parser.add_argument('--gpu', default=None, type=int, | |||||
| help='GPU id to use.') | |||||
| parser.add_argument('--dist-url', default='', type=str, | |||||
| help='url used to set up distributed training') | |||||
| parser.add_argument('--dist-backend', default='nccl', type=str, | |||||
| help='distributed backend') | |||||
| parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', | |||||
| help='number of data loading workers (default: 4)') | |||||
| parser.add_argument('--distribute', action='store_true') | |||||
| args, unparsed = parser.parse_known_args() | |||||
| template.set_template(args) | |||||
| args.scale = [int(x) for x in args.scale.split("+")] | |||||
| args.data_train = args.data_train.split('+') | |||||
| args.data_test = args.data_test.split('+') | |||||
| if args.epochs == 0: | |||||
| args.epochs = 1e8 | |||||
| for arg in vars(args): | |||||
| if vars(args)[arg] == 'True': | |||||
| vars(args)[arg] = True | |||||
| elif vars(args)[arg] == 'False': | |||||
| vars(args)[arg] = False | |||||
| @@ -0,0 +1,35 @@ | |||||
| """data""" | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd" | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| from importlib import import_module | |||||
| class Data: | |||||
| """data""" | |||||
| def __init__(self, args): | |||||
| self.loader_train = None | |||||
| self.loader_test = [] | |||||
| for d in args.data_test: | |||||
| if d in ['Set5', 'Set14', 'B100', 'Urban100', 'Manga109', 'CBSD68', 'Rain100L', 'GOPRO_Large']: | |||||
| m = import_module('data.benchmark') | |||||
| testset = getattr(m, 'Benchmark')(args, train=False, name=d) | |||||
| else: | |||||
| module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG' | |||||
| m = import_module('data.' + module_name.lower()) | |||||
| testset = getattr(m, module_name)(args, train=False, name=d) | |||||
| self.loader_test.append( | |||||
| testset | |||||
| ) | |||||
| @@ -0,0 +1,93 @@ | |||||
| """common""" | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import random | |||||
| import numpy as np | |||||
| import skimage.color as sc | |||||
| def get_patch(*args, patch_size=96, scale=2, multi=False, input_large=False): | |||||
| """common""" | |||||
| ih, iw = args[0].shape[:2] | |||||
| tp = patch_size | |||||
| ip = tp // scale | |||||
| ix = random.randrange(0, iw - ip + 1) | |||||
| iy = random.randrange(0, ih - ip + 1) | |||||
| if not input_large: | |||||
| tx, ty = scale * ix, scale * iy | |||||
| else: | |||||
| tx, ty = ix, iy | |||||
| ret = [ | |||||
| args[0][iy:iy + ip, ix:ix + ip, :], | |||||
| *[a[ty:ty + tp, tx:tx + tp, :] for a in args[1:]] | |||||
| ] | |||||
| return ret | |||||
| def set_channel(*args, n_channels=3): | |||||
| """common""" | |||||
| def _set_channel(img): | |||||
| if img.ndim == 2: | |||||
| img = np.expand_dims(img, axis=2) | |||||
| c = img.shape[2] | |||||
| if n_channels == 1 and c == 3: | |||||
| img = np.expand_dims(sc.rgb2ycbcr(img)[:, :, 0], 2) | |||||
| elif n_channels == 3 and c == 1: | |||||
| img = np.concatenate([img] * n_channels, 2) | |||||
| return img[:, :, :n_channels] | |||||
| return [_set_channel(a) for a in args] | |||||
| def np2Tensor(*args, rgb_range=255): | |||||
| """common""" | |||||
| def _np2Tensor(img): | |||||
| np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1))) | |||||
| tensor = np_transpose.astype(np.float32) | |||||
| tensor = tensor * (rgb_range / 255) | |||||
| # tensor = torch.from_numpy(np_transpose).float() | |||||
| # tensor.mul_(rgb_range / 255) | |||||
| return tensor | |||||
| return [_np2Tensor(a) for a in args] | |||||
| def augment(*args, hflip=True, rot=True): | |||||
| """common""" | |||||
| hflip = hflip and random.random() < 0.5 | |||||
| vflip = rot and random.random() < 0.5 | |||||
| rot90 = rot and random.random() < 0.5 | |||||
| def _augment(img): | |||||
| if hflip: | |||||
| img = img[:, ::-1, :] | |||||
| if vflip: | |||||
| img = img[::-1, :, :] | |||||
| if rot90: | |||||
| img = img.transpose(1, 0, 2) | |||||
| return img | |||||
| return [_augment(a) for a in args] | |||||
| @@ -0,0 +1,301 @@ | |||||
| """srdata""" | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import os | |||||
| import glob | |||||
| import random | |||||
| import pickle | |||||
| from src.data import common | |||||
| import numpy as np | |||||
| import imageio | |||||
| def search(root, target="JPEG"): | |||||
| """srdata""" | |||||
| item_list = [] | |||||
| items = os.listdir(root) | |||||
| for item in items: | |||||
| path = os.path.join(root, item) | |||||
| if os.path.isdir(path): | |||||
| item_list.extend(search(path, target)) | |||||
| elif path.split('/')[-1].startswith(target): | |||||
| item_list.append(path) | |||||
| elif target in (path.split('/')[-2], path.split('/')[-3], path.split('/')[-4]): | |||||
| item_list.append(path) | |||||
| else: | |||||
| item_list = [] | |||||
| return item_list | |||||
| def search_dehaze(root, target="JPEG"): | |||||
| """srdata""" | |||||
| item_list = [] | |||||
| items = os.listdir(root) | |||||
| for item in items: | |||||
| path = os.path.join(root, item) | |||||
| if os.path.isdir(path): | |||||
| extend_list = search_dehaze(path, target) | |||||
| if extend_list is not None: | |||||
| item_list.extend(extend_list) | |||||
| elif path.split('/')[-2].endswith(target): | |||||
| item_list.append(path) | |||||
| return item_list | |||||
| class SRData(): | |||||
| """srdata""" | |||||
| def __init__(self, args, name='', train=True, benchmark=False): | |||||
| self.args = args | |||||
| self.name = name | |||||
| self.train = train | |||||
| self.split = 'train' if train else 'test' | |||||
| self.do_eval = True | |||||
| self.benchmark = benchmark | |||||
| self.input_large = (args.model == 'VDSR') | |||||
| self.scale = args.scale | |||||
| self.idx_scale = 0 | |||||
| if self.args.derain: | |||||
| self.derain_test = os.path.join(args.dir_data, "Rain100L") | |||||
| self.derain_lr_test = search(self.derain_test, "rain") | |||||
| self.derain_hr_test = [path.replace( | |||||
| "rainy/", "no") for path in self.derain_lr_test] | |||||
| self._set_filesystem(args.dir_data) | |||||
| if args.ext.find('img') < 0: | |||||
| path_bin = os.path.join(self.apath, 'bin') | |||||
| os.makedirs(path_bin, exist_ok=True) | |||||
| list_hr, list_lr = self._scan() | |||||
| if args.ext.find('img') >= 0 or benchmark: | |||||
| self.images_hr, self.images_lr = list_hr, list_lr | |||||
| elif args.ext.find('sep') >= 0: | |||||
| os.makedirs( | |||||
| self.dir_hr.replace(self.apath, path_bin), | |||||
| exist_ok=True | |||||
| ) | |||||
| for s in self.scale: | |||||
| if s == 1: | |||||
| os.makedirs( | |||||
| os.path.join(self.dir_hr), | |||||
| exist_ok=True | |||||
| ) | |||||
| else: | |||||
| os.makedirs( | |||||
| os.path.join( | |||||
| self.dir_lr.replace(self.apath, path_bin), | |||||
| 'X{}'.format(s) | |||||
| ), | |||||
| exist_ok=True | |||||
| ) | |||||
| self.images_hr, self.images_lr = [], [[] for _ in self.scale] | |||||
| for h in list_hr: | |||||
| b = h.replace(self.apath, path_bin) | |||||
| b = b.replace(self.ext[0], '.pt') | |||||
| self.images_hr.append(b) | |||||
| self._check_and_load(args.ext, h, b, verbose=True) | |||||
| for i, ll in enumerate(list_lr): | |||||
| for l in ll: | |||||
| b = l.replace(self.apath, path_bin) | |||||
| b = b.replace(self.ext[1], '.pt') | |||||
| self.images_lr[i].append(b) | |||||
| self._check_and_load(args.ext, l, b, verbose=True) | |||||
| # Below functions as used to prepare images | |||||
| def _scan(self): | |||||
| """srdata""" | |||||
| names_hr = sorted( | |||||
| glob.glob(os.path.join(self.dir_hr, '*' + self.ext[0])) | |||||
| ) | |||||
| names_lr = [[] for _ in self.scale] | |||||
| for f in names_hr: | |||||
| filename, _ = os.path.splitext(os.path.basename(f)) | |||||
| for si, s in enumerate(self.scale): | |||||
| if s != 1: | |||||
| scale = s | |||||
| names_lr[si].append(os.path.join( | |||||
| self.dir_lr, 'X{}/{}x{}{}'.format( | |||||
| s, filename, scale, self.ext[1] | |||||
| ) | |||||
| )) | |||||
| for si, s in enumerate(self.scale): | |||||
| if s == 1: | |||||
| names_lr[si] = names_hr | |||||
| return names_hr, names_lr | |||||
| def _set_filesystem(self, dir_data): | |||||
| self.apath = os.path.join(dir_data, self.name[0]) | |||||
| self.dir_hr = os.path.join(self.apath, 'HR') | |||||
| self.dir_lr = os.path.join(self.apath, 'LR_bicubic') | |||||
| self.ext = ('.png', '.png') | |||||
| def _check_and_load(self, ext, img, f, verbose=True): | |||||
| if not os.path.isfile(f) or ext.find('reset') >= 0: | |||||
| if verbose: | |||||
| print('Making a binary: {}'.format(f)) | |||||
| with open(f, 'wb') as _f: | |||||
| pickle.dump(imageio.imread(img), _f) | |||||
| def __getitem__(self, idx): | |||||
| if self.args.model == 'vtip' and self.args.derain and self.scale[ | |||||
| self.idx_scale] == 1 and not self.args.finetune: | |||||
| norain, rain, _ = self._load_rain_test(idx) | |||||
| pair = common.set_channel( | |||||
| *[rain, norain], n_channels=self.args.n_colors) | |||||
| pair_t = common.np2Tensor(*pair, rgb_range=self.args.rgb_range) | |||||
| return pair_t[0], pair_t[1] | |||||
| if self.args.model == 'vtip' and self.args.denoise and self.scale[self.idx_scale] == 1: | |||||
| hr, _ = self._load_file_hr(idx) | |||||
| pair = self.get_patch_hr(hr) | |||||
| pair = common.set_channel(*[pair], n_channels=self.args.n_colors) | |||||
| pair_t = common.np2Tensor(*pair, rgb_range=self.args.rgb_range) | |||||
| noise = np.random.randn(*pair_t[0].shape) * self.args.sigma | |||||
| lr = pair_t[0] + noise | |||||
| lr = np.float32(np.clip(lr, 0, 255)) | |||||
| return lr, pair_t[0] | |||||
| lr, hr, _ = self._load_file(idx) | |||||
| pair = self.get_patch(lr, hr) | |||||
| pair = common.set_channel(*pair, n_channels=self.args.n_colors) | |||||
| pair_t = common.np2Tensor(*pair, rgb_range=self.args.rgb_range) | |||||
| return pair_t[0], pair_t[1] | |||||
| def __len__(self): | |||||
| if self.train: | |||||
| return len(self.images_hr) * self.repeat | |||||
| if self.args.derain and not self.args.alltask: | |||||
| return int(len(self.derain_hr_test) / self.args.derain_test) | |||||
| return len(self.images_hr) | |||||
| def _get_index(self, idx): | |||||
| """srdata""" | |||||
| if self.train: | |||||
| return idx % len(self.images_hr) | |||||
| return idx | |||||
| def _load_file_hr(self, idx): | |||||
| """srdata""" | |||||
| idx = self._get_index(idx) | |||||
| f_hr = self.images_hr[idx] | |||||
| filename, _ = os.path.splitext(os.path.basename(f_hr)) | |||||
| if self.args.ext == 'img' or self.benchmark: | |||||
| hr = imageio.imread(f_hr) | |||||
| elif self.args.ext.find('sep') >= 0: | |||||
| with open(f_hr, 'rb') as _f: | |||||
| hr = pickle.load(_f) | |||||
| return hr, filename | |||||
| def _load_rain(self, idx, rain_img=False): | |||||
| """srdata""" | |||||
| idx = random.randint(0, len(self.derain_img_list) - 1) | |||||
| f_lr = self.derain_img_list[idx] | |||||
| if rain_img: | |||||
| norain = imageio.imread(f_lr.replace("rainstreak", "norain")) | |||||
| rain = imageio.imread(f_lr.replace("rainstreak", "rain")) | |||||
| return norain, rain | |||||
| lr = imageio.imread(f_lr) | |||||
| return lr | |||||
| def _load_rain_test(self, idx): | |||||
| """srdata""" | |||||
| f_hr = self.derain_hr_test[idx] | |||||
| f_lr = self.derain_lr_test[idx] | |||||
| filename, _ = os.path.splitext(os.path.basename(f_lr)) | |||||
| norain = imageio.imread(f_hr) | |||||
| rain = imageio.imread(f_lr) | |||||
| return norain, rain, filename | |||||
| def _load_denoise(self, idx): | |||||
| """srdata""" | |||||
| idx = self._get_index(idx) | |||||
| f_lr = self.images_hr[idx] | |||||
| norain = imageio.imread(f_lr) | |||||
| rain = imageio.imread(f_lr.replace("HR", "LR_bicubic")) | |||||
| return norain, rain | |||||
| def _load_file(self, idx): | |||||
| """srdata""" | |||||
| idx = self._get_index(idx) | |||||
| f_hr = self.images_hr[idx] | |||||
| f_lr = self.images_lr[self.idx_scale][idx] | |||||
| filename, _ = os.path.splitext(os.path.basename(f_hr)) | |||||
| if self.args.ext == 'img' or self.benchmark: | |||||
| hr = imageio.imread(f_hr) | |||||
| lr = imageio.imread(f_lr) | |||||
| elif self.args.ext.find('sep') >= 0: | |||||
| with open(f_hr, 'rb') as _f: | |||||
| hr = pickle.load(_f) | |||||
| with open(f_lr, 'rb') as _f: | |||||
| lr = pickle.load(_f) | |||||
| return lr, hr, filename | |||||
| def get_patch_hr(self, hr): | |||||
| """srdata""" | |||||
| if self.train: | |||||
| hr = self.get_patch_img_hr( | |||||
| hr, | |||||
| patch_size=self.args.patch_size, | |||||
| scale=1 | |||||
| ) | |||||
| return hr | |||||
| def get_patch_img_hr(self, img, patch_size=96, scale=2): | |||||
| """srdata""" | |||||
| ih, iw = img.shape[:2] | |||||
| tp = patch_size | |||||
| ip = tp // scale | |||||
| ix = random.randrange(0, iw - ip + 1) | |||||
| iy = random.randrange(0, ih - ip + 1) | |||||
| ret = img[iy:iy + ip, ix:ix + ip, :] | |||||
| return ret | |||||
| def get_patch(self, lr, hr): | |||||
| """srdata""" | |||||
| scale = self.scale[self.idx_scale] | |||||
| if self.train: | |||||
| lr, hr = common.get_patch( | |||||
| lr, hr, | |||||
| patch_size=self.args.patch_size * scale, | |||||
| scale=scale, | |||||
| multi=(len(self.scale) > 1) | |||||
| ) | |||||
| if not self.args.no_augment: | |||||
| lr, hr = common.augment(lr, hr) | |||||
| else: | |||||
| ih, iw = lr.shape[:2] | |||||
| hr = hr[0:ih * scale, 0:iw * scale] | |||||
| return lr, hr | |||||
| def set_scale(self, idx_scale): | |||||
| """srdata""" | |||||
| if not self.input_large: | |||||
| self.idx_scale = idx_scale | |||||
| else: | |||||
| self.idx_scale = random.randint(0, len(self.scale) - 1) | |||||
| @@ -0,0 +1,241 @@ | |||||
| '''stride''' | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import numpy as np | |||||
| from mindspore import nn | |||||
| from mindspore.ops import operations as P | |||||
| from mindspore.common import dtype as mstype | |||||
| class _stride_unfold_(nn.Cell): | |||||
| '''stride''' | |||||
| def __init__(self, | |||||
| kernel_size, | |||||
| stride=-1): | |||||
| super(_stride_unfold_, self).__init__() | |||||
| if stride == -1: | |||||
| self.stride = kernel_size | |||||
| else: | |||||
| self.stride = stride | |||||
| self.kernel_size = kernel_size | |||||
| self.reshape = P.Reshape() | |||||
| self.transpose = P.Transpose() | |||||
| self.unfold = _unfold_(kernel_size) | |||||
| def construct(self, x): | |||||
| """stride""" | |||||
| N, C, H, W = x.shape | |||||
| leftup_idx_x = [] | |||||
| leftup_idx_y = [] | |||||
| nh = int(H / self.stride) | |||||
| nw = int(W / self.stride) | |||||
| for i in range(nh): | |||||
| leftup_idx_x.append(i * self.stride) | |||||
| for i in range(nw): | |||||
| leftup_idx_y.append(i * self.stride) | |||||
| NumBlock_x = len(leftup_idx_x) | |||||
| NumBlock_y = len(leftup_idx_y) | |||||
| zeroslike = P.ZerosLike() | |||||
| cc_2 = P.Concat(axis=2) | |||||
| cc_3 = P.Concat(axis=3) | |||||
| unf_x = P.Zeros()((N, C, NumBlock_x * self.kernel_size, | |||||
| NumBlock_y * self.kernel_size), mstype.float32) | |||||
| N, C, H, W = unf_x.shape | |||||
| for i in range(NumBlock_x): | |||||
| for j in range(NumBlock_y): | |||||
| unf_i = i * self.kernel_size | |||||
| unf_j = j * self.kernel_size | |||||
| org_i = leftup_idx_x[i] | |||||
| org_j = leftup_idx_y[j] | |||||
| fills = x[:, :, org_i:org_i + self.kernel_size, | |||||
| org_j:org_j + self.kernel_size] | |||||
| unf_x += cc_3((cc_3((zeroslike(unf_x[:, :, :, :unf_j]), cc_2((cc_2( | |||||
| (zeroslike(unf_x[:, :, :unf_i, unf_j:unf_j + self.kernel_size]), fills)), zeroslike( | |||||
| unf_x[:, :, unf_i + self.kernel_size:, unf_j:unf_j + self.kernel_size]))))), | |||||
| zeroslike(unf_x[:, :, :, unf_j + self.kernel_size:]))) | |||||
| y = self.unfold(unf_x) | |||||
| return y | |||||
| class _stride_fold_(nn.Cell): | |||||
| '''stride''' | |||||
| def __init__(self, | |||||
| kernel_size, | |||||
| output_shape=(-1, -1), | |||||
| stride=-1): | |||||
| super(_stride_fold_, self).__init__() | |||||
| if isinstance(kernel_size, (list, tuple)): | |||||
| self.kernel_size = kernel_size | |||||
| else: | |||||
| self.kernel_size = [kernel_size, kernel_size] | |||||
| if stride == -1: | |||||
| self.stride = kernel_size[0] | |||||
| else: | |||||
| self.stride = stride | |||||
| self.output_shape = output_shape | |||||
| self.reshape = P.Reshape() | |||||
| self.transpose = P.Transpose() | |||||
| self.fold = _fold_(kernel_size) | |||||
| def construct(self, x): | |||||
| '''stride''' | |||||
| if self.output_shape[0] == -1: | |||||
| large_x = self.fold(x) | |||||
| N, C, H, _ = large_x.shape | |||||
| leftup_idx = [] | |||||
| for i in range(0, H, self.kernel_size[0]): | |||||
| leftup_idx.append(i) | |||||
| NumBlock = len(leftup_idx) | |||||
| fold_x = P.Zeros()((N, C, (NumBlock - 1) * self.stride + self.kernel_size[0], | |||||
| (NumBlock - 1) * self.stride + self.kernel_size[0]), mstype.float32) | |||||
| for i in range(NumBlock): | |||||
| for j in range(NumBlock): | |||||
| fold_i = i * self.stride | |||||
| fold_j = j * self.stride | |||||
| org_i = leftup_idx[i] | |||||
| org_j = leftup_idx[j] | |||||
| fills = x[:, :, org_i:org_i + self.kernel_size[0], | |||||
| org_j:org_j + self.kernel_size[1]] | |||||
| fold_x += cc_3((cc_3((zeroslike(fold_x[:, :, :, :fold_j]), cc_2((cc_2( | |||||
| (zeroslike(fold_x[:, :, :fold_i, fold_j:fold_j + self.kernel_size[1]]), fills)), zeroslike( | |||||
| fold_x[:, :, fold_i + self.kernel_size[0]:, fold_j:fold_j + self.kernel_size[1]]))))), | |||||
| zeroslike(fold_x[:, :, :, fold_j + self.kernel_size[1]:]))) | |||||
| y = fold_x | |||||
| else: | |||||
| NumBlock_x = int( | |||||
| (self.output_shape[0] - self.kernel_size[0]) / self.stride + 1) | |||||
| NumBlock_y = int( | |||||
| (self.output_shape[1] - self.kernel_size[1]) / self.stride + 1) | |||||
| large_shape = [NumBlock_x * self.kernel_size[0], | |||||
| NumBlock_y * self.kernel_size[1]] | |||||
| self.fold = _fold_(self.kernel_size, large_shape) | |||||
| large_x = self.fold(x) | |||||
| N, C, H, _ = large_x.shape | |||||
| leftup_idx_x = [] | |||||
| leftup_idx_y = [] | |||||
| for i in range(NumBlock_x): | |||||
| leftup_idx_x.append(i * self.kernel_size[0]) | |||||
| for i in range(NumBlock_y): | |||||
| leftup_idx_y.append(i * self.kernel_size[1]) | |||||
| fold_x = P.Zeros()((N, C, (NumBlock_x - 1) * self.stride + self.kernel_size[0], | |||||
| (NumBlock_y - 1) * self.stride + self.kernel_size[1]), mstype.float32) | |||||
| for i in range(NumBlock_x): | |||||
| for j in range(NumBlock_y): | |||||
| fold_i = i * self.stride | |||||
| fold_j = j * self.stride | |||||
| org_i = leftup_idx_x[i] | |||||
| org_j = leftup_idx_y[j] | |||||
| fills = x[:, :, org_i:org_i + self.kernel_size[0], | |||||
| org_j:org_j + self.kernel_size[1]] | |||||
| fold_x += cc_3((cc_3((zeroslike(fold_x[:, :, :, :fold_j]), cc_2((cc_2( | |||||
| (zeroslike(fold_x[:, :, :fold_i, fold_j:fold_j + self.kernel_size[1]]), fills)), zeroslike( | |||||
| fold_x[:, :, fold_i + self.kernel_size[0]:, fold_j:fold_j + self.kernel_size[1]]))))), | |||||
| zeroslike(fold_x[:, :, :, fold_j + self.kernel_size[1]:]))) | |||||
| y = fold_x | |||||
| return y | |||||
| class _unfold_(nn.Cell): | |||||
| '''stride''' | |||||
| def __init__(self, | |||||
| kernel_size, | |||||
| stride=-1): | |||||
| super(_unfold_, self).__init__() | |||||
| if stride == -1: | |||||
| self.stride = kernel_size | |||||
| self.kernel_size = kernel_size | |||||
| self.reshape = P.Reshape() | |||||
| self.transpose = P.Transpose() | |||||
| def construct(self, x): | |||||
| '''stride''' | |||||
| N, C, H, W = x.shape | |||||
| numH = int(H / self.kernel_size) | |||||
| numW = int(W / self.kernel_size) | |||||
| if numH * self.kernel_size != H or numW * self.kernel_size != W: | |||||
| x = x[:, :, :numH * self.kernel_size, :, numW * self.kernel_size] | |||||
| output_img = self.reshape(x, (N, C, numH, self.kernel_size, W)) | |||||
| output_img = self.transpose(output_img, (0, 1, 2, 4, 3)) | |||||
| output_img = self.reshape(output_img, (N, C, int( | |||||
| numH * numW), self.kernel_size, self.kernel_size)) | |||||
| output_img = self.transpose(output_img, (0, 2, 1, 4, 3)) | |||||
| output_img = self.reshape(output_img, (N, int(numH * numW), -1)) | |||||
| return output_img | |||||
| class _fold_(nn.Cell): | |||||
| '''stride''' | |||||
| def __init__(self, | |||||
| kernel_size, | |||||
| output_shape=(-1, -1), | |||||
| stride=-1): | |||||
| super(_fold_, self).__init__() | |||||
| if isinstance(kernel_size, (list, tuple)): | |||||
| self.kernel_size = kernel_size | |||||
| else: | |||||
| self.kernel_size = [kernel_size, kernel_size] | |||||
| if stride == -1: | |||||
| self.stride = kernel_size[0] | |||||
| self.output_shape = output_shape | |||||
| self.reshape = P.Reshape() | |||||
| self.transpose = P.Transpose() | |||||
| def construct(self, x): | |||||
| '''stride''' | |||||
| N, C, L = x.shape | |||||
| org_C = int(L / self.kernel_size[0] / self.kernel_size[1]) | |||||
| if self.output_shape[0] == -1: | |||||
| numH = int(np.sqrt(C)) | |||||
| numW = int(np.sqrt(C)) | |||||
| org_H = int(numH * self.kernel_size[0]) | |||||
| org_W = org_H | |||||
| else: | |||||
| org_H = int(self.output_shape[0]) | |||||
| org_W = int(self.output_shape[1]) | |||||
| numH = int(org_H / self.kernel_size[0]) | |||||
| numW = int(org_W / self.kernel_size[1]) | |||||
| output_img = self.reshape( | |||||
| x, (N, C, org_C, self.kernel_size[0], self.kernel_size[1])) | |||||
| output_img = self.transpose(output_img, (0, 2, 1, 3, 4)) | |||||
| output_img = self.reshape( | |||||
| output_img, (N, org_C, numH, numW, self.kernel_size[0], self.kernel_size[1])) | |||||
| output_img = self.transpose(output_img, (0, 1, 2, 4, 3, 5)) | |||||
| output_img = self.reshape(output_img, (N, org_C, org_H, org_W)) | |||||
| return output_img | |||||
| @@ -0,0 +1,56 @@ | |||||
| '''metrics''' | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import math | |||||
| import numpy as np | |||||
| def quantize(img, rgb_range): | |||||
| '''metrics''' | |||||
| pixel_range = 255 / rgb_range | |||||
| img = np.multiply(img, pixel_range) | |||||
| img = np.clip(img, 0, 255) | |||||
| img = np.round(img) / pixel_range | |||||
| return img | |||||
| def calc_psnr(sr, hr, scale, rgb_range, y_only=False, dataset=None): | |||||
| '''metrics''' | |||||
| hr = np.float32(hr) | |||||
| sr = np.float32(sr) | |||||
| diff = (sr - hr) / rgb_range | |||||
| gray_coeffs = np.array([65.738, 129.057, 25.064] | |||||
| ).reshape((1, 3, 1, 1)) / 256 | |||||
| diff = np.multiply(diff, gray_coeffs).sum(1) | |||||
| if hr.size == 1: | |||||
| return 0 | |||||
| if scale != 1: | |||||
| shave = scale | |||||
| else: | |||||
| shave = scale + 6 | |||||
| if scale == 1: | |||||
| valid = diff | |||||
| else: | |||||
| valid = diff[..., shave:-shave, shave:-shave] | |||||
| mse = np.mean(pow(valid, 2)) | |||||
| return -10 * math.log10(mse) | |||||
| def rgb2ycbcr(img, y_only=True): | |||||
| '''metrics''' | |||||
| img.astype(np.float32) | |||||
| if y_only: | |||||
| rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0 | |||||
| return rlt | |||||
| @@ -0,0 +1,67 @@ | |||||
| '''temp''' | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| def set_template(args): | |||||
| '''temp''' | |||||
| if args.template.find('jpeg') >= 0: | |||||
| args.data_train = 'DIV2K_jpeg' | |||||
| args.data_test = 'DIV2K_jpeg' | |||||
| args.epochs = 200 | |||||
| args.decay = '100' | |||||
| if args.template.find('EDSR_paper') >= 0: | |||||
| args.model = 'EDSR' | |||||
| args.n_resblocks = 32 | |||||
| args.n_feats = 256 | |||||
| args.res_scale = 0.1 | |||||
| if args.template.find('MDSR') >= 0: | |||||
| args.model = 'MDSR' | |||||
| args.patch_size = 48 | |||||
| args.epochs = 650 | |||||
| if args.template.find('DDBPN') >= 0: | |||||
| args.model = 'DDBPN' | |||||
| args.patch_size = 128 | |||||
| args.scale = '4' | |||||
| args.data_test = 'Set5' | |||||
| args.batch_size = 20 | |||||
| args.epochs = 1000 | |||||
| args.decay = '500' | |||||
| args.gamma = 0.1 | |||||
| args.weight_decay = 1e-4 | |||||
| args.loss = '1*MSE' | |||||
| if args.template.find('GAN') >= 0: | |||||
| args.epochs = 200 | |||||
| args.lr = 5e-5 | |||||
| args.decay = '150' | |||||
| if args.template.find('RCAN') >= 0: | |||||
| args.model = 'RCAN' | |||||
| args.n_resgroups = 10 | |||||
| args.n_resblocks = 20 | |||||
| args.n_feats = 64 | |||||
| args.chop = True | |||||
| if args.template.find('VDSR') >= 0: | |||||
| args.model = 'VDSR' | |||||
| args.n_resblocks = 20 | |||||
| args.n_feats = 64 | |||||
| args.patch_size = 41 | |||||
| args.lr = 1e-1 | |||||