新增ofa关于日常场景文字识别的任务,主要包括:
1、新增pipeline及task名称定义;
2、新增pipeline、task、model及prepreocess核心类方法的代码逻辑;
3、其它同步修正的小细节逻辑;
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10471089
master
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:772b19f76c98044e39330853928624f10e085106a4292b4dd19f865531080747 | |||||
| size 959 | |||||
| @@ -263,6 +263,7 @@ class Pipelines(object): | |||||
| text_to_image_synthesis = 'text-to-image-synthesis' | text_to_image_synthesis = 'text-to-image-synthesis' | ||||
| video_multi_modal_embedding = 'video-multi-modal-embedding' | video_multi_modal_embedding = 'video-multi-modal-embedding' | ||||
| image_text_retrieval = 'image-text-retrieval' | image_text_retrieval = 'image-text-retrieval' | ||||
| ofa_ocr_recognition = 'ofa-ocr-recognition' | |||||
| class Trainers(object): | class Trainers(object): | ||||
| @@ -3,6 +3,7 @@ from modelscope.outputs import OutputKeys | |||||
| from modelscope.utils.constant import Tasks | from modelscope.utils.constant import Tasks | ||||
| OFA_TASK_KEY_MAPPING = { | OFA_TASK_KEY_MAPPING = { | ||||
| Tasks.ofa_ocr_recognition: OutputKeys.TEXT, | |||||
| Tasks.image_captioning: OutputKeys.CAPTION, | Tasks.image_captioning: OutputKeys.CAPTION, | ||||
| Tasks.summarization: OutputKeys.TEXT, | Tasks.summarization: OutputKeys.TEXT, | ||||
| Tasks.visual_question_answering: OutputKeys.TEXT, | Tasks.visual_question_answering: OutputKeys.TEXT, | ||||
| @@ -27,6 +27,7 @@ __all__ = ['OfaForAllTasks'] | |||||
| @MODELS.register_module(Tasks.image_captioning, module_name=Models.ofa) | @MODELS.register_module(Tasks.image_captioning, module_name=Models.ofa) | ||||
| @MODELS.register_module(Tasks.ofa_ocr_recognition, module_name=Models.ofa) | |||||
| @MODELS.register_module(Tasks.visual_grounding, module_name=Models.ofa) | @MODELS.register_module(Tasks.visual_grounding, module_name=Models.ofa) | ||||
| @MODELS.register_module( | @MODELS.register_module( | ||||
| Tasks.visual_question_answering, module_name=Models.ofa) | Tasks.visual_question_answering, module_name=Models.ofa) | ||||
| @@ -96,6 +97,7 @@ class OfaForAllTasks(TorchModel): | |||||
| 'traverse': self._traverse_inference, | 'traverse': self._traverse_inference, | ||||
| } | } | ||||
| self.task_inference_mapping = { | self.task_inference_mapping = { | ||||
| Tasks.ofa_ocr_recognition: self._text_gen_inference, | |||||
| Tasks.image_captioning: self._text_gen_inference, | Tasks.image_captioning: self._text_gen_inference, | ||||
| Tasks.summarization: self._text_gen_inference, | Tasks.summarization: self._text_gen_inference, | ||||
| Tasks.visual_grounding: self._visual_grounding_inference, | Tasks.visual_grounding: self._visual_grounding_inference, | ||||
| @@ -661,6 +661,7 @@ TASK_OUTPUTS = { | |||||
| # "caption": "this is an image caption text." | # "caption": "this is an image caption text." | ||||
| # } | # } | ||||
| Tasks.image_captioning: [OutputKeys.CAPTION], | Tasks.image_captioning: [OutputKeys.CAPTION], | ||||
| Tasks.ofa_ocr_recognition: [OutputKeys.TEXT], | |||||
| # visual grounding result for single sample | # visual grounding result for single sample | ||||
| # { | # { | ||||
| @@ -0,0 +1,52 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from typing import Any, Dict, Optional, Union | |||||
| import torch | |||||
| from modelscope.metainfo import Pipelines | |||||
| from modelscope.models.multi_modal import OfaForAllTasks | |||||
| from modelscope.outputs import OutputKeys | |||||
| from modelscope.pipelines.base import Model, Pipeline | |||||
| from modelscope.pipelines.builder import PIPELINES | |||||
| from modelscope.preprocessors import OfaPreprocessor, Preprocessor | |||||
| from modelscope.utils.constant import Tasks | |||||
| from modelscope.utils.logger import get_logger | |||||
| logger = get_logger() | |||||
| @PIPELINES.register_module( | |||||
| Tasks.ofa_ocr_recognition, module_name=Pipelines.ofa_ocr_recognition) | |||||
| class OcrRecognitionPipeline(Pipeline): | |||||
| def __init__(self, | |||||
| model: Union[Model, str], | |||||
| preprocessor: Optional[Preprocessor] = None, | |||||
| **kwargs): | |||||
| """ | |||||
| use `model` and `preprocessor` to create a ocr recognition pipeline for prediction | |||||
| Args: | |||||
| model: model id on modelscope hub. | |||||
| """ | |||||
| super().__init__(model=model) | |||||
| assert isinstance(model, str) or isinstance(model, Model), \ | |||||
| 'model must be a single str or OfaForAllTasks' | |||||
| if isinstance(model, str): | |||||
| pipe_model = Model.from_pretrained(model) | |||||
| elif isinstance(model, Model): | |||||
| pipe_model = model | |||||
| else: | |||||
| raise NotImplementedError | |||||
| pipe_model.model.eval() | |||||
| if preprocessor is None: | |||||
| if isinstance(pipe_model, OfaForAllTasks): | |||||
| preprocessor = OfaPreprocessor(pipe_model.model_dir) | |||||
| super().__init__(model=pipe_model, preprocessor=preprocessor, **kwargs) | |||||
| def forward(self, inputs: Dict[str, Any], | |||||
| **forward_params) -> Dict[str, Any]: | |||||
| with torch.no_grad(): | |||||
| return super().forward(inputs, **forward_params) | |||||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||||
| return inputs | |||||
| @@ -34,6 +34,7 @@ class OfaPreprocessor(Preprocessor): | |||||
| """ | """ | ||||
| super().__init__(*args, **kwargs) | super().__init__(*args, **kwargs) | ||||
| preprocess_mapping = { | preprocess_mapping = { | ||||
| Tasks.ofa_ocr_recognition: OfaOcrRecognitionPreprocessor, | |||||
| Tasks.image_captioning: OfaImageCaptioningPreprocessor, | Tasks.image_captioning: OfaImageCaptioningPreprocessor, | ||||
| Tasks.visual_grounding: OfaVisualGroundingPreprocessor, | Tasks.visual_grounding: OfaVisualGroundingPreprocessor, | ||||
| Tasks.visual_question_answering: | Tasks.visual_question_answering: | ||||
| @@ -45,6 +46,7 @@ class OfaPreprocessor(Preprocessor): | |||||
| Tasks.text_to_image_synthesis: OfaTextToImageSynthesisPreprocessor | Tasks.text_to_image_synthesis: OfaTextToImageSynthesisPreprocessor | ||||
| } | } | ||||
| input_key_mapping = { | input_key_mapping = { | ||||
| Tasks.ofa_ocr_recognition: ['image'], | |||||
| Tasks.image_captioning: ['image'], | Tasks.image_captioning: ['image'], | ||||
| Tasks.image_classification: ['image'], | Tasks.image_classification: ['image'], | ||||
| Tasks.summarization: ['text'], | Tasks.summarization: ['text'], | ||||
| @@ -1,6 +1,7 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| from .image_captioning import OfaImageCaptioningPreprocessor | from .image_captioning import OfaImageCaptioningPreprocessor | ||||
| from .image_classification import OfaImageClassificationPreprocessor | from .image_classification import OfaImageClassificationPreprocessor | ||||
| from .ocr_recognition import OfaOcrRecognitionPreprocessor | |||||
| from .summarization import OfaSummarizationPreprocessor | from .summarization import OfaSummarizationPreprocessor | ||||
| from .text_classification import OfaTextClassificationPreprocessor | from .text_classification import OfaTextClassificationPreprocessor | ||||
| from .text_to_image_synthesis import OfaTextToImageSynthesisPreprocessor | from .text_to_image_synthesis import OfaTextToImageSynthesisPreprocessor | ||||
| @@ -0,0 +1,99 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import random | |||||
| import unicodedata | |||||
| from typing import Any, Dict, Union | |||||
| import torch | |||||
| from PIL import Image | |||||
| from torchvision import transforms | |||||
| from torchvision.transforms import InterpolationMode | |||||
| from torchvision.transforms import functional as F | |||||
| from modelscope.preprocessors.image import load_image | |||||
| from .base import OfaBasePreprocessor | |||||
| IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) | |||||
| IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) | |||||
| def ocr_resize(img, patch_image_size, is_document=False): | |||||
| img = img.convert('RGB') | |||||
| width, height = img.size | |||||
| if is_document: | |||||
| new_height, new_width = 64, 1920 | |||||
| else: | |||||
| if width >= height: | |||||
| new_width = max(64, patch_image_size) | |||||
| new_height = max(64, int(patch_image_size * (height / width))) | |||||
| top = (patch_image_size - new_height) // 2 | |||||
| bottom = patch_image_size - new_height - top | |||||
| left, right = 0, 0 | |||||
| else: | |||||
| new_height = max(64, patch_image_size) | |||||
| new_width = max(64, int(patch_image_size * (width / height))) | |||||
| left = (patch_image_size - new_width) // 2 | |||||
| right = patch_image_size - new_width - left | |||||
| top, bottom = 0, 0 | |||||
| img_new = F.resize( | |||||
| img, | |||||
| (new_height, new_width), | |||||
| interpolation=InterpolationMode.BICUBIC, | |||||
| ) | |||||
| if is_document: | |||||
| img_split = transforms.ToTensor()(img_new).chunk(4, dim=-1) | |||||
| img_new = transforms.ToPILImage()(torch.cat(img_split, dim=-2)) | |||||
| new_width, new_height = img_new.size | |||||
| top = (patch_image_size - new_height) // 2 | |||||
| bottom = patch_image_size - new_height - top | |||||
| left, right = 0, 0 | |||||
| img_new = F.pad( | |||||
| img_new, padding=[left, top, right, bottom], padding_mode='edge') | |||||
| assert img_new.size == (patch_image_size, patch_image_size) | |||||
| return img_new | |||||
| class OfaOcrRecognitionPreprocessor(OfaBasePreprocessor): | |||||
| def __init__(self, cfg, model_dir): | |||||
| """preprocess the data | |||||
| Args: | |||||
| cfg(modelscope.utils.config.ConfigDict) : model config | |||||
| model_dir (str): model path | |||||
| """ | |||||
| super(OfaOcrRecognitionPreprocessor, self).__init__(cfg, model_dir) | |||||
| # Initialize transform | |||||
| if self.cfg.model.imagenet_default_mean_and_std: | |||||
| mean = IMAGENET_DEFAULT_MEAN | |||||
| std = IMAGENET_DEFAULT_STD | |||||
| else: | |||||
| mean = [0.5, 0.5, 0.5] | |||||
| std = [0.5, 0.5, 0.5] | |||||
| self.patch_resize_transform = transforms.Compose([ | |||||
| lambda image: ocr_resize( | |||||
| image, | |||||
| self.cfg.model.patch_image_size, | |||||
| is_document=self.cfg.model.is_document), | |||||
| transforms.ToTensor(), | |||||
| transforms.Normalize(mean=mean, std=std), | |||||
| ]) | |||||
| def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: | |||||
| image = data['image'] if isinstance( | |||||
| data['image'], Image.Image) else load_image(data['image']) | |||||
| patch_image = self.patch_resize_transform(image) | |||||
| prompt = self.cfg.model.get('prompt', '图片上的文字是什么?') | |||||
| inputs = self.get_inputs(prompt) | |||||
| sample = { | |||||
| 'source': inputs, | |||||
| 'patch_image': patch_image, | |||||
| 'patch_mask': torch.tensor([True]) | |||||
| } | |||||
| return sample | |||||
| @@ -151,6 +151,7 @@ class MultiModalTasks(object): | |||||
| visual_entailment = 'visual-entailment' | visual_entailment = 'visual-entailment' | ||||
| video_multi_modal_embedding = 'video-multi-modal-embedding' | video_multi_modal_embedding = 'video-multi-modal-embedding' | ||||
| image_text_retrieval = 'image-text-retrieval' | image_text_retrieval = 'image-text-retrieval' | ||||
| ofa_ocr_recognition = 'ofa-ocr-recognition' | |||||
| class TasksIODescriptions(object): | class TasksIODescriptions(object): | ||||
| @@ -45,6 +45,14 @@ class OfaTasksTest(unittest.TestCase, DemoCompatibilityCheck): | |||||
| result = img_captioning('data/test/images/image_captioning.png') | result = img_captioning('data/test/images/image_captioning.png') | ||||
| print(result[OutputKeys.CAPTION]) | print(result[OutputKeys.CAPTION]) | ||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_run_with_ocr_recognize_with_name(self): | |||||
| ocr_recognize = pipeline( | |||||
| Tasks.ofa_ocr_recognition, | |||||
| model='damo/ofa_ocr-recognition_scene_base_zh') | |||||
| result = ocr_recognize('data/test/images/image_ocr_recognition.jpg') | |||||
| print(result[OutputKeys.TEXT]) | |||||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | ||||
| def test_run_with_image_classification_with_model(self): | def test_run_with_image_classification_with_model(self): | ||||
| model = Model.from_pretrained( | model = Model.from_pretrained( | ||||