You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

hetero_node_classification.py 656 B

123456789101112131415161718192021222324
  1. import os
  2. os.environ["AUTOGL_BACKEND"] = 'dgl'
  3. from autogl.datasets import build_dataset_from_name
  4. from autogl.solver import AutoHeteroNodeClassifier
  5. import argparse
  6. if __name__ == '__main__':
  7. parser = argparse.ArgumentParser()
  8. parser.add_argument("--model", type=str, choices=["han", "hgt"])
  9. parser.add_argument("--max_evals", type=int, default=10)
  10. args = parser.parse_args()
  11. dataset = build_dataset_from_name(f"hetero-acm-{args.model}")
  12. solver = AutoHeteroNodeClassifier(
  13. graph_models=(args.model, ),
  14. max_evals=10
  15. )
  16. solver.fit(dataset)
  17. acc = solver.evaluate(metric='acc')
  18. print("acc: ", acc)