""" This module contains the base class used for evaluation. Copyright (c) 2024 LAMDA. All rights reserved. """ import logging from abc import ABCMeta, abstractmethod from typing import Any, List, Optional from ...utils import print_log from ..structures import ListData class BaseMetric(metaclass=ABCMeta): """ Base class for a metrics. The metrics first processes each batch of data_examples and appends the processed results to the results list. Then, it computes the metrics of the entire dataset. Parameters ---------- prefix : str, optional The prefix that will be added in the metrics names to disambiguate homonymous metrics of different tasks. If prefix is not provided in the argument, self.default_prefix will be used instead. Defaults to None. """ def __init__( self, prefix: Optional[str] = None, ) -> None: self.default_prefix = "" self.results: List[Any] = [] self.prefix = prefix or self.default_prefix @abstractmethod def process(self, data_examples: ListData) -> None: """ Process one batch of data examples. The processed results should be stored in ``self.results``, which will be used to compute the metrics when all batches have been processed. Parameters ---------- data_examples : ListData A batch of data examples. """ @abstractmethod def compute_metrics(self) -> dict: """ Compute the metrics from processed results. Returns ------- dict The computed metrics. The keys are the names of the metrics, and the values are the corresponding results. """ def evaluate(self) -> dict: """ Evaluate the model performance of the whole dataset after processing all batches. Returns ------- dict Evaluation metrics dict on the val dataset. The keys are the names of the metrics, and the values are the corresponding results. """ if len(self.results) == 0: print_log( f"{self.__class__.__name__} got empty `self.results`. Please " "ensure that the processed results are properly added into " "`self.results` in `process` method.", logger="current", level=logging.WARNING, ) metrics = self.compute_metrics() # Add prefix to metrics names if self.prefix: metrics = {"/".join((self.prefix, k)): v for k, v in metrics.items()} # reset the results list self.results.clear() return metrics