You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graphnas.py 896 B

4 years ago
4 years ago
123456789101112131415161718192021
  1. from autogl.datasets import build_dataset_from_name
  2. from autogl.solver import AutoNodeClassifier
  3. from autogl.solver.utils import set_seed
  4. import argparse
  5. from autogl.backend import DependentBackend
  6. if __name__ == '__main__':
  7. set_seed(202106)
  8. parser = argparse.ArgumentParser()
  9. parser.add_argument('--config', type=str, default='../configs/nodeclf_nas_macro_benchmark2.yml')
  10. parser.add_argument('--dataset', choices=['cora', 'citeseer', 'pubmed'], default='cora', type=str)
  11. args = parser.parse_args()
  12. dataset = build_dataset_from_name(args.dataset)
  13. label = dataset[0].nodes.data["y" if DependentBackend.is_pyg() else "label"][dataset[0].nodes.data["test_mask"]].cpu().numpy()
  14. solver = AutoNodeClassifier.from_config(args.config)
  15. solver.fit(dataset)
  16. solver.get_leaderboard().show()
  17. acc = solver.evaluate(metric="acc")
  18. print('acc on dataset', acc)