From 1bac4f3349cbd1c343f4fbe1d9ec80198afd1a32 Mon Sep 17 00:00:00 2001 From: "xianzhe.xxz" Date: Fri, 2 Sep 2022 13:10:31 +0800 Subject: [PATCH] [to #42322933]add tinynas-detection pipeline and models MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 接入tinynas-detection,新增tinynas object detection pipeline以及tinynas models。 Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9938220 --- modelscope/metainfo.py | 3 + .../models/cv/tinynas_detection/__init__.py | 24 + .../cv/tinynas_detection/backbone/__init__.py | 16 + .../cv/tinynas_detection/backbone/darknet.py | 126 ++++ .../cv/tinynas_detection/backbone/tinynas.py | 347 +++++++++ .../cv/tinynas_detection/core/__init__.py | 2 + .../cv/tinynas_detection/core/base_ops.py | 474 +++++++++++++ .../cv/tinynas_detection/core/neck_ops.py | 324 +++++++++ .../cv/tinynas_detection/core/repvgg_block.py | 205 ++++++ .../models/cv/tinynas_detection/core/utils.py | 196 ++++++ .../models/cv/tinynas_detection/detector.py | 181 +++++ .../cv/tinynas_detection/head/__init__.py | 16 + .../tinynas_detection/head/gfocal_v2_tiny.py | 361 ++++++++++ .../cv/tinynas_detection/neck/__init__.py | 16 + .../tinynas_detection/neck/giraffe_config.py | 235 +++++++ .../cv/tinynas_detection/neck/giraffe_fpn.py | 661 ++++++++++++++++++ .../tinynas_detection/neck/giraffe_fpn_v2.py | 203 ++++++ .../cv/tinynas_detection/tinynas_detector.py | 16 + .../models/cv/tinynas_detection/utils.py | 30 + .../cv/tinynas_detection_pipeline.py | 61 ++ tests/pipelines/test_tinynas_detection.py | 20 + 21 files changed, 3517 insertions(+) create mode 100644 modelscope/models/cv/tinynas_detection/__init__.py create mode 100644 modelscope/models/cv/tinynas_detection/backbone/__init__.py create mode 100644 modelscope/models/cv/tinynas_detection/backbone/darknet.py create mode 100755 modelscope/models/cv/tinynas_detection/backbone/tinynas.py create mode 100644 modelscope/models/cv/tinynas_detection/core/__init__.py create mode 100644 modelscope/models/cv/tinynas_detection/core/base_ops.py create mode 100644 modelscope/models/cv/tinynas_detection/core/neck_ops.py create mode 100644 modelscope/models/cv/tinynas_detection/core/repvgg_block.py create mode 100644 modelscope/models/cv/tinynas_detection/core/utils.py create mode 100644 modelscope/models/cv/tinynas_detection/detector.py create mode 100644 modelscope/models/cv/tinynas_detection/head/__init__.py create mode 100644 modelscope/models/cv/tinynas_detection/head/gfocal_v2_tiny.py create mode 100644 modelscope/models/cv/tinynas_detection/neck/__init__.py create mode 100644 modelscope/models/cv/tinynas_detection/neck/giraffe_config.py create mode 100644 modelscope/models/cv/tinynas_detection/neck/giraffe_fpn.py create mode 100644 modelscope/models/cv/tinynas_detection/neck/giraffe_fpn_v2.py create mode 100644 modelscope/models/cv/tinynas_detection/tinynas_detector.py create mode 100644 modelscope/models/cv/tinynas_detection/utils.py create mode 100644 modelscope/pipelines/cv/tinynas_detection_pipeline.py create mode 100644 tests/pipelines/test_tinynas_detection.py diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 971dd3f1..fd653bac 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -9,6 +9,8 @@ class Models(object): Model name should only contain model info but not task info. """ + tinynas_detection = 'tinynas-detection' + # vision models detection = 'detection' realtime_object_detection = 'realtime-object-detection' @@ -133,6 +135,7 @@ class Pipelines(object): image_to_image_generation = 'image-to-image-generation' skin_retouching = 'unet-skin-retouching' tinynas_classification = 'tinynas-classification' + tinynas_detection = 'tinynas-detection' crowd_counting = 'hrnet-crowd-counting' action_detection = 'ResNetC3D-action-detection' video_single_object_tracking = 'ostrack-vitb-video-single-object-tracking' diff --git a/modelscope/models/cv/tinynas_detection/__init__.py b/modelscope/models/cv/tinynas_detection/__init__.py new file mode 100644 index 00000000..13532d10 --- /dev/null +++ b/modelscope/models/cv/tinynas_detection/__init__.py @@ -0,0 +1,24 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet. + +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .tinynas_detector import Tinynas_detector + +else: + _import_structure = { + 'tinynas_detector': ['TinynasDetector'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/tinynas_detection/backbone/__init__.py b/modelscope/models/cv/tinynas_detection/backbone/__init__.py new file mode 100644 index 00000000..186d06a3 --- /dev/null +++ b/modelscope/models/cv/tinynas_detection/backbone/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet. + +import copy + +from .darknet import CSPDarknet +from .tinynas import load_tinynas_net + + +def build_backbone(cfg): + backbone_cfg = copy.deepcopy(cfg) + name = backbone_cfg.pop('name') + if name == 'CSPDarknet': + return CSPDarknet(**backbone_cfg) + elif name == 'TinyNAS': + return load_tinynas_net(backbone_cfg) diff --git a/modelscope/models/cv/tinynas_detection/backbone/darknet.py b/modelscope/models/cv/tinynas_detection/backbone/darknet.py new file mode 100644 index 00000000..d3294f0d --- /dev/null +++ b/modelscope/models/cv/tinynas_detection/backbone/darknet.py @@ -0,0 +1,126 @@ +# Copyright (c) Megvii Inc. All rights reserved. +# Copyright (c) Alibaba, Inc. and its affiliates. +# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet. + +import torch +from torch import nn + +from ..core.base_ops import (BaseConv, CSPLayer, DWConv, Focus, ResLayer, + SPPBottleneck) + + +class CSPDarknet(nn.Module): + + def __init__( + self, + dep_mul, + wid_mul, + out_features=('dark3', 'dark4', 'dark5'), + depthwise=False, + act='silu', + reparam=False, + ): + super(CSPDarknet, self).__init__() + assert out_features, 'please provide output features of Darknet' + self.out_features = out_features + Conv = DWConv if depthwise else BaseConv + + base_channels = int(wid_mul * 64) # 64 + base_depth = max(round(dep_mul * 3), 1) # 3 + + # stem + # self.stem = Focus(3, base_channels, ksize=3, act=act) + self.stem = Focus(3, base_channels, 3, act=act) + + # dark2 + self.dark2 = nn.Sequential( + Conv(base_channels, base_channels * 2, 3, 2, act=act), + CSPLayer( + base_channels * 2, + base_channels * 2, + n=base_depth, + depthwise=depthwise, + act=act, + reparam=reparam, + ), + ) + + # dark3 + self.dark3 = nn.Sequential( + Conv(base_channels * 2, base_channels * 4, 3, 2, act=act), + CSPLayer( + base_channels * 4, + base_channels * 4, + n=base_depth * 3, + depthwise=depthwise, + act=act, + reparam=reparam, + ), + ) + + # dark4 + self.dark4 = nn.Sequential( + Conv(base_channels * 4, base_channels * 8, 3, 2, act=act), + CSPLayer( + base_channels * 8, + base_channels * 8, + n=base_depth * 3, + depthwise=depthwise, + act=act, + reparam=reparam, + ), + ) + + # dark5 + self.dark5 = nn.Sequential( + Conv(base_channels * 8, base_channels * 16, 3, 2, act=act), + SPPBottleneck( + base_channels * 16, base_channels * 16, activation=act), + CSPLayer( + base_channels * 16, + base_channels * 16, + n=base_depth, + shortcut=False, + depthwise=depthwise, + act=act, + reparam=reparam, + ), + ) + + def init_weights(self, pretrain=None): + + if pretrain is None: + return + else: + pretrained_dict = torch.load( + pretrain, map_location='cpu')['state_dict'] + new_params = self.state_dict().copy() + for k, v in pretrained_dict.items(): + ks = k.split('.') + if ks[0] == 'fc' or ks[-1] == 'total_ops' or ks[ + -1] == 'total_params': + continue + else: + new_params[k] = v + + self.load_state_dict(new_params) + print(f' load pretrain backbone from {pretrain}') + + def forward(self, x): + outputs = {} + x = self.stem(x) + outputs['stem'] = x + x = self.dark2(x) + outputs['dark2'] = x + x = self.dark3(x) + outputs['dark3'] = x + x = self.dark4(x) + outputs['dark4'] = x + x = self.dark5(x) + outputs['dark5'] = x + features_out = [ + outputs['stem'], outputs['dark2'], outputs['dark3'], + outputs['dark4'], outputs['dark5'] + ] + + return features_out diff --git a/modelscope/models/cv/tinynas_detection/backbone/tinynas.py b/modelscope/models/cv/tinynas_detection/backbone/tinynas.py new file mode 100755 index 00000000..814ee550 --- /dev/null +++ b/modelscope/models/cv/tinynas_detection/backbone/tinynas.py @@ -0,0 +1,347 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet. + +import torch +import torch.nn as nn + +from ..core.base_ops import Focus, SPPBottleneck, get_activation +from ..core.repvgg_block import RepVggBlock + + +class ConvKXBN(nn.Module): + + def __init__(self, in_c, out_c, kernel_size, stride): + super(ConvKXBN, self).__init__() + self.conv1 = nn.Conv2d( + in_c, + out_c, + kernel_size, + stride, (kernel_size - 1) // 2, + groups=1, + bias=False) + self.bn1 = nn.BatchNorm2d(out_c) + + def forward(self, x): + return self.bn1(self.conv1(x)) + + +class ConvKXBNRELU(nn.Module): + + def __init__(self, in_c, out_c, kernel_size, stride, act='silu'): + super(ConvKXBNRELU, self).__init__() + self.conv = ConvKXBN(in_c, out_c, kernel_size, stride) + if act is None: + self.activation_function = torch.relu + else: + self.activation_function = get_activation(act) + + def forward(self, x): + output = self.conv(x) + return self.activation_function(output) + + +class ResConvK1KX(nn.Module): + + def __init__(self, + in_c, + out_c, + btn_c, + kernel_size, + stride, + force_resproj=False, + act='silu'): + super(ResConvK1KX, self).__init__() + self.stride = stride + self.conv1 = ConvKXBN(in_c, btn_c, 1, 1) + self.conv2 = RepVggBlock( + btn_c, out_c, kernel_size, stride, act='identity') + + if act is None: + self.activation_function = torch.relu + else: + self.activation_function = get_activation(act) + + if stride == 2: + self.residual_downsample = nn.AvgPool2d(kernel_size=2, stride=2) + else: + self.residual_downsample = nn.Identity() + + if in_c != out_c or force_resproj: + self.residual_proj = ConvKXBN(in_c, out_c, 1, 1) + else: + self.residual_proj = nn.Identity() + + def forward(self, x): + if self.stride != 2: + reslink = self.residual_downsample(x) + reslink = self.residual_proj(reslink) + + output = x + output = self.conv1(output) + output = self.activation_function(output) + output = self.conv2(output) + if self.stride != 2: + output = output + reslink + output = self.activation_function(output) + + return output + + +class SuperResConvK1KX(nn.Module): + + def __init__(self, + in_c, + out_c, + btn_c, + kernel_size, + stride, + num_blocks, + with_spp=False, + act='silu'): + super(SuperResConvK1KX, self).__init__() + if act is None: + self.act = torch.relu + else: + self.act = get_activation(act) + self.block_list = nn.ModuleList() + for block_id in range(num_blocks): + if block_id == 0: + in_channels = in_c + out_channels = out_c + this_stride = stride + force_resproj = False # as a part of CSPLayer, DO NOT need this flag + this_kernel_size = kernel_size + else: + in_channels = out_c + out_channels = out_c + this_stride = 1 + force_resproj = False + this_kernel_size = kernel_size + the_block = ResConvK1KX( + in_channels, + out_channels, + btn_c, + this_kernel_size, + this_stride, + force_resproj, + act=act) + self.block_list.append(the_block) + if block_id == 0 and with_spp: + self.block_list.append( + SPPBottleneck(out_channels, out_channels)) + + def forward(self, x): + output = x + for block in self.block_list: + output = block(output) + return output + + +class ResConvKXKX(nn.Module): + + def __init__(self, + in_c, + out_c, + btn_c, + kernel_size, + stride, + force_resproj=False, + act='silu'): + super(ResConvKXKX, self).__init__() + self.stride = stride + if self.stride == 2: + self.downsampler = ConvKXBNRELU(in_c, out_c, 3, 2, act=act) + else: + self.conv1 = ConvKXBN(in_c, btn_c, kernel_size, 1) + self.conv2 = RepVggBlock( + btn_c, out_c, kernel_size, stride, act='identity') + + if act is None: + self.activation_function = torch.relu + else: + self.activation_function = get_activation(act) + + if stride == 2: + self.residual_downsample = nn.AvgPool2d( + kernel_size=2, stride=2) + else: + self.residual_downsample = nn.Identity() + + if in_c != out_c or force_resproj: + self.residual_proj = ConvKXBN(in_c, out_c, 1, 1) + else: + self.residual_proj = nn.Identity() + + def forward(self, x): + if self.stride == 2: + return self.downsampler(x) + reslink = self.residual_downsample(x) + reslink = self.residual_proj(reslink) + + output = x + output = self.conv1(output) + output = self.activation_function(output) + output = self.conv2(output) + + output = output + reslink + output = self.activation_function(output) + + return output + + +class SuperResConvKXKX(nn.Module): + + def __init__(self, + in_c, + out_c, + btn_c, + kernel_size, + stride, + num_blocks, + with_spp=False, + act='silu'): + super(SuperResConvKXKX, self).__init__() + if act is None: + self.act = torch.relu + else: + self.act = get_activation(act) + self.block_list = nn.ModuleList() + for block_id in range(num_blocks): + if block_id == 0: + in_channels = in_c + out_channels = out_c + this_stride = stride + force_resproj = False # as a part of CSPLayer, DO NOT need this flag + this_kernel_size = kernel_size + else: + in_channels = out_c + out_channels = out_c + this_stride = 1 + force_resproj = False + this_kernel_size = kernel_size + the_block = ResConvKXKX( + in_channels, + out_channels, + btn_c, + this_kernel_size, + this_stride, + force_resproj, + act=act) + self.block_list.append(the_block) + if block_id == 0 and with_spp: + self.block_list.append( + SPPBottleneck(out_channels, out_channels)) + + def forward(self, x): + output = x + for block in self.block_list: + output = block(output) + return output + + +class TinyNAS(nn.Module): + + def __init__(self, + structure_info=None, + out_indices=[0, 1, 2, 4, 5], + out_channels=[None, None, 128, 256, 512], + with_spp=False, + use_focus=False, + need_conv1=True, + act='silu'): + super(TinyNAS, self).__init__() + assert len(out_indices) == len(out_channels) + self.out_indices = out_indices + self.need_conv1 = need_conv1 + + self.block_list = nn.ModuleList() + if need_conv1: + self.conv1_list = nn.ModuleList() + for idx, block_info in enumerate(structure_info): + the_block_class = block_info['class'] + if the_block_class == 'ConvKXBNRELU': + if use_focus: + the_block = Focus(block_info['in'], block_info['out'], + block_info['k']) + else: + the_block = ConvKXBNRELU( + block_info['in'], + block_info['out'], + block_info['k'], + block_info['s'], + act=act) + self.block_list.append(the_block) + elif the_block_class == 'SuperResConvK1KX': + spp = with_spp if idx == len(structure_info) - 1 else False + the_block = SuperResConvK1KX( + block_info['in'], + block_info['out'], + block_info['btn'], + block_info['k'], + block_info['s'], + block_info['L'], + spp, + act=act) + self.block_list.append(the_block) + elif the_block_class == 'SuperResConvKXKX': + spp = with_spp if idx == len(structure_info) - 1 else False + the_block = SuperResConvKXKX( + block_info['in'], + block_info['out'], + block_info['btn'], + block_info['k'], + block_info['s'], + block_info['L'], + spp, + act=act) + self.block_list.append(the_block) + if need_conv1: + if idx in self.out_indices and out_channels[ + self.out_indices.index(idx)] is not None: + self.conv1_list.append( + nn.Conv2d(block_info['out'], + out_channels[self.out_indices.index(idx)], + 1)) + else: + self.conv1_list.append(None) + + def init_weights(self, pretrain=None): + pass + + def forward(self, x): + output = x + stage_feature_list = [] + for idx, block in enumerate(self.block_list): + output = block(output) + if idx in self.out_indices: + if self.need_conv1 and self.conv1_list[idx] is not None: + true_out = self.conv1_list[idx](output) + stage_feature_list.append(true_out) + else: + stage_feature_list.append(output) + return stage_feature_list + + +def load_tinynas_net(backbone_cfg): + # load masternet model to path + import ast + + struct_str = ''.join([x.strip() for x in backbone_cfg.net_structure_str]) + struct_info = ast.literal_eval(struct_str) + for layer in struct_info: + if 'nbitsA' in layer: + del layer['nbitsA'] + if 'nbitsW' in layer: + del layer['nbitsW'] + + model = TinyNAS( + structure_info=struct_info, + out_indices=backbone_cfg.out_indices, + out_channels=backbone_cfg.out_channels, + with_spp=backbone_cfg.with_spp, + use_focus=backbone_cfg.use_focus, + act=backbone_cfg.act, + need_conv1=backbone_cfg.need_conv1, + ) + + return model diff --git a/modelscope/models/cv/tinynas_detection/core/__init__.py b/modelscope/models/cv/tinynas_detection/core/__init__.py new file mode 100644 index 00000000..3dad5e72 --- /dev/null +++ b/modelscope/models/cv/tinynas_detection/core/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet. diff --git a/modelscope/models/cv/tinynas_detection/core/base_ops.py b/modelscope/models/cv/tinynas_detection/core/base_ops.py new file mode 100644 index 00000000..62729ca2 --- /dev/null +++ b/modelscope/models/cv/tinynas_detection/core/base_ops.py @@ -0,0 +1,474 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet. +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .repvgg_block import RepVggBlock + + +class SiLU(nn.Module): + """export-friendly version of nn.SiLU()""" + + @staticmethod + def forward(x): + return x * torch.sigmoid(x) + + +def get_activation(name='silu', inplace=True): + if name == 'silu': + module = nn.SiLU(inplace=inplace) + elif name == 'relu': + module = nn.ReLU(inplace=inplace) + elif name == 'lrelu': + module = nn.LeakyReLU(0.1, inplace=inplace) + else: + raise AttributeError('Unsupported act type: {}'.format(name)) + return module + + +def get_norm(name, out_channels, inplace=True): + if name == 'bn': + module = nn.BatchNorm2d(out_channels) + elif name == 'gn': + module = nn.GroupNorm(num_channels=out_channels, num_groups=32) + return module + + +class BaseConv(nn.Module): + """A Conv2d -> Batchnorm -> silu/leaky relu block""" + + def __init__(self, + in_channels, + out_channels, + ksize, + stride=1, + groups=1, + bias=False, + act='silu', + norm='bn'): + super().__init__() + # same padding + pad = (ksize - 1) // 2 + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size=ksize, + stride=stride, + padding=pad, + groups=groups, + bias=bias, + ) + if norm is not None: + self.bn = get_norm(norm, out_channels, inplace=True) + if act is not None: + self.act = get_activation(act, inplace=True) + self.with_norm = norm is not None + self.with_act = act is not None + + def forward(self, x): + x = self.conv(x) + if self.with_norm: + # x = self.norm(x) + x = self.bn(x) + if self.with_act: + x = self.act(x) + return x + + def fuseforward(self, x): + return self.act(self.conv(x)) + + +class DepthWiseConv(nn.Module): + + def __init__(self, + in_channels, + out_channels, + ksize, + stride=1, + groups=None, + bias=False, + act='silu', + norm='bn'): + super().__init__() + padding = (ksize - 1) // 2 + self.depthwise = nn.Conv2d( + in_channels, + in_channels, + kernel_size=ksize, + stride=stride, + padding=padding, + groups=in_channels, + bias=bias, + ) + + self.pointwise = nn.Conv2d( + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias) + if norm is not None: + self.dwnorm = get_norm(norm, in_channels, inplace=True) + self.pwnorm = get_norm(norm, out_channels, inplace=True) + if act is not None: + self.act = get_activation(act, inplace=True) + + self.with_norm = norm is not None + self.with_act = act is not None + self.order = ['depthwise', 'dwnorm', 'pointwise', 'act'] + + def forward(self, x): + + for layer_name in self.order: + layer = self.__getattr__(layer_name) + if layer is not None: + x = layer(x) + return x + + +class DWConv(nn.Module): + """Depthwise Conv + Conv""" + + def __init__(self, in_channels, out_channels, ksize, stride=1, act='silu'): + super().__init__() + self.dconv = BaseConv( + in_channels, + in_channels, + ksize=ksize, + stride=stride, + groups=in_channels, + act=act, + ) + self.pconv = BaseConv( + in_channels, out_channels, ksize=1, stride=1, groups=1, act=act) + + def forward(self, x): + x = self.dconv(x) + return self.pconv(x) + + +class Bottleneck(nn.Module): + # Standard bottleneck + def __init__( + self, + in_channels, + out_channels, + shortcut=True, + expansion=0.5, + depthwise=False, + act='silu', + reparam=False, + ): + super().__init__() + hidden_channels = int(out_channels * expansion) + Conv = DWConv if depthwise else BaseConv + k_conv1 = 3 if reparam else 1 + self.conv1 = BaseConv( + in_channels, hidden_channels, k_conv1, stride=1, act=act) + if reparam: + self.conv2 = RepVggBlock( + hidden_channels, out_channels, 3, stride=1, act=act) + else: + self.conv2 = Conv( + hidden_channels, out_channels, 3, stride=1, act=act) + self.use_add = shortcut and in_channels == out_channels + + def forward(self, x): + y = self.conv2(self.conv1(x)) + if self.use_add: + y = y + x + return y + + +class ResLayer(nn.Module): + 'Residual layer with `in_channels` inputs.' + + def __init__(self, in_channels: int): + super().__init__() + mid_channels = in_channels // 2 + self.layer1 = BaseConv( + in_channels, mid_channels, ksize=1, stride=1, act='lrelu') + self.layer2 = BaseConv( + mid_channels, in_channels, ksize=3, stride=1, act='lrelu') + + def forward(self, x): + out = self.layer2(self.layer1(x)) + return x + out + + +class SPPBottleneck(nn.Module): + """Spatial pyramid pooling layer used in YOLOv3-SPP""" + + def __init__(self, + in_channels, + out_channels, + kernel_sizes=(5, 9, 13), + activation='silu'): + super().__init__() + hidden_channels = in_channels // 2 + self.conv1 = BaseConv( + in_channels, hidden_channels, 1, stride=1, act=activation) + self.m = nn.ModuleList([ + nn.MaxPool2d(kernel_size=ks, stride=1, padding=ks // 2) + for ks in kernel_sizes + ]) + conv2_channels = hidden_channels * (len(kernel_sizes) + 1) + self.conv2 = BaseConv( + conv2_channels, out_channels, 1, stride=1, act=activation) + + def forward(self, x): + x = self.conv1(x) + x = torch.cat([x] + [m(x) for m in self.m], dim=1) + x = self.conv2(x) + return x + + +class CSPLayer(nn.Module): + """C3 in yolov5, CSP Bottleneck with 3 convolutions""" + + def __init__( + self, + in_channels, + out_channels, + n=1, + shortcut=True, + expansion=0.5, + depthwise=False, + act='silu', + reparam=False, + ): + """ + Args: + in_channels (int): input channels. + out_channels (int): output channels. + n (int): number of Bottlenecks. Default value: 1. + """ + # ch_in, ch_out, number, shortcut, groups, expansion + super().__init__() + hidden_channels = int(out_channels * expansion) # hidden channels + self.conv1 = BaseConv( + in_channels, hidden_channels, 1, stride=1, act=act) + self.conv2 = BaseConv( + in_channels, hidden_channels, 1, stride=1, act=act) + self.conv3 = BaseConv( + 2 * hidden_channels, out_channels, 1, stride=1, act=act) + module_list = [ + Bottleneck( + hidden_channels, + hidden_channels, + shortcut, + 1.0, + depthwise, + act=act, + reparam=reparam) for _ in range(n) + ] + self.m = nn.Sequential(*module_list) + + def forward(self, x): + x_1 = self.conv1(x) + x_2 = self.conv2(x) + x_1 = self.m(x_1) + x = torch.cat((x_1, x_2), dim=1) + return self.conv3(x) + + +class Focus(nn.Module): + """Focus width and height information into channel space.""" + + def __init__(self, + in_channels, + out_channels, + ksize=1, + stride=1, + act='silu'): + super().__init__() + self.conv = BaseConv( + in_channels * 4, out_channels, ksize, stride, act=act) + + def forward(self, x): + # shape of x (b,c,w,h) -> y(b,4c,w/2,h/2) + patch_top_left = x[..., ::2, ::2] + patch_top_right = x[..., ::2, 1::2] + patch_bot_left = x[..., 1::2, ::2] + patch_bot_right = x[..., 1::2, 1::2] + x = torch.cat( + ( + patch_top_left, + patch_bot_left, + patch_top_right, + patch_bot_right, + ), + dim=1, + ) + return self.conv(x) + + +class fast_Focus(nn.Module): + + def __init__(self, + in_channels, + out_channels, + ksize=1, + stride=1, + act='silu'): + super(Focus, self).__init__() + self.conv1 = self.focus_conv(w1=1.0) + self.conv2 = self.focus_conv(w3=1.0) + self.conv3 = self.focus_conv(w2=1.0) + self.conv4 = self.focus_conv(w4=1.0) + + self.conv = BaseConv( + in_channels * 4, out_channels, ksize, stride, act=act) + + def forward(self, x): + return self.conv( + torch.cat( + [self.conv1(x), + self.conv2(x), + self.conv3(x), + self.conv4(x)], 1)) + + def focus_conv(self, w1=0.0, w2=0.0, w3=0.0, w4=0.0): + conv = nn.Conv2d(3, 3, 2, 2, groups=3, bias=False) + conv.weight = self.init_weights_constant(w1, w2, w3, w4) + conv.weight.requires_grad = False + return conv + + def init_weights_constant(self, w1=0.0, w2=0.0, w3=0.0, w4=0.0): + return nn.Parameter( + torch.tensor([[[[w1, w2], [w3, w4]]], [[[w1, w2], [w3, w4]]], + [[[w1, w2], [w3, w4]]]])) + + +# shufflenet block +def channel_shuffle(x, groups=2): + bat_size, channels, w, h = x.shape + group_c = channels // groups + x = x.view(bat_size, groups, group_c, w, h) + x = torch.transpose(x, 1, 2).contiguous() + x = x.view(bat_size, -1, w, h) + return x + + +def conv_1x1_bn(in_c, out_c, stride=1): + return nn.Sequential( + nn.Conv2d(in_c, out_c, 1, stride, 0, bias=False), + nn.BatchNorm2d(out_c), nn.ReLU(True)) + + +def conv_bn(in_c, out_c, stride=2): + return nn.Sequential( + nn.Conv2d(in_c, out_c, 3, stride, 1, bias=False), + nn.BatchNorm2d(out_c), nn.ReLU(True)) + + +class ShuffleBlock(nn.Module): + + def __init__(self, in_c, out_c, downsample=False): + super(ShuffleBlock, self).__init__() + self.downsample = downsample + half_c = out_c // 2 + if downsample: + self.branch1 = nn.Sequential( + # 3*3 dw conv, stride = 2 + # nn.Conv2d(in_c, in_c, 3, 2, 1, groups=in_c, bias=False), + nn.Conv2d(in_c, in_c, 3, 1, 1, groups=in_c, bias=False), + nn.BatchNorm2d(in_c), + # 1*1 pw conv + nn.Conv2d(in_c, half_c, 1, 1, 0, bias=False), + nn.BatchNorm2d(half_c), + nn.ReLU(True)) + + self.branch2 = nn.Sequential( + # 1*1 pw conv + nn.Conv2d(in_c, half_c, 1, 1, 0, bias=False), + nn.BatchNorm2d(half_c), + nn.ReLU(True), + # 3*3 dw conv, stride = 2 + # nn.Conv2d(half_c, half_c, 3, 2, 1, groups=half_c, bias=False), + nn.Conv2d(half_c, half_c, 3, 1, 1, groups=half_c, bias=False), + nn.BatchNorm2d(half_c), + # 1*1 pw conv + nn.Conv2d(half_c, half_c, 1, 1, 0, bias=False), + nn.BatchNorm2d(half_c), + nn.ReLU(True)) + else: + # in_c = out_c + assert in_c == out_c + + self.branch2 = nn.Sequential( + # 1*1 pw conv + nn.Conv2d(half_c, half_c, 1, 1, 0, bias=False), + nn.BatchNorm2d(half_c), + nn.ReLU(True), + # 3*3 dw conv, stride = 1 + nn.Conv2d(half_c, half_c, 3, 1, 1, groups=half_c, bias=False), + nn.BatchNorm2d(half_c), + # 1*1 pw conv + nn.Conv2d(half_c, half_c, 1, 1, 0, bias=False), + nn.BatchNorm2d(half_c), + nn.ReLU(True)) + + def forward(self, x): + out = None + if self.downsample: + # if it is downsampling, we don't need to do channel split + out = torch.cat((self.branch1(x), self.branch2(x)), 1) + else: + # channel split + channels = x.shape[1] + c = channels // 2 + x1 = x[:, :c, :, :] + x2 = x[:, c:, :, :] + out = torch.cat((x1, self.branch2(x2)), 1) + return channel_shuffle(out, 2) + + +class ShuffleCSPLayer(nn.Module): + """C3 in yolov5, CSP Bottleneck with 3 convolutions""" + + def __init__( + self, + in_channels, + out_channels, + n=1, + shortcut=True, + expansion=0.5, + depthwise=False, + act='silu', + ): + """ + Args: + in_channels (int): input channels. + out_channels (int): output channels. + n (int): number of Bottlenecks. Default value: 1. + """ + # ch_in, ch_out, number, shortcut, groups, expansion + super().__init__() + hidden_channels = int(out_channels * expansion) # hidden channels + self.conv1 = BaseConv( + in_channels, hidden_channels, 1, stride=1, act=act) + self.conv2 = BaseConv( + in_channels, hidden_channels, 1, stride=1, act=act) + module_list = [ + Bottleneck( + hidden_channels, + hidden_channels, + shortcut, + 1.0, + depthwise, + act=act) for _ in range(n) + ] + self.m = nn.Sequential(*module_list) + + def forward(self, x): + x_1 = self.conv1(x) + x_2 = self.conv2(x) + x_1 = self.m(x_1) + x = torch.cat((x_1, x_2), dim=1) + # add channel shuffle + return channel_shuffle(x, 2) diff --git a/modelscope/models/cv/tinynas_detection/core/neck_ops.py b/modelscope/models/cv/tinynas_detection/core/neck_ops.py new file mode 100644 index 00000000..7f481665 --- /dev/null +++ b/modelscope/models/cv/tinynas_detection/core/neck_ops.py @@ -0,0 +1,324 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet. + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class Swish(nn.Module): + + def __init__(self, inplace=True): + super(Swish, self).__init__() + self.inplace = inplace + + def forward(self, x): + if self.inplace: + x.mul_(F.sigmoid(x)) + return x + else: + return x * F.sigmoid(x) + + +def get_activation(name='silu', inplace=True): + if name is None: + return nn.Identity() + + if isinstance(name, str): + if name == 'silu': + module = nn.SiLU(inplace=inplace) + elif name == 'relu': + module = nn.ReLU(inplace=inplace) + elif name == 'lrelu': + module = nn.LeakyReLU(0.1, inplace=inplace) + elif name == 'swish': + module = Swish(inplace=inplace) + elif name == 'hardsigmoid': + module = nn.Hardsigmoid(inplace=inplace) + else: + raise AttributeError('Unsupported act type: {}'.format(name)) + return module + elif isinstance(name, nn.Module): + return name + else: + raise AttributeError('Unsupported act type: {}'.format(name)) + + +class ConvBNLayer(nn.Module): + + def __init__(self, + ch_in, + ch_out, + filter_size=3, + stride=1, + groups=1, + padding=0, + act=None): + super(ConvBNLayer, self).__init__() + self.conv = nn.Conv2d( + in_channels=ch_in, + out_channels=ch_out, + kernel_size=filter_size, + stride=stride, + padding=padding, + groups=groups, + bias=False) + self.bn = nn.BatchNorm2d(ch_out, ) + self.act = get_activation(act, inplace=True) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + x = self.act(x) + + return x + + +class RepVGGBlock(nn.Module): + + def __init__(self, ch_in, ch_out, act='relu', deploy=False): + super(RepVGGBlock, self).__init__() + self.ch_in = ch_in + self.ch_out = ch_out + self.deploy = deploy + self.in_channels = ch_in + self.groups = 1 + if self.deploy is False: + self.rbr_dense = ConvBNLayer( + ch_in, ch_out, 3, stride=1, padding=1, act=None) + self.rbr_1x1 = ConvBNLayer( + ch_in, ch_out, 1, stride=1, padding=0, act=None) + # self.rbr_identity = nn.BatchNorm2d(num_features=ch_in) if ch_out == ch_in else None + self.rbr_identity = None + else: + self.rbr_reparam = nn.Conv2d( + in_channels=self.ch_in, + out_channels=self.ch_out, + kernel_size=3, + stride=1, + padding=1, + groups=1) + self.act = get_activation(act) if act is None or isinstance( + act, (str, dict)) else act + + def forward(self, x): + if self.deploy: + print('----------deploy----------') + y = self.rbr_reparam(x) + else: + if self.rbr_identity is None: + y = self.rbr_dense(x) + self.rbr_1x1(x) + else: + y = self.rbr_dense(x) + self.rbr_1x1(x) + self.rbr_identity(x) + + y = self.act(y) + return y + + def switch_to_deploy(self): + print('switch') + if not hasattr(self, 'rbr_reparam'): + # return + self.rbr_reparam = nn.Conv2d( + in_channels=self.ch_in, + out_channels=self.ch_out, + kernel_size=3, + stride=1, + padding=1, + groups=1) + print('switch') + kernel, bias = self.get_equivalent_kernel_bias() + self.rbr_reparam.weight.data = kernel + self.rbr_reparam.bias.data = bias + for para in self.parameters(): + para.detach_() + # self.__delattr__(self.rbr_dense) + # self.__delattr__(self.rbr_1x1) + self.__delattr__('rbr_dense') + self.__delattr__('rbr_1x1') + if hasattr(self, 'rbr_identity'): + self.__delattr__('rbr_identity') + if hasattr(self, 'id_tensor'): + self.__delattr__('id_tensor') + self.deploy = True + + def get_equivalent_kernel_bias(self): + kernel3x3, bias3x3 = self._fuse_bn_tensor(self.rbr_dense) + kernel1x1, bias1x1 = self._fuse_bn_tensor(self.rbr_1x1) + kernelid, biasid = self._fuse_bn_tensor(self.rbr_identity) + return kernel3x3 + self._pad_1x1_to_3x3_tensor( + kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid + + def _pad_1x1_to_3x3_tensor(self, kernel1x1): + if kernel1x1 is None: + return 0 + else: + return torch.nn.functional.pad(kernel1x1, [1, 1, 1, 1]) + + def _fuse_bn_tensor(self, branch): + if branch is None: + return 0, 0 + # if isinstance(branch, nn.Sequential): + if isinstance(branch, ConvBNLayer): + kernel = branch.conv.weight + running_mean = branch.bn.running_mean + running_var = branch.bn.running_var + gamma = branch.bn.weight + beta = branch.bn.bias + eps = branch.bn.eps + else: + assert isinstance(branch, nn.BatchNorm2d) + if not hasattr(self, 'id_tensor'): + input_dim = self.in_channels // self.groups + kernel_value = np.zeros((self.in_channels, input_dim, 3, 3), + dtype=np.float32) + for i in range(self.in_channels): + kernel_value[i, i % input_dim, 1, 1] = 1 + self.id_tensor = torch.from_numpy(kernel_value).to( + branch.weight.device) + kernel = self.id_tensor + running_mean = branch.running_mean + running_var = branch.running_var + gamma = branch.weight + beta = branch.bias + eps = branch.eps + std = (running_var + eps).sqrt() + t = (gamma / std).reshape(-1, 1, 1, 1) + return kernel * t, beta - running_mean * gamma / std + + +class BasicBlock(nn.Module): + + def __init__(self, ch_in, ch_out, act='relu', shortcut=True): + super(BasicBlock, self).__init__() + assert ch_in == ch_out + # self.conv1 = ConvBNLayer(ch_in, ch_out, 3, stride=1, padding=1, act=act) + # self.conv1 = ConvBNLayer(ch_in, ch_out, 1, stride=1, padding=0, act=act) + self.conv2 = RepVGGBlock(ch_in, ch_out, act=act) + self.shortcut = shortcut + + def forward(self, x): + # y = self.conv1(x) + y = self.conv2(x) + if self.shortcut: + return x + y + else: + return y + + +class BasicBlock_3x3(nn.Module): + + def __init__(self, ch_in, ch_out, act='relu', shortcut=True): + super(BasicBlock_3x3, self).__init__() + assert ch_in == ch_out + self.conv1 = ConvBNLayer( + ch_in, ch_out, 3, stride=1, padding=1, act=act) + # self.conv1 = ConvBNLayer(ch_in, ch_out, 1, stride=1, padding=0, act=act) + self.conv2 = RepVGGBlock(ch_in, ch_out, act=act) + self.shortcut = shortcut + + def forward(self, x): + y = self.conv1(x) + y = self.conv2(y) + if self.shortcut: + return x + y + else: + return y + + +class BasicBlock_3x3_Reverse(nn.Module): + + def __init__(self, ch_in, ch_out, act='relu', shortcut=True): + super(BasicBlock_3x3_Reverse, self).__init__() + assert ch_in == ch_out + self.conv1 = ConvBNLayer( + ch_in, ch_out, 3, stride=1, padding=1, act=act) + # self.conv1 = ConvBNLayer(ch_in, ch_out, 1, stride=1, padding=0, act=act) + self.conv2 = RepVGGBlock(ch_in, ch_out, act=act) + self.shortcut = shortcut + + def forward(self, x): + y = self.conv2(x) + y = self.conv1(y) + if self.shortcut: + return x + y + else: + return y + + +class SPP(nn.Module): + + def __init__( + self, + ch_in, + ch_out, + k, + pool_size, + act='swish', + ): + super(SPP, self).__init__() + self.pool = [] + for i, size in enumerate(pool_size): + pool = nn.MaxPool2d( + kernel_size=size, stride=1, padding=size // 2, ceil_mode=False) + self.add_module('pool{}'.format(i), pool) + self.pool.append(pool) + self.conv = ConvBNLayer(ch_in, ch_out, k, padding=k // 2, act=act) + + def forward(self, x): + outs = [x] + + for pool in self.pool: + outs.append(pool(x)) + y = torch.cat(outs, axis=1) + + y = self.conv(y) + return y + + +class CSPStage(nn.Module): + + def __init__(self, block_fn, ch_in, ch_out, n, act='swish', spp=False): + super(CSPStage, self).__init__() + + ch_mid = int(ch_out // 2) + self.conv1 = ConvBNLayer(ch_in, ch_mid, 1, act=act) + self.conv2 = ConvBNLayer(ch_in, ch_mid, 1, act=act) + # self.conv2 = ConvBNLayer(ch_in, ch_mid, 3, stride=1, padding=1, act=act) + self.convs = nn.Sequential() + + next_ch_in = ch_mid + for i in range(n): + if block_fn == 'BasicBlock': + self.convs.add_module( + str(i), + BasicBlock(next_ch_in, ch_mid, act=act, shortcut=False)) + elif block_fn == 'BasicBlock_3x3': + self.convs.add_module( + str(i), + BasicBlock_3x3(next_ch_in, ch_mid, act=act, shortcut=True)) + elif block_fn == 'BasicBlock_3x3_Reverse': + self.convs.add_module( + str(i), + BasicBlock_3x3_Reverse( + next_ch_in, ch_mid, act=act, shortcut=True)) + else: + raise NotImplementedError + if i == (n - 1) // 2 and spp: + self.convs.add_module( + 'spp', SPP(ch_mid * 4, ch_mid, 1, [5, 9, 13], act=act)) + next_ch_in = ch_mid + # self.convs = nn.Sequential(*convs) + self.conv3 = ConvBNLayer(ch_mid * (n + 1), ch_out, 1, act=act) + + def forward(self, x): + y1 = self.conv1(x) + y2 = self.conv2(x) + + mid_out = [y1] + for conv in self.convs: + y2 = conv(y2) + mid_out.append(y2) + y = torch.cat(mid_out, axis=1) + y = self.conv3(y) + return y diff --git a/modelscope/models/cv/tinynas_detection/core/repvgg_block.py b/modelscope/models/cv/tinynas_detection/core/repvgg_block.py new file mode 100644 index 00000000..06966a4e --- /dev/null +++ b/modelscope/models/cv/tinynas_detection/core/repvgg_block.py @@ -0,0 +1,205 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet. + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.nn.init as init +from torch.nn.parameter import Parameter + + +def get_activation(name='silu', inplace=True): + if name == 'silu': + module = nn.SiLU(inplace=inplace) + elif name == 'relu': + module = nn.ReLU(inplace=inplace) + elif name == 'lrelu': + module = nn.LeakyReLU(0.1, inplace=inplace) + elif name == 'identity': + module = nn.Identity() + else: + raise AttributeError('Unsupported act type: {}'.format(name)) + return module + + +def conv_bn(in_channels, out_channels, kernel_size, stride, padding, groups=1): + '''Basic cell for rep-style block, including conv and bn''' + result = nn.Sequential() + result.add_module( + 'conv', + nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + groups=groups, + bias=False)) + result.add_module('bn', nn.BatchNorm2d(num_features=out_channels)) + return result + + +class RepVggBlock(nn.Module): + '''RepVggBlock is a basic rep-style block, including training and deploy status + This code is based on https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py + ''' + + def __init__(self, + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + dilation=1, + groups=1, + padding_mode='zeros', + deploy=False, + use_se=False, + act='relu', + norm=None): + super(RepVggBlock, self).__init__() + """ Initialization of the class. + Args: + in_channels (int): Number of channels in the input image + out_channels (int): Number of channels produced by the convolution + kernel_size (int or tuple): Size of the convolving kernel + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int or tuple, optional): Zero-padding added to both sides of + the input. Default: 1 + dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 + groups (int, optional): Number of blocked connections from input + channels to output channels. Default: 1 + padding_mode (string, optional): Default: 'zeros' + deploy: Whether to be deploy status or training status. Default: False + use_se: Whether to use se. Default: False + """ + self.deploy = deploy + self.groups = groups + self.in_channels = in_channels + self.out_channels = out_channels + + assert kernel_size == 3 + assert padding == 1 + + padding_11 = padding - kernel_size // 2 + + if isinstance(act, str): + self.nonlinearity = get_activation(act) + else: + self.nonlinearity = act + + if use_se: + raise NotImplementedError('se block not supported yet') + else: + self.se = nn.Identity() + + if deploy: + self.rbr_reparam = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=True, + padding_mode=padding_mode) + + else: + self.rbr_identity = None + self.rbr_dense = conv_bn( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + groups=groups) + self.rbr_1x1 = conv_bn( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + stride=stride, + padding=padding_11, + groups=groups) + + def forward(self, inputs): + '''Forward process''' + if hasattr(self, 'rbr_reparam'): + return self.nonlinearity(self.se(self.rbr_reparam(inputs))) + + if self.rbr_identity is None: + id_out = 0 + else: + id_out = self.rbr_identity(inputs) + + return self.nonlinearity( + self.se(self.rbr_dense(inputs) + self.rbr_1x1(inputs) + id_out)) + + def get_equivalent_kernel_bias(self): + kernel3x3, bias3x3 = self._fuse_bn_tensor(self.rbr_dense) + kernel1x1, bias1x1 = self._fuse_bn_tensor(self.rbr_1x1) + kernelid, biasid = self._fuse_bn_tensor(self.rbr_identity) + return kernel3x3 + self._pad_1x1_to_3x3_tensor( + kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid + + def _pad_1x1_to_3x3_tensor(self, kernel1x1): + if kernel1x1 is None: + return 0 + else: + return torch.nn.functional.pad(kernel1x1, [1, 1, 1, 1]) + + def _fuse_bn_tensor(self, branch): + if branch is None: + return 0, 0 + if isinstance(branch, nn.Sequential): + kernel = branch.conv.weight + running_mean = branch.bn.running_mean + running_var = branch.bn.running_var + gamma = branch.bn.weight + beta = branch.bn.bias + eps = branch.bn.eps + else: + assert isinstance(branch, nn.BatchNorm2d) + if not hasattr(self, 'id_tensor'): + input_dim = self.in_channels // self.groups + kernel_value = np.zeros((self.in_channels, input_dim, 3, 3), + dtype=np.float32) + for i in range(self.in_channels): + kernel_value[i, i % input_dim, 1, 1] = 1 + self.id_tensor = torch.from_numpy(kernel_value).to( + branch.weight.device) + kernel = self.id_tensor + running_mean = branch.running_mean + running_var = branch.running_var + gamma = branch.weight + beta = branch.bias + eps = branch.eps + std = (running_var + eps).sqrt() + t = (gamma / std).reshape(-1, 1, 1, 1) + return kernel * t, beta - running_mean * gamma / std + + def switch_to_deploy(self): + if hasattr(self, 'rbr_reparam'): + return + kernel, bias = self.get_equivalent_kernel_bias() + self.rbr_reparam = nn.Conv2d( + in_channels=self.rbr_dense.conv.in_channels, + out_channels=self.rbr_dense.conv.out_channels, + kernel_size=self.rbr_dense.conv.kernel_size, + stride=self.rbr_dense.conv.stride, + padding=self.rbr_dense.conv.padding, + dilation=self.rbr_dense.conv.dilation, + groups=self.rbr_dense.conv.groups, + bias=True) + self.rbr_reparam.weight.data = kernel + self.rbr_reparam.bias.data = bias + for para in self.parameters(): + para.detach_() + self.__delattr__('rbr_dense') + self.__delattr__('rbr_1x1') + if hasattr(self, 'rbr_identity'): + self.__delattr__('rbr_identity') + if hasattr(self, 'id_tensor'): + self.__delattr__('id_tensor') + self.deploy = True diff --git a/modelscope/models/cv/tinynas_detection/core/utils.py b/modelscope/models/cv/tinynas_detection/core/utils.py new file mode 100644 index 00000000..482f12fb --- /dev/null +++ b/modelscope/models/cv/tinynas_detection/core/utils.py @@ -0,0 +1,196 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet. + +import numpy as np +import torch +import torchvision + +__all__ = [ + 'filter_box', + 'postprocess_airdet', + 'bboxes_iou', + 'matrix_iou', + 'adjust_box_anns', + 'xyxy2xywh', + 'xyxy2cxcywh', +] + + +def multiclass_nms(multi_bboxes, + multi_scores, + score_thr, + iou_thr, + max_num=100, + score_factors=None): + """NMS for multi-class bboxes. + + Args: + multi_bboxes (Tensor): shape (n, #class*4) or (n, 4) + multi_scores (Tensor): shape (n, #class), where the last column + contains scores of the background class, but this will be ignored. + score_thr (float): bbox threshold, bboxes with scores lower than it + will not be considered. + nms_thr (float): NMS IoU threshold + max_num (int): if there are more than max_num bboxes after NMS, + only top max_num will be kept. + score_factors (Tensor): The factors multiplied to scores before + applying NMS + + Returns: + tuple: (bboxes, labels), tensors of shape (k, 5) and (k, 1). Labels \ + are 0-based. + """ + num_classes = multi_scores.size(1) + # exclude background category + if multi_bboxes.shape[1] > 4: + bboxes = multi_bboxes.view(multi_scores.size(0), -1, 4) + else: + bboxes = multi_bboxes[:, None].expand( + multi_scores.size(0), num_classes, 4) + scores = multi_scores + # filter out boxes with low scores + valid_mask = scores > score_thr # 1000 * 80 bool + + # We use masked_select for ONNX exporting purpose, + # which is equivalent to bboxes = bboxes[valid_mask] + # (TODO): as ONNX does not support repeat now, + # we have to use this ugly code + # bboxes -> 1000, 4 + bboxes = torch.masked_select( + bboxes, + torch.stack((valid_mask, valid_mask, valid_mask, valid_mask), + -1)).view(-1, 4) # mask-> 1000*80*4, 80000*4 + if score_factors is not None: + scores = scores * score_factors[:, None] + scores = torch.masked_select(scores, valid_mask) + labels = valid_mask.nonzero(as_tuple=False)[:, 1] + + if bboxes.numel() == 0: + bboxes = multi_bboxes.new_zeros((0, 5)) + labels = multi_bboxes.new_zeros((0, ), dtype=torch.long) + scores = multi_bboxes.new_zeros((0, )) + + return bboxes, scores, labels + + keep = torchvision.ops.batched_nms(bboxes, scores, labels, iou_thr) + + if max_num > 0: + keep = keep[:max_num] + + return bboxes[keep], scores[keep], labels[keep] + + +def filter_box(output, scale_range): + """ + output: (N, 5+class) shape + """ + min_scale, max_scale = scale_range + w = output[:, 2] - output[:, 0] + h = output[:, 3] - output[:, 1] + keep = (w * h > min_scale * min_scale) & (w * h < max_scale * max_scale) + return output[keep] + + +def filter_results(boxlist, num_classes, nms_thre): + boxes = boxlist.bbox + scores = boxlist.get_field('scores') + cls = boxlist.get_field('labels') + nms_out_index = torchvision.ops.batched_nms( + boxes, + scores, + cls, + nms_thre, + ) + boxlist = boxlist[nms_out_index] + + return boxlist + + +def postprocess_airdet(prediction, + num_classes, + conf_thre=0.7, + nms_thre=0.45, + imgs=None): + box_corner = prediction.new(prediction.shape) + box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2 + box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2 + box_corner[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2 + box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2 + prediction[:, :, :4] = box_corner[:, :, :4] + output = [None for _ in range(len(prediction))] + for i, image_pred in enumerate(prediction): + # If none are remaining => process next image + if not image_pred.size(0): + continue + multi_bboxes = image_pred[:, :4] + multi_scores = image_pred[:, 5:] + detections, scores, labels = multiclass_nms(multi_bboxes, multi_scores, + conf_thre, nms_thre, 500) + detections = torch.cat( + (detections, scores[:, None], scores[:, None], labels[:, None]), + dim=1) + + if output[i] is None: + output[i] = detections + else: + output[i] = torch.cat((output[i], detections)) + return output + + +def bboxes_iou(bboxes_a, bboxes_b, xyxy=True): + if bboxes_a.shape[1] != 4 or bboxes_b.shape[1] != 4: + raise IndexError + + if xyxy: + tl = torch.max(bboxes_a[:, None, :2], bboxes_b[:, :2]) + br = torch.min(bboxes_a[:, None, 2:], bboxes_b[:, 2:]) + area_a = torch.prod(bboxes_a[:, 2:] - bboxes_a[:, :2], 1) + area_b = torch.prod(bboxes_b[:, 2:] - bboxes_b[:, :2], 1) + else: + tl = torch.max( + (bboxes_a[:, None, :2] - bboxes_a[:, None, 2:] / 2), + (bboxes_b[:, :2] - bboxes_b[:, 2:] / 2), + ) + br = torch.min( + (bboxes_a[:, None, :2] + bboxes_a[:, None, 2:] / 2), + (bboxes_b[:, :2] + bboxes_b[:, 2:] / 2), + ) + + area_a = torch.prod(bboxes_a[:, 2:], 1) + area_b = torch.prod(bboxes_b[:, 2:], 1) + en = (tl < br).type(tl.type()).prod(dim=2) + area_i = torch.prod(br - tl, 2) * en # * ((tl < br).all()) + return area_i / (area_a[:, None] + area_b - area_i) + + +def matrix_iou(a, b): + """ + return iou of a and b, numpy version for data augenmentation + """ + lt = np.maximum(a[:, np.newaxis, :2], b[:, :2]) + rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:]) + + area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2) + area_a = np.prod(a[:, 2:] - a[:, :2], axis=1) + area_b = np.prod(b[:, 2:] - b[:, :2], axis=1) + return area_i / (area_a[:, np.newaxis] + area_b - area_i + 1e-12) + + +def adjust_box_anns(bbox, scale_ratio, padw, padh, w_max, h_max): + bbox[:, 0::2] = np.clip(bbox[:, 0::2] * scale_ratio + padw, 0, w_max) + bbox[:, 1::2] = np.clip(bbox[:, 1::2] * scale_ratio + padh, 0, h_max) + return bbox + + +def xyxy2xywh(bboxes): + bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 0] + bboxes[:, 3] = bboxes[:, 3] - bboxes[:, 1] + return bboxes + + +def xyxy2cxcywh(bboxes): + bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 0] + bboxes[:, 3] = bboxes[:, 3] - bboxes[:, 1] + bboxes[:, 0] = bboxes[:, 0] + bboxes[:, 2] * 0.5 + bboxes[:, 1] = bboxes[:, 1] + bboxes[:, 3] * 0.5 + return bboxes diff --git a/modelscope/models/cv/tinynas_detection/detector.py b/modelscope/models/cv/tinynas_detection/detector.py new file mode 100644 index 00000000..615b13a8 --- /dev/null +++ b/modelscope/models/cv/tinynas_detection/detector.py @@ -0,0 +1,181 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet. + +import os.path as osp +import pickle + +import cv2 +import torch +import torchvision + +from modelscope.metainfo import Models +from modelscope.models.base.base_torch_model import TorchModel +from modelscope.models.builder import MODELS +from modelscope.utils.config import Config +from modelscope.utils.constant import ModelFile, Tasks +from .backbone import build_backbone +from .head import build_head +from .neck import build_neck +from .utils import parse_config + + +class SingleStageDetector(TorchModel): + """ + The base class of single stage detector. + """ + + def __init__(self, model_dir: str, *args, **kwargs): + """ + init model by cfg + """ + super().__init__(model_dir, *args, **kwargs) + + config_path = osp.join(model_dir, 'airdet_s.py') + config = parse_config(config_path) + self.cfg = config + model_path = osp.join(model_dir, config.model.name) + label_map = osp.join(model_dir, config.model.class_map) + self.label_map = pickle.load(open(label_map, 'rb')) + self.size_divisible = config.dataset.size_divisibility + self.num_classes = config.model.head.num_classes + self.conf_thre = config.model.head.nms_conf_thre + self.nms_thre = config.model.head.nms_iou_thre + + self.backbone = build_backbone(self.cfg.model.backbone) + self.neck = build_neck(self.cfg.model.neck) + self.head = build_head(self.cfg.model.head) + + self.load_pretrain_model(model_path) + + def load_pretrain_model(self, pretrain_model): + + state_dict = torch.load(pretrain_model, map_location='cpu')['model'] + new_state_dict = {} + for k, v in state_dict.items(): + k = k.replace('module.', '') + new_state_dict[k] = v + self.load_state_dict(new_state_dict, strict=True) + + def inference(self, x): + + if self.training: + return self.forward_train(x) + else: + return self.forward_eval(x) + + def forward_train(self, x): + + pass + + def forward_eval(self, x): + + x = self.backbone(x) + x = self.neck(x) + prediction = self.head(x) + + return prediction + + def preprocess(self, image): + image = torch.from_numpy(image).type(torch.float32) + image = image.permute(2, 0, 1) + shape = image.shape # c, h, w + if self.size_divisible > 0: + import math + stride = self.size_divisible + shape = list(shape) + shape[1] = int(math.ceil(shape[1] / stride) * stride) + shape[2] = int(math.ceil(shape[2] / stride) * stride) + shape = tuple(shape) + pad_img = image.new(*shape).zero_() + pad_img[:, :image.shape[1], :image.shape[2]].copy_(image) + pad_img = pad_img.unsqueeze(0) + + return pad_img + + def postprocess(self, preds): + bboxes, scores, labels_idx = postprocess_gfocal( + preds, self.num_classes, self.conf_thre, self.nms_thre) + bboxes = bboxes.cpu().numpy() + scores = scores.cpu().numpy() + labels_idx = labels_idx.cpu().numpy() + labels = [self.label_map[idx + 1][0]['name'] for idx in labels_idx] + + return (bboxes, scores, labels) + + +def multiclass_nms(multi_bboxes, + multi_scores, + score_thr, + iou_thr, + max_num=100, + score_factors=None): + """NMS for multi-class bboxes. + + Args: + multi_bboxes (Tensor): shape (n, #class*4) or (n, 4) + multi_scores (Tensor): shape (n, #class), where the last column + contains scores of the background class, but this will be ignored. + score_thr (float): bbox threshold, bboxes with scores lower than it + will not be considered. + nms_thr (float): NMS IoU threshold + max_num (int): if there are more than max_num bboxes after NMS, + only top max_num will be kept. + score_factors (Tensor): The factors multiplied to scores before + applying NMS + + Returns: + tuple: (bboxes, labels), tensors of shape (k, 5) and (k, 1). Labels \ + are 0-based. + """ + num_classes = multi_scores.size(1) + # exclude background category + if multi_bboxes.shape[1] > 4: + bboxes = multi_bboxes.view(multi_scores.size(0), -1, 4) + else: + bboxes = multi_bboxes[:, None].expand( + multi_scores.size(0), num_classes, 4) + scores = multi_scores + # filter out boxes with low scores + valid_mask = scores > score_thr # 1000 * 80 bool + + # We use masked_select for ONNX exporting purpose, + # which is equivalent to bboxes = bboxes[valid_mask] + # (TODO): as ONNX does not support repeat now, + # we have to use this ugly code + # bboxes -> 1000, 4 + bboxes = torch.masked_select( + bboxes, + torch.stack((valid_mask, valid_mask, valid_mask, valid_mask), + -1)).view(-1, 4) # mask-> 1000*80*4, 80000*4 + if score_factors is not None: + scores = scores * score_factors[:, None] + scores = torch.masked_select(scores, valid_mask) + labels = valid_mask.nonzero(as_tuple=False)[:, 1] + + if bboxes.numel() == 0: + bboxes = multi_bboxes.new_zeros((0, 5)) + labels = multi_bboxes.new_zeros((0, ), dtype=torch.long) + scores = multi_bboxes.new_zeros((0, )) + + return bboxes, scores, labels + + keep = torchvision.ops.batched_nms(bboxes, scores, labels, iou_thr) + + if max_num > 0: + keep = keep[:max_num] + + return bboxes[keep], scores[keep], labels[keep] + + +def postprocess_gfocal(prediction, num_classes, conf_thre=0.05, nms_thre=0.7): + assert prediction.shape[0] == 1 + for i, image_pred in enumerate(prediction): + # If none are remaining => process next image + if not image_pred.size(0): + continue + multi_bboxes = image_pred[:, :4] + multi_scores = image_pred[:, 4:] + detections, scores, labels = multiclass_nms(multi_bboxes, multi_scores, + conf_thre, nms_thre, 500) + + return detections, scores, labels diff --git a/modelscope/models/cv/tinynas_detection/head/__init__.py b/modelscope/models/cv/tinynas_detection/head/__init__.py new file mode 100644 index 00000000..f870fae1 --- /dev/null +++ b/modelscope/models/cv/tinynas_detection/head/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet. + +import copy + +from .gfocal_v2_tiny import GFocalHead_Tiny + + +def build_head(cfg): + + head_cfg = copy.deepcopy(cfg) + name = head_cfg.pop('name') + if name == 'GFocalV2': + return GFocalHead_Tiny(**head_cfg) + else: + raise NotImplementedError diff --git a/modelscope/models/cv/tinynas_detection/head/gfocal_v2_tiny.py b/modelscope/models/cv/tinynas_detection/head/gfocal_v2_tiny.py new file mode 100644 index 00000000..41f35968 --- /dev/null +++ b/modelscope/models/cv/tinynas_detection/head/gfocal_v2_tiny.py @@ -0,0 +1,361 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet. + +import functools +from functools import partial + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ..core.base_ops import BaseConv, DWConv + + +class Scale(nn.Module): + + def __init__(self, scale=1.0): + super(Scale, self).__init__() + self.scale = nn.Parameter(torch.tensor(scale, dtype=torch.float)) + + def forward(self, x): + return x * self.scale + + +def multi_apply(func, *args, **kwargs): + + pfunc = partial(func, **kwargs) if kwargs else func + map_results = map(pfunc, *args) + return tuple(map(list, zip(*map_results))) + + +def xyxy2CxCywh(xyxy, size=None): + x1 = xyxy[..., 0] + y1 = xyxy[..., 1] + x2 = xyxy[..., 2] + y2 = xyxy[..., 3] + + cx = (x1 + x2) / 2 + cy = (y1 + y2) / 2 + + w = x2 - x1 + h = y2 - y1 + if size is not None: + w = w.clamp(min=0, max=size[1]) + h = h.clamp(min=0, max=size[0]) + return torch.stack([cx, cy, w, h], axis=-1) + + +def distance2bbox(points, distance, max_shape=None): + """Decode distance prediction to bounding box. + """ + x1 = points[..., 0] - distance[..., 0] + y1 = points[..., 1] - distance[..., 1] + x2 = points[..., 0] + distance[..., 2] + y2 = points[..., 1] + distance[..., 3] + if max_shape is not None: + x1 = x1.clamp(min=0, max=max_shape[1]) + y1 = y1.clamp(min=0, max=max_shape[0]) + x2 = x2.clamp(min=0, max=max_shape[1]) + y2 = y2.clamp(min=0, max=max_shape[0]) + return torch.stack([x1, y1, x2, y2], -1) + + +def bbox2distance(points, bbox, max_dis=None, eps=0.1): + """Decode bounding box based on distances. + """ + left = points[:, 0] - bbox[:, 0] + top = points[:, 1] - bbox[:, 1] + right = bbox[:, 2] - points[:, 0] + bottom = bbox[:, 3] - points[:, 1] + if max_dis is not None: + left = left.clamp(min=0, max=max_dis - eps) + top = top.clamp(min=0, max=max_dis - eps) + right = right.clamp(min=0, max=max_dis - eps) + bottom = bottom.clamp(min=0, max=max_dis - eps) + return torch.stack([left, top, right, bottom], -1) + + +class Integral(nn.Module): + """A fixed layer for calculating integral result from distribution. + """ + + def __init__(self, reg_max=16): + super(Integral, self).__init__() + self.reg_max = reg_max + self.register_buffer('project', + torch.linspace(0, self.reg_max, self.reg_max + 1)) + + def forward(self, x): + """Forward feature from the regression head to get integral result of + bounding box location. + """ + shape = x.size() + x = F.softmax(x.reshape(*shape[:-1], 4, self.reg_max + 1), dim=-1) + b, nb, ne, _ = x.size() + x = x.reshape(b * nb * ne, self.reg_max + 1) + y = self.project.type_as(x).unsqueeze(1) + x = torch.matmul(x, y).reshape(b, nb, 4) + return x + + +class GFocalHead_Tiny(nn.Module): + """Ref to Generalized Focal Loss V2: Learning Reliable Localization Quality + Estimation for Dense Object Detection. + """ + + def __init__( + self, + num_classes, + in_channels, + stacked_convs=4, # 4 + feat_channels=256, + reg_max=12, + reg_topk=4, + reg_channels=64, + strides=[8, 16, 32], + add_mean=True, + norm='gn', + act='relu', + start_kernel_size=3, + conv_groups=1, + conv_type='BaseConv', + simOTA_cls_weight=1.0, + simOTA_iou_weight=3.0, + octbase=8, + simlqe=False, + **kwargs): + self.simlqe = simlqe + self.num_classes = num_classes + self.in_channels = in_channels + self.strides = strides + self.feat_channels = feat_channels if isinstance(feat_channels, list) \ + else [feat_channels] * len(self.strides) + + self.cls_out_channels = num_classes + 1 # add 1 for keep consistance with former models + # and will be deprecated in future. + self.stacked_convs = stacked_convs + self.conv_groups = conv_groups + self.reg_max = reg_max + self.reg_topk = reg_topk + self.reg_channels = reg_channels + self.add_mean = add_mean + self.total_dim = reg_topk + self.start_kernel_size = start_kernel_size + + self.norm = norm + self.act = act + self.conv_module = DWConv if conv_type == 'DWConv' else BaseConv + + if add_mean: + self.total_dim += 1 + + super(GFocalHead_Tiny, self).__init__() + self.integral = Integral(self.reg_max) + + self._init_layers() + + def _build_not_shared_convs(self, in_channel, feat_channels): + self.relu = nn.ReLU(inplace=True) + cls_convs = nn.ModuleList() + reg_convs = nn.ModuleList() + + for i in range(self.stacked_convs): + chn = feat_channels if i > 0 else in_channel + kernel_size = 3 if i > 0 else self.start_kernel_size + cls_convs.append( + self.conv_module( + chn, + feat_channels, + kernel_size, + stride=1, + groups=self.conv_groups, + norm=self.norm, + act=self.act)) + reg_convs.append( + self.conv_module( + chn, + feat_channels, + kernel_size, + stride=1, + groups=self.conv_groups, + norm=self.norm, + act=self.act)) + if not self.simlqe: + conf_vector = [nn.Conv2d(4 * self.total_dim, self.reg_channels, 1)] + else: + conf_vector = [ + nn.Conv2d(4 * (self.reg_max + 1), self.reg_channels, 1) + ] + conf_vector += [self.relu] + conf_vector += [nn.Conv2d(self.reg_channels, 1, 1), nn.Sigmoid()] + reg_conf = nn.Sequential(*conf_vector) + + return cls_convs, reg_convs, reg_conf + + def _init_layers(self): + """Initialize layers of the head.""" + self.relu = nn.ReLU(inplace=True) + self.cls_convs = nn.ModuleList() + self.reg_convs = nn.ModuleList() + self.reg_confs = nn.ModuleList() + + for i in range(len(self.strides)): + cls_convs, reg_convs, reg_conf = self._build_not_shared_convs( + self.in_channels[i], self.feat_channels[i]) + self.cls_convs.append(cls_convs) + self.reg_convs.append(reg_convs) + self.reg_confs.append(reg_conf) + + self.gfl_cls = nn.ModuleList([ + nn.Conv2d( + self.feat_channels[i], self.cls_out_channels, 3, padding=1) + for i in range(len(self.strides)) + ]) + + self.gfl_reg = nn.ModuleList([ + nn.Conv2d( + self.feat_channels[i], 4 * (self.reg_max + 1), 3, padding=1) + for i in range(len(self.strides)) + ]) + + self.scales = nn.ModuleList([Scale(1.0) for _ in self.strides]) + + def forward(self, + xin, + labels=None, + imgs=None, + conf_thre=0.05, + nms_thre=0.7): + + # prepare labels during training + b, c, h, w = xin[0].shape + if labels is not None: + gt_bbox_list = [] + gt_cls_list = [] + for label in labels: + gt_bbox_list.append(label.bbox) + gt_cls_list.append((label.get_field('labels') + - 1).long()) # labels starts from 1 + + # prepare priors for label assignment and bbox decode + mlvl_priors_list = [ + self.get_single_level_center_priors( + xin[i].shape[0], + xin[i].shape[-2:], + stride, + dtype=torch.float32, + device=xin[0].device) for i, stride in enumerate(self.strides) + ] + mlvl_priors = torch.cat(mlvl_priors_list, dim=1) + + # forward for bboxes and classification prediction + cls_scores, bbox_preds = multi_apply( + self.forward_single, + xin, + self.cls_convs, + self.reg_convs, + self.gfl_cls, + self.gfl_reg, + self.reg_confs, + self.scales, + ) + flatten_cls_scores = torch.cat(cls_scores, dim=1) + flatten_bbox_preds = torch.cat(bbox_preds, dim=1) + + # calculating losses or bboxes decoded + if self.training: + loss = self.loss(flatten_cls_scores, flatten_bbox_preds, + gt_bbox_list, gt_cls_list, mlvl_priors) + return loss + else: + output = self.get_bboxes(flatten_cls_scores, flatten_bbox_preds, + mlvl_priors) + return output + + def forward_single(self, x, cls_convs, reg_convs, gfl_cls, gfl_reg, + reg_conf, scale): + """Forward feature of a single scale level. + + """ + cls_feat = x + reg_feat = x + + for cls_conv in cls_convs: + cls_feat = cls_conv(cls_feat) + for reg_conv in reg_convs: + reg_feat = reg_conv(reg_feat) + + bbox_pred = scale(gfl_reg(reg_feat)).float() + N, C, H, W = bbox_pred.size() + prob = F.softmax( + bbox_pred.reshape(N, 4, self.reg_max + 1, H, W), dim=2) + if not self.simlqe: + prob_topk, _ = prob.topk(self.reg_topk, dim=2) + + if self.add_mean: + stat = torch.cat( + [prob_topk, prob_topk.mean(dim=2, keepdim=True)], dim=2) + else: + stat = prob_topk + + quality_score = reg_conf(stat.reshape(N, 4 * self.total_dim, H, W)) + else: + quality_score = reg_conf( + bbox_pred.reshape(N, 4 * (self.reg_max + 1), H, W)) + + cls_score = gfl_cls(cls_feat).sigmoid() * quality_score + + flatten_cls_score = cls_score.flatten(start_dim=2).transpose(1, 2) + flatten_bbox_pred = bbox_pred.flatten(start_dim=2).transpose(1, 2) + return flatten_cls_score, flatten_bbox_pred + + def get_single_level_center_priors(self, batch_size, featmap_size, stride, + dtype, device): + + h, w = featmap_size + x_range = (torch.arange(0, int(w), dtype=dtype, + device=device)) * stride + y_range = (torch.arange(0, int(h), dtype=dtype, + device=device)) * stride + + x = x_range.repeat(h, 1) + y = y_range.unsqueeze(-1).repeat(1, w) + + y = y.flatten() + x = x.flatten() + strides = x.new_full((x.shape[0], ), stride) + priors = torch.stack([x, y, strides, strides], dim=-1) + + return priors.unsqueeze(0).repeat(batch_size, 1, 1) + + def sample(self, assign_result, gt_bboxes): + pos_inds = torch.nonzero( + assign_result.gt_inds > 0, as_tuple=False).squeeze(-1).unique() + neg_inds = torch.nonzero( + assign_result.gt_inds == 0, as_tuple=False).squeeze(-1).unique() + pos_assigned_gt_inds = assign_result.gt_inds[pos_inds] - 1 + + if gt_bboxes.numel() == 0: + # hack for index error case + assert pos_assigned_gt_inds.numel() == 0 + pos_gt_bboxes = torch.empty_like(gt_bboxes).view(-1, 4) + else: + if len(gt_bboxes.shape) < 2: + gt_bboxes = gt_bboxes.view(-1, 4) + pos_gt_bboxes = gt_bboxes[pos_assigned_gt_inds, :] + + return pos_inds, neg_inds, pos_gt_bboxes, pos_assigned_gt_inds + + def get_bboxes(self, + cls_preds, + reg_preds, + mlvl_center_priors, + img_meta=None): + + dis_preds = self.integral(reg_preds) * mlvl_center_priors[..., 2, None] + bboxes = distance2bbox(mlvl_center_priors[..., :2], dis_preds) + + res = torch.cat([bboxes, cls_preds[..., 0:self.num_classes]], dim=-1) + + return res diff --git a/modelscope/models/cv/tinynas_detection/neck/__init__.py b/modelscope/models/cv/tinynas_detection/neck/__init__.py new file mode 100644 index 00000000..3c418c29 --- /dev/null +++ b/modelscope/models/cv/tinynas_detection/neck/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet. + +import copy + +from .giraffe_fpn import GiraffeNeck +from .giraffe_fpn_v2 import GiraffeNeckV2 + + +def build_neck(cfg): + neck_cfg = copy.deepcopy(cfg) + name = neck_cfg.pop('name') + if name == 'GiraffeNeck': + return GiraffeNeck(**neck_cfg) + elif name == 'GiraffeNeckV2': + return GiraffeNeckV2(**neck_cfg) diff --git a/modelscope/models/cv/tinynas_detection/neck/giraffe_config.py b/modelscope/models/cv/tinynas_detection/neck/giraffe_config.py new file mode 100644 index 00000000..289fdfd2 --- /dev/null +++ b/modelscope/models/cv/tinynas_detection/neck/giraffe_config.py @@ -0,0 +1,235 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet. + +import collections +import itertools +import os + +import networkx as nx +from omegaconf import OmegaConf + +Node = collections.namedtuple('Node', ['id', 'inputs', 'type']) + + +def get_graph_info(graph): + input_nodes = [] + output_nodes = [] + Nodes = [] + for node in range(graph.number_of_nodes()): + tmp = list(graph.neighbors(node)) + tmp.sort() + type = -1 + if node < tmp[0]: + input_nodes.append(node) + type = 0 + if node > tmp[-1]: + output_nodes.append(node) + type = 1 + Nodes.append(Node(node, [n for n in tmp if n < node], type)) + return Nodes, input_nodes, output_nodes + + +def nodeid_trans(id, cur_level, num_levels): + if id % 2 == 1: + gap = int(((id + 1) // 2) * num_levels * 2) + else: + a = (num_levels - cur_level) * 2 - 1 + b = ((id + 1) // 2) * num_levels * 2 + gap = int(a + b) + return cur_level + gap + + +def gen_log2n_graph_file(log2n_graph_file, depth_multiplier): + f = open(log2n_graph_file, 'w') + for i in range(depth_multiplier): + for j in [1, 2, 4, 8, 16, 32]: + if i - j < 0: + break + else: + f.write('%d,%d\n' % (i - j, i)) + f.close() + + +def get_log2n_graph(depth_multiplier): + nodes = [] + connnections = [] + + for i in range(depth_multiplier): + nodes.append(i) + for j in [1, 2, 4, 8, 16, 32]: + if i - j < 0: + break + else: + connnections.append((i - j, i)) + return nodes, connnections + + +def get_dense_graph(depth_multiplier): + nodes = [] + connections = [] + + for i in range(depth_multiplier): + nodes.append(i) + for j in range(i): + connections.append((j, i)) + return nodes, connections + + +def giraffeneck_config(min_level, + max_level, + weight_method=None, + depth_multiplier=5, + with_backslash=False, + with_slash=False, + with_skip_connect=False, + skip_connect_type='dense'): + """Graph config with log2n merge and panet""" + if skip_connect_type == 'dense': + nodes, connections = get_dense_graph(depth_multiplier) + elif skip_connect_type == 'log2n': + nodes, connections = get_log2n_graph(depth_multiplier) + graph = nx.Graph() + graph.add_nodes_from(nodes) + graph.add_edges_from(connections) + + drop_node = [] + nodes, input_nodes, output_nodes = get_graph_info(graph) + + weight_method = weight_method or 'fastattn' + + num_levels = max_level - min_level + 1 + node_ids = {min_level + i: [i] for i in range(num_levels)} + node_ids_per_layer = {} + + pnodes = {} + + def update_drop_node(new_id, input_offsets): + if new_id not in drop_node: + new_id = new_id + else: + while new_id in drop_node: + if new_id in pnodes: + for n in pnodes[new_id]['inputs_offsets']: + if n not in input_offsets and n not in drop_node: + input_offsets.append(n) + new_id = new_id - 1 + if new_id not in input_offsets: + input_offsets.append(new_id) + + # top-down layer + for i in range(max_level, min_level - 1, -1): + node_ids_per_layer[i] = [] + for id, node in enumerate(nodes): + input_offsets = [] + if id in input_nodes: + input_offsets.append(node_ids[i][0]) + else: + if with_skip_connect: + for input_id in node.inputs: + new_id = nodeid_trans(input_id, i - min_level, + num_levels) + update_drop_node(new_id, input_offsets) + + # add top2down + new_id = nodeid_trans(id, i - min_level, num_levels) + + # add backslash node + def cal_backslash_node(id): + ind = id // num_levels + mod = id % num_levels + if ind % 2 == 0: # even + if mod == (num_levels - 1): + last = -1 + else: + last = (ind - 1) * num_levels + ( + num_levels - 1 - mod - 1) + else: # odd + if mod == 0: + last = -1 + else: + last = (ind - 1) * num_levels + ( + num_levels - 1 - mod + 1) + + return last + + # add slash node + def cal_slash_node(id): + ind = id // num_levels + mod = id % num_levels + if ind % 2 == 1: # odd + if mod == (num_levels - 1): + last = -1 + else: + last = (ind - 1) * num_levels + ( + num_levels - 1 - mod - 1) + else: # even + if mod == 0: + last = -1 + else: + last = (ind - 1) * num_levels + ( + num_levels - 1 - mod + 1) + + return last + + # add last node + last = new_id - 1 + update_drop_node(last, input_offsets) + + if with_backslash: + backslash = cal_backslash_node(new_id) + if backslash != -1 and backslash not in input_offsets: + input_offsets.append(backslash) + + if with_slash: + slash = cal_slash_node(new_id) + if slash != -1 and slash not in input_offsets: + input_offsets.append(slash) + + if new_id in drop_node: + input_offsets = [] + + pnodes[new_id] = { + 'reduction': 1 << i, + 'inputs_offsets': input_offsets, + 'weight_method': weight_method, + 'is_out': 0, + } + + input_offsets = [] + for out_id in output_nodes: + new_id = nodeid_trans(out_id, i - min_level, num_levels) + input_offsets.append(new_id) + + pnodes[node_ids[i][0] + num_levels * (len(nodes) + 1)] = { + 'reduction': 1 << i, + 'inputs_offsets': input_offsets, + 'weight_method': weight_method, + 'is_out': 1, + } + + pnodes = dict(sorted(pnodes.items(), key=lambda x: x[0])) + return pnodes + + +def get_graph_config(fpn_name, + min_level=3, + max_level=7, + weight_method='concat', + depth_multiplier=5, + with_backslash=False, + with_slash=False, + with_skip_connect=False, + skip_connect_type='dense'): + name_to_config = { + 'giraffeneck': + giraffeneck_config( + min_level=min_level, + max_level=max_level, + weight_method=weight_method, + depth_multiplier=depth_multiplier, + with_backslash=with_backslash, + with_slash=with_slash, + with_skip_connect=with_skip_connect, + skip_connect_type=skip_connect_type), + } + return name_to_config[fpn_name] diff --git a/modelscope/models/cv/tinynas_detection/neck/giraffe_fpn.py b/modelscope/models/cv/tinynas_detection/neck/giraffe_fpn.py new file mode 100644 index 00000000..b7087779 --- /dev/null +++ b/modelscope/models/cv/tinynas_detection/neck/giraffe_fpn.py @@ -0,0 +1,661 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet. + +import logging +import math +from collections import OrderedDict +from functools import partial +from typing import Callable, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from timm import create_model +from timm.models.layers import (Swish, create_conv2d, create_pool2d, + get_act_layer) + +from ..core.base_ops import CSPLayer, ShuffleBlock, ShuffleCSPLayer +from .giraffe_config import get_graph_config + +_ACT_LAYER = Swish + + +class SequentialList(nn.Sequential): + """ This module exists to work around torchscript typing issues list -> list""" + + def __init__(self, *args): + super(SequentialList, self).__init__(*args) + + def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]: + for module in self: + x = module(x) + return x + + +class ConvBnAct2d(nn.Module): + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + dilation=1, + padding='', + bias=False, + norm_layer=nn.BatchNorm2d, + act_layer=_ACT_LAYER): + super(ConvBnAct2d, self).__init__() + + self.conv = create_conv2d( + in_channels, + out_channels, + kernel_size, + stride=stride, + dilation=dilation, + padding=padding, + bias=bias) + self.bn = None if norm_layer is None else norm_layer(out_channels) + self.act = None if act_layer is None else act_layer(inplace=True) + + def forward(self, x): + x = self.conv(x) + if self.bn is not None: + x = self.bn(x) + if self.act is not None: + x = self.act(x) + return x + + +class SeparableConv2d(nn.Module): + """ Separable Conv + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size=3, + stride=1, + dilation=1, + padding='', + bias=False, + channel_multiplier=1.0, + pw_kernel_size=1, + norm_layer=nn.BatchNorm2d, + act_layer=_ACT_LAYER): + super(SeparableConv2d, self).__init__() + self.conv_dw = create_conv2d( + in_channels, + int(in_channels * channel_multiplier), + kernel_size, + stride=stride, + dilation=dilation, + padding=padding, + depthwise=True) + + self.conv_pw = create_conv2d( + int(in_channels * channel_multiplier), + out_channels, + pw_kernel_size, + padding=padding, + bias=bias) + + self.bn = None if norm_layer is None else norm_layer(out_channels) + self.act = None if act_layer is None else act_layer(inplace=True) + + def forward(self, x): + x = self.conv_dw(x) + x = self.conv_pw(x) + if self.bn is not None: + x = self.bn(x) + if self.act is not None: + x = self.act(x) + return x + + +def _init_weight( + m, + n='', +): + """ Weight initialization as per Tensorflow official implementations. + """ + + def _fan_in_out(w, groups=1): + dimensions = w.dim() + if dimensions < 2: + raise ValueError( + 'Fan in and fan out can not be computed for tensor with fewer than 2 dimensions' + ) + num_input_fmaps = w.size(1) + num_output_fmaps = w.size(0) + receptive_field_size = 1 + if w.dim() > 2: + receptive_field_size = w[0][0].numel() + fan_in = num_input_fmaps * receptive_field_size + fan_out = num_output_fmaps * receptive_field_size + fan_out //= groups + return fan_in, fan_out + + def _glorot_uniform(w, gain=1, groups=1): + fan_in, fan_out = _fan_in_out(w, groups) + gain /= max(1., (fan_in + fan_out) / 2.) # fan avg + limit = math.sqrt(3.0 * gain) + w.data.uniform_(-limit, limit) + + def _variance_scaling(w, gain=1, groups=1): + fan_in, fan_out = _fan_in_out(w, groups) + gain /= max(1., fan_in) # fan in + std = math.sqrt(gain) + w.data.normal_(std=std) + + if isinstance(m, SeparableConv2d): + if 'box_net' in n or 'class_net' in n: + _variance_scaling(m.conv_dw.weight, groups=m.conv_dw.groups) + _variance_scaling(m.conv_pw.weight) + if m.conv_pw.bias is not None: + if 'class_net.predict' in n: + m.conv_pw.bias.data.fill_(-math.log((1 - 0.01) / 0.01)) + else: + m.conv_pw.bias.data.zero_() + else: + _glorot_uniform(m.conv_dw.weight, groups=m.conv_dw.groups) + _glorot_uniform(m.conv_pw.weight) + if m.conv_pw.bias is not None: + m.conv_pw.bias.data.zero_() + elif isinstance(m, ConvBnAct2d): + if 'box_net' in n or 'class_net' in n: + m.conv.weight.data.normal_(std=.01) + if m.conv.bias is not None: + if 'class_net.predict' in n: + m.conv.bias.data.fill_(-math.log((1 - 0.01) / 0.01)) + else: + m.conv.bias.data.zero_() + else: + _glorot_uniform(m.conv.weight) + if m.conv.bias is not None: + m.conv.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1.0) + m.bias.data.zero_() + + +def _init_weight_alt( + m, + n='', +): + """ Weight initialization alternative, based on EfficientNet bacbkone init w/ class bias addition + NOTE: this will likely be removed after some experimentation + """ + if isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + if 'class_net.predict' in n: + m.bias.data.fill_(-math.log((1 - 0.01) / 0.01)) + else: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1.0) + m.bias.data.zero_() + + +class Interpolate2d(nn.Module): + r"""Resamples a 2d Image + + The input data is assumed to be of the form + `minibatch x channels x [optional depth] x [optional height] x width`. + Hence, for spatial inputs, we expect a 4D Tensor and for volumetric inputs, we expect a 5D Tensor. + + The algorithms available for upsampling are nearest neighbor and linear, + bilinear, bicubic and trilinear for 3D, 4D and 5D input Tensor, + respectively. + + One can either give a :attr:`scale_factor` or the target output :attr:`size` to + calculate the output size. (You cannot give both, as it is ambiguous) + + Args: + size (int or Tuple[int] or Tuple[int, int] or Tuple[int, int, int], optional): + output spatial sizes + scale_factor (float or Tuple[float] or Tuple[float, float] or Tuple[float, float, float], optional): + multiplier for spatial size. Has to match input size if it is a tuple. + mode (str, optional): the upsampling algorithm: one of ``'nearest'``, + ``'linear'``, ``'bilinear'``, ``'bicubic'`` and ``'trilinear'``. + Default: ``'nearest'`` + align_corners (bool, optional): if ``True``, the corner pixels of the input + and output tensors are aligned, and thus preserving the values at + those pixels. This only has effect when :attr:`mode` is + ``'linear'``, ``'bilinear'``, or ``'trilinear'``. Default: ``False`` + """ + __constants__ = ['size', 'scale_factor', 'mode', 'align_corners', 'name'] + name: str + size: Optional[Union[int, Tuple[int, int]]] + scale_factor: Optional[Union[float, Tuple[float, float]]] + mode: str + align_corners: Optional[bool] + + def __init__(self, + size: Optional[Union[int, Tuple[int, int]]] = None, + scale_factor: Optional[Union[float, Tuple[float, + float]]] = None, + mode: str = 'nearest', + align_corners: bool = False) -> None: + super(Interpolate2d, self).__init__() + self.name = type(self).__name__ + self.size = size + if isinstance(scale_factor, tuple): + self.scale_factor = tuple(float(factor) for factor in scale_factor) + else: + self.scale_factor = float(scale_factor) if scale_factor else None + self.mode = mode + self.align_corners = None if mode == 'nearest' else align_corners + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return F.interpolate( + input, + self.size, + self.scale_factor, + self.mode, + self.align_corners, + recompute_scale_factor=False) + + +class ResampleFeatureMap(nn.Sequential): + + def __init__(self, + in_channels, + out_channels, + reduction_ratio=1., + pad_type='', + downsample=None, + upsample=None, + norm_layer=nn.BatchNorm2d, + apply_bn=False, + conv_after_downsample=False, + redundant_bias=False): + super(ResampleFeatureMap, self).__init__() + downsample = downsample or 'max' + upsample = upsample or 'nearest' + self.in_channels = in_channels + self.out_channels = out_channels + self.reduction_ratio = reduction_ratio + self.conv_after_downsample = conv_after_downsample + + conv = None + if in_channels != out_channels: + conv = ConvBnAct2d( + in_channels, + out_channels, + kernel_size=1, + padding=pad_type, + norm_layer=norm_layer if apply_bn else None, + bias=not apply_bn or redundant_bias, + act_layer=None) + + if reduction_ratio > 1: + if conv is not None and not self.conv_after_downsample: + self.add_module('conv', conv) + if downsample in ('max', 'avg'): + stride_size = int(reduction_ratio) + downsample = create_pool2d( + downsample, + kernel_size=stride_size + 1, + stride=stride_size, + padding=pad_type) + else: + downsample = Interpolate2d( + scale_factor=1. / reduction_ratio, mode=downsample) + self.add_module('downsample', downsample) + if conv is not None and self.conv_after_downsample: + self.add_module('conv', conv) + else: + if conv is not None: + self.add_module('conv', conv) + if reduction_ratio < 1: + scale = int(1 // reduction_ratio) + self.add_module( + 'upsample', + Interpolate2d(scale_factor=scale, mode=upsample)) + + +class GiraffeCombine(nn.Module): + + def __init__(self, + feature_info, + fpn_config, + fpn_channels, + inputs_offsets, + target_reduction, + pad_type='', + downsample=None, + upsample=None, + norm_layer=nn.BatchNorm2d, + apply_resample_bn=False, + conv_after_downsample=False, + redundant_bias=False, + weight_method='attn'): + super(GiraffeCombine, self).__init__() + self.inputs_offsets = inputs_offsets + self.weight_method = weight_method + + self.resample = nn.ModuleDict() + reduction_base = feature_info[0]['reduction'] + + target_channels_idx = int( + math.log(target_reduction // reduction_base, 2)) + for idx, offset in enumerate(inputs_offsets): + if offset < len(feature_info): + in_channels = feature_info[offset]['num_chs'] + input_reduction = feature_info[offset]['reduction'] + else: + node_idx = offset + input_reduction = fpn_config[node_idx]['reduction'] + # in_channels = fpn_config[node_idx]['num_chs'] + input_channels_idx = int( + math.log(input_reduction // reduction_base, 2)) + in_channels = feature_info[input_channels_idx]['num_chs'] + + reduction_ratio = target_reduction / input_reduction + if weight_method == 'concat': + self.resample[str(offset)] = ResampleFeatureMap( + in_channels, + in_channels, + reduction_ratio=reduction_ratio, + pad_type=pad_type, + downsample=downsample, + upsample=upsample, + norm_layer=norm_layer, + apply_bn=apply_resample_bn, + conv_after_downsample=conv_after_downsample, + redundant_bias=redundant_bias) + else: + self.resample[str(offset)] = ResampleFeatureMap( + in_channels, + fpn_channels[target_channels_idx], + reduction_ratio=reduction_ratio, + pad_type=pad_type, + downsample=downsample, + upsample=upsample, + norm_layer=norm_layer, + apply_bn=apply_resample_bn, + conv_after_downsample=conv_after_downsample, + redundant_bias=redundant_bias) + + if weight_method == 'attn' or weight_method == 'fastattn': + self.edge_weights = nn.Parameter( + torch.ones(len(inputs_offsets)), requires_grad=True) # WSM + else: + self.edge_weights = None + + def forward(self, x: List[torch.Tensor]): + dtype = x[0].dtype + nodes = [] + if len(self.inputs_offsets) == 0: + return None + for offset, resample in zip(self.inputs_offsets, + self.resample.values()): + input_node = x[offset] + input_node = resample(input_node) + nodes.append(input_node) + + if self.weight_method == 'attn': + normalized_weights = torch.softmax( + self.edge_weights.to(dtype=dtype), dim=0) + out = torch.stack(nodes, dim=-1) * normalized_weights + out = torch.sum(out, dim=-1) + elif self.weight_method == 'fastattn': + edge_weights = nn.functional.relu( + self.edge_weights.to(dtype=dtype)) + weights_sum = torch.sum(edge_weights) + weights_norm = weights_sum + 0.0001 + out = torch.stack([(nodes[i] * edge_weights[i]) / weights_norm + for i in range(len(nodes))], + dim=-1) + + out = torch.sum(out, dim=-1) + elif self.weight_method == 'sum': + out = torch.stack(nodes, dim=-1) + out = torch.sum(out, dim=-1) + elif self.weight_method == 'concat': + out = torch.cat(nodes, dim=1) + else: + raise ValueError('unknown weight_method {}'.format( + self.weight_method)) + return out + + +class GiraffeNode(nn.Module): + """ A simple wrapper used in place of nn.Sequential for torchscript typing + Handles input type List[Tensor] -> output type Tensor + """ + + def __init__(self, combine: nn.Module, after_combine: nn.Module): + super(GiraffeNode, self).__init__() + self.combine = combine + self.after_combine = after_combine + + def forward(self, x: List[torch.Tensor]) -> torch.Tensor: + combine_feat = self.combine(x) + if combine_feat is None: + return None + else: + return self.after_combine(combine_feat) + + +class GiraffeLayer(nn.Module): + + def __init__(self, + feature_info, + fpn_config, + inner_fpn_channels, + outer_fpn_channels, + num_levels=5, + pad_type='', + downsample=None, + upsample=None, + norm_layer=nn.BatchNorm2d, + act_layer=_ACT_LAYER, + apply_resample_bn=False, + conv_after_downsample=True, + conv_bn_relu_pattern=False, + separable_conv=True, + redundant_bias=False, + merge_type='conv'): + super(GiraffeLayer, self).__init__() + self.num_levels = num_levels + self.conv_bn_relu_pattern = False + + self.feature_info = {} + for idx, feat in enumerate(feature_info): + self.feature_info[idx] = feat + + self.fnode = nn.ModuleList() + reduction_base = feature_info[0]['reduction'] + for i, fnode_cfg in fpn_config.items(): + logging.debug('fnode {} : {}'.format(i, fnode_cfg)) + + if fnode_cfg['is_out'] == 1: + fpn_channels = outer_fpn_channels + else: + fpn_channels = inner_fpn_channels + + reduction = fnode_cfg['reduction'] + fpn_channels_idx = int(math.log(reduction // reduction_base, 2)) + combine = GiraffeCombine( + self.feature_info, + fpn_config, + fpn_channels, + tuple(fnode_cfg['inputs_offsets']), + target_reduction=reduction, + pad_type=pad_type, + downsample=downsample, + upsample=upsample, + norm_layer=norm_layer, + apply_resample_bn=apply_resample_bn, + conv_after_downsample=conv_after_downsample, + redundant_bias=redundant_bias, + weight_method=fnode_cfg['weight_method']) + + after_combine = nn.Sequential() + + in_channels = 0 + out_channels = 0 + for input_offset in fnode_cfg['inputs_offsets']: + in_channels += self.feature_info[input_offset]['num_chs'] + + out_channels = fpn_channels[fpn_channels_idx] + + if merge_type == 'csp': + after_combine.add_module( + 'CspLayer', + CSPLayer( + in_channels, + out_channels, + 2, + shortcut=True, + depthwise=False, + act='silu')) + elif merge_type == 'shuffle': + after_combine.add_module( + 'shuffleBlock', ShuffleBlock(in_channels, in_channels)) + after_combine.add_module( + 'conv1x1', + create_conv2d(in_channels, out_channels, kernel_size=1)) + elif merge_type == 'conv': + after_combine.add_module( + 'conv1x1', + create_conv2d(in_channels, out_channels, kernel_size=1)) + conv_kwargs = dict( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=3, + padding=pad_type, + bias=False, + norm_layer=norm_layer, + act_layer=act_layer) + if not conv_bn_relu_pattern: + conv_kwargs['bias'] = redundant_bias + conv_kwargs['act_layer'] = None + after_combine.add_module('act', act_layer(inplace=True)) + after_combine.add_module( + 'conv', + SeparableConv2d(**conv_kwargs) + if separable_conv else ConvBnAct2d(**conv_kwargs)) + + self.fnode.append( + GiraffeNode(combine=combine, after_combine=after_combine)) + self.feature_info[i] = dict( + num_chs=fpn_channels[fpn_channels_idx], reduction=reduction) + + self.out_feature_info = [] + out_node = list(self.feature_info.keys())[-num_levels::] + for i in out_node: + self.out_feature_info.append(self.feature_info[i]) + + self.feature_info = self.out_feature_info + + def forward(self, x: List[torch.Tensor]): + for fn in self.fnode: + x.append(fn(x)) + return x[-self.num_levels::] + + +class GiraffeNeck(nn.Module): + + def __init__(self, min_level, max_level, num_levels, norm_layer, + norm_kwargs, act_type, fpn_config, fpn_name, fpn_channels, + out_fpn_channels, weight_method, depth_multiplier, + width_multiplier, with_backslash, with_slash, + with_skip_connect, skip_connect_type, separable_conv, + feature_info, merge_type, pad_type, downsample_type, + upsample_type, apply_resample_bn, conv_after_downsample, + redundant_bias, conv_bn_relu_pattern, alternate_init): + super(GiraffeNeck, self).__init__() + + self.num_levels = num_levels + self.min_level = min_level + self.in_features = [0, 1, 2, 3, 4, 5, + 6][self.min_level - 1:self.min_level - 1 + + num_levels] + self.alternate_init = alternate_init + norm_layer = norm_layer or nn.BatchNorm2d + if norm_kwargs: + norm_layer = partial(norm_layer, **norm_kwargs) + act_layer = get_act_layer(act_type) or _ACT_LAYER + fpn_config = fpn_config or get_graph_config( + fpn_name, + min_level=min_level, + max_level=max_level, + weight_method=weight_method, + depth_multiplier=depth_multiplier, + with_backslash=with_backslash, + with_slash=with_slash, + with_skip_connect=with_skip_connect, + skip_connect_type=skip_connect_type) + + # width scale + for i in range(len(fpn_channels)): + fpn_channels[i] = int(fpn_channels[i] * width_multiplier) + + self.resample = nn.ModuleDict() + for level in range(num_levels): + if level < len(feature_info): + in_chs = feature_info[level]['num_chs'] + reduction = feature_info[level]['reduction'] + else: + # Adds a coarser level by downsampling the last feature map + reduction_ratio = 2 + self.resample[str(level)] = ResampleFeatureMap( + in_channels=in_chs, + out_channels=feature_info[level - 1]['num_chs'], + pad_type=pad_type, + downsample=downsample_type, + upsample=upsample_type, + norm_layer=norm_layer, + reduction_ratio=reduction_ratio, + apply_bn=apply_resample_bn, + conv_after_downsample=conv_after_downsample, + redundant_bias=redundant_bias, + ) + in_chs = feature_info[level - 1]['num_chs'] + reduction = int(reduction * reduction_ratio) + feature_info.append(dict(num_chs=in_chs, reduction=reduction)) + + self.cell = SequentialList() + logging.debug('building giraffeNeck') + giraffe_layer = GiraffeLayer( + feature_info=feature_info, + fpn_config=fpn_config, + inner_fpn_channels=fpn_channels, + outer_fpn_channels=out_fpn_channels, + num_levels=num_levels, + pad_type=pad_type, + downsample=downsample_type, + upsample=upsample_type, + norm_layer=norm_layer, + act_layer=act_layer, + separable_conv=separable_conv, + apply_resample_bn=apply_resample_bn, + conv_after_downsample=conv_after_downsample, + conv_bn_relu_pattern=conv_bn_relu_pattern, + redundant_bias=redundant_bias, + merge_type=merge_type) + self.cell.add_module('giraffeNeck', giraffe_layer) + feature_info = giraffe_layer.feature_info + + def init_weights(self, pretrained=False): + for n, m in self.named_modules(): + if 'backbone' not in n: + if self.alternate_init: + _init_weight_alt(m, n) + else: + _init_weight(m, n) + + def forward(self, x: List[torch.Tensor]): + if type(x) is tuple: + x = list(x) + x = [x[f] for f in self.in_features] + for resample in self.resample.values(): + x.append(resample(x[-1])) + x = self.cell(x) + return x diff --git a/modelscope/models/cv/tinynas_detection/neck/giraffe_fpn_v2.py b/modelscope/models/cv/tinynas_detection/neck/giraffe_fpn_v2.py new file mode 100644 index 00000000..b710572f --- /dev/null +++ b/modelscope/models/cv/tinynas_detection/neck/giraffe_fpn_v2.py @@ -0,0 +1,203 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet. + +import torch +import torch.nn as nn + +from ..core.base_ops import BaseConv, CSPLayer, DWConv +from ..core.neck_ops import CSPStage + + +class GiraffeNeckV2(nn.Module): + + def __init__( + self, + depth=1.0, + width=1.0, + in_features=[2, 3, 4], + in_channels=[256, 512, 1024], + out_channels=[256, 512, 1024], + depthwise=False, + act='silu', + spp=True, + reparam_mode=True, + block_name='BasicBlock', + ): + super().__init__() + self.in_features = in_features + self.in_channels = in_channels + Conv = DWConv if depthwise else BaseConv + + reparam_mode = reparam_mode + + self.upsample = nn.Upsample(scale_factor=2, mode='nearest') + + # node x3: input x0, x1 + self.bu_conv13 = Conv( + int(in_channels[1] * width), + int(in_channels[1] * width), + 3, + 2, + act=act) + if reparam_mode: + self.merge_3 = CSPStage( + block_name, + int((in_channels[1] + in_channels[2]) * width), + int(in_channels[2] * width), + round(3 * depth), + act=act, + spp=spp) + else: + self.merge_3 = CSPLayer( + int((in_channels[1] + in_channels[2]) * width), + int(in_channels[2] * width), + round(3 * depth), + False, + depthwise=depthwise, + act=act) + + # node x4: input x1, x2, x3 + self.bu_conv24 = Conv( + int(in_channels[0] * width), + int(in_channels[0] * width), + 3, + 2, + act=act) + if reparam_mode: + self.merge_4 = CSPStage( + block_name, + int((in_channels[0] + in_channels[1] + in_channels[2]) + * width), + int(in_channels[1] * width), + round(3 * depth), + act=act, + spp=spp) + else: + self.merge_4 = CSPLayer( + int((in_channels[0] + in_channels[1] + in_channels[2]) + * width), + int(in_channels[1] * width), + round(3 * depth), + False, + depthwise=depthwise, + act=act) + + # node x5: input x2, x4 + if reparam_mode: + self.merge_5 = CSPStage( + block_name, + int((in_channels[1] + in_channels[0]) * width), + int(out_channels[0] * width), + round(3 * depth), + act=act, + spp=spp) + else: + self.merge_5 = CSPLayer( + int((in_channels[1] + in_channels[0]) * width), + int(out_channels[0] * width), + round(3 * depth), + False, + depthwise=depthwise, + act=act) + + # node x7: input x4, x5 + self.bu_conv57 = Conv( + int(out_channels[0] * width), + int(out_channels[0] * width), + 3, + 2, + act=act) + if reparam_mode: + self.merge_7 = CSPStage( + block_name, + int((out_channels[0] + in_channels[1]) * width), + int(out_channels[1] * width), + round(3 * depth), + act=act, + spp=spp) + else: + self.merge_7 = CSPLayer( + int((out_channels[0] + in_channels[1]) * width), + int(out_channels[1] * width), + round(3 * depth), + False, + depthwise=depthwise, + act=act) + + # node x6: input x3, x4, x7 + self.bu_conv46 = Conv( + int(in_channels[1] * width), + int(in_channels[1] * width), + 3, + 2, + act=act) + self.bu_conv76 = Conv( + int(out_channels[1] * width), + int(out_channels[1] * width), + 3, + 2, + act=act) + if reparam_mode: + self.merge_6 = CSPStage( + block_name, + int((in_channels[1] + out_channels[1] + in_channels[2]) + * width), + int(out_channels[2] * width), + round(3 * depth), + act=act, + spp=spp) + else: + self.merge_6 = CSPLayer( + int((in_channels[1] + out_channels[1] + in_channels[2]) + * width), + int(out_channels[2] * width), + round(3 * depth), + False, + depthwise=depthwise, + act=act) + + def init_weights(self): + pass + + def forward(self, out_features): + """ + Args: + inputs: input images. + + Returns: + Tuple[Tensor]: FPN feature. + """ + + # backbone + features = [out_features[f] for f in self.in_features] + [x2, x1, x0] = features + + # node x3 + x13 = self.bu_conv13(x1) + x3 = torch.cat([x0, x13], 1) + x3 = self.merge_3(x3) + + # node x4 + x34 = self.upsample(x3) + x24 = self.bu_conv24(x2) + x4 = torch.cat([x1, x24, x34], 1) + x4 = self.merge_4(x4) + + # node x5 + x45 = self.upsample(x4) + x5 = torch.cat([x2, x45], 1) + x5 = self.merge_5(x5) + + # node x7 + x57 = self.bu_conv57(x5) + x7 = torch.cat([x4, x57], 1) + x7 = self.merge_7(x7) + + # node x6 + x46 = self.bu_conv46(x4) + x76 = self.bu_conv76(x7) + x6 = torch.cat([x3, x46, x76], 1) + x6 = self.merge_6(x6) + + outputs = (x5, x7, x6) + return outputs diff --git a/modelscope/models/cv/tinynas_detection/tinynas_detector.py b/modelscope/models/cv/tinynas_detection/tinynas_detector.py new file mode 100644 index 00000000..e6f144df --- /dev/null +++ b/modelscope/models/cv/tinynas_detection/tinynas_detector.py @@ -0,0 +1,16 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet. + +from modelscope.metainfo import Models +from modelscope.models.builder import MODELS +from modelscope.utils.constant import Tasks +from .detector import SingleStageDetector + + +@MODELS.register_module( + Tasks.image_object_detection, module_name=Models.tinynas_detection) +class TinynasDetector(SingleStageDetector): + + def __init__(self, model_dir, *args, **kwargs): + + super(TinynasDetector, self).__init__(model_dir, *args, **kwargs) diff --git a/modelscope/models/cv/tinynas_detection/utils.py b/modelscope/models/cv/tinynas_detection/utils.py new file mode 100644 index 00000000..d67d3a36 --- /dev/null +++ b/modelscope/models/cv/tinynas_detection/utils.py @@ -0,0 +1,30 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet. + +import importlib +import os +import sys +from os.path import dirname, join + + +def get_config_by_file(config_file): + try: + sys.path.append(os.path.dirname(config_file)) + current_config = importlib.import_module( + os.path.basename(config_file).split('.')[0]) + exp = current_config.Config() + except Exception: + raise ImportError( + "{} doesn't contains class named 'Config'".format(config_file)) + return exp + + +def parse_config(config_file): + """ + get config object by file. + Args: + config_file (str): file path of config. + """ + assert (config_file is not None), 'plz provide config file' + if config_file is not None: + return get_config_by_file(config_file) diff --git a/modelscope/pipelines/cv/tinynas_detection_pipeline.py b/modelscope/pipelines/cv/tinynas_detection_pipeline.py new file mode 100644 index 00000000..b2063629 --- /dev/null +++ b/modelscope/pipelines/cv/tinynas_detection_pipeline.py @@ -0,0 +1,61 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import Any, Dict + +import cv2 +import numpy as np +import torch + +from modelscope.metainfo import Pipelines +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import LoadImage +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.image_object_detection, module_name=Pipelines.tinynas_detection) +class TinynasDetectionPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + model: model id on modelscope hub. + """ + super().__init__(model=model, auto_collate=False, **kwargs) + if torch.cuda.is_available(): + self.device = 'cuda' + else: + self.device = 'cpu' + self.model.to(self.device) + self.model.eval() + + def preprocess(self, input: Input) -> Dict[str, Any]: + + img = LoadImage.convert_to_ndarray(input) + self.img = img + img = img.astype(np.float) + img = self.model.preprocess(img) + result = {'img': img.to(self.device)} + return result + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + + outputs = self.model.inference(input['img']) + result = {'data': outputs} + return result + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + + bboxes, scores, labels = self.model.postprocess(inputs['data']) + if bboxes is None: + return None + outputs = { + OutputKeys.SCORES: scores, + OutputKeys.LABELS: labels, + OutputKeys.BOXES: bboxes + } + return outputs diff --git a/tests/pipelines/test_tinynas_detection.py b/tests/pipelines/test_tinynas_detection.py new file mode 100644 index 00000000..6b2ecd0b --- /dev/null +++ b/tests/pipelines/test_tinynas_detection.py @@ -0,0 +1,20 @@ +import unittest + +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level + + +class TinynasObjectDetectionTest(unittest.TestCase): + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run(self): + tinynas_object_detection = pipeline( + Tasks.image_object_detection, model='damo/cv_tinynas_detection') + result = tinynas_object_detection( + 'data/test/images/image_detection.jpg') + print(result) + + +if __name__ == '__main__': + unittest.main()