You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

trainer_dataset.py 2.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. """
  2. Performance check of AutoGL trainer + PYG dataset
  3. """
  4. import os
  5. import numpy as np
  6. from tqdm import tqdm
  7. os.environ["AUTOGL_BACKEND"] = "pyg"
  8. from autogl.module.feature import NormalizeFeatures
  9. from autogl.module.train import NodeClassificationFullTrainer
  10. from autogl.datasets import utils, build_dataset_from_name
  11. from autogl.solver.utils import set_seed
  12. import logging
  13. logging.basicConfig(level=logging.ERROR)
  14. if __name__ == '__main__':
  15. import argparse
  16. parser = argparse.ArgumentParser('pyg model')
  17. parser.add_argument('--device', type=str, default='cuda')
  18. parser.add_argument('--dataset', type=str, choices=['Cora', 'CiteSeer', 'PubMed'], default='Cora')
  19. parser.add_argument('--repeat', type=int, default=50)
  20. parser.add_argument('--model', type=str, choices=['gat', 'gcn', 'sage'], default='gat')
  21. parser.add_argument('--lr', type=float, default=0.01)
  22. parser.add_argument('--weight_decay', type=float, default=0.0)
  23. parser.add_argument('--epoch', type=int, default=200)
  24. args = parser.parse_args()
  25. # seed = 100
  26. dataset = build_dataset_from_name('cora')
  27. dataset = NormalizeFeatures().fit_transform(dataset)
  28. dataset = utils.conversion.general_static_graphs_to_pyg_dataset(dataset)
  29. data = dataset[0].to(args.device)
  30. num_features = data.x.size(1)
  31. num_classes = max([label.item() for label in data.y]) + 1
  32. accs = []
  33. for seed in tqdm(range(args.repeat)):
  34. set_seed(seed)
  35. if args.model == 'gat':
  36. model_hp = {
  37. # hp from model
  38. "num_layers": 2,
  39. "hidden": [8],
  40. "heads": 8,
  41. "dropout": 0.6,
  42. "act": "elu",
  43. }
  44. elif args.model == 'gcn':
  45. model_hp = {
  46. "num_layers": 2,
  47. "hidden": [16],
  48. "dropout": 0.5,
  49. "act": "relu"
  50. }
  51. elif args.model == 'sage':
  52. model_hp = {
  53. "num_layers": 2,
  54. "hidden": [64],
  55. "dropout": 0.5,
  56. "act": "relu",
  57. "agg": "mean",
  58. }
  59. trainer = NodeClassificationFullTrainer(
  60. model=args.model,
  61. num_features=num_features,
  62. num_classes=num_classes,
  63. device=args.device,
  64. init=False,
  65. feval=['acc'],
  66. loss="nll_loss",
  67. ).duplicate_from_hyper_parameter({
  68. "max_epoch": args.epoch,
  69. "early_stopping_round": args.epoch + 1,
  70. "lr": args.lr,
  71. "weight_decay": args.weight_decay,
  72. **model_hp
  73. })
  74. trainer.train(dataset, False)
  75. output = trainer.predict(dataset, 'test')
  76. acc = (output == data.y[data.test_mask]).float().mean().item()
  77. accs.append(acc)
  78. print('{:.4f} ~ {:.4f}'.format(np.mean(accs), np.std(accs)))