|
|
|
@@ -84,8 +84,11 @@ class ImageCaptionPipeline(Pipeline): |
|
|
|
s = torch.cat([s, self.eos_item]) |
|
|
|
return s |
|
|
|
|
|
|
|
patch_image = self.patch_resize_transform( |
|
|
|
load_image(input)).unsqueeze(0) |
|
|
|
if isinstance(input, Image.Image): |
|
|
|
patch_image = self.patch_resize_transform(input).unsqueeze(0) |
|
|
|
else: |
|
|
|
patch_image = self.patch_resize_transform( |
|
|
|
load_image(input)).unsqueeze(0) |
|
|
|
patch_mask = torch.tensor([True]) |
|
|
|
text = 'what does the image describe?' |
|
|
|
src_text = encode_text( |
|
|
|
|