| @@ -25,6 +25,61 @@ from mindspore.common.initializer import TruncatedNormal, initializer | |||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from mindspore.ops import functional as F | from mindspore.ops import functional as F | ||||
| class LayerNorm(nn.Cell): | |||||
| """ | |||||
| Layer Normalization | |||||
| Args: | |||||
| normalized_shape: the corresponding shape of the normalized axes | |||||
| eps: epsilon, a small number avoiding zero division | |||||
| Inputs: | |||||
| x: input tensor | |||||
| Returns: | |||||
| rescaled_output: Tensor, returned tensor after layernorm | |||||
| """ | |||||
| def __init__(self, normalized_shape, eps=1e-5): | |||||
| super(LayerNorm, self).__init__() | |||||
| self.gamma = Parameter(initializer('ones', normalized_shape), name="gamma") | |||||
| self.beta = Parameter(initializer('zeros', normalized_shape), name="beta") | |||||
| self.mean = P.ReduceMean(keep_dims=True) | |||||
| self.eps = eps | |||||
| def construct(self, x): | |||||
| mean = self.mean(x, -1) | |||||
| variance = self.mean(F.square(x - mean)) | |||||
| output = (x - mean) / F.sqrt(variance + self.eps) | |||||
| rescaled_output = output * self.gamma + self.beta | |||||
| return rescaled_output | |||||
| class Softmax(nn.Cell): | |||||
| """ | |||||
| softmax realization | |||||
| Args: | |||||
| axis: the axis to be applied softmax | |||||
| Inputs: | |||||
| x: input tensor | |||||
| Returns: | |||||
| output: Tensor, returned tensor after softmax | |||||
| """ | |||||
| def __init__(self, axis=-1): | |||||
| super(Softmax, self).__init__() | |||||
| self.max = P.ArgMaxWithValue(axis=axis, keep_dims=True) | |||||
| self.sum = P.ReduceSum(keep_dims=True) | |||||
| self.axis = axis | |||||
| def construct(self, x): | |||||
| _, max_value = self.max(x) | |||||
| exp_x = F.tensor_pow(np.e, x - max_value) | |||||
| sum_x = self.sum(exp_x, self.axis) | |||||
| output = exp_x / sum_x | |||||
| return output | |||||
| class Mapping(nn.Cell): | class Mapping(nn.Cell): | ||||
| """ | """ | ||||
| A mapping function with a 3d input | A mapping function with a 3d input | ||||
| @@ -162,7 +217,6 @@ class Attention(nn.Cell): | |||||
| def __init__(self, config, scale=1.0, layer_idx=None): | def __init__(self, config, scale=1.0, layer_idx=None): | ||||
| super(Attention, self).__init__() | super(Attention, self).__init__() | ||||
| self.get_attention_mask = AttentionMask(config) | self.get_attention_mask = AttentionMask(config) | ||||
| self.expand_mapping = Mapping(config.embedding_size, 3*config.embedding_size, config.compute_dtype) | |||||
| self.projection = Mapping(config.embedding_size, config.embedding_size, config.compute_dtype, scale) | self.projection = Mapping(config.embedding_size, config.embedding_size, config.compute_dtype, scale) | ||||
| self.split = P.Split(axis=-1, output_num=3) | self.split = P.Split(axis=-1, output_num=3) | ||||
| self.transpose = P.Transpose() | self.transpose = P.Transpose() | ||||
| @@ -182,7 +236,6 @@ class Attention(nn.Cell): | |||||
| self.use_past = config.use_past | self.use_past = config.use_past | ||||
| self.dropout = nn.Dropout(1-config.dropout_rate) | self.dropout = nn.Dropout(1-config.dropout_rate) | ||||
| self.prob_dropout = nn.Dropout(1-config.dropout_rate) | self.prob_dropout = nn.Dropout(1-config.dropout_rate) | ||||
| self.softmax = nn.Softmax() | |||||
| self.dense1 = nn.Dense(config.embedding_size, config.embedding_size).to_float(config.compute_dtype) | self.dense1 = nn.Dense(config.embedding_size, config.embedding_size).to_float(config.compute_dtype) | ||||
| self.dense2 = nn.Dense(config.embedding_size, config.embedding_size).to_float(config.compute_dtype) | self.dense2 = nn.Dense(config.embedding_size, config.embedding_size).to_float(config.compute_dtype) | ||||
| @@ -285,9 +338,7 @@ class Attention(nn.Cell): | |||||
| attention_scores = adder + score | attention_scores = adder + score | ||||
| attention_scores = P.Cast()(attention_scores, ori_dtype) | attention_scores = P.Cast()(attention_scores, ori_dtype) | ||||
| shape = F.shape(attention_scores) | |||||
| attention_probs = nn.Softmax()(F.reshape(attention_scores, (-1, shape[-1]))) | |||||
| attention_probs = F.reshape(attention_probs, shape) | |||||
| attention_probs = Softmax()(attention_scores) | |||||
| attention_probs = self.prob_dropout(attention_probs) | attention_probs = self.prob_dropout(attention_probs) | ||||
| weighted_values = self.batch_matmul(attention_probs, value) | weighted_values = self.batch_matmul(attention_probs, value) | ||||
| @@ -313,9 +364,9 @@ class Block(nn.Cell): | |||||
| def __init__(self, config, layer_idx): | def __init__(self, config, layer_idx): | ||||
| super(Block, self).__init__() | super(Block, self).__init__() | ||||
| scale = 1 / math.sqrt(2.0*layer_idx) | scale = 1 / math.sqrt(2.0*layer_idx) | ||||
| self.layernorm1 = nn.LayerNorm((config.embedding_size,)).to_float(config.compute_dtype) | |||||
| self.layernorm1 = LayerNorm((config.embedding_size,)).to_float(config.compute_dtype) | |||||
| self.attention = Attention(config, scale, layer_idx) | self.attention = Attention(config, scale, layer_idx) | ||||
| self.layernorm2 = nn.LayerNorm((config.embedding_size,)).to_float(config.compute_dtype) | |||||
| self.layernorm2 = LayerNorm((config.embedding_size,)).to_float(config.compute_dtype) | |||||
| self.output = Output(config, scale) | self.output = Output(config, scale) | ||||
| self.post_layernorm_residual = config.post_layernorm_residual | self.post_layernorm_residual = config.post_layernorm_residual | ||||
| @@ -362,7 +413,7 @@ class GPT_Model(nn.Cell): | |||||
| self.blocks = nn.CellList() | self.blocks = nn.CellList() | ||||
| for i in range(config.num_layers): | for i in range(config.num_layers): | ||||
| self.blocks.append(Block(config, i+1)) | self.blocks.append(Block(config, i+1)) | ||||
| self.layernorm = nn.LayerNorm((config.embedding_size,)).to_float(config.compute_dtype) | |||||
| self.layernorm = LayerNorm((config.embedding_size,)).to_float(config.compute_dtype) | |||||
| self.use_past = config.use_past | self.use_past = config.use_past | ||||
| self.past = tuple([None]*config.num_layers) | self.past = tuple([None]*config.num_layers) | ||||
| self.num_layers = config.num_layers | self.num_layers = config.num_layers | ||||