You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

solver.py 2.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. """
  2. Performance check of AutoGL trainer + PYG dataset
  3. """
  4. import os
  5. import numpy as np
  6. from tqdm import tqdm
  7. os.environ["AUTOGL_BACKEND"] = "pyg"
  8. from autogl.module.feature import NormalizeFeatures
  9. from autogl.solver import AutoNodeClassifier
  10. from autogl.datasets import utils, build_dataset_from_name
  11. from autogl.solver.utils import set_seed
  12. import logging
  13. logging.basicConfig(level=logging.ERROR)
  14. def fixed(**kwargs):
  15. return [{
  16. 'parameterName': k,
  17. "type": "FIXED",
  18. "value": v
  19. } for k, v in kwargs.items()]
  20. if __name__ == '__main__':
  21. import argparse
  22. parser = argparse.ArgumentParser('pyg model')
  23. parser.add_argument('--device', type=str, default='cuda')
  24. parser.add_argument('--dataset', type=str, choices=['Cora', 'CiteSeer', 'PubMed'], default='Cora')
  25. parser.add_argument('--repeat', type=int, default=50)
  26. parser.add_argument('--model', type=str, choices=['gat', 'gcn', 'sage', 'gin'], default='gat')
  27. parser.add_argument('--lr', type=float, default=0.01)
  28. parser.add_argument('--weight_decay', type=float, default=0.0)
  29. parser.add_argument('--epoch', type=int, default=200)
  30. args = parser.parse_args()
  31. # seed = 100
  32. dataset = build_dataset_from_name('cora')
  33. label = dataset[0].nodes.data['y'][dataset[0].nodes.data['test_mask']].numpy()
  34. accs = []
  35. for seed in tqdm(range(args.repeat)):
  36. if args.model == 'gat':
  37. model_hp = {
  38. # hp from model
  39. "num_layers": 2,
  40. "hidden": [8],
  41. "heads": 8,
  42. "dropout": 0.6,
  43. "act": "elu",
  44. }
  45. elif args.model == 'gcn':
  46. model_hp = {
  47. "num_layers": 2,
  48. "hidden": [16],
  49. "dropout": 0.5,
  50. "act": "relu"
  51. }
  52. elif args.model == 'sage':
  53. model_hp = {
  54. "num_layers": 2,
  55. "hidden": [64],
  56. "dropout": 0.5,
  57. "act": "relu",
  58. "agg": "mean",
  59. }
  60. solver = AutoNodeClassifier(
  61. feature_module='NormalizeFeatures',
  62. graph_models=(args.model,),
  63. ensemble_module=None,
  64. max_evals=1,
  65. hpo_module='random',
  66. trainer_hp_space=fixed(**{
  67. "max_epoch": args.epoch,
  68. "early_stopping_round": args.epoch + 1,
  69. "lr": args.lr,
  70. "weight_decay": args.weight_decay,
  71. }),
  72. model_hp_spaces=[fixed(**model_hp)]
  73. )
  74. solver.fit(dataset, seed=seed)
  75. output = solver.predict(dataset)
  76. acc = (output == label).astype('float').mean()
  77. accs.append(acc)
  78. print('{:.4f} ~ {:.4f}'.format(np.mean(accs), np.std(accs)))