""" This module contains the ReasoningMetric, which is used for evaluating the model performance on tasks that need reasoning. Copyright (c) 2024 LAMDA. All rights reserved. """ from typing import Optional from ...reasoning import KBBase from ..structures import ListData from .base_metric import BaseMetric class ReasoningMetric(BaseMetric): """ A metrics class for evaluating the model performance on tasks that need reasoning. This class is designed to calculate the accuracy of the reasoing results. Reasoning results are generated by first using the learning part to predict pseudo-labels and then using a knowledge base (KB) to perform logical reasoning. The reasoning results are then compared with the ground truth to calculate the accuracy. Parameters ---------- kb : KBBase An instance of a knowledge base, used for logical reasoning and validation. If not provided, reasoning checks are not performed. Defaults to None. prefix : str, optional The prefix that will be added to the metrics names to disambiguate homonymous metrics of different tasks. Inherits from BaseMetric. Defaults to None. Notes ----- The `ReasoningMetric` expects data_examples to have the attributes `pred_pseudo_label`, `Y`, and `X`, corresponding to the predicted pseduo labels, ground truth of reasoning results, and input data, respectively. """ def __init__(self, kb: KBBase, prefix: Optional[str] = None) -> None: super().__init__(prefix) self.kb = kb # pylint: disable=protected-access def process(self, data_examples: ListData) -> None: """ Process a batch of data examples. This method takes in a batch of data examples, each containing predicted pseudo-labels (pred_pseudo_label), ground truth of reasoning results (Y), and input data (X). It evaluates the reasoning accuracy of each example by comparing the logical reasoning result (derived using the knowledge base) of the predicted pseudo-labels against Y The result of this comparison (1 for correct reasoning, 0 for incorrect) is appended to ``self.results``. Parameters ---------- data_examples : ListData A batch of data examples. """ pred_pseudo_label_list = data_examples.pred_pseudo_label y_list = data_examples.Y x_list = data_examples.X for pred_pseudo_label, y, x in zip(pred_pseudo_label_list, y_list, x_list): if self.kb._check_equal( self.kb.logic_forward(pred_pseudo_label, *(x,) if self.kb._num_args == 2 else ()), y ): self.results.append(1) else: self.results.append(0) def compute_metrics(self) -> dict: """ Compute the reasoning accuracy metrics from ``self.results``. It calculates the percentage of correctly reasoned examples over all examples. Returns ------- dict A dictionary containing the computed metrics. It includes the key 'reasoning_accuracy' which maps to the calculated reasoning accuracy, represented as a float between 0 and 1. """ results = self.results metrics = dict() metrics["reasoning_accuracy"] = sum(results) / len(results) return metrics