| @@ -1,3 +1,4 @@ | |||
| from pdb import set_trace | |||
| import torch | |||
| import numpy as np | |||
| from torch_geometric.data import DataLoader | |||
| @@ -37,32 +38,33 @@ def random_splits_mask(dataset, train_ratio=0.2, val_ratio=0.4, seed=None): | |||
| assert ( | |||
| train_ratio + val_ratio <= 1 | |||
| ), "the sum of train_ratio and val_ratio is larger than 1" | |||
| data = dataset[0] | |||
| r_s = torch.get_rng_state() | |||
| if torch.cuda.is_available(): | |||
| r_s_cuda = torch.cuda.get_rng_state() | |||
| if seed is not None: | |||
| torch.manual_seed(seed) | |||
| _dataset=[d for d in dataset] | |||
| for data in _dataset: | |||
| r_s = torch.get_rng_state() | |||
| if torch.cuda.is_available(): | |||
| torch.cuda.manual_seed(seed) | |||
| perm = torch.randperm(data.num_nodes) | |||
| train_index = perm[: int(data.num_nodes * train_ratio)] | |||
| val_index = perm[ | |||
| int(data.num_nodes * train_ratio) : int( | |||
| data.num_nodes * (train_ratio + val_ratio) | |||
| ) | |||
| ] | |||
| test_index = perm[int(data.num_nodes * (train_ratio + val_ratio)) :] | |||
| data.train_mask = index_to_mask(train_index, size=data.num_nodes) | |||
| data.val_mask = index_to_mask(val_index, size=data.num_nodes) | |||
| data.test_mask = index_to_mask(test_index, size=data.num_nodes) | |||
| r_s_cuda = torch.cuda.get_rng_state() | |||
| if seed is not None: | |||
| torch.manual_seed(seed) | |||
| if torch.cuda.is_available(): | |||
| torch.cuda.manual_seed(seed) | |||
| perm = torch.randperm(data.num_nodes) | |||
| train_index = perm[: int(data.num_nodes * train_ratio)] | |||
| val_index = perm[ | |||
| int(data.num_nodes * train_ratio) : int( | |||
| data.num_nodes * (train_ratio + val_ratio) | |||
| ) | |||
| ] | |||
| test_index = perm[int(data.num_nodes * (train_ratio + val_ratio)) :] | |||
| data.train_mask = index_to_mask(train_index, size=data.num_nodes) | |||
| data.val_mask = index_to_mask(val_index, size=data.num_nodes) | |||
| data.test_mask = index_to_mask(test_index, size=data.num_nodes) | |||
| torch.set_rng_state(r_s) | |||
| if torch.cuda.is_available(): | |||
| torch.cuda.set_rng_state(r_s_cuda) | |||
| torch.set_rng_state(r_s) | |||
| if torch.cuda.is_available(): | |||
| torch.cuda.set_rng_state(r_s_cuda) | |||
| dataset.data, dataset.slices = dataset.collate([d for d in dataset]) | |||
| dataset.data, dataset.slices = dataset.collate(_dataset) | |||
| # while type(dataset.data.num_nodes) == list: | |||
| # dataset.data.num_nodes = dataset.data.num_nodes[0] | |||
| # dataset.data.num_nodes = dataset.data.num_nodes[0] | |||