|
|
|
@@ -8,7 +8,7 @@ from ..learning import ABLModel |
|
|
|
from ..reasoning import Reasoner |
|
|
|
from ..structures import ListData |
|
|
|
from ..utils import print_log |
|
|
|
from .base_bridge import BaseBridge, DataSet |
|
|
|
from .base_bridge import BaseBridge |
|
|
|
|
|
|
|
|
|
|
|
class SimpleBridge(BaseBridge): |
|
|
|
@@ -55,8 +55,10 @@ class SimpleBridge(BaseBridge): |
|
|
|
|
|
|
|
def train( |
|
|
|
self, |
|
|
|
train_data: Union[ListData, DataSet], |
|
|
|
val_data: Optional[Union[ListData, DataSet]] = None, |
|
|
|
train_data: Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]]], |
|
|
|
val_data: Optional[ |
|
|
|
Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]]] |
|
|
|
] = None, |
|
|
|
loops: int = 50, |
|
|
|
segment_size: Union[int, float] = -1, |
|
|
|
eval_interval: int = 1, |
|
|
|
@@ -79,12 +81,12 @@ class SimpleBridge(BaseBridge): |
|
|
|
self.pseudo_label_to_idx(sub_data_samples) |
|
|
|
loss = self.model.train(sub_data_samples) |
|
|
|
|
|
|
|
print_log( |
|
|
|
f"loop(train) [{loop + 1}/{loops}] segment(train) \ |
|
|
|
[{(seg_idx + 1)}/{(len(data_samples) - 1) // segment_size + 1}] \ |
|
|
|
model loss is {loss:.5f}", |
|
|
|
logger="current", |
|
|
|
log_string = ( |
|
|
|
f"loop(train) [{loop + 1}/{loops}] segment(train) " |
|
|
|
f"[{(seg_idx + 1)}/{(len(data_samples) - 1) // segment_size + 1}] " |
|
|
|
f"model loss is {loss:.5f}" |
|
|
|
) |
|
|
|
print_log(log_string, logger="current") |
|
|
|
|
|
|
|
if (loop + 1) % eval_interval == 0 or loop == loops - 1: |
|
|
|
print_log(f"Evaluation start: loop(val) [{loop + 1}]", logger="current") |
|
|
|
@@ -114,12 +116,18 @@ class SimpleBridge(BaseBridge): |
|
|
|
msg += k + f": {v:.3f} " |
|
|
|
print_log(msg, logger="current") |
|
|
|
|
|
|
|
def valid(self, valid_data: Union[ListData, DataSet]) -> None: |
|
|
|
def valid( |
|
|
|
self, |
|
|
|
valid_data: Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]]], |
|
|
|
) -> None: |
|
|
|
if not isinstance(valid_data, ListData): |
|
|
|
data_samples = self.data_preprocess(*valid_data) |
|
|
|
else: |
|
|
|
data_samples = valid_data |
|
|
|
self._valid(data_samples) |
|
|
|
|
|
|
|
def test(self, test_data: Union[ListData, DataSet]) -> None: |
|
|
|
def test( |
|
|
|
self, |
|
|
|
test_data: Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]]], |
|
|
|
) -> None: |
|
|
|
self.valid(test_data) |