You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

model.py 6.5 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. """
  2. Performance check of AutoGL model + DGL (dataset + trainer)
  3. """
  4. import os
  5. import pickle
  6. os.environ["AUTOGL_BACKEND"] = "dgl"
  7. # from dgl.dataloading.pytorch.dataloader import GraphDataLoader
  8. from dgl.dataloading import GraphDataLoader
  9. import numpy as np
  10. from tqdm import tqdm
  11. import random
  12. import torch
  13. import torch.nn as nn
  14. import torch.optim as optim
  15. from dgl.data import GINDataset
  16. import torch
  17. import torch.nn as nn
  18. from autogl.module.model.dgl.gin import AutoGIN
  19. from autogl.module.model.dgl.topkpool import AutoTopkpool
  20. from autogl.solver.utils import set_seed
  21. import argparse
  22. class DatasetAbstraction():
  23. def __init__(self, graphs, labels):
  24. for g in graphs:
  25. g.ndata['feat'] = g.ndata['attr']
  26. self.graphs, self.labels = [], []
  27. for g, l in zip(graphs, labels):
  28. self.graphs.append(g)
  29. self.labels.append(l)
  30. self.gclasses = max(self.labels).item() + 1
  31. self.graph = self.graphs
  32. def __len__(self):
  33. return len(self.graphs)
  34. def __getitem__(self, idx):
  35. if isinstance(idx, int):
  36. return self.graphs[idx], self.labels[idx]
  37. elif isinstance(idx, torch.BoolTensor):
  38. idx = [i for i in range(len(idx)) if idx[i]]
  39. elif isinstance(idx, torch.Tensor) and idx.unique()[0].sum().item() == 1:
  40. idx = [i for i in range(len(idx)) if idx[i]]
  41. return DatasetAbstraction([self.graphs[i] for i in idx], [self.labels[i] for i in idx])
  42. def train(net, trainloader, validloader, optimizer, criterion, epoch, device):
  43. best_model = pickle.dumps(net.state_dict())
  44. best_acc = 0.
  45. for e in range(epoch):
  46. net.train()
  47. for graphs, labels in trainloader:
  48. labels = labels.to(device)
  49. graphs = graphs.to(device)
  50. # outputs = net((graphs, labels))
  51. # feat = graphs.ndata.pop('attr')
  52. # outputs = net(graphs, feat)
  53. outputs = net(graphs)
  54. loss = criterion(outputs, labels)
  55. # backprop
  56. optimizer.zero_grad()
  57. loss.backward()
  58. optimizer.step()
  59. gt = []
  60. pr = []
  61. net.eval()
  62. for graphs, labels in validloader:
  63. labels = labels.to(device)
  64. graphs = graphs.to(device)
  65. gt.append(labels)
  66. # feat = graphs.ndata.pop('attr')
  67. # outputs = net(graphs, feat)
  68. # outputs = net((graphs, labels))
  69. outputs = net(graphs)
  70. pr.append(outputs.argmax(1))
  71. gt = torch.cat(gt, dim=0)
  72. pr = torch.cat(pr, dim=0)
  73. acc = (gt == pr).float().mean().item()
  74. if acc > best_acc:
  75. best_acc = acc
  76. best_model = pickle.dumps(net.state_dict())
  77. net.load_state_dict(pickle.loads(best_model))
  78. return net
  79. def eval_net(net, dataloader, device):
  80. net.eval()
  81. total = 0
  82. total_correct = 0
  83. for data in dataloader:
  84. graphs, labels = data
  85. graphs = graphs.to(device)
  86. labels = labels.to(device)
  87. # feat = graphs.ndata.pop('attr')
  88. total += len(labels)
  89. # outputs = net(graphs, feat)
  90. # outputs = net((graphs, labels))
  91. outputs = net(graphs)
  92. _, predicted = torch.max(outputs.data, 1)
  93. total_correct += (predicted == labels.data).sum().item()
  94. acc = 1.0 * total_correct / total
  95. net.train()
  96. return acc
  97. def main(args):
  98. device = torch.device(args.device)
  99. dataset_ = GINDataset(args.dataset, False)
  100. dataset = DatasetAbstraction([g[0] for g in dataset_], [g[1] for g in dataset_])
  101. # 1. split dataset [fix split]
  102. dataids = list(range(len(dataset)))
  103. random.seed(args.dataset_seed)
  104. random.shuffle(dataids)
  105. fold = int(len(dataset) * 0.1)
  106. train_dataset = dataset[dataids[:fold * 8]]
  107. val_dataset = dataset[dataids[fold * 8: fold * 9]]
  108. test_dataset = dataset[dataids[fold * 9: ]]
  109. trainloader = GraphDataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
  110. valloader = GraphDataLoader(val_dataset, batch_size=args.batch_size, shuffle=False)
  111. testloader = GraphDataLoader(test_dataset, batch_size=args.batch_size, shuffle=False)
  112. accs = []
  113. for seed in tqdm(range(args.repeat)):
  114. # set up seeds, args.seed supported
  115. set_seed(seed)
  116. if args.model == 'gin':
  117. model = AutoGIN(
  118. num_features=dataset_.dim_nfeats,
  119. num_classes=dataset_.gclasses,
  120. device=device,
  121. ).from_hyper_parameter({
  122. "num_layers": 5,
  123. "hidden": [64,64,64,64],
  124. "dropout": 0.5,
  125. "act": "relu",
  126. "eps": "False",
  127. "mlp_layers": 2,
  128. "neighbor_pooling_type": "sum",
  129. "graph_pooling_type": "sum"
  130. }).model
  131. elif args.model == 'topkpool':
  132. model = AutoTopkpool(
  133. num_features=dataset_.dim_nfeats,
  134. num_classes=dataset_.gclasses,
  135. device=device,
  136. ).from_hyper_parameter({
  137. "num_layers": 5,
  138. "hidden": [64,64,64,64],
  139. "dropout": 0.5
  140. }).model
  141. model = model.to(device)
  142. criterion = nn.CrossEntropyLoss() # defaul reduce is true
  143. optimizer = optim.Adam(model.parameters(), lr=args.lr)
  144. model = train(model, trainloader, valloader, optimizer, criterion, args.epoch, device)
  145. acc = eval_net(model, testloader, device)
  146. accs.append(acc)
  147. print('{:.2f} ~ {:.2f}'.format(np.mean(accs) * 100, np.std(accs) * 100))
  148. if __name__ == '__main__':
  149. parser = argparse.ArgumentParser('model parser')
  150. parser.add_argument('--device', type=str, default='cuda')
  151. parser.add_argument('--dataset', type=str, choices=['MUTAG', 'COLLAB', 'IMDBBINARY', 'IMDBMULTI', 'NCI1', 'PROTEINS', 'PTC', 'REDDITBINARY', 'REDDITMULTI5K'], default='MUTAG')
  152. parser.add_argument('--dataset_seed', type=int, default=2021)
  153. parser.add_argument('--batch_size', type=int, default=32)
  154. parser.add_argument('--repeat', type=int, default=50)
  155. parser.add_argument('--model', type=str, choices=['gin', 'topkpool'], default='gin')
  156. parser.add_argument('--lr', type=float, default=0.0001)
  157. parser.add_argument('--epoch', type=int, default=100)
  158. args = parser.parse_args()
  159. main(args)