You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

_auto_feature_engineer.py 9.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271
  1. import time
  2. import numpy as np
  3. import torch
  4. import typing as _typing
  5. from sklearn import preprocessing
  6. from sklearn.metrics.pairwise import cosine_similarity
  7. import tqdm
  8. import tabulate
  9. import autogl.data.graph
  10. from ._feature_engineer import FeatureEngineer
  11. from .._data_preprocessor_registry import DataPreprocessorUniversalRegistry
  12. from ._selectors import GBDTFeatureSelector
  13. from ....utils import get_logger
  14. LOGGER = get_logger("Feature")
  15. @DataPreprocessorUniversalRegistry.register_data_preprocessor("identity")
  16. class IdentityFeature(FeatureEngineer):
  17. ...
  18. @DataPreprocessorUniversalRegistry.register_data_preprocessor("OnlyConst".lower())
  19. class OnlyConstFeature(FeatureEngineer):
  20. def _transform(
  21. self, data: _typing.Union[autogl.data.graph.GeneralStaticGraph, _typing.Any]
  22. ) -> _typing.Union[autogl.data.graph.GeneralStaticGraph, _typing.Any]:
  23. if isinstance(data, autogl.data.graph.GeneralStaticGraph):
  24. for node_t in data.nodes:
  25. for candidate_feature_key in ('feat', 'x'):
  26. if candidate_feature_key in data.nodes[node_t].data:
  27. data.nodes[node_t].data[candidate_feature_key] = torch.ones(
  28. (data.nodes[node_t].data[candidate_feature_key].size(0), 1)
  29. ).to(data.nodes[node_t].data[candidate_feature_key])
  30. elif len(data.nodes[node_t].data) > 0:
  31. _ref = data.nodes[node_t].data[list(data.nodes[node_t].data)[0]]
  32. data.nodes[node_t].data[candidate_feature_key] = (
  33. torch.ones((_ref.size(0), 1)).to(_ref)
  34. )
  35. else:
  36. data.nodes[node_t].data[candidate_feature_key] = torch.ones(
  37. (torch.unique(data.edges.connections).size(0), 1)
  38. )
  39. elif hasattr(data, 'x') and isinstance(data.x, torch.Tensor):
  40. data.x = torch.ones((data.x.shape[0], 1)).to(data.x)
  41. elif hasattr(data, 'edge_index') and isinstance(data.edge_index, torch.Tensor):
  42. data.x = torch.ones((torch.unique(data.edge_index).size(0), 1)).to(data.edge_index)
  43. else:
  44. raise ValueError("Unsupported provided data")
  45. return data
  46. def op_sum(x, nbs):
  47. res = np.zeros_like(x)
  48. for u in range(len(nbs)):
  49. nb = nbs[u]
  50. if len(nb != 0):
  51. res[u] = np.sum(x[nb], axis=0)
  52. return res
  53. def op_mean(x, nbs):
  54. res = np.zeros_like(x)
  55. for u in range(len(nbs)):
  56. nb = nbs[u]
  57. if len(nb != 0):
  58. res[u] = np.mean(x[nb], axis=0)
  59. return res
  60. def op_max(x, nbs):
  61. res = np.zeros_like(x)
  62. for u in range(len(nbs)):
  63. nb = nbs[u]
  64. if len(nb != 0):
  65. res[u] = np.max(x[nb], axis=0)
  66. return res
  67. def op_min(x, nbs):
  68. res = np.zeros_like(x)
  69. for u in range(len(nbs)):
  70. nb = nbs[u]
  71. if len(nb != 0):
  72. res[u] = np.min(x[nb], axis=0)
  73. return res
  74. def op_prod(x, nbs):
  75. res = np.zeros_like(x)
  76. for u in range(len(nbs)):
  77. nb = nbs[u]
  78. if len(nb != 0):
  79. res[u] = np.prod(x[nb], axis=0)
  80. return res
  81. mms = preprocessing.MinMaxScaler()
  82. ss = preprocessing.StandardScaler()
  83. def scale(x):
  84. return ss.fit_transform(x)
  85. class Timer:
  86. def __init__(self, timebudget=None):
  87. self._timebudget = timebudget
  88. self._esti_time = 0
  89. self._g_start = time.time()
  90. def start(self):
  91. self._start = time.time()
  92. def end(self):
  93. time_use = time.time() - self._start
  94. self._esti_time = (self._esti_time + time_use) / 2
  95. def is_timeout(self):
  96. timebudget = self._timebudget
  97. if timebudget:
  98. timebudget = self._timebudget - (time.time() - self._g_start)
  99. if timebudget < self._esti_time:
  100. return True
  101. return False
  102. @DataPreprocessorUniversalRegistry.register_data_preprocessor('DeepGL'.lower())
  103. class AutoFeatureEngineer(FeatureEngineer):
  104. r"""
  105. Notes
  106. -----
  107. An implementation of auto feature engineering method Deepgl [#]_ ,which iteratively generates features by aggregating neighbour features
  108. and select a fixed number of features to automatically add important graph-aware features.
  109. References
  110. ----------
  111. .. [#] Rossi, R. A., Zhou, R., & Ahmed, N. K. (2020).
  112. Deep Inductive Graph Representation Learning.
  113. IEEE Transactions on Knowledge and Data Engineering, 32(3), 438–452.
  114. https://doi.org/10.1109/TKDE.2018.2878247
  115. Parameters
  116. ----------
  117. fix_length : int
  118. fixed number of features for every epoch. The final number of features added will be
  119. ``fixlen`` \times ``max_epoch``, 200 \times 5 in default.
  120. max_epoch : int
  121. number of epochs in total process.
  122. time_budget : int
  123. timebudget(seconds) for the feature engineering process, None for no time budget . Note that
  124. this time budget is a soft budget ,which is obtained by rough time estimation through previous iterations and
  125. may finally exceed the actual timebudget
  126. y_sel_func : Callable
  127. feature selector function object for selection at each iteration ,lightgbm in default. Note that in original paper,
  128. connected components of feature graph is used , and you may implement it by yourself if you want.
  129. verbosity : int
  130. hide any infomation except error and fatal if ``verbosity`` < 1
  131. """
  132. def __init__(
  133. self,
  134. fix_length: int = 200,
  135. max_epoch: int = 5,
  136. time_budget: _typing.Optional[int] = None,
  137. feature_selector=GBDTFeatureSelector,
  138. verbosity: int = 0,
  139. *args, **kwargs
  140. ):
  141. super(AutoFeatureEngineer, self).__init__()
  142. self._ops = [op_sum, op_mean, op_max, op_min]
  143. self._sim = cosine_similarity
  144. self._fixlen = fix_length
  145. self._max_epoch = max_epoch
  146. self._timebudget = time_budget
  147. self._feature_selector = feature_selector(
  148. fix_length, verbose_eval=verbosity >= 1, *args, **kwargs
  149. )
  150. self._verbosity = verbosity
  151. def _gen(self, x) -> np.ndarray:
  152. res = []
  153. for i, op in enumerate(self._ops):
  154. res.append(op(x, self.__neighbours))
  155. res = np.concatenate(res, axis=1)
  156. return res
  157. def _fit(self, homogeneous_static_graph: autogl.data.graph.GeneralStaticGraph):
  158. if not (
  159. homogeneous_static_graph.nodes.is_homogeneous and
  160. homogeneous_static_graph.edges.is_homogeneous
  161. ):
  162. raise ValueError
  163. if 'x' in homogeneous_static_graph.nodes.data:
  164. _feature_key = 'x'
  165. _original_features: torch.Tensor = (
  166. homogeneous_static_graph.nodes.data['x']
  167. )
  168. elif 'feat' in homogeneous_static_graph.nodes.data:
  169. _feature_key = 'feat'
  170. _original_features: torch.Tensor = (
  171. homogeneous_static_graph.nodes.data['feat']
  172. )
  173. else:
  174. raise ValueError
  175. num_nodes = _original_features.size(0)
  176. neighbours = [[] for _ in range(num_nodes)]
  177. for u, v in homogeneous_static_graph.edges.connections.t().numpy():
  178. neighbours[u].append(v)
  179. self.__neighbours: _typing.Sequence[np.ndarray] = tuple(
  180. [np.array(v) for v in neighbours]
  181. )
  182. x: np.ndarray = _original_features.numpy()
  183. gx: np.ndarray = x.copy()
  184. verbs = []
  185. soft_timer = Timer(self._timebudget)
  186. self._selection = []
  187. for epoch in tqdm.tqdm(range(self._max_epoch), disable=self._verbosity <= 0):
  188. soft_timer.start()
  189. verb = [epoch, gx.shape[1]]
  190. gx = self._gen(gx)
  191. gx = scale(gx)
  192. verb.append(gx.shape[1])
  193. homogeneous_static_graph.nodes.data[_feature_key] = torch.from_numpy(gx)
  194. self._feature_selector._fit(homogeneous_static_graph)
  195. self._selection.append(self._feature_selector._selection)
  196. homogeneous_static_graph = self._feature_selector._transform(
  197. homogeneous_static_graph
  198. )
  199. gx: np.ndarray = homogeneous_static_graph.nodes.data[_feature_key].numpy()
  200. verb.append(gx.shape[1])
  201. x = np.concatenate([x, gx], axis=1)
  202. verbs.append(verb)
  203. soft_timer.end()
  204. if soft_timer.is_timeout():
  205. break
  206. if self._verbosity >= 1:
  207. LOGGER.info(
  208. tabulate.tabulate(verbs, headers="epoch origin after-gen after-sel".split())
  209. )
  210. homogeneous_static_graph.nodes.data[_feature_key] = torch.from_numpy(x)
  211. return homogeneous_static_graph
  212. def _transform(self, homogeneous_static_graph: autogl.data.graph.GeneralStaticGraph):
  213. if not (
  214. homogeneous_static_graph.nodes.is_homogeneous and
  215. homogeneous_static_graph.edges.is_homogeneous
  216. ):
  217. raise ValueError
  218. if 'x' in homogeneous_static_graph.nodes.data:
  219. _feature_key = 'x'
  220. _original_features: torch.Tensor = (
  221. homogeneous_static_graph.nodes.data['x']
  222. )
  223. elif 'feat' in homogeneous_static_graph.nodes.data:
  224. _feature_key = 'feat'
  225. _original_features: torch.Tensor = (
  226. homogeneous_static_graph.nodes.data['feat']
  227. )
  228. else:
  229. raise ValueError
  230. x: np.ndarray = _original_features.numpy()
  231. gx: np.ndarray = x.copy()
  232. for selection in self._selection:
  233. gx = scale(self._gen(gx))[:, selection]
  234. x = np.concatenate([x, gx], axis=1)
  235. homogeneous_static_graph.nodes.data[_feature_key] = torch.from_numpy(x)
  236. return homogeneous_static_graph