| @@ -1,14 +1,14 @@ | |||
| from typing import List, Union, Any, Tuple, Dict, Optional | |||
| import os.path as osp | |||
| from typing import Any, Dict, List, Optional, Tuple, Union | |||
| from numpy import ndarray | |||
| from .base_bridge import BaseBridge, DataSet | |||
| from ..evaluation import BaseMetric | |||
| from ..learning import ABLModel | |||
| from ..reasoning import ReasonerBase | |||
| from ..evaluation import BaseMetric | |||
| from ..structures import ListData | |||
| from ..utils.logger import print_log | |||
| from ..utils import print_log | |||
| from .base_bridge import BaseBridge, DataSet | |||
| class SimpleBridge(BaseBridge): | |||
| @@ -21,11 +21,13 @@ class SimpleBridge(BaseBridge): | |||
| super().__init__(model, abducer) | |||
| self.metric_list = metric_list | |||
| def predict(self, data_samples: ListData) -> Tuple[List[ndarray], ndarray]: | |||
| # TODO: add abducer.mapping to the property of SimpleBridge | |||
| def predict(self, data_samples: ListData) -> Tuple[List[ndarray], List[ndarray]]: | |||
| pred_res = self.model.predict(data_samples) | |||
| data_samples.pred_idx = pred_res["label"] | |||
| data_samples.pred_prob = pred_res["prob"] | |||
| return data_samples["pred_idx"], ["data_samples.pred_prob"] | |||
| return data_samples["pred_idx"], data_samples["pred_prob"] | |||
| def abduce_pseudo_label( | |||
| self, | |||
| @@ -37,7 +39,7 @@ class SimpleBridge(BaseBridge): | |||
| return data_samples["abduced_pseudo_label"] | |||
| def idx_to_pseudo_label( | |||
| self, data_samples: ListData, mapping: Dict = None | |||
| self, data_samples: ListData, mapping: Optional[Dict] = None | |||
| ) -> List[List[Any]]: | |||
| if mapping is None: | |||
| mapping = self.abducer.mapping | |||
| @@ -48,7 +50,7 @@ class SimpleBridge(BaseBridge): | |||
| return data_samples["pred_pseudo_label"] | |||
| def pseudo_label_to_idx( | |||
| self, data_samples: ListData, mapping: Dict = None | |||
| self, data_samples: ListData, mapping: Optional[Dict] = None | |||
| ) -> List[List[Any]]: | |||
| if mapping is None: | |||
| mapping = self.abducer.remapping | |||
| @@ -59,9 +61,7 @@ class SimpleBridge(BaseBridge): | |||
| data_samples.abduced_idx = abduced_idx | |||
| return data_samples["abduced_idx"] | |||
| def data_preprocess( | |||
| self, X: List[Any], gt_pseudo_label: List[Any], Y: List[Any] | |||
| ) -> ListData: | |||
| def data_preprocess(self, X: List[Any], gt_pseudo_label: List[Any], Y: List[Any]) -> ListData: | |||
| data_samples = ListData() | |||
| data_samples.X = X | |||
| @@ -72,17 +72,22 @@ class SimpleBridge(BaseBridge): | |||
| def train( | |||
| self, | |||
| train_data: DataSet, | |||
| epochs: int = 50, | |||
| batch_size: Union[int, float] = -1, | |||
| train_data: Union[ListData, DataSet], | |||
| loops: int = 50, | |||
| segment_size: Union[int, float] = -1, | |||
| eval_interval: int = 1, | |||
| save_interval: Optional[int] = None, | |||
| save_dir: Optional[str] = None, | |||
| ): | |||
| data_samples = self.data_preprocess(*train_data) | |||
| if isinstance(train_data, ListData): | |||
| data_samples = train_data | |||
| else: | |||
| data_samples = self.data_preprocess(*train_data) | |||
| for epoch in range(epochs): | |||
| for seg_idx in range((len(data_samples) - 1) // batch_size + 1): | |||
| for loop in range(loops): | |||
| for seg_idx in range((len(data_samples) - 1) // segment_size + 1): | |||
| sub_data_samples = data_samples[ | |||
| seg_idx * batch_size : (seg_idx + 1) * batch_size | |||
| seg_idx * segment_size : (seg_idx + 1) * segment_size | |||
| ] | |||
| self.predict(sub_data_samples) | |||
| self.idx_to_pseudo_label(sub_data_samples) | |||
| @@ -91,25 +96,25 @@ class SimpleBridge(BaseBridge): | |||
| loss = self.model.train(sub_data_samples) | |||
| print_log( | |||
| f"Epoch(train) [{epoch + 1}] [{(seg_idx + 1):3}/{(len(data_samples) - 1) // batch_size + 1}] model loss is {loss:.5f}", | |||
| f"loop(train) [{loop + 1}/{loops}] segment(train) [{(seg_idx + 1)}/{(len(data_samples) - 1) // segment_size + 1}] model loss is {loss:.5f}", | |||
| logger="current", | |||
| ) | |||
| if (epoch + 1) % eval_interval == 0 or epoch == epochs - 1: | |||
| print_log(f"Evaluation start: Epoch(val) [{epoch}]", logger="current") | |||
| if (loop + 1) % eval_interval == 0 or loop == loops - 1: | |||
| print_log(f"Evaluation start: loop(val) [{loop + 1}]", logger="current") | |||
| self.valid(train_data) | |||
| if save_interval is not None and ((loop + 1) % save_interval == 0 or loop == loops - 1): | |||
| print_log(f"Saving model: loop(save) [{loop + 1}]", logger="current") | |||
| self.model.save(save_path=osp.join(save_dir, f"model_checkpoint_loop_{loop + 1}.pth")) | |||
| def _valid(self, data_samples: ListData, batch_size: int = 128) -> None: | |||
| for seg_idx in range((len(data_samples) - 1) // batch_size + 1): | |||
| sub_data_samples = data_samples[ | |||
| seg_idx * batch_size : (seg_idx + 1) * batch_size | |||
| ] | |||
| sub_data_samples = data_samples[seg_idx * batch_size : (seg_idx + 1) * batch_size] | |||
| self.predict(sub_data_samples) | |||
| self.idx_to_pseudo_label(sub_data_samples) | |||
| sub_data_samples.set_metainfo( | |||
| dict(logic_forward=self.abducer.kb.logic_forward) | |||
| ) | |||
| sub_data_samples.set_metainfo(dict(logic_forward=self.abducer.kb.logic_forward)) | |||
| for metric in self.metric_list: | |||
| metric.process(sub_data_samples) | |||