|
- """
- NAS algorithms
- """
-
- import importlib
- import os
- from .base import BaseNAS
-
- NAS_ALGO_DICT = {}
-
-
- def register_nas_algo(name):
- def register_nas_algo_cls(cls):
- if name in NAS_ALGO_DICT:
- raise ValueError(
- "Cannot register duplicate NAS algorithm ({})".format(name)
- )
- if not issubclass(cls, BaseNAS):
- raise ValueError(
- "Model ({}: {}) must extend NAS algorithm".format(name, cls.__name__)
- )
- NAS_ALGO_DICT[name] = cls
- return cls
-
- return register_nas_algo_cls
-
-
- from .darts import Darts
- from .enas import Enas
- from .random_search import RandomSearch
- from .rl import RL, GraphNasRL
- from .spos import Spos
-
- def build_nas_algo_from_name(name: str) -> BaseNAS:
- """
- Parameters
- ----------
- name: ``str``
- the name of nas algorithm.
-
- Returns
- -------
- BaseNAS:
- the NAS algorithm built using default parameters
-
- Raises
- ------
- AssertionError
- If an invalid name is passed in
- """
- assert name in NAS_ALGO_DICT, "HPO module do not have name " + name
- return NAS_ALGO_DICT[name]()
-
-
- __all__ = ["BaseNAS", "Darts", "Enas", "RandomSearch", "RL", "GraphNasRL","Spos"]
|