|
- """ Migrated `train_test_split_edges` function from PyTorch-Geometric """
- import math
- import torch
- import typing as _typing
-
-
- def to_undirected(
- edge_index: torch.Tensor, edge_attr: _typing.Optional[torch.Tensor] = None
- ) -> _typing.Union[torch.Tensor, _typing.Tuple[torch.Tensor, torch.Tensor]]:
- r"""Converts the graph given by :attr:`edge_index` to an undirected graph
- such that :math:`(j,i) \in \mathcal{E}` for every edge :math:`(i,j) \in
- \mathcal{E}`.
-
- Args:
- edge_index (LongTensor): The edge indices.
- edge_attr (Tensor, optional): Edge weights or multi-dimensional
- edge features. (default: :obj:`None`)
- num_nodes (int, optional): The number of nodes, *i.e.*
- :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`)
-
- :rtype: :class:`LongTensor` if :attr:`edge_attr` is :obj:`None`, else
- (:class:`LongTensor`, :class:`Tensor`)
- """
-
- row, col = edge_index
- row, col = torch.cat([row, col], dim=0), torch.cat([col, row], dim=0)
- edge_index = torch.stack([row, col], dim=0)
- if edge_attr is not None:
- edge_attr = torch.cat([edge_attr, edge_attr], dim=0)
-
- if edge_attr is None:
- return edge_index
- else:
- return edge_index, edge_attr
-
-
- def train_test_split_edges(data, val_ratio: float = 0.05,
- test_ratio: float = 0.1):
- r"""Splits the edges of a :class:`torch_geometric.data.Data` object
- into positive and negative train/val/test edges.
- As such, it will replace the :obj:`edge_index` attribute with
- :obj:`train_pos_edge_index`, :obj:`train_pos_neg_adj_mask`,
- :obj:`val_pos_edge_index`, :obj:`val_neg_edge_index` and
- :obj:`test_pos_edge_index` attributes.
- If :obj:`data` has edge features named :obj:`edge_attr`, then
- :obj:`train_pos_edge_attr`, :obj:`val_pos_edge_attr` and
- :obj:`test_pos_edge_attr` will be added as well.
-
- Args:
- data (Data): The data object.
- val_ratio (float, optional): The ratio of positive validation edges.
- (default: :obj:`0.05`)
- test_ratio (float, optional): The ratio of positive test edges.
- (default: :obj:`0.1`)
-
- :rtype: :class:`torch_geometric.data.Data`
- """
-
- num_nodes = data.num_nodes
- row, col = data.edge_index
- edge_attr = data.edge_attr
- data.edge_index = data.edge_attr = None
-
- # Return upper triangular portion.
- mask = row < col
- row, col = row[mask], col[mask]
-
- if edge_attr is not None:
- edge_attr = edge_attr[mask]
-
- n_v = int(math.floor(val_ratio * row.size(0)))
- n_t = int(math.floor(test_ratio * row.size(0)))
-
- # Positive edges.
- perm = torch.randperm(row.size(0))
- row, col = row[perm], col[perm]
- if edge_attr is not None:
- edge_attr = edge_attr[perm]
-
- r, c = row[:n_v], col[:n_v]
- data.val_pos_edge_index = torch.stack([r, c], dim=0)
- if edge_attr is not None:
- data.val_pos_edge_attr = edge_attr[:n_v]
-
- r, c = row[n_v:n_v + n_t], col[n_v:n_v + n_t]
- data.test_pos_edge_index = torch.stack([r, c], dim=0)
- if edge_attr is not None:
- data.test_pos_edge_attr = edge_attr[n_v:n_v + n_t]
-
- r, c = row[n_v + n_t:], col[n_v + n_t:]
- data.train_pos_edge_index = torch.stack([r, c], dim=0)
- if edge_attr is not None:
- out = to_undirected(data.train_pos_edge_index, edge_attr[n_v + n_t:])
- data.train_pos_edge_index, data.train_pos_edge_attr = out
- else:
- data.train_pos_edge_index = to_undirected(data.train_pos_edge_index)
-
- # Negative edges.
- neg_adj_mask = torch.ones(num_nodes, num_nodes, dtype=torch.uint8)
- neg_adj_mask = neg_adj_mask.triu(diagonal=1).to(torch.bool)
- neg_adj_mask[row, col] = 0
-
- neg_row, neg_col = neg_adj_mask.nonzero().t()
- perm = torch.randperm(neg_row.size(0))[:n_v + n_t]
- neg_row, neg_col = neg_row[perm], neg_col[perm]
-
- neg_adj_mask[neg_row, neg_col] = 0
- data.train_neg_adj_mask = neg_adj_mask
-
- row, col = neg_row[:n_v], neg_col[:n_v]
- data.val_neg_edge_index = torch.stack([row, col], dim=0)
-
- row, col = neg_row[n_v:n_v + n_t], neg_col[n_v:n_v + n_t]
- data.test_neg_edge_index = torch.stack([row, col], dim=0)
-
- return data
|