# Copyright (c) Microsoft Corporation. # Licensed under the MIT license. import logging from collections import OrderedDict import time import numpy as np import torch import nni.retiarii.nn.pytorch as nn from nni.nas.pytorch.mutables import Mutable, InputChoice, LayerChoice _logger = logging.getLogger(__name__) class PathSamplingLayerChoice(nn.Module): """ Mixed module, in which fprop is decided by exactly one or multiple (sampled) module. If multiple module is selected, the result will be sumed and returned. Attributes ---------- sampled : int or list of int Sampled module indices. mask : tensor A multi-hot bool 1D-tensor representing the sampled mask. """ def __init__(self, layer_choice): super(PathSamplingLayerChoice, self).__init__() self.op_names = [] for name, module in layer_choice.named_children(): self.add_module(name, module) self.op_names.append(name) assert self.op_names, "There has to be at least one op to choose from." self.sampled = None # sampled can be either a list of indices or an index def forward(self, *args, **kwargs): assert ( self.sampled is not None ), "At least one path needs to be sampled before fprop." if isinstance(self.sampled, list): return sum( [getattr(self, self.op_names[i])(*args, **kwargs) for i in self.sampled] ) # pylint: disable=not-an-iterable else: return getattr(self, self.op_names[self.sampled])( *args, **kwargs ) # pylint: disable=invalid-sequence-index def sampled_choices(self): if self.sampled is None: return [] elif isinstance(self.sampled, list): return [getattr(self, self.op_names[i]) for i in self.sampled] # pylint: disable=not-an-iterable else: return [getattr(self, self.op_names[self.sampled])] # pylint: disable=invalid-sequence-index def __len__(self): return len(self.op_names) @property def mask(self): return _get_mask(self.sampled, len(self)) def __repr__(self): return f"PathSamplingLayerChoice(op_names={self.op_names}, chosen={self.sampled})" class PathSamplingInputChoice(nn.Module): """ Mixed input. Take a list of tensor as input, select some of them and return the sum. Attributes ---------- sampled : int or list of int Sampled module indices. mask : tensor A multi-hot bool 1D-tensor representing the sampled mask. """ def __init__(self, input_choice): super(PathSamplingInputChoice, self).__init__() self.n_candidates = input_choice.n_candidates self.n_chosen = input_choice.n_chosen self.sampled = None def forward(self, input_tensors): if isinstance(self.sampled, list): return sum( [input_tensors[t] for t in self.sampled] ) # pylint: disable=not-an-iterable else: return input_tensors[self.sampled] def __len__(self): return self.n_candidates @property def mask(self): return _get_mask(self.sampled, len(self)) def __repr__(self): return f"PathSamplingInputChoice(n_candidates={self.n_candidates}, chosen={self.sampled})" def get_hardware_aware_metric(model, hardware_metric): """ Get architectures' hardware-aware metrics Attributes ---------- model : BaseSpace The architecture to be evaluated hardware_metric : str The name of hardware-aware metric. Can be 'parameter' or 'latency' """ if hardware_metric == 'parameter': return count_parameters(model) elif hardware_metric == 'latency': return measure_latency(model, 20, warmup_iters=5) else: raise ValueError('Unsupported hardware-aware metric') def count_parameters(module, only_trainable=False): s = sum(p.numel() for p in module.parameters(recurse=False) if not only_trainable or p.requires_grad) if isinstance(module, PathSamplingLayerChoice): s += sum(count_parameters(m) for m in module.sampled_choices()) else: s += sum(count_parameters(m) for m in module.children()) return s def measure_latency(model, num_iters=200, *, warmup_iters=50): device = next(model.parameters()).device num_feat = model.input_dim model.eval() latencys = [] data = _build_random_data(device, num_feat) with torch.no_grad(): try: for i in range(warmup_iters + num_iters): if device.type == 'cuda': torch.cuda.synchronize() start = time.time() model(data) if device.type == 'cuda': torch.cuda.synchronize() dt = time.time() - start if i >= warmup_iters: latencys.append(dt) except RuntimeError as e: if "cuda" in str(e) or "CUDA" in str(e): INF = 100 return INF else: raise e return np.mean(latencys) def _build_random_data(device, num_feat): node_nums = 3000 edge_nums = 10000 class Data: pass data = Data() data.x = torch.randn((node_nums, num_feat)).to(device) data.edge_index = torch.randint(0, node_nums, (2, edge_nums)).to(device) data.num_features = num_feat return data def to_device(obj, device): """ Move a tensor, tuple, list, or dict onto device. """ if torch.is_tensor(obj): return obj.to(device) if isinstance(obj, tuple): return tuple(to_device(t, device) for t in obj) if isinstance(obj, list): return [to_device(t, device) for t in obj] if isinstance(obj, dict): return {k: to_device(v, device) for k, v in obj.items()} if isinstance(obj, (int, float, str)): return obj raise ValueError("'%s' has unsupported type '%s'" % (obj, type(obj))) def to_list(arr): if torch.is_tensor(arr): return arr.cpu().numpy().tolist() if isinstance(arr, np.ndarray): return arr.tolist() if isinstance(arr, (list, tuple)): return list(arr) return arr class AverageMeterGroup: """ Average meter group for multiple average meters. """ def __init__(self): self.meters = OrderedDict() def update(self, data): """ Update the meter group with a dict of metrics. Non-exist average meters will be automatically created. """ for k, v in data.items(): if k not in self.meters: self.meters[k] = AverageMeter(k, ":4f") self.meters[k].update(v) def __getattr__(self, item): return self.meters[item] def __getitem__(self, item): return self.meters[item] def __str__(self): return " ".join(str(v) for v in self.meters.values()) def summary(self): """ Return a summary string of group data. """ return " ".join(v.summary() for v in self.meters.values()) class AverageMeter: """ Computes and stores the average and current value. Parameters ---------- name : str Name to display. fmt : str Format string to print the values. """ def __init__(self, name, fmt=":f"): self.name = name self.fmt = fmt self.reset() def reset(self): """ Reset the meter. """ self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): """ Update with value and weight. Parameters ---------- val : float or int The new value to be accounted in. n : int The weight of the new value. """ self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count def __str__(self): fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})" return fmtstr.format(**self.__dict__) def summary(self): fmtstr = "{name}: {avg" + self.fmt + "}" return fmtstr.format(**self.__dict__) def get_module_order(root_module): key2order = {} def apply(m): for name, child in m.named_children(): if isinstance(child, Mutable): key2order[child.key] = child.order else: apply(child) apply(root_module) return key2order def sort_replaced_module(k2o, modules): modules = sorted(modules, key=lambda x: k2o[x[0]]) return modules def _replace_module_with_type(root_module, init_fn, type_name, modules): if modules is None: modules = [] def apply(m): for name, child in m.named_children(): if isinstance(child, type_name): setattr(m, name, init_fn(child)) modules.append((child.key, getattr(m, name))) else: apply(child) apply(root_module) return modules def replace_layer_choice(root_module, init_fn, modules=None): """ Replace layer choice modules with modules that are initiated with init_fn. Parameters ---------- root_module : nn.Module Root module to traverse. init_fn : Callable Initializing function. modules : dict, optional Update the replaced modules into the dict and check duplicate if provided. Returns ------- List[Tuple[str, nn.Module]] A list from layer choice keys (names) and replaced modules. """ return _replace_module_with_type( root_module, init_fn, (LayerChoice, nn.LayerChoice), modules ) def replace_input_choice(root_module, init_fn, modules=None): """ Replace input choice modules with modules that are initiated with init_fn. Parameters ---------- root_module : nn.Module Root module to traverse. init_fn : Callable Initializing function. modules : dict, optional Update the replaced modules into the dict and check duplicate if provided. Returns ------- List[Tuple[str, nn.Module]] A list from layer choice keys (names) and replaced modules. """ return _replace_module_with_type( root_module, init_fn, (InputChoice, nn.InputChoice), modules )