You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

metric.py 2.8 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Metric for accuracy evaluation."""
  16. from mindspore import nn
  17. BLANK_LABLE = 10
  18. class WarpCTCAccuracy(nn.Metric):
  19. """
  20. Define accuracy metric for warpctc network.
  21. """
  22. def __init__(self):
  23. super(WarpCTCAccuracy).__init__()
  24. self._correct_num = 0
  25. self._total_num = 0
  26. self._count = 0
  27. def clear(self):
  28. self._correct_num = 0
  29. self._total_num = 0
  30. def update(self, *inputs):
  31. if len(inputs) != 2:
  32. raise ValueError('WarpCTCAccuracy need 2 inputs (y_pred, y), but got {}'.format(len(inputs)))
  33. y_pred = self._convert_data(inputs[0])
  34. y = self._convert_data(inputs[1])
  35. self._count += 1
  36. pred_lbls = self._get_prediction(y_pred)
  37. for b_idx, target in enumerate(y):
  38. if self._is_eq(pred_lbls[b_idx], target):
  39. self._correct_num += 1
  40. self._total_num += 1
  41. def eval(self):
  42. if self._total_num == 0:
  43. raise RuntimeError('Accuary can not be calculated, because the number of samples is 0.')
  44. return self._correct_num / self._total_num
  45. @staticmethod
  46. def _is_eq(pred_lbl, target):
  47. """
  48. check whether predict label is equal to target label
  49. """
  50. target = target.tolist()
  51. pred_diff = len(target) - len(pred_lbl)
  52. if pred_diff > 0:
  53. # padding by BLANK_LABLE
  54. pred_lbl.extend([BLANK_LABLE] * pred_diff)
  55. return pred_lbl == target
  56. @staticmethod
  57. def _get_prediction(y_pred):
  58. """
  59. parse predict result to labels
  60. """
  61. seq_len, batch_size, _ = y_pred.shape
  62. indices = y_pred.argmax(axis=2)
  63. lens = [seq_len] * batch_size
  64. pred_lbls = []
  65. for i in range(batch_size):
  66. idx = indices[:, i]
  67. last_idx = BLANK_LABLE
  68. pred_lbl = []
  69. for j in range(lens[i]):
  70. cur_idx = idx[j]
  71. if cur_idx not in [last_idx, BLANK_LABLE]:
  72. pred_lbl.append(cur_idx)
  73. last_idx = cur_idx
  74. pred_lbls.append(pred_lbl)
  75. return pred_lbls