| @@ -225,8 +225,9 @@ bool CompFunc(const AnfNodePtr &node1, const AnfNodePtr &node2) { | |||
| if (rank_tag2 == nullptr) { | |||
| rank_tag2 = prim2->GetAttr(DEST_RANK); | |||
| } | |||
| MS_EXCEPTION_IF_NULL(rank_tag1); | |||
| MS_EXCEPTION_IF_NULL(rank_tag2); | |||
| if (!rank_tag1 || !rank_tag2) { | |||
| return false; | |||
| } | |||
| auto rank1_value = GetValue<int64_t>(rank_tag1); | |||
| auto rank2_value = GetValue<int64_t>(rank_tag2); | |||
| if (rank1_value == rank2_value) { | |||
| @@ -136,7 +136,7 @@ const char PYTHON_EXTERN_MINDSPORE_FLAG[] = "_mindspore_flags"; | |||
| // define the parse constant | |||
| const int64_t MAX_COMPARISON_OPS_SUPPORTED = 1; | |||
| const char CUSTOM_BPROP_NAME[] = "bprop"; | |||
| const char STAGE_NAME[] = "pipeline_stage"; | |||
| const char STAGE_NAME[] = "_pipeline_stage"; | |||
| // define the Namespace name | |||
| const char RESOLVE_NAMESPACE_NAME_AST[] = "Ast"; // for ast type namespace | |||
| @@ -154,6 +154,7 @@ class Parameter(Tensor_): | |||
| self._cast_type = None | |||
| self._unique = False | |||
| self.is_in_parallel = _is_in_parallel_mode() | |||
| self._pipeline_stage_list = [] | |||
| if isinstance(default_input, (Tensor_, Tensor)): | |||
| Tensor_.__init__(self, default_input.dtype, default_input.shape) | |||
| elif isinstance(default_input, int): | |||
| @@ -452,6 +453,9 @@ class Parameter(Tensor_): | |||
| new_param.param_info = self.param_info | |||
| return new_param | |||
| def add_pipeline_stage(self, stage): | |||
| self._pipeline_stage_list.append(stage) | |||
| def set_data(self, data, slice_shape=False): | |||
| """ | |||
| Set Parameter's data. | |||
| @@ -230,6 +230,18 @@ class Cell(Cell_): | |||
| raise TypeError("'parallel_parameter_name_list' must be list type.") | |||
| self._parallel_parameter_name_list = value | |||
| @property | |||
| def pipeline_stage(self): | |||
| return self._pipeline_stage | |||
| @pipeline_stage.setter | |||
| def pipeline_stage(self, value): | |||
| if not isinstance(value, int): | |||
| raise TypeError("'pipeline_stage' must be int type.") | |||
| self._pipeline_stage = value | |||
| for item in self.trainable_params(): | |||
| item.add_pipeline_stage(value) | |||
| @property | |||
| def parallel_parameter_merge_net_dict(self): | |||
| return self._parallel_parameter_merge_net_dict | |||
| @@ -1297,6 +1309,41 @@ class Cell(Cell_): | |||
| for cell in self.cells(): | |||
| cell.recompute(mode, True) | |||
| def infer_param_pipeline_stage(self): | |||
| """ | |||
| Infer pipeline stages of all parameters in the cell. | |||
| Notes: | |||
| - If a parameter does not belong to any cell which has been set pipeline_stage, | |||
| the parameter should use add_pipeline_stage to add it's pipeline_stage information. | |||
| - If a parameter P has been used by two operator in different stages "stageA" and "stageB", | |||
| the parameter P should use P.add_pipeline_stage(stageA) and P.add_pipeline_stage(stageB) | |||
| to add it's stage information before use infer_param_pipeline_stage. | |||
| Returns: | |||
| The params belong to current stage in pipeline parallel. | |||
| Raises: | |||
| RuntimeError: If there is a parameter does not belong to any stage. | |||
| """ | |||
| from mindspore.communication import get_group_size, get_rank | |||
| stage_num = context.get_auto_parallel_context("pipeline_stages") | |||
| device_num = get_group_size() | |||
| rank_id = get_rank() | |||
| per_stage_devices = device_num // stage_num | |||
| current_stage = rank_id // per_stage_devices | |||
| params = [] | |||
| for param in self.trainable_params(): | |||
| if not param._pipeline_stage_list: | |||
| raise RuntimeError("The parameter {} does not belong to any stage, " | |||
| "please check whether the cell where the param locates" | |||
| " has been set pipeline_stage. " | |||
| "Otherwise, the parameter should use add_pipeline_stage " | |||
| "to add its stage information".format(param.name)) | |||
| if current_stage in param._pipeline_stage_list: | |||
| params.append(param) | |||
| return params | |||
| class GraphKernel(Cell): | |||
| """ | |||
| @@ -302,6 +302,32 @@ class EmbeddingLookup(nn.Cell): | |||
| return output, self.embedding_table | |||
| class EmbeddingLookupPipeline(nn.Cell): | |||
| """ | |||
| The embedding lookup table for vocabulary | |||
| Args: | |||
| config(PanguAlphaConfig): the config of network | |||
| Inputs: | |||
| input_ids: the tokenized inputs with datatype int32 | |||
| Returns: | |||
| output: Tensor, the embedding vector for the input with shape (batch_size, seq_length, embedding_size) | |||
| self.embedding_table: Tensor, the embedding table for the vocabulary | |||
| """ | |||
| def __init__(self, config): | |||
| super(EmbeddingLookupPipeline, self).__init__() | |||
| self.vocab_size = config.vocab_size | |||
| self.embedding_size = config.embedding_size | |||
| if config.word_emb_dp: | |||
| self.gather = P.GatherV2().shard(((1, 1), (config.dp, 1))) | |||
| else: | |||
| self.gather = P.GatherV2().shard(((config.mp, 1), (1, 1))) | |||
| self.gather.add_prim_attr("parameter_start", 0) | |||
| self.shape = (-1, config.seq_length, config.embedding_size) | |||
| def construct(self, input_ids, table): | |||
| output = self.gather(table, input_ids, 0) | |||
| return output | |||
| class Attention(nn.Cell): | |||
| """ | |||
| Self-Attention module for each layer | |||
| @@ -593,6 +619,44 @@ class Block(nn.Cell): | |||
| output = self.add(x, mlp_logit) | |||
| return output, layer_present | |||
| class PanguAlpha_EmbeddingPipeLine(nn.Cell): | |||
| """ | |||
| PanguAlpha_EmbeddingPipeLine | |||
| """ | |||
| def __init__(self, config): | |||
| super(PanguAlpha_EmbeddingPipeLine, self).__init__() | |||
| self.word_embedding = EmbeddingLookupPipeline(config) | |||
| self.position_embedding = nn.Embedding(config.seq_length, | |||
| config.embedding_size, | |||
| embedding_table=Normal(0.02)) | |||
| self.position_embedding.gather.shard(((1, 1), (config.dp,))) | |||
| self.position_embedding.expand.shard(((config.dp, 1),)) | |||
| self.add = P.TensorAdd().shard(((config.dp, 1, 1), (config.dp, 1, 1))) | |||
| self.dropout = Dropout(1 - config.dropout_rate) | |||
| self.dropout.dropout_gen_mask.shard(((config.dp, 1, 1),)) | |||
| self.dropout.dropout_do_mask.shard(((config.dp, 1, 1),)) | |||
| def construct(self, input_ids, table, input_position): | |||
| input_embedding = self.word_embedding(input_ids, table) | |||
| position_embedding = self.position_embedding(input_position) | |||
| hidden_states = self.add(input_embedding, position_embedding) | |||
| hidden_states = self.dropout(hidden_states) | |||
| hidden_states = P.Cast()(hidden_states, mstype.float16) | |||
| return hidden_states | |||
| class PanguAlpha_Mask(nn.Cell): | |||
| """ | |||
| PanguAlpha_Mask | |||
| """ | |||
| def __init__(self, config): | |||
| super(PanguAlpha_Mask, self).__init__() | |||
| self.get_attention_mask = AttentionMask(config) | |||
| self.dtype = config.compute_dtype | |||
| self.expand_dims = P.ExpandDims().shard(((config.dp, 1, 1),)) | |||
| def construct(self, input_mask, attention_mask): | |||
| attention_mask = self.expand_dims(attention_mask, 1) | |||
| return attention_mask | |||
| class QueryLayerAttention(Attention): | |||
| r""" | |||
| @@ -828,6 +892,74 @@ class PanguAlpha_Model(nn.Cell): | |||
| present_layer = present_layer + (present,) | |||
| return output_state, present_layer, embedding_table | |||
| class PanguAlpha_ModelPipeline(nn.Cell): | |||
| """ | |||
| The backbone of PanguAlpha network | |||
| Args: | |||
| config(PanguAlphaConfig): the config of network | |||
| Inputs: | |||
| input_ids: the tokenized inputs with datatype int32 | |||
| input_mask: the mask indicating whether each position is a valid input | |||
| layer_past: the previous feature map | |||
| Returns: | |||
| output_state: Tensor, the output logit of backbone | |||
| present_layer: Tensor, the current feature map | |||
| embedding_table: Tensor, the embedding table for the vocabulary | |||
| """ | |||
| def __init__(self, config): | |||
| super(PanguAlpha_ModelPipeline, self).__init__() | |||
| self.pangu_alpha_embedding = PanguAlpha_EmbeddingPipeLine(config).set_comm_fusion(1) | |||
| self.pangu_alpha_embedding.pipeline_stage = 0 | |||
| self.pangu_alpha_mask = PanguAlpha_Mask(config) | |||
| self.blocks = nn.CellList() | |||
| self.top_query_embedding = nn.Embedding(config.seq_length, config.embedding_size, | |||
| embedding_table=TruncatedNormal(0.02)) | |||
| self.top_query_embedding.gather.shard(((1, 1), (config.dp,))) | |||
| self.top_query_embedding.expand.shard(((config.dp, 1),)) | |||
| for i in range(config.num_layers): | |||
| if i == config.num_layers - 1: | |||
| self.top_query_embedding.set_comm_fusion(2) | |||
| self.top_query_embedding.pipeline_stage = i * config.stage_num // config.num_layers | |||
| per_block = QueryLayer(config).set_comm_fusion(2) | |||
| else: | |||
| per_block = Block(config, i + 1).set_comm_fusion(2) | |||
| per_block.pipeline_stage = i * config.stage_num // config.num_layers | |||
| per_block.recompute() | |||
| self.blocks.append(per_block) | |||
| if config.self_layernorm: | |||
| self.layernorm = LayerNorm((config.embedding_size,), config.dp).to_float(mstype.float32) | |||
| else: | |||
| self.layernorm = nn.LayerNorm( | |||
| (config.embedding_size,)).to_float(mstype.float32) | |||
| self.layernorm.layer_norm.shard(((config.dp, 1, 1), (1,), (1,))) | |||
| self.layernorm.set_comm_fusion(2) | |||
| self.layernorm.pipeline_stage = config.stage_num - 1 | |||
| self.use_past = config.use_past | |||
| self.past = tuple([None] * config.num_layers) | |||
| self.dtype = config.compute_dtype | |||
| self.num_layers = config.num_layers | |||
| def construct(self, input_ids, input_mask, table, input_position, attention_mask, layer_past=None): | |||
| """PanguAlpha model""" | |||
| if not self.use_past: | |||
| layer_past = self.past | |||
| hidden_states = self.pangu_alpha_embedding(input_ids, table, input_position) | |||
| attention_mask = self.pangu_alpha_mask(input_mask, attention_mask) | |||
| present_layer = () | |||
| for i in range(self.num_layers-1): | |||
| hidden_states, present = self.blocks[i](hidden_states, | |||
| attention_mask, layer_past) | |||
| present_layer = present_layer + (present,) | |||
| top_query_hidden_states = self.top_query_embedding(input_position) | |||
| hidden_states, present = self.blocks[self.num_layers-1](hidden_states, top_query_hidden_states, | |||
| attention_mask, layer_past) | |||
| present_layer = present_layer + (present,) | |||
| output_state = self.layernorm(hidden_states) | |||
| output_state = F.cast(output_state, self.dtype) | |||
| return output_state, present_layer | |||
| class PanguAlpha_Head(nn.Cell): | |||
| """ | |||
| @@ -883,6 +1015,35 @@ class PanguAlpha(nn.Cell): | |||
| logits = self.head(output_states, embedding_table) | |||
| return logits | |||
| class PanguAlphaPipeline(nn.Cell): | |||
| """ | |||
| The PanguAlpha network consisting of two parts the backbone and the head | |||
| Args: | |||
| config(PanguAlphaConfig): the config of network | |||
| Inputs: | |||
| input_ids: the tokenized inputs | |||
| input_mask: the mask indicating whether each position is a valid input | |||
| past: the previous feature map | |||
| Returns: | |||
| logits: Tensor: the logits of the corresponding inputs with shape (batch_size, seq_length, vocab_size) | |||
| """ | |||
| def __init__(self, config): | |||
| super(PanguAlphaPipeline, self).__init__() | |||
| self.backbone = PanguAlpha_ModelPipeline(config) | |||
| self.head = PanguAlpha_Head(config) | |||
| self.head.pipeline_stage = config.stage_num - 1 | |||
| self.vocab_size = config.vocab_size | |||
| self.embedding_size = config.embedding_size | |||
| self.embedding_table = Parameter(initializer(Normal(0.02), [self.vocab_size, self.embedding_size]), | |||
| name="embedding_table") | |||
| self.embedding_table.add_pipeline_stage(self.backbone.blocks[0].pipeline_stage) | |||
| self.embedding_table.add_pipeline_stage(self.head.pipeline_stage) | |||
| def construct(self, input_ids, input_mask, input_position, attention_mask, past=None): | |||
| output_states, _ = self.backbone(input_ids, input_mask, self.embedding_table, | |||
| input_position, attention_mask, past) | |||
| logits = self.head(output_states, self.embedding_table) | |||
| return logits | |||
| class CrossEntropyLoss(nn.Cell): | |||
| """ | |||
| @@ -1010,6 +1171,38 @@ class PanguAlphaWithLoss(nn.Cell): | |||
| output = self.loss(logits, labels, input_mask) | |||
| return output | |||
| class PanguAlphaWithLossPipeline(nn.Cell): | |||
| """ | |||
| PanguAlpha training loss | |||
| Args: | |||
| network: backbone network of PanguAlpha | |||
| loss: loss function, e.g., crossentropy | |||
| eos_token: the end_of_sentence token | |||
| Inputs: | |||
| input_ids: the tokenized inputs | |||
| past: the previous feature map | |||
| Returns: | |||
| output: Tensor, the loss of the network | |||
| """ | |||
| def __init__(self, config, network, loss, eos_token=6): | |||
| super(PanguAlphaWithLossPipeline, self).__init__(auto_prefix=False) | |||
| self.network = network | |||
| self.loss = loss | |||
| self.eos_token = eos_token | |||
| self.slice = P.StridedSlice().shard(((config.dp, 1),)) | |||
| self.not_equal = P.NotEqual().shard(((config.dp, 1), ())) | |||
| self.batch_size = config.batch_size | |||
| self.len = config.seq_length | |||
| self.micro_batch_step = config.micro_size | |||
| def construct(self, input_ids, input_position, attention_mask): | |||
| tokens = self.slice(input_ids, (0, 0), (self.batch_size // self.micro_batch_step, -1), (1, 1)) | |||
| input_mask = F.cast(self.not_equal(tokens, self.eos_token), mstype.float32) | |||
| logits = self.network(tokens, input_mask, input_position, attention_mask) | |||
| labels = self.slice(input_ids, (0, 1), (self.batch_size // self.micro_batch_step, | |||
| self.len + 1), (1, 1)) | |||
| output = self.loss(logits, labels, input_mask) | |||
| return output | |||
| class EvalNet(nn.Cell): | |||
| """ | |||
| @@ -89,20 +89,27 @@ def set_parse(args_opt): | |||
| args_opt.embedding_size = 16384 | |||
| args_opt.num_layers = 64 | |||
| args_opt.num_heads = 128 | |||
| args_opt.per_batch_size = 1 | |||
| args_opt.word_emb_dp = 0 | |||
| if args_opt.run_type == "train": | |||
| args_opt.start_lr = 6e-5 | |||
| args_opt.end_lr = 6e-6 | |||
| args_opt.optimizer_shard = 0 | |||
| args_opt.stage_num = 16 | |||
| args_opt.micro_size = 32 | |||
| args_opt.op_level_model_parallel_num = 16 | |||
| if args_opt.optimizer_shard == 1: | |||
| args_opt.op_level_model_parallel_num = 8 | |||
| elif args_opt.run_type == "predict": | |||
| args_opt.stage_num = 4 | |||
| args_opt.micro_size = 1 | |||
| args_opt.op_level_model_parallel_num = 16 | |||
| if args_opt.optimizer_shard == 1: | |||
| args_opt.op_level_model_parallel_num = 8 | |||
| elif args_opt.mode == "13B": | |||
| args_opt.embedding_size = 5120 | |||
| args_opt.num_layers = 40 | |||
| args_opt.num_heads = 40 | |||
| args_opt.word_emb_dp = 1 | |||
| args_opt.op_level_model_parallel_num = 8 | |||
| if args_opt.run_type == "train": | |||
| args_opt.start_lr = 5e-5 | |||
| @@ -20,8 +20,11 @@ from mindspore.ops import composite as C | |||
| from mindspore.ops import functional as F | |||
| from mindspore.common.tensor import Tensor | |||
| import mindspore.common.dtype as mstype | |||
| from mindspore.ops.operations.comm_ops import _VirtualDataset | |||
| from mindspore.nn.wrap.loss_scale import TrainOneStepWithLossScaleCell | |||
| from mindspore import context, Parameter | |||
| from mindspore.context import ParallelMode | |||
| from mindspore.nn.wrap.grad_reducer import DistributedGradReducer | |||
| from mindspore.communication.management import get_group_size | |||
| from src.utils import ClipByGlobalNorm | |||
| GRADIENT_CLIP_TYPE = 1 | |||
| @@ -65,16 +68,14 @@ reciprocal = P.Reciprocal() | |||
| def tensor_grad_scale(scale, grad): | |||
| return grad * reciprocal(scale) | |||
| class VirtualDatasetOneInputCell(nn.Cell): | |||
| def __init__(self, backbone): | |||
| super(VirtualDatasetOneInputCell, self).__init__(auto_prefix=False) | |||
| self._backbone = backbone | |||
| self._virtual_dataset = _VirtualDataset() | |||
| def construct(self, *data): | |||
| data_ = self._virtual_dataset(*data) | |||
| return self._backbone(*data_) | |||
| @grad_scale.register("Tensor", "Tensor", "Tensor") | |||
| def tensor_grad_scale_pipeline(scale, grad, accu_grad): | |||
| accu_grad = F.depend(accu_grad, grad) | |||
| new_grad = accu_grad * reciprocal(scale) | |||
| accu_grad = F.depend(accu_grad, new_grad) | |||
| zeros = F.tensor_mul(accu_grad, 0.0) | |||
| _ = F.assign(accu_grad, zeros) | |||
| return new_grad | |||
| class PanguAlphaTrainOneStepWithLossScaleCell(TrainOneStepWithLossScaleCell): | |||
| """ | |||
| @@ -102,7 +103,7 @@ class PanguAlphaTrainOneStepWithLossScaleCell(TrainOneStepWithLossScaleCell): | |||
| self.optimizer = optimizer | |||
| self.default_lr = Tensor([0.0], dtype=mstype.float32) | |||
| self.enable_global_norm = enable_global_norm | |||
| self.clip = ClipByGlobalNorm(self.weights) | |||
| self.clip = ClipByGlobalNorm(self.weights, config) | |||
| self.cast = P.Cast() | |||
| def construct(self, input_ids, input_position=None, attention_mask=None, layer_past=None, sens=None): | |||
| @@ -142,3 +143,111 @@ class PanguAlphaTrainOneStepWithLossScaleCell(TrainOneStepWithLossScaleCell): | |||
| else: | |||
| succ = self.optimizer(grads) | |||
| return F.depend(loss, succ), cond, scaling_sens | |||
| class PanguAlphaTrainPipelineWithLossScaleCell(nn.Cell): | |||
| """ | |||
| Encapsulation class of PanguAlpha network training. | |||
| Append an optimizer to the training network after that the construct | |||
| function can be called to create the backward graph. | |||
| Args: | |||
| network (Cell): The training network. Note that loss function should have been added. | |||
| optimizer (Optimizer): Optimizer for updating the weights. | |||
| scale_update_cell (Cell): Cell to do the loss scale. Default: None. | |||
| """ | |||
| def __init__(self, network, optimizer, config, scale_update_cell=None, enable_global_norm=True): | |||
| super(PanguAlphaTrainPipelineWithLossScaleCell, self).__init__(auto_prefix=False) | |||
| self.config = config | |||
| self.network = network | |||
| self.network.add_flags(defer_inline=True) | |||
| self.weights = optimizer.parameters | |||
| self.accu_grads = self.weights.clone(prefix="accu_grads", init="zeros") | |||
| self.optimizer = optimizer | |||
| self.enable_global_norm = enable_global_norm | |||
| self.grad = C.GradOperation(get_by_list=True, | |||
| sens_param=True) | |||
| self.reducer_flag = False | |||
| self.allreduce = P.AllReduce() | |||
| self.parallel_mode = context.get_auto_parallel_context("parallel_mode") | |||
| if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: | |||
| self.reducer_flag = True | |||
| self.grad_reducer = F.identity | |||
| self.degree = 1 | |||
| if self.reducer_flag: | |||
| self.degree = get_group_size() | |||
| self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree) | |||
| self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) | |||
| self.cast = P.Cast() | |||
| self.alloc_status = P.NPUAllocFloatStatus().add_prim_attr("_side_effect_flag", False) | |||
| self.get_status = P.NPUGetFloatStatus().add_prim_attr("_side_effect_flag", False) | |||
| self.clear_before_grad = P.NPUClearFloatStatus().add_prim_attr("_side_effect_flag", False) | |||
| self.reduce_sum = P.ReduceSum(keep_dims=False) | |||
| #self.depend_parameter_use = P.ControlDepend(depend_mode=1) | |||
| self.base = Tensor(1, mstype.float32) | |||
| self.less_equal = P.LessEqual() | |||
| self.hyper_map = C.HyperMap() | |||
| self.loss_scale = None | |||
| self.reshape = P.Reshape() | |||
| #self.control = P.ControlDepend(1) | |||
| #self.clip_norm = Tensor(1000.0, mstype.float32) | |||
| self.loss_scaling_manager = scale_update_cell | |||
| if scale_update_cell: | |||
| self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32), | |||
| name="loss_scale") | |||
| self.clip = ClipByGlobalNorm(self.weights, self.config) | |||
| self.micro_size = config.micro_size | |||
| @C.add_flags(has_effect=True) | |||
| def construct(self, | |||
| input_ids, | |||
| input_position, | |||
| attention_mask, | |||
| past=None, | |||
| sens=None): | |||
| """Defines the computation performed.""" | |||
| weights = self.weights | |||
| loss = self.network(input_ids, input_position, attention_mask) | |||
| if sens is None: | |||
| scaling_sens = self.loss_scale | |||
| scaling_sens = self.reshape(scaling_sens, (1,)) | |||
| else: | |||
| scaling_sens = sens | |||
| # alloc status and clear should be right before gradoperation | |||
| init = self.alloc_status() | |||
| status_clear = self.clear_before_grad(init) | |||
| #clear_depend = self.control(status_clear, self.weights) | |||
| grads = self.grad(self.network, weights)(input_ids, | |||
| input_position, | |||
| attention_mask, | |||
| self.cast(scaling_sens / self.micro_size, | |||
| mstype.float32)) | |||
| init = F.depend(init, grads) | |||
| get_status = self.get_status(init) | |||
| init = F.depend(init, get_status) | |||
| flag_sum = self.reduce_sum(init, (0,)) | |||
| loss = F.depend(loss, status_clear) | |||
| # apply grad reducer on grads | |||
| accu_grads = self.grad_reducer(self.accu_grads) | |||
| grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads, accu_grads) | |||
| if self.enable_global_norm: | |||
| grads, _ = self.clip(grads) | |||
| else: | |||
| grads = self.hyper_map( | |||
| F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), | |||
| grads) | |||
| if self.is_distributed: | |||
| # sum overflow flag over devices | |||
| flag_reduce = self.allreduce(flag_sum) | |||
| cond = self.less_equal(self.base, flag_reduce) | |||
| else: | |||
| cond = self.less_equal(self.base, flag_sum) | |||
| overflow = cond | |||
| if sens is None: | |||
| overflow = self.loss_scaling_manager(self.loss_scale, cond) | |||
| if overflow: | |||
| succ = False | |||
| else: | |||
| succ = self.optimizer(grads) | |||
| ret = (loss, overflow, scaling_sens) | |||
| return F.depend(ret, succ) | |||
| @@ -26,8 +26,8 @@ from mindspore.ops import functional as F | |||
| import mindspore.common.dtype as mstype | |||
| from mindspore.common.tensor import Tensor | |||
| from mindspore.nn.learning_rate_schedule import LearningRateSchedule, PolynomialDecayLR, WarmUpLR, CosineDecayLR | |||
| from mindspore.parallel._utils import _get_global_rank | |||
| from mindspore.communication.management import get_group_size | |||
| from mindspore.parallel._auto_parallel_context import auto_parallel_context | |||
| from mindspore.communication.management import get_rank, get_group_size, create_group | |||
| from mindspore.nn import AdamWeightDecay | |||
| from mindspore.common import Parameter, ParameterTuple | |||
| from mindspore.common.initializer import initializer | |||
| @@ -71,14 +71,12 @@ class FP32StateAdamWeightDecay(AdamWeightDecay): | |||
| get_square_sum = C.MultitypeFuncGraph("get_square_sum") | |||
| @get_square_sum.register("Tensor", "Tensor") | |||
| @get_square_sum.register("Tensor", "Number") | |||
| def _get_square_sum(grad, value): | |||
| norm = P.ReduceSum(False)(F.square(grad) / value, ()) | |||
| norm = P.ReduceSum(False)(F.square(grad), ()) / value | |||
| norm = F.expand_dims(F.cast(norm, mstype.float32), 0) | |||
| return norm | |||
| apply_global_norm = C.MultitypeFuncGraph("apply_global_norm") | |||
| @@ -87,6 +85,41 @@ def _apply_global_norm(clip_norm, global_norm, grad): | |||
| grad = grad * clip_norm / global_norm | |||
| return grad | |||
| def _get_model_parallel_group(mp): | |||
| """ | |||
| Calculate the communication group of model parallel dim in one pipeline stage | |||
| """ | |||
| rank = get_rank() | |||
| stage_nums = auto_parallel_context().get_pipeline_stages() | |||
| device_nums = get_group_size() | |||
| per_stage_device_nums = device_nums // stage_nums | |||
| stage_id = rank // per_stage_device_nums | |||
| local_stage_rank_id = rank % per_stage_device_nums | |||
| index = local_stage_rank_id // mp | |||
| group = range(0, mp) | |||
| rank_str_list = [str(x + index * mp + stage_id * per_stage_device_nums) for x in group] | |||
| rank_list_str = "-".join(rank_str_list) | |||
| rank_list = [x + index * mp + stage_id * per_stage_device_nums for x in group] | |||
| return rank_list, rank_list_str | |||
| def _get_pipeline_group(): | |||
| """ | |||
| Calculate the communication group between all pipeline stages | |||
| """ | |||
| rank = get_rank() | |||
| stage_nums = auto_parallel_context().get_pipeline_stages() | |||
| device_nums = get_group_size() | |||
| per_stage_device_nums = device_nums // stage_nums | |||
| local_stage_rank_id = rank % per_stage_device_nums | |||
| group = range(0, stage_nums) | |||
| rank_list = [local_stage_rank_id + x * per_stage_device_nums for x in group] | |||
| rank_str_list = [str(local_stage_rank_id + x * per_stage_device_nums) for x in group] | |||
| rank_list_str = "-".join(rank_str_list) | |||
| return rank_list, rank_list_str | |||
| class GlobalNorm(nn.Cell): | |||
| """ | |||
| @@ -107,9 +140,9 @@ class GlobalNorm(nn.Cell): | |||
| self.group_size = get_group_size() | |||
| for item in self.allreduce_filter: | |||
| if item: | |||
| self.values.append(Tensor([1.0], mstype.float32)) | |||
| self.values.append(1.0) | |||
| else: | |||
| self.values.append(Tensor([self.group_size * 1.0], mstype.float32)) | |||
| self.values.append(self.group_size * 1.0) | |||
| self.values = tuple(self.values) | |||
| def construct(self, grads): | |||
| @@ -119,6 +152,37 @@ class GlobalNorm(nn.Cell): | |||
| global_norms = F.sqrt(P.AllReduce()(F.addn(square_sum_dp))) | |||
| return global_norms | |||
| class GlobalNormPipline(nn.Cell): | |||
| """ | |||
| Calculate the global norm value of given tensors | |||
| """ | |||
| def __init__(self, params, config): | |||
| super(GlobalNormPipline, self).__init__() | |||
| self.norm = nn.Norm() | |||
| self.hyper_map = C.HyperMap() | |||
| self.allreduce_filter = tuple("projection.bias" not in x.name and "layernorm" not in x.name | |||
| and "position_embedding.embedding_table" not in x.name for x in params) | |||
| self.allreduce_group_size = () | |||
| for item in self.allreduce_filter: | |||
| if item: | |||
| self.allreduce_group_size = self.allreduce_group_size + (1.0,) | |||
| else: | |||
| self.allreduce_group_size = self.allreduce_group_size + (config.mp * 1.0,) | |||
| self.length = len(params) | |||
| group_list, group_name = _get_model_parallel_group(config.mp) | |||
| create_group(group_name, group_list) | |||
| self.allreduce = P.AllReduce(group=group_name) | |||
| pipeline_group_list, pipeline_group_name = _get_pipeline_group() | |||
| create_group(pipeline_group_name, pipeline_group_list) | |||
| self.allreduce2 = P.AllReduce(group=pipeline_group_name) | |||
| def construct(self, grads): | |||
| square_sum = self.hyper_map(get_square_sum, grads, self.allreduce_group_size) | |||
| square_reduce_sum = F.addn(square_sum) | |||
| stage_square_reduce_sum = self.allreduce(square_reduce_sum) | |||
| global_square_reduce_sum = self.allreduce2(stage_square_reduce_sum) | |||
| global_norms = F.sqrt(global_square_reduce_sum) | |||
| return global_norms | |||
| class ClipByGlobalNorm(nn.Cell): | |||
| """ | |||
| @@ -126,10 +190,12 @@ class ClipByGlobalNorm(nn.Cell): | |||
| Clip grads by global norm | |||
| """ | |||
| def __init__(self, params, clip_norm=1.0): | |||
| def __init__(self, params, config, clip_norm=1.0): | |||
| super(ClipByGlobalNorm, self).__init__() | |||
| self.global_norm = GlobalNorm(params) | |||
| if config.stage_num > 1: | |||
| self.global_norm = GlobalNormPipline(params, config) | |||
| else: | |||
| self.global_norm = GlobalNorm(params) | |||
| self.clip_norm = Tensor([clip_norm], mstype.float32) | |||
| self.hyper_map = C.HyperMap() | |||
| @@ -140,14 +206,6 @@ class ClipByGlobalNorm(nn.Cell): | |||
| grads = self.hyper_map(F.partial(apply_global_norm, self.clip_norm, global_norm), grads) | |||
| return grads, global_norm_value | |||
| def _get_model_parallel_group(dp, mp): | |||
| rank = _get_global_rank() | |||
| group = range(0, mp) | |||
| index = rank // dp | |||
| return [x + index * mp for x in group] | |||
| class LearningRate(LearningRateSchedule): | |||
| """ | |||
| Warmup-decay learning rate for PanguAlpha network. | |||
| @@ -258,6 +316,10 @@ def add_training_params(opt): | |||
| type=int, | |||
| default=2000, | |||
| help="Warmup step, default is 2000.") | |||
| opt.add_argument("--decay_steps", | |||
| type=int, | |||
| default=200000, | |||
| help="Decay step, default is 200000.") | |||
| opt.add_argument("--optimizer", | |||
| type=str, | |||
| default="adam", | |||
| @@ -298,6 +360,11 @@ def add_training_params(opt): | |||
| type=int, | |||
| default=8, | |||
| help="The model parallel way. default 8") | |||
| opt.add_argument("--word_emb_dp", | |||
| type=int, | |||
| default=1, | |||
| choices=[0, 1], | |||
| help="Whether do data parallel in word embedding. default 1") | |||
| def get_args(inference=False): | |||
| @@ -29,10 +29,11 @@ from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell | |||
| import mindspore.common.dtype as mstype | |||
| from mindspore.parallel import set_algo_parameters | |||
| from mindspore.parallel._cost_model_context import _set_multi_subgraphs | |||
| from mindspore.nn.wrap.cell_wrapper import PipelineCell, _VirtualDatasetCell | |||
| from src.dataset import create_dataset | |||
| from src.pangu_alpha import PanguAlpha, PanguAlphaWithLoss, CrossEntropyLoss | |||
| from src.pangu_alpha_wrapcell import PanguAlphaTrainOneStepWithLossScaleCell, VirtualDatasetOneInputCell | |||
| from src.pangu_alpha import PanguAlpha, PanguAlphaWithLoss,\ | |||
| PanguAlphaPipeline, PanguAlphaWithLossPipeline, CrossEntropyLoss | |||
| from src.pangu_alpha_wrapcell import PanguAlphaTrainOneStepWithLossScaleCell, PanguAlphaTrainPipelineWithLossScaleCell | |||
| from src.pangu_alpha_config import PANGUALPHAConfig, set_parse | |||
| from src.utils import LearningRate, get_args, FP32StateAdamWeightDecay | |||
| from src.utils import download_data | |||
| @@ -132,14 +133,14 @@ def run_train(args_opt): | |||
| micro_size=args_opt.micro_size, | |||
| eod_reset=bool(args_opt.eod_reset), | |||
| param_init_type=mstype.float32 if args_opt.param_init_type == 'fp32' else mstype.float16, | |||
| word_emb_dp=True) | |||
| word_emb_dp=bool(args_opt.word_emb_dp)) | |||
| print("===config is: ", config, flush=True) | |||
| # Define network | |||
| pangu_alpha = PanguAlpha(config) | |||
| loss = CrossEntropyLoss(config) | |||
| pangu_alpha_with_loss = PanguAlphaWithLoss(config, pangu_alpha, loss) | |||
| pangu_alpha_with_loss = VirtualDatasetOneInputCell(pangu_alpha_with_loss) | |||
| pangu_alpha_with_loss = _VirtualDatasetCell(pangu_alpha_with_loss) | |||
| print("=====args_opt is: ", args_opt, flush=True) | |||
| @@ -189,8 +190,111 @@ def run_train(args_opt): | |||
| print("Dataset size: {}, actual_epoch_num: {}".format(ds.get_dataset_size(), actual_epoch_num), flush=True) | |||
| model.train(actual_epoch_num, ds, callbacks=callback, sink_size=callback_size, dataset_sink_mode=True) | |||
| def run_train_pipeline(args_opt): | |||
| r""" | |||
| The main training process in pipeline. | |||
| """ | |||
| device_id = int(os.getenv("DEVICE_ID")) | |||
| context.set_context(save_graphs=False, | |||
| mode=context.GRAPH_MODE, | |||
| device_target="Ascend", | |||
| device_id=device_id) | |||
| context.set_context(variable_memory_max_size="31GB") | |||
| if args_opt.distribute == "true": | |||
| D.init() | |||
| device_num = D.get_group_size() | |||
| rank_id = D.get_rank() | |||
| context.reset_auto_parallel_context() | |||
| context.set_auto_parallel_context( | |||
| parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, | |||
| gradients_mean=False, | |||
| device_num=device_num, | |||
| full_batch=True, | |||
| loss_repeated_mean=True, | |||
| enable_parallel_optimizer=bool(args_opt.optimizer_shard), | |||
| pipeline_stages=args_opt.stage_num) | |||
| set_algo_parameters(elementwise_op_strategy_follow=True) | |||
| _set_multi_subgraphs() | |||
| else: | |||
| rank_id = int(os.getenv("RANK_ID")) | |||
| device_num = 1 | |||
| model_parallel_num = args_opt.op_level_model_parallel_num | |||
| stage_device_num = int(device_num / args_opt.stage_num) | |||
| data_parallel_num = int(stage_device_num / model_parallel_num) | |||
| per_batch_size = args_opt.per_batch_size | |||
| batch_size = per_batch_size * data_parallel_num * args_opt.micro_size | |||
| config = PANGUALPHAConfig( | |||
| data_parallel_num=data_parallel_num, | |||
| model_parallel_num=model_parallel_num, | |||
| batch_size=batch_size, | |||
| seq_length=args_opt.seq_length, | |||
| vocab_size=args_opt.vocab_size, | |||
| embedding_size=args_opt.embedding_size, | |||
| num_layers=args_opt.num_layers, | |||
| num_heads=args_opt.num_heads, | |||
| expand_ratio=4, | |||
| post_layernorm_residual=False, | |||
| dropout_rate=0.1, | |||
| compute_dtype=mstype.float16, | |||
| use_past=False, | |||
| self_layernorm=True, | |||
| stage_num=args_opt.stage_num, | |||
| micro_size=args_opt.micro_size, | |||
| word_emb_dp=bool(args_opt.word_emb_dp)) | |||
| print("===config is: ", config, flush=True) | |||
| pangu_alpha = PanguAlphaPipeline(config) | |||
| loss = CrossEntropyLoss(config) | |||
| pangu_alpha_with_loss = PipelineCell(PanguAlphaWithLossPipeline(config, pangu_alpha, loss), config.micro_size) | |||
| pangu_alpha_with_loss = _VirtualDatasetCell(pangu_alpha_with_loss) | |||
| print("=====args_opt is: ", args_opt, flush=True) | |||
| lr = LearningRate(learning_rate=args_opt.start_lr, | |||
| end_learning_rate=args_opt.end_lr, | |||
| warmup_steps=args_opt.warmup_step, | |||
| decay_steps=args_opt.decay_steps) | |||
| params = pangu_alpha.infer_param_pipeline_stage() | |||
| decay_filter = lambda x: 'layernorm' not in x.name.lower() and "bias" not in x.name.lower() | |||
| decay_params = list(filter(decay_filter, params)) | |||
| other_params = list(filter(lambda x: not decay_filter(x), params)) | |||
| group_params = [{ | |||
| 'params': decay_params, | |||
| 'weight_decay': 1e-1 | |||
| }, { | |||
| 'params': other_params, | |||
| 'weight_decay': 0.0 | |||
| }, { | |||
| 'order_params': params | |||
| }] | |||
| if args_opt.optimizer == "lamb": | |||
| optimizer = nn.Lamb(group_params, learning_rate=lr) | |||
| else: | |||
| optimizer = nn.AdamWeightDecay(group_params, learning_rate=lr, beta1=0.9, beta2=0.95, eps=1e-8) | |||
| ds = create_dataset(config.batch_size, data_path=args_opt.data_url, eod_reset=True, | |||
| data_start_index=0, full_batch=True) | |||
| epoch_num = args_opt.epoch_size | |||
| step_per_epoch = ds.get_dataset_size() | |||
| callback_size = args_opt.sink_size | |||
| actual_epoch_num = int(epoch_num * step_per_epoch / callback_size) | |||
| callback = [ | |||
| TimeMonitor(callback_size), | |||
| LossCallBack(callback_size, rank_id, config.stage_num) | |||
| ] | |||
| loss_scale_value = math.pow(2, 32) | |||
| update_cell = DynamicLossScaleUpdateCell(loss_scale_value=loss_scale_value, | |||
| scale_factor=2, | |||
| scale_window=1000) | |||
| pangu_alpha_with_grads = PanguAlphaTrainPipelineWithLossScaleCell( | |||
| pangu_alpha_with_loss, optimizer=optimizer, config=config, scale_update_cell=update_cell) | |||
| model = Model(pangu_alpha_with_grads) | |||
| model.train(actual_epoch_num, | |||
| ds, | |||
| callbacks=callback, | |||
| sink_size=callback_size, | |||
| dataset_sink_mode=True) | |||
| if __name__ == "__main__": | |||
| opt = get_args() | |||
| set_parse(opt) | |||
| run_train(opt) | |||
| if opt.stage_num > 1: | |||
| run_train_pipeline(opt) | |||
| else: | |||
| run_train(opt) | |||