You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318
  1. # codes in this file are reproduced from https://github.com/microsoft/nni with some changes.
  2. import copy
  3. from logging import Logger
  4. from numpy.core.fromnumeric import sort
  5. import torch
  6. import torch.nn as nn
  7. import torch.nn.functional as F
  8. from . import register_nas_algo
  9. from .base import BaseNAS
  10. from ..space import BaseSpace
  11. from ..utils import (
  12. AverageMeterGroup,
  13. replace_layer_choice,
  14. replace_input_choice,
  15. get_module_order,
  16. sort_replaced_module,
  17. PathSamplingLayerChoice,
  18. PathSamplingInputChoice,
  19. )
  20. from tqdm import tqdm, trange
  21. from ....utils import get_logger
  22. import numpy as np
  23. LOGGER = get_logger("SPOS")
  24. import collections
  25. import dataclasses
  26. import random
  27. @dataclasses.dataclass
  28. class Individual:
  29. """
  30. A class that represents an individual.
  31. Holds two attributes, where ``x`` is the model and ``y`` is the metric (e.g., accuracy).
  32. """
  33. x: dict
  34. y: float
  35. class Evolution:
  36. """
  37. Algorithm for regularized evolution (i.e. aging evolution).
  38. Follows "Algorithm 1" in Real et al. "Regularized Evolution for Image Classifier Architecture Search".
  39. Parameters
  40. ----------
  41. optimize_mode : str
  42. Can be one of "maximize" and "minimize". Default: maximize.
  43. population_size : int
  44. The number of individuals to keep in the population. Default: 100.
  45. cycles : int
  46. The number of cycles (trials) the algorithm should run for. Default: 20000.
  47. sample_size : int
  48. The number of individuals that should participate in each tournament. Default: 25.
  49. mutation_prob : float
  50. Probability that mutation happens in each dim. Default: 0.05
  51. """
  52. def __init__(self, optimize_mode='maximize', population_size=100, sample_size=25, cycles=20000,
  53. mutation_prob=0.05,disable_progress=False):
  54. assert optimize_mode in ['maximize', 'minimize']
  55. assert sample_size < population_size
  56. self.optimize_mode = optimize_mode
  57. self.population_size = population_size
  58. self.sample_size = sample_size
  59. self.cycles = cycles
  60. self.mutation_prob = mutation_prob
  61. self.disable_progress= disable_progress
  62. self._worst = float('-inf') if self.optimize_mode == 'maximize' else float('inf')
  63. self._success_count = 0
  64. self._population = collections.deque()
  65. self._running_models = []
  66. self._polling_interval = 2.
  67. self._history = []
  68. def best_parent(self,sample_size=None):
  69. """get the config of the best parent
  70. """
  71. samples = [p for p in self._population] # copy population
  72. random.shuffle(samples)
  73. if sample_size is not None:
  74. samples = list(samples)[:sample_size]
  75. if self.optimize_mode == 'maximize':
  76. parent = max(samples, key=lambda sample: sample.y)
  77. else:
  78. parent = min(samples, key=lambda sample: sample.y)
  79. return parent.x
  80. def _prepare(self):
  81. self.uniform=UniformSampler(self.nas_modules)
  82. self.mutation=MutationSampler(self.nas_modules,self.mutation_prob)
  83. def _get_metric(self,config):
  84. for name, module in self.nas_modules:
  85. module.sampled = config[name]
  86. # todo: this may be computational expensive
  87. # model=self.model.parse_model(config,self.device)
  88. with torch.no_grad():
  89. metric, loss = self.estimator.infer(self.model, self.dataset, mask='val')
  90. return metric[0]
  91. def search(self, space: BaseSpace,nas_modules,dset,estimator,device):
  92. self.model = space
  93. self.dataset = dset
  94. self.estimator = estimator
  95. self.nas_modules = nas_modules
  96. self.device = device
  97. self._prepare()
  98. LOGGER.info('Initializing the first population.')
  99. with tqdm(range(self.population_size), disable=self.disable_progress) as bar:
  100. for i in bar:
  101. config = self.uniform.resample()
  102. metric=self._get_metric(config)
  103. individual = Individual(config, metric)
  104. # LOGGER.debug('Individual created: %s', str(individual))
  105. self._population.append(individual)
  106. self._history.append(individual)
  107. bar.set_postfix(metric=metric,max=max(x.y for x in self._population),min=min(x.y for x in self._population))
  108. LOGGER.info('Running mutations.')
  109. with tqdm(range(self.cycles), disable=self.disable_progress) as bar:
  110. for i in bar:
  111. parent=self.best_parent(self.sample_size)
  112. config=self.mutation.resample(parent)
  113. metric=self._get_metric(config) # todo : add aging factor
  114. individual = Individual(config, metric)
  115. LOGGER.debug('Individual created: %s', str(individual))
  116. self._population.append(individual)
  117. self._history.append(individual)
  118. if len(self._population) > self.population_size:
  119. self._population.popleft()
  120. bar.set_postfix(metric=metric,max_h=max(x.y for x in self._history),max=max(x.y for x in self._population),min=min(x.y for x in self._population))
  121. # todo: origin is best in history | or the population may need to be retrained
  122. self._history.sort(key=lambda x: x.y)
  123. # best=self.best_parent()
  124. if self.optimize_mode == 'maximize':
  125. best=self._history[-1].x
  126. else:
  127. best=self._history[0].x
  128. return best
  129. class MutationSampler:
  130. """uniform mutator
  131. Parameters
  132. ----------
  133. nas_modules:
  134. nas_modules in NAS algorithms , including choices of modules
  135. mutation_prob: float
  136. probability of doing mutation in each choice.
  137. parent : dict
  138. parent individual's choices
  139. """
  140. def __init__(self,nas_modules,mutation_prob):
  141. selection_range = {}
  142. for k, v in nas_modules:
  143. selection_range[k] = len(v)
  144. self.selection_dict = selection_range
  145. self.mutation_prob = mutation_prob
  146. def resample(self, parent):
  147. search_space=self.selection_dict
  148. child = {}
  149. for k, v in parent.items():
  150. if random.uniform(0, 1) < self.mutation_prob:
  151. child[k] = np.random.choice(range(search_space[k])) # do not exclude the original operator
  152. else:
  153. child[k] = v
  154. return child
  155. class UniformSampler:
  156. """Uniform Sampler
  157. Parameters
  158. ----------
  159. nas_modules:
  160. nas_modules in NAS algorithms , including choices of modules
  161. """
  162. def __init__(self,nas_modules):
  163. selection_range = {}
  164. for k, v in nas_modules:
  165. selection_range[k] = len(v)
  166. self.selection_dict = selection_range
  167. def resample(self):
  168. selection = {}
  169. for k, v in self.selection_dict.items():
  170. selection[k] = np.random.choice(range(v))
  171. return selection
  172. @register_nas_algo("spos")
  173. class Spos(BaseNAS):
  174. """
  175. SPOS trainer.
  176. Parameters
  177. ----------
  178. n_warmup : int
  179. Number of epochs for training super network.
  180. model_lr : float
  181. Learning rate for super network.
  182. model_wd : float
  183. Weight decay for super network.
  184. Other parameters see Evolution
  185. """
  186. def __init__(
  187. self,
  188. n_warmup=1000,
  189. grad_clip=5.0,
  190. disable_progress=False,
  191. optimize_mode='maximize',
  192. population_size=100,
  193. sample_size=25,
  194. cycles=20000,
  195. mutation_prob=0.05,
  196. device="cuda",
  197. ):
  198. super().__init__(device)
  199. self.model_lr=5e-3
  200. self.model_wd=5e-4
  201. self.n_warmup = n_warmup
  202. self.disable_progress= disable_progress
  203. self.grad_clip = grad_clip
  204. self.optimize_mode = optimize_mode
  205. self.population_size = population_size
  206. self.sample_size = sample_size
  207. self.cycles = cycles
  208. self.mutation_prob = mutation_prob
  209. def _prepare(self):
  210. # replace choice
  211. self.nas_modules = []
  212. k2o = get_module_order(self.model)
  213. replace_layer_choice(self.model, PathSamplingLayerChoice, self.nas_modules)
  214. replace_input_choice(self.model, PathSamplingInputChoice, self.nas_modules)
  215. self.nas_modules = sort_replaced_module(k2o, self.nas_modules)
  216. # to device
  217. self.model = self.model.to(self.device)
  218. self.model_optim = torch.optim.Adam(
  219. self.model.parameters(), lr=self.model_lr, weight_decay=self.model_wd
  220. )
  221. # controller
  222. self.controller=UniformSampler(self.nas_modules)
  223. # Evolution
  224. self.evolve = Evolution(
  225. optimize_mode='maximize',
  226. population_size=self.population_size,
  227. sample_size=self.sample_size,
  228. cycles=self.cycles,
  229. mutation_prob=self.mutation_prob,
  230. disable_progress=self.disable_progress
  231. )
  232. def search(self, space: BaseSpace, dset, estimator):
  233. self.model = space
  234. self.dataset = dset
  235. self.estimator = estimator
  236. self._prepare()
  237. self._train() # train using uniform sampling
  238. self._search() # search using evolutionary algorithm
  239. selection = self.export()
  240. # here may sample N , retrain N ,and get best
  241. print(selection)
  242. return space.parse_model(selection, self.device)
  243. def _search(self):
  244. self.best_config=self.evolve.search(
  245. self.model,
  246. self.nas_modules,
  247. self.dataset,
  248. self.estimator,
  249. self.device,
  250. )
  251. def _train(self):
  252. with tqdm(range(self.n_warmup), disable=self.disable_progress) as bar:
  253. for i in bar:
  254. acc, l1 = self._train_one_epoch(i)
  255. with torch.no_grad():
  256. val_acc, val_loss = self._infer("val")
  257. bar.set_postfix(loss=l1, acc=acc, val_acc=val_acc, val_loss=val_loss.item())
  258. def _train_one_epoch(self, epoch):
  259. self.model.train()
  260. self.model_optim.zero_grad()
  261. self._resample() # uniform sampling
  262. metric, loss = self._infer(mask="train")
  263. loss.backward()
  264. if self.grad_clip > 0:
  265. nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip)
  266. self.model_optim.step()
  267. return metric, loss.item()
  268. def _resample(self):
  269. result=self.controller.resample()
  270. for name, module in self.nas_modules:
  271. module.sampled = result[name]
  272. def export(self):
  273. return self.best_config
  274. def _infer(self, mask="train"):
  275. metric, loss = self.estimator.infer(self.model, self.dataset, mask=mask)
  276. return metric[0], loss