diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py
index 9c5ed709..75259f43 100644
--- a/modelscope/metainfo.py
+++ b/modelscope/metainfo.py
@@ -26,6 +26,7 @@ class Models(object):
space = 'space'
tcrf = 'transformer-crf'
bart = 'bart'
+ gpt3 = 'gpt3'
# audio models
sambert_hifigan = 'sambert-hifigan'
@@ -160,7 +161,7 @@ class Preprocessors(object):
# nlp preprocessor
sen_sim_tokenizer = 'sen-sim-tokenizer'
bert_seq_cls_tokenizer = 'bert-seq-cls-tokenizer'
- palm_text_gen_tokenizer = 'palm-text-gen-tokenizer'
+ text_gen_tokenizer = 'text-gen-tokenizer'
token_cls_tokenizer = 'token-cls-tokenizer'
ner_tokenizer = 'ner-tokenizer'
nli_tokenizer = 'nli-tokenizer'
diff --git a/modelscope/models/nlp/__init__.py b/modelscope/models/nlp/__init__.py
index 23041168..f2219b0e 100644
--- a/modelscope/models/nlp/__init__.py
+++ b/modelscope/models/nlp/__init__.py
@@ -4,7 +4,8 @@ from typing import TYPE_CHECKING
from modelscope.utils.import_utils import LazyImportModule
if TYPE_CHECKING:
- from .backbones import (SbertModel, SpaceGenerator, SpaceModelBase)
+ from .backbones import (SbertModel, SpaceGenerator, SpaceModelBase,
+ GPT3Model)
from .heads import SequenceClassificationHead
from .bert_for_sequence_classification import BertForSequenceClassification
from .csanmt_for_translation import CsanmtForTranslation
@@ -23,10 +24,12 @@ if TYPE_CHECKING:
from .space_for_dialog_state_tracking import SpaceForDialogStateTracking
from .task_model import SingleBackboneTaskModelBase
from .bart_for_text_error_correction import BartForTextErrorCorrection
+ from .gpt3_for_text_generation import GPT3ForTextGeneration
else:
_import_structure = {
- 'backbones': ['SbertModel', 'SpaceGenerator', 'SpaceModelBase'],
+ 'backbones':
+ ['SbertModel', 'SpaceGenerator', 'SpaceModelBase', 'GPT3Model'],
'heads': ['SequenceClassificationHead'],
'csanmt_for_translation': ['CsanmtForTranslation'],
'bert_for_sequence_classification': ['BertForSequenceClassification'],
@@ -48,6 +51,7 @@ else:
'space_for_dialog_state_tracking': ['SpaceForDialogStateTracking'],
'task_model': ['SingleBackboneTaskModelBase'],
'bart_for_text_error_correction': ['BartForTextErrorCorrection'],
+ 'gpt3_for_text_generation': ['GPT3ForTextGeneration'],
}
import sys
diff --git a/modelscope/models/nlp/backbones/__init__.py b/modelscope/models/nlp/backbones/__init__.py
index a21c5d6f..ffe8ac05 100644
--- a/modelscope/models/nlp/backbones/__init__.py
+++ b/modelscope/models/nlp/backbones/__init__.py
@@ -6,10 +6,12 @@ from modelscope.utils.import_utils import LazyImportModule
if TYPE_CHECKING:
from .space import SpaceGenerator, SpaceModelBase
from .structbert import SbertModel
+ from .gpt3 import GPT3Model
else:
_import_structure = {
'space': ['SpaceGenerator', 'SpaceModelBase'],
- 'structbert': ['SbertModel']
+ 'structbert': ['SbertModel'],
+ 'gpt3': ['GPT3Model']
}
import sys
diff --git a/modelscope/models/nlp/backbones/gpt3/__init__.py b/modelscope/models/nlp/backbones/gpt3/__init__.py
new file mode 100644
index 00000000..b0739c22
--- /dev/null
+++ b/modelscope/models/nlp/backbones/gpt3/__init__.py
@@ -0,0 +1,23 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+from typing import TYPE_CHECKING
+
+from modelscope.utils.import_utils import LazyImportModule
+
+if TYPE_CHECKING:
+ from .configuration_gpt3 import GPT3Config
+ from .modeling_gpt3 import GPT3Model
+else:
+ _import_structure = {
+ 'configuration_gpt3': ['GPT3Config'],
+ 'modeling_gpt3': ['GPT3Model']
+ }
+
+ import sys
+
+ sys.modules[__name__] = LazyImportModule(
+ __name__,
+ globals()['__file__'],
+ _import_structure,
+ module_spec=__spec__,
+ extra_objects={},
+ )
diff --git a/modelscope/models/nlp/backbones/gpt3/configuration_gpt3.py b/modelscope/models/nlp/backbones/gpt3/configuration_gpt3.py
new file mode 100644
index 00000000..d5a054fd
--- /dev/null
+++ b/modelscope/models/nlp/backbones/gpt3/configuration_gpt3.py
@@ -0,0 +1,51 @@
+# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors.
+# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from transformers.configuration_utils import PretrainedConfig
+from transformers.utils import logging
+
+logger = logging.get_logger(__name__)
+
+
+class GPT3Config(PretrainedConfig):
+
+ model_type = 'gpt'
+
+ def __init__(self,
+ vocab_size=25600,
+ hidden_size=768,
+ num_hidden_layers=12,
+ num_attention_heads=12,
+ intermediate_size=3072,
+ hidden_act='gelu',
+ hidden_dropout_prob=0.1,
+ attention_probs_dropout_prob=0.1,
+ max_position_embeddings=2048,
+ type_vocab_size=2,
+ layernorm_epsilon=1e-12,
+ **kwargs):
+ super().__init__(layer_norm_eps=layernorm_epsilon, **kwargs)
+
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.hidden_act = hidden_act
+ self.intermediate_size = intermediate_size
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.max_position_embeddings = max_position_embeddings
+ self.type_vocab_size = type_vocab_size
+ self.layernorm_epsilon = layernorm_epsilon
diff --git a/modelscope/models/nlp/backbones/gpt3/modeling_gpt3.py b/modelscope/models/nlp/backbones/gpt3/modeling_gpt3.py
new file mode 100644
index 00000000..f7024713
--- /dev/null
+++ b/modelscope/models/nlp/backbones/gpt3/modeling_gpt3.py
@@ -0,0 +1,337 @@
+# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+import os
+from typing import Optional, Union
+
+import torch
+from addict import Dict
+from torch.nn import Dropout, Embedding, LayerNorm, Linear, Module, Softmax
+from torch.nn import functional as F
+from transformers.modeling_utils import PreTrainedModel
+
+from modelscope.utils.constant import ModelFile
+from .configuration_gpt3 import GPT3Config
+
+
+class GPT3SelfAttention(Module):
+ """Parallel self-attention layer abstract class.
+
+ Self-attention layer takes input with size [s, b, h]
+ and returns output of the same size.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+
+ self.hidden_size = config.hidden_size
+ self.num_attention_heads = config.num_attention_heads
+ # Per attention head
+ self.hidden_size_per_attention_head = \
+ self.hidden_size // self.num_attention_heads
+
+ self.query_key_value = Linear(self.hidden_size, 3 * self.hidden_size)
+ self.softmax = Softmax(dim=-1)
+ self.attention_dropout = Dropout(config.attention_probs_dropout_prob)
+
+ # Output.
+ self.dense = Linear(self.hidden_size, self.hidden_size)
+ self.output_dropout = torch.nn.Dropout(config.hidden_dropout_prob)
+
+ def _transpose_for_scores(self, tensor):
+ """Transpose a 3D tensor [b, s, np*hn] into a 4D tensor with
+ size [b, np, s, hn].
+ """
+ new_tensor_shape = tensor.size()[:-1] + (
+ self.num_attention_heads, self.hidden_size_per_attention_head)
+ tensor = tensor.view(*new_tensor_shape)
+ return tensor.permute(0, 2, 1, 3)
+
+ def _split_tensor_along_last_dim(self,
+ tensor,
+ num_partitions,
+ contiguous_split_chunks=False):
+ # Get the size and dimension.
+ last_dim = tensor.dim() - 1
+ last_dim_size = tensor.size()[last_dim] // num_partitions
+ # Split.
+ tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
+ # Note: torch.split does not create contiguous tensors by default.
+ if contiguous_split_chunks:
+ return tuple(chunk.contiguous() for chunk in tensor_list)
+
+ return tensor_list
+
+ def forward(self, hidden_states, ltor_mask, is_infer=False):
+ # hidden_states: [b, s, h]
+ # ltor_mask: [1, 1, s, s]
+
+ # Attention heads. [b, s, hp]
+ tgt_len = hidden_states.size(1)
+ ltor_mask = torch.reshape(ltor_mask, [1, 1, tgt_len, tgt_len])
+ mixed_x_layer = self.query_key_value(hidden_states)
+ (mixed_query_layer, mixed_key_layer, mixed_value_layer) = \
+ self._split_tensor_along_last_dim(mixed_x_layer, 3)
+
+ # Reshape and transpose [b, np, s, hn]
+ query_layer = self._transpose_for_scores(mixed_query_layer)
+ key_layer = self._transpose_for_scores(mixed_key_layer)
+ value_layer = self._transpose_for_scores(mixed_value_layer)
+
+ previous_type = value_layer.type()
+
+ # Raw attention scores. [b, np, s, s]
+ attention_scores = torch.matmul(query_layer,
+ key_layer.transpose(-1, -2))
+ attention_scores = attention_scores / math.sqrt(
+ self.hidden_size_per_attention_head)
+ # Apply the left to right attention mask.
+ if is_infer:
+ src_len = key_layer.size(2)
+ ltor_mask = torch.tril(
+ torch.ones((1, tgt_len, src_len),
+ device=hidden_states.device)).view(
+ 1, 1, tgt_len, src_len).type(previous_type)
+ converted_mask = 10000.0 * (1.0 - ltor_mask)
+ attention_scores = (torch.mul(attention_scores, ltor_mask)
+ - converted_mask).type(previous_type)
+
+ # Attention probabilities. [b, np, s, s]
+ attention_probs = self.softmax(attention_scores)
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs = self.attention_dropout(attention_probs)
+
+ # Context layer.
+ # [b, np, s, hn]
+ context_layer = torch.matmul(attention_probs, value_layer)
+ # [b, s, np, hn]
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+ new_context_layer_shape = context_layer.size()[:-2] + (
+ self.hidden_size, )
+ # [b, s, hp]
+ context_layer = context_layer.view(*new_context_layer_shape)
+
+ # Output. [b, s, h]
+ output = self.dense(context_layer)
+ output = self.output_dropout(output)
+
+ return output
+
+
+class GPT3MLP(Module):
+ """MLP.
+
+ MLP will take the input with h hidden state, project it to 4*h
+ hidden dimension, perform nonlinear transformation, and project the
+ state back into h hidden dimension.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+
+ hidden_size = config.hidden_size
+ # Project to 4h.
+ self.dense_h_to_4h = Linear(hidden_size, 4 * hidden_size)
+ self.activation_func = F.gelu
+ # Project back to h.
+ self.dense_4h_to_h = Linear(4 * hidden_size, hidden_size)
+
+ self.dropout = Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states):
+
+ # [s, b, 4hp]
+ intermediate_parallel = self.dense_h_to_4h(hidden_states)
+ intermediate_parallel = self.activation_func(intermediate_parallel)
+ # [s, b, h]
+ output = self.dense_4h_to_h(intermediate_parallel)
+ output = self.dropout(output)
+ return output
+
+
+class GPT3TransformerLayer(Module):
+ """A single transformer layer.
+
+ Transformer layer takes input with size [s, b, h] and returns an
+ output of the same size.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+
+ # Layernorm on the input data.
+ self.input_layernorm = LayerNorm(
+ config.hidden_size, eps=config.layernorm_epsilon)
+
+ # Self attention.
+ self.attention = GPT3SelfAttention(config)
+
+ # Layernorm on the attention output
+ self.post_attention_layernorm = LayerNorm(
+ config.hidden_size, eps=config.layernorm_epsilon)
+
+ # MLP
+ self.mlp = GPT3MLP(config)
+
+ def forward(self, hidden_states, ltor_mask):
+ # hidden_states: [b, s, h]
+ # ltor_mask: [1, 1, s, s]
+
+ # Layer norm at the begining of the transformer layer.
+ layernorm_output = self.input_layernorm(hidden_states)
+ # Self attention.
+ attention_output = self.attention(layernorm_output, ltor_mask)
+ # Residual connection.
+ layernorm_input = hidden_states + attention_output
+ # Layer norm post the self attention.
+ layernorm_output = self.post_attention_layernorm(layernorm_input)
+ # MLP.
+ mlp_output = self.mlp(layernorm_output)
+ # Second residual connection.
+ output = layernorm_input + mlp_output
+
+ return output
+
+
+class GPT3Transformer(Module):
+ """Transformer class."""
+
+ def __init__(self, config):
+ super().__init__()
+
+ self.input_tensor = None
+
+ # Number of layers.
+ self.num_layers = config.num_hidden_layers
+
+ self.layers = torch.nn.ModuleList(
+ [GPT3TransformerLayer(config) for _ in range(self.num_layers)])
+
+ # Final layer norm before output.
+ self.final_layernorm = LayerNorm(
+ config.hidden_size, eps=config.layernorm_epsilon)
+
+ def _get_layer(self, layer_number):
+ return self.layers[layer_number]
+
+ def forward(self, hidden_states, attention_mask):
+ # hidden_states: [s, b, h]
+
+ for index in range(self.num_layers):
+ layer = self._get_layer(index)
+ hidden_states = layer(hidden_states, attention_mask)
+
+ # Final layer norm.
+ hidden_states = self.final_layernorm(hidden_states)
+
+ return hidden_states
+
+
+class GPT3TransformerLanguageModel(Module):
+ """Transformer language model.
+
+ Arguments:
+ transformer_hparams: transformer hyperparameters
+ vocab_size: vocabulary size
+ max_sequence_length: maximum size of sequence. This
+ is used for positional embedding
+ embedding_dropout_prob: dropout probability for embeddings
+ num_tokentypes: size of the token-type embeddings. 0 value
+ will ignore this embedding
+ """
+
+ def __init__(self, config):
+ super().__init__()
+
+ # Embeddings.
+ self.word_embeddings = Embedding(config.vocab_size, config.hidden_size)
+ self.position_embeddings = Embedding(config.max_position_embeddings,
+ config.hidden_size)
+ self.embedding_dropout = Dropout(config.hidden_dropout_prob)
+
+ # Transformer.
+ self.transformer = GPT3Transformer(config)
+
+ def forward(self, input_ids, attention_mask, position_ids):
+ words_embeddings = self.word_embeddings(input_ids)
+ position_embeddings = self.position_embeddings(position_ids)
+
+ embeddings = words_embeddings + position_embeddings
+ transformer_input = self.embedding_dropout(embeddings)
+ transformer_output = self.transformer(transformer_input,
+ attention_mask)
+
+ logits = F.linear(transformer_output, self.word_embeddings.weight)
+ return logits
+
+
+class GPT3Model(PreTrainedModel):
+
+ config_class = GPT3Config
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ if isinstance(module, Linear):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(
+ mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, Embedding):
+ module.weight.data.normal_(
+ mean=0.0, std=self.config.initializer_range)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.language_model = GPT3TransformerLanguageModel(config)
+
+ def forward(self,
+ input_ids,
+ attention_mask=None,
+ position_ids=None,
+ **kwargs):
+ seq_length = input_ids.size(1)
+ if attention_mask is None:
+ attention_mask = torch.tril(
+ torch.ones((1, seq_length, seq_length),
+ dtype=torch.long,
+ device=input_ids.device))
+ if position_ids is None:
+ position_ids = torch.arange(
+ seq_length, dtype=torch.long, device=input_ids.device)
+ position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
+
+ logits = self.language_model(input_ids, attention_mask, position_ids)
+ return Dict(logits=logits)
+
+ @classmethod
+ def from_pretrained(
+ cls, pretrained_model_name_or_path: Optional[Union[str,
+ os.PathLike]]):
+ config = cls.config_class.from_pretrained(
+ pretrained_model_name_or_path)
+ model = cls(config)
+ state_dict_file = os.path.join(pretrained_model_name_or_path,
+ ModelFile.TORCH_MODEL_BIN_FILE)
+ state_dict = torch.load(state_dict_file)
+ model.load_state_dict(state_dict)
+ return model
diff --git a/modelscope/models/nlp/backbones/structbert/__init__.py b/modelscope/models/nlp/backbones/structbert/__init__.py
index 7db035d8..1d147730 100644
--- a/modelscope/models/nlp/backbones/structbert/__init__.py
+++ b/modelscope/models/nlp/backbones/structbert/__init__.py
@@ -1 +1,19 @@
-from .modeling_sbert import SbertModel
+# Copyright (c) Alibaba, Inc. and its affiliates.
+from typing import TYPE_CHECKING
+
+from modelscope.utils.import_utils import LazyImportModule
+
+if TYPE_CHECKING:
+ from .modeling_sbert import SbertModel
+else:
+ _import_structure = {'modeling_sbert': ['SbertModel']}
+
+ import sys
+
+ sys.modules[__name__] = LazyImportModule(
+ __name__,
+ globals()['__file__'],
+ _import_structure,
+ module_spec=__spec__,
+ extra_objects={},
+ )
diff --git a/modelscope/models/nlp/gpt3_for_text_generation.py b/modelscope/models/nlp/gpt3_for_text_generation.py
new file mode 100644
index 00000000..22a6458d
--- /dev/null
+++ b/modelscope/models/nlp/gpt3_for_text_generation.py
@@ -0,0 +1,56 @@
+from typing import Dict
+
+from modelscope.metainfo import Models
+from modelscope.models.base import Tensor, TorchModel
+from modelscope.models.builder import MODELS
+from modelscope.outputs import OutputKeys
+from modelscope.utils.constant import Tasks
+
+__all__ = ['GPT3ForTextGeneration']
+
+
+@MODELS.register_module(Tasks.text_generation, module_name=Models.gpt3)
+class GPT3ForTextGeneration(TorchModel):
+
+ def __init__(self, model_dir: str, *args, **kwargs):
+ """initialize the text generation model from the `model_dir` path.
+
+ Args:
+ model_dir (str): the model path.
+ """
+ super().__init__(model_dir, *args, **kwargs)
+
+ from modelscope.models.nlp import GPT3Model
+ from transformers import BertTokenizer
+
+ self.model = GPT3Model.from_pretrained(model_dir)
+ self.tokenizer = BertTokenizer.from_pretrained(model_dir)
+
+ def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
+ """return the result by the model
+
+ Args:
+ input (Dict[str, Tensor]): the preprocessed data
+
+ Returns:
+ Dict[str, Tensor]: results
+ Example:
+ {
+ 'logits': Tensor([[0.54, 0.32...])]), # logits
+ }
+ """
+ return self.model(**input)
+
+ def generate(self, input: Dict[str, Tensor]) -> Dict[str, str]:
+ assert 'input_ids' in input, "generate function must accept 'input_ids' key"
+ gen_params = dict()
+ gen_params['inputs'] = input['input_ids']
+ gen_params['do_sample'] = input.pop('do_sample', True)
+ gen_params['max_length'] = input.pop('max_length', 128)
+ gen_params['top_k'] = input.pop('top_k', 10)
+ gen_params['top_p'] = input.pop('top_p', None)
+ sample_output = self.model.generate(**gen_params)
+ return {
+ OutputKeys.TEXT:
+ self.tokenizer.decode(sample_output[0], skip_special_tokens=True)
+ }
diff --git a/modelscope/models/nlp/palm_for_text_generation.py b/modelscope/models/nlp/palm_for_text_generation.py
index 245a5fdb..23d60663 100644
--- a/modelscope/models/nlp/palm_for_text_generation.py
+++ b/modelscope/models/nlp/palm_for_text_generation.py
@@ -1,8 +1,9 @@
-from typing import Dict
+from typing import Dict, List
from modelscope.metainfo import Models
from modelscope.models.base import Tensor, TorchModel
from modelscope.models.builder import MODELS
+from modelscope.outputs import OutputKeys
from modelscope.utils.constant import Tasks
__all__ = ['PalmForTextGeneration']
@@ -27,8 +28,7 @@ class PalmForTextGeneration(TorchModel):
self.tokenizer = self.model.tokenizer
self.generator = Translator(self.model)
- def _evaluate_postprocess(self, src: Tensor, tgt: Tensor,
- mask_src: Tensor) -> Dict[str, str]:
+ def _evaluate_postprocess(self, ids_list: List[List[int]]) -> List[str]:
replace_tokens_bert = (('[unused0]', ''), ('[PAD]', ''),
('[unused1]', ''), (r' +', ' '), ('[SEP]', ''),
('[unused2]', ''), ('[CLS]', ''), ('[UNK]', ''))
@@ -36,29 +36,14 @@ class PalmForTextGeneration(TorchModel):
''),
('', ''), ('', ''), ('', ' '))
- inputs = self.generator(src, mask_src)
- pred_list = inputs['predictions']
- pred_id_list = [
- pred_batch[0].cpu().numpy().tolist() for pred_batch in pred_list
- ]
- tgt_id_list = tgt.cpu().numpy().tolist()
- pred_strings = [
- self.tokenizer.decode(pred_ids) for pred_ids in pred_id_list
- ]
- tgt_strings = [
- self.tokenizer.decode(tgt_ids) for tgt_ids in tgt_id_list
- ]
+ strings = [self.tokenizer.decode(pred_ids) for pred_ids in ids_list]
for _old, _new in replace_tokens_bert:
- pred_strings = [s.replace(_old, _new) for s in pred_strings]
- tgt_strings = [s.replace(_old, _new) for s in tgt_strings]
+ strings = [s.replace(_old, _new) for s in strings]
for _old, _new in replace_tokens_roberta:
- pred_strings = [s.replace(_old, _new) for s in pred_strings]
- tgt_strings = [s.replace(_old, _new) for s in tgt_strings]
- for s in pred_strings:
+ strings = [s.replace(_old, _new) for s in strings]
+ for s in strings:
s.strip()
- for s in tgt_strings:
- s.strip()
- return {'preds': pred_strings, 'tgts': tgt_strings}
+ return strings
def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
"""return the result by the model
@@ -70,12 +55,30 @@ class PalmForTextGeneration(TorchModel):
Dict[str, Tensor]: results
Example:
{
- 'predictions': Tensor([[1377, 4959, 2785, 6392...])]), # tokens need to be decode by tokenizer
+ 'loss': Tensor([12.34]), # loss for backward
+ }
+ or
+ {
+ 'preds': List["hello word"...] # the predicted strings
+ 'tgts': List["hello world"...] # target strings
}
"""
if self.training:
return {'loss': self.model(**input)}
- elif 'tgt' in input:
- return self._evaluate_postprocess(**input)
else:
- return self.generator(**input)
+ outputs = self.generator(input['src'], input['mask_src'])
+ preds = outputs['predictions']
+ pred_ids_list = [
+ pred_batch[0].cpu().numpy().tolist() for pred_batch in preds
+ ]
+ tgt_ids_list = input['tgt'].cpu().numpy().tolist()
+ return {
+ 'preds': self._evaluate_postprocess(pred_ids_list),
+ 'tgts': self._evaluate_postprocess(tgt_ids_list)
+ }
+
+ def generate(self, input: Dict[str, Tensor]) -> Dict[str, str]:
+ outputs = self.generator(**input)
+ preds = outputs['predictions']
+ pred_ids_list = [preds[0][0].cpu().numpy().tolist()]
+ return {OutputKeys.TEXT: self._evaluate_postprocess(pred_ids_list)[0]}
diff --git a/modelscope/pipelines/nlp/text_generation_pipeline.py b/modelscope/pipelines/nlp/text_generation_pipeline.py
index 5cf75314..85a81eba 100644
--- a/modelscope/pipelines/nlp/text_generation_pipeline.py
+++ b/modelscope/pipelines/nlp/text_generation_pipeline.py
@@ -3,9 +3,7 @@ from typing import Any, Dict, Optional, Union
import torch
from modelscope.metainfo import Pipelines
-from modelscope.models import Model
-from modelscope.models.nlp import PalmForTextGeneration
-from modelscope.outputs import OutputKeys
+from modelscope.models.base import TorchModel
from modelscope.pipelines.base import Pipeline, Tensor
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import TextGenerationPreprocessor
@@ -19,7 +17,7 @@ __all__ = ['TextGenerationPipeline']
class TextGenerationPipeline(Pipeline):
def __init__(self,
- model: Union[PalmForTextGeneration, str],
+ model: Union[TorchModel, str],
preprocessor: Optional[TextGenerationPreprocessor] = None,
**kwargs):
"""use `model` and `preprocessor` to create a nlp text generation pipeline for prediction
@@ -29,21 +27,19 @@ class TextGenerationPipeline(Pipeline):
preprocessor (TextGenerationPreprocessor): a preprocessor instance
"""
model = model if isinstance(
- model, PalmForTextGeneration) else Model.from_pretrained(model)
+ model, TorchModel) else TorchModel.from_pretrained(model)
if preprocessor is None:
preprocessor = TextGenerationPreprocessor(
model.model_dir,
- model.tokenizer,
first_sequence='sentence',
second_sequence=None)
model.eval()
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
- self.tokenizer = model.tokenizer
def forward(self, inputs: Dict[str, Any],
**forward_params) -> Dict[str, Any]:
with torch.no_grad():
- return super().forward(inputs, **forward_params)
+ return self.model.generate(inputs)
def postprocess(self, inputs: Dict[str, Tensor],
**postprocess_params) -> Dict[str, str]:
@@ -55,20 +51,4 @@ class TextGenerationPipeline(Pipeline):
Returns:
Dict[str, str]: the prediction results
"""
- replace_tokens_bert = (('[unused0]', ''), ('[PAD]', ''),
- ('[unused1]', ''), (r' +', ' '), ('[SEP]', ''),
- ('[unused2]', ''), ('[CLS]', ''), ('[UNK]', ''))
- replace_tokens_roberta = ((r' +', ' '), ('', ''), ('',
- ''),
- ('', ''), ('', ''), ('', ' '))
-
- pred_list = inputs['predictions']
- pred_ids = pred_list[0][0].cpu().numpy().tolist()
- pred_string = self.tokenizer.decode(pred_ids)
- for _old, _new in replace_tokens_bert:
- pred_string = pred_string.replace(_old, _new)
- pred_string.strip()
- for _old, _new in replace_tokens_roberta:
- pred_string = pred_string.replace(_old, _new)
- pred_string.strip()
- return {OutputKeys.TEXT: pred_string}
+ return inputs
diff --git a/modelscope/preprocessors/nlp.py b/modelscope/preprocessors/nlp.py
index 0da17cb0..a0a7a5b5 100644
--- a/modelscope/preprocessors/nlp.py
+++ b/modelscope/preprocessors/nlp.py
@@ -2,7 +2,7 @@
import os.path as osp
import uuid
-from typing import Any, Dict, Union
+from typing import Any, Dict, Optional, Union
from transformers import AutoTokenizer
@@ -211,36 +211,34 @@ class SentenceSimilarityFinetunePreprocessor(SentenceSimilarityPreprocessor):
@PREPROCESSORS.register_module(
- Fields.nlp, module_name=Preprocessors.palm_text_gen_tokenizer)
+ Fields.nlp, module_name=Preprocessors.text_gen_tokenizer)
class TextGenerationPreprocessor(NLPPreprocessorBase):
def __init__(self, model_dir: str, tokenizer=None, *args, **kwargs):
self.tokenizer = self.build_tokenizer(
model_dir) if tokenizer is None else tokenizer
kwargs['truncation'] = True
- kwargs['padding'] = 'max_length'
+ kwargs['padding'] = True
kwargs['return_tensors'] = 'pt'
kwargs['return_token_type_ids'] = False
kwargs['max_length'] = kwargs.pop('sequence_length', 128)
super().__init__(model_dir, *args, **kwargs)
- def build_tokenizer(self, model_dir: str):
+ @staticmethod
+ def get_roberta_tokenizer_dir(model_dir: str) -> Optional[str]:
import os
- from sofa.models.palm_v2 import PalmConfig
+ for name in os.listdir(model_dir):
+ full_name = os.path.join(model_dir, name)
+ if 'roberta' in name and os.path.isdir(full_name):
+ return full_name
- config_file = os.path.join(model_dir, 'config.json')
- config = PalmConfig.from_json_file(config_file) if os.path.isfile(
- config_file) else PalmConfig()
- config.encoder_pth = os.path.join(model_dir, config.encoder_pth)
- if config.encoder == 'roberta':
+ def build_tokenizer(self, model_dir: str):
+ roberta_tokenizer_dir = self.get_roberta_tokenizer_dir(model_dir)
+ if roberta_tokenizer_dir:
from transformers import RobertaTokenizer
- tokenizer = RobertaTokenizer.from_pretrained(
- config.encoder_pth, do_lower_case=False)
- elif config.encoder == 'bert' or config.encoder == 'zh_bert':
- from transformers import BertTokenizer
- tokenizer = BertTokenizer.from_pretrained(
- config.encoder_pth, do_lower_case=True)
- return tokenizer
+ return RobertaTokenizer.from_pretrained(
+ roberta_tokenizer_dir, do_lower_case=False)
+ return super().build_tokenizer(model_dir)
@PREPROCESSORS.register_module(
diff --git a/tests/pipelines/test_text_generation.py b/tests/pipelines/test_text_generation.py
index 61faf20c..fd397de3 100644
--- a/tests/pipelines/test_text_generation.py
+++ b/tests/pipelines/test_text_generation.py
@@ -3,7 +3,7 @@ import unittest
from modelscope.hub.snapshot_download import snapshot_download
from modelscope.models import Model
-from modelscope.models.nlp import PalmForTextGeneration
+from modelscope.models.nlp import GPT3ForTextGeneration, PalmForTextGeneration
from modelscope.pipelines import pipeline
from modelscope.pipelines.nlp import TextGenerationPipeline
from modelscope.preprocessors import TextGenerationPreprocessor
@@ -12,26 +12,32 @@ from modelscope.utils.test_utils import test_level
class TextGenerationTest(unittest.TestCase):
- model_id_zh = 'damo/nlp_palm2.0_text-generation_chinese-base'
- model_id_en = 'damo/nlp_palm2.0_text-generation_english-base'
- input_zh = """
- 本文总结了十个可穿戴产品的设计原则,而这些原则,同样也是笔者认为是这个行业最吸引人的地方:
- 1.为人们解决重复性问题;2.从人开始,而不是从机器开始;3.要引起注意,但不要刻意;4.提升用户能力,而不是取代
- """
- input_en = """
- The Director of Public Prosecutions who let off Lord Janner over alleged child sex abuse started
- her career at a legal chambers when the disgraced Labour peer was a top QC there . Alison Saunders ,
- 54 , sparked outrage last week when she decided the 86-year-old should not face astring of charges
- of paedophilia against nine children because he has dementia . Today , newly-released documents
- revealed damning evidence that abuse was covered up by police andsocial workers for more than 20 years .
- And now it has emerged Mrs Saunders ' law career got off to a flying start when she secured her
- pupillage -- a barrister 's training contract at 1 Garden Court Chambers in London in 1983 .
- """
+
+ def setUp(self) -> None:
+ self.palm_model_id_zh = 'damo/nlp_palm2.0_text-generation_chinese-base'
+ self.palm_model_id_en = 'damo/nlp_palm2.0_text-generation_english-base'
+ self.palm_input_zh = """
+ 本文总结了十个可穿戴产品的设计原则,而这些原则,同样也是笔者认为是这个行业最吸引人的地方:
+ 1.为人们解决重复性问题;2.从人开始,而不是从机器开始;3.要引起注意,但不要刻意;4.提升用户能力,而不是取代
+ """
+ self.palm_input_en = """
+ The Director of Public Prosecutions who let off Lord Janner over alleged child sex abuse started
+ her career at a legal chambers when the disgraced Labour peer was a top QC there . Alison Saunders ,
+ 54 , sparked outrage last week when she decided the 86-year-old should not face astring of charges
+ of paedophilia against nine children because he has dementia . Today , newly-released documents
+ revealed damning evidence that abuse was covered up by police andsocial workers for more than 20 years .
+ And now it has emerged Mrs Saunders ' law career got off to a flying start when she secured her
+ pupillage -- a barrister 's training contract at 1 Garden Court Chambers in London in 1983 .
+ """
+
+ self.gpt3_base_model_id = 'damo/nlp_gpt3_text-generation_chinese-base'
+ self.gpt3_large_model_id = 'damo/nlp_gpt3_text-generation_chinese-large'
+ self.gpt3_input = '我很好奇'
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
- def test_run(self):
- for model_id, input in ((self.model_id_zh, self.input_zh),
- (self.model_id_en, self.input_en)):
+ def test_run_palm(self):
+ for model_id, input in ((self.palm_model_id_zh, self.palm_input_zh),
+ (self.palm_model_id_en, self.palm_input_en)):
cache_path = snapshot_download(model_id)
model = PalmForTextGeneration(cache_path)
preprocessor = TextGenerationPreprocessor(
@@ -46,10 +52,28 @@ class TextGenerationTest(unittest.TestCase):
f'pipeline1: {pipeline1(input)}\npipeline2: {pipeline2(input)}'
)
+ @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
+ def test_run_gpt3(self):
+ cache_path = snapshot_download(self.gpt3_base_model_id)
+ model = GPT3ForTextGeneration(cache_path)
+ preprocessor = TextGenerationPreprocessor(
+ cache_path,
+ model.tokenizer,
+ first_sequence='sentence',
+ second_sequence=None)
+ pipeline1 = TextGenerationPipeline(model, preprocessor)
+ pipeline2 = pipeline(
+ Tasks.text_generation, model=model, preprocessor=preprocessor)
+ print(
+ f'pipeline1: {pipeline1(self.gpt3_input)}\npipeline2: {pipeline2(self.gpt3_input)}'
+ )
+
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_model_from_modelhub(self):
- for model_id, input in ((self.model_id_zh, self.input_zh),
- (self.model_id_en, self.input_en)):
+ for model_id, input in ((self.palm_model_id_zh, self.palm_input_zh),
+ (self.palm_model_id_en, self.palm_input_en),
+ (self.gpt3_base_model_id, self.gpt3_input),
+ (self.gpt3_large_model_id, self.gpt3_input)):
model = Model.from_pretrained(model_id)
preprocessor = TextGenerationPreprocessor(
model.model_dir,
@@ -62,17 +86,19 @@ class TextGenerationTest(unittest.TestCase):
preprocessor=preprocessor)
print(pipeline_ins(input))
- @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
+ @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_model_name(self):
- for model_id, input in ((self.model_id_zh, self.input_zh),
- (self.model_id_en, self.input_en)):
+ for model_id, input in ((self.palm_model_id_zh, self.palm_input_zh),
+ (self.palm_model_id_en, self.palm_input_en),
+ (self.gpt3_base_model_id, self.gpt3_input),
+ (self.gpt3_large_model_id, self.gpt3_input)):
pipeline_ins = pipeline(task=Tasks.text_generation, model=model_id)
print(pipeline_ins(input))
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_default_model(self):
pipeline_ins = pipeline(task=Tasks.text_generation)
- print(pipeline_ins(self.input_zh))
+ print(pipeline_ins(self.palm_input_zh))
if __name__ == '__main__':