You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

link_prediction.py 2.7 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. from autogl.datasets import build_dataset_from_name
  2. from autogl.solver.classifier.link_predictor import AutoLinkPredictor
  3. from autogl.module.train.evaluation import Auc
  4. from autogl.datasets.utils import split_edges
  5. from autogl.backend import DependentBackend
  6. import yaml
  7. import random
  8. import torch
  9. import numpy as np
  10. if __name__ == "__main__":
  11. from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
  12. parser = ArgumentParser(
  13. "auto link prediction", formatter_class=ArgumentDefaultsHelpFormatter
  14. )
  15. parser.add_argument(
  16. "--dataset",
  17. default="cora",
  18. type=str,
  19. help="dataset to use",
  20. choices=[
  21. "cora",
  22. "pubmed",
  23. "citeseer",
  24. "coauthor_cs",
  25. "coauthor_physics",
  26. "amazon_computers",
  27. "amazon_photo",
  28. ],
  29. )
  30. parser.add_argument(
  31. "--configs",
  32. type=str,
  33. default="../configs/lp_gcn_benchmark.yml",
  34. help="config to use",
  35. )
  36. # following arguments will override parameters in the config file
  37. parser.add_argument("--hpo", type=str, default="tpe", help="hpo methods")
  38. parser.add_argument(
  39. "--max_eval", type=int, default=50, help="max hpo evaluation times"
  40. )
  41. parser.add_argument("--seed", type=int, default=0, help="random seed")
  42. parser.add_argument("--device", default=0, type=int, help="GPU device")
  43. args = parser.parse_args()
  44. if torch.cuda.is_available():
  45. torch.cuda.set_device(args.device)
  46. seed = args.seed
  47. # set random seed
  48. random.seed(seed)
  49. np.random.seed(seed)
  50. torch.manual_seed(seed)
  51. if torch.cuda.is_available():
  52. torch.cuda.manual_seed(seed)
  53. torch.backends.cudnn.deterministic = True
  54. torch.backends.cudnn.benchmark = False
  55. dataset = build_dataset_from_name(args.dataset)
  56. # split the edges for dataset
  57. dataset = split_edges(dataset, 0.8, 0.05)
  58. # add self-loop
  59. if DependentBackend.is_dgl():
  60. import dgl
  61. # add self loop to 0
  62. data = list(dataset[0])
  63. data[0] = dgl.add_self_loop(data[0])
  64. dataset = [data]
  65. configs = yaml.load(open(args.configs, "r").read(), Loader=yaml.FullLoader)
  66. configs["hpo"]["name"] = args.hpo
  67. configs["hpo"]["max_evals"] = args.max_eval
  68. autoClassifier = AutoLinkPredictor.from_config(configs)
  69. # train
  70. autoClassifier.fit(
  71. dataset,
  72. time_limit=3600,
  73. evaluation_method=[Auc],
  74. seed=seed
  75. )
  76. autoClassifier.get_leaderboard().show()
  77. auc = autoClassifier.evaluate(metric="auc")
  78. print("test auc: {:.4f}".format(auc))