Browse Source

add damoyolo-t & damoyolo-m

1. add damoyolo-t & damoyolo-m models
2. fix the configuration overlap error
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10816561
master^2
xianzhe.xxz yingda.chen 3 years ago
parent
commit
9da5091d4d
25 changed files with 1463 additions and 606 deletions
  1. +1
    -1
      modelscope/models/cv/tinynas_detection/__init__.py
  2. +7
    -4
      modelscope/models/cv/tinynas_detection/backbone/__init__.py
  3. +2
    -3
      modelscope/models/cv/tinynas_detection/backbone/darknet.py
  4. +0
    -359
      modelscope/models/cv/tinynas_detection/backbone/tinynas.py
  5. +295
    -0
      modelscope/models/cv/tinynas_detection/backbone/tinynas_csp.py
  6. +238
    -0
      modelscope/models/cv/tinynas_detection/backbone/tinynas_res.py
  7. +1
    -1
      modelscope/models/cv/tinynas_detection/core/__init__.py
  8. +1
    -1
      modelscope/models/cv/tinynas_detection/core/base_ops.py
  9. +1
    -1
      modelscope/models/cv/tinynas_detection/core/neck_ops.py
  10. +435
    -0
      modelscope/models/cv/tinynas_detection/core/ops.py
  11. +1
    -1
      modelscope/models/cv/tinynas_detection/core/repvgg_block.py
  12. +1
    -1
      modelscope/models/cv/tinynas_detection/core/utils.py
  13. +2
    -2
      modelscope/models/cv/tinynas_detection/detector.py
  14. +4
    -1
      modelscope/models/cv/tinynas_detection/head/__init__.py
  15. +3
    -2
      modelscope/models/cv/tinynas_detection/head/gfocal_v2_tiny.py
  16. +288
    -0
      modelscope/models/cv/tinynas_detection/head/zero_head.py
  17. +2
    -2
      modelscope/models/cv/tinynas_detection/neck/__init__.py
  18. +1
    -1
      modelscope/models/cv/tinynas_detection/neck/giraffe_config.py
  19. +3
    -2
      modelscope/models/cv/tinynas_detection/neck/giraffe_fpn.py
  20. +132
    -0
      modelscope/models/cv/tinynas_detection/neck/giraffe_fpn_btn.py
  21. +0
    -200
      modelscope/models/cv/tinynas_detection/neck/giraffe_fpn_v2.py
  22. +1
    -1
      modelscope/models/cv/tinynas_detection/tinynas_damoyolo.py
  23. +1
    -1
      modelscope/models/cv/tinynas_detection/tinynas_detector.py
  24. +23
    -20
      modelscope/models/cv/tinynas_detection/utils.py
  25. +20
    -2
      tests/pipelines/test_tinynas_detection.py

+ 1
- 1
modelscope/models/cv/tinynas_detection/__init__.py View File

@@ -1,5 +1,5 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet.
# The DAMO-YOLO implementation is also open-sourced by the authors at https://github.com/tinyvision/damo-yolo.

from typing import TYPE_CHECKING



+ 7
- 4
modelscope/models/cv/tinynas_detection/backbone/__init__.py View File

@@ -1,10 +1,11 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet.
# The DAMO-YOLO implementation is also open-sourced by the authors at https://github.com/tinyvision/damo-yolo.

import copy

from .darknet import CSPDarknet
from .tinynas import load_tinynas_net
from .tinynas_csp import load_tinynas_net as load_tinynas_net_csp
from .tinynas_res import load_tinynas_net as load_tinynas_net_res


def build_backbone(cfg):
@@ -12,5 +13,7 @@ def build_backbone(cfg):
name = backbone_cfg.pop('name')
if name == 'CSPDarknet':
return CSPDarknet(**backbone_cfg)
elif name == 'TinyNAS':
return load_tinynas_net(backbone_cfg)
elif name == 'TinyNAS_csp':
return load_tinynas_net_csp(backbone_cfg)
elif name == 'TinyNAS_res':
return load_tinynas_net_res(backbone_cfg)

+ 2
- 3
modelscope/models/cv/tinynas_detection/backbone/darknet.py View File

@@ -1,12 +1,11 @@
# Copyright (c) Megvii Inc. All rights reserved.
# Copyright (c) Alibaba, Inc. and its affiliates.
# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet.

import torch
from torch import nn

from ..core.base_ops import (BaseConv, CSPLayer, DWConv, Focus, ResLayer,
SPPBottleneck)
from modelscope.models.cv.tinynas_detection.core.base_ops import (
BaseConv, CSPLayer, DWConv, Focus, ResLayer, SPPBottleneck)


class CSPDarknet(nn.Module):


+ 0
- 359
modelscope/models/cv/tinynas_detection/backbone/tinynas.py View File

@@ -1,359 +0,0 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet.

import torch
import torch.nn as nn

from modelscope.utils.file_utils import read_file
from ..core.base_ops import Focus, SPPBottleneck, get_activation
from ..core.repvgg_block import RepVggBlock


class ConvKXBN(nn.Module):

def __init__(self, in_c, out_c, kernel_size, stride):
super(ConvKXBN, self).__init__()
self.conv1 = nn.Conv2d(
in_c,
out_c,
kernel_size,
stride, (kernel_size - 1) // 2,
groups=1,
bias=False)
self.bn1 = nn.BatchNorm2d(out_c)

def forward(self, x):
return self.bn1(self.conv1(x))


class ConvKXBNRELU(nn.Module):

def __init__(self, in_c, out_c, kernel_size, stride, act='silu'):
super(ConvKXBNRELU, self).__init__()
self.conv = ConvKXBN(in_c, out_c, kernel_size, stride)
if act is None:
self.activation_function = torch.relu
else:
self.activation_function = get_activation(act)

def forward(self, x):
output = self.conv(x)
return self.activation_function(output)


class ResConvK1KX(nn.Module):

def __init__(self,
in_c,
out_c,
btn_c,
kernel_size,
stride,
force_resproj=False,
act='silu',
reparam=False):
super(ResConvK1KX, self).__init__()
self.stride = stride
self.conv1 = ConvKXBN(in_c, btn_c, 1, 1)
if not reparam:
self.conv2 = ConvKXBN(btn_c, out_c, 3, stride)
else:
self.conv2 = RepVggBlock(
btn_c, out_c, kernel_size, stride, act='identity')

if act is None:
self.activation_function = torch.relu
else:
self.activation_function = get_activation(act)

if stride == 2:
self.residual_downsample = nn.AvgPool2d(kernel_size=2, stride=2)
else:
self.residual_downsample = nn.Identity()

if in_c != out_c or force_resproj:
self.residual_proj = ConvKXBN(in_c, out_c, 1, 1)
else:
self.residual_proj = nn.Identity()

def forward(self, x):
if self.stride != 2:
reslink = self.residual_downsample(x)
reslink = self.residual_proj(reslink)

output = x
output = self.conv1(output)
output = self.activation_function(output)
output = self.conv2(output)
if self.stride != 2:
output = output + reslink
output = self.activation_function(output)

return output


class SuperResConvK1KX(nn.Module):

def __init__(self,
in_c,
out_c,
btn_c,
kernel_size,
stride,
num_blocks,
with_spp=False,
act='silu',
reparam=False):
super(SuperResConvK1KX, self).__init__()
if act is None:
self.act = torch.relu
else:
self.act = get_activation(act)
self.block_list = nn.ModuleList()
for block_id in range(num_blocks):
if block_id == 0:
in_channels = in_c
out_channels = out_c
this_stride = stride
force_resproj = False # as a part of CSPLayer, DO NOT need this flag
this_kernel_size = kernel_size
else:
in_channels = out_c
out_channels = out_c
this_stride = 1
force_resproj = False
this_kernel_size = kernel_size
the_block = ResConvK1KX(
in_channels,
out_channels,
btn_c,
this_kernel_size,
this_stride,
force_resproj,
act=act,
reparam=reparam)
self.block_list.append(the_block)
if block_id == 0 and with_spp:
self.block_list.append(
SPPBottleneck(out_channels, out_channels))

def forward(self, x):
output = x
for block in self.block_list:
output = block(output)
return output


