|
|
|
@@ -118,10 +118,12 @@ class MPlugVisualQuestionAnsweringPreprocessor(Preprocessor): |
|
|
|
transforms.Normalize(mean=mean, std=std), |
|
|
|
]) |
|
|
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
|
|
|
image, question = data['image'], data['question'] |
|
|
|
image = Image.open(image).convert('RGB') if isinstance(image, |
|
|
|
str) else image |
|
|
|
def __call__(self, data: Union[tuple, Dict[str, Any]]) -> Dict[str, Any]: |
|
|
|
image: Image.Image = data[0] if isinstance(data, |
|
|
|
tuple) else data['image'] |
|
|
|
question: str = data[1] if isinstance(data, |
|
|
|
tuple) else data['question'] |
|
|
|
image = image.convert('RGB') |
|
|
|
image = self.patch_resize_transform(image) |
|
|
|
image = torch.stack([image], dim=0) |
|
|
|
question = self.tokenizer([question.lower()], |
|
|
|
|