Browse Source

add damoyolo-t & damoyolo-m

1. add damoyolo-t & damoyolo-m models
2. fix the configuration overlap error
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10816561
master^2
xianzhe.xxz yingda.chen 3 years ago
parent
commit
9da5091d4d
25 changed files with 1463 additions and 606 deletions
  1. +1
    -1
      modelscope/models/cv/tinynas_detection/__init__.py
  2. +7
    -4
      modelscope/models/cv/tinynas_detection/backbone/__init__.py
  3. +2
    -3
      modelscope/models/cv/tinynas_detection/backbone/darknet.py
  4. +0
    -359
      modelscope/models/cv/tinynas_detection/backbone/tinynas.py
  5. +295
    -0
      modelscope/models/cv/tinynas_detection/backbone/tinynas_csp.py
  6. +238
    -0
      modelscope/models/cv/tinynas_detection/backbone/tinynas_res.py
  7. +1
    -1
      modelscope/models/cv/tinynas_detection/core/__init__.py
  8. +1
    -1
      modelscope/models/cv/tinynas_detection/core/base_ops.py
  9. +1
    -1
      modelscope/models/cv/tinynas_detection/core/neck_ops.py
  10. +435
    -0
      modelscope/models/cv/tinynas_detection/core/ops.py
  11. +1
    -1
      modelscope/models/cv/tinynas_detection/core/repvgg_block.py
  12. +1
    -1
      modelscope/models/cv/tinynas_detection/core/utils.py
  13. +2
    -2
      modelscope/models/cv/tinynas_detection/detector.py
  14. +4
    -1
      modelscope/models/cv/tinynas_detection/head/__init__.py
  15. +3
    -2
      modelscope/models/cv/tinynas_detection/head/gfocal_v2_tiny.py
  16. +288
    -0
      modelscope/models/cv/tinynas_detection/head/zero_head.py
  17. +2
    -2
      modelscope/models/cv/tinynas_detection/neck/__init__.py
  18. +1
    -1
      modelscope/models/cv/tinynas_detection/neck/giraffe_config.py
  19. +3
    -2
      modelscope/models/cv/tinynas_detection/neck/giraffe_fpn.py
  20. +132
    -0
      modelscope/models/cv/tinynas_detection/neck/giraffe_fpn_btn.py
  21. +0
    -200
      modelscope/models/cv/tinynas_detection/neck/giraffe_fpn_v2.py
  22. +1
    -1
      modelscope/models/cv/tinynas_detection/tinynas_damoyolo.py
  23. +1
    -1
      modelscope/models/cv/tinynas_detection/tinynas_detector.py
  24. +23
    -20
      modelscope/models/cv/tinynas_detection/utils.py
  25. +20
    -2
      tests/pipelines/test_tinynas_detection.py

+ 1
- 1
modelscope/models/cv/tinynas_detection/__init__.py View File

@@ -1,5 +1,5 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.
# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet.
# The DAMO-YOLO implementation is also open-sourced by the authors at https://github.com/tinyvision/damo-yolo.


from typing import TYPE_CHECKING from typing import TYPE_CHECKING




+ 7
- 4
modelscope/models/cv/tinynas_detection/backbone/__init__.py View File

@@ -1,10 +1,11 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.
# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet.
# The DAMO-YOLO implementation is also open-sourced by the authors at https://github.com/tinyvision/damo-yolo.


import copy import copy


from .darknet import CSPDarknet from .darknet import CSPDarknet
from .tinynas import load_tinynas_net
from .tinynas_csp import load_tinynas_net as load_tinynas_net_csp
from .tinynas_res import load_tinynas_net as load_tinynas_net_res




def build_backbone(cfg): def build_backbone(cfg):
@@ -12,5 +13,7 @@ def build_backbone(cfg):
name = backbone_cfg.pop('name') name = backbone_cfg.pop('name')
if name == 'CSPDarknet': if name == 'CSPDarknet':
return CSPDarknet(**backbone_cfg) return CSPDarknet(**backbone_cfg)
elif name == 'TinyNAS':
return load_tinynas_net(backbone_cfg)
elif name == 'TinyNAS_csp':
return load_tinynas_net_csp(backbone_cfg)
elif name == 'TinyNAS_res':
return load_tinynas_net_res(backbone_cfg)

+ 2
- 3
modelscope/models/cv/tinynas_detection/backbone/darknet.py View File

@@ -1,12 +1,11 @@
# Copyright (c) Megvii Inc. All rights reserved. # Copyright (c) Megvii Inc. All rights reserved.
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.
# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet.


import torch import torch
from torch import nn from torch import nn


from ..core.base_ops import (BaseConv, CSPLayer, DWConv, Focus, ResLayer,
SPPBottleneck)
from modelscope.models.cv.tinynas_detection.core.base_ops import (
BaseConv, CSPLayer, DWConv, Focus, ResLayer, SPPBottleneck)




class CSPDarknet(nn.Module): class CSPDarknet(nn.Module):


+ 0
- 359
modelscope/models/cv/tinynas_detection/backbone/tinynas.py View File

@@ -1,359 +0,0 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet.

import torch
import torch.nn as nn

from modelscope.utils.file_utils import read_file
from ..core.base_ops import Focus, SPPBottleneck, get_activation
from ..core.repvgg_block import RepVggBlock


class ConvKXBN(nn.Module):

def __init__(self, in_c, out_c, kernel_size, stride):
super(ConvKXBN, self).__init__()
self.conv1 = nn.Conv2d(
in_c,
out_c,
kernel_size,
stride, (kernel_size - 1) // 2,
groups=1,
bias=False)
self.bn1 = nn.BatchNorm2d(out_c)

def forward(self, x):
return self.bn1(self.conv1(x))


class ConvKXBNRELU(nn.Module):

def __init__(self, in_c, out_c, kernel_size, stride, act='silu'):
super(ConvKXBNRELU, self).__init__()
self.conv = ConvKXBN(in_c, out_c, kernel_size, stride)
if act is None:
self.activation_function = torch.relu
else:
self.activation_function = get_activation(act)

def forward(self, x):
output = self.conv(x)
return self.activation_function(output)


class ResConvK1KX(nn.Module):

def __init__(self,
in_c,
out_c,
btn_c,
kernel_size,
stride,
force_resproj=False,
act='silu',
reparam=False):
super(ResConvK1KX, self).__init__()
self.stride = stride
self.conv1 = ConvKXBN(in_c, btn_c, 1, 1)
if not reparam:
self.conv2 = ConvKXBN(btn_c, out_c, 3, stride)
else:
self.conv2 = RepVggBlock(
btn_c, out_c, kernel_size, stride, act='identity')

if act is None:
self.activation_function = torch.relu
else:
self.activation_function = get_activation(act)

if stride == 2:
self.residual_downsample = nn.AvgPool2d(kernel_size=2, stride=2)
else:
self.residual_downsample = nn.Identity()

if in_c != out_c or force_resproj:
self.residual_proj = ConvKXBN(in_c, out_c, 1, 1)
else:
self.residual_proj = nn.Identity()

def forward(self, x):
if self.stride != 2:
reslink = self.residual_downsample(x)
reslink = self.residual_proj(reslink)

output = x
output = self.conv1(output)
output = self.activation_function(output)
output = self.conv2(output)
if self.stride != 2:
output = output + reslink
output = self.activation_function(output)

return output


class SuperResConvK1KX(nn.Module):

def __init__(self,
in_c,
out_c,
btn_c,
kernel_size,
stride,
num_blocks,
with_spp=False,
act='silu',
reparam=False):
super(SuperResConvK1KX, self).__init__()
if act is None:
self.act = torch.relu
else:
self.act = get_activation(act)
self.block_list = nn.ModuleList()
for block_id in range(num_blocks):
if block_id == 0:
in_channels = in_c
out_channels = out_c
this_stride = stride
force_resproj = False # as a part of CSPLayer, DO NOT need this flag
this_kernel_size = kernel_size
else:
in_channels = out_c
out_channels = out_c
this_stride = 1
force_resproj = False
this_kernel_size = kernel_size
the_block = ResConvK1KX(
in_channels,
out_channels,
btn_c,
this_kernel_size,
this_stride,
force_resproj,
act=act,
reparam=reparam)
self.block_list.append(the_block)
if block_id == 0 and with_spp:
self.block_list.append(
SPPBottleneck(out_channels, out_channels))

