|
- """
- Base solver for classification problems
- """
-
- from typing import Any
- from ..base import BaseSolver
- from ...module.ensemble import ENSEMBLE_DICT
- from ...module import BaseEnsembler
-
-
- class BaseClassifier(BaseSolver):
- """
- Base solver for classification problems
- """
-
- def predict_proba(self, *args, **kwargs) -> Any:
- """
- Predict the node probability.
-
- Returns
- -------
- result: Any
- The predicted probability
- """
- raise NotImplementedError()
-
- def set_ensemble_module(self, ensemble_module, *args, **kwargs) -> "BaseClassifier":
- """
- Set the ensemble module used in current solver.
-
- Parameters
- ----------
- ensemble_module: autogl.module.ensemble.BaseEnsembler or str or None
- The (name of) ensemble module used to ensemble the multi-models found.
- Disable ensemble by setting it to ``None``.
-
- Returns
- -------
- self: autogl.solver.BaseSolver
- A reference of current solver.
- """
- # load ensemble module
- if ensemble_module is None:
- self.ensemble_module = None
- elif isinstance(ensemble_module, BaseEnsembler):
- self.ensemble_module = ensemble_module
- elif isinstance(ensemble_module, str):
- if ensemble_module in ENSEMBLE_DICT:
- self.ensemble_module = ENSEMBLE_DICT[ensemble_module](*args, **kwargs)
- else:
- raise KeyError("cannot find ensemble module %s." % (ensemble_module))
- else:
- ValueError(
- "need ensemble module to be str or a BaseEnsembler instance, get",
- type(ensemble_module),
- "instead.",
- )
|