Browse Source

[to #42322933] Add palm finetuning

Palm 模型支持 finetuning
master
hemu.zp 3 years ago
parent
commit
0b7b964226
10 changed files with 224 additions and 18 deletions
  1. +2
    -0
      modelscope/metainfo.py
  2. +1
    -0
      modelscope/metrics/__init__.py
  3. +1
    -0
      modelscope/metrics/builder.py
  4. +34
    -0
      modelscope/metrics/text_generation_metric.py
  5. +43
    -11
      modelscope/models/nlp/palm_for_text_generation.py
  6. +40
    -4
      modelscope/preprocessors/nlp.py
  7. +2
    -1
      modelscope/trainers/trainer.py
  8. +9
    -2
      modelscope/trainers/utils/inference.py
  9. +1
    -0
      requirements/nlp.txt
  10. +91
    -0
      tests/trainers/test_text_generation_trainer.py

+ 2
- 0
modelscope/metainfo.py View File

@@ -141,3 +141,5 @@ class Metrics(object):
seq_cls_metric = 'seq_cls_metric' seq_cls_metric = 'seq_cls_metric'
# metrics for token-classification task # metrics for token-classification task
token_cls_metric = 'token-cls-metric' token_cls_metric = 'token-cls-metric'
# metrics for text-generation task
text_gen_metric = 'text-gen-metric'

+ 1
- 0
modelscope/metrics/__init__.py View File

@@ -1,3 +1,4 @@
from .base import Metric from .base import Metric
from .builder import METRICS, build_metric, task_default_metrics from .builder import METRICS, build_metric, task_default_metrics
from .sequence_classification_metric import SequenceClassificationMetric from .sequence_classification_metric import SequenceClassificationMetric
from .text_generation_metric import TextGenerationMetric

+ 1
- 0
modelscope/metrics/builder.py View File

@@ -17,6 +17,7 @@ class MetricKeys(object):


task_default_metrics = { task_default_metrics = {
Tasks.sentence_similarity: [Metrics.seq_cls_metric], Tasks.sentence_similarity: [Metrics.seq_cls_metric],
Tasks.text_generation: [Metrics.text_gen_metric],
} }






+ 34
- 0
modelscope/metrics/text_generation_metric.py View File

@@ -0,0 +1,34 @@
from typing import Dict

import numpy as np
from rouge_score import rouge_scorer

from ..metainfo import Metrics
from ..utils.registry import default_group
from .base import Metric
from .builder import METRICS, MetricKeys


@METRICS.register_module(
group_key=default_group, module_name=Metrics.text_gen_metric)
class TextGenerationMetric(Metric):
"""The metric computation class for text generation classes.
"""

def __init__(self):
self.preds = []
self.tgts = []
self.scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)

def add(self, outputs: Dict, inputs: Dict):
ground_truths = outputs['tgts']
eval_results = outputs['preds']
self.preds.extend(eval_results)
self.tgts.extend(ground_truths)

def evaluate(self):
scores = [
self.scorer.score(pred, tgt)['rougeL'].fmeasure
for pred, tgt in zip(self.preds, self.tgts)
]
return {MetricKeys.F1: sum(scores) / len(scores)}

+ 43
- 11
modelscope/models/nlp/palm_for_text_generation.py View File

@@ -2,14 +2,15 @@ from typing import Dict


from ...metainfo import Models from ...metainfo import Models
from ...utils.constant import Tasks from ...utils.constant import Tasks
from ..base import Model, Tensor
from ..base import Tensor
from ..base_torch import TorchModel
from ..builder import MODELS from ..builder import MODELS


__all__ = ['PalmForTextGeneration'] __all__ = ['PalmForTextGeneration']




@MODELS.register_module(Tasks.text_generation, module_name=Models.palm) @MODELS.register_module(Tasks.text_generation, module_name=Models.palm)
class PalmForTextGeneration(Model):
class PalmForTextGeneration(TorchModel):


