From 2019315c544b44da5a3c103608095588cc9d8388 Mon Sep 17 00:00:00 2001 From: "lingcai.wl" Date: Wed, 22 Jun 2022 23:54:38 +0800 Subject: [PATCH] [to #42463204] support Pil.Image for image_captioning_pipeline --- modelscope/preprocessors/__init__.py | 2 +- modelscope/preprocessors/{multi_model.py => multi_modal.py} | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) rename modelscope/preprocessors/{multi_model.py => multi_modal.py} (95%) diff --git a/modelscope/preprocessors/__init__.py b/modelscope/preprocessors/__init__.py index 50860514..942d17c3 100644 --- a/modelscope/preprocessors/__init__.py +++ b/modelscope/preprocessors/__init__.py @@ -5,6 +5,6 @@ from .base import Preprocessor from .builder import PREPROCESSORS, build_preprocessor from .common import Compose from .image import LoadImage, load_image -from .multi_model import OfaImageCaptionPreprocessor +from .multi_modal import OfaImageCaptionPreprocessor from .nlp import * # noqa F403 from .text_to_speech import * # noqa F403 diff --git a/modelscope/preprocessors/multi_model.py b/modelscope/preprocessors/multi_modal.py similarity index 95% rename from modelscope/preprocessors/multi_model.py rename to modelscope/preprocessors/multi_modal.py index aa0bc8a7..7c8f0fab 100644 --- a/modelscope/preprocessors/multi_model.py +++ b/modelscope/preprocessors/multi_modal.py @@ -73,7 +73,7 @@ class OfaImageCaptionPreprocessor(Preprocessor): self.eos_item = torch.LongTensor([task.src_dict.eos()]) self.pad_idx = task.src_dict.pad() - @type_assert(object, (str, tuple)) + @type_assert(object, (str, tuple, Image.Image)) def __call__(self, data: Union[str, tuple]) -> Dict[str, Any]: def encode_text(text, length=None, append_bos=False, append_eos=False): @@ -89,8 +89,8 @@ class OfaImageCaptionPreprocessor(Preprocessor): s = torch.cat([s, self.eos_item]) return s - if isinstance(input, Image.Image): - patch_image = self.patch_resize_transform(input).unsqueeze(0) + if isinstance(data, Image.Image): + patch_image = self.patch_resize_transform(data).unsqueeze(0) else: patch_image = self.patch_resize_transform( load_image(data)).unsqueeze(0)