From a3598f8d8c09ced380c9393d5c5208ef65aa13dd Mon Sep 17 00:00:00 2001 From: "hemu.zp" Date: Tue, 27 Sep 2022 23:24:58 +0800 Subject: [PATCH] [to #42322933] Fix rouge metrics for chinese text MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 修复 TextGenerationMetric 中 Rouge 指标计算中文时结果不正确的问题 为文本生成添加 BLEU 指标 Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10254323 --- modelscope/metrics/builder.py | 4 ++ modelscope/metrics/text_generation_metric.py | 62 +++++++++++++++----- 2 files changed, 51 insertions(+), 15 deletions(-) diff --git a/modelscope/metrics/builder.py b/modelscope/metrics/builder.py index 800e3508..9e875cc4 100644 --- a/modelscope/metrics/builder.py +++ b/modelscope/metrics/builder.py @@ -18,6 +18,10 @@ class MetricKeys(object): SSIM = 'ssim' AVERAGE_LOSS = 'avg_loss' FScore = 'fscore' + BLEU_1 = 'bleu-1' + BLEU_4 = 'bleu-4' + ROUGE_1 = 'rouge-1' + ROUGE_L = 'rouge-l' task_default_metrics = { diff --git a/modelscope/metrics/text_generation_metric.py b/modelscope/metrics/text_generation_metric.py index f154281d..90b80425 100644 --- a/modelscope/metrics/text_generation_metric.py +++ b/modelscope/metrics/text_generation_metric.py @@ -1,11 +1,14 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -from typing import Dict +from typing import Dict, Iterable, List + +from nltk.translate.bleu_score import sentence_bleu +from rouge import Rouge from modelscope.metainfo import Metrics +from modelscope.metrics.base import Metric +from modelscope.metrics.builder import METRICS, MetricKeys from modelscope.utils.registry import default_group -from .base import Metric -from .builder import METRICS, MetricKeys @METRICS.register_module( @@ -17,20 +20,49 @@ class TextGenerationMetric(Metric): """ def __init__(self): - self.preds = [] - self.tgts = [] - from rouge_score import rouge_scorer - self.scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True) + self.preds: List[str] = [] + self.tgts: List[str] = [] + self.rouge = Rouge() + + @staticmethod + def is_chinese_char(char: str): + # the length of char must be 1 + return '\u4e00' <= char <= '\u9fa5' + + # add space for each chinese char + def rebuild_str(self, string: str): + return ' '.join(''.join([ + f' {char} ' if self.is_chinese_char(char) else char + for char in string + ]).split()) - def add(self, outputs: Dict, inputs: Dict): + def add(self, outputs: Dict[str, List[str]], inputs: Dict = None): ground_truths = outputs['tgts'] eval_results = outputs['preds'] - self.preds.extend(eval_results) - self.tgts.extend(ground_truths) + for truth in ground_truths: + self.tgts.append(self.rebuild_str(truth)) + for result in eval_results: + self.preds.append(self.rebuild_str(result)) def evaluate(self): - scores = [ - self.scorer.score(pred, tgt)['rougeL'].fmeasure - for pred, tgt in zip(self.preds, self.tgts) - ] - return {MetricKeys.F1: sum(scores) / len(scores)} + + def mean(iter: Iterable) -> float: + return sum(iter) / len(self.preds) + + rouge_scores = self.rouge.get_scores(hyps=self.preds, refs=self.tgts) + rouge_1 = mean(map(lambda score: score['rouge-1']['f'], rouge_scores)) + rouge_l = mean(map(lambda score: score['rouge-l']['f'], rouge_scores)) + pred_split = tuple(pred.split(' ') for pred in self.preds) + tgt_split = tuple(tgt.split(' ') for tgt in self.tgts) + bleu_1 = mean( + sentence_bleu([tgt], pred, weights=(1, 0, 0, 0)) + for pred, tgt in zip(pred_split, tgt_split)) + bleu_4 = mean( + sentence_bleu([tgt], pred) + for pred, tgt in zip(pred_split, tgt_split)) + return { + MetricKeys.ROUGE_1: rouge_1, + MetricKeys.ROUGE_L: rouge_l, + MetricKeys.BLEU_1: bleu_1, + MetricKeys.BLEU_4: bleu_4 + }