class ResConvKXKX(nn.Module):

def __init__(self,
in_c,
out_c,
btn_c,
kernel_size,
stride,
force_resproj=False,
act='silu'):
super(ResConvKXKX, self).__init__()
self.stride = stride
if self.stride == 2:
self.downsampler = ConvKXBNRELU(in_c, out_c, 3, 2, act=act)
else:
self.conv1 = ConvKXBN(in_c, btn_c, kernel_size, 1)
self.conv2 = RepVggBlock(
btn_c, out_c, kernel_size, stride, act='identity')

if act is None:
self.activation_function = torch.relu
else:
self.activation_function = get_activation(act)

if stride == 2:
self.residual_downsample = nn.AvgPool2d(
kernel_size=2, stride=2)
else:
self.residual_downsample = nn.Identity()

if in_c != out_c or force_resproj:
self.residual_proj = ConvKXBN(in_c, out_c, 1, 1)
else:
self.residual_proj = nn.Identity()

def forward(self, x):
if self.stride == 2:
return self.downsampler(x)
reslink = self.residual_downsample(x)
reslink = self.residual_proj(reslink)

output = x
output = self.conv1(output)
output = self.activation_function(output)
output = self.conv2(output)

output = output + reslink
output = self.activation_function(output)

return output


class SuperResConvKXKX(nn.Module):

def __init__(self,
in_c,
out_c,
btn_c,
kernel_size,
stride,
num_blocks,
with_spp=False,
act='silu'):
super(SuperResConvKXKX, self).__init__()
if act is None:
self.act = torch.relu
else:
self.act = get_activation(act)
self.block_list = nn.ModuleList()
for block_id in range(num_blocks):
if block_id == 0:
in_channels = in_c
out_channels = out_c
this_stride = stride
force_resproj = False # as a part of CSPLayer, DO NOT need this flag
this_kernel_size = kernel_size
else:
in_channels = out_c
out_channels = out_c
this_stride = 1
force_resproj = False
this_kernel_size = kernel_size
the_block = ResConvKXKX(
in_channels,
out_channels,
btn_c,
this_kernel_size,
this_stride,
force_resproj,
act=act)
self.block_list.append(the_block)
if block_id == 0 and with_spp:
self.block_list.append(
SPPBottleneck(out_channels, out_channels))

def forward(self, x):
output = x
for block in self.block_list:
output = block(output)
return output


class TinyNAS(nn.Module):

def __init__(self,
structure_info=None,
out_indices=[0, 1, 2, 4, 5],
out_channels=[None, None, 128, 256, 512],
with_spp=False,
use_focus=False,
need_conv1=True,
act='silu',
reparam=False):
super(TinyNAS, self).__init__()
assert len(out_indices) == len(out_channels)
self.out_indices = out_indices
self.need_conv1 = need_conv1

self.block_list = nn.ModuleList()
if need_conv1:
self.conv1_list = nn.ModuleList()
for idx, block_info in enumerate(structure_info):
the_block_class = block_info['class']
if the_block_class == 'ConvKXBNRELU':
if use_focus:
the_block = Focus(
block_info['in'],
block_info['out'],
block_info['k'],
act=act)
else:
the_block = ConvKXBNRELU(
block_info['in'],
block_info['out'],
block_info['k'],
block_info['s'],
act=act)
self.block_list.append(the_block)
elif the_block_class == 'SuperResConvK1KX':
spp = with_spp if idx == len(structure_info) - 1 else False
the_block = SuperResConvK1KX(
block_info['in'],
block_info['out'],
block_info['btn'],
block_info['k'],
block_info['s'],
block_info['L'],
spp,
act=act,
reparam=reparam)
self.block_list.append(the_block)
elif the_block_class == 'SuperResConvKXKX':
spp = with_spp if idx == len(structure_info) - 1 else False
the_block = SuperResConvKXKX(
block_info['in'],
block_info['out'],
block_info['btn'],
block_info['k'],
block_info['s'],
block_info['L'],
spp,
act=act)
self.block_list.append(the_block)
if need_conv1:
if idx in self.out_indices and out_channels[
self.out_indices.index(idx)] is not None:
self.conv1_list.append(
nn.Conv2d(block_info['out'],
out_channels[self.out_indices.index(idx)],
1))
else:
self.conv1_list.append(None)

def init_weights(self, pretrain=None):
pass

def forward(self, x):
output = x
stage_feature_list = []
for idx, block in enumerate(self.block_list):
output = block(output)
if idx in self.out_indices:
if self.need_conv1 and self.conv1_list[idx] is not None:
true_out = self.conv1_list[idx](output)
stage_feature_list.append(true_out)
else:
stage_feature_list.append(output)
return stage_feature_list


def load_tinynas_net(backbone_cfg):
# load masternet model to path
import ast
net_structure_str = read_file(backbone_cfg.structure_file)
struct_str = ''.join([x.strip() for x in net_structure_str])
struct_info = ast.literal_eval(struct_str)
for layer in struct_info:
if 'nbitsA' in layer:
del layer['nbitsA']
if 'nbitsW' in layer:
del layer['nbitsW']

model = TinyNAS(
structure_info=struct_info,
out_indices=backbone_cfg.out_indices,
out_channels=backbone_cfg.out_channels,
with_spp=backbone_cfg.with_spp,
use_focus=backbone_cfg.use_focus,
act=backbone_cfg.act,
need_conv1=backbone_cfg.need_conv1,
reparam=backbone_cfg.reparam)

return model

+ 295
- 0
modelscope/models/cv/tinynas_detection/backbone/tinynas_csp.py View File

@@ -0,0 +1,295 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# The DAMO-YOLO implementation is also open-sourced by the authors, and available
# at https://github.com/tinyvision/damo-yolo.

import torch
import torch.nn as nn

from modelscope.models.cv.tinynas_detection.core.ops import (Focus, RepConv,
SPPBottleneck,
get_activation)
from modelscope.utils.file_utils import read_file


class ConvKXBN(nn.Module):

def __init__(self, in_c, out_c, kernel_size, stride):
super(ConvKXBN, self).__init__()
self.conv1 = nn.Conv2d(
in_c,
out_c,
kernel_size,
stride, (kernel_size - 1) // 2,
groups=1,
bias=False)
self.bn1 = nn.BatchNorm2d(out_c)

def forward(self, x):
return self.bn1(self.conv1(x))


class ConvKXBNRELU(nn.Module):

def __init__(self, in_c, out_c, kernel_size, stride, act='silu'):
super(ConvKXBNRELU, self).__init__()
self.conv = ConvKXBN(in_c, out_c, kernel_size, stride)
if act is None:
self.activation_function = torch.relu
else:
self.activation_function = get_activation(act)

def forward(self, x):
output = self.conv(x)
return self.activation_function(output)


class ResConvBlock(nn.Module):

def __init__(self,
in_c,
out_c,
btn_c,
kernel_size,
stride,
act='silu',
reparam=False,
block_type='k1kx'):
super(ResConvBlock, self).__init__()
self.stride = stride
if block_type == 'k1kx':
self.conv1 = ConvKXBN(in_c, btn_c, kernel_size=1, stride=1)
else:
self.conv1 = ConvKXBN(
in_c, btn_c, kernel_size=kernel_size, stride=1)
if not reparam:
self.conv2 = ConvKXBN(btn_c, out_c, kernel_size, stride)
else:
self.conv2 = RepConv(
btn_c, out_c, kernel_size, stride, act='identity')

self.activation_function = get_activation(act)

if in_c != out_c and stride != 2:
self.residual_proj = ConvKXBN(in_c, out_c, kernel_size=1, stride=1)
else:
self.residual_proj = None

def forward(self, x):
if self.residual_proj is not None:
reslink = self.residual_proj(x)
else:
reslink = x
x = self.conv1(x)
x = self.activation_function(x)
x = self.conv2(x)
if self.stride != 2:
x = x + reslink
x = self.activation_function(x)
return x


class CSPStem(nn.Module):

