diff --git a/autogl/data/graph/_general_static_graph/utils/conversion/__init__.py b/autogl/data/graph/_general_static_graph/utils/conversion/__init__.py index ec0f06e..2ed61c2 100644 --- a/autogl/data/graph/_general_static_graph/utils/conversion/__init__.py +++ b/autogl/data/graph/_general_static_graph/utils/conversion/__init__.py @@ -1,3 +1,4 @@ +from ._general import StaticGraphToGeneralData, static_graph_to_general_data from ._nx import ( HomogeneousStaticGraphToNetworkX ) diff --git a/autogl/data/graph/_general_static_graph/utils/conversion/_general.py b/autogl/data/graph/_general_static_graph/utils/conversion/_general.py new file mode 100644 index 0000000..647085b --- /dev/null +++ b/autogl/data/graph/_general_static_graph/utils/conversion/_general.py @@ -0,0 +1,79 @@ +import torch +import typing as _typing +import autogl +from ... import GeneralStaticGraph + + +class StaticGraphToGeneralData: + def __init__(self, *__args, **__kwargs): + pass + + def __call__( + self, static_graph: GeneralStaticGraph, + *__args, **__kwargs + ): + if not isinstance(static_graph, GeneralStaticGraph): + raise TypeError + elif not static_graph.nodes.is_homogeneous: + raise ValueError("Provided static graph MUST consist of homogeneous nodes") + homogeneous_node_type: _typing.Optional[str] = ( + list(static_graph.nodes)[0] + if len(list(static_graph.nodes)) > 0 else None + ) + data: _typing.Dict[str, torch.Tensor] = dict() + if isinstance(homogeneous_node_type, str): + node_and_edge_data_keys_intersection: _typing.Set[str] = ( + set(static_graph.nodes.data) & set(static_graph.data) + ) + if len(node_and_edge_data_keys_intersection) > 0: + raise ValueError( + f"Provided static graph contains duplicate data " + f"with same keys {node_and_edge_data_keys_intersection}" + f"for homogeneous nodes data and graph-level data, " + f"please refer to doc for more details." + ) + data.update(static_graph.nodes.data) + data.update(static_graph.data) + else: + data.update(static_graph.data) + + if len(list(static_graph.edges)) == 1: + data["edge_index"] = static_graph.edges.connections + if len(set(data.keys()) & set(static_graph.edges.data.keys())) > 0: + raise ValueError( + "Provided static graph contains duplicate data with same key, " + "please refer to doc for more details." + ) + data.update(static_graph.edges.data) + elif len(list(static_graph.edges)) > 1: + for canonical_edge_type in static_graph.edges: + if homogeneous_node_type is not None and isinstance(homogeneous_node_type, str) and ( + canonical_edge_type.source_node_type != homogeneous_node_type or + canonical_edge_type.target_node_type != homogeneous_node_type + ): + continue + if len(canonical_edge_type.relation_type) < 4 or canonical_edge_type[-4:] != 'edge': + continue + data[f"{canonical_edge_type.relation_type}_index"] = ( + static_graph.edges[canonical_edge_type].connections + ) + + edge_type_prefix: str = canonical_edge_type.relation_type[:-4] + for data_key in static_graph.edges[canonical_edge_type].data: + if len(data_key) >= 4 and data_key[:4] == 'edge': + data[f"{edge_type_prefix}{data_key}"] = ( + static_graph.edges[canonical_edge_type].data[data_key].detach() + ) + else: + data[f"{canonical_edge_type.relation_type}_{data_key}"] = ( + static_graph.edges[canonical_edge_type].data[data_key].detach() + ) + + general_data = autogl.data.Data() + for key, value in data.items(): + setattr(general_data, key, value) + return general_data + + +def static_graph_to_general_data(static_graph: GeneralStaticGraph) -> autogl.data.Data: + return StaticGraphToGeneralData().__call__(static_graph) diff --git a/autogl/datasets/utils/__init__.py b/autogl/datasets/utils/__init__.py index 85b58ff..a7e80d0 100644 --- a/autogl/datasets/utils/__init__.py +++ b/autogl/datasets/utils/__init__.py @@ -1,6 +1,6 @@ +from ._split_edges import split_edges from ._general import ( index_to_mask, - split_edges, random_splits_mask, random_splits_mask_class, graph_cross_validation, diff --git a/autogl/datasets/utils/_general.py b/autogl/datasets/utils/_general.py index a3e5b30..77d878f 100644 --- a/autogl/datasets/utils/_general.py +++ b/autogl/datasets/utils/_general.py @@ -5,9 +5,7 @@ import torch.utils.data import typing as _typing from sklearn.model_selection import StratifiedKFold, KFold from autogl import backend as _backend -from autogl.data import Data, Dataset, InMemoryStaticGraphSet -from ...data.graph import GeneralStaticGraph, GeneralStaticGraphGenerator -from . import _pyg +from autogl.data import InMemoryStaticGraphSet def index_to_mask(index: torch.Tensor, size): @@ -16,70 +14,6 @@ def index_to_mask(index: torch.Tensor, size): return mask -def split_edges( - dataset: InMemoryStaticGraphSet, - train_ratio: float, val_ratio: float -) -> InMemoryStaticGraphSet: - test_ratio: float = 1 - train_ratio - val_ratio - - def _split_edges_for_graph(homogeneous_static_graph: GeneralStaticGraph) -> GeneralStaticGraph: - if not isinstance(homogeneous_static_graph, GeneralStaticGraph): - raise TypeError - elif not homogeneous_static_graph.edges.is_homogeneous: - raise ValueError("The provided graph MUST consist of homogeneous edges.") - else: - split_data = _pyg.train_test_split_edges( - Data( - edge_index=homogeneous_static_graph.edges.connections.detach().clone(), - edge_attr=( - homogeneous_static_graph.edges.data['edge_attr'].detach().clone() - if 'edge_attr' in homogeneous_static_graph.edges.data else None - ) - ), - val_ratio, test_ratio - ) - original_edge_type = [et for et in homogeneous_static_graph.edges][0] - - split_static_graph = GeneralStaticGraphGenerator.create_heterogeneous_static_graph( - dict([ - (node_type, homogeneous_static_graph.nodes[node_type].data) - for node_type in homogeneous_static_graph.nodes - ]), - { - (original_edge_type.source_node_type, "train_pos_edge", original_edge_type.target_node_type): ( - getattr(split_data, "train_pos_edge_index"), - {"edge_attr": getattr(split_data, "train_pos_edge_attr")} - if isinstance(getattr(split_data, "train_pos_edge_attr"), torch.Tensor) - else None - ), - (original_edge_type.source_node_type, "val_pos_edge", original_edge_type.target_node_type): ( - getattr(split_data, "val_pos_edge_index"), - {"edge_attr": getattr(split_data, "val_pos_edge_attr")} - if isinstance(getattr(split_data, "val_pos_edge_attr"), torch.Tensor) - else None - ), - (original_edge_type.source_node_type, "val_neg_edge", original_edge_type.target_node_type): - getattr(split_data, "val_neg_edge_index"), - (original_edge_type.source_node_type, "test_pos_edge", original_edge_type.target_node_type): ( - getattr(split_data, "test_pos_edge_index"), - {"edge_attr": getattr(split_data, "test_pos_edge_attr")} - if isinstance(getattr(split_data, "test_pos_edge_attr"), torch.Tensor) - else None - ), - (original_edge_type.source_node_type, "test_neg_edge", original_edge_type.target_node_type): - getattr(split_data, "test_neg_edge_index") - }, - homogeneous_static_graph.data - ) - return split_static_graph - - if not isinstance(dataset, InMemoryStaticGraphSet): - raise TypeError - for index in range(len(dataset)): - dataset[index] = _split_edges_for_graph(dataset[index]) - return dataset - - def random_splits_mask( dataset: InMemoryStaticGraphSet, train_ratio: float = 0.2, val_ratio: float = 0.4, diff --git a/autogl/datasets/utils/_split_edges/__init__.py b/autogl/datasets/utils/_split_edges/__init__.py new file mode 100644 index 0000000..d29abab --- /dev/null +++ b/autogl/datasets/utils/_split_edges/__init__.py @@ -0,0 +1 @@ +from .split_edges import split_edges diff --git a/autogl/datasets/utils/_split_edges/_dgl_compatible.py b/autogl/datasets/utils/_split_edges/_dgl_compatible.py new file mode 100644 index 0000000..40b2bb7 --- /dev/null +++ b/autogl/datasets/utils/_split_edges/_dgl_compatible.py @@ -0,0 +1,207 @@ +import numpy as np +import scipy.sparse as sp +import torch +import typing as _typing +import dgl +import autogl.data.graph +from .train_test_split_edges import train_test_split_edges +from autogl.data.graph.utils.conversion import static_graph_to_general_data + + +class _SplitEdgesDGLImpl: + @classmethod + def __split_edges_train_val_test( + cls, g: dgl.DGLGraph, + train_ratio: float, val_ratio: float + ) -> _typing.Tuple[ + dgl.DGLGraph, dgl.DGLGraph, dgl.DGLGraph, + dgl.DGLGraph, dgl.DGLGraph, dgl.DGLGraph, dgl.DGLGraph + ]: + u, v = g.edges() + + eids = np.arange(g.number_of_edges()) + eids = np.random.permutation(eids) + + valid_size = int(len(eids) * val_ratio) + test_size = int(len(eids) * (1 - train_ratio - val_ratio)) + train_size = g.number_of_edges() - test_size - valid_size + + test_pos_u, test_pos_v = u[eids[:test_size]], v[eids[:test_size]] + valid_pos_u, valid_pos_v = u[eids[test_size:test_size + valid_size]], v[eids[test_size:test_size + valid_size]] + train_pos_u, train_pos_v = u[eids[test_size + valid_size:]], v[eids[test_size + valid_size:]] + + # Find all negative edges and split them for training and testing + adj = sp.coo_matrix((np.ones(len(u)), (u.numpy(), v.numpy()))) + adj_neg = 1 - adj.todense() - np.eye(g.number_of_nodes()) + neg_u, neg_v = np.where(adj_neg != 0) + + neg_eids = np.random.choice(len(neg_u), g.number_of_edges()) + test_neg_u, test_neg_v = neg_u[neg_eids[:test_size]], neg_v[neg_eids[:test_size]] + valid_neg_u, valid_neg_v = neg_u[neg_eids[test_size:test_size + valid_size]], neg_v[neg_eids[test_size:test_size + valid_size]] + train_neg_u, train_neg_v = neg_u[neg_eids[train_size:]], neg_v[neg_eids[train_size:]] + + train_g = dgl.remove_edges(g, eids[:test_size + valid_size]) + + train_pos_g = dgl.graph((train_pos_u, train_pos_v), num_nodes=g.number_of_nodes()) + train_neg_g = dgl.graph((train_neg_u, train_neg_v), num_nodes=g.number_of_nodes()) + + valid_pos_g = dgl.graph((valid_pos_u, valid_pos_v), num_nodes=g.number_of_nodes()) + valid_neg_g = dgl.graph((valid_neg_u, valid_neg_v), num_nodes=g.number_of_nodes()) + + test_pos_g = dgl.graph((test_pos_u, test_pos_v), num_nodes=g.number_of_nodes()) + test_neg_g = dgl.graph((test_neg_u, test_neg_v), num_nodes=g.number_of_nodes()) + + return ( + train_g, train_pos_g, train_neg_g, + valid_pos_g, valid_neg_g, test_pos_g, test_neg_g + ) + + @classmethod + def __split_edges_train_test( + cls, g: dgl.DGLGraph, train_ratio: float + ) -> _typing.Tuple[ + dgl.DGLGraph, dgl.DGLGraph, dgl.DGLGraph, + dgl.DGLGraph, dgl.DGLGraph, + ]: + u, v = g.edges() + + eids = np.arange(g.number_of_edges()) + eids = np.random.permutation(eids) + test_size = int(len(eids) * (1 - train_ratio)) + train_size = g.number_of_edges() - test_size + test_pos_u, test_pos_v = u[eids[:test_size]], v[eids[:test_size]] + train_pos_u, train_pos_v = u[eids[test_size:]], v[eids[test_size:]] + + # Find all negative edges and split them for training and testing + adj = sp.coo_matrix((np.ones(len(u)), (u.numpy(), v.numpy()))) + adj_neg = 1 - adj.todense() - np.eye(g.number_of_nodes()) + neg_u, neg_v = np.where(adj_neg != 0) + + neg_eids = np.random.choice(len(neg_u), g.number_of_edges()) + test_neg_u, test_neg_v = neg_u[neg_eids[:test_size]], neg_v[neg_eids[:test_size]] + train_neg_u, train_neg_v = neg_u[neg_eids[train_size:]], neg_v[neg_eids[train_size:]] + + train_g = dgl.remove_edges(g, eids[:test_size]) + + train_pos_g = dgl.graph((train_pos_u, train_pos_v), num_nodes=g.number_of_nodes()) + train_neg_g = dgl.graph((train_neg_u, train_neg_v), num_nodes=g.number_of_nodes()) + + test_pos_g = dgl.graph((test_pos_u, test_pos_v), num_nodes=g.number_of_nodes()) + test_neg_g = dgl.graph((test_neg_u, test_neg_v), num_nodes=g.number_of_nodes()) + + return train_g, train_pos_g, train_neg_g, test_pos_g, test_neg_g + + @classmethod + def split_edges_for_dgl_graph( + cls, graph: dgl.DGLGraph, + train_ratio: float, val_ratio: _typing.Optional[float] = ... + ) -> _typing.Union[ + _typing.Tuple[ + dgl.DGLGraph, dgl.DGLGraph, + dgl.DGLGraph, dgl.DGLGraph, + dgl.DGLGraph, dgl.DGLGraph, + dgl.DGLGraph + ], + _typing.Tuple[ + dgl.DGLGraph, dgl.DGLGraph, dgl.DGLGraph, + dgl.DGLGraph, dgl.DGLGraph, + ] + ]: + if not 0 < train_ratio < 1: + raise ValueError(f"Invalid train_ratio as {train_ratio}") + if isinstance(val_ratio, float): + if not 0 < val_ratio < 1: + raise ValueError(f"Invalid val_ratio as {val_ratio}") + if not 0 < train_ratio + val_ratio < 1: + raise ValueError( + f"Invalid combination (train_ratio, val_ratio) " + f"as ({train_ratio}, {val_ratio})" + ) + return cls.__split_edges_train_val_test(graph, train_ratio, val_ratio) + else: + return cls.__split_edges_train_test(graph, train_ratio) + + +def split_edges_for_data( + data: _typing.Union[ + dgl.DGLGraph, autogl.data.graph.GeneralStaticGraph, _typing.Any + ], + train_ratio: float, val_ratio: _typing.Optional[float] +) -> _typing.Union[ + autogl.data.Data, + _typing.Tuple[ + dgl.DGLGraph, dgl.DGLGraph, + dgl.DGLGraph, dgl.DGLGraph, + dgl.DGLGraph, dgl.DGLGraph, + dgl.DGLGraph + ], + _typing.Tuple[ + dgl.DGLGraph, dgl.DGLGraph, dgl.DGLGraph, + dgl.DGLGraph, dgl.DGLGraph, + ] +]: + if isinstance(data, dgl.DGLGraph): + if not data.is_homogeneous: + raise ValueError( + "provided DGL graph to split edges MUST be homogeneous" + ) + else: + return _SplitEdgesDGLImpl.split_edges_for_dgl_graph( + data, train_ratio, val_ratio + ) + elif isinstance(data, autogl.data.graph.GeneralStaticGraph): + if not (data.nodes.is_homogeneous and data.edges.is_homogeneous): + raise ValueError( + "Provided instance of GeneralStaticGraph MUST be homogeneous" + ) + __data = static_graph_to_general_data(data) + edge_index: torch.LongTensor = data.edges.connections + edge_attr: _typing.Optional[torch.Tensor] = ( + data.edges.data['edge_attr'] if 'edge_attr' in data.edges.data else None + ) + elif ( + hasattr(data, 'edge_index') and + isinstance(data.edge_index, torch.Tensor) and + data.edge_index.dim() == data.edge_index.size(0) == 2 + ): + edge_index: torch.LongTensor = data.edge_index + if ( + hasattr(data, 'edge_attr') and + isinstance(data.edge_attr, torch.Tensor) and + data.edge_attr.size(0) == edge_index.size(1) + ): + edge_attr: _typing.Optional[torch.Tensor] = data.edge_attr + else: + edge_attr: _typing.Optional[torch.Tensor] = None + if hasattr(data, 'x') and isinstance(data.x, torch.Tensor): + x: _typing.Optional[torch.Tensor] = data.x + else: + x: _typing.Optional[torch.Tensor] = None + if hasattr(data, 'y') and isinstance(data.y, torch.Tensor): + y: _typing.Optional[torch.Tensor] = data.x + else: + y: _typing.Optional[torch.Tensor] = None + __data = autogl.data.Data( + edge_index=edge_index, edge_attr=edge_attr, x=x, y=y + ) + else: + raise ValueError + + if isinstance(val_ratio, float) and 0 < val_ratio < 1: + test_ratio = 1 - train_ratio - val_ratio + else: + test_ratio = 1 - train_ratio + compound_results = train_test_split_edges( + edge_index, edge_attr, + val_ratio=val_ratio, test_ratio=test_ratio + ) + __data.train_pos_edge_index = compound_results.train_pos_edge_index + __data.train_pos_edge_attr = compound_results.train_pos_edge_attr + __data.train_neg_adj_mask = compound_results.train_neg_adj_mask + __data.val_pos_edge_index = compound_results.val_pos_edge_index + __data.val_pos_edge_attr = compound_results.val_pos_edge_attr + __data.val_neg_edge_index = compound_results.val_neg_edge_index + __data.test_pos_edge_index = compound_results.test_pos_edge_index + __data.test_pos_edge_attr = compound_results.test_pos_edge_attr + __data.test_neg_edge_index = compound_results.test_neg_edge_index + return __data diff --git a/autogl/datasets/utils/_split_edges/_pyg_compatible.py b/autogl/datasets/utils/_split_edges/_pyg_compatible.py new file mode 100644 index 0000000..83ec892 --- /dev/null +++ b/autogl/datasets/utils/_split_edges/_pyg_compatible.py @@ -0,0 +1,71 @@ +import torch +import typing as _typing +import torch_geometric +import autogl +from autogl.data.graph import GeneralStaticGraph +from autogl.data.graph.utils.conversion import static_graph_to_pyg_data +from .train_test_split_edges import train_test_split_edges + + +def split_edges_for_data( + data: _typing.Union[ + torch_geometric.data.Data, autogl.data.graph.GeneralStaticGraph, _typing.Any + ], + train_ratio: float, val_ratio: float +) -> torch_geometric.data.Data: + if isinstance(data, autogl.data.graph.GeneralStaticGraph): + if not (data.nodes.is_homogeneous and data.edges.is_homogeneous): + raise ValueError( + "Provided instance of GeneralStaticGraph MUST be homogeneous" + ) + edge_index: torch.LongTensor = data.edges.connections + edge_attr: _typing.Optional[torch.Tensor] = ( + data.edges.data['edge_attr'] if 'edge_attr' in data.edges.data else None + ) + __data = static_graph_to_pyg_data(data) + elif ( + hasattr(data, 'edge_index') and + isinstance(data.edge_index, torch.Tensor) and + data.edge_index.dim() == data.edge_index.size(0) == 2 + ): + edge_index: torch.LongTensor = data.edge_index + if ( + hasattr(data, 'edge_attr') and + isinstance(data.edge_attr, torch.Tensor) and + data.edge_attr.size(0) == edge_index.size(1) + ): + edge_attr: _typing.Optional[torch.Tensor] = data.edge_attr + else: + edge_attr: _typing.Optional[torch.Tensor] = None + if hasattr(data, 'x') and isinstance(data.x, torch.Tensor): + x: _typing.Optional[torch.Tensor] = data.x + else: + x: _typing.Optional[torch.Tensor] = None + if hasattr(data, 'y') and isinstance(data.y, torch.Tensor): + y: _typing.Optional[torch.Tensor] = data.x + else: + y: _typing.Optional[torch.Tensor] = None + __data = torch_geometric.data.Data( + edge_index=edge_index, edge_attr=edge_attr, x=x, y=y + ) + else: + raise ValueError + + if isinstance(val_ratio, float) and 0 < val_ratio < 1: + test_ratio = 1 - train_ratio - val_ratio + else: + test_ratio = 1 - train_ratio + compound_results = train_test_split_edges( + edge_index, edge_attr, + val_ratio=val_ratio, test_ratio=test_ratio + ) + __data.train_pos_edge_index = compound_results.train_pos_edge_index + __data.train_pos_edge_attr = compound_results.train_pos_edge_attr + __data.train_neg_adj_mask = compound_results.train_neg_adj_mask + __data.val_pos_edge_index = compound_results.val_pos_edge_index + __data.val_pos_edge_attr = compound_results.val_pos_edge_attr + __data.val_neg_edge_index = compound_results.val_neg_edge_index + __data.test_pos_edge_index = compound_results.test_pos_edge_index + __data.test_pos_edge_attr = compound_results.test_pos_edge_attr + __data.test_neg_edge_index = compound_results.test_neg_edge_index + return __data diff --git a/autogl/datasets/utils/_split_edges/split_edges.py b/autogl/datasets/utils/_split_edges/split_edges.py new file mode 100644 index 0000000..dc5a5c1 --- /dev/null +++ b/autogl/datasets/utils/_split_edges/split_edges.py @@ -0,0 +1,32 @@ +import typing as _typing +from autogl.data import InMemoryDataset +import autogl + +if autogl.backend.DependentBackend.is_dgl(): + from ._dgl_compatible import split_edges_for_data +elif autogl.backend.DependentBackend.is_pyg(): + from ._pyg_compatible import split_edges_for_data +else: + raise NotImplementedError + + +def split_edges( + dataset: InMemoryDataset, + train_ratio: float, val_ratio: _typing.Optional[float] +) -> InMemoryDataset: + if isinstance(val_ratio, float) and not 0 < train_ratio + val_ratio < 1: + raise ValueError + elif not 0 < train_ratio < 1: + raise ValueError + if ( + autogl.backend.DependentBackend.is_pyg() and + not (isinstance(val_ratio, float) and 0 < val_ratio < 1) + ): + raise ValueError( + "For PyG as backend, val_ratio MUST be specific float between 0 and 1, " + "i.e. 0 < val_ratio < 1" + ) + return InMemoryDataset( + [split_edges_for_data(item, train_ratio, val_ratio) for item in dataset], + dataset.train_index, dataset.val_index, dataset.test_index + ) diff --git a/autogl/datasets/utils/_split_edges/train_test_split_edges.py b/autogl/datasets/utils/_split_edges/train_test_split_edges.py new file mode 100644 index 0000000..1e85371 --- /dev/null +++ b/autogl/datasets/utils/_split_edges/train_test_split_edges.py @@ -0,0 +1,310 @@ +import math +import torch +import typing as _typing + + +def _maybe_num_nodes(edge_index, num_nodes=None): + if isinstance(num_nodes, int): + return num_nodes + elif isinstance(edge_index, torch.Tensor): + return int(edge_index.max()) + 1 if edge_index.numel() > 0 else 0 + else: + return max(edge_index.size(0), edge_index.size(1)) + + +def __coalesce( + edge_index: torch.Tensor, + edge_attr: _typing.Union[ + torch.Tensor, _typing.Iterable[torch.Tensor], None + ] = None, + num_nodes: _typing.Optional[int] = ..., + is_sorted: bool = False, + sort_by_row: bool = True +) -> _typing.Union[ + torch.Tensor, _typing.Tuple[torch.Tensor, torch.Tensor], + _typing.Tuple[torch.Tensor, _typing.Iterable[torch.Tensor]] +]: + """ + Row-wise sorts :obj:`edge_index` and removes its duplicated entries. + Duplicate entries in :obj:`edge_attr` are directly removed, instead of merged. + Args: + edge_index (LongTensor): The edge indices. + edge_attr (Tensor or List[Tensor], optional): Edge weights or multi- + dimensional edge features. + If given as a list, will re-shuffle and remove duplicates for all + its entries. (default: :obj:`None`) + num_nodes (int, optional): The number of nodes, *i.e.* + :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`) + is_sorted (bool, optional): If set to :obj:`True`, will expect + :obj:`edge_index` to be already sorted row-wise. + sort_by_row (bool, optional): If set to :obj:`False`, will sort + :obj:`edge_index` column-wise. + :rtype: :class:`LongTensor` if :attr:`edge_attr` is :obj:`None`, else + (:class:`LongTensor`, :obj:`Tensor` or :obj:`Iterable[Tensor]]`) + """ + + if edge_attr is None: + pass + elif isinstance(edge_attr, torch.Tensor) and torch.is_tensor(edge_attr): + if edge_attr.size(0) != edge_index.size(1): + raise ValueError + elif isinstance(edge_attr, _typing.Iterable): + if not all([ + ( + isinstance(attr, torch.Tensor) and + attr.size(0) == edge_index.size(1) + ) for attr in edge_attr + ]): + raise ValueError("Invalid edge_attr argument") + else: + raise TypeError("Unsupported type of edge_attr argument") + + nnz = edge_index.size(1) + num_nodes = _maybe_num_nodes(edge_index, num_nodes) + + idx = edge_index.new_empty(nnz + 1) + idx[0] = -1 + idx[1:] = edge_index[1 - int(sort_by_row)] + idx[1:].mul_(num_nodes).add_(edge_index[int(sort_by_row)]) + + if not is_sorted: + idx[1:], perm = idx[1:].sort() + edge_index = edge_index[:, perm] + if edge_attr is not None and isinstance(edge_attr, torch.Tensor): + edge_attr = edge_attr[perm] + elif edge_attr is not None: + edge_attr = [e[perm] for e in edge_attr] + + mask: _typing.Any = idx[1:] > idx[:-1] + + # Only perform expensive merging in case there exists duplicates: + if mask.all(): + return edge_index if edge_attr is None else (edge_index, edge_attr) + + edge_index = edge_index[:, mask] + if edge_attr is None: + return edge_index + elif isinstance(edge_attr, torch.Tensor): + return edge_index, edge_attr[mask] + elif isinstance(edge_attr, _typing.Iterable): + return edge_index, [attr[mask] for attr in edge_attr] + + +def coalesce( + edge_index: torch.Tensor, + edge_attr: _typing.Union[ + torch.Tensor, _typing.Iterable[torch.Tensor], None + ] = None, + num_nodes: _typing.Optional[int] = ..., + is_sorted: bool = False, + sort_by_row: bool = True +) -> _typing.Union[ + torch.Tensor, _typing.Tuple[torch.Tensor, torch.Tensor], + _typing.Tuple[torch.Tensor, _typing.Iterable[torch.Tensor]] +]: + """ + Row-wise sorts :obj:`edge_index` and removes its duplicated entries. + Duplicate entries in :obj:`edge_attr` are directly removed, instead of merged. + Args: + edge_index (LongTensor): The edge indices. + edge_attr (Tensor or List[Tensor], optional): Edge weights or multi- + dimensional edge features. + If given as a list, will re-shuffle and remove duplicates for all + its entries. (default: :obj:`None`) + num_nodes (int, optional): The number of nodes, *i.e.* + :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`) + is_sorted (bool, optional): If set to :obj:`True`, will expect + :obj:`edge_index` to be already sorted row-wise. + sort_by_row (bool, optional): If set to :obj:`False`, will sort + :obj:`edge_index` column-wise. + :rtype: :class:`LongTensor` if :attr:`edge_attr` is :obj:`None`, else + (:class:`LongTensor`, :obj:`Tensor` or :obj:`Iterable[Tensor]]`) + """ + if not isinstance(num_nodes, int): + num_nodes = None + try: + import torch_geometric + return torch_geometric.utils.coalesce( + edge_index, edge_attr, num_nodes, + is_sorted=is_sorted, + sort_by_row=sort_by_row + ) + except ModuleNotFoundError: + return __coalesce( + edge_index, edge_attr, num_nodes, + is_sorted=is_sorted, + sort_by_row=sort_by_row + ) + + +def to_undirected( + edge_index: torch.Tensor, + edge_attr: _typing.Optional[_typing.Union[torch.Tensor, _typing.List[torch.Tensor]]] = None, + num_nodes: _typing.Optional[int] = ..., + __reduce: str = "add", +) -> _typing.Union[ + torch.Tensor, _typing.Tuple[torch.Tensor, torch.Tensor], + _typing.Tuple[torch.Tensor, _typing.List[torch.Tensor]] +]: + r"""Converts the graph given by :attr:`edge_index` to an undirected graph + such that :math:`(j,i) \in \mathcal{E}` for every edge :math:`(i,j) \in + \mathcal{E}`. + + Args: + edge_index (LongTensor): The edge indices. + edge_attr (Tensor or List[Tensor], optional): Edge weights or multi- + dimensional edge features. + If given as a list, will remove duplicates for all its entries. + (default: :obj:`None`) + num_nodes (int, optional): The number of nodes, *i.e.* + :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`) + __reduce (string, optional): The reduce operation to use for merging edge + features (:obj:`"add"`, :obj:`"mean"`, :obj:`"min"`, :obj:`"max"`, + :obj:`"mul"`). (default: :obj:`"add"`) + + :rtype: :class:`LongTensor` if :attr:`edge_attr` is :obj:`None`, else + (:class:`LongTensor`, :obj:`Tensor` or :obj:`List[Tensor]]`) + """ + # Maintain backward compatibility to `to_undirected(edge_index, num_nodes)` + if isinstance(edge_attr, int): + edge_attr = None + num_nodes = edge_attr + + row, col = edge_index + row, col = torch.cat([row, col], dim=0), torch.cat([col, row], dim=0) + edge_index = torch.stack([row, col], dim=0) + + if edge_attr is not None and isinstance(edge_attr, torch.Tensor): + edge_attr = torch.cat([edge_attr, edge_attr], dim=0) + elif edge_attr is not None: + edge_attr = [torch.cat([e, e], dim=0) for e in edge_attr] + + return coalesce(edge_index, edge_attr, num_nodes) + + +class _SplitResult: + def __init__( + self, + train_pos_edge_index: torch.Tensor, + train_pos_edge_attr: _typing.Optional[torch.Tensor], + train_neg_adj_mask: torch.Tensor, + val_pos_edge_index: torch.Tensor, + val_pos_edge_attr: _typing.Optional[torch.Tensor], + val_neg_edge_index: torch.Tensor, + test_pos_edge_index: torch.Tensor, + test_pos_edge_attr: _typing.Optional[torch.Tensor], + test_neg_edge_index: torch.Tensor + ): + self.train_pos_edge_index: torch.Tensor = train_pos_edge_index + self.train_pos_edge_attr: _typing.Optional[torch.Tensor] = train_pos_edge_attr + self.train_neg_adj_mask: torch.Tensor = train_neg_adj_mask + self.val_pos_edge_index: torch.Tensor = val_pos_edge_index + self.val_pos_edge_attr: _typing.Optional[torch.Tensor] = val_pos_edge_attr + self.val_neg_edge_index: torch.Tensor = val_neg_edge_index + self.test_pos_edge_index: torch.Tensor = test_pos_edge_index + self.test_pos_edge_attr: _typing.Optional[torch.Tensor] = test_pos_edge_attr + self.test_neg_edge_index: torch.Tensor = test_neg_edge_index + + +def train_test_split_edges( + edge_index: torch.Tensor, + edge_attr: _typing.Optional[_typing.Union[torch.Tensor, _typing.List[torch.Tensor]]] = None, + num_nodes: _typing.Optional[int] = ..., + val_ratio: float = 0.05, + test_ratio: float = 0.1 +): + r"""Splits the edges of a :class:`torch_geometric.data.Data` object + into positive and negative train/val/test edges. + As such, it will replace the :obj:`edge_index` attribute with + :obj:`train_pos_edge_index`, :obj:`train_pos_neg_adj_mask`, + :obj:`val_pos_edge_index`, :obj:`val_neg_edge_index` and + :obj:`test_pos_edge_index` attributes. + If :obj:`data` has edge features named :obj:`edge_attr`, then + :obj:`train_pos_edge_attr`, :obj:`val_pos_edge_attr` and + :obj:`test_pos_edge_attr` will be added as well. + + .. warning:: + + :meth:`~torch_geometric.utils.train_test_split_edges` is deprecated and + will be removed in a future release. + Use :class:`torch_geometric.transforms.RandomLinkSplit` instead. + + Args: + edge_index (LongTensor): The edge indices. + edge_attr (Tensor or List[Tensor], optional): Edge weights or multi- + dimensional edge features. + If given as a list, will remove duplicates for all its entries. + (default: :obj:`None`) + num_nodes (int, optional): The number of nodes, *i.e.* + :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`) + val_ratio (float, optional): The ratio of positive validation edges. + (default: :obj:`0.05`) + test_ratio (float, optional): The ratio of positive test edges. + (default: :obj:`0.1`) + """ + row, col = edge_index + num_nodes = _maybe_num_nodes(edge_index, num_nodes) + + # Return upper triangular portion. + mask = row < col + row, col = row[mask], col[mask] + + if edge_attr is not None: + edge_attr = edge_attr[mask] + + n_v = int(math.floor(val_ratio * row.size(0))) + n_t = int(math.floor(test_ratio * row.size(0))) + + # Positive edges. + perm = torch.randperm(row.size(0)) + row, col = row[perm], col[perm] + if edge_attr is not None: + edge_attr = edge_attr[perm] + + r, c = row[:n_v], col[:n_v] + val_pos_edge_index = torch.stack([r, c], dim=0) + if edge_attr is not None: + val_pos_edge_attr = edge_attr[:n_v] + else: + val_pos_edge_attr = None + + r, c = row[n_v:n_v + n_t], col[n_v:n_v + n_t] + test_pos_edge_index = torch.stack([r, c], dim=0) + if edge_attr is not None: + test_pos_edge_attr = edge_attr[n_v:n_v + n_t] + else: + test_pos_edge_attr = None + + r, c = row[n_v + n_t:], col[n_v + n_t:] + train_pos_edge_index = torch.stack([r, c], dim=0) + if edge_attr is not None: + train_pos_edge_index, train_pos_edge_attr = to_undirected( + train_pos_edge_index, edge_attr[n_v + n_t:] + ) + else: + train_pos_edge_index = to_undirected(train_pos_edge_index) + train_pos_edge_attr = None + + # Negative edges. + neg_adj_mask = torch.ones(num_nodes, num_nodes, dtype=torch.uint8) + neg_adj_mask = neg_adj_mask.triu(diagonal=1).to(torch.bool) + neg_adj_mask[row, col] = 0 + + neg_row, neg_col = neg_adj_mask.nonzero(as_tuple=False).t() + perm = torch.randperm(neg_row.size(0))[:n_v + n_t] + neg_row, neg_col = neg_row[perm], neg_col[perm] + + neg_adj_mask[neg_row, neg_col] = 0 + train_neg_adj_mask = neg_adj_mask + + row, col = neg_row[:n_v], neg_col[:n_v] + val_neg_edge_index = torch.stack([row, col], dim=0) + + row, col = neg_row[n_v:n_v + n_t], neg_col[n_v:n_v + n_t] + test_neg_edge_index = torch.stack([row, col], dim=0) + + return _SplitResult( + train_pos_edge_index, train_pos_edge_attr, train_neg_adj_mask, + val_pos_edge_index, val_pos_edge_attr, val_neg_edge_index, + test_pos_edge_index, test_pos_edge_attr, test_neg_edge_index + )