Browse Source

[ENH] add base and simple bridge which are used for combining learning and reasoning

pull/3/head
Gao Enhao 3 years ago
parent
commit
757e485864
2 changed files with 170 additions and 0 deletions
  1. +52
    -0
      abl/bridge/base_bridge.py
  2. +118
    -0
      abl/bridge/simple_bridge.py

+ 52
- 0
abl/bridge/base_bridge.py View File

@@ -0,0 +1,52 @@
from abc import ABCMeta, abstractmethod
from typing import Any, List, Tuple

from ..learning import ABLModel
from ..reasoning import ReasonerBase


class BaseBridge(metaclass=ABCMeta):

def __init__(self, model: ABLModel, abducer: ReasonerBase) -> None:
if not isinstance(model, ABLModel):
raise TypeError("Expected an ABLModel")
if not isinstance(abducer, ReasonerBase):
raise TypeError("Expected an ReasonerBase")
self.model = model
self.abducer = abducer

@abstractmethod
def predict(self, X: List[List[Any]]) -> Tuple[List[List[Any]], List[List[Any]]]:
"""Placeholder for predict labels from input."""
pass

@abstractmethod
def abduce_pseudo_label(self, pseudo_label: List[List[Any]]) -> List[List[Any]]:
"""Placeholder for abduce pseudo labels."""

@abstractmethod
def label_to_pseudo_label(self, label: List[List[Any]]) -> List[List[Any]]:
"""Placeholder for map label space to symbol space."""
pass

@abstractmethod
def pseudo_label_to_label(self, pseudo_label: List[List[Any]]) -> List[List[Any]]:
"""Placeholder for map symbol space to label space."""
pass
@abstractmethod
def train(self, train_data):
"""Placeholder for train loop of ABductive Learning."""
pass

@abstractmethod
def test(self, test_data):
"""Placeholder for model test."""
pass

@abstractmethod
def valid(self, valid_data):
"""Placeholder for model validation."""
pass

+ 118
- 0
abl/bridge/simple_bridge.py View File

@@ -0,0 +1,118 @@
from ..learning import ABLModel
from ..reasoning import ReasonerBase
from ..evaluation import BaseMetric
from .base_bridge import BaseBridge
from typing import List, Union, Any, Tuple, Dict, Optional
from numpy import ndarray

from torch.utils.data import DataLoader
from ..dataset import BridgeDataset
from ..utils.logger import print_log


class SimpleBridge(BaseBridge):
def __init__(
self,
model: ABLModel,
abducer: ReasonerBase,
metric_list: BaseMetric,
) -> None:
super().__init__(model, abducer)
self.metric_list = metric_list

def predict(self, X) -> Tuple[List[List[Any]], ndarray]:
pred_res = self.model.predict(X)
pred_label, pred_prob = pred_res["label"], pred_res["prob"]
return pred_label, pred_prob

def abduce_pseudo_label(
self,
pred_label: List[List[Any]],
pred_prob: ndarray,
pseudo_label: List[List[Any]],
Y: List[List[Any]],
) -> List[List[Any]]:
return self.abducer.batch_abduce(pred_label, pred_prob, pseudo_label, Y)

def label_to_pseudo_label(
self, label: List[List[Any]], mapping: Dict = None
) -> List[List[Any]]:
if mapping is None:
mapping = self.abducer.mapping
return [[mapping[_label] for _label in sub_list] for sub_list in label]

def pseudo_label_to_label(
self, pseudo_label: List[List[Any]], mapping: Dict = None
) -> List[List[Any]]:
if mapping is None:
mapping = self.abducer.remapping
return [
[mapping[_pseudo_label] for _pseudo_label in sub_list]
for sub_list in pseudo_label
]

def train(
self,
train_data: Tuple[List[List[Any]], Optional[List[List[Any]]], List[List[Any]]],
epochs: int = 50,
batch_size: Union[int, float] = -1,
eval_interval: int = 1,
):
dataset = BridgeDataset(*train_data)
data_loader = DataLoader(
dataset,
batch_size=batch_size,
collate_fn=lambda data_list: [list(data) for data in zip(*data_list)],
)

for epoch in range(epochs):
for seg_idx, (X, Z, Y) in enumerate(data_loader):
pred_label, pred_prob = self.predict(X)
pred_pseudo_label = self.label_to_pseudo_label(pred_label)
abduced_pseudo_label = self.abduce_pseudo_label(
pred_label, pred_prob, pred_pseudo_label, Y
)
abduced_label = self.pseudo_label_to_label(abduced_pseudo_label)
min_loss = self.model.train(X, abduced_label)

print_log(f"Epoch(train) [{epoch}] [{seg_idx:3}/{len(data_loader)}] minimal_loss is {min_loss:.5f}", logger="current")

if (epoch + 1) % eval_interval == 0 or epoch == epochs - 1:
print_log(f"Evaluation start: Epoch(val) [{epoch}]", logger="current")
self.valid(train_data)

def test(self, test_data):
return super().test(test_data)

def _valid(self, data_loader):
res = dict()
for X, Z, Y in data_loader:
pred_label, pred_prob = self.predict(X)
pred_pseudo_label = self.label_to_pseudo_label(pred_label)
data_samples = dict(
pred_label=pred_label,
pred_prob=pred_prob,
pred_pseudo_label=pred_pseudo_label,
gt_pseudo_label=Z,
Y=Y,
logic_forward=self.abducer.kb.logic_forward,
)
for metric in self.metric_list:
metric.process(data_samples)
res.update(metric.evaluate())
msg = "Evaluation ended, "
for k, v in res.items():
msg += k + f": {v:.3f} "
print_log(msg, logger="current")

def valid(self, valid_data, batch_size=1000):
dataset = BridgeDataset(*valid_data)
data_loader = DataLoader(
dataset,
batch_size=batch_size,
collate_fn=lambda data_list: [list(data) for data in zip(*data_list)],
)
self._valid(data_loader)

return super().valid(valid_data)

Loading…
Cancel
Save