Browse Source

Merge pull request #97 from THUMNLab/fix#96

fix nas nclf backend
develop/0.4/predevelop
Generall GitHub 4 years ago
parent
commit
e65faf3400
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 8 additions and 12 deletions
  1. +3
    -0
      autogl/module/nas/space/autoattend.py
  2. +5
    -12
      test/nas/node_classification.py

+ 3
- 0
autogl/module/nas/space/autoattend.py View File

@@ -57,6 +57,9 @@ class AutoAttendNodeClassificationSpace(BaseSpace):

):
super().__init__()

from autogl.backend import DependentBackend;assert not DependentBackend.is_dgl(),"Now AutoAttend only support pyg"
self.layer_number = layer_number
self.hidden_dim = hidden_dim
self.input_dim = input_dim


+ 5
- 12
test/nas/node_classification.py View File

@@ -20,7 +20,7 @@ if DependentBackend.is_dgl():
elif DependentBackend.is_pyg():
from torch_geometric.datasets import Planetoid
from autogl.module.model.pyg import BaseAutoModel
from autogl.datasets import build_dataset_from_name
import torch
import torch.nn.functional as F
from autogl.module.nas.space.single_path import SinglePathNodeClassificationSpace
@@ -71,12 +71,13 @@ def test_model(model, data=None, check_children=False):
if __name__ == "__main__":

print("Testing backend: {}".format("dgl" if DependentBackend.is_dgl() else "pyg"))

if DependentBackend.is_dgl():
dataset = CoraGraphDataset()
from autogl.datasets.utils.conversion._to_dgl_dataset import to_dgl_dataset as convert_dataset
else:
dataset = Planetoid(os.path.expanduser("~/.cache-autogl"), "Cora")
from autogl.datasets.utils.conversion._to_pyg_dataset import to_pyg_dataset as convert_dataset

dataset = build_dataset_from_name('cora')
dataset = convert_dataset(dataset)
data = dataset[0]

di = bk_feat(data).shape[1]
@@ -125,14 +126,6 @@ if __name__ == "__main__":
model = algo.search(space, dataset, esti)
test_model(model, data, True)

print("darts + graphnas ")
space = AutoAttendNodeClassificationSpace().cuda()
space.instantiate(input_dim=di, output_dim=do)
esti = OneShotEstimator()
algo = Darts(num_epochs=10)
model = algo.search(space, dataset, esti)
test_model(model, data, True)

print("Random search + graphnas ")
space = GraphNasNodeClassificationSpace().cuda()
space.instantiate(input_dim=di, output_dim=do)


Loading…
Cancel
Save