Browse Source

[MNT] move logic_forward into data_sample for ABLMetric

pull/3/head
Gao Enhao 2 years ago
parent
commit
c120ce3e4b
2 changed files with 5 additions and 1 deletions
  1. +3
    -0
      abl/evaluation/__init__.py
  2. +2
    -1
      abl/evaluation/abl_metric.py

+ 3
- 0
abl/evaluation/__init__.py View File

@@ -0,0 +1,3 @@
from .base_metric import BaseMetric
from .symbol_metric import SymbolMetric
from .abl_metric import ABLMetric

+ 2
- 1
abl/evaluation/abl_metric.py View File

@@ -6,9 +6,10 @@ class ABLMetric(BaseMetric):
def __init__(self, prefix: Optional[str] = None) -> None:
super().__init__(prefix)

def process(self, data_samples: Sequence[dict], logic_forward: Callable) -> None:
def process(self, data_samples: Sequence[dict]) -> None:
pred_pseudo_label = data_samples["pred_pseudo_label"]
gt_Y = data_samples["Y"]
logic_forward = data_samples["logic_forward"]

for pred_z, y in zip(pred_pseudo_label, gt_Y):
if logic_forward(pred_z) == y:


Loading…
Cancel
Save