You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

gasso.py 5.0 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. # "Graph differentiable architecture search with structure optimization" NeurIPS 21'
  2. import logging
  3. import torch
  4. import torch.optim
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. from . import register_nas_algo
  8. from .base import BaseNAS
  9. from ..estimator.base import BaseEstimator
  10. from ..space import BaseSpace
  11. from ..utils import replace_layer_choice, replace_input_choice
  12. from ...model.base import BaseAutoModel
  13. from torch.autograd import Variable
  14. import numpy as np
  15. import time
  16. import copy
  17. import torch.optim as optim
  18. import scipy.sparse as sp
  19. _logger = logging.getLogger(__name__)
  20. @register_nas_algo("gasso")
  21. class Gasso(BaseNAS):
  22. """
  23. GASSO trainer.
  24. Parameters
  25. ----------
  26. num_epochs : int
  27. Number of epochs planned for training.
  28. warmup_epochs : int
  29. Number of epochs planned for warming up.
  30. workers : int
  31. Workers for data loading.
  32. model_lr : float
  33. Learning rate to optimize the model.
  34. model_wd : float
  35. Weight decay to optimize the model.
  36. arch_lr : float
  37. Learning rate to optimize the architecture.
  38. stru_lr : float
  39. Learning rate to optimize the structure.
  40. lamb : float
  41. The parameter to control the influence of hidden feature smoothness
  42. device : str or torch.device
  43. The device of the whole process
  44. """
  45. def __init__(
  46. self,
  47. num_epochs=250,
  48. warmup_epochs=10,
  49. model_lr=0.01,
  50. model_wd=1e-4,
  51. arch_lr = 0.03,
  52. stru_lr = 0.04,
  53. lamb = 0.6,
  54. device="auto",
  55. ):
  56. super().__init__(device=device)
  57. self.num_epochs = num_epochs
  58. self.warmup_epochs = warmup_epochs
  59. self.model_lr = model_lr
  60. self.model_wd = model_wd
  61. self.arch_lr = arch_lr
  62. self.stru_lr = stru_lr
  63. self.lamb = lamb
  64. def train_stru(self, model, optimizer, data):
  65. # forward
  66. model.train()
  67. data[0].adj = self.adjs
  68. logits = model(data[0]).detach()
  69. loss = 0
  70. for adj in self.adjs:
  71. e1 = adj[0][0]
  72. e2 = adj[0][1]
  73. ew = adj[1]
  74. diff = (logits[e1] - logits[e2]).pow(2).sum(1)
  75. smooth = (diff * torch.sigmoid(ew)).sum()
  76. dist = (ew * ew).sum()
  77. loss += self.lamb * smooth + dist
  78. optimizer.zero_grad()
  79. loss.backward()
  80. optimizer.step()
  81. train_loss = loss.item()
  82. del logits
  83. def _infer(self, model: BaseSpace, dataset, estimator: BaseEstimator, mask="train"):
  84. dataset[0].adj = self.adjs
  85. metric, loss = estimator.infer(model, dataset, mask=mask)
  86. return metric, loss
  87. def prepare(self, dset):
  88. """Train Pro-GNN.
  89. """
  90. data = dset[0]
  91. self.ews = []
  92. self.edges = data.edge_index.to(self.device)
  93. edge_weight = torch.ones(self.edges.size(1)).to(self.device)
  94. self.adjs = []
  95. for i in range(self.steps):
  96. edge_weight = Variable(edge_weight * 1.0, requires_grad = True).to(self.device)
  97. self.ews.append(edge_weight)
  98. self.adjs.append((self.edges, edge_weight))
  99. def fit(self, data):
  100. self.optimizer = optim.Adam(self.space.parameters(), lr=self.model_lr, weight_decay=self.model_wd)
  101. self.arch_optimizer = optim.Adam(self.space.arch_parameters(),
  102. lr=self.arch_lr, betas=(0.5, 0.999))
  103. self.stru_optimizer = optim.SGD(self.ews, lr=self.stru_lr)
  104. # Train model
  105. best_performance = 0
  106. min_val_loss = float("inf")
  107. min_train_loss = float("inf")
  108. t_total = time.time()
  109. for epoch in range(self.num_epochs):
  110. self.space.train()
  111. self.optimizer.zero_grad()
  112. _, loss = self._infer(self.space, data, self.estimator, "train")
  113. loss.backward()
  114. self.optimizer.step()
  115. if epoch <20:
  116. continue
  117. self.train_stru(self.space, self.stru_optimizer, data)
  118. self.arch_optimizer.zero_grad()
  119. _, loss = self._infer(self.space, data, self.estimator, "train")
  120. loss.backward()
  121. self.arch_optimizer.step()
  122. self.space.eval()
  123. train_acc, _ = self._infer(self.space, data, self.estimator, "train")
  124. val_acc, val_loss = self._infer(self.space, data, self.estimator, "val")
  125. if val_loss < min_val_loss:
  126. min_val_loss = val_loss
  127. best_performance = val_acc
  128. self.space.keep_prediction()
  129. #print("acc:" + str(train_acc) + " val_acc" + str(val_acc))
  130. return best_performance, min_val_loss
  131. def search(self, space: BaseSpace, dataset, estimator):
  132. self.estimator = estimator
  133. self.space = space.to(self.device)
  134. self.steps = space.steps
  135. self.prepare(dataset)
  136. perf, val_loss = self.fit(dataset)
  137. return space.parse_model(None, self.device)