| @@ -141,3 +141,5 @@ class Metrics(object): | |||
| seq_cls_metric = 'seq_cls_metric' | |||
| # metrics for token-classification task | |||
| token_cls_metric = 'token-cls-metric' | |||
| # metrics for text-generation task | |||
| text_gen_metric = 'text-gen-metric' | |||
| @@ -1,3 +1,4 @@ | |||
| from .base import Metric | |||
| from .builder import METRICS, build_metric, task_default_metrics | |||
| from .sequence_classification_metric import SequenceClassificationMetric | |||
| from .text_generation_metric import TextGenerationMetric | |||
| @@ -17,6 +17,7 @@ class MetricKeys(object): | |||
| task_default_metrics = { | |||
| Tasks.sentence_similarity: [Metrics.seq_cls_metric], | |||
| Tasks.text_generation: [Metrics.text_gen_metric], | |||
| } | |||
| @@ -0,0 +1,34 @@ | |||
| from typing import Dict | |||
| import numpy as np | |||
| from rouge_score import rouge_scorer | |||
| from ..metainfo import Metrics | |||
| from ..utils.registry import default_group | |||
| from .base import Metric | |||
| from .builder import METRICS, MetricKeys | |||
| @METRICS.register_module( | |||
| group_key=default_group, module_name=Metrics.text_gen_metric) | |||
| class TextGenerationMetric(Metric): | |||
| """The metric computation class for text generation classes. | |||
| """ | |||
| def __init__(self): | |||
| self.preds = [] | |||
| self.tgts = [] | |||
| self.scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True) | |||
| def add(self, outputs: Dict, inputs: Dict): | |||
| ground_truths = outputs['tgts'] | |||
| eval_results = outputs['preds'] | |||
| self.preds.extend(eval_results) | |||
| self.tgts.extend(ground_truths) | |||
| def evaluate(self): | |||
| scores = [ | |||
| self.scorer.score(pred, tgt)['rougeL'].fmeasure | |||
| for pred, tgt in zip(self.preds, self.tgts) | |||
| ] | |||
| return {MetricKeys.F1: sum(scores) / len(scores)} | |||
| @@ -2,14 +2,15 @@ from typing import Dict | |||
| from ...metainfo import Models | |||
| from ...utils.constant import Tasks | |||
| from ..base import Model, Tensor | |||
| from ..base import Tensor | |||
| from ..base_torch import TorchModel | |||
| from ..builder import MODELS | |||
| __all__ = ['PalmForTextGeneration'] | |||
| @MODELS.register_module(Tasks.text_generation, module_name=Models.palm) | |||
| class PalmForTextGeneration(Model): | |||
| class PalmForTextGeneration(TorchModel): | |||
| def __init__(self, model_dir: str, *args, **kwargs): | |||
| """initialize the text generation model from the `model_dir` path. | |||
| @@ -22,15 +23,42 @@ class PalmForTextGeneration(Model): | |||
| super().__init__(model_dir, *args, **kwargs) | |||
| from sofa.models.palm_v2 import PalmForConditionalGeneration, Translator | |||
| model = PalmForConditionalGeneration.from_pretrained(model_dir) | |||
| self.tokenizer = model.tokenizer | |||
| self.generator = Translator(model) | |||
| self.model = PalmForConditionalGeneration.from_pretrained(model_dir) | |||
| self.tokenizer = self.model.tokenizer | |||
| self.generator = Translator(self.model) | |||
| def train(self): | |||
| return self.generator.train() | |||
| def _evaluate_postprocess(self, src: Tensor, tgt: Tensor, | |||
| mask_src: Tensor) -> Dict[str, str]: | |||
| replace_tokens_bert = (('[unused0]', ''), ('[PAD]', ''), | |||
| ('[unused1]', ''), (r' +', ' '), ('[SEP]', ''), | |||
| ('[unused2]', ''), ('[CLS]', ''), ('[UNK]', '')) | |||
| replace_tokens_roberta = ((r' +', ' '), ('<mask>', '<q>'), ('<pad>', | |||
| ''), | |||
| ('<s>', ''), ('</s>', ''), ('<unk>', ' ')) | |||
| def eval(self): | |||
| return self.generator.eval() | |||
| inputs = self.generator(src, mask_src) | |||
| pred_list = inputs['predictions'] | |||
| pred_id_list = [ | |||
| pred_batch[0].cpu().numpy().tolist() for pred_batch in pred_list | |||
| ] | |||
| tgt_id_list = tgt.cpu().numpy().tolist() | |||
| pred_strings = [ | |||
| self.tokenizer.decode(pred_ids) for pred_ids in pred_id_list | |||
| ] | |||
| tgt_strings = [ | |||
| self.tokenizer.decode(tgt_ids) for tgt_ids in tgt_id_list | |||
| ] | |||
| for _old, _new in replace_tokens_bert: | |||
| pred_strings = [s.replace(_old, _new) for s in pred_strings] | |||
| tgt_strings = [s.replace(_old, _new) for s in tgt_strings] | |||
| for _old, _new in replace_tokens_roberta: | |||
| pred_strings = [s.replace(_old, _new) for s in pred_strings] | |||
| tgt_strings = [s.replace(_old, _new) for s in tgt_strings] | |||
| for s in pred_strings: | |||
| s.strip() | |||
| for s in tgt_strings: | |||
| s.strip() | |||
| return {'preds': pred_strings, 'tgts': tgt_strings} | |||
| def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | |||
| """return the result by the model | |||
| @@ -45,5 +73,9 @@ class PalmForTextGeneration(Model): | |||
| 'predictions': Tensor([[1377, 4959, 2785, 6392...])]), # tokens need to be decode by tokenizer | |||
| } | |||
| """ | |||
| return self.generator(**input) | |||
| if self.training: | |||
| return {'loss': self.model(**input)} | |||
| elif 'tgt' in input: | |||
| return self._evaluate_postprocess(**input) | |||
| else: | |||
| return self.generator(**input) | |||
| @@ -216,8 +216,9 @@ class SentenceSimilarityFinetunePreprocessor(SentenceSimilarityPreprocessor): | |||
| Fields.nlp, module_name=Preprocessors.palm_text_gen_tokenizer) | |||
| class TextGenerationPreprocessor(NLPPreprocessorBase): | |||
| def __init__(self, model_dir: str, tokenizer, *args, **kwargs): | |||
| self.tokenizer = tokenizer | |||
| def __init__(self, model_dir: str, tokenizer=None, *args, **kwargs): | |||
| self.tokenizer = self.build_tokenizer( | |||
| model_dir) if tokenizer is None else tokenizer | |||
| kwargs['truncation'] = True | |||
| kwargs['padding'] = 'max_length' | |||
| kwargs['return_tensors'] = 'pt' | |||
| @@ -225,8 +226,43 @@ class TextGenerationPreprocessor(NLPPreprocessorBase): | |||
| kwargs['max_length'] = kwargs.pop('sequence_length', 128) | |||
| super().__init__(model_dir, *args, **kwargs) | |||
| def build_tokenizer(self, model_dir): | |||
| return self.tokenizer | |||
| def build_tokenizer(self, model_dir: str): | |||
| import os | |||
| from sofa.models.palm_v2 import PalmConfig | |||
| config_file = os.path.join(model_dir, 'config.json') | |||
| config = PalmConfig.from_json_file(config_file) if os.path.isfile( | |||
| config_file) else PalmConfig() | |||
| config.encoder_pth = os.path.join(model_dir, config.encoder_pth) | |||
| if config.encoder == 'roberta': | |||
| from transformers import RobertaTokenizer | |||
| tokenizer = RobertaTokenizer.from_pretrained( | |||
| config.encoder_pth, do_lower_case=False) | |||
| elif config.encoder == 'bert' or config.encoder == 'zh_bert': | |||
| from transformers import BertTokenizer | |||
| tokenizer = BertTokenizer.from_pretrained( | |||
| config.encoder_pth, do_lower_case=True) | |||
| return tokenizer | |||
| @PREPROCESSORS.register_module( | |||
| Fields.nlp, module_name='palm-text-gen-tokenizer-finetune') | |||
| class TextGenerationFinetunePreprocessor(TextGenerationPreprocessor): | |||
| @type_assert(object, dict) | |||
| def __call__(self, data: dict) -> Dict[str, Any]: | |||
| src_txt = data['src_txt'] | |||
| tgt_txt = data['tgt_txt'] | |||
| src_rst = super().__call__(src_txt) | |||
| tgt_rst = super().__call__(tgt_txt) | |||
| src_rst = {k: v.squeeze() for k, v in src_rst.items()} | |||
| tgt_rst = {k: v.squeeze() for k, v in tgt_rst.items()} | |||
| return { | |||
| 'src': src_rst['input_ids'], | |||
| 'tgt': tgt_rst['input_ids'], | |||
| 'mask_src': src_rst['attention_mask'] | |||
| } | |||
| @PREPROCESSORS.register_module(Fields.nlp) | |||
| @@ -297,6 +297,7 @@ class EpochBasedTrainer(BaseTrainer): | |||
| model = Model.from_pretrained(self.model_dir) | |||
| if not isinstance(model, nn.Module) and hasattr(model, 'model'): | |||
| return model.model | |||
| return model | |||
| def collate_fn(self, data): | |||
| """Prepare the input just before the forward function. | |||
| @@ -339,7 +340,7 @@ class EpochBasedTrainer(BaseTrainer): | |||
| model.train() | |||
| self._mode = ModeKeys.TRAIN | |||
| inputs = self.collate_fn(inputs) | |||
| if isinstance(inputs, dict): | |||
| if not isinstance(model, Model) and isinstance(inputs, dict): | |||
| train_outputs = model.forward(**inputs) | |||
| else: | |||
| train_outputs = model.forward(inputs) | |||
| @@ -10,6 +10,7 @@ import torch | |||
| from torch import distributed as dist | |||
| from tqdm import tqdm | |||
| from modelscope.models.base import Model | |||
| from modelscope.utils.torch_utils import get_dist_info | |||
| @@ -35,7 +36,10 @@ def single_gpu_test(model, | |||
| if data_collate_fn is not None: | |||
| data = data_collate_fn(data) | |||
| with torch.no_grad(): | |||
| result = model(**data) | |||
| if not isinstance(model, Model): | |||
| result = model(**data) | |||
| else: | |||
| result = model(data) | |||
| if metric_classes is not None: | |||
| for metric_cls in metric_classes: | |||
| metric_cls.add(result, data) | |||
| @@ -83,7 +87,10 @@ def multi_gpu_test(model, | |||
| if data_collate_fn is not None: | |||
| data = data_collate_fn(data) | |||
| with torch.no_grad(): | |||
| result = model(**data) | |||
| if not isinstance(model, Model): | |||
| result = model(**data) | |||
| else: | |||
| result = model(data) | |||
| results.extend(result) | |||
| rank, world_size = get_dist_info() | |||
| @@ -1,5 +1,6 @@ | |||
| http://ait-public.oss-cn-hangzhou-zmf.aliyuncs.com/jizhu/en_core_web_sm-2.3.1.tar.gz | |||
| pai-easynlp | |||
| rouge_score | |||
| sofa>=1.0.5 | |||
| spacy>=2.3.5 | |||
| @@ -0,0 +1,91 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os | |||
| import shutil | |||
| import tempfile | |||
| import unittest | |||
| from modelscope.hub.snapshot_download import snapshot_download | |||
| from modelscope.models.nlp.palm_for_text_generation import \ | |||
| PalmForTextGeneration | |||
| from modelscope.msdatasets import MsDataset | |||
| from modelscope.trainers import build_trainer | |||
| from modelscope.utils.constant import ModelFile | |||
| from modelscope.utils.test_utils import test_level | |||
| class TestTextGenerationTrainer(unittest.TestCase): | |||
| def setUp(self): | |||
| print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) | |||
| self.tmp_dir = tempfile.TemporaryDirectory().name | |||
| if not os.path.exists(self.tmp_dir): | |||
| os.makedirs(self.tmp_dir) | |||
| from datasets import Dataset | |||
| self.model_id = 'damo/nlp_palm2.0_text-generation_english-base' | |||
| dataset_dict = { | |||
| 'src_txt': [ | |||
| 'This is test sentence1-1', 'This is test sentence2-1', | |||
| 'This is test sentence3-1' | |||
| ], | |||
| 'tgt_txt': [ | |||
| 'This is test sentence1-2', 'This is test sentence2-2', | |||
| 'This is test sentence3-2' | |||
| ] | |||
| } | |||
| dataset = Dataset.from_dict(dataset_dict) | |||
| class MsDatasetDummy(MsDataset): | |||
| def __len__(self): | |||
| return len(self._hf_ds) | |||
| self.dataset = MsDatasetDummy(dataset) | |||
| def tearDown(self): | |||
| shutil.rmtree(self.tmp_dir) | |||
| super().tearDown() | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_trainer(self): | |||
| kwargs = dict( | |||
| model=self.model_id, | |||
| train_dataset=self.dataset, | |||
| eval_dataset=self.dataset, | |||
| work_dir=self.tmp_dir) | |||
| trainer = build_trainer(default_args=kwargs) | |||
| trainer.train() | |||
| results_files = os.listdir(self.tmp_dir) | |||
| self.assertIn(f'{trainer.timestamp}.log.json', results_files) | |||
| for i in range(3): | |||
| self.assertIn(f'epoch_{i+1}.pth', results_files) | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_trainer_with_model_and_args(self): | |||
| tmp_dir = tempfile.TemporaryDirectory().name | |||
| if not os.path.exists(tmp_dir): | |||
| os.makedirs(tmp_dir) | |||
| cache_path = snapshot_download(self.model_id) | |||
| model = PalmForTextGeneration.from_pretrained(cache_path) | |||
| kwargs = dict( | |||
| cfg_file=os.path.join(cache_path, ModelFile.CONFIGURATION), | |||
| model=model, | |||
| train_dataset=self.dataset, | |||
| eval_dataset=self.dataset, | |||
| max_epochs=2, | |||
| work_dir=self.tmp_dir) | |||
| trainer = build_trainer(default_args=kwargs) | |||
| trainer.train() | |||
| results_files = os.listdir(self.tmp_dir) | |||
| self.assertIn(f'{trainer.timestamp}.log.json', results_files) | |||
| for i in range(2): | |||
| self.assertIn(f'epoch_{i+1}.pth', results_files) | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||