添加基于 Megatron-v3 的 GPT3 tensor 并行的推理代码
复用 DistributedPipeline 与 megatron-util
适用模型:1.3B/2.7B/13B 参数的 GPT-3 预训练生成模型
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10416721
master
| @@ -227,6 +227,7 @@ class Pipelines(object): | |||
| zero_shot_classification = 'zero-shot-classification' | |||
| text_error_correction = 'text-error-correction' | |||
| plug_generation = 'plug-generation' | |||
| gpt3_generation = 'gpt3-generation' | |||
| faq_question_answering = 'faq-question-answering' | |||
| conversational_text_to_sql = 'conversational-text-to-sql' | |||
| table_question_answering_pipeline = 'table-question-answering-pipeline' | |||
| @@ -324,6 +325,7 @@ class Preprocessors(object): | |||
| bert_seq_cls_tokenizer = 'bert-seq-cls-tokenizer' | |||
| text_gen_tokenizer = 'text-gen-tokenizer' | |||
| text2text_gen_preprocessor = 'text2text-gen-preprocessor' | |||
| text_gen_jieba_tokenizer = 'text-gen-jieba-tokenizer' | |||
| text2text_translate_preprocessor = 'text2text-translate-preprocessor' | |||
| token_cls_tokenizer = 'token-cls-tokenizer' | |||
| ner_tokenizer = 'ner-tokenizer' | |||
| @@ -7,11 +7,13 @@ if TYPE_CHECKING: | |||
| from .configuration_gpt3 import GPT3Config | |||
| from .modeling_gpt3 import GPT3Model | |||
| from .gpt3_for_text_generation import GPT3ForTextGeneration | |||
| from .tokenizer_gpt3 import JiebaBPETokenizer | |||
| else: | |||
| _import_structure = { | |||
| 'configuration_gpt3': ['GPT3Config'], | |||
| 'modeling_gpt3': ['GPT3Model'], | |||
| 'gpt3_for_text_generation': ['GPT3ForTextGeneration'], | |||
| 'tokenizer_gpt3': ['JiebaBPETokenizer'], | |||
| } | |||
| import sys | |||
| @@ -13,6 +13,7 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| import torch | |||
| from transformers.configuration_utils import PretrainedConfig | |||
| from transformers.utils import logging | |||
| @@ -21,25 +22,48 @@ logger = logging.get_logger(__name__) | |||
| class GPT3Config(PretrainedConfig): | |||
| model_type = 'gpt' | |||
| def __init__(self, | |||
| vocab_size=25600, | |||
| hidden_size=768, | |||
| num_hidden_layers=12, | |||
| num_attention_heads=12, | |||
| intermediate_size=3072, | |||
| hidden_act='gelu', | |||
| hidden_dropout_prob=0.1, | |||
| attention_probs_dropout_prob=0.1, | |||
| max_position_embeddings=2048, | |||
| type_vocab_size=2, | |||
| layernorm_epsilon=1e-12, | |||
| **kwargs): | |||
| model_type = 'gpt3' | |||
| def __init__( | |||
| self, | |||
| vocab_size=25600, | |||
| hidden_size=768, | |||
| ffn_hidden_size=None, | |||
| num_hidden_layers=12, | |||
| num_attention_heads=12, | |||
| intermediate_size=3072, | |||
| hidden_act='gelu', | |||
| hidden_dropout_prob=0.1, | |||
| attention_probs_dropout_prob=0.1, | |||
| max_position_embeddings=2048, | |||
| type_vocab_size=2, | |||
| layernorm_epsilon=1e-12, | |||
| bias_gelu_fusion=True, | |||
| fp32_residual_connection=False, | |||
| sequence_parallel=False, | |||
| fp16=False, | |||
| bf16=False, | |||
| apply_query_key_layer_scaling=True, | |||
| attention_softmax_in_fp32=False, | |||
| kv_channels=None, | |||
| masked_softmax_fusion=True, | |||
| attention_dropout=0.1, | |||
| bias_dropout_fusion=True, | |||
| apply_residual_connection_post_layernorm=False, | |||
| hidden_dropout=0.1, | |||
| init_method_std=0.02, | |||
| # generate | |||
| eod_id=7, | |||
| tokens_to_generate=100, | |||
| top_k=0, | |||
| top_p=0.9, | |||
| **kwargs): | |||
| super().__init__(layer_norm_eps=layernorm_epsilon, **kwargs) | |||
| self.vocab_size = vocab_size | |||
| self.hidden_size = hidden_size | |||
| self.ffn_hidden_size = 4 * hidden_size \ | |||
| if ffn_hidden_size is None else ffn_hidden_size | |||
| self.num_hidden_layers = num_hidden_layers | |||
| self.num_attention_heads = num_attention_heads | |||
| self.hidden_act = hidden_act | |||
| @@ -49,3 +73,39 @@ class GPT3Config(PretrainedConfig): | |||
| self.max_position_embeddings = max_position_embeddings | |||
| self.type_vocab_size = type_vocab_size | |||
| self.layernorm_epsilon = layernorm_epsilon | |||
| self.bias_gelu_fusion = bias_gelu_fusion | |||
| self.fp32_residual_connection = fp32_residual_connection | |||
| self.sequence_parallel = sequence_parallel | |||
| self.fp16 = fp16 | |||
| self.bf16 = bf16 | |||
| assert not (fp16 and bf16) | |||
| self.apply_query_key_layer_scaling = apply_query_key_layer_scaling | |||
| self.attention_softmax_in_fp32 = attention_softmax_in_fp32 | |||
| if kv_channels is None: | |||
| assert hidden_size % num_attention_heads == 0 | |||
| self.kv_channels = hidden_size // num_attention_heads | |||
| self.masked_softmax_fusion = masked_softmax_fusion | |||
| self.attention_dropout = attention_dropout | |||
| self.bias_dropout_fusion = bias_dropout_fusion | |||
| self.apply_residual_connection_post_layernorm = \ | |||
| apply_residual_connection_post_layernorm | |||
| self.hidden_dropout = hidden_dropout | |||
| self.init_method_std = init_method_std | |||
| self.eod_id = eod_id | |||
| self.tokens_to_generate = tokens_to_generate | |||
| self.top_k = top_k | |||
| self.top_p = top_p | |||
| TORCH_MAJOR = int(torch.__version__.split('.')[0]) | |||
| TORCH_MINOR = int(torch.__version__.split('.')[1]) | |||
| self.no_persist_layer_norm = \ | |||
| TORCH_MAJOR < 1 or (TORCH_MAJOR == 1 and TORCH_MINOR < 11) | |||
| @property | |||
| def params_dtype(self): | |||
| if self.fp16: | |||
| return torch.half | |||
| elif self.bf16: | |||
| return torch.bfloat16 | |||
| else: | |||
| return torch.float | |||
| @@ -19,8 +19,7 @@ from typing import Optional, Union | |||
| import addict | |||
| import torch | |||
| from torch.nn import (CrossEntropyLoss, Dropout, Embedding, LayerNorm, Linear, | |||
| Module, Softmax) | |||
| from torch import nn | |||
| from torch.nn import functional as F | |||
| from transformers.modeling_utils import PreTrainedModel | |||
| @@ -28,7 +27,7 @@ from modelscope.utils.constant import ModelFile | |||
| from .configuration_gpt3 import GPT3Config | |||
| class GPT3SelfAttention(Module): | |||
| class GPT3SelfAttention(nn.Module): | |||
| """Parallel self-attention layer abstract class. | |||
| Self-attention layer takes input with size [s, b, h] | |||
| @@ -44,13 +43,15 @@ class GPT3SelfAttention(Module): | |||
| self.hidden_size_per_attention_head = \ | |||
| self.hidden_size // self.num_attention_heads | |||
| self.query_key_value = Linear(self.hidden_size, 3 * self.hidden_size) | |||
| self.softmax = Softmax(dim=-1) | |||
| self.attention_dropout = Dropout(config.attention_probs_dropout_prob) | |||
| self.query_key_value = nn.Linear(self.hidden_size, | |||
| 3 * self.hidden_size) | |||
| self.softmax = nn.Softmax(dim=-1) | |||
| self.attention_dropout = nn.Dropout( | |||
| config.attention_probs_dropout_prob) | |||
| # Output. | |||
| self.dense = Linear(self.hidden_size, self.hidden_size) | |||
| self.output_dropout = torch.nn.Dropout(config.hidden_dropout_prob) | |||
| self.dense = nn.Linear(self.hidden_size, self.hidden_size) | |||
| self.output_dropout = nn.Dropout(config.hidden_dropout_prob) | |||
| def _transpose_for_scores(self, tensor): | |||
| """Transpose a 3D tensor [b, s, np*hn] into a 4D tensor with | |||
| @@ -133,7 +134,7 @@ class GPT3SelfAttention(Module): | |||
| return output | |||
| class GPT3MLP(Module): | |||
| class GPT3MLP(nn.Module): | |||
| """MLP. | |||
| MLP will take the input with h hidden state, project it to 4*h | |||
| @@ -146,12 +147,12 @@ class GPT3MLP(Module): | |||
| hidden_size = config.hidden_size | |||
| # Project to 4h. | |||
| self.dense_h_to_4h = Linear(hidden_size, 4 * hidden_size) | |||
| self.dense_h_to_4h = nn.Linear(hidden_size, 4 * hidden_size) | |||
| self.activation_func = F.gelu | |||
| # Project back to h. | |||
| self.dense_4h_to_h = Linear(4 * hidden_size, hidden_size) | |||
| self.dense_4h_to_h = nn.Linear(4 * hidden_size, hidden_size) | |||
| self.dropout = Dropout(config.hidden_dropout_prob) | |||
| self.dropout = nn.Dropout(config.hidden_dropout_prob) | |||
| def forward(self, hidden_states): | |||
| @@ -164,7 +165,7 @@ class GPT3MLP(Module): | |||
| return output | |||
| class GPT3TransformerLayer(Module): | |||
| class GPT3TransformerLayer(nn.Module): | |||
| """A single transformer layer. | |||
| Transformer layer takes input with size [s, b, h] and returns an | |||
| @@ -175,14 +176,14 @@ class GPT3TransformerLayer(Module): | |||
| super().__init__() | |||
| # Layernorm on the input data. | |||
| self.input_layernorm = LayerNorm( | |||
| self.input_layernorm = nn.LayerNorm( | |||
| config.hidden_size, eps=config.layernorm_epsilon) | |||
| # Self attention. | |||
| self.attention = GPT3SelfAttention(config) | |||
| # Layernorm on the attention output | |||
| self.post_attention_layernorm = LayerNorm( | |||
| self.post_attention_layernorm = nn.LayerNorm( | |||
| config.hidden_size, eps=config.layernorm_epsilon) | |||
| # MLP | |||
| @@ -208,7 +209,7 @@ class GPT3TransformerLayer(Module): | |||
| return output | |||
| class GPT3Transformer(Module): | |||
| class GPT3Transformer(nn.Module): | |||
| """Transformer class.""" | |||
| def __init__(self, config): | |||
| @@ -223,7 +224,7 @@ class GPT3Transformer(Module): | |||
| [GPT3TransformerLayer(config) for _ in range(self.num_layers)]) | |||
| # Final layer norm before output. | |||
| self.final_layernorm = LayerNorm( | |||
| self.final_layernorm = nn.LayerNorm( | |||
| config.hidden_size, eps=config.layernorm_epsilon) | |||
| def _get_layer(self, layer_number): | |||
| @@ -242,7 +243,7 @@ class GPT3Transformer(Module): | |||
| return hidden_states | |||
| class GPT3TransformerLanguageModel(Module): | |||
| class GPT3TransformerLanguageModel(nn.Module): | |||
| """Transformer language model. | |||
| Arguments: | |||
| @@ -259,10 +260,11 @@ class GPT3TransformerLanguageModel(Module): | |||
| super().__init__() | |||
| # Embeddings. | |||
| self.word_embeddings = Embedding(config.vocab_size, config.hidden_size) | |||
| self.position_embeddings = Embedding(config.max_position_embeddings, | |||
| config.hidden_size) | |||
| self.embedding_dropout = Dropout(config.hidden_dropout_prob) | |||
| self.word_embeddings = nn.Embedding(config.vocab_size, | |||
| config.hidden_size) | |||
| self.position_embeddings = nn.Embedding(config.max_position_embeddings, | |||
| config.hidden_size) | |||
| self.embedding_dropout = nn.Dropout(config.hidden_dropout_prob) | |||
| # Transformer. | |||
| self.transformer = GPT3Transformer(config) | |||
| @@ -286,19 +288,19 @@ class GPT3Model(PreTrainedModel): | |||
| def _init_weights(self, module): | |||
| """Initialize the weights""" | |||
| if isinstance(module, Linear): | |||
| if isinstance(module, nn.Linear): | |||
| # Slightly different from the TF version which uses truncated_normal for initialization | |||
| # cf https://github.com/pytorch/pytorch/pull/5617 | |||
| module.weight.data.normal_( | |||
| mean=0.0, std=self.config.initializer_range) | |||
| if module.bias is not None: | |||
| module.bias.data.zero_() | |||
| elif isinstance(module, Embedding): | |||
| elif isinstance(module, nn.Embedding): | |||
| module.weight.data.normal_( | |||
| mean=0.0, std=self.config.initializer_range) | |||
| if module.padding_idx is not None: | |||
| module.weight.data[module.padding_idx].zero_() | |||
| elif isinstance(module, LayerNorm): | |||
| elif isinstance(module, nn.LayerNorm): | |||
| module.bias.data.zero_() | |||
| module.weight.data.fill_(1.0) | |||
| @@ -325,7 +327,7 @@ class GPT3Model(PreTrainedModel): | |||
| logits = self.language_model(input_ids, attention_mask, position_ids) | |||
| loss = None | |||
| if labels is not None: | |||
| loss_fct = CrossEntropyLoss() | |||
| loss_fct = nn.CrossEntropyLoss() | |||
| loss = loss_fct( | |||
| logits.view(-1, self.config.vocab_size), labels.view(-1)) | |||
| return addict.Dict(loss=loss, logits=logits) | |||
| @@ -0,0 +1,69 @@ | |||
| # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| from tokenizers import Tokenizer | |||
| class JiebaBPETokenizer: | |||
| """SentencePiece BPE tokenizer with Jieba integration""" | |||
| def __init__(self, tokenizer_json_file): | |||
| self.name = 'Jieba BPE Tokenizer' | |||
| self.tokenizer = Tokenizer.from_file(tokenizer_json_file) | |||
| self.eod_id = self.tokenizer.token_to_id('<|endoftext|>') | |||
| try: | |||
| import jieba | |||
| except ImportError: | |||
| raise ImportError( | |||
| 'You need to install rjieba to use JiebaTokenizer. ' | |||
| 'See https://pypi.org/project/rjieba/ for installation.') | |||
| self.jieba = jieba | |||
| self.new_line = self.vocab['\n'] | |||
| self.sep_token = self.vocab['<sep>'] | |||
| @property | |||
| def vocab_size(self): | |||
| return self.tokenizer.get_vocab_size(with_added_tokens=True) | |||
| @property | |||
| def vocab(self): | |||
| return self.tokenizer.get_vocab(with_added_tokens=True) | |||
| @property | |||
| def inv_vocab(self): | |||
| vocab = self.vocab | |||
| inv_vocab = dict() | |||
| for key, val in vocab.items(): | |||
| inv_vocab[val] = key | |||
| return inv_vocab | |||
| def tokenize(self, text, is_code=False): | |||
| """ | |||
| """ | |||
| if not is_code: | |||
| seg_list = [x for x in self.jieba.cut(text)] | |||
| return self.tokenizer.encode( | |||
| seg_list, is_pretokenized=True, add_special_tokens=True).ids | |||
| else: | |||
| return self.tokenizer.encode( | |||
| text, is_pretokenized=False, add_special_tokens=True).ids | |||
| def detokenize(self, token_ids): | |||
| text = self.tokenizer.decode(token_ids, skip_special_tokens=False) | |||
| return text | |||
| @property | |||
| def eod(self): | |||
| return self.eod_id | |||
| @@ -0,0 +1,54 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from typing import Any, Dict | |||
| import torch | |||
| from modelscope.metainfo import Pipelines | |||
| from modelscope.models.nlp.gpt3.distributed_gpt3 import DistributedGPT3 | |||
| from modelscope.pipelines.base import DistributedPipeline | |||
| from modelscope.pipelines.builder import PIPELINES | |||
| from modelscope.preprocessors import TextGenerationJiebaPreprocessor | |||
| from modelscope.utils.constant import Tasks | |||
| @PIPELINES.register_module( | |||
| Tasks.text_generation, module_name=Pipelines.gpt3_generation) | |||
| class DistributedGPT3Pipeline(DistributedPipeline): | |||
| """This class is used to instantiate the gpt3 model. | |||
| """ | |||
| model = None | |||
| def __init__(self, model, preprocessor=None, **kwargs): | |||
| if preprocessor is None: | |||
| preprocessor = TextGenerationJiebaPreprocessor(model) | |||
| super().__init__(model, preprocessor=preprocessor, **kwargs) | |||
| assert hasattr(preprocessor, 'tokenizer') | |||
| @classmethod | |||
| def _instantiate_one(cls, rank, model_dir, **kwargs): | |||
| cls.model = DistributedGPT3(model_dir, rank, **kwargs) | |||
| cls.model.eval() | |||
| @classmethod | |||
| def _forward_one(cls, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
| tokens = inputs['inputs']['input_ids'].cuda( | |||
| torch.cuda.current_device()) | |||
| return cls.model.generate(tokens) | |||
| def postprocess(self, inputs: Dict[str, Any], | |||
| **postprocess_params) -> Dict[str, str]: | |||
| """process the prediction results | |||
| Args: | |||
| inputs (Dict[str, Any]): _description_ | |||
| Returns: | |||
| Dict[str, str]: the prediction results | |||
| """ | |||
| from modelscope.outputs import OutputKeys | |||
| return { | |||
| OutputKeys.TEXT: | |||
| self.preprocessor.tokenizer.detokenize(inputs[0].tolist()) | |||
| } | |||
| @@ -32,6 +32,7 @@ if TYPE_CHECKING: | |||
| Tokenize, | |||
| WordSegmentationBlankSetToLabelPreprocessor, | |||
| ZeroShotClassificationPreprocessor, | |||
| TextGenerationJiebaPreprocessor, | |||
| SentencePiecePreprocessor, | |||
| ) | |||
| from .space import (DialogIntentPredictionPreprocessor, | |||
| @@ -72,6 +73,7 @@ else: | |||
| 'Text2TextGenerationPreprocessor', | |||
| 'WordSegmentationBlankSetToLabelPreprocessor', | |||
| 'ZeroShotClassificationPreprocessor', | |||
| 'TextGenerationJiebaPreprocessor', | |||
| 'SentencePiecePreprocessor', | |||
| ], | |||
| 'space': [ | |||
| @@ -21,6 +21,7 @@ if TYPE_CHECKING: | |||
| Tokenize, | |||
| WordSegmentationBlankSetToLabelPreprocessor, | |||
| ZeroShotClassificationPreprocessor, | |||
| TextGenerationJiebaPreprocessor, | |||
| SentencePiecePreprocessor, | |||
| ) | |||
| @@ -42,6 +43,7 @@ else: | |||
| 'Text2TextGenerationPreprocessor', | |||
| 'WordSegmentationBlankSetToLabelPreprocessor', | |||
| 'ZeroShotClassificationPreprocessor', | |||
| 'TextGenerationJiebaPreprocessor', | |||
| 'SentencePiecePreprocessor', | |||
| ], | |||
| 'text_error_correction': [ | |||
| @@ -494,6 +494,41 @@ class TextGenerationPreprocessor(NLPTokenizerPreprocessorBase): | |||
| } | |||
| @PREPROCESSORS.register_module( | |||
| Fields.nlp, module_name=Preprocessors.text_gen_jieba_tokenizer) | |||
| class TextGenerationJiebaPreprocessor(Preprocessor): | |||
| """The jieba tokenizer preprocessor used in text generation. | |||
| """ | |||
| def __init__(self, model_dir: str, *args, **kwargs): | |||
| from modelscope.models.nlp.gpt3 import JiebaBPETokenizer | |||
| super().__init__(*args, **kwargs) | |||
| self.tokenizer = JiebaBPETokenizer( | |||
| osp.join(model_dir, 'tokenizer.json')) | |||
| def __call__(self, data: str) -> Dict[str, Any]: | |||
| """process the raw input data | |||
| Args: | |||
| data (str): a sentence | |||
| Example: | |||
| '深蓝的天空中挂着一轮金黄的圆月,下面是海边的沙地' | |||
| Returns: | |||
| Dict[str, Any]: the preprocessed data | |||
| Example: | |||
| {'net_input': | |||
| {'src_tokens':tensor([1,2,3,4]), | |||
| 'src_lengths': tensor([4])} | |||
| } | |||
| """ | |||
| import torch | |||
| return { | |||
| 'input_ids': | |||
| torch.tensor(self.tokenizer.tokenize(data)).unsqueeze_(0) | |||
| } | |||
| @PREPROCESSORS.register_module( | |||
| Fields.nlp, | |||
| module_name=Preprocessors.word_segment_text_to_label_preprocessor) | |||
| @@ -35,7 +35,10 @@ def initialize_distributed(rank, mpu, world_size, model_parallel_size, | |||
| init_method = 'tcp://' | |||
| init_method += master_ip + ':' + master_port | |||
| torch.distributed.init_process_group( | |||
| backend='nccl', world_size=8, rank=rank, init_method=init_method) | |||
| backend='nccl', | |||
| world_size=world_size, | |||
| rank=rank, | |||
| init_method=init_method) | |||
| # Set the model-parallel communicators. | |||
| mpu.initialize_model_parallel(model_parallel_size) | |||
| @@ -0,0 +1,58 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import unittest | |||
| from modelscope.hub.snapshot_download import snapshot_download | |||
| from modelscope.pipelines import pipeline | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.test_utils import test_level | |||
| class TextGPT3GenerationTest(unittest.TestCase): | |||
| def setUp(self) -> None: | |||
| # please make sure this local path exists. | |||
| self.model_id_1_3B = 'damo/nlp_gpt3_text-generation_1.3B' | |||
| self.model_id_2_7B = 'damo/nlp_gpt3_text-generation_2.7B' | |||
| self.model_id_13B = 'damo/nlp_gpt3_text-generation_13B' | |||
| self.model_dir_13B = snapshot_download(self.model_id_13B) | |||
| self.input = '好的' | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_gpt3_1_3B(self): | |||
| pipe = pipeline(Tasks.text_generation, model=self.model_id_1_3B) | |||
| print(pipe(self.input)) | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_gpt3_2_7B(self): | |||
| pipe = pipeline(Tasks.text_generation, model=self.model_id_2_7B) | |||
| print(pipe(self.input)) | |||
| @unittest.skip('distributed gpt3 13B, skipped') | |||
| def test_gpt3_13B(self): | |||
| """ The model can be downloaded from the link on | |||
| TODO: add gpt3 checkpoint link | |||
| After downloading, you should have a gpt3 model structure like this: | |||
| nlp_gpt3_text-generation_13B | |||
| |_ config.json | |||
| |_ configuration.json | |||
| |_ tokenizer.json | |||
| |_ model <-- an empty directory | |||
| Model binaries shall be downloaded separately to populate the model directory, so that | |||
| the model directory would contain the following binaries: | |||
| |_ model | |||
| |_ mp_rank_00_model_states.pt | |||
| |_ mp_rank_01_model_states.pt | |||
| |_ mp_rank_02_model_states.pt | |||
| |_ mp_rank_03_model_states.pt | |||
| |_ mp_rank_04_model_states.pt | |||
| |_ mp_rank_05_model_states.pt | |||
| |_ mp_rank_06_model_states.pt | |||
| |_ mp_rank_07_model_states.pt | |||
| """ | |||
| pipe = pipeline(Tasks.text_generation, model=self.model_dir_13B) | |||
| print(pipe(self.input)) | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||