From 53e9f02561081648ead120ca719d0b9c191c781b Mon Sep 17 00:00:00 2001 From: "hemu.zp" Date: Fri, 28 Oct 2022 09:28:15 +0800 Subject: [PATCH] [to #42322933] Fix bug for bleu in text generation metrics. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 修复了使用错误算法导致 BLEU-4 值计算结果偏小的问题 Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10558494 --- modelscope/metrics/text_generation_metric.py | 22 ++++++++++++-------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/modelscope/metrics/text_generation_metric.py b/modelscope/metrics/text_generation_metric.py index 9bca7cf3..c2d9c6a8 100644 --- a/modelscope/metrics/text_generation_metric.py +++ b/modelscope/metrics/text_generation_metric.py @@ -2,7 +2,7 @@ from typing import Dict, Iterable, List -from nltk.translate.bleu_score import sentence_bleu +from nltk.translate.bleu_score import SmoothingFunction, corpus_bleu from rouge import Rouge from modelscope.metainfo import Metrics @@ -63,14 +63,18 @@ class TextGenerationMetric(Metric): rouge_scores = self.rouge.get_scores(hyps=preds, refs=tgts) rouge_1 = mean(map(lambda score: score['rouge-1']['f'], rouge_scores)) rouge_l = mean(map(lambda score: score['rouge-l']['f'], rouge_scores)) - pred_split = tuple(pred.split(' ') for pred in self.preds) - tgt_split = tuple(tgt.split(' ') for tgt in self.tgts) - bleu_1 = mean( - sentence_bleu([tgt], pred, weights=(1, 0, 0, 0)) - for pred, tgt in zip(pred_split, tgt_split)) - bleu_4 = mean( - sentence_bleu([tgt], pred) - for pred, tgt in zip(pred_split, tgt_split)) + + pred_list = [each.strip().split(' ') for each in self.preds] + tgt_list = [[each.strip().split(' ')] for each in self.tgts] + bleu_1 = corpus_bleu( + tgt_list, + pred_list, + weights=(1, 0, 0, 0), + smoothing_function=SmoothingFunction().method3) + bleu_4 = corpus_bleu( + tgt_list, + pred_list, + smoothing_function=SmoothingFunction().method3) return { MetricKeys.ROUGE_1: rouge_1, MetricKeys.ROUGE_L: rouge_l,