You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

solver.py 2.8 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. """
  2. Performance check of AutoGL trainer + PYG dataset
  3. """
  4. import os
  5. import numpy as np
  6. from tqdm import tqdm
  7. os.environ["AUTOGL_BACKEND"] = "pyg"
  8. from autogl.module.feature import NormalizeFeatures
  9. from autogl.solver import AutoNodeClassifier
  10. from autogl.datasets import utils, build_dataset_from_name
  11. from autogl.solver.utils import set_seed
  12. import logging
  13. logging.basicConfig(level=logging.ERROR)
  14. def fixed(**kwargs):
  15. return [{
  16. 'parameterName': k,
  17. "type": "FIXED",
  18. "value": v
  19. } for k, v in kwargs.items()]
  20. if __name__ == '__main__':
  21. import argparse
  22. parser = argparse.ArgumentParser('pyg model')
  23. parser.add_argument('--device', type=str, default='cuda')
  24. parser.add_argument('--dataset', type=str, choices=['Cora', 'CiteSeer', 'PubMed'], default='Cora')
  25. parser.add_argument('--repeat', type=int, default=50)
  26. parser.add_argument('--model', type=str, choices=['gat', 'gcn', 'sage'], default='gat')
  27. parser.add_argument('--lr', type=float, default=0.01)
  28. parser.add_argument('--weight_decay', type=float, default=0.0)
  29. parser.add_argument('--epoch', type=int, default=200)
  30. args = parser.parse_args()
  31. # seed = 100
  32. dataset = build_dataset_from_name('cora')
  33. label = dataset[0].nodes.data['y'][dataset[0].nodes.data['test_mask']].numpy()
  34. accs = []
  35. for seed in tqdm(range(args.repeat)):
  36. set_seed(seed)
  37. if args.model == 'gat':
  38. model_hp = {
  39. # hp from model
  40. "num_layers": 2,
  41. "hidden": [8],
  42. "heads": 8,
  43. "dropout": 0.6,
  44. "act": "elu",
  45. }
  46. elif args.model == 'gcn':
  47. model_hp = {
  48. "num_layers": 2,
  49. "hidden": [16],
  50. "dropout": 0.5,
  51. "act": "relu"
  52. }
  53. elif args.model == 'sage':
  54. model_hp = {
  55. "num_layers": 2,
  56. "hidden": [64],
  57. "dropout": 0.5,
  58. "act": "relu",
  59. "agg": "mean",
  60. }
  61. solver = AutoNodeClassifier(
  62. feature_module='NormalizeFeatures',
  63. graph_models=(args.model,),
  64. ensemble_module=None,
  65. max_evals=1,
  66. hpo_module='random',
  67. trainer_hp_space=fixed(**{
  68. "max_epoch": args.epoch,
  69. "early_stopping_round": args.epoch + 1,
  70. "lr": args.lr,
  71. "weight_decay": args.weight_decay,
  72. }),
  73. model_hp_spaces=[fixed(**model_hp)]
  74. )
  75. solver.fit(dataset)
  76. output = solver.predict(dataset)
  77. acc = (output == label).astype('float').mean()
  78. accs.append(acc)
  79. print('{:.4f} ~ {:.4f}'.format(np.mean(accs), np.std(accs)))