You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

trainer_dataset.py 3.0 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. """
  2. Performance check of AutoGL trainer + dataset
  3. """
  4. import os
  5. import numpy as np
  6. from tqdm import tqdm
  7. os.environ["AUTOGL_BACKEND"] = "dgl"
  8. from autogl.datasets import build_dataset_from_name
  9. from autogl.datasets.utils.conversion import general_static_graphs_to_dgl_dataset
  10. from autogl.module.train import NodeClassificationFullTrainer
  11. from autogl.solver.utils import set_seed
  12. import logging
  13. logging.basicConfig(level=logging.ERROR)
  14. if __name__ == '__main__':
  15. import argparse
  16. parser = argparse.ArgumentParser('dgl trainer dataset')
  17. parser.add_argument('--device', type=str, default='cuda')
  18. parser.add_argument('--dataset', type=str, choices=['Cora', 'CiteSeer', 'PubMed'], default='Cora')
  19. parser.add_argument('--repeat', type=int, default=50)
  20. parser.add_argument('--model', type=str, choices=['gat', 'gcn', 'sage'], default='gat')
  21. parser.add_argument('--lr', type=float, default=0.01)
  22. parser.add_argument('--weight_decay', type=float, default=0.0)
  23. parser.add_argument('--epoch', type=int, default=200)
  24. args = parser.parse_args()
  25. # seed = 100
  26. dataset = build_dataset_from_name(args.dataset.lower())
  27. dataset = general_static_graphs_to_dgl_dataset(dataset)
  28. data = dataset[0].to(args.device)
  29. num_features = data.ndata['feat'].size(1)
  30. num_classes = data.ndata['label'].max().item() + 1
  31. label = data.ndata['label']
  32. test_mask = data.ndata['test_mask']
  33. accs = []
  34. for seed in tqdm(range(args.repeat)):
  35. set_seed(seed)
  36. if args.model == 'gat':
  37. model_hp = {
  38. # hp from model
  39. "num_layers": 2,
  40. "hidden": [8],
  41. "heads": 8,
  42. "dropout": 0.6,
  43. "act": "elu",
  44. }
  45. elif args.model == 'gcn':
  46. model_hp = {
  47. "num_layers": 2,
  48. "hidden": [16],
  49. "dropout": 0.5,
  50. "act": "relu"
  51. }
  52. elif args.model == 'sage':
  53. model_hp = {
  54. "num_layers": 2,
  55. "hidden": [64],
  56. "dropout": 0.5,
  57. "act": "relu",
  58. "agg": "gcn",
  59. }
  60. trainer = NodeClassificationFullTrainer(
  61. model=args.model,
  62. num_features=num_features,
  63. num_classes=num_classes,
  64. device=args.device,
  65. init=False,
  66. feval=['acc'],
  67. loss="nll_loss",
  68. ).duplicate_from_hyper_parameter({
  69. "max_epoch": args.epoch,
  70. "early_stopping_round": args.epoch + 1,
  71. "lr": args.lr,
  72. "weight_decay": args.weight_decay,
  73. **model_hp
  74. })
  75. trainer.train(dataset, False)
  76. output = trainer.predict(dataset, 'test')
  77. acc = (output == label[test_mask]).float().mean().item()
  78. accs.append(acc)
  79. print('{:.4f} ~ {:.4f}'.format(np.mean(accs), np.std(accs)))