From f398d98cd2357382e91d11460a93a766ebcf4dc8 Mon Sep 17 00:00:00 2001 From: Gao Enhao Date: Sun, 10 Dec 2023 01:11:32 +0800 Subject: [PATCH] [MNT] resolve comments of metrics --- abl/evaluation/base_metric.py | 21 ++++++--------------- abl/evaluation/semantics_metric.py | 8 +++++--- abl/evaluation/symbol_metric.py | 28 ++++++++++++++-------------- 3 files changed, 25 insertions(+), 32 deletions(-) diff --git a/abl/evaluation/base_metric.py b/abl/evaluation/base_metric.py index 03b1997..1fbcf0f 100644 --- a/abl/evaluation/base_metric.py +++ b/abl/evaluation/base_metric.py @@ -1,7 +1,8 @@ import logging from abc import ABCMeta, abstractmethod -from typing import Any, List, Optional, Sequence +from typing import Any, List, Optional +from ..structures import ListData from ..utils import print_log @@ -28,23 +29,20 @@ class BaseMetric(metaclass=ABCMeta): self.prefix = prefix or self.default_prefix @abstractmethod - def process(self, data_samples: Sequence[dict]) -> None: + def process(self, data_samples: ListData) -> None: """Process one batch of data samples and predictions. The processed results should be stored in ``self.results``, which will be used to compute the metrics when all batches have been processed. Args: - data_samples (Sequence[dict]): A batch of outputs from + data_samples (ListData): A batch of outputs from the model. """ @abstractmethod - def compute_metrics(self, results: list) -> dict: + def compute_metrics(self) -> dict: """Compute the metrics from processed results. - Args: - results (list): The processed results of each batch. - Returns: dict: The computed metrics. The keys are the names of the metrics, and the values are corresponding results. @@ -54,13 +52,6 @@ class BaseMetric(metaclass=ABCMeta): """Evaluate the model performance of the whole dataset after processing all batches. - Args: - size (int): Length of the entire validation dataset. When batch - size > 1, the dataloader may pad some data samples to make - sure all ranks have the same length of dataset slice. The - ``collect_results`` function will drop the padded data based on - this size. - Returns: dict: Evaluation metrics dict on the val dataset. The keys are the names of the metrics, and the values are corresponding results. @@ -74,7 +65,7 @@ class BaseMetric(metaclass=ABCMeta): level=logging.WARNING, ) - metrics = self.compute_metrics(self.results) + metrics = self.compute_metrics() # Add prefix to metric names if self.prefix: metrics = {"/".join((self.prefix, k)): v for k, v in metrics.items()} diff --git a/abl/evaluation/semantics_metric.py b/abl/evaluation/semantics_metric.py index 14c4f46..7254f34 100644 --- a/abl/evaluation/semantics_metric.py +++ b/abl/evaluation/semantics_metric.py @@ -1,6 +1,7 @@ -from typing import Optional, Sequence +from typing import Optional from ..reasoning import KBBase +from ..structures import ListData from .base_metric import BaseMetric @@ -9,7 +10,7 @@ class SemanticsMetric(BaseMetric): super().__init__(prefix) self.kb = kb - def process(self, data_samples: Sequence[dict]) -> None: + def process(self, data_samples: ListData) -> None: pred_pseudo_label_list = data_samples.pred_pseudo_label y_list = data_samples.Y for pred_pseudo_label, y in zip(pred_pseudo_label_list, y_list): @@ -18,7 +19,8 @@ class SemanticsMetric(BaseMetric): else: self.results.append(0) - def compute_metrics(self, results: list) -> dict: + def compute_metrics(self) -> dict: + results = self.results metrics = dict() metrics["semantics_accuracy"] = sum(results) / len(results) return metrics diff --git a/abl/evaluation/symbol_metric.py b/abl/evaluation/symbol_metric.py index 112dc8b..46b0a70 100644 --- a/abl/evaluation/symbol_metric.py +++ b/abl/evaluation/symbol_metric.py @@ -1,5 +1,6 @@ -from typing import Optional, Sequence +from typing import Optional +from ..structures import ListData from .base_metric import BaseMetric @@ -7,22 +8,21 @@ class SymbolMetric(BaseMetric): def __init__(self, prefix: Optional[str] = None) -> None: super().__init__(prefix) - def process(self, data_samples: Sequence[dict]) -> None: - pred_pseudo_label = data_samples.pred_pseudo_label + def process(self, data_samples: ListData) -> None: + pred_pseudo_label_list = data_samples.flatten("pred_pseudo_label") + gt_pseudo_label_list = data_samples.flatten("gt_pseudo_label") - gt_pseudo_label = data_samples.gt_pseudo_label - - if not len(pred_pseudo_label) == len(gt_pseudo_label): + if not len(pred_pseudo_label_list) == len(gt_pseudo_label_list): raise ValueError("lengthes of pred_pseudo_label and gt_pseudo_label should be equal") - for pred_z, z in zip(pred_pseudo_label, gt_pseudo_label): - correct_num = 0 - for pred_symbol, symbol in zip(pred_z, z): - if pred_symbol == symbol: - correct_num += 1 - self.results.append(correct_num / len(z)) + correct_num = 0 + for pred_pseudo_label, gt_pseudo_label in zip(pred_pseudo_label_list, gt_pseudo_label_list): + if pred_pseudo_label == gt_pseudo_label: + correct_num += 1 + self.results.append((correct_num, len(pred_pseudo_label_list))) - def compute_metrics(self, results: list) -> dict: + def compute_metrics(self) -> dict: + results = self.results metrics = dict() - metrics["character_accuracy"] = sum(results) / len(results) + metrics["character_accuracy"] = sum(t[0] for t in results) / sum(t[1] for t in results) return metrics