You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

enas.py 7.5 kB

5 years ago
5 years ago
4 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223
  1. # codes in this file are reproduced from https://github.com/microsoft/nni with some changes.
  2. import copy
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. from . import register_nas_algo
  7. from .base import BaseNAS
  8. from ..space import BaseSpace
  9. from ..utils import (
  10. AverageMeterGroup,
  11. replace_layer_choice,
  12. replace_input_choice,
  13. get_module_order,
  14. sort_replaced_module,
  15. )
  16. from tqdm import tqdm, trange
  17. from .rl import (
  18. PathSamplingLayerChoice,
  19. PathSamplingInputChoice,
  20. ReinforceField,
  21. ReinforceController,
  22. )
  23. from ....utils import get_logger
  24. LOGGER = get_logger("ENAS")
  25. @register_nas_algo("enas")
  26. class Enas(BaseNAS):
  27. """
  28. ENAS trainer.
  29. Parameters
  30. ----------
  31. num_epochs : int
  32. Number of epochs planned for training.
  33. n_warmup : int
  34. Number of epochs for training super network.
  35. log_frequency : int
  36. Step count per logging.
  37. grad_clip : float
  38. Gradient clipping. Set to 0 to disable. Default: 5.
  39. entropy_weight : float
  40. Weight of sample entropy loss.
  41. skip_weight : float
  42. Weight of skip penalty loss.
  43. baseline_decay : float
  44. Decay factor of baseline. New baseline will be equal to ``baseline_decay * baseline_old + reward * (1 - baseline_decay)``.
  45. ctrl_lr : float
  46. Learning rate for RL controller.
  47. ctrl_steps_aggregate : int
  48. Number of steps that will be aggregated into one mini-batch for RL controller.
  49. ctrl_kwargs : dict
  50. Optional kwargs that will be passed to :class:`ReinforceController`.
  51. model_lr : float
  52. Learning rate for super network.
  53. model_wd : float
  54. Weight decay for super network.
  55. disable_progeress: boolean
  56. Control whether show the progress bar.
  57. device : str or torch.device
  58. The device of the whole process, e.g. "cuda", torch.device("cpu")
  59. """
  60. def __init__(
  61. self,
  62. num_epochs=5,
  63. n_warmup=100,
  64. log_frequency=None,
  65. grad_clip=5.0,
  66. entropy_weight=0.0001,
  67. skip_weight=0.8,
  68. baseline_decay=0.999,
  69. ctrl_lr=0.00035,
  70. ctrl_steps_aggregate=20,
  71. ctrl_kwargs=None,
  72. model_lr=5e-3,
  73. model_wd=5e-4,
  74. disable_progress=True,
  75. device="auto",
  76. ):
  77. super().__init__(device)
  78. self.num_epochs = num_epochs
  79. self.log_frequency = log_frequency
  80. self.entropy_weight = entropy_weight
  81. self.skip_weight = skip_weight
  82. self.baseline_decay = baseline_decay
  83. self.baseline = 0.0
  84. self.ctrl_steps_aggregate = ctrl_steps_aggregate
  85. self.grad_clip = grad_clip
  86. self.ctrl_kwargs = ctrl_kwargs
  87. self.ctrl_lr = ctrl_lr
  88. self.n_warmup = n_warmup
  89. self.model_lr = model_lr
  90. self.model_wd = model_wd
  91. self.disable_progress = disable_progress
  92. def search(self, space: BaseSpace, dset, estimator):
  93. self.model = space
  94. self.dataset = dset # .to(self.device)
  95. self.estimator = estimator
  96. # replace choice
  97. self.nas_modules = []
  98. k2o = get_module_order(self.model)
  99. replace_layer_choice(self.model, PathSamplingLayerChoice, self.nas_modules)
  100. replace_input_choice(self.model, PathSamplingInputChoice, self.nas_modules)
  101. self.nas_modules = sort_replaced_module(k2o, self.nas_modules)
  102. # to device
  103. self.model = self.model.to(self.device)
  104. self.model_optim = torch.optim.Adam(
  105. self.model.parameters(), lr=self.model_lr, weight_decay=self.model_wd
  106. )
  107. # fields
  108. self.nas_fields = [
  109. ReinforceField(
  110. name,
  111. len(module),
  112. isinstance(module, PathSamplingLayerChoice) or module.n_chosen == 1,
  113. )
  114. for name, module in self.nas_modules
  115. ]
  116. self.controller = ReinforceController(
  117. self.nas_fields, **(self.ctrl_kwargs or {})
  118. )
  119. self.ctrl_optim = torch.optim.Adam(
  120. self.controller.parameters(), lr=self.ctrl_lr
  121. )
  122. # warm up supernet
  123. with tqdm(range(self.n_warmup), disable=self.disable_progress) as bar:
  124. for i in bar:
  125. acc, l1 = self._train_model(i)
  126. with torch.no_grad():
  127. val_acc, val_loss = self._infer("val")
  128. bar.set_postfix(loss=l1, acc=acc, val_acc=val_acc, val_loss=val_loss)
  129. # train
  130. with tqdm(range(self.num_epochs), disable=self.disable_progress) as bar:
  131. for i in bar:
  132. #try:
  133. l1 = self._train_model(i)
  134. l2 = self._train_controller(i)
  135. """except Exception as e:
  136. print(e)
  137. nm = self.nas_modules
  138. for i in range(len(nm)):
  139. print(nm[i][1].sampled)"""
  140. bar.set_postfix(loss_model=l1, reward_controller=l2)
  141. selection = self.export()
  142. # print(selection)
  143. return space.parse_model(selection, self.device)
  144. def _train_model(self, epoch):
  145. self.model.train()
  146. self.controller.eval()
  147. self.model_optim.zero_grad()
  148. self._resample()
  149. metric, loss = self._infer()
  150. loss.backward()
  151. if self.grad_clip > 0:
  152. nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip)
  153. self.model_optim.step()
  154. return metric, loss.item()
  155. def _train_controller(self, epoch):
  156. self.model.eval()
  157. self.controller.train()
  158. self.ctrl_optim.zero_grad()
  159. rewards = []
  160. for ctrl_step in range(self.ctrl_steps_aggregate):
  161. self._resample()
  162. with torch.no_grad():
  163. metric, loss = self._infer(mask="val")
  164. reward = metric
  165. rewards.append(reward)
  166. if self.entropy_weight:
  167. reward += self.entropy_weight * self.controller.sample_entropy.item()
  168. self.baseline = self.baseline * self.baseline_decay + reward * (
  169. 1 - self.baseline_decay
  170. )
  171. loss = self.controller.sample_log_prob * (reward - self.baseline)
  172. if self.skip_weight:
  173. loss += self.skip_weight * self.controller.sample_skip_penalty
  174. loss /= self.ctrl_steps_aggregate
  175. loss.backward()
  176. if (ctrl_step + 1) % self.ctrl_steps_aggregate == 0:
  177. if self.grad_clip > 0:
  178. nn.utils.clip_grad_norm_(
  179. self.controller.parameters(), self.grad_clip
  180. )
  181. self.ctrl_optim.step()
  182. self.ctrl_optim.zero_grad()
  183. if self.log_frequency is not None and ctrl_step % self.log_frequency == 0:
  184. LOGGER.info(
  185. "RL Epoch [%d/%d] Step [%d/%d] %s",
  186. epoch + 1,
  187. self.num_epochs,
  188. ctrl_step + 1,
  189. self.ctrl_steps_aggregate,
  190. )
  191. return sum(rewards) / len(rewards)
  192. def _resample(self):
  193. result = self.controller.resample()
  194. for name, module in self.nas_modules:
  195. module.sampled = result[name]
  196. def export(self):
  197. self.controller.eval()
  198. with torch.no_grad():
  199. return self.controller.resample()
  200. def _infer(self, mask="train"):
  201. metric, loss = self.estimator.infer(self.model, self.dataset, mask=mask)
  202. return metric[0], loss