|
|
|
@@ -73,7 +73,7 @@ class OfaImageCaptionPreprocessor(Preprocessor): |
|
|
|
self.eos_item = torch.LongTensor([task.src_dict.eos()]) |
|
|
|
self.pad_idx = task.src_dict.pad() |
|
|
|
|
|
|
|
@type_assert(object, (str, tuple)) |
|
|
|
@type_assert(object, (str, tuple, Image.Image)) |
|
|
|
def __call__(self, data: Union[str, tuple]) -> Dict[str, Any]: |
|
|
|
|
|
|
|
def encode_text(text, length=None, append_bos=False, append_eos=False): |
|
|
|
@@ -89,8 +89,8 @@ class OfaImageCaptionPreprocessor(Preprocessor): |
|
|
|
s = torch.cat([s, self.eos_item]) |
|
|
|
return s |
|
|
|
|
|
|
|
if isinstance(input, Image.Image): |
|
|
|
patch_image = self.patch_resize_transform(input).unsqueeze(0) |
|
|
|
if isinstance(data, Image.Image): |
|
|
|
patch_image = self.patch_resize_transform(data).unsqueeze(0) |
|
|
|
else: |
|
|
|
patch_image = self.patch_resize_transform( |
|
|
|
load_image(data)).unsqueeze(0) |