You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_classification.py 495 B

4 years ago
12345678910111213141516
  1. from autogl.datasets import build_dataset_from_name
  2. from autogl.solver import AutoGraphClassifier
  3. from autogl.datasets import utils
  4. mutag = build_dataset_from_name("mutag")
  5. utils.graph_random_splits(mutag, 0.8, 0.1)
  6. solver = AutoGraphClassifier(
  7. graph_models=("gin",),
  8. hpo_module=None,
  9. device="auto"
  10. )
  11. solver.fit(mutag, evaluation_method=["acc"])
  12. result = solver.predict(mutag)
  13. print("Acc:", sum([d.data["y"].item() == r for d, r in zip(mutag.test_split, result)]) / len(result))