Browse Source

[ENH] add ValueError in simplebridge

pull/6/head
troyyyyy 2 years ago
parent
commit
b470cb723f
1 changed files with 9 additions and 0 deletions
  1. +9
    -0
      ablkit/bridge/simple_bridge.py

+ 9
- 0
ablkit/bridge/simple_bridge.py View File

@@ -49,6 +49,15 @@ class SimpleBridge(BaseBridge):
) -> None:
super().__init__(model, reasoner)
self.metric_list = metric_list
if not hasattr(model.base_model, "predict_proba") and reasoner.dist_func in [
"confidence",
"avg_confidence",
]:
raise ValueError(
"If the base model does not implement the predict_proba method, "
+ "then the dist_func in the reasoner cannot be set to 'confidence'"
+ "or 'avg_confidence', which are related to predicted probability."
)

def predict(self, data_examples: ListData) -> Tuple[List[ndarray], List[ndarray]]:
"""


Loading…
Cancel
Save