Browse Source

[to #42322933] Add palm finetuning

Palm 模型支持 finetuning
master
hemu.zp 3 years ago
parent
commit
0b7b964226
10 changed files with 224 additions and 18 deletions
  1. +2
    -0
      modelscope/metainfo.py
  2. +1
    -0
      modelscope/metrics/__init__.py
  3. +1
    -0
      modelscope/metrics/builder.py
  4. +34
    -0
      modelscope/metrics/text_generation_metric.py
  5. +43
    -11
      modelscope/models/nlp/palm_for_text_generation.py
  6. +40
    -4
      modelscope/preprocessors/nlp.py
  7. +2
    -1
      modelscope/trainers/trainer.py
  8. +9
    -2
      modelscope/trainers/utils/inference.py
  9. +1
    -0
      requirements/nlp.txt
  10. +91
    -0
      tests/trainers/test_text_generation_trainer.py

+ 2
- 0
modelscope/metainfo.py View File

@@ -141,3 +141,5 @@ class Metrics(object):
seq_cls_metric = 'seq_cls_metric'
# metrics for token-classification task
token_cls_metric = 'token-cls-metric'
# metrics for text-generation task
text_gen_metric = 'text-gen-metric'

+ 1
- 0
modelscope/metrics/__init__.py View File

@@ -1,3 +1,4 @@
from .base import Metric
from .builder import METRICS, build_metric, task_default_metrics
from .sequence_classification_metric import SequenceClassificationMetric
from .text_generation_metric import TextGenerationMetric

+ 1
- 0
modelscope/metrics/builder.py View File

@@ -17,6 +17,7 @@ class MetricKeys(object):

task_default_metrics = {
Tasks.sentence_similarity: [Metrics.seq_cls_metric],
Tasks.text_generation: [Metrics.text_gen_metric],
}




+ 34
- 0
modelscope/metrics/text_generation_metric.py View File

@@ -0,0 +1,34 @@
from typing import Dict

import numpy as np
from rouge_score import rouge_scorer

from ..metainfo import Metrics
from ..utils.registry import default_group
from .base import Metric
from .builder import METRICS, MetricKeys


@METRICS.register_module(
group_key=default_group, module_name=Metrics.text_gen_metric)
class TextGenerationMetric(Metric):
"""The metric computation class for text generation classes.
"""

def __init__(self):
self.preds = []
self.tgts = []
self.scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)

def add(self, outputs: Dict, inputs: Dict):
ground_truths = outputs['tgts']
eval_results = outputs['preds']
self.preds.extend(eval_results)
self.tgts.extend(ground_truths)

def evaluate(self):
scores = [
self.scorer.score(pred, tgt)['rougeL'].fmeasure
for pred, tgt in zip(self.preds, self.tgts)
]
return {MetricKeys.F1: sum(scores) / len(scores)}

+ 43
- 11
modelscope/models/nlp/palm_for_text_generation.py View File

@@ -2,14 +2,15 @@ from typing import Dict

from ...metainfo import Models
from ...utils.constant import Tasks
from ..base import Model, Tensor
from ..base import Tensor
from ..base_torch import TorchModel
from ..builder import MODELS

__all__ = ['PalmForTextGeneration']


@MODELS.register_module(Tasks.text_generation, module_name=Models.palm)
class PalmForTextGeneration(Model):
class PalmForTextGeneration(TorchModel):

