diff --git a/abl/bridge/base_bridge.py b/abl/bridge/base_bridge.py new file mode 100644 index 0000000..d09f1e0 --- /dev/null +++ b/abl/bridge/base_bridge.py @@ -0,0 +1,52 @@ +from abc import ABCMeta, abstractmethod +from typing import Any, List, Tuple + +from ..learning import ABLModel +from ..reasoning import ReasonerBase + + +class BaseBridge(metaclass=ABCMeta): + + def __init__(self, model: ABLModel, abducer: ReasonerBase) -> None: + if not isinstance(model, ABLModel): + raise TypeError("Expected an ABLModel") + if not isinstance(abducer, ReasonerBase): + raise TypeError("Expected an ReasonerBase") + + self.model = model + self.abducer = abducer + + @abstractmethod + def predict(self, X: List[List[Any]]) -> Tuple[List[List[Any]], List[List[Any]]]: + """Placeholder for predict labels from input.""" + pass + + @abstractmethod + def abduce_pseudo_label(self, pseudo_label: List[List[Any]]) -> List[List[Any]]: + """Placeholder for abduce pseudo labels.""" + + @abstractmethod + def label_to_pseudo_label(self, label: List[List[Any]]) -> List[List[Any]]: + """Placeholder for map label space to symbol space.""" + pass + + @abstractmethod + def pseudo_label_to_label(self, pseudo_label: List[List[Any]]) -> List[List[Any]]: + """Placeholder for map symbol space to label space.""" + pass + + @abstractmethod + def train(self, train_data): + """Placeholder for train loop of ABductive Learning.""" + pass + + @abstractmethod + def test(self, test_data): + """Placeholder for model test.""" + pass + + @abstractmethod + def valid(self, valid_data): + """Placeholder for model validation.""" + pass + \ No newline at end of file diff --git a/abl/bridge/simple_bridge.py b/abl/bridge/simple_bridge.py new file mode 100644 index 0000000..de77c31 --- /dev/null +++ b/abl/bridge/simple_bridge.py @@ -0,0 +1,118 @@ +from ..learning import ABLModel +from ..reasoning import ReasonerBase +from ..evaluation import BaseMetric +from .base_bridge import BaseBridge +from typing import List, Union, Any, Tuple, Dict, Optional +from numpy import ndarray + +from torch.utils.data import DataLoader +from ..dataset import BridgeDataset +from ..utils.logger import print_log + + +class SimpleBridge(BaseBridge): + def __init__( + self, + model: ABLModel, + abducer: ReasonerBase, + metric_list: BaseMetric, + ) -> None: + super().__init__(model, abducer) + self.metric_list = metric_list + + def predict(self, X) -> Tuple[List[List[Any]], ndarray]: + pred_res = self.model.predict(X) + pred_label, pred_prob = pred_res["label"], pred_res["prob"] + return pred_label, pred_prob + + def abduce_pseudo_label( + self, + pred_label: List[List[Any]], + pred_prob: ndarray, + pseudo_label: List[List[Any]], + Y: List[List[Any]], + ) -> List[List[Any]]: + return self.abducer.batch_abduce(pred_label, pred_prob, pseudo_label, Y) + + def label_to_pseudo_label( + self, label: List[List[Any]], mapping: Dict = None + ) -> List[List[Any]]: + if mapping is None: + mapping = self.abducer.mapping + return [[mapping[_label] for _label in sub_list] for sub_list in label] + + def pseudo_label_to_label( + self, pseudo_label: List[List[Any]], mapping: Dict = None + ) -> List[List[Any]]: + if mapping is None: + mapping = self.abducer.remapping + return [ + [mapping[_pseudo_label] for _pseudo_label in sub_list] + for sub_list in pseudo_label + ] + + def train( + self, + train_data: Tuple[List[List[Any]], Optional[List[List[Any]]], List[List[Any]]], + epochs: int = 50, + batch_size: Union[int, float] = -1, + eval_interval: int = 1, + ): + dataset = BridgeDataset(*train_data) + data_loader = DataLoader( + dataset, + batch_size=batch_size, + collate_fn=lambda data_list: [list(data) for data in zip(*data_list)], + ) + + for epoch in range(epochs): + for seg_idx, (X, Z, Y) in enumerate(data_loader): + pred_label, pred_prob = self.predict(X) + pred_pseudo_label = self.label_to_pseudo_label(pred_label) + abduced_pseudo_label = self.abduce_pseudo_label( + pred_label, pred_prob, pred_pseudo_label, Y + ) + abduced_label = self.pseudo_label_to_label(abduced_pseudo_label) + min_loss = self.model.train(X, abduced_label) + + print_log(f"Epoch(train) [{epoch}] [{seg_idx:3}/{len(data_loader)}] minimal_loss is {min_loss:.5f}", logger="current") + + if (epoch + 1) % eval_interval == 0 or epoch == epochs - 1: + print_log(f"Evaluation start: Epoch(val) [{epoch}]", logger="current") + self.valid(train_data) + + def test(self, test_data): + return super().test(test_data) + + def _valid(self, data_loader): + res = dict() + for X, Z, Y in data_loader: + pred_label, pred_prob = self.predict(X) + pred_pseudo_label = self.label_to_pseudo_label(pred_label) + data_samples = dict( + pred_label=pred_label, + pred_prob=pred_prob, + pred_pseudo_label=pred_pseudo_label, + gt_pseudo_label=Z, + Y=Y, + logic_forward=self.abducer.kb.logic_forward, + ) + + for metric in self.metric_list: + metric.process(data_samples) + res.update(metric.evaluate()) + msg = "Evaluation ended, " + for k, v in res.items(): + msg += k + f": {v:.3f} " + print_log(msg, logger="current") + + def valid(self, valid_data, batch_size=1000): + dataset = BridgeDataset(*valid_data) + data_loader = DataLoader( + dataset, + batch_size=batch_size, + collate_fn=lambda data_list: [list(data) for data in zip(*data_list)], + ) + self._valid(data_loader) + + return super().valid(valid_data)