| @@ -1,3 +1,4 @@ | |||
| from ._general import StaticGraphToGeneralData, static_graph_to_general_data | |||
| from ._nx import ( | |||
| HomogeneousStaticGraphToNetworkX | |||
| ) | |||
| @@ -0,0 +1,79 @@ | |||
| import torch | |||
| import typing as _typing | |||
| import autogl | |||
| from ... import GeneralStaticGraph | |||
| class StaticGraphToGeneralData: | |||
| def __init__(self, *__args, **__kwargs): | |||
| pass | |||
| def __call__( | |||
| self, static_graph: GeneralStaticGraph, | |||
| *__args, **__kwargs | |||
| ): | |||
| if not isinstance(static_graph, GeneralStaticGraph): | |||
| raise TypeError | |||
| elif not static_graph.nodes.is_homogeneous: | |||
| raise ValueError("Provided static graph MUST consist of homogeneous nodes") | |||
| homogeneous_node_type: _typing.Optional[str] = ( | |||
| list(static_graph.nodes)[0] | |||
| if len(list(static_graph.nodes)) > 0 else None | |||
| ) | |||
| data: _typing.Dict[str, torch.Tensor] = dict() | |||
| if isinstance(homogeneous_node_type, str): | |||
| node_and_edge_data_keys_intersection: _typing.Set[str] = ( | |||
| set(static_graph.nodes.data) & set(static_graph.data) | |||
| ) | |||
| if len(node_and_edge_data_keys_intersection) > 0: | |||
| raise ValueError( | |||
| f"Provided static graph contains duplicate data " | |||
| f"with same keys {node_and_edge_data_keys_intersection}" | |||
| f"for homogeneous nodes data and graph-level data, " | |||
| f"please refer to doc for more details." | |||
| ) | |||
| data.update(static_graph.nodes.data) | |||
| data.update(static_graph.data) | |||
| else: | |||
| data.update(static_graph.data) | |||
| if len(list(static_graph.edges)) == 1: | |||
| data["edge_index"] = static_graph.edges.connections | |||
| if len(set(data.keys()) & set(static_graph.edges.data.keys())) > 0: | |||
| raise ValueError( | |||
| "Provided static graph contains duplicate data with same key, " | |||
| "please refer to doc for more details." | |||
| ) | |||
| data.update(static_graph.edges.data) | |||
| elif len(list(static_graph.edges)) > 1: | |||
| for canonical_edge_type in static_graph.edges: | |||
| if homogeneous_node_type is not None and isinstance(homogeneous_node_type, str) and ( | |||
| canonical_edge_type.source_node_type != homogeneous_node_type or | |||
| canonical_edge_type.target_node_type != homogeneous_node_type | |||
| ): | |||
| continue | |||
| if len(canonical_edge_type.relation_type) < 4 or canonical_edge_type[-4:] != 'edge': | |||
| continue | |||
| data[f"{canonical_edge_type.relation_type}_index"] = ( | |||
| static_graph.edges[canonical_edge_type].connections | |||
| ) | |||
| edge_type_prefix: str = canonical_edge_type.relation_type[:-4] | |||
| for data_key in static_graph.edges[canonical_edge_type].data: | |||
| if len(data_key) >= 4 and data_key[:4] == 'edge': | |||
| data[f"{edge_type_prefix}{data_key}"] = ( | |||
| static_graph.edges[canonical_edge_type].data[data_key].detach() | |||
| ) | |||
| else: | |||
| data[f"{canonical_edge_type.relation_type}_{data_key}"] = ( | |||
| static_graph.edges[canonical_edge_type].data[data_key].detach() | |||
| ) | |||
| general_data = autogl.data.Data() | |||
| for key, value in data.items(): | |||
| setattr(general_data, key, value) | |||
| return general_data | |||
| def static_graph_to_general_data(static_graph: GeneralStaticGraph) -> autogl.data.Data: | |||
| return StaticGraphToGeneralData().__call__(static_graph) | |||
| @@ -1,6 +1,6 @@ | |||
| from ._split_edges import split_edges | |||
| from ._general import ( | |||
| index_to_mask, | |||
| split_edges, | |||
| random_splits_mask, | |||
| random_splits_mask_class, | |||
| graph_cross_validation, | |||
| @@ -5,9 +5,7 @@ import torch.utils.data | |||
| import typing as _typing | |||
| from sklearn.model_selection import StratifiedKFold, KFold | |||
| from autogl import backend as _backend | |||
| from autogl.data import Data, Dataset, InMemoryStaticGraphSet | |||
| from ...data.graph import GeneralStaticGraph, GeneralStaticGraphGenerator | |||
| from . import _pyg | |||
| from autogl.data import InMemoryStaticGraphSet | |||
| def index_to_mask(index: torch.Tensor, size): | |||
| @@ -16,70 +14,6 @@ def index_to_mask(index: torch.Tensor, size): | |||
| return mask | |||
| def split_edges( | |||
| dataset: InMemoryStaticGraphSet, | |||
| train_ratio: float, val_ratio: float | |||
| ) -> InMemoryStaticGraphSet: | |||
| test_ratio: float = 1 - train_ratio - val_ratio | |||
| def _split_edges_for_graph(homogeneous_static_graph: GeneralStaticGraph) -> GeneralStaticGraph: | |||
| if not isinstance(homogeneous_static_graph, GeneralStaticGraph): | |||
| raise TypeError | |||
| elif not homogeneous_static_graph.edges.is_homogeneous: | |||
| raise ValueError("The provided graph MUST consist of homogeneous edges.") | |||
| else: | |||
| split_data = _pyg.train_test_split_edges( | |||
| Data( | |||
| edge_index=homogeneous_static_graph.edges.connections.detach().clone(), | |||
| edge_attr=( | |||
| homogeneous_static_graph.edges.data['edge_attr'].detach().clone() | |||
| if 'edge_attr' in homogeneous_static_graph.edges.data else None | |||
| ) | |||
| ), | |||
| val_ratio, test_ratio | |||
| ) | |||
| original_edge_type = [et for et in homogeneous_static_graph.edges][0] | |||
| split_static_graph = GeneralStaticGraphGenerator.create_heterogeneous_static_graph( | |||
| dict([ | |||
| (node_type, homogeneous_static_graph.nodes[node_type].data) | |||
| for node_type in homogeneous_static_graph.nodes | |||
| ]), | |||
| { | |||
| (original_edge_type.source_node_type, "train_pos_edge", original_edge_type.target_node_type): ( | |||
| getattr(split_data, "train_pos_edge_index"), | |||
| {"edge_attr": getattr(split_data, "train_pos_edge_attr")} | |||
| if isinstance(getattr(split_data, "train_pos_edge_attr"), torch.Tensor) | |||
| else None | |||
| ), | |||
| (original_edge_type.source_node_type, "val_pos_edge", original_edge_type.target_node_type): ( | |||
| getattr(split_data, "val_pos_edge_index"), | |||
| {"edge_attr": getattr(split_data, "val_pos_edge_attr")} | |||
| if isinstance(getattr(split_data, "val_pos_edge_attr"), torch.Tensor) | |||
| else None | |||
| ), | |||
| (original_edge_type.source_node_type, "val_neg_edge", original_edge_type.target_node_type): | |||
| getattr(split_data, "val_neg_edge_index"), | |||
| (original_edge_type.source_node_type, "test_pos_edge", original_edge_type.target_node_type): ( | |||
| getattr(split_data, "test_pos_edge_index"), | |||
| {"edge_attr": getattr(split_data, "test_pos_edge_attr")} | |||
| if isinstance(getattr(split_data, "test_pos_edge_attr"), torch.Tensor) | |||
| else None | |||
| ), | |||
| (original_edge_type.source_node_type, "test_neg_edge", original_edge_type.target_node_type): | |||
| getattr(split_data, "test_neg_edge_index") | |||
| }, | |||
| homogeneous_static_graph.data | |||
| ) | |||
| return split_static_graph | |||
| if not isinstance(dataset, InMemoryStaticGraphSet): | |||
| raise TypeError | |||
| for index in range(len(dataset)): | |||
| dataset[index] = _split_edges_for_graph(dataset[index]) | |||
| return dataset | |||
| def random_splits_mask( | |||
| dataset: InMemoryStaticGraphSet, | |||
| train_ratio: float = 0.2, val_ratio: float = 0.4, | |||
| @@ -0,0 +1 @@ | |||
| from .split_edges import split_edges | |||
| @@ -0,0 +1,207 @@ | |||
| import numpy as np | |||
| import scipy.sparse as sp | |||
| import torch | |||
| import typing as _typing | |||
| import dgl | |||
| import autogl.data.graph | |||
| from .train_test_split_edges import train_test_split_edges | |||
| from autogl.data.graph.utils.conversion import static_graph_to_general_data | |||
| class _SplitEdgesDGLImpl: | |||
| @classmethod | |||
| def __split_edges_train_val_test( | |||
| cls, g: dgl.DGLGraph, | |||
| train_ratio: float, val_ratio: float | |||
| ) -> _typing.Tuple[ | |||
| dgl.DGLGraph, dgl.DGLGraph, dgl.DGLGraph, | |||
| dgl.DGLGraph, dgl.DGLGraph, dgl.DGLGraph, dgl.DGLGraph | |||
| ]: | |||
| u, v = g.edges() | |||
| eids = np.arange(g.number_of_edges()) | |||
| eids = np.random.permutation(eids) | |||
| valid_size = int(len(eids) * val_ratio) | |||
| test_size = int(len(eids) * (1 - train_ratio - val_ratio)) | |||
| train_size = g.number_of_edges() - test_size - valid_size | |||
| test_pos_u, test_pos_v = u[eids[:test_size]], v[eids[:test_size]] | |||
| valid_pos_u, valid_pos_v = u[eids[test_size:test_size + valid_size]], v[eids[test_size:test_size + valid_size]] | |||
| train_pos_u, train_pos_v = u[eids[test_size + valid_size:]], v[eids[test_size + valid_size:]] | |||
| # Find all negative edges and split them for training and testing | |||
| adj = sp.coo_matrix((np.ones(len(u)), (u.numpy(), v.numpy()))) | |||
| adj_neg = 1 - adj.todense() - np.eye(g.number_of_nodes()) | |||
| neg_u, neg_v = np.where(adj_neg != 0) | |||
| neg_eids = np.random.choice(len(neg_u), g.number_of_edges()) | |||
| test_neg_u, test_neg_v = neg_u[neg_eids[:test_size]], neg_v[neg_eids[:test_size]] | |||
| valid_neg_u, valid_neg_v = neg_u[neg_eids[test_size:test_size + valid_size]], neg_v[neg_eids[test_size:test_size + valid_size]] | |||
| train_neg_u, train_neg_v = neg_u[neg_eids[train_size:]], neg_v[neg_eids[train_size:]] | |||
| train_g = dgl.remove_edges(g, eids[:test_size + valid_size]) | |||
| train_pos_g = dgl.graph((train_pos_u, train_pos_v), num_nodes=g.number_of_nodes()) | |||
| train_neg_g = dgl.graph((train_neg_u, train_neg_v), num_nodes=g.number_of_nodes()) | |||
| valid_pos_g = dgl.graph((valid_pos_u, valid_pos_v), num_nodes=g.number_of_nodes()) | |||
| valid_neg_g = dgl.graph((valid_neg_u, valid_neg_v), num_nodes=g.number_of_nodes()) | |||
| test_pos_g = dgl.graph((test_pos_u, test_pos_v), num_nodes=g.number_of_nodes()) | |||
| test_neg_g = dgl.graph((test_neg_u, test_neg_v), num_nodes=g.number_of_nodes()) | |||
| return ( | |||
| train_g, train_pos_g, train_neg_g, | |||
| valid_pos_g, valid_neg_g, test_pos_g, test_neg_g | |||
| ) | |||
| @classmethod | |||
| def __split_edges_train_test( | |||
| cls, g: dgl.DGLGraph, train_ratio: float | |||
| ) -> _typing.Tuple[ | |||
| dgl.DGLGraph, dgl.DGLGraph, dgl.DGLGraph, | |||
| dgl.DGLGraph, dgl.DGLGraph, | |||
| ]: | |||
| u, v = g.edges() | |||
| eids = np.arange(g.number_of_edges()) | |||
| eids = np.random.permutation(eids) | |||
| test_size = int(len(eids) * (1 - train_ratio)) | |||
| train_size = g.number_of_edges() - test_size | |||
| test_pos_u, test_pos_v = u[eids[:test_size]], v[eids[:test_size]] | |||
| train_pos_u, train_pos_v = u[eids[test_size:]], v[eids[test_size:]] | |||
| # Find all negative edges and split them for training and testing | |||
| adj = sp.coo_matrix((np.ones(len(u)), (u.numpy(), v.numpy()))) | |||
| adj_neg = 1 - adj.todense() - np.eye(g.number_of_nodes()) | |||
| neg_u, neg_v = np.where(adj_neg != 0) | |||
| neg_eids = np.random.choice(len(neg_u), g.number_of_edges()) | |||
| test_neg_u, test_neg_v = neg_u[neg_eids[:test_size]], neg_v[neg_eids[:test_size]] | |||
| train_neg_u, train_neg_v = neg_u[neg_eids[train_size:]], neg_v[neg_eids[train_size:]] | |||
| train_g = dgl.remove_edges(g, eids[:test_size]) | |||
| train_pos_g = dgl.graph((train_pos_u, train_pos_v), num_nodes=g.number_of_nodes()) | |||
| train_neg_g = dgl.graph((train_neg_u, train_neg_v), num_nodes=g.number_of_nodes()) | |||
| test_pos_g = dgl.graph((test_pos_u, test_pos_v), num_nodes=g.number_of_nodes()) | |||
| test_neg_g = dgl.graph((test_neg_u, test_neg_v), num_nodes=g.number_of_nodes()) | |||
| return train_g, train_pos_g, train_neg_g, test_pos_g, test_neg_g | |||
| @classmethod | |||
| def split_edges_for_dgl_graph( | |||
| cls, graph: dgl.DGLGraph, | |||
| train_ratio: float, val_ratio: _typing.Optional[float] = ... | |||
| ) -> _typing.Union[ | |||
| _typing.Tuple[ | |||
| dgl.DGLGraph, dgl.DGLGraph, | |||
| dgl.DGLGraph, dgl.DGLGraph, | |||
| dgl.DGLGraph, dgl.DGLGraph, | |||
| dgl.DGLGraph | |||
| ], | |||
| _typing.Tuple[ | |||
| dgl.DGLGraph, dgl.DGLGraph, dgl.DGLGraph, | |||
| dgl.DGLGraph, dgl.DGLGraph, | |||
| ] | |||
| ]: | |||
| if not 0 < train_ratio < 1: | |||
| raise ValueError(f"Invalid train_ratio as {train_ratio}") | |||
| if isinstance(val_ratio, float): | |||
| if not 0 < val_ratio < 1: | |||
| raise ValueError(f"Invalid val_ratio as {val_ratio}") | |||
| if not 0 < train_ratio + val_ratio < 1: | |||
| raise ValueError( | |||
| f"Invalid combination (train_ratio, val_ratio) " | |||
| f"as ({train_ratio}, {val_ratio})" | |||
| ) | |||
| return cls.__split_edges_train_val_test(graph, train_ratio, val_ratio) | |||
| else: | |||
| return cls.__split_edges_train_test(graph, train_ratio) | |||
| def split_edges_for_data( | |||
| data: _typing.Union[ | |||
| dgl.DGLGraph, autogl.data.graph.GeneralStaticGraph, _typing.Any | |||
| ], | |||
| train_ratio: float, val_ratio: _typing.Optional[float] | |||
| ) -> _typing.Union[ | |||
| autogl.data.Data, | |||
| _typing.Tuple[ | |||
| dgl.DGLGraph, dgl.DGLGraph, | |||
| dgl.DGLGraph, dgl.DGLGraph, | |||
| dgl.DGLGraph, dgl.DGLGraph, | |||
| dgl.DGLGraph | |||
| ], | |||
| _typing.Tuple[ | |||
| dgl.DGLGraph, dgl.DGLGraph, dgl.DGLGraph, | |||
| dgl.DGLGraph, dgl.DGLGraph, | |||
| ] | |||
| ]: | |||
| if isinstance(data, dgl.DGLGraph): | |||
| if not data.is_homogeneous: | |||
| raise ValueError( | |||
| "provided DGL graph to split edges MUST be homogeneous" | |||
| ) | |||
| else: | |||
| return _SplitEdgesDGLImpl.split_edges_for_dgl_graph( | |||
| data, train_ratio, val_ratio | |||
| ) | |||
| elif isinstance(data, autogl.data.graph.GeneralStaticGraph): | |||
| if not (data.nodes.is_homogeneous and data.edges.is_homogeneous): | |||
| raise ValueError( | |||
| "Provided instance of GeneralStaticGraph MUST be homogeneous" | |||
| ) | |||
| __data = static_graph_to_general_data(data) | |||
| edge_index: torch.LongTensor = data.edges.connections | |||
| edge_attr: _typing.Optional[torch.Tensor] = ( | |||
| data.edges.data['edge_attr'] if 'edge_attr' in data.edges.data else None | |||
| ) | |||
| elif ( | |||
| hasattr(data, 'edge_index') and | |||
| isinstance(data.edge_index, torch.Tensor) and | |||
| data.edge_index.dim() == data.edge_index.size(0) == 2 | |||
| ): | |||
| edge_index: torch.LongTensor = data.edge_index | |||
| if ( | |||
| hasattr(data, 'edge_attr') and | |||
| isinstance(data.edge_attr, torch.Tensor) and | |||
| data.edge_attr.size(0) == edge_index.size(1) | |||
| ): | |||
| edge_attr: _typing.Optional[torch.Tensor] = data.edge_attr | |||
| else: | |||
| edge_attr: _typing.Optional[torch.Tensor] = None | |||
| if hasattr(data, 'x') and isinstance(data.x, torch.Tensor): | |||
| x: _typing.Optional[torch.Tensor] = data.x | |||
| else: | |||
| x: _typing.Optional[torch.Tensor] = None | |||
| if hasattr(data, 'y') and isinstance(data.y, torch.Tensor): | |||
| y: _typing.Optional[torch.Tensor] = data.x | |||
| else: | |||
| y: _typing.Optional[torch.Tensor] = None | |||
| __data = autogl.data.Data( | |||
| edge_index=edge_index, edge_attr=edge_attr, x=x, y=y | |||
| ) | |||
| else: | |||
| raise ValueError | |||
| if isinstance(val_ratio, float) and 0 < val_ratio < 1: | |||
| test_ratio = 1 - train_ratio - val_ratio | |||
| else: | |||
| test_ratio = 1 - train_ratio | |||
| compound_results = train_test_split_edges( | |||
| edge_index, edge_attr, | |||
| val_ratio=val_ratio, test_ratio=test_ratio | |||
| ) | |||
| __data.train_pos_edge_index = compound_results.train_pos_edge_index | |||
| __data.train_pos_edge_attr = compound_results.train_pos_edge_attr | |||
| __data.train_neg_adj_mask = compound_results.train_neg_adj_mask | |||
| __data.val_pos_edge_index = compound_results.val_pos_edge_index | |||
| __data.val_pos_edge_attr = compound_results.val_pos_edge_attr | |||
| __data.val_neg_edge_index = compound_results.val_neg_edge_index | |||
| __data.test_pos_edge_index = compound_results.test_pos_edge_index | |||
| __data.test_pos_edge_attr = compound_results.test_pos_edge_attr | |||
| __data.test_neg_edge_index = compound_results.test_neg_edge_index | |||
| return __data | |||
| @@ -0,0 +1,71 @@ | |||
| import torch | |||
| import typing as _typing | |||
| import torch_geometric | |||
| import autogl | |||
| from autogl.data.graph import GeneralStaticGraph | |||
| from autogl.data.graph.utils.conversion import static_graph_to_pyg_data | |||
| from .train_test_split_edges import train_test_split_edges | |||
| def split_edges_for_data( | |||
| data: _typing.Union[ | |||
| torch_geometric.data.Data, autogl.data.graph.GeneralStaticGraph, _typing.Any | |||
| ], | |||
| train_ratio: float, val_ratio: float | |||
| ) -> torch_geometric.data.Data: | |||
| if isinstance(data, autogl.data.graph.GeneralStaticGraph): | |||
| if not (data.nodes.is_homogeneous and data.edges.is_homogeneous): | |||
| raise ValueError( | |||
| "Provided instance of GeneralStaticGraph MUST be homogeneous" | |||
| ) | |||
| edge_index: torch.LongTensor = data.edges.connections | |||
| edge_attr: _typing.Optional[torch.Tensor] = ( | |||
| data.edges.data['edge_attr'] if 'edge_attr' in data.edges.data else None | |||
| ) | |||
| __data = static_graph_to_pyg_data(data) | |||
| elif ( | |||
| hasattr(data, 'edge_index') and | |||
| isinstance(data.edge_index, torch.Tensor) and | |||
| data.edge_index.dim() == data.edge_index.size(0) == 2 | |||
| ): | |||
| edge_index: torch.LongTensor = data.edge_index | |||
| if ( | |||
| hasattr(data, 'edge_attr') and | |||
| isinstance(data.edge_attr, torch.Tensor) and | |||
| data.edge_attr.size(0) == edge_index.size(1) | |||
| ): | |||
| edge_attr: _typing.Optional[torch.Tensor] = data.edge_attr | |||
| else: | |||
| edge_attr: _typing.Optional[torch.Tensor] = None | |||
| if hasattr(data, 'x') and isinstance(data.x, torch.Tensor): | |||
| x: _typing.Optional[torch.Tensor] = data.x | |||
| else: | |||
| x: _typing.Optional[torch.Tensor] = None | |||
| if hasattr(data, 'y') and isinstance(data.y, torch.Tensor): | |||
| y: _typing.Optional[torch.Tensor] = data.x | |||
| else: | |||
| y: _typing.Optional[torch.Tensor] = None | |||
| __data = torch_geometric.data.Data( | |||
| edge_index=edge_index, edge_attr=edge_attr, x=x, y=y | |||
| ) | |||
| else: | |||
| raise ValueError | |||
| if isinstance(val_ratio, float) and 0 < val_ratio < 1: | |||
| test_ratio = 1 - train_ratio - val_ratio | |||
| else: | |||
| test_ratio = 1 - train_ratio | |||
| compound_results = train_test_split_edges( | |||
| edge_index, edge_attr, | |||
| val_ratio=val_ratio, test_ratio=test_ratio | |||
| ) | |||
| __data.train_pos_edge_index = compound_results.train_pos_edge_index | |||
| __data.train_pos_edge_attr = compound_results.train_pos_edge_attr | |||
| __data.train_neg_adj_mask = compound_results.train_neg_adj_mask | |||
| __data.val_pos_edge_index = compound_results.val_pos_edge_index | |||
| __data.val_pos_edge_attr = compound_results.val_pos_edge_attr | |||
| __data.val_neg_edge_index = compound_results.val_neg_edge_index | |||
| __data.test_pos_edge_index = compound_results.test_pos_edge_index | |||
| __data.test_pos_edge_attr = compound_results.test_pos_edge_attr | |||
| __data.test_neg_edge_index = compound_results.test_neg_edge_index | |||
| return __data | |||
| @@ -0,0 +1,32 @@ | |||
| import typing as _typing | |||
| from autogl.data import InMemoryDataset | |||
| import autogl | |||
| if autogl.backend.DependentBackend.is_dgl(): | |||
| from ._dgl_compatible import split_edges_for_data | |||
| elif autogl.backend.DependentBackend.is_pyg(): | |||
| from ._pyg_compatible import split_edges_for_data | |||
| else: | |||
| raise NotImplementedError | |||
| def split_edges( | |||
| dataset: InMemoryDataset, | |||
| train_ratio: float, val_ratio: _typing.Optional[float] | |||
| ) -> InMemoryDataset: | |||
| if isinstance(val_ratio, float) and not 0 < train_ratio + val_ratio < 1: | |||
| raise ValueError | |||
| elif not 0 < train_ratio < 1: | |||
| raise ValueError | |||
| if ( | |||
| autogl.backend.DependentBackend.is_pyg() and | |||
| not (isinstance(val_ratio, float) and 0 < val_ratio < 1) | |||
| ): | |||
| raise ValueError( | |||
| "For PyG as backend, val_ratio MUST be specific float between 0 and 1, " | |||
| "i.e. 0 < val_ratio < 1" | |||
| ) | |||
| return InMemoryDataset( | |||
| [split_edges_for_data(item, train_ratio, val_ratio) for item in dataset], | |||
| dataset.train_index, dataset.val_index, dataset.test_index | |||
| ) | |||
| @@ -0,0 +1,310 @@ | |||
| import math | |||
| import torch | |||
| import typing as _typing | |||
| def _maybe_num_nodes(edge_index, num_nodes=None): | |||
| if isinstance(num_nodes, int): | |||
| return num_nodes | |||
| elif isinstance(edge_index, torch.Tensor): | |||
| return int(edge_index.max()) + 1 if edge_index.numel() > 0 else 0 | |||
| else: | |||
| return max(edge_index.size(0), edge_index.size(1)) | |||
| def __coalesce( | |||
| edge_index: torch.Tensor, | |||
| edge_attr: _typing.Union[ | |||
| torch.Tensor, _typing.Iterable[torch.Tensor], None | |||
| ] = None, | |||
| num_nodes: _typing.Optional[int] = ..., | |||
| is_sorted: bool = False, | |||
| sort_by_row: bool = True | |||
| ) -> _typing.Union[ | |||
| torch.Tensor, _typing.Tuple[torch.Tensor, torch.Tensor], | |||
| _typing.Tuple[torch.Tensor, _typing.Iterable[torch.Tensor]] | |||
| ]: | |||
| """ | |||
| Row-wise sorts :obj:`edge_index` and removes its duplicated entries. | |||
| Duplicate entries in :obj:`edge_attr` are directly removed, instead of merged. | |||
| Args: | |||
| edge_index (LongTensor): The edge indices. | |||
| edge_attr (Tensor or List[Tensor], optional): Edge weights or multi- | |||
| dimensional edge features. | |||
| If given as a list, will re-shuffle and remove duplicates for all | |||
| its entries. (default: :obj:`None`) | |||
| num_nodes (int, optional): The number of nodes, *i.e.* | |||
| :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`) | |||
| is_sorted (bool, optional): If set to :obj:`True`, will expect | |||
| :obj:`edge_index` to be already sorted row-wise. | |||
| sort_by_row (bool, optional): If set to :obj:`False`, will sort | |||
| :obj:`edge_index` column-wise. | |||
| :rtype: :class:`LongTensor` if :attr:`edge_attr` is :obj:`None`, else | |||
| (:class:`LongTensor`, :obj:`Tensor` or :obj:`Iterable[Tensor]]`) | |||
| """ | |||
| if edge_attr is None: | |||
| pass | |||
| elif isinstance(edge_attr, torch.Tensor) and torch.is_tensor(edge_attr): | |||
| if edge_attr.size(0) != edge_index.size(1): | |||
| raise ValueError | |||
| elif isinstance(edge_attr, _typing.Iterable): | |||
| if not all([ | |||
| ( | |||
| isinstance(attr, torch.Tensor) and | |||
| attr.size(0) == edge_index.size(1) | |||
| ) for attr in edge_attr | |||
| ]): | |||
| raise ValueError("Invalid edge_attr argument") | |||
| else: | |||
| raise TypeError("Unsupported type of edge_attr argument") | |||
| nnz = edge_index.size(1) | |||
| num_nodes = _maybe_num_nodes(edge_index, num_nodes) | |||
| idx = edge_index.new_empty(nnz + 1) | |||
| idx[0] = -1 | |||
| idx[1:] = edge_index[1 - int(sort_by_row)] | |||
| idx[1:].mul_(num_nodes).add_(edge_index[int(sort_by_row)]) | |||
| if not is_sorted: | |||
| idx[1:], perm = idx[1:].sort() | |||
| edge_index = edge_index[:, perm] | |||
| if edge_attr is not None and isinstance(edge_attr, torch.Tensor): | |||
| edge_attr = edge_attr[perm] | |||
| elif edge_attr is not None: | |||
| edge_attr = [e[perm] for e in edge_attr] | |||
| mask: _typing.Any = idx[1:] > idx[:-1] | |||
| # Only perform expensive merging in case there exists duplicates: | |||
| if mask.all(): | |||
| return edge_index if edge_attr is None else (edge_index, edge_attr) | |||
| edge_index = edge_index[:, mask] | |||
| if edge_attr is None: | |||
| return edge_index | |||
| elif isinstance(edge_attr, torch.Tensor): | |||
| return edge_index, edge_attr[mask] | |||
| elif isinstance(edge_attr, _typing.Iterable): | |||
| return edge_index, [attr[mask] for attr in edge_attr] | |||
| def coalesce( | |||
| edge_index: torch.Tensor, | |||
| edge_attr: _typing.Union[ | |||
| torch.Tensor, _typing.Iterable[torch.Tensor], None | |||
| ] = None, | |||
| num_nodes: _typing.Optional[int] = ..., | |||
| is_sorted: bool = False, | |||
| sort_by_row: bool = True | |||
| ) -> _typing.Union[ | |||
| torch.Tensor, _typing.Tuple[torch.Tensor, torch.Tensor], | |||
| _typing.Tuple[torch.Tensor, _typing.Iterable[torch.Tensor]] | |||
| ]: | |||
| """ | |||
| Row-wise sorts :obj:`edge_index` and removes its duplicated entries. | |||
| Duplicate entries in :obj:`edge_attr` are directly removed, instead of merged. | |||
| Args: | |||
| edge_index (LongTensor): The edge indices. | |||
| edge_attr (Tensor or List[Tensor], optional): Edge weights or multi- | |||
| dimensional edge features. | |||
| If given as a list, will re-shuffle and remove duplicates for all | |||
| its entries. (default: :obj:`None`) | |||
| num_nodes (int, optional): The number of nodes, *i.e.* | |||
| :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`) | |||
| is_sorted (bool, optional): If set to :obj:`True`, will expect | |||
| :obj:`edge_index` to be already sorted row-wise. | |||
| sort_by_row (bool, optional): If set to :obj:`False`, will sort | |||
| :obj:`edge_index` column-wise. | |||
| :rtype: :class:`LongTensor` if :attr:`edge_attr` is :obj:`None`, else | |||
| (:class:`LongTensor`, :obj:`Tensor` or :obj:`Iterable[Tensor]]`) | |||
| """ | |||
| if not isinstance(num_nodes, int): | |||
| num_nodes = None | |||
| try: | |||
| import torch_geometric | |||
| return torch_geometric.utils.coalesce( | |||
| edge_index, edge_attr, num_nodes, | |||
| is_sorted=is_sorted, | |||
| sort_by_row=sort_by_row | |||
| ) | |||
| except ModuleNotFoundError: | |||
| return __coalesce( | |||
| edge_index, edge_attr, num_nodes, | |||
| is_sorted=is_sorted, | |||
| sort_by_row=sort_by_row | |||
| ) | |||
| def to_undirected( | |||
| edge_index: torch.Tensor, | |||
| edge_attr: _typing.Optional[_typing.Union[torch.Tensor, _typing.List[torch.Tensor]]] = None, | |||
| num_nodes: _typing.Optional[int] = ..., | |||
| __reduce: str = "add", | |||
| ) -> _typing.Union[ | |||
| torch.Tensor, _typing.Tuple[torch.Tensor, torch.Tensor], | |||
| _typing.Tuple[torch.Tensor, _typing.List[torch.Tensor]] | |||
| ]: | |||
| r"""Converts the graph given by :attr:`edge_index` to an undirected graph | |||
| such that :math:`(j,i) \in \mathcal{E}` for every edge :math:`(i,j) \in | |||
| \mathcal{E}`. | |||
| Args: | |||
| edge_index (LongTensor): The edge indices. | |||
| edge_attr (Tensor or List[Tensor], optional): Edge weights or multi- | |||
| dimensional edge features. | |||
| If given as a list, will remove duplicates for all its entries. | |||
| (default: :obj:`None`) | |||
| num_nodes (int, optional): The number of nodes, *i.e.* | |||
| :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`) | |||
| __reduce (string, optional): The reduce operation to use for merging edge | |||
| features (:obj:`"add"`, :obj:`"mean"`, :obj:`"min"`, :obj:`"max"`, | |||
| :obj:`"mul"`). (default: :obj:`"add"`) | |||
| :rtype: :class:`LongTensor` if :attr:`edge_attr` is :obj:`None`, else | |||
| (:class:`LongTensor`, :obj:`Tensor` or :obj:`List[Tensor]]`) | |||
| """ | |||
| # Maintain backward compatibility to `to_undirected(edge_index, num_nodes)` | |||
| if isinstance(edge_attr, int): | |||
| edge_attr = None | |||
| num_nodes = edge_attr | |||
| row, col = edge_index | |||
| row, col = torch.cat([row, col], dim=0), torch.cat([col, row], dim=0) | |||
| edge_index = torch.stack([row, col], dim=0) | |||
| if edge_attr is not None and isinstance(edge_attr, torch.Tensor): | |||
| edge_attr = torch.cat([edge_attr, edge_attr], dim=0) | |||
| elif edge_attr is not None: | |||
| edge_attr = [torch.cat([e, e], dim=0) for e in edge_attr] | |||
| return coalesce(edge_index, edge_attr, num_nodes) | |||
| class _SplitResult: | |||
| def __init__( | |||
| self, | |||
| train_pos_edge_index: torch.Tensor, | |||
| train_pos_edge_attr: _typing.Optional[torch.Tensor], | |||
| train_neg_adj_mask: torch.Tensor, | |||
| val_pos_edge_index: torch.Tensor, | |||
| val_pos_edge_attr: _typing.Optional[torch.Tensor], | |||
| val_neg_edge_index: torch.Tensor, | |||
| test_pos_edge_index: torch.Tensor, | |||
| test_pos_edge_attr: _typing.Optional[torch.Tensor], | |||
| test_neg_edge_index: torch.Tensor | |||
| ): | |||
| self.train_pos_edge_index: torch.Tensor = train_pos_edge_index | |||
| self.train_pos_edge_attr: _typing.Optional[torch.Tensor] = train_pos_edge_attr | |||
| self.train_neg_adj_mask: torch.Tensor = train_neg_adj_mask | |||
| self.val_pos_edge_index: torch.Tensor = val_pos_edge_index | |||
| self.val_pos_edge_attr: _typing.Optional[torch.Tensor] = val_pos_edge_attr | |||
| self.val_neg_edge_index: torch.Tensor = val_neg_edge_index | |||
| self.test_pos_edge_index: torch.Tensor = test_pos_edge_index | |||
| self.test_pos_edge_attr: _typing.Optional[torch.Tensor] = test_pos_edge_attr | |||
| self.test_neg_edge_index: torch.Tensor = test_neg_edge_index | |||
| def train_test_split_edges( | |||
| edge_index: torch.Tensor, | |||
| edge_attr: _typing.Optional[_typing.Union[torch.Tensor, _typing.List[torch.Tensor]]] = None, | |||
| num_nodes: _typing.Optional[int] = ..., | |||
| val_ratio: float = 0.05, | |||
| test_ratio: float = 0.1 | |||
| ): | |||
| r"""Splits the edges of a :class:`torch_geometric.data.Data` object | |||
| into positive and negative train/val/test edges. | |||
| As such, it will replace the :obj:`edge_index` attribute with | |||
| :obj:`train_pos_edge_index`, :obj:`train_pos_neg_adj_mask`, | |||
| :obj:`val_pos_edge_index`, :obj:`val_neg_edge_index` and | |||
| :obj:`test_pos_edge_index` attributes. | |||
| If :obj:`data` has edge features named :obj:`edge_attr`, then | |||
| :obj:`train_pos_edge_attr`, :obj:`val_pos_edge_attr` and | |||
| :obj:`test_pos_edge_attr` will be added as well. | |||
| .. warning:: | |||
| :meth:`~torch_geometric.utils.train_test_split_edges` is deprecated and | |||
| will be removed in a future release. | |||
| Use :class:`torch_geometric.transforms.RandomLinkSplit` instead. | |||
| Args: | |||
| edge_index (LongTensor): The edge indices. | |||
| edge_attr (Tensor or List[Tensor], optional): Edge weights or multi- | |||
| dimensional edge features. | |||
| If given as a list, will remove duplicates for all its entries. | |||
| (default: :obj:`None`) | |||
| num_nodes (int, optional): The number of nodes, *i.e.* | |||
| :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`) | |||
| val_ratio (float, optional): The ratio of positive validation edges. | |||
| (default: :obj:`0.05`) | |||
| test_ratio (float, optional): The ratio of positive test edges. | |||
| (default: :obj:`0.1`) | |||
| """ | |||
| row, col = edge_index | |||
| num_nodes = _maybe_num_nodes(edge_index, num_nodes) | |||
| # Return upper triangular portion. | |||
| mask = row < col | |||
| row, col = row[mask], col[mask] | |||
| if edge_attr is not None: | |||
| edge_attr = edge_attr[mask] | |||
| n_v = int(math.floor(val_ratio * row.size(0))) | |||
| n_t = int(math.floor(test_ratio * row.size(0))) | |||
| # Positive edges. | |||
| perm = torch.randperm(row.size(0)) | |||
| row, col = row[perm], col[perm] | |||
| if edge_attr is not None: | |||
| edge_attr = edge_attr[perm] | |||
| r, c = row[:n_v], col[:n_v] | |||
| val_pos_edge_index = torch.stack([r, c], dim=0) | |||
| if edge_attr is not None: | |||
| val_pos_edge_attr = edge_attr[:n_v] | |||
| else: | |||
| val_pos_edge_attr = None | |||
| r, c = row[n_v:n_v + n_t], col[n_v:n_v + n_t] | |||
| test_pos_edge_index = torch.stack([r, c], dim=0) | |||
| if edge_attr is not None: | |||
| test_pos_edge_attr = edge_attr[n_v:n_v + n_t] | |||
| else: | |||
| test_pos_edge_attr = None | |||
| r, c = row[n_v + n_t:], col[n_v + n_t:] | |||
| train_pos_edge_index = torch.stack([r, c], dim=0) | |||
| if edge_attr is not None: | |||
| train_pos_edge_index, train_pos_edge_attr = to_undirected( | |||
| train_pos_edge_index, edge_attr[n_v + n_t:] | |||
| ) | |||
| else: | |||
| train_pos_edge_index = to_undirected(train_pos_edge_index) | |||
| train_pos_edge_attr = None | |||
| # Negative edges. | |||
| neg_adj_mask = torch.ones(num_nodes, num_nodes, dtype=torch.uint8) | |||
| neg_adj_mask = neg_adj_mask.triu(diagonal=1).to(torch.bool) | |||
| neg_adj_mask[row, col] = 0 | |||
| neg_row, neg_col = neg_adj_mask.nonzero(as_tuple=False).t() | |||
| perm = torch.randperm(neg_row.size(0))[:n_v + n_t] | |||
| neg_row, neg_col = neg_row[perm], neg_col[perm] | |||
| neg_adj_mask[neg_row, neg_col] = 0 | |||
| train_neg_adj_mask = neg_adj_mask | |||
| row, col = neg_row[:n_v], neg_col[:n_v] | |||
| val_neg_edge_index = torch.stack([row, col], dim=0) | |||
| row, col = neg_row[n_v:n_v + n_t], neg_col[n_v:n_v + n_t] | |||
| test_neg_edge_index = torch.stack([row, col], dim=0) | |||
| return _SplitResult( | |||
| train_pos_edge_index, train_pos_edge_attr, train_neg_adj_mask, | |||
| val_pos_edge_index, val_pos_edge_attr, val_neg_edge_index, | |||
| test_pos_edge_index, test_pos_edge_attr, test_neg_edge_index | |||
| ) | |||