You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

trainer.py 3.1 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. """
  2. Performance check of AutoGL trainer + DGL dataset
  3. """
  4. import os
  5. import numpy as np
  6. from tqdm import tqdm
  7. os.environ["AUTOGL_BACKEND"] = "dgl"
  8. from dgl.data import CoraGraphDataset, PubmedGraphDataset, CiteseerGraphDataset
  9. from autogl.module.train import NodeClassificationFullTrainer
  10. from autogl.solver.utils import set_seed
  11. import logging
  12. logging.basicConfig(level=logging.ERROR)
  13. if __name__ == '__main__':
  14. import argparse
  15. parser = argparse.ArgumentParser('dgl trainer')
  16. parser.add_argument('--device', type=str, default='cuda')
  17. parser.add_argument('--dataset', type=str, choices=['Cora', 'CiteSeer', 'PubMed'], default='Cora')
  18. parser.add_argument('--repeat', type=int, default=50)
  19. parser.add_argument('--model', type=str, choices=['gat', 'gcn', 'sage'], default='gat')
  20. parser.add_argument('--lr', type=float, default=0.01)
  21. parser.add_argument('--weight_decay', type=float, default=0.0)
  22. parser.add_argument('--epoch', type=int, default=200)
  23. args = parser.parse_args()
  24. # seed = 100
  25. if args.dataset == 'Cora':
  26. dataset = CoraGraphDataset()
  27. elif args.dataset == 'CiteSeer':
  28. dataset = CiteseerGraphDataset()
  29. elif args.dataset == 'PubMed':
  30. dataset = PubmedGraphDataset()
  31. graph = dataset[0].to(args.device)
  32. label = graph.ndata['label']
  33. train_mask = graph.ndata['train_mask']
  34. val_mask = graph.ndata['val_mask']
  35. test_mask = graph.ndata['test_mask']
  36. num_features = graph.ndata['feat'].size(1)
  37. num_classes = dataset.num_classes
  38. accs = []
  39. for seed in tqdm(range(args.repeat)):
  40. set_seed(seed)
  41. if args.model == 'gat':
  42. model_hp = {
  43. # hp from model
  44. "num_layers": 2,
  45. "hidden": [8],
  46. "heads": 8,
  47. "dropout": 0.6,
  48. "act": "elu",
  49. }
  50. elif args.model == 'gcn':
  51. model_hp = {
  52. "num_layers": 2,
  53. "hidden": [16],
  54. "dropout": 0.5,
  55. "act": "relu"
  56. }
  57. elif args.model == 'sage':
  58. model_hp = {
  59. "num_layers": 2,
  60. "hidden": [64],
  61. "dropout": 0.5,
  62. "act": "relu",
  63. "agg": "gcn",
  64. }
  65. trainer = NodeClassificationFullTrainer(
  66. model=args.model,
  67. num_features=num_features,
  68. num_classes=num_classes,
  69. device=args.device,
  70. init=False,
  71. feval=['acc'],
  72. loss="nll_loss",
  73. ).duplicate_from_hyper_parameter({
  74. "max_epoch": args.epoch,
  75. "early_stopping_round": args.epoch + 1,
  76. "lr": args.lr,
  77. "weight_decay": args.weight_decay,
  78. **model_hp
  79. })
  80. trainer.train(dataset, False)
  81. output = trainer.predict(dataset, 'test')
  82. acc = (output == label[test_mask]).float().mean().item()
  83. accs.append(acc)
  84. print('{:.4f} ~ {:.4f}'.format(np.mean(accs), np.std(accs)))