|
- import collections
- import os.path as osp
-
- import torch.utils.data
-
- from .makedirs import makedirs
-
-
- def to_list(x):
- if not isinstance(x, collections.Iterable) or isinstance(x, str):
- x = [x]
- return x
-
-
- def files_exist(files):
- return all([osp.exists(f) for f in files])
-
-
- class Dataset(torch.utils.data.Dataset):
- r"""Dataset base class for creating graph datasets.
- See `here <https://rusty1s.github.io/pycogdl/build/html/notes/
- create_dataset.html>`__ for the accompanying tutorial.
-
- Args:
- root (string): Root directory where the dataset should be saved.
- transform (callable, optional): A function/transform that takes in an
- :obj:`cogdl.data.Data` object and returns a transformed
- version. The data object will be transformed before every access.
- (default: :obj:`None`)
- pre_transform (callable, optional): A function/transform that takes in
- an :obj:`cogdl.data.Data` object and returns a
- transformed version. The data object will be transformed before
- being saved to disk. (default: :obj:`None`)
- pre_filter (callable, optional): A function that takes in an
- :obj:`cogdl.data.Data` object and returns a boolean
- value, indicating whether the data object should be included in the
- final dataset. (default: :obj:`None`)
- """
-
- @property
- def raw_file_names(self):
- r"""The name of the files to find in the :obj:`self.raw_dir` folder in
- order to skip the download."""
- raise NotImplementedError
-
- @property
- def processed_file_names(self):
- r"""The name of the files to find in the :obj:`self.processed_dir`
- folder in order to skip the processing."""
- raise NotImplementedError
-
- def download(self):
- r"""Downloads the dataset to the :obj:`self.raw_dir` folder."""
- raise NotImplementedError
-
- def process(self):
- r"""Processes the dataset to the :obj:`self.processed_dir` folder."""
- raise NotImplementedError
-
- def __len__(self):
- r"""The number of examples in the dataset."""
- raise NotImplementedError
-
- def get(self, idx):
- r"""Gets the data object at index :obj:`idx`."""
- raise NotImplementedError
-
- def __init__(self, root, transform=None, pre_transform=None, pre_filter=None):
- super(Dataset, self).__init__()
-
- self.root = osp.expanduser(osp.normpath(root))
- self.raw_dir = osp.join(self.root, "raw")
- self.processed_dir = osp.join(self.root, "processed")
- self.transform = transform
- self.pre_transform = pre_transform
- self.pre_filter = pre_filter
-
- self._download()
- self._process()
-
- @property
- def get_label_number(self):
- r"""Get the number of labels in this dataset as dict."""
- label_num = {}
- labels = self[0].y.unique().cpu().detach().numpy().tolist()
- for label in labels:
- label_num[label] = (self[0].y == label).sum().item()
- return label_num
-
- @property
- def num_features(self):
- r"""Returns the number of features per node in the graph."""
- return self[0].num_features
-
- @property
- def raw_paths(self):
- r"""The filepaths to find in order to skip the download."""
- files = to_list(self.raw_file_names)
- return [osp.join(self.raw_dir, f) for f in files]
-
- @property
- def processed_paths(self):
- r"""The filepaths to find in the :obj:`self.processed_dir`
- folder in order to skip the processing."""
- files = to_list(self.processed_file_names)
- return [osp.join(self.processed_dir, f) for f in files]
-
- def _download(self):
- if files_exist(self.raw_paths): # pragma: no cover
- return
-
- makedirs(self.raw_dir)
- self.download()
-
- def _process(self):
- if files_exist(self.processed_paths): # pragma: no cover
- return
-
- print("Processing...")
-
- makedirs(self.processed_dir)
- self.process()
-
- print("Done!")
-
- def __getitem__(self, idx): # pragma: no cover
- r"""Gets the data object at index :obj:`idx` and transforms it (in case
- a :obj:`self.transform` is given)."""
- data = self.get(idx)
- data = data if self.transform is None else self.transform(data)
- return data
-
- def __repr__(self): # pragma: no cover
- return "{}({})".format(self.__class__.__name__, len(self))
|