You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_classification.py 3.4 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. """
  2. Example of graph classification on given datasets.
  3. This version use random split to only show the usage of AutoGraphClassifier.
  4. Refer to `graph_cv.py` for cross validation evaluation of the whole system
  5. following paper `A Fair Comparison of Graph Neural Networks for Graph Classification`
  6. """
  7. import random
  8. import torch
  9. import numpy as np
  10. from autogl.datasets import build_dataset_from_name, utils
  11. from autogl.solver import AutoGraphClassifier
  12. from autogl.module import Acc
  13. from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
  14. from autogl.backend import DependentBackend
  15. if DependentBackend.is_pyg():
  16. from autogl.datasets.utils.conversion import to_pyg_dataset as convert_dataset
  17. else:
  18. from autogl.datasets.utils.conversion import to_dgl_dataset as convert_dataset
  19. backend = DependentBackend.get_backend_name()
  20. if __name__ == "__main__":
  21. parser = ArgumentParser(
  22. "auto graph classification", formatter_class=ArgumentDefaultsHelpFormatter
  23. )
  24. parser.add_argument(
  25. "--dataset",
  26. default="mutag",
  27. type=str,
  28. help="graph classification dataset",
  29. choices=["mutag", "imdb-b", "imdb-m", "proteins", "collab"],
  30. )
  31. parser.add_argument(
  32. "--configs", default="../configs/graphclf_gin_benchmark.yml", help="config files"
  33. )
  34. parser.add_argument("--device", type=int, default=-1, help="device to run on, -1 means cpu")
  35. parser.add_argument("--seed", type=int, default=0, help="random seed")
  36. args = parser.parse_args()
  37. if args.device == -1:
  38. args.device = "cpu"
  39. if torch.cuda.is_available() and args.device != "cpu":
  40. torch.cuda.set_device(args.device)
  41. seed = args.seed
  42. # set random seed
  43. random.seed(seed)
  44. np.random.seed(seed)
  45. torch.manual_seed(seed)
  46. if torch.cuda.is_available():
  47. torch.cuda.manual_seed(seed)
  48. torch.backends.cudnn.deterministic = True
  49. torch.backends.cudnn.benchmark = False
  50. dataset = build_dataset_from_name(args.dataset)
  51. _converted_dataset = convert_dataset(dataset)
  52. if args.dataset.startswith("imdb"):
  53. from autogl.module.feature import OneHotDegreeGenerator
  54. if DependentBackend.is_pyg():
  55. from torch_geometric.utils import degree
  56. max_degree = 0
  57. for data in _converted_dataset:
  58. deg_max = int(degree(data.edge_index[0], data.num_nodes).max().item())
  59. max_degree = max(max_degree, deg_max)
  60. else:
  61. max_degree = 0
  62. for data, _ in _converted_dataset:
  63. deg_max = data.in_degrees().max().item()
  64. max_degree = max(max_degree, deg_max)
  65. dataset = OneHotDegreeGenerator(max_degree).fit_transform(dataset, inplace=False)
  66. elif args.dataset == "collab":
  67. from autogl.module.feature._auto_feature import OnlyConstFeature
  68. dataset = OnlyConstFeature().fit_transform(dataset, inplace=False)
  69. utils.graph_random_splits(dataset, train_ratio=0.8, val_ratio=0.1, seed=args.seed)
  70. autoClassifier = AutoGraphClassifier.from_config(args.configs)
  71. # train
  72. autoClassifier.fit(dataset, evaluation_method=[Acc], seed=args.seed)
  73. autoClassifier.get_leaderboard().show()
  74. print("best single model:\n", autoClassifier.get_leaderboard().get_best_model(0))
  75. # test
  76. acc = autoClassifier.evaluate(metric="acc")
  77. print("test acc {:.4f}".format(acc))