接入tinynas-detection,新增tinynas object detection pipeline以及tinynas models。
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9938220
master
| @@ -9,6 +9,8 @@ class Models(object): | |||
| Model name should only contain model info but not task info. | |||
| """ | |||
| tinynas_detection = 'tinynas-detection' | |||
| # vision models | |||
| detection = 'detection' | |||
| realtime_object_detection = 'realtime-object-detection' | |||
| @@ -133,6 +135,7 @@ class Pipelines(object): | |||
| image_to_image_generation = 'image-to-image-generation' | |||
| skin_retouching = 'unet-skin-retouching' | |||
| tinynas_classification = 'tinynas-classification' | |||
| tinynas_detection = 'tinynas-detection' | |||
| crowd_counting = 'hrnet-crowd-counting' | |||
| action_detection = 'ResNetC3D-action-detection' | |||
| video_single_object_tracking = 'ostrack-vitb-video-single-object-tracking' | |||
| @@ -0,0 +1,24 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| # The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet. | |||
| from typing import TYPE_CHECKING | |||
| from modelscope.utils.import_utils import LazyImportModule | |||
| if TYPE_CHECKING: | |||
| from .tinynas_detector import Tinynas_detector | |||
| else: | |||
| _import_structure = { | |||
| 'tinynas_detector': ['TinynasDetector'], | |||
| } | |||
| import sys | |||
| sys.modules[__name__] = LazyImportModule( | |||
| __name__, | |||
| globals()['__file__'], | |||
| _import_structure, | |||
| module_spec=__spec__, | |||
| extra_objects={}, | |||
| ) | |||
| @@ -0,0 +1,16 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| # The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet. | |||
| import copy | |||
| from .darknet import CSPDarknet | |||
| from .tinynas import load_tinynas_net | |||
| def build_backbone(cfg): | |||
| backbone_cfg = copy.deepcopy(cfg) | |||
| name = backbone_cfg.pop('name') | |||
| if name == 'CSPDarknet': | |||
| return CSPDarknet(**backbone_cfg) | |||
| elif name == 'TinyNAS': | |||
| return load_tinynas_net(backbone_cfg) | |||
| @@ -0,0 +1,126 @@ | |||
| # Copyright (c) Megvii Inc. All rights reserved. | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| # The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet. | |||
| import torch | |||
| from torch import nn | |||
| from ..core.base_ops import (BaseConv, CSPLayer, DWConv, Focus, ResLayer, | |||
| SPPBottleneck) | |||
| class CSPDarknet(nn.Module): | |||
| def __init__( | |||
| self, | |||
| dep_mul, | |||
| wid_mul, | |||
| out_features=('dark3', 'dark4', 'dark5'), | |||
| depthwise=False, | |||
| act='silu', | |||
| reparam=False, | |||
| ): | |||
| super(CSPDarknet, self).__init__() | |||
| assert out_features, 'please provide output features of Darknet' | |||
| self.out_features = out_features | |||
| Conv = DWConv if depthwise else BaseConv | |||
| base_channels = int(wid_mul * 64) # 64 | |||
| base_depth = max(round(dep_mul * 3), 1) # 3 | |||
| # stem | |||
| # self.stem = Focus(3, base_channels, ksize=3, act=act) | |||
| self.stem = Focus(3, base_channels, 3, act=act) | |||
| # dark2 | |||
| self.dark2 = nn.Sequential( | |||
| Conv(base_channels, base_channels * 2, 3, 2, act=act), | |||
| CSPLayer( | |||
| base_channels * 2, | |||
| base_channels * 2, | |||
| n=base_depth, | |||
| depthwise=depthwise, | |||
| act=act, | |||
| reparam=reparam, | |||
| ), | |||
| ) | |||
| # dark3 | |||
| self.dark3 = nn.Sequential( | |||
| Conv(base_channels * 2, base_channels * 4, 3, 2, act=act), | |||
| CSPLayer( | |||
| base_channels * 4, | |||
| base_channels * 4, | |||
| n=base_depth * 3, | |||
| depthwise=depthwise, | |||
| act=act, | |||
| reparam=reparam, | |||
| ), | |||
| ) | |||
| # dark4 | |||
| self.dark4 = nn.Sequential( | |||
| Conv(base_channels * 4, base_channels * 8, 3, 2, act=act), | |||
| CSPLayer( | |||
| base_channels * 8, | |||
| base_channels * 8, | |||
| n=base_depth * 3, | |||
| depthwise=depthwise, | |||
| act=act, | |||
| reparam=reparam, | |||
| ), | |||
| ) | |||
| # dark5 | |||
| self.dark5 = nn.Sequential( | |||
| Conv(base_channels * 8, base_channels * 16, 3, 2, act=act), | |||
| SPPBottleneck( | |||
| base_channels * 16, base_channels * 16, activation=act), | |||
| CSPLayer( | |||
| base_channels * 16, | |||
| base_channels * 16, | |||
| n=base_depth, | |||
| shortcut=False, | |||
| depthwise=depthwise, | |||
| act=act, | |||
| reparam=reparam, | |||
| ), | |||
| ) | |||
| def init_weights(self, pretrain=None): | |||
| if pretrain is None: | |||
| return | |||
| else: | |||
| pretrained_dict = torch.load( | |||
| pretrain, map_location='cpu')['state_dict'] | |||
| new_params = self.state_dict().copy() | |||
| for k, v in pretrained_dict.items(): | |||
| ks = k.split('.') | |||
| if ks[0] == 'fc' or ks[-1] == 'total_ops' or ks[ | |||
| -1] == 'total_params': | |||
| continue | |||
| else: | |||
| new_params[k] = v | |||
| self.load_state_dict(new_params) | |||
| print(f' load pretrain backbone from {pretrain}') | |||
| def forward(self, x): | |||
| outputs = {} | |||
| x = self.stem(x) | |||
| outputs['stem'] = x | |||
| x = self.dark2(x) | |||
| outputs['dark2'] = x | |||
| x = self.dark3(x) | |||
| outputs['dark3'] = x | |||
| x = self.dark4(x) | |||
| outputs['dark4'] = x | |||
| x = self.dark5(x) | |||
| outputs['dark5'] = x | |||
| features_out = [ | |||
| outputs['stem'], outputs['dark2'], outputs['dark3'], | |||
| outputs['dark4'], outputs['dark5'] | |||
| ] | |||
| return features_out | |||
| @@ -0,0 +1,347 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| # The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet. | |||
| import torch | |||
| import torch.nn as nn | |||
| from ..core.base_ops import Focus, SPPBottleneck, get_activation | |||
| from ..core.repvgg_block import RepVggBlock | |||
| class ConvKXBN(nn.Module): | |||
| def __init__(self, in_c, out_c, kernel_size, stride): | |||
| super(ConvKXBN, self).__init__() | |||
| self.conv1 = nn.Conv2d( | |||
| in_c, | |||
| out_c, | |||
| kernel_size, | |||
| stride, (kernel_size - 1) // 2, | |||
| groups=1, | |||
| bias=False) | |||
| self.bn1 = nn.BatchNorm2d(out_c) | |||
| def forward(self, x): | |||
| return self.bn1(self.conv1(x)) | |||
| class ConvKXBNRELU(nn.Module): | |||
| def __init__(self, in_c, out_c, kernel_size, stride, act='silu'): | |||
| super(ConvKXBNRELU, self).__init__() | |||
| self.conv = ConvKXBN(in_c, out_c, kernel_size, stride) | |||
| if act is None: | |||
| self.activation_function = torch.relu | |||
| else: | |||
| self.activation_function = get_activation(act) | |||
| def forward(self, x): | |||
| output = self.conv(x) | |||
| return self.activation_function(output) | |||
| class ResConvK1KX(nn.Module): | |||
| def __init__(self, | |||
| in_c, | |||
| out_c, | |||
| btn_c, | |||
| kernel_size, | |||
| stride, | |||
| force_resproj=False, | |||
| act='silu'): | |||
| super(ResConvK1KX, self).__init__() | |||
| self.stride = stride | |||
| self.conv1 = ConvKXBN(in_c, btn_c, 1, 1) | |||
| self.conv2 = RepVggBlock( | |||
| btn_c, out_c, kernel_size, stride, act='identity') | |||
| if act is None: | |||
| self.activation_function = torch.relu | |||
| else: | |||
| self.activation_function = get_activation(act) | |||
| if stride == 2: | |||
| self.residual_downsample = nn.AvgPool2d(kernel_size=2, stride=2) | |||
| else: | |||
| self.residual_downsample = nn.Identity() | |||
| if in_c != out_c or force_resproj: | |||
| self.residual_proj = ConvKXBN(in_c, out_c, 1, 1) | |||
| else: | |||
| self.residual_proj = nn.Identity() | |||
| def forward(self, x): | |||
| if self.stride != 2: | |||
| reslink = self.residual_downsample(x) | |||
| reslink = self.residual_proj(reslink) | |||
| output = x | |||
| output = self.conv1(output) | |||
| output = self.activation_function(output) | |||
| output = self.conv2(output) | |||
| if self.stride != 2: | |||
| output = output + reslink | |||
| output = self.activation_function(output) | |||
| return output | |||
| class SuperResConvK1KX(nn.Module): | |||
| def __init__(self, | |||
| in_c, | |||
| out_c, | |||
| btn_c, | |||
| kernel_size, | |||
| stride, | |||
| num_blocks, | |||
| with_spp=False, | |||
| act='silu'): | |||
| super(SuperResConvK1KX, self).__init__() | |||
| if act is None: | |||
| self.act = torch.relu | |||
| else: | |||
| self.act = get_activation(act) | |||
| self.block_list = nn.ModuleList() | |||
| for block_id in range(num_blocks): | |||
| if block_id == 0: | |||
| in_channels = in_c | |||
| out_channels = out_c | |||
| this_stride = stride | |||
| force_resproj = False # as a part of CSPLayer, DO NOT need this flag | |||
| this_kernel_size = kernel_size | |||
| else: | |||
| in_channels = out_c | |||
| out_channels = out_c | |||
| this_stride = 1 | |||
| force_resproj = False | |||
| this_kernel_size = kernel_size | |||
| the_block = ResConvK1KX( | |||
| in_channels, | |||
| out_channels, | |||
| btn_c, | |||
| this_kernel_size, | |||
| this_stride, | |||
| force_resproj, | |||
| act=act) | |||
| self.block_list.append(the_block) | |||
| if block_id == 0 and with_spp: | |||
| self.block_list.append( | |||
| SPPBottleneck(out_channels, out_channels)) | |||
| def forward(self, x): | |||
| output = x | |||
| for block in self.block_list: | |||
| output = block(output) | |||
| return output | |||
| class ResConvKXKX(nn.Module): | |||
| def __init__(self, | |||
| in_c, | |||
| out_c, | |||
| btn_c, | |||
| kernel_size, | |||
| stride, | |||
| force_resproj=False, | |||
| act='silu'): | |||
| super(ResConvKXKX, self).__init__() | |||
| self.stride = stride | |||
| if self.stride == 2: | |||
| self.downsampler = ConvKXBNRELU(in_c, out_c, 3, 2, act=act) | |||
| else: | |||
| self.conv1 = ConvKXBN(in_c, btn_c, kernel_size, 1) | |||
| self.conv2 = RepVggBlock( | |||
| btn_c, out_c, kernel_size, stride, act='identity') | |||
| if act is None: | |||
| self.activation_function = torch.relu | |||
| else: | |||
| self.activation_function = get_activation(act) | |||
| if stride == 2: | |||
| self.residual_downsample = nn.AvgPool2d( | |||
| kernel_size=2, stride=2) | |||
| else: | |||
| self.residual_downsample = nn.Identity() | |||
| if in_c != out_c or force_resproj: | |||
| self.residual_proj = ConvKXBN(in_c, out_c, 1, 1) | |||
| else: | |||
| self.residual_proj = nn.Identity() | |||
| def forward(self, x): | |||
| if self.stride == 2: | |||
| return self.downsampler(x) | |||
| reslink = self.residual_downsample(x) | |||
| reslink = self.residual_proj(reslink) | |||
| output = x | |||
| output = self.conv1(output) | |||
| output = self.activation_function(output) | |||
| output = self.conv2(output) | |||
| output = output + reslink | |||
| output = self.activation_function(output) | |||
| return output | |||
| class SuperResConvKXKX(nn.Module): | |||
| def __init__(self, | |||
| in_c, | |||
| out_c, | |||
| btn_c, | |||
| kernel_size, | |||
| stride, | |||
| num_blocks, | |||
| with_spp=False, | |||
| act='silu'): | |||
| super(SuperResConvKXKX, self).__init__() | |||
| if act is None: | |||
| self.act = torch.relu | |||
| else: | |||
| self.act = get_activation(act) | |||
| self.block_list = nn.ModuleList() | |||
| for block_id in range(num_blocks): | |||
| if block_id == 0: | |||
| in_channels = in_c | |||
| out_channels = out_c | |||
| this_stride = stride | |||
| force_resproj = False # as a part of CSPLayer, DO NOT need this flag | |||
| this_kernel_size = kernel_size | |||
| else: | |||
| in_channels = out_c | |||
| out_channels = out_c | |||
| this_stride = 1 | |||
| force_resproj = False | |||
| this_kernel_size = kernel_size | |||
| the_block = ResConvKXKX( | |||
| in_channels, | |||
| out_channels, | |||
| btn_c, | |||
| this_kernel_size, | |||
| this_stride, | |||
| force_resproj, | |||
| act=act) | |||
| self.block_list.append(the_block) | |||
| if block_id == 0 and with_spp: | |||
| self.block_list.append( | |||
| SPPBottleneck(out_channels, out_channels)) | |||
| def forward(self, x): | |||
| output = x | |||
| for block in self.block_list: | |||
| output = block(output) | |||
| return output | |||
| class TinyNAS(nn.Module): | |||
| def __init__(self, | |||
| structure_info=None, | |||
| out_indices=[0, 1, 2, 4, 5], | |||
| out_channels=[None, None, 128, 256, 512], | |||
| with_spp=False, | |||
| use_focus=False, | |||
| need_conv1=True, | |||
| act='silu'): | |||
| super(TinyNAS, self).__init__() | |||
| assert len(out_indices) == len(out_channels) | |||
| self.out_indices = out_indices | |||
| self.need_conv1 = need_conv1 | |||
| self.block_list = nn.ModuleList() | |||
| if need_conv1: | |||
| self.conv1_list = nn.ModuleList() | |||
| for idx, block_info in enumerate(structure_info): | |||
| the_block_class = block_info['class'] | |||
| if the_block_class == 'ConvKXBNRELU': | |||
| if use_focus: | |||
| the_block = Focus(block_info['in'], block_info['out'], | |||
| block_info['k']) | |||
| else: | |||
| the_block = ConvKXBNRELU( | |||
| block_info['in'], | |||
| block_info['out'], | |||
| block_info['k'], | |||
| block_info['s'], | |||
| act=act) | |||
| self.block_list.append(the_block) | |||
| elif the_block_class == 'SuperResConvK1KX': | |||
| spp = with_spp if idx == len(structure_info) - 1 else False | |||
| the_block = SuperResConvK1KX( | |||
| block_info['in'], | |||
| block_info['out'], | |||
| block_info['btn'], | |||
| block_info['k'], | |||
| block_info['s'], | |||
| block_info['L'], | |||
| spp, | |||
| act=act) | |||
| self.block_list.append(the_block) | |||
| elif the_block_class == 'SuperResConvKXKX': | |||
| spp = with_spp if idx == len(structure_info) - 1 else False | |||
| the_block = SuperResConvKXKX( | |||
| block_info['in'], | |||
| block_info['out'], | |||
| block_info['btn'], | |||
| block_info['k'], | |||
| block_info['s'], | |||
| block_info['L'], | |||
| spp, | |||
| act=act) | |||
| self.block_list.append(the_block) | |||
| if need_conv1: | |||
| if idx in self.out_indices and out_channels[ | |||
| self.out_indices.index(idx)] is not None: | |||
| self.conv1_list.append( | |||
| nn.Conv2d(block_info['out'], | |||
| out_channels[self.out_indices.index(idx)], | |||
| 1)) | |||
| else: | |||
| self.conv1_list.append(None) | |||
| def init_weights(self, pretrain=None): | |||
| pass | |||
| def forward(self, x): | |||
| output = x | |||
| stage_feature_list = [] | |||
| for idx, block in enumerate(self.block_list): | |||
| output = block(output) | |||
| if idx in self.out_indices: | |||
| if self.need_conv1 and self.conv1_list[idx] is not None: | |||
| true_out = self.conv1_list[idx](output) | |||
| stage_feature_list.append(true_out) | |||
| else: | |||
| stage_feature_list.append(output) | |||
| return stage_feature_list | |||
| def load_tinynas_net(backbone_cfg): | |||
| # load masternet model to path | |||
| import ast | |||
| struct_str = ''.join([x.strip() for x in backbone_cfg.net_structure_str]) | |||
| struct_info = ast.literal_eval(struct_str) | |||
| for layer in struct_info: | |||
| if 'nbitsA' in layer: | |||
| del layer['nbitsA'] | |||
| if 'nbitsW' in layer: | |||
| del layer['nbitsW'] | |||
| model = TinyNAS( | |||
| structure_info=struct_info, | |||
| out_indices=backbone_cfg.out_indices, | |||
| out_channels=backbone_cfg.out_channels, | |||
| with_spp=backbone_cfg.with_spp, | |||
| use_focus=backbone_cfg.use_focus, | |||
| act=backbone_cfg.act, | |||
| need_conv1=backbone_cfg.need_conv1, | |||
| ) | |||
| return model | |||
| @@ -0,0 +1,2 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| # The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet. | |||
| @@ -0,0 +1,474 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| # The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet. | |||
| import math | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| from .repvgg_block import RepVggBlock | |||
| class SiLU(nn.Module): | |||
| """export-friendly version of nn.SiLU()""" | |||
| @staticmethod | |||
| def forward(x): | |||
| return x * torch.sigmoid(x) | |||
| def get_activation(name='silu', inplace=True): | |||
| if name == 'silu': | |||
| module = nn.SiLU(inplace=inplace) | |||
| elif name == 'relu': | |||
| module = nn.ReLU(inplace=inplace) | |||
| elif name == 'lrelu': | |||
| module = nn.LeakyReLU(0.1, inplace=inplace) | |||
| else: | |||
| raise AttributeError('Unsupported act type: {}'.format(name)) | |||
| return module | |||
| def get_norm(name, out_channels, inplace=True): | |||
| if name == 'bn': | |||
| module = nn.BatchNorm2d(out_channels) | |||
| elif name == 'gn': | |||
| module = nn.GroupNorm(num_channels=out_channels, num_groups=32) | |||
| return module | |||
| class BaseConv(nn.Module): | |||
| """A Conv2d -> Batchnorm -> silu/leaky relu block""" | |||
| def __init__(self, | |||
| in_channels, | |||
| out_channels, | |||
| ksize, | |||
| stride=1, | |||
| groups=1, | |||
| bias=False, | |||
| act='silu', | |||
| norm='bn'): | |||
| super().__init__() | |||
| # same padding | |||
| pad = (ksize - 1) // 2 | |||
| self.conv = nn.Conv2d( | |||
| in_channels, | |||
| out_channels, | |||
| kernel_size=ksize, | |||
| stride=stride, | |||
| padding=pad, | |||
| groups=groups, | |||
| bias=bias, | |||
| ) | |||
| if norm is not None: | |||
| self.bn = get_norm(norm, out_channels, inplace=True) | |||
| if act is not None: | |||
| self.act = get_activation(act, inplace=True) | |||
| self.with_norm = norm is not None | |||
| self.with_act = act is not None | |||
| def forward(self, x): | |||
| x = self.conv(x) | |||
| if self.with_norm: | |||
| # x = self.norm(x) | |||
| x = self.bn(x) | |||
| if self.with_act: | |||
| x = self.act(x) | |||
| return x | |||
| def fuseforward(self, x): | |||
| return self.act(self.conv(x)) | |||
| class DepthWiseConv(nn.Module): | |||
| def __init__(self, | |||
| in_channels, | |||
| out_channels, | |||
| ksize, | |||
| stride=1, | |||
| groups=None, | |||
| bias=False, | |||
| act='silu', | |||
| norm='bn'): | |||
| super().__init__() | |||
| padding = (ksize - 1) // 2 | |||
| self.depthwise = nn.Conv2d( | |||
| in_channels, | |||
| in_channels, | |||
| kernel_size=ksize, | |||
| stride=stride, | |||
| padding=padding, | |||
| groups=in_channels, | |||
| bias=bias, | |||
| ) | |||
| self.pointwise = nn.Conv2d( | |||
| in_channels, | |||
| out_channels, | |||
| kernel_size=1, | |||
| stride=1, | |||
| padding=0, | |||
| bias=bias) | |||
| if norm is not None: | |||
| self.dwnorm = get_norm(norm, in_channels, inplace=True) | |||
| self.pwnorm = get_norm(norm, out_channels, inplace=True) | |||
| if act is not None: | |||
| self.act = get_activation(act, inplace=True) | |||
| self.with_norm = norm is not None | |||
| self.with_act = act is not None | |||
| self.order = ['depthwise', 'dwnorm', 'pointwise', 'act'] | |||
| def forward(self, x): | |||
| for layer_name in self.order: | |||
| layer = self.__getattr__(layer_name) | |||
| if layer is not None: | |||
| x = layer(x) | |||
| return x | |||
| class DWConv(nn.Module): | |||
| """Depthwise Conv + Conv""" | |||
| def __init__(self, in_channels, out_channels, ksize, stride=1, act='silu'): | |||
| super().__init__() | |||
| self.dconv = BaseConv( | |||
| in_channels, | |||
| in_channels, | |||
| ksize=ksize, | |||
| stride=stride, | |||
| groups=in_channels, | |||
| act=act, | |||
| ) | |||
| self.pconv = BaseConv( | |||
| in_channels, out_channels, ksize=1, stride=1, groups=1, act=act) | |||
| def forward(self, x): | |||
| x = self.dconv(x) | |||
| return self.pconv(x) | |||
| class Bottleneck(nn.Module): | |||
| # Standard bottleneck | |||
| def __init__( | |||
| self, | |||
| in_channels, | |||
| out_channels, | |||
| shortcut=True, | |||
| expansion=0.5, | |||
| depthwise=False, | |||
| act='silu', | |||
| reparam=False, | |||
| ): | |||
| super().__init__() | |||
| hidden_channels = int(out_channels * expansion) | |||
| Conv = DWConv if depthwise else BaseConv | |||
| k_conv1 = 3 if reparam else 1 | |||
| self.conv1 = BaseConv( | |||
| in_channels, hidden_channels, k_conv1, stride=1, act=act) | |||
| if reparam: | |||
| self.conv2 = RepVggBlock( | |||
| hidden_channels, out_channels, 3, stride=1, act=act) | |||
| else: | |||
| self.conv2 = Conv( | |||
| hidden_channels, out_channels, 3, stride=1, act=act) | |||
| self.use_add = shortcut and in_channels == out_channels | |||
| def forward(self, x): | |||
| y = self.conv2(self.conv1(x)) | |||
| if self.use_add: | |||
| y = y + x | |||
| return y | |||
| class ResLayer(nn.Module): | |||
| 'Residual layer with `in_channels` inputs.' | |||
| def __init__(self, in_channels: int): | |||
| super().__init__() | |||
| mid_channels = in_channels // 2 | |||
| self.layer1 = BaseConv( | |||
| in_channels, mid_channels, ksize=1, stride=1, act='lrelu') | |||
| self.layer2 = BaseConv( | |||
| mid_channels, in_channels, ksize=3, stride=1, act='lrelu') | |||
| def forward(self, x): | |||
| out = self.layer2(self.layer1(x)) | |||
| return x + out | |||
| class SPPBottleneck(nn.Module): | |||
| """Spatial pyramid pooling layer used in YOLOv3-SPP""" | |||
| def __init__(self, | |||
| in_channels, | |||
| out_channels, | |||
| kernel_sizes=(5, 9, 13), | |||
| activation='silu'): | |||
| super().__init__() | |||
| hidden_channels = in_channels // 2 | |||
| self.conv1 = BaseConv( | |||
| in_channels, hidden_channels, 1, stride=1, act=activation) | |||
| self.m = nn.ModuleList([ | |||
| nn.MaxPool2d(kernel_size=ks, stride=1, padding=ks // 2) | |||
| for ks in kernel_sizes | |||
| ]) | |||
| conv2_channels = hidden_channels * (len(kernel_sizes) + 1) | |||
| self.conv2 = BaseConv( | |||
| conv2_channels, out_channels, 1, stride=1, act=activation) | |||
| def forward(self, x): | |||
| x = self.conv1(x) | |||
| x = torch.cat([x] + [m(x) for m in self.m], dim=1) | |||
| x = self.conv2(x) | |||
| return x | |||
| class CSPLayer(nn.Module): | |||
| """C3 in yolov5, CSP Bottleneck with 3 convolutions""" | |||
| def __init__( | |||
| self, | |||
| in_channels, | |||
| out_channels, | |||
| n=1, | |||
| shortcut=True, | |||
| expansion=0.5, | |||
| depthwise=False, | |||
| act='silu', | |||
| reparam=False, | |||
| ): | |||
| """ | |||
| Args: | |||
| in_channels (int): input channels. | |||
| out_channels (int): output channels. | |||
| n (int): number of Bottlenecks. Default value: 1. | |||
| """ | |||
| # ch_in, ch_out, number, shortcut, groups, expansion | |||
| super().__init__() | |||
| hidden_channels = int(out_channels * expansion) # hidden channels | |||
| self.conv1 = BaseConv( | |||
| in_channels, hidden_channels, 1, stride=1, act=act) | |||
| self.conv2 = BaseConv( | |||
| in_channels, hidden_channels, 1, stride=1, act=act) | |||
| self.conv3 = BaseConv( | |||
| 2 * hidden_channels, out_channels, 1, stride=1, act=act) | |||
| module_list = [ | |||
| Bottleneck( | |||
| hidden_channels, | |||
| hidden_channels, | |||
| shortcut, | |||
| 1.0, | |||
| depthwise, | |||
| act=act, | |||
| reparam=reparam) for _ in range(n) | |||
| ] | |||
| self.m = nn.Sequential(*module_list) | |||
| def forward(self, x): | |||
| x_1 = self.conv1(x) | |||
| x_2 = self.conv2(x) | |||
| x_1 = self.m(x_1) | |||
| x = torch.cat((x_1, x_2), dim=1) | |||
| return self.conv3(x) | |||
| class Focus(nn.Module): | |||
| """Focus width and height information into channel space.""" | |||
| def __init__(self, | |||
| in_channels, | |||
| out_channels, | |||
| ksize=1, | |||
| stride=1, | |||
| act='silu'): | |||
| super().__init__() | |||
| self.conv = BaseConv( | |||
| in_channels * 4, out_channels, ksize, stride, act=act) | |||
| def forward(self, x): | |||
| # shape of x (b,c,w,h) -> y(b,4c,w/2,h/2) | |||
| patch_top_left = x[..., ::2, ::2] | |||
| patch_top_right = x[..., ::2, 1::2] | |||
| patch_bot_left = x[..., 1::2, ::2] | |||
| patch_bot_right = x[..., 1::2, 1::2] | |||
| x = torch.cat( | |||
| ( | |||
| patch_top_left, | |||
| patch_bot_left, | |||
| patch_top_right, | |||
| patch_bot_right, | |||
| ), | |||
| dim=1, | |||
| ) | |||
| return self.conv(x) | |||
| class fast_Focus(nn.Module): | |||
| def __init__(self, | |||
| in_channels, | |||
| out_channels, | |||
| ksize=1, | |||
| stride=1, | |||
| act='silu'): | |||
| super(Focus, self).__init__() | |||
| self.conv1 = self.focus_conv(w1=1.0) | |||
| self.conv2 = self.focus_conv(w3=1.0) | |||
| self.conv3 = self.focus_conv(w2=1.0) | |||
| self.conv4 = self.focus_conv(w4=1.0) | |||
| self.conv = BaseConv( | |||
| in_channels * 4, out_channels, ksize, stride, act=act) | |||
| def forward(self, x): | |||
| return self.conv( | |||
| torch.cat( | |||
| [self.conv1(x), | |||
| self.conv2(x), | |||
| self.conv3(x), | |||
| self.conv4(x)], 1)) | |||
| def focus_conv(self, w1=0.0, w2=0.0, w3=0.0, w4=0.0): | |||
| conv = nn.Conv2d(3, 3, 2, 2, groups=3, bias=False) | |||
| conv.weight = self.init_weights_constant(w1, w2, w3, w4) | |||
| conv.weight.requires_grad = False | |||
| return conv | |||
| def init_weights_constant(self, w1=0.0, w2=0.0, w3=0.0, w4=0.0): | |||
| return nn.Parameter( | |||
| torch.tensor([[[[w1, w2], [w3, w4]]], [[[w1, w2], [w3, w4]]], | |||
| [[[w1, w2], [w3, w4]]]])) | |||
| # shufflenet block | |||
| def channel_shuffle(x, groups=2): | |||
| bat_size, channels, w, h = x.shape | |||
| group_c = channels // groups | |||
| x = x.view(bat_size, groups, group_c, w, h) | |||
| x = torch.transpose(x, 1, 2).contiguous() | |||
| x = x.view(bat_size, -1, w, h) | |||
| return x | |||
| def conv_1x1_bn(in_c, out_c, stride=1): | |||
| return nn.Sequential( | |||
| nn.Conv2d(in_c, out_c, 1, stride, 0, bias=False), | |||
| nn.BatchNorm2d(out_c), nn.ReLU(True)) | |||
| def conv_bn(in_c, out_c, stride=2): | |||
| return nn.Sequential( | |||
| nn.Conv2d(in_c, out_c, 3, stride, 1, bias=False), | |||
| nn.BatchNorm2d(out_c), nn.ReLU(True)) | |||
| class ShuffleBlock(nn.Module): | |||
| def __init__(self, in_c, out_c, downsample=False): | |||
| super(ShuffleBlock, self).__init__() | |||
| self.downsample = downsample | |||
| half_c = out_c // 2 | |||
| if downsample: | |||
| self.branch1 = nn.Sequential( | |||
| # 3*3 dw conv, stride = 2 | |||
| # nn.Conv2d(in_c, in_c, 3, 2, 1, groups=in_c, bias=False), | |||
| nn.Conv2d(in_c, in_c, 3, 1, 1, groups=in_c, bias=False), | |||
| nn.BatchNorm2d(in_c), | |||
| # 1*1 pw conv | |||
| nn.Conv2d(in_c, half_c, 1, 1, 0, bias=False), | |||
| nn.BatchNorm2d(half_c), | |||
| nn.ReLU(True)) | |||
| self.branch2 = nn.Sequential( | |||
| # 1*1 pw conv | |||
| nn.Conv2d(in_c, half_c, 1, 1, 0, bias=False), | |||
| nn.BatchNorm2d(half_c), | |||
| nn.ReLU(True), | |||
| # 3*3 dw conv, stride = 2 | |||
| # nn.Conv2d(half_c, half_c, 3, 2, 1, groups=half_c, bias=False), | |||
| nn.Conv2d(half_c, half_c, 3, 1, 1, groups=half_c, bias=False), | |||
| nn.BatchNorm2d(half_c), | |||
| # 1*1 pw conv | |||
| nn.Conv2d(half_c, half_c, 1, 1, 0, bias=False), | |||
| nn.BatchNorm2d(half_c), | |||
| nn.ReLU(True)) | |||
| else: | |||
| # in_c = out_c | |||
| assert in_c == out_c | |||
| self.branch2 = nn.Sequential( | |||
| # 1*1 pw conv | |||
| nn.Conv2d(half_c, half_c, 1, 1, 0, bias=False), | |||
| nn.BatchNorm2d(half_c), | |||
| nn.ReLU(True), | |||
| # 3*3 dw conv, stride = 1 | |||
| nn.Conv2d(half_c, half_c, 3, 1, 1, groups=half_c, bias=False), | |||
| nn.BatchNorm2d(half_c), | |||
| # 1*1 pw conv | |||
| nn.Conv2d(half_c, half_c, 1, 1, 0, bias=False), | |||
| nn.BatchNorm2d(half_c), | |||
| nn.ReLU(True)) | |||
| def forward(self, x): | |||
| out = None | |||
| if self.downsample: | |||
| # if it is downsampling, we don't need to do channel split | |||
| out = torch.cat((self.branch1(x), self.branch2(x)), 1) | |||
| else: | |||
| # channel split | |||
| channels = x.shape[1] | |||
| c = channels // 2 | |||
| x1 = x[:, :c, :, :] | |||
| x2 = x[:, c:, :, :] | |||
| out = torch.cat((x1, self.branch2(x2)), 1) | |||
| return channel_shuffle(out, 2) | |||
| class ShuffleCSPLayer(nn.Module): | |||
| """C3 in yolov5, CSP Bottleneck with 3 convolutions""" | |||
| def __init__( | |||
| self, | |||
| in_channels, | |||
| out_channels, | |||
| n=1, | |||
| shortcut=True, | |||
| expansion=0.5, | |||
| depthwise=False, | |||
| act='silu', | |||
| ): | |||
| """ | |||
| Args: | |||
| in_channels (int): input channels. | |||
| out_channels (int): output channels. | |||
| n (int): number of Bottlenecks. Default value: 1. | |||
| """ | |||
| # ch_in, ch_out, number, shortcut, groups, expansion | |||
| super().__init__() | |||
| hidden_channels = int(out_channels * expansion) # hidden channels | |||
| self.conv1 = BaseConv( | |||
| in_channels, hidden_channels, 1, stride=1, act=act) | |||
| self.conv2 = BaseConv( | |||
| in_channels, hidden_channels, 1, stride=1, act=act) | |||
| module_list = [ | |||
| Bottleneck( | |||
| hidden_channels, | |||
| hidden_channels, | |||
| shortcut, | |||
| 1.0, | |||
| depthwise, | |||
| act=act) for _ in range(n) | |||
| ] | |||
| self.m = nn.Sequential(*module_list) | |||
| def forward(self, x): | |||
| x_1 = self.conv1(x) | |||
| x_2 = self.conv2(x) | |||
| x_1 = self.m(x_1) | |||
| x = torch.cat((x_1, x_2), dim=1) | |||
| # add channel shuffle | |||
| return channel_shuffle(x, 2) | |||
| @@ -0,0 +1,324 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| # The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet. | |||
| import numpy as np | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| class Swish(nn.Module): | |||
| def __init__(self, inplace=True): | |||
| super(Swish, self).__init__() | |||
| self.inplace = inplace | |||
| def forward(self, x): | |||
| if self.inplace: | |||
| x.mul_(F.sigmoid(x)) | |||
| return x | |||
| else: | |||
| return x * F.sigmoid(x) | |||
| def get_activation(name='silu', inplace=True): | |||
| if name is None: | |||
| return nn.Identity() | |||
| if isinstance(name, str): | |||
| if name == 'silu': | |||
| module = nn.SiLU(inplace=inplace) | |||
| elif name == 'relu': | |||
| module = nn.ReLU(inplace=inplace) | |||
| elif name == 'lrelu': | |||
| module = nn.LeakyReLU(0.1, inplace=inplace) | |||
| elif name == 'swish': | |||
| module = Swish(inplace=inplace) | |||
| elif name == 'hardsigmoid': | |||
| module = nn.Hardsigmoid(inplace=inplace) | |||
| else: | |||
| raise AttributeError('Unsupported act type: {}'.format(name)) | |||
| return module | |||
| elif isinstance(name, nn.Module): | |||
| return name | |||
| else: | |||
| raise AttributeError('Unsupported act type: {}'.format(name)) | |||
| class ConvBNLayer(nn.Module): | |||
| def __init__(self, | |||
| ch_in, | |||
| ch_out, | |||
| filter_size=3, | |||
| stride=1, | |||
| groups=1, | |||
| padding=0, | |||
| act=None): | |||
| super(ConvBNLayer, self).__init__() | |||
| self.conv = nn.Conv2d( | |||
| in_channels=ch_in, | |||
| out_channels=ch_out, | |||
| kernel_size=filter_size, | |||
| stride=stride, | |||
| padding=padding, | |||
| groups=groups, | |||
| bias=False) | |||
| self.bn = nn.BatchNorm2d(ch_out, ) | |||
| self.act = get_activation(act, inplace=True) | |||
| def forward(self, x): | |||
| x = self.conv(x) | |||
| x = self.bn(x) | |||
| x = self.act(x) | |||
| return x | |||
| class RepVGGBlock(nn.Module): | |||
| def __init__(self, ch_in, ch_out, act='relu', deploy=False): | |||
| super(RepVGGBlock, self).__init__() | |||
| self.ch_in = ch_in | |||
| self.ch_out = ch_out | |||
| self.deploy = deploy | |||
| self.in_channels = ch_in | |||
| self.groups = 1 | |||
| if self.deploy is False: | |||
| self.rbr_dense = ConvBNLayer( | |||
| ch_in, ch_out, 3, stride=1, padding=1, act=None) | |||
| self.rbr_1x1 = ConvBNLayer( | |||
| ch_in, ch_out, 1, stride=1, padding=0, act=None) | |||
| # self.rbr_identity = nn.BatchNorm2d(num_features=ch_in) if ch_out == ch_in else None | |||
| self.rbr_identity = None | |||
| else: | |||
| self.rbr_reparam = nn.Conv2d( | |||
| in_channels=self.ch_in, | |||
| out_channels=self.ch_out, | |||
| kernel_size=3, | |||
| stride=1, | |||
| padding=1, | |||
| groups=1) | |||
| self.act = get_activation(act) if act is None or isinstance( | |||
| act, (str, dict)) else act | |||
| def forward(self, x): | |||
| if self.deploy: | |||
| print('----------deploy----------') | |||
| y = self.rbr_reparam(x) | |||
| else: | |||
| if self.rbr_identity is None: | |||
| y = self.rbr_dense(x) + self.rbr_1x1(x) | |||
| else: | |||
| y = self.rbr_dense(x) + self.rbr_1x1(x) + self.rbr_identity(x) | |||
| y = self.act(y) | |||
| return y | |||
| def switch_to_deploy(self): | |||
| print('switch') | |||
| if not hasattr(self, 'rbr_reparam'): | |||
| # return | |||
| self.rbr_reparam = nn.Conv2d( | |||
| in_channels=self.ch_in, | |||
| out_channels=self.ch_out, | |||
| kernel_size=3, | |||
| stride=1, | |||
| padding=1, | |||
| groups=1) | |||
| print('switch') | |||
| kernel, bias = self.get_equivalent_kernel_bias() | |||
| self.rbr_reparam.weight.data = kernel | |||
| self.rbr_reparam.bias.data = bias | |||
| for para in self.parameters(): | |||
| para.detach_() | |||
| # self.__delattr__(self.rbr_dense) | |||
| # self.__delattr__(self.rbr_1x1) | |||
| self.__delattr__('rbr_dense') | |||
| self.__delattr__('rbr_1x1') | |||
| if hasattr(self, 'rbr_identity'): | |||
| self.__delattr__('rbr_identity') | |||
| if hasattr(self, 'id_tensor'): | |||
| self.__delattr__('id_tensor') | |||
| self.deploy = True | |||
| def get_equivalent_kernel_bias(self): | |||
| kernel3x3, bias3x3 = self._fuse_bn_tensor(self.rbr_dense) | |||
| kernel1x1, bias1x1 = self._fuse_bn_tensor(self.rbr_1x1) | |||
| kernelid, biasid = self._fuse_bn_tensor(self.rbr_identity) | |||
| return kernel3x3 + self._pad_1x1_to_3x3_tensor( | |||
| kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid | |||
| def _pad_1x1_to_3x3_tensor(self, kernel1x1): | |||
| if kernel1x1 is None: | |||
| return 0 | |||
| else: | |||
| return torch.nn.functional.pad(kernel1x1, [1, 1, 1, 1]) | |||
| def _fuse_bn_tensor(self, branch): | |||
| if branch is None: | |||
| return 0, 0 | |||
| # if isinstance(branch, nn.Sequential): | |||
| if isinstance(branch, ConvBNLayer): | |||
| kernel = branch.conv.weight | |||
| running_mean = branch.bn.running_mean | |||
| running_var = branch.bn.running_var | |||
| gamma = branch.bn.weight | |||
| beta = branch.bn.bias | |||
| eps = branch.bn.eps | |||
| else: | |||
| assert isinstance(branch, nn.BatchNorm2d) | |||
| if not hasattr(self, 'id_tensor'): | |||
| input_dim = self.in_channels // self.groups | |||
| kernel_value = np.zeros((self.in_channels, input_dim, 3, 3), | |||
| dtype=np.float32) | |||
| for i in range(self.in_channels): | |||
| kernel_value[i, i % input_dim, 1, 1] = 1 | |||
| self.id_tensor = torch.from_numpy(kernel_value).to( | |||
| branch.weight.device) | |||
| kernel = self.id_tensor | |||
| running_mean = branch.running_mean | |||
| running_var = branch.running_var | |||
| gamma = branch.weight | |||
| beta = branch.bias | |||
| eps = branch.eps | |||
| std = (running_var + eps).sqrt() | |||
| t = (gamma / std).reshape(-1, 1, 1, 1) | |||
| return kernel * t, beta - running_mean * gamma / std | |||
| class BasicBlock(nn.Module): | |||
| def __init__(self, ch_in, ch_out, act='relu', shortcut=True): | |||
| super(BasicBlock, self).__init__() | |||
| assert ch_in == ch_out | |||
| # self.conv1 = ConvBNLayer(ch_in, ch_out, 3, stride=1, padding=1, act=act) | |||
| # self.conv1 = ConvBNLayer(ch_in, ch_out, 1, stride=1, padding=0, act=act) | |||
| self.conv2 = RepVGGBlock(ch_in, ch_out, act=act) | |||
| self.shortcut = shortcut | |||
| def forward(self, x): | |||
| # y = self.conv1(x) | |||
| y = self.conv2(x) | |||
| if self.shortcut: | |||
| return x + y | |||
| else: | |||
| return y | |||
| class BasicBlock_3x3(nn.Module): | |||
| def __init__(self, ch_in, ch_out, act='relu', shortcut=True): | |||
| super(BasicBlock_3x3, self).__init__() | |||
| assert ch_in == ch_out | |||
| self.conv1 = ConvBNLayer( | |||
| ch_in, ch_out, 3, stride=1, padding=1, act=act) | |||
| # self.conv1 = ConvBNLayer(ch_in, ch_out, 1, stride=1, padding=0, act=act) | |||
| self.conv2 = RepVGGBlock(ch_in, ch_out, act=act) | |||
| self.shortcut = shortcut | |||
| def forward(self, x): | |||
| y = self.conv1(x) | |||
| y = self.conv2(y) | |||
| if self.shortcut: | |||
| return x + y | |||
| else: | |||
| return y | |||
| class BasicBlock_3x3_Reverse(nn.Module): | |||
| def __init__(self, ch_in, ch_out, act='relu', shortcut=True): | |||
| super(BasicBlock_3x3_Reverse, self).__init__() | |||
| assert ch_in == ch_out | |||
| self.conv1 = ConvBNLayer( | |||
| ch_in, ch_out, 3, stride=1, padding=1, act=act) | |||
| # self.conv1 = ConvBNLayer(ch_in, ch_out, 1, stride=1, padding=0, act=act) | |||
| self.conv2 = RepVGGBlock(ch_in, ch_out, act=act) | |||
| self.shortcut = shortcut | |||
| def forward(self, x): | |||
| y = self.conv2(x) | |||
| y = self.conv1(y) | |||
| if self.shortcut: | |||
| return x + y | |||
| else: | |||
| return y | |||
| class SPP(nn.Module): | |||
| def __init__( | |||
| self, | |||
| ch_in, | |||
| ch_out, | |||
| k, | |||
| pool_size, | |||
| act='swish', | |||
| ): | |||
| super(SPP, self).__init__() | |||
| self.pool = [] | |||
| for i, size in enumerate(pool_size): | |||
| pool = nn.MaxPool2d( | |||
| kernel_size=size, stride=1, padding=size // 2, ceil_mode=False) | |||
| self.add_module('pool{}'.format(i), pool) | |||
| self.pool.append(pool) | |||
| self.conv = ConvBNLayer(ch_in, ch_out, k, padding=k // 2, act=act) | |||
| def forward(self, x): | |||
| outs = [x] | |||
| for pool in self.pool: | |||
| outs.append(pool(x)) | |||
| y = torch.cat(outs, axis=1) | |||
| y = self.conv(y) | |||
| return y | |||
| class CSPStage(nn.Module): | |||
| def __init__(self, block_fn, ch_in, ch_out, n, act='swish', spp=False): | |||
| super(CSPStage, self).__init__() | |||
| ch_mid = int(ch_out // 2) | |||
| self.conv1 = ConvBNLayer(ch_in, ch_mid, 1, act=act) | |||
| self.conv2 = ConvBNLayer(ch_in, ch_mid, 1, act=act) | |||
| # self.conv2 = ConvBNLayer(ch_in, ch_mid, 3, stride=1, padding=1, act=act) | |||
| self.convs = nn.Sequential() | |||
| next_ch_in = ch_mid | |||
| for i in range(n): | |||
| if block_fn == 'BasicBlock': | |||
| self.convs.add_module( | |||
| str(i), | |||
| BasicBlock(next_ch_in, ch_mid, act=act, shortcut=False)) | |||
| elif block_fn == 'BasicBlock_3x3': | |||
| self.convs.add_module( | |||
| str(i), | |||
| BasicBlock_3x3(next_ch_in, ch_mid, act=act, shortcut=True)) | |||
| elif block_fn == 'BasicBlock_3x3_Reverse': | |||
| self.convs.add_module( | |||
| str(i), | |||
| BasicBlock_3x3_Reverse( | |||
| next_ch_in, ch_mid, act=act, shortcut=True)) | |||
| else: | |||
| raise NotImplementedError | |||
| if i == (n - 1) // 2 and spp: | |||
| self.convs.add_module( | |||
| 'spp', SPP(ch_mid * 4, ch_mid, 1, [5, 9, 13], act=act)) | |||
| next_ch_in = ch_mid | |||
| # self.convs = nn.Sequential(*convs) | |||
| self.conv3 = ConvBNLayer(ch_mid * (n + 1), ch_out, 1, act=act) | |||
| def forward(self, x): | |||
| y1 = self.conv1(x) | |||
| y2 = self.conv2(x) | |||
| mid_out = [y1] | |||
| for conv in self.convs: | |||
| y2 = conv(y2) | |||
| mid_out.append(y2) | |||
| y = torch.cat(mid_out, axis=1) | |||
| y = self.conv3(y) | |||
| return y | |||
| @@ -0,0 +1,205 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| # The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet. | |||
| import numpy as np | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| import torch.nn.init as init | |||
| from torch.nn.parameter import Parameter | |||
| def get_activation(name='silu', inplace=True): | |||
| if name == 'silu': | |||
| module = nn.SiLU(inplace=inplace) | |||
| elif name == 'relu': | |||
| module = nn.ReLU(inplace=inplace) | |||
| elif name == 'lrelu': | |||
| module = nn.LeakyReLU(0.1, inplace=inplace) | |||
| elif name == 'identity': | |||
| module = nn.Identity() | |||
| else: | |||
| raise AttributeError('Unsupported act type: {}'.format(name)) | |||
| return module | |||
| def conv_bn(in_channels, out_channels, kernel_size, stride, padding, groups=1): | |||
| '''Basic cell for rep-style block, including conv and bn''' | |||
| result = nn.Sequential() | |||
| result.add_module( | |||
| 'conv', | |||
| nn.Conv2d( | |||
| in_channels=in_channels, | |||
| out_channels=out_channels, | |||
| kernel_size=kernel_size, | |||
| stride=stride, | |||
| padding=padding, | |||
| groups=groups, | |||
| bias=False)) | |||
| result.add_module('bn', nn.BatchNorm2d(num_features=out_channels)) | |||
| return result | |||
| class RepVggBlock(nn.Module): | |||
| '''RepVggBlock is a basic rep-style block, including training and deploy status | |||
| This code is based on https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py | |||
| ''' | |||
| def __init__(self, | |||
| in_channels, | |||
| out_channels, | |||
| kernel_size=3, | |||
| stride=1, | |||
| padding=1, | |||
| dilation=1, | |||
| groups=1, | |||
| padding_mode='zeros', | |||
| deploy=False, | |||
| use_se=False, | |||
| act='relu', | |||
| norm=None): | |||
| super(RepVggBlock, self).__init__() | |||
| """ Initialization of the class. | |||
| Args: | |||
| in_channels (int): Number of channels in the input image | |||
| out_channels (int): Number of channels produced by the convolution | |||
| kernel_size (int or tuple): Size of the convolving kernel | |||
| stride (int or tuple, optional): Stride of the convolution. Default: 1 | |||
| padding (int or tuple, optional): Zero-padding added to both sides of | |||
| the input. Default: 1 | |||
| dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 | |||
| groups (int, optional): Number of blocked connections from input | |||
| channels to output channels. Default: 1 | |||
| padding_mode (string, optional): Default: 'zeros' | |||
| deploy: Whether to be deploy status or training status. Default: False | |||
| use_se: Whether to use se. Default: False | |||
| """ | |||
| self.deploy = deploy | |||
| self.groups = groups | |||
| self.in_channels = in_channels | |||
| self.out_channels = out_channels | |||
| assert kernel_size == 3 | |||
| assert padding == 1 | |||
| padding_11 = padding - kernel_size // 2 | |||
| if isinstance(act, str): | |||
| self.nonlinearity = get_activation(act) | |||
| else: | |||
| self.nonlinearity = act | |||
| if use_se: | |||
| raise NotImplementedError('se block not supported yet') | |||
| else: | |||
| self.se = nn.Identity() | |||
| if deploy: | |||
| self.rbr_reparam = nn.Conv2d( | |||
| in_channels=in_channels, | |||
| out_channels=out_channels, | |||
| kernel_size=kernel_size, | |||
| stride=stride, | |||
| padding=padding, | |||
| dilation=dilation, | |||
| groups=groups, | |||
| bias=True, | |||
| padding_mode=padding_mode) | |||
| else: | |||
| self.rbr_identity = None | |||
| self.rbr_dense = conv_bn( | |||
| in_channels=in_channels, | |||
| out_channels=out_channels, | |||
| kernel_size=kernel_size, | |||
| stride=stride, | |||
| padding=padding, | |||
| groups=groups) | |||
| self.rbr_1x1 = conv_bn( | |||
| in_channels=in_channels, | |||
| out_channels=out_channels, | |||
| kernel_size=1, | |||
| stride=stride, | |||
| padding=padding_11, | |||
| groups=groups) | |||
| def forward(self, inputs): | |||
| '''Forward process''' | |||
| if hasattr(self, 'rbr_reparam'): | |||
| return self.nonlinearity(self.se(self.rbr_reparam(inputs))) | |||
| if self.rbr_identity is None: | |||
| id_out = 0 | |||
| else: | |||
| id_out = self.rbr_identity(inputs) | |||
| return self.nonlinearity( | |||
| self.se(self.rbr_dense(inputs) + self.rbr_1x1(inputs) + id_out)) | |||
| def get_equivalent_kernel_bias(self): | |||
| kernel3x3, bias3x3 = self._fuse_bn_tensor(self.rbr_dense) | |||
| kernel1x1, bias1x1 = self._fuse_bn_tensor(self.rbr_1x1) | |||
| kernelid, biasid = self._fuse_bn_tensor(self.rbr_identity) | |||
| return kernel3x3 + self._pad_1x1_to_3x3_tensor( | |||
| kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid | |||
| def _pad_1x1_to_3x3_tensor(self, kernel1x1): | |||
| if kernel1x1 is None: | |||
| return 0 | |||
| else: | |||
| return torch.nn.functional.pad(kernel1x1, [1, 1, 1, 1]) | |||
| def _fuse_bn_tensor(self, branch): | |||
| if branch is None: | |||
| return 0, 0 | |||
| if isinstance(branch, nn.Sequential): | |||
| kernel = branch.conv.weight | |||
| running_mean = branch.bn.running_mean | |||
| running_var = branch.bn.running_var | |||
| gamma = branch.bn.weight | |||
| beta = branch.bn.bias | |||
| eps = branch.bn.eps | |||
| else: | |||
| assert isinstance(branch, nn.BatchNorm2d) | |||
| if not hasattr(self, 'id_tensor'): | |||
| input_dim = self.in_channels // self.groups | |||
| kernel_value = np.zeros((self.in_channels, input_dim, 3, 3), | |||
| dtype=np.float32) | |||
| for i in range(self.in_channels): | |||
| kernel_value[i, i % input_dim, 1, 1] = 1 | |||
| self.id_tensor = torch.from_numpy(kernel_value).to( | |||
| branch.weight.device) | |||
| kernel = self.id_tensor | |||
| running_mean = branch.running_mean | |||
| running_var = branch.running_var | |||
| gamma = branch.weight | |||
| beta = branch.bias | |||
| eps = branch.eps | |||
| std = (running_var + eps).sqrt() | |||
| t = (gamma / std).reshape(-1, 1, 1, 1) | |||
| return kernel * t, beta - running_mean * gamma / std | |||
| def switch_to_deploy(self): | |||
| if hasattr(self, 'rbr_reparam'): | |||
| return | |||
| kernel, bias = self.get_equivalent_kernel_bias() | |||
| self.rbr_reparam = nn.Conv2d( | |||
| in_channels=self.rbr_dense.conv.in_channels, | |||
| out_channels=self.rbr_dense.conv.out_channels, | |||
| kernel_size=self.rbr_dense.conv.kernel_size, | |||
| stride=self.rbr_dense.conv.stride, | |||
| padding=self.rbr_dense.conv.padding, | |||
| dilation=self.rbr_dense.conv.dilation, | |||
| groups=self.rbr_dense.conv.groups, | |||
| bias=True) | |||
| self.rbr_reparam.weight.data = kernel | |||
| self.rbr_reparam.bias.data = bias | |||
| for para in self.parameters(): | |||
| para.detach_() | |||
| self.__delattr__('rbr_dense') | |||
| self.__delattr__('rbr_1x1') | |||
| if hasattr(self, 'rbr_identity'): | |||
| self.__delattr__('rbr_identity') | |||
| if hasattr(self, 'id_tensor'): | |||
| self.__delattr__('id_tensor') | |||
| self.deploy = True | |||
| @@ -0,0 +1,196 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| # The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet. | |||
| import numpy as np | |||
| import torch | |||
| import torchvision | |||
| __all__ = [ | |||
| 'filter_box', | |||
| 'postprocess_airdet', | |||
| 'bboxes_iou', | |||
| 'matrix_iou', | |||
| 'adjust_box_anns', | |||
| 'xyxy2xywh', | |||
| 'xyxy2cxcywh', | |||
| ] | |||
| def multiclass_nms(multi_bboxes, | |||
| multi_scores, | |||
| score_thr, | |||
| iou_thr, | |||
| max_num=100, | |||
| score_factors=None): | |||
| """NMS for multi-class bboxes. | |||
| Args: | |||
| multi_bboxes (Tensor): shape (n, #class*4) or (n, 4) | |||
| multi_scores (Tensor): shape (n, #class), where the last column | |||
| contains scores of the background class, but this will be ignored. | |||
| score_thr (float): bbox threshold, bboxes with scores lower than it | |||
| will not be considered. | |||
| nms_thr (float): NMS IoU threshold | |||
| max_num (int): if there are more than max_num bboxes after NMS, | |||
| only top max_num will be kept. | |||
| score_factors (Tensor): The factors multiplied to scores before | |||
| applying NMS | |||
| Returns: | |||
| tuple: (bboxes, labels), tensors of shape (k, 5) and (k, 1). Labels \ | |||
| are 0-based. | |||
| """ | |||
| num_classes = multi_scores.size(1) | |||
| # exclude background category | |||
| if multi_bboxes.shape[1] > 4: | |||
| bboxes = multi_bboxes.view(multi_scores.size(0), -1, 4) | |||
| else: | |||
| bboxes = multi_bboxes[:, None].expand( | |||
| multi_scores.size(0), num_classes, 4) | |||
| scores = multi_scores | |||
| # filter out boxes with low scores | |||
| valid_mask = scores > score_thr # 1000 * 80 bool | |||
| # We use masked_select for ONNX exporting purpose, | |||
| # which is equivalent to bboxes = bboxes[valid_mask] | |||
| # (TODO): as ONNX does not support repeat now, | |||
| # we have to use this ugly code | |||
| # bboxes -> 1000, 4 | |||
| bboxes = torch.masked_select( | |||
| bboxes, | |||
| torch.stack((valid_mask, valid_mask, valid_mask, valid_mask), | |||
| -1)).view(-1, 4) # mask-> 1000*80*4, 80000*4 | |||
| if score_factors is not None: | |||
| scores = scores * score_factors[:, None] | |||
| scores = torch.masked_select(scores, valid_mask) | |||
| labels = valid_mask.nonzero(as_tuple=False)[:, 1] | |||
| if bboxes.numel() == 0: | |||
| bboxes = multi_bboxes.new_zeros((0, 5)) | |||
| labels = multi_bboxes.new_zeros((0, ), dtype=torch.long) | |||
| scores = multi_bboxes.new_zeros((0, )) | |||
| return bboxes, scores, labels | |||
| keep = torchvision.ops.batched_nms(bboxes, scores, labels, iou_thr) | |||
| if max_num > 0: | |||
| keep = keep[:max_num] | |||
| return bboxes[keep], scores[keep], labels[keep] | |||
| def filter_box(output, scale_range): | |||
| """ | |||
| output: (N, 5+class) shape | |||
| """ | |||
| min_scale, max_scale = scale_range | |||
| w = output[:, 2] - output[:, 0] | |||
| h = output[:, 3] - output[:, 1] | |||
| keep = (w * h > min_scale * min_scale) & (w * h < max_scale * max_scale) | |||
| return output[keep] | |||
| def filter_results(boxlist, num_classes, nms_thre): | |||
| boxes = boxlist.bbox | |||
| scores = boxlist.get_field('scores') | |||
| cls = boxlist.get_field('labels') | |||
| nms_out_index = torchvision.ops.batched_nms( | |||
| boxes, | |||
| scores, | |||
| cls, | |||
| nms_thre, | |||
| ) | |||
| boxlist = boxlist[nms_out_index] | |||
| return boxlist | |||
| def postprocess_airdet(prediction, | |||
| num_classes, | |||
| conf_thre=0.7, | |||
| nms_thre=0.45, | |||
| imgs=None): | |||
| box_corner = prediction.new(prediction.shape) | |||
| box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2 | |||
| box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2 | |||
| box_corner[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2 | |||
| box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2 | |||
| prediction[:, :, :4] = box_corner[:, :, :4] | |||
| output = [None for _ in range(len(prediction))] | |||
| for i, image_pred in enumerate(prediction): | |||
| # If none are remaining => process next image | |||
| if not image_pred.size(0): | |||
| continue | |||
| multi_bboxes = image_pred[:, :4] | |||
| multi_scores = image_pred[:, 5:] | |||
| detections, scores, labels = multiclass_nms(multi_bboxes, multi_scores, | |||
| conf_thre, nms_thre, 500) | |||
| detections = torch.cat( | |||
| (detections, scores[:, None], scores[:, None], labels[:, None]), | |||
| dim=1) | |||
| if output[i] is None: | |||
| output[i] = detections | |||
| else: | |||
| output[i] = torch.cat((output[i], detections)) | |||
| return output | |||
| def bboxes_iou(bboxes_a, bboxes_b, xyxy=True): | |||
| if bboxes_a.shape[1] != 4 or bboxes_b.shape[1] != 4: | |||
| raise IndexError | |||
| if xyxy: | |||
| tl = torch.max(bboxes_a[:, None, :2], bboxes_b[:, :2]) | |||
| br = torch.min(bboxes_a[:, None, 2:], bboxes_b[:, 2:]) | |||
| area_a = torch.prod(bboxes_a[:, 2:] - bboxes_a[:, :2], 1) | |||
| area_b = torch.prod(bboxes_b[:, 2:] - bboxes_b[:, :2], 1) | |||
| else: | |||
| tl = torch.max( | |||
| (bboxes_a[:, None, :2] - bboxes_a[:, None, 2:] / 2), | |||
| (bboxes_b[:, :2] - bboxes_b[:, 2:] / 2), | |||
| ) | |||
| br = torch.min( | |||
| (bboxes_a[:, None, :2] + bboxes_a[:, None, 2:] / 2), | |||
| (bboxes_b[:, :2] + bboxes_b[:, 2:] / 2), | |||
| ) | |||
| area_a = torch.prod(bboxes_a[:, 2:], 1) | |||
| area_b = torch.prod(bboxes_b[:, 2:], 1) | |||
| en = (tl < br).type(tl.type()).prod(dim=2) | |||
| area_i = torch.prod(br - tl, 2) * en # * ((tl < br).all()) | |||
| return area_i / (area_a[:, None] + area_b - area_i) | |||
| def matrix_iou(a, b): | |||
| """ | |||
| return iou of a and b, numpy version for data augenmentation | |||
| """ | |||
| lt = np.maximum(a[:, np.newaxis, :2], b[:, :2]) | |||
| rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:]) | |||
| area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2) | |||
| area_a = np.prod(a[:, 2:] - a[:, :2], axis=1) | |||
| area_b = np.prod(b[:, 2:] - b[:, :2], axis=1) | |||
| return area_i / (area_a[:, np.newaxis] + area_b - area_i + 1e-12) | |||
| def adjust_box_anns(bbox, scale_ratio, padw, padh, w_max, h_max): | |||
| bbox[:, 0::2] = np.clip(bbox[:, 0::2] * scale_ratio + padw, 0, w_max) | |||
| bbox[:, 1::2] = np.clip(bbox[:, 1::2] * scale_ratio + padh, 0, h_max) | |||
| return bbox | |||
| def xyxy2xywh(bboxes): | |||
| bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 0] | |||
| bboxes[:, 3] = bboxes[:, 3] - bboxes[:, 1] | |||
| return bboxes | |||
| def xyxy2cxcywh(bboxes): | |||
| bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 0] | |||
| bboxes[:, 3] = bboxes[:, 3] - bboxes[:, 1] | |||
| bboxes[:, 0] = bboxes[:, 0] + bboxes[:, 2] * 0.5 | |||
| bboxes[:, 1] = bboxes[:, 1] + bboxes[:, 3] * 0.5 | |||
| return bboxes | |||
| @@ -0,0 +1,181 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| # The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet. | |||
| import os.path as osp | |||
| import pickle | |||
| import cv2 | |||
| import torch | |||
| import torchvision | |||
| from modelscope.metainfo import Models | |||
| from modelscope.models.base.base_torch_model import TorchModel | |||
| from modelscope.models.builder import MODELS | |||
| from modelscope.utils.config import Config | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| from .backbone import build_backbone | |||
| from .head import build_head | |||
| from .neck import build_neck | |||
| from .utils import parse_config | |||
| class SingleStageDetector(TorchModel): | |||
| """ | |||
| The base class of single stage detector. | |||
| """ | |||
| def __init__(self, model_dir: str, *args, **kwargs): | |||
| """ | |||
| init model by cfg | |||
| """ | |||
| super().__init__(model_dir, *args, **kwargs) | |||
| config_path = osp.join(model_dir, 'airdet_s.py') | |||
| config = parse_config(config_path) | |||
| self.cfg = config | |||
| model_path = osp.join(model_dir, config.model.name) | |||
| label_map = osp.join(model_dir, config.model.class_map) | |||
| self.label_map = pickle.load(open(label_map, 'rb')) | |||
| self.size_divisible = config.dataset.size_divisibility | |||
| self.num_classes = config.model.head.num_classes | |||
| self.conf_thre = config.model.head.nms_conf_thre | |||
| self.nms_thre = config.model.head.nms_iou_thre | |||
| self.backbone = build_backbone(self.cfg.model.backbone) | |||
| self.neck = build_neck(self.cfg.model.neck) | |||
| self.head = build_head(self.cfg.model.head) | |||
| self.load_pretrain_model(model_path) | |||
| def load_pretrain_model(self, pretrain_model): | |||
| state_dict = torch.load(pretrain_model, map_location='cpu')['model'] | |||
| new_state_dict = {} | |||
| for k, v in state_dict.items(): | |||
| k = k.replace('module.', '') | |||
| new_state_dict[k] = v | |||
| self.load_state_dict(new_state_dict, strict=True) | |||
| def inference(self, x): | |||
| if self.training: | |||
| return self.forward_train(x) | |||
| else: | |||
| return self.forward_eval(x) | |||
| def forward_train(self, x): | |||
| pass | |||
| def forward_eval(self, x): | |||
| x = self.backbone(x) | |||
| x = self.neck(x) | |||
| prediction = self.head(x) | |||
| return prediction | |||
| def preprocess(self, image): | |||
| image = torch.from_numpy(image).type(torch.float32) | |||
| image = image.permute(2, 0, 1) | |||
| shape = image.shape # c, h, w | |||
| if self.size_divisible > 0: | |||
| import math | |||
| stride = self.size_divisible | |||
| shape = list(shape) | |||
| shape[1] = int(math.ceil(shape[1] / stride) * stride) | |||
| shape[2] = int(math.ceil(shape[2] / stride) * stride) | |||
| shape = tuple(shape) | |||
| pad_img = image.new(*shape).zero_() | |||
| pad_img[:, :image.shape[1], :image.shape[2]].copy_(image) | |||
| pad_img = pad_img.unsqueeze(0) | |||
| return pad_img | |||
| def postprocess(self, preds): | |||
| bboxes, scores, labels_idx = postprocess_gfocal( | |||
| preds, self.num_classes, self.conf_thre, self.nms_thre) | |||
| bboxes = bboxes.cpu().numpy() | |||
| scores = scores.cpu().numpy() | |||
| labels_idx = labels_idx.cpu().numpy() | |||
| labels = [self.label_map[idx + 1][0]['name'] for idx in labels_idx] | |||
| return (bboxes, scores, labels) | |||
| def multiclass_nms(multi_bboxes, | |||
| multi_scores, | |||
| score_thr, | |||
| iou_thr, | |||
| max_num=100, | |||
| score_factors=None): | |||
| """NMS for multi-class bboxes. | |||
| Args: | |||
| multi_bboxes (Tensor): shape (n, #class*4) or (n, 4) | |||
| multi_scores (Tensor): shape (n, #class), where the last column | |||
| contains scores of the background class, but this will be ignored. | |||
| score_thr (float): bbox threshold, bboxes with scores lower than it | |||
| will not be considered. | |||
| nms_thr (float): NMS IoU threshold | |||
| max_num (int): if there are more than max_num bboxes after NMS, | |||
| only top max_num will be kept. | |||
| score_factors (Tensor): The factors multiplied to scores before | |||
| applying NMS | |||
| Returns: | |||
| tuple: (bboxes, labels), tensors of shape (k, 5) and (k, 1). Labels \ | |||
| are 0-based. | |||
| """ | |||
| num_classes = multi_scores.size(1) | |||
| # exclude background category | |||
| if multi_bboxes.shape[1] > 4: | |||
| bboxes = multi_bboxes.view(multi_scores.size(0), -1, 4) | |||
| else: | |||
| bboxes = multi_bboxes[:, None].expand( | |||
| multi_scores.size(0), num_classes, 4) | |||
| scores = multi_scores | |||
| # filter out boxes with low scores | |||
| valid_mask = scores > score_thr # 1000 * 80 bool | |||
| # We use masked_select for ONNX exporting purpose, | |||
| # which is equivalent to bboxes = bboxes[valid_mask] | |||
| # (TODO): as ONNX does not support repeat now, | |||
| # we have to use this ugly code | |||
| # bboxes -> 1000, 4 | |||
| bboxes = torch.masked_select( | |||
| bboxes, | |||
| torch.stack((valid_mask, valid_mask, valid_mask, valid_mask), | |||
| -1)).view(-1, 4) # mask-> 1000*80*4, 80000*4 | |||
| if score_factors is not None: | |||
| scores = scores * score_factors[:, None] | |||
| scores = torch.masked_select(scores, valid_mask) | |||
| labels = valid_mask.nonzero(as_tuple=False)[:, 1] | |||
| if bboxes.numel() == 0: | |||
| bboxes = multi_bboxes.new_zeros((0, 5)) | |||
| labels = multi_bboxes.new_zeros((0, ), dtype=torch.long) | |||
| scores = multi_bboxes.new_zeros((0, )) | |||
| return bboxes, scores, labels | |||
| keep = torchvision.ops.batched_nms(bboxes, scores, labels, iou_thr) | |||
| if max_num > 0: | |||
| keep = keep[:max_num] | |||
| return bboxes[keep], scores[keep], labels[keep] | |||
| def postprocess_gfocal(prediction, num_classes, conf_thre=0.05, nms_thre=0.7): | |||
| assert prediction.shape[0] == 1 | |||
| for i, image_pred in enumerate(prediction): | |||
| # If none are remaining => process next image | |||
| if not image_pred.size(0): | |||
| continue | |||
| multi_bboxes = image_pred[:, :4] | |||
| multi_scores = image_pred[:, 4:] | |||
| detections, scores, labels = multiclass_nms(multi_bboxes, multi_scores, | |||
| conf_thre, nms_thre, 500) | |||
| return detections, scores, labels | |||
| @@ -0,0 +1,16 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| # The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet. | |||
| import copy | |||
| from .gfocal_v2_tiny import GFocalHead_Tiny | |||
| def build_head(cfg): | |||
| head_cfg = copy.deepcopy(cfg) | |||
| name = head_cfg.pop('name') | |||
| if name == 'GFocalV2': | |||
| return GFocalHead_Tiny(**head_cfg) | |||
| else: | |||
| raise NotImplementedError | |||
| @@ -0,0 +1,361 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| # The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet. | |||
| import functools | |||
| from functools import partial | |||
| import numpy as np | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| from ..core.base_ops import BaseConv, DWConv | |||
| class Scale(nn.Module): | |||
| def __init__(self, scale=1.0): | |||
| super(Scale, self).__init__() | |||
| self.scale = nn.Parameter(torch.tensor(scale, dtype=torch.float)) | |||
| def forward(self, x): | |||
| return x * self.scale | |||
| def multi_apply(func, *args, **kwargs): | |||
| pfunc = partial(func, **kwargs) if kwargs else func | |||
| map_results = map(pfunc, *args) | |||
| return tuple(map(list, zip(*map_results))) | |||
| def xyxy2CxCywh(xyxy, size=None): | |||
| x1 = xyxy[..., 0] | |||
| y1 = xyxy[..., 1] | |||
| x2 = xyxy[..., 2] | |||
| y2 = xyxy[..., 3] | |||
| cx = (x1 + x2) / 2 | |||
| cy = (y1 + y2) / 2 | |||
| w = x2 - x1 | |||
| h = y2 - y1 | |||
| if size is not None: | |||
| w = w.clamp(min=0, max=size[1]) | |||
| h = h.clamp(min=0, max=size[0]) | |||
| return torch.stack([cx, cy, w, h], axis=-1) | |||
| def distance2bbox(points, distance, max_shape=None): | |||
| """Decode distance prediction to bounding box. | |||
| """ | |||
| x1 = points[..., 0] - distance[..., 0] | |||
| y1 = points[..., 1] - distance[..., 1] | |||
| x2 = points[..., 0] + distance[..., 2] | |||
| y2 = points[..., 1] + distance[..., 3] | |||
| if max_shape is not None: | |||
| x1 = x1.clamp(min=0, max=max_shape[1]) | |||
| y1 = y1.clamp(min=0, max=max_shape[0]) | |||
| x2 = x2.clamp(min=0, max=max_shape[1]) | |||
| y2 = y2.clamp(min=0, max=max_shape[0]) | |||
| return torch.stack([x1, y1, x2, y2], -1) | |||
| def bbox2distance(points, bbox, max_dis=None, eps=0.1): | |||
| """Decode bounding box based on distances. | |||
| """ | |||
| left = points[:, 0] - bbox[:, 0] | |||
| top = points[:, 1] - bbox[:, 1] | |||
| right = bbox[:, 2] - points[:, 0] | |||
| bottom = bbox[:, 3] - points[:, 1] | |||
| if max_dis is not None: | |||
| left = left.clamp(min=0, max=max_dis - eps) | |||
| top = top.clamp(min=0, max=max_dis - eps) | |||
| right = right.clamp(min=0, max=max_dis - eps) | |||
| bottom = bottom.clamp(min=0, max=max_dis - eps) | |||
| return torch.stack([left, top, right, bottom], -1) | |||
| class Integral(nn.Module): | |||
| """A fixed layer for calculating integral result from distribution. | |||
| """ | |||
| def __init__(self, reg_max=16): | |||
| super(Integral, self).__init__() | |||
| self.reg_max = reg_max | |||
| self.register_buffer('project', | |||
| torch.linspace(0, self.reg_max, self.reg_max + 1)) | |||
| def forward(self, x): | |||
| """Forward feature from the regression head to get integral result of | |||
| bounding box location. | |||
| """ | |||
| shape = x.size() | |||
| x = F.softmax(x.reshape(*shape[:-1], 4, self.reg_max + 1), dim=-1) | |||
| b, nb, ne, _ = x.size() | |||
| x = x.reshape(b * nb * ne, self.reg_max + 1) | |||
| y = self.project.type_as(x).unsqueeze(1) | |||
| x = torch.matmul(x, y).reshape(b, nb, 4) | |||
| return x | |||
| class GFocalHead_Tiny(nn.Module): | |||
| """Ref to Generalized Focal Loss V2: Learning Reliable Localization Quality | |||
| Estimation for Dense Object Detection. | |||
| """ | |||
| def __init__( | |||
| self, | |||
| num_classes, | |||
| in_channels, | |||
| stacked_convs=4, # 4 | |||
| feat_channels=256, | |||
| reg_max=12, | |||
| reg_topk=4, | |||
| reg_channels=64, | |||
| strides=[8, 16, 32], | |||
| add_mean=True, | |||
| norm='gn', | |||
| act='relu', | |||
| start_kernel_size=3, | |||
| conv_groups=1, | |||
| conv_type='BaseConv', | |||
| simOTA_cls_weight=1.0, | |||
| simOTA_iou_weight=3.0, | |||
| octbase=8, | |||
| simlqe=False, | |||
| **kwargs): | |||
| self.simlqe = simlqe | |||
| self.num_classes = num_classes | |||
| self.in_channels = in_channels | |||
| self.strides = strides | |||
| self.feat_channels = feat_channels if isinstance(feat_channels, list) \ | |||
| else [feat_channels] * len(self.strides) | |||
| self.cls_out_channels = num_classes + 1 # add 1 for keep consistance with former models | |||
| # and will be deprecated in future. | |||
| self.stacked_convs = stacked_convs | |||
| self.conv_groups = conv_groups | |||
| self.reg_max = reg_max | |||
| self.reg_topk = reg_topk | |||
| self.reg_channels = reg_channels | |||
| self.add_mean = add_mean | |||
| self.total_dim = reg_topk | |||
| self.start_kernel_size = start_kernel_size | |||
| self.norm = norm | |||
| self.act = act | |||
| self.conv_module = DWConv if conv_type == 'DWConv' else BaseConv | |||
| if add_mean: | |||
| self.total_dim += 1 | |||
| super(GFocalHead_Tiny, self).__init__() | |||
| self.integral = Integral(self.reg_max) | |||
| self._init_layers() | |||
| def _build_not_shared_convs(self, in_channel, feat_channels): | |||
| self.relu = nn.ReLU(inplace=True) | |||
| cls_convs = nn.ModuleList() | |||
| reg_convs = nn.ModuleList() | |||
| for i in range(self.stacked_convs): | |||
| chn = feat_channels if i > 0 else in_channel | |||
| kernel_size = 3 if i > 0 else self.start_kernel_size | |||
| cls_convs.append( | |||
| self.conv_module( | |||
| chn, | |||
| feat_channels, | |||
| kernel_size, | |||
| stride=1, | |||
| groups=self.conv_groups, | |||
| norm=self.norm, | |||
| act=self.act)) | |||
| reg_convs.append( | |||
| self.conv_module( | |||
| chn, | |||
| feat_channels, | |||
| kernel_size, | |||
| stride=1, | |||
| groups=self.conv_groups, | |||
| norm=self.norm, | |||
| act=self.act)) | |||
| if not self.simlqe: | |||
| conf_vector = [nn.Conv2d(4 * self.total_dim, self.reg_channels, 1)] | |||
| else: | |||
| conf_vector = [ | |||
| nn.Conv2d(4 * (self.reg_max + 1), self.reg_channels, 1) | |||
| ] | |||
| conf_vector += [self.relu] | |||
| conf_vector += [nn.Conv2d(self.reg_channels, 1, 1), nn.Sigmoid()] | |||
| reg_conf = nn.Sequential(*conf_vector) | |||
| return cls_convs, reg_convs, reg_conf | |||
| def _init_layers(self): | |||
| """Initialize layers of the head.""" | |||
| self.relu = nn.ReLU(inplace=True) | |||
| self.cls_convs = nn.ModuleList() | |||
| self.reg_convs = nn.ModuleList() | |||
| self.reg_confs = nn.ModuleList() | |||
| for i in range(len(self.strides)): | |||
| cls_convs, reg_convs, reg_conf = self._build_not_shared_convs( | |||
| self.in_channels[i], self.feat_channels[i]) | |||
| self.cls_convs.append(cls_convs) | |||
| self.reg_convs.append(reg_convs) | |||
| self.reg_confs.append(reg_conf) | |||
| self.gfl_cls = nn.ModuleList([ | |||
| nn.Conv2d( | |||
| self.feat_channels[i], self.cls_out_channels, 3, padding=1) | |||
| for i in range(len(self.strides)) | |||
| ]) | |||
| self.gfl_reg = nn.ModuleList([ | |||
| nn.Conv2d( | |||
| self.feat_channels[i], 4 * (self.reg_max + 1), 3, padding=1) | |||
| for i in range(len(self.strides)) | |||
| ]) | |||
| self.scales = nn.ModuleList([Scale(1.0) for _ in self.strides]) | |||
| def forward(self, | |||
| xin, | |||
| labels=None, | |||
| imgs=None, | |||
| conf_thre=0.05, | |||
| nms_thre=0.7): | |||
| # prepare labels during training | |||
| b, c, h, w = xin[0].shape | |||
| if labels is not None: | |||
| gt_bbox_list = [] | |||
| gt_cls_list = [] | |||
| for label in labels: | |||
| gt_bbox_list.append(label.bbox) | |||
| gt_cls_list.append((label.get_field('labels') | |||
| - 1).long()) # labels starts from 1 | |||
| # prepare priors for label assignment and bbox decode | |||
| mlvl_priors_list = [ | |||
| self.get_single_level_center_priors( | |||
| xin[i].shape[0], | |||
| xin[i].shape[-2:], | |||
| stride, | |||
| dtype=torch.float32, | |||
| device=xin[0].device) for i, stride in enumerate(self.strides) | |||
| ] | |||
| mlvl_priors = torch.cat(mlvl_priors_list, dim=1) | |||
| # forward for bboxes and classification prediction | |||
| cls_scores, bbox_preds = multi_apply( | |||
| self.forward_single, | |||
| xin, | |||
| self.cls_convs, | |||
| self.reg_convs, | |||
| self.gfl_cls, | |||
| self.gfl_reg, | |||
| self.reg_confs, | |||
| self.scales, | |||
| ) | |||
| flatten_cls_scores = torch.cat(cls_scores, dim=1) | |||
| flatten_bbox_preds = torch.cat(bbox_preds, dim=1) | |||
| # calculating losses or bboxes decoded | |||
| if self.training: | |||
| loss = self.loss(flatten_cls_scores, flatten_bbox_preds, | |||
| gt_bbox_list, gt_cls_list, mlvl_priors) | |||
| return loss | |||
| else: | |||
| output = self.get_bboxes(flatten_cls_scores, flatten_bbox_preds, | |||
| mlvl_priors) | |||
| return output | |||
| def forward_single(self, x, cls_convs, reg_convs, gfl_cls, gfl_reg, | |||
| reg_conf, scale): | |||
| """Forward feature of a single scale level. | |||
| """ | |||
| cls_feat = x | |||
| reg_feat = x | |||
| for cls_conv in cls_convs: | |||
| cls_feat = cls_conv(cls_feat) | |||
| for reg_conv in reg_convs: | |||
| reg_feat = reg_conv(reg_feat) | |||
| bbox_pred = scale(gfl_reg(reg_feat)).float() | |||
| N, C, H, W = bbox_pred.size() | |||
| prob = F.softmax( | |||
| bbox_pred.reshape(N, 4, self.reg_max + 1, H, W), dim=2) | |||
| if not self.simlqe: | |||
| prob_topk, _ = prob.topk(self.reg_topk, dim=2) | |||
| if self.add_mean: | |||
| stat = torch.cat( | |||
| [prob_topk, prob_topk.mean(dim=2, keepdim=True)], dim=2) | |||
| else: | |||
| stat = prob_topk | |||
| quality_score = reg_conf(stat.reshape(N, 4 * self.total_dim, H, W)) | |||
| else: | |||
| quality_score = reg_conf( | |||
| bbox_pred.reshape(N, 4 * (self.reg_max + 1), H, W)) | |||
| cls_score = gfl_cls(cls_feat).sigmoid() * quality_score | |||
| flatten_cls_score = cls_score.flatten(start_dim=2).transpose(1, 2) | |||
| flatten_bbox_pred = bbox_pred.flatten(start_dim=2).transpose(1, 2) | |||
| return flatten_cls_score, flatten_bbox_pred | |||
| def get_single_level_center_priors(self, batch_size, featmap_size, stride, | |||
| dtype, device): | |||
| h, w = featmap_size | |||
| x_range = (torch.arange(0, int(w), dtype=dtype, | |||
| device=device)) * stride | |||
| y_range = (torch.arange(0, int(h), dtype=dtype, | |||
| device=device)) * stride | |||
| x = x_range.repeat(h, 1) | |||
| y = y_range.unsqueeze(-1).repeat(1, w) | |||
| y = y.flatten() | |||
| x = x.flatten() | |||
| strides = x.new_full((x.shape[0], ), stride) | |||
| priors = torch.stack([x, y, strides, strides], dim=-1) | |||
| return priors.unsqueeze(0).repeat(batch_size, 1, 1) | |||
| def sample(self, assign_result, gt_bboxes): | |||
| pos_inds = torch.nonzero( | |||
| assign_result.gt_inds > 0, as_tuple=False).squeeze(-1).unique() | |||
| neg_inds = torch.nonzero( | |||
| assign_result.gt_inds == 0, as_tuple=False).squeeze(-1).unique() | |||
| pos_assigned_gt_inds = assign_result.gt_inds[pos_inds] - 1 | |||
| if gt_bboxes.numel() == 0: | |||
| # hack for index error case | |||
| assert pos_assigned_gt_inds.numel() == 0 | |||
| pos_gt_bboxes = torch.empty_like(gt_bboxes).view(-1, 4) | |||
| else: | |||
| if len(gt_bboxes.shape) < 2: | |||
| gt_bboxes = gt_bboxes.view(-1, 4) | |||
| pos_gt_bboxes = gt_bboxes[pos_assigned_gt_inds, :] | |||
| return pos_inds, neg_inds, pos_gt_bboxes, pos_assigned_gt_inds | |||
| def get_bboxes(self, | |||
| cls_preds, | |||
| reg_preds, | |||
| mlvl_center_priors, | |||
| img_meta=None): | |||
| dis_preds = self.integral(reg_preds) * mlvl_center_priors[..., 2, None] | |||
| bboxes = distance2bbox(mlvl_center_priors[..., :2], dis_preds) | |||
| res = torch.cat([bboxes, cls_preds[..., 0:self.num_classes]], dim=-1) | |||
| return res | |||
| @@ -0,0 +1,16 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| # The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet. | |||
| import copy | |||
| from .giraffe_fpn import GiraffeNeck | |||
| from .giraffe_fpn_v2 import GiraffeNeckV2 | |||
| def build_neck(cfg): | |||
| neck_cfg = copy.deepcopy(cfg) | |||
| name = neck_cfg.pop('name') | |||
| if name == 'GiraffeNeck': | |||
| return GiraffeNeck(**neck_cfg) | |||
| elif name == 'GiraffeNeckV2': | |||
| return GiraffeNeckV2(**neck_cfg) | |||
| @@ -0,0 +1,235 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| # The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet. | |||
| import collections | |||
| import itertools | |||
| import os | |||
| import networkx as nx | |||
| from omegaconf import OmegaConf | |||
| Node = collections.namedtuple('Node', ['id', 'inputs', 'type']) | |||
| def get_graph_info(graph): | |||
| input_nodes = [] | |||
| output_nodes = [] | |||
| Nodes = [] | |||
| for node in range(graph.number_of_nodes()): | |||
| tmp = list(graph.neighbors(node)) | |||
| tmp.sort() | |||
| type = -1 | |||
| if node < tmp[0]: | |||
| input_nodes.append(node) | |||
| type = 0 | |||
| if node > tmp[-1]: | |||
| output_nodes.append(node) | |||
| type = 1 | |||
| Nodes.append(Node(node, [n for n in tmp if n < node], type)) | |||
| return Nodes, input_nodes, output_nodes | |||
| def nodeid_trans(id, cur_level, num_levels): | |||
| if id % 2 == 1: | |||
| gap = int(((id + 1) // 2) * num_levels * 2) | |||
| else: | |||
| a = (num_levels - cur_level) * 2 - 1 | |||
| b = ((id + 1) // 2) * num_levels * 2 | |||
| gap = int(a + b) | |||
| return cur_level + gap | |||
| def gen_log2n_graph_file(log2n_graph_file, depth_multiplier): | |||
| f = open(log2n_graph_file, 'w') | |||
| for i in range(depth_multiplier): | |||
| for j in [1, 2, 4, 8, 16, 32]: | |||
| if i - j < 0: | |||
| break | |||
| else: | |||
| f.write('%d,%d\n' % (i - j, i)) | |||
| f.close() | |||
| def get_log2n_graph(depth_multiplier): | |||
| nodes = [] | |||
| connnections = [] | |||
| for i in range(depth_multiplier): | |||
| nodes.append(i) | |||
| for j in [1, 2, 4, 8, 16, 32]: | |||
| if i - j < 0: | |||
| break | |||
| else: | |||
| connnections.append((i - j, i)) | |||
| return nodes, connnections | |||
| def get_dense_graph(depth_multiplier): | |||
| nodes = [] | |||
| connections = [] | |||
| for i in range(depth_multiplier): | |||
| nodes.append(i) | |||
| for j in range(i): | |||
| connections.append((j, i)) | |||
| return nodes, connections | |||
| def giraffeneck_config(min_level, | |||
| max_level, | |||
| weight_method=None, | |||
| depth_multiplier=5, | |||
| with_backslash=False, | |||
| with_slash=False, | |||
| with_skip_connect=False, | |||
| skip_connect_type='dense'): | |||
| """Graph config with log2n merge and panet""" | |||
| if skip_connect_type == 'dense': | |||
| nodes, connections = get_dense_graph(depth_multiplier) | |||
| elif skip_connect_type == 'log2n': | |||
| nodes, connections = get_log2n_graph(depth_multiplier) | |||
| graph = nx.Graph() | |||
| graph.add_nodes_from(nodes) | |||
| graph.add_edges_from(connections) | |||
| drop_node = [] | |||
| nodes, input_nodes, output_nodes = get_graph_info(graph) | |||
| weight_method = weight_method or 'fastattn' | |||
| num_levels = max_level - min_level + 1 | |||
| node_ids = {min_level + i: [i] for i in range(num_levels)} | |||
| node_ids_per_layer = {} | |||
| pnodes = {} | |||
| def update_drop_node(new_id, input_offsets): | |||
| if new_id not in drop_node: | |||
| new_id = new_id | |||
| else: | |||
| while new_id in drop_node: | |||
| if new_id in pnodes: | |||
| for n in pnodes[new_id]['inputs_offsets']: | |||
| if n not in input_offsets and n not in drop_node: | |||
| input_offsets.append(n) | |||
| new_id = new_id - 1 | |||
| if new_id not in input_offsets: | |||
| input_offsets.append(new_id) | |||
| # top-down layer | |||
| for i in range(max_level, min_level - 1, -1): | |||
| node_ids_per_layer[i] = [] | |||
| for id, node in enumerate(nodes): | |||
| input_offsets = [] | |||
| if id in input_nodes: | |||
| input_offsets.append(node_ids[i][0]) | |||
| else: | |||
| if with_skip_connect: | |||
| for input_id in node.inputs: | |||
| new_id = nodeid_trans(input_id, i - min_level, | |||
| num_levels) | |||
| update_drop_node(new_id, input_offsets) | |||
| # add top2down | |||
| new_id = nodeid_trans(id, i - min_level, num_levels) | |||
| # add backslash node | |||
| def cal_backslash_node(id): | |||
| ind = id // num_levels | |||
| mod = id % num_levels | |||
| if ind % 2 == 0: # even | |||
| if mod == (num_levels - 1): | |||
| last = -1 | |||
| else: | |||
| last = (ind - 1) * num_levels + ( | |||
| num_levels - 1 - mod - 1) | |||
| else: # odd | |||
| if mod == 0: | |||
| last = -1 | |||
| else: | |||
| last = (ind - 1) * num_levels + ( | |||
| num_levels - 1 - mod + 1) | |||
| return last | |||
| # add slash node | |||
| def cal_slash_node(id): | |||
| ind = id // num_levels | |||
| mod = id % num_levels | |||
| if ind % 2 == 1: # odd | |||
| if mod == (num_levels - 1): | |||
| last = -1 | |||
| else: | |||
| last = (ind - 1) * num_levels + ( | |||
| num_levels - 1 - mod - 1) | |||
| else: # even | |||
| if mod == 0: | |||
| last = -1 | |||
| else: | |||
| last = (ind - 1) * num_levels + ( | |||
| num_levels - 1 - mod + 1) | |||
| return last | |||
| # add last node | |||
| last = new_id - 1 | |||
| update_drop_node(last, input_offsets) | |||
| if with_backslash: | |||
| backslash = cal_backslash_node(new_id) | |||
| if backslash != -1 and backslash not in input_offsets: | |||
| input_offsets.append(backslash) | |||
| if with_slash: | |||
| slash = cal_slash_node(new_id) | |||
| if slash != -1 and slash not in input_offsets: | |||
| input_offsets.append(slash) | |||
| if new_id in drop_node: | |||
| input_offsets = [] | |||
| pnodes[new_id] = { | |||
| 'reduction': 1 << i, | |||
| 'inputs_offsets': input_offsets, | |||
| 'weight_method': weight_method, | |||
| 'is_out': 0, | |||
| } | |||
| input_offsets = [] | |||
| for out_id in output_nodes: | |||
| new_id = nodeid_trans(out_id, i - min_level, num_levels) | |||
| input_offsets.append(new_id) | |||
| pnodes[node_ids[i][0] + num_levels * (len(nodes) + 1)] = { | |||
| 'reduction': 1 << i, | |||
| 'inputs_offsets': input_offsets, | |||
| 'weight_method': weight_method, | |||
| 'is_out': 1, | |||
| } | |||
| pnodes = dict(sorted(pnodes.items(), key=lambda x: x[0])) | |||
| return pnodes | |||
| def get_graph_config(fpn_name, | |||
| min_level=3, | |||
| max_level=7, | |||
| weight_method='concat', | |||
| depth_multiplier=5, | |||
| with_backslash=False, | |||
| with_slash=False, | |||
| with_skip_connect=False, | |||
| skip_connect_type='dense'): | |||
| name_to_config = { | |||
| 'giraffeneck': | |||
| giraffeneck_config( | |||
| min_level=min_level, | |||
| max_level=max_level, | |||
| weight_method=weight_method, | |||
| depth_multiplier=depth_multiplier, | |||
| with_backslash=with_backslash, | |||
| with_slash=with_slash, | |||
| with_skip_connect=with_skip_connect, | |||
| skip_connect_type=skip_connect_type), | |||
| } | |||
| return name_to_config[fpn_name] | |||
| @@ -0,0 +1,661 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| # The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet. | |||
| import logging | |||
| import math | |||
| from collections import OrderedDict | |||
| from functools import partial | |||
| from typing import Callable, List, Optional, Tuple, Union | |||
| import numpy as np | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| from timm import create_model | |||
| from timm.models.layers import (Swish, create_conv2d, create_pool2d, | |||
| get_act_layer) | |||
| from ..core.base_ops import CSPLayer, ShuffleBlock, ShuffleCSPLayer | |||
| from .giraffe_config import get_graph_config | |||
| _ACT_LAYER = Swish | |||
| class SequentialList(nn.Sequential): | |||
| """ This module exists to work around torchscript typing issues list -> list""" | |||
| def __init__(self, *args): | |||
| super(SequentialList, self).__init__(*args) | |||
| def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]: | |||
| for module in self: | |||
| x = module(x) | |||
| return x | |||
| class ConvBnAct2d(nn.Module): | |||
| def __init__(self, | |||
| in_channels, | |||
| out_channels, | |||
| kernel_size, | |||
| stride=1, | |||
| dilation=1, | |||
| padding='', | |||
| bias=False, | |||
| norm_layer=nn.BatchNorm2d, | |||
| act_layer=_ACT_LAYER): | |||
| super(ConvBnAct2d, self).__init__() | |||
| self.conv = create_conv2d( | |||
| in_channels, | |||
| out_channels, | |||
| kernel_size, | |||
| stride=stride, | |||
| dilation=dilation, | |||
| padding=padding, | |||
| bias=bias) | |||
| self.bn = None if norm_layer is None else norm_layer(out_channels) | |||
| self.act = None if act_layer is None else act_layer(inplace=True) | |||
| def forward(self, x): | |||
| x = self.conv(x) | |||
| if self.bn is not None: | |||
| x = self.bn(x) | |||
| if self.act is not None: | |||
| x = self.act(x) | |||
| return x | |||
| class SeparableConv2d(nn.Module): | |||
| """ Separable Conv | |||
| """ | |||
| def __init__(self, | |||
| in_channels, | |||
| out_channels, | |||
| kernel_size=3, | |||
| stride=1, | |||
| dilation=1, | |||
| padding='', | |||
| bias=False, | |||
| channel_multiplier=1.0, | |||
| pw_kernel_size=1, | |||
| norm_layer=nn.BatchNorm2d, | |||
| act_layer=_ACT_LAYER): | |||
| super(SeparableConv2d, self).__init__() | |||
| self.conv_dw = create_conv2d( | |||
| in_channels, | |||
| int(in_channels * channel_multiplier), | |||
| kernel_size, | |||
| stride=stride, | |||
| dilation=dilation, | |||
| padding=padding, | |||
| depthwise=True) | |||
| self.conv_pw = create_conv2d( | |||
| int(in_channels * channel_multiplier), | |||
| out_channels, | |||
| pw_kernel_size, | |||
| padding=padding, | |||
| bias=bias) | |||
| self.bn = None if norm_layer is None else norm_layer(out_channels) | |||
| self.act = None if act_layer is None else act_layer(inplace=True) | |||
| def forward(self, x): | |||
| x = self.conv_dw(x) | |||
| x = self.conv_pw(x) | |||
| if self.bn is not None: | |||
| x = self.bn(x) | |||
| if self.act is not None: | |||
| x = self.act(x) | |||
| return x | |||
| def _init_weight( | |||
| m, | |||
| n='', | |||
| ): | |||
| """ Weight initialization as per Tensorflow official implementations. | |||
| """ | |||
| def _fan_in_out(w, groups=1): | |||
| dimensions = w.dim() | |||
| if dimensions < 2: | |||
| raise ValueError( | |||
| 'Fan in and fan out can not be computed for tensor with fewer than 2 dimensions' | |||
| ) | |||
| num_input_fmaps = w.size(1) | |||
| num_output_fmaps = w.size(0) | |||
| receptive_field_size = 1 | |||
| if w.dim() > 2: | |||
| receptive_field_size = w[0][0].numel() | |||
| fan_in = num_input_fmaps * receptive_field_size | |||
| fan_out = num_output_fmaps * receptive_field_size | |||
| fan_out //= groups | |||
| return fan_in, fan_out | |||
| def _glorot_uniform(w, gain=1, groups=1): | |||
| fan_in, fan_out = _fan_in_out(w, groups) | |||
| gain /= max(1., (fan_in + fan_out) / 2.) # fan avg | |||
| limit = math.sqrt(3.0 * gain) | |||
| w.data.uniform_(-limit, limit) | |||
| def _variance_scaling(w, gain=1, groups=1): | |||
| fan_in, fan_out = _fan_in_out(w, groups) | |||
| gain /= max(1., fan_in) # fan in | |||
| std = math.sqrt(gain) | |||
| w.data.normal_(std=std) | |||
| if isinstance(m, SeparableConv2d): | |||
| if 'box_net' in n or 'class_net' in n: | |||
| _variance_scaling(m.conv_dw.weight, groups=m.conv_dw.groups) | |||
| _variance_scaling(m.conv_pw.weight) | |||
| if m.conv_pw.bias is not None: | |||
| if 'class_net.predict' in n: | |||
| m.conv_pw.bias.data.fill_(-math.log((1 - 0.01) / 0.01)) | |||
| else: | |||
| m.conv_pw.bias.data.zero_() | |||
| else: | |||
| _glorot_uniform(m.conv_dw.weight, groups=m.conv_dw.groups) | |||
| _glorot_uniform(m.conv_pw.weight) | |||
| if m.conv_pw.bias is not None: | |||
| m.conv_pw.bias.data.zero_() | |||
| elif isinstance(m, ConvBnAct2d): | |||
| if 'box_net' in n or 'class_net' in n: | |||
| m.conv.weight.data.normal_(std=.01) | |||
| if m.conv.bias is not None: | |||
| if 'class_net.predict' in n: | |||
| m.conv.bias.data.fill_(-math.log((1 - 0.01) / 0.01)) | |||
| else: | |||
| m.conv.bias.data.zero_() | |||
| else: | |||
| _glorot_uniform(m.conv.weight) | |||
| if m.conv.bias is not None: | |||
| m.conv.bias.data.zero_() | |||
| elif isinstance(m, nn.BatchNorm2d): | |||
| m.weight.data.fill_(1.0) | |||
| m.bias.data.zero_() | |||
| def _init_weight_alt( | |||
| m, | |||
| n='', | |||
| ): | |||
| """ Weight initialization alternative, based on EfficientNet bacbkone init w/ class bias addition | |||
| NOTE: this will likely be removed after some experimentation | |||
| """ | |||
| if isinstance(m, nn.Conv2d): | |||
| fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels | |||
| fan_out //= m.groups | |||
| m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) | |||
| if m.bias is not None: | |||
| if 'class_net.predict' in n: | |||
| m.bias.data.fill_(-math.log((1 - 0.01) / 0.01)) | |||
| else: | |||
| m.bias.data.zero_() | |||
| elif isinstance(m, nn.BatchNorm2d): | |||
| m.weight.data.fill_(1.0) | |||
| m.bias.data.zero_() | |||
| class Interpolate2d(nn.Module): | |||
| r"""Resamples a 2d Image | |||
| The input data is assumed to be of the form | |||
| `minibatch x channels x [optional depth] x [optional height] x width`. | |||
| Hence, for spatial inputs, we expect a 4D Tensor and for volumetric inputs, we expect a 5D Tensor. | |||
| The algorithms available for upsampling are nearest neighbor and linear, | |||
| bilinear, bicubic and trilinear for 3D, 4D and 5D input Tensor, | |||
| respectively. | |||
| One can either give a :attr:`scale_factor` or the target output :attr:`size` to | |||
| calculate the output size. (You cannot give both, as it is ambiguous) | |||
| Args: | |||
| size (int or Tuple[int] or Tuple[int, int] or Tuple[int, int, int], optional): | |||
| output spatial sizes | |||
| scale_factor (float or Tuple[float] or Tuple[float, float] or Tuple[float, float, float], optional): | |||
| multiplier for spatial size. Has to match input size if it is a tuple. | |||
| mode (str, optional): the upsampling algorithm: one of ``'nearest'``, | |||
| ``'linear'``, ``'bilinear'``, ``'bicubic'`` and ``'trilinear'``. | |||
| Default: ``'nearest'`` | |||
| align_corners (bool, optional): if ``True``, the corner pixels of the input | |||
| and output tensors are aligned, and thus preserving the values at | |||
| those pixels. This only has effect when :attr:`mode` is | |||
| ``'linear'``, ``'bilinear'``, or ``'trilinear'``. Default: ``False`` | |||
| """ | |||
| __constants__ = ['size', 'scale_factor', 'mode', 'align_corners', 'name'] | |||
| name: str | |||
| size: Optional[Union[int, Tuple[int, int]]] | |||
| scale_factor: Optional[Union[float, Tuple[float, float]]] | |||
| mode: str | |||
| align_corners: Optional[bool] | |||
| def __init__(self, | |||
| size: Optional[Union[int, Tuple[int, int]]] = None, | |||
| scale_factor: Optional[Union[float, Tuple[float, | |||
| float]]] = None, | |||
| mode: str = 'nearest', | |||
| align_corners: bool = False) -> None: | |||
| super(Interpolate2d, self).__init__() | |||
| self.name = type(self).__name__ | |||
| self.size = size | |||
| if isinstance(scale_factor, tuple): | |||
| self.scale_factor = tuple(float(factor) for factor in scale_factor) | |||
| else: | |||
| self.scale_factor = float(scale_factor) if scale_factor else None | |||
| self.mode = mode | |||
| self.align_corners = None if mode == 'nearest' else align_corners | |||
| def forward(self, input: torch.Tensor) -> torch.Tensor: | |||
| return F.interpolate( | |||
| input, | |||
| self.size, | |||
| self.scale_factor, | |||
| self.mode, | |||
| self.align_corners, | |||
| recompute_scale_factor=False) | |||
| class ResampleFeatureMap(nn.Sequential): | |||
| def __init__(self, | |||
| in_channels, | |||
| out_channels, | |||
| reduction_ratio=1., | |||
| pad_type='', | |||
| downsample=None, | |||
| upsample=None, | |||
| norm_layer=nn.BatchNorm2d, | |||
| apply_bn=False, | |||
| conv_after_downsample=False, | |||
| redundant_bias=False): | |||
| super(ResampleFeatureMap, self).__init__() | |||
| downsample = downsample or 'max' | |||
| upsample = upsample or 'nearest' | |||
| self.in_channels = in_channels | |||
| self.out_channels = out_channels | |||
| self.reduction_ratio = reduction_ratio | |||
| self.conv_after_downsample = conv_after_downsample | |||
| conv = None | |||
| if in_channels != out_channels: | |||
| conv = ConvBnAct2d( | |||
| in_channels, | |||
| out_channels, | |||
| kernel_size=1, | |||
| padding=pad_type, | |||
| norm_layer=norm_layer if apply_bn else None, | |||
| bias=not apply_bn or redundant_bias, | |||
| act_layer=None) | |||
| if reduction_ratio > 1: | |||
| if conv is not None and not self.conv_after_downsample: | |||
| self.add_module('conv', conv) | |||
| if downsample in ('max', 'avg'): | |||
| stride_size = int(reduction_ratio) | |||
| downsample = create_pool2d( | |||
| downsample, | |||
| kernel_size=stride_size + 1, | |||
| stride=stride_size, | |||
| padding=pad_type) | |||
| else: | |||
| downsample = Interpolate2d( | |||
| scale_factor=1. / reduction_ratio, mode=downsample) | |||
| self.add_module('downsample', downsample) | |||
| if conv is not None and self.conv_after_downsample: | |||
| self.add_module('conv', conv) | |||
| else: | |||
| if conv is not None: | |||
| self.add_module('conv', conv) | |||
| if reduction_ratio < 1: | |||
| scale = int(1 // reduction_ratio) | |||
| self.add_module( | |||
| 'upsample', | |||
| Interpolate2d(scale_factor=scale, mode=upsample)) | |||
| class GiraffeCombine(nn.Module): | |||
| def __init__(self, | |||
| feature_info, | |||
| fpn_config, | |||
| fpn_channels, | |||
| inputs_offsets, | |||
| target_reduction, | |||
| pad_type='', | |||
| downsample=None, | |||
| upsample=None, | |||
| norm_layer=nn.BatchNorm2d, | |||
| apply_resample_bn=False, | |||
| conv_after_downsample=False, | |||
| redundant_bias=False, | |||
| weight_method='attn'): | |||
| super(GiraffeCombine, self).__init__() | |||
| self.inputs_offsets = inputs_offsets | |||
| self.weight_method = weight_method | |||
| self.resample = nn.ModuleDict() | |||
| reduction_base = feature_info[0]['reduction'] | |||
| target_channels_idx = int( | |||
| math.log(target_reduction // reduction_base, 2)) | |||
| for idx, offset in enumerate(inputs_offsets): | |||
| if offset < len(feature_info): | |||
| in_channels = feature_info[offset]['num_chs'] | |||
| input_reduction = feature_info[offset]['reduction'] | |||
| else: | |||
| node_idx = offset | |||
| input_reduction = fpn_config[node_idx]['reduction'] | |||
| # in_channels = fpn_config[node_idx]['num_chs'] | |||
| input_channels_idx = int( | |||
| math.log(input_reduction // reduction_base, 2)) | |||
| in_channels = feature_info[input_channels_idx]['num_chs'] | |||
| reduction_ratio = target_reduction / input_reduction | |||
| if weight_method == 'concat': | |||
| self.resample[str(offset)] = ResampleFeatureMap( | |||
| in_channels, | |||
| in_channels, | |||
| reduction_ratio=reduction_ratio, | |||
| pad_type=pad_type, | |||
| downsample=downsample, | |||
| upsample=upsample, | |||
| norm_layer=norm_layer, | |||
| apply_bn=apply_resample_bn, | |||
| conv_after_downsample=conv_after_downsample, | |||
| redundant_bias=redundant_bias) | |||
| else: | |||
| self.resample[str(offset)] = ResampleFeatureMap( | |||
| in_channels, | |||
| fpn_channels[target_channels_idx], | |||
| reduction_ratio=reduction_ratio, | |||
| pad_type=pad_type, | |||
| downsample=downsample, | |||
| upsample=upsample, | |||
| norm_layer=norm_layer, | |||
| apply_bn=apply_resample_bn, | |||
| conv_after_downsample=conv_after_downsample, | |||
| redundant_bias=redundant_bias) | |||
| if weight_method == 'attn' or weight_method == 'fastattn': | |||
| self.edge_weights = nn.Parameter( | |||
| torch.ones(len(inputs_offsets)), requires_grad=True) # WSM | |||
| else: | |||
| self.edge_weights = None | |||
| def forward(self, x: List[torch.Tensor]): | |||
| dtype = x[0].dtype | |||
| nodes = [] | |||
| if len(self.inputs_offsets) == 0: | |||
| return None | |||
| for offset, resample in zip(self.inputs_offsets, | |||
| self.resample.values()): | |||
| input_node = x[offset] | |||
| input_node = resample(input_node) | |||
| nodes.append(input_node) | |||
| if self.weight_method == 'attn': | |||
| normalized_weights = torch.softmax( | |||
| self.edge_weights.to(dtype=dtype), dim=0) | |||
| out = torch.stack(nodes, dim=-1) * normalized_weights | |||
| out = torch.sum(out, dim=-1) | |||
| elif self.weight_method == 'fastattn': | |||
| edge_weights = nn.functional.relu( | |||
| self.edge_weights.to(dtype=dtype)) | |||
| weights_sum = torch.sum(edge_weights) | |||
| weights_norm = weights_sum + 0.0001 | |||
| out = torch.stack([(nodes[i] * edge_weights[i]) / weights_norm | |||
| for i in range(len(nodes))], | |||
| dim=-1) | |||
| out = torch.sum(out, dim=-1) | |||
| elif self.weight_method == 'sum': | |||
| out = torch.stack(nodes, dim=-1) | |||
| out = torch.sum(out, dim=-1) | |||
| elif self.weight_method == 'concat': | |||
| out = torch.cat(nodes, dim=1) | |||
| else: | |||
| raise ValueError('unknown weight_method {}'.format( | |||
| self.weight_method)) | |||
| return out | |||
| class GiraffeNode(nn.Module): | |||
| """ A simple wrapper used in place of nn.Sequential for torchscript typing | |||
| Handles input type List[Tensor] -> output type Tensor | |||
| """ | |||
| def __init__(self, combine: nn.Module, after_combine: nn.Module): | |||
| super(GiraffeNode, self).__init__() | |||
| self.combine = combine | |||
| self.after_combine = after_combine | |||
| def forward(self, x: List[torch.Tensor]) -> torch.Tensor: | |||
| combine_feat = self.combine(x) | |||
| if combine_feat is None: | |||
| return None | |||
| else: | |||
| return self.after_combine(combine_feat) | |||
| class GiraffeLayer(nn.Module): | |||
| def __init__(self, | |||
| feature_info, | |||
| fpn_config, | |||
| inner_fpn_channels, | |||
| outer_fpn_channels, | |||
| num_levels=5, | |||
| pad_type='', | |||
| downsample=None, | |||
| upsample=None, | |||
| norm_layer=nn.BatchNorm2d, | |||
| act_layer=_ACT_LAYER, | |||
| apply_resample_bn=False, | |||
| conv_after_downsample=True, | |||
| conv_bn_relu_pattern=False, | |||
| separable_conv=True, | |||
| redundant_bias=False, | |||
| merge_type='conv'): | |||
| super(GiraffeLayer, self).__init__() | |||
| self.num_levels = num_levels | |||
| self.conv_bn_relu_pattern = False | |||
| self.feature_info = {} | |||
| for idx, feat in enumerate(feature_info): | |||
| self.feature_info[idx] = feat | |||
| self.fnode = nn.ModuleList() | |||
| reduction_base = feature_info[0]['reduction'] | |||
| for i, fnode_cfg in fpn_config.items(): | |||
| logging.debug('fnode {} : {}'.format(i, fnode_cfg)) | |||
| if fnode_cfg['is_out'] == 1: | |||
| fpn_channels = outer_fpn_channels | |||
| else: | |||
| fpn_channels = inner_fpn_channels | |||
| reduction = fnode_cfg['reduction'] | |||
| fpn_channels_idx = int(math.log(reduction // reduction_base, 2)) | |||
| combine = GiraffeCombine( | |||
| self.feature_info, | |||
| fpn_config, | |||
| fpn_channels, | |||
| tuple(fnode_cfg['inputs_offsets']), | |||
| target_reduction=reduction, | |||
| pad_type=pad_type, | |||
| downsample=downsample, | |||
| upsample=upsample, | |||
| norm_layer=norm_layer, | |||
| apply_resample_bn=apply_resample_bn, | |||
| conv_after_downsample=conv_after_downsample, | |||
| redundant_bias=redundant_bias, | |||
| weight_method=fnode_cfg['weight_method']) | |||
| after_combine = nn.Sequential() | |||
| in_channels = 0 | |||
| out_channels = 0 | |||
| for input_offset in fnode_cfg['inputs_offsets']: | |||
| in_channels += self.feature_info[input_offset]['num_chs'] | |||
| out_channels = fpn_channels[fpn_channels_idx] | |||
| if merge_type == 'csp': | |||
| after_combine.add_module( | |||
| 'CspLayer', | |||
| CSPLayer( | |||
| in_channels, | |||
| out_channels, | |||
| 2, | |||
| shortcut=True, | |||
| depthwise=False, | |||
| act='silu')) | |||
| elif merge_type == 'shuffle': | |||
| after_combine.add_module( | |||
| 'shuffleBlock', ShuffleBlock(in_channels, in_channels)) | |||
| after_combine.add_module( | |||
| 'conv1x1', | |||
| create_conv2d(in_channels, out_channels, kernel_size=1)) | |||
| elif merge_type == 'conv': | |||
| after_combine.add_module( | |||
| 'conv1x1', | |||
| create_conv2d(in_channels, out_channels, kernel_size=1)) | |||
| conv_kwargs = dict( | |||
| in_channels=out_channels, | |||
| out_channels=out_channels, | |||
| kernel_size=3, | |||
| padding=pad_type, | |||
| bias=False, | |||
| norm_layer=norm_layer, | |||
| act_layer=act_layer) | |||
| if not conv_bn_relu_pattern: | |||
| conv_kwargs['bias'] = redundant_bias | |||
| conv_kwargs['act_layer'] = None | |||
| after_combine.add_module('act', act_layer(inplace=True)) | |||
| after_combine.add_module( | |||
| 'conv', | |||
| SeparableConv2d(**conv_kwargs) | |||
| if separable_conv else ConvBnAct2d(**conv_kwargs)) | |||
| self.fnode.append( | |||
| GiraffeNode(combine=combine, after_combine=after_combine)) | |||
| self.feature_info[i] = dict( | |||
| num_chs=fpn_channels[fpn_channels_idx], reduction=reduction) | |||
| self.out_feature_info = [] | |||
| out_node = list(self.feature_info.keys())[-num_levels::] | |||
| for i in out_node: | |||
| self.out_feature_info.append(self.feature_info[i]) | |||
| self.feature_info = self.out_feature_info | |||
| def forward(self, x: List[torch.Tensor]): | |||
| for fn in self.fnode: | |||
| x.append(fn(x)) | |||
| return x[-self.num_levels::] | |||
| class GiraffeNeck(nn.Module): | |||
| def __init__(self, min_level, max_level, num_levels, norm_layer, | |||
| norm_kwargs, act_type, fpn_config, fpn_name, fpn_channels, | |||
| out_fpn_channels, weight_method, depth_multiplier, | |||
| width_multiplier, with_backslash, with_slash, | |||
| with_skip_connect, skip_connect_type, separable_conv, | |||
| feature_info, merge_type, pad_type, downsample_type, | |||
| upsample_type, apply_resample_bn, conv_after_downsample, | |||
| redundant_bias, conv_bn_relu_pattern, alternate_init): | |||
| super(GiraffeNeck, self).__init__() | |||
| self.num_levels = num_levels | |||
| self.min_level = min_level | |||
| self.in_features = [0, 1, 2, 3, 4, 5, | |||
| 6][self.min_level - 1:self.min_level - 1 | |||
| + num_levels] | |||
| self.alternate_init = alternate_init | |||
| norm_layer = norm_layer or nn.BatchNorm2d | |||
| if norm_kwargs: | |||
| norm_layer = partial(norm_layer, **norm_kwargs) | |||
| act_layer = get_act_layer(act_type) or _ACT_LAYER | |||
| fpn_config = fpn_config or get_graph_config( | |||
| fpn_name, | |||
| min_level=min_level, | |||
| max_level=max_level, | |||
| weight_method=weight_method, | |||
| depth_multiplier=depth_multiplier, | |||
| with_backslash=with_backslash, | |||
| with_slash=with_slash, | |||
| with_skip_connect=with_skip_connect, | |||
| skip_connect_type=skip_connect_type) | |||
| # width scale | |||
| for i in range(len(fpn_channels)): | |||
| fpn_channels[i] = int(fpn_channels[i] * width_multiplier) | |||
| self.resample = nn.ModuleDict() | |||
| for level in range(num_levels): | |||
| if level < len(feature_info): | |||
| in_chs = feature_info[level]['num_chs'] | |||
| reduction = feature_info[level]['reduction'] | |||
| else: | |||
| # Adds a coarser level by downsampling the last feature map | |||
| reduction_ratio = 2 | |||
| self.resample[str(level)] = ResampleFeatureMap( | |||
| in_channels=in_chs, | |||
| out_channels=feature_info[level - 1]['num_chs'], | |||
| pad_type=pad_type, | |||
| downsample=downsample_type, | |||
| upsample=upsample_type, | |||
| norm_layer=norm_layer, | |||
| reduction_ratio=reduction_ratio, | |||
| apply_bn=apply_resample_bn, | |||
| conv_after_downsample=conv_after_downsample, | |||
| redundant_bias=redundant_bias, | |||
| ) | |||
| in_chs = feature_info[level - 1]['num_chs'] | |||
| reduction = int(reduction * reduction_ratio) | |||
| feature_info.append(dict(num_chs=in_chs, reduction=reduction)) | |||
| self.cell = SequentialList() | |||
| logging.debug('building giraffeNeck') | |||
| giraffe_layer = GiraffeLayer( | |||
| feature_info=feature_info, | |||
| fpn_config=fpn_config, | |||
| inner_fpn_channels=fpn_channels, | |||
| outer_fpn_channels=out_fpn_channels, | |||
| num_levels=num_levels, | |||
| pad_type=pad_type, | |||
| downsample=downsample_type, | |||
| upsample=upsample_type, | |||
| norm_layer=norm_layer, | |||
| act_layer=act_layer, | |||
| separable_conv=separable_conv, | |||
| apply_resample_bn=apply_resample_bn, | |||
| conv_after_downsample=conv_after_downsample, | |||
| conv_bn_relu_pattern=conv_bn_relu_pattern, | |||
| redundant_bias=redundant_bias, | |||
| merge_type=merge_type) | |||
| self.cell.add_module('giraffeNeck', giraffe_layer) | |||
| feature_info = giraffe_layer.feature_info | |||
| def init_weights(self, pretrained=False): | |||
| for n, m in self.named_modules(): | |||
| if 'backbone' not in n: | |||
| if self.alternate_init: | |||
| _init_weight_alt(m, n) | |||
| else: | |||
| _init_weight(m, n) | |||
| def forward(self, x: List[torch.Tensor]): | |||
| if type(x) is tuple: | |||
| x = list(x) | |||
| x = [x[f] for f in self.in_features] | |||
| for resample in self.resample.values(): | |||
| x.append(resample(x[-1])) | |||
| x = self.cell(x) | |||
| return x | |||
| @@ -0,0 +1,203 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| # The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet. | |||
| import torch | |||
| import torch.nn as nn | |||
| from ..core.base_ops import BaseConv, CSPLayer, DWConv | |||
| from ..core.neck_ops import CSPStage | |||
| class GiraffeNeckV2(nn.Module): | |||
| def __init__( | |||
| self, | |||
| depth=1.0, | |||
| width=1.0, | |||
| in_features=[2, 3, 4], | |||
| in_channels=[256, 512, 1024], | |||
| out_channels=[256, 512, 1024], | |||
| depthwise=False, | |||
| act='silu', | |||
| spp=True, | |||
| reparam_mode=True, | |||
| block_name='BasicBlock', | |||
| ): | |||
| super().__init__() | |||
| self.in_features = in_features | |||
| self.in_channels = in_channels | |||
| Conv = DWConv if depthwise else BaseConv | |||
| reparam_mode = reparam_mode | |||
| self.upsample = nn.Upsample(scale_factor=2, mode='nearest') | |||
| # node x3: input x0, x1 | |||
| self.bu_conv13 = Conv( | |||
| int(in_channels[1] * width), | |||
| int(in_channels[1] * width), | |||
| 3, | |||
| 2, | |||
| act=act) | |||
| if reparam_mode: | |||
| self.merge_3 = CSPStage( | |||
| block_name, | |||
| int((in_channels[1] + in_channels[2]) * width), | |||
| int(in_channels[2] * width), | |||
| round(3 * depth), | |||
| act=act, | |||
| spp=spp) | |||
| else: | |||
| self.merge_3 = CSPLayer( | |||
| int((in_channels[1] + in_channels[2]) * width), | |||
| int(in_channels[2] * width), | |||
| round(3 * depth), | |||
| False, | |||
| depthwise=depthwise, | |||
| act=act) | |||
| # node x4: input x1, x2, x3 | |||
| self.bu_conv24 = Conv( | |||
| int(in_channels[0] * width), | |||
| int(in_channels[0] * width), | |||
| 3, | |||
| 2, | |||
| act=act) | |||
| if reparam_mode: | |||
| self.merge_4 = CSPStage( | |||
| block_name, | |||
| int((in_channels[0] + in_channels[1] + in_channels[2]) | |||
| * width), | |||
| int(in_channels[1] * width), | |||
| round(3 * depth), | |||
| act=act, | |||
| spp=spp) | |||
| else: | |||
| self.merge_4 = CSPLayer( | |||
| int((in_channels[0] + in_channels[1] + in_channels[2]) | |||
| * width), | |||
| int(in_channels[1] * width), | |||
| round(3 * depth), | |||
| False, | |||
| depthwise=depthwise, | |||
| act=act) | |||
| # node x5: input x2, x4 | |||
| if reparam_mode: | |||
| self.merge_5 = CSPStage( | |||
| block_name, | |||
| int((in_channels[1] + in_channels[0]) * width), | |||
| int(out_channels[0] * width), | |||
| round(3 * depth), | |||
| act=act, | |||
| spp=spp) | |||
| else: | |||
| self.merge_5 = CSPLayer( | |||
| int((in_channels[1] + in_channels[0]) * width), | |||
| int(out_channels[0] * width), | |||
| round(3 * depth), | |||
| False, | |||
| depthwise=depthwise, | |||
| act=act) | |||
| # node x7: input x4, x5 | |||
| self.bu_conv57 = Conv( | |||
| int(out_channels[0] * width), | |||
| int(out_channels[0] * width), | |||
| 3, | |||
| 2, | |||
| act=act) | |||
| if reparam_mode: | |||
| self.merge_7 = CSPStage( | |||
| block_name, | |||
| int((out_channels[0] + in_channels[1]) * width), | |||
| int(out_channels[1] * width), | |||
| round(3 * depth), | |||
| act=act, | |||
| spp=spp) | |||
| else: | |||
| self.merge_7 = CSPLayer( | |||
| int((out_channels[0] + in_channels[1]) * width), | |||
| int(out_channels[1] * width), | |||
| round(3 * depth), | |||
| False, | |||
| depthwise=depthwise, | |||
| act=act) | |||
| # node x6: input x3, x4, x7 | |||
| self.bu_conv46 = Conv( | |||
| int(in_channels[1] * width), | |||
| int(in_channels[1] * width), | |||
| 3, | |||
| 2, | |||
| act=act) | |||
| self.bu_conv76 = Conv( | |||
| int(out_channels[1] * width), | |||
| int(out_channels[1] * width), | |||
| 3, | |||
| 2, | |||
| act=act) | |||
| if reparam_mode: | |||
| self.merge_6 = CSPStage( | |||
| block_name, | |||
| int((in_channels[1] + out_channels[1] + in_channels[2]) | |||
| * width), | |||
| int(out_channels[2] * width), | |||
| round(3 * depth), | |||
| act=act, | |||
| spp=spp) | |||
| else: | |||
| self.merge_6 = CSPLayer( | |||
| int((in_channels[1] + out_channels[1] + in_channels[2]) | |||
| * width), | |||
| int(out_channels[2] * width), | |||
| round(3 * depth), | |||
| False, | |||
| depthwise=depthwise, | |||
| act=act) | |||
| def init_weights(self): | |||
| pass | |||
| def forward(self, out_features): | |||
| """ | |||
| Args: | |||
| inputs: input images. | |||
| Returns: | |||
| Tuple[Tensor]: FPN feature. | |||
| """ | |||
| # backbone | |||
| features = [out_features[f] for f in self.in_features] | |||
| [x2, x1, x0] = features | |||
| # node x3 | |||
| x13 = self.bu_conv13(x1) | |||
| x3 = torch.cat([x0, x13], 1) | |||
| x3 = self.merge_3(x3) | |||
| # node x4 | |||
| x34 = self.upsample(x3) | |||
| x24 = self.bu_conv24(x2) | |||
| x4 = torch.cat([x1, x24, x34], 1) | |||
| x4 = self.merge_4(x4) | |||
| # node x5 | |||
| x45 = self.upsample(x4) | |||
| x5 = torch.cat([x2, x45], 1) | |||
| x5 = self.merge_5(x5) | |||
| # node x7 | |||
| x57 = self.bu_conv57(x5) | |||
| x7 = torch.cat([x4, x57], 1) | |||
| x7 = self.merge_7(x7) | |||
| # node x6 | |||
| x46 = self.bu_conv46(x4) | |||
| x76 = self.bu_conv76(x7) | |||
| x6 = torch.cat([x3, x46, x76], 1) | |||
| x6 = self.merge_6(x6) | |||
| outputs = (x5, x7, x6) | |||
| return outputs | |||
| @@ -0,0 +1,16 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| # The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet. | |||
| from modelscope.metainfo import Models | |||
| from modelscope.models.builder import MODELS | |||
| from modelscope.utils.constant import Tasks | |||
| from .detector import SingleStageDetector | |||
| @MODELS.register_module( | |||
| Tasks.image_object_detection, module_name=Models.tinynas_detection) | |||
| class TinynasDetector(SingleStageDetector): | |||
| def __init__(self, model_dir, *args, **kwargs): | |||
| super(TinynasDetector, self).__init__(model_dir, *args, **kwargs) | |||
| @@ -0,0 +1,30 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| # The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet. | |||
| import importlib | |||
| import os | |||
| import sys | |||
| from os.path import dirname, join | |||
| def get_config_by_file(config_file): | |||
| try: | |||
| sys.path.append(os.path.dirname(config_file)) | |||
| current_config = importlib.import_module( | |||
| os.path.basename(config_file).split('.')[0]) | |||
| exp = current_config.Config() | |||
| except Exception: | |||
| raise ImportError( | |||
| "{} doesn't contains class named 'Config'".format(config_file)) | |||
| return exp | |||
| def parse_config(config_file): | |||
| """ | |||
| get config object by file. | |||
| Args: | |||
| config_file (str): file path of config. | |||
| """ | |||
| assert (config_file is not None), 'plz provide config file' | |||
| if config_file is not None: | |||
| return get_config_by_file(config_file) | |||
| @@ -0,0 +1,61 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from typing import Any, Dict | |||
| import cv2 | |||
| import numpy as np | |||
| import torch | |||
| from modelscope.metainfo import Pipelines | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines.base import Input, Pipeline | |||
| from modelscope.pipelines.builder import PIPELINES | |||
| from modelscope.preprocessors import LoadImage | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.logger import get_logger | |||
| logger = get_logger() | |||
| @PIPELINES.register_module( | |||
| Tasks.image_object_detection, module_name=Pipelines.tinynas_detection) | |||
| class TinynasDetectionPipeline(Pipeline): | |||
| def __init__(self, model: str, **kwargs): | |||
| """ | |||
| model: model id on modelscope hub. | |||
| """ | |||
| super().__init__(model=model, auto_collate=False, **kwargs) | |||
| if torch.cuda.is_available(): | |||
| self.device = 'cuda' | |||
| else: | |||
| self.device = 'cpu' | |||
| self.model.to(self.device) | |||
| self.model.eval() | |||
| def preprocess(self, input: Input) -> Dict[str, Any]: | |||
| img = LoadImage.convert_to_ndarray(input) | |||
| self.img = img | |||
| img = img.astype(np.float) | |||
| img = self.model.preprocess(img) | |||
| result = {'img': img.to(self.device)} | |||
| return result | |||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||
| outputs = self.model.inference(input['img']) | |||
| result = {'data': outputs} | |||
| return result | |||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
| bboxes, scores, labels = self.model.postprocess(inputs['data']) | |||
| if bboxes is None: | |||
| return None | |||
| outputs = { | |||
| OutputKeys.SCORES: scores, | |||
| OutputKeys.LABELS: labels, | |||
| OutputKeys.BOXES: bboxes | |||
| } | |||
| return outputs | |||
| @@ -0,0 +1,20 @@ | |||
| import unittest | |||
| from modelscope.pipelines import pipeline | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.test_utils import test_level | |||
| class TinynasObjectDetectionTest(unittest.TestCase): | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run(self): | |||
| tinynas_object_detection = pipeline( | |||
| Tasks.image_object_detection, model='damo/cv_tinynas_detection') | |||
| result = tinynas_object_detection( | |||
| 'data/test/images/image_detection.jpg') | |||
| print(result) | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||