You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_cv.py 3.6 kB

5 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. """
  2. Auto graph classification using cross validation methods proposed in
  3. paper `A Fair Comparison of Graph Neural Networks for Graph Classification`
  4. """
  5. import sys
  6. import random
  7. import torch
  8. import numpy as np
  9. from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
  10. sys.path.append("../")
  11. from autogl.datasets import build_dataset_from_name, utils
  12. from autogl.solver import AutoGraphClassifier
  13. from autogl.module import Acc
  14. from autogl.backend import DependentBackend
  15. if DependentBackend.is_pyg():
  16. from autogl.datasets.utils.conversion import to_pyg_dataset as convert_dataset
  17. else:
  18. from autogl.datasets.utils.conversion import to_dgl_dataset as convert_dataset
  19. if __name__ == "__main__":
  20. parser = ArgumentParser(
  21. "auto graph classification", formatter_class=ArgumentDefaultsHelpFormatter
  22. )
  23. parser.add_argument(
  24. "--dataset",
  25. default="mutag",
  26. type=str,
  27. help="graph classification dataset",
  28. choices=["mutag", "imdb-b", "imdb-m", "proteins", "collab"],
  29. )
  30. parser.add_argument(
  31. "--configs", default="../configs/graphclf_full.yml", help="config files"
  32. )
  33. parser.add_argument("--device", type=int, default=0, help="device to run on")
  34. parser.add_argument("--seed", type=int, default=0, help="random seed")
  35. parser.add_argument("--folds", type=int, default=10, help="fold number")
  36. args = parser.parse_args()
  37. if torch.cuda.is_available():
  38. torch.cuda.set_device(args.device)
  39. seed = args.seed
  40. # set random seed
  41. random.seed(seed)
  42. np.random.seed(seed)
  43. torch.manual_seed(seed)
  44. if torch.cuda.is_available():
  45. torch.cuda.manual_seed(seed)
  46. torch.backends.cudnn.deterministic = True
  47. torch.backends.cudnn.benchmark = False
  48. print("begin processing dataset", args.dataset, "into", args.folds, "folds.")
  49. dataset = build_dataset_from_name(args.dataset)
  50. _converted_dataset = convert_dataset(dataset)
  51. if args.dataset.startswith("imdb"):
  52. from autogl.module.feature import OneHotDegreeGenerator
  53. if DependentBackend.is_pyg():
  54. from torch_geometric.utils import degree
  55. max_degree = 0
  56. for data in _converted_dataset:
  57. deg_max = int(degree(data.edge_index[0], data.num_nodes).max().item())
  58. max_degree = max(max_degree, deg_max)
  59. else:
  60. max_degree = 0
  61. for data, _ in _converted_dataset:
  62. deg_max = data.in_degrees().max().item()
  63. max_degree = max(max_degree, deg_max)
  64. dataset = OneHotDegreeGenerator(max_degree).fit_transform(dataset, inplace=False)
  65. elif args.dataset == "collab":
  66. from autogl.module.feature._auto_feature import OnlyConstFeature
  67. dataset = OnlyConstFeature().fit_transform(dataset, inplace=False)
  68. utils.graph_cross_validation(dataset, args.folds, random_seed=args.seed)
  69. accs = []
  70. for fold in range(args.folds):
  71. print("evaluating on fold number:", fold)
  72. utils.set_fold(dataset, fold)
  73. train_dataset = utils.graph_get_split(dataset, "train", False)
  74. autoClassifier = AutoGraphClassifier.from_config(args.configs)
  75. autoClassifier.fit(
  76. train_dataset,
  77. train_split=0.9,
  78. val_split=0.1,
  79. seed=args.seed,
  80. evaluation_method=[Acc],
  81. )
  82. acc = autoClassifier.evaluate(dataset, mask="val", metric="acc")
  83. print("test acc fold {:d}: {:.4f}".format(fold, acc))
  84. accs.append(acc)
  85. print("Average acc on", args.dataset, ":", np.mean(accs), "~", np.std(accs))