You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

solver.py 2.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. """
  2. Performance check of AutoGL solver
  3. """
  4. import os
  5. import numpy as np
  6. from tqdm import tqdm
  7. os.environ["AUTOGL_BACKEND"] = "dgl"
  8. from autogl.solver import AutoNodeClassifier
  9. from autogl.datasets import build_dataset_from_name
  10. from helper import get_encoder_decoder_hp
  11. import logging
  12. logging.basicConfig(level=logging.ERROR)
  13. def fixed(**kwargs):
  14. return [{
  15. 'parameterName': k,
  16. "type": "FIXED",
  17. "value": v
  18. } for k, v in kwargs.items()]
  19. if __name__ == '__main__':
  20. import argparse
  21. parser = argparse.ArgumentParser('dgl solver')
  22. parser.add_argument('--device', type=str, default='cuda')
  23. parser.add_argument('--dataset', type=str, choices=['Cora', 'CiteSeer', 'PubMed'], default='Cora')
  24. parser.add_argument('--repeat', type=int, default=50)
  25. parser.add_argument('--model', type=str, choices=['gin', 'gat', 'gcn', 'sage', 'topk'], default='gin')
  26. parser.add_argument('--lr', type=float, default=0.01)
  27. parser.add_argument('--weight_decay', type=float, default=0.0)
  28. parser.add_argument('--epoch', type=int, default=200)
  29. args = parser.parse_args()
  30. # seed = 100
  31. dataset = build_dataset_from_name(args.dataset.lower())
  32. label = dataset[0].nodes.data['label'][dataset[0].nodes.data['test_mask']].numpy()
  33. accs = []
  34. model_hp, decoder_hp = get_encoder_decoder_hp(args.model)
  35. for seed in tqdm(range(args.repeat)):
  36. solver = AutoNodeClassifier(
  37. feature_module=None,
  38. graph_models=(args.model,),
  39. ensemble_module=None,
  40. max_evals=1,
  41. hpo_module='random',
  42. trainer_hp_space=fixed(**{
  43. "max_epoch": args.epoch,
  44. "early_stopping_round": args.epoch + 1,
  45. "lr": args.lr,
  46. "weight_decay": args.weight_decay,
  47. }),
  48. model_hp_spaces=[{"encoder": fixed(**model_hp), "decoder": fixed(**decoder_hp)}]
  49. )
  50. solver.fit(dataset, evaluation_method=['acc'], seed=seed)
  51. output = solver.predict(dataset)
  52. acc = (output == label).astype('float').mean()
  53. accs.append(acc)
  54. print('{:.4f} ~ {:.4f}'.format(np.mean(accs), np.std(accs)))