|
- import numpy as np
- import torch
- import typing as _typing
- from autogl.data.graph import GeneralStaticGraph
- from .._base_feature_engineer import BaseFeatureEngineer
- from .._feature_engineer_registry import FeatureEngineerUniversalRegistry
-
-
- class BaseFeatureSelector(BaseFeatureEngineer):
- def __init__(self):
- self._selection = _typing.Optional[torch.Tensor] = None
- super(BaseFeatureSelector, self).__init__()
-
- def _transform(self, static_graph: GeneralStaticGraph) -> GeneralStaticGraph:
- if (
- 'x' in static_graph.nodes.data and
- self._selection not in (Ellipsis, None) and
- isinstance(self._selection, torch.Tensor) and
- torch.is_tensor(self._selection) and self._selection.dim() == 1
- ):
- static_graph.nodes.data['x'] = static_graph.nodes.data['x'][:, self._selection]
- if (
- 'feat' in static_graph.nodes.data and
- self._selection not in (Ellipsis, None) and
- isinstance(self._selection, torch.Tensor) and
- torch.is_tensor(self._selection) and self._selection.dim() == 1
- ):
- static_graph.nodes.data['feat'] = static_graph.nodes.data['feat'][:, self._selection]
- return static_graph
-
-
- @FeatureEngineerUniversalRegistry.register_feature_engineer("FilterConstant")
- class FilterConstant(BaseFeatureSelector):
- r"""drop constant features"""
-
- def _fit(self, static_graph: GeneralStaticGraph) -> GeneralStaticGraph:
- if (
- 'x' in static_graph.nodes.data and
- self._selection not in (Ellipsis, None) and
- isinstance(self._selection, torch.Tensor) and
- torch.is_tensor(self._selection) and self._selection.dim() == 1
- ):
- feature: _typing.Optional[np.ndarray] = static_graph.nodes.data['x'].numpy()
- elif (
- 'feat' in static_graph.nodes.data and
- self._selection not in (Ellipsis, None) and
- isinstance(self._selection, torch.Tensor) and
- torch.is_tensor(self._selection) and self._selection.dim() == 1
- ):
- feature: _typing.Optional[np.ndarray] = static_graph.nodes.data['feat'].numpy()
- else:
- feature: _typing.Optional[np.ndarray] = None
- self._selection: _typing.Optional[torch.Tensor] = torch.from_numpy(
- np.where(np.all(feature == feature[0, :], axis=0) == np.array(False))[0]
- if feature is not None and isinstance(feature, np.ndarray) and feature.ndim == 2
- else None
- )
- return static_graph
|