From: @Somnus2020 Reviewed-by: @oacjiewen,@c_34,@linqingke Signed-off-by: @linqingkepull/15439/MERGE
| @@ -1,6 +1,6 @@ | |||
| """eval script""" | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| @@ -13,48 +13,82 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import os | |||
| import numpy as np | |||
| from src import ipt | |||
| import mindspore.dataset as ds | |||
| from mindspore import Tensor, context | |||
| from mindspore.common import dtype as mstype | |||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| from src.args import args | |||
| import src.ipt_model as ipt | |||
| from src.data.srdata import SRData | |||
| from src.metrics import calc_psnr, quantize | |||
| from mindspore import context | |||
| import mindspore.dataset as de | |||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| device_id = int(os.getenv('DEVICE_ID', '0')) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id, save_graphs=False) | |||
| context.set_context(max_call_depth=10000) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="ASCEND", device_id=0) | |||
| def sub_mean(x): | |||
| red_channel_mean = 0.4488 * 255 | |||
| green_channel_mean = 0.4371 * 255 | |||
| blue_channel_mean = 0.4040 * 255 | |||
| x[:, 0, :, :] -= red_channel_mean | |||
| x[:, 1, :, :] -= green_channel_mean | |||
| x[:, 2, :, :] -= blue_channel_mean | |||
| return x | |||
| def add_mean(x): | |||
| red_channel_mean = 0.4488 * 255 | |||
| green_channel_mean = 0.4371 * 255 | |||
| blue_channel_mean = 0.4040 * 255 | |||
| x[:, 0, :, :] += red_channel_mean | |||
| x[:, 1, :, :] += green_channel_mean | |||
| x[:, 2, :, :] += blue_channel_mean | |||
| return x | |||
| def main(): | |||
| def eval_net(): | |||
| """eval""" | |||
| args.batch_size = 128 | |||
| args.decay = 70 | |||
| args.patch_size = 48 | |||
| args.num_queries = 6 | |||
| args.model = 'vtip' | |||
| args.num_layers = 4 | |||
| if args.epochs == 0: | |||
| args.epochs = 1e8 | |||
| for arg in vars(args): | |||
| if vars(args)[arg] == 'True': | |||
| vars(args)[arg] = True | |||
| elif vars(args)[arg] == 'False': | |||
| vars(args)[arg] = False | |||
| train_dataset = SRData(args, name=args.data_test, train=False, benchmark=False) | |||
| train_de_dataset = de.GeneratorDataset(train_dataset, ['LR', "HR"], shuffle=False) | |||
| train_de_dataset = ds.GeneratorDataset(train_dataset, ['LR', 'HR', "idx", "filename"], shuffle=False) | |||
| train_de_dataset = train_de_dataset.batch(1, drop_remainder=True) | |||
| train_loader = train_de_dataset.create_dict_iterator() | |||
| train_loader = train_de_dataset.create_dict_iterator(output_numpy=True) | |||
| net_m = ipt.IPT(args) | |||
| print('load mindspore net successfully.') | |||
| if args.pth_path: | |||
| param_dict = load_checkpoint(args.pth_path) | |||
| load_param_into_net(net_m, param_dict) | |||
| net_m.set_train(False) | |||
| idx = Tensor(np.ones(args.task_id), mstype.int32) | |||
| inference = ipt.IPT_post(net_m, args) | |||
| print('load mindspore net successfully.') | |||
| num_imgs = train_de_dataset.get_dataset_size() | |||
| psnrs = np.zeros((num_imgs, 1)) | |||
| inference = ipt.IPT_post(net_m, args) | |||
| for batch_idx, imgs in enumerate(train_loader): | |||
| lr = imgs['LR'] | |||
| hr = imgs['HR'] | |||
| hr_np = np.float32(hr.asnumpy()) | |||
| pred = inference.forward(lr) | |||
| pred_np = np.float32(pred.asnumpy()) | |||
| lr = sub_mean(lr) | |||
| lr = Tensor(lr, mstype.float32) | |||
| pred = inference.forward(lr, idx) | |||
| pred_np = add_mean(pred.asnumpy()) | |||
| pred_np = quantize(pred_np, 255) | |||
| psnr = calc_psnr(pred_np, hr_np, 4, 255.0, y_only=True) | |||
| psnr = calc_psnr(pred_np, hr, 4, 255.0) | |||
| print("current psnr: ", psnr) | |||
| psnrs[batch_idx, 0] = psnr | |||
| if args.denoise: | |||
| print('Mean psnr of %s DN_%s is %.4f' % (args.data_test[0], args.sigma, psnrs.mean(axis=0)[0])) | |||
| @@ -63,7 +97,6 @@ def main(): | |||
| else: | |||
| print('Mean psnr of %s x%s is %.4f' % (args.data_test[0], args.scale[0], psnrs.mean(axis=0)[0])) | |||
| if __name__ == '__main__': | |||
| print("Start main function!") | |||
| main() | |||
| print("Start eval function!") | |||
| eval_net() | |||
| @@ -1,26 +0,0 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """hub config.""" | |||
| from src.vitm import ViT | |||
| def IPT(*args, **kwargs): | |||
| return ViT(*args, **kwargs) | |||
| def create_network(name, *args, **kwargs): | |||
| if name == 'IPT': | |||
| return IPT(*args, **kwargs) | |||
| raise NotImplementedError(f"{name} is not implemented in the repo") | |||
| @@ -45,9 +45,9 @@ The result images are converted into YCbCr color space. The PSNR is evaluated on | |||
| ## Requirements | |||
| ### Hardware (GPU) | |||
| ### Hardware (Ascend) | |||
| > Prepare hardware environment with GPU. | |||
| > Prepare hardware environment with Ascend. | |||
| ### Framework | |||
| @@ -67,34 +67,73 @@ The result images are converted into YCbCr color space. The PSNR is evaluated on | |||
| ```bash | |||
| IPT | |||
| ├── eval.py # inference entry | |||
| ├── train.py # pre-training entry | |||
| ├── train_finetune.py # fine-tuning entry | |||
| ├── image | |||
| │ └── ipt.png # the illustration of IPT network | |||
| ├── model | |||
| │ ├── IPT_denoise30.ckpt # denoise model weights for noise level 30 | |||
| │ ├── IPT_denoise50.ckpt # denoise model weights for noise level 50 | |||
| │ ├── IPT_derain.ckpt # derain model weights | |||
| │ ├── IPT_sr2.ckpt # X2 super-resolution model weights | |||
| │ ├── IPT_sr3.ckpt # X3 super-resolution model weights | |||
| │ └── IPT_sr4.ckpt # X4 super-resolution model weights | |||
| ├── readme.md # Readme | |||
| ├── scripts | |||
| │ └── run_eval.sh # inference script for all tasks | |||
| │ ├── run_eval.sh # inference script for all tasks | |||
| │ ├── run_distributed.sh # pre-training script for all tasks | |||
| │ └── run_finetune_distributed.sh # fine-tuning script for all tasks | |||
| └── src | |||
| ├── args.py # options/hyper-parameters of IPT | |||
| ├── data | |||
| │ ├── common.py # common dataset | |||
| │ ├── __init__.py # Class data init function | |||
| │ └── srdata.py # flow of loading sr data | |||
| ├── foldunfold_stride.py # function of fold and unfold operations for images | |||
| │ ├── bicubic.py # scripts for data pre-processing | |||
| │ ├── div2k.py # DIV2K dataset | |||
| │ ├── imagenet.py # Imagenet data for pre-training | |||
| │ └── srdata.py # All dataset | |||
| ├── metrics.py # PSNR calculator | |||
| ├── template.py # setting of model selection | |||
| └── vitm.py # IPT network | |||
| ├── utils.py # training scripts | |||
| ├── loss.py # contrastive_loss | |||
| └── ipt_model.py # IPT network | |||
| ``` | |||
| ### Script Parameter | |||
| > For details about hyperparameters, see src/args.py. | |||
| ## Training Process | |||
| ### For pre-training | |||
| ```bash | |||
| python train.py --distribute --imagenet 1 --batch_size 64 --lr 5e-5 --scale 2+3+4+1+1+1 --alltask --react --model vtip --num_queries 6 --chop_new --num_layers 4 --data_train imagenet --dir_data $DATA_PATH --derain --save $SAVE_PATH | |||
| ``` | |||
| > Or one can run following script for all tasks. | |||
| ```bash | |||
| sh scripts/run_distributed.sh RANK_TABLE_FILE DATA_PATH | |||
| ``` | |||
| ### For fine-tuning | |||
| > For SR tasks: | |||
| ```bash | |||
| python train_finetune.py --distribute --imagenet 0 --batch_size 64 --lr 2e-5 --scale 2+3+4+1+1+1 --model vtip --num_queries 6 --chop_new --num_layers 4 --task_id $TASK_ID --dir_data $DATA_PATH --pth_path $MODEL --epochs 50 | |||
| ``` | |||
| > For Denoising tasks: | |||
| ```bash | |||
| python train_finetune.py --distribute --imagenet 0 --batch_size 64 --lr 2e-5 --scale 2+3+4+1+1+1 --model vtip --num_queries 6 --chop_new --num_layers 4 --task_id $TASK_ID --dir_data $DATA_PATH --pth_path $MODEL --denoise --sigma $Noise --epochs 50 | |||
| ``` | |||
| > For deraining tasks: | |||
| ```bash | |||
| python train_finetune.py --distribute --imagenet 0 --batch_size 64 --lr 2e-5 --scale 2+3+4+1+1+1 --model vtip --num_queries 6 --chop_new --num_layers 4 --task_id $TASK_ID --dir_data $DATA_PATH --pth_path $MODEL --derain --epochs 50 | |||
| ``` | |||
| > Or one can run following script for all tasks. | |||
| ```bash | |||
| sh scripts/run_finetune_distributed.sh RANK_TABLE_FILE DATA_PATH MODEL TASK_ID | |||
| ``` | |||
| ## Evaluation | |||
| ### Evaluation Process | |||
| @@ -103,13 +142,13 @@ IPT | |||
| > For SR x4: | |||
| ```bash | |||
| python eval.py --dir_data ../../data/ --data_test Set14 --nochange --test_only --ext img --chop_new --scale 4 --pth_path ./model/IPT_sr4.ckpt | |||
| python eval.py --dir_data $DATA_PATH --data_test $DATA_TEST --test_only --ext img --pth_path $MODEL --task_id $TASK_ID --scale $SCALE | |||
| ``` | |||
| > Or one can run following script for all tasks. | |||
| ```bash | |||
| sh scripts/run_eval.sh | |||
| sh scripts/run_eval.sh DATA_PATH DATA_TEST MODEL TASK_ID | |||
| ``` | |||
| ### Evaluation Result | |||
| @@ -117,7 +156,7 @@ sh scripts/run_eval.sh | |||
| The result are evaluated by the value of PSNR (Peak Signal-to-Noise Ratio), and the format is as following. | |||
| ```bash | |||
| result: {"Mean psnr of Se5 x4 is 32.68"} | |||
| result: {"Mean psnr of Set5 x4 is 32.68"} | |||
| ``` | |||
| ## Performance | |||
| @@ -0,0 +1,40 @@ | |||
| #!/bin/bash | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| ulimit -u unlimited | |||
| export DEVICE_NUM=8 | |||
| export RANK_SIZE=8 | |||
| RANK_TABLE_FILE=$(realpath $1) | |||
| export RANK_TABLE_FILE | |||
| export DATA_PATH=$2 | |||
| echo "RANK_TABLE_FILE=${RANK_TABLE_FILE}" | |||
| export SERVER_ID=0 | |||
| rank_start=$((DEVICE_NUM * SERVER_ID)) | |||
| for((i=0; i<${DEVICE_NUM}; i++)) | |||
| do | |||
| export DEVICE_ID=$i | |||
| export RANK_ID=$((rank_start + i)) | |||
| rm -rf ./train_parallel$i | |||
| mkdir ./train_parallel$i | |||
| cp -r ../src ./train_parallel$i | |||
| cp ../*.py ./train_parallel$i | |||
| echo "start training for rank $RANK_ID, device $DEVICE_ID" | |||
| cd ./train_parallel$i ||exit | |||
| env > env$i.log | |||
| python train.py --distribute --imagenet 1 --batch_size 64 --lr 5e-5 --scale 2+3+4+1+1+1 --alltask --react --model vtip --num_queries 6 --chop_new --num_layers 4 --data_train imagenet --dir_data $DATA_PATH --derain --save experiments/ckpt_new_init/ > log 2>&1 & | |||
| cd .. | |||
| done | |||
| @@ -14,18 +14,49 @@ | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| export DEVICE_ID=$1 | |||
| DATA_DIR=$2 | |||
| DATA_SET=$3 | |||
| PATH_CHECKPOINT=$4 | |||
| ulimit -u unlimited | |||
| export DATA_PATH=$1 | |||
| export DATA_TEST=$2 | |||
| export MODEL=$3 | |||
| export TASK_ID=$4 | |||
| python ../eval.py --dir_data=$DATA_DIR --data_test=$DATA_SET --nochange --test_only --ext img --chop_new --scale 4 --pth_path=$PATH_CHECKPOINT > eval.log 2>&1 & | |||
| python ../eval.py --dir_data=$DATA_DIR --data_test=$DATA_SET --nochange --test_only --ext img --chop_new --scale 3 --pth_path=$PATH_CHECKPOINT > eval.log 2>&1 & | |||
| python ../eval.py --dir_data=$DATA_DIR --data_test=$DATA_SET --nochange --test_only --ext img --chop_new --scale 2 --pth_path=$PATH_CHECKPOINT > eval.log 2>&1 & | |||
| if [[ $TASK_ID -lt 3 ]]; then | |||
| mkdir ./run_eval$TASK_ID | |||
| cp -r ../src ./run_eval$TASK_ID | |||
| cp ../*.py ./run_eval$TASK_ID | |||
| echo "start evaluation for Task $TASK_ID, device $DEVICE_ID" | |||
| cd ./run_eval$TASK_ID ||exit | |||
| env > env$TASK_ID.log | |||
| SCALE=$[$TASK_ID+2] | |||
| python eval.py --dir_data $DATA_PATH --data_test $DATA_TEST --test_only --ext img --pth_path $MODEL --task_id $TASK_ID --scale $SCALE > log 2>&1 & | |||
| fi | |||
| ##denoise | |||
| python ../eval.py --dir_data=$DATA_DIR --data_test=$DATA_SET --nochange --test_only --ext img --chop_new --scale 1 --denoise --sigma 30 --pth_path=$PATH_CHECKPOINT > eval.log 2>&1 & | |||
| python ../eval.py --dir_data=$DATA_DIR --data_test=$DATA_SET --nochange --test_only --ext img --chop_new --scale 1 --denoise --sigma 50 --pth_path=$PATH_CHECKPOINT > eval.log 2>&1 & | |||
| if [[ $TASK_ID -eq 3 ]]; then | |||
| mkdir ./run_eval$TASK_ID | |||
| cp -r ../src ./run_eval$TASK_ID | |||
| cp ../*.py ./run_eval$TASK_ID | |||
| echo "start evaluation for Task $TASK_ID, device $DEVICE_ID" | |||
| cd ./run_eval$TASK_ID ||exit | |||
| env > env$TASK_ID.log | |||
| python eval.py --dir_data $DATA_PATH --data_test $DATA_TEST --test_only --ext img --pth_path $MODEL --task_id $TASK_ID --scale 1 --derain > log 2>&1 & | |||
| fi | |||
| ##derain | |||
| python ../eval.py --dir_data=$DATA_DIR --data_test=$DATA_SET --nochange --test_only --ext img --chop_new --scale 1 --derain --derain_test 1 --pth_path=$PATH_CHECKPOINT > eval.log 2>&1 & | |||
| if [[ $TASK_ID -eq 4 ]]; then | |||
| mkdir ./run_eval$TASK_ID | |||
| cp -r ../src ./run_eval$TASK_ID | |||
| cp ../*.py ./run_eval$TASK_ID | |||
| echo "start evaluation for Task $TASK_ID, device $DEVICE_ID" | |||
| cd ./run_eval$TASK_ID ||exit | |||
| env > env$TASK_ID.log | |||
| python eval.py --dir_data $DATA_PATH --data_test $DATA_TEST --test_only --ext img --pth_path $MODEL --task_id $TASK_ID --scale 1 --denoise --sigma 30 > log 2>&1 & | |||
| fi | |||
| if [[ $TASK_ID -eq 5 ]]; then | |||
| mkdir ./run_eval$TASK_ID | |||
| cp -r ../src ./run_eval$TASK_ID | |||
| cp ../*.py ./run_eval$TASK_ID | |||
| echo "start evaluation for Task $TASK_ID, device $DEVICE_ID" | |||
| cd ./run_eval$TASK_ID ||exit | |||
| env > env$TASK_ID.log | |||
| python eval.py --dir_data $DATA_PATH --data_test $DATA_TEST --test_only --ext img --pth_path $MODEL --task_id $TASK_ID --scale 1 --denoise --sigma 50 > log 2>&1 & | |||
| fi | |||
| @@ -0,0 +1,43 @@ | |||
| #!/bin/bash | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| ulimit -u unlimited | |||
| export DEVICE_NUM=8 | |||
| export RANK_SIZE=8 | |||
| RANK_TABLE_FILE=$(realpath $1) | |||
| export RANK_TABLE_FILE | |||
| export DATA_PATH=$2 | |||
| export MODEL=$3 | |||
| export TASK_ID=$4 | |||
| echo "RANK_TABLE_FILE=${RANK_TABLE_FILE}" | |||
| export SERVER_ID=0 | |||
| rank_start=$((DEVICE_NUM * SERVER_ID)) | |||
| for((i=0; i<${DEVICE_NUM}; i++)) | |||
| do | |||
| export DEVICE_ID=$i | |||
| export RANK_ID=$((rank_start + i)) | |||
| rm -rf ./train_parallel$i | |||
| mkdir ./train_parallel$i | |||
| cp -r ../src ./train_parallel$i | |||
| cp ../*.py ./train_parallel$i | |||
| echo "start training for rank $RANK_ID, device $DEVICE_ID" | |||
| cd ./train_parallel$i ||exit | |||
| env > env$i.log | |||
| python train_finetune.py --distribute --imagenet 0 --batch_size 64 --lr 2e-5 --scale 2+3+4+1+1+1 --model vtip --num_queries 6 --chop_new --num_layers 4 --task_id $TASK_ID --dir_data $DATA_PATH --pth_path $MODEL --epochs 100 > log 2>&1 & | |||
| cd .. | |||
| done | |||
| @@ -1,4 +1,4 @@ | |||
| '''args''' | |||
| """args""" | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| @@ -13,8 +13,8 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import argparse | |||
| from src import template | |||
| parser = argparse.ArgumentParser(description='EDSR and MDSR') | |||
| @@ -24,12 +24,6 @@ parser.add_argument('--template', default='.', | |||
| help='You can set various templates in option.py') | |||
| # Hardware specifications | |||
| parser.add_argument('--n_threads', type=int, default=6, | |||
| help='number of threads for data loading') | |||
| parser.add_argument('--cpu', action='store_true', | |||
| help='use cpu only') | |||
| parser.add_argument('--n_GPUs', type=int, default=1, | |||
| help='number of GPUs') | |||
| parser.add_argument('--seed', type=int, default=1, | |||
| help='random seed') | |||
| @@ -60,9 +54,8 @@ parser.add_argument('--no_augment', action='store_true', | |||
| help='do not use data augmentation') | |||
| # Model specifications | |||
| parser.add_argument('--model', default='vtip', | |||
| parser.add_argument('--model', default='EDSR', | |||
| help='model name') | |||
| parser.add_argument('--act', type=str, default='relu', | |||
| help='activation function') | |||
| parser.add_argument('--pre_train', type=str, default='', | |||
| @@ -139,6 +132,7 @@ parser.add_argument('--gclip', type=float, default=0, | |||
| help='gradient clipping threshold (0 = no clipping)') | |||
| # Loss specifications | |||
| parser.add_argument('--con_loss', action='store_true') | |||
| parser.add_argument('--loss', type=str, default='1*L1', | |||
| help='loss function configuration') | |||
| parser.add_argument('--skip_threshold', type=float, default='1e8', | |||
| @@ -161,6 +155,7 @@ parser.add_argument('--save_gt', action='store_true', | |||
| help='save low-resolution and high-resolution images together') | |||
| parser.add_argument('--scalelr', type=int, default=0) | |||
| # cloud | |||
| parser.add_argument('--moxfile', type=int, default=1) | |||
| parser.add_argument('--imagenet', type=int, default=0) | |||
| @@ -169,10 +164,11 @@ parser.add_argument('--train_url', type=str, help='train_dir') | |||
| parser.add_argument('--pretrain', type=str, default='') | |||
| parser.add_argument('--pth_path', type=str, default='') | |||
| parser.add_argument('--load_query', type=int, default=0) | |||
| # transformer | |||
| parser.add_argument('--patch_dim', type=int, default=3) | |||
| parser.add_argument('--num_heads', type=int, default=12) | |||
| parser.add_argument('--num_layers', type=int, default=12) | |||
| parser.add_argument('--num_layers', type=int, default=4) | |||
| parser.add_argument('--dropout_rate', type=float, default=0) | |||
| parser.add_argument('--no_norm', action='store_true') | |||
| parser.add_argument('--post_norm', action='store_true') | |||
| @@ -192,8 +188,10 @@ parser.add_argument('--sigma', type=float, default=25) | |||
| parser.add_argument('--derain', action='store_true') | |||
| parser.add_argument('--finetune', action='store_true') | |||
| parser.add_argument('--derain_test', type=int, default=10) | |||
| # alltask | |||
| parser.add_argument('--alltask', action='store_true') | |||
| parser.add_argument('--task_id', type=int, default=0) | |||
| # dehaze | |||
| parser.add_argument('--dehaze', action='store_true') | |||
| @@ -201,6 +199,7 @@ parser.add_argument('--dehaze_test', type=int, default=100) | |||
| parser.add_argument('--indoor', action='store_true') | |||
| parser.add_argument('--outdoor', action='store_true') | |||
| parser.add_argument('--nochange', action='store_true') | |||
| # deblur | |||
| parser.add_argument('--deblur', action='store_true') | |||
| parser.add_argument('--deblur_test', type=int, default=1000) | |||
| @@ -210,6 +209,8 @@ parser.add_argument('--init_method', type=str, | |||
| default=None, help='master address') | |||
| parser.add_argument('--rank', type=int, default=0, | |||
| help='Index of current task') | |||
| parser.add_argument('--group_size', type=int, default=0, | |||
| help='Index of current task') | |||
| parser.add_argument('--world_size', type=int, default=1, | |||
| help='Total number of tasks') | |||
| parser.add_argument('--gpu', default=None, type=int, | |||
| @@ -223,7 +224,6 @@ parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', | |||
| parser.add_argument('--distribute', action='store_true') | |||
| args, unparsed = parser.parse_known_args() | |||
| template.set_template(args) | |||
| args.scale = [int(x) for x in args.scale.split("+")] | |||
| args.data_train = args.data_train.split('+') | |||
| @@ -1,35 +0,0 @@ | |||
| """data""" | |||
| # Copyright 2021 Huawei Technologies Co., Ltd" | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| from importlib import import_module | |||
| class Data: | |||
| """data""" | |||
| def __init__(self, args): | |||
| self.loader_train = None | |||
| self.loader_test = [] | |||
| for d in args.data_test: | |||
| if d in ['Set5', 'Set14', 'B100', 'Urban100', 'Manga109', 'CBSD68', 'Rain100L', 'GOPRO_Large']: | |||
| m = import_module('data.benchmark') | |||
| testset = getattr(m, 'Benchmark')(args, train=False, name=d) | |||
| else: | |||
| module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG' | |||
| m = import_module('data.' + module_name.lower()) | |||
| testset = getattr(m, module_name)(args, train=False, name=d) | |||
| self.loader_test.append( | |||
| testset | |||
| ) | |||
| @@ -0,0 +1,132 @@ | |||
| """bicubic""" | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| class bicubic: | |||
| """bicubic""" | |||
| def __init__(self, seed=0): | |||
| self.seed = seed | |||
| self.rand_fn = np.random.RandomState(self.seed) | |||
| def cubic(self, x): | |||
| absx2 = np.abs(x) * np.abs(x) | |||
| absx3 = np.abs(x) * np.abs(x) * np.abs(x) | |||
| condition1 = (np.abs(x) <= 1).astype(np.float32) | |||
| condition2 = ((np.abs(x) > 1) & (np.abs(x) <= 2)).astype(np.float32) | |||
| f = (1.5 * absx3 - 2.5 * absx2 + 1) * condition1 + (-0.5 * absx3 + 2.5 * absx2 - 4 * np.abs(x) + 2) * condition2 | |||
| return f | |||
| def contribute(self, in_size, out_size, scale): | |||
| """bicubic""" | |||
| kernel_width = 4 | |||
| if scale < 1: | |||
| kernel_width = 4 / scale | |||
| x0 = np.arange(start=1, stop=out_size[0]+1).astype(np.float32) | |||
| x1 = np.arange(start=1, stop=out_size[1]+1).astype(np.float32) | |||
| u0 = x0 / scale + 0.5 * (1 - 1 / scale) | |||
| u1 = x1 / scale + 0.5 * (1 - 1 / scale) | |||
| left0 = np.floor(u0 - kernel_width / 2) | |||
| left1 = np.floor(u1 - kernel_width / 2) | |||
| width = np.ceil(kernel_width) + 2 | |||
| indice0 = np.expand_dims(left0, axis=1) + \ | |||
| np.expand_dims(np.arange(start=0, stop=width).astype(np.float32), axis=0) | |||
| indice1 = np.expand_dims(left1, axis=1) + \ | |||
| np.expand_dims(np.arange(start=0, stop=width).astype(np.float32), axis=0) | |||
| mid0 = np.expand_dims(u0, axis=1) - np.expand_dims(indice0, axis=0) | |||
| mid1 = np.expand_dims(u1, axis=1) - np.expand_dims(indice1, axis=0) | |||
| if scale < 1: | |||
| weight0 = scale * self.cubic(mid0 * scale) | |||
| weight1 = scale * self.cubic(mid1 * scale) | |||
| else: | |||
| weight0 = self.cubic(mid0) | |||
| weight1 = self.cubic(mid1) | |||
| weight0 = weight0 / (np.expand_dims(np.sum(weight0, axis=2), 2)) | |||
| weight1 = weight1 / (np.expand_dims(np.sum(weight1, axis=2), 2)) | |||
| indice0 = np.expand_dims(np.minimum(np.maximum(1, indice0), in_size[0]), axis=0) | |||
| indice1 = np.expand_dims(np.minimum(np.maximum(1, indice1), in_size[1]), axis=0) | |||
| kill0 = np.equal(weight0, 0)[0][0] | |||
| kill1 = np.equal(weight1, 0)[0][0] | |||
| weight0 = weight0[:, :, kill0 == 0] | |||
| weight1 = weight1[:, :, kill1 == 0] | |||
| indice0 = indice0[:, :, kill0 == 0] | |||
| indice1 = indice1[:, :, kill1 == 0] | |||
| return weight0, weight1, indice0, indice1 | |||
| def forward(self, hr, rain, lrx2, lrx3, lrx4, filename, batchInfo): | |||
| """bicubic""" | |||
| idx = self.rand_fn.randint(0, 6) | |||
| if idx < 3: | |||
| if idx == 0: | |||
| scale = 1/2 | |||
| hr = lrx2 | |||
| elif idx == 1: | |||
| scale = 1/3 | |||
| hr = lrx3 | |||
| elif idx == 2: | |||
| scale = 1/4 | |||
| hr = lrx4 | |||
| hr = np.array(hr) | |||
| [_, _, h, w] = hr.shape | |||
| weight0, weight1, indice0, indice1 = self.contribute([h, w], [int(h * scale), int(w * scale)], scale) | |||
| weight0 = np.asarray(weight0[0], dtype=np.float32) | |||
| indice0 = np.asarray(indice0[0], dtype=np.float32).astype(np.long) | |||
| weight0 = np.expand_dims(np.expand_dims(np.expand_dims(weight0, axis=0), axis=1), axis=4) | |||
| out = hr[:, :, (indice0-1), :] * weight0 | |||
| out = np.sum(out, axis=3) | |||
| A = np.transpose(out, (0, 1, 3, 2)) | |||
| weight1 = np.asarray(weight1[0], dtype=np.float32) | |||
| weight1 = np.expand_dims(np.expand_dims(np.expand_dims(weight1, axis=0), axis=1), axis=4) | |||
| indice1 = np.asarray(indice1[0], dtype=np.float32).astype(np.long) | |||
| out = A[:, :, (indice1-1), :] * weight1 | |||
| out = np.round(255 * np.transpose(np.sum(out, axis=3), (0, 1, 3, 2)))/255 | |||
| out = np.clip(np.round(out), 0, 255) | |||
| lr = list(out) | |||
| hr = list(hr) | |||
| else: | |||
| if idx == 3: | |||
| hr = np.array(hr) | |||
| rain = np.array(rain) | |||
| lr = np.clip((rain + hr), 0, 255) | |||
| hr = list(hr) | |||
| lr = list(lr) | |||
| elif idx == 4: | |||
| hr = np.array(hr) | |||
| noise = np.random.randn(*hr.shape) * 30 | |||
| lr = np.clip(noise + hr, 0, 255) | |||
| hr = list(hr) | |||
| lr = list(lr) | |||
| elif idx == 5: | |||
| hr = np.array(hr) | |||
| noise = np.random.randn(*hr.shape) * 50 | |||
| lr = np.clip(noise + hr, 0, 255) | |||
| hr = list(hr) | |||
| lr = list(lr) | |||
| return lr, hr, [idx] * len(hr), filename | |||
| @@ -1,6 +1,6 @@ | |||
| """common""" | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| @@ -13,13 +13,11 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import random | |||
| import random | |||
| import numpy as np | |||
| import skimage.color as sc | |||
| def get_patch(*args, patch_size=96, scale=2, multi=False, input_large=False): | |||
| def get_patch(*args, patch_size=96, scale=2, input_large=False): | |||
| """common""" | |||
| ih, iw = args[0].shape[:2] | |||
| @@ -34,25 +32,19 @@ def get_patch(*args, patch_size=96, scale=2, multi=False, input_large=False): | |||
| else: | |||
| tx, ty = ix, iy | |||
| ret = [ | |||
| args[0][iy:iy + ip, ix:ix + ip, :], | |||
| *[a[ty:ty + tp, tx:tx + tp, :] for a in args[1:]] | |||
| ] | |||
| ret = [args[0][iy:iy + ip, ix:ix + ip, :], *[a[ty:ty + tp, tx:tx + tp, :] for a in args[1:]]] | |||
| return ret | |||
| def set_channel(*args, n_channels=3): | |||
| """common""" | |||
| def _set_channel(img): | |||
| if img.ndim == 2: | |||
| img = np.expand_dims(img, axis=2) | |||
| c = img.shape[2] | |||
| if n_channels == 1 and c == 3: | |||
| img = np.expand_dims(sc.rgb2ycbcr(img)[:, :, 0], 2) | |||
| elif n_channels == 3 and c == 1: | |||
| if n_channels == 3 and c == 1: | |||
| img = np.concatenate([img] * n_channels, 2) | |||
| return img[:, :, :n_channels] | |||
| @@ -61,14 +53,11 @@ def set_channel(*args, n_channels=3): | |||
| def np2Tensor(*args, rgb_range=255): | |||
| """common""" | |||
| def _np2Tensor(img): | |||
| np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1))) | |||
| tensor = np_transpose.astype(np.float32) | |||
| tensor = tensor * (rgb_range / 255) | |||
| return tensor | |||
| input_data = np_transpose.astype(np.float32) | |||
| output = input_data * (rgb_range / 255) | |||
| return output | |||
| return [_np2Tensor(a) for a in args] | |||
| @@ -79,6 +68,7 @@ def augment(*args, hflip=True, rot=True): | |||
| rot90 = rot and random.random() < 0.5 | |||
| def _augment(img): | |||
| """common""" | |||
| if hflip: | |||
| img = img[:, ::-1, :] | |||
| if vflip: | |||
| @@ -88,3 +78,18 @@ def augment(*args, hflip=True, rot=True): | |||
| return img | |||
| return [_augment(a) for a in args] | |||
| def search(root, target="JPEG"): | |||
| """srdata""" | |||
| item_list = [] | |||
| items = os.listdir(root) | |||
| for item in items: | |||
| path = os.path.join(root, item) | |||
| if os.path.isdir(path): | |||
| item_list.extend(search(path, target)) | |||
| elif path.split('/')[-1].startswith(target): | |||
| item_list.append(path) | |||
| elif target in (path.split('/')[-2], path.split('/')[-3], path.split('/')[-4]): | |||
| item_list.append(path) | |||
| return item_list | |||
| @@ -0,0 +1,45 @@ | |||
| """div2k""" | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import os | |||
| from src.data.srdata import SRData | |||
| class DIV2K(SRData): | |||
| """DIV2K""" | |||
| def __init__(self, args, name='DIV2K', train=True, benchmark=False): | |||
| data_range = [r.split('-') for r in args.data_range.split('/')] | |||
| if train: | |||
| data_range = data_range[0] | |||
| else: | |||
| if args.test_only and len(data_range) == 1: | |||
| data_range = data_range[0] | |||
| else: | |||
| data_range = data_range[1] | |||
| self.begin, self.end = list(map(int, data_range)) | |||
| super(DIV2K, self).__init__(args, name=name, train=train, benchmark=benchmark) | |||
| def _scan(self): | |||
| names_hr, names_lr = super(DIV2K, self)._scan() | |||
| names_hr = names_hr[self.begin - 1:self.end] | |||
| names_lr = [n[self.begin - 1:self.end] for n in names_lr] | |||
| return names_hr, names_lr | |||
| def _set_filesystem(self, dir_data): | |||
| super(DIV2K, self)._set_filesystem(dir_data) | |||
| self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR') | |||
| self.dir_lr = os.path.join(self.apath, 'DIV2K_train_LR_bicubic') | |||
| @@ -0,0 +1,171 @@ | |||
| """imagent""" | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import os | |||
| import random | |||
| import io | |||
| from PIL import Image | |||
| import numpy as np | |||
| import imageio | |||
| def search(root, target="JPEG"): | |||
| """imagent""" | |||
| item_list = [] | |||
| items = os.listdir(root) | |||
| for item in items: | |||
| path = os.path.join(root, item) | |||
| if os.path.isdir(path): | |||
| item_list.extend(search(path, target)) | |||
| elif path.split('.')[-1] == target: | |||
| item_list.append(path) | |||
| elif path.split('/')[-1].startswith(target): | |||
| item_list.append(path) | |||
| return item_list | |||
| def get_patch_img(img, patch_size=96, scale=2): | |||
| """imagent""" | |||
| ih, iw = img.shape[:2] | |||
| tp = scale * patch_size | |||
| if (iw - tp) > -1 and (ih-tp) > 1: | |||
| ix = random.randrange(0, iw-tp+1) | |||
| iy = random.randrange(0, ih-tp+1) | |||
| hr = img[iy:iy+tp, ix:ix+tp, :3] | |||
| elif (iw - tp) > -1 and (ih - tp) <= -1: | |||
| ix = random.randrange(0, iw-tp+1) | |||
| hr = img[:, ix:ix+tp, :3] | |||
| pil_img = Image.fromarray(hr).resize((tp, tp), Image.BILINEAR) | |||
| hr = np.array(pil_img) | |||
| elif (iw - tp) <= -1 and (ih - tp) > -1: | |||
| iy = random.randrange(0, ih-tp+1) | |||
| hr = img[iy:iy+tp, :, :3] | |||
| pil_img = Image.fromarray(hr).resize((tp, tp), Image.BILINEAR) | |||
| hr = np.array(pil_img) | |||
| else: | |||
| pil_img = Image.fromarray(img).resize((tp, tp), Image.BILINEAR) | |||
| hr = np.array(pil_img) | |||
| return hr | |||
| class ImgData(): | |||
| """imagent""" | |||
| def __init__(self, args, train=True): | |||
| self.input_large = (args.model == 'VDSR') | |||
| self.scale = args.scale | |||
| self.idx_scale = 0 | |||
| self.dataroot = args.dir_data | |||
| self.img_list = search(os.path.join(self.dataroot, "train"), "JPEG") | |||
| self.img_list.extend(search(os.path.join(self.dataroot, "val"), "JPEG")) | |||
| self.img_list = sorted(self.img_list) | |||
| self.train = train | |||
| self.args = args | |||
| self.len = len(self.img_list) | |||
| print("data length:", len(self.img_list)) | |||
| if self.args.derain: | |||
| self.derain_dataroot = os.path.join(self.dataroot, "RainTrainL") | |||
| self.derain_img_list = search(self.derain_dataroot, "rainstreak") | |||
| def __len__(self): | |||
| return len(self.img_list) | |||
| def _get_index(self, idx): | |||
| return idx % len(self.img_list) | |||
| def _load_file(self, idx): | |||
| idx = self._get_index(idx) | |||
| f_lr = self.img_list[idx] | |||
| lr = imageio.imread(f_lr) | |||
| if len(lr.shape) == 2: | |||
| lr = np.dstack([lr, lr, lr]) | |||
| return lr, f_lr | |||
| def _np2Tensor(self, img, rgb_range): | |||
| np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1))) | |||
| tensor = np_transpose.astype(np.float32) | |||
| tensor = tensor * (rgb_range / 255) | |||
| return tensor | |||
| def __getitem__(self, idx): | |||
| if self.args.model == 'vtip' and self.train and self.args.alltask: | |||
| lr, filename = self._load_file(idx % self.len) | |||
| pair_list = [] | |||
| rain = self._load_rain() | |||
| rain = np.expand_dims(rain, axis=2) | |||
| rain = self.get_patch(rain, 1) | |||
| rain = self._np2Tensor(rain, rgb_range=self.args.rgb_range) | |||
| for idx_scale in range(4): | |||
| self.idx_scale = idx_scale | |||
| pair = self.get_patch(lr) | |||
| pair_t = self._np2Tensor(pair, rgb_range=self.args.rgb_range) | |||
| pair_list.append(pair_t) | |||
| return pair_list[3], rain, pair_list[0], pair_list[1], pair_list[2], [self.scale], [filename] | |||
| if self.args.model == 'vtip' and self.train and len(self.scale) > 1: | |||
| lr, filename = self._load_file(idx % self.len) | |||
| pair_list = [] | |||
| for idx_scale in range(3): | |||
| self.idx_scale = idx_scale | |||
| pair = self.get_patch(lr) | |||
| pair_t = self._np2Tensor(pair, rgb_range=self.args.rgb_range) | |||
| pair_list.append(pair_t) | |||
| return pair_list[0], pair_list[1], pair_list[2], filename | |||
| if self.args.model == 'vtip' and self.args.derain and self.scale[self.idx_scale] == 1: | |||
| lr, filename = self._load_file(idx % self.len) | |||
| rain = self._load_rain() | |||
| rain = np.expand_dims(rain, axis=2) | |||
| rain = self.get_patch(rain, 1) | |||
| rain = self._np2Tensor(rain, rgb_range=self.args.rgb_range) | |||
| pair = self.get_patch(lr) | |||
| pair_t = self._np2Tensor(pair, rgb_range=self.args.rgb_range) | |||
| return pair_t, rain, filename | |||
| if self.args.jpeg: | |||
| hr, filename = self._load_file(idx % self.len) | |||
| buffer = io.BytesIO() | |||
| width, height = hr.size | |||
| patch_size = self.scale[self.idx_scale]*self.args.patch_size | |||
| if width < patch_size: | |||
| hr = hr.resize((patch_size, height), Image.ANTIALIAS) | |||
| width, height = hr.size | |||
| if height < patch_size: | |||
| hr = hr.resize((width, patch_size), Image.ANTIALIAS) | |||
| hr.save(buffer, format='jpeg', quality=25) | |||
| lr = Image.open(buffer) | |||
| lr = np.array(lr).astype(np.float32) | |||
| hr = np.array(hr).astype(np.float32) | |||
| lr = self.get_patch(lr) | |||
| hr = self.get_patch(hr) | |||
| lr = self._np2Tensor(lr, rgb_range=self.args.rgb_range) | |||
| hr = self._np2Tensor(hr, rgb_range=self.args.rgb_range) | |||
| return lr, hr, filename | |||
| lr, filename = self._load_file(idx % self.len) | |||
| pair = self.get_patch(lr) | |||
| pair_t = self._np2Tensor(pair, rgb_range=self.args.rgb_range) | |||
| return pair_t, filename | |||
| def _load_rain(self): | |||
| idx = random.randint(0, len(self.derain_img_list) - 1) | |||
| f_lr = self.derain_img_list[idx] | |||
| lr = imageio.imread(f_lr) | |||
| return lr | |||
| def get_patch(self, lr, scale=0): | |||
| if scale == 0: | |||
| scale = self.scale[self.idx_scale] | |||
| lr = get_patch_img(lr, patch_size=self.args.patch_size, scale=scale) | |||
| return lr | |||
| def set_scale(self, idx_scale): | |||
| self.idx_scale = idx_scale | |||
| @@ -1,6 +1,6 @@ | |||
| """srdata""" | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| @@ -13,6 +13,7 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import os | |||
| import glob | |||
| import random | |||
| @@ -20,43 +21,12 @@ import pickle | |||
| import numpy as np | |||
| import imageio | |||
| from src.data import common | |||
| from PIL import ImageFile | |||
| ImageFile.LOAD_TRUNCATED_IMAGES = True | |||
| def search(root, target="JPEG"): | |||
| """srdata""" | |||
| item_list = [] | |||
| items = os.listdir(root) | |||
| for item in items: | |||
| path = os.path.join(root, item) | |||
| if os.path.isdir(path): | |||
| item_list.extend(search(path, target)) | |||
| elif path.split('/')[-1].startswith(target): | |||
| item_list.append(path) | |||
| elif target in (path.split('/')[-2], path.split('/')[-3], path.split('/')[-4]): | |||
| item_list.append(path) | |||
| else: | |||
| item_list = [] | |||
| return item_list | |||
| def search_dehaze(root, target="JPEG"): | |||
| class SRData: | |||
| """srdata""" | |||
| item_list = [] | |||
| items = os.listdir(root) | |||
| for item in items: | |||
| path = os.path.join(root, item) | |||
| if os.path.isdir(path): | |||
| extend_list = search_dehaze(path, target) | |||
| if extend_list is not None: | |||
| item_list.extend(extend_list) | |||
| elif path.split('/')[-2].endswith(target): | |||
| item_list.append(path) | |||
| return item_list | |||
| class SRData(): | |||
| """srdata""" | |||
| def __init__(self, args, name='', train=True, benchmark=False): | |||
| self.args = args | |||
| self.name = name | |||
| @@ -69,37 +39,46 @@ class SRData(): | |||
| self.idx_scale = 0 | |||
| if self.args.derain: | |||
| self.derain_test = os.path.join(args.dir_data, "Rain100L") | |||
| self.derain_lr_test = search(self.derain_test, "rain") | |||
| self.derain_hr_test = [path.replace( | |||
| "rainy/", "no") for path in self.derain_lr_test] | |||
| if self.train: | |||
| self.derain_dataroot = os.path.join(args.dir_data, "RainTrainL") | |||
| self.clear_train = common.search(self.derain_dataroot, "norain") | |||
| self.rain_train = [] | |||
| for path in self.clear_train: | |||
| change_path = path.split('/') | |||
| change_path[-1] = change_path[-1][2:] | |||
| self.rain_train.append('/'.join(change_path)) | |||
| self.derain_test = os.path.join(args.dir_data, "Rain100L") | |||
| self.deblur_lr_test = common.search(self.derain_test, "rain") | |||
| self.deblur_hr_test = [path.replace("rainy/", "no") for path in self.deblur_lr_test] | |||
| self.derain_hr_test = self.deblur_hr_test | |||
| else: | |||
| self.derain_test = os.path.join(args.dir_data, "Rain100L") | |||
| self.derain_lr_test = common.search(self.derain_test, "rain") | |||
| self.derain_hr_test = [path.replace("rainy/", "no") for path in self.derain_lr_test] | |||
| self._set_filesystem(args.dir_data) | |||
| self._set_img(args) | |||
| if self.args.derain and self.train: | |||
| self.images_hr, self.images_lr = self.clear_train, self.rain_train | |||
| if train: | |||
| self._repeat(args) | |||
| def _set_img(self, args): | |||
| """srdata""" | |||
| if args.ext.find('img') < 0: | |||
| path_bin = os.path.join(self.apath, 'bin') | |||
| os.makedirs(path_bin, exist_ok=True) | |||
| list_hr, list_lr = self._scan() | |||
| if args.ext.find('img') >= 0 or benchmark: | |||
| if args.ext.find('img') >= 0 or self.benchmark: | |||
| self.images_hr, self.images_lr = list_hr, list_lr | |||
| elif args.ext.find('sep') >= 0: | |||
| os.makedirs( | |||
| self.dir_hr.replace(self.apath, path_bin), | |||
| exist_ok=True | |||
| ) | |||
| os.makedirs(self.dir_hr.replace(self.apath, path_bin), exist_ok=True) | |||
| for s in self.scale: | |||
| if s == 1: | |||
| os.makedirs( | |||
| os.path.join(self.dir_hr), | |||
| exist_ok=True | |||
| ) | |||
| os.makedirs(os.path.join(self.dir_hr), exist_ok=True) | |||
| else: | |||
| os.makedirs( | |||
| os.path.join( | |||
| self.dir_lr.replace(self.apath, path_bin), | |||
| 'X{}'.format(s) | |||
| ), | |||
| exist_ok=True | |||
| ) | |||
| os.path.join(self.dir_lr.replace(self.apath, path_bin), 'X{}'.format(s)), exist_ok=True) | |||
| self.images_hr, self.images_lr = [], [[] for _ in self.scale] | |||
| for h in list_hr: | |||
| @@ -114,23 +93,27 @@ class SRData(): | |||
| self.images_lr[i].append(b) | |||
| self._check_and_load(args.ext, l, b, verbose=True) | |||
| # Below functions as used to prepare images | |||
| def _repeat(self, args): | |||
| """srdata""" | |||
| n_patches = args.batch_size * args.test_every | |||
| n_images = len(args.data_train) * len(self.images_hr) | |||
| if n_images == 0: | |||
| self.repeat = 0 | |||
| else: | |||
| self.repeat = max(n_patches // n_images, 1) | |||
| def _scan(self): | |||
| """srdata""" | |||
| names_hr = sorted( | |||
| glob.glob(os.path.join(self.dir_hr, '*' + self.ext[0])) | |||
| ) | |||
| glob.glob(os.path.join(self.dir_hr, '*' + self.ext[0]))) | |||
| names_lr = [[] for _ in self.scale] | |||
| for f in names_hr: | |||
| filename, _ = os.path.splitext(os.path.basename(f)) | |||
| for si, s in enumerate(self.scale): | |||
| if s != 1: | |||
| scale = s | |||
| names_lr[si].append(os.path.join( | |||
| self.dir_lr, 'X{}/{}x{}{}'.format( | |||
| s, filename, scale, self.ext[1] | |||
| ) | |||
| )) | |||
| names_lr[si].append(os.path.join(self.dir_lr, 'X{}/{}x{}{}' \ | |||
| .format(s, filename, scale, self.ext[1]))) | |||
| for si, s in enumerate(self.scale): | |||
| if s == 1: | |||
| names_lr[si] = names_hr | |||
| @@ -150,28 +133,33 @@ class SRData(): | |||
| pickle.dump(imageio.imread(img), _f) | |||
| def __getitem__(self, idx): | |||
| if self.args.model == 'vtip' and self.args.derain and self.scale[ | |||
| self.idx_scale] == 1 and not self.args.finetune: | |||
| norain, rain, _ = self._load_rain_test(idx) | |||
| pair = common.set_channel( | |||
| *[rain, norain], n_channels=self.args.n_colors) | |||
| pair_t = common.np2Tensor(*pair, rgb_range=self.args.rgb_range) | |||
| return pair_t[0], pair_t[1] | |||
| if self.args.model == 'vtip' and self.args.denoise and self.scale[self.idx_scale] == 1: | |||
| hr, _ = self._load_file_hr(idx) | |||
| if self.args.derain and self.scale[self.idx_scale] == 1: | |||
| if self.train: | |||
| lr, hr, filename = self._load_file_deblur(idx) | |||
| pair = self.get_patch(lr, hr) | |||
| pair = common.set_channel(*pair, n_channels=self.args.n_colors) | |||
| pair_t = common.np2Tensor(*pair, rgb_range=self.args.rgb_range) | |||
| else: | |||
| norain, rain, filename = self._load_rain_test(idx) | |||
| pair = common.set_channel(*[rain, norain], n_channels=self.args.n_colors) | |||
| pair_t = common.np2Tensor(*pair, rgb_range=self.args.rgb_range) | |||
| return pair_t[0], pair_t[1], [self.idx_scale], [filename] | |||
| if self.args.denoise and self.scale[self.idx_scale] == 1: | |||
| hr, filename = self._load_file_hr(idx) | |||
| pair = self.get_patch_hr(hr) | |||
| pair = common.set_channel(*[pair], n_channels=self.args.n_colors) | |||
| pair_t = common.np2Tensor(*pair, rgb_range=self.args.rgb_range) | |||
| noise = np.random.randn(*pair_t[0].shape) * self.args.sigma | |||
| lr = pair_t[0] + noise | |||
| lr = np.float32(np.clip(lr, 0, 255)) | |||
| return lr, pair_t[0] | |||
| lr, hr, _ = self._load_file(idx) | |||
| return lr, pair_t[0], [self.idx_scale], [filename] | |||
| lr, hr, filename = self._load_file(idx) | |||
| pair = self.get_patch(lr, hr) | |||
| pair = common.set_channel(*pair, n_channels=self.args.n_colors) | |||
| pair_t = common.np2Tensor(*pair, rgb_range=self.args.rgb_range) | |||
| return pair_t[0], pair_t[1] | |||
| return pair_t[0], pair_t[1], [self.idx_scale], [filename] | |||
| def __len__(self): | |||
| if self.train: | |||
| @@ -182,7 +170,6 @@ class SRData(): | |||
| return len(self.images_hr) | |||
| def _get_index(self, idx): | |||
| """srdata""" | |||
| if self.train: | |||
| return idx % len(self.images_hr) | |||
| return idx | |||
| @@ -198,22 +185,9 @@ class SRData(): | |||
| elif self.args.ext.find('sep') >= 0: | |||
| with open(f_hr, 'rb') as _f: | |||
| hr = pickle.load(_f) | |||
| return hr, filename | |||
| def _load_rain(self, idx, rain_img=False): | |||
| """srdata""" | |||
| idx = random.randint(0, len(self.derain_img_list) - 1) | |||
| f_lr = self.derain_img_list[idx] | |||
| if rain_img: | |||
| norain = imageio.imread(f_lr.replace("rainstreak", "norain")) | |||
| rain = imageio.imread(f_lr.replace("rainstreak", "rain")) | |||
| return norain, rain | |||
| lr = imageio.imread(f_lr) | |||
| return lr | |||
| def _load_rain_test(self, idx): | |||
| """srdata""" | |||
| f_hr = self.derain_hr_test[idx] | |||
| f_lr = self.derain_lr_test[idx] | |||
| filename, _ = os.path.splitext(os.path.basename(f_lr)) | |||
| @@ -221,14 +195,6 @@ class SRData(): | |||
| rain = imageio.imread(f_lr) | |||
| return norain, rain, filename | |||
| def _load_denoise(self, idx): | |||
| """srdata""" | |||
| idx = self._get_index(idx) | |||
| f_lr = self.images_hr[idx] | |||
| norain = imageio.imread(f_lr) | |||
| rain = imageio.imread(f_lr.replace("HR", "LR_bicubic")) | |||
| return norain, rain | |||
| def _load_file(self, idx): | |||
| """srdata""" | |||
| idx = self._get_index(idx) | |||
| @@ -251,12 +217,7 @@ class SRData(): | |||
| def get_patch_hr(self, hr): | |||
| """srdata""" | |||
| if self.train: | |||
| hr = self.get_patch_img_hr( | |||
| hr, | |||
| patch_size=self.args.patch_size, | |||
| scale=1 | |||
| ) | |||
| hr = self.get_patch_img_hr(hr, patch_size=self.args.patch_size, scale=1) | |||
| return hr | |||
| def get_patch_img_hr(self, img, patch_size=96, scale=2): | |||
| @@ -280,9 +241,7 @@ class SRData(): | |||
| lr, hr = common.get_patch( | |||
| lr, hr, | |||
| patch_size=self.args.patch_size * scale, | |||
| scale=scale, | |||
| multi=(len(self.scale) > 1) | |||
| ) | |||
| scale=scale) | |||
| if not self.args.no_augment: | |||
| lr, hr = common.augment(lr, hr) | |||
| else: | |||
| @@ -292,7 +251,6 @@ class SRData(): | |||
| return lr, hr | |||
| def set_scale(self, idx_scale): | |||
| """srdata""" | |||
| if not self.input_large: | |||
| self.idx_scale = idx_scale | |||
| else: | |||
| @@ -1,241 +0,0 @@ | |||
| '''stride''' | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| from mindspore import nn | |||
| from mindspore.ops import operations as P | |||
| from mindspore.common import dtype as mstype | |||
| class _stride_unfold_(nn.Cell): | |||
| '''stride''' | |||
| def __init__(self, | |||
| kernel_size, | |||
| stride=-1): | |||
| super(_stride_unfold_, self).__init__() | |||
| if stride == -1: | |||
| self.stride = kernel_size | |||
| else: | |||
| self.stride = stride | |||
| self.kernel_size = kernel_size | |||
| self.reshape = P.Reshape() | |||
| self.transpose = P.Transpose() | |||
| self.unfold = _unfold_(kernel_size) | |||
| def construct(self, x): | |||
| """stride""" | |||
| N, C, H, W = x.shape | |||
| leftup_idx_x = [] | |||
| leftup_idx_y = [] | |||
| nh = int(H / self.stride) | |||
| nw = int(W / self.stride) | |||
| for i in range(nh): | |||
| leftup_idx_x.append(i * self.stride) | |||
| for i in range(nw): | |||
| leftup_idx_y.append(i * self.stride) | |||
| NumBlock_x = len(leftup_idx_x) | |||
| NumBlock_y = len(leftup_idx_y) | |||
| zeroslike = P.ZerosLike() | |||
| cc_2 = P.Concat(axis=2) | |||
| cc_3 = P.Concat(axis=3) | |||
| unf_x = P.Zeros()((N, C, NumBlock_x * self.kernel_size, | |||
| NumBlock_y * self.kernel_size), mstype.float32) | |||
| N, C, H, W = unf_x.shape | |||
| for i in range(NumBlock_x): | |||
| for j in range(NumBlock_y): | |||
| unf_i = i * self.kernel_size | |||
| unf_j = j * self.kernel_size | |||
| org_i = leftup_idx_x[i] | |||
| org_j = leftup_idx_y[j] | |||
| fills = x[:, :, org_i:org_i + self.kernel_size, | |||
| org_j:org_j + self.kernel_size] | |||
| unf_x += cc_3((cc_3((zeroslike(unf_x[:, :, :, :unf_j]), cc_2((cc_2( | |||
| (zeroslike(unf_x[:, :, :unf_i, unf_j:unf_j + self.kernel_size]), fills)), zeroslike( | |||
| unf_x[:, :, unf_i + self.kernel_size:, unf_j:unf_j + self.kernel_size]))))), | |||
| zeroslike(unf_x[:, :, :, unf_j + self.kernel_size:]))) | |||
| y = self.unfold(unf_x) | |||
| return y | |||
| class _stride_fold_(nn.Cell): | |||
| '''stride''' | |||
| def __init__(self, | |||
| kernel_size, | |||
| output_shape=(-1, -1), | |||
| stride=-1): | |||
| super(_stride_fold_, self).__init__() | |||
| if isinstance(kernel_size, (list, tuple)): | |||
| self.kernel_size = kernel_size | |||
| else: | |||
| self.kernel_size = [kernel_size, kernel_size] | |||
| if stride == -1: | |||
| self.stride = kernel_size[0] | |||
| else: | |||
| self.stride = stride | |||
| self.output_shape = output_shape | |||
| self.reshape = P.Reshape() | |||
| self.transpose = P.Transpose() | |||
| self.fold = _fold_(kernel_size) | |||
| def construct(self, x): | |||
| '''stride''' | |||
| if self.output_shape[0] == -1: | |||
| large_x = self.fold(x) | |||
| N, C, H, _ = large_x.shape | |||
| leftup_idx = [] | |||
| for i in range(0, H, self.kernel_size[0]): | |||
| leftup_idx.append(i) | |||
| NumBlock = len(leftup_idx) | |||
| fold_x = P.Zeros()((N, C, (NumBlock - 1) * self.stride + self.kernel_size[0], | |||
| (NumBlock - 1) * self.stride + self.kernel_size[0]), mstype.float32) | |||
| for i in range(NumBlock): | |||
| for j in range(NumBlock): | |||
| fold_i = i * self.stride | |||
| fold_j = j * self.stride | |||
| org_i = leftup_idx[i] | |||
| org_j = leftup_idx[j] | |||
| fills = x[:, :, org_i:org_i + self.kernel_size[0], | |||
| org_j:org_j + self.kernel_size[1]] | |||
| fold_x += cc_3((cc_3((zeroslike(fold_x[:, :, :, :fold_j]), cc_2((cc_2( | |||
| (zeroslike(fold_x[:, :, :fold_i, fold_j:fold_j + self.kernel_size[1]]), fills)), zeroslike( | |||
| fold_x[:, :, fold_i + self.kernel_size[0]:, fold_j:fold_j + self.kernel_size[1]]))))), | |||
| zeroslike(fold_x[:, :, :, fold_j + self.kernel_size[1]:]))) | |||
| y = fold_x | |||
| else: | |||
| NumBlock_x = int( | |||
| (self.output_shape[0] - self.kernel_size[0]) / self.stride + 1) | |||
| NumBlock_y = int( | |||
| (self.output_shape[1] - self.kernel_size[1]) / self.stride + 1) | |||
| large_shape = [NumBlock_x * self.kernel_size[0], | |||
| NumBlock_y * self.kernel_size[1]] | |||
| self.fold = _fold_(self.kernel_size, large_shape) | |||
| large_x = self.fold(x) | |||
| N, C, H, _ = large_x.shape | |||
| leftup_idx_x = [] | |||
| leftup_idx_y = [] | |||
| for i in range(NumBlock_x): | |||
| leftup_idx_x.append(i * self.kernel_size[0]) | |||
| for i in range(NumBlock_y): | |||
| leftup_idx_y.append(i * self.kernel_size[1]) | |||
| fold_x = P.Zeros()((N, C, (NumBlock_x - 1) * self.stride + self.kernel_size[0], | |||
| (NumBlock_y - 1) * self.stride + self.kernel_size[1]), mstype.float32) | |||
| for i in range(NumBlock_x): | |||
| for j in range(NumBlock_y): | |||
| fold_i = i * self.stride | |||
| fold_j = j * self.stride | |||
| org_i = leftup_idx_x[i] | |||
| org_j = leftup_idx_y[j] | |||
| fills = x[:, :, org_i:org_i + self.kernel_size[0], | |||
| org_j:org_j + self.kernel_size[1]] | |||
| fold_x += cc_3((cc_3((zeroslike(fold_x[:, :, :, :fold_j]), cc_2((cc_2( | |||
| (zeroslike(fold_x[:, :, :fold_i, fold_j:fold_j + self.kernel_size[1]]), fills)), zeroslike( | |||
| fold_x[:, :, fold_i + self.kernel_size[0]:, fold_j:fold_j + self.kernel_size[1]]))))), | |||
| zeroslike(fold_x[:, :, :, fold_j + self.kernel_size[1]:]))) | |||
| y = fold_x | |||
| return y | |||
| class _unfold_(nn.Cell): | |||
| '''stride''' | |||
| def __init__(self, | |||
| kernel_size, | |||
| stride=-1): | |||
| super(_unfold_, self).__init__() | |||
| if stride == -1: | |||
| self.stride = kernel_size | |||
| self.kernel_size = kernel_size | |||
| self.reshape = P.Reshape() | |||
| self.transpose = P.Transpose() | |||
| def construct(self, x): | |||
| '''stride''' | |||
| N, C, H, W = x.shape | |||
| numH = int(H / self.kernel_size) | |||
| numW = int(W / self.kernel_size) | |||
| if numH * self.kernel_size != H or numW * self.kernel_size != W: | |||
| x = x[:, :, :numH * self.kernel_size, :, numW * self.kernel_size] | |||
| output_img = self.reshape(x, (N, C, numH, self.kernel_size, W)) | |||
| output_img = self.transpose(output_img, (0, 1, 2, 4, 3)) | |||
| output_img = self.reshape(output_img, (N, C, int( | |||
| numH * numW), self.kernel_size, self.kernel_size)) | |||
| output_img = self.transpose(output_img, (0, 2, 1, 4, 3)) | |||
| output_img = self.reshape(output_img, (N, int(numH * numW), -1)) | |||
| return output_img | |||
| class _fold_(nn.Cell): | |||
| '''stride''' | |||
| def __init__(self, | |||
| kernel_size, | |||
| output_shape=(-1, -1), | |||
| stride=-1): | |||
| super(_fold_, self).__init__() | |||
| if isinstance(kernel_size, (list, tuple)): | |||
| self.kernel_size = kernel_size | |||
| else: | |||
| self.kernel_size = [kernel_size, kernel_size] | |||
| if stride == -1: | |||
| self.stride = kernel_size[0] | |||
| self.output_shape = output_shape | |||
| self.reshape = P.Reshape() | |||
| self.transpose = P.Transpose() | |||
| def construct(self, x): | |||
| '''stride''' | |||
| N, C, L = x.shape | |||
| org_C = int(L / self.kernel_size[0] / self.kernel_size[1]) | |||
| if self.output_shape[0] == -1: | |||
| numH = int(np.sqrt(C)) | |||
| numW = int(np.sqrt(C)) | |||
| org_H = int(numH * self.kernel_size[0]) | |||
| org_W = org_H | |||
| else: | |||
| org_H = int(self.output_shape[0]) | |||
| org_W = int(self.output_shape[1]) | |||
| numH = int(org_H / self.kernel_size[0]) | |||
| numW = int(org_W / self.kernel_size[1]) | |||
| output_img = self.reshape( | |||
| x, (N, C, org_C, self.kernel_size[0], self.kernel_size[1])) | |||
| output_img = self.transpose(output_img, (0, 2, 1, 3, 4)) | |||
| output_img = self.reshape( | |||
| output_img, (N, org_C, numH, numW, self.kernel_size[0], self.kernel_size[1])) | |||
| output_img = self.transpose(output_img, (0, 1, 2, 4, 3, 5)) | |||
| output_img = self.reshape(output_img, (N, org_C, org_H, org_W)) | |||
| return output_img | |||
| @@ -13,15 +13,30 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import math | |||
| import copy | |||
| import numpy as np | |||
| from mindspore import nn | |||
| import mindspore.common.dtype as mstype | |||
| from mindspore.ops import operations as P | |||
| from mindspore.common.tensor import Tensor | |||
| from mindspore.common.parameter import Parameter | |||
| from mindspore.common import Tensor, Parameter | |||
| class LayerPreprocess(nn.Cell): | |||
| """ | |||
| Preprocess input of each layer | |||
| """ | |||
| def __init__(self, in_channels=None): | |||
| super(LayerPreprocess, self).__init__() | |||
| self.layernorm = nn.LayerNorm((in_channels,)) | |||
| self.cast = P.Cast() | |||
| self.get_dtype = P.DType() | |||
| def construct(self, input_tensor): | |||
| output = self.cast(input_tensor, mstype.float32) | |||
| output = self.layernorm(output) | |||
| output = self.cast(output, self.get_dtype(input_tensor)) | |||
| return output | |||
| class MultiheadAttention(nn.Cell): | |||
| """ | |||
| @@ -45,7 +60,7 @@ class MultiheadAttention(nn.Cell): | |||
| initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. | |||
| do_return_2d_tensor (bool): True for return 2d tensor. False for return 3d | |||
| tensor. Default: False. | |||
| compute_type (:class:`mindspore.dtype`): Compute type in MultiheadAttention. Default: mstype.float32. | |||
| compute_type (:class:`mindspore.dtype`): Compute type in MultiheadAttention. Default: mstype.float16. | |||
| """ | |||
| def __init__(self, | |||
| @@ -64,13 +79,12 @@ class MultiheadAttention(nn.Cell): | |||
| use_one_hot_embeddings=False, | |||
| initializer_range=0.02, | |||
| do_return_2d_tensor=False, | |||
| compute_type=mstype.float32, | |||
| compute_type=mstype.float16, | |||
| same_dim=True): | |||
| super(MultiheadAttention, self).__init__() | |||
| self.num_attention_heads = num_attention_heads | |||
| self.size_per_head = int(hidden_width / num_attention_heads) | |||
| self.has_attention_mask = has_attention_mask | |||
| assert has_attention_mask | |||
| self.use_one_hot_embeddings = use_one_hot_embeddings | |||
| self.initializer_range = initializer_range | |||
| self.do_return_2d_tensor = do_return_2d_tensor | |||
| @@ -83,11 +97,9 @@ class MultiheadAttention(nn.Cell): | |||
| self.shape_k_2d = (-1, k_tensor_width) | |||
| self.shape_v_2d = (-1, v_tensor_width) | |||
| self.hidden_width = int(hidden_width) | |||
| # units = num_attention_heads * self.size_per_head | |||
| if self.same_dim: | |||
| self.in_proj_layer = \ | |||
| Parameter(Tensor(np.random.rand(hidden_width * 3, | |||
| q_tensor_width), dtype=compute_type), name="weight") | |||
| self.in_proj_layer = Parameter(Tensor(np.random.rand(hidden_width * 3, | |||
| q_tensor_width), dtype=mstype.float32), name="weight") | |||
| else: | |||
| self.query_layer = nn.Dense(q_tensor_width, | |||
| hidden_width, | |||
| @@ -132,8 +144,10 @@ class MultiheadAttention(nn.Cell): | |||
| self.equal = P.Equal() | |||
| self.shape = P.Shape() | |||
| def construct(self, tensor_q, tensor_k, tensor_v, attention_mask=None): | |||
| """Apply multihead attention.""" | |||
| def construct(self, tensor_q, tensor_k, tensor_v): | |||
| """ | |||
| Apply multihead attention. | |||
| """ | |||
| batch_size, seq_length, _ = self.shape(tensor_q) | |||
| shape_qkv = (batch_size, -1, | |||
| self.num_attention_heads, self.size_per_head) | |||
| @@ -161,20 +175,14 @@ class MultiheadAttention(nn.Cell): | |||
| _start = 0 | |||
| _end = self.hidden_width | |||
| _w = self.in_proj_layer[_start:_end, :] | |||
| # _b = None | |||
| query_out = self.matmul_dense(_w, tensor_q_2d) | |||
| _start = self.hidden_width | |||
| _end = self.hidden_width * 2 | |||
| _w = self.in_proj_layer[_start:_end, :] | |||
| # _b = None | |||
| key_out = self.matmul_dense(_w, tensor_k_2d) | |||
| _start = self.hidden_width * 2 | |||
| _end = None | |||
| _w = self.in_proj_layer[_start:] | |||
| # _b = None | |||
| value_out = self.matmul_dense(_w, tensor_v_2d) | |||
| else: | |||
| query_out = self.query_layer(tensor_q_2d) | |||
| @@ -193,8 +201,7 @@ class MultiheadAttention(nn.Cell): | |||
| attention_scores = self.softmax_cast(attention_scores, mstype.float32) | |||
| attention_probs = self.softmax(attention_scores) | |||
| attention_probs = self.softmax_cast( | |||
| attention_probs, self.get_dtype(key_layer)) | |||
| attention_probs = self.softmax_cast(attention_probs, mstype.float16) | |||
| if self.use_dropout: | |||
| attention_probs = self.dropout(attention_probs) | |||
| @@ -212,11 +219,8 @@ class MultiheadAttention(nn.Cell): | |||
| class TransformerEncoderLayer(nn.Cell): | |||
| """ipt""" | |||
| def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, | |||
| activation="relu"): | |||
| def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, compute_type=mstype.float16): | |||
| super().__init__() | |||
| self.self_attn = MultiheadAttention(q_tensor_width=d_model, | |||
| k_tensor_width=d_model, | |||
| v_tensor_width=d_model, | |||
| @@ -224,12 +228,12 @@ class TransformerEncoderLayer(nn.Cell): | |||
| out_tensor_width=d_model, | |||
| num_attention_heads=nhead, | |||
| attention_probs_dropout_prob=dropout) | |||
| self.linear1 = nn.Dense(d_model, dim_feedforward) | |||
| self.linear1 = nn.Dense(d_model, dim_feedforward).to_float(compute_type) | |||
| self.dropout = nn.Dropout(1. - dropout) | |||
| self.linear2 = nn.Dense(dim_feedforward, d_model) | |||
| self.norm1 = nn.LayerNorm([d_model]) | |||
| self.norm2 = nn.LayerNorm([d_model]) | |||
| self.norm1 = LayerPreprocess(d_model) | |||
| self.norm2 = LayerPreprocess(d_model) | |||
| self.dropout1 = nn.Dropout(1. - dropout) | |||
| self.dropout2 = nn.Dropout(1. - dropout) | |||
| self.reshape = P.Reshape() | |||
| @@ -237,7 +241,6 @@ class TransformerEncoderLayer(nn.Cell): | |||
| self.activation = P.ReLU() | |||
| def with_pos_embed(self, tensor, pos): | |||
| """ipt""" | |||
| return tensor if pos is None else tensor + pos | |||
| def construct(self, src, pos=None): | |||
| @@ -258,10 +261,8 @@ class TransformerEncoderLayer(nn.Cell): | |||
| class TransformerDecoderLayer(nn.Cell): | |||
| """ipt""" | |||
| def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, | |||
| activation="relu"): | |||
| """ ipt""" | |||
| def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1): | |||
| super().__init__() | |||
| self.self_attn = MultiheadAttention(q_tensor_width=d_model, | |||
| k_tensor_width=d_model, | |||
| @@ -281,9 +282,9 @@ class TransformerDecoderLayer(nn.Cell): | |||
| self.dropout = nn.Dropout(1. - dropout) | |||
| self.linear2 = nn.Dense(dim_feedforward, d_model) | |||
| self.norm1 = nn.LayerNorm([d_model]) | |||
| self.norm2 = nn.LayerNorm([d_model]) | |||
| self.norm3 = nn.LayerNorm([d_model]) | |||
| self.norm1 = LayerPreprocess(d_model) | |||
| self.norm2 = LayerPreprocess(d_model) | |||
| self.norm3 = LayerPreprocess(d_model) | |||
| self.dropout1 = nn.Dropout(1. - dropout) | |||
| self.dropout2 = nn.Dropout(1. - dropout) | |||
| self.dropout3 = nn.Dropout(1. - dropout) | |||
| @@ -291,7 +292,6 @@ class TransformerDecoderLayer(nn.Cell): | |||
| self.activation = P.ReLU() | |||
| def with_pos_embed(self, tensor, pos): | |||
| """ipt""" | |||
| return tensor if pos is None else tensor + pos | |||
| def construct(self, tgt, memory, pos=None, query_pos=None): | |||
| @@ -306,7 +306,7 @@ class TransformerDecoderLayer(nn.Cell): | |||
| tgt2 = self.norm2(tgt) | |||
| tgt2 = self.multihead_attn(tensor_q=self.with_pos_embed(tgt2, query_pos), | |||
| tensor_k=self.with_pos_embed(memory, pos), | |||
| tensor_v=memory,) | |||
| tensor_v=memory) | |||
| tgt = tgt + self.dropout2(tgt2) | |||
| tgt2 = self.norm3(tgt) | |||
| tgt2 = self.reshape(tgt2, permute_linear) | |||
| @@ -318,47 +318,38 @@ class TransformerDecoderLayer(nn.Cell): | |||
| class TransformerEncoder(nn.Cell): | |||
| """ipt""" | |||
| def __init__(self, encoder_layer, num_layers): | |||
| super().__init__() | |||
| self.layers = _get_clones(encoder_layer, num_layers) | |||
| self.num_layers = num_layers | |||
| def construct(self, src, pos=None): | |||
| """ipt""" | |||
| output = src | |||
| for layer in self.layers: | |||
| output = layer(output, pos=pos) | |||
| return output | |||
| class TransformerDecoder(nn.Cell): | |||
| """ipt""" | |||
| def __init__(self, decoder_layer, num_layers): | |||
| super().__init__() | |||
| self.layers = _get_clones(decoder_layer, num_layers) | |||
| self.num_layers = num_layers | |||
| def construct(self, tgt, memory, pos=None, query_pos=None): | |||
| """ipt""" | |||
| output = tgt | |||
| for layer in self.layers: | |||
| output = layer(output, memory, pos=pos, query_pos=query_pos) | |||
| return output | |||
| def _get_clones(module, N): | |||
| """ipt""" | |||
| return nn.CellList([copy.deepcopy(module) for i in range(N)]) | |||
| def _get_clones(module, n): | |||
| return nn.CellList([copy.deepcopy(module) for i in range(n)]) | |||
| class LearnedPositionalEncoding(nn.Cell): | |||
| """ipt""" | |||
| def __init__(self, max_position_embeddings, embedding_dim, seq_length): | |||
| super(LearnedPositionalEncoding, self).__init__() | |||
| self.pe = nn.Embedding( | |||
| @@ -370,8 +361,7 @@ class LearnedPositionalEncoding(nn.Cell): | |||
| self.position_ids = self.reshape( | |||
| self.position_ids, (1, self.seq_length)) | |||
| def construct(self, x, position_ids=None): | |||
| """ipt""" | |||
| def construct(self, position_ids=None): | |||
| if position_ids is None: | |||
| position_ids = self.position_ids[:, : self.seq_length] | |||
| @@ -381,46 +371,35 @@ class LearnedPositionalEncoding(nn.Cell): | |||
| class VisionTransformer(nn.Cell): | |||
| """ipt""" | |||
| def __init__( | |||
| self, | |||
| img_dim, | |||
| patch_dim, | |||
| num_channels, | |||
| embedding_dim, | |||
| num_heads, | |||
| num_layers, | |||
| hidden_dim, | |||
| num_queries, | |||
| idx, | |||
| positional_encoding_type="learned", | |||
| dropout_rate=0, | |||
| norm=False, | |||
| mlp=False, | |||
| pos_every=False, | |||
| no_pos=False | |||
| ): | |||
| def __init__(self, | |||
| img_dim, | |||
| patch_dim, | |||
| num_channels, | |||
| embedding_dim, | |||
| num_heads, | |||
| num_layers, | |||
| hidden_dim, | |||
| num_queries, | |||
| dropout_rate=0, | |||
| norm=False, | |||
| mlp=False, | |||
| pos_every=False, | |||
| no_pos=False, | |||
| con_loss=False): | |||
| super(VisionTransformer, self).__init__() | |||
| assert embedding_dim % num_heads == 0 | |||
| assert img_dim % patch_dim == 0 | |||
| self.norm = norm | |||
| self.mlp = mlp | |||
| self.embedding_dim = embedding_dim | |||
| self.num_heads = num_heads | |||
| self.patch_dim = patch_dim | |||
| self.num_channels = num_channels | |||
| self.img_dim = img_dim | |||
| self.pos_every = pos_every | |||
| self.num_patches = int((img_dim // patch_dim) ** 2) | |||
| self.seq_length = self.num_patches | |||
| self.flatten_dim = patch_dim * patch_dim * num_channels | |||
| self.out_dim = patch_dim * patch_dim * num_channels | |||
| self.no_pos = no_pos | |||
| self.unf = _unfold_(patch_dim) | |||
| self.fold = _fold_(patch_dim, output_shape=(img_dim, img_dim)) | |||
| @@ -432,8 +411,7 @@ class VisionTransformer(nn.Cell): | |||
| nn.Dropout(1. - dropout_rate), | |||
| nn.ReLU(), | |||
| nn.Dense(hidden_dim, self.out_dim), | |||
| nn.Dropout(1. - dropout_rate) | |||
| ) | |||
| nn.Dropout(1. - dropout_rate)) | |||
| self.query_embed = nn.Embedding( | |||
| num_queries, embedding_dim * self.seq_length) | |||
| @@ -449,55 +427,54 @@ class VisionTransformer(nn.Cell): | |||
| self.tile = P.Tile() | |||
| self.transpose = P.Transpose() | |||
| if not self.no_pos: | |||
| self.position_encoding = LearnedPositionalEncoding( | |||
| self.seq_length, self.embedding_dim, self.seq_length | |||
| ) | |||
| self.position_encoding = LearnedPositionalEncoding(self.seq_length, self.embedding_dim, self.seq_length) | |||
| self.dropout_layer1 = nn.Dropout(1. - dropout_rate) | |||
| self.query_idx = idx | |||
| self.query_idx_tensor = Tensor(idx, mstype.int32) | |||
| def construct(self, x): | |||
| self.con_loss = con_loss | |||
| def construct(self, x, query_idx_tensor): | |||
| """ipt""" | |||
| B, _, _, _ = x.shape | |||
| x = self.unf(x) | |||
| B, N, _ = x.shape | |||
| b, n, _ = x.shape | |||
| if self.mlp is not True: | |||
| x = self.reshape(x, (B * N, -1)) | |||
| x = self.reshape(x, (b * n, -1)) | |||
| x = self.dropout_layer1(self.linear_encoding(x)) + x | |||
| x = self.reshape(x, (B, N, -1)) | |||
| x = self.reshape(x, (b, n, -1)) | |||
| query_embed = self.tile( | |||
| self.reshape(self.query_embed(self.query_idx_tensor), (1, self.seq_length, self.embedding_dim)), | |||
| (B, 1, 1)) | |||
| self.reshape(self.query_embed(query_idx_tensor), (1, self.seq_length, self.embedding_dim)), (b, 1, 1)) | |||
| if not self.no_pos: | |||
| pos = self.position_encoding(x) | |||
| pos = self.position_encoding() | |||
| x = self.encoder(x + pos) | |||
| else: | |||
| x = self.encoder(x) | |||
| x = self.decoder(x, x, query_pos=query_embed) | |||
| if self.mlp is not True: | |||
| x = self.reshape(x, (B * N, -1)) | |||
| x = self.reshape(x, (b * n, -1)) | |||
| x = self.mlp_head(x) + x | |||
| x = self.reshape(x, (B, N, -1)) | |||
| x = self.reshape(x, (b, n, -1)) | |||
| if self.con_loss: | |||
| con_x = x | |||
| x = self.fold(x) | |||
| return x, con_x | |||
| x = self.fold(x) | |||
| return x | |||
| def default_conv(in_channels, out_channels, kernel_size, has_bias=True): | |||
| """ipt""" | |||
| return nn.Conv2d( | |||
| in_channels, out_channels, kernel_size, has_bias=has_bias) | |||
| return nn.Conv2d(in_channels, out_channels, kernel_size, has_bias=has_bias) | |||
| class MeanShift(nn.Conv2d): | |||
| """ipt""" | |||
| def __init__( | |||
| self, rgb_range, | |||
| rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1): | |||
| def __init__(self, | |||
| rgb_range, | |||
| rgb_mean=(0.4488, 0.4371, 0.4040), | |||
| rgb_std=(1.0, 1.0, 1.0), | |||
| sign=-1): | |||
| super(MeanShift, self).__init__(3, 3, kernel_size=1) | |||
| self.reshape = P.Reshape() | |||
| self.eye = P.Eye() | |||
| @@ -512,10 +489,14 @@ class MeanShift(nn.Conv2d): | |||
| class ResBlock(nn.Cell): | |||
| """ipt""" | |||
| def __init__( | |||
| self, conv, n_feats, kernel_size, | |||
| bias=True, bn=False, act=nn.ReLU(), res_scale=1): | |||
| def __init__(self, | |||
| conv, | |||
| n_feats, | |||
| kernel_size, | |||
| bias=True, | |||
| bn=False, | |||
| act=nn.ReLU(), | |||
| res_scale=1): | |||
| super(ResBlock, self).__init__() | |||
| m = [] | |||
| @@ -532,35 +513,28 @@ class ResBlock(nn.Cell): | |||
| self.mul = P.Mul() | |||
| def construct(self, x): | |||
| """ipt""" | |||
| res = self.mul(self.body(x), self.res_scale) | |||
| res += x | |||
| return res | |||
| def _pixelsf_(x, scale): | |||
| """ipt""" | |||
| N, C, iH, iW = x.shape | |||
| oH = iH * scale | |||
| oW = iW * scale | |||
| oC = C // (scale ** 2) | |||
| output = P.Reshape()(x, (N, oC, scale, scale, iH, iW)) | |||
| output = P.Transpose()(output, (0, 1, 5, 3, 4, 2)) | |||
| output = P.Reshape()(output, (N, oC, oH, oW)) | |||
| output = P.Transpose()(output, (0, 1, 3, 2)) | |||
| n, c, ih, iw = x.shape | |||
| oh = ih * scale | |||
| ow = iw * scale | |||
| oc = c // (scale ** 2) | |||
| output = P.Transpose()(x, (0, 2, 1, 3)) | |||
| output = P.Reshape()(output, (n, ih, oc*scale, scale, iw)) | |||
| output = P.Transpose()(output, (0, 1, 2, 4, 3)) | |||
| output = P.Reshape()(output, (n, ih, oc, scale, ow)) | |||
| output = P.Transpose()(output, (0, 2, 1, 3, 4)) | |||
| output = P.Reshape()(output, (n, oc, oh, ow)) | |||
| return output | |||
| class SmallUpSampler(nn.Cell): | |||
| """ipt""" | |||
| def __init__(self, conv, upsize, n_feats, bn=False, act=False, bias=True): | |||
| def __init__(self, conv, upsize, n_feats, bias=True): | |||
| super(SmallUpSampler, self).__init__() | |||
| self.conv = conv(n_feats, upsize * upsize * n_feats, 3, bias) | |||
| self.reshape = P.Reshape() | |||
| @@ -568,7 +542,6 @@ class SmallUpSampler(nn.Cell): | |||
| self.pixelsf = _pixelsf_ | |||
| def construct(self, x): | |||
| """ipt""" | |||
| x = self.conv(x) | |||
| output = self.pixelsf(x, self.upsize) | |||
| return output | |||
| @@ -576,47 +549,37 @@ class SmallUpSampler(nn.Cell): | |||
| class Upsampler(nn.Cell): | |||
| """ipt""" | |||
| def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True): | |||
| def __init__(self, conv, scale, n_feats, bias=True): | |||
| super(Upsampler, self).__init__() | |||
| m = [] | |||
| if (scale & (scale - 1)) == 0: | |||
| for _ in range(int(math.log(scale, 2))): | |||
| m.append(SmallUpSampler(conv, 2, n_feats, bias=bias)) | |||
| elif scale == 3: | |||
| m.append(SmallUpSampler(conv, 3, n_feats, bias=bias)) | |||
| self.net = nn.SequentialCell(m) | |||
| def construct(self, x): | |||
| """ipt""" | |||
| return self.net(x) | |||
| class IPT(nn.Cell): | |||
| """ipt""" | |||
| def __init__(self, args, conv=default_conv): | |||
| super(IPT, self).__init__() | |||
| self.dytpe = mstype.float16 | |||
| self.scale_idx = 0 | |||
| self.args = args | |||
| self.con_loss = args.con_loss | |||
| n_feats = args.n_feats | |||
| kernel_size = 3 | |||
| act = nn.ReLU() | |||
| self.sub_mean = MeanShift(args.rgb_range) | |||
| self.add_mean = MeanShift(args.rgb_range, sign=1) | |||
| self.head = nn.CellList([ | |||
| nn.SequentialCell( | |||
| conv(args.n_colors, n_feats, kernel_size), | |||
| ResBlock(conv, n_feats, 5, act=act), | |||
| ResBlock(conv, n_feats, 5, act=act) | |||
| ) for _ in args.scale | |||
| ]) | |||
| nn.SequentialCell(conv(args.n_colors, n_feats, kernel_size).to_float(self.dytpe), | |||
| ResBlock(conv, n_feats, 5, act=act).to_float(self.dytpe), | |||
| ResBlock(conv, n_feats, 5, act=act).to_float(self.dytpe)) for _ in range(6)]) | |||
| self.body = VisionTransformer(img_dim=args.patch_size, | |||
| patch_dim=args.patch_dim, | |||
| @@ -630,36 +593,34 @@ class IPT(nn.Cell): | |||
| mlp=args.no_mlp, | |||
| pos_every=args.pos_every, | |||
| no_pos=args.no_pos, | |||
| idx=self.scale_idx) | |||
| con_loss=args.con_loss).to_float(self.dytpe) | |||
| self.tail = nn.CellList([ | |||
| nn.SequentialCell( | |||
| Upsampler(conv, s, n_feats, act=False), | |||
| conv(n_feats, args.n_colors, kernel_size) | |||
| ) for s in args.scale | |||
| ]) | |||
| nn.SequentialCell(Upsampler(conv, s, n_feats).to_float(self.dytpe), | |||
| conv(n_feats, args.n_colors, kernel_size).to_float(self.dytpe)) \ | |||
| for s in [2, 3, 4, 1, 1, 1]]) | |||
| self.reshape = P.Reshape() | |||
| self.tile = P.Tile() | |||
| self.transpose = P.Transpose() | |||
| self.s2t = P.ScalarToTensor() | |||
| self.cast = P.Cast() | |||
| def construct(self, x): | |||
| def construct(self, x, idx): | |||
| """ipt""" | |||
| x = self.sub_mean(x) | |||
| x = self.head[self.scale_idx](x) | |||
| res = self.body(x) | |||
| idx_num = idx.shape[0] | |||
| x = self.head[idx_num](x) | |||
| idx_tensor = self.cast(self.s2t(idx_num), mstype.int32) | |||
| if self.con_loss: | |||
| res, x_con = self.body(x, idx_tensor) | |||
| res += x | |||
| x = self.tail[idx_num](x) | |||
| return x, x_con | |||
| res = self.body(x, idx_tensor) | |||
| res += x | |||
| x = self.tail[self.scale_idx](res) | |||
| x = self.add_mean(x) | |||
| x = self.tail[idx_num](res) | |||
| return x | |||
| def set_scale(self, scale_idx): | |||
| """ipt""" | |||
| self.body.query_idx = scale_idx | |||
| self.scale_idx = scale_idx | |||
| class IPT_post(): | |||
| """ipt""" | |||
| def __init__(self, model, args): | |||
| @@ -674,17 +635,13 @@ class IPT_post(): | |||
| self.cc_2 = P.Concat(axis=2) | |||
| self.cc_3 = P.Concat(axis=3) | |||
| def set_scale(self, scale_idx): | |||
| """ipt""" | |||
| self.body.query_idx = scale_idx | |||
| self.scale_idx = scale_idx | |||
| def forward(self, x, shave=12, batchsize=64): | |||
| def forward(self, x, idx, shave=12, batchsize=64): | |||
| """ipt""" | |||
| self.idx = idx | |||
| h, w = x.shape[-2:] | |||
| padsize = int(self.args.patch_size) | |||
| shave = int(self.args.patch_size / 4) | |||
| scale = self.args.scale[self.scale_idx] | |||
| scale = self.args.scale[0] | |||
| h_cut = (h - padsize) % (padsize - shave) | |||
| w_cut = (w - padsize) % (padsize - shave) | |||
| @@ -692,7 +649,7 @@ class IPT_post(): | |||
| x_unfold = unf_1.compute(x) | |||
| x_unfold = self.transpose(x_unfold, (1, 0, 2)) # transpose(0,2) | |||
| x_hw_cut = x[:, :, (h - padsize):, (w - padsize):] | |||
| y_hw_cut = self.model(x_hw_cut) | |||
| y_hw_cut = self.model(x_hw_cut, self.idx) | |||
| x_h_cut = x[:, :, (h - padsize):, :] | |||
| x_w_cut = x[:, :, :, (w - padsize):] | |||
| @@ -714,10 +671,10 @@ class IPT_post(): | |||
| for i in range(x_range): | |||
| if i == 0: | |||
| y_unfold = self.model( | |||
| x_unfold[i * batchsize:(i + 1) * batchsize, :, :, :]) | |||
| x_unfold[i * batchsize:(i + 1) * batchsize, :, :, :], self.idx) | |||
| else: | |||
| y_unfold = self.cc_0((y_unfold, self.model( | |||
| x_unfold[i * batchsize:(i + 1) * batchsize, :, :, :]))) | |||
| x_unfold[i * batchsize:(i + 1) * batchsize, :, :, :], self.idx))) | |||
| y_unf_shape_0 = y_unfold.shape[0] | |||
| fold_1 = \ | |||
| _stride_fold_(padsize * scale, output_shape=((h - h_cut) * scale, (w - w_cut) * scale), | |||
| @@ -740,17 +697,18 @@ class IPT_post(): | |||
| stride=padsize * scale - shave * scale) | |||
| y_inter = fold_2.compute(self.transpose(self.reshape( | |||
| y_unfold, (y_unf_shape_0, -1, 1)), (2, 0, 1))) | |||
| concat1 = self.cc_2((y[:, :, :int(shave / 2 * scale), int(shave / 2 * scale):(w - w_cut) * scale - int(shave / 2 * scale)], y_inter)) #pylint: disable=line-too-long | |||
| concat2 = self.cc_2((concat1, y[:, :, (h - h_cut) * scale - int(shave / 2 * scale):, int(shave / 2 * scale):(w - w_cut) * scale - int(shave / 2 * scale)])) #pylint: disable=line-too-long | |||
| concat1 = self.cc_2((y[:, :, :int(shave / 2 * scale), \ | |||
| int(shave / 2 * scale):(w - w_cut) * scale - int(shave / 2 * scale)], y_inter)) | |||
| concat2 = self.cc_2((concat1, y[:, :, (h - h_cut) * scale - int(shave / 2 * scale):, \ | |||
| int(shave / 2 * scale):(w - w_cut) * scale - int(shave / 2 * scale)])) | |||
| concat3 = self.cc_3((y[:, :, :, :int(shave / 2 * scale)], concat2)) | |||
| y = self.cc_3((concat3, y[:, :, :, (w - w_cut) * scale - int(shave / 2 * scale):])) #pylint: disable=line-too-long | |||
| y = self.cc_2((y[:, :, :y.shape[2] - int((padsize - h_cut) / 2 * scale), :], y_h_cut[:, :, int((padsize - h_cut) / 2 * scale + 0.5):, :])) #pylint: disable=line-too-long | |||
| y = self.cc_3((concat3, y[:, :, :, (w - w_cut) * scale - int(shave / 2 * scale):])) | |||
| y = self.cc_2((y[:, :, :y.shape[2] - int((padsize - h_cut) / 2 * scale), :], | |||
| y_h_cut[:, :, int((padsize - h_cut) / 2 * scale + 0.5):, :])) | |||
| y_w_cat = self.cc_2((y_w_cut[:, :, :y_w_cut.shape[2] - int((padsize - h_cut) / 2 * scale), :], | |||
| y_hw_cut[:, :, int((padsize - h_cut) / 2 * scale + 0.5):, :])) | |||
| y = self.cc_3((y[:, :, :, :y.shape[3] - int((padsize - w_cut) / 2 * scale)], | |||
| y_w_cat[:, :, :, int((padsize - w_cut) / 2 * scale + 0.5):])) | |||
| return y | |||
| def cut_h_new(self, x_h_cut, h, w, h_cut, w_cut, padsize, shave, scale, batchsize): | |||
| @@ -766,11 +724,11 @@ class IPT_post(): | |||
| for i in range(x_range): | |||
| if i == 0: | |||
| y_h_cut_unfold = self.model( | |||
| x_h_cut_unfold[i * batchsize:(i + 1) * batchsize, :, :, :]) | |||
| x_h_cut_unfold[i * batchsize:(i + 1) * batchsize, :, :, :], self.idx) | |||
| else: | |||
| y_h_cut_unfold = \ | |||
| self.cc_0((y_h_cut_unfold, self.model( | |||
| x_h_cut_unfold[i * batchsize:(i + 1) * batchsize, :, :, :]))) | |||
| x_h_cut_unfold[i * batchsize:(i + 1) * batchsize, :, :, :], self.idx))) | |||
| y_h_cut_unfold_shape_0 = y_h_cut_unfold.shape[0] | |||
| fold_1 = \ | |||
| _stride_fold_(padsize * scale, output_shape=(padsize * scale, (w - w_cut) * scale), | |||
| @@ -802,10 +760,11 @@ class IPT_post(): | |||
| for i in range(x_range): | |||
| if i == 0: | |||
| y_w_cut_unfold = self.model( | |||
| x_w_cut_unfold[i * batchsize:(i + 1) * batchsize, :, :, :]) | |||
| x_w_cut_unfold[i * batchsize:(i + 1) * batchsize, :, :, :], self.idx) | |||
| else: | |||
| y_w_cut_unfold = self.cc_0((y_w_cut_unfold, | |||
| self.model(x_w_cut_unfold[i * batchsize:(i + 1) * batchsize, :, :, :]))) | |||
| self.model(x_w_cut_unfold[i * batchsize:(i + 1) * batchsize, :, :, :], \ | |||
| self.idx))) | |||
| y_w_cut_unfold_shape_0 = y_w_cut_unfold.shape[0] | |||
| fold_1 = _stride_fold_(padsize * scale, | |||
| output_shape=((h - h_cut) * scale, | |||
| @@ -827,7 +786,6 @@ class IPT_post(): | |||
| class _stride_unfold_(): | |||
| '''stride''' | |||
| def __init__(self, | |||
| kernel_size, | |||
| stride=-1): | |||
| @@ -874,13 +832,12 @@ class _stride_unfold_(): | |||
| zeros4 = np.zeros(unf_x[:, :, :, unf_j + self.kernel_size:].shape) | |||
| concat4 = np.concatenate((concat3, zeros4), axis=3) | |||
| unf_x += concat4 | |||
| unf_x = Tensor(unf_x, mstype.float32) | |||
| unf_x = Tensor(unf_x, mstype.float16) | |||
| y = self.unfold(unf_x) | |||
| return y | |||
| class _stride_fold_(): | |||
| '''stride''' | |||
| def __init__(self, | |||
| kernel_size, | |||
| output_shape=(-1, -1), | |||
| @@ -905,7 +862,7 @@ class _stride_fold_(): | |||
| self.fold = _fold_(self.kernel_size, self.large_shape) | |||
| def compute(self, x): | |||
| '''stride''' | |||
| """ compute""" | |||
| NumBlock_x = self.NumBlock_x | |||
| NumBlock_y = self.NumBlock_y | |||
| large_x = self.fold(x) | |||
| @@ -917,7 +874,8 @@ class _stride_fold_(): | |||
| leftup_idx_x.append(i * self.kernel_size[0]) | |||
| for i in range(NumBlock_y): | |||
| leftup_idx_y.append(i * self.kernel_size[1]) | |||
| fold_x = np.zeros((N, C, (NumBlock_x - 1) * self.stride + self.kernel_size[0], (NumBlock_y - 1) * self.stride + self.kernel_size[1]), dtype=np.float32) #pylint: disable=line-too-long | |||
| fold_x = np.zeros((N, C, (NumBlock_x - 1) * self.stride + self.kernel_size[0], \ | |||
| (NumBlock_y - 1) * self.stride + self.kernel_size[1]), dtype=np.float32) | |||
| for i in range(NumBlock_x): | |||
| for j in range(NumBlock_y): | |||
| fold_i = i * self.stride | |||
| @@ -938,12 +896,11 @@ class _stride_fold_(): | |||
| zeros4 = np.zeros(t4.shape) | |||
| concat4 = np.concatenate((concat3, zeros4), axis=3) | |||
| fold_x += concat4 | |||
| y = Tensor(fold_x, mstype.float32) | |||
| y = Tensor(fold_x, mstype.float16) | |||
| return y | |||
| class _unfold_(nn.Cell): | |||
| """ipt""" | |||
| def __init__( | |||
| self, kernel_size, stride=-1): | |||
| @@ -965,8 +922,10 @@ class _unfold_(nn.Cell): | |||
| output_img = self.reshape(x, (N, C, numH, self.kernel_size, W)) | |||
| output_img = self.transpose(output_img, (0, 1, 2, 4, 3)) | |||
| output_img = self.reshape(output_img, (N, C, numH, -1, self.kernel_size, self.kernel_size)) | |||
| output_img = self.transpose(output_img, (0, 2, 3, 1, 5, 4)) | |||
| output_img = self.reshape(output_img, (N*C, numH, numW, self.kernel_size, self.kernel_size)) | |||
| output_img = self.transpose(output_img, (0, 1, 2, 4, 3)) | |||
| output_img = self.reshape(output_img, (N, C, numH * numW, self.kernel_size*self.kernel_size)) | |||
| output_img = self.transpose(output_img, (0, 2, 1, 3)) | |||
| output_img = self.reshape(output_img, (N, numH * numW, -1)) | |||
| return output_img | |||
| @@ -1002,14 +961,10 @@ class _fold_(nn.Cell): | |||
| org_W = self.output_shape[1] | |||
| numH = org_H // self.kernel_size[0] | |||
| numW = org_W // self.kernel_size[1] | |||
| output_img = self.reshape( | |||
| x, (N, C, org_C, self.kernel_size[0], self.kernel_size[1])) | |||
| output_img = self.reshape(x, (N, C, org_C, self.kernel_size[0], self.kernel_size[1])) | |||
| output_img = self.transpose(output_img, (0, 2, 3, 1, 4)) | |||
| output_img = self.reshape(output_img, (N*org_C, self.kernel_size[0], numH, numW, self.kernel_size[1])) | |||
| output_img = self.transpose(output_img, (0, 2, 1, 3, 4)) | |||
| output_img = self.reshape( | |||
| output_img, (N, org_C, numH, numW, self.kernel_size[0], self.kernel_size[1])) | |||
| output_img = self.transpose(output_img, (0, 1, 2, 4, 3, 5)) | |||
| output_img = self.reshape(output_img, (N, org_C, org_H, org_W)) | |||
| return output_img | |||
| @@ -0,0 +1,125 @@ | |||
| """loss""" | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import mindspore.nn as nn | |||
| from mindspore import dtype as mstype | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops import composite as C | |||
| from mindspore.ops import functional as F | |||
| GRADIENT_CLIP_TYPE = 1 | |||
| GRADIENT_CLIP_VALUE = 1.0 | |||
| class SupConLoss(nn.Cell): | |||
| """SupConLoss""" | |||
| def __init__(self, temperature=0.07, contrast_mode='all', | |||
| base_temperature=0.07): | |||
| super(SupConLoss, self).__init__() | |||
| self.temperature = temperature | |||
| self.contrast_mode = contrast_mode | |||
| self.base_temperature = base_temperature | |||
| self.normalize = P.L2Normalize(axis=2) | |||
| self.eye = P.Eye() | |||
| self.unbind = P.Unstack(axis=1) | |||
| self.cat = P.Concat(axis=0) | |||
| self.matmul = P.MatMul() | |||
| self.div = P.Div() | |||
| self.transpose = P.Transpose() | |||
| self.maxes = P.ArgMaxWithValue(axis=1, keep_dims=True) | |||
| self.tile = P.Tile() | |||
| self.scatter = P.ScatterNd() | |||
| self.oneslike = P.OnesLike() | |||
| self.exp = P.Exp() | |||
| self.sum = P.ReduceSum(keep_dims=True) | |||
| self.log = P.Log() | |||
| self.reshape = P.Reshape() | |||
| self.mean = P.ReduceMean() | |||
| def construct(self, features): | |||
| """SupConLoss""" | |||
| features = self.normalize(features) | |||
| batch_size = features.shape[0] | |||
| mask = self.eye(batch_size, batch_size, mstype.float32) | |||
| contrast_count = features.shape[1] | |||
| contrast_feature = self.cat(self.unbind(features)) | |||
| if self.contrast_mode == 'all': | |||
| anchor_feature = contrast_feature | |||
| anchor_count = contrast_count | |||
| else: | |||
| anchor_feature = features[:, 0] | |||
| anchor_count = 1 | |||
| anchor_dot_contrast = self.div(self.matmul(anchor_feature, self.transpose(contrast_feature, (1, 0))), \ | |||
| self.temperature) | |||
| _, logits_max = self.maxes(anchor_dot_contrast) | |||
| logits = anchor_dot_contrast - logits_max | |||
| mask = self.tile(mask, (anchor_count, contrast_count)) | |||
| logits_mask = 1 - self.eye(mask.shape[0], mask.shape[1], mstype.float32) | |||
| mask = mask * logits_mask | |||
| exp_logits = self.exp(logits) * logits_mask | |||
| log_prob = logits - self.log(self.sum(exp_logits, 1) + 1e-8) | |||
| mean_log_prob_pos = self.sum((mask * log_prob), 1) / self.sum(mask, 1) | |||
| loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos | |||
| loss = self.mean(self.reshape(loss, (anchor_count, batch_size))) | |||
| return loss, anchor_count | |||
| class ClipGradients(nn.Cell): | |||
| """ | |||
| Clip gradients. | |||
| Args: | |||
| grads (list): List of gradient tuples. | |||
| clip_type (Tensor): The way to clip, 'value' or 'norm'. | |||
| clip_value (Tensor): Specifies how much to clip. | |||
| Returns: | |||
| List, a list of clipped_grad tuples. | |||
| """ | |||
| def __init__(self): | |||
| super(ClipGradients, self).__init__() | |||
| self.clip_by_norm = nn.ClipByNorm() | |||
| self.cast = P.Cast() | |||
| self.dtype = P.DType() | |||
| def construct(self, grads, clip_type, clip_value): | |||
| """ClipGradients""" | |||
| if clip_type not in (0, 1): | |||
| return grads | |||
| new_grads = () | |||
| for grad in grads: | |||
| dt = self.dtype(grad) | |||
| if clip_type == 0: | |||
| t = C.clip_by_value(grad, self.cast(F.tuple_to_array((-clip_value,)), dt), | |||
| self.cast(F.tuple_to_array((clip_value,)), dt)) | |||
| else: | |||
| t = self.clip_by_norm(grad, self.cast(F.tuple_to_array((clip_value,)), dt)) | |||
| t = self.cast(t, dt) | |||
| new_grads = new_grads + (t,) | |||
| return new_grads | |||
| grad_scale = C.MultitypeFuncGraph("grad_scale") | |||
| reciprocal = P.Reciprocal() | |||
| @grad_scale.register("Tensor", "Tensor") | |||
| def tensor_grad_scale(scale, grad): | |||
| return grad * F.cast(reciprocal(scale), F.dtype(grad)) | |||
| _grad_overflow = C.MultitypeFuncGraph("_grad_overflow") | |||
| grad_overflow = P.FloatStatus() | |||
| @_grad_overflow.register("Tensor") | |||
| def _tensor_grad_overflow(grad): | |||
| return grad_overflow(grad) | |||
| @@ -1,4 +1,4 @@ | |||
| '''metrics''' | |||
| """metrics""" | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| @@ -13,12 +13,12 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import math | |||
| import numpy as np | |||
| def quantize(img, rgb_range): | |||
| '''metrics''' | |||
| """metrics""" | |||
| pixel_range = 255 / rgb_range | |||
| img = np.multiply(img, pixel_range) | |||
| img = np.clip(img, 0, 255) | |||
| @@ -26,15 +26,14 @@ def quantize(img, rgb_range): | |||
| return img | |||
| def calc_psnr(sr, hr, scale, rgb_range, y_only=False, dataset=None): | |||
| '''metrics''' | |||
| def calc_psnr(sr, hr, scale, rgb_range): | |||
| """metrics""" | |||
| hr = np.float32(hr) | |||
| sr = np.float32(sr) | |||
| diff = (sr - hr) / rgb_range | |||
| gray_coeffs = np.array([65.738, 129.057, 25.064] | |||
| ).reshape((1, 3, 1, 1)) / 256 | |||
| gray_coeffs = np.array([65.738, 129.057, 25.064]).reshape((1, 3, 1, 1)) / 256 | |||
| diff = np.multiply(diff, gray_coeffs).sum(1) | |||
| if np.size(hr) == 1: | |||
| if hr.size == 1: | |||
| return 0 | |||
| if scale != 1: | |||
| shave = scale | |||
| @@ -49,7 +48,7 @@ def calc_psnr(sr, hr, scale, rgb_range, y_only=False, dataset=None): | |||
| def rgb2ycbcr(img, y_only=True): | |||
| '''metrics''' | |||
| """metrics""" | |||
| img.astype(np.float32) | |||
| if y_only: | |||
| rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0 | |||
| @@ -1,67 +0,0 @@ | |||
| '''temp''' | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| def set_template(args): | |||
| '''temp''' | |||
| if args.template.find('jpeg') >= 0: | |||
| args.data_train = 'DIV2K_jpeg' | |||
| args.data_test = 'DIV2K_jpeg' | |||
| args.epochs = 200 | |||
| args.decay = '100' | |||
| if args.template.find('EDSR_paper') >= 0: | |||
| args.model = 'EDSR' | |||
| args.n_resblocks = 32 | |||
| args.n_feats = 256 | |||
| args.res_scale = 0.1 | |||
| if args.template.find('MDSR') >= 0: | |||
| args.model = 'MDSR' | |||
| args.patch_size = 48 | |||
| args.epochs = 650 | |||
| if args.template.find('DDBPN') >= 0: | |||
| args.model = 'DDBPN' | |||
| args.patch_size = 128 | |||
| args.scale = '4' | |||
| args.data_test = 'Set5' | |||
| args.batch_size = 20 | |||
| args.epochs = 1000 | |||
| args.decay = '500' | |||
| args.gamma = 0.1 | |||
| args.weight_decay = 1e-4 | |||
| args.loss = '1*MSE' | |||
| if args.template.find('GAN') >= 0: | |||
| args.epochs = 200 | |||
| args.lr = 5e-5 | |||
| args.decay = '150' | |||
| if args.template.find('RCAN') >= 0: | |||
| args.model = 'RCAN' | |||
| args.n_resgroups = 10 | |||
| args.n_resblocks = 20 | |||
| args.n_feats = 64 | |||
| args.chop = True | |||
| if args.template.find('VDSR') >= 0: | |||
| args.model = 'VDSR' | |||
| args.n_resblocks = 20 | |||
| args.n_feats = 64 | |||
| args.patch_size = 41 | |||
| args.lr = 1e-1 | |||
| @@ -0,0 +1,173 @@ | |||
| """utils""" | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import os | |||
| import time | |||
| from bisect import bisect_right | |||
| import numpy as np | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore.common import dtype as mstype | |||
| from mindspore.context import ParallelMode | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops import composite as C | |||
| from mindspore.ops import functional as F | |||
| from mindspore.nn.wrap.grad_reducer import DistributedGradReducer | |||
| from mindspore.parallel._utils import _get_parallel_mode | |||
| from mindspore.train.serialization import save_checkpoint | |||
| from src.loss import SupConLoss | |||
| class MyTrain(nn.Cell): | |||
| """MyTrain""" | |||
| def __init__(self, model, criterion, con_loss, use_con=True): | |||
| super(MyTrain, self).__init__(auto_prefix=True) | |||
| self.use_con = use_con | |||
| self.model = model | |||
| self.con_loss = con_loss | |||
| self.criterion = criterion | |||
| self.p = P.Print() | |||
| self.cast = P.Cast() | |||
| def construct(self, lr, hr, idx): | |||
| """MyTrain""" | |||
| if self.use_con: | |||
| sr, x_con = self.model(lr, idx) | |||
| x_con = self.cast(x_con, mstype.float32) | |||
| sr = self.cast(sr, mstype.float32) | |||
| loss1 = self.criterion(sr, hr) | |||
| loss2 = self.con_loss(x_con) | |||
| loss = loss1 + 0.1 * loss2 | |||
| else: | |||
| sr = self.model(lr, idx) | |||
| sr = self.cast(sr, mstype.float32) | |||
| loss = self.criterion(sr, hr) | |||
| return loss | |||
| class MyTrainOneStepCell(nn.Cell): | |||
| """MyTrainOneStepCell""" | |||
| def __init__(self, network, optimizer, sens=1.0): | |||
| super(MyTrainOneStepCell, self).__init__(auto_prefix=False) | |||
| self.network = network | |||
| self.network.set_grad() | |||
| self.weights = optimizer.parameters | |||
| self.optimizer = optimizer | |||
| self.grad = C.GradOperation(get_by_list=True, sens_param=True) | |||
| self.sens = sens | |||
| self.reducer_flag = False | |||
| self.grad_reducer = None | |||
| parallel_mode = _get_parallel_mode() | |||
| if parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL): | |||
| self.reducer_flag = True | |||
| if self.reducer_flag: | |||
| self.grad_reducer = DistributedGradReducer(optimizer.parameters, True, 8) | |||
| def construct(self, *args): | |||
| weights = self.weights | |||
| loss = self.network(*args) | |||
| sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) | |||
| grads = self.grad(self.network, weights)(*args, sens) | |||
| if self.reducer_flag: | |||
| grads = self.grad_reducer(grads) | |||
| return F.depend(loss, self.optimizer(grads)) | |||
| def sub_mean(x): | |||
| red_channel_mean = 0.4488 * 255 | |||
| green_channel_mean = 0.4371 * 255 | |||
| blue_channel_mean = 0.4040 * 255 | |||
| x[:, 0, :, :] -= red_channel_mean | |||
| x[:, 1, :, :] -= green_channel_mean | |||
| x[:, 2, :, :] -= blue_channel_mean | |||
| return x | |||
| def add_mean(x): | |||
| red_channel_mean = 0.4488 * 255 | |||
| green_channel_mean = 0.4371 * 255 | |||
| blue_channel_mean = 0.4040 * 255 | |||
| x[:, 0, :, :] += red_channel_mean | |||
| x[:, 1, :, :] += green_channel_mean | |||
| x[:, 2, :, :] += blue_channel_mean | |||
| return x | |||
| class Trainer(): | |||
| """Trainer""" | |||
| def __init__(self, args, loader, my_model): | |||
| self.args = args | |||
| self.scale = args.scale | |||
| self.trainloader = loader | |||
| self.model = my_model | |||
| self.model.set_train() | |||
| self.criterion = nn.L1Loss() | |||
| self.con_loss = SupConLoss() | |||
| self.optimizer = nn.Adam(self.model.trainable_params(), learning_rate=args.lr, loss_scale=1024.0) | |||
| self.train_net = MyTrain(self.model, self.criterion, self.con_loss, use_con=args.con_loss) | |||
| self.bp = MyTrainOneStepCell(self.train_net, self.optimizer, 1024.0) | |||
| def train(self): | |||
| """Trainer""" | |||
| losses = 0 | |||
| for batch_idx, imgs in enumerate(self.trainloader): | |||
| lr = imgs["LR"] | |||
| hr = imgs["HR"] | |||
| lr = Tensor(sub_mean(lr), mstype.float32) | |||
| hr = Tensor(sub_mean(hr), mstype.float32) | |||
| idx = Tensor(np.ones(imgs["idx"][0]), mstype.int32) | |||
| t1 = time.time() | |||
| loss = self.bp(lr, hr, idx) | |||
| t2 = time.time() | |||
| losses += loss.asnumpy() | |||
| print('Task: %g, Step: %g, loss: %f, time: %f s' % (idx.shape[0], batch_idx, loss.asnumpy(), t2 - t1), | |||
| flush=True) | |||
| os.makedirs(self.args.save, exist_ok=True) | |||
| if self.args.rank == 0: | |||
| save_checkpoint(self.bp, self.args.save + "model_" + str(self.epoch) + '.ckpt') | |||
| def update_learning_rate(self, epoch): | |||
| """Update learning rates for all the networks; called at the end of every epoch. | |||
| :param epoch: current epoch | |||
| :type epoch: int | |||
| :param lr: learning rate of cyclegan | |||
| :type lr: float | |||
| :param niter: number of epochs with the initial learning rate | |||
| :type niter: int | |||
| :param niter_decay: number of epochs to linearly decay learning rate to zero | |||
| :type niter_decay: int | |||
| """ | |||
| self.epoch = epoch | |||
| value = self.args.decay.split('-') | |||
| value.sort(key=int) | |||
| milestones = list(map(int, value)) | |||
| print("*********** epoch: {} **********".format(epoch)) | |||
| lr = self.args.lr * self.args.gamma ** bisect_right(milestones, epoch) | |||
| self.adjust_lr('model', self.optimizer, lr) | |||
| print("*********************************") | |||
| def adjust_lr(self, name, optimizer, lr): | |||
| """Adjust learning rate for the corresponding model. | |||
| :param name: name of model | |||
| :type name: str | |||
| :param optimizer: the optimizer of the corresponding model | |||
| :type optimizer: torch.optim | |||
| :param lr: learning rate to be adjusted | |||
| :type lr: float | |||
| """ | |||
| lr_param = optimizer.get_lr() | |||
| lr_param.assign_value(Tensor(lr, mstype.float32)) | |||
| print('==> ' + name + ' learning rate: ', lr_param.asnumpy()) | |||
| @@ -0,0 +1,154 @@ | |||
| """train""" | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import os | |||
| import math | |||
| import mindspore.dataset as ds | |||
| from mindspore import Parameter, set_seed, context | |||
| from mindspore.context import ParallelMode | |||
| from mindspore.common.initializer import initializer, HeUniform, XavierUniform, Uniform, Normal, Zero | |||
| from mindspore.communication.management import init, get_rank, get_group_size | |||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| from src.args import args | |||
| from src.data.bicubic import bicubic | |||
| from src.data.imagenet import ImgData | |||
| from src.ipt_model import IPT | |||
| from src.utils import Trainer | |||
| def _calculate_fan_in_and_fan_out(shape): | |||
| """ | |||
| calculate fan_in and fan_out | |||
| Args: | |||
| shape (tuple): input shape. | |||
| Returns: | |||
| Tuple, a tuple with two elements, the first element is `n_in` and the second element is `n_out`. | |||
| """ | |||
| dimensions = len(shape) | |||
| if dimensions < 2: | |||
| raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions") | |||
| if dimensions == 2: | |||
| fan_in = shape[1] | |||
| fan_out = shape[0] | |||
| else: | |||
| num_input_fmaps = shape[1] | |||
| num_output_fmaps = shape[0] | |||
| receptive_field_size = 1 | |||
| if dimensions > 2: | |||
| receptive_field_size = shape[2] * shape[3] | |||
| fan_in = num_input_fmaps * receptive_field_size | |||
| fan_out = num_output_fmaps * receptive_field_size | |||
| return fan_in, fan_out | |||
| def init_weights(net, init_type='normal', init_gain=0.02): | |||
| """ | |||
| Initialize network weights. | |||
| :param net: network to be initialized | |||
| :type net: nn.Module | |||
| :param init_type: the name of an initialization method: normal | xavier | kaiming | orthogonal | |||
| :type init_type: str | |||
| :param init_gain: scaling factor for normal, xavier and orthogonal. | |||
| :type init_gain: float | |||
| """ | |||
| for _, cell in net.cells_and_names(): | |||
| classname = cell.__class__.__name__ | |||
| if hasattr(cell, 'in_proj_layer'): | |||
| cell.in_proj_layer = Parameter(initializer(HeUniform(negative_slope=math.sqrt(5)), cell.in_proj_layer.shape, | |||
| cell.in_proj_layer.dtype), name=cell.in_proj_layer.name) | |||
| if hasattr(cell, 'weight'): | |||
| if init_type == 'normal': | |||
| cell.weight = Parameter(initializer(Normal(init_gain), cell.weight.shape, | |||
| cell.weight.dtype), name=cell.weight.name) | |||
| elif init_type == 'xavier': | |||
| cell.weight = Parameter(initializer(XavierUniform(init_gain), cell.weight.shape, | |||
| cell.weight.dtype), name=cell.weight.name) | |||
| elif init_type == "he": | |||
| cell.weight = Parameter(initializer(HeUniform(negative_slope=math.sqrt(5)), cell.weight.shape, | |||
| cell.weight.dtype), name=cell.weight.name) | |||
| else: | |||
| raise NotImplementedError('initialization method [%s] is not implemented' % init_type) | |||
| if hasattr(cell, 'bias') and cell.bias is not None: | |||
| fan_in, _ = _calculate_fan_in_and_fan_out(cell.weight.shape) | |||
| bound = 1 / math.sqrt(fan_in) | |||
| cell.bias = Parameter(initializer(Uniform(bound), cell.bias.shape, cell.bias.dtype), | |||
| name=cell.bias.name) | |||
| elif classname.find('BatchNorm2d') != -1: | |||
| cell.gamma = Parameter(initializer(Normal(1.0), cell.gamma.default_input.shape()), name=cell.gamma.name) | |||
| cell.beta = Parameter(initializer(Zero(), cell.beta.default_input.shape()), name=cell.beta.name) | |||
| print('initialize network weight with %s' % init_type) | |||
| def train_net(distribute, imagenet, epochs): | |||
| """Train net""" | |||
| set_seed(1) | |||
| device_id = int(os.getenv('DEVICE_ID', '0')) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id) | |||
| if imagenet == 1: | |||
| train_dataset = ImgData(args) | |||
| else: | |||
| train_dataset = data.Data(args).loader_train | |||
| if distribute: | |||
| init() | |||
| rank_id = get_rank() | |||
| rank_size = get_group_size() | |||
| parallel_mode = ParallelMode.DATA_PARALLEL | |||
| context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=rank_size, gradients_mean=True) | |||
| print('Rank {}, rank_size {}'.format(rank_id, rank_size)) | |||
| if imagenet == 1: | |||
| train_de_dataset = ds.GeneratorDataset(train_dataset, | |||
| ["HR", "Rain", "LRx2", "LRx3", "LRx4", "scales", "filename"], | |||
| num_shards=rank_size, shard_id=args.rank, shuffle=True) | |||
| else: | |||
| train_de_dataset = ds.GeneratorDataset(train_dataset, ["LR", "HR"], num_shards=rank_size, | |||
| shard_id=rank_id, shuffle=True) | |||
| else: | |||
| if imagenet == 1: | |||
| train_de_dataset = ds.GeneratorDataset(train_dataset, | |||
| ["HR", "Rain", "LRx2", "LRx3", "LRx4", "scales", "filename"], | |||
| shuffle=True) | |||
| else: | |||
| train_de_dataset = ds.GeneratorDataset(train_dataset, ["LR", "HR"], shuffle=True) | |||
| resize_fuc = bicubic() | |||
| train_de_dataset = train_de_dataset.project(columns=["HR", "Rain", "LRx2", "LRx3", "LRx4", "filename"]) | |||
| train_de_dataset = train_de_dataset.batch(args.batch_size, | |||
| input_columns=["HR", "Rain", "LRx2", "LRx3", "LRx4", "filename"], | |||
| output_columns=["LR", "HR", "idx", "filename"], | |||
| drop_remainder=True, per_batch_map=resize_fuc.forward) | |||
| train_loader = train_de_dataset.create_dict_iterator(output_numpy=True) | |||
| net_work = IPT(args) | |||
| init_weights(net_work, init_type='he', init_gain=1.0) | |||
| print("Init net weight successfully") | |||
| if args.pth_path: | |||
| param_dict = load_checkpoint(args.pth_path) | |||
| load_param_into_net(net_work, param_dict) | |||
| print("Load net weight successfully") | |||
| train_func = Trainer(args, train_loader, net_work) | |||
| for epoch in range(0, epochs): | |||
| train_func.update_learning_rate(epoch) | |||
| train_func.train() | |||
| if __name__ == '__main__': | |||
| train_net(distribute=args.distribute, imagenet=args.imagenet, epochs=args.epochs) | |||
| @@ -0,0 +1,95 @@ | |||
| """train finetune""" | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import os | |||
| from mindspore import context | |||
| from mindspore.context import ParallelMode | |||
| import mindspore.dataset as ds | |||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| from mindspore.communication.management import init, get_rank, get_group_size | |||
| from mindspore.common import set_seed | |||
| from src.args import args | |||
| from src.data.imagenet import ImgData | |||
| from src.data.srdata import SRData | |||
| from src.data.div2k import DIV2K | |||
| from src.data.bicubic import bicubic | |||
| from src.ipt_model import IPT | |||
| from src.utils import Trainer | |||
| def train_net(distribute, imagenet): | |||
| """Train net with finetune""" | |||
| set_seed(1) | |||
| device_id = int(os.getenv('DEVICE_ID', '0')) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id) | |||
| if imagenet == 1: | |||
| train_dataset = ImgData(args) | |||
| elif not args.derain: | |||
| train_dataset = DIV2K(args, name=args.data_train, train=True, benchmark=False) | |||
| train_dataset.set_scale(args.task_id) | |||
| else: | |||
| train_dataset = SRData(args, name=args.data_train, train=True, benchmark=False) | |||
| train_dataset.set_scale(args.task_id) | |||
| if distribute: | |||
| init() | |||
| rank_id = get_rank() | |||
| rank_size = get_group_size() | |||
| parallel_mode = ParallelMode.DATA_PARALLEL | |||
| context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=rank_size, gradients_mean=True) | |||
| print('Rank {}, group_size {}'.format(rank_id, rank_size)) | |||
| if imagenet == 1: | |||
| train_de_dataset = ds.GeneratorDataset(train_dataset, | |||
| ["HR", "Rain", "LRx2", "LRx3", "LRx4", "scales", "filename"], | |||
| num_shards=rank_size, shard_id=rank_id, shuffle=True) | |||
| else: | |||
| train_de_dataset = ds.GeneratorDataset(train_dataset, ["LR", "HR", "idx", "filename"], | |||
| num_shards=rank_size, shard_id=rank_id, shuffle=True) | |||
| else: | |||
| if imagenet == 1: | |||
| train_de_dataset = ds.GeneratorDataset(train_dataset, | |||
| ["HR", "Rain", "LRx2", "LRx3", "LRx4", "scales", "filename"], | |||
| shuffle=True) | |||
| else: | |||
| train_de_dataset = ds.GeneratorDataset(train_dataset, ["LR", "HR", "idx", "filename"], shuffle=True) | |||
| if args.imagenet == 1: | |||
| resize_fuc = bicubic() | |||
| train_de_dataset = train_de_dataset.batch( | |||
| args.batch_size, | |||
| input_columns=["HR", "Rain", "LRx2", "LRx3", "LRx4", "scales", "filename"], | |||
| output_columns=["LR", "HR", "idx", "filename"], drop_remainder=True, | |||
| per_batch_map=resize_fuc.forward) | |||
| else: | |||
| train_de_dataset = train_de_dataset.batch(args.batch_size, drop_remainder=True) | |||
| train_loader = train_de_dataset.create_dict_iterator(output_numpy=True) | |||
| net_m = IPT(args) | |||
| print("Init net weights successfully") | |||
| if args.pth_path: | |||
| param_dict = load_checkpoint(args.pth_path) | |||
| load_param_into_net(net_m, param_dict) | |||
| print("Load net weight successfully") | |||
| train_func = Trainer(args, train_loader, net_m) | |||
| for epoch in range(0, args.epochs): | |||
| train_func.update_learning_rate(epoch) | |||
| train_func.train() | |||
| if __name__ == "__main__": | |||
| train_net(distribute=args.distribute, imagenet=args.imagenet) | |||