You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

batch.py 4.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. import re
  2. import torch
  3. from .data import Data
  4. class Batch(Data):
  5. r"""A plain old python object modeling a batch of graphs as one big
  6. (dicconnected) graph. With :class:`cogdl.data.Data` being the
  7. base class, all its methods can also be used here.
  8. In addition, single graphs can be reconstructed via the assignment vector
  9. :obj:`batch`, which maps each node to its respective graph identifier.
  10. """
  11. def __init__(self, batch=None, **kwargs):
  12. super(Batch, self).__init__(**kwargs)
  13. self.batch = batch
  14. self.__data_class__ = Data
  15. self.__slices__ = None
  16. @staticmethod
  17. def from_data_list(data_list, follow_batch=[]):
  18. r"""Constructs a batch object from a python list holding
  19. :class:`torch_geometric.data.Data` objects.
  20. The assignment vector :obj:`batch` is created on the fly.
  21. Additionally, creates assignment batch vectors for each key in
  22. :obj:`follow_batch`."""
  23. keys = [set(data.keys) for data in data_list]
  24. keys = list(set.union(*keys))
  25. assert "batch" not in keys
  26. batch = Batch()
  27. batch.__data_class__ = data_list[0].__class__
  28. batch.__slices__ = {key: [0] for key in keys}
  29. for key in keys:
  30. batch[key] = []
  31. for key in follow_batch:
  32. batch["{}_batch".format(key)] = []
  33. cumsum = {key: 0 for key in keys}
  34. batch.batch = []
  35. for i, data in enumerate(data_list):
  36. for key in data.keys:
  37. item = data[key]
  38. if torch.is_tensor(item) and item.dtype != torch.bool:
  39. item = item + cumsum[key]
  40. if torch.is_tensor(item):
  41. size = item.size(data.cat_dim(key, data[key]))
  42. else:
  43. size = 1
  44. batch.__slices__[key].append(size + batch.__slices__[key][-1])
  45. cumsum[key] = cumsum[key] + data.__inc__(key, item)
  46. batch[key].append(item)
  47. if key in follow_batch:
  48. item = torch.full((size,), i, dtype=torch.long)
  49. batch["{}_batch".format(key)].append(item)
  50. num_nodes = data.num_nodes
  51. if num_nodes is not None:
  52. item = torch.full((num_nodes,), i, dtype=torch.long)
  53. batch.batch.append(item)
  54. if num_nodes is None:
  55. batch.batch = None
  56. for key in batch.keys:
  57. item = batch[key][0]
  58. if torch.is_tensor(item):
  59. batch[key] = torch.cat(batch[key], dim=data_list[0].cat_dim(key, item))
  60. elif isinstance(item, int) or isinstance(item, float):
  61. batch[key] = torch.tensor(batch[key])
  62. return batch.contiguous()
  63. def cumsum(self, key, item):
  64. r"""If :obj:`True`, the attribute :obj:`key` with content :obj:`item`
  65. should be added up cumulatively before concatenated together.
  66. .. note::
  67. This method is for internal use only, and should only be overridden
  68. if the batch concatenation process is corrupted for a specific data
  69. attribute.
  70. """
  71. return bool(re.search("(index|face)", key))
  72. def to_data_list(self):
  73. r"""Reconstructs the list of :class:`torch_geometric.data.Data` objects
  74. from the batch object.
  75. The batch object must have been created via :meth:`from_data_list` in
  76. order to be able reconstruct the initial objects."""
  77. if self.__slices__ is None:
  78. raise RuntimeError(
  79. (
  80. "Cannot reconstruct data list from batch because the batch "
  81. "object was not created using Batch.from_data_list()"
  82. )
  83. )
  84. keys = [key for key in self.keys if key[-5:] != "batch"]
  85. cumsum = {key: 0 for key in keys}
  86. data_list = []
  87. for i in range(len(self.__slices__[keys[0]]) - 1):
  88. data = self.__data_class__()
  89. for key in keys:
  90. if torch.is_tensor(self[key]):
  91. data[key] = self[key].narrow(
  92. data.cat_dim(key, self[key]),
  93. self.__slices__[key][i],
  94. self.__slices__[key][i + 1] - self.__slices__[key][i],
  95. )
  96. if self[key].dtype != torch.bool:
  97. data[key] = data[key] - cumsum[key]
  98. else:
  99. data[key] = self[key][
  100. self.__slices__[key][i] : self.__slices__[key][i + 1]
  101. ]
  102. cumsum[key] = cumsum[key] + data.__inc__(key, data[key])
  103. data_list.append(data)
  104. return data_list
  105. @property
  106. def num_graphs(self):
  107. """Returns the number of graphs in the batch."""
  108. return self.batch[-1].item() + 1