You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_classification.py 843 B

12345678910111213141516171819202122232425
  1. from autogl.module.train import GraphClassificationFullTrainer
  2. from autogl.datasets import build_dataset_from_name, utils
  3. from autogl.datasets.utils.conversion._to_pyg_dataset import to_pyg_dataset
  4. def test_graph_trainer():
  5. dataset = build_dataset_from_name("mutag")
  6. utils.graph_random_splits(dataset, 0.8, 0.1)
  7. dataset = to_pyg_dataset(dataset)
  8. lp_trainer = GraphClassificationFullTrainer(model='gin', init=False)
  9. lp_trainer.num_features = dataset[0].x.size(1)
  10. lp_trainer.num_classes = max([d.y for d in dataset]).item() + 1
  11. lp_trainer.num_graph_features = 0
  12. lp_trainer.initialize()
  13. print(lp_trainer.encoder.encoder)
  14. print(lp_trainer.decoder.decoder)
  15. lp_trainer.train(dataset, True)
  16. result = lp_trainer.evaluate(dataset, "test", "acc")
  17. print("Acc:", result)
  18. test_graph_trainer()