You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

eval.py 3.8 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. """eval script"""
  2. # Copyright 2021 Huawei Technologies Co., Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. # ============================================================================
  16. import os
  17. import numpy as np
  18. import mindspore.dataset as ds
  19. from mindspore import Tensor, context
  20. from mindspore.common import dtype as mstype
  21. from mindspore.train.serialization import load_checkpoint, load_param_into_net
  22. from src.args import args
  23. import src.ipt_model as ipt
  24. from src.data.srdata import SRData
  25. from src.metrics import calc_psnr, quantize
  26. device_id = int(os.getenv('DEVICE_ID', '0'))
  27. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id, save_graphs=False)
  28. context.set_context(max_call_depth=10000)
  29. def sub_mean(x):
  30. red_channel_mean = 0.4488 * 255
  31. green_channel_mean = 0.4371 * 255
  32. blue_channel_mean = 0.4040 * 255
  33. x[:, 0, :, :] -= red_channel_mean
  34. x[:, 1, :, :] -= green_channel_mean
  35. x[:, 2, :, :] -= blue_channel_mean
  36. return x
  37. def add_mean(x):
  38. red_channel_mean = 0.4488 * 255
  39. green_channel_mean = 0.4371 * 255
  40. blue_channel_mean = 0.4040 * 255
  41. x[:, 0, :, :] += red_channel_mean
  42. x[:, 1, :, :] += green_channel_mean
  43. x[:, 2, :, :] += blue_channel_mean
  44. return x
  45. def eval_net():
  46. """eval"""
  47. args.batch_size = 128
  48. args.decay = 70
  49. args.patch_size = 48
  50. args.num_queries = 6
  51. args.model = 'vtip'
  52. args.num_layers = 4
  53. if args.epochs == 0:
  54. args.epochs = 1e8
  55. for arg in vars(args):
  56. if vars(args)[arg] == 'True':
  57. vars(args)[arg] = True
  58. elif vars(args)[arg] == 'False':
  59. vars(args)[arg] = False
  60. train_dataset = SRData(args, name=args.data_test, train=False, benchmark=False)
  61. train_de_dataset = ds.GeneratorDataset(train_dataset, ['LR', 'HR', "idx", "filename"], shuffle=False)
  62. train_de_dataset = train_de_dataset.batch(1, drop_remainder=True)
  63. train_loader = train_de_dataset.create_dict_iterator(output_numpy=True)
  64. net_m = ipt.IPT(args)
  65. if args.pth_path:
  66. param_dict = load_checkpoint(args.pth_path)
  67. load_param_into_net(net_m, param_dict)
  68. net_m.set_train(False)
  69. idx = Tensor(np.ones(args.task_id), mstype.int32)
  70. inference = ipt.IPT_post(net_m, args)
  71. print('load mindspore net successfully.')
  72. num_imgs = train_de_dataset.get_dataset_size()
  73. psnrs = np.zeros((num_imgs, 1))
  74. for batch_idx, imgs in enumerate(train_loader):
  75. lr = imgs['LR']
  76. hr = imgs['HR']
  77. lr = sub_mean(lr)
  78. lr = Tensor(lr, mstype.float32)
  79. pred = inference.forward(lr, idx)
  80. pred_np = add_mean(pred.asnumpy())
  81. pred_np = quantize(pred_np, 255)
  82. psnr = calc_psnr(pred_np, hr, 4, 255.0)
  83. print("current psnr: ", psnr)
  84. psnrs[batch_idx, 0] = psnr
  85. if args.denoise:
  86. print('Mean psnr of %s DN_%s is %.4f' % (args.data_test[0], args.sigma, psnrs.mean(axis=0)[0]))
  87. elif args.derain:
  88. print('Mean psnr of Derain is %.4f' % (psnrs.mean(axis=0)))
  89. else:
  90. print('Mean psnr of %s x%s is %.4f' % (args.data_test[0], args.scale[0], psnrs.mean(axis=0)[0]))
  91. if __name__ == '__main__':
  92. print("Start eval function!")
  93. eval_net()