def __init__(self,
in_c,
out_c,
btn_c,
stride,
kernel_size,
num_blocks,
act='silu',
reparam=False,
block_type='k1kx'):
super(CSPStem, self).__init__()
self.in_channels = in_c
self.out_channels = out_c
self.stride = stride
if self.stride == 2:
self.num_blocks = num_blocks - 1
else:
self.num_blocks = num_blocks
self.kernel_size = kernel_size
self.act = act
self.block_type = block_type
out_c = out_c // 2

if act is None:
self.act = torch.relu
else:
self.act = get_activation(act)
self.block_list = nn.ModuleList()
for block_id in range(self.num_blocks):
if self.stride == 1 and block_id == 0:
in_c = in_c // 2
else:
in_c = out_c
the_block = ResConvBlock(
in_c,
out_c,
btn_c,
kernel_size,
stride=1,
act=act,
reparam=reparam,
block_type=block_type)
self.block_list.append(the_block)

def forward(self, x):
output = x
for block in self.block_list:
output = block(output)
return output


class TinyNAS(nn.Module):

def __init__(self,
structure_info=None,
out_indices=[2, 3, 4],
with_spp=False,
use_focus=False,
act='silu',
reparam=False):
super(TinyNAS, self).__init__()
self.out_indices = out_indices
self.block_list = nn.ModuleList()
self.stride_list = []

for idx, block_info in enumerate(structure_info):
the_block_class = block_info['class']
if the_block_class == 'ConvKXBNRELU':
if use_focus and idx == 0:
the_block = Focus(
block_info['in'],
block_info['out'],
block_info['k'],
act=act)
else:
the_block = ConvKXBNRELU(
block_info['in'],
block_info['out'],
block_info['k'],
block_info['s'],
act=act)
elif the_block_class == 'SuperResConvK1KX':
the_block = CSPStem(
block_info['in'],
block_info['out'],
block_info['btn'],
block_info['s'],
block_info['k'],
block_info['L'],
act=act,
reparam=reparam,
block_type='k1kx')
elif the_block_class == 'SuperResConvKXKX':
the_block = CSPStem(
block_info['in'],
block_info['out'],
block_info['btn'],
block_info['s'],
block_info['k'],
block_info['L'],
act=act,
reparam=reparam,
block_type='kxkx')
else:
raise NotImplementedError

self.block_list.append(the_block)

self.csp_stage = nn.ModuleList()
self.csp_stage.append(self.block_list[0])
self.csp_stage.append(CSPWrapper(self.block_list[1]))
self.csp_stage.append(CSPWrapper(self.block_list[2]))
self.csp_stage.append(
CSPWrapper((self.block_list[3], self.block_list[4])))
self.csp_stage.append(
CSPWrapper(self.block_list[5], with_spp=with_spp))
del self.block_list

def init_weights(self, pretrain=None):
pass

def forward(self, x):
output = x
stage_feature_list = []
for idx, block in enumerate(self.csp_stage):
output = block(output)
if idx in self.out_indices:
stage_feature_list.append(output)
return stage_feature_list


class CSPWrapper(nn.Module):

def __init__(self, convstem, act='relu', reparam=False, with_spp=False):

super(CSPWrapper, self).__init__()
self.with_spp = with_spp
if isinstance(convstem, tuple):
in_c = convstem[0].in_channels
out_c = convstem[-1].out_channels
hidden_dim = convstem[0].out_channels // 2
_convstem = nn.ModuleList()
for modulelist in convstem:
for layer in modulelist.block_list:
_convstem.append(layer)
else:
in_c = convstem.in_channels
out_c = convstem.out_channels
hidden_dim = out_c // 2
_convstem = convstem.block_list

self.convstem = nn.ModuleList()
for layer in _convstem:
self.convstem.append(layer)

