From 0b7b964226923a748dbffbfbc731a1f94a6fb1e1 Mon Sep 17 00:00:00 2001 From: "hemu.zp" Date: Wed, 20 Jul 2022 16:36:11 +0800 Subject: [PATCH] [to #42322933] Add palm finetuning MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Palm 模型支持 finetuning --- modelscope/metainfo.py | 2 + modelscope/metrics/__init__.py | 1 + modelscope/metrics/builder.py | 1 + modelscope/metrics/text_generation_metric.py | 34 +++++++ .../models/nlp/palm_for_text_generation.py | 54 ++++++++--- modelscope/preprocessors/nlp.py | 44 ++++++++- modelscope/trainers/trainer.py | 3 +- modelscope/trainers/utils/inference.py | 11 ++- requirements/nlp.txt | 1 + .../trainers/test_text_generation_trainer.py | 91 +++++++++++++++++++ 10 files changed, 224 insertions(+), 18 deletions(-) create mode 100644 modelscope/metrics/text_generation_metric.py create mode 100644 tests/trainers/test_text_generation_trainer.py diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index f4f100dd..33d62084 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -141,3 +141,5 @@ class Metrics(object): seq_cls_metric = 'seq_cls_metric' # metrics for token-classification task token_cls_metric = 'token-cls-metric' + # metrics for text-generation task + text_gen_metric = 'text-gen-metric' diff --git a/modelscope/metrics/__init__.py b/modelscope/metrics/__init__.py index 681c65c4..9a0ca94a 100644 --- a/modelscope/metrics/__init__.py +++ b/modelscope/metrics/__init__.py @@ -1,3 +1,4 @@ from .base import Metric from .builder import METRICS, build_metric, task_default_metrics from .sequence_classification_metric import SequenceClassificationMetric +from .text_generation_metric import TextGenerationMetric diff --git a/modelscope/metrics/builder.py b/modelscope/metrics/builder.py index 03657b89..860a3295 100644 --- a/modelscope/metrics/builder.py +++ b/modelscope/metrics/builder.py @@ -17,6 +17,7 @@ class MetricKeys(object): task_default_metrics = { Tasks.sentence_similarity: [Metrics.seq_cls_metric], + Tasks.text_generation: [Metrics.text_gen_metric], } diff --git a/modelscope/metrics/text_generation_metric.py b/modelscope/metrics/text_generation_metric.py new file mode 100644 index 00000000..ae61d225 --- /dev/null +++ b/modelscope/metrics/text_generation_metric.py @@ -0,0 +1,34 @@ +from typing import Dict + +import numpy as np +from rouge_score import rouge_scorer + +from ..metainfo import Metrics +from ..utils.registry import default_group +from .base import Metric +from .builder import METRICS, MetricKeys + + +@METRICS.register_module( + group_key=default_group, module_name=Metrics.text_gen_metric) +class TextGenerationMetric(Metric): + """The metric computation class for text generation classes. + """ + + def __init__(self): + self.preds = [] + self.tgts = [] + self.scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True) + + def add(self, outputs: Dict, inputs: Dict): + ground_truths = outputs['tgts'] + eval_results = outputs['preds'] + self.preds.extend(eval_results) + self.tgts.extend(ground_truths) + + def evaluate(self): + scores = [ + self.scorer.score(pred, tgt)['rougeL'].fmeasure + for pred, tgt in zip(self.preds, self.tgts) + ] + return {MetricKeys.F1: sum(scores) / len(scores)} diff --git a/modelscope/models/nlp/palm_for_text_generation.py b/modelscope/models/nlp/palm_for_text_generation.py index f6c15387..1d5de894 100644 --- a/modelscope/models/nlp/palm_for_text_generation.py +++ b/modelscope/models/nlp/palm_for_text_generation.py @@ -2,14 +2,15 @@ from typing import Dict from ...metainfo import Models from ...utils.constant import Tasks -from ..base import Model, Tensor +from ..base import Tensor +from ..base_torch import TorchModel from ..builder import MODELS __all__ = ['PalmForTextGeneration'] @MODELS.register_module(Tasks.text_generation, module_name=Models.palm) -class PalmForTextGeneration(Model): +class PalmForTextGeneration(TorchModel): def __init__(self, model_dir: str, *args, **kwargs): """initialize the text generation model from the `model_dir` path. @@ -22,15 +23,42 @@ class PalmForTextGeneration(Model): super().__init__(model_dir, *args, **kwargs) from sofa.models.palm_v2 import PalmForConditionalGeneration, Translator - model = PalmForConditionalGeneration.from_pretrained(model_dir) - self.tokenizer = model.tokenizer - self.generator = Translator(model) + self.model = PalmForConditionalGeneration.from_pretrained(model_dir) + self.tokenizer = self.model.tokenizer + self.generator = Translator(self.model) - def train(self): - return self.generator.train() + def _evaluate_postprocess(self, src: Tensor, tgt: Tensor, + mask_src: Tensor) -> Dict[str, str]: + replace_tokens_bert = (('[unused0]', ''), ('[PAD]', ''), + ('[unused1]', ''), (r' +', ' '), ('[SEP]', ''), + ('[unused2]', ''), ('[CLS]', ''), ('[UNK]', '')) + replace_tokens_roberta = ((r' +', ' '), ('', ''), ('', + ''), + ('', ''), ('', ''), ('', ' ')) - def eval(self): - return self.generator.eval() + inputs = self.generator(src, mask_src) + pred_list = inputs['predictions'] + pred_id_list = [ + pred_batch[0].cpu().numpy().tolist() for pred_batch in pred_list + ] + tgt_id_list = tgt.cpu().numpy().tolist() + pred_strings = [ + self.tokenizer.decode(pred_ids) for pred_ids in pred_id_list + ] + tgt_strings = [ + self.tokenizer.decode(tgt_ids) for tgt_ids in tgt_id_list + ] + for _old, _new in replace_tokens_bert: + pred_strings = [s.replace(_old, _new) for s in pred_strings] + tgt_strings = [s.replace(_old, _new) for s in tgt_strings] + for _old, _new in replace_tokens_roberta: + pred_strings = [s.replace(_old, _new) for s in pred_strings] + tgt_strings = [s.replace(_old, _new) for s in tgt_strings] + for s in pred_strings: + s.strip() + for s in tgt_strings: + s.strip() + return {'preds': pred_strings, 'tgts': tgt_strings} def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: """return the result by the model @@ -45,5 +73,9 @@ class PalmForTextGeneration(Model): 'predictions': Tensor([[1377, 4959, 2785, 6392...])]), # tokens need to be decode by tokenizer } """ - - return self.generator(**input) + if self.training: + return {'loss': self.model(**input)} + elif 'tgt' in input: + return self._evaluate_postprocess(**input) + else: + return self.generator(**input) diff --git a/modelscope/preprocessors/nlp.py b/modelscope/preprocessors/nlp.py index c4c3aa71..910aed6a 100644 --- a/modelscope/preprocessors/nlp.py +++ b/modelscope/preprocessors/nlp.py @@ -216,8 +216,9 @@ class SentenceSimilarityFinetunePreprocessor(SentenceSimilarityPreprocessor): Fields.nlp, module_name=Preprocessors.palm_text_gen_tokenizer) class TextGenerationPreprocessor(NLPPreprocessorBase): - def __init__(self, model_dir: str, tokenizer, *args, **kwargs): - self.tokenizer = tokenizer + def __init__(self, model_dir: str, tokenizer=None, *args, **kwargs): + self.tokenizer = self.build_tokenizer( + model_dir) if tokenizer is None else tokenizer kwargs['truncation'] = True kwargs['padding'] = 'max_length' kwargs['return_tensors'] = 'pt' @@ -225,8 +226,43 @@ class TextGenerationPreprocessor(NLPPreprocessorBase): kwargs['max_length'] = kwargs.pop('sequence_length', 128) super().__init__(model_dir, *args, **kwargs) - def build_tokenizer(self, model_dir): - return self.tokenizer + def build_tokenizer(self, model_dir: str): + import os + from sofa.models.palm_v2 import PalmConfig + + config_file = os.path.join(model_dir, 'config.json') + config = PalmConfig.from_json_file(config_file) if os.path.isfile( + config_file) else PalmConfig() + config.encoder_pth = os.path.join(model_dir, config.encoder_pth) + if config.encoder == 'roberta': + from transformers import RobertaTokenizer + tokenizer = RobertaTokenizer.from_pretrained( + config.encoder_pth, do_lower_case=False) + elif config.encoder == 'bert' or config.encoder == 'zh_bert': + from transformers import BertTokenizer + tokenizer = BertTokenizer.from_pretrained( + config.encoder_pth, do_lower_case=True) + return tokenizer + + +@PREPROCESSORS.register_module( + Fields.nlp, module_name='palm-text-gen-tokenizer-finetune') +class TextGenerationFinetunePreprocessor(TextGenerationPreprocessor): + + @type_assert(object, dict) + def __call__(self, data: dict) -> Dict[str, Any]: + src_txt = data['src_txt'] + tgt_txt = data['tgt_txt'] + src_rst = super().__call__(src_txt) + tgt_rst = super().__call__(tgt_txt) + src_rst = {k: v.squeeze() for k, v in src_rst.items()} + tgt_rst = {k: v.squeeze() for k, v in tgt_rst.items()} + + return { + 'src': src_rst['input_ids'], + 'tgt': tgt_rst['input_ids'], + 'mask_src': src_rst['attention_mask'] + } @PREPROCESSORS.register_module(Fields.nlp) diff --git a/modelscope/trainers/trainer.py b/modelscope/trainers/trainer.py index 6249c82d..6a178104 100644 --- a/modelscope/trainers/trainer.py +++ b/modelscope/trainers/trainer.py @@ -297,6 +297,7 @@ class EpochBasedTrainer(BaseTrainer): model = Model.from_pretrained(self.model_dir) if not isinstance(model, nn.Module) and hasattr(model, 'model'): return model.model + return model def collate_fn(self, data): """Prepare the input just before the forward function. @@ -339,7 +340,7 @@ class EpochBasedTrainer(BaseTrainer): model.train() self._mode = ModeKeys.TRAIN inputs = self.collate_fn(inputs) - if isinstance(inputs, dict): + if not isinstance(model, Model) and isinstance(inputs, dict): train_outputs = model.forward(**inputs) else: train_outputs = model.forward(inputs) diff --git a/modelscope/trainers/utils/inference.py b/modelscope/trainers/utils/inference.py index 4a455b5e..f056fb08 100644 --- a/modelscope/trainers/utils/inference.py +++ b/modelscope/trainers/utils/inference.py @@ -10,6 +10,7 @@ import torch from torch import distributed as dist from tqdm import tqdm +from modelscope.models.base import Model from modelscope.utils.torch_utils import get_dist_info @@ -35,7 +36,10 @@ def single_gpu_test(model, if data_collate_fn is not None: data = data_collate_fn(data) with torch.no_grad(): - result = model(**data) + if not isinstance(model, Model): + result = model(**data) + else: + result = model(data) if metric_classes is not None: for metric_cls in metric_classes: metric_cls.add(result, data) @@ -83,7 +87,10 @@ def multi_gpu_test(model, if data_collate_fn is not None: data = data_collate_fn(data) with torch.no_grad(): - result = model(**data) + if not isinstance(model, Model): + result = model(**data) + else: + result = model(data) results.extend(result) rank, world_size = get_dist_info() diff --git a/requirements/nlp.txt b/requirements/nlp.txt index 0c7f1b59..827bd512 100644 --- a/requirements/nlp.txt +++ b/requirements/nlp.txt @@ -1,5 +1,6 @@ http://ait-public.oss-cn-hangzhou-zmf.aliyuncs.com/jizhu/en_core_web_sm-2.3.1.tar.gz pai-easynlp +rouge_score sofa>=1.0.5 spacy>=2.3.5 diff --git a/tests/trainers/test_text_generation_trainer.py b/tests/trainers/test_text_generation_trainer.py new file mode 100644 index 00000000..28f08c97 --- /dev/null +++ b/tests/trainers/test_text_generation_trainer.py @@ -0,0 +1,91 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import shutil +import tempfile +import unittest + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.models.nlp.palm_for_text_generation import \ + PalmForTextGeneration +from modelscope.msdatasets import MsDataset +from modelscope.trainers import build_trainer +from modelscope.utils.constant import ModelFile +from modelscope.utils.test_utils import test_level + + +class TestTextGenerationTrainer(unittest.TestCase): + + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + self.tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(self.tmp_dir): + os.makedirs(self.tmp_dir) + + from datasets import Dataset + + self.model_id = 'damo/nlp_palm2.0_text-generation_english-base' + + dataset_dict = { + 'src_txt': [ + 'This is test sentence1-1', 'This is test sentence2-1', + 'This is test sentence3-1' + ], + 'tgt_txt': [ + 'This is test sentence1-2', 'This is test sentence2-2', + 'This is test sentence3-2' + ] + } + dataset = Dataset.from_dict(dataset_dict) + + class MsDatasetDummy(MsDataset): + + def __len__(self): + return len(self._hf_ds) + + self.dataset = MsDatasetDummy(dataset) + + def tearDown(self): + shutil.rmtree(self.tmp_dir) + super().tearDown() + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_trainer(self): + kwargs = dict( + model=self.model_id, + train_dataset=self.dataset, + eval_dataset=self.dataset, + work_dir=self.tmp_dir) + + trainer = build_trainer(default_args=kwargs) + trainer.train() + results_files = os.listdir(self.tmp_dir) + self.assertIn(f'{trainer.timestamp}.log.json', results_files) + for i in range(3): + self.assertIn(f'epoch_{i+1}.pth', results_files) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_trainer_with_model_and_args(self): + tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(tmp_dir): + os.makedirs(tmp_dir) + + cache_path = snapshot_download(self.model_id) + model = PalmForTextGeneration.from_pretrained(cache_path) + kwargs = dict( + cfg_file=os.path.join(cache_path, ModelFile.CONFIGURATION), + model=model, + train_dataset=self.dataset, + eval_dataset=self.dataset, + max_epochs=2, + work_dir=self.tmp_dir) + + trainer = build_trainer(default_args=kwargs) + trainer.train() + results_files = os.listdir(self.tmp_dir) + self.assertIn(f'{trainer.timestamp}.log.json', results_files) + for i in range(2): + self.assertIn(f'epoch_{i+1}.pth', results_files) + + +if __name__ == '__main__': + unittest.main()