| @@ -9,15 +9,17 @@ from typing import Any, List, Optional, Tuple, Union | |||
| from numpy import ndarray | |||
| import wandb | |||
| from ..data.evaluation import BaseMetric | |||
| from ..data.structures import ListData | |||
| from ..learning import ABLModel | |||
| from ..reasoning import Reasoner | |||
| from ..utils import print_log | |||
| from .base_bridge import BaseBridge | |||
| from .base_bridge import BaseBridge, M, R | |||
| class SimpleBridge(BaseBridge): | |||
| class SimpleBridge(BaseBridge[M, R]): | |||
| """ | |||
| A basic implementation for bridging machine learning and reasoning parts. | |||
| @@ -32,10 +34,10 @@ class SimpleBridge(BaseBridge): | |||
| Parameters | |||
| ---------- | |||
| model : ABLModel | |||
| model : M | |||
| The machine learning model wrapped in ``ABLModel``, which is mainly used for | |||
| prediction and model training. | |||
| reasoner : Reasoner | |||
| reasoner : R | |||
| The reasoning part wrapped in ``Reasoner``, which is used for pseudo-label revision. | |||
| metric_list : List[BaseMetric] | |||
| A list of metrics used for evaluating the model's performance. | |||
| @@ -43,12 +45,13 @@ class SimpleBridge(BaseBridge): | |||
| def __init__( | |||
| self, | |||
| model: ABLModel, | |||
| reasoner: Reasoner, | |||
| model: M, | |||
| reasoner: R, | |||
| metric_list: List[BaseMetric], | |||
| ) -> None: | |||
| super().__init__(model, reasoner) | |||
| self.metric_list = metric_list | |||
| self.use_wandb = self._check_wandb_available() | |||
| if not hasattr(model.base_model, "predict_proba") and reasoner.dist_func in [ | |||
| "confidence", | |||
| "avg_confidence", | |||
| @@ -59,6 +62,20 @@ class SimpleBridge(BaseBridge): | |||
| + "or 'avg_confidence', which are related to predicted probability." | |||
| ) | |||
| def _check_wandb_available(self): | |||
| """ | |||
| Check if wandb is available and initialized. | |||
| Returns | |||
| ------- | |||
| bool | |||
| True if wandb is available and initialized, False otherwise. | |||
| """ | |||
| try: | |||
| return wandb.run is not None | |||
| except ImportError: | |||
| return False | |||
| def predict(self, data_examples: ListData) -> Tuple[List[ndarray], List[ndarray]]: | |||
| """ | |||
| Predict class indices and probabilities (if ``predict_proba`` is implemented in | |||
| @@ -129,10 +146,7 @@ class SimpleBridge(BaseBridge): | |||
| A list of indices converted from pseudo-labels. | |||
| """ | |||
| abduced_idx = [ | |||
| [ | |||
| self.reasoner.label_to_idx[_abduced_pseudo_label] | |||
| for _abduced_pseudo_label in sub_list | |||
| ] | |||
| [self.reasoner.label_to_idx[_abduced_pseudo_label] for _abduced_pseudo_label in sub_list] | |||
| for sub_list in data_examples.abduced_pseudo_label | |||
| ] | |||
| data_examples.abduced_idx = abduced_idx | |||
| @@ -207,11 +221,12 @@ class SimpleBridge(BaseBridge): | |||
| def train( | |||
| self, | |||
| train_data: Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]]], | |||
| label_data: Optional[ | |||
| Union[ListData, Tuple[List[List[Any]], List[List[Any]], List[Any]]] | |||
| ] = None, | |||
| label_data: Optional[Union[ListData, Tuple[List[List[Any]], List[List[Any]], List[Any]]]] = None, | |||
| val_data: Optional[ | |||
| Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]]] | |||
| Union[ | |||
| ListData, | |||
| Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]], | |||
| ] | |||
| ] = None, | |||
| loops: int = 50, | |||
| segment_size: Union[int, float] = 1.0, | |||
| @@ -287,28 +302,26 @@ class SimpleBridge(BaseBridge): | |||
| logger="current", | |||
| ) | |||
| sub_data_examples = data_examples[ | |||
| seg_idx * segment_size : (seg_idx + 1) * segment_size | |||
| ] | |||
| sub_data_examples = data_examples[seg_idx * segment_size : (seg_idx + 1) * segment_size] | |||
| self.predict(sub_data_examples) | |||
| self.idx_to_pseudo_label(sub_data_examples) | |||
| self.abduce_pseudo_label(sub_data_examples) | |||
| self.filter_pseudo_label(sub_data_examples) | |||
| self.concat_data_examples(sub_data_examples, label_data_examples) | |||
| self.pseudo_label_to_idx(sub_data_examples) | |||
| if len(sub_data_examples) == 0: | |||
| continue | |||
| self.model.train(sub_data_examples) | |||
| if (loop + 1) % eval_interval == 0 or loop == loops - 1: | |||
| print_log(f"Eval start: loop(val) [{loop + 1}]", logger="current") | |||
| self._valid(val_data_examples) | |||
| self._valid(val_data_examples, prefix="val") | |||
| if save_interval is not None and ((loop + 1) % save_interval == 0 or loop == loops - 1): | |||
| print_log(f"Saving model: loop(save) [{loop + 1}]", logger="current") | |||
| self.model.save( | |||
| save_path=osp.join(save_dir, f"model_checkpoint_loop_{loop + 1}.pth") | |||
| ) | |||
| self.model.save(save_path=osp.join(save_dir, f"model_checkpoint_loop_{loop + 1}.pth")) | |||
| def _valid(self, data_examples: ListData) -> None: | |||
| def _valid(self, data_examples: ListData, prefix: str = "val") -> None: | |||
| """ | |||
| Internal method for validating the model with given data examples. | |||
| @@ -320,21 +333,40 @@ class SimpleBridge(BaseBridge): | |||
| self.predict(data_examples) | |||
| self.idx_to_pseudo_label(data_examples) | |||
| for metric in self.metric_list: | |||
| metric.prefix = prefix | |||
| for metric in self.metric_list: | |||
| metric.process(data_examples) | |||
| res = dict() | |||
| for metric in self.metric_list: | |||
| res.update(metric.evaluate()) | |||
| msg = "Evaluation ended, " | |||
| for k, v in res.items(): | |||
| msg += k + f": {v:.3f} " | |||
| try: | |||
| v = float(v) | |||
| msg += k + f": {v:.3f} " | |||
| except: | |||
| pass | |||
| if self.use_wandb: | |||
| try: | |||
| wandb_metrics = {} | |||
| for k, v in res.items(): | |||
| wandb_metrics[f"{k}"] = v | |||
| wandb.log(wandb_metrics) | |||
| except Exception as e: | |||
| print_log(f"Failed to log metrics to wandb: {e}", logger="current") | |||
| print_log(msg, logger="current") | |||
| def valid( | |||
| self, | |||
| val_data: Union[ | |||
| ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]] | |||
| ListData, | |||
| Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]], | |||
| ], | |||
| ) -> None: | |||
| """ | |||
| @@ -349,12 +381,13 @@ class SimpleBridge(BaseBridge): | |||
| ``self.metric_list``. | |||
| """ | |||
| val_data_examples = self.data_preprocess("val", val_data) | |||
| self._valid(val_data_examples) | |||
| self._valid(val_data_examples, prefix="val") | |||
| def test( | |||
| self, | |||
| test_data: Union[ | |||
| ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]] | |||
| ListData, | |||
| Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]], | |||
| ], | |||
| ) -> None: | |||
| """ | |||
| @@ -370,4 +403,4 @@ class SimpleBridge(BaseBridge): | |||
| """ | |||
| print_log("Test start:", logger="current") | |||
| test_data_examples = self.data_preprocess("test", test_data) | |||
| self._valid(test_data_examples) | |||
| self._valid(test_data_examples, prefix="test") | |||