You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

solver.py 2.7 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. """
  2. Performance check of AutoGL solver
  3. """
  4. import os
  5. import numpy as np
  6. from tqdm import tqdm
  7. os.environ["AUTOGL_BACKEND"] = "dgl"
  8. from autogl.solver import AutoNodeClassifier
  9. from autogl.datasets import build_dataset_from_name
  10. import logging
  11. logging.basicConfig(level=logging.ERROR)
  12. def fixed(**kwargs):
  13. return [{
  14. 'parameterName': k,
  15. "type": "FIXED",
  16. "value": v
  17. } for k, v in kwargs.items()]
  18. if __name__ == '__main__':
  19. import argparse
  20. parser = argparse.ArgumentParser('dgl solver')
  21. parser.add_argument('--device', type=str, default='cuda')
  22. parser.add_argument('--dataset', type=str, choices=['Cora', 'CiteSeer', 'PubMed'], default='Cora')
  23. parser.add_argument('--repeat', type=int, default=50)
  24. parser.add_argument('--model', type=str, choices=['gat', 'gcn', 'sage'], default='gat')
  25. parser.add_argument('--lr', type=float, default=0.01)
  26. parser.add_argument('--weight_decay', type=float, default=0.0)
  27. parser.add_argument('--epoch', type=int, default=200)
  28. args = parser.parse_args()
  29. # seed = 100
  30. dataset = build_dataset_from_name(args.dataset.lower())
  31. label = dataset[0].nodes.data['label'][dataset[0].nodes.data['test_mask']].numpy()
  32. accs = []
  33. for seed in tqdm(range(args.repeat)):
  34. if args.model == 'gat':
  35. model_hp = {
  36. # hp from model
  37. "num_layers": 2,
  38. "hidden": [8],
  39. "heads": 8,
  40. "dropout": 0.6,
  41. "act": "elu",
  42. }
  43. elif args.model == 'gcn':
  44. model_hp = {
  45. "num_layers": 2,
  46. "hidden": [16],
  47. "dropout": 0.5,
  48. "act": "relu"
  49. }
  50. elif args.model == 'sage':
  51. model_hp = {
  52. "num_layers": 2,
  53. "hidden": [64],
  54. "dropout": 0.5,
  55. "act": "relu",
  56. "agg": "gcn",
  57. }
  58. solver = AutoNodeClassifier(
  59. feature_module=None,
  60. graph_models=(args.model,),
  61. ensemble_module=None,
  62. max_evals=1,
  63. hpo_module='random',
  64. trainer_hp_space=fixed(**{
  65. "max_epoch": args.epoch,
  66. "early_stopping_round": args.epoch + 1,
  67. "lr": args.lr,
  68. "weight_decay": args.weight_decay,
  69. }),
  70. model_hp_spaces=[fixed(**model_hp)]
  71. )
  72. solver.fit(dataset, evaluation_method=['acc'], seed=seed)
  73. output = solver.predict(dataset)
  74. acc = (output == label).astype('float').mean()
  75. accs.append(acc)
  76. print('{:.4f} ~ {:.4f}'.format(np.mean(accs), np.std(accs)))