1. 修复 bloom 和 gpt_neo 模型更新 transformers 4.23 后后处理报错的问题
2. 统一使用 ModelOutput 作为模型输出
3. gpt_neo checkpoint 已上线,修改 ut 为 level2
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10553103
master
| @@ -8,7 +8,6 @@ from torch import nn | |||
| from modelscope.metainfo import Heads | |||
| from modelscope.models.base import TorchHead | |||
| from modelscope.models.builder import HEADS | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.utils.constant import Tasks | |||
| @@ -27,9 +26,8 @@ class TextGenerationHead(TorchHead): | |||
| def forward(self, inputs=None): | |||
| logits = self.linear(inputs) | |||
| return {OutputKeys.LOGITS: logits} | |||
| return logits | |||
| def compute_loss(self, outputs: Dict[str, torch.Tensor], | |||
| def compute_loss(self, logits: torch.Tensor, | |||
| labels) -> Dict[str, torch.Tensor]: | |||
| logits = outputs[OutputKeys.LOGITS] | |||
| return {OutputKeys.LOSS: F.cross_entropy(logits, labels)} | |||
| return F.cross_entropy(logits, labels) | |||
| @@ -1,7 +1,6 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from typing import Any, Dict | |||
| import addict | |||
| import numpy as np | |||
| from transformers.modeling_utils import PreTrainedModel | |||
| @@ -9,7 +8,8 @@ from modelscope.metainfo import TaskModels | |||
| from modelscope.models.builder import MODELS | |||
| from modelscope.models.nlp.task_models.task_model import \ | |||
| SingleBackboneTaskModelBase | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.outputs import (OutputKeys, TextGenerationModelOutput, | |||
| TokenGeneratorOutput) | |||
| from modelscope.utils.constant import Tasks | |||
| __all__ = ['TaskModelForTextGeneration'] | |||
| @@ -43,12 +43,12 @@ class TaskModelForTextGeneration(SingleBackboneTaskModelBase, PreTrainedModel): | |||
| backbone_outputs = super().forward(input) | |||
| hidden_states = backbone_outputs[0] | |||
| outputs = self.head.forward(hidden_states) | |||
| logits = self.head.forward(hidden_states) | |||
| loss = None | |||
| if labels is not None: | |||
| input[OutputKeys.LABELS] = labels | |||
| loss = self.compute_loss(outputs, labels) | |||
| outputs.update(loss) | |||
| return addict.Dict(outputs) | |||
| loss = self.compute_loss(logits, labels) | |||
| return TextGenerationModelOutput(logits=logits, loss=loss) | |||
| def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): | |||
| # only last token for inputs_ids if past is defined in kwargs | |||
| @@ -76,4 +76,12 @@ class TaskModelForTextGeneration(SingleBackboneTaskModelBase, PreTrainedModel): | |||
| def generate(self, inputs, *args, **kwargs): | |||
| input_ids = inputs['input_ids'] if isinstance(inputs, Dict) else inputs | |||
| return super().generate(input_ids, *args, **kwargs) | |||
| generate_output = super().generate(input_ids, *args, **kwargs) | |||
| if isinstance(generate_output, Dict): | |||
| return TokenGeneratorOutput( | |||
| sequences=generate_output.sequences, | |||
| scores=generate_output.scores, | |||
| attentions=generate_output.attentions, | |||
| hidden_states=generate_output.hidden_states) | |||
| else: | |||
| return TokenGeneratorOutput(sequences=generate_output) | |||
| @@ -541,3 +541,50 @@ class Seq2SeqLMOutput(ModelOutputBase): | |||
| encoder_last_hidden_state: Optional[Tensor] = None | |||
| encoder_hidden_states: Optional[Tuple[Tensor]] = None | |||
| encoder_attentions: Optional[Tuple[Tensor]] = None | |||
| @dataclass | |||
| class TextGenerationModelOutput(ModelOutputBase): | |||
| """The output class for text generation models. | |||
| Args: | |||
| logits (`Tensor`): The logits output of the model. loss (`Tensor`, | |||
| *optional*) The loss of the model, available when training. | |||
| hidden_states (`Tensor`, *optional*) Hidden-states of the model at the | |||
| output of each layer plus the optional initial embedding outputs. | |||
| """ | |||
| logits: Tensor = None | |||
| loss: Tensor = None | |||
| @dataclass | |||
| class TokenGeneratorOutput(ModelOutputBase): | |||
| """ | |||
| The output class for generate method of text generation models. | |||
| Args: | |||
| sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`): | |||
| The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter | |||
| if all batches finished early due to the `eos_token_id`. | |||
| scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` | |||
| is passed or when `config.output_scores=True`): | |||
| Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) | |||
| at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for | |||
| each generated token), with each tensor of shape `(batch_size*num_return_sequences, config.vocab_size)`. | |||
| attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` | |||
| is passed or `config.output_attentions=True`): | |||
| Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of | |||
| `torch.FloatTensor` of shape `(num_return_sequences*batch_size, num_heads, generated_length, | |||
| sequence_length)`. | |||
| hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` | |||
| is passed or when `config.output_hidden_states=True`): | |||
| Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of | |||
| `torch.FloatTensor` of shape `(num_return_sequences*batch_size, generated_length, hidden_size)`. | |||
| """ | |||
| sequences: Tensor = None | |||
| scores: Optional[Tuple[Tensor]] = None | |||
| attentions: Optional[Tuple[Tuple[Tensor]]] = None | |||
| hidden_states: Optional[Tuple[Tuple[Tensor]]] = None | |||
| @@ -104,6 +104,10 @@ class TextGenerationPipeline(Pipeline): | |||
| tokenizer = self.preprocessor.tokenizer | |||
| return tokenizer.decode(inputs.tolist(), skip_special_tokens=True) | |||
| def sentence_piece(self, inputs) -> str: | |||
| tokenizer = self.preprocessor.tokenizer | |||
| return tokenizer.decode(inputs.tolist()) | |||
| def roberta(self, inputs) -> str: | |||
| tokenizer = self.preprocessor.tokenizer | |||
| decoded = tokenizer.decode(inputs.tolist()) | |||
| @@ -121,7 +125,7 @@ class TextGenerationPipeline(Pipeline): | |||
| Dict[str, str]: the prediction results | |||
| """ | |||
| inputs = inputs['sequences'] | |||
| if isinstance(inputs, list): | |||
| if isinstance(inputs, list) or len(inputs.shape) > 1: | |||
| inputs = inputs[0] | |||
| decoded = getattr(self, self.postprocessor)(inputs) | |||
| text = self._remove_space_between_chinese_chars(decoded) | |||
| @@ -183,7 +183,7 @@ class TextGenerationTest(unittest.TestCase, DemoCompatibilityCheck): | |||
| task=Tasks.text_generation, model='langboat/bloom-1b4-zh') | |||
| print(pipe('中国的首都是')) | |||
| @unittest.skip("Langboat's checkpoint has not been uploaded to modelhub") | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_gpt_neo(self): | |||
| pipe = pipeline( | |||
| task=Tasks.text_generation, model='langboat/mengzi-gpt-neo-base') | |||