Browse Source

[MNT] flatten DataSet type

pull/1/head
Gao Enhao 2 years ago
parent
commit
72289caccd
2 changed files with 35 additions and 20 deletions
  1. +17
    -10
      abl/bridge/base_bridge.py
  2. +18
    -10
      abl/bridge/simple_bridge.py

+ 17
- 10
abl/bridge/base_bridge.py View File

@@ -5,8 +5,6 @@ from ..learning import ABLModel
from ..reasoning import Reasoner
from ..structures import ListData

DataSet = Tuple[List[List[Any]], Optional[List[List[Any]]], List[List[Any]]]


class BaseBridge(metaclass=ABCMeta):
def __init__(self, model: ABLModel, reasoner: Reasoner) -> None:
@@ -24,28 +22,37 @@ class BaseBridge(metaclass=ABCMeta):

@abstractmethod
def predict(self, data_samples: ListData) -> Tuple[List[List[Any]], List[List[Any]]]:
"""Placeholder for predict labels from input."""
"""Placeholder for predicting labels from input."""

@abstractmethod
def abduce_pseudo_label(self, data_samples: ListData) -> List[List[Any]]:
"""Placeholder for abduce pseudo labels."""
"""Placeholder for abducing pseudo labels."""

@abstractmethod
def idx_to_pseudo_label(self, data_samples: ListData) -> List[List[Any]]:
"""Placeholder for map label space to symbol space."""
"""Placeholder for mapping indexes to pseudo labels."""

@abstractmethod
def pseudo_label_to_idx(self, data_samples: ListData) -> List[List[Any]]:
"""Placeholder for map symbol space to label space."""
"""Placeholder for mapping pseudo labels to indexes."""

@abstractmethod
def train(self, train_data: Union[ListData, DataSet]):
"""Placeholder for train loop of ABductive Learning."""
def train(
self,
train_data: Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]]],
):
"""Placeholder for training loop of ABductive Learning."""

@abstractmethod
def valid(self, valid_data: Union[ListData, DataSet]) -> None:
def valid(
self,
valid_data: Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]]],
) -> None:
"""Placeholder for model test."""

@abstractmethod
def test(self, test_data: Union[ListData, DataSet]) -> None:
def test(
self,
test_data: Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]]],
) -> None:
"""Placeholder for model validation."""

+ 18
- 10
abl/bridge/simple_bridge.py View File

@@ -8,7 +8,7 @@ from ..learning import ABLModel
from ..reasoning import Reasoner
from ..structures import ListData
from ..utils import print_log
from .base_bridge import BaseBridge, DataSet
from .base_bridge import BaseBridge


class SimpleBridge(BaseBridge):
@@ -55,8 +55,10 @@ class SimpleBridge(BaseBridge):

def train(
self,
train_data: Union[ListData, DataSet],
val_data: Optional[Union[ListData, DataSet]] = None,
train_data: Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]]],
val_data: Optional[
Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]]]
] = None,
loops: int = 50,
segment_size: Union[int, float] = -1,
eval_interval: int = 1,
@@ -79,12 +81,12 @@ class SimpleBridge(BaseBridge):
self.pseudo_label_to_idx(sub_data_samples)
loss = self.model.train(sub_data_samples)

print_log(
f"loop(train) [{loop + 1}/{loops}] segment(train) \
[{(seg_idx + 1)}/{(len(data_samples) - 1) // segment_size + 1}] \
model loss is {loss:.5f}",
logger="current",
log_string = (
f"loop(train) [{loop + 1}/{loops}] segment(train) "
f"[{(seg_idx + 1)}/{(len(data_samples) - 1) // segment_size + 1}] "
f"model loss is {loss:.5f}"
)
print_log(log_string, logger="current")

if (loop + 1) % eval_interval == 0 or loop == loops - 1:
print_log(f"Evaluation start: loop(val) [{loop + 1}]", logger="current")
@@ -114,12 +116,18 @@ class SimpleBridge(BaseBridge):
msg += k + f": {v:.3f} "
print_log(msg, logger="current")

def valid(self, valid_data: Union[ListData, DataSet]) -> None:
def valid(
self,
valid_data: Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]]],
) -> None:
if not isinstance(valid_data, ListData):
data_samples = self.data_preprocess(*valid_data)
else:
data_samples = valid_data
self._valid(data_samples)

def test(self, test_data: Union[ListData, DataSet]) -> None:
def test(
self,
test_data: Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]]],
) -> None:
self.valid(test_data)

Loading…
Cancel
Save