Browse Source

!15439 add IPT net

From: @Somnus2020
Reviewed-by: @oacjiewen,@c_34,@linqingke
Signed-off-by: @linqingke
pull/15439/MERGE
mindspore-ci-bot Gitee 4 years ago
parent
commit
33fd9ca01b
21 changed files with 1385 additions and 756 deletions
  1. +51
    -18
      model_zoo/research/cv/IPT/eval.py
  2. +0
    -26
      model_zoo/research/cv/IPT/mindpsore_hub_conf.py
  3. +57
    -18
      model_zoo/research/cv/IPT/readme.md
  4. +40
    -0
      model_zoo/research/cv/IPT/scripts/run_distributed.sh
  5. +43
    -12
      model_zoo/research/cv/IPT/scripts/run_eval.sh
  6. +43
    -0
      model_zoo/research/cv/IPT/scripts/run_finetune_distributed.sh
  7. +12
    -12
      model_zoo/research/cv/IPT/src/args.py
  8. +0
    -35
      model_zoo/research/cv/IPT/src/data/__init__.py
  9. +132
    -0
      model_zoo/research/cv/IPT/src/data/bicubic.py
  10. +24
    -19
      model_zoo/research/cv/IPT/src/data/common.py
  11. +45
    -0
      model_zoo/research/cv/IPT/src/data/div2k.py
  12. +171
    -0
      model_zoo/research/cv/IPT/src/data/imagenet.py
  13. +64
    -106
      model_zoo/research/cv/IPT/src/data/srdata.py
  14. +0
    -241
      model_zoo/research/cv/IPT/src/foldunfold_stride.py
  15. +148
    -193
      model_zoo/research/cv/IPT/src/ipt_model.py
  16. +125
    -0
      model_zoo/research/cv/IPT/src/loss.py
  17. +8
    -9
      model_zoo/research/cv/IPT/src/metrics.py
  18. +0
    -67
      model_zoo/research/cv/IPT/src/template.py
  19. +173
    -0
      model_zoo/research/cv/IPT/src/utils.py
  20. +154
    -0
      model_zoo/research/cv/IPT/train.py
  21. +95
    -0
      model_zoo/research/cv/IPT/train_finetune.py

+ 51
- 18
model_zoo/research/cv/IPT/eval.py View File

