Browse Source

[ENH] add metric folder and three basic metric classes

pull/3/head
Gao Enhao 3 years ago
parent
commit
dafbe2cb43
3 changed files with 130 additions and 0 deletions
  1. +22
    -0
      abl/evaluation/abl_metric.py
  2. +81
    -0
      abl/evaluation/base_metric.py
  3. +27
    -0
      abl/evaluation/symbol_metric.py

+ 22
- 0
abl/evaluation/abl_metric.py View File

@@ -0,0 +1,22 @@
from typing import Optional, Sequence, Callable
from .base_metric import BaseMetric


class ABLMetric(BaseMetric):
def __init__(self, prefix: Optional[str] = None) -> None:
super().__init__(prefix)

def process(self, data_samples: Sequence[dict], logic_forward: Callable) -> None:
pred_pseudo_label = data_samples["pred_pseudo_label"]
gt_Y = data_samples["Y"]

for pred_z, y in zip(pred_pseudo_label, gt_Y):
if logic_forward(pred_z) == y:
self.results.append(1)
else:
self.results.append(0)
def compute_metrics(self, results: list) -> dict:
metrics = dict()
metrics["abl_accuracy"] = sum(results) / len(results)
return metrics

+ 81
- 0
abl/evaluation/base_metric.py View File

@@ -0,0 +1,81 @@
from abc import ABCMeta, abstractmethod
from typing import Any, List, Optional, Sequence


class BaseMetric(metaclass=ABCMeta):
"""Base class for a metric.

The metric first processes each batch of data_samples and predictions,
and appends the processed results to the results list. Then it
collects all results together from all ranks if distributed training
is used. Finally, it computes the metrics of the entire dataset.

Args:
prefix (str, optional): The prefix that will be added in the metric
names to disambiguate homonymous metrics of different evaluators.
If prefix is not provided in the argument, self.default_prefix
will be used instead. Default: None
"""

def __init__(self,
prefix: Optional[str] = None,) -> None:
self.results: List[Any] = []
self.prefix = prefix or self.default_prefix

@abstractmethod
def process(self, data_samples: Sequence[dict]) -> None:
"""Process one batch of data samples and predictions. The processed
results should be stored in ``self.results``, which will be used to
compute the metrics when all batches have been processed.

Args:
data_samples (Sequence[dict]): A batch of outputs from
the model.
"""

@abstractmethod
def compute_metrics(self, results: list) -> dict:
"""Compute the metrics from processed results.

Args:
results (list): The processed results of each batch.

Returns:
dict: The computed metrics. The keys are the names of the metrics,
and the values are corresponding results.
"""

def evaluate(self) -> dict:
"""Evaluate the model performance of the whole dataset after processing
all batches.

Args:
size (int): Length of the entire validation dataset. When batch
size > 1, the dataloader may pad some data samples to make
sure all ranks have the same length of dataset slice. The
``collect_results`` function will drop the padded data based on
this size.

Returns:
dict: Evaluation metrics dict on the val dataset. The keys are the
names of the metrics, and the values are corresponding results.
"""
# if len(self.results) == 0:
# print_log(
# f'{self.__class__.__name__} got empty `self.results`. Please '
# 'ensure that the processed results are properly added into '
# '`self.results` in `process` method.',
# logger='current',
# level=logging.WARNING)

metrics = self.compute_metrics(self.results)
# Add prefix to metric names
if self.prefix:
metrics = {
'/'.join((self.prefix, k)): v
for k, v in metrics.items()
}

# reset the results list
self.results.clear()
return metrics

+ 27
- 0
abl/evaluation/symbol_metric.py View File

@@ -0,0 +1,27 @@
from typing import Optional, Sequence, Callable
from .base_metric import BaseMetric


class SymbolMetric(BaseMetric):
def __init__(self, prefix: Optional[str] = None) -> None:
super().__init__(prefix)

def process(self, data_samples: Sequence[dict]) -> None:
pred_pseudo_label = data_samples["pred_pseudo_label"]

gt_pseudo_label = data_samples["gt_pseudo_label"]

if not len(pred_pseudo_label) == len(gt_pseudo_label):
raise ValueError("lengthes of pred_pseudo_label and gt_pseudo_label should be equal")
for pred_z, z in zip(pred_pseudo_label, gt_pseudo_label):
correct_num = 0
for pred_symbol, symbol in zip(pred_z, z):
if pred_symbol == symbol:
correct_num += 1
self.results.append(correct_num / len(z))
def compute_metrics(self, results: list) -> dict:
metrics = dict()
metrics["character_accuracy"] = sum(results) / len(results)
return metrics

Loading…
Cancel
Save