# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved import logging import torch import torch.distributed as dist from torch import nn from torch.autograd.function import Function from detectron2.utils import comm from .wrappers import BatchNorm2d class FrozenBatchNorm2d(nn.Module): """ BatchNorm2d where the batch statistics and the affine parameters are fixed. It contains non-trainable buffers called "weight" and "bias", "running_mean", "running_var", initialized to perform identity transformation. The pre-trained backbone models from Caffe2 only contain "weight" and "bias", which are computed from the original four parameters of BN. The affine transform `x * weight + bias` will perform the equivalent computation of `(x - running_mean) / sqrt(running_var) * weight + bias`. When loading a backbone model from Caffe2, "running_mean" and "running_var" will be left unchanged as identity transformation. Other pre-trained backbone models may contain all 4 parameters. The forward is implemented by `F.batch_norm(..., training=False)`. """ _version = 3 def __init__(self, num_features, eps=1e-5): super().__init__() self.num_features = num_features self.eps = eps self.register_buffer("weight", torch.ones(num_features)) self.register_buffer("bias", torch.zeros(num_features)) self.register_buffer("running_mean", torch.zeros(num_features)) self.register_buffer("running_var", torch.ones(num_features) - eps) def forward(self, x): scale = self.weight * (self.running_var + self.eps).rsqrt() bias = self.bias - self.running_mean * scale scale = scale.reshape(1, -1, 1, 1) bias = bias.reshape(1, -1, 1, 1) return x * scale + bias def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs ): version = local_metadata.get("version", None) if version is None or version < 2: # No running_mean/var in early versions # This will silent the warnings if prefix + "running_mean" not in state_dict: state_dict[prefix + "running_mean"] = torch.zeros_like(self.running_mean) if prefix + "running_var" not in state_dict: state_dict[prefix + "running_var"] = torch.ones_like(self.running_var) if version is not None and version < 3: logger = logging.getLogger(__name__) logger.info("FrozenBatchNorm {} is upgraded to version 3.".format(prefix.rstrip("."))) # In version < 3, running_var are used without +eps. state_dict[prefix + "running_var"] -= self.eps super()._load_from_state_dict( state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs ) def __repr__(self): return "FrozenBatchNorm2d(num_features={}, eps={})".format(self.num_features, self.eps) @classmethod def convert_frozen_batchnorm(cls, module): """ Convert BatchNorm/SyncBatchNorm in module into FrozenBatchNorm. Args: module (torch.nn.Module): Returns: If module is BatchNorm/SyncBatchNorm, returns a new module. Otherwise, in-place convert module and return it. Similar to convert_sync_batchnorm in https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py """ bn_module = nn.modules.batchnorm bn_module = (bn_module.BatchNorm2d, bn_module.SyncBatchNorm) res = module if isinstance(module, bn_module): res = cls(module.num_features) if module.affine: res.weight.data = module.weight.data.clone().detach() res.bias.data = module.bias.data.clone().detach() res.running_mean.data = module.running_mean.data res.running_var.data = module.running_var.data + module.eps else: for name, child in module.named_children(): new_child = cls.convert_frozen_batchnorm(child) if new_child is not child: res.add_module(name, new_child) return res def get_norm(norm, out_channels): """ Args: norm (str or callable): Returns: nn.Module or None: the normalization layer """ if isinstance(norm, str): if len(norm) == 0: return None norm = { "BN": BatchNorm2d, "SyncBN": NaiveSyncBatchNorm, "FrozenBN": FrozenBatchNorm2d, "GN": lambda channels: nn.GroupNorm(32, channels), "nnSyncBN": nn.SyncBatchNorm, # keep for debugging }[norm] return norm(out_channels) class AllReduce(Function): @staticmethod def forward(ctx, input): input_list = [torch.zeros_like(input) for k in range(dist.get_world_size())] # Use allgather instead of allreduce since I don't trust in-place operations .. dist.all_gather(input_list, input, async_op=False) inputs = torch.stack(input_list, dim=0) return torch.sum(inputs, dim=0) @staticmethod def backward(ctx, grad_output): dist.all_reduce(grad_output, async_op=False) return grad_output class NaiveSyncBatchNorm(BatchNorm2d): """ `torch.nn.SyncBatchNorm` has known unknown bugs. It produces significantly worse AP (and sometimes goes NaN) when the batch size on each worker is quite different (e.g., when scale augmentation is used, or when it is applied to mask head). Use this implementation before `nn.SyncBatchNorm` is fixed. It is slower than `nn.SyncBatchNorm`. """ def forward(self, input): if comm.get_world_size() == 1 or not self.training: return super().forward(input) assert input.shape[0] > 0, "SyncBatchNorm does not support empty inputs" C = input.shape[1] mean = torch.mean(input, dim=[0, 2, 3]) meansqr = torch.mean(input * input, dim=[0, 2, 3]) vec = torch.cat([mean, meansqr], dim=0) vec = AllReduce.apply(vec) * (1.0 / dist.get_world_size()) mean, meansqr = torch.split(vec, C) var = meansqr - mean * mean self.running_mean += self.momentum * (mean.detach() - self.running_mean) self.running_var += self.momentum * (var.detach() - self.running_var) invstd = torch.rsqrt(var + self.eps) scale = self.weight * invstd bias = self.bias - mean * scale scale = scale.reshape(1, -1, 1, 1) bias = bias.reshape(1, -1, 1, 1) return input * scale + bias