接入damyolo系列检测模型
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10377688
master
| @@ -9,7 +9,9 @@ class Models(object): | |||||
| Model name should only contain model info but not task info. | Model name should only contain model info but not task info. | ||||
| """ | """ | ||||
| # tinynas models | |||||
| tinynas_detection = 'tinynas-detection' | tinynas_detection = 'tinynas-detection' | ||||
| tinynas_damoyolo = 'tinynas-damoyolo' | |||||
| # vision models | # vision models | ||||
| detection = 'detection' | detection = 'detection' | ||||
| @@ -7,10 +7,12 @@ from modelscope.utils.import_utils import LazyImportModule | |||||
| if TYPE_CHECKING: | if TYPE_CHECKING: | ||||
| from .tinynas_detector import Tinynas_detector | from .tinynas_detector import Tinynas_detector | ||||
| from .tinynas_damoyolo import DamoYolo | |||||
| else: | else: | ||||
| _import_structure = { | _import_structure = { | ||||
| 'tinynas_detector': ['TinynasDetector'], | 'tinynas_detector': ['TinynasDetector'], | ||||
| 'tinynas_damoyolo': ['DamoYolo'], | |||||
| } | } | ||||
| import sys | import sys | ||||
| @@ -4,6 +4,7 @@ | |||||
| import torch | import torch | ||||
| import torch.nn as nn | import torch.nn as nn | ||||
| from modelscope.utils.file_utils import read_file | |||||
| from ..core.base_ops import Focus, SPPBottleneck, get_activation | from ..core.base_ops import Focus, SPPBottleneck, get_activation | ||||
| from ..core.repvgg_block import RepVggBlock | from ..core.repvgg_block import RepVggBlock | ||||
| @@ -49,12 +50,16 @@ class ResConvK1KX(nn.Module): | |||||
| kernel_size, | kernel_size, | ||||
| stride, | stride, | ||||
| force_resproj=False, | force_resproj=False, | ||||
| act='silu'): | |||||
| act='silu', | |||||
| reparam=False): | |||||
| super(ResConvK1KX, self).__init__() | super(ResConvK1KX, self).__init__() | ||||
| self.stride = stride | self.stride = stride | ||||
| self.conv1 = ConvKXBN(in_c, btn_c, 1, 1) | self.conv1 = ConvKXBN(in_c, btn_c, 1, 1) | ||||
| self.conv2 = RepVggBlock( | |||||
| btn_c, out_c, kernel_size, stride, act='identity') | |||||
| if not reparam: | |||||
| self.conv2 = ConvKXBN(btn_c, out_c, 3, stride) | |||||
| else: | |||||
| self.conv2 = RepVggBlock( | |||||
| btn_c, out_c, kernel_size, stride, act='identity') | |||||
| if act is None: | if act is None: | ||||
| self.activation_function = torch.relu | self.activation_function = torch.relu | ||||
| @@ -97,7 +102,8 @@ class SuperResConvK1KX(nn.Module): | |||||
| stride, | stride, | ||||
| num_blocks, | num_blocks, | ||||
| with_spp=False, | with_spp=False, | ||||
| act='silu'): | |||||
| act='silu', | |||||
| reparam=False): | |||||
| super(SuperResConvK1KX, self).__init__() | super(SuperResConvK1KX, self).__init__() | ||||
| if act is None: | if act is None: | ||||
| self.act = torch.relu | self.act = torch.relu | ||||
| @@ -124,7 +130,8 @@ class SuperResConvK1KX(nn.Module): | |||||
| this_kernel_size, | this_kernel_size, | ||||
| this_stride, | this_stride, | ||||
| force_resproj, | force_resproj, | ||||
| act=act) | |||||
| act=act, | |||||
| reparam=reparam) | |||||
| self.block_list.append(the_block) | self.block_list.append(the_block) | ||||
| if block_id == 0 and with_spp: | if block_id == 0 and with_spp: | ||||
| self.block_list.append( | self.block_list.append( | ||||
| @@ -248,7 +255,8 @@ class TinyNAS(nn.Module): | |||||
| with_spp=False, | with_spp=False, | ||||
| use_focus=False, | use_focus=False, | ||||
| need_conv1=True, | need_conv1=True, | ||||
| act='silu'): | |||||
| act='silu', | |||||
| reparam=False): | |||||
| super(TinyNAS, self).__init__() | super(TinyNAS, self).__init__() | ||||
| assert len(out_indices) == len(out_channels) | assert len(out_indices) == len(out_channels) | ||||
| self.out_indices = out_indices | self.out_indices = out_indices | ||||
| @@ -281,7 +289,8 @@ class TinyNAS(nn.Module): | |||||
| block_info['s'], | block_info['s'], | ||||
| block_info['L'], | block_info['L'], | ||||
| spp, | spp, | ||||
| act=act) | |||||
| act=act, | |||||
| reparam=reparam) | |||||
| self.block_list.append(the_block) | self.block_list.append(the_block) | ||||
| elif the_block_class == 'SuperResConvKXKX': | elif the_block_class == 'SuperResConvKXKX': | ||||
| spp = with_spp if idx == len(structure_info) - 1 else False | spp = with_spp if idx == len(structure_info) - 1 else False | ||||
| @@ -325,8 +334,8 @@ class TinyNAS(nn.Module): | |||||
| def load_tinynas_net(backbone_cfg): | def load_tinynas_net(backbone_cfg): | ||||
| # load masternet model to path | # load masternet model to path | ||||
| import ast | import ast | ||||
| struct_str = ''.join([x.strip() for x in backbone_cfg.net_structure_str]) | |||||
| net_structure_str = read_file(backbone_cfg.structure_file) | |||||
| struct_str = ''.join([x.strip() for x in net_structure_str]) | |||||
| struct_info = ast.literal_eval(struct_str) | struct_info = ast.literal_eval(struct_str) | ||||
| for layer in struct_info: | for layer in struct_info: | ||||
| if 'nbitsA' in layer: | if 'nbitsA' in layer: | ||||
| @@ -342,6 +351,6 @@ def load_tinynas_net(backbone_cfg): | |||||
| use_focus=backbone_cfg.use_focus, | use_focus=backbone_cfg.use_focus, | ||||
| act=backbone_cfg.act, | act=backbone_cfg.act, | ||||
| need_conv1=backbone_cfg.need_conv1, | need_conv1=backbone_cfg.need_conv1, | ||||
| ) | |||||
| reparam=backbone_cfg.reparam) | |||||
| return model | return model | ||||
| @@ -30,7 +30,7 @@ class SingleStageDetector(TorchModel): | |||||
| """ | """ | ||||
| super().__init__(model_dir, *args, **kwargs) | super().__init__(model_dir, *args, **kwargs) | ||||
| config_path = osp.join(model_dir, 'airdet_s.py') | |||||
| config_path = osp.join(model_dir, self.config_name) | |||||
| config = parse_config(config_path) | config = parse_config(config_path) | ||||
| self.cfg = config | self.cfg = config | ||||
| model_path = osp.join(model_dir, config.model.name) | model_path = osp.join(model_dir, config.model.name) | ||||
| @@ -41,6 +41,9 @@ class SingleStageDetector(TorchModel): | |||||
| self.conf_thre = config.model.head.nms_conf_thre | self.conf_thre = config.model.head.nms_conf_thre | ||||
| self.nms_thre = config.model.head.nms_iou_thre | self.nms_thre = config.model.head.nms_iou_thre | ||||
| if self.cfg.model.backbone.name == 'TinyNAS': | |||||
| self.cfg.model.backbone.structure_file = osp.join( | |||||
| model_dir, self.cfg.model.backbone.structure_file) | |||||
| self.backbone = build_backbone(self.cfg.model.backbone) | self.backbone = build_backbone(self.cfg.model.backbone) | ||||
| self.neck = build_neck(self.cfg.model.neck) | self.neck = build_neck(self.cfg.model.neck) | ||||
| self.head = build_head(self.cfg.model.head) | self.head = build_head(self.cfg.model.head) | ||||
| @@ -124,11 +124,13 @@ class GFocalHead_Tiny(nn.Module): | |||||
| simOTA_iou_weight=3.0, | simOTA_iou_weight=3.0, | ||||
| octbase=8, | octbase=8, | ||||
| simlqe=False, | simlqe=False, | ||||
| use_lqe=True, | |||||
| **kwargs): | **kwargs): | ||||
| self.simlqe = simlqe | self.simlqe = simlqe | ||||
| self.num_classes = num_classes | self.num_classes = num_classes | ||||
| self.in_channels = in_channels | self.in_channels = in_channels | ||||
| self.strides = strides | self.strides = strides | ||||
| self.use_lqe = use_lqe | |||||
| self.feat_channels = feat_channels if isinstance(feat_channels, list) \ | self.feat_channels = feat_channels if isinstance(feat_channels, list) \ | ||||
| else [feat_channels] * len(self.strides) | else [feat_channels] * len(self.strides) | ||||
| @@ -181,15 +183,20 @@ class GFocalHead_Tiny(nn.Module): | |||||
| groups=self.conv_groups, | groups=self.conv_groups, | ||||
| norm=self.norm, | norm=self.norm, | ||||
| act=self.act)) | act=self.act)) | ||||
| if not self.simlqe: | |||||
| conf_vector = [nn.Conv2d(4 * self.total_dim, self.reg_channels, 1)] | |||||
| if self.use_lqe: | |||||
| if not self.simlqe: | |||||
| conf_vector = [ | |||||
| nn.Conv2d(4 * self.total_dim, self.reg_channels, 1) | |||||
| ] | |||||
| else: | |||||
| conf_vector = [ | |||||
| nn.Conv2d(4 * (self.reg_max + 1), self.reg_channels, 1) | |||||
| ] | |||||
| conf_vector += [self.relu] | |||||
| conf_vector += [nn.Conv2d(self.reg_channels, 1, 1), nn.Sigmoid()] | |||||
| reg_conf = nn.Sequential(*conf_vector) | |||||
| else: | else: | ||||
| conf_vector = [ | |||||
| nn.Conv2d(4 * (self.reg_max + 1), self.reg_channels, 1) | |||||
| ] | |||||
| conf_vector += [self.relu] | |||||
| conf_vector += [nn.Conv2d(self.reg_channels, 1, 1), nn.Sigmoid()] | |||||
| reg_conf = nn.Sequential(*conf_vector) | |||||
| reg_conf = None | |||||
| return cls_convs, reg_convs, reg_conf | return cls_convs, reg_convs, reg_conf | ||||
| @@ -290,21 +297,27 @@ class GFocalHead_Tiny(nn.Module): | |||||
| N, C, H, W = bbox_pred.size() | N, C, H, W = bbox_pred.size() | ||||
| prob = F.softmax( | prob = F.softmax( | ||||
| bbox_pred.reshape(N, 4, self.reg_max + 1, H, W), dim=2) | bbox_pred.reshape(N, 4, self.reg_max + 1, H, W), dim=2) | ||||
| if not self.simlqe: | |||||
| prob_topk, _ = prob.topk(self.reg_topk, dim=2) | |||||
| if self.add_mean: | |||||
| stat = torch.cat( | |||||
| [prob_topk, prob_topk.mean(dim=2, keepdim=True)], dim=2) | |||||
| if self.use_lqe: | |||||
| if not self.simlqe: | |||||
| prob_topk, _ = prob.topk(self.reg_topk, dim=2) | |||||
| if self.add_mean: | |||||
| stat = torch.cat( | |||||
| [prob_topk, | |||||
| prob_topk.mean(dim=2, keepdim=True)], | |||||
| dim=2) | |||||
| else: | |||||
| stat = prob_topk | |||||
| quality_score = reg_conf( | |||||
| stat.reshape(N, 4 * self.total_dim, H, W)) | |||||
| else: | else: | ||||
| stat = prob_topk | |||||
| quality_score = reg_conf( | |||||
| bbox_pred.reshape(N, 4 * (self.reg_max + 1), H, W)) | |||||
| quality_score = reg_conf(stat.reshape(N, 4 * self.total_dim, H, W)) | |||||
| cls_score = gfl_cls(cls_feat).sigmoid() * quality_score | |||||
| else: | else: | ||||
| quality_score = reg_conf( | |||||
| bbox_pred.reshape(N, 4 * (self.reg_max + 1), H, W)) | |||||
| cls_score = gfl_cls(cls_feat).sigmoid() * quality_score | |||||
| cls_score = gfl_cls(cls_feat).sigmoid() | |||||
| flatten_cls_score = cls_score.flatten(start_dim=2).transpose(1, 2) | flatten_cls_score = cls_score.flatten(start_dim=2).transpose(1, 2) | ||||
| flatten_bbox_pred = bbox_pred.flatten(start_dim=2).transpose(1, 2) | flatten_bbox_pred = bbox_pred.flatten(start_dim=2).transpose(1, 2) | ||||
| @@ -14,7 +14,6 @@ class GiraffeNeckV2(nn.Module): | |||||
| self, | self, | ||||
| depth=1.0, | depth=1.0, | ||||
| width=1.0, | width=1.0, | ||||
| in_features=[2, 3, 4], | |||||
| in_channels=[256, 512, 1024], | in_channels=[256, 512, 1024], | ||||
| out_channels=[256, 512, 1024], | out_channels=[256, 512, 1024], | ||||
| depthwise=False, | depthwise=False, | ||||
| @@ -24,7 +23,6 @@ class GiraffeNeckV2(nn.Module): | |||||
| block_name='BasicBlock', | block_name='BasicBlock', | ||||
| ): | ): | ||||
| super().__init__() | super().__init__() | ||||
| self.in_features = in_features | |||||
| self.in_channels = in_channels | self.in_channels = in_channels | ||||
| Conv = DWConv if depthwise else BaseConv | Conv = DWConv if depthwise else BaseConv | ||||
| @@ -169,8 +167,7 @@ class GiraffeNeckV2(nn.Module): | |||||
| """ | """ | ||||
| # backbone | # backbone | ||||
| features = [out_features[f] for f in self.in_features] | |||||
| [x2, x1, x0] = features | |||||
| [x2, x1, x0] = out_features | |||||
| # node x3 | # node x3 | ||||
| x13 = self.bu_conv13(x1) | x13 = self.bu_conv13(x1) | ||||
| @@ -0,0 +1,15 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from modelscope.metainfo import Models | |||||
| from modelscope.models.builder import MODELS | |||||
| from modelscope.utils.constant import Tasks | |||||
| from .detector import SingleStageDetector | |||||
| @MODELS.register_module( | |||||
| Tasks.image_object_detection, module_name=Models.tinynas_damoyolo) | |||||
| class DamoYolo(SingleStageDetector): | |||||
| def __init__(self, model_dir, *args, **kwargs): | |||||
| self.config_name = 'damoyolo_s.py' | |||||
| super(DamoYolo, self).__init__(model_dir, *args, **kwargs) | |||||
| @@ -12,5 +12,5 @@ from .detector import SingleStageDetector | |||||
| class TinynasDetector(SingleStageDetector): | class TinynasDetector(SingleStageDetector): | ||||
| def __init__(self, model_dir, *args, **kwargs): | def __init__(self, model_dir, *args, **kwargs): | ||||
| self.config_name = 'airdet_s.py' | |||||
| super(TinynasDetector, self).__init__(model_dir, *args, **kwargs) | super(TinynasDetector, self).__init__(model_dir, *args, **kwargs) | ||||
| @@ -12,6 +12,8 @@ from modelscope.pipelines.base import Input, Pipeline | |||||
| from modelscope.pipelines.builder import PIPELINES | from modelscope.pipelines.builder import PIPELINES | ||||
| from modelscope.preprocessors import LoadImage | from modelscope.preprocessors import LoadImage | ||||
| from modelscope.utils.constant import Tasks | from modelscope.utils.constant import Tasks | ||||
| from modelscope.utils.cv.image_utils import \ | |||||
| show_image_object_detection_auto_result | |||||
| from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
| logger = get_logger() | logger = get_logger() | ||||
| @@ -52,10 +54,18 @@ class TinynasDetectionPipeline(Pipeline): | |||||
| bboxes, scores, labels = self.model.postprocess(inputs['data']) | bboxes, scores, labels = self.model.postprocess(inputs['data']) | ||||
| if bboxes is None: | if bboxes is None: | ||||
| return None | |||||
| outputs = { | |||||
| OutputKeys.SCORES: scores, | |||||
| OutputKeys.LABELS: labels, | |||||
| OutputKeys.BOXES: bboxes | |||||
| } | |||||
| outputs = { | |||||
| OutputKeys.SCORES: [], | |||||
| OutputKeys.LABELS: [], | |||||
| OutputKeys.BOXES: [] | |||||
| } | |||||
| else: | |||||
| outputs = { | |||||
| OutputKeys.SCORES: scores, | |||||
| OutputKeys.LABELS: labels, | |||||
| OutputKeys.BOXES: bboxes | |||||
| } | |||||
| return outputs | return outputs | ||||
| def show_result(self, img_path, result, save_path=None): | |||||
| show_image_object_detection_auto_result(img_path, result, save_path) | |||||
| @@ -1,6 +1,7 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| import inspect | import inspect | ||||
| import os | |||||
| from pathlib import Path | from pathlib import Path | ||||
| @@ -35,3 +36,10 @@ def get_default_cache_dir(): | |||||
| """ | """ | ||||
| default_cache_dir = Path.home().joinpath('.cache', 'modelscope') | default_cache_dir = Path.home().joinpath('.cache', 'modelscope') | ||||
| return default_cache_dir | return default_cache_dir | ||||
| def read_file(path): | |||||
| with open(path, 'r') as f: | |||||
| text = f.read() | |||||
| return text | |||||
| @@ -4,22 +4,45 @@ import unittest | |||||
| from modelscope.pipelines import pipeline | from modelscope.pipelines import pipeline | ||||
| from modelscope.utils.constant import Tasks | from modelscope.utils.constant import Tasks | ||||
| from modelscope.utils.demo_utils import DemoCompatibilityCheck | |||||
| from modelscope.utils.test_utils import test_level | from modelscope.utils.test_utils import test_level | ||||
| class TinynasObjectDetectionTest(unittest.TestCase): | |||||
| class TinynasObjectDetectionTest(unittest.TestCase, DemoCompatibilityCheck): | |||||
| def setUp(self) -> None: | |||||
| self.task = Tasks.image_object_detection | |||||
| self.model_id = 'damo/cv_tinynas_object-detection_damoyolo' | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | ||||
| def test_run(self): | |||||
| def test_run_airdet(self): | |||||
| tinynas_object_detection = pipeline( | tinynas_object_detection = pipeline( | ||||
| Tasks.image_object_detection, model='damo/cv_tinynas_detection') | Tasks.image_object_detection, model='damo/cv_tinynas_detection') | ||||
| result = tinynas_object_detection( | result = tinynas_object_detection( | ||||
| 'data/test/images/image_detection.jpg') | 'data/test/images/image_detection.jpg') | ||||
| print(result) | print(result) | ||||
| @unittest.skip('will be enabled after damoyolo officially released') | |||||
| def test_run_damoyolo(self): | |||||
| tinynas_object_detection = pipeline( | |||||
| Tasks.image_object_detection, | |||||
| model='damo/cv_tinynas_object-detection_damoyolo') | |||||
| result = tinynas_object_detection( | |||||
| 'data/test/images/image_detection.jpg') | |||||
| print(result) | |||||
| @unittest.skip('demo compatibility test is only enabled on a needed-basis') | @unittest.skip('demo compatibility test is only enabled on a needed-basis') | ||||
| def test_demo_compatibility(self): | def test_demo_compatibility(self): | ||||
| self.test_demo() | |||||
| self.compatibility_check() | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_image_object_detection_auto_pipeline(self): | |||||
| test_image = 'data/test/images/image_detection.jpg' | |||||
| tinynas_object_detection = pipeline( | |||||
| Tasks.image_object_detection, model='damo/cv_tinynas_detection') | |||||
| result = tinynas_object_detection(test_image) | |||||
| tinynas_object_detection.show_result(test_image, result, | |||||
| 'demo_ret.jpg') | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||