You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

miou_precision.py 3.0 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """mIou."""
  16. import numpy as np
  17. from mindspore.nn.metrics.metric import Metric
  18. def confuse_matrix(target, pred, n):
  19. k = (target >= 0) & (target < n)
  20. return np.bincount(n * target[k].astype(int) + pred[k], minlength=n ** 2).reshape(n, n)
  21. def iou(hist):
  22. denominator = hist.sum(1) + hist.sum(0) - np.diag(hist)
  23. res = np.diag(hist) / np.where(denominator > 0, denominator, 1)
  24. res = np.sum(res) / np.count_nonzero(denominator)
  25. return res
  26. class MiouPrecision(Metric):
  27. """Calculate miou precision."""
  28. def __init__(self, num_class=21):
  29. super(MiouPrecision, self).__init__()
  30. if not isinstance(num_class, int):
  31. raise TypeError('num_class should be integer type, but got {}'.format(type(num_class)))
  32. if num_class < 1:
  33. raise ValueError('num_class must be at least 1, but got {}'.format(num_class))
  34. self._num_class = num_class
  35. self._mIoU = []
  36. self.clear()
  37. def clear(self):
  38. self._hist = np.zeros((self._num_class, self._num_class))
  39. self._mIoU = []
  40. def update(self, *inputs):
  41. if len(inputs) != 2:
  42. raise ValueError('Need 2 inputs (y_pred, y), but got {}'.format(len(inputs)))
  43. predict_in = self._convert_data(inputs[0])
  44. label_in = self._convert_data(inputs[1])
  45. if predict_in.shape[1] != self._num_class:
  46. raise ValueError('Class number not match, last input data contain {} classes, but current data contain {} '
  47. 'classes'.format(self._num_class, predict_in.shape[1]))
  48. pred = np.argmax(predict_in, axis=1)
  49. label = label_in
  50. if len(label.flatten()) != len(pred.flatten()):
  51. print('Skipping: len(gt) = {:d}, len(pred) = {:d}'.format(len(label.flatten()), len(pred.flatten())))
  52. raise ValueError('Class number not match, last input data contain {} classes, but current data contain {} '
  53. 'classes'.format(self._num_class, predict_in.shape[1]))
  54. self._hist = confuse_matrix(label.flatten(), pred.flatten(), self._num_class)
  55. mIoUs = iou(self._hist)
  56. self._mIoU.append(mIoUs)
  57. def eval(self):
  58. """
  59. Computes the mIoU categorical accuracy.
  60. """
  61. mIoU = np.nanmean(self._mIoU)
  62. print('mIoU = {}'.format(mIoU))
  63. return mIoU