You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

trainer.py 2.4 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. """
  2. Performance check of AutoGL trainer + PYG dataset
  3. """
  4. import os
  5. import numpy as np
  6. from tqdm import tqdm
  7. os.environ["AUTOGL_BACKEND"] = "pyg"
  8. from torch_geometric.datasets import Planetoid
  9. import torch_geometric.transforms as T
  10. from autogl.module.train import NodeClassificationFullTrainer
  11. from autogl.datasets import utils
  12. from autogl.solver.utils import set_seed
  13. from helper import get_encoder_decoder_hp
  14. import logging
  15. logging.basicConfig(level=logging.ERROR)
  16. if __name__ == '__main__':
  17. import argparse
  18. parser = argparse.ArgumentParser('pyg model')
  19. parser.add_argument('--device', type=str, default='cuda')
  20. parser.add_argument('--dataset', type=str, choices=['Cora', 'CiteSeer', 'PubMed'], default='Cora')
  21. parser.add_argument('--repeat', type=int, default=50)
  22. parser.add_argument('--model', type=str, choices=['gat', 'gcn', 'sage', 'gin'], default='gat')
  23. parser.add_argument('--lr', type=float, default=0.01)
  24. parser.add_argument('--weight_decay', type=float, default=0.0)
  25. parser.add_argument('--epoch', type=int, default=200)
  26. args = parser.parse_args()
  27. # seed = 100
  28. dataset = Planetoid(os.path.expanduser('~/.cache-autogl'), args.dataset, transform=T.NormalizeFeatures())
  29. data = dataset[0].to(args.device)
  30. num_features = dataset.num_node_features
  31. num_classes = dataset.num_classes
  32. dataset = [data]
  33. accs = []
  34. model_hp, decoder_hp = get_encoder_decoder_hp(args.model, decoupled=True)
  35. for seed in tqdm(range(args.repeat)):
  36. set_seed(seed)
  37. trainer = NodeClassificationFullTrainer(
  38. model=args.model,
  39. num_features=num_features,
  40. num_classes=num_classes,
  41. device=args.device,
  42. init=False,
  43. feval=['acc'],
  44. loss="nll_loss",
  45. ).duplicate_from_hyper_parameter({
  46. "trainer": {
  47. "max_epoch": args.epoch,
  48. "early_stopping_round": args.epoch + 1,
  49. "lr": args.lr,
  50. "weight_decay": args.weight_decay,
  51. },
  52. "encoder": model_hp,
  53. "decoder": decoder_hp
  54. })
  55. trainer.train(dataset, False)
  56. output = trainer.predict(dataset, 'test')
  57. acc = (output == data.y[data.test_mask]).float().mean().item()
  58. accs.append(acc)
  59. print('{:.4f} ~ {:.4f}'.format(np.mean(accs), np.std(accs)))