Browse Source

[to #42322933]add tinynas-detection pipeline and models

接入tinynas-detection,新增tinynas object detection pipeline以及tinynas models。
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9938220
master
xianzhe.xxz yingda.chen 3 years ago
parent
commit
1bac4f3349
21 changed files with 3517 additions and 0 deletions
  1. +3
    -0
      modelscope/metainfo.py
  2. +24
    -0
      modelscope/models/cv/tinynas_detection/__init__.py
  3. +16
    -0
      modelscope/models/cv/tinynas_detection/backbone/__init__.py
  4. +126
    -0
      modelscope/models/cv/tinynas_detection/backbone/darknet.py
  5. +347
    -0
      modelscope/models/cv/tinynas_detection/backbone/tinynas.py
  6. +2
    -0
      modelscope/models/cv/tinynas_detection/core/__init__.py
  7. +474
    -0
      modelscope/models/cv/tinynas_detection/core/base_ops.py
  8. +324
    -0
      modelscope/models/cv/tinynas_detection/core/neck_ops.py
  9. +205
    -0
      modelscope/models/cv/tinynas_detection/core/repvgg_block.py
  10. +196
    -0
      modelscope/models/cv/tinynas_detection/core/utils.py
  11. +181
    -0
      modelscope/models/cv/tinynas_detection/detector.py
  12. +16
    -0
      modelscope/models/cv/tinynas_detection/head/__init__.py
  13. +361
    -0
      modelscope/models/cv/tinynas_detection/head/gfocal_v2_tiny.py
  14. +16
    -0
      modelscope/models/cv/tinynas_detection/neck/__init__.py
  15. +235
    -0
      modelscope/models/cv/tinynas_detection/neck/giraffe_config.py
  16. +661
    -0
      modelscope/models/cv/tinynas_detection/neck/giraffe_fpn.py
  17. +203
    -0
      modelscope/models/cv/tinynas_detection/neck/giraffe_fpn_v2.py
  18. +16
    -0
      modelscope/models/cv/tinynas_detection/tinynas_detector.py
  19. +30
    -0
      modelscope/models/cv/tinynas_detection/utils.py
  20. +61
    -0
      modelscope/pipelines/cv/tinynas_detection_pipeline.py
  21. +20
    -0
      tests/pipelines/test_tinynas_detection.py

+ 3
- 0
modelscope/metainfo.py View File

@@ -9,6 +9,8 @@ class Models(object):

Model name should only contain model info but not task info.
"""
tinynas_detection = 'tinynas-detection'

# vision models
detection = 'detection'
realtime_object_detection = 'realtime-object-detection'
@@ -133,6 +135,7 @@ class Pipelines(object):
image_to_image_generation = 'image-to-image-generation'
skin_retouching = 'unet-skin-retouching'
tinynas_classification = 'tinynas-classification'
tinynas_detection = 'tinynas-detection'
crowd_counting = 'hrnet-crowd-counting'
action_detection = 'ResNetC3D-action-detection'
video_single_object_tracking = 'ostrack-vitb-video-single-object-tracking'


+ 24
- 0
modelscope/models/cv/tinynas_detection/__init__.py View File

@@ -0,0 +1,24 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet.

from typing import TYPE_CHECKING

from modelscope.utils.import_utils import LazyImportModule

if TYPE_CHECKING:
from .tinynas_detector import Tinynas_detector

else:
_import_structure = {
'tinynas_detector': ['TinynasDetector'],
}

import sys

sys.modules[__name__] = LazyImportModule(
__name__,
globals()['__file__'],
_import_structure,
module_spec=__spec__,
extra_objects={},
)

+ 16
- 0
modelscope/models/cv/tinynas_detection/backbone/__init__.py View File

@@ -0,0 +1,16 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet.

import copy

from .darknet import CSPDarknet
from .tinynas import load_tinynas_net


def build_backbone(cfg):
backbone_cfg = copy.deepcopy(cfg)
name = backbone_cfg.pop('name')
if name == 'CSPDarknet':
return CSPDarknet(**backbone_cfg)
elif name == 'TinyNAS':
return load_tinynas_net(backbone_cfg)

+ 126
- 0
modelscope/models/cv/tinynas_detection/backbone/darknet.py View File

@@ -0,0 +1,126 @@
# Copyright (c) Megvii Inc. All rights reserved.
# Copyright (c) Alibaba, Inc. and its affiliates.
# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet.

import torch
from torch import nn

from ..core.base_ops import (BaseConv, CSPLayer, DWConv, Focus, ResLayer,
SPPBottleneck)


class CSPDarknet(nn.Module):

def __init__(
self,
dep_mul,
wid_mul,
out_features=('dark3', 'dark4', 'dark5'),
depthwise=False,
act='silu',
reparam=False,
):
super(CSPDarknet, self).__init__()
assert out_features, 'please provide output features of Darknet'
self.out_features = out_features
Conv = DWConv if depthwise else BaseConv

base_channels = int(wid_mul * 64) # 64
base_depth = max(round(dep_mul * 3), 1) # 3

# stem
# self.stem = Focus(3, base_channels, ksize=3, act=act)
self.stem = Focus(3, base_channels, 3, act=act)

# dark2
self.dark2 = nn.Sequential(
Conv(base_channels, base_channels * 2, 3, 2, act=act),
CSPLayer(
base_channels * 2,
base_channels * 2,
n=base_depth,
depthwise=depthwise,
act=act,
reparam=reparam,
),
)

# dark3
self.dark3 = nn.Sequential(
Conv(base_channels * 2, base_channels * 4, 3, 2, act=act),
CSPLayer(
base_channels * 4,
base_channels * 4,
n=base_depth * 3,
depthwise=depthwise,
act=act,
reparam=reparam,
),
)

# dark4
self.dark4 = nn.Sequential(
Conv(base_channels * 4, base_channels * 8, 3, 2, act=act),
CSPLayer(
base_channels * 8,
base_channels * 8,
n=base_depth * 3,
depthwise=depthwise,
act=act,
reparam=reparam,
),
)

# dark5
self.dark5 = nn.Sequential(
Conv(base_channels * 8, base_channels * 16, 3, 2, act=act),
SPPBottleneck(
base_channels * 16, base_channels * 16, activation=act),
CSPLayer(
base_channels * 16,
base_channels * 16,
n=base_depth,
shortcut=False,
depthwise=depthwise,
act=act,
reparam=reparam,
),
)

def init_weights(self, pretrain=None):

if pretrain is None:
return
else:
pretrained_dict = torch.load(
pretrain, map_location='cpu')['state_dict']
new_params = self.state_dict().copy()
for k, v in pretrained_dict.items():
ks = k.split('.')
if ks[0] == 'fc' or ks[-1] == 'total_ops' or ks[
-1] == 'total_params':
continue
else:
new_params[k] = v

self.load_state_dict(new_params)
print(f' load pretrain backbone from {pretrain}')

def forward(self, x):
outputs = {}
x = self.stem(x)
outputs['stem'] = x
x = self.dark2(x)
outputs['dark2'] = x
x = self.dark3(x)
outputs['dark3'] = x
x = self.dark4(x)
outputs['dark4'] = x
x = self.dark5(x)
outputs['dark5'] = x
features_out = [
outputs['stem'], outputs['dark2'], outputs['dark3'],
outputs['dark4'], outputs['dark5']
]

return features_out

+ 347
- 0
modelscope/models/cv/tinynas_detection/backbone/tinynas.py View File

@@ -0,0 +1,347 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet.

import torch
import torch.nn as nn

from ..core.base_ops import Focus, SPPBottleneck, get_activation
from ..core.repvgg_block import RepVggBlock


class ConvKXBN(nn.Module):

def __init__(self, in_c, out_c, kernel_size, stride):
super(ConvKXBN, self).__init__()
self.conv1 = nn.Conv2d(
in_c,
out_c,
kernel_size,
stride, (kernel_size - 1) // 2,
groups=1,
bias=False)
self.bn1 = nn.BatchNorm2d(out_c)

def forward(self, x):
return self.bn1(self.conv1(x))


class ConvKXBNRELU(nn.Module):

def __init__(self, in_c, out_c, kernel_size, stride, act='silu'):
super(ConvKXBNRELU, self).__init__()
self.conv = ConvKXBN(in_c, out_c, kernel_size, stride)
if act is None:
self.activation_function = torch.relu
else:
self.activation_function = get_activation(act)

def forward(self, x):
output = self.conv(x)
return self.activation_function(output)


class ResConvK1KX(nn.Module):

def __init__(self,
in_c,
out_c,
btn_c,
kernel_size,
stride,
force_resproj=False,
act='silu'):
super(ResConvK1KX, self).__init__()
self.stride = stride
self.conv1 = ConvKXBN(in_c, btn_c, 1, 1)
self.conv2 = RepVggBlock(
btn_c, out_c, kernel_size, stride, act='identity')

if act is None:
self.activation_function = torch.relu
else:
self.activation_function = get_activation(act)

if stride == 2:
self.residual_downsample = nn.AvgPool2d(kernel_size=2, stride=2)
else:
self.residual_downsample = nn.Identity()

if in_c != out_c or force_resproj:
self.residual_proj = ConvKXBN(in_c, out_c, 1, 1)
else:
self.residual_proj = nn.Identity()

def forward(self, x):
if self.stride != 2:
reslink = self.residual_downsample(x)
reslink = self.residual_proj(reslink)

output = x
output = self.conv1(output)
output = self.activation_function(output)
output = self.conv2(output)
if self.stride != 2:
output = output + reslink
output = self.activation_function(output)

return output


class SuperResConvK1KX(nn.Module):

def __init__(self,
in_c,
out_c,
btn_c,
kernel_size,
stride,
num_blocks,
with_spp=False,
act='silu'):
super(SuperResConvK1KX, self).__init__()
if act is None:
self.act = torch.relu
else:
self.act = get_activation(act)
self.block_list = nn.ModuleList()
for block_id in range(num_blocks):
if block_id == 0:
in_channels = in_c
out_channels = out_c
this_stride = stride
force_resproj = False # as a part of CSPLayer, DO NOT need this flag
this_kernel_size = kernel_size
else:
in_channels = out_c
out_channels = out_c
this_stride = 1
force_resproj = False
this_kernel_size = kernel_size
the_block = ResConvK1KX(
in_channels,
out_channels,
btn_c,
this_kernel_size,
this_stride,
force_resproj,
act=act)
self.block_list.append(the_block)
if block_id == 0 and with_spp:
self.block_list.append(
SPPBottleneck(out_channels, out_channels))

def forward(self, x):
output = x
for block in self.block_list:
output = block(output)
return output


class ResConvKXKX(nn.Module):

def __init__(self,
in_c,
out_c,
btn_c,
kernel_size,
stride,
force_resproj=False,
act='silu'):
super(ResConvKXKX, self).__init__()
self.stride = stride
if self.stride == 2:
self.downsampler = ConvKXBNRELU(in_c, out_c, 3, 2, act=act)
else:
self.conv1 = ConvKXBN(in_c, btn_c, kernel_size, 1)
self.conv2 = RepVggBlock(
btn_c, out_c, kernel_size, stride, act='identity')

if act is None:
self.activation_function = torch.relu
else:
self.activation_function = get_activation(act)

if stride == 2:
self.residual_downsample = nn.AvgPool2d(
kernel_size=2, stride=2)
else:
self.residual_downsample = nn.Identity()

if in_c != out_c or force_resproj:
self.residual_proj = ConvKXBN(in_c, out_c, 1, 1)
else:
self.residual_proj = nn.Identity()

def forward(self, x):
if self.stride == 2:
return self.downsampler(x)
reslink = self.residual_downsample(x)
reslink = self.residual_proj(reslink)

output = x
output = self.conv1(output)
output = self.activation_function(output)
output = self.conv2(output)

output = output + reslink
output = self.activation_function(output)

return output


class SuperResConvKXKX(nn.Module):

def __init__(self,
in_c,
out_c,
btn_c,
kernel_size,
stride,
num_blocks,
with_spp=False,
act='silu'):
super(SuperResConvKXKX, self).__init__()
if act is None:
self.act = torch.relu
else:
self.act = get_activation(act)
self.block_list = nn.ModuleList()
for block_id in range(num_blocks):
if block_id == 0:
in_channels = in_c
out_channels = out_c
this_stride = stride
force_resproj = False # as a part of CSPLayer, DO NOT need this flag
this_kernel_size = kernel_size
else:
in_channels = out_c
out_channels = out_c
this_stride = 1
force_resproj = False
this_kernel_size = kernel_size
the_block = ResConvKXKX(
in_channels,
out_channels,
btn_c,
this_kernel_size,
this_stride,
force_resproj,
act=act)
self.block_list.append(the_block)
if block_id == 0 and with_spp:
self.block_list.append(
SPPBottleneck(out_channels, out_channels))

def forward(self, x):
output = x
for block in self.block_list:
output = block(output)
return output


class TinyNAS(nn.Module):

def __init__(self,
structure_info=None,
out_indices=[0, 1, 2, 4, 5],
out_channels=[None, None, 128, 256, 512],
with_spp=False,
use_focus=False,
need_conv1=True,
act='silu'):
super(TinyNAS, self).__init__()
assert len(out_indices) == len(out_channels)
self.out_indices = out_indices
self.need_conv1 = need_conv1

self.block_list = nn.ModuleList()
if need_conv1:
self.conv1_list = nn.ModuleList()
for idx, block_info in enumerate(structure_info):
the_block_class = block_info['class']
if the_block_class == 'ConvKXBNRELU':
if use_focus:
the_block = Focus(block_info['in'], block_info['out'],
block_info['k'])
else:
the_block = ConvKXBNRELU(
block_info['in'],
block_info['out'],
block_info['k'],
block_info['s'],
act=act)
self.block_list.append(the_block)
elif the_block_class == 'SuperResConvK1KX':
spp = with_spp if idx == len(structure_info) - 1 else False
the_block = SuperResConvK1KX(
block_info['in'],
block_info['out'],
block_info['btn'],
block_info['k'],
block_info['s'],
block_info['L'],
spp,
act=act)
self.block_list.append(the_block)
elif the_block_class == 'SuperResConvKXKX':
spp = with_spp if idx == len(structure_info) - 1 else False
the_block = SuperResConvKXKX(
block_info['in'],
block_info['out'],
block_info['btn'],
block_info['k'],
block_info['s'],
block_info['L'],
spp,
act=act)
self.block_list.append(the_block)
if need_conv1:
if idx in self.out_indices and out_channels[
self.out_indices.index(idx)] is not None:
self.conv1_list.append(
nn.Conv2d(block_info['out'],
out_channels[self.out_indices.index(idx)],
1))
else:
self.conv1_list.append(None)

def init_weights(self, pretrain=None):
pass

def forward(self, x):
output = x
stage_feature_list = []
for idx, block in enumerate(self.block_list):
output = block(output)
if idx in self.out_indices:
if self.need_conv1 and self.conv1_list[idx] is not None:
true_out = self.conv1_list[idx](output)
stage_feature_list.append(true_out)
else:
stage_feature_list.append(output)
return stage_feature_list


def load_tinynas_net(backbone_cfg):
# load masternet model to path
import ast

struct_str = ''.join([x.strip() for x in backbone_cfg.net_structure_str])
struct_info = ast.literal_eval(struct_str)
for layer in struct_info:
if 'nbitsA' in layer:
del layer['nbitsA']
if 'nbitsW' in layer:
del layer['nbitsW']

model = TinyNAS(
structure_info=struct_info,
out_indices=backbone_cfg.out_indices,
out_channels=backbone_cfg.out_channels,
with_spp=backbone_cfg.with_spp,
use_focus=backbone_cfg.use_focus,
act=backbone_cfg.act,
need_conv1=backbone_cfg.need_conv1,
)

return model

+ 2
- 0
modelscope/models/cv/tinynas_detection/core/__init__.py View File

@@ -0,0 +1,2 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet.

+ 474
- 0
modelscope/models/cv/tinynas_detection/core/base_ops.py View File

@@ -0,0 +1,474 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet.
import math

import torch
import torch.nn as nn
import torch.nn.functional as F

from .repvgg_block import RepVggBlock


class SiLU(nn.Module):
"""export-friendly version of nn.SiLU()"""

@staticmethod
def forward(x):
return x * torch.sigmoid(x)


def get_activation(name='silu', inplace=True):
if name == 'silu':
module = nn.SiLU(inplace=inplace)
elif name == 'relu':
module = nn.ReLU(inplace=inplace)
elif name == 'lrelu':
module = nn.LeakyReLU(0.1, inplace=inplace)
else:
raise AttributeError('Unsupported act type: {}'.format(name))
return module


def get_norm(name, out_channels, inplace=True):
if name == 'bn':
module = nn.BatchNorm2d(out_channels)
elif name == 'gn':
module = nn.GroupNorm(num_channels=out_channels, num_groups=32)
return module


class BaseConv(nn.Module):
"""A Conv2d -> Batchnorm -> silu/leaky relu block"""

def __init__(self,
in_channels,
out_channels,
ksize,
stride=1,
groups=1,
bias=False,
act='silu',
norm='bn'):
super().__init__()
# same padding
pad = (ksize - 1) // 2
self.conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size=ksize,
stride=stride,
padding=pad,
groups=groups,
bias=bias,
)
if norm is not None:
self.bn = get_norm(norm, out_channels, inplace=True)
if act is not None:
self.act = get_activation(act, inplace=True)
self.with_norm = norm is not None
self.with_act = act is not None

def forward(self, x):
x = self.conv(x)
if self.with_norm:
# x = self.norm(x)
x = self.bn(x)
if self.with_act:
x = self.act(x)
return x

def fuseforward(self, x):
return self.act(self.conv(x))


class DepthWiseConv(nn.Module):

def __init__(self,
in_channels,
out_channels,
ksize,
stride=1,
groups=None,
bias=False,
act='silu',
norm='bn'):
super().__init__()
padding = (ksize - 1) // 2
self.depthwise = nn.Conv2d(
in_channels,
in_channels,
kernel_size=ksize,
stride=stride,
padding=padding,
groups=in_channels,
bias=bias,
)

self.pointwise = nn.Conv2d(
in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0,
bias=bias)
if norm is not None:
self.dwnorm = get_norm(norm, in_channels, inplace=True)
self.pwnorm = get_norm(norm, out_channels, inplace=True)
if act is not None:
self.act = get_activation(act, inplace=True)

self.with_norm = norm is not None
self.with_act = act is not None
self.order = ['depthwise', 'dwnorm', 'pointwise', 'act']

def forward(self, x):

for layer_name in self.order:
layer = self.__getattr__(layer_name)
if layer is not None:
x = layer(x)
return x


class DWConv(nn.Module):
"""Depthwise Conv + Conv"""

def __init__(self, in_channels, out_channels, ksize, stride=1, act='silu'):
super().__init__()
self.dconv = BaseConv(
in_channels,
in_channels,
ksize=ksize,
stride=stride,
groups=in_channels,
act=act,
)
self.pconv = BaseConv(
in_channels, out_channels, ksize=1, stride=1, groups=1, act=act)

def forward(self, x):
x = self.dconv(x)
return self.pconv(x)


class Bottleneck(nn.Module):
# Standard bottleneck
def __init__(
self,
in_channels,
out_channels,
shortcut=True,
expansion=0.5,
depthwise=False,
act='silu',
reparam=False,
):
super().__init__()
hidden_channels = int(out_channels * expansion)
Conv = DWConv if depthwise else BaseConv
k_conv1 = 3 if reparam else 1
self.conv1 = BaseConv(
in_channels, hidden_channels, k_conv1, stride=1, act=act)
if reparam:
self.conv2 = RepVggBlock(
hidden_channels, out_channels, 3, stride=1, act=act)
else:
self.conv2 = Conv(
hidden_channels, out_channels, 3, stride=1, act=act)
self.use_add = shortcut and in_channels == out_channels

def forward(self, x):
y = self.conv2(self.conv1(x))
if self.use_add:
y = y + x
return y


class ResLayer(nn.Module):
'Residual layer with `in_channels` inputs.'

def __init__(self, in_channels: int):
super().__init__()
mid_channels = in_channels // 2
self.layer1 = BaseConv(
in_channels, mid_channels, ksize=1, stride=1, act='lrelu')
self.layer2 = BaseConv(
mid_channels, in_channels, ksize=3, stride=1, act='lrelu')

def forward(self, x):
out = self.layer2(self.layer1(x))
return x + out


class SPPBottleneck(nn.Module):
"""Spatial pyramid pooling layer used in YOLOv3-SPP"""

def __init__(self,
in_channels,
out_channels,
kernel_sizes=(5, 9, 13),
activation='silu'):
super().__init__()
hidden_channels = in_channels // 2
self.conv1 = BaseConv(
in_channels, hidden_channels, 1, stride=1, act=activation)
self.m = nn.ModuleList([
nn.MaxPool2d(kernel_size=ks, stride=1, padding=ks // 2)
for ks in kernel_sizes
])
conv2_channels = hidden_channels * (len(kernel_sizes) + 1)
self.conv2 = BaseConv(
conv2_channels, out_channels, 1, stride=1, act=activation)

def forward(self, x):
x = self.conv1(x)
x = torch.cat([x] + [m(x) for m in self.m], dim=1)
x = self.conv2(x)
return x


class CSPLayer(nn.Module):
"""C3 in yolov5, CSP Bottleneck with 3 convolutions"""

def __init__(
self,
in_channels,
out_channels,
n=1,
shortcut=True,
expansion=0.5,
depthwise=False,
act='silu',
reparam=False,
):
"""
Args:
in_channels (int): input channels.
out_channels (int): output channels.
n (int): number of Bottlenecks. Default value: 1.
"""
# ch_in, ch_out, number, shortcut, groups, expansion
super().__init__()
hidden_channels = int(out_channels * expansion) # hidden channels
self.conv1 = BaseConv(
in_channels, hidden_channels, 1, stride=1, act=act)
self.conv2 = BaseConv(
in_channels, hidden_channels, 1, stride=1, act=act)
self.conv3 = BaseConv(
2 * hidden_channels, out_channels, 1, stride=1, act=act)
module_list = [
Bottleneck(
hidden_channels,
hidden_channels,
shortcut,
1.0,
depthwise,
act=act,
reparam=reparam) for _ in range(n)
]
self.m = nn.Sequential(*module_list)

def forward(self, x):
x_1 = self.conv1(x)
x_2 = self.conv2(x)
x_1 = self.m(x_1)
x = torch.cat((x_1, x_2), dim=1)
return self.conv3(x)


class Focus(nn.Module):
"""Focus width and height information into channel space."""

def __init__(self,
in_channels,
out_channels,
ksize=1,
stride=1,
act='silu'):
super().__init__()
self.conv = BaseConv(
in_channels * 4, out_channels, ksize, stride, act=act)

def forward(self, x):
# shape of x (b,c,w,h) -> y(b,4c,w/2,h/2)
patch_top_left = x[..., ::2, ::2]
patch_top_right = x[..., ::2, 1::2]
patch_bot_left = x[..., 1::2, ::2]
patch_bot_right = x[..., 1::2, 1::2]
x = torch.cat(
(
patch_top_left,
patch_bot_left,
patch_top_right,
patch_bot_right,
),
dim=1,
)
return self.conv(x)


class fast_Focus(nn.Module):

def __init__(self,
in_channels,
out_channels,
ksize=1,
stride=1,
act='silu'):
super(Focus, self).__init__()
self.conv1 = self.focus_conv(w1=1.0)
self.conv2 = self.focus_conv(w3=1.0)
self.conv3 = self.focus_conv(w2=1.0)
self.conv4 = self.focus_conv(w4=1.0)

self.conv = BaseConv(
in_channels * 4, out_channels, ksize, stride, act=act)

def forward(self, x):
return self.conv(
torch.cat(
[self.conv1(x),
self.conv2(x),
self.conv3(x),
self.conv4(x)], 1))

def focus_conv(self, w1=0.0, w2=0.0, w3=0.0, w4=0.0):
conv = nn.Conv2d(3, 3, 2, 2, groups=3, bias=False)
conv.weight = self.init_weights_constant(w1, w2, w3, w4)
conv.weight.requires_grad = False
return conv

def init_weights_constant(self, w1=0.0, w2=0.0, w3=0.0, w4=0.0):
return nn.Parameter(
torch.tensor([[[[w1, w2], [w3, w4]]], [[[w1, w2], [w3, w4]]],
[[[w1, w2], [w3, w4]]]]))


# shufflenet block
def channel_shuffle(x, groups=2):
bat_size, channels, w, h = x.shape
group_c = channels // groups
x = x.view(bat_size, groups, group_c, w, h)
x = torch.transpose(x, 1, 2).contiguous()
x = x.view(bat_size, -1, w, h)
return x


def conv_1x1_bn(in_c, out_c, stride=1):
return nn.Sequential(
nn.Conv2d(in_c, out_c, 1, stride, 0, bias=False),
nn.BatchNorm2d(out_c), nn.ReLU(True))


def conv_bn(in_c, out_c, stride=2):
return nn.Sequential(
nn.Conv2d(in_c, out_c, 3, stride, 1, bias=False),
nn.BatchNorm2d(out_c), nn.ReLU(True))


class ShuffleBlock(nn.Module):

def __init__(self, in_c, out_c, downsample=False):
super(ShuffleBlock, self).__init__()
self.downsample = downsample
half_c = out_c // 2
if downsample:
self.branch1 = nn.Sequential(
# 3*3 dw conv, stride = 2
# nn.Conv2d(in_c, in_c, 3, 2, 1, groups=in_c, bias=False),
nn.Conv2d(in_c, in_c, 3, 1, 1, groups=in_c, bias=False),
nn.BatchNorm2d(in_c),
# 1*1 pw conv
nn.Conv2d(in_c, half_c, 1, 1, 0, bias=False),
nn.BatchNorm2d(half_c),
nn.ReLU(True))

self.branch2 = nn.Sequential(
# 1*1 pw conv
nn.Conv2d(in_c, half_c, 1, 1, 0, bias=False),
nn.BatchNorm2d(half_c),
nn.ReLU(True),
# 3*3 dw conv, stride = 2
# nn.Conv2d(half_c, half_c, 3, 2, 1, groups=half_c, bias=False),
nn.Conv2d(half_c, half_c, 3, 1, 1, groups=half_c, bias=False),
nn.BatchNorm2d(half_c),
# 1*1 pw conv
nn.Conv2d(half_c, half_c, 1, 1, 0, bias=False),
nn.BatchNorm2d(half_c),
nn.ReLU(True))
else:
# in_c = out_c
assert in_c == out_c

self.branch2 = nn.Sequential(
# 1*1 pw conv
nn.Conv2d(half_c, half_c, 1, 1, 0, bias=False),
nn.BatchNorm2d(half_c),
nn.ReLU(True),
# 3*3 dw conv, stride = 1
nn.Conv2d(half_c, half_c, 3, 1, 1, groups=half_c, bias=False),
nn.BatchNorm2d(half_c),
# 1*1 pw conv
nn.Conv2d(half_c, half_c, 1, 1, 0, bias=False),
nn.BatchNorm2d(half_c),
nn.ReLU(True))

def forward(self, x):
out = None
if self.downsample:
# if it is downsampling, we don't need to do channel split
out = torch.cat((self.branch1(x), self.branch2(x)), 1)
else:
# channel split
channels = x.shape[1]
c = channels // 2
x1 = x[:, :c, :, :]
x2 = x[:, c:, :, :]
out = torch.cat((x1, self.branch2(x2)), 1)
return channel_shuffle(out, 2)


class ShuffleCSPLayer(nn.Module):
"""C3 in yolov5, CSP Bottleneck with 3 convolutions"""

def __init__(
self,
in_channels,
out_channels,
n=1,
shortcut=True,
expansion=0.5,
depthwise=False,
act='silu',
):
"""
Args:
in_channels (int): input channels.
out_channels (int): output channels.
n (int): number of Bottlenecks. Default value: 1.
"""
# ch_in, ch_out, number, shortcut, groups, expansion
super().__init__()
hidden_channels = int(out_channels * expansion) # hidden channels
self.conv1 = BaseConv(
in_channels, hidden_channels, 1, stride=1, act=act)
self.conv2 = BaseConv(
in_channels, hidden_channels, 1, stride=1, act=act)
module_list = [
Bottleneck(
hidden_channels,
hidden_channels,
shortcut,
1.0,
depthwise,
act=act) for _ in range(n)
]
self.m = nn.Sequential(*module_list)

def forward(self, x):
x_1 = self.conv1(x)
x_2 = self.conv2(x)
x_1 = self.m(x_1)
x = torch.cat((x_1, x_2), dim=1)
# add channel shuffle
return channel_shuffle(x, 2)

+ 324
- 0
modelscope/models/cv/tinynas_detection/core/neck_ops.py View File

@@ -0,0 +1,324 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet.

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


class Swish(nn.Module):

def __init__(self, inplace=True):
super(Swish, self).__init__()
self.inplace = inplace

def forward(self, x):
if self.inplace:
x.mul_(F.sigmoid(x))
return x
else:
return x * F.sigmoid(x)


def get_activation(name='silu', inplace=True):
if name is None:
return nn.Identity()

if isinstance(name, str):
if name == 'silu':
module = nn.SiLU(inplace=inplace)
elif name == 'relu':
module = nn.ReLU(inplace=inplace)
elif name == 'lrelu':
module = nn.LeakyReLU(0.1, inplace=inplace)
elif name == 'swish':
module = Swish(inplace=inplace)
elif name == 'hardsigmoid':
module = nn.Hardsigmoid(inplace=inplace)
else:
raise AttributeError('Unsupported act type: {}'.format(name))
return module
elif isinstance(name, nn.Module):
return name
else:
raise AttributeError('Unsupported act type: {}'.format(name))


class ConvBNLayer(nn.Module):

def __init__(self,
ch_in,
ch_out,
filter_size=3,
stride=1,
groups=1,
padding=0,
act=None):
super(ConvBNLayer, self).__init__()
self.conv = nn.Conv2d(
in_channels=ch_in,
out_channels=ch_out,
kernel_size=filter_size,
stride=stride,
padding=padding,
groups=groups,
bias=False)
self.bn = nn.BatchNorm2d(ch_out, )
self.act = get_activation(act, inplace=True)

def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.act(x)

return x


class RepVGGBlock(nn.Module):

def __init__(self, ch_in, ch_out, act='relu', deploy=False):
super(RepVGGBlock, self).__init__()
self.ch_in = ch_in
self.ch_out = ch_out
self.deploy = deploy
self.in_channels = ch_in
self.groups = 1
if self.deploy is False:
self.rbr_dense = ConvBNLayer(
ch_in, ch_out, 3, stride=1, padding=1, act=None)
self.rbr_1x1 = ConvBNLayer(
ch_in, ch_out, 1, stride=1, padding=0, act=None)
# self.rbr_identity = nn.BatchNorm2d(num_features=ch_in) if ch_out == ch_in else None
self.rbr_identity = None
else:
self.rbr_reparam = nn.Conv2d(
in_channels=self.ch_in,
out_channels=self.ch_out,
kernel_size=3,
stride=1,
padding=1,
groups=1)
self.act = get_activation(act) if act is None or isinstance(
act, (str, dict)) else act

def forward(self, x):
if self.deploy:
print('----------deploy----------')
y = self.rbr_reparam(x)
else:
if self.rbr_identity is None:
y = self.rbr_dense(x) + self.rbr_1x1(x)
else:
y = self.rbr_dense(x) + self.rbr_1x1(x) + self.rbr_identity(x)

y = self.act(y)
return y

def switch_to_deploy(self):
print('switch')
if not hasattr(self, 'rbr_reparam'):
# return
self.rbr_reparam = nn.Conv2d(
in_channels=self.ch_in,
out_channels=self.ch_out,
kernel_size=3,
stride=1,
padding=1,
groups=1)
print('switch')
kernel, bias = self.get_equivalent_kernel_bias()
self.rbr_reparam.weight.data = kernel
self.rbr_reparam.bias.data = bias
for para in self.parameters():
para.detach_()
# self.__delattr__(self.rbr_dense)
# self.__delattr__(self.rbr_1x1)
self.__delattr__('rbr_dense')
self.__delattr__('rbr_1x1')
if hasattr(self, 'rbr_identity'):
self.__delattr__('rbr_identity')
if hasattr(self, 'id_tensor'):
self.__delattr__('id_tensor')
self.deploy = True

def get_equivalent_kernel_bias(self):
kernel3x3, bias3x3 = self._fuse_bn_tensor(self.rbr_dense)
kernel1x1, bias1x1 = self._fuse_bn_tensor(self.rbr_1x1)
kernelid, biasid = self._fuse_bn_tensor(self.rbr_identity)
return kernel3x3 + self._pad_1x1_to_3x3_tensor(
kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid

def _pad_1x1_to_3x3_tensor(self, kernel1x1):
if kernel1x1 is None:
return 0
else:
return torch.nn.functional.pad(kernel1x1, [1, 1, 1, 1])

def _fuse_bn_tensor(self, branch):
if branch is None:
return 0, 0
# if isinstance(branch, nn.Sequential):
if isinstance(branch, ConvBNLayer):
kernel = branch.conv.weight
running_mean = branch.bn.running_mean
running_var = branch.bn.running_var
gamma = branch.bn.weight
beta = branch.bn.bias
eps = branch.bn.eps
else:
assert isinstance(branch, nn.BatchNorm2d)
if not hasattr(self, 'id_tensor'):
input_dim = self.in_channels // self.groups
kernel_value = np.zeros((self.in_channels, input_dim, 3, 3),
dtype=np.float32)
for i in range(self.in_channels):
kernel_value[i, i % input_dim, 1, 1] = 1
self.id_tensor = torch.from_numpy(kernel_value).to(
branch.weight.device)
kernel = self.id_tensor
running_mean = branch.running_mean
running_var = branch.running_var
gamma = branch.weight
beta = branch.bias
eps = branch.eps
std = (running_var + eps).sqrt()
t = (gamma / std).reshape(-1, 1, 1, 1)
return kernel * t, beta - running_mean * gamma / std


class BasicBlock(nn.Module):

def __init__(self, ch_in, ch_out, act='relu', shortcut=True):
super(BasicBlock, self).__init__()
assert ch_in == ch_out
# self.conv1 = ConvBNLayer(ch_in, ch_out, 3, stride=1, padding=1, act=act)
# self.conv1 = ConvBNLayer(ch_in, ch_out, 1, stride=1, padding=0, act=act)
self.conv2 = RepVGGBlock(ch_in, ch_out, act=act)
self.shortcut = shortcut

def forward(self, x):
# y = self.conv1(x)
y = self.conv2(x)
if self.shortcut:
return x + y
else:
return y


class BasicBlock_3x3(nn.Module):

def __init__(self, ch_in, ch_out, act='relu', shortcut=True):
super(BasicBlock_3x3, self).__init__()
assert ch_in == ch_out
self.conv1 = ConvBNLayer(
ch_in, ch_out, 3, stride=1, padding=1, act=act)
# self.conv1 = ConvBNLayer(ch_in, ch_out, 1, stride=1, padding=0, act=act)
self.conv2 = RepVGGBlock(ch_in, ch_out, act=act)
self.shortcut = shortcut

def forward(self, x):
y = self.conv1(x)
y = self.conv2(y)
if self.shortcut:
return x + y
else:
return y


class BasicBlock_3x3_Reverse(nn.Module):

def __init__(self, ch_in, ch_out, act='relu', shortcut=True):
super(BasicBlock_3x3_Reverse, self).__init__()
assert ch_in == ch_out
self.conv1 = ConvBNLayer(
ch_in, ch_out, 3, stride=1, padding=1, act=act)
# self.conv1 = ConvBNLayer(ch_in, ch_out, 1, stride=1, padding=0, act=act)
self.conv2 = RepVGGBlock(ch_in, ch_out, act=act)
self.shortcut = shortcut

def forward(self, x):
y = self.conv2(x)
y = self.conv1(y)
if self.shortcut:
return x + y
else:
return y


class SPP(nn.Module):

def __init__(
self,
ch_in,
ch_out,
k,
pool_size,
act='swish',
):
super(SPP, self).__init__()
self.pool = []
for i, size in enumerate(pool_size):
pool = nn.MaxPool2d(
kernel_size=size, stride=1, padding=size // 2, ceil_mode=False)
self.add_module('pool{}'.format(i), pool)
self.pool.append(pool)
self.conv = ConvBNLayer(ch_in, ch_out, k, padding=k // 2, act=act)

def forward(self, x):
outs = [x]

for pool in self.pool:
outs.append(pool(x))
y = torch.cat(outs, axis=1)

y = self.conv(y)
return y


class CSPStage(nn.Module):

def __init__(self, block_fn, ch_in, ch_out, n, act='swish', spp=False):
super(CSPStage, self).__init__()

ch_mid = int(ch_out // 2)
self.conv1 = ConvBNLayer(ch_in, ch_mid, 1, act=act)
self.conv2 = ConvBNLayer(ch_in, ch_mid, 1, act=act)
# self.conv2 = ConvBNLayer(ch_in, ch_mid, 3, stride=1, padding=1, act=act)
self.convs = nn.Sequential()

next_ch_in = ch_mid
for i in range(n):
if block_fn == 'BasicBlock':
self.convs.add_module(
str(i),
BasicBlock(next_ch_in, ch_mid, act=act, shortcut=False))
elif block_fn == 'BasicBlock_3x3':
self.convs.add_module(
str(i),
BasicBlock_3x3(next_ch_in, ch_mid, act=act, shortcut=True))
elif block_fn == 'BasicBlock_3x3_Reverse':
self.convs.add_module(
str(i),
BasicBlock_3x3_Reverse(
next_ch_in, ch_mid, act=act, shortcut=True))
else:
raise NotImplementedError
if i == (n - 1) // 2 and spp:
self.convs.add_module(
'spp', SPP(ch_mid * 4, ch_mid, 1, [5, 9, 13], act=act))
next_ch_in = ch_mid
# self.convs = nn.Sequential(*convs)
self.conv3 = ConvBNLayer(ch_mid * (n + 1), ch_out, 1, act=act)

def forward(self, x):
y1 = self.conv1(x)
y2 = self.conv2(x)

mid_out = [y1]
for conv in self.convs:
y2 = conv(y2)
mid_out.append(y2)
y = torch.cat(mid_out, axis=1)
y = self.conv3(y)
return y

+ 205
- 0
modelscope/models/cv/tinynas_detection/core/repvgg_block.py View File

@@ -0,0 +1,205 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet.

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from torch.nn.parameter import Parameter


def get_activation(name='silu', inplace=True):
if name == 'silu':
module = nn.SiLU(inplace=inplace)
elif name == 'relu':
module = nn.ReLU(inplace=inplace)
elif name == 'lrelu':
module = nn.LeakyReLU(0.1, inplace=inplace)
elif name == 'identity':
module = nn.Identity()
else:
raise AttributeError('Unsupported act type: {}'.format(name))
return module


def conv_bn(in_channels, out_channels, kernel_size, stride, padding, groups=1):
'''Basic cell for rep-style block, including conv and bn'''
result = nn.Sequential()
result.add_module(
'conv',
nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
groups=groups,
bias=False))
result.add_module('bn', nn.BatchNorm2d(num_features=out_channels))
return result


class RepVggBlock(nn.Module):
'''RepVggBlock is a basic rep-style block, including training and deploy status
This code is based on https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py
'''

def __init__(self,
in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1,
dilation=1,
groups=1,
padding_mode='zeros',
deploy=False,
use_se=False,
act='relu',
norm=None):
super(RepVggBlock, self).__init__()
""" Initialization of the class.
Args:
in_channels (int): Number of channels in the input image
out_channels (int): Number of channels produced by the convolution
kernel_size (int or tuple): Size of the convolving kernel
stride (int or tuple, optional): Stride of the convolution. Default: 1
padding (int or tuple, optional): Zero-padding added to both sides of
the input. Default: 1
dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
groups (int, optional): Number of blocked connections from input
channels to output channels. Default: 1
padding_mode (string, optional): Default: 'zeros'
deploy: Whether to be deploy status or training status. Default: False
use_se: Whether to use se. Default: False
"""
self.deploy = deploy
self.groups = groups
self.in_channels = in_channels
self.out_channels = out_channels

assert kernel_size == 3
assert padding == 1

padding_11 = padding - kernel_size // 2

if isinstance(act, str):
self.nonlinearity = get_activation(act)
else:
self.nonlinearity = act

if use_se:
raise NotImplementedError('se block not supported yet')
else:
self.se = nn.Identity()

if deploy:
self.rbr_reparam = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=True,
padding_mode=padding_mode)

else:
self.rbr_identity = None
self.rbr_dense = conv_bn(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
groups=groups)
self.rbr_1x1 = conv_bn(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
stride=stride,
padding=padding_11,
groups=groups)

def forward(self, inputs):
'''Forward process'''
if hasattr(self, 'rbr_reparam'):
return self.nonlinearity(self.se(self.rbr_reparam(inputs)))

if self.rbr_identity is None:
id_out = 0
else:
id_out = self.rbr_identity(inputs)

return self.nonlinearity(
self.se(self.rbr_dense(inputs) + self.rbr_1x1(inputs) + id_out))

def get_equivalent_kernel_bias(self):
kernel3x3, bias3x3 = self._fuse_bn_tensor(self.rbr_dense)
kernel1x1, bias1x1 = self._fuse_bn_tensor(self.rbr_1x1)
kernelid, biasid = self._fuse_bn_tensor(self.rbr_identity)
return kernel3x3 + self._pad_1x1_to_3x3_tensor(
kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid

def _pad_1x1_to_3x3_tensor(self, kernel1x1):
if kernel1x1 is None:
return 0
else:
return torch.nn.functional.pad(kernel1x1, [1, 1, 1, 1])

def _fuse_bn_tensor(self, branch):
if branch is None:
return 0, 0
if isinstance(branch, nn.Sequential):
kernel = branch.conv.weight
running_mean = branch.bn.running_mean
running_var = branch.bn.running_var
gamma = branch.bn.weight
beta = branch.bn.bias
eps = branch.bn.eps
else:
assert isinstance(branch, nn.BatchNorm2d)
if not hasattr(self, 'id_tensor'):
input_dim = self.in_channels // self.groups
kernel_value = np.zeros((self.in_channels, input_dim, 3, 3),
dtype=np.float32)
for i in range(self.in_channels):
kernel_value[i, i % input_dim, 1, 1] = 1
self.id_tensor = torch.from_numpy(kernel_value).to(
branch.weight.device)
kernel = self.id_tensor
running_mean = branch.running_mean
running_var = branch.running_var
gamma = branch.weight
beta = branch.bias
eps = branch.eps
std = (running_var + eps).sqrt()
t = (gamma / std).reshape(-1, 1, 1, 1)
return kernel * t, beta - running_mean * gamma / std

def switch_to_deploy(self):
if hasattr(self, 'rbr_reparam'):
return
kernel, bias = self.get_equivalent_kernel_bias()
self.rbr_reparam = nn.Conv2d(
in_channels=self.rbr_dense.conv.in_channels,
out_channels=self.rbr_dense.conv.out_channels,
kernel_size=self.rbr_dense.conv.kernel_size,
stride=self.rbr_dense.conv.stride,
padding=self.rbr_dense.conv.padding,
dilation=self.rbr_dense.conv.dilation,
groups=self.rbr_dense.conv.groups,
bias=True)
self.rbr_reparam.weight.data = kernel
self.rbr_reparam.bias.data = bias
for para in self.parameters():
para.detach_()
self.__delattr__('rbr_dense')
self.__delattr__('rbr_1x1')
if hasattr(self, 'rbr_identity'):
self.__delattr__('rbr_identity')
if hasattr(self, 'id_tensor'):
self.__delattr__('id_tensor')
self.deploy = True

+ 196
- 0
modelscope/models/cv/tinynas_detection/core/utils.py View File

@@ -0,0 +1,196 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet.

import numpy as np
import torch
import torchvision

__all__ = [
'filter_box',
'postprocess_airdet',
'bboxes_iou',
'matrix_iou',
'adjust_box_anns',
'xyxy2xywh',
'xyxy2cxcywh',
]


def multiclass_nms(multi_bboxes,
multi_scores,
score_thr,
iou_thr,
max_num=100,
score_factors=None):
"""NMS for multi-class bboxes.

Args:
multi_bboxes (Tensor): shape (n, #class*4) or (n, 4)
multi_scores (Tensor): shape (n, #class), where the last column
contains scores of the background class, but this will be ignored.
score_thr (float): bbox threshold, bboxes with scores lower than it
will not be considered.
nms_thr (float): NMS IoU threshold
max_num (int): if there are more than max_num bboxes after NMS,
only top max_num will be kept.
score_factors (Tensor): The factors multiplied to scores before
applying NMS

Returns:
tuple: (bboxes, labels), tensors of shape (k, 5) and (k, 1). Labels \
are 0-based.
"""
num_classes = multi_scores.size(1)
# exclude background category
if multi_bboxes.shape[1] > 4:
bboxes = multi_bboxes.view(multi_scores.size(0), -1, 4)
else:
bboxes = multi_bboxes[:, None].expand(
multi_scores.size(0), num_classes, 4)
scores = multi_scores
# filter out boxes with low scores
valid_mask = scores > score_thr # 1000 * 80 bool

# We use masked_select for ONNX exporting purpose,
# which is equivalent to bboxes = bboxes[valid_mask]
# (TODO): as ONNX does not support repeat now,
# we have to use this ugly code
# bboxes -> 1000, 4
bboxes = torch.masked_select(
bboxes,
torch.stack((valid_mask, valid_mask, valid_mask, valid_mask),
-1)).view(-1, 4) # mask-> 1000*80*4, 80000*4
if score_factors is not None:
scores = scores * score_factors[:, None]
scores = torch.masked_select(scores, valid_mask)
labels = valid_mask.nonzero(as_tuple=False)[:, 1]

if bboxes.numel() == 0:
bboxes = multi_bboxes.new_zeros((0, 5))
labels = multi_bboxes.new_zeros((0, ), dtype=torch.long)
scores = multi_bboxes.new_zeros((0, ))

return bboxes, scores, labels

keep = torchvision.ops.batched_nms(bboxes, scores, labels, iou_thr)

if max_num > 0:
keep = keep[:max_num]

return bboxes[keep], scores[keep], labels[keep]


def filter_box(output, scale_range):
"""
output: (N, 5+class) shape
"""
min_scale, max_scale = scale_range
w = output[:, 2] - output[:, 0]
h = output[:, 3] - output[:, 1]
keep = (w * h > min_scale * min_scale) & (w * h < max_scale * max_scale)
return output[keep]


def filter_results(boxlist, num_classes, nms_thre):
boxes = boxlist.bbox
scores = boxlist.get_field('scores')
cls = boxlist.get_field('labels')
nms_out_index = torchvision.ops.batched_nms(
boxes,
scores,
cls,
nms_thre,
)
boxlist = boxlist[nms_out_index]

return boxlist


def postprocess_airdet(prediction,
num_classes,
conf_thre=0.7,
nms_thre=0.45,
imgs=None):
box_corner = prediction.new(prediction.shape)
box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2
box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2
box_corner[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2
box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2
prediction[:, :, :4] = box_corner[:, :, :4]
output = [None for _ in range(len(prediction))]
for i, image_pred in enumerate(prediction):
# If none are remaining => process next image
if not image_pred.size(0):
continue
multi_bboxes = image_pred[:, :4]
multi_scores = image_pred[:, 5:]
detections, scores, labels = multiclass_nms(multi_bboxes, multi_scores,
conf_thre, nms_thre, 500)
detections = torch.cat(
(detections, scores[:, None], scores[:, None], labels[:, None]),
dim=1)

if output[i] is None:
output[i] = detections
else:
output[i] = torch.cat((output[i], detections))
return output


def bboxes_iou(bboxes_a, bboxes_b, xyxy=True):
if bboxes_a.shape[1] != 4 or bboxes_b.shape[1] != 4:
raise IndexError

if xyxy:
tl = torch.max(bboxes_a[:, None, :2], bboxes_b[:, :2])
br = torch.min(bboxes_a[:, None, 2:], bboxes_b[:, 2:])
area_a = torch.prod(bboxes_a[:, 2:] - bboxes_a[:, :2], 1)
area_b = torch.prod(bboxes_b[:, 2:] - bboxes_b[:, :2], 1)
else:
tl = torch.max(
(bboxes_a[:, None, :2] - bboxes_a[:, None, 2:] / 2),
(bboxes_b[:, :2] - bboxes_b[:, 2:] / 2),
)
br = torch.min(
(bboxes_a[:, None, :2] + bboxes_a[:, None, 2:] / 2),
(bboxes_b[:, :2] + bboxes_b[:, 2:] / 2),
)

area_a = torch.prod(bboxes_a[:, 2:], 1)
area_b = torch.prod(bboxes_b[:, 2:], 1)
en = (tl < br).type(tl.type()).prod(dim=2)
area_i = torch.prod(br - tl, 2) * en # * ((tl < br).all())
return area_i / (area_a[:, None] + area_b - area_i)


def matrix_iou(a, b):
"""
return iou of a and b, numpy version for data augenmentation
"""
lt = np.maximum(a[:, np.newaxis, :2], b[:, :2])
rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:])

area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2)
area_a = np.prod(a[:, 2:] - a[:, :2], axis=1)
area_b = np.prod(b[:, 2:] - b[:, :2], axis=1)
return area_i / (area_a[:, np.newaxis] + area_b - area_i + 1e-12)


def adjust_box_anns(bbox, scale_ratio, padw, padh, w_max, h_max):
bbox[:, 0::2] = np.clip(bbox[:, 0::2] * scale_ratio + padw, 0, w_max)
bbox[:, 1::2] = np.clip(bbox[:, 1::2] * scale_ratio + padh, 0, h_max)
return bbox


def xyxy2xywh(bboxes):
bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 0]
bboxes[:, 3] = bboxes[:, 3] - bboxes[:, 1]
return bboxes


def xyxy2cxcywh(bboxes):
bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 0]
bboxes[:, 3] = bboxes[:, 3] - bboxes[:, 1]
bboxes[:, 0] = bboxes[:, 0] + bboxes[:, 2] * 0.5
bboxes[:, 1] = bboxes[:, 1] + bboxes[:, 3] * 0.5
return bboxes

+ 181
- 0
modelscope/models/cv/tinynas_detection/detector.py View File

@@ -0,0 +1,181 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet.

import os.path as osp
import pickle

import cv2
import torch
import torchvision

from modelscope.metainfo import Models
from modelscope.models.base.base_torch_model import TorchModel
from modelscope.models.builder import MODELS
from modelscope.utils.config import Config
from modelscope.utils.constant import ModelFile, Tasks
from .backbone import build_backbone
from .head import build_head
from .neck import build_neck
from .utils import parse_config


class SingleStageDetector(TorchModel):
"""
The base class of single stage detector.
"""

def __init__(self, model_dir: str, *args, **kwargs):
"""
init model by cfg
"""
super().__init__(model_dir, *args, **kwargs)

config_path = osp.join(model_dir, 'airdet_s.py')
config = parse_config(config_path)
self.cfg = config
model_path = osp.join(model_dir, config.model.name)
label_map = osp.join(model_dir, config.model.class_map)
self.label_map = pickle.load(open(label_map, 'rb'))
self.size_divisible = config.dataset.size_divisibility
self.num_classes = config.model.head.num_classes
self.conf_thre = config.model.head.nms_conf_thre
self.nms_thre = config.model.head.nms_iou_thre

self.backbone = build_backbone(self.cfg.model.backbone)
self.neck = build_neck(self.cfg.model.neck)
self.head = build_head(self.cfg.model.head)

self.load_pretrain_model(model_path)

def load_pretrain_model(self, pretrain_model):

state_dict = torch.load(pretrain_model, map_location='cpu')['model']
new_state_dict = {}
for k, v in state_dict.items():
k = k.replace('module.', '')
new_state_dict[k] = v
self.load_state_dict(new_state_dict, strict=True)

def inference(self, x):

if self.training:
return self.forward_train(x)
else:
return self.forward_eval(x)

def forward_train(self, x):

pass

def forward_eval(self, x):

x = self.backbone(x)
x = self.neck(x)
prediction = self.head(x)

return prediction

def preprocess(self, image):
image = torch.from_numpy(image).type(torch.float32)
image = image.permute(2, 0, 1)
shape = image.shape # c, h, w
if self.size_divisible > 0:
import math
stride = self.size_divisible
shape = list(shape)
shape[1] = int(math.ceil(shape[1] / stride) * stride)
shape[2] = int(math.ceil(shape[2] / stride) * stride)
shape = tuple(shape)
pad_img = image.new(*shape).zero_()
pad_img[:, :image.shape[1], :image.shape[2]].copy_(image)
pad_img = pad_img.unsqueeze(0)

return pad_img

def postprocess(self, preds):
bboxes, scores, labels_idx = postprocess_gfocal(
preds, self.num_classes, self.conf_thre, self.nms_thre)
bboxes = bboxes.cpu().numpy()
scores = scores.cpu().numpy()
labels_idx = labels_idx.cpu().numpy()
labels = [self.label_map[idx + 1][0]['name'] for idx in labels_idx]

return (bboxes, scores, labels)


def multiclass_nms(multi_bboxes,
multi_scores,
score_thr,
iou_thr,
max_num=100,
score_factors=None):
"""NMS for multi-class bboxes.

Args:
multi_bboxes (Tensor): shape (n, #class*4) or (n, 4)
multi_scores (Tensor): shape (n, #class), where the last column
contains scores of the background class, but this will be ignored.
score_thr (float): bbox threshold, bboxes with scores lower than it
will not be considered.
nms_thr (float): NMS IoU threshold
max_num (int): if there are more than max_num bboxes after NMS,
only top max_num will be kept.
score_factors (Tensor): The factors multiplied to scores before
applying NMS

Returns:
tuple: (bboxes, labels), tensors of shape (k, 5) and (k, 1). Labels \
are 0-based.
"""
num_classes = multi_scores.size(1)
# exclude background category
if multi_bboxes.shape[1] > 4:
bboxes = multi_bboxes.view(multi_scores.size(0), -1, 4)
else:
bboxes = multi_bboxes[:, None].expand(
multi_scores.size(0), num_classes, 4)
scores = multi_scores
# filter out boxes with low scores
valid_mask = scores > score_thr # 1000 * 80 bool

# We use masked_select for ONNX exporting purpose,
# which is equivalent to bboxes = bboxes[valid_mask]
# (TODO): as ONNX does not support repeat now,
# we have to use this ugly code
# bboxes -> 1000, 4
bboxes = torch.masked_select(
bboxes,
torch.stack((valid_mask, valid_mask, valid_mask, valid_mask),
-1)).view(-1, 4) # mask-> 1000*80*4, 80000*4
if score_factors is not None:
scores = scores * score_factors[:, None]
scores = torch.masked_select(scores, valid_mask)
labels = valid_mask.nonzero(as_tuple=False)[:, 1]

if bboxes.numel() == 0:
bboxes = multi_bboxes.new_zeros((0, 5))
labels = multi_bboxes.new_zeros((0, ), dtype=torch.long)
scores = multi_bboxes.new_zeros((0, ))

return bboxes, scores, labels

keep = torchvision.ops.batched_nms(bboxes, scores, labels, iou_thr)

if max_num > 0:
keep = keep[:max_num]

return bboxes[keep], scores[keep], labels[keep]


def postprocess_gfocal(prediction, num_classes, conf_thre=0.05, nms_thre=0.7):
assert prediction.shape[0] == 1
for i, image_pred in enumerate(prediction):
# If none are remaining => process next image
if not image_pred.size(0):
continue
multi_bboxes = image_pred[:, :4]
multi_scores = image_pred[:, 4:]
detections, scores, labels = multiclass_nms(multi_bboxes, multi_scores,
conf_thre, nms_thre, 500)

return detections, scores, labels

+ 16
- 0
modelscope/models/cv/tinynas_detection/head/__init__.py View File

@@ -0,0 +1,16 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet.

import copy

from .gfocal_v2_tiny import GFocalHead_Tiny


def build_head(cfg):

head_cfg = copy.deepcopy(cfg)
name = head_cfg.pop('name')
if name == 'GFocalV2':
return GFocalHead_Tiny(**head_cfg)
else:
raise NotImplementedError

+ 361
- 0
modelscope/models/cv/tinynas_detection/head/gfocal_v2_tiny.py View File

@@ -0,0 +1,361 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet.

import functools
from functools import partial

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from ..core.base_ops import BaseConv, DWConv


class Scale(nn.Module):

def __init__(self, scale=1.0):
super(Scale, self).__init__()
self.scale = nn.Parameter(torch.tensor(scale, dtype=torch.float))

def forward(self, x):
return x * self.scale


def multi_apply(func, *args, **kwargs):

pfunc = partial(func, **kwargs) if kwargs else func
map_results = map(pfunc, *args)
return tuple(map(list, zip(*map_results)))


def xyxy2CxCywh(xyxy, size=None):
x1 = xyxy[..., 0]
y1 = xyxy[..., 1]
x2 = xyxy[..., 2]
y2 = xyxy[..., 3]

cx = (x1 + x2) / 2
cy = (y1 + y2) / 2

w = x2 - x1
h = y2 - y1
if size is not None:
w = w.clamp(min=0, max=size[1])
h = h.clamp(min=0, max=size[0])
return torch.stack([cx, cy, w, h], axis=-1)


def distance2bbox(points, distance, max_shape=None):
"""Decode distance prediction to bounding box.
"""
x1 = points[..., 0] - distance[..., 0]
y1 = points[..., 1] - distance[..., 1]
x2 = points[..., 0] + distance[..., 2]
y2 = points[..., 1] + distance[..., 3]
if max_shape is not None:
x1 = x1.clamp(min=0, max=max_shape[1])
y1 = y1.clamp(min=0, max=max_shape[0])
x2 = x2.clamp(min=0, max=max_shape[1])
y2 = y2.clamp(min=0, max=max_shape[0])
return torch.stack([x1, y1, x2, y2], -1)


def bbox2distance(points, bbox, max_dis=None, eps=0.1):
"""Decode bounding box based on distances.
"""
left = points[:, 0] - bbox[:, 0]
top = points[:, 1] - bbox[:, 1]
right = bbox[:, 2] - points[:, 0]
bottom = bbox[:, 3] - points[:, 1]
if max_dis is not None:
left = left.clamp(min=0, max=max_dis - eps)
top = top.clamp(min=0, max=max_dis - eps)
right = right.clamp(min=0, max=max_dis - eps)
bottom = bottom.clamp(min=0, max=max_dis - eps)
return torch.stack([left, top, right, bottom], -1)


class Integral(nn.Module):
"""A fixed layer for calculating integral result from distribution.
"""

def __init__(self, reg_max=16):
super(Integral, self).__init__()
self.reg_max = reg_max
self.register_buffer('project',
torch.linspace(0, self.reg_max, self.reg_max + 1))

def forward(self, x):
"""Forward feature from the regression head to get integral result of
bounding box location.
"""
shape = x.size()
x = F.softmax(x.reshape(*shape[:-1], 4, self.reg_max + 1), dim=-1)
b, nb, ne, _ = x.size()
x = x.reshape(b * nb * ne, self.reg_max + 1)
y = self.project.type_as(x).unsqueeze(1)
x = torch.matmul(x, y).reshape(b, nb, 4)
return x


class GFocalHead_Tiny(nn.Module):
"""Ref to Generalized Focal Loss V2: Learning Reliable Localization Quality
Estimation for Dense Object Detection.
"""

def __init__(
self,
num_classes,
in_channels,
stacked_convs=4, # 4
feat_channels=256,
reg_max=12,
reg_topk=4,
reg_channels=64,
strides=[8, 16, 32],
add_mean=True,
norm='gn',
act='relu',
start_kernel_size=3,
conv_groups=1,
conv_type='BaseConv',
simOTA_cls_weight=1.0,
simOTA_iou_weight=3.0,
octbase=8,
simlqe=False,
**kwargs):
self.simlqe = simlqe
self.num_classes = num_classes
self.in_channels = in_channels
self.strides = strides
self.feat_channels = feat_channels if isinstance(feat_channels, list) \
else [feat_channels] * len(self.strides)

self.cls_out_channels = num_classes + 1 # add 1 for keep consistance with former models
# and will be deprecated in future.
self.stacked_convs = stacked_convs
self.conv_groups = conv_groups
self.reg_max = reg_max
self.reg_topk = reg_topk
self.reg_channels = reg_channels
self.add_mean = add_mean
self.total_dim = reg_topk
self.start_kernel_size = start_kernel_size

self.norm = norm
self.act = act
self.conv_module = DWConv if conv_type == 'DWConv' else BaseConv

if add_mean:
self.total_dim += 1

super(GFocalHead_Tiny, self).__init__()
self.integral = Integral(self.reg_max)

self._init_layers()

def _build_not_shared_convs(self, in_channel, feat_channels):
self.relu = nn.ReLU(inplace=True)
cls_convs = nn.ModuleList()
reg_convs = nn.ModuleList()

for i in range(self.stacked_convs):
chn = feat_channels if i > 0 else in_channel
kernel_size = 3 if i > 0 else self.start_kernel_size
cls_convs.append(
self.conv_module(
chn,
feat_channels,
kernel_size,
stride=1,
groups=self.conv_groups,
norm=self.norm,
act=self.act))
reg_convs.append(
self.conv_module(
chn,
feat_channels,
kernel_size,
stride=1,
groups=self.conv_groups,
norm=self.norm,
act=self.act))
if not self.simlqe:
conf_vector = [nn.Conv2d(4 * self.total_dim, self.reg_channels, 1)]
else:
conf_vector = [
nn.Conv2d(4 * (self.reg_max + 1), self.reg_channels, 1)
]
conf_vector += [self.relu]
conf_vector += [nn.Conv2d(self.reg_channels, 1, 1), nn.Sigmoid()]
reg_conf = nn.Sequential(*conf_vector)

return cls_convs, reg_convs, reg_conf

def _init_layers(self):
"""Initialize layers of the head."""
self.relu = nn.ReLU(inplace=True)
self.cls_convs = nn.ModuleList()
self.reg_convs = nn.ModuleList()
self.reg_confs = nn.ModuleList()

for i in range(len(self.strides)):
cls_convs, reg_convs, reg_conf = self._build_not_shared_convs(
self.in_channels[i], self.feat_channels[i])
self.cls_convs.append(cls_convs)
self.reg_convs.append(reg_convs)
self.reg_confs.append(reg_conf)

self.gfl_cls = nn.ModuleList([
nn.Conv2d(
self.feat_channels[i], self.cls_out_channels, 3, padding=1)
for i in range(len(self.strides))
])

self.gfl_reg = nn.ModuleList([
nn.Conv2d(
self.feat_channels[i], 4 * (self.reg_max + 1), 3, padding=1)
for i in range(len(self.strides))
])

self.scales = nn.ModuleList([Scale(1.0) for _ in self.strides])

def forward(self,
xin,
labels=None,
imgs=None,
conf_thre=0.05,
nms_thre=0.7):

# prepare labels during training
b, c, h, w = xin[0].shape
if labels is not None:
gt_bbox_list = []
gt_cls_list = []
for label in labels:
gt_bbox_list.append(label.bbox)
gt_cls_list.append((label.get_field('labels')
- 1).long()) # labels starts from 1

# prepare priors for label assignment and bbox decode
mlvl_priors_list = [
self.get_single_level_center_priors(
xin[i].shape[0],
xin[i].shape[-2:],
stride,
dtype=torch.float32,
device=xin[0].device) for i, stride in enumerate(self.strides)
]
mlvl_priors = torch.cat(mlvl_priors_list, dim=1)

# forward for bboxes and classification prediction
cls_scores, bbox_preds = multi_apply(
self.forward_single,
xin,
self.cls_convs,
self.reg_convs,
self.gfl_cls,
self.gfl_reg,
self.reg_confs,
self.scales,
)
flatten_cls_scores = torch.cat(cls_scores, dim=1)
flatten_bbox_preds = torch.cat(bbox_preds, dim=1)

# calculating losses or bboxes decoded
if self.training:
loss = self.loss(flatten_cls_scores, flatten_bbox_preds,
gt_bbox_list, gt_cls_list, mlvl_priors)
return loss
else:
output = self.get_bboxes(flatten_cls_scores, flatten_bbox_preds,
mlvl_priors)
return output

def forward_single(self, x, cls_convs, reg_convs, gfl_cls, gfl_reg,
reg_conf, scale):
"""Forward feature of a single scale level.

"""
cls_feat = x
reg_feat = x

for cls_conv in cls_convs:
cls_feat = cls_conv(cls_feat)
for reg_conv in reg_convs:
reg_feat = reg_conv(reg_feat)

bbox_pred = scale(gfl_reg(reg_feat)).float()
N, C, H, W = bbox_pred.size()
prob = F.softmax(
bbox_pred.reshape(N, 4, self.reg_max + 1, H, W), dim=2)
if not self.simlqe:
prob_topk, _ = prob.topk(self.reg_topk, dim=2)

if self.add_mean:
stat = torch.cat(
[prob_topk, prob_topk.mean(dim=2, keepdim=True)], dim=2)
else:
stat = prob_topk

quality_score = reg_conf(stat.reshape(N, 4 * self.total_dim, H, W))
else:
quality_score = reg_conf(
bbox_pred.reshape(N, 4 * (self.reg_max + 1), H, W))

cls_score = gfl_cls(cls_feat).sigmoid() * quality_score

flatten_cls_score = cls_score.flatten(start_dim=2).transpose(1, 2)
flatten_bbox_pred = bbox_pred.flatten(start_dim=2).transpose(1, 2)
return flatten_cls_score, flatten_bbox_pred

def get_single_level_center_priors(self, batch_size, featmap_size, stride,
dtype, device):

h, w = featmap_size
x_range = (torch.arange(0, int(w), dtype=dtype,
device=device)) * stride
y_range = (torch.arange(0, int(h), dtype=dtype,
device=device)) * stride

x = x_range.repeat(h, 1)
y = y_range.unsqueeze(-1).repeat(1, w)

y = y.flatten()
x = x.flatten()
strides = x.new_full((x.shape[0], ), stride)
priors = torch.stack([x, y, strides, strides], dim=-1)

return priors.unsqueeze(0).repeat(batch_size, 1, 1)

def sample(self, assign_result, gt_bboxes):
pos_inds = torch.nonzero(
assign_result.gt_inds > 0, as_tuple=False).squeeze(-1).unique()
neg_inds = torch.nonzero(
assign_result.gt_inds == 0, as_tuple=False).squeeze(-1).unique()
pos_assigned_gt_inds = assign_result.gt_inds[pos_inds] - 1

if gt_bboxes.numel() == 0:
# hack for index error case
assert pos_assigned_gt_inds.numel() == 0
pos_gt_bboxes = torch.empty_like(gt_bboxes).view(-1, 4)
else:
if len(gt_bboxes.shape) < 2:
gt_bboxes = gt_bboxes.view(-1, 4)
pos_gt_bboxes = gt_bboxes[pos_assigned_gt_inds, :]

return pos_inds, neg_inds, pos_gt_bboxes, pos_assigned_gt_inds

def get_bboxes(self,
cls_preds,
reg_preds,
mlvl_center_priors,
img_meta=None):

dis_preds = self.integral(reg_preds) * mlvl_center_priors[..., 2, None]
bboxes = distance2bbox(mlvl_center_priors[..., :2], dis_preds)

res = torch.cat([bboxes, cls_preds[..., 0:self.num_classes]], dim=-1)

return res

+ 16
- 0
modelscope/models/cv/tinynas_detection/neck/__init__.py View File

@@ -0,0 +1,16 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet.

import copy

from .giraffe_fpn import GiraffeNeck
from .giraffe_fpn_v2 import GiraffeNeckV2


def build_neck(cfg):
neck_cfg = copy.deepcopy(cfg)
name = neck_cfg.pop('name')
if name == 'GiraffeNeck':
return GiraffeNeck(**neck_cfg)
elif name == 'GiraffeNeckV2':
return GiraffeNeckV2(**neck_cfg)

+ 235
- 0
modelscope/models/cv/tinynas_detection/neck/giraffe_config.py View File

@@ -0,0 +1,235 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet.

import collections
import itertools
import os

import networkx as nx
from omegaconf import OmegaConf

Node = collections.namedtuple('Node', ['id', 'inputs', 'type'])


def get_graph_info(graph):
input_nodes = []
output_nodes = []
Nodes = []
for node in range(graph.number_of_nodes()):
tmp = list(graph.neighbors(node))
tmp.sort()
type = -1
if node < tmp[0]:
input_nodes.append(node)
type = 0
if node > tmp[-1]:
output_nodes.append(node)
type = 1
Nodes.append(Node(node, [n for n in tmp if n < node], type))
return Nodes, input_nodes, output_nodes


def nodeid_trans(id, cur_level, num_levels):
if id % 2 == 1:
gap = int(((id + 1) // 2) * num_levels * 2)
else:
a = (num_levels - cur_level) * 2 - 1
b = ((id + 1) // 2) * num_levels * 2
gap = int(a + b)
return cur_level + gap


def gen_log2n_graph_file(log2n_graph_file, depth_multiplier):
f = open(log2n_graph_file, 'w')
for i in range(depth_multiplier):
for j in [1, 2, 4, 8, 16, 32]:
if i - j < 0:
break
else:
f.write('%d,%d\n' % (i - j, i))
f.close()


def get_log2n_graph(depth_multiplier):
nodes = []
connnections = []

for i in range(depth_multiplier):
nodes.append(i)
for j in [1, 2, 4, 8, 16, 32]:
if i - j < 0:
break
else:
connnections.append((i - j, i))
return nodes, connnections


def get_dense_graph(depth_multiplier):
nodes = []
connections = []

for i in range(depth_multiplier):
nodes.append(i)
for j in range(i):
connections.append((j, i))
return nodes, connections


def giraffeneck_config(min_level,
max_level,
weight_method=None,
depth_multiplier=5,
with_backslash=False,
with_slash=False,
with_skip_connect=False,
skip_connect_type='dense'):
"""Graph config with log2n merge and panet"""
if skip_connect_type == 'dense':
nodes, connections = get_dense_graph(depth_multiplier)
elif skip_connect_type == 'log2n':
nodes, connections = get_log2n_graph(depth_multiplier)
graph = nx.Graph()
graph.add_nodes_from(nodes)
graph.add_edges_from(connections)

drop_node = []
nodes, input_nodes, output_nodes = get_graph_info(graph)

weight_method = weight_method or 'fastattn'

num_levels = max_level - min_level + 1
node_ids = {min_level + i: [i] for i in range(num_levels)}
node_ids_per_layer = {}

pnodes = {}

def update_drop_node(new_id, input_offsets):
if new_id not in drop_node:
new_id = new_id
else:
while new_id in drop_node:
if new_id in pnodes:
for n in pnodes[new_id]['inputs_offsets']:
if n not in input_offsets and n not in drop_node:
input_offsets.append(n)
new_id = new_id - 1
if new_id not in input_offsets:
input_offsets.append(new_id)

# top-down layer
for i in range(max_level, min_level - 1, -1):
node_ids_per_layer[i] = []
for id, node in enumerate(nodes):
input_offsets = []
if id in input_nodes:
input_offsets.append(node_ids[i][0])
else:
if with_skip_connect:
for input_id in node.inputs:
new_id = nodeid_trans(input_id, i - min_level,
num_levels)
update_drop_node(new_id, input_offsets)

# add top2down
new_id = nodeid_trans(id, i - min_level, num_levels)

# add backslash node
def cal_backslash_node(id):
ind = id // num_levels
mod = id % num_levels
if ind % 2 == 0: # even
if mod == (num_levels - 1):
last = -1
else:
last = (ind - 1) * num_levels + (
num_levels - 1 - mod - 1)
else: # odd
if mod == 0:
last = -1
else:
last = (ind - 1) * num_levels + (
num_levels - 1 - mod + 1)

return last

# add slash node
def cal_slash_node(id):
ind = id // num_levels
mod = id % num_levels
if ind % 2 == 1: # odd
if mod == (num_levels - 1):
last = -1
else:
last = (ind - 1) * num_levels + (
num_levels - 1 - mod - 1)
else: # even
if mod == 0:
last = -1
else:
last = (ind - 1) * num_levels + (
num_levels - 1 - mod + 1)

return last

# add last node
last = new_id - 1
update_drop_node(last, input_offsets)

if with_backslash:
backslash = cal_backslash_node(new_id)
if backslash != -1 and backslash not in input_offsets:
input_offsets.append(backslash)

if with_slash:
slash = cal_slash_node(new_id)
if slash != -1 and slash not in input_offsets:
input_offsets.append(slash)

if new_id in drop_node:
input_offsets = []

pnodes[new_id] = {
'reduction': 1 << i,
'inputs_offsets': input_offsets,
'weight_method': weight_method,
'is_out': 0,
}

input_offsets = []
for out_id in output_nodes:
new_id = nodeid_trans(out_id, i - min_level, num_levels)
input_offsets.append(new_id)

pnodes[node_ids[i][0] + num_levels * (len(nodes) + 1)] = {
'reduction': 1 << i,
'inputs_offsets': input_offsets,
'weight_method': weight_method,
'is_out': 1,
}

pnodes = dict(sorted(pnodes.items(), key=lambda x: x[0]))
return pnodes


def get_graph_config(fpn_name,
min_level=3,
max_level=7,
weight_method='concat',
depth_multiplier=5,
with_backslash=False,
with_slash=False,
with_skip_connect=False,
skip_connect_type='dense'):
name_to_config = {
'giraffeneck':
giraffeneck_config(
min_level=min_level,
max_level=max_level,
weight_method=weight_method,
depth_multiplier=depth_multiplier,
with_backslash=with_backslash,
with_slash=with_slash,
with_skip_connect=with_skip_connect,
skip_connect_type=skip_connect_type),
}
return name_to_config[fpn_name]

+ 661
- 0
modelscope/models/cv/tinynas_detection/neck/giraffe_fpn.py View File

@@ -0,0 +1,661 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet.

import logging
import math
from collections import OrderedDict
from functools import partial
from typing import Callable, List, Optional, Tuple, Union

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm import create_model
from timm.models.layers import (Swish, create_conv2d, create_pool2d,
get_act_layer)

from ..core.base_ops import CSPLayer, ShuffleBlock, ShuffleCSPLayer
from .giraffe_config import get_graph_config

_ACT_LAYER = Swish


class SequentialList(nn.Sequential):
""" This module exists to work around torchscript typing issues list -> list"""

def __init__(self, *args):
super(SequentialList, self).__init__(*args)

def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]:
for module in self:
x = module(x)
return x


class ConvBnAct2d(nn.Module):

def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
dilation=1,
padding='',
bias=False,
norm_layer=nn.BatchNorm2d,
act_layer=_ACT_LAYER):
super(ConvBnAct2d, self).__init__()

self.conv = create_conv2d(
in_channels,
out_channels,
kernel_size,
stride=stride,
dilation=dilation,
padding=padding,
bias=bias)
self.bn = None if norm_layer is None else norm_layer(out_channels)
self.act = None if act_layer is None else act_layer(inplace=True)

def forward(self, x):
x = self.conv(x)
if self.bn is not None:
x = self.bn(x)
if self.act is not None:
x = self.act(x)
return x


class SeparableConv2d(nn.Module):
""" Separable Conv
"""

def __init__(self,
in_channels,
out_channels,
kernel_size=3,
stride=1,
dilation=1,
padding='',
bias=False,
channel_multiplier=1.0,
pw_kernel_size=1,
norm_layer=nn.BatchNorm2d,
act_layer=_ACT_LAYER):
super(SeparableConv2d, self).__init__()
self.conv_dw = create_conv2d(
in_channels,
int(in_channels * channel_multiplier),
kernel_size,
stride=stride,
dilation=dilation,
padding=padding,
depthwise=True)

self.conv_pw = create_conv2d(
int(in_channels * channel_multiplier),
out_channels,
pw_kernel_size,
padding=padding,
bias=bias)

self.bn = None if norm_layer is None else norm_layer(out_channels)
self.act = None if act_layer is None else act_layer(inplace=True)

def forward(self, x):
x = self.conv_dw(x)
x = self.conv_pw(x)
if self.bn is not None:
x = self.bn(x)
if self.act is not None:
x = self.act(x)
return x


def _init_weight(
m,
n='',
):
""" Weight initialization as per Tensorflow official implementations.
"""

def _fan_in_out(w, groups=1):
dimensions = w.dim()
if dimensions < 2:
raise ValueError(
'Fan in and fan out can not be computed for tensor with fewer than 2 dimensions'
)
num_input_fmaps = w.size(1)
num_output_fmaps = w.size(0)
receptive_field_size = 1
if w.dim() > 2:
receptive_field_size = w[0][0].numel()
fan_in = num_input_fmaps * receptive_field_size
fan_out = num_output_fmaps * receptive_field_size
fan_out //= groups
return fan_in, fan_out

def _glorot_uniform(w, gain=1, groups=1):
fan_in, fan_out = _fan_in_out(w, groups)
gain /= max(1., (fan_in + fan_out) / 2.) # fan avg
limit = math.sqrt(3.0 * gain)
w.data.uniform_(-limit, limit)

def _variance_scaling(w, gain=1, groups=1):
fan_in, fan_out = _fan_in_out(w, groups)
gain /= max(1., fan_in) # fan in
std = math.sqrt(gain)
w.data.normal_(std=std)

if isinstance(m, SeparableConv2d):
if 'box_net' in n or 'class_net' in n:
_variance_scaling(m.conv_dw.weight, groups=m.conv_dw.groups)
_variance_scaling(m.conv_pw.weight)
if m.conv_pw.bias is not None:
if 'class_net.predict' in n:
m.conv_pw.bias.data.fill_(-math.log((1 - 0.01) / 0.01))
else:
m.conv_pw.bias.data.zero_()
else:
_glorot_uniform(m.conv_dw.weight, groups=m.conv_dw.groups)
_glorot_uniform(m.conv_pw.weight)
if m.conv_pw.bias is not None:
m.conv_pw.bias.data.zero_()
elif isinstance(m, ConvBnAct2d):
if 'box_net' in n or 'class_net' in n:
m.conv.weight.data.normal_(std=.01)
if m.conv.bias is not None:
if 'class_net.predict' in n:
m.conv.bias.data.fill_(-math.log((1 - 0.01) / 0.01))
else:
m.conv.bias.data.zero_()
else:
_glorot_uniform(m.conv.weight)
if m.conv.bias is not None:
m.conv.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1.0)
m.bias.data.zero_()


def _init_weight_alt(
m,
n='',
):
""" Weight initialization alternative, based on EfficientNet bacbkone init w/ class bias addition
NOTE: this will likely be removed after some experimentation
"""
if isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
if 'class_net.predict' in n:
m.bias.data.fill_(-math.log((1 - 0.01) / 0.01))
else:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1.0)
m.bias.data.zero_()


class Interpolate2d(nn.Module):
r"""Resamples a 2d Image

The input data is assumed to be of the form
`minibatch x channels x [optional depth] x [optional height] x width`.
Hence, for spatial inputs, we expect a 4D Tensor and for volumetric inputs, we expect a 5D Tensor.

The algorithms available for upsampling are nearest neighbor and linear,
bilinear, bicubic and trilinear for 3D, 4D and 5D input Tensor,
respectively.

One can either give a :attr:`scale_factor` or the target output :attr:`size` to
calculate the output size. (You cannot give both, as it is ambiguous)

Args:
size (int or Tuple[int] or Tuple[int, int] or Tuple[int, int, int], optional):
output spatial sizes
scale_factor (float or Tuple[float] or Tuple[float, float] or Tuple[float, float, float], optional):
multiplier for spatial size. Has to match input size if it is a tuple.
mode (str, optional): the upsampling algorithm: one of ``'nearest'``,
``'linear'``, ``'bilinear'``, ``'bicubic'`` and ``'trilinear'``.
Default: ``'nearest'``
align_corners (bool, optional): if ``True``, the corner pixels of the input
and output tensors are aligned, and thus preserving the values at
those pixels. This only has effect when :attr:`mode` is
``'linear'``, ``'bilinear'``, or ``'trilinear'``. Default: ``False``
"""
__constants__ = ['size', 'scale_factor', 'mode', 'align_corners', 'name']
name: str
size: Optional[Union[int, Tuple[int, int]]]
scale_factor: Optional[Union[float, Tuple[float, float]]]
mode: str
align_corners: Optional[bool]

def __init__(self,
size: Optional[Union[int, Tuple[int, int]]] = None,
scale_factor: Optional[Union[float, Tuple[float,
float]]] = None,
mode: str = 'nearest',
align_corners: bool = False) -> None:
super(Interpolate2d, self).__init__()
self.name = type(self).__name__
self.size = size
if isinstance(scale_factor, tuple):
self.scale_factor = tuple(float(factor) for factor in scale_factor)
else:
self.scale_factor = float(scale_factor) if scale_factor else None
self.mode = mode
self.align_corners = None if mode == 'nearest' else align_corners

def forward(self, input: torch.Tensor) -> torch.Tensor:
return F.interpolate(
input,
self.size,
self.scale_factor,
self.mode,
self.align_corners,
recompute_scale_factor=False)


class ResampleFeatureMap(nn.Sequential):

def __init__(self,
in_channels,
out_channels,
reduction_ratio=1.,
pad_type='',
downsample=None,
upsample=None,
norm_layer=nn.BatchNorm2d,
apply_bn=False,
conv_after_downsample=False,
redundant_bias=False):
super(ResampleFeatureMap, self).__init__()
downsample = downsample or 'max'
upsample = upsample or 'nearest'
self.in_channels = in_channels
self.out_channels = out_channels
self.reduction_ratio = reduction_ratio
self.conv_after_downsample = conv_after_downsample

conv = None
if in_channels != out_channels:
conv = ConvBnAct2d(
in_channels,
out_channels,
kernel_size=1,
padding=pad_type,
norm_layer=norm_layer if apply_bn else None,
bias=not apply_bn or redundant_bias,
act_layer=None)

if reduction_ratio > 1:
if conv is not None and not self.conv_after_downsample:
self.add_module('conv', conv)
if downsample in ('max', 'avg'):
stride_size = int(reduction_ratio)
downsample = create_pool2d(
downsample,
kernel_size=stride_size + 1,
stride=stride_size,
padding=pad_type)
else:
downsample = Interpolate2d(
scale_factor=1. / reduction_ratio, mode=downsample)
self.add_module('downsample', downsample)
if conv is not None and self.conv_after_downsample:
self.add_module('conv', conv)
else:
if conv is not None:
self.add_module('conv', conv)
if reduction_ratio < 1:
scale = int(1 // reduction_ratio)
self.add_module(
'upsample',
Interpolate2d(scale_factor=scale, mode=upsample))


class GiraffeCombine(nn.Module):

def __init__(self,
feature_info,
fpn_config,
fpn_channels,
inputs_offsets,
target_reduction,
pad_type='',
downsample=None,
upsample=None,
norm_layer=nn.BatchNorm2d,
apply_resample_bn=False,
conv_after_downsample=False,
redundant_bias=False,
weight_method='attn'):
super(GiraffeCombine, self).__init__()
self.inputs_offsets = inputs_offsets
self.weight_method = weight_method

self.resample = nn.ModuleDict()
reduction_base = feature_info[0]['reduction']

target_channels_idx = int(
math.log(target_reduction // reduction_base, 2))
for idx, offset in enumerate(inputs_offsets):
if offset < len(feature_info):
in_channels = feature_info[offset]['num_chs']
input_reduction = feature_info[offset]['reduction']
else:
node_idx = offset
input_reduction = fpn_config[node_idx]['reduction']
# in_channels = fpn_config[node_idx]['num_chs']
input_channels_idx = int(
math.log(input_reduction // reduction_base, 2))
in_channels = feature_info[input_channels_idx]['num_chs']

reduction_ratio = target_reduction / input_reduction
if weight_method == 'concat':
self.resample[str(offset)] = ResampleFeatureMap(
in_channels,
in_channels,
reduction_ratio=reduction_ratio,
pad_type=pad_type,
downsample=downsample,
upsample=upsample,
norm_layer=norm_layer,
apply_bn=apply_resample_bn,
conv_after_downsample=conv_after_downsample,
redundant_bias=redundant_bias)
else:
self.resample[str(offset)] = ResampleFeatureMap(
in_channels,
fpn_channels[target_channels_idx],
reduction_ratio=reduction_ratio,
pad_type=pad_type,
downsample=downsample,
upsample=upsample,
norm_layer=norm_layer,
apply_bn=apply_resample_bn,
conv_after_downsample=conv_after_downsample,
redundant_bias=redundant_bias)

if weight_method == 'attn' or weight_method == 'fastattn':
self.edge_weights = nn.Parameter(
torch.ones(len(inputs_offsets)), requires_grad=True) # WSM
else:
self.edge_weights = None

def forward(self, x: List[torch.Tensor]):
dtype = x[0].dtype
nodes = []
if len(self.inputs_offsets) == 0:
return None
for offset, resample in zip(self.inputs_offsets,
self.resample.values()):
input_node = x[offset]
input_node = resample(input_node)
nodes.append(input_node)

if self.weight_method == 'attn':
normalized_weights = torch.softmax(
self.edge_weights.to(dtype=dtype), dim=0)
out = torch.stack(nodes, dim=-1) * normalized_weights
out = torch.sum(out, dim=-1)
elif self.weight_method == 'fastattn':
edge_weights = nn.functional.relu(
self.edge_weights.to(dtype=dtype))
weights_sum = torch.sum(edge_weights)
weights_norm = weights_sum + 0.0001
out = torch.stack([(nodes[i] * edge_weights[i]) / weights_norm
for i in range(len(nodes))],
dim=-1)

out = torch.sum(out, dim=-1)
elif self.weight_method == 'sum':
out = torch.stack(nodes, dim=-1)
out = torch.sum(out, dim=-1)
elif self.weight_method == 'concat':
out = torch.cat(nodes, dim=1)
else:
raise ValueError('unknown weight_method {}'.format(
self.weight_method))
return out


class GiraffeNode(nn.Module):
""" A simple wrapper used in place of nn.Sequential for torchscript typing
Handles input type List[Tensor] -> output type Tensor
"""

def __init__(self, combine: nn.Module, after_combine: nn.Module):
super(GiraffeNode, self).__init__()
self.combine = combine
self.after_combine = after_combine

def forward(self, x: List[torch.Tensor]) -> torch.Tensor:
combine_feat = self.combine(x)
if combine_feat is None:
return None
else:
return self.after_combine(combine_feat)


class GiraffeLayer(nn.Module):

def __init__(self,
feature_info,
fpn_config,
inner_fpn_channels,
outer_fpn_channels,
num_levels=5,
pad_type='',
downsample=None,
upsample=None,
norm_layer=nn.BatchNorm2d,
act_layer=_ACT_LAYER,
apply_resample_bn=False,
conv_after_downsample=True,
conv_bn_relu_pattern=False,
separable_conv=True,
redundant_bias=False,
merge_type='conv'):
super(GiraffeLayer, self).__init__()
self.num_levels = num_levels
self.conv_bn_relu_pattern = False

self.feature_info = {}
for idx, feat in enumerate(feature_info):
self.feature_info[idx] = feat

self.fnode = nn.ModuleList()
reduction_base = feature_info[0]['reduction']
for i, fnode_cfg in fpn_config.items():
logging.debug('fnode {} : {}'.format(i, fnode_cfg))

if fnode_cfg['is_out'] == 1:
fpn_channels = outer_fpn_channels
else:
fpn_channels = inner_fpn_channels

reduction = fnode_cfg['reduction']
fpn_channels_idx = int(math.log(reduction // reduction_base, 2))
combine = GiraffeCombine(
self.feature_info,
fpn_config,
fpn_channels,
tuple(fnode_cfg['inputs_offsets']),
target_reduction=reduction,
pad_type=pad_type,
downsample=downsample,
upsample=upsample,
norm_layer=norm_layer,
apply_resample_bn=apply_resample_bn,
conv_after_downsample=conv_after_downsample,
redundant_bias=redundant_bias,
weight_method=fnode_cfg['weight_method'])

after_combine = nn.Sequential()

in_channels = 0
out_channels = 0
for input_offset in fnode_cfg['inputs_offsets']:
in_channels += self.feature_info[input_offset]['num_chs']

out_channels = fpn_channels[fpn_channels_idx]

if merge_type == 'csp':
after_combine.add_module(
'CspLayer',
CSPLayer(
in_channels,
out_channels,
2,
shortcut=True,
depthwise=False,
act='silu'))
elif merge_type == 'shuffle':
after_combine.add_module(
'shuffleBlock', ShuffleBlock(in_channels, in_channels))
after_combine.add_module(
'conv1x1',
create_conv2d(in_channels, out_channels, kernel_size=1))
elif merge_type == 'conv':
after_combine.add_module(
'conv1x1',
create_conv2d(in_channels, out_channels, kernel_size=1))
conv_kwargs = dict(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
padding=pad_type,
bias=False,
norm_layer=norm_layer,
act_layer=act_layer)
if not conv_bn_relu_pattern:
conv_kwargs['bias'] = redundant_bias
conv_kwargs['act_layer'] = None
after_combine.add_module('act', act_layer(inplace=True))
after_combine.add_module(
'conv',
SeparableConv2d(**conv_kwargs)
if separable_conv else ConvBnAct2d(**conv_kwargs))

self.fnode.append(
GiraffeNode(combine=combine, after_combine=after_combine))
self.feature_info[i] = dict(
num_chs=fpn_channels[fpn_channels_idx], reduction=reduction)

self.out_feature_info = []
out_node = list(self.feature_info.keys())[-num_levels::]
for i in out_node:
self.out_feature_info.append(self.feature_info[i])

self.feature_info = self.out_feature_info

def forward(self, x: List[torch.Tensor]):
for fn in self.fnode:
x.append(fn(x))
return x[-self.num_levels::]


class GiraffeNeck(nn.Module):

def __init__(self, min_level, max_level, num_levels, norm_layer,
norm_kwargs, act_type, fpn_config, fpn_name, fpn_channels,
out_fpn_channels, weight_method, depth_multiplier,
width_multiplier, with_backslash, with_slash,
with_skip_connect, skip_connect_type, separable_conv,
feature_info, merge_type, pad_type, downsample_type,
upsample_type, apply_resample_bn, conv_after_downsample,
redundant_bias, conv_bn_relu_pattern, alternate_init):
super(GiraffeNeck, self).__init__()

self.num_levels = num_levels
self.min_level = min_level
self.in_features = [0, 1, 2, 3, 4, 5,
6][self.min_level - 1:self.min_level - 1
+ num_levels]
self.alternate_init = alternate_init
norm_layer = norm_layer or nn.BatchNorm2d
if norm_kwargs:
norm_layer = partial(norm_layer, **norm_kwargs)
act_layer = get_act_layer(act_type) or _ACT_LAYER
fpn_config = fpn_config or get_graph_config(
fpn_name,
min_level=min_level,
max_level=max_level,
weight_method=weight_method,
depth_multiplier=depth_multiplier,
with_backslash=with_backslash,
with_slash=with_slash,
with_skip_connect=with_skip_connect,
skip_connect_type=skip_connect_type)

# width scale
for i in range(len(fpn_channels)):
fpn_channels[i] = int(fpn_channels[i] * width_multiplier)

self.resample = nn.ModuleDict()
for level in range(num_levels):
if level < len(feature_info):
in_chs = feature_info[level]['num_chs']
reduction = feature_info[level]['reduction']
else:
# Adds a coarser level by downsampling the last feature map
reduction_ratio = 2
self.resample[str(level)] = ResampleFeatureMap(
in_channels=in_chs,
out_channels=feature_info[level - 1]['num_chs'],
pad_type=pad_type,
downsample=downsample_type,
upsample=upsample_type,
norm_layer=norm_layer,
reduction_ratio=reduction_ratio,
apply_bn=apply_resample_bn,
conv_after_downsample=conv_after_downsample,
redundant_bias=redundant_bias,
)
in_chs = feature_info[level - 1]['num_chs']
reduction = int(reduction * reduction_ratio)
feature_info.append(dict(num_chs=in_chs, reduction=reduction))

self.cell = SequentialList()
logging.debug('building giraffeNeck')
giraffe_layer = GiraffeLayer(
feature_info=feature_info,
fpn_config=fpn_config,
inner_fpn_channels=fpn_channels,
outer_fpn_channels=out_fpn_channels,
num_levels=num_levels,
pad_type=pad_type,
downsample=downsample_type,
upsample=upsample_type,
norm_layer=norm_layer,
act_layer=act_layer,
separable_conv=separable_conv,
apply_resample_bn=apply_resample_bn,
conv_after_downsample=conv_after_downsample,
conv_bn_relu_pattern=conv_bn_relu_pattern,
redundant_bias=redundant_bias,
merge_type=merge_type)
self.cell.add_module('giraffeNeck', giraffe_layer)
feature_info = giraffe_layer.feature_info

def init_weights(self, pretrained=False):
for n, m in self.named_modules():
if 'backbone' not in n:
if self.alternate_init:
_init_weight_alt(m, n)
else:
_init_weight(m, n)

def forward(self, x: List[torch.Tensor]):
if type(x) is tuple:
x = list(x)
x = [x[f] for f in self.in_features]
for resample in self.resample.values():
x.append(resample(x[-1]))
x = self.cell(x)
return x

+ 203
- 0
modelscope/models/cv/tinynas_detection/neck/giraffe_fpn_v2.py View File

@@ -0,0 +1,203 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet.

import torch
import torch.nn as nn

from ..core.base_ops import BaseConv, CSPLayer, DWConv
from ..core.neck_ops import CSPStage


class GiraffeNeckV2(nn.Module):

def __init__(
self,
depth=1.0,
width=1.0,
in_features=[2, 3, 4],
in_channels=[256, 512, 1024],
out_channels=[256, 512, 1024],
depthwise=False,
act='silu',
spp=True,
reparam_mode=True,
block_name='BasicBlock',
):
super().__init__()
self.in_features = in_features
self.in_channels = in_channels
Conv = DWConv if depthwise else BaseConv

reparam_mode = reparam_mode

self.upsample = nn.Upsample(scale_factor=2, mode='nearest')

# node x3: input x0, x1
self.bu_conv13 = Conv(
int(in_channels[1] * width),
int(in_channels[1] * width),
3,
2,
act=act)
if reparam_mode:
self.merge_3 = CSPStage(
block_name,
int((in_channels[1] + in_channels[2]) * width),
int(in_channels[2] * width),
round(3 * depth),
act=act,
spp=spp)
else:
self.merge_3 = CSPLayer(
int((in_channels[1] + in_channels[2]) * width),
int(in_channels[2] * width),
round(3 * depth),
False,
depthwise=depthwise,
act=act)

# node x4: input x1, x2, x3
self.bu_conv24 = Conv(
int(in_channels[0] * width),
int(in_channels[0] * width),
3,
2,
act=act)
if reparam_mode:
self.merge_4 = CSPStage(
block_name,
int((in_channels[0] + in_channels[1] + in_channels[2])
* width),
int(in_channels[1] * width),
round(3 * depth),
act=act,
spp=spp)
else:
self.merge_4 = CSPLayer(
int((in_channels[0] + in_channels[1] + in_channels[2])
* width),
int(in_channels[1] * width),
round(3 * depth),
False,
depthwise=depthwise,
act=act)

# node x5: input x2, x4
if reparam_mode:
self.merge_5 = CSPStage(
block_name,
int((in_channels[1] + in_channels[0]) * width),
int(out_channels[0] * width),
round(3 * depth),
act=act,
spp=spp)
else:
self.merge_5 = CSPLayer(
int((in_channels[1] + in_channels[0]) * width),
int(out_channels[0] * width),
round(3 * depth),
False,
depthwise=depthwise,
act=act)

# node x7: input x4, x5
self.bu_conv57 = Conv(
int(out_channels[0] * width),
int(out_channels[0] * width),
3,
2,
act=act)
if reparam_mode:
self.merge_7 = CSPStage(
block_name,
int((out_channels[0] + in_channels[1]) * width),
int(out_channels[1] * width),
round(3 * depth),
act=act,
spp=spp)
else:
self.merge_7 = CSPLayer(
int((out_channels[0] + in_channels[1]) * width),
int(out_channels[1] * width),
round(3 * depth),
False,
depthwise=depthwise,
act=act)

# node x6: input x3, x4, x7
self.bu_conv46 = Conv(
int(in_channels[1] * width),
int(in_channels[1] * width),
3,
2,
act=act)
self.bu_conv76 = Conv(
int(out_channels[1] * width),
int(out_channels[1] * width),
3,
2,
act=act)
if reparam_mode:
self.merge_6 = CSPStage(
block_name,
int((in_channels[1] + out_channels[1] + in_channels[2])
* width),
int(out_channels[2] * width),
round(3 * depth),
act=act,
spp=spp)
else:
self.merge_6 = CSPLayer(
int((in_channels[1] + out_channels[1] + in_channels[2])
* width),
int(out_channels[2] * width),
round(3 * depth),
False,
depthwise=depthwise,
act=act)

def init_weights(self):
pass

def forward(self, out_features):
"""
Args:
inputs: input images.

Returns:
Tuple[Tensor]: FPN feature.
"""

# backbone
features = [out_features[f] for f in self.in_features]
[x2, x1, x0] = features

# node x3
x13 = self.bu_conv13(x1)
x3 = torch.cat([x0, x13], 1)
x3 = self.merge_3(x3)

# node x4
x34 = self.upsample(x3)
x24 = self.bu_conv24(x2)
x4 = torch.cat([x1, x24, x34], 1)
x4 = self.merge_4(x4)

# node x5
x45 = self.upsample(x4)
x5 = torch.cat([x2, x45], 1)
x5 = self.merge_5(x5)

# node x7
x57 = self.bu_conv57(x5)
x7 = torch.cat([x4, x57], 1)
x7 = self.merge_7(x7)

# node x6
x46 = self.bu_conv46(x4)
x76 = self.bu_conv76(x7)
x6 = torch.cat([x3, x46, x76], 1)
x6 = self.merge_6(x6)

outputs = (x5, x7, x6)
return outputs

+ 16
- 0
modelscope/models/cv/tinynas_detection/tinynas_detector.py View File

@@ -0,0 +1,16 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet.

from modelscope.metainfo import Models
from modelscope.models.builder import MODELS
from modelscope.utils.constant import Tasks
from .detector import SingleStageDetector


@MODELS.register_module(
Tasks.image_object_detection, module_name=Models.tinynas_detection)
class TinynasDetector(SingleStageDetector):

def __init__(self, model_dir, *args, **kwargs):

super(TinynasDetector, self).__init__(model_dir, *args, **kwargs)

+ 30
- 0
modelscope/models/cv/tinynas_detection/utils.py View File

@@ -0,0 +1,30 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet.

import importlib
import os
import sys
from os.path import dirname, join


def get_config_by_file(config_file):
try:
sys.path.append(os.path.dirname(config_file))
current_config = importlib.import_module(
os.path.basename(config_file).split('.')[0])
exp = current_config.Config()
except Exception:
raise ImportError(
"{} doesn't contains class named 'Config'".format(config_file))
return exp


def parse_config(config_file):
"""
get config object by file.
Args:
config_file (str): file path of config.
"""
assert (config_file is not None), 'plz provide config file'
if config_file is not None:
return get_config_by_file(config_file)

+ 61
- 0
modelscope/pipelines/cv/tinynas_detection_pipeline.py View File

@@ -0,0 +1,61 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

from typing import Any, Dict

import cv2
import numpy as np
import torch

from modelscope.metainfo import Pipelines
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Input, Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import LoadImage
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger

logger = get_logger()


@PIPELINES.register_module(
Tasks.image_object_detection, module_name=Pipelines.tinynas_detection)
class TinynasDetectionPipeline(Pipeline):

def __init__(self, model: str, **kwargs):
"""
model: model id on modelscope hub.
"""
super().__init__(model=model, auto_collate=False, **kwargs)
if torch.cuda.is_available():
self.device = 'cuda'
else:
self.device = 'cpu'
self.model.to(self.device)
self.model.eval()

def preprocess(self, input: Input) -> Dict[str, Any]:

img = LoadImage.convert_to_ndarray(input)
self.img = img
img = img.astype(np.float)
img = self.model.preprocess(img)
result = {'img': img.to(self.device)}
return result

def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:

outputs = self.model.inference(input['img'])
result = {'data': outputs}
return result

def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:

bboxes, scores, labels = self.model.postprocess(inputs['data'])
if bboxes is None:
return None
outputs = {
OutputKeys.SCORES: scores,
OutputKeys.LABELS: labels,
OutputKeys.BOXES: bboxes
}
return outputs

+ 20
- 0
tests/pipelines/test_tinynas_detection.py View File

@@ -0,0 +1,20 @@
import unittest

from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.test_utils import test_level


class TinynasObjectDetectionTest(unittest.TestCase):

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run(self):
tinynas_object_detection = pipeline(
Tasks.image_object_detection, model='damo/cv_tinynas_detection')
result = tinynas_object_detection(
'data/test/images/image_detection.jpg')
print(result)


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save