You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

node_classification.py 830 B

4 years ago
4 years ago
4 years ago
1234567891011121314151617181920212223242526272829
  1. import os
  2. os.environ["AUTOGL_BACKEND"] = "dgl"
  3. from autogl.datasets import build_dataset_from_name
  4. from autogl.solver import AutoNodeClassifier
  5. from autogl.module.train import NodeClassificationFullTrainer
  6. from autogl.backend import DependentBackend
  7. key = "y" if DependentBackend.is_pyg() else "label"
  8. cora = build_dataset_from_name("cora")
  9. solver = AutoNodeClassifier(
  10. graph_models=("gin",),
  11. default_trainer=NodeClassificationFullTrainer(
  12. decoder=None,
  13. init=False,
  14. max_epoch=200,
  15. early_stopping_round=201,
  16. lr=0.01,
  17. weight_decay=0.0,
  18. ),
  19. hpo_module=None,
  20. device="auto"
  21. )
  22. solver.fit(cora, evaluation_method=["acc"])
  23. result = solver.predict(cora)
  24. print((result == cora[0].nodes.data[key][cora[0].nodes.data["test_mask"]].cpu().numpy()).astype('float').mean())