|
- import torch.utils.data
- from torch.utils.data.dataloader import default_collate
-
- from .batch import Batch
-
-
- class DataLoader(torch.utils.data.DataLoader):
- r"""Data loader which merges data objects from a
- :class:`cogdl.data.dataset` to a mini-batch.
-
- Args:
- dataset (Dataset): The dataset from which to load the data.
- batch_size (int, optional): How may samples per batch to load.
- (default: :obj:`1`)
- shuffle (bool, optional): If set to :obj:`True`, the data will be
- reshuffled at every epoch (default: :obj:`True`)
- """
-
- def __init__(self, dataset, batch_size=1, shuffle=True, **kwargs):
- super(DataLoader, self).__init__(
- dataset,
- batch_size,
- shuffle,
- collate_fn=lambda data_list: Batch.from_data_list(data_list),
- **kwargs
- )
-
-
- class DataListLoader(torch.utils.data.DataLoader):
- r"""Data loader which merges data objects from a
- :class:`cogdl.data.dataset` to a python list.
-
- .. note::
-
- This data loader should be used for multi-gpu support via
- :class:`cogdl.nn.DataParallel`.
-
- Args:
- dataset (Dataset): The dataset from which to load the data.
- batch_size (int, optional): How may samples per batch to load.
- (default: :obj:`1`)
- shuffle (bool, optional): If set to :obj:`True`, the data will be
- reshuffled at every epoch (default: :obj:`True`)
- """
-
- def __init__(self, dataset, batch_size=1, shuffle=True, **kwargs):
- super(DataListLoader, self).__init__(
- dataset,
- batch_size,
- shuffle,
- collate_fn=lambda data_list: data_list,
- **kwargs
- )
-
-
- class DenseDataLoader(torch.utils.data.DataLoader):
- r"""Data loader which merges data objects from a
- :class:`cogdl.data.dataset` to a mini-batch.
-
- .. note::
-
- To make use of this data loader, all graphs in the dataset needs to
- have the same shape for each its attributes.
- Therefore, this data loader should only be used when working with
- *dense* adjacency matrices.
-
- Args:
- dataset (Dataset): The dataset from which to load the data.
- batch_size (int, optional): How may samples per batch to load.
- (default: :obj:`1`)
- shuffle (bool, optional): If set to :obj:`True`, the data will be
- reshuffled at every epoch (default: :obj:`True`)
- """
-
- def __init__(self, dataset, batch_size=1, shuffle=True, **kwargs):
- def dense_collate(data_list):
- batch = Batch()
- for key in data_list[0].keys:
- batch[key] = default_collate([d[key] for d in data_list])
- return batch
-
- super(DenseDataLoader, self).__init__(
- dataset, batch_size, shuffle, collate_fn=dense_collate, **kwargs
- )
|