def forward(self, x):
output = x
for block in self.block_list:
output = block(output)
return output


class ResConvKXKX(nn.Module):

def __init__(self,
in_c,
out_c,
btn_c,
kernel_size,
stride,
force_resproj=False,
act='silu'):
super(ResConvKXKX, self).__init__()
self.stride = stride
if self.stride == 2:
self.downsampler = ConvKXBNRELU(in_c, out_c, 3, 2, act=act)
else:
self.conv1 = ConvKXBN(in_c, btn_c, kernel_size, 1)
self.conv2 = RepVggBlock(
btn_c, out_c, kernel_size, stride, act='identity')

if act is None:
self.activation_function = torch.relu
else:
self.activation_function = get_activation(act)

if stride == 2:
self.residual_downsample = nn.AvgPool2d(
kernel_size=2, stride=2)
else:
self.residual_downsample = nn.Identity()

if in_c != out_c or force_resproj:
self.residual_proj = ConvKXBN(in_c, out_c, 1, 1)
else:
self.residual_proj = nn.Identity()

def forward(self, x):
if self.stride == 2:
return self.downsampler(x)
reslink = self.residual_downsample(x)
reslink = self.residual_proj(reslink)

output = x
output = self.conv1(output)
output = self.activation_function(output)
output = self.conv2(output)

output = output + reslink
output = self.activation_function(output)

return output


class SuperResConvKXKX(nn.Module):

def __init__(self,
in_c,
out_c,
btn_c,
kernel_size,
stride,
num_blocks,
with_spp=False,
act='silu'):
super(SuperResConvKXKX, self).__init__()
if act is None:
self.act = torch.relu
else:
self.act = get_activation(act)
self.block_list = nn.ModuleList()
for block_id in range(num_blocks):
if block_id == 0:
in_channels = in_c
out_channels = out_c
this_stride = stride
force_resproj = False # as a part of CSPLayer, DO NOT need this flag
this_kernel_size = kernel_size
else:
in_channels = out_c
out_channels = out_c
this_stride = 1
force_resproj = False
this_kernel_size = kernel_size
the_block = ResConvKXKX(
in_channels,
out_channels,
btn_c,
this_kernel_size,
this_stride,
force_resproj,
act=act)
self.block_list.append(the_block)
if block_id == 0 and with_spp:
self.block_list.append(
SPPBottleneck(out_channels, out_channels))

def forward(self, x):
output = x
for block in self.block_list:
output = block(output)
return output


class TinyNAS(nn.Module):

def __init__(self,
structure_info=None,
out_indices=[0, 1, 2, 4, 5],
out_channels=[None, None, 128, 256, 512],
with_spp=False,
use_focus=False,
need_conv1=True,
act='silu',
reparam=False):
super(TinyNAS, self).__init__()
assert len(out_indices) == len(out_channels)
self.out_indices = out_indices
self.need_conv1 = need_conv1

self.block_list = nn.ModuleList()
if need_conv1:
self.conv1_list = nn.ModuleList()
for idx, block_info in enumerate(structure_info):
the_block_class = block_info['class']
if the_block_class == 'ConvKXBNRELU':
if use_focus:
the_block = Focus(
block_info['in'],
block_info['out'],
block_info['k'],
act=act)
else:
the_block = ConvKXBNRELU(
block_info['in'],
block_info['out'],
block_info['k'],
block_info['s'],
act=act)
self.block_list.append(the_block)
elif the_block_class == 'SuperResConvK1KX':
spp = with_spp if idx == len(structure_info) - 1 else False
the_block = SuperResConvK1KX(
block_info['in'],
block_info['out'],
block_info['btn'],
block_info['k'],
block_info['s'],
block_info['L'],
spp,
act=act,
reparam=reparam)
self.block_list.append(the_block)
elif the_block_class == 'SuperResConvKXKX':
spp = with_spp if idx == len(structure_info) - 1 else False
the_block = SuperResConvKXKX(
block_info['in'],
block_info['out'],
block_info['btn'],
block_info['k'],
block_info['s'],
block_info['L'],
spp,
act=act)
self.block_list.append(the_block)
if need_conv1:
if idx in self.out_indices and out_channels[
self.out_indices.index(idx)] is not None:
self.conv1_list.append(
nn.Conv2d(block_info['out'],
out_channels[self.out_indices.index(idx)],
1))
else:
self.conv1_list.append(None)

def init_weights(self, pretrain=None):
pass

def forward(self, x):
output = x
stage_feature_list = []
for idx, block in enumerate(self.block_list):
output = block(output)
if idx in self.out_indices:
if self.need_conv1 and self.conv1_list[idx] is not None:
true_out = self.conv1_list[idx](output)
stage_feature_list.append(true_out)
else:
stage_feature_list.append(output)
return stage_feature_list


def load_tinynas_net(backbone_cfg):
# load masternet model to path
import ast
net_structure_str = read_file(backbone_cfg.structure_file)
struct_str = ''.join([x.strip() for x in net_structure_str])
struct_info = ast.literal_eval(struct_str)
for layer in struct_info:
if 'nbitsA' in layer:
del layer['nbitsA']
if 'nbitsW' in layer:
del layer['nbitsW']

model = TinyNAS(
structure_info=struct_info,
out_indices=backbone_cfg.out_indices,
out_channels=backbone_cfg.out_channels,
with_spp=backbone_cfg.with_spp,
use_focus=backbone_cfg.use_focus,
act=backbone_cfg.act,
need_conv1=backbone_cfg.need_conv1,
reparam=backbone_cfg.reparam)

return model

+ 295
- 0
modelscope/models/cv/tinynas_detection/backbone/tinynas_csp.py View File

@@ -0,0 +1,295 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# The DAMO-YOLO implementation is also open-sourced by the authors, and available
# at https://github.com/tinyvision/damo-yolo.

import torch
import torch.nn as nn

from modelscope.models.cv.tinynas_detection.core.ops import (Focus, RepConv,
SPPBottleneck,
get_activation)
from modelscope.utils.file_utils import read_file


class ConvKXBN(nn.Module):

def __init__(self, in_c, out_c, kernel_size, stride):
super(ConvKXBN, self).__init__()
self.conv1 = nn.Conv2d(
in_c,
out_c,
kernel_size,
stride, (kernel_size - 1) // 2,
groups=1,
bias=False)
self.bn1 = nn.BatchNorm2d(out_c)

def forward(self, x):
return self.bn1(self.conv1(x))


class ConvKXBNRELU(nn.Module):

def __init__(self, in_c, out_c, kernel_size, stride, act='silu'):
super(ConvKXBNRELU, self).__init__()
self.conv = ConvKXBN(in_c, out_c, kernel_size, stride)
if act is None:
self.activation_function = torch.relu
else:
self.activation_function = get_activation(act)

def forward(self, x):
output = self.conv(x)
return self.activation_function(output)


class ResConvBlock(nn.Module):

def __init__(self,
in_c,
out_c,
btn_c,
kernel_size,
stride,
act='silu',
reparam=False,
block_type='k1kx'):
super(ResConvBlock, self).__init__()
self.stride = stride
if block_type == 'k1kx':
self.conv1 = ConvKXBN(in_c, btn_c, kernel_size=1, stride=1)
else:
self.conv1 = ConvKXBN(
in_c, btn_c, kernel_size=kernel_size, stride=1)
if not reparam:
self.conv2 = ConvKXBN(btn_c, out_c, kernel_size, stride)
else:
self.conv2 = RepConv(
btn_c, out_c, kernel_size, stride, act='identity')

self.activation_function = get_activation(act)

if in_c != out_c and stride != 2:
self.residual_proj = ConvKXBN(in_c, out_c, kernel_size=1, stride=1)
else:
self.residual_proj = None

def forward(self, x):
if self.residual_proj is not None:
reslink = self.residual_proj(x)
else:
reslink = x
x = self.conv1(x)
x = self.activation_function(x)
x = self.conv2(x)
if self.stride != 2:
x = x + reslink
x = self.activation_function(x)
return x


class CSPStem(nn.Module):

def __init__(self,
in_c,
out_c,
btn_c,
stride,
kernel_size,
num_blocks,
act='silu',
reparam=False,
block_type='k1kx'):
super(CSPStem, self).__init__()
self.in_channels = in_c
self.out_channels = out_c
self.stride = stride
if self.stride == 2:
self.num_blocks = num_blocks - 1
else:
self.num_blocks = num_blocks
self.kernel_size = kernel_size
self.act = act
self.block_type = block_type
out_c = out_c // 2

