|
|
|
@@ -491,7 +491,9 @@ class GEVL(nn.Module): |
|
|
|
gen_logits = self.to_logits(out_embs[-1:, ...]) |
|
|
|
probs = F.softmax(self.gen_logit_scale.exp() * gen_logits, dim=-1) |
|
|
|
pred = torch.argmax( |
|
|
|
probs * (1.0 + torch.rand_like(probs)), axis=-1) |
|
|
|
probs * (2.0 + torch.rand_like(probs)), axis=-1) |
|
|
|
if int(pred) >= eot_token or int(pred) <= 0: |
|
|
|
break |
|
|
|
pred_tokens.append(pred) |
|
|
|
text_input = torch.cat( |
|
|
|
[text_input, pred.permute(1, 0).contiguous()], axis=1) |
|
|
|
@@ -500,8 +502,6 @@ class GEVL(nn.Module): |
|
|
|
for out_tokens in pred_text_tokens: |
|
|
|
tokens = [] |
|
|
|
for x in out_tokens: |
|
|
|
if x >= eot_token or x <= 0: |
|
|
|
break |
|
|
|
tokens.append(int(x)) |
|
|
|
out_text = self.tokenizer.decode(tokens) |
|
|
|
out_text = out_text.strip() |
|
|
|
|