修复了 mplug evaluation 使用了错误的 metrics 的问题,将部分中文处理代码独立到 utils 中,为 mplug 添加 trainer
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10612875
master
| @@ -299,6 +299,7 @@ class Trainers(object): | |||
| # multi-modal trainers | |||
| clip_multi_modal_embedding = 'clip-multi-modal-embedding' | |||
| ofa = 'ofa' | |||
| mplug = 'mplug' | |||
| # cv trainers | |||
| image_instance_segmentation = 'image-instance-segmentation' | |||
| @@ -6,6 +6,7 @@ import numpy as np | |||
| from modelscope.metainfo import Metrics | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.utils.chinese_utils import remove_space_between_chinese_chars | |||
| from modelscope.utils.registry import default_group | |||
| from .base import Metric | |||
| from .builder import METRICS, MetricKeys | |||
| @@ -26,10 +27,10 @@ class AccuracyMetric(Metric): | |||
| def add(self, outputs: Dict, inputs: Dict): | |||
| label_name = OutputKeys.LABEL if OutputKeys.LABEL in inputs else OutputKeys.LABELS | |||
| ground_truths = inputs[label_name] | |||
| eval_results = outputs[label_name] | |||
| eval_results = None | |||
| for key in [ | |||
| OutputKeys.CAPTION, OutputKeys.TEXT, OutputKeys.BOXES, | |||
| OutputKeys.LABELS, OutputKeys.SCORES | |||
| OutputKeys.LABEL, OutputKeys.LABELS, OutputKeys.SCORES | |||
| ]: | |||
| if key in outputs and outputs[key] is not None: | |||
| eval_results = outputs[key] | |||
| @@ -39,7 +40,7 @@ class AccuracyMetric(Metric): | |||
| self.labels.append(truth) | |||
| for result in eval_results: | |||
| if isinstance(truth, str): | |||
| self.preds.append(result.strip().replace(' ', '')) | |||
| self.preds.append(remove_space_between_chinese_chars(result)) | |||
| else: | |||
| self.preds.append(result) | |||
| @@ -41,8 +41,8 @@ task_default_metrics = { | |||
| Tasks.image_portrait_enhancement: | |||
| [Metrics.image_portrait_enhancement_metric], | |||
| Tasks.video_summarization: [Metrics.video_summarization_metric], | |||
| Tasks.image_captioning: [Metrics.text_gen_metric], | |||
| Tasks.visual_question_answering: [Metrics.text_gen_metric], | |||
| Tasks.image_captioning: [Metrics.accuracy], | |||
| Tasks.visual_question_answering: [Metrics.accuracy], | |||
| Tasks.movie_scene_segmentation: [Metrics.movie_scene_segmentation_metric], | |||
| Tasks.image_inpainting: [Metrics.image_inpainting_metric], | |||
| Tasks.referring_video_object_segmentation: | |||
| @@ -8,6 +8,7 @@ from rouge import Rouge | |||
| from modelscope.metainfo import Metrics | |||
| from modelscope.metrics.base import Metric | |||
| from modelscope.metrics.builder import METRICS, MetricKeys | |||
| from modelscope.utils.chinese_utils import rebuild_chinese_str | |||
| from modelscope.utils.registry import default_group | |||
| @@ -24,25 +25,13 @@ class TextGenerationMetric(Metric): | |||
| self.tgts: List[str] = [] | |||
| self.rouge = Rouge() | |||
| @staticmethod | |||
| def is_chinese_char(char: str): | |||
| # the length of char must be 1 | |||
| return '\u4e00' <= char <= '\u9fa5' | |||
| # add space for each chinese char | |||
| def rebuild_str(self, string: str): | |||
| return ' '.join(''.join([ | |||
| f' {char} ' if self.is_chinese_char(char) else char | |||
| for char in string | |||
| ]).split()) | |||
| def add(self, outputs: Dict[str, List[str]], inputs: Dict[str, List[str]]): | |||
| ground_truths = inputs['tgts'] | |||
| eval_results = outputs['preds'] | |||
| for truth in ground_truths: | |||
| self.tgts.append(self.rebuild_str(truth)) | |||
| self.tgts.append(rebuild_chinese_str(truth)) | |||
| for result in eval_results: | |||
| self.preds.append(self.rebuild_str(result)) | |||
| self.preds.append(rebuild_chinese_str(result)) | |||
| def _check(self, pred: str, tgt: str) -> bool: | |||
| @@ -45,10 +45,6 @@ class MPlugForAllTasks(TorchModel): | |||
| } | |||
| """ | |||
| replace_tokens_bert = (('[unused0]', ''), ('[PAD]', ''), | |||
| ('[unused1]', ''), (r' +', ' '), ('[SEP]', ''), | |||
| ('[unused2]', ''), ('[CLS]', ''), ('[UNK]', '')) | |||
| # get task from config file | |||
| task = Config.from_file( | |||
| osp.join(self.model_dir, ModelFile.CONFIGURATION)).task | |||
| @@ -60,10 +56,7 @@ class MPlugForAllTasks(TorchModel): | |||
| return {OutputKeys.SCORES: output[0].tolist()} | |||
| topk_ids, _ = output | |||
| pred_string: List[str] = \ | |||
| self.tokenizer.decode(topk_ids[0][0]) | |||
| for _old, _new in replace_tokens_bert: | |||
| pred_string = pred_string.replace(_old, _new) | |||
| pred_string = pred_string.strip() | |||
| self.tokenizer.decode(topk_ids[0][0], skip_special_tokens=True) | |||
| output_key = OutputKeys.CAPTION \ | |||
| if task == Tasks.image_captioning else OutputKeys.TEXT | |||
| return {output_key: pred_string} | |||
| @@ -87,19 +80,4 @@ class MPlugForAllTasks(TorchModel): | |||
| # evaluate | |||
| topk_ids, _ = output | |||
| preds: List[str] = [ | |||
| self.tokenizer.decode(batch[0]) for batch in topk_ids | |||
| ] | |||
| for i in range(len(preds)): | |||
| for _old, _new in replace_tokens_bert: | |||
| preds[i] = preds[i].replace(_old, _new) | |||
| preds[i] = preds[i].strip() | |||
| tgts: List[str] = [ | |||
| self.tokenizer.decode(batch) | |||
| for batch in input['answer_input_ids'].cpu().numpy().tolist() | |||
| ] | |||
| for i in range(len(tgts)): | |||
| for _old, _new in replace_tokens_bert: | |||
| tgts[i] = tgts[i].replace(_old, _new) | |||
| preds[i] = preds[i].strip() | |||
| return {'preds': preds, 'tgts': tgts} | |||
| return {'sequences': [list_tensor[0] for list_tensor in topk_ids]} | |||
| @@ -2,7 +2,7 @@ | |||
| from typing import Any, Dict | |||
| import numpy as np | |||
| from transformers.modeling_utils import PreTrainedModel | |||
| from transformers.modeling_utils import GenerationMixin | |||
| from modelscope.metainfo import TaskModels | |||
| from modelscope.models.builder import MODELS | |||
| @@ -17,7 +17,8 @@ __all__ = ['TaskModelForTextGeneration'] | |||
| @MODELS.register_module( | |||
| Tasks.text_generation, module_name=TaskModels.text_generation) | |||
| class TaskModelForTextGeneration(SingleBackboneTaskModelBase, PreTrainedModel): | |||
| class TaskModelForTextGeneration(SingleBackboneTaskModelBase, GenerationMixin): | |||
| main_input_name = 'input_ids' | |||
| def __init__(self, model_dir: str, *args, **kwargs): | |||
| """initialize the text generation model from the `model_dir` path. | |||
| @@ -10,6 +10,7 @@ from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines.base import Pipeline, Tensor | |||
| from modelscope.pipelines.builder import PIPELINES | |||
| from modelscope.preprocessors import Preprocessor, build_preprocessor | |||
| from modelscope.utils.chinese_utils import remove_space_between_chinese_chars | |||
| from modelscope.utils.constant import Fields, Tasks | |||
| from modelscope.utils.hub import read_config | |||
| @@ -78,28 +79,6 @@ class TextGenerationPipeline(Pipeline): | |||
| with torch.no_grad(): | |||
| return self.model.generate(inputs, **forward_params) | |||
| def _is_chinese_char(self, word: str): | |||
| chinese_punctuations = (',', '。', ';', ':' '!', '?', '《', '》') | |||
| return len(word) == 1 \ | |||
| and ('\u4e00' <= word <= '\u9fa5' or word in chinese_punctuations) | |||
| def _remove_space_between_chinese_chars(self, decoded: str): | |||
| old_word_list = decoded.split(' ') | |||
| new_word_list = [] | |||
| start = -1 | |||
| for i, word in enumerate(old_word_list): | |||
| if self._is_chinese_char(word): | |||
| if start == -1: | |||
| start = i | |||
| else: | |||
| if start != -1: | |||
| new_word_list.append(''.join(old_word_list[start:i])) | |||
| start = -1 | |||
| new_word_list.append(word) | |||
| if start != -1: | |||
| new_word_list.append(''.join(old_word_list[start:])) | |||
| return ' '.join(new_word_list) | |||
| def decode(self, inputs) -> str: | |||
| tokenizer = self.preprocessor.tokenizer | |||
| return tokenizer.decode(inputs.tolist(), skip_special_tokens=True) | |||
| @@ -128,5 +107,5 @@ class TextGenerationPipeline(Pipeline): | |||
| if isinstance(inputs, list) or len(inputs.shape) > 1: | |||
| inputs = inputs[0] | |||
| decoded = getattr(self, self.postprocessor)(inputs) | |||
| text = self._remove_space_between_chinese_chars(decoded) | |||
| text = remove_space_between_chinese_chars(decoded) | |||
| return {OutputKeys.TEXT: text} | |||
| @@ -6,11 +6,15 @@ from modelscope.utils.import_utils import LazyImportModule | |||
| if TYPE_CHECKING: | |||
| from .clip import CLIPTrainer | |||
| from .team import TEAMImgClsTrainer | |||
| from .ofa import OFATrainer | |||
| from .mplug import MPlugTrainer | |||
| else: | |||
| _import_structure = { | |||
| 'clip': ['CLIPTrainer'], | |||
| 'team': ['TEAMImgClsTrainer'] | |||
| 'team': ['TEAMImgClsTrainer'], | |||
| 'ofa': ['OFATrainer'], | |||
| 'mplug': ['MPlugTrainer'], | |||
| } | |||
| import sys | |||
| @@ -0,0 +1,3 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from .mplug_trainer import MPlugTrainer | |||
| @@ -0,0 +1,40 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from collections.abc import Mapping | |||
| import torch | |||
| from modelscope.metainfo import Trainers | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.trainers import NlpEpochBasedTrainer | |||
| from modelscope.trainers.builder import TRAINERS | |||
| from modelscope.utils.file_utils import func_receive_dict_inputs | |||
| @TRAINERS.register_module(module_name=Trainers.mplug) | |||
| class MPlugTrainer(NlpEpochBasedTrainer): | |||
| def _decode(self, tokens): | |||
| tokenizer = self.eval_preprocessor.tokenizer | |||
| return tokenizer.decode(tokens, skip_special_tokens=True) | |||
| def evaluation_step(self, data): | |||
| model = self.model.module if self._dist else self.model | |||
| model.eval() | |||
| with torch.no_grad(): | |||
| if isinstance( | |||
| data, | |||
| Mapping) and not func_receive_dict_inputs(model.forward): | |||
| result = model.forward(**data) | |||
| else: | |||
| result = model.forward(data) | |||
| result[OutputKeys.TEXT] = [ | |||
| self._decode(seq) for seq in result['sequences'] | |||
| ] | |||
| data[OutputKeys.LABELS] = [ | |||
| self._decode(seq) for seq in data['answer_input_ids'] | |||
| ] | |||
| return result | |||
| @@ -0,0 +1,35 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| def is_chinese_char(word: str): | |||
| chinese_punctuations = { | |||
| ',', '。', ';', ':' | |||
| '!', '?', '《', '》', '‘', '’', '“', '”', '(', ')', '【', '】' | |||
| } | |||
| return len(word) == 1 \ | |||
| and ('\u4e00' <= word <= '\u9fa5' or word in chinese_punctuations) | |||
| def remove_space_between_chinese_chars(decoded_str: str): | |||
| old_word_list = decoded_str.split(' ') | |||
| new_word_list = [] | |||
| start = -1 | |||
| for i, word in enumerate(old_word_list): | |||
| if is_chinese_char(word): | |||
| if start == -1: | |||
| start = i | |||
| else: | |||
| if start != -1: | |||
| new_word_list.append(''.join(old_word_list[start:i])) | |||
| start = -1 | |||
| new_word_list.append(word) | |||
| if start != -1: | |||
| new_word_list.append(''.join(old_word_list[start:])) | |||
| return ' '.join(new_word_list).strip() | |||
| # add space for each chinese char | |||
| def rebuild_chinese_str(string: str): | |||
| return ' '.join(''.join([ | |||
| f' {char} ' if is_chinese_char(char) else char for char in string | |||
| ]).split()) | |||
| @@ -20,10 +20,7 @@ class TestFinetuneMPlug(unittest.TestCase): | |||
| self.tmp_dir = tempfile.TemporaryDirectory().name | |||
| if not os.path.exists(self.tmp_dir): | |||
| os.makedirs(self.tmp_dir) | |||
| from modelscope.utils.constant import DownloadMode | |||
| datadict = MsDataset.load( | |||
| 'coco_captions_small_slice', | |||
| download_mode=DownloadMode.FORCE_REDOWNLOAD) | |||
| datadict = MsDataset.load('coco_captions_small_slice') | |||
| self.train_dataset = MsDataset( | |||
| datadict['train'].remap_columns({ | |||
| 'image:FILE': 'image', | |||
| @@ -40,18 +37,6 @@ class TestFinetuneMPlug(unittest.TestCase): | |||
| shutil.rmtree(self.tmp_dir) | |||
| super().tearDown() | |||
| def _cfg_modify_fn(self, cfg): | |||
| cfg.train.hooks = [{ | |||
| 'type': 'CheckpointHook', | |||
| 'interval': self.max_epochs | |||
| }, { | |||
| 'type': 'TextLoggerHook', | |||
| 'interval': 1 | |||
| }, { | |||
| 'type': 'IterTimerHook' | |||
| }] | |||
| return cfg | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_trainer_with_caption(self): | |||
| kwargs = dict( | |||
| @@ -59,11 +44,10 @@ class TestFinetuneMPlug(unittest.TestCase): | |||
| train_dataset=self.train_dataset, | |||
| eval_dataset=self.test_dataset, | |||
| max_epochs=self.max_epochs, | |||
| work_dir=self.tmp_dir, | |||
| cfg_modify_fn=self._cfg_modify_fn) | |||
| work_dir=self.tmp_dir) | |||
| trainer: EpochBasedTrainer = build_trainer( | |||
| name=Trainers.nlp_base_trainer, default_args=kwargs) | |||
| name=Trainers.mplug, default_args=kwargs) | |||
| trainer.train() | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| @@ -80,7 +64,7 @@ class TestFinetuneMPlug(unittest.TestCase): | |||
| work_dir=self.tmp_dir) | |||
| trainer: EpochBasedTrainer = build_trainer( | |||
| name=Trainers.nlp_base_trainer, default_args=kwargs) | |||
| name=Trainers.mplug, default_args=kwargs) | |||
| trainer.train() | |||
| results_files = os.listdir(self.tmp_dir) | |||
| self.assertIn(f'{trainer.timestamp}.log.json', results_files) | |||
| @@ -94,11 +78,10 @@ class TestFinetuneMPlug(unittest.TestCase): | |||
| train_dataset=self.train_dataset, | |||
| eval_dataset=self.test_dataset, | |||
| max_epochs=self.max_epochs, | |||
| work_dir=self.tmp_dir, | |||
| cfg_modify_fn=self._cfg_modify_fn) | |||
| work_dir=self.tmp_dir) | |||
| trainer: EpochBasedTrainer = build_trainer( | |||
| name=Trainers.nlp_base_trainer, default_args=kwargs) | |||
| name=Trainers.mplug, default_args=kwargs) | |||
| trainer.train() | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| @@ -115,7 +98,7 @@ class TestFinetuneMPlug(unittest.TestCase): | |||
| work_dir=self.tmp_dir) | |||
| trainer: EpochBasedTrainer = build_trainer( | |||
| name=Trainers.nlp_base_trainer, default_args=kwargs) | |||
| name=Trainers.mplug, default_args=kwargs) | |||
| trainer.train() | |||
| results_files = os.listdir(self.tmp_dir) | |||
| self.assertIn(f'{trainer.timestamp}.log.json', results_files) | |||
| @@ -129,11 +112,10 @@ class TestFinetuneMPlug(unittest.TestCase): | |||
| train_dataset=self.train_dataset, | |||
| eval_dataset=self.test_dataset, | |||
| max_epochs=self.max_epochs, | |||
| work_dir=self.tmp_dir, | |||
| cfg_modify_fn=self._cfg_modify_fn) | |||
| work_dir=self.tmp_dir) | |||
| trainer: EpochBasedTrainer = build_trainer( | |||
| name=Trainers.nlp_base_trainer, default_args=kwargs) | |||
| name=Trainers.mplug, default_args=kwargs) | |||
| trainer.train() | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| @@ -150,7 +132,7 @@ class TestFinetuneMPlug(unittest.TestCase): | |||
| work_dir=self.tmp_dir) | |||
| trainer: EpochBasedTrainer = build_trainer( | |||
| name=Trainers.nlp_base_trainer, default_args=kwargs) | |||
| name=Trainers.mplug, default_args=kwargs) | |||
| trainer.train() | |||
| results_files = os.listdir(self.tmp_dir) | |||
| self.assertIn(f'{trainer.timestamp}.log.json', results_files) | |||
| @@ -41,7 +41,7 @@ class AstScaningTest(unittest.TestCase): | |||
| self.assertIsInstance(from_imports, dict) | |||
| self.assertIsInstance(decorators, list) | |||
| self.assertListEqual(list(set(imports.keys()) - set(['torch'])), []) | |||
| self.assertEqual(len(from_imports.keys()), 9) | |||
| self.assertEqual(len(from_imports.keys()), 10) | |||
| self.assertTrue(from_imports['modelscope.metainfo'] is not None) | |||
| self.assertEqual(from_imports['modelscope.metainfo'], ['Pipelines']) | |||
| self.assertEqual(decorators, | |||