From b470cb723f667892afb8d6705ae659433c69da8a Mon Sep 17 00:00:00 2001 From: troyyyyy Date: Thu, 11 Jan 2024 11:13:16 +0800 Subject: [PATCH] [ENH] add ValueError in simplebridge --- ablkit/bridge/simple_bridge.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/ablkit/bridge/simple_bridge.py b/ablkit/bridge/simple_bridge.py index 79b6c81..5c2cbfb 100644 --- a/ablkit/bridge/simple_bridge.py +++ b/ablkit/bridge/simple_bridge.py @@ -49,6 +49,15 @@ class SimpleBridge(BaseBridge): ) -> None: super().__init__(model, reasoner) self.metric_list = metric_list + if not hasattr(model.base_model, "predict_proba") and reasoner.dist_func in [ + "confidence", + "avg_confidence", + ]: + raise ValueError( + "If the base model does not implement the predict_proba method, " + + "then the dist_func in the reasoner cannot be set to 'confidence'" + + "or 'avg_confidence', which are related to predicted probability." + ) def predict(self, data_examples: ListData) -> Tuple[List[ndarray], List[ndarray]]: """