新增实时目标检测pipeline
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9788299
master
| @@ -11,6 +11,7 @@ class Models(object): | |||||
| """ | """ | ||||
| # vision models | # vision models | ||||
| detection = 'detection' | detection = 'detection' | ||||
| realtime_object_detection = 'realtime-object-detection' | |||||
| scrfd = 'scrfd' | scrfd = 'scrfd' | ||||
| classification_model = 'ClassificationModel' | classification_model = 'ClassificationModel' | ||||
| nafnet = 'nafnet' | nafnet = 'nafnet' | ||||
| @@ -111,6 +112,7 @@ class Pipelines(object): | |||||
| image_super_resolution = 'rrdb-image-super-resolution' | image_super_resolution = 'rrdb-image-super-resolution' | ||||
| face_image_generation = 'gan-face-image-generation' | face_image_generation = 'gan-face-image-generation' | ||||
| product_retrieval_embedding = 'resnet50-product-retrieval-embedding' | product_retrieval_embedding = 'resnet50-product-retrieval-embedding' | ||||
| realtime_object_detection = 'cspnet_realtime-object-detection_yolox' | |||||
| face_recognition = 'ir101-face-recognition-cfglint' | face_recognition = 'ir101-face-recognition-cfglint' | ||||
| image_instance_segmentation = 'cascade-mask-rcnn-swin-image-instance-segmentation' | image_instance_segmentation = 'cascade-mask-rcnn-swin-image-instance-segmentation' | ||||
| image2image_translation = 'image-to-image-translation' | image2image_translation = 'image-to-image-translation' | ||||
| @@ -7,5 +7,5 @@ from . import (action_recognition, animal_recognition, body_2d_keypoints, | |||||
| image_reid_person, image_semantic_segmentation, | image_reid_person, image_semantic_segmentation, | ||||
| image_to_image_generation, image_to_image_translation, | image_to_image_generation, image_to_image_translation, | ||||
| object_detection, product_retrieval_embedding, | object_detection, product_retrieval_embedding, | ||||
| salient_detection, super_resolution, | |||||
| realtime_object_detection, salient_detection, super_resolution, | |||||
| video_single_object_tracking, video_summarization, virual_tryon) | video_single_object_tracking, video_summarization, virual_tryon) | ||||
| @@ -0,0 +1,21 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from typing import TYPE_CHECKING | |||||
| from modelscope.utils.import_utils import LazyImportModule | |||||
| if TYPE_CHECKING: | |||||
| from .realtime_detector import RealtimeDetector | |||||
| else: | |||||
| _import_structure = { | |||||
| 'realtime_detector': ['RealtimeDetector'], | |||||
| } | |||||
| import sys | |||||
| sys.modules[__name__] = LazyImportModule( | |||||
| __name__, | |||||
| globals()['__file__'], | |||||
| _import_structure, | |||||
| module_spec=__spec__, | |||||
| extra_objects={}, | |||||
| ) | |||||
| @@ -0,0 +1,85 @@ | |||||
| import argparse | |||||
| import logging as logger | |||||
| import os | |||||
| import os.path as osp | |||||
| import time | |||||
| import cv2 | |||||
| import json | |||||
| import torch | |||||
| from modelscope.metainfo import Models | |||||
| from modelscope.models.base.base_torch_model import TorchModel | |||||
| from modelscope.models.builder import MODELS | |||||
| from modelscope.preprocessors import LoadImage | |||||
| from modelscope.utils.config import Config | |||||
| from modelscope.utils.constant import ModelFile, Tasks | |||||
| from .yolox.data.data_augment import ValTransform | |||||
| from .yolox.exp import get_exp_by_name | |||||
| from .yolox.utils import postprocess | |||||
| @MODELS.register_module( | |||||
| group_key=Tasks.image_object_detection, | |||||
| module_name=Models.realtime_object_detection) | |||||
| class RealtimeDetector(TorchModel): | |||||
| def __init__(self, model_dir: str, *args, **kwargs): | |||||
| super().__init__(model_dir, *args, **kwargs) | |||||
| self.config = Config.from_file( | |||||
| os.path.join(self.model_dir, ModelFile.CONFIGURATION)) | |||||
| # model type | |||||
| self.exp = get_exp_by_name(self.config.model_type) | |||||
| # build model | |||||
| self.model = self.exp.get_model() | |||||
| model_path = osp.join(model_dir, ModelFile.TORCH_MODEL_BIN_FILE) | |||||
| ckpt = torch.load(model_path, map_location='cpu') | |||||
| # load the model state dict | |||||
| self.model.load_state_dict(ckpt['model']) | |||||
| self.model.eval() | |||||
| # params setting | |||||
| self.exp.num_classes = self.config.num_classes | |||||
| self.confthre = self.config.conf_thr | |||||
| self.num_classes = self.exp.num_classes | |||||
| self.nmsthre = self.exp.nmsthre | |||||
| self.test_size = self.exp.test_size | |||||
| self.preproc = ValTransform(legacy=False) | |||||
| def inference(self, img): | |||||
| with torch.no_grad(): | |||||
| outputs = self.model(img) | |||||
| return outputs | |||||
| def forward(self, inputs): | |||||
| return self.inference(inputs) | |||||
| def preprocess(self, img): | |||||
| img = LoadImage.convert_to_ndarray(img) | |||||
| height, width = img.shape[:2] | |||||
| self.ratio = min(self.test_size[0] / img.shape[0], | |||||
| self.test_size[1] / img.shape[1]) | |||||
| img, _ = self.preproc(img, None, self.test_size) | |||||
| img = torch.from_numpy(img).unsqueeze(0) | |||||
| img = img.float() | |||||
| return img | |||||
| def postprocess(self, input): | |||||
| outputs = postprocess( | |||||
| input, | |||||
| self.num_classes, | |||||
| self.confthre, | |||||
| self.nmsthre, | |||||
| class_agnostic=True) | |||||
| if len(outputs) == 1: | |||||
| bboxes = outputs[0][:, 0:4].cpu().numpy() / self.ratio | |||||
| scores = outputs[0][:, 5].cpu().numpy() | |||||
| labels = outputs[0][:, 6].cpu().int().numpy() | |||||
| return bboxes, scores, labels | |||||
| @@ -0,0 +1,69 @@ | |||||
| # The implementation is based on YOLOX, available at https://github.com/Megvii-BaseDetection/YOLOX | |||||
| """ | |||||
| Data augmentation functionality. Passed as callable transformations to | |||||
| Dataset classes. | |||||
| The data augmentation procedures were interpreted from @weiliu89's SSD paper | |||||
| http://arxiv.org/abs/1512.02325 | |||||
| """ | |||||
| import math | |||||
| import random | |||||
| import cv2 | |||||
| import numpy as np | |||||
| from ..utils import xyxy2cxcywh | |||||
| def preproc(img, input_size, swap=(2, 0, 1)): | |||||
| if len(img.shape) == 3: | |||||
| padded_img = np.ones( | |||||
| (input_size[0], input_size[1], 3), dtype=np.uint8) * 114 | |||||
| else: | |||||
| padded_img = np.ones(input_size, dtype=np.uint8) * 114 | |||||
| r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1]) | |||||
| resized_img = cv2.resize( | |||||
| img, | |||||
| (int(img.shape[1] * r), int(img.shape[0] * r)), | |||||
| interpolation=cv2.INTER_LINEAR, | |||||
| ).astype(np.uint8) | |||||
| padded_img[:int(img.shape[0] * r), :int(img.shape[1] * r)] = resized_img | |||||
| padded_img = padded_img.transpose(swap) | |||||
| padded_img = np.ascontiguousarray(padded_img, dtype=np.float32) | |||||
| return padded_img, r | |||||
| class ValTransform: | |||||
| """ | |||||
| Defines the transformations that should be applied to test PIL image | |||||
| for input into the network | |||||
| dimension -> tensorize -> color adj | |||||
| Arguments: | |||||
| resize (int): input dimension to SSD | |||||
| rgb_means ((int,int,int)): average RGB of the dataset | |||||
| (104,117,123) | |||||
| swap ((int,int,int)): final order of channels | |||||
| Returns: | |||||
| transform (transform) : callable transform to be applied to test/val | |||||
| data | |||||
| """ | |||||
| def __init__(self, swap=(2, 0, 1), legacy=False): | |||||
| self.swap = swap | |||||
| self.legacy = legacy | |||||
| # assume input is cv2 img for now | |||||
| def __call__(self, img, res, input_size): | |||||
| img, _ = preproc(img, input_size, self.swap) | |||||
| if self.legacy: | |||||
| img = img[::-1, :, :].copy() | |||||
| img /= 255.0 | |||||
| img -= np.array([0.485, 0.456, 0.406]).reshape(3, 1, 1) | |||||
| img /= np.array([0.229, 0.224, 0.225]).reshape(3, 1, 1) | |||||
| return img, np.zeros((1, 5)) | |||||
| @@ -0,0 +1,5 @@ | |||||
| # The implementation is based on YOLOX, available at https://github.com/Megvii-BaseDetection/YOLOX | |||||
| from .base_exp import BaseExp | |||||
| from .build import get_exp_by_name | |||||
| from .yolox_base import Exp | |||||
| @@ -0,0 +1,12 @@ | |||||
| # The implementation is based on YOLOX, available at https://github.com/Megvii-BaseDetection/YOLOX | |||||
| from abc import ABCMeta, abstractmethod | |||||
| from torch.nn import Module | |||||
| class BaseExp(metaclass=ABCMeta): | |||||
| @abstractmethod | |||||
| def get_model(self) -> Module: | |||||
| pass | |||||
| @@ -0,0 +1,18 @@ | |||||
| # The implementation is based on YOLOX, available at https://github.com/Megvii-BaseDetection/YOLOX | |||||
| import os | |||||
| import sys | |||||
| def get_exp_by_name(exp_name): | |||||
| exp = exp_name.replace('-', | |||||
| '_') # convert string like "yolox-s" to "yolox_s" | |||||
| if exp == 'yolox_s': | |||||
| from .default import YoloXSExp as YoloXExp | |||||
| elif exp == 'yolox_nano': | |||||
| from .default import YoloXNanoExp as YoloXExp | |||||
| elif exp == 'yolox_tiny': | |||||
| from .default import YoloXTinyExp as YoloXExp | |||||
| else: | |||||
| pass | |||||
| return YoloXExp() | |||||
| @@ -0,0 +1,5 @@ | |||||
| # The implementation is based on YOLOX, available at https://github.com/Megvii-BaseDetection/YOLOX | |||||
| from .yolox_nano import YoloXNanoExp | |||||
| from .yolox_s import YoloXSExp | |||||
| from .yolox_tiny import YoloXTinyExp | |||||
| @@ -0,0 +1,46 @@ | |||||
| # The implementation is based on YOLOX, available at https://github.com/Megvii-BaseDetection/YOLOX | |||||
| import os | |||||
| import torch.nn as nn | |||||
| from ..yolox_base import Exp as YoloXExp | |||||
| class YoloXNanoExp(YoloXExp): | |||||
| def __init__(self): | |||||
| super(YoloXNanoExp, self).__init__() | |||||
| self.depth = 0.33 | |||||
| self.width = 0.25 | |||||
| self.input_size = (416, 416) | |||||
| self.test_size = (416, 416) | |||||
| def get_model(self, sublinear=False): | |||||
| def init_yolo(M): | |||||
| for m in M.modules(): | |||||
| if isinstance(m, nn.BatchNorm2d): | |||||
| m.eps = 1e-3 | |||||
| m.momentum = 0.03 | |||||
| if 'model' not in self.__dict__: | |||||
| from ...models import YOLOX, YOLOPAFPN, YOLOXHead | |||||
| in_channels = [256, 512, 1024] | |||||
| # NANO model use depthwise = True, which is main difference. | |||||
| backbone = YOLOPAFPN( | |||||
| self.depth, | |||||
| self.width, | |||||
| in_channels=in_channels, | |||||
| act=self.act, | |||||
| depthwise=True, | |||||
| ) | |||||
| head = YOLOXHead( | |||||
| self.num_classes, | |||||
| self.width, | |||||
| in_channels=in_channels, | |||||
| act=self.act, | |||||
| depthwise=True) | |||||
| self.model = YOLOX(backbone, head) | |||||
| return self.model | |||||
| @@ -0,0 +1,13 @@ | |||||
| # The implementation is based on YOLOX, available at https://github.com/Megvii-BaseDetection/YOLOX | |||||
| import os | |||||
| from ..yolox_base import Exp as YoloXExp | |||||
| class YoloXSExp(YoloXExp): | |||||
| def __init__(self): | |||||
| super(YoloXSExp, self).__init__() | |||||
| self.depth = 0.33 | |||||
| self.width = 0.50 | |||||
| @@ -0,0 +1,20 @@ | |||||
| # The implementation is based on YOLOX, available at https://github.com/Megvii-BaseDetection/YOLOX | |||||
| import os | |||||
| from ..yolox_base import Exp as YoloXExp | |||||
| class YoloXTinyExp(YoloXExp): | |||||
| def __init__(self): | |||||
| super(YoloXTinyExp, self).__init__() | |||||
| self.depth = 0.33 | |||||
| self.width = 0.375 | |||||
| self.input_size = (416, 416) | |||||
| self.mosaic_scale = (0.5, 1.5) | |||||
| self.random_size = (10, 20) | |||||
| self.test_size = (416, 416) | |||||
| self.exp_name = os.path.split( | |||||
| os.path.realpath(__file__))[1].split('.')[0] | |||||
| self.enable_mixup = False | |||||
| @@ -0,0 +1,59 @@ | |||||
| # The implementation is based on YOLOX, available at https://github.com/Megvii-BaseDetection/YOLOX | |||||
| import os | |||||
| import random | |||||
| import torch | |||||
| import torch.distributed as dist | |||||
| import torch.nn as nn | |||||
| from .base_exp import BaseExp | |||||
| class Exp(BaseExp): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| # ---------------- model config ---------------- # | |||||
| # detect classes number of model | |||||
| self.num_classes = 80 | |||||
| # factor of model depth | |||||
| self.depth = 1.00 | |||||
| # factor of model width | |||||
| self.width = 1.00 | |||||
| # activation name. For example, if using "relu", then "silu" will be replaced to "relu". | |||||
| self.act = 'silu' | |||||
| # ----------------- testing config ------------------ # | |||||
| # output image size during evaluation/test | |||||
| self.test_size = (640, 640) | |||||
| # confidence threshold during evaluation/test, | |||||
| # boxes whose scores are less than test_conf will be filtered | |||||
| self.test_conf = 0.01 | |||||
| # nms threshold | |||||
| self.nmsthre = 0.65 | |||||
| def get_model(self): | |||||
| from ..models import YOLOX, YOLOPAFPN, YOLOXHead | |||||
| def init_yolo(M): | |||||
| for m in M.modules(): | |||||
| if isinstance(m, nn.BatchNorm2d): | |||||
| m.eps = 1e-3 | |||||
| m.momentum = 0.03 | |||||
| if getattr(self, 'model', None) is None: | |||||
| in_channels = [256, 512, 1024] | |||||
| backbone = YOLOPAFPN( | |||||
| self.depth, self.width, in_channels=in_channels, act=self.act) | |||||
| head = YOLOXHead( | |||||
| self.num_classes, | |||||
| self.width, | |||||
| in_channels=in_channels, | |||||
| act=self.act) | |||||
| self.model = YOLOX(backbone, head) | |||||
| self.model.apply(init_yolo) | |||||
| self.model.head.initialize_biases(1e-2) | |||||
| self.model.train() | |||||
| return self.model | |||||
| @@ -0,0 +1,7 @@ | |||||
| # The implementation is based on YOLOX, available at https://github.com/Megvii-BaseDetection/YOLOX | |||||
| from .darknet import CSPDarknet, Darknet | |||||
| from .yolo_fpn import YOLOFPN | |||||
| from .yolo_head import YOLOXHead | |||||
| from .yolo_pafpn import YOLOPAFPN | |||||
| from .yolox import YOLOX | |||||
| @@ -0,0 +1,189 @@ | |||||
| # The implementation is based on YOLOX, available at https://github.com/Megvii-BaseDetection/YOLOX | |||||
| from torch import nn | |||||
| from .network_blocks import (BaseConv, CSPLayer, DWConv, Focus, ResLayer, | |||||
| SPPBottleneck) | |||||
| class Darknet(nn.Module): | |||||
| # number of blocks from dark2 to dark5. | |||||
| depth2blocks = {21: [1, 2, 2, 1], 53: [2, 8, 8, 4]} | |||||
| def __init__( | |||||
| self, | |||||
| depth, | |||||
| in_channels=3, | |||||
| stem_out_channels=32, | |||||
| out_features=('dark3', 'dark4', 'dark5'), | |||||
| ): | |||||
| """ | |||||
| Args: | |||||
| depth (int): depth of darknet used in model, usually use [21, 53] for this param. | |||||
| in_channels (int): number of input channels, for example, use 3 for RGB image. | |||||
| stem_out_channels (int): number of output channels of darknet stem. | |||||
| It decides channels of darknet layer2 to layer5. | |||||
| out_features (Tuple[str]): desired output layer name. | |||||
| """ | |||||
| super().__init__() | |||||
| assert out_features, 'please provide output features of Darknet' | |||||
| self.out_features = out_features | |||||
| self.stem = nn.Sequential( | |||||
| BaseConv( | |||||
| in_channels, stem_out_channels, ksize=3, stride=1, | |||||
| act='lrelu'), | |||||
| *self.make_group_layer(stem_out_channels, num_blocks=1, stride=2), | |||||
| ) | |||||
| in_channels = stem_out_channels * 2 # 64 | |||||
| num_blocks = Darknet.depth2blocks[depth] | |||||
| # create darknet with `stem_out_channels` and `num_blocks` layers. | |||||
| # to make model structure more clear, we don't use `for` statement in python. | |||||
| self.dark2 = nn.Sequential( | |||||
| *self.make_group_layer(in_channels, num_blocks[0], stride=2)) | |||||
| in_channels *= 2 # 128 | |||||
| self.dark3 = nn.Sequential( | |||||
| *self.make_group_layer(in_channels, num_blocks[1], stride=2)) | |||||
| in_channels *= 2 # 256 | |||||
| self.dark4 = nn.Sequential( | |||||
| *self.make_group_layer(in_channels, num_blocks[2], stride=2)) | |||||
| in_channels *= 2 # 512 | |||||
| self.dark5 = nn.Sequential( | |||||
| *self.make_group_layer(in_channels, num_blocks[3], stride=2), | |||||
| *self.make_spp_block([in_channels, in_channels * 2], | |||||
| in_channels * 2), | |||||
| ) | |||||
| def make_group_layer(self, | |||||
| in_channels: int, | |||||
| num_blocks: int, | |||||
| stride: int = 1): | |||||
| 'starts with conv layer then has `num_blocks` `ResLayer`' | |||||
| return [ | |||||
| BaseConv( | |||||
| in_channels, | |||||
| in_channels * 2, | |||||
| ksize=3, | |||||
| stride=stride, | |||||
| act='lrelu'), | |||||
| *[(ResLayer(in_channels * 2)) for _ in range(num_blocks)], | |||||
| ] | |||||
| def make_spp_block(self, filters_list, in_filters): | |||||
| m = nn.Sequential(*[ | |||||
| BaseConv(in_filters, filters_list[0], 1, stride=1, act='lrelu'), | |||||
| BaseConv( | |||||
| filters_list[0], filters_list[1], 3, stride=1, act='lrelu'), | |||||
| SPPBottleneck( | |||||
| in_channels=filters_list[1], | |||||
| out_channels=filters_list[0], | |||||
| activation='lrelu', | |||||
| ), | |||||
| BaseConv( | |||||
| filters_list[0], filters_list[1], 3, stride=1, act='lrelu'), | |||||
| BaseConv( | |||||
| filters_list[1], filters_list[0], 1, stride=1, act='lrelu'), | |||||
| ]) | |||||
| return m | |||||
| def forward(self, x): | |||||
| outputs = {} | |||||
| x = self.stem(x) | |||||
| outputs['stem'] = x | |||||
| x = self.dark2(x) | |||||
| outputs['dark2'] = x | |||||
| x = self.dark3(x) | |||||
| outputs['dark3'] = x | |||||
| x = self.dark4(x) | |||||
| outputs['dark4'] = x | |||||
| x = self.dark5(x) | |||||
| outputs['dark5'] = x | |||||
| return {k: v for k, v in outputs.items() if k in self.out_features} | |||||
| class CSPDarknet(nn.Module): | |||||
| def __init__( | |||||
| self, | |||||
| dep_mul, | |||||
| wid_mul, | |||||
| out_features=('dark3', 'dark4', 'dark5'), | |||||
| depthwise=False, | |||||
| act='silu', | |||||
| ): | |||||
| super().__init__() | |||||
| assert out_features, 'please provide output features of Darknet' | |||||
| self.out_features = out_features | |||||
| Conv = DWConv if depthwise else BaseConv | |||||
| base_channels = int(wid_mul * 64) # 64 | |||||
| base_depth = max(round(dep_mul * 3), 1) # 3 | |||||
| # stem | |||||
| self.stem = Focus(3, base_channels, ksize=3, act=act) | |||||
| # dark2 | |||||
| self.dark2 = nn.Sequential( | |||||
| Conv(base_channels, base_channels * 2, 3, 2, act=act), | |||||
| CSPLayer( | |||||
| base_channels * 2, | |||||
| base_channels * 2, | |||||
| n=base_depth, | |||||
| depthwise=depthwise, | |||||
| act=act, | |||||
| ), | |||||
| ) | |||||
| # dark3 | |||||
| self.dark3 = nn.Sequential( | |||||
| Conv(base_channels * 2, base_channels * 4, 3, 2, act=act), | |||||
| CSPLayer( | |||||
| base_channels * 4, | |||||
| base_channels * 4, | |||||
| n=base_depth * 3, | |||||
| depthwise=depthwise, | |||||
| act=act, | |||||
| ), | |||||
| ) | |||||
| # dark4 | |||||
| self.dark4 = nn.Sequential( | |||||
| Conv(base_channels * 4, base_channels * 8, 3, 2, act=act), | |||||
| CSPLayer( | |||||
| base_channels * 8, | |||||
| base_channels * 8, | |||||
| n=base_depth * 3, | |||||
| depthwise=depthwise, | |||||
| act=act, | |||||
| ), | |||||
| ) | |||||
| # dark5 | |||||
| self.dark5 = nn.Sequential( | |||||
| Conv(base_channels * 8, base_channels * 16, 3, 2, act=act), | |||||
| SPPBottleneck( | |||||
| base_channels * 16, base_channels * 16, activation=act), | |||||
| CSPLayer( | |||||
| base_channels * 16, | |||||
| base_channels * 16, | |||||
| n=base_depth, | |||||
| shortcut=False, | |||||
| depthwise=depthwise, | |||||
| act=act, | |||||
| ), | |||||
| ) | |||||
| def forward(self, x): | |||||
| outputs = {} | |||||
| x = self.stem(x) | |||||
| outputs['stem'] = x | |||||
| x = self.dark2(x) | |||||
| outputs['dark2'] = x | |||||
| x = self.dark3(x) | |||||
| outputs['dark3'] = x | |||||
| x = self.dark4(x) | |||||
| outputs['dark4'] = x | |||||
| x = self.dark5(x) | |||||
| outputs['dark5'] = x | |||||
| return {k: v for k, v in outputs.items() if k in self.out_features} | |||||
| @@ -0,0 +1,213 @@ | |||||
| # The implementation is based on YOLOX, available at https://github.com/Megvii-BaseDetection/YOLOX | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| def get_activation(name='silu', inplace=True): | |||||
| if name == 'silu': | |||||
| module = nn.SiLU(inplace=inplace) | |||||
| else: | |||||
| raise AttributeError('Unsupported act type: {}'.format(name)) | |||||
| return module | |||||
| class BaseConv(nn.Module): | |||||
| """A Conv2d -> Batchnorm -> silu/leaky relu block""" | |||||
| def __init__(self, | |||||
| in_channels, | |||||
| out_channels, | |||||
| ksize, | |||||
| stride, | |||||
| groups=1, | |||||
| bias=False, | |||||
| act='silu'): | |||||
| super(BaseConv, self).__init__() | |||||
| # same padding | |||||
| pad = (ksize - 1) // 2 | |||||
| self.conv = nn.Conv2d( | |||||
| in_channels, | |||||
| out_channels, | |||||
| kernel_size=ksize, | |||||
| stride=stride, | |||||
| padding=pad, | |||||
| groups=groups, | |||||
| bias=bias, | |||||
| ) | |||||
| self.bn = nn.BatchNorm2d(out_channels) | |||||
| self.act = get_activation(act, inplace=True) | |||||
| def forward(self, x): | |||||
| return self.act(self.bn(self.conv(x))) | |||||
| def fuseforward(self, x): | |||||
| return self.act(self.conv(x)) | |||||
| class DWConv(nn.Module): | |||||
| """Depthwise Conv + Conv""" | |||||
| def __init__(self, in_channels, out_channels, ksize, stride=1, act='silu'): | |||||
| super(DWConv, self).__init__() | |||||
| self.dconv = BaseConv( | |||||
| in_channels, | |||||
| in_channels, | |||||
| ksize=ksize, | |||||
| stride=stride, | |||||
| groups=in_channels, | |||||
| act=act, | |||||
| ) | |||||
| self.pconv = BaseConv( | |||||
| in_channels, out_channels, ksize=1, stride=1, groups=1, act=act) | |||||
| def forward(self, x): | |||||
| x = self.dconv(x) | |||||
| return self.pconv(x) | |||||
| class Bottleneck(nn.Module): | |||||
| # Standard bottleneck | |||||
| def __init__( | |||||
| self, | |||||
| in_channels, | |||||
| out_channels, | |||||
| shortcut=True, | |||||
| expansion=0.5, | |||||
| depthwise=False, | |||||
| act='silu', | |||||
| ): | |||||
| super().__init__() | |||||
| hidden_channels = int(out_channels * expansion) | |||||
| Conv = DWConv if depthwise else BaseConv | |||||
| self.conv1 = BaseConv( | |||||
| in_channels, hidden_channels, 1, stride=1, act=act) | |||||
| self.conv2 = Conv(hidden_channels, out_channels, 3, stride=1, act=act) | |||||
| self.use_add = shortcut and in_channels == out_channels | |||||
| def forward(self, x): | |||||
| y = self.conv2(self.conv1(x)) | |||||
| if self.use_add: | |||||
| y = y + x | |||||
| return y | |||||
| class ResLayer(nn.Module): | |||||
| 'Residual layer with `in_channels` inputs.' | |||||
| def __init__(self, in_channels: int): | |||||
| super().__init__() | |||||
| mid_channels = in_channels // 2 | |||||
| self.layer1 = BaseConv( | |||||
| in_channels, mid_channels, ksize=1, stride=1, act='lrelu') | |||||
| self.layer2 = BaseConv( | |||||
| mid_channels, in_channels, ksize=3, stride=1, act='lrelu') | |||||
| def forward(self, x): | |||||
| out = self.layer2(self.layer1(x)) | |||||
| return x + out | |||||
| class SPPBottleneck(nn.Module): | |||||
| """Spatial pyramid pooling layer used in YOLOv3-SPP""" | |||||
| def __init__(self, | |||||
| in_channels, | |||||
| out_channels, | |||||
| kernel_sizes=(5, 9, 13), | |||||
| activation='silu'): | |||||
| super().__init__() | |||||
| hidden_channels = in_channels // 2 | |||||
| self.conv1 = BaseConv( | |||||
| in_channels, hidden_channels, 1, stride=1, act=activation) | |||||
| self.m = nn.ModuleList([ | |||||
| nn.MaxPool2d(kernel_size=ks, stride=1, padding=ks // 2) | |||||
| for ks in kernel_sizes | |||||
| ]) | |||||
| conv2_channels = hidden_channels * (len(kernel_sizes) + 1) | |||||
| self.conv2 = BaseConv( | |||||
| conv2_channels, out_channels, 1, stride=1, act=activation) | |||||
| def forward(self, x): | |||||
| x = self.conv1(x) | |||||
| x = torch.cat([x] + [m(x) for m in self.m], dim=1) | |||||
| x = self.conv2(x) | |||||
| return x | |||||
| class CSPLayer(nn.Module): | |||||
| """C3 in yolov5, CSP Bottleneck with 3 convolutions""" | |||||
| def __init__( | |||||
| self, | |||||
| in_channels, | |||||
| out_channels, | |||||
| n=1, | |||||
| shortcut=True, | |||||
| expansion=0.5, | |||||
| depthwise=False, | |||||
| act='silu', | |||||
| ): | |||||
| """ | |||||
| Args: | |||||
| in_channels (int): input channels. | |||||
| out_channels (int): output channels. | |||||
| n (int): number of Bottlenecks. Default value: 1. | |||||
| """ | |||||
| # ch_in, ch_out, number, shortcut, groups, expansion | |||||
| super().__init__() | |||||
| hidden_channels = int(out_channels * expansion) # hidden channels | |||||
| self.conv1 = BaseConv( | |||||
| in_channels, hidden_channels, 1, stride=1, act=act) | |||||
| self.conv2 = BaseConv( | |||||
| in_channels, hidden_channels, 1, stride=1, act=act) | |||||
| self.conv3 = BaseConv( | |||||
| 2 * hidden_channels, out_channels, 1, stride=1, act=act) | |||||
| module_list = [ | |||||
| Bottleneck( | |||||
| hidden_channels, | |||||
| hidden_channels, | |||||
| shortcut, | |||||
| 1.0, | |||||
| depthwise, | |||||
| act=act) for _ in range(n) | |||||
| ] | |||||
| self.m = nn.Sequential(*module_list) | |||||
| def forward(self, x): | |||||
| x_1 = self.conv1(x) | |||||
| x_2 = self.conv2(x) | |||||
| x_1 = self.m(x_1) | |||||
| x = torch.cat((x_1, x_2), dim=1) | |||||
| return self.conv3(x) | |||||
| class Focus(nn.Module): | |||||
| """Focus width and height information into channel space.""" | |||||
| def __init__(self, | |||||
| in_channels, | |||||
| out_channels, | |||||
| ksize=1, | |||||
| stride=1, | |||||
| act='silu'): | |||||
| super().__init__() | |||||
| self.conv = BaseConv( | |||||
| in_channels * 4, out_channels, ksize, stride, act=act) | |||||
| def forward(self, x): | |||||
| # shape of x (b,c,w,h) -> y(b,4c,w/2,h/2) | |||||
| patch_top_left = x[..., ::2, ::2] | |||||
| patch_top_right = x[..., ::2, 1::2] | |||||
| patch_bot_left = x[..., 1::2, ::2] | |||||
| patch_bot_right = x[..., 1::2, 1::2] | |||||
| x = torch.cat( | |||||
| ( | |||||
| patch_top_left, | |||||
| patch_bot_left, | |||||
| patch_top_right, | |||||
| patch_bot_right, | |||||
| ), | |||||
| dim=1, | |||||
| ) | |||||
| return self.conv(x) | |||||
| @@ -0,0 +1,80 @@ | |||||
| # The implementation is based on YOLOX, available at https://github.com/Megvii-BaseDetection/YOLOX | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| from .darknet import Darknet | |||||
| from .network_blocks import BaseConv | |||||
| class YOLOFPN(nn.Module): | |||||
| """ | |||||
| YOLOFPN module. Darknet 53 is the default backbone of this model. | |||||
| """ | |||||
| def __init__( | |||||
| self, | |||||
| depth=53, | |||||
| in_features=['dark3', 'dark4', 'dark5'], | |||||
| ): | |||||
| super(YOLOFPN, self).__init__() | |||||
| self.backbone = Darknet(depth) | |||||
| self.in_features = in_features | |||||
| # out 1 | |||||
| self.out1_cbl = self._make_cbl(512, 256, 1) | |||||
| self.out1 = self._make_embedding([256, 512], 512 + 256) | |||||
| # out 2 | |||||
| self.out2_cbl = self._make_cbl(256, 128, 1) | |||||
| self.out2 = self._make_embedding([128, 256], 256 + 128) | |||||
| # upsample | |||||
| self.upsample = nn.Upsample(scale_factor=2, mode='nearest') | |||||
| def _make_cbl(self, _in, _out, ks): | |||||
| return BaseConv(_in, _out, ks, stride=1, act='lrelu') | |||||
| def _make_embedding(self, filters_list, in_filters): | |||||
| m = nn.Sequential(*[ | |||||
| self._make_cbl(in_filters, filters_list[0], 1), | |||||
| self._make_cbl(filters_list[0], filters_list[1], 3), | |||||
| self._make_cbl(filters_list[1], filters_list[0], 1), | |||||
| self._make_cbl(filters_list[0], filters_list[1], 3), | |||||
| self._make_cbl(filters_list[1], filters_list[0], 1), | |||||
| ]) | |||||
| return m | |||||
| def load_pretrained_model(self, filename='./weights/darknet53.mix.pth'): | |||||
| with open(filename, 'rb') as f: | |||||
| state_dict = torch.load(f, map_location='cpu') | |||||
| print('loading pretrained weights...') | |||||
| self.backbone.load_state_dict(state_dict) | |||||
| def forward(self, inputs): | |||||
| """ | |||||
| Args: | |||||
| inputs (Tensor): input image. | |||||
| Returns: | |||||
| Tuple[Tensor]: FPN output features.. | |||||
| """ | |||||
| # backbone | |||||
| out_features = self.backbone(inputs) | |||||
| x2, x1, x0 = [out_features[f] for f in self.in_features] | |||||
| # yolo branch 1 | |||||
| x1_in = self.out1_cbl(x0) | |||||
| x1_in = self.upsample(x1_in) | |||||
| x1_in = torch.cat([x1_in, x1], 1) | |||||
| out_dark4 = self.out1(x1_in) | |||||
| # yolo branch 2 | |||||
| x2_in = self.out2_cbl(out_dark4) | |||||
| x2_in = self.upsample(x2_in) | |||||
| x2_in = torch.cat([x2_in, x2], 1) | |||||
| out_dark3 = self.out2(x2_in) | |||||
| outputs = (out_dark3, out_dark4, x0) | |||||
| return outputs | |||||
| @@ -0,0 +1,182 @@ | |||||
| # The implementation is based on YOLOX, available at https://github.com/Megvii-BaseDetection/YOLOX | |||||
| import math | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| import torch.nn.functional as F | |||||
| from ..utils import bboxes_iou, meshgrid | |||||
| from .network_blocks import BaseConv, DWConv | |||||
| class YOLOXHead(nn.Module): | |||||
| def __init__( | |||||
| self, | |||||
| num_classes, | |||||
| width=1.0, | |||||
| strides=[8, 16, 32], | |||||
| in_channels=[256, 512, 1024], | |||||
| act='silu', | |||||
| depthwise=False, | |||||
| ): | |||||
| """ | |||||
| Args: | |||||
| act (str): activation type of conv. Defalut value: "silu". | |||||
| depthwise (bool): whether apply depthwise conv in conv branch. Defalut value: False. | |||||
| """ | |||||
| super(YOLOXHead, self).__init__() | |||||
| self.n_anchors = 1 | |||||
| self.num_classes = num_classes | |||||
| self.decode_in_inference = True # for deploy, set to False | |||||
| self.cls_convs = nn.ModuleList() | |||||
| self.reg_convs = nn.ModuleList() | |||||
| self.cls_preds = nn.ModuleList() | |||||
| self.reg_preds = nn.ModuleList() | |||||
| self.obj_preds = nn.ModuleList() | |||||
| self.stems = nn.ModuleList() | |||||
| Conv = DWConv if depthwise else BaseConv | |||||
| for i in range(len(in_channels)): | |||||
| self.stems.append( | |||||
| BaseConv( | |||||
| in_channels=int(in_channels[i] * width), | |||||
| out_channels=int(256 * width), | |||||
| ksize=1, | |||||
| stride=1, | |||||
| act=act, | |||||
| )) | |||||
| self.cls_convs.append( | |||||
| nn.Sequential(*[ | |||||
| Conv( | |||||
| in_channels=int(256 * width), | |||||
| out_channels=int(256 * width), | |||||
| ksize=3, | |||||
| stride=1, | |||||
| act=act, | |||||
| ), | |||||
| Conv( | |||||
| in_channels=int(256 * width), | |||||
| out_channels=int(256 * width), | |||||
| ksize=3, | |||||
| stride=1, | |||||
| act=act, | |||||
| ), | |||||
| ])) | |||||
| self.reg_convs.append( | |||||
| nn.Sequential(*[ | |||||
| Conv( | |||||
| in_channels=int(256 * width), | |||||
| out_channels=int(256 * width), | |||||
| ksize=3, | |||||
| stride=1, | |||||
| act=act, | |||||
| ), | |||||
| Conv( | |||||
| in_channels=int(256 * width), | |||||
| out_channels=int(256 * width), | |||||
| ksize=3, | |||||
| stride=1, | |||||
| act=act, | |||||
| ), | |||||
| ])) | |||||
| self.cls_preds.append( | |||||
| nn.Conv2d( | |||||
| in_channels=int(256 * width), | |||||
| out_channels=self.n_anchors * self.num_classes, | |||||
| kernel_size=1, | |||||
| stride=1, | |||||
| padding=0, | |||||
| )) | |||||
| self.reg_preds.append( | |||||
| nn.Conv2d( | |||||
| in_channels=int(256 * width), | |||||
| out_channels=4, | |||||
| kernel_size=1, | |||||
| stride=1, | |||||
| padding=0, | |||||
| )) | |||||
| self.obj_preds.append( | |||||
| nn.Conv2d( | |||||
| in_channels=int(256 * width), | |||||
| out_channels=self.n_anchors * 1, | |||||
| kernel_size=1, | |||||
| stride=1, | |||||
| padding=0, | |||||
| )) | |||||
| self.use_l1 = False | |||||
| self.l1_loss = nn.L1Loss(reduction='none') | |||||
| self.bcewithlog_loss = nn.BCEWithLogitsLoss(reduction='none') | |||||
| # self.iou_loss = IOUloss(reduction="none") | |||||
| self.strides = strides | |||||
| self.grids = [torch.zeros(1)] * len(in_channels) | |||||
| def initialize_biases(self, prior_prob): | |||||
| for conv in self.cls_preds: | |||||
| b = conv.bias.view(self.n_anchors, -1) | |||||
| b.data.fill_(-math.log((1 - prior_prob) / prior_prob)) | |||||
| conv.bias = torch.nn.Parameter(b.view(-1), requires_grad=True) | |||||
| for conv in self.obj_preds: | |||||
| b = conv.bias.view(self.n_anchors, -1) | |||||
| b.data.fill_(-math.log((1 - prior_prob) / prior_prob)) | |||||
| conv.bias = torch.nn.Parameter(b.view(-1), requires_grad=True) | |||||
| def forward(self, xin, labels=None, imgs=None): | |||||
| outputs = [] | |||||
| for k, (cls_conv, reg_conv, stride_this_level, x) in enumerate( | |||||
| zip(self.cls_convs, self.reg_convs, self.strides, xin)): | |||||
| x = self.stems[k](x) | |||||
| cls_x = x | |||||
| reg_x = x | |||||
| cls_feat = cls_conv(cls_x) | |||||
| cls_output = self.cls_preds[k](cls_feat) | |||||
| reg_feat = reg_conv(reg_x) | |||||
| reg_output = self.reg_preds[k](reg_feat) | |||||
| obj_output = self.obj_preds[k](reg_feat) | |||||
| if self.training: | |||||
| pass | |||||
| else: | |||||
| output = torch.cat( | |||||
| [reg_output, | |||||
| obj_output.sigmoid(), | |||||
| cls_output.sigmoid()], 1) | |||||
| outputs.append(output) | |||||
| if self.training: | |||||
| pass | |||||
| else: | |||||
| self.hw = [x.shape[-2:] for x in outputs] | |||||
| # [batch, n_anchors_all, 85] | |||||
| outputs = torch.cat([x.flatten(start_dim=2) for x in outputs], | |||||
| dim=2).permute(0, 2, 1) | |||||
| if self.decode_in_inference: | |||||
| return self.decode_outputs(outputs, dtype=xin[0].type()) | |||||
| else: | |||||
| return outputs | |||||
| def decode_outputs(self, outputs, dtype): | |||||
| grids = [] | |||||
| strides = [] | |||||
| for (hsize, wsize), stride in zip(self.hw, self.strides): | |||||
| yv, xv = meshgrid([torch.arange(hsize), torch.arange(wsize)]) | |||||
| grid = torch.stack((xv, yv), 2).view(1, -1, 2) | |||||
| grids.append(grid) | |||||
| shape = grid.shape[:2] | |||||
| strides.append(torch.full((*shape, 1), stride)) | |||||
| grids = torch.cat(grids, dim=1).type(dtype) | |||||
| strides = torch.cat(strides, dim=1).type(dtype) | |||||
| outputs[..., :2] = (outputs[..., :2] + grids) * strides | |||||
| outputs[..., 2:4] = torch.exp(outputs[..., 2:4]) * strides | |||||
| return outputs | |||||
| @@ -0,0 +1,126 @@ | |||||
| # The implementation is based on YOLOX, available at https://github.com/Megvii-BaseDetection/YOLOX | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| from .darknet import CSPDarknet | |||||
| from .network_blocks import BaseConv, CSPLayer, DWConv | |||||
| class YOLOPAFPN(nn.Module): | |||||
| """ | |||||
| YOLOv3 model. Darknet 53 is the default backbone of this model. | |||||
| """ | |||||
| def __init__( | |||||
| self, | |||||
| depth=1.0, | |||||
| width=1.0, | |||||
| in_features=('dark3', 'dark4', 'dark5'), | |||||
| in_channels=[256, 512, 1024], | |||||
| depthwise=False, | |||||
| act='silu', | |||||
| ): | |||||
| super(YOLOPAFPN, self).__init__() | |||||
| self.backbone = CSPDarknet(depth, width, depthwise=depthwise, act=act) | |||||
| self.in_features = in_features | |||||
| self.in_channels = in_channels | |||||
| Conv = DWConv if depthwise else BaseConv | |||||
| self.upsample = nn.Upsample(scale_factor=2, mode='nearest') | |||||
| self.lateral_conv0 = BaseConv( | |||||
| int(in_channels[2] * width), | |||||
| int(in_channels[1] * width), | |||||
| 1, | |||||
| 1, | |||||
| act=act) | |||||
| self.C3_p4 = CSPLayer( | |||||
| int(2 * in_channels[1] * width), | |||||
| int(in_channels[1] * width), | |||||
| round(3 * depth), | |||||
| False, | |||||
| depthwise=depthwise, | |||||
| act=act, | |||||
| ) # cat | |||||
| self.reduce_conv1 = BaseConv( | |||||
| int(in_channels[1] * width), | |||||
| int(in_channels[0] * width), | |||||
| 1, | |||||
| 1, | |||||
| act=act) | |||||
| self.C3_p3 = CSPLayer( | |||||
| int(2 * in_channels[0] * width), | |||||
| int(in_channels[0] * width), | |||||
| round(3 * depth), | |||||
| False, | |||||
| depthwise=depthwise, | |||||
| act=act, | |||||
| ) | |||||
| # bottom-up conv | |||||
| self.bu_conv2 = Conv( | |||||
| int(in_channels[0] * width), | |||||
| int(in_channels[0] * width), | |||||
| 3, | |||||
| 2, | |||||
| act=act) | |||||
| self.C3_n3 = CSPLayer( | |||||
| int(2 * in_channels[0] * width), | |||||
| int(in_channels[1] * width), | |||||
| round(3 * depth), | |||||
| False, | |||||
| depthwise=depthwise, | |||||
| act=act, | |||||
| ) | |||||
| # bottom-up conv | |||||
| self.bu_conv1 = Conv( | |||||
| int(in_channels[1] * width), | |||||
| int(in_channels[1] * width), | |||||
| 3, | |||||
| 2, | |||||
| act=act) | |||||
| self.C3_n4 = CSPLayer( | |||||
| int(2 * in_channels[1] * width), | |||||
| int(in_channels[2] * width), | |||||
| round(3 * depth), | |||||
| False, | |||||
| depthwise=depthwise, | |||||
| act=act, | |||||
| ) | |||||
| def forward(self, input): | |||||
| """ | |||||
| Args: | |||||
| inputs: input images. | |||||
| Returns: | |||||
| Tuple[Tensor]: FPN feature. | |||||
| """ | |||||
| # backbone | |||||
| out_features = self.backbone(input) | |||||
| features = [out_features[f] for f in self.in_features] | |||||
| [x2, x1, x0] = features | |||||
| fpn_out0 = self.lateral_conv0(x0) # 1024->512/32 | |||||
| f_out0 = self.upsample(fpn_out0) # 512/16 | |||||
| f_out0 = torch.cat([f_out0, x1], 1) # 512->1024/16 | |||||
| f_out0 = self.C3_p4(f_out0) # 1024->512/16 | |||||
| fpn_out1 = self.reduce_conv1(f_out0) # 512->256/16 | |||||
| f_out1 = self.upsample(fpn_out1) # 256/8 | |||||
| f_out1 = torch.cat([f_out1, x2], 1) # 256->512/8 | |||||
| pan_out2 = self.C3_p3(f_out1) # 512->256/8 | |||||
| p_out1 = self.bu_conv2(pan_out2) # 256->256/16 | |||||
| p_out1 = torch.cat([p_out1, fpn_out1], 1) # 256->512/16 | |||||
| pan_out1 = self.C3_n3(p_out1) # 512->512/16 | |||||
| p_out0 = self.bu_conv1(pan_out1) # 512->512/32 | |||||
| p_out0 = torch.cat([p_out0, fpn_out0], 1) # 512->1024/32 | |||||
| pan_out0 = self.C3_n4(p_out0) # 1024->1024/32 | |||||
| outputs = (pan_out2, pan_out1, pan_out0) | |||||
| return outputs | |||||
| @@ -0,0 +1,33 @@ | |||||
| # The implementation is based on YOLOX, available at https://github.com/Megvii-BaseDetection/YOLOX | |||||
| import torch.nn as nn | |||||
| from .yolo_head import YOLOXHead | |||||
| from .yolo_pafpn import YOLOPAFPN | |||||
| class YOLOX(nn.Module): | |||||
| """ | |||||
| YOLOX model module. The module list is defined by create_yolov3_modules function. | |||||
| The network returns loss values from three YOLO layers during training | |||||
| and detection results during test. | |||||
| """ | |||||
| def __init__(self, backbone=None, head=None): | |||||
| super(YOLOX, self).__init__() | |||||
| if backbone is None: | |||||
| backbone = YOLOPAFPN() | |||||
| if head is None: | |||||
| head = YOLOXHead(80) | |||||
| self.backbone = backbone | |||||
| self.head = head | |||||
| def forward(self, x, targets=None): | |||||
| fpn_outs = self.backbone(x) | |||||
| if self.training: | |||||
| raise NotImplementedError('Training is not supported yet!') | |||||
| else: | |||||
| outputs = self.head(fpn_outs) | |||||
| return outputs | |||||
| @@ -0,0 +1,5 @@ | |||||
| # The implementation is based on YOLOX, available at https://github.com/Megvii-BaseDetection/YOLOX | |||||
| from .boxes import * # noqa | |||||
| __all__ = ['bboxes_iou', 'meshgrid', 'postprocess', 'xyxy2cxcywh', 'xyxy2xywh'] | |||||
| @@ -0,0 +1,107 @@ | |||||
| # The implementation is based on YOLOX, available at https://github.com/Megvii-BaseDetection/YOLOX | |||||
| import torch | |||||
| import torchvision | |||||
| _TORCH_VER = [int(x) for x in torch.__version__.split('.')[:2]] | |||||
| def meshgrid(*tensors): | |||||
| if _TORCH_VER >= [1, 10]: | |||||
| return torch.meshgrid(*tensors, indexing='ij') | |||||
| else: | |||||
| return torch.meshgrid(*tensors) | |||||
| def xyxy2xywh(bboxes): | |||||
| bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 0] | |||||
| bboxes[:, 3] = bboxes[:, 3] - bboxes[:, 1] | |||||
| return bboxes | |||||
| def xyxy2cxcywh(bboxes): | |||||
| bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 0] | |||||
| bboxes[:, 3] = bboxes[:, 3] - bboxes[:, 1] | |||||
| bboxes[:, 0] = bboxes[:, 0] + bboxes[:, 2] * 0.5 | |||||
| bboxes[:, 1] = bboxes[:, 1] + bboxes[:, 3] * 0.5 | |||||
| return bboxes | |||||
| def postprocess(prediction, | |||||
| num_classes, | |||||
| conf_thre=0.7, | |||||
| nms_thre=0.45, | |||||
| class_agnostic=False): | |||||
| box_corner = prediction.new(prediction.shape) | |||||
| box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2 | |||||
| box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2 | |||||
| box_corner[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2 | |||||
| box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2 | |||||
| prediction[:, :, :4] = box_corner[:, :, :4] | |||||
| output = [None for _ in range(len(prediction))] | |||||
| for i, image_pred in enumerate(prediction): | |||||
| # If none are remaining => process next image | |||||
| if not image_pred.size(0): | |||||
| continue | |||||
| # Get score and class with highest confidence | |||||
| class_conf, class_pred = torch.max( | |||||
| image_pred[:, 5:5 + num_classes], 1, keepdim=True) | |||||
| conf_mask = image_pred[:, 4] * class_conf.squeeze() | |||||
| conf_mask = (conf_mask >= conf_thre).squeeze() | |||||
| # Detections ordered as (x1, y1, x2, y2, obj_conf, class_conf, class_pred) | |||||
| detections = torch.cat( | |||||
| (image_pred[:, :5], class_conf, class_pred.float()), 1) | |||||
| detections = detections[conf_mask] | |||||
| if not detections.size(0): | |||||
| continue | |||||
| if class_agnostic: | |||||
| nms_out_index = torchvision.ops.nms( | |||||
| detections[:, :4], | |||||
| detections[:, 4] * detections[:, 5], | |||||
| nms_thre, | |||||
| ) | |||||
| else: | |||||
| nms_out_index = torchvision.ops.batched_nms( | |||||
| detections[:, :4], | |||||
| detections[:, 4] * detections[:, 5], | |||||
| detections[:, 6], | |||||
| nms_thre, | |||||
| ) | |||||
| detections = detections[nms_out_index] | |||||
| if output[i] is None: | |||||
| output[i] = detections | |||||
| else: | |||||
| output[i] = torch.cat((output[i], detections)) | |||||
| return output | |||||
| def bboxes_iou(bboxes_a, bboxes_b, xyxy=True): | |||||
| if bboxes_a.shape[1] != 4 or bboxes_b.shape[1] != 4: | |||||
| raise IndexError | |||||
| if xyxy: | |||||
| tl = torch.max(bboxes_a[:, None, :2], bboxes_b[:, :2]) | |||||
| br = torch.min(bboxes_a[:, None, 2:], bboxes_b[:, 2:]) | |||||
| area_a = torch.prod(bboxes_a[:, 2:] - bboxes_a[:, :2], 1) | |||||
| area_b = torch.prod(bboxes_b[:, 2:] - bboxes_b[:, :2], 1) | |||||
| else: | |||||
| tl = torch.max( | |||||
| (bboxes_a[:, None, :2] - bboxes_a[:, None, 2:] / 2), | |||||
| (bboxes_b[:, :2] - bboxes_b[:, 2:] / 2), | |||||
| ) | |||||
| br = torch.min( | |||||
| (bboxes_a[:, None, :2] + bboxes_a[:, None, 2:] / 2), | |||||
| (bboxes_b[:, :2] + bboxes_b[:, 2:] / 2), | |||||
| ) | |||||
| area_a = torch.prod(bboxes_a[:, 2:], 1) | |||||
| area_b = torch.prod(bboxes_b[:, 2:], 1) | |||||
| en = (tl < br).type(tl.type()).prod(dim=2) | |||||
| area_i = torch.prod(br - tl, 2) * en # * ((tl < br).all()) | |||||
| return area_i / (area_a[:, None] + area_b - area_i) | |||||
| @@ -32,6 +32,7 @@ if TYPE_CHECKING: | |||||
| from .image_to_image_generate_pipeline import Image2ImageGenerationPipeline | from .image_to_image_generate_pipeline import Image2ImageGenerationPipeline | ||||
| from .image_to_image_translation_pipeline import Image2ImageTranslationPipeline | from .image_to_image_translation_pipeline import Image2ImageTranslationPipeline | ||||
| from .product_retrieval_embedding_pipeline import ProductRetrievalEmbeddingPipeline | from .product_retrieval_embedding_pipeline import ProductRetrievalEmbeddingPipeline | ||||
| from .realtime_object_detection_pipeline import RealtimeObjectDetectionPipeline | |||||
| from .live_category_pipeline import LiveCategoryPipeline | from .live_category_pipeline import LiveCategoryPipeline | ||||
| from .ocr_detection_pipeline import OCRDetectionPipeline | from .ocr_detection_pipeline import OCRDetectionPipeline | ||||
| from .ocr_recognition_pipeline import OCRRecognitionPipeline | from .ocr_recognition_pipeline import OCRRecognitionPipeline | ||||
| @@ -75,6 +76,8 @@ else: | |||||
| ['Image2ImageTranslationPipeline'], | ['Image2ImageTranslationPipeline'], | ||||
| 'product_retrieval_embedding_pipeline': | 'product_retrieval_embedding_pipeline': | ||||
| ['ProductRetrievalEmbeddingPipeline'], | ['ProductRetrievalEmbeddingPipeline'], | ||||
| 'realtime_object_detection_pipeline': | |||||
| ['RealtimeObjectDetectionPipeline'], | |||||
| 'live_category_pipeline': ['LiveCategoryPipeline'], | 'live_category_pipeline': ['LiveCategoryPipeline'], | ||||
| 'image_to_image_generation_pipeline': | 'image_to_image_generation_pipeline': | ||||
| ['Image2ImageGenerationPipeline'], | ['Image2ImageGenerationPipeline'], | ||||
| @@ -0,0 +1,50 @@ | |||||
| import os.path as osp | |||||
| from typing import Any, Dict, List, Union | |||||
| import cv2 | |||||
| import json | |||||
| import numpy as np | |||||
| import torch | |||||
| from PIL import Image | |||||
| from torchvision import transforms | |||||
| from modelscope.metainfo import Pipelines | |||||
| from modelscope.models.cv.realtime_object_detection import RealtimeDetector | |||||
| from modelscope.outputs import OutputKeys | |||||
| from modelscope.pipelines import pipeline | |||||
| from modelscope.pipelines.base import Input, Model, Pipeline, Tensor | |||||
| from modelscope.pipelines.builder import PIPELINES | |||||
| from modelscope.preprocessors import load_image | |||||
| from modelscope.utils.constant import ModelFile, Tasks | |||||
| from modelscope.utils.logger import get_logger | |||||
| logger = get_logger() | |||||
| @PIPELINES.register_module( | |||||
| Tasks.image_object_detection, | |||||
| module_name=Pipelines.realtime_object_detection) | |||||
| class RealtimeObjectDetectionPipeline(Pipeline): | |||||
| def __init__(self, model: str, **kwargs): | |||||
| super().__init__(model=model, **kwargs) | |||||
| self.model = RealtimeDetector(model) | |||||
| def preprocess(self, input: Input) -> Dict[Tensor, Union[str, np.ndarray]]: | |||||
| output = self.model.preprocess(input) | |||||
| return {'pre_output': output} | |||||
| def forward(self, input: Tensor) -> Dict[Tensor, Dict[str, np.ndarray]]: | |||||
| pre_output = input['pre_output'] | |||||
| forward_output = self.model(pre_output) | |||||
| return {'forward_output': forward_output} | |||||
| def postprocess(self, input: Dict[Tensor, Dict[str, np.ndarray]], | |||||
| **kwargs) -> str: | |||||
| forward_output = input['forward_output'] | |||||
| bboxes, scores, labels = forward_output | |||||
| return { | |||||
| OutputKeys.BOXES: bboxes, | |||||
| OutputKeys.SCORES: scores, | |||||
| OutputKeys.LABELS: labels, | |||||
| } | |||||
| @@ -70,6 +70,13 @@ def draw_box(image, box): | |||||
| (int(box[1][0]), int(box[1][1])), (0, 0, 255), 2) | (int(box[1][0]), int(box[1][1])), (0, 0, 255), 2) | ||||
| def realtime_object_detection_bbox_vis(image, bboxes): | |||||
| for bbox in bboxes: | |||||
| cv2.rectangle(image, (bbox[0], bbox[1]), (bbox[2], bbox[3]), | |||||
| (255, 0, 0), 2) | |||||
| return image | |||||
| def draw_keypoints(output, original_image): | def draw_keypoints(output, original_image): | ||||
| poses = np.array(output[OutputKeys.POSES]) | poses = np.array(output[OutputKeys.POSES]) | ||||
| scores = np.array(output[OutputKeys.SCORES]) | scores = np.array(output[OutputKeys.SCORES]) | ||||
| @@ -0,0 +1,52 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import unittest | |||||
| import cv2 | |||||
| import numpy as np | |||||
| from modelscope.outputs import OutputKeys | |||||
| from modelscope.pipelines import pipeline | |||||
| from modelscope.pipelines.base import Pipeline | |||||
| from modelscope.utils.constant import Tasks | |||||
| from modelscope.utils.cv.image_utils import realtime_object_detection_bbox_vis | |||||
| from modelscope.utils.test_utils import test_level | |||||
| class RealtimeObjectDetectionTest(unittest.TestCase): | |||||
| def setUp(self) -> None: | |||||
| self.model_id = 'damo/cv_cspnet_image-object-detection_yolox' | |||||
| self.model_nano_id = 'damo/cv_cspnet_image-object-detection_yolox_nano_coco' | |||||
| self.test_image = 'data/test/images/keypoints_detect/000000438862.jpg' | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_run_modelhub(self): | |||||
| realtime_object_detection = pipeline( | |||||
| Tasks.image_object_detection, model=self.model_id) | |||||
| image = cv2.imread(self.test_image) | |||||
| result = realtime_object_detection(image) | |||||
| if result: | |||||
| bboxes = result[OutputKeys.BOXES].astype(int) | |||||
| image = realtime_object_detection_bbox_vis(image, bboxes) | |||||
| cv2.imwrite('rt_obj_out.jpg', image) | |||||
| else: | |||||
| raise ValueError('process error') | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_run_nano(self): | |||||
| realtime_object_detection = pipeline( | |||||
| Tasks.image_object_detection, model=self.model_nano_id) | |||||
| image = cv2.imread(self.test_image) | |||||
| result = realtime_object_detection(image) | |||||
| if result: | |||||
| bboxes = result[OutputKeys.BOXES].astype(int) | |||||
| image = realtime_object_detection_bbox_vis(image, bboxes) | |||||
| cv2.imwrite('rtnano_obj_out.jpg', image) | |||||
| else: | |||||
| raise ValueError('process error') | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||