You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

base.py 7.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222
  1. """
  2. Performance check of PYG (model + trainer + dataset)
  3. """
  4. import os
  5. import random
  6. import numpy as np
  7. from tqdm import tqdm
  8. import torch
  9. import torch.nn.functional as F
  10. from torch.nn import Sequential, Linear, ReLU
  11. import torch_geometric
  12. from torch_geometric.datasets import TUDataset
  13. if int(torch_geometric.__version__.split(".")[0]) >= 2:
  14. from torch_geometric.loader import DataLoader
  15. else:
  16. from torch_geometric.data import DataLoader
  17. from torch_geometric.nn import GINConv, global_add_pool, GraphConv, TopKPooling
  18. from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp
  19. import logging
  20. torch.backends.cudnn.deterministic = True
  21. #torch.use_deterministic_algorithms(True)
  22. logging.basicConfig(level=logging.ERROR)
  23. class GIN(torch.nn.Module):
  24. def __init__(self):
  25. super(GIN, self).__init__()
  26. num_features = dataset.num_features
  27. dim = 32
  28. nn1 = Sequential(Linear(num_features, dim), ReLU(), Linear(dim, dim))
  29. self.conv1 = GINConv(nn1)
  30. self.bn1 = torch.nn.BatchNorm1d(dim)
  31. nn2 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim))
  32. self.conv2 = GINConv(nn2)
  33. self.bn2 = torch.nn.BatchNorm1d(dim)
  34. nn3 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim))
  35. self.conv3 = GINConv(nn3)
  36. self.bn3 = torch.nn.BatchNorm1d(dim)
  37. nn4 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim))
  38. self.conv4 = GINConv(nn4)
  39. self.bn4 = torch.nn.BatchNorm1d(dim)
  40. nn5 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim))
  41. self.conv5 = GINConv(nn5)
  42. self.bn5 = torch.nn.BatchNorm1d(dim)
  43. self.fc1 = Linear(dim, dim)
  44. self.fc2 = Linear(dim, dataset.num_classes)
  45. def forward(self, data):
  46. x, edge_index, batch = data.x, data.edge_index, data.batch
  47. x = F.relu(self.conv1(x, edge_index))
  48. x = self.bn1(x)
  49. x = F.relu(self.conv2(x, edge_index))
  50. x = self.bn2(x)
  51. x = F.relu(self.conv3(x, edge_index))
  52. x = self.bn3(x)
  53. x = F.relu(self.conv4(x, edge_index))
  54. x = self.bn4(x)
  55. x = F.relu(self.conv5(x, edge_index))
  56. x = self.bn5(x)
  57. x = global_add_pool(x, batch)
  58. x = F.relu(self.fc1(x))
  59. x = F.dropout(x, p=0.5, training=self.training)
  60. x = self.fc2(x)
  61. return F.log_softmax(x, dim=-1)
  62. class TopKPool(torch.nn.Module):
  63. def __init__(self):
  64. super(TopKPool, self).__init__()
  65. self.conv1 = GraphConv(dataset.num_features, 128)
  66. self.pool1 = TopKPooling(128, ratio=0.8)
  67. self.conv2 = GraphConv(128, 128)
  68. self.pool2 = TopKPooling(128, ratio=0.8)
  69. self.conv3 = GraphConv(128, 128)
  70. self.pool3 = TopKPooling(128, ratio=0.8)
  71. self.lin1 = torch.nn.Linear(256, 128)
  72. self.lin2 = torch.nn.Linear(128, 64)
  73. self.lin3 = torch.nn.Linear(64, dataset.num_classes)
  74. def forward(self, data):
  75. x, edge_index, batch = data.x, data.edge_index, data.batch
  76. x = F.relu(self.conv1(x, edge_index))
  77. x, edge_index, _, batch, _, _ = self.pool1(x, edge_index, None, batch)
  78. x1 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)
  79. x = F.relu(self.conv2(x, edge_index))
  80. x, edge_index, _, batch, _, _ = self.pool2(x, edge_index, None, batch)
  81. x2 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)
  82. x = F.relu(self.conv3(x, edge_index))
  83. x, edge_index, _, batch, _, _ = self.pool3(x, edge_index, None, batch)
  84. x3 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)
  85. x = x1 + x2 + x3
  86. x = F.relu(self.lin1(x))
  87. x = F.dropout(x, p=0.5, training=self.training)
  88. x = F.relu(self.lin2(x))
  89. x = F.log_softmax(self.lin3(x), dim=-1)
  90. return x
  91. def test(model, loader, args):
  92. model.eval()
  93. correct = 0
  94. for data in loader:
  95. data = data.to(args.device)
  96. output = model(data)
  97. pred = output.max(dim=1)[1]
  98. correct += pred.eq(data.y).sum().item()
  99. return correct / len(loader.dataset)
  100. def train(model, train_loader, val_loader, args):
  101. optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
  102. parameters = model.state_dict()
  103. best_acc = 0.
  104. for epoch in range(args.epoch):
  105. model.train()
  106. for data in train_loader:
  107. data = data.to(args.device)
  108. optimizer.zero_grad()
  109. output = model(data)
  110. loss = F.nll_loss(output, data.y)
  111. loss.backward()
  112. optimizer.step()
  113. val_acc = test(model, val_loader, args)
  114. if val_acc > best_acc:
  115. best_acc = val_acc
  116. parameters = model.state_dict()
  117. model.load_state_dict(parameters)
  118. return model
  119. if __name__ == '__main__':
  120. import argparse
  121. parser = argparse.ArgumentParser('pyg trainer')
  122. parser.add_argument('--device', type=str, default='cuda')
  123. parser.add_argument('--dataset', type=str, choices=['MUTAG', 'COLLAB', 'IMDBBINARY', 'IMDBMULTI', 'NCI1', 'PROTEINS', 'PTC', 'REDDITBINARY', 'REDDITMULTI5K'], default='MUTAG')
  124. parser.add_argument('--dataset_seed', type=int, default=2021)
  125. parser.add_argument('--batch_size', type=int, default=32)
  126. parser.add_argument('--repeat', type=int, default=50)
  127. parser.add_argument('--model', type=str, choices=['gin', 'topkpool'], default='gin')
  128. parser.add_argument('--lr', type=float, default=0.0001)
  129. parser.add_argument('--epoch', type=int, default=100)
  130. args = parser.parse_args()
  131. # seed = 100
  132. dataset = TUDataset(os.path.expanduser('~/.pyg'), args.dataset)
  133. # 1. split dataset [fix split]
  134. dataids = list(range(len(dataset)))
  135. random.seed(args.dataset_seed)
  136. random.shuffle(dataids)
  137. torch.manual_seed(args.dataset_seed)
  138. np.random.seed(args.dataset_seed)
  139. if args.device == 'cuda':
  140. torch.cuda.manual_seed(args.dataset_seed)
  141. fold = int(len(dataset) * 0.1)
  142. train_index = dataids[:fold * 8]
  143. val_index = dataids[fold * 8: fold * 9]
  144. test_index = dataids[fold * 9: ]
  145. dataset.train_index = train_index
  146. dataset.val_index = val_index
  147. dataset.test_index = test_index
  148. dataset.train_split = dataset[dataset.train_index]
  149. dataset.val_split = dataset[dataset.val_index]
  150. dataset.test_split = dataset[dataset.test_index]
  151. labels = np.array([data.y.item() for data in dataset.test_split])
  152. def seed_worker(worker_id):
  153. #seed = torch.initial_seed()
  154. torch.manual_seed(args.dataset_seed)
  155. np.random.seed(args.dataset_seed)
  156. random.seed(args.dataset_seed)
  157. g = torch.Generator()
  158. g.manual_seed(args.dataset_seed)
  159. train_loader = DataLoader(dataset.train_split, batch_size=args.batch_size, worker_init_fn=seed_worker, generator=g)
  160. val_loader = DataLoader(dataset.val_split, batch_size=args.batch_size, worker_init_fn=seed_worker, generator=g)
  161. test_loader = DataLoader(dataset.test_split, batch_size=args.batch_size, worker_init_fn=seed_worker, generator=g)
  162. #train_loader = DataLoader(dataset.train_split, batch_size=args.batch_size, shuffle=False)
  163. #val_loader = DataLoader(dataset.val_split, batch_size=args.batch_size, shuffle=False)
  164. #test_loader = DataLoader(dataset.test_split, batch_size=args.batch_size, shuffle=False)
  165. accs = []
  166. for seed in tqdm(range(args.repeat)):
  167. torch.manual_seed(seed)
  168. np.random.seed(seed)
  169. #random.seed(seed)
  170. if args.device == 'cuda':
  171. torch.cuda.manual_seed(seed)
  172. if args.model == 'gin':
  173. model = GIN()
  174. elif args.model == 'topkpool':
  175. model = TopKPool()
  176. model.to(args.device)
  177. train(model, train_loader, val_loader, args)
  178. acc = test(model, test_loader, args)
  179. accs.append(acc)
  180. print('{:.4f} ~ {:.4f}'.format(np.mean(accs), np.std(accs)))