def __init__(self, model_dir: str, *args, **kwargs): def __init__(self, model_dir: str, *args, **kwargs):
"""initialize the text generation model from the `model_dir` path. """initialize the text generation model from the `model_dir` path.
@@ -22,15 +23,42 @@ class PalmForTextGeneration(Model):
super().__init__(model_dir, *args, **kwargs) super().__init__(model_dir, *args, **kwargs)


from sofa.models.palm_v2 import PalmForConditionalGeneration, Translator from sofa.models.palm_v2 import PalmForConditionalGeneration, Translator
model = PalmForConditionalGeneration.from_pretrained(model_dir)
self.tokenizer = model.tokenizer
self.generator = Translator(model)
self.model = PalmForConditionalGeneration.from_pretrained(model_dir)
self.tokenizer = self.model.tokenizer
self.generator = Translator(self.model)


def train(self):
return self.generator.train()
def _evaluate_postprocess(self, src: Tensor, tgt: Tensor,
mask_src: Tensor) -> Dict[str, str]:
replace_tokens_bert = (('[unused0]', ''), ('[PAD]', ''),
('[unused1]', ''), (r' +', ' '), ('[SEP]', ''),
('[unused2]', ''), ('[CLS]', ''), ('[UNK]', ''))
replace_tokens_roberta = ((r' +', ' '), ('<mask>', '<q>'), ('<pad>',
''),
('<s>', ''), ('</s>', ''), ('<unk>', ' '))


def eval(self):
return self.generator.eval()
inputs = self.generator(src, mask_src)
pred_list = inputs['predictions']
pred_id_list = [
pred_batch[0].cpu().numpy().tolist() for pred_batch in pred_list
]
tgt_id_list = tgt.cpu().numpy().tolist()
pred_strings = [
self.tokenizer.decode(pred_ids) for pred_ids in pred_id_list
]
tgt_strings = [
self.tokenizer.decode(tgt_ids) for tgt_ids in tgt_id_list
]
for _old, _new in replace_tokens_bert:
pred_strings = [s.replace(_old, _new) for s in pred_strings]
tgt_strings = [s.replace(_old, _new) for s in tgt_strings]
for _old, _new in replace_tokens_roberta:
pred_strings = [s.replace(_old, _new) for s in pred_strings]
tgt_strings = [s.replace(_old, _new) for s in tgt_strings]
for s in pred_strings:
s.strip()
for s in tgt_strings:
s.strip()
return {'preds': pred_strings, 'tgts': tgt_strings}


def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
"""return the result by the model """return the result by the model
@@ -45,5 +73,9 @@ class PalmForTextGeneration(Model):
'predictions': Tensor([[1377, 4959, 2785, 6392...])]), # tokens need to be decode by tokenizer 'predictions': Tensor([[1377, 4959, 2785, 6392...])]), # tokens need to be decode by tokenizer
} }
""" """

return self.generator(**input)
if self.training:
return {'loss': self.model(**input)}
elif 'tgt' in input:
return self._evaluate_postprocess(**input)
else:
return self.generator(**input)

+ 40
- 4
modelscope/preprocessors/nlp.py View File

@@ -216,8 +216,9 @@ class SentenceSimilarityFinetunePreprocessor(SentenceSimilarityPreprocessor):
Fields.nlp, module_name=Preprocessors.palm_text_gen_tokenizer) Fields.nlp, module_name=Preprocessors.palm_text_gen_tokenizer)
class TextGenerationPreprocessor(NLPPreprocessorBase): class TextGenerationPreprocessor(NLPPreprocessorBase):


