You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

node_classification.py 2.7 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. import sys
  2. sys.path.append("../")
  3. from autogl.datasets import build_dataset_from_name
  4. from autogl.solver import AutoNodeClassifier
  5. from autogl.module import Acc
  6. import yaml
  7. import random
  8. import torch
  9. import numpy as np
  10. if __name__ == "__main__":
  11. from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
  12. parser = ArgumentParser(
  13. "auto node classification", formatter_class=ArgumentDefaultsHelpFormatter
  14. )
  15. parser.add_argument(
  16. "--dataset",
  17. default="cora",
  18. type=str,
  19. help="dataset to use",
  20. choices=[
  21. "cora",
  22. "pubmed",
  23. "citeseer",
  24. "coauthor_cs",
  25. "coauthor_physics",
  26. "amazon_computers",
  27. "amazon_photo",
  28. ],
  29. )
  30. parser.add_argument(
  31. "--configs",
  32. type=str,
  33. default="../configs/nodeclf_gcn_benchmark_small.yml",
  34. help="config to use",
  35. )
  36. # following arguments will override parameters in the config file
  37. parser.add_argument("--hpo", type=str, default="tpe", help="hpo methods")
  38. parser.add_argument(
  39. "--max_eval", type=int, default=50, help="max hpo evaluation times"
  40. )
  41. parser.add_argument("--seed", type=int, default=0, help="random seed")
  42. parser.add_argument("--device", default=0, type=int, help="GPU device")
  43. args = parser.parse_args()
  44. if torch.cuda.is_available():
  45. torch.cuda.set_device(args.device)
  46. seed = args.seed
  47. # set random seed
  48. random.seed(seed)
  49. np.random.seed(seed)
  50. torch.manual_seed(seed)
  51. if torch.cuda.is_available():
  52. torch.cuda.manual_seed(seed)
  53. torch.backends.cudnn.deterministic = True
  54. torch.backends.cudnn.benchmark = False
  55. dataset = build_dataset_from_name(args.dataset)
  56. configs = yaml.load(open(args.configs, "r").read(), Loader=yaml.FullLoader)
  57. configs["hpo"]["name"] = args.hpo
  58. configs["hpo"]["max_evals"] = args.max_eval
  59. autoClassifier = AutoNodeClassifier.from_config(configs)
  60. # train
  61. if args.dataset in ["cora", "citeseer", "pubmed"]:
  62. autoClassifier.fit(dataset, time_limit=3600, evaluation_method=[Acc])
  63. else:
  64. autoClassifier.fit(
  65. dataset,
  66. time_limit=3600,
  67. evaluation_method=[Acc],
  68. seed=seed,
  69. train_split=20 * dataset.num_classes,
  70. val_split=30 * dataset.num_classes,
  71. balanced=False,
  72. )
  73. autoClassifier.get_leaderboard().show()
  74. # test
  75. predict_result = autoClassifier.predict_proba()
  76. print(
  77. "test acc: %.4f"
  78. % (Acc.evaluate(predict_result, dataset.data.y[dataset.data.test_mask].numpy()))
  79. )