You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

_basic.py 2.6 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. import numpy as np
  2. import torch
  3. import typing as _typing
  4. from autogl.data.graph import GeneralStaticGraph
  5. from .._base_feature_engineer import BaseFeatureEngineer
  6. from .._feature_engineer_registry import FeatureEngineerUniversalRegistry
  7. class BaseFeatureSelector(BaseFeatureEngineer):
  8. def __init__(self):
  9. self._selection = _typing.Optional[torch.Tensor] = None
  10. super(BaseFeatureSelector, self).__init__()
  11. def _transform(self, static_graph: GeneralStaticGraph) -> GeneralStaticGraph:
  12. if (
  13. 'x' in static_graph.nodes.data and
  14. self._selection not in (Ellipsis, None) and
  15. isinstance(self._selection, torch.Tensor) and
  16. torch.is_tensor(self._selection) and self._selection.dim() == 1
  17. ):
  18. static_graph.nodes.data['x'] = static_graph.nodes.data['x'][:, self._selection]
  19. if (
  20. 'feat' in static_graph.nodes.data and
  21. self._selection not in (Ellipsis, None) and
  22. isinstance(self._selection, torch.Tensor) and
  23. torch.is_tensor(self._selection) and self._selection.dim() == 1
  24. ):
  25. static_graph.nodes.data['feat'] = static_graph.nodes.data['feat'][:, self._selection]
  26. return static_graph
  27. @FeatureEngineerUniversalRegistry.register_feature_engineer("FilterConstant")
  28. class FilterConstant(BaseFeatureSelector):
  29. r"""drop constant features"""
  30. def _fit(self, static_graph: GeneralStaticGraph) -> GeneralStaticGraph:
  31. if (
  32. 'x' in static_graph.nodes.data and
  33. self._selection not in (Ellipsis, None) and
  34. isinstance(self._selection, torch.Tensor) and
  35. torch.is_tensor(self._selection) and self._selection.dim() == 1
  36. ):
  37. feature: _typing.Optional[np.ndarray] = static_graph.nodes.data['x'].numpy()
  38. elif (
  39. 'feat' in static_graph.nodes.data and
  40. self._selection not in (Ellipsis, None) and
  41. isinstance(self._selection, torch.Tensor) and
  42. torch.is_tensor(self._selection) and self._selection.dim() == 1
  43. ):
  44. feature: _typing.Optional[np.ndarray] = static_graph.nodes.data['feat'].numpy()
  45. else:
  46. feature: _typing.Optional[np.ndarray] = None
  47. self._selection: _typing.Optional[torch.Tensor] = torch.from_numpy(
  48. np.where(np.all(feature == feature[0, :], axis=0) == np.array(False))[0]
  49. if feature is not None and isinstance(feature, np.ndarray) and feature.ndim == 2
  50. else None
  51. )
  52. return static_graph