| @@ -0,0 +1,68 @@ | |||
| """eval script""" | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| from src import ipt | |||
| from src.args import args | |||
| from src.data.srdata import SRData | |||
| from src.metrics import calc_psnr, quantize | |||
| from mindspore import context | |||
| import mindspore.dataset as de | |||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU", device_id=0) | |||
| def main(): | |||
| """eval""" | |||
| for arg in vars(args): | |||
| if vars(args)[arg] == 'True': | |||
| vars(args)[arg] = True | |||
| elif vars(args)[arg] == 'False': | |||
| vars(args)[arg] = False | |||
| train_dataset = SRData(args, name=args.data_test, train=False, benchmark=False) | |||
| train_de_dataset = de.GeneratorDataset(train_dataset, ['LR', "HR"], shuffle=False) | |||
| train_de_dataset = train_de_dataset.batch(1, drop_remainder=True) | |||
| train_loader = train_de_dataset.create_dict_iterator() | |||
| net_m = ipt.IPT(args) | |||
| print('load mindspore net successfully.') | |||
| if args.pth_path: | |||
| param_dict = load_checkpoint(args.pth_path) | |||
| load_param_into_net(net_m, param_dict) | |||
| net_m.set_train(False) | |||
| num_imgs = train_de_dataset.get_dataset_size() | |||
| psnrs = np.zeros((num_imgs, 1)) | |||
| for batch_idx, imgs in enumerate(train_loader): | |||
| lr = imgs['LR'] | |||
| hr = imgs['HR'] | |||
| hr_np = np.float32(hr.asnumpy()) | |||
| pred = net_m.infrc(lr) | |||
| pred_np = np.float32(pred.asnumpy()) | |||
| pred_np = quantize(pred_np, 255) | |||
| psnr = calc_psnr(pred_np, hr_np, 4, 255.0, y_only=True) | |||
| psnrs[batch_idx, 0] = psnr | |||
| if args.denoise: | |||
| print('Mean psnr of %s DN_%s is %.4f' % (args.data_test[0], args.sigma, psnrs.mean(axis=0)[0])) | |||
| elif args.derain: | |||
| print('Mean psnr of Derain is %.4f' % (psnrs.mean(axis=0))) | |||
| else: | |||
| print('Mean psnr of %s x%s is %.4f' % (args.data_test[0], args.scale[0], psnrs.mean(axis=0)[0])) | |||
| if __name__ == '__main__': | |||
| print("Start main function!") | |||
| main() | |||
| @@ -0,0 +1,26 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """hub config.""" | |||
| from src.vitm import ViT | |||
| def IPT(*args, **kwargs): | |||
| return ViT(*args, **kwargs) | |||
| def create_network(name, *args, **kwargs): | |||
| if name == 'IPT': | |||
| return IPT(*args, **kwargs) | |||
| raise NotImplementedError(f"{name} is not implemented in the repo") | |||
| @@ -0,0 +1,147 @@ | |||
| <TOC> | |||
| # Pre-Trained Image Processing Transformer (IPT) | |||
| This repository is an official implementation of the paper "Pre-Trained Image Processing Transformer" from CVPR 2021. | |||
| We study the low-level computer vision task (e.g., denoising, super-resolution and deraining) and develop a new pre-trained model, namely, image processing transformer (IPT). To maximally excavate the capability of transformer, we present to utilize the well-known ImageNet benchmark for generating a large amount of corrupted image pairs. The IPT model is trained on these images with multi-heads and multi-tails. In addition, the contrastive learning is introduced for well adapting to different image processing tasks. The pre-trained model can therefore efficiently employed on desired task after fine-tuning. With only one pre-trained model, IPT outperforms the current state-of-the-art methods on various low-level benchmarks. | |||
| If you find our work useful in your research or publication, please cite our work: | |||
| [1] Hanting Chen, Yunhe Wang, Tianyu Guo, Chang Xu, Yiping Deng, Zhenhua Liu, Siwei Ma, Chunjing Xu, Chao Xu, and Wen Gao. **"Pre-trained image processing transformer"**. <i>**CVPR 2021**.</i> [[arXiv](https://arxiv.org/abs/2012.00364)] | |||
| @inproceedings{chen2020pre, | |||
| title={Pre-trained image processing transformer}, | |||
| author={Chen, Hanting and Wang, Yunhe and Guo, Tianyu and Xu, Chang and Deng, Yiping and Liu, Zhenhua and Ma, Siwei and Xu, Chunjing and Xu, Chao and Gao, Wen}, | |||
| booktitle={CVPR}, | |||
| year={2021} | |||
| } | |||
| ## Model architecture | |||
| ### The overall network architecture of IPT is shown as below: | |||
|  | |||
| ## Dataset | |||
| The benchmark datasets can be downloaded as follows: | |||
| For super-resolution: | |||
| Set5, | |||
| [Set14](https://sites.google.com/site/romanzeyde/research-interests), | |||
| [B100](https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/bsds/), | |||
| [Urban100](https://sites.google.com/site/jbhuang0604/publications/struct_sr). | |||
| For denoising: | |||
| [CBSD68](https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/bsds/). | |||
| For deraining: | |||
| [Rain100L](https://www.icst.pku.edu.cn/struct/Projects/joint_rain_removal.html) | |||
| The result images are converted into YCbCr color space. The PSNR is evaluated on the Y channel only. | |||
| ## Requirements | |||
| ### Hardware (GPU) | |||
| > Prepare hardware environment with GPU. | |||
| ### Framework | |||
| > [MindSpore](https://www.mindspore.cn/install/en) | |||
| ### For more information, please check the resources below: | |||
| [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html) | |||
| [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html) | |||
| ## Script Description | |||
| > This is the inference script of IPT, you can following steps to finish the test of image processing tasks, like SR, denoise and derain, via the corresponding pretrained models. | |||
| ### Scripts and Sample Code | |||
| ``` | |||
| IPT | |||
| ├── eval.py # inference entry | |||
| ├── image | |||
| │ └── ipt.png # the illustration of IPT network | |||
| ├── model | |||
| │ ├── IPT_denoise30.ckpt # denoise model weights for noise level 30 | |||
| │ ├── IPT_denoise50.ckpt # denoise model weights for noise level 50 | |||
| │ ├── IPT_derain.ckpt # derain model weights | |||
| │ ├── IPT_sr2.ckpt # X2 super-resolution model weights | |||
| │ ├── IPT_sr3.ckpt # X3 super-resolution model weights | |||
| │ └── IPT_sr4.ckpt # X4 super-resolution model weights | |||
| ├── readme.md # Readme | |||
| ├── scripts | |||
| │ └── run_eval.sh # inference script for all tasks | |||
| └── src | |||
| ├── args.py # options/hyper-parameters of IPT | |||
| ├── data | |||
| │ ├── common.py # common dataset | |||
| │ ├── __init__.py # Class data init function | |||
| │ └── srdata.py # flow of loading sr data | |||
| ├── foldunfold_stride.py # function of fold and unfold operations for images | |||
| ├── metrics.py # PSNR calculator | |||
| ├── template.py # setting of model selection | |||
| └── vitm.py # IPT network | |||
| ``` | |||
| ### Script Parameter | |||
| > For details about hyperparameters, see src/args.py. | |||
| ## Evaluation | |||
| ### Evaluation Process | |||
| > Inference example: | |||
| > For SR x4: | |||
| ``` | |||
| python eval.py --dir_data ../../data/ --data_test Set14 --nochange --test_only --ext img --chop_new --scale 4 --pth_path ./model/IPT_sr4.ckpt | |||
| ``` | |||
| > Or one can run following script for all tasks. | |||
| ``` | |||
| sh scripts/run_eval.sh | |||
| ``` | |||
| ### Evaluation Result | |||
| The result are evaluated by the value of PSNR (Peak Signal-to-Noise Ratio), and the format is as following. | |||
| ``` | |||
| result: {"Mean psnr of Se5 x4 is 32.68"} | |||
| ``` | |||
| ## Performance | |||
| ### Inference Performance | |||
| The Results on all tasks are listed as below. | |||
| Super-resolution results: | |||
| | Scale | Set5 | Set14 | B100 | Urban100 | | |||
| | ----- | ----- | ----- | ----- | ----- | | |||
| | ×2 | 38.36 | 34.54 | 32.50 | 33.88 | | |||
| | ×3 | 34.83 | 30.96 | 29.39 | 29.59 | | |||
| | ×4 | 32.68 | 29.01 | 27.81 | 27.24 | | |||
| Denoising results: | |||
| | noisy level | CBSD68 | Urban100 | | |||
| | ----- | ----- | ----- | | |||
| | 30 | 32.37 | 33.82 | | |||
| | 50 | 29.94 | 31.56 | | |||
| Derain results: | |||
| | Task | Rain100L | | |||
| | ----- | ----- | | |||
| | Derain | 41.98 | | |||
| ## ModeZoo Homepage | |||
| Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo). | |||
| @@ -0,0 +1,31 @@ | |||
| #!/bin/bash | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| export DEVICE_ID=$1 | |||
| DATA_DIR=$2 | |||
| DATA_SET=$3 | |||
| PATH_CHECKPOINT=$4 | |||
| python ../eval.py --dir_data=$DATA_DIR --data_test=$DATA_SET --nochange --test_only --ext img --chop_new --scale 4 --pth_path=$PATH_CHECKPOINT > eval.log 2>&1 & | |||
| python ../eval.py --dir_data=$DATA_DIR --data_test=$DATA_SET --nochange --test_only --ext img --chop_new --scale 3 --pth_path=$PATH_CHECKPOINT > eval.log 2>&1 & | |||
| python ../eval.py --dir_data=$DATA_DIR --data_test=$DATA_SET --nochange --test_only --ext img --chop_new --scale 2 --pth_path=$PATH_CHECKPOINT > eval.log 2>&1 & | |||
| ##denoise | |||
| python ../eval.py --dir_data=$DATA_DIR --data_test=$DATA_SET --nochange --test_only --ext img --chop_new --scale 1 --denoise --sigma 30 --pth_path=$PATH_CHECKPOINT > eval.log 2>&1 & | |||
| python ../eval.py --dir_data=$DATA_DIR --data_test=$DATA_SET --nochange --test_only --ext img --chop_new --scale 1 --denoise --sigma 50 --pth_path=$PATH_CHECKPOINT > eval.log 2>&1 & | |||
| ##derain | |||
| python ../eval.py --dir_data=$DATA_DIR --data_test=$DATA_SET --nochange --test_only --ext img --chop_new --scale 1 --derain --derain_test 1 --pth_path=$PATH_CHECKPOINT > eval.log 2>&1 & | |||
| @@ -0,0 +1,239 @@ | |||
| '''args''' | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import argparse | |||
| from src import template | |||
| parser = argparse.ArgumentParser(description='EDSR and MDSR') | |||
| parser.add_argument('--debug', action='store_true', | |||
| help='Enables debug mode') | |||
| parser.add_argument('--template', default='.', | |||
| help='You can set various templates in option.py') | |||
| # Hardware specifications | |||
| parser.add_argument('--n_threads', type=int, default=6, | |||
| help='number of threads for data loading') | |||
| parser.add_argument('--cpu', action='store_true', | |||
| help='use cpu only') | |||
| parser.add_argument('--n_GPUs', type=int, default=1, | |||
| help='number of GPUs') | |||
| parser.add_argument('--seed', type=int, default=1, | |||
| help='random seed') | |||
| # Data specifications | |||
| parser.add_argument('--dir_data', type=str, default='/cache/data/', | |||
| help='dataset directory') | |||
| parser.add_argument('--dir_demo', type=str, default='../test', | |||
| help='demo image directory') | |||
| parser.add_argument('--data_train', type=str, default='DIV2K', | |||
| help='train dataset name') | |||
| parser.add_argument('--data_test', type=str, default='DIV2K', | |||
| help='test dataset name') | |||
| parser.add_argument('--data_range', type=str, default='1-800/801-810', | |||
| help='train/test data range') | |||
| parser.add_argument('--ext', type=str, default='sep', | |||
| help='dataset file extension') | |||
| parser.add_argument('--scale', type=str, default='4', | |||
| help='super resolution scale') | |||
| parser.add_argument('--patch_size', type=int, default=48, | |||
| help='output patch size') | |||
| parser.add_argument('--rgb_range', type=int, default=255, | |||
| help='maximum value of RGB') | |||
| parser.add_argument('--n_colors', type=int, default=3, | |||
| help='number of color channels to use') | |||
| parser.add_argument('--chop', action='store_true', | |||
| help='enable memory-efficient forward') | |||
| parser.add_argument('--no_augment', action='store_true', | |||
| help='do not use data augmentation') | |||
| # Model specifications | |||
| parser.add_argument('--model', default='vtip', | |||
| help='model name') | |||
| parser.add_argument('--act', type=str, default='relu', | |||
| help='activation function') | |||
| parser.add_argument('--pre_train', type=str, default='', | |||
| help='pre-trained model directory') | |||
| parser.add_argument('--extend', type=str, default='.', | |||
| help='pre-trained model directory') | |||
| parser.add_argument('--n_resblocks', type=int, default=16, | |||
| help='number of residual blocks') | |||
| parser.add_argument('--n_feats', type=int, default=64, | |||
| help='number of feature maps') | |||
| parser.add_argument('--res_scale', type=float, default=1, | |||
| help='residual scaling') | |||
| parser.add_argument('--shift_mean', default=True, | |||
| help='subtract pixel mean from the input') | |||
| parser.add_argument('--dilation', action='store_true', | |||
| help='use dilated convolution') | |||
| parser.add_argument('--precision', type=str, default='single', | |||
| choices=('single', 'half'), | |||
| help='FP precision for test (single | half)') | |||
| # Option for Residual dense network (RDN) | |||
| parser.add_argument('--G0', type=int, default=64, | |||
| help='default number of filters. (Use in RDN)') | |||
| parser.add_argument('--RDNkSize', type=int, default=3, | |||
| help='default kernel size. (Use in RDN)') | |||
| parser.add_argument('--RDNconfig', type=str, default='B', | |||
| help='parameters config of RDN. (Use in RDN)') | |||
| # Option for Residual channel attention network (RCAN) | |||
| parser.add_argument('--n_resgroups', type=int, default=10, | |||
| help='number of residual groups') | |||
| parser.add_argument('--reduction', type=int, default=16, | |||
| help='number of feature maps reduction') | |||
| # Training specifications | |||
| parser.add_argument('--reset', action='store_true', | |||
| help='reset the training') | |||
| parser.add_argument('--test_every', type=int, default=1000, | |||
| help='do test per every N batches') | |||
| parser.add_argument('--epochs', type=int, default=300, | |||
| help='number of epochs to train') | |||
| parser.add_argument('--batch_size', type=int, default=16, | |||
| help='input batch size for training') | |||
| parser.add_argument('--test_batch_size', type=int, default=1, | |||
| help='input batch size for training') | |||
| parser.add_argument('--split_batch', type=int, default=1, | |||
| help='split the batch into smaller chunks') | |||
| parser.add_argument('--self_ensemble', action='store_true', | |||
| help='use self-ensemble method for test') | |||
| parser.add_argument('--test_only', action='store_true', | |||
| help='set this option to test the model') | |||
| parser.add_argument('--gan_k', type=int, default=1, | |||
| help='k value for adversarial loss') | |||
| # Optimization specifications | |||
| parser.add_argument('--lr', type=float, default=1e-4, | |||
| help='learning rate') | |||
| parser.add_argument('--decay', type=str, default='200', | |||
| help='learning rate decay type') | |||
| parser.add_argument('--gamma', type=float, default=0.5, | |||
| help='learning rate decay factor for step decay') | |||
| parser.add_argument('--optimizer', default='ADAM', | |||
| choices=('SGD', 'ADAM', 'RMSprop'), | |||
| help='optimizer to use (SGD | ADAM | RMSprop)') | |||
| parser.add_argument('--momentum', type=float, default=0.9, | |||
| help='SGD momentum') | |||
| parser.add_argument('--betas', type=tuple, default=(0.9, 0.999), | |||
| help='ADAM beta') | |||
| parser.add_argument('--epsilon', type=float, default=1e-8, | |||
| help='ADAM epsilon for numerical stability') | |||
| parser.add_argument('--weight_decay', type=float, default=0, | |||
| help='weight decay') | |||
| parser.add_argument('--gclip', type=float, default=0, | |||
| help='gradient clipping threshold (0 = no clipping)') | |||
| # Loss specifications | |||
| parser.add_argument('--loss', type=str, default='1*L1', | |||
| help='loss function configuration') | |||
| parser.add_argument('--skip_threshold', type=float, default='1e8', | |||
| help='skipping batch that has large error') | |||
| # Log specifications | |||
| parser.add_argument('--save', type=str, default='/cache/results/edsr_baseline_x2/', | |||
| help='file name to save') | |||
| parser.add_argument('--load', type=str, default='', | |||
| help='file name to load') | |||
| parser.add_argument('--resume', type=int, default=0, | |||
| help='resume from specific checkpoint') | |||
| parser.add_argument('--save_models', action='store_true', | |||
| help='save all intermediate models') | |||
| parser.add_argument('--print_every', type=int, default=100, | |||
| help='how many batches to wait before logging training status') | |||
| parser.add_argument('--save_results', action='store_true', | |||
| help='save output results') | |||
| parser.add_argument('--save_gt', action='store_true', | |||
| help='save low-resolution and high-resolution images together') | |||
| parser.add_argument('--scalelr', type=int, default=0) | |||
| # cloud | |||
| parser.add_argument('--moxfile', type=int, default=1) | |||
| parser.add_argument('--imagenet', type=int, default=0) | |||
| parser.add_argument('--data_url', type=str, help='path to dataset') | |||
| parser.add_argument('--train_url', type=str, help='train_dir') | |||
| parser.add_argument('--pretrain', type=str, default='') | |||
| parser.add_argument('--pth_path', type=str, default='') | |||
| parser.add_argument('--load_query', type=int, default=0) | |||
| # transformer | |||
| parser.add_argument('--patch_dim', type=int, default=3) | |||
| parser.add_argument('--num_heads', type=int, default=12) | |||
| parser.add_argument('--num_layers', type=int, default=12) | |||
| parser.add_argument('--dropout_rate', type=float, default=0) | |||
| parser.add_argument('--no_norm', action='store_true') | |||
| parser.add_argument('--post_norm', action='store_true') | |||
| parser.add_argument('--no_mlp', action='store_true') | |||
| parser.add_argument('--test', action='store_true') | |||
| parser.add_argument('--chop_new', action='store_true') | |||
| parser.add_argument('--pos_every', action='store_true') | |||
| parser.add_argument('--no_pos', action='store_true') | |||
| parser.add_argument('--num_queries', type=int, default=6) | |||
| parser.add_argument('--reweight', action='store_true') | |||
| # denoise | |||
| parser.add_argument('--denoise', action='store_true') | |||
| parser.add_argument('--sigma', type=float, default=25) | |||
| # derain | |||
| parser.add_argument('--derain', action='store_true') | |||
| parser.add_argument('--finetune', action='store_true') | |||
| parser.add_argument('--derain_test', type=int, default=10) | |||
| # alltask | |||
| parser.add_argument('--alltask', action='store_true') | |||
| # dehaze | |||
| parser.add_argument('--dehaze', action='store_true') | |||
| parser.add_argument('--dehaze_test', type=int, default=100) | |||
| parser.add_argument('--indoor', action='store_true') | |||
| parser.add_argument('--outdoor', action='store_true') | |||
| parser.add_argument('--nochange', action='store_true') | |||
| # deblur | |||
| parser.add_argument('--deblur', action='store_true') | |||
| parser.add_argument('--deblur_test', type=int, default=1000) | |||
| # distribute | |||
| parser.add_argument('--init_method', type=str, | |||
| default=None, help='master address') | |||
| parser.add_argument('--rank', type=int, default=0, | |||
| help='Index of current task') | |||
| parser.add_argument('--world_size', type=int, default=1, | |||
| help='Total number of tasks') | |||
| parser.add_argument('--gpu', default=None, type=int, | |||
| help='GPU id to use.') | |||
| parser.add_argument('--dist-url', default='', type=str, | |||
| help='url used to set up distributed training') | |||
| parser.add_argument('--dist-backend', default='nccl', type=str, | |||
| help='distributed backend') | |||
| parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', | |||
| help='number of data loading workers (default: 4)') | |||
| parser.add_argument('--distribute', action='store_true') | |||
| args, unparsed = parser.parse_known_args() | |||
| template.set_template(args) | |||
| args.scale = [int(x) for x in args.scale.split("+")] | |||
| args.data_train = args.data_train.split('+') | |||
| args.data_test = args.data_test.split('+') | |||
| if args.epochs == 0: | |||
| args.epochs = 1e8 | |||
| for arg in vars(args): | |||
| if vars(args)[arg] == 'True': | |||
| vars(args)[arg] = True | |||
| elif vars(args)[arg] == 'False': | |||
| vars(args)[arg] = False | |||
| @@ -0,0 +1,35 @@ | |||
| """data""" | |||
| # Copyright 2021 Huawei Technologies Co., Ltd" | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| from importlib import import_module | |||
| class Data: | |||
| """data""" | |||
| def __init__(self, args): | |||
| self.loader_train = None | |||
| self.loader_test = [] | |||
| for d in args.data_test: | |||
| if d in ['Set5', 'Set14', 'B100', 'Urban100', 'Manga109', 'CBSD68', 'Rain100L', 'GOPRO_Large']: | |||
| m = import_module('data.benchmark') | |||
| testset = getattr(m, 'Benchmark')(args, train=False, name=d) | |||
| else: | |||
| module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG' | |||
| m = import_module('data.' + module_name.lower()) | |||
| testset = getattr(m, module_name)(args, train=False, name=d) | |||
| self.loader_test.append( | |||
| testset | |||
| ) | |||
| @@ -0,0 +1,93 @@ | |||
| """common""" | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import random | |||
| import numpy as np | |||
| import skimage.color as sc | |||
| def get_patch(*args, patch_size=96, scale=2, multi=False, input_large=False): | |||
| """common""" | |||
| ih, iw = args[0].shape[:2] | |||
| tp = patch_size | |||
| ip = tp // scale | |||
| ix = random.randrange(0, iw - ip + 1) | |||
| iy = random.randrange(0, ih - ip + 1) | |||
| if not input_large: | |||
| tx, ty = scale * ix, scale * iy | |||
| else: | |||
| tx, ty = ix, iy | |||
| ret = [ | |||
| args[0][iy:iy + ip, ix:ix + ip, :], | |||
| *[a[ty:ty + tp, tx:tx + tp, :] for a in args[1:]] | |||
| ] | |||
| return ret | |||
| def set_channel(*args, n_channels=3): | |||
| """common""" | |||
| def _set_channel(img): | |||
| if img.ndim == 2: | |||
| img = np.expand_dims(img, axis=2) | |||
| c = img.shape[2] | |||
| if n_channels == 1 and c == 3: | |||
| img = np.expand_dims(sc.rgb2ycbcr(img)[:, :, 0], 2) | |||
| elif n_channels == 3 and c == 1: | |||
| img = np.concatenate([img] * n_channels, 2) | |||
| return img[:, :, :n_channels] | |||
| return [_set_channel(a) for a in args] | |||
| def np2Tensor(*args, rgb_range=255): | |||
| """common""" | |||
| def _np2Tensor(img): | |||
| np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1))) | |||
| tensor = np_transpose.astype(np.float32) | |||
| tensor = tensor * (rgb_range / 255) | |||
| # tensor = torch.from_numpy(np_transpose).float() | |||
| # tensor.mul_(rgb_range / 255) | |||
| return tensor | |||
| return [_np2Tensor(a) for a in args] | |||
| def augment(*args, hflip=True, rot=True): | |||
| """common""" | |||
| hflip = hflip and random.random() < 0.5 | |||
| vflip = rot and random.random() < 0.5 | |||
| rot90 = rot and random.random() < 0.5 | |||
| def _augment(img): | |||
| if hflip: | |||
| img = img[:, ::-1, :] | |||
| if vflip: | |||
| img = img[::-1, :, :] | |||
| if rot90: | |||
| img = img.transpose(1, 0, 2) | |||
| return img | |||
| return [_augment(a) for a in args] | |||
| @@ -0,0 +1,301 @@ | |||
| """srdata""" | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import os | |||
| import glob | |||
| import random | |||
| import pickle | |||
| from src.data import common | |||
| import numpy as np | |||
| import imageio | |||
| def search(root, target="JPEG"): | |||
| """srdata""" | |||
| item_list = [] | |||
| items = os.listdir(root) | |||
| for item in items: | |||
| path = os.path.join(root, item) | |||
| if os.path.isdir(path): | |||
| item_list.extend(search(path, target)) | |||
| elif path.split('/')[-1].startswith(target): | |||
| item_list.append(path) | |||
| elif target in (path.split('/')[-2], path.split('/')[-3], path.split('/')[-4]): | |||
| item_list.append(path) | |||
| else: | |||
| item_list = [] | |||
| return item_list | |||
| def search_dehaze(root, target="JPEG"): | |||
| """srdata""" | |||
| item_list = [] | |||
| items = os.listdir(root) | |||
| for item in items: | |||
| path = os.path.join(root, item) | |||
| if os.path.isdir(path): | |||
| extend_list = search_dehaze(path, target) | |||
| if extend_list is not None: | |||
| item_list.extend(extend_list) | |||
| elif path.split('/')[-2].endswith(target): | |||
| item_list.append(path) | |||
| return item_list | |||
| class SRData(): | |||
| """srdata""" | |||
| def __init__(self, args, name='', train=True, benchmark=False): | |||
| self.args = args | |||
| self.name = name | |||
| self.train = train | |||
| self.split = 'train' if train else 'test' | |||
| self.do_eval = True | |||
| self.benchmark = benchmark | |||
| self.input_large = (args.model == 'VDSR') | |||
| self.scale = args.scale | |||
| self.idx_scale = 0 | |||
| if self.args.derain: | |||
| self.derain_test = os.path.join(args.dir_data, "Rain100L") | |||
| self.derain_lr_test = search(self.derain_test, "rain") | |||
| self.derain_hr_test = [path.replace( | |||
| "rainy/", "no") for path in self.derain_lr_test] | |||
| self._set_filesystem(args.dir_data) | |||
| if args.ext.find('img') < 0: | |||
| path_bin = os.path.join(self.apath, 'bin') | |||
| os.makedirs(path_bin, exist_ok=True) | |||
| list_hr, list_lr = self._scan() | |||
| if args.ext.find('img') >= 0 or benchmark: | |||
| self.images_hr, self.images_lr = list_hr, list_lr | |||
| elif args.ext.find('sep') >= 0: | |||
| os.makedirs( | |||
| self.dir_hr.replace(self.apath, path_bin), | |||
| exist_ok=True | |||
| ) | |||
| for s in self.scale: | |||
| if s == 1: | |||
| os.makedirs( | |||
| os.path.join(self.dir_hr), | |||
| exist_ok=True | |||
| ) | |||
| else: | |||
| os.makedirs( | |||
| os.path.join( | |||
| self.dir_lr.replace(self.apath, path_bin), | |||
| 'X{}'.format(s) | |||
| ), | |||
| exist_ok=True | |||
| ) | |||
| self.images_hr, self.images_lr = [], [[] for _ in self.scale] | |||
| for h in list_hr: | |||
| b = h.replace(self.apath, path_bin) | |||
| b = b.replace(self.ext[0], '.pt') | |||
| self.images_hr.append(b) | |||
| self._check_and_load(args.ext, h, b, verbose=True) | |||
| for i, ll in enumerate(list_lr): | |||
| for l in ll: | |||
| b = l.replace(self.apath, path_bin) | |||
| b = b.replace(self.ext[1], '.pt') | |||
| self.images_lr[i].append(b) | |||
| self._check_and_load(args.ext, l, b, verbose=True) | |||
| # Below functions as used to prepare images | |||
| def _scan(self): | |||
| """srdata""" | |||
| names_hr = sorted( | |||
| glob.glob(os.path.join(self.dir_hr, '*' + self.ext[0])) | |||
| ) | |||
| names_lr = [[] for _ in self.scale] | |||
| for f in names_hr: | |||
| filename, _ = os.path.splitext(os.path.basename(f)) | |||
| for si, s in enumerate(self.scale): | |||
| if s != 1: | |||
| scale = s | |||
| names_lr[si].append(os.path.join( | |||
| self.dir_lr, 'X{}/{}x{}{}'.format( | |||
| s, filename, scale, self.ext[1] | |||
| ) | |||
| )) | |||
| for si, s in enumerate(self.scale): | |||
| if s == 1: | |||
| names_lr[si] = names_hr | |||
| return names_hr, names_lr | |||
| def _set_filesystem(self, dir_data): | |||
| self.apath = os.path.join(dir_data, self.name[0]) | |||
| self.dir_hr = os.path.join(self.apath, 'HR') | |||
| self.dir_lr = os.path.join(self.apath, 'LR_bicubic') | |||
| self.ext = ('.png', '.png') | |||
| def _check_and_load(self, ext, img, f, verbose=True): | |||
| if not os.path.isfile(f) or ext.find('reset') >= 0: | |||
| if verbose: | |||
| print('Making a binary: {}'.format(f)) | |||
| with open(f, 'wb') as _f: | |||
| pickle.dump(imageio.imread(img), _f) | |||
| def __getitem__(self, idx): | |||
| if self.args.model == 'vtip' and self.args.derain and self.scale[ | |||
| self.idx_scale] == 1 and not self.args.finetune: | |||
| norain, rain, _ = self._load_rain_test(idx) | |||
| pair = common.set_channel( | |||
| *[rain, norain], n_channels=self.args.n_colors) | |||
| pair_t = common.np2Tensor(*pair, rgb_range=self.args.rgb_range) | |||
| return pair_t[0], pair_t[1] | |||
| if self.args.model == 'vtip' and self.args.denoise and self.scale[self.idx_scale] == 1: | |||
| hr, _ = self._load_file_hr(idx) | |||
| pair = self.get_patch_hr(hr) | |||
| pair = common.set_channel(*[pair], n_channels=self.args.n_colors) | |||
| pair_t = common.np2Tensor(*pair, rgb_range=self.args.rgb_range) | |||
| noise = np.random.randn(*pair_t[0].shape) * self.args.sigma | |||
| lr = pair_t[0] + noise | |||
| lr = np.float32(np.clip(lr, 0, 255)) | |||
| return lr, pair_t[0] | |||
| lr, hr, _ = self._load_file(idx) | |||
| pair = self.get_patch(lr, hr) | |||
| pair = common.set_channel(*pair, n_channels=self.args.n_colors) | |||
| pair_t = common.np2Tensor(*pair, rgb_range=self.args.rgb_range) | |||
| return pair_t[0], pair_t[1] | |||
| def __len__(self): | |||
| if self.train: | |||
| return len(self.images_hr) * self.repeat | |||
| if self.args.derain and not self.args.alltask: | |||
| return int(len(self.derain_hr_test) / self.args.derain_test) | |||
| return len(self.images_hr) | |||
| def _get_index(self, idx): | |||
| """srdata""" | |||
| if self.train: | |||
| return idx % len(self.images_hr) | |||
| return idx | |||
| def _load_file_hr(self, idx): | |||
| """srdata""" | |||
| idx = self._get_index(idx) | |||
| f_hr = self.images_hr[idx] | |||
| filename, _ = os.path.splitext(os.path.basename(f_hr)) | |||
| if self.args.ext == 'img' or self.benchmark: | |||
| hr = imageio.imread(f_hr) | |||
| elif self.args.ext.find('sep') >= 0: | |||
| with open(f_hr, 'rb') as _f: | |||
| hr = pickle.load(_f) | |||
| return hr, filename | |||
| def _load_rain(self, idx, rain_img=False): | |||
| """srdata""" | |||
| idx = random.randint(0, len(self.derain_img_list) - 1) | |||
| f_lr = self.derain_img_list[idx] | |||
| if rain_img: | |||
| norain = imageio.imread(f_lr.replace("rainstreak", "norain")) | |||
| rain = imageio.imread(f_lr.replace("rainstreak", "rain")) | |||
| return norain, rain | |||
| lr = imageio.imread(f_lr) | |||
| return lr | |||
| def _load_rain_test(self, idx): | |||
| """srdata""" | |||
| f_hr = self.derain_hr_test[idx] | |||
| f_lr = self.derain_lr_test[idx] | |||
| filename, _ = os.path.splitext(os.path.basename(f_lr)) | |||
| norain = imageio.imread(f_hr) | |||
| rain = imageio.imread(f_lr) | |||
| return norain, rain, filename | |||
| def _load_denoise(self, idx): | |||
| """srdata""" | |||
| idx = self._get_index(idx) | |||
| f_lr = self.images_hr[idx] | |||
| norain = imageio.imread(f_lr) | |||
| rain = imageio.imread(f_lr.replace("HR", "LR_bicubic")) | |||
| return norain, rain | |||
| def _load_file(self, idx): | |||
| """srdata""" | |||
| idx = self._get_index(idx) | |||
| f_hr = self.images_hr[idx] | |||
| f_lr = self.images_lr[self.idx_scale][idx] | |||
| filename, _ = os.path.splitext(os.path.basename(f_hr)) | |||
| if self.args.ext == 'img' or self.benchmark: | |||
| hr = imageio.imread(f_hr) | |||
| lr = imageio.imread(f_lr) | |||
| elif self.args.ext.find('sep') >= 0: | |||
| with open(f_hr, 'rb') as _f: | |||
| hr = pickle.load(_f) | |||
| with open(f_lr, 'rb') as _f: | |||
| lr = pickle.load(_f) | |||
| return lr, hr, filename | |||
| def get_patch_hr(self, hr): | |||
| """srdata""" | |||
| if self.train: | |||
| hr = self.get_patch_img_hr( | |||
| hr, | |||
| patch_size=self.args.patch_size, | |||
| scale=1 | |||
| ) | |||
| return hr | |||
| def get_patch_img_hr(self, img, patch_size=96, scale=2): | |||
| """srdata""" | |||
| ih, iw = img.shape[:2] | |||
| tp = patch_size | |||
| ip = tp // scale | |||
| ix = random.randrange(0, iw - ip + 1) | |||
| iy = random.randrange(0, ih - ip + 1) | |||
| ret = img[iy:iy + ip, ix:ix + ip, :] | |||
| return ret | |||
| def get_patch(self, lr, hr): | |||
| """srdata""" | |||
| scale = self.scale[self.idx_scale] | |||
| if self.train: | |||
| lr, hr = common.get_patch( | |||
| lr, hr, | |||
| patch_size=self.args.patch_size * scale, | |||
| scale=scale, | |||
| multi=(len(self.scale) > 1) | |||
| ) | |||
| if not self.args.no_augment: | |||
| lr, hr = common.augment(lr, hr) | |||
| else: | |||
| ih, iw = lr.shape[:2] | |||
| hr = hr[0:ih * scale, 0:iw * scale] | |||
| return lr, hr | |||
| def set_scale(self, idx_scale): | |||
| """srdata""" | |||
| if not self.input_large: | |||
| self.idx_scale = idx_scale | |||
| else: | |||
| self.idx_scale = random.randint(0, len(self.scale) - 1) | |||
| @@ -0,0 +1,241 @@ | |||
| '''stride''' | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| from mindspore import nn | |||
| from mindspore.ops import operations as P | |||
| from mindspore.common import dtype as mstype | |||
| class _stride_unfold_(nn.Cell): | |||
| '''stride''' | |||
| def __init__(self, | |||
| kernel_size, | |||
| stride=-1): | |||
| super(_stride_unfold_, self).__init__() | |||
| if stride == -1: | |||
| self.stride = kernel_size | |||
| else: | |||
| self.stride = stride | |||
| self.kernel_size = kernel_size | |||
| self.reshape = P.Reshape() | |||
| self.transpose = P.Transpose() | |||
| self.unfold = _unfold_(kernel_size) | |||
| def construct(self, x): | |||
| """stride""" | |||
| N, C, H, W = x.shape | |||
| leftup_idx_x = [] | |||
| leftup_idx_y = [] | |||
| nh = int(H / self.stride) | |||
| nw = int(W / self.stride) | |||
| for i in range(nh): | |||
| leftup_idx_x.append(i * self.stride) | |||
| for i in range(nw): | |||
| leftup_idx_y.append(i * self.stride) | |||
| NumBlock_x = len(leftup_idx_x) | |||
| NumBlock_y = len(leftup_idx_y) | |||
| zeroslike = P.ZerosLike() | |||
| cc_2 = P.Concat(axis=2) | |||
| cc_3 = P.Concat(axis=3) | |||
| unf_x = P.Zeros()((N, C, NumBlock_x * self.kernel_size, | |||
| NumBlock_y * self.kernel_size), mstype.float32) | |||
| N, C, H, W = unf_x.shape | |||
| for i in range(NumBlock_x): | |||
| for j in range(NumBlock_y): | |||
| unf_i = i * self.kernel_size | |||
| unf_j = j * self.kernel_size | |||
| org_i = leftup_idx_x[i] | |||
| org_j = leftup_idx_y[j] | |||
| fills = x[:, :, org_i:org_i + self.kernel_size, | |||
| org_j:org_j + self.kernel_size] | |||
| unf_x += cc_3((cc_3((zeroslike(unf_x[:, :, :, :unf_j]), cc_2((cc_2( | |||
| (zeroslike(unf_x[:, :, :unf_i, unf_j:unf_j + self.kernel_size]), fills)), zeroslike( | |||
| unf_x[:, :, unf_i + self.kernel_size:, unf_j:unf_j + self.kernel_size]))))), | |||
| zeroslike(unf_x[:, :, :, unf_j + self.kernel_size:]))) | |||
| y = self.unfold(unf_x) | |||
| return y | |||
| class _stride_fold_(nn.Cell): | |||
| '''stride''' | |||
| def __init__(self, | |||
| kernel_size, | |||
| output_shape=(-1, -1), | |||
| stride=-1): | |||
| super(_stride_fold_, self).__init__() | |||
| if isinstance(kernel_size, (list, tuple)): | |||
| self.kernel_size = kernel_size | |||
| else: | |||
| self.kernel_size = [kernel_size, kernel_size] | |||
| if stride == -1: | |||
| self.stride = kernel_size[0] | |||
| else: | |||
| self.stride = stride | |||
| self.output_shape = output_shape | |||
| self.reshape = P.Reshape() | |||
| self.transpose = P.Transpose() | |||
| self.fold = _fold_(kernel_size) | |||
| def construct(self, x): | |||
| '''stride''' | |||
| if self.output_shape[0] == -1: | |||
| large_x = self.fold(x) | |||
| N, C, H, _ = large_x.shape | |||
| leftup_idx = [] | |||
| for i in range(0, H, self.kernel_size[0]): | |||
| leftup_idx.append(i) | |||
| NumBlock = len(leftup_idx) | |||
| fold_x = P.Zeros()((N, C, (NumBlock - 1) * self.stride + self.kernel_size[0], | |||
| (NumBlock - 1) * self.stride + self.kernel_size[0]), mstype.float32) | |||
| for i in range(NumBlock): | |||
| for j in range(NumBlock): | |||
| fold_i = i * self.stride | |||
| fold_j = j * self.stride | |||
| org_i = leftup_idx[i] | |||
| org_j = leftup_idx[j] | |||
| fills = x[:, :, org_i:org_i + self.kernel_size[0], | |||
| org_j:org_j + self.kernel_size[1]] | |||
| fold_x += cc_3((cc_3((zeroslike(fold_x[:, :, :, :fold_j]), cc_2((cc_2( | |||
| (zeroslike(fold_x[:, :, :fold_i, fold_j:fold_j + self.kernel_size[1]]), fills)), zeroslike( | |||
| fold_x[:, :, fold_i + self.kernel_size[0]:, fold_j:fold_j + self.kernel_size[1]]))))), | |||
| zeroslike(fold_x[:, :, :, fold_j + self.kernel_size[1]:]))) | |||
| y = fold_x | |||
| else: | |||
| NumBlock_x = int( | |||
| (self.output_shape[0] - self.kernel_size[0]) / self.stride + 1) | |||
| NumBlock_y = int( | |||
| (self.output_shape[1] - self.kernel_size[1]) / self.stride + 1) | |||
| large_shape = [NumBlock_x * self.kernel_size[0], | |||
| NumBlock_y * self.kernel_size[1]] | |||
| self.fold = _fold_(self.kernel_size, large_shape) | |||
| large_x = self.fold(x) | |||
| N, C, H, _ = large_x.shape | |||
| leftup_idx_x = [] | |||
| leftup_idx_y = [] | |||
| for i in range(NumBlock_x): | |||
| leftup_idx_x.append(i * self.kernel_size[0]) | |||
| for i in range(NumBlock_y): | |||
| leftup_idx_y.append(i * self.kernel_size[1]) | |||
| fold_x = P.Zeros()((N, C, (NumBlock_x - 1) * self.stride + self.kernel_size[0], | |||
| (NumBlock_y - 1) * self.stride + self.kernel_size[1]), mstype.float32) | |||
| for i in range(NumBlock_x): | |||
| for j in range(NumBlock_y): | |||
| fold_i = i * self.stride | |||
| fold_j = j * self.stride | |||
| org_i = leftup_idx_x[i] | |||
| org_j = leftup_idx_y[j] | |||
| fills = x[:, :, org_i:org_i + self.kernel_size[0], | |||
| org_j:org_j + self.kernel_size[1]] | |||
| fold_x += cc_3((cc_3((zeroslike(fold_x[:, :, :, :fold_j]), cc_2((cc_2( | |||
| (zeroslike(fold_x[:, :, :fold_i, fold_j:fold_j + self.kernel_size[1]]), fills)), zeroslike( | |||
| fold_x[:, :, fold_i + self.kernel_size[0]:, fold_j:fold_j + self.kernel_size[1]]))))), | |||
| zeroslike(fold_x[:, :, :, fold_j + self.kernel_size[1]:]))) | |||
| y = fold_x | |||
| return y | |||
| class _unfold_(nn.Cell): | |||
| '''stride''' | |||
| def __init__(self, | |||
| kernel_size, | |||
| stride=-1): | |||
| super(_unfold_, self).__init__() | |||
| if stride == -1: | |||
| self.stride = kernel_size | |||
| self.kernel_size = kernel_size | |||
| self.reshape = P.Reshape() | |||
| self.transpose = P.Transpose() | |||
| def construct(self, x): | |||
| '''stride''' | |||
| N, C, H, W = x.shape | |||
| numH = int(H / self.kernel_size) | |||
| numW = int(W / self.kernel_size) | |||
| if numH * self.kernel_size != H or numW * self.kernel_size != W: | |||
| x = x[:, :, :numH * self.kernel_size, :, numW * self.kernel_size] | |||
| output_img = self.reshape(x, (N, C, numH, self.kernel_size, W)) | |||
| output_img = self.transpose(output_img, (0, 1, 2, 4, 3)) | |||
| output_img = self.reshape(output_img, (N, C, int( | |||
| numH * numW), self.kernel_size, self.kernel_size)) | |||
| output_img = self.transpose(output_img, (0, 2, 1, 4, 3)) | |||
| output_img = self.reshape(output_img, (N, int(numH * numW), -1)) | |||
| return output_img | |||
| class _fold_(nn.Cell): | |||
| '''stride''' | |||
| def __init__(self, | |||
| kernel_size, | |||
| output_shape=(-1, -1), | |||
| stride=-1): | |||
| super(_fold_, self).__init__() | |||
| if isinstance(kernel_size, (list, tuple)): | |||
| self.kernel_size = kernel_size | |||
| else: | |||
| self.kernel_size = [kernel_size, kernel_size] | |||
| if stride == -1: | |||
| self.stride = kernel_size[0] | |||
| self.output_shape = output_shape | |||
| self.reshape = P.Reshape() | |||
| self.transpose = P.Transpose() | |||
| def construct(self, x): | |||
| '''stride''' | |||
| N, C, L = x.shape | |||
| org_C = int(L / self.kernel_size[0] / self.kernel_size[1]) | |||
| if self.output_shape[0] == -1: | |||
| numH = int(np.sqrt(C)) | |||
| numW = int(np.sqrt(C)) | |||
| org_H = int(numH * self.kernel_size[0]) | |||
| org_W = org_H | |||
| else: | |||
| org_H = int(self.output_shape[0]) | |||
| org_W = int(self.output_shape[1]) | |||
| numH = int(org_H / self.kernel_size[0]) | |||
| numW = int(org_W / self.kernel_size[1]) | |||
| output_img = self.reshape( | |||
| x, (N, C, org_C, self.kernel_size[0], self.kernel_size[1])) | |||
| output_img = self.transpose(output_img, (0, 2, 1, 3, 4)) | |||
| output_img = self.reshape( | |||
| output_img, (N, org_C, numH, numW, self.kernel_size[0], self.kernel_size[1])) | |||
| output_img = self.transpose(output_img, (0, 1, 2, 4, 3, 5)) | |||
| output_img = self.reshape(output_img, (N, org_C, org_H, org_W)) | |||
| return output_img | |||
| @@ -0,0 +1,56 @@ | |||
| '''metrics''' | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import math | |||
| import numpy as np | |||
| def quantize(img, rgb_range): | |||
| '''metrics''' | |||
| pixel_range = 255 / rgb_range | |||
| img = np.multiply(img, pixel_range) | |||
| img = np.clip(img, 0, 255) | |||
| img = np.round(img) / pixel_range | |||
| return img | |||
| def calc_psnr(sr, hr, scale, rgb_range, y_only=False, dataset=None): | |||
| '''metrics''' | |||
| hr = np.float32(hr) | |||
| sr = np.float32(sr) | |||
| diff = (sr - hr) / rgb_range | |||
| gray_coeffs = np.array([65.738, 129.057, 25.064] | |||
| ).reshape((1, 3, 1, 1)) / 256 | |||
| diff = np.multiply(diff, gray_coeffs).sum(1) | |||
| if hr.size == 1: | |||
| return 0 | |||
| if scale != 1: | |||
| shave = scale | |||
| else: | |||
| shave = scale + 6 | |||
| if scale == 1: | |||
| valid = diff | |||
| else: | |||
| valid = diff[..., shave:-shave, shave:-shave] | |||
| mse = np.mean(pow(valid, 2)) | |||
| return -10 * math.log10(mse) | |||
| def rgb2ycbcr(img, y_only=True): | |||
| '''metrics''' | |||
| img.astype(np.float32) | |||
| if y_only: | |||
| rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0 | |||
| return rlt | |||
| @@ -0,0 +1,67 @@ | |||
| '''temp''' | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| def set_template(args): | |||
| '''temp''' | |||
| if args.template.find('jpeg') >= 0: | |||
| args.data_train = 'DIV2K_jpeg' | |||
| args.data_test = 'DIV2K_jpeg' | |||
| args.epochs = 200 | |||
| args.decay = '100' | |||
| if args.template.find('EDSR_paper') >= 0: | |||
| args.model = 'EDSR' | |||
| args.n_resblocks = 32 | |||
| args.n_feats = 256 | |||
| args.res_scale = 0.1 | |||
| if args.template.find('MDSR') >= 0: | |||
| args.model = 'MDSR' | |||
| args.patch_size = 48 | |||
| args.epochs = 650 | |||
| if args.template.find('DDBPN') >= 0: | |||
| args.model = 'DDBPN' | |||
| args.patch_size = 128 | |||
| args.scale = '4' | |||
| args.data_test = 'Set5' | |||
| args.batch_size = 20 | |||
| args.epochs = 1000 | |||
| args.decay = '500' | |||
| args.gamma = 0.1 | |||
| args.weight_decay = 1e-4 | |||
| args.loss = '1*MSE' | |||
| if args.template.find('GAN') >= 0: | |||
| args.epochs = 200 | |||
| args.lr = 5e-5 | |||
| args.decay = '150' | |||
| if args.template.find('RCAN') >= 0: | |||
| args.model = 'RCAN' | |||
| args.n_resgroups = 10 | |||
| args.n_resblocks = 20 | |||
| args.n_feats = 64 | |||
| args.chop = True | |||
| if args.template.find('VDSR') >= 0: | |||
| args.model = 'VDSR' | |||
| args.n_resblocks = 20 | |||
| args.n_feats = 64 | |||
| args.patch_size = 41 | |||
| args.lr = 1e-1 | |||