添加直播类目模型
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9476982
master
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:5c09750586d1693b6c521d98907c3290d78635a2fb33c76db0132cd2b8ef90f0 | |||||
| size 1019267 | |||||
| @@ -67,6 +67,7 @@ class Pipelines(object): | |||||
| action_recognition = 'TAdaConv_action-recognition' | action_recognition = 'TAdaConv_action-recognition' | ||||
| animal_recognation = 'resnet101-animal_recog' | animal_recognation = 'resnet101-animal_recog' | ||||
| cmdssl_video_embedding = 'cmdssl-r2p1d_video_embedding' | cmdssl_video_embedding = 'cmdssl-r2p1d_video_embedding' | ||||
| live_category = 'live-category' | |||||
| general_image_classification = 'vit-base_image-classification_ImageNet-labels' | general_image_classification = 'vit-base_image-classification_ImageNet-labels' | ||||
| daily_image_classification = 'vit-base_image-classification_Dailylife-labels' | daily_image_classification = 'vit-base_image-classification_Dailylife-labels' | ||||
| image_color_enhance = 'csrnet-image-color-enhance' | image_color_enhance = 'csrnet-image-color-enhance' | ||||
| @@ -76,6 +77,7 @@ class Pipelines(object): | |||||
| face_image_generation = 'gan-face-image-generation' | face_image_generation = 'gan-face-image-generation' | ||||
| style_transfer = 'AAMS-style-transfer' | style_transfer = 'AAMS-style-transfer' | ||||
| image_instance_segmentation = 'cascade-mask-rcnn-swin-image-instance-segmentation' | image_instance_segmentation = 'cascade-mask-rcnn-swin-image-instance-segmentation' | ||||
| live_category = 'live-category' | |||||
| video_category = 'video-category' | video_category = 'video-category' | ||||
| # nlp tasks | # nlp tasks | ||||
| @@ -85,6 +85,12 @@ TASK_OUTPUTS = { | |||||
| # } | # } | ||||
| Tasks.action_recognition: [OutputKeys.LABELS], | Tasks.action_recognition: [OutputKeys.LABELS], | ||||
| # live category recognition result for single video | |||||
| # { | |||||
| # "scores": [0.885272, 0.014790631, 0.014558001] | |||||
| # "labels": ['女装/女士精品>>棉衣/棉服', '女装/女士精品>>牛仔裤', '女装/女士精品>>裤子>>休闲裤'], | |||||
| # } | |||||
| Tasks.live_category: [OutputKeys.SCORES, OutputKeys.LABELS], | |||||
| # video category recognition result for single video | # video category recognition result for single video | ||||
| # { | # { | ||||
| # "scores": [0.7716429233551025] | # "scores": [0.7716429233551025] | ||||
| @@ -61,6 +61,8 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||||
| Tasks.fill_mask: (Pipelines.fill_mask, 'damo/nlp_veco_fill-mask-large'), | Tasks.fill_mask: (Pipelines.fill_mask, 'damo/nlp_veco_fill-mask-large'), | ||||
| Tasks.action_recognition: (Pipelines.action_recognition, | Tasks.action_recognition: (Pipelines.action_recognition, | ||||
| 'damo/cv_TAdaConv_action-recognition'), | 'damo/cv_TAdaConv_action-recognition'), | ||||
| Tasks.live_category: (Pipelines.live_category, | |||||
| 'damo/cv_resnet50_live-category'), | |||||
| Tasks.video_category: (Pipelines.video_category, | Tasks.video_category: (Pipelines.video_category, | ||||
| 'damo/cv_resnet50_video-category'), | 'damo/cv_resnet50_video-category'), | ||||
| Tasks.multi_modal_embedding: | Tasks.multi_modal_embedding: | ||||
| @@ -7,6 +7,7 @@ if TYPE_CHECKING: | |||||
| from .action_recognition_pipeline import ActionRecognitionPipeline | from .action_recognition_pipeline import ActionRecognitionPipeline | ||||
| from .animal_recog_pipeline import AnimalRecogPipeline | from .animal_recog_pipeline import AnimalRecogPipeline | ||||
| from .cmdssl_video_embedding_pipleline import CMDSSLVideoEmbeddingPipeline | from .cmdssl_video_embedding_pipleline import CMDSSLVideoEmbeddingPipeline | ||||
| from .live_category_pipeline import LiveCategoryPipeline | |||||
| from .image_classification_pipeline import GeneralImageClassificationPipeline | from .image_classification_pipeline import GeneralImageClassificationPipeline | ||||
| from .face_image_generation_pipeline import FaceImageGenerationPipeline | from .face_image_generation_pipeline import FaceImageGenerationPipeline | ||||
| from .image_cartoon_pipeline import ImageCartoonPipeline | from .image_cartoon_pipeline import ImageCartoonPipeline | ||||
| @@ -40,6 +41,7 @@ else: | |||||
| 'image_instance_segmentation_pipeline': | 'image_instance_segmentation_pipeline': | ||||
| ['ImageInstanceSegmentationPipeline'], | ['ImageInstanceSegmentationPipeline'], | ||||
| 'video_category_pipeline': ['VideoCategoryPipeline'], | 'video_category_pipeline': ['VideoCategoryPipeline'], | ||||
| 'live_category_pipeline': ['LiveCategoryPipeline'], | |||||
| } | } | ||||
| import sys | import sys | ||||
| @@ -0,0 +1,156 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import os.path as osp | |||||
| from typing import Any, Dict | |||||
| import numpy as np | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| import torch.nn.functional as F | |||||
| import torchvision.models as models | |||||
| import torchvision.transforms.functional as TF | |||||
| from PIL import Image | |||||
| from modelscope.metainfo import Pipelines | |||||
| from modelscope.outputs import OutputKeys | |||||
| from modelscope.pipelines.base import Input, Pipeline | |||||
| from modelscope.pipelines.builder import PIPELINES | |||||
| from modelscope.utils.config import Config | |||||
| from modelscope.utils.constant import ModelFile, Tasks | |||||
| from modelscope.utils.logger import get_logger | |||||
| logger = get_logger() | |||||
| @PIPELINES.register_module( | |||||
| Tasks.live_category, module_name=Pipelines.live_category) | |||||
| class LiveCategoryPipeline(Pipeline): | |||||
| def __init__(self, model: str, **kwargs): | |||||
| """ | |||||
| use `model` to create a live-category pipeline for prediction | |||||
| Args: | |||||
| model: model id on modelscope hub. | |||||
| """ | |||||
| super().__init__(model=model, **kwargs) | |||||
| model_path = osp.join(self.model, ModelFile.TORCH_MODEL_FILE) | |||||
| logger.info(f'loading model from {model_path}') | |||||
| self.infer_model = models.resnet50(pretrained=False) | |||||
| self.infer_model.fc = nn.Linear(2048, 8613) | |||||
| self.device = torch.device( | |||||
| 'cuda' if torch.cuda.is_available() else 'cpu') | |||||
| self.infer_model = self.infer_model.to(self.device).eval() | |||||
| self.infer_model.load_state_dict( | |||||
| torch.load(model_path, map_location=self.device)) | |||||
| logger.info('load model done') | |||||
| config_path = osp.join(self.model, ModelFile.CONFIGURATION) | |||||
| logger.info(f'loading config from {config_path}') | |||||
| self.cfg = Config.from_file(config_path) | |||||
| self.label_mapping = self.cfg.label_mapping | |||||
| logger.info('load config done') | |||||
| self.transforms = VCompose([ | |||||
| VRescale(size=256), | |||||
| VCenterCrop(size=224), | |||||
| VToTensor(), | |||||
| VNormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |||||
| ]) | |||||
| def preprocess(self, input: Input) -> Dict[str, Any]: | |||||
| if isinstance(input, str): | |||||
| import decord | |||||
| from decord import VideoReader, cpu | |||||
| decord.bridge.set_bridge('native') | |||||
| vr = VideoReader(input, ctx=cpu(0)) | |||||
| indices = np.linspace(0, len(vr) - 1, 4).astype(int) | |||||
| frames = vr.get_batch(indices).asnumpy() | |||||
| video_input_data = self.transforms( | |||||
| [Image.fromarray(f) for f in frames]) | |||||
| else: | |||||
| raise TypeError(f'input should be a str,' | |||||
| f' but got {type(input)}') | |||||
| result = {'video_data': video_input_data} | |||||
| return result | |||||
| @torch.no_grad() | |||||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||||
| logits = self.infer_model(input['video_data'].to(self.device)) | |||||
| softmax_out = F.softmax(logits, dim=1).mean(dim=0).cpu() | |||||
| scores, ids = softmax_out.topk(3, 0, True, True) | |||||
| scores = scores.numpy() | |||||
| ids = ids.numpy() | |||||
| labels = [] | |||||
| for i in ids: | |||||
| label_info = self.label_mapping[str(i)] | |||||
| label_keys = ['cate_level1_name', 'cate_level2_name', 'cate_name'] | |||||
| label_str = [] | |||||
| for label_key in label_keys: | |||||
| if label_info[label_key] not in label_str: | |||||
| label_str.append(label_info[label_key]) | |||||
| labels.append('>>'.join(label_str)) | |||||
| return {OutputKeys.SCORES: list(scores), OutputKeys.LABELS: labels} | |||||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||||
| return inputs | |||||
| class VCompose(object): | |||||
| def __init__(self, transforms): | |||||
| self.transforms = transforms | |||||
| def __call__(self, item): | |||||
| for t in self.transforms: | |||||
| item = t(item) | |||||
| return item | |||||
| class VRescale(object): | |||||
| def __init__(self, size=128): | |||||
| self.size = size | |||||
| def __call__(self, vclip): | |||||
| vclip = [ | |||||
| u.resize((self.size, self.size), Image.BILINEAR) for u in vclip | |||||
| ] | |||||
| return vclip | |||||
| class VCenterCrop(object): | |||||
| def __init__(self, size=112): | |||||
| self.size = size | |||||
| def __call__(self, vclip): | |||||
| w, h = vclip[0].size | |||||
| assert min(w, h) >= self.size | |||||
| x1 = (w - self.size) // 2 | |||||
| y1 = (h - self.size) // 2 | |||||
| vclip = [ | |||||
| u.crop((x1, y1, x1 + self.size, y1 + self.size)) for u in vclip | |||||
| ] | |||||
| return vclip | |||||
| class VToTensor(object): | |||||
| def __call__(self, vclip): | |||||
| vclip = torch.stack([TF.to_tensor(u) for u in vclip], dim=0) | |||||
| return vclip | |||||
| class VNormalize(object): | |||||
| def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]): | |||||
| self.mean = mean | |||||
| self.std = std | |||||
| def __call__(self, vclip): | |||||
| assert vclip.min() > -0.1 and vclip.max() < 1.1, \ | |||||
| 'vclip values should be in [0, 1]' | |||||
| vclip = vclip.clone() | |||||
| if not isinstance(self.mean, torch.Tensor): | |||||
| self.mean = vclip.new_tensor(self.mean).view(1, -1, 1, 1) | |||||
| if not isinstance(self.std, torch.Tensor): | |||||
| self.std = vclip.new_tensor(self.std).view(1, -1, 1, 1) | |||||
| vclip.sub_(self.mean).div_(self.std) | |||||
| return vclip | |||||
| @@ -34,6 +34,7 @@ class CVTasks(object): | |||||
| face_image_generation = 'face-image-generation' | face_image_generation = 'face-image-generation' | ||||
| image_super_resolution = 'image-super-resolution' | image_super_resolution = 'image-super-resolution' | ||||
| style_transfer = 'style-transfer' | style_transfer = 'style-transfer' | ||||
| live_category = 'live-category' | |||||
| video_category = 'video-category' | video_category = 'video-category' | ||||
| image_classification_imagenet = 'image-classification-imagenet' | image_classification_imagenet = 'image-classification-imagenet' | ||||
| image_classification_dailylife = 'image-classification-dailylife' | image_classification_dailylife = 'image-classification-dailylife' | ||||
| @@ -0,0 +1,22 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import unittest | |||||
| from modelscope.pipelines import pipeline | |||||
| from modelscope.utils.constant import Tasks | |||||
| from modelscope.utils.test_utils import test_level | |||||
| class LiveCategoryTest(unittest.TestCase): | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_run_modelhub(self): | |||||
| category_pipeline = pipeline( | |||||
| Tasks.live_category, model='damo/cv_resnet50_live-category') | |||||
| result = category_pipeline( | |||||
| 'data/test/videos/live_category_test_video.mp4') | |||||
| print(f'live category output: {result}.') | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||