You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

_basic.py 3.2 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. import numpy as np
  2. import torch
  3. import typing as _typing
  4. from autogl.data.graph import GeneralStaticGraph
  5. from .._base_feature_engineer import BaseFeatureEngineer
  6. from .._feature_engineer_registry import FeatureEngineerUniversalRegistry
  7. class BaseFeatureSelector(BaseFeatureEngineer):
  8. def __init__(self):
  9. self._selection : _typing.Optional[torch.Tensor] = None
  10. super(BaseFeatureSelector, self).__init__()
  11. def __transform_homogeneous_static_graph(
  12. self, static_graph: GeneralStaticGraph
  13. ) -> GeneralStaticGraph:
  14. if (
  15. 'x' in static_graph.nodes.data and
  16. self._selection not in (Ellipsis, None) and
  17. isinstance(self._selection, torch.Tensor) and
  18. torch.is_tensor(self._selection) and self._selection.dim() == 1
  19. ):
  20. static_graph.nodes.data['x'] = static_graph.nodes.data['x'][:, self._selection]
  21. if (
  22. 'feat' in static_graph.nodes.data and
  23. self._selection not in (Ellipsis, None) and
  24. isinstance(self._selection, torch.Tensor) and
  25. torch.is_tensor(self._selection) and self._selection.dim() == 1
  26. ):
  27. static_graph.nodes.data['feat'] = static_graph.nodes.data['feat'][:, self._selection]
  28. return static_graph
  29. def _transform(
  30. self, data: _typing.Union[GeneralStaticGraph, _typing.Any]
  31. ) -> _typing.Union[GeneralStaticGraph, _typing.Any]:
  32. if isinstance(data, GeneralStaticGraph):
  33. return self.__transform_homogeneous_static_graph(data)
  34. elif (
  35. hasattr(data, 'x') and isinstance(data.x, torch.Tensor) and
  36. torch.is_tensor(data.x) and data.x.dim() == 2
  37. ):
  38. data.x = data.x[:, self._selection]
  39. return data
  40. else:
  41. return data
  42. @FeatureEngineerUniversalRegistry.register_feature_engineer("FilterConstant")
  43. class FilterConstant(BaseFeatureSelector):
  44. r"""drop constant features"""
  45. def _fit(self, static_graph: GeneralStaticGraph) -> GeneralStaticGraph:
  46. if (
  47. 'x' in static_graph.nodes.data and
  48. self._selection not in (Ellipsis, None) and
  49. isinstance(self._selection, torch.Tensor) and
  50. torch.is_tensor(self._selection) and self._selection.dim() == 1
  51. ):
  52. feature: _typing.Optional[np.ndarray] = static_graph.nodes.data['x'].numpy()
  53. elif (
  54. 'feat' in static_graph.nodes.data and
  55. self._selection not in (Ellipsis, None) and
  56. isinstance(self._selection, torch.Tensor) and
  57. torch.is_tensor(self._selection) and self._selection.dim() == 1
  58. ):
  59. feature: _typing.Optional[np.ndarray] = static_graph.nodes.data['feat'].numpy()
  60. else:
  61. feature: _typing.Optional[np.ndarray] = None
  62. self._selection: _typing.Optional[torch.Tensor] = torch.from_numpy(
  63. np.where(np.all(feature == feature[0, :], axis=0) == np.array(False))[0]
  64. if feature is not None and isinstance(feature, np.ndarray) and feature.ndim == 2
  65. else None
  66. )
  67. return static_graph