diff --git a/model_zoo/official/nlp/gpt/src/gpt.py b/model_zoo/official/nlp/gpt/src/gpt.py index 31537e4e15..016f055f42 100644 --- a/model_zoo/official/nlp/gpt/src/gpt.py +++ b/model_zoo/official/nlp/gpt/src/gpt.py @@ -25,6 +25,61 @@ from mindspore.common.initializer import TruncatedNormal, initializer from mindspore.ops import operations as P from mindspore.ops import functional as F +class LayerNorm(nn.Cell): + """ + Layer Normalization + + Args: + normalized_shape: the corresponding shape of the normalized axes + eps: epsilon, a small number avoiding zero division + + Inputs: + x: input tensor + + Returns: + rescaled_output: Tensor, returned tensor after layernorm + """ + def __init__(self, normalized_shape, eps=1e-5): + super(LayerNorm, self).__init__() + self.gamma = Parameter(initializer('ones', normalized_shape), name="gamma") + self.beta = Parameter(initializer('zeros', normalized_shape), name="beta") + self.mean = P.ReduceMean(keep_dims=True) + self.eps = eps + + def construct(self, x): + mean = self.mean(x, -1) + variance = self.mean(F.square(x - mean)) + output = (x - mean) / F.sqrt(variance + self.eps) + rescaled_output = output * self.gamma + self.beta + return rescaled_output + +class Softmax(nn.Cell): + """ + softmax realization + + Args: + axis: the axis to be applied softmax + + Inputs: + x: input tensor + + Returns: + output: Tensor, returned tensor after softmax + """ + def __init__(self, axis=-1): + super(Softmax, self).__init__() + self.max = P.ArgMaxWithValue(axis=axis, keep_dims=True) + self.sum = P.ReduceSum(keep_dims=True) + self.axis = axis + + def construct(self, x): + _, max_value = self.max(x) + exp_x = F.tensor_pow(np.e, x - max_value) + sum_x = self.sum(exp_x, self.axis) + output = exp_x / sum_x + return output + + class Mapping(nn.Cell): """ A mapping function with a 3d input @@ -162,7 +217,6 @@ class Attention(nn.Cell): def __init__(self, config, scale=1.0, layer_idx=None): super(Attention, self).__init__() self.get_attention_mask = AttentionMask(config) - self.expand_mapping = Mapping(config.embedding_size, 3*config.embedding_size, config.compute_dtype) self.projection = Mapping(config.embedding_size, config.embedding_size, config.compute_dtype, scale) self.split = P.Split(axis=-1, output_num=3) self.transpose = P.Transpose() @@ -182,7 +236,6 @@ class Attention(nn.Cell): self.use_past = config.use_past self.dropout = nn.Dropout(1-config.dropout_rate) self.prob_dropout = nn.Dropout(1-config.dropout_rate) - self.softmax = nn.Softmax() self.dense1 = nn.Dense(config.embedding_size, config.embedding_size).to_float(config.compute_dtype) self.dense2 = nn.Dense(config.embedding_size, config.embedding_size).to_float(config.compute_dtype) @@ -285,9 +338,7 @@ class Attention(nn.Cell): attention_scores = adder + score attention_scores = P.Cast()(attention_scores, ori_dtype) - shape = F.shape(attention_scores) - attention_probs = nn.Softmax()(F.reshape(attention_scores, (-1, shape[-1]))) - attention_probs = F.reshape(attention_probs, shape) + attention_probs = Softmax()(attention_scores) attention_probs = self.prob_dropout(attention_probs) weighted_values = self.batch_matmul(attention_probs, value) @@ -313,9 +364,9 @@ class Block(nn.Cell): def __init__(self, config, layer_idx): super(Block, self).__init__() scale = 1 / math.sqrt(2.0*layer_idx) - self.layernorm1 = nn.LayerNorm((config.embedding_size,)).to_float(config.compute_dtype) + self.layernorm1 = LayerNorm((config.embedding_size,)).to_float(config.compute_dtype) self.attention = Attention(config, scale, layer_idx) - self.layernorm2 = nn.LayerNorm((config.embedding_size,)).to_float(config.compute_dtype) + self.layernorm2 = LayerNorm((config.embedding_size,)).to_float(config.compute_dtype) self.output = Output(config, scale) self.post_layernorm_residual = config.post_layernorm_residual @@ -362,7 +413,7 @@ class GPT_Model(nn.Cell): self.blocks = nn.CellList() for i in range(config.num_layers): self.blocks.append(Block(config, i+1)) - self.layernorm = nn.LayerNorm((config.embedding_size,)).to_float(config.compute_dtype) + self.layernorm = LayerNorm((config.embedding_size,)).to_float(config.compute_dtype) self.use_past = config.use_past self.past = tuple([None]*config.num_layers) self.num_layers = config.num_layers