You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

random_search.py 3.1 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from . import register_nas_algo
  5. from .base import BaseNAS
  6. from ..space import BaseSpace
  7. from ..utils import (
  8. AverageMeterGroup,
  9. replace_layer_choice,
  10. replace_input_choice,
  11. get_module_order,
  12. sort_replaced_module,
  13. )
  14. from tqdm import tqdm
  15. from .rl import PathSamplingLayerChoice, PathSamplingInputChoice
  16. import numpy as np
  17. from ....utils import get_logger
  18. LOGGER = get_logger("random_search_NAS")
  19. @register_nas_algo("random")
  20. class RandomSearch(BaseNAS):
  21. """
  22. Uniformly random architecture search
  23. Parameters
  24. ----------
  25. device : str or torch.device
  26. The device of the whole process, e.g. "cuda", torch.device("cpu")
  27. num_epochs : int
  28. Number of epochs planned for training.
  29. disable_progeress: boolean
  30. Control whether show the progress bar.
  31. """
  32. def __init__(self, device="auto", num_epochs=400, disable_progress=False, hardware_metric_limit=None):
  33. super().__init__(device)
  34. self.num_epochs = num_epochs
  35. self.disable_progress = disable_progress
  36. self.hardware_metric_limit = hardware_metric_limit
  37. def search(self, space: BaseSpace, dset, estimator):
  38. self.estimator = estimator
  39. self.dataset = dset
  40. self.space = space
  41. self.nas_modules = []
  42. k2o = get_module_order(self.space)
  43. replace_layer_choice(self.space, PathSamplingLayerChoice, self.nas_modules)
  44. replace_input_choice(self.space, PathSamplingInputChoice, self.nas_modules)
  45. self.nas_modules = sort_replaced_module(k2o, self.nas_modules)
  46. selection_range = {}
  47. for k, v in self.nas_modules:
  48. selection_range[k] = len(v)
  49. self.selection_dict = selection_range
  50. # space_size=np.prod(list(selection_range.values()))
  51. arch_perfs = []
  52. cache = {}
  53. with tqdm(range(self.num_epochs), disable=self.disable_progress) as bar:
  54. for i in bar:
  55. selection = self.sample()
  56. vec = tuple(list(selection.values()))
  57. if vec not in cache:
  58. self.arch = space.parse_model(selection, self.device)
  59. metric, loss, hardware_metric = self._infer(mask="val")
  60. if self.hardware_metric_limit is None or hardware_metric[0] < self.hardware_metric_limit:
  61. arch_perfs.append([metric, selection])
  62. cache[vec] = metric
  63. bar.set_postfix(acc=metric, max_acc=max(cache.values()))
  64. selection = arch_perfs[np.argmax([x[0] for x in arch_perfs])][1]
  65. arch = space.parse_model(selection, self.device)
  66. return arch
  67. def sample(self):
  68. # uniformly sample
  69. selection = {}
  70. for k, v in self.selection_dict.items():
  71. selection[k] = np.random.choice(range(v))
  72. return selection
  73. def _infer(self, mask="train"):
  74. metric, loss = self.estimator.infer(self.arch._model, self.dataset, mask=mask)
  75. return metric[0], loss, metric[1:]