You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

solver.py 3.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. """
  2. Performance check of AutoGL Solver
  3. """
  4. import os
  5. os.environ["AUTOGL_BACKEND"] = "pyg"
  6. import random
  7. import numpy as np
  8. from tqdm import tqdm
  9. from autogl.solver import AutoGraphClassifier
  10. from autogl.datasets import build_dataset_from_name, utils
  11. from autogl.solver.utils import set_seed
  12. import logging
  13. logging.basicConfig(level=logging.ERROR)
  14. def graph_get_split(dataset, mask, is_loader=True, batch_size=128, num_workers=0):
  15. out = getattr(dataset, f'{mask}_split')
  16. from torch_geometric.data import DataLoader
  17. if is_loader:
  18. out = DataLoader(out, batch_size, num_workers=num_workers)
  19. return out
  20. utils.graph_get_split = graph_get_split
  21. def fixed(**kwargs):
  22. return [{
  23. 'parameterName': k,
  24. "type": "FIXED",
  25. "value": v
  26. } for k, v in kwargs.items()]
  27. if __name__ == '__main__':
  28. import argparse
  29. parser = argparse.ArgumentParser('pyg solver')
  30. parser.add_argument('--device', type=str, default='cuda')
  31. parser.add_argument('--dataset', type=str, choices=[x.lower() for x in ['MUTAG', 'COLLAB', 'IMDBBINARY', 'IMDBMULTI', 'NCI1', 'PROTEINS', 'PTC', 'REDDITBINARY', 'REDDITMULTI5K']], default='mutag')
  32. parser.add_argument('--dataset_seed', type=int, default=2021)
  33. parser.add_argument('--batch_size', type=int, default=32)
  34. parser.add_argument('--repeat', type=int, default=50)
  35. parser.add_argument('--model', type=str, choices=['gin', 'topkpool'], default='gin')
  36. parser.add_argument('--lr', type=float, default=0.0001)
  37. parser.add_argument('--epoch', type=int, default=100)
  38. args = parser.parse_args()
  39. # seed = 100
  40. dataset = build_dataset_from_name(args.dataset)
  41. # 1. split dataset [fix split]
  42. dataids = list(range(len(dataset)))
  43. random.seed(args.dataset_seed)
  44. random.shuffle(dataids)
  45. fold = int(len(dataset) * 0.1)
  46. dataset.train_index = dataids[:fold * 8]
  47. dataset.val_index = dataids[fold * 8: fold * 9]
  48. dataset.test_index = dataids[fold * 9: ]
  49. dataset.loss = 'nll_loss'
  50. labels = np.array([x.data['y'].item() for x in dataset.test_split])
  51. accs = []
  52. for seed in tqdm(range(args.repeat)):
  53. set_seed(seed)
  54. solver = AutoGraphClassifier(
  55. feature_module=None,
  56. graph_models=[args.model],
  57. hpo_module='random',
  58. ensemble_module=None,
  59. device=args.device, max_evals=1,
  60. trainer_hp_space = fixed(
  61. **{
  62. # hp from trainer
  63. "max_epoch": args.epoch,
  64. "batch_size": args.batch_size,
  65. "early_stopping_round": args.epoch + 1,
  66. "lr": args.lr,
  67. "weight_decay": 0,
  68. }
  69. ),
  70. model_hp_spaces=[
  71. fixed(**{
  72. # hp from model
  73. "num_layers": 5,
  74. "hidden": [64,64,64,64],
  75. "dropout": 0.5,
  76. "act": "relu",
  77. "eps": "False",
  78. "mlp_layers": 2,
  79. "neighbor_pooling_type": "sum",
  80. "graph_pooling_type": "sum"
  81. }) if args.model == 'gin' else fixed(**{
  82. "ratio": 0.8,
  83. "dropout": 0.5,
  84. "act": "relu"
  85. }),
  86. ]
  87. )
  88. solver.fit(dataset, evaluation_method=['acc'])
  89. out = solver.predict(dataset, mask='test')
  90. acc = (out == labels).astype('float').mean()
  91. accs.append(acc)
  92. print('{:.4f} ~ {:.4f}'.format(np.mean(accs), np.std(accs)))