|
- """eval script"""
- # Copyright 2021 Huawei Technologies Co., Ltd
-
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- import numpy as np
- from src import ipt
- from src.args import args
- from src.data.srdata import SRData
- from src.metrics import calc_psnr, quantize
-
- from mindspore import context
- import mindspore.dataset as de
- from mindspore.train.serialization import load_checkpoint, load_param_into_net
-
- context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU", device_id=0)
-
-
- def main():
- """eval"""
- for arg in vars(args):
- if vars(args)[arg] == 'True':
- vars(args)[arg] = True
- elif vars(args)[arg] == 'False':
- vars(args)[arg] = False
- train_dataset = SRData(args, name=args.data_test, train=False, benchmark=False)
- train_de_dataset = de.GeneratorDataset(train_dataset, ['LR', "HR"], shuffle=False)
- train_de_dataset = train_de_dataset.batch(1, drop_remainder=True)
- train_loader = train_de_dataset.create_dict_iterator()
-
- net_m = ipt.IPT(args)
- print('load mindspore net successfully.')
- if args.pth_path:
- param_dict = load_checkpoint(args.pth_path)
- load_param_into_net(net_m, param_dict)
- net_m.set_train(False)
- num_imgs = train_de_dataset.get_dataset_size()
- psnrs = np.zeros((num_imgs, 1))
- for batch_idx, imgs in enumerate(train_loader):
- lr = imgs['LR']
- hr = imgs['HR']
- hr_np = np.float32(hr.asnumpy())
- pred = net_m.infrc(lr)
- pred_np = np.float32(pred.asnumpy())
- pred_np = quantize(pred_np, 255)
- psnr = calc_psnr(pred_np, hr_np, 4, 255.0, y_only=True)
- psnrs[batch_idx, 0] = psnr
- if args.denoise:
- print('Mean psnr of %s DN_%s is %.4f' % (args.data_test[0], args.sigma, psnrs.mean(axis=0)[0]))
- elif args.derain:
- print('Mean psnr of Derain is %.4f' % (psnrs.mean(axis=0)))
- else:
- print('Mean psnr of %s x%s is %.4f' % (args.data_test[0], args.scale[0], psnrs.mean(axis=0)[0]))
-
-
- if __name__ == '__main__':
- print("Start main function!")
- main()
|