You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_classification.py 2.9 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. """
  2. Example of graph classification on given datasets.
  3. This version use random split to only show the usage of AutoGraphClassifier.
  4. Refer to `graph_cv.py` for cross validation evaluation of the whole system
  5. following paper `A Fair Comparison of Graph Neural Networks for Graph Classification`
  6. """
  7. import sys
  8. sys.path.append("../")
  9. import random
  10. import torch
  11. import numpy as np
  12. from autogl.datasets import build_dataset_from_name, utils
  13. from autogl.solver import AutoGraphClassifier
  14. from autogl.module import Acc
  15. from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
  16. if __name__ == "__main__":
  17. parser = ArgumentParser(
  18. "auto graph classification", formatter_class=ArgumentDefaultsHelpFormatter
  19. )
  20. parser.add_argument(
  21. "--dataset",
  22. default="mutag",
  23. type=str,
  24. help="graph classification dataset",
  25. choices=["mutag", "imdb-b", "imdb-m", "proteins", "collab"],
  26. )
  27. parser.add_argument(
  28. "--configs", default="../configs/graphclf_full.yml", help="config files"
  29. )
  30. parser.add_argument("--device", type=int, default=0, help="device to run on")
  31. parser.add_argument("--seed", type=int, default=0, help="random seed")
  32. args = parser.parse_args()
  33. if torch.cuda.is_available():
  34. torch.cuda.set_device(args.device)
  35. seed = args.seed
  36. # set random seed
  37. random.seed(seed)
  38. np.random.seed(seed)
  39. torch.manual_seed(seed)
  40. if torch.cuda.is_available():
  41. torch.cuda.manual_seed(seed)
  42. torch.backends.cudnn.deterministic = True
  43. torch.backends.cudnn.benchmark = False
  44. dataset = build_dataset_from_name(args.dataset)
  45. if args.dataset.startswith("imdb"):
  46. from autogl.module.feature.generators import PYGOneHotDegree
  47. # get max degree
  48. from torch_geometric.utils import degree
  49. max_degree = 0
  50. for data in dataset:
  51. deg_max = int(degree(data.edge_index[0], data.num_nodes).max().item())
  52. max_degree = max(max_degree, deg_max)
  53. dataset = PYGOneHotDegree(max_degree).fit_transform(dataset, inplace=False)
  54. elif args.dataset == "collab":
  55. from autogl.module.feature.auto_feature import Onlyconst
  56. dataset = Onlyconst().fit_transform(dataset, inplace=False)
  57. utils.graph_random_splits(dataset, train_ratio=0.8, val_ratio=0.1, seed=args.seed)
  58. autoClassifier = AutoGraphClassifier.from_config(args.configs)
  59. # train
  60. autoClassifier.fit(dataset, evaluation_method=[Acc], seed=args.seed)
  61. autoClassifier.get_leaderboard().show()
  62. print("best single model:\n", autoClassifier.get_leaderboard().get_best_model(0))
  63. # test
  64. predict_result = autoClassifier.predict_proba()
  65. print(
  66. "test acc %.4f"
  67. % (
  68. Acc.evaluate(
  69. predict_result,
  70. dataset.data.y[dataset.test_index].cpu().detach().numpy(),
  71. )
  72. )
  73. )