You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

solver.py 2.3 kB

4 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. """
  2. Performance check of AutoGL trainer + PYG dataset
  3. """
  4. import os
  5. import numpy as np
  6. from tqdm import tqdm
  7. os.environ["AUTOGL_BACKEND"] = "pyg"
  8. from autogl.module.feature import NormalizeFeatures
  9. from autogl.solver import AutoNodeClassifier
  10. from autogl.datasets import utils, build_dataset_from_name
  11. from autogl.solver.utils import set_seed
  12. from helper import get_encoder_decoder_hp
  13. import logging
  14. logging.basicConfig(level=logging.ERROR)
  15. def fixed(**kwargs):
  16. return [{
  17. 'parameterName': k,
  18. "type": "FIXED",
  19. "value": v
  20. } for k, v in kwargs.items()]
  21. if __name__ == '__main__':
  22. import argparse
  23. parser = argparse.ArgumentParser('pyg model')
  24. parser.add_argument('--device', type=str, default='cuda')
  25. parser.add_argument('--dataset', type=str, choices=['Cora', 'CiteSeer', 'PubMed'], default='Cora')
  26. parser.add_argument('--repeat', type=int, default=50)
  27. parser.add_argument('--model', type=str, choices=['gat', 'gcn', 'sage', 'gin'], default='gat')
  28. parser.add_argument('--lr', type=float, default=0.01)
  29. parser.add_argument('--weight_decay', type=float, default=0.0)
  30. parser.add_argument('--epoch', type=int, default=200)
  31. args = parser.parse_args()
  32. # seed = 100
  33. dataset = build_dataset_from_name(args.dataset.lower())
  34. label = dataset[0].nodes.data['y'][dataset[0].nodes.data['test_mask']].numpy()
  35. accs = []
  36. model_hp, decoder_hp = get_encoder_decoder_hp(args.model, decoupled=True)
  37. for seed in tqdm(range(args.repeat)):
  38. solver = AutoNodeClassifier(
  39. feature_module='NormalizeFeatures',
  40. graph_models=(args.model,),
  41. ensemble_module=None,
  42. max_evals=1,
  43. hpo_module='random',
  44. trainer_hp_space=fixed(**{
  45. "max_epoch": args.epoch,
  46. "early_stopping_round": args.epoch + 1,
  47. "lr": args.lr,
  48. "weight_decay": args.weight_decay,
  49. }),
  50. model_hp_spaces=[{"encoder": fixed(**model_hp), "decoder": fixed(**decoder_hp)}]
  51. )
  52. solver.fit(dataset, seed=seed)
  53. output = solver.predict(dataset)
  54. acc = (output == label).astype('float').mean()
  55. accs.append(acc)
  56. print('{:.4f} ~ {:.4f}'.format(np.mean(accs), np.std(accs)))