From 921edf2cafb7312e150b618cab18e41aa745048b Mon Sep 17 00:00:00 2001 From: wondergo2017 Date: Wed, 13 Apr 2022 16:12:55 +0800 Subject: [PATCH] fix nas nclf backend --- autogl/module/nas/space/autoattend.py | 3 +++ test/nas/node_classification.py | 17 +++++------------ 2 files changed, 8 insertions(+), 12 deletions(-) diff --git a/autogl/module/nas/space/autoattend.py b/autogl/module/nas/space/autoattend.py index 51d12eb..c0e6477 100644 --- a/autogl/module/nas/space/autoattend.py +++ b/autogl/module/nas/space/autoattend.py @@ -57,6 +57,9 @@ class AutoAttendNodeClassificationSpace(BaseSpace): ): super().__init__() + + from autogl.backend import DependentBackend;assert not DependentBackend.is_dgl(),"Now AutoAttend only support pyg" + self.layer_number = layer_number self.hidden_dim = hidden_dim self.input_dim = input_dim diff --git a/test/nas/node_classification.py b/test/nas/node_classification.py index 4b5e3c2..e28bbac 100644 --- a/test/nas/node_classification.py +++ b/test/nas/node_classification.py @@ -20,7 +20,7 @@ if DependentBackend.is_dgl(): elif DependentBackend.is_pyg(): from torch_geometric.datasets import Planetoid from autogl.module.model.pyg import BaseAutoModel - +from autogl.datasets import build_dataset_from_name import torch import torch.nn.functional as F from autogl.module.nas.space.single_path import SinglePathNodeClassificationSpace @@ -71,12 +71,13 @@ def test_model(model, data=None, check_children=False): if __name__ == "__main__": print("Testing backend: {}".format("dgl" if DependentBackend.is_dgl() else "pyg")) - if DependentBackend.is_dgl(): - dataset = CoraGraphDataset() + from autogl.datasets.utils.conversion._to_dgl_dataset import to_dgl_dataset as convert_dataset else: - dataset = Planetoid(os.path.expanduser("~/.cache-autogl"), "Cora") + from autogl.datasets.utils.conversion._to_pyg_dataset import to_pyg_dataset as convert_dataset + dataset = build_dataset_from_name('cora') + dataset = convert_dataset(dataset) data = dataset[0] di = bk_feat(data).shape[1] @@ -125,14 +126,6 @@ if __name__ == "__main__": model = algo.search(space, dataset, esti) test_model(model, data, True) - print("darts + graphnas ") - space = AutoAttendNodeClassificationSpace().cuda() - space.instantiate(input_dim=di, output_dim=do) - esti = OneShotEstimator() - algo = Darts(num_epochs=10) - model = algo.search(space, dataset, esti) - test_model(model, data, True) - print("Random search + graphnas ") space = GraphNasNodeClassificationSpace().cuda() space.instantiate(input_dim=di, output_dim=do)