You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

trainer.py 2.6 kB

4 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. """
  2. Performance check of AutoGL trainer + DGL dataset
  3. """
  4. import os
  5. import numpy as np
  6. from tqdm import tqdm
  7. os.environ["AUTOGL_BACKEND"] = "dgl"
  8. from dgl.data import CoraGraphDataset, PubmedGraphDataset, CiteseerGraphDataset
  9. from autogl.module.train import NodeClassificationFullTrainer
  10. from autogl.solver.utils import set_seed
  11. import logging
  12. from helper import get_encoder_decoder_hp
  13. logging.basicConfig(level=logging.ERROR)
  14. if __name__ == '__main__':
  15. import argparse
  16. parser = argparse.ArgumentParser('dgl trainer')
  17. parser.add_argument('--device', type=str, default='cuda')
  18. parser.add_argument('--dataset', type=str, choices=['Cora', 'CiteSeer', 'PubMed'], default='Cora')
  19. parser.add_argument('--repeat', type=int, default=50)
  20. parser.add_argument('--model', type=str, choices=['gat', 'gcn', 'sage'], default='gat')
  21. parser.add_argument('--lr', type=float, default=0.01)
  22. parser.add_argument('--weight_decay', type=float, default=0.0)
  23. parser.add_argument('--epoch', type=int, default=200)
  24. args = parser.parse_args()
  25. # seed = 100
  26. if args.dataset == 'Cora':
  27. dataset = CoraGraphDataset()
  28. elif args.dataset == 'CiteSeer':
  29. dataset = CiteseerGraphDataset()
  30. elif args.dataset == 'PubMed':
  31. dataset = PubmedGraphDataset()
  32. graph = dataset[0].to(args.device)
  33. label = graph.ndata['label']
  34. train_mask = graph.ndata['train_mask']
  35. val_mask = graph.ndata['val_mask']
  36. test_mask = graph.ndata['test_mask']
  37. num_features = graph.ndata['feat'].size(1)
  38. num_classes = dataset.num_classes
  39. accs = []
  40. model_hp, decoder_hp = get_encoder_decoder_hp(args.model)
  41. for seed in tqdm(range(args.repeat)):
  42. set_seed(seed)
  43. trainer = NodeClassificationFullTrainer(
  44. model=args.model,
  45. num_features=num_features,
  46. num_classes=num_classes,
  47. device=args.device,
  48. init=False,
  49. feval=['acc'],
  50. loss="nll_loss",
  51. ).duplicate_from_hyper_parameter({
  52. "trainer": {
  53. "max_epoch": args.epoch,
  54. "early_stopping_round": args.epoch + 1,
  55. "lr": args.lr,
  56. "weight_decay": args.weight_decay
  57. },
  58. "encoder": model_hp,
  59. "decoder": decoder_hp
  60. })
  61. trainer.train(dataset, False)
  62. output = trainer.predict(dataset, 'test')
  63. acc = (output == label[test_mask]).float().mean().item()
  64. accs.append(acc)
  65. print('{:.4f} ~ {:.4f}'.format(np.mean(accs), np.std(accs)))