| @@ -454,6 +454,8 @@ class Parameter(Tensor_): | |||
| return new_param | |||
| def add_pipeline_stage(self, stage): | |||
| if not isinstance(stage, int) or stage < 0: | |||
| raise TypeError("`stage` must be a positive number of int type") | |||
| self._pipeline_stage_list.append(stage) | |||
| def set_data(self, data, slice_shape=False): | |||
| @@ -1313,12 +1313,12 @@ class Cell(Cell_): | |||
| """ | |||
| Infer pipeline stages of all parameters in the cell. | |||
| Notes: | |||
| Note: | |||
| - If a parameter does not belong to any cell which has been set pipeline_stage, | |||
| the parameter should use add_pipeline_stage to add it's pipeline_stage information. | |||
| the parameter should use add_pipeline_stage to add it's pipeline_stage information. | |||
| - If a parameter P has been used by two operator in different stages "stageA" and "stageB", | |||
| the parameter P should use P.add_pipeline_stage(stageA) and P.add_pipeline_stage(stageB) | |||
| to add it's stage information before use infer_param_pipeline_stage. | |||
| the parameter P should use P.add_pipeline_stage(stageA) and P.add_pipeline_stage(stageB) | |||
| to add it's stage information before use infer_param_pipeline_stage. | |||
| Returns: | |||
| The params belong to current stage in pipeline parallel. | |||
| @@ -80,23 +80,21 @@ class LayerNorm(nn.Cell): | |||
| r""" | |||
| A self-defined layer norm operation using reduce sum and reduce mean | |||
| """ | |||
| def __init__(self, normalized_shape, dp=4, eps=1e-5, scale=1e-3): | |||
| def __init__(self, normalized_shape, dp=4, eps=1e-5, parallel_optimizer=False): | |||
| super(LayerNorm, self).__init__() | |||
| self.gamma = Parameter(initializer('ones', normalized_shape), name="gamma") | |||
| self.beta = Parameter(initializer('zeros', normalized_shape), name="beta") | |||
| self.gamma = Parameter(initializer('ones', normalized_shape), name="gamma", | |||
| parallel_optimizer=parallel_optimizer) | |||
| self.beta = Parameter(initializer('zeros', normalized_shape), name="beta", | |||
| parallel_optimizer=parallel_optimizer) | |||
| self.mean = P.ReduceMean(keep_dims=True).shard(((dp, 1, 1),)) | |||
| self.square = P.Square().shard(((dp, 1, 1),)) | |||
| self.sqrt = P.Sqrt().shard(((dp, 1, 1),)) | |||
| self.sub1 = P.Sub().shard(((dp, 1, 1), (dp, 1, 1))) | |||
| self.sub2 = P.Sub().shard(((dp, 1, 1), (dp, 1, 1))) | |||
| self.add = P.TensorAdd().shard(((dp, 1, 1), ())) | |||
| self.eps = eps | |||
| self.mul = P.Mul().shard(((dp, 1, 1), (1,))) | |||
| self.add2 = P.TensorAdd().shard(((dp, 1, 1), (1,))) | |||
| self.real_div = P.RealDiv().shard(((dp, 1, 1), (dp, 1, 1))) | |||
| self.scale_div = P.RealDiv().shard(((dp, 1, 1), ())) | |||
| self.scale_mul = P.Mul().shard(((dp, 1, 1), ())) | |||
| self.scale = scale | |||
| self.eps = eps | |||
| def construct(self, x): | |||
| mean = self.mean(x, -1) | |||
| diff = self.sub1(x, mean) | |||
| @@ -278,53 +276,36 @@ class EmbeddingLookup(nn.Cell): | |||
| super(EmbeddingLookup, self).__init__() | |||
| self.vocab_size = config.vocab_size | |||
| self.embedding_size = config.embedding_size | |||
| if config.load_ckpt_path: | |||
| # Loading the embedding table from the ckpt path: | |||
| embedding_path = os.path.join(config.load_ckpt_path, 'word_embedding.npy') | |||
| if os.path.exists(embedding_path): | |||
| e_table = np.load(embedding_path) | |||
| e_table = Tensor(e_table, mstype.float32) | |||
| self.embedding_table = Parameter(e_table, name="embedding_table") | |||
| else: | |||
| raise ValueError(f"{embedding_path} file not exits, please check whether word_embedding file exist.") | |||
| else: | |||
| self.embedding_table = Parameter(initializer( | |||
| Normal(0.02), [self.vocab_size, self.embedding_size]), | |||
| name="embedding_table") | |||
| if config.word_emb_dp: | |||
| self.gather = P.GatherV2().shard(((1, 1), (config.dp, 1))) | |||
| else: | |||
| self.gather = P.GatherV2().shard(((config.mp, 1), (1, 1))) | |||
| self.shape = (-1, config.seq_length, config.embedding_size) | |||
| if config.stage_num > 1: | |||
| self.construct = self.construct_pipeline | |||
| self.gather.add_prim_attr("parameter_start", 0) | |||
| else: | |||
| if config.load_ckpt_path: | |||
| # Loading the embedding table from the ckpt path: | |||
| embedding_path = os.path.join(config.load_ckpt_path, 'word_embedding.npy') | |||
| if os.path.exists(embedding_path): | |||
| e_table = np.load(embedding_path) | |||
| e_table = Tensor(e_table, mstype.float32) | |||
| self.embedding_table = Parameter(e_table, name="embedding_table") | |||
| else: | |||
| raise ValueError(f"{embedding_path} file not exits, " | |||
| f"please check whether word_embedding file exist.") | |||
| else: | |||
| self.embedding_table = Parameter(initializer(Normal(0.02), [self.vocab_size, self.embedding_size]), | |||
| name="embedding_table") | |||
| self.construct = self.construct_no_pipeline | |||
| def construct(self, input_ids): | |||
| def construct_no_pipeline(self, input_ids): | |||
| output = self.gather(self.embedding_table, input_ids, 0) | |||
| return output, self.embedding_table | |||
| class EmbeddingLookupPipeline(nn.Cell): | |||
| """ | |||
| The embedding lookup table for vocabulary | |||
| Args: | |||
| config(PanguAlphaConfig): the config of network | |||
| Inputs: | |||
| input_ids: the tokenized inputs with datatype int32 | |||
| Returns: | |||
| output: Tensor, the embedding vector for the input with shape (batch_size, seq_length, embedding_size) | |||
| self.embedding_table: Tensor, the embedding table for the vocabulary | |||
| """ | |||
| def __init__(self, config): | |||
| super(EmbeddingLookupPipeline, self).__init__() | |||
| self.vocab_size = config.vocab_size | |||
| self.embedding_size = config.embedding_size | |||
| if config.word_emb_dp: | |||
| self.gather = P.GatherV2().shard(((1, 1), (config.dp, 1))) | |||
| else: | |||
| self.gather = P.GatherV2().shard(((config.mp, 1), (1, 1))) | |||
| self.gather.add_prim_attr("parameter_start", 0) | |||
| self.shape = (-1, config.seq_length, config.embedding_size) | |||
| def construct(self, input_ids, table): | |||
| def construct_pipeline(self, input_ids, table): | |||
| output = self.gather(table, input_ids, 0) | |||
| return output | |||
| @@ -653,12 +634,12 @@ class Decoder(nn.Cell): | |||
| self.layernorm1.layer_norm.shard(((config.dp, 1, 1), (1,), (1,))) | |||
| self.layernorm2 = nn.LayerNorm((config.embedding_size,)).to_float(mstype.float32) | |||
| self.layernorm2.layer_norm.shard(((config.dp, 1, 1), (1,), (1,))) | |||
| self.layernorm1.gamma.parallel_optimizer = False | |||
| self.layernorm1.beta.parallel_optimizer = False | |||
| self.layernorm2.gamma.parallel_optimizer = False | |||
| self.layernorm2.beta.parallel_optimizer = False | |||
| self.layernorm1.gamma.parallel_optimizer = False | |||
| self.layernorm1.beta.parallel_optimizer = False | |||
| self.attention = Attention(config, scale, layer_idx) | |||
| self.layernorm2.gamma.parallel_optimizer = False | |||
| self.layernorm2.beta.parallel_optimizer = False | |||
| # Feed Forward Network, FFN | |||
| self.output = Output(config, scale) | |||
| self.post_layernorm_residual = config.post_layernorm_residual | |||
| @@ -740,7 +721,7 @@ class PanguAlpha_EmbeddingPipeLine(nn.Cell): | |||
| """ | |||
| def __init__(self, config): | |||
| super(PanguAlpha_EmbeddingPipeLine, self).__init__() | |||
| self.word_embedding = EmbeddingLookupPipeline(config) | |||
| self.word_embedding = EmbeddingLookup(config) | |||
| self.position_embedding = nn.Embedding(config.seq_length, | |||
| config.embedding_size, | |||
| embedding_table=Normal(0.02)) | |||
| @@ -856,11 +837,7 @@ class QueryLayer(nn.Cell): | |||
| scale = 1 / math.sqrt(2.0 * config.num_layers) | |||
| self.layernorm1 = LayerNorm((config.embedding_size,), config.dp).to_float(mstype.float32) | |||
| self.layernorm2 = LayerNorm((config.embedding_size,), config.dp).to_float(mstype.float32) | |||
| self.layernorm1.gamma.parallel_optimizer = False | |||
| self.layernorm1.beta.parallel_optimizer = False | |||
| self.attention = QueryLayerAttention(config, scale) | |||
| self.layernorm2.gamma.parallel_optimizer = False | |||
| self.layernorm2.beta.parallel_optimizer = False | |||
| self.output = Output(config, scale) | |||
| self.post_layernorm_residual = config.post_layernorm_residual | |||
| self.add = P.TensorAdd().shard(((config.dp, 1, 1), (config.dp, 1, 1))) | |||
| @@ -1060,8 +1037,8 @@ class PanguAlpha_Model(nn.Cell): | |||
| mstype.float32).set_comm_fusion( | |||
| int((num_layers - 1) / fusion_group_size) + 2) | |||
| self.layernorm.layer_norm.shard(((config.dp, 1, 1), (1,), (1,))) | |||
| self.layernorm.gamma.parallel_optimizer = False | |||
| self.layernorm.beta.parallel_optimizer = False | |||
| self.layernorm.gamma.parallel_optimizer = False | |||
| self.layernorm.beta.parallel_optimizer = False | |||
| self.use_past = config.use_past | |||
| self.past = tuple([None] * config.num_layers) | |||
| self.dtype = config.compute_dtype | |||
| @@ -183,14 +183,11 @@ class PanguAlphaTrainPipelineWithLossScaleCell(nn.Cell): | |||
| self.get_status = P.NPUGetFloatStatus().add_prim_attr("_side_effect_flag", False) | |||
| self.clear_before_grad = P.NPUClearFloatStatus().add_prim_attr("_side_effect_flag", False) | |||
| self.reduce_sum = P.ReduceSum(keep_dims=False) | |||
| #self.depend_parameter_use = P.ControlDepend(depend_mode=1) | |||
| self.base = Tensor(1, mstype.float32) | |||
| self.less_equal = P.LessEqual() | |||
| self.hyper_map = C.HyperMap() | |||
| self.loss_scale = None | |||
| self.reshape = P.Reshape() | |||
| #self.control = P.ControlDepend(1) | |||
| #self.clip_norm = Tensor(1000.0, mstype.float32) | |||
| self.loss_scaling_manager = scale_update_cell | |||
| if scale_update_cell: | |||
| self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32), | |||
| @@ -216,7 +213,6 @@ class PanguAlphaTrainPipelineWithLossScaleCell(nn.Cell): | |||
| # alloc status and clear should be right before gradoperation | |||
| init = self.alloc_status() | |||
| status_clear = self.clear_before_grad(init) | |||
| #clear_depend = self.control(status_clear, self.weights) | |||
| grads = self.grad(self.network, weights)(input_ids, | |||
| input_position, | |||
| attention_mask, | |||
| @@ -123,65 +123,46 @@ def _get_pipeline_group(): | |||
| class GlobalNorm(nn.Cell): | |||
| """ | |||
| Calculate the global norm value of given tensors | |||
| """ | |||
| def __init__(self, params): | |||
| super(GlobalNorm, self).__init__() | |||
| self.norm = nn.Norm() | |||
| self.hyper_map = C.HyperMap() | |||
| self.allreduce_filter = tuple( | |||
| "projection.bias" not in x.name and "layernorm" not in x.name and "embedding_table" | |||
| not in x.name for x in params) | |||
| self.length = len(params) | |||
| self.values = [] | |||
| self.group_size = get_group_size() | |||
| for item in self.allreduce_filter: | |||
| if item: | |||
| self.values.append(1.0) | |||
| else: | |||
| self.values.append(self.group_size * 1.0) | |||
| self.values = tuple(self.values) | |||
| def construct(self, grads): | |||
| # Square sum of gradients for current rank | |||
| square_sum_dp = self.hyper_map(get_square_sum, grads, self.values) | |||
| # Global square sum of gradients | |||
| global_norms = F.sqrt(P.AllReduce()(F.addn(square_sum_dp))) | |||
| return global_norms | |||
| class GlobalNormPipline(nn.Cell): | |||
| """ | |||
| Calculate the global norm value of given tensors | |||
| """ | |||
| def __init__(self, params, config): | |||
| super(GlobalNormPipline, self).__init__() | |||
| super(GlobalNorm, self).__init__() | |||
| self.norm = nn.Norm() | |||
| self.hyper_map = C.HyperMap() | |||
| self.allreduce_filter = tuple("projection.bias" not in x.name and "layernorm" not in x.name | |||
| and "position_embedding.embedding_table" not in x.name for x in params) | |||
| self.is_pipeline = (config.stage_num > 1) | |||
| if self.is_pipeline: | |||
| group_size = config.mp | |||
| group_list, group_name = _get_model_parallel_group(config.mp) | |||
| create_group(group_name, group_list) | |||
| self.allreduce = P.AllReduce(group=group_name) | |||
| pipeline_group_list, pipeline_group_name = _get_pipeline_group() | |||
| create_group(pipeline_group_name, pipeline_group_list) | |||
| self.allreduce2 = P.AllReduce(group=pipeline_group_name) | |||
| else: | |||
| group_size = get_group_size() | |||
| if config.word_emb_dp: | |||
| self.allreduce_filter = tuple("projection.bias" not in x.name and "layernorm" not in x.name | |||
| and "embedding_table" not in x.name for x in params) | |||
| else: | |||
| self.allreduce_filter = tuple("projection.bias" not in x.name and "layernorm" not in x.name | |||
| and "position_embedding.embedding_table" not in x.name for x in params) | |||
| self.allreduce_group_size = () | |||
| for item in self.allreduce_filter: | |||
| if item: | |||
| self.allreduce_group_size = self.allreduce_group_size + (1.0,) | |||
| else: | |||
| self.allreduce_group_size = self.allreduce_group_size + (config.mp * 1.0,) | |||
| self.length = len(params) | |||
| group_list, group_name = _get_model_parallel_group(config.mp) | |||
| create_group(group_name, group_list) | |||
| self.allreduce = P.AllReduce(group=group_name) | |||
| pipeline_group_list, pipeline_group_name = _get_pipeline_group() | |||
| create_group(pipeline_group_name, pipeline_group_list) | |||
| self.allreduce2 = P.AllReduce(group=pipeline_group_name) | |||
| self.allreduce_group_size = self.allreduce_group_size + (group_size * 1.0,) | |||
| def construct(self, grads): | |||
| """Calculate global norm construct""" | |||
| square_sum = self.hyper_map(get_square_sum, grads, self.allreduce_group_size) | |||
| square_reduce_sum = F.addn(square_sum) | |||
| stage_square_reduce_sum = self.allreduce(square_reduce_sum) | |||
| global_square_reduce_sum = self.allreduce2(stage_square_reduce_sum) | |||
| global_norms = F.sqrt(global_square_reduce_sum) | |||
| if self.is_pipeline: | |||
| stage_square_reduce_sum = self.allreduce(square_reduce_sum) | |||
| global_square_reduce_sum = self.allreduce2(stage_square_reduce_sum) | |||
| global_norms = F.sqrt(global_square_reduce_sum) | |||
| else: | |||
| global_norms = F.sqrt(P.AllReduce()(square_reduce_sum)) | |||
| return global_norms | |||
| class ClipByGlobalNorm(nn.Cell): | |||
| @@ -192,14 +173,12 @@ class ClipByGlobalNorm(nn.Cell): | |||
| """ | |||
| def __init__(self, params, config, clip_norm=1.0): | |||
| super(ClipByGlobalNorm, self).__init__() | |||
| if config.stage_num > 1: | |||
| self.global_norm = GlobalNormPipline(params, config) | |||
| else: | |||
| self.global_norm = GlobalNorm(params) | |||
| self.global_norm = GlobalNorm(params, config) | |||
| self.clip_norm = Tensor([clip_norm], mstype.float32) | |||
| self.hyper_map = C.HyperMap() | |||
| def construct(self, grads): | |||
| """Clip grads by global norm construct""" | |||
| global_norm_value = self.global_norm(grads) | |||
| cond = P.GreaterEqual()(global_norm_value, self.clip_norm) | |||
| global_norm = F.select(cond, global_norm_value, self.clip_norm) | |||
| @@ -366,6 +345,10 @@ def add_training_params(opt): | |||
| default=1, | |||
| choices=[0, 1], | |||
| help="Whether do data parallel in word embedding. default 1") | |||
| opt.add_argument("--data_column_name", | |||
| type=str, | |||
| default="input_ids", | |||
| help="Column name of datasets") | |||
| def get_args(inference=False): | |||
| @@ -174,7 +174,8 @@ def run_train(args_opt): | |||
| # Dataset loading mindrecord files | |||
| ds = create_dataset(config.batch_size, data_path=cache_url, | |||
| data_start_index=0, eod_reset=config.eod_reset, full_batch=bool(args_opt.full_batch), | |||
| eod_id=args_opt.eod_id, device_num=device_num, rank=rank, epoch=epoch_num) | |||
| eod_id=args_opt.eod_id, device_num=device_num, rank=rank, | |||
| column_name=args_opt.data_column_name, epoch=epoch_num) | |||
| step_per_epoch = ds.get_dataset_size() | |||
| callback_size = args_opt.sink_size | |||
| actual_epoch_num = int(epoch_num * step_per_epoch / callback_size) | |||
| @@ -269,7 +270,7 @@ def run_train_pipeline(args_opt): | |||
| else: | |||
| optimizer = nn.AdamWeightDecay(group_params, learning_rate=lr, beta1=0.9, beta2=0.95, eps=1e-8) | |||
| ds = create_dataset(config.batch_size, data_path=args_opt.data_url, eod_reset=True, | |||
| data_start_index=0, full_batch=True) | |||
| data_start_index=0, full_batch=True, column_name=args_opt.data_column_name) | |||
| epoch_num = args_opt.epoch_size | |||
| step_per_epoch = ds.get_dataset_size() | |||
| callback_size = args_opt.sink_size | |||