From 8f060d0bc3980cae65e8ee3b0b7f8740d4b587dc Mon Sep 17 00:00:00 2001 From: "hemu.zp" Date: Sat, 30 Jul 2022 11:03:01 +0800 Subject: [PATCH] [to #42322933] Add GPT3 base model MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 添加 GPT3 base 模型,复用 text generation pipeline --- modelscope/metainfo.py | 3 +- modelscope/models/nlp/__init__.py | 8 +- modelscope/models/nlp/backbones/__init__.py | 4 +- .../models/nlp/backbones/gpt3/__init__.py | 23 ++ .../nlp/backbones/gpt3/configuration_gpt3.py | 51 +++ .../nlp/backbones/gpt3/modeling_gpt3.py | 337 ++++++++++++++++++ .../nlp/backbones/structbert/__init__.py | 20 +- .../models/nlp/gpt3_for_text_generation.py | 56 +++ .../models/nlp/palm_for_text_generation.py | 57 +-- .../pipelines/nlp/text_generation_pipeline.py | 30 +- modelscope/preprocessors/nlp.py | 32 +- tests/pipelines/test_text_generation.py | 76 ++-- 12 files changed, 598 insertions(+), 99 deletions(-) create mode 100644 modelscope/models/nlp/backbones/gpt3/__init__.py create mode 100644 modelscope/models/nlp/backbones/gpt3/configuration_gpt3.py create mode 100644 modelscope/models/nlp/backbones/gpt3/modeling_gpt3.py create mode 100644 modelscope/models/nlp/gpt3_for_text_generation.py diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 9c5ed709..75259f43 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -26,6 +26,7 @@ class Models(object): space = 'space' tcrf = 'transformer-crf' bart = 'bart' + gpt3 = 'gpt3' # audio models sambert_hifigan = 'sambert-hifigan' @@ -160,7 +161,7 @@ class Preprocessors(object): # nlp preprocessor sen_sim_tokenizer = 'sen-sim-tokenizer' bert_seq_cls_tokenizer = 'bert-seq-cls-tokenizer' - palm_text_gen_tokenizer = 'palm-text-gen-tokenizer' + text_gen_tokenizer = 'text-gen-tokenizer' token_cls_tokenizer = 'token-cls-tokenizer' ner_tokenizer = 'ner-tokenizer' nli_tokenizer = 'nli-tokenizer' diff --git a/modelscope/models/nlp/__init__.py b/modelscope/models/nlp/__init__.py index 23041168..f2219b0e 100644 --- a/modelscope/models/nlp/__init__.py +++ b/modelscope/models/nlp/__init__.py @@ -4,7 +4,8 @@ from typing import TYPE_CHECKING from modelscope.utils.import_utils import LazyImportModule if TYPE_CHECKING: - from .backbones import (SbertModel, SpaceGenerator, SpaceModelBase) + from .backbones import (SbertModel, SpaceGenerator, SpaceModelBase, + GPT3Model) from .heads import SequenceClassificationHead from .bert_for_sequence_classification import BertForSequenceClassification from .csanmt_for_translation import CsanmtForTranslation @@ -23,10 +24,12 @@ if TYPE_CHECKING: from .space_for_dialog_state_tracking import SpaceForDialogStateTracking from .task_model import SingleBackboneTaskModelBase from .bart_for_text_error_correction import BartForTextErrorCorrection + from .gpt3_for_text_generation import GPT3ForTextGeneration else: _import_structure = { - 'backbones': ['SbertModel', 'SpaceGenerator', 'SpaceModelBase'], + 'backbones': + ['SbertModel', 'SpaceGenerator', 'SpaceModelBase', 'GPT3Model'], 'heads': ['SequenceClassificationHead'], 'csanmt_for_translation': ['CsanmtForTranslation'], 'bert_for_sequence_classification': ['BertForSequenceClassification'], @@ -48,6 +51,7 @@ else: 'space_for_dialog_state_tracking': ['SpaceForDialogStateTracking'], 'task_model': ['SingleBackboneTaskModelBase'], 'bart_for_text_error_correction': ['BartForTextErrorCorrection'], + 'gpt3_for_text_generation': ['GPT3ForTextGeneration'], } import sys diff --git a/modelscope/models/nlp/backbones/__init__.py b/modelscope/models/nlp/backbones/__init__.py index a21c5d6f..ffe8ac05 100644 --- a/modelscope/models/nlp/backbones/__init__.py +++ b/modelscope/models/nlp/backbones/__init__.py @@ -6,10 +6,12 @@ from modelscope.utils.import_utils import LazyImportModule if TYPE_CHECKING: from .space import SpaceGenerator, SpaceModelBase from .structbert import SbertModel + from .gpt3 import GPT3Model else: _import_structure = { 'space': ['SpaceGenerator', 'SpaceModelBase'], - 'structbert': ['SbertModel'] + 'structbert': ['SbertModel'], + 'gpt3': ['GPT3Model'] } import sys diff --git a/modelscope/models/nlp/backbones/gpt3/__init__.py b/modelscope/models/nlp/backbones/gpt3/__init__.py new file mode 100644 index 00000000..b0739c22 --- /dev/null +++ b/modelscope/models/nlp/backbones/gpt3/__init__.py @@ -0,0 +1,23 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .configuration_gpt3 import GPT3Config + from .modeling_gpt3 import GPT3Model +else: + _import_structure = { + 'configuration_gpt3': ['GPT3Config'], + 'modeling_gpt3': ['GPT3Model'] + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/nlp/backbones/gpt3/configuration_gpt3.py b/modelscope/models/nlp/backbones/gpt3/configuration_gpt3.py new file mode 100644 index 00000000..d5a054fd --- /dev/null +++ b/modelscope/models/nlp/backbones/gpt3/configuration_gpt3.py @@ -0,0 +1,51 @@ +# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + + +class GPT3Config(PretrainedConfig): + + model_type = 'gpt' + + def __init__(self, + vocab_size=25600, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act='gelu', + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=2048, + type_vocab_size=2, + layernorm_epsilon=1e-12, + **kwargs): + super().__init__(layer_norm_eps=layernorm_epsilon, **kwargs) + + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.layernorm_epsilon = layernorm_epsilon diff --git a/modelscope/models/nlp/backbones/gpt3/modeling_gpt3.py b/modelscope/models/nlp/backbones/gpt3/modeling_gpt3.py new file mode 100644 index 00000000..f7024713 --- /dev/null +++ b/modelscope/models/nlp/backbones/gpt3/modeling_gpt3.py @@ -0,0 +1,337 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import os +from typing import Optional, Union + +import torch +from addict import Dict +from torch.nn import Dropout, Embedding, LayerNorm, Linear, Module, Softmax +from torch.nn import functional as F +from transformers.modeling_utils import PreTrainedModel + +from modelscope.utils.constant import ModelFile +from .configuration_gpt3 import GPT3Config + + +class GPT3SelfAttention(Module): + """Parallel self-attention layer abstract class. + + Self-attention layer takes input with size [s, b, h] + and returns output of the same size. + """ + + def __init__(self, config): + super().__init__() + + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + # Per attention head + self.hidden_size_per_attention_head = \ + self.hidden_size // self.num_attention_heads + + self.query_key_value = Linear(self.hidden_size, 3 * self.hidden_size) + self.softmax = Softmax(dim=-1) + self.attention_dropout = Dropout(config.attention_probs_dropout_prob) + + # Output. + self.dense = Linear(self.hidden_size, self.hidden_size) + self.output_dropout = torch.nn.Dropout(config.hidden_dropout_prob) + + def _transpose_for_scores(self, tensor): + """Transpose a 3D tensor [b, s, np*hn] into a 4D tensor with + size [b, np, s, hn]. + """ + new_tensor_shape = tensor.size()[:-1] + ( + self.num_attention_heads, self.hidden_size_per_attention_head) + tensor = tensor.view(*new_tensor_shape) + return tensor.permute(0, 2, 1, 3) + + def _split_tensor_along_last_dim(self, + tensor, + num_partitions, + contiguous_split_chunks=False): + # Get the size and dimension. + last_dim = tensor.dim() - 1 + last_dim_size = tensor.size()[last_dim] // num_partitions + # Split. + tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) + # Note: torch.split does not create contiguous tensors by default. + if contiguous_split_chunks: + return tuple(chunk.contiguous() for chunk in tensor_list) + + return tensor_list + + def forward(self, hidden_states, ltor_mask, is_infer=False): + # hidden_states: [b, s, h] + # ltor_mask: [1, 1, s, s] + + # Attention heads. [b, s, hp] + tgt_len = hidden_states.size(1) + ltor_mask = torch.reshape(ltor_mask, [1, 1, tgt_len, tgt_len]) + mixed_x_layer = self.query_key_value(hidden_states) + (mixed_query_layer, mixed_key_layer, mixed_value_layer) = \ + self._split_tensor_along_last_dim(mixed_x_layer, 3) + + # Reshape and transpose [b, np, s, hn] + query_layer = self._transpose_for_scores(mixed_query_layer) + key_layer = self._transpose_for_scores(mixed_key_layer) + value_layer = self._transpose_for_scores(mixed_value_layer) + + previous_type = value_layer.type() + + # Raw attention scores. [b, np, s, s] + attention_scores = torch.matmul(query_layer, + key_layer.transpose(-1, -2)) + attention_scores = attention_scores / math.sqrt( + self.hidden_size_per_attention_head) + # Apply the left to right attention mask. + if is_infer: + src_len = key_layer.size(2) + ltor_mask = torch.tril( + torch.ones((1, tgt_len, src_len), + device=hidden_states.device)).view( + 1, 1, tgt_len, src_len).type(previous_type) + converted_mask = 10000.0 * (1.0 - ltor_mask) + attention_scores = (torch.mul(attention_scores, ltor_mask) + - converted_mask).type(previous_type) + + # Attention probabilities. [b, np, s, s] + attention_probs = self.softmax(attention_scores) + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.attention_dropout(attention_probs) + + # Context layer. + # [b, np, s, hn] + context_layer = torch.matmul(attention_probs, value_layer) + # [b, s, np, hn] + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + ( + self.hidden_size, ) + # [b, s, hp] + context_layer = context_layer.view(*new_context_layer_shape) + + # Output. [b, s, h] + output = self.dense(context_layer) + output = self.output_dropout(output) + + return output + + +class GPT3MLP(Module): + """MLP. + + MLP will take the input with h hidden state, project it to 4*h + hidden dimension, perform nonlinear transformation, and project the + state back into h hidden dimension. + """ + + def __init__(self, config): + super().__init__() + + hidden_size = config.hidden_size + # Project to 4h. + self.dense_h_to_4h = Linear(hidden_size, 4 * hidden_size) + self.activation_func = F.gelu + # Project back to h. + self.dense_4h_to_h = Linear(4 * hidden_size, hidden_size) + + self.dropout = Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states): + + # [s, b, 4hp] + intermediate_parallel = self.dense_h_to_4h(hidden_states) + intermediate_parallel = self.activation_func(intermediate_parallel) + # [s, b, h] + output = self.dense_4h_to_h(intermediate_parallel) + output = self.dropout(output) + return output + + +class GPT3TransformerLayer(Module): + """A single transformer layer. + + Transformer layer takes input with size [s, b, h] and returns an + output of the same size. + """ + + def __init__(self, config): + super().__init__() + + # Layernorm on the input data. + self.input_layernorm = LayerNorm( + config.hidden_size, eps=config.layernorm_epsilon) + + # Self attention. + self.attention = GPT3SelfAttention(config) + + # Layernorm on the attention output + self.post_attention_layernorm = LayerNorm( + config.hidden_size, eps=config.layernorm_epsilon) + + # MLP + self.mlp = GPT3MLP(config) + + def forward(self, hidden_states, ltor_mask): + # hidden_states: [b, s, h] + # ltor_mask: [1, 1, s, s] + + # Layer norm at the begining of the transformer layer. + layernorm_output = self.input_layernorm(hidden_states) + # Self attention. + attention_output = self.attention(layernorm_output, ltor_mask) + # Residual connection. + layernorm_input = hidden_states + attention_output + # Layer norm post the self attention. + layernorm_output = self.post_attention_layernorm(layernorm_input) + # MLP. + mlp_output = self.mlp(layernorm_output) + # Second residual connection. + output = layernorm_input + mlp_output + + return output + + +class GPT3Transformer(Module): + """Transformer class.""" + + def __init__(self, config): + super().__init__() + + self.input_tensor = None + + # Number of layers. + self.num_layers = config.num_hidden_layers + + self.layers = torch.nn.ModuleList( + [GPT3TransformerLayer(config) for _ in range(self.num_layers)]) + + # Final layer norm before output. + self.final_layernorm = LayerNorm( + config.hidden_size, eps=config.layernorm_epsilon) + + def _get_layer(self, layer_number): + return self.layers[layer_number] + + def forward(self, hidden_states, attention_mask): + # hidden_states: [s, b, h] + + for index in range(self.num_layers): + layer = self._get_layer(index) + hidden_states = layer(hidden_states, attention_mask) + + # Final layer norm. + hidden_states = self.final_layernorm(hidden_states) + + return hidden_states + + +class GPT3TransformerLanguageModel(Module): + """Transformer language model. + + Arguments: + transformer_hparams: transformer hyperparameters + vocab_size: vocabulary size + max_sequence_length: maximum size of sequence. This + is used for positional embedding + embedding_dropout_prob: dropout probability for embeddings + num_tokentypes: size of the token-type embeddings. 0 value + will ignore this embedding + """ + + def __init__(self, config): + super().__init__() + + # Embeddings. + self.word_embeddings = Embedding(config.vocab_size, config.hidden_size) + self.position_embeddings = Embedding(config.max_position_embeddings, + config.hidden_size) + self.embedding_dropout = Dropout(config.hidden_dropout_prob) + + # Transformer. + self.transformer = GPT3Transformer(config) + + def forward(self, input_ids, attention_mask, position_ids): + words_embeddings = self.word_embeddings(input_ids) + position_embeddings = self.position_embeddings(position_ids) + + embeddings = words_embeddings + position_embeddings + transformer_input = self.embedding_dropout(embeddings) + transformer_output = self.transformer(transformer_input, + attention_mask) + + logits = F.linear(transformer_output, self.word_embeddings.weight) + return logits + + +class GPT3Model(PreTrainedModel): + + config_class = GPT3Config + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_( + mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, Embedding): + module.weight.data.normal_( + mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + def __init__(self, config): + super().__init__(config) + self.language_model = GPT3TransformerLanguageModel(config) + + def forward(self, + input_ids, + attention_mask=None, + position_ids=None, + **kwargs): + seq_length = input_ids.size(1) + if attention_mask is None: + attention_mask = torch.tril( + torch.ones((1, seq_length, seq_length), + dtype=torch.long, + device=input_ids.device)) + if position_ids is None: + position_ids = torch.arange( + seq_length, dtype=torch.long, device=input_ids.device) + position_ids = position_ids.unsqueeze(0).expand_as(input_ids) + + logits = self.language_model(input_ids, attention_mask, position_ids) + return Dict(logits=logits) + + @classmethod + def from_pretrained( + cls, pretrained_model_name_or_path: Optional[Union[str, + os.PathLike]]): + config = cls.config_class.from_pretrained( + pretrained_model_name_or_path) + model = cls(config) + state_dict_file = os.path.join(pretrained_model_name_or_path, + ModelFile.TORCH_MODEL_BIN_FILE) + state_dict = torch.load(state_dict_file) + model.load_state_dict(state_dict) + return model diff --git a/modelscope/models/nlp/backbones/structbert/__init__.py b/modelscope/models/nlp/backbones/structbert/__init__.py index 7db035d8..1d147730 100644 --- a/modelscope/models/nlp/backbones/structbert/__init__.py +++ b/modelscope/models/nlp/backbones/structbert/__init__.py @@ -1 +1,19 @@ -from .modeling_sbert import SbertModel +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .modeling_sbert import SbertModel +else: + _import_structure = {'modeling_sbert': ['SbertModel']} + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/nlp/gpt3_for_text_generation.py b/modelscope/models/nlp/gpt3_for_text_generation.py new file mode 100644 index 00000000..22a6458d --- /dev/null +++ b/modelscope/models/nlp/gpt3_for_text_generation.py @@ -0,0 +1,56 @@ +from typing import Dict + +from modelscope.metainfo import Models +from modelscope.models.base import Tensor, TorchModel +from modelscope.models.builder import MODELS +from modelscope.outputs import OutputKeys +from modelscope.utils.constant import Tasks + +__all__ = ['GPT3ForTextGeneration'] + + +@MODELS.register_module(Tasks.text_generation, module_name=Models.gpt3) +class GPT3ForTextGeneration(TorchModel): + + def __init__(self, model_dir: str, *args, **kwargs): + """initialize the text generation model from the `model_dir` path. + + Args: + model_dir (str): the model path. + """ + super().__init__(model_dir, *args, **kwargs) + + from modelscope.models.nlp import GPT3Model + from transformers import BertTokenizer + + self.model = GPT3Model.from_pretrained(model_dir) + self.tokenizer = BertTokenizer.from_pretrained(model_dir) + + def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: + """return the result by the model + + Args: + input (Dict[str, Tensor]): the preprocessed data + + Returns: + Dict[str, Tensor]: results + Example: + { + 'logits': Tensor([[0.54, 0.32...])]), # logits + } + """ + return self.model(**input) + + def generate(self, input: Dict[str, Tensor]) -> Dict[str, str]: + assert 'input_ids' in input, "generate function must accept 'input_ids' key" + gen_params = dict() + gen_params['inputs'] = input['input_ids'] + gen_params['do_sample'] = input.pop('do_sample', True) + gen_params['max_length'] = input.pop('max_length', 128) + gen_params['top_k'] = input.pop('top_k', 10) + gen_params['top_p'] = input.pop('top_p', None) + sample_output = self.model.generate(**gen_params) + return { + OutputKeys.TEXT: + self.tokenizer.decode(sample_output[0], skip_special_tokens=True) + } diff --git a/modelscope/models/nlp/palm_for_text_generation.py b/modelscope/models/nlp/palm_for_text_generation.py index 245a5fdb..23d60663 100644 --- a/modelscope/models/nlp/palm_for_text_generation.py +++ b/modelscope/models/nlp/palm_for_text_generation.py @@ -1,8 +1,9 @@ -from typing import Dict +from typing import Dict, List from modelscope.metainfo import Models from modelscope.models.base import Tensor, TorchModel from modelscope.models.builder import MODELS +from modelscope.outputs import OutputKeys from modelscope.utils.constant import Tasks __all__ = ['PalmForTextGeneration'] @@ -27,8 +28,7 @@ class PalmForTextGeneration(TorchModel): self.tokenizer = self.model.tokenizer self.generator = Translator(self.model) - def _evaluate_postprocess(self, src: Tensor, tgt: Tensor, - mask_src: Tensor) -> Dict[str, str]: + def _evaluate_postprocess(self, ids_list: List[List[int]]) -> List[str]: replace_tokens_bert = (('[unused0]', ''), ('[PAD]', ''), ('[unused1]', ''), (r' +', ' '), ('[SEP]', ''), ('[unused2]', ''), ('[CLS]', ''), ('[UNK]', '')) @@ -36,29 +36,14 @@ class PalmForTextGeneration(TorchModel): ''), ('', ''), ('', ''), ('', ' ')) - inputs = self.generator(src, mask_src) - pred_list = inputs['predictions'] - pred_id_list = [ - pred_batch[0].cpu().numpy().tolist() for pred_batch in pred_list - ] - tgt_id_list = tgt.cpu().numpy().tolist() - pred_strings = [ - self.tokenizer.decode(pred_ids) for pred_ids in pred_id_list - ] - tgt_strings = [ - self.tokenizer.decode(tgt_ids) for tgt_ids in tgt_id_list - ] + strings = [self.tokenizer.decode(pred_ids) for pred_ids in ids_list] for _old, _new in replace_tokens_bert: - pred_strings = [s.replace(_old, _new) for s in pred_strings] - tgt_strings = [s.replace(_old, _new) for s in tgt_strings] + strings = [s.replace(_old, _new) for s in strings] for _old, _new in replace_tokens_roberta: - pred_strings = [s.replace(_old, _new) for s in pred_strings] - tgt_strings = [s.replace(_old, _new) for s in tgt_strings] - for s in pred_strings: + strings = [s.replace(_old, _new) for s in strings] + for s in strings: s.strip() - for s in tgt_strings: - s.strip() - return {'preds': pred_strings, 'tgts': tgt_strings} + return strings def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: """return the result by the model @@ -70,12 +55,30 @@ class PalmForTextGeneration(TorchModel): Dict[str, Tensor]: results Example: { - 'predictions': Tensor([[1377, 4959, 2785, 6392...])]), # tokens need to be decode by tokenizer + 'loss': Tensor([12.34]), # loss for backward + } + or + { + 'preds': List["hello word"...] # the predicted strings + 'tgts': List["hello world"...] # target strings } """ if self.training: return {'loss': self.model(**input)} - elif 'tgt' in input: - return self._evaluate_postprocess(**input) else: - return self.generator(**input) + outputs = self.generator(input['src'], input['mask_src']) + preds = outputs['predictions'] + pred_ids_list = [ + pred_batch[0].cpu().numpy().tolist() for pred_batch in preds + ] + tgt_ids_list = input['tgt'].cpu().numpy().tolist() + return { + 'preds': self._evaluate_postprocess(pred_ids_list), + 'tgts': self._evaluate_postprocess(tgt_ids_list) + } + + def generate(self, input: Dict[str, Tensor]) -> Dict[str, str]: + outputs = self.generator(**input) + preds = outputs['predictions'] + pred_ids_list = [preds[0][0].cpu().numpy().tolist()] + return {OutputKeys.TEXT: self._evaluate_postprocess(pred_ids_list)[0]} diff --git a/modelscope/pipelines/nlp/text_generation_pipeline.py b/modelscope/pipelines/nlp/text_generation_pipeline.py index 5cf75314..85a81eba 100644 --- a/modelscope/pipelines/nlp/text_generation_pipeline.py +++ b/modelscope/pipelines/nlp/text_generation_pipeline.py @@ -3,9 +3,7 @@ from typing import Any, Dict, Optional, Union import torch from modelscope.metainfo import Pipelines -from modelscope.models import Model -from modelscope.models.nlp import PalmForTextGeneration -from modelscope.outputs import OutputKeys +from modelscope.models.base import TorchModel from modelscope.pipelines.base import Pipeline, Tensor from modelscope.pipelines.builder import PIPELINES from modelscope.preprocessors import TextGenerationPreprocessor @@ -19,7 +17,7 @@ __all__ = ['TextGenerationPipeline'] class TextGenerationPipeline(Pipeline): def __init__(self, - model: Union[PalmForTextGeneration, str], + model: Union[TorchModel, str], preprocessor: Optional[TextGenerationPreprocessor] = None, **kwargs): """use `model` and `preprocessor` to create a nlp text generation pipeline for prediction @@ -29,21 +27,19 @@ class TextGenerationPipeline(Pipeline): preprocessor (TextGenerationPreprocessor): a preprocessor instance """ model = model if isinstance( - model, PalmForTextGeneration) else Model.from_pretrained(model) + model, TorchModel) else TorchModel.from_pretrained(model) if preprocessor is None: preprocessor = TextGenerationPreprocessor( model.model_dir, - model.tokenizer, first_sequence='sentence', second_sequence=None) model.eval() super().__init__(model=model, preprocessor=preprocessor, **kwargs) - self.tokenizer = model.tokenizer def forward(self, inputs: Dict[str, Any], **forward_params) -> Dict[str, Any]: with torch.no_grad(): - return super().forward(inputs, **forward_params) + return self.model.generate(inputs) def postprocess(self, inputs: Dict[str, Tensor], **postprocess_params) -> Dict[str, str]: @@ -55,20 +51,4 @@ class TextGenerationPipeline(Pipeline): Returns: Dict[str, str]: the prediction results """ - replace_tokens_bert = (('[unused0]', ''), ('[PAD]', ''), - ('[unused1]', ''), (r' +', ' '), ('[SEP]', ''), - ('[unused2]', ''), ('[CLS]', ''), ('[UNK]', '')) - replace_tokens_roberta = ((r' +', ' '), ('', ''), ('', - ''), - ('', ''), ('', ''), ('', ' ')) - - pred_list = inputs['predictions'] - pred_ids = pred_list[0][0].cpu().numpy().tolist() - pred_string = self.tokenizer.decode(pred_ids) - for _old, _new in replace_tokens_bert: - pred_string = pred_string.replace(_old, _new) - pred_string.strip() - for _old, _new in replace_tokens_roberta: - pred_string = pred_string.replace(_old, _new) - pred_string.strip() - return {OutputKeys.TEXT: pred_string} + return inputs diff --git a/modelscope/preprocessors/nlp.py b/modelscope/preprocessors/nlp.py index 0da17cb0..a0a7a5b5 100644 --- a/modelscope/preprocessors/nlp.py +++ b/modelscope/preprocessors/nlp.py @@ -2,7 +2,7 @@ import os.path as osp import uuid -from typing import Any, Dict, Union +from typing import Any, Dict, Optional, Union from transformers import AutoTokenizer @@ -211,36 +211,34 @@ class SentenceSimilarityFinetunePreprocessor(SentenceSimilarityPreprocessor): @PREPROCESSORS.register_module( - Fields.nlp, module_name=Preprocessors.palm_text_gen_tokenizer) + Fields.nlp, module_name=Preprocessors.text_gen_tokenizer) class TextGenerationPreprocessor(NLPPreprocessorBase): def __init__(self, model_dir: str, tokenizer=None, *args, **kwargs): self.tokenizer = self.build_tokenizer( model_dir) if tokenizer is None else tokenizer kwargs['truncation'] = True - kwargs['padding'] = 'max_length' + kwargs['padding'] = True kwargs['return_tensors'] = 'pt' kwargs['return_token_type_ids'] = False kwargs['max_length'] = kwargs.pop('sequence_length', 128) super().__init__(model_dir, *args, **kwargs) - def build_tokenizer(self, model_dir: str): + @staticmethod + def get_roberta_tokenizer_dir(model_dir: str) -> Optional[str]: import os - from sofa.models.palm_v2 import PalmConfig + for name in os.listdir(model_dir): + full_name = os.path.join(model_dir, name) + if 'roberta' in name and os.path.isdir(full_name): + return full_name - config_file = os.path.join(model_dir, 'config.json') - config = PalmConfig.from_json_file(config_file) if os.path.isfile( - config_file) else PalmConfig() - config.encoder_pth = os.path.join(model_dir, config.encoder_pth) - if config.encoder == 'roberta': + def build_tokenizer(self, model_dir: str): + roberta_tokenizer_dir = self.get_roberta_tokenizer_dir(model_dir) + if roberta_tokenizer_dir: from transformers import RobertaTokenizer - tokenizer = RobertaTokenizer.from_pretrained( - config.encoder_pth, do_lower_case=False) - elif config.encoder == 'bert' or config.encoder == 'zh_bert': - from transformers import BertTokenizer - tokenizer = BertTokenizer.from_pretrained( - config.encoder_pth, do_lower_case=True) - return tokenizer + return RobertaTokenizer.from_pretrained( + roberta_tokenizer_dir, do_lower_case=False) + return super().build_tokenizer(model_dir) @PREPROCESSORS.register_module( diff --git a/tests/pipelines/test_text_generation.py b/tests/pipelines/test_text_generation.py index 61faf20c..fd397de3 100644 --- a/tests/pipelines/test_text_generation.py +++ b/tests/pipelines/test_text_generation.py @@ -3,7 +3,7 @@ import unittest from modelscope.hub.snapshot_download import snapshot_download from modelscope.models import Model -from modelscope.models.nlp import PalmForTextGeneration +from modelscope.models.nlp import GPT3ForTextGeneration, PalmForTextGeneration from modelscope.pipelines import pipeline from modelscope.pipelines.nlp import TextGenerationPipeline from modelscope.preprocessors import TextGenerationPreprocessor @@ -12,26 +12,32 @@ from modelscope.utils.test_utils import test_level class TextGenerationTest(unittest.TestCase): - model_id_zh = 'damo/nlp_palm2.0_text-generation_chinese-base' - model_id_en = 'damo/nlp_palm2.0_text-generation_english-base' - input_zh = """ - 本文总结了十个可穿戴产品的设计原则,而这些原则,同样也是笔者认为是这个行业最吸引人的地方: - 1.为人们解决重复性问题;2.从人开始,而不是从机器开始;3.要引起注意,但不要刻意;4.提升用户能力,而不是取代 - """ - input_en = """ - The Director of Public Prosecutions who let off Lord Janner over alleged child sex abuse started - her career at a legal chambers when the disgraced Labour peer was a top QC there . Alison Saunders , - 54 , sparked outrage last week when she decided the 86-year-old should not face astring of charges - of paedophilia against nine children because he has dementia . Today , newly-released documents - revealed damning evidence that abuse was covered up by police andsocial workers for more than 20 years . - And now it has emerged Mrs Saunders ' law career got off to a flying start when she secured her - pupillage -- a barrister 's training contract at 1 Garden Court Chambers in London in 1983 . - """ + + def setUp(self) -> None: + self.palm_model_id_zh = 'damo/nlp_palm2.0_text-generation_chinese-base' + self.palm_model_id_en = 'damo/nlp_palm2.0_text-generation_english-base' + self.palm_input_zh = """ + 本文总结了十个可穿戴产品的设计原则,而这些原则,同样也是笔者认为是这个行业最吸引人的地方: + 1.为人们解决重复性问题;2.从人开始,而不是从机器开始;3.要引起注意,但不要刻意;4.提升用户能力,而不是取代 + """ + self.palm_input_en = """ + The Director of Public Prosecutions who let off Lord Janner over alleged child sex abuse started + her career at a legal chambers when the disgraced Labour peer was a top QC there . Alison Saunders , + 54 , sparked outrage last week when she decided the 86-year-old should not face astring of charges + of paedophilia against nine children because he has dementia . Today , newly-released documents + revealed damning evidence that abuse was covered up by police andsocial workers for more than 20 years . + And now it has emerged Mrs Saunders ' law career got off to a flying start when she secured her + pupillage -- a barrister 's training contract at 1 Garden Court Chambers in London in 1983 . + """ + + self.gpt3_base_model_id = 'damo/nlp_gpt3_text-generation_chinese-base' + self.gpt3_large_model_id = 'damo/nlp_gpt3_text-generation_chinese-large' + self.gpt3_input = '我很好奇' @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') - def test_run(self): - for model_id, input in ((self.model_id_zh, self.input_zh), - (self.model_id_en, self.input_en)): + def test_run_palm(self): + for model_id, input in ((self.palm_model_id_zh, self.palm_input_zh), + (self.palm_model_id_en, self.palm_input_en)): cache_path = snapshot_download(model_id) model = PalmForTextGeneration(cache_path) preprocessor = TextGenerationPreprocessor( @@ -46,10 +52,28 @@ class TextGenerationTest(unittest.TestCase): f'pipeline1: {pipeline1(input)}\npipeline2: {pipeline2(input)}' ) + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_gpt3(self): + cache_path = snapshot_download(self.gpt3_base_model_id) + model = GPT3ForTextGeneration(cache_path) + preprocessor = TextGenerationPreprocessor( + cache_path, + model.tokenizer, + first_sequence='sentence', + second_sequence=None) + pipeline1 = TextGenerationPipeline(model, preprocessor) + pipeline2 = pipeline( + Tasks.text_generation, model=model, preprocessor=preprocessor) + print( + f'pipeline1: {pipeline1(self.gpt3_input)}\npipeline2: {pipeline2(self.gpt3_input)}' + ) + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_run_with_model_from_modelhub(self): - for model_id, input in ((self.model_id_zh, self.input_zh), - (self.model_id_en, self.input_en)): + for model_id, input in ((self.palm_model_id_zh, self.palm_input_zh), + (self.palm_model_id_en, self.palm_input_en), + (self.gpt3_base_model_id, self.gpt3_input), + (self.gpt3_large_model_id, self.gpt3_input)): model = Model.from_pretrained(model_id) preprocessor = TextGenerationPreprocessor( model.model_dir, @@ -62,17 +86,19 @@ class TextGenerationTest(unittest.TestCase): preprocessor=preprocessor) print(pipeline_ins(input)) - @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_run_with_model_name(self): - for model_id, input in ((self.model_id_zh, self.input_zh), - (self.model_id_en, self.input_en)): + for model_id, input in ((self.palm_model_id_zh, self.palm_input_zh), + (self.palm_model_id_en, self.palm_input_en), + (self.gpt3_base_model_id, self.gpt3_input), + (self.gpt3_large_model_id, self.gpt3_input)): pipeline_ins = pipeline(task=Tasks.text_generation, model=model_id) print(pipeline_ins(input)) @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_run_with_default_model(self): pipeline_ins = pipeline(task=Tasks.text_generation) - print(pipeline_ins(self.input_zh)) + print(pipeline_ins(self.palm_input_zh)) if __name__ == '__main__':