|
|
|
@@ -277,8 +277,8 @@ class RelaPosMatrixGenerator(nn.Cell): |
|
|
|
def __init__(self, length, max_relative_position): |
|
|
|
super(RelaPosMatrixGenerator, self).__init__() |
|
|
|
self._length = length |
|
|
|
self._max_relative_position = Tensor(max_relative_position, dtype=mstype.int32) |
|
|
|
self._min_relative_position = Tensor(-max_relative_position, dtype=mstype.int32) |
|
|
|
self._max_relative_position = max_relative_position |
|
|
|
self._min_relative_position = -max_relative_position |
|
|
|
self.range_length = -length + 1 |
|
|
|
|
|
|
|
self.tile = P.Tile() |
|
|
|
@@ -336,9 +336,7 @@ class RelaPosEmbeddingsGenerator(nn.Cell): |
|
|
|
self.relative_positions_matrix = RelaPosMatrixGenerator(length=length, |
|
|
|
max_relative_position=max_relative_position) |
|
|
|
self.reshape = P.Reshape() |
|
|
|
self.one_hot = P.OneHot() |
|
|
|
self.on_value = Tensor(1.0, mstype.float32) |
|
|
|
self.off_value = Tensor(0.0, mstype.float32) |
|
|
|
self.one_hot = nn.OneHot(depth=self.vocab_size) |
|
|
|
self.shape = P.Shape() |
|
|
|
self.gather = P.GatherV2() # index_select |
|
|
|
self.matmul = P.BatchMatMul() |
|
|
|
@@ -350,7 +348,7 @@ class RelaPosEmbeddingsGenerator(nn.Cell): |
|
|
|
if self.use_one_hot_embeddings: |
|
|
|
flat_relative_positions_matrix = self.reshape(relative_positions_matrix_out, (-1,)) |
|
|
|
one_hot_relative_positions_matrix = self.one_hot( |
|
|
|
flat_relative_positions_matrix, self.vocab_size, self.on_value, self.off_value) |
|
|
|
flat_relative_positions_matrix) |
|
|
|
embeddings = self.matmul(one_hot_relative_positions_matrix, self.embeddings_table) |
|
|
|
my_shape = self.shape(relative_positions_matrix_out) + (self.depth,) |
|
|
|
embeddings = self.reshape(embeddings, my_shape) |
|
|
|
@@ -372,11 +370,9 @@ class SaturateCast(nn.Cell): |
|
|
|
def __init__(self, src_type=mstype.float32, dst_type=mstype.float32): |
|
|
|
super(SaturateCast, self).__init__() |
|
|
|
np_type = mstype.dtype_to_nptype(dst_type) |
|
|
|
min_type = np.finfo(np_type).min |
|
|
|
max_type = np.finfo(np_type).max |
|
|
|
|
|
|
|
self.tensor_min_type = Tensor([min_type], dtype=src_type) |
|
|
|
self.tensor_max_type = Tensor([max_type], dtype=src_type) |
|
|
|
self.tensor_min_type = float(np.finfo(np_type).min) |
|
|
|
self.tensor_max_type = float(np.finfo(np_type).max) |
|
|
|
|
|
|
|
self.min_op = P.Minimum() |
|
|
|
self.max_op = P.Maximum() |
|
|
|
@@ -442,7 +438,7 @@ class BertAttention(nn.Cell): |
|
|
|
self.has_attention_mask = has_attention_mask |
|
|
|
self.use_relative_positions = use_relative_positions |
|
|
|
|
|
|
|
self.scores_mul = Tensor([1.0 / math.sqrt(float(self.size_per_head))], dtype=compute_type) |
|
|
|
self.scores_mul = 1.0 / math.sqrt(float(self.size_per_head)) |
|
|
|
self.reshape = P.Reshape() |
|
|
|
self.shape_from_2d = (-1, from_tensor_width) |
|
|
|
self.shape_to_2d = (-1, to_tensor_width) |
|
|
|
@@ -471,7 +467,7 @@ class BertAttention(nn.Cell): |
|
|
|
self.trans_shape = (0, 2, 1, 3) |
|
|
|
self.trans_shape_relative = (2, 0, 1, 3) |
|
|
|
self.trans_shape_position = (1, 2, 0, 3) |
|
|
|
self.multiply_data = Tensor([-10000.0,], dtype=compute_type) |
|
|
|
self.multiply_data = -10000.0 |
|
|
|
self.batch_num = batch_size * num_attention_heads |
|
|
|
self.matmul = P.BatchMatMul() |
|
|
|
|
|
|
|
|