Browse Source

Modify datasets.utils.split_edges

tags/v0.3.1
CoreLeader 4 years ago
parent
commit
b076761ddc
9 changed files with 703 additions and 68 deletions
  1. +1
    -0
      autogl/data/graph/_general_static_graph/utils/conversion/__init__.py
  2. +79
    -0
      autogl/data/graph/_general_static_graph/utils/conversion/_general.py
  3. +1
    -1
      autogl/datasets/utils/__init__.py
  4. +1
    -67
      autogl/datasets/utils/_general.py
  5. +1
    -0
      autogl/datasets/utils/_split_edges/__init__.py
  6. +207
    -0
      autogl/datasets/utils/_split_edges/_dgl_compatible.py
  7. +71
    -0
      autogl/datasets/utils/_split_edges/_pyg_compatible.py
  8. +32
    -0
      autogl/datasets/utils/_split_edges/split_edges.py
  9. +310
    -0
      autogl/datasets/utils/_split_edges/train_test_split_edges.py

+ 1
- 0
autogl/data/graph/_general_static_graph/utils/conversion/__init__.py View File

@@ -1,3 +1,4 @@
from ._general import StaticGraphToGeneralData, static_graph_to_general_data
from ._nx import (
HomogeneousStaticGraphToNetworkX
)


+ 79
- 0
autogl/data/graph/_general_static_graph/utils/conversion/_general.py View File

@@ -0,0 +1,79 @@
import torch
import typing as _typing
import autogl
from ... import GeneralStaticGraph


class StaticGraphToGeneralData:
def __init__(self, *__args, **__kwargs):
pass

def __call__(
self, static_graph: GeneralStaticGraph,
*__args, **__kwargs
):
if not isinstance(static_graph, GeneralStaticGraph):
raise TypeError
elif not static_graph.nodes.is_homogeneous:
raise ValueError("Provided static graph MUST consist of homogeneous nodes")
homogeneous_node_type: _typing.Optional[str] = (
list(static_graph.nodes)[0]
if len(list(static_graph.nodes)) > 0 else None
)
data: _typing.Dict[str, torch.Tensor] = dict()
if isinstance(homogeneous_node_type, str):
node_and_edge_data_keys_intersection: _typing.Set[str] = (
set(static_graph.nodes.data) & set(static_graph.data)
)
if len(node_and_edge_data_keys_intersection) > 0:
raise ValueError(
f"Provided static graph contains duplicate data "
f"with same keys {node_and_edge_data_keys_intersection}"
f"for homogeneous nodes data and graph-level data, "
f"please refer to doc for more details."
)
data.update(static_graph.nodes.data)
data.update(static_graph.data)
else:
data.update(static_graph.data)

if len(list(static_graph.edges)) == 1:
data["edge_index"] = static_graph.edges.connections
if len(set(data.keys()) & set(static_graph.edges.data.keys())) > 0:
raise ValueError(
"Provided static graph contains duplicate data with same key, "
"please refer to doc for more details."
)
data.update(static_graph.edges.data)
elif len(list(static_graph.edges)) > 1:
for canonical_edge_type in static_graph.edges:
if homogeneous_node_type is not None and isinstance(homogeneous_node_type, str) and (
canonical_edge_type.source_node_type != homogeneous_node_type or
canonical_edge_type.target_node_type != homogeneous_node_type
):
continue
if len(canonical_edge_type.relation_type) < 4 or canonical_edge_type[-4:] != 'edge':
continue
data[f"{canonical_edge_type.relation_type}_index"] = (
static_graph.edges[canonical_edge_type].connections
)

edge_type_prefix: str = canonical_edge_type.relation_type[:-4]
for data_key in static_graph.edges[canonical_edge_type].data:
if len(data_key) >= 4 and data_key[:4] == 'edge':
data[f"{edge_type_prefix}{data_key}"] = (
static_graph.edges[canonical_edge_type].data[data_key].detach()
)
else:
data[f"{canonical_edge_type.relation_type}_{data_key}"] = (
static_graph.edges[canonical_edge_type].data[data_key].detach()
)

general_data = autogl.data.Data()
for key, value in data.items():
setattr(general_data, key, value)
return general_data


def static_graph_to_general_data(static_graph: GeneralStaticGraph) -> autogl.data.Data:
return StaticGraphToGeneralData().__call__(static_graph)

+ 1
- 1
autogl/datasets/utils/__init__.py View File

