You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599
  1. # codes in this file are reproduced from https://github.com/microsoft/nni with some changes.
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from . import register_nas_algo
  6. from .base import BaseNAS
  7. from ..space import BaseSpace
  8. from ..utils import (
  9. AverageMeterGroup,
  10. replace_layer_choice,
  11. replace_input_choice,
  12. get_module_order,
  13. sort_replaced_module,
  14. PathSamplingInputChoice,
  15. PathSamplingLayerChoice,
  16. )
  17. from nni.nas.pytorch.fixed import apply_fixed_architecture
  18. from tqdm import tqdm
  19. from datetime import datetime
  20. import numpy as np
  21. from ....utils import get_logger
  22. LOGGER = get_logger("RL_NAS")
  23. def _get_mask(sampled, total):
  24. multihot = [
  25. i == sampled or (isinstance(sampled, list) and i in sampled)
  26. for i in range(total)
  27. ]
  28. return torch.tensor(multihot, dtype=torch.bool) # pylint: disable=not-callable
  29. class StackedLSTMCell(nn.Module):
  30. def __init__(self, layers, size, bias):
  31. super().__init__()
  32. self.lstm_num_layers = layers
  33. self.lstm_modules = nn.ModuleList(
  34. [nn.LSTMCell(size, size, bias=bias) for _ in range(self.lstm_num_layers)]
  35. )
  36. def forward(self, inputs, hidden):
  37. prev_h, prev_c = hidden
  38. next_h, next_c = [], []
  39. for i, m in enumerate(self.lstm_modules):
  40. curr_h, curr_c = m(inputs, (prev_h[i], prev_c[i]))
  41. next_c.append(curr_c)
  42. next_h.append(curr_h)
  43. # current implementation only supports batch size equals 1,
  44. # but the algorithm does not necessarily have this limitation
  45. inputs = curr_h[-1].view(1, -1)
  46. return next_h, next_c
  47. class ReinforceField:
  48. """
  49. A field with ``name``, with ``total`` choices. ``choose_one`` is true if one and only one is meant to be
  50. selected. Otherwise, any number of choices can be chosen.
  51. """
  52. def __init__(self, name, total, choose_one):
  53. self.name = name
  54. self.total = total
  55. self.choose_one = choose_one
  56. def __repr__(self):
  57. return f"ReinforceField(name={self.name}, total={self.total}, choose_one={self.choose_one})"
  58. class ReinforceController(nn.Module):
  59. """
  60. A controller that mutates the graph with RL.
  61. Parameters
  62. ----------
  63. fields : list of ReinforceField
  64. List of fields to choose.
  65. lstm_size : int
  66. Controller LSTM hidden units.
  67. lstm_num_layers : int
  68. Number of layers for stacked LSTM.
  69. tanh_constant : float
  70. Logits will be equal to ``tanh_constant * tanh(logits)``. Don't use ``tanh`` if this value is ``None``.
  71. skip_target : float
  72. Target probability that skipconnect will appear.
  73. temperature : float
  74. Temperature constant that divides the logits.
  75. entropy_reduction : str
  76. Can be one of ``sum`` and ``mean``. How the entropy of multi-input-choice is reduced.
  77. """
  78. def __init__(
  79. self,
  80. fields,
  81. lstm_size=64,
  82. lstm_num_layers=1,
  83. tanh_constant=1.5,
  84. skip_target=0.4,
  85. temperature=None,
  86. entropy_reduction="sum",
  87. ):
  88. super(ReinforceController, self).__init__()
  89. self.fields = fields
  90. self.lstm_size = lstm_size
  91. self.lstm_num_layers = lstm_num_layers
  92. self.tanh_constant = tanh_constant
  93. self.temperature = temperature
  94. self.skip_target = skip_target
  95. self.lstm = StackedLSTMCell(self.lstm_num_layers, self.lstm_size, False)
  96. self.attn_anchor = nn.Linear(self.lstm_size, self.lstm_size, bias=False)
  97. self.attn_query = nn.Linear(self.lstm_size, self.lstm_size, bias=False)
  98. self.v_attn = nn.Linear(self.lstm_size, 1, bias=False)
  99. self.g_emb = nn.Parameter(torch.randn(1, self.lstm_size) * 0.1)
  100. self.skip_targets = nn.Parameter(
  101. torch.tensor(
  102. [1.0 - self.skip_target, self.skip_target]
  103. ), # pylint: disable=not-callable
  104. requires_grad=False,
  105. )
  106. assert entropy_reduction in [
  107. "sum",
  108. "mean",
  109. ], "Entropy reduction must be one of sum and mean."
  110. self.entropy_reduction = torch.sum if entropy_reduction == "sum" else torch.mean
  111. self.cross_entropy_loss = nn.CrossEntropyLoss(reduction="none")
  112. self.soft = nn.ModuleDict(
  113. {
  114. field.name: nn.Linear(self.lstm_size, field.total, bias=False)
  115. for field in fields
  116. }
  117. )
  118. self.embedding = nn.ModuleDict(
  119. {field.name: nn.Embedding(field.total, self.lstm_size) for field in fields}
  120. )
  121. def resample(self):
  122. self._initialize()
  123. result = dict()
  124. for field in self.fields:
  125. result[field.name] = self._sample_single(field)
  126. return result
  127. def _initialize(self):
  128. self._inputs = self.g_emb.data
  129. self._c = [
  130. torch.zeros(
  131. (1, self.lstm_size),
  132. dtype=self._inputs.dtype,
  133. device=self._inputs.device,
  134. )
  135. for _ in range(self.lstm_num_layers)
  136. ]
  137. self._h = [
  138. torch.zeros(
  139. (1, self.lstm_size),
  140. dtype=self._inputs.dtype,
  141. device=self._inputs.device,
  142. )
  143. for _ in range(self.lstm_num_layers)
  144. ]
  145. self.sample_log_prob = 0
  146. self.sample_entropy = 0
  147. self.sample_skip_penalty = 0
  148. def _lstm_next_step(self):
  149. self._h, self._c = self.lstm(self._inputs, (self._h, self._c))
  150. def _sample_single(self, field):
  151. self._lstm_next_step()
  152. logit = self.soft[field.name](self._h[-1])
  153. if self.temperature is not None:
  154. logit /= self.temperature
  155. if self.tanh_constant is not None:
  156. logit = self.tanh_constant * torch.tanh(logit)
  157. if field.choose_one:
  158. sampled = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1)
  159. log_prob = self.cross_entropy_loss(logit, sampled)
  160. self._inputs = self.embedding[field.name](sampled)
  161. else:
  162. logit = logit.view(-1, 1)
  163. logit = torch.cat(
  164. [-logit, logit], 1
  165. ) # pylint: disable=invalid-unary-operand-type
  166. sampled = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1)
  167. skip_prob = torch.sigmoid(logit)
  168. kl = torch.sum(skip_prob * torch.log(skip_prob / self.skip_targets))
  169. self.sample_skip_penalty += kl
  170. log_prob = self.cross_entropy_loss(logit, sampled)
  171. sampled = sampled.nonzero().view(-1)
  172. if sampled.sum().item():
  173. self._inputs = (
  174. torch.sum(self.embedding[field.name](sampled.view(-1)), 0)
  175. / (1.0 + torch.sum(sampled))
  176. ).unsqueeze(0)
  177. else:
  178. self._inputs = torch.zeros(
  179. 1, self.lstm_size, device=self.embedding[field.name].weight.device
  180. )
  181. sampled = sampled.detach().numpy().tolist()
  182. self.sample_log_prob += self.entropy_reduction(log_prob)
  183. entropy = (
  184. log_prob * torch.exp(-log_prob)
  185. ).detach() # pylint: disable=invalid-unary-operand-type
  186. self.sample_entropy += self.entropy_reduction(entropy)
  187. if len(sampled) == 1:
  188. sampled = sampled[0]
  189. return sampled
  190. @register_nas_algo("rl")
  191. class RL(BaseNAS):
  192. """
  193. RL in GraphNas.
  194. Parameters
  195. ----------
  196. num_epochs : int
  197. Number of epochs planned for training.
  198. device : torch.device
  199. ``torch.device("cpu")`` or ``torch.device("cuda")``.
  200. log_frequency : int
  201. Step count per logging.
  202. grad_clip : float
  203. Gradient clipping. Set to 0 to disable. Default: 5.
  204. entropy_weight : float
  205. Weight of sample entropy loss.
  206. skip_weight : float
  207. Weight of skip penalty loss.
  208. baseline_decay : float
  209. Decay factor of baseline. New baseline will be equal to ``baseline_decay * baseline_old + reward * (1 - baseline_decay)``.
  210. ctrl_lr : float
  211. Learning rate for RL controller.
  212. ctrl_steps_aggregate : int
  213. Number of steps that will be aggregated into one mini-batch for RL controller.
  214. ctrl_steps : int
  215. Number of mini-batches for each epoch of RL controller learning.
  216. ctrl_kwargs : dict
  217. Optional kwargs that will be passed to :class:`ReinforceController`.
  218. n_warmup : int
  219. Number of epochs for training super network.
  220. model_lr : float
  221. Learning rate for super network.
  222. model_wd : float
  223. Weight decay for super network.
  224. disable_progress: boolean
  225. Control whether show the progress bar.
  226. """
  227. def __init__(
  228. self,
  229. num_epochs=5,
  230. device="auto",
  231. log_frequency=None,
  232. grad_clip=5.0,
  233. entropy_weight=0.0001,
  234. skip_weight=0.8,
  235. baseline_decay=0.999,
  236. ctrl_lr=0.00035,
  237. ctrl_steps_aggregate=20,
  238. ctrl_kwargs=None,
  239. n_warmup=100,
  240. model_lr=5e-3,
  241. model_wd=5e-4,
  242. disable_progress=False,
  243. ):
  244. super().__init__(device)
  245. self.num_epochs = num_epochs
  246. self.log_frequency = log_frequency
  247. self.entropy_weight = entropy_weight
  248. self.skip_weight = skip_weight
  249. self.baseline_decay = baseline_decay
  250. self.baseline = 0.0
  251. self.ctrl_steps_aggregate = ctrl_steps_aggregate
  252. self.grad_clip = grad_clip
  253. self.ctrl_kwargs = ctrl_kwargs
  254. self.ctrl_lr = ctrl_lr
  255. self.n_warmup = n_warmup
  256. self.model_lr = model_lr
  257. self.model_wd = model_wd
  258. self.disable_progress = disable_progress
  259. def search(self, space: BaseSpace, dset, estimator):
  260. self.model = space
  261. self.dataset = dset # .to(self.device)
  262. self.estimator = estimator
  263. # replace choice
  264. self.nas_modules = []
  265. k2o = get_module_order(self.model)
  266. replace_layer_choice(self.model, PathSamplingLayerChoice, self.nas_modules)
  267. replace_input_choice(self.model, PathSamplingInputChoice, self.nas_modules)
  268. self.nas_modules = sort_replaced_module(k2o, self.nas_modules)
  269. # to device
  270. self.model = self.model.to(self.device)
  271. # fields
  272. self.nas_fields = [
  273. ReinforceField(
  274. name,
  275. len(module),
  276. isinstance(module, PathSamplingLayerChoice) or module.n_chosen == 1,
  277. )
  278. for name, module in self.nas_modules
  279. ]
  280. self.controller = ReinforceController(
  281. self.nas_fields, **(self.ctrl_kwargs or {})
  282. )
  283. self.ctrl_optim = torch.optim.Adam(
  284. self.controller.parameters(), lr=self.ctrl_lr
  285. )
  286. # train
  287. with tqdm(range(self.num_epochs), disable=self.disable_progress) as bar:
  288. for i in bar:
  289. l2 = self._train_controller(i)
  290. bar.set_postfix(reward_controller=l2)
  291. selection = self.export()
  292. arch = space.parse_model(selection, self.device)
  293. # print(selection,arch)
  294. return arch
  295. def _train_controller(self, epoch):
  296. self.model.eval()
  297. self.controller.train()
  298. self.ctrl_optim.zero_grad()
  299. rewards = []
  300. with tqdm(
  301. range(self.ctrl_steps_aggregate), disable=self.disable_progress
  302. ) as bar:
  303. for ctrl_step in bar:
  304. self._resample()
  305. metric, loss = self._infer(mask="val")
  306. reward = metric
  307. bar.set_postfix(acc=metric, loss=loss.item())
  308. LOGGER.debug(f"{self.arch}\n{self.selection}\n{metric},{loss}")
  309. rewards.append(reward)
  310. if self.entropy_weight:
  311. reward += (
  312. self.entropy_weight * self.controller.sample_entropy.item()
  313. )
  314. self.baseline = self.baseline * self.baseline_decay + reward * (
  315. 1 - self.baseline_decay
  316. )
  317. loss = self.controller.sample_log_prob * (reward - self.baseline)
  318. if self.skip_weight:
  319. loss += self.skip_weight * self.controller.sample_skip_penalty
  320. loss /= self.ctrl_steps_aggregate
  321. loss.backward()
  322. if (ctrl_step + 1) % self.ctrl_steps_aggregate == 0:
  323. if self.grad_clip > 0:
  324. nn.utils.clip_grad_norm_(
  325. self.controller.parameters(), self.grad_clip
  326. )
  327. self.ctrl_optim.step()
  328. self.ctrl_optim.zero_grad()
  329. if (
  330. self.log_frequency is not None
  331. and ctrl_step % self.log_frequency == 0
  332. ):
  333. LOGGER.debug(
  334. "RL Epoch [%d/%d] Step [%d/%d] %s",
  335. epoch + 1,
  336. self.num_epochs,
  337. ctrl_step + 1,
  338. self.ctrl_steps_aggregate,
  339. )
  340. return sum(rewards) / len(rewards)
  341. def _resample(self):
  342. result = self.controller.resample()
  343. self.arch = self.model.parse_model(result, device=self.device)
  344. self.selection = result
  345. def export(self):
  346. self.controller.eval()
  347. with torch.no_grad():
  348. return self.controller.resample()
  349. def _infer(self, mask="train"):
  350. metric, loss = self.estimator.infer(self.arch._model, self.dataset, mask=mask)
  351. return metric[0], loss
  352. @register_nas_algo("graphnas")
  353. class GraphNasRL(BaseNAS):
  354. """
  355. RL in GraphNas.
  356. Parameters
  357. ----------
  358. device : torch.device
  359. ``torch.device("cpu")`` or ``torch.device("cuda")``.
  360. num_epochs : int
  361. Number of epochs planned for training.
  362. log_frequency : int
  363. Step count per logging.
  364. grad_clip : float
  365. Gradient clipping. Set to 0 to disable. Default: 5.
  366. entropy_weight : float
  367. Weight of sample entropy loss.
  368. skip_weight : float
  369. Weight of skip penalty loss.
  370. baseline_decay : float
  371. Decay factor of baseline. New baseline will be equal to ``baseline_decay * baseline_old + reward * (1 - baseline_decay)``.
  372. ctrl_lr : float
  373. Learning rate for RL controller.
  374. ctrl_steps_aggregate : int
  375. Number of steps that will be aggregated into one mini-batch for RL controller.
  376. ctrl_steps : int
  377. Number of mini-batches for each epoch of RL controller learning.
  378. ctrl_kwargs : dict
  379. Optional kwargs that will be passed to :class:`ReinforceController`.
  380. n_warmup : int
  381. Number of epochs for training super network.
  382. model_lr : float
  383. Learning rate for super network.
  384. model_wd : float
  385. Weight decay for super network.
  386. topk : int
  387. Number of architectures kept in training process.
  388. disable_progeress: boolean
  389. Control whether show the progress bar.
  390. """
  391. def __init__(
  392. self,
  393. device="auto",
  394. num_epochs=10,
  395. log_frequency=None,
  396. grad_clip=5.0,
  397. entropy_weight=0.0001,
  398. skip_weight=0,
  399. baseline_decay=0.95,
  400. ctrl_lr=0.00035,
  401. ctrl_steps_aggregate=100,
  402. ctrl_kwargs=None,
  403. n_warmup=100,
  404. model_lr=5e-3,
  405. model_wd=5e-4,
  406. topk=5,
  407. disable_progress=False,
  408. hardware_metric_limit=None,
  409. ):
  410. super().__init__(device)
  411. self.num_epochs = num_epochs
  412. self.log_frequency = log_frequency
  413. self.entropy_weight = entropy_weight
  414. self.skip_weight = skip_weight
  415. self.baseline_decay = baseline_decay
  416. self.ctrl_steps_aggregate = ctrl_steps_aggregate
  417. self.grad_clip = grad_clip
  418. self.ctrl_kwargs = ctrl_kwargs
  419. self.ctrl_lr = ctrl_lr
  420. self.n_warmup = n_warmup
  421. self.model_lr = model_lr
  422. self.model_wd = model_wd
  423. self.hist = []
  424. self.topk = topk
  425. self.disable_progress = disable_progress
  426. self.hardware_metric_limit = hardware_metric_limit
  427. self.allhist=[]
  428. def search(self, space: BaseSpace, dset, estimator):
  429. self.model = space
  430. self.dataset = dset # .to(self.device)
  431. self.estimator = estimator
  432. # replace choice
  433. self.nas_modules = []
  434. k2o = get_module_order(self.model)
  435. replace_layer_choice(self.model, PathSamplingLayerChoice, self.nas_modules)
  436. replace_input_choice(self.model, PathSamplingInputChoice, self.nas_modules)
  437. self.nas_modules = sort_replaced_module(k2o, self.nas_modules)
  438. # to device
  439. self.model = self.model.to(self.device)
  440. # fields
  441. self.nas_fields = [
  442. ReinforceField(
  443. name,
  444. len(module),
  445. isinstance(module, PathSamplingLayerChoice) or module.n_chosen == 1,
  446. )
  447. for name, module in self.nas_modules
  448. ]
  449. self.controller = ReinforceController(
  450. self.nas_fields,
  451. lstm_size=100,
  452. temperature=5.0,
  453. tanh_constant=2.5,
  454. **(self.ctrl_kwargs or {}),
  455. )
  456. self.ctrl_optim = torch.optim.Adam(
  457. self.controller.parameters(), lr=self.ctrl_lr
  458. )
  459. # train
  460. with tqdm(range(self.num_epochs), disable=self.disable_progress) as bar:
  461. for i in bar:
  462. l2 = self._train_controller(i)
  463. bar.set_postfix(reward_controller=l2)
  464. # selection=self.export()
  465. selections = [x[1] for x in self.hist]
  466. candidiate_accs = [-x[0] for x in self.hist]
  467. # print('candidiate accuracies',candidiate_accs)
  468. selection = self._choose_best(selections)
  469. arch = space.parse_model(selection, self.device)
  470. # print(selection,arch)
  471. return arch
  472. def _choose_best(self, selections):
  473. # graphnas use top 5 models, can evaluate 20 times epoch and choose the best.
  474. results = []
  475. for selection in selections:
  476. accs = []
  477. for i in tqdm(range(20), disable=self.disable_progress):
  478. self.arch = self.model.parse_model(selection, device=self.device)
  479. metric, loss, _ = self._infer(mask="val")
  480. accs.append(metric)
  481. result = np.mean(accs)
  482. LOGGER.info(
  483. "selection {} \n acc {:.4f} +- {:.4f}".format(
  484. selection, np.mean(accs), np.std(accs) / np.sqrt(20)
  485. )
  486. )
  487. results.append(result)
  488. best_selection = selections[np.argmax(results)]
  489. return best_selection
  490. def _train_controller(self, epoch):
  491. self.model.eval()
  492. self.controller.train()
  493. self.ctrl_optim.zero_grad()
  494. rewards = []
  495. baseline = None
  496. # diff: graph nas train 100 and derive 100 for every epoch(10 epochs), we just train 100(20 epochs). totol num of samples are same (2000)
  497. with tqdm(
  498. range(self.ctrl_steps_aggregate), disable=self.disable_progress
  499. ) as bar:
  500. for ctrl_step in bar:
  501. self._resample()
  502. metric, loss, hardware_metric = self._infer(mask="val")
  503. reward = metric
  504. # bar.set_postfix(acc=metric,loss=loss.item())
  505. LOGGER.debug(f"{self.arch}\n{self.selection}\n{metric},{loss}")
  506. # diff: not do reward shaping as in graphnas code
  507. if (
  508. self.hardware_metric_limit is None
  509. or hardware_metric[0] < self.hardware_metric_limit
  510. ):
  511. self.hist.append([-metric, self.selection])
  512. self.allhist.append([-metric, self.selection])
  513. if len(self.hist) > self.topk:
  514. self.hist.sort(key=lambda x: x[0])
  515. self.hist.pop()
  516. rewards.append(reward)
  517. if self.entropy_weight:
  518. reward += (
  519. self.entropy_weight * self.controller.sample_entropy.item()
  520. )
  521. if not baseline:
  522. baseline = reward
  523. else:
  524. baseline = baseline * self.baseline_decay + reward * (
  525. 1 - self.baseline_decay
  526. )
  527. loss = self.controller.sample_log_prob * (reward - baseline)
  528. self.ctrl_optim.zero_grad()
  529. loss.backward()
  530. self.ctrl_optim.step()
  531. bar.set_postfix(acc=metric, max_acc=max(rewards))
  532. LOGGER.info("epoch:{}, mean rewards:{}".format(epoch, sum(rewards) / len(rewards)))
  533. return sum(rewards) / len(rewards)
  534. def _resample(self):
  535. result = self.controller.resample()
  536. self.arch = self.model.parse_model(result, device=self.device)
  537. self.selection = result
  538. def export(self):
  539. self.controller.eval()
  540. with torch.no_grad():
  541. return self.controller.resample()
  542. def _infer(self, mask="train"):
  543. metric, loss = self.estimator.infer(self.arch._model, self.dataset, mask=mask)
  544. return metric[0], loss, metric[1:]