diff --git a/autogl/datasets/utils.py b/autogl/datasets/utils.py index 8b677ee..afd1128 100644 --- a/autogl/datasets/utils.py +++ b/autogl/datasets/utils.py @@ -1,3 +1,4 @@ +from pdb import set_trace import torch import numpy as np from torch_geometric.data import DataLoader @@ -37,32 +38,33 @@ def random_splits_mask(dataset, train_ratio=0.2, val_ratio=0.4, seed=None): assert ( train_ratio + val_ratio <= 1 ), "the sum of train_ratio and val_ratio is larger than 1" - data = dataset[0] - r_s = torch.get_rng_state() - if torch.cuda.is_available(): - r_s_cuda = torch.cuda.get_rng_state() - if seed is not None: - torch.manual_seed(seed) + _dataset=[d for d in dataset] + for data in _dataset: + r_s = torch.get_rng_state() if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) - - perm = torch.randperm(data.num_nodes) - train_index = perm[: int(data.num_nodes * train_ratio)] - val_index = perm[ - int(data.num_nodes * train_ratio) : int( - data.num_nodes * (train_ratio + val_ratio) - ) - ] - test_index = perm[int(data.num_nodes * (train_ratio + val_ratio)) :] - data.train_mask = index_to_mask(train_index, size=data.num_nodes) - data.val_mask = index_to_mask(val_index, size=data.num_nodes) - data.test_mask = index_to_mask(test_index, size=data.num_nodes) + r_s_cuda = torch.cuda.get_rng_state() + if seed is not None: + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + + perm = torch.randperm(data.num_nodes) + train_index = perm[: int(data.num_nodes * train_ratio)] + val_index = perm[ + int(data.num_nodes * train_ratio) : int( + data.num_nodes * (train_ratio + val_ratio) + ) + ] + test_index = perm[int(data.num_nodes * (train_ratio + val_ratio)) :] + data.train_mask = index_to_mask(train_index, size=data.num_nodes) + data.val_mask = index_to_mask(val_index, size=data.num_nodes) + data.test_mask = index_to_mask(test_index, size=data.num_nodes) - torch.set_rng_state(r_s) - if torch.cuda.is_available(): - torch.cuda.set_rng_state(r_s_cuda) + torch.set_rng_state(r_s) + if torch.cuda.is_available(): + torch.cuda.set_rng_state(r_s_cuda) - dataset.data, dataset.slices = dataset.collate([d for d in dataset]) + dataset.data, dataset.slices = dataset.collate(_dataset) # while type(dataset.data.num_nodes) == list: # dataset.data.num_nodes = dataset.data.num_nodes[0] # dataset.data.num_nodes = dataset.data.num_nodes[0]