@@ -1,6 +1,6 @@
from ._split_edges import split_edges
from ._general import (
index_to_mask,
split_edges,
random_splits_mask,
random_splits_mask_class,
graph_cross_validation,


+ 1
- 67
autogl/datasets/utils/_general.py View File

@@ -5,9 +5,7 @@ import torch.utils.data
import typing as _typing
from sklearn.model_selection import StratifiedKFold, KFold
from autogl import backend as _backend
from autogl.data import Data, Dataset, InMemoryStaticGraphSet
from ...data.graph import GeneralStaticGraph, GeneralStaticGraphGenerator
from . import _pyg
from autogl.data import InMemoryStaticGraphSet


def index_to_mask(index: torch.Tensor, size):
@@ -16,70 +14,6 @@ def index_to_mask(index: torch.Tensor, size):
return mask


def split_edges(
dataset: InMemoryStaticGraphSet,
train_ratio: float, val_ratio: float
) -> InMemoryStaticGraphSet:
test_ratio: float = 1 - train_ratio - val_ratio

def _split_edges_for_graph(homogeneous_static_graph: GeneralStaticGraph) -> GeneralStaticGraph:
if not isinstance(homogeneous_static_graph, GeneralStaticGraph):
raise TypeError
elif not homogeneous_static_graph.edges.is_homogeneous:
raise ValueError("The provided graph MUST consist of homogeneous edges.")
else:
split_data = _pyg.train_test_split_edges(
Data(
edge_index=homogeneous_static_graph.edges.connections.detach().clone(),
edge_attr=(
homogeneous_static_graph.edges.data['edge_attr'].detach().clone()
if 'edge_attr' in homogeneous_static_graph.edges.data else None
)
),
val_ratio, test_ratio
)
original_edge_type = [et for et in homogeneous_static_graph.edges][0]

split_static_graph = GeneralStaticGraphGenerator.create_heterogeneous_static_graph(
dict([
(node_type, homogeneous_static_graph.nodes[node_type].data)
for node_type in homogeneous_static_graph.nodes
]),
{
(original_edge_type.source_node_type, "train_pos_edge", original_edge_type.target_node_type): (
getattr(split_data, "train_pos_edge_index"),
{"edge_attr": getattr(split_data, "train_pos_edge_attr")}
if isinstance(getattr(split_data, "train_pos_edge_attr"), torch.Tensor)
else None
),
(original_edge_type.source_node_type, "val_pos_edge", original_edge_type.target_node_type): (
getattr(split_data, "val_pos_edge_index"),
{"edge_attr": getattr(split_data, "val_pos_edge_attr")}
if isinstance(getattr(split_data, "val_pos_edge_attr"), torch.Tensor)
else None
),
(original_edge_type.source_node_type, "val_neg_edge", original_edge_type.target_node_type):
getattr(split_data, "val_neg_edge_index"),
(original_edge_type.source_node_type, "test_pos_edge", original_edge_type.target_node_type): (
getattr(split_data, "test_pos_edge_index"),
{"edge_attr": getattr(split_data, "test_pos_edge_attr")}
if isinstance(getattr(split_data, "test_pos_edge_attr"), torch.Tensor)
else None
),
(original_edge_type.source_node_type, "test_neg_edge", original_edge_type.target_node_type):
getattr(split_data, "test_neg_edge_index")
},
homogeneous_static_graph.data
)
return split_static_graph

if not isinstance(dataset, InMemoryStaticGraphSet):
raise TypeError
for index in range(len(dataset)):
dataset[index] = _split_edges_for_graph(dataset[index])
return dataset


def random_splits_mask(
dataset: InMemoryStaticGraphSet,
train_ratio: float = 0.2, val_ratio: float = 0.4,


+ 1
- 0
autogl/datasets/utils/_split_edges/__init__.py View File

@@ -0,0 +1 @@
from .split_edges import split_edges

+ 207
- 0
autogl/datasets/utils/_split_edges/_dgl_compatible.py View File

@@ -0,0 +1,207 @@
import numpy as np
import scipy.sparse as sp
import torch
import typing as _typing
import dgl
import autogl.data.graph
from .train_test_split_edges import train_test_split_edges
from autogl.data.graph.utils.conversion import static_graph_to_general_data


class _SplitEdgesDGLImpl:
@classmethod
def __split_edges_train_val_test(
cls, g: dgl.DGLGraph,
train_ratio: float, val_ratio: float
) -> _typing.Tuple[
dgl.DGLGraph, dgl.DGLGraph, dgl.DGLGraph,
dgl.DGLGraph, dgl.DGLGraph, dgl.DGLGraph, dgl.DGLGraph
]:
u, v = g.edges()

eids = np.arange(g.number_of_edges())
eids = np.random.permutation(eids)

valid_size = int(len(eids) * val_ratio)
test_size = int(len(eids) * (1 - train_ratio - val_ratio))
train_size = g.number_of_edges() - test_size - valid_size

test_pos_u, test_pos_v = u[eids[:test_size]], v[eids[:test_size]]
valid_pos_u, valid_pos_v = u[eids[test_size:test_size + valid_size]], v[eids[test_size:test_size + valid_size]]
train_pos_u, train_pos_v = u[eids[test_size + valid_size:]], v[eids[test_size + valid_size:]]

# Find all negative edges and split them for training and testing
adj = sp.coo_matrix((np.ones(len(u)), (u.numpy(), v.numpy())))
adj_neg = 1 - adj.todense() - np.eye(g.number_of_nodes())
neg_u, neg_v = np.where(adj_neg != 0)

neg_eids = np.random.choice(len(neg_u), g.number_of_edges())
test_neg_u, test_neg_v = neg_u[neg_eids[:test_size]], neg_v[neg_eids[:test_size]]
valid_neg_u, valid_neg_v = neg_u[neg_eids[test_size:test_size + valid_size]], neg_v[neg_eids[test_size:test_size + valid_size]]
train_neg_u, train_neg_v = neg_u[neg_eids[train_size:]], neg_v[neg_eids[train_size:]]

train_g = dgl.remove_edges(g, eids[:test_size + valid_size])

train_pos_g = dgl.graph((train_pos_u, train_pos_v), num_nodes=g.number_of_nodes())
train_neg_g = dgl.graph((train_neg_u, train_neg_v), num_nodes=g.number_of_nodes())

valid_pos_g = dgl.graph((valid_pos_u, valid_pos_v), num_nodes=g.number_of_nodes())
valid_neg_g = dgl.graph((valid_neg_u, valid_neg_v), num_nodes=g.number_of_nodes())

test_pos_g = dgl.graph((test_pos_u, test_pos_v), num_nodes=g.number_of_nodes())
test_neg_g = dgl.graph((test_neg_u, test_neg_v), num_nodes=g.number_of_nodes())

return (
train_g, train_pos_g, train_neg_g,
valid_pos_g, valid_neg_g, test_pos_g, test_neg_g
)

@classmethod
def __split_edges_train_test(
cls, g: dgl.DGLGraph, train_ratio: float
) -> _typing.Tuple[
dgl.DGLGraph, dgl.DGLGraph, dgl.DGLGraph,
dgl.DGLGraph, dgl.DGLGraph,
]:
u, v = g.edges()

eids = np.arange(g.number_of_edges())
eids = np.random.permutation(eids)
test_size = int(len(eids) * (1 - train_ratio))
train_size = g.number_of_edges() - test_size
test_pos_u, test_pos_v = u[eids[:test_size]], v[eids[:test_size]]
train_pos_u, train_pos_v = u[eids[test_size:]], v[eids[test_size:]]

# Find all negative edges and split them for training and testing
adj = sp.coo_matrix((np.ones(len(u)), (u.numpy(), v.numpy())))
adj_neg = 1 - adj.todense() - np.eye(g.number_of_nodes())
neg_u, neg_v = np.where(adj_neg != 0)

neg_eids = np.random.choice(len(neg_u), g.number_of_edges())
test_neg_u, test_neg_v = neg_u[neg_eids[:test_size]], neg_v[neg_eids[:test_size]]
train_neg_u, train_neg_v = neg_u[neg_eids[train_size:]], neg_v[neg_eids[train_size:]]

train_g = dgl.remove_edges(g, eids[:test_size])

train_pos_g = dgl.graph((train_pos_u, train_pos_v), num_nodes=g.number_of_nodes())
train_neg_g = dgl.graph((train_neg_u, train_neg_v), num_nodes=g.number_of_nodes())

test_pos_g = dgl.graph((test_pos_u, test_pos_v), num_nodes=g.number_of_nodes())
test_neg_g = dgl.graph((test_neg_u, test_neg_v), num_nodes=g.number_of_nodes())

return train_g, train_pos_g, train_neg_g, test_pos_g, test_neg_g

@classmethod
def split_edges_for_dgl_graph(
cls, graph: dgl.DGLGraph,
train_ratio: float, val_ratio: _typing.Optional[float] = ...
) -> _typing.Union[
_typing.Tuple[
dgl.DGLGraph, dgl.DGLGraph,
dgl.DGLGraph, dgl.DGLGraph,
dgl.DGLGraph, dgl.DGLGraph,
dgl.DGLGraph
],
_typing.Tuple[
dgl.DGLGraph, dgl.DGLGraph, dgl.DGLGraph,
dgl.DGLGraph, dgl.DGLGraph,
]
]:
if not 0 < train_ratio < 1:
raise ValueError(f"Invalid train_ratio as {train_ratio}")
if isinstance(val_ratio, float):
if not 0 < val_ratio < 1:
raise ValueError(f"Invalid val_ratio as {val_ratio}")
if not 0 < train_ratio + val_ratio < 1:
raise ValueError(
f"Invalid combination (train_ratio, val_ratio) "
f"as ({train_ratio}, {val_ratio})"
)
return cls.__split_edges_train_val_test(graph, train_ratio, val_ratio)
else:
return cls.__split_edges_train_test(graph, train_ratio)


def split_edges_for_data(
data: _typing.Union[
dgl.DGLGraph, autogl.data.graph.GeneralStaticGraph, _typing.Any
],
train_ratio: float, val_ratio: _typing.Optional[float]
) -> _typing.Union[
autogl.data.Data,
_typing.Tuple[
dgl.DGLGraph, dgl.DGLGraph,
dgl.DGLGraph, dgl.DGLGraph,
dgl.DGLGraph, dgl.DGLGraph,
dgl.DGLGraph
],
_typing.Tuple[
dgl.DGLGraph, dgl.DGLGraph, dgl.DGLGraph,
dgl.DGLGraph, dgl.DGLGraph,
]
]:
if isinstance(data, dgl.DGLGraph):
if not data.is_homogeneous:
raise ValueError(
"provided DGL graph to split edges MUST be homogeneous"
)
else:
return _SplitEdgesDGLImpl.split_edges_for_dgl_graph(
data, train_ratio, val_ratio
)
elif isinstance(data, autogl.data.graph.GeneralStaticGraph):
if not (data.nodes.is_homogeneous and data.edges.is_homogeneous):
raise ValueError(
"Provided instance of GeneralStaticGraph MUST be homogeneous"
)
__data = static_graph_to_general_data(data)
edge_index: torch.LongTensor = data.edges.connections
edge_attr: _typing.Optional[torch.Tensor] = (
data.edges.data['edge_attr'] if 'edge_attr' in data.edges.data else None
)
elif (
hasattr(data, 'edge_index') and
isinstance(data.edge_index, torch.Tensor) and
data.edge_index.dim() == data.edge_index.size(0) == 2
):
edge_index: torch.LongTensor = data.edge_index
if (
hasattr(data, 'edge_attr') and
isinstance(data.edge_attr, torch.Tensor) and
data.edge_attr.size(0) == edge_index.size(1)
):
edge_attr: _typing.Optional[torch.Tensor] = data.edge_attr
else:
edge_attr: _typing.Optional[torch.Tensor] = None
if hasattr(data, 'x') and isinstance(data.x, torch.Tensor):
x: _typing.Optional[torch.Tensor] = data.x
else:
x: _typing.Optional[torch.Tensor] = None
if hasattr(data, 'y') and isinstance(data.y, torch.Tensor):
y: _typing.Optional[torch.Tensor] = data.x
else:
y: _typing.Optional[torch.Tensor] = None
__data = autogl.data.Data(
edge_index=edge_index, edge_attr=edge_attr, x=x, y=y
)
else:
raise ValueError

if isinstance(val_ratio, float) and 0 < val_ratio < 1:
test_ratio = 1 - train_ratio - val_ratio
else:
test_ratio = 1 - train_ratio
compound_results = train_test_split_edges(
edge_index, edge_attr,
val_ratio=val_ratio, test_ratio=test_ratio
)
__data.train_pos_edge_index = compound_results.train_pos_edge_index
__data.train_pos_edge_attr = compound_results.train_pos_edge_attr
__data.train_neg_adj_mask = compound_results.train_neg_adj_mask
__data.val_pos_edge_index = compound_results.val_pos_edge_index
__data.val_pos_edge_attr = compound_results.val_pos_edge_attr
__data.val_neg_edge_index = compound_results.val_neg_edge_index
__data.test_pos_edge_index = compound_results.test_pos_edge_index
__data.test_pos_edge_attr = compound_results.test_pos_edge_attr
__data.test_neg_edge_index = compound_results.test_neg_edge_index
return __data

+ 71
- 0
autogl/datasets/utils/_split_edges/_pyg_compatible.py View File

@@ -0,0 +1,71 @@
import torch
import typing as _typing
import torch_geometric
import autogl
from autogl.data.graph import GeneralStaticGraph
from autogl.data.graph.utils.conversion import static_graph_to_pyg_data
from .train_test_split_edges import train_test_split_edges


def split_edges_for_data(
data: _typing.Union[
torch_geometric.data.Data, autogl.data.graph.GeneralStaticGraph, _typing.Any
],
train_ratio: float, val_ratio: float
) -> torch_geometric.data.Data:
if isinstance(data, autogl.data.graph.GeneralStaticGraph):
if not (data.nodes.is_homogeneous and data.edges.is_homogeneous):
raise ValueError(
"Provided instance of GeneralStaticGraph MUST be homogeneous"
)
edge_index: torch.LongTensor = data.edges.connections
edge_attr: _typing.Optional[torch.Tensor] = (
data.edges.data['edge_attr'] if 'edge_attr' in data.edges.data else None
)
__data = static_graph_to_pyg_data(data)
elif (
hasattr(data, 'edge_index') and
isinstance(data.edge_index, torch.Tensor) and
data.edge_index.dim() == data.edge_index.size(0) == 2
):
edge_index: torch.LongTensor = data.edge_index
if (
hasattr(data, 'edge_attr') and
isinstance(data.edge_attr, torch.Tensor) and
data.edge_attr.size(0) == edge_index.size(1)
):
edge_attr: _typing.Optional[torch.Tensor] = data.edge_attr
else:
edge_attr: _typing.Optional[torch.Tensor] = None
if hasattr(data, 'x') and isinstance(data.x, torch.Tensor):
x: _typing.Optional[torch.Tensor] = data.x
else:
x: _typing.Optional[torch.Tensor] = None
if hasattr(data, 'y') and isinstance(data.y, torch.Tensor):
y: _typing.Optional[torch.Tensor] = data.x
else:
y: _typing.Optional[torch.Tensor] = None
__data = torch_geometric.data.Data(
edge_index=edge_index, edge_attr=edge_attr, x=x, y=y
)
else:
raise ValueError

if isinstance(val_ratio, float) and 0 < val_ratio < 1:
test_ratio = 1 - train_ratio - val_ratio
else:
test_ratio = 1 - train_ratio
compound_results = train_test_split_edges(
edge_index, edge_attr,
val_ratio=val_ratio, test_ratio=test_ratio
)
__data.train_pos_edge_index = compound_results.train_pos_edge_index
__data.train_pos_edge_attr = compound_results.train_pos_edge_attr
__data.train_neg_adj_mask = compound_results.train_neg_adj_mask
__data.val_pos_edge_index = compound_results.val_pos_edge_index
__data.val_pos_edge_attr = compound_results.val_pos_edge_attr
__data.val_neg_edge_index = compound_results.val_neg_edge_index
__data.test_pos_edge_index = compound_results.test_pos_edge_index
__data.test_pos_edge_attr = compound_results.test_pos_edge_attr
__data.test_neg_edge_index = compound_results.test_neg_edge_index
return __data

+ 32
- 0
autogl/datasets/utils/_split_edges/split_edges.py View File

@@ -0,0 +1,32 @@
import typing as _typing
from autogl.data import InMemoryDataset
import autogl

if autogl.backend.DependentBackend.is_dgl():
from ._dgl_compatible import split_edges_for_data
elif autogl.backend.DependentBackend.is_pyg():
from ._pyg_compatible import split_edges_for_data
else:
raise NotImplementedError


def split_edges(
dataset: InMemoryDataset,
train_ratio: float, val_ratio: _typing.Optional[float]
) -> InMemoryDataset:
if isinstance(val_ratio, float) and not 0 < train_ratio + val_ratio < 1:
raise ValueError
elif not 0 < train_ratio < 1:
raise ValueError
if (
autogl.backend.DependentBackend.is_pyg() and
not (isinstance(val_ratio, float) and 0 < val_ratio < 1)
):
raise ValueError(
"For PyG as backend, val_ratio MUST be specific float between 0 and 1, "
"i.e. 0 < val_ratio < 1"
)
return InMemoryDataset(
[split_edges_for_data(item, train_ratio, val_ratio) for item in dataset],
dataset.train_index, dataset.val_index, dataset.test_index
)

+ 310
- 0
autogl/datasets/utils/_split_edges/train_test_split_edges.py View File

@@ -0,0 +1,310 @@
import math
import torch
import typing as _typing


def _maybe_num_nodes(edge_index, num_nodes=None):
if isinstance(num_nodes, int):
return num_nodes
elif isinstance(edge_index, torch.Tensor):
return int(edge_index.max()) + 1 if edge_index.numel() > 0 else 0
else:
return max(edge_index.size(0), edge_index.size(1))


def __coalesce(
edge_index: torch.Tensor,
edge_attr: _typing.Union[
torch.Tensor, _typing.Iterable[torch.Tensor], None
] = None,
num_nodes: _typing.Optional[int] = ...,
is_sorted: bool = False,
sort_by_row: bool = True
) -> _typing.Union[
torch.Tensor, _typing.Tuple[torch.Tensor, torch.Tensor],
_typing.Tuple[torch.Tensor, _typing.Iterable[torch.Tensor]]
]:
"""
Row-wise sorts :obj:`edge_index` and removes its duplicated entries.
Duplicate entries in :obj:`edge_attr` are directly removed, instead of merged.
Args:
edge_index (LongTensor): The edge indices.
edge_attr (Tensor or List[Tensor], optional): Edge weights or multi-
dimensional edge features.
If given as a list, will re-shuffle and remove duplicates for all
its entries. (default: :obj:`None`)
num_nodes (int, optional): The number of nodes, *i.e.*
:obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`)
is_sorted (bool, optional): If set to :obj:`True`, will expect
:obj:`edge_index` to be already sorted row-wise.
sort_by_row (bool, optional): If set to :obj:`False`, will sort
:obj:`edge_index` column-wise.
:rtype: :class:`LongTensor` if :attr:`edge_attr` is :obj:`None`, else
(:class:`LongTensor`, :obj:`Tensor` or :obj:`Iterable[Tensor]]`)
"""

if edge_attr is None:
pass
elif isinstance(edge_attr, torch.Tensor) and torch.is_tensor(edge_attr):
if edge_attr.size(0) != edge_index.size(1):
raise ValueError
elif isinstance(edge_attr, _typing.Iterable):
if not all([
(
isinstance(attr, torch.Tensor) and
attr.size(0) == edge_index.size(1)
) for attr in edge_attr
]):
raise ValueError("Invalid edge_attr argument")
else:
raise TypeError("Unsupported type of edge_attr argument")

nnz = edge_index.size(1)
num_nodes = _maybe_num_nodes(edge_index, num_nodes)

idx = edge_index.new_empty(nnz + 1)
idx[0] = -1
idx[1:] = edge_index[1 - int(sort_by_row)]
idx[1:].mul_(num_nodes).add_(edge_index[int(sort_by_row)])

if not is_sorted:
idx[1:], perm = idx[1:].sort()
edge_index = edge_index[:, perm]
if edge_attr is not None and isinstance(edge_attr, torch.Tensor):
edge_attr = edge_attr[perm]
elif edge_attr is not None:
edge_attr = [e[perm] for e in edge_attr]

mask: _typing.Any = idx[1:] > idx[:-1]

# Only perform expensive merging in case there exists duplicates:
if mask.all():
return edge_index if edge_attr is None else (edge_index, edge_attr)

edge_index = edge_index[:, mask]
if edge_attr is None:
return edge_index
elif isinstance(edge_attr, torch.Tensor):
return edge_index, edge_attr[mask]
elif isinstance(edge_attr, _typing.Iterable):
return edge_index, [attr[mask] for attr in edge_attr]


def coalesce(
edge_index: torch.Tensor,
edge_attr: _typing.Union[
torch.Tensor, _typing.Iterable[torch.Tensor], None
] = None,
num_nodes: _typing.Optional[int] = ...,
is_sorted: bool = False,
sort_by_row: bool = True
) -> _typing.Union[
torch.Tensor, _typing.Tuple[torch.Tensor, torch.Tensor],
_typing.Tuple[torch.Tensor, _typing.Iterable[torch.Tensor]]
]:
"""
Row-wise sorts :obj:`edge_index` and removes its duplicated entries.
Duplicate entries in :obj:`edge_attr` are directly removed, instead of merged.
Args:
edge_index (LongTensor): The edge indices.
edge_attr (Tensor or List[Tensor], optional): Edge weights or multi-
dimensional edge features.
If given as a list, will re-shuffle and remove duplicates for all
its entries. (default: :obj:`None`)
num_nodes (int, optional): The number of nodes, *i.e.*
:obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`)
is_sorted (bool, optional): If set to :obj:`True`, will expect
:obj:`edge_index` to be already sorted row-wise.
sort_by_row (bool, optional): If set to :obj:`False`, will sort
:obj:`edge_index` column-wise.
:rtype: :class:`LongTensor` if :attr:`edge_attr` is :obj:`None`, else
(:class:`LongTensor`, :obj:`Tensor` or :obj:`Iterable[Tensor]]`)
"""
if not isinstance(num_nodes, int):
num_nodes = None
try:
import torch_geometric
return torch_geometric.utils.coalesce(
edge_index, edge_attr, num_nodes,
is_sorted=is_sorted,
sort_by_row=sort_by_row
)
except ModuleNotFoundError:
return __coalesce(
edge_index, edge_attr, num_nodes,
is_sorted=is_sorted,
sort_by_row=sort_by_row
)


def to_undirected(
edge_index: torch.Tensor,
edge_attr: _typing.Optional[_typing.Union[torch.Tensor, _typing.List[torch.Tensor]]] = None,
num_nodes: _typing.Optional[int] = ...,
__reduce: str = "add",
) -> _typing.Union[
torch.Tensor, _typing.Tuple[torch.Tensor, torch.Tensor],
_typing.Tuple[torch.Tensor, _typing.List[torch.Tensor]]
]:
r"""Converts the graph given by :attr:`edge_index` to an undirected graph
such that :math:`(j,i) \in \mathcal{E}` for every edge :math:`(i,j) \in
\mathcal{E}`.

Args:
edge_index (LongTensor): The edge indices.
edge_attr (Tensor or List[Tensor], optional): Edge weights or multi-
dimensional edge features.
If given as a list, will remove duplicates for all its entries.
(default: :obj:`None`)
num_nodes (int, optional): The number of nodes, *i.e.*
:obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`)
__reduce (string, optional): The reduce operation to use for merging edge
features (:obj:`"add"`, :obj:`"mean"`, :obj:`"min"`, :obj:`"max"`,
:obj:`"mul"`). (default: :obj:`"add"`)

:rtype: :class:`LongTensor` if :attr:`edge_attr` is :obj:`None`, else
(:class:`LongTensor`, :obj:`Tensor` or :obj:`List[Tensor]]`)
"""
# Maintain backward compatibility to `to_undirected(edge_index, num_nodes)`
if isinstance(edge_attr, int):
edge_attr = None
num_nodes = edge_attr

row, col = edge_index
row, col = torch.cat([row, col], dim=0), torch.cat([col, row], dim=0)
edge_index = torch.stack([row, col], dim=0)

if edge_attr is not None and isinstance(edge_attr, torch.Tensor):
edge_attr = torch.cat([edge_attr, edge_attr], dim=0)
elif edge_attr is not None:
edge_attr = [torch.cat([e, e], dim=0) for e in edge_attr]

return coalesce(edge_index, edge_attr, num_nodes)


class _SplitResult:
def __init__(
self,
train_pos_edge_index: torch.Tensor,
train_pos_edge_attr: _typing.Optional[torch.Tensor],
train_neg_adj_mask: torch.Tensor,
val_pos_edge_index: torch.Tensor,
val_pos_edge_attr: _typing.Optional[torch.Tensor],
val_neg_edge_index: torch.Tensor,
test_pos_edge_index: torch.Tensor,
test_pos_edge_attr: _typing.Optional[torch.Tensor],
test_neg_edge_index: torch.Tensor
):
self.train_pos_edge_index: torch.Tensor = train_pos_edge_index
self.train_pos_edge_attr: _typing.Optional[torch.Tensor] = train_pos_edge_attr
self.train_neg_adj_mask: torch.Tensor = train_neg_adj_mask
self.val_pos_edge_index: torch.Tensor = val_pos_edge_index
self.val_pos_edge_attr: _typing.Optional[torch.Tensor] = val_pos_edge_attr
self.val_neg_edge_index: torch.Tensor = val_neg_edge_index
self.test_pos_edge_index: torch.Tensor = test_pos_edge_index
self.test_pos_edge_attr: _typing.Optional[torch.Tensor] = test_pos_edge_attr
self.test_neg_edge_index: torch.Tensor = test_neg_edge_index


def train_test_split_edges(
edge_index: torch.Tensor,
edge_attr: _typing.Optional[_typing.Union[torch.Tensor, _typing.List[torch.Tensor]]] = None,
num_nodes: _typing.Optional[int] = ...,
val_ratio: float = 0.05,
test_ratio: float = 0.1
):
r"""Splits the edges of a :class:`torch_geometric.data.Data` object
into positive and negative train/val/test edges.
As such, it will replace the :obj:`edge_index` attribute with
:obj:`train_pos_edge_index`, :obj:`train_pos_neg_adj_mask`,
:obj:`val_pos_edge_index`, :obj:`val_neg_edge_index` and
:obj:`test_pos_edge_index` attributes.
If :obj:`data` has edge features named :obj:`edge_attr`, then
:obj:`train_pos_edge_attr`, :obj:`val_pos_edge_attr` and
:obj:`test_pos_edge_attr` will be added as well.

.. warning::

:meth:`~torch_geometric.utils.train_test_split_edges` is deprecated and
will be removed in a future release.
Use :class:`torch_geometric.transforms.RandomLinkSplit` instead.

Args:
edge_index (LongTensor): The edge indices.
edge_attr (Tensor or List[Tensor], optional): Edge weights or multi-
dimensional edge features.
If given as a list, will remove duplicates for all its entries.
(default: :obj:`None`)
num_nodes (int, optional): The number of nodes, *i.e.*
:obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`)
val_ratio (float, optional): The ratio of positive validation edges.
(default: :obj:`0.05`)
test_ratio (float, optional): The ratio of positive test edges.
(default: :obj:`0.1`)
"""
row, col = edge_index
num_nodes = _maybe_num_nodes(edge_index, num_nodes)

# Return upper triangular portion.
mask = row < col
row, col = row[mask], col[mask]

if edge_attr is not None:
edge_attr = edge_attr[mask]

n_v = int(math.floor(val_ratio * row.size(0)))
n_t = int(math.floor(test_ratio * row.size(0)))

# Positive edges.
perm = torch.randperm(row.size(0))
row, col = row[perm], col[perm]
if edge_attr is not None:
edge_attr = edge_attr[perm]

r, c = row[:n_v], col[:n_v]
val_pos_edge_index = torch.stack([r, c], dim=0)
if edge_attr is not None:
val_pos_edge_attr = edge_attr[:n_v]
else:
val_pos_edge_attr = None

r, c = row[n_v:n_v + n_t], col[n_v:n_v + n_t]
test_pos_edge_index = torch.stack([r, c], dim=0)
if edge_attr is not None:
test_pos_edge_attr = edge_attr[n_v:n_v + n_t]
else:
test_pos_edge_attr = None

r, c = row[n_v + n_t:], col[n_v + n_t:]
train_pos_edge_index = torch.stack([r, c], dim=0)
if edge_attr is not None:
train_pos_edge_index, train_pos_edge_attr = to_undirected(
train_pos_edge_index, edge_attr[n_v + n_t:]
)
else:
train_pos_edge_index = to_undirected(train_pos_edge_index)
train_pos_edge_attr = None

# Negative edges.
neg_adj_mask = torch.ones(num_nodes, num_nodes, dtype=torch.uint8)
neg_adj_mask = neg_adj_mask.triu(diagonal=1).to(torch.bool)
neg_adj_mask[row, col] = 0

neg_row, neg_col = neg_adj_mask.nonzero(as_tuple=False).t()
perm = torch.randperm(neg_row.size(0))[:n_v + n_t]
neg_row, neg_col = neg_row[perm], neg_col[perm]

neg_adj_mask[neg_row, neg_col] = 0
train_neg_adj_mask = neg_adj_mask

row, col = neg_row[:n_v], neg_col[:n_v]
val_neg_edge_index = torch.stack([row, col], dim=0)

row, col = neg_row[n_v:n_v + n_t], neg_col[n_v:n_v + n_t]
test_neg_edge_index = torch.stack([row, col], dim=0)

return _SplitResult(
train_pos_edge_index, train_pos_edge_attr, train_neg_adj_mask,
val_pos_edge_index, val_pos_edge_attr, val_neg_edge_index,
test_pos_edge_index, test_pos_edge_attr, test_neg_edge_index
)

Loading…
Cancel
Save