You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

darts.py 5.9 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. # Modified from NNI
  2. import logging
  3. import torch
  4. import torch.optim
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. from . import register_nas_algo
  8. from .base import BaseNAS
  9. from ..estimator.base import BaseEstimator
  10. from ..space import BaseSpace
  11. from ..utils import replace_layer_choice, replace_input_choice
  12. # from nni.retiarii.oneshot.pytorch.darts import DartsLayerChoice, DartsInputChoice
  13. _logger = logging.getLogger(__name__)
  14. # copy from nni2.1 for stablility
  15. class DartsLayerChoice(nn.Module):
  16. def __init__(self, layer_choice):
  17. super(DartsLayerChoice, self).__init__()
  18. self.name = layer_choice.key
  19. self.op_choices = nn.ModuleDict(layer_choice.named_children())
  20. self.alpha = nn.Parameter(torch.randn(len(self.op_choices)) * 1e-3)
  21. def forward(self, *args, **kwargs):
  22. op_results = torch.stack([op(*args, **kwargs) for op in self.op_choices.values()])
  23. alpha_shape = [-1] + [1] * (len(op_results.size()) - 1)
  24. return torch.sum(op_results * F.softmax(self.alpha, -1).view(*alpha_shape), 0)
  25. def parameters(self):
  26. for _, p in self.named_parameters():
  27. yield p
  28. def named_parameters(self):
  29. for name, p in super(DartsLayerChoice, self).named_parameters():
  30. if name == 'alpha':
  31. continue
  32. yield name, p
  33. def export(self):
  34. return torch.argmax(self.alpha).item()
  35. class DartsInputChoice(nn.Module):
  36. def __init__(self, input_choice):
  37. super(DartsInputChoice, self).__init__()
  38. self.name = input_choice.key
  39. self.alpha = nn.Parameter(torch.randn(input_choice.n_candidates) * 1e-3)
  40. self.n_chosen = input_choice.n_chosen or 1
  41. def forward(self, inputs):
  42. inputs = torch.stack(inputs)
  43. alpha_shape = [-1] + [1] * (len(inputs.size()) - 1)
  44. return torch.sum(inputs * F.softmax(self.alpha, -1).view(*alpha_shape), 0)
  45. def parameters(self):
  46. for _, p in self.named_parameters():
  47. yield p
  48. def named_parameters(self):
  49. for name, p in super(DartsInputChoice, self).named_parameters():
  50. if name == 'alpha':
  51. continue
  52. yield name, p
  53. def export(self):
  54. return torch.argsort(-self.alpha).cpu().numpy().tolist()[:self.n_chosen]
  55. @register_nas_algo("darts")
  56. class Darts(BaseNAS):
  57. """
  58. DARTS trainer.
  59. Parameters
  60. ----------
  61. num_epochs : int
  62. Number of epochs planned for training.
  63. workers : int
  64. Workers for data loading.
  65. gradient_clip : float
  66. Gradient clipping. Set to 0 to disable. Default: 5.
  67. model_lr : float
  68. Learning rate to optimize the model.
  69. model_wd : float
  70. Weight decay to optimize the model.
  71. arch_lr : float
  72. Learning rate to optimize the architecture.
  73. arch_wd : float
  74. Weight decay to optimize the architecture.
  75. device : str or torch.device
  76. The device of the whole process
  77. """
  78. def __init__(
  79. self,
  80. num_epochs=5,
  81. workers=4,
  82. gradient_clip=5.0,
  83. model_lr=1e-3,
  84. model_wd=5e-4,
  85. arch_lr=3e-4,
  86. arch_wd=1e-3,
  87. device="auto",
  88. ):
  89. super().__init__(device=device)
  90. self.num_epochs = num_epochs
  91. self.workers = workers
  92. self.gradient_clip = gradient_clip
  93. self.model_optimizer = torch.optim.Adam
  94. self.arch_optimizer = torch.optim.Adam
  95. self.model_lr = model_lr
  96. self.model_wd = model_wd
  97. self.arch_lr = arch_lr
  98. self.arch_wd = arch_wd
  99. def search(self, space: BaseSpace, dataset, estimator):
  100. model_optim = self.model_optimizer(
  101. space.parameters(), self.model_lr, weight_decay=self.model_wd
  102. )
  103. nas_modules = []
  104. replace_layer_choice(space, DartsLayerChoice, nas_modules)
  105. replace_input_choice(space, DartsInputChoice, nas_modules)
  106. space = space.to(self.device)
  107. ctrl_params = {}
  108. for _, m in nas_modules:
  109. if m.name in ctrl_params:
  110. assert (
  111. m.alpha.size() == ctrl_params[m.name].size()
  112. ), "Size of parameters with the same label should be same."
  113. m.alpha = ctrl_params[m.name]
  114. else:
  115. ctrl_params[m.name] = m.alpha
  116. arch_optim = self.arch_optimizer(
  117. list(ctrl_params.values()), self.arch_lr, weight_decay=self.arch_wd
  118. )
  119. for epoch in range(self.num_epochs):
  120. self._train_one_epoch(
  121. epoch, space, dataset, estimator, model_optim, arch_optim
  122. )
  123. selection = self.export(nas_modules)
  124. return space.parse_model(selection, self.device)
  125. def _train_one_epoch(
  126. self,
  127. epoch,
  128. model: BaseSpace,
  129. dataset,
  130. estimator,
  131. model_optim: torch.optim.Optimizer,
  132. arch_optim: torch.optim.Optimizer,
  133. ):
  134. model.train()
  135. # phase 1. architecture step
  136. arch_optim.zero_grad()
  137. # only no unroll here
  138. _, loss = self._infer(model, dataset, estimator, "val")
  139. loss.backward()
  140. arch_optim.step()
  141. # phase 2: child network step
  142. model_optim.zero_grad()
  143. metric, loss = self._infer(model, dataset, estimator, "train")
  144. loss.backward()
  145. # gradient clipping
  146. if self.gradient_clip > 0:
  147. nn.utils.clip_grad_norm_(model.parameters(), self.gradient_clip)
  148. model_optim.step()
  149. def _infer(self, model: BaseSpace, dataset, estimator: BaseEstimator, mask="train"):
  150. metric, loss = estimator.infer(model, dataset, mask=mask)
  151. return metric, loss
  152. @torch.no_grad()
  153. def export(self, nas_modules) -> dict:
  154. result = dict()
  155. for name, module in nas_modules:
  156. if name not in result:
  157. result[name] = module.export()
  158. return result