You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

metric.py 2.9 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Metric for accuracy evaluation."""
  16. from mindspore import nn
  17. class WarpCTCAccuracy(nn.Metric):
  18. """
  19. Define accuracy metric for warpctc network.
  20. """
  21. def __init__(self, device_target='Ascend'):
  22. super(WarpCTCAccuracy).__init__()
  23. self._correct_num = 0
  24. self._total_num = 0
  25. self._count = 0
  26. self.device_target = device_target
  27. self.blank = 10
  28. def clear(self):
  29. self._correct_num = 0
  30. self._total_num = 0
  31. def update(self, *inputs):
  32. if len(inputs) != 2:
  33. raise ValueError('WarpCTCAccuracy need 2 inputs (y_pred, y), but got {}'.format(len(inputs)))
  34. y_pred = self._convert_data(inputs[0])
  35. y = self._convert_data(inputs[1])
  36. self._count += 1
  37. pred_lbls = self._get_prediction(y_pred)
  38. for b_idx, target in enumerate(y):
  39. if self._is_eq(pred_lbls[b_idx], target):
  40. self._correct_num += 1
  41. self._total_num += 1
  42. def eval(self):
  43. if self._total_num == 0:
  44. raise RuntimeError('Accuracy can not be calculated, because the number of samples is 0.')
  45. return self._correct_num / self._total_num
  46. def _is_eq(self, pred_lbl, target):
  47. """
  48. check whether predict label is equal to target label
  49. """
  50. target = target.tolist()
  51. pred_diff = len(target) - len(pred_lbl)
  52. if pred_diff > 0:
  53. # padding by BLANK_LABLE
  54. pred_lbl.extend([self.blank] * pred_diff)
  55. return pred_lbl == target
  56. def _get_prediction(self, y_pred):
  57. """
  58. parse predict result to labels
  59. """
  60. seq_len, batch_size, _ = y_pred.shape
  61. indices = y_pred.argmax(axis=2)
  62. lens = [seq_len] * batch_size
  63. pred_lbls = []
  64. for i in range(batch_size):
  65. idx = indices[:, i]
  66. last_idx = self.blank
  67. pred_lbl = []
  68. for j in range(lens[i]):
  69. cur_idx = idx[j]
  70. if cur_idx not in [last_idx, self.blank]:
  71. pred_lbl.append(cur_idx)
  72. last_idx = cur_idx
  73. pred_lbls.append(pred_lbl)
  74. return pred_lbls