def __init__(self, model_dir: str, tokenizer, *args, **kwargs):
self.tokenizer = tokenizer
def __init__(self, model_dir: str, tokenizer=None, *args, **kwargs):
self.tokenizer = self.build_tokenizer(
model_dir) if tokenizer is None else tokenizer
kwargs['truncation'] = True kwargs['truncation'] = True
kwargs['padding'] = 'max_length' kwargs['padding'] = 'max_length'
kwargs['return_tensors'] = 'pt' kwargs['return_tensors'] = 'pt'
@@ -225,8 +226,43 @@ class TextGenerationPreprocessor(NLPPreprocessorBase):
kwargs['max_length'] = kwargs.pop('sequence_length', 128) kwargs['max_length'] = kwargs.pop('sequence_length', 128)
super().__init__(model_dir, *args, **kwargs) super().__init__(model_dir, *args, **kwargs)


def build_tokenizer(self, model_dir):
return self.tokenizer
def build_tokenizer(self, model_dir: str):
import os
from sofa.models.palm_v2 import PalmConfig

config_file = os.path.join(model_dir, 'config.json')
config = PalmConfig.from_json_file(config_file) if os.path.isfile(
config_file) else PalmConfig()
config.encoder_pth = os.path.join(model_dir, config.encoder_pth)
if config.encoder == 'roberta':
from transformers import RobertaTokenizer
tokenizer = RobertaTokenizer.from_pretrained(
config.encoder_pth, do_lower_case=False)
elif config.encoder == 'bert' or config.encoder == 'zh_bert':
from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained(
config.encoder_pth, do_lower_case=True)
return tokenizer


@PREPROCESSORS.register_module(
Fields.nlp, module_name='palm-text-gen-tokenizer-finetune')
class TextGenerationFinetunePreprocessor(TextGenerationPreprocessor):

@type_assert(object, dict)
def __call__(self, data: dict) -> Dict[str, Any]:
src_txt = data['src_txt']
tgt_txt = data['tgt_txt']
src_rst = super().__call__(src_txt)
tgt_rst = super().__call__(tgt_txt)
src_rst = {k: v.squeeze() for k, v in src_rst.items()}
tgt_rst = {k: v.squeeze() for k, v in tgt_rst.items()}

return {
'src': src_rst['input_ids'],
'tgt': tgt_rst['input_ids'],
'mask_src': src_rst['attention_mask']
}




@PREPROCESSORS.register_module(Fields.nlp) @PREPROCESSORS.register_module(Fields.nlp)


+ 2
- 1
modelscope/trainers/trainer.py View File

@@ -297,6 +297,7 @@ class EpochBasedTrainer(BaseTrainer):
model = Model.from_pretrained(self.model_dir) model = Model.from_pretrained(self.model_dir)
if not isinstance(model, nn.Module) and hasattr(model, 'model'): if not isinstance(model, nn.Module) and hasattr(model, 'model'):
return model.model return model.model
return model


def collate_fn(self, data): def collate_fn(self, data):
"""Prepare the input just before the forward function. """Prepare the input just before the forward function.
@@ -339,7 +340,7 @@ class EpochBasedTrainer(BaseTrainer):
model.train() model.train()
self._mode = ModeKeys.TRAIN self._mode = ModeKeys.TRAIN
inputs = self.collate_fn(inputs) inputs = self.collate_fn(inputs)
if isinstance(inputs, dict):
if not isinstance(model, Model) and isinstance(inputs, dict):
train_outputs = model.forward(**inputs) train_outputs = model.forward(**inputs)
else: else:
train_outputs = model.forward(inputs) train_outputs = model.forward(inputs)


+ 9
- 2
modelscope/trainers/utils/inference.py View File

@@ -10,6 +10,7 @@ import torch
from torch import distributed as dist from torch import distributed as dist
from tqdm import tqdm from tqdm import tqdm


from modelscope.models.base import Model
from modelscope.utils.torch_utils import get_dist_info from modelscope.utils.torch_utils import get_dist_info




