You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

_basic.py 3.0 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. import numpy as np
  2. import torch
  3. import typing as _typing
  4. from autogl.data.graph import GeneralStaticGraph
  5. from .._base_feature_engineer import BaseFeatureEngineer
  6. from .._feature_engineer_registry import FeatureEngineerUniversalRegistry
  7. class BaseFeatureSelector(BaseFeatureEngineer):
  8. def __init__(self):
  9. self._selection : _typing.Optional[torch.Tensor] = None
  10. super(BaseFeatureSelector, self).__init__()
  11. def __transform_homogeneous_static_graph(
  12. self, static_graph: GeneralStaticGraph
  13. ) -> GeneralStaticGraph:
  14. if (
  15. 'x' in static_graph.nodes.data and
  16. isinstance(self._selection, (torch.Tensor, np.ndarray))
  17. ):
  18. static_graph.nodes.data['x'] = static_graph.nodes.data['x'][:, self._selection]
  19. if (
  20. 'feat' in static_graph.nodes.data and
  21. isinstance(self._selection, (torch.Tensor, np.ndarray))
  22. ):
  23. static_graph.nodes.data['feat'] = static_graph.nodes.data['feat'][:, self._selection]
  24. return static_graph
  25. def _transform(
  26. self, data: _typing.Union[GeneralStaticGraph, _typing.Any]
  27. ) -> _typing.Union[GeneralStaticGraph, _typing.Any]:
  28. if isinstance(data, GeneralStaticGraph):
  29. return self.__transform_homogeneous_static_graph(data)
  30. elif (
  31. hasattr(data, 'x') and isinstance(data.x, torch.Tensor) and
  32. torch.is_tensor(data.x) and data.x.dim() == 2
  33. ):
  34. data.x = data.x[:, self._selection]
  35. return data
  36. else:
  37. return data
  38. @FeatureEngineerUniversalRegistry.register_feature_engineer("FilterConstant")
  39. class FilterConstant(BaseFeatureSelector):
  40. r"""drop constant features"""
  41. def _fit(self, static_graph: GeneralStaticGraph) -> GeneralStaticGraph:
  42. if (
  43. 'x' in static_graph.nodes.data and
  44. self._selection not in (Ellipsis, None) and
  45. isinstance(self._selection, torch.Tensor) and
  46. torch.is_tensor(self._selection) and self._selection.dim() == 1
  47. ):
  48. feature: _typing.Optional[np.ndarray] = static_graph.nodes.data['x'].numpy()
  49. elif (
  50. 'feat' in static_graph.nodes.data and
  51. self._selection not in (Ellipsis, None) and
  52. isinstance(self._selection, torch.Tensor) and
  53. torch.is_tensor(self._selection) and self._selection.dim() == 1
  54. ):
  55. feature: _typing.Optional[np.ndarray] = static_graph.nodes.data['feat'].numpy()
  56. else:
  57. feature: _typing.Optional[np.ndarray] = None
  58. self._selection: _typing.Optional[torch.Tensor] = torch.from_numpy(
  59. np.where(np.all(feature == feature[0, :], axis=0) == np.array(False))[0]
  60. if feature is not None and isinstance(feature, np.ndarray) and feature.ndim == 2
  61. else None
  62. )
  63. return static_graph