def __init__(self, model_dir: str, *args, **kwargs):
"""initialize the text generation model from the `model_dir` path.
@@ -22,15 +23,42 @@ class PalmForTextGeneration(Model):
super().__init__(model_dir, *args, **kwargs)

from sofa.models.palm_v2 import PalmForConditionalGeneration, Translator
model = PalmForConditionalGeneration.from_pretrained(model_dir)
self.tokenizer = model.tokenizer
self.generator = Translator(model)
self.model = PalmForConditionalGeneration.from_pretrained(model_dir)
self.tokenizer = self.model.tokenizer
self.generator = Translator(self.model)

def train(self):
return self.generator.train()
def _evaluate_postprocess(self, src: Tensor, tgt: Tensor,
mask_src: Tensor) -> Dict[str, str]:
replace_tokens_bert = (('[unused0]', ''), ('[PAD]', ''),
('[unused1]', ''), (r' +', ' '), ('[SEP]', ''),
('[unused2]', ''), ('[CLS]', ''), ('[UNK]', ''))
replace_tokens_roberta = ((r' +', ' '), ('<mask>', '<q>'), ('<pad>',
''),
('<s>', ''), ('</s>', ''), ('<unk>', ' '))

def eval(self):
return self.generator.eval()
inputs = self.generator(src, mask_src)
pred_list = inputs['predictions']
pred_id_list = [
pred_batch[0].cpu().numpy().tolist() for pred_batch in pred_list
]
tgt_id_list = tgt.cpu().numpy().tolist()
pred_strings = [
self.tokenizer.decode(pred_ids) for pred_ids in pred_id_list
]
tgt_strings = [
self.tokenizer.decode(tgt_ids) for tgt_ids in tgt_id_list
]
for _old, _new in replace_tokens_bert:
pred_strings = [s.replace(_old, _new) for s in pred_strings]
tgt_strings = [s.replace(_old, _new) for s in tgt_strings]
for _old, _new in replace_tokens_roberta:
pred_strings = [s.replace(_old, _new) for s in pred_strings]
tgt_strings = [s.replace(_old, _new) for s in tgt_strings]
for s in pred_strings:
s.strip()
for s in tgt_strings:
s.strip()
return {'preds': pred_strings, 'tgts': tgt_strings}

def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
"""return the result by the model
@@ -45,5 +73,9 @@ class PalmForTextGeneration(Model):
'predictions': Tensor([[1377, 4959, 2785, 6392...])]), # tokens need to be decode by tokenizer
}
"""

return self.generator(**input)
if self.training:
return {'loss': self.model(**input)}
elif 'tgt' in input:
return self._evaluate_postprocess(**input)
else:
return self.generator(**input)

+ 40
- 4
modelscope/preprocessors/nlp.py View File

@@ -216,8 +216,9 @@ class SentenceSimilarityFinetunePreprocessor(SentenceSimilarityPreprocessor):
Fields.nlp, module_name=Preprocessors.palm_text_gen_tokenizer)
class TextGenerationPreprocessor(NLPPreprocessorBase):

def __init__(self, model_dir: str, tokenizer, *args, **kwargs):
self.tokenizer = tokenizer
def __init__(self, model_dir: str, tokenizer=None, *args, **kwargs):
self.tokenizer = self.build_tokenizer(
model_dir) if tokenizer is None else tokenizer
kwargs['truncation'] = True
kwargs['padding'] = 'max_length'
kwargs['return_tensors'] = 'pt'
@@ -225,8 +226,43 @@ class TextGenerationPreprocessor(NLPPreprocessorBase):
kwargs['max_length'] = kwargs.pop('sequence_length', 128)
super().__init__(model_dir, *args, **kwargs)

def build_tokenizer(self, model_dir):
return self.tokenizer
def build_tokenizer(self, model_dir: str):
import os
from sofa.models.palm_v2 import PalmConfig

config_file = os.path.join(model_dir, 'config.json')
config = PalmConfig.from_json_file(config_file) if os.path.isfile(
config_file) else PalmConfig()
config.encoder_pth = os.path.join(model_dir, config.encoder_pth)
if config.encoder == 'roberta':
from transformers import RobertaTokenizer
tokenizer = RobertaTokenizer.from_pretrained(
config.encoder_pth, do_lower_case=False)
elif config.encoder == 'bert' or config.encoder == 'zh_bert':
from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained(
config.encoder_pth, do_lower_case=True)
return tokenizer


@PREPROCESSORS.register_module(
Fields.nlp, module_name='palm-text-gen-tokenizer-finetune')
class TextGenerationFinetunePreprocessor(TextGenerationPreprocessor):

@type_assert(object, dict)
def __call__(self, data: dict) -> Dict[str, Any]:
src_txt = data['src_txt']
tgt_txt = data['tgt_txt']
src_rst = super().__call__(src_txt)
tgt_rst = super().__call__(tgt_txt)
src_rst = {k: v.squeeze() for k, v in src_rst.items()}
tgt_rst = {k: v.squeeze() for k, v in tgt_rst.items()}

return {
'src': src_rst['input_ids'],
'tgt': tgt_rst['input_ids'],
'mask_src': src_rst['attention_mask']
}


@PREPROCESSORS.register_module(Fields.nlp)


+ 2
- 1
modelscope/trainers/trainer.py View File

@@ -297,6 +297,7 @@ class EpochBasedTrainer(BaseTrainer):
model = Model.from_pretrained(self.model_dir)
if not isinstance(model, nn.Module) and hasattr(model, 'model'):
return model.model
return model

def collate_fn(self, data):
"""Prepare the input just before the forward function.
@@ -339,7 +340,7 @@ class EpochBasedTrainer(BaseTrainer):
model.train()
self._mode = ModeKeys.TRAIN
inputs = self.collate_fn(inputs)
if isinstance(inputs, dict):
if not isinstance(model, Model) and isinstance(inputs, dict):
train_outputs = model.forward(**inputs)
else:
train_outputs = model.forward(inputs)


