You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

model.py 3.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. """
  2. Performance check of AutoGL model + PYG (trainer + dataset)
  3. """
  4. import os
  5. import pickle
  6. import numpy as np
  7. from tqdm import tqdm
  8. os.environ["AUTOGL_BACKEND"] = "pyg"
  9. import torch
  10. import torch.nn.functional as F
  11. from torch_geometric.datasets import Planetoid
  12. import torch_geometric.transforms as T
  13. from autogl.module.model.pyg import AutoGCN, AutoGAT, AutoSAGE
  14. from autogl.datasets import utils
  15. from autogl.solver.utils import set_seed
  16. from helper import get_encoder_decoder_hp
  17. import logging
  18. logging.basicConfig(level=logging.ERROR)
  19. def test(model, data, mask):
  20. model.eval()
  21. if hasattr(model, 'cls_forward'):
  22. out = model.cls_forward(data)[mask]
  23. else:
  24. out = model(data)[mask]
  25. pred = out.max(1)[1]
  26. acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()
  27. return acc
  28. def train(model, data, args):
  29. optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
  30. parameters = model.state_dict()
  31. best_acc = 0.
  32. for epoch in range(args.epoch):
  33. model.train()
  34. optimizer.zero_grad()
  35. if hasattr(model, 'cls_forward'):
  36. output = model.cls_forward(data)
  37. else:
  38. output = model(data)
  39. loss = F.nll_loss(output[data.train_mask], data.y[data.train_mask])
  40. loss.backward()
  41. optimizer.step()
  42. val_acc = test(model, data, data.val_mask)
  43. if val_acc > best_acc:
  44. best_acc = val_acc
  45. parameters = pickle.dumps(model.state_dict())
  46. model.load_state_dict(pickle.loads(parameters))
  47. return model
  48. if __name__ == '__main__':
  49. import argparse
  50. parser = argparse.ArgumentParser('pyg model')
  51. parser.add_argument('--device', type=str, default='cuda')
  52. parser.add_argument('--dataset', type=str, choices=['Cora', 'CiteSeer', 'PubMed'], default='Cora')
  53. parser.add_argument('--repeat', type=int, default=50)
  54. parser.add_argument('--model', type=str, choices=['gat', 'gcn', 'sage'], default='gat')
  55. parser.add_argument('--lr', type=float, default=0.01)
  56. parser.add_argument('--weight_decay', type=float, default=0.0)
  57. parser.add_argument('--epoch', type=int, default=200)
  58. args = parser.parse_args()
  59. # seed = 100
  60. dataset = Planetoid(os.path.expanduser('~/.cache-autogl'), args.dataset, transform=T.NormalizeFeatures())
  61. data = dataset[0].to(args.device)
  62. accs = []
  63. model_hp, _ = get_encoder_decoder_hp(args.model)
  64. for seed in tqdm(range(args.repeat)):
  65. set_seed(seed)
  66. if args.model == 'gat':
  67. model = AutoGAT(
  68. num_features=dataset.num_node_features,
  69. num_classes=dataset.num_classes,
  70. device=args.device,
  71. init=False
  72. ).from_hyper_parameter(model_hp).model
  73. elif args.model == 'gcn':
  74. model = AutoGCN(
  75. num_features=dataset.num_node_features,
  76. num_classes=dataset.num_classes,
  77. device=args.device,
  78. init=False
  79. ).from_hyper_parameter(model_hp).model
  80. elif args.model == 'sage':
  81. model = AutoSAGE(
  82. num_features=dataset.num_node_features,
  83. num_classes=dataset.num_classes,
  84. device=args.device,
  85. init=False
  86. ).from_hyper_parameter(model_hp).model
  87. model.to(args.device)
  88. train(model, data, args)
  89. acc = test(model, data, data.test_mask)
  90. accs.append(acc)
  91. print('{:.4f} ~ {:.4f}'.format(np.mean(accs), np.std(accs)))