| @@ -1,6 +1,6 @@ | |||
| from .data import Data | |||
| from .batch import Batch | |||
| from .dataset import Dataset | |||
| from ._dataset import Dataset, InMemoryDataset, InMemoryStaticGraphSet | |||
| from .dataloader import DataLoader, DataListLoader, DenseDataLoader | |||
| from .download import download_url | |||
| from .extract import extract_tar, extract_zip, extract_bz2, extract_gz | |||
| @@ -9,6 +9,8 @@ __all__ = [ | |||
| "Data", | |||
| "Batch", | |||
| "Dataset", | |||
| "InMemoryDataset", | |||
| "InMemoryStaticGraphSet", | |||
| "DataLoader", | |||
| "DataListLoader", | |||
| "DenseDataLoader", | |||
| @@ -0,0 +1,2 @@ | |||
| from ._dataset import Dataset, InMemoryDataset | |||
| from ._in_memory_static_graph_set import InMemoryStaticGraphSet | |||
| @@ -0,0 +1,243 @@ | |||
| import typing as _typing | |||
| _D = _typing.TypeVar('_D') | |||
| class Dataset(_typing.Iterable[_D], _typing.Sized): | |||
| def __len__(self) -> int: | |||
| raise NotImplementedError | |||
| def __iter__(self) -> _typing.Iterator[_D]: | |||
| raise NotImplementedError | |||
| def __getitem__(self, index: int) -> _D: | |||
| raise NotImplementedError | |||
| def __setitem__(self, index: int, data: _D): | |||
| raise NotImplementedError | |||
| @property | |||
| def train_split(self) -> _typing.Optional[_typing.Iterable[_D]]: | |||
| raise NotImplementedError | |||
| @property | |||
| def val_split(self) -> _typing.Optional[_typing.Iterable[_D]]: | |||
| raise NotImplementedError | |||
| @property | |||
| def test_split(self) -> _typing.Optional[_typing.Iterable[_D]]: | |||
| raise NotImplementedError | |||
| @property | |||
| def train_index(self) -> _typing.Optional[_typing.AbstractSet[int]]: | |||
| raise NotImplementedError | |||
| @property | |||
| def val_index(self) -> _typing.Optional[_typing.AbstractSet[int]]: | |||
| raise NotImplementedError | |||
| @property | |||
| def test_index(self) -> _typing.Optional[_typing.AbstractSet[int]]: | |||
| raise NotImplementedError | |||
| @train_index.setter | |||
| def train_index(self, train_index: _typing.Optional[_typing.Iterable[int]]): | |||
| raise NotImplementedError | |||
| @val_index.setter | |||
| def val_index(self, val_index: _typing.Optional[_typing.Iterable[int]]): | |||
| raise NotImplementedError | |||
| @test_index.setter | |||
| def test_index(self, test_index: _typing.Optional[_typing.Iterable[int]]): | |||
| raise NotImplementedError | |||
| class _FoldsContainer: | |||
| def __init__( | |||
| self, | |||
| folds: _typing.Optional[_typing.Iterable[_typing.Tuple[_typing.Sequence[int], _typing.Sequence[int]]]] = ... | |||
| ): | |||
| self._folds: _typing.Optional[_typing.List[_typing.Tuple[_typing.Sequence[int], _typing.Sequence[int]]]] = ( | |||
| list(folds) if isinstance(folds, _typing.Iterable) else None | |||
| ) | |||
| if self._folds is not None and len(self._folds) == 0: | |||
| self._folds = None | |||
| @property | |||
| def folds(self) -> _typing.Optional[_typing.Sequence[_typing.Tuple[_typing.Sequence[int], _typing.Sequence[int]]]]: | |||
| if self._folds is not None and len(self._folds) == 0: | |||
| self._folds = None | |||
| return self._folds | |||
| @folds.setter | |||
| def folds(self, folds: _typing.Optional[_typing.Iterable[_typing.Tuple[_typing.Sequence[int], _typing.Sequence[int]]]]): | |||
| self._folds: _typing.Optional[_typing.List[_typing.Tuple[_typing.Sequence[int], _typing.Sequence[int]]]] = ( | |||
| list(folds) if isinstance(folds, _typing.Iterable) else None | |||
| ) | |||
| if self._folds is not None and len(self._folds) == 0: | |||
| self._folds = None | |||
| class _FoldView: | |||
| def __init__(self, folds_container: _FoldsContainer, fold_index: int): | |||
| self._folds_container: _FoldsContainer = folds_container | |||
| self._fold_index: int = fold_index | |||
| @property | |||
| def train_index(self) -> _typing.Sequence[int]: | |||
| return self._folds_container.folds[self._fold_index][0] | |||
| @property | |||
| def val_index(self) -> _typing.Sequence[int]: | |||
| return self._folds_container.folds[self._fold_index][1] | |||
| class _FoldsView(_typing.Sequence[_FoldView]): | |||
| def __init__(self, folds_container: _FoldsContainer): | |||
| self._folds_container = folds_container | |||
| def __len__(self) -> int: | |||
| return ( | |||
| len(self._folds_container.folds) | |||
| if self._folds_container.folds is not None | |||
| else 0 | |||
| ) | |||
| def __getitem__(self, fold_index: int) -> _FoldView: | |||
| return _FoldView(self._folds_container, fold_index) | |||
| class InMemoryDataset(Dataset[_D]): | |||
| def __init__( | |||
| self, data: _typing.Iterable[_D], | |||
| train_index: _typing.Optional[_typing.Iterable[int]] = ..., | |||
| val_index: _typing.Optional[_typing.Iterable[int]] = ..., | |||
| test_index: _typing.Optional[_typing.Iterable[int]] = ... | |||
| ): | |||
| self.__data: _typing.MutableSequence[_D] = list(data) | |||
| self.__train_index: _typing.Optional[_typing.Iterable[int]] = ( | |||
| train_index if isinstance(train_index, _typing.Iterable) else None | |||
| ) | |||
| self.__val_index: _typing.Optional[_typing.Iterable[int]] = ( | |||
| val_index if isinstance(val_index, _typing.Iterable) else None | |||
| ) | |||
| self.__test_index: _typing.Optional[_typing.Iterable[int]] = ( | |||
| test_index if isinstance(test_index, _typing.Iterable) else None | |||
| ) | |||
| self.__folds_container: _FoldsContainer = _FoldsContainer() | |||
| @property | |||
| def folds(self) -> _typing.Optional[_FoldsView]: | |||
| return ( | |||
| _FoldsView(self.__folds_container) | |||
| if ( | |||
| self.__folds_container.folds is not None and | |||
| len(self.__folds_container.folds) > 0 | |||
| ) | |||
| else None | |||
| ) | |||
| @folds.setter | |||
| def folds( | |||
| self, | |||
| folds: _typing.Optional[ | |||
| _typing.Iterable[ | |||
| _typing.Tuple[_typing.Sequence[int], _typing.Sequence[int]] | |||
| ] | |||
| ] = ... | |||
| ): | |||
| self.__folds_container.folds = folds | |||
| def __len__(self) -> int: | |||
| return len(self.__data) | |||
| def __iter__(self) -> _typing.Iterator[_D]: | |||
| return iter(self.__data) | |||
| def __getitem__(self, index: int) -> _D: | |||
| return self.__data[index] | |||
| def __setitem__(self, index: int, data: _D): | |||
| self.__data[index] = data | |||
| def reset_dataset(self, data: _typing.Iterable[_D]): | |||
| if not isinstance(data, _typing.Iterable): | |||
| raise TypeError | |||
| __data: _typing.MutableSequence[_D] = list(data) | |||
| __preserve_info: bool = __data == len(self) | |||
| self.__data: _typing.MutableSequence[_D] = __data | |||
| if not __preserve_info: | |||
| self.train_index = self.val_index = self.test_index = None | |||
| @property | |||
| def train_split(self) -> _typing.Optional[_typing.Iterable[_D]]: | |||
| return ( | |||
| [self.__data[i] for i in self.__train_index] | |||
| if isinstance(self.__train_index, _typing.Iterable) else None | |||
| ) | |||
| @property | |||
| def val_split(self) -> _typing.Optional[_typing.Iterable[_D]]: | |||
| return ( | |||
| [self.__data[i] for i in self.__val_index] | |||
| if isinstance(self.__val_index, _typing.Iterable) else None | |||
| ) | |||
| @property | |||
| def test_split(self) -> _typing.Optional[_typing.Iterable[_D]]: | |||
| return ( | |||
| [self.__data[i] for i in self.__test_index] | |||
| if isinstance(self.__test_index, _typing.Iterable) else None | |||
| ) | |||
| @property | |||
| def train_index(self) -> _typing.Optional[_typing.AbstractSet[int]]: | |||
| return self.__train_index | |||
| @property | |||
| def val_index(self) -> _typing.Optional[_typing.AbstractSet[int]]: | |||
| return self.__val_index | |||
| @property | |||
| def test_index(self) -> _typing.Optional[_typing.AbstractSet[int]]: | |||
| return self.__test_index | |||
| @train_index.setter | |||
| def train_index(self, train_index: _typing.Optional[_typing.Iterable[int]]): | |||
| if not (train_index is None or isinstance(train_index, _typing.Iterable)): | |||
| raise TypeError | |||
| elif train_index is None: | |||
| self.__train_index: _typing.Optional[_typing.Iterable[int]] = None | |||
| elif isinstance(train_index, _typing.Iterable): | |||
| if not all([isinstance(i, int) for i in train_index]): | |||
| raise TypeError | |||
| if not (0 <= min(train_index) <= max(train_index) < len(self)): | |||
| raise ValueError | |||
| self.__train_index: _typing.Optional[_typing.Iterable[int]] = train_index | |||
| @val_index.setter | |||
| def val_index(self, val_index: _typing.Optional[_typing.Iterable[int]]): | |||
| if not (val_index is None or isinstance(val_index, _typing.Iterable)): | |||
| raise TypeError | |||
| elif val_index is None: | |||
| self.__val_index: _typing.Optional[_typing.Iterable[int]] = None | |||
| elif isinstance(val_index, _typing.Iterable): | |||
| if not all([isinstance(i, int) for i in val_index]): | |||
| raise TypeError | |||
| if not (0 <= min(val_index) <= max(val_index) < len(self)): | |||
| raise ValueError | |||
| self.__val_index: _typing.Optional[_typing.Iterable[int]] = val_index | |||
| @test_index.setter | |||
| def test_index(self, test_index: _typing.Optional[_typing.Iterable[int]]): | |||
| if not (test_index is None or isinstance(test_index, _typing.Iterable)): | |||
| raise TypeError | |||
| elif test_index is None: | |||
| self.__test_index: _typing.Optional[_typing.Set[int]] = None | |||
| elif isinstance(test_index, _typing.Iterable): | |||
| if not all([isinstance(i, int) for i in test_index]): | |||
| raise TypeError | |||
| if not (0 <= min(test_index) <= max(test_index) < len(self)): | |||
| raise ValueError | |||
| self.__test_index: _typing.Optional[_typing.Iterable[int]] = test_index | |||
| @@ -0,0 +1,24 @@ | |||
| import typing as _typing | |||
| from ._dataset import InMemoryDataset | |||
| from ..graph import GeneralStaticGraph | |||
| class InMemoryStaticGraphSet(InMemoryDataset[GeneralStaticGraph]): | |||
| def __init__( | |||
| self, graphs: _typing.Iterable[GeneralStaticGraph], | |||
| train_index: _typing.Optional[_typing.Iterable[int]] = ..., | |||
| val_index: _typing.Optional[_typing.Iterable[int]] = ..., | |||
| test_index: _typing.Optional[_typing.Iterable[int]] = ... | |||
| ): | |||
| super(InMemoryStaticGraphSet, self).__init__( | |||
| graphs, train_index, val_index, test_index | |||
| ) | |||
| def __iter__(self) -> _typing.Iterator[GeneralStaticGraph]: | |||
| return super(InMemoryStaticGraphSet, self).__iter__() | |||
| def __getitem__(self, index: int) -> GeneralStaticGraph: | |||
| return super(InMemoryStaticGraphSet, self).__getitem__(index) | |||
| def __setitem__(self, index: int, data: GeneralStaticGraph): | |||
| super(InMemoryStaticGraphSet, self).__setitem__(index, data) | |||
| @@ -1,134 +0,0 @@ | |||
| import collections | |||
| import os.path as osp | |||
| import torch.utils.data | |||
| from .makedirs import makedirs | |||
| def to_list(x): | |||
| if not isinstance(x, collections.Iterable) or isinstance(x, str): | |||
| x = [x] | |||
| return x | |||
| def files_exist(files): | |||
| return all([osp.exists(f) for f in files]) | |||
| class Dataset(torch.utils.data.Dataset): | |||
| r"""Dataset base class for creating graph datasets. | |||
| See `here <https://rusty1s.github.io/pycogdl/build/html/notes/ | |||
| create_dataset.html>`__ for the accompanying tutorial. | |||
| Args: | |||
| root (string): Root directory where the dataset should be saved. | |||
| transform (callable, optional): A function/transform that takes in an | |||
| :obj:`cogdl.data.Data` object and returns a transformed | |||
| version. The data object will be transformed before every access. | |||
| (default: :obj:`None`) | |||
| pre_transform (callable, optional): A function/transform that takes in | |||
| an :obj:`cogdl.data.Data` object and returns a | |||
| transformed version. The data object will be transformed before | |||
| being saved to disk. (default: :obj:`None`) | |||
| pre_filter (callable, optional): A function that takes in an | |||
| :obj:`cogdl.data.Data` object and returns a boolean | |||
| value, indicating whether the data object should be included in the | |||
| final dataset. (default: :obj:`None`) | |||
| """ | |||
| @property | |||
| def raw_file_names(self): | |||
| r"""The name of the files to find in the :obj:`self.raw_dir` folder in | |||
| order to skip the download.""" | |||
| raise NotImplementedError | |||
| @property | |||
| def processed_file_names(self): | |||
| r"""The name of the files to find in the :obj:`self.processed_dir` | |||
| folder in order to skip the processing.""" | |||
| raise NotImplementedError | |||
| def download(self): | |||
| r"""Downloads the dataset to the :obj:`self.raw_dir` folder.""" | |||
| raise NotImplementedError | |||
| def process(self): | |||
| r"""Processes the dataset to the :obj:`self.processed_dir` folder.""" | |||
| raise NotImplementedError | |||
| def __len__(self): | |||
| r"""The number of examples in the dataset.""" | |||
| raise NotImplementedError | |||
| def get(self, idx): | |||
| r"""Gets the data object at index :obj:`idx`.""" | |||
| raise NotImplementedError | |||
| def __init__(self, root, transform=None, pre_transform=None, pre_filter=None): | |||
| super(Dataset, self).__init__() | |||
| self.root = osp.expanduser(osp.normpath(root)) | |||
| self.raw_dir = osp.join(self.root, "raw") | |||
| self.processed_dir = osp.join(self.root, "processed") | |||
| self.transform = transform | |||
| self.pre_transform = pre_transform | |||
| self.pre_filter = pre_filter | |||
| self._download() | |||
| self._process() | |||
| @property | |||
| def get_label_number(self): | |||
| r"""Get the number of labels in this dataset as dict.""" | |||
| label_num = {} | |||
| labels = self[0].y.unique().cpu().detach().numpy().tolist() | |||
| for label in labels: | |||
| label_num[label] = (self[0].y == label).sum().item() | |||
| return label_num | |||
| @property | |||
| def num_features(self): | |||
| r"""Returns the number of features per node in the graph.""" | |||
| return self[0].num_features | |||
| @property | |||
| def raw_paths(self): | |||
| r"""The filepaths to find in order to skip the download.""" | |||
| files = to_list(self.raw_file_names) | |||
| return [osp.join(self.raw_dir, f) for f in files] | |||
| @property | |||
| def processed_paths(self): | |||
| r"""The filepaths to find in the :obj:`self.processed_dir` | |||
| folder in order to skip the processing.""" | |||
| files = to_list(self.processed_file_names) | |||
| return [osp.join(self.processed_dir, f) for f in files] | |||
| def _download(self): | |||
| if files_exist(self.raw_paths): # pragma: no cover | |||
| return | |||
| makedirs(self.raw_dir) | |||
| self.download() | |||
| def _process(self): | |||
| if files_exist(self.processed_paths): # pragma: no cover | |||
| return | |||
| print("Processing...") | |||
| makedirs(self.processed_dir) | |||
| self.process() | |||
| print("Done!") | |||
| def __getitem__(self, idx): # pragma: no cover | |||
| r"""Gets the data object at index :obj:`idx` and transforms it (in case | |||
| a :obj:`self.transform` is given).""" | |||
| data = self.get(idx) | |||
| data = data if self.transform is None else self.transform(data) | |||
| return data | |||
| def __repr__(self): # pragma: no cover | |||
| return "{}({})".format(self.__class__.__name__, len(self)) | |||
| @@ -0,0 +1,4 @@ | |||
| from ._general_static_graph import ( | |||
| GeneralStaticGraph, GeneralStaticGraphGenerator | |||
| ) | |||
| from . import utils | |||
| @@ -0,0 +1,2 @@ | |||
| from ._general_static_graph import GeneralStaticGraph | |||
| from ._general_static_graph_generator import GeneralStaticGraphGenerator | |||
| @@ -0,0 +1,162 @@ | |||
| import torch | |||
| import typing as _typing | |||
| from . import _canonical_edge_type | |||
| class SpecificTypedNodeDataView(_typing.MutableMapping[str, torch.Tensor]): | |||
| def __getitem__(self, data_key: str) -> torch.Tensor: | |||
| raise NotImplementedError | |||
| def __setitem__(self, data_key: str, value: torch.Tensor): | |||
| raise NotImplementedError | |||
| def __delitem__(self, data_key: str) -> None: | |||
| raise NotImplementedError | |||
| def __len__(self) -> int: | |||
| raise NotImplementedError | |||
| def __iter__(self) -> _typing.Iterator[str]: | |||
| raise NotImplementedError | |||
| class SpecificTypedNodeView: | |||
| @property | |||
| def data(self) -> SpecificTypedNodeDataView: | |||
| raise NotImplementedError | |||
| @data.setter | |||
| def data(self, nodes_data: _typing.Mapping[str, torch.Tensor]): | |||
| raise NotImplementedError | |||
| class HeterogeneousNodeView(_typing.Iterable[str]): | |||
| @property | |||
| def data(self) -> SpecificTypedNodeDataView: | |||
| raise NotImplementedError | |||
| @data.setter | |||
| def data(self, nodes_data: _typing.Mapping[str, torch.Tensor]): | |||
| raise NotImplementedError | |||
| def __getitem__(self, node_type: _typing.Optional[str]) -> SpecificTypedNodeView: | |||
| raise NotImplementedError | |||
| def __setitem__( | |||
| self, node_t: _typing.Optional[str], | |||
| nodes_data: _typing.Mapping[str, torch.Tensor] | |||
| ): | |||
| raise NotImplementedError | |||
| def __delitem__(self, node_t: _typing.Optional[str]): | |||
| raise NotImplementedError | |||
| def __iter__(self) -> _typing.Iterator[str]: | |||
| raise NotImplementedError | |||
| @property | |||
| def is_homogeneous(self) -> bool: | |||
| raise NotImplementedError | |||
| class HomogeneousEdgesDataView(_typing.MutableMapping[str, torch.Tensor]): | |||
| def __getitem__(self, data_key: str) -> torch.Tensor: | |||
| raise NotImplementedError | |||
| def __setitem__(self, data_key: str, value: torch.Tensor): | |||
| raise NotImplementedError | |||
| def __delitem__(self, data_key: str): | |||
| raise NotImplementedError | |||
| def __len__(self) -> int: | |||
| raise NotImplementedError | |||
| def __iter__(self) -> _typing.Iterator[str]: | |||
| raise NotImplementedError | |||
| class HomogeneousEdgesView: | |||
| @property | |||
| def connections(self) -> torch.LongTensor: | |||
| raise NotImplementedError | |||
| @property | |||
| def data(self) -> HomogeneousEdgesDataView: | |||
| raise NotImplementedError | |||
| class HeterogeneousEdgesView(_typing.Collection[_canonical_edge_type.CanonicalEdgeType]): | |||
| @property | |||
| def connections(self) -> torch.LongTensor: | |||
| raise NotImplementedError | |||
| @property | |||
| def data(self) -> HomogeneousEdgesDataView: | |||
| raise NotImplementedError | |||
| @property | |||
| def is_homogeneous(self) -> bool: | |||
| raise NotImplementedError | |||
| def set( | |||
| self, edge_t: _typing.Union[None, str, _typing.Tuple[str, str, str]], | |||
| connections: torch.LongTensor, data: _typing.Optional[_typing.Mapping[str, torch.Tensor]] = ... | |||
| ): | |||
| raise NotImplementedError | |||
| def __getitem__( | |||
| self, | |||
| edge_t: _typing.Union[ | |||
| None, str, _typing.Tuple[str, str, str], _canonical_edge_type.CanonicalEdgeType | |||
| ] | |||
| ) -> HomogeneousEdgesView: | |||
| raise NotImplementedError | |||
| def __setitem__( | |||
| self, | |||
| edge_t: _typing.Union[ | |||
| None, str, _typing.Tuple[str, str, str], _canonical_edge_type.CanonicalEdgeType | |||
| ], | |||
| edges: _typing.Union[torch.LongTensor] | |||
| ): | |||
| raise NotImplementedError | |||
| def __delitem__( | |||
| self, | |||
| edge_t: _typing.Union[ | |||
| None, str, _typing.Tuple[str, str, str], _canonical_edge_type.CanonicalEdgeType | |||
| ] | |||
| ): | |||
| raise NotImplementedError | |||
| def __len__(self) -> int: | |||
| raise NotImplementedError | |||
| def __iter__(self) -> _typing.Iterator[_canonical_edge_type.CanonicalEdgeType]: | |||
| raise NotImplementedError | |||
| def __contains__( | |||
| self, | |||
| edge_type: _typing.Union[ | |||
| str, _typing.Tuple[str, str, str], _canonical_edge_type.CanonicalEdgeType | |||
| ] | |||
| ) -> bool: | |||
| raise NotImplementedError | |||
| class GraphDataView(_typing.MutableMapping[str, torch.Tensor]): | |||
| def __setitem__(self, data_key: str, data: torch.Tensor) -> None: | |||
| raise NotImplementedError | |||
| def __delitem__(self, data_key: str) -> None: | |||
| raise NotImplementedError | |||
| def __getitem__(self, data_key: str) -> torch.Tensor: | |||
| raise NotImplementedError | |||
| def __len__(self) -> int: | |||
| raise NotImplementedError | |||
| def __iter__(self) -> _typing.Iterator[str]: | |||
| raise NotImplementedError | |||
| @@ -0,0 +1,56 @@ | |||
| import typing as _typing | |||
| class CanonicalEdgeType(_typing.Sequence[str]): | |||
| def __init__(self, source_node_type: str, relation_type: str, target_node_type: str): | |||
| if not isinstance(source_node_type, str): | |||
| raise TypeError | |||
| elif ' ' in source_node_type: | |||
| raise ValueError | |||
| if not isinstance(relation_type, str): | |||
| raise TypeError | |||
| elif ' ' in relation_type: | |||
| raise ValueError | |||
| if not isinstance(target_node_type, str): | |||
| raise TypeError | |||
| elif ' ' in target_node_type: | |||
| raise ValueError | |||
| self.__source_node_type: str = source_node_type | |||
| self.__relation_type: str = relation_type | |||
| self.__destination_node_type: str = target_node_type | |||
| @property | |||
| def source_node_type(self) -> str: | |||
| return self.__source_node_type | |||
| @property | |||
| def relation_type(self) -> str: | |||
| return self.__relation_type | |||
| @property | |||
| def target_node_type(self) -> str: | |||
| return self.__destination_node_type | |||
| def __eq__(self, other): | |||
| if not (isinstance(other, CanonicalEdgeType) or isinstance(other, _typing.Sequence)): | |||
| return False | |||
| elif isinstance(other, _typing.Sequence): | |||
| if not (len(other) == 3 and all([(isinstance(t, str) and ' ' not in t) for t in other])): | |||
| raise TypeError | |||
| return ( | |||
| other[0] == self.source_node_type and | |||
| other[1] == self.relation_type and | |||
| other[2] == self.target_node_type | |||
| ) | |||
| elif isinstance(other, CanonicalEdgeType): | |||
| return ( | |||
| other.source_node_type == self.source_node_type and | |||
| other.relation_type == self.relation_type and | |||
| other.target_node_type == self.target_node_type | |||
| ) | |||
| def __getitem__(self, index: int): | |||
| return (self.source_node_type, self.relation_type, self.target_node_type)[index] | |||
| def __len__(self) -> int: | |||
| return 3 | |||
| @@ -0,0 +1,15 @@ | |||
| from . import _abstract_views | |||
| class GeneralStaticGraph: | |||
| @property | |||
| def nodes(self) -> _abstract_views.HeterogeneousNodeView: | |||
| raise NotImplementedError | |||
| @property | |||
| def edges(self) -> _abstract_views.HeterogeneousEdgesView: | |||
| raise NotImplementedError | |||
| @property | |||
| def data(self) -> _abstract_views.GraphDataView: | |||
| raise NotImplementedError | |||
| @@ -0,0 +1,940 @@ | |||
| import pandas as pd | |||
| import torch | |||
| import typing as _typing | |||
| from . import ( | |||
| _abstract_views, | |||
| _canonical_edge_type, | |||
| _general_static_graph | |||
| ) | |||
| class HeterogeneousNodesContainer: | |||
| @property | |||
| def node_types(self) -> _typing.AbstractSet[str]: | |||
| raise NotImplementedError | |||
| def remove_nodes(self, node_t: _typing.Optional[str]) -> 'HeterogeneousNodesContainer': | |||
| raise NotImplementedError | |||
| def reset_nodes( | |||
| self, node_t: _typing.Optional[str], | |||
| nodes_data: _typing.Mapping[str, torch.Tensor] | |||
| ) -> 'HeterogeneousNodesContainer': | |||
| raise NotImplementedError | |||
| def set_data( | |||
| self, node_t: _typing.Optional[str], data_key: str, data: torch.Tensor | |||
| ) -> 'HeterogeneousNodesContainer': | |||
| raise NotImplementedError | |||
| def get_data( | |||
| self, node_t: _typing.Optional[str] = ..., | |||
| data_key: _typing.Optional[str] = ... | |||
| ) -> _typing.Union[torch.Tensor, _typing.Mapping[str, torch.Tensor]]: | |||
| raise NotImplementedError | |||
| def delete_data( | |||
| self, node_t: _typing.Optional[str], data_key: str | |||
| ) -> 'HeterogeneousNodesContainer': | |||
| raise TypeError | |||
| def remove_data( | |||
| self, node_t: _typing.Optional[str], data_key: str | |||
| ) -> 'HeterogeneousNodesContainer': | |||
| return self.delete_data(node_t, data_key) | |||
| class HeterogeneousNodesContainerImplementation(HeterogeneousNodesContainer): | |||
| def __init__(self, data: _typing.Optional[_typing.Mapping[str, _typing.Mapping[str, torch.Tensor]]] = ...): | |||
| self.__nodes_data: _typing.MutableMapping[str, _typing.MutableMapping[str, torch.Tensor]] = {} | |||
| if data not in (None, Ellipsis) and isinstance(data, _typing.Mapping): | |||
| for node_t, nodes_data in data.items(): | |||
| self.reset_nodes(node_t, nodes_data) | |||
| @property | |||
| def node_types(self) -> _typing.AbstractSet[str]: | |||
| return self.__nodes_data.keys() | |||
| def remove_nodes(self, node_t: _typing.Optional[str]) -> HeterogeneousNodesContainer: | |||
| if not (node_t in (Ellipsis, None) or isinstance(node_t, str)): | |||
| raise TypeError | |||
| elif node_t in (Ellipsis, None): | |||
| if len(self.node_types) == 0: | |||
| return self | |||
| elif len(self.node_types) == 1: | |||
| del self.__nodes_data[tuple(self.node_types)[0]] | |||
| else: | |||
| _error_message: str = ' '.join(( | |||
| "Unable to determine node type automatically,", | |||
| "possible cause is that the graph contains heterogeneous nodes,", | |||
| "node type must be specified for graph containing heterogeneous nodes." | |||
| )) | |||
| raise TypeError(_error_message) | |||
| elif isinstance(node_t, str): | |||
| try: | |||
| del self.__nodes_data[node_t] | |||
| except Exception: | |||
| raise ValueError(f"nodes with type [{node_t}] NOT exists") | |||
| return self | |||
| def reset_nodes( | |||
| self, node_t: _typing.Optional[str], | |||
| nodes_data: _typing.Mapping[str, torch.Tensor] | |||
| ) -> HeterogeneousNodesContainer: | |||
| if not (node_t in (Ellipsis, None) or isinstance(node_t, str)): | |||
| raise TypeError | |||
| elif node_t in (Ellipsis, None) and len(self.node_types) > 1: | |||
| _error_message: str = ' '.join(( | |||
| "Unable to determine node type automatically,", | |||
| "possible cause is that the graph contains heterogeneous nodes,", | |||
| "node type must be specified for graph containing heterogeneous nodes." | |||
| )) | |||
| raise TypeError(_error_message) | |||
| elif isinstance(node_t, str) and ' ' in node_t: | |||
| raise ValueError("node type must NOT contain space character (\' \').") | |||
| __node_t: str = "" if node_t is Ellipsis else node_t | |||
| num_nodes: int = ... | |||
| for data_key, data_item in nodes_data.items(): | |||
| if not isinstance(data_key, str): | |||
| raise TypeError | |||
| if ' ' in data_key: | |||
| raise ValueError("data key must NOT contain space character (\' \').") | |||
| if not isinstance(data_item, torch.Tensor): | |||
| raise TypeError | |||
| if not data_item.dim() > 0: | |||
| raise ValueError( | |||
| "data item MUST have at least one dimension, " | |||
| "and the first dimension corresponds to data for diverse nodes." | |||
| ) | |||
| if not isinstance(num_nodes, int): | |||
| num_nodes: int = data_item.size(0) | |||
| if data_item.size(0) != num_nodes: | |||
| raise ValueError | |||
| self.__nodes_data[__node_t] = dict(nodes_data) | |||
| return self | |||
| def set_data( | |||
| self, node_t: _typing.Optional[str], data_key: str, data: torch.Tensor | |||
| ) -> HeterogeneousNodesContainer: | |||
| if node_t in (Ellipsis, None): | |||
| if len(self.node_types) == 0: | |||
| __node_t: str = "" # Default node type for homogeneous graph | |||
| elif len(self.node_types) == 1: | |||
| __node_t: str = list(self.node_types)[0] | |||
| else: | |||
| _error_message: str = ' '.join(( | |||
| "Unable to determine node type automatically,", | |||
| "possible cause is that the graph contains heterogeneous nodes,", | |||
| "node type must be specified for graph containing heterogeneous nodes." | |||
| )) | |||
| raise TypeError(_error_message) | |||
| elif isinstance(node_t, str): | |||
| __node_t: str = node_t | |||
| else: | |||
| raise TypeError | |||
| if not isinstance(data_key, str): | |||
| raise TypeError | |||
| if not isinstance(data, torch.Tensor): | |||
| raise TypeError | |||
| if ' ' in __node_t: | |||
| raise ValueError | |||
| if ' ' in data_key: | |||
| raise ValueError | |||
| if not data.dim() > 0: | |||
| raise ValueError( | |||
| "data item MUST have at least one dimension, " | |||
| "and the first dimension corresponds to data for diverse nodes." | |||
| ) | |||
| if __node_t not in self.node_types: | |||
| self.__nodes_data[__node_t] = dict([(data_key, data)]) | |||
| else: | |||
| obsolete_data: _typing.Optional[torch.Tensor] = self.__nodes_data[__node_t].get(data_key) | |||
| if obsolete_data is not None and isinstance(obsolete_data, torch.Tensor): | |||
| if data.size(0) != obsolete_data.size(0): | |||
| raise ValueError | |||
| elif len(self.__nodes_data.get(__node_t)) > 0: | |||
| num_nodes: int = self.__nodes_data[__node_t][list(self.__nodes_data[__node_t].keys())[0]].size(0) | |||
| if data.size(0) != num_nodes: | |||
| raise ValueError | |||
| self.__nodes_data[__node_t][data_key] = data | |||
| return self | |||
| def __get_data_for_specific_node_type( | |||
| self, node_t: str, data_key: _typing.Optional[str] = ... | |||
| ) -> _typing.Union[torch.Tensor, _typing.Mapping[str, torch.Tensor]]: | |||
| if not isinstance(node_t, str): | |||
| raise TypeError | |||
| elif ' ' in node_t: | |||
| raise ValueError | |||
| if not (data_key in (Ellipsis, None) or isinstance(data_key, str)): | |||
| raise TypeError | |||
| elif isinstance(data_key, str) and ' ' in data_key: | |||
| raise ValueError | |||
| if node_t not in self.node_types: | |||
| raise ValueError("Node type NOT exists") | |||
| elif isinstance(data_key, str): | |||
| data: _typing.Optional[torch.Tensor] = self.__nodes_data[node_t].get(data_key) | |||
| if data is not None: | |||
| return data | |||
| else: | |||
| raise KeyError( | |||
| f"Data with key [{data_key}] NOT exists " | |||
| f"for nodes with specific type [{node_t}]" | |||
| ) | |||
| else: | |||
| return self.__nodes_data[node_t] | |||
| def __get_data_for_specific_data_key( | |||
| self, data_key: str, node_t: _typing.Optional[str] = ... | |||
| ) -> _typing.Union[torch.Tensor, _typing.Mapping[str, torch.Tensor]]: | |||
| if not isinstance(data_key, str): | |||
| raise TypeError | |||
| elif ' ' in data_key: | |||
| raise ValueError | |||
| if not (node_t in (Ellipsis, None) or isinstance(node_t, str)): | |||
| raise TypeError | |||
| elif isinstance(node_t, str) and ' ' in node_t: | |||
| raise ValueError | |||
| if isinstance(node_t, str): | |||
| if node_t not in self.node_types: | |||
| raise ValueError("Node type NOT exists") | |||
| else: | |||
| data: _typing.Optional[torch.Tensor] = ( | |||
| self.__nodes_data[node_t].get(data_key) | |||
| ) | |||
| if data is not None: | |||
| return data | |||
| else: | |||
| raise KeyError( | |||
| f"Data with key [{data_key}] NOT exists " | |||
| f"for nodes with specific type [{node_t}]" | |||
| ) | |||
| else: | |||
| if len(self.node_types) == 0: | |||
| raise RuntimeError("Unable to get data from empty graph") | |||
| elif len(self.node_types) == 1: | |||
| __node_t: str = tuple(self.node_types)[0] | |||
| __optional_data: _typing.Optional[torch.Tensor] = ( | |||
| self.__nodes_data[__node_t].get(data_key) | |||
| ) | |||
| if __optional_data is not None: | |||
| return __optional_data | |||
| else: | |||
| raise KeyError(f"Data with key [{data_key}] NOT exists") | |||
| else: | |||
| __result: _typing.Dict[str, torch.Tensor] = {} | |||
| for __node_t, __nodes_data in self.__nodes_data.items(): | |||
| __optional_data: _typing.Optional[torch.Tensor] = ( | |||
| __nodes_data.get(data_key) | |||
| ) | |||
| if ( | |||
| __optional_data is not None and | |||
| isinstance(__optional_data, torch.Tensor) | |||
| ): | |||
| __result[__node_t] = __optional_data | |||
| if len(__result): | |||
| return __result | |||
| else: | |||
| raise KeyError(f"Data with key [{data_key}] NOT exists") | |||
| def get_data( | |||
| self, node_t: _typing.Optional[str] = ..., | |||
| data_key: _typing.Optional[str] = ... | |||
| ) -> _typing.Union[torch.Tensor, _typing.Mapping[str, torch.Tensor]]: | |||
| if not (node_t in (Ellipsis, None) or isinstance(node_t, str)): | |||
| raise TypeError | |||
| elif isinstance(node_t, str) and ' ' in node_t: | |||
| raise ValueError | |||
| if not (data_key in (Ellipsis, None) or isinstance(data_key, str)): | |||
| raise TypeError | |||
| elif isinstance(data_key, str) and ' ' in data_key: | |||
| raise ValueError | |||
| if isinstance(node_t, str): | |||
| return self.__get_data_for_specific_node_type(node_t, data_key) | |||
| elif node_t in (Ellipsis, None) and isinstance(data_key, str): | |||
| return self.__get_data_for_specific_data_key(data_key) | |||
| elif node_t in (Ellipsis, None) and data_key in (Ellipsis, None): | |||
| if len(self.node_types) == 1: | |||
| __node_t: str = tuple(self.node_types)[0] | |||
| return self.__get_data_for_specific_node_type(__node_t) | |||
| else: | |||
| raise TypeError( | |||
| "Unable to determine node type automatically, " | |||
| "possible cause is that the graph contains heterogeneous nodes or is empty, " | |||
| "node type must be specified for graph containing heterogeneous nodes." | |||
| ) | |||
| def delete_data( | |||
| self, node_t: _typing.Optional[str], data_key: str | |||
| ) -> HeterogeneousNodesContainer: | |||
| if not (node_t in (Ellipsis, None) or isinstance(node_t, str)): | |||
| raise TypeError | |||
| elif node_t in (Ellipsis, None): | |||
| if len(self.node_types) == 1: | |||
| __node_t: str = tuple(self.node_types)[0] | |||
| else: | |||
| raise TypeError( | |||
| "Unable to determine node type automatically, " | |||
| "possible cause is that the graph contains heterogeneous nodes or is empty, " | |||
| "node type must be specified for graph containing heterogeneous nodes." | |||
| ) | |||
| elif isinstance(node_t, str): | |||
| if node_t in self.node_types: | |||
| __node_t: str = node_t | |||
| else: | |||
| raise ValueError("node type NOT exists") | |||
| else: | |||
| raise TypeError | |||
| if not isinstance(data_key, str): | |||
| raise TypeError | |||
| elif data_key not in self.__nodes_data.get(__node_t): | |||
| raise KeyError( | |||
| f"Data with key [{data_key}] NOT exists for nodes with type [{__node_t}]" | |||
| ) | |||
| else: | |||
| self.__nodes_data[__node_t].__delitem__(data_key) | |||
| if len(self.__nodes_data.get(__node_t)) == 0: | |||
| del self.__nodes_data[__node_t] | |||
| return self | |||
| class _SpecificTypedNodeDataView(_abstract_views.SpecificTypedNodeDataView): | |||
| def __init__( | |||
| self, heterogeneous_nodes_container: HeterogeneousNodesContainer, | |||
| node_type: _typing.Optional[str] | |||
| ): | |||
| if not isinstance(heterogeneous_nodes_container, HeterogeneousNodesContainer): | |||
| raise TypeError | |||
| else: | |||
| self._heterogeneous_nodes_container: HeterogeneousNodesContainer = ( | |||
| heterogeneous_nodes_container | |||
| ) | |||
| if not (isinstance(node_type, str) or node_type in (Ellipsis, None)): | |||
| raise TypeError | |||
| elif isinstance(node_type, str): | |||
| if node_type not in self._heterogeneous_nodes_container.node_types: | |||
| raise ValueError("Invalid node type") | |||
| self.__node_t: _typing.Optional[str] = node_type | |||
| def __getitem__(self, data_key: str) -> torch.Tensor: | |||
| return self._heterogeneous_nodes_container.get_data(self.__node_t, data_key) | |||
| def __setitem__(self, data_key: str, value: torch.Tensor): | |||
| self._heterogeneous_nodes_container.set_data(self.__node_t, data_key, value) | |||
| def __delitem__(self, data_key: str) -> None: | |||
| self._heterogeneous_nodes_container.delete_data(self.__node_t, data_key) | |||
| def __len__(self) -> int: | |||
| return len(self._heterogeneous_nodes_container.get_data(self.__node_t)) | |||
| def __iter__(self) -> _typing.Iterator[str]: | |||
| return iter(self._heterogeneous_nodes_container.get_data(self.__node_t)) | |||
| class _SpecificTypedNodeView(_abstract_views.SpecificTypedNodeView): | |||
| def __init__( | |||
| self, nodes_container: HeterogeneousNodesContainer, | |||
| node_t: _typing.Optional[str] | |||
| ): | |||
| self._heterogeneous_nodes_container: HeterogeneousNodesContainer = nodes_container | |||
| self.__node_t: _typing.Optional[str] = node_t | |||
| @property | |||
| def data(self) -> _SpecificTypedNodeDataView: | |||
| return _SpecificTypedNodeDataView(self._heterogeneous_nodes_container, self.__node_t) | |||
| @data.setter | |||
| def data(self, nodes_data: _typing.Mapping[str, torch.Tensor]): | |||
| self._heterogeneous_nodes_container.reset_nodes(self.__node_t, nodes_data) | |||
| class _HeterogeneousNodeView(_abstract_views.HeterogeneousNodeView): | |||
| def __init__(self, nodes_container: HeterogeneousNodesContainer): | |||
| self._heterogeneous_nodes_container: HeterogeneousNodesContainer = nodes_container | |||
| def __getitem__(self, node_type: _typing.Optional[str]) -> _SpecificTypedNodeView: | |||
| return _SpecificTypedNodeView(self._heterogeneous_nodes_container, node_type) | |||
| def __setitem__( | |||
| self, node_t: _typing.Optional[str], | |||
| nodes_data: _typing.Mapping[str, torch.Tensor] | |||
| ) -> None: | |||
| self._heterogeneous_nodes_container.reset_nodes(node_t, nodes_data) | |||
| def __delitem__(self, node_t: _typing.Optional[str]): | |||
| self._heterogeneous_nodes_container.remove_nodes(node_t) | |||
| def __iter__(self) -> _typing.Iterator[str]: | |||
| return iter(self._heterogeneous_nodes_container.node_types) | |||
| @property | |||
| def data(self) -> _SpecificTypedNodeDataView: | |||
| return _SpecificTypedNodeDataView(self._heterogeneous_nodes_container, ...) | |||
| @data.setter | |||
| def data(self, nodes_data: _typing.Mapping[str, torch.Tensor]): | |||
| self._heterogeneous_nodes_container.reset_nodes(..., nodes_data) | |||
| @property | |||
| def is_homogeneous(self) -> bool: | |||
| return len(self._heterogeneous_nodes_container.node_types) <= 1 | |||
| class HomogeneousEdgesContainer: | |||
| @property | |||
| def connections(self) -> torch.Tensor: | |||
| raise NotImplementedError | |||
| @property | |||
| def data_keys(self) -> _typing.Iterable[str]: | |||
| raise NotImplementedError | |||
| def get_data( | |||
| self, data_key: _typing.Optional[str] = ... | |||
| ) -> _typing.Union[torch.Tensor, _typing.Mapping[str, torch.Tensor]]: | |||
| raise NotImplementedError | |||
| def set_data(self, data_key: str, data: torch.Tensor): | |||
| raise NotImplementedError | |||
| def delete_data(self, data_key: str): | |||
| raise NotImplementedError | |||
| class HomogeneousEdgesContainerImplementation(HomogeneousEdgesContainer): | |||
| def __init__( | |||
| self, edge_connections: torch.Tensor, | |||
| data: _typing.Optional[_typing.Mapping[str, torch.Tensor]] = ... | |||
| ): | |||
| if not isinstance(edge_connections, torch.Tensor): | |||
| raise TypeError | |||
| if not (data in (Ellipsis, None) or isinstance(data, _typing.Mapping)): | |||
| raise TypeError | |||
| if not ( | |||
| edge_connections.dtype == torch.int64 and | |||
| edge_connections.dim() == edge_connections.size(0) == 2 | |||
| ): | |||
| raise ValueError | |||
| self.__connections: torch.Tensor = edge_connections | |||
| if not isinstance(data, _typing.Mapping): | |||
| self.__data: _typing.MutableMapping[str, torch.Tensor] = {} | |||
| else: | |||
| for data_key, data_item in data.items(): | |||
| if not isinstance(data_key, str): | |||
| raise TypeError | |||
| if not isinstance(data_item, torch.Tensor): | |||
| raise TypeError | |||
| if ' ' in data_key: | |||
| raise ValueError | |||
| if not data_item.dim() > 0: | |||
| raise ValueError | |||
| if data_item.size(0) != self.__connections.size(1): | |||
| raise ValueError | |||
| self.__data: _typing.MutableMapping[str, torch.Tensor] = dict(data) | |||
| @property | |||
| def connections(self) -> torch.Tensor: | |||
| return self.__connections | |||
| @property | |||
| def data_keys(self) -> _typing.Iterable[str]: | |||
| return self.__data.keys() | |||
| def set_data(self, data_key: str, data: torch.Tensor) -> HomogeneousEdgesContainer: | |||
| if not isinstance(data_key, str): | |||
| raise TypeError | |||
| if not isinstance(data, torch.Tensor): | |||
| raise TypeError | |||
| if ' ' in data_key: | |||
| raise ValueError | |||
| if data.dim() == 0 or data.size(0) != self.__connections.size(1): | |||
| raise ValueError | |||
| self.__data[data_key] = data | |||
| return self | |||
| def get_data( | |||
| self, data_key: _typing.Optional[str] = ... | |||
| ) -> _typing.Union[torch.Tensor, _typing.Mapping[str, torch.Tensor]]: | |||
| if not (data_key in (Ellipsis, None) or isinstance(data_key, str)): | |||
| raise TypeError | |||
| if isinstance(data_key, str): | |||
| if ' ' in data_key: | |||
| raise ValueError | |||
| temp: _typing.Optional[torch.Tensor] = self.__data.get(data_key) | |||
| if temp is None: | |||
| raise KeyError(f"Data with key [{data_key}] NOT exists") | |||
| else: | |||
| return temp | |||
| else: | |||
| return dict(self.__data) | |||
| def delete_data(self, data_key: str) -> HomogeneousEdgesContainer: | |||
| if not isinstance(data_key, str): | |||
| raise TypeError | |||
| if ' ' in data_key: | |||
| raise ValueError | |||
| try: | |||
| del self.__data[data_key] | |||
| finally: | |||
| return self | |||
| class HeterogeneousEdgesAggregation( | |||
| _typing.MutableMapping[ | |||
| _typing.Union[str, _typing.Tuple[str, str, str], _canonical_edge_type.CanonicalEdgeType], | |||
| HomogeneousEdgesContainer | |||
| ] | |||
| ): | |||
| def __setitem__( | |||
| self, edge_t: _typing.Union[None, str, _typing.Tuple[str, str, str], _canonical_edge_type.CanonicalEdgeType], | |||
| edges: _typing.Union[HomogeneousEdgesContainer, torch.LongTensor] | |||
| ) -> None: | |||
| self._set_edges(edge_t, edges) | |||
| def __delitem__( | |||
| self, edge_t: _typing.Union[None, str, _typing.Tuple[str, str, str], _canonical_edge_type.CanonicalEdgeType] | |||
| ) -> None: | |||
| self._delete_edges(edge_t) | |||
| def __getitem__( | |||
| self, edge_t: _typing.Union[None, str, _typing.Tuple[str, str, str], _canonical_edge_type.CanonicalEdgeType] = ... | |||
| ) -> HomogeneousEdgesContainer: | |||
| return self._get_edges(edge_t) | |||
| def __len__(self) -> int: | |||
| return len(list(self._edge_types)) | |||
| def __iter__(self) -> _typing.Iterator[_canonical_edge_type.CanonicalEdgeType]: | |||
| return iter(self._edge_types) | |||
| @property | |||
| def _edge_types(self) -> _typing.Iterable[_canonical_edge_type.CanonicalEdgeType]: | |||
| raise NotImplementedError | |||
| def _get_edges( | |||
| self, edge_t: _typing.Union[None, str, _typing.Tuple[str, str, str], _canonical_edge_type.CanonicalEdgeType] = ... | |||
| ) -> HomogeneousEdgesContainer: | |||
| raise NotImplementedError | |||
| def _set_edges( | |||
| self, edge_t: _typing.Union[None, str, _typing.Tuple[str, str, str], _canonical_edge_type.CanonicalEdgeType], | |||
| edges: _typing.Union[HomogeneousEdgesContainer, torch.LongTensor] | |||
| ): | |||
| raise NotImplementedError | |||
| def _delete_edges( | |||
| self, edge_t: _typing.Union[None, str, _typing.Tuple[str, str, str], _canonical_edge_type.CanonicalEdgeType] | |||
| ) -> None: | |||
| raise NotImplementedError | |||
| class HeterogeneousEdgesAggregationImplementation(HeterogeneousEdgesAggregation): | |||
| def __init__(self): | |||
| self.__heterogeneous_edges_data_frame: pd.DataFrame = pd.DataFrame( | |||
| columns=('s', 'r', 't', 'edges'), | |||
| ) | |||
| @property | |||
| def _edge_types(self) -> _typing.Iterable[_canonical_edge_type.CanonicalEdgeType]: | |||
| return [ | |||
| _canonical_edge_type.CanonicalEdgeType(getattr(row_tuple, 's'), getattr(row_tuple, 'r'), getattr(row_tuple, 't')) | |||
| for row_tuple in self.__heterogeneous_edges_data_frame.itertuples(False, name="Edge") | |||
| ] | |||
| def _get_edges( | |||
| self, edge_t: _typing.Union[None, str, _typing.Tuple[str, str, str], _canonical_edge_type.CanonicalEdgeType] = ... | |||
| ) -> HomogeneousEdgesContainer: | |||
| if edge_t in (Ellipsis, None): | |||
| if len(self.__heterogeneous_edges_data_frame) == 1: | |||
| return self.__heterogeneous_edges_data_frame.iloc[0]['edges'] | |||
| else: | |||
| raise RuntimeError # Undetermined | |||
| elif isinstance(edge_t, str): | |||
| if ' ' in edge_t: | |||
| raise ValueError | |||
| if len( | |||
| self.__heterogeneous_edges_data_frame.loc[ | |||
| self.__heterogeneous_edges_data_frame['r'] == edge_t | |||
| ] | |||
| ) != 1: | |||
| raise ValueError # todo: Unable to determine | |||
| else: | |||
| temp: HomogeneousEdgesContainer = self.__heterogeneous_edges_data_frame.loc[ | |||
| self.__heterogeneous_edges_data_frame['r'] == edge_t, 'edges' | |||
| ] | |||
| if not isinstance(temp, HomogeneousEdgesContainer): | |||
| raise RuntimeError | |||
| else: | |||
| return temp | |||
| elif isinstance(edge_t, _typing.Tuple) or isinstance(edge_t, _canonical_edge_type.CanonicalEdgeType): | |||
| if isinstance(edge_t, _typing.Tuple) and not ( | |||
| len(edge_t) == 3 and | |||
| isinstance(edge_t[0], str) and | |||
| isinstance(edge_t[1], str) and | |||
| isinstance(edge_t[2], str) and | |||
| ' ' not in edge_t[0] and ' ' not in edge_t[1] and ' ' not in edge_t[2] | |||
| ): | |||
| raise TypeError("Illegal canonical edge type") | |||
| __edge_t: _typing.Tuple[str, str, str] = ( | |||
| (edge_t.source_node_type, edge_t.relation_type, edge_t.target_node_type) | |||
| if isinstance(edge_t, _canonical_edge_type.CanonicalEdgeType) else edge_t | |||
| ) | |||
| partial_data_frame: pd.DataFrame = self.__heterogeneous_edges_data_frame.loc[ | |||
| (self.__heterogeneous_edges_data_frame['s'] == __edge_t[0]) & | |||
| (self.__heterogeneous_edges_data_frame['r'] == __edge_t[1]) & | |||
| (self.__heterogeneous_edges_data_frame['t'] == __edge_t[2]) | |||
| ] | |||
| if len(partial_data_frame) == 0: | |||
| raise ValueError | |||
| elif len(partial_data_frame) == 1: | |||
| temp: HomogeneousEdgesContainer = partial_data_frame.iloc[0]['edges'] | |||
| if not isinstance(temp, HomogeneousEdgesContainer): | |||
| raise RuntimeError | |||
| else: | |||
| return temp | |||
| else: | |||
| raise RuntimeError | |||
| def _set_edges( | |||
| self, | |||
| edge_t: _typing.Union[None, str, _typing.Tuple[str, str, str], _canonical_edge_type.CanonicalEdgeType], | |||
| edges: _typing.Union[HomogeneousEdgesContainer, torch.LongTensor] | |||
| ): | |||
| if not (isinstance(edges, HomogeneousEdgesContainer) or isinstance(edges, torch.Tensor)): | |||
| raise TypeError | |||
| if edge_t in (Ellipsis, None): | |||
| if len(self.__heterogeneous_edges_data_frame) == 0: | |||
| self.__heterogeneous_edges_data_frame: pd.DataFrame = ( | |||
| self.__heterogeneous_edges_data_frame.append( | |||
| pd.DataFrame( | |||
| { | |||
| 's': [''], 'r': [''], 't': [''], | |||
| 'edges': [ | |||
| edges if isinstance(edges, HomogeneousEdgesContainer) | |||
| else HomogeneousEdgesContainerImplementation(edges) | |||
| ] | |||
| } | |||
| ) | |||
| ) | |||
| ) | |||
| elif len(self.__heterogeneous_edges_data_frame) == 1: | |||
| self.__heterogeneous_edges_data_frame.iloc[0]['edges'] = ( | |||
| edges if isinstance(edges, HomogeneousEdgesContainer) | |||
| else HomogeneousEdgesContainerImplementation(edges) | |||
| ) | |||
| else: | |||
| raise RuntimeError # todo: Unable to determine error | |||
| elif isinstance(edge_t, str): | |||
| if ' ' in edge_t: | |||
| raise ValueError | |||
| if len( | |||
| self.__heterogeneous_edges_data_frame.loc[ | |||
| self.__heterogeneous_edges_data_frame['r'] == edge_t | |||
| ] | |||
| ) == 1: | |||
| self.__heterogeneous_edges_data_frame.loc[ | |||
| self.__heterogeneous_edges_data_frame['r'] == edge_t, 'edges' | |||
| ] = ( | |||
| edges if isinstance(edges, HomogeneousEdgesContainer) | |||
| else HomogeneousEdgesContainerImplementation(edges) | |||
| ) | |||
| else: | |||
| raise RuntimeError | |||
| elif isinstance(edge_t, _typing.Tuple) or isinstance(edge_t, _canonical_edge_type.CanonicalEdgeType): | |||
| if isinstance(edge_t, _typing.Tuple) and not ( | |||
| len(edge_t) == 3 and | |||
| isinstance(edge_t[0], str) and | |||
| isinstance(edge_t[1], str) and | |||
| isinstance(edge_t[2], str) and | |||
| ' ' not in edge_t[0] and ' ' not in edge_t[1] and ' ' not in edge_t[2] | |||
| ): | |||
| raise TypeError("Illegal canonical edge type") | |||
| __edge_t: _typing.Tuple[str, str, str] = ( | |||
| (edge_t.source_node_type, edge_t.relation_type, edge_t.target_node_type) | |||
| if isinstance(edge_t, _canonical_edge_type.CanonicalEdgeType) else edge_t | |||
| ) | |||
| if len( | |||
| self.__heterogeneous_edges_data_frame.loc[ | |||
| (self.__heterogeneous_edges_data_frame['s'] == __edge_t[0]) & | |||
| (self.__heterogeneous_edges_data_frame['r'] == __edge_t[1]) & | |||
| (self.__heterogeneous_edges_data_frame['t'] == __edge_t[2]) | |||
| ] | |||
| ) == 0: | |||
| self.__heterogeneous_edges_data_frame: pd.DataFrame = ( | |||
| self.__heterogeneous_edges_data_frame.append( | |||
| pd.DataFrame( | |||
| { | |||
| 's': [__edge_t[0]], | |||
| 'r': [__edge_t[1]], | |||
| 't': [__edge_t[2]], | |||
| 'edges': [ | |||
| edges if isinstance(edges, HomogeneousEdgesContainer) | |||
| else HomogeneousEdgesContainerImplementation(edges) | |||
| ] | |||
| } | |||
| ) | |||
| ) | |||
| ) | |||
| elif len( | |||
| self.__heterogeneous_edges_data_frame.loc[ | |||
| (self.__heterogeneous_edges_data_frame['s'] == __edge_t[0]) & | |||
| (self.__heterogeneous_edges_data_frame['r'] == __edge_t[1]) & | |||
| (self.__heterogeneous_edges_data_frame['t'] == __edge_t[2]) | |||
| ] | |||
| ) == 1: | |||
| self.__heterogeneous_edges_data_frame.loc[ | |||
| (self.__heterogeneous_edges_data_frame['s'] == __edge_t[0]) & | |||
| (self.__heterogeneous_edges_data_frame['r'] == __edge_t[1]) & | |||
| (self.__heterogeneous_edges_data_frame['t'] == __edge_t[2]), | |||
| 'edges' | |||
| ] = ( | |||
| edges if isinstance(edges, HomogeneousEdgesContainer) | |||
| else HomogeneousEdgesContainerImplementation(edges) | |||
| ) | |||
| else: | |||
| raise RuntimeError # todo: Unable to determine error | |||
| else: | |||
| raise RuntimeError | |||
| def _delete_edges( | |||
| self, edge_t: _typing.Union[None, str, _typing.Tuple[str, str, str], _canonical_edge_type.CanonicalEdgeType] = ... | |||
| ) -> None: | |||
| if edge_t in (Ellipsis, None): | |||
| if len(self.__heterogeneous_edges_data_frame) == 1: | |||
| self.__heterogeneous_edges_data_frame.drop( | |||
| self.__heterogeneous_edges_data_frame.index[0], inplace=True | |||
| ) | |||
| elif len(self.__heterogeneous_edges_data_frame) > 1: | |||
| raise ValueError("Edge Type must be specified for graph containing heterogeneous edges") | |||
| raise NotImplementedError # todo: Complete this function | |||
| class _HomogeneousEdgesDataView(_abstract_views.HomogeneousEdgesDataView): | |||
| def __init__(self, homogeneous_edges_container: HomogeneousEdgesContainer): | |||
| if not isinstance(homogeneous_edges_container, HomogeneousEdgesContainer): | |||
| raise TypeError | |||
| self._homogeneous_edges_container: HomogeneousEdgesContainer = homogeneous_edges_container | |||
| def __getitem__(self, data_key: str) -> torch.Tensor: | |||
| if not isinstance(data_key, str): | |||
| raise TypeError | |||
| if ' ' in data_key: | |||
| raise ValueError | |||
| return self._homogeneous_edges_container.get_data(data_key) | |||
| def __setitem__(self, data_key: str, data: torch.Tensor): | |||
| if not isinstance(data_key, str): | |||
| raise TypeError | |||
| elif ' ' in data_key: | |||
| raise ValueError | |||
| if not isinstance(data, torch.Tensor): | |||
| raise TypeError | |||
| elif not data.dim() > 0: | |||
| raise ValueError | |||
| self._homogeneous_edges_container.set_data(data_key, data) | |||
| def __delitem__(self, data_key: str): | |||
| if not isinstance(data_key, str): | |||
| raise TypeError | |||
| elif ' ' in data_key: | |||
| raise ValueError | |||
| self._homogeneous_edges_container.delete_data(data_key) | |||
| def __len__(self): | |||
| return len(list(self._homogeneous_edges_container.data_keys)) | |||
| def __iter__(self) -> _typing.Iterator[str]: | |||
| return iter(self._homogeneous_edges_container.data_keys) | |||
| class _SpecificTypedHomogeneousEdgesView(_abstract_views.HomogeneousEdgesView): | |||
| def __init__(self, homogeneous_edges_container: HomogeneousEdgesContainer): | |||
| if not isinstance(homogeneous_edges_container, HomogeneousEdgesContainer): | |||
| raise TypeError | |||
| self._homogeneous_edges_container: HomogeneousEdgesContainer = homogeneous_edges_container | |||
| @property | |||
| def connections(self) -> torch.Tensor: | |||
| return self._homogeneous_edges_container.connections | |||
| @property | |||
| def data(self) -> _HomogeneousEdgesDataView: | |||
| return _HomogeneousEdgesDataView(self._homogeneous_edges_container) | |||
| class _HeterogeneousEdgesView(_abstract_views.HeterogeneousEdgesView): | |||
| def __init__(self, _heterogeneous_edges_aggregation: HeterogeneousEdgesAggregation): | |||
| if not isinstance(_heterogeneous_edges_aggregation, HeterogeneousEdgesAggregation): | |||
| raise TypeError | |||
| self._heterogeneous_edges_aggregation: HeterogeneousEdgesAggregation = ( | |||
| _heterogeneous_edges_aggregation | |||
| ) | |||
| def __getitem__( | |||
| self, edge_t: _typing.Union[None, str, _typing.Tuple[str, str, str], _canonical_edge_type.CanonicalEdgeType] | |||
| ) -> _SpecificTypedHomogeneousEdgesView: | |||
| return _SpecificTypedHomogeneousEdgesView(self._heterogeneous_edges_aggregation[edge_t]) | |||
| def __setitem__( | |||
| self, edge_t: _typing.Union[None, str, _typing.Tuple[str, str, str], _canonical_edge_type.CanonicalEdgeType], | |||
| edges: _typing.Union[HomogeneousEdgesContainer, torch.LongTensor] | |||
| ): | |||
| self._heterogeneous_edges_aggregation[edge_t] = edges | |||
| def __delitem__( | |||
| self, edge_t: _typing.Union[None, str, _typing.Tuple[str, str, str], _canonical_edge_type.CanonicalEdgeType] | |||
| ): | |||
| del self._heterogeneous_edges_aggregation[edge_t] | |||
| def __len__(self) -> int: | |||
| return len(self._heterogeneous_edges_aggregation) | |||
| def __iter__(self) -> _typing.Iterator[_canonical_edge_type.CanonicalEdgeType]: | |||
| return iter(self._heterogeneous_edges_aggregation) | |||
| def __contains__(self, edge_type: _typing.Union[str, _typing.Tuple[str, str, str], _canonical_edge_type.CanonicalEdgeType]) -> bool: | |||
| if isinstance(edge_type, str): | |||
| if ' ' in edge_type: | |||
| raise ValueError | |||
| else: | |||
| for existing_edge_type in self: | |||
| if existing_edge_type.relation_type == edge_type: | |||
| return True | |||
| return False | |||
| elif isinstance(edge_type, _typing.Tuple): | |||
| if not ( | |||
| len(edge_type) == 3 and | |||
| all([(isinstance(t, str) and ' ' not in t) for t in edge_type]) | |||
| ): | |||
| raise TypeError | |||
| else: | |||
| for existing_edge_type in self: | |||
| if existing_edge_type.__eq__(edge_type): | |||
| return True | |||
| return False | |||
| elif isinstance(edge_type, _canonical_edge_type.CanonicalEdgeType): | |||
| for existing_edge_type in self: | |||
| if existing_edge_type == edge_type: | |||
| return True | |||
| return False | |||
| else: | |||
| raise TypeError | |||
| @property | |||
| def connections(self) -> torch.Tensor: | |||
| return self[...].connections | |||
| @property | |||
| def data(self) -> _HomogeneousEdgesDataView: | |||
| return self[...].data | |||
| @property | |||
| def is_homogeneous(self) -> bool: | |||
| return len(self) <= 1 | |||
| def set( | |||
| self, edge_t: _typing.Union[None, str, _typing.Tuple[str, str, str]], | |||
| connections: torch.LongTensor, data: _typing.Optional[_typing.Mapping[str, torch.Tensor]] = ... | |||
| ): | |||
| self[edge_t] = HomogeneousEdgesContainerImplementation(connections, data) | |||
| class _StaticGraphDataContainer(_typing.MutableMapping[str, torch.Tensor]): | |||
| def __setitem__(self, data_key: str, data: torch.Tensor) -> None: | |||
| raise NotImplementedError | |||
| def __delitem__(self, data_key: str) -> None: | |||
| raise NotImplementedError | |||
| def __getitem__(self, data_key: str) -> torch.Tensor: | |||
| raise NotImplementedError | |||
| def __len__(self) -> int: | |||
| raise NotImplementedError | |||
| def __iter__(self) -> _typing.Iterator[str]: | |||
| raise NotImplementedError | |||
| class StaticGraphDataAggregation(_StaticGraphDataContainer): | |||
| def __init__( | |||
| self, graph_data: _typing.Optional[_typing.Mapping[str, torch.Tensor]] = ... | |||
| ): | |||
| self.__data: _typing.MutableMapping[str, torch.Tensor] = ( | |||
| dict(graph_data) if isinstance(graph_data, _typing.Mapping) | |||
| else {} | |||
| ) | |||
| def __setitem__(self, data_key: str, data: torch.Tensor) -> None: | |||
| self.__data[data_key] = data | |||
| def __delitem__(self, data_key: str) -> None: | |||
| del self.__data[data_key] | |||
| def __getitem__(self, data_key: str) -> torch.Tensor: | |||
| return self.__data[data_key] | |||
| def __len__(self) -> int: | |||
| return len(self.__data) | |||
| def __iter__(self) -> _typing.Iterator[str]: | |||
| return iter(self.__data) | |||
| class _StaticGraphDataView(_abstract_views.GraphDataView): | |||
| def __init__(self, graph_data_container: _StaticGraphDataContainer): | |||
| self.__graph_data_container: _StaticGraphDataContainer = ( | |||
| graph_data_container | |||
| ) | |||
| def __setitem__(self, data_key: str, data: torch.Tensor) -> None: | |||
| self.__graph_data_container[data_key] = data | |||
| def __delitem__(self, data_key: str) -> None: | |||
| del self.__graph_data_container[data_key] | |||
| def __getitem__(self, data_key: str) -> torch.Tensor: | |||
| return self.__graph_data_container[data_key] | |||
| def __len__(self) -> int: | |||
| return len(self.__graph_data_container) | |||
| def __iter__(self) -> _typing.Iterator[str]: | |||
| return iter(self.__graph_data_container) | |||
| class GeneralStaticGraphImplementation(_general_static_graph.GeneralStaticGraph): | |||
| def __init__( | |||
| self, _heterogeneous_nodes_container: _typing.Optional[HeterogeneousNodesContainer] = ..., | |||
| _heterogeneous_edges_aggregation: _typing.Optional[HeterogeneousEdgesAggregation] = ..., | |||
| graph_data_container: _typing.Optional[_StaticGraphDataContainer] = ... | |||
| ): | |||
| self._static_graph_data_container: _StaticGraphDataContainer = ( | |||
| graph_data_container | |||
| if isinstance(graph_data_container, _StaticGraphDataContainer) | |||
| else StaticGraphDataAggregation() | |||
| ) | |||
| self._heterogeneous_nodes_container: HeterogeneousNodesContainer = ( | |||
| _heterogeneous_nodes_container | |||
| if isinstance(_heterogeneous_nodes_container, HeterogeneousNodesContainer) | |||
| else HeterogeneousNodesContainerImplementation() | |||
| ) | |||
| self._heterogeneous_edges_aggregation: HeterogeneousEdgesAggregation = ( | |||
| _heterogeneous_edges_aggregation | |||
| if isinstance(_heterogeneous_edges_aggregation, HeterogeneousEdgesAggregation) | |||
| else HeterogeneousEdgesAggregationImplementation() | |||
| ) | |||
| @property | |||
| def nodes(self) -> _HeterogeneousNodeView: | |||
| return _HeterogeneousNodeView(self._heterogeneous_nodes_container) | |||
| @property | |||
| def edges(self) -> _HeterogeneousEdgesView: | |||
| return _HeterogeneousEdgesView(self._heterogeneous_edges_aggregation) | |||
| @property | |||
| def data(self) -> _StaticGraphDataView: | |||
| return _StaticGraphDataView(self._static_graph_data_container) | |||
| @@ -0,0 +1,651 @@ | |||
| import dgl | |||
| import torch | |||
| import typing as _typing | |||
| from . import ( | |||
| _abstract_views, | |||
| _canonical_edge_type, | |||
| _general_static_graph | |||
| ) | |||
| class _DGLGraphHolder: | |||
| def __init__(self, dgl_graph: dgl.DGLGraph): | |||
| if not isinstance(dgl_graph, dgl.DGLGraph): | |||
| raise TypeError | |||
| self.__graph: dgl.DGLGraph = dgl_graph | |||
| @property | |||
| def graph(self) -> dgl.DGLGraph: | |||
| return self.__graph | |||
| @graph.setter | |||
| def graph(self, dgl_graph: dgl.DGLGraph): | |||
| if not isinstance(dgl_graph, dgl.DGLGraph): | |||
| raise TypeError | |||
| else: | |||
| self.__graph = dgl_graph | |||
| class _SpecificTypedNodeDataView(_abstract_views.SpecificTypedNodeDataView): | |||
| def __init__( | |||
| self, dgl_graph_holder: _DGLGraphHolder, | |||
| node_type: _typing.Optional[str] = ... | |||
| ): | |||
| if not isinstance(dgl_graph_holder, _DGLGraphHolder): | |||
| raise TypeError | |||
| if not (node_type in (Ellipsis, None) or isinstance(node_type, str)): | |||
| raise TypeError | |||
| elif isinstance(node_type, str) and ' ' in node_type: | |||
| raise ValueError("Illegal node type") | |||
| self.__dgl_graph_holder: _DGLGraphHolder = dgl_graph_holder | |||
| self.__optional_node_type: _typing.Optional[str] = ( | |||
| node_type if isinstance(node_type, str) else None | |||
| ) | |||
| def __getitem__(self, data_key: str) -> torch.Tensor: | |||
| if not isinstance(data_key, str): | |||
| raise TypeError | |||
| elif ' ' in data_key: | |||
| raise ValueError("Illegal data key") | |||
| if isinstance(self.__optional_node_type, str): | |||
| node_type: str = self.__optional_node_type | |||
| else: | |||
| if len(self.__dgl_graph_holder.graph.ntypes) == 0: | |||
| raise ValueError("the graph is empty") | |||
| elif len(self.__dgl_graph_holder.graph.ntypes) > 1: | |||
| raise ValueError( | |||
| "Unable to automatically determine node type, " | |||
| "the graph consists of heterogeneous node types" | |||
| ) | |||
| else: | |||
| node_type: str = self.__dgl_graph_holder.graph.ntypes[0] | |||
| if data_key in self.__dgl_graph_holder.graph.nodes[node_type].data: | |||
| return self.__dgl_graph_holder.graph.nodes[node_type].data[data_key] | |||
| else: | |||
| raise KeyError # todo: Complete message | |||
| def __setitem__(self, data_key: str, value: torch.Tensor): | |||
| if not isinstance(data_key, str): | |||
| raise TypeError | |||
| elif ' ' in data_key: | |||
| raise ValueError("Illegal data key") | |||
| if not isinstance(value, torch.Tensor): | |||
| raise TypeError | |||
| elif value.dim() == 0: | |||
| raise ValueError | |||
| if isinstance(self.__optional_node_type, str): | |||
| node_type: str = self.__optional_node_type | |||
| else: | |||
| if len(self.__dgl_graph_holder.graph.ntypes) == 0: | |||
| raise ValueError("the graph is empty") | |||
| elif len(self.__dgl_graph_holder.graph.ntypes) > 1: | |||
| raise ValueError( | |||
| "Unable to automatically determine node type, " | |||
| "the graph consists of heterogeneous node types" | |||
| ) | |||
| else: | |||
| node_type: str = self.__dgl_graph_holder.graph.ntypes[0] | |||
| if value.size(0) != self.__dgl_graph_holder.graph.num_nodes(node_type): | |||
| raise ValueError # todo: Complete error message | |||
| else: | |||
| # todo: 现在这个方法没有处理node_type不存在的情况 | |||
| self.__dgl_graph_holder.graph.nodes[node_type].data[data_key] = value | |||
| def __delitem__(self, data_key: str) -> None: | |||
| if not isinstance(data_key, str): | |||
| raise TypeError | |||
| elif ' ' in data_key: | |||
| raise ValueError("Illegal data key") | |||
| if isinstance(self.__optional_node_type, str): | |||
| node_type: str = self.__optional_node_type | |||
| else: | |||
| if len(self.__dgl_graph_holder.graph.ntypes) == 0: | |||
| raise ValueError("the graph is empty") | |||
| elif len(self.__dgl_graph_holder.graph.ntypes) > 1: | |||
| raise ValueError( | |||
| "Unable to automatically determine node type, " | |||
| "the graph consists of heterogeneous node types" | |||
| ) | |||
| else: | |||
| node_type: str = self.__dgl_graph_holder.graph.ntypes[0] | |||
| if data_key in self.__dgl_graph_holder.graph.nodes[node_type].data: | |||
| try: | |||
| del self.__dgl_graph_holder.graph.nodes[node_type].data[data_key] | |||
| except KeyError: | |||
| pass # todo: Use logger to warn | |||
| def __len__(self) -> int: | |||
| if isinstance(self.__optional_node_type, str): | |||
| node_type: str = self.__optional_node_type | |||
| else: | |||
| if len(self.__dgl_graph_holder.graph.ntypes) == 0: | |||
| raise ValueError("the graph is empty") | |||
| elif len(self.__dgl_graph_holder.graph.ntypes) > 1: | |||
| raise ValueError( | |||
| "Unable to automatically determine node type, " | |||
| "the graph consists of heterogeneous node types" | |||
| ) | |||
| else: | |||
| node_type: str = self.__dgl_graph_holder.graph.ntypes[0] | |||
| return len(self.__dgl_graph_holder.graph.nodes[node_type].data) | |||
| def __iter__(self) -> _typing.Iterator[str]: | |||
| if isinstance(self.__optional_node_type, str): | |||
| node_type: str = self.__optional_node_type | |||
| else: | |||
| if len(self.__dgl_graph_holder.graph.ntypes) == 0: | |||
| raise ValueError("the graph is empty") | |||
| elif len(self.__dgl_graph_holder.graph.ntypes) > 1: | |||
| raise ValueError( | |||
| "Unable to automatically determine node type, " | |||
| "the graph consists of heterogeneous node types" | |||
| ) | |||
| else: | |||
| node_type: str = self.__dgl_graph_holder.graph.ntypes[0] | |||
| return iter(self.__dgl_graph_holder.graph.nodes[node_type].data) | |||
| class _SpecificTypedNodeView(_abstract_views.SpecificTypedNodeView): | |||
| def __init__( | |||
| self, dgl_graph_holder: _DGLGraphHolder, | |||
| node_type: _typing.Optional[str] = ... | |||
| ): | |||
| if not isinstance(dgl_graph_holder, _DGLGraphHolder): | |||
| raise TypeError | |||
| if not (node_type in (Ellipsis, None) or isinstance(node_type, str)): | |||
| raise TypeError | |||
| elif isinstance(node_type, str) and ' ' in node_type: | |||
| raise ValueError("Illegal node type") | |||
| self.__dgl_graph_holder: _DGLGraphHolder = dgl_graph_holder | |||
| self.__optional_node_type: _typing.Optional[str] = ( | |||
| node_type if isinstance(node_type, str) else None | |||
| ) | |||
| @property | |||
| def data(self) -> _SpecificTypedNodeDataView: | |||
| return _SpecificTypedNodeDataView( | |||
| self.__dgl_graph_holder, self.__optional_node_type | |||
| ) | |||
| @data.setter | |||
| def data(self, nodes_data: _typing.Mapping[str, torch.Tensor]): | |||
| raise NotImplementedError # todo: Currently, DGL not support this operation | |||
| class _HeterogeneousNodeView(_abstract_views.HeterogeneousNodeView): | |||
| def __init__(self, dgl_graph_holder: _DGLGraphHolder): | |||
| if not isinstance(dgl_graph_holder, _DGLGraphHolder): | |||
| raise TypeError | |||
| self.__dgl_graph_holder: _DGLGraphHolder = dgl_graph_holder | |||
| @property | |||
| def data(self) -> _SpecificTypedNodeDataView: | |||
| return _SpecificTypedNodeDataView(self.__dgl_graph_holder, ...) | |||
| @data.setter | |||
| def data(self, nodes_data: _typing.Mapping[str, torch.Tensor]): | |||
| if not isinstance(nodes_data, _typing.Mapping): | |||
| raise TypeError | |||
| _SpecificTypedNodeView(self.__dgl_graph_holder, ...).data = nodes_data | |||
| def __getitem__(self, node_type: _typing.Optional[str]) -> _SpecificTypedNodeView: | |||
| if not (node_type in (Ellipsis, None) or isinstance(node_type, str)): | |||
| raise TypeError | |||
| elif isinstance(node_type, str) and ' ' in node_type: | |||
| raise ValueError("Illegal edge type") | |||
| return _SpecificTypedNodeView(self.__dgl_graph_holder, node_type) | |||
| def __setitem__( | |||
| self, node_type: _typing.Optional[str], | |||
| nodes_data: _typing.Mapping[str, torch.Tensor] | |||
| ): | |||
| if not (node_type in (Ellipsis, None) or isinstance(node_type, str)): | |||
| raise TypeError | |||
| elif isinstance(node_type, str) and ' ' in node_type: | |||
| raise ValueError("Illegal edge type") | |||
| if not isinstance(nodes_data, _typing.Mapping): | |||
| raise TypeError | |||
| _SpecificTypedNodeView( | |||
| self.__dgl_graph_holder, node_type if isinstance(node_type, str) else None | |||
| ).data = nodes_data | |||
| def __delitem__(self, node_t: _typing.Optional[str]): | |||
| raise NotImplementedError # todo: Currently, DGL not support this operation | |||
| def __iter__(self) -> _typing.Iterator[str]: | |||
| return iter(self.__dgl_graph_holder.graph.ntypes) | |||
| @property | |||
| def is_homogeneous(self) -> bool: | |||
| return len(self.__dgl_graph_holder.graph.ntypes) <= 1 | |||
| class _HomogeneousEdgesDataView(_abstract_views.HomogeneousEdgesDataView): | |||
| def __init__( | |||
| self, dgl_graph_holder: _DGLGraphHolder, | |||
| edge_type: _typing.Union[ | |||
| None, str, _typing.Tuple[str, str, str], | |||
| _canonical_edge_type.CanonicalEdgeType | |||
| ] = ... | |||
| ): | |||
| if not isinstance(dgl_graph_holder, _DGLGraphHolder): | |||
| raise TypeError | |||
| self.__dgl_graph_holder: _DGLGraphHolder = dgl_graph_holder | |||
| if edge_type in (Ellipsis, None): | |||
| self.__optional_edge_type: _typing.Union[None, str, _typing.Tuple[str, str, str]] = None | |||
| elif isinstance(edge_type, str): | |||
| if ' ' in edge_type: | |||
| raise ValueError("Illegal edge type") | |||
| self.__optional_edge_type: _typing.Union[None, str, _typing.Tuple[str, str, str]] = edge_type | |||
| elif isinstance(edge_type, _typing.Sequence) and not isinstance(edge_type, str): | |||
| if not ( | |||
| len(edge_type) == 3 and | |||
| isinstance(edge_type[0], str) and ' ' not in edge_type[0] and | |||
| isinstance(edge_type[1], str) and ' ' not in edge_type[1] and | |||
| isinstance(edge_type[2], str) and ' ' not in edge_type[2] | |||
| ): | |||
| raise ValueError("Illegal edge type") | |||
| self.__optional_edge_type: _typing.Union[None, str, _typing.Tuple[str, str, str]] = tuple(edge_type) | |||
| elif isinstance(edge_type, _canonical_edge_type.CanonicalEdgeType): | |||
| self.__optional_edge_type: _typing.Union[None, str, _typing.Tuple[str, str, str]] = ( | |||
| edge_type.source_node_type, edge_type.relation_type, edge_type.target_node_type | |||
| ) | |||
| else: | |||
| raise TypeError | |||
| def __get_canonical_edge_type(self) -> _typing.Tuple[str, str, str]: | |||
| if self.__optional_edge_type in (Ellipsis, None): | |||
| if len(self.__dgl_graph_holder.graph.canonical_etypes) == 0: | |||
| raise ValueError("The graph is empty") | |||
| elif len(self.__dgl_graph_holder.graph.canonical_etypes) > 1: | |||
| raise ValueError( | |||
| "Unable to automatically determine edge type, " | |||
| "the graph consists of heterogeneous edge types." | |||
| ) | |||
| else: | |||
| return self.__dgl_graph_holder.graph.canonical_etypes[0] | |||
| elif isinstance(self.__optional_edge_type, str): | |||
| try: | |||
| canonical_edge_type = self.__dgl_graph_holder.graph.to_canonical_etype( | |||
| self.__optional_edge_type | |||
| ) | |||
| except dgl.DGLError as e: | |||
| raise e | |||
| else: | |||
| return canonical_edge_type | |||
| else: | |||
| return self.__optional_edge_type | |||
| def __getitem__(self, data_key: str) -> torch.Tensor: | |||
| if not isinstance(data_key, str): | |||
| raise TypeError | |||
| elif ' ' in data_key: | |||
| raise ValueError("Illegal data key") | |||
| edge_type: _typing.Tuple[str, str, str] = self.__get_canonical_edge_type() | |||
| found = False | |||
| for et in self.__dgl_graph_holder.graph.canonical_etypes: | |||
| if all([a == b for a, b in zip(et, edge_type)]): | |||
| found = True | |||
| break | |||
| if not found: | |||
| raise ValueError("edge type not exist") | |||
| if data_key in self.__dgl_graph_holder.graph.edges[edge_type].data: | |||
| return self.__dgl_graph_holder.graph.edges[edge_type].data[data_key] | |||
| else: | |||
| raise KeyError # todo: Complete error message | |||
| def __setitem__(self, data_key: str, value: torch.Tensor): | |||
| if not isinstance(data_key, str): | |||
| raise TypeError | |||
| elif ' ' in data_key: | |||
| raise ValueError("Illegal data key") | |||
| if not isinstance(value, torch.Tensor): | |||
| raise TypeError | |||
| if value.dim() == 0: | |||
| raise ValueError | |||
| edge_type: _typing.Tuple[str, str, str] = self.__get_canonical_edge_type() | |||
| found = False | |||
| for et in self.__dgl_graph_holder.graph.canonical_etypes: | |||
| if all([a == b for a, b in zip(et, edge_type)]): | |||
| found = True | |||
| break | |||
| if not found: | |||
| raise ValueError("edge type not exist") | |||
| self.__dgl_graph_holder.graph.edges[edge_type].data[data_key] = value | |||
| def __delitem__(self, data_key: str): | |||
| if not isinstance(data_key, str): | |||
| raise TypeError | |||
| elif ' ' in data_key: | |||
| raise ValueError("Illegal data key") | |||
| edge_type: _typing.Tuple[str, str, str] = self.__get_canonical_edge_type() | |||
| found = False | |||
| for et in self.__dgl_graph_holder.graph.canonical_etypes: | |||
| if all([a == b for a, b in zip(et, edge_type)]): | |||
| found = True | |||
| break | |||
| if not found: | |||
| raise ValueError("edge type not exist") | |||
| if data_key in self.__dgl_graph_holder.graph.edges[edge_type].data: | |||
| del self.__dgl_graph_holder.graph.edges[edge_type].data[data_key] | |||
| else: | |||
| raise KeyError # todo: Complete error message | |||
| def __len__(self) -> int: | |||
| edge_type: _typing.Tuple[str, str, str] = self.__get_canonical_edge_type() | |||
| found = False | |||
| for et in self.__dgl_graph_holder.graph.canonical_etypes: | |||
| if all([a == b for a, b in zip(et, edge_type)]): | |||
| found = True | |||
| break | |||
| if not found: | |||
| raise ValueError("edge type not exist") | |||
| return len(self.__dgl_graph_holder.graph.edges[edge_type].data) | |||
| def __iter__(self) -> _typing.Iterator[str]: | |||
| edge_type: _typing.Tuple[str, str, str] = self.__get_canonical_edge_type() | |||
| found = False | |||
| for et in self.__dgl_graph_holder.graph.canonical_etypes: | |||
| if all([a == b for a, b in zip(et, edge_type)]): | |||
| found = True | |||
| break | |||
| if not found: | |||
| raise ValueError("edge type not exist") | |||
| return iter(self.__dgl_graph_holder.graph.edges[edge_type].data) | |||
| class _HomogeneousEdgesView(_abstract_views.HomogeneousEdgesView): | |||
| def __init__( | |||
| self, dgl_graph_holder: _DGLGraphHolder, | |||
| edge_type: _typing.Union[ | |||
| None, str, _typing.Tuple[str, str, str], | |||
| _canonical_edge_type.CanonicalEdgeType | |||
| ] = ... | |||
| ): | |||
| if not isinstance(dgl_graph_holder, _DGLGraphHolder): | |||
| raise TypeError | |||
| self.__dgl_graph_holder: _DGLGraphHolder = dgl_graph_holder | |||
| if edge_type in (Ellipsis, None): | |||
| self.__optional_edge_type: _typing.Union[None, str, _typing.Tuple[str, str, str]] = None | |||
| elif isinstance(edge_type, str): | |||
| if ' ' in edge_type: | |||
| raise ValueError("Illegal edge type") | |||
| self.__optional_edge_type: _typing.Union[None, str, _typing.Tuple[str, str, str]] = edge_type | |||
| elif isinstance(edge_type, _typing.Sequence) and not isinstance(edge_type, str): | |||
| if not ( | |||
| len(edge_type) == 3 and | |||
| isinstance(edge_type[0], str) and ' ' not in edge_type[0] and | |||
| isinstance(edge_type[1], str) and ' ' not in edge_type[1] and | |||
| isinstance(edge_type[2], str) and ' ' not in edge_type[2] | |||
| ): | |||
| raise ValueError("Illegal edge type") | |||
| self.__optional_edge_type: _typing.Union[None, str, _typing.Tuple[str, str, str]] = tuple(edge_type) | |||
| elif isinstance(edge_type, _canonical_edge_type.CanonicalEdgeType): | |||
| self.__optional_edge_type: _typing.Union[None, str, _typing.Tuple[str, str, str]] = ( | |||
| edge_type.source_node_type, edge_type.relation_type, edge_type.target_node_type | |||
| ) | |||
| else: | |||
| raise TypeError | |||
| def __get_canonical_edge_type(self) -> _typing.Tuple[str, str, str]: | |||
| if self.__optional_edge_type in (Ellipsis, None): | |||
| if len(self.__dgl_graph_holder.graph.canonical_etypes) == 0: | |||
| raise ValueError("The graph is empty") | |||
| elif len(self.__dgl_graph_holder.graph.canonical_etypes) > 1: | |||
| raise ValueError( | |||
| "Unable to automatically determine edge type, " | |||
| "the graph consists of heterogeneous edge types." | |||
| ) | |||
| else: | |||
| return self.__dgl_graph_holder.graph.canonical_etypes[0] | |||
| elif isinstance(self.__optional_edge_type, str): | |||
| try: | |||
| canonical_edge_type = self.__dgl_graph_holder.graph.to_canonical_etype( | |||
| self.__optional_edge_type | |||
| ) | |||
| except dgl.DGLError as e: | |||
| raise e | |||
| else: | |||
| return canonical_edge_type | |||
| else: | |||
| return self.__optional_edge_type | |||
| @property | |||
| def connections(self) -> torch.Tensor: | |||
| return torch.vstack( | |||
| self.__dgl_graph_holder.graph.edges(etype=self.__get_canonical_edge_type()) | |||
| ) | |||
| @property | |||
| def data(self) -> _HomogeneousEdgesDataView: | |||
| return _HomogeneousEdgesDataView(self.__dgl_graph_holder, self.__optional_edge_type) | |||
| class _HeterogeneousEdgesView(_abstract_views.HeterogeneousEdgesView): | |||
| def __init__(self, dgl_graph_holder: _DGLGraphHolder): | |||
| if not isinstance(dgl_graph_holder, _DGLGraphHolder): | |||
| raise TypeError | |||
| self.__dgl_graph_holder: _DGLGraphHolder = dgl_graph_holder | |||
| self.__optional_edge_type: _typing.Union[None, str, _typing.Tuple[str, str, str]] = None | |||
| def __get_canonical_edge_type(self) -> _typing.Tuple[str, str, str]: | |||
| if self.__optional_edge_type in (Ellipsis, None): | |||
| if len(self.__dgl_graph_holder.graph.canonical_etypes) == 0: | |||
| raise ValueError("The graph is empty") | |||
| elif len(self.__dgl_graph_holder.graph.canonical_etypes) > 1: | |||
| raise ValueError( | |||
| "Unable to automatically determine edge type, " | |||
| "the graph consists of heterogeneous edge types." | |||
| ) | |||
| else: | |||
| return self.__dgl_graph_holder.graph.canonical_etypes[0] | |||
| elif isinstance(self.__optional_edge_type, str): | |||
| try: | |||
| canonical_edge_type = self.__dgl_graph_holder.graph.to_canonical_etype( | |||
| self.__optional_edge_type | |||
| ) | |||
| except dgl.DGLError as e: | |||
| raise e | |||
| else: | |||
| return canonical_edge_type | |||
| else: | |||
| return self.__optional_edge_type | |||
| @property | |||
| def connections(self) -> torch.Tensor: | |||
| return _HomogeneousEdgesView(self.__dgl_graph_holder, ...).connections | |||
| @property | |||
| def data(self) -> _HomogeneousEdgesDataView: | |||
| return _HomogeneousEdgesView(self.__dgl_graph_holder, ...).data | |||
| @property | |||
| def is_homogeneous(self) -> bool: | |||
| return len(self.__dgl_graph_holder.graph.canonical_etypes) <= 1 | |||
| def set( | |||
| self, edge_t: _typing.Union[None, str, _typing.Tuple[str, str, str]], | |||
| connections: torch.LongTensor, | |||
| data: _typing.Optional[_typing.Mapping[str, torch.Tensor]] = ... | |||
| ): | |||
| raise NotImplementedError # todo: Complete this function or this error message | |||
| def __getitem__( | |||
| self, | |||
| edge_t: _typing.Union[ | |||
| None, str, _typing.Tuple[str, str, str], _canonical_edge_type.CanonicalEdgeType | |||
| ] | |||
| ) -> _HomogeneousEdgesView: | |||
| return _HomogeneousEdgesView(self.__dgl_graph_holder, edge_t) | |||
| def __setitem__( | |||
| self, | |||
| edge_t: _typing.Union[ | |||
| None, str, _typing.Tuple[str, str, str], _canonical_edge_type.CanonicalEdgeType | |||
| ], | |||
| edges: _typing.Union[torch.LongTensor] | |||
| ): | |||
| raise NotImplementedError # todo: Complete this function or this error message | |||
| def __delitem__( | |||
| self, | |||
| edge_t: _typing.Union[ | |||
| None, str, _typing.Tuple[str, str, str], _canonical_edge_type.CanonicalEdgeType | |||
| ] | |||
| ): | |||
| raise NotImplementedError # todo: Complete this function or this error message | |||
| def __len__(self) -> int: | |||
| return len(self.__dgl_graph_holder.graph.canonical_etypes) | |||
| def __iter__(self) -> _typing.Iterator[_canonical_edge_type.CanonicalEdgeType]: | |||
| return iter([ | |||
| _canonical_edge_type.CanonicalEdgeType(et[0], et[1], et[2]) | |||
| for et in self.__dgl_graph_holder.graph.canonical_etypes | |||
| ]) | |||
| def __contains__( | |||
| self, | |||
| edge_type: _typing.Union[ | |||
| str, _typing.Tuple[str, str, str], _canonical_edge_type.CanonicalEdgeType | |||
| ] | |||
| ) -> bool: | |||
| # raise NotImplementedError | |||
| if isinstance(edge_type, str): | |||
| if ' ' in edge_type: | |||
| raise ValueError("Illegal edge type") | |||
| else: | |||
| return edge_type in self.__dgl_graph_holder.graph.etypes | |||
| elif isinstance(edge_type, _typing.Sequence) and not isinstance(edge_type, str): | |||
| if not ( | |||
| len(edge_type) == 3 and | |||
| isinstance(edge_type[0], str) and ' ' not in edge_type[0] and | |||
| isinstance(edge_type[1], str) and ' ' not in edge_type[1] and | |||
| isinstance(edge_type[2], str) and ' ' not in edge_type[2] | |||
| ): | |||
| raise ValueError("Illegal edge type") | |||
| found = False | |||
| for et in self.__dgl_graph_holder.graph.canonical_etypes: | |||
| if all([a == b for a, b in zip(et, edge_type)]): | |||
| found = True | |||
| break | |||
| return found | |||
| elif isinstance(edge_type, _canonical_edge_type.CanonicalEdgeType): | |||
| found = False | |||
| for et in self.__dgl_graph_holder.graph.canonical_etypes: | |||
| if ( | |||
| et[0] == edge_type.source_node_type and | |||
| et[1] == edge_type.relation_type and | |||
| et[2] == edge_type.target_node_type | |||
| ): | |||
| found = True | |||
| break | |||
| return found | |||
| else: | |||
| raise TypeError | |||
| class _StaticGraphDataContainer(_typing.MutableMapping[str, torch.Tensor]): | |||
| def __setitem__(self, data_key: str, data: torch.Tensor) -> None: | |||
| raise NotImplementedError | |||
| def __delitem__(self, data_key: str) -> None: | |||
| raise NotImplementedError | |||
| def __getitem__(self, data_key: str) -> torch.Tensor: | |||
| raise NotImplementedError | |||
| def __len__(self) -> int: | |||
| raise NotImplementedError | |||
| def __iter__(self) -> _typing.Iterator[str]: | |||
| raise NotImplementedError | |||
| class StaticGraphDataAggregation(_StaticGraphDataContainer): | |||
| def __init__( | |||
| self, graph_data: _typing.Optional[_typing.Mapping[str, torch.Tensor]] = ... | |||
| ): | |||
| self.__data: _typing.MutableMapping[str, torch.Tensor] = ( | |||
| dict(graph_data) if isinstance(graph_data, _typing.Mapping) | |||
| else {} | |||
| ) | |||
| def __setitem__(self, data_key: str, data: torch.Tensor) -> None: | |||
| self.__data[data_key] = data | |||
| def __delitem__(self, data_key: str) -> None: | |||
| del self.__data[data_key] | |||
| def __getitem__(self, data_key: str) -> torch.Tensor: | |||
| return self.__data[data_key] | |||
| def __len__(self) -> int: | |||
| return len(self.__data) | |||
| def __iter__(self) -> _typing.Iterator[str]: | |||
| return iter(self.__data) | |||
| class _StaticGraphDataView(_abstract_views.GraphDataView): | |||
| def __init__(self, graph_data_container: _StaticGraphDataContainer): | |||
| self.__graph_data_container: _StaticGraphDataContainer = ( | |||
| graph_data_container | |||
| ) | |||
| def __setitem__(self, data_key: str, data: torch.Tensor) -> None: | |||
| self.__graph_data_container[data_key] = data | |||
| def __delitem__(self, data_key: str) -> None: | |||
| del self.__graph_data_container[data_key] | |||
| def __getitem__(self, data_key: str) -> torch.Tensor: | |||
| return self.__graph_data_container[data_key] | |||
| def __len__(self) -> int: | |||
| return len(self.__graph_data_container) | |||
| def __iter__(self) -> _typing.Iterator[str]: | |||
| return iter(self.__graph_data_container) | |||
| class GeneralStaticGraphDGLImplementation( | |||
| _general_static_graph.GeneralStaticGraph | |||
| ): | |||
| def __init__( | |||
| self, dgl_graph: dgl.DGLGraph, | |||
| graph_data: _typing.Optional[_typing.Mapping[str, torch.Tensor]] = ... | |||
| ): | |||
| if not isinstance(dgl_graph, dgl.DGLGraph) and ( | |||
| graph_data in (Ellipsis, None) or | |||
| isinstance(graph_data, _typing.Mapping) | |||
| ): | |||
| raise TypeError | |||
| self.__dgl_graph_holder: _DGLGraphHolder = _DGLGraphHolder(dgl_graph) | |||
| self.__graph_data_container: _StaticGraphDataContainer = ( | |||
| StaticGraphDataAggregation( | |||
| graph_data if isinstance(graph_data, _typing.Mapping) else None | |||
| ) | |||
| ) | |||
| @property | |||
| def nodes(self) -> _abstract_views.HeterogeneousNodeView: | |||
| return _HeterogeneousNodeView(self.__dgl_graph_holder) | |||
| @property | |||
| def edges(self) -> _abstract_views.HeterogeneousEdgesView: | |||
| return _HeterogeneousEdgesView(self.__dgl_graph_holder) | |||
| @property | |||
| def data(self) -> _abstract_views.GraphDataView: | |||
| return _StaticGraphDataView(self.__graph_data_container) | |||
| @@ -0,0 +1,80 @@ | |||
| import torch | |||
| import typing as _typing | |||
| from . import _general_static_graph | |||
| from ._general_static_graph_default_implementation import ( | |||
| HeterogeneousNodesContainer, HeterogeneousNodesContainerImplementation, | |||
| HomogeneousEdgesContainerImplementation, | |||
| HeterogeneousEdgesAggregation, HeterogeneousEdgesAggregationImplementation, | |||
| StaticGraphDataAggregation, GeneralStaticGraphImplementation | |||
| ) | |||
| class GeneralStaticGraphGenerator: | |||
| @classmethod | |||
| def create_heterogeneous_static_graph( | |||
| cls, heterogeneous_nodes_data: _typing.Mapping[str, _typing.Mapping[str, torch.Tensor]], | |||
| heterogeneous_edges: _typing.Mapping[ | |||
| _typing.Tuple[str, str, str], | |||
| _typing.Union[ | |||
| torch.Tensor, | |||
| _typing.Tuple[ | |||
| torch.Tensor, | |||
| _typing.Optional[_typing.Mapping[str, torch.Tensor]] | |||
| ] | |||
| ] | |||
| ], | |||
| graph_data: _typing.Optional[_typing.Mapping[str, torch.Tensor]] = ... | |||
| ) -> _general_static_graph.GeneralStaticGraph: | |||
| _heterogeneous_nodes_container: HeterogeneousNodesContainer = ( | |||
| HeterogeneousNodesContainerImplementation(heterogeneous_nodes_data) | |||
| ) | |||
| _heterogeneous_edges_aggregation: HeterogeneousEdgesAggregation = ( | |||
| HeterogeneousEdgesAggregationImplementation() | |||
| ) | |||
| for canonical_edge_type, specific_typed_edges in heterogeneous_edges.items(): | |||
| if isinstance(specific_typed_edges, torch.Tensor): | |||
| connections = specific_typed_edges | |||
| data = None | |||
| elif ( | |||
| isinstance(specific_typed_edges, _typing.Sequence) and | |||
| len(specific_typed_edges) == 2 and | |||
| isinstance(specific_typed_edges[0], torch.Tensor) and | |||
| ( | |||
| specific_typed_edges[1] is None or | |||
| isinstance(specific_typed_edges[1], _typing.Mapping) | |||
| ) | |||
| ): | |||
| connections = specific_typed_edges[0] | |||
| data = specific_typed_edges[1] | |||
| else: | |||
| raise TypeError | |||
| _heterogeneous_edges_aggregation[canonical_edge_type] = ( | |||
| HomogeneousEdgesContainerImplementation(connections, data) | |||
| ) | |||
| return GeneralStaticGraphImplementation( | |||
| _heterogeneous_nodes_container, | |||
| _heterogeneous_edges_aggregation, | |||
| StaticGraphDataAggregation(graph_data) | |||
| ) | |||
| @classmethod | |||
| def create_homogeneous_static_graph( | |||
| cls, nodes_data: _typing.Mapping[str, torch.Tensor], | |||
| edges_connections: torch.Tensor, | |||
| edges_data: _typing.Optional[_typing.Mapping[str, torch.Tensor]] = ..., | |||
| graph_data: _typing.Optional[_typing.Mapping[str, torch.Tensor]] = ... | |||
| ) -> _general_static_graph.GeneralStaticGraph: | |||
| _heterogeneous_nodes_container: HeterogeneousNodesContainer = ( | |||
| HeterogeneousNodesContainerImplementation({'': nodes_data}) | |||
| ) | |||
| _heterogeneous_edges_aggregation: HeterogeneousEdgesAggregation = ( | |||
| HeterogeneousEdgesAggregationImplementation() | |||
| ) | |||
| _heterogeneous_edges_aggregation[('', '', '')] = ( | |||
| HomogeneousEdgesContainerImplementation(edges_connections, edges_data) | |||
| ) | |||
| return GeneralStaticGraphImplementation( | |||
| _heterogeneous_nodes_container, | |||
| _heterogeneous_edges_aggregation, | |||
| StaticGraphDataAggregation(graph_data) | |||
| ) | |||
| @@ -0,0 +1,19 @@ | |||
| from ._nx import ( | |||
| HomogeneousStaticGraphToNetworkX | |||
| ) | |||
| try: | |||
| import dgl | |||
| except ModuleNotFoundError: | |||
| dgl = None | |||
| else: | |||
| from ._dgl import ( | |||
| DGLGraphToGeneralStaticGraph, dgl_graph_to_general_static_graph, | |||
| GeneralStaticGraphToDGLGraph, general_static_graph_to_dgl_graph, | |||
| ) | |||
| try: | |||
| import torch_geometric | |||
| except ModuleNotFoundError: | |||
| torch_geometric = None | |||
| else: | |||
| from ._pyg import StaticGraphToPyGData, static_graph_to_pyg_data | |||
| @@ -0,0 +1,136 @@ | |||
| import dgl | |||
| import torch | |||
| import typing as _typing | |||
| from ..._general_static_graph import GeneralStaticGraph | |||
| from ... import ( | |||
| _general_static_graph_generator, _general_static_graph_dgl_implementation | |||
| ) | |||
| class GeneralStaticGraphToDGLGraph: | |||
| def __init__(self, *__args, **__kwargs): | |||
| pass | |||
| def __call__(self, static_graph: GeneralStaticGraph) -> dgl.DGLGraph: | |||
| dgl_graph: dgl.DGLGraph = dgl.heterograph( | |||
| dict([ | |||
| ( | |||
| ( | |||
| canonical_edge_type.source_node_type, | |||
| canonical_edge_type.relation_type, | |||
| canonical_edge_type.target_node_type | |||
| ), | |||
| ( | |||
| static_graph.edges[canonical_edge_type].connections[0], | |||
| static_graph.edges[canonical_edge_type].connections[1] | |||
| ) | |||
| ) | |||
| for canonical_edge_type in static_graph.edges | |||
| ]) | |||
| ) | |||
| for node_type in static_graph.nodes: | |||
| for data_key in static_graph.nodes[node_type].data: | |||
| dgl_graph.nodes[node_type].data[data_key] = ( | |||
| static_graph.nodes[node_type].data[data_key] | |||
| ) | |||
| for canonical_edge_type in static_graph.edges: | |||
| for data_key in static_graph.edges[canonical_edge_type].data: | |||
| dgl_graph.edges[ | |||
| ( | |||
| canonical_edge_type.source_node_type, | |||
| canonical_edge_type.relation_type, | |||
| canonical_edge_type.target_node_type | |||
| ) | |||
| ].data[data_key] = ( | |||
| static_graph.edges[canonical_edge_type].data[data_key] | |||
| ) | |||
| # Set graph level data by `setattr` | |||
| if len(static_graph.data) > 0: | |||
| setattr(dgl_graph, "graph_data", dict(static_graph.data)) | |||
| if "gf" in static_graph.data: | |||
| setattr(dgl_graph, "gf", static_graph.data["gf"].detach().clone()) | |||
| return dgl_graph | |||
| class DGLGraphToGeneralStaticGraph: | |||
| def __init__( | |||
| self, as_universal_storage_format: bool = False, | |||
| *__args, **__kwargs | |||
| ): | |||
| if not isinstance(as_universal_storage_format, bool): | |||
| raise TypeError | |||
| else: | |||
| self._as_universal_storage_format: bool = as_universal_storage_format | |||
| def __call__( | |||
| self, dgl_graph: dgl.DGLGraph, | |||
| as_universal_storage_format: _typing.Optional[bool] = ..., | |||
| *__args, **__kwargs | |||
| ) -> GeneralStaticGraph: | |||
| if not ( | |||
| as_universal_storage_format in (Ellipsis, None) or | |||
| isinstance(as_universal_storage_format, bool) | |||
| ): | |||
| raise TypeError | |||
| _as_universal_storage_format: bool = ( | |||
| as_universal_storage_format | |||
| if isinstance(as_universal_storage_format, bool) | |||
| else self._as_universal_storage_format | |||
| ) | |||
| if not _as_universal_storage_format: | |||
| general_static_graph: GeneralStaticGraph = ( | |||
| _general_static_graph_dgl_implementation.GeneralStaticGraphDGLImplementation(dgl_graph) | |||
| ) | |||
| else: | |||
| general_static_graph: GeneralStaticGraph = ( | |||
| _general_static_graph_generator.GeneralStaticGraphGenerator.create_heterogeneous_static_graph( | |||
| dict([(node_type, dgl_graph.nodes[node_type].data) for node_type in dgl_graph.ntypes]), | |||
| dict([ | |||
| ( | |||
| canonical_edge_type, | |||
| ( | |||
| torch.vstack(dgl_graph.edges(etype=canonical_edge_type)), | |||
| dgl_graph.edges[canonical_edge_type].data | |||
| ) | |||
| ) | |||
| for canonical_edge_type in dgl_graph.canonical_etypes] | |||
| ) | |||
| ) | |||
| ) | |||
| if ( | |||
| hasattr(dgl_graph, "graph_data") and | |||
| isinstance(getattr(dgl_graph, "graph_data"), _typing.Mapping) | |||
| ): | |||
| graph_data: _typing.Mapping[str, torch.Tensor] = getattr(dgl_graph, "graph_data") | |||
| for k, v in graph_data.items(): | |||
| if ( | |||
| isinstance(k, str) and ' ' not in k and | |||
| isinstance(v, torch.Tensor) | |||
| ): | |||
| general_static_graph.data[k] = v | |||
| for k in ("gf",): | |||
| if ( | |||
| hasattr(dgl_graph, k) and | |||
| isinstance(getattr(dgl_graph, k), torch.Tensor) | |||
| ): | |||
| general_static_graph.data[k] = getattr(dgl_graph, k) | |||
| return general_static_graph | |||
| def general_static_graph_to_dgl_graph( | |||
| general_static_graph: GeneralStaticGraph, *__args, **__kwargs | |||
| ) -> dgl.DGLGraph: | |||
| return GeneralStaticGraphToDGLGraph(*__args, **__kwargs).__call__( | |||
| general_static_graph | |||
| ) | |||
| def dgl_graph_to_general_static_graph( | |||
| dgl_graph: dgl.DGLGraph, as_universal_storage_format: bool = False, | |||
| *__args, **__kwargs | |||
| ) -> GeneralStaticGraph: | |||
| return DGLGraphToGeneralStaticGraph(as_universal_storage_format).__call__( | |||
| dgl_graph, as_universal_storage_format | |||
| ) | |||
| @@ -0,0 +1,74 @@ | |||
| import typing as _typing | |||
| import networkx as nx | |||
| from autogl.data.graph._general_static_graph import GeneralStaticGraph | |||
| class HomogeneousStaticGraphToNetworkX: | |||
| def __init__( | |||
| self, remove_self_loops: bool = False, to_undirected: bool = False, | |||
| *__args, **__kwargs | |||
| ): | |||
| if not isinstance(remove_self_loops, bool): | |||
| raise TypeError | |||
| if not isinstance(to_undirected, bool): | |||
| raise TypeError | |||
| self.__remove_self_loops: bool = remove_self_loops | |||
| self.__to_undirected: bool = to_undirected | |||
| def __call__( | |||
| self, homogeneous_static_graph: GeneralStaticGraph, | |||
| remove_self_loops: _typing.Optional[bool] = ..., | |||
| to_undirected: _typing.Optional[bool] = ..., | |||
| *args, **kwargs | |||
| ): | |||
| if not isinstance(homogeneous_static_graph, GeneralStaticGraph): | |||
| raise TypeError | |||
| elif not ( | |||
| homogeneous_static_graph.nodes.is_homogeneous and | |||
| homogeneous_static_graph.edges.is_homogeneous | |||
| ): | |||
| raise ValueError("Only homogeneous static graph can be converted to NetworkX") | |||
| if not (remove_self_loops in (Ellipsis, None) or isinstance(remove_self_loops, bool)): | |||
| raise TypeError | |||
| else: | |||
| __remove_self_loops: bool = ( | |||
| remove_self_loops if isinstance(remove_self_loops, bool) | |||
| else self.__remove_self_loops | |||
| ) | |||
| if not (to_undirected in (Ellipsis, None) or isinstance(to_undirected, bool)): | |||
| raise TypeError | |||
| else: | |||
| __to_undirected: bool = ( | |||
| to_undirected if isinstance(to_undirected, bool) | |||
| else self.__to_undirected | |||
| ) | |||
| num_nodes: int = homogeneous_static_graph.edges.connections.max().item() + 1 | |||
| # todo: Note that this is an assumption | |||
| g: nx.Graph = nx.Graph() if __to_undirected else nx.DiGraph() | |||
| g.add_nodes_from(range(num_nodes)) | |||
| nodes_data: _typing.MutableMapping[str, list] = {} | |||
| for data_key in homogeneous_static_graph.nodes.data: | |||
| nodes_data[data_key] = ( | |||
| homogeneous_static_graph.nodes.data[data_key].squeeze().tolist() | |||
| ) | |||
| edges_data: _typing.MutableMapping[str, list] = {} | |||
| for data_key in homogeneous_static_graph.edges.data: | |||
| edges_data[data_key] = ( | |||
| homogeneous_static_graph.edges.data[data_key].squeeze().tolist() | |||
| ) | |||
| for i, (u, v) in enumerate(homogeneous_static_graph.edges.connections.t().tolist()): | |||
| if __remove_self_loops and v == u: | |||
| continue | |||
| g.add_edge(u, v) | |||
| for data_key in edges_data: | |||
| g[u][v][data_key] = edges_data[data_key][i] | |||
| for data_key in nodes_data: | |||
| for i, feature_dict in g.nodes(data=True): | |||
| feature_dict.update( | |||
| {data_key: nodes_data[data_key][i]} | |||
| ) | |||
| return g | |||
| @@ -0,0 +1,77 @@ | |||
| import torch | |||
| import typing as _typing | |||
| import torch_geometric | |||
| from ... import GeneralStaticGraph | |||
| class StaticGraphToPyGData: | |||
| def __init__(self, *__args, **__kwargs): | |||
| pass | |||
| def __call__( | |||
| self, static_graph: GeneralStaticGraph, | |||
| *__args, **__kwargs | |||
| ): | |||
| if not isinstance(static_graph, GeneralStaticGraph): | |||
| raise TypeError | |||
| elif not static_graph.nodes.is_homogeneous: | |||
| raise ValueError("Provided static graph MUST consist of homogeneous nodes") | |||
| homogeneous_node_type: _typing.Optional[str] = ( | |||
| list(static_graph.nodes)[0] | |||
| if len(list(static_graph.nodes)) > 0 else None | |||
| ) | |||
| data: _typing.Dict[str, torch.Tensor] = dict() | |||
| if isinstance(homogeneous_node_type, str): | |||
| node_and_edge_data_keys_intersection: _typing.Set[str] = ( | |||
| set(static_graph.nodes.data) & set(static_graph.data) | |||
| ) | |||
| if len(node_and_edge_data_keys_intersection) > 0: | |||
| raise ValueError( | |||
| f"Provided static graph contains duplicate data " | |||
| f"with same keys {node_and_edge_data_keys_intersection}" | |||
| f"for homogeneous nodes data and graph-level data, " | |||
| f"please refer to doc for more details." | |||
| ) | |||
| data.update(static_graph.nodes.data) | |||
| data.update(static_graph.data) | |||
| else: | |||
| data.update(static_graph.data) | |||
| if len(list(static_graph.edges)) == 1: | |||
| data["edge_index"] = static_graph.edges.connections | |||
| if len(set(data.keys()) & set(static_graph.edges.data.keys())) > 0: | |||
| raise ValueError( | |||
| "Provided static graph contains duplicate data with same key, " | |||
| "please refer to doc for more details." | |||
| ) | |||
| data.update(static_graph.edges.data) | |||
| elif len(list(static_graph.edges)) > 1: | |||
| for canonical_edge_type in static_graph.edges: | |||
| if homogeneous_node_type is not None and isinstance(homogeneous_node_type, str) and ( | |||
| canonical_edge_type.source_node_type != homogeneous_node_type or | |||
| canonical_edge_type.target_node_type != homogeneous_node_type | |||
| ): | |||
| continue | |||
| if len(canonical_edge_type.relation_type) < 4 or canonical_edge_type[-4:] != 'edge': | |||
| continue | |||
| data[f"{canonical_edge_type.relation_type}_index"] = ( | |||
| static_graph.edges[canonical_edge_type].connections | |||
| ) | |||
| edge_type_prefix: str = canonical_edge_type.relation_type[:-4] | |||
| for data_key in static_graph.edges[canonical_edge_type].data: | |||
| if len(data_key) >= 4 and data_key[:4] == 'edge': | |||
| data[f"{edge_type_prefix}{data_key}"] = ( | |||
| static_graph.edges[canonical_edge_type].data[data_key].detach() | |||
| ) | |||
| else: | |||
| data[f"{canonical_edge_type.relation_type}_{data_key}"] = ( | |||
| static_graph.edges[canonical_edge_type].data[data_key].detach() | |||
| ) | |||
| pyg_data: torch_geometric.data.Data = torch_geometric.data.Data(**data) | |||
| return pyg_data | |||
| def static_graph_to_pyg_data(static_graph: GeneralStaticGraph) -> torch_geometric.data.Data: | |||
| return StaticGraphToPyGData().__call__(static_graph) | |||
| @@ -0,0 +1 @@ | |||
| from .._general_static_graph.utils.conversion import * | |||
| @@ -1,65 +0,0 @@ | |||
| Datasets are derived from PyG, OGB and CogDL. | |||
| ================= | |||
| AutoGL now supports the following benchmarks for different tasks: | |||
| - semi-supervised node classification: Cora, Citeseer, Pubmed, Amazon Computers\*, Amazon Photo\*, Coauthor CS\*, Coauthor Physics\*, Reddit (\*: using `utils.random_splits_mask_class` for splitting dataset is recommended.) | |||
| | Dataset | PyG | CogDL | x | y | edge_index | edge_attr | train/val/test node | train/val/test mask | | |||
| | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | | |||
| | Cora | ✓ | | ✓ | ✓ | ✓ | ✓ | | ✓ | | |||
| | Citeseer | ✓ | | ✓ | ✓ | ✓ | ✓ | | ✓ | | |||
| | Pubmed | ✓ | | ✓ | ✓ | ✓ | ✓ | | ✓ | | |||
| | Amazon Computers | ✓ | | ✓ | ✓ | ✓ | ✓ | | | | |||
| | Amazon Photo | ✓ | | ✓ | ✓ | ✓ | ✓ | | | | |||
| | Coauthor CS | ✓ | | ✓ | ✓ | ✓ | ✓ | | | | |||
| | Coauthor Physics | ✓ | | ✓ | ✓ | ✓ | ✓ | | | | |||
| | Reddit | ✓ | | ✓ | ✓ | ✓ | ✓ | | ✓ | | |||
| - supervised graph classification: MUTAG, IMDB-B, IMDB-M, PROTEINS, COLLAB | |||
| | Dataset | PyG | CogDL | x | y | edge_index | edge_attr | train/val/test node | train/val/test mask | adj| | |||
| | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | | |||
| | Mutag | ✓ | | ✓ | ✓ | ✓ | ✓ | | | | | |||
| | IMDB-B | ✓ | | | ✓ | ✓ | | | | | | |||
| | IMDB-M | ✓ | | | ✓ | ✓ | | | | | | |||
| | PROTEINS | ✓ | | ✓ | ✓ | ✓ | | | | | | |||
| | COLLAB | ✓ | | | ✓ | ✓ | | | | | | |||
| - node classification datasets from OGB: ogbn-products, ogbn-proteins, ogbn-arxiv, ogbn-papers100M and ogbn-mag. | |||
| - graph classification datasets from OGB: ogbg-molhiv, ogbg-molpcba, ogbg-ppa and ogbg-code. | |||
| --- | |||
| TODO: | |||
| In future version, AutoGL will support the following benchmarks for different tasks: | |||
| - unsupervised node classification: PPI, Blogcatalog, Wikipedia | |||
| - heterogeneous node classification: DBLP, ACM, IMDB | |||
| - link prediction: PPI, Wikipedia, Blogcatalog | |||
| - multiplex link prediction: Amazon, YouTube, Twitter | |||
| - link prediction datasets from OGB: ogbl-ppa, ogbl-collab, ogbl-ddi, ogbl-citation, ogbl-wikikg and ogbl-biokg. | |||
| <!-- | |||
| | Dataset | PyG | CogDL | x | y | edge_index | edge_attr | train/val/test node | train/val/test mask | adj| | |||
| | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | | |||
| | ACM | | ✓ | ✓ | ✓ | ✓ | | ✓ | | ✓ list | | |||
| | DBLP | | ✓ | ✓ | ✓ | ✓ | | ✓ | | ✓ list | | |||
| | IMDB | | ✓ | ✓ | ✓ | ✓ | | ✓ | | ✓ list | | |||
| | Flickr | | ✓ | | ✓ | ✓ | ✓ | | | | | |||
| | Blogcatalog | | ✓ | | ✓ | ✓ | ✓ | | | | | |||
| | PPI | | ✓ | | ✓ | ✓ | ✓ | | | | | |||
| | Wikipedia | | ✓ | | ✓ | ✓ | ✓ | | | | | |||
| | Amazon | | ✓ | | | | | ✓ data | | | | |||
| | Twitter | | ✓ | | | | | ✓ data | | | | |||
| | Youtube | | ✓ | | | | | ✓ data | | | | |||
| | NCI1 | ✓ | | ✓ | ✓ | ✓ | | | | | | |||
| | NCI109 | ✓ | | ✓ | ✓ | ✓ | | | | | | |||
| | Enzyme | ✓ | | ✓ | ✓ | ✓ | | | | | | |||
| | Reddit-B | ✓ | | | ✓ | ✓ | | | | | | |||
| | Reddit-Multi-5k | ✓ | | | ✓ | ✓ | | | | | | |||
| | Reddit-Multi-12k | ✓ | | | ✓ | ✓ | | | | | | |||
| | PTC-MR | ✓ | | ✓ | ✓ | ✓ | ✓ | | | | | |||
| --> | |||
| @@ -1,203 +1,70 @@ | |||
| import os.path as osp | |||
| import os | |||
| import torch | |||
| from ..data.dataset import Dataset | |||
| try: | |||
| import torch_geometric | |||
| except ImportError: | |||
| torch_geometric = None | |||
| pyg = False | |||
| else: | |||
| pyg = True | |||
| DATASET_DICT = {} | |||
| def register_dataset(name): | |||
| """ | |||
| New dataset types can be added to autogl with the :func:`register_dataset` | |||
| function decorator. | |||
| For example:: | |||
| @register_dataset('my_dataset') | |||
| class MyDataset(): | |||
| (...) | |||
| Args: | |||
| name (str): the name of the dataset | |||
| """ | |||
| def register_dataset_cls(cls): | |||
| if name in DATASET_DICT: | |||
| raise ValueError("Cannot register duplicate dataset ({})".format(name)) | |||
| if not issubclass(cls, Dataset) and ( | |||
| pyg and not issubclass(cls, torch_geometric.data.Dataset) | |||
| ): | |||
| raise ValueError( | |||
| "Dataset ({}: {}) must extend autogl.data.Dataset".format( | |||
| name, cls.__name__ | |||
| ) | |||
| ) | |||
| DATASET_DICT[name] = cls | |||
| return cls | |||
| return register_dataset_cls | |||
| from .pyg import ( | |||
| AmazonComputersDataset, | |||
| AmazonPhotoDataset, | |||
| CoauthorPhysicsDataset, | |||
| CoauthorCSDataset, | |||
| CoraDataset, | |||
| CiteSeerDataset, | |||
| PubMedDataset, | |||
| RedditDataset, | |||
| MUTAGDataset, | |||
| IMDBBinaryDataset, | |||
| IMDBMultiDataset, | |||
| CollabDataset, | |||
| ProteinsDataset, | |||
| REDDITBinary, | |||
| REDDITMulti5K, | |||
| REDDITMulti12K, | |||
| PTCMRDataset, | |||
| NCI1Dataset, | |||
| ENZYMES, | |||
| QM9Dataset, | |||
| ) | |||
| from .ogb import ( | |||
| OGBNproductsDataset, | |||
| OGBNproteinsDataset, | |||
| OGBNarxivDataset, | |||
| OGBNpapers100MDataset, | |||
| OGBNmagDataset, | |||
| OGBGmolhivDataset, | |||
| OGBGmolpcbaDataset, | |||
| OGBGppaDataset, | |||
| OGBGcodeDataset, | |||
| OGBLppaDataset, | |||
| OGBLcollabDataset, | |||
| OGBLddiDataset, | |||
| OGBLcitationDataset, | |||
| OGBLwikikgDataset, | |||
| OGBLbiokgDataset, | |||
| ) | |||
| from .gatne import GatneDataset, AmazonDataset, TwitterDataset, YouTubeDataset | |||
| from .gtn_data import GTNDataset, ACM_GTNDataset, DBLP_GTNDataset, IMDB_GTNDataset | |||
| from .han_data import HANDataset, ACM_HANDataset, DBLP_HANDataset, IMDB_HANDataset | |||
| from .matlab_matrix import ( | |||
| MatlabMatrix, | |||
| BlogcatalogDataset, | |||
| WikipediaDataset, | |||
| PPIDataset, | |||
| ) | |||
| from .modelnet import ( | |||
| ModelNet10, | |||
| ModelNet40, | |||
| ModelNet10Train, | |||
| ModelNet10Test, | |||
| ModelNet40Train, | |||
| ModelNet40Test, | |||
| ) | |||
| from .utils import ( | |||
| get_label_number, | |||
| random_splits_mask, | |||
| random_splits_mask_class, | |||
| graph_cross_validation, | |||
| graph_set_fold_id, | |||
| graph_random_splits, | |||
| graph_get_split, | |||
| from autogl import backend as _backend | |||
| from ._dataset_registry import ( | |||
| DatasetUniversalRegistry, | |||
| build_dataset_from_name | |||
| ) | |||
| from ._gtn_data import ( | |||
| GTNACMDataset, | |||
| GTNDBLPDataset, | |||
| GTNIMDBDataset, | |||
| ) | |||
| def build_dataset(args, path="~/.cache-autogl/"): | |||
| path = osp.join(path, "data", args.dataset) | |||
| path = os.path.expanduser(path) | |||
| return DATASET_DICT[args.dataset](path) | |||
| def build_dataset_from_name(dataset_name, path="~/.cache-autogl/"): | |||
| path = osp.join(path, "data", dataset_name) | |||
| path = os.path.expanduser(path) | |||
| dataset = DATASET_DICT[dataset_name](path) | |||
| if "ogbn" in dataset_name: | |||
| # dataset.data, dataset.slices = dataset.collate([dataset.data]) | |||
| # dataset.data.num_nodes = dataset.data.num_nodes[0] | |||
| if dataset.data.y.shape[-1] == 1: | |||
| dataset.data.y = torch.squeeze(dataset.data.y) | |||
| return dataset | |||
| from ._matlab_matrix import BlogCatalogDataset, WIKIPEDIADataset | |||
| from ._ogb import ( | |||
| OGBNProductsDataset, OGBNProteinsDataset, OGBNArxivDataset, OGBNPapers100MDataset, | |||
| OGBLPPADataset, OGBLCOLLABDataset, OGBLDDIDataset, OGBLCitation2Dataset, | |||
| OGBGMOLHIVDataset, OGBGMOLPCBADataset, OGBGPPADataset, OGBGCode2Dataset | |||
| ) | |||
| __all__ = [ | |||
| "register_dataset", | |||
| "build_dataset", | |||
| "build_dataset_from_name", | |||
| "get_label_number", | |||
| "random_splits_mask", | |||
| "random_splits_mask_class", | |||
| "graph_cross_validation", | |||
| "graph_set_fold_id", | |||
| "graph_random_splits", | |||
| "graph_get_split", | |||
| "AmazonComputersDataset", | |||
| "AmazonPhotoDataset", | |||
| "CoauthorPhysicsDataset", | |||
| "CoauthorCSDataset", | |||
| "CoraDataset", | |||
| "CiteSeerDataset", | |||
| "PubMedDataset", | |||
| "RedditDataset", | |||
| "MUTAGDataset", | |||
| "IMDBBinaryDataset", | |||
| "IMDBMultiDataset", | |||
| "CollabDataset", | |||
| "ProteinsDataset", | |||
| "REDDITBinary", | |||
| "REDDITMulti5K", | |||
| "REDDITMulti12K", | |||
| "PTCMRDataset", | |||
| "NCI1Dataset", | |||
| "ENZYMES", | |||
| "QM9Dataset", | |||
| "OGBNproductsDataset", | |||
| "OGBNproteinsDataset", | |||
| "OGBNarxivDataset", | |||
| "OGBNpapers100MDataset", | |||
| "OGBNmagDataset", | |||
| "OGBGmolhivDataset", | |||
| "OGBGmolpcbaDataset", | |||
| "OGBGppaDataset", | |||
| "OGBGcodeDataset", | |||
| "OGBLppaDataset", | |||
| "OGBLcollabDataset", | |||
| "OGBLddiDataset", | |||
| "OGBLcitationDataset", | |||
| "OGBLwikikgDataset", | |||
| "OGBLbiokgDataset", | |||
| "GatneDataset", | |||
| "AmazonDataset", | |||
| "TwitterDataset", | |||
| "YouTubeDataset", | |||
| "GTNDataset", | |||
| "ACM_GTNDataset", | |||
| "DBLP_GTNDataset", | |||
| "IMDB_GTNDataset", | |||
| "HANDataset", | |||
| "ACM_HANDataset", | |||
| "DBLP_HANDataset", | |||
| "IMDB_HANDataset", | |||
| "MatlabMatrix", | |||
| "BlogcatalogDataset", | |||
| "WikipediaDataset", | |||
| "PPIDataset", | |||
| "ModelNet10", | |||
| "ModelNet40", | |||
| "ModelNet10Train", | |||
| "ModelNet10Test", | |||
| "ModelNet40Train", | |||
| "ModelNet40Test", | |||
| ] | |||
| if _backend.DependentBackend.is_dgl(): | |||
| from ._dgl import ( | |||
| CoraDataset, | |||
| CiteSeerDataset, | |||
| PubMedDataset, | |||
| RedditDataset, | |||
| AmazonComputersDataset, | |||
| AmazonPhotoDataset, | |||
| CoauthorPhysicsDataset, | |||
| CoauthorCSDataset, | |||
| MUTAGDataset, | |||
| ENZYMESDataset, | |||
| IMDBBinaryDataset, | |||
| IMDBMultiDataset, | |||
| RedditBinaryDataset, | |||
| REDDITMulti5KDataset, | |||
| COLLABDataset, | |||
| ProteinsDataset, | |||
| PTCMRDataset, | |||
| NCI1Dataset | |||
| ) | |||
| elif _backend.DependentBackend.is_pyg(): | |||
| from ._pyg import ( | |||
| CoraDataset, | |||
| CiteSeerDataset, | |||
| PubMedDataset, | |||
| FlickrDataset, | |||
| RedditDataset, | |||
| AmazonComputersDataset, | |||
| AmazonPhotoDataset, | |||
| CoauthorPhysicsDataset, | |||
| CoauthorCSDataset, | |||
| PPIDataset, | |||
| QM9Dataset, | |||
| MUTAGDataset, | |||
| ENZYMESDataset, | |||
| IMDBBinaryDataset, | |||
| IMDBMultiDataset, | |||
| RedditBinaryDataset, | |||
| REDDITMulti5KDataset, | |||
| REDDITMulti12KDataset, | |||
| COLLABDataset, | |||
| ProteinsDataset, | |||
| PTCMRDataset, | |||
| NCI1Dataset, | |||
| NCI109Dataset, | |||
| ModelNet10TrainingDataset, | |||
| ModelNet10TestDataset, | |||
| ModelNet40TrainingDataset, | |||
| ModelNet40TestDataset | |||
| ) | |||
| @@ -0,0 +1,80 @@ | |||
| import os | |||
| import typing as _typing | |||
| class OnlineDataSource: | |||
| @property | |||
| def _raw_directory(self) -> str: | |||
| return os.path.join(self.__path, "raw") | |||
| @property | |||
| def _processed_directory(self) -> str: | |||
| return os.path.join(self.__path, "processed") | |||
| @property | |||
| def _raw_filenames(self) -> _typing.Iterable[str]: | |||
| raise NotImplementedError | |||
| @property | |||
| def _processed_filenames(self) -> _typing.Iterable[str]: | |||
| raise NotImplementedError | |||
| @property | |||
| def _raw_file_paths(self) -> _typing.Iterable[str]: | |||
| return [ | |||
| os.path.join(self._raw_directory, raw_filename) | |||
| for raw_filename in self._raw_filenames | |||
| ] | |||
| @property | |||
| def _processed_file_paths(self) -> _typing.Iterable[str]: | |||
| return [ | |||
| os.path.join(self._processed_directory, processed_filename) | |||
| for processed_filename in self._processed_filenames | |||
| ] | |||
| @classmethod | |||
| def __files_exist(cls, files: _typing.Iterable[str]) -> bool: | |||
| return all([os.path.exists(file) for file in files]) | |||
| @classmethod | |||
| def __make_directory(cls, path): | |||
| import errno | |||
| try: | |||
| os.makedirs(os.path.expanduser(os.path.normpath(path))) | |||
| except OSError as e: | |||
| if e.errno != errno.EEXIST and os.path.isdir(path): | |||
| raise e | |||
| def _fetch(self): | |||
| raise NotImplementedError | |||
| def __fetch(self): | |||
| if not self.__files_exist(self._raw_file_paths): | |||
| self.__make_directory(self._raw_directory) | |||
| self._fetch() | |||
| def _process(self): | |||
| raise NotImplementedError | |||
| def __preprocess(self): | |||
| if not self.__files_exist(self._processed_file_paths): | |||
| self.__make_directory(self._processed_directory) | |||
| self._process() | |||
| def __getitem__(self, index: int) -> _typing.Any: | |||
| raise NotImplementedError | |||
| def __len__(self) -> int: | |||
| raise NotImplementedError | |||
| def __init__( | |||
| self, path: str, | |||
| # transform: _typing.Optional[_typing.Callable[[_typing.Any], _typing.Any]] = ... | |||
| ): | |||
| self.__path: str = os.path.expanduser(os.path.normpath(path)) | |||
| # self.__transform: _typing.Optional[_typing.Callable[[_typing.Any], _typing.Any]] = ( | |||
| # transform if transform not in (Ellipsis, None) and callable(transform) else None | |||
| # ) | |||
| self.__fetch() | |||
| self.__preprocess() | |||
| @@ -0,0 +1,45 @@ | |||
| import os | |||
| import typing as _typing | |||
| from autogl.data import Dataset | |||
| class _DatasetUniversalRegistryMetaclass(type): | |||
| def __new__( | |||
| mcs, name: str, bases: _typing.Tuple[type, ...], | |||
| namespace: _typing.Dict[str, _typing.Any] | |||
| ): | |||
| return super(_DatasetUniversalRegistryMetaclass, mcs).__new__( | |||
| mcs, name, bases, namespace | |||
| ) | |||
| def __init__( | |||
| cls, name: str, bases: _typing.Tuple[type, ...], | |||
| namespace: _typing.Dict[str, _typing.Any] | |||
| ): | |||
| super(_DatasetUniversalRegistryMetaclass, cls).__init__(name, bases, namespace) | |||
| cls._dataset_universal_registry: _typing.MutableMapping[str, _typing.Type[Dataset]] = {} | |||
| class DatasetUniversalRegistry(metaclass=_DatasetUniversalRegistryMetaclass): | |||
| @classmethod | |||
| def register_dataset(cls, dataset_name: str): | |||
| def register_dataset_cls(dataset: _typing.Type[Dataset]): | |||
| if dataset_name in cls._dataset_universal_registry: | |||
| raise ValueError(f"Dataset with name \"{dataset_name}\" already exists!") | |||
| elif not issubclass(dataset, Dataset): | |||
| raise TypeError | |||
| else: | |||
| cls._dataset_universal_registry[dataset_name] = dataset | |||
| return dataset | |||
| return register_dataset_cls | |||
| @classmethod | |||
| def get_dataset(cls, dataset_name: str) -> _typing.Type[Dataset]: | |||
| return cls._dataset_universal_registry.get(dataset_name) | |||
| def build_dataset_from_name(dataset_name: str, path: str = "~/.cache-autogl/"): | |||
| path = os.path.expanduser(os.path.join(path, "data", dataset_name)) | |||
| _dataset = DatasetUniversalRegistry.get_dataset(dataset_name) | |||
| return _dataset(path) | |||
| @@ -0,0 +1,544 @@ | |||
| import os | |||
| import torch | |||
| import dgl | |||
| # from autogl.data.graph import GeneralStaticGraphGenerator | |||
| from autogl.data.graph.utils import conversion as _conversion | |||
| from autogl.data import InMemoryStaticGraphSet | |||
| from ._dataset_registry import DatasetUniversalRegistry | |||
| @DatasetUniversalRegistry.register_dataset("cora") | |||
| class CoraDataset(InMemoryStaticGraphSet): | |||
| def __init__(self, path: str): | |||
| dgl_dataset = dgl.data.CoraGraphDataset( | |||
| os.path.join(path, '_dgl') | |||
| ) | |||
| dgl_graph: dgl.DGLGraph = dgl_dataset[0] | |||
| super(CoraDataset, self).__init__( | |||
| [_conversion.dgl_graph_to_general_static_graph(dgl_graph)] | |||
| ) | |||
| # super(CoraDataset, self).__init__( | |||
| # [ | |||
| # GeneralStaticGraphGenerator.create_homogeneous_static_graph( | |||
| # { | |||
| # 'feat': dgl_graph.ndata['feat'], | |||
| # 'label': dgl_graph.ndata['label'], | |||
| # 'train_mask': dgl_graph.ndata['train_mask'], | |||
| # 'val_mask': dgl_graph.ndata['val_mask'], | |||
| # 'test_mask': dgl_graph.ndata['test_mask'] | |||
| # }, | |||
| # torch.vstack(dgl_graph.edges()) | |||
| # ) | |||
| # ] | |||
| # ) | |||
| @DatasetUniversalRegistry.register_dataset("CiteSeer".lower()) | |||
| class CiteSeerDataset(InMemoryStaticGraphSet): | |||
| def __init__(self, path: str): | |||
| dgl_dataset = dgl.data.CiteseerGraphDataset( | |||
| os.path.join(path, '_dgl') | |||
| ) | |||
| dgl_graph: dgl.DGLGraph = dgl_dataset[0] | |||
| super(CiteSeerDataset, self).__init__( | |||
| [_conversion.dgl_graph_to_general_static_graph(dgl_graph)] | |||
| ) | |||
| # super(CiteSeerDataset, self).__init__( | |||
| # [ | |||
| # GeneralStaticGraphGenerator.create_homogeneous_static_graph( | |||
| # { | |||
| # 'feat': dgl_graph.ndata['feat'], | |||
| # 'label': dgl_graph.ndata['label'], | |||
| # 'train_mask': dgl_graph.ndata['train_mask'], | |||
| # 'val_mask': dgl_graph.ndata['val_mask'], | |||
| # 'test_mask': dgl_graph.ndata['test_mask'] | |||
| # }, | |||
| # torch.vstack(dgl_graph.edges()) | |||
| # ) | |||
| # ] | |||
| # ) | |||
| @DatasetUniversalRegistry.register_dataset("PubMed".lower()) | |||
| class PubMedDataset(InMemoryStaticGraphSet): | |||
| def __init__(self, path: str): | |||
| dgl_dataset = dgl.data.PubmedGraphDataset( | |||
| os.path.join(path, '_dgl') | |||
| ) | |||
| dgl_graph: dgl.DGLGraph = dgl_dataset[0] | |||
| super(PubMedDataset, self).__init__( | |||
| [_conversion.dgl_graph_to_general_static_graph(dgl_graph)] | |||
| ) | |||
| # super(PubMedDataset, self).__init__( | |||
| # [ | |||
| # GeneralStaticGraphGenerator.create_homogeneous_static_graph( | |||
| # { | |||
| # 'feat': dgl_graph.ndata['feat'], | |||
| # 'label': dgl_graph.ndata['label'], | |||
| # 'train_mask': dgl_graph.ndata['train_mask'], | |||
| # 'val_mask': dgl_graph.ndata['val_mask'], | |||
| # 'test_mask': dgl_graph.ndata['test_mask'] | |||
| # }, | |||
| # torch.vstack(dgl_graph.edges()) | |||
| # ) | |||
| # ] | |||
| # ) | |||
| @DatasetUniversalRegistry.register_dataset("reddit") | |||
| class RedditDataset(InMemoryStaticGraphSet): | |||
| def __init__(self, path: str): | |||
| dgl_dataset = dgl.data.RedditDataset( | |||
| raw_dir=os.path.join(path, '_dgl') | |||
| ) | |||
| dgl_graph: dgl.DGLGraph = dgl_dataset[0] | |||
| super(RedditDataset, self).__init__( | |||
| [_conversion.dgl_graph_to_general_static_graph(dgl_graph)] | |||
| ) | |||
| # super(RedditDataset, self).__init__( | |||
| # [ | |||
| # GeneralStaticGraphGenerator.create_homogeneous_static_graph( | |||
| # { | |||
| # 'feat': dgl_graph.ndata['feat'], | |||
| # 'label': dgl_graph.ndata['label'], | |||
| # 'train_mask': dgl_graph.ndata['train_mask'], | |||
| # 'val_mask': dgl_graph.ndata['val_mask'], | |||
| # 'test_mask': dgl_graph.ndata['test_mask'] | |||
| # }, | |||
| # torch.vstack(dgl_graph.edges()) | |||
| # ) | |||
| # ] | |||
| # ) | |||
| @DatasetUniversalRegistry.register_dataset("amazon_computers") | |||
| class AmazonComputersDataset(InMemoryStaticGraphSet): | |||
| def __init__(self, path: str): | |||
| dgl_dataset = dgl.data.AmazonCoBuyComputerDataset( | |||
| raw_dir=os.path.join(path, '_dgl') | |||
| ) | |||
| dgl_graph: dgl.DGLGraph = dgl_dataset[0] | |||
| super(AmazonComputersDataset, self).__init__( | |||
| [_conversion.dgl_graph_to_general_static_graph(dgl_graph)] | |||
| ) | |||
| # super(AmazonComputersDataset, self).__init__( | |||
| # [ | |||
| # GeneralStaticGraphGenerator.create_homogeneous_static_graph( | |||
| # { | |||
| # 'feat': dgl_graph.ndata['feat'], | |||
| # 'label': dgl_graph.ndata['label'] | |||
| # }, | |||
| # torch.vstack(dgl_graph.edges()) | |||
| # ) | |||
| # ] | |||
| # ) | |||
| @DatasetUniversalRegistry.register_dataset("amazon_photo") | |||
| class AmazonPhotoDataset(InMemoryStaticGraphSet): | |||
| def __init__(self, path: str): | |||
| dgl_dataset = dgl.data.AmazonCoBuyPhotoDataset( | |||
| raw_dir=os.path.join(path, '_dgl') | |||
| ) | |||
| dgl_graph: dgl.DGLGraph = dgl_dataset[0] | |||
| super(AmazonPhotoDataset, self).__init__( | |||
| [_conversion.dgl_graph_to_general_static_graph(dgl_graph)] | |||
| ) | |||
| # super(AmazonPhotoDataset, self).__init__( | |||
| # [ | |||
| # GeneralStaticGraphGenerator.create_homogeneous_static_graph( | |||
| # { | |||
| # 'feat': dgl_graph.ndata['feat'], | |||
| # 'label': dgl_graph.ndata['label'] | |||
| # }, | |||
| # torch.vstack(dgl_graph.edges()) | |||
| # ) | |||
| # ] | |||
| # ) | |||
| @DatasetUniversalRegistry.register_dataset("coauthor_physics") | |||
| class CoauthorPhysicsDataset(InMemoryStaticGraphSet): | |||
| def __init__(self, path: str): | |||
| dgl_dataset = dgl.data.CoauthorPhysicsDataset( | |||
| raw_dir=os.path.join(path, '_dgl') | |||
| ) | |||
| dgl_graph: dgl.DGLGraph = dgl_dataset[0] | |||
| super(CoauthorPhysicsDataset, self).__init__( | |||
| [_conversion.dgl_graph_to_general_static_graph(dgl_graph)] | |||
| ) | |||
| # super(CoauthorPhysicsDataset, self).__init__( | |||
| # [ | |||
| # GeneralStaticGraphGenerator.create_homogeneous_static_graph( | |||
| # { | |||
| # 'feat': dgl_graph.ndata['feat'], | |||
| # 'label': dgl_graph.ndata['label'] | |||
| # }, | |||
| # torch.vstack(dgl_graph.edges()) | |||
| # ) | |||
| # ] | |||
| # ) | |||
| @DatasetUniversalRegistry.register_dataset("coauthor_cs") | |||
| class CoauthorCSDataset(InMemoryStaticGraphSet): | |||
| def __init__(self, path: str): | |||
| dgl_dataset = dgl.data.CoauthorCSDataset( | |||
| raw_dir=os.path.join(path, '_dgl') | |||
| ) | |||
| dgl_graph: dgl.DGLGraph = dgl_dataset[0] | |||
| super(CoauthorCSDataset, self).__init__( | |||
| [_conversion.dgl_graph_to_general_static_graph(dgl_graph)] | |||
| ) | |||
| # super(CoauthorCSDataset, self).__init__( | |||
| # [ | |||
| # GeneralStaticGraphGenerator.create_homogeneous_static_graph( | |||
| # { | |||
| # 'feat': dgl_graph.ndata['feat'], | |||
| # 'label': dgl_graph.ndata['label'] | |||
| # }, | |||
| # torch.vstack(dgl_graph.edges()) | |||
| # ) | |||
| # ] | |||
| # ) | |||
| @DatasetUniversalRegistry.register_dataset("mutag") | |||
| class MUTAGDataset(InMemoryStaticGraphSet): | |||
| def __init__(self, path: str): | |||
| dgl_dataset = dgl.data.GINDataset( | |||
| "MUTAG", False, raw_dir=os.path.join(path, "_dgl") | |||
| ) | |||
| def _transform(dgl_graph: dgl.DGLGraph, label: torch.Tensor): | |||
| dgl_graph.ndata['feat'] = dgl_graph.ndata['attr'] | |||
| dgl_graph.ndata['node_label'] = dgl_graph.ndata['label'] | |||
| del dgl_graph.ndata['attr'] | |||
| del dgl_graph.ndata['label'] | |||
| static_graph = _conversion.dgl_graph_to_general_static_graph(dgl_graph) | |||
| static_graph.data['label'] = label | |||
| return static_graph | |||
| super(MUTAGDataset, self).__init__( | |||
| [_transform(dgl_graph, label) for (dgl_graph, label) in dgl_dataset] | |||
| ) | |||
| # super(MUTAGDataset, self).__init__( | |||
| # [ | |||
| # GeneralStaticGraphGenerator.create_homogeneous_static_graph( | |||
| # { | |||
| # 'label': dgl_graph.ndata['label'], | |||
| # 'attr': dgl_graph.ndata['attr'] | |||
| # }, | |||
| # torch.vstack(dgl_graph.edges()), | |||
| # graph_data={'label': graph_label} | |||
| # ) | |||
| # for (dgl_graph, graph_label) in dgl_dataset | |||
| # ] | |||
| # ) | |||
| @DatasetUniversalRegistry.register_dataset("enzymes") | |||
| class ENZYMESDataset(InMemoryStaticGraphSet): | |||
| def __init__(self, path: str): | |||
| dgl_dataset = dgl.data.TUDataset( | |||
| "ENZYMES", raw_dir=os.path.join(path, "_dgl") | |||
| ) | |||
| def _transform(dgl_graph: dgl.DGLGraph, label: torch.Tensor): | |||
| dgl_graph.ndata['feat'] = dgl_graph.ndata['node_attr'] | |||
| dgl_graph.ndata['node_label'] = dgl_graph.ndata['node_labels'] | |||
| del dgl_graph.ndata['node_attr'] | |||
| del dgl_graph.ndata['node_labels'] | |||
| static_graph = _conversion.dgl_graph_to_general_static_graph(dgl_graph) | |||
| static_graph.data['label'] = label | |||
| return static_graph | |||
| super(ENZYMESDataset, self).__init__( | |||
| [_transform(dgl_graph, label) for (dgl_graph, label) in dgl_dataset] | |||
| ) | |||
| # super(ENZYMESDataset, self).__init__( | |||
| # [ | |||
| # GeneralStaticGraphGenerator.create_homogeneous_static_graph( | |||
| # { | |||
| # 'node_labels': dgl_graph.ndata['node_labels'], | |||
| # 'node_attr': dgl_graph.ndata['node_attr'] | |||
| # }, | |||
| # torch.vstack(dgl_graph.edges()), | |||
| # graph_data={'label': label} | |||
| # ) for (dgl_graph, label) in dgl_dataset | |||
| # ] | |||
| # ) | |||
| @DatasetUniversalRegistry.register_dataset("imdb-b") | |||
| class IMDBBinaryDataset(InMemoryStaticGraphSet): | |||
| def __init__(self, path: str): | |||
| dgl_dataset = dgl.data.GINDataset( | |||
| "IMDBBINARY", False, raw_dir=os.path.join(path, "_dgl") | |||
| ) | |||
| def _transform(dgl_graph: dgl.DGLGraph, label: torch.Tensor): | |||
| dgl_graph.ndata['feat'] = dgl_graph.ndata['attr'] | |||
| dgl_graph.ndata['node_label'] = dgl_graph.ndata['label'] | |||
| del dgl_graph.ndata['attr'] | |||
| del dgl_graph.ndata['label'] | |||
| static_graph = _conversion.dgl_graph_to_general_static_graph(dgl_graph) | |||
| static_graph.data['label'] = label | |||
| return static_graph | |||
| super(IMDBBinaryDataset, self).__init__( | |||
| [_transform(dgl_graph, label) for (dgl_graph, label) in dgl_dataset] | |||
| ) | |||
| # super(IMDBBinaryDataset, self).__init__( | |||
| # [ | |||
| # GeneralStaticGraphGenerator.create_homogeneous_static_graph( | |||
| # { | |||
| # 'label': dgl_graph.ndata['label'], | |||
| # 'attr': dgl_graph.ndata['attr'] | |||
| # }, | |||
| # torch.vstack(dgl_graph.edges()), | |||
| # graph_data={'label': graph_label} | |||
| # ) | |||
| # for (dgl_graph, graph_label) in dgl_dataset | |||
| # ] | |||
| # ) | |||
| @DatasetUniversalRegistry.register_dataset("imdb-m") | |||
| class IMDBMultiDataset(InMemoryStaticGraphSet): | |||
| def __init__(self, path: str): | |||
| dgl_dataset = dgl.data.GINDataset( | |||
| "IMDBMULTI", False, raw_dir=os.path.join(path, "_dgl") | |||
| ) | |||
| def _transform(dgl_graph: dgl.DGLGraph, label: torch.Tensor): | |||
| dgl_graph.ndata['feat'] = dgl_graph.ndata['attr'] | |||
| dgl_graph.ndata['node_label'] = dgl_graph.ndata['label'] | |||
| del dgl_graph.ndata['attr'] | |||
| del dgl_graph.ndata['label'] | |||
| static_graph = _conversion.dgl_graph_to_general_static_graph(dgl_graph) | |||
| static_graph.data['label'] = label | |||
| return static_graph | |||
| super(IMDBMultiDataset, self).__init__( | |||
| [_transform(dgl_graph, label) for (dgl_graph, label) in dgl_dataset] | |||
| ) | |||
| # super(IMDBMultiDataset, self).__init__( | |||
| # [ | |||
| # GeneralStaticGraphGenerator.create_homogeneous_static_graph( | |||
| # { | |||
| # 'label': dgl_graph.ndata['label'], | |||
| # 'attr': dgl_graph.ndata['attr'] | |||
| # }, | |||
| # torch.vstack(dgl_graph.edges()), | |||
| # graph_data={'label': graph_label} | |||
| # ) | |||
| # for (dgl_graph, graph_label) in dgl_dataset | |||
| # ] | |||
| # ) | |||
| @DatasetUniversalRegistry.register_dataset("reddit-b") | |||
| class RedditBinaryDataset(InMemoryStaticGraphSet): | |||
| def __init__(self, path: str): | |||
| dgl_dataset = dgl.data.GINDataset( | |||
| "REDDITBINARY", False, raw_dir=os.path.join(path, "_dgl") | |||
| ) | |||
| def _transform(dgl_graph: dgl.DGLGraph, label: torch.Tensor): | |||
| dgl_graph.ndata['feat'] = dgl_graph.ndata['attr'] | |||
| dgl_graph.ndata['node_label'] = dgl_graph.ndata['label'] | |||
| del dgl_graph.ndata['attr'] | |||
| del dgl_graph.ndata['label'] | |||
| static_graph = _conversion.dgl_graph_to_general_static_graph(dgl_graph) | |||
| static_graph.data['label'] = label | |||
| return static_graph | |||
| super(RedditBinaryDataset, self).__init__( | |||
| [_transform(dgl_graph, label) for (dgl_graph, label) in dgl_dataset] | |||
| ) | |||
| # super(RedditBinaryDataset, self).__init__( | |||
| # [ | |||
| # GeneralStaticGraphGenerator.create_homogeneous_static_graph( | |||
| # { | |||
| # 'label': dgl_graph.ndata['label'], | |||
| # 'attr': dgl_graph.ndata['attr'] | |||
| # }, | |||
| # torch.vstack(dgl_graph.edges()), | |||
| # graph_data={'label': graph_label} | |||
| # ) | |||
| # for (dgl_graph, graph_label) in dgl_dataset | |||
| # ] | |||
| # ) | |||
| @DatasetUniversalRegistry.register_dataset("reddit-multi-5k") | |||
| class REDDITMulti5KDataset(InMemoryStaticGraphSet): | |||
| def __init__(self, path: str): | |||
| dgl_dataset = dgl.data.GINDataset( | |||
| "REDDITMULTI5K", False, raw_dir=os.path.join(path, "_dgl") | |||
| ) | |||
| def _transform(dgl_graph: dgl.DGLGraph, label: torch.Tensor): | |||
| dgl_graph.ndata['feat'] = dgl_graph.ndata['attr'] | |||
| dgl_graph.ndata['node_label'] = dgl_graph.ndata['label'] | |||
| del dgl_graph.ndata['attr'] | |||
| del dgl_graph.ndata['label'] | |||
| static_graph = _conversion.dgl_graph_to_general_static_graph(dgl_graph) | |||
| static_graph.data['label'] = label | |||
| return static_graph | |||
| super(REDDITMulti5KDataset, self).__init__( | |||
| [_transform(dgl_graph, label) for (dgl_graph, label) in dgl_dataset] | |||
| ) | |||
| # super(REDDITMulti5KDataset, self).__init__( | |||
| # [ | |||
| # GeneralStaticGraphGenerator.create_homogeneous_static_graph( | |||
| # { | |||
| # 'label': dgl_graph.ndata['label'], | |||
| # 'attr': dgl_graph.ndata['attr'] | |||
| # }, | |||
| # torch.vstack(dgl_graph.edges()), | |||
| # graph_data={'label': graph_label} | |||
| # ) | |||
| # for (dgl_graph, graph_label) in dgl_dataset | |||
| # ] | |||
| # ) | |||
| @DatasetUniversalRegistry.register_dataset("collab") | |||
| class COLLABDataset(InMemoryStaticGraphSet): | |||
| def __init__(self, path: str): | |||
| dgl_dataset = dgl.data.GINDataset( | |||
| "COLLAB", False, raw_dir=os.path.join(path, "_dgl") | |||
| ) | |||
| def _transform(dgl_graph: dgl.DGLGraph, label: torch.Tensor): | |||
| dgl_graph.ndata['feat'] = dgl_graph.ndata['attr'] | |||
| dgl_graph.ndata['node_label'] = dgl_graph.ndata['label'] | |||
| del dgl_graph.ndata['attr'] | |||
| del dgl_graph.ndata['label'] | |||
| static_graph = _conversion.dgl_graph_to_general_static_graph(dgl_graph) | |||
| static_graph.data['label'] = label | |||
| return static_graph | |||
| super(COLLABDataset, self).__init__( | |||
| [_transform(dgl_graph, label) for (dgl_graph, label) in dgl_dataset] | |||
| ) | |||
| # super(COLLABDataset, self).__init__( | |||
| # [ | |||
| # GeneralStaticGraphGenerator.create_homogeneous_static_graph( | |||
| # { | |||
| # 'label': dgl_graph.ndata['label'], | |||
| # 'attr': dgl_graph.ndata['attr'] | |||
| # }, | |||
| # torch.vstack(dgl_graph.edges()), | |||
| # graph_data={'label': graph_label} | |||
| # ) | |||
| # for (dgl_graph, graph_label) in dgl_dataset | |||
| # ] | |||
| # ) | |||
| @DatasetUniversalRegistry.register_dataset("proteins") | |||
| class ProteinsDataset(InMemoryStaticGraphSet): | |||
| def __init__(self, path: str): | |||
| dgl_dataset = dgl.data.GINDataset( | |||
| "PROTEINS", False, raw_dir=os.path.join(path, "_dgl") | |||
| ) | |||
| def _transform(dgl_graph: dgl.DGLGraph, label: torch.Tensor): | |||
| dgl_graph.ndata['feat'] = dgl_graph.ndata['attr'] | |||
| dgl_graph.ndata['node_label'] = dgl_graph.ndata['label'] | |||
| del dgl_graph.ndata['attr'] | |||
| del dgl_graph.ndata['label'] | |||
| static_graph = _conversion.dgl_graph_to_general_static_graph(dgl_graph) | |||
| static_graph.data['label'] = label | |||
| return static_graph | |||
| super(ProteinsDataset, self).__init__( | |||
| [_transform(dgl_graph, label) for (dgl_graph, label) in dgl_dataset] | |||
| ) | |||
| # super(ProteinsDataset, self).__init__( | |||
| # [ | |||
| # GeneralStaticGraphGenerator.create_homogeneous_static_graph( | |||
| # { | |||
| # 'label': dgl_graph.ndata['label'], | |||
| # 'attr': dgl_graph.ndata['attr'] | |||
| # }, | |||
| # torch.vstack(dgl_graph.edges()), | |||
| # graph_data={'label': graph_label} | |||
| # ) | |||
| # for (dgl_graph, graph_label) in dgl_dataset | |||
| # ] | |||
| # ) | |||
| @DatasetUniversalRegistry.register_dataset("ptc-mr") | |||
| class PTCMRDataset(InMemoryStaticGraphSet): | |||
| def __init__(self, path: str): | |||
| dgl_dataset = dgl.data.GINDataset( | |||
| "PTC", False, raw_dir=os.path.join(path, "_dgl") | |||
| ) | |||
| def _transform(dgl_graph: dgl.DGLGraph, label: torch.Tensor): | |||
| dgl_graph.ndata['feat'] = dgl_graph.ndata['attr'] | |||
| dgl_graph.ndata['node_label'] = dgl_graph.ndata['label'] | |||
| del dgl_graph.ndata['attr'] | |||
| del dgl_graph.ndata['label'] | |||
| static_graph = _conversion.dgl_graph_to_general_static_graph(dgl_graph) | |||
| static_graph.data['label'] = label | |||
| return static_graph | |||
| super(PTCMRDataset, self).__init__( | |||
| [_transform(dgl_graph, label) for (dgl_graph, label) in dgl_dataset] | |||
| ) | |||
| # super(PTCMRDataset, self).__init__( | |||
| # [ | |||
| # GeneralStaticGraphGenerator.create_homogeneous_static_graph( | |||
| # { | |||
| # 'label': dgl_graph.ndata['label'], | |||
| # 'attr': dgl_graph.ndata['attr'] | |||
| # }, | |||
| # torch.vstack(dgl_graph.edges()), | |||
| # graph_data={'label': graph_label} | |||
| # ) | |||
| # for (dgl_graph, graph_label) in dgl_dataset | |||
| # ] | |||
| # ) | |||
| @DatasetUniversalRegistry.register_dataset("nci1") | |||
| class NCI1Dataset(InMemoryStaticGraphSet): | |||
| def __init__(self, path: str): | |||
| dgl_dataset = dgl.data.GINDataset( | |||
| "NCI1", False, raw_dir=os.path.join(path, "_dgl") | |||
| ) | |||
| def _transform(dgl_graph: dgl.DGLGraph, label: torch.Tensor): | |||
| dgl_graph.ndata['feat'] = dgl_graph.ndata['attr'] | |||
| dgl_graph.ndata['node_label'] = dgl_graph.ndata['label'] | |||
| del dgl_graph.ndata['attr'] | |||
| del dgl_graph.ndata['label'] | |||
| static_graph = _conversion.dgl_graph_to_general_static_graph(dgl_graph) | |||
| static_graph.data['label'] = label | |||
| return static_graph | |||
| super(NCI1Dataset, self).__init__( | |||
| [_transform(dgl_graph, label) for (dgl_graph, label) in dgl_dataset] | |||
| ) | |||
| # super(NCI1Dataset, self).__init__( | |||
| # [ | |||
| # GeneralStaticGraphGenerator.create_homogeneous_static_graph( | |||
| # { | |||
| # 'label': dgl_graph.ndata['label'], | |||
| # 'attr': dgl_graph.ndata['attr'] | |||
| # }, | |||
| # torch.vstack(dgl_graph.edges()), | |||
| # graph_data={'label': graph_label} | |||
| # ) | |||
| # for (dgl_graph, graph_label) in dgl_dataset | |||
| # ] | |||
| # ) | |||
| @@ -0,0 +1,244 @@ | |||
| import os | |||
| import os.path as osp | |||
| import shutil | |||
| import pickle | |||
| import numpy as np | |||
| import torch | |||
| import typing as _typing | |||
| from autogl.data import Data, download_url, InMemoryStaticGraphSet | |||
| from autogl.data.graph import GeneralStaticGraphGenerator | |||
| from ._dataset_registry import DatasetUniversalRegistry | |||
| from ._data_source import OnlineDataSource | |||
| from .. import backend as _backend | |||
| def _untar(path, fname, delete_tar=True): | |||
| """ | |||
| Unpacks the given archive file to the same directory, then (by default) | |||
| deletes the archive file. | |||
| """ | |||
| print("unpacking " + fname) | |||
| full_path = os.path.join(path, fname) | |||
| shutil.unpack_archive(full_path, path) | |||
| if delete_tar: | |||
| os.remove(full_path) | |||
| class _GTNDataSource(OnlineDataSource): | |||
| def __init__(self, path: str, name: str): | |||
| self.__name: str = name | |||
| self.__url: str = ( | |||
| f"https://github.com/cenyk1230/gtn-data/blob/master/{name}.zip?raw=true" | |||
| ) | |||
| super(_GTNDataSource, self).__init__(path) | |||
| self.__data = torch.load(list(self._processed_file_paths)[0]) | |||
| @property | |||
| def _raw_filenames(self) -> _typing.Iterable[str]: | |||
| return ["edges.pkl", "labels.pkl", "node_features.pkl"] | |||
| @property | |||
| def _processed_filenames(self) -> _typing.Iterable[str]: | |||
| return ["data.pt"] | |||
| def __read_gtn_data(self, directory): | |||
| edges = pickle.load(open(osp.join(directory, "edges.pkl"), "rb")) | |||
| labels = pickle.load(open(osp.join(directory, "labels.pkl"), "rb")) | |||
| node_features = pickle.load(open(osp.join(directory, "node_features.pkl"), "rb")) | |||
| data = Data() | |||
| data.x = torch.from_numpy(node_features).float() | |||
| num_nodes = edges[0].shape[0] | |||
| node_type = np.zeros(num_nodes, dtype=int) | |||
| assert len(edges) == 4 | |||
| assert len(edges[0].nonzero()) == 2 | |||
| node_type[edges[0].nonzero()[0]] = 0 | |||
| node_type[edges[0].nonzero()[1]] = 1 | |||
| node_type[edges[1].nonzero()[0]] = 1 | |||
| node_type[edges[1].nonzero()[1]] = 0 | |||
| node_type[edges[2].nonzero()[0]] = 0 | |||
| node_type[edges[2].nonzero()[1]] = 2 | |||
| node_type[edges[3].nonzero()[0]] = 2 | |||
| node_type[edges[3].nonzero()[1]] = 0 | |||
| print(node_type) | |||
| data.pos = torch.from_numpy(node_type) | |||
| edge_list = [] | |||
| for i, edge in enumerate(edges): | |||
| edge_tmp = torch.from_numpy( | |||
| np.vstack((edge.nonzero()[0], edge.nonzero()[1])) | |||
| ).long() | |||
| edge_list.append(edge_tmp) | |||
| data.edge_index = torch.cat(edge_list, 1) | |||
| A = [] | |||
| for i, edge in enumerate(edges): | |||
| edge_tmp = torch.from_numpy( | |||
| np.vstack((edge.nonzero()[0], edge.nonzero()[1])) | |||
| ).long() | |||
| value_tmp = torch.ones(edge_tmp.shape[1]).float() | |||
| A.append((edge_tmp, value_tmp)) | |||
| edge_tmp = torch.stack( | |||
| (torch.arange(0, num_nodes), torch.arange(0, num_nodes)) | |||
| ).long() | |||
| value_tmp = torch.ones(num_nodes).float() | |||
| A.append((edge_tmp, value_tmp)) | |||
| data.adj = A | |||
| data.train_node = torch.from_numpy(np.array(labels[0])[:, 0]).long() | |||
| data.train_target = torch.from_numpy(np.array(labels[0])[:, 1]).long() | |||
| data.valid_node = torch.from_numpy(np.array(labels[1])[:, 0]).long() | |||
| data.valid_target = torch.from_numpy(np.array(labels[1])[:, 1]).long() | |||
| data.test_node = torch.from_numpy(np.array(labels[2])[:, 0]).long() | |||
| data.test_target = torch.from_numpy(np.array(labels[2])[:, 1]).long() | |||
| y = np.zeros(num_nodes, dtype=int) | |||
| x_index = torch.cat((data.train_node, data.valid_node, data.test_node)) | |||
| y_index = torch.cat((data.train_target, data.valid_target, data.test_target)) | |||
| y[x_index.numpy()] = y_index.numpy() | |||
| data.y = torch.from_numpy(y) | |||
| self.__data = data | |||
| def __transform_gtn_data(self): | |||
| self.__data.train_mask: torch.Tensor = torch.zeros(self.__data.x.size(0), dtype=torch.bool) | |||
| self.__data.val_mask: torch.Tensor = torch.zeros(self.__data.x.size(0), dtype=torch.bool) | |||
| self.__data.test_mask: torch.Tensor = torch.zeros(self.__data.x.size(0), dtype=torch.bool) | |||
| self.__data.train_mask[getattr(self.__data, "train_node")] = True | |||
| self.__data.val_mask[getattr(self.__data, "valid_node")] = True | |||
| self.__data.test_mask[getattr(self.__data, "test_node")] = True | |||
| def _fetch(self): | |||
| download_url(self.__url, self._raw_directory, name=f"{self.__name}.zip") | |||
| _untar(self._raw_directory, f"{self.__name}.zip") | |||
| def _process(self): | |||
| self.__read_gtn_data(self._raw_directory) | |||
| self.__transform_gtn_data() | |||
| torch.save(self.__data, list(self._processed_file_paths)[0]) | |||
| def __len__(self) -> int: | |||
| return 1 | |||
| def __getitem__(self, index): | |||
| if index != 0: | |||
| raise IndexError | |||
| return self.__data | |||
| @DatasetUniversalRegistry.register_dataset("gtn-acm") | |||
| class GTNACMDataset(InMemoryStaticGraphSet): | |||
| def __init__(self, path: str): | |||
| data = _GTNDataSource(path, "gtn-acm")[0] | |||
| if _backend.DependentBackend.is_dgl(): | |||
| super(GTNACMDataset, self).__init__( | |||
| [ | |||
| GeneralStaticGraphGenerator.create_homogeneous_static_graph( | |||
| { | |||
| 'feat': getattr(data, 'x'), | |||
| 'label': getattr(data, 'y'), | |||
| 'pos': getattr(data, 'pos'), | |||
| 'train_mask': getattr(data, 'train_mask'), | |||
| 'val_mask': getattr(data, 'val_mask'), | |||
| 'test_mask': getattr(data, 'test_mask') | |||
| }, | |||
| getattr(data, 'edge_index') | |||
| ) | |||
| ] | |||
| ) | |||
| elif _backend.DependentBackend.is_pyg(): | |||
| super(GTNACMDataset, self).__init__( | |||
| [ | |||
| GeneralStaticGraphGenerator.create_homogeneous_static_graph( | |||
| { | |||
| 'x': getattr(data, 'x'), | |||
| 'y': getattr(data, 'y'), | |||
| 'pos': getattr(data, 'pos'), | |||
| 'train_mask': getattr(data, 'train_mask'), | |||
| 'val_mask': getattr(data, 'val_mask'), | |||
| 'test_mask': getattr(data, 'test_mask') | |||
| }, | |||
| getattr(data, 'edge_index') | |||
| ) | |||
| ] | |||
| ) | |||
| @DatasetUniversalRegistry.register_dataset("gtn-dblp") | |||
| class GTNDBLPDataset(InMemoryStaticGraphSet): | |||
| def __init__(self, path: str): | |||
| data = _GTNDataSource(path, "gtn-dblp")[0] | |||
| if _backend.DependentBackend.is_dgl(): | |||
| super(GTNDBLPDataset, self).__init__( | |||
| [ | |||
| GeneralStaticGraphGenerator.create_homogeneous_static_graph( | |||
| { | |||
| 'feat': getattr(data, 'x'), | |||
| 'label': getattr(data, 'y'), | |||
| 'pos': getattr(data, 'pos'), | |||
| 'train_mask': getattr(data, 'train_mask'), | |||
| 'val_mask': getattr(data, 'val_mask'), | |||
| 'test_mask': getattr(data, 'test_mask') | |||
| }, | |||
| getattr(data, 'edge_index') | |||
| ) | |||
| ] | |||
| ) | |||
| elif _backend.DependentBackend.is_pyg(): | |||
| super(GTNDBLPDataset, self).__init__( | |||
| [ | |||
| GeneralStaticGraphGenerator.create_homogeneous_static_graph( | |||
| { | |||
| 'x': getattr(data, 'x'), | |||
| 'y': getattr(data, 'y'), | |||
| 'pos': getattr(data, 'pos'), | |||
| 'train_mask': getattr(data, 'train_mask'), | |||
| 'val_mask': getattr(data, 'val_mask'), | |||
| 'test_mask': getattr(data, 'test_mask') | |||
| }, | |||
| getattr(data, 'edge_index') | |||
| ) | |||
| ] | |||
| ) | |||
| @DatasetUniversalRegistry.register_dataset("gtn-imdb") | |||
| class GTNIMDBDataset(InMemoryStaticGraphSet): | |||
| def __init__(self, path: str): | |||
| data = _GTNDataSource(path, "gtn-imdb")[0] | |||
| if _backend.DependentBackend.is_dgl(): | |||
| super(GTNIMDBDataset, self).__init__( | |||
| [ | |||
| GeneralStaticGraphGenerator.create_homogeneous_static_graph( | |||
| { | |||
| 'feat': getattr(data, 'x'), | |||
| 'label': getattr(data, 'y'), | |||
| 'pos': getattr(data, 'pos'), | |||
| 'train_mask': getattr(data, 'train_mask'), | |||
| 'val_mask': getattr(data, 'val_mask'), | |||
| 'test_mask': getattr(data, 'test_mask') | |||
| }, | |||
| getattr(data, 'edge_index') | |||
| ) | |||
| ] | |||
| ) | |||
| elif _backend.DependentBackend.is_pyg(): | |||
| super(GTNIMDBDataset, self).__init__( | |||
| [ | |||
| GeneralStaticGraphGenerator.create_homogeneous_static_graph( | |||
| { | |||
| 'x': getattr(data, 'x'), | |||
| 'y': getattr(data, 'y'), | |||
| 'pos': getattr(data, 'pos'), | |||
| 'train_mask': getattr(data, 'train_mask'), | |||
| 'val_mask': getattr(data, 'val_mask'), | |||
| 'test_mask': getattr(data, 'test_mask') | |||
| }, | |||
| getattr(data, 'edge_index') | |||
| ) | |||
| ] | |||
| ) | |||
| @@ -0,0 +1,112 @@ | |||
| import itertools | |||
| import os | |||
| import scipy.io | |||
| import torch | |||
| import typing as _typing | |||
| from autogl.data import Data, download_url, InMemoryStaticGraphSet | |||
| from autogl.data.graph import GeneralStaticGraphGenerator | |||
| from ._dataset_registry import DatasetUniversalRegistry | |||
| from ._data_source import OnlineDataSource | |||
| from .. import backend as _backend | |||
| class _MATLABMatrix(OnlineDataSource): | |||
| @property | |||
| def _raw_filenames(self) -> _typing.Iterable[str]: | |||
| splits = [self.__name] | |||
| files = ["mat"] | |||
| return [ | |||
| "{}.{}".format(s, f) for s, f | |||
| in itertools.product(splits, files) | |||
| ] | |||
| @property | |||
| def _processed_filenames(self) -> _typing.Iterable[str]: | |||
| return ["data.pt"] | |||
| def _fetch(self): | |||
| for name in self._raw_filenames: | |||
| download_url(self.__url + name, self._raw_directory) | |||
| def _process(self): | |||
| path = os.path.join(self._raw_directory, f"{self.__name}.mat") | |||
| mat = scipy.io.loadmat(path) | |||
| adj_matrix, group = mat["network"], mat["group"] | |||
| y = torch.from_numpy(group.todense()).to(torch.float) | |||
| row_ind, col_ind = adj_matrix.nonzero() | |||
| edge_index = torch.stack([torch.tensor(row_ind), torch.tensor(col_ind)], dim=0) | |||
| edge_attr = torch.tensor(adj_matrix[row_ind, col_ind]) | |||
| data = Data(edge_index=edge_index, edge_attr=edge_attr, x=None, y=y) | |||
| torch.save(data, list(self._processed_file_paths)[0]) | |||
| def __len__(self) -> int: | |||
| return 1 | |||
| def __getitem__(self, index: int): | |||
| if index != 0: | |||
| raise IndexError | |||
| return self.__data | |||
| def __init__(self, path: str, name: str, url: str): | |||
| self.__name: str = name | |||
| self.__url: str = url | |||
| super(_MATLABMatrix, self).__init__(path) | |||
| self.__data = torch.load( | |||
| list(self._processed_file_paths)[0] | |||
| ) | |||
| @DatasetUniversalRegistry.register_dataset("BlogCatalog".lower()) | |||
| class BlogCatalogDataset(InMemoryStaticGraphSet): | |||
| def __init__(self, path: str): | |||
| filename: str = "BlogCatalog".lower() | |||
| url: str = "http://leitang.net/code/social-dimension/data/" | |||
| data = _MATLABMatrix(path, filename, url)[0] | |||
| if _backend.DependentBackend.is_dgl(): | |||
| super(BlogCatalogDataset, self).__init__( | |||
| [ | |||
| GeneralStaticGraphGenerator.create_homogeneous_static_graph( | |||
| {'label': data.y}, data.edge_index, | |||
| {'edge_attr': data.edge_attr} | |||
| ) | |||
| ] | |||
| ) | |||
| elif _backend.DependentBackend.is_pyg(): | |||
| super(BlogCatalogDataset, self).__init__( | |||
| [ | |||
| GeneralStaticGraphGenerator.create_homogeneous_static_graph( | |||
| {'y': data.y}, data.edge_index, | |||
| {'edge_attr': data.edge_attr} | |||
| ) | |||
| ] | |||
| ) | |||
| @DatasetUniversalRegistry.register_dataset("WikiPEDIA".lower()) | |||
| class WIKIPEDIADataset(InMemoryStaticGraphSet): | |||
| def __init__(self, path: str): | |||
| filename: str = "POS" | |||
| url = "http://snap.stanford.edu/node2vec/" | |||
| data = _MATLABMatrix(path, filename, url)[0] | |||
| if _backend.DependentBackend.is_dgl(): | |||
| super(WIKIPEDIADataset, self).__init__( | |||
| [ | |||
| GeneralStaticGraphGenerator.create_homogeneous_static_graph( | |||
| {'label': data.y}, data.edge_index, | |||
| {'attr': data.edge_attr} | |||
| ) | |||
| ] | |||
| ) | |||
| elif _backend.DependentBackend.is_pyg(): | |||
| super(WIKIPEDIADataset, self).__init__( | |||
| [ | |||
| GeneralStaticGraphGenerator.create_homogeneous_static_graph( | |||
| {'y': data.y}, data.edge_index, | |||
| {'attr': data.edge_attr} | |||
| ) | |||
| ] | |||
| ) | |||
| @@ -0,0 +1,445 @@ | |||
| import numpy as np | |||
| import torch | |||
| import typing as _typing | |||
| from ogb.nodeproppred import NodePropPredDataset | |||
| from ogb.linkproppred import LinkPropPredDataset | |||
| from ogb.graphproppred import GraphPropPredDataset | |||
| from autogl import backend as _backend | |||
| from autogl.data import InMemoryStaticGraphSet | |||
| from autogl.data.graph import ( | |||
| GeneralStaticGraph, GeneralStaticGraphGenerator | |||
| ) | |||
| from ._dataset_registry import DatasetUniversalRegistry | |||
| from .utils import index_to_mask | |||
| class _OGBDatasetUtil: | |||
| ... | |||
| class _OGBNDatasetUtil(_OGBDatasetUtil): | |||
| @classmethod | |||
| def ogbn_data_to_general_static_graph( | |||
| cls, ogbn_data: _typing.Mapping[str, _typing.Union[np.ndarray, int]], | |||
| nodes_label: np.ndarray = ..., nodes_label_key: str = ..., | |||
| train_index: _typing.Optional[np.ndarray] = ..., | |||
| val_index: _typing.Optional[np.ndarray] = ..., | |||
| test_index: _typing.Optional[np.ndarray] = ..., | |||
| nodes_data_key_mapping: _typing.Optional[_typing.Mapping[str, str]] = ..., | |||
| edges_data_key_mapping: _typing.Optional[_typing.Mapping[str, str]] = ..., | |||
| graph_data_key_mapping: _typing.Optional[_typing.Mapping[str, str]] = ... | |||
| ) -> GeneralStaticGraph: | |||
| homogeneous_static_graph: GeneralStaticGraph = ( | |||
| GeneralStaticGraphGenerator.create_homogeneous_static_graph( | |||
| dict([ | |||
| (target_key, torch.from_numpy(ogbn_data[source_key])) | |||
| for source_key, target_key in nodes_data_key_mapping.items() | |||
| ]), | |||
| torch.from_numpy(ogbn_data['edge_index']), | |||
| dict([ | |||
| (target_key, torch.from_numpy(ogbn_data[source_key])) | |||
| for source_key, target_key in edges_data_key_mapping.items() | |||
| ]) if isinstance(edges_data_key_mapping, _typing.Mapping) else ..., | |||
| dict([ | |||
| (target_key, torch.from_numpy(ogbn_data[source_key])) | |||
| for source_key, target_key in graph_data_key_mapping.items() | |||
| ]) if isinstance(graph_data_key_mapping, _typing.Mapping) else ... | |||
| ) | |||
| ) | |||
| if isinstance(nodes_label, np.ndarray) and isinstance(nodes_label_key, str): | |||
| if ' ' in nodes_label_key: | |||
| raise ValueError("Illegal nodes label key") | |||
| homogeneous_static_graph.nodes.data[nodes_label_key] = ( | |||
| torch.from_numpy(nodes_label.squeeze()).squeeze() | |||
| ) | |||
| if isinstance(train_index, np.ndarray): | |||
| homogeneous_static_graph.nodes.data['train_mask'] = index_to_mask( | |||
| torch.from_numpy(train_index), ogbn_data['num_nodes'] | |||
| ) | |||
| if isinstance(val_index, np.ndarray): | |||
| homogeneous_static_graph.nodes.data['val_mask'] = index_to_mask( | |||
| torch.from_numpy(val_index), ogbn_data['num_nodes'] | |||
| ) | |||
| if isinstance(test_index, np.ndarray): | |||
| homogeneous_static_graph.nodes.data['test_mask'] = index_to_mask( | |||
| torch.from_numpy(test_index), ogbn_data['num_nodes'] | |||
| ) | |||
| return homogeneous_static_graph | |||
| @classmethod | |||
| def ogbn_dataset_to_general_static_graph( | |||
| cls, ogbn_dataset: NodePropPredDataset, | |||
| nodes_label_key: str, | |||
| nodes_data_key_mapping: _typing.Optional[_typing.Mapping[str, str]] = ..., | |||
| edges_data_key_mapping: _typing.Optional[_typing.Mapping[str, str]] = ..., | |||
| graph_data_key_mapping: _typing.Optional[_typing.Mapping[str, str]] = ... | |||
| ) -> GeneralStaticGraph: | |||
| split_idx = ogbn_dataset.get_idx_split() | |||
| return cls.ogbn_data_to_general_static_graph( | |||
| ogbn_dataset[0][0], | |||
| ogbn_dataset[0][1], | |||
| nodes_label_key, | |||
| split_idx["train"], | |||
| split_idx["valid"], | |||
| split_idx["test"], | |||
| nodes_data_key_mapping, | |||
| edges_data_key_mapping, | |||
| graph_data_key_mapping | |||
| ) | |||
| @DatasetUniversalRegistry.register_dataset("ogbn-products") | |||
| class OGBNProductsDataset(InMemoryStaticGraphSet): | |||
| def __init__(self, path: str): | |||
| ogbn_dataset = NodePropPredDataset("ogbn-products", path) | |||
| if _backend.DependentBackend.is_dgl(): | |||
| super(OGBNProductsDataset, self).__init__([ | |||
| _OGBNDatasetUtil.ogbn_dataset_to_general_static_graph( | |||
| ogbn_dataset, "label", | |||
| {"node_feat": "feat"}, | |||
| {"edge_feat": "edge_feat"} | |||
| ) | |||
| ]) | |||
| elif _backend.DependentBackend.is_pyg(): | |||
| super(OGBNProductsDataset, self).__init__([ | |||
| _OGBNDatasetUtil.ogbn_dataset_to_general_static_graph( | |||
| ogbn_dataset, "y", | |||
| {"node_feat": "x"} | |||
| ) | |||
| ]) | |||
| @DatasetUniversalRegistry.register_dataset("ogbn-proteins") | |||
| class OGBNProteinsDataset(InMemoryStaticGraphSet): | |||
| def __init__(self, path: str): | |||
| ogbn_dataset = NodePropPredDataset("ogbn-proteins", path) | |||
| if _backend.DependentBackend.is_dgl(): | |||
| super(OGBNProteinsDataset, self).__init__([ | |||
| _OGBNDatasetUtil.ogbn_dataset_to_general_static_graph( | |||
| ogbn_dataset, "label", | |||
| {"node_species": "species"}, | |||
| {"edge_feat": "edge_feat"} | |||
| ) | |||
| ]) | |||
| elif _backend.DependentBackend.is_pyg(): | |||
| super(OGBNProteinsDataset, self).__init__([ | |||
| _OGBNDatasetUtil.ogbn_dataset_to_general_static_graph( | |||
| ogbn_dataset, "y", | |||
| {"node_species": "species"}, | |||
| {"edge_feat": "edge_feat"} | |||
| ) | |||
| ]) | |||
| @DatasetUniversalRegistry.register_dataset("ogbn-arxiv") | |||
| class OGBNArxivDataset(InMemoryStaticGraphSet): | |||
| def __init__(self, path: str): | |||
| ogbn_dataset = NodePropPredDataset("ogbn-arxiv", path) | |||
| if _backend.DependentBackend.is_dgl(): | |||
| super(OGBNArxivDataset, self).__init__([ | |||
| _OGBNDatasetUtil.ogbn_dataset_to_general_static_graph( | |||
| ogbn_dataset, "label", | |||
| { | |||
| "node_feat": "feat", | |||
| "node_year": "year" | |||
| } | |||
| ) | |||
| ]) | |||
| elif _backend.DependentBackend.is_pyg(): | |||
| super(OGBNArxivDataset, self).__init__([ | |||
| _OGBNDatasetUtil.ogbn_dataset_to_general_static_graph( | |||
| ogbn_dataset, "y", | |||
| { | |||
| "node_feat": "x", | |||
| "node_year": "year" | |||
| } | |||
| ) | |||
| ]) | |||
| @DatasetUniversalRegistry.register_dataset("ogbn-papers100M") | |||
| class OGBNPapers100MDataset(InMemoryStaticGraphSet): | |||
| def __init__(self, path: str): | |||
| ogbn_dataset = NodePropPredDataset("ogbn-papers100M", path) | |||
| if _backend.DependentBackend.is_dgl(): | |||
| super(OGBNPapers100MDataset, self).__init__([ | |||
| _OGBNDatasetUtil.ogbn_dataset_to_general_static_graph( | |||
| ogbn_dataset, "label", | |||
| { | |||
| "node_feat": "feat", | |||
| "node_year": "year" | |||
| } | |||
| ) | |||
| ]) | |||
| elif _backend.DependentBackend.is_pyg(): | |||
| super(OGBNPapers100MDataset, self).__init__([ | |||
| _OGBNDatasetUtil.ogbn_dataset_to_general_static_graph( | |||
| ogbn_dataset, "y", | |||
| { | |||
| "node_feat": "x", | |||
| "node_year": "year" | |||
| } | |||
| ) | |||
| ]) | |||
| # todo: currently homogeneous dataset `ogbn-mag` NOT supported | |||
| class _OGBLDatasetUtil(_OGBDatasetUtil): | |||
| @classmethod | |||
| def ogbl_data_to_general_static_graph( | |||
| cls, ogbl_data: _typing.Mapping[str, _typing.Union[np.ndarray, int]], | |||
| heterogeneous_edges: _typing.Mapping[ | |||
| _typing.Tuple[str, str, str], | |||
| _typing.Union[ | |||
| torch.Tensor, | |||
| _typing.Tuple[torch.Tensor, _typing.Optional[_typing.Mapping[str, torch.Tensor]]] | |||
| ] | |||
| ] = ..., | |||
| nodes_data_key_mapping: _typing.Optional[_typing.Mapping[str, str]] = ..., | |||
| graph_data_key_mapping: _typing.Optional[_typing.Mapping[str, str]] = ... | |||
| ) -> GeneralStaticGraph: | |||
| return GeneralStaticGraphGenerator.create_heterogeneous_static_graph( | |||
| { | |||
| '': dict([ | |||
| (target_data_key, torch.from_numpy(ogbl_data[source_data_key]).squeeze()) | |||
| for source_data_key, target_data_key in nodes_data_key_mapping.items() | |||
| ]) | |||
| }, | |||
| heterogeneous_edges, | |||
| dict([ | |||
| (target_data_key, torch.from_numpy(ogbl_data[source_data_key]).squeeze()) | |||
| for source_data_key, target_data_key in graph_data_key_mapping.items() | |||
| ]) if isinstance(graph_data_key_mapping, _typing.Mapping) else ... | |||
| ) | |||
| @DatasetUniversalRegistry.register_dataset("ogbl-ppa") | |||
| class OGBLPPADataset(InMemoryStaticGraphSet): | |||
| def __init__(self, path: str): | |||
| ogbl_dataset = LinkPropPredDataset("ogbl-ppa", path) | |||
| edge_split = ogbl_dataset.get_edge_split() | |||
| super(OGBLPPADataset, self).__init__([ | |||
| _OGBLDatasetUtil.ogbl_data_to_general_static_graph( | |||
| ogbl_dataset[0], { | |||
| ('', '', ''): torch.from_numpy(ogbl_dataset[0]['edge_index']), | |||
| ('', 'train_pos_edge', ''): torch.from_numpy(edge_split['train']['edge']), | |||
| ('', 'val_pos_edge', ''): torch.from_numpy(edge_split['valid']['edge']), | |||
| ('', 'val_neg_edge', ''): torch.from_numpy(edge_split['valid']['edge_neg']), | |||
| ('', 'test_pos_edge', ''): torch.from_numpy(edge_split['test']['edge']), | |||
| ('', 'test_neg_edge', ''): torch.from_numpy(edge_split['test']['edge_neg']) | |||
| }, | |||
| {'node_feat': 'feat'} if _backend.DependentBackend.is_dgl() else {'node_feat': 'x'} | |||
| ) | |||
| ]) | |||
| @DatasetUniversalRegistry.register_dataset("ogbl-collab") | |||
| class OGBLCOLLABDataset(InMemoryStaticGraphSet): | |||
| def __init__(self, path: str): | |||
| ogbl_dataset = LinkPropPredDataset("ogbl-collab", path) | |||
| edge_split = ogbl_dataset.get_edge_split() | |||
| super(OGBLCOLLABDataset, self).__init__([ | |||
| _OGBLDatasetUtil.ogbl_data_to_general_static_graph( | |||
| ogbl_dataset[0], { | |||
| ('', '', ''): torch.from_numpy(ogbl_dataset[0]['edge_index']), | |||
| ('', 'train_pos_edge', ''): ( | |||
| torch.from_numpy(edge_split['train']['edge']), | |||
| { | |||
| 'weight': torch.from_numpy(edge_split['train']['weight']), | |||
| 'year': torch.from_numpy(edge_split['train']['year']) | |||
| } | |||
| ), | |||
| ('', 'val_pos_edge', ''): ( | |||
| torch.from_numpy(edge_split['valid']['edge']), | |||
| { | |||
| 'weight': torch.from_numpy(edge_split['valid']['weight']), | |||
| 'year': torch.from_numpy(edge_split['valid']['year']) | |||
| } | |||
| ), | |||
| ('', 'val_neg_edge', ''): torch.from_numpy(edge_split['valid']['edge_neg']), | |||
| ('', 'test_pos_edge', ''): ( | |||
| torch.from_numpy(edge_split['test']['edge']), | |||
| { | |||
| 'weight': torch.from_numpy(edge_split['test']['weight']), | |||
| 'year': torch.from_numpy(edge_split['test']['year']) | |||
| } | |||
| ), | |||
| ('', 'test_neg_edge', ''): torch.from_numpy(edge_split['test']['edge_neg']) | |||
| }, | |||
| {'node_feat': 'feat'} if _backend.DependentBackend.is_dgl() else {'node_feat': 'x'} | |||
| ) | |||
| ]) | |||
| @DatasetUniversalRegistry.register_dataset("ogbl-ddi") | |||
| class OGBLDDIDataset(InMemoryStaticGraphSet): | |||
| def __init__(self, path: str): | |||
| ogbl_dataset = LinkPropPredDataset("ogbl-ddi", path) | |||
| edge_split = ogbl_dataset.get_edge_split() | |||
| super(OGBLDDIDataset, self).__init__([ | |||
| GeneralStaticGraphGenerator.create_heterogeneous_static_graph( | |||
| {'': {'_NID': torch.arange(ogbl_dataset[0]['num_nodes'])}}, | |||
| { | |||
| ('', '', ''): torch.from_numpy(ogbl_dataset[0]['edge_index']), | |||
| ('', 'train_pos_edge', ''): torch.from_numpy(edge_split['train']['edge']), | |||
| ('', 'val_pos_edge', ''): torch.from_numpy(edge_split['valid']['edge']), | |||
| ('', 'val_neg_edge', ''): torch.from_numpy(edge_split['valid']['edge_neg']), | |||
| ('', 'test_pos_edge', ''): torch.from_numpy(edge_split['test']['edge']), | |||
| ('', 'test_neg_edge', ''): torch.from_numpy(edge_split['test']['edge_neg']) | |||
| } | |||
| ) | |||
| ]) | |||
| @DatasetUniversalRegistry.register_dataset("ogbl-citation") | |||
| @DatasetUniversalRegistry.register_dataset("ogbl-citation2") | |||
| class OGBLCitation2Dataset(InMemoryStaticGraphSet): | |||
| def __init__(self, path: str): | |||
| ogbl_dataset = LinkPropPredDataset("ogbl-citation2", path) | |||
| edge_split = ogbl_dataset.get_edge_split() | |||
| super(OGBLCitation2Dataset, self).__init__([ | |||
| _OGBLDatasetUtil.ogbl_data_to_general_static_graph( | |||
| ogbl_dataset[0], | |||
| { | |||
| ('', '', ''): torch.from_numpy(ogbl_dataset[0]['edge_index']), | |||
| ('', 'train_pos_edge', ''): torch.from_numpy(edge_split['train']['edge']), | |||
| ('', 'val_pos_edge', ''): torch.from_numpy(edge_split['valid']['edge']), | |||
| ('', 'val_neg_edge', ''): torch.from_numpy(edge_split['valid']['edge_neg']), | |||
| ('', 'test_pos_edge', ''): torch.from_numpy(edge_split['test']['edge']), | |||
| ('', 'test_neg_edge', ''): torch.from_numpy(edge_split['test']['edge_neg']) | |||
| }, | |||
| ( | |||
| {'node_feat': 'feat', 'node_year': 'year'} | |||
| if _backend.DependentBackend.is_dgl() | |||
| else {'node_feat': 'x', 'node_year': 'year'} | |||
| ) | |||
| ) | |||
| ]) | |||
| # todo: currently homogeneous dataset `ogbl-wikikg2` and `ogbl-biokg` NOT supported | |||
| class _OGBGDatasetUtil: | |||
| ... | |||
| @DatasetUniversalRegistry.register_dataset("ogbg-molhiv") | |||
| class OGBGMOLHIVDataset(InMemoryStaticGraphSet): | |||
| def __init__(self, path: str): | |||
| ogbl_dataset = GraphPropPredDataset("ogbg-molhiv", path) | |||
| idx_split: _typing.Mapping[str, np.ndarray] = ogbl_dataset.get_idx_split() | |||
| train_index: _typing.Any = idx_split['train'].tolist() | |||
| test_index: _typing.Any = idx_split['test'].tolist() | |||
| val_index: _typing.Any = idx_split['valid'].tolist() | |||
| super(OGBGMOLHIVDataset, self).__init__( | |||
| [ | |||
| GeneralStaticGraphGenerator.create_homogeneous_static_graph( | |||
| ( | |||
| {"feat": torch.from_numpy(data['node_feat'])} | |||
| if _backend.DependentBackend.is_dgl() | |||
| else {"x": torch.from_numpy(data['node_feat'])} | |||
| ), | |||
| torch.from_numpy(data['edge_index']), | |||
| {'edge_feat': torch.from_numpy(data['edge_feat'])}, | |||
| ( | |||
| {'label': torch.from_numpy(label)} | |||
| if _backend.DependentBackend.is_dgl() | |||
| else {'y': torch.from_numpy(label)} | |||
| ) | |||
| ) for data, label in ogbl_dataset | |||
| ], | |||
| train_index, val_index, test_index | |||
| ) | |||
| @DatasetUniversalRegistry.register_dataset("ogbg-molpcba") | |||
| class OGBGMOLPCBADataset(InMemoryStaticGraphSet): | |||
| def __init__(self, path: str): | |||
| ogbl_dataset = GraphPropPredDataset("ogbg-molhiv", path) | |||
| idx_split: _typing.Mapping[str, np.ndarray] = ogbl_dataset.get_idx_split() | |||
| train_index: _typing.Any = idx_split['train'].tolist() | |||
| test_index: _typing.Any = idx_split['test'].tolist() | |||
| val_index: _typing.Any = idx_split['valid'].tolist() | |||
| super(OGBGMOLPCBADataset, self).__init__( | |||
| [ | |||
| GeneralStaticGraphGenerator.create_homogeneous_static_graph( | |||
| ( | |||
| {"feat": torch.from_numpy(data['node_feat'])} | |||
| if _backend.DependentBackend.is_dgl() | |||
| else {"x": torch.from_numpy(data['node_feat'])} | |||
| ), | |||
| torch.from_numpy(data['edge_index']), | |||
| {'edge_feat': torch.from_numpy(data['edge_feat'])}, | |||
| ( | |||
| {'label': torch.from_numpy(label)} | |||
| if _backend.DependentBackend.is_dgl() | |||
| else {'y': torch.from_numpy(label)} | |||
| ) | |||
| ) for data, label in ogbl_dataset | |||
| ], | |||
| train_index, val_index, test_index | |||
| ) | |||
| @DatasetUniversalRegistry.register_dataset("ogbg-ppa") | |||
| class OGBGPPADataset(InMemoryStaticGraphSet): | |||
| def __init__(self, path: str): | |||
| ogbl_dataset = GraphPropPredDataset("ogbg-molhiv", path) | |||
| idx_split: _typing.Mapping[str, np.ndarray] = ogbl_dataset.get_idx_split() | |||
| train_index: _typing.Any = idx_split['train'].tolist() | |||
| test_index: _typing.Any = idx_split['test'].tolist() | |||
| val_index: _typing.Any = idx_split['valid'].tolist() | |||
| super(OGBGPPADataset, self).__init__( | |||
| [ | |||
| GeneralStaticGraphGenerator.create_homogeneous_static_graph( | |||
| {'_NID': torch.arange(data['num_nodes'])}, | |||
| torch.from_numpy(data['edge_index']), | |||
| {'edge_feat': torch.from_numpy(data['edge_feat'])}, | |||
| ( | |||
| {'label': torch.from_numpy(label)} | |||
| if _backend.DependentBackend.is_dgl() | |||
| else {'y': torch.from_numpy(label)} | |||
| ) | |||
| ) for data, label in ogbl_dataset | |||
| ], | |||
| train_index, val_index, test_index | |||
| ) | |||
| @DatasetUniversalRegistry.register_dataset("ogbg-code") | |||
| @DatasetUniversalRegistry.register_dataset("ogbg-code2") | |||
| class OGBGCode2Dataset(InMemoryStaticGraphSet): | |||
| def __init__(self, path: str): | |||
| ogbl_dataset = GraphPropPredDataset("ogbg-molhiv", path) | |||
| idx_split: _typing.Mapping[str, np.ndarray] = ogbl_dataset.get_idx_split() | |||
| train_index: _typing.Any = idx_split['train'].tolist() | |||
| test_index: _typing.Any = idx_split['test'].tolist() | |||
| val_index: _typing.Any = idx_split['valid'].tolist() | |||
| super(OGBGCode2Dataset, self).__init__( | |||
| [ | |||
| GeneralStaticGraphGenerator.create_homogeneous_static_graph( | |||
| ( | |||
| { | |||
| "feat": torch.from_numpy(data['node_feat']), | |||
| "node_is_attributed": torch.from_numpy(data["node_is_attributed"]), | |||
| "node_dfs_order": torch.from_numpy(data["node_dfs_order"]), | |||
| "node_depth": torch.from_numpy(data["node_depth"]) | |||
| } | |||
| if _backend.DependentBackend.is_dgl() | |||
| else | |||
| { | |||
| "x": torch.from_numpy(data['node_feat']), | |||
| "node_is_attributed": torch.from_numpy(data["node_is_attributed"]), | |||
| "node_dfs_order": torch.from_numpy(data["node_dfs_order"]), | |||
| "node_depth": torch.from_numpy(data["node_depth"]) | |||
| } | |||
| ), | |||
| torch.from_numpy(data['edge_index']) | |||
| ) for data, label in ogbl_dataset | |||
| ], | |||
| train_index, val_index, test_index | |||
| ) | |||
| @@ -0,0 +1,567 @@ | |||
| import os | |||
| from autogl.data.graph import GeneralStaticGraphGenerator | |||
| from autogl.data import InMemoryStaticGraphSet | |||
| from ._dataset_registry import DatasetUniversalRegistry | |||
| import torch_geometric | |||
| from torch_geometric.datasets import ( | |||
| Amazon, Coauthor, Flickr, ModelNet, | |||
| Planetoid, PPI, QM9, Reddit, TUDataset | |||
| ) | |||
| @DatasetUniversalRegistry.register_dataset("cora") | |||
| class CoraDataset(InMemoryStaticGraphSet): | |||
| def __init__(self, path: str): | |||
| pyg_dataset = Planetoid(os.path.join(path, '_pyg'), "Cora") | |||
| if hasattr(pyg_dataset, "__data_list__"): | |||
| delattr(pyg_dataset, "__data_list__") | |||
| if hasattr(pyg_dataset, "_data_list"): | |||
| delattr(pyg_dataset, "_data_list") | |||
| pyg_data = pyg_dataset[0] | |||
| static_graph = GeneralStaticGraphGenerator.create_homogeneous_static_graph( | |||
| { | |||
| 'x': pyg_data.x, | |||
| 'y': pyg_data.y, | |||
| 'train_mask': getattr(pyg_data, 'train_mask'), | |||
| 'val_mask': getattr(pyg_data, 'val_mask'), | |||
| 'test_mask': getattr(pyg_data, 'test_mask') | |||
| }, | |||
| pyg_data.edge_index | |||
| ) | |||
| super(CoraDataset, self).__init__([static_graph]) | |||
| @DatasetUniversalRegistry.register_dataset("CiteSeer".lower()) | |||
| class CiteSeerDataset(InMemoryStaticGraphSet): | |||
| def __init__(self, path: str): | |||
| pyg_dataset = Planetoid(os.path.join(path, '_pyg'), "CiteSeer") | |||
| if hasattr(pyg_dataset, "__data_list__"): | |||
| delattr(pyg_dataset, "__data_list__") | |||
| if hasattr(pyg_dataset, "_data_list"): | |||
| delattr(pyg_dataset, "_data_list") | |||
| pyg_data = pyg_dataset[0] | |||
| static_graph = GeneralStaticGraphGenerator.create_homogeneous_static_graph( | |||
| { | |||
| 'x': pyg_data.x, | |||
| 'y': pyg_data.y, | |||
| 'train_mask': getattr(pyg_data, 'train_mask'), | |||
| 'val_mask': getattr(pyg_data, 'val_mask'), | |||
| 'test_mask': getattr(pyg_data, 'test_mask') | |||
| }, | |||
| pyg_data.edge_index | |||
| ) | |||
| super(CiteSeerDataset, self).__init__([static_graph]) | |||
| @DatasetUniversalRegistry.register_dataset("PubMed".lower()) | |||
| class PubMedDataset(InMemoryStaticGraphSet): | |||
| def __init__(self, path: str): | |||
| pyg_dataset = Planetoid(os.path.join(path, '_pyg'), "PubMed") | |||
| if hasattr(pyg_dataset, "__data_list__"): | |||
| delattr(pyg_dataset, "__data_list__") | |||
| if hasattr(pyg_dataset, "_data_list"): | |||
| delattr(pyg_dataset, "_data_list") | |||
| pyg_data = pyg_dataset[0] | |||
| static_graph = GeneralStaticGraphGenerator.create_homogeneous_static_graph( | |||
| { | |||
| 'x': pyg_data.x, | |||
| 'y': pyg_data.y, | |||
| 'train_mask': getattr(pyg_data, 'train_mask'), | |||
| 'val_mask': getattr(pyg_data, 'val_mask'), | |||
| 'test_mask': getattr(pyg_data, 'test_mask') | |||
| }, | |||
| pyg_data.edge_index | |||
| ) | |||
| super(PubMedDataset, self).__init__([static_graph]) | |||
| @DatasetUniversalRegistry.register_dataset("flickr") | |||
| class FlickrDataset(InMemoryStaticGraphSet): | |||
| def __init__(self, path: str): | |||
| pyg_dataset = Flickr(os.path.join(path, '_pyg')) | |||
| if hasattr(pyg_dataset, "__data_list__"): | |||
| delattr(pyg_dataset, "__data_list__") | |||
| if hasattr(pyg_dataset, "_data_list"): | |||
| delattr(pyg_dataset, "_data_list") | |||
| pyg_data = pyg_dataset[0] | |||
| static_graph = GeneralStaticGraphGenerator.create_homogeneous_static_graph( | |||
| { | |||
| 'x': pyg_data.x, | |||
| 'y': pyg_data.y, | |||
| 'train_mask': getattr(pyg_data, 'train_mask'), | |||
| 'val_mask': getattr(pyg_data, 'val_mask'), | |||
| 'test_mask': getattr(pyg_data, 'test_mask') | |||
| }, | |||
| pyg_data.edge_index | |||
| ) | |||
| super(FlickrDataset, self).__init__([static_graph]) | |||
| @DatasetUniversalRegistry.register_dataset("reddit") | |||
| class RedditDataset(InMemoryStaticGraphSet): | |||
| def __init__(self, path: str): | |||
| pyg_dataset = Reddit(os.path.join(path, '_pyg')) | |||
| if hasattr(pyg_dataset, "__data_list__"): | |||
| delattr(pyg_dataset, "__data_list__") | |||
| if hasattr(pyg_dataset, "_data_list"): | |||
| delattr(pyg_dataset, "_data_list") | |||
| pyg_data = pyg_dataset[0] | |||
| static_graph = GeneralStaticGraphGenerator.create_homogeneous_static_graph( | |||
| { | |||
| 'x': pyg_data.x, | |||
| 'y': pyg_data.y, | |||
| 'train_mask': getattr(pyg_data, 'train_mask'), | |||
| 'val_mask': getattr(pyg_data, 'val_mask'), | |||
| 'test_mask': getattr(pyg_data, 'test_mask') | |||
| }, | |||
| pyg_data.edge_index | |||
| ) | |||
| super(RedditDataset, self).__init__([static_graph]) | |||
| @DatasetUniversalRegistry.register_dataset("amazon_computers") | |||
| class AmazonComputersDataset(InMemoryStaticGraphSet): | |||
| def __init__(self, path: str): | |||
| pyg_dataset = Amazon(os.path.join(path, '_pyg'), "Computers") | |||
| if hasattr(pyg_dataset, "__data_list__"): | |||
| delattr(pyg_dataset, "__data_list__") | |||
| if hasattr(pyg_dataset, "_data_list"): | |||
| delattr(pyg_dataset, "_data_list") | |||
| pyg_data = pyg_dataset[0] | |||
| static_graph = GeneralStaticGraphGenerator.create_homogeneous_static_graph( | |||
| {'x': pyg_data.x, 'y': pyg_data.y}, | |||
| pyg_data.edge_index | |||
| ) | |||
| super(AmazonComputersDataset, self).__init__([static_graph]) | |||
| @DatasetUniversalRegistry.register_dataset("amazon_photo") | |||
| class AmazonPhotoDataset(InMemoryStaticGraphSet): | |||
| def __init__(self, path: str): | |||
| pyg_dataset = Amazon(os.path.join(path, '_pyg'), "Photo") | |||
| if hasattr(pyg_dataset, "__data_list__"): | |||
| delattr(pyg_dataset, "__data_list__") | |||
| if hasattr(pyg_dataset, "_data_list"): | |||
| delattr(pyg_dataset, "_data_list") | |||
| pyg_data = pyg_dataset[0] | |||
| static_graph = GeneralStaticGraphGenerator.create_homogeneous_static_graph( | |||
| {'x': pyg_data.x, 'y': pyg_data.y}, | |||
| pyg_data.edge_index | |||
| ) | |||
| super(AmazonPhotoDataset, self).__init__([static_graph]) | |||
| @DatasetUniversalRegistry.register_dataset("coauthor_physics") | |||
| class CoauthorPhysicsDataset(InMemoryStaticGraphSet): | |||
| def __init__(self, path: str): | |||
| pyg_dataset = Coauthor(os.path.join(path, '_pyg'), "Physics") | |||
| if hasattr(pyg_dataset, "__data_list__"): | |||
| delattr(pyg_dataset, "__data_list__") | |||
| if hasattr(pyg_dataset, "_data_list"): | |||
| delattr(pyg_dataset, "_data_list") | |||
| pyg_data = pyg_dataset[0] | |||
| static_graph = GeneralStaticGraphGenerator.create_homogeneous_static_graph( | |||
| {'x': pyg_data.x, 'y': pyg_data.y}, | |||
| pyg_data.edge_index | |||
| ) | |||
| super(CoauthorPhysicsDataset, self).__init__([static_graph]) | |||
| @DatasetUniversalRegistry.register_dataset("coauthor_cs") | |||
| class CoauthorCSDataset(InMemoryStaticGraphSet): | |||
| def __init__(self, path: str): | |||
| pyg_dataset = Coauthor(os.path.join(path, '_pyg'), "CS") | |||
| if hasattr(pyg_dataset, "__data_list__"): | |||
| delattr(pyg_dataset, "__data_list__") | |||
| if hasattr(pyg_dataset, "_data_list"): | |||
| delattr(pyg_dataset, "_data_list") | |||
| pyg_data = pyg_dataset[0] | |||
| static_graph = GeneralStaticGraphGenerator.create_homogeneous_static_graph( | |||
| {'x': pyg_data.x, 'y': pyg_data.y}, | |||
| pyg_data.edge_index | |||
| ) | |||
| super(CoauthorCSDataset, self).__init__([static_graph]) | |||
| @DatasetUniversalRegistry.register_dataset("ppi") | |||
| class PPIDataset(InMemoryStaticGraphSet): | |||
| def __init__(self, path: str): | |||
| train_dataset = PPI(os.path.join(path, '_pyg'), 'train') | |||
| if hasattr(train_dataset, "__data_list__"): | |||
| delattr(train_dataset, "__data_list__") | |||
| if hasattr(train_dataset, "_data_list"): | |||
| delattr(train_dataset, "_data_list") | |||
| val_dataset = PPI(os.path.join(path, '_pyg'), 'val') | |||
| if hasattr(val_dataset, "__data_list__"): | |||
| delattr(val_dataset, "__data_list__") | |||
| if hasattr(val_dataset, "_data_list"): | |||
| delattr(val_dataset, "_data_list") | |||
| test_dataset = PPI(os.path.join(path, '_pyg'), 'test') | |||
| if hasattr(test_dataset, "__data_list__"): | |||
| delattr(test_dataset, "__data_list__") | |||
| if hasattr(test_dataset, "_data_list"): | |||
| delattr(test_dataset, "_data_list") | |||
| train_index = range(len(train_dataset)) | |||
| val_index = range(len(train_dataset), len(train_dataset) + len(val_dataset)) | |||
| test_index = range( | |||
| len(train_dataset) + len(val_dataset), | |||
| len(train_dataset) + len(val_dataset) + len(test_dataset) | |||
| ) | |||
| super(PPIDataset, self).__init__( | |||
| [ | |||
| GeneralStaticGraphGenerator.create_homogeneous_static_graph( | |||
| {'x': data.x, 'y': data.y}, data.edge_index | |||
| ) for data in train_dataset | |||
| ] + | |||
| [ | |||
| GeneralStaticGraphGenerator.create_homogeneous_static_graph( | |||
| {'x': data.x, 'y': data.y}, data.edge_index | |||
| ) for data in val_dataset | |||
| ] + | |||
| [ | |||
| GeneralStaticGraphGenerator.create_homogeneous_static_graph( | |||
| {'x': data.x, 'y': data.y}, data.edge_index | |||
| ) for data in test_dataset | |||
| ], | |||
| train_index, val_index, test_index | |||
| ) | |||
| @DatasetUniversalRegistry.register_dataset("qm9") | |||
| class QM9Dataset(InMemoryStaticGraphSet): | |||
| def __init__(self, path: str): | |||
| pyg_dataset = QM9(os.path.join(path, '_pyg')) | |||
| if hasattr(pyg_dataset, "__data_list__"): | |||
| delattr(pyg_dataset, "__data_list__") | |||
| if hasattr(pyg_dataset, "_data_list"): | |||
| delattr(pyg_dataset, "_data_list") | |||
| super(QM9Dataset, self).__init__( | |||
| [ | |||
| GeneralStaticGraphGenerator.create_homogeneous_static_graph( | |||
| {'x': data.x, 'pos': data.pos, 'z': data.z}, | |||
| data.edge_index, | |||
| edges_data={'edge_attr': data.edge_attr}, | |||
| graph_data={'idx': data.idx, 'y': data.y} | |||
| ) for data in pyg_dataset | |||
| ] | |||
| ) | |||
| @DatasetUniversalRegistry.register_dataset("mutag") | |||
| class MUTAGDataset(InMemoryStaticGraphSet): | |||
| def __init__(self, path: str): | |||
| pyg_dataset = TUDataset(os.path.join(path, '_pyg'), "MUTAG") | |||
| if hasattr(pyg_dataset, "__data_list__"): | |||
| delattr(pyg_dataset, "__data_list__") | |||
| if hasattr(pyg_dataset, "_data_list"): | |||
| delattr(pyg_dataset, "_data_list") | |||
| super(MUTAGDataset, self).__init__( | |||
| [ | |||
| GeneralStaticGraphGenerator.create_homogeneous_static_graph( | |||
| {'x': pyg_data.x}, pyg_data.edge_index, | |||
| edges_data={'edge_attr': pyg_data.edge_attr}, | |||
| graph_data={'y': pyg_data.y} | |||
| ) | |||
| for pyg_data in pyg_dataset | |||
| ] | |||
| ) | |||
| @DatasetUniversalRegistry.register_dataset("enzymes") | |||
| class ENZYMESDataset(InMemoryStaticGraphSet): | |||
| def __init__(self, path: str): | |||
| pyg_dataset = TUDataset(os.path.join(path, '_pyg'), "ENZYMES") | |||
| if hasattr(pyg_dataset, "__data_list__"): | |||
| delattr(pyg_dataset, "__data_list__") | |||
| if hasattr(pyg_dataset, "_data_list"): | |||
| delattr(pyg_dataset, "_data_list") | |||
| super(ENZYMESDataset, self).__init__( | |||
| [ | |||
| GeneralStaticGraphGenerator.create_homogeneous_static_graph( | |||
| {'x': pyg_data.x}, pyg_data.edge_index, | |||
| graph_data={'y': pyg_data.y} | |||
| ) | |||
| for pyg_data in pyg_dataset | |||
| ] | |||
| ) | |||
| @DatasetUniversalRegistry.register_dataset("imdb-b") | |||
| class IMDBBinaryDataset(InMemoryStaticGraphSet): | |||
| def __init__(self, path: str): | |||
| pyg_dataset = TUDataset(os.path.join(path, '_pyg'), "IMDB-BINARY") | |||
| if hasattr(pyg_dataset, "__data_list__"): | |||
| delattr(pyg_dataset, "__data_list__") | |||
| if hasattr(pyg_dataset, "_data_list"): | |||
| delattr(pyg_dataset, "_data_list") | |||
| super(IMDBBinaryDataset, self).__init__( | |||
| [ | |||
| GeneralStaticGraphGenerator.create_homogeneous_static_graph( | |||
| {}, pyg_data.edge_index, graph_data={'y': pyg_data.y} | |||
| ) | |||
| for pyg_data in pyg_dataset | |||
| ] | |||
| ) | |||
| @DatasetUniversalRegistry.register_dataset("imdb-m") | |||
| class IMDBMultiDataset(InMemoryStaticGraphSet): | |||
| def __init__(self, path: str): | |||
| pyg_dataset = TUDataset(os.path.join(path, '_pyg'), "IMDB-MULTI") | |||
| if hasattr(pyg_dataset, "__data_list__"): | |||
| delattr(pyg_dataset, "__data_list__") | |||
| if hasattr(pyg_dataset, "_data_list"): | |||
| delattr(pyg_dataset, "_data_list") | |||
| super(IMDBMultiDataset, self).__init__( | |||
| [ | |||
| GeneralStaticGraphGenerator.create_homogeneous_static_graph( | |||
| {}, pyg_data.edge_index, graph_data={'y': pyg_data.y} | |||
| ) | |||
| for pyg_data in pyg_dataset | |||
| ] | |||
| ) | |||
| @DatasetUniversalRegistry.register_dataset("reddit-b") | |||
| class RedditBinaryDataset(InMemoryStaticGraphSet): | |||
| def __init__(self, path: str): | |||
| pyg_dataset = TUDataset(os.path.join(path, '_pyg'), "REDDIT-BINARY") | |||
| if hasattr(pyg_dataset, "__data_list__"): | |||
| delattr(pyg_dataset, "__data_list__") | |||
| if hasattr(pyg_dataset, "_data_list"): | |||
| delattr(pyg_dataset, "_data_list") | |||
| super(RedditBinaryDataset, self).__init__( | |||
| [ | |||
| GeneralStaticGraphGenerator.create_homogeneous_static_graph( | |||
| {}, pyg_data.edge_index, graph_data={'y': pyg_data.y} | |||
| ) | |||
| for pyg_data in pyg_dataset | |||
| ] | |||
| ) | |||
| @DatasetUniversalRegistry.register_dataset("reddit-multi-5k") | |||
| class REDDITMulti5KDataset(InMemoryStaticGraphSet): | |||
| def __init__(self, path: str): | |||
| pyg_dataset = TUDataset(os.path.join(path, '_pyg'), "REDDIT-MULTI-5K") | |||
| if hasattr(pyg_dataset, "__data_list__"): | |||
| delattr(pyg_dataset, "__data_list__") | |||
| if hasattr(pyg_dataset, "_data_list"): | |||
| delattr(pyg_dataset, "_data_list") | |||
| super(REDDITMulti5KDataset, self).__init__( | |||
| [ | |||
| GeneralStaticGraphGenerator.create_homogeneous_static_graph( | |||
| {}, pyg_data.edge_index, graph_data={'y': pyg_data.y} | |||
| ) | |||
| for pyg_data in pyg_dataset | |||
| ] | |||
| ) | |||
| @DatasetUniversalRegistry.register_dataset("reddit-multi-12k") | |||
| class REDDITMulti12KDataset(InMemoryStaticGraphSet): | |||
| def __init__(self, path: str): | |||
| pyg_dataset = TUDataset(os.path.join(path, '_pyg'), "REDDIT-MULTI-12K") | |||
| if hasattr(pyg_dataset, "__data_list__"): | |||
| delattr(pyg_dataset, "__data_list__") | |||
| if hasattr(pyg_dataset, "_data_list"): | |||
| delattr(pyg_dataset, "_data_list") | |||
| super(REDDITMulti12KDataset, self).__init__( | |||
| [ | |||
| GeneralStaticGraphGenerator.create_homogeneous_static_graph( | |||
| {}, pyg_data.edge_index, graph_data={'y': pyg_data.y} | |||
| ) | |||
| for pyg_data in pyg_dataset | |||
| ] | |||
| ) | |||
| @DatasetUniversalRegistry.register_dataset("collab") | |||
| class COLLABDataset(InMemoryStaticGraphSet): | |||
| def __init__(self, path: str): | |||
| pyg_dataset = TUDataset(os.path.join(path, '_pyg'), "COLLAB") | |||
| if hasattr(pyg_dataset, "__data_list__"): | |||
| delattr(pyg_dataset, "__data_list__") | |||
| if hasattr(pyg_dataset, "_data_list"): | |||
| delattr(pyg_dataset, "_data_list") | |||
| super(COLLABDataset, self).__init__( | |||
| [ | |||
| GeneralStaticGraphGenerator.create_homogeneous_static_graph( | |||
| {}, pyg_data.edge_index, graph_data={'y': pyg_data.y} | |||
| ) | |||
| for pyg_data in pyg_dataset | |||
| ] | |||
| ) | |||
| @DatasetUniversalRegistry.register_dataset("proteins") | |||
| class ProteinsDataset(InMemoryStaticGraphSet): | |||
| def __init__(self, path: str): | |||
| pyg_dataset = TUDataset(os.path.join(path, '_pyg'), "PROTEINS") | |||
| if hasattr(pyg_dataset, "__data_list__"): | |||
| delattr(pyg_dataset, "__data_list__") | |||
| if hasattr(pyg_dataset, "_data_list"): | |||
| delattr(pyg_dataset, "_data_list") | |||
| super(ProteinsDataset, self).__init__( | |||
| [ | |||
| GeneralStaticGraphGenerator.create_homogeneous_static_graph( | |||
| {'x': pyg_data.x}, pyg_data.edge_index, graph_data={'y': pyg_data.y} | |||
| ) | |||
| for pyg_data in pyg_dataset | |||
| ] | |||
| ) | |||
| @DatasetUniversalRegistry.register_dataset("ptc-mr") | |||
| class PTCMRDataset(InMemoryStaticGraphSet): | |||
| def __init__(self, path: str): | |||
| pyg_dataset = TUDataset(os.path.join(path, '_pyg'), "PTC_MR") | |||
| if hasattr(pyg_dataset, "__data_list__"): | |||
| delattr(pyg_dataset, "__data_list__") | |||
| if hasattr(pyg_dataset, "_data_list"): | |||
| delattr(pyg_dataset, "_data_list") | |||
| super(PTCMRDataset, self).__init__( | |||
| [ | |||
| GeneralStaticGraphGenerator.create_homogeneous_static_graph( | |||
| {'x': pyg_data.x}, pyg_data.edge_index, | |||
| edges_data={'edge_attr': pyg_data.edge_attr}, | |||
| graph_data={'y': pyg_data.y} | |||
| ) | |||
| for pyg_data in pyg_dataset | |||
| ] | |||
| ) | |||
| @DatasetUniversalRegistry.register_dataset("nci1") | |||
| class NCI1Dataset(InMemoryStaticGraphSet): | |||
| def __init__(self, path: str): | |||
| pyg_dataset = TUDataset(os.path.join(path, '_pyg'), "NCI1") | |||
| if hasattr(pyg_dataset, "__data_list__"): | |||
| delattr(pyg_dataset, "__data_list__") | |||
| if hasattr(pyg_dataset, "_data_list"): | |||
| delattr(pyg_dataset, "_data_list") | |||
| super(NCI1Dataset, self).__init__( | |||
| [ | |||
| GeneralStaticGraphGenerator.create_homogeneous_static_graph( | |||
| {'x': pyg_data.x}, pyg_data.edge_index, | |||
| graph_data={'y': pyg_data.y} | |||
| ) | |||
| for pyg_data in pyg_dataset | |||
| ] | |||
| ) | |||
| @DatasetUniversalRegistry.register_dataset("nci109") | |||
| class NCI109Dataset(InMemoryStaticGraphSet): | |||
| def __init__(self, path: str): | |||
| pyg_dataset = TUDataset(os.path.join(path, '_pyg'), "NCI109") | |||
| if hasattr(pyg_dataset, "__data_list__"): | |||
| delattr(pyg_dataset, "__data_list__") | |||
| if hasattr(pyg_dataset, "_data_list"): | |||
| delattr(pyg_dataset, "_data_list") | |||
| super(NCI109Dataset, self).__init__( | |||
| [ | |||
| GeneralStaticGraphGenerator.create_homogeneous_static_graph( | |||
| {'x': pyg_data.x}, pyg_data.edge_index, | |||
| graph_data={'y': pyg_data.y} | |||
| ) | |||
| for pyg_data in pyg_dataset | |||
| ] | |||
| ) | |||
| @DatasetUniversalRegistry.register_dataset("ModelNet10Training") | |||
| class ModelNet10TrainingDataset(InMemoryStaticGraphSet): | |||
| def __init__(self, path: str): | |||
| pyg_dataset = ModelNet( | |||
| os.path.join(path, '_pyg'), '10', True, | |||
| pre_transform=torch_geometric.transforms.FaceToEdge() | |||
| ) | |||
| if hasattr(pyg_dataset, "__data_list__"): | |||
| delattr(pyg_dataset, "__data_list__") | |||
| if hasattr(pyg_dataset, "_data_list"): | |||
| delattr(pyg_dataset, "_data_list") | |||
| super(ModelNet10TrainingDataset, self).__init__( | |||
| [ | |||
| GeneralStaticGraphGenerator.create_homogeneous_static_graph( | |||
| {'pos': pyg_data.pos}, | |||
| pyg_data.edge_index, | |||
| graph_data={'y': pyg_data.y} | |||
| ) | |||
| for pyg_data in pyg_dataset | |||
| ] | |||
| ) | |||
| @DatasetUniversalRegistry.register_dataset("ModelNet10Test") | |||
| class ModelNet10TestDataset(InMemoryStaticGraphSet): | |||
| def __init__(self, path: str): | |||
| pyg_dataset = ModelNet( | |||
| os.path.join(path, '_pyg'), '10', False, | |||
| pre_transform=torch_geometric.transforms.FaceToEdge() | |||
| ) | |||
| if hasattr(pyg_dataset, "__data_list__"): | |||
| delattr(pyg_dataset, "__data_list__") | |||
| if hasattr(pyg_dataset, "_data_list"): | |||
| delattr(pyg_dataset, "_data_list") | |||
| super(ModelNet10TestDataset, self).__init__( | |||
| [ | |||
| GeneralStaticGraphGenerator.create_homogeneous_static_graph( | |||
| {'pos': pyg_data.pos}, | |||
| pyg_data.edge_index, | |||
| graph_data={'y': pyg_data.y} | |||
| ) | |||
| for pyg_data in pyg_dataset | |||
| ] | |||
| ) | |||
| @DatasetUniversalRegistry.register_dataset("ModelNet40Training") | |||
| class ModelNet40TrainingDataset(InMemoryStaticGraphSet): | |||
| def __init__(self, path: str): | |||
| pyg_dataset = ModelNet( | |||
| os.path.join(path, '_pyg'), '40', True, | |||
| pre_transform=torch_geometric.transforms.FaceToEdge() | |||
| ) | |||
| if hasattr(pyg_dataset, "__data_list__"): | |||
| delattr(pyg_dataset, "__data_list__") | |||
| if hasattr(pyg_dataset, "_data_list"): | |||
| delattr(pyg_dataset, "_data_list") | |||
| super(ModelNet40TrainingDataset, self).__init__( | |||
| [ | |||
| GeneralStaticGraphGenerator.create_homogeneous_static_graph( | |||
| {'pos': pyg_data.pos}, | |||
| pyg_data.edge_index, | |||
| graph_data={'y': pyg_data.y} | |||
| ) | |||
| for pyg_data in pyg_dataset | |||
| ] | |||
| ) | |||
| @DatasetUniversalRegistry.register_dataset("ModelNet40Test") | |||
| class ModelNet40TestDataset(InMemoryStaticGraphSet): | |||
| def __init__(self, path: str): | |||
| pyg_dataset = ModelNet( | |||
| os.path.join(path, '_pyg'), '40', False, | |||
| pre_transform=torch_geometric.transforms.FaceToEdge() | |||
| ) | |||
| if hasattr(pyg_dataset, "__data_list__"): | |||
| delattr(pyg_dataset, "__data_list__") | |||
| if hasattr(pyg_dataset, "_data_list"): | |||
| delattr(pyg_dataset, "_data_list") | |||
| super(ModelNet40TestDataset, self).__init__( | |||
| [ | |||
| GeneralStaticGraphGenerator.create_homogeneous_static_graph( | |||
| {'pos': pyg_data.pos}, | |||
| pyg_data.edge_index, | |||
| graph_data={'y': pyg_data.y} | |||
| ) | |||
| for pyg_data in pyg_dataset | |||
| ] | |||
| ) | |||
| @@ -1,113 +0,0 @@ | |||
| import os.path as osp | |||
| import sys | |||
| import torch | |||
| from ..data import Data, Dataset, download_url | |||
| from . import register_dataset | |||
| def read_gatne_data(folder): | |||
| train_data = {} | |||
| with open(osp.join(folder, "{}".format("train.txt")), "r") as f: | |||
| for line in f: | |||
| items = line.strip().split() | |||
| if items[0] not in train_data: | |||
| train_data[items[0]] = [] | |||
| train_data[items[0]].append([int(items[1]), int(items[2])]) | |||
| valid_data = {} | |||
| with open(osp.join(folder, "{}".format("valid.txt")), "r") as f: | |||
| for line in f: | |||
| items = line.strip().split() | |||
| if items[0] not in valid_data: | |||
| valid_data[items[0]] = [[], []] | |||
| valid_data[items[0]][1 - int(items[3])].append( | |||
| [int(items[1]), int(items[2])] | |||
| ) | |||
| test_data = {} | |||
| with open(osp.join(folder, "{}".format("test.txt")), "r") as f: | |||
| for line in f: | |||
| items = line.strip().split() | |||
| if items[0] not in test_data: | |||
| test_data[items[0]] = [[], []] | |||
| test_data[items[0]][1 - int(items[3])].append( | |||
| [int(items[1]), int(items[2])] | |||
| ) | |||
| data = Data() | |||
| data.train_data = train_data | |||
| data.valid_data = valid_data | |||
| data.test_data = test_data | |||
| return data | |||
| class GatneDataset(Dataset): | |||
| r"""The network datasets "Amazon", "Twitter" and "YouTube" from the | |||
| `"Representation Learning for Attributed Multiplex Heterogeneous Network" | |||
| <https://arxiv.org/abs/1905.01669>`_ paper. | |||
| Args: | |||
| root (string): Root directory where the dataset should be saved. | |||
| name (string): The name of the dataset (:obj:`"Amazon"`, | |||
| :obj:`"Twitter"`, :obj:`"YouTube"`). | |||
| """ | |||
| url = "https://github.com/THUDM/GATNE/raw/master/data" | |||
| def __init__(self, root, name): | |||
| self.name = name | |||
| super(GatneDataset, self).__init__(root) | |||
| self.data = torch.load(self.processed_paths[0]) | |||
| @property | |||
| def raw_file_names(self): | |||
| names = ["train.txt", "valid.txt", "test.txt"] | |||
| return names | |||
| @property | |||
| def processed_file_names(self): | |||
| return ["data.pt"] | |||
| def get(self, idx): | |||
| assert idx == 0 | |||
| return self.data | |||
| def download(self): | |||
| for name in self.raw_file_names: | |||
| download_url( | |||
| "{}/{}/{}".format(self.url, self.name.lower(), name), self.raw_dir | |||
| ) | |||
| def process(self): | |||
| data = read_gatne_data(self.raw_dir) | |||
| torch.save(data, self.processed_paths[0]) | |||
| def __repr__(self): | |||
| return "{}()".format(self.name) | |||
| @register_dataset("amazon") | |||
| class AmazonDataset(GatneDataset): | |||
| def __init__(self, path): | |||
| dataset = "amazon" | |||
| # path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset) | |||
| super(AmazonDataset, self).__init__(path, dataset) | |||
| @register_dataset("twitter") | |||
| class TwitterDataset(GatneDataset): | |||
| def __init__(self, path): | |||
| dataset = "twitter" | |||
| # path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset) | |||
| super(TwitterDataset, self).__init__(path, dataset) | |||
| @register_dataset("youtube") | |||
| class YouTubeDataset(GatneDataset): | |||
| def __init__(self, path): | |||
| dataset = "youtube" | |||
| # path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset) | |||
| super(YouTubeDataset, self).__init__(path, dataset) | |||
| @@ -1,188 +0,0 @@ | |||
| import sys | |||
| import time | |||
| import os | |||
| import os.path as osp | |||
| import requests | |||
| import shutil | |||
| import tqdm | |||
| import pickle | |||
| import numpy as np | |||
| import torch | |||
| from ..data import Data, Dataset, download_url | |||
| from . import register_dataset | |||
| def untar(path, fname, deleteTar=True): | |||
| """ | |||
| Unpacks the given archive file to the same directory, then (by default) | |||
| deletes the archive file. | |||
| """ | |||
| print("unpacking " + fname) | |||
| fullpath = os.path.join(path, fname) | |||
| shutil.unpack_archive(fullpath, path) | |||
| if deleteTar: | |||
| os.remove(fullpath) | |||
| class GTNDataset(Dataset): | |||
| r"""The network datasets "ACM", "DBLP" and "IMDB" from the | |||
| `"Graph Transformer Networks" | |||
| <https://arxiv.org/abs/1911.06455>`_ paper. | |||
| Args: | |||
| root (string): Root directory where the dataset should be saved. | |||
| name (string): The name of the dataset (:obj:`"gtn-acm"`, | |||
| :obj:`"gtn-dblp"`, :obj:`"gtn-imdb"`). | |||
| """ | |||
| def __init__(self, root, name): | |||
| self.name = name | |||
| self.url = ( | |||
| f"https://github.com/cenyk1230/gtn-data/blob/master/{name}.zip?raw=true" | |||
| ) | |||
| super(GTNDataset, self).__init__(root) | |||
| self.data = torch.load(self.processed_paths[0]) | |||
| self.num_classes = torch.max(self.data.train_target).item() + 1 | |||
| self.num_edge = len(self.data.adj) | |||
| self.num_nodes = self.data.x.shape[0] | |||
| @property | |||
| def raw_file_names(self): | |||
| names = ["edges.pkl", "labels.pkl", "node_features.pkl"] | |||
| return names | |||
| @property | |||
| def processed_file_names(self): | |||
| return ["data.pt"] | |||
| def read_gtn_data(self, folder): | |||
| edges = pickle.load(open(osp.join(folder, "edges.pkl"), "rb")) | |||
| labels = pickle.load(open(osp.join(folder, "labels.pkl"), "rb")) | |||
| node_features = pickle.load(open(osp.join(folder, "node_features.pkl"), "rb")) | |||
| data = Data() | |||
| data.x = torch.from_numpy(node_features).type(torch.FloatTensor) | |||
| num_nodes = edges[0].shape[0] | |||
| node_type = np.zeros((num_nodes), dtype=int) | |||
| assert len(edges) == 4 | |||
| assert len(edges[0].nonzero()) == 2 | |||
| node_type[edges[0].nonzero()[0]] = 0 | |||
| node_type[edges[0].nonzero()[1]] = 1 | |||
| node_type[edges[1].nonzero()[0]] = 1 | |||
| node_type[edges[1].nonzero()[1]] = 0 | |||
| node_type[edges[2].nonzero()[0]] = 0 | |||
| node_type[edges[2].nonzero()[1]] = 2 | |||
| node_type[edges[3].nonzero()[0]] = 2 | |||
| node_type[edges[3].nonzero()[1]] = 0 | |||
| print(node_type) | |||
| data.pos = torch.from_numpy(node_type) | |||
| edge_list = [] | |||
| for i, edge in enumerate(edges): | |||
| edge_tmp = torch.from_numpy( | |||
| np.vstack((edge.nonzero()[0], edge.nonzero()[1])) | |||
| ).type(torch.LongTensor) | |||
| edge_list.append(edge_tmp) | |||
| data.edge_index = torch.cat(edge_list, 1) | |||
| A = [] | |||
| for i, edge in enumerate(edges): | |||
| edge_tmp = torch.from_numpy( | |||
| np.vstack((edge.nonzero()[0], edge.nonzero()[1])) | |||
| ).type(torch.LongTensor) | |||
| value_tmp = torch.ones(edge_tmp.shape[1]).type(torch.FloatTensor) | |||
| A.append((edge_tmp, value_tmp)) | |||
| edge_tmp = torch.stack( | |||
| (torch.arange(0, num_nodes), torch.arange(0, num_nodes)) | |||
| ).type(torch.LongTensor) | |||
| value_tmp = torch.ones(num_nodes).type(torch.FloatTensor) | |||
| A.append((edge_tmp, value_tmp)) | |||
| data.adj = A | |||
| data.train_node = torch.from_numpy(np.array(labels[0])[:, 0]).type( | |||
| torch.LongTensor | |||
| ) | |||
| data.train_target = torch.from_numpy(np.array(labels[0])[:, 1]).type( | |||
| torch.LongTensor | |||
| ) | |||
| data.valid_node = torch.from_numpy(np.array(labels[1])[:, 0]).type( | |||
| torch.LongTensor | |||
| ) | |||
| data.valid_target = torch.from_numpy(np.array(labels[1])[:, 1]).type( | |||
| torch.LongTensor | |||
| ) | |||
| data.test_node = torch.from_numpy(np.array(labels[2])[:, 0]).type( | |||
| torch.LongTensor | |||
| ) | |||
| data.test_target = torch.from_numpy(np.array(labels[2])[:, 1]).type( | |||
| torch.LongTensor | |||
| ) | |||
| y = np.zeros((num_nodes), dtype=int) | |||
| x_index = torch.cat((data.train_node, data.valid_node, data.test_node)) | |||
| y_index = torch.cat((data.train_target, data.valid_target, data.test_target)) | |||
| y[x_index.numpy()] = y_index.numpy() | |||
| data.y = torch.from_numpy(y) | |||
| self.data = data | |||
| def get(self, idx): | |||
| assert idx == 0 | |||
| return self.data | |||
| def apply_to_device(self, device): | |||
| self.data.x = self.data.x.to(device) | |||
| self.data.train_node = self.data.train_node.to(device) | |||
| self.data.valid_node = self.data.valid_node.to(device) | |||
| self.data.test_node = self.data.test_node.to(device) | |||
| self.data.train_target = self.data.train_target.to(device) | |||
| self.data.valid_target = self.data.valid_target.to(device) | |||
| self.data.test_target = self.data.test_target.to(device) | |||
| new_adj = [] | |||
| for (t1, t2) in self.data.adj: | |||
| new_adj.append((t1.to(device), t2.to(device))) | |||
| self.data.adj = new_adj | |||
| def download(self): | |||
| download_url(self.url, self.raw_dir, name=self.name + ".zip") | |||
| untar(self.raw_dir, self.name + ".zip") | |||
| def process(self): | |||
| self.read_gtn_data(self.raw_dir) | |||
| torch.save(self.data, self.processed_paths[0]) | |||
| def __repr__(self): | |||
| return "{}()".format(self.name) | |||
| @register_dataset("gtn-acm") | |||
| class ACM_GTNDataset(GTNDataset): | |||
| def __init__(self, path): | |||
| dataset = "gtn-acm" | |||
| # path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset) | |||
| super(ACM_GTNDataset, self).__init__(path, dataset) | |||
| @register_dataset("gtn-dblp") | |||
| class DBLP_GTNDataset(GTNDataset): | |||
| def __init__(self, path): | |||
| dataset = "gtn-dblp" | |||
| # path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset) | |||
| super(DBLP_GTNDataset, self).__init__(path, dataset) | |||
| @register_dataset("gtn-imdb") | |||
| class IMDB_GTNDataset(GTNDataset): | |||
| def __init__(self, path): | |||
| dataset = "gtn-imdb" | |||
| # path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset) | |||
| super(IMDB_GTNDataset, self).__init__(path, dataset) | |||
| @@ -1,187 +0,0 @@ | |||
| import sys | |||
| import time | |||
| import os | |||
| import os.path as osp | |||
| import requests | |||
| import shutil | |||
| import tqdm | |||
| import pickle | |||
| import numpy as np | |||
| import scipy.io as sio | |||
| import scipy.sparse as sp | |||
| import torch | |||
| from ..data import Data, Dataset, download_url | |||
| from . import register_dataset | |||
| def untar(path, fname, deleteTar=True): | |||
| """ | |||
| Unpacks the given archive file to the same directory, then (by default) | |||
| deletes the archive file. | |||
| """ | |||
| print("unpacking " + fname) | |||
| fullpath = os.path.join(path, fname) | |||
| shutil.unpack_archive(fullpath, path) | |||
| if deleteTar: | |||
| os.remove(fullpath) | |||
| def sample_mask(idx, l): | |||
| """Create mask.""" | |||
| mask = np.zeros(l) | |||
| mask[idx] = 1 | |||
| return np.array(mask, dtype=np.bool) | |||
| class HANDataset(Dataset): | |||
| r"""The network datasets "ACM", "DBLP" and "IMDB" from the | |||
| `"Heterogeneous Graph Attention Network" | |||
| <https://arxiv.org/abs/1903.07293>`_ paper. | |||
| Args: | |||
| root (string): Root directory where the dataset should be saved. | |||
| name (string): The name of the dataset (:obj:`"han-acm"`, | |||
| :obj:`"han-dblp"`, :obj:`"han-imdb"`). | |||
| """ | |||
| def __init__(self, root, name): | |||
| self.name = name | |||
| self.url = ( | |||
| f"https://github.com/cenyk1230/han-data/blob/master/{name}.zip?raw=true" | |||
| ) | |||
| super(HANDataset, self).__init__(root) | |||
| self.data = torch.load(self.processed_paths[0]) | |||
| self.num_classes = torch.max(self.data.train_target).item() + 1 | |||
| self.num_edge = len(self.data.adj) | |||
| self.num_nodes = self.data.x.shape[0] | |||
| @property | |||
| def raw_file_names(self): | |||
| names = ["data.mat"] | |||
| return names | |||
| @property | |||
| def processed_file_names(self): | |||
| return ["data.pt"] | |||
| def read_gtn_data(self, folder): | |||
| data = sio.loadmat(osp.join(folder, "data.mat")) | |||
| if self.name == "han-acm" or self.name == "han-imdb": | |||
| truelabels, truefeatures = data["label"], data["feature"].astype(float) | |||
| elif self.name == "han-dblp": | |||
| truelabels, truefeatures = data["label"], data["features"].astype(float) | |||
| num_nodes = truefeatures.shape[0] | |||
| if self.name == "han-acm": | |||
| rownetworks = [ | |||
| data["PAP"] - np.eye(num_nodes), | |||
| data["PLP"] - np.eye(num_nodes), | |||
| ] | |||
| elif self.name == "han-dblp": | |||
| rownetworks = [ | |||
| data["net_APA"] - np.eye(num_nodes), | |||
| data["net_APCPA"] - np.eye(num_nodes), | |||
| data["net_APTPA"] - np.eye(num_nodes), | |||
| ] | |||
| elif self.name == "han-imdb": | |||
| rownetworks = [ | |||
| data["MAM"] - np.eye(num_nodes), | |||
| data["MDM"] - np.eye(num_nodes), | |||
| data["MYM"] - np.eye(num_nodes), | |||
| ] | |||
| y = truelabels | |||
| train_idx = data["train_idx"] | |||
| val_idx = data["val_idx"] | |||
| test_idx = data["test_idx"] | |||
| train_mask = sample_mask(train_idx, y.shape[0]) | |||
| val_mask = sample_mask(val_idx, y.shape[0]) | |||
| test_mask = sample_mask(test_idx, y.shape[0]) | |||
| y_train = np.argmax(y[train_mask, :], axis=1) | |||
| y_val = np.argmax(y[val_mask, :], axis=1) | |||
| y_test = np.argmax(y[test_mask, :], axis=1) | |||
| data = Data() | |||
| A = [] | |||
| for i, edge in enumerate(rownetworks): | |||
| edge_tmp = torch.from_numpy( | |||
| np.vstack((edge.nonzero()[0], edge.nonzero()[1])) | |||
| ).type(torch.LongTensor) | |||
| value_tmp = torch.ones(edge_tmp.shape[1]).type(torch.FloatTensor) | |||
| A.append((edge_tmp, value_tmp)) | |||
| edge_tmp = torch.stack( | |||
| (torch.arange(0, num_nodes), torch.arange(0, num_nodes)) | |||
| ).type(torch.LongTensor) | |||
| value_tmp = torch.ones(num_nodes).type(torch.FloatTensor) | |||
| A.append((edge_tmp, value_tmp)) | |||
| data.adj = A | |||
| data.x = torch.from_numpy(truefeatures).type(torch.FloatTensor) | |||
| data.train_node = torch.from_numpy(train_idx[0]).type(torch.LongTensor) | |||
| data.train_target = torch.from_numpy(y_train).type(torch.LongTensor) | |||
| data.valid_node = torch.from_numpy(val_idx[0]).type(torch.LongTensor) | |||
| data.valid_target = torch.from_numpy(y_val).type(torch.LongTensor) | |||
| data.test_node = torch.from_numpy(test_idx[0]).type(torch.LongTensor) | |||
| data.test_target = torch.from_numpy(y_test).type(torch.LongTensor) | |||
| self.data = data | |||
| def get(self, idx): | |||
| assert idx == 0 | |||
| return self.data | |||
| def apply_to_device(self, device): | |||
| self.data.x = self.data.x.to(device) | |||
| self.data.train_node = self.data.train_node.to(device) | |||
| self.data.valid_node = self.data.valid_node.to(device) | |||
| self.data.test_node = self.data.test_node.to(device) | |||
| self.data.train_target = self.data.train_target.to(device) | |||
| self.data.valid_target = self.data.valid_target.to(device) | |||
| self.data.test_target = self.data.test_target.to(device) | |||
| new_adj = [] | |||
| for (t1, t2) in self.data.adj: | |||
| new_adj.append((t1.to(device), t2.to(device))) | |||
| self.data.adj = new_adj | |||
| def download(self): | |||
| download_url(self.url, self.raw_dir, name=self.name + ".zip") | |||
| untar(self.raw_dir, self.name + ".zip") | |||
| def process(self): | |||
| self.read_gtn_data(self.raw_dir) | |||
| torch.save(self.data, self.processed_paths[0]) | |||
| def __repr__(self): | |||
| return "{}()".format(self.name) | |||
| @register_dataset("han-acm") | |||
| class ACM_HANDataset(HANDataset): | |||
| def __init__(self, path): | |||
| dataset = "han-acm" | |||
| # path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset) | |||
| super(ACM_HANDataset, self).__init__(path, dataset) | |||
| @register_dataset("han-dblp") | |||
| class DBLP_HANDataset(HANDataset): | |||
| def __init__(self, path): | |||
| dataset = "han-dblp" | |||
| # path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset) | |||
| super(DBLP_HANDataset, self).__init__(path, dataset) | |||
| @register_dataset("han-imdb") | |||
| class IMDB_HANDataset(HANDataset): | |||
| def __init__(self, path): | |||
| dataset = "han-imdb" | |||
| # path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset) | |||
| super(IMDB_HANDataset, self).__init__(path, dataset) | |||
| @@ -1,96 +0,0 @@ | |||
| import json | |||
| import os | |||
| import os.path as osp | |||
| from itertools import product | |||
| import numpy as np | |||
| import scipy.io | |||
| import torch | |||
| from ..data import Data, Dataset, download_url | |||
| from . import register_dataset | |||
| class MatlabMatrix(Dataset): | |||
| r"""The networks datasets "Blogcatalog", "Flickr", "Wikipedia" and "PPI" from the http://leitang.net/code/social-dimension/data/ or http://snap.stanford.edu/node2vec/ | |||
| Args: | |||
| root (string): Root directory where the dataset should be saved. | |||
| name (string): The name of the dataset (:obj:`"Blogcatalog"`). | |||
| """ | |||
| def __init__(self, root, name, url): | |||
| self.name = name | |||
| self.url = url | |||
| super(MatlabMatrix, self).__init__(root) | |||
| self.data = torch.load(self.processed_paths[0]) | |||
| @property | |||
| def raw_file_names(self): | |||
| splits = [self.name] | |||
| files = ["mat"] | |||
| return ["{}.{}".format(s, f) for s, f in product(splits, files)] | |||
| @property | |||
| def processed_file_names(self): | |||
| return ["data.pt"] | |||
| def download(self): | |||
| for name in self.raw_file_names: | |||
| download_url("{}{}".format(self.url, name), self.raw_dir) | |||
| def get(self, idx): | |||
| assert idx == 0 | |||
| return self.data | |||
| def process(self): | |||
| path = osp.join(self.raw_dir, "{}.mat".format(self.name)) | |||
| smat = scipy.io.loadmat(path) | |||
| adj_matrix, group = smat["network"], smat["group"] | |||
| y = torch.from_numpy(group.todense()).to(torch.float) | |||
| row_ind, col_ind = adj_matrix.nonzero() | |||
| edge_index = torch.stack([torch.tensor(row_ind), torch.tensor(col_ind)], dim=0) | |||
| edge_attr = torch.tensor(adj_matrix[row_ind, col_ind]) | |||
| data = Data(edge_index=edge_index, edge_attr=edge_attr, x=None, y=y) | |||
| torch.save(data, self.processed_paths[0]) | |||
| @register_dataset("blogcatalog") | |||
| class BlogcatalogDataset(MatlabMatrix): | |||
| def __init__(self, path): | |||
| dataset, filename = "blogcatalog", "blogcatalog" | |||
| url = "http://leitang.net/code/social-dimension/data/" | |||
| # path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset) | |||
| super(BlogcatalogDataset, self).__init__(path, filename, url) | |||
| # @register_dataset("flickr") | |||
| # class FlickrDataset(MatlabMatrix): | |||
| # def __init__(self, path): | |||
| # dataset, filename = "flickr", "flickr" | |||
| # url = "http://leitang.net/code/social-dimension/data/" | |||
| # # path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset) | |||
| # super(FlickrDataset, self).__init__(path, filename, url) | |||
| @register_dataset("wikipedia") | |||
| class WikipediaDataset(MatlabMatrix): | |||
| def __init__(self, path): | |||
| dataset, filename = "wikipedia", "POS" | |||
| url = "http://snap.stanford.edu/node2vec/" | |||
| # path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset) | |||
| super(WikipediaDataset, self).__init__(path, filename, url) | |||
| @register_dataset("ppi") | |||
| class PPIDataset(MatlabMatrix): | |||
| def __init__(self, path): | |||
| dataset, filename = "ppi", "Homo_sapiens" | |||
| url = "http://snap.stanford.edu/node2vec/" | |||
| # path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset) | |||
| super(PPIDataset, self).__init__(path, filename, url) | |||
| @@ -1,70 +0,0 @@ | |||
| # import os.path as osp | |||
| # import torch_geometric.transforms as T | |||
| from torch_geometric.datasets import ModelNet | |||
| from . import register_dataset | |||
| class ModelNet10(ModelNet): | |||
| def __init__(self, path: str, train: bool): | |||
| # pre_transform, transform = T.NormalizeScale(), T.SamplePoints(1024) | |||
| # path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset) | |||
| super(ModelNet10, self).__init__(path, name="10", train=train) | |||
| class ModelNet40(ModelNet): | |||
| def __init__(self, path: str, train: bool): | |||
| # pre_transform, transform = T.NormalizeScale(), T.SamplePoints(1024) | |||
| # path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset) | |||
| super(ModelNet40, self).__init__(path, name="40", train=train) | |||
| @register_dataset("ModelNet10Train") | |||
| class ModelNet10Train(ModelNet): | |||
| def __init__(self, path: str): | |||
| super(ModelNet10Train, self).__init__(path, "10", train=True) | |||
| def get(self, idx): | |||
| if hasattr(self, "__data_list__"): | |||
| delattr(self, "__data_list__") | |||
| if hasattr(self, "_data_list"): | |||
| delattr(self, "_data_list") | |||
| return super(ModelNet10Train, self).get(idx) | |||
| @register_dataset("ModelNet10Test") | |||
| class ModelNet10Test(ModelNet): | |||
| def __init__(self, path: str): | |||
| super(ModelNet10Test, self).__init__(path, "10", train=False) | |||
| def get(self, idx): | |||
| if hasattr(self, "__data_list__"): | |||
| delattr(self, "__data_list__") | |||
| if hasattr(self, "_data_list"): | |||
| delattr(self, "_data_list") | |||
| return super(ModelNet10Test, self).get(idx) | |||
| @register_dataset("ModelNet40Train") | |||
| class ModelNet40Train(ModelNet): | |||
| def __init__(self, path: str): | |||
| super(ModelNet40Train, self).__init__(path, "40", train=True) | |||
| def get(self, idx): | |||
| if hasattr(self, "__data_list__"): | |||
| delattr(self, "__data_list__") | |||
| if hasattr(self, "_data_list"): | |||
| delattr(self, "_data_list") | |||
| return super(ModelNet40Train, self).get(idx) | |||
| @register_dataset("ModelNet40Test") | |||
| class ModelNet40Test(ModelNet): | |||
| def __init__(self, path: str): | |||
| super(ModelNet40Test, self).__init__(path, "40", train=False) | |||
| def get(self, idx): | |||
| if hasattr(self, "__data_list__"): | |||
| delattr(self, "__data_list__") | |||
| if hasattr(self, "_data_list"): | |||
| delattr(self, "_data_list") | |||
| return super(ModelNet40Test, self).get(idx) | |||
| @@ -1,358 +0,0 @@ | |||
| import torch_geometric.transforms as T | |||
| from ogb.nodeproppred import PygNodePropPredDataset | |||
| from ogb.graphproppred import PygGraphPropPredDataset | |||
| from ogb.linkproppred import PygLinkPropPredDataset | |||
| from . import register_dataset | |||
| from .utils import index_to_mask | |||
| from torch_geometric.data import Data | |||
| # OGBN | |||
| @register_dataset("ogbn-products") | |||
| class OGBNproductsDataset(PygNodePropPredDataset): | |||
| def __init__(self, path): | |||
| dataset = "ogbn-products" | |||
| # path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset) | |||
| PygNodePropPredDataset(name=dataset, root=path) | |||
| super(OGBNproductsDataset, self).__init__(dataset, path) | |||
| # Pre-compute GCN normalization. | |||
| # adj_t = self.data.adj_t.set_diag() | |||
| # deg = adj_t.sum(dim=1).to(torch.float) | |||
| # deg_inv_sqrt = deg.pow(-0.5) | |||
| # deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 | |||
| # adj_t = deg_inv_sqrt.view(-1, 1) * adj_t * deg_inv_sqrt.view(1, -1) | |||
| # self.data.adj_t = adj_t | |||
| setattr(OGBNproductsDataset, "metric", "Accuracy") | |||
| setattr(OGBNproductsDataset, "loss", "nll_loss") | |||
| split_idx = self.get_idx_split() | |||
| datalist = [] | |||
| for d in self: | |||
| setattr(d, "train_mask", index_to_mask(split_idx["train"], d.y.shape[0])) | |||
| setattr(d, "val_mask", index_to_mask(split_idx["valid"], d.y.shape[0])) | |||
| setattr(d, "test_mask", index_to_mask(split_idx["test"], d.y.shape[0])) | |||
| datalist.append(d) | |||
| self.data, self.slices = self.collate(datalist) | |||
| def get(self, idx): | |||
| if hasattr(self, "__data_list__"): | |||
| delattr(self, "__data_list__") | |||
| if hasattr(self, "_data_list"): | |||
| delattr(self, "_data_list") | |||
| return super(OGBNproductsDataset, self).get(idx) | |||
| @register_dataset("ogbn-proteins") | |||
| class OGBNproteinsDataset(PygNodePropPredDataset): | |||
| def __init__(self, path): | |||
| dataset = "ogbn-proteins" | |||
| # path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset) | |||
| PygNodePropPredDataset(name=dataset, root=path) | |||
| super(OGBNproteinsDataset, self).__init__(dataset, path) | |||
| dataset_t = PygNodePropPredDataset( | |||
| name=dataset, root=path, transform=T.ToSparseTensor() | |||
| ) | |||
| # Move edge features to node features. | |||
| self.data.x = dataset_t[0].adj_t.mean(dim=1) | |||
| # dataset_t[0].adj_t.set_value_(None) | |||
| del dataset_t | |||
| setattr(OGBNproteinsDataset, "metric", "ROC-AUC") | |||
| setattr(OGBNproteinsDataset, "loss", "binary_cross_entropy_with_logits") | |||
| split_idx = self.get_idx_split() | |||
| datalist = [] | |||
| for d in self: | |||
| setattr(d, "train_mask", index_to_mask(split_idx["train"], d.y.shape[0])) | |||
| setattr(d, "val_mask", index_to_mask(split_idx["valid"], d.y.shape[0])) | |||
| setattr(d, "test_mask", index_to_mask(split_idx["test"], d.y.shape[0])) | |||
| datalist.append(d) | |||
| self.data, self.slices = self.collate(datalist) | |||
| def get(self, idx): | |||
| if hasattr(self, "__data_list__"): | |||
| delattr(self, "__data_list__") | |||
| if hasattr(self, "_data_list"): | |||
| delattr(self, "_data_list") | |||
| return super(OGBNproteinsDataset, self).get(idx) | |||
| @register_dataset("ogbn-arxiv") | |||
| class OGBNarxivDataset(PygNodePropPredDataset): | |||
| def __init__(self, path): | |||
| dataset = "ogbn-arxiv" | |||
| # path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset) | |||
| PygNodePropPredDataset(name=dataset, root=path) | |||
| super(OGBNarxivDataset, self).__init__(dataset, path) | |||
| setattr(OGBNarxivDataset, "metric", "Accuracy") | |||
| setattr(OGBNarxivDataset, "loss", "nll_loss") | |||
| split_idx = self.get_idx_split() | |||
| datalist = [] | |||
| for d in self: | |||
| setattr(d, "train_mask", index_to_mask(split_idx["train"], d.y.shape[0])) | |||
| setattr(d, "val_mask", index_to_mask(split_idx["valid"], d.y.shape[0])) | |||
| setattr(d, "test_mask", index_to_mask(split_idx["test"], d.y.shape[0])) | |||
| datalist.append(d) | |||
| self.data, self.slices = self.collate(datalist) | |||
| def get(self, idx): | |||
| if hasattr(self, "__data_list__"): | |||
| delattr(self, "__data_list__") | |||
| if hasattr(self, "_data_list"): | |||
| delattr(self, "_data_list") | |||
| return super(OGBNarxivDataset, self).get(idx) | |||
| @register_dataset("ogbn-papers100M") | |||
| class OGBNpapers100MDataset(PygNodePropPredDataset): | |||
| def __init__(self, path): | |||
| dataset = "ogbn-papers100M" | |||
| # path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset) | |||
| PygNodePropPredDataset(name=dataset, root=path) | |||
| super(OGBNpapers100MDataset, self).__init__(dataset, path) | |||
| setattr(OGBNpapers100MDataset, "metric", "Accuracy") | |||
| setattr(OGBNpapers100MDataset, "loss", "nll_loss") | |||
| split_idx = self.get_idx_split() | |||
| datalist = [] | |||
| for d in self: | |||
| setattr(d, "train_mask", index_to_mask(split_idx["train"], d.y.shape[0])) | |||
| setattr(d, "val_mask", index_to_mask(split_idx["valid"], d.y.shape[0])) | |||
| setattr(d, "test_mask", index_to_mask(split_idx["test"], d.y.shape[0])) | |||
| datalist.append(d) | |||
| self.data, self.slices = self.collate(datalist) | |||
| def get(self, idx): | |||
| if hasattr(self, "__data_list__"): | |||
| delattr(self, "__data_list__") | |||
| if hasattr(self, "_data_list"): | |||
| delattr(self, "_data_list") | |||
| return super(OGBNpapers100MDataset, self).get(idx) | |||
| @register_dataset("ogbn-mag") | |||
| class OGBNmagDataset(PygNodePropPredDataset): | |||
| def __init__(self, path): | |||
| dataset = "ogbn-mag" | |||
| # path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset) | |||
| PygNodePropPredDataset(name=dataset, root=path) | |||
| super(OGBNmagDataset, self).__init__(dataset, path) | |||
| # Preprocess | |||
| rel_data = self[0] | |||
| # We are only interested in paper <-> paper relations. | |||
| self.data = Data( | |||
| x=rel_data.x_dict["paper"], | |||
| edge_index=rel_data.edge_index_dict[("paper", "cites", "paper")], | |||
| y=rel_data.y_dict["paper"], | |||
| ) | |||
| # self.data = T.ToSparseTensor()(data) | |||
| # self[0].adj_t = self[0].adj_t.to_symmetric() | |||
| setattr(OGBNmagDataset, "metric", "Accuracy") | |||
| setattr(OGBNmagDataset, "loss", "nll_loss") | |||
| split_idx = self.get_idx_split() | |||
| datalist = [] | |||
| for d in self: | |||
| setattr(d, "train_mask", index_to_mask(split_idx["train"], d.y.shape[0])) | |||
| setattr(d, "val_mask", index_to_mask(split_idx["valid"], d.y.shape[0])) | |||
| setattr(d, "test_mask", index_to_mask(split_idx["test"], d.y.shape[0])) | |||
| datalist.append(d) | |||
| self.data, self.slices = self.collate(datalist) | |||
| def get(self, idx): | |||
| if hasattr(self, "__data_list__"): | |||
| delattr(self, "__data_list__") | |||
| if hasattr(self, "_data_list"): | |||
| delattr(self, "_data_list") | |||
| return super(OGBNmagDataset, self).get(idx) | |||
| # OGBG | |||
| @register_dataset("ogbg-molhiv") | |||
| class OGBGmolhivDataset(PygGraphPropPredDataset): | |||
| def __init__(self, path): | |||
| dataset = "ogbg-molhiv" | |||
| # path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset) | |||
| PygGraphPropPredDataset(name=dataset, root=path) | |||
| super(OGBGmolhivDataset, self).__init__(dataset, path) | |||
| setattr(OGBGmolhivDataset, "metric", "ROC-AUC") | |||
| setattr(OGBGmolhivDataset, "loss", "binary_cross_entropy_with_logits") | |||
| def get(self, idx): | |||
| if hasattr(self, "__data_list__"): | |||
| delattr(self, "__data_list__") | |||
| if hasattr(self, "_data_list"): | |||
| delattr(self, "_data_list") | |||
| return super(OGBGmolhivDataset, self).get(idx) | |||
| @register_dataset("ogbg-molpcba") | |||
| class OGBGmolpcbaDataset(PygGraphPropPredDataset): | |||
| def __init__(self, path): | |||
| dataset = "ogbg-molpcba" | |||
| # path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset) | |||
| PygGraphPropPredDataset(name=dataset, root=path) | |||
| super(OGBGmolpcbaDataset, self).__init__(dataset, path) | |||
| setattr(OGBGmolpcbaDataset, "metric", "AP") | |||
| setattr(OGBGmolpcbaDataset, "loss", "binary_cross_entropy_with_logits") | |||
| def get(self, idx): | |||
| if hasattr(self, "__data_list__"): | |||
| delattr(self, "__data_list__") | |||
| if hasattr(self, "_data_list"): | |||
| delattr(self, "_data_list") | |||
| return super(OGBGmolpcbaDataset, self).get(idx) | |||
| @register_dataset("ogbg-ppa") | |||
| class OGBGppaDataset(PygGraphPropPredDataset): | |||
| def __init__(self, path): | |||
| dataset = "ogbg-ppa" | |||
| # path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset) | |||
| PygGraphPropPredDataset(name=dataset, root=path) | |||
| super(OGBGppaDataset, self).__init__(dataset, path) | |||
| setattr(OGBGppaDataset, "metric", "Accuracy") | |||
| setattr(OGBGppaDataset, "loss", "cross_entropy") | |||
| def get(self, idx): | |||
| if hasattr(self, "__data_list__"): | |||
| delattr(self, "__data_list__") | |||
| if hasattr(self, "_data_list"): | |||
| delattr(self, "_data_list") | |||
| return super(OGBGppaDataset, self).get(idx) | |||
| @register_dataset("ogbg-code") | |||
| class OGBGcodeDataset(PygGraphPropPredDataset): | |||
| def __init__(self, path): | |||
| dataset = "ogbg-code" | |||
| # path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset) | |||
| PygGraphPropPredDataset(name=dataset, root=path) | |||
| super(OGBGcodeDataset, self).__init__(dataset, path) | |||
| setattr(OGBGcodeDataset, "metric", "F1 score") | |||
| setattr(OGBGcodeDataset, "loss", "cross_entropy") | |||
| def get(self, idx): | |||
| if hasattr(self, "__data_list__"): | |||
| delattr(self, "__data_list__") | |||
| if hasattr(self, "_data_list"): | |||
| delattr(self, "_data_list") | |||
| return super(OGBGcodeDataset, self).get(idx) | |||
| # OGBL | |||
| @register_dataset("ogbl-ppa") | |||
| class OGBLppaDataset(PygLinkPropPredDataset): | |||
| def __init__(self, path): | |||
| dataset = "ogbl-ppa" | |||
| # path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset) | |||
| PygLinkPropPredDataset(name=dataset, root=path) | |||
| super(OGBLppaDataset, self).__init__(dataset, path) | |||
| setattr(OGBLppaDataset, "metric", "Hits@100") | |||
| setattr(OGBLppaDataset, "loss", "pos_neg_loss") | |||
| def get(self, idx): | |||
| if hasattr(self, "__data_list__"): | |||
| delattr(self, "__data_list__") | |||
| if hasattr(self, "_data_list"): | |||
| delattr(self, "_data_list") | |||
| return super(OGBLppaDataset, self).get(idx) | |||
| @register_dataset("ogbl-collab") | |||
| class OGBLcollabDataset(PygLinkPropPredDataset): | |||
| def __init__(self, path): | |||
| dataset = "ogbl-collab" | |||
| # path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset) | |||
| PygLinkPropPredDataset(name=dataset, root=path) | |||
| super(OGBLcollabDataset, self).__init__(dataset, path) | |||
| setattr(OGBLcollabDataset, "metric", "Hits@50") | |||
| setattr(OGBLcollabDataset, "loss", "pos_neg_loss") | |||
| def get(self, idx): | |||
| if hasattr(self, "__data_list__"): | |||
| delattr(self, "__data_list__") | |||
| if hasattr(self, "_data_list"): | |||
| delattr(self, "_data_list") | |||
| return super(OGBLcollabDataset, self).get(idx) | |||
| @register_dataset("ogbl-ddi") | |||
| class OGBLddiDataset(PygLinkPropPredDataset): | |||
| def __init__(self, path): | |||
| dataset = "ogbl-ddi" | |||
| # path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset) | |||
| PygLinkPropPredDataset(name=dataset, root=path) | |||
| super(OGBLddiDataset, self).__init__(dataset, path) | |||
| setattr(OGBLddiDataset, "metric", "Hits@20") | |||
| setattr(OGBLddiDataset, "loss", "pos_neg_loss") | |||
| def get(self, idx): | |||
| if hasattr(self, "__data_list__"): | |||
| delattr(self, "__data_list__") | |||
| if hasattr(self, "_data_list"): | |||
| delattr(self, "_data_list") | |||
| return super(OGBLddiDataset, self).get(idx) | |||
| @register_dataset("ogbl-citation") | |||
| class OGBLcitationDataset(PygLinkPropPredDataset): | |||
| def __init__(self, path): | |||
| dataset = "ogbl-citation" | |||
| # path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset) | |||
| PygLinkPropPredDataset(name=dataset, root=path) | |||
| super(OGBLcitationDataset, self).__init__(dataset, path) | |||
| setattr(OGBLcitationDataset, "metric", "MRR") | |||
| setattr(OGBLcitationDataset, "loss", "pos_neg_loss") | |||
| def get(self, idx): | |||
| if hasattr(self, "__data_list__"): | |||
| delattr(self, "__data_list__") | |||
| if hasattr(self, "_data_list"): | |||
| delattr(self, "_data_list") | |||
| return super(OGBLcitationDataset, self).get(idx) | |||
| @register_dataset("ogbl-wikikg") | |||
| class OGBLwikikgDataset(PygLinkPropPredDataset): | |||
| def __init__(self, path): | |||
| dataset = "ogbl-wikikg" | |||
| # path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset) | |||
| PygLinkPropPredDataset(name=dataset, root=path) | |||
| super(OGBLwikikgDataset, self).__init__(dataset, path) | |||
| setattr(OGBLwikikgDataset, "metric", "MRR") | |||
| setattr(OGBLwikikgDataset, "loss", "pos_neg_loss") | |||
| def get(self, idx): | |||
| if hasattr(self, "__data_list__"): | |||
| delattr(self, "__data_list__") | |||
| if hasattr(self, "_data_list"): | |||
| delattr(self, "_data_list") | |||
| return super(OGBLwikikgDataset, self).get(idx) | |||
| @register_dataset("ogbl-biokg") | |||
| class OGBLbiokgDataset(PygLinkPropPredDataset): | |||
| def __init__(self, path): | |||
| dataset = "ogbl-biokg" | |||
| # path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset) | |||
| PygLinkPropPredDataset(name=dataset, root=path) | |||
| super(OGBLbiokgDataset, self).__init__(dataset, path) | |||
| setattr(OGBLbiokgDataset, "metric", "MRR") | |||
| setattr(OGBLbiokgDataset, "loss", "pos_neg_loss") | |||
| def get(self, idx): | |||
| if hasattr(self, "__data_list__"): | |||
| delattr(self, "__data_list__") | |||
| if hasattr(self, "_data_list"): | |||
| delattr(self, "_data_list") | |||
| return super(OGBLbiokgDataset, self).get(idx) | |||
| @@ -1,407 +0,0 @@ | |||
| import os.path as osp | |||
| import torch | |||
| # import torch_geometric.transforms as T | |||
| from torch_geometric.datasets import ( | |||
| Planetoid, | |||
| Reddit, | |||
| TUDataset, | |||
| QM9, | |||
| Amazon, | |||
| Coauthor, | |||
| Flickr, | |||
| ) | |||
| from torch_geometric.utils import remove_self_loops | |||
| from . import register_dataset | |||
| @register_dataset("amazon_computers") | |||
| class AmazonComputersDataset(Amazon): | |||
| def __init__(self, path): | |||
| dataset = "Computers" | |||
| # path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset) | |||
| Amazon(path, dataset) | |||
| super(AmazonComputersDataset, self).__init__(path, dataset) | |||
| def get(self, idx): | |||
| if hasattr(self, "__data_list__"): | |||
| delattr(self, "__data_list__") | |||
| if hasattr(self, "_data_list"): | |||
| delattr(self, "_data_list") | |||
| return super(AmazonComputersDataset, self).get(idx) | |||
| @register_dataset("amazon_photo") | |||
| class AmazonPhotoDataset(Amazon): | |||
| def __init__(self, path): | |||
| dataset = "Photo" | |||
| # path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset) | |||
| Amazon(path, dataset) | |||
| super(AmazonPhotoDataset, self).__init__(path, dataset) | |||
| def get(self, idx): | |||
| if hasattr(self, "__data_list__"): | |||
| delattr(self, "__data_list__") | |||
| if hasattr(self, "_data_list"): | |||
| delattr(self, "_data_list") | |||
| return super(AmazonPhotoDataset, self).get(idx) | |||
| @register_dataset("coauthor_physics") | |||
| class CoauthorPhysicsDataset(Coauthor): | |||
| def __init__(self, path): | |||
| dataset = "Physics" | |||
| # path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset) | |||
| Coauthor(path, dataset) | |||
| super(CoauthorPhysicsDataset, self).__init__(path, dataset) | |||
| def get(self, idx): | |||
| if hasattr(self, "__data_list__"): | |||
| delattr(self, "__data_list__") | |||
| if hasattr(self, "_data_list"): | |||
| delattr(self, "_data_list") | |||
| return super(CoauthorPhysicsDataset, self).get(idx) | |||
| @register_dataset("coauthor_cs") | |||
| class CoauthorCSDataset(Coauthor): | |||
| def __init__(self, path): | |||
| dataset = "CS" | |||
| # path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset) | |||
| Coauthor(path, dataset) | |||
| super(CoauthorCSDataset, self).__init__(path, dataset) | |||
| def get(self, idx): | |||
| if hasattr(self, "__data_list__"): | |||
| delattr(self, "__data_list__") | |||
| if hasattr(self, "_data_list"): | |||
| delattr(self, "_data_list") | |||
| return super(CoauthorCSDataset, self).get(idx) | |||
| @register_dataset("cora") | |||
| class CoraDataset(Planetoid): | |||
| def __init__(self, path): | |||
| dataset = "Cora" | |||
| # path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset) | |||
| Planetoid(path, dataset) | |||
| super(CoraDataset, self).__init__(path, dataset) | |||
| def get(self, idx): | |||
| if hasattr(self, "__data_list__"): | |||
| delattr(self, "__data_list__") | |||
| if hasattr(self, "_data_list"): | |||
| delattr(self, "_data_list") | |||
| return super(CoraDataset, self).get(idx) | |||
| @register_dataset("citeseer") | |||
| class CiteSeerDataset(Planetoid): | |||
| def __init__(self, path): | |||
| dataset = "CiteSeer" | |||
| # path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset) | |||
| Planetoid(path, dataset) | |||
| super(CiteSeerDataset, self).__init__(path, dataset) | |||
| def get(self, idx): | |||
| if hasattr(self, "__data_list__"): | |||
| delattr(self, "__data_list__") | |||
| if hasattr(self, "_data_list"): | |||
| delattr(self, "_data_list") | |||
| return super(CiteSeerDataset, self).get(idx) | |||
| @register_dataset("pubmed") | |||
| class PubMedDataset(Planetoid): | |||
| def __init__(self, path): | |||
| dataset = "PubMed" | |||
| # path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset) | |||
| Planetoid(path, dataset) | |||
| super(PubMedDataset, self).__init__(path, dataset) | |||
| def get(self, idx): | |||
| if hasattr(self, "__data_list__"): | |||
| delattr(self, "__data_list__") | |||
| if hasattr(self, "_data_list"): | |||
| delattr(self, "_data_list") | |||
| return super(PubMedDataset, self).get(idx) | |||
| @register_dataset("reddit") | |||
| class RedditDataset(Reddit): | |||
| def __init__(self, path): | |||
| dataset = "Reddit" | |||
| # path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset) | |||
| Reddit(path) | |||
| super(RedditDataset, self).__init__(path) | |||
| def get(self, idx): | |||
| if hasattr(self, "__data_list__"): | |||
| delattr(self, "__data_list__") | |||
| if hasattr(self, "_data_list"): | |||
| delattr(self, "_data_list") | |||
| return super(RedditDataset, self).get(idx) | |||
| @register_dataset("flickr") | |||
| class FlickrDataset(Flickr): | |||
| def __init__(self, path): | |||
| Flickr(path) | |||
| super(FlickrDataset, self).__init__(path) | |||
| def get(self, idx): | |||
| if hasattr(self, "__data_list__"): | |||
| delattr(self, "__data_list__") | |||
| if hasattr(self, "_data_list"): | |||
| delattr(self, "_data_list") | |||
| return super(FlickrDataset, self).get(idx) | |||
| @register_dataset("mutag") | |||
| class MUTAGDataset(TUDataset): | |||
| def __init__(self, path): | |||
| dataset = "MUTAG" | |||
| # path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset) | |||
| TUDataset(path, name=dataset) | |||
| super(MUTAGDataset, self).__init__(path, name=dataset) | |||
| def get(self, idx): | |||
| if hasattr(self, "__data_list__"): | |||
| delattr(self, "__data_list__") | |||
| if hasattr(self, "_data_list"): | |||
| delattr(self, "_data_list") | |||
| return super(MUTAGDataset, self).get(idx) | |||
| @register_dataset("imdb-b") | |||
| class IMDBBinaryDataset(TUDataset): | |||
| def __init__(self, path): | |||
| dataset = "IMDB-BINARY" | |||
| # path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset) | |||
| TUDataset(path, name=dataset) | |||
| super(IMDBBinaryDataset, self).__init__(path, name=dataset) | |||
| def get(self, idx): | |||
| if hasattr(self, "__data_list__"): | |||
| delattr(self, "__data_list__") | |||
| if hasattr(self, "_data_list"): | |||
| delattr(self, "_data_list") | |||
| return super(IMDBBinaryDataset, self).get(idx) | |||
| @register_dataset("imdb-m") | |||
| class IMDBMultiDataset(TUDataset): | |||
| def __init__(self, path): | |||
| dataset = "IMDB-MULTI" | |||
| # path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset) | |||
| TUDataset(path, name=dataset) | |||
| super(IMDBMultiDataset, self).__init__(path, name=dataset) | |||
| def get(self, idx): | |||
| if hasattr(self, "__data_list__"): | |||
| delattr(self, "__data_list__") | |||
| if hasattr(self, "_data_list"): | |||
| delattr(self, "_data_list") | |||
| return super(IMDBMultiDataset, self).get(idx) | |||
| @register_dataset("collab") | |||
| class CollabDataset(TUDataset): | |||
| def __init__(self, path): | |||
| dataset = "COLLAB" | |||
| # path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset) | |||
| TUDataset(path, name=dataset) | |||
| super(CollabDataset, self).__init__(path, name=dataset) | |||
| def get(self, idx): | |||
| if hasattr(self, "__data_list__"): | |||
| delattr(self, "__data_list__") | |||
| if hasattr(self, "_data_list"): | |||
| delattr(self, "_data_list") | |||
| return super(CollabDataset, self).get(idx) | |||
| @register_dataset("proteins") | |||
| class ProteinsDataset(TUDataset): | |||
| def __init__(self, path): | |||
| dataset = "PROTEINS" | |||
| # path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset) | |||
| TUDataset(path, name=dataset) | |||
| super(ProteinsDataset, self).__init__(path, name=dataset) | |||
| def get(self, idx): | |||
| if hasattr(self, "__data_list__"): | |||
| delattr(self, "__data_list__") | |||
| if hasattr(self, "_data_list"): | |||
| delattr(self, "_data_list") | |||
| return super(ProteinsDataset, self).get(idx) | |||
| @register_dataset("reddit-b") | |||
| class REDDITBinary(TUDataset): | |||
| def __init__(self, path): | |||
| dataset = "REDDIT-BINARY" | |||
| # path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset) | |||
| TUDataset(path, name=dataset) | |||
| super(REDDITBinary, self).__init__(path, name=dataset) | |||
| def get(self, idx): | |||
| if hasattr(self, "__data_list__"): | |||
| delattr(self, "__data_list__") | |||
| if hasattr(self, "_data_list"): | |||
| delattr(self, "_data_list") | |||
| return super(REDDITBinary, self).get(idx) | |||
| @register_dataset("reddit-multi-5k") | |||
| class REDDITMulti5K(TUDataset): | |||
| def __init__(self, path): | |||
| dataset = "REDDIT-MULTI-5K" | |||
| # path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset) | |||
| TUDataset(path, name=dataset) | |||
| super(REDDITMulti5K, self).__init__(path, name=dataset) | |||
| def get(self, idx): | |||
| if hasattr(self, "__data_list__"): | |||
| delattr(self, "__data_list__") | |||
| if hasattr(self, "_data_list"): | |||
| delattr(self, "_data_list") | |||
| return super(REDDITMulti5K, self).get(idx) | |||
| @register_dataset("reddit-multi-12k") | |||
| class REDDITMulti12K(TUDataset): | |||
| def __init__(self, path): | |||
| dataset = "REDDIT-MULTI-12K" | |||
| # path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset) | |||
| TUDataset(path, name=dataset) | |||
| super(REDDITMulti12K, self).__init__(path, name=dataset) | |||
| def get(self, idx): | |||
| if hasattr(self, "__data_list__"): | |||
| delattr(self, "__data_list__") | |||
| if hasattr(self, "_data_list"): | |||
| delattr(self, "_data_list") | |||
| return super(REDDITMulti12K, self).get(idx) | |||
| @register_dataset("ptc-mr") | |||
| class PTCMRDataset(TUDataset): | |||
| def __init__(self, path): | |||
| dataset = "PTC_MR" | |||
| # path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset) | |||
| TUDataset(path, name=dataset) | |||
| super(PTCMRDataset, self).__init__(path, name=dataset) | |||
| def get(self, idx): | |||
| if hasattr(self, "__data_list__"): | |||
| delattr(self, "__data_list__") | |||
| if hasattr(self, "_data_list"): | |||
| delattr(self, "_data_list") | |||
| return super(PTCMRDataset, self).get(idx) | |||
| @register_dataset("nci1") | |||
| class NCI1Dataset(TUDataset): | |||
| def __init__(self, path): | |||
| dataset = "NCI1" | |||
| # path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset) | |||
| TUDataset(path, name=dataset) | |||
| super(NCI1Dataset, self).__init__(path, name=dataset) | |||
| def get(self, idx): | |||
| if hasattr(self, "__data_list__"): | |||
| delattr(self, "__data_list__") | |||
| if hasattr(self, "_data_list"): | |||
| delattr(self, "_data_list") | |||
| return super(NCI1Dataset, self).get(idx) | |||
| @register_dataset("nci109") | |||
| class NCI109Dataset(TUDataset): | |||
| def __init__(self, path): | |||
| dataset = "NCI109" | |||
| # path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset) | |||
| TUDataset(path, name=dataset) | |||
| super(NCI109Dataset, self).__init__(path, name=dataset) | |||
| def get(self, idx): | |||
| if hasattr(self, "__data_list__"): | |||
| delattr(self, "__data_list__") | |||
| if hasattr(self, "_data_list"): | |||
| delattr(self, "_data_list") | |||
| return super(NCI109Dataset, self).get(idx) | |||
| @register_dataset("enzymes") | |||
| class ENZYMES(TUDataset): | |||
| def __init__(self, path): | |||
| dataset = "ENZYMES" | |||
| # path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset) | |||
| TUDataset(path, name=dataset) | |||
| super(ENZYMES, self).__init__(path, name=dataset) | |||
| def __getitem__(self, idx): | |||
| if isinstance(idx, int): | |||
| data = self.get(self.indices()[idx]) | |||
| data = data | |||
| edge_nodes = data.edge_index.max() + 1 | |||
| if edge_nodes < data.x.size(0): | |||
| data.x = data.x[:edge_nodes] | |||
| return data | |||
| else: | |||
| return self.index_select(idx) | |||
| def get(self, idx): | |||
| if hasattr(self, "__data_list__"): | |||
| delattr(self, "__data_list__") | |||
| if hasattr(self, "_data_list"): | |||
| delattr(self, "_data_list") | |||
| return super(ENZYMES, self).get(idx) | |||
| @register_dataset("qm9") | |||
| class QM9Dataset(QM9): | |||
| def __init__(self, path): | |||
| dataset = "QM9" | |||
| # path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset) | |||
| target = 0 | |||
| class MyTransform(object): | |||
| def __call__(self, data): | |||
| # Specify target. | |||
| data.y = data.y[:, target] | |||
| return data | |||
| class Complete(object): | |||
| def __call__(self, data): | |||
| device = data.edge_index.device | |||
| row = torch.arange(data.num_nodes, dtype=torch.long, device=device) | |||
| col = torch.arange(data.num_nodes, dtype=torch.long, device=device) | |||
| row = row.view(-1, 1).repeat(1, data.num_nodes).view(-1) | |||
| col = col.repeat(data.num_nodes) | |||
| edge_index = torch.stack([row, col], dim=0) | |||
| edge_attr = None | |||
| if data.edge_attr is not None: | |||
| idx = data.edge_index[0] * data.num_nodes + data.edge_index[1] | |||
| size = list(data.edge_attr.size()) | |||
| size[0] = data.num_nodes * data.num_nodes | |||
| edge_attr = data.edge_attr.new_zeros(size) | |||
| edge_attr[idx] = data.edge_attr | |||
| edge_index, edge_attr = remove_self_loops(edge_index, edge_attr) | |||
| data.edge_attr = edge_attr | |||
| data.edge_index = edge_index | |||
| return data | |||
| if not osp.exists(path): | |||
| QM9(path) | |||
| super(QM9Dataset, self).__init__(path) | |||
| def get(self, idx): | |||
| if hasattr(self, "__data_list__"): | |||
| delattr(self, "__data_list__") | |||
| if hasattr(self, "_data_list"): | |||
| delattr(self, "_data_list") | |||
| return super(QM9Dataset, self).get(idx) | |||
| @@ -1,453 +0,0 @@ | |||
| from pdb import set_trace | |||
| import torch | |||
| import numpy as np | |||
| from torch_geometric.data import DataLoader | |||
| from torch_geometric.utils import train_test_split_edges | |||
| from sklearn.model_selection import StratifiedKFold, KFold | |||
| def split_edges(dataset, train_ratio, val_ratio): | |||
| datas = [data for data in dataset] | |||
| for i in range(len(datas)): | |||
| datas[i] = train_test_split_edges( | |||
| datas[i], val_ratio, 1 - train_ratio - val_ratio | |||
| ) | |||
| dataset.data, dataset.slices = dataset.collate(datas) | |||
| def get_label_number(dataset): | |||
| r"""Get the number of labels in this dataset as dict.""" | |||
| label_num = {} | |||
| labels = dataset.data.y.unique().cpu().detach().numpy().tolist() | |||
| for label in labels: | |||
| label_num[label] = (dataset.data.y == label).sum().item() | |||
| return label_num | |||
| def index_to_mask(index, size): | |||
| mask = torch.zeros(size, dtype=torch.bool, device=index.device) | |||
| mask[index] = 1 | |||
| return mask | |||
| def random_splits_mask(dataset, train_ratio=0.2, val_ratio=0.4, seed=None): | |||
| r"""If the data has masks for train/val/test, return the splits with specific ratio. | |||
| Parameters | |||
| ---------- | |||
| train_ratio : float | |||
| the portion of data that used for training. | |||
| val_ratio : float | |||
| the portion of data that used for validation. | |||
| seed : int | |||
| random seed for splitting dataset. | |||
| """ | |||
| assert ( | |||
| train_ratio + val_ratio <= 1 | |||
| ), "the sum of train_ratio and val_ratio is larger than 1" | |||
| _dataset = [d for d in dataset] | |||
| for data in _dataset: | |||
| r_s = torch.get_rng_state() | |||
| if torch.cuda.is_available(): | |||
| r_s_cuda = torch.cuda.get_rng_state() | |||
| if seed is not None: | |||
| torch.manual_seed(seed) | |||
| if torch.cuda.is_available(): | |||
| torch.cuda.manual_seed(seed) | |||
| perm = torch.randperm(data.num_nodes) | |||
| train_index = perm[: int(data.num_nodes * train_ratio)] | |||
| val_index = perm[ | |||
| int(data.num_nodes * train_ratio) : int( | |||
| data.num_nodes * (train_ratio + val_ratio) | |||
| ) | |||
| ] | |||
| test_index = perm[int(data.num_nodes * (train_ratio + val_ratio)) :] | |||
| data.train_mask = index_to_mask(train_index, size=data.num_nodes) | |||
| data.val_mask = index_to_mask(val_index, size=data.num_nodes) | |||
| data.test_mask = index_to_mask(test_index, size=data.num_nodes) | |||
| torch.set_rng_state(r_s) | |||
| if torch.cuda.is_available(): | |||
| torch.cuda.set_rng_state(r_s_cuda) | |||
| dataset.data, dataset.slices = dataset.collate(_dataset) | |||
| if hasattr(dataset, "__data_list__"): | |||
| delattr(dataset, "__data_list__") | |||
| # while type(dataset.data.num_nodes) == list: | |||
| # dataset.data.num_nodes = dataset.data.num_nodes[0] | |||
| # dataset.data.num_nodes = dataset.data.num_nodes[0] | |||
| return dataset | |||
| def random_splits_mask_class( | |||
| dataset, | |||
| num_train_per_class=20, | |||
| num_val_per_class=30, | |||
| num_val=None, | |||
| num_test=None, | |||
| seed=None, | |||
| ): | |||
| r"""If the data has masks for train/val/test, return the splits with specific number of samples from every class for training as suggested in Pitfalls of graph neural network evaluation [#]_ for semi-supervised learning. | |||
| References | |||
| ---------- | |||
| .. [#] Shchur, O., Mumme, M., Bojchevski, A., & Günnemann, S. (2018). | |||
| Pitfalls of graph neural network evaluation. | |||
| arXiv preprint arXiv:1811.05868. | |||
| Parameters | |||
| ---------- | |||
| num_train_per_class : int | |||
| the number of samples from every class used for training. | |||
| num_val_per_class : int | |||
| the number of samples from every class used for validation. | |||
| num_val : int | |||
| the total number of nodes that used for validation as alternative. | |||
| num_test : int | |||
| the total number of nodes that used for testing as alternative. The rest of the data will be seleted as test set if num_test set to None. | |||
| seed : int | |||
| random seed for splitting dataset. | |||
| """ | |||
| data = dataset[0] | |||
| r_s = torch.get_rng_state() | |||
| if torch.cuda.is_available(): | |||
| r_s_cuda = torch.cuda.get_rng_state() | |||
| if seed is not None: | |||
| torch.manual_seed(seed) | |||
| if torch.cuda.is_available(): | |||
| torch.cuda.manual_seed(seed) | |||
| num_classes = data.y.max().cpu().item() + 1 | |||
| try: | |||
| data.train_mask.fill_(False) | |||
| data.val_mask.fill_(False) | |||
| data.test_mask.fill_(False) | |||
| except: | |||
| train_mask = torch.zeros( | |||
| data.num_nodes, dtype=torch.bool, device=data.edge_index.device | |||
| ) | |||
| val_mask = torch.zeros( | |||
| data.num_nodes, dtype=torch.bool, device=data.edge_index.device | |||
| ) | |||
| test_mask = torch.zeros( | |||
| data.num_nodes, dtype=torch.bool, device=data.edge_index.device | |||
| ) | |||
| setattr(data, "train_mask", train_mask) | |||
| setattr(data, "val_mask", val_mask) | |||
| setattr(data, "test_mask", test_mask) | |||
| for c_i in range(num_classes): | |||
| idx = (data.y == c_i).nonzero().view(-1) | |||
| assert num_train_per_class + num_val_per_class < idx.size(0), ( | |||
| "the total number of samples from every class used for training and validation is larger than the total samples in class " | |||
| + str(c_i) | |||
| ) | |||
| idx_idx_rand = torch.randperm(idx.size(0)) | |||
| idx_train = idx[idx_idx_rand[:num_train_per_class]] | |||
| idx_val = idx[ | |||
| idx_idx_rand[num_train_per_class : num_train_per_class + num_val_per_class] | |||
| ] | |||
| data.train_mask[idx_train] = True | |||
| data.val_mask[idx_val] = True | |||
| if num_val is not None: | |||
| remaining = (~data.train_mask).nonzero().view(-1) | |||
| remaining = remaining[torch.randperm(remaining.size(0))] | |||
| data.val_mask[remaining[:num_val]] = True | |||
| if num_test is not None: | |||
| data.test_mask[remaining[num_val : num_val + num_test]] = True | |||
| else: | |||
| data.test_mask[remaining[num_val:]] = True | |||
| else: | |||
| remaining = (~(data.train_mask + data.val_mask)).nonzero().view(-1) | |||
| data.test_mask[remaining] = True | |||
| torch.set_rng_state(r_s) | |||
| if torch.cuda.is_available(): | |||
| torch.cuda.set_rng_state(r_s_cuda) | |||
| datalist = [] | |||
| for d in dataset: | |||
| setattr(d, "train_mask", data.train_mask) | |||
| setattr(d, "val_mask", data.val_mask) | |||
| setattr(d, "test_mask", data.test_mask) | |||
| datalist.append(d) | |||
| dataset.data, dataset.slices = dataset.collate(datalist) | |||
| if hasattr(dataset, "__data_list__"): | |||
| delattr(dataset, "__data_list__") | |||
| # while type(dataset.data.num_nodes) == list: | |||
| # dataset.data.num_nodes = dataset.data.num_nodes[0] | |||
| # dataset.data.num_nodes = dataset.data.num_nodes[0] | |||
| return dataset | |||
| def graph_cross_validation( | |||
| dataset, n_splits=10, shuffle=True, random_seed=42, stratify=False | |||
| ): | |||
| r"""Cross validation for graph classification data, returning one fold with specific idx in autogl.datasets or pyg.Dataloader(default) | |||
| Parameters | |||
| ---------- | |||
| dataset : str | |||
| dataset with multiple graphs. | |||
| n_splits : int | |||
| the number of how many folds will be splitted. | |||
| shuffle : bool | |||
| shuffle or not for sklearn.model_selection.StratifiedKFold | |||
| random_seed : int | |||
| random_state for sklearn.model_selection.StratifiedKFold | |||
| """ | |||
| if stratify: | |||
| skf = StratifiedKFold( | |||
| n_splits=n_splits, shuffle=shuffle, random_state=random_seed | |||
| ) | |||
| else: | |||
| skf = KFold(n_splits=n_splits, shuffle=shuffle, random_state=random_seed) | |||
| idx_list = [] | |||
| # BUG: from pytorch_geometric, not sure whether it is a bug. The dataset.data will return | |||
| # the data of original dataset even if the input dataset is subset of original. We hackfix | |||
| # this bug currently by iterating y. | |||
| dataset_y = [data.y[0].tolist() for data in dataset] | |||
| for idx in skf.split(np.zeros(len(dataset_y)), dataset_y): | |||
| idx_list.append(idx) | |||
| dataset.idx_list = idx_list | |||
| dataset.n_splits = n_splits | |||
| # BUG: only saving idx will result in different references when calling multiple times, | |||
| # we need to also save splits in advance. | |||
| dataset.cv_dict = [ | |||
| { | |||
| "train": dataset[dataset.idx_list[idx][0].tolist()], | |||
| "val": dataset[dataset.idx_list[idx][1].tolist()], | |||
| } | |||
| for idx in range(n_splits) | |||
| ] | |||
| graph_set_fold_id(dataset, 0) | |||
| return dataset | |||
| def graph_set_fold_id(dataset, fold_id): | |||
| r"""Set the current fold id of graph dataset. | |||
| Parameters | |||
| ---------- | |||
| dataset: ``torch_geometric.data.dataset.Dataset`` | |||
| dataset with multiple graphs. | |||
| fold_id: ``int`` | |||
| The current fold id this dataset uses. Should be in [0, dataset.n_splits) | |||
| Returns | |||
| ------- | |||
| ``torch_geometric.data.dataset.Dataset`` | |||
| The reference original dataset. | |||
| """ | |||
| if not hasattr(dataset, "n_splits"): | |||
| raise ValueError("Dataset set fold id before cross validated!") | |||
| assert ( | |||
| 0 <= fold_id < dataset.n_splits | |||
| ), "Fold id %d exceed total cross validation split number %d" % ( | |||
| fold_id, | |||
| dataset.n_splits, | |||
| ) | |||
| dataset.current_fold_id = fold_id | |||
| dataset.train_split = dataset.cv_dict[dataset.current_fold_id]["train"] | |||
| dataset.val_split = dataset.cv_dict[dataset.current_fold_id]["val"] | |||
| dataset.train_index = dataset.idx_list[dataset.current_fold_id][0] | |||
| dataset.val_index = dataset.idx_list[dataset.current_fold_id][1] | |||
| return dataset | |||
| def graph_random_splits(dataset, train_ratio=0.2, val_ratio=0.4, seed=None): | |||
| r"""Splitting graph dataset with specific ratio for train/val/test. | |||
| Parameters | |||
| ---------- | |||
| dataset: ``torch_geometric.data.dataset.Dataset`` | |||
| dataset with multiple graphs. | |||
| train_ratio : float | |||
| the portion of data that used for training. | |||
| val_ratio : float | |||
| the portion of data that used for validation. | |||
| seed : int | |||
| random seed for splitting dataset. | |||
| Returns | |||
| ------- | |||
| ``torch_geometric.data.dataset.Dataset`` | |||
| The reference of original dataset | |||
| """ | |||
| assert ( | |||
| train_ratio + val_ratio <= 1 | |||
| ), "the sum of train_ratio and val_ratio is larger than 1" | |||
| r_s = torch.get_rng_state() | |||
| if torch.cuda.is_available(): | |||
| r_s_cuda = torch.cuda.get_rng_state() | |||
| if seed is not None: | |||
| torch.manual_seed(seed) | |||
| if torch.cuda.is_available(): | |||
| torch.cuda.manual_seed(seed) | |||
| perm = torch.randperm(len(dataset)) | |||
| train_index = perm[: int(len(dataset) * train_ratio)] | |||
| val_index = perm[ | |||
| int(len(dataset) * train_ratio) : int(len(dataset) * (train_ratio + val_ratio)) | |||
| ] | |||
| test_index = perm[int(len(dataset) * (train_ratio + val_ratio)) :] | |||
| train_dataset = dataset[train_index] | |||
| val_dataset = dataset[val_index] | |||
| test_dataset = dataset[test_index] | |||
| # set train_idx, val_idx and test_idx as dataset attribute | |||
| dataset.train_split = train_dataset | |||
| dataset.val_split = val_dataset | |||
| dataset.test_split = test_dataset | |||
| dataset.train_index = train_index | |||
| dataset.val_index = val_index | |||
| dataset.test_index = test_index | |||
| torch.set_rng_state(r_s) | |||
| if torch.cuda.is_available(): | |||
| torch.cuda.set_rng_state(r_s_cuda) | |||
| return dataset | |||
| def graph_get_split( | |||
| dataset, mask="train", is_loader=True, batch_size=128, num_workers=0 | |||
| ): | |||
| r"""Get train/test dataset/dataloader after cross validation. | |||
| Parameters | |||
| ---------- | |||
| dataset: ``torch_geometric.data.dataset.Dataset`` | |||
| dataset with multiple graphs. | |||
| mask : str | |||
| return with which dataset/dataloader | |||
| is_loader : bool | |||
| return with autogl.datasets or pyg.Dataloader | |||
| batch_size : int | |||
| batch_size for generateing Dataloader | |||
| """ | |||
| assert hasattr( | |||
| dataset, "%s_split" % (mask) | |||
| ), "Given dataset do not have %s split" % (mask) | |||
| if is_loader: | |||
| return DataLoader( | |||
| getattr(dataset, "%s_split" % (mask)), | |||
| batch_size=batch_size, | |||
| num_workers=num_workers, | |||
| ) | |||
| else: | |||
| return getattr(dataset, "%s_split" % (mask)) | |||
| ''' | |||
| def graph_cross_validation(dataset, n_splits = 10, shuffle = True, random_seed = 42, fold_idx = 0, batch_size = 32, dataloader = True): | |||
| r"""Cross validation for graph classification data, returning one fold with specific idx in autogl.datasets or pyg.Dataloader(default) | |||
| Parameters | |||
| ---------- | |||
| dataset : str | |||
| dataset with multiple graphs. | |||
| n_splits : int | |||
| the number of how many folds will be splitted. | |||
| shuffle : bool | |||
| shuffle or not for sklearn.model_selection.StratifiedKFold | |||
| random_seed : int | |||
| random_state for sklearn.model_selection.StratifiedKFold | |||
| fold_idx : int | |||
| specific fold id from 0 to n_splits-1 | |||
| batch_size : int | |||
| batch_size for generateing Dataloader | |||
| dataloader : bool | |||
| return with autogl.datasets or pyg.Dataloader | |||
| """ | |||
| skf = StratifiedKFold(n_splits=n_splits, shuffle = shuffle, random_state = random_seed) | |||
| idx_list = [] | |||
| for idx in skf.split(np.zeros(len(dataset.data.y)), dataset.data.y): | |||
| idx_list.append(idx) | |||
| assert 0 <= fold_idx and fold_idx < n_splits, "fold_idx must be from 0 to " + str(n_splits-1) | |||
| train_idx, test_idx = idx_list[fold_idx] | |||
| test_dataset = dataset[test_idx.tolist()] | |||
| train_dataset = dataset[train_idx.tolist()] | |||
| if dataloader: | |||
| return DataLoader(train_dataset, batch_size=128), DataLoader(test_dataset, batch_size=128) | |||
| else: | |||
| return train_dataset, test_dataset | |||
| ''' | |||
| def train_test_split(self, method="auto", ratio=None): | |||
| raise NotImplementedError() | |||
| def train_valid_split(self, method="auto", ratio=None): | |||
| raise NotImplementedError() | |||
| def cross_validation_split(self, method="auto", cv_fold_num=5): | |||
| return NotImplementedError() | |||
| # below get_* can also be set as property | |||
| def get_train_dataset(self): | |||
| raise NotImplementedError() | |||
| def get_test_dataset(self): | |||
| raise NotImplementedError() | |||
| def get_valid_dataset(self): | |||
| raise NotImplementedError() | |||
| def get_train_generator(self, batch_size): | |||
| """ | |||
| should return a torch.utils.data.Dataloader | |||
| """ | |||
| raise NotImplementedError() | |||
| def get_test_generator(self, batch_size): | |||
| """ | |||
| should return a torch.utils.data.Dataloader | |||
| """ | |||
| raise NotImplementedError() | |||
| def get_valid_generator(self, batch_size): | |||
| """ | |||
| should return a torch.utils.data.Dataloader | |||
| """ | |||
| raise NotImplementedError() | |||
| @@ -0,0 +1,9 @@ | |||
| from ._general import ( | |||
| index_to_mask, | |||
| split_edges, | |||
| random_splits_mask, | |||
| random_splits_mask_class, | |||
| graph_cross_validation, | |||
| graph_random_splits, | |||
| graph_get_split | |||
| ) | |||
| @@ -0,0 +1,412 @@ | |||
| import numpy as np | |||
| import random | |||
| import torch | |||
| import torch.utils.data | |||
| import typing as _typing | |||
| from sklearn.model_selection import StratifiedKFold, KFold | |||
| from autogl import backend as _backend | |||
| from autogl.data import Data, Dataset, InMemoryStaticGraphSet | |||
| from ...data.graph import GeneralStaticGraph, GeneralStaticGraphGenerator | |||
| from . import _pyg | |||
| def index_to_mask(index: torch.Tensor, size): | |||
| mask = torch.zeros(size, dtype=torch.bool, device=index.device) | |||
| mask[index] = True | |||
| return mask | |||
| def split_edges( | |||
| dataset: InMemoryStaticGraphSet, | |||
| train_ratio: float, val_ratio: float | |||
| ) -> InMemoryStaticGraphSet: | |||
| test_ratio: float = 1 - train_ratio - val_ratio | |||
| def _split_edges_for_graph(homogeneous_static_graph: GeneralStaticGraph) -> GeneralStaticGraph: | |||
| if not isinstance(homogeneous_static_graph, GeneralStaticGraph): | |||
| raise TypeError | |||
| elif not homogeneous_static_graph.edges.is_homogeneous: | |||
| raise ValueError("The provided graph MUST consist of homogeneous edges.") | |||
| else: | |||
| split_data = _pyg.train_test_split_edges( | |||
| Data( | |||
| edge_index=homogeneous_static_graph.edges.connections.detach().clone(), | |||
| edge_attr=( | |||
| homogeneous_static_graph.edges.data['edge_attr'].detach().clone() | |||
| if 'edge_attr' in homogeneous_static_graph.edges.data else None | |||
| ) | |||
| ), | |||
| val_ratio, test_ratio | |||
| ) | |||
| original_edge_type = [et for et in homogeneous_static_graph.edges][0] | |||
| split_static_graph = GeneralStaticGraphGenerator.create_heterogeneous_static_graph( | |||
| dict([ | |||
| (node_type, homogeneous_static_graph.nodes[node_type].data) | |||
| for node_type in homogeneous_static_graph.nodes | |||
| ]), | |||
| { | |||
| (original_edge_type.source_node_type, "train_pos_edge", original_edge_type.target_node_type): ( | |||
| getattr(split_data, "train_pos_edge_index"), | |||
| {"edge_attr": getattr(split_data, "train_pos_edge_attr")} | |||
| if isinstance(getattr(split_data, "train_pos_edge_attr"), torch.Tensor) | |||
| else None | |||
| ), | |||
| (original_edge_type.source_node_type, "val_pos_edge", original_edge_type.target_node_type): ( | |||
| getattr(split_data, "val_pos_edge_index"), | |||
| {"edge_attr": getattr(split_data, "val_pos_edge_attr")} | |||
| if isinstance(getattr(split_data, "val_pos_edge_attr"), torch.Tensor) | |||
| else None | |||
| ), | |||
| (original_edge_type.source_node_type, "val_neg_edge", original_edge_type.target_node_type): | |||
| getattr(split_data, "val_neg_edge_index"), | |||
| (original_edge_type.source_node_type, "test_pos_edge", original_edge_type.target_node_type): ( | |||
| getattr(split_data, "test_pos_edge_index"), | |||
| {"edge_attr": getattr(split_data, "test_pos_edge_attr")} | |||
| if isinstance(getattr(split_data, "test_pos_edge_attr"), torch.Tensor) | |||
| else None | |||
| ), | |||
| (original_edge_type.source_node_type, "test_neg_edge", original_edge_type.target_node_type): | |||
| getattr(split_data, "test_neg_edge_index") | |||
| }, | |||
| homogeneous_static_graph.data | |||
| ) | |||
| return split_static_graph | |||
| if not isinstance(dataset, InMemoryStaticGraphSet): | |||
| raise TypeError | |||
| for index in range(len(dataset)): | |||
| dataset[index] = _split_edges_for_graph(dataset[index]) | |||
| return dataset | |||
| def random_splits_mask( | |||
| dataset: InMemoryStaticGraphSet, | |||
| train_ratio: float = 0.2, val_ratio: float = 0.4, | |||
| seed: _typing.Optional[int] = None | |||
| ) -> InMemoryStaticGraphSet: | |||
| r"""If the data has masks for train/val/test, return the splits with specific ratio. | |||
| Parameters | |||
| ---------- | |||
| dataset : InMemoryStaticGraphSet | |||
| graph set | |||
| train_ratio : float | |||
| the portion of data that used for training. | |||
| val_ratio : float | |||
| the portion of data that used for validation. | |||
| seed : int | |||
| random seed for splitting dataset. | |||
| """ | |||
| if not train_ratio + val_ratio <= 1: | |||
| raise ValueError("the sum of provided train_ratio and val_ratio is larger than 1") | |||
| def __random_split_masks( | |||
| num_nodes: int | |||
| ) -> _typing.Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |||
| _rng_state: torch.Tensor = torch.get_rng_state() | |||
| if seed is not None and isinstance(seed, int): | |||
| torch.manual_seed(seed) | |||
| perm = torch.randperm(num_nodes) | |||
| train_index = perm[:int(num_nodes * train_ratio)] | |||
| val_index = perm[int(num_nodes * train_ratio): int(num_nodes * (train_ratio + val_ratio))] | |||
| test_index = perm[int(num_nodes * (train_ratio + val_ratio)):] | |||
| torch.set_rng_state(_rng_state) | |||
| return ( | |||
| index_to_mask(train_index, num_nodes), | |||
| index_to_mask(val_index, num_nodes), | |||
| index_to_mask(test_index, num_nodes) | |||
| ) | |||
| for index in range(len(dataset)): | |||
| for node_type in dataset[index].nodes: | |||
| data_keys = [data_key for data_key in dataset[index].nodes.data] | |||
| if len(data_keys) > 0: | |||
| _num_nodes: int = dataset[index].nodes[node_type].data[data_keys[0]].size(0) | |||
| _masks: _typing.Tuple[torch.Tensor, torch.Tensor, torch.Tensor] = ( | |||
| __random_split_masks(_num_nodes) | |||
| ) | |||
| dataset[index].nodes[node_type].data["train_mask"] = _masks[0] | |||
| dataset[index].nodes[node_type].data["val_mask"] = _masks[1] | |||
| dataset[index].nodes[node_type].data["test_mask"] = _masks[2] | |||
| return dataset | |||
| def random_splits_mask_class( | |||
| dataset: InMemoryStaticGraphSet, | |||
| num_train_per_class: int = 20, | |||
| num_val_per_class: int = 30, | |||
| total_num_val: _typing.Optional[int] = ..., | |||
| total_num_test: _typing.Optional[int] = ..., | |||
| seed: _typing.Optional[int] = ... | |||
| ): | |||
| r"""If the data has masks for train/val/test, return the splits with specific number of samples from every class for training as suggested in Pitfalls of graph neural network evaluation [#]_ for semi-supervised learning. | |||
| References | |||
| ---------- | |||
| .. [#] Shchur, O., Mumme, M., Bojchevski, A., & Günnemann, S. (2018). | |||
| Pitfalls of graph neural network evaluation. | |||
| arXiv preprint arXiv:1811.05868. | |||
| Parameters | |||
| ---------- | |||
| dataset: InMemoryStaticGraphSet | |||
| instance of InMemoryStaticGraphSet | |||
| num_train_per_class : int | |||
| the number of samples from every class used for training. | |||
| num_val_per_class : int | |||
| the number of samples from every class used for validation. | |||
| total_num_val : int | |||
| the total number of nodes that used for validation as alternative. | |||
| total_num_test : int | |||
| the total number of nodes that used for testing as alternative. The rest of the data will be seleted as test set if num_test set to None. | |||
| seed : int | |||
| random seed for splitting dataset. | |||
| """ | |||
| for graph_index in range(len(dataset)): | |||
| for node_type in dataset[graph_index].nodes: | |||
| if ( | |||
| 'y' in dataset[graph_index].nodes[node_type].data and | |||
| 'label' in dataset[graph_index].nodes[node_type].data | |||
| ): | |||
| raise ValueError( | |||
| f"Both 'y' and 'label' data exist " | |||
| f"for node type [{node_type}] in " | |||
| f"graph with index [{graph_index}]." | |||
| ) | |||
| elif ( | |||
| 'y' not in dataset[graph_index].nodes[node_type].data and | |||
| 'label' not in dataset[graph_index].nodes[node_type].data | |||
| ): | |||
| continue | |||
| elif 'y' in dataset[graph_index].nodes[node_type].data: | |||
| label: torch.Tensor = dataset[graph_index].nodes[node_type].data['y'] | |||
| elif 'label' in dataset[graph_index].nodes[node_type].data: | |||
| label: torch.Tensor = dataset[graph_index].nodes[node_type].data['label'] | |||
| else: | |||
| raise RuntimeError | |||
| num_nodes: int = label.size(0) | |||
| num_classes: int = label.cpu().max().item() + 1 | |||
| _rng_state: torch.Tensor = torch.get_rng_state() | |||
| if seed not in (Ellipsis, None) and isinstance(seed, int): | |||
| torch.manual_seed(seed) | |||
| train_mask = torch.zeros(num_nodes, dtype=torch.bool, device=label.device) | |||
| val_mask = torch.zeros(num_nodes, dtype=torch.bool, device=label.device) | |||
| test_mask = torch.zeros(num_nodes, dtype=torch.bool, device=label.device) | |||
| for class_index in range(num_classes): | |||
| idx = (label == class_index).nonzero().view(-1) | |||
| assert num_train_per_class + num_val_per_class < idx.size(0), ( | |||
| f"the total number of samples from every class " | |||
| f"used for training and validation is larger than " | |||
| f"the total samples in class [{class_index}] for node type [{node_type}] " | |||
| f"in graph with index [{graph_index}]" | |||
| ) | |||
| randomized_index: torch.Tensor = torch.randperm(idx.size(0)) | |||
| train_idx = idx[randomized_index[:num_train_per_class]] | |||
| val_idx = idx[ | |||
| randomized_index[num_train_per_class: (num_train_per_class + num_val_per_class)] | |||
| ] | |||
| train_mask[train_idx] = True | |||
| val_mask[val_idx] = True | |||
| if isinstance(total_num_val, int) and total_num_val > 0: | |||
| remaining = (~train_mask).nonzero().view(-1) | |||
| remaining = remaining[torch.randperm(remaining.size(0))] | |||
| val_mask[remaining[:total_num_val]] = True | |||
| if isinstance(total_num_test, int) and total_num_test > 0: | |||
| test_mask[remaining[total_num_val: (total_num_val + total_num_test)]] = True | |||
| else: | |||
| test_mask[remaining[total_num_val:]] = True | |||
| else: | |||
| remaining = (~(train_mask + val_mask)).nonzero().view(-1) | |||
| test_mask[remaining] = True | |||
| torch.set_rng_state(_rng_state) | |||
| dataset[graph_index].nodes[node_type].data["train_mask"] = train_mask | |||
| dataset[graph_index].nodes[node_type].data["val_mask"] = val_mask | |||
| dataset[graph_index].nodes[node_type].data["test_mask"] = test_mask | |||
| return dataset | |||
| def graph_cross_validation( | |||
| dataset: InMemoryStaticGraphSet, | |||
| n_splits: int = 10, shuffle: bool = True, | |||
| random_seed: _typing.Optional[int] = ..., | |||
| stratify: bool = False | |||
| ) -> InMemoryStaticGraphSet: | |||
| r"""Cross validation for graph classification data, returning one fold with specific idx in autogl.datasets or pyg.Dataloader(default) | |||
| Parameters | |||
| ---------- | |||
| dataset : str | |||
| dataset with multiple graphs. | |||
| n_splits : int | |||
| the number of how many folds will be splitted. | |||
| shuffle : bool | |||
| shuffle or not for sklearn.model_selection.StratifiedKFold | |||
| random_seed : int | |||
| random_state for sklearn.model_selection.StratifiedKFold | |||
| stratify: bool | |||
| """ | |||
| if not isinstance(dataset, InMemoryStaticGraphSet): | |||
| raise TypeError | |||
| if not isinstance(n_splits, int): | |||
| raise TypeError | |||
| elif not n_splits > 0: | |||
| raise ValueError | |||
| if not isinstance(shuffle, bool): | |||
| raise TypeError | |||
| if not (random_seed in (Ellipsis, None) or isinstance(random_seed, int)): | |||
| raise TypeError | |||
| elif isinstance(random_seed, int) and random_seed >= 0: | |||
| _random_seed: int = random_seed | |||
| else: | |||
| _random_seed: int = random.randrange(0, 65536) | |||
| if not isinstance(stratify, bool): | |||
| raise TypeError | |||
| if stratify: | |||
| kf = StratifiedKFold( | |||
| n_splits=n_splits, shuffle=shuffle, random_state=_random_seed | |||
| ) | |||
| else: | |||
| kf = KFold( | |||
| n_splits=n_splits, shuffle=shuffle, random_state=_random_seed | |||
| ) | |||
| dataset_y = [g.data['y'].item() for g in dataset] | |||
| idx_list = [ | |||
| (train_index.tolist(), test_index.tolist()) | |||
| for train_index, test_index | |||
| in kf.split(np.zeros(len(dataset)), np.array(dataset_y)) | |||
| ] | |||
| dataset.folds = idx_list | |||
| dataset.train_index = idx_list[0][0] | |||
| dataset.val_index = idx_list[0][1] | |||
| return dataset | |||
| def graph_random_splits( | |||
| dataset: InMemoryStaticGraphSet, | |||
| train_ratio: float = 0.2, | |||
| val_ratio: float = 0.4, | |||
| seed: _typing.Optional[int] = ... | |||
| ): | |||
| r"""Splitting graph dataset with specific ratio for train/val/test. | |||
| Parameters | |||
| ---------- | |||
| dataset: ``InMemoryStaticGraphSet`` | |||
| train_ratio : float | |||
| the portion of data that used for training. | |||
| val_ratio : float | |||
| the portion of data that used for validation. | |||
| seed : int | |||
| random seed for splitting dataset. | |||
| """ | |||
| _rng_state = torch.get_rng_state() | |||
| if isinstance(seed, int): | |||
| torch.manual_seed(seed) | |||
| perm = torch.randperm(len(dataset)) | |||
| train_index = perm[: int(len(dataset) * train_ratio)] | |||
| val_index = ( | |||
| perm[int(len(dataset) * train_ratio): int(len(dataset) * (train_ratio + val_ratio))] | |||
| ) | |||
| test_index = perm[int(len(dataset) * (train_ratio + val_ratio)):] | |||
| dataset.train_index = train_index | |||
| dataset.val_index = val_index | |||
| dataset.test_index = test_index | |||
| torch.set_rng_state(_rng_state) | |||
| return dataset | |||
| def graph_get_split( | |||
| dataset: Dataset, mask: str = "train", | |||
| is_loader: bool = True, batch_size: int = 128, | |||
| num_workers: int = 0 | |||
| ) -> _typing.Union[torch.utils.data.DataLoader, _typing.Iterable]: | |||
| r"""Get train/test dataset/dataloader after cross validation. | |||
| Parameters | |||
| ---------- | |||
| dataset: | |||
| dataset with multiple graphs. | |||
| mask : str | |||
| is_loader : bool | |||
| return original dataset or data loader | |||
| batch_size : int | |||
| batch_size for generating Dataloader | |||
| num_workers : int | |||
| number of workers parameter for data loader | |||
| """ | |||
| if not isinstance(dataset, Dataset): | |||
| raise TypeError | |||
| if not isinstance(mask, str): | |||
| raise TypeError | |||
| elif mask.lower() not in ("train", "val", "test"): | |||
| raise ValueError | |||
| if not isinstance(is_loader, bool): | |||
| raise TypeError | |||
| if not isinstance(batch_size, int): | |||
| raise TypeError | |||
| elif not batch_size > 0: | |||
| raise ValueError | |||
| if not isinstance(num_workers, int): | |||
| raise TypeError | |||
| elif not num_workers >= 0: | |||
| raise ValueError | |||
| if mask.lower() not in ("train", "val", "test"): | |||
| raise ValueError | |||
| elif mask.lower() == "train": | |||
| optional_dataset_split = dataset.train_split | |||
| elif mask.lower() == "val": | |||
| optional_dataset_split = dataset.val_split | |||
| elif mask.lower() == "test": | |||
| optional_dataset_split = dataset.test_split | |||
| else: | |||
| raise ValueError( | |||
| f"The provided mask parameter must be a str in ['train', 'val', 'test'], " | |||
| f"illegal provided value is [{mask}]" | |||
| ) | |||
| if ( | |||
| optional_dataset_split is None or | |||
| not isinstance(optional_dataset_split, _typing.Iterable) | |||
| ): | |||
| raise ValueError( | |||
| f"Provided dataset do NOT have {mask} split" | |||
| ) | |||
| if is_loader: | |||
| if not (_backend.DependentBackend.is_dgl() or _backend.DependentBackend.is_pyg()): | |||
| raise RuntimeError("Unsupported backend") | |||
| elif _backend.DependentBackend.is_dgl(): | |||
| from dgl.dataloading.pytorch import GraphDataLoader | |||
| return GraphDataLoader( | |||
| optional_dataset_split, | |||
| **{"batch_size": batch_size, "num_workers": num_workers} | |||
| ) | |||
| elif _backend.DependentBackend.is_pyg(): | |||
| dataset_split: _typing.Any = optional_dataset_split | |||
| import torch_geometric | |||
| return torch_geometric.data.DataLoader( | |||
| dataset_split, batch_size=batch_size, num_workers=num_workers | |||
| ) | |||
| else: | |||
| return optional_dataset_split | |||
| @@ -0,0 +1,116 @@ | |||
| """ Migrated `train_test_split_edges` function from PyTorch-Geometric """ | |||
| import math | |||
| import torch | |||
| import typing as _typing | |||
| def to_undirected( | |||
| edge_index: torch.Tensor, edge_attr: _typing.Optional[torch.Tensor] = None | |||
| ) -> _typing.Union[torch.Tensor, _typing.Tuple[torch.Tensor, torch.Tensor]]: | |||
| r"""Converts the graph given by :attr:`edge_index` to an undirected graph | |||
| such that :math:`(j,i) \in \mathcal{E}` for every edge :math:`(i,j) \in | |||
| \mathcal{E}`. | |||
| Args: | |||
| edge_index (LongTensor): The edge indices. | |||
| edge_attr (Tensor, optional): Edge weights or multi-dimensional | |||
| edge features. (default: :obj:`None`) | |||
| num_nodes (int, optional): The number of nodes, *i.e.* | |||
| :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`) | |||
| :rtype: :class:`LongTensor` if :attr:`edge_attr` is :obj:`None`, else | |||
| (:class:`LongTensor`, :class:`Tensor`) | |||
| """ | |||
| row, col = edge_index | |||
| row, col = torch.cat([row, col], dim=0), torch.cat([col, row], dim=0) | |||
| edge_index = torch.stack([row, col], dim=0) | |||
| if edge_attr is not None: | |||
| edge_attr = torch.cat([edge_attr, edge_attr], dim=0) | |||
| if edge_attr is None: | |||
| return edge_index | |||
| else: | |||
| return edge_index, edge_attr | |||
| def train_test_split_edges(data, val_ratio: float = 0.05, | |||
| test_ratio: float = 0.1): | |||
| r"""Splits the edges of a :class:`torch_geometric.data.Data` object | |||
| into positive and negative train/val/test edges. | |||
| As such, it will replace the :obj:`edge_index` attribute with | |||
| :obj:`train_pos_edge_index`, :obj:`train_pos_neg_adj_mask`, | |||
| :obj:`val_pos_edge_index`, :obj:`val_neg_edge_index` and | |||
| :obj:`test_pos_edge_index` attributes. | |||
| If :obj:`data` has edge features named :obj:`edge_attr`, then | |||
| :obj:`train_pos_edge_attr`, :obj:`val_pos_edge_attr` and | |||
| :obj:`test_pos_edge_attr` will be added as well. | |||
| Args: | |||
| data (Data): The data object. | |||
| val_ratio (float, optional): The ratio of positive validation edges. | |||
| (default: :obj:`0.05`) | |||
| test_ratio (float, optional): The ratio of positive test edges. | |||
| (default: :obj:`0.1`) | |||
| :rtype: :class:`torch_geometric.data.Data` | |||
| """ | |||
| num_nodes = data.num_nodes | |||
| row, col = data.edge_index | |||
| edge_attr = data.edge_attr | |||
| data.edge_index = data.edge_attr = None | |||
| # Return upper triangular portion. | |||
| mask = row < col | |||
| row, col = row[mask], col[mask] | |||
| if edge_attr is not None: | |||
| edge_attr = edge_attr[mask] | |||
| n_v = int(math.floor(val_ratio * row.size(0))) | |||
| n_t = int(math.floor(test_ratio * row.size(0))) | |||
| # Positive edges. | |||
| perm = torch.randperm(row.size(0)) | |||
| row, col = row[perm], col[perm] | |||
| if edge_attr is not None: | |||
| edge_attr = edge_attr[perm] | |||
| r, c = row[:n_v], col[:n_v] | |||
| data.val_pos_edge_index = torch.stack([r, c], dim=0) | |||
| if edge_attr is not None: | |||
| data.val_pos_edge_attr = edge_attr[:n_v] | |||
| r, c = row[n_v:n_v + n_t], col[n_v:n_v + n_t] | |||
| data.test_pos_edge_index = torch.stack([r, c], dim=0) | |||
| if edge_attr is not None: | |||
| data.test_pos_edge_attr = edge_attr[n_v:n_v + n_t] | |||
| r, c = row[n_v + n_t:], col[n_v + n_t:] | |||
| data.train_pos_edge_index = torch.stack([r, c], dim=0) | |||
| if edge_attr is not None: | |||
| out = to_undirected(data.train_pos_edge_index, edge_attr[n_v + n_t:]) | |||
| data.train_pos_edge_index, data.train_pos_edge_attr = out | |||
| else: | |||
| data.train_pos_edge_index = to_undirected(data.train_pos_edge_index) | |||
| # Negative edges. | |||
| neg_adj_mask = torch.ones(num_nodes, num_nodes, dtype=torch.uint8) | |||
| neg_adj_mask = neg_adj_mask.triu(diagonal=1).to(torch.bool) | |||
| neg_adj_mask[row, col] = 0 | |||
| neg_row, neg_col = neg_adj_mask.nonzero().t() | |||
| perm = torch.randperm(neg_row.size(0))[:n_v + n_t] | |||
| neg_row, neg_col = neg_row[perm], neg_col[perm] | |||
| neg_adj_mask[neg_row, neg_col] = 0 | |||
| data.train_neg_adj_mask = neg_adj_mask | |||
| row, col = neg_row[:n_v], neg_col[:n_v] | |||
| data.val_neg_edge_index = torch.stack([row, col], dim=0) | |||
| row, col = neg_row[n_v:n_v + n_t], neg_col[n_v:n_v + n_t] | |||
| data.test_neg_edge_index = torch.stack([row, col], dim=0) | |||
| return data | |||
| @@ -0,0 +1,12 @@ | |||
| try: | |||
| import dgl | |||
| except ModuleNotFoundError: | |||
| dgl = None | |||
| else: | |||
| from ._to_dgl_dataset import general_static_graphs_to_dgl_dataset | |||
| try: | |||
| import torch_geometric | |||
| except ModuleNotFoundError: | |||
| torch_geometric = None | |||
| else: | |||
| from ._to_pyg_dataset import general_static_graphs_to_pyg_dataset | |||
| @@ -0,0 +1,36 @@ | |||
| import dgl | |||
| import torch | |||
| import typing as _typing | |||
| from autogl.data import Dataset, InMemoryDataset | |||
| from autogl.data.graph import GeneralStaticGraph | |||
| from autogl.data.graph.utils import conversion | |||
| def general_static_graphs_to_dgl_dataset( | |||
| general_static_graphs: _typing.Iterable[GeneralStaticGraph] | |||
| ) -> Dataset[_typing.Union[dgl.DGLGraph, _typing.Tuple[dgl.DGLGraph, torch.Tensor]]]: | |||
| def _transform( | |||
| general_static_graph: GeneralStaticGraph | |||
| ) -> _typing.Union[dgl.DGLGraph, _typing.Tuple[dgl.DGLGraph, torch.Tensor]]: | |||
| if not isinstance(general_static_graph, GeneralStaticGraph): | |||
| raise TypeError | |||
| if 'label' in general_static_graph.data: | |||
| label: _typing.Optional[torch.Tensor] = general_static_graph.data['label'] | |||
| elif 'y' in general_static_graph.data: | |||
| label: _typing.Optional[torch.Tensor] = general_static_graph.data['y'] | |||
| else: | |||
| label: _typing.Optional[torch.Tensor] = None | |||
| if label is not None and isinstance(label, torch.Tensor) and torch.is_tensor(label): | |||
| return conversion.general_static_graph_to_dgl_graph(general_static_graph), label | |||
| else: | |||
| return conversion.general_static_graph_to_dgl_graph(general_static_graph) | |||
| if isinstance(general_static_graphs, Dataset): | |||
| return InMemoryDataset( | |||
| [_transform(g) for g in general_static_graphs], | |||
| general_static_graphs.train_index, | |||
| general_static_graphs.val_index, | |||
| general_static_graphs.test_index | |||
| ) | |||
| else: | |||
| return InMemoryDataset([_transform(g) for g in general_static_graphs]) | |||
| @@ -0,0 +1,18 @@ | |||
| import typing as _typing | |||
| from autogl.data import Data, Dataset, InMemoryDataset | |||
| from autogl.data.graph import GeneralStaticGraph | |||
| from autogl.data.graph.utils import conversion | |||
| def general_static_graphs_to_pyg_dataset( | |||
| graphs: _typing.Iterable[GeneralStaticGraph] | |||
| ) -> Dataset[Data]: | |||
| if isinstance(graphs, Dataset): | |||
| return InMemoryDataset( | |||
| [conversion.static_graph_to_pyg_data(g) for g in graphs], | |||
| graphs.train_index, graphs.val_index, graphs.test_index | |||
| ) | |||
| else: | |||
| return InMemoryDataset( | |||
| [conversion.static_graph_to_pyg_data(g) for g in graphs] | |||
| ) | |||
| @@ -0,0 +1,103 @@ | |||
| from .base import BaseFeature | |||
| from .base import BaseFeatureEngineer | |||
| FEATURE_DICT = {} | |||
| def register_feature(name): | |||
| def register_feature_cls(cls): | |||
| if name in FEATURE_DICT: | |||
| raise ValueError( | |||
| "Cannot register duplicate feature engineer ({})".format(name) | |||
| ) | |||
| # if not issubclass(cls, BaseFeatureEngineer): | |||
| if not issubclass(cls, BaseFeature): | |||
| raise ValueError( | |||
| "Trainer ({}: {}) must extend BaseFeatureEngineer".format( | |||
| name, cls.__name__ | |||
| ) | |||
| ) | |||
| FEATURE_DICT[name] = cls | |||
| return cls | |||
| return register_feature_cls | |||
| from .auto_feature import AutoFeatureEngineer | |||
| from .generators import ( | |||
| BaseGenerator, | |||
| GeGraphlet, | |||
| GeEigen, | |||
| GePageRank, | |||
| register_pyg, | |||
| pygfunc, | |||
| PYGGenerator, | |||
| PYGLocalDegreeProfile, | |||
| PYGNormalizeFeatures, | |||
| PYGOneHotDegree, | |||
| ) | |||
| from .selectors import BaseSelector, SeFilterConstant, SeGBDT | |||
| from .graph import ( | |||
| BaseGraph, | |||
| SgNetLSD, | |||
| register_nx, | |||
| NxGraph, | |||
| nxfunc, | |||
| NxLargeCliqueSize, | |||
| NxAverageClusteringApproximate, | |||
| NxDegreeAssortativityCoefficient, | |||
| NxDegreePearsonCorrelationCoefficient, | |||
| NxHasBridge, | |||
| NxGraphCliqueNumber, | |||
| NxGraphNumberOfCliques, | |||
| NxTransitivity, | |||
| NxAverageClustering, | |||
| NxIsConnected, | |||
| NxNumberConnectedComponents, | |||
| NxIsDistanceRegular, | |||
| NxLocalEfficiency, | |||
| NxGlobalEfficiency, | |||
| NxIsEulerian, | |||
| ) | |||
| __all__ = [ | |||
| "BaseFeatureEngineer", | |||
| "AutoFeatureEngineer", | |||
| "BaseFeature", | |||
| "BaseGenerator", | |||
| "GeGraphlet", | |||
| "GeEigen", | |||
| "GePageRank", | |||
| "register_pyg", | |||
| "pygfunc", | |||
| "PYGGenerator", | |||
| "PYGLocalDegreeProfile", | |||
| "PYGNormalizeFeatures", | |||
| "PYGOneHotDegree", | |||
| "BaseSelector", | |||
| "SeFilterConstant", | |||
| "SeGBDT", | |||
| "BaseGraph", | |||
| "SgNetLSD", | |||
| "register_nx", | |||
| "NxGraph", | |||
| "nxfunc", | |||
| "NxLargeCliqueSize", | |||
| "NxAverageClusteringApproximate", | |||
| "NxDegreeAssortativityCoefficient", | |||
| "NxDegreePearsonCorrelationCoefficient", | |||
| "NxHasBridge", | |||
| "NxGraphCliqueNumber", | |||
| "NxGraphNumberOfCliques", | |||
| "NxTransitivity", | |||
| "NxAverageClustering", | |||
| "NxIsConnected", | |||
| "NxNumberConnectedComponents", | |||
| "NxIsDistanceRegular", | |||
| "NxLocalEfficiency", | |||
| "NxGlobalEfficiency", | |||
| "NxIsEulerian", | |||
| ] | |||
| @@ -1,103 +1,35 @@ | |||
| from .base import BaseFeature | |||
| from .base import BaseFeatureEngineer | |||
| FEATURE_DICT = {} | |||
| def register_feature(name): | |||
| def register_feature_cls(cls): | |||
| if name in FEATURE_DICT: | |||
| raise ValueError( | |||
| "Cannot register duplicate feature engineer ({})".format(name) | |||
| ) | |||
| # if not issubclass(cls, BaseFeatureEngineer): | |||
| if not issubclass(cls, BaseFeature): | |||
| raise ValueError( | |||
| "Trainer ({}: {}) must extend BaseFeatureEngineer".format( | |||
| name, cls.__name__ | |||
| ) | |||
| ) | |||
| FEATURE_DICT[name] = cls | |||
| return cls | |||
| return register_feature_cls | |||
| from .auto_feature import AutoFeatureEngineer | |||
| from .generators import ( | |||
| BaseGenerator, | |||
| GeGraphlet, | |||
| GeEigen, | |||
| GePageRank, | |||
| register_pyg, | |||
| pygfunc, | |||
| PYGGenerator, | |||
| PYGLocalDegreeProfile, | |||
| PYGNormalizeFeatures, | |||
| PYGOneHotDegree, | |||
| from ._base_feature_engineer import ( | |||
| BaseFeatureEngineer, BaseFeature | |||
| ) | |||
| from .selectors import BaseSelector, SeFilterConstant, SeGBDT | |||
| from .graph import ( | |||
| BaseGraph, | |||
| SgNetLSD, | |||
| register_nx, | |||
| NxGraph, | |||
| nxfunc, | |||
| NxLargeCliqueSize, | |||
| NxAverageClusteringApproximate, | |||
| NxDegreeAssortativityCoefficient, | |||
| NxDegreePearsonCorrelationCoefficient, | |||
| NxHasBridge, | |||
| NxGraphCliqueNumber, | |||
| NxGraphNumberOfCliques, | |||
| NxTransitivity, | |||
| NxAverageClustering, | |||
| NxIsConnected, | |||
| NxNumberConnectedComponents, | |||
| NxIsDistanceRegular, | |||
| NxLocalEfficiency, | |||
| NxGlobalEfficiency, | |||
| NxIsEulerian, | |||
| from ._feature_engineer_registry import ( | |||
| FeatureEngineerUniversalRegistry, FEATURE_DICT | |||
| ) | |||
| from ._generators import ( | |||
| OneHotFeatureGenerator, | |||
| EigenFeatureGenerator, | |||
| GraphletGenerator, | |||
| PageRankFeatureGenerator, | |||
| LocalDegreeProfileGenerator, | |||
| NormalizeFeatures, | |||
| OneHotDegreeGenerator | |||
| ) | |||
| from ._graph import ( | |||
| NetLSD, | |||
| NXLargeCliqueSize, | |||
| NXDegreeAssortativityCoefficient, | |||
| NXDegreePearsonCorrelationCoefficient, | |||
| NXHasBridges, | |||
| NXGraphCliqueNumber, | |||
| NXGraphNumberOfCliques, | |||
| NXTransitivity, | |||
| NXAverageClustering, | |||
| NXIsConnected, | |||
| NXNumberConnectedComponents, | |||
| NXIsDistanceRegular, | |||
| NXLocalEfficiency, | |||
| NXGlobalEfficiency, | |||
| NXIsEulerian, | |||
| ) | |||
| from ._selectors import ( | |||
| FilterConstant, GBDTFeatureSelector | |||
| ) | |||
| __all__ = [ | |||
| "BaseFeatureEngineer", | |||
| "AutoFeatureEngineer", | |||
| "BaseFeature", | |||
| "BaseGenerator", | |||
| "GeGraphlet", | |||
| "GeEigen", | |||
| "GePageRank", | |||
| "register_pyg", | |||
| "pygfunc", | |||
| "PYGGenerator", | |||
| "PYGLocalDegreeProfile", | |||
| "PYGNormalizeFeatures", | |||
| "PYGOneHotDegree", | |||
| "BaseSelector", | |||
| "SeFilterConstant", | |||
| "SeGBDT", | |||
| "BaseGraph", | |||
| "SgNetLSD", | |||
| "register_nx", | |||
| "NxGraph", | |||
| "nxfunc", | |||
| "NxLargeCliqueSize", | |||
| "NxAverageClusteringApproximate", | |||
| "NxDegreeAssortativityCoefficient", | |||
| "NxDegreePearsonCorrelationCoefficient", | |||
| "NxHasBridge", | |||
| "NxGraphCliqueNumber", | |||
| "NxGraphNumberOfCliques", | |||
| "NxTransitivity", | |||
| "NxAverageClustering", | |||
| "NxIsConnected", | |||
| "NxNumberConnectedComponents", | |||
| "NxIsDistanceRegular", | |||
| "NxLocalEfficiency", | |||
| "NxGlobalEfficiency", | |||
| "NxIsEulerian", | |||
| ] | |||
| @@ -0,0 +1,90 @@ | |||
| import copy | |||
| import logging | |||
| import torch | |||
| import typing as _typing | |||
| from autogl.data import Dataset | |||
| LOGGER = logging.getLogger("FeatureEngineer") | |||
| class _BaseFeatureEngineer: | |||
| def __and__(self, other): | |||
| raise NotImplementedError | |||
| def fit_transform(self, dataset: Dataset, inplace=True) -> Dataset: | |||
| """ | |||
| Fit and transform dataset inplace or not w.r.t bool argument ``inplace`` | |||
| """ | |||
| dataset = self.fit(dataset) | |||
| return self.transform(dataset, inplace=inplace) | |||
| def fit(self, dataset: Dataset) -> Dataset: | |||
| raise NotImplementedError | |||
| def transform(self, dataset: Dataset, inplace: bool = True) -> Dataset: | |||
| raise NotImplementedError | |||
| class _ComposedFeatureEngineer(_BaseFeatureEngineer): | |||
| @property | |||
| def fe_components(self) -> _typing.Iterable[_BaseFeatureEngineer]: | |||
| return self.__fe_components | |||
| def __init__(self, feature_engineers: _typing.Iterable[_BaseFeatureEngineer]): | |||
| self.__fe_components: _typing.List[_BaseFeatureEngineer] = [] | |||
| for fe in feature_engineers: | |||
| if isinstance(fe, _ComposedFeatureEngineer): | |||
| self.__fe_components.extend(fe.fe_components) | |||
| else: | |||
| self.__fe_components.append(fe) | |||
| def __and__(self, other: _BaseFeatureEngineer): | |||
| return _ComposedFeatureEngineer((self, other)) | |||
| def fit(self, dataset) -> Dataset: | |||
| for fe in self.fe_components: | |||
| dataset = fe.fit(dataset) | |||
| return dataset | |||
| def transform(self, dataset: Dataset, inplace: bool = True) -> Dataset: | |||
| for fe in self.fe_components: | |||
| dataset = fe.transform(dataset, inplace) | |||
| return dataset | |||
| class BaseFeature(_BaseFeatureEngineer): | |||
| def __init__(self, multi_graph: bool = True, subgraph=False): | |||
| self._multi_graph: bool = multi_graph | |||
| def __and__(self, other): | |||
| return _ComposedFeatureEngineer((self, other)) | |||
| def _preprocess(self, data: _typing.Any) -> _typing.Any: | |||
| return data | |||
| def _fit(self, data: _typing.Any) -> _typing.Any: | |||
| return data | |||
| def _transform(self, data: _typing.Any) -> _typing.Any: | |||
| return data | |||
| def _postprocess(self, data: _typing.Any) -> _typing.Any: | |||
| return data | |||
| def fit(self, dataset: Dataset) -> Dataset: | |||
| with torch.no_grad(): | |||
| for i, data in enumerate(dataset): | |||
| dataset[i] = self._postprocess(self._transform(self._fit(self._preprocess(data)))) | |||
| return dataset | |||
| def transform(self, dataset: Dataset, inplace: bool = True) -> Dataset: | |||
| if not inplace: | |||
| dataset = copy.deepcopy(dataset) | |||
| with torch.no_grad(): | |||
| for i, data in enumerate(dataset): | |||
| dataset[i] = self._postprocess(self._transform(self._preprocess(data))) | |||
| return dataset | |||
| class BaseFeatureEngineer(BaseFeature): | |||
| ... | |||
| @@ -0,0 +1,62 @@ | |||
| import typing as _typing | |||
| from ._base_feature_engineer import BaseFeatureEngineer | |||
| class _FeatureEngineerUniversalRegistryMetaclass(type): | |||
| def __new__( | |||
| mcs, name: str, bases: _typing.Tuple[type, ...], | |||
| namespace: _typing.Dict[str, _typing.Any] | |||
| ): | |||
| return super(_FeatureEngineerUniversalRegistryMetaclass, mcs).__new__( | |||
| mcs, name, bases, namespace | |||
| ) | |||
| def __init__( | |||
| cls, name: str, bases: _typing.Tuple[type, ...], | |||
| namespace: _typing.Dict[str, _typing.Any] | |||
| ): | |||
| super(_FeatureEngineerUniversalRegistryMetaclass, cls).__init__( | |||
| name, bases, namespace | |||
| ) | |||
| cls._feature_engineer_universal_registry: _typing.MutableMapping[ | |||
| str, _typing.Type[BaseFeatureEngineer] | |||
| ] = {} | |||
| class FeatureEngineerUniversalRegistry(metaclass=_FeatureEngineerUniversalRegistryMetaclass): | |||
| @classmethod | |||
| def register_feature_engineer(cls, name: str) -> _typing.Callable[ | |||
| [_typing.Type[BaseFeatureEngineer]], _typing.Type[BaseFeatureEngineer] | |||
| ]: | |||
| def register_fe( | |||
| fe: _typing.Type[BaseFeatureEngineer] | |||
| ) -> _typing.Type[BaseFeatureEngineer]: | |||
| if name in cls._feature_engineer_universal_registry: | |||
| raise ValueError( | |||
| f"Feature Engineer with name \"{name}\" already exists!" | |||
| ) | |||
| elif not issubclass(fe, BaseFeatureEngineer): | |||
| raise TypeError | |||
| else: | |||
| cls._feature_engineer_universal_registry[name] = fe | |||
| return fe | |||
| return register_fe | |||
| @classmethod | |||
| def get_feature_engineer(cls, name: str) -> _typing.Type[BaseFeatureEngineer]: | |||
| if name in cls._feature_engineer_universal_registry: | |||
| return cls._feature_engineer_universal_registry[name] | |||
| else: | |||
| raise ValueError(f"cannot find feature engineer {name}") | |||
| class _DeprecatedFeatureDict: | |||
| def __contains__(self, name: str) -> bool: | |||
| return name in FeatureEngineerUniversalRegistry._feature_engineer_universal_registry | |||
| def __getitem__(self, name: str) -> _typing.Type[BaseFeatureEngineer]: | |||
| return FeatureEngineerUniversalRegistry.get_feature_engineer(name) | |||
| FEATURE_DICT = _DeprecatedFeatureDict() | |||
| @@ -0,0 +1,19 @@ | |||
| from ._basic import OneHotFeatureGenerator | |||
| from ._eigen import EigenFeatureGenerator | |||
| from ._graphlet import GraphletGenerator | |||
| from ._page_rank import PageRankFeatureGenerator | |||
| from ._pyg import ( | |||
| LocalDegreeProfileGenerator, | |||
| NormalizeFeatures, | |||
| OneHotDegreeGenerator | |||
| ) | |||
| __all__ = [ | |||
| "OneHotFeatureGenerator", | |||
| "EigenFeatureGenerator", | |||
| "GraphletGenerator", | |||
| "PageRankFeatureGenerator", | |||
| "LocalDegreeProfileGenerator", | |||
| "NormalizeFeatures", | |||
| "OneHotDegreeGenerator" | |||
| ] | |||
| @@ -0,0 +1,107 @@ | |||
| import torch | |||
| import typing as _typing | |||
| import autogl | |||
| from autogl.data.graph import GeneralStaticGraph | |||
| from .._base_feature_engineer import BaseFeatureEngineer | |||
| from .._feature_engineer_registry import FeatureEngineerUniversalRegistry | |||
| class BaseFeatureGenerator(BaseFeatureEngineer): | |||
| def __init__(self, override_features: bool = False): | |||
| super(BaseFeatureGenerator, self).__init__() | |||
| if not isinstance(override_features, bool): | |||
| raise TypeError | |||
| else: | |||
| self._override_features: bool = override_features | |||
| def _extract_nodes_feature(self, data: autogl.data.Data) -> torch.Tensor: | |||
| raise NotImplementedError | |||
| def __transform_homogeneous_static_graph( | |||
| self, homogeneous_static_graph: GeneralStaticGraph | |||
| ) -> GeneralStaticGraph: | |||
| if not ( | |||
| homogeneous_static_graph.nodes.is_homogeneous and | |||
| homogeneous_static_graph.edges.is_homogeneous | |||
| ): | |||
| raise ValueError("Provided static graph must be homogeneous") | |||
| if 'x' in homogeneous_static_graph.nodes.data: | |||
| feature_key: _typing.Optional[str] = 'x' | |||
| features: _typing.Optional[torch.Tensor] = ( | |||
| homogeneous_static_graph.nodes.data['x'] | |||
| ) | |||
| elif 'feat' in homogeneous_static_graph.nodes.data: | |||
| feature_key: _typing.Optional[str] = 'feat' | |||
| features: _typing.Optional[torch.Tensor] = ( | |||
| homogeneous_static_graph.nodes.data['feat'] | |||
| ) | |||
| else: | |||
| feature_key: _typing.Optional[str] = None | |||
| features: _typing.Optional[torch.Tensor] = None | |||
| if 'y' in homogeneous_static_graph.nodes.data: | |||
| label: _typing.Optional[torch.Tensor] = ( | |||
| homogeneous_static_graph.nodes.data['y'] | |||
| ) | |||
| elif 'label' in homogeneous_static_graph.nodes.data: | |||
| label: _typing.Optional[torch.Tensor] = ( | |||
| homogeneous_static_graph.nodes.data['label'] | |||
| ) | |||
| else: | |||
| label: _typing.Optional[torch.Tensor] = None | |||
| if ( | |||
| 'edge_weight' in homogeneous_static_graph.edges.data and | |||
| homogeneous_static_graph.edges.data['edge_weight'].dim() == 1 | |||
| ): | |||
| edge_weight: torch.Tensor = ( | |||
| homogeneous_static_graph.edges.data['edge_weight'] | |||
| ) | |||
| else: | |||
| edge_weight: torch.Tensor = torch.ones( | |||
| homogeneous_static_graph.edges.connections.size(1) | |||
| ) | |||
| data = autogl.data.Data( | |||
| edge_index=homogeneous_static_graph.edges.connections, | |||
| x=features, y=label | |||
| ) | |||
| setattr(data, "edge_weight", edge_weight) | |||
| extracted_features: torch.Tensor = self._extract_nodes_feature(data) | |||
| if isinstance(feature_key, str): | |||
| nodes_features: torch.Tensor = ( | |||
| homogeneous_static_graph.nodes.data[feature_key].view(-1, 1) | |||
| if homogeneous_static_graph.nodes.data[feature_key].dim() == 1 | |||
| else homogeneous_static_graph.nodes.data[feature_key] | |||
| ) | |||
| assert extracted_features.size(0) == nodes_features.size(0) | |||
| assert extracted_features.dim() == nodes_features.dim() == 2 | |||
| homogeneous_static_graph.nodes.data[feature_key] = ( | |||
| extracted_features.to(nodes_features.device) | |||
| if self._override_features | |||
| else torch.cat( | |||
| [nodes_features, extracted_features.to(nodes_features.device)], dim=-1 | |||
| ) | |||
| ) | |||
| else: | |||
| if autogl.backend.DependentBackend.is_pyg(): | |||
| homogeneous_static_graph.nodes.data['x'] = extracted_features | |||
| elif autogl.backend.DependentBackend.is_dgl(): | |||
| homogeneous_static_graph.nodes.data['feat'] = extracted_features | |||
| return homogeneous_static_graph | |||
| def _transform(self, data: _typing.Any) -> _typing.Any: | |||
| if isinstance(data, GeneralStaticGraph): | |||
| return self.__transform_homogeneous_static_graph(data) | |||
| else: | |||
| raise NotImplementedError( | |||
| f"Feature Generator only support instance of {GeneralStaticGraph} as provided data" | |||
| ) | |||
| @FeatureEngineerUniversalRegistry.register_feature_engineer("OneHot".lower()) | |||
| class OneHotFeatureGenerator(BaseFeatureGenerator): | |||
| def _extract_nodes_feature(self, data: autogl.data.Data) -> torch.Tensor: | |||
| num_nodes: int = ( | |||
| data.x.size(0) | |||
| if data.x is not None and isinstance(data.x, torch.Tensor) | |||
| else (data.edge_index.max().item() + 1) | |||
| ) | |||
| return torch.eye(num_nodes) | |||
| @@ -0,0 +1,92 @@ | |||
| import autogl | |||
| import numpy as np | |||
| from scipy.sparse import csr_matrix | |||
| import scipy.sparse as ssp | |||
| import scipy.sparse.linalg | |||
| import networkx as nx | |||
| import torch | |||
| from ._basic import BaseFeatureGenerator | |||
| from .._feature_engineer_registry import FeatureEngineerUniversalRegistry | |||
| class _Eigen: | |||
| def __init__(self): | |||
| ... | |||
| @classmethod | |||
| def __normalize_adj(cls, adj): | |||
| row_sum = np.array(adj.sum(1)) | |||
| d_inv_sqrt = np.power(row_sum, -0.5).flatten() | |||
| d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.0 | |||
| d_inv_sqrt = ssp.diags(d_inv_sqrt) | |||
| return adj.dot(d_inv_sqrt).transpose().dot(d_inv_sqrt) | |||
| def __call__(self, adj, d, use_eigenvalues=0, adj_norm=1): | |||
| G = nx.from_scipy_sparse_matrix(adj) | |||
| comp = list(nx.connected_components(G)) | |||
| results = np.zeros((adj.shape[0], d)) | |||
| for i in range(len(comp)): | |||
| node_index = np.array(list(comp[i])) | |||
| d_temp = min(len(node_index) - 2, d) | |||
| if d_temp <= 0: | |||
| continue | |||
| temp_adj = adj[node_index, :][:, node_index].asfptype() | |||
| if adj_norm == 1: | |||
| temp_adj = self.__normalize_adj(temp_adj) | |||
| lamb, X = scipy.sparse.linalg.eigs(temp_adj, d_temp) | |||
| lamb, X = lamb.real, X.real | |||
| temp_order = np.argsort(lamb) | |||
| lamb, X = lamb[temp_order], X[:, temp_order] | |||
| for i in range(X.shape[1]): | |||
| if np.sum(X[:, i]) < 0: | |||
| X[:, i] = -X[:, i] | |||
| if use_eigenvalues == 1: | |||
| X = X.dot(np.diag(np.sqrt(np.absolute(lamb)))) | |||
| elif use_eigenvalues == 2: | |||
| X = X.dot(np.diag(lamb)) | |||
| results[node_index, :d_temp] = X | |||
| return results | |||
| @FeatureEngineerUniversalRegistry.register_feature_engineer("eigen") | |||
| class EigenFeatureGenerator(BaseFeatureGenerator): | |||
| r""" | |||
| concat Eigen features | |||
| Notes | |||
| ----- | |||
| An implementation of [#]_ | |||
| References | |||
| ---------- | |||
| .. [#] Ziwei Zhang, Peng Cui, Jian Pei, Xin Wang, Wenwu Zhu: | |||
| Eigen-GNN: A Graph Structure Preserving Plug-in for GNNs. CoRR abs/2006.04330 (2020) | |||
| https://arxiv.org/abs/2006.04330 | |||
| Parameters | |||
| ---------- | |||
| size : int | |||
| EigenGNN hidden size | |||
| """ | |||
| def __init__(self, size: int = 32): | |||
| super(EigenFeatureGenerator, self).__init__() | |||
| self.__size: int = size | |||
| def _extract_nodes_feature(self, data: autogl.data.Data) -> torch.Tensor: | |||
| edge_index: np.ndarray = data.edge_index.numpy() | |||
| edge_weight: np.ndarray = getattr(data, "edge_weight").numpy() | |||
| num_nodes: int = ( | |||
| data.x.size(0) | |||
| if data.x is not None and isinstance(data.x, torch.Tensor) | |||
| else (data.edge_index.max().item() + 1) | |||
| ) | |||
| adj = csr_matrix( | |||
| (edge_weight, (edge_index[0, :], edge_index[1, :])), | |||
| shape=(num_nodes, num_nodes) | |||
| ) | |||
| if np.max(adj - adj.T) > 1e-5: | |||
| adj = adj + adj.T | |||
| mf = _Eigen() | |||
| features: np.ndarray = mf(adj, self.__size) | |||
| return torch.from_numpy(features) | |||
| @@ -0,0 +1,247 @@ | |||
| import logging | |||
| import numpy as np | |||
| import torch | |||
| from tqdm import tqdm | |||
| import autogl | |||
| from ._basic import BaseFeatureGenerator | |||
| from .._feature_engineer_registry import FeatureEngineerUniversalRegistry | |||
| _LOGGER = logging.getLogger("FE") | |||
| class _Graphlet: | |||
| def __init__(self, data, sample_error=0.1, sample_confidence=0.1): | |||
| self._data = data | |||
| self._init() | |||
| self._sample_error = sample_error | |||
| self._sample_confidence = sample_confidence | |||
| self._dw = int( | |||
| np.ceil( | |||
| 0.5 * (self._sample_error ** -2) * np.log(2 / self._sample_confidence) | |||
| ) | |||
| ) | |||
| _LOGGER.info( | |||
| "sample error {} , confidence {},num {}".format( | |||
| self._sample_error, self._sample_confidence, self._dw | |||
| ) | |||
| ) | |||
| def _init(self): | |||
| self._edges = list(self._data.edge_index) | |||
| self._edges = [self._edges[0], self._edges[1]] | |||
| self._num_nodes = self._data.x.shape[0] | |||
| self._num_edges = len(self._edges[0]) | |||
| self._neighbours = [[] for _ in range(self._num_nodes)] | |||
| for i in range(len(self._edges[0])): | |||
| u, v = self._edges[0][i], self._edges[1][i] | |||
| self._neighbours[u].append(v) | |||
| _LOGGER.info("nodes {} , edges {}".format(self._num_nodes, self._num_edges)) | |||
| # sorting | |||
| self._node_degrees = np.array([len(x) for x in self._neighbours]) | |||
| self._nodes = np.argsort(self._node_degrees) | |||
| for i in self._nodes: | |||
| self._neighbours[i] = [ | |||
| x | |||
| for _, x in sorted( | |||
| zip(self._node_degrees[self._neighbours[i]], self._neighbours[i]), | |||
| reverse=True, | |||
| ) | |||
| ] | |||
| self._neighbours = [np.array(x) for x in self._neighbours] | |||
| def _get_gdv(self, v, u): | |||
| if self._node_degrees[v] >= self._node_degrees[u]: | |||
| pass | |||
| else: | |||
| u, v = v, u | |||
| Sv, Su, Te = set(), set(), set() | |||
| sigma1, sigma2 = 0, 0 | |||
| nb = self._neighbours | |||
| N = self._num_nodes | |||
| M = self._num_edges | |||
| phi = np.zeros(self._num_nodes, dtype=int) | |||
| c1, c2, c3, c4 = 1, 2, 3, 4 | |||
| x = np.zeros(16, dtype=int) | |||
| # p1 | |||
| for w in nb[v]: | |||
| if w != u: | |||
| Sv.add(w) | |||
| phi[w] = c1 | |||
| # p2 | |||
| for w in nb[u]: | |||
| if w != v: | |||
| if phi[w] == c1: | |||
| Te.add(w) | |||
| phi[w] = c3 | |||
| Sv.remove(w) | |||
| else: | |||
| Su.add(w) | |||
| phi[w] = c2 | |||
| # p3 | |||
| for w in Te: | |||
| for r in nb[w]: | |||
| if phi[r] == c3: | |||
| x[5] += 1 | |||
| phi[w] = c4 | |||
| sigma2 = sigma2 + len(nb[w]) - 2 | |||
| # p4 | |||
| for w in Su: | |||
| for r in nb[w]: | |||
| if phi[r] == c1: | |||
| x[8] += 1 | |||
| if phi[r] == c2: | |||
| x[7] += 1 | |||
| if phi[r] == c4: | |||
| sigma1 += 1 | |||
| phi[w] = 0 | |||
| sigma2 = sigma2 + len(nb[w]) - 1 | |||
| # p5 | |||
| for w in Sv: | |||
| for r in nb[w]: | |||
| if phi[r] == c1: | |||
| x[7] += 1 | |||
| if phi[r] == c4: | |||
| sigma1 += 1 | |||
| phi[w] = 0 | |||
| sigma2 = sigma2 + len(nb[w]) - 1 | |||
| lsv, lsu, lte, du, dv = len(Sv), len(Su), len(Te), len(nb[u]), len(nb[v]) | |||
| # 3-graphlet | |||
| x[1] = lte | |||
| x[2] = du + dv - 2 - 2 * x[1] | |||
| x[3] = N - x[2] - x[1] - 2 | |||
| x[4] = N * (N - 1) * (N - 2) / 6 - (x[1] + x[2] + x[3]) | |||
| # 4 connected graphlets | |||
| x[6] = x[1] * (x[1] - 1) / 2 - x[5] | |||
| x[10] = lsv * lsu - x[8] | |||
| x[9] = lsv * (lsv - 1) / 2 + lsu * (lsu - 1) / 2 - x[7] | |||
| # 4 disconnected graphlets | |||
| t1 = N - (lte + lsu + lsv + 2) | |||
| x[11] = x[1] * t1 | |||
| x[12] = M - (du + dv - 1) - (sigma2 - sigma1 - x[5] - x[8] - x[7]) | |||
| x[13] = (lsu + lsv) * t1 | |||
| x[14] = t1 * (t1 - 1) / 2 - x[12] | |||
| x[15] = N * (N - 1) * (N - 2) * (N - 3) / 24 - np.sum(x[5:15]) | |||
| return x | |||
| def _get_gdv_sample(self, v, u): | |||
| if self._node_degrees[v] >= self._node_degrees[u]: | |||
| pass | |||
| else: | |||
| u, v = v, u | |||
| Sv = set() | |||
| sigma1, sigma2 = 0, 0 | |||
| nb = self._neighbours | |||
| N = self._num_nodes | |||
| M = self._num_edges | |||
| phi = np.zeros(self._num_nodes, dtype=int) | |||
| c1, c2, c3, c4 = 1, 2, 3, 4 | |||
| x = np.zeros(16) | |||
| dw = self._dw | |||
| # p1 | |||
| Sv = set(nb[v][nb[v] != u]) | |||
| phi[list(Sv)] = c1 | |||
| # p2 | |||
| p2w = nb[u][nb[u] != c1] | |||
| p2w1 = p2w[phi[p2w] == c1] | |||
| p2w2 = p2w[phi[p2w] != c1] | |||
| Te = p2w1 | |||
| phi[p2w1] = c3 | |||
| Sv -= set(list(p2w1)) | |||
| Su = p2w2 | |||
| phi[p2w2] = c2 | |||
| # p3 | |||
| for w in Te: | |||
| if dw >= len(nb[w]): | |||
| region = nb[w] | |||
| inc = 1 | |||
| else: | |||
| region = np.random.choice(nb[w], dw, replace=False) | |||
| inc = self._node_degrees[w] / dw | |||
| phir = phi[region] | |||
| x[5] += inc * np.sum(phir == c3) | |||
| phi[w] = c4 | |||
| sigma2 = sigma2 + len(nb[w]) - 2 | |||
| # p4 | |||
| for w in Su: | |||
| if dw >= len(nb[w]): | |||
| region = nb[w] | |||
| inc = 1 | |||
| else: | |||
| region = np.random.choice(nb[w], dw, replace=False) | |||
| inc = self._node_degrees[w] / dw | |||
| phir = phi[region] | |||
| x[8] += inc * np.sum(phir == c1) | |||
| x[7] += inc * np.sum(phir == c2) | |||
| sigma1 += inc * np.sum(phir == c4) | |||
| phi[w] = 0 | |||
| sigma2 = sigma2 + len(nb[w]) - 1 | |||
| # p5 | |||
| for w in Sv: | |||
| if dw >= len(nb[w]): | |||
| region = nb[w] | |||
| inc = 1 | |||
| else: | |||
| region = np.random.choice(nb[w], dw, replace=False) | |||
| inc = self._node_degrees[w] / dw | |||
| phir = phi[region] | |||
| x[7] += inc * np.sum(phir == c1) | |||
| sigma1 += inc * np.sum(phir == c4) | |||
| phi[w] = 0 | |||
| sigma2 = sigma2 + len(nb[w]) - 1 | |||
| lsv, lsu, lte, du, dv = len(Sv), len(Su), len(Te), len(nb[u]), len(nb[v]) | |||
| # 3-graphlet | |||
| x[1] = lte | |||
| x[2] = du + dv - 2 - 2 * x[1] | |||
| x[3] = N - x[2] - x[1] - 2 | |||
| x[4] = N * (N - 1) * (N - 2) / 6 - (x[1] + x[2] + x[3]) | |||
| # 4 connected graphlets | |||
| x[6] = x[1] * (x[1] - 1) / 2 - x[5] | |||
| x[10] = lsv * lsu - x[8] | |||
| x[9] = lsv * (lsv - 1) / 2 + lsu * (lsu - 1) / 2 - x[7] | |||
| # 4 disconnected graphlets | |||
| t1 = N - (lte + lsu + lsv + 2) | |||
| x[11] = x[1] * t1 | |||
| x[12] = M - (du + dv - 1) - (sigma2 - sigma1 - x[5] - x[8] - x[7]) | |||
| x[13] = (lsu + lsv) * t1 | |||
| x[14] = t1 * (t1 - 1) / 2 - x[12] | |||
| x[15] = N * (N - 1) * (N - 2) * (N - 3) / 24 - np.sum(x[5:15]) | |||
| return x | |||
| def get_gdvs(self, sample=True): | |||
| res = np.zeros((self._num_nodes, 15)) | |||
| for u in tqdm(range(self._num_nodes)): | |||
| vs = self._neighbours[u] | |||
| if len(vs) != 0: | |||
| gdvs = [] | |||
| for v in tqdm(vs, disable=len(vs) < 100): | |||
| if sample: | |||
| gdvs.append(self._get_gdv_sample(u, v)) | |||
| else: | |||
| gdvs.append(self._get_gdv(u, v)) | |||
| res[u, :] = np.mean(gdvs, axis=0)[1:] | |||
| return res | |||
| @FeatureEngineerUniversalRegistry.register_feature_engineer("graph" + "let") | |||
| class GraphletGenerator(BaseFeatureGenerator): | |||
| r"""generate local graphlet numbers as features. The implementation refers to [#]_ . | |||
| References | |||
| ---------- | |||
| .. [#] Ahmed, N. K., Willke, T. L., & Rossi, R. A. (2016). | |||
| Estimation of local subgraph counts. Proceedings - 2016 IEEE International Conference on Big Data, Big Data 2016, 586–595. | |||
| https://doi.org/10.1109/BigData.2016.7840651 | |||
| """ | |||
| def _extract_nodes_feature(self, data: autogl.data.Data) -> torch.Tensor: | |||
| result: np.ndarray = _Graphlet(data).get_gdvs() | |||
| return torch.from_numpy(result) | |||
| @@ -0,0 +1,29 @@ | |||
| import numpy as np | |||
| import networkx as nx | |||
| import torch | |||
| import autogl | |||
| from ._basic import BaseFeatureGenerator | |||
| from .._feature_engineer_registry import FeatureEngineerUniversalRegistry | |||
| @FeatureEngineerUniversalRegistry.register_feature_engineer("PageRank".lower()) | |||
| class PageRankFeatureGenerator(BaseFeatureGenerator): | |||
| def _extract_nodes_feature(self, data: autogl.data.Data) -> torch.Tensor: | |||
| edge_weight = getattr(data, "edge_weight").tolist() | |||
| g = nx.DiGraph() | |||
| g.add_weighted_edges_from( | |||
| [ | |||
| (u, v, edge_weight[i]) | |||
| for i, (u, v) in enumerate(data.edge_index.t().tolist()) | |||
| ] | |||
| ) | |||
| page_rank = nx.pagerank(g) | |||
| num_nodes: int = ( | |||
| data.x.size(0) | |||
| if data.x is not None and isinstance(data.x, torch.Tensor) | |||
| else (data.edge_index.max().item() + 1) | |||
| ) | |||
| pr = np.zeros(num_nodes) | |||
| for i, v in page_rank.items(): | |||
| pr[i] = v | |||
| return torch.from_numpy(pr) | |||
| @@ -0,0 +1,81 @@ | |||
| import torch.nn.functional | |||
| import autogl | |||
| from ._basic import BaseFeatureGenerator | |||
| from ._pyg_impl import degree, scatter_min, scatter_max, scatter_mean, scatter_std | |||
| from .._feature_engineer_registry import FeatureEngineerUniversalRegistry | |||
| @FeatureEngineerUniversalRegistry.register_feature_engineer("LocalDegreeProfile") | |||
| class LocalDegreeProfileGenerator(BaseFeatureGenerator): | |||
| r"""Appends the Local Degree Profile (LDP) from the `"A Simple yet | |||
| Effective Baseline for Non-attribute Graph Classification" | |||
| <https://arxiv.org/abs/1811.03508>`_ paper | |||
| .. math:: | |||
| \mathbf{x}_i = \mathbf{x}_i \, \Vert \, (\deg(i), \min(DN(i)), | |||
| \max(DN(i)), \textrm{mean}(DN(i)), \textrm{std}(DN(i))) | |||
| to the node features, where :math:`DN(i) = \{ \deg(j) \mid j \in | |||
| \mathcal{N}(i) \}`. | |||
| """ | |||
| def _extract_nodes_feature(self, data: autogl.data.Data) -> torch.Tensor: | |||
| row, col = data.edge_index | |||
| if data.x is not None and isinstance(data.x, torch.Tensor): | |||
| N = data.x.size(0) | |||
| else: | |||
| N = (torch.max(data.edge_index).item() + 1) | |||
| deg = degree(row, N, dtype=torch.float) | |||
| deg_col = deg[col] | |||
| min_deg, _ = scatter_min(deg_col, row, dim_size=N) | |||
| min_deg[min_deg > 10000] = 0 | |||
| max_deg, _ = scatter_max(deg_col, row, dim_size=N) | |||
| max_deg[max_deg < -10000] = 0 | |||
| mean_deg = scatter_mean(deg_col, row, dim_size=N) | |||
| std_deg = scatter_std(deg_col, row, dim_size=N) | |||
| x = torch.stack([deg, min_deg, max_deg, mean_deg, std_deg], dim=1) | |||
| return x | |||
| @FeatureEngineerUniversalRegistry.register_feature_engineer("NormalizeFeatures") | |||
| class NormalizeFeatures(BaseFeatureGenerator): | |||
| def __init__(self): | |||
| super(NormalizeFeatures, self).__init__(override_features=True) | |||
| def _extract_nodes_feature(self, data: autogl.data.Data) -> torch.Tensor: | |||
| if data.x is not None and isinstance(data.x, torch.Tensor): | |||
| data.x.div_(data.x.sum(dim=-1, keepdim=True).clamp_(min=1.)) | |||
| return data.x | |||
| @FeatureEngineerUniversalRegistry.register_feature_engineer("OneHotDegree") | |||
| class OneHotDegreeGenerator(BaseFeatureGenerator): | |||
| r"""Adds the node degree as one hot encodings to the node features. | |||
| Args: | |||
| max_degree (int): Maximum degree. | |||
| in_degree (bool, optional): If set to :obj:`True`, will compute the | |||
| in-degree of nodes instead of the out-degree. | |||
| (default: :obj:`False`) | |||
| cat (bool, optional): Concat node degrees to node features instead | |||
| of replacing them. (default: :obj:`True`) | |||
| """ | |||
| def __init__( | |||
| self, max_degree: int = 1000, | |||
| in_degree: bool = False, cat: bool = True | |||
| ): | |||
| self.__max_degree: int = max_degree | |||
| self.__in_degree: bool = in_degree | |||
| self.__cat: bool = cat | |||
| super(OneHotDegreeGenerator, self).__init__() | |||
| def _extract_nodes_feature(self, data: autogl.data.Data) -> torch.Tensor: | |||
| idx, x = data.edge_index[1 if self.__in_degree else 0], data.x | |||
| deg = degree(idx, data.num_nodes, dtype=torch.long) | |||
| deg = torch.nn.functional.one_hot( | |||
| deg, num_classes=self.__max_degree + 1 | |||
| ).to(torch.float) | |||
| return deg | |||
| @@ -0,0 +1,234 @@ | |||
| import torch | |||
| import typing as _typing | |||
| from typing import Optional, Tuple | |||
| def degree(index, num_nodes: _typing.Optional[int] = None, | |||
| dtype: _typing.Optional[torch.dtype] = None): | |||
| r"""Computes the (unweighted) degree of a given one-dimensional index | |||
| tensor. | |||
| Args: | |||
| index (LongTensor): Index tensor. | |||
| num_nodes (int, optional): The number of nodes, *i.e.* | |||
| :obj:`max_val + 1` of :attr:`index`. (default: :obj:`None`) | |||
| dtype (:obj:`torch.dtype`, optional): The desired data type of the | |||
| returned tensor. | |||
| :rtype: :class:`Tensor` | |||
| """ | |||
| def maybe_num_nodes(edge_index, __num_nodes=None): | |||
| if __num_nodes is not None: | |||
| return __num_nodes | |||
| elif isinstance(edge_index, torch.Tensor): | |||
| return int(edge_index.max()) + 1 if edge_index.numel() > 0 else 0 | |||
| else: | |||
| return max(edge_index.size(0), edge_index.size(1)) | |||
| N = maybe_num_nodes(index, num_nodes) | |||
| out = torch.zeros((N,), dtype=dtype, device=index.device) | |||
| one = torch.ones((index.size(0),), dtype=out.dtype, device=out.device) | |||
| return out.scatter_add_(0, index, one) | |||
| def broadcast(src: torch.Tensor, other: torch.Tensor, dim: int): | |||
| if dim < 0: | |||
| dim = other.dim() + dim | |||
| if src.dim() == 1: | |||
| for _ in range(0, dim): | |||
| src = src.unsqueeze(0) | |||
| for _ in range(src.dim(), other.dim()): | |||
| src = src.unsqueeze(-1) | |||
| src = src.expand_as(other) | |||
| return src | |||
| def scatter_sum(src: torch.Tensor, index: torch.Tensor, dim: int = -1, | |||
| out: Optional[torch.Tensor] = None, | |||
| dim_size: Optional[int] = None) -> torch.Tensor: | |||
| index = broadcast(index, src, dim) | |||
| if out is None: | |||
| size = list(src.size()) | |||
| if dim_size is not None: | |||
| size[dim] = dim_size | |||
| elif index.numel() == 0: | |||
| size[dim] = 0 | |||
| else: | |||
| size[dim] = int(index.max()) + 1 | |||
| out = torch.zeros(size, dtype=src.dtype, device=src.device) | |||
| return out.scatter_add_(dim, index, src) | |||
| else: | |||
| return out.scatter_add_(dim, index, src) | |||
| def scatter_add(src: torch.Tensor, index: torch.Tensor, dim: int = -1, | |||
| out: Optional[torch.Tensor] = None, | |||
| dim_size: Optional[int] = None) -> torch.Tensor: | |||
| return scatter_sum(src, index, dim, out, dim_size) | |||
| def scatter_mul(src: torch.Tensor, index: torch.Tensor, dim: int = -1, | |||
| out: Optional[torch.Tensor] = None, | |||
| dim_size: Optional[int] = None) -> torch.Tensor: | |||
| return torch.ops.torch_scatter.scatter_mul(src, index, dim, out, dim_size) | |||
| def scatter_mean(src: torch.Tensor, index: torch.Tensor, dim: int = -1, | |||
| out: Optional[torch.Tensor] = None, | |||
| dim_size: Optional[int] = None) -> torch.Tensor: | |||
| out = scatter_sum(src, index, dim, out, dim_size) | |||
| dim_size = out.size(dim) | |||
| index_dim = dim | |||
| if index_dim < 0: | |||
| index_dim = index_dim + src.dim() | |||
| if index.dim() <= index_dim: | |||
| index_dim = index.dim() - 1 | |||
| ones = torch.ones(index.size(), dtype=src.dtype, device=src.device) | |||
| count = scatter_sum(ones, index, index_dim, None, dim_size) | |||
| count[count < 1] = 1 | |||
| count = broadcast(count, out, dim) | |||
| if out.is_floating_point(): | |||
| out.true_divide_(count) | |||
| else: | |||
| out.floor_divide_(count) | |||
| return out | |||
| def scatter_min( | |||
| src: torch.Tensor, index: torch.Tensor, dim: int = -1, | |||
| out: Optional[torch.Tensor] = None, | |||
| dim_size: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]: | |||
| return torch.ops.torch_scatter.scatter_min(src, index, dim, out, dim_size) | |||
| def scatter_max( | |||
| src: torch.Tensor, index: torch.Tensor, dim: int = -1, | |||
| out: Optional[torch.Tensor] = None, | |||
| dim_size: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]: | |||
| return torch.ops.torch_scatter.scatter_max(src, index, dim, out, dim_size) | |||
| def scatter_std(src: torch.Tensor, index: torch.Tensor, dim: int = -1, | |||
| out: Optional[torch.Tensor] = None, | |||
| dim_size: Optional[int] = None, | |||
| unbiased: bool = True) -> torch.Tensor: | |||
| if out is not None: | |||
| dim_size = out.size(dim) | |||
| if dim < 0: | |||
| dim = src.dim() + dim | |||
| count_dim = dim | |||
| if index.dim() <= dim: | |||
| count_dim = index.dim() - 1 | |||
| ones = torch.ones(index.size(), dtype=src.dtype, device=src.device) | |||
| count = scatter_sum(ones, index, count_dim, dim_size=dim_size) | |||
| index = broadcast(index, src, dim) | |||
| tmp = scatter_sum(src, index, dim, dim_size=dim_size) | |||
| count = broadcast(count, tmp, dim).clamp(1) | |||
| mean = tmp.div(count) | |||
| var = (src - mean.gather(dim, index)) | |||
| var = var * var | |||
| out = scatter_sum(var, index, dim, out, dim_size) | |||
| if unbiased: | |||
| count = count.sub(1).clamp_(1) | |||
| out = out.div(count + 1e-6).sqrt() | |||
| return out | |||
| def scatter(src: torch.Tensor, index: torch.Tensor, dim: int = -1, | |||
| out: Optional[torch.Tensor] = None, dim_size: Optional[int] = None, | |||
| reduce: str = "sum") -> torch.Tensor: | |||
| r""" | |||
| | | |||
| .. image:: https://raw.githubusercontent.com/rusty1s/pytorch_scatter/ | |||
| master/docs/source/_figures/add.svg?sanitize=true | |||
| :align: center | |||
| :width: 400px | |||
| | | |||
| Reduces all values from the :attr:`src` tensor into :attr:`out` at the | |||
| indices specified in the :attr:`index` tensor along a given axis | |||
| :attr:`dim`. | |||
| For each value in :attr:`src`, its output index is specified by its index | |||
| in :attr:`src` for dimensions outside of :attr:`dim` and by the | |||
| corresponding value in :attr:`index` for dimension :attr:`dim`. | |||
| The applied reduction is defined via the :attr:`reduce` argument. | |||
| Formally, if :attr:`src` and :attr:`index` are :math:`n`-dimensional | |||
| tensors with size :math:`(x_0, ..., x_{i-1}, x_i, x_{i+1}, ..., x_{n-1})` | |||
| and :attr:`dim` = `i`, then :attr:`out` must be an :math:`n`-dimensional | |||
| tensor with size :math:`(x_0, ..., x_{i-1}, y, x_{i+1}, ..., x_{n-1})`. | |||
| Moreover, the values of :attr:`index` must be between :math:`0` and | |||
| :math:`y - 1`, although no specific ordering of indices is required. | |||
| The :attr:`index` tensor supports broadcasting in case its dimensions do | |||
| not match with :attr:`src`. | |||
| For one-dimensional tensors with :obj:`reduce="sum"`, the operation | |||
| computes | |||
| .. math:: | |||
| \mathrm{out}_i = \mathrm{out}_i + \sum_j~\mathrm{src}_j | |||
| where :math:`\sum_j` is over :math:`j` such that | |||
| :math:`\mathrm{index}_j = i`. | |||
| .. note:: | |||
| This operation is implemented via atomic operations on the GPU and is | |||
| therefore **non-deterministic** since the order of parallel operations | |||
| to the same value is undetermined. | |||
| For floating-point variables, this results in a source of variance in | |||
| the result. | |||
| :param src: The source tensor. | |||
| :param index: The indices of elements to scatter. | |||
| :param dim: The axis along which to index. (default: :obj:`-1`) | |||
| :param out: The destination tensor. | |||
| :param dim_size: If :attr:`out` is not given, automatically create output | |||
| with size :attr:`dim_size` at dimension :attr:`dim`. | |||
| If :attr:`dim_size` is not given, a minimal sized output tensor | |||
| according to :obj:`index.max() + 1` is returned. | |||
| :param reduce: The reduce operation (:obj:`"sum"`, :obj:`"mul"`, | |||
| :obj:`"mean"`, :obj:`"min"` or :obj:`"max"`). (default: :obj:`"sum"`) | |||
| :rtype: :class:`Tensor` | |||
| .. code-block:: python | |||
| from torch_scatter import scatter | |||
| src = torch.randn(10, 6, 64) | |||
| index = torch.tensor([0, 1, 0, 1, 2, 1]) | |||
| # Broadcasting in the first and last dim. | |||
| out = scatter(src, index, dim=1, reduce="sum") | |||
| print(out.size()) | |||
| .. code-block:: | |||
| torch.Size([10, 3, 64]) | |||
| """ | |||
| if reduce == 'sum' or reduce == 'add': | |||
| return scatter_sum(src, index, dim, out, dim_size) | |||
| if reduce == 'mul': | |||
| return scatter_mul(src, index, dim, out, dim_size) | |||
| elif reduce == 'mean': | |||
| return scatter_mean(src, index, dim, out, dim_size) | |||
| elif reduce == 'min': | |||
| return scatter_min(src, index, dim, out, dim_size)[0] | |||
| elif reduce == 'max': | |||
| return scatter_max(src, index, dim, out, dim_size)[0] | |||
| else: | |||
| raise ValueError | |||
| @@ -0,0 +1,17 @@ | |||
| from ._netlsd import NetLSD | |||
| from ._networkx import ( | |||
| NXLargeCliqueSize, | |||
| NXDegreeAssortativityCoefficient, | |||
| NXDegreePearsonCorrelationCoefficient, | |||
| NXHasBridges, | |||
| NXGraphCliqueNumber, | |||
| NXGraphNumberOfCliques, | |||
| NXTransitivity, | |||
| NXAverageClustering, | |||
| NXIsConnected, | |||
| NXNumberConnectedComponents, | |||
| NXIsDistanceRegular, | |||
| NXLocalEfficiency, | |||
| NXGlobalEfficiency, | |||
| NXIsEulerian, | |||
| ) | |||
| @@ -0,0 +1,82 @@ | |||
| import netlsd | |||
| import networkx | |||
| import torch | |||
| from autogl.data.graph import GeneralStaticGraph | |||
| from autogl.data.graph.utils import conversion | |||
| from .._base_feature_engineer import BaseFeatureEngineer | |||
| from .._feature_engineer_registry import FeatureEngineerUniversalRegistry | |||
| @FeatureEngineerUniversalRegistry.register_feature_engineer("NetLSD".lower()) | |||
| class NetLSD(BaseFeatureEngineer): | |||
| r""" | |||
| Notes | |||
| ----- | |||
| a graph feature generation method. This is a simple wrapper of NetLSD [#]_. | |||
| References | |||
| ---------- | |||
| .. [#] A. Tsitsulin, D. Mottin, P. Karras, A. Bronstein, and E. Müller, “NetLSD: Hearing the shape of a graph,” | |||
| Proc. ACM SIGKDD Int. Conf. Knowl. Discov. Data Min., pp. 2347–2356, 2018. | |||
| """ | |||
| def __init__(self, *args, **kwargs): | |||
| self.__args = args | |||
| self.__kwargs = kwargs | |||
| super(NetLSD, self).__init__() | |||
| def __extract(self, nx_g: networkx.Graph) -> torch.Tensor: | |||
| return torch.tensor(netlsd.heat(nx_g, *self.__args, **self.__kwargs)).view(-1) | |||
| def __transform_homogeneous_static_graph( | |||
| self, homogeneous_static_graph: GeneralStaticGraph | |||
| ) -> GeneralStaticGraph: | |||
| if not ( | |||
| homogeneous_static_graph.nodes.is_homogeneous and | |||
| homogeneous_static_graph.edges.is_homogeneous | |||
| ): | |||
| raise ValueError("Provided static graph must be homogeneous") | |||
| dsc: torch.Tensor = self.__extract( | |||
| conversion.HomogeneousStaticGraphToNetworkX(to_undirected=True).__call__( | |||
| homogeneous_static_graph, to_undirected=True | |||
| ) | |||
| ) | |||
| if 'gf' in homogeneous_static_graph.data: | |||
| gf = homogeneous_static_graph.data['gf'].view(-1) | |||
| homogeneous_static_graph.data['gf'] = torch.cat([gf, dsc]) | |||
| else: | |||
| homogeneous_static_graph.data['gf'] = dsc | |||
| return homogeneous_static_graph | |||
| @classmethod | |||
| def __edge_index_to_nx_graph(cls, edge_index: torch.Tensor) -> networkx.Graph: | |||
| g: networkx.Graph = networkx.Graph() | |||
| for u, v in edge_index.t().tolist(): | |||
| if u == v: | |||
| continue | |||
| else: | |||
| g.add_edge(u, v) | |||
| return g | |||
| def __transform_data(self, data): | |||
| if not ( | |||
| hasattr(data, "edge_index") and | |||
| torch.is_tensor(data.edge_index) and | |||
| isinstance(data.edge_index, torch.Tensor) and | |||
| data.edge_index.dim() == data.edge_index.size(0) == 2 and | |||
| data.edge_index.dtype == torch.long | |||
| ): | |||
| raise TypeError("Unsupported provided data") | |||
| dsc: torch.Tensor = self.__extract(self.__edge_index_to_nx_graph(data.edge_index)) | |||
| if hasattr(data, 'gf') and isinstance(data.gf, torch.Tensor): | |||
| gf = data.gf.view(-1) | |||
| data.gf = torch.cat([gf, dsc]) | |||
| else: | |||
| data.gf = dsc | |||
| return data | |||
| def _transform(self, data): | |||
| if isinstance(data, GeneralStaticGraph): | |||
| return self.__transform_homogeneous_static_graph(data) | |||
| else: | |||
| return self.__transform_data(data) | |||
| @@ -0,0 +1,176 @@ | |||
| import torch | |||
| import typing as _typing | |||
| import networkx | |||
| from networkx.algorithms.euler import is_eulerian | |||
| from networkx.algorithms.efficiency_measures import global_efficiency | |||
| from networkx.algorithms.efficiency_measures import local_efficiency | |||
| from networkx.algorithms.distance_regular import is_distance_regular | |||
| from networkx.algorithms.components import number_connected_components | |||
| from networkx.algorithms.components import is_connected | |||
| # from networkx.algorithms.cluster import average_clustering | |||
| from networkx.algorithms.cluster import transitivity | |||
| from networkx.algorithms.clique import graph_number_of_cliques | |||
| from networkx.algorithms.clique import graph_clique_number | |||
| from networkx.algorithms.bridges import has_bridges | |||
| from networkx.algorithms.assortativity import degree_pearson_correlation_coefficient | |||
| from networkx.algorithms.assortativity import degree_assortativity_coefficient | |||
| from networkx.algorithms.approximation.clustering_coefficient import average_clustering | |||
| from networkx.algorithms.approximation.clique import large_clique_size | |||
| from autogl.data.graph import GeneralStaticGraph | |||
| from autogl.data.graph.utils import conversion | |||
| from .._base_feature_engineer import BaseFeatureEngineer | |||
| from .._feature_engineer_registry import FeatureEngineerUniversalRegistry | |||
| class _NetworkXGraphFeatureEngineer(BaseFeatureEngineer): | |||
| def __init__(self, feature_extractor: _typing.Callable[[networkx.Graph], _typing.Any]): | |||
| self.__feature_extractor: _typing.Callable[[networkx.Graph], _typing.Any] = feature_extractor | |||
| super(_NetworkXGraphFeatureEngineer, self).__init__() | |||
| def __transform_homogeneous_static_graph( | |||
| self, homogeneous_static_graph: GeneralStaticGraph | |||
| ) -> GeneralStaticGraph: | |||
| if not ( | |||
| homogeneous_static_graph.nodes.is_homogeneous and | |||
| homogeneous_static_graph.edges.is_homogeneous | |||
| ): | |||
| raise ValueError("Provided static graph must be homogeneous") | |||
| dsc: torch.Tensor = torch.tensor( | |||
| [ | |||
| self.__feature_extractor( | |||
| conversion.HomogeneousStaticGraphToNetworkX(to_undirected=True)(homogeneous_static_graph) | |||
| ) | |||
| ] | |||
| ).view(-1) | |||
| if 'gf' in homogeneous_static_graph.data: | |||
| gf = homogeneous_static_graph.data['gf'].view(-1) | |||
| homogeneous_static_graph.data['gf'] = torch.cat([gf, dsc]) | |||
| else: | |||
| homogeneous_static_graph.data['gf'] = dsc | |||
| return homogeneous_static_graph | |||
| @classmethod | |||
| def __edge_index_to_nx_graph(cls, edge_index: torch.Tensor) -> networkx.Graph: | |||
| g: networkx.Graph = networkx.Graph() | |||
| for u, v in edge_index.t().tolist(): | |||
| if u == v: | |||
| continue | |||
| else: | |||
| g.add_edge(u, v) | |||
| return g | |||
| def __transform_data(self, data): | |||
| if not ( | |||
| hasattr(data, "edge_index") and | |||
| torch.is_tensor(data.edge_index) and | |||
| isinstance(data.edge_index, torch.Tensor) and | |||
| data.edge_index.dim() == data.edge_index.size(0) == 2 and | |||
| data.edge_index.dtype == torch.long | |||
| ): | |||
| raise TypeError("Unsupported provided data") | |||
| dsc: torch.Tensor = torch.tensor( | |||
| [self.__feature_extractor(self.__edge_index_to_nx_graph(data.edge_index))] | |||
| ).view(-1) | |||
| if hasattr(data, 'gf') and isinstance(data.gf, torch.Tensor): | |||
| gf = data.gf.view(-1) | |||
| data.gf = torch.cat([gf, dsc]) | |||
| else: | |||
| data.gf = dsc | |||
| return data | |||
| def _transform(self, data): | |||
| if isinstance(data, GeneralStaticGraph): | |||
| return self.__transform_homogeneous_static_graph(data) | |||
| else: | |||
| return self.__transform_data(data) | |||
| @FeatureEngineerUniversalRegistry.register_feature_engineer("NXLargeCliqueSize") | |||
| class NXLargeCliqueSize(_NetworkXGraphFeatureEngineer): | |||
| def __init__(self): | |||
| super(NXLargeCliqueSize, self).__init__(large_clique_size) | |||
| # @FeatureEngineerUniversalRegistry.register_feature_engineer("NXAverageClusteringApproximate") | |||
| # class NXAverageClusteringApproximate(_NetworkXGraphFeatureEngineer): | |||
| # def __init__(self): | |||
| # super(NXAverageClusteringApproximate, self).__init__(average_clustering) | |||
| @FeatureEngineerUniversalRegistry.register_feature_engineer("NXDegreeAssortativityCoefficient") | |||
| class NXDegreeAssortativityCoefficient(_NetworkXGraphFeatureEngineer): | |||
| def __init__(self): | |||
| super(NXDegreeAssortativityCoefficient, self).__init__(degree_assortativity_coefficient) | |||
| @FeatureEngineerUniversalRegistry.register_feature_engineer("NXDegreePearsonCorrelationCoefficient") | |||
| class NXDegreePearsonCorrelationCoefficient(_NetworkXGraphFeatureEngineer): | |||
| def __init__(self): | |||
| super(NXDegreePearsonCorrelationCoefficient, self).__init__(degree_pearson_correlation_coefficient) | |||
| @FeatureEngineerUniversalRegistry.register_feature_engineer("NXHasBridges") | |||
| class NXHasBridges(_NetworkXGraphFeatureEngineer): | |||
| def __init__(self): | |||
| super(NXHasBridges, self).__init__(has_bridges) | |||
| @FeatureEngineerUniversalRegistry.register_feature_engineer("NXGraphCliqueNumber") | |||
| class NXGraphCliqueNumber(_NetworkXGraphFeatureEngineer): | |||
| def __init__(self): | |||
| super(NXGraphCliqueNumber, self).__init__(graph_clique_number) | |||
| @FeatureEngineerUniversalRegistry.register_feature_engineer("NXGraphNumberOfCliques") | |||
| class NXGraphNumberOfCliques(_NetworkXGraphFeatureEngineer): | |||
| def __init__(self): | |||
| super(NXGraphNumberOfCliques, self).__init__(graph_number_of_cliques) | |||
| @FeatureEngineerUniversalRegistry.register_feature_engineer("NXTransitivity") | |||
| class NXTransitivity(_NetworkXGraphFeatureEngineer): | |||
| def __init__(self): | |||
| super(NXTransitivity, self).__init__(transitivity) | |||
| @FeatureEngineerUniversalRegistry.register_feature_engineer("NXAverageClustering") | |||
| class NXAverageClustering(_NetworkXGraphFeatureEngineer): | |||
| def __init__(self): | |||
| super(NXAverageClustering, self).__init__(average_clustering) | |||
| @FeatureEngineerUniversalRegistry.register_feature_engineer("NXIsConnected") | |||
| class NXIsConnected(_NetworkXGraphFeatureEngineer): | |||
| def __init__(self): | |||
| super(NXIsConnected, self).__init__(is_connected) | |||
| @FeatureEngineerUniversalRegistry.register_feature_engineer("NXNumberConnectedComponents") | |||
| class NXNumberConnectedComponents(_NetworkXGraphFeatureEngineer): | |||
| def __init__(self): | |||
| super(NXNumberConnectedComponents, self).__init__(number_connected_components) | |||
| @FeatureEngineerUniversalRegistry.register_feature_engineer("NXIsDistanceRegular") | |||
| class NXIsDistanceRegular(_NetworkXGraphFeatureEngineer): | |||
| def __init__(self): | |||
| super(NXIsDistanceRegular, self).__init__(is_distance_regular) | |||
| @FeatureEngineerUniversalRegistry.register_feature_engineer("NXLocalEfficiency") | |||
| class NXLocalEfficiency(_NetworkXGraphFeatureEngineer): | |||
| def __init__(self): | |||
| super(NXLocalEfficiency, self).__init__(local_efficiency) | |||
| @FeatureEngineerUniversalRegistry.register_feature_engineer("NXGlobalEfficiency") | |||
| class NXGlobalEfficiency(_NetworkXGraphFeatureEngineer): | |||
| def __init__(self): | |||
| super(NXGlobalEfficiency, self).__init__(global_efficiency) | |||
| @FeatureEngineerUniversalRegistry.register_feature_engineer("NXIsEulerian") | |||
| class NXIsEulerian(_NetworkXGraphFeatureEngineer): | |||
| def __init__(self): | |||
| super(NXIsEulerian, self).__init__(is_eulerian) | |||
| @@ -0,0 +1,2 @@ | |||
| from ._basic import FilterConstant | |||
| from ._gbdt import GBDTFeatureSelector | |||
| @@ -0,0 +1,58 @@ | |||
| import numpy as np | |||
| import torch | |||
| import typing as _typing | |||
| from autogl.data.graph import GeneralStaticGraph | |||
| from .._base_feature_engineer import BaseFeatureEngineer | |||
| from .._feature_engineer_registry import FeatureEngineerUniversalRegistry | |||
| class BaseFeatureSelector(BaseFeatureEngineer): | |||
| def __init__(self): | |||
| self._selection = _typing.Optional[torch.Tensor] = None | |||
| super(BaseFeatureSelector, self).__init__() | |||
| def _transform(self, static_graph: GeneralStaticGraph) -> GeneralStaticGraph: | |||
| if ( | |||
| 'x' in static_graph.nodes.data and | |||
| self._selection not in (Ellipsis, None) and | |||
| isinstance(self._selection, torch.Tensor) and | |||
| torch.is_tensor(self._selection) and self._selection.dim() == 1 | |||
| ): | |||
| static_graph.nodes.data['x'] = static_graph.nodes.data['x'][:, self._selection] | |||
| if ( | |||
| 'feat' in static_graph.nodes.data and | |||
| self._selection not in (Ellipsis, None) and | |||
| isinstance(self._selection, torch.Tensor) and | |||
| torch.is_tensor(self._selection) and self._selection.dim() == 1 | |||
| ): | |||
| static_graph.nodes.data['feat'] = static_graph.nodes.data['feat'][:, self._selection] | |||
| return static_graph | |||
| @FeatureEngineerUniversalRegistry.register_feature_engineer("FilterConstant") | |||
| class FilterConstant(BaseFeatureSelector): | |||
| r"""drop constant features""" | |||
| def _fit(self, static_graph: GeneralStaticGraph) -> GeneralStaticGraph: | |||
| if ( | |||
| 'x' in static_graph.nodes.data and | |||
| self._selection not in (Ellipsis, None) and | |||
| isinstance(self._selection, torch.Tensor) and | |||
| torch.is_tensor(self._selection) and self._selection.dim() == 1 | |||
| ): | |||
| feature: _typing.Optional[np.ndarray] = static_graph.nodes.data['x'].numpy() | |||
| elif ( | |||
| 'feat' in static_graph.nodes.data and | |||
| self._selection not in (Ellipsis, None) and | |||
| isinstance(self._selection, torch.Tensor) and | |||
| torch.is_tensor(self._selection) and self._selection.dim() == 1 | |||
| ): | |||
| feature: _typing.Optional[np.ndarray] = static_graph.nodes.data['feat'].numpy() | |||
| else: | |||
| feature: _typing.Optional[np.ndarray] = None | |||
| self._selection: _typing.Optional[torch.Tensor] = torch.from_numpy( | |||
| np.where(np.all(feature == feature[0, :], axis=0) == np.array(False))[0] | |||
| if feature is not None and isinstance(feature, np.ndarray) and feature.ndim == 2 | |||
| else None | |||
| ) | |||
| return static_graph | |||
| @@ -0,0 +1,139 @@ | |||
| import numpy as np | |||
| import pandas as pd | |||
| import torch | |||
| import typing as _typing | |||
| import autogl | |||
| from autogl.data.graph import GeneralStaticGraph | |||
| from .. import _feature_engineer_registry | |||
| import lightgbm | |||
| from sklearn.model_selection import train_test_split | |||
| from ._basic import BaseFeatureSelector | |||
| def _gbdt_generator( | |||
| data: autogl.data.Data, fixlen: int = 1000, | |||
| params: _typing.Mapping[str, _typing.Any] = ..., | |||
| is_val: bool = True, train_val_ratio: float = 0.2, | |||
| **optimizer_parameters | |||
| ) -> _typing.Optional[np.ndarray]: | |||
| parameters: _typing.Dict[str, _typing.Any] = ( | |||
| dict(params) | |||
| if ( | |||
| params not in (Ellipsis, None) and | |||
| isinstance(params, _typing.Mapping) | |||
| ) | |||
| else { | |||
| "boosting_type": "gbdt", | |||
| "verbosity": -1, | |||
| "random_state": 47, | |||
| "objective": "multiclass", | |||
| "metric": ["multi_logloss"], | |||
| "max_bin": 63, | |||
| "save_binary": True, | |||
| "num_threads": 20, | |||
| "num_leaves": 16, | |||
| "subsample": 0.9, | |||
| "subsample_freq": 1, | |||
| "colsample_bytree": 0.8, | |||
| # 'is_training_metric': True, | |||
| # 'metric_freq': 1, | |||
| } | |||
| ) | |||
| num_classes: int = torch.max(data.y).item() + 1 | |||
| __optimizer_parameters = { | |||
| "num_boost_round": 100, | |||
| "early_stopping_rounds": 5, | |||
| "verbose_eval": False | |||
| } | |||
| __optimizer_parameters.update(optimizer_parameters) | |||
| if hasattr(data, "train_mask") and data.train_mask is not None and ( | |||
| isinstance(data.train_mask, np.ndarray) or | |||
| isinstance(data.train_mask, torch.Tensor) | |||
| ): | |||
| x: np.ndarray = data.x[data.train_mask].numpy() | |||
| label: np.ndarray = data.y[data.train_mask].numpy() | |||
| else: | |||
| x: np.ndarray = data.x.numpy() | |||
| label: np.ndarray = data.y.numpy() | |||
| is_val: bool = False | |||
| _, num_features = x.shape | |||
| if num_features < fixlen: | |||
| return None | |||
| feature_index: np.ndarray = np.array( | |||
| [f"f{i}" for i in range(num_features)] | |||
| ) | |||
| if is_val: | |||
| x_train, x_val, y_train, y_val = train_test_split( | |||
| x, label, test_size=train_val_ratio, stratify=label, random_state=47 | |||
| ) | |||
| dtrain = lightgbm.Dataset(x_train, label=y_train) | |||
| dval = lightgbm.Dataset(x_val, label=y_val) | |||
| clf = lightgbm.train( | |||
| train_set=dtrain, params=parameters, valid_sets=dval, | |||
| **__optimizer_parameters | |||
| ) | |||
| else: | |||
| train_x = pd.DataFrame(x, columns=feature_index, index=None) | |||
| dtrain = lightgbm.Dataset(train_x, label=label) | |||
| clf = lightgbm.train( | |||
| train_set=dtrain, params=params, | |||
| **__optimizer_parameters | |||
| ) | |||
| imp = np.array(list(clf.feature_importance())) | |||
| return np.argsort(imp)[-fixlen:] | |||
| @_feature_engineer_registry.FeatureEngineerUniversalRegistry.register_feature_engineer("gbdt") | |||
| class GBDTFeatureSelector(BaseFeatureSelector): | |||
| r"""simple wrapper of lightgbm , using importance ranking to select top-k features. | |||
| Parameters | |||
| ---------- | |||
| fixlen : int | |||
| K for top-K important features. | |||
| """ | |||
| def __init__(self, fixlen: int = 10, *args, **kwargs): | |||
| super(GBDTFeatureSelector, self).__init__() | |||
| self.__fixlen = fixlen | |||
| self.__args = args | |||
| self.__kwargs = kwargs | |||
| def _fit(self, homogeneous_static_graph: GeneralStaticGraph) -> GeneralStaticGraph: | |||
| if not isinstance(homogeneous_static_graph, GeneralStaticGraph): | |||
| raise TypeError | |||
| elif not ( | |||
| homogeneous_static_graph.nodes.is_homogeneous and | |||
| homogeneous_static_graph.edges.is_homogeneous | |||
| ): | |||
| raise ValueError | |||
| if 'x' in homogeneous_static_graph.nodes.data: | |||
| features: torch.Tensor = homogeneous_static_graph.nodes.data['x'] | |||
| elif 'feat' in homogeneous_static_graph.nodes.data: | |||
| features: torch.Tensor = homogeneous_static_graph.nodes.data['feat'] | |||
| else: | |||
| raise ValueError("Node features not exists") | |||
| if 'y' in homogeneous_static_graph.nodes.data: | |||
| label: torch.Tensor = homogeneous_static_graph.nodes.data['y'] | |||
| elif 'label' in homogeneous_static_graph.nodes.data: | |||
| label: torch.Tensor = homogeneous_static_graph.nodes.data['label'] | |||
| else: | |||
| raise ValueError("Node label not exists") | |||
| if 'train_mask' in homogeneous_static_graph.nodes.data: | |||
| train_mask: _typing.Optional[torch.Tensor] = ( | |||
| homogeneous_static_graph.nodes.data['train_mask'] | |||
| ) | |||
| else: | |||
| train_mask: _typing.Optional[torch.Tensor] = None | |||
| data = autogl.data.Data( | |||
| edge_index=homogeneous_static_graph.edges.connections, | |||
| x=features, y=label | |||
| ) | |||
| data.train_mask = train_mask | |||
| self._selection = _gbdt_generator( | |||
| data, self.__fixlen, *self.__args, **self.__kwargs | |||
| ) | |||
| return homogeneous_static_graph | |||
| @@ -12,7 +12,7 @@ from .autone_file import utils | |||
| from torch_geometric.data import GraphSAINTRandomWalkSampler | |||
| from ..feature.graph import SgNetLSD | |||
| from ..feature import NetLSD as SgNetLSD | |||
| from torch_geometric.data import InMemoryDataset | |||
| @@ -1,22 +1,12 @@ | |||
| from ._model_registry import MODEL_DICT, ModelUniversalRegistry, register_model | |||
| from .base import BaseModel | |||
| from .topkpool import AutoTopkpool | |||
| import importlib | |||
| import sys | |||
| from ...backend import DependentBackend | |||
| # from .graph_sage import AutoSAGE | |||
| from .graphsage import AutoSAGE | |||
| from .graph_saint import GraphSAINTAggregationModel | |||
| from .gcn import AutoGCN | |||
| from .gat import AutoGAT | |||
| from .gin import AutoGIN | |||
| # load corresponding backend of subclass | |||
| def _load_subclass_backend(backend): | |||
| sub_module = importlib.import_module(f'.{backend.get_backend_name()}', __name__) | |||
| this = sys.modules[__name__] | |||
| for api, obj in sub_module.__dict__.items(): | |||
| setattr(this, api, obj) | |||
| __all__ = [ | |||
| "ModelUniversalRegistry", | |||
| "register_model", | |||
| "BaseModel", | |||
| "AutoTopkpool", | |||
| "AutoSAGE", | |||
| "GraphSAINTAggregationModel", | |||
| "AutoGCN", | |||
| "AutoGAT", | |||
| "AutoGIN", | |||
| ] | |||
| _load_subclass_backend(DependentBackend) | |||
| @@ -0,0 +1,25 @@ | |||
| from ._model_registry import MODEL_DICT, ModelUniversalRegistry, register_model | |||
| from .base import BaseModel | |||
| from .topkpool import AutoTopkpool | |||
| from .graph_saint import GraphSAINTAggregationModel | |||
| from .gcn import GCN, AutoGCN | |||
| from .graphsage import GraphSAGE, AutoSAGE | |||
| from .gat import GAT,AutoGAT | |||
| from .gin import AutoGIN | |||
| __all__ = [ | |||
| "ModelUniversalRegistry", | |||
| "register_model", | |||
| "BaseModel", | |||
| "AutoTopkpool", | |||
| "GraphSAINTAggregationModel", | |||
| "GCN", | |||
| "AutoGCN", | |||
| "GraphSAGE", | |||
| "AutoSAGE", | |||
| "GAT", | |||
| "AutoGAT", | |||
| "AutoGIN" | |||
| ] | |||
| @@ -0,0 +1,212 @@ | |||
| import torch | |||
| import torch.nn.functional as F | |||
| from dgl.nn.pytorch.conv import GATConv | |||
| from . import register_model | |||
| from .base import BaseModel, activate_func | |||
| from ....utils import get_logger | |||
| LOGGER = get_logger("GATModel") | |||
| def set_default(args, d): | |||
| for k, v in d.items(): | |||
| if k not in args: | |||
| args[k] = v | |||
| return args | |||
| class GAT(torch.nn.Module): | |||
| def __init__(self, args): | |||
| super(GAT, self).__init__() | |||
| self.args = args | |||
| self.num_layer = int(self.args["num_layers"]) | |||
| missing_keys = list( | |||
| set( | |||
| [ | |||
| "features_num", | |||
| "num_class", | |||
| "num_layers", | |||
| "hidden", | |||
| "heads", | |||
| "dropout", | |||
| "act", | |||
| ] | |||
| ) | |||
| - set(self.args.keys()) | |||
| ) | |||
| if len(missing_keys) > 0: | |||
| raise Exception("Missing keys: %s." % ",".join(missing_keys)) | |||
| if not self.num_layer == len(self.args["hidden"]) + 1: | |||
| LOGGER.warn("Warning: layer size does not match the length of hidden units") | |||
| self.convs = torch.nn.ModuleList() | |||
| self.convs.append( | |||
| GATConv( | |||
| self.args["features_num"], | |||
| self.args["hidden"][0], | |||
| num_heads =self.args["heads"], | |||
| attn_drop=self.args["dropout"], | |||
| ) | |||
| ) | |||
| last_dim = self.args["hidden"][0] * self.args["heads"] | |||
| for i in range(self.num_layer - 2): | |||
| self.convs.append( | |||
| GATConv( | |||
| last_dim, | |||
| self.args["hidden"][i + 1], | |||
| num_heads=self.args["heads"], | |||
| attn_drop=self.args["dropout"], | |||
| ) | |||
| ) | |||
| last_dim = self.args["hidden"][i + 1] * self.args["heads"] | |||
| self.convs.append( | |||
| GATConv( | |||
| last_dim, | |||
| self.args["num_class"], | |||
| num_heads=1, | |||
| attn_drop=self.args["dropout"], | |||
| ) | |||
| ) | |||
| def forward(self, data): | |||
| try: | |||
| x = data.ndata['feat'] | |||
| except: | |||
| print("no x") | |||
| pass | |||
| for i in range(self.num_layer): | |||
| x = F.dropout(x, p=self.args["dropout"], training=self.training) | |||
| x = self.convs[i](data, x).flatten(1) | |||
| if i != self.num_layer - 1: | |||
| x = activate_func(x, self.args["act"]) | |||
| return F.log_softmax(x, dim=1) | |||
| def lp_encode(self, data): | |||
| x = data.ndata['feat'] | |||
| for i in range(self.num_layer - 1): | |||
| x = self.convs[i](x, data.train_pos_edge_index).flatten(1) | |||
| if i != self.num_layer - 2: | |||
| x = activate_func(x, self.args["act"]) | |||
| # x = F.dropout(x, p=self.args["dropout"], training=self.training) | |||
| return x | |||
| def lp_decode(self, z, pos_edge_index, neg_edge_index): | |||
| edge_index = torch.cat([pos_edge_index, neg_edge_index], dim=-1) | |||
| logits = (z[edge_index[0]] * z[edge_index[1]]).sum(dim=-1) | |||
| return logits | |||
| def lp_decode_all(self, z): | |||
| prob_adj = z @ z.t() | |||
| return (prob_adj > 0).nonzero(as_tuple=False).t() | |||
| @register_model("gat") | |||
| class AutoGAT(BaseModel): | |||
| r""" | |||
| AutoGAT. The model used in this automodel is GAT, i.e., the graph attentional network from the `"Graph Attention Networks" | |||
| <https://arxiv.org/abs/1710.10903>`_ paper. The layer is | |||
| .. math:: | |||
| \mathbf{x}^{\prime}_i = \alpha_{i,i}\mathbf{\Theta}\mathbf{x}_{i} + | |||
| \sum_{j \in \mathcal{N}(i)} \alpha_{i,j}\mathbf{\Theta}\mathbf{x}_{j} | |||
| where the attention coefficients :math:`\alpha_{i,j}` are computed as | |||
| .. math:: | |||
| \alpha_{i,j} = | |||
| \frac{ | |||
| \exp\left(\mathrm{LeakyReLU}\left(\mathbf{a}^{\top} | |||
| [\mathbf{\Theta}\mathbf{x}_i \, \Vert \, \mathbf{\Theta}\mathbf{x}_j] | |||
| \right)\right)} | |||
| {\sum_{k \in \mathcal{N}(i) \cup \{ i \}} | |||
| \exp\left(\mathrm{LeakyReLU}\left(\mathbf{a}^{\top} | |||
| [\mathbf{\Theta}\mathbf{x}_i \, \Vert \, \mathbf{\Theta}\mathbf{x}_k] | |||
| \right)\right)}. | |||
| Parameters | |||
| ---------- | |||
| num_features: `int`. | |||
| The dimension of features. | |||
| num_classes: `int`. | |||
| The number of classes. | |||
| device: `torch.device` or `str` | |||
| The device where model will be running on. | |||
| init: `bool`. | |||
| If True(False), the model will (not) be initialized. | |||
| args: Other parameters. | |||
| """ | |||
| def __init__( | |||
| self, num_features=None, num_classes=None, device=None, init=False, **args | |||
| ): | |||
| super(AutoGAT, self).__init__() | |||
| self.num_features = num_features if num_features is not None else 0 | |||
| self.num_classes = int(num_classes) if num_classes is not None else 0 | |||
| self.device = device if device is not None else "cpu" | |||
| self.init = True | |||
| self.params = { | |||
| "features_num": self.num_features, | |||
| "num_class": self.num_classes, | |||
| } | |||
| self.space = [ | |||
| { | |||
| "parameterName": "num_layers", | |||
| "type": "DISCRETE", | |||
| "feasiblePoints": "2,3,4", | |||
| }, | |||
| { | |||
| "parameterName": "hidden", | |||
| "type": "NUMERICAL_LIST", | |||
| "numericalType": "INTEGER", | |||
| "length": 3, | |||
| "minValue": [8, 8, 8], | |||
| "maxValue": [64, 64, 64], | |||
| "scalingType": "LOG", | |||
| "cutPara": ("num_layers",), | |||
| "cutFunc": lambda x: x[0] - 1, | |||
| }, | |||
| { | |||
| "parameterName": "dropout", | |||
| "type": "DOUBLE", | |||
| "maxValue": 0.8, | |||
| "minValue": 0.2, | |||
| "scalingType": "LINEAR", | |||
| }, | |||
| { | |||
| "parameterName": "heads", | |||
| "type": "DISCRETE", | |||
| "feasiblePoints": "2,4,8,16", | |||
| }, | |||
| { | |||
| "parameterName": "act", | |||
| "type": "CATEGORICAL", | |||
| "feasiblePoints": ["leaky_relu", "relu", "elu", "tanh"], | |||
| }, | |||
| ] | |||
| self.hyperparams = { | |||
| "num_layers": 2, | |||
| "hidden": [32], | |||
| "heads": 4, | |||
| "dropout": 0.2, | |||
| "act": "leaky_relu", | |||
| } | |||
| self.initialized = False | |||
| if init is True: | |||
| self.initialize() | |||
| def initialize(self): | |||
| # """Initialize model.""" | |||
| if self.initialized: | |||
| return | |||
| self.initialized = True | |||
| self.model = GAT({**self.params, **self.hyperparams}).to(self.device) | |||
| @@ -0,0 +1,395 @@ | |||
| import torch | |||
| import torch.nn.functional as F | |||
| from typing import Sequence, Optional, Union, Tuple | |||
| from numbers import Real | |||
| from dgl.nn.pytorch.conv import GraphConv | |||
| from dgl import remove_self_loop, add_self_loop | |||
| import autogl.data | |||
| from . import register_model | |||
| from .base import BaseModel, activate_func, ClassificationSupportedSequentialModel | |||
| from ....utils import get_logger | |||
| LOGGER = get_logger("GCNModel") | |||
| class GCN(ClassificationSupportedSequentialModel): | |||
| class _GCNLayer(torch.nn.Module): | |||
| def __init__( | |||
| self, | |||
| input_channels: int, | |||
| output_channels: int, | |||
| add_self_loops: bool = True, | |||
| normalize: bool = True, | |||
| activation_name: Optional[str] = None, | |||
| dropout_probability: Optional[Real] = None, | |||
| ): | |||
| super().__init__() | |||
| self._convolution: GraphConv = GraphConv( | |||
| input_channels, | |||
| output_channels, | |||
| norm='both' if normalize else 'none', | |||
| ) | |||
| self.add_self_loops = bool(add_self_loops), | |||
| if isinstance(activation_name, str): | |||
| self._activation_name = activation_name | |||
| else: | |||
| self._activation_name = None | |||
| if isinstance(dropout_probability, Real): | |||
| if dropout_probability < 0: | |||
| dropout_probability = 0 | |||
| if dropout_probability > 1: | |||
| dropout_probability = 1 | |||
| self._dropout = torch.nn.Dropout(dropout_probability) | |||
| else: | |||
| self._dropout = None | |||
| def forward(self, data, x, enable_activation: bool = True) -> torch.Tensor: | |||
| if self.add_self_loops: | |||
| data = remove_self_loop(data) | |||
| data = add_self_loop(data) | |||
| x: torch.Tensor = self._convolution.forward(data, x) | |||
| if self._activation_name is not None and enable_activation: | |||
| x: torch.Tensor = activate_func(x, self._activation_name) | |||
| if self._dropout is not None: | |||
| x: torch.Tensor = self._dropout.forward(x) | |||
| return x | |||
| def __init__( | |||
| self, | |||
| num_features: int, | |||
| num_classes: int, | |||
| hidden_features: Sequence[int], | |||
| activation_name: str, | |||
| dropout: Union[Real, Sequence[Optional[Real]], None] = None, | |||
| add_self_loops: bool = True, | |||
| normalize: bool = True, | |||
| ): | |||
| if isinstance(dropout, Sequence): | |||
| if len(dropout) != len(hidden_features) + 1: | |||
| raise TypeError( | |||
| "When the dropout argument is a sequence, " | |||
| "The sequence length must equal to the number of layers to construct." | |||
| ) | |||
| for _dropout in dropout: | |||
| if _dropout is not None and not isinstance(_dropout, Real): | |||
| raise TypeError( | |||
| "When the dropout argument is a sequence, " | |||
| "every item in the sequence must be float or None" | |||
| ) | |||
| dropout_list: Sequence[Optional[Real]] = dropout | |||
| elif isinstance(dropout, Real): | |||
| if dropout < 0: | |||
| dropout = 0 | |||
| if dropout > 1: | |||
| dropout = 1 | |||
| dropout_list: Sequence[Real] = [ | |||
| dropout for _ in range(len(hidden_features)) | |||
| ] + [None] | |||
| elif dropout is None: | |||
| dropout_list: Sequence[None] = [ | |||
| None for _ in range(len(hidden_features) + 1) | |||
| ] | |||
| else: | |||
| raise TypeError( | |||
| "The provided dropout argument must be a float number or None or " | |||
| "a sequence in which each item is either a float Number or None." | |||
| ) | |||
| super().__init__() | |||
| if len(hidden_features) == 0: | |||
| self.__sequential_encoding_layers: torch.nn.ModuleList = ( | |||
| torch.nn.ModuleList( | |||
| ( | |||
| self._GCNLayer( | |||
| num_features, | |||
| num_classes, | |||
| add_self_loops, | |||
| normalize, | |||
| dropout_probability=dropout_list[0], | |||
| ), | |||
| ) | |||
| ) | |||
| ) | |||
| else: | |||
| self.__sequential_encoding_layers = torch.nn.ModuleList() | |||
| self.__sequential_encoding_layers.append( | |||
| self._GCNLayer( | |||
| num_features, | |||
| hidden_features[0], | |||
| add_self_loops, | |||
| normalize, | |||
| activation_name, | |||
| dropout_list[0], | |||
| ) | |||
| ) | |||
| for hidden_feature_index in range(len(hidden_features)): | |||
| if hidden_feature_index + 1 < len(hidden_features): | |||
| self.__sequential_encoding_layers.append( | |||
| self._GCNLayer( | |||
| hidden_features[hidden_feature_index], | |||
| hidden_features[hidden_feature_index + 1], | |||
| add_self_loops, | |||
| normalize, | |||
| activation_name, | |||
| dropout_list[hidden_feature_index + 1], | |||
| ) | |||
| ) | |||
| else: | |||
| self.__sequential_encoding_layers.append( | |||
| self._GCNLayer( | |||
| hidden_features[hidden_feature_index], | |||
| num_classes, | |||
| add_self_loops, | |||
| normalize, | |||
| dropout_list[-1], | |||
| ) | |||
| ) | |||
| @property | |||
| def sequential_encoding_layers(self) -> torch.nn.ModuleList: | |||
| return self.__sequential_encoding_layers | |||
| def __extract_edge_indexes_and_weights( | |||
| self, data | |||
| ) -> Union[ | |||
| Sequence[Tuple[torch.LongTensor, Optional[torch.Tensor]]], | |||
| Tuple[torch.LongTensor, Optional[torch.Tensor]], | |||
| ]: | |||
| def __compose_edge_index_and_weight( | |||
| _edge_index: torch.LongTensor, | |||
| _edge_weight: Optional[torch.Tensor] = None, | |||
| ) -> Tuple[torch.LongTensor, Optional[torch.Tensor]]: | |||
| if type(_edge_index) != torch.Tensor or _edge_index.dtype != torch.int64: | |||
| raise TypeError | |||
| if _edge_weight is not None and ( | |||
| type(_edge_weight) != torch.Tensor | |||
| or _edge_index.size() != (2, _edge_weight.size(0)) | |||
| ): | |||
| _edge_weight: Optional[torch.Tensor] = None | |||
| return _edge_index, _edge_weight | |||
| if not ( | |||
| hasattr(data, "edge_indexes") | |||
| and isinstance(getattr(data, "edge_indexes"), Sequence) | |||
| and len(getattr(data, "edge_indexes")) | |||
| == len(self.__sequential_encoding_layers) | |||
| ): | |||
| if not data.edata.has_key('edge_weights'): | |||
| data.edata['edge_weights']=None | |||
| return __compose_edge_index_and_weight( | |||
| data.edges(), data.edata['edge_weights'] | |||
| ) | |||
| # for __edge_index in getattr(data, "edge_indexes"): | |||
| # if type(__edge_index) != torch.Tensor or __edge_index.dtype != torch.int64: | |||
| # return __compose_edge_index_and_weight( | |||
| # data.edges(), getattr(data, "edge_weight", None) | |||
| # ) | |||
| if ( | |||
| data.edata.has_key('edge_weights') | |||
| and isinstance(data.edata['edge_weights'], Sequence) | |||
| and len(data.edata.has_key('edge_weights')) | |||
| == len(self.__sequential_encoding_layers) | |||
| ): | |||
| return [ | |||
| __compose_edge_index_and_weight(_edge_index, _edge_weight) | |||
| for _edge_index, _edge_weight in zip( | |||
| getattr(data, "edge_indexes"), getattr(data, "edge_weights") | |||
| ) | |||
| ] | |||
| else: | |||
| return [ | |||
| __compose_edge_index_and_weight(__edge_index) | |||
| for __edge_index in getattr(data, "edge_indexes") | |||
| ] | |||
| def forward(self, data): | |||
| x = data.ndata['feat'] | |||
| for gcn in self.__sequential_encoding_layers: | |||
| x = gcn(data,x) | |||
| return F.log_softmax(x, dim=-1) | |||
| def cls_encode(self, data) -> torch.Tensor: | |||
| return self(data) | |||
| edge_indexes_and_weights: Union[ | |||
| Sequence[Tuple[torch.LongTensor, Optional[torch.Tensor]]], | |||
| Tuple[torch.LongTensor, Optional[torch.Tensor]], | |||
| ] = self.__extract_edge_indexes_and_weights(data) | |||
| if (not isinstance(edge_indexes_and_weights, tuple)) and isinstance( | |||
| edge_indexes_and_weights[0], tuple | |||
| ): | |||
| """ edge_indexes_and_weights is sequence of (edge_index, edge_weight) """ | |||
| assert len(edge_indexes_and_weights) == len( | |||
| self.__sequential_encoding_layers | |||
| ) | |||
| x: torch.Tensor = data.ndata['feat'] | |||
| for _edge_index_and_weight, gcn in zip( | |||
| edge_indexes_and_weights, self.__sequential_encoding_layers | |||
| ): | |||
| _temp_data = autogl.data.Data(x=x, edge_index=_edge_index_and_weight[0]) | |||
| _temp_data.edge_weight = _edge_index_and_weight[1] | |||
| x = gcn(_temp_data) | |||
| return x | |||
| else: | |||
| """ edge_indexes_and_weights is (edge_index, edge_weight) """ | |||
| x = data.ndata['feat'] | |||
| for gcn in self.__sequential_encoding_layers: | |||
| _temp_data = autogl.data.Data( | |||
| x=x, edge_index=edge_indexes_and_weights[0] | |||
| ) | |||
| _temp_data.edge_weight = edge_indexes_and_weights[1] | |||
| x = gcn(_temp_data) | |||
| return x | |||
| def cls_decode(self, x: torch.Tensor) -> torch.Tensor: | |||
| return torch.nn.functional.log_softmax(x, dim=1) | |||
| def lp_encode(self, data): | |||
| x: torch.Tensor = data.ndata['feat'] | |||
| for i in range(len(self.__sequential_encoding_layers) - 2): | |||
| x = self.__sequential_encoding_layers[i]( | |||
| autogl.data.Data(x, data.edges()) | |||
| ) | |||
| x = self.__sequential_encoding_layers[-2]( | |||
| autogl.data.Data(x, data.edges()), enable_activation=False | |||
| ) | |||
| return x | |||
| def lp_decode(self, z, pos_edge_index, neg_edge_index): | |||
| edge_index = torch.cat([pos_edge_index, neg_edge_index], dim=-1) | |||
| logits = (z[edge_index[0]] * z[edge_index[1]]).sum(dim=-1) | |||
| return logits | |||
| def lp_decode_all(self, z): | |||
| prob_adj = z @ z.t() | |||
| return (prob_adj > 0).nonzero(as_tuple=False).t() | |||
| @register_model("gcn") | |||
| class AutoGCN(BaseModel): | |||
| r""" | |||
| AutoGCN. | |||
| The model used in this automodel is GCN, i.e., the graph convolutional network from the | |||
| `"Semi-supervised Classification with Graph Convolutional | |||
| Networks" <https://arxiv.org/abs/1609.02907>`_ paper. The layer is | |||
| .. math:: | |||
| \mathbf{X}^{\prime} = \mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} | |||
| \mathbf{\hat{D}}^{-1/2} \mathbf{X} \mathbf{\Theta}, | |||
| where :math:`\mathbf{\hat{A}} = \mathbf{A} + \mathbf{I}` denotes the | |||
| adjacency matrix with inserted self-loops and | |||
| :math:`\hat{D}_{ii} = \sum_{j=0} \hat{A}_{ij}` its diagonal degree matrix. | |||
| Parameters | |||
| ---------- | |||
| num_features: ``int`` | |||
| The dimension of features. | |||
| num_classes: ``int`` | |||
| The number of classes. | |||
| device: ``torch.device`` or ``str`` | |||
| The device where model will be running on. | |||
| init: `bool`. | |||
| If True(False), the model will (not) be initialized. | |||
| """ | |||
| def __init__( | |||
| self, | |||
| num_features: Optional[int] = None, | |||
| num_classes: Optional[int] = None, | |||
| device: Union[str, torch.device] = 'cpu', | |||
| init: bool = False, | |||
| **kwargs | |||
| ) -> None: | |||
| super().__init__() | |||
| self.num_features = num_features | |||
| self.num_classes = num_classes | |||
| self.device = device | |||
| self.params = { | |||
| "features_num": self.num_features, | |||
| "num_class": self.num_classes, | |||
| } | |||
| self.space = [ | |||
| { | |||
| "parameterName": "add_self_loops", | |||
| "type": "CATEGORICAL", | |||
| "feasiblePoints": [1], | |||
| }, | |||
| { | |||
| "parameterName": "normalize", | |||
| "type": "CATEGORICAL", | |||
| "feasiblePoints": [1], | |||
| }, | |||
| { | |||
| "parameterName": "num_layers", | |||
| "type": "DISCRETE", | |||
| "feasiblePoints": "2,3,4", | |||
| }, | |||
| { | |||
| "parameterName": "hidden", | |||
| "type": "NUMERICAL_LIST", | |||
| "numericalType": "INTEGER", | |||
| "length": 3, | |||
| "minValue": [8, 8, 8], | |||
| "maxValue": [128, 128, 128], | |||
| "scalingType": "LOG", | |||
| "cutPara": ("num_layers",), | |||
| "cutFunc": lambda x: x[0] - 1, | |||
| }, | |||
| { | |||
| "parameterName": "dropout", | |||
| "type": "DOUBLE", | |||
| "maxValue": 0.8, | |||
| "minValue": 0.2, | |||
| "scalingType": "LINEAR", | |||
| }, | |||
| { | |||
| "parameterName": "act", | |||
| "type": "CATEGORICAL", | |||
| "feasiblePoints": ["leaky_relu", "relu", "elu", "tanh"], | |||
| }, | |||
| ] | |||
| # initial point of hp search | |||
| # self.hyperparams = { | |||
| # "num_layers": 2, | |||
| # "hidden": [16], | |||
| # "dropout": 0.2, | |||
| # "act": "leaky_relu", | |||
| # } | |||
| self.hyperparams = { | |||
| "num_layers": 3, | |||
| "hidden": [128, 64], | |||
| "dropout": 0., | |||
| "act": "relu", | |||
| } | |||
| self.initialized = False | |||
| if init is True: | |||
| self.initialize() | |||
| def initialize(self): | |||
| if self.initialized: | |||
| return | |||
| self.initialized = True | |||
| self.model = GCN( | |||
| self.num_features, | |||
| self.num_classes, | |||
| self.hyperparams.get("hidden"), | |||
| self.hyperparams.get("act"), | |||
| self.hyperparams.get("dropout", None), | |||
| bool(self.hyperparams.get("add_self_loops", True)), | |||
| bool(self.hyperparams.get("normalize", True)), | |||
| ).to(self.device) | |||
| @@ -0,0 +1,345 @@ | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| from torch.nn import Linear, ReLU, Sequential, LeakyReLU, Tanh, ELU | |||
| from dgl.nn.pytorch.conv import GINConv | |||
| from dgl.nn.pytorch.glob import SumPooling, AvgPooling, MaxPooling | |||
| from torch.nn import BatchNorm1d | |||
| from . import register_model | |||
| from .base import BaseModel, activate_func | |||
| from copy import deepcopy | |||
| from ....utils import get_logger | |||
| LOGGER = get_logger("GINModel") | |||
| def set_default(args, d): | |||
| for k, v in d.items(): | |||
| if k not in args: | |||
| args[k] = v | |||
| return args | |||
| class ApplyNodeFunc(nn.Module): | |||
| """Update the node feature hv with MLP, BN and ReLU.""" | |||
| def __init__(self, mlp): | |||
| super(ApplyNodeFunc, self).__init__() | |||
| self.mlp = mlp | |||
| self.bn = nn.BatchNorm1d(self.mlp.output_dim) | |||
| def forward(self, h): | |||
| h = self.mlp(h) | |||
| h = self.bn(h) | |||
| h = F.relu(h) | |||
| return h | |||
| class MLP(nn.Module): | |||
| """MLP with linear output""" | |||
| def __init__(self, num_layers, input_dim, hidden_dim, output_dim): | |||
| """MLP layers construction | |||
| Paramters | |||
| --------- | |||
| num_layers: int | |||
| The number of linear layers | |||
| input_dim: int | |||
| The dimensionality of input features | |||
| hidden_dim: int | |||
| The dimensionality of hidden units at ALL layers | |||
| output_dim: int | |||
| The number of classes for prediction | |||
| """ | |||
| super(MLP, self).__init__() | |||
| self.linear_or_not = True # default is linear model | |||
| self.num_layers = num_layers | |||
| self.output_dim = output_dim | |||
| if num_layers < 1: | |||
| raise ValueError("number of layers should be positive!") | |||
| elif num_layers == 1: | |||
| # Linear model | |||
| self.linear = nn.Linear(input_dim, output_dim) | |||
| else: | |||
| # Multi-layer model | |||
| self.linear_or_not = False | |||
| self.linears = torch.nn.ModuleList() | |||
| self.batch_norms = torch.nn.ModuleList() | |||
| self.linears.append(nn.Linear(input_dim, hidden_dim)) | |||
| for layer in range(num_layers - 2): | |||
| self.linears.append(nn.Linear(hidden_dim, hidden_dim)) | |||
| self.linears.append(nn.Linear(hidden_dim, output_dim)) | |||
| for layer in range(num_layers - 1): | |||
| self.batch_norms.append(nn.BatchNorm1d((hidden_dim))) | |||
| def forward(self, x): | |||
| if self.linear_or_not: | |||
| # If linear model | |||
| return self.linear(x) | |||
| else: | |||
| # If MLP | |||
| h = x | |||
| for i in range(self.num_layers - 1): | |||
| h = F.relu(self.batch_norms[i](self.linears[i](h))) | |||
| return self.linears[-1](h) | |||
| class GIN(torch.nn.Module): | |||
| """GIN model""" | |||
| def __init__(self, args): | |||
| """model parameters setting | |||
| Paramters | |||
| --------- | |||
| num_layers: int | |||
| The number of linear layers in the neural network | |||
| num_mlp_layers: int | |||
| The number of linear layers in mlps | |||
| input_dim: int | |||
| The dimensionality of input features | |||
| hidden_dim: int | |||
| The dimensionality of hidden units at ALL layers | |||
| output_dim: int | |||
| The number of classes for prediction | |||
| final_dropout: float | |||
| dropout ratio on the final linear layer | |||
| learn_eps: boolean | |||
| If True, learn epsilon to distinguish center nodes from neighbors | |||
| If False, aggregate neighbors and center nodes altogether. | |||
| neighbor_pooling_type: str | |||
| how to aggregate neighbors (sum, mean, or max) | |||
| graph_pooling_type: str | |||
| how to aggregate entire nodes in a graph (sum, mean or max) | |||
| """ | |||
| super(GIN, self).__init__() | |||
| self.args = args | |||
| missing_keys = list( | |||
| set( | |||
| [ | |||
| "features_num", | |||
| "num_class", | |||
| "num_graph_features", | |||
| "num_layers", | |||
| "hidden", | |||
| "dropout", | |||
| "act", | |||
| "mlp_layers", | |||
| "eps", | |||
| ] | |||
| ) | |||
| - set(self.args.keys()) | |||
| ) | |||
| if len(missing_keys) > 0: | |||
| raise Exception("Missing keys: %s." % ",".join(missing_keys)) | |||
| #if not self.num_layer == len(self.args["hidden"]) + 1: | |||
| # LOGGER.warn("Warning: layer size does not match the length of hidden units") | |||
| self.num_graph_features = self.args["num_graph_features"] | |||
| self.num_layers = self.args["num_layers"] | |||
| assert self.num_layers > 2, "Number of layers in GIN should not less than 3" | |||
| self.learn_eps = self.args["eps"] | |||
| self.num_mlp_layers = self.args["mlp_layers"] | |||
| input_dim = self.args["features_num"] | |||
| hidden_dim = self.args["hidden"][0] | |||
| neighbor_pooling_type = self.args["neighbor_pooling_type"] | |||
| graph_pooling_type = self.args["graph_pooling_type"] | |||
| if self.args["act"] == "leaky_relu": | |||
| act = LeakyReLU() | |||
| elif self.args["act"] == "relu": | |||
| act = ReLU() | |||
| elif self.args["act"] == "elu": | |||
| act = ELU() | |||
| elif self.args["act"] == "tanh": | |||
| act = Tanh() | |||
| else: | |||
| act = ReLU() | |||
| learn_eps = True if self.args["eps"] == "True" else False | |||
| final_dropout = self.args["dropout"] | |||
| output_dim = self.args["num_class"] | |||
| # List of MLPs | |||
| self.ginlayers = torch.nn.ModuleList() | |||
| self.batch_norms = torch.nn.ModuleList() | |||
| for layer in range(self.num_layers - 1): | |||
| if layer == 0: | |||
| mlp = MLP(self.num_mlp_layers, input_dim, hidden_dim, hidden_dim) | |||
| else: | |||
| mlp = MLP(self.num_mlp_layers, hidden_dim, hidden_dim, hidden_dim) | |||
| self.ginlayers.append( | |||
| GINConv(ApplyNodeFunc(mlp), neighbor_pooling_type, 0, self.learn_eps)) | |||
| self.batch_norms.append(nn.BatchNorm1d(hidden_dim)) | |||
| # Linear function for graph poolings of output of each layer | |||
| # which maps the output of different layers into a prediction score | |||
| self.linears_prediction = torch.nn.ModuleList() | |||
| for layer in range(self.num_layers): | |||
| if layer == 0: | |||
| self.linears_prediction.append( | |||
| nn.Linear(input_dim, output_dim)) | |||
| else: | |||
| self.linears_prediction.append( | |||
| nn.Linear(hidden_dim, output_dim)) | |||
| self.drop = nn.Dropout(final_dropout) | |||
| if graph_pooling_type == 'sum': | |||
| self.pool = SumPooling() | |||
| elif graph_pooling_type == 'mean': | |||
| self.pool = AvgPooling() | |||
| elif graph_pooling_type == 'max': | |||
| self.pool = MaxPooling() | |||
| else: | |||
| raise NotImplementedError | |||
| #def forward(self, g, h): | |||
| def forward(self, data): | |||
| g, _ = data | |||
| h = g.ndata.pop('feat') | |||
| # list of hidden representation at each layer (including input) | |||
| hidden_rep = [h] | |||
| for i in range(self.num_layers - 1): | |||
| h = self.ginlayers[i](g, h) | |||
| h = self.batch_norms[i](h) | |||
| h = F.relu(h) | |||
| hidden_rep.append(h) | |||
| score_over_layer = 0 | |||
| # perform pooling over all nodes in each graph in every layer | |||
| for i, h in enumerate(hidden_rep): | |||
| pooled_h = self.pool(g, h) | |||
| score_over_layer += self.drop(self.linears_prediction[i](pooled_h)) | |||
| return score_over_layer | |||
| @register_model("gin") | |||
| class AutoGIN(BaseModel): | |||
| r""" | |||
| AutoGIN. The model used in this automodel is GIN, i.e., the graph isomorphism network from the `"How Powerful are | |||
| Graph Neural Networks?" <https://arxiv.org/abs/1810.00826>`_ paper. The layer is | |||
| .. math:: | |||
| \mathbf{x}^{\prime}_i = h_{\mathbf{\Theta}} \left( (1 + \epsilon) \cdot | |||
| \mathbf{x}_i + \sum_{j \in \mathcal{N}(i)} \mathbf{x}_j \right) | |||
| or | |||
| .. math:: | |||
| \mathbf{X}^{\prime} = h_{\mathbf{\Theta}} \left( \left( \mathbf{A} + | |||
| (1 + \epsilon) \cdot \mathbf{I} \right) \cdot \mathbf{X} \right), | |||
| here :math:`h_{\mathbf{\Theta}}` denotes a neural network, *.i.e.* an MLP. | |||
| Parameters | |||
| ---------- | |||
| num_features: `int`. | |||
| The dimension of features. | |||
| num_classes: `int`. | |||
| The number of classes. | |||
| device: `torch.device` or `str` | |||
| The device where model will be running on. | |||
| init: `bool`. | |||
| If True(False), the model will (not) be initialized. | |||
| """ | |||
| def __init__( | |||
| self, | |||
| num_features=None, | |||
| num_classes=None, | |||
| device=None, | |||
| init=False, | |||
| num_graph_features=None, | |||
| **args | |||
| ): | |||
| super(AutoGIN, self).__init__() | |||
| self.num_features = num_features if num_features is not None else 0 | |||
| self.num_classes = int(num_classes) if num_classes is not None else 0 | |||
| self.num_graph_features = ( | |||
| int(num_graph_features) if num_graph_features is not None else 0 | |||
| ) | |||
| self.device = device if device is not None else "cpu" | |||
| self.params = { | |||
| "features_num": self.num_features, | |||
| "num_class": self.num_classes, | |||
| "num_graph_features": self.num_graph_features, | |||
| } | |||
| self.space = [ | |||
| { | |||
| "parameterName": "num_layers", | |||
| "type": "DISCRETE", | |||
| "feasiblePoints": "4,5,6", | |||
| }, | |||
| { | |||
| "parameterName": "hidden", | |||
| "type": "NUMERICAL_LIST", | |||
| "numericalType": "INTEGER", | |||
| "length": 5, | |||
| "minValue": [8, 8, 8, 8, 8], | |||
| "maxValue": [64, 64, 64, 64, 64], | |||
| "scalingType": "LOG", | |||
| "cutPara": ("num_layers",), | |||
| "cutFunc": lambda x: x[0] - 1, | |||
| }, | |||
| { | |||
| "parameterName": "dropout", | |||
| "type": "DOUBLE", | |||
| "maxValue": 0.9, | |||
| "minValue": 0.1, | |||
| "scalingType": "LINEAR", | |||
| }, | |||
| { | |||
| "parameterName": "act", | |||
| "type": "CATEGORICAL", | |||
| "feasiblePoints": ["leaky_relu", "relu", "elu", "tanh"], | |||
| }, | |||
| { | |||
| "parameterName": "eps", | |||
| "type": "CATEGORICAL", | |||
| "feasiblePoints": ["True", "False"], | |||
| }, | |||
| { | |||
| "parameterName": "mlp_layers", | |||
| "type": "DISCRETE", | |||
| "feasiblePoints": "2,3,4", | |||
| }, | |||
| ] | |||
| self.hyperparams = { | |||
| "num_layers": 5, | |||
| "hidden": [64], | |||
| "dropout": 0.5, | |||
| "act": "relu", | |||
| "eps": "False", | |||
| "mlp_layers": 2, | |||
| "neighbor_pooling_type": "sum", | |||
| "graph_pooling_type": "sum" | |||
| } | |||
| self.initialized = False | |||
| if init is True: | |||
| self.initialize() | |||
| def initialize(self): | |||
| # """Initialize model.""" | |||
| if self.initialized: | |||
| return | |||
| self.initialized = True | |||
| self.model = GIN({**self.params, **self.hyperparams}).to(self.device) | |||
| @@ -0,0 +1,299 @@ | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| import torch as th | |||
| import dgl.function as fn | |||
| import math | |||
| import os | |||
| import time | |||
| import torch as th | |||
| import random | |||
| import numpy as np | |||
| import dgl.function as fn | |||
| import dgl | |||
| from dgl.sampling import random_walk, pack_traces | |||
| class GCNLayer(nn.Module): | |||
| def __init__(self, in_dim, out_dim, order=1, act=None, | |||
| dropout=0, batch_norm=False, aggr="concat"): | |||
| super(GCNLayer, self).__init__() | |||
| self.lins = nn.ModuleList() | |||
| self.bias = nn.ParameterList() | |||
| for _ in range(order + 1): | |||
| self.lins.append(nn.Linear(in_dim, out_dim, bias=False)) | |||
| self.bias.append(nn.Parameter(th.zeros(out_dim))) | |||
| self.order = order | |||
| self.act = act | |||
| self.dropout = nn.Dropout(dropout) | |||
| self.batch_norm = batch_norm | |||
| if batch_norm: | |||
| self.offset, self.scale = nn.ParameterList(), nn.ParameterList() | |||
| for _ in range(order + 1): | |||
| self.offset.append(nn.Parameter(th.zeros(out_dim))) | |||
| self.scale.append(nn.Parameter(th.ones(out_dim))) | |||
| self.aggr = aggr | |||
| self.reset_parameters() | |||
| def reset_parameters(self): | |||
| for lin in self.lins: | |||
| nn.init.xavier_normal_(lin.weight) | |||
| def feat_trans(self, features, idx): | |||
| h = self.lins[idx](features) + self.bias[idx] | |||
| if self.act is not None: | |||
| h = self.act(h) | |||
| if self.batch_norm: | |||
| mean = h.mean(dim=1).view(h.shape[0], 1) | |||
| var = h.var(dim=1, unbiased=False).view(h.shape[0], 1) + 1e-9 | |||
| h = (h - mean) * self.scale[idx] * th.rsqrt(var) + self.offset[idx] | |||
| return h | |||
| def forward(self, graph, features): | |||
| g = graph.local_var() | |||
| h_in = self.dropout(features) | |||
| h_hop = [h_in] | |||
| D_norm = g.ndata['train_D_norm'] if 'train_D_norm' in g.ndata else g.ndata['full_D_norm'] | |||
| for _ in range(self.order): | |||
| g.ndata['h'] = h_hop[-1] | |||
| if 'w' not in g.edata: | |||
| g.edata['w'] = th.ones((g.num_edges(), )).to(features.device) | |||
| g.update_all(fn.u_mul_e('h', 'w', 'm'), | |||
| fn.sum('m', 'h')) | |||
| h = g.ndata.pop('h') | |||
| h = h * D_norm | |||
| h_hop.append(h) | |||
| h_part = [self.feat_trans(ft, idx) for idx, ft in enumerate(h_hop)] | |||
| if self.aggr == "mean": | |||
| h_out = h_part[0] | |||
| for i in range(len(h_part) - 1): | |||
| h_out = h_out + h_part[i + 1] | |||
| elif self.aggr == "concat": | |||
| h_out = th.cat(h_part, 1) | |||
| else: | |||
| raise NotImplementedError | |||
| return h_out | |||
| class GCNNet(nn.Module): | |||
| def __init__(self, in_dim, hid_dim, out_dim, arch="1-1-0", | |||
| act=F.relu, dropout=0, batch_norm=False, aggr="concat"): | |||
| super(GCNNet, self).__init__() | |||
| self.gcn = nn.ModuleList() | |||
| orders = list(map(int, arch.split('-'))) | |||
| self.gcn.append(GCNLayer(in_dim=in_dim, out_dim=hid_dim, order=orders[0], | |||
| act=act, dropout=dropout, batch_norm=batch_norm, aggr=aggr)) | |||
| pre_out = ((aggr == "concat") * orders[0] + 1) * hid_dim | |||
| for i in range(1, len(orders)-1): | |||
| self.gcn.append(GCNLayer(in_dim=pre_out, out_dim=hid_dim, order=orders[i], | |||
| act=act, dropout=dropout, batch_norm=batch_norm, aggr=aggr)) | |||
| pre_out = ((aggr == "concat") * orders[i] + 1) * hid_dim | |||
| self.gcn.append(GCNLayer(in_dim=pre_out, out_dim=hid_dim, order=orders[-1], | |||
| act=act, dropout=dropout, batch_norm=batch_norm, aggr=aggr)) | |||
| pre_out = ((aggr == "concat") * orders[-1] + 1) * hid_dim | |||
| self.out_layer = GCNLayer(in_dim=pre_out, out_dim=out_dim, order=0, | |||
| act=None, dropout=dropout, batch_norm=False, aggr=aggr) | |||
| def forward(self, graph): | |||
| h = graph.ndata['feat'] | |||
| for layer in self.gcn: | |||
| h = layer(graph, h) | |||
| h = F.normalize(h, p=2, dim=1) | |||
| h = self.out_layer(graph, h) | |||
| return h | |||
| # The base class of sampler | |||
| # (TODO): online sampling | |||
| class SAINTSampler(object): | |||
| def __init__(self, dn, g, train_nid, node_budget, num_repeat=50): | |||
| """ | |||
| :param dn: name of dataset | |||
| :param g: full graph | |||
| :param train_nid: ids of training nodes | |||
| :param node_budget: expected number of sampled nodes | |||
| :param num_repeat: number of times of repeating sampling one node | |||
| """ | |||
| self.g = g | |||
| self.train_g: dgl.graph = g.subgraph(train_nid) | |||
| self.dn, self.num_repeat = dn, num_repeat | |||
| self.node_counter = th.zeros((self.train_g.num_nodes(),)) | |||
| self.edge_counter = th.zeros((self.train_g.num_edges(),)) | |||
| self.prob = None | |||
| graph_fn, norm_fn = self.__generate_fn__() | |||
| if os.path.exists(graph_fn): | |||
| self.subgraphs = np.load(graph_fn, allow_pickle=True) | |||
| aggr_norm, loss_norm = np.load(norm_fn, allow_pickle=True) | |||
| else: | |||
| os.makedirs('./subgraphs/', exist_ok=True) | |||
| self.subgraphs = [] | |||
| self.N, sampled_nodes = 0, 0 | |||
| t = time.perf_counter() | |||
| while sampled_nodes <= self.train_g.num_nodes() * num_repeat: | |||
| subgraph = self.__sample__() | |||
| self.subgraphs.append(subgraph) | |||
| sampled_nodes += subgraph.shape[0] | |||
| self.N += 1 | |||
| print(f'Sampling time: [{time.perf_counter() - t:.2f}s]') | |||
| np.save(graph_fn, self.subgraphs) | |||
| t = time.perf_counter() | |||
| self.__counter__() | |||
| aggr_norm, loss_norm = self.__compute_norm__() | |||
| print(f'Normalization time: [{time.perf_counter() - t:.2f}s]') | |||
| np.save(norm_fn, (aggr_norm, loss_norm)) | |||
| self.train_g.ndata['l_n'] = th.Tensor(loss_norm) | |||
| self.train_g.edata['w'] = th.Tensor(aggr_norm) | |||
| self.__compute_degree_norm() | |||
| self.num_batch = math.ceil(self.train_g.num_nodes() / node_budget) | |||
| random.shuffle(self.subgraphs) | |||
| self.__clear__() | |||
| print("The number of subgraphs is: ", len(self.subgraphs)) | |||
| print("The size of subgraphs is about: ", len(self.subgraphs[-1])) | |||
| def __clear__(self): | |||
| self.prob = None | |||
| self.node_counter = None | |||
| self.edge_counter = None | |||
| self.g = None | |||
| def __counter__(self): | |||
| for sampled_nodes in self.subgraphs: | |||
| sampled_nodes = th.from_numpy(sampled_nodes) | |||
| self.node_counter[sampled_nodes] += 1 | |||
| subg = self.train_g.subgraph(sampled_nodes) | |||
| sampled_edges = subg.edata[dgl.EID] | |||
| self.edge_counter[sampled_edges] += 1 | |||
| def __generate_fn__(self): | |||
| raise NotImplementedError | |||
| def __compute_norm__(self): | |||
| self.node_counter[self.node_counter == 0] = 1 | |||
| self.edge_counter[self.edge_counter == 0] = 1 | |||
| loss_norm = self.N / self.node_counter / self.train_g.num_nodes() | |||
| self.train_g.ndata['n_c'] = self.node_counter | |||
| self.train_g.edata['e_c'] = self.edge_counter | |||
| self.train_g.apply_edges(fn.v_div_e('n_c', 'e_c', 'a_n')) | |||
| aggr_norm = self.train_g.edata.pop('a_n') | |||
| self.train_g.ndata.pop('n_c') | |||
| self.train_g.edata.pop('e_c') | |||
| return aggr_norm.numpy(), loss_norm.numpy() | |||
| def __compute_degree_norm(self): | |||
| self.train_g.ndata['train_D_norm'] = 1. / self.train_g.in_degrees().float().clamp(min=1).unsqueeze(1) | |||
| self.g.ndata['full_D_norm'] = 1. / self.g.in_degrees().float().clamp(min=1).unsqueeze(1) | |||
| def __sample__(self): | |||
| raise NotImplementedError | |||
| def __len__(self): | |||
| return self.num_batch | |||
| def __iter__(self): | |||
| self.n = 0 | |||
| return self | |||
| def __next__(self): | |||
| if self.n < self.num_batch: | |||
| result = self.train_g.subgraph(self.subgraphs[self.n]) | |||
| self.n += 1 | |||
| return result | |||
| else: | |||
| random.shuffle(self.subgraphs) | |||
| raise StopIteration() | |||
| class SAINTNodeSampler(SAINTSampler): | |||
| def __init__(self, node_budget, dn, g, train_nid, num_repeat=50): | |||
| self.node_budget = node_budget | |||
| super(SAINTNodeSampler, self).__init__(dn, g, train_nid, node_budget, num_repeat) | |||
| def __generate_fn__(self): | |||
| graph_fn = os.path.join('./subgraphs/{}_Node_{}_{}.npy'.format(self.dn, self.node_budget, | |||
| self.num_repeat)) | |||
| norm_fn = os.path.join('./subgraphs/{}_Node_{}_{}_norm.npy'.format(self.dn, self.node_budget, | |||
| self.num_repeat)) | |||
| return graph_fn, norm_fn | |||
| def __sample__(self): | |||
| if self.prob is None: | |||
| self.prob = self.train_g.in_degrees().float().clamp(min=1) | |||
| sampled_nodes = th.multinomial(self.prob, num_samples=self.node_budget, replacement=True).unique() | |||
| return sampled_nodes.numpy() | |||
| class SAINTEdgeSampler(SAINTSampler): | |||
| def __init__(self, edge_budget, dn, g, train_nid, num_repeat=50): | |||
| self.edge_budget = edge_budget | |||
| super(SAINTEdgeSampler, self).__init__(dn, g, train_nid, edge_budget * 2, num_repeat) | |||
| def __generate_fn__(self): | |||
| graph_fn = os.path.join('./subgraphs/{}_Edge_{}_{}.npy'.format(self.dn, self.edge_budget, | |||
| self.num_repeat)) | |||
| norm_fn = os.path.join('./subgraphs/{}_Edge_{}_{}_norm.npy'.format(self.dn, self.edge_budget, | |||
| self.num_repeat)) | |||
| return graph_fn, norm_fn | |||
| def __sample__(self): | |||
| if self.prob is None: | |||
| src, dst = self.train_g.edges() | |||
| src_degrees, dst_degrees = self.train_g.in_degrees(src).float().clamp(min=1),\ | |||
| self.train_g.in_degrees(dst).float().clamp(min=1) | |||
| self.prob = 1. / src_degrees + 1. / dst_degrees | |||
| sampled_edges = th.multinomial(self.prob, num_samples=self.edge_budget, replacement=True).unique() | |||
| sampled_src, sampled_dst = self.train_g.find_edges(sampled_edges) | |||
| sampled_nodes = th.cat([sampled_src, sampled_dst]).unique() | |||
| return sampled_nodes.numpy() | |||
| class SAINTRandomWalkSampler(SAINTSampler): | |||
| def __init__(self, num_roots, length, dn, g, train_nid, num_repeat=50): | |||
| self.num_roots, self.length = num_roots, length | |||
| super(SAINTRandomWalkSampler, self).__init__(dn, g, train_nid, num_roots * length, num_repeat) | |||
| def __generate_fn__(self): | |||
| graph_fn = os.path.join('./subgraphs/{}_RW_{}_{}_{}.npy'.format(self.dn, self.num_roots, | |||
| self.length, self.num_repeat)) | |||
| norm_fn = os.path.join('./subgraphs/{}_RW_{}_{}_{}_norm.npy'.format(self.dn, self.num_roots, | |||
| self.length, self.num_repeat)) | |||
| return graph_fn, norm_fn | |||
| def __sample__(self): | |||
| sampled_roots = th.randint(0, self.train_g.num_nodes(), (self.num_roots, )) | |||
| traces, types = random_walk(self.train_g, nodes=sampled_roots, length=self.length) | |||
| sampled_nodes, _, _, _ = pack_traces(traces, types) | |||
| sampled_nodes = sampled_nodes.unique() | |||
| return sampled_nodes.numpy() | |||
| @@ -0,0 +1,314 @@ | |||
| import torch | |||
| import typing as _typing | |||
| import torch.nn.functional as F | |||
| from dgl.nn.pytorch.conv import SAGEConv | |||
| import torch.nn.functional | |||
| import autogl.data | |||
| from . import register_model | |||
| from .base import BaseModel, activate_func, ClassificationSupportedSequentialModel | |||
| from ....utils import get_logger | |||
| LOGGER = get_logger("SAGEModel") | |||
| class GraphSAGE(ClassificationSupportedSequentialModel): | |||
| class _SAGELayer(torch.nn.Module): | |||
| def __init__( | |||
| self, | |||
| input_channels: int, | |||
| output_channels: int, | |||
| aggr: str, | |||
| activation_name: _typing.Optional[str] = ..., | |||
| dropout_probability: _typing.Optional[float] = ..., | |||
| ): | |||
| super().__init__() | |||
| self._convolution: SAGEConv = SAGEConv( | |||
| input_channels, output_channels, aggregator_type=aggr | |||
| ) | |||
| if ( | |||
| activation_name is not Ellipsis | |||
| and activation_name is not None | |||
| and type(activation_name) == str | |||
| ): | |||
| self._activation_name: _typing.Optional[str] = activation_name | |||
| else: | |||
| self._activation_name: _typing.Optional[str] = None | |||
| if ( | |||
| dropout_probability is not Ellipsis | |||
| and dropout_probability is not None | |||
| and type(dropout_probability) == float | |||
| ): | |||
| if dropout_probability < 0: | |||
| dropout_probability = 0 | |||
| if dropout_probability > 1: | |||
| dropout_probability = 1 | |||
| self._dropout: _typing.Optional[torch.nn.Dropout] = torch.nn.Dropout( | |||
| dropout_probability | |||
| ) | |||
| else: | |||
| self._dropout: _typing.Optional[torch.nn.Dropout] = None | |||
| def forward(self, data, x, enable_activation: bool = True) -> torch.Tensor: | |||
| # x = data.ndata['feat'] | |||
| x: torch.Tensor = self._convolution.forward(data, x) | |||
| if (self._activation_name is not None) and enable_activation: | |||
| x: torch.Tensor = activate_func(x, self._activation_name) | |||
| if self._dropout is not None: | |||
| x: torch.Tensor = self._dropout.forward(x) | |||
| return x | |||
| def __init__( | |||
| self, | |||
| num_features: int, | |||
| num_classes: int, | |||
| hidden_features: _typing.Sequence[int], | |||
| activation_name: str, | |||
| layers_dropout: _typing.Union[ | |||
| _typing.Optional[float], _typing.Sequence[_typing.Optional[float]] | |||
| ] = None, | |||
| aggr: str = "mean", | |||
| ): | |||
| super().__init__() | |||
| if not type(num_features) == type(num_classes) == int: | |||
| raise TypeError | |||
| if not isinstance(hidden_features, _typing.Sequence): | |||
| raise TypeError | |||
| for hidden_feature in hidden_features: | |||
| if type(hidden_feature) != int: | |||
| raise TypeError | |||
| elif hidden_feature <= 0: | |||
| raise ValueError | |||
| if isinstance(layers_dropout, _typing.Sequence): | |||
| if len(layers_dropout) != (len(hidden_features) + 1): | |||
| raise TypeError | |||
| for d in layers_dropout: | |||
| if d is not None and type(d) != float: | |||
| raise TypeError | |||
| _layers_dropout: _typing.Sequence[_typing.Optional[float]] = layers_dropout | |||
| elif layers_dropout is None or type(layers_dropout) == float: | |||
| _layers_dropout: _typing.Sequence[_typing.Optional[float]] = [ | |||
| layers_dropout for _ in range(len(hidden_features)) | |||
| ] + [None] | |||
| else: | |||
| raise TypeError | |||
| if not type(activation_name) == type(aggr) == str: | |||
| raise TypeError | |||
| if aggr not in ("add", "max", "mean"): | |||
| aggr = "mean" | |||
| if len(hidden_features) == 0: | |||
| self.__sequential_encoding_layers: torch.nn.ModuleList = ( | |||
| torch.nn.ModuleList( | |||
| [ | |||
| self._SAGELayer( | |||
| num_features, | |||
| num_classes, | |||
| aggr, | |||
| activation_name, | |||
| _layers_dropout[0], | |||
| ) | |||
| ] | |||
| ) | |||
| ) | |||
| else: | |||
| self.__sequential_encoding_layers: torch.nn.ModuleList = ( | |||
| torch.nn.ModuleList( | |||
| [ | |||
| self._SAGELayer( | |||
| num_features, | |||
| hidden_features[0], | |||
| aggr, | |||
| activation_name, | |||
| _layers_dropout[0], | |||
| ) | |||
| ] | |||
| ) | |||
| ) | |||
| for i in range(len(hidden_features)): | |||
| if i + 1 < len(hidden_features): | |||
| self.__sequential_encoding_layers.append( | |||
| self._SAGELayer( | |||
| hidden_features[i], | |||
| hidden_features[i + 1], | |||
| aggr, | |||
| activation_name, | |||
| _layers_dropout[i + 1], | |||
| ) | |||
| ) | |||
| else: | |||
| self.__sequential_encoding_layers.append( | |||
| self._SAGELayer( | |||
| hidden_features[i], | |||
| num_classes, | |||
| aggr, | |||
| dropout_probability=_layers_dropout[i + 1], | |||
| ) | |||
| ) | |||
| @property | |||
| def sequential_encoding_layers(self) -> torch.nn.ModuleList: | |||
| return self.__sequential_encoding_layers | |||
| def cls_encode(self, data) -> torch.Tensor: | |||
| return self(data) | |||
| # if ( | |||
| # hasattr(data, "edge_indexes") | |||
| # and isinstance(getattr(data, "edge_indexes"), _typing.Sequence) | |||
| # and len(getattr(data, "edge_indexes")) | |||
| # == len(self.__sequential_encoding_layers) | |||
| # ): | |||
| # for __edge_index in getattr(data, "edge_indexes"): | |||
| # if type(__edge_index) != torch.Tensor: | |||
| # raise TypeError | |||
| # """ Layer-wise encode """ | |||
| # x: torch.Tensor = getattr(data, "x") | |||
| # for i, __edge_index in enumerate(getattr(data, "edge_indexes")): | |||
| # x: torch.Tensor = self.__sequential_encoding_layers[i]( | |||
| # autogl.data.Data(x=x, edge_index=__edge_index) | |||
| # ) | |||
| # return x | |||
| # else: | |||
| x: torch.Tensor = data.ndata['feat'] | |||
| for i in range(len(self.__sequential_encoding_layers)): | |||
| x = self.__sequential_encoding_layers[i]( | |||
| autogl.data.Data(x, data.edges()) | |||
| ) | |||
| return x | |||
| def cls_decode(self, x: torch.Tensor) -> torch.Tensor: | |||
| return torch.nn.functional.log_softmax(x, dim=1) | |||
| def lp_encode(self, data): | |||
| x: torch.Tensor = data.ndata['feat'] | |||
| for i in range(len(self.__sequential_encoding_layers) - 2): | |||
| x = self.__sequential_encoding_layers[i]( | |||
| autogl.data.Data(x, data.edges()) | |||
| ) | |||
| x = self.__sequential_encoding_layers[-2]( | |||
| autogl.data.Data(x, data.edges()), enable_activation=False | |||
| ) | |||
| return x | |||
| def lp_decode(self, z, pos_edge_index, neg_edge_index): | |||
| edge_index = torch.cat([pos_edge_index, neg_edge_index], dim=-1) | |||
| logits = (z[edge_index[0]] * z[edge_index[1]]).sum(dim=-1) | |||
| return logits | |||
| def lp_decode_all(self, z): | |||
| prob_adj = z @ z.t() | |||
| return (prob_adj > 0).nonzero(as_tuple=False).t() | |||
| def forward(self, data): | |||
| # only for test | |||
| x = data.ndata['feat'] | |||
| for i in range(len(self.__sequential_encoding_layers)): | |||
| x = self.__sequential_encoding_layers[i](data,x) | |||
| return F.log_softmax(x, dim=1) | |||
| @register_model("sage") | |||
| class AutoSAGE(BaseModel): | |||
| r""" | |||
| AutoSAGE. The model used in this automodel is GraphSAGE, i.e., the GraphSAGE from the `"Inductive Representation Learning on | |||
| Large Graphs" <https://arxiv.org/abs/1706.02216>`_ paper. The layer is | |||
| .. math:: | |||
| \mathbf{x}^{\prime}_i = \mathbf{W}_1 \mathbf{x}_i + \mathbf{W_2} \cdot | |||
| \mathrm{mean}_{j \in \mathcal{N(i)}} \mathbf{x}_j | |||
| Parameters | |||
| ---------- | |||
| num_features: `int`. | |||
| The dimension of features. | |||
| num_classes: `int`. | |||
| The number of classes. | |||
| device: `torch.device` or `str` | |||
| The device where model will be running on. | |||
| init: `bool`. | |||
| If True(False), the model will (not) be initialized. | |||
| """ | |||
| def __init__( | |||
| self, num_features=None, num_classes=None, device=None, init=False, **args | |||
| ): | |||
| super(AutoSAGE, self).__init__() | |||
| self.num_features = num_features if num_features is not None else 0 | |||
| self.num_classes = int(num_classes) if num_classes is not None else 0 | |||
| self.device = device if device is not None else "cpu" | |||
| self.init = True | |||
| self.params = { | |||
| "features_num": self.num_features, | |||
| "num_class": self.num_classes, | |||
| } | |||
| self.space = [ | |||
| { | |||
| "parameterName": "num_layers", | |||
| "type": "DISCRETE", | |||
| "feasiblePoints": "2,3,4", | |||
| }, | |||
| { | |||
| "parameterName": "hidden", | |||
| "type": "NUMERICAL_LIST", | |||
| "numericalType": "INTEGER", | |||
| "length": 3, | |||
| "minValue": [8, 8, 8], | |||
| "maxValue": [128, 128, 128], | |||
| "scalingType": "LOG", | |||
| "cutPara": ("num_layers",), | |||
| "cutFunc": lambda x: x[0] - 1, | |||
| }, | |||
| { | |||
| "parameterName": "dropout", | |||
| "type": "DOUBLE", | |||
| "maxValue": 0.8, | |||
| "minValue": 0.2, | |||
| "scalingType": "LINEAR", | |||
| }, | |||
| { | |||
| "parameterName": "act", | |||
| "type": "CATEGORICAL", | |||
| "feasiblePoints": ["leaky_relu", "relu", "elu", "tanh"], | |||
| }, | |||
| { | |||
| "parameterName": "agg", | |||
| "type": "CATEGORICAL", | |||
| "feasiblePoints": ["mean", "add", "max"], | |||
| }, | |||
| ] | |||
| self.hyperparams = { | |||
| "num_layers": 3, | |||
| "hidden": [64, 32], | |||
| "dropout": 0.5, | |||
| "act": "relu", | |||
| "agg": "mean", | |||
| } | |||
| self.initialized = False | |||
| if init is True: | |||
| self.initialize() | |||
| def initialize(self): | |||
| if self.initialized: | |||
| return | |||
| self.initialized = True | |||
| self.model = GraphSAGE( | |||
| self.num_features, | |||
| self.num_classes, | |||
| self.hyperparams.get("hidden"), | |||
| self.hyperparams.get("act", "relu"), | |||
| self.hyperparams.get("dropout", None), | |||
| self.hyperparams.get("agg", "mean"), | |||
| ).to(self.device) | |||
| @@ -0,0 +1,286 @@ | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| from torch.nn import Linear, ReLU, Sequential, LeakyReLU, Tanh, ELU | |||
| from dgl.nn.pytorch.conv import GraphConv | |||
| from dgl.nn.pytorch.glob import SortPooling | |||
| from torch.nn import BatchNorm1d | |||
| from . import register_model | |||
| from .base import BaseModel, activate_func | |||
| from copy import deepcopy | |||
| from ....utils import get_logger | |||
| LOGGER = get_logger("TopkModel") | |||
| def set_default(args, d): | |||
| for k, v in d.items(): | |||
| if k not in args: | |||
| args[k] = v | |||
| return args | |||
| class ApplyNodeFunc(nn.Module): | |||
| """Update the node feature hv with MLP, BN and ReLU.""" | |||
| def __init__(self, mlp): | |||
| super(ApplyNodeFunc, self).__init__() | |||
| self.mlp = mlp | |||
| self.bn = nn.BatchNorm1d(self.mlp.output_dim) | |||
| def forward(self, h): | |||
| h = self.mlp(h) | |||
| h = self.bn(h) | |||
| h = F.relu(h) | |||
| return h | |||
| class MLP(nn.Module): | |||
| """MLP with linear output""" | |||
| def __init__(self, num_layers, input_dim, hidden_dim, output_dim): | |||
| """MLP layers construction | |||
| Paramters | |||
| --------- | |||
| num_layers: int | |||
| The number of linear layers | |||
| input_dim: int | |||
| The dimensionality of input features | |||
| hidden_dim: int | |||
| The dimensionality of hidden units at ALL layers | |||
| output_dim: int | |||
| The number of classes for prediction | |||
| """ | |||
| super(MLP, self).__init__() | |||
| self.linear_or_not = True # default is linear model | |||
| self.num_layers = num_layers | |||
| self.output_dim = output_dim | |||
| if num_layers < 1: | |||
| raise ValueError("number of layers should be positive!") | |||
| elif num_layers == 1: | |||
| # Linear model | |||
| self.linear = nn.Linear(input_dim, output_dim) | |||
| else: | |||
| # Multi-layer model | |||
| self.linear_or_not = False | |||
| self.linears = torch.nn.ModuleList() | |||
| self.batch_norms = torch.nn.ModuleList() | |||
| self.linears.append(nn.Linear(input_dim, hidden_dim)) | |||
| for layer in range(num_layers - 2): | |||
| self.linears.append(nn.Linear(hidden_dim, hidden_dim)) | |||
| self.linears.append(nn.Linear(hidden_dim, output_dim)) | |||
| for layer in range(num_layers - 1): | |||
| self.batch_norms.append(nn.BatchNorm1d((hidden_dim))) | |||
| def forward(self, x): | |||
| if self.linear_or_not: | |||
| # If linear model | |||
| return self.linear(x) | |||
| else: | |||
| # If MLP | |||
| h = x | |||
| for i in range(self.num_layers - 1): | |||
| h = F.relu(self.batch_norms[i](self.linears[i](h))) | |||
| return self.linears[-1](h) | |||
| class Topkpool(torch.nn.Module): | |||
| """Topkpool model""" | |||
| def __init__(self, args): | |||
| """model parameters setting | |||
| Paramters | |||
| --------- | |||
| num_layers: int | |||
| The number of linear layers in the neural network | |||
| num_mlp_layers: int | |||
| The number of linear layers in mlps | |||
| input_dim: int | |||
| The dimensionality of input features | |||
| hidden_dim: int | |||
| The dimensionality of hidden units at ALL layers | |||
| output_dim: int | |||
| The number of classes for prediction | |||
| final_dropout: float | |||
| dropout ratio on the final linear layer | |||
| """ | |||
| super(Topkpool, self).__init__() | |||
| self.args = args | |||
| missing_keys = list( | |||
| set( | |||
| [ | |||
| "features_num", | |||
| "num_class", | |||
| "num_graph_features", | |||
| "num_layers", | |||
| "hidden", | |||
| "dropout", | |||
| ] | |||
| ) | |||
| - set(self.args.keys()) | |||
| ) | |||
| if len(missing_keys) > 0: | |||
| raise Exception("Missing keys: %s." % ",".join(missing_keys)) | |||
| #if not self.num_layer == len(self.args["hidden"]) + 1: | |||
| # LOGGER.warn("Warning: layer size does not match the length of hidden units") | |||
| self.num_graph_features = self.args["num_graph_features"] | |||
| self.num_layers = self.args["num_layers"] | |||
| assert self.num_layers > 2, "Number of layers in GIN should not less than 3" | |||
| input_dim = self.args["features_num"] | |||
| hidden_dim = self.args["hidden"][0] | |||
| final_dropout = self.args["dropout"] | |||
| output_dim = self.args["num_class"] | |||
| # List of MLPs | |||
| self.gcnlayers = torch.nn.ModuleList() | |||
| self.batch_norms = torch.nn.ModuleList() | |||
| for layer in range(self.num_layers - 1): | |||
| if layer == 0: | |||
| self.gcnlayers.append(GraphConv(input_dim, hidden_dim)) | |||
| else: | |||
| self.gcnlayers.append(GraphConv(hidden_dim, hidden_dim)) | |||
| #self.gcnlayers.append(GraphConv(input_dim, hidden_dim)) | |||
| self.batch_norms.append(nn.BatchNorm1d(hidden_dim)) | |||
| # Linear function for graph poolings of output of each layer | |||
| # which maps the output of different layers into a prediction score | |||
| self.linears_prediction = torch.nn.ModuleList() | |||
| #TopKPool | |||
| k = 3 | |||
| self.pool = SortPooling(k) | |||
| for layer in range(self.num_layers): | |||
| if layer == 0: | |||
| self.linears_prediction.append( | |||
| nn.Linear(input_dim * k, output_dim)) | |||
| else: | |||
| self.linears_prediction.append( | |||
| nn.Linear(hidden_dim * k, output_dim)) | |||
| self.drop = nn.Dropout(final_dropout) | |||
| #def forward(self, g, h): | |||
| def forward(self, data): | |||
| g, _ = data | |||
| h = g.ndata.pop('feat') | |||
| # list of hidden representation at each layer (including input) | |||
| hidden_rep = [h] | |||
| for i in range(self.num_layers - 1): | |||
| h = self.gcnlayers[i](g, h) | |||
| h = self.batch_norms[i](h) | |||
| h = F.relu(h) | |||
| hidden_rep.append(h) | |||
| score_over_layer = 0 | |||
| # perform pooling over all nodes in each graph in every layer | |||
| for i, h in enumerate(hidden_rep): | |||
| pooled_h = self.pool(g, h) | |||
| #import pdb; pdb.set_trace() | |||
| score_over_layer += self.drop(self.linears_prediction[i](pooled_h)) | |||
| return score_over_layer | |||
| @register_model("topkpool") | |||
| class AutoTopkpool(BaseModel): | |||
| r""" | |||
| AutoTopkpool. The model used in this automodel is from https://arxiv.org/abs/1905.05178, https://arxiv.org/abs/1905.02850 | |||
| Parameters | |||
| ---------- | |||
| num_features: `int`. | |||
| The dimension of features. | |||
| num_classes: `int`. | |||
| The number of classes. | |||
| device: `torch.device` or `str` | |||
| The device where model will be running on. | |||
| init: `bool`. | |||
| If True(False), the model will (not) be initialized. | |||
| """ | |||
| def __init__( | |||
| self, | |||
| num_features=None, | |||
| num_classes=None, | |||
| device=None, | |||
| init=False, | |||
| num_graph_features=None, | |||
| **args | |||
| ): | |||
| super(AutoTopkpool, self).__init__() | |||
| LOGGER.debug( | |||
| "topkpool __init__ get params num_graph_features {}".format( | |||
| num_graph_features | |||
| ) | |||
| ) | |||
| self.num_features = num_features if num_features is not None else 0 | |||
| self.num_classes = int(num_classes) if num_classes is not None else 0 | |||
| self.num_graph_features = ( | |||
| int(num_graph_features) if num_graph_features is not None else 0 | |||
| ) | |||
| self.device = device if device is not None else "cpu" | |||
| self.params = { | |||
| "features_num": self.num_features, | |||
| "num_class": self.num_classes, | |||
| "num_graph_features": self.num_graph_features, | |||
| } | |||
| self.space = [ | |||
| { | |||
| "parameterName": "hidden", | |||
| "type": "NUMERICAL_LIST", | |||
| "numericalType": "INTEGER", | |||
| "length": 1, | |||
| "minValue": [128], | |||
| "maxValue": [32], | |||
| "scalingType": "LOG", | |||
| "cutPara": (), | |||
| "cutFunc": lambda:1, | |||
| }, | |||
| { | |||
| "parameterName": "dropout", | |||
| "type": "DOUBLE", | |||
| "maxValue": 0.9, | |||
| "minValue": 0.1, | |||
| "scalingType": "LINEAR", | |||
| }, | |||
| { | |||
| "parameterName": "num_layers", | |||
| "type": "INTEGER", | |||
| "minValue": 7, | |||
| "maxValue": 2, | |||
| "scalingType": "LINEAR" | |||
| }, | |||
| ] | |||
| self.hyperparams = { | |||
| "num_layers": 5, | |||
| "hidden": [64], | |||
| "dropout": 0.5 | |||
| } | |||
| self.initialized = False | |||
| if init is True: | |||
| self.initialize() | |||
| def initialize(self): | |||
| if self.initialized: | |||
| return | |||
| self.initialized = True | |||
| LOGGER.debug("topkpool initialize with parameters {}".format(self.params)) | |||
| self.model = Topkpool({**self.params, **self.hyperparams}).to(self.device) | |||
| @@ -0,0 +1,22 @@ | |||
| from ._model_registry import MODEL_DICT, ModelUniversalRegistry, register_model | |||
| from .base import BaseModel | |||
| from .topkpool import AutoTopkpool | |||
| # from .graph_sage import AutoSAGE | |||
| from .graphsage import AutoSAGE | |||
| from .graph_saint import GraphSAINTAggregationModel | |||
| from .gcn import AutoGCN | |||
| from .gat import AutoGAT | |||
| from .gin import AutoGIN | |||
| __all__ = [ | |||
| "ModelUniversalRegistry", | |||
| "register_model", | |||
| "BaseModel", | |||
| "AutoTopkpool", | |||
| "AutoSAGE", | |||
| "GraphSAINTAggregationModel", | |||
| "AutoGCN", | |||
| "AutoGAT", | |||
| "AutoGIN", | |||
| ] | |||
| @@ -0,0 +1,28 @@ | |||
| import typing as _typing | |||
| from .base import BaseModel | |||
| MODEL_DICT: _typing.Dict[str, _typing.Type[BaseModel]] = {} | |||
| def register_model(name): | |||
| def register_model_cls(cls): | |||
| if name in MODEL_DICT: | |||
| raise ValueError("Cannot register duplicate trainer ({})".format(name)) | |||
| if not issubclass(cls, BaseModel): | |||
| raise ValueError( | |||
| "Trainer ({}: {}) must extend BaseModel".format(name, cls.__name__) | |||
| ) | |||
| MODEL_DICT[name] = cls | |||
| return cls | |||
| return register_model_cls | |||
| class ModelUniversalRegistry: | |||
| @classmethod | |||
| def get_model(cls, name: str) -> _typing.Type[BaseModel]: | |||
| if type(name) != str: | |||
| raise TypeError | |||
| if name not in MODEL_DICT: | |||
| raise KeyError | |||
| return MODEL_DICT.get(name) | |||
| @@ -0,0 +1,413 @@ | |||
| """ | |||
| auto graph model | |||
| a list of models with their hyper parameters | |||
| NOTE: neural architecture search (NAS) maybe included here | |||
| """ | |||
| import copy | |||
| import logging | |||
| import typing as _typing | |||
| import torch | |||
| import torch.nn.functional as F | |||
| from copy import deepcopy | |||
| base_approach_logger: logging.Logger = logging.getLogger("BaseModel") | |||
| def activate_func(x, func): | |||
| if func == "tanh": | |||
| return torch.tanh(x) | |||
| elif hasattr(F, func): | |||
| return getattr(F, func)(x) | |||
| elif func == "": | |||
| pass | |||
| else: | |||
| raise TypeError("PyTorch does not support activation function {}".format(func)) | |||
| return x | |||
| class BaseModel: | |||
| def __init__(self, init=False, *args, **kwargs): | |||
| super(BaseModel, self).__init__() | |||
| def get_hyper_parameter(self): | |||
| return deepcopy(self.hyperparams) | |||
| @property | |||
| def hyper_parameter_space(self): | |||
| return self.space | |||
| @hyper_parameter_space.setter | |||
| def hyper_parameter_space(self, space): | |||
| self.space = space | |||
| def initialize(self): | |||
| pass | |||
| def forward(self): | |||
| pass | |||
| def to(self, device): | |||
| if isinstance(device, (str, torch.device)): | |||
| self.device = device | |||
| if ( | |||
| hasattr(self, "model") | |||
| and self.model is not None | |||
| and isinstance(self.model, torch.nn.Module) | |||
| ): | |||
| self.model.to(self.device) | |||
| return self | |||
| def from_hyper_parameter(self, hp): | |||
| ret_self = self.__class__( | |||
| num_features=self.num_features, | |||
| num_classes=self.num_classes, | |||
| device=self.device, | |||
| init=False, | |||
| ) | |||
| ret_self.hyperparams.update(hp) | |||
| ret_self.params.update(self.params) | |||
| ret_self.initialize() | |||
| return ret_self | |||
| def get_num_classes(self): | |||
| return self.num_classes | |||
| def set_num_classes(self, num_classes): | |||
| self.num_classes = num_classes | |||
| self.params["num_class"] = num_classes | |||
| def get_num_features(self): | |||
| return self.num_features | |||
| def set_num_features(self, num_features): | |||
| self.num_features = num_features | |||
| self.params["features_num"] = self.num_features | |||
| def set_num_graph_features(self, num_graph_features): | |||
| assert hasattr( | |||
| self, "num_graph_features" | |||
| ), "Cannot set graph features for tasks other than graph classification" | |||
| self.num_graph_features = num_graph_features | |||
| self.params["num_graph_features"] = num_graph_features | |||
| class _BaseBaseModel: | |||
| # todo: after renaming the experimental base class _BaseModel to BaseModel, | |||
| # rename this class to _BaseModel | |||
| """ | |||
| The base class for class BaseModel, | |||
| designed to implement some basic functionality of BaseModel. | |||
| -- Designed by ZiXin Sun | |||
| """ | |||
| @classmethod | |||
| def __formulate_device( | |||
| cls, device: _typing.Union[str, torch.device] = ... | |||
| ) -> torch.device: | |||
| if type(device) == torch.device or ( | |||
| type(device) == str and device.strip().lower() != "auto" | |||
| ): | |||
| return torch.device(device) | |||
| elif torch.cuda.is_available() and torch.cuda.device_count() > 0: | |||
| return torch.device("cuda") | |||
| else: | |||
| return torch.device("cpu") | |||
| @property | |||
| def device(self) -> torch.device: | |||
| return self.__device | |||
| @device.setter | |||
| def device(self, __device: _typing.Union[str, torch.device, None]): | |||
| self.__device: torch.device = self.__formulate_device(__device) | |||
| @property | |||
| def model(self) -> _typing.Optional[torch.nn.Module]: | |||
| if self._model is None: | |||
| base_approach_logger.debug( | |||
| "property of model NOT initialized before accessing" | |||
| ) | |||
| return self._model | |||
| @model.setter | |||
| def model(self, _model: torch.nn.Module) -> None: | |||
| if not isinstance(_model, torch.nn.Module): | |||
| raise TypeError( | |||
| "the property of model MUST be an instance of " "torch.nn.Module" | |||
| ) | |||
| self._model = _model | |||
| def _initialize(self): | |||
| raise NotImplementedError | |||
| def initialize(self) -> bool: | |||
| """ | |||
| Initialize the model in case that the model has NOT been initialized | |||
| :return: whether self._initialize() method called | |||
| """ | |||
| if not self.__is_initialized: | |||
| self._initialize() | |||
| self.__is_initialized = True | |||
| return True | |||
| return False | |||
| # def to(self, *args, **kwargs): | |||
| # """ | |||
| # Due to the signature of to() method in class BaseApproach | |||
| # is inconsistent with the signature of the method | |||
| # in the base class torch.nn.Module, | |||
| # this intermediate overridden method is necessary to | |||
| # walk around (bypass) the inspection for | |||
| # signature of overriding method. | |||
| # :param args: positional arguments list | |||
| # :param kwargs: keyword arguments dict | |||
| # :return: self | |||
| # """ | |||
| # return super(_BaseBaseModel, self).to(*args, **kwargs) | |||
| def forward(self, *args, **kwargs): | |||
| if self.model is not None and isinstance(self.model, torch.nn.Module): | |||
| return self.model(*args, **kwargs) | |||
| else: | |||
| raise NotImplementedError | |||
| def __init__( | |||
| self, | |||
| model: _typing.Optional[torch.nn.Module] = None, | |||
| initialize: bool = False, | |||
| device: _typing.Union[str, torch.device] = ..., | |||
| ): | |||
| if type(initialize) != bool: | |||
| raise TypeError | |||
| super(_BaseBaseModel, self).__init__() | |||
| self.__device: torch.device = self.__formulate_device(device) | |||
| self._model: _typing.Optional[torch.nn.Module] = model | |||
| self.__is_initialized: bool = False | |||
| if initialize: | |||
| self.initialize() | |||
| class _BaseModel(_BaseBaseModel, BaseModel): | |||
| """ | |||
| The upcoming root base class for Model, i.e. BaseModel | |||
| -- Designed by ZiXin Sun | |||
| """ | |||
| # todo: Deprecate and remove the legacy class "BaseModel", | |||
| # then rename this class to "BaseModel", | |||
| # correspondingly, this class will no longer extend | |||
| # the legacy class "BaseModel" after the removal. | |||
| def _initialize(self): | |||
| raise NotImplementedError | |||
| def to(self, device: torch.device): | |||
| self.device = device | |||
| if self.model is not None and isinstance(self.model, torch.nn.Module): | |||
| self.model.to(self.device) | |||
| return super().to(device) | |||
| @property | |||
| def space(self) -> _typing.Sequence[_typing.Dict[str, _typing.Any]]: | |||
| # todo: deprecate and remove in future major version | |||
| return self.__hyper_parameter_space | |||
| @property | |||
| def hyper_parameter_space(self): | |||
| return self.__hyper_parameter_space | |||
| @hyper_parameter_space.setter | |||
| def hyper_parameter_space( | |||
| self, space: _typing.Sequence[_typing.Dict[str, _typing.Any]] | |||
| ): | |||
| self.__hyper_parameter_space = space | |||
| @property | |||
| def hyper_parameter(self) -> _typing.Dict[str, _typing.Any]: | |||
| return self.__hyper_parameter | |||
| @hyper_parameter.setter | |||
| def hyper_parameter(self, _hyper_parameter: _typing.Dict[str, _typing.Any]): | |||
| if not isinstance(_hyper_parameter, dict): | |||
| raise TypeError | |||
| self.__hyper_parameter = _hyper_parameter | |||
| def get_hyper_parameter(self) -> _typing.Dict[str, _typing.Any]: | |||
| """ | |||
| todo: consider deprecating this trivial getter method in the future | |||
| :return: copied hyper parameter | |||
| """ | |||
| return copy.deepcopy(self.__hyper_parameter) | |||
| def __init__( | |||
| self, | |||
| model: _typing.Optional[torch.nn.Module] = None, | |||
| initialize: bool = False, | |||
| hyper_parameter_space: _typing.Sequence[_typing.Any] = ..., | |||
| hyper_parameter: _typing.Dict[str, _typing.Any] = ..., | |||
| device: _typing.Union[str, torch.device] = ..., | |||
| ): | |||
| if type(initialize) != bool: | |||
| raise TypeError | |||
| super(_BaseModel, self).__init__(model, initialize, device) | |||
| if hyper_parameter_space != Ellipsis and isinstance( | |||
| hyper_parameter_space, _typing.Sequence | |||
| ): | |||
| self.__hyper_parameter_space: _typing.Sequence[ | |||
| _typing.Dict[str, _typing.Any] | |||
| ] = hyper_parameter_space | |||
| else: | |||
| self.__hyper_parameter_space: _typing.Sequence[ | |||
| _typing.Dict[str, _typing.Any] | |||
| ] = [] | |||
| if hyper_parameter != Ellipsis and isinstance(hyper_parameter, dict): | |||
| self.__hyper_parameter: _typing.Dict[str, _typing.Any] = hyper_parameter | |||
| else: | |||
| self.__hyper_parameter: _typing.Dict[str, _typing.Any] = {} | |||
| def from_hyper_parameter(self, hyper_parameter: _typing.Dict[str, _typing.Any]): | |||
| raise NotImplementedError | |||
| class ClassificationModel(_BaseModel): | |||
| def _initialize(self): | |||
| raise NotImplementedError | |||
| def from_hyper_parameter( | |||
| self, hyper_parameter: _typing.Dict[str, _typing.Any] | |||
| ) -> "ClassificationModel": | |||
| new_model: ClassificationModel = self.__class__( | |||
| num_features=self.num_features, | |||
| num_classes=self.num_classes, | |||
| device=self.device, | |||
| init=False, | |||
| ) | |||
| _hyper_parameter = self.hyper_parameter | |||
| _hyper_parameter.update(hyper_parameter) | |||
| new_model.hyper_parameter = _hyper_parameter | |||
| new_model.initialize() | |||
| return new_model | |||
| def __init__( | |||
| self, | |||
| num_features: int = ..., | |||
| num_classes: int = ..., | |||
| num_graph_features: int = ..., | |||
| device: _typing.Union[str, torch.device] = ..., | |||
| hyper_parameter_space: _typing.Sequence[_typing.Any] = ..., | |||
| hyper_parameter: _typing.Dict[str, _typing.Any] = ..., | |||
| init: bool = False, | |||
| **kwargs | |||
| ): | |||
| if "initialize" in kwargs: | |||
| del kwargs["initialize"] | |||
| super(ClassificationModel, self).__init__( | |||
| initialize=init, | |||
| hyper_parameter_space=hyper_parameter_space, | |||
| hyper_parameter=hyper_parameter, | |||
| device=device, | |||
| **kwargs | |||
| ) | |||
| if num_classes != Ellipsis and type(num_classes) == int: | |||
| self.__num_classes: int = num_classes if num_classes > 0 else 0 | |||
| else: | |||
| self.__num_classes: int = 0 | |||
| if num_features != Ellipsis and type(num_features) == int: | |||
| self.__num_features: int = num_features if num_features > 0 else 0 | |||
| else: | |||
| self.__num_features: int = 0 | |||
| if num_graph_features != Ellipsis and type(num_graph_features) == int: | |||
| if num_graph_features > 0: | |||
| self.__num_graph_features: int = num_graph_features | |||
| else: | |||
| self.__num_graph_features: int = 0 | |||
| else: | |||
| self.__num_graph_features: int = 0 | |||
| def __repr__(self) -> str: | |||
| import yaml | |||
| return yaml.dump(self.hyper_parameter) | |||
| @property | |||
| def num_classes(self) -> int: | |||
| return self.__num_classes | |||
| @num_classes.setter | |||
| def num_classes(self, __num_classes: int): | |||
| if type(__num_classes) != int: | |||
| raise TypeError | |||
| if not __num_classes > 0: | |||
| raise ValueError | |||
| self.__num_classes = __num_classes if __num_classes > 0 else 0 | |||
| @property | |||
| def num_features(self) -> int: | |||
| return self.__num_features | |||
| @num_features.setter | |||
| def num_features(self, __num_features: int): | |||
| if type(__num_features) != int: | |||
| raise TypeError | |||
| if not __num_features > 0: | |||
| raise ValueError | |||
| self.__num_features = __num_features if __num_features > 0 else 0 | |||
| def get_num_classes(self) -> int: | |||
| # todo: consider replacing with property with getter and setter | |||
| return self.__num_classes | |||
| def set_num_classes(self, num_classes: int) -> None: | |||
| # todo: consider replacing with property with getter and setter | |||
| if type(num_classes) != int: | |||
| raise TypeError | |||
| self.__num_classes = num_classes if num_classes > 0 else 0 | |||
| def get_num_features(self) -> int: | |||
| # todo: consider replacing with property with getter and setter | |||
| return self.__num_features | |||
| def set_num_features(self, num_features: int): | |||
| # todo: consider replacing with property with getter and setter | |||
| if type(num_features) != int: | |||
| raise TypeError | |||
| self.__num_features = num_features if num_features > 0 else 0 | |||
| def set_num_graph_features(self, num_graph_features: int): | |||
| # todo: consider replacing with property with getter and setter | |||
| if type(num_graph_features) != int: | |||
| raise TypeError | |||
| else: | |||
| if num_graph_features > 0: | |||
| self.__num_graph_features = num_graph_features | |||
| else: | |||
| self.__num_graph_features = 0 | |||
| class _ClassificationModel(torch.nn.Module): | |||
| def __init__(self): | |||
| super(_ClassificationModel, self).__init__() | |||
| def cls_encode(self, data) -> torch.Tensor: | |||
| raise NotImplementedError | |||
| def cls_decode(self, x: torch.Tensor) -> torch.Tensor: | |||
| raise NotImplementedError | |||
| def cls_forward(self, data) -> torch.Tensor: | |||
| return self.cls_decode(self.cls_encode(data)) | |||
| class ClassificationSupportedSequentialModel(_ClassificationModel): | |||
| def __init__(self): | |||
| super(ClassificationSupportedSequentialModel, self).__init__() | |||
| @property | |||
| def sequential_encoding_layers(self) -> torch.nn.ModuleList: | |||
| raise NotImplementedError | |||
| def cls_encode(self, data) -> torch.Tensor: | |||
| raise NotImplementedError | |||
| def cls_decode(self, x: torch.Tensor) -> torch.Tensor: | |||
| raise NotImplementedError | |||
| @@ -3,7 +3,7 @@ import torch.nn.functional as F | |||
| from torch_geometric.nn import GATConv | |||
| from . import register_model | |||
| from .base import BaseModel, activate_func | |||
| from ...utils import get_logger | |||
| from ....utils import get_logger | |||
| LOGGER = get_logger("GATModel") | |||
| @@ -161,7 +161,6 @@ class AutoGAT(BaseModel): | |||
| self.num_features = num_features if num_features is not None else 0 | |||
| self.num_classes = int(num_classes) if num_classes is not None else 0 | |||
| self.device = device if device is not None else "cpu" | |||
| self.init = True | |||
| self.params = { | |||
| "features_num": self.num_features, | |||
| @@ -6,7 +6,7 @@ from torch_geometric.nn.conv import GCNConv | |||
| import autogl.data | |||
| from . import register_model | |||
| from .base import BaseModel, activate_func, ClassificationSupportedSequentialModel | |||
| from ...utils import get_logger | |||
| from ....utils import get_logger | |||
| LOGGER = get_logger("GCNModel") | |||
| @@ -6,7 +6,7 @@ from torch.nn import BatchNorm1d | |||
| from . import register_model | |||
| from .base import BaseModel, activate_func | |||
| from copy import deepcopy | |||
| from ...utils import get_logger | |||
| from ....utils import get_logger | |||
| LOGGER = get_logger("GINModel") | |||
| @@ -163,7 +163,6 @@ class AutoGIN(BaseModel): | |||
| int(num_graph_features) if num_graph_features is not None else 0 | |||
| ) | |||
| self.device = device if device is not None else "cpu" | |||
| self.init = True | |||
| self.params = { | |||
| "features_num": self.num_features, | |||
| @@ -0,0 +1,407 @@ | |||
| import typing as _typing | |||
| import torch.nn.functional | |||
| from torch_geometric.nn.conv import MessagePassing | |||
| from torch_sparse import SparseTensor, matmul | |||
| from . import register_model | |||
| from .base import ClassificationModel, ClassificationSupportedSequentialModel | |||
| class _GraphSAINTAggregationLayers: | |||
| class MultiOrderAggregationLayer(torch.nn.Module): | |||
| class Order0Aggregator(torch.nn.Module): | |||
| def __init__( | |||
| self, | |||
| input_dimension: int, | |||
| output_dimension: int, | |||
| bias: bool = True, | |||
| activation: _typing.Optional[str] = "ReLU", | |||
| batch_norm: bool = True, | |||
| ): | |||
| super().__init__() | |||
| if not type(input_dimension) == type(output_dimension) == int: | |||
| raise TypeError | |||
| if not (input_dimension > 0 and output_dimension > 0): | |||
| raise ValueError | |||
| if not type(bias) == bool: | |||
| raise TypeError | |||
| self.__linear_transform = torch.nn.Linear( | |||
| input_dimension, output_dimension, bias | |||
| ) | |||
| self.__linear_transform.reset_parameters() | |||
| if type(activation) == str: | |||
| if activation.lower() == "ReLU".lower(): | |||
| self.__activation = torch.nn.functional.relu | |||
| elif activation.lower() == "elu": | |||
| self.__activation = torch.nn.functional.elu | |||
| elif hasattr(torch.nn.functional, activation) and callable( | |||
| getattr(torch.nn.functional, activation) | |||
| ): | |||
| self.__activation = getattr(torch.nn.functional, activation) | |||
| else: | |||
| self.__activation = lambda x: x | |||
| else: | |||
| self.__activation = lambda x: x | |||
| if type(batch_norm) != bool: | |||
| raise TypeError | |||
| else: | |||
| self.__optional_batch_normalization: _typing.Optional[ | |||
| torch.nn.BatchNorm1d | |||
| ] = ( | |||
| torch.nn.BatchNorm1d(output_dimension, 1e-8) | |||
| if batch_norm | |||
| else None | |||
| ) | |||
| def forward( | |||
| self, | |||
| x: _typing.Union[ | |||
| torch.Tensor, _typing.Tuple[torch.Tensor, torch.Tensor] | |||
| ], | |||
| _edge_index: torch.Tensor, | |||
| _edge_weight: _typing.Optional[torch.Tensor] = None, | |||
| _size: _typing.Optional[_typing.Tuple[int, int]] = None, | |||
| ) -> torch.Tensor: | |||
| __output: torch.Tensor = self.__linear_transform(x) | |||
| if self.__activation is not None and callable(self.__activation): | |||
| __output: torch.Tensor = self.__activation(__output) | |||
| if self.__optional_batch_normalization is not None and isinstance( | |||
| self.__optional_batch_normalization, torch.nn.BatchNorm1d | |||
| ): | |||
| __output: torch.Tensor = self.__optional_batch_normalization( | |||
| __output | |||
| ) | |||
| return __output | |||
| class Order1Aggregator(MessagePassing): | |||
| def __init__( | |||
| self, | |||
| input_dimension: int, | |||
| output_dimension: int, | |||
| bias: bool = True, | |||
| activation: _typing.Optional[str] = "ReLU", | |||
| batch_norm: bool = True, | |||
| ): | |||
| super().__init__(aggr="add") | |||
| if not type(input_dimension) == type(output_dimension) == int: | |||
| raise TypeError | |||
| if not (input_dimension > 0 and output_dimension > 0): | |||
| raise ValueError | |||
| if not type(bias) == bool: | |||
| raise TypeError | |||
| self.__linear_transform = torch.nn.Linear( | |||
| input_dimension, output_dimension, bias | |||
| ) | |||
| self.__linear_transform.reset_parameters() | |||
| if type(activation) == str: | |||
| if activation.lower() == "ReLU".lower(): | |||
| self.__activation = torch.nn.functional.relu | |||
| elif activation.lower() == "elu": | |||
| self.__activation = torch.nn.functional.elu | |||
| elif hasattr(torch.nn.functional, activation) and callable( | |||
| getattr(torch.nn.functional, activation) | |||
| ): | |||
| self.__activation = getattr(torch.nn.functional, activation) | |||
| else: | |||
| self.__activation = lambda x: x | |||
| else: | |||
| self.__activation = lambda x: x | |||
| if type(batch_norm) != bool: | |||
| raise TypeError | |||
| else: | |||
| self.__optional_batch_normalization: _typing.Optional[ | |||
| torch.nn.BatchNorm1d | |||
| ] = ( | |||
| torch.nn.BatchNorm1d(output_dimension, 1e-8) | |||
| if batch_norm | |||
| else None | |||
| ) | |||
| def forward( | |||
| self, | |||
| x: _typing.Union[ | |||
| torch.Tensor, _typing.Tuple[torch.Tensor, torch.Tensor] | |||
| ], | |||
| _edge_index: torch.Tensor, | |||
| _edge_weight: _typing.Optional[torch.Tensor] = None, | |||
| _size: _typing.Optional[_typing.Tuple[int, int]] = None, | |||
| ) -> torch.Tensor: | |||
| if type(x) == torch.Tensor: | |||
| x: _typing.Tuple[torch.Tensor, torch.Tensor] = (x, x) | |||
| __output = self.propagate( | |||
| _edge_index, x=x, edge_weight=_edge_weight, size=_size | |||
| ) | |||
| __output: torch.Tensor = self.__linear_transform(__output) | |||
| if self.__activation is not None and callable(self.__activation): | |||
| __output: torch.Tensor = self.__activation(__output) | |||
| if self.__optional_batch_normalization is not None and isinstance( | |||
| self.__optional_batch_normalization, torch.nn.BatchNorm1d | |||
| ): | |||
| __output: torch.Tensor = self.__optional_batch_normalization( | |||
| __output | |||
| ) | |||
| return __output | |||
| def message( | |||
| self, x_j: torch.Tensor, edge_weight: _typing.Optional[torch.Tensor] | |||
| ) -> torch.Tensor: | |||
| return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j | |||
| def message_and_aggregate( | |||
| self, | |||
| adj_t: SparseTensor, | |||
| x: _typing.Union[ | |||
| torch.Tensor, _typing.Tuple[torch.Tensor, torch.Tensor] | |||
| ], | |||
| ) -> torch.Tensor: | |||
| return matmul(adj_t, x[0], reduce=self.aggr) | |||
| @property | |||
| def integral_output_dimension(self) -> int: | |||
| return (self._order + 1) * self._each_order_output_dimension | |||
| def __init__( | |||
| self, | |||
| _input_dimension: int, | |||
| _each_order_output_dimension: int, | |||
| _order: int, | |||
| bias: bool = True, | |||
| activation: _typing.Optional[str] = "ReLU", | |||
| batch_norm: bool = True, | |||
| _dropout: _typing.Optional[float] = ..., | |||
| ): | |||
| super().__init__() | |||
| if not ( | |||
| type(_input_dimension) == type(_order) == int | |||
| and type(_each_order_output_dimension) == int | |||
| ): | |||
| raise TypeError | |||
| if _input_dimension <= 0 or _each_order_output_dimension <= 0: | |||
| raise ValueError | |||
| if _order not in (0, 1): | |||
| raise ValueError("Unsupported order number") | |||
| self._input_dimension: int = _input_dimension | |||
| self._each_order_output_dimension: int = _each_order_output_dimension | |||
| self._order: int = _order | |||
| if type(bias) != bool: | |||
| raise TypeError | |||
| self.__order0_transform = self.Order0Aggregator( | |||
| self._input_dimension, | |||
| self._each_order_output_dimension, | |||
| bias, | |||
| activation, | |||
| batch_norm, | |||
| ) | |||
| if _order == 1: | |||
| self.__order1_transform = self.Order1Aggregator( | |||
| self._input_dimension, | |||
| self._each_order_output_dimension, | |||
| bias, | |||
| activation, | |||
| batch_norm, | |||
| ) | |||
| else: | |||
| self.__order1_transform = None | |||
| if _dropout is not None and type(_dropout) == float: | |||
| if _dropout < 0: | |||
| _dropout = 0 | |||
| if _dropout > 1: | |||
| _dropout = 1 | |||
| self.__optional_dropout: _typing.Optional[ | |||
| torch.nn.Dropout | |||
| ] = torch.nn.Dropout(_dropout) | |||
| else: | |||
| self.__optional_dropout: _typing.Optional[torch.nn.Dropout] = None | |||
| def _forward( | |||
| self, | |||
| x: _typing.Union[torch.Tensor, _typing.Tuple[torch.Tensor, torch.Tensor]], | |||
| edge_index: torch.Tensor, | |||
| edge_weight: _typing.Optional[torch.Tensor] = None, | |||
| size: _typing.Optional[_typing.Tuple[int, int]] = None, | |||
| ) -> torch.Tensor: | |||
| if self.__order1_transform is not None and isinstance( | |||
| self.__order1_transform, self.Order1Aggregator | |||
| ): | |||
| __output: torch.Tensor = torch.cat( | |||
| [ | |||
| self.__order0_transform(x, edge_index, edge_weight, size), | |||
| self.__order1_transform(x, edge_index, edge_weight, size), | |||
| ], | |||
| dim=1, | |||
| ) | |||
| else: | |||
| __output: torch.Tensor = self.__order0_transform( | |||
| x, edge_index, edge_weight, size | |||
| ) | |||
| if self.__optional_dropout is not None and isinstance( | |||
| self.__optional_dropout, torch.nn.Dropout | |||
| ): | |||
| __output: torch.Tensor = self.__optional_dropout(__output) | |||
| return __output | |||
| def forward(self, data) -> torch.Tensor: | |||
| x: torch.Tensor = getattr(data, "x") | |||
| if type(x) != torch.Tensor: | |||
| raise TypeError | |||
| edge_index: torch.LongTensor = getattr(data, "edge_index") | |||
| if type(edge_index) != torch.Tensor: | |||
| raise TypeError | |||
| edge_weight: _typing.Optional[torch.Tensor] = getattr( | |||
| data, "edge_weight", None | |||
| ) | |||
| if edge_weight is not None and type(edge_weight) != torch.Tensor: | |||
| raise TypeError | |||
| return self._forward(x, edge_index, edge_weight) | |||
| class WrappedDropout(torch.nn.Module): | |||
| def __init__(self, dropout_module: torch.nn.Dropout): | |||
| super().__init__() | |||
| self.__dropout_module: torch.nn.Dropout = dropout_module | |||
| def forward(self, tenser_or_data) -> torch.Tensor: | |||
| if type(tenser_or_data) == torch.Tensor: | |||
| return self.__dropout_module(tenser_or_data) | |||
| elif ( | |||
| hasattr(tenser_or_data, "x") | |||
| and type(getattr(tenser_or_data, "x")) == torch.Tensor | |||
| ): | |||
| return self.__dropout_module(getattr(tenser_or_data, "x")) | |||
| else: | |||
| raise TypeError | |||
| class GraphSAINTMultiOrderAggregationModel(ClassificationSupportedSequentialModel): | |||
| def __init__( | |||
| self, | |||
| num_features: int, | |||
| num_classes: int, | |||
| _output_dimension_for_each_order: int, | |||
| _layers_order_list: _typing.Sequence[int], | |||
| _pre_dropout: float, | |||
| _layers_dropout: _typing.Union[float, _typing.Sequence[float]], | |||
| activation: _typing.Optional[str] = "ReLU", | |||
| bias: bool = True, | |||
| batch_norm: bool = True, | |||
| normalize: bool = True, | |||
| ): | |||
| super(GraphSAINTMultiOrderAggregationModel, self).__init__() | |||
| if type(_output_dimension_for_each_order) != int: | |||
| raise TypeError | |||
| if not _output_dimension_for_each_order > 0: | |||
| raise ValueError | |||
| self._layers_order_list: _typing.Sequence[int] = _layers_order_list | |||
| if isinstance(_layers_dropout, _typing.Sequence): | |||
| if len(_layers_dropout) != len(_layers_order_list): | |||
| raise ValueError | |||
| else: | |||
| self._layers_dropout: _typing.Sequence[float] = _layers_dropout | |||
| elif type(_layers_dropout) == float: | |||
| if _layers_dropout < 0: | |||
| _layers_dropout = 0 | |||
| if _layers_dropout > 1: | |||
| _layers_dropout = 1 | |||
| self._layers_dropout: _typing.Sequence[float] = [ | |||
| _layers_dropout for _ in _layers_order_list | |||
| ] | |||
| else: | |||
| raise TypeError | |||
| if type(_pre_dropout) != float: | |||
| raise TypeError | |||
| else: | |||
| if _pre_dropout < 0: | |||
| _pre_dropout = 0 | |||
| if _pre_dropout > 1: | |||
| _pre_dropout = 1 | |||
| self.__sequential_encoding_layers: torch.nn.ModuleList = torch.nn.ModuleList( | |||
| ( | |||
| _GraphSAINTAggregationLayers.WrappedDropout( | |||
| torch.nn.Dropout(_pre_dropout) | |||
| ), | |||
| _GraphSAINTAggregationLayers.MultiOrderAggregationLayer( | |||
| num_features, | |||
| _output_dimension_for_each_order, | |||
| _layers_order_list[0], | |||
| bias, | |||
| activation, | |||
| batch_norm, | |||
| _layers_dropout[0], | |||
| ), | |||
| ) | |||
| ) | |||
| for _layer_index in range(1, len(_layers_order_list)): | |||
| self.__sequential_encoding_layers.append( | |||
| _GraphSAINTAggregationLayers.MultiOrderAggregationLayer( | |||
| self.__sequential_encoding_layers[-1].integral_output_dimension, | |||
| _output_dimension_for_each_order, | |||
| _layers_order_list[_layer_index], | |||
| bias, | |||
| activation, | |||
| batch_norm, | |||
| _layers_dropout[_layer_index], | |||
| ) | |||
| ) | |||
| self.__apply_normalize: bool = normalize | |||
| self.__linear_transform: torch.nn.Linear = torch.nn.Linear( | |||
| self.__sequential_encoding_layers[-1].integral_output_dimension, | |||
| num_classes, | |||
| bias, | |||
| ) | |||
| self.__linear_transform.reset_parameters() | |||
| def cls_decode(self, x: torch.Tensor) -> torch.Tensor: | |||
| if self.__apply_normalize: | |||
| x: torch.Tensor = torch.nn.functional.normalize(x, p=2, dim=1) | |||
| return torch.nn.functional.log_softmax(self.__linear_transform(x), dim=1) | |||
| def cls_encode(self, data) -> torch.Tensor: | |||
| if type(getattr(data, "x")) != torch.Tensor: | |||
| raise TypeError | |||
| if type(getattr(data, "edge_index")) != torch.Tensor: | |||
| raise TypeError | |||
| if ( | |||
| getattr(data, "edge_weight", None) is not None | |||
| and type(getattr(data, "edge_weight")) != torch.Tensor | |||
| ): | |||
| raise TypeError | |||
| for encoding_layer in self.__sequential_encoding_layers: | |||
| setattr(data, "x", encoding_layer(data)) | |||
| return getattr(data, "x") | |||
| @property | |||
| def sequential_encoding_layers(self) -> torch.nn.ModuleList: | |||
| return self.__sequential_encoding_layers | |||
| @register_model("GraphSAINTAggregationModel") | |||
| class GraphSAINTAggregationModel(ClassificationModel): | |||
| def __init__( | |||
| self, | |||
| num_features: int = ..., | |||
| num_classes: int = ..., | |||
| device: _typing.Union[str, torch.device] = ..., | |||
| init: bool = False, | |||
| **kwargs | |||
| ): | |||
| super(GraphSAINTAggregationModel, self).__init__( | |||
| num_features, num_classes, device=device, init=init, **kwargs | |||
| ) | |||
| # todo: Initialize with default hyper parameter space and hyper parameter | |||
| def _initialize(self): | |||
| """ Initialize model """ | |||
| self.model = GraphSAINTMultiOrderAggregationModel( | |||
| self.num_features, | |||
| self.num_classes, | |||
| self.hyper_parameter.get("output_dimension_for_each_order"), | |||
| self.hyper_parameter.get("layers_order_list"), | |||
| self.hyper_parameter.get("pre_dropout"), | |||
| self.hyper_parameter.get("layers_dropout"), | |||
| self.hyper_parameter.get("activation", "ReLU"), | |||
| bool(self.hyper_parameter.get("bias", True)), | |||
| bool(self.hyper_parameter.get("batch_norm", True)), | |||
| bool(self.hyper_parameter.get("normalize", True)), | |||
| ).to(self.device) | |||
| @@ -6,7 +6,7 @@ import torch.nn.functional | |||
| import autogl.data | |||
| from . import register_model | |||
| from .base import BaseModel, activate_func, ClassificationSupportedSequentialModel | |||
| from ...utils import get_logger | |||
| from ....utils import get_logger | |||
| LOGGER = get_logger("SAGEModel") | |||
| @@ -238,7 +238,6 @@ class AutoSAGE(BaseModel): | |||
| self.num_features = num_features if num_features is not None else 0 | |||
| self.num_classes = int(num_classes) if num_classes is not None else 0 | |||
| self.device = device if device is not None else "cpu" | |||
| self.init = True | |||
| self.params = { | |||
| "features_num": self.num_features, | |||
| @@ -4,7 +4,7 @@ from torch_geometric.nn import GraphConv, TopKPooling | |||
| from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp | |||
| from . import register_model | |||
| from .base import BaseModel, activate_func | |||
| from ...utils import get_logger | |||
| from ....utils import get_logger | |||
| LOGGER = get_logger("TopkModel") | |||
| @@ -126,7 +126,6 @@ class AutoTopkpool(BaseModel): | |||
| int(num_graph_features) if num_graph_features is not None else 0 | |||
| ) | |||
| self.device = device if device is not None else "cpu" | |||
| self.init = True | |||
| self.params = { | |||
| "features_num": self.num_features, | |||
| @@ -17,6 +17,8 @@ import torch.multiprocessing as mp | |||
| from ...utils import get_logger | |||
| from ...backend import DependentBackend | |||
| LOGGER = get_logger("graph classification solver") | |||
| @@ -71,6 +73,7 @@ class GraphClassificationFullTrainer(BaseGraphClassificationTrainer): | |||
| feval=[Logloss], | |||
| loss="nll_loss", | |||
| lr_scheduler_type=None, | |||
| criterion=None, | |||
| *args, | |||
| **kwargs | |||
| ): | |||
| @@ -124,6 +127,9 @@ class GraphClassificationFullTrainer(BaseGraphClassificationTrainer): | |||
| self.initialized = False | |||
| self.device = device | |||
| self.pyg_dgl = DependentBackend.get_backend_name() | |||
| self.criterion = criterion | |||
| self.space = [ | |||
| { | |||
| "parameterName": "max_epoch", | |||
| @@ -233,39 +239,61 @@ class GraphClassificationFullTrainer(BaseGraphClassificationTrainer): | |||
| self.model.model.train() | |||
| loss_all = 0 | |||
| for data in train_loader: | |||
| data = data.to(self.device) | |||
| optimizer.zero_grad() | |||
| output = self.model.model(data) | |||
| # loss = F.nll_loss(output, data.y) | |||
| if hasattr(F, self.loss): | |||
| loss = getattr(F, self.loss)(output, data.y) | |||
| else: | |||
| raise TypeError( | |||
| "PyTorch does not support loss type {}".format(self.loss) | |||
| ) | |||
| loss.backward() | |||
| loss_all += data.num_graphs * loss.item() | |||
| if self.pyg_dgl == 'pyg': | |||
| data = data.to(self.device) | |||
| optimizer.zero_grad() | |||
| output = self.model.model(data) | |||
| # loss = F.nll_loss(output, data.y) | |||
| if hasattr(F, self.loss): | |||
| loss = getattr(F, self.loss)(output, data.y) | |||
| else: | |||
| raise TypeError( | |||
| "PyTorch does not support loss type {}".format(self.loss) | |||
| ) | |||
| loss.backward() | |||
| loss_all += data.num_graphs * loss.item() | |||
| elif self.pyg_dgl == 'dgl': | |||
| data = [data[i].to(self.device) for i in range(len(data))] | |||
| _, labels = data | |||
| optimizer.zero_grad() | |||
| output = self.model.model(data) | |||
| if hasattr(F, self.loss): | |||
| loss = getattr(F, self.loss)(output, labels) | |||
| else: | |||
| raise TypeError( | |||
| "PyTorch does not support loss type {}".format(self.loss) | |||
| ) | |||
| # print('loss', self.loss) | |||
| loss.backward() | |||
| loss_all += len(labels) * loss.item() | |||
| optimizer.step() | |||
| if self.lr_scheduler_type: | |||
| scheduler.step() | |||
| # loss = loss_all / len(train_loader.dataset) | |||
| # train_loss = self.evaluate(train_loader) | |||
| if valid_loader is not None: | |||
| eval_func = ( | |||
| self.feval if not isinstance(self.feval, list) else self.feval[0] | |||
| ) | |||
| val_loss = self._evaluate(valid_loader, eval_func) | |||
| # print(val_loss) | |||
| if eval_func.is_higher_better(): | |||
| val_loss = -val_loss | |||
| self.early_stopping(val_loss, self.model.model) | |||
| if self.early_stopping.early_stop: | |||
| LOGGER.debug("Early stopping at", epoch) | |||
| break | |||
| if valid_loader is not None: | |||
| self.early_stopping.load_checkpoint(self.model.model) | |||
| def predict_only(self, loader): | |||
| def predict_only(self, loader, return_label=False): | |||
| """ | |||
| The function of predicting on the given dataset and mask. | |||
| @@ -281,11 +309,25 @@ class GraphClassificationFullTrainer(BaseGraphClassificationTrainer): | |||
| """ | |||
| self.model.model.eval() | |||
| pred = [] | |||
| label = [] | |||
| for data in loader: | |||
| data = data.to(self.device) | |||
| pred.append(self.model.model(data)) | |||
| if self.pyg_dgl == 'pyg': | |||
| data = data.to(self.device) | |||
| pred.append(self.model.model(data)) | |||
| label.append(data.y) | |||
| elif self.pyg_dgl == 'dgl': | |||
| data = [data[i].to(self.device) for i in range(len(data))] | |||
| _, labels = data | |||
| output = self.model.model(data) | |||
| pred.append(output) | |||
| label.append(labels) | |||
| ret = torch.cat(pred, 0) | |||
| return ret | |||
| label = torch.cat(label, 0) | |||
| if return_label: | |||
| return ret, label | |||
| else: | |||
| return ret | |||
| def train(self, dataset, keep_valid_result=True): | |||
| """ | |||
| @@ -332,6 +374,7 @@ class GraphClassificationFullTrainer(BaseGraphClassificationTrainer): | |||
| ------- | |||
| The prediction result of ``predict_proba``. | |||
| """ | |||
| loader = utils.graph_get_split( | |||
| dataset, mask, batch_size=self.batch_size, num_workers=self.num_workers | |||
| ) | |||
| @@ -360,12 +403,23 @@ class GraphClassificationFullTrainer(BaseGraphClassificationTrainer): | |||
| ) | |||
| return self._predict_proba(loader, in_log_format) | |||
| def _predict_proba(self, loader, in_log_format=False): | |||
| ret = self.predict_only(loader) | |||
| if in_log_format is True: | |||
| return ret | |||
| def _predict_proba(self, loader, in_log_format=False, return_label=False): | |||
| if return_label: | |||
| ret, label = self.predict_only(loader, return_label=True) | |||
| else: | |||
| return torch.exp(ret) | |||
| ret = self.predict_only(loader, return_label=False) | |||
| if self.pyg_dgl == 'dgl': | |||
| ret = F.log_softmax(ret, dim=1) | |||
| if in_log_format is False: | |||
| ret = torch.exp(ret) | |||
| if return_label: | |||
| return ret, label | |||
| else: | |||
| return ret | |||
| def get_valid_predict(self): | |||
| # """Get the valid result.""" | |||
| @@ -430,23 +484,33 @@ class GraphClassificationFullTrainer(BaseGraphClassificationTrainer): | |||
| res: The evaluation result on the given dataset. | |||
| """ | |||
| loader = utils.graph_get_split( | |||
| dataset, mask, batch_size=self.batch_size, num_workers=self.num_workers | |||
| ) | |||
| return self._evaluate(loader, feval) | |||
| def _evaluate(self, loader, feval=None): | |||
| if feval is None: | |||
| feval = self.feval | |||
| else: | |||
| feval = get_feval(feval) | |||
| y_pred_prob = self._predict_proba(loader=loader) | |||
| y_pred_prob, y_true = self._predict_proba(loader=loader, return_label=True) | |||
| y_pred = y_pred_prob.max(1)[1] | |||
| y_true_tmp = [] | |||
| for data in loader: | |||
| y_true_tmp.append(data.y) | |||
| y_true = torch.cat(y_true_tmp, 0) | |||
| # y_pred_prob = self._predict_proba(loader=loader) | |||
| # y_pred = y_pred_prob.max(1)[1] | |||
| # | |||
| # y_true_tmp = [] | |||
| # for data in loader: | |||
| # if self.pyg_dgl == 'pyg': | |||
| # y_true_tmp.append(data.y) | |||
| # elif self.pyg_dgl == 'dgl': | |||
| # graphs, labels = data | |||
| # y_true_tmp.append(labels) | |||
| # y_true = torch.cat(y_true_tmp, 0) | |||
| if not isinstance(feval, list): | |||
| feval = [feval] | |||