diff --git a/modelscope/models/cv/tinynas_detection/__init__.py b/modelscope/models/cv/tinynas_detection/__init__.py index 6d696ac4..01c50b4b 100644 --- a/modelscope/models/cv/tinynas_detection/__init__.py +++ b/modelscope/models/cv/tinynas_detection/__init__.py @@ -1,5 +1,5 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet. +# The DAMO-YOLO implementation is also open-sourced by the authors at https://github.com/tinyvision/damo-yolo. from typing import TYPE_CHECKING diff --git a/modelscope/models/cv/tinynas_detection/backbone/__init__.py b/modelscope/models/cv/tinynas_detection/backbone/__init__.py index 186d06a3..22a7654f 100644 --- a/modelscope/models/cv/tinynas_detection/backbone/__init__.py +++ b/modelscope/models/cv/tinynas_detection/backbone/__init__.py @@ -1,10 +1,11 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet. +# The DAMO-YOLO implementation is also open-sourced by the authors at https://github.com/tinyvision/damo-yolo. import copy from .darknet import CSPDarknet -from .tinynas import load_tinynas_net +from .tinynas_csp import load_tinynas_net as load_tinynas_net_csp +from .tinynas_res import load_tinynas_net as load_tinynas_net_res def build_backbone(cfg): @@ -12,5 +13,7 @@ def build_backbone(cfg): name = backbone_cfg.pop('name') if name == 'CSPDarknet': return CSPDarknet(**backbone_cfg) - elif name == 'TinyNAS': - return load_tinynas_net(backbone_cfg) + elif name == 'TinyNAS_csp': + return load_tinynas_net_csp(backbone_cfg) + elif name == 'TinyNAS_res': + return load_tinynas_net_res(backbone_cfg) diff --git a/modelscope/models/cv/tinynas_detection/backbone/darknet.py b/modelscope/models/cv/tinynas_detection/backbone/darknet.py index d3294f0d..d8f80e76 100644 --- a/modelscope/models/cv/tinynas_detection/backbone/darknet.py +++ b/modelscope/models/cv/tinynas_detection/backbone/darknet.py @@ -1,12 +1,11 @@ # Copyright (c) Megvii Inc. All rights reserved. # Copyright (c) Alibaba, Inc. and its affiliates. -# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet. import torch from torch import nn -from ..core.base_ops import (BaseConv, CSPLayer, DWConv, Focus, ResLayer, - SPPBottleneck) +from modelscope.models.cv.tinynas_detection.core.base_ops import ( + BaseConv, CSPLayer, DWConv, Focus, ResLayer, SPPBottleneck) class CSPDarknet(nn.Module): diff --git a/modelscope/models/cv/tinynas_detection/backbone/tinynas.py b/modelscope/models/cv/tinynas_detection/backbone/tinynas.py deleted file mode 100755 index 202bdd55..00000000 --- a/modelscope/models/cv/tinynas_detection/backbone/tinynas.py +++ /dev/null @@ -1,359 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. -# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet. - -import torch -import torch.nn as nn - -from modelscope.utils.file_utils import read_file -from ..core.base_ops import Focus, SPPBottleneck, get_activation -from ..core.repvgg_block import RepVggBlock - - -class ConvKXBN(nn.Module): - - def __init__(self, in_c, out_c, kernel_size, stride): - super(ConvKXBN, self).__init__() - self.conv1 = nn.Conv2d( - in_c, - out_c, - kernel_size, - stride, (kernel_size - 1) // 2, - groups=1, - bias=False) - self.bn1 = nn.BatchNorm2d(out_c) - - def forward(self, x): - return self.bn1(self.conv1(x)) - - -class ConvKXBNRELU(nn.Module): - - def __init__(self, in_c, out_c, kernel_size, stride, act='silu'): - super(ConvKXBNRELU, self).__init__() - self.conv = ConvKXBN(in_c, out_c, kernel_size, stride) - if act is None: - self.activation_function = torch.relu - else: - self.activation_function = get_activation(act) - - def forward(self, x): - output = self.conv(x) - return self.activation_function(output) - - -class ResConvK1KX(nn.Module): - - def __init__(self, - in_c, - out_c, - btn_c, - kernel_size, - stride, - force_resproj=False, - act='silu', - reparam=False): - super(ResConvK1KX, self).__init__() - self.stride = stride - self.conv1 = ConvKXBN(in_c, btn_c, 1, 1) - if not reparam: - self.conv2 = ConvKXBN(btn_c, out_c, 3, stride) - else: - self.conv2 = RepVggBlock( - btn_c, out_c, kernel_size, stride, act='identity') - - if act is None: - self.activation_function = torch.relu - else: - self.activation_function = get_activation(act) - - if stride == 2: - self.residual_downsample = nn.AvgPool2d(kernel_size=2, stride=2) - else: - self.residual_downsample = nn.Identity() - - if in_c != out_c or force_resproj: - self.residual_proj = ConvKXBN(in_c, out_c, 1, 1) - else: - self.residual_proj = nn.Identity() - - def forward(self, x): - if self.stride != 2: - reslink = self.residual_downsample(x) - reslink = self.residual_proj(reslink) - - output = x - output = self.conv1(output) - output = self.activation_function(output) - output = self.conv2(output) - if self.stride != 2: - output = output + reslink - output = self.activation_function(output) - - return output - - -class SuperResConvK1KX(nn.Module): - - def __init__(self, - in_c, - out_c, - btn_c, - kernel_size, - stride, - num_blocks, - with_spp=False, - act='silu', - reparam=False): - super(SuperResConvK1KX, self).__init__() - if act is None: - self.act = torch.relu - else: - self.act = get_activation(act) - self.block_list = nn.ModuleList() - for block_id in range(num_blocks): - if block_id == 0: - in_channels = in_c - out_channels = out_c - this_stride = stride - force_resproj = False # as a part of CSPLayer, DO NOT need this flag - this_kernel_size = kernel_size - else: - in_channels = out_c - out_channels = out_c - this_stride = 1 - force_resproj = False - this_kernel_size = kernel_size - the_block = ResConvK1KX( - in_channels, - out_channels, - btn_c, - this_kernel_size, - this_stride, - force_resproj, - act=act, - reparam=reparam) - self.block_list.append(the_block) - if block_id == 0 and with_spp: - self.block_list.append( - SPPBottleneck(out_channels, out_channels)) - - def forward(self, x): - output = x - for block in self.block_list: - output = block(output) - return output - - -class ResConvKXKX(nn.Module): - - def __init__(self, - in_c, - out_c, - btn_c, - kernel_size, - stride, - force_resproj=False, - act='silu'): - super(ResConvKXKX, self).__init__() - self.stride = stride - if self.stride == 2: - self.downsampler = ConvKXBNRELU(in_c, out_c, 3, 2, act=act) - else: - self.conv1 = ConvKXBN(in_c, btn_c, kernel_size, 1) - self.conv2 = RepVggBlock( - btn_c, out_c, kernel_size, stride, act='identity') - - if act is None: - self.activation_function = torch.relu - else: - self.activation_function = get_activation(act) - - if stride == 2: - self.residual_downsample = nn.AvgPool2d( - kernel_size=2, stride=2) - else: - self.residual_downsample = nn.Identity() - - if in_c != out_c or force_resproj: - self.residual_proj = ConvKXBN(in_c, out_c, 1, 1) - else: - self.residual_proj = nn.Identity() - - def forward(self, x): - if self.stride == 2: - return self.downsampler(x) - reslink = self.residual_downsample(x) - reslink = self.residual_proj(reslink) - - output = x - output = self.conv1(output) - output = self.activation_function(output) - output = self.conv2(output) - - output = output + reslink - output = self.activation_function(output) - - return output - - -class SuperResConvKXKX(nn.Module): - - def __init__(self, - in_c, - out_c, - btn_c, - kernel_size, - stride, - num_blocks, - with_spp=False, - act='silu'): - super(SuperResConvKXKX, self).__init__() - if act is None: - self.act = torch.relu - else: - self.act = get_activation(act) - self.block_list = nn.ModuleList() - for block_id in range(num_blocks): - if block_id == 0: - in_channels = in_c - out_channels = out_c - this_stride = stride - force_resproj = False # as a part of CSPLayer, DO NOT need this flag - this_kernel_size = kernel_size - else: - in_channels = out_c - out_channels = out_c - this_stride = 1 - force_resproj = False - this_kernel_size = kernel_size - the_block = ResConvKXKX( - in_channels, - out_channels, - btn_c, - this_kernel_size, - this_stride, - force_resproj, - act=act) - self.block_list.append(the_block) - if block_id == 0 and with_spp: - self.block_list.append( - SPPBottleneck(out_channels, out_channels)) - - def forward(self, x): - output = x - for block in self.block_list: - output = block(output) - return output - - -class TinyNAS(nn.Module): - - def __init__(self, - structure_info=None, - out_indices=[0, 1, 2, 4, 5], - out_channels=[None, None, 128, 256, 512], - with_spp=False, - use_focus=False, - need_conv1=True, - act='silu', - reparam=False): - super(TinyNAS, self).__init__() - assert len(out_indices) == len(out_channels) - self.out_indices = out_indices - self.need_conv1 = need_conv1 - - self.block_list = nn.ModuleList() - if need_conv1: - self.conv1_list = nn.ModuleList() - for idx, block_info in enumerate(structure_info): - the_block_class = block_info['class'] - if the_block_class == 'ConvKXBNRELU': - if use_focus: - the_block = Focus( - block_info['in'], - block_info['out'], - block_info['k'], - act=act) - else: - the_block = ConvKXBNRELU( - block_info['in'], - block_info['out'], - block_info['k'], - block_info['s'], - act=act) - self.block_list.append(the_block) - elif the_block_class == 'SuperResConvK1KX': - spp = with_spp if idx == len(structure_info) - 1 else False - the_block = SuperResConvK1KX( - block_info['in'], - block_info['out'], - block_info['btn'], - block_info['k'], - block_info['s'], - block_info['L'], - spp, - act=act, - reparam=reparam) - self.block_list.append(the_block) - elif the_block_class == 'SuperResConvKXKX': - spp = with_spp if idx == len(structure_info) - 1 else False - the_block = SuperResConvKXKX( - block_info['in'], - block_info['out'], - block_info['btn'], - block_info['k'], - block_info['s'], - block_info['L'], - spp, - act=act) - self.block_list.append(the_block) - if need_conv1: - if idx in self.out_indices and out_channels[ - self.out_indices.index(idx)] is not None: - self.conv1_list.append( - nn.Conv2d(block_info['out'], - out_channels[self.out_indices.index(idx)], - 1)) - else: - self.conv1_list.append(None) - - def init_weights(self, pretrain=None): - pass - - def forward(self, x): - output = x - stage_feature_list = [] - for idx, block in enumerate(self.block_list): - output = block(output) - if idx in self.out_indices: - if self.need_conv1 and self.conv1_list[idx] is not None: - true_out = self.conv1_list[idx](output) - stage_feature_list.append(true_out) - else: - stage_feature_list.append(output) - return stage_feature_list - - -def load_tinynas_net(backbone_cfg): - # load masternet model to path - import ast - net_structure_str = read_file(backbone_cfg.structure_file) - struct_str = ''.join([x.strip() for x in net_structure_str]) - struct_info = ast.literal_eval(struct_str) - for layer in struct_info: - if 'nbitsA' in layer: - del layer['nbitsA'] - if 'nbitsW' in layer: - del layer['nbitsW'] - - model = TinyNAS( - structure_info=struct_info, - out_indices=backbone_cfg.out_indices, - out_channels=backbone_cfg.out_channels, - with_spp=backbone_cfg.with_spp, - use_focus=backbone_cfg.use_focus, - act=backbone_cfg.act, - need_conv1=backbone_cfg.need_conv1, - reparam=backbone_cfg.reparam) - - return model diff --git a/modelscope/models/cv/tinynas_detection/backbone/tinynas_csp.py b/modelscope/models/cv/tinynas_detection/backbone/tinynas_csp.py new file mode 100644 index 00000000..903b6900 --- /dev/null +++ b/modelscope/models/cv/tinynas_detection/backbone/tinynas_csp.py @@ -0,0 +1,295 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# The DAMO-YOLO implementation is also open-sourced by the authors, and available +# at https://github.com/tinyvision/damo-yolo. + +import torch +import torch.nn as nn + +from modelscope.models.cv.tinynas_detection.core.ops import (Focus, RepConv, + SPPBottleneck, + get_activation) +from modelscope.utils.file_utils import read_file + + +class ConvKXBN(nn.Module): + + def __init__(self, in_c, out_c, kernel_size, stride): + super(ConvKXBN, self).__init__() + self.conv1 = nn.Conv2d( + in_c, + out_c, + kernel_size, + stride, (kernel_size - 1) // 2, + groups=1, + bias=False) + self.bn1 = nn.BatchNorm2d(out_c) + + def forward(self, x): + return self.bn1(self.conv1(x)) + + +class ConvKXBNRELU(nn.Module): + + def __init__(self, in_c, out_c, kernel_size, stride, act='silu'): + super(ConvKXBNRELU, self).__init__() + self.conv = ConvKXBN(in_c, out_c, kernel_size, stride) + if act is None: + self.activation_function = torch.relu + else: + self.activation_function = get_activation(act) + + def forward(self, x): + output = self.conv(x) + return self.activation_function(output) + + +class ResConvBlock(nn.Module): + + def __init__(self, + in_c, + out_c, + btn_c, + kernel_size, + stride, + act='silu', + reparam=False, + block_type='k1kx'): + super(ResConvBlock, self).__init__() + self.stride = stride + if block_type == 'k1kx': + self.conv1 = ConvKXBN(in_c, btn_c, kernel_size=1, stride=1) + else: + self.conv1 = ConvKXBN( + in_c, btn_c, kernel_size=kernel_size, stride=1) + if not reparam: + self.conv2 = ConvKXBN(btn_c, out_c, kernel_size, stride) + else: + self.conv2 = RepConv( + btn_c, out_c, kernel_size, stride, act='identity') + + self.activation_function = get_activation(act) + + if in_c != out_c and stride != 2: + self.residual_proj = ConvKXBN(in_c, out_c, kernel_size=1, stride=1) + else: + self.residual_proj = None + + def forward(self, x): + if self.residual_proj is not None: + reslink = self.residual_proj(x) + else: + reslink = x + x = self.conv1(x) + x = self.activation_function(x) + x = self.conv2(x) + if self.stride != 2: + x = x + reslink + x = self.activation_function(x) + return x + + +class CSPStem(nn.Module): + + def __init__(self, + in_c, + out_c, + btn_c, + stride, + kernel_size, + num_blocks, + act='silu', + reparam=False, + block_type='k1kx'): + super(CSPStem, self).__init__() + self.in_channels = in_c + self.out_channels = out_c + self.stride = stride + if self.stride == 2: + self.num_blocks = num_blocks - 1 + else: + self.num_blocks = num_blocks + self.kernel_size = kernel_size + self.act = act + self.block_type = block_type + out_c = out_c // 2 + + if act is None: + self.act = torch.relu + else: + self.act = get_activation(act) + self.block_list = nn.ModuleList() + for block_id in range(self.num_blocks): + if self.stride == 1 and block_id == 0: + in_c = in_c // 2 + else: + in_c = out_c + the_block = ResConvBlock( + in_c, + out_c, + btn_c, + kernel_size, + stride=1, + act=act, + reparam=reparam, + block_type=block_type) + self.block_list.append(the_block) + + def forward(self, x): + output = x + for block in self.block_list: + output = block(output) + return output + + +class TinyNAS(nn.Module): + + def __init__(self, + structure_info=None, + out_indices=[2, 3, 4], + with_spp=False, + use_focus=False, + act='silu', + reparam=False): + super(TinyNAS, self).__init__() + self.out_indices = out_indices + self.block_list = nn.ModuleList() + self.stride_list = [] + + for idx, block_info in enumerate(structure_info): + the_block_class = block_info['class'] + if the_block_class == 'ConvKXBNRELU': + if use_focus and idx == 0: + the_block = Focus( + block_info['in'], + block_info['out'], + block_info['k'], + act=act) + else: + the_block = ConvKXBNRELU( + block_info['in'], + block_info['out'], + block_info['k'], + block_info['s'], + act=act) + elif the_block_class == 'SuperResConvK1KX': + the_block = CSPStem( + block_info['in'], + block_info['out'], + block_info['btn'], + block_info['s'], + block_info['k'], + block_info['L'], + act=act, + reparam=reparam, + block_type='k1kx') + elif the_block_class == 'SuperResConvKXKX': + the_block = CSPStem( + block_info['in'], + block_info['out'], + block_info['btn'], + block_info['s'], + block_info['k'], + block_info['L'], + act=act, + reparam=reparam, + block_type='kxkx') + else: + raise NotImplementedError + + self.block_list.append(the_block) + + self.csp_stage = nn.ModuleList() + self.csp_stage.append(self.block_list[0]) + self.csp_stage.append(CSPWrapper(self.block_list[1])) + self.csp_stage.append(CSPWrapper(self.block_list[2])) + self.csp_stage.append( + CSPWrapper((self.block_list[3], self.block_list[4]))) + self.csp_stage.append( + CSPWrapper(self.block_list[5], with_spp=with_spp)) + del self.block_list + + def init_weights(self, pretrain=None): + pass + + def forward(self, x): + output = x + stage_feature_list = [] + for idx, block in enumerate(self.csp_stage): + output = block(output) + if idx in self.out_indices: + stage_feature_list.append(output) + return stage_feature_list + + +class CSPWrapper(nn.Module): + + def __init__(self, convstem, act='relu', reparam=False, with_spp=False): + + super(CSPWrapper, self).__init__() + self.with_spp = with_spp + if isinstance(convstem, tuple): + in_c = convstem[0].in_channels + out_c = convstem[-1].out_channels + hidden_dim = convstem[0].out_channels // 2 + _convstem = nn.ModuleList() + for modulelist in convstem: + for layer in modulelist.block_list: + _convstem.append(layer) + else: + in_c = convstem.in_channels + out_c = convstem.out_channels + hidden_dim = out_c // 2 + _convstem = convstem.block_list + + self.convstem = nn.ModuleList() + for layer in _convstem: + self.convstem.append(layer) + + self.act = get_activation(act) + self.downsampler = ConvKXBNRELU( + in_c, hidden_dim * 2, 3, 2, act=self.act) + if self.with_spp: + self.spp = SPPBottleneck(hidden_dim * 2, hidden_dim * 2) + if len(self.convstem) > 0: + self.conv_start = ConvKXBNRELU( + hidden_dim * 2, hidden_dim, 1, 1, act=self.act) + self.conv_shortcut = ConvKXBNRELU( + hidden_dim * 2, out_c // 2, 1, 1, act=self.act) + self.conv_fuse = ConvKXBNRELU(out_c, out_c, 1, 1, act=self.act) + + def forward(self, x): + x = self.downsampler(x) + if self.with_spp: + x = self.spp(x) + if len(self.convstem) > 0: + shortcut = self.conv_shortcut(x) + x = self.conv_start(x) + for block in self.convstem: + x = block(x) + x = torch.cat((x, shortcut), dim=1) + x = self.conv_fuse(x) + return x + + +def load_tinynas_net(backbone_cfg): + # load masternet model to path + import ast + + net_structure_str = read_file(backbone_cfg.structure_file) + struct_str = ''.join([x.strip() for x in net_structure_str]) + struct_info = ast.literal_eval(struct_str) + for layer in struct_info: + if 'nbitsA' in layer: + del layer['nbitsA'] + if 'nbitsW' in layer: + del layer['nbitsW'] + + model = TinyNAS( + structure_info=struct_info, + out_indices=backbone_cfg.out_indices, + with_spp=backbone_cfg.with_spp, + use_focus=backbone_cfg.use_focus, + act=backbone_cfg.act, + reparam=backbone_cfg.reparam) + + return model diff --git a/modelscope/models/cv/tinynas_detection/backbone/tinynas_res.py b/modelscope/models/cv/tinynas_detection/backbone/tinynas_res.py new file mode 100644 index 00000000..3fb9e573 --- /dev/null +++ b/modelscope/models/cv/tinynas_detection/backbone/tinynas_res.py @@ -0,0 +1,238 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# The DAMO-YOLO implementation is also open-sourced by the authors, and available +# at https://github.com/tinyvision/damo-yolo. + +import torch +import torch.nn as nn + +from modelscope.models.cv.tinynas_detection.core.ops import (Focus, RepConv, + SPPBottleneck, + get_activation) +from modelscope.utils.file_utils import read_file + + +class ConvKXBN(nn.Module): + + def __init__(self, in_c, out_c, kernel_size, stride): + super(ConvKXBN, self).__init__() + self.conv1 = nn.Conv2d( + in_c, + out_c, + kernel_size, + stride, (kernel_size - 1) // 2, + groups=1, + bias=False) + self.bn1 = nn.BatchNorm2d(out_c) + + def forward(self, x): + return self.bn1(self.conv1(x)) + + +class ConvKXBNRELU(nn.Module): + + def __init__(self, in_c, out_c, kernel_size, stride, act='silu'): + super(ConvKXBNRELU, self).__init__() + self.conv = ConvKXBN(in_c, out_c, kernel_size, stride) + if act is None: + self.activation_function = torch.relu + else: + self.activation_function = get_activation(act) + + def forward(self, x): + output = self.conv(x) + return self.activation_function(output) + + +class ResConvBlock(nn.Module): + + def __init__(self, + in_c, + out_c, + btn_c, + kernel_size, + stride, + act='silu', + reparam=False, + block_type='k1kx'): + super(ResConvBlock, self).__init__() + self.stride = stride + if block_type == 'k1kx': + self.conv1 = ConvKXBN(in_c, btn_c, kernel_size=1, stride=1) + else: + self.conv1 = ConvKXBN( + in_c, btn_c, kernel_size=kernel_size, stride=1) + + if not reparam: + self.conv2 = ConvKXBN(btn_c, out_c, kernel_size, stride) + else: + self.conv2 = RepConv( + btn_c, out_c, kernel_size, stride, act='identity') + + self.activation_function = get_activation(act) + + if in_c != out_c and stride != 2: + self.residual_proj = ConvKXBN(in_c, out_c, 1, 1) + else: + self.residual_proj = None + + def forward(self, x): + if self.residual_proj is not None: + reslink = self.residual_proj(x) + else: + reslink = x + x = self.conv1(x) + x = self.activation_function(x) + x = self.conv2(x) + if self.stride != 2: + x = x + reslink + x = self.activation_function(x) + return x + + +class SuperResStem(nn.Module): + + def __init__(self, + in_c, + out_c, + btn_c, + kernel_size, + stride, + num_blocks, + with_spp=False, + act='silu', + reparam=False, + block_type='k1kx'): + super(SuperResStem, self).__init__() + if act is None: + self.act = torch.relu + else: + self.act = get_activation(act) + self.block_list = nn.ModuleList() + for block_id in range(num_blocks): + if block_id == 0: + in_channels = in_c + out_channels = out_c + this_stride = stride + this_kernel_size = kernel_size + else: + in_channels = out_c + out_channels = out_c + this_stride = 1 + this_kernel_size = kernel_size + the_block = ResConvBlock( + in_channels, + out_channels, + btn_c, + this_kernel_size, + this_stride, + act=act, + reparam=reparam, + block_type=block_type) + self.block_list.append(the_block) + if block_id == 0 and with_spp: + self.block_list.append( + SPPBottleneck(out_channels, out_channels)) + + def forward(self, x): + output = x + for block in self.block_list: + output = block(output) + return output + + +class TinyNAS(nn.Module): + + def __init__(self, + structure_info=None, + out_indices=[2, 4, 5], + with_spp=False, + use_focus=False, + act='silu', + reparam=False): + super(TinyNAS, self).__init__() + self.out_indices = out_indices + self.block_list = nn.ModuleList() + + for idx, block_info in enumerate(structure_info): + the_block_class = block_info['class'] + if the_block_class == 'ConvKXBNRELU': + if use_focus: + the_block = Focus( + block_info['in'], + block_info['out'], + block_info['k'], + act=act) + else: + the_block = ConvKXBNRELU( + block_info['in'], + block_info['out'], + block_info['k'], + block_info['s'], + act=act) + self.block_list.append(the_block) + elif the_block_class == 'SuperResConvK1KX': + spp = with_spp if idx == len(structure_info) - 1 else False + the_block = SuperResStem( + block_info['in'], + block_info['out'], + block_info['btn'], + block_info['k'], + block_info['s'], + block_info['L'], + spp, + act=act, + reparam=reparam, + block_type='k1kx') + self.block_list.append(the_block) + elif the_block_class == 'SuperResConvKXKX': + spp = with_spp if idx == len(structure_info) - 1 else False + the_block = SuperResStem( + block_info['in'], + block_info['out'], + block_info['btn'], + block_info['k'], + block_info['s'], + block_info['L'], + spp, + act=act, + reparam=reparam, + block_type='kxkx') + self.block_list.append(the_block) + else: + raise NotImplementedError + + def init_weights(self, pretrain=None): + pass + + def forward(self, x): + output = x + stage_feature_list = [] + for idx, block in enumerate(self.block_list): + output = block(output) + if idx in self.out_indices: + stage_feature_list.append(output) + return stage_feature_list + + +def load_tinynas_net(backbone_cfg): + # load masternet model to path + import ast + + net_structure_str = read_file(backbone_cfg.structure_file) + struct_str = ''.join([x.strip() for x in net_structure_str]) + struct_info = ast.literal_eval(struct_str) + for layer in struct_info: + if 'nbitsA' in layer: + del layer['nbitsA'] + if 'nbitsW' in layer: + del layer['nbitsW'] + + model = TinyNAS( + structure_info=struct_info, + out_indices=backbone_cfg.out_indices, + with_spp=backbone_cfg.with_spp, + use_focus=backbone_cfg.use_focus, + act=backbone_cfg.act, + reparam=backbone_cfg.reparam) + + return model diff --git a/modelscope/models/cv/tinynas_detection/core/__init__.py b/modelscope/models/cv/tinynas_detection/core/__init__.py index 3dad5e72..50a10d0b 100644 --- a/modelscope/models/cv/tinynas_detection/core/__init__.py +++ b/modelscope/models/cv/tinynas_detection/core/__init__.py @@ -1,2 +1,2 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet. +# The DAMO-YOLO implementation is also open-sourced by the authors at https://github.com/tinyvision/damo-yolo. diff --git a/modelscope/models/cv/tinynas_detection/core/base_ops.py b/modelscope/models/cv/tinynas_detection/core/base_ops.py index 62729ca2..daf71d05 100644 --- a/modelscope/models/cv/tinynas_detection/core/base_ops.py +++ b/modelscope/models/cv/tinynas_detection/core/base_ops.py @@ -1,5 +1,5 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet. +# The DAMO-YOLO implementation is also open-sourced by the authors at https://github.com/tinyvision/damo-yolo. import math import torch diff --git a/modelscope/models/cv/tinynas_detection/core/neck_ops.py b/modelscope/models/cv/tinynas_detection/core/neck_ops.py index 7f481665..b04c323d 100644 --- a/modelscope/models/cv/tinynas_detection/core/neck_ops.py +++ b/modelscope/models/cv/tinynas_detection/core/neck_ops.py @@ -1,5 +1,5 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet. +# The DAMO-YOLO implementation is also open-sourced by the authors at https://github.com/tinyvision/damo-yolo. import numpy as np import torch diff --git a/modelscope/models/cv/tinynas_detection/core/ops.py b/modelscope/models/cv/tinynas_detection/core/ops.py new file mode 100644 index 00000000..07a96c13 --- /dev/null +++ b/modelscope/models/cv/tinynas_detection/core/ops.py @@ -0,0 +1,435 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# The DAMO-YOLO implementation is also open-sourced by the authors at https://github.com/tinyvision/damo-yolo. + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class SiLU(nn.Module): + """export-friendly version of nn.SiLU()""" + + @staticmethod + def forward(x): + return x * torch.sigmoid(x) + + +class Swish(nn.Module): + + def __init__(self, inplace=True): + super(Swish, self).__init__() + self.inplace = inplace + + def forward(self, x): + if self.inplace: + x.mul_(F.sigmoid(x)) + return x + else: + return x * F.sigmoid(x) + + +def get_activation(name='silu', inplace=True): + if name is None: + return nn.Identity() + + if isinstance(name, str): + if name == 'silu': + module = nn.SiLU(inplace=inplace) + elif name == 'relu': + module = nn.ReLU(inplace=inplace) + elif name == 'lrelu': + module = nn.LeakyReLU(0.1, inplace=inplace) + elif name == 'swish': + module = Swish(inplace=inplace) + elif name == 'hardsigmoid': + module = nn.Hardsigmoid(inplace=inplace) + elif name == 'identity': + module = nn.Identity() + else: + raise AttributeError('Unsupported act type: {}'.format(name)) + return module + + elif isinstance(name, nn.Module): + return name + + else: + raise AttributeError('Unsupported act type: {}'.format(name)) + + +def get_norm(name, out_channels, inplace=True): + if name == 'bn': + module = nn.BatchNorm2d(out_channels) + else: + raise NotImplementedError + return module + + +class ConvBNAct(nn.Module): + """A Conv2d -> Batchnorm -> silu/leaky relu block""" + + def __init__( + self, + in_channels, + out_channels, + ksize, + stride=1, + groups=1, + bias=False, + act='silu', + norm='bn', + reparam=False, + ): + super().__init__() + # same padding + pad = (ksize - 1) // 2 + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size=ksize, + stride=stride, + padding=pad, + groups=groups, + bias=bias, + ) + if norm is not None: + self.bn = get_norm(norm, out_channels, inplace=True) + if act is not None: + self.act = get_activation(act, inplace=True) + self.with_norm = norm is not None + self.with_act = act is not None + + def forward(self, x): + x = self.conv(x) + if self.with_norm: + x = self.bn(x) + if self.with_act: + x = self.act(x) + return x + + def fuseforward(self, x): + return self.act(self.conv(x)) + + +class SPPBottleneck(nn.Module): + """Spatial pyramid pooling layer used in YOLOv3-SPP""" + + def __init__(self, + in_channels, + out_channels, + kernel_sizes=(5, 9, 13), + activation='silu'): + super().__init__() + hidden_channels = in_channels // 2 + self.conv1 = ConvBNAct( + in_channels, hidden_channels, 1, stride=1, act=activation) + self.m = nn.ModuleList([ + nn.MaxPool2d(kernel_size=ks, stride=1, padding=ks // 2) + for ks in kernel_sizes + ]) + conv2_channels = hidden_channels * (len(kernel_sizes) + 1) + self.conv2 = ConvBNAct( + conv2_channels, out_channels, 1, stride=1, act=activation) + + def forward(self, x): + x = self.conv1(x) + x = torch.cat([x] + [m(x) for m in self.m], dim=1) + x = self.conv2(x) + return x + + +class Focus(nn.Module): + """Focus width and height information into channel space.""" + + def __init__(self, + in_channels, + out_channels, + ksize=1, + stride=1, + act='silu'): + super().__init__() + self.conv = ConvBNAct( + in_channels * 4, out_channels, ksize, stride, act=act) + + def forward(self, x): + # shape of x (b,c,w,h) -> y(b,4c,w/2,h/2) + patch_top_left = x[..., ::2, ::2] + patch_top_right = x[..., ::2, 1::2] + patch_bot_left = x[..., 1::2, ::2] + patch_bot_right = x[..., 1::2, 1::2] + x = torch.cat( + ( + patch_top_left, + patch_bot_left, + patch_top_right, + patch_bot_right, + ), + dim=1, + ) + return self.conv(x) + + +class BasicBlock_3x3_Reverse(nn.Module): + + def __init__(self, + ch_in, + ch_hidden_ratio, + ch_out, + act='relu', + shortcut=True): + super(BasicBlock_3x3_Reverse, self).__init__() + assert ch_in == ch_out + ch_hidden = int(ch_in * ch_hidden_ratio) + self.conv1 = ConvBNAct(ch_hidden, ch_out, 3, stride=1, act=act) + self.conv2 = RepConv(ch_in, ch_hidden, 3, stride=1, act=act) + self.shortcut = shortcut + + def forward(self, x): + y = self.conv2(x) + y = self.conv1(y) + if self.shortcut: + return x + y + else: + return y + + +class SPP(nn.Module): + + def __init__( + self, + ch_in, + ch_out, + k, + pool_size, + act='swish', + ): + super(SPP, self).__init__() + self.pool = [] + for i, size in enumerate(pool_size): + pool = nn.MaxPool2d( + kernel_size=size, stride=1, padding=size // 2, ceil_mode=False) + self.add_module('pool{}'.format(i), pool) + self.pool.append(pool) + self.conv = ConvBNAct(ch_in, ch_out, k, act=act) + + def forward(self, x): + outs = [x] + + for pool in self.pool: + outs.append(pool(x)) + y = torch.cat(outs, axis=1) + + y = self.conv(y) + return y + + +class CSPStage(nn.Module): + + def __init__(self, + block_fn, + ch_in, + ch_hidden_ratio, + ch_out, + n, + act='swish', + spp=False): + super(CSPStage, self).__init__() + + split_ratio = 2 + ch_first = int(ch_out // split_ratio) + ch_mid = int(ch_out - ch_first) + self.conv1 = ConvBNAct(ch_in, ch_first, 1, act=act) + self.conv2 = ConvBNAct(ch_in, ch_mid, 1, act=act) + self.convs = nn.Sequential() + + next_ch_in = ch_mid + for i in range(n): + if block_fn == 'BasicBlock_3x3_Reverse': + self.convs.add_module( + str(i), + BasicBlock_3x3_Reverse( + next_ch_in, + ch_hidden_ratio, + ch_mid, + act=act, + shortcut=True)) + else: + raise NotImplementedError + if i == (n - 1) // 2 and spp: + self.convs.add_module( + 'spp', SPP(ch_mid * 4, ch_mid, 1, [5, 9, 13], act=act)) + next_ch_in = ch_mid + self.conv3 = ConvBNAct(ch_mid * n + ch_first, ch_out, 1, act=act) + + def forward(self, x): + y1 = self.conv1(x) + y2 = self.conv2(x) + + mid_out = [y1] + for conv in self.convs: + y2 = conv(y2) + mid_out.append(y2) + y = torch.cat(mid_out, axis=1) + y = self.conv3(y) + return y + + +def conv_bn(in_channels, out_channels, kernel_size, stride, padding, groups=1): + '''Basic cell for rep-style block, including conv and bn''' + result = nn.Sequential() + result.add_module( + 'conv', + nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + groups=groups, + bias=False)) + result.add_module('bn', nn.BatchNorm2d(num_features=out_channels)) + return result + + +class RepConv(nn.Module): + '''RepConv is a basic rep-style block, including training and deploy status + Code is based on https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py + ''' + + def __init__(self, + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + dilation=1, + groups=1, + padding_mode='zeros', + deploy=False, + act='relu', + norm=None): + super(RepConv, self).__init__() + self.deploy = deploy + self.groups = groups + self.in_channels = in_channels + self.out_channels = out_channels + + assert kernel_size == 3 + assert padding == 1 + + padding_11 = padding - kernel_size // 2 + + if isinstance(act, str): + self.nonlinearity = get_activation(act) + else: + self.nonlinearity = act + + if deploy: + self.rbr_reparam = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=True, + padding_mode=padding_mode) + + else: + self.rbr_identity = None + self.rbr_dense = conv_bn( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + groups=groups) + self.rbr_1x1 = conv_bn( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + stride=stride, + padding=padding_11, + groups=groups) + + def forward(self, inputs): + '''Forward process''' + if hasattr(self, 'rbr_reparam'): + return self.nonlinearity(self.rbr_reparam(inputs)) + + if self.rbr_identity is None: + id_out = 0 + else: + id_out = self.rbr_identity(inputs) + + return self.nonlinearity( + self.rbr_dense(inputs) + self.rbr_1x1(inputs) + id_out) + + def get_equivalent_kernel_bias(self): + kernel3x3, bias3x3 = self._fuse_bn_tensor(self.rbr_dense) + kernel1x1, bias1x1 = self._fuse_bn_tensor(self.rbr_1x1) + kernelid, biasid = self._fuse_bn_tensor(self.rbr_identity) + return kernel3x3 + self._pad_1x1_to_3x3_tensor( + kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid + + def _pad_1x1_to_3x3_tensor(self, kernel1x1): + if kernel1x1 is None: + return 0 + else: + return torch.nn.functional.pad(kernel1x1, [1, 1, 1, 1]) + + def _fuse_bn_tensor(self, branch): + if branch is None: + return 0, 0 + if isinstance(branch, nn.Sequential): + kernel = branch.conv.weight + running_mean = branch.bn.running_mean + running_var = branch.bn.running_var + gamma = branch.bn.weight + beta = branch.bn.bias + eps = branch.bn.eps + else: + assert isinstance(branch, nn.BatchNorm2d) + if not hasattr(self, 'id_tensor'): + input_dim = self.in_channels // self.groups + kernel_value = np.zeros((self.in_channels, input_dim, 3, 3), + dtype=np.float32) + for i in range(self.in_channels): + kernel_value[i, i % input_dim, 1, 1] = 1 + self.id_tensor = torch.from_numpy(kernel_value).to( + branch.weight.device) + kernel = self.id_tensor + running_mean = branch.running_mean + running_var = branch.running_var + gamma = branch.weight + beta = branch.bias + eps = branch.eps + std = (running_var + eps).sqrt() + t = (gamma / std).reshape(-1, 1, 1, 1) + return kernel * t, beta - running_mean * gamma / std + + def switch_to_deploy(self): + if hasattr(self, 'rbr_reparam'): + return + kernel, bias = self.get_equivalent_kernel_bias() + self.rbr_reparam = nn.Conv2d( + in_channels=self.rbr_dense.conv.in_channels, + out_channels=self.rbr_dense.conv.out_channels, + kernel_size=self.rbr_dense.conv.kernel_size, + stride=self.rbr_dense.conv.stride, + padding=self.rbr_dense.conv.padding, + dilation=self.rbr_dense.conv.dilation, + groups=self.rbr_dense.conv.groups, + bias=True) + self.rbr_reparam.weight.data = kernel + self.rbr_reparam.bias.data = bias + for para in self.parameters(): + para.detach_() + self.__delattr__('rbr_dense') + self.__delattr__('rbr_1x1') + if hasattr(self, 'rbr_identity'): + self.__delattr__('rbr_identity') + if hasattr(self, 'id_tensor'): + self.__delattr__('id_tensor') + self.deploy = True diff --git a/modelscope/models/cv/tinynas_detection/core/repvgg_block.py b/modelscope/models/cv/tinynas_detection/core/repvgg_block.py index 06966a4e..b2c5ddc4 100644 --- a/modelscope/models/cv/tinynas_detection/core/repvgg_block.py +++ b/modelscope/models/cv/tinynas_detection/core/repvgg_block.py @@ -1,5 +1,5 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet. +# The DAMO-YOLO implementation is also open-sourced by the authors at https://github.com/tinyvision/damo-yolo. import numpy as np import torch diff --git a/modelscope/models/cv/tinynas_detection/core/utils.py b/modelscope/models/cv/tinynas_detection/core/utils.py index 482f12fb..29f08f05 100644 --- a/modelscope/models/cv/tinynas_detection/core/utils.py +++ b/modelscope/models/cv/tinynas_detection/core/utils.py @@ -1,5 +1,5 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet. +# The DAMO-YOLO implementation is also open-sourced by the authors at https://github.com/tinyvision/damo-yolo. import numpy as np import torch diff --git a/modelscope/models/cv/tinynas_detection/detector.py b/modelscope/models/cv/tinynas_detection/detector.py index 7aff2167..d7320aaa 100644 --- a/modelscope/models/cv/tinynas_detection/detector.py +++ b/modelscope/models/cv/tinynas_detection/detector.py @@ -1,5 +1,5 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet. +# The DAMO-YOLO implementation is also open-sourced by the authors at https://github.com/tinyvision/damo-yolo. import os.path as osp import pickle @@ -42,7 +42,7 @@ class SingleStageDetector(TorchModel): self.conf_thre = config.model.head.nms_conf_thre self.nms_thre = config.model.head.nms_iou_thre - if self.cfg.model.backbone.name == 'TinyNAS': + if 'TinyNAS' in self.cfg.model.backbone.name: self.cfg.model.backbone.structure_file = osp.join( model_dir, self.cfg.model.backbone.structure_file) self.backbone = build_backbone(self.cfg.model.backbone) diff --git a/modelscope/models/cv/tinynas_detection/head/__init__.py b/modelscope/models/cv/tinynas_detection/head/__init__.py index f870fae1..b522ef8a 100644 --- a/modelscope/models/cv/tinynas_detection/head/__init__.py +++ b/modelscope/models/cv/tinynas_detection/head/__init__.py @@ -1,9 +1,10 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet. +# The DAMO-YOLO implementation is also open-sourced by the authors at https://github.com/tinyvision/damo-yolo. import copy from .gfocal_v2_tiny import GFocalHead_Tiny +from .zero_head import ZeroHead def build_head(cfg): @@ -12,5 +13,7 @@ def build_head(cfg): name = head_cfg.pop('name') if name == 'GFocalV2': return GFocalHead_Tiny(**head_cfg) + elif name == 'ZeroHead': + return ZeroHead(**head_cfg) else: raise NotImplementedError diff --git a/modelscope/models/cv/tinynas_detection/head/gfocal_v2_tiny.py b/modelscope/models/cv/tinynas_detection/head/gfocal_v2_tiny.py index 66904ed1..822efd2a 100644 --- a/modelscope/models/cv/tinynas_detection/head/gfocal_v2_tiny.py +++ b/modelscope/models/cv/tinynas_detection/head/gfocal_v2_tiny.py @@ -1,5 +1,5 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet. +# The DAMO-YOLO implementation is also open-sourced by the authors at https://github.com/tinyvision/damo-yolo. import functools from functools import partial @@ -9,7 +9,8 @@ import torch import torch.nn as nn import torch.nn.functional as F -from ..core.base_ops import BaseConv, DWConv +from modelscope.models.cv.tinynas_detection.core.base_ops import (BaseConv, + DWConv) class Scale(nn.Module): diff --git a/modelscope/models/cv/tinynas_detection/head/zero_head.py b/modelscope/models/cv/tinynas_detection/head/zero_head.py new file mode 100644 index 00000000..0e23ebc3 --- /dev/null +++ b/modelscope/models/cv/tinynas_detection/head/zero_head.py @@ -0,0 +1,288 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# The DAMO-YOLO implementation is also open-sourced by the authors, and available +# at https://github.com/tinyvision/damo-yolo. +from functools import partial + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from modelscope.models.cv.tinynas_detection.core.ops import ConvBNAct + + +class Scale(nn.Module): + + def __init__(self, scale=1.0): + super(Scale, self).__init__() + self.scale = nn.Parameter(torch.tensor(scale, dtype=torch.float)) + + def forward(self, x): + return x * self.scale + + +def multi_apply(func, *args, **kwargs): + + pfunc = partial(func, **kwargs) if kwargs else func + map_results = map(pfunc, *args) + return tuple(map(list, zip(*map_results))) + + +def distance2bbox(points, distance, max_shape=None): + """Decode distance prediction to bounding box. + """ + x1 = points[..., 0] - distance[..., 0] + y1 = points[..., 1] - distance[..., 1] + x2 = points[..., 0] + distance[..., 2] + y2 = points[..., 1] + distance[..., 3] + if max_shape is not None: + x1 = x1.clamp(min=0, max=max_shape[1]) + y1 = y1.clamp(min=0, max=max_shape[0]) + x2 = x2.clamp(min=0, max=max_shape[1]) + y2 = y2.clamp(min=0, max=max_shape[0]) + return torch.stack([x1, y1, x2, y2], -1) + + +def bbox2distance(points, bbox, max_dis=None, eps=0.1): + """Decode bounding box based on distances. + """ + left = points[:, 0] - bbox[:, 0] + top = points[:, 1] - bbox[:, 1] + right = bbox[:, 2] - points[:, 0] + bottom = bbox[:, 3] - points[:, 1] + if max_dis is not None: + left = left.clamp(min=0, max=max_dis - eps) + top = top.clamp(min=0, max=max_dis - eps) + right = right.clamp(min=0, max=max_dis - eps) + bottom = bottom.clamp(min=0, max=max_dis - eps) + return torch.stack([left, top, right, bottom], -1) + + +class Integral(nn.Module): + """A fixed layer for calculating integral result from distribution. + """ + + def __init__(self, reg_max=16): + super(Integral, self).__init__() + self.reg_max = reg_max + self.register_buffer('project', + torch.linspace(0, self.reg_max, self.reg_max + 1)) + + def forward(self, x): + """Forward feature from the regression head to get integral result of + bounding box location. + """ + b, hw, _, _ = x.size() + x = x.reshape(b * hw * 4, self.reg_max + 1) + y = self.project.type_as(x).unsqueeze(1) + x = torch.matmul(x, y).reshape(b, hw, 4) + return x + + +class ZeroHead(nn.Module): + """Ref to Generalized Focal Loss V2: Learning Reliable Localization Quality + Estimation for Dense Object Detection. + """ + + def __init__( + self, + num_classes, + in_channels, + stacked_convs=4, # 4 + feat_channels=256, + reg_max=12, + strides=[8, 16, 32], + norm='gn', + act='relu', + nms_conf_thre=0.05, + nms_iou_thre=0.7, + nms=True, + **kwargs): + self.in_channels = in_channels + self.num_classes = num_classes + self.stacked_convs = stacked_convs + self.act = act + self.strides = strides + if stacked_convs == 0: + feat_channels = in_channels + if isinstance(feat_channels, list): + self.feat_channels = feat_channels + else: + self.feat_channels = [feat_channels] * len(self.strides) + # add 1 for keep consistance with former models + self.cls_out_channels = num_classes + 1 + self.reg_max = reg_max + + self.nms = nms + self.nms_conf_thre = nms_conf_thre + self.nms_iou_thre = nms_iou_thre + + self.feat_size = [torch.zeros(4) for _ in strides] + + super(ZeroHead, self).__init__() + self.integral = Integral(self.reg_max) + + self._init_layers() + + def _build_not_shared_convs(self, in_channel, feat_channels): + cls_convs = nn.ModuleList() + reg_convs = nn.ModuleList() + + for i in range(self.stacked_convs): + chn = feat_channels if i > 0 else in_channel + kernel_size = 3 if i > 0 else 1 + cls_convs.append( + ConvBNAct( + chn, + feat_channels, + kernel_size, + stride=1, + groups=1, + norm='bn', + act=self.act)) + reg_convs.append( + ConvBNAct( + chn, + feat_channels, + kernel_size, + stride=1, + groups=1, + norm='bn', + act=self.act)) + + return cls_convs, reg_convs + + def _init_layers(self): + """Initialize layers of the head.""" + self.cls_convs = nn.ModuleList() + self.reg_convs = nn.ModuleList() + + for i in range(len(self.strides)): + cls_convs, reg_convs = self._build_not_shared_convs( + self.in_channels[i], self.feat_channels[i]) + self.cls_convs.append(cls_convs) + self.reg_convs.append(reg_convs) + + self.gfl_cls = nn.ModuleList([ + nn.Conv2d( + self.feat_channels[i], self.cls_out_channels, 3, padding=1) + for i in range(len(self.strides)) + ]) + + self.gfl_reg = nn.ModuleList([ + nn.Conv2d( + self.feat_channels[i], 4 * (self.reg_max + 1), 3, padding=1) + for i in range(len(self.strides)) + ]) + + self.scales = nn.ModuleList([Scale(1.0) for _ in self.strides]) + + def forward(self, xin, labels=None, imgs=None, aux_targets=None): + if self.training: + return NotImplementedError + else: + return self.forward_eval(xin=xin, labels=labels, imgs=imgs) + + def forward_eval(self, xin, labels=None, imgs=None): + + # prepare priors for label assignment and bbox decode + if self.feat_size[0] != xin[0].shape: + mlvl_priors_list = [ + self.get_single_level_center_priors( + xin[i].shape[0], + xin[i].shape[-2:], + stride, + dtype=torch.float32, + device=xin[0].device) + for i, stride in enumerate(self.strides) + ] + self.mlvl_priors = torch.cat(mlvl_priors_list, dim=1) + self.feat_size[0] = xin[0].shape + + # forward for bboxes and classification prediction + cls_scores, bbox_preds = multi_apply( + self.forward_single, + xin, + self.cls_convs, + self.reg_convs, + self.gfl_cls, + self.gfl_reg, + self.scales, + ) + cls_scores = torch.cat(cls_scores, dim=1)[:, :, :self.num_classes] + bbox_preds = torch.cat(bbox_preds, dim=1) + # batch bbox decode + bbox_preds = self.integral(bbox_preds) * self.mlvl_priors[..., 2, None] + bbox_preds = distance2bbox(self.mlvl_priors[..., :2], bbox_preds) + + res = torch.cat([bbox_preds, cls_scores[..., 0:self.num_classes]], + dim=-1) + return res + + def forward_single(self, x, cls_convs, reg_convs, gfl_cls, gfl_reg, scale): + """Forward feature of a single scale level. + + """ + cls_feat = x + reg_feat = x + + for cls_conv, reg_conv in zip(cls_convs, reg_convs): + cls_feat = cls_conv(cls_feat) + reg_feat = reg_conv(reg_feat) + + bbox_pred = scale(gfl_reg(reg_feat)).float() + N, C, H, W = bbox_pred.size() + if self.training: + bbox_before_softmax = bbox_pred.reshape(N, 4, self.reg_max + 1, H, + W) + bbox_before_softmax = bbox_before_softmax.flatten( + start_dim=3).permute(0, 3, 1, 2) + bbox_pred = F.softmax( + bbox_pred.reshape(N, 4, self.reg_max + 1, H, W), dim=2) + + cls_score = gfl_cls(cls_feat).sigmoid() + + cls_score = cls_score.flatten(start_dim=2).permute( + 0, 2, 1) # N, h*w, self.num_classes+1 + bbox_pred = bbox_pred.flatten(start_dim=3).permute( + 0, 3, 1, 2) # N, h*w, 4, self.reg_max+1 + if self.training: + return cls_score, bbox_pred, bbox_before_softmax + else: + return cls_score, bbox_pred + + def get_single_level_center_priors(self, batch_size, featmap_size, stride, + dtype, device): + + h, w = featmap_size + x_range = (torch.arange(0, int(w), dtype=dtype, + device=device)) * stride + y_range = (torch.arange(0, int(h), dtype=dtype, + device=device)) * stride + + x = x_range.repeat(h, 1) + y = y_range.unsqueeze(-1).repeat(1, w) + + y = y.flatten() + x = x.flatten() + strides = x.new_full((x.shape[0], ), stride) + priors = torch.stack([x, y, strides, strides], dim=-1) + + return priors.unsqueeze(0).repeat(batch_size, 1, 1) + + def sample(self, assign_result, gt_bboxes): + pos_inds = torch.nonzero( + assign_result.gt_inds > 0, as_tuple=False).squeeze(-1).unique() + neg_inds = torch.nonzero( + assign_result.gt_inds == 0, as_tuple=False).squeeze(-1).unique() + pos_assigned_gt_inds = assign_result.gt_inds[pos_inds] - 1 + + if gt_bboxes.numel() == 0: + # hack for index error case + assert pos_assigned_gt_inds.numel() == 0 + pos_gt_bboxes = torch.empty_like(gt_bboxes).view(-1, 4) + else: + if len(gt_bboxes.shape) < 2: + gt_bboxes = gt_bboxes.view(-1, 4) + pos_gt_bboxes = gt_bboxes[pos_assigned_gt_inds, :] + + return pos_inds, neg_inds, pos_gt_bboxes, pos_assigned_gt_inds diff --git a/modelscope/models/cv/tinynas_detection/neck/__init__.py b/modelscope/models/cv/tinynas_detection/neck/__init__.py index 3c418c29..e5b9e72a 100644 --- a/modelscope/models/cv/tinynas_detection/neck/__init__.py +++ b/modelscope/models/cv/tinynas_detection/neck/__init__.py @@ -1,10 +1,10 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet. +# The DAMO-YOLO implementation is also open-sourced by the authors at https://github.com/tinyvision/damo-yolo. import copy from .giraffe_fpn import GiraffeNeck -from .giraffe_fpn_v2 import GiraffeNeckV2 +from .giraffe_fpn_btn import GiraffeNeckV2 def build_neck(cfg): diff --git a/modelscope/models/cv/tinynas_detection/neck/giraffe_config.py b/modelscope/models/cv/tinynas_detection/neck/giraffe_config.py index 289fdfd2..23994356 100644 --- a/modelscope/models/cv/tinynas_detection/neck/giraffe_config.py +++ b/modelscope/models/cv/tinynas_detection/neck/giraffe_config.py @@ -1,5 +1,5 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet. +# The DAMO-YOLO implementation is also open-sourced by the authors at https://github.com/tinyvision/damo-yolo. import collections import itertools diff --git a/modelscope/models/cv/tinynas_detection/neck/giraffe_fpn.py b/modelscope/models/cv/tinynas_detection/neck/giraffe_fpn.py index b7087779..1b7db26e 100644 --- a/modelscope/models/cv/tinynas_detection/neck/giraffe_fpn.py +++ b/modelscope/models/cv/tinynas_detection/neck/giraffe_fpn.py @@ -1,5 +1,5 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet. +# The DAMO-YOLO implementation is also open-sourced by the authors at https://github.com/tinyvision/damo-yolo. import logging import math @@ -15,7 +15,8 @@ from timm import create_model from timm.models.layers import (Swish, create_conv2d, create_pool2d, get_act_layer) -from ..core.base_ops import CSPLayer, ShuffleBlock, ShuffleCSPLayer +from modelscope.models.cv.tinynas_detection.core.base_ops import ( + CSPLayer, ShuffleBlock, ShuffleCSPLayer) from .giraffe_config import get_graph_config _ACT_LAYER = Swish diff --git a/modelscope/models/cv/tinynas_detection/neck/giraffe_fpn_btn.py b/modelscope/models/cv/tinynas_detection/neck/giraffe_fpn_btn.py new file mode 100644 index 00000000..f8519df0 --- /dev/null +++ b/modelscope/models/cv/tinynas_detection/neck/giraffe_fpn_btn.py @@ -0,0 +1,132 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# The DAMO-YOLO implementation is also open-sourced by the authors at https://github.com/tinyvision/damo-yolo. + +import torch +import torch.nn as nn + +from modelscope.models.cv.tinynas_detection.core.ops import ConvBNAct, CSPStage + + +class GiraffeNeckV2(nn.Module): + + def __init__( + self, + depth=1.0, + hidden_ratio=1.0, + in_features=[2, 3, 4], + in_channels=[256, 512, 1024], + out_channels=[256, 512, 1024], + act='silu', + spp=False, + block_name='BasicBlock', + ): + super().__init__() + self.in_features = in_features + self.in_channels = in_channels + self.out_channels = out_channels + Conv = ConvBNAct + + self.upsample = nn.Upsample(scale_factor=2, mode='nearest') + + # node x3: input x0, x1 + self.bu_conv13 = Conv(in_channels[1], in_channels[1], 3, 2, act=act) + self.merge_3 = CSPStage( + block_name, + in_channels[1] + in_channels[2], + hidden_ratio, + in_channels[2], + round(3 * depth), + act=act, + spp=spp) + + # node x4: input x1, x2, x3 + self.bu_conv24 = Conv(in_channels[0], in_channels[0], 3, 2, act=act) + self.merge_4 = CSPStage( + block_name, + in_channels[0] + in_channels[1] + in_channels[2], + hidden_ratio, + in_channels[1], + round(3 * depth), + act=act, + spp=spp) + + # node x5: input x2, x4 + self.merge_5 = CSPStage( + block_name, + in_channels[1] + in_channels[0], + hidden_ratio, + out_channels[0], + round(3 * depth), + act=act, + spp=spp) + + # node x7: input x4, x5 + self.bu_conv57 = Conv(out_channels[0], out_channels[0], 3, 2, act=act) + self.merge_7 = CSPStage( + block_name, + out_channels[0] + in_channels[1], + hidden_ratio, + out_channels[1], + round(3 * depth), + act=act, + spp=spp) + + # node x6: input x3, x4, x7 + self.bu_conv46 = Conv(in_channels[1], in_channels[1], 3, 2, act=act) + self.bu_conv76 = Conv(out_channels[1], out_channels[1], 3, 2, act=act) + self.merge_6 = CSPStage( + block_name, + in_channels[1] + out_channels[1] + in_channels[2], + hidden_ratio, + out_channels[2], + round(3 * depth), + act=act, + spp=spp) + + def init_weights(self): + pass + + def forward(self, out_features): + """ + Args: + inputs: input images. + + Returns: + Tuple[Tensor]: FPN feature. + """ + + # backbone + [x2, x1, x0] = out_features + + # node x3 + x13 = self.bu_conv13(x1) + x3 = torch.cat([x0, x13], 1) + x3 = self.merge_3(x3) + + # node x4 + x34 = self.upsample(x3) + x24 = self.bu_conv24(x2) + x4 = torch.cat([x1, x24, x34], 1) + x4 = self.merge_4(x4) + + # node x5 + x45 = self.upsample(x4) + x5 = torch.cat([x2, x45], 1) + x5 = self.merge_5(x5) + + # node x8 + # x8 = x5 + + # node x7 + x57 = self.bu_conv57(x5) + x7 = torch.cat([x4, x57], 1) + x7 = self.merge_7(x7) + + # node x6 + x46 = self.bu_conv46(x4) + x76 = self.bu_conv76(x7) + x6 = torch.cat([x3, x46, x76], 1) + x6 = self.merge_6(x6) + + outputs = (x5, x7, x6) + return outputs diff --git a/modelscope/models/cv/tinynas_detection/neck/giraffe_fpn_v2.py b/modelscope/models/cv/tinynas_detection/neck/giraffe_fpn_v2.py deleted file mode 100644 index b88c39f2..00000000 --- a/modelscope/models/cv/tinynas_detection/neck/giraffe_fpn_v2.py +++ /dev/null @@ -1,200 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. -# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet. - -import torch -import torch.nn as nn - -from ..core.base_ops import BaseConv, CSPLayer, DWConv -from ..core.neck_ops import CSPStage - - -class GiraffeNeckV2(nn.Module): - - def __init__( - self, - depth=1.0, - width=1.0, - in_channels=[256, 512, 1024], - out_channels=[256, 512, 1024], - depthwise=False, - act='silu', - spp=True, - reparam_mode=True, - block_name='BasicBlock', - ): - super().__init__() - self.in_channels = in_channels - Conv = DWConv if depthwise else BaseConv - - reparam_mode = reparam_mode - - self.upsample = nn.Upsample(scale_factor=2, mode='nearest') - - # node x3: input x0, x1 - self.bu_conv13 = Conv( - int(in_channels[1] * width), - int(in_channels[1] * width), - 3, - 2, - act=act) - if reparam_mode: - self.merge_3 = CSPStage( - block_name, - int((in_channels[1] + in_channels[2]) * width), - int(in_channels[2] * width), - round(3 * depth), - act=act, - spp=spp) - else: - self.merge_3 = CSPLayer( - int((in_channels[1] + in_channels[2]) * width), - int(in_channels[2] * width), - round(3 * depth), - False, - depthwise=depthwise, - act=act) - - # node x4: input x1, x2, x3 - self.bu_conv24 = Conv( - int(in_channels[0] * width), - int(in_channels[0] * width), - 3, - 2, - act=act) - if reparam_mode: - self.merge_4 = CSPStage( - block_name, - int((in_channels[0] + in_channels[1] + in_channels[2]) - * width), - int(in_channels[1] * width), - round(3 * depth), - act=act, - spp=spp) - else: - self.merge_4 = CSPLayer( - int((in_channels[0] + in_channels[1] + in_channels[2]) - * width), - int(in_channels[1] * width), - round(3 * depth), - False, - depthwise=depthwise, - act=act) - - # node x5: input x2, x4 - if reparam_mode: - self.merge_5 = CSPStage( - block_name, - int((in_channels[1] + in_channels[0]) * width), - int(out_channels[0] * width), - round(3 * depth), - act=act, - spp=spp) - else: - self.merge_5 = CSPLayer( - int((in_channels[1] + in_channels[0]) * width), - int(out_channels[0] * width), - round(3 * depth), - False, - depthwise=depthwise, - act=act) - - # node x7: input x4, x5 - self.bu_conv57 = Conv( - int(out_channels[0] * width), - int(out_channels[0] * width), - 3, - 2, - act=act) - if reparam_mode: - self.merge_7 = CSPStage( - block_name, - int((out_channels[0] + in_channels[1]) * width), - int(out_channels[1] * width), - round(3 * depth), - act=act, - spp=spp) - else: - self.merge_7 = CSPLayer( - int((out_channels[0] + in_channels[1]) * width), - int(out_channels[1] * width), - round(3 * depth), - False, - depthwise=depthwise, - act=act) - - # node x6: input x3, x4, x7 - self.bu_conv46 = Conv( - int(in_channels[1] * width), - int(in_channels[1] * width), - 3, - 2, - act=act) - self.bu_conv76 = Conv( - int(out_channels[1] * width), - int(out_channels[1] * width), - 3, - 2, - act=act) - if reparam_mode: - self.merge_6 = CSPStage( - block_name, - int((in_channels[1] + out_channels[1] + in_channels[2]) - * width), - int(out_channels[2] * width), - round(3 * depth), - act=act, - spp=spp) - else: - self.merge_6 = CSPLayer( - int((in_channels[1] + out_channels[1] + in_channels[2]) - * width), - int(out_channels[2] * width), - round(3 * depth), - False, - depthwise=depthwise, - act=act) - - def init_weights(self): - pass - - def forward(self, out_features): - """ - Args: - inputs: input images. - - Returns: - Tuple[Tensor]: FPN feature. - """ - - # backbone - [x2, x1, x0] = out_features - - # node x3 - x13 = self.bu_conv13(x1) - x3 = torch.cat([x0, x13], 1) - x3 = self.merge_3(x3) - - # node x4 - x34 = self.upsample(x3) - x24 = self.bu_conv24(x2) - x4 = torch.cat([x1, x24, x34], 1) - x4 = self.merge_4(x4) - - # node x5 - x45 = self.upsample(x4) - x5 = torch.cat([x2, x45], 1) - x5 = self.merge_5(x5) - - # node x7 - x57 = self.bu_conv57(x5) - x7 = torch.cat([x4, x57], 1) - x7 = self.merge_7(x7) - - # node x6 - x46 = self.bu_conv46(x4) - x76 = self.bu_conv76(x7) - x6 = torch.cat([x3, x46, x76], 1) - x6 = self.merge_6(x6) - - outputs = (x5, x7, x6) - return outputs diff --git a/modelscope/models/cv/tinynas_detection/tinynas_damoyolo.py b/modelscope/models/cv/tinynas_detection/tinynas_damoyolo.py index 9effad3a..181c3095 100644 --- a/modelscope/models/cv/tinynas_detection/tinynas_damoyolo.py +++ b/modelscope/models/cv/tinynas_detection/tinynas_damoyolo.py @@ -11,5 +11,5 @@ from .detector import SingleStageDetector class DamoYolo(SingleStageDetector): def __init__(self, model_dir, *args, **kwargs): - self.config_name = 'damoyolo_s.py' + self.config_name = 'damoyolo.py' super(DamoYolo, self).__init__(model_dir, *args, **kwargs) diff --git a/modelscope/models/cv/tinynas_detection/tinynas_detector.py b/modelscope/models/cv/tinynas_detection/tinynas_detector.py index 92acf3fa..37bb01da 100644 --- a/modelscope/models/cv/tinynas_detection/tinynas_detector.py +++ b/modelscope/models/cv/tinynas_detection/tinynas_detector.py @@ -1,5 +1,5 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet. +# The DAMO-YOLO implementation is also open-sourced by the authors at https://github.com/tinyvision/damo-yolo. from modelscope.metainfo import Models from modelscope.models.builder import MODELS diff --git a/modelscope/models/cv/tinynas_detection/utils.py b/modelscope/models/cv/tinynas_detection/utils.py index d67d3a36..984e1e4e 100644 --- a/modelscope/models/cv/tinynas_detection/utils.py +++ b/modelscope/models/cv/tinynas_detection/utils.py @@ -1,30 +1,33 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet. +# The DAMO-YOLO implementation is also open-sourced by the authors, and available +# at https://github.com/tinyvision/damo-yolo. import importlib import os +import shutil import sys +import tempfile from os.path import dirname, join +from easydict import EasyDict -def get_config_by_file(config_file): - try: - sys.path.append(os.path.dirname(config_file)) - current_config = importlib.import_module( - os.path.basename(config_file).split('.')[0]) - exp = current_config.Config() - except Exception: - raise ImportError( - "{} doesn't contains class named 'Config'".format(config_file)) - return exp +def parse_config(filename): + filename = str(filename) + if filename.endswith('.py'): + with tempfile.TemporaryDirectory() as temp_config_dir: + shutil.copyfile(filename, join(temp_config_dir, '_tempconfig.py')) + sys.path.insert(0, temp_config_dir) + mod = importlib.import_module('_tempconfig') + sys.path.pop(0) + cfg_dict = EasyDict({ + name: value + for name, value in mod.__dict__.items() + if not name.startswith('__') + }) + # delete imported module + del sys.modules['_tempconfig'] + else: + raise IOError('Only .py type are supported now!') -def parse_config(config_file): - """ - get config object by file. - Args: - config_file (str): file path of config. - """ - assert (config_file is not None), 'plz provide config file' - if config_file is not None: - return get_config_by_file(config_file) + return cfg_dict diff --git a/tests/pipelines/test_tinynas_detection.py b/tests/pipelines/test_tinynas_detection.py index c92b5568..79ccf89f 100644 --- a/tests/pipelines/test_tinynas_detection.py +++ b/tests/pipelines/test_tinynas_detection.py @@ -29,7 +29,25 @@ class TinynasObjectDetectionTest(unittest.TestCase, DemoCompatibilityCheck): model='damo/cv_tinynas_object-detection_damoyolo') result = tinynas_object_detection( 'data/test/images/image_detection.jpg') - print('damoyolo', result) + print('damoyolo-s', result) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_damoyolo_m(self): + tinynas_object_detection = pipeline( + Tasks.image_object_detection, + model='damo/cv_tinynas_object-detection_damoyolo-m') + result = tinynas_object_detection( + 'data/test/images/image_detection.jpg') + print('damoyolo-m', result) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_damoyolo_t(self): + tinynas_object_detection = pipeline( + Tasks.image_object_detection, + model='damo/cv_tinynas_object-detection_damoyolo-t') + result = tinynas_object_detection( + 'data/test/images/image_detection.jpg') + print('damoyolo-t', result) @unittest.skip('demo compatibility test is only enabled on a needed-basis') def test_demo_compatibility(self): @@ -40,7 +58,7 @@ class TinynasObjectDetectionTest(unittest.TestCase, DemoCompatibilityCheck): test_image = 'data/test/images/image_detection.jpg' tinynas_object_detection = pipeline( Tasks.image_object_detection, - model='damo/cv_tinynas_object-detection_damoyolo') + model='damo/cv_tinynas_object-detection_damoyolo-m') result = tinynas_object_detection(test_image) tinynas_object_detection.show_result(test_image, result, 'demo_ret.jpg')