|
|
|
@@ -21,7 +21,7 @@ import mindspore.nn as nn |
|
|
|
from mindspore.common.tensor import Tensor
|
|
|
|
from mindspore.common.parameter import Parameter
|
|
|
|
import mindspore.common.dtype as mstype
|
|
|
|
from mindspore.common.initializer import TruncatedNormal, initializer
|
|
|
|
from mindspore.common.initializer import TruncatedNormal, initializer, Normal
|
|
|
|
from mindspore.ops import operations as P
|
|
|
|
from mindspore.ops import functional as F
|
|
|
|
|
|
|
|
@@ -48,7 +48,7 @@ class LayerNorm(nn.Cell): |
|
|
|
|
|
|
|
def construct(self, x):
|
|
|
|
mean = self.mean(x, -1)
|
|
|
|
variance = self.mean(F.square(x - mean))
|
|
|
|
variance = self.mean(F.square(x - mean), -1)
|
|
|
|
output = (x - mean) / F.sqrt(variance + self.eps)
|
|
|
|
rescaled_output = output * self.gamma + self.beta
|
|
|
|
return rescaled_output
|
|
|
|
@@ -100,10 +100,8 @@ class Mapping(nn.Cell): |
|
|
|
super(Mapping, self).__init__()
|
|
|
|
self.output_size = output_size
|
|
|
|
self.input_size = input_size
|
|
|
|
weight = np.random.normal(loc=0.0, scale=0.02*scale, size=(input_size, output_size))
|
|
|
|
bias = np.zeros(shape=(output_size,))
|
|
|
|
self.weight = Parameter(Tensor(weight, mstype.float32), name="mapping_weight")
|
|
|
|
self.bias = Parameter(Tensor(bias, mstype.float32), name="mapping_bias")
|
|
|
|
self.weight = Parameter(initializer(Normal(sigma=0.02*scale), [input_size, output_size]), name="mapping_weight")
|
|
|
|
self.bias = Parameter(initializer("zeros", [output_size,]), name="mapping_bias")
|
|
|
|
self.dtype = dtype
|
|
|
|
self.cast = P.Cast()
|
|
|
|
|
|
|
|
@@ -363,7 +361,6 @@ class Block(nn.Cell): |
|
|
|
"""
|
|
|
|
def __init__(self, config, layer_idx):
|
|
|
|
super(Block, self).__init__()
|
|
|
|
scale = 1 / math.sqrt(2.0*layer_idx)
|
|
|
|
self.layernorm1 = LayerNorm((config.embedding_size,)).to_float(config.compute_dtype)
|
|
|
|
self.attention = Attention(config, scale, layer_idx)
|
|
|
|
self.layernorm2 = LayerNorm((config.embedding_size,)).to_float(config.compute_dtype)
|
|
|
|
|