+ 9
- 2
modelscope/trainers/utils/inference.py View File

@@ -10,6 +10,7 @@ import torch
from torch import distributed as dist
from tqdm import tqdm

from modelscope.models.base import Model
from modelscope.utils.torch_utils import get_dist_info


@@ -35,7 +36,10 @@ def single_gpu_test(model,
if data_collate_fn is not None:
data = data_collate_fn(data)
with torch.no_grad():
result = model(**data)
if not isinstance(model, Model):
result = model(**data)
else:
result = model(data)
if metric_classes is not None:
for metric_cls in metric_classes:
metric_cls.add(result, data)
@@ -83,7 +87,10 @@ def multi_gpu_test(model,
if data_collate_fn is not None:
data = data_collate_fn(data)
with torch.no_grad():
result = model(**data)
if not isinstance(model, Model):
result = model(**data)
else:
result = model(data)
results.extend(result)

rank, world_size = get_dist_info()


+ 1
- 0
requirements/nlp.txt View File

@@ -1,5 +1,6 @@

http://ait-public.oss-cn-hangzhou-zmf.aliyuncs.com/jizhu/en_core_web_sm-2.3.1.tar.gz
pai-easynlp
rouge_score
sofa>=1.0.5
spacy>=2.3.5

+ 91
- 0
tests/trainers/test_text_generation_trainer.py View File

@@ -0,0 +1,91 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import shutil
import tempfile
import unittest

from modelscope.hub.snapshot_download import snapshot_download
from modelscope.models.nlp.palm_for_text_generation import \
PalmForTextGeneration
from modelscope.msdatasets import MsDataset
from modelscope.trainers import build_trainer
from modelscope.utils.constant import ModelFile
from modelscope.utils.test_utils import test_level


class TestTextGenerationTrainer(unittest.TestCase):

def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
self.tmp_dir = tempfile.TemporaryDirectory().name
if not os.path.exists(self.tmp_dir):
os.makedirs(self.tmp_dir)

from datasets import Dataset

self.model_id = 'damo/nlp_palm2.0_text-generation_english-base'

dataset_dict = {
'src_txt': [
'This is test sentence1-1', 'This is test sentence2-1',
'This is test sentence3-1'
],
'tgt_txt': [
'This is test sentence1-2', 'This is test sentence2-2',
'This is test sentence3-2'
]
}
dataset = Dataset.from_dict(dataset_dict)

class MsDatasetDummy(MsDataset):

def __len__(self):
return len(self._hf_ds)

self.dataset = MsDatasetDummy(dataset)

def tearDown(self):
shutil.rmtree(self.tmp_dir)
super().tearDown()

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_trainer(self):
kwargs = dict(
model=self.model_id,
train_dataset=self.dataset,
eval_dataset=self.dataset,
work_dir=self.tmp_dir)

trainer = build_trainer(default_args=kwargs)
trainer.train()
results_files = os.listdir(self.tmp_dir)
self.assertIn(f'{trainer.timestamp}.log.json', results_files)
for i in range(3):
self.assertIn(f'epoch_{i+1}.pth', results_files)

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_trainer_with_model_and_args(self):
tmp_dir = tempfile.TemporaryDirectory().name
if not os.path.exists(tmp_dir):
os.makedirs(tmp_dir)

cache_path = snapshot_download(self.model_id)
model = PalmForTextGeneration.from_pretrained(cache_path)
kwargs = dict(
cfg_file=os.path.join(cache_path, ModelFile.CONFIGURATION),
model=model,
train_dataset=self.dataset,
eval_dataset=self.dataset,
max_epochs=2,
work_dir=self.tmp_dir)

trainer = build_trainer(default_args=kwargs)
trainer.train()
results_files = os.listdir(self.tmp_dir)
self.assertIn(f'{trainer.timestamp}.log.json', results_files)
for i in range(2):
self.assertIn(f'epoch_{i+1}.pth', results_files)


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save