|
|
|
@@ -14,9 +14,9 @@ from .builder import METRICS, MetricKeys |
|
|
|
|
|
|
|
@METRICS.register_module(group_key=default_group, module_name=Metrics.NED) |
|
|
|
class NedMetric(Metric): |
|
|
|
"""The metric computation class for classification classes. |
|
|
|
"""The ned metric computation class for classification classes. |
|
|
|
|
|
|
|
This metric class calculates accuracy for the whole input batches. |
|
|
|
This metric class calculates the levenshtein distance between sentences for the whole input batches. |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, *args, **kwargs): |
|
|
|
@@ -44,13 +44,46 @@ class NedMetric(Metric): |
|
|
|
self.preds.extend(eval_results.tolist()) |
|
|
|
self.labels.extend(ground_truths.tolist()) |
|
|
|
else: |
|
|
|
raise 'only support list or np.ndarray' |
|
|
|
raise Exception('only support list or np.ndarray') |
|
|
|
|
|
|
|
def evaluate(self): |
|
|
|
assert len(self.preds) == len(self.labels) |
|
|
|
return { |
|
|
|
MetricKeys.NED: (np.asarray([ |
|
|
|
self.ned.distance(pred, ref) |
|
|
|
1.0 - NedMetric._distance(pred, ref) |
|
|
|
for pred, ref in zip(self.preds, self.labels) |
|
|
|
])).mean().item() |
|
|
|
} |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def _distance(pred, ref): |
|
|
|
if pred is None or ref is None: |
|
|
|
raise TypeError('Argument s0 is NoneType.') |
|
|
|
if pred == ref: |
|
|
|
return 0.0 |
|
|
|
if len(pred) == 0: |
|
|
|
return len(ref) |
|
|
|
if len(ref) == 0: |
|
|
|
return len(pred) |
|
|
|
m_len = max(len(pred), len(ref)) |
|
|
|
if m_len == 0: |
|
|
|
return 0.0 |
|
|
|
|
|
|
|
def levenshtein(s0, s1): |
|
|
|
v0 = [0] * (len(s1) + 1) |
|
|
|
v1 = [0] * (len(s1) + 1) |
|
|
|
|
|
|
|
for i in range(len(v0)): |
|
|
|
v0[i] = i |
|
|
|
|
|
|
|
for i in range(len(s0)): |
|
|
|
v1[0] = i + 1 |
|
|
|
for j in range(len(s1)): |
|
|
|
cost = 1 |
|
|
|
if s0[i] == s1[j]: |
|
|
|
cost = 0 |
|
|
|
v1[j + 1] = min(v1[j] + 1, v0[j + 1] + 1, v0[j] + cost) |
|
|
|
v0, v1 = v1, v0 |
|
|
|
return v0[len(s1)] |
|
|
|
|
|
|
|
return levenshtein(pred, ref) / m_len |