1. add damoyolo-t & damoyolo-m models
2. fix the configuration overlap error
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10816561
master^2
| @@ -1,5 +1,5 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| # The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet. | |||
| # The DAMO-YOLO implementation is also open-sourced by the authors at https://github.com/tinyvision/damo-yolo. | |||
| from typing import TYPE_CHECKING | |||
| @@ -1,10 +1,11 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| # The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet. | |||
| # The DAMO-YOLO implementation is also open-sourced by the authors at https://github.com/tinyvision/damo-yolo. | |||
| import copy | |||
| from .darknet import CSPDarknet | |||
| from .tinynas import load_tinynas_net | |||
| from .tinynas_csp import load_tinynas_net as load_tinynas_net_csp | |||
| from .tinynas_res import load_tinynas_net as load_tinynas_net_res | |||
| def build_backbone(cfg): | |||
| @@ -12,5 +13,7 @@ def build_backbone(cfg): | |||
| name = backbone_cfg.pop('name') | |||
| if name == 'CSPDarknet': | |||
| return CSPDarknet(**backbone_cfg) | |||
| elif name == 'TinyNAS': | |||
| return load_tinynas_net(backbone_cfg) | |||
| elif name == 'TinyNAS_csp': | |||
| return load_tinynas_net_csp(backbone_cfg) | |||
| elif name == 'TinyNAS_res': | |||
| return load_tinynas_net_res(backbone_cfg) | |||
| @@ -1,12 +1,11 @@ | |||
| # Copyright (c) Megvii Inc. All rights reserved. | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| # The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet. | |||
| import torch | |||
| from torch import nn | |||
| from ..core.base_ops import (BaseConv, CSPLayer, DWConv, Focus, ResLayer, | |||
| SPPBottleneck) | |||
| from modelscope.models.cv.tinynas_detection.core.base_ops import ( | |||
| BaseConv, CSPLayer, DWConv, Focus, ResLayer, SPPBottleneck) | |||
| class CSPDarknet(nn.Module): | |||
| @@ -1,359 +0,0 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| # The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet. | |||
| import torch | |||
| import torch.nn as nn | |||
| from modelscope.utils.file_utils import read_file | |||
| from ..core.base_ops import Focus, SPPBottleneck, get_activation | |||
| from ..core.repvgg_block import RepVggBlock | |||
| class ConvKXBN(nn.Module): | |||
| def __init__(self, in_c, out_c, kernel_size, stride): | |||
| super(ConvKXBN, self).__init__() | |||
| self.conv1 = nn.Conv2d( | |||
| in_c, | |||
| out_c, | |||
| kernel_size, | |||
| stride, (kernel_size - 1) // 2, | |||
| groups=1, | |||
| bias=False) | |||
| self.bn1 = nn.BatchNorm2d(out_c) | |||
| def forward(self, x): | |||
| return self.bn1(self.conv1(x)) | |||
| class ConvKXBNRELU(nn.Module): | |||
| def __init__(self, in_c, out_c, kernel_size, stride, act='silu'): | |||
| super(ConvKXBNRELU, self).__init__() | |||
| self.conv = ConvKXBN(in_c, out_c, kernel_size, stride) | |||
| if act is None: | |||
| self.activation_function = torch.relu | |||
| else: | |||
| self.activation_function = get_activation(act) | |||
| def forward(self, x): | |||
| output = self.conv(x) | |||
| return self.activation_function(output) | |||
| class ResConvK1KX(nn.Module): | |||
| def __init__(self, | |||
| in_c, | |||
| out_c, | |||
| btn_c, | |||
| kernel_size, | |||
| stride, | |||
| force_resproj=False, | |||
| act='silu', | |||
| reparam=False): | |||
| super(ResConvK1KX, self).__init__() | |||
| self.stride = stride | |||
| self.conv1 = ConvKXBN(in_c, btn_c, 1, 1) | |||
| if not reparam: | |||
| self.conv2 = ConvKXBN(btn_c, out_c, 3, stride) | |||
| else: | |||
| self.conv2 = RepVggBlock( | |||
| btn_c, out_c, kernel_size, stride, act='identity') | |||
| if act is None: | |||
| self.activation_function = torch.relu | |||
| else: | |||
| self.activation_function = get_activation(act) | |||
| if stride == 2: | |||
| self.residual_downsample = nn.AvgPool2d(kernel_size=2, stride=2) | |||
| else: | |||
| self.residual_downsample = nn.Identity() | |||
| if in_c != out_c or force_resproj: | |||
| self.residual_proj = ConvKXBN(in_c, out_c, 1, 1) | |||
| else: | |||
| self.residual_proj = nn.Identity() | |||
| def forward(self, x): | |||
| if self.stride != 2: | |||
| reslink = self.residual_downsample(x) | |||
| reslink = self.residual_proj(reslink) | |||
| output = x | |||
| output = self.conv1(output) | |||
| output = self.activation_function(output) | |||
| output = self.conv2(output) | |||
| if self.stride != 2: | |||
| output = output + reslink | |||
| output = self.activation_function(output) | |||
| return output | |||
| class SuperResConvK1KX(nn.Module): | |||
| def __init__(self, | |||
| in_c, | |||
| out_c, | |||
| btn_c, | |||
| kernel_size, | |||
| stride, | |||
| num_blocks, | |||
| with_spp=False, | |||
| act='silu', | |||
| reparam=False): | |||
| super(SuperResConvK1KX, self).__init__() | |||
| if act is None: | |||
| self.act = torch.relu | |||
| else: | |||
| self.act = get_activation(act) | |||
| self.block_list = nn.ModuleList() | |||
| for block_id in range(num_blocks): | |||
| if block_id == 0: | |||
| in_channels = in_c | |||
| out_channels = out_c | |||
| this_stride = stride | |||
| force_resproj = False # as a part of CSPLayer, DO NOT need this flag | |||
| this_kernel_size = kernel_size | |||
| else: | |||
| in_channels = out_c | |||
| out_channels = out_c | |||
| this_stride = 1 | |||
| force_resproj = False | |||
| this_kernel_size = kernel_size | |||
| the_block = ResConvK1KX( | |||
| in_channels, | |||
| out_channels, | |||
| btn_c, | |||
| this_kernel_size, | |||
| this_stride, | |||
| force_resproj, | |||
| act=act, | |||
| reparam=reparam) | |||
| self.block_list.append(the_block) | |||
| if block_id == 0 and with_spp: | |||
| self.block_list.append( | |||
| SPPBottleneck(out_channels, out_channels)) | |||
| def forward(self, x): | |||
| output = x | |||
| for block in self.block_list: | |||
| output = block(output) | |||
| return output | |||
| class ResConvKXKX(nn.Module): | |||
| def __init__(self, | |||
| in_c, | |||
| out_c, | |||
| btn_c, | |||
| kernel_size, | |||
| stride, | |||
| force_resproj=False, | |||
| act='silu'): | |||
| super(ResConvKXKX, self).__init__() | |||
| self.stride = stride | |||
| if self.stride == 2: | |||
| self.downsampler = ConvKXBNRELU(in_c, out_c, 3, 2, act=act) | |||
| else: | |||
| self.conv1 = ConvKXBN(in_c, btn_c, kernel_size, 1) | |||
| self.conv2 = RepVggBlock( | |||
| btn_c, out_c, kernel_size, stride, act='identity') | |||
| if act is None: | |||
| self.activation_function = torch.relu | |||
| else: | |||
| self.activation_function = get_activation(act) | |||
| if stride == 2: | |||
| self.residual_downsample = nn.AvgPool2d( | |||
| kernel_size=2, stride=2) | |||
| else: | |||
| self.residual_downsample = nn.Identity() | |||
| if in_c != out_c or force_resproj: | |||
| self.residual_proj = ConvKXBN(in_c, out_c, 1, 1) | |||
| else: | |||
| self.residual_proj = nn.Identity() | |||
| def forward(self, x): | |||
| if self.stride == 2: | |||
| return self.downsampler(x) | |||
| reslink = self.residual_downsample(x) | |||
| reslink = self.residual_proj(reslink) | |||
| output = x | |||
| output = self.conv1(output) | |||
| output = self.activation_function(output) | |||
| output = self.conv2(output) | |||
| output = output + reslink | |||
| output = self.activation_function(output) | |||
| return output | |||
| class SuperResConvKXKX(nn.Module): | |||
| def __init__(self, | |||
| in_c, | |||
| out_c, | |||
| btn_c, | |||
| kernel_size, | |||
| stride, | |||
| num_blocks, | |||
| with_spp=False, | |||
| act='silu'): | |||
| super(SuperResConvKXKX, self).__init__() | |||
| if act is None: | |||
| self.act = torch.relu | |||
| else: | |||
| self.act = get_activation(act) | |||
| self.block_list = nn.ModuleList() | |||
| for block_id in range(num_blocks): | |||
| if block_id == 0: | |||
| in_channels = in_c | |||
| out_channels = out_c | |||
| this_stride = stride | |||
| force_resproj = False # as a part of CSPLayer, DO NOT need this flag | |||
| this_kernel_size = kernel_size | |||
| else: | |||
| in_channels = out_c | |||
| out_channels = out_c | |||
| this_stride = 1 | |||
| force_resproj = False | |||
| this_kernel_size = kernel_size | |||
| the_block = ResConvKXKX( | |||
| in_channels, | |||
| out_channels, | |||
| btn_c, | |||
| this_kernel_size, | |||
| this_stride, | |||
| force_resproj, | |||
| act=act) | |||
| self.block_list.append(the_block) | |||
| if block_id == 0 and with_spp: | |||
| self.block_list.append( | |||
| SPPBottleneck(out_channels, out_channels)) | |||
| def forward(self, x): | |||
| output = x | |||
| for block in self.block_list: | |||
| output = block(output) | |||
| return output | |||
| class TinyNAS(nn.Module): | |||
| def __init__(self, | |||
| structure_info=None, | |||
| out_indices=[0, 1, 2, 4, 5], | |||
| out_channels=[None, None, 128, 256, 512], | |||
| with_spp=False, | |||
| use_focus=False, | |||
| need_conv1=True, | |||
| act='silu', | |||
| reparam=False): | |||
| super(TinyNAS, self).__init__() | |||
| assert len(out_indices) == len(out_channels) | |||
| self.out_indices = out_indices | |||
| self.need_conv1 = need_conv1 | |||
| self.block_list = nn.ModuleList() | |||
| if need_conv1: | |||
| self.conv1_list = nn.ModuleList() | |||
| for idx, block_info in enumerate(structure_info): | |||
| the_block_class = block_info['class'] | |||
| if the_block_class == 'ConvKXBNRELU': | |||
| if use_focus: | |||
| the_block = Focus( | |||
| block_info['in'], | |||
| block_info['out'], | |||
| block_info['k'], | |||
| act=act) | |||
| else: | |||
| the_block = ConvKXBNRELU( | |||
| block_info['in'], | |||
| block_info['out'], | |||
| block_info['k'], | |||
| block_info['s'], | |||
| act=act) | |||
| self.block_list.append(the_block) | |||
| elif the_block_class == 'SuperResConvK1KX': | |||
| spp = with_spp if idx == len(structure_info) - 1 else False | |||
| the_block = SuperResConvK1KX( | |||
| block_info['in'], | |||
| block_info['out'], | |||
| block_info['btn'], | |||
| block_info['k'], | |||
| block_info['s'], | |||
| block_info['L'], | |||
| spp, | |||
| act=act, | |||
| reparam=reparam) | |||
| self.block_list.append(the_block) | |||
| elif the_block_class == 'SuperResConvKXKX': | |||
| spp = with_spp if idx == len(structure_info) - 1 else False | |||
| the_block = SuperResConvKXKX( | |||
| block_info['in'], | |||
| block_info['out'], | |||
| block_info['btn'], | |||
| block_info['k'], | |||
| block_info['s'], | |||
| block_info['L'], | |||
| spp, | |||
| act=act) | |||
| self.block_list.append(the_block) | |||
| if need_conv1: | |||
| if idx in self.out_indices and out_channels[ | |||
| self.out_indices.index(idx)] is not None: | |||
| self.conv1_list.append( | |||
| nn.Conv2d(block_info['out'], | |||
| out_channels[self.out_indices.index(idx)], | |||
| 1)) | |||
| else: | |||
| self.conv1_list.append(None) | |||
| def init_weights(self, pretrain=None): | |||
| pass | |||
| def forward(self, x): | |||
| output = x | |||
| stage_feature_list = [] | |||
| for idx, block in enumerate(self.block_list): | |||
| output = block(output) | |||
| if idx in self.out_indices: | |||
| if self.need_conv1 and self.conv1_list[idx] is not None: | |||
| true_out = self.conv1_list[idx](output) | |||
| stage_feature_list.append(true_out) | |||
| else: | |||
| stage_feature_list.append(output) | |||
| return stage_feature_list | |||
| def load_tinynas_net(backbone_cfg): | |||
| # load masternet model to path | |||
| import ast | |||
| net_structure_str = read_file(backbone_cfg.structure_file) | |||
| struct_str = ''.join([x.strip() for x in net_structure_str]) | |||
| struct_info = ast.literal_eval(struct_str) | |||
| for layer in struct_info: | |||
| if 'nbitsA' in layer: | |||
| del layer['nbitsA'] | |||
| if 'nbitsW' in layer: | |||
| del layer['nbitsW'] | |||
| model = TinyNAS( | |||
| structure_info=struct_info, | |||
| out_indices=backbone_cfg.out_indices, | |||
| out_channels=backbone_cfg.out_channels, | |||
| with_spp=backbone_cfg.with_spp, | |||
| use_focus=backbone_cfg.use_focus, | |||
| act=backbone_cfg.act, | |||
| need_conv1=backbone_cfg.need_conv1, | |||
| reparam=backbone_cfg.reparam) | |||
| return model | |||
| @@ -0,0 +1,295 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| # The DAMO-YOLO implementation is also open-sourced by the authors, and available | |||
| # at https://github.com/tinyvision/damo-yolo. | |||
| import torch | |||
| import torch.nn as nn | |||
| from modelscope.models.cv.tinynas_detection.core.ops import (Focus, RepConv, | |||
| SPPBottleneck, | |||
| get_activation) | |||
| from modelscope.utils.file_utils import read_file | |||
| class ConvKXBN(nn.Module): | |||
| def __init__(self, in_c, out_c, kernel_size, stride): | |||
| super(ConvKXBN, self).__init__() | |||
| self.conv1 = nn.Conv2d( | |||
| in_c, | |||
| out_c, | |||
| kernel_size, | |||
| stride, (kernel_size - 1) // 2, | |||
| groups=1, | |||
| bias=False) | |||
| self.bn1 = nn.BatchNorm2d(out_c) | |||
| def forward(self, x): | |||
| return self.bn1(self.conv1(x)) | |||
| class ConvKXBNRELU(nn.Module): | |||
| def __init__(self, in_c, out_c, kernel_size, stride, act='silu'): | |||
| super(ConvKXBNRELU, self).__init__() | |||
| self.conv = ConvKXBN(in_c, out_c, kernel_size, stride) | |||
| if act is None: | |||
| self.activation_function = torch.relu | |||
| else: | |||
| self.activation_function = get_activation(act) | |||
| def forward(self, x): | |||
| output = self.conv(x) | |||
| return self.activation_function(output) | |||
| class ResConvBlock(nn.Module): | |||
| def __init__(self, | |||
| in_c, | |||
| out_c, | |||
| btn_c, | |||
| kernel_size, | |||
| stride, | |||
| act='silu', | |||
| reparam=False, | |||
| block_type='k1kx'): | |||
| super(ResConvBlock, self).__init__() | |||
| self.stride = stride | |||
| if block_type == 'k1kx': | |||
| self.conv1 = ConvKXBN(in_c, btn_c, kernel_size=1, stride=1) | |||
| else: | |||
| self.conv1 = ConvKXBN( | |||
| in_c, btn_c, kernel_size=kernel_size, stride=1) | |||
| if not reparam: | |||
| self.conv2 = ConvKXBN(btn_c, out_c, kernel_size, stride) | |||
| else: | |||
| self.conv2 = RepConv( | |||
| btn_c, out_c, kernel_size, stride, act='identity') | |||
| self.activation_function = get_activation(act) | |||
| if in_c != out_c and stride != 2: | |||
| self.residual_proj = ConvKXBN(in_c, out_c, kernel_size=1, stride=1) | |||
| else: | |||
| self.residual_proj = None | |||
| def forward(self, x): | |||
| if self.residual_proj is not None: | |||
| reslink = self.residual_proj(x) | |||
| else: | |||
| reslink = x | |||
| x = self.conv1(x) | |||
| x = self.activation_function(x) | |||
| x = self.conv2(x) | |||
| if self.stride != 2: | |||
| x = x + reslink | |||
| x = self.activation_function(x) | |||
| return x | |||
| class CSPStem(nn.Module): | |||
| def __init__(self, | |||
| in_c, | |||
| out_c, | |||
| btn_c, | |||
| stride, | |||
| kernel_size, | |||
| num_blocks, | |||
| act='silu', | |||
| reparam=False, | |||
| block_type='k1kx'): | |||
| super(CSPStem, self).__init__() | |||
| self.in_channels = in_c | |||
| self.out_channels = out_c | |||
| self.stride = stride | |||
| if self.stride == 2: | |||
| self.num_blocks = num_blocks - 1 | |||
| else: | |||
| self.num_blocks = num_blocks | |||
| self.kernel_size = kernel_size | |||
| self.act = act | |||
| self.block_type = block_type | |||
| out_c = out_c // 2 | |||
| if act is None: | |||
| self.act = torch.relu | |||
| else: | |||
| self.act = get_activation(act) | |||
| self.block_list = nn.ModuleList() | |||
| for block_id in range(self.num_blocks): | |||
| if self.stride == 1 and block_id == 0: | |||
| in_c = in_c // 2 | |||
| else: | |||
| in_c = out_c | |||
| the_block = ResConvBlock( | |||
| in_c, | |||
| out_c, | |||
| btn_c, | |||
| kernel_size, | |||
| stride=1, | |||
| act=act, | |||
| reparam=reparam, | |||
| block_type=block_type) | |||
| self.block_list.append(the_block) | |||
| def forward(self, x): | |||
| output = x | |||
| for block in self.block_list: | |||
| output = block(output) | |||
| return output | |||
| class TinyNAS(nn.Module): | |||
| def __init__(self, | |||
| structure_info=None, | |||
| out_indices=[2, 3, 4], | |||
| with_spp=False, | |||
| use_focus=False, | |||
| act='silu', | |||
| reparam=False): | |||
| super(TinyNAS, self).__init__() | |||
| self.out_indices = out_indices | |||
| self.block_list = nn.ModuleList() | |||
| self.stride_list = [] | |||
| for idx, block_info in enumerate(structure_info): | |||
| the_block_class = block_info['class'] | |||
| if the_block_class == 'ConvKXBNRELU': | |||
| if use_focus and idx == 0: | |||
| the_block = Focus( | |||
| block_info['in'], | |||
| block_info['out'], | |||
| block_info['k'], | |||
| act=act) | |||
| else: | |||
| the_block = ConvKXBNRELU( | |||
| block_info['in'], | |||
| block_info['out'], | |||
| block_info['k'], | |||
| block_info['s'], | |||
| act=act) | |||
| elif the_block_class == 'SuperResConvK1KX': | |||
| the_block = CSPStem( | |||
| block_info['in'], | |||
| block_info['out'], | |||
| block_info['btn'], | |||
| block_info['s'], | |||
| block_info['k'], | |||
| block_info['L'], | |||
| act=act, | |||
| reparam=reparam, | |||
| block_type='k1kx') | |||
| elif the_block_class == 'SuperResConvKXKX': | |||
| the_block = CSPStem( | |||
| block_info['in'], | |||
| block_info['out'], | |||
| block_info['btn'], | |||
| block_info['s'], | |||
| block_info['k'], | |||
| block_info['L'], | |||
| act=act, | |||
| reparam=reparam, | |||
| block_type='kxkx') | |||
| else: | |||
| raise NotImplementedError | |||
| self.block_list.append(the_block) | |||
| self.csp_stage = nn.ModuleList() | |||
| self.csp_stage.append(self.block_list[0]) | |||
| self.csp_stage.append(CSPWrapper(self.block_list[1])) | |||
| self.csp_stage.append(CSPWrapper(self.block_list[2])) | |||
| self.csp_stage.append( | |||
| CSPWrapper((self.block_list[3], self.block_list[4]))) | |||
| self.csp_stage.append( | |||
| CSPWrapper(self.block_list[5], with_spp=with_spp)) | |||
| del self.block_list | |||
| def init_weights(self, pretrain=None): | |||
| pass | |||
| def forward(self, x): | |||
| output = x | |||
| stage_feature_list = [] | |||
| for idx, block in enumerate(self.csp_stage): | |||
| output = block(output) | |||
| if idx in self.out_indices: | |||
| stage_feature_list.append(output) | |||
| return stage_feature_list | |||
| class CSPWrapper(nn.Module): | |||
| def __init__(self, convstem, act='relu', reparam=False, with_spp=False): | |||
| super(CSPWrapper, self).__init__() | |||
| self.with_spp = with_spp | |||
| if isinstance(convstem, tuple): | |||
| in_c = convstem[0].in_channels | |||
| out_c = convstem[-1].out_channels | |||
| hidden_dim = convstem[0].out_channels // 2 | |||
| _convstem = nn.ModuleList() | |||
| for modulelist in convstem: | |||
| for layer in modulelist.block_list: | |||
| _convstem.append(layer) | |||
| else: | |||
| in_c = convstem.in_channels | |||
| out_c = convstem.out_channels | |||
| hidden_dim = out_c // 2 | |||
| _convstem = convstem.block_list | |||
| self.convstem = nn.ModuleList() | |||
| for layer in _convstem: | |||
| self.convstem.append(layer) | |||
| self.act = get_activation(act) | |||
| self.downsampler = ConvKXBNRELU( | |||
| in_c, hidden_dim * 2, 3, 2, act=self.act) | |||
| if self.with_spp: | |||
| self.spp = SPPBottleneck(hidden_dim * 2, hidden_dim * 2) | |||
| if len(self.convstem) > 0: | |||
| self.conv_start = ConvKXBNRELU( | |||
| hidden_dim * 2, hidden_dim, 1, 1, act=self.act) | |||
| self.conv_shortcut = ConvKXBNRELU( | |||
| hidden_dim * 2, out_c // 2, 1, 1, act=self.act) | |||
| self.conv_fuse = ConvKXBNRELU(out_c, out_c, 1, 1, act=self.act) | |||
| def forward(self, x): | |||
| x = self.downsampler(x) | |||
| if self.with_spp: | |||
| x = self.spp(x) | |||
| if len(self.convstem) > 0: | |||
| shortcut = self.conv_shortcut(x) | |||
| x = self.conv_start(x) | |||
| for block in self.convstem: | |||
| x = block(x) | |||
| x = torch.cat((x, shortcut), dim=1) | |||
| x = self.conv_fuse(x) | |||
| return x | |||
| def load_tinynas_net(backbone_cfg): | |||
| # load masternet model to path | |||
| import ast | |||
| net_structure_str = read_file(backbone_cfg.structure_file) | |||
| struct_str = ''.join([x.strip() for x in net_structure_str]) | |||
| struct_info = ast.literal_eval(struct_str) | |||
| for layer in struct_info: | |||
| if 'nbitsA' in layer: | |||
| del layer['nbitsA'] | |||
| if 'nbitsW' in layer: | |||
| del layer['nbitsW'] | |||
| model = TinyNAS( | |||
| structure_info=struct_info, | |||
| out_indices=backbone_cfg.out_indices, | |||
| with_spp=backbone_cfg.with_spp, | |||
| use_focus=backbone_cfg.use_focus, | |||
| act=backbone_cfg.act, | |||
| reparam=backbone_cfg.reparam) | |||
| return model | |||
| @@ -0,0 +1,238 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| # The DAMO-YOLO implementation is also open-sourced by the authors, and available | |||
| # at https://github.com/tinyvision/damo-yolo. | |||
| import torch | |||
| import torch.nn as nn | |||
| from modelscope.models.cv.tinynas_detection.core.ops import (Focus, RepConv, | |||
| SPPBottleneck, | |||
| get_activation) | |||
| from modelscope.utils.file_utils import read_file | |||
| class ConvKXBN(nn.Module): | |||
| def __init__(self, in_c, out_c, kernel_size, stride): | |||
| super(ConvKXBN, self).__init__() | |||
| self.conv1 = nn.Conv2d( | |||
| in_c, | |||
| out_c, | |||
| kernel_size, | |||
| stride, (kernel_size - 1) // 2, | |||
| groups=1, | |||
| bias=False) | |||
| self.bn1 = nn.BatchNorm2d(out_c) | |||
| def forward(self, x): | |||
| return self.bn1(self.conv1(x)) | |||
| class ConvKXBNRELU(nn.Module): | |||
| def __init__(self, in_c, out_c, kernel_size, stride, act='silu'): | |||
| super(ConvKXBNRELU, self).__init__() | |||
| self.conv = ConvKXBN(in_c, out_c, kernel_size, stride) | |||
| if act is None: | |||
| self.activation_function = torch.relu | |||
| else: | |||
| self.activation_function = get_activation(act) | |||
| def forward(self, x): | |||
| output = self.conv(x) | |||
| return self.activation_function(output) | |||
| class ResConvBlock(nn.Module): | |||
| def __init__(self, | |||
| in_c, | |||
| out_c, | |||
| btn_c, | |||
| kernel_size, | |||
| stride, | |||
| act='silu', | |||
| reparam=False, | |||
| block_type='k1kx'): | |||
| super(ResConvBlock, self).__init__() | |||
| self.stride = stride | |||
| if block_type == 'k1kx': | |||
| self.conv1 = ConvKXBN(in_c, btn_c, kernel_size=1, stride=1) | |||
| else: | |||
| self.conv1 = ConvKXBN( | |||
| in_c, btn_c, kernel_size=kernel_size, stride=1) | |||
| if not reparam: | |||
| self.conv2 = ConvKXBN(btn_c, out_c, kernel_size, stride) | |||
| else: | |||
| self.conv2 = RepConv( | |||
| btn_c, out_c, kernel_size, stride, act='identity') | |||
| self.activation_function = get_activation(act) | |||
| if in_c != out_c and stride != 2: | |||
| self.residual_proj = ConvKXBN(in_c, out_c, 1, 1) | |||
| else: | |||
| self.residual_proj = None | |||
| def forward(self, x): | |||
| if self.residual_proj is not None: | |||
| reslink = self.residual_proj(x) | |||
| else: | |||
| reslink = x | |||
| x = self.conv1(x) | |||
| x = self.activation_function(x) | |||
| x = self.conv2(x) | |||
| if self.stride != 2: | |||
| x = x + reslink | |||
| x = self.activation_function(x) | |||
| return x | |||
| class SuperResStem(nn.Module): | |||
| def __init__(self, | |||
| in_c, | |||
| out_c, | |||
| btn_c, | |||
| kernel_size, | |||
| stride, | |||
| num_blocks, | |||
| with_spp=False, | |||
| act='silu', | |||
| reparam=False, | |||
| block_type='k1kx'): | |||
| super(SuperResStem, self).__init__() | |||
| if act is None: | |||
| self.act = torch.relu | |||
| else: | |||
| self.act = get_activation(act) | |||
| self.block_list = nn.ModuleList() | |||
| for block_id in range(num_blocks): | |||
| if block_id == 0: | |||
| in_channels = in_c | |||
| out_channels = out_c | |||
| this_stride = stride | |||
| this_kernel_size = kernel_size | |||
| else: | |||
| in_channels = out_c | |||
| out_channels = out_c | |||
| this_stride = 1 | |||
| this_kernel_size = kernel_size | |||
| the_block = ResConvBlock( | |||
| in_channels, | |||
| out_channels, | |||
| btn_c, | |||
| this_kernel_size, | |||
| this_stride, | |||
| act=act, | |||
| reparam=reparam, | |||
| block_type=block_type) | |||
| self.block_list.append(the_block) | |||
| if block_id == 0 and with_spp: | |||
| self.block_list.append( | |||
| SPPBottleneck(out_channels, out_channels)) | |||
| def forward(self, x): | |||
| output = x | |||
| for block in self.block_list: | |||
| output = block(output) | |||
| return output | |||
| class TinyNAS(nn.Module): | |||
| def __init__(self, | |||
| structure_info=None, | |||
| out_indices=[2, 4, 5], | |||
| with_spp=False, | |||
| use_focus=False, | |||
| act='silu', | |||
| reparam=False): | |||
| super(TinyNAS, self).__init__() | |||
| self.out_indices = out_indices | |||
| self.block_list = nn.ModuleList() | |||
| for idx, block_info in enumerate(structure_info): | |||
| the_block_class = block_info['class'] | |||
| if the_block_class == 'ConvKXBNRELU': | |||
| if use_focus: | |||
| the_block = Focus( | |||
| block_info['in'], | |||
| block_info['out'], | |||
| block_info['k'], | |||
| act=act) | |||
| else: | |||
| the_block = ConvKXBNRELU( | |||
| block_info['in'], | |||
| block_info['out'], | |||
| block_info['k'], | |||
| block_info['s'], | |||
| act=act) | |||
| self.block_list.append(the_block) | |||
| elif the_block_class == 'SuperResConvK1KX': | |||
| spp = with_spp if idx == len(structure_info) - 1 else False | |||
| the_block = SuperResStem( | |||
| block_info['in'], | |||
| block_info['out'], | |||
| block_info['btn'], | |||
| block_info['k'], | |||
| block_info['s'], | |||
| block_info['L'], | |||
| spp, | |||
| act=act, | |||
| reparam=reparam, | |||
| block_type='k1kx') | |||
| self.block_list.append(the_block) | |||
| elif the_block_class == 'SuperResConvKXKX': | |||
| spp = with_spp if idx == len(structure_info) - 1 else False | |||
| the_block = SuperResStem( | |||
| block_info['in'], | |||
| block_info['out'], | |||
| block_info['btn'], | |||
| block_info['k'], | |||
| block_info['s'], | |||
| block_info['L'], | |||
| spp, | |||
| act=act, | |||
| reparam=reparam, | |||
| block_type='kxkx') | |||
| self.block_list.append(the_block) | |||
| else: | |||
| raise NotImplementedError | |||
| def init_weights(self, pretrain=None): | |||
| pass | |||
| def forward(self, x): | |||
| output = x | |||
| stage_feature_list = [] | |||
| for idx, block in enumerate(self.block_list): | |||
| output = block(output) | |||
| if idx in self.out_indices: | |||
| stage_feature_list.append(output) | |||
| return stage_feature_list | |||
| def load_tinynas_net(backbone_cfg): | |||
| # load masternet model to path | |||
| import ast | |||
| net_structure_str = read_file(backbone_cfg.structure_file) | |||
| struct_str = ''.join([x.strip() for x in net_structure_str]) | |||
| struct_info = ast.literal_eval(struct_str) | |||
| for layer in struct_info: | |||
| if 'nbitsA' in layer: | |||
| del layer['nbitsA'] | |||
| if 'nbitsW' in layer: | |||
| del layer['nbitsW'] | |||
| model = TinyNAS( | |||
| structure_info=struct_info, | |||
| out_indices=backbone_cfg.out_indices, | |||
| with_spp=backbone_cfg.with_spp, | |||
| use_focus=backbone_cfg.use_focus, | |||
| act=backbone_cfg.act, | |||
| reparam=backbone_cfg.reparam) | |||
| return model | |||
| @@ -1,2 +1,2 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| # The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet. | |||
| # The DAMO-YOLO implementation is also open-sourced by the authors at https://github.com/tinyvision/damo-yolo. | |||
| @@ -1,5 +1,5 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| # The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet. | |||
| # The DAMO-YOLO implementation is also open-sourced by the authors at https://github.com/tinyvision/damo-yolo. | |||
| import math | |||
| import torch | |||
| @@ -1,5 +1,5 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| # The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet. | |||
| # The DAMO-YOLO implementation is also open-sourced by the authors at https://github.com/tinyvision/damo-yolo. | |||
| import numpy as np | |||
| import torch | |||
| @@ -0,0 +1,435 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| # The DAMO-YOLO implementation is also open-sourced by the authors at https://github.com/tinyvision/damo-yolo. | |||
| import numpy as np | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| class SiLU(nn.Module): | |||
| """export-friendly version of nn.SiLU()""" | |||
| @staticmethod | |||
| def forward(x): | |||
| return x * torch.sigmoid(x) | |||
| class Swish(nn.Module): | |||
| def __init__(self, inplace=True): | |||
| super(Swish, self).__init__() | |||
| self.inplace = inplace | |||
| def forward(self, x): | |||
| if self.inplace: | |||
| x.mul_(F.sigmoid(x)) | |||
| return x | |||
| else: | |||
| return x * F.sigmoid(x) | |||
| def get_activation(name='silu', inplace=True): | |||
| if name is None: | |||
| return nn.Identity() | |||
| if isinstance(name, str): | |||
| if name == 'silu': | |||
| module = nn.SiLU(inplace=inplace) | |||
| elif name == 'relu': | |||
| module = nn.ReLU(inplace=inplace) | |||
| elif name == 'lrelu': | |||
| module = nn.LeakyReLU(0.1, inplace=inplace) | |||
| elif name == 'swish': | |||
| module = Swish(inplace=inplace) | |||
| elif name == 'hardsigmoid': | |||
| module = nn.Hardsigmoid(inplace=inplace) | |||
| elif name == 'identity': | |||
| module = nn.Identity() | |||
| else: | |||
| raise AttributeError('Unsupported act type: {}'.format(name)) | |||
| return module | |||
| elif isinstance(name, nn.Module): | |||
| return name | |||
| else: | |||
| raise AttributeError('Unsupported act type: {}'.format(name)) | |||
| def get_norm(name, out_channels, inplace=True): | |||
| if name == 'bn': | |||
| module = nn.BatchNorm2d(out_channels) | |||
| else: | |||
| raise NotImplementedError | |||
| return module | |||
| class ConvBNAct(nn.Module): | |||
| """A Conv2d -> Batchnorm -> silu/leaky relu block""" | |||
| def __init__( | |||
| self, | |||
| in_channels, | |||
| out_channels, | |||
| ksize, | |||
| stride=1, | |||
| groups=1, | |||
| bias=False, | |||
| act='silu', | |||
| norm='bn', | |||
| reparam=False, | |||
| ): | |||
| super().__init__() | |||
| # same padding | |||
| pad = (ksize - 1) // 2 | |||
| self.conv = nn.Conv2d( | |||
| in_channels, | |||
| out_channels, | |||
| kernel_size=ksize, | |||
| stride=stride, | |||
| padding=pad, | |||
| groups=groups, | |||
| bias=bias, | |||
| ) | |||
| if norm is not None: | |||
| self.bn = get_norm(norm, out_channels, inplace=True) | |||
| if act is not None: | |||
| self.act = get_activation(act, inplace=True) | |||
| self.with_norm = norm is not None | |||
| self.with_act = act is not None | |||
| def forward(self, x): | |||
| x = self.conv(x) | |||
| if self.with_norm: | |||
| x = self.bn(x) | |||
| if self.with_act: | |||
| x = self.act(x) | |||
| return x | |||
| def fuseforward(self, x): | |||
| return self.act(self.conv(x)) | |||
| class SPPBottleneck(nn.Module): | |||
| """Spatial pyramid pooling layer used in YOLOv3-SPP""" | |||
| def __init__(self, | |||
| in_channels, | |||
| out_channels, | |||
| kernel_sizes=(5, 9, 13), | |||
| activation='silu'): | |||
| super().__init__() | |||
| hidden_channels = in_channels // 2 | |||
| self.conv1 = ConvBNAct( | |||
| in_channels, hidden_channels, 1, stride=1, act=activation) | |||
| self.m = nn.ModuleList([ | |||
| nn.MaxPool2d(kernel_size=ks, stride=1, padding=ks // 2) | |||
| for ks in kernel_sizes | |||
| ]) | |||
| conv2_channels = hidden_channels * (len(kernel_sizes) + 1) | |||
| self.conv2 = ConvBNAct( | |||
| conv2_channels, out_channels, 1, stride=1, act=activation) | |||
| def forward(self, x): | |||
| x = self.conv1(x) | |||
| x = torch.cat([x] + [m(x) for m in self.m], dim=1) | |||
| x = self.conv2(x) | |||
| return x | |||
| class Focus(nn.Module): | |||
| """Focus width and height information into channel space.""" | |||
| def __init__(self, | |||
| in_channels, | |||
| out_channels, | |||
| ksize=1, | |||
| stride=1, | |||
| act='silu'): | |||
| super().__init__() | |||
| self.conv = ConvBNAct( | |||
| in_channels * 4, out_channels, ksize, stride, act=act) | |||
| def forward(self, x): | |||
| # shape of x (b,c,w,h) -> y(b,4c,w/2,h/2) | |||
| patch_top_left = x[..., ::2, ::2] | |||
| patch_top_right = x[..., ::2, 1::2] | |||
| patch_bot_left = x[..., 1::2, ::2] | |||
| patch_bot_right = x[..., 1::2, 1::2] | |||
| x = torch.cat( | |||
| ( | |||
| patch_top_left, | |||
| patch_bot_left, | |||
| patch_top_right, | |||
| patch_bot_right, | |||
| ), | |||
| dim=1, | |||
| ) | |||
| return self.conv(x) | |||
| class BasicBlock_3x3_Reverse(nn.Module): | |||
| def __init__(self, | |||
| ch_in, | |||
| ch_hidden_ratio, | |||
| ch_out, | |||
| act='relu', | |||
| shortcut=True): | |||
| super(BasicBlock_3x3_Reverse, self).__init__() | |||
| assert ch_in == ch_out | |||
| ch_hidden = int(ch_in * ch_hidden_ratio) | |||
| self.conv1 = ConvBNAct(ch_hidden, ch_out, 3, stride=1, act=act) | |||
| self.conv2 = RepConv(ch_in, ch_hidden, 3, stride=1, act=act) | |||
| self.shortcut = shortcut | |||
| def forward(self, x): | |||
| y = self.conv2(x) | |||
| y = self.conv1(y) | |||
| if self.shortcut: | |||
| return x + y | |||
| else: | |||
| return y | |||
| class SPP(nn.Module): | |||
| def __init__( | |||
| self, | |||
| ch_in, | |||
| ch_out, | |||
| k, | |||
| pool_size, | |||
| act='swish', | |||
| ): | |||
| super(SPP, self).__init__() | |||
| self.pool = [] | |||
| for i, size in enumerate(pool_size): | |||
| pool = nn.MaxPool2d( | |||
| kernel_size=size, stride=1, padding=size // 2, ceil_mode=False) | |||
| self.add_module('pool{}'.format(i), pool) | |||
| self.pool.append(pool) | |||
| self.conv = ConvBNAct(ch_in, ch_out, k, act=act) | |||
| def forward(self, x): | |||
| outs = [x] | |||
| for pool in self.pool: | |||
| outs.append(pool(x)) | |||
| y = torch.cat(outs, axis=1) | |||
| y = self.conv(y) | |||
| return y | |||
| class CSPStage(nn.Module): | |||
| def __init__(self, | |||
| block_fn, | |||
| ch_in, | |||
| ch_hidden_ratio, | |||
| ch_out, | |||
| n, | |||
| act='swish', | |||
| spp=False): | |||
| super(CSPStage, self).__init__() | |||
| split_ratio = 2 | |||
| ch_first = int(ch_out // split_ratio) | |||
| ch_mid = int(ch_out - ch_first) | |||
| self.conv1 = ConvBNAct(ch_in, ch_first, 1, act=act) | |||
| self.conv2 = ConvBNAct(ch_in, ch_mid, 1, act=act) | |||
| self.convs = nn.Sequential() | |||
| next_ch_in = ch_mid | |||
| for i in range(n): | |||
| if block_fn == 'BasicBlock_3x3_Reverse': | |||
| self.convs.add_module( | |||
| str(i), | |||
| BasicBlock_3x3_Reverse( | |||
| next_ch_in, | |||
| ch_hidden_ratio, | |||
| ch_mid, | |||
| act=act, | |||
| shortcut=True)) | |||
| else: | |||
| raise NotImplementedError | |||
| if i == (n - 1) // 2 and spp: | |||
| self.convs.add_module( | |||
| 'spp', SPP(ch_mid * 4, ch_mid, 1, [5, 9, 13], act=act)) | |||
| next_ch_in = ch_mid | |||
| self.conv3 = ConvBNAct(ch_mid * n + ch_first, ch_out, 1, act=act) | |||
| def forward(self, x): | |||
| y1 = self.conv1(x) | |||
| y2 = self.conv2(x) | |||
| mid_out = [y1] | |||
| for conv in self.convs: | |||
| y2 = conv(y2) | |||
| mid_out.append(y2) | |||
| y = torch.cat(mid_out, axis=1) | |||
| y = self.conv3(y) | |||
| return y | |||
| def conv_bn(in_channels, out_channels, kernel_size, stride, padding, groups=1): | |||
| '''Basic cell for rep-style block, including conv and bn''' | |||
| result = nn.Sequential() | |||
| result.add_module( | |||
| 'conv', | |||
| nn.Conv2d( | |||
| in_channels=in_channels, | |||
| out_channels=out_channels, | |||
| kernel_size=kernel_size, | |||
| stride=stride, | |||
| padding=padding, | |||
| groups=groups, | |||
| bias=False)) | |||
| result.add_module('bn', nn.BatchNorm2d(num_features=out_channels)) | |||
| return result | |||
| class RepConv(nn.Module): | |||
| '''RepConv is a basic rep-style block, including training and deploy status | |||
| Code is based on https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py | |||
| ''' | |||
| def __init__(self, | |||
| in_channels, | |||
| out_channels, | |||
| kernel_size=3, | |||
| stride=1, | |||
| padding=1, | |||
| dilation=1, | |||
| groups=1, | |||
| padding_mode='zeros', | |||
| deploy=False, | |||
| act='relu', | |||
| norm=None): | |||
| super(RepConv, self).__init__() | |||
| self.deploy = deploy | |||
| self.groups = groups | |||
| self.in_channels = in_channels | |||
| self.out_channels = out_channels | |||
| assert kernel_size == 3 | |||
| assert padding == 1 | |||
| padding_11 = padding - kernel_size // 2 | |||
| if isinstance(act, str): | |||
| self.nonlinearity = get_activation(act) | |||
| else: | |||
| self.nonlinearity = act | |||
| if deploy: | |||
| self.rbr_reparam = nn.Conv2d( | |||
| in_channels=in_channels, | |||
| out_channels=out_channels, | |||
| kernel_size=kernel_size, | |||
| stride=stride, | |||
| padding=padding, | |||
| dilation=dilation, | |||
| groups=groups, | |||
| bias=True, | |||
| padding_mode=padding_mode) | |||
| else: | |||
| self.rbr_identity = None | |||
| self.rbr_dense = conv_bn( | |||
| in_channels=in_channels, | |||
| out_channels=out_channels, | |||
| kernel_size=kernel_size, | |||
| stride=stride, | |||
| padding=padding, | |||
| groups=groups) | |||
| self.rbr_1x1 = conv_bn( | |||
| in_channels=in_channels, | |||
| out_channels=out_channels, | |||
| kernel_size=1, | |||
| stride=stride, | |||
| padding=padding_11, | |||
| groups=groups) | |||
| def forward(self, inputs): | |||
| '''Forward process''' | |||
| if hasattr(self, 'rbr_reparam'): | |||
| return self.nonlinearity(self.rbr_reparam(inputs)) | |||
| if self.rbr_identity is None: | |||
| id_out = 0 | |||
| else: | |||
| id_out = self.rbr_identity(inputs) | |||
| return self.nonlinearity( | |||
| self.rbr_dense(inputs) + self.rbr_1x1(inputs) + id_out) | |||
| def get_equivalent_kernel_bias(self): | |||
| kernel3x3, bias3x3 = self._fuse_bn_tensor(self.rbr_dense) | |||
| kernel1x1, bias1x1 = self._fuse_bn_tensor(self.rbr_1x1) | |||
| kernelid, biasid = self._fuse_bn_tensor(self.rbr_identity) | |||
| return kernel3x3 + self._pad_1x1_to_3x3_tensor( | |||
| kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid | |||
| def _pad_1x1_to_3x3_tensor(self, kernel1x1): | |||
| if kernel1x1 is None: | |||
| return 0 | |||
| else: | |||
| return torch.nn.functional.pad(kernel1x1, [1, 1, 1, 1]) | |||
| def _fuse_bn_tensor(self, branch): | |||
| if branch is None: | |||
| return 0, 0 | |||
| if isinstance(branch, nn.Sequential): | |||
| kernel = branch.conv.weight | |||
| running_mean = branch.bn.running_mean | |||
| running_var = branch.bn.running_var | |||
| gamma = branch.bn.weight | |||
| beta = branch.bn.bias | |||
| eps = branch.bn.eps | |||
| else: | |||
| assert isinstance(branch, nn.BatchNorm2d) | |||
| if not hasattr(self, 'id_tensor'): | |||
| input_dim = self.in_channels // self.groups | |||
| kernel_value = np.zeros((self.in_channels, input_dim, 3, 3), | |||
| dtype=np.float32) | |||
| for i in range(self.in_channels): | |||
| kernel_value[i, i % input_dim, 1, 1] = 1 | |||
| self.id_tensor = torch.from_numpy(kernel_value).to( | |||
| branch.weight.device) | |||
| kernel = self.id_tensor | |||
| running_mean = branch.running_mean | |||
| running_var = branch.running_var | |||
| gamma = branch.weight | |||
| beta = branch.bias | |||
| eps = branch.eps | |||
| std = (running_var + eps).sqrt() | |||
| t = (gamma / std).reshape(-1, 1, 1, 1) | |||
| return kernel * t, beta - running_mean * gamma / std | |||
| def switch_to_deploy(self): | |||
| if hasattr(self, 'rbr_reparam'): | |||
| return | |||
| kernel, bias = self.get_equivalent_kernel_bias() | |||
| self.rbr_reparam = nn.Conv2d( | |||
| in_channels=self.rbr_dense.conv.in_channels, | |||
| out_channels=self.rbr_dense.conv.out_channels, | |||
| kernel_size=self.rbr_dense.conv.kernel_size, | |||
| stride=self.rbr_dense.conv.stride, | |||
| padding=self.rbr_dense.conv.padding, | |||
| dilation=self.rbr_dense.conv.dilation, | |||
| groups=self.rbr_dense.conv.groups, | |||
| bias=True) | |||
| self.rbr_reparam.weight.data = kernel | |||
| self.rbr_reparam.bias.data = bias | |||
| for para in self.parameters(): | |||
| para.detach_() | |||
| self.__delattr__('rbr_dense') | |||
| self.__delattr__('rbr_1x1') | |||
| if hasattr(self, 'rbr_identity'): | |||
| self.__delattr__('rbr_identity') | |||
| if hasattr(self, 'id_tensor'): | |||
| self.__delattr__('id_tensor') | |||
| self.deploy = True | |||
| @@ -1,5 +1,5 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| # The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet. | |||
| # The DAMO-YOLO implementation is also open-sourced by the authors at https://github.com/tinyvision/damo-yolo. | |||
| import numpy as np | |||
| import torch | |||
| @@ -1,5 +1,5 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| # The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet. | |||
| # The DAMO-YOLO implementation is also open-sourced by the authors at https://github.com/tinyvision/damo-yolo. | |||
| import numpy as np | |||
| import torch | |||
| @@ -1,5 +1,5 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| # The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet. | |||
| # The DAMO-YOLO implementation is also open-sourced by the authors at https://github.com/tinyvision/damo-yolo. | |||
| import os.path as osp | |||
| import pickle | |||
| @@ -42,7 +42,7 @@ class SingleStageDetector(TorchModel): | |||
| self.conf_thre = config.model.head.nms_conf_thre | |||
| self.nms_thre = config.model.head.nms_iou_thre | |||
| if self.cfg.model.backbone.name == 'TinyNAS': | |||
| if 'TinyNAS' in self.cfg.model.backbone.name: | |||
| self.cfg.model.backbone.structure_file = osp.join( | |||
| model_dir, self.cfg.model.backbone.structure_file) | |||
| self.backbone = build_backbone(self.cfg.model.backbone) | |||
| @@ -1,9 +1,10 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| # The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet. | |||
| # The DAMO-YOLO implementation is also open-sourced by the authors at https://github.com/tinyvision/damo-yolo. | |||
| import copy | |||
| from .gfocal_v2_tiny import GFocalHead_Tiny | |||
| from .zero_head import ZeroHead | |||
| def build_head(cfg): | |||
| @@ -12,5 +13,7 @@ def build_head(cfg): | |||
| name = head_cfg.pop('name') | |||
| if name == 'GFocalV2': | |||
| return GFocalHead_Tiny(**head_cfg) | |||
| elif name == 'ZeroHead': | |||
| return ZeroHead(**head_cfg) | |||
| else: | |||
| raise NotImplementedError | |||
| @@ -1,5 +1,5 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| # The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet. | |||
| # The DAMO-YOLO implementation is also open-sourced by the authors at https://github.com/tinyvision/damo-yolo. | |||
| import functools | |||
| from functools import partial | |||
| @@ -9,7 +9,8 @@ import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| from ..core.base_ops import BaseConv, DWConv | |||
| from modelscope.models.cv.tinynas_detection.core.base_ops import (BaseConv, | |||
| DWConv) | |||
| class Scale(nn.Module): | |||
| @@ -0,0 +1,288 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| # The DAMO-YOLO implementation is also open-sourced by the authors, and available | |||
| # at https://github.com/tinyvision/damo-yolo. | |||
| from functools import partial | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| from modelscope.models.cv.tinynas_detection.core.ops import ConvBNAct | |||
| class Scale(nn.Module): | |||
| def __init__(self, scale=1.0): | |||
| super(Scale, self).__init__() | |||
| self.scale = nn.Parameter(torch.tensor(scale, dtype=torch.float)) | |||
| def forward(self, x): | |||
| return x * self.scale | |||
| def multi_apply(func, *args, **kwargs): | |||
| pfunc = partial(func, **kwargs) if kwargs else func | |||
| map_results = map(pfunc, *args) | |||
| return tuple(map(list, zip(*map_results))) | |||
| def distance2bbox(points, distance, max_shape=None): | |||
| """Decode distance prediction to bounding box. | |||
| """ | |||
| x1 = points[..., 0] - distance[..., 0] | |||
| y1 = points[..., 1] - distance[..., 1] | |||
| x2 = points[..., 0] + distance[..., 2] | |||
| y2 = points[..., 1] + distance[..., 3] | |||
| if max_shape is not None: | |||
| x1 = x1.clamp(min=0, max=max_shape[1]) | |||
| y1 = y1.clamp(min=0, max=max_shape[0]) | |||
| x2 = x2.clamp(min=0, max=max_shape[1]) | |||
| y2 = y2.clamp(min=0, max=max_shape[0]) | |||
| return torch.stack([x1, y1, x2, y2], -1) | |||
| def bbox2distance(points, bbox, max_dis=None, eps=0.1): | |||
| """Decode bounding box based on distances. | |||
| """ | |||
| left = points[:, 0] - bbox[:, 0] | |||
| top = points[:, 1] - bbox[:, 1] | |||
| right = bbox[:, 2] - points[:, 0] | |||
| bottom = bbox[:, 3] - points[:, 1] | |||
| if max_dis is not None: | |||
| left = left.clamp(min=0, max=max_dis - eps) | |||
| top = top.clamp(min=0, max=max_dis - eps) | |||
| right = right.clamp(min=0, max=max_dis - eps) | |||
| bottom = bottom.clamp(min=0, max=max_dis - eps) | |||
| return torch.stack([left, top, right, bottom], -1) | |||
| class Integral(nn.Module): | |||
| """A fixed layer for calculating integral result from distribution. | |||
| """ | |||
| def __init__(self, reg_max=16): | |||
| super(Integral, self).__init__() | |||
| self.reg_max = reg_max | |||
| self.register_buffer('project', | |||
| torch.linspace(0, self.reg_max, self.reg_max + 1)) | |||
| def forward(self, x): | |||
| """Forward feature from the regression head to get integral result of | |||
| bounding box location. | |||
| """ | |||
| b, hw, _, _ = x.size() | |||
| x = x.reshape(b * hw * 4, self.reg_max + 1) | |||
| y = self.project.type_as(x).unsqueeze(1) | |||
| x = torch.matmul(x, y).reshape(b, hw, 4) | |||
| return x | |||
| class ZeroHead(nn.Module): | |||
| """Ref to Generalized Focal Loss V2: Learning Reliable Localization Quality | |||
| Estimation for Dense Object Detection. | |||
| """ | |||
| def __init__( | |||
| self, | |||
| num_classes, | |||
| in_channels, | |||
| stacked_convs=4, # 4 | |||
| feat_channels=256, | |||
| reg_max=12, | |||
| strides=[8, 16, 32], | |||
| norm='gn', | |||
| act='relu', | |||
| nms_conf_thre=0.05, | |||
| nms_iou_thre=0.7, | |||
| nms=True, | |||
| **kwargs): | |||
| self.in_channels = in_channels | |||
| self.num_classes = num_classes | |||
| self.stacked_convs = stacked_convs | |||
| self.act = act | |||
| self.strides = strides | |||
| if stacked_convs == 0: | |||
| feat_channels = in_channels | |||
| if isinstance(feat_channels, list): | |||
| self.feat_channels = feat_channels | |||
| else: | |||
| self.feat_channels = [feat_channels] * len(self.strides) | |||
| # add 1 for keep consistance with former models | |||
| self.cls_out_channels = num_classes + 1 | |||
| self.reg_max = reg_max | |||
| self.nms = nms | |||
| self.nms_conf_thre = nms_conf_thre | |||
| self.nms_iou_thre = nms_iou_thre | |||
| self.feat_size = [torch.zeros(4) for _ in strides] | |||
| super(ZeroHead, self).__init__() | |||
| self.integral = Integral(self.reg_max) | |||
| self._init_layers() | |||
| def _build_not_shared_convs(self, in_channel, feat_channels): | |||
| cls_convs = nn.ModuleList() | |||
| reg_convs = nn.ModuleList() | |||
| for i in range(self.stacked_convs): | |||
| chn = feat_channels if i > 0 else in_channel | |||
| kernel_size = 3 if i > 0 else 1 | |||
| cls_convs.append( | |||
| ConvBNAct( | |||
| chn, | |||
| feat_channels, | |||
| kernel_size, | |||
| stride=1, | |||
| groups=1, | |||
| norm='bn', | |||
| act=self.act)) | |||
| reg_convs.append( | |||
| ConvBNAct( | |||
| chn, | |||
| feat_channels, | |||
| kernel_size, | |||
| stride=1, | |||
| groups=1, | |||
| norm='bn', | |||
| act=self.act)) | |||
| return cls_convs, reg_convs | |||
| def _init_layers(self): | |||
| """Initialize layers of the head.""" | |||
| self.cls_convs = nn.ModuleList() | |||
| self.reg_convs = nn.ModuleList() | |||
| for i in range(len(self.strides)): | |||
| cls_convs, reg_convs = self._build_not_shared_convs( | |||
| self.in_channels[i], self.feat_channels[i]) | |||
| self.cls_convs.append(cls_convs) | |||
| self.reg_convs.append(reg_convs) | |||
| self.gfl_cls = nn.ModuleList([ | |||
| nn.Conv2d( | |||
| self.feat_channels[i], self.cls_out_channels, 3, padding=1) | |||
| for i in range(len(self.strides)) | |||
| ]) | |||
| self.gfl_reg = nn.ModuleList([ | |||
| nn.Conv2d( | |||
| self.feat_channels[i], 4 * (self.reg_max + 1), 3, padding=1) | |||
| for i in range(len(self.strides)) | |||
| ]) | |||
| self.scales = nn.ModuleList([Scale(1.0) for _ in self.strides]) | |||
| def forward(self, xin, labels=None, imgs=None, aux_targets=None): | |||
| if self.training: | |||
| return NotImplementedError | |||
| else: | |||
| return self.forward_eval(xin=xin, labels=labels, imgs=imgs) | |||
| def forward_eval(self, xin, labels=None, imgs=None): | |||
| # prepare priors for label assignment and bbox decode | |||
| if self.feat_size[0] != xin[0].shape: | |||
| mlvl_priors_list = [ | |||
| self.get_single_level_center_priors( | |||
| xin[i].shape[0], | |||
| xin[i].shape[-2:], | |||
| stride, | |||
| dtype=torch.float32, | |||
| device=xin[0].device) | |||
| for i, stride in enumerate(self.strides) | |||
| ] | |||
| self.mlvl_priors = torch.cat(mlvl_priors_list, dim=1) | |||
| self.feat_size[0] = xin[0].shape | |||
| # forward for bboxes and classification prediction | |||
| cls_scores, bbox_preds = multi_apply( | |||
| self.forward_single, | |||
| xin, | |||
| self.cls_convs, | |||
| self.reg_convs, | |||
| self.gfl_cls, | |||
| self.gfl_reg, | |||
| self.scales, | |||
| ) | |||
| cls_scores = torch.cat(cls_scores, dim=1)[:, :, :self.num_classes] | |||
| bbox_preds = torch.cat(bbox_preds, dim=1) | |||
| # batch bbox decode | |||
| bbox_preds = self.integral(bbox_preds) * self.mlvl_priors[..., 2, None] | |||
| bbox_preds = distance2bbox(self.mlvl_priors[..., :2], bbox_preds) | |||
| res = torch.cat([bbox_preds, cls_scores[..., 0:self.num_classes]], | |||
| dim=-1) | |||
| return res | |||
| def forward_single(self, x, cls_convs, reg_convs, gfl_cls, gfl_reg, scale): | |||
| """Forward feature of a single scale level. | |||
| """ | |||
| cls_feat = x | |||
| reg_feat = x | |||
| for cls_conv, reg_conv in zip(cls_convs, reg_convs): | |||
| cls_feat = cls_conv(cls_feat) | |||
| reg_feat = reg_conv(reg_feat) | |||
| bbox_pred = scale(gfl_reg(reg_feat)).float() | |||
| N, C, H, W = bbox_pred.size() | |||
| if self.training: | |||
| bbox_before_softmax = bbox_pred.reshape(N, 4, self.reg_max + 1, H, | |||
| W) | |||
| bbox_before_softmax = bbox_before_softmax.flatten( | |||
| start_dim=3).permute(0, 3, 1, 2) | |||
| bbox_pred = F.softmax( | |||
| bbox_pred.reshape(N, 4, self.reg_max + 1, H, W), dim=2) | |||
| cls_score = gfl_cls(cls_feat).sigmoid() | |||
| cls_score = cls_score.flatten(start_dim=2).permute( | |||
| 0, 2, 1) # N, h*w, self.num_classes+1 | |||
| bbox_pred = bbox_pred.flatten(start_dim=3).permute( | |||
| 0, 3, 1, 2) # N, h*w, 4, self.reg_max+1 | |||
| if self.training: | |||
| return cls_score, bbox_pred, bbox_before_softmax | |||
| else: | |||
| return cls_score, bbox_pred | |||
| def get_single_level_center_priors(self, batch_size, featmap_size, stride, | |||
| dtype, device): | |||
| h, w = featmap_size | |||
| x_range = (torch.arange(0, int(w), dtype=dtype, | |||
| device=device)) * stride | |||
| y_range = (torch.arange(0, int(h), dtype=dtype, | |||
| device=device)) * stride | |||
| x = x_range.repeat(h, 1) | |||
| y = y_range.unsqueeze(-1).repeat(1, w) | |||
| y = y.flatten() | |||
| x = x.flatten() | |||
| strides = x.new_full((x.shape[0], ), stride) | |||
| priors = torch.stack([x, y, strides, strides], dim=-1) | |||
| return priors.unsqueeze(0).repeat(batch_size, 1, 1) | |||
| def sample(self, assign_result, gt_bboxes): | |||
| pos_inds = torch.nonzero( | |||
| assign_result.gt_inds > 0, as_tuple=False).squeeze(-1).unique() | |||
| neg_inds = torch.nonzero( | |||
| assign_result.gt_inds == 0, as_tuple=False).squeeze(-1).unique() | |||
| pos_assigned_gt_inds = assign_result.gt_inds[pos_inds] - 1 | |||
| if gt_bboxes.numel() == 0: | |||
| # hack for index error case | |||
| assert pos_assigned_gt_inds.numel() == 0 | |||
| pos_gt_bboxes = torch.empty_like(gt_bboxes).view(-1, 4) | |||
| else: | |||
| if len(gt_bboxes.shape) < 2: | |||
| gt_bboxes = gt_bboxes.view(-1, 4) | |||
| pos_gt_bboxes = gt_bboxes[pos_assigned_gt_inds, :] | |||
| return pos_inds, neg_inds, pos_gt_bboxes, pos_assigned_gt_inds | |||
| @@ -1,10 +1,10 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| # The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet. | |||
| # The DAMO-YOLO implementation is also open-sourced by the authors at https://github.com/tinyvision/damo-yolo. | |||
| import copy | |||
| from .giraffe_fpn import GiraffeNeck | |||
| from .giraffe_fpn_v2 import GiraffeNeckV2 | |||
| from .giraffe_fpn_btn import GiraffeNeckV2 | |||
| def build_neck(cfg): | |||
| @@ -1,5 +1,5 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| # The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet. | |||
| # The DAMO-YOLO implementation is also open-sourced by the authors at https://github.com/tinyvision/damo-yolo. | |||
| import collections | |||
| import itertools | |||
| @@ -1,5 +1,5 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| # The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet. | |||
| # The DAMO-YOLO implementation is also open-sourced by the authors at https://github.com/tinyvision/damo-yolo. | |||
| import logging | |||
| import math | |||
| @@ -15,7 +15,8 @@ from timm import create_model | |||
| from timm.models.layers import (Swish, create_conv2d, create_pool2d, | |||
| get_act_layer) | |||
| from ..core.base_ops import CSPLayer, ShuffleBlock, ShuffleCSPLayer | |||
| from modelscope.models.cv.tinynas_detection.core.base_ops import ( | |||
| CSPLayer, ShuffleBlock, ShuffleCSPLayer) | |||
| from .giraffe_config import get_graph_config | |||
| _ACT_LAYER = Swish | |||
| @@ -0,0 +1,132 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| # The DAMO-YOLO implementation is also open-sourced by the authors at https://github.com/tinyvision/damo-yolo. | |||
| import torch | |||
| import torch.nn as nn | |||
| from modelscope.models.cv.tinynas_detection.core.ops import ConvBNAct, CSPStage | |||
| class GiraffeNeckV2(nn.Module): | |||
| def __init__( | |||
| self, | |||
| depth=1.0, | |||
| hidden_ratio=1.0, | |||
| in_features=[2, 3, 4], | |||
| in_channels=[256, 512, 1024], | |||
| out_channels=[256, 512, 1024], | |||
| act='silu', | |||
| spp=False, | |||
| block_name='BasicBlock', | |||
| ): | |||
| super().__init__() | |||
| self.in_features = in_features | |||
| self.in_channels = in_channels | |||
| self.out_channels = out_channels | |||
| Conv = ConvBNAct | |||
| self.upsample = nn.Upsample(scale_factor=2, mode='nearest') | |||
| # node x3: input x0, x1 | |||
| self.bu_conv13 = Conv(in_channels[1], in_channels[1], 3, 2, act=act) | |||
| self.merge_3 = CSPStage( | |||
| block_name, | |||
| in_channels[1] + in_channels[2], | |||
| hidden_ratio, | |||
| in_channels[2], | |||
| round(3 * depth), | |||
| act=act, | |||
| spp=spp) | |||
| # node x4: input x1, x2, x3 | |||
| self.bu_conv24 = Conv(in_channels[0], in_channels[0], 3, 2, act=act) | |||
| self.merge_4 = CSPStage( | |||
| block_name, | |||
| in_channels[0] + in_channels[1] + in_channels[2], | |||
| hidden_ratio, | |||
| in_channels[1], | |||
| round(3 * depth), | |||
| act=act, | |||
| spp=spp) | |||
| # node x5: input x2, x4 | |||
| self.merge_5 = CSPStage( | |||
| block_name, | |||
| in_channels[1] + in_channels[0], | |||
| hidden_ratio, | |||
| out_channels[0], | |||
| round(3 * depth), | |||
| act=act, | |||
| spp=spp) | |||
| # node x7: input x4, x5 | |||
| self.bu_conv57 = Conv(out_channels[0], out_channels[0], 3, 2, act=act) | |||
| self.merge_7 = CSPStage( | |||
| block_name, | |||
| out_channels[0] + in_channels[1], | |||
| hidden_ratio, | |||
| out_channels[1], | |||
| round(3 * depth), | |||
| act=act, | |||
| spp=spp) | |||
| # node x6: input x3, x4, x7 | |||
| self.bu_conv46 = Conv(in_channels[1], in_channels[1], 3, 2, act=act) | |||
| self.bu_conv76 = Conv(out_channels[1], out_channels[1], 3, 2, act=act) | |||
| self.merge_6 = CSPStage( | |||
| block_name, | |||
| in_channels[1] + out_channels[1] + in_channels[2], | |||
| hidden_ratio, | |||
| out_channels[2], | |||
| round(3 * depth), | |||
| act=act, | |||
| spp=spp) | |||
| def init_weights(self): | |||
| pass | |||
| def forward(self, out_features): | |||
| """ | |||
| Args: | |||
| inputs: input images. | |||
| Returns: | |||
| Tuple[Tensor]: FPN feature. | |||
| """ | |||
| # backbone | |||
| [x2, x1, x0] = out_features | |||
| # node x3 | |||
| x13 = self.bu_conv13(x1) | |||
| x3 = torch.cat([x0, x13], 1) | |||
| x3 = self.merge_3(x3) | |||
| # node x4 | |||
| x34 = self.upsample(x3) | |||
| x24 = self.bu_conv24(x2) | |||
| x4 = torch.cat([x1, x24, x34], 1) | |||
| x4 = self.merge_4(x4) | |||
| # node x5 | |||
| x45 = self.upsample(x4) | |||
| x5 = torch.cat([x2, x45], 1) | |||
| x5 = self.merge_5(x5) | |||
| # node x8 | |||
| # x8 = x5 | |||
| # node x7 | |||
| x57 = self.bu_conv57(x5) | |||
| x7 = torch.cat([x4, x57], 1) | |||
| x7 = self.merge_7(x7) | |||
| # node x6 | |||
| x46 = self.bu_conv46(x4) | |||
| x76 = self.bu_conv76(x7) | |||
| x6 = torch.cat([x3, x46, x76], 1) | |||
| x6 = self.merge_6(x6) | |||
| outputs = (x5, x7, x6) | |||
| return outputs | |||
| @@ -1,200 +0,0 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| # The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet. | |||
| import torch | |||
| import torch.nn as nn | |||
| from ..core.base_ops import BaseConv, CSPLayer, DWConv | |||
| from ..core.neck_ops import CSPStage | |||
| class GiraffeNeckV2(nn.Module): | |||
| def __init__( | |||
| self, | |||
| depth=1.0, | |||
| width=1.0, | |||
| in_channels=[256, 512, 1024], | |||
| out_channels=[256, 512, 1024], | |||
| depthwise=False, | |||
| act='silu', | |||
| spp=True, | |||
| reparam_mode=True, | |||
| block_name='BasicBlock', | |||
| ): | |||
| super().__init__() | |||
| self.in_channels = in_channels | |||
| Conv = DWConv if depthwise else BaseConv | |||
| reparam_mode = reparam_mode | |||
| self.upsample = nn.Upsample(scale_factor=2, mode='nearest') | |||
| # node x3: input x0, x1 | |||
| self.bu_conv13 = Conv( | |||
| int(in_channels[1] * width), | |||
| int(in_channels[1] * width), | |||
| 3, | |||
| 2, | |||
| act=act) | |||
| if reparam_mode: | |||
| self.merge_3 = CSPStage( | |||
| block_name, | |||
| int((in_channels[1] + in_channels[2]) * width), | |||
| int(in_channels[2] * width), | |||
| round(3 * depth), | |||
| act=act, | |||
| spp=spp) | |||
| else: | |||
| self.merge_3 = CSPLayer( | |||
| int((in_channels[1] + in_channels[2]) * width), | |||
| int(in_channels[2] * width), | |||
| round(3 * depth), | |||
| False, | |||
| depthwise=depthwise, | |||
| act=act) | |||
| # node x4: input x1, x2, x3 | |||
| self.bu_conv24 = Conv( | |||
| int(in_channels[0] * width), | |||
| int(in_channels[0] * width), | |||
| 3, | |||
| 2, | |||
| act=act) | |||
| if reparam_mode: | |||
| self.merge_4 = CSPStage( | |||
| block_name, | |||
| int((in_channels[0] + in_channels[1] + in_channels[2]) | |||
| * width), | |||
| int(in_channels[1] * width), | |||
| round(3 * depth), | |||
| act=act, | |||
| spp=spp) | |||
| else: | |||
| self.merge_4 = CSPLayer( | |||
| int((in_channels[0] + in_channels[1] + in_channels[2]) | |||
| * width), | |||
| int(in_channels[1] * width), | |||
| round(3 * depth), | |||
| False, | |||
| depthwise=depthwise, | |||
| act=act) | |||
| # node x5: input x2, x4 | |||
| if reparam_mode: | |||
| self.merge_5 = CSPStage( | |||
| block_name, | |||
| int((in_channels[1] + in_channels[0]) * width), | |||
| int(out_channels[0] * width), | |||
| round(3 * depth), | |||
| act=act, | |||
| spp=spp) | |||
| else: | |||
| self.merge_5 = CSPLayer( | |||
| int((in_channels[1] + in_channels[0]) * width), | |||
| int(out_channels[0] * width), | |||
| round(3 * depth), | |||
| False, | |||
| depthwise=depthwise, | |||
| act=act) | |||
| # node x7: input x4, x5 | |||
| self.bu_conv57 = Conv( | |||
| int(out_channels[0] * width), | |||
| int(out_channels[0] * width), | |||
| 3, | |||
| 2, | |||
| act=act) | |||
| if reparam_mode: | |||
| self.merge_7 = CSPStage( | |||
| block_name, | |||
| int((out_channels[0] + in_channels[1]) * width), | |||
| int(out_channels[1] * width), | |||
| round(3 * depth), | |||
| act=act, | |||
| spp=spp) | |||
| else: | |||
| self.merge_7 = CSPLayer( | |||
| int((out_channels[0] + in_channels[1]) * width), | |||
| int(out_channels[1] * width), | |||
| round(3 * depth), | |||
| False, | |||
| depthwise=depthwise, | |||
| act=act) | |||
| # node x6: input x3, x4, x7 | |||
| self.bu_conv46 = Conv( | |||
| int(in_channels[1] * width), | |||
| int(in_channels[1] * width), | |||
| 3, | |||
| 2, | |||
| act=act) | |||
| self.bu_conv76 = Conv( | |||
| int(out_channels[1] * width), | |||
| int(out_channels[1] * width), | |||
| 3, | |||
| 2, | |||
| act=act) | |||
| if reparam_mode: | |||
| self.merge_6 = CSPStage( | |||
| block_name, | |||
| int((in_channels[1] + out_channels[1] + in_channels[2]) | |||
| * width), | |||
| int(out_channels[2] * width), | |||
| round(3 * depth), | |||
| act=act, | |||
| spp=spp) | |||
| else: | |||
| self.merge_6 = CSPLayer( | |||
| int((in_channels[1] + out_channels[1] + in_channels[2]) | |||
| * width), | |||
| int(out_channels[2] * width), | |||
| round(3 * depth), | |||
| False, | |||
| depthwise=depthwise, | |||
| act=act) | |||
| def init_weights(self): | |||
| pass | |||
| def forward(self, out_features): | |||
| """ | |||
| Args: | |||
| inputs: input images. | |||
| Returns: | |||
| Tuple[Tensor]: FPN feature. | |||
| """ | |||
| # backbone | |||
| [x2, x1, x0] = out_features | |||
| # node x3 | |||
| x13 = self.bu_conv13(x1) | |||
| x3 = torch.cat([x0, x13], 1) | |||
| x3 = self.merge_3(x3) | |||
| # node x4 | |||
| x34 = self.upsample(x3) | |||
| x24 = self.bu_conv24(x2) | |||
| x4 = torch.cat([x1, x24, x34], 1) | |||
| x4 = self.merge_4(x4) | |||
| # node x5 | |||
| x45 = self.upsample(x4) | |||
| x5 = torch.cat([x2, x45], 1) | |||
| x5 = self.merge_5(x5) | |||
| # node x7 | |||
| x57 = self.bu_conv57(x5) | |||
| x7 = torch.cat([x4, x57], 1) | |||
| x7 = self.merge_7(x7) | |||
| # node x6 | |||
| x46 = self.bu_conv46(x4) | |||
| x76 = self.bu_conv76(x7) | |||
| x6 = torch.cat([x3, x46, x76], 1) | |||
| x6 = self.merge_6(x6) | |||
| outputs = (x5, x7, x6) | |||
| return outputs | |||
| @@ -11,5 +11,5 @@ from .detector import SingleStageDetector | |||
| class DamoYolo(SingleStageDetector): | |||
| def __init__(self, model_dir, *args, **kwargs): | |||
| self.config_name = 'damoyolo_s.py' | |||
| self.config_name = 'damoyolo.py' | |||
| super(DamoYolo, self).__init__(model_dir, *args, **kwargs) | |||
| @@ -1,5 +1,5 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| # The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet. | |||
| # The DAMO-YOLO implementation is also open-sourced by the authors at https://github.com/tinyvision/damo-yolo. | |||
| from modelscope.metainfo import Models | |||
| from modelscope.models.builder import MODELS | |||
| @@ -1,30 +1,33 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| # The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet. | |||
| # The DAMO-YOLO implementation is also open-sourced by the authors, and available | |||
| # at https://github.com/tinyvision/damo-yolo. | |||
| import importlib | |||
| import os | |||
| import shutil | |||
| import sys | |||
| import tempfile | |||
| from os.path import dirname, join | |||
| from easydict import EasyDict | |||
| def get_config_by_file(config_file): | |||
| try: | |||
| sys.path.append(os.path.dirname(config_file)) | |||
| current_config = importlib.import_module( | |||
| os.path.basename(config_file).split('.')[0]) | |||
| exp = current_config.Config() | |||
| except Exception: | |||
| raise ImportError( | |||
| "{} doesn't contains class named 'Config'".format(config_file)) | |||
| return exp | |||
| def parse_config(filename): | |||
| filename = str(filename) | |||
| if filename.endswith('.py'): | |||
| with tempfile.TemporaryDirectory() as temp_config_dir: | |||
| shutil.copyfile(filename, join(temp_config_dir, '_tempconfig.py')) | |||
| sys.path.insert(0, temp_config_dir) | |||
| mod = importlib.import_module('_tempconfig') | |||
| sys.path.pop(0) | |||
| cfg_dict = EasyDict({ | |||
| name: value | |||
| for name, value in mod.__dict__.items() | |||
| if not name.startswith('__') | |||
| }) | |||
| # delete imported module | |||
| del sys.modules['_tempconfig'] | |||
| else: | |||
| raise IOError('Only .py type are supported now!') | |||
| def parse_config(config_file): | |||
| """ | |||
| get config object by file. | |||
| Args: | |||
| config_file (str): file path of config. | |||
| """ | |||
| assert (config_file is not None), 'plz provide config file' | |||
| if config_file is not None: | |||
| return get_config_by_file(config_file) | |||
| return cfg_dict | |||
| @@ -29,7 +29,25 @@ class TinynasObjectDetectionTest(unittest.TestCase, DemoCompatibilityCheck): | |||
| model='damo/cv_tinynas_object-detection_damoyolo') | |||
| result = tinynas_object_detection( | |||
| 'data/test/images/image_detection.jpg') | |||
| print('damoyolo', result) | |||
| print('damoyolo-s', result) | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_damoyolo_m(self): | |||
| tinynas_object_detection = pipeline( | |||
| Tasks.image_object_detection, | |||
| model='damo/cv_tinynas_object-detection_damoyolo-m') | |||
| result = tinynas_object_detection( | |||
| 'data/test/images/image_detection.jpg') | |||
| print('damoyolo-m', result) | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_damoyolo_t(self): | |||
| tinynas_object_detection = pipeline( | |||
| Tasks.image_object_detection, | |||
| model='damo/cv_tinynas_object-detection_damoyolo-t') | |||
| result = tinynas_object_detection( | |||
| 'data/test/images/image_detection.jpg') | |||
| print('damoyolo-t', result) | |||
| @unittest.skip('demo compatibility test is only enabled on a needed-basis') | |||
| def test_demo_compatibility(self): | |||
| @@ -40,7 +58,7 @@ class TinynasObjectDetectionTest(unittest.TestCase, DemoCompatibilityCheck): | |||
| test_image = 'data/test/images/image_detection.jpg' | |||
| tinynas_object_detection = pipeline( | |||
| Tasks.image_object_detection, | |||
| model='damo/cv_tinynas_object-detection_damoyolo') | |||
| model='damo/cv_tinynas_object-detection_damoyolo-m') | |||
| result = tinynas_object_detection(test_image) | |||
| tinynas_object_detection.show_result(test_image, result, | |||
| 'demo_ret.jpg') | |||