You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

node_classification.py 2.7 kB

4 years ago
4 years ago
4 years ago
4 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. import yaml
  2. import random
  3. import torch.backends.cudnn
  4. import numpy as np
  5. from autogl.datasets import build_dataset_from_name
  6. from autogl.solver import AutoNodeClassifier
  7. from autogl.module import Acc
  8. from autogl.backend import DependentBackend
  9. if __name__ == "__main__":
  10. from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
  11. parser = ArgumentParser(
  12. "auto node classification", formatter_class=ArgumentDefaultsHelpFormatter
  13. )
  14. parser.add_argument(
  15. "--dataset",
  16. default="cora",
  17. type=str,
  18. help="dataset to use",
  19. choices=[
  20. "cora",
  21. "pubmed",
  22. "citeseer",
  23. "coauthor_cs",
  24. "coauthor_physics",
  25. "amazon_computers",
  26. "amazon_photo",
  27. ],
  28. )
  29. parser.add_argument(
  30. "--configs",
  31. type=str,
  32. default="../configs/nodeclf_gcn_benchmark_small.yml",
  33. help="config to use",
  34. )
  35. # following arguments will override parameters in the config file
  36. parser.add_argument("--hpo", type=str, default="tpe", help="hpo methods")
  37. parser.add_argument(
  38. "--max_eval", type=int, default=50, help="max hpo evaluation times"
  39. )
  40. parser.add_argument("--seed", type=int, default=0, help="random seed")
  41. parser.add_argument("--device", default=0, type=int, help="GPU device")
  42. args = parser.parse_args()
  43. if torch.cuda.is_available():
  44. torch.cuda.set_device(args.device)
  45. seed = args.seed
  46. # set random seed
  47. random.seed(seed)
  48. np.random.seed(seed)
  49. torch.manual_seed(seed)
  50. if torch.cuda.is_available():
  51. torch.cuda.manual_seed(seed)
  52. torch.backends.cudnn.deterministic = True
  53. torch.backends.cudnn.benchmark = False
  54. dataset = build_dataset_from_name(args.dataset)
  55. label = dataset[0].nodes.data["y" if DependentBackend.is_pyg() else "label"]
  56. num_classes = len(np.unique(label.numpy()))
  57. configs = yaml.load(open(args.configs, "r").read(), Loader=yaml.FullLoader)
  58. configs["hpo"]["name"] = args.hpo
  59. configs["hpo"]["max_evals"] = args.max_eval
  60. autoClassifier = AutoNodeClassifier.from_config(configs)
  61. # train
  62. if args.dataset in ["cora", "citeseer", "pubmed"]:
  63. autoClassifier.fit(dataset, time_limit=3600, evaluation_method=[Acc])
  64. else:
  65. autoClassifier.fit(
  66. dataset,
  67. time_limit=3600,
  68. evaluation_method=[Acc],
  69. seed=seed,
  70. train_split=20 * num_classes,
  71. val_split=30 * num_classes,
  72. balanced=False,
  73. )
  74. autoClassifier.get_leaderboard().show()
  75. acc = autoClassifier.evaluate(metric="acc")
  76. print("test acc: {:.4f}".format(acc))