You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

stacking.py 6.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. """
  2. Ensemble module.
  3. """
  4. import numpy as np
  5. import torch
  6. import torch.nn.functional as F
  7. from sklearn.ensemble import GradientBoostingClassifier
  8. from sklearn.linear_model import LogisticRegression
  9. from .base import BaseEnsembler
  10. from . import register_ensembler
  11. from ...utils import get_logger
  12. STACKING_LOGGER = get_logger("stacking")
  13. @register_ensembler("stacking")
  14. class Stacking(BaseEnsembler):
  15. """
  16. A stacking ensembler. Currently we support gradient boosting as the meta-algorithm.
  17. Parameters
  18. ----------
  19. meta_model : 'gbm' or 'glm' (Optional)
  20. Type of the stacker:
  21. 'gbm' : Gradient boosting model. This is the default.
  22. 'glm' : Generalized linear model.
  23. meta_params : a ``dict`` (Optional)
  24. When ``meta_model`` is specified, you can customize the parameters of the stacker.
  25. If this argument is not provided, the stacker will be configurated with default parameters.
  26. Default ``{}``.
  27. """
  28. def __init__(self, meta_model="gbm", meta_params={}, *args, **kwargs):
  29. super().__init__()
  30. self.model_name = meta_model.lower()
  31. assert self.model_name in [
  32. "gbm",
  33. "glm",
  34. ], "Only support gbm and glm when ensemble!"
  35. self.meta_params = meta_params
  36. def fit(
  37. self, predictions, label, identifiers, feval, n_classes="auto", *args, **kwargs
  38. ):
  39. """
  40. Fit the ensembler to the given data using Stacking method.
  41. Parameters
  42. ----------
  43. predictions : a list of np.ndarray
  44. Predictions of base learners (corresponding to the elements in identifiers).
  45. label : a list of int
  46. Class labels of instances.
  47. identifiers : a list of str
  48. The names of base models.
  49. feval : (a list of) autogl.module.train.evaluate
  50. Performance evaluation metrices.
  51. n_classes : int or str (Optional)
  52. The number of classes. Default as ``'auto'``, which will use maximum label.
  53. Returns
  54. -------
  55. (a list of) float
  56. The validation performance of the final stacker.
  57. """
  58. n_classes = n_classes if not n_classes == "auto" else max(label) + 1
  59. assert n_classes > max(
  60. label
  61. ), "Detect max label passed (%d) exceeeds" " n_classes given (%d)" % (
  62. max(label),
  63. n_classes,
  64. )
  65. assert len(identifiers) == len(
  66. set(identifiers)
  67. ), "Duplicate name" " in identifiers {} !".format(identifiers)
  68. self.fit_identifiers = identifiers
  69. if not isinstance(feval, list):
  70. feval = [feval]
  71. self._re_initialize(identifiers, len(predictions))
  72. config = self.meta_params
  73. STACKING_LOGGER.debug("meta-model name %s", self.model_name)
  74. if self.model_name == "gbm":
  75. meta_X = (
  76. torch.tensor(predictions).transpose(0, 1).flatten(start_dim=1).numpy()
  77. )
  78. meta_Y = np.array(label)
  79. config = {}
  80. model = GradientBoostingClassifier(**config)
  81. model.fit(meta_X, meta_Y)
  82. self.model = model
  83. ensemble_prediction = model.predict_proba(meta_X)
  84. elif self.model_name == "glm":
  85. meta_X = (
  86. torch.tensor(predictions).transpose(0, 1).flatten(start_dim=1).numpy()
  87. )
  88. meta_Y = np.array(label)
  89. config["multi_class"] = "auto"
  90. config["solver"] = "lbfgs"
  91. model = LogisticRegression(**config)
  92. model.fit(meta_X, meta_Y)
  93. self.model = model
  94. ensemble_prediction = model.predict_proba(meta_X)
  95. elif self.model_name == "nn":
  96. meta_X = torch.tensor(predictions).transpose(0, 1).flatten(start_dim=1)
  97. meta_Y = F.one_hot(
  98. torch.tensor(label, dtype=torch.int64), n_classes
  99. ).double()
  100. # print(meta_Y.type())
  101. n_instance, n_input = meta_X.size()
  102. n_learners = len(identifiers)
  103. fc = torch.nn.Linear(n_input, n_input // n_learners).double()
  104. config["lr"] = 1e-1
  105. # config['weight_decay'] = 1e-2
  106. optimizer = torch.optim.SGD(fc.parameters(), **config)
  107. max_epoch = 100
  108. for epoch in range(max_epoch):
  109. optimizer.zero_grad()
  110. ensemble_prediction = F.normalize(fc.forward(meta_X), dim=0)
  111. loss = F.mse_loss(ensemble_prediction, meta_Y)
  112. loss.backward()
  113. optimizer.step()
  114. self.model = fc
  115. ensemble_prediction = (
  116. F.normalize(fc.forward(meta_X), dim=0).detach().numpy()
  117. )
  118. else:
  119. STACKING_LOGGER.error(
  120. "Cannot parse stacking ensemble model name %s", self.model_name
  121. )
  122. return [fx.evaluate(ensemble_prediction, label) for fx in feval]
  123. def ensemble(self, predictions, identifiers, *args, **kwargs):
  124. """
  125. Ensemble the predictions of base models.
  126. Parameters
  127. ----------
  128. predictions : a list of ``np.ndarray``
  129. Predictions of base learners (corresponding to the elements in identifiers).
  130. identifiers : a list of ``str``
  131. The names of base models.
  132. Returns
  133. -------
  134. ``np.ndarray``
  135. The ensembled predictions.
  136. """
  137. assert len(identifiers) == len(
  138. set(identifiers)
  139. ), "Duplicate name in" " identifiers {} !".format(identifiers)
  140. assert set(self.fit_identifiers) == set(
  141. identifiers
  142. ), "Different identifiers" " passed in fit {} and ensemble {} !".format(
  143. self.fit_identifiers, identifiers
  144. )
  145. # re-order predictions if needed
  146. if not self.fit_identifiers == identifiers:
  147. re_id = [
  148. identifiers.index(identifier) for identifier in self.fit_identifiers
  149. ]
  150. predictions = [predictions[i] for i in re_id]
  151. if self.model_name in ["gbm", "glm"]:
  152. pred_packed = (
  153. torch.tensor(predictions).transpose(0, 1).flatten(start_dim=1).numpy()
  154. )
  155. return self.model.predict_proba(pred_packed)
  156. elif self.model_name in ["nn"]:
  157. pred_packed = torch.tensor(predictions).transpose(0, 1).flatten(start_dim=1)
  158. return F.normalize(self.model.forward(pred_packed), dim=0).detach().numpy()
  159. def _re_initialize(self, identifiers, n_models):
  160. self.identifiers = identifiers
  161. self.model = None