Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10836131master^2
| @@ -80,6 +80,7 @@ class Models(object): | |||
| gcnncrf = 'gcnn-crf' | |||
| bart = 'bart' | |||
| gpt3 = 'gpt3' | |||
| gpt_moe = 'gpt-moe' | |||
| gpt_neo = 'gpt-neo' | |||
| plug = 'plug' | |||
| bert_for_ds = 'bert-for-document-segmentation' | |||
| @@ -255,6 +256,7 @@ class Pipelines(object): | |||
| text_error_correction = 'text-error-correction' | |||
| plug_generation = 'plug-generation' | |||
| gpt3_generation = 'gpt3-generation' | |||
| gpt_moe_generation = 'gpt-moe-generation' | |||
| faq_question_answering = 'faq-question-answering' | |||
| conversational_text_to_sql = 'conversational-text-to-sql' | |||
| table_question_answering_pipeline = 'table-question-answering-pipeline' | |||
| @@ -0,0 +1,27 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from typing import TYPE_CHECKING | |||
| from modelscope.utils.import_utils import LazyImportModule | |||
| if TYPE_CHECKING: | |||
| from .configuration import GPTMoEConfig | |||
| from .backbone import GPTMoEModel | |||
| from .text_generation import GPTMoEForTextGeneration | |||
| from .tokenizer import JiebaBPETokenizer | |||
| else: | |||
| _import_structure = { | |||
| 'configuration': ['GPTMoEConfig'], | |||
| 'backbone': ['GPTMoEModel'], | |||
| 'text_generation': ['GPTMoEForTextGeneration'], | |||
| 'tokenizer': ['JiebaBPETokenizer'], | |||
| } | |||
| import sys | |||
| sys.modules[__name__] = LazyImportModule( | |||
| __name__, | |||
| globals()['__file__'], | |||
| _import_structure, | |||
| module_spec=__spec__, | |||
| extra_objects={}, | |||
| ) | |||
| @@ -0,0 +1,355 @@ | |||
| # Copyright 2021-2022 The Alibaba PAI Team Authors. | |||
| # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| import math | |||
| import os | |||
| from typing import Optional, Union | |||
| import addict | |||
| import torch | |||
| from torch import nn | |||
| from torch.nn import functional as F | |||
| from transformers.modeling_utils import PreTrainedModel | |||
| from modelscope.utils.constant import ModelFile | |||
| from .configuration import GPTMoEConfig | |||
| class GPTMoESelfAttention(nn.Module): | |||
| """Parallel self-attention layer abstract class. | |||
| Self-attention layer takes input with size [s, b, h] | |||
| and returns output of the same size. | |||
| """ | |||
| def __init__(self, config): | |||
| super().__init__() | |||
| self.hidden_size = config.hidden_size | |||
| self.num_attention_heads = config.num_attention_heads | |||
| # Per attention head | |||
| self.hidden_size_per_attention_head = \ | |||
| self.hidden_size // self.num_attention_heads | |||
| self.query_key_value = nn.Linear(self.hidden_size, | |||
| 3 * self.hidden_size) | |||
| self.softmax = nn.Softmax(dim=-1) | |||
| self.attention_dropout = nn.Dropout( | |||
| config.attention_probs_dropout_prob) | |||
| # Output. | |||
| self.dense = nn.Linear(self.hidden_size, self.hidden_size) | |||
| self.output_dropout = nn.Dropout(config.hidden_dropout_prob) | |||
| def _transpose_for_scores(self, tensor): | |||
| """Transpose a 3D tensor [b, s, np*hn] into a 4D tensor with | |||
| size [b, np, s, hn]. | |||
| """ | |||
| new_tensor_shape = tensor.size()[:-1] + ( | |||
| self.num_attention_heads, self.hidden_size_per_attention_head) | |||
| tensor = tensor.view(*new_tensor_shape) | |||
| return tensor.permute(0, 2, 1, 3) | |||
| def _split_tensor_along_last_dim(self, | |||
| tensor, | |||
| num_partitions, | |||
| contiguous_split_chunks=False): | |||
| # Get the size and dimension. | |||
| last_dim = tensor.dim() - 1 | |||
| last_dim_size = tensor.size()[last_dim] // num_partitions | |||
| # Split. | |||
| tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) | |||
| # Note: torch.split does not create contiguous tensors by default. | |||
| if contiguous_split_chunks: | |||
| return tuple(chunk.contiguous() for chunk in tensor_list) | |||
| return tensor_list | |||
| def forward(self, hidden_states, ltor_mask, is_infer=False): | |||
| # hidden_states: [b, s, h] | |||
| # ltor_mask: [1, 1, s, s] | |||
| # Attention heads. [b, s, hp] | |||
| tgt_len = hidden_states.size(1) | |||
| ltor_mask = torch.reshape(ltor_mask, [1, 1, tgt_len, tgt_len]) | |||
| mixed_x_layer = self.query_key_value(hidden_states) | |||
| (mixed_query_layer, mixed_key_layer, mixed_value_layer) = \ | |||
| self._split_tensor_along_last_dim(mixed_x_layer, 3) | |||
| # Reshape and transpose [b, np, s, hn] | |||
| query_layer = self._transpose_for_scores(mixed_query_layer) | |||
| key_layer = self._transpose_for_scores(mixed_key_layer) | |||
| value_layer = self._transpose_for_scores(mixed_value_layer) | |||
| previous_type = value_layer.type() | |||
| # Raw attention scores. [b, np, s, s] | |||
| attention_scores = torch.matmul(query_layer, | |||
| key_layer.transpose(-1, -2)) | |||
| attention_scores = attention_scores / math.sqrt( | |||
| self.hidden_size_per_attention_head) | |||
| # Apply the left to right attention mask. | |||
| if is_infer: | |||
| src_len = key_layer.size(2) | |||
| ltor_mask = torch.tril( | |||
| torch.ones((1, tgt_len, src_len), | |||
| device=hidden_states.device)).view( | |||
| 1, 1, tgt_len, src_len).type(previous_type) | |||
| converted_mask = 10000.0 * (1.0 - ltor_mask) | |||
| attention_scores = (torch.mul(attention_scores, ltor_mask) | |||
| - converted_mask).type(previous_type) | |||
| # Attention probabilities. [b, np, s, s] | |||
| attention_probs = self.softmax(attention_scores) | |||
| # This is actually dropping out entire tokens to attend to, which might | |||
| # seem a bit unusual, but is taken from the original Transformer paper. | |||
| attention_probs = self.attention_dropout(attention_probs) | |||
| # Context layer. | |||
| # [b, np, s, hn] | |||
| context_layer = torch.matmul(attention_probs, value_layer) | |||
| # [b, s, np, hn] | |||
| context_layer = context_layer.permute(0, 2, 1, 3).contiguous() | |||
| new_context_layer_shape = context_layer.size()[:-2] + ( | |||
| self.hidden_size, ) | |||
| # [b, s, hp] | |||
| context_layer = context_layer.view(*new_context_layer_shape) | |||
| # Output. [b, s, h] | |||
| output = self.dense(context_layer) | |||
| output = self.output_dropout(output) | |||
| return output | |||
| class GPTMoEMLP(nn.Module): | |||
| """MLP. | |||
| MLP will take the input with h hidden state, project it to 4*h | |||
| hidden dimension, perform nonlinear transformation, and project the | |||
| state back into h hidden dimension. | |||
| """ | |||
| def __init__(self, config): | |||
| super().__init__() | |||
| hidden_size = config.hidden_size | |||
| # Project to 4h. | |||
| self.dense_h_to_4h = nn.Linear(hidden_size, 4 * hidden_size) | |||
| self.activation_func = F.gelu | |||
| # Project back to h. | |||
| self.dense_4h_to_h = nn.Linear(4 * hidden_size, hidden_size) | |||
| self.dropout = nn.Dropout(config.hidden_dropout_prob) | |||
| def forward(self, hidden_states): | |||
| # [s, b, 4hp] | |||
| intermediate_parallel = self.dense_h_to_4h(hidden_states) | |||
| intermediate_parallel = self.activation_func(intermediate_parallel) | |||
| # [s, b, h] | |||
| output = self.dense_4h_to_h(intermediate_parallel) | |||
| output = self.dropout(output) | |||
| return output | |||
| class GPTMoETransformerLayer(nn.Module): | |||
| """A single transformer layer. | |||
| Transformer layer takes input with size [s, b, h] and returns an | |||
| output of the same size. | |||
| """ | |||
| def __init__(self, config): | |||
| super().__init__() | |||
| # Layernorm on the input data. | |||
| self.input_layernorm = nn.LayerNorm( | |||
| config.hidden_size, eps=config.layernorm_epsilon) | |||
| # Self attention. | |||
| self.attention = GPTMoESelfAttention(config) | |||
| # Layernorm on the attention output | |||
| self.post_attention_layernorm = nn.LayerNorm( | |||
| config.hidden_size, eps=config.layernorm_epsilon) | |||
| # MLP | |||
| self.mlp = GPTMoEMLP(config) | |||
| def forward(self, hidden_states, ltor_mask): | |||
| # hidden_states: [b, s, h] | |||
| # ltor_mask: [1, 1, s, s] | |||
| # Layer norm at the begining of the transformer layer. | |||
| layernorm_output = self.input_layernorm(hidden_states) | |||
| # Self attention. | |||
| attention_output = self.attention(layernorm_output, ltor_mask) | |||
| # Residual connection. | |||
| layernorm_input = hidden_states + attention_output | |||
| # Layer norm post the self attention. | |||
| layernorm_output = self.post_attention_layernorm(layernorm_input) | |||
| # MLP. | |||
| mlp_output = self.mlp(layernorm_output) | |||
| # Second residual connection. | |||
| output = layernorm_input + mlp_output | |||
| return output | |||
| class GPTMoETransformer(nn.Module): | |||
| """Transformer class.""" | |||
| def __init__(self, config): | |||
| super().__init__() | |||
| self.input_tensor = None | |||
| # Number of layers. | |||
| self.num_layers = config.num_hidden_layers | |||
| self.layers = torch.nn.ModuleList( | |||
| [GPTMoETransformerLayer(config) for _ in range(self.num_layers)]) | |||
| # Final layer norm before output. | |||
| self.final_layernorm = nn.LayerNorm( | |||
| config.hidden_size, eps=config.layernorm_epsilon) | |||
| def _get_layer(self, layer_number): | |||
| return self.layers[layer_number] | |||
| def forward(self, hidden_states, attention_mask): | |||
| # hidden_states: [s, b, h] | |||
| for index in range(self.num_layers): | |||
| layer = self._get_layer(index) | |||
| hidden_states = layer(hidden_states, attention_mask) | |||
| # Final layer norm. | |||
| hidden_states = self.final_layernorm(hidden_states) | |||
| return hidden_states | |||
| class GPTMoETransformerLanguageModel(nn.Module): | |||
| """Transformer language model. | |||
| Arguments: | |||
| transformer_hparams: transformer hyperparameters | |||
| vocab_size: vocabulary size | |||
| max_sequence_length: maximum size of sequence. This | |||
| is used for positional embedding | |||
| embedding_dropout_prob: dropout probability for embeddings | |||
| num_tokentypes: size of the token-type embeddings. 0 value | |||
| will ignore this embedding | |||
| """ | |||
| def __init__(self, config): | |||
| super().__init__() | |||
| # Embeddings. | |||
| self.word_embeddings = nn.Embedding(config.vocab_size, | |||
| config.hidden_size) | |||
| self.position_embeddings = nn.Embedding(config.max_position_embeddings, | |||
| config.hidden_size) | |||
| self.embedding_dropout = nn.Dropout(config.hidden_dropout_prob) | |||
| # Transformer. | |||
| self.transformer = GPTMoETransformer(config) | |||
| def forward(self, input_ids, attention_mask, position_ids): | |||
| words_embeddings = self.word_embeddings(input_ids) | |||
| position_embeddings = self.position_embeddings(position_ids) | |||
| embeddings = words_embeddings + position_embeddings | |||
| transformer_input = self.embedding_dropout(embeddings) | |||
| transformer_output = self.transformer(transformer_input, | |||
| attention_mask) | |||
| logits = F.linear(transformer_output, self.word_embeddings.weight) | |||
| return logits | |||
| class GPTMoEModel(PreTrainedModel): | |||
| config_class = GPTMoEConfig | |||
| def _init_weights(self, module): | |||
| """Initialize the weights""" | |||
| if isinstance(module, nn.Linear): | |||
| # Slightly different from the TF version which uses truncated_normal for initialization | |||
| # cf https://github.com/pytorch/pytorch/pull/5617 | |||
| module.weight.data.normal_( | |||
| mean=0.0, std=self.config.initializer_range) | |||
| if module.bias is not None: | |||
| module.bias.data.zero_() | |||
| elif isinstance(module, nn.Embedding): | |||
| module.weight.data.normal_( | |||
| mean=0.0, std=self.config.initializer_range) | |||
| if module.padding_idx is not None: | |||
| module.weight.data[module.padding_idx].zero_() | |||
| elif isinstance(module, nn.LayerNorm): | |||
| module.bias.data.zero_() | |||
| module.weight.data.fill_(1.0) | |||
| def __init__(self, config): | |||
| super().__init__(config) | |||
| self.language_model = GPTMoETransformerLanguageModel(config) | |||
| def forward(self, | |||
| input_ids, | |||
| attention_mask=None, | |||
| position_ids=None, | |||
| labels=None, | |||
| **kwargs): | |||
| seq_length = input_ids.size(1) | |||
| attention_mask = torch.tril( | |||
| torch.ones((1, 1, seq_length, seq_length), | |||
| dtype=torch.long, | |||
| device=input_ids.device)) | |||
| if position_ids is None: | |||
| position_ids = torch.arange( | |||
| seq_length, dtype=torch.long, device=input_ids.device) | |||
| position_ids = position_ids.unsqueeze(0).expand_as(input_ids) | |||
| logits = self.language_model(input_ids, attention_mask, position_ids) | |||
| loss = None | |||
| if labels is not None: | |||
| loss_fct = nn.CrossEntropyLoss() | |||
| loss = loss_fct( | |||
| logits.view(-1, self.config.vocab_size), labels.view(-1)) | |||
| return addict.Dict(loss=loss, logits=logits) | |||
| @classmethod | |||
| def from_pretrained( | |||
| cls, pretrained_model_name_or_path: Optional[Union[str, | |||
| os.PathLike]]): | |||
| config = cls.config_class.from_pretrained( | |||
| pretrained_model_name_or_path) | |||
| model = cls(config) | |||
| state_dict_file = os.path.join(pretrained_model_name_or_path, | |||
| ModelFile.TORCH_MODEL_BIN_FILE) | |||
| state_dict = torch.load(state_dict_file) | |||
| if 'state_dict' in state_dict: | |||
| state_dict = state_dict['state_dict'] | |||
| state_dict = { | |||
| k.replace('model.language_model', 'language_model'): v | |||
| for k, v in state_dict.items() | |||
| } | |||
| model.load_state_dict(state_dict) | |||
| return model | |||
| def prepare_inputs_for_generation(self, input_ids, *args, **kwargs): | |||
| return {'input_ids': input_ids} | |||
| @@ -0,0 +1,145 @@ | |||
| # Copyright 2021-2022 The Alibaba PAI Team Authors. | |||
| # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| import os | |||
| import torch | |||
| from megatron import mpu | |||
| from megatron.model import Float16Module | |||
| from torch.nn.parallel import DistributedDataParallel as torchDDP | |||
| from .configuration import logger | |||
| from .moe.layer import MoE | |||
| def unwrap_model(model, module_instances=(torchDDP)): | |||
| return_list = True | |||
| if not isinstance(model, list): | |||
| model = [model] | |||
| return_list = False | |||
| unwrapped_model = [] | |||
| for model_module in model: | |||
| while isinstance(model_module, module_instances): | |||
| model_module = model_module.module | |||
| unwrapped_model.append(model_module) | |||
| if not return_list: | |||
| return unwrapped_model[0] | |||
| return unwrapped_model | |||
| def get_checkpoint_names(checkpoints_path, | |||
| path_load_tag, | |||
| num_experts, | |||
| tensor_rank=None, | |||
| expp_rank=None): | |||
| """Determine the directory name for this rank's checkpoint.""" | |||
| if tensor_rank is None: | |||
| tensor_rank = mpu.get_model_parallel_rank() | |||
| common_path = os.path.join(checkpoints_path, path_load_tag, | |||
| f'mp_rank_{tensor_rank:02d}') | |||
| if num_experts[0] > 0: | |||
| model_name = common_path + '_model_states.pt' | |||
| optim_name = os.path.join( | |||
| checkpoints_path, path_load_tag, | |||
| f'expp_rank_{expp_rank}_mp_rank_{tensor_rank:02d}_optim_states.pt') | |||
| else: | |||
| model_name = optim_name = os.path.join(common_path, | |||
| 'model_optim_rng.pt') | |||
| return model_name, optim_name | |||
| def _get_expert_ckpt_name(checkpoints_path, layer_id, expert_id): | |||
| mp_rank = mpu.get_model_parallel_rank() | |||
| ckpt_name = os.path.join( | |||
| os.path.join(checkpoints_path, 'model'), | |||
| f'layer_{layer_id}_expert_{expert_id}_mp_rank_{mp_rank:02d}_model_states.pt' | |||
| ) | |||
| return ckpt_name | |||
| def _load_base_checkpoint(load_dir, path_load_tag=None, num_experts=None): | |||
| """ Load the base state_dict from the given directory | |||
| If rank0 is true, just loads rank 0 checkpoint, ignoring arguments. | |||
| """ | |||
| largest_group_name = mpu.get_max_expert_size_name() | |||
| expp_rank = mpu.get_expert_parallel_rank(largest_group_name) | |||
| checkpoint_names = get_checkpoint_names( | |||
| load_dir, | |||
| path_load_tag=path_load_tag, | |||
| num_experts=num_experts, | |||
| expp_rank=expp_rank) | |||
| model_checkpoint_name, optim_checkpoint_name = checkpoint_names | |||
| logger.info(f'Loading model checkpoint from {model_checkpoint_name}') | |||
| model_state_dict = torch.load(model_checkpoint_name, map_location='cpu') | |||
| return model_state_dict | |||
| def load_checkpoint(model, | |||
| load_dir, | |||
| num_experts=None, | |||
| strict=True, | |||
| path_load_tag='model', | |||
| load_ds_ckpts=True): | |||
| model = unwrap_model(model, (torchDDP, Float16Module)) | |||
| model_state_dict = _load_base_checkpoint( | |||
| load_dir, path_load_tag=path_load_tag, num_experts=num_experts) | |||
| assert model_state_dict is not None | |||
| if load_ds_ckpts: | |||
| load_moe_checkpoint(model, model_state_dict['module'], load_dir) | |||
| else: | |||
| load_moe_checkpoint(model, model_state_dict['model'], load_dir) | |||
| if load_ds_ckpts: | |||
| model.load_state_dict(model_state_dict['module'], strict=strict) | |||
| else: | |||
| model.load_state_dict(model_state_dict['model'], strict=strict) | |||
| if torch.distributed.is_initialized(): | |||
| torch.distributed.barrier() | |||
| def load_moe_checkpoint(model, state_dict, load_dir): | |||
| moe_layer_id = 0 | |||
| for n_module, module in model.named_modules(): | |||
| if isinstance(module, MoE): # and torch.distributed.get_rank() == 0: | |||
| group_name = module.expert_group_name | |||
| num_local_experts = module.num_local_experts | |||
| expp_rank = mpu.get_expert_parallel_rank(group_name) | |||
| # loop all local_experts | |||
| for local_expert_id in range(num_local_experts): | |||
| global_expert_id = expp_rank * num_local_experts + local_expert_id | |||
| moe_load_path = _get_expert_ckpt_name(load_dir, moe_layer_id, | |||
| global_expert_id) | |||
| logger.info(f'Loading expert states from {moe_load_path}') | |||
| expert_state_dict = torch.load( | |||
| moe_load_path, map_location=torch.device('cpu')) | |||
| # Updating global -> local expert ids | |||
| moe_str_prefix = '.deepspeed_moe.experts.deepspeed_experts.' | |||
| for key in list(expert_state_dict.keys()): | |||
| local_key = key.replace( | |||
| f'{moe_str_prefix}{global_expert_id}', | |||
| f'{moe_str_prefix}{local_expert_id}') | |||
| expert_state_dict[local_key] = expert_state_dict.pop(key) | |||
| state_dict.update(expert_state_dict) | |||
| moe_layer_id += 1 | |||
| @@ -0,0 +1,128 @@ | |||
| # Copyright 2021-2022 The Alibaba PAI Team Authors. | |||
| # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| import torch | |||
| from transformers.configuration_utils import PretrainedConfig | |||
| from transformers.utils import logging | |||
| logger = logging.get_logger(__name__) | |||
| class GPTMoEConfig(PretrainedConfig): | |||
| model_type = 'gpt-moe' | |||
| def __init__( | |||
| self, | |||
| vocab_size=25600, | |||
| hidden_size=768, | |||
| ffn_hidden_size=None, | |||
| num_hidden_layers=12, | |||
| num_attention_heads=12, | |||
| intermediate_size=3072, | |||
| hidden_act='gelu', | |||
| hidden_dropout_prob=0.1, | |||
| attention_probs_dropout_prob=0.1, | |||
| max_position_embeddings=2048, | |||
| type_vocab_size=2, | |||
| layernorm_epsilon=1e-12, | |||
| bias_gelu_fusion=True, | |||
| fp32_residual_connection=False, | |||
| sequence_parallel=False, | |||
| fp16=False, | |||
| bf16=False, | |||
| apply_query_key_layer_scaling=True, | |||
| attention_softmax_in_fp32=False, | |||
| kv_channels=None, | |||
| masked_softmax_fusion=True, | |||
| attention_dropout=0.1, | |||
| bias_dropout_fusion=True, | |||
| apply_residual_connection_post_layernorm=False, | |||
| hidden_dropout=0.1, | |||
| init_method_std=0.02, | |||
| # generate | |||
| eod_id=7, | |||
| tokens_to_generate=100, | |||
| top_k=0, | |||
| top_p=0.9, | |||
| num_experts=[0], | |||
| use_tutel=False, | |||
| top_k_linear_strategy='standard', | |||
| use_expert_residual_network=False, | |||
| load_ds_ckpts=False, | |||
| model_dir=None, | |||
| **kwargs): | |||
| super().__init__(layer_norm_eps=layernorm_epsilon, **kwargs) | |||
| self.vocab_size = vocab_size | |||
| self.hidden_size = hidden_size | |||
| self.ffn_hidden_size = 4 * hidden_size \ | |||
| if ffn_hidden_size is None else ffn_hidden_size | |||
| self.num_hidden_layers = num_hidden_layers | |||
| self.num_attention_heads = num_attention_heads | |||
| self.hidden_act = hidden_act | |||
| self.intermediate_size = intermediate_size | |||
| self.hidden_dropout_prob = hidden_dropout_prob | |||
| self.attention_probs_dropout_prob = attention_probs_dropout_prob | |||
| self.max_position_embeddings = max_position_embeddings | |||
| self.type_vocab_size = type_vocab_size | |||
| self.layernorm_epsilon = layernorm_epsilon | |||
| self.bias_gelu_fusion = bias_gelu_fusion | |||
| self.fp32_residual_connection = fp32_residual_connection | |||
| self.sequence_parallel = sequence_parallel | |||
| self.fp16 = fp16 | |||
| self.bf16 = bf16 | |||
| assert not (fp16 and bf16) | |||
| self.apply_query_key_layer_scaling = apply_query_key_layer_scaling | |||
| self.attention_softmax_in_fp32 = attention_softmax_in_fp32 | |||
| if kv_channels is None: | |||
| assert hidden_size % num_attention_heads == 0 | |||
| self.kv_channels = hidden_size // num_attention_heads | |||
| self.masked_softmax_fusion = masked_softmax_fusion | |||
| self.attention_dropout = attention_dropout | |||
| self.bias_dropout_fusion = bias_dropout_fusion | |||
| self.apply_residual_connection_post_layernorm = \ | |||
| apply_residual_connection_post_layernorm | |||
| self.hidden_dropout = hidden_dropout | |||
| self.init_method_std = init_method_std | |||
| self.eod_id = eod_id | |||
| self.tokens_to_generate = tokens_to_generate | |||
| self.top_k = top_k | |||
| self.top_p = top_p | |||
| self.num_experts = num_experts | |||
| self.use_tutel = use_tutel | |||
| self.top_k_linear_strategy = top_k_linear_strategy | |||
| self.use_expert_residual_network = use_expert_residual_network | |||
| self.load_ds_ckpts = load_ds_ckpts | |||
| self.model_dir = model_dir | |||
| if self.num_experts[0] > torch.cuda.device_count(): | |||
| self.moe_expert_parallel_size = torch.cuda.device_count() | |||
| else: | |||
| self.moe_expert_parallel_size = self.num_experts[0] | |||
| TORCH_MAJOR = int(torch.__version__.split('.')[0]) | |||
| TORCH_MINOR = int(torch.__version__.split('.')[1]) | |||
| self.no_persist_layer_norm = \ | |||
| TORCH_MAJOR < 1 or (TORCH_MAJOR == 1 and TORCH_MINOR < 11) | |||
| @property | |||
| def params_dtype(self): | |||
| if self.fp16: | |||
| return torch.half | |||
| elif self.bf16: | |||
| return torch.bfloat16 | |||
| else: | |||
| return torch.float | |||
| @@ -0,0 +1,36 @@ | |||
| ''' | |||
| Copyright 2020 The Microsoft DeepSpeed Team | |||
| ''' | |||
| import copy | |||
| import torch | |||
| class Experts(torch.nn.Module): | |||
| def __init__(self, expert, num_local_experts=1, expert_group_name=None): | |||
| super(Experts, self).__init__() | |||
| self.deepspeed_experts = torch.nn.ModuleList( | |||
| [copy.deepcopy(expert) for i in range(num_local_experts)]) | |||
| self.num_local_experts = num_local_experts | |||
| # TODO: revisit allreduce for moe.gate... | |||
| for expert in self.deepspeed_experts: | |||
| # TODO: Create param groups to handle expert + data case (e.g. param.group = moe_group) | |||
| for name, param in expert.named_parameters(): | |||
| param.allreduce = False | |||
| param.group_name = expert_group_name | |||
| def forward(self, inputs): | |||
| chunks = inputs.chunk(self.num_local_experts, dim=1) | |||
| expert_outputs = [] | |||
| for chunk, expert in zip(chunks, self.deepspeed_experts): | |||
| out = expert(chunk) | |||
| if type(out) is tuple: | |||
| out = out[0] # Ignore the bias term for now | |||
| expert_outputs += [out] | |||
| expert_output = torch.cat(expert_outputs, dim=1) | |||
| return expert_output | |||
| @@ -0,0 +1,98 @@ | |||
| ''' | |||
| Copyright 2020 The Microsoft DeepSpeed Team | |||
| ''' | |||
| import typing | |||
| import torch | |||
| from megatron import mpu | |||
| from .experts import Experts | |||
| from .sharded_moe import MOELayer, TopKGate | |||
| class MoE(torch.nn.Module): | |||
| def __init__(self, | |||
| hidden_size, | |||
| expert, | |||
| num_experts=1, | |||
| ep_size=1, | |||
| k=1, | |||
| capacity_factor=1., | |||
| eval_capacity_factor=1., | |||
| min_capacity=4, | |||
| use_residual=False, | |||
| noisy_gate_policy: typing.Optional[str] = None, | |||
| drop_tokens: bool = True, | |||
| use_rts=True, | |||
| use_tutel: bool = False, | |||
| top_k_linear_strategy: str = 'normal', | |||
| use_expert_residual_network: bool = False): | |||
| super(MoE, self).__init__() | |||
| self.use_residual = use_residual | |||
| assert num_experts % ep_size == 0, f'Number of experts ({num_experts}) should ' \ | |||
| f'be divisible by expert parallel size ({ep_size})' | |||
| self.ep_size = ep_size | |||
| self.expert_group_name = f'ep_size_{self.ep_size}' | |||
| self.num_experts = num_experts | |||
| self.num_local_experts = num_experts // self.ep_size | |||
| assert noisy_gate_policy is None or noisy_gate_policy in ['None', 'Jitter', 'RSample'], \ | |||
| 'Unsupported noisy_gate_policy: ' + noisy_gate_policy | |||
| experts = Experts(expert, self.num_local_experts, | |||
| self.expert_group_name) | |||
| self.deepspeed_moe = MOELayer( | |||
| TopKGate( | |||
| hidden_size, | |||
| num_experts, | |||
| k, | |||
| capacity_factor, | |||
| eval_capacity_factor, | |||
| min_capacity, | |||
| noisy_gate_policy, | |||
| drop_tokens, | |||
| use_rts, | |||
| top_k_linear_strategy=top_k_linear_strategy), | |||
| experts, | |||
| self.expert_group_name, | |||
| self.ep_size, | |||
| self.num_local_experts, | |||
| use_tutel=use_tutel, | |||
| use_expert_residual_network=use_expert_residual_network) | |||
| self.deepspeed_moe._set_ep_group( | |||
| mpu.get_expert_parallel_group(self.expert_group_name)) | |||
| if self.use_residual: | |||
| self.mlp = expert | |||
| # coefficient is used for weighted sum of the output of expert and mlp | |||
| self.coefficient = torch.nn.Linear(hidden_size, 2) | |||
| def forward(self, hidden_states, used_token=None): | |||
| """ MoE forward | |||
| Arguments: | |||
| hidden_states (Tensor): input to the layer | |||
| used_token (Tensor, optional): default: None, mask only used tokens | |||
| Returns: | |||
| A tuple including output, gate loss, and expert count. | |||
| * output (Tensor): output of the model | |||
| * l_aux (Tensor): gate loss value | |||
| * exp_counts (int): expert count | |||
| """ | |||
| output = self.deepspeed_moe(hidden_states, used_token) | |||
| if self.use_residual: | |||
| # Residual MoE | |||
| output_mlp = self.mlp(hidden_states) | |||
| if type(output_mlp) is tuple: | |||
| output_mlp = output_mlp[0] # Ignore the bias term for now | |||
| coef = self.coefficient(hidden_states) | |||
| coef = torch.nn.functional.softmax(coef, dim=1) | |||
| output = output * coef[..., 0:1] + output_mlp * coef[..., 1:] | |||
| return output, self.deepspeed_moe.l_aux, self.deepspeed_moe.exp_counts | |||
| @@ -0,0 +1,87 @@ | |||
| ''' | |||
| Copyright 2020 The Microsoft DeepSpeed Team | |||
| ''' | |||
| import torch | |||
| from megatron import mpu | |||
| def _gather_tokens(input_, dim=0): | |||
| """Gather tensors and concatenate them along a dimension""" | |||
| input_ = input_.contiguous() | |||
| # Size and dimension. | |||
| rank = mpu.get_tensor_model_parallel_rank() | |||
| tensor_list = [ | |||
| torch.empty_like(input_) | |||
| for _ in range(mpu.get_model_parallel_world_size()) | |||
| ] | |||
| tensor_list[rank] = input_ | |||
| torch.distributed.all_gather( | |||
| tensor_list, input_, group=mpu.get_tensor_model_parallel_group()) | |||
| # Note: torch.cat already creates a contiguous tensor. | |||
| output = torch.cat(tensor_list, dim=dim).contiguous() | |||
| return output | |||
| def _drop_tokens(input_, dim=0): | |||
| """Divide a tensor among the tensor parallel ranks""" | |||
| total_chunks = mpu.get_model_parallel_world_size() | |||
| this_chunk = mpu.get_model_parallel_rank() | |||
| assert input_.shape[ | |||
| dim] % total_chunks == 0, f'input dimension {dim} ({input_.shape[dim]}) ' \ | |||
| f'is not divisible by tensor parallel world size ({total_chunks})' | |||
| chunk_size = input_.shape[dim] // total_chunks | |||
| return torch.narrow(input_, dim, this_chunk * chunk_size, chunk_size) | |||
| class _GatherTokens(torch.autograd.Function): | |||
| """All gather tokens among the tensor parallel ranks""" | |||
| @staticmethod | |||
| def symbolic(graph, input_, dim): | |||
| return _gather_tokens(input_, dim) | |||
| @staticmethod | |||
| def forward(ctx, input_, dim): | |||
| ctx.dim = dim | |||
| return _gather_tokens(input_, dim) | |||
| @staticmethod | |||
| def backward(ctx, grad_output): | |||
| return _drop_tokens(grad_output, ctx.dim), None | |||
| class _DropTokens(torch.autograd.Function): | |||
| 'Divide tokens equally among the tensor parallel ranks' | |||
| @staticmethod | |||
| def symbolic(graph, input_, dim): | |||
| return _drop_tokens(input_, dim) | |||
| @staticmethod | |||
| def forward(ctx, input_, dim): | |||
| ctx.dim = dim | |||
| return _drop_tokens(input_, dim) | |||
| @staticmethod | |||
| def backward(ctx, input_): | |||
| return _gather_tokens(input_, ctx.dim), None | |||
| def gather_tokens(input_, dim=0): | |||
| if mpu is None or mpu.get_model_parallel_world_size() == 1: | |||
| # no tensor parallelism for non-experts | |||
| return input_ | |||
| return _GatherTokens.apply(input_, dim) | |||
| def drop_tokens(input_, dim=0): | |||
| if mpu is None or mpu.get_model_parallel_world_size() == 1: | |||
| # no tensor parallelism for non-experts | |||
| return input_ | |||
| return _DropTokens.apply(input_, dim) | |||
| @@ -0,0 +1,647 @@ | |||
| ''' | |||
| Copyright 2021 The Microsoft DeepSpeed Team | |||
| ''' | |||
| # The file has been adapted from two fairscale files: | |||
| # (1) https://github.com/facebookresearch/fairscale/blob/master/fairscale/nn/moe/moe_layer.py | |||
| # (2) https://github.com/facebookresearch/fairscale/blob/master/fairscale/nn/moe/top2gate.py | |||
| # Git commit hash: 34df606902a240567a0d898037ece55c2f1336cf | |||
| # We retain the following license from the original files: | |||
| # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | |||
| # | |||
| # This source code is licensed under the BSD license found in the | |||
| # LICENSE file in the root directory of this source tree. | |||
| import math | |||
| from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple | |||
| import torch | |||
| import torch.distributed as dist | |||
| import torch.nn.functional as F | |||
| from megatron import mpu | |||
| from scipy.special import binom | |||
| from torch import Tensor, nn | |||
| from torch.nn import Module | |||
| from ..configuration import logger | |||
| from .mappings import drop_tokens, gather_tokens | |||
| try: | |||
| from apex.normalization import FusedLayerNorm as _FusedLayerNorm | |||
| has_fused_layernorm = True | |||
| class FusedLayerNorm(_FusedLayerNorm): | |||
| @torch.jit.unused | |||
| def forward(self, x): | |||
| if not x.is_cuda: | |||
| return super().forward(x) | |||
| else: | |||
| with torch.cuda.device(x.device): | |||
| return super().forward(x) | |||
| except ImportError: | |||
| has_fused_layernorm = False | |||
| if TYPE_CHECKING: | |||
| Base = Module[Tensor] | |||
| else: | |||
| Base = Module | |||
| uniform_map: Dict[torch.device, Callable] = {} | |||
| gumbel_map: Dict[torch.device, Callable] = {} | |||
| exp_selection_uniform_map: Dict[torch.device, Callable] = {} | |||
| def multiplicative_jitter(x, device: torch.device, epsilon=1e-2): | |||
| """ | |||
| Modified from switch transformer paper. mesh transformers | |||
| Multiply values by a random number between 1-epsilon and 1+epsilon. | |||
| Makes models more resilient to rounding errors introduced by bfloat16. | |||
| This seems particularly important for logits. | |||
| Args: | |||
| x: a torch.tensor | |||
| device: torch.device | |||
| epsilon: a floating point value | |||
| Returns: | |||
| a jittered x. | |||
| """ | |||
| if epsilon == 0: | |||
| return x | |||
| uniform = uniform_map.get(device) | |||
| if uniform is None: | |||
| uniform = torch.distributions.uniform.Uniform( | |||
| low=torch.tensor(1.0 - epsilon, device=device), | |||
| high=torch.tensor(1.0 + epsilon, | |||
| device=device)).rsample # type: ignore | |||
| uniform_map[device] = uniform | |||
| return x * uniform(x.shape) | |||
| def gumbel_rsample(shape: Tuple, device: torch.device) -> Tensor: | |||
| gumbel = gumbel_map.get(device) | |||
| if gumbel is None: | |||
| one = torch.tensor(1.0, device=device) | |||
| zero = torch.tensor(0.0, device=device) | |||
| gumbel = torch.distributions.gumbel.Gumbel(zero, | |||
| one).rsample # type: ignore | |||
| gumbel_map[device] = gumbel | |||
| return gumbel(shape) | |||
| # Based on https://github.com/pytorch/pytorch/pull/40762 | |||
| class _AllToAll(torch.autograd.Function): | |||
| @staticmethod | |||
| def forward(ctx: Any, group: dist.ProcessGroup, | |||
| input: Tensor) -> Tensor: # type: ignore | |||
| ctx.group = group | |||
| input = input.contiguous() | |||
| output = torch.empty_like(input) | |||
| dist.all_to_all_single(output, input, group=group) | |||
| return output | |||
| @staticmethod | |||
| def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor]: | |||
| return (None, _AllToAll.apply(ctx.group, *grad_output)) | |||
| # einsum rewrites are on par or more performant | |||
| # switch can be bubbled up in future | |||
| USE_EINSUM = True | |||
| # einsum dimensions: (g)roup, (s)equence, (e)xpert, (m)odel, (c)apacity | |||
| # See https://arxiv.org/pdf/2006.16668.pdf for details. | |||
| def einsum(rule, a, b): | |||
| if USE_EINSUM: | |||
| return torch.einsum(rule, a, b) | |||
| elif rule == 's,se->se': | |||
| return a.reshape(a.shape[0], -1) * b | |||
| elif rule == 'se,sc->sec': | |||
| return a.unsqueeze(2) * b.unsqueeze(1) | |||
| elif rule == 'se,se->s': | |||
| return torch.bmm(a.unsqueeze(1), b.unsqueeze(2)).reshape(-1) | |||
| elif rule == 'sec,sm->ecm': | |||
| s = a.shape[0] | |||
| e = a.shape[1] | |||
| c = a.shape[2] | |||
| m = b.shape[1] | |||
| return torch.matmul(a.reshape(s, -1).t(), b).reshape(e, c, m) | |||
| elif rule == 'sec,ecm->sm': | |||
| return torch.matmul( | |||
| a.reshape(a.shape[0], -1), b.reshape(-1, b.shape[-1])) | |||
| elif rule == 'ks,ksm->sm': | |||
| k = b.shape[0] | |||
| s = b.shape[1] | |||
| m = b.shape[2] | |||
| # [k, s] -> [s, k] -> [s, 1, k] | |||
| a = a.t().unsqueeze(1) | |||
| # [k,s,m] -> [k, sm] -> [sm, k] -> [s, m, k] | |||
| b = b.reshape(k, -1).t().reshape(s, m, k) | |||
| # bmm([s, 1, k], [s, m, k]^t) -> [s, m, 1] | |||
| return torch.bmm(a, b.transpose(1, 2)).squeeze(2) | |||
| else: | |||
| return torch.einsum(rule, a, b) | |||
| # The following functions are extracted and scripted | |||
| # because otherwise during a torch.jit.trace, the non-Tensor | |||
| # values used in the calculations get recorded as constants. | |||
| # torch.jit.script coerces them into Tensors and preserves | |||
| # their dynamic shapes. This enables ONNX export. | |||
| # We can't script the entire top1gating function because it | |||
| # includes stateful caching logic which is incompatible with ONNX. | |||
| @torch.jit.script | |||
| def _capacity(gates: Tensor, capacity_factor: Tensor, | |||
| min_capacity: Tensor) -> Tensor: | |||
| # gates has shape of SE | |||
| num_tokens = gates.shape[0] | |||
| num_experts = gates.shape[1] | |||
| # to(torch.int64) works around a bug in torch.onnx.export: | |||
| # it should cast k to int64 when converting torch.topk but it doesn't. | |||
| capacity = torch.ceil( | |||
| (num_tokens / num_experts) * capacity_factor).to(torch.int64) | |||
| if capacity < min_capacity: | |||
| capacity = min_capacity.to(torch.int64) | |||
| return capacity | |||
| @torch.jit.script | |||
| def _top_idx(source, k): | |||
| return torch.topk(source, k=k, dim=0)[1] | |||
| @torch.jit.script | |||
| def _one_hot_to_float(x, num_classes): | |||
| return F.one_hot(x, num_classes=num_classes).float() | |||
| def top1gating( | |||
| logits: Tensor, | |||
| capacity_factor: float, | |||
| min_capacity: int, | |||
| used_token: Tensor = None, | |||
| noisy_gate_policy: Optional[str] = None, | |||
| drop_tokens: bool = True, | |||
| use_rts: bool = True, | |||
| use_tutel: bool = False) -> Tuple[Tensor, Tensor, Tensor, Tensor]: | |||
| """Implements Top1Gating on logits.""" | |||
| if noisy_gate_policy == 'RSample': | |||
| logits_w_noise = logits + gumbel_rsample( | |||
| logits.shape, device=logits.device) | |||
| # everything is in fp32 in this function | |||
| gates = F.softmax(logits, dim=1) | |||
| capacity = _capacity(gates, torch.tensor(capacity_factor), | |||
| torch.tensor(min_capacity)) | |||
| # Create a mask for 1st's expert per token | |||
| # noisy gating | |||
| indices1_s = torch.argmax( | |||
| logits_w_noise if noisy_gate_policy == 'RSample' else gates, dim=1) | |||
| num_experts = int(gates.shape[1]) | |||
| mask1 = F.one_hot(indices1_s, num_classes=num_experts) | |||
| # mask only used tokens | |||
| if used_token is not None: | |||
| mask1 = einsum('s,se->se', used_token, mask1) | |||
| # gating decisions | |||
| exp_counts = torch.sum(mask1, dim=0).detach().to('cpu') | |||
| # if we don't want to drop any tokens | |||
| if not drop_tokens: | |||
| new_capacity = torch.max(exp_counts).to(logits.device) | |||
| dist.all_reduce( | |||
| new_capacity, op=dist.ReduceOp.MAX, group=dist.group.WORLD) | |||
| capacity = new_capacity | |||
| # Compute l_aux | |||
| alpha = torch.max(gates, dim=1).values.unsqueeze(1) | |||
| me = torch.mean(gates, dim=0) | |||
| ce = torch.mean(mask1.float(), dim=0) | |||
| l_aux = torch.sum(me * ce) * num_experts | |||
| # Random Token Selection | |||
| if use_rts: | |||
| uniform = exp_selection_uniform_map.get(logits.device) | |||
| if uniform is None: | |||
| uniform = torch.distributions.uniform.Uniform( | |||
| low=torch.tensor(0.0, device=logits.device), | |||
| high=torch.tensor(1.0, device=logits.device)).rsample | |||
| exp_selection_uniform_map[logits.device] = uniform | |||
| mask1_rand = mask1 * uniform(mask1.shape) | |||
| else: | |||
| mask1_rand = mask1 | |||
| assert logits.shape[0] >= min_capacity, \ | |||
| 'No. of tokens (batch-size) should be greater than min_capacity. ' \ | |||
| 'Either set min_capacity to 0 or increase your batch size.' | |||
| top_idx = _top_idx(mask1_rand, capacity) | |||
| new_mask1 = mask1 * torch.zeros_like(mask1).scatter_(0, top_idx, 1) | |||
| mask1 = new_mask1 | |||
| if use_tutel: | |||
| # Tutel doesn't support index values masked with zero | |||
| # so we need to replace masked indices with -1 | |||
| indices_mask = mask1.sum(dim=1) * num_experts - 1 | |||
| indices1_s = torch.min(indices1_s, indices_mask) | |||
| # Compute locations in capacity buffer | |||
| if use_tutel: | |||
| locations1 = tutel_moe.fast_cumsum_sub_one(mask1) | |||
| else: | |||
| locations1 = torch.cumsum(mask1, dim=0) - 1 | |||
| if use_tutel: | |||
| gates1_s = (gates * mask1).sum(dim=1) | |||
| locations1_s = torch.sum(locations1 * mask1, dim=1) | |||
| return l_aux, capacity, num_experts, [ | |||
| indices1_s, | |||
| ], [ | |||
| locations1_s, | |||
| ], [ | |||
| gates1_s, | |||
| ], exp_counts, alpha | |||
| # Store the capacity location for each token | |||
| locations1_s = torch.sum(locations1 * mask1, dim=1) | |||
| # Normalize gate probabilities | |||
| mask1_float = mask1.float() | |||
| gates = gates * mask1_float | |||
| locations1_sc = _one_hot_to_float(locations1_s, capacity) | |||
| combine_weights = einsum('se,sc->sec', gates, locations1_sc) | |||
| dispatch_mask = combine_weights.bool() | |||
| return l_aux, combine_weights, dispatch_mask, exp_counts, alpha | |||
| class TopKGate(Module): | |||
| """Gate module which implements Top2Gating as described in Gshard_. | |||
| :: | |||
| gate = TopKGate(model_dim, num_experts) | |||
| l_aux, combine_weights, dispatch_mask = gate(input) | |||
| .. Gshard_: https://arxiv.org/pdf/2006.16668.pdf | |||
| Args: | |||
| model_dim (int): | |||
| size of model embedding dimension | |||
| num_experts (ints): | |||
| number of experts in model | |||
| """ | |||
| wg: torch.nn.Linear | |||
| def __init__(self, | |||
| model_dim: int, | |||
| num_experts: int, | |||
| k: int = 1, | |||
| capacity_factor: float = 1.0, | |||
| eval_capacity_factor: float = 1.0, | |||
| min_capacity: int = 8, | |||
| noisy_gate_policy: Optional[str] = None, | |||
| drop_tokens: bool = True, | |||
| use_rts: bool = True, | |||
| top_k_linear_strategy: str = 'standard') -> None: | |||
| super().__init__() | |||
| # Only top-1 are supported at the moment. | |||
| if k != 1: | |||
| raise ValueError('Only top-1 gatings are supported.') | |||
| if top_k_linear_strategy == 'standard': | |||
| self.wg = torch.nn.Linear( | |||
| model_dim, num_experts, bias=False).float() | |||
| elif top_k_linear_strategy == 'lsoftmax': | |||
| self.wg = LSoftmaxLinearLayer( | |||
| model_dim, num_experts, margin=1).float() | |||
| else: | |||
| raise ValueError( | |||
| 'Only standard or lsoftmax top-k-linear-strategy are supported.' | |||
| ) | |||
| self.k = k | |||
| self.capacity_factor = capacity_factor | |||
| self.eval_capacity_factor = eval_capacity_factor | |||
| self.min_capacity = min_capacity | |||
| self.noisy_gate_policy = noisy_gate_policy | |||
| self.wall_clock_breakdown = False | |||
| self.gate_time = 0.0 | |||
| self.drop_tokens = drop_tokens | |||
| self.use_rts = use_rts | |||
| self.top_k_linear_strategy = top_k_linear_strategy | |||
| def forward( | |||
| self, | |||
| input: torch.Tensor, | |||
| used_token: torch.Tensor = None, | |||
| use_tutel: bool = False | |||
| ) -> Tuple[Tensor, Tensor, Tensor]: # type: ignore | |||
| if self.wall_clock_breakdown: | |||
| self.timers('TopKGate').start() | |||
| if self.top_k_linear_strategy == 'standard': | |||
| if self.wg.weight.dtype != torch.float32: | |||
| self.wg = self.wg.float() | |||
| elif self.top_k_linear_strategy == 'lsoftmax': | |||
| if self.wg.weight.weight.dtype != torch.float32: | |||
| self.wg.weight = self.wg.weight.float() | |||
| input_fp32 = input.float() | |||
| # input jittering | |||
| if self.noisy_gate_policy == 'Jitter' and self.training: | |||
| input_fp32 = multiplicative_jitter(input_fp32, device=input.device) | |||
| if self.k == 1: | |||
| if self.top_k_linear_strategy == 'standard': | |||
| logits = self.wg(input_fp32) | |||
| elif self.top_k_linear_strategy == 'lsoftmax': | |||
| logits = self.wg(input_fp32, input_fp32.device, self.training) | |||
| gate_output = top1gating( | |||
| logits, self.capacity_factor if self.training else | |||
| self.eval_capacity_factor, self.min_capacity, used_token, | |||
| self.noisy_gate_policy if self.training else None, | |||
| self.drop_tokens, self.use_rts, use_tutel) | |||
| if self.wall_clock_breakdown: | |||
| self.timers('TopKGate').stop() | |||
| self.gate_time = self.timers('TopKGate').elapsed( | |||
| reset=False) * 1000 | |||
| return gate_output | |||
| class MOELayer(Base): | |||
| """MOELayer module which implements MixtureOfExperts as described in Gshard_. | |||
| :: | |||
| gate = TopKGate(model_dim, num_experts) | |||
| moe = MOELayer(gate, expert) | |||
| output = moe(input) | |||
| l_aux = moe.l_aux | |||
| .. Gshard_: https://arxiv.org/pdf/2006.16668.pdf | |||
| Args: | |||
| gate (torch.nn.Module): | |||
| gate network | |||
| expert (torch.nn.Module): | |||
| expert network | |||
| """ | |||
| def __init__(self, | |||
| gate: Module, | |||
| experts: Module, | |||
| ep_group_name, | |||
| ep_size, | |||
| num_local_experts: int, | |||
| use_tutel: bool = False, | |||
| use_expert_residual_network: bool = False) -> None: | |||
| super().__init__() | |||
| self.gate = gate | |||
| self.experts = experts | |||
| self.ep_group = None | |||
| self.ep_size = ep_size | |||
| self.ep_group_name = ep_group_name | |||
| self.num_local_experts = num_local_experts | |||
| self.wall_clock_breakdown = False | |||
| self.use_expert_residual_network = use_expert_residual_network | |||
| if self.use_expert_residual_network: | |||
| self.expert_network = nn.Sequential( | |||
| *([ExpertResidualLayer(self.gate.model_dim) | |||
| for _ in range(6)])) | |||
| self.use_tutel = use_tutel and TUTEL_INSTALLED | |||
| if self.use_tutel: | |||
| logger.info('Using Tutel optimizations.') | |||
| elif use_tutel and not TUTEL_INSTALLED: | |||
| logger.info( | |||
| 'Tutel optimization requested but not installed Proceeding without Tutel.' | |||
| ) | |||
| def _set_ep_group(self, ep_group): | |||
| self.ep_group = ep_group | |||
| def forward(self, *input: Tensor, **kwargs: Any) -> Tensor: | |||
| if self.wall_clock_breakdown: | |||
| self.timers('moe').start() | |||
| # Implement Algorithm 2 from GShard paper. | |||
| d_model = input[0].shape[-1] | |||
| # Initial implementation -> Reshape into S tokens by dropping sequence dimension. | |||
| # Reshape into G groups so that each group can distribute tokens equally | |||
| # group_size = kwargs['group_size'] if 'group_size' in kwargs.keys() else 1 | |||
| reshaped_input = input[0].reshape(-1, d_model) | |||
| if self.use_tutel: | |||
| self.l_aux, C, E, indices_, locations_, gates_, self.exp_counts, alpha = self.gate( | |||
| reshaped_input, input[1], True) | |||
| _, M = reshaped_input.size(0), reshaped_input.size(1) | |||
| if not hasattr(self, '_tutel_dispatcher'): | |||
| self._tutel_dispatcher = tutel_moe.fast_dispatcher( | |||
| E, C, M, dispatch_dtype=reshaped_input.dtype) | |||
| self._tutel_dispatcher.update( | |||
| indices_, locations_, gates_, capacity=C) | |||
| dispatched_input = self._tutel_dispatcher.encode(reshaped_input) | |||
| else: | |||
| self.l_aux, combine_weights, dispatch_mask, self.exp_counts, alpha = self.gate( | |||
| reshaped_input, input[1]) | |||
| dispatched_input = einsum('sec,sm->ecm', | |||
| dispatch_mask.type_as(input[0]), | |||
| reshaped_input) | |||
| if self.wall_clock_breakdown: | |||
| self.timers('falltoall').start() | |||
| if mpu.get_expert_model_parallel_world_size() == 1: | |||
| # If the non-expert is tensor-parallel, it will create | |||
| # duplicate tokens on the tensor-parallel ranks. | |||
| # Since our experts are not tensor-parallel, these duplicates | |||
| # need to be dropped to ensure correctness. | |||
| # this also doubles up as a communication optimization as we are | |||
| # reducing the all-to-all communication volume. | |||
| if self.use_tutel: | |||
| # reshape tutel's output from [e*c,m] to [e,c,m] | |||
| dispatched_input = dispatched_input.reshape( | |||
| self.ep_size * self.num_local_experts, -1, d_model) | |||
| dispatched_input = drop_tokens(dispatched_input, dim=1) | |||
| dispatched_input = _AllToAll.apply(self.ep_group, dispatched_input) | |||
| if self.wall_clock_breakdown: | |||
| self.timers('falltoall').stop() | |||
| self.time_falltoall = self.timers('falltoall').elapsed( | |||
| reset=False) * 1000 | |||
| # Re-shape after all-to-all: ecm -> gecm | |||
| dispatched_input = dispatched_input.reshape(self.ep_size, | |||
| self.num_local_experts, -1, | |||
| d_model) | |||
| expert_output = self.experts(dispatched_input) | |||
| if self.wall_clock_breakdown: | |||
| self.timers('salltoall').start() | |||
| expert_output = _AllToAll.apply(self.ep_group, expert_output) | |||
| if self.wall_clock_breakdown: | |||
| self.timers('salltoall').stop() | |||
| self.time_salltoall = self.timers('salltoall').elapsed( | |||
| reset=False) * 1000 | |||
| # Re-shape back: gecm -> ecm | |||
| expert_output = expert_output.reshape( | |||
| self.ep_size * self.num_local_experts, -1, d_model) | |||
| if mpu.get_expert_model_parallel_world_size() == 1: | |||
| # the dropped duplicate tokens need to be gathered on each | |||
| # tensor parallel rank again for the tensor-parallel | |||
| # non-expert of the next layer. | |||
| expert_output = gather_tokens(expert_output, dim=1) | |||
| if self.use_tutel: | |||
| combined_output = self._tutel_dispatcher.decode( | |||
| expert_output.view(E * C, M)) | |||
| else: | |||
| combined_output = einsum('sec,ecm->sm', | |||
| combine_weights.type_as(input[0]), | |||
| expert_output) | |||
| if self.use_expert_residual_network: | |||
| combined_output = alpha * self.expert_network(combined_output) + ( | |||
| 1 - alpha) * combined_output | |||
| a = combined_output.reshape(input[0].shape) | |||
| if self.wall_clock_breakdown: | |||
| self.timers('moe').stop() | |||
| self.time_moe = self.timers('moe').elapsed(reset=False) * 1000 | |||
| return a | |||
| class LSoftmaxLinearLayer(torch.nn.Module): | |||
| def __init__(self, input_features, output_features, margin): | |||
| super().__init__() | |||
| self.input_dim = input_features # number of input feature i.e. output of the last fc layer | |||
| self.output_dim = output_features # number of output = class numbers | |||
| self.margin = margin # m | |||
| self.beta = 100 | |||
| self.beta_min = 0 | |||
| self.scale = 0.99 | |||
| self.num_experts = output_features | |||
| # Initialize L-Softmax parameters | |||
| self.weight = torch.nn.Linear( | |||
| input_features, output_features, bias=False).float() | |||
| self.divisor = math.pi / self.margin # pi/m | |||
| self.C_m_2n = torch.Tensor(binom(margin, range(0, margin + 1, | |||
| 2))) # C_m{2n} | |||
| self.cos_powers = torch.Tensor(range(self.margin, -1, -2)) # m - 2n | |||
| self.sin2_powers = torch.Tensor(range(len(self.cos_powers))) # n | |||
| self.signs = torch.ones(margin // 2 + 1) # 1, -1, 1, -1, ... | |||
| self.signs[1::2] = -1 | |||
| def calculate_cos_m_theta(self, cos_theta, device): | |||
| sin2_theta = 1 - cos_theta**2 | |||
| cos_terms = cos_theta.unsqueeze(1)**self.cos_powers.to( | |||
| device).unsqueeze(0) # cos^{m - 2n} | |||
| sin2_terms = ( | |||
| sin2_theta.unsqueeze(1)**self.sin2_powers.to(device).unsqueeze(0)) | |||
| cos_m_theta = (self.signs.to(device).unsqueeze(0) | |||
| * self.C_m_2n.to(device).unsqueeze(0) * cos_terms | |||
| * sin2_terms).sum(1) # summation of all terms | |||
| return cos_m_theta | |||
| def reset_parameters(self): | |||
| nn.init.kaiming_normal_(self.weight.data.t()) | |||
| def find_k(self, cos): | |||
| # to account for acos numerical errors | |||
| eps = 1e-7 | |||
| cos = torch.clamp(cos, -1 + eps, 1 - eps) | |||
| acos = cos.acos() | |||
| k = (acos / self.divisor).floor().detach() | |||
| return k | |||
| def forward(self, input, device, training): | |||
| if training: | |||
| x, w = input, self.weight.float() | |||
| beta = max(self.beta, self.beta_min) | |||
| logit = w(x) | |||
| indexes = range(logit.size(0)) | |||
| # target = torch.fmod(torch.randperm(logit.size(0)), self.num_experts) | |||
| target = torch.fmod( | |||
| torch.range(0, | |||
| logit.size(0) - 1), self.num_experts).long() | |||
| logit_target = logit[indexes, target] | |||
| # cos(theta) = w * x / ||w||*||x|| | |||
| w_target_norm = w.weight[:, target].norm(p=2, dim=0) | |||
| x_norm = x.norm(p=2, dim=1) | |||
| cos_theta_target = logit_target / (w_target_norm * x_norm + 1e-10) | |||
| # equation 7 | |||
| cos_m_theta_target = self.calculate_cos_m_theta( | |||
| cos_theta_target, device) | |||
| # find k in equation 6 | |||
| k = self.find_k(cos_theta_target) | |||
| # f_y_i | |||
| logit_target_updated = w_target_norm * x_norm * (( | |||
| (-1)**k * cos_m_theta_target) - 2 * k) | |||
| logit_target_updated_beta = (logit_target_updated + beta | |||
| * logit[indexes, target]) / (1 + beta) | |||
| logit[indexes, target] = logit_target_updated_beta | |||
| self.beta *= self.scale | |||
| return logit | |||
| else: | |||
| return self.weight(input) | |||
| def LayerNorm(normalized_shape, | |||
| eps=1e-5, | |||
| elementwise_affine=True, | |||
| export=False): | |||
| if torch.jit.is_scripting() or torch.jit.is_tracing(): | |||
| export = True | |||
| if not export and torch.cuda.is_available() and has_fused_layernorm: | |||
| return FusedLayerNorm(normalized_shape, eps, elementwise_affine) | |||
| return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine) | |||
| class ExpertResidualLayer(torch.nn.Module): | |||
| def __init__(self, embed_dim): | |||
| super().__init__() | |||
| self.norm = LayerNorm(embed_dim, export=False) | |||
| self.ff1 = torch.nn.Linear(embed_dim, embed_dim * 4) | |||
| self.ff2 = torch.nn.Linear(embed_dim * 4, embed_dim) | |||
| self.ff2.weight.data.zero_() | |||
| def forward(self, xs): | |||
| return xs + self.ff2(torch.nn.functional.relu(self.ff1(self.norm(xs)))) | |||
| @@ -0,0 +1,125 @@ | |||
| ''' | |||
| Copyright 2020 The Microsoft DeepSpeed Team | |||
| ''' | |||
| from typing import Dict, List, Tuple | |||
| import torch | |||
| from .layer import MoE | |||
| def has_moe_layers(m): | |||
| has_moe = False | |||
| num_experts = 0 | |||
| for _, module in m.named_modules(): | |||
| if isinstance(module, MoE): | |||
| has_moe = True | |||
| num_experts = module.num_experts | |||
| break | |||
| return has_moe, num_experts | |||
| def is_moe_param(param: torch.Tensor) -> bool: | |||
| if hasattr(param, 'allreduce') and not param.allreduce: | |||
| return True | |||
| return False | |||
| def split_params_into_shared_and_expert_params( | |||
| params: List[torch.nn.Parameter] | |||
| ) -> Tuple[torch.nn.Parameter, torch.nn.Parameter]: | |||
| shared_params, expert_params = [], [] | |||
| for p in params: | |||
| if is_moe_param(p): | |||
| expert_params.append(p) | |||
| else: | |||
| shared_params.append(p) | |||
| return shared_params, expert_params | |||
| def split_params_grads_into_shared_and_expert_params( | |||
| group: List[torch.nn.Parameter] | |||
| ) -> Tuple[torch.nn.Parameter, torch.nn.Parameter]: | |||
| """Split grad of parameters into grads of non-expert params | |||
| and grads of expert params. This is useful while computing | |||
| grad-norms for clipping and overflow detection | |||
| group (List[torch.nn.Parameter]): | |||
| Args: | |||
| The group of parameters to split | |||
| Returns: | |||
| Tuple[List[torch.nn.Parameter], List[torch.nn.Parameter]]: | |||
| list of gradients for non MoE params, list of gradients of MoE params | |||
| """ | |||
| expert_grads = [] | |||
| shared_grads = [] | |||
| for p in group: | |||
| if p.grad is not None: | |||
| if is_moe_param(p): | |||
| expert_grads.append(p.grad.to(p.dtype)) | |||
| else: | |||
| shared_grads.append(p.grad.to(p.dtype)) | |||
| return shared_grads, expert_grads | |||
| def split_params_into_different_moe_groups_for_optimizer( | |||
| param_groups: Tuple[Dict]) -> Tuple[Dict]: | |||
| """Split parameters into different MoE groups for optimizer | |||
| Args: | |||
| param_groups (Tuple[Dict]): | |||
| The list of parameter groups to split | |||
| Returns: | |||
| Tuple[Dict]: | |||
| list of MoE/non-MoE groups for optimizer | |||
| """ | |||
| if isinstance(param_groups, tuple): | |||
| param_groups = list(param_groups) # Tuple cannot be modified | |||
| elif isinstance(param_groups, dict): | |||
| param_groups = [param_groups] | |||
| elif not isinstance(param_groups, list): | |||
| raise ValueError(f'Unknown param group type of {type(param_groups)}') | |||
| # gather all data parallel group names | |||
| data_parallel_group_names = set() | |||
| for param_group in param_groups: | |||
| for param in param_group['params']: | |||
| if is_moe_param(param): | |||
| data_parallel_group_names.add(param.group_name) | |||
| data_parallel_group_names = list(data_parallel_group_names) | |||
| group_moe = {} | |||
| # Create the param MoE groups, leave param assign to next step | |||
| for param_group in param_groups: | |||
| group_moe[param_group['name']] = {} | |||
| for key in data_parallel_group_names: | |||
| group_moe[param_group['name']][key] = {} | |||
| group_moe[param_group['name']][key]['name'] = key | |||
| group_moe[param_group['name']][key]['moe'] = True | |||
| for ori_key in param_group.keys(): | |||
| if ori_key != 'name': | |||
| if ori_key == 'params': | |||
| group_moe[param_group['name']][key][ori_key] = [] | |||
| else: | |||
| group_moe[param_group['name']][key][ | |||
| ori_key] = param_group[ori_key] | |||
| # Assign param | |||
| for param_group in param_groups: | |||
| new_params = [] | |||
| for param in param_group['params']: | |||
| if is_moe_param(param): | |||
| group_moe[param_group['name']][ | |||
| param.group_name]['params'].append(param) | |||
| # param_group['params'].remove(param) | |||
| else: | |||
| new_params.append(param) | |||
| param_group['params'] = new_params | |||
| # Flatten the moe groups | |||
| for k, v in group_moe.items(): | |||
| for k1, v1 in v.items(): | |||
| param_groups.append(v1) | |||
| return tuple(param_groups) | |||
| @@ -0,0 +1,62 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from typing import Dict | |||
| from modelscope.metainfo import Models | |||
| from modelscope.models.base import Tensor, TorchModel | |||
| from modelscope.models.builder import MODELS | |||
| from modelscope.utils.constant import Tasks | |||
| __all__ = ['GPTMoEForTextGeneration'] | |||
| @MODELS.register_module(Tasks.text_generation, module_name=Models.gpt_moe) | |||
| class GPTMoEForTextGeneration(TorchModel): | |||
| def __init__(self, model_dir: str, *args, **kwargs): | |||
| """initialize the text generation model from the `model_dir` path. | |||
| Args: | |||
| model_dir (str): the model path. | |||
| """ | |||
| super().__init__(model_dir, *args, **kwargs) | |||
| from modelscope.models.nlp.gpt_moe import GPTMoEModel | |||
| from transformers import BertTokenizer | |||
| print('****') | |||
| print(model_dir) | |||
| self.model = GPTMoEModel.from_pretrained(model_dir) | |||
| self.tokenizer = BertTokenizer.from_pretrained(model_dir) | |||
| def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | |||
| """return the result by the model | |||
| Args: | |||
| input (Dict[str, Tensor]): the preprocessed data | |||
| Returns: | |||
| Dict[str, Tensor]: results | |||
| Example: | |||
| { | |||
| 'logits': Tensor([[0.54, 0.32...])]), # logits | |||
| } | |||
| """ | |||
| return self.model(**input) | |||
| def generate(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | |||
| assert 'input_ids' in input, "generate function must accept 'input_ids' key" | |||
| input_ids = input['input_ids'] | |||
| if 'attention_mask' in input: | |||
| attention_mask = input['attention_mask'] | |||
| input_ids = input_ids[0][attention_mask[0].nonzero()] \ | |||
| .squeeze().unsqueeze(0) | |||
| # remove sep token at the end of tokenizer output | |||
| input_ids = input_ids[:, :-1] | |||
| gen_params = dict() | |||
| gen_params['inputs'] = input_ids | |||
| gen_params['do_sample'] = input.pop('do_sample', True) | |||
| gen_params['max_length'] = input.pop('max_length', 128) | |||
| gen_params['top_k'] = input.pop('top_k', 10) | |||
| gen_params['top_p'] = input.pop('top_p', None) | |||
| sample_output = self.model.generate(**gen_params) | |||
| return {'sequences': sample_output[0]} | |||
| @@ -0,0 +1,67 @@ | |||
| # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| from tokenizers import Tokenizer | |||
| class JiebaBPETokenizer: | |||
| """SentencePiece BPE tokenizer with Jieba integration""" | |||
| def __init__(self, tokenizer_json_file): | |||
| self.name = 'Jieba BPE Tokenizer' | |||
| self.tokenizer = Tokenizer.from_file(tokenizer_json_file) | |||
| self.eod_id = self.tokenizer.token_to_id('<|endoftext|>') | |||
| try: | |||
| import jieba | |||
| except ImportError: | |||
| raise ImportError( | |||
| 'You need to install rjieba to use JiebaTokenizer. ' | |||
| 'See https://pypi.org/project/rjieba/ for installation.') | |||
| self.jieba = jieba | |||
| self.new_line = self.vocab['\n'] | |||
| self.sep_token = self.vocab['<sep>'] | |||
| @property | |||
| def vocab_size(self): | |||
| return self.tokenizer.get_vocab_size(with_added_tokens=True) | |||
| @property | |||
| def vocab(self): | |||
| return self.tokenizer.get_vocab(with_added_tokens=True) | |||
| @property | |||
| def inv_vocab(self): | |||
| vocab = self.vocab | |||
| inv_vocab = dict() | |||
| for key, val in vocab.items(): | |||
| inv_vocab[val] = key | |||
| return inv_vocab | |||
| def tokenize(self, text, is_code=False): | |||
| if not is_code: | |||
| seg_list = [x for x in self.jieba.cut(text)] | |||
| return self.tokenizer.encode( | |||
| seg_list, is_pretokenized=True, add_special_tokens=True).ids | |||
| else: | |||
| return self.tokenizer.encode( | |||
| text, is_pretokenized=False, add_special_tokens=True).ids | |||
| def detokenize(self, token_ids): | |||
| text = self.tokenizer.decode(token_ids, skip_special_tokens=False) | |||
| return text | |||
| @property | |||
| def eod(self): | |||
| return self.eod_id | |||
| @@ -0,0 +1,54 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from typing import Any, Dict | |||
| import torch | |||
| from modelscope.metainfo import Pipelines | |||
| from modelscope.models.nlp.gpt_moe.distributed_gpt_moe import DistributedGPTMoE | |||
| from modelscope.pipelines.base import DistributedPipeline | |||
| from modelscope.pipelines.builder import PIPELINES | |||
| from modelscope.preprocessors import TextGenerationJiebaPreprocessor | |||
| from modelscope.utils.constant import Tasks | |||
| @PIPELINES.register_module( | |||
| Tasks.text_generation, module_name=Pipelines.gpt_moe_generation) | |||
| class DistributedGPTMoEPipeline(DistributedPipeline): | |||
| """This class is used to instantiate the gpt-moe model. | |||
| """ | |||
| model = None | |||
| def __init__(self, model, preprocessor=None, **kwargs): | |||
| if preprocessor is None: | |||
| preprocessor = TextGenerationJiebaPreprocessor(model) | |||
| super().__init__(model, preprocessor=preprocessor, **kwargs) | |||
| assert hasattr(preprocessor, 'tokenizer') | |||
| @classmethod | |||
| def _instantiate_one(cls, rank, model_dir, **kwargs): | |||
| cls.model = DistributedGPTMoE(model_dir, rank, **kwargs) | |||
| cls.model.eval() | |||
| @classmethod | |||
| def _forward_one(cls, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
| tokens = inputs['inputs']['input_ids'].cuda( | |||
| torch.cuda.current_device()) | |||
| return cls.model.generate(tokens) | |||
| def postprocess(self, inputs: Dict[str, Any], | |||
| **postprocess_params) -> Dict[str, str]: | |||
| """process the prediction results | |||
| Args: | |||
| inputs (Dict[str, Any]): _description_ | |||
| Returns: | |||
| Dict[str, str]: the prediction results | |||
| """ | |||
| from modelscope.outputs import OutputKeys | |||
| return { | |||
| OutputKeys.TEXT: | |||
| self.preprocessor.tokenizer.detokenize(inputs[0].tolist()) | |||
| } | |||
| @@ -0,0 +1,24 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import unittest | |||
| from modelscope.hub.snapshot_download import snapshot_download | |||
| from modelscope.pipelines import pipeline | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.test_utils import test_level | |||
| class TextGPTMoEGenerationTest(unittest.TestCase): | |||
| def setUp(self) -> None: | |||
| self.model_id_1_3B_MoE32 = 'PAI/nlp_gpt3_text-generation_1.3B_MoE-32' | |||
| self.model_dir_1_3B_MoE32 = snapshot_download(self.model_id_1_3B_MoE32) | |||
| self.input = '好的' | |||
| @unittest.skip('distributed gpt-moe 1.3B_MoE-32, skipped') | |||
| def test_gpt_moe_1_3B_MoE32(self): | |||
| pipe = pipeline(Tasks.text_generation, model=self.model_id_1_3B_MoE32) | |||
| print(pipe(self.input)) | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||