*Add image-classification-imagenet and image-classification-dailylife pipelines *Add models.cv.mmcls_model.ClassificaitonModel as a wrapper class for mmclsmaster
| @@ -2,3 +2,4 @@ | |||
| *.jpg filter=lfs diff=lfs merge=lfs -text | |||
| *.mp4 filter=lfs diff=lfs merge=lfs -text | |||
| *.wav filter=lfs diff=lfs merge=lfs -text | |||
| *.JPEG filter=lfs diff=lfs merge=lfs -text | |||
| @@ -0,0 +1,3 @@ | |||
| version https://git-lfs.github.com/spec/v1 | |||
| oid sha256:19fb781a44aec9349a8b73850e53b7eb9b0623d54ebd0cd8577c13bf463b5004 | |||
| size 74237 | |||
| @@ -10,6 +10,7 @@ class Models(object): | |||
| Model name should only contain model info but not task info. | |||
| """ | |||
| # vision models | |||
| classification_model = 'ClassificationModel' | |||
| nafnet = 'nafnet' | |||
| csrnet = 'csrnet' | |||
| cascade_mask_rcnn_swin = 'cascade_mask_rcnn_swin' | |||
| @@ -66,6 +67,8 @@ class Pipelines(object): | |||
| action_recognition = 'TAdaConv_action-recognition' | |||
| animal_recognation = 'resnet101-animal_recog' | |||
| cmdssl_video_embedding = 'cmdssl-r2p1d_video_embedding' | |||
| general_image_classification = 'vit-base_image-classification_ImageNet-labels' | |||
| daily_image_classification = 'vit-base_image-classification_Dailylife-labels' | |||
| image_color_enhance = 'csrnet-image-color-enhance' | |||
| virtual_tryon = 'virtual_tryon' | |||
| image_colorization = 'unet-image-colorization' | |||
| @@ -1,5 +1,5 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from . import (action_recognition, animal_recognition, cartoon, | |||
| cmdssl_video_embedding, face_generation, image_color_enhance, | |||
| image_colorization, image_denoise, image_instance_segmentation, | |||
| super_resolution, virual_tryon) | |||
| cmdssl_video_embedding, face_generation, image_classification, | |||
| image_color_enhance, image_colorization, image_denoise, | |||
| image_instance_segmentation, super_resolution, virual_tryon) | |||
| @@ -0,0 +1,22 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from typing import TYPE_CHECKING | |||
| from modelscope.utils.import_utils import LazyImportModule | |||
| if TYPE_CHECKING: | |||
| from .mmcls_model import ClassificationModel | |||
| else: | |||
| _import_structure = { | |||
| 'mmcls_model': ['ClassificationModel'], | |||
| } | |||
| import sys | |||
| sys.modules[__name__] = LazyImportModule( | |||
| __name__, | |||
| globals()['__file__'], | |||
| _import_structure, | |||
| module_spec=__spec__, | |||
| extra_objects={}, | |||
| ) | |||
| @@ -0,0 +1,46 @@ | |||
| import os | |||
| from modelscope.metainfo import Models | |||
| from modelscope.models.base.base_torch_model import TorchModel | |||
| from modelscope.models.builder import MODELS | |||
| from modelscope.utils.constant import Tasks | |||
| @MODELS.register_module( | |||
| Tasks.image_classification_imagenet, | |||
| module_name=Models.classification_model) | |||
| @MODELS.register_module( | |||
| Tasks.image_classification_dailylife, | |||
| module_name=Models.classification_model) | |||
| class ClassificationModel(TorchModel): | |||
| def __init__(self, model_dir: str): | |||
| import mmcv | |||
| from mmcls.models import build_classifier | |||
| super().__init__(model_dir) | |||
| config = os.path.join(model_dir, 'config.py') | |||
| cfg = mmcv.Config.fromfile(config) | |||
| cfg.model.pretrained = None | |||
| self.cls_model = build_classifier(cfg.model) | |||
| self.cfg = cfg | |||
| self.ms_model_dir = model_dir | |||
| self.load_pretrained_checkpoint() | |||
| def forward(self, Inputs): | |||
| return self.cls_model(**Inputs) | |||
| def load_pretrained_checkpoint(self): | |||
| import mmcv | |||
| checkpoint_path = os.path.join(self.ms_model_dir, 'checkpoints.pth') | |||
| if os.path.exists(checkpoint_path): | |||
| checkpoint = mmcv.runner.load_checkpoint( | |||
| self.cls_model, checkpoint_path, map_location='cpu') | |||
| if 'CLASSES' in checkpoint.get('meta', {}): | |||
| self.cls_model.CLASSES = checkpoint['meta']['CLASSES'] | |||
| self.CLASSES = self.cls_model.CLASSES | |||
| @@ -94,6 +94,12 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||
| 'damo/cv_gan_face-image-generation'), | |||
| Tasks.image_super_resolution: (Pipelines.image_super_resolution, | |||
| 'damo/cv_rrdb_image-super-resolution'), | |||
| Tasks.image_classification_imagenet: | |||
| (Pipelines.general_image_classification, | |||
| 'damo/cv_vit-base_image-classification_ImageNet-labels'), | |||
| Tasks.image_classification_dailylife: | |||
| (Pipelines.daily_image_classification, | |||
| 'damo/cv_vit-base_image-classification_Dailylife-labels'), | |||
| } | |||
| @@ -7,6 +7,7 @@ if TYPE_CHECKING: | |||
| from .action_recognition_pipeline import ActionRecognitionPipeline | |||
| from .animal_recog_pipeline import AnimalRecogPipeline | |||
| from .cmdssl_video_embedding_pipleline import CMDSSLVideoEmbeddingPipeline | |||
| from .image_classification_pipeline import GeneralImageClassificationPipeline | |||
| from .face_image_generation_pipeline import FaceImageGenerationPipeline | |||
| from .image_cartoon_pipeline import ImageCartoonPipeline | |||
| from .image_denoise_pipeline import ImageDenoisePipeline | |||
| @@ -23,6 +24,8 @@ else: | |||
| 'action_recognition_pipeline': ['ActionRecognitionPipeline'], | |||
| 'animal_recog_pipeline': ['AnimalRecogPipeline'], | |||
| 'cmdssl_video_embedding_pipleline': ['CMDSSLVideoEmbeddingPipeline'], | |||
| 'image_classification_pipeline': | |||
| ['GeneralImageClassificationPipeline'], | |||
| 'image_color_enhance_pipeline': ['ImageColorEnhancePipeline'], | |||
| 'virtual_tryon_pipeline': ['VirtualTryonPipeline'], | |||
| 'image_colorization_pipeline': ['ImageColorizationPipeline'], | |||
| @@ -0,0 +1,87 @@ | |||
| from typing import Any, Dict | |||
| import cv2 | |||
| import numpy as np | |||
| import PIL | |||
| import torch | |||
| from modelscope.metainfo import Pipelines | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines.base import Input | |||
| from modelscope.preprocessors import load_image | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.logger import get_logger | |||
| from ..base import Pipeline | |||
| from ..builder import PIPELINES | |||
| logger = get_logger() | |||
| @PIPELINES.register_module( | |||
| Tasks.image_classification_imagenet, | |||
| module_name=Pipelines.general_image_classification) | |||
| @PIPELINES.register_module( | |||
| Tasks.image_classification_dailylife, | |||
| module_name=Pipelines.daily_image_classification) | |||
| class GeneralImageClassificationPipeline(Pipeline): | |||
| def __init__(self, model: str, **kwargs): | |||
| """ | |||
| use `model` and `preprocessor` to create a kws pipeline for prediction | |||
| Args: | |||
| model: model id on modelscope hub. | |||
| """ | |||
| super().__init__(model=model, **kwargs) | |||
| logger.info('load model done') | |||
| def preprocess(self, input: Input) -> Dict[str, Any]: | |||
| from mmcls.datasets.pipelines import Compose | |||
| from mmcv.parallel import collate, scatter | |||
| if isinstance(input, str): | |||
| img = np.array(load_image(input)) | |||
| elif isinstance(input, PIL.Image.Image): | |||
| img = np.array(input.convert('RGB')) | |||
| elif isinstance(input, np.ndarray): | |||
| if len(input.shape) == 2: | |||
| img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) | |||
| img = input[:, :, ::-1] # in rgb order | |||
| else: | |||
| raise TypeError(f'input should be either str, PIL.Image,' | |||
| f' np.array, but got {type(input)}') | |||
| mmcls_cfg = self.model.cfg | |||
| # build the data pipeline | |||
| if mmcls_cfg.data.test.pipeline[0]['type'] == 'LoadImageFromFile': | |||
| mmcls_cfg.data.test.pipeline.pop(0) | |||
| data = dict(img=img) | |||
| test_pipeline = Compose(mmcls_cfg.data.test.pipeline) | |||
| data = test_pipeline(data) | |||
| data = collate([data], samples_per_gpu=1) | |||
| if next(self.model.parameters()).is_cuda: | |||
| # scatter to specified GPU | |||
| data = scatter(data, [next(self.model.parameters()).device])[0] | |||
| return data | |||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||
| with torch.no_grad(): | |||
| input['return_loss'] = False | |||
| scores = self.model(input) | |||
| return {'scores': scores} | |||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
| scores = inputs['scores'] | |||
| pred_score = np.max(scores, axis=1)[0] | |||
| pred_label = np.argmax(scores, axis=1)[0] | |||
| result = {'pred_label': pred_label, 'pred_score': float(pred_score)} | |||
| result['pred_class'] = self.model.CLASSES[result['pred_label']] | |||
| outputs = { | |||
| OutputKeys.SCORES: [result['pred_score']], | |||
| OutputKeys.LABELS: [result['pred_class']] | |||
| } | |||
| return outputs | |||
| @@ -34,6 +34,8 @@ class CVTasks(object): | |||
| face_image_generation = 'face-image-generation' | |||
| image_super_resolution = 'image-super-resolution' | |||
| style_transfer = 'style-transfer' | |||
| image_classification_imagenet = 'image-classification-imagenet' | |||
| image_classification_dailylife = 'image-classification-dailylife' | |||
| class NLPTasks(object): | |||
| @@ -0,0 +1,42 @@ | |||
| import unittest | |||
| from modelscope.pipelines import pipeline | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.test_utils import test_level | |||
| class GeneralImageClassificationTest(unittest.TestCase): | |||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
| def test_run_ImageNet(self): | |||
| general_image_classification = pipeline( | |||
| Tasks.image_classification_imagenet, | |||
| model='damo/cv_vit-base_image-classification_ImageNet-labels') | |||
| result = general_image_classification('data/test/images/bird.JPEG') | |||
| print(result) | |||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
| def test_run_Dailylife(self): | |||
| general_image_classification = pipeline( | |||
| Tasks.image_classification_dailylife, | |||
| model='damo/cv_vit-base_image-classification_Dailylife-labels') | |||
| result = general_image_classification('data/test/images/bird.JPEG') | |||
| print(result) | |||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
| def test_run_ImageNet_default_task(self): | |||
| general_image_classification = pipeline( | |||
| Tasks.image_classification_imagenet) | |||
| result = general_image_classification('data/test/images/bird.JPEG') | |||
| print(result) | |||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
| def test_run_Dailylife_default_task(self): | |||
| general_image_classification = pipeline( | |||
| Tasks.image_classification_dailylife) | |||
| result = general_image_classification('data/test/images/bird.JPEG') | |||
| print(result) | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||