From 3b1f1a0252d4fee7ecd15ac8dc7c04ec0535add0 Mon Sep 17 00:00:00 2001 From: "hemu.zp" Date: Tue, 18 Oct 2022 15:58:33 +0800 Subject: [PATCH] [to #42322933] Add GPT3 tensor parallel inference MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 添加基于 Megatron-v3 的 GPT3 tensor 并行的推理代码 复用 DistributedPipeline 与 megatron-util 适用模型:1.3B/2.7B/13B 参数的 GPT-3 预训练生成模型 Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10416721 --- modelscope/metainfo.py | 2 + modelscope/models/nlp/gpt3/__init__.py | 2 + .../models/nlp/gpt3/configuration_gpt3.py | 90 +- .../models/nlp/gpt3/distributed_gpt3.py | 1057 +++++++++++++++++ modelscope/models/nlp/gpt3/modeling_gpt3.py | 54 +- modelscope/models/nlp/gpt3/tokenizer_gpt3.py | 69 ++ .../nlp/distributed_gpt3_pipeline.py | 54 + modelscope/preprocessors/__init__.py | 2 + modelscope/preprocessors/nlp/__init__.py | 2 + modelscope/preprocessors/nlp/nlp_base.py | 35 + modelscope/utils/nlp/distributed.py | 5 +- tests/pipelines/test_gpt3_text_generation.py | 58 + 12 files changed, 1388 insertions(+), 42 deletions(-) create mode 100644 modelscope/models/nlp/gpt3/distributed_gpt3.py create mode 100644 modelscope/models/nlp/gpt3/tokenizer_gpt3.py create mode 100644 modelscope/pipelines/nlp/distributed_gpt3_pipeline.py create mode 100644 tests/pipelines/test_gpt3_text_generation.py diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index e4a26303..2dbff948 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -227,6 +227,7 @@ class Pipelines(object): zero_shot_classification = 'zero-shot-classification' text_error_correction = 'text-error-correction' plug_generation = 'plug-generation' + gpt3_generation = 'gpt3-generation' faq_question_answering = 'faq-question-answering' conversational_text_to_sql = 'conversational-text-to-sql' table_question_answering_pipeline = 'table-question-answering-pipeline' @@ -324,6 +325,7 @@ class Preprocessors(object): bert_seq_cls_tokenizer = 'bert-seq-cls-tokenizer' text_gen_tokenizer = 'text-gen-tokenizer' text2text_gen_preprocessor = 'text2text-gen-preprocessor' + text_gen_jieba_tokenizer = 'text-gen-jieba-tokenizer' text2text_translate_preprocessor = 'text2text-translate-preprocessor' token_cls_tokenizer = 'token-cls-tokenizer' ner_tokenizer = 'ner-tokenizer' diff --git a/modelscope/models/nlp/gpt3/__init__.py b/modelscope/models/nlp/gpt3/__init__.py index 076a0c6b..9cae8cc8 100644 --- a/modelscope/models/nlp/gpt3/__init__.py +++ b/modelscope/models/nlp/gpt3/__init__.py @@ -7,11 +7,13 @@ if TYPE_CHECKING: from .configuration_gpt3 import GPT3Config from .modeling_gpt3 import GPT3Model from .gpt3_for_text_generation import GPT3ForTextGeneration + from .tokenizer_gpt3 import JiebaBPETokenizer else: _import_structure = { 'configuration_gpt3': ['GPT3Config'], 'modeling_gpt3': ['GPT3Model'], 'gpt3_for_text_generation': ['GPT3ForTextGeneration'], + 'tokenizer_gpt3': ['JiebaBPETokenizer'], } import sys diff --git a/modelscope/models/nlp/gpt3/configuration_gpt3.py b/modelscope/models/nlp/gpt3/configuration_gpt3.py index d5a054fd..66e8b836 100644 --- a/modelscope/models/nlp/gpt3/configuration_gpt3.py +++ b/modelscope/models/nlp/gpt3/configuration_gpt3.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import torch from transformers.configuration_utils import PretrainedConfig from transformers.utils import logging @@ -21,25 +22,48 @@ logger = logging.get_logger(__name__) class GPT3Config(PretrainedConfig): - model_type = 'gpt' - - def __init__(self, - vocab_size=25600, - hidden_size=768, - num_hidden_layers=12, - num_attention_heads=12, - intermediate_size=3072, - hidden_act='gelu', - hidden_dropout_prob=0.1, - attention_probs_dropout_prob=0.1, - max_position_embeddings=2048, - type_vocab_size=2, - layernorm_epsilon=1e-12, - **kwargs): + model_type = 'gpt3' + + def __init__( + self, + vocab_size=25600, + hidden_size=768, + ffn_hidden_size=None, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act='gelu', + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=2048, + type_vocab_size=2, + layernorm_epsilon=1e-12, + bias_gelu_fusion=True, + fp32_residual_connection=False, + sequence_parallel=False, + fp16=False, + bf16=False, + apply_query_key_layer_scaling=True, + attention_softmax_in_fp32=False, + kv_channels=None, + masked_softmax_fusion=True, + attention_dropout=0.1, + bias_dropout_fusion=True, + apply_residual_connection_post_layernorm=False, + hidden_dropout=0.1, + init_method_std=0.02, + # generate + eod_id=7, + tokens_to_generate=100, + top_k=0, + top_p=0.9, + **kwargs): super().__init__(layer_norm_eps=layernorm_epsilon, **kwargs) self.vocab_size = vocab_size self.hidden_size = hidden_size + self.ffn_hidden_size = 4 * hidden_size \ + if ffn_hidden_size is None else ffn_hidden_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.hidden_act = hidden_act @@ -49,3 +73,39 @@ class GPT3Config(PretrainedConfig): self.max_position_embeddings = max_position_embeddings self.type_vocab_size = type_vocab_size self.layernorm_epsilon = layernorm_epsilon + self.bias_gelu_fusion = bias_gelu_fusion + self.fp32_residual_connection = fp32_residual_connection + self.sequence_parallel = sequence_parallel + self.fp16 = fp16 + self.bf16 = bf16 + assert not (fp16 and bf16) + self.apply_query_key_layer_scaling = apply_query_key_layer_scaling + self.attention_softmax_in_fp32 = attention_softmax_in_fp32 + if kv_channels is None: + assert hidden_size % num_attention_heads == 0 + self.kv_channels = hidden_size // num_attention_heads + self.masked_softmax_fusion = masked_softmax_fusion + self.attention_dropout = attention_dropout + self.bias_dropout_fusion = bias_dropout_fusion + self.apply_residual_connection_post_layernorm = \ + apply_residual_connection_post_layernorm + self.hidden_dropout = hidden_dropout + self.init_method_std = init_method_std + self.eod_id = eod_id + self.tokens_to_generate = tokens_to_generate + self.top_k = top_k + self.top_p = top_p + + TORCH_MAJOR = int(torch.__version__.split('.')[0]) + TORCH_MINOR = int(torch.__version__.split('.')[1]) + self.no_persist_layer_norm = \ + TORCH_MAJOR < 1 or (TORCH_MAJOR == 1 and TORCH_MINOR < 11) + + @property + def params_dtype(self): + if self.fp16: + return torch.half + elif self.bf16: + return torch.bfloat16 + else: + return torch.float diff --git a/modelscope/models/nlp/gpt3/distributed_gpt3.py b/modelscope/models/nlp/gpt3/distributed_gpt3.py new file mode 100644 index 00000000..a0091259 --- /dev/null +++ b/modelscope/models/nlp/gpt3/distributed_gpt3.py @@ -0,0 +1,1057 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import torch +from megatron import mpu +from megatron.global_vars import get_global_memory_buffer, set_global_variables +from megatron.model import (AttnMaskType, Float16Module, LayerNorm, + bias_gelu_impl) +from megatron.model.fused_softmax import FusedScaleMaskSoftmax +from torch import nn +from torch.nn import functional as F +from transformers.modeling_utils import PreTrainedModel + +from modelscope.models import TorchModel +from modelscope.models.nlp.gpt3 import GPT3Config +from modelscope.utils.nlp.distributed import initialize_distributed +from modelscope.utils.nlp.load_checkpoint import pre_load +from modelscope.utils.torch_utils import set_random_seed_mpu + + +class GPT3ParallelMLP(nn.Module): + """MLP. + + MLP will take the input with h hidden state, project it to 4*h + hidden dimension, perform nonlinear transformation, and project the + state back into h hidden dimension. + """ + + def __init__(self, config, init_method, output_layer_init_method): + super().__init__() + + # Project to 4h. + self.dense_h_to_4h = mpu.ColumnParallelLinearV3( + config, + config.hidden_size, + config.ffn_hidden_size, + gather_output=False, + init_method=init_method, + skip_bias_add=True) + + self.bias_gelu_fusion = config.bias_gelu_fusion + self.activation_func = F.gelu + + # Project back to h. + self.dense_4h_to_h = mpu.RowParallelLinearV3( + config, + config.ffn_hidden_size, + config.hidden_size, + input_is_parallel=True, + init_method=output_layer_init_method, + skip_bias_add=True) + + def forward(self, hidden_states): + + # [s, b, 4hp] + intermediate_parallel, bias_parallel = self.dense_h_to_4h( + hidden_states) + + if self.bias_gelu_fusion: + intermediate_parallel = \ + bias_gelu_impl(intermediate_parallel, bias_parallel) + else: + intermediate_parallel = \ + self.activation_func(intermediate_parallel + bias_parallel) + + # [s, b, h] + output, output_bias = self.dense_4h_to_h(intermediate_parallel) + return output, output_bias + + +class GPT3Embedding(nn.Module): + """Language model embeddings. + + Arguments: + hidden_size: hidden size + vocab_size: vocabulary size + max_sequence_length: maximum size of sequence. This + is used for positional embedding + embedding_dropout_prob: dropout probability for embeddings + init_method: weight initialization method + num_tokentypes: size of the token-type embeddings. 0 value + will ignore this embedding + """ + + def __init__(self, config, init_method): + super().__init__() + + self.hidden_size = config.hidden_size + self.init_method = init_method + + # Word embeddings (parallel). + self.word_embeddings = mpu.VocabParallelEmbedding( + config.vocab_size, self.hidden_size, init_method=self.init_method) + + # Position embedding (serial). + self.position_embeddings = nn.Embedding(config.max_position_embeddings, + self.hidden_size) + # Initialize the position embeddings. + self.init_method(self.position_embeddings.weight) + + self.fp32_residual_connection = config.fp32_residual_connection + self.sequence_parallel = config.sequence_parallel + # Embeddings dropout + self.embedding_dropout = nn.Dropout(config.hidden_dropout) + + def zero_parameters(self): + """Zero out all parameters in embedding.""" + self.word_embeddings.weight.data.fill_(0) + self.word_embeddings.weight.shared = True + self.position_embeddings.weight.data.fill_(0) + self.position_embeddings.weight.shared = True + + def forward(self, input_ids, position_ids): + # Embeddings. + words_embeddings = self.word_embeddings(input_ids) + position_embeddings = self.position_embeddings(position_ids) + embeddings = words_embeddings + position_embeddings + + # Data format change to avoid explicit tranposes : [b s h] --> [s b h]. + embeddings = embeddings.transpose(0, 1).contiguous() + + # If the input flag for fp32 residual connection is set, convert for float. + if self.fp32_residual_connection: + embeddings = embeddings.float() + + # Dropout. + if self.sequence_parallel: + embeddings = mpu.scatter_to_sequence_parallel_region(embeddings) + with mpu.get_cuda_rng_tracker().fork(): + embeddings = self.embedding_dropout(embeddings) + else: + embeddings = self.embedding_dropout(embeddings) + return embeddings + + +class NoopTransformerLayer(nn.Module): + + def __init__(self, layer_number): + super().__init__() + self.layer_number = layer_number + + def forward(self, + hidden_states, + attention_mask, + encoder_output=None, + enc_dec_attn_mask=None, + inference_params=None): + return hidden_states.clone() + + +def attention_mask_func(attention_scores, attention_mask): + attention_scores.masked_fill_(attention_mask, -10000.0) + return attention_scores + + +class GPT3CoreAttention(nn.Module): + + def __init__(self, + config, + layer_number, + attn_mask_type=AttnMaskType.padding): + super().__init__() + self.fp16 = config.fp16 + self.bf16 = config.bf16 + + self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling + self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32 + if self.apply_query_key_layer_scaling: + self.attention_softmax_in_fp32 = True + self.layer_number = max(1, layer_number) + self.attn_mask_type = attn_mask_type + self.sequence_parallel = config.sequence_parallel + + projection_size = config.kv_channels * config.num_attention_heads + + # Per attention head and per partition values. + world_size = mpu.get_model_parallel_world_size() + self.hidden_size_per_partition = mpu.divide(projection_size, + world_size) + self.hidden_size_per_attention_head = mpu.divide( + projection_size, config.num_attention_heads) + self.num_attention_heads_per_partition = mpu.divide( + config.num_attention_heads, world_size) + + coeff = None + self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) + if self.apply_query_key_layer_scaling: + coeff = self.layer_number + self.norm_factor *= coeff + + self.scale_mask_softmax = FusedScaleMaskSoftmax( + self.fp16, self.bf16, self.attn_mask_type, + config.masked_softmax_fusion, attention_mask_func, + self.attention_softmax_in_fp32, coeff) + + # Dropout. Note that for a single iteration, this layer will generate + # different outputs on different number of parallel partitions but + # on average it should not be partition dependent. + self.attention_dropout = nn.Dropout(config.attention_dropout) + + def forward(self, query_layer, key_layer, value_layer, attention_mask): + + # =================================== + # Raw attention scores. [b, np, s, s] + # =================================== + + # [b, np, sq, sk] + output_size = (query_layer.size(1), query_layer.size(2), + query_layer.size(0), key_layer.size(0)) + + # [sq, b, np, hn] -> [sq, b * np, hn] + query_layer = query_layer.view(output_size[2], + output_size[0] * output_size[1], -1) + # [sk, b, np, hn] -> [sk, b * np, hn] + key_layer = key_layer.view(output_size[3], + output_size[0] * output_size[1], -1) + + # preallocting input tensor: [b * np, sq, sk] + matmul_input_buffer = get_global_memory_buffer().get_tensor( + (output_size[0] * output_size[1], output_size[2], output_size[3]), + query_layer.dtype, 'mpu') + + # Raw attention scores. [b * np, sq, sk] + matmul_result = torch.baddbmm( + matmul_input_buffer, + query_layer.transpose(0, 1), # [b * np, sq, hn] + key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] + beta=0.0, + alpha=(1.0 / self.norm_factor)) + + # change view to [b, np, sq, sk] + attention_scores = matmul_result.view(*output_size) + + # =========================== + # Attention probs and dropout + # =========================== + + # attention scores and attention mask [b, np, sq, sk] + attention_probs = self.scale_mask_softmax(attention_scores, + attention_mask) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + + if not self.sequence_parallel: + with mpu.get_cuda_rng_tracker().fork(): + attention_probs = self.attention_dropout(attention_probs) + else: + attention_probs = self.attention_dropout(attention_probs) + + # ========================= + # Context layer. [sq, b, hp] + # ========================= + + # value_layer -> context layer. + # [sk, b, np, hn] --> [b, np, sq, hn] + + # context layer shape: [b, np, sq, hn] + output_size = (value_layer.size(1), value_layer.size(2), + query_layer.size(0), value_layer.size(3)) + + # change view [sk, b * np, hn] + value_layer = value_layer.view( + value_layer.size(0), output_size[0] * output_size[1], -1) + + # change view [b * np, sq, sk] + attention_probs = attention_probs.view(output_size[0] * output_size[1], + output_size[2], -1) + + # matmul: [b * np, sq, hn] + context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) + + # change view [b, np, sq, hn] + context_layer = context_layer.view(*output_size) + + # [b, np, sq, hn] --> [sq, b, np, hn] + context_layer = context_layer.permute(2, 0, 1, 3).contiguous() + + # [sq, b, np, hn] --> [sq, b, hp] + new_context_layer_shape = context_layer.size()[:-2] + \ + (self.hidden_size_per_partition,) + context_layer = context_layer.view(*new_context_layer_shape) + + return context_layer + + +class GPT3ParallelAttention(nn.Module): + """Parallel self-attention layer abstract class. + + Self-attention layer takes input with size [s, b, h] + and returns output of the same size. + """ + + def __init__(self, config, init_method, output_layer_init_method, + layer_number): + super().__init__() + self.layer_number = max(1, layer_number) + self.params_dtype = config.params_dtype + + projection_size = config.kv_channels * config.num_attention_heads + + # Per attention head and per partition values. + world_size = mpu.get_model_parallel_world_size() + self.hidden_size_per_attention_head = mpu.divide( + projection_size, config.num_attention_heads) + self.num_attention_heads_per_partition = mpu.divide( + config.num_attention_heads, world_size) + + # Strided linear layer. + self.query_key_value = mpu.ColumnParallelLinearV3( + config, + config.hidden_size, + 3 * projection_size, + gather_output=False, + init_method=init_method) + + self.core_attention = GPT3CoreAttention(config, self.layer_number) + + # Output. + self.dense = mpu.RowParallelLinearV3( + config, + projection_size, + config.hidden_size, + input_is_parallel=True, + init_method=output_layer_init_method, + skip_bias_add=True) + + def _allocate_memory(self, inference_max_sequence_len, batch_size): + return torch.empty( + inference_max_sequence_len, + batch_size, + self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + dtype=self.params_dtype, + device=torch.cuda.current_device()) + + def forward(self, hidden_states, attention_mask, inference_params=None): + # hidden_states: [sq, b, h] + + # ================================================= + # Pre-allocate memory for key-values for inference. + # ================================================= + if inference_params: + if self.layer_number not in inference_params.key_value_memory_dict: + inf_max_seq_len = inference_params.max_sequence_len + inf_max_batch_size = inference_params.max_batch_size + inference_key_memory = self._allocate_memory( + inf_max_seq_len, inf_max_batch_size) + inference_value_memory = self._allocate_memory( + inf_max_seq_len, inf_max_batch_size) + inference_params.key_value_memory_dict[self.layer_number] = ( + inference_key_memory, inference_value_memory) + else: + inference_key_memory, inference_value_memory = \ + inference_params.key_value_memory_dict[self.layer_number] + + # ===================== + # Query, Key, and Value + # ===================== + # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] + mixed_x_layer, _ = self.query_key_value(hidden_states) + + # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn] + new_tensor_shape = mixed_x_layer.size()[:-1] + \ + (self.num_attention_heads_per_partition, + 3 * self.hidden_size_per_attention_head) + mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) + + # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] + (query_layer, key_layer, + value_layer) = mpu.split_tensor_along_last_dim(mixed_x_layer, 3) + + # ================================== + # Adjust key and value for inference + # ================================== + + if inference_params: + batch_start = inference_params.batch_size_offset + batch_end = batch_start + key_layer.size(1) + assert batch_end <= inference_key_memory.size(1) + sequence_start = inference_params.sequence_len_offset + sequence_end = sequence_start + key_layer.size(0) + assert sequence_end <= inference_key_memory.size(0) + # Copy key and values. + inference_key_memory[sequence_start:sequence_end, + batch_start:batch_end, ...] = key_layer + inference_value_memory[sequence_start:sequence_end, + batch_start:batch_end, ...] = value_layer + key_layer = inference_key_memory[:sequence_end, + batch_start:batch_end, ...] + value_layer = inference_value_memory[:sequence_end, + batch_start:batch_end, ...] + + # ================================== + # core attention computation + # ================================== + + context_layer = self.core_attention(query_layer, key_layer, + value_layer, attention_mask) + + # ================= + # Output. [sq, b, h] + # ================= + + output, bias = self.dense(context_layer) + + return output, bias + + +class nullcontext: + + def __init__(self, enter_result=None): + self.enter_result = enter_result + + def __enter__(self): + return self.enter_result + + def __exit__(self, *excinfo): + pass + + +def bias_dropout_add(x, bias, residual, prob, training): + # type: (Tensor, Tensor, Tensor, float, bool) -> Tensor + out = torch.nn.functional.dropout(x + bias, p=prob, training=training) + out = residual + out + return out + + +def get_bias_dropout_add(training): + + def _bias_dropout_add(x, bias, residual, prob): + return bias_dropout_add(x, bias, residual, prob, training) + + return _bias_dropout_add + + +@torch.jit.script +def bias_dropout_add_fused_train(x: torch.Tensor, bias: torch.Tensor, + residual: torch.Tensor, + prob: float) -> torch.Tensor: + return bias_dropout_add(x, bias, residual, prob, True) + + +@torch.jit.script +def bias_dropout_add_fused_inference(x: torch.Tensor, bias: torch.Tensor, + residual: torch.Tensor, + prob: float) -> torch.Tensor: + return bias_dropout_add(x, bias, residual, prob, False) + + +class GPT3ParallelTransformerLayer(nn.Module): + """A single transformer layer. + + Transformer layer takes input with size [s, b, h] and returns an + output of the same size. + """ + + def __init__(self, config, init_method, output_layer_init_method, + layer_number): + + super().__init__() + self.layer_number = layer_number + + self.apply_residual_connection_post_layernorm \ + = config.apply_residual_connection_post_layernorm + + self.bf16 = config.bf16 + self.fp32_residual_connection = config.fp32_residual_connection + + # Layernorm on the input data. + self.input_layernorm = LayerNorm( + config.hidden_size, + eps=config.layernorm_epsilon, + no_persist_layer_norm=config.no_persist_layer_norm, + sequence_parallel=config.sequence_parallel) + + # Self attention. + self.self_attention = GPT3ParallelAttention(config, init_method, + output_layer_init_method, + layer_number) + self.hidden_dropout = config.hidden_dropout + self.bias_dropout_fusion = config.bias_dropout_fusion + + # Layernorm on the attention output + self.post_attention_layernorm = LayerNorm( + config.hidden_size, + eps=config.layernorm_epsilon, + no_persist_layer_norm=config.no_persist_layer_norm, + sequence_parallel=config.sequence_parallel) + + # MLP + self.mlp = GPT3ParallelMLP(config, init_method, + output_layer_init_method) + + # Set bias+dropout+add fusion grad_enable execution handler. + TORCH_MAJOR = int(torch.__version__.split('.')[0]) + TORCH_MINOR = int(torch.__version__.split('.')[1]) + use_nvfuser = TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 + and TORCH_MINOR >= 10) + self.bias_dropout_add_exec_handler = \ + nullcontext if use_nvfuser else torch.enable_grad + + def forward(self, hidden_states, attention_mask, inference_params=None): + # hidden_states: [s, b, h] + + # Layer norm at the beginning of the transformer layer. + layernorm_output = self.input_layernorm(hidden_states) + # Self attention. + attention_output, attention_bias = \ + self.self_attention( + layernorm_output, + attention_mask, + inference_params=inference_params) + # Residual connection. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = hidden_states + + if self.bias_dropout_fusion: + if self.training: + bias_dropout_add_func = bias_dropout_add_fused_train + else: + bias_dropout_add_func = bias_dropout_add_fused_inference + else: + bias_dropout_add_func = get_bias_dropout_add(self.training) + + with self.bias_dropout_add_exec_handler(): + layernorm_input = bias_dropout_add_func( + attention_output, attention_bias.expand_as(residual), residual, + self.hidden_dropout) + + # Layer norm post the self attention. + layernorm_output = self.post_attention_layernorm(layernorm_input) + + # MLP. + mlp_output, mlp_bias = self.mlp(layernorm_output) + # Second residual connection. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = layernorm_input + + with self.bias_dropout_add_exec_handler(): + output = bias_dropout_add_func(mlp_output, + mlp_bias.expand_as(residual), + residual, self.hidden_dropout) + + # Jit compiled function creates 'view' tensor. This tensor + # potentially gets saved in the MPU checkpoint function context, + # which rejects view tensors. While making a viewless tensor here + # won't result in memory savings (like the data loader, or + # p2p_communication), it serves to document the origin of this + # 'view' tensor. + output = mpu.make_viewless_tensor( + inp=output, requires_grad=output.requires_grad, keep_graph=True) + + return output + + +class GPT3ParallelTransformer(nn.Module): + """Transformer class.""" + + def __init__(self, + config, + init_method, + output_layer_init_method, + post_layer_norm=True, + pre_process=True, + post_process=True): + super().__init__() + + self.bf16 = config.bf16 + self.fp32_residual_connection = config.fp32_residual_connection + self.post_layer_norm = post_layer_norm + self.pre_process = pre_process + self.post_process = post_process + self.input_tensor = None + + self.sequence_parallel = config.sequence_parallel + + # Number of layers. + self.num_layers = config.num_hidden_layers + + # Transformer layers. + def build_layer(layer_number): + return GPT3ParallelTransformerLayer(config, init_method, + output_layer_init_method, + layer_number) + + if self.num_layers == 0: + self.num_layers = 1 + self.layers = torch.nn.ModuleList([NoopTransformerLayer(1)]) + else: + self.layers = torch.nn.ModuleList( + [build_layer(i + 1) for i in range(self.num_layers)]) + + if self.post_process and self.post_layer_norm: + # Final layer norm before output. + self.final_layernorm = LayerNorm( + config.hidden_size, + eps=config.layernorm_epsilon, + no_persist_layer_norm=config.no_persist_layer_norm, + sequence_parallel=config.sequence_parallel) + + def _get_layer(self, layer_number): + return self.layers[layer_number] + + def forward(self, hidden_states, attention_mask, inference_params=None): + # hidden_states: [s, b, h] + + if not self.pre_process: + # See set_input_tensor() + hidden_states = self.input_tensor + + # Viewless tensor. + # - We only need to create a viewless tensor in the case of micro batch + # size (mbs) == 1, since in this case, 'hidden_states.transpose()' + # above creates a view tensor, and '.contiguous()' is a pass-through. + # For mbs >= 2, '.contiguous()' creates a new tensor, eliminating + # the need to make it viewless. + # + # However, we don't explicitly check mbs == 1 here because + # make_viewless_tensor() has negligible overhead when its input + # is already viewless. + # + # - For the 'else' case above, calling make_viewless_tensor() here is + # likely redundant, since p2p_communication.py (likely originator) + # already creates viewless tensors. That said, make_viewless_tensor() + # is called here to be future-proof and corner-case-proof. + hidden_states = mpu.make_viewless_tensor( + hidden_states, + requires_grad=True, + keep_graph=True, + ) + + if self.sequence_parallel: + rng_context = mpu.get_cuda_rng_tracker().fork() + else: + rng_context = nullcontext() + + with rng_context: + # Forward pass. + for index in range(self.num_layers): + layer = self._get_layer(index) + hidden_states = layer( + hidden_states, + attention_mask, + inference_params=inference_params) + + # Final layer norm. + if self.post_process and self.post_layer_norm: + hidden_states = self.final_layernorm(hidden_states) + + return hidden_states + + +class GPT3TransformerLanguageModel(nn.Module): + """Transformer language model. + + Arguments: + transformer_hparams: transformer hyperparameters + vocab_size: vocabulary size + max_sequence_length: maximum size of sequence. This + is used for positional embedding + embedding_dropout_prob: dropout probability for embeddings + num_tokentypes: size of the token-type embeddings. 0 value + will ignore this embedding + """ + + def __init__(self, config, init_method, output_layer_init_method): + super().__init__() + + self.hidden_size = config.hidden_size + self.init_method = init_method + self.encoder_hidden_state = None + + # Embeddings. + self.embedding = GPT3Embedding(config, self.init_method) + + # Transformer. + self.encoder = GPT3ParallelTransformer( + config, + self.init_method, + output_layer_init_method, + ) + + def forward(self, + enc_input_ids, + enc_position_ids, + enc_attn_mask, + inference_params=None, + enc_hidden_states=None): + + # Encoder embedding. + encoder_input = self.embedding(enc_input_ids, enc_position_ids) + + # Run encoder. + if enc_hidden_states is None: + if self.encoder is not None: + encoder_output = self.encoder( + encoder_input, + enc_attn_mask, + inference_params=inference_params) + else: + encoder_output = self.encoder_hidden_state + else: + encoder_output = enc_hidden_states.to(encoder_input.dtype) + + return encoder_output + + +def init_method_normal(sigma): + """Init method based on N(0, sigma).""" + + def init_(tensor): + return nn.init.normal_(tensor, mean=0.0, std=sigma) + + return init_ + + +def scaled_init_method_normal(sigma, num_layers): + """Init method based on N(0, sigma/sqrt(2*num_layers).""" + std = sigma / math.sqrt(2.0 * num_layers) + + def init_(tensor): + return nn.init.normal_(tensor, mean=0.0, std=std) + + return init_ + + +class GPT3Model(PreTrainedModel): + + config_class = GPT3Config + + def __init__(self, config, parallel_output=False): + super().__init__(config) + + self.parallel_output = parallel_output + + self.language_model = GPT3TransformerLanguageModel( + config, init_method_normal(config.init_method_std), + scaled_init_method_normal(config.init_method_std, + config.num_hidden_layers)) + + def word_embeddings_weight(self): + return self.language_model.embedding.word_embeddings.weight + + @staticmethod + def build_attention_mask_and_position_ids(tokens): + seq_length = tokens.size(1) + attention_mask = torch.tril( + torch.ones((1, 1, seq_length, seq_length), + dtype=torch.long, + device=tokens.device)) + attention_mask = (attention_mask < 0.5) + + position_ids = torch.arange( + seq_length, dtype=torch.long, device=tokens.device) + position_ids = position_ids.unsqueeze(0).expand_as(tokens) + + return attention_mask, position_ids + + def forward(self, + input_ids, + attention_mask=None, + position_ids=None, + inference_params=None, + **kwargs): + if attention_mask is None and position_ids is None: + attention_mask, position_ids = \ + self.build_attention_mask_and_position_ids(input_ids) + + lm_output = self.language_model( + input_ids, + position_ids, + attention_mask, + inference_params=inference_params) + + logits_parallel = mpu.LinearWithGradAccumulationAndAsyncCommunication.apply( + lm_output, self.word_embeddings_weight(), None, False, True, + self.config.sequence_parallel) + # Gather if needed. + + output = logits_parallel + if not self.parallel_output: + output = mpu.gather_from_model_parallel_region(logits_parallel) + return output.transpose(0, 1).contiguous() + + +def modify_logits_for_top_k_filtering(logits, top_k): + """Set the logits for none top-k values to -inf.""" + + filter_ = logits < torch.topk(logits, top_k)[0][..., -1, None] + logits.masked_fill_(filter_, float('-Inf')) + + +def modify_logits_for_top_p_filtering(logits, top_p): + """Set the logits for none top-p values to -inf.""" + + # First sort and calculate cumulative sum of probabilities. + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) + + # Filteration based on the cumulative sum. + filter_ = cumulative_probs > top_p + # This shift by 1 is weird and I cannot justify it. This existed + # in the original implementation: + # https://github.com/ari-holtzman/degen/blob/master/gen.py + # and I guess it is needed so keeping it for now. + filter_[:, 1:] = filter_[:, :-1].clone() + # Make sure we at least have one token to select from. + filter_[..., 0] = 0 + + # Fill in the filtered part + filter_ = filter_.scatter(1, sorted_indices, filter_) + logits.masked_fill_(filter_, float('-Inf')) + + +def sample(logits, top_k=0, top_p=0.0, temperature=1.0, vocab_size=None): + """ Sample and generate a token. + Note: logits has the dimension [b, v] where b is the batch size + and v is the vocabulary size. + If vocab_size is provided, we will make sure the sample that is + generated is in [0, vocab-size). This will avoid out of vocabulary + generations due to padding. + """ + + # Check logits for consistency. + assert logits.ndim == 2, 'expected the logits to be of [b, v] shape.' + assert logits.type() == 'torch.cuda.FloatTensor', \ + 'input logits should be floats.' + + # Greedy is just simple argmax. + if top_k == 1: + assert top_p == 0.0, 'cannot set both greedy and top-p samplings.' + samples = torch.argmax(logits, dim=-1) + + # Top-k or top-p sampling. + else: + # Clone so we do not modify the inputs, + logits = logits.clone() + # Apply temperature in place. + if temperature != 1.0: + logits.div_(temperature) + + if top_k > 1: + assert top_p == 0.0, 'cannot set both top-k and top-p samplings.' + assert top_k <= logits.size(1), 'top-k is larger than logit size.' + if vocab_size: + assert top_k < vocab_size, 'top-k is larger than vocab size.' + modify_logits_for_top_k_filtering(logits, top_k) + + elif top_p > 0.0: + assert top_p <= 1.0, 'top-p should be in (0, 1].' + modify_logits_for_top_p_filtering(logits, top_p) + + # After filtering, we need to recalculate the distribution. + probs = logits.softmax(dim=-1) + samples = torch.multinomial(probs, num_samples=1).view(-1) + + # If vocab size is provided, make sure the samples are in + # in the range [0, vocab-size). + if vocab_size: + samples = torch.clamp(samples, min=0, max=(vocab_size - 1)) + + return samples + + +class InferenceParams: + """Inference parameters that are passed to the main model in order + to efficienly calculate and store the context during inference.""" + + def __init__(self, max_batch_size, max_sequence_len): + """Note that offsets are set to zero and we always set the + flag to allocate memory. After the first call, make sure to + set this flag to False.""" + self.max_sequence_len = max_sequence_len + self.max_batch_size = max_batch_size + self.sequence_len_offset = 0 + self.batch_size_offset = 0 + self.key_value_memory_dict = {} + + def swap_key_value_dict(self, batch_idx): + 'swap between batches' + if len(self.key_value_memory_dict) == 0: + raise ValueError('should not swap when dict in empty') + + for layer_number in self.key_value_memory_dict.keys(): + inference_key_memory, inference_value_memory = self.key_value_memory_dict[ + layer_number] + assert len(batch_idx) == inference_key_memory.shape[ + 1] # make sure batch size is the same + new_inference_key_memory = inference_key_memory[:, batch_idx] + new_inference_value_memory = inference_value_memory[:, batch_idx] + self.key_value_memory_dict[layer_number] = ( + new_inference_key_memory, new_inference_value_memory) + + +class DistributedGPT3(TorchModel): + + def __init__(self, + model_dir, + rank, + path_load_tag='model', + *args, + **kwargs): + super().__init__(model_dir, *args, **kwargs) + initialize_distributed(rank, mpu, kwargs['world_size'], + kwargs['model_parallel_size'], + kwargs['master_ip'], kwargs['master_port']) + seed = 0 if 'seed' not in kwargs else kwargs['seed'] + set_random_seed_mpu(seed) + set_global_variables() + + self.config = GPT3Config.from_pretrained(model_dir) + # Build model. + model = GPT3Model(self.config) + + for param in model.parameters(): + mpu.set_defaults_if_not_set_tensor_model_parallel_attributes(param) + + # GPU allocation. + model.cuda(torch.cuda.current_device()) + + # Fp16 conversion. + if self.config.fp16 or self.config.bf16: + model = Float16Module(model, self.config) + + self.dist_model = model + load_model = pre_load(mpu, model_dir, tag=path_load_tag) + self.dist_model.load_state_dict(load_model) + + self.inference_params = None + + def forward_step(self, tokens, attention_mask, position_ids): + logits = self.dist_model( + tokens, + attention_mask, + position_ids, + inference_params=self.inference_params) + self.inference_params.sequence_len_offset += tokens.size(1) + return logits + + def generate(self, + tokens, + temperature=1.0, + use_eod_token_for_early_termination=True, + stop_on_double_eol=False, + stop_on_eol=False): + lengths = torch.tensor([tokens.size(1)], device=tokens.device) + pads = torch.ones( + 1, self.config.tokens_to_generate, + device=tokens.device).long() * self.config.eod_id + tokens = torch.cat((tokens, pads), dim=-1) + + batch_size = tokens.size(0) + min_prompt_length = lengths.min().item() + max_sequence_length = tokens.size(1) + max_sequence_length = min(max_sequence_length, + self.config.max_position_embeddings) + + # If the context is too big, this happens + if min_prompt_length >= max_sequence_length: + raise ValueError('context length + tokens_to_generate too large') + + # Initialize inference parameters. + self.inference_params = InferenceParams(batch_size, + max_sequence_length) + + # Added termination_id to support the case that we want to terminate the + # generation once that id is generated. + termination_id = self.config.eod_id + + # Whether we have reached a termination id. + is_generation_done = torch.zeros( + batch_size, dtype=torch.uint8, device=torch.cuda.current_device()) + + # ============= + # Run infernece + # ============= + + with torch.no_grad(): + attention_mask, position_ids = \ + GPT3Model.build_attention_mask_and_position_ids(tokens) + prev_context_length = 0 + for context_length in range(min_prompt_length, + max_sequence_length): + + # Pick the slice that we need to pass through the network. + tokens2use = tokens[:, prev_context_length:context_length] + positions2use = position_ids[:, prev_context_length: + context_length] + attention_mask2use = attention_mask[ + ..., prev_context_length:context_length, :context_length] + + # logits will be meanigful only in the last pipeline stage. + logits = self.forward_step(tokens2use, attention_mask2use, + positions2use) + + # Sample. + last_token_logits = logits[:, -1, :] + new_sample = sample( + last_token_logits, + top_k=self.config.top_k, + top_p=self.config.top_p, + temperature=temperature, + vocab_size=self.config.vocab_size) + + # If a prompt length is smaller or equal th current context + # length, it means we have started generating tokens + started = lengths <= context_length + # Update the tokens. + tokens[started, context_length] = new_sample[started] + + # Update the context length for the next token generation. + prev_context_length = context_length + + # instead tokenization should be in the inference loop so stop sequences can be used + if stop_on_double_eol: + hit_double_eol = (new_sample + == 628).byte() & started.byte() + hit_two_eols = (new_sample == 198).byte() & ( + tokens[:, context_length - 1] + == 198).byte() & started.byte() + done_token = hit_double_eol | hit_two_eols + elif stop_on_eol: + hit_double_eol = (new_sample + == 628).byte() & started.byte() + hit_eol = (new_sample == 198).byte() & started.byte() + done_token = hit_double_eol | hit_eol + else: + done_token = (new_sample == termination_id).byte() & \ + started.byte() + + is_generation_done = is_generation_done | done_token + done = torch.all(is_generation_done) + + if use_eod_token_for_early_termination and done: + break + + tokens = tokens[:, :(context_length + 1)] + return tokens diff --git a/modelscope/models/nlp/gpt3/modeling_gpt3.py b/modelscope/models/nlp/gpt3/modeling_gpt3.py index ade36e36..2c23f5db 100644 --- a/modelscope/models/nlp/gpt3/modeling_gpt3.py +++ b/modelscope/models/nlp/gpt3/modeling_gpt3.py @@ -19,8 +19,7 @@ from typing import Optional, Union import addict import torch -from torch.nn import (CrossEntropyLoss, Dropout, Embedding, LayerNorm, Linear, - Module, Softmax) +from torch import nn from torch.nn import functional as F from transformers.modeling_utils import PreTrainedModel @@ -28,7 +27,7 @@ from modelscope.utils.constant import ModelFile from .configuration_gpt3 import GPT3Config -class GPT3SelfAttention(Module): +class GPT3SelfAttention(nn.Module): """Parallel self-attention layer abstract class. Self-attention layer takes input with size [s, b, h] @@ -44,13 +43,15 @@ class GPT3SelfAttention(Module): self.hidden_size_per_attention_head = \ self.hidden_size // self.num_attention_heads - self.query_key_value = Linear(self.hidden_size, 3 * self.hidden_size) - self.softmax = Softmax(dim=-1) - self.attention_dropout = Dropout(config.attention_probs_dropout_prob) + self.query_key_value = nn.Linear(self.hidden_size, + 3 * self.hidden_size) + self.softmax = nn.Softmax(dim=-1) + self.attention_dropout = nn.Dropout( + config.attention_probs_dropout_prob) # Output. - self.dense = Linear(self.hidden_size, self.hidden_size) - self.output_dropout = torch.nn.Dropout(config.hidden_dropout_prob) + self.dense = nn.Linear(self.hidden_size, self.hidden_size) + self.output_dropout = nn.Dropout(config.hidden_dropout_prob) def _transpose_for_scores(self, tensor): """Transpose a 3D tensor [b, s, np*hn] into a 4D tensor with @@ -133,7 +134,7 @@ class GPT3SelfAttention(Module): return output -class GPT3MLP(Module): +class GPT3MLP(nn.Module): """MLP. MLP will take the input with h hidden state, project it to 4*h @@ -146,12 +147,12 @@ class GPT3MLP(Module): hidden_size = config.hidden_size # Project to 4h. - self.dense_h_to_4h = Linear(hidden_size, 4 * hidden_size) + self.dense_h_to_4h = nn.Linear(hidden_size, 4 * hidden_size) self.activation_func = F.gelu # Project back to h. - self.dense_4h_to_h = Linear(4 * hidden_size, hidden_size) + self.dense_4h_to_h = nn.Linear(4 * hidden_size, hidden_size) - self.dropout = Dropout(config.hidden_dropout_prob) + self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward(self, hidden_states): @@ -164,7 +165,7 @@ class GPT3MLP(Module): return output -class GPT3TransformerLayer(Module): +class GPT3TransformerLayer(nn.Module): """A single transformer layer. Transformer layer takes input with size [s, b, h] and returns an @@ -175,14 +176,14 @@ class GPT3TransformerLayer(Module): super().__init__() # Layernorm on the input data. - self.input_layernorm = LayerNorm( + self.input_layernorm = nn.LayerNorm( config.hidden_size, eps=config.layernorm_epsilon) # Self attention. self.attention = GPT3SelfAttention(config) # Layernorm on the attention output - self.post_attention_layernorm = LayerNorm( + self.post_attention_layernorm = nn.LayerNorm( config.hidden_size, eps=config.layernorm_epsilon) # MLP @@ -208,7 +209,7 @@ class GPT3TransformerLayer(Module): return output -class GPT3Transformer(Module): +class GPT3Transformer(nn.Module): """Transformer class.""" def __init__(self, config): @@ -223,7 +224,7 @@ class GPT3Transformer(Module): [GPT3TransformerLayer(config) for _ in range(self.num_layers)]) # Final layer norm before output. - self.final_layernorm = LayerNorm( + self.final_layernorm = nn.LayerNorm( config.hidden_size, eps=config.layernorm_epsilon) def _get_layer(self, layer_number): @@ -242,7 +243,7 @@ class GPT3Transformer(Module): return hidden_states -class GPT3TransformerLanguageModel(Module): +class GPT3TransformerLanguageModel(nn.Module): """Transformer language model. Arguments: @@ -259,10 +260,11 @@ class GPT3TransformerLanguageModel(Module): super().__init__() # Embeddings. - self.word_embeddings = Embedding(config.vocab_size, config.hidden_size) - self.position_embeddings = Embedding(config.max_position_embeddings, - config.hidden_size) - self.embedding_dropout = Dropout(config.hidden_dropout_prob) + self.word_embeddings = nn.Embedding(config.vocab_size, + config.hidden_size) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, + config.hidden_size) + self.embedding_dropout = nn.Dropout(config.hidden_dropout_prob) # Transformer. self.transformer = GPT3Transformer(config) @@ -286,19 +288,19 @@ class GPT3Model(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" - if isinstance(module, Linear): + if isinstance(module, nn.Linear): # Slightly different from the TF version which uses truncated_normal for initialization # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_( mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() - elif isinstance(module, Embedding): + elif isinstance(module, nn.Embedding): module.weight.data.normal_( mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - elif isinstance(module, LayerNorm): + elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) @@ -325,7 +327,7 @@ class GPT3Model(PreTrainedModel): logits = self.language_model(input_ids, attention_mask, position_ids) loss = None if labels is not None: - loss_fct = CrossEntropyLoss() + loss_fct = nn.CrossEntropyLoss() loss = loss_fct( logits.view(-1, self.config.vocab_size), labels.view(-1)) return addict.Dict(loss=loss, logits=logits) diff --git a/modelscope/models/nlp/gpt3/tokenizer_gpt3.py b/modelscope/models/nlp/gpt3/tokenizer_gpt3.py new file mode 100644 index 00000000..5780ddbd --- /dev/null +++ b/modelscope/models/nlp/gpt3/tokenizer_gpt3.py @@ -0,0 +1,69 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from tokenizers import Tokenizer + + +class JiebaBPETokenizer: + """SentencePiece BPE tokenizer with Jieba integration""" + + def __init__(self, tokenizer_json_file): + self.name = 'Jieba BPE Tokenizer' + + self.tokenizer = Tokenizer.from_file(tokenizer_json_file) + self.eod_id = self.tokenizer.token_to_id('<|endoftext|>') + try: + import jieba + except ImportError: + raise ImportError( + 'You need to install rjieba to use JiebaTokenizer. ' + 'See https://pypi.org/project/rjieba/ for installation.') + self.jieba = jieba + self.new_line = self.vocab['\n'] + self.sep_token = self.vocab[''] + + @property + def vocab_size(self): + return self.tokenizer.get_vocab_size(with_added_tokens=True) + + @property + def vocab(self): + return self.tokenizer.get_vocab(with_added_tokens=True) + + @property + def inv_vocab(self): + vocab = self.vocab + inv_vocab = dict() + for key, val in vocab.items(): + inv_vocab[val] = key + return inv_vocab + + def tokenize(self, text, is_code=False): + """ + """ + if not is_code: + seg_list = [x for x in self.jieba.cut(text)] + return self.tokenizer.encode( + seg_list, is_pretokenized=True, add_special_tokens=True).ids + else: + return self.tokenizer.encode( + text, is_pretokenized=False, add_special_tokens=True).ids + + def detokenize(self, token_ids): + text = self.tokenizer.decode(token_ids, skip_special_tokens=False) + return text + + @property + def eod(self): + return self.eod_id diff --git a/modelscope/pipelines/nlp/distributed_gpt3_pipeline.py b/modelscope/pipelines/nlp/distributed_gpt3_pipeline.py new file mode 100644 index 00000000..325d3303 --- /dev/null +++ b/modelscope/pipelines/nlp/distributed_gpt3_pipeline.py @@ -0,0 +1,54 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import Any, Dict + +import torch + +from modelscope.metainfo import Pipelines +from modelscope.models.nlp.gpt3.distributed_gpt3 import DistributedGPT3 +from modelscope.pipelines.base import DistributedPipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import TextGenerationJiebaPreprocessor +from modelscope.utils.constant import Tasks + + +@PIPELINES.register_module( + Tasks.text_generation, module_name=Pipelines.gpt3_generation) +class DistributedGPT3Pipeline(DistributedPipeline): + """This class is used to instantiate the gpt3 model. + """ + + model = None + + def __init__(self, model, preprocessor=None, **kwargs): + if preprocessor is None: + preprocessor = TextGenerationJiebaPreprocessor(model) + super().__init__(model, preprocessor=preprocessor, **kwargs) + assert hasattr(preprocessor, 'tokenizer') + + @classmethod + def _instantiate_one(cls, rank, model_dir, **kwargs): + cls.model = DistributedGPT3(model_dir, rank, **kwargs) + cls.model.eval() + + @classmethod + def _forward_one(cls, inputs: Dict[str, Any]) -> Dict[str, Any]: + tokens = inputs['inputs']['input_ids'].cuda( + torch.cuda.current_device()) + return cls.model.generate(tokens) + + def postprocess(self, inputs: Dict[str, Any], + **postprocess_params) -> Dict[str, str]: + """process the prediction results + + Args: + inputs (Dict[str, Any]): _description_ + + Returns: + Dict[str, str]: the prediction results + """ + from modelscope.outputs import OutputKeys + return { + OutputKeys.TEXT: + self.preprocessor.tokenizer.detokenize(inputs[0].tolist()) + } diff --git a/modelscope/preprocessors/__init__.py b/modelscope/preprocessors/__init__.py index 43fa64a7..f7defd92 100644 --- a/modelscope/preprocessors/__init__.py +++ b/modelscope/preprocessors/__init__.py @@ -32,6 +32,7 @@ if TYPE_CHECKING: Tokenize, WordSegmentationBlankSetToLabelPreprocessor, ZeroShotClassificationPreprocessor, + TextGenerationJiebaPreprocessor, SentencePiecePreprocessor, ) from .space import (DialogIntentPredictionPreprocessor, @@ -72,6 +73,7 @@ else: 'Text2TextGenerationPreprocessor', 'WordSegmentationBlankSetToLabelPreprocessor', 'ZeroShotClassificationPreprocessor', + 'TextGenerationJiebaPreprocessor', 'SentencePiecePreprocessor', ], 'space': [ diff --git a/modelscope/preprocessors/nlp/__init__.py b/modelscope/preprocessors/nlp/__init__.py index a753fe6c..f7478329 100644 --- a/modelscope/preprocessors/nlp/__init__.py +++ b/modelscope/preprocessors/nlp/__init__.py @@ -21,6 +21,7 @@ if TYPE_CHECKING: Tokenize, WordSegmentationBlankSetToLabelPreprocessor, ZeroShotClassificationPreprocessor, + TextGenerationJiebaPreprocessor, SentencePiecePreprocessor, ) @@ -42,6 +43,7 @@ else: 'Text2TextGenerationPreprocessor', 'WordSegmentationBlankSetToLabelPreprocessor', 'ZeroShotClassificationPreprocessor', + 'TextGenerationJiebaPreprocessor', 'SentencePiecePreprocessor', ], 'text_error_correction': [ diff --git a/modelscope/preprocessors/nlp/nlp_base.py b/modelscope/preprocessors/nlp/nlp_base.py index 3d708634..267dbb8c 100644 --- a/modelscope/preprocessors/nlp/nlp_base.py +++ b/modelscope/preprocessors/nlp/nlp_base.py @@ -494,6 +494,41 @@ class TextGenerationPreprocessor(NLPTokenizerPreprocessorBase): } +@PREPROCESSORS.register_module( + Fields.nlp, module_name=Preprocessors.text_gen_jieba_tokenizer) +class TextGenerationJiebaPreprocessor(Preprocessor): + """The jieba tokenizer preprocessor used in text generation. + """ + + def __init__(self, model_dir: str, *args, **kwargs): + from modelscope.models.nlp.gpt3 import JiebaBPETokenizer + super().__init__(*args, **kwargs) + self.tokenizer = JiebaBPETokenizer( + osp.join(model_dir, 'tokenizer.json')) + + def __call__(self, data: str) -> Dict[str, Any]: + """process the raw input data + + Args: + data (str): a sentence + Example: + '深蓝的天空中挂着一轮金黄的圆月,下面是海边的沙地' + Returns: + Dict[str, Any]: the preprocessed data + Example: + {'net_input': + {'src_tokens':tensor([1,2,3,4]), + 'src_lengths': tensor([4])} + } + """ + import torch + + return { + 'input_ids': + torch.tensor(self.tokenizer.tokenize(data)).unsqueeze_(0) + } + + @PREPROCESSORS.register_module( Fields.nlp, module_name=Preprocessors.word_segment_text_to_label_preprocessor) diff --git a/modelscope/utils/nlp/distributed.py b/modelscope/utils/nlp/distributed.py index 2b590a10..53332c0f 100755 --- a/modelscope/utils/nlp/distributed.py +++ b/modelscope/utils/nlp/distributed.py @@ -35,7 +35,10 @@ def initialize_distributed(rank, mpu, world_size, model_parallel_size, init_method = 'tcp://' init_method += master_ip + ':' + master_port torch.distributed.init_process_group( - backend='nccl', world_size=8, rank=rank, init_method=init_method) + backend='nccl', + world_size=world_size, + rank=rank, + init_method=init_method) # Set the model-parallel communicators. mpu.initialize_model_parallel(model_parallel_size) diff --git a/tests/pipelines/test_gpt3_text_generation.py b/tests/pipelines/test_gpt3_text_generation.py new file mode 100644 index 00000000..413b5874 --- /dev/null +++ b/tests/pipelines/test_gpt3_text_generation.py @@ -0,0 +1,58 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level + + +class TextGPT3GenerationTest(unittest.TestCase): + + def setUp(self) -> None: + # please make sure this local path exists. + self.model_id_1_3B = 'damo/nlp_gpt3_text-generation_1.3B' + self.model_id_2_7B = 'damo/nlp_gpt3_text-generation_2.7B' + self.model_id_13B = 'damo/nlp_gpt3_text-generation_13B' + self.model_dir_13B = snapshot_download(self.model_id_13B) + self.input = '好的' + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_gpt3_1_3B(self): + pipe = pipeline(Tasks.text_generation, model=self.model_id_1_3B) + print(pipe(self.input)) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_gpt3_2_7B(self): + pipe = pipeline(Tasks.text_generation, model=self.model_id_2_7B) + print(pipe(self.input)) + + @unittest.skip('distributed gpt3 13B, skipped') + def test_gpt3_13B(self): + """ The model can be downloaded from the link on + TODO: add gpt3 checkpoint link + After downloading, you should have a gpt3 model structure like this: + nlp_gpt3_text-generation_13B + |_ config.json + |_ configuration.json + |_ tokenizer.json + |_ model <-- an empty directory + + Model binaries shall be downloaded separately to populate the model directory, so that + the model directory would contain the following binaries: + |_ model + |_ mp_rank_00_model_states.pt + |_ mp_rank_01_model_states.pt + |_ mp_rank_02_model_states.pt + |_ mp_rank_03_model_states.pt + |_ mp_rank_04_model_states.pt + |_ mp_rank_05_model_states.pt + |_ mp_rank_06_model_states.pt + |_ mp_rank_07_model_states.pt + """ + pipe = pipeline(Tasks.text_generation, model=self.model_dir_13B) + print(pipe(self.input)) + + +if __name__ == '__main__': + unittest.main()