import torch import typing as _typing import torch_geometric import autogl from autogl.data.graph import GeneralStaticGraph from autogl.data.graph.utils.conversion import static_graph_to_pyg_data from .train_test_split_edges import train_test_split_edges def split_edges_for_data( data: _typing.Union[ torch_geometric.data.Data, autogl.data.graph.GeneralStaticGraph, _typing.Any ], train_ratio: float, val_ratio: float ) -> torch_geometric.data.Data: if isinstance(data, torch_geometric.data.Data): if not ( isinstance(data.edge_index, torch.Tensor) and data.edge_index.dim() == data.edge_index.size(0) == 2 ): raise ValueError edge_index: torch.LongTensor = data.edge_index edge_attr: _typing.Optional[torch.Tensor] = data.edge_attr import copy __data = copy.copy(data) elif isinstance(data, autogl.data.graph.GeneralStaticGraph): if not (data.nodes.is_homogeneous and data.edges.is_homogeneous): raise ValueError( "Provided instance of GeneralStaticGraph MUST be homogeneous" ) edge_index: torch.LongTensor = data.edges.connections edge_attr: _typing.Optional[torch.Tensor] = ( data.edges.data['edge_attr'] if 'edge_attr' in data.edges.data else None ) __data = static_graph_to_pyg_data(data) elif ( hasattr(data, 'edge_index') and isinstance(data.edge_index, torch.Tensor) and data.edge_index.dim() == data.edge_index.size(0) == 2 ): edge_index: torch.LongTensor = data.edge_index if ( hasattr(data, 'edge_attr') and isinstance(data.edge_attr, torch.Tensor) and data.edge_attr.size(0) == edge_index.size(1) ): edge_attr: _typing.Optional[torch.Tensor] = data.edge_attr else: edge_attr: _typing.Optional[torch.Tensor] = None if hasattr(data, 'x') and isinstance(data.x, torch.Tensor): x: _typing.Optional[torch.Tensor] = data.x else: x: _typing.Optional[torch.Tensor] = None if hasattr(data, 'y') and isinstance(data.y, torch.Tensor): y: _typing.Optional[torch.Tensor] = data.y else: y: _typing.Optional[torch.Tensor] = None __data = torch_geometric.data.Data( edge_index=edge_index, edge_attr=edge_attr, x=x, y=y ) else: raise ValueError if isinstance(val_ratio, float) and 0 < val_ratio < 1: test_ratio = 1 - train_ratio - val_ratio else: test_ratio = 1 - train_ratio compound_results = train_test_split_edges( edge_index, edge_attr, val_ratio=val_ratio, test_ratio=test_ratio ) __data.train_pos_edge_index = compound_results.train_pos_edge_index __data.train_pos_edge_attr = compound_results.train_pos_edge_attr __data.train_neg_adj_mask = compound_results.train_neg_adj_mask __data.val_pos_edge_index = compound_results.val_pos_edge_index __data.val_pos_edge_attr = compound_results.val_pos_edge_attr __data.val_neg_edge_index = compound_results.val_neg_edge_index __data.test_pos_edge_index = compound_results.test_pos_edge_index __data.test_pos_edge_attr = compound_results.test_pos_edge_attr __data.test_neg_edge_index = compound_results.test_neg_edge_index return __data