@@ -35,7 +36,10 @@ def single_gpu_test(model,
if data_collate_fn is not None: if data_collate_fn is not None:
data = data_collate_fn(data) data = data_collate_fn(data)
with torch.no_grad(): with torch.no_grad():
result = model(**data)
if not isinstance(model, Model):
result = model(**data)
else:
result = model(data)
if metric_classes is not None: if metric_classes is not None:
for metric_cls in metric_classes: for metric_cls in metric_classes:
metric_cls.add(result, data) metric_cls.add(result, data)
@@ -83,7 +87,10 @@ def multi_gpu_test(model,
if data_collate_fn is not None: if data_collate_fn is not None:
data = data_collate_fn(data) data = data_collate_fn(data)
with torch.no_grad(): with torch.no_grad():
result = model(**data)
if not isinstance(model, Model):
result = model(**data)
else:
result = model(data)
results.extend(result) results.extend(result)


rank, world_size = get_dist_info() rank, world_size = get_dist_info()


+ 1
- 0
requirements/nlp.txt View File

@@ -1,5 +1,6 @@


http://ait-public.oss-cn-hangzhou-zmf.aliyuncs.com/jizhu/en_core_web_sm-2.3.1.tar.gz http://ait-public.oss-cn-hangzhou-zmf.aliyuncs.com/jizhu/en_core_web_sm-2.3.1.tar.gz
pai-easynlp pai-easynlp
rouge_score
sofa>=1.0.5 sofa>=1.0.5
spacy>=2.3.5 spacy>=2.3.5

+ 91
- 0
tests/trainers/test_text_generation_trainer.py View File

@@ -0,0 +1,91 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import shutil
import tempfile
import unittest

from modelscope.hub.snapshot_download import snapshot_download
from modelscope.models.nlp.palm_for_text_generation import \
PalmForTextGeneration
from modelscope.msdatasets import MsDataset
from modelscope.trainers import build_trainer
from modelscope.utils.constant import ModelFile
from modelscope.utils.test_utils import test_level


class TestTextGenerationTrainer(unittest.TestCase):

def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
self.tmp_dir = tempfile.TemporaryDirectory().name
if not os.path.exists(self.tmp_dir):
os.makedirs(self.tmp_dir)

from datasets import Dataset

self.model_id = 'damo/nlp_palm2.0_text-generation_english-base'

dataset_dict = {
'src_txt': [
'This is test sentence1-1', 'This is test sentence2-1',
'This is test sentence3-1'
],
'tgt_txt': [
'This is test sentence1-2', 'This is test sentence2-2',
'This is test sentence3-2'
]
}
dataset = Dataset.from_dict(dataset_dict)

class MsDatasetDummy(MsDataset):

def __len__(self):
return len(self._hf_ds)

self.dataset = MsDatasetDummy(dataset)

def tearDown(self):
shutil.rmtree(self.tmp_dir)
super().tearDown()

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_trainer(self):
kwargs = dict(
model=self.model_id,
train_dataset=self.dataset,
eval_dataset=self.dataset,
work_dir=self.tmp_dir)

trainer = build_trainer(default_args=kwargs)
trainer.train()
results_files = os.listdir(self.tmp_dir)
self.assertIn(f'{trainer.timestamp}.log.json', results_files)
for i in range(3):
self.assertIn(f'epoch_{i+1}.pth', results_files)

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_trainer_with_model_and_args(self):
tmp_dir = tempfile.TemporaryDirectory().name
if not os.path.exists(tmp_dir):
os.makedirs(tmp_dir)

cache_path = snapshot_download(self.model_id)
model = PalmForTextGeneration.from_pretrained(cache_path)
kwargs = dict(
cfg_file=os.path.join(cache_path, ModelFile.CONFIGURATION),
model=model,
train_dataset=self.dataset,
eval_dataset=self.dataset,
max_epochs=2,
work_dir=self.tmp_dir)

trainer = build_trainer(default_args=kwargs)
trainer.train()
results_files = os.listdir(self.tmp_dir)
self.assertIn(f'{trainer.timestamp}.log.json', results_files)
for i in range(2):
self.assertIn(f'epoch_{i+1}.pth', results_files)


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save