Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9931748master
| @@ -55,6 +55,7 @@ class Models(object): | |||
| lcrf = 'lstm-crf' | |||
| bart = 'bart' | |||
| gpt3 = 'gpt3' | |||
| plug = 'plug' | |||
| bert_for_ds = 'bert-for-document-segmentation' | |||
| # audio models | |||
| @@ -172,6 +173,7 @@ class Pipelines(object): | |||
| dialog_state_tracking = 'dialog-state-tracking' | |||
| zero_shot_classification = 'zero-shot-classification' | |||
| text_error_correction = 'text-error-correction' | |||
| plug_generation = 'plug-generation' | |||
| faq_question_answering = 'faq-question-answering' | |||
| conversational_text_to_sql = 'conversational-text-to-sql' | |||
| relation_extraction = 'relation-extraction' | |||
| @@ -28,6 +28,7 @@ if TYPE_CHECKING: | |||
| SingleBackboneTaskModelBase) | |||
| from .bart_for_text_error_correction import BartForTextErrorCorrection | |||
| from .gpt3 import GPT3ForTextGeneration | |||
| from .plug import PlugForTextGeneration | |||
| from .sbert_for_faq_question_answering import SbertForFaqQuestionAnswering | |||
| else: | |||
| @@ -60,6 +61,7 @@ else: | |||
| ], | |||
| 'bart_for_text_error_correction': ['BartForTextErrorCorrection'], | |||
| 'gpt3': ['GPT3ForTextGeneration'], | |||
| 'plug': ['PlugForTextGeneration'], | |||
| 'sbert_for_faq_question_answering': ['SbertForFaqQuestionAnswering'], | |||
| } | |||
| @@ -0,0 +1,27 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from typing import TYPE_CHECKING | |||
| from modelscope.utils.import_utils import LazyImportModule | |||
| if TYPE_CHECKING: | |||
| from .configuration_plug import PlugNLGConfig | |||
| from .modeling_plug import PlugModel | |||
| from .distributed_plug import DistributedPlug | |||
| from .plug_for_text_generation import PlugForTextGeneration | |||
| else: | |||
| _import_structure = { | |||
| 'configuration_plug': ['PlugNLGConfig'], | |||
| 'modeling_plug': ['PlugModel'], | |||
| 'distributed_plug': ['DistributedPlug'], | |||
| 'plug_for_text_generation': ['PlugForTextGeneration'], | |||
| } | |||
| import sys | |||
| sys.modules[__name__] = LazyImportModule( | |||
| __name__, | |||
| globals()['__file__'], | |||
| _import_structure, | |||
| module_spec=__spec__, | |||
| extra_objects={}, | |||
| ) | |||
| @@ -0,0 +1,232 @@ | |||
| # Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. | |||
| # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. | |||
| # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| import copy | |||
| import json | |||
| from transformers import PretrainedConfig | |||
| from modelscope.utils import logger as logging | |||
| logger = logging.get_logger(__name__) | |||
| class PlugNLUConfig(PretrainedConfig): | |||
| model_type = 'plugNLU' | |||
| def __init__(self, | |||
| vocab_size=21504, | |||
| original_vocab_size=21128, | |||
| hidden_size=8192, | |||
| num_hidden_layers=24, | |||
| num_attention_heads=128, | |||
| intermediate_size=32768, | |||
| hidden_act='gelu', | |||
| hidden_dropout_prob=0.1, | |||
| attention_probs_dropout_prob=0.1, | |||
| max_position_embeddings=2048, | |||
| type_vocab_size=3, | |||
| initializer_range=0.00707, | |||
| deep_init=False, | |||
| deepspeed=False, | |||
| lr_decay_style='linear', | |||
| weight_decay=1e-2, | |||
| clip_grad=1.0, | |||
| warmup=0.0333, | |||
| pre_ln=True, | |||
| fp16=True, | |||
| fp32_layernorm=True, | |||
| fp32_embedding=False, | |||
| fp32_tokentypes=False, | |||
| layernorm_epsilon=1e-5, | |||
| dec_hidden_layers=6, | |||
| pruning_method=None, | |||
| pruning_mask_init='constant', | |||
| pruning_mask_scale=0.0, | |||
| pruning_initial_threshold=1.0, | |||
| pruning_final_threshold=0.01, | |||
| pruning_initial_warmup=1, | |||
| pruning_final_warmup=20, | |||
| pruning_module='decoder', | |||
| pruning_decay_step=50, | |||
| pruning_decay_type='exp', | |||
| ft_module=None, | |||
| attn_separate=False, | |||
| LR_weight_rank=8, | |||
| LR_mask_rank=8, | |||
| **kwargs): | |||
| super().__init__(layer_norm_eps=layernorm_epsilon, **kwargs) | |||
| self.vocab_size = vocab_size | |||
| self.original_vocab_size = original_vocab_size | |||
| self.hidden_size = hidden_size | |||
| self.num_hidden_layers = num_hidden_layers | |||
| self.num_attention_heads = num_attention_heads | |||
| self.hidden_act = hidden_act | |||
| self.intermediate_size = intermediate_size | |||
| self.hidden_dropout_prob = hidden_dropout_prob | |||
| self.attention_probs_dropout_prob = attention_probs_dropout_prob | |||
| self.max_position_embeddings = max_position_embeddings | |||
| self.type_vocab_size = type_vocab_size | |||
| self.initializer_range = initializer_range | |||
| self.deep_init = deep_init | |||
| self.deepspeed = deepspeed | |||
| self.lr_decay_style = lr_decay_style | |||
| self.weight_decay = weight_decay | |||
| self.clip_grad = clip_grad | |||
| self.warmup = warmup | |||
| self.pre_ln = pre_ln | |||
| self.fp16 = fp16 | |||
| self.fp32_layernorm = fp32_layernorm | |||
| self.fp32_embedding = fp32_embedding | |||
| self.layernorm_epsilon = layernorm_epsilon | |||
| self.fp32_tokentypes = fp32_tokentypes | |||
| self.dec_hidden_layers = dec_hidden_layers | |||
| self.pruning_method = pruning_method | |||
| self.pruning_mask_init = pruning_mask_init | |||
| self.pruning_mask_scale = pruning_mask_scale | |||
| self.pruning_module = pruning_module | |||
| self.pruning_initial_threshold = pruning_initial_threshold | |||
| self.pruning_final_threshold = pruning_final_threshold | |||
| self.pruning_initial_warmup = pruning_initial_warmup | |||
| self.pruning_final_warmup = pruning_final_warmup | |||
| self.pruning_decay_step = pruning_decay_step | |||
| self.pruning_decay_type = pruning_decay_type | |||
| self.ft_module = ft_module | |||
| self.attn_separate = attn_separate | |||
| self.LR_weight_rank = LR_weight_rank | |||
| self.LR_mask_rank = LR_mask_rank | |||
| @classmethod | |||
| def from_dict(cls, json_object): | |||
| """Constructs a `BertConfig` from a Python dictionary of parameters.""" | |||
| config = PlugNLUConfig() | |||
| for key, value in json_object.items(): | |||
| config.__dict__[key] = value | |||
| return config | |||
| @classmethod | |||
| def from_json_file(cls, json_file): | |||
| """Constructs a `BertConfig` from a json file of parameters.""" | |||
| with open(json_file, 'r', encoding='utf-8') as reader: | |||
| text = reader.read() | |||
| return cls.from_dict(json.loads(text)) | |||
| def merge_args(self, args): | |||
| """merge values a `BertConfig` from a json file of parameters.""" | |||
| local_keys = self.__dict__.keys() | |||
| for key, value in args.__dict__.items(): | |||
| if key in local_keys: | |||
| continue | |||
| self.__dict__[key] = value | |||
| return self | |||
| def __repr__(self): | |||
| return str(self.to_json_string()) | |||
| def to_dict(self): | |||
| """Serializes this instance to a Python dictionary.""" | |||
| output = copy.deepcopy(self.__dict__) | |||
| return output | |||
| def to_json_string(self): | |||
| """Serializes this instance to a JSON string.""" | |||
| return json.dumps(self.to_dict(), indent=2, sort_keys=True) + '\n' | |||
| class PlugNLGConfig(PlugNLUConfig): | |||
| model_type = 'plugNLG' | |||
| def __init__(self, | |||
| vocab_size=21504, | |||
| hidden_size=768, | |||
| num_hidden_layers=12, | |||
| num_attention_heads=12, | |||
| intermediate_size=3072, | |||
| hidden_act='gelu', | |||
| hidden_dropout_prob=0.1, | |||
| attention_probs_dropout_prob=0.1, | |||
| max_position_embeddings=512, | |||
| type_vocab_size=2, | |||
| initializer_range=0.00707, | |||
| deep_init=False, | |||
| deepspeed=False, | |||
| lr_decay_style='linear', | |||
| weight_decay=1e-2, | |||
| clip_grad=1.0, | |||
| warmup=0.01, | |||
| pre_ln=False, | |||
| fp16=False, | |||
| fp32_layernorm=False, | |||
| fp32_embedding=False, | |||
| fp32_tokentypes=False, | |||
| layernorm_epsilon=1e-12, | |||
| dec_hidden_layers=6, | |||
| pruning_method=None, | |||
| pruning_mask_init='constant', | |||
| pruning_mask_scale=0.0, | |||
| pruning_initial_threshold=1.0, | |||
| pruning_final_threshold=0.01, | |||
| pruning_initial_warmup=1, | |||
| pruning_final_warmup=20, | |||
| pruning_module='decoder', | |||
| pruning_decay_step=50, | |||
| pruning_decay_type='exp', | |||
| ft_module=None, | |||
| attn_separate=False, | |||
| LR_weight_rank=8, | |||
| LR_mask_rank=8, | |||
| **kwargs): | |||
| super().__init__(layer_norm_eps=layernorm_epsilon, **kwargs) | |||
| self.vocab_size = vocab_size | |||
| self.hidden_size = hidden_size | |||
| self.num_hidden_layers = num_hidden_layers | |||
| self.num_attention_heads = num_attention_heads | |||
| self.hidden_act = hidden_act | |||
| self.intermediate_size = intermediate_size | |||
| self.hidden_dropout_prob = hidden_dropout_prob | |||
| self.attention_probs_dropout_prob = attention_probs_dropout_prob | |||
| self.max_position_embeddings = max_position_embeddings | |||
| self.type_vocab_size = type_vocab_size | |||
| self.initializer_range = initializer_range | |||
| self.deep_init = deep_init | |||
| self.deepspeed = deepspeed | |||
| self.lr_decay_style = lr_decay_style | |||
| self.weight_decay = weight_decay | |||
| self.clip_grad = clip_grad | |||
| self.warmup = warmup | |||
| self.pre_ln = pre_ln | |||
| self.fp16 = fp16 | |||
| self.fp32_layernorm = fp32_layernorm | |||
| self.fp32_embedding = fp32_embedding | |||
| self.layernorm_epsilon = layernorm_epsilon | |||
| self.fp32_tokentypes = fp32_tokentypes | |||
| self.dec_hidden_layers = dec_hidden_layers | |||
| self.pruning_method = pruning_method | |||
| self.pruning_mask_init = pruning_mask_init | |||
| self.pruning_mask_scale = pruning_mask_scale | |||
| self.pruning_module = pruning_module | |||
| self.pruning_initial_threshold = pruning_initial_threshold | |||
| self.pruning_final_threshold = pruning_final_threshold | |||
| self.pruning_initial_warmup = pruning_initial_warmup | |||
| self.pruning_final_warmup = pruning_final_warmup | |||
| self.pruning_decay_step = pruning_decay_step | |||
| self.pruning_decay_type = pruning_decay_type | |||
| self.ft_module = ft_module | |||
| self.attn_separate = attn_separate | |||
| self.LR_weight_rank = LR_weight_rank | |||
| self.LR_mask_rank = LR_mask_rank | |||
| @@ -0,0 +1,191 @@ | |||
| import os | |||
| from typing import Dict | |||
| import torch | |||
| import torch.nn.functional as F | |||
| from megatron import mpu | |||
| from megatron.fp16 import FP16_Module | |||
| from megatron.utils import print_rank_0 | |||
| from modelscope.models import TorchModel | |||
| from modelscope.models.base import Tensor | |||
| from modelscope.utils.logger import get_logger | |||
| from modelscope.utils.nlp.distributed import initialize_distributed | |||
| from modelscope.utils.nlp.load_checkpoint import pre_load | |||
| from modelscope.utils.torch_utils import set_random_seed_mpu | |||
| from . import PlugModel | |||
| from .configuration_plug import PlugNLGConfig | |||
| logger = get_logger(__name__) | |||
| class DistributedPlug(TorchModel): | |||
| def __init__(self, model_dir, rank, **kwargs): | |||
| super().__init__(model_dir, **kwargs) | |||
| self.rank = rank | |||
| self.model_cfg = kwargs | |||
| self.config = PlugNLGConfig.from_pretrained(model_dir) | |||
| initialize_distributed(rank, mpu, kwargs['world_size'], | |||
| kwargs['model_parallel_size'], | |||
| kwargs['master_ip'], kwargs['master_port']) | |||
| seed = 0 if 'seed' not in kwargs else kwargs['seed'] | |||
| set_random_seed_mpu(seed) | |||
| self.iteration = 0 | |||
| self.dist_model = self.initialize_model(path_load_tag='model') | |||
| def initialize_model(self, path_load_tag='model'): | |||
| """Build the model.""" | |||
| print_rank_0('Building Plug model. It will take a few minutes ...') | |||
| model = PlugModel(self.config) | |||
| if mpu.get_data_parallel_rank() == 0: | |||
| logger.info( | |||
| ' > number of parameters on model parallel rank {}: {}'.format( | |||
| mpu.get_model_parallel_rank(), | |||
| sum([p.nelement() for p in model.parameters()]))) | |||
| if self.config.deepspeed and self.config.fp16: | |||
| model.half() | |||
| # GPU allocation. | |||
| model.cuda(torch.cuda.current_device()) | |||
| # Fp16 conversion. | |||
| if self.config.fp16: | |||
| model = FP16_Module(model) | |||
| if self.config.fp32_embedding: | |||
| model.module.model.bert.embeddings.word_embeddings.float() | |||
| model.module.model.bert.embeddings.position_embeddings.float() | |||
| model.module.model.bert.embeddings.token_type_embeddings.float( | |||
| ) | |||
| if self.config.fp32_tokentypes: | |||
| model.module.model.bert.embeddings.token_type_embeddings.float( | |||
| ) | |||
| if self.config.fp32_layernorm: | |||
| for name, _module in model.named_modules(): | |||
| if 'LayerNorm' in name: | |||
| _module.float() | |||
| load_model = pre_load(mpu, self.model_dir, tag=path_load_tag) | |||
| model_dict = model.module.model.state_dict() | |||
| for key in load_model: | |||
| if key not in model_dict.keys(): | |||
| print_rank_0('Skip key: ' + key) | |||
| else: | |||
| print_rank_0('Loading key: ' + key) | |||
| model.module.model.load_state_dict(load_model, strict=False) | |||
| return model | |||
| @staticmethod | |||
| def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): | |||
| # This function has been mostly taken from huggingface conversational ai code at | |||
| # https://medium.com/huggingface/how-to-build-a-state-of-the-art- | |||
| # conversational-ai-with-transfer-learning-2d818ac26313 | |||
| if top_k > 0: | |||
| # Remove all tokens with a probability less than the last token of the top-k | |||
| indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, | |||
| None] | |||
| logits[indices_to_remove] = filter_value | |||
| if top_p > 0.0: | |||
| # convert to 1D | |||
| logits = logits.view(logits.size()[1]).contiguous() | |||
| sorted_logits, sorted_indices = torch.sort(logits, descending=True) | |||
| cumulative_probs = torch.cumsum( | |||
| F.softmax(sorted_logits, dim=-1), dim=-1) | |||
| # Remove tokens with cumulative probability above the threshold | |||
| sorted_indices_to_remove = cumulative_probs > top_p | |||
| # Shift the indices to the right to keep also the first token above the threshold | |||
| sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[ | |||
| ..., :-1].clone() | |||
| sorted_indices_to_remove[..., 0] = 0 | |||
| indices_to_remove = sorted_indices[sorted_indices_to_remove] | |||
| logits[indices_to_remove] = filter_value | |||
| # going back to 2D | |||
| logits = logits.view(1, -1).contiguous() | |||
| return logits | |||
| def generate(self, input: Dict[str, Tensor], out_length=128, *kwargs): | |||
| device = torch.cuda.current_device() | |||
| batch_size = input['input_ids'].shape[0] | |||
| tokens = input['input_ids'].view(1, -1).contiguous().to(device) | |||
| dec_input_ids = input['dec_input_ids'].to(device) | |||
| attention_mask = input['attention_mask'].to(device) | |||
| self.dist_model.eval() | |||
| with torch.no_grad(): | |||
| # Only supports batch_size=1 | |||
| all_generate_tokens = [] | |||
| generate_tokens = [] | |||
| counter = 0 | |||
| sequence_output = None | |||
| vocab_size = self.config.original_vocab_size | |||
| sep_token_idx = 102 # index of [SEP] token in BertTokenizer | |||
| while counter < out_length: | |||
| if counter % 128 == 0 and counter != 0: | |||
| # Sliding window | |||
| generate_tokens.append(sep_token_idx) | |||
| start = (tokens == sep_token_idx).nonzero( | |||
| as_tuple=True)[-1] | |||
| if start + len(generate_tokens) >= 512: | |||
| tokens = torch.cat([ | |||
| tokens[:start], | |||
| torch.cuda.LongTensor(generate_tokens) | |||
| ], -1)[-512:] | |||
| else: | |||
| tokens[0][start:start + len(generate_tokens | |||
| )] = torch.cuda.LongTensor( | |||
| generate_tokens) | |||
| attention_mask = (tokens != 0) | |||
| dec_input_ids = input['dec_input_ids'].to(device) | |||
| generate_tokens = [] | |||
| sequence_output = None | |||
| position_ids = torch.full([batch_size, 1], | |||
| len(generate_tokens), | |||
| dtype=torch.long, | |||
| device=device) | |||
| _, logits, sequence_output = self.dist_model( | |||
| tokens, | |||
| None, | |||
| attention_mask, | |||
| dec_input_ids, | |||
| attention_mask, | |||
| position_ids, | |||
| is_infer=True, | |||
| sequence_output=sequence_output, | |||
| parallel_output=False) | |||
| logits = logits[:, -1, :] | |||
| logits = logits / self.model_cfg['temperature'] | |||
| logits = self.top_k_logits( | |||
| logits, | |||
| top_k=self.model_cfg['top_k'], | |||
| top_p=self.model_cfg['top_p']) | |||
| log_probs = F.softmax(logits, dim=-1) | |||
| prev = torch.multinomial(log_probs, num_samples=1) | |||
| prev_token = prev[0].item() | |||
| if prev_token >= vocab_size: | |||
| prev_token = 100 | |||
| prev[0] = 100 | |||
| if prev_token == 102 and len(all_generate_tokens) > int( | |||
| max(1, out_length) * 0.8): | |||
| break | |||
| if prev_token == 102: | |||
| counter += 1 | |||
| continue | |||
| dec_input_ids = torch.cat([dec_input_ids, prev], dim=1) | |||
| generate_tokens.append(prev_token) | |||
| all_generate_tokens.append(prev_token) | |||
| counter += 1 | |||
| generate_context = [] | |||
| for token in all_generate_tokens: | |||
| if generate_context and generate_context[ | |||
| -1] == 100 and token == 100: | |||
| continue | |||
| else: | |||
| generate_context.append(token) | |||
| return {'generate_context': generate_context} | |||
| @@ -1,7 +1,10 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os | |||
| import os.path as osp | |||
| from abc import ABC, abstractmethod | |||
| from functools import partial | |||
| from multiprocessing import Pool | |||
| from threading import Lock | |||
| from typing import Any, Dict, Generator, List, Mapping, Union | |||
| @@ -15,8 +18,10 @@ from modelscope.utils.config import Config | |||
| from modelscope.utils.constant import Frameworks, ModelFile | |||
| from modelscope.utils.device import (create_device, device_placement, | |||
| verify_device) | |||
| from modelscope.utils.hub import read_config, snapshot_download | |||
| from modelscope.utils.import_utils import is_tf_available, is_torch_available | |||
| from modelscope.utils.logger import get_logger | |||
| from modelscope.utils.torch_utils import _find_free_port, _is_free_port | |||
| from .util import is_model, is_official_hub_path | |||
| if is_torch_available(): | |||
| @@ -302,3 +307,106 @@ class Pipeline(ABC): | |||
| output should have the standard output name. | |||
| """ | |||
| raise NotImplementedError('postprocess') | |||
| class DistributedPipeline(Pipeline): | |||
| """This pipeline is used to load multi gpu models. | |||
| What will this class do: | |||
| 1. Read the global config from the configuration.json | |||
| 2. Set the multiprocessing method to spawn | |||
| 3. Open a multiprocessing pool of the world_size to instantiate model pieces. | |||
| 4. Set the master port and ip | |||
| 5. Call _instantiate_one to instantiate one model piece | |||
| This method should be implemented by the derived class. | |||
| 6. After the forward method is called, do preprocess in main process | |||
| and call _forward_one to collect results, and do | |||
| post process in main process. | |||
| NOTE: _instantiate_one and _forward_one are class methods, any derived class should implement them and | |||
| store the model handler in the class field. | |||
| """ | |||
| def __init__(self, | |||
| model: str = None, | |||
| preprocessor: Union[Preprocessor, List[Preprocessor]] = None, | |||
| auto_collate=True, | |||
| **kwargs): | |||
| self.preprocessor = preprocessor | |||
| self._model_prepare = False | |||
| self._model_prepare_lock = Lock() | |||
| self._auto_collate = auto_collate | |||
| if os.path.exists(model): | |||
| self.model_dir = model | |||
| else: | |||
| self.model_dir = snapshot_download(model) | |||
| self.cfg = read_config(self.model_dir) | |||
| self.world_size = self.cfg.model.world_size | |||
| self.model_pool = None | |||
| self.device_name = 'cpu' | |||
| self.device = create_device(self.device_name) | |||
| self.has_multiple_models = False | |||
| self.framework = self.cfg.framework | |||
| if torch.multiprocessing.get_start_method(allow_none=True) is None: | |||
| torch.multiprocessing.set_start_method('spawn') | |||
| ranks = list(range(self.world_size)) | |||
| self.model_pool = Pool(self.world_size) | |||
| master_ip = '127.0.0.1' if 'master_ip' not in kwargs else kwargs[ | |||
| 'master_ip'] | |||
| master_port = '29500' if 'master_port' not in kwargs else kwargs[ | |||
| 'master_port'] | |||
| if not _is_free_port(int(master_port)): | |||
| master_port = str(_find_free_port()) | |||
| self.model_pool.map( | |||
| partial( | |||
| self.__class__._instantiate_one, | |||
| model_dir=self.model_dir, | |||
| master_ip=master_ip, | |||
| master_port=master_port, | |||
| **self.cfg.model, | |||
| **kwargs), ranks) | |||
| def __del__(self): | |||
| if hasattr(self, 'model_pool') and self.model_pool is not None: | |||
| self.model_pool.terminate() | |||
| def __getstate__(self): | |||
| self_dict = self.__dict__.copy() | |||
| del self_dict['model_pool'] | |||
| del self_dict['preprocessor'] | |||
| del self_dict['_model_prepare_lock'] | |||
| return self_dict | |||
| @classmethod | |||
| def _instantiate_one(cls, rank, model_dir, **kwargs): | |||
| """Instantiate one model piece. | |||
| @param rank: The model rank. | |||
| @param model_dir: The model_dir in the node. | |||
| @param kwargs: Any extra args. | |||
| @return: None. The model handler should be kept in the class field. | |||
| """ | |||
| pass | |||
| def forward(self, inputs: Dict[str, Any], | |||
| **forward_params) -> Dict[str, Any]: | |||
| inputs = { | |||
| 'inputs': inputs, | |||
| 'forward_params': forward_params, | |||
| } | |||
| res = self.model_pool.map(self.__class__._forward_one, | |||
| [inputs] * self.world_size) | |||
| return res[0] | |||
| @classmethod | |||
| def _forward_one(cls, inputs): | |||
| """Forward the inputs to one model piece. | |||
| Use the model handler kept in the class field to forward. | |||
| @param inputs: The inputs after the preprocessing. | |||
| @return: The forward results. | |||
| """ | |||
| pass | |||
| @@ -0,0 +1,107 @@ | |||
| from typing import Any, Dict | |||
| import torch | |||
| from modelscope.metainfo import Pipelines | |||
| from modelscope.models.nlp.plug import DistributedPlug | |||
| from modelscope.pipelines.base import DistributedPipeline | |||
| from modelscope.pipelines.builder import PIPELINES | |||
| from modelscope.preprocessors import TextGenerationPreprocessor | |||
| from modelscope.utils.constant import Tasks | |||
| @PIPELINES.register_module( | |||
| Tasks.text_generation, module_name=Pipelines.plug_generation) | |||
| class DistributedPlugPipeline(DistributedPipeline): | |||
| """This class is used to instantiate the plug model. | |||
| """ | |||
| model = None | |||
| def __init__(self, | |||
| model, | |||
| preprocessor=None, | |||
| first_sequence='sentence', | |||
| **kwargs): | |||
| """Create a plug pipeline instance. | |||
| @param model: The model_id of plug(damo/nlp_plug_text-generation_27B). | |||
| The default path to damo/nlp_plug_text-generation_27B can be obtained by function | |||
| get_cache_dir("damo/nlp_plug_text-generation_27B"), the model should be downloaded to | |||
| this path before calling this class by model_id. | |||
| The model can be downloaded from the link on | |||
| https://modelscope.cn/models/damo/nlp_plug_text-generation_27B/summary. | |||
| After downloading, you should have a plug model structure like this: | |||
| /your/path/to/damo/nlp_plug_text-generation_27B | |||
| |_ config.json | |||
| |_ configuration.json | |||
| |_ ds_zero-offload_10B_config.json | |||
| |_ vocab.txt | |||
| |_ model <-- an empty directory | |||
| Model binaries shall be downloaded separately to populate the model directory, so that | |||
| the model directory would contain the following binaries: | |||
| |_ model | |||
| |_ mp_rank_00_model_states.pt | |||
| |_ mp_rank_01_model_states.pt | |||
| |_ mp_rank_02_model_states.pt | |||
| |_ mp_rank_03_model_states.pt | |||
| |_ mp_rank_04_model_states.pt | |||
| |_ mp_rank_05_model_states.pt | |||
| |_ mp_rank_06_model_states.pt | |||
| |_ mp_rank_07_model_states.pt | |||
| @param preprocessor: The optional preprocessor, if not passed in, a TextGenerationPreprocessor will | |||
| be used as default. | |||
| @param first_sequence: The first_sequence key name if the input format is a dict. | |||
| @param kwargs: | |||
| sequence_length: The input sequence_length. | |||
| """ | |||
| if preprocessor is None: | |||
| preprocessor = TextGenerationPreprocessor( | |||
| model, | |||
| first_sequence=first_sequence, | |||
| sequence_length=kwargs.pop('sequence_length', 512)) | |||
| super().__init__(model, preprocessor=preprocessor, **kwargs) | |||
| assert hasattr(preprocessor, 'tokenizer') | |||
| self.cls_token_id = preprocessor.tokenizer.cls_token_id | |||
| @classmethod | |||
| def _forward_one(cls, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
| with torch.no_grad(): | |||
| return cls.model.generate(inputs['inputs'], | |||
| **inputs['forward_params']) | |||
| def _sanitize_parameters(self, **pipeline_parameters): | |||
| return {}, pipeline_parameters, {} | |||
| def forward(self, inputs: Dict[str, Any], | |||
| **forward_params) -> Dict[str, Any]: | |||
| batch_size = inputs['input_ids'].shape[0] | |||
| dec_input_ids = torch.full([batch_size, 1], | |||
| self.cls_token_id, | |||
| dtype=torch.long) | |||
| inputs['dec_input_ids'] = dec_input_ids | |||
| res = super().forward(inputs, **forward_params) | |||
| return res | |||
| @classmethod | |||
| def _instantiate_one(cls, rank, model_dir, **kwargs): | |||
| cls.model = DistributedPlug(model_dir, rank, **kwargs) | |||
| cls.model.eval() | |||
| def postprocess(self, inputs: Dict[str, Any], | |||
| **postprocess_params) -> Dict[str, str]: | |||
| """process the prediction results | |||
| Args: | |||
| inputs (Dict[str, Any]): _description_ | |||
| Returns: | |||
| Dict[str, str]: the prediction results | |||
| """ | |||
| from modelscope.outputs import OutputKeys | |||
| generate_context = inputs['generate_context'] | |||
| generate_context = ''.join( | |||
| self.preprocessor.tokenizer.convert_ids_to_tokens( | |||
| generate_context)).replace('[UNK]', '“').replace('##', '') | |||
| return {OutputKeys.TEXT: generate_context} | |||
| @@ -164,7 +164,8 @@ class NLPTokenizerPreprocessorBase(Preprocessor): | |||
| """ | |||
| model_type = get_model_type(model_dir) | |||
| if model_type in (Models.structbert, Models.gpt3, Models.palm): | |||
| if model_type in (Models.structbert, Models.gpt3, Models.palm, | |||
| Models.plug): | |||
| from modelscope.models.nlp.structbert import SbertTokenizer | |||
| return SbertTokenizer.from_pretrained(model_dir, use_fast=False) | |||
| elif model_type == Models.veco: | |||
| @@ -39,7 +39,8 @@ from modelscope.utils.device import create_device, verify_device | |||
| from modelscope.utils.file_utils import func_receive_dict_inputs | |||
| from modelscope.utils.logger import get_logger | |||
| from modelscope.utils.registry import build_from_cfg | |||
| from modelscope.utils.torch_utils import get_dist_info, init_dist | |||
| from modelscope.utils.torch_utils import (get_dist_info, init_dist, | |||
| set_random_seed) | |||
| from .base import BaseTrainer | |||
| from .builder import TRAINERS | |||
| from .default_config import DEFAULT_CONFIG | |||
| @@ -922,6 +923,4 @@ def worker_init_fn(worker_id, num_workers, rank, seed): | |||
| # The seed of each worker equals to | |||
| # num_worker * rank + worker_id + user_seed | |||
| worker_seed = num_workers * rank + worker_id + seed | |||
| np.random.seed(worker_seed) | |||
| random.seed(worker_seed) | |||
| torch.manual_seed(worker_seed) | |||
| set_random_seed(worker_seed) | |||
| @@ -0,0 +1,130 @@ | |||
| # Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. | |||
| # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| import math | |||
| import torch | |||
| import torch.distributed as dist | |||
| from megatron import mpu | |||
| from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors | |||
| from torch.autograd import Variable | |||
| from torch.nn.modules import Module | |||
| from modelscope.utils.torch_utils import init_dist | |||
| def initialize_distributed(rank, mpu, world_size, model_parallel_size, | |||
| master_ip, master_port): | |||
| """Initialize torch.distributed.""" | |||
| # Manually set the device ids. | |||
| device = rank % torch.cuda.device_count() | |||
| torch.cuda.set_device(device) | |||
| # Call the init process | |||
| init_method = 'tcp://' | |||
| init_method += master_ip + ':' + master_port | |||
| torch.distributed.init_process_group( | |||
| backend='nccl', world_size=8, rank=rank, init_method=init_method) | |||
| # Set the model-parallel communicators. | |||
| mpu.initialize_model_parallel(model_parallel_size) | |||
| def normal_init_method(mean, std): | |||
| def init_(tensor): | |||
| return torch.nn.init.normal_(tensor, mean=mean, std=std) | |||
| return init_ | |||
| def scaled_init_method(mean, std, num_layers): | |||
| """Init method based on N(0, sigma/sqrt(2*num_layers).""" | |||
| std = std / math.sqrt(2.0 * num_layers) | |||
| def init_(tensor): | |||
| return torch.nn.init.normal_(tensor, mean=mean, std=std) | |||
| return init_ | |||
| class DistributedDataParallel(Module): | |||
| def __init__(self, module): | |||
| super(DistributedDataParallel, self).__init__() | |||
| self.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False | |||
| self.module = module | |||
| self.data_parallel_group = mpu.get_data_parallel_group() | |||
| src_rank = mpu.get_model_parallel_rank() | |||
| for p in self.module.parameters(): | |||
| if torch.is_tensor(p): | |||
| dist.broadcast(p, src_rank, group=self.data_parallel_group) | |||
| def allreduce_params(reduce_after=True, | |||
| no_scale=False, | |||
| fp32_allreduce=False): | |||
| if (self.needs_reduction): | |||
| self.needs_reduction = False | |||
| buckets = {} | |||
| for name, param in self.module.named_parameters(): | |||
| if param.requires_grad and param.grad is not None: | |||
| tp = (param.data.type()) | |||
| if tp not in buckets: | |||
| buckets[tp] = [] | |||
| buckets[tp].append(param) | |||
| if self.warn_on_half: | |||
| if torch.cuda.HalfTensor in buckets: | |||
| print( | |||
| 'WARNING: gloo dist backend for half parameters may be extremely slow.', | |||
| 'It is recommended to use the NCCL backend in this case.' | |||
| ) | |||
| self.warn_on_half = False | |||
| for tp in buckets: | |||
| bucket = buckets[tp] | |||
| grads = [param.grad.data for param in bucket] | |||
| coalesced = _flatten_dense_tensors(grads) | |||
| if fp32_allreduce: | |||
| coalesced = coalesced.float() | |||
| if not no_scale and not reduce_after: | |||
| coalesced /= dist.get_world_size( | |||
| group=self.data_parallel_group) | |||
| dist.all_reduce(coalesced, group=self.data_parallel_group) | |||
| torch.cuda.synchronize() | |||
| if not no_scale and reduce_after: | |||
| coalesced /= dist.get_world_size( | |||
| group=self.data_parallel_group) | |||
| for buf, synced in zip( | |||
| grads, _unflatten_dense_tensors(coalesced, grads)): | |||
| buf.copy_(synced) | |||
| self.hook_handles = [] | |||
| self.hooks = [] | |||
| for param in list(self.module.parameters()): | |||
| def allreduce_hook(*unused): | |||
| Variable._execution_engine.queue_callback(allreduce_params) | |||
| self.allreduce_params = allreduce_params | |||
| def forward(self, *inputs, **kwargs): | |||
| self.needs_reduction = True | |||
| return self.module(*inputs, **kwargs) | |||
| def state_dict(self, destination=None, prefix='', keep_vars=False): | |||
| sd = self.module.state_dict(destination, prefix, keep_vars) | |||
| return sd | |||
| def load_state_dict(self, state_dict, strict=True): | |||
| self.module.load_state_dict(state_dict, strict=strict) | |||
| @@ -0,0 +1,117 @@ | |||
| # Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. | |||
| # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| import os | |||
| import torch | |||
| def load_checkpoint(model, | |||
| load_dir, | |||
| tag, | |||
| load_module_strict=True, | |||
| load_optimizer_states=True, | |||
| load_lr_scheduler_states=True): | |||
| r"""Load training checkpoint | |||
| Arguments: | |||
| load_dir: Required. Directory to load the checkpoint from | |||
| tag: Required. Checkpoint tag used as a unique identifier for the checkpoint. Ex. Global Step. | |||
| load_module_strict: Optional. Boolean to strictly enforce that the keys in state_dict of module and | |||
| checkpoint match. | |||
| load_optimizer_states: Optional. Boolean to load the training optimizer states from Checkpoint. | |||
| Ex. ADAM's momentum and variance | |||
| load_lr_scheduler_states: Optional. Boolean to add the learning rate scheduler states from Checkpoint. | |||
| Return: | |||
| load_path: Path of the loaded checkpoint. None if loading the checkpoint failed | |||
| client_state: State dictionary used for loading required training states in the client code. | |||
| """ | |||
| load_path, client_states = _load_checkpoint( | |||
| model, | |||
| load_dir, | |||
| tag, | |||
| load_module_strict=load_module_strict, | |||
| load_optimizer_states=load_optimizer_states, | |||
| load_lr_scheduler_states=load_lr_scheduler_states) | |||
| if load_optimizer_states: | |||
| if model.zero_optimization() and load_path is not None: | |||
| model._load_zero_checkpoint( | |||
| load_dir, tag, load_optimizer_states=load_optimizer_states) | |||
| return load_path, client_states | |||
| def _get_ckpt_name(mpu, checkpoints_path, tag): | |||
| mp_rank = 0 if mpu is None else mpu.get_model_parallel_rank() | |||
| ckpt_name = os.path.join( | |||
| checkpoints_path, str(tag), | |||
| 'mp_rank_{:02d}'.format(mp_rank) + '_model_states.pt') | |||
| return ckpt_name | |||
| def pre_load(mpu, load_dir, tag=''): | |||
| load_path = _get_ckpt_name(mpu, load_dir, tag) | |||
| checkpoint = torch.load( | |||
| load_path, map_location=lambda storage, loc: storage) | |||
| return checkpoint['module'] | |||
| def _load_checkpoint(model, | |||
| load_dir, | |||
| tag, | |||
| load_module_strict=True, | |||
| load_optimizer_states=True, | |||
| load_lr_scheduler_states=True): | |||
| load_path = model._get_ckpt_name(load_dir, tag) | |||
| if not os.path.exists(load_path): | |||
| return None, None | |||
| checkpoint = torch.load( | |||
| load_path, map_location=lambda storage, loc: storage) | |||
| model.load_module_state_dict( | |||
| state_dict=checkpoint['module'], strict=load_module_strict) | |||
| if not model.zero_optimization() and load_optimizer_states: | |||
| if model.fp16_enabled(): | |||
| model.optimizer.load_state_dict( | |||
| checkpoint['optimizer'], | |||
| load_optimizer_states=load_optimizer_states) | |||
| elif load_optimizer_states: | |||
| model.optimizer.load_state_dict(checkpoint['optimizer']) | |||
| if load_lr_scheduler_states and model.lr_scheduler is not None: | |||
| model.lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) | |||
| model.csr_tensor_module_names = checkpoint['csr_tensor_module_names'] | |||
| model.global_steps = checkpoint['global_steps'] | |||
| model.global_samples = checkpoint.get( | |||
| 'global_samples', model.global_steps * model.train_batch_size()) | |||
| model.skipped_steps = checkpoint['skipped_steps'] | |||
| model.loaded_checkpoint_mp_world_size = checkpoint['mp_world_size'] | |||
| model.loaded_checkpoint_dp_world_size = checkpoint['dp_world_size'] | |||
| deepspeed_states = [ | |||
| 'module', 'optimizer', 'lr_scheduler', 'csr_tensor_module_names', | |||
| 'skipped_steps', 'global_steps', 'dp_world_size', 'mp_world_size' | |||
| ] | |||
| client_state = { | |||
| key: value | |||
| for key, value in checkpoint.items() if key not in deepspeed_states | |||
| } | |||
| return load_path, client_state | |||
| @@ -3,16 +3,16 @@ | |||
| import functools | |||
| import os | |||
| import pickle | |||
| import random | |||
| import socket | |||
| import subprocess | |||
| import tempfile | |||
| from typing import Callable, List, Optional, Tuple | |||
| import numpy as np | |||
| import torch | |||
| import torch.multiprocessing as mp | |||
| from torch import distributed as dist | |||
| from torch._utils import (_flatten_dense_tensors, _take_tensors, | |||
| _unflatten_dense_tensors) | |||
| def _find_free_port() -> str: | |||
| @@ -49,7 +49,6 @@ def init_dist(launcher: str, backend: str = 'nccl', **kwargs) -> None: | |||
| def _init_dist_pytorch(backend: str, **kwargs) -> None: | |||
| # rank = int(os.environ['RANK']) | |||
| local_rank = int(os.environ['LOCAL_RANK']) | |||
| torch.cuda.set_device(local_rank) | |||
| dist.init_process_group(backend=backend, **kwargs) | |||
| @@ -180,3 +179,19 @@ def broadcast(inputs, src): | |||
| dist.broadcast(inputs_tensor, src) | |||
| return pickle.loads(inputs_tensor.cpu().numpy().tobytes()) | |||
| def set_random_seed(seed): | |||
| if seed is not None and seed >= 0: | |||
| random.seed(seed) | |||
| np.random.seed(seed) | |||
| torch.manual_seed(seed) | |||
| else: | |||
| raise ValueError( | |||
| f'Random seed should be positive, current seed is {seed}') | |||
| def set_random_seed_mpu(seed): | |||
| from megatron import mpu | |||
| set_random_seed(seed) | |||
| mpu.model_parallel_cuda_manual_seed(seed) | |||
| @@ -1,6 +1,8 @@ | |||
| deepspeed | |||
| en_core_web_sm>=2.3.5 | |||
| fairseq>=0.10.2 | |||
| jieba>=0.42.1 | |||
| megatron_util | |||
| pai-easynlp | |||
| # rough-score was just recently updated from 0.0.4 to 0.0.7 | |||
| # which introduced compatability issues that are being investigated | |||
| @@ -0,0 +1,49 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import unittest | |||
| from modelscope.hub.snapshot_download import snapshot_download | |||
| from modelscope.pipelines import pipeline | |||
| from modelscope.utils.constant import Tasks | |||
| class TextPlugGenerationTest(unittest.TestCase): | |||
| def setUp(self) -> None: | |||
| # please make sure this local path exists. | |||
| self.model_id = 'damo/nlp_plug_text-generation_27B' | |||
| self.model_dir = snapshot_download(self.model_id) | |||
| self.plug_input = '段誉轻挥折扇,摇了摇头,说道:“你师父是你的师父,你师父可不是我的师父。"' | |||
| @unittest.skip('distributed plug, skipped') | |||
| def test_plug(self): | |||
| """ The model can be downloaded from the link on | |||
| https://modelscope.cn/models/damo/nlp_plug_text-generation_27B/summary. | |||
| After downloading, you should have a plug model structure like this: | |||
| nlp_plug_text-generation_27B | |||
| |_ config.json | |||
| |_ configuration.json | |||
| |_ ds_zero-offload_10B_config.json | |||
| |_ vocab.txt | |||
| |_ model <-- an empty directory | |||
| Model binaries shall be downloaded separately to populate the model directory, so that | |||
| the model directory would contain the following binaries: | |||
| |_ model | |||
| |_ mp_rank_00_model_states.pt | |||
| |_ mp_rank_01_model_states.pt | |||
| |_ mp_rank_02_model_states.pt | |||
| |_ mp_rank_03_model_states.pt | |||
| |_ mp_rank_04_model_states.pt | |||
| |_ mp_rank_05_model_states.pt | |||
| |_ mp_rank_06_model_states.pt | |||
| |_ mp_rank_07_model_states.pt | |||
| """ | |||
| # download model binaries to <model_dir>/model | |||
| pipe = pipeline(Tasks.text_generation, model=self.model_id) | |||
| print( | |||
| f'input: {self.plug_input}\noutput: {pipe(self.plug_input, out_length=256)}' | |||
| ) | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||