Browse Source

[to #42322933] Add GPT3 tensor parallel inference

添加基于 Megatron-v3 的 GPT3 tensor 并行的推理代码
复用 DistributedPipeline 与 megatron-util
适用模型:1.3B/2.7B/13B 参数的 GPT-3 预训练生成模型
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10416721
master
hemu.zp yingda.chen 3 years ago
parent
commit
3b1f1a0252
12 changed files with 1388 additions and 42 deletions
  1. +2
    -0
      modelscope/metainfo.py
  2. +2
    -0
      modelscope/models/nlp/gpt3/__init__.py
  3. +75
    -15
      modelscope/models/nlp/gpt3/configuration_gpt3.py
  4. +1057
    -0
      modelscope/models/nlp/gpt3/distributed_gpt3.py
  5. +28
    -26
      modelscope/models/nlp/gpt3/modeling_gpt3.py
  6. +69
    -0
      modelscope/models/nlp/gpt3/tokenizer_gpt3.py
  7. +54
    -0
      modelscope/pipelines/nlp/distributed_gpt3_pipeline.py
  8. +2
    -0
      modelscope/preprocessors/__init__.py
  9. +2
    -0
      modelscope/preprocessors/nlp/__init__.py
  10. +35
    -0
      modelscope/preprocessors/nlp/nlp_base.py
  11. +4
    -1
      modelscope/utils/nlp/distributed.py
  12. +58
    -0
      tests/pipelines/test_gpt3_text_generation.py

+ 2
- 0
modelscope/metainfo.py View File

@@ -227,6 +227,7 @@ class Pipelines(object):
zero_shot_classification = 'zero-shot-classification'
text_error_correction = 'text-error-correction'
plug_generation = 'plug-generation'
gpt3_generation = 'gpt3-generation'
faq_question_answering = 'faq-question-answering'
conversational_text_to_sql = 'conversational-text-to-sql'
table_question_answering_pipeline = 'table-question-answering-pipeline'
@@ -324,6 +325,7 @@ class Preprocessors(object):
bert_seq_cls_tokenizer = 'bert-seq-cls-tokenizer'
text_gen_tokenizer = 'text-gen-tokenizer'
text2text_gen_preprocessor = 'text2text-gen-preprocessor'
text_gen_jieba_tokenizer = 'text-gen-jieba-tokenizer'
text2text_translate_preprocessor = 'text2text-translate-preprocessor'
token_cls_tokenizer = 'token-cls-tokenizer'
ner_tokenizer = 'ner-tokenizer'


+ 2
- 0
modelscope/models/nlp/gpt3/__init__.py View File

@@ -7,11 +7,13 @@ if TYPE_CHECKING:
from .configuration_gpt3 import GPT3Config
from .modeling_gpt3 import GPT3Model
from .gpt3_for_text_generation import GPT3ForTextGeneration
from .tokenizer_gpt3 import JiebaBPETokenizer
else:
_import_structure = {
'configuration_gpt3': ['GPT3Config'],
'modeling_gpt3': ['GPT3Model'],
'gpt3_for_text_generation': ['GPT3ForTextGeneration'],
'tokenizer_gpt3': ['JiebaBPETokenizer'],
}

import sys


+ 75
- 15
modelscope/models/nlp/gpt3/configuration_gpt3.py View File

@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging

@@ -21,25 +22,48 @@ logger = logging.get_logger(__name__)

class GPT3Config(PretrainedConfig):

model_type = 'gpt'

def __init__(self,
vocab_size=25600,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act='gelu',
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=2048,
type_vocab_size=2,
layernorm_epsilon=1e-12,
**kwargs):
model_type = 'gpt3'

def __init__(
self,
vocab_size=25600,
hidden_size=768,
ffn_hidden_size=None,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act='gelu',
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=2048,
type_vocab_size=2,
layernorm_epsilon=1e-12,
bias_gelu_fusion=True,
fp32_residual_connection=False,
sequence_parallel=False,
fp16=False,
bf16=False,
apply_query_key_layer_scaling=True,
attention_softmax_in_fp32=False,
kv_channels=None,
masked_softmax_fusion=True,
attention_dropout=0.1,
bias_dropout_fusion=True,
apply_residual_connection_post_layernorm=False,
hidden_dropout=0.1,
init_method_std=0.02,
# generate
eod_id=7,
tokens_to_generate=100,
top_k=0,
top_p=0.9,
**kwargs):
super().__init__(layer_norm_eps=layernorm_epsilon, **kwargs)

self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.ffn_hidden_size = 4 * hidden_size \
if ffn_hidden_size is None else ffn_hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_act = hidden_act
@@ -49,3 +73,39 @@ class GPT3Config(PretrainedConfig):
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.layernorm_epsilon = layernorm_epsilon
self.bias_gelu_fusion = bias_gelu_fusion
self.fp32_residual_connection = fp32_residual_connection
self.sequence_parallel = sequence_parallel
self.fp16 = fp16
self.bf16 = bf16
assert not (fp16 and bf16)
self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
self.attention_softmax_in_fp32 = attention_softmax_in_fp32
if kv_channels is None:
assert hidden_size % num_attention_heads == 0
self.kv_channels = hidden_size // num_attention_heads
self.masked_softmax_fusion = masked_softmax_fusion
self.attention_dropout = attention_dropout
self.bias_dropout_fusion = bias_dropout_fusion
self.apply_residual_connection_post_layernorm = \
apply_residual_connection_post_layernorm
self.hidden_dropout = hidden_dropout
self.init_method_std = init_method_std
self.eod_id = eod_id
self.tokens_to_generate = tokens_to_generate
self.top_k = top_k
self.top_p = top_p

TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
self.no_persist_layer_norm = \
TORCH_MAJOR < 1 or (TORCH_MAJOR == 1 and TORCH_MINOR < 11)

@property
def params_dtype(self):
if self.fp16:
return torch.half
elif self.bf16:
return torch.bfloat16
else:
return torch.float

+ 1057
- 0
modelscope/models/nlp/gpt3/distributed_gpt3.py
File diff suppressed because it is too large
View File


+ 28
- 26
modelscope/models/nlp/gpt3/modeling_gpt3.py View File

@@ -19,8 +19,7 @@ from typing import Optional, Union

import addict
import torch
from torch.nn import (CrossEntropyLoss, Dropout, Embedding, LayerNorm, Linear,
Module, Softmax)
from torch import nn
from torch.nn import functional as F
from transformers.modeling_utils import PreTrainedModel

@@ -28,7 +27,7 @@ from modelscope.utils.constant import ModelFile
from .configuration_gpt3 import GPT3Config


class GPT3SelfAttention(Module):
class GPT3SelfAttention(nn.Module):
"""Parallel self-attention layer abstract class.

Self-attention layer takes input with size [s, b, h]
@@ -44,13 +43,15 @@ class GPT3SelfAttention(Module):
self.hidden_size_per_attention_head = \
self.hidden_size // self.num_attention_heads

self.query_key_value = Linear(self.hidden_size, 3 * self.hidden_size)
self.softmax = Softmax(dim=-1)
self.attention_dropout = Dropout(config.attention_probs_dropout_prob)
self.query_key_value = nn.Linear(self.hidden_size,
3 * self.hidden_size)
self.softmax = nn.Softmax(dim=-1)
self.attention_dropout = nn.Dropout(
config.attention_probs_dropout_prob)

# Output.
self.dense = Linear(self.hidden_size, self.hidden_size)
self.output_dropout = torch.nn.Dropout(config.hidden_dropout_prob)
self.dense = nn.Linear(self.hidden_size, self.hidden_size)
self.output_dropout = nn.Dropout(config.hidden_dropout_prob)

def _transpose_for_scores(self, tensor):
"""Transpose a 3D tensor [b, s, np*hn] into a 4D tensor with
@@ -133,7 +134,7 @@ class GPT3SelfAttention(Module):
return output


class GPT3MLP(Module):
class GPT3MLP(nn.Module):
"""MLP.

MLP will take the input with h hidden state, project it to 4*h
@@ -146,12 +147,12 @@ class GPT3MLP(Module):

hidden_size = config.hidden_size
# Project to 4h.
self.dense_h_to_4h = Linear(hidden_size, 4 * hidden_size)
self.dense_h_to_4h = nn.Linear(hidden_size, 4 * hidden_size)
self.activation_func = F.gelu
# Project back to h.
self.dense_4h_to_h = Linear(4 * hidden_size, hidden_size)
self.dense_4h_to_h = nn.Linear(4 * hidden_size, hidden_size)

self.dropout = Dropout(config.hidden_dropout_prob)
self.dropout = nn.Dropout(config.hidden_dropout_prob)

def forward(self, hidden_states):

@@ -164,7 +165,7 @@ class GPT3MLP(Module):
return output


class GPT3TransformerLayer(Module):
class GPT3TransformerLayer(nn.Module):
"""A single transformer layer.

Transformer layer takes input with size [s, b, h] and returns an
@@ -175,14 +176,14 @@ class GPT3TransformerLayer(Module):
super().__init__()

# Layernorm on the input data.
self.input_layernorm = LayerNorm(
self.input_layernorm = nn.LayerNorm(
config.hidden_size, eps=config.layernorm_epsilon)

# Self attention.
self.attention = GPT3SelfAttention(config)

# Layernorm on the attention output
self.post_attention_layernorm = LayerNorm(
self.post_attention_layernorm = nn.LayerNorm(
config.hidden_size, eps=config.layernorm_epsilon)

# MLP
@@ -208,7 +209,7 @@ class GPT3TransformerLayer(Module):
return output


class GPT3Transformer(Module):
class GPT3Transformer(nn.Module):
"""Transformer class."""

def __init__(self, config):
@@ -223,7 +224,7 @@ class GPT3Transformer(Module):
[GPT3TransformerLayer(config) for _ in range(self.num_layers)])

# Final layer norm before output.
self.final_layernorm = LayerNorm(
self.final_layernorm = nn.LayerNorm(
config.hidden_size, eps=config.layernorm_epsilon)

def _get_layer(self, layer_number):
@@ -242,7 +243,7 @@ class GPT3Transformer(Module):
return hidden_states


class GPT3TransformerLanguageModel(Module):
class GPT3TransformerLanguageModel(nn.Module):
"""Transformer language model.

Arguments:
@@ -259,10 +260,11 @@ class GPT3TransformerLanguageModel(Module):
super().__init__()

# Embeddings.
self.word_embeddings = Embedding(config.vocab_size, config.hidden_size)
self.position_embeddings = Embedding(config.max_position_embeddings,
config.hidden_size)
self.embedding_dropout = Dropout(config.hidden_dropout_prob)
self.word_embeddings = nn.Embedding(config.vocab_size,
config.hidden_size)
self.position_embeddings = nn.Embedding(config.max_position_embeddings,
config.hidden_size)
self.embedding_dropout = nn.Dropout(config.hidden_dropout_prob)

# Transformer.
self.transformer = GPT3Transformer(config)
@@ -286,19 +288,19 @@ class GPT3Model(PreTrainedModel):

def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, Linear):
if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(
mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, Embedding):
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(
mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, LayerNorm):
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)

@@ -325,7 +327,7 @@ class GPT3Model(PreTrainedModel):
logits = self.language_model(input_ids, attention_mask, position_ids)
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(
logits.view(-1, self.config.vocab_size), labels.view(-1))
return addict.Dict(loss=loss, logits=logits)


+ 69
- 0
modelscope/models/nlp/gpt3/tokenizer_gpt3.py View File

@@ -0,0 +1,69 @@
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from tokenizers import Tokenizer


class JiebaBPETokenizer:
"""SentencePiece BPE tokenizer with Jieba integration"""

def __init__(self, tokenizer_json_file):
self.name = 'Jieba BPE Tokenizer'

self.tokenizer = Tokenizer.from_file(tokenizer_json_file)
self.eod_id = self.tokenizer.token_to_id('<|endoftext|>')
try:
import jieba
except ImportError:
raise ImportError(
'You need to install rjieba to use JiebaTokenizer. '
'See https://pypi.org/project/rjieba/ for installation.')
self.jieba = jieba
self.new_line = self.vocab['\n']
self.sep_token = self.vocab['<sep>']

@property
def vocab_size(self):
return self.tokenizer.get_vocab_size(with_added_tokens=True)

@property
def vocab(self):
return self.tokenizer.get_vocab(with_added_tokens=True)

@property
def inv_vocab(self):
vocab = self.vocab
inv_vocab = dict()
for key, val in vocab.items():
inv_vocab[val] = key
return inv_vocab

def tokenize(self, text, is_code=False):
"""
"""
if not is_code:
seg_list = [x for x in self.jieba.cut(text)]
return self.tokenizer.encode(
seg_list, is_pretokenized=True, add_special_tokens=True).ids
else:
return self.tokenizer.encode(
text, is_pretokenized=False, add_special_tokens=True).ids

def detokenize(self, token_ids):
text = self.tokenizer.decode(token_ids, skip_special_tokens=False)
return text

@property
def eod(self):
return self.eod_id

+ 54
- 0
modelscope/pipelines/nlp/distributed_gpt3_pipeline.py View File

@@ -0,0 +1,54 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

from typing import Any, Dict

import torch

from modelscope.metainfo import Pipelines
from modelscope.models.nlp.gpt3.distributed_gpt3 import DistributedGPT3
from modelscope.pipelines.base import DistributedPipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import TextGenerationJiebaPreprocessor
from modelscope.utils.constant import Tasks


@PIPELINES.register_module(
Tasks.text_generation, module_name=Pipelines.gpt3_generation)
class DistributedGPT3Pipeline(DistributedPipeline):
"""This class is used to instantiate the gpt3 model.
"""

model = None

def __init__(self, model, preprocessor=None, **kwargs):
if preprocessor is None:
preprocessor = TextGenerationJiebaPreprocessor(model)
super().__init__(model, preprocessor=preprocessor, **kwargs)
assert hasattr(preprocessor, 'tokenizer')

@classmethod
def _instantiate_one(cls, rank, model_dir, **kwargs):
cls.model = DistributedGPT3(model_dir, rank, **kwargs)
cls.model.eval()

@classmethod
def _forward_one(cls, inputs: Dict[str, Any]) -> Dict[str, Any]:
tokens = inputs['inputs']['input_ids'].cuda(
torch.cuda.current_device())
return cls.model.generate(tokens)

def postprocess(self, inputs: Dict[str, Any],
**postprocess_params) -> Dict[str, str]:
"""process the prediction results

Args:
inputs (Dict[str, Any]): _description_

Returns:
Dict[str, str]: the prediction results
"""
from modelscope.outputs import OutputKeys
return {
OutputKeys.TEXT:
self.preprocessor.tokenizer.detokenize(inputs[0].tolist())
}

+ 2
- 0
modelscope/preprocessors/__init__.py View File

@@ -32,6 +32,7 @@ if TYPE_CHECKING:
Tokenize,
WordSegmentationBlankSetToLabelPreprocessor,
ZeroShotClassificationPreprocessor,
TextGenerationJiebaPreprocessor,
SentencePiecePreprocessor,
)
from .space import (DialogIntentPredictionPreprocessor,
@@ -72,6 +73,7 @@ else:
'Text2TextGenerationPreprocessor',
'WordSegmentationBlankSetToLabelPreprocessor',
'ZeroShotClassificationPreprocessor',
'TextGenerationJiebaPreprocessor',
'SentencePiecePreprocessor',
],
'space': [


+ 2
- 0
modelscope/preprocessors/nlp/__init__.py View File

@@ -21,6 +21,7 @@ if TYPE_CHECKING:
Tokenize,
WordSegmentationBlankSetToLabelPreprocessor,
ZeroShotClassificationPreprocessor,
TextGenerationJiebaPreprocessor,
SentencePiecePreprocessor,
)

@@ -42,6 +43,7 @@ else:
'Text2TextGenerationPreprocessor',
'WordSegmentationBlankSetToLabelPreprocessor',
'ZeroShotClassificationPreprocessor',
'TextGenerationJiebaPreprocessor',
'SentencePiecePreprocessor',
],
'text_error_correction': [


+ 35
- 0
modelscope/preprocessors/nlp/nlp_base.py View File

@@ -494,6 +494,41 @@ class TextGenerationPreprocessor(NLPTokenizerPreprocessorBase):
}


@PREPROCESSORS.register_module(
Fields.nlp, module_name=Preprocessors.text_gen_jieba_tokenizer)
class TextGenerationJiebaPreprocessor(Preprocessor):
"""The jieba tokenizer preprocessor used in text generation.
"""

def __init__(self, model_dir: str, *args, **kwargs):
from modelscope.models.nlp.gpt3 import JiebaBPETokenizer
super().__init__(*args, **kwargs)
self.tokenizer = JiebaBPETokenizer(
osp.join(model_dir, 'tokenizer.json'))

def __call__(self, data: str) -> Dict[str, Any]:
"""process the raw input data

Args:
data (str): a sentence
Example:
'深蓝的天空中挂着一轮金黄的圆月,下面是海边的沙地'
Returns:
Dict[str, Any]: the preprocessed data
Example:
{'net_input':
{'src_tokens':tensor([1,2,3,4]),
'src_lengths': tensor([4])}
}
"""
import torch

return {
'input_ids':
torch.tensor(self.tokenizer.tokenize(data)).unsqueeze_(0)
}


@PREPROCESSORS.register_module(
Fields.nlp,
module_name=Preprocessors.word_segment_text_to_label_preprocessor)


+ 4
- 1
modelscope/utils/nlp/distributed.py View File

@@ -35,7 +35,10 @@ def initialize_distributed(rank, mpu, world_size, model_parallel_size,
init_method = 'tcp://'
init_method += master_ip + ':' + master_port
torch.distributed.init_process_group(
backend='nccl', world_size=8, rank=rank, init_method=init_method)
backend='nccl',
world_size=world_size,
rank=rank,
init_method=init_method)
# Set the model-parallel communicators.
mpu.initialize_model_parallel(model_parallel_size)



+ 58
- 0
tests/pipelines/test_gpt3_text_generation.py View File

@@ -0,0 +1,58 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import unittest

from modelscope.hub.snapshot_download import snapshot_download
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.test_utils import test_level


class TextGPT3GenerationTest(unittest.TestCase):

def setUp(self) -> None:
# please make sure this local path exists.
self.model_id_1_3B = 'damo/nlp_gpt3_text-generation_1.3B'
self.model_id_2_7B = 'damo/nlp_gpt3_text-generation_2.7B'
self.model_id_13B = 'damo/nlp_gpt3_text-generation_13B'
self.model_dir_13B = snapshot_download(self.model_id_13B)
self.input = '好的'

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_gpt3_1_3B(self):
pipe = pipeline(Tasks.text_generation, model=self.model_id_1_3B)
print(pipe(self.input))

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_gpt3_2_7B(self):
pipe = pipeline(Tasks.text_generation, model=self.model_id_2_7B)
print(pipe(self.input))

@unittest.skip('distributed gpt3 13B, skipped')
def test_gpt3_13B(self):
""" The model can be downloaded from the link on
TODO: add gpt3 checkpoint link
After downloading, you should have a gpt3 model structure like this:
nlp_gpt3_text-generation_13B
|_ config.json
|_ configuration.json
|_ tokenizer.json
|_ model <-- an empty directory

Model binaries shall be downloaded separately to populate the model directory, so that
the model directory would contain the following binaries:
|_ model
|_ mp_rank_00_model_states.pt
|_ mp_rank_01_model_states.pt
|_ mp_rank_02_model_states.pt
|_ mp_rank_03_model_states.pt
|_ mp_rank_04_model_states.pt
|_ mp_rank_05_model_states.pt
|_ mp_rank_06_model_states.pt
|_ mp_rank_07_model_states.pt
"""
pipe = pipeline(Tasks.text_generation, model=self.model_dir_13B)
print(pipe(self.input))


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save