diff --git a/ablkit/bridge/simple_bridge.py b/ablkit/bridge/simple_bridge.py index 5c2cbfb..045b2a8 100644 --- a/ablkit/bridge/simple_bridge.py +++ b/ablkit/bridge/simple_bridge.py @@ -9,15 +9,17 @@ from typing import Any, List, Optional, Tuple, Union from numpy import ndarray +import wandb + from ..data.evaluation import BaseMetric from ..data.structures import ListData from ..learning import ABLModel from ..reasoning import Reasoner from ..utils import print_log -from .base_bridge import BaseBridge +from .base_bridge import BaseBridge, M, R -class SimpleBridge(BaseBridge): +class SimpleBridge(BaseBridge[M, R]): """ A basic implementation for bridging machine learning and reasoning parts. @@ -32,10 +34,10 @@ class SimpleBridge(BaseBridge): Parameters ---------- - model : ABLModel + model : M The machine learning model wrapped in ``ABLModel``, which is mainly used for prediction and model training. - reasoner : Reasoner + reasoner : R The reasoning part wrapped in ``Reasoner``, which is used for pseudo-label revision. metric_list : List[BaseMetric] A list of metrics used for evaluating the model's performance. @@ -43,12 +45,13 @@ class SimpleBridge(BaseBridge): def __init__( self, - model: ABLModel, - reasoner: Reasoner, + model: M, + reasoner: R, metric_list: List[BaseMetric], ) -> None: super().__init__(model, reasoner) self.metric_list = metric_list + self.use_wandb = self._check_wandb_available() if not hasattr(model.base_model, "predict_proba") and reasoner.dist_func in [ "confidence", "avg_confidence", @@ -59,6 +62,20 @@ class SimpleBridge(BaseBridge): + "or 'avg_confidence', which are related to predicted probability." ) + def _check_wandb_available(self): + """ + Check if wandb is available and initialized. + + Returns + ------- + bool + True if wandb is available and initialized, False otherwise. + """ + try: + return wandb.run is not None + except ImportError: + return False + def predict(self, data_examples: ListData) -> Tuple[List[ndarray], List[ndarray]]: """ Predict class indices and probabilities (if ``predict_proba`` is implemented in @@ -129,10 +146,7 @@ class SimpleBridge(BaseBridge): A list of indices converted from pseudo-labels. """ abduced_idx = [ - [ - self.reasoner.label_to_idx[_abduced_pseudo_label] - for _abduced_pseudo_label in sub_list - ] + [self.reasoner.label_to_idx[_abduced_pseudo_label] for _abduced_pseudo_label in sub_list] for sub_list in data_examples.abduced_pseudo_label ] data_examples.abduced_idx = abduced_idx @@ -207,11 +221,12 @@ class SimpleBridge(BaseBridge): def train( self, train_data: Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]]], - label_data: Optional[ - Union[ListData, Tuple[List[List[Any]], List[List[Any]], List[Any]]] - ] = None, + label_data: Optional[Union[ListData, Tuple[List[List[Any]], List[List[Any]], List[Any]]]] = None, val_data: Optional[ - Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]]] + Union[ + ListData, + Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]], + ] ] = None, loops: int = 50, segment_size: Union[int, float] = 1.0, @@ -287,28 +302,26 @@ class SimpleBridge(BaseBridge): logger="current", ) - sub_data_examples = data_examples[ - seg_idx * segment_size : (seg_idx + 1) * segment_size - ] + sub_data_examples = data_examples[seg_idx * segment_size : (seg_idx + 1) * segment_size] self.predict(sub_data_examples) self.idx_to_pseudo_label(sub_data_examples) self.abduce_pseudo_label(sub_data_examples) self.filter_pseudo_label(sub_data_examples) self.concat_data_examples(sub_data_examples, label_data_examples) self.pseudo_label_to_idx(sub_data_examples) + if len(sub_data_examples) == 0: + continue self.model.train(sub_data_examples) if (loop + 1) % eval_interval == 0 or loop == loops - 1: print_log(f"Eval start: loop(val) [{loop + 1}]", logger="current") - self._valid(val_data_examples) + self._valid(val_data_examples, prefix="val") if save_interval is not None and ((loop + 1) % save_interval == 0 or loop == loops - 1): print_log(f"Saving model: loop(save) [{loop + 1}]", logger="current") - self.model.save( - save_path=osp.join(save_dir, f"model_checkpoint_loop_{loop + 1}.pth") - ) + self.model.save(save_path=osp.join(save_dir, f"model_checkpoint_loop_{loop + 1}.pth")) - def _valid(self, data_examples: ListData) -> None: + def _valid(self, data_examples: ListData, prefix: str = "val") -> None: """ Internal method for validating the model with given data examples. @@ -320,21 +333,40 @@ class SimpleBridge(BaseBridge): self.predict(data_examples) self.idx_to_pseudo_label(data_examples) + for metric in self.metric_list: + metric.prefix = prefix + for metric in self.metric_list: metric.process(data_examples) res = dict() for metric in self.metric_list: res.update(metric.evaluate()) + msg = "Evaluation ended, " for k, v in res.items(): - msg += k + f": {v:.3f} " + try: + v = float(v) + msg += k + f": {v:.3f} " + except: + pass + + if self.use_wandb: + try: + wandb_metrics = {} + for k, v in res.items(): + wandb_metrics[f"{k}"] = v + wandb.log(wandb_metrics) + except Exception as e: + print_log(f"Failed to log metrics to wandb: {e}", logger="current") + print_log(msg, logger="current") def valid( self, val_data: Union[ - ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]] + ListData, + Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]], ], ) -> None: """ @@ -349,12 +381,13 @@ class SimpleBridge(BaseBridge): ``self.metric_list``. """ val_data_examples = self.data_preprocess("val", val_data) - self._valid(val_data_examples) + self._valid(val_data_examples, prefix="val") def test( self, test_data: Union[ - ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]] + ListData, + Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]], ], ) -> None: """ @@ -370,4 +403,4 @@ class SimpleBridge(BaseBridge): """ print_log("Test start:", logger="current") test_data_examples = self.data_preprocess("test", test_data) - self._valid(test_data_examples) + self._valid(test_data_examples, prefix="test")