You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

solver.py 3.2 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. """
  2. Performance check of AutoGL Solver
  3. """
  4. import os
  5. os.environ["AUTOGL_BACKEND"] = "dgl"
  6. import random
  7. import numpy as np
  8. from tqdm import tqdm
  9. from autogl.solver import AutoGraphClassifier
  10. from autogl.datasets import build_dataset_from_name
  11. from autogl.solver.utils import set_seed
  12. import logging
  13. logging.basicConfig(level=logging.ERROR)
  14. def fixed(**kwargs):
  15. return [{
  16. 'parameterName': k,
  17. "type": "FIXED",
  18. "value": v
  19. } for k, v in kwargs.items()]
  20. if __name__ == '__main__':
  21. import argparse
  22. parser = argparse.ArgumentParser('dgl solver')
  23. parser.add_argument('--device', type=str, default='cuda')
  24. parser.add_argument('--dataset', type=str, choices=[x.lower() for x in ['MUTAG', 'COLLAB', 'IMDBBINARY', 'IMDBMULTI', 'NCI1', 'PROTEINS', 'PTC', 'REDDITBINARY', 'REDDITMULTI5K']], default='mutag')
  25. parser.add_argument('--dataset_seed', type=int, default=2021)
  26. parser.add_argument('--batch_size', type=int, default=32)
  27. parser.add_argument('--repeat', type=int, default=50)
  28. parser.add_argument('--model', type=str, choices=['gin', 'topkpool'], default='gin')
  29. parser.add_argument('--lr', type=float, default=0.0001)
  30. parser.add_argument('--epoch', type=int, default=100)
  31. args = parser.parse_args()
  32. # seed = 100
  33. dataset = build_dataset_from_name(args.dataset)
  34. # 1. split dataset [fix split]
  35. dataids = list(range(len(dataset)))
  36. random.seed(args.dataset_seed)
  37. random.shuffle(dataids)
  38. fold = int(len(dataset) * 0.1)
  39. dataset.train_index = dataids[:fold * 8]
  40. dataset.val_index = dataids[fold * 8: fold * 9]
  41. dataset.test_index = dataids[fold * 9: ]
  42. labels = np.array([x.data['label'].item() for x in dataset.test_split])
  43. accs = []
  44. for seed in tqdm(range(args.repeat)):
  45. set_seed(seed)
  46. solver = AutoGraphClassifier(
  47. feature_module=None,
  48. graph_models=[args.model],
  49. hpo_module='random',
  50. ensemble_module=None,
  51. device=args.device, max_evals=1,
  52. trainer_hp_space = fixed(**{
  53. # hp from trainer
  54. "max_epoch": 100,
  55. "batch_size": 32,
  56. "early_stopping_round": 101,
  57. "lr": 0.0001,
  58. "weight_decay": 0,
  59. }),
  60. model_hp_spaces=[
  61. fixed(**{
  62. # hp from model
  63. "num_layers": 5,
  64. "hidden": [64],
  65. "dropout": 0.5,
  66. "act": "relu",
  67. "eps": "False",
  68. "mlp_layers": 2,
  69. "neighbor_pooling_type": "sum",
  70. "graph_pooling_type": "sum"
  71. }) if args.model == 'gin' else fixed(**{
  72. "num_layers": 5,
  73. "hidden": [64],
  74. "dropout": 0.5
  75. }),
  76. ]
  77. )
  78. solver.fit(dataset, evaluation_method=['acc'])
  79. out = solver.predict(dataset, mask='test')
  80. acc = (out == labels).astype('float').mean()
  81. accs.append(acc)
  82. print('{:.4f} ~ {:.4f}'.format(np.mean(accs), np.std(accs)))