Browse Source

[to #42322933] inference speedup for multimodal model GEMM

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9786728
master
lingchen.zlm yingda.chen 3 years ago
parent
commit
01ec568ce1
1 changed files with 3 additions and 3 deletions
  1. +3
    -3
      modelscope/models/multi_modal/gemm/gemm_base.py

+ 3
- 3
modelscope/models/multi_modal/gemm/gemm_base.py View File

@@ -491,7 +491,9 @@ class GEVL(nn.Module):
gen_logits = self.to_logits(out_embs[-1:, ...])
probs = F.softmax(self.gen_logit_scale.exp() * gen_logits, dim=-1)
pred = torch.argmax(
probs * (1.0 + torch.rand_like(probs)), axis=-1)
probs * (2.0 + torch.rand_like(probs)), axis=-1)
if int(pred) >= eot_token or int(pred) <= 0:
break
pred_tokens.append(pred)
text_input = torch.cat(
[text_input, pred.permute(1, 0).contiguous()], axis=1)
@@ -500,8 +502,6 @@ class GEVL(nn.Module):
for out_tokens in pred_text_tokens:
tokens = []
for x in out_tokens:
if x >= eot_token or x <= 0:
break
tokens.append(int(x))
out_text = self.tokenizer.decode(tokens)
out_text = out_text.strip()


Loading…
Cancel
Save