Debug for Default Implementation of GeneralStaticGraph Debug for OGB nodes datasets Improvement for feature generators and graph feature extractors to support any generic provided conventional datatags/v0.3.1
| @@ -154,7 +154,7 @@ class HeterogeneousNodesContainerImplementation(HeterogeneousNodesContainer): | |||
| if data.size(0) != obsolete_data.size(0): | |||
| raise ValueError | |||
| elif len(self.__nodes_data.get(__node_t)) > 0: | |||
| num_nodes: int = self.__nodes_data.get(__node_t).get(list(self.node_types)[0]).size(0) | |||
| num_nodes: int = self.__nodes_data[__node_t][list(self.__nodes_data[__node_t].keys())[0]].size(0) | |||
| if data.size(0) != num_nodes: | |||
| raise ValueError | |||
| self.__nodes_data[__node_t][data_key] = data | |||
| @@ -61,7 +61,7 @@ class HomogeneousStaticGraphToNetworkX: | |||
| homogeneous_static_graph.edges.data[data_key].squeeze().tolist() | |||
| ) | |||
| for i, (u, v) in enumerate(homogeneous_static_graph.edges.connections.t().tolist()): | |||
| if (__remove_self_loops and v == u) or (__to_undirected and v > u): | |||
| if __remove_self_loops and v == u: | |||
| continue | |||
| g.add_edge(u, v) | |||
| for data_key in edges_data: | |||
| @@ -57,11 +57,11 @@ class _OGBNDatasetUtil(_OGBDatasetUtil): | |||
| homogeneous_static_graph.nodes.data['train_mask'] = index_to_mask( | |||
| torch.from_numpy(train_index), ogbn_data['num_nodes'] | |||
| ) | |||
| if val_index not in (Ellipsis, None) and isinstance(val_index, np.ndarray): | |||
| if isinstance(val_index, np.ndarray): | |||
| homogeneous_static_graph.nodes.data['val_mask'] = index_to_mask( | |||
| torch.from_numpy(val_index), ogbn_data['num_nodes'] | |||
| ) | |||
| if test_index not in (Ellipsis, None) and isinstance(test_index, np.ndarray): | |||
| if isinstance(test_index, np.ndarray): | |||
| homogeneous_static_graph.nodes.data['test_mask'] = index_to_mask( | |||
| torch.from_numpy(test_index), ogbn_data['num_nodes'] | |||
| ) | |||
| @@ -105,8 +105,7 @@ class OGBNProductsDataset(InMemoryStaticGraphSet): | |||
| super(OGBNProductsDataset, self).__init__([ | |||
| _OGBNDatasetUtil.ogbn_dataset_to_general_static_graph( | |||
| ogbn_dataset, "y", | |||
| {"node_feat": "x"}, | |||
| {"edge_feat": "edge_feat"} | |||
| {"node_feat": "x"} | |||
| ) | |||
| ]) | |||
| @@ -119,10 +118,7 @@ class OGBNProteinsDataset(InMemoryStaticGraphSet): | |||
| super(OGBNProteinsDataset, self).__init__([ | |||
| _OGBNDatasetUtil.ogbn_dataset_to_general_static_graph( | |||
| ogbn_dataset, "label", | |||
| { | |||
| "node_feat": "feat", | |||
| "node_species": "species" | |||
| }, | |||
| {"node_species": "species"}, | |||
| {"edge_feat": "edge_feat"} | |||
| ) | |||
| ]) | |||
| @@ -130,10 +126,7 @@ class OGBNProteinsDataset(InMemoryStaticGraphSet): | |||
| super(OGBNProteinsDataset, self).__init__([ | |||
| _OGBNDatasetUtil.ogbn_dataset_to_general_static_graph( | |||
| ogbn_dataset, "y", | |||
| { | |||
| "node_feat": "x", | |||
| "node_species": "species" | |||
| }, | |||
| {"node_species": "species"}, | |||
| {"edge_feat": "edge_feat"} | |||
| ) | |||
| ]) | |||
| @@ -150,8 +143,7 @@ class OGBNArxivDataset(InMemoryStaticGraphSet): | |||
| { | |||
| "node_feat": "feat", | |||
| "node_year": "year" | |||
| }, | |||
| {"edge_feat": "edge_feat"} | |||
| } | |||
| ) | |||
| ]) | |||
| elif _backend.DependentBackend.is_pyg(): | |||
| @@ -161,8 +153,7 @@ class OGBNArxivDataset(InMemoryStaticGraphSet): | |||
| { | |||
| "node_feat": "x", | |||
| "node_year": "year" | |||
| }, | |||
| {"edge_feat": "edge_feat"} | |||
| } | |||
| ) | |||
| ]) | |||
| @@ -178,8 +169,7 @@ class OGBNPapers100MDataset(InMemoryStaticGraphSet): | |||
| { | |||
| "node_feat": "feat", | |||
| "node_year": "year" | |||
| }, | |||
| {"edge_feat": "edge_feat"} | |||
| } | |||
| ) | |||
| ]) | |||
| elif _backend.DependentBackend.is_pyg(): | |||
| @@ -189,8 +179,7 @@ class OGBNPapers100MDataset(InMemoryStaticGraphSet): | |||
| { | |||
| "node_feat": "x", | |||
| "node_year": "year" | |||
| }, | |||
| {"edge_feat": "edge_feat"} | |||
| } | |||
| ) | |||
| ]) | |||
| @@ -1,28 +1,27 @@ | |||
| import copy | |||
| import logging | |||
| import torch | |||
| import typing as _typing | |||
| from autogl.data.graph import GeneralStaticGraph | |||
| from autogl.data import InMemoryStaticGraphSet | |||
| from autogl.data import Dataset | |||
| from ...utils import get_logger | |||
| LOGGER = get_logger("FeatureEngineer") | |||
| LOGGER = logging.getLogger("FeatureEngineer") | |||
| class _BaseFeatureEngineer: | |||
| def __and__(self, other): | |||
| raise NotImplementedError | |||
| def fit( | |||
| self, in_memory_static_graph_set: InMemoryStaticGraphSet, | |||
| inplace: bool = True | |||
| ): | |||
| def fit_transform(self, dataset: Dataset, inplace=True) -> Dataset: | |||
| """ | |||
| Fit and transform dataset inplace or not w.r.t bool argument ``inplace`` | |||
| """ | |||
| dataset = self.fit(dataset) | |||
| return self.transform(dataset, inplace=inplace) | |||
| def fit(self, dataset: Dataset) -> Dataset: | |||
| raise NotImplementedError | |||
| def transform( | |||
| self, in_memory_static_graph_set: InMemoryStaticGraphSet, | |||
| inplace: bool = True | |||
| ) -> InMemoryStaticGraphSet: | |||
| def transform(self, dataset: Dataset, inplace: bool = True) -> Dataset: | |||
| raise NotImplementedError | |||
| @@ -42,73 +41,50 @@ class _ComposedFeatureEngineer(_BaseFeatureEngineer): | |||
| def __and__(self, other: _BaseFeatureEngineer): | |||
| return _ComposedFeatureEngineer((self, other)) | |||
| def fit(self, in_memory_static_graph_set, inplace: bool = True): | |||
| def fit(self, dataset) -> Dataset: | |||
| for fe in self.fe_components: | |||
| fe.fit(in_memory_static_graph_set, inplace) | |||
| dataset = fe.fit(dataset) | |||
| return dataset | |||
| def transform( | |||
| self, in_memory_static_graph_set, | |||
| inplace: bool = True | |||
| ): | |||
| def transform(self, dataset: Dataset, inplace: bool = True) -> Dataset: | |||
| for fe in self.fe_components: | |||
| in_memory_static_graph_set = fe.transform( | |||
| in_memory_static_graph_set, inplace | |||
| ) | |||
| return in_memory_static_graph_set | |||
| dataset = fe.transform(dataset, inplace) | |||
| return dataset | |||
| class BaseFeatureEngineer: | |||
| class BaseFeature(_BaseFeatureEngineer): | |||
| def __init__(self, multi_graph: bool = True, subgraph=False): | |||
| self._multi_graph: bool = multi_graph | |||
| def __and__(self, other): | |||
| return _ComposedFeatureEngineer((self, other)) | |||
| @classmethod | |||
| def __reset_graph_set( | |||
| cls, graphs: _typing.Sequence[GeneralStaticGraph], | |||
| in_memory_static_graph_set: InMemoryStaticGraphSet | |||
| ): | |||
| in_memory_static_graph_set.reset_dataset(graphs) | |||
| def _preprocess(self, static_graph: GeneralStaticGraph) -> GeneralStaticGraph: | |||
| return static_graph | |||
| def _preprocess(self, data: _typing.Any) -> _typing.Any: | |||
| return data | |||
| def _fit(self, static_graph: GeneralStaticGraph) -> GeneralStaticGraph: | |||
| return static_graph | |||
| def _fit(self, data: _typing.Any) -> _typing.Any: | |||
| return data | |||
| def _transform(self, static_graph: GeneralStaticGraph) -> GeneralStaticGraph: | |||
| return static_graph | |||
| def _transform(self, data: _typing.Any) -> _typing.Any: | |||
| return data | |||
| def _postprocess(self, static_graph: GeneralStaticGraph) -> GeneralStaticGraph: | |||
| return static_graph | |||
| def _postprocess(self, data: _typing.Any) -> _typing.Any: | |||
| return data | |||
| def fit( | |||
| self, in_memory_static_graph_set: InMemoryStaticGraphSet, | |||
| inplace: bool = True | |||
| ): | |||
| if not inplace: | |||
| in_memory_static_graph_set = copy.deepcopy(in_memory_static_graph_set) | |||
| def fit(self, dataset: Dataset) -> Dataset: | |||
| with torch.no_grad(): | |||
| __graphs: _typing.Sequence[GeneralStaticGraph] = [ | |||
| self._postprocess(self._transform(self._fit(self._preprocess(g)))) | |||
| for g in in_memory_static_graph_set | |||
| ] | |||
| self.__reset_graph_set(__graphs, in_memory_static_graph_set) | |||
| def transform( | |||
| self, in_memory_static_graph_set: InMemoryStaticGraphSet, | |||
| inplace: bool = True | |||
| ) -> InMemoryStaticGraphSet: | |||
| for i, data in enumerate(dataset): | |||
| dataset[i] = self._postprocess(self._transform(self._fit(self._preprocess(data)))) | |||
| return dataset | |||
| def transform(self, dataset: Dataset, inplace: bool = True) -> Dataset: | |||
| if not inplace: | |||
| in_memory_static_graph_set = copy.deepcopy(in_memory_static_graph_set) | |||
| dataset = copy.deepcopy(dataset) | |||
| with torch.no_grad(): | |||
| __graphs: _typing.Sequence[GeneralStaticGraph] = [ | |||
| self._postprocess(self._transform(self._preprocess(g))) | |||
| for g in in_memory_static_graph_set | |||
| ] | |||
| return in_memory_static_graph_set | |||
| for i, data in enumerate(dataset): | |||
| dataset[i] = self._postprocess(self._transform(self._preprocess(data))) | |||
| return dataset | |||
| class BaseFeature(BaseFeatureEngineer): | |||
| class BaseFeatureEngineer(BaseFeature): | |||
| ... | |||
| @@ -1,37 +0,0 @@ | |||
| import typing as _typing | |||
| from . import _base_feature_engineer | |||
| class _ComposedFeatureEngineer(_base_feature_engineer.BaseFeatureEngineer): | |||
| ... | |||
| class ComposedFeatureEngineer(_ComposedFeatureEngineer): | |||
| @property | |||
| def fe_components(self) -> _typing.Iterable[_base_feature_engineer.BaseFeatureEngineer]: | |||
| raise NotImplementedError # todo | |||
| def __init__(self, feature_engineers: _typing.Iterable[_base_feature_engineer.BaseFeatureEngineer]): | |||
| super(ComposedFeatureEngineer, self).__init__() | |||
| self.__fe_components: _typing.List[_base_feature_engineer.BaseFeatureEngineer] = [] | |||
| for fe in feature_engineers: | |||
| if isinstance(fe, ComposedFeatureEngineer): | |||
| self.__fe_components.extend(fe.fe_components) | |||
| elif isinstance(fe, _base_feature_engineer.BaseFeatureEngineer): | |||
| self.__fe_components.append(fe) | |||
| else: | |||
| raise TypeError | |||
| def fit(self, in_memory_static_graph_set, inplace: bool = True): | |||
| for fe in self.fe_components: | |||
| fe.fit(in_memory_static_graph_set, inplace) | |||
| def transform( | |||
| self, in_memory_static_graph_set, | |||
| inplace: bool = True | |||
| ): | |||
| for fe in self.fe_components: | |||
| in_memory_static_graph_set = fe.transform( | |||
| in_memory_static_graph_set, inplace | |||
| ) | |||
| return in_memory_static_graph_set | |||
| @@ -7,29 +7,29 @@ from .._feature_engineer_registry import FeatureEngineerUniversalRegistry | |||
| class BaseFeatureGenerator(BaseFeatureEngineer): | |||
| def _preprocess(self, static_graph: GeneralStaticGraph) -> GeneralStaticGraph: | |||
| if not ( | |||
| static_graph.nodes.is_homogeneous and | |||
| static_graph.edges.is_homogeneous | |||
| ): | |||
| raise ValueError("Provided static graph must be homogeneous") | |||
| else: | |||
| return static_graph | |||
| def _extract_nodes_feature(self, data: autogl.data.Data) -> torch.Tensor: | |||
| raise NotImplementedError | |||
| @classmethod | |||
| def __to_data(cls, homogeneous_static_graph: GeneralStaticGraph) -> autogl.data.Data: | |||
| def __transform_homogeneous_static_graph( | |||
| self, homogeneous_static_graph: GeneralStaticGraph | |||
| ) -> GeneralStaticGraph: | |||
| if not ( | |||
| homogeneous_static_graph.nodes.is_homogeneous and | |||
| homogeneous_static_graph.edges.is_homogeneous | |||
| ): | |||
| raise ValueError("Provided static graph must be homogeneous") | |||
| if 'x' in homogeneous_static_graph.nodes.data: | |||
| feature_key: _typing.Optional[str] = 'x' | |||
| features: _typing.Optional[torch.Tensor] = ( | |||
| homogeneous_static_graph.nodes.data['x'] | |||
| ) | |||
| elif 'feat' in homogeneous_static_graph.nodes.data: | |||
| feature_key: _typing.Optional[str] = 'feat' | |||
| features: _typing.Optional[torch.Tensor] = ( | |||
| homogeneous_static_graph.nodes.data['feat'] | |||
| ) | |||
| else: | |||
| feature_key: _typing.Optional[str] = None | |||
| features: _typing.Optional[torch.Tensor] = None | |||
| if 'y' in homogeneous_static_graph.nodes.data: | |||
| label: _typing.Optional[torch.Tensor] = ( | |||
| @@ -57,47 +57,37 @@ class BaseFeatureGenerator(BaseFeatureEngineer): | |||
| x=features, y=label | |||
| ) | |||
| setattr(data, "edge_weight", edge_weight) | |||
| return data | |||
| def _transform(self, homogeneous_static_graph: GeneralStaticGraph) -> GeneralStaticGraph: | |||
| nodes_features: torch.Tensor = self._extract_nodes_feature( | |||
| self.__to_data(homogeneous_static_graph) | |||
| ) | |||
| if not isinstance(nodes_features, torch.Tensor): | |||
| raise TypeError | |||
| elif nodes_features.dim() == 0: | |||
| raise ValueError | |||
| elif nodes_features.dim() == 1: | |||
| nodes_features = nodes_features.view(-1, 1) | |||
| if 'x' in homogeneous_static_graph.nodes.data: | |||
| x: torch.Tensor = ( | |||
| homogeneous_static_graph.nodes.data['x'].view(-1, 1) | |||
| if homogeneous_static_graph.nodes.data['x'].dim() == 1 | |||
| else homogeneous_static_graph.nodes.data['x'] | |||
| ) | |||
| assert nodes_features.size(0) == x.size(0) | |||
| assert nodes_features.dim() == x.dim() == 2 | |||
| homogeneous_static_graph.nodes.data['x'] = torch.cat( | |||
| [x, nodes_features.to(x.dtype)], dim=-1 | |||
| ) | |||
| elif 'feat' in homogeneous_static_graph.nodes.data: | |||
| x: torch.Tensor = ( | |||
| homogeneous_static_graph.nodes.data['feat'].view(-1, 1) | |||
| if homogeneous_static_graph.nodes.data['feat'].dim() == 1 | |||
| else homogeneous_static_graph.nodes.data['feat'] | |||
| extracted_features: torch.Tensor = self._extract_nodes_feature(data) | |||
| if isinstance(feature_key, str): | |||
| nodes_features: torch.Tensor = ( | |||
| homogeneous_static_graph.nodes.data[feature_key].view(-1, 1) | |||
| if homogeneous_static_graph.nodes.data[feature_key].dim() == 1 | |||
| else homogeneous_static_graph.nodes.data[feature_key] | |||
| ) | |||
| assert nodes_features.size(0) == x.size(0) | |||
| assert nodes_features.dim() == x.dim() == 2 | |||
| homogeneous_static_graph.nodes.data['feat'] = torch.cat( | |||
| [x, nodes_features.to(x.dtype)], dim=-1 | |||
| assert extracted_features.size(0) == nodes_features.size(0) | |||
| assert extracted_features.dim() == nodes_features.dim() == 2 | |||
| homogeneous_static_graph.nodes.data[feature_key] = torch.cat( | |||
| [ | |||
| nodes_features, | |||
| extracted_features.to(nodes_features.device) | |||
| ], | |||
| dim=-1 | |||
| ) | |||
| else: | |||
| if autogl.backend.DependentBackend.is_pyg(): | |||
| homogeneous_static_graph.nodes.data['x'] = nodes_features | |||
| homogeneous_static_graph.nodes.data['x'] = extracted_features | |||
| elif autogl.backend.DependentBackend.is_dgl(): | |||
| homogeneous_static_graph.nodes.data['feat'] = nodes_features | |||
| homogeneous_static_graph.nodes.data['feat'] = extracted_features | |||
| return homogeneous_static_graph | |||
| def _transform(self, data: _typing.Any) -> _typing.Any: | |||
| if isinstance(data, GeneralStaticGraph): | |||
| return self.__transform_homogeneous_static_graph(data) | |||
| else: | |||
| raise NotImplementedError( | |||
| f"Feature Generator only support instance of {GeneralStaticGraph} as provided data" | |||
| ) | |||
| @FeatureEngineerUniversalRegistry.register_feature_engineer("OneHot".lower()) | |||
| class OneHotFeatureGenerator(BaseFeatureGenerator): | |||
| @@ -1,6 +1,6 @@ | |||
| import netlsd | |||
| import networkx | |||
| import torch | |||
| from autogl.data.graph import GeneralStaticGraph | |||
| from autogl.data.graph.utils import conversion | |||
| from .._base_feature_engineer import BaseFeatureEngineer | |||
| @@ -25,17 +25,58 @@ class NetLSD(BaseFeatureEngineer): | |||
| self.__kwargs = kwargs | |||
| super(NetLSD, self).__init__() | |||
| def _transform(self, static_graph: GeneralStaticGraph) -> GeneralStaticGraph: | |||
| temp = netlsd.heat( | |||
| def __extract(self, nx_g: networkx.Graph) -> torch.Tensor: | |||
| return torch.tensor(netlsd.heat(nx_g, *self.__args, **self.__kwargs)).view(-1) | |||
| def __transform_homogeneous_static_graph( | |||
| self, homogeneous_static_graph: GeneralStaticGraph | |||
| ) -> GeneralStaticGraph: | |||
| if not ( | |||
| homogeneous_static_graph.nodes.is_homogeneous and | |||
| homogeneous_static_graph.edges.is_homogeneous | |||
| ): | |||
| raise ValueError("Provided static graph must be homogeneous") | |||
| dsc: torch.Tensor = self.__extract( | |||
| conversion.HomogeneousStaticGraphToNetworkX(to_undirected=True).__call__( | |||
| static_graph, to_undirected=True | |||
| ), | |||
| *self.__args, **self.__kwargs | |||
| homogeneous_static_graph, to_undirected=True | |||
| ) | |||
| ) | |||
| dsc: torch.Tensor = torch.tensor([temp]).view(-1) | |||
| if 'gf' in static_graph.data: | |||
| gf = static_graph.data['gf'].view(-1) | |||
| static_graph.data['gf'] = torch.cat([gf, dsc]) | |||
| if 'gf' in homogeneous_static_graph.data: | |||
| gf = homogeneous_static_graph.data['gf'].view(-1) | |||
| homogeneous_static_graph.data['gf'] = torch.cat([gf, dsc]) | |||
| else: | |||
| homogeneous_static_graph.data['gf'] = dsc | |||
| return homogeneous_static_graph | |||
| @classmethod | |||
| def __edge_index_to_nx_graph(cls, edge_index: torch.Tensor) -> networkx.Graph: | |||
| g: networkx.Graph = networkx.Graph() | |||
| for u, v in edge_index.t().tolist(): | |||
| if u == v: | |||
| continue | |||
| else: | |||
| g.add_edge(u, v) | |||
| return g | |||
| def __transform_data(self, data): | |||
| if not ( | |||
| hasattr(data, "edge_index") and | |||
| torch.is_tensor(data.edge_index) and | |||
| isinstance(data.edge_index, torch.Tensor) and | |||
| data.edge_index.dim() == data.edge_index.size(0) == 2 and | |||
| data.edge_index.dtype == torch.long | |||
| ): | |||
| raise TypeError("Unsupported provided data") | |||
| dsc: torch.Tensor = self.__extract(self.__edge_index_to_nx_graph(data.edge_index)) | |||
| if hasattr(data, 'gf') and isinstance(data.gf, torch.Tensor): | |||
| gf = data.gf.view(-1) | |||
| data.gf = torch.cat([gf, dsc]) | |||
| else: | |||
| data.gf = dsc | |||
| return data | |||
| def _transform(self, data): | |||
| if isinstance(data, GeneralStaticGraph): | |||
| return self.__transform_homogeneous_static_graph(data) | |||
| else: | |||
| static_graph.data['gf'] = dsc | |||
| return static_graph | |||
| return self.__transform_data(data) | |||
| @@ -28,17 +28,62 @@ class _NetworkXGraphFeatureEngineer(BaseFeatureEngineer): | |||
| self.__feature_extractor: _typing.Callable[[networkx.Graph], _typing.Any] = feature_extractor | |||
| super(_NetworkXGraphFeatureEngineer, self).__init__() | |||
| def _transform(self, static_graph: GeneralStaticGraph) -> GeneralStaticGraph: | |||
| dsc = self.__feature_extractor( | |||
| conversion.HomogeneousStaticGraphToNetworkX(to_undirected=True)(static_graph) | |||
| ) | |||
| dsc: torch.Tensor = torch.tensor([dsc]).view(-1) | |||
| if 'gf' in static_graph.data: | |||
| gf = static_graph.data['gf'].view(-1) | |||
| static_graph.data['gf'] = torch.cat([gf, dsc]) | |||
| def __transform_homogeneous_static_graph( | |||
| self, homogeneous_static_graph: GeneralStaticGraph | |||
| ) -> GeneralStaticGraph: | |||
| if not ( | |||
| homogeneous_static_graph.nodes.is_homogeneous and | |||
| homogeneous_static_graph.edges.is_homogeneous | |||
| ): | |||
| raise ValueError("Provided static graph must be homogeneous") | |||
| dsc: torch.Tensor = torch.tensor( | |||
| [ | |||
| self.__feature_extractor( | |||
| conversion.HomogeneousStaticGraphToNetworkX(to_undirected=True)(homogeneous_static_graph) | |||
| ) | |||
| ] | |||
| ).view(-1) | |||
| if 'gf' in homogeneous_static_graph.data: | |||
| gf = homogeneous_static_graph.data['gf'].view(-1) | |||
| homogeneous_static_graph.data['gf'] = torch.cat([gf, dsc]) | |||
| else: | |||
| static_graph.data['gf'] = dsc | |||
| return static_graph | |||
| homogeneous_static_graph.data['gf'] = dsc | |||
| return homogeneous_static_graph | |||
| @classmethod | |||
| def __edge_index_to_nx_graph(cls, edge_index: torch.Tensor) -> networkx.Graph: | |||
| g: networkx.Graph = networkx.Graph() | |||
| for u, v in edge_index.t().tolist(): | |||
| if u == v: | |||
| continue | |||
| else: | |||
| g.add_edge(u, v) | |||
| return g | |||
| def __transform_data(self, data): | |||
| if not ( | |||
| hasattr(data, "edge_index") and | |||
| torch.is_tensor(data.edge_index) and | |||
| isinstance(data.edge_index, torch.Tensor) and | |||
| data.edge_index.dim() == data.edge_index.size(0) == 2 and | |||
| data.edge_index.dtype == torch.long | |||
| ): | |||
| raise TypeError("Unsupported provided data") | |||
| dsc: torch.Tensor = torch.tensor( | |||
| [self.__feature_extractor(self.__edge_index_to_nx_graph(data.edge_index))] | |||
| ).view(-1) | |||
| if hasattr(data, 'gf') and isinstance(data.gf, torch.Tensor): | |||
| gf = data.gf.view(-1) | |||
| data.gf = torch.cat([gf, dsc]) | |||
| else: | |||
| data.gf = dsc | |||
| return data | |||
| def _transform(self, data): | |||
| if isinstance(data, GeneralStaticGraph): | |||
| return self.__transform_homogeneous_static_graph(data) | |||
| else: | |||
| return self.__transform_data(data) | |||
| @FeatureEngineerUniversalRegistry.register_feature_engineer("NXLargeCliqueSize") | |||
| @@ -12,7 +12,7 @@ from .autone_file import utils | |||
| from torch_geometric.data import GraphSAINTRandomWalkSampler | |||
| from ..feature.graph import SgNetLSD | |||
| from ..feature import NetLSD as SgNetLSD | |||
| from torch_geometric.data import InMemoryDataset | |||