diff --git a/modelscope/metrics/text_generation_metric.py b/modelscope/metrics/text_generation_metric.py index 9bca7cf3..c2d9c6a8 100644 --- a/modelscope/metrics/text_generation_metric.py +++ b/modelscope/metrics/text_generation_metric.py @@ -2,7 +2,7 @@ from typing import Dict, Iterable, List -from nltk.translate.bleu_score import sentence_bleu +from nltk.translate.bleu_score import SmoothingFunction, corpus_bleu from rouge import Rouge from modelscope.metainfo import Metrics @@ -63,14 +63,18 @@ class TextGenerationMetric(Metric): rouge_scores = self.rouge.get_scores(hyps=preds, refs=tgts) rouge_1 = mean(map(lambda score: score['rouge-1']['f'], rouge_scores)) rouge_l = mean(map(lambda score: score['rouge-l']['f'], rouge_scores)) - pred_split = tuple(pred.split(' ') for pred in self.preds) - tgt_split = tuple(tgt.split(' ') for tgt in self.tgts) - bleu_1 = mean( - sentence_bleu([tgt], pred, weights=(1, 0, 0, 0)) - for pred, tgt in zip(pred_split, tgt_split)) - bleu_4 = mean( - sentence_bleu([tgt], pred) - for pred, tgt in zip(pred_split, tgt_split)) + + pred_list = [each.strip().split(' ') for each in self.preds] + tgt_list = [[each.strip().split(' ')] for each in self.tgts] + bleu_1 = corpus_bleu( + tgt_list, + pred_list, + weights=(1, 0, 0, 0), + smoothing_function=SmoothingFunction().method3) + bleu_4 = corpus_bleu( + tgt_list, + pred_list, + smoothing_function=SmoothingFunction().method3) return { MetricKeys.ROUGE_1: rouge_1, MetricKeys.ROUGE_L: rouge_l,