if act is None:
self.act = torch.relu
else:
self.act = get_activation(act)
self.block_list = nn.ModuleList()
for block_id in range(self.num_blocks):
if self.stride == 1 and block_id == 0:
in_c = in_c // 2
else:
in_c = out_c
the_block = ResConvBlock(
in_c,
out_c,
btn_c,
kernel_size,
stride=1,
act=act,
reparam=reparam,
block_type=block_type)
self.block_list.append(the_block)

def forward(self, x):
output = x
for block in self.block_list:
output = block(output)
return output


class TinyNAS(nn.Module):

def __init__(self,
structure_info=None,
out_indices=[2, 3, 4],
with_spp=False,
use_focus=False,
act='silu',
reparam=False):
super(TinyNAS, self).__init__()
self.out_indices = out_indices
self.block_list = nn.ModuleList()
self.stride_list = []

for idx, block_info in enumerate(structure_info):
the_block_class = block_info['class']
if the_block_class == 'ConvKXBNRELU':
if use_focus and idx == 0:
the_block = Focus(
block_info['in'],
block_info['out'],
block_info['k'],
act=act)
else:
the_block = ConvKXBNRELU(
block_info['in'],
block_info['out'],
block_info['k'],
block_info['s'],
act=act)
elif the_block_class == 'SuperResConvK1KX':
the_block = CSPStem(
block_info['in'],
block_info['out'],
block_info['btn'],
block_info['s'],
block_info['k'],
block_info['L'],
act=act,
reparam=reparam,
block_type='k1kx')
elif the_block_class == 'SuperResConvKXKX':
the_block = CSPStem(
block_info['in'],
block_info['out'],
block_info['btn'],
block_info['s'],
block_info['k'],
block_info['L'],
act=act,
reparam=reparam,
block_type='kxkx')
else:
raise NotImplementedError

self.block_list.append(the_block)

self.csp_stage = nn.ModuleList()
self.csp_stage.append(self.block_list[0])
self.csp_stage.append(CSPWrapper(self.block_list[1]))
self.csp_stage.append(CSPWrapper(self.block_list[2]))
self.csp_stage.append(
CSPWrapper((self.block_list[3], self.block_list[4])))
self.csp_stage.append(
CSPWrapper(self.block_list[5], with_spp=with_spp))
del self.block_list

def init_weights(self, pretrain=None):
pass

def forward(self, x):
output = x
stage_feature_list = []
for idx, block in enumerate(self.csp_stage):
output = block(output)
if idx in self.out_indices:
stage_feature_list.append(output)
return stage_feature_list


class CSPWrapper(nn.Module):

def __init__(self, convstem, act='relu', reparam=False, with_spp=False):

super(CSPWrapper, self).__init__()
self.with_spp = with_spp
if isinstance(convstem, tuple):
in_c = convstem[0].in_channels
out_c = convstem[-1].out_channels
hidden_dim = convstem[0].out_channels // 2
_convstem = nn.ModuleList()
for modulelist in convstem:
for layer in modulelist.block_list:
_convstem.append(layer)
else:
in_c = convstem.in_channels
out_c = convstem.out_channels
hidden_dim = out_c // 2
_convstem = convstem.block_list

self.convstem = nn.ModuleList()
for layer in _convstem:
self.convstem.append(layer)

