You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graphnas.py 845 B

12345678910111213141516171819202122
  1. import sys
  2. sys.path.append('../')
  3. from autogl.datasets import build_dataset_from_name
  4. from autogl.solver import AutoNodeClassifier
  5. from autogl.module.train import Acc
  6. from autogl.solver.utils import set_seed
  7. import argparse
  8. if __name__ == '__main__':
  9. set_seed(202106)
  10. parser = argparse.ArgumentParser()
  11. parser.add_argument('--config', type=str, default='../configs/nodeclf_nas_macro_benchmark2.yml')
  12. parser.add_argument('--dataset', choices=['cora', 'citeseer', 'pubmed'], default='cora', type=str)
  13. args = parser.parse_args()
  14. dataset = build_dataset_from_name('cora')
  15. solver = AutoNodeClassifier.from_config(args.config)
  16. solver.fit(dataset)
  17. solver.get_leaderboard().show()
  18. out = solver.predict_proba()
  19. print('acc on dataset', Acc.evaluate(out, dataset[0].y[dataset[0].test_mask].detach().numpy()))