Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9931748master
| @@ -55,6 +55,7 @@ class Models(object): | |||||
| lcrf = 'lstm-crf' | lcrf = 'lstm-crf' | ||||
| bart = 'bart' | bart = 'bart' | ||||
| gpt3 = 'gpt3' | gpt3 = 'gpt3' | ||||
| plug = 'plug' | |||||
| bert_for_ds = 'bert-for-document-segmentation' | bert_for_ds = 'bert-for-document-segmentation' | ||||
| # audio models | # audio models | ||||
| @@ -172,6 +173,7 @@ class Pipelines(object): | |||||
| dialog_state_tracking = 'dialog-state-tracking' | dialog_state_tracking = 'dialog-state-tracking' | ||||
| zero_shot_classification = 'zero-shot-classification' | zero_shot_classification = 'zero-shot-classification' | ||||
| text_error_correction = 'text-error-correction' | text_error_correction = 'text-error-correction' | ||||
| plug_generation = 'plug-generation' | |||||
| faq_question_answering = 'faq-question-answering' | faq_question_answering = 'faq-question-answering' | ||||
| conversational_text_to_sql = 'conversational-text-to-sql' | conversational_text_to_sql = 'conversational-text-to-sql' | ||||
| relation_extraction = 'relation-extraction' | relation_extraction = 'relation-extraction' | ||||
| @@ -28,6 +28,7 @@ if TYPE_CHECKING: | |||||
| SingleBackboneTaskModelBase) | SingleBackboneTaskModelBase) | ||||
| from .bart_for_text_error_correction import BartForTextErrorCorrection | from .bart_for_text_error_correction import BartForTextErrorCorrection | ||||
| from .gpt3 import GPT3ForTextGeneration | from .gpt3 import GPT3ForTextGeneration | ||||
| from .plug import PlugForTextGeneration | |||||
| from .sbert_for_faq_question_answering import SbertForFaqQuestionAnswering | from .sbert_for_faq_question_answering import SbertForFaqQuestionAnswering | ||||
| else: | else: | ||||
| @@ -60,6 +61,7 @@ else: | |||||
| ], | ], | ||||
| 'bart_for_text_error_correction': ['BartForTextErrorCorrection'], | 'bart_for_text_error_correction': ['BartForTextErrorCorrection'], | ||||
| 'gpt3': ['GPT3ForTextGeneration'], | 'gpt3': ['GPT3ForTextGeneration'], | ||||
| 'plug': ['PlugForTextGeneration'], | |||||
| 'sbert_for_faq_question_answering': ['SbertForFaqQuestionAnswering'], | 'sbert_for_faq_question_answering': ['SbertForFaqQuestionAnswering'], | ||||
| } | } | ||||
| @@ -0,0 +1,27 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from typing import TYPE_CHECKING | |||||
| from modelscope.utils.import_utils import LazyImportModule | |||||
| if TYPE_CHECKING: | |||||
| from .configuration_plug import PlugNLGConfig | |||||
| from .modeling_plug import PlugModel | |||||
| from .distributed_plug import DistributedPlug | |||||
| from .plug_for_text_generation import PlugForTextGeneration | |||||
| else: | |||||
| _import_structure = { | |||||
| 'configuration_plug': ['PlugNLGConfig'], | |||||
| 'modeling_plug': ['PlugModel'], | |||||
| 'distributed_plug': ['DistributedPlug'], | |||||
| 'plug_for_text_generation': ['PlugForTextGeneration'], | |||||
| } | |||||
| import sys | |||||
| sys.modules[__name__] = LazyImportModule( | |||||
| __name__, | |||||
| globals()['__file__'], | |||||
| _import_structure, | |||||
| module_spec=__spec__, | |||||
| extra_objects={}, | |||||
| ) | |||||
| @@ -0,0 +1,232 @@ | |||||
| # Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. | |||||
| # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. | |||||
| # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| import copy | |||||
| import json | |||||
| from transformers import PretrainedConfig | |||||
| from modelscope.utils import logger as logging | |||||
| logger = logging.get_logger(__name__) | |||||
| class PlugNLUConfig(PretrainedConfig): | |||||
| model_type = 'plugNLU' | |||||
| def __init__(self, | |||||
| vocab_size=21504, | |||||
| original_vocab_size=21128, | |||||
| hidden_size=8192, | |||||
| num_hidden_layers=24, | |||||
| num_attention_heads=128, | |||||
| intermediate_size=32768, | |||||
| hidden_act='gelu', | |||||
| hidden_dropout_prob=0.1, | |||||
| attention_probs_dropout_prob=0.1, | |||||
| max_position_embeddings=2048, | |||||
| type_vocab_size=3, | |||||
| initializer_range=0.00707, | |||||
| deep_init=False, | |||||
| deepspeed=False, | |||||
| lr_decay_style='linear', | |||||
| weight_decay=1e-2, | |||||
| clip_grad=1.0, | |||||
| warmup=0.0333, | |||||
| pre_ln=True, | |||||
| fp16=True, | |||||
| fp32_layernorm=True, | |||||
| fp32_embedding=False, | |||||
| fp32_tokentypes=False, | |||||
| layernorm_epsilon=1e-5, | |||||
| dec_hidden_layers=6, | |||||
| pruning_method=None, | |||||
| pruning_mask_init='constant', | |||||
| pruning_mask_scale=0.0, | |||||
| pruning_initial_threshold=1.0, | |||||
| pruning_final_threshold=0.01, | |||||
| pruning_initial_warmup=1, | |||||
| pruning_final_warmup=20, | |||||
| pruning_module='decoder', | |||||
| pruning_decay_step=50, | |||||
| pruning_decay_type='exp', | |||||
| ft_module=None, | |||||
| attn_separate=False, | |||||
| LR_weight_rank=8, | |||||
| LR_mask_rank=8, | |||||
| **kwargs): | |||||
| super().__init__(layer_norm_eps=layernorm_epsilon, **kwargs) | |||||
| self.vocab_size = vocab_size | |||||
| self.original_vocab_size = original_vocab_size | |||||
| self.hidden_size = hidden_size | |||||
| self.num_hidden_layers = num_hidden_layers | |||||
| self.num_attention_heads = num_attention_heads | |||||
| self.hidden_act = hidden_act | |||||
| self.intermediate_size = intermediate_size | |||||
| self.hidden_dropout_prob = hidden_dropout_prob | |||||
| self.attention_probs_dropout_prob = attention_probs_dropout_prob | |||||
| self.max_position_embeddings = max_position_embeddings | |||||
| self.type_vocab_size = type_vocab_size | |||||
| self.initializer_range = initializer_range | |||||
| self.deep_init = deep_init | |||||
| self.deepspeed = deepspeed | |||||
| self.lr_decay_style = lr_decay_style | |||||
| self.weight_decay = weight_decay | |||||
| self.clip_grad = clip_grad | |||||
| self.warmup = warmup | |||||
| self.pre_ln = pre_ln | |||||
| self.fp16 = fp16 | |||||
| self.fp32_layernorm = fp32_layernorm | |||||
| self.fp32_embedding = fp32_embedding | |||||
| self.layernorm_epsilon = layernorm_epsilon | |||||
| self.fp32_tokentypes = fp32_tokentypes | |||||
| self.dec_hidden_layers = dec_hidden_layers | |||||
| self.pruning_method = pruning_method | |||||
| self.pruning_mask_init = pruning_mask_init | |||||
| self.pruning_mask_scale = pruning_mask_scale | |||||
| self.pruning_module = pruning_module | |||||
| self.pruning_initial_threshold = pruning_initial_threshold | |||||
| self.pruning_final_threshold = pruning_final_threshold | |||||
| self.pruning_initial_warmup = pruning_initial_warmup | |||||
| self.pruning_final_warmup = pruning_final_warmup | |||||
| self.pruning_decay_step = pruning_decay_step | |||||
| self.pruning_decay_type = pruning_decay_type | |||||
| self.ft_module = ft_module | |||||
| self.attn_separate = attn_separate | |||||
| self.LR_weight_rank = LR_weight_rank | |||||
| self.LR_mask_rank = LR_mask_rank | |||||
| @classmethod | |||||
| def from_dict(cls, json_object): | |||||
| """Constructs a `BertConfig` from a Python dictionary of parameters.""" | |||||
| config = PlugNLUConfig() | |||||
| for key, value in json_object.items(): | |||||
| config.__dict__[key] = value | |||||
| return config | |||||
| @classmethod | |||||
| def from_json_file(cls, json_file): | |||||
| """Constructs a `BertConfig` from a json file of parameters.""" | |||||
| with open(json_file, 'r', encoding='utf-8') as reader: | |||||
| text = reader.read() | |||||
| return cls.from_dict(json.loads(text)) | |||||
| def merge_args(self, args): | |||||
| """merge values a `BertConfig` from a json file of parameters.""" | |||||
| local_keys = self.__dict__.keys() | |||||
| for key, value in args.__dict__.items(): | |||||
| if key in local_keys: | |||||
| continue | |||||
| self.__dict__[key] = value | |||||
| return self | |||||
| def __repr__(self): | |||||
| return str(self.to_json_string()) | |||||
| def to_dict(self): | |||||
| """Serializes this instance to a Python dictionary.""" | |||||
| output = copy.deepcopy(self.__dict__) | |||||
| return output | |||||
| def to_json_string(self): | |||||
| """Serializes this instance to a JSON string.""" | |||||
| return json.dumps(self.to_dict(), indent=2, sort_keys=True) + '\n' | |||||
| class PlugNLGConfig(PlugNLUConfig): | |||||
| model_type = 'plugNLG' | |||||
| def __init__(self, | |||||
| vocab_size=21504, | |||||
| hidden_size=768, | |||||
| num_hidden_layers=12, | |||||
| num_attention_heads=12, | |||||
| intermediate_size=3072, | |||||
| hidden_act='gelu', | |||||
| hidden_dropout_prob=0.1, | |||||
| attention_probs_dropout_prob=0.1, | |||||
| max_position_embeddings=512, | |||||
| type_vocab_size=2, | |||||
| initializer_range=0.00707, | |||||
| deep_init=False, | |||||
| deepspeed=False, | |||||
| lr_decay_style='linear', | |||||
| weight_decay=1e-2, | |||||
| clip_grad=1.0, | |||||
| warmup=0.01, | |||||
| pre_ln=False, | |||||
| fp16=False, | |||||
| fp32_layernorm=False, | |||||
| fp32_embedding=False, | |||||
| fp32_tokentypes=False, | |||||
| layernorm_epsilon=1e-12, | |||||
| dec_hidden_layers=6, | |||||
| pruning_method=None, | |||||
| pruning_mask_init='constant', | |||||
| pruning_mask_scale=0.0, | |||||
| pruning_initial_threshold=1.0, | |||||
| pruning_final_threshold=0.01, | |||||
| pruning_initial_warmup=1, | |||||
| pruning_final_warmup=20, | |||||
| pruning_module='decoder', | |||||
| pruning_decay_step=50, | |||||
| pruning_decay_type='exp', | |||||
| ft_module=None, | |||||
| attn_separate=False, | |||||
| LR_weight_rank=8, | |||||
| LR_mask_rank=8, | |||||
| **kwargs): | |||||
| super().__init__(layer_norm_eps=layernorm_epsilon, **kwargs) | |||||
| self.vocab_size = vocab_size | |||||
| self.hidden_size = hidden_size | |||||
| self.num_hidden_layers = num_hidden_layers | |||||
| self.num_attention_heads = num_attention_heads | |||||
| self.hidden_act = hidden_act | |||||
| self.intermediate_size = intermediate_size | |||||
| self.hidden_dropout_prob = hidden_dropout_prob | |||||
| self.attention_probs_dropout_prob = attention_probs_dropout_prob | |||||
| self.max_position_embeddings = max_position_embeddings | |||||
| self.type_vocab_size = type_vocab_size | |||||
| self.initializer_range = initializer_range | |||||
| self.deep_init = deep_init | |||||
| self.deepspeed = deepspeed | |||||
| self.lr_decay_style = lr_decay_style | |||||
| self.weight_decay = weight_decay | |||||
| self.clip_grad = clip_grad | |||||
| self.warmup = warmup | |||||
| self.pre_ln = pre_ln | |||||
| self.fp16 = fp16 | |||||
| self.fp32_layernorm = fp32_layernorm | |||||
| self.fp32_embedding = fp32_embedding | |||||
| self.layernorm_epsilon = layernorm_epsilon | |||||
| self.fp32_tokentypes = fp32_tokentypes | |||||
| self.dec_hidden_layers = dec_hidden_layers | |||||
| self.pruning_method = pruning_method | |||||
| self.pruning_mask_init = pruning_mask_init | |||||
| self.pruning_mask_scale = pruning_mask_scale | |||||
| self.pruning_module = pruning_module | |||||
| self.pruning_initial_threshold = pruning_initial_threshold | |||||
| self.pruning_final_threshold = pruning_final_threshold | |||||
| self.pruning_initial_warmup = pruning_initial_warmup | |||||
| self.pruning_final_warmup = pruning_final_warmup | |||||
| self.pruning_decay_step = pruning_decay_step | |||||
| self.pruning_decay_type = pruning_decay_type | |||||
| self.ft_module = ft_module | |||||
| self.attn_separate = attn_separate | |||||
| self.LR_weight_rank = LR_weight_rank | |||||
| self.LR_mask_rank = LR_mask_rank | |||||
| @@ -0,0 +1,191 @@ | |||||
| import os | |||||
| from typing import Dict | |||||
| import torch | |||||
| import torch.nn.functional as F | |||||
| from megatron import mpu | |||||
| from megatron.fp16 import FP16_Module | |||||
| from megatron.utils import print_rank_0 | |||||
| from modelscope.models import TorchModel | |||||
| from modelscope.models.base import Tensor | |||||
| from modelscope.utils.logger import get_logger | |||||
| from modelscope.utils.nlp.distributed import initialize_distributed | |||||
| from modelscope.utils.nlp.load_checkpoint import pre_load | |||||
| from modelscope.utils.torch_utils import set_random_seed_mpu | |||||
| from . import PlugModel | |||||
| from .configuration_plug import PlugNLGConfig | |||||
| logger = get_logger(__name__) | |||||
| class DistributedPlug(TorchModel): | |||||
| def __init__(self, model_dir, rank, **kwargs): | |||||
| super().__init__(model_dir, **kwargs) | |||||
| self.rank = rank | |||||
| self.model_cfg = kwargs | |||||
| self.config = PlugNLGConfig.from_pretrained(model_dir) | |||||
| initialize_distributed(rank, mpu, kwargs['world_size'], | |||||
| kwargs['model_parallel_size'], | |||||
| kwargs['master_ip'], kwargs['master_port']) | |||||
| seed = 0 if 'seed' not in kwargs else kwargs['seed'] | |||||
| set_random_seed_mpu(seed) | |||||
| self.iteration = 0 | |||||
| self.dist_model = self.initialize_model(path_load_tag='model') | |||||
| def initialize_model(self, path_load_tag='model'): | |||||
| """Build the model.""" | |||||
| print_rank_0('Building Plug model. It will take a few minutes ...') | |||||
| model = PlugModel(self.config) | |||||
| if mpu.get_data_parallel_rank() == 0: | |||||
| logger.info( | |||||
| ' > number of parameters on model parallel rank {}: {}'.format( | |||||
| mpu.get_model_parallel_rank(), | |||||
| sum([p.nelement() for p in model.parameters()]))) | |||||
| if self.config.deepspeed and self.config.fp16: | |||||
| model.half() | |||||
| # GPU allocation. | |||||
| model.cuda(torch.cuda.current_device()) | |||||
| # Fp16 conversion. | |||||
| if self.config.fp16: | |||||
| model = FP16_Module(model) | |||||
| if self.config.fp32_embedding: | |||||
| model.module.model.bert.embeddings.word_embeddings.float() | |||||
| model.module.model.bert.embeddings.position_embeddings.float() | |||||
| model.module.model.bert.embeddings.token_type_embeddings.float( | |||||
| ) | |||||
| if self.config.fp32_tokentypes: | |||||
| model.module.model.bert.embeddings.token_type_embeddings.float( | |||||
| ) | |||||
| if self.config.fp32_layernorm: | |||||
| for name, _module in model.named_modules(): | |||||
| if 'LayerNorm' in name: | |||||
| _module.float() | |||||
| load_model = pre_load(mpu, self.model_dir, tag=path_load_tag) | |||||
| model_dict = model.module.model.state_dict() | |||||
| for key in load_model: | |||||
| if key not in model_dict.keys(): | |||||
| print_rank_0('Skip key: ' + key) | |||||
| else: | |||||
| print_rank_0('Loading key: ' + key) | |||||
| model.module.model.load_state_dict(load_model, strict=False) | |||||
| return model | |||||
| @staticmethod | |||||
| def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): | |||||
| # This function has been mostly taken from huggingface conversational ai code at | |||||
| # https://medium.com/huggingface/how-to-build-a-state-of-the-art- | |||||
| # conversational-ai-with-transfer-learning-2d818ac26313 | |||||
| if top_k > 0: | |||||
| # Remove all tokens with a probability less than the last token of the top-k | |||||
| indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, | |||||
| None] | |||||
| logits[indices_to_remove] = filter_value | |||||
| if top_p > 0.0: | |||||
| # convert to 1D | |||||
| logits = logits.view(logits.size()[1]).contiguous() | |||||
| sorted_logits, sorted_indices = torch.sort(logits, descending=True) | |||||
| cumulative_probs = torch.cumsum( | |||||
| F.softmax(sorted_logits, dim=-1), dim=-1) | |||||
| # Remove tokens with cumulative probability above the threshold | |||||
| sorted_indices_to_remove = cumulative_probs > top_p | |||||
| # Shift the indices to the right to keep also the first token above the threshold | |||||
| sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[ | |||||
| ..., :-1].clone() | |||||
| sorted_indices_to_remove[..., 0] = 0 | |||||
| indices_to_remove = sorted_indices[sorted_indices_to_remove] | |||||
| logits[indices_to_remove] = filter_value | |||||
| # going back to 2D | |||||
| logits = logits.view(1, -1).contiguous() | |||||
| return logits | |||||
| def generate(self, input: Dict[str, Tensor], out_length=128, *kwargs): | |||||
| device = torch.cuda.current_device() | |||||
| batch_size = input['input_ids'].shape[0] | |||||
| tokens = input['input_ids'].view(1, -1).contiguous().to(device) | |||||
| dec_input_ids = input['dec_input_ids'].to(device) | |||||
| attention_mask = input['attention_mask'].to(device) | |||||
| self.dist_model.eval() | |||||
| with torch.no_grad(): | |||||
| # Only supports batch_size=1 | |||||
| all_generate_tokens = [] | |||||
| generate_tokens = [] | |||||
| counter = 0 | |||||
| sequence_output = None | |||||
| vocab_size = self.config.original_vocab_size | |||||
| sep_token_idx = 102 # index of [SEP] token in BertTokenizer | |||||
| while counter < out_length: | |||||
| if counter % 128 == 0 and counter != 0: | |||||
| # Sliding window | |||||
| generate_tokens.append(sep_token_idx) | |||||
| start = (tokens == sep_token_idx).nonzero( | |||||
| as_tuple=True)[-1] | |||||
| if start + len(generate_tokens) >= 512: | |||||
| tokens = torch.cat([ | |||||
| tokens[:start], | |||||
| torch.cuda.LongTensor(generate_tokens) | |||||
| ], -1)[-512:] | |||||
| else: | |||||
| tokens[0][start:start + len(generate_tokens | |||||
| )] = torch.cuda.LongTensor( | |||||
| generate_tokens) | |||||
| attention_mask = (tokens != 0) | |||||
| dec_input_ids = input['dec_input_ids'].to(device) | |||||
| generate_tokens = [] | |||||
| sequence_output = None | |||||
| position_ids = torch.full([batch_size, 1], | |||||
| len(generate_tokens), | |||||
| dtype=torch.long, | |||||
| device=device) | |||||
| _, logits, sequence_output = self.dist_model( | |||||
| tokens, | |||||
| None, | |||||
| attention_mask, | |||||
| dec_input_ids, | |||||
| attention_mask, | |||||
| position_ids, | |||||
| is_infer=True, | |||||
| sequence_output=sequence_output, | |||||
| parallel_output=False) | |||||
| logits = logits[:, -1, :] | |||||
| logits = logits / self.model_cfg['temperature'] | |||||
| logits = self.top_k_logits( | |||||
| logits, | |||||
| top_k=self.model_cfg['top_k'], | |||||
| top_p=self.model_cfg['top_p']) | |||||
| log_probs = F.softmax(logits, dim=-1) | |||||
| prev = torch.multinomial(log_probs, num_samples=1) | |||||
| prev_token = prev[0].item() | |||||
| if prev_token >= vocab_size: | |||||
| prev_token = 100 | |||||
| prev[0] = 100 | |||||
| if prev_token == 102 and len(all_generate_tokens) > int( | |||||
| max(1, out_length) * 0.8): | |||||
| break | |||||
| if prev_token == 102: | |||||
| counter += 1 | |||||
| continue | |||||
| dec_input_ids = torch.cat([dec_input_ids, prev], dim=1) | |||||
| generate_tokens.append(prev_token) | |||||
| all_generate_tokens.append(prev_token) | |||||
| counter += 1 | |||||
| generate_context = [] | |||||
| for token in all_generate_tokens: | |||||
| if generate_context and generate_context[ | |||||
| -1] == 100 and token == 100: | |||||
| continue | |||||
| else: | |||||
| generate_context.append(token) | |||||
| return {'generate_context': generate_context} | |||||
| @@ -1,7 +1,10 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| import os | |||||
| import os.path as osp | import os.path as osp | ||||
| from abc import ABC, abstractmethod | from abc import ABC, abstractmethod | ||||
| from functools import partial | |||||
| from multiprocessing import Pool | |||||
| from threading import Lock | from threading import Lock | ||||
| from typing import Any, Dict, Generator, List, Mapping, Union | from typing import Any, Dict, Generator, List, Mapping, Union | ||||
| @@ -15,8 +18,10 @@ from modelscope.utils.config import Config | |||||
| from modelscope.utils.constant import Frameworks, ModelFile | from modelscope.utils.constant import Frameworks, ModelFile | ||||
| from modelscope.utils.device import (create_device, device_placement, | from modelscope.utils.device import (create_device, device_placement, | ||||
| verify_device) | verify_device) | ||||
| from modelscope.utils.hub import read_config, snapshot_download | |||||
| from modelscope.utils.import_utils import is_tf_available, is_torch_available | from modelscope.utils.import_utils import is_tf_available, is_torch_available | ||||
| from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
| from modelscope.utils.torch_utils import _find_free_port, _is_free_port | |||||
| from .util import is_model, is_official_hub_path | from .util import is_model, is_official_hub_path | ||||
| if is_torch_available(): | if is_torch_available(): | ||||
| @@ -302,3 +307,106 @@ class Pipeline(ABC): | |||||
| output should have the standard output name. | output should have the standard output name. | ||||
| """ | """ | ||||
| raise NotImplementedError('postprocess') | raise NotImplementedError('postprocess') | ||||
| class DistributedPipeline(Pipeline): | |||||
| """This pipeline is used to load multi gpu models. | |||||
| What will this class do: | |||||
| 1. Read the global config from the configuration.json | |||||
| 2. Set the multiprocessing method to spawn | |||||
| 3. Open a multiprocessing pool of the world_size to instantiate model pieces. | |||||
| 4. Set the master port and ip | |||||
| 5. Call _instantiate_one to instantiate one model piece | |||||
| This method should be implemented by the derived class. | |||||
| 6. After the forward method is called, do preprocess in main process | |||||
| and call _forward_one to collect results, and do | |||||
| post process in main process. | |||||
| NOTE: _instantiate_one and _forward_one are class methods, any derived class should implement them and | |||||
| store the model handler in the class field. | |||||
| """ | |||||
| def __init__(self, | |||||
| model: str = None, | |||||
| preprocessor: Union[Preprocessor, List[Preprocessor]] = None, | |||||
| auto_collate=True, | |||||
| **kwargs): | |||||
| self.preprocessor = preprocessor | |||||
| self._model_prepare = False | |||||
| self._model_prepare_lock = Lock() | |||||
| self._auto_collate = auto_collate | |||||
| if os.path.exists(model): | |||||
| self.model_dir = model | |||||
| else: | |||||
| self.model_dir = snapshot_download(model) | |||||
| self.cfg = read_config(self.model_dir) | |||||
| self.world_size = self.cfg.model.world_size | |||||
| self.model_pool = None | |||||
| self.device_name = 'cpu' | |||||
| self.device = create_device(self.device_name) | |||||
| self.has_multiple_models = False | |||||
| self.framework = self.cfg.framework | |||||
| if torch.multiprocessing.get_start_method(allow_none=True) is None: | |||||
| torch.multiprocessing.set_start_method('spawn') | |||||
| ranks = list(range(self.world_size)) | |||||
| self.model_pool = Pool(self.world_size) | |||||
| master_ip = '127.0.0.1' if 'master_ip' not in kwargs else kwargs[ | |||||
| 'master_ip'] | |||||
| master_port = '29500' if 'master_port' not in kwargs else kwargs[ | |||||
| 'master_port'] | |||||
| if not _is_free_port(int(master_port)): | |||||
| master_port = str(_find_free_port()) | |||||
| self.model_pool.map( | |||||
| partial( | |||||
| self.__class__._instantiate_one, | |||||
| model_dir=self.model_dir, | |||||
| master_ip=master_ip, | |||||
| master_port=master_port, | |||||
| **self.cfg.model, | |||||
| **kwargs), ranks) | |||||
| def __del__(self): | |||||
| if hasattr(self, 'model_pool') and self.model_pool is not None: | |||||
| self.model_pool.terminate() | |||||
| def __getstate__(self): | |||||
| self_dict = self.__dict__.copy() | |||||
| del self_dict['model_pool'] | |||||
| del self_dict['preprocessor'] | |||||
| del self_dict['_model_prepare_lock'] | |||||
| return self_dict | |||||
| @classmethod | |||||
| def _instantiate_one(cls, rank, model_dir, **kwargs): | |||||
| """Instantiate one model piece. | |||||
| @param rank: The model rank. | |||||
| @param model_dir: The model_dir in the node. | |||||
| @param kwargs: Any extra args. | |||||
| @return: None. The model handler should be kept in the class field. | |||||
| """ | |||||
| pass | |||||
| def forward(self, inputs: Dict[str, Any], | |||||
| **forward_params) -> Dict[str, Any]: | |||||
| inputs = { | |||||
| 'inputs': inputs, | |||||
| 'forward_params': forward_params, | |||||
| } | |||||
| res = self.model_pool.map(self.__class__._forward_one, | |||||
| [inputs] * self.world_size) | |||||
| return res[0] | |||||
| @classmethod | |||||
| def _forward_one(cls, inputs): | |||||
| """Forward the inputs to one model piece. | |||||
| Use the model handler kept in the class field to forward. | |||||
| @param inputs: The inputs after the preprocessing. | |||||
| @return: The forward results. | |||||
| """ | |||||
| pass | |||||
| @@ -0,0 +1,107 @@ | |||||
| from typing import Any, Dict | |||||
| import torch | |||||
| from modelscope.metainfo import Pipelines | |||||
| from modelscope.models.nlp.plug import DistributedPlug | |||||
| from modelscope.pipelines.base import DistributedPipeline | |||||
| from modelscope.pipelines.builder import PIPELINES | |||||
| from modelscope.preprocessors import TextGenerationPreprocessor | |||||
| from modelscope.utils.constant import Tasks | |||||
| @PIPELINES.register_module( | |||||
| Tasks.text_generation, module_name=Pipelines.plug_generation) | |||||
| class DistributedPlugPipeline(DistributedPipeline): | |||||
| """This class is used to instantiate the plug model. | |||||
| """ | |||||
| model = None | |||||
| def __init__(self, | |||||
| model, | |||||
| preprocessor=None, | |||||
| first_sequence='sentence', | |||||
| **kwargs): | |||||
| """Create a plug pipeline instance. | |||||
| @param model: The model_id of plug(damo/nlp_plug_text-generation_27B). | |||||
| The default path to damo/nlp_plug_text-generation_27B can be obtained by function | |||||
| get_cache_dir("damo/nlp_plug_text-generation_27B"), the model should be downloaded to | |||||
| this path before calling this class by model_id. | |||||
| The model can be downloaded from the link on | |||||
| https://modelscope.cn/models/damo/nlp_plug_text-generation_27B/summary. | |||||
| After downloading, you should have a plug model structure like this: | |||||
| /your/path/to/damo/nlp_plug_text-generation_27B | |||||
| |_ config.json | |||||
| |_ configuration.json | |||||
| |_ ds_zero-offload_10B_config.json | |||||
| |_ vocab.txt | |||||
| |_ model <-- an empty directory | |||||
| Model binaries shall be downloaded separately to populate the model directory, so that | |||||
| the model directory would contain the following binaries: | |||||
| |_ model | |||||
| |_ mp_rank_00_model_states.pt | |||||
| |_ mp_rank_01_model_states.pt | |||||
| |_ mp_rank_02_model_states.pt | |||||
| |_ mp_rank_03_model_states.pt | |||||
| |_ mp_rank_04_model_states.pt | |||||
| |_ mp_rank_05_model_states.pt | |||||
| |_ mp_rank_06_model_states.pt | |||||
| |_ mp_rank_07_model_states.pt | |||||
| @param preprocessor: The optional preprocessor, if not passed in, a TextGenerationPreprocessor will | |||||
| be used as default. | |||||
| @param first_sequence: The first_sequence key name if the input format is a dict. | |||||
| @param kwargs: | |||||
| sequence_length: The input sequence_length. | |||||
| """ | |||||
| if preprocessor is None: | |||||
| preprocessor = TextGenerationPreprocessor( | |||||
| model, | |||||
| first_sequence=first_sequence, | |||||
| sequence_length=kwargs.pop('sequence_length', 512)) | |||||
| super().__init__(model, preprocessor=preprocessor, **kwargs) | |||||
| assert hasattr(preprocessor, 'tokenizer') | |||||
| self.cls_token_id = preprocessor.tokenizer.cls_token_id | |||||
| @classmethod | |||||
| def _forward_one(cls, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||||
| with torch.no_grad(): | |||||
| return cls.model.generate(inputs['inputs'], | |||||
| **inputs['forward_params']) | |||||
| def _sanitize_parameters(self, **pipeline_parameters): | |||||
| return {}, pipeline_parameters, {} | |||||
| def forward(self, inputs: Dict[str, Any], | |||||
| **forward_params) -> Dict[str, Any]: | |||||
| batch_size = inputs['input_ids'].shape[0] | |||||
| dec_input_ids = torch.full([batch_size, 1], | |||||
| self.cls_token_id, | |||||
| dtype=torch.long) | |||||
| inputs['dec_input_ids'] = dec_input_ids | |||||
| res = super().forward(inputs, **forward_params) | |||||
| return res | |||||
| @classmethod | |||||
| def _instantiate_one(cls, rank, model_dir, **kwargs): | |||||
| cls.model = DistributedPlug(model_dir, rank, **kwargs) | |||||
| cls.model.eval() | |||||
| def postprocess(self, inputs: Dict[str, Any], | |||||
| **postprocess_params) -> Dict[str, str]: | |||||
| """process the prediction results | |||||
| Args: | |||||
| inputs (Dict[str, Any]): _description_ | |||||
| Returns: | |||||
| Dict[str, str]: the prediction results | |||||
| """ | |||||
| from modelscope.outputs import OutputKeys | |||||
| generate_context = inputs['generate_context'] | |||||
| generate_context = ''.join( | |||||
| self.preprocessor.tokenizer.convert_ids_to_tokens( | |||||
| generate_context)).replace('[UNK]', '“').replace('##', '') | |||||
| return {OutputKeys.TEXT: generate_context} | |||||
| @@ -164,7 +164,8 @@ class NLPTokenizerPreprocessorBase(Preprocessor): | |||||
| """ | """ | ||||
| model_type = get_model_type(model_dir) | model_type = get_model_type(model_dir) | ||||
| if model_type in (Models.structbert, Models.gpt3, Models.palm): | |||||
| if model_type in (Models.structbert, Models.gpt3, Models.palm, | |||||
| Models.plug): | |||||
| from modelscope.models.nlp.structbert import SbertTokenizer | from modelscope.models.nlp.structbert import SbertTokenizer | ||||
| return SbertTokenizer.from_pretrained(model_dir, use_fast=False) | return SbertTokenizer.from_pretrained(model_dir, use_fast=False) | ||||
| elif model_type == Models.veco: | elif model_type == Models.veco: | ||||
| @@ -39,7 +39,8 @@ from modelscope.utils.device import create_device, verify_device | |||||
| from modelscope.utils.file_utils import func_receive_dict_inputs | from modelscope.utils.file_utils import func_receive_dict_inputs | ||||
| from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
| from modelscope.utils.registry import build_from_cfg | from modelscope.utils.registry import build_from_cfg | ||||
| from modelscope.utils.torch_utils import get_dist_info, init_dist | |||||
| from modelscope.utils.torch_utils import (get_dist_info, init_dist, | |||||
| set_random_seed) | |||||
| from .base import BaseTrainer | from .base import BaseTrainer | ||||
| from .builder import TRAINERS | from .builder import TRAINERS | ||||
| from .default_config import DEFAULT_CONFIG | from .default_config import DEFAULT_CONFIG | ||||
| @@ -922,6 +923,4 @@ def worker_init_fn(worker_id, num_workers, rank, seed): | |||||
| # The seed of each worker equals to | # The seed of each worker equals to | ||||
| # num_worker * rank + worker_id + user_seed | # num_worker * rank + worker_id + user_seed | ||||
| worker_seed = num_workers * rank + worker_id + seed | worker_seed = num_workers * rank + worker_id + seed | ||||
| np.random.seed(worker_seed) | |||||
| random.seed(worker_seed) | |||||
| torch.manual_seed(worker_seed) | |||||
| set_random_seed(worker_seed) | |||||
| @@ -0,0 +1,130 @@ | |||||
| # Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. | |||||
| # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| import math | |||||
| import torch | |||||
| import torch.distributed as dist | |||||
| from megatron import mpu | |||||
| from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors | |||||
| from torch.autograd import Variable | |||||
| from torch.nn.modules import Module | |||||
| from modelscope.utils.torch_utils import init_dist | |||||
| def initialize_distributed(rank, mpu, world_size, model_parallel_size, | |||||
| master_ip, master_port): | |||||
| """Initialize torch.distributed.""" | |||||
| # Manually set the device ids. | |||||
| device = rank % torch.cuda.device_count() | |||||
| torch.cuda.set_device(device) | |||||
| # Call the init process | |||||
| init_method = 'tcp://' | |||||
| init_method += master_ip + ':' + master_port | |||||
| torch.distributed.init_process_group( | |||||
| backend='nccl', world_size=8, rank=rank, init_method=init_method) | |||||
| # Set the model-parallel communicators. | |||||
| mpu.initialize_model_parallel(model_parallel_size) | |||||
| def normal_init_method(mean, std): | |||||
| def init_(tensor): | |||||
| return torch.nn.init.normal_(tensor, mean=mean, std=std) | |||||
| return init_ | |||||
| def scaled_init_method(mean, std, num_layers): | |||||
| """Init method based on N(0, sigma/sqrt(2*num_layers).""" | |||||
| std = std / math.sqrt(2.0 * num_layers) | |||||
| def init_(tensor): | |||||
| return torch.nn.init.normal_(tensor, mean=mean, std=std) | |||||
| return init_ | |||||
| class DistributedDataParallel(Module): | |||||
| def __init__(self, module): | |||||
| super(DistributedDataParallel, self).__init__() | |||||
| self.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False | |||||
| self.module = module | |||||
| self.data_parallel_group = mpu.get_data_parallel_group() | |||||
| src_rank = mpu.get_model_parallel_rank() | |||||
| for p in self.module.parameters(): | |||||
| if torch.is_tensor(p): | |||||
| dist.broadcast(p, src_rank, group=self.data_parallel_group) | |||||
| def allreduce_params(reduce_after=True, | |||||
| no_scale=False, | |||||
| fp32_allreduce=False): | |||||
| if (self.needs_reduction): | |||||
| self.needs_reduction = False | |||||
| buckets = {} | |||||
| for name, param in self.module.named_parameters(): | |||||
| if param.requires_grad and param.grad is not None: | |||||
| tp = (param.data.type()) | |||||
| if tp not in buckets: | |||||
| buckets[tp] = [] | |||||
| buckets[tp].append(param) | |||||
| if self.warn_on_half: | |||||
| if torch.cuda.HalfTensor in buckets: | |||||
| print( | |||||
| 'WARNING: gloo dist backend for half parameters may be extremely slow.', | |||||
| 'It is recommended to use the NCCL backend in this case.' | |||||
| ) | |||||
| self.warn_on_half = False | |||||
| for tp in buckets: | |||||
| bucket = buckets[tp] | |||||
| grads = [param.grad.data for param in bucket] | |||||
| coalesced = _flatten_dense_tensors(grads) | |||||
| if fp32_allreduce: | |||||
| coalesced = coalesced.float() | |||||
| if not no_scale and not reduce_after: | |||||
| coalesced /= dist.get_world_size( | |||||
| group=self.data_parallel_group) | |||||
| dist.all_reduce(coalesced, group=self.data_parallel_group) | |||||
| torch.cuda.synchronize() | |||||
| if not no_scale and reduce_after: | |||||
| coalesced /= dist.get_world_size( | |||||
| group=self.data_parallel_group) | |||||
| for buf, synced in zip( | |||||
| grads, _unflatten_dense_tensors(coalesced, grads)): | |||||
| buf.copy_(synced) | |||||
| self.hook_handles = [] | |||||
| self.hooks = [] | |||||
| for param in list(self.module.parameters()): | |||||
| def allreduce_hook(*unused): | |||||
| Variable._execution_engine.queue_callback(allreduce_params) | |||||
| self.allreduce_params = allreduce_params | |||||
| def forward(self, *inputs, **kwargs): | |||||
| self.needs_reduction = True | |||||
| return self.module(*inputs, **kwargs) | |||||
| def state_dict(self, destination=None, prefix='', keep_vars=False): | |||||
| sd = self.module.state_dict(destination, prefix, keep_vars) | |||||
| return sd | |||||
| def load_state_dict(self, state_dict, strict=True): | |||||
| self.module.load_state_dict(state_dict, strict=strict) | |||||
| @@ -0,0 +1,117 @@ | |||||
| # Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. | |||||
| # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| import os | |||||
| import torch | |||||
| def load_checkpoint(model, | |||||
| load_dir, | |||||
| tag, | |||||
| load_module_strict=True, | |||||
| load_optimizer_states=True, | |||||
| load_lr_scheduler_states=True): | |||||
| r"""Load training checkpoint | |||||
| Arguments: | |||||
| load_dir: Required. Directory to load the checkpoint from | |||||
| tag: Required. Checkpoint tag used as a unique identifier for the checkpoint. Ex. Global Step. | |||||
| load_module_strict: Optional. Boolean to strictly enforce that the keys in state_dict of module and | |||||
| checkpoint match. | |||||
| load_optimizer_states: Optional. Boolean to load the training optimizer states from Checkpoint. | |||||
| Ex. ADAM's momentum and variance | |||||
| load_lr_scheduler_states: Optional. Boolean to add the learning rate scheduler states from Checkpoint. | |||||
| Return: | |||||
| load_path: Path of the loaded checkpoint. None if loading the checkpoint failed | |||||
| client_state: State dictionary used for loading required training states in the client code. | |||||
| """ | |||||
| load_path, client_states = _load_checkpoint( | |||||
| model, | |||||
| load_dir, | |||||
| tag, | |||||
| load_module_strict=load_module_strict, | |||||
| load_optimizer_states=load_optimizer_states, | |||||
| load_lr_scheduler_states=load_lr_scheduler_states) | |||||
| if load_optimizer_states: | |||||
| if model.zero_optimization() and load_path is not None: | |||||
| model._load_zero_checkpoint( | |||||
| load_dir, tag, load_optimizer_states=load_optimizer_states) | |||||
| return load_path, client_states | |||||
| def _get_ckpt_name(mpu, checkpoints_path, tag): | |||||
| mp_rank = 0 if mpu is None else mpu.get_model_parallel_rank() | |||||
| ckpt_name = os.path.join( | |||||
| checkpoints_path, str(tag), | |||||
| 'mp_rank_{:02d}'.format(mp_rank) + '_model_states.pt') | |||||
| return ckpt_name | |||||
| def pre_load(mpu, load_dir, tag=''): | |||||
| load_path = _get_ckpt_name(mpu, load_dir, tag) | |||||
| checkpoint = torch.load( | |||||
| load_path, map_location=lambda storage, loc: storage) | |||||
| return checkpoint['module'] | |||||
| def _load_checkpoint(model, | |||||
| load_dir, | |||||
| tag, | |||||
| load_module_strict=True, | |||||
| load_optimizer_states=True, | |||||
| load_lr_scheduler_states=True): | |||||
| load_path = model._get_ckpt_name(load_dir, tag) | |||||
| if not os.path.exists(load_path): | |||||
| return None, None | |||||
| checkpoint = torch.load( | |||||
| load_path, map_location=lambda storage, loc: storage) | |||||
| model.load_module_state_dict( | |||||
| state_dict=checkpoint['module'], strict=load_module_strict) | |||||
| if not model.zero_optimization() and load_optimizer_states: | |||||
| if model.fp16_enabled(): | |||||
| model.optimizer.load_state_dict( | |||||
| checkpoint['optimizer'], | |||||
| load_optimizer_states=load_optimizer_states) | |||||
| elif load_optimizer_states: | |||||
| model.optimizer.load_state_dict(checkpoint['optimizer']) | |||||
| if load_lr_scheduler_states and model.lr_scheduler is not None: | |||||
| model.lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) | |||||
| model.csr_tensor_module_names = checkpoint['csr_tensor_module_names'] | |||||
| model.global_steps = checkpoint['global_steps'] | |||||
| model.global_samples = checkpoint.get( | |||||
| 'global_samples', model.global_steps * model.train_batch_size()) | |||||
| model.skipped_steps = checkpoint['skipped_steps'] | |||||
| model.loaded_checkpoint_mp_world_size = checkpoint['mp_world_size'] | |||||
| model.loaded_checkpoint_dp_world_size = checkpoint['dp_world_size'] | |||||
| deepspeed_states = [ | |||||
| 'module', 'optimizer', 'lr_scheduler', 'csr_tensor_module_names', | |||||
| 'skipped_steps', 'global_steps', 'dp_world_size', 'mp_world_size' | |||||
| ] | |||||
| client_state = { | |||||
| key: value | |||||
| for key, value in checkpoint.items() if key not in deepspeed_states | |||||
| } | |||||
| return load_path, client_state | |||||
| @@ -3,16 +3,16 @@ | |||||
| import functools | import functools | ||||
| import os | import os | ||||
| import pickle | import pickle | ||||
| import random | |||||
| import socket | import socket | ||||
| import subprocess | import subprocess | ||||
| import tempfile | import tempfile | ||||
| from typing import Callable, List, Optional, Tuple | from typing import Callable, List, Optional, Tuple | ||||
| import numpy as np | |||||
| import torch | import torch | ||||
| import torch.multiprocessing as mp | import torch.multiprocessing as mp | ||||
| from torch import distributed as dist | from torch import distributed as dist | ||||
| from torch._utils import (_flatten_dense_tensors, _take_tensors, | |||||
| _unflatten_dense_tensors) | |||||
| def _find_free_port() -> str: | def _find_free_port() -> str: | ||||
| @@ -49,7 +49,6 @@ def init_dist(launcher: str, backend: str = 'nccl', **kwargs) -> None: | |||||
| def _init_dist_pytorch(backend: str, **kwargs) -> None: | def _init_dist_pytorch(backend: str, **kwargs) -> None: | ||||
| # rank = int(os.environ['RANK']) | # rank = int(os.environ['RANK']) | ||||
| local_rank = int(os.environ['LOCAL_RANK']) | local_rank = int(os.environ['LOCAL_RANK']) | ||||
| torch.cuda.set_device(local_rank) | torch.cuda.set_device(local_rank) | ||||
| dist.init_process_group(backend=backend, **kwargs) | dist.init_process_group(backend=backend, **kwargs) | ||||
| @@ -180,3 +179,19 @@ def broadcast(inputs, src): | |||||
| dist.broadcast(inputs_tensor, src) | dist.broadcast(inputs_tensor, src) | ||||
| return pickle.loads(inputs_tensor.cpu().numpy().tobytes()) | return pickle.loads(inputs_tensor.cpu().numpy().tobytes()) | ||||
| def set_random_seed(seed): | |||||
| if seed is not None and seed >= 0: | |||||
| random.seed(seed) | |||||
| np.random.seed(seed) | |||||
| torch.manual_seed(seed) | |||||
| else: | |||||
| raise ValueError( | |||||
| f'Random seed should be positive, current seed is {seed}') | |||||
| def set_random_seed_mpu(seed): | |||||
| from megatron import mpu | |||||
| set_random_seed(seed) | |||||
| mpu.model_parallel_cuda_manual_seed(seed) | |||||
| @@ -1,6 +1,8 @@ | |||||
| deepspeed | |||||
| en_core_web_sm>=2.3.5 | en_core_web_sm>=2.3.5 | ||||
| fairseq>=0.10.2 | fairseq>=0.10.2 | ||||
| jieba>=0.42.1 | jieba>=0.42.1 | ||||
| megatron_util | |||||
| pai-easynlp | pai-easynlp | ||||
| # rough-score was just recently updated from 0.0.4 to 0.0.7 | # rough-score was just recently updated from 0.0.4 to 0.0.7 | ||||
| # which introduced compatability issues that are being investigated | # which introduced compatability issues that are being investigated | ||||
| @@ -0,0 +1,49 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import unittest | |||||
| from modelscope.hub.snapshot_download import snapshot_download | |||||
| from modelscope.pipelines import pipeline | |||||
| from modelscope.utils.constant import Tasks | |||||
| class TextPlugGenerationTest(unittest.TestCase): | |||||
| def setUp(self) -> None: | |||||
| # please make sure this local path exists. | |||||
| self.model_id = 'damo/nlp_plug_text-generation_27B' | |||||
| self.model_dir = snapshot_download(self.model_id) | |||||
| self.plug_input = '段誉轻挥折扇,摇了摇头,说道:“你师父是你的师父,你师父可不是我的师父。"' | |||||
| @unittest.skip('distributed plug, skipped') | |||||
| def test_plug(self): | |||||
| """ The model can be downloaded from the link on | |||||
| https://modelscope.cn/models/damo/nlp_plug_text-generation_27B/summary. | |||||
| After downloading, you should have a plug model structure like this: | |||||
| nlp_plug_text-generation_27B | |||||
| |_ config.json | |||||
| |_ configuration.json | |||||
| |_ ds_zero-offload_10B_config.json | |||||
| |_ vocab.txt | |||||
| |_ model <-- an empty directory | |||||
| Model binaries shall be downloaded separately to populate the model directory, so that | |||||
| the model directory would contain the following binaries: | |||||
| |_ model | |||||
| |_ mp_rank_00_model_states.pt | |||||
| |_ mp_rank_01_model_states.pt | |||||
| |_ mp_rank_02_model_states.pt | |||||
| |_ mp_rank_03_model_states.pt | |||||
| |_ mp_rank_04_model_states.pt | |||||
| |_ mp_rank_05_model_states.pt | |||||
| |_ mp_rank_06_model_states.pt | |||||
| |_ mp_rank_07_model_states.pt | |||||
| """ | |||||
| # download model binaries to <model_dir>/model | |||||
| pipe = pipeline(Tasks.text_generation, model=self.model_id) | |||||
| print( | |||||
| f'input: {self.plug_input}\noutput: {pipe(self.plug_input, out_length=256)}' | |||||
| ) | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||