You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

node_classification.py 875 B

123456789101112131415161718192021222324252627282930
  1. from autogl.module.train import NodeClassificationFullTrainer
  2. from autogl.datasets import build_dataset_from_name
  3. from autogl.datasets.utils.conversion._to_pyg_dataset import to_pyg_dataset
  4. def test_node_trainer():
  5. dataset = build_dataset_from_name("cora")
  6. dataset = to_pyg_dataset(dataset)
  7. node_trainer = NodeClassificationFullTrainer(
  8. model='gcn',
  9. init=False,
  10. lr=1e-2,
  11. weight_decay=5e-4,
  12. max_epoch=200,
  13. early_stopping_round=200,
  14. )
  15. node_trainer.num_features = dataset[0].x.size(1)
  16. node_trainer.num_classes = dataset[0].y.max().item() + 1
  17. node_trainer.initialize()
  18. print(node_trainer.encoder.encoder)
  19. print(node_trainer.decoder.decoder)
  20. node_trainer.train(dataset, True)
  21. result = node_trainer.evaluate(dataset, "test", "acc")
  22. print("Acc:", result)
  23. test_node_trainer()