Browse Source

[to #42322933] Fix rouge metrics for chinese text

修复 TextGenerationMetric 中 Rouge 指标计算中文时结果不正确的问题

为文本生成添加 BLEU 指标
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10254323
master
hemu.zp yingda.chen 3 years ago
parent
commit
a3598f8d8c
2 changed files with 51 additions and 15 deletions
  1. +4
    -0
      modelscope/metrics/builder.py
  2. +47
    -15
      modelscope/metrics/text_generation_metric.py

+ 4
- 0
modelscope/metrics/builder.py View File

@@ -18,6 +18,10 @@ class MetricKeys(object):
SSIM = 'ssim'
AVERAGE_LOSS = 'avg_loss'
FScore = 'fscore'
BLEU_1 = 'bleu-1'
BLEU_4 = 'bleu-4'
ROUGE_1 = 'rouge-1'
ROUGE_L = 'rouge-l'


task_default_metrics = {


+ 47
- 15
modelscope/metrics/text_generation_metric.py View File

@@ -1,11 +1,14 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

from typing import Dict
from typing import Dict, Iterable, List

from nltk.translate.bleu_score import sentence_bleu
from rouge import Rouge

from modelscope.metainfo import Metrics
from modelscope.metrics.base import Metric
from modelscope.metrics.builder import METRICS, MetricKeys
from modelscope.utils.registry import default_group
from .base import Metric
from .builder import METRICS, MetricKeys


@METRICS.register_module(
@@ -17,20 +20,49 @@ class TextGenerationMetric(Metric):
"""

def __init__(self):
self.preds = []
self.tgts = []
from rouge_score import rouge_scorer
self.scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)
self.preds: List[str] = []
self.tgts: List[str] = []
self.rouge = Rouge()

@staticmethod
def is_chinese_char(char: str):
# the length of char must be 1
return '\u4e00' <= char <= '\u9fa5'

# add space for each chinese char
def rebuild_str(self, string: str):
return ' '.join(''.join([
f' {char} ' if self.is_chinese_char(char) else char
for char in string
]).split())

def add(self, outputs: Dict, inputs: Dict):
def add(self, outputs: Dict[str, List[str]], inputs: Dict = None):
ground_truths = outputs['tgts']
eval_results = outputs['preds']
self.preds.extend(eval_results)
self.tgts.extend(ground_truths)
for truth in ground_truths:
self.tgts.append(self.rebuild_str(truth))
for result in eval_results:
self.preds.append(self.rebuild_str(result))

def evaluate(self):
scores = [
self.scorer.score(pred, tgt)['rougeL'].fmeasure
for pred, tgt in zip(self.preds, self.tgts)
]
return {MetricKeys.F1: sum(scores) / len(scores)}

def mean(iter: Iterable) -> float:
return sum(iter) / len(self.preds)

rouge_scores = self.rouge.get_scores(hyps=self.preds, refs=self.tgts)
rouge_1 = mean(map(lambda score: score['rouge-1']['f'], rouge_scores))
rouge_l = mean(map(lambda score: score['rouge-l']['f'], rouge_scores))
pred_split = tuple(pred.split(' ') for pred in self.preds)
tgt_split = tuple(tgt.split(' ') for tgt in self.tgts)
bleu_1 = mean(
sentence_bleu([tgt], pred, weights=(1, 0, 0, 0))
for pred, tgt in zip(pred_split, tgt_split))
bleu_4 = mean(
sentence_bleu([tgt], pred)
for pred, tgt in zip(pred_split, tgt_split))
return {
MetricKeys.ROUGE_1: rouge_1,
MetricKeys.ROUGE_L: rouge_l,
MetricKeys.BLEU_1: bleu_1,
MetricKeys.BLEU_4: bleu_4
}

Loading…
Cancel
Save