import numpy as np import random import torch import torch.utils.data import typing as _typing from sklearn.model_selection import StratifiedKFold, KFold from autogl import backend as _backend from autogl.data import InMemoryDataset def index_to_mask(index: torch.Tensor, size): mask = torch.zeros(size, dtype=torch.bool, device=index.device) mask[index] = True return mask def random_splits_mask( dataset: InMemoryDataset, train_ratio: float = 0.2, val_ratio: float = 0.4, seed: _typing.Optional[int] = None ) -> InMemoryDataset: r"""If the data has masks for train/val/test, return the splits with specific ratio. Parameters ---------- dataset : InMemoryDataset graph set train_ratio : float the portion of data that used for training. val_ratio : float the portion of data that used for validation. seed : int random seed for splitting dataset. """ if not train_ratio + val_ratio <= 1: raise ValueError("the sum of provided train_ratio and val_ratio is larger than 1") def __random_split_masks( num_nodes: int ) -> _typing.Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: _rng_state: torch.Tensor = torch.get_rng_state() if seed is not None and isinstance(seed, int): torch.manual_seed(seed) perm = torch.randperm(num_nodes) train_index = perm[:int(num_nodes * train_ratio)] val_index = perm[int(num_nodes * train_ratio): int(num_nodes * (train_ratio + val_ratio))] test_index = perm[int(num_nodes * (train_ratio + val_ratio)):] torch.set_rng_state(_rng_state) return ( index_to_mask(train_index, num_nodes), index_to_mask(val_index, num_nodes), index_to_mask(test_index, num_nodes) ) for index in range(len(dataset)): for node_type in dataset[index].nodes: data_keys = [data_key for data_key in dataset[index].nodes.data] if len(data_keys) > 0: _num_nodes: int = dataset[index].nodes[node_type].data[data_keys[0]].size(0) _masks: _typing.Tuple[torch.Tensor, torch.Tensor, torch.Tensor] = ( __random_split_masks(_num_nodes) ) dataset[index].nodes[node_type].data["train_mask"] = _masks[0] dataset[index].nodes[node_type].data["val_mask"] = _masks[1] dataset[index].nodes[node_type].data["test_mask"] = _masks[2] return dataset def random_splits_mask_class( dataset: InMemoryDataset, num_train_per_class: int = 20, num_val_per_class: int = 30, total_num_val: _typing.Optional[int] = ..., total_num_test: _typing.Optional[int] = ..., seed: _typing.Optional[int] = ... ): r"""If the data has masks for train/val/test, return the splits with specific number of samples from every class for training as suggested in Pitfalls of graph neural network evaluation [#]_ for semi-supervised learning. References ---------- .. [#] Shchur, O., Mumme, M., Bojchevski, A., & Günnemann, S. (2018). Pitfalls of graph neural network evaluation. arXiv preprint arXiv:1811.05868. Parameters ---------- dataset: InMemoryDataset instance of ``InMemoryDataset`` num_train_per_class : int the number of samples from every class used for training. num_val_per_class : int the number of samples from every class used for validation. total_num_val : int the total number of nodes that used for validation as alternative. total_num_test : int the total number of nodes that used for testing as alternative. The rest of the data will be seleted as test set if num_test set to None. seed : int random seed for splitting dataset. """ for graph_index in range(len(dataset)): for node_type in dataset[graph_index].nodes: if ( 'y' in dataset[graph_index].nodes[node_type].data and 'label' in dataset[graph_index].nodes[node_type].data ): raise ValueError( f"Both 'y' and 'label' data exist " f"for node type [{node_type}] in " f"graph with index [{graph_index}]." ) elif ( 'y' not in dataset[graph_index].nodes[node_type].data and 'label' not in dataset[graph_index].nodes[node_type].data ): continue elif 'y' in dataset[graph_index].nodes[node_type].data: label: torch.Tensor = dataset[graph_index].nodes[node_type].data['y'] elif 'label' in dataset[graph_index].nodes[node_type].data: label: torch.Tensor = dataset[graph_index].nodes[node_type].data['label'] else: raise RuntimeError num_nodes: int = label.size(0) num_classes: int = label.cpu().max().item() + 1 _rng_state: torch.Tensor = torch.get_rng_state() if seed not in (Ellipsis, None) and isinstance(seed, int): torch.manual_seed(seed) train_mask = torch.zeros(num_nodes, dtype=torch.bool, device=label.device) val_mask = torch.zeros(num_nodes, dtype=torch.bool, device=label.device) test_mask = torch.zeros(num_nodes, dtype=torch.bool, device=label.device) for class_index in range(num_classes): idx = (label == class_index).nonzero().view(-1) assert num_train_per_class + num_val_per_class < idx.size(0), ( f"the total number of samples from every class " f"used for training and validation is larger than " f"the total samples in class [{class_index}] for node type [{node_type}] " f"in graph with index [{graph_index}]" ) randomized_index: torch.Tensor = torch.randperm(idx.size(0)) train_idx = idx[randomized_index[:num_train_per_class]] val_idx = idx[ randomized_index[num_train_per_class: (num_train_per_class + num_val_per_class)] ] train_mask[train_idx] = True val_mask[val_idx] = True if isinstance(total_num_val, int) and total_num_val > 0: remaining = (~train_mask).nonzero().view(-1) remaining = remaining[torch.randperm(remaining.size(0))] val_mask[remaining[:total_num_val]] = True if isinstance(total_num_test, int) and total_num_test > 0: test_mask[remaining[total_num_val: (total_num_val + total_num_test)]] = True else: test_mask[remaining[total_num_val:]] = True else: remaining = (~(train_mask + val_mask)).nonzero().view(-1) test_mask[remaining] = True torch.set_rng_state(_rng_state) dataset[graph_index].nodes[node_type].data["train_mask"] = train_mask dataset[graph_index].nodes[node_type].data["val_mask"] = val_mask dataset[graph_index].nodes[node_type].data["test_mask"] = test_mask return dataset def graph_cross_validation( dataset: InMemoryDataset, n_splits: int = 10, shuffle: bool = True, random_seed: _typing.Optional[int] = ..., stratify: bool = False ) -> InMemoryDataset: r"""Cross validation for graph classification data Parameters ---------- dataset : InMemoryDataset dataset with multiple graphs. n_splits : int the number of folds to split. shuffle : bool shuffle or not for sklearn.model_selection.StratifiedKFold random_seed : int random_state for sklearn.model_selection.StratifiedKFold stratify: bool """ if not isinstance(n_splits, int): raise TypeError elif not n_splits > 0: raise ValueError if not isinstance(shuffle, bool): raise TypeError if not (random_seed in (Ellipsis, None) or isinstance(random_seed, int)): raise TypeError elif isinstance(random_seed, int) and random_seed >= 0: _random_seed: int = random_seed else: _random_seed: int = random.randrange(0, 65536) if not isinstance(stratify, bool): raise TypeError if stratify: kf = StratifiedKFold( n_splits=n_splits, shuffle=shuffle, random_state=_random_seed ) else: kf = KFold( n_splits=n_splits, shuffle=shuffle, random_state=_random_seed ) dataset_y = [g.data['y' if 'y' in g.data else 'label'].item() for g in dataset] idx_list = [ (train_index.tolist(), test_index.tolist()) for train_index, test_index in kf.split(np.zeros(len(dataset)), np.array(dataset_y)) ] dataset.folds = idx_list dataset.train_index = idx_list[0][0] dataset.val_index = idx_list[0][1] return dataset def set_fold(dataset: InMemoryDataset, fold_id: int) -> InMemoryDataset: r"""Set fold for graph dataset consist of multiple graphs. Parameters ---------- dataset: `autogl.data.InMemoryDataset` dataset with multiple graphs. fold_id: `int` The fold in to use, MUST be in [0, dataset.n_splits) Returns ------- `autogl.data.InMemoryDataset` The reference of original dataset. """ if not (hasattr(dataset, 'folds') and dataset.folds is not None): raise ValueError("Dataset do NOT contain folds") if not 0 <= fold_id < len(dataset.folds): raise ValueError( f"Fold id {fold_id} exceed total cross validation split number {len(dataset.folds)}" ) dataset.train_index = dataset.folds[fold_id].train_index dataset.val_index = dataset.folds[fold_id].val_index return dataset def graph_random_splits( dataset: InMemoryDataset, train_ratio: float = 0.2, val_ratio: float = 0.4, seed: _typing.Optional[int] = ... ): r"""Splitting graph dataset with specific ratio for train/val/test. Parameters ---------- dataset: ``InMemoryStaticGraphSet`` train_ratio : float the portion of data that used for training. val_ratio : float the portion of data that used for validation. seed : int random seed for splitting dataset. """ _rng_state = torch.get_rng_state() if isinstance(seed, int): torch.manual_seed(seed) perm = torch.randperm(len(dataset)) train_index = perm[: int(len(dataset) * train_ratio)] val_index = ( perm[int(len(dataset) * train_ratio): int(len(dataset) * (train_ratio + val_ratio))] ) test_index = perm[int(len(dataset) * (train_ratio + val_ratio)):] dataset.train_index = train_index.tolist() dataset.val_index = val_index.tolist() dataset.test_index = test_index.tolist() torch.set_rng_state(_rng_state) return dataset def graph_get_split( dataset, mask: str = "train", is_loader: bool = True, batch_size: int = 128, num_workers: int = 0, shuffle: bool = False ) -> _typing.Union[torch.utils.data.DataLoader, _typing.Iterable]: r"""Get train/test dataset/dataloader after cross validation. Parameters ---------- dataset: dataset with multiple graphs. mask : str is_loader : bool return original dataset or data loader batch_size : int batch_size for generating Dataloader num_workers : int number of workers parameter for data loader shuffle: bool whether shuffle the dataloader """ if not isinstance(mask, str): raise TypeError elif mask.lower() not in ("train", "val", "test"): raise ValueError if not isinstance(is_loader, bool): raise TypeError if not isinstance(batch_size, int): raise TypeError elif not batch_size > 0: raise ValueError if not isinstance(num_workers, int): raise TypeError elif not num_workers >= 0: raise ValueError if mask.lower() not in ("train", "val", "test"): raise ValueError elif mask.lower() == "train": optional_dataset_split = dataset.train_split if optional_dataset_split is None: raise ValueError(f"Provided dataset do NOT have {mask} split") else: sub_dataset = InMemoryDataset( optional_dataset_split, train_index=list(range(len(optional_dataset_split))) ) elif mask.lower() == "val": optional_dataset_split = dataset.val_split if optional_dataset_split is None: raise ValueError(f"Provided dataset do NOT have {mask} split") else: sub_dataset = InMemoryDataset( optional_dataset_split, val_index=list(range(len(optional_dataset_split))) ) elif mask.lower() == "test": optional_dataset_split = dataset.test_split if optional_dataset_split is None: raise ValueError(f"Provided dataset do NOT have {mask} split") else: sub_dataset = InMemoryDataset( optional_dataset_split, test_index=list(range(len(optional_dataset_split))) ) else: raise ValueError( f"The provided mask parameter must be a str in ['train', 'val', 'test'], " f"illegal provided value is [{mask}]" ) if not is_loader: return sub_dataset if is_loader: if not (_backend.DependentBackend.is_dgl() or _backend.DependentBackend.is_pyg()): raise RuntimeError("Unsupported backend") elif _backend.DependentBackend.is_dgl(): from dgl.dataloading.pytorch import GraphDataLoader return GraphDataLoader( sub_dataset, **{"batch_size": batch_size, "num_workers": num_workers}, shuffle=shuffle ) elif _backend.DependentBackend.is_pyg(): _sub_dataset: _typing.Any = optional_dataset_split import torch_geometric if int(torch_geometric.__version__.split('.')[0]) >= 2: # version 2.x from torch_geometric.loader import DataLoader else: from torch_geometric.data import DataLoader return DataLoader( _sub_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=shuffle ) else: return sub_dataset