Browse Source

[MNT] resolve comments of metrics

pull/1/head
Gao Enhao 2 years ago
parent
commit
f398d98cd2
3 changed files with 25 additions and 32 deletions
  1. +6
    -15
      abl/evaluation/base_metric.py
  2. +5
    -3
      abl/evaluation/semantics_metric.py
  3. +14
    -14
      abl/evaluation/symbol_metric.py

+ 6
- 15
abl/evaluation/base_metric.py View File

@@ -1,7 +1,8 @@
import logging
from abc import ABCMeta, abstractmethod
from typing import Any, List, Optional, Sequence
from typing import Any, List, Optional

from ..structures import ListData
from ..utils import print_log


@@ -28,23 +29,20 @@ class BaseMetric(metaclass=ABCMeta):
self.prefix = prefix or self.default_prefix

@abstractmethod
def process(self, data_samples: Sequence[dict]) -> None:
def process(self, data_samples: ListData) -> None:
"""Process one batch of data samples and predictions. The processed
results should be stored in ``self.results``, which will be used to
compute the metrics when all batches have been processed.

Args:
data_samples (Sequence[dict]): A batch of outputs from
data_samples (ListData): A batch of outputs from
the model.
"""

@abstractmethod
def compute_metrics(self, results: list) -> dict:
def compute_metrics(self) -> dict:
"""Compute the metrics from processed results.

Args:
results (list): The processed results of each batch.

Returns:
dict: The computed metrics. The keys are the names of the metrics,
and the values are corresponding results.
@@ -54,13 +52,6 @@ class BaseMetric(metaclass=ABCMeta):
"""Evaluate the model performance of the whole dataset after processing
all batches.

Args:
size (int): Length of the entire validation dataset. When batch
size > 1, the dataloader may pad some data samples to make
sure all ranks have the same length of dataset slice. The
``collect_results`` function will drop the padded data based on
this size.

Returns:
dict: Evaluation metrics dict on the val dataset. The keys are the
names of the metrics, and the values are corresponding results.
@@ -74,7 +65,7 @@ class BaseMetric(metaclass=ABCMeta):
level=logging.WARNING,
)

metrics = self.compute_metrics(self.results)
metrics = self.compute_metrics()
# Add prefix to metric names
if self.prefix:
metrics = {"/".join((self.prefix, k)): v for k, v in metrics.items()}


+ 5
- 3
abl/evaluation/semantics_metric.py View File

@@ -1,6 +1,7 @@
from typing import Optional, Sequence
from typing import Optional

from ..reasoning import KBBase
from ..structures import ListData
from .base_metric import BaseMetric


@@ -9,7 +10,7 @@ class SemanticsMetric(BaseMetric):
super().__init__(prefix)
self.kb = kb

def process(self, data_samples: Sequence[dict]) -> None:
def process(self, data_samples: ListData) -> None:
pred_pseudo_label_list = data_samples.pred_pseudo_label
y_list = data_samples.Y
for pred_pseudo_label, y in zip(pred_pseudo_label_list, y_list):
@@ -18,7 +19,8 @@ class SemanticsMetric(BaseMetric):
else:
self.results.append(0)

def compute_metrics(self, results: list) -> dict:
def compute_metrics(self) -> dict:
results = self.results
metrics = dict()
metrics["semantics_accuracy"] = sum(results) / len(results)
return metrics

+ 14
- 14
abl/evaluation/symbol_metric.py View File

@@ -1,5 +1,6 @@
from typing import Optional, Sequence
from typing import Optional

from ..structures import ListData
from .base_metric import BaseMetric


@@ -7,22 +8,21 @@ class SymbolMetric(BaseMetric):
def __init__(self, prefix: Optional[str] = None) -> None:
super().__init__(prefix)

def process(self, data_samples: Sequence[dict]) -> None:
pred_pseudo_label = data_samples.pred_pseudo_label
def process(self, data_samples: ListData) -> None:
pred_pseudo_label_list = data_samples.flatten("pred_pseudo_label")
gt_pseudo_label_list = data_samples.flatten("gt_pseudo_label")

gt_pseudo_label = data_samples.gt_pseudo_label

if not len(pred_pseudo_label) == len(gt_pseudo_label):
if not len(pred_pseudo_label_list) == len(gt_pseudo_label_list):
raise ValueError("lengthes of pred_pseudo_label and gt_pseudo_label should be equal")

for pred_z, z in zip(pred_pseudo_label, gt_pseudo_label):
correct_num = 0
for pred_symbol, symbol in zip(pred_z, z):
if pred_symbol == symbol:
correct_num += 1
self.results.append(correct_num / len(z))
correct_num = 0
for pred_pseudo_label, gt_pseudo_label in zip(pred_pseudo_label_list, gt_pseudo_label_list):
if pred_pseudo_label == gt_pseudo_label:
correct_num += 1
self.results.append((correct_num, len(pred_pseudo_label_list)))

def compute_metrics(self, results: list) -> dict:
def compute_metrics(self) -> dict:
results = self.results
metrics = dict()
metrics["character_accuracy"] = sum(results) / len(results)
metrics["character_accuracy"] = sum(t[0] for t in results) / sum(t[1] for t in results)
return metrics

Loading…
Cancel
Save