|
|
@@ -57,18 +57,14 @@ class MPlugForAllTasks(TorchModel): |
|
|
if task == Tasks.image_text_retrieval: |
|
|
if task == Tasks.image_text_retrieval: |
|
|
return {OutputKeys.SCORES: output[0].tolist()} |
|
|
return {OutputKeys.SCORES: output[0].tolist()} |
|
|
topk_ids, _ = output |
|
|
topk_ids, _ = output |
|
|
topk_ids = [topk_ids[i][0] for i in range(len(topk_ids))] |
|
|
|
|
|
pred_strings: List[str] = \ |
|
|
|
|
|
self.tokenizer.batch_decode(topk_ids, skip_special_tokens=True) |
|
|
|
|
|
output = [] |
|
|
|
|
|
for pred_string in pred_strings: |
|
|
|
|
|
for _old, _new in replace_tokens_bert: |
|
|
|
|
|
pred_string = pred_string.replace(_old, _new) |
|
|
|
|
|
pred_string = pred_string.strip() |
|
|
|
|
|
output.append(pred_string) |
|
|
|
|
|
|
|
|
pred_string: List[str] = \ |
|
|
|
|
|
self.tokenizer.decode(topk_ids[0][0]) |
|
|
|
|
|
for _old, _new in replace_tokens_bert: |
|
|
|
|
|
pred_string = pred_string.replace(_old, _new) |
|
|
|
|
|
pred_string = pred_string.strip() |
|
|
output_key = OutputKeys.CAPTION \ |
|
|
output_key = OutputKeys.CAPTION \ |
|
|
if task == Tasks.image_captioning else OutputKeys.TEXT |
|
|
if task == Tasks.image_captioning else OutputKeys.TEXT |
|
|
return {output_key: output} |
|
|
|
|
|
|
|
|
return {output_key: pred_string} |
|
|
|
|
|
|
|
|
# train and evaluate |
|
|
# train and evaluate |
|
|
import addict |
|
|
import addict |
|
|
|