diff --git a/abl/bridge/base_bridge.py b/abl/bridge/base_bridge.py index ba1a663..c9261fb 100644 --- a/abl/bridge/base_bridge.py +++ b/abl/bridge/base_bridge.py @@ -5,8 +5,6 @@ from ..learning import ABLModel from ..reasoning import Reasoner from ..structures import ListData -DataSet = Tuple[List[List[Any]], Optional[List[List[Any]]], List[List[Any]]] - class BaseBridge(metaclass=ABCMeta): def __init__(self, model: ABLModel, reasoner: Reasoner) -> None: @@ -24,28 +22,37 @@ class BaseBridge(metaclass=ABCMeta): @abstractmethod def predict(self, data_samples: ListData) -> Tuple[List[List[Any]], List[List[Any]]]: - """Placeholder for predict labels from input.""" + """Placeholder for predicting labels from input.""" @abstractmethod def abduce_pseudo_label(self, data_samples: ListData) -> List[List[Any]]: - """Placeholder for abduce pseudo labels.""" + """Placeholder for abducing pseudo labels.""" @abstractmethod def idx_to_pseudo_label(self, data_samples: ListData) -> List[List[Any]]: - """Placeholder for map label space to symbol space.""" + """Placeholder for mapping indexes to pseudo labels.""" @abstractmethod def pseudo_label_to_idx(self, data_samples: ListData) -> List[List[Any]]: - """Placeholder for map symbol space to label space.""" + """Placeholder for mapping pseudo labels to indexes.""" @abstractmethod - def train(self, train_data: Union[ListData, DataSet]): - """Placeholder for train loop of ABductive Learning.""" + def train( + self, + train_data: Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]]], + ): + """Placeholder for training loop of ABductive Learning.""" @abstractmethod - def valid(self, valid_data: Union[ListData, DataSet]) -> None: + def valid( + self, + valid_data: Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]]], + ) -> None: """Placeholder for model test.""" @abstractmethod - def test(self, test_data: Union[ListData, DataSet]) -> None: + def test( + self, + test_data: Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]]], + ) -> None: """Placeholder for model validation.""" diff --git a/abl/bridge/simple_bridge.py b/abl/bridge/simple_bridge.py index 508c106..df4accd 100644 --- a/abl/bridge/simple_bridge.py +++ b/abl/bridge/simple_bridge.py @@ -8,7 +8,7 @@ from ..learning import ABLModel from ..reasoning import Reasoner from ..structures import ListData from ..utils import print_log -from .base_bridge import BaseBridge, DataSet +from .base_bridge import BaseBridge class SimpleBridge(BaseBridge): @@ -55,8 +55,10 @@ class SimpleBridge(BaseBridge): def train( self, - train_data: Union[ListData, DataSet], - val_data: Optional[Union[ListData, DataSet]] = None, + train_data: Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]]], + val_data: Optional[ + Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]]] + ] = None, loops: int = 50, segment_size: Union[int, float] = -1, eval_interval: int = 1, @@ -79,12 +81,12 @@ class SimpleBridge(BaseBridge): self.pseudo_label_to_idx(sub_data_samples) loss = self.model.train(sub_data_samples) - print_log( - f"loop(train) [{loop + 1}/{loops}] segment(train) \ - [{(seg_idx + 1)}/{(len(data_samples) - 1) // segment_size + 1}] \ - model loss is {loss:.5f}", - logger="current", + log_string = ( + f"loop(train) [{loop + 1}/{loops}] segment(train) " + f"[{(seg_idx + 1)}/{(len(data_samples) - 1) // segment_size + 1}] " + f"model loss is {loss:.5f}" ) + print_log(log_string, logger="current") if (loop + 1) % eval_interval == 0 or loop == loops - 1: print_log(f"Evaluation start: loop(val) [{loop + 1}]", logger="current") @@ -114,12 +116,18 @@ class SimpleBridge(BaseBridge): msg += k + f": {v:.3f} " print_log(msg, logger="current") - def valid(self, valid_data: Union[ListData, DataSet]) -> None: + def valid( + self, + valid_data: Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]]], + ) -> None: if not isinstance(valid_data, ListData): data_samples = self.data_preprocess(*valid_data) else: data_samples = valid_data self._valid(data_samples) - def test(self, test_data: Union[ListData, DataSet]) -> None: + def test( + self, + test_data: Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]]], + ) -> None: self.valid(test_data)