self.act = get_activation(act)
self.downsampler = ConvKXBNRELU(
in_c, hidden_dim * 2, 3, 2, act=self.act)
if self.with_spp:
self.spp = SPPBottleneck(hidden_dim * 2, hidden_dim * 2)
if len(self.convstem) > 0:
self.conv_start = ConvKXBNRELU(
hidden_dim * 2, hidden_dim, 1, 1, act=self.act)
self.conv_shortcut = ConvKXBNRELU(
hidden_dim * 2, out_c // 2, 1, 1, act=self.act)
self.conv_fuse = ConvKXBNRELU(out_c, out_c, 1, 1, act=self.act)

def forward(self, x):
x = self.downsampler(x)
if self.with_spp:
x = self.spp(x)
if len(self.convstem) > 0:
shortcut = self.conv_shortcut(x)
x = self.conv_start(x)
for block in self.convstem:
x = block(x)
x = torch.cat((x, shortcut), dim=1)
x = self.conv_fuse(x)
return x


def load_tinynas_net(backbone_cfg):
# load masternet model to path
import ast

net_structure_str = read_file(backbone_cfg.structure_file)
struct_str = ''.join([x.strip() for x in net_structure_str])
struct_info = ast.literal_eval(struct_str)
for layer in struct_info:
if 'nbitsA' in layer:
del layer['nbitsA']
if 'nbitsW' in layer:
del layer['nbitsW']

model = TinyNAS(
structure_info=struct_info,
out_indices=backbone_cfg.out_indices,
with_spp=backbone_cfg.with_spp,
use_focus=backbone_cfg.use_focus,
act=backbone_cfg.act,
reparam=backbone_cfg.reparam)

return model

+ 238
- 0
modelscope/models/cv/tinynas_detection/backbone/tinynas_res.py View File

@@ -0,0 +1,238 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# The DAMO-YOLO implementation is also open-sourced by the authors, and available
# at https://github.com/tinyvision/damo-yolo.

import torch
import torch.nn as nn

from modelscope.models.cv.tinynas_detection.core.ops import (Focus, RepConv,
SPPBottleneck,
get_activation)
from modelscope.utils.file_utils import read_file


class ConvKXBN(nn.Module):

def __init__(self, in_c, out_c, kernel_size, stride):
super(ConvKXBN, self).__init__()
self.conv1 = nn.Conv2d(
in_c,
out_c,
kernel_size,
stride, (kernel_size - 1) // 2,
groups=1,
bias=False)
self.bn1 = nn.BatchNorm2d(out_c)

def forward(self, x):
return self.bn1(self.conv1(x))


class ConvKXBNRELU(nn.Module):

def __init__(self, in_c, out_c, kernel_size, stride, act='silu'):
super(ConvKXBNRELU, self).__init__()
self.conv = ConvKXBN(in_c, out_c, kernel_size, stride)
if act is None:
self.activation_function = torch.relu
else:
self.activation_function = get_activation(act)

def forward(self, x):
output = self.conv(x)
return self.activation_function(output)


class ResConvBlock(nn.Module):

def __init__(self,
in_c,
out_c,
btn_c,
kernel_size,
stride,
act='silu',
reparam=False,
block_type='k1kx'):
super(ResConvBlock, self).__init__()
self.stride = stride
if block_type == 'k1kx':
self.conv1 = ConvKXBN(in_c, btn_c, kernel_size=1, stride=1)
else:
self.conv1 = ConvKXBN(
in_c, btn_c, kernel_size=kernel_size, stride=1)

if not reparam:
self.conv2 = ConvKXBN(btn_c, out_c, kernel_size, stride)
else:
self.conv2 = RepConv(
btn_c, out_c, kernel_size, stride, act='identity')

self.activation_function = get_activation(act)

if in_c != out_c and stride != 2:
self.residual_proj = ConvKXBN(in_c, out_c, 1, 1)
else:
self.residual_proj = None

def forward(self, x):
if self.residual_proj is not None:
reslink = self.residual_proj(x)
else:
reslink = x
x = self.conv1(x)
x = self.activation_function(x)
x = self.conv2(x)
if self.stride != 2:
x = x + reslink
x = self.activation_function(x)
return x


class SuperResStem(nn.Module):

def __init__(self,
in_c,
out_c,
btn_c,
kernel_size,
stride,
num_blocks,
with_spp=False,
act='silu',
reparam=False,
block_type='k1kx'):
super(SuperResStem, self).__init__()
if act is None:
self.act = torch.relu
else:
self.act = get_activation(act)
self.block_list = nn.ModuleList()
for block_id in range(num_blocks):
if block_id == 0:
in_channels = in_c
out_channels = out_c
this_stride = stride
this_kernel_size = kernel_size
else:
in_channels = out_c
out_channels = out_c
this_stride = 1
this_kernel_size = kernel_size
the_block = ResConvBlock(
in_channels,
out_channels,
btn_c,
this_kernel_size,
this_stride,
act=act,
reparam=reparam,
block_type=block_type)
self.block_list.append(the_block)
if block_id == 0 and with_spp:
self.block_list.append(
SPPBottleneck(out_channels, out_channels))

def forward(self, x):
output = x
for block in self.block_list:
output = block(output)
return output


class TinyNAS(nn.Module):

def __init__(self,
structure_info=None,
out_indices=[2, 4, 5],
with_spp=False,
use_focus=False,
act='silu',
reparam=False):
super(TinyNAS, self).__init__()
self.out_indices = out_indices
self.block_list = nn.ModuleList()

for idx, block_info in enumerate(structure_info):
the_block_class = block_info['class']
if the_block_class == 'ConvKXBNRELU':
if use_focus:
the_block = Focus(
block_info['in'],
block_info['out'],
block_info['k'],
act=act)
else:
the_block = ConvKXBNRELU(
block_info['in'],
block_info['out'],
block_info['k'],
block_info['s'],
act=act)
self.block_list.append(the_block)
elif the_block_class == 'SuperResConvK1KX':
spp = with_spp if idx == len(structure_info) - 1 else False
the_block = SuperResStem(
block_info['in'],
block_info['out'],
block_info['btn'],
block_info['k'],
block_info['s'],
block_info['L'],
spp,
act=act,
reparam=reparam,
block_type='k1kx')
self.block_list.append(the_block)
elif the_block_class == 'SuperResConvKXKX':
spp = with_spp if idx == len(structure_info) - 1 else False
the_block = SuperResStem(
block_info['in'],
block_info['out'],
block_info['btn'],
block_info['k'],
block_info['s'],
block_info['L'],
spp,
act=act,
reparam=reparam,
block_type='kxkx')
self.block_list.append(the_block)
else:
raise NotImplementedError

def init_weights(self, pretrain=None):
pass

def forward(self, x):
output = x
stage_feature_list = []
for idx, block in enumerate(self.block_list):
output = block(output)
if idx in self.out_indices:
stage_feature_list.append(output)
return stage_feature_list


def load_tinynas_net(backbone_cfg):
# load masternet model to path
import ast

net_structure_str = read_file(backbone_cfg.structure_file)
struct_str = ''.join([x.strip() for x in net_structure_str])
struct_info = ast.literal_eval(struct_str)
for layer in struct_info:
if 'nbitsA' in layer:
del layer['nbitsA']
if 'nbitsW' in layer:
del layer['nbitsW']

model = TinyNAS(
structure_info=struct_info,
out_indices=backbone_cfg.out_indices,
with_spp=backbone_cfg.with_spp,
use_focus=backbone_cfg.use_focus,
act=backbone_cfg.act,
reparam=backbone_cfg.reparam)

return model

+ 1
- 1
modelscope/models/cv/tinynas_detection/core/__init__.py View File

@@ -1,2 +1,2 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.
# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet.
# The DAMO-YOLO implementation is also open-sourced by the authors at https://github.com/tinyvision/damo-yolo.

+ 1
- 1
modelscope/models/cv/tinynas_detection/core/base_ops.py View File

@@ -1,5 +1,5 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.
# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet.
# The DAMO-YOLO implementation is also open-sourced by the authors at https://github.com/tinyvision/damo-yolo.
import math import math


import torch import torch


+ 1
- 1
modelscope/models/cv/tinynas_detection/core/neck_ops.py View File

@@ -1,5 +1,5 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.
# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet.
# The DAMO-YOLO implementation is also open-sourced by the authors at https://github.com/tinyvision/damo-yolo.


import numpy as np import numpy as np
import torch import torch


+ 435
- 0
modelscope/models/cv/tinynas_detection/core/ops.py View File

@@ -0,0 +1,435 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# The DAMO-YOLO implementation is also open-sourced by the authors at https://github.com/tinyvision/damo-yolo.

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


class SiLU(nn.Module):
"""export-friendly version of nn.SiLU()"""

@staticmethod
def forward(x):
return x * torch.sigmoid(x)


class Swish(nn.Module):

def __init__(self, inplace=True):
super(Swish, self).__init__()
self.inplace = inplace

def forward(self, x):
if self.inplace:
x.mul_(F.sigmoid(x))
return x
else:
return x * F.sigmoid(x)


def get_activation(name='silu', inplace=True):
if name is None:
return nn.Identity()

if isinstance(name, str):
if name == 'silu':
module = nn.SiLU(inplace=inplace)
elif name == 'relu':
module = nn.ReLU(inplace=inplace)
elif name == 'lrelu':
module = nn.LeakyReLU(0.1, inplace=inplace)
elif name == 'swish':
module = Swish(inplace=inplace)
elif name == 'hardsigmoid':
module = nn.Hardsigmoid(inplace=inplace)
elif name == 'identity':
module = nn.Identity()
else:
raise AttributeError('Unsupported act type: {}'.format(name))
return module

elif isinstance(name, nn.Module):
return name

else:
raise AttributeError('Unsupported act type: {}'.format(name))


def get_norm(name, out_channels, inplace=True):
if name == 'bn':
module = nn.BatchNorm2d(out_channels)
else:
raise NotImplementedError
return module


class ConvBNAct(nn.Module):
"""A Conv2d -> Batchnorm -> silu/leaky relu block"""

def __init__(
self,
in_channels,
out_channels,
ksize,
stride=1,
groups=1,
bias=False,
act='silu',
norm='bn',
reparam=False,
):
super().__init__()
# same padding
pad = (ksize - 1) // 2
self.conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size=ksize,
stride=stride,
padding=pad,
groups=groups,
bias=bias,
)
if norm is not None:
self.bn = get_norm(norm, out_channels, inplace=True)
if act is not None:
self.act = get_activation(act, inplace=True)
self.with_norm = norm is not None
self.with_act = act is not None

def forward(self, x):
x = self.conv(x)
if self.with_norm:
x = self.bn(x)
if self.with_act:
x = self.act(x)
return x

def fuseforward(self, x):
return self.act(self.conv(x))


class SPPBottleneck(nn.Module):
"""Spatial pyramid pooling layer used in YOLOv3-SPP"""

