diff --git a/ablkit/bridge/simple_bridge.py b/ablkit/bridge/simple_bridge.py index 79b6c81..5c2cbfb 100644 --- a/ablkit/bridge/simple_bridge.py +++ b/ablkit/bridge/simple_bridge.py @@ -49,6 +49,15 @@ class SimpleBridge(BaseBridge): ) -> None: super().__init__(model, reasoner) self.metric_list = metric_list + if not hasattr(model.base_model, "predict_proba") and reasoner.dist_func in [ + "confidence", + "avg_confidence", + ]: + raise ValueError( + "If the base model does not implement the predict_proba method, " + + "then the dist_func in the reasoner cannot be set to 'confidence'" + + "or 'avg_confidence', which are related to predicted probability." + ) def predict(self, data_examples: ListData) -> Tuple[List[ndarray], List[ndarray]]: """