You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

base.py 1.3 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. """
  2. Base class for algorithm
  3. """
  4. from ...model import BaseAutoModel
  5. import torch
  6. from abc import abstractmethod
  7. from ....utils import get_device
  8. class BaseNAS:
  9. """
  10. Base NAS algorithm class
  11. Parameters
  12. ----------
  13. device: str or torch.device
  14. The device of the whole process
  15. """
  16. def __init__(self, device="auto") -> None:
  17. self.device = get_device(device)
  18. def to(self, device):
  19. """
  20. Change the device of the whole NAS search process
  21. Parameters
  22. ----------
  23. device: str or torch.device
  24. """
  25. self.device = get_device(device)
  26. @abstractmethod
  27. def search(self, space, dataset, estimator) -> BaseAutoModel:
  28. """
  29. The search process of NAS.
  30. Parameters
  31. ----------
  32. space : autogl.module.nas.space.BaseSpace
  33. The search space. Constructed following nni.
  34. dataset : autogl.datasets
  35. Dataset to perform search on.
  36. estimator : autogl.module.nas.estimator.BaseEstimator
  37. The estimator to compute loss & metrics.
  38. Returns
  39. -------
  40. model: autogl.module.model.BaseModel
  41. The searched model.
  42. """
  43. raise NotImplementedError()