From dafbe2cb43cc6843121341724c5a7411fa8d09c8 Mon Sep 17 00:00:00 2001 From: Gao Enhao Date: Thu, 8 Jun 2023 22:39:00 +0800 Subject: [PATCH] [ENH] add metric folder and three basic metric classes --- abl/evaluation/abl_metric.py | 22 +++++++++ abl/evaluation/base_metric.py | 81 +++++++++++++++++++++++++++++++++ abl/evaluation/symbol_metric.py | 27 +++++++++++ 3 files changed, 130 insertions(+) create mode 100644 abl/evaluation/abl_metric.py create mode 100644 abl/evaluation/base_metric.py create mode 100644 abl/evaluation/symbol_metric.py diff --git a/abl/evaluation/abl_metric.py b/abl/evaluation/abl_metric.py new file mode 100644 index 0000000..f0547d0 --- /dev/null +++ b/abl/evaluation/abl_metric.py @@ -0,0 +1,22 @@ +from typing import Optional, Sequence, Callable +from .base_metric import BaseMetric + + +class ABLMetric(BaseMetric): + def __init__(self, prefix: Optional[str] = None) -> None: + super().__init__(prefix) + + def process(self, data_samples: Sequence[dict], logic_forward: Callable) -> None: + pred_pseudo_label = data_samples["pred_pseudo_label"] + gt_Y = data_samples["Y"] + + for pred_z, y in zip(pred_pseudo_label, gt_Y): + if logic_forward(pred_z) == y: + self.results.append(1) + else: + self.results.append(0) + + def compute_metrics(self, results: list) -> dict: + metrics = dict() + metrics["abl_accuracy"] = sum(results) / len(results) + return metrics \ No newline at end of file diff --git a/abl/evaluation/base_metric.py b/abl/evaluation/base_metric.py new file mode 100644 index 0000000..a861919 --- /dev/null +++ b/abl/evaluation/base_metric.py @@ -0,0 +1,81 @@ +from abc import ABCMeta, abstractmethod +from typing import Any, List, Optional, Sequence + + +class BaseMetric(metaclass=ABCMeta): + """Base class for a metric. + + The metric first processes each batch of data_samples and predictions, + and appends the processed results to the results list. Then it + collects all results together from all ranks if distributed training + is used. Finally, it computes the metrics of the entire dataset. + + Args: + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Default: None + """ + + def __init__(self, + prefix: Optional[str] = None,) -> None: + self.results: List[Any] = [] + self.prefix = prefix or self.default_prefix + + @abstractmethod + def process(self, data_samples: Sequence[dict]) -> None: + """Process one batch of data samples and predictions. The processed + results should be stored in ``self.results``, which will be used to + compute the metrics when all batches have been processed. + + Args: + data_samples (Sequence[dict]): A batch of outputs from + the model. + """ + + @abstractmethod + def compute_metrics(self, results: list) -> dict: + """Compute the metrics from processed results. + + Args: + results (list): The processed results of each batch. + + Returns: + dict: The computed metrics. The keys are the names of the metrics, + and the values are corresponding results. + """ + + def evaluate(self) -> dict: + """Evaluate the model performance of the whole dataset after processing + all batches. + + Args: + size (int): Length of the entire validation dataset. When batch + size > 1, the dataloader may pad some data samples to make + sure all ranks have the same length of dataset slice. The + ``collect_results`` function will drop the padded data based on + this size. + + Returns: + dict: Evaluation metrics dict on the val dataset. The keys are the + names of the metrics, and the values are corresponding results. + """ + # if len(self.results) == 0: + # print_log( + # f'{self.__class__.__name__} got empty `self.results`. Please ' + # 'ensure that the processed results are properly added into ' + # '`self.results` in `process` method.', + # logger='current', + # level=logging.WARNING) + + metrics = self.compute_metrics(self.results) + # Add prefix to metric names + if self.prefix: + metrics = { + '/'.join((self.prefix, k)): v + for k, v in metrics.items() + } + + # reset the results list + self.results.clear() + return metrics \ No newline at end of file diff --git a/abl/evaluation/symbol_metric.py b/abl/evaluation/symbol_metric.py new file mode 100644 index 0000000..3c0c216 --- /dev/null +++ b/abl/evaluation/symbol_metric.py @@ -0,0 +1,27 @@ +from typing import Optional, Sequence, Callable +from .base_metric import BaseMetric + + +class SymbolMetric(BaseMetric): + def __init__(self, prefix: Optional[str] = None) -> None: + super().__init__(prefix) + + def process(self, data_samples: Sequence[dict]) -> None: + pred_pseudo_label = data_samples["pred_pseudo_label"] + + gt_pseudo_label = data_samples["gt_pseudo_label"] + + if not len(pred_pseudo_label) == len(gt_pseudo_label): + raise ValueError("lengthes of pred_pseudo_label and gt_pseudo_label should be equal") + + for pred_z, z in zip(pred_pseudo_label, gt_pseudo_label): + correct_num = 0 + for pred_symbol, symbol in zip(pred_z, z): + if pred_symbol == symbol: + correct_num += 1 + self.results.append(correct_num / len(z)) + + def compute_metrics(self, results: list) -> dict: + metrics = dict() + metrics["character_accuracy"] = sum(results) / len(results) + return metrics \ No newline at end of file