You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

metrics.py 1.7 kB

4 years ago
4 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. '''metrics'''
  2. # Copyright 2021 Huawei Technologies Co., Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. # ============================================================================
  16. import math
  17. import numpy as np
  18. def quantize(img, rgb_range):
  19. '''metrics'''
  20. pixel_range = 255 / rgb_range
  21. img = np.multiply(img, pixel_range)
  22. img = np.clip(img, 0, 255)
  23. img = np.round(img) / pixel_range
  24. return img
  25. def calc_psnr(sr, hr, scale, rgb_range, y_only=False, dataset=None):
  26. '''metrics'''
  27. hr = np.float32(hr)
  28. sr = np.float32(sr)
  29. diff = (sr - hr) / rgb_range
  30. gray_coeffs = np.array([65.738, 129.057, 25.064]
  31. ).reshape((1, 3, 1, 1)) / 256
  32. diff = np.multiply(diff, gray_coeffs).sum(1)
  33. if np.size(hr) == 1:
  34. return 0
  35. if scale != 1:
  36. shave = scale
  37. else:
  38. shave = scale + 6
  39. if scale == 1:
  40. valid = diff
  41. else:
  42. valid = diff[..., shave:-shave, shave:-shave]
  43. mse = np.mean(pow(valid, 2))
  44. return -10 * math.log10(mse)
  45. def rgb2ycbcr(img, y_only=True):
  46. '''metrics'''
  47. img.astype(np.float32)
  48. if y_only:
  49. rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0
  50. return rlt