diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 24f2f748..153ca9b4 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -11,6 +11,7 @@ class Models(object): """ # vision models detection = 'detection' + realtime_object_detection = 'realtime-object-detection' scrfd = 'scrfd' classification_model = 'ClassificationModel' nafnet = 'nafnet' @@ -111,6 +112,7 @@ class Pipelines(object): image_super_resolution = 'rrdb-image-super-resolution' face_image_generation = 'gan-face-image-generation' product_retrieval_embedding = 'resnet50-product-retrieval-embedding' + realtime_object_detection = 'cspnet_realtime-object-detection_yolox' face_recognition = 'ir101-face-recognition-cfglint' image_instance_segmentation = 'cascade-mask-rcnn-swin-image-instance-segmentation' image2image_translation = 'image-to-image-translation' diff --git a/modelscope/models/cv/__init__.py b/modelscope/models/cv/__init__.py index 227be2c7..74451c31 100644 --- a/modelscope/models/cv/__init__.py +++ b/modelscope/models/cv/__init__.py @@ -7,5 +7,5 @@ from . import (action_recognition, animal_recognition, body_2d_keypoints, image_reid_person, image_semantic_segmentation, image_to_image_generation, image_to_image_translation, object_detection, product_retrieval_embedding, - salient_detection, super_resolution, + realtime_object_detection, salient_detection, super_resolution, video_single_object_tracking, video_summarization, virual_tryon) diff --git a/modelscope/models/cv/realtime_object_detection/__init__.py b/modelscope/models/cv/realtime_object_detection/__init__.py new file mode 100644 index 00000000..aed13cec --- /dev/null +++ b/modelscope/models/cv/realtime_object_detection/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .realtime_detector import RealtimeDetector +else: + _import_structure = { + 'realtime_detector': ['RealtimeDetector'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/realtime_object_detection/realtime_detector.py b/modelscope/models/cv/realtime_object_detection/realtime_detector.py new file mode 100644 index 00000000..b147f769 --- /dev/null +++ b/modelscope/models/cv/realtime_object_detection/realtime_detector.py @@ -0,0 +1,85 @@ +import argparse +import logging as logger +import os +import os.path as osp +import time + +import cv2 +import json +import torch + +from modelscope.metainfo import Models +from modelscope.models.base.base_torch_model import TorchModel +from modelscope.models.builder import MODELS +from modelscope.preprocessors import LoadImage +from modelscope.utils.config import Config +from modelscope.utils.constant import ModelFile, Tasks +from .yolox.data.data_augment import ValTransform +from .yolox.exp import get_exp_by_name +from .yolox.utils import postprocess + + +@MODELS.register_module( + group_key=Tasks.image_object_detection, + module_name=Models.realtime_object_detection) +class RealtimeDetector(TorchModel): + + def __init__(self, model_dir: str, *args, **kwargs): + super().__init__(model_dir, *args, **kwargs) + self.config = Config.from_file( + os.path.join(self.model_dir, ModelFile.CONFIGURATION)) + + # model type + self.exp = get_exp_by_name(self.config.model_type) + + # build model + self.model = self.exp.get_model() + model_path = osp.join(model_dir, ModelFile.TORCH_MODEL_BIN_FILE) + ckpt = torch.load(model_path, map_location='cpu') + + # load the model state dict + self.model.load_state_dict(ckpt['model']) + self.model.eval() + + # params setting + self.exp.num_classes = self.config.num_classes + self.confthre = self.config.conf_thr + self.num_classes = self.exp.num_classes + self.nmsthre = self.exp.nmsthre + self.test_size = self.exp.test_size + self.preproc = ValTransform(legacy=False) + + def inference(self, img): + with torch.no_grad(): + outputs = self.model(img) + return outputs + + def forward(self, inputs): + return self.inference(inputs) + + def preprocess(self, img): + img = LoadImage.convert_to_ndarray(img) + height, width = img.shape[:2] + self.ratio = min(self.test_size[0] / img.shape[0], + self.test_size[1] / img.shape[1]) + + img, _ = self.preproc(img, None, self.test_size) + img = torch.from_numpy(img).unsqueeze(0) + img = img.float() + + return img + + def postprocess(self, input): + outputs = postprocess( + input, + self.num_classes, + self.confthre, + self.nmsthre, + class_agnostic=True) + + if len(outputs) == 1: + bboxes = outputs[0][:, 0:4].cpu().numpy() / self.ratio + scores = outputs[0][:, 5].cpu().numpy() + labels = outputs[0][:, 6].cpu().int().numpy() + + return bboxes, scores, labels diff --git a/modelscope/models/cv/realtime_object_detection/yolox/__init__.py b/modelscope/models/cv/realtime_object_detection/yolox/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/cv/realtime_object_detection/yolox/data/__init__.py b/modelscope/models/cv/realtime_object_detection/yolox/data/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/cv/realtime_object_detection/yolox/data/data_augment.py b/modelscope/models/cv/realtime_object_detection/yolox/data/data_augment.py new file mode 100644 index 00000000..b52a65fe --- /dev/null +++ b/modelscope/models/cv/realtime_object_detection/yolox/data/data_augment.py @@ -0,0 +1,69 @@ +# The implementation is based on YOLOX, available at https://github.com/Megvii-BaseDetection/YOLOX +""" +Data augmentation functionality. Passed as callable transformations to +Dataset classes. + +The data augmentation procedures were interpreted from @weiliu89's SSD paper +http://arxiv.org/abs/1512.02325 +""" + +import math +import random + +import cv2 +import numpy as np + +from ..utils import xyxy2cxcywh + + +def preproc(img, input_size, swap=(2, 0, 1)): + if len(img.shape) == 3: + padded_img = np.ones( + (input_size[0], input_size[1], 3), dtype=np.uint8) * 114 + else: + padded_img = np.ones(input_size, dtype=np.uint8) * 114 + + r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1]) + resized_img = cv2.resize( + img, + (int(img.shape[1] * r), int(img.shape[0] * r)), + interpolation=cv2.INTER_LINEAR, + ).astype(np.uint8) + padded_img[:int(img.shape[0] * r), :int(img.shape[1] * r)] = resized_img + + padded_img = padded_img.transpose(swap) + padded_img = np.ascontiguousarray(padded_img, dtype=np.float32) + return padded_img, r + + +class ValTransform: + """ + Defines the transformations that should be applied to test PIL image + for input into the network + + dimension -> tensorize -> color adj + + Arguments: + resize (int): input dimension to SSD + rgb_means ((int,int,int)): average RGB of the dataset + (104,117,123) + swap ((int,int,int)): final order of channels + + Returns: + transform (transform) : callable transform to be applied to test/val + data + """ + + def __init__(self, swap=(2, 0, 1), legacy=False): + self.swap = swap + self.legacy = legacy + + # assume input is cv2 img for now + def __call__(self, img, res, input_size): + img, _ = preproc(img, input_size, self.swap) + if self.legacy: + img = img[::-1, :, :].copy() + img /= 255.0 + img -= np.array([0.485, 0.456, 0.406]).reshape(3, 1, 1) + img /= np.array([0.229, 0.224, 0.225]).reshape(3, 1, 1) + return img, np.zeros((1, 5)) diff --git a/modelscope/models/cv/realtime_object_detection/yolox/exp/__init__.py b/modelscope/models/cv/realtime_object_detection/yolox/exp/__init__.py new file mode 100644 index 00000000..e8e3be15 --- /dev/null +++ b/modelscope/models/cv/realtime_object_detection/yolox/exp/__init__.py @@ -0,0 +1,5 @@ +# The implementation is based on YOLOX, available at https://github.com/Megvii-BaseDetection/YOLOX + +from .base_exp import BaseExp +from .build import get_exp_by_name +from .yolox_base import Exp diff --git a/modelscope/models/cv/realtime_object_detection/yolox/exp/base_exp.py b/modelscope/models/cv/realtime_object_detection/yolox/exp/base_exp.py new file mode 100644 index 00000000..a4278cbf --- /dev/null +++ b/modelscope/models/cv/realtime_object_detection/yolox/exp/base_exp.py @@ -0,0 +1,12 @@ +# The implementation is based on YOLOX, available at https://github.com/Megvii-BaseDetection/YOLOX + +from abc import ABCMeta, abstractmethod + +from torch.nn import Module + + +class BaseExp(metaclass=ABCMeta): + + @abstractmethod + def get_model(self) -> Module: + pass diff --git a/modelscope/models/cv/realtime_object_detection/yolox/exp/build.py b/modelscope/models/cv/realtime_object_detection/yolox/exp/build.py new file mode 100644 index 00000000..4858100c --- /dev/null +++ b/modelscope/models/cv/realtime_object_detection/yolox/exp/build.py @@ -0,0 +1,18 @@ +# The implementation is based on YOLOX, available at https://github.com/Megvii-BaseDetection/YOLOX + +import os +import sys + + +def get_exp_by_name(exp_name): + exp = exp_name.replace('-', + '_') # convert string like "yolox-s" to "yolox_s" + if exp == 'yolox_s': + from .default import YoloXSExp as YoloXExp + elif exp == 'yolox_nano': + from .default import YoloXNanoExp as YoloXExp + elif exp == 'yolox_tiny': + from .default import YoloXTinyExp as YoloXExp + else: + pass + return YoloXExp() diff --git a/modelscope/models/cv/realtime_object_detection/yolox/exp/default/__init__.py b/modelscope/models/cv/realtime_object_detection/yolox/exp/default/__init__.py new file mode 100644 index 00000000..552bbccd --- /dev/null +++ b/modelscope/models/cv/realtime_object_detection/yolox/exp/default/__init__.py @@ -0,0 +1,5 @@ +# The implementation is based on YOLOX, available at https://github.com/Megvii-BaseDetection/YOLOX + +from .yolox_nano import YoloXNanoExp +from .yolox_s import YoloXSExp +from .yolox_tiny import YoloXTinyExp diff --git a/modelscope/models/cv/realtime_object_detection/yolox/exp/default/yolox_nano.py b/modelscope/models/cv/realtime_object_detection/yolox/exp/default/yolox_nano.py new file mode 100644 index 00000000..330eef16 --- /dev/null +++ b/modelscope/models/cv/realtime_object_detection/yolox/exp/default/yolox_nano.py @@ -0,0 +1,46 @@ +# The implementation is based on YOLOX, available at https://github.com/Megvii-BaseDetection/YOLOX + +import os + +import torch.nn as nn + +from ..yolox_base import Exp as YoloXExp + + +class YoloXNanoExp(YoloXExp): + + def __init__(self): + super(YoloXNanoExp, self).__init__() + self.depth = 0.33 + self.width = 0.25 + self.input_size = (416, 416) + self.test_size = (416, 416) + + def get_model(self, sublinear=False): + + def init_yolo(M): + for m in M.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eps = 1e-3 + m.momentum = 0.03 + + if 'model' not in self.__dict__: + from ...models import YOLOX, YOLOPAFPN, YOLOXHead + in_channels = [256, 512, 1024] + # NANO model use depthwise = True, which is main difference. + backbone = YOLOPAFPN( + self.depth, + self.width, + in_channels=in_channels, + act=self.act, + depthwise=True, + ) + head = YOLOXHead( + self.num_classes, + self.width, + in_channels=in_channels, + act=self.act, + depthwise=True) + self.model = YOLOX(backbone, head) + + return self.model diff --git a/modelscope/models/cv/realtime_object_detection/yolox/exp/default/yolox_s.py b/modelscope/models/cv/realtime_object_detection/yolox/exp/default/yolox_s.py new file mode 100644 index 00000000..5a123b37 --- /dev/null +++ b/modelscope/models/cv/realtime_object_detection/yolox/exp/default/yolox_s.py @@ -0,0 +1,13 @@ +# The implementation is based on YOLOX, available at https://github.com/Megvii-BaseDetection/YOLOX + +import os + +from ..yolox_base import Exp as YoloXExp + + +class YoloXSExp(YoloXExp): + + def __init__(self): + super(YoloXSExp, self).__init__() + self.depth = 0.33 + self.width = 0.50 diff --git a/modelscope/models/cv/realtime_object_detection/yolox/exp/default/yolox_tiny.py b/modelscope/models/cv/realtime_object_detection/yolox/exp/default/yolox_tiny.py new file mode 100644 index 00000000..a80d0f2d --- /dev/null +++ b/modelscope/models/cv/realtime_object_detection/yolox/exp/default/yolox_tiny.py @@ -0,0 +1,20 @@ +# The implementation is based on YOLOX, available at https://github.com/Megvii-BaseDetection/YOLOX + +import os + +from ..yolox_base import Exp as YoloXExp + + +class YoloXTinyExp(YoloXExp): + + def __init__(self): + super(YoloXTinyExp, self).__init__() + self.depth = 0.33 + self.width = 0.375 + self.input_size = (416, 416) + self.mosaic_scale = (0.5, 1.5) + self.random_size = (10, 20) + self.test_size = (416, 416) + self.exp_name = os.path.split( + os.path.realpath(__file__))[1].split('.')[0] + self.enable_mixup = False diff --git a/modelscope/models/cv/realtime_object_detection/yolox/exp/yolox_base.py b/modelscope/models/cv/realtime_object_detection/yolox/exp/yolox_base.py new file mode 100644 index 00000000..a2a41535 --- /dev/null +++ b/modelscope/models/cv/realtime_object_detection/yolox/exp/yolox_base.py @@ -0,0 +1,59 @@ +# The implementation is based on YOLOX, available at https://github.com/Megvii-BaseDetection/YOLOX + +import os +import random + +import torch +import torch.distributed as dist +import torch.nn as nn + +from .base_exp import BaseExp + + +class Exp(BaseExp): + + def __init__(self): + super().__init__() + + # ---------------- model config ---------------- # + # detect classes number of model + self.num_classes = 80 + # factor of model depth + self.depth = 1.00 + # factor of model width + self.width = 1.00 + # activation name. For example, if using "relu", then "silu" will be replaced to "relu". + self.act = 'silu' + # ----------------- testing config ------------------ # + # output image size during evaluation/test + self.test_size = (640, 640) + # confidence threshold during evaluation/test, + # boxes whose scores are less than test_conf will be filtered + self.test_conf = 0.01 + # nms threshold + self.nmsthre = 0.65 + + def get_model(self): + from ..models import YOLOX, YOLOPAFPN, YOLOXHead + + def init_yolo(M): + for m in M.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eps = 1e-3 + m.momentum = 0.03 + + if getattr(self, 'model', None) is None: + in_channels = [256, 512, 1024] + backbone = YOLOPAFPN( + self.depth, self.width, in_channels=in_channels, act=self.act) + head = YOLOXHead( + self.num_classes, + self.width, + in_channels=in_channels, + act=self.act) + self.model = YOLOX(backbone, head) + + self.model.apply(init_yolo) + self.model.head.initialize_biases(1e-2) + self.model.train() + return self.model diff --git a/modelscope/models/cv/realtime_object_detection/yolox/models/__init__.py b/modelscope/models/cv/realtime_object_detection/yolox/models/__init__.py new file mode 100644 index 00000000..20b1a0d1 --- /dev/null +++ b/modelscope/models/cv/realtime_object_detection/yolox/models/__init__.py @@ -0,0 +1,7 @@ +# The implementation is based on YOLOX, available at https://github.com/Megvii-BaseDetection/YOLOX + +from .darknet import CSPDarknet, Darknet +from .yolo_fpn import YOLOFPN +from .yolo_head import YOLOXHead +from .yolo_pafpn import YOLOPAFPN +from .yolox import YOLOX diff --git a/modelscope/models/cv/realtime_object_detection/yolox/models/darknet.py b/modelscope/models/cv/realtime_object_detection/yolox/models/darknet.py new file mode 100644 index 00000000..8ece2a1e --- /dev/null +++ b/modelscope/models/cv/realtime_object_detection/yolox/models/darknet.py @@ -0,0 +1,189 @@ +# The implementation is based on YOLOX, available at https://github.com/Megvii-BaseDetection/YOLOX + +from torch import nn + +from .network_blocks import (BaseConv, CSPLayer, DWConv, Focus, ResLayer, + SPPBottleneck) + + +class Darknet(nn.Module): + # number of blocks from dark2 to dark5. + depth2blocks = {21: [1, 2, 2, 1], 53: [2, 8, 8, 4]} + + def __init__( + self, + depth, + in_channels=3, + stem_out_channels=32, + out_features=('dark3', 'dark4', 'dark5'), + ): + """ + Args: + depth (int): depth of darknet used in model, usually use [21, 53] for this param. + in_channels (int): number of input channels, for example, use 3 for RGB image. + stem_out_channels (int): number of output channels of darknet stem. + It decides channels of darknet layer2 to layer5. + out_features (Tuple[str]): desired output layer name. + """ + super().__init__() + assert out_features, 'please provide output features of Darknet' + self.out_features = out_features + self.stem = nn.Sequential( + BaseConv( + in_channels, stem_out_channels, ksize=3, stride=1, + act='lrelu'), + *self.make_group_layer(stem_out_channels, num_blocks=1, stride=2), + ) + in_channels = stem_out_channels * 2 # 64 + + num_blocks = Darknet.depth2blocks[depth] + # create darknet with `stem_out_channels` and `num_blocks` layers. + # to make model structure more clear, we don't use `for` statement in python. + self.dark2 = nn.Sequential( + *self.make_group_layer(in_channels, num_blocks[0], stride=2)) + in_channels *= 2 # 128 + self.dark3 = nn.Sequential( + *self.make_group_layer(in_channels, num_blocks[1], stride=2)) + in_channels *= 2 # 256 + self.dark4 = nn.Sequential( + *self.make_group_layer(in_channels, num_blocks[2], stride=2)) + in_channels *= 2 # 512 + + self.dark5 = nn.Sequential( + *self.make_group_layer(in_channels, num_blocks[3], stride=2), + *self.make_spp_block([in_channels, in_channels * 2], + in_channels * 2), + ) + + def make_group_layer(self, + in_channels: int, + num_blocks: int, + stride: int = 1): + 'starts with conv layer then has `num_blocks` `ResLayer`' + return [ + BaseConv( + in_channels, + in_channels * 2, + ksize=3, + stride=stride, + act='lrelu'), + *[(ResLayer(in_channels * 2)) for _ in range(num_blocks)], + ] + + def make_spp_block(self, filters_list, in_filters): + m = nn.Sequential(*[ + BaseConv(in_filters, filters_list[0], 1, stride=1, act='lrelu'), + BaseConv( + filters_list[0], filters_list[1], 3, stride=1, act='lrelu'), + SPPBottleneck( + in_channels=filters_list[1], + out_channels=filters_list[0], + activation='lrelu', + ), + BaseConv( + filters_list[0], filters_list[1], 3, stride=1, act='lrelu'), + BaseConv( + filters_list[1], filters_list[0], 1, stride=1, act='lrelu'), + ]) + return m + + def forward(self, x): + outputs = {} + x = self.stem(x) + outputs['stem'] = x + x = self.dark2(x) + outputs['dark2'] = x + x = self.dark3(x) + outputs['dark3'] = x + x = self.dark4(x) + outputs['dark4'] = x + x = self.dark5(x) + outputs['dark5'] = x + return {k: v for k, v in outputs.items() if k in self.out_features} + + +class CSPDarknet(nn.Module): + + def __init__( + self, + dep_mul, + wid_mul, + out_features=('dark3', 'dark4', 'dark5'), + depthwise=False, + act='silu', + ): + super().__init__() + assert out_features, 'please provide output features of Darknet' + self.out_features = out_features + Conv = DWConv if depthwise else BaseConv + + base_channels = int(wid_mul * 64) # 64 + base_depth = max(round(dep_mul * 3), 1) # 3 + + # stem + self.stem = Focus(3, base_channels, ksize=3, act=act) + + # dark2 + self.dark2 = nn.Sequential( + Conv(base_channels, base_channels * 2, 3, 2, act=act), + CSPLayer( + base_channels * 2, + base_channels * 2, + n=base_depth, + depthwise=depthwise, + act=act, + ), + ) + + # dark3 + self.dark3 = nn.Sequential( + Conv(base_channels * 2, base_channels * 4, 3, 2, act=act), + CSPLayer( + base_channels * 4, + base_channels * 4, + n=base_depth * 3, + depthwise=depthwise, + act=act, + ), + ) + + # dark4 + self.dark4 = nn.Sequential( + Conv(base_channels * 4, base_channels * 8, 3, 2, act=act), + CSPLayer( + base_channels * 8, + base_channels * 8, + n=base_depth * 3, + depthwise=depthwise, + act=act, + ), + ) + + # dark5 + self.dark5 = nn.Sequential( + Conv(base_channels * 8, base_channels * 16, 3, 2, act=act), + SPPBottleneck( + base_channels * 16, base_channels * 16, activation=act), + CSPLayer( + base_channels * 16, + base_channels * 16, + n=base_depth, + shortcut=False, + depthwise=depthwise, + act=act, + ), + ) + + def forward(self, x): + outputs = {} + x = self.stem(x) + outputs['stem'] = x + x = self.dark2(x) + outputs['dark2'] = x + x = self.dark3(x) + outputs['dark3'] = x + x = self.dark4(x) + outputs['dark4'] = x + x = self.dark5(x) + outputs['dark5'] = x + return {k: v for k, v in outputs.items() if k in self.out_features} diff --git a/modelscope/models/cv/realtime_object_detection/yolox/models/network_blocks.py b/modelscope/models/cv/realtime_object_detection/yolox/models/network_blocks.py new file mode 100644 index 00000000..fd15c1c1 --- /dev/null +++ b/modelscope/models/cv/realtime_object_detection/yolox/models/network_blocks.py @@ -0,0 +1,213 @@ +# The implementation is based on YOLOX, available at https://github.com/Megvii-BaseDetection/YOLOX + +import torch +import torch.nn as nn + + +def get_activation(name='silu', inplace=True): + if name == 'silu': + module = nn.SiLU(inplace=inplace) + else: + raise AttributeError('Unsupported act type: {}'.format(name)) + return module + + +class BaseConv(nn.Module): + """A Conv2d -> Batchnorm -> silu/leaky relu block""" + + def __init__(self, + in_channels, + out_channels, + ksize, + stride, + groups=1, + bias=False, + act='silu'): + super(BaseConv, self).__init__() + # same padding + pad = (ksize - 1) // 2 + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size=ksize, + stride=stride, + padding=pad, + groups=groups, + bias=bias, + ) + self.bn = nn.BatchNorm2d(out_channels) + self.act = get_activation(act, inplace=True) + + def forward(self, x): + return self.act(self.bn(self.conv(x))) + + def fuseforward(self, x): + return self.act(self.conv(x)) + + +class DWConv(nn.Module): + """Depthwise Conv + Conv""" + + def __init__(self, in_channels, out_channels, ksize, stride=1, act='silu'): + super(DWConv, self).__init__() + self.dconv = BaseConv( + in_channels, + in_channels, + ksize=ksize, + stride=stride, + groups=in_channels, + act=act, + ) + self.pconv = BaseConv( + in_channels, out_channels, ksize=1, stride=1, groups=1, act=act) + + def forward(self, x): + x = self.dconv(x) + return self.pconv(x) + + +class Bottleneck(nn.Module): + # Standard bottleneck + def __init__( + self, + in_channels, + out_channels, + shortcut=True, + expansion=0.5, + depthwise=False, + act='silu', + ): + super().__init__() + hidden_channels = int(out_channels * expansion) + Conv = DWConv if depthwise else BaseConv + self.conv1 = BaseConv( + in_channels, hidden_channels, 1, stride=1, act=act) + self.conv2 = Conv(hidden_channels, out_channels, 3, stride=1, act=act) + self.use_add = shortcut and in_channels == out_channels + + def forward(self, x): + y = self.conv2(self.conv1(x)) + if self.use_add: + y = y + x + return y + + +class ResLayer(nn.Module): + 'Residual layer with `in_channels` inputs.' + + def __init__(self, in_channels: int): + super().__init__() + mid_channels = in_channels // 2 + self.layer1 = BaseConv( + in_channels, mid_channels, ksize=1, stride=1, act='lrelu') + self.layer2 = BaseConv( + mid_channels, in_channels, ksize=3, stride=1, act='lrelu') + + def forward(self, x): + out = self.layer2(self.layer1(x)) + return x + out + + +class SPPBottleneck(nn.Module): + """Spatial pyramid pooling layer used in YOLOv3-SPP""" + + def __init__(self, + in_channels, + out_channels, + kernel_sizes=(5, 9, 13), + activation='silu'): + super().__init__() + hidden_channels = in_channels // 2 + self.conv1 = BaseConv( + in_channels, hidden_channels, 1, stride=1, act=activation) + self.m = nn.ModuleList([ + nn.MaxPool2d(kernel_size=ks, stride=1, padding=ks // 2) + for ks in kernel_sizes + ]) + conv2_channels = hidden_channels * (len(kernel_sizes) + 1) + self.conv2 = BaseConv( + conv2_channels, out_channels, 1, stride=1, act=activation) + + def forward(self, x): + x = self.conv1(x) + x = torch.cat([x] + [m(x) for m in self.m], dim=1) + x = self.conv2(x) + return x + + +class CSPLayer(nn.Module): + """C3 in yolov5, CSP Bottleneck with 3 convolutions""" + + def __init__( + self, + in_channels, + out_channels, + n=1, + shortcut=True, + expansion=0.5, + depthwise=False, + act='silu', + ): + """ + Args: + in_channels (int): input channels. + out_channels (int): output channels. + n (int): number of Bottlenecks. Default value: 1. + """ + # ch_in, ch_out, number, shortcut, groups, expansion + super().__init__() + hidden_channels = int(out_channels * expansion) # hidden channels + self.conv1 = BaseConv( + in_channels, hidden_channels, 1, stride=1, act=act) + self.conv2 = BaseConv( + in_channels, hidden_channels, 1, stride=1, act=act) + self.conv3 = BaseConv( + 2 * hidden_channels, out_channels, 1, stride=1, act=act) + module_list = [ + Bottleneck( + hidden_channels, + hidden_channels, + shortcut, + 1.0, + depthwise, + act=act) for _ in range(n) + ] + self.m = nn.Sequential(*module_list) + + def forward(self, x): + x_1 = self.conv1(x) + x_2 = self.conv2(x) + x_1 = self.m(x_1) + x = torch.cat((x_1, x_2), dim=1) + return self.conv3(x) + + +class Focus(nn.Module): + """Focus width and height information into channel space.""" + + def __init__(self, + in_channels, + out_channels, + ksize=1, + stride=1, + act='silu'): + super().__init__() + self.conv = BaseConv( + in_channels * 4, out_channels, ksize, stride, act=act) + + def forward(self, x): + # shape of x (b,c,w,h) -> y(b,4c,w/2,h/2) + patch_top_left = x[..., ::2, ::2] + patch_top_right = x[..., ::2, 1::2] + patch_bot_left = x[..., 1::2, ::2] + patch_bot_right = x[..., 1::2, 1::2] + x = torch.cat( + ( + patch_top_left, + patch_bot_left, + patch_top_right, + patch_bot_right, + ), + dim=1, + ) + return self.conv(x) diff --git a/modelscope/models/cv/realtime_object_detection/yolox/models/yolo_fpn.py b/modelscope/models/cv/realtime_object_detection/yolox/models/yolo_fpn.py new file mode 100644 index 00000000..0cbebb09 --- /dev/null +++ b/modelscope/models/cv/realtime_object_detection/yolox/models/yolo_fpn.py @@ -0,0 +1,80 @@ +# The implementation is based on YOLOX, available at https://github.com/Megvii-BaseDetection/YOLOX + +import torch +import torch.nn as nn + +from .darknet import Darknet +from .network_blocks import BaseConv + + +class YOLOFPN(nn.Module): + """ + YOLOFPN module. Darknet 53 is the default backbone of this model. + """ + + def __init__( + self, + depth=53, + in_features=['dark3', 'dark4', 'dark5'], + ): + super(YOLOFPN, self).__init__() + + self.backbone = Darknet(depth) + self.in_features = in_features + + # out 1 + self.out1_cbl = self._make_cbl(512, 256, 1) + self.out1 = self._make_embedding([256, 512], 512 + 256) + + # out 2 + self.out2_cbl = self._make_cbl(256, 128, 1) + self.out2 = self._make_embedding([128, 256], 256 + 128) + + # upsample + self.upsample = nn.Upsample(scale_factor=2, mode='nearest') + + def _make_cbl(self, _in, _out, ks): + return BaseConv(_in, _out, ks, stride=1, act='lrelu') + + def _make_embedding(self, filters_list, in_filters): + m = nn.Sequential(*[ + self._make_cbl(in_filters, filters_list[0], 1), + self._make_cbl(filters_list[0], filters_list[1], 3), + self._make_cbl(filters_list[1], filters_list[0], 1), + self._make_cbl(filters_list[0], filters_list[1], 3), + self._make_cbl(filters_list[1], filters_list[0], 1), + ]) + return m + + def load_pretrained_model(self, filename='./weights/darknet53.mix.pth'): + with open(filename, 'rb') as f: + state_dict = torch.load(f, map_location='cpu') + print('loading pretrained weights...') + self.backbone.load_state_dict(state_dict) + + def forward(self, inputs): + """ + Args: + inputs (Tensor): input image. + + Returns: + Tuple[Tensor]: FPN output features.. + """ + # backbone + out_features = self.backbone(inputs) + x2, x1, x0 = [out_features[f] for f in self.in_features] + + # yolo branch 1 + x1_in = self.out1_cbl(x0) + x1_in = self.upsample(x1_in) + x1_in = torch.cat([x1_in, x1], 1) + out_dark4 = self.out1(x1_in) + + # yolo branch 2 + x2_in = self.out2_cbl(out_dark4) + x2_in = self.upsample(x2_in) + x2_in = torch.cat([x2_in, x2], 1) + out_dark3 = self.out2(x2_in) + + outputs = (out_dark3, out_dark4, x0) + return outputs diff --git a/modelscope/models/cv/realtime_object_detection/yolox/models/yolo_head.py b/modelscope/models/cv/realtime_object_detection/yolox/models/yolo_head.py new file mode 100644 index 00000000..1eef93a4 --- /dev/null +++ b/modelscope/models/cv/realtime_object_detection/yolox/models/yolo_head.py @@ -0,0 +1,182 @@ +# The implementation is based on YOLOX, available at https://github.com/Megvii-BaseDetection/YOLOX + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ..utils import bboxes_iou, meshgrid +from .network_blocks import BaseConv, DWConv + + +class YOLOXHead(nn.Module): + + def __init__( + self, + num_classes, + width=1.0, + strides=[8, 16, 32], + in_channels=[256, 512, 1024], + act='silu', + depthwise=False, + ): + """ + Args: + act (str): activation type of conv. Defalut value: "silu". + depthwise (bool): whether apply depthwise conv in conv branch. Defalut value: False. + """ + super(YOLOXHead, self).__init__() + + self.n_anchors = 1 + self.num_classes = num_classes + self.decode_in_inference = True # for deploy, set to False + + self.cls_convs = nn.ModuleList() + self.reg_convs = nn.ModuleList() + self.cls_preds = nn.ModuleList() + self.reg_preds = nn.ModuleList() + self.obj_preds = nn.ModuleList() + self.stems = nn.ModuleList() + Conv = DWConv if depthwise else BaseConv + + for i in range(len(in_channels)): + self.stems.append( + BaseConv( + in_channels=int(in_channels[i] * width), + out_channels=int(256 * width), + ksize=1, + stride=1, + act=act, + )) + self.cls_convs.append( + nn.Sequential(*[ + Conv( + in_channels=int(256 * width), + out_channels=int(256 * width), + ksize=3, + stride=1, + act=act, + ), + Conv( + in_channels=int(256 * width), + out_channels=int(256 * width), + ksize=3, + stride=1, + act=act, + ), + ])) + self.reg_convs.append( + nn.Sequential(*[ + Conv( + in_channels=int(256 * width), + out_channels=int(256 * width), + ksize=3, + stride=1, + act=act, + ), + Conv( + in_channels=int(256 * width), + out_channels=int(256 * width), + ksize=3, + stride=1, + act=act, + ), + ])) + self.cls_preds.append( + nn.Conv2d( + in_channels=int(256 * width), + out_channels=self.n_anchors * self.num_classes, + kernel_size=1, + stride=1, + padding=0, + )) + self.reg_preds.append( + nn.Conv2d( + in_channels=int(256 * width), + out_channels=4, + kernel_size=1, + stride=1, + padding=0, + )) + self.obj_preds.append( + nn.Conv2d( + in_channels=int(256 * width), + out_channels=self.n_anchors * 1, + kernel_size=1, + stride=1, + padding=0, + )) + + self.use_l1 = False + self.l1_loss = nn.L1Loss(reduction='none') + self.bcewithlog_loss = nn.BCEWithLogitsLoss(reduction='none') + # self.iou_loss = IOUloss(reduction="none") + self.strides = strides + self.grids = [torch.zeros(1)] * len(in_channels) + + def initialize_biases(self, prior_prob): + for conv in self.cls_preds: + b = conv.bias.view(self.n_anchors, -1) + b.data.fill_(-math.log((1 - prior_prob) / prior_prob)) + conv.bias = torch.nn.Parameter(b.view(-1), requires_grad=True) + + for conv in self.obj_preds: + b = conv.bias.view(self.n_anchors, -1) + b.data.fill_(-math.log((1 - prior_prob) / prior_prob)) + conv.bias = torch.nn.Parameter(b.view(-1), requires_grad=True) + + def forward(self, xin, labels=None, imgs=None): + outputs = [] + + for k, (cls_conv, reg_conv, stride_this_level, x) in enumerate( + zip(self.cls_convs, self.reg_convs, self.strides, xin)): + x = self.stems[k](x) + cls_x = x + reg_x = x + + cls_feat = cls_conv(cls_x) + cls_output = self.cls_preds[k](cls_feat) + + reg_feat = reg_conv(reg_x) + reg_output = self.reg_preds[k](reg_feat) + obj_output = self.obj_preds[k](reg_feat) + + if self.training: + pass + else: + output = torch.cat( + [reg_output, + obj_output.sigmoid(), + cls_output.sigmoid()], 1) + + outputs.append(output) + + if self.training: + pass + else: + self.hw = [x.shape[-2:] for x in outputs] + # [batch, n_anchors_all, 85] + outputs = torch.cat([x.flatten(start_dim=2) for x in outputs], + dim=2).permute(0, 2, 1) + if self.decode_in_inference: + return self.decode_outputs(outputs, dtype=xin[0].type()) + else: + return outputs + + def decode_outputs(self, outputs, dtype): + grids = [] + strides = [] + for (hsize, wsize), stride in zip(self.hw, self.strides): + yv, xv = meshgrid([torch.arange(hsize), torch.arange(wsize)]) + grid = torch.stack((xv, yv), 2).view(1, -1, 2) + grids.append(grid) + shape = grid.shape[:2] + strides.append(torch.full((*shape, 1), stride)) + + grids = torch.cat(grids, dim=1).type(dtype) + strides = torch.cat(strides, dim=1).type(dtype) + + outputs[..., :2] = (outputs[..., :2] + grids) * strides + outputs[..., 2:4] = torch.exp(outputs[..., 2:4]) * strides + return outputs diff --git a/modelscope/models/cv/realtime_object_detection/yolox/models/yolo_pafpn.py b/modelscope/models/cv/realtime_object_detection/yolox/models/yolo_pafpn.py new file mode 100644 index 00000000..cd4258bf --- /dev/null +++ b/modelscope/models/cv/realtime_object_detection/yolox/models/yolo_pafpn.py @@ -0,0 +1,126 @@ +# The implementation is based on YOLOX, available at https://github.com/Megvii-BaseDetection/YOLOX + +import torch +import torch.nn as nn + +from .darknet import CSPDarknet +from .network_blocks import BaseConv, CSPLayer, DWConv + + +class YOLOPAFPN(nn.Module): + """ + YOLOv3 model. Darknet 53 is the default backbone of this model. + """ + + def __init__( + self, + depth=1.0, + width=1.0, + in_features=('dark3', 'dark4', 'dark5'), + in_channels=[256, 512, 1024], + depthwise=False, + act='silu', + ): + super(YOLOPAFPN, self).__init__() + self.backbone = CSPDarknet(depth, width, depthwise=depthwise, act=act) + self.in_features = in_features + self.in_channels = in_channels + Conv = DWConv if depthwise else BaseConv + + self.upsample = nn.Upsample(scale_factor=2, mode='nearest') + self.lateral_conv0 = BaseConv( + int(in_channels[2] * width), + int(in_channels[1] * width), + 1, + 1, + act=act) + self.C3_p4 = CSPLayer( + int(2 * in_channels[1] * width), + int(in_channels[1] * width), + round(3 * depth), + False, + depthwise=depthwise, + act=act, + ) # cat + + self.reduce_conv1 = BaseConv( + int(in_channels[1] * width), + int(in_channels[0] * width), + 1, + 1, + act=act) + self.C3_p3 = CSPLayer( + int(2 * in_channels[0] * width), + int(in_channels[0] * width), + round(3 * depth), + False, + depthwise=depthwise, + act=act, + ) + + # bottom-up conv + self.bu_conv2 = Conv( + int(in_channels[0] * width), + int(in_channels[0] * width), + 3, + 2, + act=act) + self.C3_n3 = CSPLayer( + int(2 * in_channels[0] * width), + int(in_channels[1] * width), + round(3 * depth), + False, + depthwise=depthwise, + act=act, + ) + + # bottom-up conv + self.bu_conv1 = Conv( + int(in_channels[1] * width), + int(in_channels[1] * width), + 3, + 2, + act=act) + self.C3_n4 = CSPLayer( + int(2 * in_channels[1] * width), + int(in_channels[2] * width), + round(3 * depth), + False, + depthwise=depthwise, + act=act, + ) + + def forward(self, input): + """ + Args: + inputs: input images. + + Returns: + Tuple[Tensor]: FPN feature. + """ + + # backbone + out_features = self.backbone(input) + features = [out_features[f] for f in self.in_features] + [x2, x1, x0] = features + + fpn_out0 = self.lateral_conv0(x0) # 1024->512/32 + f_out0 = self.upsample(fpn_out0) # 512/16 + f_out0 = torch.cat([f_out0, x1], 1) # 512->1024/16 + f_out0 = self.C3_p4(f_out0) # 1024->512/16 + + fpn_out1 = self.reduce_conv1(f_out0) # 512->256/16 + f_out1 = self.upsample(fpn_out1) # 256/8 + f_out1 = torch.cat([f_out1, x2], 1) # 256->512/8 + pan_out2 = self.C3_p3(f_out1) # 512->256/8 + + p_out1 = self.bu_conv2(pan_out2) # 256->256/16 + p_out1 = torch.cat([p_out1, fpn_out1], 1) # 256->512/16 + pan_out1 = self.C3_n3(p_out1) # 512->512/16 + + p_out0 = self.bu_conv1(pan_out1) # 512->512/32 + p_out0 = torch.cat([p_out0, fpn_out0], 1) # 512->1024/32 + pan_out0 = self.C3_n4(p_out0) # 1024->1024/32 + + outputs = (pan_out2, pan_out1, pan_out0) + return outputs diff --git a/modelscope/models/cv/realtime_object_detection/yolox/models/yolox.py b/modelscope/models/cv/realtime_object_detection/yolox/models/yolox.py new file mode 100644 index 00000000..181c368b --- /dev/null +++ b/modelscope/models/cv/realtime_object_detection/yolox/models/yolox.py @@ -0,0 +1,33 @@ +# The implementation is based on YOLOX, available at https://github.com/Megvii-BaseDetection/YOLOX + +import torch.nn as nn + +from .yolo_head import YOLOXHead +from .yolo_pafpn import YOLOPAFPN + + +class YOLOX(nn.Module): + """ + YOLOX model module. The module list is defined by create_yolov3_modules function. + The network returns loss values from three YOLO layers during training + and detection results during test. + """ + + def __init__(self, backbone=None, head=None): + super(YOLOX, self).__init__() + if backbone is None: + backbone = YOLOPAFPN() + if head is None: + head = YOLOXHead(80) + + self.backbone = backbone + self.head = head + + def forward(self, x, targets=None): + fpn_outs = self.backbone(x) + if self.training: + raise NotImplementedError('Training is not supported yet!') + else: + outputs = self.head(fpn_outs) + + return outputs diff --git a/modelscope/models/cv/realtime_object_detection/yolox/utils/__init__.py b/modelscope/models/cv/realtime_object_detection/yolox/utils/__init__.py new file mode 100644 index 00000000..2c1ea489 --- /dev/null +++ b/modelscope/models/cv/realtime_object_detection/yolox/utils/__init__.py @@ -0,0 +1,5 @@ +# The implementation is based on YOLOX, available at https://github.com/Megvii-BaseDetection/YOLOX + +from .boxes import * # noqa + +__all__ = ['bboxes_iou', 'meshgrid', 'postprocess', 'xyxy2cxcywh', 'xyxy2xywh'] diff --git a/modelscope/models/cv/realtime_object_detection/yolox/utils/boxes.py b/modelscope/models/cv/realtime_object_detection/yolox/utils/boxes.py new file mode 100644 index 00000000..b29a3a04 --- /dev/null +++ b/modelscope/models/cv/realtime_object_detection/yolox/utils/boxes.py @@ -0,0 +1,107 @@ +# The implementation is based on YOLOX, available at https://github.com/Megvii-BaseDetection/YOLOX + +import torch +import torchvision + +_TORCH_VER = [int(x) for x in torch.__version__.split('.')[:2]] + + +def meshgrid(*tensors): + if _TORCH_VER >= [1, 10]: + return torch.meshgrid(*tensors, indexing='ij') + else: + return torch.meshgrid(*tensors) + + +def xyxy2xywh(bboxes): + bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 0] + bboxes[:, 3] = bboxes[:, 3] - bboxes[:, 1] + return bboxes + + +def xyxy2cxcywh(bboxes): + bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 0] + bboxes[:, 3] = bboxes[:, 3] - bboxes[:, 1] + bboxes[:, 0] = bboxes[:, 0] + bboxes[:, 2] * 0.5 + bboxes[:, 1] = bboxes[:, 1] + bboxes[:, 3] * 0.5 + return bboxes + + +def postprocess(prediction, + num_classes, + conf_thre=0.7, + nms_thre=0.45, + class_agnostic=False): + box_corner = prediction.new(prediction.shape) + box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2 + box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2 + box_corner[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2 + box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2 + prediction[:, :, :4] = box_corner[:, :, :4] + + output = [None for _ in range(len(prediction))] + for i, image_pred in enumerate(prediction): + + # If none are remaining => process next image + if not image_pred.size(0): + continue + # Get score and class with highest confidence + class_conf, class_pred = torch.max( + image_pred[:, 5:5 + num_classes], 1, keepdim=True) + + conf_mask = image_pred[:, 4] * class_conf.squeeze() + conf_mask = (conf_mask >= conf_thre).squeeze() + # Detections ordered as (x1, y1, x2, y2, obj_conf, class_conf, class_pred) + detections = torch.cat( + (image_pred[:, :5], class_conf, class_pred.float()), 1) + detections = detections[conf_mask] + if not detections.size(0): + continue + + if class_agnostic: + nms_out_index = torchvision.ops.nms( + detections[:, :4], + detections[:, 4] * detections[:, 5], + nms_thre, + ) + else: + nms_out_index = torchvision.ops.batched_nms( + detections[:, :4], + detections[:, 4] * detections[:, 5], + detections[:, 6], + nms_thre, + ) + + detections = detections[nms_out_index] + if output[i] is None: + output[i] = detections + else: + output[i] = torch.cat((output[i], detections)) + + return output + + +def bboxes_iou(bboxes_a, bboxes_b, xyxy=True): + if bboxes_a.shape[1] != 4 or bboxes_b.shape[1] != 4: + raise IndexError + + if xyxy: + tl = torch.max(bboxes_a[:, None, :2], bboxes_b[:, :2]) + br = torch.min(bboxes_a[:, None, 2:], bboxes_b[:, 2:]) + area_a = torch.prod(bboxes_a[:, 2:] - bboxes_a[:, :2], 1) + area_b = torch.prod(bboxes_b[:, 2:] - bboxes_b[:, :2], 1) + else: + tl = torch.max( + (bboxes_a[:, None, :2] - bboxes_a[:, None, 2:] / 2), + (bboxes_b[:, :2] - bboxes_b[:, 2:] / 2), + ) + br = torch.min( + (bboxes_a[:, None, :2] + bboxes_a[:, None, 2:] / 2), + (bboxes_b[:, :2] + bboxes_b[:, 2:] / 2), + ) + + area_a = torch.prod(bboxes_a[:, 2:], 1) + area_b = torch.prod(bboxes_b[:, 2:], 1) + en = (tl < br).type(tl.type()).prod(dim=2) + area_i = torch.prod(br - tl, 2) * en # * ((tl < br).all()) + return area_i / (area_a[:, None] + area_b - area_i) diff --git a/modelscope/pipelines/cv/__init__.py b/modelscope/pipelines/cv/__init__.py index b1a513e5..2c062226 100644 --- a/modelscope/pipelines/cv/__init__.py +++ b/modelscope/pipelines/cv/__init__.py @@ -32,6 +32,7 @@ if TYPE_CHECKING: from .image_to_image_generate_pipeline import Image2ImageGenerationPipeline from .image_to_image_translation_pipeline import Image2ImageTranslationPipeline from .product_retrieval_embedding_pipeline import ProductRetrievalEmbeddingPipeline + from .realtime_object_detection_pipeline import RealtimeObjectDetectionPipeline from .live_category_pipeline import LiveCategoryPipeline from .ocr_detection_pipeline import OCRDetectionPipeline from .ocr_recognition_pipeline import OCRRecognitionPipeline @@ -75,6 +76,8 @@ else: ['Image2ImageTranslationPipeline'], 'product_retrieval_embedding_pipeline': ['ProductRetrievalEmbeddingPipeline'], + 'realtime_object_detection_pipeline': + ['RealtimeObjectDetectionPipeline'], 'live_category_pipeline': ['LiveCategoryPipeline'], 'image_to_image_generation_pipeline': ['Image2ImageGenerationPipeline'], diff --git a/modelscope/pipelines/cv/realtime_object_detection_pipeline.py b/modelscope/pipelines/cv/realtime_object_detection_pipeline.py new file mode 100644 index 00000000..629720d1 --- /dev/null +++ b/modelscope/pipelines/cv/realtime_object_detection_pipeline.py @@ -0,0 +1,50 @@ +import os.path as osp +from typing import Any, Dict, List, Union + +import cv2 +import json +import numpy as np +import torch +from PIL import Image +from torchvision import transforms + +from modelscope.metainfo import Pipelines +from modelscope.models.cv.realtime_object_detection import RealtimeDetector +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.pipelines.base import Input, Model, Pipeline, Tensor +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import load_image +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.image_object_detection, + module_name=Pipelines.realtime_object_detection) +class RealtimeObjectDetectionPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + super().__init__(model=model, **kwargs) + self.model = RealtimeDetector(model) + + def preprocess(self, input: Input) -> Dict[Tensor, Union[str, np.ndarray]]: + output = self.model.preprocess(input) + return {'pre_output': output} + + def forward(self, input: Tensor) -> Dict[Tensor, Dict[str, np.ndarray]]: + pre_output = input['pre_output'] + forward_output = self.model(pre_output) + return {'forward_output': forward_output} + + def postprocess(self, input: Dict[Tensor, Dict[str, np.ndarray]], + **kwargs) -> str: + forward_output = input['forward_output'] + bboxes, scores, labels = forward_output + return { + OutputKeys.BOXES: bboxes, + OutputKeys.SCORES: scores, + OutputKeys.LABELS: labels, + } diff --git a/modelscope/utils/cv/image_utils.py b/modelscope/utils/cv/image_utils.py index 0ad0ef8f..ea1d95b5 100644 --- a/modelscope/utils/cv/image_utils.py +++ b/modelscope/utils/cv/image_utils.py @@ -70,6 +70,13 @@ def draw_box(image, box): (int(box[1][0]), int(box[1][1])), (0, 0, 255), 2) +def realtime_object_detection_bbox_vis(image, bboxes): + for bbox in bboxes: + cv2.rectangle(image, (bbox[0], bbox[1]), (bbox[2], bbox[3]), + (255, 0, 0), 2) + return image + + def draw_keypoints(output, original_image): poses = np.array(output[OutputKeys.POSES]) scores = np.array(output[OutputKeys.SCORES]) diff --git a/tests/pipelines/test_realtime_object_detection.py b/tests/pipelines/test_realtime_object_detection.py new file mode 100644 index 00000000..03ddacf4 --- /dev/null +++ b/tests/pipelines/test_realtime_object_detection.py @@ -0,0 +1,52 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +import cv2 +import numpy as np + +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.pipelines.base import Pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.cv.image_utils import realtime_object_detection_bbox_vis +from modelscope.utils.test_utils import test_level + + +class RealtimeObjectDetectionTest(unittest.TestCase): + + def setUp(self) -> None: + self.model_id = 'damo/cv_cspnet_image-object-detection_yolox' + self.model_nano_id = 'damo/cv_cspnet_image-object-detection_yolox_nano_coco' + self.test_image = 'data/test/images/keypoints_detect/000000438862.jpg' + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_modelhub(self): + realtime_object_detection = pipeline( + Tasks.image_object_detection, model=self.model_id) + + image = cv2.imread(self.test_image) + result = realtime_object_detection(image) + if result: + bboxes = result[OutputKeys.BOXES].astype(int) + image = realtime_object_detection_bbox_vis(image, bboxes) + cv2.imwrite('rt_obj_out.jpg', image) + else: + raise ValueError('process error') + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_nano(self): + realtime_object_detection = pipeline( + Tasks.image_object_detection, model=self.model_nano_id) + + image = cv2.imread(self.test_image) + result = realtime_object_detection(image) + if result: + bboxes = result[OutputKeys.BOXES].astype(int) + image = realtime_object_detection_bbox_vis(image, bboxes) + cv2.imwrite('rtnano_obj_out.jpg', image) + else: + raise ValueError('process error') + + +if __name__ == '__main__': + unittest.main()