You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

accuracy.py 3.0 kB

2 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import mmcv
  3. import torch.nn as nn
  4. @mmcv.jit(coderize=True)
  5. def accuracy(pred, target, topk=1, thresh=None):
  6. """Calculate accuracy according to the prediction and target.
  7. Args:
  8. pred (torch.Tensor): The model prediction, shape (N, num_class)
  9. target (torch.Tensor): The target of each prediction, shape (N, )
  10. topk (int | tuple[int], optional): If the predictions in ``topk``
  11. matches the target, the predictions will be regarded as
  12. correct ones. Defaults to 1.
  13. thresh (float, optional): If not None, predictions with scores under
  14. this threshold are considered incorrect. Default to None.
  15. Returns:
  16. float | tuple[float]: If the input ``topk`` is a single integer,
  17. the function will return a single float as accuracy. If
  18. ``topk`` is a tuple containing multiple integers, the
  19. function will return a tuple containing accuracies of
  20. each ``topk`` number.
  21. """
  22. assert isinstance(topk, (int, tuple))
  23. if isinstance(topk, int):
  24. topk = (topk, )
  25. return_single = True
  26. else:
  27. return_single = False
  28. maxk = max(topk)
  29. if pred.size(0) == 0:
  30. accu = [pred.new_tensor(0.) for i in range(len(topk))]
  31. return accu[0] if return_single else accu
  32. assert pred.ndim == 2 and target.ndim == 1
  33. assert pred.size(0) == target.size(0)
  34. assert maxk <= pred.size(1), \
  35. f'maxk {maxk} exceeds pred dimension {pred.size(1)}'
  36. pred_value, pred_label = pred.topk(maxk, dim=1)
  37. pred_label = pred_label.t() # transpose to shape (maxk, N)
  38. correct = pred_label.eq(target.view(1, -1).expand_as(pred_label))
  39. if thresh is not None:
  40. # Only prediction values larger than thresh are counted as correct
  41. correct = correct & (pred_value > thresh).t()
  42. res = []
  43. for k in topk:
  44. correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
  45. res.append(correct_k.mul_(100.0 / pred.size(0)))
  46. return res[0] if return_single else res
  47. class Accuracy(nn.Module):
  48. def __init__(self, topk=(1, ), thresh=None):
  49. """Module to calculate the accuracy.
  50. Args:
  51. topk (tuple, optional): The criterion used to calculate the
  52. accuracy. Defaults to (1,).
  53. thresh (float, optional): If not None, predictions with scores
  54. under this threshold are considered incorrect. Default to None.
  55. """
  56. super().__init__()
  57. self.topk = topk
  58. self.thresh = thresh
  59. def forward(self, pred, target):
  60. """Forward function to calculate accuracy.
  61. Args:
  62. pred (torch.Tensor): Prediction of models.
  63. target (torch.Tensor): Target for each prediction.
  64. Returns:
  65. tuple[float]: The accuracies under different topk criterions.
  66. """
  67. return accuracy(pred, target, self.topk, self.thresh)

No Description

Contributors (3)