You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_classification.py 808 B

123456789101112131415161718192021222324252627
  1. import sys
  2. sys.path.append('../')
  3. from autogl.datasets import build_dataset_from_name, utils
  4. from autogl.solver import AutoGraphClassifier
  5. from autogl.module import Acc, BaseModel
  6. dataset = build_dataset_from_name('mutag')
  7. utils.graph_random_splits(dataset, train_ratio=0.4, val_ratio=0.4)
  8. autoClassifier = AutoGraphClassifier.from_config('../configs/graph_classification.yaml')
  9. # train
  10. autoClassifier.fit(
  11. dataset,
  12. time_limit=3600,
  13. train_split=0.8,
  14. val_split=0.1,
  15. cross_validation=True,
  16. cv_split=10,
  17. )
  18. autoClassifier.get_leaderboard().show()
  19. print('best single model:\n', autoClassifier.get_leaderboard().get_best_model(0))
  20. # test
  21. predict_result = autoClassifier.predict_proba()
  22. print(Acc.evaluate(predict_result, dataset.data.y[dataset.test_index].cpu().detach().numpy()))