|
- import re
-
- import torch
- from .data import Data
-
-
- class Batch(Data):
- r"""A plain old python object modeling a batch of graphs as one big
- (dicconnected) graph. With :class:`cogdl.data.Data` being the
- base class, all its methods can also be used here.
- In addition, single graphs can be reconstructed via the assignment vector
- :obj:`batch`, which maps each node to its respective graph identifier.
- """
-
- def __init__(self, batch=None, **kwargs):
- super(Batch, self).__init__(**kwargs)
- self.batch = batch
- self.__data_class__ = Data
- self.__slices__ = None
-
- @staticmethod
- def from_data_list(data_list, follow_batch=[]):
- r"""Constructs a batch object from a python list holding
- :class:`torch_geometric.data.Data` objects.
- The assignment vector :obj:`batch` is created on the fly.
- Additionally, creates assignment batch vectors for each key in
- :obj:`follow_batch`."""
-
- keys = [set(data.keys) for data in data_list]
- keys = list(set.union(*keys))
- assert "batch" not in keys
-
- batch = Batch()
- batch.__data_class__ = data_list[0].__class__
- batch.__slices__ = {key: [0] for key in keys}
-
- for key in keys:
- batch[key] = []
-
- for key in follow_batch:
- batch["{}_batch".format(key)] = []
-
- cumsum = {key: 0 for key in keys}
- batch.batch = []
- for i, data in enumerate(data_list):
- for key in data.keys:
- item = data[key]
- if torch.is_tensor(item) and item.dtype != torch.bool:
- item = item + cumsum[key]
- if torch.is_tensor(item):
- size = item.size(data.cat_dim(key, data[key]))
- else:
- size = 1
- batch.__slices__[key].append(size + batch.__slices__[key][-1])
- cumsum[key] = cumsum[key] + data.__inc__(key, item)
- batch[key].append(item)
-
- if key in follow_batch:
- item = torch.full((size,), i, dtype=torch.long)
- batch["{}_batch".format(key)].append(item)
-
- num_nodes = data.num_nodes
- if num_nodes is not None:
- item = torch.full((num_nodes,), i, dtype=torch.long)
- batch.batch.append(item)
-
- if num_nodes is None:
- batch.batch = None
- for key in batch.keys:
- item = batch[key][0]
- if torch.is_tensor(item):
- batch[key] = torch.cat(batch[key], dim=data_list[0].cat_dim(key, item))
- elif isinstance(item, int) or isinstance(item, float):
- batch[key] = torch.tensor(batch[key])
- return batch.contiguous()
-
- def cumsum(self, key, item):
- r"""If :obj:`True`, the attribute :obj:`key` with content :obj:`item`
- should be added up cumulatively before concatenated together.
-
- .. note::
-
- This method is for internal use only, and should only be overridden
- if the batch concatenation process is corrupted for a specific data
- attribute.
- """
- return bool(re.search("(index|face)", key))
-
- def to_data_list(self):
- r"""Reconstructs the list of :class:`torch_geometric.data.Data` objects
- from the batch object.
- The batch object must have been created via :meth:`from_data_list` in
- order to be able reconstruct the initial objects."""
-
- if self.__slices__ is None:
- raise RuntimeError(
- (
- "Cannot reconstruct data list from batch because the batch "
- "object was not created using Batch.from_data_list()"
- )
- )
-
- keys = [key for key in self.keys if key[-5:] != "batch"]
- cumsum = {key: 0 for key in keys}
- data_list = []
- for i in range(len(self.__slices__[keys[0]]) - 1):
- data = self.__data_class__()
- for key in keys:
- if torch.is_tensor(self[key]):
- data[key] = self[key].narrow(
- data.cat_dim(key, self[key]),
- self.__slices__[key][i],
- self.__slices__[key][i + 1] - self.__slices__[key][i],
- )
- if self[key].dtype != torch.bool:
- data[key] = data[key] - cumsum[key]
- else:
- data[key] = self[key][
- self.__slices__[key][i] : self.__slices__[key][i + 1]
- ]
- cumsum[key] = cumsum[key] + data.__inc__(key, data[key])
- data_list.append(data)
-
- return data_list
-
- @property
- def num_graphs(self):
- """Returns the number of graphs in the batch."""
- return self.batch[-1].item() + 1
|