diff --git a/modelscope/models/multi_modal/gemm/gemm_model.py b/modelscope/models/multi_modal/gemm/gemm_model.py index 55b211c0..c90b35d4 100644 --- a/modelscope/models/multi_modal/gemm/gemm_model.py +++ b/modelscope/models/multi_modal/gemm/gemm_model.py @@ -67,7 +67,7 @@ class GEMMForMultiModalEmbedding(TorchModel): return img_tensor def parse_text(self, text_str): - if text_str is None: + if text_str is None or len(text_str) == 0: return None if isinstance(text_str, str): text_ids_tensor = self.gemm_model.tokenize(text_str) @@ -79,9 +79,12 @@ class GEMMForMultiModalEmbedding(TorchModel): return text_ids_tensor.view(1, -1) def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: - image = self.parse_image(input.get('image', input.get('img', None))) - text = self.parse_text(input.get('text', input.get('txt', None))) - captioning = input.get('captioning', False) is True + image_input = input.get('image', input.get('img', None)) + text_input = input.get('text', input.get('txt', None)) + captioning_input = input.get('captioning', None) + image = self.parse_image(image_input) + text = self.parse_text(text_input) + captioning = captioning_input is True or text_input == '' out = self.gemm_model(image, text, captioning) output = { OutputKeys.IMG_EMBEDDING: out.get('image_feature', None),