Browse Source

fix dataset change in random_split_mask

tags/v0.3.1
wondergo2017 5 years ago
parent
commit
f15d08e83f
1 changed files with 25 additions and 23 deletions
  1. +25
    -23
      autogl/datasets/utils.py

+ 25
- 23
autogl/datasets/utils.py View File

@@ -1,3 +1,4 @@
from pdb import set_trace
import torch
import numpy as np
from torch_geometric.data import DataLoader
@@ -37,32 +38,33 @@ def random_splits_mask(dataset, train_ratio=0.2, val_ratio=0.4, seed=None):
assert (
train_ratio + val_ratio <= 1
), "the sum of train_ratio and val_ratio is larger than 1"
data = dataset[0]
r_s = torch.get_rng_state()
if torch.cuda.is_available():
r_s_cuda = torch.cuda.get_rng_state()
if seed is not None:
torch.manual_seed(seed)
_dataset=[d for d in dataset]
for data in _dataset:
r_s = torch.get_rng_state()
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)

perm = torch.randperm(data.num_nodes)
train_index = perm[: int(data.num_nodes * train_ratio)]
val_index = perm[
int(data.num_nodes * train_ratio) : int(
data.num_nodes * (train_ratio + val_ratio)
)
]
test_index = perm[int(data.num_nodes * (train_ratio + val_ratio)) :]
data.train_mask = index_to_mask(train_index, size=data.num_nodes)
data.val_mask = index_to_mask(val_index, size=data.num_nodes)
data.test_mask = index_to_mask(test_index, size=data.num_nodes)
r_s_cuda = torch.cuda.get_rng_state()
if seed is not None:
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)

perm = torch.randperm(data.num_nodes)
train_index = perm[: int(data.num_nodes * train_ratio)]
val_index = perm[
int(data.num_nodes * train_ratio) : int(
data.num_nodes * (train_ratio + val_ratio)
)
]
test_index = perm[int(data.num_nodes * (train_ratio + val_ratio)) :]
data.train_mask = index_to_mask(train_index, size=data.num_nodes)
data.val_mask = index_to_mask(val_index, size=data.num_nodes)
data.test_mask = index_to_mask(test_index, size=data.num_nodes)

torch.set_rng_state(r_s)
if torch.cuda.is_available():
torch.cuda.set_rng_state(r_s_cuda)
torch.set_rng_state(r_s)
if torch.cuda.is_available():
torch.cuda.set_rng_state(r_s_cuda)

dataset.data, dataset.slices = dataset.collate([d for d in dataset])
dataset.data, dataset.slices = dataset.collate(_dataset)
# while type(dataset.data.num_nodes) == list:
# dataset.data.num_nodes = dataset.data.num_nodes[0]
# dataset.data.num_nodes = dataset.data.num_nodes[0]


Loading…
Cancel
Save