You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

base.py 1.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. """
  2. Base solver for classification problems
  3. """
  4. from typing import Any
  5. from ..base import BaseSolver
  6. from ...module.ensemble import ENSEMBLE_DICT
  7. from ...module import BaseEnsembler
  8. class BaseClassifier(BaseSolver):
  9. """
  10. Base solver for classification problems
  11. """
  12. def predict_proba(self, *args, **kwargs) -> Any:
  13. """
  14. Predict the node probability.
  15. Returns
  16. -------
  17. result: Any
  18. The predicted probability
  19. """
  20. raise NotImplementedError()
  21. def set_ensemble_module(self, ensemble_module, *args, **kwargs) -> "BaseClassifier":
  22. """
  23. Set the ensemble module used in current solver.
  24. Parameters
  25. ----------
  26. ensemble_module: autogl.module.ensemble.BaseEnsembler or str or None
  27. The (name of) ensemble module used to ensemble the multi-models found.
  28. Disable ensemble by setting it to ``None``.
  29. Returns
  30. -------
  31. self: autogl.solver.BaseSolver
  32. A reference of current solver.
  33. """
  34. # load ensemble module
  35. if ensemble_module is None:
  36. self.ensemble_module = None
  37. elif isinstance(ensemble_module, BaseEnsembler):
  38. self.ensemble_module = ensemble_module
  39. elif isinstance(ensemble_module, str):
  40. if ensemble_module in ENSEMBLE_DICT:
  41. self.ensemble_module = ENSEMBLE_DICT[ensemble_module](*args, **kwargs)
  42. else:
  43. raise KeyError("cannot find ensemble module %s." % (ensemble_module))
  44. else:
  45. ValueError(
  46. "need ensemble module to be str or a BaseEnsembler instance, get",
  47. type(ensemble_module),
  48. "instead.",
  49. )