| @@ -42,7 +42,7 @@ class AccuracyMetric(Metric): | |||||
| self.preds.extend(eval_results.tolist()) | self.preds.extend(eval_results.tolist()) | ||||
| self.labels.extend(ground_truths.tolist()) | self.labels.extend(ground_truths.tolist()) | ||||
| else: | else: | ||||
| raise 'only support list or np.ndarray' | |||||
| raise Exception('only support list or np.ndarray') | |||||
| def evaluate(self): | def evaluate(self): | ||||
| assert len(self.preds) == len(self.labels) | assert len(self.preds) == len(self.labels) | ||||
| @@ -14,9 +14,9 @@ from .builder import METRICS, MetricKeys | |||||
| @METRICS.register_module(group_key=default_group, module_name=Metrics.NED) | @METRICS.register_module(group_key=default_group, module_name=Metrics.NED) | ||||
| class NedMetric(Metric): | class NedMetric(Metric): | ||||
| """The metric computation class for classification classes. | |||||
| """The ned metric computation class for classification classes. | |||||
| This metric class calculates accuracy for the whole input batches. | |||||
| This metric class calculates the levenshtein distance between sentences for the whole input batches. | |||||
| """ | """ | ||||
| def __init__(self, *args, **kwargs): | def __init__(self, *args, **kwargs): | ||||
| @@ -44,13 +44,46 @@ class NedMetric(Metric): | |||||
| self.preds.extend(eval_results.tolist()) | self.preds.extend(eval_results.tolist()) | ||||
| self.labels.extend(ground_truths.tolist()) | self.labels.extend(ground_truths.tolist()) | ||||
| else: | else: | ||||
| raise 'only support list or np.ndarray' | |||||
| raise Exception('only support list or np.ndarray') | |||||
| def evaluate(self): | def evaluate(self): | ||||
| assert len(self.preds) == len(self.labels) | assert len(self.preds) == len(self.labels) | ||||
| return { | return { | ||||
| MetricKeys.NED: (np.asarray([ | MetricKeys.NED: (np.asarray([ | ||||
| self.ned.distance(pred, ref) | |||||
| 1.0 - NedMetric._distance(pred, ref) | |||||
| for pred, ref in zip(self.preds, self.labels) | for pred, ref in zip(self.preds, self.labels) | ||||
| ])).mean().item() | ])).mean().item() | ||||
| } | } | ||||
| @staticmethod | |||||
| def _distance(pred, ref): | |||||
| if pred is None or ref is None: | |||||
| raise TypeError('Argument s0 is NoneType.') | |||||
| if pred == ref: | |||||
| return 0.0 | |||||
| if len(pred) == 0: | |||||
| return len(ref) | |||||
| if len(ref) == 0: | |||||
| return len(pred) | |||||
| m_len = max(len(pred), len(ref)) | |||||
| if m_len == 0: | |||||
| return 0.0 | |||||
| def levenshtein(s0, s1): | |||||
| v0 = [0] * (len(s1) + 1) | |||||
| v1 = [0] * (len(s1) + 1) | |||||
| for i in range(len(v0)): | |||||
| v0[i] = i | |||||
| for i in range(len(s0)): | |||||
| v1[0] = i + 1 | |||||
| for j in range(len(s1)): | |||||
| cost = 1 | |||||
| if s0[i] == s1[j]: | |||||
| cost = 0 | |||||
| v1[j + 1] = min(v1[j] + 1, v0[j + 1] + 1, v0[j] + cost) | |||||
| v0, v1 = v1, v0 | |||||
| return v0[len(s1)] | |||||
| return levenshtein(pred, ref) / m_len | |||||
| @@ -87,7 +87,7 @@ class TestOfaTrainer(unittest.TestCase): | |||||
| 'max_image_size': 480, | 'max_image_size': 480, | ||||
| 'imagenet_default_mean_and_std': False}, | 'imagenet_default_mean_and_std': False}, | ||||
| 'pipeline': {'type': 'ofa-ocr-recognition'}, | 'pipeline': {'type': 'ofa-ocr-recognition'}, | ||||
| 'dataset': {'column_map': {'text': 'caption'}}, | |||||
| 'dataset': {'column_map': {'text': 'label'}}, | |||||
| 'train': {'work_dir': 'work/ckpts/recognition', | 'train': {'work_dir': 'work/ckpts/recognition', | ||||
| # 'launcher': 'pytorch', | # 'launcher': 'pytorch', | ||||
| 'max_epochs': 1, | 'max_epochs': 1, | ||||
| @@ -116,7 +116,6 @@ class TestOfaTrainer(unittest.TestCase): | |||||
| 'use_rdrop': True}, | 'use_rdrop': True}, | ||||
| 'hooks': [{'type': 'BestCkptSaverHook', | 'hooks': [{'type': 'BestCkptSaverHook', | ||||
| 'metric_key': 'ned', | 'metric_key': 'ned', | ||||
| 'rule': 'min', | |||||
| 'interval': 100}, | 'interval': 100}, | ||||
| {'type': 'TextLoggerHook', 'interval': 1}, | {'type': 'TextLoggerHook', 'interval': 1}, | ||||
| {'type': 'IterTimerHook'}, | {'type': 'IterTimerHook'}, | ||||
| @@ -138,11 +137,13 @@ class TestOfaTrainer(unittest.TestCase): | |||||
| model=pretrained_model, | model=pretrained_model, | ||||
| work_dir=WORKSPACE, | work_dir=WORKSPACE, | ||||
| train_dataset=MsDataset.load( | train_dataset=MsDataset.load( | ||||
| 'coco_2014_caption', | |||||
| 'ocr_fudanvi_zh', | |||||
| subset_name='scene', | |||||
| namespace='modelscope', | namespace='modelscope', | ||||
| split='train[:12]'), | split='train[:12]'), | ||||
| eval_dataset=MsDataset.load( | eval_dataset=MsDataset.load( | ||||
| 'coco_2014_caption', | |||||
| 'ocr_fudanvi_zh', | |||||
| subset_name='scene', | |||||
| namespace='modelscope', | namespace='modelscope', | ||||
| split='validation[:4]'), | split='validation[:4]'), | ||||
| cfg_file=config_file) | cfg_file=config_file) | ||||