|
- # codes in this file are reproduced from https://github.com/microsoft/nni with some changes.
- import copy
- from logging import Logger
- from numpy.core.fromnumeric import sort
-
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
-
- from . import register_nas_algo
- from .base import BaseNAS
- from ..space import BaseSpace
- from ..utils import (
- AverageMeterGroup,
- replace_layer_choice,
- replace_input_choice,
- get_module_order,
- sort_replaced_module,
- PathSamplingLayerChoice,
- PathSamplingInputChoice,
- )
- from tqdm import tqdm, trange
- from ....utils import get_logger
-
- import numpy as np
- LOGGER = get_logger("SPOS")
-
- import collections
- import dataclasses
- import random
-
-
- @dataclasses.dataclass
- class Individual:
- """
- A class that represents an individual.
- Holds two attributes, where ``x`` is the model and ``y`` is the metric (e.g., accuracy).
- """
- x: dict
- y: float
-
-
- class Evolution:
- """
- Algorithm for regularized evolution (i.e. aging evolution).
- Follows "Algorithm 1" in Real et al. "Regularized Evolution for Image Classifier Architecture Search".
-
- Parameters
- ----------
- optimize_mode : str
- Can be one of "maximize" and "minimize". Default: maximize.
- population_size : int
- The number of individuals to keep in the population. Default: 100.
- cycles : int
- The number of cycles (trials) the algorithm should run for. Default: 20000.
- sample_size : int
- The number of individuals that should participate in each tournament. Default: 25.
- mutation_prob : float
- Probability that mutation happens in each dim. Default: 0.05
- """
-
- def __init__(self, optimize_mode='maximize', population_size=100, sample_size=25, cycles=20000,
- mutation_prob=0.05,disable_progress=False):
- assert optimize_mode in ['maximize', 'minimize']
- assert sample_size < population_size
- self.optimize_mode = optimize_mode
- self.population_size = population_size
- self.sample_size = sample_size
- self.cycles = cycles
- self.mutation_prob = mutation_prob
- self.disable_progress= disable_progress
- self._worst = float('-inf') if self.optimize_mode == 'maximize' else float('inf')
-
- self._success_count = 0
- self._population = collections.deque()
- self._running_models = []
- self._polling_interval = 2.
- self._history = []
-
- def best_parent(self,sample_size=None):
- """get the config of the best parent
- """
- samples = [p for p in self._population] # copy population
- random.shuffle(samples)
- if sample_size is not None:
- samples = list(samples)[:sample_size]
- if self.optimize_mode == 'maximize':
- parent = max(samples, key=lambda sample: sample.y)
- else:
- parent = min(samples, key=lambda sample: sample.y)
- return parent.x
-
- def _prepare(self):
- self.uniform=UniformSampler(self.nas_modules)
- self.mutation=MutationSampler(self.nas_modules,self.mutation_prob)
-
- def _get_metric(self,config):
- for name, module in self.nas_modules:
- module.sampled = config[name]
- # todo: this may be computational expensive
- # model=self.model.parse_model(config,self.device)
- with torch.no_grad():
- metric, loss = self.estimator.infer(self.model, self.dataset, mask='val')
- return metric[0]
-
- def search(self, space: BaseSpace,nas_modules,dset,estimator,device):
- self.model = space
- self.dataset = dset
- self.estimator = estimator
- self.nas_modules = nas_modules
- self.device = device
-
- self._prepare()
- LOGGER.info('Initializing the first population.')
- with tqdm(range(self.population_size), disable=self.disable_progress) as bar:
- for i in bar:
- config = self.uniform.resample()
- metric=self._get_metric(config)
- individual = Individual(config, metric)
- # LOGGER.debug('Individual created: %s', str(individual))
- self._population.append(individual)
- self._history.append(individual)
- bar.set_postfix(metric=metric,max=max(x.y for x in self._population),min=min(x.y for x in self._population))
-
- LOGGER.info('Running mutations.')
- with tqdm(range(self.cycles), disable=self.disable_progress) as bar:
- for i in bar:
- parent=self.best_parent(self.sample_size)
- config=self.mutation.resample(parent)
- metric=self._get_metric(config) # todo : add aging factor
- individual = Individual(config, metric)
- LOGGER.debug('Individual created: %s', str(individual))
- self._population.append(individual)
- self._history.append(individual)
- if len(self._population) > self.population_size:
- self._population.popleft()
- bar.set_postfix(metric=metric,max_h=max(x.y for x in self._history),max=max(x.y for x in self._population),min=min(x.y for x in self._population))
-
- # todo: origin is best in history | or the population may need to be retrained
- self._history.sort(key=lambda x: x.y)
- # best=self.best_parent()
- if self.optimize_mode == 'maximize':
- best=self._history[-1].x
- else:
- best=self._history[0].x
- return best
-
-
-
- class MutationSampler:
- """uniform mutator
-
- Parameters
- ----------
- nas_modules:
- nas_modules in NAS algorithms , including choices of modules
- mutation_prob: float
- probability of doing mutation in each choice.
- parent : dict
- parent individual's choices
- """
- def __init__(self,nas_modules,mutation_prob):
- selection_range = {}
- for k, v in nas_modules:
- selection_range[k] = len(v)
- self.selection_dict = selection_range
- self.mutation_prob = mutation_prob
-
- def resample(self, parent):
- search_space=self.selection_dict
- child = {}
- for k, v in parent.items():
- if random.uniform(0, 1) < self.mutation_prob:
- child[k] = np.random.choice(range(search_space[k])) # do not exclude the original operator
- else:
- child[k] = v
- return child
-
- class UniformSampler:
- """Uniform Sampler
-
- Parameters
- ----------
- nas_modules:
- nas_modules in NAS algorithms , including choices of modules
- """
- def __init__(self,nas_modules):
- selection_range = {}
- for k, v in nas_modules:
- selection_range[k] = len(v)
- self.selection_dict = selection_range
- def resample(self):
- selection = {}
- for k, v in self.selection_dict.items():
- selection[k] = np.random.choice(range(v))
- return selection
-
- @register_nas_algo("spos")
- class Spos(BaseNAS):
- """
- SPOS trainer.
-
- Parameters
- ----------
- n_warmup : int
- Number of epochs for training super network.
- model_lr : float
- Learning rate for super network.
- model_wd : float
- Weight decay for super network.
-
- Other parameters see Evolution
- """
-
- def __init__(
- self,
- n_warmup=1000,
- grad_clip=5.0,
- disable_progress=False,
- optimize_mode='maximize',
- population_size=100,
- sample_size=25,
- cycles=20000,
- mutation_prob=0.05,
- device="cuda",
- ):
- super().__init__(device)
- self.model_lr=5e-3
- self.model_wd=5e-4
- self.n_warmup = n_warmup
- self.disable_progress= disable_progress
- self.grad_clip = grad_clip
- self.optimize_mode = optimize_mode
- self.population_size = population_size
- self.sample_size = sample_size
- self.cycles = cycles
- self.mutation_prob = mutation_prob
-
- def _prepare(self):
- # replace choice
- self.nas_modules = []
- k2o = get_module_order(self.model)
- replace_layer_choice(self.model, PathSamplingLayerChoice, self.nas_modules)
- replace_input_choice(self.model, PathSamplingInputChoice, self.nas_modules)
- self.nas_modules = sort_replaced_module(k2o, self.nas_modules)
- # to device
- self.model = self.model.to(self.device)
- self.model_optim = torch.optim.Adam(
- self.model.parameters(), lr=self.model_lr, weight_decay=self.model_wd
- )
- # controller
- self.controller=UniformSampler(self.nas_modules)
-
- # Evolution
- self.evolve = Evolution(
- optimize_mode='maximize',
- population_size=self.population_size,
- sample_size=self.sample_size,
- cycles=self.cycles,
- mutation_prob=self.mutation_prob,
- disable_progress=self.disable_progress
- )
-
- def search(self, space: BaseSpace, dset, estimator):
- self.model = space
- self.dataset = dset
- self.estimator = estimator
-
- self._prepare()
- self._train() # train using uniform sampling
- self._search() # search using evolutionary algorithm
-
- selection = self.export()
-
- # here may sample N , retrain N ,and get best
- print(selection)
- return space.parse_model(selection, self.device)
-
- def _search(self):
- self.best_config=self.evolve.search(
- self.model,
- self.nas_modules,
- self.dataset,
- self.estimator,
- self.device,
- )
-
-
- def _train(self):
- with tqdm(range(self.n_warmup), disable=self.disable_progress) as bar:
- for i in bar:
- acc, l1 = self._train_one_epoch(i)
- with torch.no_grad():
- val_acc, val_loss = self._infer("val")
- bar.set_postfix(loss=l1, acc=acc, val_acc=val_acc, val_loss=val_loss.item())
-
- def _train_one_epoch(self, epoch):
- self.model.train()
- self.model_optim.zero_grad()
- self._resample() # uniform sampling
- metric, loss = self._infer(mask="train")
- loss.backward()
- if self.grad_clip > 0:
- nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip)
- self.model_optim.step()
- return metric, loss.item()
-
- def _resample(self):
- result=self.controller.resample()
- for name, module in self.nas_modules:
- module.sampled = result[name]
-
- def export(self):
- return self.best_config
-
- def _infer(self, mask="train"):
- metric, loss = self.estimator.infer(self.model, self.dataset, mask=mask)
- return metric[0], loss
|