| @@ -1,7 +1,8 @@ | |||
| import logging | |||
| from abc import ABCMeta, abstractmethod | |||
| from typing import Any, List, Optional, Sequence | |||
| from typing import Any, List, Optional | |||
| from ..structures import ListData | |||
| from ..utils import print_log | |||
| @@ -28,23 +29,20 @@ class BaseMetric(metaclass=ABCMeta): | |||
| self.prefix = prefix or self.default_prefix | |||
| @abstractmethod | |||
| def process(self, data_samples: Sequence[dict]) -> None: | |||
| def process(self, data_samples: ListData) -> None: | |||
| """Process one batch of data samples and predictions. The processed | |||
| results should be stored in ``self.results``, which will be used to | |||
| compute the metrics when all batches have been processed. | |||
| Args: | |||
| data_samples (Sequence[dict]): A batch of outputs from | |||
| data_samples (ListData): A batch of outputs from | |||
| the model. | |||
| """ | |||
| @abstractmethod | |||
| def compute_metrics(self, results: list) -> dict: | |||
| def compute_metrics(self) -> dict: | |||
| """Compute the metrics from processed results. | |||
| Args: | |||
| results (list): The processed results of each batch. | |||
| Returns: | |||
| dict: The computed metrics. The keys are the names of the metrics, | |||
| and the values are corresponding results. | |||
| @@ -54,13 +52,6 @@ class BaseMetric(metaclass=ABCMeta): | |||
| """Evaluate the model performance of the whole dataset after processing | |||
| all batches. | |||
| Args: | |||
| size (int): Length of the entire validation dataset. When batch | |||
| size > 1, the dataloader may pad some data samples to make | |||
| sure all ranks have the same length of dataset slice. The | |||
| ``collect_results`` function will drop the padded data based on | |||
| this size. | |||
| Returns: | |||
| dict: Evaluation metrics dict on the val dataset. The keys are the | |||
| names of the metrics, and the values are corresponding results. | |||
| @@ -74,7 +65,7 @@ class BaseMetric(metaclass=ABCMeta): | |||
| level=logging.WARNING, | |||
| ) | |||
| metrics = self.compute_metrics(self.results) | |||
| metrics = self.compute_metrics() | |||
| # Add prefix to metric names | |||
| if self.prefix: | |||
| metrics = {"/".join((self.prefix, k)): v for k, v in metrics.items()} | |||
| @@ -1,6 +1,7 @@ | |||
| from typing import Optional, Sequence | |||
| from typing import Optional | |||
| from ..reasoning import KBBase | |||
| from ..structures import ListData | |||
| from .base_metric import BaseMetric | |||
| @@ -9,7 +10,7 @@ class SemanticsMetric(BaseMetric): | |||
| super().__init__(prefix) | |||
| self.kb = kb | |||
| def process(self, data_samples: Sequence[dict]) -> None: | |||
| def process(self, data_samples: ListData) -> None: | |||
| pred_pseudo_label_list = data_samples.pred_pseudo_label | |||
| y_list = data_samples.Y | |||
| for pred_pseudo_label, y in zip(pred_pseudo_label_list, y_list): | |||
| @@ -18,7 +19,8 @@ class SemanticsMetric(BaseMetric): | |||
| else: | |||
| self.results.append(0) | |||
| def compute_metrics(self, results: list) -> dict: | |||
| def compute_metrics(self) -> dict: | |||
| results = self.results | |||
| metrics = dict() | |||
| metrics["semantics_accuracy"] = sum(results) / len(results) | |||
| return metrics | |||
| @@ -1,5 +1,6 @@ | |||
| from typing import Optional, Sequence | |||
| from typing import Optional | |||
| from ..structures import ListData | |||
| from .base_metric import BaseMetric | |||
| @@ -7,22 +8,21 @@ class SymbolMetric(BaseMetric): | |||
| def __init__(self, prefix: Optional[str] = None) -> None: | |||
| super().__init__(prefix) | |||
| def process(self, data_samples: Sequence[dict]) -> None: | |||
| pred_pseudo_label = data_samples.pred_pseudo_label | |||
| def process(self, data_samples: ListData) -> None: | |||
| pred_pseudo_label_list = data_samples.flatten("pred_pseudo_label") | |||
| gt_pseudo_label_list = data_samples.flatten("gt_pseudo_label") | |||
| gt_pseudo_label = data_samples.gt_pseudo_label | |||
| if not len(pred_pseudo_label) == len(gt_pseudo_label): | |||
| if not len(pred_pseudo_label_list) == len(gt_pseudo_label_list): | |||
| raise ValueError("lengthes of pred_pseudo_label and gt_pseudo_label should be equal") | |||
| for pred_z, z in zip(pred_pseudo_label, gt_pseudo_label): | |||
| correct_num = 0 | |||
| for pred_symbol, symbol in zip(pred_z, z): | |||
| if pred_symbol == symbol: | |||
| correct_num += 1 | |||
| self.results.append(correct_num / len(z)) | |||
| correct_num = 0 | |||
| for pred_pseudo_label, gt_pseudo_label in zip(pred_pseudo_label_list, gt_pseudo_label_list): | |||
| if pred_pseudo_label == gt_pseudo_label: | |||
| correct_num += 1 | |||
| self.results.append((correct_num, len(pred_pseudo_label_list))) | |||
| def compute_metrics(self, results: list) -> dict: | |||
| def compute_metrics(self) -> dict: | |||
| results = self.results | |||
| metrics = dict() | |||
| metrics["character_accuracy"] = sum(results) / len(results) | |||
| metrics["character_accuracy"] = sum(t[0] for t in results) / sum(t[1] for t in results) | |||
| return metrics | |||