|
- # Copyright 2020 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """Metric for accuracy evaluation."""
- from mindspore import nn
-
-
- class WarpCTCAccuracy(nn.Metric):
- """
- Define accuracy metric for warpctc network.
- """
-
- def __init__(self, device_target='Ascend'):
- super(WarpCTCAccuracy).__init__()
- self._correct_num = 0
- self._total_num = 0
- self._count = 0
- self.device_target = device_target
- self.blank = 10
-
- def clear(self):
- self._correct_num = 0
- self._total_num = 0
-
- def update(self, *inputs):
- if len(inputs) != 2:
- raise ValueError('WarpCTCAccuracy need 2 inputs (y_pred, y), but got {}'.format(len(inputs)))
-
- y_pred = self._convert_data(inputs[0])
- y = self._convert_data(inputs[1])
- self._count += 1
-
- pred_lbls = self._get_prediction(y_pred)
-
- for b_idx, target in enumerate(y):
- if self._is_eq(pred_lbls[b_idx], target):
- self._correct_num += 1
- self._total_num += 1
-
- def eval(self):
- if self._total_num == 0:
- raise RuntimeError('Accuracy can not be calculated, because the number of samples is 0.')
- return self._correct_num / self._total_num
-
- def _is_eq(self, pred_lbl, target):
- """
- check whether predict label is equal to target label
- """
- target = target.tolist()
- pred_diff = len(target) - len(pred_lbl)
- if pred_diff > 0:
- # padding by BLANK_LABLE
- pred_lbl.extend([self.blank] * pred_diff)
- return pred_lbl == target
-
- def _get_prediction(self, y_pred):
- """
- parse predict result to labels
- """
- seq_len, batch_size, _ = y_pred.shape
- indices = y_pred.argmax(axis=2)
-
- lens = [seq_len] * batch_size
- pred_lbls = []
- for i in range(batch_size):
- idx = indices[:, i]
- last_idx = self.blank
- pred_lbl = []
- for j in range(lens[i]):
- cur_idx = idx[j]
- if cur_idx not in [last_idx, self.blank]:
- pred_lbl.append(cur_idx)
- last_idx = cur_idx
- pred_lbls.append(pred_lbl)
- return pred_lbls
|