@@ -1,6 +1,6 @@
"""eval script""" """eval script"""
# Copyright 2021 Huawei Technologies Co., Ltd # Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
@@ -13,48 +13,82 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
import os
import numpy as np import numpy as np
from src import ipt
import mindspore.dataset as ds
from mindspore import Tensor, context
from mindspore.common import dtype as mstype
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.args import args from src.args import args
import src.ipt_model as ipt
from src.data.srdata import SRData from src.data.srdata import SRData
from src.metrics import calc_psnr, quantize from src.metrics import calc_psnr, quantize
from mindspore import context
import mindspore.dataset as de
from mindspore.train.serialization import load_checkpoint, load_param_into_net
device_id = int(os.getenv('DEVICE_ID', '0'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id, save_graphs=False)
context.set_context(max_call_depth=10000)
context.set_context(mode=context.GRAPH_MODE, device_target="ASCEND", device_id=0)
def sub_mean(x):
red_channel_mean = 0.4488 * 255
green_channel_mean = 0.4371 * 255
blue_channel_mean = 0.4040 * 255
x[:, 0, :, :] -= red_channel_mean
x[:, 1, :, :] -= green_channel_mean
x[:, 2, :, :] -= blue_channel_mean
return x
def add_mean(x):
red_channel_mean = 0.4488 * 255
green_channel_mean = 0.4371 * 255
blue_channel_mean = 0.4040 * 255
x[:, 0, :, :] += red_channel_mean
x[:, 1, :, :] += green_channel_mean
x[:, 2, :, :] += blue_channel_mean
return x
def main():
def eval_net():
"""eval""" """eval"""
args.batch_size = 128
args.decay = 70
args.patch_size = 48
args.num_queries = 6
args.model = 'vtip'
args.num_layers = 4
if args.epochs == 0:
args.epochs = 1e8
for arg in vars(args): for arg in vars(args):
if vars(args)[arg] == 'True': if vars(args)[arg] == 'True':
vars(args)[arg] = True vars(args)[arg] = True
elif vars(args)[arg] == 'False': elif vars(args)[arg] == 'False':
vars(args)[arg] = False vars(args)[arg] = False
train_dataset = SRData(args, name=args.data_test, train=False, benchmark=False) train_dataset = SRData(args, name=args.data_test, train=False, benchmark=False)
train_de_dataset = de.GeneratorDataset(train_dataset, ['LR', "HR"], shuffle=False)
train_de_dataset = ds.GeneratorDataset(train_dataset, ['LR', 'HR', "idx", "filename"], shuffle=False)
train_de_dataset = train_de_dataset.batch(1, drop_remainder=True) train_de_dataset = train_de_dataset.batch(1, drop_remainder=True)
train_loader = train_de_dataset.create_dict_iterator()
train_loader = train_de_dataset.create_dict_iterator(output_numpy=True)
net_m = ipt.IPT(args) net_m = ipt.IPT(args)
print('load mindspore net successfully.')
if args.pth_path: if args.pth_path:
param_dict = load_checkpoint(args.pth_path) param_dict = load_checkpoint(args.pth_path)
load_param_into_net(net_m, param_dict) load_param_into_net(net_m, param_dict)
net_m.set_train(False) net_m.set_train(False)
idx = Tensor(np.ones(args.task_id), mstype.int32)
inference = ipt.IPT_post(net_m, args)
print('load mindspore net successfully.')
num_imgs = train_de_dataset.get_dataset_size() num_imgs = train_de_dataset.get_dataset_size()
psnrs = np.zeros((num_imgs, 1)) psnrs = np.zeros((num_imgs, 1))
inference = ipt.IPT_post(net_m, args)
for batch_idx, imgs in enumerate(train_loader): for batch_idx, imgs in enumerate(train_loader):
lr = imgs['LR'] lr = imgs['LR']
hr = imgs['HR'] hr = imgs['HR']
hr_np = np.float32(hr.asnumpy())
pred = inference.forward(lr)
pred_np = np.float32(pred.asnumpy())
lr = sub_mean(lr)
lr = Tensor(lr, mstype.float32)
pred = inference.forward(lr, idx)
pred_np = add_mean(pred.asnumpy())
pred_np = quantize(pred_np, 255) pred_np = quantize(pred_np, 255)
psnr = calc_psnr(pred_np, hr_np, 4, 255.0, y_only=True)
psnr = calc_psnr(pred_np, hr, 4, 255.0)
print("current psnr: ", psnr)
psnrs[batch_idx, 0] = psnr psnrs[batch_idx, 0] = psnr
if args.denoise: if args.denoise:
print('Mean psnr of %s DN_%s is %.4f' % (args.data_test[0], args.sigma, psnrs.mean(axis=0)[0])) print('Mean psnr of %s DN_%s is %.4f' % (args.data_test[0], args.sigma, psnrs.mean(axis=0)[0]))
@@ -63,7 +97,6 @@ def main():
else: else:
print('Mean psnr of %s x%s is %.4f' % (args.data_test[0], args.scale[0], psnrs.mean(axis=0)[0])) print('Mean psnr of %s x%s is %.4f' % (args.data_test[0], args.scale[0], psnrs.mean(axis=0)[0]))
if __name__ == '__main__': if __name__ == '__main__':
print("Start main function!")
main()
print("Start eval function!")
eval_net()

+ 0
- 26
model_zoo/research/cv/IPT/mindpsore_hub_conf.py View File

@@ -1,26 +0,0 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""hub config."""
from src.vitm import ViT


def IPT(*args, **kwargs):
return ViT(*args, **kwargs)


def create_network(name, *args, **kwargs):
if name == 'IPT':
return IPT(*args, **kwargs)
raise NotImplementedError(f"{name} is not implemented in the repo")

+ 57
- 18
model_zoo/research/cv/IPT/readme.md View File

@@ -45,9 +45,9 @@ The result images are converted into YCbCr color space. The PSNR is evaluated on
## Requirements ## Requirements
### Hardware (GPU)
### Hardware (Ascend)
> Prepare hardware environment with GPU.
> Prepare hardware environment with Ascend.
### Framework ### Framework
@@ -67,34 +67,73 @@ The result images are converted into YCbCr color space. The PSNR is evaluated on
```bash ```bash
IPT IPT
├── eval.py # inference entry ├── eval.py # inference entry
├── train.py # pre-training entry
├── train_finetune.py # fine-tuning entry
├── image ├── image
│   └── ipt.png # the illustration of IPT network │   └── ipt.png # the illustration of IPT network
├── model
│   ├── IPT_denoise30.ckpt # denoise model weights for noise level 30
│   ├── IPT_denoise50.ckpt # denoise model weights for noise level 50
│   ├── IPT_derain.ckpt # derain model weights
│   ├── IPT_sr2.ckpt # X2 super-resolution model weights
│   ├── IPT_sr3.ckpt # X3 super-resolution model weights
│   └── IPT_sr4.ckpt # X4 super-resolution model weights
├── readme.md # Readme ├── readme.md # Readme
├── scripts ├── scripts
│   └── run_eval.sh # inference script for all tasks
│   ├── run_eval.sh # inference script for all tasks
│   ├── run_distributed.sh # pre-training script for all tasks
│   └── run_finetune_distributed.sh # fine-tuning script for all tasks
└── src └── src
├── args.py # options/hyper-parameters of IPT ├── args.py # options/hyper-parameters of IPT
├── data ├── data
│   ├── common.py # common dataset │   ├── common.py # common dataset
│   ├── __init__.py # Class data init function
│   └── srdata.py # flow of loading sr data
├── foldunfold_stride.py # function of fold and unfold operations for images
│   ├── bicubic.py # scripts for data pre-processing
│   ├── div2k.py # DIV2K dataset
│   ├── imagenet.py # Imagenet data for pre-training
│   └── srdata.py # All dataset
├── metrics.py # PSNR calculator ├── metrics.py # PSNR calculator
├── template.py # setting of model selection
└── vitm.py # IPT network
├── utils.py # training scripts
├── loss.py # contrastive_loss
└── ipt_model.py # IPT network
``` ```
### Script Parameter ### Script Parameter
> For details about hyperparameters, see src/args.py. > For details about hyperparameters, see src/args.py.
## Training Process
### For pre-training
```bash
python train.py --distribute --imagenet 1 --batch_size 64 --lr 5e-5 --scale 2+3+4+1+1+1 --alltask --react --model vtip --num_queries 6 --chop_new --num_layers 4 --data_train imagenet --dir_data $DATA_PATH --derain --save $SAVE_PATH
```
> Or one can run following script for all tasks.
```bash
sh scripts/run_distributed.sh RANK_TABLE_FILE DATA_PATH
```
### For fine-tuning
> For SR tasks:
```bash
python train_finetune.py --distribute --imagenet 0 --batch_size 64 --lr 2e-5 --scale 2+3+4+1+1+1 --model vtip --num_queries 6 --chop_new --num_layers 4 --task_id $TASK_ID --dir_data $DATA_PATH --pth_path $MODEL --epochs 50
```
> For Denoising tasks:
```bash
python train_finetune.py --distribute --imagenet 0 --batch_size 64 --lr 2e-5 --scale 2+3+4+1+1+1 --model vtip --num_queries 6 --chop_new --num_layers 4 --task_id $TASK_ID --dir_data $DATA_PATH --pth_path $MODEL --denoise --sigma $Noise --epochs 50
```
> For deraining tasks:
```bash
python train_finetune.py --distribute --imagenet 0 --batch_size 64 --lr 2e-5 --scale 2+3+4+1+1+1 --model vtip --num_queries 6 --chop_new --num_layers 4 --task_id $TASK_ID --dir_data $DATA_PATH --pth_path $MODEL --derain --epochs 50
```
> Or one can run following script for all tasks.
```bash
sh scripts/run_finetune_distributed.sh RANK_TABLE_FILE DATA_PATH MODEL TASK_ID
```
## Evaluation ## Evaluation
### Evaluation Process ### Evaluation Process
@@ -103,13 +142,13 @@ IPT
> For SR x4: > For SR x4:
```bash ```bash
python eval.py --dir_data ../../data/ --data_test Set14 --nochange --test_only --ext img --chop_new --scale 4 --pth_path ./model/IPT_sr4.ckpt
python eval.py --dir_data $DATA_PATH --data_test $DATA_TEST --test_only --ext img --pth_path $MODEL --task_id $TASK_ID --scale $SCALE
``` ```
> Or one can run following script for all tasks. > Or one can run following script for all tasks.
```bash ```bash
sh scripts/run_eval.sh
sh scripts/run_eval.sh DATA_PATH DATA_TEST MODEL TASK_ID
``` ```
### Evaluation Result ### Evaluation Result
@@ -117,7 +156,7 @@ sh scripts/run_eval.sh
The result are evaluated by the value of PSNR (Peak Signal-to-Noise Ratio), and the format is as following. The result are evaluated by the value of PSNR (Peak Signal-to-Noise Ratio), and the format is as following.
```bash ```bash
result: {"Mean psnr of Se5 x4 is 32.68"}
result: {"Mean psnr of Set5 x4 is 32.68"}
``` ```
## Performance ## Performance


+ 40
- 0
model_zoo/research/cv/IPT/scripts/run_distributed.sh View File

@@ -0,0 +1,40 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

ulimit -u unlimited
export DEVICE_NUM=8
export RANK_SIZE=8
RANK_TABLE_FILE=$(realpath $1)
export RANK_TABLE_FILE
export DATA_PATH=$2
echo "RANK_TABLE_FILE=${RANK_TABLE_FILE}"

export SERVER_ID=0
rank_start=$((DEVICE_NUM * SERVER_ID))
for((i=0; i<${DEVICE_NUM}; i++))
do
export DEVICE_ID=$i
export RANK_ID=$((rank_start + i))
rm -rf ./train_parallel$i
mkdir ./train_parallel$i
cp -r ../src ./train_parallel$i
cp ../*.py ./train_parallel$i
echo "start training for rank $RANK_ID, device $DEVICE_ID"
cd ./train_parallel$i ||exit
env > env$i.log
python train.py --distribute --imagenet 1 --batch_size 64 --lr 5e-5 --scale 2+3+4+1+1+1 --alltask --react --model vtip --num_queries 6 --chop_new --num_layers 4 --data_train imagenet --dir_data $DATA_PATH --derain --save experiments/ckpt_new_init/ > log 2>&1 &
cd ..
done

+ 43
- 12
model_zoo/research/cv/IPT/scripts/run_eval.sh View File

@@ -14,18 +14,49 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================


export DEVICE_ID=$1
DATA_DIR=$2
DATA_SET=$3
PATH_CHECKPOINT=$4
ulimit -u unlimited
export DATA_PATH=$1
export DATA_TEST=$2
export MODEL=$3
export TASK_ID=$4


python ../eval.py --dir_data=$DATA_DIR --data_test=$DATA_SET --nochange --test_only --ext img --chop_new --scale 4 --pth_path=$PATH_CHECKPOINT > eval.log 2>&1 &
python ../eval.py --dir_data=$DATA_DIR --data_test=$DATA_SET --nochange --test_only --ext img --chop_new --scale 3 --pth_path=$PATH_CHECKPOINT > eval.log 2>&1 &
python ../eval.py --dir_data=$DATA_DIR --data_test=$DATA_SET --nochange --test_only --ext img --chop_new --scale 2 --pth_path=$PATH_CHECKPOINT > eval.log 2>&1 &
if [[ $TASK_ID -lt 3 ]]; then
mkdir ./run_eval$TASK_ID
cp -r ../src ./run_eval$TASK_ID
cp ../*.py ./run_eval$TASK_ID
echo "start evaluation for Task $TASK_ID, device $DEVICE_ID"
cd ./run_eval$TASK_ID ||exit
env > env$TASK_ID.log
SCALE=$[$TASK_ID+2]
python eval.py --dir_data $DATA_PATH --data_test $DATA_TEST --test_only --ext img --pth_path $MODEL --task_id $TASK_ID --scale $SCALE > log 2>&1 &
fi


##denoise
python ../eval.py --dir_data=$DATA_DIR --data_test=$DATA_SET --nochange --test_only --ext img --chop_new --scale 1 --denoise --sigma 30 --pth_path=$PATH_CHECKPOINT > eval.log 2>&1 &
python ../eval.py --dir_data=$DATA_DIR --data_test=$DATA_SET --nochange --test_only --ext img --chop_new --scale 1 --denoise --sigma 50 --pth_path=$PATH_CHECKPOINT > eval.log 2>&1 &
if [[ $TASK_ID -eq 3 ]]; then
mkdir ./run_eval$TASK_ID
cp -r ../src ./run_eval$TASK_ID
cp ../*.py ./run_eval$TASK_ID
echo "start evaluation for Task $TASK_ID, device $DEVICE_ID"
cd ./run_eval$TASK_ID ||exit
env > env$TASK_ID.log
python eval.py --dir_data $DATA_PATH --data_test $DATA_TEST --test_only --ext img --pth_path $MODEL --task_id $TASK_ID --scale 1 --derain > log 2>&1 &
fi


##derain
python ../eval.py --dir_data=$DATA_DIR --data_test=$DATA_SET --nochange --test_only --ext img --chop_new --scale 1 --derain --derain_test 1 --pth_path=$PATH_CHECKPOINT > eval.log 2>&1 &
if [[ $TASK_ID -eq 4 ]]; then
mkdir ./run_eval$TASK_ID
cp -r ../src ./run_eval$TASK_ID
cp ../*.py ./run_eval$TASK_ID
echo "start evaluation for Task $TASK_ID, device $DEVICE_ID"
cd ./run_eval$TASK_ID ||exit
env > env$TASK_ID.log
python eval.py --dir_data $DATA_PATH --data_test $DATA_TEST --test_only --ext img --pth_path $MODEL --task_id $TASK_ID --scale 1 --denoise --sigma 30 > log 2>&1 &
fi

if [[ $TASK_ID -eq 5 ]]; then
mkdir ./run_eval$TASK_ID
cp -r ../src ./run_eval$TASK_ID
cp ../*.py ./run_eval$TASK_ID
echo "start evaluation for Task $TASK_ID, device $DEVICE_ID"
cd ./run_eval$TASK_ID ||exit
env > env$TASK_ID.log
python eval.py --dir_data $DATA_PATH --data_test $DATA_TEST --test_only --ext img --pth_path $MODEL --task_id $TASK_ID --scale 1 --denoise --sigma 50 > log 2>&1 &
fi

+ 43
- 0
model_zoo/research/cv/IPT/scripts/run_finetune_distributed.sh View File

@@ -0,0 +1,43 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

ulimit -u unlimited
export DEVICE_NUM=8
export RANK_SIZE=8
RANK_TABLE_FILE=$(realpath $1)
export RANK_TABLE_FILE
export DATA_PATH=$2
export MODEL=$3
export TASK_ID=$4
echo "RANK_TABLE_FILE=${RANK_TABLE_FILE}"


export SERVER_ID=0
rank_start=$((DEVICE_NUM * SERVER_ID))
for((i=0; i<${DEVICE_NUM}; i++))
do
export DEVICE_ID=$i
export RANK_ID=$((rank_start + i))
rm -rf ./train_parallel$i
mkdir ./train_parallel$i
cp -r ../src ./train_parallel$i
cp ../*.py ./train_parallel$i
echo "start training for rank $RANK_ID, device $DEVICE_ID"
cd ./train_parallel$i ||exit
env > env$i.log
python train_finetune.py --distribute --imagenet 0 --batch_size 64 --lr 2e-5 --scale 2+3+4+1+1+1 --model vtip --num_queries 6 --chop_new --num_layers 4 --task_id $TASK_ID --dir_data $DATA_PATH --pth_path $MODEL --epochs 100 > log 2>&1 &
cd ..
done

+ 12
- 12
model_zoo/research/cv/IPT/src/args.py View File

@@ -1,4 +1,4 @@
'''args'''
"""args"""
# Copyright 2021 Huawei Technologies Co., Ltd # Copyright 2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,8 +13,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================

import argparse import argparse
from src import template


parser = argparse.ArgumentParser(description='EDSR and MDSR') parser = argparse.ArgumentParser(description='EDSR and MDSR')


@@ -24,12 +24,6 @@ parser.add_argument('--template', default='.',
help='You can set various templates in option.py') help='You can set various templates in option.py')


# Hardware specifications # Hardware specifications
parser.add_argument('--n_threads', type=int, default=6,
help='number of threads for data loading')
parser.add_argument('--cpu', action='store_true',
help='use cpu only')
parser.add_argument('--n_GPUs', type=int, default=1,
help='number of GPUs')
parser.add_argument('--seed', type=int, default=1, parser.add_argument('--seed', type=int, default=1,
help='random seed') help='random seed')


@@ -60,9 +54,8 @@ parser.add_argument('--no_augment', action='store_true',
help='do not use data augmentation') help='do not use data augmentation')


# Model specifications # Model specifications
parser.add_argument('--model', default='vtip',
parser.add_argument('--model', default='EDSR',
help='model name') help='model name')

parser.add_argument('--act', type=str, default='relu', parser.add_argument('--act', type=str, default='relu',
help='activation function') help='activation function')
parser.add_argument('--pre_train', type=str, default='', parser.add_argument('--pre_train', type=str, default='',
@@ -139,6 +132,7 @@ parser.add_argument('--gclip', type=float, default=0,
help='gradient clipping threshold (0 = no clipping)') help='gradient clipping threshold (0 = no clipping)')


# Loss specifications # Loss specifications
parser.add_argument('--con_loss', action='store_true')
parser.add_argument('--loss', type=str, default='1*L1', parser.add_argument('--loss', type=str, default='1*L1',
help='loss function configuration') help='loss function configuration')
parser.add_argument('--skip_threshold', type=float, default='1e8', parser.add_argument('--skip_threshold', type=float, default='1e8',
@@ -161,6 +155,7 @@ parser.add_argument('--save_gt', action='store_true',
help='save low-resolution and high-resolution images together') help='save low-resolution and high-resolution images together')


parser.add_argument('--scalelr', type=int, default=0) parser.add_argument('--scalelr', type=int, default=0)

# cloud # cloud
parser.add_argument('--moxfile', type=int, default=1) parser.add_argument('--moxfile', type=int, default=1)
parser.add_argument('--imagenet', type=int, default=0) parser.add_argument('--imagenet', type=int, default=0)
@@ -169,10 +164,11 @@ parser.add_argument('--train_url', type=str, help='train_dir')
parser.add_argument('--pretrain', type=str, default='') parser.add_argument('--pretrain', type=str, default='')
parser.add_argument('--pth_path', type=str, default='') parser.add_argument('--pth_path', type=str, default='')
parser.add_argument('--load_query', type=int, default=0) parser.add_argument('--load_query', type=int, default=0)

# transformer # transformer
parser.add_argument('--patch_dim', type=int, default=3) parser.add_argument('--patch_dim', type=int, default=3)
parser.add_argument('--num_heads', type=int, default=12) parser.add_argument('--num_heads', type=int, default=12)
parser.add_argument('--num_layers', type=int, default=12)
parser.add_argument('--num_layers', type=int, default=4)
parser.add_argument('--dropout_rate', type=float, default=0) parser.add_argument('--dropout_rate', type=float, default=0)
parser.add_argument('--no_norm', action='store_true') parser.add_argument('--no_norm', action='store_true')
parser.add_argument('--post_norm', action='store_true') parser.add_argument('--post_norm', action='store_true')
@@ -192,8 +188,10 @@ parser.add_argument('--sigma', type=float, default=25)
parser.add_argument('--derain', action='store_true') parser.add_argument('--derain', action='store_true')
parser.add_argument('--finetune', action='store_true') parser.add_argument('--finetune', action='store_true')
parser.add_argument('--derain_test', type=int, default=10) parser.add_argument('--derain_test', type=int, default=10)

# alltask # alltask
parser.add_argument('--alltask', action='store_true') parser.add_argument('--alltask', action='store_true')
parser.add_argument('--task_id', type=int, default=0)


# dehaze # dehaze
parser.add_argument('--dehaze', action='store_true') parser.add_argument('--dehaze', action='store_true')
@@ -201,6 +199,7 @@ parser.add_argument('--dehaze_test', type=int, default=100)
parser.add_argument('--indoor', action='store_true') parser.add_argument('--indoor', action='store_true')
parser.add_argument('--outdoor', action='store_true') parser.add_argument('--outdoor', action='store_true')
parser.add_argument('--nochange', action='store_true') parser.add_argument('--nochange', action='store_true')

# deblur # deblur
parser.add_argument('--deblur', action='store_true') parser.add_argument('--deblur', action='store_true')
parser.add_argument('--deblur_test', type=int, default=1000) parser.add_argument('--deblur_test', type=int, default=1000)
@@ -210,6 +209,8 @@ parser.add_argument('--init_method', type=str,
default=None, help='master address') default=None, help='master address')
parser.add_argument('--rank', type=int, default=0, parser.add_argument('--rank', type=int, default=0,
help='Index of current task') help='Index of current task')
parser.add_argument('--group_size', type=int, default=0,
help='Index of current task')
parser.add_argument('--world_size', type=int, default=1, parser.add_argument('--world_size', type=int, default=1,
help='Total number of tasks') help='Total number of tasks')
parser.add_argument('--gpu', default=None, type=int, parser.add_argument('--gpu', default=None, type=int,
@@ -223,7 +224,6 @@ parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
parser.add_argument('--distribute', action='store_true') parser.add_argument('--distribute', action='store_true')


args, unparsed = parser.parse_known_args() args, unparsed = parser.parse_known_args()
template.set_template(args)


args.scale = [int(x) for x in args.scale.split("+")] args.scale = [int(x) for x in args.scale.split("+")]
args.data_train = args.data_train.split('+') args.data_train = args.data_train.split('+')


+ 0
- 35
model_zoo/research/cv/IPT/src/data/__init__.py View File

@@ -1,35 +0,0 @@
"""data"""
# Copyright 2021 Huawei Technologies Co., Ltd"
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from importlib import import_module


class Data:
"""data"""

def __init__(self, args):
self.loader_train = None
self.loader_test = []
for d in args.data_test:
if d in ['Set5', 'Set14', 'B100', 'Urban100', 'Manga109', 'CBSD68', 'Rain100L', 'GOPRO_Large']:
m = import_module('data.benchmark')
testset = getattr(m, 'Benchmark')(args, train=False, name=d)
else:
module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG'
m = import_module('data.' + module_name.lower())
testset = getattr(m, module_name)(args, train=False, name=d)

self.loader_test.append(
testset
)

+ 132
- 0
model_zoo/research/cv/IPT/src/data/bicubic.py View File

@@ -0,0 +1,132 @@
"""bicubic"""
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

import numpy as np

class bicubic:
"""bicubic"""
def __init__(self, seed=0):
self.seed = seed
self.rand_fn = np.random.RandomState(self.seed)

def cubic(self, x):
absx2 = np.abs(x) * np.abs(x)
absx3 = np.abs(x) * np.abs(x) * np.abs(x)

condition1 = (np.abs(x) <= 1).astype(np.float32)
condition2 = ((np.abs(x) > 1) & (np.abs(x) <= 2)).astype(np.float32)

f = (1.5 * absx3 - 2.5 * absx2 + 1) * condition1 + (-0.5 * absx3 + 2.5 * absx2 - 4 * np.abs(x) + 2) * condition2
return f

def contribute(self, in_size, out_size, scale):
"""bicubic"""
kernel_width = 4
if scale < 1:
kernel_width = 4 / scale
x0 = np.arange(start=1, stop=out_size[0]+1).astype(np.float32)
x1 = np.arange(start=1, stop=out_size[1]+1).astype(np.float32)

u0 = x0 / scale + 0.5 * (1 - 1 / scale)
u1 = x1 / scale + 0.5 * (1 - 1 / scale)

left0 = np.floor(u0 - kernel_width / 2)
left1 = np.floor(u1 - kernel_width / 2)

width = np.ceil(kernel_width) + 2
indice0 = np.expand_dims(left0, axis=1) + \
np.expand_dims(np.arange(start=0, stop=width).astype(np.float32), axis=0)
indice1 = np.expand_dims(left1, axis=1) + \
np.expand_dims(np.arange(start=0, stop=width).astype(np.float32), axis=0)

mid0 = np.expand_dims(u0, axis=1) - np.expand_dims(indice0, axis=0)
mid1 = np.expand_dims(u1, axis=1) - np.expand_dims(indice1, axis=0)

if scale < 1:
weight0 = scale * self.cubic(mid0 * scale)
weight1 = scale * self.cubic(mid1 * scale)
else:
weight0 = self.cubic(mid0)
weight1 = self.cubic(mid1)

weight0 = weight0 / (np.expand_dims(np.sum(weight0, axis=2), 2))
weight1 = weight1 / (np.expand_dims(np.sum(weight1, axis=2), 2))

indice0 = np.expand_dims(np.minimum(np.maximum(1, indice0), in_size[0]), axis=0)
indice1 = np.expand_dims(np.minimum(np.maximum(1, indice1), in_size[1]), axis=0)
kill0 = np.equal(weight0, 0)[0][0]
kill1 = np.equal(weight1, 0)[0][0]

weight0 = weight0[:, :, kill0 == 0]
weight1 = weight1[:, :, kill1 == 0]

indice0 = indice0[:, :, kill0 == 0]
indice1 = indice1[:, :, kill1 == 0]

return weight0, weight1, indice0, indice1

def forward(self, hr, rain, lrx2, lrx3, lrx4, filename, batchInfo):
"""bicubic"""
idx = self.rand_fn.randint(0, 6)
if idx < 3:
if idx == 0:
scale = 1/2
hr = lrx2
elif idx == 1:
scale = 1/3
hr = lrx3
elif idx == 2:
scale = 1/4
hr = lrx4
hr = np.array(hr)
[_, _, h, w] = hr.shape
weight0, weight1, indice0, indice1 = self.contribute([h, w], [int(h * scale), int(w * scale)], scale)
weight0 = np.asarray(weight0[0], dtype=np.float32)

indice0 = np.asarray(indice0[0], dtype=np.float32).astype(np.long)
weight0 = np.expand_dims(np.expand_dims(np.expand_dims(weight0, axis=0), axis=1), axis=4)
out = hr[:, :, (indice0-1), :] * weight0
out = np.sum(out, axis=3)
A = np.transpose(out, (0, 1, 3, 2))

weight1 = np.asarray(weight1[0], dtype=np.float32)
weight1 = np.expand_dims(np.expand_dims(np.expand_dims(weight1, axis=0), axis=1), axis=4)
indice1 = np.asarray(indice1[0], dtype=np.float32).astype(np.long)
out = A[:, :, (indice1-1), :] * weight1
out = np.round(255 * np.transpose(np.sum(out, axis=3), (0, 1, 3, 2)))/255
out = np.clip(np.round(out), 0, 255)
lr = list(out)
hr = list(hr)
else:
if idx == 3:
hr = np.array(hr)
rain = np.array(rain)
lr = np.clip((rain + hr), 0, 255)
hr = list(hr)
lr = list(lr)
elif idx == 4:
hr = np.array(hr)
noise = np.random.randn(*hr.shape) * 30
lr = np.clip(noise + hr, 0, 255)
hr = list(hr)
lr = list(lr)
elif idx == 5:
hr = np.array(hr)
noise = np.random.randn(*hr.shape) * 50
lr = np.clip(noise + hr, 0, 255)
hr = list(hr)
lr = list(lr)
return lr, hr, [idx] * len(hr), filename

+ 24
- 19
model_zoo/research/cv/IPT/src/data/common.py View File

@@ -1,6 +1,6 @@
"""common""" """common"""
# Copyright 2021 Huawei Technologies Co., Ltd # Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
@@ -13,13 +13,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
import random


import random
import numpy as np import numpy as np
import skimage.color as sc



def get_patch(*args, patch_size=96, scale=2, multi=False, input_large=False):
def get_patch(*args, patch_size=96, scale=2, input_large=False):
"""common""" """common"""
ih, iw = args[0].shape[:2] ih, iw = args[0].shape[:2]


@@ -34,25 +32,19 @@ def get_patch(*args, patch_size=96, scale=2, multi=False, input_large=False):
else: else:
tx, ty = ix, iy tx, ty = ix, iy


ret = [
args[0][iy:iy + ip, ix:ix + ip, :],
*[a[ty:ty + tp, tx:tx + tp, :] for a in args[1:]]
]
ret = [args[0][iy:iy + ip, ix:ix + ip, :], *[a[ty:ty + tp, tx:tx + tp, :] for a in args[1:]]]


return ret return ret




def set_channel(*args, n_channels=3): def set_channel(*args, n_channels=3):
"""common""" """common"""

def _set_channel(img): def _set_channel(img):
if img.ndim == 2: if img.ndim == 2:
img = np.expand_dims(img, axis=2) img = np.expand_dims(img, axis=2)


c = img.shape[2] c = img.shape[2]
if n_channels == 1 and c == 3:
img = np.expand_dims(sc.rgb2ycbcr(img)[:, :, 0], 2)
elif n_channels == 3 and c == 1:
if n_channels == 3 and c == 1:
img = np.concatenate([img] * n_channels, 2) img = np.concatenate([img] * n_channels, 2)


return img[:, :, :n_channels] return img[:, :, :n_channels]
@@ -61,14 +53,11 @@ def set_channel(*args, n_channels=3):




def np2Tensor(*args, rgb_range=255): def np2Tensor(*args, rgb_range=255):
"""common"""

def _np2Tensor(img): def _np2Tensor(img):
np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1))) np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1)))
tensor = np_transpose.astype(np.float32)
tensor = tensor * (rgb_range / 255)
return tensor

input_data = np_transpose.astype(np.float32)
output = input_data * (rgb_range / 255)
return output
return [_np2Tensor(a) for a in args] return [_np2Tensor(a) for a in args]




@@ -79,6 +68,7 @@ def augment(*args, hflip=True, rot=True):
rot90 = rot and random.random() < 0.5 rot90 = rot and random.random() < 0.5


def _augment(img): def _augment(img):
"""common"""
if hflip: if hflip:
img = img[:, ::-1, :] img = img[:, ::-1, :]
if vflip: if vflip:
@@ -88,3 +78,18 @@ def augment(*args, hflip=True, rot=True):
return img return img


return [_augment(a) for a in args] return [_augment(a) for a in args]


def search(root, target="JPEG"):
"""srdata"""
item_list = []
items = os.listdir(root)
for item in items:
path = os.path.join(root, item)
if os.path.isdir(path):
item_list.extend(search(path, target))
elif path.split('/')[-1].startswith(target):
item_list.append(path)
elif target in (path.split('/')[-2], path.split('/')[-3], path.split('/')[-4]):
item_list.append(path)
return item_list

+ 45
- 0
model_zoo/research/cv/IPT/src/data/div2k.py View File

@@ -0,0 +1,45 @@
"""div2k"""
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

import os
from src.data.srdata import SRData

class DIV2K(SRData):
"""DIV2K"""
def __init__(self, args, name='DIV2K', train=True, benchmark=False):
data_range = [r.split('-') for r in args.data_range.split('/')]
if train:
data_range = data_range[0]
else:
if args.test_only and len(data_range) == 1:
data_range = data_range[0]
else:
data_range = data_range[1]

self.begin, self.end = list(map(int, data_range))
super(DIV2K, self).__init__(args, name=name, train=train, benchmark=benchmark)

def _scan(self):
names_hr, names_lr = super(DIV2K, self)._scan()
names_hr = names_hr[self.begin - 1:self.end]
names_lr = [n[self.begin - 1:self.end] for n in names_lr]

return names_hr, names_lr

def _set_filesystem(self, dir_data):
super(DIV2K, self)._set_filesystem(dir_data)
self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR')
self.dir_lr = os.path.join(self.apath, 'DIV2K_train_LR_bicubic')

+ 171
- 0
model_zoo/research/cv/IPT/src/data/imagenet.py View File

@@ -0,0 +1,171 @@
"""imagent"""
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

import os
import random
import io
from PIL import Image

import numpy as np
import imageio

def search(root, target="JPEG"):
"""imagent"""
item_list = []
items = os.listdir(root)
for item in items:
path = os.path.join(root, item)
if os.path.isdir(path):
item_list.extend(search(path, target))
elif path.split('.')[-1] == target:
item_list.append(path)
elif path.split('/')[-1].startswith(target):
item_list.append(path)
return item_list


def get_patch_img(img, patch_size=96, scale=2):
"""imagent"""
ih, iw = img.shape[:2]
tp = scale * patch_size
if (iw - tp) > -1 and (ih-tp) > 1:
ix = random.randrange(0, iw-tp+1)
iy = random.randrange(0, ih-tp+1)
hr = img[iy:iy+tp, ix:ix+tp, :3]
elif (iw - tp) > -1 and (ih - tp) <= -1:
ix = random.randrange(0, iw-tp+1)
hr = img[:, ix:ix+tp, :3]
pil_img = Image.fromarray(hr).resize((tp, tp), Image.BILINEAR)
hr = np.array(pil_img)
elif (iw - tp) <= -1 and (ih - tp) > -1:
iy = random.randrange(0, ih-tp+1)
hr = img[iy:iy+tp, :, :3]
pil_img = Image.fromarray(hr).resize((tp, tp), Image.BILINEAR)
hr = np.array(pil_img)
else:
pil_img = Image.fromarray(img).resize((tp, tp), Image.BILINEAR)
hr = np.array(pil_img)
return hr


class ImgData():
"""imagent"""
def __init__(self, args, train=True):
self.input_large = (args.model == 'VDSR')
self.scale = args.scale
self.idx_scale = 0
self.dataroot = args.dir_data
self.img_list = search(os.path.join(self.dataroot, "train"), "JPEG")
self.img_list.extend(search(os.path.join(self.dataroot, "val"), "JPEG"))
self.img_list = sorted(self.img_list)
self.train = train
self.args = args
self.len = len(self.img_list)
print("data length:", len(self.img_list))
if self.args.derain:
self.derain_dataroot = os.path.join(self.dataroot, "RainTrainL")
self.derain_img_list = search(self.derain_dataroot, "rainstreak")

def __len__(self):
return len(self.img_list)

def _get_index(self, idx):
return idx % len(self.img_list)

def _load_file(self, idx):
idx = self._get_index(idx)
f_lr = self.img_list[idx]
lr = imageio.imread(f_lr)
if len(lr.shape) == 2:
lr = np.dstack([lr, lr, lr])
return lr, f_lr

def _np2Tensor(self, img, rgb_range):
np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1)))
tensor = np_transpose.astype(np.float32)
tensor = tensor * (rgb_range / 255)
return tensor

def __getitem__(self, idx):
if self.args.model == 'vtip' and self.train and self.args.alltask:
lr, filename = self._load_file(idx % self.len)
pair_list = []
rain = self._load_rain()
rain = np.expand_dims(rain, axis=2)
rain = self.get_patch(rain, 1)
rain = self._np2Tensor(rain, rgb_range=self.args.rgb_range)
for idx_scale in range(4):
self.idx_scale = idx_scale
pair = self.get_patch(lr)
pair_t = self._np2Tensor(pair, rgb_range=self.args.rgb_range)
pair_list.append(pair_t)
return pair_list[3], rain, pair_list[0], pair_list[1], pair_list[2], [self.scale], [filename]
if self.args.model == 'vtip' and self.train and len(self.scale) > 1:
lr, filename = self._load_file(idx % self.len)
pair_list = []
for idx_scale in range(3):
self.idx_scale = idx_scale
pair = self.get_patch(lr)
pair_t = self._np2Tensor(pair, rgb_range=self.args.rgb_range)
pair_list.append(pair_t)
return pair_list[0], pair_list[1], pair_list[2], filename
if self.args.model == 'vtip' and self.args.derain and self.scale[self.idx_scale] == 1:
lr, filename = self._load_file(idx % self.len)
rain = self._load_rain()
rain = np.expand_dims(rain, axis=2)
rain = self.get_patch(rain, 1)
rain = self._np2Tensor(rain, rgb_range=self.args.rgb_range)
pair = self.get_patch(lr)
pair_t = self._np2Tensor(pair, rgb_range=self.args.rgb_range)
return pair_t, rain, filename
if self.args.jpeg:
hr, filename = self._load_file(idx % self.len)
buffer = io.BytesIO()
width, height = hr.size
patch_size = self.scale[self.idx_scale]*self.args.patch_size
if width < patch_size:
hr = hr.resize((patch_size, height), Image.ANTIALIAS)
width, height = hr.size
if height < patch_size:
hr = hr.resize((width, patch_size), Image.ANTIALIAS)
hr.save(buffer, format='jpeg', quality=25)
lr = Image.open(buffer)
lr = np.array(lr).astype(np.float32)
hr = np.array(hr).astype(np.float32)
lr = self.get_patch(lr)
hr = self.get_patch(hr)
lr = self._np2Tensor(lr, rgb_range=self.args.rgb_range)
hr = self._np2Tensor(hr, rgb_range=self.args.rgb_range)
return lr, hr, filename
lr, filename = self._load_file(idx % self.len)
pair = self.get_patch(lr)
pair_t = self._np2Tensor(pair, rgb_range=self.args.rgb_range)
return pair_t, filename

def _load_rain(self):
idx = random.randint(0, len(self.derain_img_list) - 1)
f_lr = self.derain_img_list[idx]
lr = imageio.imread(f_lr)
return lr

def get_patch(self, lr, scale=0):
if scale == 0:
scale = self.scale[self.idx_scale]
lr = get_patch_img(lr, patch_size=self.args.patch_size, scale=scale)
return lr

def set_scale(self, idx_scale):
self.idx_scale = idx_scale

+ 64
- 106
model_zoo/research/cv/IPT/src/data/srdata.py View File

@@ -1,6 +1,6 @@
"""srdata""" """srdata"""
# Copyright 2021 Huawei Technologies Co., Ltd # Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================

import os import os
import glob import glob
import random import random
@@ -20,43 +21,12 @@ import pickle
import numpy as np import numpy as np
import imageio import imageio
from src.data import common from src.data import common
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True




def search(root, target="JPEG"):
"""srdata"""
item_list = []
items = os.listdir(root)
for item in items:
path = os.path.join(root, item)
if os.path.isdir(path):
item_list.extend(search(path, target))
elif path.split('/')[-1].startswith(target):
item_list.append(path)
elif target in (path.split('/')[-2], path.split('/')[-3], path.split('/')[-4]):
item_list.append(path)
else:
item_list = []
return item_list


def search_dehaze(root, target="JPEG"):
class SRData:
"""srdata""" """srdata"""
item_list = []
items = os.listdir(root)
for item in items:
path = os.path.join(root, item)
if os.path.isdir(path):
extend_list = search_dehaze(path, target)
if extend_list is not None:
item_list.extend(extend_list)
elif path.split('/')[-2].endswith(target):
item_list.append(path)
return item_list


class SRData():
"""srdata"""

def __init__(self, args, name='', train=True, benchmark=False): def __init__(self, args, name='', train=True, benchmark=False):
self.args = args self.args = args
self.name = name self.name = name
@@ -69,37 +39,46 @@ class SRData():
self.idx_scale = 0 self.idx_scale = 0


if self.args.derain: if self.args.derain:
self.derain_test = os.path.join(args.dir_data, "Rain100L")
self.derain_lr_test = search(self.derain_test, "rain")
self.derain_hr_test = [path.replace(
"rainy/", "no") for path in self.derain_lr_test]
if self.train:
self.derain_dataroot = os.path.join(args.dir_data, "RainTrainL")
self.clear_train = common.search(self.derain_dataroot, "norain")
self.rain_train = []
for path in self.clear_train:
change_path = path.split('/')
change_path[-1] = change_path[-1][2:]
self.rain_train.append('/'.join(change_path))
self.derain_test = os.path.join(args.dir_data, "Rain100L")
self.deblur_lr_test = common.search(self.derain_test, "rain")
self.deblur_hr_test = [path.replace("rainy/", "no") for path in self.deblur_lr_test]
self.derain_hr_test = self.deblur_hr_test
else:
self.derain_test = os.path.join(args.dir_data, "Rain100L")
self.derain_lr_test = common.search(self.derain_test, "rain")
self.derain_hr_test = [path.replace("rainy/", "no") for path in self.derain_lr_test]
self._set_filesystem(args.dir_data) self._set_filesystem(args.dir_data)
self._set_img(args)
if self.args.derain and self.train:
self.images_hr, self.images_lr = self.clear_train, self.rain_train
if train:
self._repeat(args)

def _set_img(self, args):
"""srdata"""
if args.ext.find('img') < 0: if args.ext.find('img') < 0:
path_bin = os.path.join(self.apath, 'bin') path_bin = os.path.join(self.apath, 'bin')
os.makedirs(path_bin, exist_ok=True) os.makedirs(path_bin, exist_ok=True)


list_hr, list_lr = self._scan() list_hr, list_lr = self._scan()
if args.ext.find('img') >= 0 or benchmark:
if args.ext.find('img') >= 0 or self.benchmark:
self.images_hr, self.images_lr = list_hr, list_lr self.images_hr, self.images_lr = list_hr, list_lr
elif args.ext.find('sep') >= 0: elif args.ext.find('sep') >= 0:
os.makedirs(
self.dir_hr.replace(self.apath, path_bin),
exist_ok=True
)
os.makedirs(self.dir_hr.replace(self.apath, path_bin), exist_ok=True)
for s in self.scale: for s in self.scale:
if s == 1: if s == 1:
os.makedirs(
os.path.join(self.dir_hr),
exist_ok=True
)
os.makedirs(os.path.join(self.dir_hr), exist_ok=True)
else: else:
os.makedirs( os.makedirs(
os.path.join(
self.dir_lr.replace(self.apath, path_bin),
'X{}'.format(s)
),
exist_ok=True
)
os.path.join(self.dir_lr.replace(self.apath, path_bin), 'X{}'.format(s)), exist_ok=True)


self.images_hr, self.images_lr = [], [[] for _ in self.scale] self.images_hr, self.images_lr = [], [[] for _ in self.scale]
for h in list_hr: for h in list_hr:
@@ -114,23 +93,27 @@ class SRData():
self.images_lr[i].append(b) self.images_lr[i].append(b)
self._check_and_load(args.ext, l, b, verbose=True) self._check_and_load(args.ext, l, b, verbose=True)


# Below functions as used to prepare images
def _repeat(self, args):
"""srdata"""
n_patches = args.batch_size * args.test_every
n_images = len(args.data_train) * len(self.images_hr)
if n_images == 0:
self.repeat = 0
else:
self.repeat = max(n_patches // n_images, 1)

def _scan(self): def _scan(self):
"""srdata""" """srdata"""
names_hr = sorted( names_hr = sorted(
glob.glob(os.path.join(self.dir_hr, '*' + self.ext[0]))
)
glob.glob(os.path.join(self.dir_hr, '*' + self.ext[0])))
names_lr = [[] for _ in self.scale] names_lr = [[] for _ in self.scale]
for f in names_hr: for f in names_hr:
filename, _ = os.path.splitext(os.path.basename(f)) filename, _ = os.path.splitext(os.path.basename(f))
for si, s in enumerate(self.scale): for si, s in enumerate(self.scale):
if s != 1: if s != 1:
scale = s scale = s
names_lr[si].append(os.path.join(
self.dir_lr, 'X{}/{}x{}{}'.format(
s, filename, scale, self.ext[1]
)
))
names_lr[si].append(os.path.join(self.dir_lr, 'X{}/{}x{}{}' \
.format(s, filename, scale, self.ext[1])))
for si, s in enumerate(self.scale): for si, s in enumerate(self.scale):
if s == 1: if s == 1:
names_lr[si] = names_hr names_lr[si] = names_hr
@@ -150,28 +133,33 @@ class SRData():
pickle.dump(imageio.imread(img), _f) pickle.dump(imageio.imread(img), _f)


def __getitem__(self, idx): def __getitem__(self, idx):
if self.args.model == 'vtip' and self.args.derain and self.scale[
self.idx_scale] == 1 and not self.args.finetune:
norain, rain, _ = self._load_rain_test(idx)
pair = common.set_channel(
*[rain, norain], n_channels=self.args.n_colors)
pair_t = common.np2Tensor(*pair, rgb_range=self.args.rgb_range)
return pair_t[0], pair_t[1]
if self.args.model == 'vtip' and self.args.denoise and self.scale[self.idx_scale] == 1:
hr, _ = self._load_file_hr(idx)
if self.args.derain and self.scale[self.idx_scale] == 1:
if self.train:
lr, hr, filename = self._load_file_deblur(idx)
pair = self.get_patch(lr, hr)
pair = common.set_channel(*pair, n_channels=self.args.n_colors)
pair_t = common.np2Tensor(*pair, rgb_range=self.args.rgb_range)
else:
norain, rain, filename = self._load_rain_test(idx)
pair = common.set_channel(*[rain, norain], n_channels=self.args.n_colors)
pair_t = common.np2Tensor(*pair, rgb_range=self.args.rgb_range)
return pair_t[0], pair_t[1], [self.idx_scale], [filename]

if self.args.denoise and self.scale[self.idx_scale] == 1:
hr, filename = self._load_file_hr(idx)
pair = self.get_patch_hr(hr) pair = self.get_patch_hr(hr)
pair = common.set_channel(*[pair], n_channels=self.args.n_colors) pair = common.set_channel(*[pair], n_channels=self.args.n_colors)
pair_t = common.np2Tensor(*pair, rgb_range=self.args.rgb_range) pair_t = common.np2Tensor(*pair, rgb_range=self.args.rgb_range)
noise = np.random.randn(*pair_t[0].shape) * self.args.sigma noise = np.random.randn(*pair_t[0].shape) * self.args.sigma
lr = pair_t[0] + noise lr = pair_t[0] + noise
lr = np.float32(np.clip(lr, 0, 255)) lr = np.float32(np.clip(lr, 0, 255))
return lr, pair_t[0]
lr, hr, _ = self._load_file(idx)
return lr, pair_t[0], [self.idx_scale], [filename]
lr, hr, filename = self._load_file(idx)
pair = self.get_patch(lr, hr) pair = self.get_patch(lr, hr)
pair = common.set_channel(*pair, n_channels=self.args.n_colors) pair = common.set_channel(*pair, n_channels=self.args.n_colors)
pair_t = common.np2Tensor(*pair, rgb_range=self.args.rgb_range) pair_t = common.np2Tensor(*pair, rgb_range=self.args.rgb_range)


return pair_t[0], pair_t[1]
return pair_t[0], pair_t[1], [self.idx_scale], [filename]


def __len__(self): def __len__(self):
if self.train: if self.train:
@@ -182,7 +170,6 @@ class SRData():
return len(self.images_hr) return len(self.images_hr)


def _get_index(self, idx): def _get_index(self, idx):
"""srdata"""
if self.train: if self.train:
return idx % len(self.images_hr) return idx % len(self.images_hr)
return idx return idx
@@ -198,22 +185,9 @@ class SRData():
elif self.args.ext.find('sep') >= 0: elif self.args.ext.find('sep') >= 0:
with open(f_hr, 'rb') as _f: with open(f_hr, 'rb') as _f:
hr = pickle.load(_f) hr = pickle.load(_f)

return hr, filename return hr, filename


def _load_rain(self, idx, rain_img=False):
"""srdata"""
idx = random.randint(0, len(self.derain_img_list) - 1)
f_lr = self.derain_img_list[idx]
if rain_img:
norain = imageio.imread(f_lr.replace("rainstreak", "norain"))
rain = imageio.imread(f_lr.replace("rainstreak", "rain"))
return norain, rain
lr = imageio.imread(f_lr)
return lr

def _load_rain_test(self, idx): def _load_rain_test(self, idx):
"""srdata"""
f_hr = self.derain_hr_test[idx] f_hr = self.derain_hr_test[idx]
f_lr = self.derain_lr_test[idx] f_lr = self.derain_lr_test[idx]
filename, _ = os.path.splitext(os.path.basename(f_lr)) filename, _ = os.path.splitext(os.path.basename(f_lr))
@@ -221,14 +195,6 @@ class SRData():
rain = imageio.imread(f_lr) rain = imageio.imread(f_lr)
return norain, rain, filename return norain, rain, filename


def _load_denoise(self, idx):
"""srdata"""
idx = self._get_index(idx)
f_lr = self.images_hr[idx]
norain = imageio.imread(f_lr)
rain = imageio.imread(f_lr.replace("HR", "LR_bicubic"))
return norain, rain

def _load_file(self, idx): def _load_file(self, idx):
"""srdata""" """srdata"""
idx = self._get_index(idx) idx = self._get_index(idx)
@@ -251,12 +217,7 @@ class SRData():
def get_patch_hr(self, hr): def get_patch_hr(self, hr):
"""srdata""" """srdata"""
if self.train: if self.train:
hr = self.get_patch_img_hr(
hr,
patch_size=self.args.patch_size,
scale=1
)

hr = self.get_patch_img_hr(hr, patch_size=self.args.patch_size, scale=1)
return hr return hr


def get_patch_img_hr(self, img, patch_size=96, scale=2): def get_patch_img_hr(self, img, patch_size=96, scale=2):
@@ -280,9 +241,7 @@ class SRData():
lr, hr = common.get_patch( lr, hr = common.get_patch(
lr, hr, lr, hr,
patch_size=self.args.patch_size * scale, patch_size=self.args.patch_size * scale,
scale=scale,
multi=(len(self.scale) > 1)
)
scale=scale)
if not self.args.no_augment: if not self.args.no_augment:
lr, hr = common.augment(lr, hr) lr, hr = common.augment(lr, hr)
else: else:
@@ -292,7 +251,6 @@ class SRData():
return lr, hr return lr, hr


def set_scale(self, idx_scale): def set_scale(self, idx_scale):
"""srdata"""
if not self.input_large: if not self.input_large:
self.idx_scale = idx_scale self.idx_scale = idx_scale
else: else:


+ 0
- 241
model_zoo/research/cv/IPT/src/foldunfold_stride.py View File

@@ -1,241 +0,0 @@
'''stride'''
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
from mindspore import nn
from mindspore.ops import operations as P
from mindspore.common import dtype as mstype


class _stride_unfold_(nn.Cell):
'''stride'''

def __init__(self,
kernel_size,
stride=-1):

super(_stride_unfold_, self).__init__()
if stride == -1:
self.stride = kernel_size
else:
self.stride = stride
self.kernel_size = kernel_size

self.reshape = P.Reshape()
self.transpose = P.Transpose()
self.unfold = _unfold_(kernel_size)

def construct(self, x):
"""stride"""
N, C, H, W = x.shape
leftup_idx_x = []
leftup_idx_y = []
nh = int(H / self.stride)
nw = int(W / self.stride)
for i in range(nh):
leftup_idx_x.append(i * self.stride)
for i in range(nw):
leftup_idx_y.append(i * self.stride)
NumBlock_x = len(leftup_idx_x)
NumBlock_y = len(leftup_idx_y)
zeroslike = P.ZerosLike()
cc_2 = P.Concat(axis=2)
cc_3 = P.Concat(axis=3)
unf_x = P.Zeros()((N, C, NumBlock_x * self.kernel_size,
NumBlock_y * self.kernel_size), mstype.float32)
N, C, H, W = unf_x.shape
for i in range(NumBlock_x):
for j in range(NumBlock_y):
unf_i = i * self.kernel_size
unf_j = j * self.kernel_size
org_i = leftup_idx_x[i]
org_j = leftup_idx_y[j]
fills = x[:, :, org_i:org_i + self.kernel_size,
org_j:org_j + self.kernel_size]
unf_x += cc_3((cc_3((zeroslike(unf_x[:, :, :, :unf_j]), cc_2((cc_2(
(zeroslike(unf_x[:, :, :unf_i, unf_j:unf_j + self.kernel_size]), fills)), zeroslike(
unf_x[:, :, unf_i + self.kernel_size:, unf_j:unf_j + self.kernel_size]))))),
zeroslike(unf_x[:, :, :, unf_j + self.kernel_size:])))
y = self.unfold(unf_x)
return y


class _stride_fold_(nn.Cell):
'''stride'''

def __init__(self,
kernel_size,
output_shape=(-1, -1),
stride=-1):

super(_stride_fold_, self).__init__()
if isinstance(kernel_size, (list, tuple)):
self.kernel_size = kernel_size
else:
self.kernel_size = [kernel_size, kernel_size]

if stride == -1:
self.stride = kernel_size[0]
else:
self.stride = stride

self.output_shape = output_shape

self.reshape = P.Reshape()
self.transpose = P.Transpose()
self.fold = _fold_(kernel_size)

def construct(self, x):
'''stride'''
if self.output_shape[0] == -1:
large_x = self.fold(x)
N, C, H, _ = large_x.shape
leftup_idx = []
for i in range(0, H, self.kernel_size[0]):
leftup_idx.append(i)
NumBlock = len(leftup_idx)
fold_x = P.Zeros()((N, C, (NumBlock - 1) * self.stride + self.kernel_size[0],
(NumBlock - 1) * self.stride + self.kernel_size[0]), mstype.float32)

for i in range(NumBlock):
for j in range(NumBlock):
fold_i = i * self.stride
fold_j = j * self.stride
org_i = leftup_idx[i]
org_j = leftup_idx[j]
fills = x[:, :, org_i:org_i + self.kernel_size[0],
org_j:org_j + self.kernel_size[1]]
fold_x += cc_3((cc_3((zeroslike(fold_x[:, :, :, :fold_j]), cc_2((cc_2(
(zeroslike(fold_x[:, :, :fold_i, fold_j:fold_j + self.kernel_size[1]]), fills)), zeroslike(
fold_x[:, :, fold_i + self.kernel_size[0]:, fold_j:fold_j + self.kernel_size[1]]))))),
zeroslike(fold_x[:, :, :, fold_j + self.kernel_size[1]:])))
y = fold_x
else:
NumBlock_x = int(
(self.output_shape[0] - self.kernel_size[0]) / self.stride + 1)
NumBlock_y = int(
(self.output_shape[1] - self.kernel_size[1]) / self.stride + 1)
large_shape = [NumBlock_x * self.kernel_size[0],
NumBlock_y * self.kernel_size[1]]
self.fold = _fold_(self.kernel_size, large_shape)
large_x = self.fold(x)
N, C, H, _ = large_x.shape
leftup_idx_x = []
leftup_idx_y = []
for i in range(NumBlock_x):
leftup_idx_x.append(i * self.kernel_size[0])
for i in range(NumBlock_y):
leftup_idx_y.append(i * self.kernel_size[1])
fold_x = P.Zeros()((N, C, (NumBlock_x - 1) * self.stride + self.kernel_size[0],
(NumBlock_y - 1) * self.stride + self.kernel_size[1]), mstype.float32)
for i in range(NumBlock_x):
for j in range(NumBlock_y):
fold_i = i * self.stride
fold_j = j * self.stride
org_i = leftup_idx_x[i]
org_j = leftup_idx_y[j]
fills = x[:, :, org_i:org_i + self.kernel_size[0],
org_j:org_j + self.kernel_size[1]]
fold_x += cc_3((cc_3((zeroslike(fold_x[:, :, :, :fold_j]), cc_2((cc_2(
(zeroslike(fold_x[:, :, :fold_i, fold_j:fold_j + self.kernel_size[1]]), fills)), zeroslike(
fold_x[:, :, fold_i + self.kernel_size[0]:, fold_j:fold_j + self.kernel_size[1]]))))),
zeroslike(fold_x[:, :, :, fold_j + self.kernel_size[1]:])))
y = fold_x
return y


class _unfold_(nn.Cell):
'''stride'''

def __init__(self,
kernel_size,
stride=-1):

super(_unfold_, self).__init__()
if stride == -1:
self.stride = kernel_size
self.kernel_size = kernel_size

self.reshape = P.Reshape()
self.transpose = P.Transpose()

def construct(self, x):
'''stride'''
N, C, H, W = x.shape
numH = int(H / self.kernel_size)
numW = int(W / self.kernel_size)
if numH * self.kernel_size != H or numW * self.kernel_size != W:
x = x[:, :, :numH * self.kernel_size, :, numW * self.kernel_size]
output_img = self.reshape(x, (N, C, numH, self.kernel_size, W))

output_img = self.transpose(output_img, (0, 1, 2, 4, 3))

output_img = self.reshape(output_img, (N, C, int(
numH * numW), self.kernel_size, self.kernel_size))

output_img = self.transpose(output_img, (0, 2, 1, 4, 3))

output_img = self.reshape(output_img, (N, int(numH * numW), -1))
return output_img


class _fold_(nn.Cell):
'''stride'''

def __init__(self,
kernel_size,
output_shape=(-1, -1),
stride=-1):

super(_fold_, self).__init__()

if isinstance(kernel_size, (list, tuple)):
self.kernel_size = kernel_size
else:
self.kernel_size = [kernel_size, kernel_size]

if stride == -1:
self.stride = kernel_size[0]
self.output_shape = output_shape

self.reshape = P.Reshape()
self.transpose = P.Transpose()

def construct(self, x):
'''stride'''
N, C, L = x.shape
org_C = int(L / self.kernel_size[0] / self.kernel_size[1])
if self.output_shape[0] == -1:
numH = int(np.sqrt(C))
numW = int(np.sqrt(C))
org_H = int(numH * self.kernel_size[0])
org_W = org_H
else:
org_H = int(self.output_shape[0])
org_W = int(self.output_shape[1])
numH = int(org_H / self.kernel_size[0])
numW = int(org_W / self.kernel_size[1])

output_img = self.reshape(
x, (N, C, org_C, self.kernel_size[0], self.kernel_size[1]))

output_img = self.transpose(output_img, (0, 2, 1, 3, 4))
output_img = self.reshape(
output_img, (N, org_C, numH, numW, self.kernel_size[0], self.kernel_size[1]))

output_img = self.transpose(output_img, (0, 1, 2, 4, 3, 5))

output_img = self.reshape(output_img, (N, org_C, org_H, org_W))
return output_img

model_zoo/research/cv/IPT/src/ipt.py → model_zoo/research/cv/IPT/src/ipt_model.py View File

@@ -13,15 +13,30 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================

import math import math
import copy import copy
import numpy as np import numpy as np
from mindspore import nn from mindspore import nn
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.common.tensor import Tensor
from mindspore.common.parameter import Parameter
from mindspore.common import Tensor, Parameter

class LayerPreprocess(nn.Cell):
"""
Preprocess input of each layer
"""
def __init__(self, in_channels=None):
super(LayerPreprocess, self).__init__()
self.layernorm = nn.LayerNorm((in_channels,))
self.cast = P.Cast()
self.get_dtype = P.DType()


def construct(self, input_tensor):
output = self.cast(input_tensor, mstype.float32)
output = self.layernorm(output)
output = self.cast(output, self.get_dtype(input_tensor))
return output


class MultiheadAttention(nn.Cell): class MultiheadAttention(nn.Cell):
""" """
@@ -45,7 +60,7 @@ class MultiheadAttention(nn.Cell):
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
do_return_2d_tensor (bool): True for return 2d tensor. False for return 3d do_return_2d_tensor (bool): True for return 2d tensor. False for return 3d
tensor. Default: False. tensor. Default: False.
compute_type (:class:`mindspore.dtype`): Compute type in MultiheadAttention. Default: mstype.float32.
compute_type (:class:`mindspore.dtype`): Compute type in MultiheadAttention. Default: mstype.float16.
""" """


def __init__(self, def __init__(self,
@@ -64,13 +79,12 @@ class MultiheadAttention(nn.Cell):
use_one_hot_embeddings=False, use_one_hot_embeddings=False,
initializer_range=0.02, initializer_range=0.02,
do_return_2d_tensor=False, do_return_2d_tensor=False,
compute_type=mstype.float32,
compute_type=mstype.float16,
same_dim=True): same_dim=True):
super(MultiheadAttention, self).__init__() super(MultiheadAttention, self).__init__()
self.num_attention_heads = num_attention_heads self.num_attention_heads = num_attention_heads
self.size_per_head = int(hidden_width / num_attention_heads) self.size_per_head = int(hidden_width / num_attention_heads)
self.has_attention_mask = has_attention_mask self.has_attention_mask = has_attention_mask
assert has_attention_mask
self.use_one_hot_embeddings = use_one_hot_embeddings self.use_one_hot_embeddings = use_one_hot_embeddings
self.initializer_range = initializer_range self.initializer_range = initializer_range
self.do_return_2d_tensor = do_return_2d_tensor self.do_return_2d_tensor = do_return_2d_tensor
@@ -83,11 +97,9 @@ class MultiheadAttention(nn.Cell):
self.shape_k_2d = (-1, k_tensor_width) self.shape_k_2d = (-1, k_tensor_width)
self.shape_v_2d = (-1, v_tensor_width) self.shape_v_2d = (-1, v_tensor_width)
self.hidden_width = int(hidden_width) self.hidden_width = int(hidden_width)
# units = num_attention_heads * self.size_per_head
if self.same_dim: if self.same_dim:
self.in_proj_layer = \
Parameter(Tensor(np.random.rand(hidden_width * 3,
q_tensor_width), dtype=compute_type), name="weight")
self.in_proj_layer = Parameter(Tensor(np.random.rand(hidden_width * 3,
q_tensor_width), dtype=mstype.float32), name="weight")
else: else:
self.query_layer = nn.Dense(q_tensor_width, self.query_layer = nn.Dense(q_tensor_width,
hidden_width, hidden_width,
@@ -132,8 +144,10 @@ class MultiheadAttention(nn.Cell):
self.equal = P.Equal() self.equal = P.Equal()
self.shape = P.Shape() self.shape = P.Shape()


def construct(self, tensor_q, tensor_k, tensor_v, attention_mask=None):
"""Apply multihead attention."""
def construct(self, tensor_q, tensor_k, tensor_v):
"""
Apply multihead attention.
"""
batch_size, seq_length, _ = self.shape(tensor_q) batch_size, seq_length, _ = self.shape(tensor_q)
shape_qkv = (batch_size, -1, shape_qkv = (batch_size, -1,
self.num_attention_heads, self.size_per_head) self.num_attention_heads, self.size_per_head)
@@ -161,20 +175,14 @@ class MultiheadAttention(nn.Cell):
_start = 0 _start = 0
_end = self.hidden_width _end = self.hidden_width
_w = self.in_proj_layer[_start:_end, :] _w = self.in_proj_layer[_start:_end, :]
# _b = None
query_out = self.matmul_dense(_w, tensor_q_2d) query_out = self.matmul_dense(_w, tensor_q_2d)

_start = self.hidden_width _start = self.hidden_width
_end = self.hidden_width * 2 _end = self.hidden_width * 2
_w = self.in_proj_layer[_start:_end, :] _w = self.in_proj_layer[_start:_end, :]
# _b = None
key_out = self.matmul_dense(_w, tensor_k_2d) key_out = self.matmul_dense(_w, tensor_k_2d)

_start = self.hidden_width * 2 _start = self.hidden_width * 2

_end = None _end = None
_w = self.in_proj_layer[_start:] _w = self.in_proj_layer[_start:]
# _b = None
value_out = self.matmul_dense(_w, tensor_v_2d) value_out = self.matmul_dense(_w, tensor_v_2d)
else: else:
query_out = self.query_layer(tensor_q_2d) query_out = self.query_layer(tensor_q_2d)
@@ -193,8 +201,7 @@ class MultiheadAttention(nn.Cell):


attention_scores = self.softmax_cast(attention_scores, mstype.float32) attention_scores = self.softmax_cast(attention_scores, mstype.float32)
attention_probs = self.softmax(attention_scores) attention_probs = self.softmax(attention_scores)
attention_probs = self.softmax_cast(
attention_probs, self.get_dtype(key_layer))
attention_probs = self.softmax_cast(attention_probs, mstype.float16)
if self.use_dropout: if self.use_dropout:
attention_probs = self.dropout(attention_probs) attention_probs = self.dropout(attention_probs)


@@ -212,11 +219,8 @@ class MultiheadAttention(nn.Cell):


class TransformerEncoderLayer(nn.Cell): class TransformerEncoderLayer(nn.Cell):
"""ipt""" """ipt"""

def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
activation="relu"):
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, compute_type=mstype.float16):
super().__init__() super().__init__()

self.self_attn = MultiheadAttention(q_tensor_width=d_model, self.self_attn = MultiheadAttention(q_tensor_width=d_model,
k_tensor_width=d_model, k_tensor_width=d_model,
v_tensor_width=d_model, v_tensor_width=d_model,
@@ -224,12 +228,12 @@ class TransformerEncoderLayer(nn.Cell):
out_tensor_width=d_model, out_tensor_width=d_model,
num_attention_heads=nhead, num_attention_heads=nhead,
attention_probs_dropout_prob=dropout) attention_probs_dropout_prob=dropout)
self.linear1 = nn.Dense(d_model, dim_feedforward)
self.linear1 = nn.Dense(d_model, dim_feedforward).to_float(compute_type)
self.dropout = nn.Dropout(1. - dropout) self.dropout = nn.Dropout(1. - dropout)
self.linear2 = nn.Dense(dim_feedforward, d_model) self.linear2 = nn.Dense(dim_feedforward, d_model)


self.norm1 = nn.LayerNorm([d_model])
self.norm2 = nn.LayerNorm([d_model])
self.norm1 = LayerPreprocess(d_model)
self.norm2 = LayerPreprocess(d_model)
self.dropout1 = nn.Dropout(1. - dropout) self.dropout1 = nn.Dropout(1. - dropout)
self.dropout2 = nn.Dropout(1. - dropout) self.dropout2 = nn.Dropout(1. - dropout)
self.reshape = P.Reshape() self.reshape = P.Reshape()
@@ -237,7 +241,6 @@ class TransformerEncoderLayer(nn.Cell):
self.activation = P.ReLU() self.activation = P.ReLU()


def with_pos_embed(self, tensor, pos): def with_pos_embed(self, tensor, pos):
"""ipt"""
return tensor if pos is None else tensor + pos return tensor if pos is None else tensor + pos


def construct(self, src, pos=None): def construct(self, src, pos=None):
@@ -258,10 +261,8 @@ class TransformerEncoderLayer(nn.Cell):




class TransformerDecoderLayer(nn.Cell): class TransformerDecoderLayer(nn.Cell):
"""ipt"""

def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
activation="relu"):
""" ipt"""
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1):
super().__init__() super().__init__()
self.self_attn = MultiheadAttention(q_tensor_width=d_model, self.self_attn = MultiheadAttention(q_tensor_width=d_model,
k_tensor_width=d_model, k_tensor_width=d_model,
@@ -281,9 +282,9 @@ class TransformerDecoderLayer(nn.Cell):
self.dropout = nn.Dropout(1. - dropout) self.dropout = nn.Dropout(1. - dropout)
self.linear2 = nn.Dense(dim_feedforward, d_model) self.linear2 = nn.Dense(dim_feedforward, d_model)


self.norm1 = nn.LayerNorm([d_model])
self.norm2 = nn.LayerNorm([d_model])
self.norm3 = nn.LayerNorm([d_model])
self.norm1 = LayerPreprocess(d_model)
self.norm2 = LayerPreprocess(d_model)
self.norm3 = LayerPreprocess(d_model)
self.dropout1 = nn.Dropout(1. - dropout) self.dropout1 = nn.Dropout(1. - dropout)
self.dropout2 = nn.Dropout(1. - dropout) self.dropout2 = nn.Dropout(1. - dropout)
self.dropout3 = nn.Dropout(1. - dropout) self.dropout3 = nn.Dropout(1. - dropout)
@@ -291,7 +292,6 @@ class TransformerDecoderLayer(nn.Cell):
self.activation = P.ReLU() self.activation = P.ReLU()


def with_pos_embed(self, tensor, pos): def with_pos_embed(self, tensor, pos):
"""ipt"""
return tensor if pos is None else tensor + pos return tensor if pos is None else tensor + pos


def construct(self, tgt, memory, pos=None, query_pos=None): def construct(self, tgt, memory, pos=None, query_pos=None):
@@ -306,7 +306,7 @@ class TransformerDecoderLayer(nn.Cell):
tgt2 = self.norm2(tgt) tgt2 = self.norm2(tgt)
tgt2 = self.multihead_attn(tensor_q=self.with_pos_embed(tgt2, query_pos), tgt2 = self.multihead_attn(tensor_q=self.with_pos_embed(tgt2, query_pos),
tensor_k=self.with_pos_embed(memory, pos), tensor_k=self.with_pos_embed(memory, pos),
tensor_v=memory,)
tensor_v=memory)
tgt = tgt + self.dropout2(tgt2) tgt = tgt + self.dropout2(tgt2)
tgt2 = self.norm3(tgt) tgt2 = self.norm3(tgt)
tgt2 = self.reshape(tgt2, permute_linear) tgt2 = self.reshape(tgt2, permute_linear)
@@ -318,47 +318,38 @@ class TransformerDecoderLayer(nn.Cell):


class TransformerEncoder(nn.Cell): class TransformerEncoder(nn.Cell):
"""ipt""" """ipt"""

def __init__(self, encoder_layer, num_layers): def __init__(self, encoder_layer, num_layers):
super().__init__() super().__init__()
self.layers = _get_clones(encoder_layer, num_layers) self.layers = _get_clones(encoder_layer, num_layers)
self.num_layers = num_layers self.num_layers = num_layers


def construct(self, src, pos=None): def construct(self, src, pos=None):
"""ipt"""
output = src output = src

for layer in self.layers: for layer in self.layers:
output = layer(output, pos=pos) output = layer(output, pos=pos)

return output return output




class TransformerDecoder(nn.Cell): class TransformerDecoder(nn.Cell):
"""ipt""" """ipt"""

def __init__(self, decoder_layer, num_layers): def __init__(self, decoder_layer, num_layers):
super().__init__() super().__init__()
self.layers = _get_clones(decoder_layer, num_layers) self.layers = _get_clones(decoder_layer, num_layers)
self.num_layers = num_layers self.num_layers = num_layers


def construct(self, tgt, memory, pos=None, query_pos=None): def construct(self, tgt, memory, pos=None, query_pos=None):
"""ipt"""
output = tgt output = tgt

for layer in self.layers: for layer in self.layers:
output = layer(output, memory, pos=pos, query_pos=query_pos) output = layer(output, memory, pos=pos, query_pos=query_pos)
return output return output




def _get_clones(module, N):
"""ipt"""
return nn.CellList([copy.deepcopy(module) for i in range(N)])
def _get_clones(module, n):
return nn.CellList([copy.deepcopy(module) for i in range(n)])




class LearnedPositionalEncoding(nn.Cell): class LearnedPositionalEncoding(nn.Cell):
"""ipt""" """ipt"""

def __init__(self, max_position_embeddings, embedding_dim, seq_length): def __init__(self, max_position_embeddings, embedding_dim, seq_length):
super(LearnedPositionalEncoding, self).__init__() super(LearnedPositionalEncoding, self).__init__()
self.pe = nn.Embedding( self.pe = nn.Embedding(
@@ -370,8 +361,7 @@ class LearnedPositionalEncoding(nn.Cell):
self.position_ids = self.reshape( self.position_ids = self.reshape(
self.position_ids, (1, self.seq_length)) self.position_ids, (1, self.seq_length))


def construct(self, x, position_ids=None):
"""ipt"""
def construct(self, position_ids=None):
if position_ids is None: if position_ids is None:
position_ids = self.position_ids[:, : self.seq_length] position_ids = self.position_ids[:, : self.seq_length]


@@ -381,46 +371,35 @@ class LearnedPositionalEncoding(nn.Cell):


class VisionTransformer(nn.Cell): class VisionTransformer(nn.Cell):
"""ipt""" """ipt"""

def __init__(
self,
img_dim,
patch_dim,
num_channels,
embedding_dim,
num_heads,
num_layers,
hidden_dim,
num_queries,
idx,
positional_encoding_type="learned",
dropout_rate=0,
norm=False,
mlp=False,
pos_every=False,
no_pos=False
):
def __init__(self,
img_dim,
patch_dim,
num_channels,
embedding_dim,
num_heads,
num_layers,
hidden_dim,
num_queries,
dropout_rate=0,
norm=False,
mlp=False,
pos_every=False,
no_pos=False,
con_loss=False):
super(VisionTransformer, self).__init__() super(VisionTransformer, self).__init__()

assert embedding_dim % num_heads == 0
assert img_dim % patch_dim == 0
self.norm = norm self.norm = norm
self.mlp = mlp self.mlp = mlp
self.embedding_dim = embedding_dim self.embedding_dim = embedding_dim
self.num_heads = num_heads self.num_heads = num_heads
self.patch_dim = patch_dim self.patch_dim = patch_dim
self.num_channels = num_channels self.num_channels = num_channels

self.img_dim = img_dim self.img_dim = img_dim
self.pos_every = pos_every self.pos_every = pos_every
self.num_patches = int((img_dim // patch_dim) ** 2) self.num_patches = int((img_dim // patch_dim) ** 2)
self.seq_length = self.num_patches self.seq_length = self.num_patches
self.flatten_dim = patch_dim * patch_dim * num_channels self.flatten_dim = patch_dim * patch_dim * num_channels

self.out_dim = patch_dim * patch_dim * num_channels self.out_dim = patch_dim * patch_dim * num_channels

self.no_pos = no_pos self.no_pos = no_pos

self.unf = _unfold_(patch_dim) self.unf = _unfold_(patch_dim)
self.fold = _fold_(patch_dim, output_shape=(img_dim, img_dim)) self.fold = _fold_(patch_dim, output_shape=(img_dim, img_dim))


@@ -432,8 +411,7 @@ class VisionTransformer(nn.Cell):
nn.Dropout(1. - dropout_rate), nn.Dropout(1. - dropout_rate),
nn.ReLU(), nn.ReLU(),
nn.Dense(hidden_dim, self.out_dim), nn.Dense(hidden_dim, self.out_dim),
nn.Dropout(1. - dropout_rate)
)
nn.Dropout(1. - dropout_rate))


self.query_embed = nn.Embedding( self.query_embed = nn.Embedding(
num_queries, embedding_dim * self.seq_length) num_queries, embedding_dim * self.seq_length)
@@ -449,55 +427,54 @@ class VisionTransformer(nn.Cell):
self.tile = P.Tile() self.tile = P.Tile()
self.transpose = P.Transpose() self.transpose = P.Transpose()
if not self.no_pos: if not self.no_pos:
self.position_encoding = LearnedPositionalEncoding(
self.seq_length, self.embedding_dim, self.seq_length
)
self.position_encoding = LearnedPositionalEncoding(self.seq_length, self.embedding_dim, self.seq_length)


self.dropout_layer1 = nn.Dropout(1. - dropout_rate) self.dropout_layer1 = nn.Dropout(1. - dropout_rate)
self.query_idx = idx
self.query_idx_tensor = Tensor(idx, mstype.int32)
def construct(self, x):
self.con_loss = con_loss
def construct(self, x, query_idx_tensor):
"""ipt""" """ipt"""
B, _, _, _ = x.shape
x = self.unf(x) x = self.unf(x)
B, N, _ = x.shape
b, n, _ = x.shape


if self.mlp is not True: if self.mlp is not True:
x = self.reshape(x, (B * N, -1))
x = self.reshape(x, (b * n, -1))
x = self.dropout_layer1(self.linear_encoding(x)) + x x = self.dropout_layer1(self.linear_encoding(x)) + x
x = self.reshape(x, (B, N, -1))
x = self.reshape(x, (b, n, -1))
query_embed = self.tile( query_embed = self.tile(
self.reshape(self.query_embed(self.query_idx_tensor), (1, self.seq_length, self.embedding_dim)),
(B, 1, 1))
self.reshape(self.query_embed(query_idx_tensor), (1, self.seq_length, self.embedding_dim)), (b, 1, 1))


if not self.no_pos: if not self.no_pos:
pos = self.position_encoding(x)
pos = self.position_encoding()
x = self.encoder(x + pos) x = self.encoder(x + pos)
else: else:
x = self.encoder(x) x = self.encoder(x)
x = self.decoder(x, x, query_pos=query_embed) x = self.decoder(x, x, query_pos=query_embed)


if self.mlp is not True: if self.mlp is not True:
x = self.reshape(x, (B * N, -1))
x = self.reshape(x, (b * n, -1))
x = self.mlp_head(x) + x x = self.mlp_head(x) + x
x = self.reshape(x, (B, N, -1))
x = self.reshape(x, (b, n, -1))
if self.con_loss:
con_x = x
x = self.fold(x)
return x, con_x
x = self.fold(x) x = self.fold(x)


return x return x




def default_conv(in_channels, out_channels, kernel_size, has_bias=True): def default_conv(in_channels, out_channels, kernel_size, has_bias=True):
"""ipt"""
return nn.Conv2d(
in_channels, out_channels, kernel_size, has_bias=has_bias)
return nn.Conv2d(in_channels, out_channels, kernel_size, has_bias=has_bias)




class MeanShift(nn.Conv2d): class MeanShift(nn.Conv2d):
"""ipt""" """ipt"""

def __init__(
self, rgb_range,
rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1):
def __init__(self,
rgb_range,
rgb_mean=(0.4488, 0.4371, 0.4040),
rgb_std=(1.0, 1.0, 1.0),
sign=-1):
super(MeanShift, self).__init__(3, 3, kernel_size=1) super(MeanShift, self).__init__(3, 3, kernel_size=1)
self.reshape = P.Reshape() self.reshape = P.Reshape()
self.eye = P.Eye() self.eye = P.Eye()
@@ -512,10 +489,14 @@ class MeanShift(nn.Conv2d):


class ResBlock(nn.Cell): class ResBlock(nn.Cell):
"""ipt""" """ipt"""

def __init__(
self, conv, n_feats, kernel_size,
bias=True, bn=False, act=nn.ReLU(), res_scale=1):
def __init__(self,
conv,
n_feats,
kernel_size,
bias=True,
bn=False,
act=nn.ReLU(),
res_scale=1):


super(ResBlock, self).__init__() super(ResBlock, self).__init__()
m = [] m = []
@@ -532,35 +513,28 @@ class ResBlock(nn.Cell):
self.mul = P.Mul() self.mul = P.Mul()


def construct(self, x): def construct(self, x):
"""ipt"""
res = self.mul(self.body(x), self.res_scale) res = self.mul(self.body(x), self.res_scale)
res += x res += x

return res return res




def _pixelsf_(x, scale): def _pixelsf_(x, scale):
"""ipt""" """ipt"""
N, C, iH, iW = x.shape
oH = iH * scale
oW = iW * scale
oC = C // (scale ** 2)

output = P.Reshape()(x, (N, oC, scale, scale, iH, iW))

output = P.Transpose()(output, (0, 1, 5, 3, 4, 2))

output = P.Reshape()(output, (N, oC, oH, oW))

output = P.Transpose()(output, (0, 1, 3, 2))

n, c, ih, iw = x.shape
oh = ih * scale
ow = iw * scale
oc = c // (scale ** 2)
output = P.Transpose()(x, (0, 2, 1, 3))
output = P.Reshape()(output, (n, ih, oc*scale, scale, iw))
output = P.Transpose()(output, (0, 1, 2, 4, 3))
output = P.Reshape()(output, (n, ih, oc, scale, ow))
output = P.Transpose()(output, (0, 2, 1, 3, 4))
output = P.Reshape()(output, (n, oc, oh, ow))
return output return output



class SmallUpSampler(nn.Cell): class SmallUpSampler(nn.Cell):
"""ipt""" """ipt"""

def __init__(self, conv, upsize, n_feats, bn=False, act=False, bias=True):
def __init__(self, conv, upsize, n_feats, bias=True):
super(SmallUpSampler, self).__init__() super(SmallUpSampler, self).__init__()
self.conv = conv(n_feats, upsize * upsize * n_feats, 3, bias) self.conv = conv(n_feats, upsize * upsize * n_feats, 3, bias)
self.reshape = P.Reshape() self.reshape = P.Reshape()
@@ -568,7 +542,6 @@ class SmallUpSampler(nn.Cell):
self.pixelsf = _pixelsf_ self.pixelsf = _pixelsf_


def construct(self, x): def construct(self, x):
"""ipt"""
x = self.conv(x) x = self.conv(x)
output = self.pixelsf(x, self.upsize) output = self.pixelsf(x, self.upsize)
return output return output
@@ -576,47 +549,37 @@ class SmallUpSampler(nn.Cell):


class Upsampler(nn.Cell): class Upsampler(nn.Cell):
"""ipt""" """ipt"""

def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True):
def __init__(self, conv, scale, n_feats, bias=True):
super(Upsampler, self).__init__() super(Upsampler, self).__init__()
m = [] m = []
if (scale & (scale - 1)) == 0: if (scale & (scale - 1)) == 0:
for _ in range(int(math.log(scale, 2))): for _ in range(int(math.log(scale, 2))):
m.append(SmallUpSampler(conv, 2, n_feats, bias=bias)) m.append(SmallUpSampler(conv, 2, n_feats, bias=bias))

elif scale == 3: elif scale == 3:
m.append(SmallUpSampler(conv, 3, n_feats, bias=bias)) m.append(SmallUpSampler(conv, 3, n_feats, bias=bias))
self.net = nn.SequentialCell(m) self.net = nn.SequentialCell(m)


def construct(self, x): def construct(self, x):
"""ipt"""
return self.net(x) return self.net(x)




class IPT(nn.Cell): class IPT(nn.Cell):
"""ipt""" """ipt"""

def __init__(self, args, conv=default_conv): def __init__(self, args, conv=default_conv):
super(IPT, self).__init__() super(IPT, self).__init__()
self.dytpe = mstype.float16
self.scale_idx = 0 self.scale_idx = 0


self.args = args self.args = args
self.con_loss = args.con_loss
n_feats = args.n_feats n_feats = args.n_feats
kernel_size = 3 kernel_size = 3
act = nn.ReLU() act = nn.ReLU()


self.sub_mean = MeanShift(args.rgb_range)
self.add_mean = MeanShift(args.rgb_range, sign=1)

self.head = nn.CellList([ self.head = nn.CellList([
nn.SequentialCell(
conv(args.n_colors, n_feats, kernel_size),
ResBlock(conv, n_feats, 5, act=act),
ResBlock(conv, n_feats, 5, act=act)
) for _ in args.scale
])
nn.SequentialCell(conv(args.n_colors, n_feats, kernel_size).to_float(self.dytpe),
ResBlock(conv, n_feats, 5, act=act).to_float(self.dytpe),
ResBlock(conv, n_feats, 5, act=act).to_float(self.dytpe)) for _ in range(6)])


self.body = VisionTransformer(img_dim=args.patch_size, self.body = VisionTransformer(img_dim=args.patch_size,
patch_dim=args.patch_dim, patch_dim=args.patch_dim,
@@ -630,36 +593,34 @@ class IPT(nn.Cell):
mlp=args.no_mlp, mlp=args.no_mlp,
pos_every=args.pos_every, pos_every=args.pos_every,
no_pos=args.no_pos, no_pos=args.no_pos,
idx=self.scale_idx)
con_loss=args.con_loss).to_float(self.dytpe)


self.tail = nn.CellList([ self.tail = nn.CellList([
nn.SequentialCell(
Upsampler(conv, s, n_feats, act=False),
conv(n_feats, args.n_colors, kernel_size)
) for s in args.scale
])
nn.SequentialCell(Upsampler(conv, s, n_feats).to_float(self.dytpe),
conv(n_feats, args.n_colors, kernel_size).to_float(self.dytpe)) \
for s in [2, 3, 4, 1, 1, 1]])


self.reshape = P.Reshape() self.reshape = P.Reshape()
self.tile = P.Tile() self.tile = P.Tile()
self.transpose = P.Transpose() self.transpose = P.Transpose()
self.s2t = P.ScalarToTensor()
self.cast = P.Cast()


def construct(self, x):
def construct(self, x, idx):
"""ipt""" """ipt"""
x = self.sub_mean(x)
x = self.head[self.scale_idx](x)
res = self.body(x)
idx_num = idx.shape[0]
x = self.head[idx_num](x)
idx_tensor = self.cast(self.s2t(idx_num), mstype.int32)
if self.con_loss:
res, x_con = self.body(x, idx_tensor)
res += x
x = self.tail[idx_num](x)
return x, x_con
res = self.body(x, idx_tensor)
res += x res += x
x = self.tail[self.scale_idx](res)
x = self.add_mean(x)

x = self.tail[idx_num](res)
return x return x


def set_scale(self, scale_idx):
"""ipt"""
self.body.query_idx = scale_idx
self.scale_idx = scale_idx


class IPT_post(): class IPT_post():
"""ipt""" """ipt"""
def __init__(self, model, args): def __init__(self, model, args):
@@ -674,17 +635,13 @@ class IPT_post():
self.cc_2 = P.Concat(axis=2) self.cc_2 = P.Concat(axis=2)
self.cc_3 = P.Concat(axis=3) self.cc_3 = P.Concat(axis=3)


def set_scale(self, scale_idx):
"""ipt"""
self.body.query_idx = scale_idx
self.scale_idx = scale_idx

def forward(self, x, shave=12, batchsize=64):
def forward(self, x, idx, shave=12, batchsize=64):
"""ipt""" """ipt"""
self.idx = idx
h, w = x.shape[-2:] h, w = x.shape[-2:]
padsize = int(self.args.patch_size) padsize = int(self.args.patch_size)
shave = int(self.args.patch_size / 4) shave = int(self.args.patch_size / 4)
scale = self.args.scale[self.scale_idx]
scale = self.args.scale[0]
h_cut = (h - padsize) % (padsize - shave) h_cut = (h - padsize) % (padsize - shave)
w_cut = (w - padsize) % (padsize - shave) w_cut = (w - padsize) % (padsize - shave)


@@ -692,7 +649,7 @@ class IPT_post():
x_unfold = unf_1.compute(x) x_unfold = unf_1.compute(x)
x_unfold = self.transpose(x_unfold, (1, 0, 2)) # transpose(0,2) x_unfold = self.transpose(x_unfold, (1, 0, 2)) # transpose(0,2)
x_hw_cut = x[:, :, (h - padsize):, (w - padsize):] x_hw_cut = x[:, :, (h - padsize):, (w - padsize):]
y_hw_cut = self.model(x_hw_cut)
y_hw_cut = self.model(x_hw_cut, self.idx)


x_h_cut = x[:, :, (h - padsize):, :] x_h_cut = x[:, :, (h - padsize):, :]
x_w_cut = x[:, :, :, (w - padsize):] x_w_cut = x[:, :, :, (w - padsize):]
@@ -714,10 +671,10 @@ class IPT_post():
for i in range(x_range): for i in range(x_range):
if i == 0: if i == 0:
y_unfold = self.model( y_unfold = self.model(
x_unfold[i * batchsize:(i + 1) * batchsize, :, :, :])
x_unfold[i * batchsize:(i + 1) * batchsize, :, :, :], self.idx)
else: else:
y_unfold = self.cc_0((y_unfold, self.model( y_unfold = self.cc_0((y_unfold, self.model(
x_unfold[i * batchsize:(i + 1) * batchsize, :, :, :])))
x_unfold[i * batchsize:(i + 1) * batchsize, :, :, :], self.idx)))
y_unf_shape_0 = y_unfold.shape[0] y_unf_shape_0 = y_unfold.shape[0]
fold_1 = \ fold_1 = \
_stride_fold_(padsize * scale, output_shape=((h - h_cut) * scale, (w - w_cut) * scale), _stride_fold_(padsize * scale, output_shape=((h - h_cut) * scale, (w - w_cut) * scale),
@@ -740,17 +697,18 @@ class IPT_post():
stride=padsize * scale - shave * scale) stride=padsize * scale - shave * scale)
y_inter = fold_2.compute(self.transpose(self.reshape( y_inter = fold_2.compute(self.transpose(self.reshape(
y_unfold, (y_unf_shape_0, -1, 1)), (2, 0, 1))) y_unfold, (y_unf_shape_0, -1, 1)), (2, 0, 1)))
concat1 = self.cc_2((y[:, :, :int(shave / 2 * scale), int(shave / 2 * scale):(w - w_cut) * scale - int(shave / 2 * scale)], y_inter)) #pylint: disable=line-too-long
concat2 = self.cc_2((concat1, y[:, :, (h - h_cut) * scale - int(shave / 2 * scale):, int(shave / 2 * scale):(w - w_cut) * scale - int(shave / 2 * scale)])) #pylint: disable=line-too-long
concat1 = self.cc_2((y[:, :, :int(shave / 2 * scale), \
int(shave / 2 * scale):(w - w_cut) * scale - int(shave / 2 * scale)], y_inter))
concat2 = self.cc_2((concat1, y[:, :, (h - h_cut) * scale - int(shave / 2 * scale):, \
int(shave / 2 * scale):(w - w_cut) * scale - int(shave / 2 * scale)]))
concat3 = self.cc_3((y[:, :, :, :int(shave / 2 * scale)], concat2)) concat3 = self.cc_3((y[:, :, :, :int(shave / 2 * scale)], concat2))
y = self.cc_3((concat3, y[:, :, :, (w - w_cut) * scale - int(shave / 2 * scale):])) #pylint: disable=line-too-long
y = self.cc_2((y[:, :, :y.shape[2] - int((padsize - h_cut) / 2 * scale), :], y_h_cut[:, :, int((padsize - h_cut) / 2 * scale + 0.5):, :])) #pylint: disable=line-too-long
y = self.cc_3((concat3, y[:, :, :, (w - w_cut) * scale - int(shave / 2 * scale):]))
y = self.cc_2((y[:, :, :y.shape[2] - int((padsize - h_cut) / 2 * scale), :],
y_h_cut[:, :, int((padsize - h_cut) / 2 * scale + 0.5):, :]))
y_w_cat = self.cc_2((y_w_cut[:, :, :y_w_cut.shape[2] - int((padsize - h_cut) / 2 * scale), :], y_w_cat = self.cc_2((y_w_cut[:, :, :y_w_cut.shape[2] - int((padsize - h_cut) / 2 * scale), :],
y_hw_cut[:, :, int((padsize - h_cut) / 2 * scale + 0.5):, :])) y_hw_cut[:, :, int((padsize - h_cut) / 2 * scale + 0.5):, :]))
y = self.cc_3((y[:, :, :, :y.shape[3] - int((padsize - w_cut) / 2 * scale)], y = self.cc_3((y[:, :, :, :y.shape[3] - int((padsize - w_cut) / 2 * scale)],
y_w_cat[:, :, :, int((padsize - w_cut) / 2 * scale + 0.5):])) y_w_cat[:, :, :, int((padsize - w_cut) / 2 * scale + 0.5):]))

return y return y


def cut_h_new(self, x_h_cut, h, w, h_cut, w_cut, padsize, shave, scale, batchsize): def cut_h_new(self, x_h_cut, h, w, h_cut, w_cut, padsize, shave, scale, batchsize):
@@ -766,11 +724,11 @@ class IPT_post():
for i in range(x_range): for i in range(x_range):
if i == 0: if i == 0:
y_h_cut_unfold = self.model( y_h_cut_unfold = self.model(
x_h_cut_unfold[i * batchsize:(i + 1) * batchsize, :, :, :])
x_h_cut_unfold[i * batchsize:(i + 1) * batchsize, :, :, :], self.idx)
else: else:
y_h_cut_unfold = \ y_h_cut_unfold = \
self.cc_0((y_h_cut_unfold, self.model( self.cc_0((y_h_cut_unfold, self.model(
x_h_cut_unfold[i * batchsize:(i + 1) * batchsize, :, :, :])))
x_h_cut_unfold[i * batchsize:(i + 1) * batchsize, :, :, :], self.idx)))
y_h_cut_unfold_shape_0 = y_h_cut_unfold.shape[0] y_h_cut_unfold_shape_0 = y_h_cut_unfold.shape[0]
fold_1 = \ fold_1 = \
_stride_fold_(padsize * scale, output_shape=(padsize * scale, (w - w_cut) * scale), _stride_fold_(padsize * scale, output_shape=(padsize * scale, (w - w_cut) * scale),
@@ -802,10 +760,11 @@ class IPT_post():
for i in range(x_range): for i in range(x_range):
if i == 0: if i == 0:
y_w_cut_unfold = self.model( y_w_cut_unfold = self.model(
x_w_cut_unfold[i * batchsize:(i + 1) * batchsize, :, :, :])
x_w_cut_unfold[i * batchsize:(i + 1) * batchsize, :, :, :], self.idx)
else: else:
y_w_cut_unfold = self.cc_0((y_w_cut_unfold, y_w_cut_unfold = self.cc_0((y_w_cut_unfold,
self.model(x_w_cut_unfold[i * batchsize:(i + 1) * batchsize, :, :, :])))
self.model(x_w_cut_unfold[i * batchsize:(i + 1) * batchsize, :, :, :], \
self.idx)))
y_w_cut_unfold_shape_0 = y_w_cut_unfold.shape[0] y_w_cut_unfold_shape_0 = y_w_cut_unfold.shape[0]
fold_1 = _stride_fold_(padsize * scale, fold_1 = _stride_fold_(padsize * scale,
output_shape=((h - h_cut) * scale, output_shape=((h - h_cut) * scale,
@@ -827,7 +786,6 @@ class IPT_post():


class _stride_unfold_(): class _stride_unfold_():
'''stride''' '''stride'''

def __init__(self, def __init__(self,
kernel_size, kernel_size,
stride=-1): stride=-1):
@@ -874,13 +832,12 @@ class _stride_unfold_():
zeros4 = np.zeros(unf_x[:, :, :, unf_j + self.kernel_size:].shape) zeros4 = np.zeros(unf_x[:, :, :, unf_j + self.kernel_size:].shape)
concat4 = np.concatenate((concat3, zeros4), axis=3) concat4 = np.concatenate((concat3, zeros4), axis=3)
unf_x += concat4 unf_x += concat4
unf_x = Tensor(unf_x, mstype.float32)
unf_x = Tensor(unf_x, mstype.float16)
y = self.unfold(unf_x) y = self.unfold(unf_x)
return y return y


class _stride_fold_(): class _stride_fold_():
'''stride''' '''stride'''

def __init__(self, def __init__(self,
kernel_size, kernel_size,
output_shape=(-1, -1), output_shape=(-1, -1),
@@ -905,7 +862,7 @@ class _stride_fold_():
self.fold = _fold_(self.kernel_size, self.large_shape) self.fold = _fold_(self.kernel_size, self.large_shape)


def compute(self, x): def compute(self, x):
'''stride'''
""" compute"""
NumBlock_x = self.NumBlock_x NumBlock_x = self.NumBlock_x
NumBlock_y = self.NumBlock_y NumBlock_y = self.NumBlock_y
large_x = self.fold(x) large_x = self.fold(x)
@@ -917,7 +874,8 @@ class _stride_fold_():
leftup_idx_x.append(i * self.kernel_size[0]) leftup_idx_x.append(i * self.kernel_size[0])
for i in range(NumBlock_y): for i in range(NumBlock_y):
leftup_idx_y.append(i * self.kernel_size[1]) leftup_idx_y.append(i * self.kernel_size[1])
fold_x = np.zeros((N, C, (NumBlock_x - 1) * self.stride + self.kernel_size[0], (NumBlock_y - 1) * self.stride + self.kernel_size[1]), dtype=np.float32) #pylint: disable=line-too-long
fold_x = np.zeros((N, C, (NumBlock_x - 1) * self.stride + self.kernel_size[0], \
(NumBlock_y - 1) * self.stride + self.kernel_size[1]), dtype=np.float32)
for i in range(NumBlock_x): for i in range(NumBlock_x):
for j in range(NumBlock_y): for j in range(NumBlock_y):
fold_i = i * self.stride fold_i = i * self.stride
@@ -938,12 +896,11 @@ class _stride_fold_():
zeros4 = np.zeros(t4.shape) zeros4 = np.zeros(t4.shape)
concat4 = np.concatenate((concat3, zeros4), axis=3) concat4 = np.concatenate((concat3, zeros4), axis=3)
fold_x += concat4 fold_x += concat4
y = Tensor(fold_x, mstype.float32)
y = Tensor(fold_x, mstype.float16)
return y return y


class _unfold_(nn.Cell): class _unfold_(nn.Cell):
"""ipt""" """ipt"""

def __init__( def __init__(
self, kernel_size, stride=-1): self, kernel_size, stride=-1):


@@ -965,8 +922,10 @@ class _unfold_(nn.Cell):
output_img = self.reshape(x, (N, C, numH, self.kernel_size, W)) output_img = self.reshape(x, (N, C, numH, self.kernel_size, W))


output_img = self.transpose(output_img, (0, 1, 2, 4, 3)) output_img = self.transpose(output_img, (0, 1, 2, 4, 3))
output_img = self.reshape(output_img, (N, C, numH, -1, self.kernel_size, self.kernel_size))
output_img = self.transpose(output_img, (0, 2, 3, 1, 5, 4))
output_img = self.reshape(output_img, (N*C, numH, numW, self.kernel_size, self.kernel_size))
output_img = self.transpose(output_img, (0, 1, 2, 4, 3))
output_img = self.reshape(output_img, (N, C, numH * numW, self.kernel_size*self.kernel_size))
output_img = self.transpose(output_img, (0, 2, 1, 3))
output_img = self.reshape(output_img, (N, numH * numW, -1)) output_img = self.reshape(output_img, (N, numH * numW, -1))
return output_img return output_img


@@ -1002,14 +961,10 @@ class _fold_(nn.Cell):
org_W = self.output_shape[1] org_W = self.output_shape[1]
numH = org_H // self.kernel_size[0] numH = org_H // self.kernel_size[0]
numW = org_W // self.kernel_size[1] numW = org_W // self.kernel_size[1]
output_img = self.reshape(
x, (N, C, org_C, self.kernel_size[0], self.kernel_size[1]))
output_img = self.reshape(x, (N, C, org_C, self.kernel_size[0], self.kernel_size[1]))
output_img = self.transpose(output_img, (0, 2, 3, 1, 4))
output_img = self.reshape(output_img, (N*org_C, self.kernel_size[0], numH, numW, self.kernel_size[1]))
output_img = self.transpose(output_img, (0, 2, 1, 3, 4)) output_img = self.transpose(output_img, (0, 2, 1, 3, 4))
output_img = self.reshape(
output_img, (N, org_C, numH, numW, self.kernel_size[0], self.kernel_size[1]))

output_img = self.transpose(output_img, (0, 1, 2, 4, 3, 5))


output_img = self.reshape(output_img, (N, org_C, org_H, org_W)) output_img = self.reshape(output_img, (N, org_C, org_H, org_W))
return output_img return output_img

+ 125
- 0
model_zoo/research/cv/IPT/src/loss.py View File

@@ -0,0 +1,125 @@
"""loss"""
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

import mindspore.nn as nn
from mindspore import dtype as mstype
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.ops import functional as F

GRADIENT_CLIP_TYPE = 1
GRADIENT_CLIP_VALUE = 1.0

class SupConLoss(nn.Cell):
"""SupConLoss"""
def __init__(self, temperature=0.07, contrast_mode='all',
base_temperature=0.07):
super(SupConLoss, self).__init__()
self.temperature = temperature
self.contrast_mode = contrast_mode
self.base_temperature = base_temperature
self.normalize = P.L2Normalize(axis=2)
self.eye = P.Eye()
self.unbind = P.Unstack(axis=1)
self.cat = P.Concat(axis=0)
self.matmul = P.MatMul()
self.div = P.Div()
self.transpose = P.Transpose()
self.maxes = P.ArgMaxWithValue(axis=1, keep_dims=True)
self.tile = P.Tile()
self.scatter = P.ScatterNd()
self.oneslike = P.OnesLike()
self.exp = P.Exp()
self.sum = P.ReduceSum(keep_dims=True)
self.log = P.Log()
self.reshape = P.Reshape()
self.mean = P.ReduceMean()

def construct(self, features):
"""SupConLoss"""
features = self.normalize(features)
batch_size = features.shape[0]
mask = self.eye(batch_size, batch_size, mstype.float32)
contrast_count = features.shape[1]
contrast_feature = self.cat(self.unbind(features))
if self.contrast_mode == 'all':
anchor_feature = contrast_feature
anchor_count = contrast_count
else:
anchor_feature = features[:, 0]
anchor_count = 1
anchor_dot_contrast = self.div(self.matmul(anchor_feature, self.transpose(contrast_feature, (1, 0))), \
self.temperature)
_, logits_max = self.maxes(anchor_dot_contrast)
logits = anchor_dot_contrast - logits_max
mask = self.tile(mask, (anchor_count, contrast_count))
logits_mask = 1 - self.eye(mask.shape[0], mask.shape[1], mstype.float32)
mask = mask * logits_mask
exp_logits = self.exp(logits) * logits_mask
log_prob = logits - self.log(self.sum(exp_logits, 1) + 1e-8)
mean_log_prob_pos = self.sum((mask * log_prob), 1) / self.sum(mask, 1)
loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
loss = self.mean(self.reshape(loss, (anchor_count, batch_size)))
return loss, anchor_count


class ClipGradients(nn.Cell):
"""
Clip gradients.

Args:
grads (list): List of gradient tuples.
clip_type (Tensor): The way to clip, 'value' or 'norm'.
clip_value (Tensor): Specifies how much to clip.

Returns:
List, a list of clipped_grad tuples.
"""
def __init__(self):
super(ClipGradients, self).__init__()
self.clip_by_norm = nn.ClipByNorm()
self.cast = P.Cast()
self.dtype = P.DType()

def construct(self, grads, clip_type, clip_value):
"""ClipGradients"""
if clip_type not in (0, 1):
return grads
new_grads = ()
for grad in grads:
dt = self.dtype(grad)
if clip_type == 0:
t = C.clip_by_value(grad, self.cast(F.tuple_to_array((-clip_value,)), dt),
self.cast(F.tuple_to_array((clip_value,)), dt))
else:
t = self.clip_by_norm(grad, self.cast(F.tuple_to_array((clip_value,)), dt))
t = self.cast(t, dt)
new_grads = new_grads + (t,)
return new_grads

grad_scale = C.MultitypeFuncGraph("grad_scale")
reciprocal = P.Reciprocal()

@grad_scale.register("Tensor", "Tensor")
def tensor_grad_scale(scale, grad):
return grad * F.cast(reciprocal(scale), F.dtype(grad))

_grad_overflow = C.MultitypeFuncGraph("_grad_overflow")
grad_overflow = P.FloatStatus()

@_grad_overflow.register("Tensor")
def _tensor_grad_overflow(grad):
return grad_overflow(grad)

+ 8
- 9
model_zoo/research/cv/IPT/src/metrics.py View File

@@ -1,4 +1,4 @@
'''metrics'''
"""metrics"""
# Copyright 2021 Huawei Technologies Co., Ltd # Copyright 2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,12 +13,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================

import math import math
import numpy as np import numpy as np



def quantize(img, rgb_range): def quantize(img, rgb_range):
'''metrics'''
"""metrics"""
pixel_range = 255 / rgb_range pixel_range = 255 / rgb_range
img = np.multiply(img, pixel_range) img = np.multiply(img, pixel_range)
img = np.clip(img, 0, 255) img = np.clip(img, 0, 255)
@@ -26,15 +26,14 @@ def quantize(img, rgb_range):
return img return img




def calc_psnr(sr, hr, scale, rgb_range, y_only=False, dataset=None):
'''metrics'''
def calc_psnr(sr, hr, scale, rgb_range):
"""metrics"""
hr = np.float32(hr) hr = np.float32(hr)
sr = np.float32(sr) sr = np.float32(sr)
diff = (sr - hr) / rgb_range diff = (sr - hr) / rgb_range
gray_coeffs = np.array([65.738, 129.057, 25.064]
).reshape((1, 3, 1, 1)) / 256
gray_coeffs = np.array([65.738, 129.057, 25.064]).reshape((1, 3, 1, 1)) / 256
diff = np.multiply(diff, gray_coeffs).sum(1) diff = np.multiply(diff, gray_coeffs).sum(1)
if np.size(hr) == 1:
if hr.size == 1:
return 0 return 0
if scale != 1: if scale != 1:
shave = scale shave = scale
@@ -49,7 +48,7 @@ def calc_psnr(sr, hr, scale, rgb_range, y_only=False, dataset=None):




def rgb2ycbcr(img, y_only=True): def rgb2ycbcr(img, y_only=True):
'''metrics'''
"""metrics"""
img.astype(np.float32) img.astype(np.float32)
if y_only: if y_only:
rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0 rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0


+ 0
- 67
model_zoo/research/cv/IPT/src/template.py View File

@@ -1,67 +0,0 @@
'''temp'''
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
def set_template(args):
'''temp'''
if args.template.find('jpeg') >= 0:
args.data_train = 'DIV2K_jpeg'
args.data_test = 'DIV2K_jpeg'
args.epochs = 200
args.decay = '100'
if args.template.find('EDSR_paper') >= 0:
args.model = 'EDSR'
args.n_resblocks = 32
args.n_feats = 256
args.res_scale = 0.1
if args.template.find('MDSR') >= 0:
args.model = 'MDSR'
args.patch_size = 48
args.epochs = 650
if args.template.find('DDBPN') >= 0:
args.model = 'DDBPN'
args.patch_size = 128
args.scale = '4'
args.data_test = 'Set5'
args.batch_size = 20
args.epochs = 1000
args.decay = '500'
args.gamma = 0.1
args.weight_decay = 1e-4
args.loss = '1*MSE'
if args.template.find('GAN') >= 0:
args.epochs = 200
args.lr = 5e-5
args.decay = '150'
if args.template.find('RCAN') >= 0:
args.model = 'RCAN'
args.n_resgroups = 10
args.n_resblocks = 20
args.n_feats = 64
args.chop = True
if args.template.find('VDSR') >= 0:
args.model = 'VDSR'
args.n_resblocks = 20
args.n_feats = 64
args.patch_size = 41
args.lr = 1e-1

+ 173
- 0
model_zoo/research/cv/IPT/src/utils.py View File

@@ -0,0 +1,173 @@
"""utils"""
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

import os
import time
from bisect import bisect_right
import numpy as np
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common import dtype as mstype
from mindspore.context import ParallelMode
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.ops import functional as F
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.parallel._utils import _get_parallel_mode
from mindspore.train.serialization import save_checkpoint
from src.loss import SupConLoss


class MyTrain(nn.Cell):
"""MyTrain"""
def __init__(self, model, criterion, con_loss, use_con=True):
super(MyTrain, self).__init__(auto_prefix=True)
self.use_con = use_con
self.model = model
self.con_loss = con_loss
self.criterion = criterion
self.p = P.Print()
self.cast = P.Cast()

def construct(self, lr, hr, idx):
"""MyTrain"""
if self.use_con:
sr, x_con = self.model(lr, idx)
x_con = self.cast(x_con, mstype.float32)
sr = self.cast(sr, mstype.float32)
loss1 = self.criterion(sr, hr)
loss2 = self.con_loss(x_con)
loss = loss1 + 0.1 * loss2
else:
sr = self.model(lr, idx)
sr = self.cast(sr, mstype.float32)
loss = self.criterion(sr, hr)
return loss


class MyTrainOneStepCell(nn.Cell):
"""MyTrainOneStepCell"""
def __init__(self, network, optimizer, sens=1.0):
super(MyTrainOneStepCell, self).__init__(auto_prefix=False)
self.network = network
self.network.set_grad()
self.weights = optimizer.parameters
self.optimizer = optimizer
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
self.sens = sens
self.reducer_flag = False
self.grad_reducer = None
parallel_mode = _get_parallel_mode()
if parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL):
self.reducer_flag = True
if self.reducer_flag:
self.grad_reducer = DistributedGradReducer(optimizer.parameters, True, 8)

def construct(self, *args):
weights = self.weights
loss = self.network(*args)
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
grads = self.grad(self.network, weights)(*args, sens)
if self.reducer_flag:
grads = self.grad_reducer(grads)
return F.depend(loss, self.optimizer(grads))


def sub_mean(x):
red_channel_mean = 0.4488 * 255
green_channel_mean = 0.4371 * 255
blue_channel_mean = 0.4040 * 255
x[:, 0, :, :] -= red_channel_mean
x[:, 1, :, :] -= green_channel_mean
x[:, 2, :, :] -= blue_channel_mean
return x


def add_mean(x):
red_channel_mean = 0.4488 * 255
green_channel_mean = 0.4371 * 255
blue_channel_mean = 0.4040 * 255
x[:, 0, :, :] += red_channel_mean
x[:, 1, :, :] += green_channel_mean
x[:, 2, :, :] += blue_channel_mean
return x


class Trainer():
"""Trainer"""
def __init__(self, args, loader, my_model):
self.args = args
self.scale = args.scale
self.trainloader = loader
self.model = my_model
self.model.set_train()
self.criterion = nn.L1Loss()
self.con_loss = SupConLoss()
self.optimizer = nn.Adam(self.model.trainable_params(), learning_rate=args.lr, loss_scale=1024.0)
self.train_net = MyTrain(self.model, self.criterion, self.con_loss, use_con=args.con_loss)
self.bp = MyTrainOneStepCell(self.train_net, self.optimizer, 1024.0)

def train(self):
"""Trainer"""
losses = 0
for batch_idx, imgs in enumerate(self.trainloader):
lr = imgs["LR"]
hr = imgs["HR"]
lr = Tensor(sub_mean(lr), mstype.float32)
hr = Tensor(sub_mean(hr), mstype.float32)
idx = Tensor(np.ones(imgs["idx"][0]), mstype.int32)
t1 = time.time()
loss = self.bp(lr, hr, idx)
t2 = time.time()
losses += loss.asnumpy()
print('Task: %g, Step: %g, loss: %f, time: %f s' % (idx.shape[0], batch_idx, loss.asnumpy(), t2 - t1),
flush=True)
os.makedirs(self.args.save, exist_ok=True)
if self.args.rank == 0:
save_checkpoint(self.bp, self.args.save + "model_" + str(self.epoch) + '.ckpt')

def update_learning_rate(self, epoch):
"""Update learning rates for all the networks; called at the end of every epoch.
:param epoch: current epoch
:type epoch: int
:param lr: learning rate of cyclegan
:type lr: float
:param niter: number of epochs with the initial learning rate
:type niter: int
:param niter_decay: number of epochs to linearly decay learning rate to zero
:type niter_decay: int
"""
self.epoch = epoch
value = self.args.decay.split('-')
value.sort(key=int)
milestones = list(map(int, value))
print("*********** epoch: {} **********".format(epoch))
lr = self.args.lr * self.args.gamma ** bisect_right(milestones, epoch)
self.adjust_lr('model', self.optimizer, lr)
print("*********************************")

def adjust_lr(self, name, optimizer, lr):
"""Adjust learning rate for the corresponding model.
:param name: name of model
:type name: str
:param optimizer: the optimizer of the corresponding model
:type optimizer: torch.optim
:param lr: learning rate to be adjusted
:type lr: float
"""
lr_param = optimizer.get_lr()
lr_param.assign_value(Tensor(lr, mstype.float32))
print('==> ' + name + ' learning rate: ', lr_param.asnumpy())

+ 154
- 0
model_zoo/research/cv/IPT/train.py View File

@@ -0,0 +1,154 @@
"""train"""
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

import os
import math
import mindspore.dataset as ds
from mindspore import Parameter, set_seed, context
from mindspore.context import ParallelMode
from mindspore.common.initializer import initializer, HeUniform, XavierUniform, Uniform, Normal, Zero
from mindspore.communication.management import init, get_rank, get_group_size
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.args import args
from src.data.bicubic import bicubic
from src.data.imagenet import ImgData
from src.ipt_model import IPT
from src.utils import Trainer


def _calculate_fan_in_and_fan_out(shape):
"""
calculate fan_in and fan_out

Args:
shape (tuple): input shape.

Returns:
Tuple, a tuple with two elements, the first element is `n_in` and the second element is `n_out`.
"""
dimensions = len(shape)
if dimensions < 2:
raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions")
if dimensions == 2:
fan_in = shape[1]
fan_out = shape[0]
else:
num_input_fmaps = shape[1]
num_output_fmaps = shape[0]
receptive_field_size = 1
if dimensions > 2:
receptive_field_size = shape[2] * shape[3]
fan_in = num_input_fmaps * receptive_field_size
fan_out = num_output_fmaps * receptive_field_size
return fan_in, fan_out

def init_weights(net, init_type='normal', init_gain=0.02):
"""
Initialize network weights.

:param net: network to be initialized
:type net: nn.Module
:param init_type: the name of an initialization method: normal | xavier | kaiming | orthogonal
:type init_type: str
:param init_gain: scaling factor for normal, xavier and orthogonal.
:type init_gain: float
"""

for _, cell in net.cells_and_names():
classname = cell.__class__.__name__
if hasattr(cell, 'in_proj_layer'):
cell.in_proj_layer = Parameter(initializer(HeUniform(negative_slope=math.sqrt(5)), cell.in_proj_layer.shape,
cell.in_proj_layer.dtype), name=cell.in_proj_layer.name)
if hasattr(cell, 'weight'):
if init_type == 'normal':
cell.weight = Parameter(initializer(Normal(init_gain), cell.weight.shape,
cell.weight.dtype), name=cell.weight.name)
elif init_type == 'xavier':
cell.weight = Parameter(initializer(XavierUniform(init_gain), cell.weight.shape,
cell.weight.dtype), name=cell.weight.name)
elif init_type == "he":
cell.weight = Parameter(initializer(HeUniform(negative_slope=math.sqrt(5)), cell.weight.shape,
cell.weight.dtype), name=cell.weight.name)
else:
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)

if hasattr(cell, 'bias') and cell.bias is not None:
fan_in, _ = _calculate_fan_in_and_fan_out(cell.weight.shape)
bound = 1 / math.sqrt(fan_in)
cell.bias = Parameter(initializer(Uniform(bound), cell.bias.shape, cell.bias.dtype),
name=cell.bias.name)
elif classname.find('BatchNorm2d') != -1:
cell.gamma = Parameter(initializer(Normal(1.0), cell.gamma.default_input.shape()), name=cell.gamma.name)
cell.beta = Parameter(initializer(Zero(), cell.beta.default_input.shape()), name=cell.beta.name)

print('initialize network weight with %s' % init_type)

def train_net(distribute, imagenet, epochs):
"""Train net"""
set_seed(1)
device_id = int(os.getenv('DEVICE_ID', '0'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id)

if imagenet == 1:
train_dataset = ImgData(args)
else:
train_dataset = data.Data(args).loader_train

if distribute:
init()
rank_id = get_rank()
rank_size = get_group_size()
parallel_mode = ParallelMode.DATA_PARALLEL
context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=rank_size, gradients_mean=True)
print('Rank {}, rank_size {}'.format(rank_id, rank_size))
if imagenet == 1:
train_de_dataset = ds.GeneratorDataset(train_dataset,
["HR", "Rain", "LRx2", "LRx3", "LRx4", "scales", "filename"],
num_shards=rank_size, shard_id=args.rank, shuffle=True)
else:
train_de_dataset = ds.GeneratorDataset(train_dataset, ["LR", "HR"], num_shards=rank_size,
shard_id=rank_id, shuffle=True)
else:
if imagenet == 1:
train_de_dataset = ds.GeneratorDataset(train_dataset,
["HR", "Rain", "LRx2", "LRx3", "LRx4", "scales", "filename"],
shuffle=True)
else:
train_de_dataset = ds.GeneratorDataset(train_dataset, ["LR", "HR"], shuffle=True)

resize_fuc = bicubic()
train_de_dataset = train_de_dataset.project(columns=["HR", "Rain", "LRx2", "LRx3", "LRx4", "filename"])
train_de_dataset = train_de_dataset.batch(args.batch_size,
input_columns=["HR", "Rain", "LRx2", "LRx3", "LRx4", "filename"],
output_columns=["LR", "HR", "idx", "filename"],
drop_remainder=True, per_batch_map=resize_fuc.forward)
train_loader = train_de_dataset.create_dict_iterator(output_numpy=True)

net_work = IPT(args)
init_weights(net_work, init_type='he', init_gain=1.0)
print("Init net weight successfully")
if args.pth_path:
param_dict = load_checkpoint(args.pth_path)
load_param_into_net(net_work, param_dict)
print("Load net weight successfully")

train_func = Trainer(args, train_loader, net_work)
for epoch in range(0, epochs):
train_func.update_learning_rate(epoch)
train_func.train()

if __name__ == '__main__':
train_net(distribute=args.distribute, imagenet=args.imagenet, epochs=args.epochs)

+ 95
- 0
model_zoo/research/cv/IPT/train_finetune.py View File

@@ -0,0 +1,95 @@
"""train finetune"""
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

import os
from mindspore import context
from mindspore.context import ParallelMode
import mindspore.dataset as ds
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.communication.management import init, get_rank, get_group_size
from mindspore.common import set_seed
from src.args import args
from src.data.imagenet import ImgData
from src.data.srdata import SRData
from src.data.div2k import DIV2K
from src.data.bicubic import bicubic
from src.ipt_model import IPT
from src.utils import Trainer

def train_net(distribute, imagenet):
"""Train net with finetune"""
set_seed(1)
device_id = int(os.getenv('DEVICE_ID', '0'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id)

if imagenet == 1:
train_dataset = ImgData(args)
elif not args.derain:
train_dataset = DIV2K(args, name=args.data_train, train=True, benchmark=False)
train_dataset.set_scale(args.task_id)
else:
train_dataset = SRData(args, name=args.data_train, train=True, benchmark=False)
train_dataset.set_scale(args.task_id)

if distribute:
init()
rank_id = get_rank()
rank_size = get_group_size()
parallel_mode = ParallelMode.DATA_PARALLEL
context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=rank_size, gradients_mean=True)
print('Rank {}, group_size {}'.format(rank_id, rank_size))
if imagenet == 1:
train_de_dataset = ds.GeneratorDataset(train_dataset,
["HR", "Rain", "LRx2", "LRx3", "LRx4", "scales", "filename"],
num_shards=rank_size, shard_id=rank_id, shuffle=True)
else:
train_de_dataset = ds.GeneratorDataset(train_dataset, ["LR", "HR", "idx", "filename"],
num_shards=rank_size, shard_id=rank_id, shuffle=True)
else:
if imagenet == 1:
train_de_dataset = ds.GeneratorDataset(train_dataset,
["HR", "Rain", "LRx2", "LRx3", "LRx4", "scales", "filename"],
shuffle=True)
else:
train_de_dataset = ds.GeneratorDataset(train_dataset, ["LR", "HR", "idx", "filename"], shuffle=True)

if args.imagenet == 1:
resize_fuc = bicubic()
train_de_dataset = train_de_dataset.batch(
args.batch_size,
input_columns=["HR", "Rain", "LRx2", "LRx3", "LRx4", "scales", "filename"],
output_columns=["LR", "HR", "idx", "filename"], drop_remainder=True,
per_batch_map=resize_fuc.forward)
else:
train_de_dataset = train_de_dataset.batch(args.batch_size, drop_remainder=True)

train_loader = train_de_dataset.create_dict_iterator(output_numpy=True)
net_m = IPT(args)
print("Init net weights successfully")

if args.pth_path:
param_dict = load_checkpoint(args.pth_path)
load_param_into_net(net_m, param_dict)
print("Load net weight successfully")

train_func = Trainer(args, train_loader, net_m)

for epoch in range(0, args.epochs):
train_func.update_learning_rate(epoch)
train_func.train()

if __name__ == "__main__":
train_net(distribute=args.distribute, imagenet=args.imagenet)

Loading…
Cancel
Save