Browse Source

merge all things

tags/v0.3.1
generall 4 years ago
parent
commit
9504ebe007
2 changed files with 18 additions and 5 deletions
  1. +17
    -4
      autogl/module/hpo/autone.py
  2. +1
    -1
      examples/node_classification.py

+ 17
- 4
autogl/module/hpo/autone.py View File

@@ -7,7 +7,7 @@ import numpy as np
from tqdm import trange
from . import register_hpo
from .base import BaseHPOptimizer, TimeTooLimitedError
import random
from .autone_file import utils

from torch_geometric.data import GraphSAINTRandomWalkSampler
@@ -16,13 +16,14 @@ from ..feature.graph import SgNetLSD

from torch_geometric.data import InMemoryDataset

from autogl.backend import DependentBackend
_isdgl=DependentBackend.is_dgl()

class _MyDataset(InMemoryDataset):
def __init__(self, datalist) -> None:
super().__init__()
self.data, self.slices = self.collate(datalist)


@register_hpo("autone")
class AutoNE(BaseHPOptimizer):
"""
@@ -48,7 +49,7 @@ class AutoNE(BaseHPOptimizer):
self.subgraphs = kwargs.get("subgraphs", 2)
self.sub_evals = kwargs.get("sub_evals", 2)
self.sample_batch_size = kwargs.get("sample_batch_size", 150)
self.sample_walk_length = kwargs.get("sample_walk_length", 2)
self.sample_walk_length = kwargs.get("sample_walk_length", 100)

def optimize(self, trainer, dataset, time_limit=None, memory_limit=None):
"""
@@ -78,6 +79,15 @@ class AutoNE(BaseHPOptimizer):
results.append(in_dataset)
return results

def sample_subgraph_dgl(whole_data):
data = whole_data.data # dgl data
# find data with different labels
# random walk
start = [random.randint(0, data.num_nodes - 1) for i in range(self.subgraphs)]
traces, _ = dgl.sampling.random_walk_with_restart(data, start, length=self.sample_batch_size, restart_prob= 1 / self.sample_walk_length))
subgraphs = dgl.node_subgraph(data, traces[i, :] for i in traces.size(0))
return subgraphs

func = SgNetLSD()

def get_wne(graph):
@@ -112,7 +122,10 @@ class AutoNE(BaseHPOptimizer):
info = []
K = utils.K(len(params.type_))
gp = utils.GaussianProcessRegressor(K)
sample_graphs = sample_subgraph(dataset)
if _isdgl:
sample_graphs = sample_subgraph_dgl(dataset)
else:
sample_graphs = sample_subgraph(dataset)
print("Sample Phase:\n")
for t in trange(sampled_number):
b_t = time.time()


+ 1
- 1
examples/node_classification.py View File

@@ -37,7 +37,7 @@ if __name__ == "__main__":
help="config to use",
)
# following arguments will override parameters in the config file
parser.add_argument("--hpo", type=str, default="tpe", help="hpo methods")
parser.add_argument("--hpo", type=str, default="autone", help="hpo methods")
parser.add_argument(
"--max_eval", type=int, default=50, help="max hpo evaluation times"
)


Loading…
Cancel
Save