You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351
  1. import re
  2. import torch
  3. def index_to_mask(index, size):
  4. mask = torch.zeros(size, dtype=torch.bool, device=index.device)
  5. mask[index] = 1
  6. return mask
  7. class Data(object):
  8. r"""A plain python object modeling a single graph with various
  9. (optional) attributes:
  10. Args:
  11. x (Tensor, optional): Node feature matrix with shape :obj:`[num_nodes,
  12. num_node_features]`. (default: :obj:`None`)
  13. edge_index (LongTensor, optional): Graph connectivity in COO format
  14. with shape :obj:`[2, num_edges]`. (default: :obj:`None`)
  15. edge_attr (Tensor, optional): Edge feature matrix with shape
  16. :obj:`[num_edges, num_edge_features]`. (default: :obj:`None`)
  17. y (Tensor, optional): Graph or node targets with arbitrary shape.
  18. (default: :obj:`None`)
  19. pos (Tensor, optional): Node position matrix with shape
  20. :obj:`[num_nodes, num_dimensions]`. (default: :obj:`None`)
  21. The data object is not restricted to these attributes and can be extented
  22. by any other additional data.
  23. """
  24. def __init__(self, x=None, edge_index=None, edge_attr=None, y=None, pos=None):
  25. self.x = x
  26. self.edge_index = edge_index
  27. self.edge_attr = edge_attr
  28. self.y = y
  29. self.pos = pos
  30. @staticmethod
  31. def from_dict(dictionary):
  32. r"""Creates a data object from a python dictionary."""
  33. data = Data()
  34. for key, item in dictionary.items():
  35. data[key] = item
  36. return data
  37. def __getitem__(self, key):
  38. r"""Gets the data of the attribute :obj:`key`."""
  39. return getattr(self, key)
  40. def __setitem__(self, key, value):
  41. """Sets the attribute :obj:`key` to :obj:`value`."""
  42. setattr(self, key, value)
  43. @property
  44. def keys(self):
  45. r"""Returns all names of graph attributes."""
  46. keys = [key for key in self.__dict__.keys() if self[key] is not None]
  47. keys = [key for key in keys if key[:2] != "__" and key[-2:] != "__"]
  48. return keys
  49. def __len__(self):
  50. r"""Returns the number of all present attributes."""
  51. return len(self.keys)
  52. def __contains__(self, key):
  53. r"""Returns :obj:`True`, if the attribute :obj:`key` is present in the
  54. data."""
  55. return key in self.keys
  56. def __iter__(self):
  57. r"""Iterates over all present attributes in the data, yielding their
  58. attribute names and content."""
  59. for key in sorted(self.keys):
  60. yield key, self[key]
  61. def __call__(self, *keys):
  62. r"""Iterates over all attributes :obj:`*keys` in the data, yielding
  63. their attribute names and content.
  64. If :obj:`*keys` is not given this method will iterative over all
  65. present attributes."""
  66. for key in sorted(self.keys) if not keys else keys:
  67. if self[key] is not None:
  68. yield key, self[key]
  69. def cat_dim(self, key, value):
  70. r"""Returns the dimension in which the attribute :obj:`key` with
  71. content :obj:`value` gets concatenated when creating batches.
  72. .. note::
  73. This method is for internal use only, and should only be overridden
  74. if the batch concatenation process is corrupted for a specific data
  75. attribute.
  76. """
  77. # `*index*` and `*face*` should be concatenated in the last dimension,
  78. # everything else in the first dimension.
  79. return -1 if bool(re.search("(index|face)", key)) else 0
  80. # own methods for processing
  81. def get_label_number(self):
  82. r"""Get the number of labels in this dataset as dict."""
  83. label_num = {}
  84. labels = self.y.unique().cpu().detach().numpy().tolist()
  85. for label in labels:
  86. label_num[label] = (self.y == label).sum().item()
  87. return label_num
  88. def random_splits_mask(self, train_ratio, val_ratio, seed=None):
  89. r"""If the data has masks for train/val/test, return the splits with specific ratio.
  90. Parameters
  91. ----------
  92. train_ratio : float
  93. the portion of data that used for training.
  94. val_ratio : float
  95. the portion of data that used for validation.
  96. seed : int
  97. random seed for splitting dataset.
  98. """
  99. rs = torch.get_rng_state()
  100. rs_cuda = torch.cuda.get_rng_state()
  101. if seed is not None:
  102. torch.manual_seed(seed)
  103. torch.cuda.manual_seed(seed)
  104. perm = torch.randperm(self.num_nodes)
  105. train_index = perm[: int(self.num_nodes * train_ratio)]
  106. val_index = perm[
  107. int(self.num_nodes * train_ratio) : int(
  108. self.num_nodes * (train_ratio + val_ratio)
  109. )
  110. ]
  111. test_index = perm[int(self.num_nodes * (train_ratio + val_ratio)) :]
  112. self.train_mask = index_to_mask(train_index, size=self.num_nodes)
  113. self.val_mask = index_to_mask(val_index, size=self.num_nodes)
  114. self.test_mask = index_to_mask(test_index, size=self.num_nodes)
  115. torch.set_rng_state(rs)
  116. torch.cuda.set_rng_state(rs_cuda)
  117. return self
  118. def random_splits_nodes(self, train_ratio, val_ratio, seed=None):
  119. r"""If the data uses id of nodes for train/val/test, return the splits with specific ratio.
  120. Parameters
  121. ----------
  122. train_ratio : float
  123. the portion of data that used for training.
  124. val_ratio : float
  125. the portion of data that used for validation.
  126. seed : int
  127. random seed for splitting dataset.
  128. """
  129. rs = torch.get_rng_state()
  130. rs_cuda = torch.cuda.get_rng_state()
  131. if seed is not None:
  132. torch.manual_seed(seed)
  133. torch.cuda.manual_seed(seed)
  134. perm = torch.randperm(self.num_nodes)
  135. self.train_node = perm[: int(self.num_nodes * train_ratio)]
  136. self.val_node = perm[
  137. int(self.num_nodes * train_ratio) : int(
  138. self.num_nodes * (train_ratio + val_ratio)
  139. )
  140. ]
  141. self.test_node = perm[int(self.num_nodes * (train_ratio + val_ratio)) :]
  142. self.train_target = self.y[self.train_node]
  143. self.valid_target = self.y[self.valid_node]
  144. self.test_target = self.y[self.test_node]
  145. torch.set_rng_state(rs)
  146. torch.cuda.set_rng_state(rs_cuda)
  147. return self
  148. def random_splits_mask_class(
  149. self, num_train_per_class, num_val, num_test, seed=None
  150. ):
  151. r"""If the data has masks for train/val/test, return the splits with specific number of samples from every class for training.
  152. Parameters
  153. ----------
  154. num_train_per_class : int
  155. the number of samples from every class used for training.
  156. num_val : int
  157. the total number of nodes that used for validation.
  158. num_test : int
  159. the total number of nodes that used for testing.
  160. seed : int
  161. random seed for splitting dataset.
  162. """
  163. rs = torch.get_rng_state()
  164. rs_cuda = torch.cuda.get_rng_state()
  165. if seed is not None:
  166. torch.manual_seed(seed)
  167. torch.cuda.manual_seed(seed)
  168. num_classes = self.y.max().cpu().item() + 1
  169. self.train_mask.fill_(False)
  170. for c in range(num_classes):
  171. idx = (self.y == c).nonzero().view(-1)
  172. idx = idx[torch.randperm(idx.size(0))[:num_train_per_class]]
  173. self.train_mask[idx] = True
  174. remaining = (~self.train_mask).nonzero().view(-1)
  175. remaining = remaining[torch.randperm(remaining.size(0))]
  176. self.val_mask.fill_(False)
  177. self.val_mask[remaining[:num_val]] = True
  178. self.test_mask.fill_(False)
  179. self.test_mask[remaining[num_val : num_val + num_test]] = True
  180. torch.set_rng_state(rs)
  181. torch.cuda.set_rng_state(rs_cuda)
  182. return self
  183. def random_splits_nodes_class(
  184. self, num_train_per_class, num_val, num_test, seed=None
  185. ):
  186. r"""If the data uses id of nodes for train/val/test, return the splits with specific number of samples from every class for training.
  187. Parameters
  188. ----------
  189. num_train_per_class : int
  190. the number of samples from every class used for training.
  191. num_val : int
  192. the total number of nodes that used for validation.
  193. num_test : int
  194. the total number of nodes that used for testing.
  195. seed : int
  196. random seed for splitting dataset.
  197. """
  198. rs = torch.get_rng_state()
  199. rs_cuda = torch.cuda.get_rng_state()
  200. if seed is not None:
  201. torch.manual_seed(seed)
  202. torch.cuda.manual_seed(seed)
  203. num_classes = self.y.max().cpu().item() + 1
  204. train_mask = torch.zeros(
  205. self.num_nodes, dtype=torch.bool, device=self.train_node.device
  206. )
  207. sup = []
  208. for c in range(num_classes):
  209. idx = (self.y == c).nonzero().view(-1)
  210. idx = idx[torch.randperm(idx.size(0))[:num_train_per_class]]
  211. sup.append(idx)
  212. train_mask[idx] = True
  213. self.train_node = torch.cat(sup)
  214. remaining = (~train_mask).nonzero().view(-1)
  215. remaining = remaining[torch.randperm(remaining.size(0))]
  216. self.val_node = remaining[:num_val]
  217. self.test_node = remaining[num_val : num_val + num_test]
  218. self.train_target = self.y[self.train_node]
  219. self.valid_target = self.y[self.valid_node]
  220. self.test_target = self.y[self.test_node]
  221. torch.set_rng_state(rs)
  222. torch.cuda.set_rng_state(rs_cuda)
  223. return self
  224. def __inc__(self, key, value):
  225. r""" "Returns the incremental count to cumulatively increase the value
  226. of the next attribute of :obj:`key` when creating batches.
  227. .. note::
  228. This method is for internal use only, and should only be overridden
  229. if the batch concatenation process is corrupted for a specific data
  230. attribute.
  231. """
  232. # Only `*index*` and `*face*` should be cumulatively summed up when
  233. # creating batches.
  234. return self.num_nodes if bool(re.search("(index|face)", key)) else 0
  235. @property
  236. def num_edges(self):
  237. r"""Returns the number of edges in the graph."""
  238. for key, item in self("edge_index", "edge_attr"):
  239. return item.size(self.cat_dim(key, item))
  240. return None
  241. @property
  242. def num_features(self):
  243. r"""Returns the number of features per node in the graph."""
  244. return 1 if self.x.dim() == 1 else self.x.size(1)
  245. @property
  246. def num_nodes(self):
  247. if self.x is not None:
  248. return self.x.shape[0]
  249. return torch.max(self.edge_index) + 1
  250. def is_coalesced(self):
  251. r"""Returns :obj:`True`, if edge indices are ordered and do not contain
  252. duplicate entries."""
  253. row, col = self.edge_index
  254. index = self.num_nodes * row + col
  255. return row.size(0) == torch.unique(index).size(0)
  256. def apply(self, func, *keys):
  257. r"""Applies the function :obj:`func` to all attributes :obj:`*keys`.
  258. If :obj:`*keys` is not given, :obj:`func` is applied to all present
  259. attributes.
  260. """
  261. for key, item in self(*keys):
  262. self[key] = func(item)
  263. return self
  264. def contiguous(self, *keys):
  265. r"""Ensures a contiguous memory layout for all attributes :obj:`*keys`.
  266. If :obj:`*keys` is not given, all present attributes are ensured to
  267. have a contiguous memory layout."""
  268. return self.apply(lambda x: x.contiguous(), *keys)
  269. def to(self, device, *keys):
  270. r"""Performs tensor dtype and/or device conversion to all attributes
  271. :obj:`*keys`.
  272. If :obj:`*keys` is not given, the conversion is applied to all present
  273. attributes."""
  274. return self.apply(lambda x: x.to(device), *keys)
  275. def cuda(self, *keys):
  276. return self.apply(lambda x: x.cuda(), *keys)
  277. def clone(self):
  278. return Data.from_dict({k: v.clone() for k, v in self})
  279. def __repr__(self):
  280. info = [
  281. "{}={}".format(key, list(item.size()))
  282. for key, item in self
  283. if type(item) != list and type(item) != dict
  284. ]
  285. return "{}({})".format(self.__class__.__name__, ", ".join(info))