def __init__(self,
in_channels,
out_channels,
kernel_sizes=(5, 9, 13),
activation='silu'):
super().__init__()
hidden_channels = in_channels // 2
self.conv1 = ConvBNAct(
in_channels, hidden_channels, 1, stride=1, act=activation)
self.m = nn.ModuleList([
nn.MaxPool2d(kernel_size=ks, stride=1, padding=ks // 2)
for ks in kernel_sizes
])
conv2_channels = hidden_channels * (len(kernel_sizes) + 1)
self.conv2 = ConvBNAct(
conv2_channels, out_channels, 1, stride=1, act=activation)

def forward(self, x):
x = self.conv1(x)
x = torch.cat([x] + [m(x) for m in self.m], dim=1)
x = self.conv2(x)
return x


class Focus(nn.Module):
"""Focus width and height information into channel space."""

def __init__(self,
in_channels,
out_channels,
ksize=1,
stride=1,
act='silu'):
super().__init__()
self.conv = ConvBNAct(
in_channels * 4, out_channels, ksize, stride, act=act)

def forward(self, x):
# shape of x (b,c,w,h) -> y(b,4c,w/2,h/2)
patch_top_left = x[..., ::2, ::2]
patch_top_right = x[..., ::2, 1::2]
patch_bot_left = x[..., 1::2, ::2]
patch_bot_right = x[..., 1::2, 1::2]
x = torch.cat(
(
patch_top_left,
patch_bot_left,
patch_top_right,
patch_bot_right,
),
dim=1,
)
return self.conv(x)


class BasicBlock_3x3_Reverse(nn.Module):

def __init__(self,
ch_in,
ch_hidden_ratio,
ch_out,
act='relu',
shortcut=True):
super(BasicBlock_3x3_Reverse, self).__init__()
assert ch_in == ch_out
ch_hidden = int(ch_in * ch_hidden_ratio)
self.conv1 = ConvBNAct(ch_hidden, ch_out, 3, stride=1, act=act)
self.conv2 = RepConv(ch_in, ch_hidden, 3, stride=1, act=act)
self.shortcut = shortcut

def forward(self, x):
y = self.conv2(x)
y = self.conv1(y)
if self.shortcut:
return x + y
else:
return y


class SPP(nn.Module):

def __init__(
self,
ch_in,
ch_out,
k,
pool_size,
act='swish',
):
super(SPP, self).__init__()
self.pool = []
for i, size in enumerate(pool_size):
pool = nn.MaxPool2d(
kernel_size=size, stride=1, padding=size // 2, ceil_mode=False)
self.add_module('pool{}'.format(i), pool)
self.pool.append(pool)
self.conv = ConvBNAct(ch_in, ch_out, k, act=act)

def forward(self, x):
outs = [x]

for pool in self.pool:
outs.append(pool(x))
y = torch.cat(outs, axis=1)

y = self.conv(y)
return y


class CSPStage(nn.Module):

def __init__(self,
block_fn,
ch_in,
ch_hidden_ratio,
ch_out,
n,
act='swish',
spp=False):
super(CSPStage, self).__init__()

split_ratio = 2
ch_first = int(ch_out // split_ratio)
ch_mid = int(ch_out - ch_first)
self.conv1 = ConvBNAct(ch_in, ch_first, 1, act=act)
self.conv2 = ConvBNAct(ch_in, ch_mid, 1, act=act)
self.convs = nn.Sequential()

next_ch_in = ch_mid
for i in range(n):
if block_fn == 'BasicBlock_3x3_Reverse':
self.convs.add_module(
str(i),
BasicBlock_3x3_Reverse(
next_ch_in,
ch_hidden_ratio,
ch_mid,
act=act,
shortcut=True))
else:
raise NotImplementedError
if i == (n - 1) // 2 and spp:
self.convs.add_module(
'spp', SPP(ch_mid * 4, ch_mid, 1, [5, 9, 13], act=act))
next_ch_in = ch_mid
self.conv3 = ConvBNAct(ch_mid * n + ch_first, ch_out, 1, act=act)

def forward(self, x):
y1 = self.conv1(x)
y2 = self.conv2(x)

mid_out = [y1]
for conv in self.convs:
y2 = conv(y2)
mid_out.append(y2)
y = torch.cat(mid_out, axis=1)
y = self.conv3(y)
return y


def conv_bn(in_channels, out_channels, kernel_size, stride, padding, groups=1):
'''Basic cell for rep-style block, including conv and bn'''
result = nn.Sequential()
result.add_module(
'conv',
nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
groups=groups,
bias=False))
result.add_module('bn', nn.BatchNorm2d(num_features=out_channels))
return result


class RepConv(nn.Module):
'''RepConv is a basic rep-style block, including training and deploy status
Code is based on https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py
'''

def __init__(self,
in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1,
dilation=1,
groups=1,
padding_mode='zeros',
deploy=False,
act='relu',
norm=None):
super(RepConv, self).__init__()
self.deploy = deploy
self.groups = groups
self.in_channels = in_channels
self.out_channels = out_channels

assert kernel_size == 3
assert padding == 1

padding_11 = padding - kernel_size // 2

if isinstance(act, str):
self.nonlinearity = get_activation(act)
else:
self.nonlinearity = act

if deploy:
self.rbr_reparam = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=True,
padding_mode=padding_mode)

else:
self.rbr_identity = None
self.rbr_dense = conv_bn(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
groups=groups)
self.rbr_1x1 = conv_bn(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
stride=stride,
padding=padding_11,
groups=groups)

def forward(self, inputs):
'''Forward process'''
if hasattr(self, 'rbr_reparam'):
return self.nonlinearity(self.rbr_reparam(inputs))

if self.rbr_identity is None:
id_out = 0
else:
id_out = self.rbr_identity(inputs)

return self.nonlinearity(
self.rbr_dense(inputs) + self.rbr_1x1(inputs) + id_out)

def get_equivalent_kernel_bias(self):
kernel3x3, bias3x3 = self._fuse_bn_tensor(self.rbr_dense)
kernel1x1, bias1x1 = self._fuse_bn_tensor(self.rbr_1x1)
kernelid, biasid = self._fuse_bn_tensor(self.rbr_identity)
return kernel3x3 + self._pad_1x1_to_3x3_tensor(
kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid

def _pad_1x1_to_3x3_tensor(self, kernel1x1):
if kernel1x1 is None:
return 0
else:
return torch.nn.functional.pad(kernel1x1, [1, 1, 1, 1])

def _fuse_bn_tensor(self, branch):
if branch is None:
return 0, 0
if isinstance(branch, nn.Sequential):
kernel = branch.conv.weight
running_mean = branch.bn.running_mean
running_var = branch.bn.running_var
gamma = branch.bn.weight
beta = branch.bn.bias
eps = branch.bn.eps
else:
assert isinstance(branch, nn.BatchNorm2d)
if not hasattr(self, 'id_tensor'):
input_dim = self.in_channels // self.groups
kernel_value = np.zeros((self.in_channels, input_dim, 3, 3),
dtype=np.float32)
for i in range(self.in_channels):
kernel_value[i, i % input_dim, 1, 1] = 1
self.id_tensor = torch.from_numpy(kernel_value).to(
branch.weight.device)
kernel = self.id_tensor
running_mean = branch.running_mean
running_var = branch.running_var
gamma = branch.weight
beta = branch.bias
eps = branch.eps
std = (running_var + eps).sqrt()
t = (gamma / std).reshape(-1, 1, 1, 1)
return kernel * t, beta - running_mean * gamma / std

def switch_to_deploy(self):
if hasattr(self, 'rbr_reparam'):
return
kernel, bias = self.get_equivalent_kernel_bias()
self.rbr_reparam = nn.Conv2d(
in_channels=self.rbr_dense.conv.in_channels,
out_channels=self.rbr_dense.conv.out_channels,
kernel_size=self.rbr_dense.conv.kernel_size,
stride=self.rbr_dense.conv.stride,
padding=self.rbr_dense.conv.padding,
dilation=self.rbr_dense.conv.dilation,
groups=self.rbr_dense.conv.groups,
bias=True)
self.rbr_reparam.weight.data = kernel
self.rbr_reparam.bias.data = bias
for para in self.parameters():
para.detach_()
self.__delattr__('rbr_dense')
self.__delattr__('rbr_1x1')
if hasattr(self, 'rbr_identity'):
self.__delattr__('rbr_identity')
if hasattr(self, 'id_tensor'):
self.__delattr__('id_tensor')
self.deploy = True

+ 1
- 1
modelscope/models/cv/tinynas_detection/core/repvgg_block.py View File

@@ -1,5 +1,5 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.
# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet.
# The DAMO-YOLO implementation is also open-sourced by the authors at https://github.com/tinyvision/damo-yolo.


import numpy as np import numpy as np
import torch import torch


+ 1
- 1
modelscope/models/cv/tinynas_detection/core/utils.py View File

@@ -1,5 +1,5 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.
# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet.
# The DAMO-YOLO implementation is also open-sourced by the authors at https://github.com/tinyvision/damo-yolo.


import numpy as np import numpy as np
import torch import torch


+ 2
- 2
modelscope/models/cv/tinynas_detection/detector.py View File

@@ -1,5 +1,5 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.
# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet.
# The DAMO-YOLO implementation is also open-sourced by the authors at https://github.com/tinyvision/damo-yolo.


import os.path as osp import os.path as osp
import pickle import pickle
@@ -42,7 +42,7 @@ class SingleStageDetector(TorchModel):
self.conf_thre = config.model.head.nms_conf_thre self.conf_thre = config.model.head.nms_conf_thre
self.nms_thre = config.model.head.nms_iou_thre self.nms_thre = config.model.head.nms_iou_thre


if self.cfg.model.backbone.name == 'TinyNAS':
if 'TinyNAS' in self.cfg.model.backbone.name:
self.cfg.model.backbone.structure_file = osp.join( self.cfg.model.backbone.structure_file = osp.join(
model_dir, self.cfg.model.backbone.structure_file) model_dir, self.cfg.model.backbone.structure_file)
self.backbone = build_backbone(self.cfg.model.backbone) self.backbone = build_backbone(self.cfg.model.backbone)


+ 4
- 1
modelscope/models/cv/tinynas_detection/head/__init__.py View File

@@ -1,9 +1,10 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.
# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet.
# The DAMO-YOLO implementation is also open-sourced by the authors at https://github.com/tinyvision/damo-yolo.


import copy import copy


from .gfocal_v2_tiny import GFocalHead_Tiny from .gfocal_v2_tiny import GFocalHead_Tiny
from .zero_head import ZeroHead




def build_head(cfg): def build_head(cfg):
@@ -12,5 +13,7 @@ def build_head(cfg):
name = head_cfg.pop('name') name = head_cfg.pop('name')
if name == 'GFocalV2': if name == 'GFocalV2':
return GFocalHead_Tiny(**head_cfg) return GFocalHead_Tiny(**head_cfg)
elif name == 'ZeroHead':
return ZeroHead(**head_cfg)
else: else:
raise NotImplementedError raise NotImplementedError

+ 3
- 2
modelscope/models/cv/tinynas_detection/head/gfocal_v2_tiny.py View File

@@ -1,5 +1,5 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.
# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet.
# The DAMO-YOLO implementation is also open-sourced by the authors at https://github.com/tinyvision/damo-yolo.


import functools import functools
from functools import partial from functools import partial
@@ -9,7 +9,8 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F


from ..core.base_ops import BaseConv, DWConv
from modelscope.models.cv.tinynas_detection.core.base_ops import (BaseConv,
DWConv)




class Scale(nn.Module): class Scale(nn.Module):


+ 288
- 0
modelscope/models/cv/tinynas_detection/head/zero_head.py View File

@@ -0,0 +1,288 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# The DAMO-YOLO implementation is also open-sourced by the authors, and available
# at https://github.com/tinyvision/damo-yolo.
from functools import partial

import torch
import torch.nn as nn
import torch.nn.functional as F

from modelscope.models.cv.tinynas_detection.core.ops import ConvBNAct


class Scale(nn.Module):

def __init__(self, scale=1.0):
super(Scale, self).__init__()
self.scale = nn.Parameter(torch.tensor(scale, dtype=torch.float))

def forward(self, x):
return x * self.scale


def multi_apply(func, *args, **kwargs):

pfunc = partial(func, **kwargs) if kwargs else func
map_results = map(pfunc, *args)
return tuple(map(list, zip(*map_results)))


def distance2bbox(points, distance, max_shape=None):
"""Decode distance prediction to bounding box.
"""
x1 = points[..., 0] - distance[..., 0]
y1 = points[..., 1] - distance[..., 1]
x2 = points[..., 0] + distance[..., 2]
y2 = points[..., 1] + distance[..., 3]
if max_shape is not None:
x1 = x1.clamp(min=0, max=max_shape[1])
y1 = y1.clamp(min=0, max=max_shape[0])
x2 = x2.clamp(min=0, max=max_shape[1])
y2 = y2.clamp(min=0, max=max_shape[0])
return torch.stack([x1, y1, x2, y2], -1)


def bbox2distance(points, bbox, max_dis=None, eps=0.1):
"""Decode bounding box based on distances.
"""
left = points[:, 0] - bbox[:, 0]
top = points[:, 1] - bbox[:, 1]
right = bbox[:, 2] - points[:, 0]
bottom = bbox[:, 3] - points[:, 1]
if max_dis is not None:
left = left.clamp(min=0, max=max_dis - eps)
top = top.clamp(min=0, max=max_dis - eps)
right = right.clamp(min=0, max=max_dis - eps)
bottom = bottom.clamp(min=0, max=max_dis - eps)
return torch.stack([left, top, right, bottom], -1)


class Integral(nn.Module):
"""A fixed layer for calculating integral result from distribution.
"""

def __init__(self, reg_max=16):
super(Integral, self).__init__()
self.reg_max = reg_max
self.register_buffer('project',
torch.linspace(0, self.reg_max, self.reg_max + 1))

def forward(self, x):
"""Forward feature from the regression head to get integral result of
bounding box location.
"""
b, hw, _, _ = x.size()
x = x.reshape(b * hw * 4, self.reg_max + 1)
y = self.project.type_as(x).unsqueeze(1)
x = torch.matmul(x, y).reshape(b, hw, 4)
return x


class ZeroHead(nn.Module):
"""Ref to Generalized Focal Loss V2: Learning Reliable Localization Quality
Estimation for Dense Object Detection.
"""

def __init__(
self,
num_classes,
in_channels,
stacked_convs=4, # 4
feat_channels=256,
reg_max=12,
strides=[8, 16, 32],
norm='gn',
act='relu',
nms_conf_thre=0.05,
nms_iou_thre=0.7,
nms=True,
**kwargs):
self.in_channels = in_channels
self.num_classes = num_classes
self.stacked_convs = stacked_convs
self.act = act
self.strides = strides
if stacked_convs == 0:
feat_channels = in_channels
if isinstance(feat_channels, list):
self.feat_channels = feat_channels
else:
self.feat_channels = [feat_channels] * len(self.strides)
# add 1 for keep consistance with former models
self.cls_out_channels = num_classes + 1
self.reg_max = reg_max

self.nms = nms
self.nms_conf_thre = nms_conf_thre
self.nms_iou_thre = nms_iou_thre

self.feat_size = [torch.zeros(4) for _ in strides]

super(ZeroHead, self).__init__()
self.integral = Integral(self.reg_max)

self._init_layers()

def _build_not_shared_convs(self, in_channel, feat_channels):
cls_convs = nn.ModuleList()
reg_convs = nn.ModuleList()

for i in range(self.stacked_convs):
chn = feat_channels if i > 0 else in_channel
kernel_size = 3 if i > 0 else 1
cls_convs.append(
ConvBNAct(
chn,
feat_channels,
kernel_size,
stride=1,
groups=1,
norm='bn',
act=self.act))
reg_convs.append(
ConvBNAct(
chn,
feat_channels,
kernel_size,
stride=1,
groups=1,
norm='bn',
act=self.act))

return cls_convs, reg_convs

def _init_layers(self):
"""Initialize layers of the head."""
self.cls_convs = nn.ModuleList()
self.reg_convs = nn.ModuleList()

for i in range(len(self.strides)):
cls_convs, reg_convs = self._build_not_shared_convs(
self.in_channels[i], self.feat_channels[i])
self.cls_convs.append(cls_convs)
self.reg_convs.append(reg_convs)

self.gfl_cls = nn.ModuleList([
nn.Conv2d(
self.feat_channels[i], self.cls_out_channels, 3, padding=1)
for i in range(len(self.strides))
])

self.gfl_reg = nn.ModuleList([
nn.Conv2d(
self.feat_channels[i], 4 * (self.reg_max + 1), 3, padding=1)
for i in range(len(self.strides))
])

self.scales = nn.ModuleList([Scale(1.0) for _ in self.strides])

def forward(self, xin, labels=None, imgs=None, aux_targets=None):
if self.training:
return NotImplementedError
else:
return self.forward_eval(xin=xin, labels=labels, imgs=imgs)

def forward_eval(self, xin, labels=None, imgs=None):

# prepare priors for label assignment and bbox decode
if self.feat_size[0] != xin[0].shape:
mlvl_priors_list = [
self.get_single_level_center_priors(
xin[i].shape[0],
xin[i].shape[-2:],
stride,
dtype=torch.float32,
device=xin[0].device)
for i, stride in enumerate(self.strides)
]
self.mlvl_priors = torch.cat(mlvl_priors_list, dim=1)
self.feat_size[0] = xin[0].shape

# forward for bboxes and classification prediction
cls_scores, bbox_preds = multi_apply(
self.forward_single,
xin,
self.cls_convs,
self.reg_convs,
self.gfl_cls,
self.gfl_reg,
self.scales,
)
cls_scores = torch.cat(cls_scores, dim=1)[:, :, :self.num_classes]
bbox_preds = torch.cat(bbox_preds, dim=1)
# batch bbox decode
bbox_preds = self.integral(bbox_preds) * self.mlvl_priors[..., 2, None]
bbox_preds = distance2bbox(self.mlvl_priors[..., :2], bbox_preds)

res = torch.cat([bbox_preds, cls_scores[..., 0:self.num_classes]],
dim=-1)
return res

def forward_single(self, x, cls_convs, reg_convs, gfl_cls, gfl_reg, scale):
"""Forward feature of a single scale level.

"""
cls_feat = x
reg_feat = x

for cls_conv, reg_conv in zip(cls_convs, reg_convs):
cls_feat = cls_conv(cls_feat)
reg_feat = reg_conv(reg_feat)

bbox_pred = scale(gfl_reg(reg_feat)).float()
N, C, H, W = bbox_pred.size()
if self.training:
bbox_before_softmax = bbox_pred.reshape(N, 4, self.reg_max + 1, H,
W)
bbox_before_softmax = bbox_before_softmax.flatten(
start_dim=3).permute(0, 3, 1, 2)
bbox_pred = F.softmax(
bbox_pred.reshape(N, 4, self.reg_max + 1, H, W), dim=2)

cls_score = gfl_cls(cls_feat).sigmoid()

cls_score = cls_score.flatten(start_dim=2).permute(
0, 2, 1) # N, h*w, self.num_classes+1
bbox_pred = bbox_pred.flatten(start_dim=3).permute(
0, 3, 1, 2) # N, h*w, 4, self.reg_max+1
if self.training:
return cls_score, bbox_pred, bbox_before_softmax
else:
return cls_score, bbox_pred

def get_single_level_center_priors(self, batch_size, featmap_size, stride,
dtype, device):

h, w = featmap_size
x_range = (torch.arange(0, int(w), dtype=dtype,
device=device)) * stride
y_range = (torch.arange(0, int(h), dtype=dtype,
device=device)) * stride

x = x_range.repeat(h, 1)
y = y_range.unsqueeze(-1).repeat(1, w)

y = y.flatten()
x = x.flatten()
strides = x.new_full((x.shape[0], ), stride)
priors = torch.stack([x, y, strides, strides], dim=-1)

return priors.unsqueeze(0).repeat(batch_size, 1, 1)

def sample(self, assign_result, gt_bboxes):
pos_inds = torch.nonzero(
assign_result.gt_inds > 0, as_tuple=False).squeeze(-1).unique()
neg_inds = torch.nonzero(
assign_result.gt_inds == 0, as_tuple=False).squeeze(-1).unique()
pos_assigned_gt_inds = assign_result.gt_inds[pos_inds] - 1

if gt_bboxes.numel() == 0:
# hack for index error case
assert pos_assigned_gt_inds.numel() == 0
pos_gt_bboxes = torch.empty_like(gt_bboxes).view(-1, 4)
else:
if len(gt_bboxes.shape) < 2:
gt_bboxes = gt_bboxes.view(-1, 4)
pos_gt_bboxes = gt_bboxes[pos_assigned_gt_inds, :]

return pos_inds, neg_inds, pos_gt_bboxes, pos_assigned_gt_inds

+ 2
- 2
modelscope/models/cv/tinynas_detection/neck/__init__.py View File

@@ -1,10 +1,10 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.
# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet.
# The DAMO-YOLO implementation is also open-sourced by the authors at https://github.com/tinyvision/damo-yolo.


import copy import copy


from .giraffe_fpn import GiraffeNeck from .giraffe_fpn import GiraffeNeck
from .giraffe_fpn_v2 import GiraffeNeckV2
from .giraffe_fpn_btn import GiraffeNeckV2




def build_neck(cfg): def build_neck(cfg):


+ 1
- 1
modelscope/models/cv/tinynas_detection/neck/giraffe_config.py View File

@@ -1,5 +1,5 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.
# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet.
# The DAMO-YOLO implementation is also open-sourced by the authors at https://github.com/tinyvision/damo-yolo.


import collections import collections
import itertools import itertools


+ 3
- 2
modelscope/models/cv/tinynas_detection/neck/giraffe_fpn.py View File

@@ -1,5 +1,5 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.
# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet.
# The DAMO-YOLO implementation is also open-sourced by the authors at https://github.com/tinyvision/damo-yolo.


import logging import logging
import math import math
@@ -15,7 +15,8 @@ from timm import create_model
from timm.models.layers import (Swish, create_conv2d, create_pool2d, from timm.models.layers import (Swish, create_conv2d, create_pool2d,
get_act_layer) get_act_layer)


from ..core.base_ops import CSPLayer, ShuffleBlock, ShuffleCSPLayer
from modelscope.models.cv.tinynas_detection.core.base_ops import (
CSPLayer, ShuffleBlock, ShuffleCSPLayer)
from .giraffe_config import get_graph_config from .giraffe_config import get_graph_config


_ACT_LAYER = Swish _ACT_LAYER = Swish


+ 132
- 0
modelscope/models/cv/tinynas_detection/neck/giraffe_fpn_btn.py View File

@@ -0,0 +1,132 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# The DAMO-YOLO implementation is also open-sourced by the authors at https://github.com/tinyvision/damo-yolo.

import torch
import torch.nn as nn

from modelscope.models.cv.tinynas_detection.core.ops import ConvBNAct, CSPStage


class GiraffeNeckV2(nn.Module):

def __init__(
self,
depth=1.0,
hidden_ratio=1.0,
in_features=[2, 3, 4],
in_channels=[256, 512, 1024],
out_channels=[256, 512, 1024],
act='silu',
spp=False,
block_name='BasicBlock',
):
super().__init__()
self.in_features = in_features
self.in_channels = in_channels
self.out_channels = out_channels
Conv = ConvBNAct

self.upsample = nn.Upsample(scale_factor=2, mode='nearest')

# node x3: input x0, x1
self.bu_conv13 = Conv(in_channels[1], in_channels[1], 3, 2, act=act)
self.merge_3 = CSPStage(
block_name,
in_channels[1] + in_channels[2],
hidden_ratio,
in_channels[2],
round(3 * depth),
act=act,
spp=spp)

# node x4: input x1, x2, x3
self.bu_conv24 = Conv(in_channels[0], in_channels[0], 3, 2, act=act)
self.merge_4 = CSPStage(
block_name,
in_channels[0] + in_channels[1] + in_channels[2],
hidden_ratio,
in_channels[1],
round(3 * depth),
act=act,
spp=spp)

# node x5: input x2, x4
self.merge_5 = CSPStage(
block_name,
in_channels[1] + in_channels[0],
hidden_ratio,
out_channels[0],
round(3 * depth),
act=act,
spp=spp)

# node x7: input x4, x5
self.bu_conv57 = Conv(out_channels[0], out_channels[0], 3, 2, act=act)
self.merge_7 = CSPStage(
block_name,
out_channels[0] + in_channels[1],
hidden_ratio,
out_channels[1],
round(3 * depth),
act=act,
spp=spp)

# node x6: input x3, x4, x7
self.bu_conv46 = Conv(in_channels[1], in_channels[1], 3, 2, act=act)
self.bu_conv76 = Conv(out_channels[1], out_channels[1], 3, 2, act=act)
self.merge_6 = CSPStage(
block_name,
in_channels[1] + out_channels[1] + in_channels[2],
hidden_ratio,
out_channels[2],
round(3 * depth),
act=act,
spp=spp)

def init_weights(self):
pass

def forward(self, out_features):
"""
Args:
inputs: input images.

Returns:
Tuple[Tensor]: FPN feature.
"""

# backbone
[x2, x1, x0] = out_features

# node x3
x13 = self.bu_conv13(x1)
x3 = torch.cat([x0, x13], 1)
x3 = self.merge_3(x3)

# node x4
x34 = self.upsample(x3)
x24 = self.bu_conv24(x2)
x4 = torch.cat([x1, x24, x34], 1)
x4 = self.merge_4(x4)

# node x5
x45 = self.upsample(x4)
x5 = torch.cat([x2, x45], 1)
x5 = self.merge_5(x5)

# node x8
# x8 = x5

# node x7
x57 = self.bu_conv57(x5)
x7 = torch.cat([x4, x57], 1)
x7 = self.merge_7(x7)

# node x6
x46 = self.bu_conv46(x4)
x76 = self.bu_conv76(x7)
x6 = torch.cat([x3, x46, x76], 1)
x6 = self.merge_6(x6)

outputs = (x5, x7, x6)
return outputs

+ 0
- 200
modelscope/models/cv/tinynas_detection/neck/giraffe_fpn_v2.py View File

@@ -1,200 +0,0 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet.

import torch
import torch.nn as nn

from ..core.base_ops import BaseConv, CSPLayer, DWConv
from ..core.neck_ops import CSPStage


class GiraffeNeckV2(nn.Module):

def __init__(
self,
depth=1.0,
width=1.0,
in_channels=[256, 512, 1024],
out_channels=[256, 512, 1024],
depthwise=False,
act='silu',
spp=True,
reparam_mode=True,
block_name='BasicBlock',
):
super().__init__()
self.in_channels = in_channels
Conv = DWConv if depthwise else BaseConv

reparam_mode = reparam_mode

self.upsample = nn.Upsample(scale_factor=2, mode='nearest')

# node x3: input x0, x1
self.bu_conv13 = Conv(
int(in_channels[1] * width),
int(in_channels[1] * width),
3,
2,
act=act)
if reparam_mode:
self.merge_3 = CSPStage(
block_name,
int((in_channels[1] + in_channels[2]) * width),
int(in_channels[2] * width),
round(3 * depth),
act=act,
spp=spp)
else:
self.merge_3 = CSPLayer(
int((in_channels[1] + in_channels[2]) * width),
int(in_channels[2] * width),
round(3 * depth),
False,
depthwise=depthwise,
act=act)

# node x4: input x1, x2, x3
self.bu_conv24 = Conv(
int(in_channels[0] * width),
int(in_channels[0] * width),
3,
2,
act=act)
if reparam_mode:
self.merge_4 = CSPStage(
block_name,
int((in_channels[0] + in_channels[1] + in_channels[2])
* width),
int(in_channels[1] * width),
round(3 * depth),
act=act,
spp=spp)
else:
self.merge_4 = CSPLayer(
int((in_channels[0] + in_channels[1] + in_channels[2])
* width),
int(in_channels[1] * width),
round(3 * depth),
False,
depthwise=depthwise,
act=act)

# node x5: input x2, x4
if reparam_mode:
self.merge_5 = CSPStage(
block_name,
int((in_channels[1] + in_channels[0]) * width),
int(out_channels[0] * width),
round(3 * depth),
act=act,
spp=spp)
else:
self.merge_5 = CSPLayer(
int((in_channels[1] + in_channels[0]) * width),
int(out_channels[0] * width),
round(3 * depth),
False,
depthwise=depthwise,
act=act)

# node x7: input x4, x5
self.bu_conv57 = Conv(
int(out_channels[0] * width),
int(out_channels[0] * width),
3,
2,
act=act)
if reparam_mode:
self.merge_7 = CSPStage(
block_name,
int((out_channels[0] + in_channels[1]) * width),
int(out_channels[1] * width),
round(3 * depth),
act=act,
spp=spp)
else:
self.merge_7 = CSPLayer(
int((out_channels[0] + in_channels[1]) * width),
int(out_channels[1] * width),
round(3 * depth),
False,
depthwise=depthwise,
act=act)

# node x6: input x3, x4, x7
self.bu_conv46 = Conv(
int(in_channels[1] * width),
int(in_channels[1] * width),
3,
2,
act=act)
self.bu_conv76 = Conv(
int(out_channels[1] * width),
int(out_channels[1] * width),
3,
2,
act=act)
if reparam_mode:
self.merge_6 = CSPStage(
block_name,
int((in_channels[1] + out_channels[1] + in_channels[2])
* width),
int(out_channels[2] * width),
round(3 * depth),
act=act,
spp=spp)
else:
self.merge_6 = CSPLayer(
int((in_channels[1] + out_channels[1] + in_channels[2])
* width),
int(out_channels[2] * width),
round(3 * depth),
False,
depthwise=depthwise,
act=act)

def init_weights(self):
pass

def forward(self, out_features):
"""
Args:
inputs: input images.

Returns:
Tuple[Tensor]: FPN feature.
"""

# backbone
[x2, x1, x0] = out_features

# node x3
x13 = self.bu_conv13(x1)
x3 = torch.cat([x0, x13], 1)
x3 = self.merge_3(x3)

# node x4
x34 = self.upsample(x3)
x24 = self.bu_conv24(x2)
x4 = torch.cat([x1, x24, x34], 1)
x4 = self.merge_4(x4)

# node x5
x45 = self.upsample(x4)
x5 = torch.cat([x2, x45], 1)
x5 = self.merge_5(x5)

# node x7
x57 = self.bu_conv57(x5)
x7 = torch.cat([x4, x57], 1)
x7 = self.merge_7(x7)

# node x6
x46 = self.bu_conv46(x4)
x76 = self.bu_conv76(x7)
x6 = torch.cat([x3, x46, x76], 1)
x6 = self.merge_6(x6)

outputs = (x5, x7, x6)
return outputs

+ 1
- 1
modelscope/models/cv/tinynas_detection/tinynas_damoyolo.py View File

@@ -11,5 +11,5 @@ from .detector import SingleStageDetector
class DamoYolo(SingleStageDetector): class DamoYolo(SingleStageDetector):


def __init__(self, model_dir, *args, **kwargs): def __init__(self, model_dir, *args, **kwargs):
self.config_name = 'damoyolo_s.py'
self.config_name = 'damoyolo.py'
super(DamoYolo, self).__init__(model_dir, *args, **kwargs) super(DamoYolo, self).__init__(model_dir, *args, **kwargs)

+ 1
- 1
modelscope/models/cv/tinynas_detection/tinynas_detector.py View File

@@ -1,5 +1,5 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.
# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet.
# The DAMO-YOLO implementation is also open-sourced by the authors at https://github.com/tinyvision/damo-yolo.


from modelscope.metainfo import Models from modelscope.metainfo import Models
from modelscope.models.builder import MODELS from modelscope.models.builder import MODELS


+ 23
- 20
modelscope/models/cv/tinynas_detection/utils.py View File

@@ -1,30 +1,33 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.
# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet.
# The DAMO-YOLO implementation is also open-sourced by the authors, and available
# at https://github.com/tinyvision/damo-yolo.


import importlib import importlib
import os import os
import shutil
import sys import sys
import tempfile
from os.path import dirname, join from os.path import dirname, join


from easydict import EasyDict


def get_config_by_file(config_file):
try:
sys.path.append(os.path.dirname(config_file))
current_config = importlib.import_module(
os.path.basename(config_file).split('.')[0])
exp = current_config.Config()
except Exception:
raise ImportError(
"{} doesn't contains class named 'Config'".format(config_file))
return exp


def parse_config(filename):
filename = str(filename)
if filename.endswith('.py'):
with tempfile.TemporaryDirectory() as temp_config_dir:
shutil.copyfile(filename, join(temp_config_dir, '_tempconfig.py'))
sys.path.insert(0, temp_config_dir)
mod = importlib.import_module('_tempconfig')
sys.path.pop(0)
cfg_dict = EasyDict({
name: value
for name, value in mod.__dict__.items()
if not name.startswith('__')
})
# delete imported module
del sys.modules['_tempconfig']
else:
raise IOError('Only .py type are supported now!')


def parse_config(config_file):
"""
get config object by file.
Args:
config_file (str): file path of config.
"""
assert (config_file is not None), 'plz provide config file'
if config_file is not None:
return get_config_by_file(config_file)
return cfg_dict

+ 20
- 2
tests/pipelines/test_tinynas_detection.py View File

@@ -29,7 +29,25 @@ class TinynasObjectDetectionTest(unittest.TestCase, DemoCompatibilityCheck):
model='damo/cv_tinynas_object-detection_damoyolo') model='damo/cv_tinynas_object-detection_damoyolo')
result = tinynas_object_detection( result = tinynas_object_detection(
'data/test/images/image_detection.jpg') 'data/test/images/image_detection.jpg')
print('damoyolo', result)
print('damoyolo-s', result)

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_damoyolo_m(self):
tinynas_object_detection = pipeline(
Tasks.image_object_detection,
model='damo/cv_tinynas_object-detection_damoyolo-m')
result = tinynas_object_detection(
'data/test/images/image_detection.jpg')
print('damoyolo-m', result)

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_damoyolo_t(self):
tinynas_object_detection = pipeline(
Tasks.image_object_detection,
model='damo/cv_tinynas_object-detection_damoyolo-t')
result = tinynas_object_detection(
'data/test/images/image_detection.jpg')
print('damoyolo-t', result)


@unittest.skip('demo compatibility test is only enabled on a needed-basis') @unittest.skip('demo compatibility test is only enabled on a needed-basis')
def test_demo_compatibility(self): def test_demo_compatibility(self):
@@ -40,7 +58,7 @@ class TinynasObjectDetectionTest(unittest.TestCase, DemoCompatibilityCheck):
test_image = 'data/test/images/image_detection.jpg' test_image = 'data/test/images/image_detection.jpg'
tinynas_object_detection = pipeline( tinynas_object_detection = pipeline(
Tasks.image_object_detection, Tasks.image_object_detection,
model='damo/cv_tinynas_object-detection_damoyolo')
model='damo/cv_tinynas_object-detection_damoyolo-m')
result = tinynas_object_detection(test_image) result = tinynas_object_detection(test_image)
tinynas_object_detection.show_result(test_image, result, tinynas_object_detection.show_result(test_image, result,
'demo_ret.jpg') 'demo_ret.jpg')


Loading…
Cancel
Save