You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

solver.py 3.0 kB

4 years ago
4 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. """
  2. Performance check of AutoGL Solver
  3. """
  4. import os
  5. os.environ["AUTOGL_BACKEND"] = "dgl"
  6. import random
  7. import numpy as np
  8. from tqdm import tqdm
  9. from autogl.solver import AutoGraphClassifier
  10. from autogl.datasets import build_dataset_from_name
  11. from autogl.solver.utils import set_seed
  12. from helper import get_encoder_decoder_hp
  13. import logging
  14. logging.basicConfig(level=logging.ERROR)
  15. def fixed(**kwargs):
  16. return [{
  17. 'parameterName': k,
  18. "type": "FIXED",
  19. "value": v
  20. } for k, v in kwargs.items()]
  21. if __name__ == '__main__':
  22. import argparse
  23. parser = argparse.ArgumentParser('dgl solver')
  24. parser.add_argument('--device', type=str, default='cuda')
  25. parser.add_argument('--dataset', type=str, choices=['MUTAG', 'COLLAB', 'IMDBBINARY', 'IMDBMULTI', 'NCI1', 'PROTEINS', 'PTC', 'REDDITBINARY', 'REDDITMULTI5K'], default='MUTAG')
  26. parser.add_argument('--dataset_seed', type=int, default=2021)
  27. parser.add_argument('--batch_size', type=int, default=32)
  28. parser.add_argument('--repeat', type=int, default=50)
  29. parser.add_argument('--model', type=str, choices=['gin', 'gat', 'gcn', 'sage'], default='gin')
  30. parser.add_argument('--lr', type=float, default=0.0001)
  31. parser.add_argument('--epoch', type=int, default=100)
  32. args = parser.parse_args()
  33. # seed = 100
  34. dataset = build_dataset_from_name(args.dataset.lower())
  35. # 1. split dataset [fix split]
  36. dataids = list(range(len(dataset)))
  37. random.seed(args.dataset_seed)
  38. random.shuffle(dataids)
  39. fold = int(len(dataset) * 0.1)
  40. dataset.train_index = dataids[:fold * 8]
  41. dataset.val_index = dataids[fold * 8: fold * 9]
  42. dataset.test_index = dataids[fold * 9: ]
  43. labels = np.array([x.data['label'].item() for x in dataset.test_split])
  44. if args.model == "gin":
  45. decoder = "JKSumPoolMLP"
  46. else:
  47. decoder = "sumpoolmlp"
  48. model_hp, decoder_hp = get_encoder_decoder_hp(args.model, decoder)
  49. accs = []
  50. for seed in tqdm(range(args.repeat)):
  51. solver = AutoGraphClassifier(
  52. feature_module=None,
  53. graph_models=[(args.model, decoder)],
  54. hpo_module='random',
  55. ensemble_module=None,
  56. device=args.device, max_evals=1,
  57. trainer_hp_space = fixed(**{
  58. # hp from trainer
  59. "max_epoch": args.epoch,
  60. "batch_size": args.batch_size,
  61. "early_stopping_round": args.epoch + 1,
  62. "lr": args.lr,
  63. "weight_decay": 0,
  64. }),
  65. model_hp_spaces=[{"encoder": fixed(**model_hp), "decoder": fixed(**decoder_hp)}]
  66. )
  67. solver.fit(dataset, evaluation_method=['acc'], seed=seed)
  68. out = solver.predict(dataset, mask='test')
  69. acc = (out == labels).astype('float').mean()
  70. accs.append(acc)
  71. print('{:.2f} ~ {:.2f}'.format(np.mean(accs) * 100, np.std(accs) * 100))