新增实时目标检测pipeline
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9788299
master
| @@ -11,6 +11,7 @@ class Models(object): | |||
| """ | |||
| # vision models | |||
| detection = 'detection' | |||
| realtime_object_detection = 'realtime-object-detection' | |||
| scrfd = 'scrfd' | |||
| classification_model = 'ClassificationModel' | |||
| nafnet = 'nafnet' | |||
| @@ -111,6 +112,7 @@ class Pipelines(object): | |||
| image_super_resolution = 'rrdb-image-super-resolution' | |||
| face_image_generation = 'gan-face-image-generation' | |||
| product_retrieval_embedding = 'resnet50-product-retrieval-embedding' | |||
| realtime_object_detection = 'cspnet_realtime-object-detection_yolox' | |||
| face_recognition = 'ir101-face-recognition-cfglint' | |||
| image_instance_segmentation = 'cascade-mask-rcnn-swin-image-instance-segmentation' | |||
| image2image_translation = 'image-to-image-translation' | |||
| @@ -7,5 +7,5 @@ from . import (action_recognition, animal_recognition, body_2d_keypoints, | |||
| image_reid_person, image_semantic_segmentation, | |||
| image_to_image_generation, image_to_image_translation, | |||
| object_detection, product_retrieval_embedding, | |||
| salient_detection, super_resolution, | |||
| realtime_object_detection, salient_detection, super_resolution, | |||
| video_single_object_tracking, video_summarization, virual_tryon) | |||
| @@ -0,0 +1,21 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from typing import TYPE_CHECKING | |||
| from modelscope.utils.import_utils import LazyImportModule | |||
| if TYPE_CHECKING: | |||
| from .realtime_detector import RealtimeDetector | |||
| else: | |||
| _import_structure = { | |||
| 'realtime_detector': ['RealtimeDetector'], | |||
| } | |||
| import sys | |||
| sys.modules[__name__] = LazyImportModule( | |||
| __name__, | |||
| globals()['__file__'], | |||
| _import_structure, | |||
| module_spec=__spec__, | |||
| extra_objects={}, | |||
| ) | |||
| @@ -0,0 +1,85 @@ | |||
| import argparse | |||
| import logging as logger | |||
| import os | |||
| import os.path as osp | |||
| import time | |||
| import cv2 | |||
| import json | |||
| import torch | |||
| from modelscope.metainfo import Models | |||
| from modelscope.models.base.base_torch_model import TorchModel | |||
| from modelscope.models.builder import MODELS | |||
| from modelscope.preprocessors import LoadImage | |||
| from modelscope.utils.config import Config | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| from .yolox.data.data_augment import ValTransform | |||
| from .yolox.exp import get_exp_by_name | |||
| from .yolox.utils import postprocess | |||
| @MODELS.register_module( | |||
| group_key=Tasks.image_object_detection, | |||
| module_name=Models.realtime_object_detection) | |||
| class RealtimeDetector(TorchModel): | |||
| def __init__(self, model_dir: str, *args, **kwargs): | |||
| super().__init__(model_dir, *args, **kwargs) | |||
| self.config = Config.from_file( | |||
| os.path.join(self.model_dir, ModelFile.CONFIGURATION)) | |||
| # model type | |||
| self.exp = get_exp_by_name(self.config.model_type) | |||
| # build model | |||
| self.model = self.exp.get_model() | |||
| model_path = osp.join(model_dir, ModelFile.TORCH_MODEL_BIN_FILE) | |||
| ckpt = torch.load(model_path, map_location='cpu') | |||
| # load the model state dict | |||
| self.model.load_state_dict(ckpt['model']) | |||
| self.model.eval() | |||
| # params setting | |||
| self.exp.num_classes = self.config.num_classes | |||
| self.confthre = self.config.conf_thr | |||
| self.num_classes = self.exp.num_classes | |||
| self.nmsthre = self.exp.nmsthre | |||
| self.test_size = self.exp.test_size | |||
| self.preproc = ValTransform(legacy=False) | |||
| def inference(self, img): | |||
| with torch.no_grad(): | |||
| outputs = self.model(img) | |||
| return outputs | |||
| def forward(self, inputs): | |||
| return self.inference(inputs) | |||
| def preprocess(self, img): | |||
| img = LoadImage.convert_to_ndarray(img) | |||
| height, width = img.shape[:2] | |||
| self.ratio = min(self.test_size[0] / img.shape[0], | |||
| self.test_size[1] / img.shape[1]) | |||
| img, _ = self.preproc(img, None, self.test_size) | |||
| img = torch.from_numpy(img).unsqueeze(0) | |||
| img = img.float() | |||
| return img | |||
| def postprocess(self, input): | |||
| outputs = postprocess( | |||
| input, | |||
| self.num_classes, | |||
| self.confthre, | |||
| self.nmsthre, | |||
| class_agnostic=True) | |||
| if len(outputs) == 1: | |||
| bboxes = outputs[0][:, 0:4].cpu().numpy() / self.ratio | |||
| scores = outputs[0][:, 5].cpu().numpy() | |||
| labels = outputs[0][:, 6].cpu().int().numpy() | |||
| return bboxes, scores, labels | |||
| @@ -0,0 +1,69 @@ | |||
| # The implementation is based on YOLOX, available at https://github.com/Megvii-BaseDetection/YOLOX | |||
| """ | |||
| Data augmentation functionality. Passed as callable transformations to | |||
| Dataset classes. | |||
| The data augmentation procedures were interpreted from @weiliu89's SSD paper | |||
| http://arxiv.org/abs/1512.02325 | |||
| """ | |||
| import math | |||
| import random | |||
| import cv2 | |||
| import numpy as np | |||
| from ..utils import xyxy2cxcywh | |||
| def preproc(img, input_size, swap=(2, 0, 1)): | |||
| if len(img.shape) == 3: | |||
| padded_img = np.ones( | |||
| (input_size[0], input_size[1], 3), dtype=np.uint8) * 114 | |||
| else: | |||
| padded_img = np.ones(input_size, dtype=np.uint8) * 114 | |||
| r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1]) | |||
| resized_img = cv2.resize( | |||
| img, | |||
| (int(img.shape[1] * r), int(img.shape[0] * r)), | |||
| interpolation=cv2.INTER_LINEAR, | |||
| ).astype(np.uint8) | |||
| padded_img[:int(img.shape[0] * r), :int(img.shape[1] * r)] = resized_img | |||
| padded_img = padded_img.transpose(swap) | |||
| padded_img = np.ascontiguousarray(padded_img, dtype=np.float32) | |||
| return padded_img, r | |||
| class ValTransform: | |||
| """ | |||
| Defines the transformations that should be applied to test PIL image | |||
| for input into the network | |||
| dimension -> tensorize -> color adj | |||
| Arguments: | |||
| resize (int): input dimension to SSD | |||
| rgb_means ((int,int,int)): average RGB of the dataset | |||
| (104,117,123) | |||
| swap ((int,int,int)): final order of channels | |||
| Returns: | |||
| transform (transform) : callable transform to be applied to test/val | |||
| data | |||
| """ | |||
| def __init__(self, swap=(2, 0, 1), legacy=False): | |||
| self.swap = swap | |||
| self.legacy = legacy | |||
| # assume input is cv2 img for now | |||
| def __call__(self, img, res, input_size): | |||
| img, _ = preproc(img, input_size, self.swap) | |||
| if self.legacy: | |||
| img = img[::-1, :, :].copy() | |||
| img /= 255.0 | |||
| img -= np.array([0.485, 0.456, 0.406]).reshape(3, 1, 1) | |||
| img /= np.array([0.229, 0.224, 0.225]).reshape(3, 1, 1) | |||
| return img, np.zeros((1, 5)) | |||
| @@ -0,0 +1,5 @@ | |||
| # The implementation is based on YOLOX, available at https://github.com/Megvii-BaseDetection/YOLOX | |||
| from .base_exp import BaseExp | |||
| from .build import get_exp_by_name | |||
| from .yolox_base import Exp | |||
| @@ -0,0 +1,12 @@ | |||
| # The implementation is based on YOLOX, available at https://github.com/Megvii-BaseDetection/YOLOX | |||
| from abc import ABCMeta, abstractmethod | |||
| from torch.nn import Module | |||
| class BaseExp(metaclass=ABCMeta): | |||
| @abstractmethod | |||
| def get_model(self) -> Module: | |||
| pass | |||
| @@ -0,0 +1,18 @@ | |||
| # The implementation is based on YOLOX, available at https://github.com/Megvii-BaseDetection/YOLOX | |||
| import os | |||
| import sys | |||
| def get_exp_by_name(exp_name): | |||
| exp = exp_name.replace('-', | |||
| '_') # convert string like "yolox-s" to "yolox_s" | |||
| if exp == 'yolox_s': | |||
| from .default import YoloXSExp as YoloXExp | |||
| elif exp == 'yolox_nano': | |||
| from .default import YoloXNanoExp as YoloXExp | |||
| elif exp == 'yolox_tiny': | |||
| from .default import YoloXTinyExp as YoloXExp | |||
| else: | |||
| pass | |||
| return YoloXExp() | |||
| @@ -0,0 +1,5 @@ | |||
| # The implementation is based on YOLOX, available at https://github.com/Megvii-BaseDetection/YOLOX | |||
| from .yolox_nano import YoloXNanoExp | |||
| from .yolox_s import YoloXSExp | |||
| from .yolox_tiny import YoloXTinyExp | |||
| @@ -0,0 +1,46 @@ | |||
| # The implementation is based on YOLOX, available at https://github.com/Megvii-BaseDetection/YOLOX | |||
| import os | |||
| import torch.nn as nn | |||
| from ..yolox_base import Exp as YoloXExp | |||
| class YoloXNanoExp(YoloXExp): | |||
| def __init__(self): | |||
| super(YoloXNanoExp, self).__init__() | |||
| self.depth = 0.33 | |||
| self.width = 0.25 | |||
| self.input_size = (416, 416) | |||
| self.test_size = (416, 416) | |||
| def get_model(self, sublinear=False): | |||
| def init_yolo(M): | |||
| for m in M.modules(): | |||
| if isinstance(m, nn.BatchNorm2d): | |||
| m.eps = 1e-3 | |||
| m.momentum = 0.03 | |||
| if 'model' not in self.__dict__: | |||
| from ...models import YOLOX, YOLOPAFPN, YOLOXHead | |||
| in_channels = [256, 512, 1024] | |||
| # NANO model use depthwise = True, which is main difference. | |||
| backbone = YOLOPAFPN( | |||
| self.depth, | |||
| self.width, | |||
| in_channels=in_channels, | |||
| act=self.act, | |||
| depthwise=True, | |||
| ) | |||
| head = YOLOXHead( | |||
| self.num_classes, | |||
| self.width, | |||
| in_channels=in_channels, | |||
| act=self.act, | |||
| depthwise=True) | |||
| self.model = YOLOX(backbone, head) | |||
| return self.model | |||
| @@ -0,0 +1,13 @@ | |||
| # The implementation is based on YOLOX, available at https://github.com/Megvii-BaseDetection/YOLOX | |||
| import os | |||
| from ..yolox_base import Exp as YoloXExp | |||
| class YoloXSExp(YoloXExp): | |||
| def __init__(self): | |||
| super(YoloXSExp, self).__init__() | |||
| self.depth = 0.33 | |||
| self.width = 0.50 | |||
| @@ -0,0 +1,20 @@ | |||
| # The implementation is based on YOLOX, available at https://github.com/Megvii-BaseDetection/YOLOX | |||
| import os | |||
| from ..yolox_base import Exp as YoloXExp | |||
| class YoloXTinyExp(YoloXExp): | |||
| def __init__(self): | |||
| super(YoloXTinyExp, self).__init__() | |||
| self.depth = 0.33 | |||
| self.width = 0.375 | |||
| self.input_size = (416, 416) | |||
| self.mosaic_scale = (0.5, 1.5) | |||
| self.random_size = (10, 20) | |||
| self.test_size = (416, 416) | |||
| self.exp_name = os.path.split( | |||
| os.path.realpath(__file__))[1].split('.')[0] | |||
| self.enable_mixup = False | |||
| @@ -0,0 +1,59 @@ | |||
| # The implementation is based on YOLOX, available at https://github.com/Megvii-BaseDetection/YOLOX | |||
| import os | |||
| import random | |||
| import torch | |||
| import torch.distributed as dist | |||
| import torch.nn as nn | |||
| from .base_exp import BaseExp | |||
| class Exp(BaseExp): | |||
| def __init__(self): | |||
| super().__init__() | |||
| # ---------------- model config ---------------- # | |||
| # detect classes number of model | |||
| self.num_classes = 80 | |||
| # factor of model depth | |||
| self.depth = 1.00 | |||
| # factor of model width | |||
| self.width = 1.00 | |||
| # activation name. For example, if using "relu", then "silu" will be replaced to "relu". | |||
| self.act = 'silu' | |||
| # ----------------- testing config ------------------ # | |||
| # output image size during evaluation/test | |||
| self.test_size = (640, 640) | |||
| # confidence threshold during evaluation/test, | |||
| # boxes whose scores are less than test_conf will be filtered | |||
| self.test_conf = 0.01 | |||
| # nms threshold | |||
| self.nmsthre = 0.65 | |||
| def get_model(self): | |||
| from ..models import YOLOX, YOLOPAFPN, YOLOXHead | |||
| def init_yolo(M): | |||
| for m in M.modules(): | |||
| if isinstance(m, nn.BatchNorm2d): | |||
| m.eps = 1e-3 | |||
| m.momentum = 0.03 | |||
| if getattr(self, 'model', None) is None: | |||
| in_channels = [256, 512, 1024] | |||
| backbone = YOLOPAFPN( | |||
| self.depth, self.width, in_channels=in_channels, act=self.act) | |||
| head = YOLOXHead( | |||
| self.num_classes, | |||
| self.width, | |||
| in_channels=in_channels, | |||
| act=self.act) | |||
| self.model = YOLOX(backbone, head) | |||
| self.model.apply(init_yolo) | |||
| self.model.head.initialize_biases(1e-2) | |||
| self.model.train() | |||
| return self.model | |||
| @@ -0,0 +1,7 @@ | |||
| # The implementation is based on YOLOX, available at https://github.com/Megvii-BaseDetection/YOLOX | |||
| from .darknet import CSPDarknet, Darknet | |||
| from .yolo_fpn import YOLOFPN | |||
| from .yolo_head import YOLOXHead | |||
| from .yolo_pafpn import YOLOPAFPN | |||
| from .yolox import YOLOX | |||
| @@ -0,0 +1,189 @@ | |||
| # The implementation is based on YOLOX, available at https://github.com/Megvii-BaseDetection/YOLOX | |||
| from torch import nn | |||
| from .network_blocks import (BaseConv, CSPLayer, DWConv, Focus, ResLayer, | |||
| SPPBottleneck) | |||
| class Darknet(nn.Module): | |||
| # number of blocks from dark2 to dark5. | |||
| depth2blocks = {21: [1, 2, 2, 1], 53: [2, 8, 8, 4]} | |||
| def __init__( | |||
| self, | |||
| depth, | |||
| in_channels=3, | |||
| stem_out_channels=32, | |||
| out_features=('dark3', 'dark4', 'dark5'), | |||
| ): | |||
| """ | |||
| Args: | |||
| depth (int): depth of darknet used in model, usually use [21, 53] for this param. | |||
| in_channels (int): number of input channels, for example, use 3 for RGB image. | |||
| stem_out_channels (int): number of output channels of darknet stem. | |||
| It decides channels of darknet layer2 to layer5. | |||
| out_features (Tuple[str]): desired output layer name. | |||
| """ | |||
| super().__init__() | |||
| assert out_features, 'please provide output features of Darknet' | |||
| self.out_features = out_features | |||
| self.stem = nn.Sequential( | |||
| BaseConv( | |||
| in_channels, stem_out_channels, ksize=3, stride=1, | |||
| act='lrelu'), | |||
| *self.make_group_layer(stem_out_channels, num_blocks=1, stride=2), | |||
| ) | |||
| in_channels = stem_out_channels * 2 # 64 | |||
| num_blocks = Darknet.depth2blocks[depth] | |||
| # create darknet with `stem_out_channels` and `num_blocks` layers. | |||
| # to make model structure more clear, we don't use `for` statement in python. | |||
| self.dark2 = nn.Sequential( | |||
| *self.make_group_layer(in_channels, num_blocks[0], stride=2)) | |||
| in_channels *= 2 # 128 | |||
| self.dark3 = nn.Sequential( | |||
| *self.make_group_layer(in_channels, num_blocks[1], stride=2)) | |||
| in_channels *= 2 # 256 | |||
| self.dark4 = nn.Sequential( | |||
| *self.make_group_layer(in_channels, num_blocks[2], stride=2)) | |||
| in_channels *= 2 # 512 | |||
| self.dark5 = nn.Sequential( | |||
| *self.make_group_layer(in_channels, num_blocks[3], stride=2), | |||
| *self.make_spp_block([in_channels, in_channels * 2], | |||
| in_channels * 2), | |||
| ) | |||
| def make_group_layer(self, | |||
| in_channels: int, | |||
| num_blocks: int, | |||
| stride: int = 1): | |||
| 'starts with conv layer then has `num_blocks` `ResLayer`' | |||
| return [ | |||
| BaseConv( | |||
| in_channels, | |||
| in_channels * 2, | |||
| ksize=3, | |||
| stride=stride, | |||
| act='lrelu'), | |||
| *[(ResLayer(in_channels * 2)) for _ in range(num_blocks)], | |||
| ] | |||
| def make_spp_block(self, filters_list, in_filters): | |||
| m = nn.Sequential(*[ | |||
| BaseConv(in_filters, filters_list[0], 1, stride=1, act='lrelu'), | |||
| BaseConv( | |||
| filters_list[0], filters_list[1], 3, stride=1, act='lrelu'), | |||
| SPPBottleneck( | |||
| in_channels=filters_list[1], | |||
| out_channels=filters_list[0], | |||
| activation='lrelu', | |||
| ), | |||
| BaseConv( | |||
| filters_list[0], filters_list[1], 3, stride=1, act='lrelu'), | |||
| BaseConv( | |||
| filters_list[1], filters_list[0], 1, stride=1, act='lrelu'), | |||
| ]) | |||
| return m | |||
| def forward(self, x): | |||
| outputs = {} | |||
| x = self.stem(x) | |||
| outputs['stem'] = x | |||
| x = self.dark2(x) | |||
| outputs['dark2'] = x | |||
| x = self.dark3(x) | |||
| outputs['dark3'] = x | |||
| x = self.dark4(x) | |||
| outputs['dark4'] = x | |||
| x = self.dark5(x) | |||
| outputs['dark5'] = x | |||
| return {k: v for k, v in outputs.items() if k in self.out_features} | |||
| class CSPDarknet(nn.Module): | |||
| def __init__( | |||
| self, | |||
| dep_mul, | |||
| wid_mul, | |||
| out_features=('dark3', 'dark4', 'dark5'), | |||
| depthwise=False, | |||
| act='silu', | |||
| ): | |||
| super().__init__() | |||
| assert out_features, 'please provide output features of Darknet' | |||
| self.out_features = out_features | |||
| Conv = DWConv if depthwise else BaseConv | |||
| base_channels = int(wid_mul * 64) # 64 | |||
| base_depth = max(round(dep_mul * 3), 1) # 3 | |||
| # stem | |||
| self.stem = Focus(3, base_channels, ksize=3, act=act) | |||
| # dark2 | |||
| self.dark2 = nn.Sequential( | |||
| Conv(base_channels, base_channels * 2, 3, 2, act=act), | |||
| CSPLayer( | |||
| base_channels * 2, | |||
| base_channels * 2, | |||
| n=base_depth, | |||
| depthwise=depthwise, | |||
| act=act, | |||
| ), | |||
| ) | |||
| # dark3 | |||
| self.dark3 = nn.Sequential( | |||
| Conv(base_channels * 2, base_channels * 4, 3, 2, act=act), | |||
| CSPLayer( | |||
| base_channels * 4, | |||
| base_channels * 4, | |||
| n=base_depth * 3, | |||
| depthwise=depthwise, | |||
| act=act, | |||
| ), | |||
| ) | |||
| # dark4 | |||
| self.dark4 = nn.Sequential( | |||
| Conv(base_channels * 4, base_channels * 8, 3, 2, act=act), | |||
| CSPLayer( | |||
| base_channels * 8, | |||
| base_channels * 8, | |||
| n=base_depth * 3, | |||
| depthwise=depthwise, | |||
| act=act, | |||
| ), | |||
| ) | |||
| # dark5 | |||
| self.dark5 = nn.Sequential( | |||
| Conv(base_channels * 8, base_channels * 16, 3, 2, act=act), | |||
| SPPBottleneck( | |||
| base_channels * 16, base_channels * 16, activation=act), | |||
| CSPLayer( | |||
| base_channels * 16, | |||
| base_channels * 16, | |||
| n=base_depth, | |||
| shortcut=False, | |||
| depthwise=depthwise, | |||
| act=act, | |||
| ), | |||
| ) | |||
| def forward(self, x): | |||
| outputs = {} | |||
| x = self.stem(x) | |||
| outputs['stem'] = x | |||
| x = self.dark2(x) | |||
| outputs['dark2'] = x | |||
| x = self.dark3(x) | |||
| outputs['dark3'] = x | |||
| x = self.dark4(x) | |||
| outputs['dark4'] = x | |||
| x = self.dark5(x) | |||
| outputs['dark5'] = x | |||
| return {k: v for k, v in outputs.items() if k in self.out_features} | |||
| @@ -0,0 +1,213 @@ | |||
| # The implementation is based on YOLOX, available at https://github.com/Megvii-BaseDetection/YOLOX | |||
| import torch | |||
| import torch.nn as nn | |||
| def get_activation(name='silu', inplace=True): | |||
| if name == 'silu': | |||
| module = nn.SiLU(inplace=inplace) | |||
| else: | |||
| raise AttributeError('Unsupported act type: {}'.format(name)) | |||
| return module | |||
| class BaseConv(nn.Module): | |||
| """A Conv2d -> Batchnorm -> silu/leaky relu block""" | |||
| def __init__(self, | |||
| in_channels, | |||
| out_channels, | |||
| ksize, | |||
| stride, | |||
| groups=1, | |||
| bias=False, | |||
| act='silu'): | |||
| super(BaseConv, self).__init__() | |||
| # same padding | |||
| pad = (ksize - 1) // 2 | |||
| self.conv = nn.Conv2d( | |||
| in_channels, | |||
| out_channels, | |||
| kernel_size=ksize, | |||
| stride=stride, | |||
| padding=pad, | |||
| groups=groups, | |||
| bias=bias, | |||
| ) | |||
| self.bn = nn.BatchNorm2d(out_channels) | |||
| self.act = get_activation(act, inplace=True) | |||
| def forward(self, x): | |||
| return self.act(self.bn(self.conv(x))) | |||
| def fuseforward(self, x): | |||
| return self.act(self.conv(x)) | |||
| class DWConv(nn.Module): | |||
| """Depthwise Conv + Conv""" | |||
| def __init__(self, in_channels, out_channels, ksize, stride=1, act='silu'): | |||
| super(DWConv, self).__init__() | |||
| self.dconv = BaseConv( | |||
| in_channels, | |||
| in_channels, | |||
| ksize=ksize, | |||
| stride=stride, | |||
| groups=in_channels, | |||
| act=act, | |||
| ) | |||
| self.pconv = BaseConv( | |||
| in_channels, out_channels, ksize=1, stride=1, groups=1, act=act) | |||
| def forward(self, x): | |||
| x = self.dconv(x) | |||
| return self.pconv(x) | |||
| class Bottleneck(nn.Module): | |||
| # Standard bottleneck | |||
| def __init__( | |||
| self, | |||
| in_channels, | |||
| out_channels, | |||
| shortcut=True, | |||
| expansion=0.5, | |||
| depthwise=False, | |||
| act='silu', | |||
| ): | |||
| super().__init__() | |||
| hidden_channels = int(out_channels * expansion) | |||
| Conv = DWConv if depthwise else BaseConv | |||
| self.conv1 = BaseConv( | |||
| in_channels, hidden_channels, 1, stride=1, act=act) | |||
| self.conv2 = Conv(hidden_channels, out_channels, 3, stride=1, act=act) | |||
| self.use_add = shortcut and in_channels == out_channels | |||
| def forward(self, x): | |||
| y = self.conv2(self.conv1(x)) | |||
| if self.use_add: | |||
| y = y + x | |||
| return y | |||
| class ResLayer(nn.Module): | |||
| 'Residual layer with `in_channels` inputs.' | |||
| def __init__(self, in_channels: int): | |||
| super().__init__() | |||
| mid_channels = in_channels // 2 | |||
| self.layer1 = BaseConv( | |||
| in_channels, mid_channels, ksize=1, stride=1, act='lrelu') | |||
| self.layer2 = BaseConv( | |||
| mid_channels, in_channels, ksize=3, stride=1, act='lrelu') | |||
| def forward(self, x): | |||
| out = self.layer2(self.layer1(x)) | |||
| return x + out | |||
| class SPPBottleneck(nn.Module): | |||
| """Spatial pyramid pooling layer used in YOLOv3-SPP""" | |||
| def __init__(self, | |||
| in_channels, | |||
| out_channels, | |||
| kernel_sizes=(5, 9, 13), | |||
| activation='silu'): | |||
| super().__init__() | |||
| hidden_channels = in_channels // 2 | |||
| self.conv1 = BaseConv( | |||
| in_channels, hidden_channels, 1, stride=1, act=activation) | |||
| self.m = nn.ModuleList([ | |||
| nn.MaxPool2d(kernel_size=ks, stride=1, padding=ks // 2) | |||
| for ks in kernel_sizes | |||
| ]) | |||
| conv2_channels = hidden_channels * (len(kernel_sizes) + 1) | |||
| self.conv2 = BaseConv( | |||
| conv2_channels, out_channels, 1, stride=1, act=activation) | |||
| def forward(self, x): | |||
| x = self.conv1(x) | |||
| x = torch.cat([x] + [m(x) for m in self.m], dim=1) | |||
| x = self.conv2(x) | |||
| return x | |||
| class CSPLayer(nn.Module): | |||
| """C3 in yolov5, CSP Bottleneck with 3 convolutions""" | |||
| def __init__( | |||
| self, | |||
| in_channels, | |||
| out_channels, | |||
| n=1, | |||
| shortcut=True, | |||
| expansion=0.5, | |||
| depthwise=False, | |||
| act='silu', | |||
| ): | |||
| """ | |||
| Args: | |||
| in_channels (int): input channels. | |||
| out_channels (int): output channels. | |||
| n (int): number of Bottlenecks. Default value: 1. | |||
| """ | |||
| # ch_in, ch_out, number, shortcut, groups, expansion | |||
| super().__init__() | |||
| hidden_channels = int(out_channels * expansion) # hidden channels | |||
| self.conv1 = BaseConv( | |||
| in_channels, hidden_channels, 1, stride=1, act=act) | |||
| self.conv2 = BaseConv( | |||
| in_channels, hidden_channels, 1, stride=1, act=act) | |||
| self.conv3 = BaseConv( | |||
| 2 * hidden_channels, out_channels, 1, stride=1, act=act) | |||
| module_list = [ | |||
| Bottleneck( | |||
| hidden_channels, | |||
| hidden_channels, | |||
| shortcut, | |||
| 1.0, | |||
| depthwise, | |||
| act=act) for _ in range(n) | |||
| ] | |||
| self.m = nn.Sequential(*module_list) | |||
| def forward(self, x): | |||
| x_1 = self.conv1(x) | |||
| x_2 = self.conv2(x) | |||
| x_1 = self.m(x_1) | |||
| x = torch.cat((x_1, x_2), dim=1) | |||
| return self.conv3(x) | |||
| class Focus(nn.Module): | |||
| """Focus width and height information into channel space.""" | |||
| def __init__(self, | |||
| in_channels, | |||
| out_channels, | |||
| ksize=1, | |||
| stride=1, | |||
| act='silu'): | |||
| super().__init__() | |||
| self.conv = BaseConv( | |||
| in_channels * 4, out_channels, ksize, stride, act=act) | |||
| def forward(self, x): | |||
| # shape of x (b,c,w,h) -> y(b,4c,w/2,h/2) | |||
| patch_top_left = x[..., ::2, ::2] | |||
| patch_top_right = x[..., ::2, 1::2] | |||
| patch_bot_left = x[..., 1::2, ::2] | |||
| patch_bot_right = x[..., 1::2, 1::2] | |||
| x = torch.cat( | |||
| ( | |||
| patch_top_left, | |||
| patch_bot_left, | |||
| patch_top_right, | |||
| patch_bot_right, | |||
| ), | |||
| dim=1, | |||
| ) | |||
| return self.conv(x) | |||
| @@ -0,0 +1,80 @@ | |||
| # The implementation is based on YOLOX, available at https://github.com/Megvii-BaseDetection/YOLOX | |||
| import torch | |||
| import torch.nn as nn | |||
| from .darknet import Darknet | |||
| from .network_blocks import BaseConv | |||
| class YOLOFPN(nn.Module): | |||
| """ | |||
| YOLOFPN module. Darknet 53 is the default backbone of this model. | |||
| """ | |||
| def __init__( | |||
| self, | |||
| depth=53, | |||
| in_features=['dark3', 'dark4', 'dark5'], | |||
| ): | |||
| super(YOLOFPN, self).__init__() | |||
| self.backbone = Darknet(depth) | |||
| self.in_features = in_features | |||
| # out 1 | |||
| self.out1_cbl = self._make_cbl(512, 256, 1) | |||
| self.out1 = self._make_embedding([256, 512], 512 + 256) | |||
| # out 2 | |||
| self.out2_cbl = self._make_cbl(256, 128, 1) | |||
| self.out2 = self._make_embedding([128, 256], 256 + 128) | |||
| # upsample | |||
| self.upsample = nn.Upsample(scale_factor=2, mode='nearest') | |||
| def _make_cbl(self, _in, _out, ks): | |||
| return BaseConv(_in, _out, ks, stride=1, act='lrelu') | |||
| def _make_embedding(self, filters_list, in_filters): | |||
| m = nn.Sequential(*[ | |||
| self._make_cbl(in_filters, filters_list[0], 1), | |||
| self._make_cbl(filters_list[0], filters_list[1], 3), | |||
| self._make_cbl(filters_list[1], filters_list[0], 1), | |||
| self._make_cbl(filters_list[0], filters_list[1], 3), | |||
| self._make_cbl(filters_list[1], filters_list[0], 1), | |||
| ]) | |||
| return m | |||
| def load_pretrained_model(self, filename='./weights/darknet53.mix.pth'): | |||
| with open(filename, 'rb') as f: | |||
| state_dict = torch.load(f, map_location='cpu') | |||
| print('loading pretrained weights...') | |||
| self.backbone.load_state_dict(state_dict) | |||
| def forward(self, inputs): | |||
| """ | |||
| Args: | |||
| inputs (Tensor): input image. | |||
| Returns: | |||
| Tuple[Tensor]: FPN output features.. | |||
| """ | |||
| # backbone | |||
| out_features = self.backbone(inputs) | |||
| x2, x1, x0 = [out_features[f] for f in self.in_features] | |||
| # yolo branch 1 | |||
| x1_in = self.out1_cbl(x0) | |||
| x1_in = self.upsample(x1_in) | |||
| x1_in = torch.cat([x1_in, x1], 1) | |||
| out_dark4 = self.out1(x1_in) | |||
| # yolo branch 2 | |||
| x2_in = self.out2_cbl(out_dark4) | |||
| x2_in = self.upsample(x2_in) | |||
| x2_in = torch.cat([x2_in, x2], 1) | |||
| out_dark3 = self.out2(x2_in) | |||
| outputs = (out_dark3, out_dark4, x0) | |||
| return outputs | |||
| @@ -0,0 +1,182 @@ | |||
| # The implementation is based on YOLOX, available at https://github.com/Megvii-BaseDetection/YOLOX | |||
| import math | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| from ..utils import bboxes_iou, meshgrid | |||
| from .network_blocks import BaseConv, DWConv | |||
| class YOLOXHead(nn.Module): | |||
| def __init__( | |||
| self, | |||
| num_classes, | |||
| width=1.0, | |||
| strides=[8, 16, 32], | |||
| in_channels=[256, 512, 1024], | |||
| act='silu', | |||
| depthwise=False, | |||
| ): | |||
| """ | |||
| Args: | |||
| act (str): activation type of conv. Defalut value: "silu". | |||
| depthwise (bool): whether apply depthwise conv in conv branch. Defalut value: False. | |||
| """ | |||
| super(YOLOXHead, self).__init__() | |||
| self.n_anchors = 1 | |||
| self.num_classes = num_classes | |||
| self.decode_in_inference = True # for deploy, set to False | |||
| self.cls_convs = nn.ModuleList() | |||
| self.reg_convs = nn.ModuleList() | |||
| self.cls_preds = nn.ModuleList() | |||
| self.reg_preds = nn.ModuleList() | |||
| self.obj_preds = nn.ModuleList() | |||
| self.stems = nn.ModuleList() | |||
| Conv = DWConv if depthwise else BaseConv | |||
| for i in range(len(in_channels)): | |||
| self.stems.append( | |||
| BaseConv( | |||
| in_channels=int(in_channels[i] * width), | |||
| out_channels=int(256 * width), | |||
| ksize=1, | |||
| stride=1, | |||
| act=act, | |||
| )) | |||
| self.cls_convs.append( | |||
| nn.Sequential(*[ | |||
| Conv( | |||
| in_channels=int(256 * width), | |||
| out_channels=int(256 * width), | |||
| ksize=3, | |||
| stride=1, | |||
| act=act, | |||
| ), | |||
| Conv( | |||
| in_channels=int(256 * width), | |||
| out_channels=int(256 * width), | |||
| ksize=3, | |||
| stride=1, | |||
| act=act, | |||
| ), | |||
| ])) | |||
| self.reg_convs.append( | |||
| nn.Sequential(*[ | |||
| Conv( | |||
| in_channels=int(256 * width), | |||
| out_channels=int(256 * width), | |||
| ksize=3, | |||
| stride=1, | |||
| act=act, | |||
| ), | |||
| Conv( | |||
| in_channels=int(256 * width), | |||
| out_channels=int(256 * width), | |||
| ksize=3, | |||
| stride=1, | |||
| act=act, | |||
| ), | |||
| ])) | |||
| self.cls_preds.append( | |||
| nn.Conv2d( | |||
| in_channels=int(256 * width), | |||
| out_channels=self.n_anchors * self.num_classes, | |||
| kernel_size=1, | |||
| stride=1, | |||
| padding=0, | |||
| )) | |||
| self.reg_preds.append( | |||
| nn.Conv2d( | |||
| in_channels=int(256 * width), | |||
| out_channels=4, | |||
| kernel_size=1, | |||
| stride=1, | |||
| padding=0, | |||
| )) | |||
| self.obj_preds.append( | |||
| nn.Conv2d( | |||
| in_channels=int(256 * width), | |||
| out_channels=self.n_anchors * 1, | |||
| kernel_size=1, | |||
| stride=1, | |||
| padding=0, | |||
| )) | |||
| self.use_l1 = False | |||
| self.l1_loss = nn.L1Loss(reduction='none') | |||
| self.bcewithlog_loss = nn.BCEWithLogitsLoss(reduction='none') | |||
| # self.iou_loss = IOUloss(reduction="none") | |||
| self.strides = strides | |||
| self.grids = [torch.zeros(1)] * len(in_channels) | |||
| def initialize_biases(self, prior_prob): | |||
| for conv in self.cls_preds: | |||
| b = conv.bias.view(self.n_anchors, -1) | |||
| b.data.fill_(-math.log((1 - prior_prob) / prior_prob)) | |||
| conv.bias = torch.nn.Parameter(b.view(-1), requires_grad=True) | |||
| for conv in self.obj_preds: | |||
| b = conv.bias.view(self.n_anchors, -1) | |||
| b.data.fill_(-math.log((1 - prior_prob) / prior_prob)) | |||
| conv.bias = torch.nn.Parameter(b.view(-1), requires_grad=True) | |||
| def forward(self, xin, labels=None, imgs=None): | |||
| outputs = [] | |||
| for k, (cls_conv, reg_conv, stride_this_level, x) in enumerate( | |||
| zip(self.cls_convs, self.reg_convs, self.strides, xin)): | |||
| x = self.stems[k](x) | |||
| cls_x = x | |||
| reg_x = x | |||
| cls_feat = cls_conv(cls_x) | |||
| cls_output = self.cls_preds[k](cls_feat) | |||
| reg_feat = reg_conv(reg_x) | |||
| reg_output = self.reg_preds[k](reg_feat) | |||
| obj_output = self.obj_preds[k](reg_feat) | |||
| if self.training: | |||
| pass | |||
| else: | |||
| output = torch.cat( | |||
| [reg_output, | |||
| obj_output.sigmoid(), | |||
| cls_output.sigmoid()], 1) | |||
| outputs.append(output) | |||
| if self.training: | |||
| pass | |||
| else: | |||
| self.hw = [x.shape[-2:] for x in outputs] | |||
| # [batch, n_anchors_all, 85] | |||
| outputs = torch.cat([x.flatten(start_dim=2) for x in outputs], | |||
| dim=2).permute(0, 2, 1) | |||
| if self.decode_in_inference: | |||
| return self.decode_outputs(outputs, dtype=xin[0].type()) | |||
| else: | |||
| return outputs | |||
| def decode_outputs(self, outputs, dtype): | |||
| grids = [] | |||
| strides = [] | |||
| for (hsize, wsize), stride in zip(self.hw, self.strides): | |||
| yv, xv = meshgrid([torch.arange(hsize), torch.arange(wsize)]) | |||
| grid = torch.stack((xv, yv), 2).view(1, -1, 2) | |||
| grids.append(grid) | |||
| shape = grid.shape[:2] | |||
| strides.append(torch.full((*shape, 1), stride)) | |||
| grids = torch.cat(grids, dim=1).type(dtype) | |||
| strides = torch.cat(strides, dim=1).type(dtype) | |||
| outputs[..., :2] = (outputs[..., :2] + grids) * strides | |||
| outputs[..., 2:4] = torch.exp(outputs[..., 2:4]) * strides | |||
| return outputs | |||
| @@ -0,0 +1,126 @@ | |||
| # The implementation is based on YOLOX, available at https://github.com/Megvii-BaseDetection/YOLOX | |||
| import torch | |||
| import torch.nn as nn | |||
| from .darknet import CSPDarknet | |||
| from .network_blocks import BaseConv, CSPLayer, DWConv | |||
| class YOLOPAFPN(nn.Module): | |||
| """ | |||
| YOLOv3 model. Darknet 53 is the default backbone of this model. | |||
| """ | |||
| def __init__( | |||
| self, | |||
| depth=1.0, | |||
| width=1.0, | |||
| in_features=('dark3', 'dark4', 'dark5'), | |||
| in_channels=[256, 512, 1024], | |||
| depthwise=False, | |||
| act='silu', | |||
| ): | |||
| super(YOLOPAFPN, self).__init__() | |||
| self.backbone = CSPDarknet(depth, width, depthwise=depthwise, act=act) | |||
| self.in_features = in_features | |||
| self.in_channels = in_channels | |||
| Conv = DWConv if depthwise else BaseConv | |||
| self.upsample = nn.Upsample(scale_factor=2, mode='nearest') | |||
| self.lateral_conv0 = BaseConv( | |||
| int(in_channels[2] * width), | |||
| int(in_channels[1] * width), | |||
| 1, | |||
| 1, | |||
| act=act) | |||
| self.C3_p4 = CSPLayer( | |||
| int(2 * in_channels[1] * width), | |||
| int(in_channels[1] * width), | |||
| round(3 * depth), | |||
| False, | |||
| depthwise=depthwise, | |||
| act=act, | |||
| ) # cat | |||
| self.reduce_conv1 = BaseConv( | |||
| int(in_channels[1] * width), | |||
| int(in_channels[0] * width), | |||
| 1, | |||
| 1, | |||
| act=act) | |||
| self.C3_p3 = CSPLayer( | |||
| int(2 * in_channels[0] * width), | |||
| int(in_channels[0] * width), | |||
| round(3 * depth), | |||
| False, | |||
| depthwise=depthwise, | |||
| act=act, | |||
| ) | |||
| # bottom-up conv | |||
| self.bu_conv2 = Conv( | |||
| int(in_channels[0] * width), | |||
| int(in_channels[0] * width), | |||
| 3, | |||
| 2, | |||
| act=act) | |||
| self.C3_n3 = CSPLayer( | |||
| int(2 * in_channels[0] * width), | |||
| int(in_channels[1] * width), | |||
| round(3 * depth), | |||
| False, | |||
| depthwise=depthwise, | |||
| act=act, | |||
| ) | |||
| # bottom-up conv | |||
| self.bu_conv1 = Conv( | |||
| int(in_channels[1] * width), | |||
| int(in_channels[1] * width), | |||
| 3, | |||
| 2, | |||
| act=act) | |||
| self.C3_n4 = CSPLayer( | |||
| int(2 * in_channels[1] * width), | |||
| int(in_channels[2] * width), | |||
| round(3 * depth), | |||
| False, | |||
| depthwise=depthwise, | |||
| act=act, | |||
| ) | |||
| def forward(self, input): | |||
| """ | |||
| Args: | |||
| inputs: input images. | |||
| Returns: | |||
| Tuple[Tensor]: FPN feature. | |||
| """ | |||
| # backbone | |||
| out_features = self.backbone(input) | |||
| features = [out_features[f] for f in self.in_features] | |||
| [x2, x1, x0] = features | |||
| fpn_out0 = self.lateral_conv0(x0) # 1024->512/32 | |||
| f_out0 = self.upsample(fpn_out0) # 512/16 | |||
| f_out0 = torch.cat([f_out0, x1], 1) # 512->1024/16 | |||
| f_out0 = self.C3_p4(f_out0) # 1024->512/16 | |||
| fpn_out1 = self.reduce_conv1(f_out0) # 512->256/16 | |||
| f_out1 = self.upsample(fpn_out1) # 256/8 | |||
| f_out1 = torch.cat([f_out1, x2], 1) # 256->512/8 | |||
| pan_out2 = self.C3_p3(f_out1) # 512->256/8 | |||
| p_out1 = self.bu_conv2(pan_out2) # 256->256/16 | |||
| p_out1 = torch.cat([p_out1, fpn_out1], 1) # 256->512/16 | |||
| pan_out1 = self.C3_n3(p_out1) # 512->512/16 | |||
| p_out0 = self.bu_conv1(pan_out1) # 512->512/32 | |||
| p_out0 = torch.cat([p_out0, fpn_out0], 1) # 512->1024/32 | |||
| pan_out0 = self.C3_n4(p_out0) # 1024->1024/32 | |||
| outputs = (pan_out2, pan_out1, pan_out0) | |||
| return outputs | |||
| @@ -0,0 +1,33 @@ | |||
| # The implementation is based on YOLOX, available at https://github.com/Megvii-BaseDetection/YOLOX | |||
| import torch.nn as nn | |||
| from .yolo_head import YOLOXHead | |||
| from .yolo_pafpn import YOLOPAFPN | |||
| class YOLOX(nn.Module): | |||
| """ | |||
| YOLOX model module. The module list is defined by create_yolov3_modules function. | |||
| The network returns loss values from three YOLO layers during training | |||
| and detection results during test. | |||
| """ | |||
| def __init__(self, backbone=None, head=None): | |||
| super(YOLOX, self).__init__() | |||
| if backbone is None: | |||
| backbone = YOLOPAFPN() | |||
| if head is None: | |||
| head = YOLOXHead(80) | |||
| self.backbone = backbone | |||
| self.head = head | |||
| def forward(self, x, targets=None): | |||
| fpn_outs = self.backbone(x) | |||
| if self.training: | |||
| raise NotImplementedError('Training is not supported yet!') | |||
| else: | |||
| outputs = self.head(fpn_outs) | |||
| return outputs | |||
| @@ -0,0 +1,5 @@ | |||
| # The implementation is based on YOLOX, available at https://github.com/Megvii-BaseDetection/YOLOX | |||
| from .boxes import * # noqa | |||
| __all__ = ['bboxes_iou', 'meshgrid', 'postprocess', 'xyxy2cxcywh', 'xyxy2xywh'] | |||
| @@ -0,0 +1,107 @@ | |||
| # The implementation is based on YOLOX, available at https://github.com/Megvii-BaseDetection/YOLOX | |||
| import torch | |||
| import torchvision | |||
| _TORCH_VER = [int(x) for x in torch.__version__.split('.')[:2]] | |||
| def meshgrid(*tensors): | |||
| if _TORCH_VER >= [1, 10]: | |||
| return torch.meshgrid(*tensors, indexing='ij') | |||
| else: | |||
| return torch.meshgrid(*tensors) | |||
| def xyxy2xywh(bboxes): | |||
| bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 0] | |||
| bboxes[:, 3] = bboxes[:, 3] - bboxes[:, 1] | |||
| return bboxes | |||
| def xyxy2cxcywh(bboxes): | |||
| bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 0] | |||
| bboxes[:, 3] = bboxes[:, 3] - bboxes[:, 1] | |||
| bboxes[:, 0] = bboxes[:, 0] + bboxes[:, 2] * 0.5 | |||
| bboxes[:, 1] = bboxes[:, 1] + bboxes[:, 3] * 0.5 | |||
| return bboxes | |||
| def postprocess(prediction, | |||
| num_classes, | |||
| conf_thre=0.7, | |||
| nms_thre=0.45, | |||
| class_agnostic=False): | |||
| box_corner = prediction.new(prediction.shape) | |||
| box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2 | |||
| box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2 | |||
| box_corner[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2 | |||
| box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2 | |||
| prediction[:, :, :4] = box_corner[:, :, :4] | |||
| output = [None for _ in range(len(prediction))] | |||
| for i, image_pred in enumerate(prediction): | |||
| # If none are remaining => process next image | |||
| if not image_pred.size(0): | |||
| continue | |||
| # Get score and class with highest confidence | |||
| class_conf, class_pred = torch.max( | |||
| image_pred[:, 5:5 + num_classes], 1, keepdim=True) | |||
| conf_mask = image_pred[:, 4] * class_conf.squeeze() | |||
| conf_mask = (conf_mask >= conf_thre).squeeze() | |||
| # Detections ordered as (x1, y1, x2, y2, obj_conf, class_conf, class_pred) | |||
| detections = torch.cat( | |||
| (image_pred[:, :5], class_conf, class_pred.float()), 1) | |||
| detections = detections[conf_mask] | |||
| if not detections.size(0): | |||
| continue | |||
| if class_agnostic: | |||
| nms_out_index = torchvision.ops.nms( | |||
| detections[:, :4], | |||
| detections[:, 4] * detections[:, 5], | |||
| nms_thre, | |||
| ) | |||
| else: | |||
| nms_out_index = torchvision.ops.batched_nms( | |||
| detections[:, :4], | |||
| detections[:, 4] * detections[:, 5], | |||
| detections[:, 6], | |||
| nms_thre, | |||
| ) | |||
| detections = detections[nms_out_index] | |||
| if output[i] is None: | |||
| output[i] = detections | |||
| else: | |||
| output[i] = torch.cat((output[i], detections)) | |||
| return output | |||
| def bboxes_iou(bboxes_a, bboxes_b, xyxy=True): | |||
| if bboxes_a.shape[1] != 4 or bboxes_b.shape[1] != 4: | |||
| raise IndexError | |||
| if xyxy: | |||
| tl = torch.max(bboxes_a[:, None, :2], bboxes_b[:, :2]) | |||
| br = torch.min(bboxes_a[:, None, 2:], bboxes_b[:, 2:]) | |||
| area_a = torch.prod(bboxes_a[:, 2:] - bboxes_a[:, :2], 1) | |||
| area_b = torch.prod(bboxes_b[:, 2:] - bboxes_b[:, :2], 1) | |||
| else: | |||
| tl = torch.max( | |||
| (bboxes_a[:, None, :2] - bboxes_a[:, None, 2:] / 2), | |||
| (bboxes_b[:, :2] - bboxes_b[:, 2:] / 2), | |||
| ) | |||
| br = torch.min( | |||
| (bboxes_a[:, None, :2] + bboxes_a[:, None, 2:] / 2), | |||
| (bboxes_b[:, :2] + bboxes_b[:, 2:] / 2), | |||
| ) | |||
| area_a = torch.prod(bboxes_a[:, 2:], 1) | |||
| area_b = torch.prod(bboxes_b[:, 2:], 1) | |||
| en = (tl < br).type(tl.type()).prod(dim=2) | |||
| area_i = torch.prod(br - tl, 2) * en # * ((tl < br).all()) | |||
| return area_i / (area_a[:, None] + area_b - area_i) | |||
| @@ -32,6 +32,7 @@ if TYPE_CHECKING: | |||
| from .image_to_image_generate_pipeline import Image2ImageGenerationPipeline | |||
| from .image_to_image_translation_pipeline import Image2ImageTranslationPipeline | |||
| from .product_retrieval_embedding_pipeline import ProductRetrievalEmbeddingPipeline | |||
| from .realtime_object_detection_pipeline import RealtimeObjectDetectionPipeline | |||
| from .live_category_pipeline import LiveCategoryPipeline | |||
| from .ocr_detection_pipeline import OCRDetectionPipeline | |||
| from .ocr_recognition_pipeline import OCRRecognitionPipeline | |||
| @@ -75,6 +76,8 @@ else: | |||
| ['Image2ImageTranslationPipeline'], | |||
| 'product_retrieval_embedding_pipeline': | |||
| ['ProductRetrievalEmbeddingPipeline'], | |||
| 'realtime_object_detection_pipeline': | |||
| ['RealtimeObjectDetectionPipeline'], | |||
| 'live_category_pipeline': ['LiveCategoryPipeline'], | |||
| 'image_to_image_generation_pipeline': | |||
| ['Image2ImageGenerationPipeline'], | |||
| @@ -0,0 +1,50 @@ | |||
| import os.path as osp | |||
| from typing import Any, Dict, List, Union | |||
| import cv2 | |||
| import json | |||
| import numpy as np | |||
| import torch | |||
| from PIL import Image | |||
| from torchvision import transforms | |||
| from modelscope.metainfo import Pipelines | |||
| from modelscope.models.cv.realtime_object_detection import RealtimeDetector | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines import pipeline | |||
| from modelscope.pipelines.base import Input, Model, Pipeline, Tensor | |||
| from modelscope.pipelines.builder import PIPELINES | |||
| from modelscope.preprocessors import load_image | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| from modelscope.utils.logger import get_logger | |||
| logger = get_logger() | |||
| @PIPELINES.register_module( | |||
| Tasks.image_object_detection, | |||
| module_name=Pipelines.realtime_object_detection) | |||
| class RealtimeObjectDetectionPipeline(Pipeline): | |||
| def __init__(self, model: str, **kwargs): | |||
| super().__init__(model=model, **kwargs) | |||
| self.model = RealtimeDetector(model) | |||
| def preprocess(self, input: Input) -> Dict[Tensor, Union[str, np.ndarray]]: | |||
| output = self.model.preprocess(input) | |||
| return {'pre_output': output} | |||
| def forward(self, input: Tensor) -> Dict[Tensor, Dict[str, np.ndarray]]: | |||
| pre_output = input['pre_output'] | |||
| forward_output = self.model(pre_output) | |||
| return {'forward_output': forward_output} | |||
| def postprocess(self, input: Dict[Tensor, Dict[str, np.ndarray]], | |||
| **kwargs) -> str: | |||
| forward_output = input['forward_output'] | |||
| bboxes, scores, labels = forward_output | |||
| return { | |||
| OutputKeys.BOXES: bboxes, | |||
| OutputKeys.SCORES: scores, | |||
| OutputKeys.LABELS: labels, | |||
| } | |||
| @@ -70,6 +70,13 @@ def draw_box(image, box): | |||
| (int(box[1][0]), int(box[1][1])), (0, 0, 255), 2) | |||
| def realtime_object_detection_bbox_vis(image, bboxes): | |||
| for bbox in bboxes: | |||
| cv2.rectangle(image, (bbox[0], bbox[1]), (bbox[2], bbox[3]), | |||
| (255, 0, 0), 2) | |||
| return image | |||
| def draw_keypoints(output, original_image): | |||
| poses = np.array(output[OutputKeys.POSES]) | |||
| scores = np.array(output[OutputKeys.SCORES]) | |||
| @@ -0,0 +1,52 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import unittest | |||
| import cv2 | |||
| import numpy as np | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines import pipeline | |||
| from modelscope.pipelines.base import Pipeline | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.cv.image_utils import realtime_object_detection_bbox_vis | |||
| from modelscope.utils.test_utils import test_level | |||
| class RealtimeObjectDetectionTest(unittest.TestCase): | |||
| def setUp(self) -> None: | |||
| self.model_id = 'damo/cv_cspnet_image-object-detection_yolox' | |||
| self.model_nano_id = 'damo/cv_cspnet_image-object-detection_yolox_nano_coco' | |||
| self.test_image = 'data/test/images/keypoints_detect/000000438862.jpg' | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_modelhub(self): | |||
| realtime_object_detection = pipeline( | |||
| Tasks.image_object_detection, model=self.model_id) | |||
| image = cv2.imread(self.test_image) | |||
| result = realtime_object_detection(image) | |||
| if result: | |||
| bboxes = result[OutputKeys.BOXES].astype(int) | |||
| image = realtime_object_detection_bbox_vis(image, bboxes) | |||
| cv2.imwrite('rt_obj_out.jpg', image) | |||
| else: | |||
| raise ValueError('process error') | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_nano(self): | |||
| realtime_object_detection = pipeline( | |||
| Tasks.image_object_detection, model=self.model_nano_id) | |||
| image = cv2.imread(self.test_image) | |||
| result = realtime_object_detection(image) | |||
| if result: | |||
| bboxes = result[OutputKeys.BOXES].astype(int) | |||
| image = realtime_object_detection_bbox_vis(image, bboxes) | |||
| cv2.imwrite('rtnano_obj_out.jpg', image) | |||
| else: | |||
| raise ValueError('process error') | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||