diff --git a/abl/evaluation/__init__.py b/abl/evaluation/__init__.py new file mode 100644 index 0000000..552197d --- /dev/null +++ b/abl/evaluation/__init__.py @@ -0,0 +1,3 @@ +from .base_metric import BaseMetric +from .symbol_metric import SymbolMetric +from .abl_metric import ABLMetric diff --git a/abl/evaluation/abl_metric.py b/abl/evaluation/abl_metric.py index f0547d0..7e59a31 100644 --- a/abl/evaluation/abl_metric.py +++ b/abl/evaluation/abl_metric.py @@ -6,9 +6,10 @@ class ABLMetric(BaseMetric): def __init__(self, prefix: Optional[str] = None) -> None: super().__init__(prefix) - def process(self, data_samples: Sequence[dict], logic_forward: Callable) -> None: + def process(self, data_samples: Sequence[dict]) -> None: pred_pseudo_label = data_samples["pred_pseudo_label"] gt_Y = data_samples["Y"] + logic_forward = data_samples["logic_forward"] for pred_z, y in zip(pred_pseudo_label, gt_Y): if logic_forward(pred_z) == y: