You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_cv.py 3.4 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. """
  2. Auto graph classification using cross validation methods proposed in
  3. paper `A Fair Comparison of Graph Neural Networks for Graph Classification`
  4. """
  5. import sys
  6. import random
  7. import torch
  8. import numpy as np
  9. from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
  10. sys.path.append("../")
  11. from autogl.datasets import build_dataset_from_name, utils
  12. from autogl.solver import AutoGraphClassifier
  13. from autogl.module import Acc
  14. if __name__ == "__main__":
  15. parser = ArgumentParser(
  16. "auto graph classification", formatter_class=ArgumentDefaultsHelpFormatter
  17. )
  18. parser.add_argument(
  19. "--dataset",
  20. default="mutag",
  21. type=str,
  22. help="graph classification dataset",
  23. choices=["mutag", "imdb-b", "imdb-m", "proteins", "collab"],
  24. )
  25. parser.add_argument(
  26. "--configs", default="../configs/graphclf_full.yml", help="config files"
  27. )
  28. parser.add_argument("--device", type=int, default=0, help="device to run on")
  29. parser.add_argument("--seed", type=int, default=0, help="random seed")
  30. parser.add_argument("--folds", type=int, default=10, help="fold number")
  31. args = parser.parse_args()
  32. if torch.cuda.is_available():
  33. torch.cuda.set_device(args.device)
  34. seed = args.seed
  35. # set random seed
  36. random.seed(seed)
  37. np.random.seed(seed)
  38. torch.manual_seed(seed)
  39. if torch.cuda.is_available():
  40. torch.cuda.manual_seed(seed)
  41. torch.backends.cudnn.deterministic = True
  42. torch.backends.cudnn.benchmark = False
  43. print("begin processing dataset", args.dataset, "into", args.folds, "folds.")
  44. dataset = build_dataset_from_name(args.dataset)
  45. if args.dataset.startswith("imdb"):
  46. from autogl.module.feature.generators import PYGOneHotDegree
  47. # get max degree
  48. from torch_geometric.utils import degree
  49. max_degree = 0
  50. for data in dataset:
  51. deg_max = int(degree(data.edge_index[0], data.num_nodes).max().item())
  52. max_degree = max(max_degree, deg_max)
  53. dataset = PYGOneHotDegree(max_degree).fit_transform(dataset, inplace=False)
  54. elif args.dataset == "collab":
  55. from autogl.module.feature.auto_feature import Onlyconst
  56. dataset = Onlyconst().fit_transform(dataset, inplace=False)
  57. utils.graph_cross_validation(dataset, args.folds, random_seed=args.seed)
  58. accs = []
  59. for fold in range(args.folds):
  60. print("evaluating on fold number:", fold)
  61. utils.graph_set_fold_id(dataset, fold)
  62. train_dataset = utils.graph_get_split(dataset, "train", False)
  63. autoClassifier = AutoGraphClassifier.from_config(args.configs)
  64. autoClassifier.fit(
  65. train_dataset,
  66. train_split=0.9,
  67. val_split=0.1,
  68. seed=args.seed,
  69. evaluation_method=[Acc],
  70. )
  71. predict_result = autoClassifier.predict_proba(dataset, mask="val")
  72. acc = Acc.evaluate(
  73. predict_result, dataset.data.y[dataset.val_index].cpu().detach().numpy()
  74. )
  75. print(
  76. "test acc %.4f"
  77. % (
  78. Acc.evaluate(
  79. predict_result,
  80. dataset.data.y[dataset.val_index].cpu().detach().numpy(),
  81. )
  82. )
  83. )
  84. accs.append(acc)
  85. print("Average acc on", args.dataset, ":", np.mean(accs), "~", np.std(accs))