self.act = get_activation(act)
self.downsampler = ConvKXBNRELU(
in_c, hidden_dim * 2, 3, 2, act=self.act)
if self.with_spp:
self.spp = SPPBottleneck(hidden_dim * 2, hidden_dim * 2)
if len(self.convstem) > 0:
self.conv_start = ConvKXBNRELU(
hidden_dim * 2, hidden_dim, 1, 1, act=self.act)
self.conv_shortcut = ConvKXBNRELU(
hidden_dim * 2, out_c // 2, 1, 1, act=self.act)
self.conv_fuse = ConvKXBNRELU(out_c, out_c, 1, 1, act=self.act)

def forward(self, x):
x = self.downsampler(x)
if self.with_spp:
x = self.spp(x)
if len(self.convstem) > 0:
shortcut = self.conv_shortcut(x)
x = self.conv_start(x)
for block in self.convstem:
x = block(x)
x = torch.cat((x, shortcut), dim=1)
x = self.conv_fuse(x)
return x


def load_tinynas_net(backbone_cfg):
# load masternet model to path
import ast

net_structure_str = read_file(backbone_cfg.structure_file)
struct_str = ''.join([x.strip() for x in net_structure_str])
struct_info = ast.literal_eval(struct_str)
for layer in struct_info:
if 'nbitsA' in layer:
del layer['nbitsA']
if 'nbitsW' in layer:
del layer['nbitsW']

model = TinyNAS(
structure_info=struct_info,
out_indices=backbone_cfg.out_indices,
with_spp=backbone_cfg.with_spp,
use_focus=backbone_cfg.use_focus,
act=backbone_cfg.act,
reparam=backbone_cfg.reparam)

return model

+ 238
- 0
modelscope/models/cv/tinynas_detection/backbone/tinynas_res.py View File

@@ -0,0 +1,238 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# The DAMO-YOLO implementation is also open-sourced by the authors, and available
# at https://github.com/tinyvision/damo-yolo.

import torch
import torch.nn as nn

from modelscope.models.cv.tinynas_detection.core.ops import (Focus, RepConv,
SPPBottleneck,
get_activation)
from modelscope.utils.file_utils import read_file


class ConvKXBN(nn.Module):

def __init__(self, in_c, out_c, kernel_size, stride):
super(ConvKXBN, self).__init__()
self.conv1 = nn.Conv2d(
in_c,
out_c,
kernel_size,
stride, (kernel_size - 1) // 2,
groups=1,
bias=False)
self.bn1 = nn.BatchNorm2d(out_c)

def forward(self, x):
return self.bn1(self.conv1(x))


class ConvKXBNRELU(nn.Module):

def __init__(self, in_c, out_c, kernel_size, stride, act='silu'):
super(ConvKXBNRELU, self).__init__()
self.conv = ConvKXBN(in_c, out_c, kernel_size, stride)
if act is None:
self.activation_function = torch.relu
else:
self.activation_function = get_activation(act)

def forward(self, x):
output = self.conv(x)
return self.activation_function(output)


class ResConvBlock(nn.Module):

def __init__(self,
in_c,
out_c,
btn_c,
kernel_size,
stride,
act='silu',
reparam=False,
block_type='k1kx'):
super(ResConvBlock, self).__init__()
self.stride = stride
if block_type == 'k1kx':
self.conv1 = ConvKXBN(in_c, btn_c, kernel_size=1, stride=1)
else:
self.conv1 = ConvKXBN(
in_c, btn_c, kernel_size=kernel_size, stride=1)

if not reparam:
self.conv2 = ConvKXBN(btn_c, out_c, kernel_size, stride)
else:
self.conv2 = RepConv(
btn_c, out_c, kernel_size, stride, act='identity')

self.activation_function = get_activation(act)

if in_c != out_c and stride != 2:
self.residual_proj = ConvKXBN(in_c, out_c, 1, 1)
else:
self.residual_proj = None

def forward(self, x):
if self.residual_proj is not None:
reslink = self.residual_proj(x)
else:
reslink = x
x = self.conv1(x)
x = self.activation_function(x)
x = self.conv2(x)
if self.stride != 2:
x = x + reslink
x = self.activation_function(x)
return x


class SuperResStem(nn.Module):

def __init__(self,
in_c,
out_c,
btn_c,
kernel_size,
stride,
num_blocks,
with_spp=False,
act='silu',
reparam=False,
block_type='k1kx'):
super(SuperResStem, self).__init__()
if act is None:
self.act = torch.relu
else:
self.act = get_activation(act)
self.block_list = nn.ModuleList()
for block_id in range(num_blocks):
if block_id == 0:
in_channels = in_c
out_channels = out_c
this_stride = stride
this_kernel_size = kernel_size
else:
in_channels = out_c
out_channels = out_c
this_stride = 1
this_kernel_size = kernel_size
the_block = ResConvBlock(
in_channels,
out_channels,
btn_c,
this_kernel_size,
this_stride,
act=act,
reparam=reparam,
block_type=block_type)
self.block_list.append(the_block)
if block_id == 0 and with_spp:
self.block_list.append(
SPPBottleneck(out_channels, out_channels))

def forward(self, x):
output = x
for block in self.block_list:
output = block(output)
return output


class TinyNAS(nn.Module):

def __init__(self,
structure_info=None,
out_indices=[2, 4, 5],
with_spp=False,
use_focus=False,
act='silu',
reparam=False):
super(TinyNAS, self).__init__()
self.out_indices = out_indices
self.block_list = nn.ModuleList()

for idx, block_info in enumerate(structure_info):
the_block_class = block_info['class']
if the_block_class == 'ConvKXBNRELU':
if use_focus:
the_block = Focus(
block_info['in'],
block_info['out'],
block_info['k'],
act=act)
else:
the_block = ConvKXBNRELU(
block_info['in'],
block_info['out'],
block_info['k'],
block_info['s'],
act=act)
self.block_list.append(the_block)
elif the_block_class == 'SuperResConvK1KX':
spp = with_spp if idx == len(structure_info) - 1 else False
the_block = SuperResStem(
block_info['in'],
block_info['out'],
block_info['btn'],
block_info['k'],
block_info['s'],
block_info['L'],
spp,
act=act,
reparam=reparam,
block_type='k1kx')
self.block_list.append(the_block)
elif the_block_class == 'SuperResConvKXKX':
spp = with_spp if idx == len(structure_info) - 1 else False
the_block = SuperResStem(
block_info['in'],
block_info['out'],
block_info['btn'],
block_info['k'],
block_info['s'],
block_info['L'],
spp,
act=act,
reparam=reparam,
block_type='kxkx')
self.block_list.append(the_block)
else:
raise NotImplementedError

def init_weights(self, pretrain=None):
pass

def forward(self, x):
output = x
stage_feature_list = []
for idx, block in enumerate(self.block_list):
output = block(output)
if idx in self.out_indices:
stage_feature_list.append(output)
return stage_feature_list


def load_tinynas_net(backbone_cfg):
# load masternet model to path
import ast

net_structure_str = read_file(backbone_cfg.structure_file)
struct_str = ''.join([x.strip() for x in net_structure_str])
struct_info = ast.literal_eval(struct_str)
for layer in struct_info:
if 'nbitsA' in layer:
del layer['nbitsA']
if 'nbitsW' in layer:
del layer['nbitsW']

model = TinyNAS(
structure_info=struct_info,
out_indices=backbone_cfg.out_indices,
with_spp=backbone_cfg.with_spp,
use_focus=backbone_cfg.use_focus,
act=backbone_cfg.act,
reparam=backbone_cfg.reparam)

return model

+ 1
- 1
modelscope/models/cv/tinynas_detection/core/__init__.py View File

@@ -1,2 +1,2 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet.
# The DAMO-YOLO implementation is also open-sourced by the authors at https://github.com/tinyvision/damo-yolo.

+ 1
- 1
modelscope/models/cv/tinynas_detection/core/base_ops.py View File

@@ -1,5 +1,5 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet.
# The DAMO-YOLO implementation is also open-sourced by the authors at https://github.com/tinyvision/damo-yolo.
import math

import torch


+ 1
- 1
modelscope/models/cv/tinynas_detection/core/neck_ops.py View File

@@ -1,5 +1,5 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet.
# The DAMO-YOLO implementation is also open-sourced by the authors at https://github.com/tinyvision/damo-yolo.

import numpy as np
import torch


+ 435
- 0
modelscope/models/cv/tinynas_detection/core/ops.py View File

@@ -0,0 +1,435 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# The DAMO-YOLO implementation is also open-sourced by the authors at https://github.com/tinyvision/damo-yolo.

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


class SiLU(nn.Module):
"""export-friendly version of nn.SiLU()"""

@staticmethod
def forward(x):
return x * torch.sigmoid(x)


class Swish(nn.Module):

def __init__(self, inplace=True):
super(Swish, self).__init__()
self.inplace = inplace

def forward(self, x):
if self.inplace:
x.mul_(F.sigmoid(x))
return x
else:
return x * F.sigmoid(x)


def get_activation(name='silu', inplace=True):
if name is None:
return nn.Identity()

if isinstance(name, str):
if name == 'silu':
module = nn.SiLU(inplace=inplace)
elif name == 'relu':
module = nn.ReLU(inplace=inplace)
elif name == 'lrelu':
module = nn.LeakyReLU(0.1, inplace=inplace)
elif name == 'swish':
module = Swish(inplace=inplace)
elif name == 'hardsigmoid':
module = nn.Hardsigmoid(inplace=inplace)
elif name == 'identity':
module = nn.Identity()
else:
raise AttributeError('Unsupported act type: {}'.format(name))
return module

elif isinstance(name, nn.Module):
return name

else:
raise AttributeError('Unsupported act type: {}'.format(name))


def get_norm(name, out_channels, inplace=True):
if name == 'bn':
module = nn.BatchNorm2d(out_channels)
else:
raise NotImplementedError
return module


class ConvBNAct(nn.Module):
"""A Conv2d -> Batchnorm -> silu/leaky relu block"""

def __init__(
self,
in_channels,
out_channels,
ksize,
stride=1,
groups=1,
bias=False,
act='silu',
norm='bn',
reparam=False,
):
super().__init__()
# same padding
pad = (ksize - 1) // 2
self.conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size=ksize,
stride=stride,
padding=pad,
groups=groups,
bias=bias,
)
if norm is not None:
self.bn = get_norm(norm, out_channels, inplace=True)
if act is not None:
self.act = get_activation(act, inplace=True)
self.with_norm = norm is not None
self.with_act = act is not None

def forward(self, x):
x = self.conv(x)
if self.with_norm:
x = self.bn(x)
if self.with_act:
x = self.act(x)
return x

def fuseforward(self, x):
return self.act(self.conv(x))


class SPPBottleneck(nn.Module):
"""Spatial pyramid pooling layer used in YOLOv3-SPP"""

