|
|
|
@@ -168,7 +168,7 @@ def random_splits_mask_class( |
|
|
|
|
|
|
|
|
|
|
|
def graph_cross_validation(dataset, n_splits=10, shuffle=True, random_seed=42): |
|
|
|
r"""Cross validation for graph classification data, returning one fold with specific idx in autograph.datasets or pyg.Dataloader(default) |
|
|
|
r"""Cross validation for graph classification data, returning one fold with specific idx in autogl.datasets or pyg.Dataloader(default) |
|
|
|
|
|
|
|
Parameters |
|
|
|
---------- |
|
|
|
@@ -315,7 +315,7 @@ def graph_get_split(dataset, mask="train", is_loader=True, batch_size=128): |
|
|
|
return with which dataset/dataloader |
|
|
|
|
|
|
|
is_loader : bool |
|
|
|
return with autograph.datasets or pyg.Dataloader |
|
|
|
return with autogl.datasets or pyg.Dataloader |
|
|
|
|
|
|
|
batch_size : int |
|
|
|
batch_size for generateing Dataloader |
|
|
|
@@ -332,7 +332,7 @@ def graph_get_split(dataset, mask="train", is_loader=True, batch_size=128): |
|
|
|
|
|
|
|
''' |
|
|
|
def graph_cross_validation(dataset, n_splits = 10, shuffle = True, random_seed = 42, fold_idx = 0, batch_size = 32, dataloader = True): |
|
|
|
r"""Cross validation for graph classification data, returning one fold with specific idx in autograph.datasets or pyg.Dataloader(default) |
|
|
|
r"""Cross validation for graph classification data, returning one fold with specific idx in autogl.datasets or pyg.Dataloader(default) |
|
|
|
|
|
|
|
Parameters |
|
|
|
---------- |
|
|
|
@@ -355,7 +355,7 @@ def graph_cross_validation(dataset, n_splits = 10, shuffle = True, random_seed = |
|
|
|
batch_size for generateing Dataloader |
|
|
|
|
|
|
|
dataloader : bool |
|
|
|
return with autograph.datasets or pyg.Dataloader |
|
|
|
return with autogl.datasets or pyg.Dataloader |
|
|
|
""" |
|
|
|
skf = StratifiedKFold(n_splits=n_splits, shuffle = shuffle, random_state = random_seed) |
|
|
|
idx_list = [] |
|
|
|
|