diff --git a/autogl/module/hpo/autone.py b/autogl/module/hpo/autone.py index 8464aa8..953cc94 100644 --- a/autogl/module/hpo/autone.py +++ b/autogl/module/hpo/autone.py @@ -7,7 +7,7 @@ import numpy as np from tqdm import trange from . import register_hpo from .base import BaseHPOptimizer, TimeTooLimitedError - +import random from .autone_file import utils from torch_geometric.data import GraphSAINTRandomWalkSampler @@ -16,13 +16,14 @@ from ..feature.graph import SgNetLSD from torch_geometric.data import InMemoryDataset +from autogl.backend import DependentBackend +_isdgl=DependentBackend.is_dgl() class _MyDataset(InMemoryDataset): def __init__(self, datalist) -> None: super().__init__() self.data, self.slices = self.collate(datalist) - @register_hpo("autone") class AutoNE(BaseHPOptimizer): """ @@ -48,7 +49,7 @@ class AutoNE(BaseHPOptimizer): self.subgraphs = kwargs.get("subgraphs", 2) self.sub_evals = kwargs.get("sub_evals", 2) self.sample_batch_size = kwargs.get("sample_batch_size", 150) - self.sample_walk_length = kwargs.get("sample_walk_length", 2) + self.sample_walk_length = kwargs.get("sample_walk_length", 100) def optimize(self, trainer, dataset, time_limit=None, memory_limit=None): """ @@ -78,6 +79,15 @@ class AutoNE(BaseHPOptimizer): results.append(in_dataset) return results + def sample_subgraph_dgl(whole_data): + data = whole_data.data # dgl data + # find data with different labels + # random walk + start = [random.randint(0, data.num_nodes - 1) for i in range(self.subgraphs)] + traces, _ = dgl.sampling.random_walk_with_restart(data, start, length=self.sample_batch_size, restart_prob= 1 / self.sample_walk_length)) + subgraphs = dgl.node_subgraph(data, traces[i, :] for i in traces.size(0)) + return subgraphs + func = SgNetLSD() def get_wne(graph): @@ -112,7 +122,10 @@ class AutoNE(BaseHPOptimizer): info = [] K = utils.K(len(params.type_)) gp = utils.GaussianProcessRegressor(K) - sample_graphs = sample_subgraph(dataset) + if _isdgl: + sample_graphs = sample_subgraph_dgl(dataset) + else: + sample_graphs = sample_subgraph(dataset) print("Sample Phase:\n") for t in trange(sampled_number): b_t = time.time() diff --git a/examples/node_classification.py b/examples/node_classification.py index f850cb5..718971e 100644 --- a/examples/node_classification.py +++ b/examples/node_classification.py @@ -37,7 +37,7 @@ if __name__ == "__main__": help="config to use", ) # following arguments will override parameters in the config file - parser.add_argument("--hpo", type=str, default="tpe", help="hpo methods") + parser.add_argument("--hpo", type=str, default="autone", help="hpo methods") parser.add_argument( "--max_eval", type=int, default=50, help="max hpo evaluation times" )