def __init__(self,
in_channels,
out_channels,
kernel_sizes=(5, 9, 13),
activation='silu'):
super().__init__()
hidden_channels = in_channels // 2
self.conv1 = ConvBNAct(
in_channels, hidden_channels, 1, stride=1, act=activation)
self.m = nn.ModuleList([
nn.MaxPool2d(kernel_size=ks, stride=1, padding=ks // 2)
for ks in kernel_sizes
])
conv2_channels = hidden_channels * (len(kernel_sizes) + 1)
self.conv2 = ConvBNAct(
conv2_channels, out_channels, 1, stride=1, act=activation)

def forward(self, x):
x = self.conv1(x)
x = torch.cat([x] + [m(x) for m in self.m], dim=1)
x = self.conv2(x)
return x


class Focus(nn.Module):
"""Focus width and height information into channel space."""

def __init__(self,
in_channels,
out_channels,
ksize=1,
stride=1,
act='silu'):
super().__init__()
self.conv = ConvBNAct(
in_channels * 4, out_channels, ksize, stride, act=act)

def forward(self, x):
# shape of x (b,c,w,h) -> y(b,4c,w/2,h/2)
patch_top_left = x[..., ::2, ::2]
patch_top_right = x[..., ::2, 1::2]
patch_bot_left = x[..., 1::2, ::2]
patch_bot_right = x[..., 1::2, 1::2]
x = torch.cat(
(
patch_top_left,
patch_bot_left,
patch_top_right,
patch_bot_right,
),
dim=1,
)
return self.conv(x)


class BasicBlock_3x3_Reverse(nn.Module):

def __init__(self,
ch_in,
ch_hidden_ratio,
ch_out,
act='relu',
shortcut=True):
super(BasicBlock_3x3_Reverse, self).__init__()
assert ch_in == ch_out
ch_hidden = int(ch_in * ch_hidden_ratio)
self.conv1 = ConvBNAct(ch_hidden, ch_out, 3, stride=1, act=act)
self.conv2 = RepConv(ch_in, ch_hidden, 3, stride=1, act=act)
self.shortcut = shortcut

def forward(self, x):
y = self.conv2(x)
y = self.conv1(y)
if self.shortcut:
return x + y
else:
return y


class SPP(nn.Module):

def __init__(
self,
ch_in,
ch_out,
k,
pool_size,
act='swish',
):
super(SPP, self).__init__()
self.pool = []
for i, size in enumerate(pool_size):
pool = nn.MaxPool2d(
kernel_size=size, stride=1, padding=size // 2, ceil_mode=False)
self.add_module('pool{}'.format(i), pool)
self.pool.append(pool)
self.conv = ConvBNAct(ch_in, ch_out, k, act=act)

def forward(self, x):
outs = [x]

for pool in self.pool:
outs.append(pool(x))
y = torch.cat(outs, axis=1)

y = self.conv(y)
return y


class CSPStage(nn.Module):

def __init__(self,
block_fn,
ch_in,
ch_hidden_ratio,
ch_out,
n,
act='swish',
spp=False):
super(CSPStage, self).__init__()

split_ratio = 2
ch_first = int(ch_out // split_ratio)
ch_mid = int(ch_out - ch_first)
self.conv1 = ConvBNAct(ch_in, ch_first, 1, act=act)
self.conv2 = ConvBNAct(ch_in, ch_mid, 1, act=act)
self.convs = nn.Sequential()

next_ch_in = ch_mid
for i in range(n):
if block_fn == 'BasicBlock_3x3_Reverse':
self.convs.add_module(
str(i),
BasicBlock_3x3_Reverse(
next_ch_in,
ch_hidden_ratio,
ch_mid,
act=act,
shortcut=True))
else:
raise NotImplementedError
if i == (n - 1) // 2 and spp:
self.convs.add_module(
'spp', SPP(ch_mid * 4, ch_mid, 1, [5, 9, 13], act=act))
next_ch_in = ch_mid
self.conv3 = ConvBNAct(ch_mid * n + ch_first, ch_out, 1, act=act)

def forward(self, x):
y1 = self.conv1(x)
y2 = self.conv2(x)

mid_out = [y1]
for conv in self.convs:
y2 = conv(y2)
mid_out.append(y2)
y = torch.cat(mid_out, axis=1)
y = self.conv3(y)
return y


def conv_bn(in_channels, out_channels, kernel_size, stride, padding, groups=1):
'''Basic cell for rep-style block, including conv and bn'''
result = nn.Sequential()
result.add_module(
'conv',
nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
groups=groups,
bias=False))
result.add_module('bn', nn.BatchNorm2d(num_features=out_channels))
return result


class RepConv(nn.Module):
'''RepConv is a basic rep-style block, including training and deploy status
Code is based on https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py
'''

def __init__(self,
in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1,
dilation=1,
groups=1,
padding_mode='zeros',
deploy=False,
act='relu',
norm=None):
super(RepConv, self).__init__()
self.deploy = deploy
self.groups = groups
self.in_channels = in_channels
self.out_channels = out_channels

assert kernel_size == 3
assert padding == 1

padding_11 = padding - kernel_size // 2

if isinstance(act, str):
self.nonlinearity = get_activation(act)
else:
self.nonlinearity = act

if deploy:
self.rbr_reparam = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=True,
padding_mode=padding_mode)

else:
self.rbr_identity = None
self.rbr_dense = conv_bn(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
groups=groups)
self.rbr_1x1 = conv_bn(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
stride=stride,
padding=padding_11,
groups=groups)

def forward(self, inputs):
'''Forward process'''
if hasattr(self, 'rbr_reparam'):
return self.nonlinearity(self.rbr_reparam(inputs))

if self.rbr_identity is None:
id_out = 0
else:
id_out = self.rbr_identity(inputs)

return self.nonlinearity(
self.rbr_dense(inputs) + self.rbr_1x1(inputs) + id_out)

def get_equivalent_kernel_bias(self):
kernel3x3, bias3x3 = self._fuse_bn_tensor(self.rbr_dense)
kernel1x1, bias1x1 = self._fuse_bn_tensor(self.rbr_1x1)
kernelid, biasid = self._fuse_bn_tensor(self.rbr_identity)
return kernel3x3 + self._pad_1x1_to_3x3_tensor(
kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid

def _pad_1x1_to_3x3_tensor(self, kernel1x1):
if kernel1x1 is None:
return 0
else:
return torch.nn.functional.pad(kernel1x1, [1, 1, 1, 1])

def _fuse_bn_tensor(self, branch):
if branch is None:
return 0, 0
if isinstance(branch, nn.Sequential):
kernel = branch.conv.weight
running_mean = branch.bn.running_mean
running_var = branch.bn.running_var
gamma = branch.bn.weight
beta = branch.bn.bias
eps = branch.bn.eps
else:
assert isinstance(branch, nn.BatchNorm2d)
if not hasattr(self, 'id_tensor'):
input_dim = self.in_channels // self.groups
kernel_value = np.zeros((self.in_channels, input_dim, 3, 3),
dtype=np.float32)
for i in range(self.in_channels):
kernel_value[i, i % input_dim, 1, 1] = 1
self.id_tensor = torch.from_numpy(kernel_value).to(
branch.weight.device)
kernel = self.id_tensor
running_mean = branch.running_mean
running_var = branch.running_var
gamma = branch.weight
beta = branch.bias
eps = branch.eps
std = (running_var + eps).sqrt()
t = (gamma / std).reshape(-1, 1, 1, 1)
return kernel * t, beta - running_mean * gamma / std

def switch_to_deploy(self):
if hasattr(self, 'rbr_reparam'):
return
kernel, bias = self.get_equivalent_kernel_bias()
self.rbr_reparam = nn.Conv2d(
in_channels=self.rbr_dense.conv.in_channels,
out_channels=self.rbr_dense.conv.out_channels,
kernel_size=self.rbr_dense.conv.kernel_size,
stride=self.rbr_dense.conv.stride,
padding=self.rbr_dense.conv.padding,
dilation=self.rbr_dense.conv.dilation,
groups=self.rbr_dense.conv.groups,
bias=True)
self.rbr_reparam.weight.data = kernel
self.rbr_reparam.bias.data = bias
for para in self.parameters():
para.detach_()
self.__delattr__('rbr_dense')
self.__delattr__('rbr_1x1')
if hasattr(self, 'rbr_identity'):
self.__delattr__('rbr_identity')
if hasattr(self, 'id_tensor'):
self.__delattr__('id_tensor')
self.deploy = True

+ 1
- 1
modelscope/models/cv/tinynas_detection/core/repvgg_block.py View File

@@ -1,5 +1,5 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet.
# The DAMO-YOLO implementation is also open-sourced by the authors at https://github.com/tinyvision/damo-yolo.

import numpy as np
import torch


+ 1
- 1
modelscope/models/cv/tinynas_detection/core/utils.py View File

@@ -1,5 +1,5 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet.
# The DAMO-YOLO implementation is also open-sourced by the authors at https://github.com/tinyvision/damo-yolo.

import numpy as np
import torch


+ 2
- 2
modelscope/models/cv/tinynas_detection/detector.py View File

@@ -1,5 +1,5 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet.
# The DAMO-YOLO implementation is also open-sourced by the authors at https://github.com/tinyvision/damo-yolo.

import os.path as osp
import pickle
@@ -42,7 +42,7 @@ class SingleStageDetector(TorchModel):
self.conf_thre = config.model.head.nms_conf_thre
self.nms_thre = config.model.head.nms_iou_thre

if self.cfg.model.backbone.name == 'TinyNAS':
if 'TinyNAS' in self.cfg.model.backbone.name:
self.cfg.model.backbone.structure_file = osp.join(
model_dir, self.cfg.model.backbone.structure_file)
self.backbone = build_backbone(self.cfg.model.backbone)


+ 4
- 1
modelscope/models/cv/tinynas_detection/head/__init__.py View File

@@ -1,9 +1,10 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet.
# The DAMO-YOLO implementation is also open-sourced by the authors at https://github.com/tinyvision/damo-yolo.

import copy

from .gfocal_v2_tiny import GFocalHead_Tiny
from .zero_head import ZeroHead


def build_head(cfg):
@@ -12,5 +13,7 @@ def build_head(cfg):
name = head_cfg.pop('name')
if name == 'GFocalV2':
return GFocalHead_Tiny(**head_cfg)
elif name == 'ZeroHead':
return ZeroHead(**head_cfg)
else:
raise NotImplementedError

+ 3
- 2
modelscope/models/cv/tinynas_detection/head/gfocal_v2_tiny.py View File

@@ -1,5 +1,5 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet.
# The DAMO-YOLO implementation is also open-sourced by the authors at https://github.com/tinyvision/damo-yolo.

import functools
from functools import partial
@@ -9,7 +9,8 @@ import torch
import torch.nn as nn
import torch.nn.functional as F

from ..core.base_ops import BaseConv, DWConv
from modelscope.models.cv.tinynas_detection.core.base_ops import (BaseConv,
DWConv)


class Scale(nn.Module):


+ 288
- 0
modelscope/models/cv/tinynas_detection/head/zero_head.py View File

@@ -0,0 +1,288 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# The DAMO-YOLO implementation is also open-sourced by the authors, and available
# at https://github.com/tinyvision/damo-yolo.
from functools import partial

import torch
import torch.nn as nn
import torch.nn.functional as F

from modelscope.models.cv.tinynas_detection.core.ops import ConvBNAct


class Scale(nn.Module):

def __init__(self, scale=1.0):
super(Scale, self).__init__()
self.scale = nn.Parameter(torch.tensor(scale, dtype=torch.float))

def forward(self, x):
return x * self.scale


def multi_apply(func, *args, **kwargs):

pfunc = partial(func, **kwargs) if kwargs else func
map_results = map(pfunc, *args)
return tuple(map(list, zip(*map_results)))


def distance2bbox(points, distance, max_shape=None):
"""Decode distance prediction to bounding box.
"""
x1 = points[..., 0] - distance[..., 0]
y1 = points[..., 1] - distance[..., 1]
x2 = points[..., 0] + distance[..., 2]
y2 = points[..., 1] + distance[..., 3]
if max_shape is not None:
x1 = x1.clamp(min=0, max=max_shape[1])
y1 = y1.clamp(min=0, max=max_shape[0])
x2 = x2.clamp(min=0, max=max_shape[1])
y2 = y2.clamp(min=0, max=max_shape[0])
return torch.stack([x1, y1, x2, y2], -1)


def bbox2distance(points, bbox, max_dis=None, eps=0.1):
"""Decode bounding box based on distances.
"""
left = points[:, 0] - bbox[:, 0]
top = points[:, 1] - bbox[:, 1]
right = bbox[:, 2] - points[:, 0]
bottom = bbox[:, 3] - points[:, 1]
if max_dis is not None:
left = left.clamp(min=0, max=max_dis - eps)
top = top.clamp(min=0, max=max_dis - eps)
right = right.clamp(min=0, max=max_dis - eps)
bottom = bottom.clamp(min=0, max=max_dis - eps)
return torch.stack([left, top, right, bottom], -1)


class Integral(nn.Module):
"""A fixed layer for calculating integral result from distribution.
"""

def __init__(self, reg_max=16):
super(Integral, self).__init__()
self.reg_max = reg_max
self.register_buffer('project',
torch.linspace(0, self.reg_max, self.reg_max + 1))

def forward(self, x):
"""Forward feature from the regression head to get integral result of
bounding box location.
"""
b, hw, _, _ = x.size()
x = x.reshape(b * hw * 4, self.reg_max + 1)
y = self.project.type_as(x).unsqueeze(1)
x = torch.matmul(x, y).reshape(b, hw, 4)
return x


class ZeroHead(nn.Module):
"""Ref to Generalized Focal Loss V2: Learning Reliable Localization Quality
Estimation for Dense Object Detection.
"""

def __init__(
self,
num_classes,
in_channels,
stacked_convs=4, # 4
feat_channels=256,
reg_max=12,
strides=[8, 16, 32],
norm='gn',
act='relu',
nms_conf_thre=0.05,
nms_iou_thre=0.7,
nms=True,
**kwargs):
self.in_channels = in_channels
self.num_classes = num_classes
self.stacked_convs = stacked_convs
self.act = act
self.strides = strides
if stacked_convs == 0:
feat_channels = in_channels
if isinstance(feat_channels, list):
self.feat_channels = feat_channels
else:
self.feat_channels = [feat_channels] * len(self.strides)
# add 1 for keep consistance with former models
self.cls_out_channels = num_classes + 1
self.reg_max = reg_max

self.nms = nms
self.nms_conf_thre = nms_conf_thre
self.nms_iou_thre = nms_iou_thre

self.feat_size = [torch.zeros(4) for _ in strides]

super(ZeroHead, self).__init__()
self.integral = Integral(self.reg_max)

self._init_layers()

def _build_not_shared_convs(self, in_channel, feat_channels):
cls_convs = nn.ModuleList()
reg_convs = nn.ModuleList()

for i in range(self.stacked_convs):
chn = feat_channels if i > 0 else in_channel
kernel_size = 3 if i > 0 else 1
cls_convs.append(
ConvBNAct(
chn,
feat_channels,
kernel_size,
stride=1,
groups=1,
norm='bn',
act=self.act))
reg_convs.append(
ConvBNAct(
chn,
feat_channels,
kernel_size,
stride=1,
groups=1,
norm='bn',
act=self.act))

return cls_convs, reg_convs

def _init_layers(self):
"""Initialize layers of the head."""
self.cls_convs = nn.ModuleList()
self.reg_convs = nn.ModuleList()

for i in range(len(self.strides)):
cls_convs, reg_convs = self._build_not_shared_convs(
self.in_channels[i], self.feat_channels[i])
self.cls_convs.append(cls_convs)
self.reg_convs.append(reg_convs)

self.gfl_cls = nn.ModuleList([
nn.Conv2d(
self.feat_channels[i], self.cls_out_channels, 3, padding=1)
for i in range(len(self.strides))
])

self.gfl_reg = nn.ModuleList([
nn.Conv2d(
self.feat_channels[i], 4 * (self.reg_max + 1), 3, padding=1)
for i in range(len(self.strides))
])

self.scales = nn.ModuleList([Scale(1.0) for _ in self.strides])

def forward(self, xin, labels=None, imgs=None, aux_targets=None):
if self.training:
return NotImplementedError
else:
return self.forward_eval(xin=xin, labels=labels, imgs=imgs)

def forward_eval(self, xin, labels=None, imgs=None):

# prepare priors for label assignment and bbox decode
if self.feat_size[0] != xin[0].shape:
mlvl_priors_list = [
self.get_single_level_center_priors(
xin[i].shape[0],
xin[i].shape[-2:],
stride,
dtype=torch.float32,
device=xin[0].device)
for i, stride in enumerate(self.strides)
]
self.mlvl_priors = torch.cat(mlvl_priors_list, dim=1)
self.feat_size[0] = xin[0].shape

# forward for bboxes and classification prediction
cls_scores, bbox_preds = multi_apply(
self.forward_single,
xin,
self.cls_convs,
self.reg_convs,
self.gfl_cls,
self.gfl_reg,
self.scales,
)
cls_scores = torch.cat(cls_scores, dim=1)[:, :, :self.num_classes]
bbox_preds = torch.cat(bbox_preds, dim=1)
# batch bbox decode
bbox_preds = self.integral(bbox_preds) * self.mlvl_priors[..., 2, None]
bbox_preds = distance2bbox(self.mlvl_priors[..., :2], bbox_preds)

res = torch.cat([bbox_preds, cls_scores[..., 0:self.num_classes]],
dim=-1)
return res

def forward_single(self, x, cls_convs, reg_convs, gfl_cls, gfl_reg, scale):
"""Forward feature of a single scale level.

"""
cls_feat = x
reg_feat = x

for cls_conv, reg_conv in zip(cls_convs, reg_convs):
cls_feat = cls_conv(cls_feat)
reg_feat = reg_conv(reg_feat)

bbox_pred = scale(gfl_reg(reg_feat)).float()
N, C, H, W = bbox_pred.size()
if self.training:
bbox_before_softmax = bbox_pred.reshape(N, 4, self.reg_max + 1, H,
W)
bbox_before_softmax = bbox_before_softmax.flatten(
start_dim=3).permute(0, 3, 1, 2)
bbox_pred = F.softmax(
bbox_pred.reshape(N, 4, self.reg_max + 1, H, W), dim=2)

cls_score = gfl_cls(cls_feat).sigmoid()

cls_score = cls_score.flatten(start_dim=2).permute(
0, 2, 1) # N, h*w, self.num_classes+1
bbox_pred = bbox_pred.flatten(start_dim=3).permute(
0, 3, 1, 2) # N, h*w, 4, self.reg_max+1
if self.training:
return cls_score, bbox_pred, bbox_before_softmax
else:
return cls_score, bbox_pred

def get_single_level_center_priors(self, batch_size, featmap_size, stride,
dtype, device):

h, w = featmap_size
x_range = (torch.arange(0, int(w), dtype=dtype,
device=device)) * stride
y_range = (torch.arange(0, int(h), dtype=dtype,
device=device)) * stride

x = x_range.repeat(h, 1)
y = y_range.unsqueeze(-1).repeat(1, w)

y = y.flatten()
x = x.flatten()
strides = x.new_full((x.shape[0], ), stride)
priors = torch.stack([x, y, strides, strides], dim=-1)

return priors.unsqueeze(0).repeat(batch_size, 1, 1)

def sample(self, assign_result, gt_bboxes):
pos_inds = torch.nonzero(
assign_result.gt_inds > 0, as_tuple=False).squeeze(-1).unique()
neg_inds = torch.nonzero(
assign_result.gt_inds == 0, as_tuple=False).squeeze(-1).unique()
pos_assigned_gt_inds = assign_result.gt_inds[pos_inds] - 1

if gt_bboxes.numel() == 0:
# hack for index error case
assert pos_assigned_gt_inds.numel() == 0
pos_gt_bboxes = torch.empty_like(gt_bboxes).view(-1, 4)
else:
if len(gt_bboxes.shape) < 2:
gt_bboxes = gt_bboxes.view(-1, 4)
pos_gt_bboxes = gt_bboxes[pos_assigned_gt_inds, :]

return pos_inds, neg_inds, pos_gt_bboxes, pos_assigned_gt_inds

+ 2
- 2
modelscope/models/cv/tinynas_detection/neck/__init__.py View File

@@ -1,10 +1,10 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet.
# The DAMO-YOLO implementation is also open-sourced by the authors at https://github.com/tinyvision/damo-yolo.

import copy

from .giraffe_fpn import GiraffeNeck
from .giraffe_fpn_v2 import GiraffeNeckV2
from .giraffe_fpn_btn import GiraffeNeckV2


def build_neck(cfg):


+ 1
- 1
modelscope/models/cv/tinynas_detection/neck/giraffe_config.py View File

@@ -1,5 +1,5 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet.
# The DAMO-YOLO implementation is also open-sourced by the authors at https://github.com/tinyvision/damo-yolo.

import collections
import itertools


+ 3
- 2
modelscope/models/cv/tinynas_detection/neck/giraffe_fpn.py View File

@@ -1,5 +1,5 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet.
# The DAMO-YOLO implementation is also open-sourced by the authors at https://github.com/tinyvision/damo-yolo.

import logging
import math
@@ -15,7 +15,8 @@ from timm import create_model
from timm.models.layers import (Swish, create_conv2d, create_pool2d,
get_act_layer)

from ..core.base_ops import CSPLayer, ShuffleBlock, ShuffleCSPLayer
from modelscope.models.cv.tinynas_detection.core.base_ops import (
CSPLayer, ShuffleBlock, ShuffleCSPLayer)
from .giraffe_config import get_graph_config

_ACT_LAYER = Swish


+ 132
- 0
modelscope/models/cv/tinynas_detection/neck/giraffe_fpn_btn.py View File

@@ -0,0 +1,132 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# The DAMO-YOLO implementation is also open-sourced by the authors at https://github.com/tinyvision/damo-yolo.

import torch
import torch.nn as nn

from modelscope.models.cv.tinynas_detection.core.ops import ConvBNAct, CSPStage


class GiraffeNeckV2(nn.Module):

def __init__(
self,
depth=1.0,
hidden_ratio=1.0,
in_features=[2, 3, 4],
in_channels=[256, 512, 1024],
out_channels=[256, 512, 1024],
act='silu',
spp=False,
block_name='BasicBlock',
):
super().__init__()
self.in_features = in_features
self.in_channels = in_channels
self.out_channels = out_channels
Conv = ConvBNAct

self.upsample = nn.Upsample(scale_factor=2, mode='nearest')

# node x3: input x0, x1
self.bu_conv13 = Conv(in_channels[1], in_channels[1], 3, 2, act=act)
self.merge_3 = CSPStage(
block_name,
in_channels[1] + in_channels[2],
hidden_ratio,
in_channels[2],
round(3 * depth),
act=act,
spp=spp)

# node x4: input x1, x2, x3
self.bu_conv24 = Conv(in_channels[0], in_channels[0], 3, 2, act=act)
self.merge_4 = CSPStage(
block_name,
in_channels[0] + in_channels[1] + in_channels[2],
hidden_ratio,
in_channels[1],
round(3 * depth),
act=act,
spp=spp)

# node x5: input x2, x4
self.merge_5 = CSPStage(
block_name,
in_channels[1] + in_channels[0],
hidden_ratio,
out_channels[0],
round(3 * depth),
act=act,
spp=spp)

# node x7: input x4, x5
self.bu_conv57 = Conv(out_channels[0], out_channels[0], 3, 2, act=act)
self.merge_7 = CSPStage(
block_name,
out_channels[0] + in_channels[1],
hidden_ratio,
out_channels[1],
round(3 * depth),
act=act,
spp=spp)

# node x6: input x3, x4, x7
self.bu_conv46 = Conv(in_channels[1], in_channels[1], 3, 2, act=act)
self.bu_conv76 = Conv(out_channels[1], out_channels[1], 3, 2, act=act)
self.merge_6 = CSPStage(
block_name,
in_channels[1] + out_channels[1] + in_channels[2],
hidden_ratio,
out_channels[2],
round(3 * depth),
act=act,
spp=spp)

def init_weights(self):
pass

def forward(self, out_features):
"""
Args:
inputs: input images.

Returns:
Tuple[Tensor]: FPN feature.
"""

# backbone
[x2, x1, x0] = out_features

# node x3
x13 = self.bu_conv13(x1)
x3 = torch.cat([x0, x13], 1)
x3 = self.merge_3(x3)

# node x4
x34 = self.upsample(x3)
x24 = self.bu_conv24(x2)
x4 = torch.cat([x1, x24, x34], 1)
x4 = self.merge_4(x4)

# node x5
x45 = self.upsample(x4)
x5 = torch.cat([x2, x45], 1)
x5 = self.merge_5(x5)

# node x8
# x8 = x5

# node x7
x57 = self.bu_conv57(x5)
x7 = torch.cat([x4, x57], 1)
x7 = self.merge_7(x7)

# node x6
x46 = self.bu_conv46(x4)
x76 = self.bu_conv76(x7)
x6 = torch.cat([x3, x46, x76], 1)
x6 = self.merge_6(x6)

outputs = (x5, x7, x6)
return outputs

+ 0
- 200
modelscope/models/cv/tinynas_detection/neck/giraffe_fpn_v2.py View File

@@ -1,200 +0,0 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet.

import torch
import torch.nn as nn

from ..core.base_ops import BaseConv, CSPLayer, DWConv
from ..core.neck_ops import CSPStage


class GiraffeNeckV2(nn.Module):

def __init__(
self,
depth=1.0,
width=1.0,
in_channels=[256, 512, 1024],
out_channels=[256, 512, 1024],
depthwise=False,
act='silu',
spp=True,
reparam_mode=True,
block_name='BasicBlock',
):
super().__init__()
self.in_channels = in_channels
Conv = DWConv if depthwise else BaseConv

reparam_mode = reparam_mode

self.upsample = nn.Upsample(scale_factor=2, mode='nearest')

# node x3: input x0, x1
self.bu_conv13 = Conv(
int(in_channels[1] * width),
int(in_channels[1] * width),
3,
2,
act=act)
if reparam_mode:
self.merge_3 = CSPStage(
block_name,
int((in_channels[1] + in_channels[2]) * width),
int(in_channels[2] * width),
round(3 * depth),
act=act,
spp=spp)
else:
self.merge_3 = CSPLayer(
int((in_channels[1] + in_channels[2]) * width),
int(in_channels[2] * width),
round(3 * depth),
False,
depthwise=depthwise,
act=act)

# node x4: input x1, x2, x3
self.bu_conv24 = Conv(
int(in_channels[0] * width),
int(in_channels[0] * width),
3,
2,
act=act)
if reparam_mode:
self.merge_4 = CSPStage(
block_name,
int((in_channels[0] + in_channels[1] + in_channels[2])
* width),
int(in_channels[1] * width),
round(3 * depth),
act=act,
spp=spp)
else:
self.merge_4 = CSPLayer(
int((in_channels[0] + in_channels[1] + in_channels[2])
* width),
int(in_channels[1] * width),
round(3 * depth),
False,
depthwise=depthwise,
act=act)

# node x5: input x2, x4
if reparam_mode:
self.merge_5 = CSPStage(
block_name,
int((in_channels[1] + in_channels[0]) * width),
int(out_channels[0] * width),
round(3 * depth),
act=act,
spp=spp)
else:
self.merge_5 = CSPLayer(
int((in_channels[1] + in_channels[0]) * width),
int(out_channels[0] * width),
round(3 * depth),
False,
depthwise=depthwise,
act=act)

# node x7: input x4, x5
self.bu_conv57 = Conv(
int(out_channels[0] * width),
int(out_channels[0] * width),
3,
2,
act=act)
if reparam_mode:
self.merge_7 = CSPStage(
block_name,
int((out_channels[0] + in_channels[1]) * width),
int(out_channels[1] * width),
round(3 * depth),
act=act,
spp=spp)
else:
self.merge_7 = CSPLayer(
int((out_channels[0] + in_channels[1]) * width),
int(out_channels[1] * width),
round(3 * depth),
False,
depthwise=depthwise,
act=act)

# node x6: input x3, x4, x7
self.bu_conv46 = Conv(
int(in_channels[1] * width),
int(in_channels[1] * width),
3,
2,
act=act)
self.bu_conv76 = Conv(
int(out_channels[1] * width),
int(out_channels[1] * width),
3,
2,
act=act)
if reparam_mode:
self.merge_6 = CSPStage(
block_name,
int((in_channels[1] + out_channels[1] + in_channels[2])
* width),
int(out_channels[2] * width),
round(3 * depth),
act=act,
spp=spp)
else:
self.merge_6 = CSPLayer(
int((in_channels[1] + out_channels[1] + in_channels[2])
* width),
int(out_channels[2] * width),
round(3 * depth),
False,
depthwise=depthwise,
act=act)

def init_weights(self):
pass

def forward(self, out_features):
"""
Args:
inputs: input images.

Returns:
Tuple[Tensor]: FPN feature.
"""

# backbone
[x2, x1, x0] = out_features

# node x3
x13 = self.bu_conv13(x1)
x3 = torch.cat([x0, x13], 1)
x3 = self.merge_3(x3)

# node x4
x34 = self.upsample(x3)
x24 = self.bu_conv24(x2)
x4 = torch.cat([x1, x24, x34], 1)
x4 = self.merge_4(x4)

# node x5
x45 = self.upsample(x4)
x5 = torch.cat([x2, x45], 1)
x5 = self.merge_5(x5)

# node x7
x57 = self.bu_conv57(x5)
x7 = torch.cat([x4, x57], 1)
x7 = self.merge_7(x7)

# node x6
x46 = self.bu_conv46(x4)
x76 = self.bu_conv76(x7)
x6 = torch.cat([x3, x46, x76], 1)
x6 = self.merge_6(x6)

outputs = (x5, x7, x6)
return outputs

+ 1
- 1
modelscope/models/cv/tinynas_detection/tinynas_damoyolo.py View File

@@ -11,5 +11,5 @@ from .detector import SingleStageDetector
class DamoYolo(SingleStageDetector):

def __init__(self, model_dir, *args, **kwargs):
self.config_name = 'damoyolo_s.py'
self.config_name = 'damoyolo.py'
super(DamoYolo, self).__init__(model_dir, *args, **kwargs)

+ 1
- 1
modelscope/models/cv/tinynas_detection/tinynas_detector.py View File

@@ -1,5 +1,5 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet.
# The DAMO-YOLO implementation is also open-sourced by the authors at https://github.com/tinyvision/damo-yolo.

from modelscope.metainfo import Models
from modelscope.models.builder import MODELS


+ 23
- 20
modelscope/models/cv/tinynas_detection/utils.py View File

@@ -1,30 +1,33 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet.
# The DAMO-YOLO implementation is also open-sourced by the authors, and available
# at https://github.com/tinyvision/damo-yolo.

import importlib
import os
import shutil
import sys
import tempfile
from os.path import dirname, join

from easydict import EasyDict

def get_config_by_file(config_file):
try:
sys.path.append(os.path.dirname(config_file))
current_config = importlib.import_module(
os.path.basename(config_file).split('.')[0])
exp = current_config.Config()
except Exception:
raise ImportError(
"{} doesn't contains class named 'Config'".format(config_file))
return exp

def parse_config(filename):
filename = str(filename)
if filename.endswith('.py'):
with tempfile.TemporaryDirectory() as temp_config_dir:
shutil.copyfile(filename, join(temp_config_dir, '_tempconfig.py'))
sys.path.insert(0, temp_config_dir)
mod = importlib.import_module('_tempconfig')
sys.path.pop(0)
cfg_dict = EasyDict({
name: value
for name, value in mod.__dict__.items()
if not name.startswith('__')
})
# delete imported module
del sys.modules['_tempconfig']
else:
raise IOError('Only .py type are supported now!')

def parse_config(config_file):
"""
get config object by file.
Args:
config_file (str): file path of config.
"""
assert (config_file is not None), 'plz provide config file'
if config_file is not None:
return get_config_by_file(config_file)
return cfg_dict

+ 20
- 2
tests/pipelines/test_tinynas_detection.py View File

@@ -29,7 +29,25 @@ class TinynasObjectDetectionTest(unittest.TestCase, DemoCompatibilityCheck):
model='damo/cv_tinynas_object-detection_damoyolo')
result = tinynas_object_detection(
'data/test/images/image_detection.jpg')
print('damoyolo', result)
print('damoyolo-s', result)

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_damoyolo_m(self):
tinynas_object_detection = pipeline(
Tasks.image_object_detection,
model='damo/cv_tinynas_object-detection_damoyolo-m')
result = tinynas_object_detection(
'data/test/images/image_detection.jpg')
print('damoyolo-m', result)

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_damoyolo_t(self):
tinynas_object_detection = pipeline(
Tasks.image_object_detection,
model='damo/cv_tinynas_object-detection_damoyolo-t')
result = tinynas_object_detection(
'data/test/images/image_detection.jpg')
print('damoyolo-t', result)

@unittest.skip('demo compatibility test is only enabled on a needed-basis')
def test_demo_compatibility(self):
@@ -40,7 +58,7 @@ class TinynasObjectDetectionTest(unittest.TestCase, DemoCompatibilityCheck):
test_image = 'data/test/images/image_detection.jpg'
tinynas_object_detection = pipeline(
Tasks.image_object_detection,
model='damo/cv_tinynas_object-detection_damoyolo')
model='damo/cv_tinynas_object-detection_damoyolo-m')
result = tinynas_object_detection(test_image)
tinynas_object_detection.show_result(test_image, result,
'demo_ret.jpg')


Loading…
Cancel
Save