diff --git a/modelscope/pipelines/multi_modal/image_captioning.py b/modelscope/pipelines/multi_modal/image_captioning.py index 91180e23..3e5f49d0 100644 --- a/modelscope/pipelines/multi_modal/image_captioning.py +++ b/modelscope/pipelines/multi_modal/image_captioning.py @@ -84,8 +84,11 @@ class ImageCaptionPipeline(Pipeline): s = torch.cat([s, self.eos_item]) return s - patch_image = self.patch_resize_transform( - load_image(input)).unsqueeze(0) + if isinstance(input, Image.Image): + patch_image = self.patch_resize_transform(input).unsqueeze(0) + else: + patch_image = self.patch_resize_transform( + load_image(input)).unsqueeze(0) patch_mask = torch.tensor([True]) text = 'what does the image describe?' src_text = encode_text(