|
- import torch
- import typing as _typing
- import autogl
- from autogl.data.graph import GeneralStaticGraph
- from .._base_feature_engineer import BaseFeatureEngineer
- from .._feature_engineer_registry import FeatureEngineerUniversalRegistry
-
-
- class BaseFeatureGenerator(BaseFeatureEngineer):
- def __init__(self, override_features: bool = False):
- super(BaseFeatureGenerator, self).__init__()
- if not isinstance(override_features, bool):
- raise TypeError
- else:
- self._override_features: bool = override_features
-
- def _extract_nodes_feature(self, data: autogl.data.Data) -> torch.Tensor:
- raise NotImplementedError
-
- def __transform_homogeneous_static_graph(
- self, homogeneous_static_graph: GeneralStaticGraph
- ) -> GeneralStaticGraph:
- if not (
- homogeneous_static_graph.nodes.is_homogeneous and
- homogeneous_static_graph.edges.is_homogeneous
- ):
- raise ValueError("Provided static graph must be homogeneous")
- if 'x' in homogeneous_static_graph.nodes.data:
- feature_key: _typing.Optional[str] = 'x'
- features: _typing.Optional[torch.Tensor] = (
- homogeneous_static_graph.nodes.data['x']
- )
- elif 'feat' in homogeneous_static_graph.nodes.data:
- feature_key: _typing.Optional[str] = 'feat'
- features: _typing.Optional[torch.Tensor] = (
- homogeneous_static_graph.nodes.data['feat']
- )
- else:
- feature_key: _typing.Optional[str] = None
- features: _typing.Optional[torch.Tensor] = None
- if 'y' in homogeneous_static_graph.nodes.data:
- label: _typing.Optional[torch.Tensor] = (
- homogeneous_static_graph.nodes.data['y']
- )
- elif 'label' in homogeneous_static_graph.nodes.data:
- label: _typing.Optional[torch.Tensor] = (
- homogeneous_static_graph.nodes.data['label']
- )
- else:
- label: _typing.Optional[torch.Tensor] = None
- if (
- 'edge_weight' in homogeneous_static_graph.edges.data and
- homogeneous_static_graph.edges.data['edge_weight'].dim() == 1
- ):
- edge_weight: torch.Tensor = (
- homogeneous_static_graph.edges.data['edge_weight']
- )
- else:
- edge_weight: torch.Tensor = torch.ones(
- homogeneous_static_graph.edges.connections.size(1)
- )
- data = autogl.data.Data(
- edge_index=homogeneous_static_graph.edges.connections,
- x=features, y=label
- )
- setattr(data, "edge_weight", edge_weight)
- extracted_features: torch.Tensor = self._extract_nodes_feature(data)
- if isinstance(feature_key, str):
- nodes_features: torch.Tensor = (
- homogeneous_static_graph.nodes.data[feature_key].view(-1, 1)
- if homogeneous_static_graph.nodes.data[feature_key].dim() == 1
- else homogeneous_static_graph.nodes.data[feature_key]
- )
- assert extracted_features.size(0) == nodes_features.size(0)
- assert extracted_features.dim() == nodes_features.dim() == 2
- homogeneous_static_graph.nodes.data[feature_key] = (
- extracted_features.to(nodes_features.device)
- if self._override_features
- else torch.cat(
- [nodes_features, extracted_features.to(nodes_features.device)], dim=-1
- )
- )
- else:
- if autogl.backend.DependentBackend.is_pyg():
- homogeneous_static_graph.nodes.data['x'] = extracted_features
- elif autogl.backend.DependentBackend.is_dgl():
- homogeneous_static_graph.nodes.data['feat'] = extracted_features
- return homogeneous_static_graph
-
- def _transform(self, data: _typing.Any) -> _typing.Any:
- if isinstance(data, GeneralStaticGraph):
- return self.__transform_homogeneous_static_graph(data)
- else:
- raise NotImplementedError(
- f"Feature Generator only support instance of {GeneralStaticGraph} as provided data"
- )
-
-
- @FeatureEngineerUniversalRegistry.register_feature_engineer("OneHot".lower())
- class OneHotFeatureGenerator(BaseFeatureGenerator):
- def _extract_nodes_feature(self, data: autogl.data.Data) -> torch.Tensor:
- num_nodes: int = (
- data.x.size(0)
- if data.x is not None and isinstance(data.x, torch.Tensor)
- else (data.edge_index.max().item() + 1)
- )
- return torch.eye(num_nodes)
|