You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

rl.py 21 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597
  1. # codes in this file are reproduced from https://github.com/microsoft/nni with some changes.
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from . import register_nas_algo
  6. from .base import BaseNAS
  7. from ..space import BaseSpace
  8. from ..utils import (
  9. AverageMeterGroup,
  10. replace_layer_choice,
  11. replace_input_choice,
  12. get_module_order,
  13. sort_replaced_module,
  14. PathSamplingInputChoice,
  15. PathSamplingLayerChoice,
  16. )
  17. from nni.nas.pytorch.fixed import apply_fixed_architecture
  18. from tqdm import tqdm
  19. from datetime import datetime
  20. import numpy as np
  21. from ....utils import get_logger
  22. LOGGER = get_logger("RL_NAS")
  23. def _get_mask(sampled, total):
  24. multihot = [
  25. i == sampled or (isinstance(sampled, list) and i in sampled)
  26. for i in range(total)
  27. ]
  28. return torch.tensor(multihot, dtype=torch.bool) # pylint: disable=not-callable
  29. class StackedLSTMCell(nn.Module):
  30. def __init__(self, layers, size, bias):
  31. super().__init__()
  32. self.lstm_num_layers = layers
  33. self.lstm_modules = nn.ModuleList(
  34. [nn.LSTMCell(size, size, bias=bias) for _ in range(self.lstm_num_layers)]
  35. )
  36. def forward(self, inputs, hidden):
  37. prev_h, prev_c = hidden
  38. next_h, next_c = [], []
  39. for i, m in enumerate(self.lstm_modules):
  40. curr_h, curr_c = m(inputs, (prev_h[i], prev_c[i]))
  41. next_c.append(curr_c)
  42. next_h.append(curr_h)
  43. # current implementation only supports batch size equals 1,
  44. # but the algorithm does not necessarily have this limitation
  45. inputs = curr_h[-1].view(1, -1)
  46. return next_h, next_c
  47. class ReinforceField:
  48. """
  49. A field with ``name``, with ``total`` choices. ``choose_one`` is true if one and only one is meant to be
  50. selected. Otherwise, any number of choices can be chosen.
  51. """
  52. def __init__(self, name, total, choose_one):
  53. self.name = name
  54. self.total = total
  55. self.choose_one = choose_one
  56. def __repr__(self):
  57. return f"ReinforceField(name={self.name}, total={self.total}, choose_one={self.choose_one})"
  58. class ReinforceController(nn.Module):
  59. """
  60. A controller that mutates the graph with RL.
  61. Parameters
  62. ----------
  63. fields : list of ReinforceField
  64. List of fields to choose.
  65. lstm_size : int
  66. Controller LSTM hidden units.
  67. lstm_num_layers : int
  68. Number of layers for stacked LSTM.
  69. tanh_constant : float
  70. Logits will be equal to ``tanh_constant * tanh(logits)``. Don't use ``tanh`` if this value is ``None``.
  71. skip_target : float
  72. Target probability that skipconnect will appear.
  73. temperature : float
  74. Temperature constant that divides the logits.
  75. entropy_reduction : str
  76. Can be one of ``sum`` and ``mean``. How the entropy of multi-input-choice is reduced.
  77. """
  78. def __init__(
  79. self,
  80. fields,
  81. lstm_size=64,
  82. lstm_num_layers=1,
  83. tanh_constant=1.5,
  84. skip_target=0.4,
  85. temperature=None,
  86. entropy_reduction="sum",
  87. ):
  88. super(ReinforceController, self).__init__()
  89. self.fields = fields
  90. self.lstm_size = lstm_size
  91. self.lstm_num_layers = lstm_num_layers
  92. self.tanh_constant = tanh_constant
  93. self.temperature = temperature
  94. self.skip_target = skip_target
  95. self.lstm = StackedLSTMCell(self.lstm_num_layers, self.lstm_size, False)
  96. self.attn_anchor = nn.Linear(self.lstm_size, self.lstm_size, bias=False)
  97. self.attn_query = nn.Linear(self.lstm_size, self.lstm_size, bias=False)
  98. self.v_attn = nn.Linear(self.lstm_size, 1, bias=False)
  99. self.g_emb = nn.Parameter(torch.randn(1, self.lstm_size) * 0.1)
  100. self.skip_targets = nn.Parameter(
  101. torch.tensor(
  102. [1.0 - self.skip_target, self.skip_target]
  103. ), # pylint: disable=not-callable
  104. requires_grad=False,
  105. )
  106. assert entropy_reduction in [
  107. "sum",
  108. "mean",
  109. ], "Entropy reduction must be one of sum and mean."
  110. self.entropy_reduction = torch.sum if entropy_reduction == "sum" else torch.mean
  111. self.cross_entropy_loss = nn.CrossEntropyLoss(reduction="none")
  112. self.soft = nn.ModuleDict(
  113. {
  114. field.name: nn.Linear(self.lstm_size, field.total, bias=False)
  115. for field in fields
  116. }
  117. )
  118. self.embedding = nn.ModuleDict(
  119. {field.name: nn.Embedding(field.total, self.lstm_size) for field in fields}
  120. )
  121. def resample(self):
  122. self._initialize()
  123. result = dict()
  124. for field in self.fields:
  125. result[field.name] = self._sample_single(field)
  126. return result
  127. def _initialize(self):
  128. self._inputs = self.g_emb.data
  129. self._c = [
  130. torch.zeros(
  131. (1, self.lstm_size),
  132. dtype=self._inputs.dtype,
  133. device=self._inputs.device,
  134. )
  135. for _ in range(self.lstm_num_layers)
  136. ]
  137. self._h = [
  138. torch.zeros(
  139. (1, self.lstm_size),
  140. dtype=self._inputs.dtype,
  141. device=self._inputs.device,
  142. )
  143. for _ in range(self.lstm_num_layers)
  144. ]
  145. self.sample_log_prob = 0
  146. self.sample_entropy = 0
  147. self.sample_skip_penalty = 0
  148. def _lstm_next_step(self):
  149. self._h, self._c = self.lstm(self._inputs, (self._h, self._c))
  150. def _sample_single(self, field):
  151. self._lstm_next_step()
  152. logit = self.soft[field.name](self._h[-1])
  153. if self.temperature is not None:
  154. logit /= self.temperature
  155. if self.tanh_constant is not None:
  156. logit = self.tanh_constant * torch.tanh(logit)
  157. if field.choose_one:
  158. sampled = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1)
  159. log_prob = self.cross_entropy_loss(logit, sampled)
  160. self._inputs = self.embedding[field.name](sampled)
  161. else:
  162. logit = logit.view(-1, 1)
  163. logit = torch.cat(
  164. [-logit, logit], 1
  165. ) # pylint: disable=invalid-unary-operand-type
  166. sampled = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1)
  167. skip_prob = torch.sigmoid(logit)
  168. kl = torch.sum(skip_prob * torch.log(skip_prob / self.skip_targets))
  169. self.sample_skip_penalty += kl
  170. log_prob = self.cross_entropy_loss(logit, sampled)
  171. sampled = sampled.nonzero().view(-1)
  172. if sampled.sum().item():
  173. self._inputs = (
  174. torch.sum(self.embedding[field.name](sampled.view(-1)), 0)
  175. / (1.0 + torch.sum(sampled))
  176. ).unsqueeze(0)
  177. else:
  178. self._inputs = torch.zeros(
  179. 1, self.lstm_size, device=self.embedding[field.name].weight.device
  180. )
  181. sampled = sampled.detach().numpy().tolist()
  182. self.sample_log_prob += self.entropy_reduction(log_prob)
  183. entropy = (
  184. log_prob * torch.exp(-log_prob)
  185. ).detach() # pylint: disable=invalid-unary-operand-type
  186. self.sample_entropy += self.entropy_reduction(entropy)
  187. if len(sampled) == 1:
  188. sampled = sampled[0]
  189. return sampled
  190. @register_nas_algo("rl")
  191. class RL(BaseNAS):
  192. """
  193. RL in GraphNas.
  194. Parameters
  195. ----------
  196. num_epochs : int
  197. Number of epochs planned for training.
  198. device : torch.device
  199. ``torch.device("cpu")`` or ``torch.device("cuda")``.
  200. log_frequency : int
  201. Step count per logging.
  202. grad_clip : float
  203. Gradient clipping. Set to 0 to disable. Default: 5.
  204. entropy_weight : float
  205. Weight of sample entropy loss.
  206. skip_weight : float
  207. Weight of skip penalty loss.
  208. baseline_decay : float
  209. Decay factor of baseline. New baseline will be equal to ``baseline_decay * baseline_old + reward * (1 - baseline_decay)``.
  210. ctrl_lr : float
  211. Learning rate for RL controller.
  212. ctrl_steps_aggregate : int
  213. Number of steps that will be aggregated into one mini-batch for RL controller.
  214. ctrl_steps : int
  215. Number of mini-batches for each epoch of RL controller learning.
  216. ctrl_kwargs : dict
  217. Optional kwargs that will be passed to :class:`ReinforceController`.
  218. n_warmup : int
  219. Number of epochs for training super network.
  220. model_lr : float
  221. Learning rate for super network.
  222. model_wd : float
  223. Weight decay for super network.
  224. disable_progress: boolean
  225. Control whether show the progress bar.
  226. """
  227. def __init__(
  228. self,
  229. num_epochs=5,
  230. device="cuda",
  231. log_frequency=None,
  232. grad_clip=5.0,
  233. entropy_weight=0.0001,
  234. skip_weight=0.8,
  235. baseline_decay=0.999,
  236. ctrl_lr=0.00035,
  237. ctrl_steps_aggregate=20,
  238. ctrl_kwargs=None,
  239. n_warmup=100,
  240. model_lr=5e-3,
  241. model_wd=5e-4,
  242. disable_progress=False,
  243. ):
  244. super().__init__(device)
  245. self.device = device
  246. self.num_epochs = num_epochs
  247. self.log_frequency = log_frequency
  248. self.entropy_weight = entropy_weight
  249. self.skip_weight = skip_weight
  250. self.baseline_decay = baseline_decay
  251. self.baseline = 0.0
  252. self.ctrl_steps_aggregate = ctrl_steps_aggregate
  253. self.grad_clip = grad_clip
  254. self.ctrl_kwargs = ctrl_kwargs
  255. self.ctrl_lr = ctrl_lr
  256. self.n_warmup = n_warmup
  257. self.model_lr = model_lr
  258. self.model_wd = model_wd
  259. self.disable_progress = disable_progress
  260. def search(self, space: BaseSpace, dset, estimator):
  261. self.model = space
  262. self.dataset = dset # .to(self.device)
  263. self.estimator = estimator
  264. # replace choice
  265. self.nas_modules = []
  266. k2o = get_module_order(self.model)
  267. replace_layer_choice(self.model, PathSamplingLayerChoice, self.nas_modules)
  268. replace_input_choice(self.model, PathSamplingInputChoice, self.nas_modules)
  269. self.nas_modules = sort_replaced_module(k2o, self.nas_modules)
  270. # to device
  271. self.model = self.model.to(self.device)
  272. # fields
  273. self.nas_fields = [
  274. ReinforceField(
  275. name,
  276. len(module),
  277. isinstance(module, PathSamplingLayerChoice) or module.n_chosen == 1,
  278. )
  279. for name, module in self.nas_modules
  280. ]
  281. self.controller = ReinforceController(
  282. self.nas_fields, **(self.ctrl_kwargs or {})
  283. )
  284. self.ctrl_optim = torch.optim.Adam(
  285. self.controller.parameters(), lr=self.ctrl_lr
  286. )
  287. # train
  288. with tqdm(range(self.num_epochs), disable=self.disable_progress) as bar:
  289. for i in bar:
  290. l2 = self._train_controller(i)
  291. bar.set_postfix(reward_controller=l2)
  292. selection = self.export()
  293. arch = space.parse_model(selection, self.device)
  294. # print(selection,arch)
  295. return arch
  296. def _train_controller(self, epoch):
  297. self.model.eval()
  298. self.controller.train()
  299. self.ctrl_optim.zero_grad()
  300. rewards = []
  301. with tqdm(
  302. range(self.ctrl_steps_aggregate), disable=self.disable_progress
  303. ) as bar:
  304. for ctrl_step in bar:
  305. self._resample()
  306. metric, loss = self._infer(mask="val")
  307. reward = metric
  308. bar.set_postfix(acc=metric, loss=loss.item())
  309. LOGGER.debug(f"{self.arch}\n{self.selection}\n{metric},{loss}")
  310. rewards.append(reward)
  311. if self.entropy_weight:
  312. reward += (
  313. self.entropy_weight * self.controller.sample_entropy.item()
  314. )
  315. self.baseline = self.baseline * self.baseline_decay + reward * (
  316. 1 - self.baseline_decay
  317. )
  318. loss = self.controller.sample_log_prob * (reward - self.baseline)
  319. if self.skip_weight:
  320. loss += self.skip_weight * self.controller.sample_skip_penalty
  321. loss /= self.ctrl_steps_aggregate
  322. loss.backward()
  323. if (ctrl_step + 1) % self.ctrl_steps_aggregate == 0:
  324. if self.grad_clip > 0:
  325. nn.utils.clip_grad_norm_(
  326. self.controller.parameters(), self.grad_clip
  327. )
  328. self.ctrl_optim.step()
  329. self.ctrl_optim.zero_grad()
  330. if (
  331. self.log_frequency is not None
  332. and ctrl_step % self.log_frequency == 0
  333. ):
  334. LOGGER.debug(
  335. "RL Epoch [%d/%d] Step [%d/%d] %s",
  336. epoch + 1,
  337. self.num_epochs,
  338. ctrl_step + 1,
  339. self.ctrl_steps_aggregate,
  340. )
  341. return sum(rewards) / len(rewards)
  342. def _resample(self):
  343. result = self.controller.resample()
  344. self.arch = self.model.parse_model(result, device=self.device)
  345. self.selection = result
  346. def export(self):
  347. self.controller.eval()
  348. with torch.no_grad():
  349. return self.controller.resample()
  350. def _infer(self, mask="train"):
  351. metric, loss = self.estimator.infer(self.arch._model, self.dataset, mask=mask)
  352. return metric[0], loss
  353. @register_nas_algo("graphnas")
  354. class GraphNasRL(BaseNAS):
  355. """
  356. RL in GraphNas.
  357. Parameters
  358. ----------
  359. device : torch.device
  360. ``torch.device("cpu")`` or ``torch.device("cuda")``.
  361. num_epochs : int
  362. Number of epochs planned for training.
  363. log_frequency : int
  364. Step count per logging.
  365. grad_clip : float
  366. Gradient clipping. Set to 0 to disable. Default: 5.
  367. entropy_weight : float
  368. Weight of sample entropy loss.
  369. skip_weight : float
  370. Weight of skip penalty loss.
  371. baseline_decay : float
  372. Decay factor of baseline. New baseline will be equal to ``baseline_decay * baseline_old + reward * (1 - baseline_decay)``.
  373. ctrl_lr : float
  374. Learning rate for RL controller.
  375. ctrl_steps_aggregate : int
  376. Number of steps that will be aggregated into one mini-batch for RL controller.
  377. ctrl_steps : int
  378. Number of mini-batches for each epoch of RL controller learning.
  379. ctrl_kwargs : dict
  380. Optional kwargs that will be passed to :class:`ReinforceController`.
  381. n_warmup : int
  382. Number of epochs for training super network.
  383. model_lr : float
  384. Learning rate for super network.
  385. model_wd : float
  386. Weight decay for super network.
  387. topk : int
  388. Number of architectures kept in training process.
  389. disable_progeress: boolean
  390. Control whether show the progress bar.
  391. """
  392. def __init__(
  393. self,
  394. device="cuda",
  395. num_epochs=10,
  396. log_frequency=None,
  397. grad_clip=5.0,
  398. entropy_weight=0.0001,
  399. skip_weight=0,
  400. baseline_decay=0.95,
  401. ctrl_lr=0.00035,
  402. ctrl_steps_aggregate=100,
  403. ctrl_kwargs=None,
  404. n_warmup=100,
  405. model_lr=5e-3,
  406. model_wd=5e-4,
  407. topk=5,
  408. disable_progress=False,
  409. hardware_metric_limit=None,
  410. ):
  411. super().__init__(device)
  412. self.device = device
  413. self.num_epochs = num_epochs
  414. self.log_frequency = log_frequency
  415. self.entropy_weight = entropy_weight
  416. self.skip_weight = skip_weight
  417. self.baseline_decay = baseline_decay
  418. self.ctrl_steps_aggregate = ctrl_steps_aggregate
  419. self.grad_clip = grad_clip
  420. self.ctrl_kwargs = ctrl_kwargs
  421. self.ctrl_lr = ctrl_lr
  422. self.n_warmup = n_warmup
  423. self.model_lr = model_lr
  424. self.model_wd = model_wd
  425. self.hist = []
  426. self.topk = topk
  427. self.disable_progress = disable_progress
  428. self.hardware_metric_limit = hardware_metric_limit
  429. def search(self, space: BaseSpace, dset, estimator):
  430. self.model = space
  431. self.dataset = dset # .to(self.device)
  432. self.estimator = estimator
  433. # replace choice
  434. self.nas_modules = []
  435. k2o = get_module_order(self.model)
  436. replace_layer_choice(self.model, PathSamplingLayerChoice, self.nas_modules)
  437. replace_input_choice(self.model, PathSamplingInputChoice, self.nas_modules)
  438. self.nas_modules = sort_replaced_module(k2o, self.nas_modules)
  439. # to device
  440. self.model = self.model.to(self.device)
  441. # fields
  442. self.nas_fields = [
  443. ReinforceField(
  444. name,
  445. len(module),
  446. isinstance(module, PathSamplingLayerChoice) or module.n_chosen == 1,
  447. )
  448. for name, module in self.nas_modules
  449. ]
  450. self.controller = ReinforceController(
  451. self.nas_fields,
  452. lstm_size=100,
  453. temperature=5.0,
  454. tanh_constant=2.5,
  455. **(self.ctrl_kwargs or {}),
  456. )
  457. self.ctrl_optim = torch.optim.Adam(
  458. self.controller.parameters(), lr=self.ctrl_lr
  459. )
  460. # train
  461. with tqdm(range(self.num_epochs), disable=self.disable_progress) as bar:
  462. for i in bar:
  463. l2 = self._train_controller(i)
  464. bar.set_postfix(reward_controller=l2)
  465. # selection=self.export()
  466. selections = [x[1] for x in self.hist]
  467. candidiate_accs = [-x[0] for x in self.hist]
  468. # print('candidiate accuracies',candidiate_accs)
  469. selection = self._choose_best(selections)
  470. arch = space.parse_model(selection, self.device)
  471. # print(selection,arch)
  472. return arch
  473. def _choose_best(self, selections):
  474. # graphnas use top 5 models, can evaluate 20 times epoch and choose the best.
  475. results = []
  476. for selection in selections:
  477. accs = []
  478. for i in tqdm(range(20), disable=self.disable_progress):
  479. self.arch = self.model.parse_model(selection, device=self.device)
  480. metric, loss, _ = self._infer(mask="val")
  481. accs.append(metric)
  482. result = np.mean(accs)
  483. LOGGER.info(
  484. "selection {} \n acc {:.4f} +- {:.4f}".format(
  485. selection, np.mean(accs), np.std(accs) / np.sqrt(20)
  486. )
  487. )
  488. results.append(result)
  489. best_selection = selections[np.argmax(results)]
  490. return best_selection
  491. def _train_controller(self, epoch):
  492. self.model.eval()
  493. self.controller.train()
  494. self.ctrl_optim.zero_grad()
  495. rewards = []
  496. baseline = None
  497. # diff: graph nas train 100 and derive 100 for every epoch(10 epochs), we just train 100(20 epochs). totol num of samples are same (2000)
  498. with tqdm(
  499. range(self.ctrl_steps_aggregate), disable=self.disable_progress
  500. ) as bar:
  501. for ctrl_step in bar:
  502. self._resample()
  503. metric, loss, hardware_metric = self._infer(mask="val")
  504. reward = metric
  505. # bar.set_postfix(acc=metric,loss=loss.item())
  506. LOGGER.debug(f"{self.arch}\n{self.selection}\n{metric},{loss}")
  507. # diff: not do reward shaping as in graphnas code
  508. if (
  509. self.hardware_metric_limit is None
  510. or hardware_metric[0] < self.hardware_metric_limit
  511. ):
  512. self.hist.append([-metric, self.selection])
  513. if len(self.hist) > self.topk:
  514. self.hist.sort(key=lambda x: x[0])
  515. self.hist.pop()
  516. rewards.append(reward)
  517. if self.entropy_weight:
  518. reward += (
  519. self.entropy_weight * self.controller.sample_entropy.item()
  520. )
  521. if not baseline:
  522. baseline = reward
  523. else:
  524. baseline = baseline * self.baseline_decay + reward * (
  525. 1 - self.baseline_decay
  526. )
  527. loss = self.controller.sample_log_prob * (reward - baseline)
  528. self.ctrl_optim.zero_grad()
  529. loss.backward()
  530. self.ctrl_optim.step()
  531. bar.set_postfix(acc=metric, max_acc=max(rewards))
  532. return sum(rewards) / len(rewards)
  533. def _resample(self):
  534. result = self.controller.resample()
  535. self.arch = self.model.parse_model(result, device=self.device)
  536. self.selection = result
  537. def export(self):
  538. self.controller.eval()
  539. with torch.no_grad():
  540. return self.controller.resample()
  541. def _infer(self, mask="train"):
  542. metric, loss = self.estimator.infer(self.arch._model, self.dataset, mask=mask)
  543. return metric[0], loss, metric[1:]