Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9477652master
| @@ -76,7 +76,8 @@ class Pipelines(object): | |||||
| person_image_cartoon = 'unet-person-image-cartoon' | person_image_cartoon = 'unet-person-image-cartoon' | ||||
| ocr_detection = 'resnet18-ocr-detection' | ocr_detection = 'resnet18-ocr-detection' | ||||
| action_recognition = 'TAdaConv_action-recognition' | action_recognition = 'TAdaConv_action-recognition' | ||||
| animal_recognation = 'resnet101-animal_recog' | |||||
| animal_recognition = 'resnet101-animal-recognition' | |||||
| general_recognition = 'resnet101-general-recognition' | |||||
| cmdssl_video_embedding = 'cmdssl-r2p1d_video_embedding' | cmdssl_video_embedding = 'cmdssl-r2p1d_video_embedding' | ||||
| body_2d_keypoints = 'hrnetv2w32_body-2d-keypoints_image' | body_2d_keypoints = 'hrnetv2w32_body-2d-keypoints_image' | ||||
| human_detection = 'resnet18-human-detection' | human_detection = 'resnet18-human-detection' | ||||
| @@ -81,8 +81,7 @@ class Bottleneck(nn.Module): | |||||
| norm_layer=norm_layer, | norm_layer=norm_layer, | ||||
| dropblock_prob=dropblock_prob) | dropblock_prob=dropblock_prob) | ||||
| elif rectified_conv: | elif rectified_conv: | ||||
| from rfconv import RFConv2d | |||||
| self.conv2 = RFConv2d( | |||||
| self.conv2 = nn.Conv2d( | |||||
| group_width, | group_width, | ||||
| group_width, | group_width, | ||||
| kernel_size=3, | kernel_size=3, | ||||
| @@ -90,8 +89,7 @@ class Bottleneck(nn.Module): | |||||
| padding=dilation, | padding=dilation, | ||||
| dilation=dilation, | dilation=dilation, | ||||
| groups=cardinality, | groups=cardinality, | ||||
| bias=False, | |||||
| average_mode=rectify_avg) | |||||
| bias=False) | |||||
| self.bn2 = norm_layer(group_width) | self.bn2 = norm_layer(group_width) | ||||
| else: | else: | ||||
| self.conv2 = nn.Conv2d( | self.conv2 = nn.Conv2d( | ||||
| @@ -190,8 +188,7 @@ class ResNet(nn.Module): | |||||
| self.rectified_conv = rectified_conv | self.rectified_conv = rectified_conv | ||||
| self.rectify_avg = rectify_avg | self.rectify_avg = rectify_avg | ||||
| if rectified_conv: | if rectified_conv: | ||||
| from rfconv import RFConv2d | |||||
| conv_layer = RFConv2d | |||||
| conv_layer = nn.Conv2d | |||||
| else: | else: | ||||
| conv_layer = nn.Conv2d | conv_layer = nn.Conv2d | ||||
| conv_kwargs = {'average_mode': rectify_avg} if rectified_conv else {} | conv_kwargs = {'average_mode': rectify_avg} if rectified_conv else {} | ||||
| @@ -39,8 +39,7 @@ class SplAtConv2d(Module): | |||||
| self.channels = channels | self.channels = channels | ||||
| self.dropblock_prob = dropblock_prob | self.dropblock_prob = dropblock_prob | ||||
| if self.rectify: | if self.rectify: | ||||
| from rfconv import RFConv2d | |||||
| self.conv = RFConv2d( | |||||
| self.conv = Conv2d( | |||||
| in_channels, | in_channels, | ||||
| channels * radix, | channels * radix, | ||||
| kernel_size, | kernel_size, | ||||
| @@ -49,7 +48,6 @@ class SplAtConv2d(Module): | |||||
| dilation, | dilation, | ||||
| groups=groups * radix, | groups=groups * radix, | ||||
| bias=bias, | bias=bias, | ||||
| average_mode=rectify_avg, | |||||
| **kwargs) | **kwargs) | ||||
| else: | else: | ||||
| self.conv = Conv2d( | self.conv = Conv2d( | ||||
| @@ -10,8 +10,9 @@ if TYPE_CHECKING: | |||||
| from .cmdssl_video_embedding_pipeline import CMDSSLVideoEmbeddingPipeline | from .cmdssl_video_embedding_pipeline import CMDSSLVideoEmbeddingPipeline | ||||
| from .image_detection_pipeline import ImageDetectionPipeline | from .image_detection_pipeline import ImageDetectionPipeline | ||||
| from .face_detection_pipeline import FaceDetectionPipeline | from .face_detection_pipeline import FaceDetectionPipeline | ||||
| from .face_recognition_pipeline import FaceRecognitionPipeline | |||||
| from .face_image_generation_pipeline import FaceImageGenerationPipeline | from .face_image_generation_pipeline import FaceImageGenerationPipeline | ||||
| from .face_recognition_pipeline import FaceRecognitionPipeline | |||||
| from .general_recognition_pipeline import GeneralRecognitionPipeline | |||||
| from .image_cartoon_pipeline import ImageCartoonPipeline | from .image_cartoon_pipeline import ImageCartoonPipeline | ||||
| from .image_classification_pipeline import GeneralImageClassificationPipeline | from .image_classification_pipeline import GeneralImageClassificationPipeline | ||||
| from .image_color_enhance_pipeline import ImageColorEnhancePipeline | from .image_color_enhance_pipeline import ImageColorEnhancePipeline | ||||
| @@ -23,7 +24,7 @@ if TYPE_CHECKING: | |||||
| from .image_portrait_enhancement_pipeline import ImagePortraitEnhancementPipeline | from .image_portrait_enhancement_pipeline import ImagePortraitEnhancementPipeline | ||||
| from .image_style_transfer_pipeline import ImageStyleTransferPipeline | from .image_style_transfer_pipeline import ImageStyleTransferPipeline | ||||
| from .image_super_resolution_pipeline import ImageSuperResolutionPipeline | from .image_super_resolution_pipeline import ImageSuperResolutionPipeline | ||||
| from .image_to_image_generate_pipeline import Image2ImageGenerationePipeline | |||||
| from .image_to_image_generate_pipeline import Image2ImageGenerationPipeline | |||||
| from .image_to_image_translation_pipeline import Image2ImageTranslationPipeline | from .image_to_image_translation_pipeline import Image2ImageTranslationPipeline | ||||
| from .product_retrieval_embedding_pipeline import ProductRetrievalEmbeddingPipeline | from .product_retrieval_embedding_pipeline import ProductRetrievalEmbeddingPipeline | ||||
| from .live_category_pipeline import LiveCategoryPipeline | from .live_category_pipeline import LiveCategoryPipeline | ||||
| @@ -41,6 +42,7 @@ else: | |||||
| 'face_detection_pipeline': ['FaceDetectionPipeline'], | 'face_detection_pipeline': ['FaceDetectionPipeline'], | ||||
| 'face_image_generation_pipeline': ['FaceImageGenerationPipeline'], | 'face_image_generation_pipeline': ['FaceImageGenerationPipeline'], | ||||
| 'face_recognition_pipeline': ['FaceRecognitionPipeline'], | 'face_recognition_pipeline': ['FaceRecognitionPipeline'], | ||||
| 'general_recognition_pipeline': ['GeneralRecognitionPipeline'], | |||||
| 'image_classification_pipeline': | 'image_classification_pipeline': | ||||
| ['GeneralImageClassificationPipeline', 'ImageClassificationPipeline'], | ['GeneralImageClassificationPipeline', 'ImageClassificationPipeline'], | ||||
| 'image_cartoon_pipeline': ['ImageCartoonPipeline'], | 'image_cartoon_pipeline': ['ImageCartoonPipeline'], | ||||
| @@ -60,7 +62,7 @@ else: | |||||
| ['ProductRetrievalEmbeddingPipeline'], | ['ProductRetrievalEmbeddingPipeline'], | ||||
| 'live_category_pipeline': ['LiveCategoryPipeline'], | 'live_category_pipeline': ['LiveCategoryPipeline'], | ||||
| 'image_to_image_generation_pipeline': | 'image_to_image_generation_pipeline': | ||||
| ['Image2ImageGenerationePipeline'], | |||||
| ['Image2ImageGenerationPipeline'], | |||||
| 'ocr_detection_pipeline': ['OCRDetectionPipeline'], | 'ocr_detection_pipeline': ['OCRDetectionPipeline'], | ||||
| 'skin_retouching_pipeline': ['SkinRetouchingPipeline'], | 'skin_retouching_pipeline': ['SkinRetouchingPipeline'], | ||||
| 'video_category_pipeline': ['VideoCategoryPipeline'], | 'video_category_pipeline': ['VideoCategoryPipeline'], | ||||
| @@ -21,7 +21,7 @@ logger = get_logger() | |||||
| @PIPELINES.register_module( | @PIPELINES.register_module( | ||||
| Tasks.image_classification, module_name=Pipelines.animal_recognation) | |||||
| Tasks.animal_recognition, module_name=Pipelines.animal_recognition) | |||||
| class AnimalRecognitionPipeline(Pipeline): | class AnimalRecognitionPipeline(Pipeline): | ||||
| def __init__(self, model: str, **kwargs): | def __init__(self, model: str, **kwargs): | ||||
| @@ -0,0 +1,121 @@ | |||||
| import os.path as osp | |||||
| from typing import Any, Dict | |||||
| import cv2 | |||||
| import numpy as np | |||||
| import torch | |||||
| from PIL import Image | |||||
| from torchvision import transforms | |||||
| from modelscope.hub.snapshot_download import snapshot_download | |||||
| from modelscope.metainfo import Pipelines | |||||
| from modelscope.models.cv.animal_recognition import resnet | |||||
| from modelscope.outputs import OutputKeys | |||||
| from modelscope.pipelines.base import Input, Pipeline | |||||
| from modelscope.pipelines.builder import PIPELINES | |||||
| from modelscope.preprocessors import LoadImage, load_image | |||||
| from modelscope.utils.constant import ModelFile, Tasks | |||||
| from modelscope.utils.logger import get_logger | |||||
| logger = get_logger() | |||||
| @PIPELINES.register_module( | |||||
| Tasks.general_recognition, module_name=Pipelines.general_recognition) | |||||
| class GeneralRecognitionPipeline(Pipeline): | |||||
| def __init__(self, model: str, device: str): | |||||
| """ | |||||
| use `model` to create a general recognition pipeline for prediction | |||||
| Args: | |||||
| model: model id on modelscope hub. | |||||
| """ | |||||
| super().__init__(model=model) | |||||
| import torch | |||||
| def resnest101(**kwargs): | |||||
| model = resnet.ResNet( | |||||
| resnet.Bottleneck, [3, 4, 23, 3], | |||||
| radix=2, | |||||
| groups=1, | |||||
| bottleneck_width=64, | |||||
| deep_stem=True, | |||||
| stem_width=64, | |||||
| avg_down=True, | |||||
| avd=True, | |||||
| avd_first=False, | |||||
| **kwargs) | |||||
| return model | |||||
| def filter_param(src_params, own_state): | |||||
| copied_keys = [] | |||||
| for name, param in src_params.items(): | |||||
| if 'module.' == name[0:7]: | |||||
| name = name[7:] | |||||
| if '.module.' not in list(own_state.keys())[0]: | |||||
| name = name.replace('.module.', '.') | |||||
| if (name in own_state) and (own_state[name].shape | |||||
| == param.shape): | |||||
| own_state[name].copy_(param) | |||||
| copied_keys.append(name) | |||||
| def load_pretrained(model, src_params): | |||||
| if 'state_dict' in src_params: | |||||
| src_params = src_params['state_dict'] | |||||
| own_state = model.state_dict() | |||||
| filter_param(src_params, own_state) | |||||
| model.load_state_dict(own_state) | |||||
| self.model = resnest101(num_classes=54092) | |||||
| local_model_dir = model | |||||
| device = 'cpu' | |||||
| if osp.exists(model): | |||||
| local_model_dir = model | |||||
| else: | |||||
| local_model_dir = snapshot_download(model) | |||||
| self.local_path = local_model_dir | |||||
| src_params = torch.load( | |||||
| osp.join(local_model_dir, ModelFile.TORCH_MODEL_FILE), device) | |||||
| load_pretrained(self.model, src_params) | |||||
| logger.info('load model done') | |||||
| def preprocess(self, input: Input) -> Dict[str, Any]: | |||||
| img = LoadImage.convert_to_img(input) | |||||
| normalize = transforms.Normalize( | |||||
| mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |||||
| transform = transforms.Compose([ | |||||
| transforms.Resize(256), | |||||
| transforms.CenterCrop(224), | |||||
| transforms.ToTensor(), normalize | |||||
| ]) | |||||
| img = transform(img) | |||||
| result = {'img': img} | |||||
| return result | |||||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||||
| def set_phase(model, is_train): | |||||
| if is_train: | |||||
| model.train() | |||||
| else: | |||||
| model.eval() | |||||
| is_train = False | |||||
| set_phase(self.model, is_train) | |||||
| img = input['img'] | |||||
| input_img = torch.unsqueeze(img, 0) | |||||
| outputs = self.model(input_img) | |||||
| return {'outputs': outputs} | |||||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||||
| label_mapping_path = osp.join(self.local_path, 'meta_info.txt') | |||||
| with open(label_mapping_path, 'r') as f: | |||||
| label_mapping = f.readlines() | |||||
| score = torch.max(inputs['outputs']) | |||||
| inputs = { | |||||
| OutputKeys.SCORES: | |||||
| score.item(), | |||||
| OutputKeys.LABELS: | |||||
| label_mapping[inputs['outputs'].argmax()].split('\t')[1] | |||||
| } | |||||
| return inputs | |||||
| @@ -32,7 +32,7 @@ logger = get_logger() | |||||
| @PIPELINES.register_module( | @PIPELINES.register_module( | ||||
| Tasks.image_to_image_generation, | Tasks.image_to_image_generation, | ||||
| module_name=Pipelines.image_to_image_generation) | module_name=Pipelines.image_to_image_generation) | ||||
| class Image2ImageGenerationePipeline(Pipeline): | |||||
| class Image2ImageGenerationPipeline(Pipeline): | |||||
| def __init__(self, model: str, **kwargs): | def __init__(self, model: str, **kwargs): | ||||
| """ | """ | ||||
| @@ -17,12 +17,14 @@ class CVTasks(object): | |||||
| ocr_recognition = 'ocr-recognition' | ocr_recognition = 'ocr-recognition' | ||||
| # human face body related | # human face body related | ||||
| animal_recognition = 'animal-recognition' | |||||
| face_detection = 'face-detection' | face_detection = 'face-detection' | ||||
| face_recognition = 'face-recognition' | face_recognition = 'face-recognition' | ||||
| human_detection = 'human-detection' | human_detection = 'human-detection' | ||||
| human_object_interaction = 'human-object-interaction' | human_object_interaction = 'human-object-interaction' | ||||
| face_image_generation = 'face-image-generation' | face_image_generation = 'face-image-generation' | ||||
| body_2d_keypoints = 'body-2d-keypoints' | body_2d_keypoints = 'body-2d-keypoints' | ||||
| general_recognition = 'general-recognition' | |||||
| image_classification = 'image-classification' | image_classification = 'image-classification' | ||||
| image_multilabel_classification = 'image-multilabel-classification' | image_multilabel_classification = 'image-multilabel-classification' | ||||
| @@ -5,14 +5,14 @@ from modelscope.utils.constant import Tasks | |||||
| from modelscope.utils.test_utils import test_level | from modelscope.utils.test_utils import test_level | ||||
| class MultiModalFeatureTest(unittest.TestCase): | |||||
| class AnimalRecognitionTest(unittest.TestCase): | |||||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | ||||
| def test_run(self): | def test_run(self): | ||||
| animal_recog = pipeline( | |||||
| Tasks.image_classification, | |||||
| animal_recognition = pipeline( | |||||
| Tasks.animal_recognition, | |||||
| model='damo/cv_resnest101_animal_recognition') | model='damo/cv_resnest101_animal_recognition') | ||||
| result = animal_recog('data/test/images/dogs.jpg') | |||||
| result = animal_recognition('data/test/images/dogs.jpg') | |||||
| print(result) | print(result) | ||||
| @@ -0,0 +1,20 @@ | |||||
| import unittest | |||||
| from modelscope.pipelines import pipeline | |||||
| from modelscope.utils.constant import Tasks | |||||
| from modelscope.utils.test_utils import test_level | |||||
| class GeneralRecognitionTest(unittest.TestCase): | |||||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||||
| def test_run(self): | |||||
| general_recognition = pipeline( | |||||
| Tasks.general_recognition, | |||||
| model='damo/cv_resnest101_general_recognition') | |||||
| result = general_recognition('data/test/images/dogs.jpg') | |||||
| print(result) | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||