|
|
|
@@ -49,6 +49,15 @@ class SimpleBridge(BaseBridge): |
|
|
|
) -> None: |
|
|
|
super().__init__(model, reasoner) |
|
|
|
self.metric_list = metric_list |
|
|
|
if not hasattr(model.base_model, "predict_proba") and reasoner.dist_func in [ |
|
|
|
"confidence", |
|
|
|
"avg_confidence", |
|
|
|
]: |
|
|
|
raise ValueError( |
|
|
|
"If the base model does not implement the predict_proba method, " |
|
|
|
+ "then the dist_func in the reasoner cannot be set to 'confidence'" |
|
|
|
+ "or 'avg_confidence', which are related to predicted probability." |
|
|
|
) |
|
|
|
|
|
|
|
def predict(self, data_examples: ListData) -> Tuple[List[ndarray], List[ndarray]]: |
|
|
|
""" |
|
|
|
|