Browse Source

[to #42322933] feat(RealtimeObjectDetection):新增实时检测pipeline

新增实时目标检测pipeline
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9788299
master
leyuan.hjy yingda.chen 3 years ago
parent
commit
285192850d
28 changed files with 1410 additions and 1 deletions
  1. +2
    -0
      modelscope/metainfo.py
  2. +1
    -1
      modelscope/models/cv/__init__.py
  3. +21
    -0
      modelscope/models/cv/realtime_object_detection/__init__.py
  4. +85
    -0
      modelscope/models/cv/realtime_object_detection/realtime_detector.py
  5. +0
    -0
      modelscope/models/cv/realtime_object_detection/yolox/__init__.py
  6. +0
    -0
      modelscope/models/cv/realtime_object_detection/yolox/data/__init__.py
  7. +69
    -0
      modelscope/models/cv/realtime_object_detection/yolox/data/data_augment.py
  8. +5
    -0
      modelscope/models/cv/realtime_object_detection/yolox/exp/__init__.py
  9. +12
    -0
      modelscope/models/cv/realtime_object_detection/yolox/exp/base_exp.py
  10. +18
    -0
      modelscope/models/cv/realtime_object_detection/yolox/exp/build.py
  11. +5
    -0
      modelscope/models/cv/realtime_object_detection/yolox/exp/default/__init__.py
  12. +46
    -0
      modelscope/models/cv/realtime_object_detection/yolox/exp/default/yolox_nano.py
  13. +13
    -0
      modelscope/models/cv/realtime_object_detection/yolox/exp/default/yolox_s.py
  14. +20
    -0
      modelscope/models/cv/realtime_object_detection/yolox/exp/default/yolox_tiny.py
  15. +59
    -0
      modelscope/models/cv/realtime_object_detection/yolox/exp/yolox_base.py
  16. +7
    -0
      modelscope/models/cv/realtime_object_detection/yolox/models/__init__.py
  17. +189
    -0
      modelscope/models/cv/realtime_object_detection/yolox/models/darknet.py
  18. +213
    -0
      modelscope/models/cv/realtime_object_detection/yolox/models/network_blocks.py
  19. +80
    -0
      modelscope/models/cv/realtime_object_detection/yolox/models/yolo_fpn.py
  20. +182
    -0
      modelscope/models/cv/realtime_object_detection/yolox/models/yolo_head.py
  21. +126
    -0
      modelscope/models/cv/realtime_object_detection/yolox/models/yolo_pafpn.py
  22. +33
    -0
      modelscope/models/cv/realtime_object_detection/yolox/models/yolox.py
  23. +5
    -0
      modelscope/models/cv/realtime_object_detection/yolox/utils/__init__.py
  24. +107
    -0
      modelscope/models/cv/realtime_object_detection/yolox/utils/boxes.py
  25. +3
    -0
      modelscope/pipelines/cv/__init__.py
  26. +50
    -0
      modelscope/pipelines/cv/realtime_object_detection_pipeline.py
  27. +7
    -0
      modelscope/utils/cv/image_utils.py
  28. +52
    -0
      tests/pipelines/test_realtime_object_detection.py

+ 2
- 0
modelscope/metainfo.py View File

@@ -11,6 +11,7 @@ class Models(object):
"""
# vision models
detection = 'detection'
realtime_object_detection = 'realtime-object-detection'
scrfd = 'scrfd'
classification_model = 'ClassificationModel'
nafnet = 'nafnet'
@@ -111,6 +112,7 @@ class Pipelines(object):
image_super_resolution = 'rrdb-image-super-resolution'
face_image_generation = 'gan-face-image-generation'
product_retrieval_embedding = 'resnet50-product-retrieval-embedding'
realtime_object_detection = 'cspnet_realtime-object-detection_yolox'
face_recognition = 'ir101-face-recognition-cfglint'
image_instance_segmentation = 'cascade-mask-rcnn-swin-image-instance-segmentation'
image2image_translation = 'image-to-image-translation'


+ 1
- 1
modelscope/models/cv/__init__.py View File

@@ -7,5 +7,5 @@ from . import (action_recognition, animal_recognition, body_2d_keypoints,
image_reid_person, image_semantic_segmentation,
image_to_image_generation, image_to_image_translation,
object_detection, product_retrieval_embedding,
salient_detection, super_resolution,
realtime_object_detection, salient_detection, super_resolution,
video_single_object_tracking, video_summarization, virual_tryon)

+ 21
- 0
modelscope/models/cv/realtime_object_detection/__init__.py View File

@@ -0,0 +1,21 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import TYPE_CHECKING

from modelscope.utils.import_utils import LazyImportModule

if TYPE_CHECKING:
from .realtime_detector import RealtimeDetector
else:
_import_structure = {
'realtime_detector': ['RealtimeDetector'],
}

import sys

sys.modules[__name__] = LazyImportModule(
__name__,
globals()['__file__'],
_import_structure,
module_spec=__spec__,
extra_objects={},
)

+ 85
- 0
modelscope/models/cv/realtime_object_detection/realtime_detector.py View File

@@ -0,0 +1,85 @@
import argparse
import logging as logger
import os
import os.path as osp
import time

import cv2
import json
import torch

from modelscope.metainfo import Models
from modelscope.models.base.base_torch_model import TorchModel
from modelscope.models.builder import MODELS
from modelscope.preprocessors import LoadImage
from modelscope.utils.config import Config
from modelscope.utils.constant import ModelFile, Tasks
from .yolox.data.data_augment import ValTransform
from .yolox.exp import get_exp_by_name
from .yolox.utils import postprocess


@MODELS.register_module(
group_key=Tasks.image_object_detection,
module_name=Models.realtime_object_detection)
class RealtimeDetector(TorchModel):

def __init__(self, model_dir: str, *args, **kwargs):
super().__init__(model_dir, *args, **kwargs)
self.config = Config.from_file(
os.path.join(self.model_dir, ModelFile.CONFIGURATION))

# model type
self.exp = get_exp_by_name(self.config.model_type)

# build model
self.model = self.exp.get_model()
model_path = osp.join(model_dir, ModelFile.TORCH_MODEL_BIN_FILE)
ckpt = torch.load(model_path, map_location='cpu')

# load the model state dict
self.model.load_state_dict(ckpt['model'])
self.model.eval()

# params setting
self.exp.num_classes = self.config.num_classes
self.confthre = self.config.conf_thr
self.num_classes = self.exp.num_classes
self.nmsthre = self.exp.nmsthre
self.test_size = self.exp.test_size
self.preproc = ValTransform(legacy=False)

def inference(self, img):
with torch.no_grad():
outputs = self.model(img)
return outputs

def forward(self, inputs):
return self.inference(inputs)

def preprocess(self, img):
img = LoadImage.convert_to_ndarray(img)
height, width = img.shape[:2]
self.ratio = min(self.test_size[0] / img.shape[0],
self.test_size[1] / img.shape[1])

img, _ = self.preproc(img, None, self.test_size)
img = torch.from_numpy(img).unsqueeze(0)
img = img.float()

return img

def postprocess(self, input):
outputs = postprocess(
input,
self.num_classes,
self.confthre,
self.nmsthre,
class_agnostic=True)

if len(outputs) == 1:
bboxes = outputs[0][:, 0:4].cpu().numpy() / self.ratio
scores = outputs[0][:, 5].cpu().numpy()
labels = outputs[0][:, 6].cpu().int().numpy()

return bboxes, scores, labels

+ 0
- 0
modelscope/models/cv/realtime_object_detection/yolox/__init__.py View File


+ 0
- 0
modelscope/models/cv/realtime_object_detection/yolox/data/__init__.py View File


+ 69
- 0
modelscope/models/cv/realtime_object_detection/yolox/data/data_augment.py View File

@@ -0,0 +1,69 @@
# The implementation is based on YOLOX, available at https://github.com/Megvii-BaseDetection/YOLOX
"""
Data augmentation functionality. Passed as callable transformations to
Dataset classes.

The data augmentation procedures were interpreted from @weiliu89's SSD paper
http://arxiv.org/abs/1512.02325
"""

import math
import random

import cv2
import numpy as np

from ..utils import xyxy2cxcywh


def preproc(img, input_size, swap=(2, 0, 1)):
if len(img.shape) == 3:
padded_img = np.ones(
(input_size[0], input_size[1], 3), dtype=np.uint8) * 114
else:
padded_img = np.ones(input_size, dtype=np.uint8) * 114

r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1])
resized_img = cv2.resize(
img,
(int(img.shape[1] * r), int(img.shape[0] * r)),
interpolation=cv2.INTER_LINEAR,
).astype(np.uint8)
padded_img[:int(img.shape[0] * r), :int(img.shape[1] * r)] = resized_img

padded_img = padded_img.transpose(swap)
padded_img = np.ascontiguousarray(padded_img, dtype=np.float32)
return padded_img, r


class ValTransform:
"""
Defines the transformations that should be applied to test PIL image
for input into the network

dimension -> tensorize -> color adj

Arguments:
resize (int): input dimension to SSD
rgb_means ((int,int,int)): average RGB of the dataset
(104,117,123)
swap ((int,int,int)): final order of channels

Returns:
transform (transform) : callable transform to be applied to test/val
data
"""

def __init__(self, swap=(2, 0, 1), legacy=False):
self.swap = swap
self.legacy = legacy

# assume input is cv2 img for now
def __call__(self, img, res, input_size):
img, _ = preproc(img, input_size, self.swap)
if self.legacy:
img = img[::-1, :, :].copy()
img /= 255.0
img -= np.array([0.485, 0.456, 0.406]).reshape(3, 1, 1)
img /= np.array([0.229, 0.224, 0.225]).reshape(3, 1, 1)
return img, np.zeros((1, 5))

+ 5
- 0
modelscope/models/cv/realtime_object_detection/yolox/exp/__init__.py View File

@@ -0,0 +1,5 @@
# The implementation is based on YOLOX, available at https://github.com/Megvii-BaseDetection/YOLOX

from .base_exp import BaseExp
from .build import get_exp_by_name
from .yolox_base import Exp

+ 12
- 0
modelscope/models/cv/realtime_object_detection/yolox/exp/base_exp.py View File

@@ -0,0 +1,12 @@
# The implementation is based on YOLOX, available at https://github.com/Megvii-BaseDetection/YOLOX

from abc import ABCMeta, abstractmethod

from torch.nn import Module


class BaseExp(metaclass=ABCMeta):

@abstractmethod
def get_model(self) -> Module:
pass

+ 18
- 0
modelscope/models/cv/realtime_object_detection/yolox/exp/build.py View File

@@ -0,0 +1,18 @@
# The implementation is based on YOLOX, available at https://github.com/Megvii-BaseDetection/YOLOX

import os
import sys


def get_exp_by_name(exp_name):
exp = exp_name.replace('-',
'_') # convert string like "yolox-s" to "yolox_s"
if exp == 'yolox_s':
from .default import YoloXSExp as YoloXExp
elif exp == 'yolox_nano':
from .default import YoloXNanoExp as YoloXExp
elif exp == 'yolox_tiny':
from .default import YoloXTinyExp as YoloXExp
else:
pass
return YoloXExp()

+ 5
- 0
modelscope/models/cv/realtime_object_detection/yolox/exp/default/__init__.py View File

@@ -0,0 +1,5 @@
# The implementation is based on YOLOX, available at https://github.com/Megvii-BaseDetection/YOLOX

from .yolox_nano import YoloXNanoExp
from .yolox_s import YoloXSExp
from .yolox_tiny import YoloXTinyExp

+ 46
- 0
modelscope/models/cv/realtime_object_detection/yolox/exp/default/yolox_nano.py View File

@@ -0,0 +1,46 @@
# The implementation is based on YOLOX, available at https://github.com/Megvii-BaseDetection/YOLOX

import os

import torch.nn as nn

from ..yolox_base import Exp as YoloXExp


class YoloXNanoExp(YoloXExp):

def __init__(self):
super(YoloXNanoExp, self).__init__()
self.depth = 0.33
self.width = 0.25
self.input_size = (416, 416)
self.test_size = (416, 416)

def get_model(self, sublinear=False):

def init_yolo(M):
for m in M.modules():
if isinstance(m, nn.BatchNorm2d):
m.eps = 1e-3
m.momentum = 0.03

if 'model' not in self.__dict__:
from ...models import YOLOX, YOLOPAFPN, YOLOXHead
in_channels = [256, 512, 1024]
# NANO model use depthwise = True, which is main difference.
backbone = YOLOPAFPN(
self.depth,
self.width,
in_channels=in_channels,
act=self.act,
depthwise=True,
)
head = YOLOXHead(
self.num_classes,
self.width,
in_channels=in_channels,
act=self.act,
depthwise=True)
self.model = YOLOX(backbone, head)

return self.model

+ 13
- 0
modelscope/models/cv/realtime_object_detection/yolox/exp/default/yolox_s.py View File

@@ -0,0 +1,13 @@
# The implementation is based on YOLOX, available at https://github.com/Megvii-BaseDetection/YOLOX

import os

from ..yolox_base import Exp as YoloXExp


class YoloXSExp(YoloXExp):

def __init__(self):
super(YoloXSExp, self).__init__()
self.depth = 0.33
self.width = 0.50

+ 20
- 0
modelscope/models/cv/realtime_object_detection/yolox/exp/default/yolox_tiny.py View File

@@ -0,0 +1,20 @@
# The implementation is based on YOLOX, available at https://github.com/Megvii-BaseDetection/YOLOX

import os

from ..yolox_base import Exp as YoloXExp


class YoloXTinyExp(YoloXExp):

def __init__(self):
super(YoloXTinyExp, self).__init__()
self.depth = 0.33
self.width = 0.375
self.input_size = (416, 416)
self.mosaic_scale = (0.5, 1.5)
self.random_size = (10, 20)
self.test_size = (416, 416)
self.exp_name = os.path.split(
os.path.realpath(__file__))[1].split('.')[0]
self.enable_mixup = False

+ 59
- 0
modelscope/models/cv/realtime_object_detection/yolox/exp/yolox_base.py View File

@@ -0,0 +1,59 @@
# The implementation is based on YOLOX, available at https://github.com/Megvii-BaseDetection/YOLOX

import os
import random

import torch
import torch.distributed as dist
import torch.nn as nn

from .base_exp import BaseExp


class Exp(BaseExp):

def __init__(self):
super().__init__()

# ---------------- model config ---------------- #
# detect classes number of model
self.num_classes = 80
# factor of model depth
self.depth = 1.00
# factor of model width
self.width = 1.00
# activation name. For example, if using "relu", then "silu" will be replaced to "relu".
self.act = 'silu'
# ----------------- testing config ------------------ #
# output image size during evaluation/test
self.test_size = (640, 640)
# confidence threshold during evaluation/test,
# boxes whose scores are less than test_conf will be filtered
self.test_conf = 0.01
# nms threshold
self.nmsthre = 0.65

def get_model(self):
from ..models import YOLOX, YOLOPAFPN, YOLOXHead

def init_yolo(M):
for m in M.modules():
if isinstance(m, nn.BatchNorm2d):
m.eps = 1e-3
m.momentum = 0.03

if getattr(self, 'model', None) is None:
in_channels = [256, 512, 1024]
backbone = YOLOPAFPN(
self.depth, self.width, in_channels=in_channels, act=self.act)
head = YOLOXHead(
self.num_classes,
self.width,
in_channels=in_channels,
act=self.act)
self.model = YOLOX(backbone, head)

self.model.apply(init_yolo)
self.model.head.initialize_biases(1e-2)
self.model.train()
return self.model

+ 7
- 0
modelscope/models/cv/realtime_object_detection/yolox/models/__init__.py View File

@@ -0,0 +1,7 @@
# The implementation is based on YOLOX, available at https://github.com/Megvii-BaseDetection/YOLOX

from .darknet import CSPDarknet, Darknet
from .yolo_fpn import YOLOFPN
from .yolo_head import YOLOXHead
from .yolo_pafpn import YOLOPAFPN
from .yolox import YOLOX

+ 189
- 0
modelscope/models/cv/realtime_object_detection/yolox/models/darknet.py View File

@@ -0,0 +1,189 @@
# The implementation is based on YOLOX, available at https://github.com/Megvii-BaseDetection/YOLOX

from torch import nn

from .network_blocks import (BaseConv, CSPLayer, DWConv, Focus, ResLayer,
SPPBottleneck)


class Darknet(nn.Module):
# number of blocks from dark2 to dark5.
depth2blocks = {21: [1, 2, 2, 1], 53: [2, 8, 8, 4]}

def __init__(
self,
depth,
in_channels=3,
stem_out_channels=32,
out_features=('dark3', 'dark4', 'dark5'),
):
"""
Args:
depth (int): depth of darknet used in model, usually use [21, 53] for this param.
in_channels (int): number of input channels, for example, use 3 for RGB image.
stem_out_channels (int): number of output channels of darknet stem.
It decides channels of darknet layer2 to layer5.
out_features (Tuple[str]): desired output layer name.
"""
super().__init__()
assert out_features, 'please provide output features of Darknet'
self.out_features = out_features
self.stem = nn.Sequential(
BaseConv(
in_channels, stem_out_channels, ksize=3, stride=1,
act='lrelu'),
*self.make_group_layer(stem_out_channels, num_blocks=1, stride=2),
)
in_channels = stem_out_channels * 2 # 64

num_blocks = Darknet.depth2blocks[depth]
# create darknet with `stem_out_channels` and `num_blocks` layers.
# to make model structure more clear, we don't use `for` statement in python.
self.dark2 = nn.Sequential(
*self.make_group_layer(in_channels, num_blocks[0], stride=2))
in_channels *= 2 # 128
self.dark3 = nn.Sequential(
*self.make_group_layer(in_channels, num_blocks[1], stride=2))
in_channels *= 2 # 256
self.dark4 = nn.Sequential(
*self.make_group_layer(in_channels, num_blocks[2], stride=2))
in_channels *= 2 # 512

self.dark5 = nn.Sequential(
*self.make_group_layer(in_channels, num_blocks[3], stride=2),
*self.make_spp_block([in_channels, in_channels * 2],
in_channels * 2),
)

def make_group_layer(self,
in_channels: int,
num_blocks: int,
stride: int = 1):
'starts with conv layer then has `num_blocks` `ResLayer`'
return [
BaseConv(
in_channels,
in_channels * 2,
ksize=3,
stride=stride,
act='lrelu'),
*[(ResLayer(in_channels * 2)) for _ in range(num_blocks)],
]

def make_spp_block(self, filters_list, in_filters):
m = nn.Sequential(*[
BaseConv(in_filters, filters_list[0], 1, stride=1, act='lrelu'),
BaseConv(
filters_list[0], filters_list[1], 3, stride=1, act='lrelu'),
SPPBottleneck(
in_channels=filters_list[1],
out_channels=filters_list[0],
activation='lrelu',
),
BaseConv(
filters_list[0], filters_list[1], 3, stride=1, act='lrelu'),
BaseConv(
filters_list[1], filters_list[0], 1, stride=1, act='lrelu'),
])
return m

def forward(self, x):
outputs = {}
x = self.stem(x)
outputs['stem'] = x
x = self.dark2(x)
outputs['dark2'] = x
x = self.dark3(x)
outputs['dark3'] = x
x = self.dark4(x)
outputs['dark4'] = x
x = self.dark5(x)
outputs['dark5'] = x
return {k: v for k, v in outputs.items() if k in self.out_features}


class CSPDarknet(nn.Module):

def __init__(
self,
dep_mul,
wid_mul,
out_features=('dark3', 'dark4', 'dark5'),
depthwise=False,
act='silu',
):
super().__init__()
assert out_features, 'please provide output features of Darknet'
self.out_features = out_features
Conv = DWConv if depthwise else BaseConv

base_channels = int(wid_mul * 64) # 64
base_depth = max(round(dep_mul * 3), 1) # 3

# stem
self.stem = Focus(3, base_channels, ksize=3, act=act)

# dark2
self.dark2 = nn.Sequential(
Conv(base_channels, base_channels * 2, 3, 2, act=act),
CSPLayer(
base_channels * 2,
base_channels * 2,
n=base_depth,
depthwise=depthwise,
act=act,
),
)

# dark3
self.dark3 = nn.Sequential(
Conv(base_channels * 2, base_channels * 4, 3, 2, act=act),
CSPLayer(
base_channels * 4,
base_channels * 4,
n=base_depth * 3,
depthwise=depthwise,
act=act,
),
)

# dark4
self.dark4 = nn.Sequential(
Conv(base_channels * 4, base_channels * 8, 3, 2, act=act),
CSPLayer(
base_channels * 8,
base_channels * 8,
n=base_depth * 3,
depthwise=depthwise,
act=act,
),
)

# dark5
self.dark5 = nn.Sequential(
Conv(base_channels * 8, base_channels * 16, 3, 2, act=act),
SPPBottleneck(
base_channels * 16, base_channels * 16, activation=act),
CSPLayer(
base_channels * 16,
base_channels * 16,
n=base_depth,
shortcut=False,
depthwise=depthwise,
act=act,
),
)

def forward(self, x):
outputs = {}
x = self.stem(x)
outputs['stem'] = x
x = self.dark2(x)
outputs['dark2'] = x
x = self.dark3(x)
outputs['dark3'] = x
x = self.dark4(x)
outputs['dark4'] = x
x = self.dark5(x)
outputs['dark5'] = x
return {k: v for k, v in outputs.items() if k in self.out_features}

+ 213
- 0
modelscope/models/cv/realtime_object_detection/yolox/models/network_blocks.py View File

@@ -0,0 +1,213 @@
# The implementation is based on YOLOX, available at https://github.com/Megvii-BaseDetection/YOLOX

import torch
import torch.nn as nn


def get_activation(name='silu', inplace=True):
if name == 'silu':
module = nn.SiLU(inplace=inplace)
else:
raise AttributeError('Unsupported act type: {}'.format(name))
return module


class BaseConv(nn.Module):
"""A Conv2d -> Batchnorm -> silu/leaky relu block"""

def __init__(self,
in_channels,
out_channels,
ksize,
stride,
groups=1,
bias=False,
act='silu'):
super(BaseConv, self).__init__()
# same padding
pad = (ksize - 1) // 2
self.conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size=ksize,
stride=stride,
padding=pad,
groups=groups,
bias=bias,
)
self.bn = nn.BatchNorm2d(out_channels)
self.act = get_activation(act, inplace=True)

def forward(self, x):
return self.act(self.bn(self.conv(x)))

def fuseforward(self, x):
return self.act(self.conv(x))


class DWConv(nn.Module):
"""Depthwise Conv + Conv"""

def __init__(self, in_channels, out_channels, ksize, stride=1, act='silu'):
super(DWConv, self).__init__()
self.dconv = BaseConv(
in_channels,
in_channels,
ksize=ksize,
stride=stride,
groups=in_channels,
act=act,
)
self.pconv = BaseConv(
in_channels, out_channels, ksize=1, stride=1, groups=1, act=act)

def forward(self, x):
x = self.dconv(x)
return self.pconv(x)


class Bottleneck(nn.Module):
# Standard bottleneck
def __init__(
self,
in_channels,
out_channels,
shortcut=True,
expansion=0.5,
depthwise=False,
act='silu',
):
super().__init__()
hidden_channels = int(out_channels * expansion)
Conv = DWConv if depthwise else BaseConv
self.conv1 = BaseConv(
in_channels, hidden_channels, 1, stride=1, act=act)
self.conv2 = Conv(hidden_channels, out_channels, 3, stride=1, act=act)
self.use_add = shortcut and in_channels == out_channels

def forward(self, x):
y = self.conv2(self.conv1(x))
if self.use_add:
y = y + x
return y


class ResLayer(nn.Module):
'Residual layer with `in_channels` inputs.'

def __init__(self, in_channels: int):
super().__init__()
mid_channels = in_channels // 2
self.layer1 = BaseConv(
in_channels, mid_channels, ksize=1, stride=1, act='lrelu')
self.layer2 = BaseConv(
mid_channels, in_channels, ksize=3, stride=1, act='lrelu')

def forward(self, x):
out = self.layer2(self.layer1(x))
return x + out


class SPPBottleneck(nn.Module):
"""Spatial pyramid pooling layer used in YOLOv3-SPP"""

def __init__(self,
in_channels,
out_channels,
kernel_sizes=(5, 9, 13),
activation='silu'):
super().__init__()
hidden_channels = in_channels // 2
self.conv1 = BaseConv(
in_channels, hidden_channels, 1, stride=1, act=activation)
self.m = nn.ModuleList([
nn.MaxPool2d(kernel_size=ks, stride=1, padding=ks // 2)
for ks in kernel_sizes
])
conv2_channels = hidden_channels * (len(kernel_sizes) + 1)
self.conv2 = BaseConv(
conv2_channels, out_channels, 1, stride=1, act=activation)

def forward(self, x):
x = self.conv1(x)
x = torch.cat([x] + [m(x) for m in self.m], dim=1)
x = self.conv2(x)
return x


class CSPLayer(nn.Module):
"""C3 in yolov5, CSP Bottleneck with 3 convolutions"""

def __init__(
self,
in_channels,
out_channels,
n=1,
shortcut=True,
expansion=0.5,
depthwise=False,
act='silu',
):
"""
Args:
in_channels (int): input channels.
out_channels (int): output channels.
n (int): number of Bottlenecks. Default value: 1.
"""
# ch_in, ch_out, number, shortcut, groups, expansion
super().__init__()
hidden_channels = int(out_channels * expansion) # hidden channels
self.conv1 = BaseConv(
in_channels, hidden_channels, 1, stride=1, act=act)
self.conv2 = BaseConv(
in_channels, hidden_channels, 1, stride=1, act=act)
self.conv3 = BaseConv(
2 * hidden_channels, out_channels, 1, stride=1, act=act)
module_list = [
Bottleneck(
hidden_channels,
hidden_channels,
shortcut,
1.0,
depthwise,
act=act) for _ in range(n)
]
self.m = nn.Sequential(*module_list)

def forward(self, x):
x_1 = self.conv1(x)
x_2 = self.conv2(x)
x_1 = self.m(x_1)
x = torch.cat((x_1, x_2), dim=1)
return self.conv3(x)


class Focus(nn.Module):
"""Focus width and height information into channel space."""

def __init__(self,
in_channels,
out_channels,
ksize=1,
stride=1,
act='silu'):
super().__init__()
self.conv = BaseConv(
in_channels * 4, out_channels, ksize, stride, act=act)

def forward(self, x):
# shape of x (b,c,w,h) -> y(b,4c,w/2,h/2)
patch_top_left = x[..., ::2, ::2]
patch_top_right = x[..., ::2, 1::2]
patch_bot_left = x[..., 1::2, ::2]
patch_bot_right = x[..., 1::2, 1::2]
x = torch.cat(
(
patch_top_left,
patch_bot_left,
patch_top_right,
patch_bot_right,
),
dim=1,
)
return self.conv(x)

+ 80
- 0
modelscope/models/cv/realtime_object_detection/yolox/models/yolo_fpn.py View File

@@ -0,0 +1,80 @@
# The implementation is based on YOLOX, available at https://github.com/Megvii-BaseDetection/YOLOX

import torch
import torch.nn as nn

from .darknet import Darknet
from .network_blocks import BaseConv


class YOLOFPN(nn.Module):
"""
YOLOFPN module. Darknet 53 is the default backbone of this model.
"""

def __init__(
self,
depth=53,
in_features=['dark3', 'dark4', 'dark5'],
):
super(YOLOFPN, self).__init__()

self.backbone = Darknet(depth)
self.in_features = in_features

# out 1
self.out1_cbl = self._make_cbl(512, 256, 1)
self.out1 = self._make_embedding([256, 512], 512 + 256)

# out 2
self.out2_cbl = self._make_cbl(256, 128, 1)
self.out2 = self._make_embedding([128, 256], 256 + 128)

# upsample
self.upsample = nn.Upsample(scale_factor=2, mode='nearest')

def _make_cbl(self, _in, _out, ks):
return BaseConv(_in, _out, ks, stride=1, act='lrelu')

def _make_embedding(self, filters_list, in_filters):
m = nn.Sequential(*[
self._make_cbl(in_filters, filters_list[0], 1),
self._make_cbl(filters_list[0], filters_list[1], 3),
self._make_cbl(filters_list[1], filters_list[0], 1),
self._make_cbl(filters_list[0], filters_list[1], 3),
self._make_cbl(filters_list[1], filters_list[0], 1),
])
return m

def load_pretrained_model(self, filename='./weights/darknet53.mix.pth'):
with open(filename, 'rb') as f:
state_dict = torch.load(f, map_location='cpu')
print('loading pretrained weights...')
self.backbone.load_state_dict(state_dict)

def forward(self, inputs):
"""
Args:
inputs (Tensor): input image.

Returns:
Tuple[Tensor]: FPN output features..
"""
# backbone
out_features = self.backbone(inputs)
x2, x1, x0 = [out_features[f] for f in self.in_features]

# yolo branch 1
x1_in = self.out1_cbl(x0)
x1_in = self.upsample(x1_in)
x1_in = torch.cat([x1_in, x1], 1)
out_dark4 = self.out1(x1_in)

# yolo branch 2
x2_in = self.out2_cbl(out_dark4)
x2_in = self.upsample(x2_in)
x2_in = torch.cat([x2_in, x2], 1)
out_dark3 = self.out2(x2_in)

outputs = (out_dark3, out_dark4, x0)
return outputs

+ 182
- 0
modelscope/models/cv/realtime_object_detection/yolox/models/yolo_head.py View File

@@ -0,0 +1,182 @@
# The implementation is based on YOLOX, available at https://github.com/Megvii-BaseDetection/YOLOX

import math

import torch
import torch.nn as nn
import torch.nn.functional as F

from ..utils import bboxes_iou, meshgrid
from .network_blocks import BaseConv, DWConv


class YOLOXHead(nn.Module):

def __init__(
self,
num_classes,
width=1.0,
strides=[8, 16, 32],
in_channels=[256, 512, 1024],
act='silu',
depthwise=False,
):
"""
Args:
act (str): activation type of conv. Defalut value: "silu".
depthwise (bool): whether apply depthwise conv in conv branch. Defalut value: False.
"""
super(YOLOXHead, self).__init__()

self.n_anchors = 1
self.num_classes = num_classes
self.decode_in_inference = True # for deploy, set to False

self.cls_convs = nn.ModuleList()
self.reg_convs = nn.ModuleList()
self.cls_preds = nn.ModuleList()
self.reg_preds = nn.ModuleList()
self.obj_preds = nn.ModuleList()
self.stems = nn.ModuleList()
Conv = DWConv if depthwise else BaseConv

for i in range(len(in_channels)):
self.stems.append(
BaseConv(
in_channels=int(in_channels[i] * width),
out_channels=int(256 * width),
ksize=1,
stride=1,
act=act,
))
self.cls_convs.append(
nn.Sequential(*[
Conv(
in_channels=int(256 * width),
out_channels=int(256 * width),
ksize=3,
stride=1,
act=act,
),
Conv(
in_channels=int(256 * width),
out_channels=int(256 * width),
ksize=3,
stride=1,
act=act,
),
]))
self.reg_convs.append(
nn.Sequential(*[
Conv(
in_channels=int(256 * width),
out_channels=int(256 * width),
ksize=3,
stride=1,
act=act,
),
Conv(
in_channels=int(256 * width),
out_channels=int(256 * width),
ksize=3,
stride=1,
act=act,
),
]))
self.cls_preds.append(
nn.Conv2d(
in_channels=int(256 * width),
out_channels=self.n_anchors * self.num_classes,
kernel_size=1,
stride=1,
padding=0,
))
self.reg_preds.append(
nn.Conv2d(
in_channels=int(256 * width),
out_channels=4,
kernel_size=1,
stride=1,
padding=0,
))
self.obj_preds.append(
nn.Conv2d(
in_channels=int(256 * width),
out_channels=self.n_anchors * 1,
kernel_size=1,
stride=1,
padding=0,
))

self.use_l1 = False
self.l1_loss = nn.L1Loss(reduction='none')
self.bcewithlog_loss = nn.BCEWithLogitsLoss(reduction='none')
# self.iou_loss = IOUloss(reduction="none")
self.strides = strides
self.grids = [torch.zeros(1)] * len(in_channels)

def initialize_biases(self, prior_prob):
for conv in self.cls_preds:
b = conv.bias.view(self.n_anchors, -1)
b.data.fill_(-math.log((1 - prior_prob) / prior_prob))
conv.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)

for conv in self.obj_preds:
b = conv.bias.view(self.n_anchors, -1)
b.data.fill_(-math.log((1 - prior_prob) / prior_prob))
conv.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)

def forward(self, xin, labels=None, imgs=None):
outputs = []

for k, (cls_conv, reg_conv, stride_this_level, x) in enumerate(
zip(self.cls_convs, self.reg_convs, self.strides, xin)):
x = self.stems[k](x)
cls_x = x
reg_x = x

cls_feat = cls_conv(cls_x)
cls_output = self.cls_preds[k](cls_feat)

reg_feat = reg_conv(reg_x)
reg_output = self.reg_preds[k](reg_feat)
obj_output = self.obj_preds[k](reg_feat)

if self.training:
pass
else:
output = torch.cat(
[reg_output,
obj_output.sigmoid(),
cls_output.sigmoid()], 1)

outputs.append(output)

if self.training:
pass
else:
self.hw = [x.shape[-2:] for x in outputs]
# [batch, n_anchors_all, 85]
outputs = torch.cat([x.flatten(start_dim=2) for x in outputs],
dim=2).permute(0, 2, 1)
if self.decode_in_inference:
return self.decode_outputs(outputs, dtype=xin[0].type())
else:
return outputs

def decode_outputs(self, outputs, dtype):
grids = []
strides = []
for (hsize, wsize), stride in zip(self.hw, self.strides):
yv, xv = meshgrid([torch.arange(hsize), torch.arange(wsize)])
grid = torch.stack((xv, yv), 2).view(1, -1, 2)
grids.append(grid)
shape = grid.shape[:2]
strides.append(torch.full((*shape, 1), stride))

grids = torch.cat(grids, dim=1).type(dtype)
strides = torch.cat(strides, dim=1).type(dtype)

outputs[..., :2] = (outputs[..., :2] + grids) * strides
outputs[..., 2:4] = torch.exp(outputs[..., 2:4]) * strides
return outputs

+ 126
- 0
modelscope/models/cv/realtime_object_detection/yolox/models/yolo_pafpn.py View File

@@ -0,0 +1,126 @@
# The implementation is based on YOLOX, available at https://github.com/Megvii-BaseDetection/YOLOX

import torch
import torch.nn as nn

from .darknet import CSPDarknet
from .network_blocks import BaseConv, CSPLayer, DWConv


class YOLOPAFPN(nn.Module):
"""
YOLOv3 model. Darknet 53 is the default backbone of this model.
"""

def __init__(
self,
depth=1.0,
width=1.0,
in_features=('dark3', 'dark4', 'dark5'),
in_channels=[256, 512, 1024],
depthwise=False,
act='silu',
):
super(YOLOPAFPN, self).__init__()
self.backbone = CSPDarknet(depth, width, depthwise=depthwise, act=act)
self.in_features = in_features
self.in_channels = in_channels
Conv = DWConv if depthwise else BaseConv

self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
self.lateral_conv0 = BaseConv(
int(in_channels[2] * width),
int(in_channels[1] * width),
1,
1,
act=act)
self.C3_p4 = CSPLayer(
int(2 * in_channels[1] * width),
int(in_channels[1] * width),
round(3 * depth),
False,
depthwise=depthwise,
act=act,
) # cat

self.reduce_conv1 = BaseConv(
int(in_channels[1] * width),
int(in_channels[0] * width),
1,
1,
act=act)
self.C3_p3 = CSPLayer(
int(2 * in_channels[0] * width),
int(in_channels[0] * width),
round(3 * depth),
False,
depthwise=depthwise,
act=act,
)

# bottom-up conv
self.bu_conv2 = Conv(
int(in_channels[0] * width),
int(in_channels[0] * width),
3,
2,
act=act)
self.C3_n3 = CSPLayer(
int(2 * in_channels[0] * width),
int(in_channels[1] * width),
round(3 * depth),
False,
depthwise=depthwise,
act=act,
)

# bottom-up conv
self.bu_conv1 = Conv(
int(in_channels[1] * width),
int(in_channels[1] * width),
3,
2,
act=act)
self.C3_n4 = CSPLayer(
int(2 * in_channels[1] * width),
int(in_channels[2] * width),
round(3 * depth),
False,
depthwise=depthwise,
act=act,
)

def forward(self, input):
"""
Args:
inputs: input images.

Returns:
Tuple[Tensor]: FPN feature.
"""

# backbone
out_features = self.backbone(input)
features = [out_features[f] for f in self.in_features]
[x2, x1, x0] = features

fpn_out0 = self.lateral_conv0(x0) # 1024->512/32
f_out0 = self.upsample(fpn_out0) # 512/16
f_out0 = torch.cat([f_out0, x1], 1) # 512->1024/16
f_out0 = self.C3_p4(f_out0) # 1024->512/16

fpn_out1 = self.reduce_conv1(f_out0) # 512->256/16
f_out1 = self.upsample(fpn_out1) # 256/8
f_out1 = torch.cat([f_out1, x2], 1) # 256->512/8
pan_out2 = self.C3_p3(f_out1) # 512->256/8

p_out1 = self.bu_conv2(pan_out2) # 256->256/16
p_out1 = torch.cat([p_out1, fpn_out1], 1) # 256->512/16
pan_out1 = self.C3_n3(p_out1) # 512->512/16

p_out0 = self.bu_conv1(pan_out1) # 512->512/32
p_out0 = torch.cat([p_out0, fpn_out0], 1) # 512->1024/32
pan_out0 = self.C3_n4(p_out0) # 1024->1024/32

outputs = (pan_out2, pan_out1, pan_out0)
return outputs

+ 33
- 0
modelscope/models/cv/realtime_object_detection/yolox/models/yolox.py View File

@@ -0,0 +1,33 @@
# The implementation is based on YOLOX, available at https://github.com/Megvii-BaseDetection/YOLOX

import torch.nn as nn

from .yolo_head import YOLOXHead
from .yolo_pafpn import YOLOPAFPN


class YOLOX(nn.Module):
"""
YOLOX model module. The module list is defined by create_yolov3_modules function.
The network returns loss values from three YOLO layers during training
and detection results during test.
"""

def __init__(self, backbone=None, head=None):
super(YOLOX, self).__init__()
if backbone is None:
backbone = YOLOPAFPN()
if head is None:
head = YOLOXHead(80)

self.backbone = backbone
self.head = head

def forward(self, x, targets=None):
fpn_outs = self.backbone(x)
if self.training:
raise NotImplementedError('Training is not supported yet!')
else:
outputs = self.head(fpn_outs)

return outputs

+ 5
- 0
modelscope/models/cv/realtime_object_detection/yolox/utils/__init__.py View File

@@ -0,0 +1,5 @@
# The implementation is based on YOLOX, available at https://github.com/Megvii-BaseDetection/YOLOX

from .boxes import * # noqa

__all__ = ['bboxes_iou', 'meshgrid', 'postprocess', 'xyxy2cxcywh', 'xyxy2xywh']

+ 107
- 0
modelscope/models/cv/realtime_object_detection/yolox/utils/boxes.py View File

@@ -0,0 +1,107 @@
# The implementation is based on YOLOX, available at https://github.com/Megvii-BaseDetection/YOLOX

import torch
import torchvision

_TORCH_VER = [int(x) for x in torch.__version__.split('.')[:2]]


def meshgrid(*tensors):
if _TORCH_VER >= [1, 10]:
return torch.meshgrid(*tensors, indexing='ij')
else:
return torch.meshgrid(*tensors)


def xyxy2xywh(bboxes):
bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 0]
bboxes[:, 3] = bboxes[:, 3] - bboxes[:, 1]
return bboxes


def xyxy2cxcywh(bboxes):
bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 0]
bboxes[:, 3] = bboxes[:, 3] - bboxes[:, 1]
bboxes[:, 0] = bboxes[:, 0] + bboxes[:, 2] * 0.5
bboxes[:, 1] = bboxes[:, 1] + bboxes[:, 3] * 0.5
return bboxes


def postprocess(prediction,
num_classes,
conf_thre=0.7,
nms_thre=0.45,
class_agnostic=False):
box_corner = prediction.new(prediction.shape)
box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2
box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2
box_corner[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2
box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2
prediction[:, :, :4] = box_corner[:, :, :4]

output = [None for _ in range(len(prediction))]
for i, image_pred in enumerate(prediction):

# If none are remaining => process next image
if not image_pred.size(0):
continue
# Get score and class with highest confidence
class_conf, class_pred = torch.max(
image_pred[:, 5:5 + num_classes], 1, keepdim=True)

conf_mask = image_pred[:, 4] * class_conf.squeeze()
conf_mask = (conf_mask >= conf_thre).squeeze()
# Detections ordered as (x1, y1, x2, y2, obj_conf, class_conf, class_pred)
detections = torch.cat(
(image_pred[:, :5], class_conf, class_pred.float()), 1)
detections = detections[conf_mask]
if not detections.size(0):
continue

if class_agnostic:
nms_out_index = torchvision.ops.nms(
detections[:, :4],
detections[:, 4] * detections[:, 5],
nms_thre,
)
else:
nms_out_index = torchvision.ops.batched_nms(
detections[:, :4],
detections[:, 4] * detections[:, 5],
detections[:, 6],
nms_thre,
)

detections = detections[nms_out_index]
if output[i] is None:
output[i] = detections
else:
output[i] = torch.cat((output[i], detections))

return output


def bboxes_iou(bboxes_a, bboxes_b, xyxy=True):
if bboxes_a.shape[1] != 4 or bboxes_b.shape[1] != 4:
raise IndexError

if xyxy:
tl = torch.max(bboxes_a[:, None, :2], bboxes_b[:, :2])
br = torch.min(bboxes_a[:, None, 2:], bboxes_b[:, 2:])
area_a = torch.prod(bboxes_a[:, 2:] - bboxes_a[:, :2], 1)
area_b = torch.prod(bboxes_b[:, 2:] - bboxes_b[:, :2], 1)
else:
tl = torch.max(
(bboxes_a[:, None, :2] - bboxes_a[:, None, 2:] / 2),
(bboxes_b[:, :2] - bboxes_b[:, 2:] / 2),
)
br = torch.min(
(bboxes_a[:, None, :2] + bboxes_a[:, None, 2:] / 2),
(bboxes_b[:, :2] + bboxes_b[:, 2:] / 2),
)

area_a = torch.prod(bboxes_a[:, 2:], 1)
area_b = torch.prod(bboxes_b[:, 2:], 1)
en = (tl < br).type(tl.type()).prod(dim=2)
area_i = torch.prod(br - tl, 2) * en # * ((tl < br).all())
return area_i / (area_a[:, None] + area_b - area_i)

+ 3
- 0
modelscope/pipelines/cv/__init__.py View File

@@ -32,6 +32,7 @@ if TYPE_CHECKING:
from .image_to_image_generate_pipeline import Image2ImageGenerationPipeline
from .image_to_image_translation_pipeline import Image2ImageTranslationPipeline
from .product_retrieval_embedding_pipeline import ProductRetrievalEmbeddingPipeline
from .realtime_object_detection_pipeline import RealtimeObjectDetectionPipeline
from .live_category_pipeline import LiveCategoryPipeline
from .ocr_detection_pipeline import OCRDetectionPipeline
from .ocr_recognition_pipeline import OCRRecognitionPipeline
@@ -75,6 +76,8 @@ else:
['Image2ImageTranslationPipeline'],
'product_retrieval_embedding_pipeline':
['ProductRetrievalEmbeddingPipeline'],
'realtime_object_detection_pipeline':
['RealtimeObjectDetectionPipeline'],
'live_category_pipeline': ['LiveCategoryPipeline'],
'image_to_image_generation_pipeline':
['Image2ImageGenerationPipeline'],


+ 50
- 0
modelscope/pipelines/cv/realtime_object_detection_pipeline.py View File

@@ -0,0 +1,50 @@
import os.path as osp
from typing import Any, Dict, List, Union

import cv2
import json
import numpy as np
import torch
from PIL import Image
from torchvision import transforms

from modelscope.metainfo import Pipelines
from modelscope.models.cv.realtime_object_detection import RealtimeDetector
from modelscope.outputs import OutputKeys
from modelscope.pipelines import pipeline
from modelscope.pipelines.base import Input, Model, Pipeline, Tensor
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import load_image
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger

logger = get_logger()


@PIPELINES.register_module(
Tasks.image_object_detection,
module_name=Pipelines.realtime_object_detection)
class RealtimeObjectDetectionPipeline(Pipeline):

def __init__(self, model: str, **kwargs):
super().__init__(model=model, **kwargs)
self.model = RealtimeDetector(model)

def preprocess(self, input: Input) -> Dict[Tensor, Union[str, np.ndarray]]:
output = self.model.preprocess(input)
return {'pre_output': output}

def forward(self, input: Tensor) -> Dict[Tensor, Dict[str, np.ndarray]]:
pre_output = input['pre_output']
forward_output = self.model(pre_output)
return {'forward_output': forward_output}

def postprocess(self, input: Dict[Tensor, Dict[str, np.ndarray]],
**kwargs) -> str:
forward_output = input['forward_output']
bboxes, scores, labels = forward_output
return {
OutputKeys.BOXES: bboxes,
OutputKeys.SCORES: scores,
OutputKeys.LABELS: labels,
}

+ 7
- 0
modelscope/utils/cv/image_utils.py View File

@@ -70,6 +70,13 @@ def draw_box(image, box):
(int(box[1][0]), int(box[1][1])), (0, 0, 255), 2)


def realtime_object_detection_bbox_vis(image, bboxes):
for bbox in bboxes:
cv2.rectangle(image, (bbox[0], bbox[1]), (bbox[2], bbox[3]),
(255, 0, 0), 2)
return image


def draw_keypoints(output, original_image):
poses = np.array(output[OutputKeys.POSES])
scores = np.array(output[OutputKeys.SCORES])


+ 52
- 0
tests/pipelines/test_realtime_object_detection.py View File

@@ -0,0 +1,52 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import unittest

import cv2
import numpy as np

from modelscope.outputs import OutputKeys
from modelscope.pipelines import pipeline
from modelscope.pipelines.base import Pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.cv.image_utils import realtime_object_detection_bbox_vis
from modelscope.utils.test_utils import test_level


class RealtimeObjectDetectionTest(unittest.TestCase):

def setUp(self) -> None:
self.model_id = 'damo/cv_cspnet_image-object-detection_yolox'
self.model_nano_id = 'damo/cv_cspnet_image-object-detection_yolox_nano_coco'
self.test_image = 'data/test/images/keypoints_detect/000000438862.jpg'

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_modelhub(self):
realtime_object_detection = pipeline(
Tasks.image_object_detection, model=self.model_id)

image = cv2.imread(self.test_image)
result = realtime_object_detection(image)
if result:
bboxes = result[OutputKeys.BOXES].astype(int)
image = realtime_object_detection_bbox_vis(image, bboxes)
cv2.imwrite('rt_obj_out.jpg', image)
else:
raise ValueError('process error')

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_nano(self):
realtime_object_detection = pipeline(
Tasks.image_object_detection, model=self.model_nano_id)

image = cv2.imread(self.test_image)
result = realtime_object_detection(image)
if result:
bboxes = result[OutputKeys.BOXES].astype(int)
image = realtime_object_detection_bbox_vis(image, bboxes)
cv2.imwrite('rtnano_obj_out.jpg', image)
else:
raise ValueError('process error')


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save