diff --git a/mindspore/common/parameter.py b/mindspore/common/parameter.py index a311a33644..3c6482b231 100644 --- a/mindspore/common/parameter.py +++ b/mindspore/common/parameter.py @@ -454,6 +454,8 @@ class Parameter(Tensor_): return new_param def add_pipeline_stage(self, stage): + if not isinstance(stage, int) or stage < 0: + raise TypeError("`stage` must be a positive number of int type") self._pipeline_stage_list.append(stage) def set_data(self, data, slice_shape=False): diff --git a/mindspore/nn/cell.py b/mindspore/nn/cell.py index 1ae544ab91..7474e6705f 100755 --- a/mindspore/nn/cell.py +++ b/mindspore/nn/cell.py @@ -1313,12 +1313,12 @@ class Cell(Cell_): """ Infer pipeline stages of all parameters in the cell. - Notes: + Note: - If a parameter does not belong to any cell which has been set pipeline_stage, - the parameter should use add_pipeline_stage to add it's pipeline_stage information. + the parameter should use add_pipeline_stage to add it's pipeline_stage information. - If a parameter P has been used by two operator in different stages "stageA" and "stageB", - the parameter P should use P.add_pipeline_stage(stageA) and P.add_pipeline_stage(stageB) - to add it's stage information before use infer_param_pipeline_stage. + the parameter P should use P.add_pipeline_stage(stageA) and P.add_pipeline_stage(stageB) + to add it's stage information before use infer_param_pipeline_stage. Returns: The params belong to current stage in pipeline parallel. diff --git a/model_zoo/official/nlp/pangu_alpha/src/pangu_alpha.py b/model_zoo/official/nlp/pangu_alpha/src/pangu_alpha.py index e1cb9cc3aa..1857c646d0 100644 --- a/model_zoo/official/nlp/pangu_alpha/src/pangu_alpha.py +++ b/model_zoo/official/nlp/pangu_alpha/src/pangu_alpha.py @@ -80,23 +80,21 @@ class LayerNorm(nn.Cell): r""" A self-defined layer norm operation using reduce sum and reduce mean """ - def __init__(self, normalized_shape, dp=4, eps=1e-5, scale=1e-3): + def __init__(self, normalized_shape, dp=4, eps=1e-5, parallel_optimizer=False): super(LayerNorm, self).__init__() - self.gamma = Parameter(initializer('ones', normalized_shape), name="gamma") - self.beta = Parameter(initializer('zeros', normalized_shape), name="beta") + self.gamma = Parameter(initializer('ones', normalized_shape), name="gamma", + parallel_optimizer=parallel_optimizer) + self.beta = Parameter(initializer('zeros', normalized_shape), name="beta", + parallel_optimizer=parallel_optimizer) self.mean = P.ReduceMean(keep_dims=True).shard(((dp, 1, 1),)) self.square = P.Square().shard(((dp, 1, 1),)) self.sqrt = P.Sqrt().shard(((dp, 1, 1),)) self.sub1 = P.Sub().shard(((dp, 1, 1), (dp, 1, 1))) - self.sub2 = P.Sub().shard(((dp, 1, 1), (dp, 1, 1))) self.add = P.TensorAdd().shard(((dp, 1, 1), ())) - self.eps = eps self.mul = P.Mul().shard(((dp, 1, 1), (1,))) self.add2 = P.TensorAdd().shard(((dp, 1, 1), (1,))) self.real_div = P.RealDiv().shard(((dp, 1, 1), (dp, 1, 1))) - self.scale_div = P.RealDiv().shard(((dp, 1, 1), ())) - self.scale_mul = P.Mul().shard(((dp, 1, 1), ())) - self.scale = scale + self.eps = eps def construct(self, x): mean = self.mean(x, -1) diff = self.sub1(x, mean) @@ -278,53 +276,36 @@ class EmbeddingLookup(nn.Cell): super(EmbeddingLookup, self).__init__() self.vocab_size = config.vocab_size self.embedding_size = config.embedding_size - if config.load_ckpt_path: - # Loading the embedding table from the ckpt path: - embedding_path = os.path.join(config.load_ckpt_path, 'word_embedding.npy') - if os.path.exists(embedding_path): - e_table = np.load(embedding_path) - e_table = Tensor(e_table, mstype.float32) - self.embedding_table = Parameter(e_table, name="embedding_table") - else: - raise ValueError(f"{embedding_path} file not exits, please check whether word_embedding file exist.") - else: - self.embedding_table = Parameter(initializer( - Normal(0.02), [self.vocab_size, self.embedding_size]), - name="embedding_table") + if config.word_emb_dp: self.gather = P.GatherV2().shard(((1, 1), (config.dp, 1))) else: self.gather = P.GatherV2().shard(((config.mp, 1), (1, 1))) self.shape = (-1, config.seq_length, config.embedding_size) + if config.stage_num > 1: + self.construct = self.construct_pipeline + self.gather.add_prim_attr("parameter_start", 0) + else: + if config.load_ckpt_path: + # Loading the embedding table from the ckpt path: + embedding_path = os.path.join(config.load_ckpt_path, 'word_embedding.npy') + if os.path.exists(embedding_path): + e_table = np.load(embedding_path) + e_table = Tensor(e_table, mstype.float32) + self.embedding_table = Parameter(e_table, name="embedding_table") + else: + raise ValueError(f"{embedding_path} file not exits, " + f"please check whether word_embedding file exist.") + else: + self.embedding_table = Parameter(initializer(Normal(0.02), [self.vocab_size, self.embedding_size]), + name="embedding_table") + self.construct = self.construct_no_pipeline - def construct(self, input_ids): + def construct_no_pipeline(self, input_ids): output = self.gather(self.embedding_table, input_ids, 0) return output, self.embedding_table - -class EmbeddingLookupPipeline(nn.Cell): - """ - The embedding lookup table for vocabulary - Args: - config(PanguAlphaConfig): the config of network - Inputs: - input_ids: the tokenized inputs with datatype int32 - Returns: - output: Tensor, the embedding vector for the input with shape (batch_size, seq_length, embedding_size) - self.embedding_table: Tensor, the embedding table for the vocabulary - """ - def __init__(self, config): - super(EmbeddingLookupPipeline, self).__init__() - self.vocab_size = config.vocab_size - self.embedding_size = config.embedding_size - if config.word_emb_dp: - self.gather = P.GatherV2().shard(((1, 1), (config.dp, 1))) - else: - self.gather = P.GatherV2().shard(((config.mp, 1), (1, 1))) - self.gather.add_prim_attr("parameter_start", 0) - self.shape = (-1, config.seq_length, config.embedding_size) - - def construct(self, input_ids, table): + def construct_pipeline(self, input_ids, table): output = self.gather(table, input_ids, 0) return output @@ -653,12 +634,12 @@ class Decoder(nn.Cell): self.layernorm1.layer_norm.shard(((config.dp, 1, 1), (1,), (1,))) self.layernorm2 = nn.LayerNorm((config.embedding_size,)).to_float(mstype.float32) self.layernorm2.layer_norm.shard(((config.dp, 1, 1), (1,), (1,))) + self.layernorm1.gamma.parallel_optimizer = False + self.layernorm1.beta.parallel_optimizer = False + self.layernorm2.gamma.parallel_optimizer = False + self.layernorm2.beta.parallel_optimizer = False - self.layernorm1.gamma.parallel_optimizer = False - self.layernorm1.beta.parallel_optimizer = False self.attention = Attention(config, scale, layer_idx) - self.layernorm2.gamma.parallel_optimizer = False - self.layernorm2.beta.parallel_optimizer = False # Feed Forward Network, FFN self.output = Output(config, scale) self.post_layernorm_residual = config.post_layernorm_residual @@ -740,7 +721,7 @@ class PanguAlpha_EmbeddingPipeLine(nn.Cell): """ def __init__(self, config): super(PanguAlpha_EmbeddingPipeLine, self).__init__() - self.word_embedding = EmbeddingLookupPipeline(config) + self.word_embedding = EmbeddingLookup(config) self.position_embedding = nn.Embedding(config.seq_length, config.embedding_size, embedding_table=Normal(0.02)) @@ -856,11 +837,7 @@ class QueryLayer(nn.Cell): scale = 1 / math.sqrt(2.0 * config.num_layers) self.layernorm1 = LayerNorm((config.embedding_size,), config.dp).to_float(mstype.float32) self.layernorm2 = LayerNorm((config.embedding_size,), config.dp).to_float(mstype.float32) - self.layernorm1.gamma.parallel_optimizer = False - self.layernorm1.beta.parallel_optimizer = False self.attention = QueryLayerAttention(config, scale) - self.layernorm2.gamma.parallel_optimizer = False - self.layernorm2.beta.parallel_optimizer = False self.output = Output(config, scale) self.post_layernorm_residual = config.post_layernorm_residual self.add = P.TensorAdd().shard(((config.dp, 1, 1), (config.dp, 1, 1))) @@ -1060,8 +1037,8 @@ class PanguAlpha_Model(nn.Cell): mstype.float32).set_comm_fusion( int((num_layers - 1) / fusion_group_size) + 2) self.layernorm.layer_norm.shard(((config.dp, 1, 1), (1,), (1,))) - self.layernorm.gamma.parallel_optimizer = False - self.layernorm.beta.parallel_optimizer = False + self.layernorm.gamma.parallel_optimizer = False + self.layernorm.beta.parallel_optimizer = False self.use_past = config.use_past self.past = tuple([None] * config.num_layers) self.dtype = config.compute_dtype diff --git a/model_zoo/official/nlp/pangu_alpha/src/pangu_alpha_wrapcell.py b/model_zoo/official/nlp/pangu_alpha/src/pangu_alpha_wrapcell.py index 2fb79399f1..c03ba4a288 100644 --- a/model_zoo/official/nlp/pangu_alpha/src/pangu_alpha_wrapcell.py +++ b/model_zoo/official/nlp/pangu_alpha/src/pangu_alpha_wrapcell.py @@ -183,14 +183,11 @@ class PanguAlphaTrainPipelineWithLossScaleCell(nn.Cell): self.get_status = P.NPUGetFloatStatus().add_prim_attr("_side_effect_flag", False) self.clear_before_grad = P.NPUClearFloatStatus().add_prim_attr("_side_effect_flag", False) self.reduce_sum = P.ReduceSum(keep_dims=False) - #self.depend_parameter_use = P.ControlDepend(depend_mode=1) self.base = Tensor(1, mstype.float32) self.less_equal = P.LessEqual() self.hyper_map = C.HyperMap() self.loss_scale = None self.reshape = P.Reshape() - #self.control = P.ControlDepend(1) - #self.clip_norm = Tensor(1000.0, mstype.float32) self.loss_scaling_manager = scale_update_cell if scale_update_cell: self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32), @@ -216,7 +213,6 @@ class PanguAlphaTrainPipelineWithLossScaleCell(nn.Cell): # alloc status and clear should be right before gradoperation init = self.alloc_status() status_clear = self.clear_before_grad(init) - #clear_depend = self.control(status_clear, self.weights) grads = self.grad(self.network, weights)(input_ids, input_position, attention_mask, diff --git a/model_zoo/official/nlp/pangu_alpha/src/utils.py b/model_zoo/official/nlp/pangu_alpha/src/utils.py index 7ef9133f53..b2d96596d7 100644 --- a/model_zoo/official/nlp/pangu_alpha/src/utils.py +++ b/model_zoo/official/nlp/pangu_alpha/src/utils.py @@ -123,65 +123,46 @@ def _get_pipeline_group(): class GlobalNorm(nn.Cell): """ - - Calculate the global norm value of given tensors - - """ - - def __init__(self, params): - super(GlobalNorm, self).__init__() - self.norm = nn.Norm() - self.hyper_map = C.HyperMap() - self.allreduce_filter = tuple( - "projection.bias" not in x.name and "layernorm" not in x.name and "embedding_table" - not in x.name for x in params) - self.length = len(params) - self.values = [] - self.group_size = get_group_size() - for item in self.allreduce_filter: - if item: - self.values.append(1.0) - else: - self.values.append(self.group_size * 1.0) - self.values = tuple(self.values) - - def construct(self, grads): - # Square sum of gradients for current rank - square_sum_dp = self.hyper_map(get_square_sum, grads, self.values) - # Global square sum of gradients - global_norms = F.sqrt(P.AllReduce()(F.addn(square_sum_dp))) - return global_norms - -class GlobalNormPipline(nn.Cell): - """ Calculate the global norm value of given tensors """ def __init__(self, params, config): - super(GlobalNormPipline, self).__init__() + super(GlobalNorm, self).__init__() self.norm = nn.Norm() self.hyper_map = C.HyperMap() - self.allreduce_filter = tuple("projection.bias" not in x.name and "layernorm" not in x.name - and "position_embedding.embedding_table" not in x.name for x in params) + self.is_pipeline = (config.stage_num > 1) + if self.is_pipeline: + group_size = config.mp + group_list, group_name = _get_model_parallel_group(config.mp) + create_group(group_name, group_list) + self.allreduce = P.AllReduce(group=group_name) + pipeline_group_list, pipeline_group_name = _get_pipeline_group() + create_group(pipeline_group_name, pipeline_group_list) + self.allreduce2 = P.AllReduce(group=pipeline_group_name) + else: + group_size = get_group_size() + if config.word_emb_dp: + self.allreduce_filter = tuple("projection.bias" not in x.name and "layernorm" not in x.name + and "embedding_table" not in x.name for x in params) + else: + self.allreduce_filter = tuple("projection.bias" not in x.name and "layernorm" not in x.name + and "position_embedding.embedding_table" not in x.name for x in params) self.allreduce_group_size = () for item in self.allreduce_filter: if item: self.allreduce_group_size = self.allreduce_group_size + (1.0,) else: - self.allreduce_group_size = self.allreduce_group_size + (config.mp * 1.0,) - self.length = len(params) - group_list, group_name = _get_model_parallel_group(config.mp) - create_group(group_name, group_list) - self.allreduce = P.AllReduce(group=group_name) - pipeline_group_list, pipeline_group_name = _get_pipeline_group() - create_group(pipeline_group_name, pipeline_group_list) - self.allreduce2 = P.AllReduce(group=pipeline_group_name) + self.allreduce_group_size = self.allreduce_group_size + (group_size * 1.0,) def construct(self, grads): + """Calculate global norm construct""" square_sum = self.hyper_map(get_square_sum, grads, self.allreduce_group_size) square_reduce_sum = F.addn(square_sum) - stage_square_reduce_sum = self.allreduce(square_reduce_sum) - global_square_reduce_sum = self.allreduce2(stage_square_reduce_sum) - global_norms = F.sqrt(global_square_reduce_sum) + if self.is_pipeline: + stage_square_reduce_sum = self.allreduce(square_reduce_sum) + global_square_reduce_sum = self.allreduce2(stage_square_reduce_sum) + global_norms = F.sqrt(global_square_reduce_sum) + else: + global_norms = F.sqrt(P.AllReduce()(square_reduce_sum)) return global_norms class ClipByGlobalNorm(nn.Cell): @@ -192,14 +173,12 @@ class ClipByGlobalNorm(nn.Cell): """ def __init__(self, params, config, clip_norm=1.0): super(ClipByGlobalNorm, self).__init__() - if config.stage_num > 1: - self.global_norm = GlobalNormPipline(params, config) - else: - self.global_norm = GlobalNorm(params) + self.global_norm = GlobalNorm(params, config) self.clip_norm = Tensor([clip_norm], mstype.float32) self.hyper_map = C.HyperMap() def construct(self, grads): + """Clip grads by global norm construct""" global_norm_value = self.global_norm(grads) cond = P.GreaterEqual()(global_norm_value, self.clip_norm) global_norm = F.select(cond, global_norm_value, self.clip_norm) @@ -366,6 +345,10 @@ def add_training_params(opt): default=1, choices=[0, 1], help="Whether do data parallel in word embedding. default 1") + opt.add_argument("--data_column_name", + type=str, + default="input_ids", + help="Column name of datasets") def get_args(inference=False): diff --git a/model_zoo/official/nlp/pangu_alpha/train.py b/model_zoo/official/nlp/pangu_alpha/train.py index 9553673cdd..a26b21bf08 100644 --- a/model_zoo/official/nlp/pangu_alpha/train.py +++ b/model_zoo/official/nlp/pangu_alpha/train.py @@ -174,7 +174,8 @@ def run_train(args_opt): # Dataset loading mindrecord files ds = create_dataset(config.batch_size, data_path=cache_url, data_start_index=0, eod_reset=config.eod_reset, full_batch=bool(args_opt.full_batch), - eod_id=args_opt.eod_id, device_num=device_num, rank=rank, epoch=epoch_num) + eod_id=args_opt.eod_id, device_num=device_num, rank=rank, + column_name=args_opt.data_column_name, epoch=epoch_num) step_per_epoch = ds.get_dataset_size() callback_size = args_opt.sink_size actual_epoch_num = int(epoch_num * step_per_epoch / callback_size) @@ -269,7 +270,7 @@ def run_train_pipeline(args_opt): else: optimizer = nn.AdamWeightDecay(group_params, learning_rate=lr, beta1=0.9, beta2=0.95, eps=1e-8) ds = create_dataset(config.batch_size, data_path=args_opt.data_url, eod_reset=True, - data_start_index=0, full_batch=True) + data_start_index=0, full_batch=True, column_name=args_opt.data_column_name) epoch_num = args_opt.epoch_size step_per_epoch = ds.get_dataset_size() callback_size = args_opt.sink_size