From 01ec568ce1bd9d9f6a522407a4d6598435f83b32 Mon Sep 17 00:00:00 2001 From: "lingchen.zlm" Date: Wed, 17 Aug 2022 15:53:54 +0800 Subject: [PATCH] [to #42322933] inference speedup for multimodal model GEMM Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9786728 --- modelscope/models/multi_modal/gemm/gemm_base.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/modelscope/models/multi_modal/gemm/gemm_base.py b/modelscope/models/multi_modal/gemm/gemm_base.py index 26eea0d5..db928212 100644 --- a/modelscope/models/multi_modal/gemm/gemm_base.py +++ b/modelscope/models/multi_modal/gemm/gemm_base.py @@ -491,7 +491,9 @@ class GEVL(nn.Module): gen_logits = self.to_logits(out_embs[-1:, ...]) probs = F.softmax(self.gen_logit_scale.exp() * gen_logits, dim=-1) pred = torch.argmax( - probs * (1.0 + torch.rand_like(probs)), axis=-1) + probs * (2.0 + torch.rand_like(probs)), axis=-1) + if int(pred) >= eot_token or int(pred) <= 0: + break pred_tokens.append(pred) text_input = torch.cat( [text_input, pred.permute(1, 0).contiguous()], axis=1) @@ -500,8 +502,6 @@ class GEVL(nn.Module): for out_tokens in pred_text_tokens: tokens = [] for x in out_tokens: - if x >= eot_token or x <= 0: - break tokens.append(int(x)) out_text = self.tokenizer.decode(tokens) out_text = out_text.strip()