|
|
|
@@ -32,7 +32,6 @@ class BertConfig: |
|
|
|
Configuration for `BertModel`. |
|
|
|
|
|
|
|
Args: |
|
|
|
batch_size (int): Batch size of input dataset. |
|
|
|
seq_length (int): Length of input sequence. Default: 128. |
|
|
|
vocab_size (int): The shape of each embedding vector. Default: 32000. |
|
|
|
hidden_size (int): Size of the bert encoder layers. Default: 768. |
|
|
|
@@ -52,15 +51,10 @@ class BertConfig: |
|
|
|
type_vocab_size (int): Size of token type vocab. Default: 16. |
|
|
|
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. |
|
|
|
use_relative_positions (bool): Specifies whether to use relative positions. Default: False. |
|
|
|
input_mask_from_dataset (bool): Specifies whether to use the input mask that loaded from |
|
|
|
dataset. Default: True. |
|
|
|
token_type_ids_from_dataset (bool): Specifies whether to use the token type ids that loaded |
|
|
|
from dataset. Default: True. |
|
|
|
dtype (:class:`mindspore.dtype`): Data type of the input. Default: mstype.float32. |
|
|
|
compute_type (:class:`mindspore.dtype`): Compute type in BertTransformer. Default: mstype.float32. |
|
|
|
""" |
|
|
|
def __init__(self, |
|
|
|
batch_size, |
|
|
|
seq_length=128, |
|
|
|
vocab_size=32000, |
|
|
|
hidden_size=768, |
|
|
|
@@ -74,11 +68,8 @@ class BertConfig: |
|
|
|
type_vocab_size=16, |
|
|
|
initializer_range=0.02, |
|
|
|
use_relative_positions=False, |
|
|
|
input_mask_from_dataset=True, |
|
|
|
token_type_ids_from_dataset=True, |
|
|
|
dtype=mstype.float32, |
|
|
|
compute_type=mstype.float32): |
|
|
|
self.batch_size = batch_size |
|
|
|
self.seq_length = seq_length |
|
|
|
self.vocab_size = vocab_size |
|
|
|
self.hidden_size = hidden_size |
|
|
|
@@ -91,8 +82,6 @@ class BertConfig: |
|
|
|
self.max_position_embeddings = max_position_embeddings |
|
|
|
self.type_vocab_size = type_vocab_size |
|
|
|
self.initializer_range = initializer_range |
|
|
|
self.input_mask_from_dataset = input_mask_from_dataset |
|
|
|
self.token_type_ids_from_dataset = token_type_ids_from_dataset |
|
|
|
self.use_relative_positions = use_relative_positions |
|
|
|
self.dtype = dtype |
|
|
|
self.compute_type = compute_type |
|
|
|
@@ -390,7 +379,6 @@ class BertAttention(nn.Cell): |
|
|
|
Apply multi-headed attention from "from_tensor" to "to_tensor". |
|
|
|
|
|
|
|
Args: |
|
|
|
batch_size (int): Batch size of input datasets. |
|
|
|
from_tensor_width (int): Size of last dim of from_tensor. |
|
|
|
to_tensor_width (int): Size of last dim of to_tensor. |
|
|
|
from_seq_length (int): Length of from_tensor sequence. |
|
|
|
@@ -411,7 +399,6 @@ class BertAttention(nn.Cell): |
|
|
|
compute_type (:class:`mindspore.dtype`): Compute type in BertAttention. Default: mstype.float32. |
|
|
|
""" |
|
|
|
def __init__(self, |
|
|
|
batch_size, |
|
|
|
from_tensor_width, |
|
|
|
to_tensor_width, |
|
|
|
from_seq_length, |
|
|
|
@@ -429,7 +416,6 @@ class BertAttention(nn.Cell): |
|
|
|
use_relative_positions=False, |
|
|
|
compute_type=mstype.float32): |
|
|
|
super(BertAttention, self).__init__() |
|
|
|
self.batch_size = batch_size |
|
|
|
self.from_seq_length = from_seq_length |
|
|
|
self.to_seq_length = to_seq_length |
|
|
|
self.num_attention_heads = num_attention_heads |
|
|
|
@@ -454,9 +440,8 @@ class BertAttention(nn.Cell): |
|
|
|
units, |
|
|
|
activation=value_act, |
|
|
|
weight_init=weight).to_float(compute_type) |
|
|
|
self.shape_from = (batch_size, from_seq_length, num_attention_heads, size_per_head) |
|
|
|
self.shape_to = ( |
|
|
|
batch_size, to_seq_length, num_attention_heads, size_per_head) |
|
|
|
self.shape_from = (-1, from_seq_length, num_attention_heads, size_per_head) |
|
|
|
self.shape_to = (-1, to_seq_length, num_attention_heads, size_per_head) |
|
|
|
self.matmul_trans_b = P.BatchMatMul(transpose_b=True) |
|
|
|
self.multiply = P.Mul() |
|
|
|
self.transpose = P.Transpose() |
|
|
|
@@ -464,7 +449,6 @@ class BertAttention(nn.Cell): |
|
|
|
self.trans_shape_relative = (2, 0, 1, 3) |
|
|
|
self.trans_shape_position = (1, 2, 0, 3) |
|
|
|
self.multiply_data = Tensor([-10000.0,], dtype=compute_type) |
|
|
|
self.batch_num = batch_size * num_attention_heads |
|
|
|
self.matmul = P.BatchMatMul() |
|
|
|
self.softmax = nn.Softmax() |
|
|
|
self.dropout = nn.Dropout(1 - attention_probs_dropout_prob) |
|
|
|
@@ -475,9 +459,9 @@ class BertAttention(nn.Cell): |
|
|
|
self.cast = P.Cast() |
|
|
|
self.get_dtype = P.DType() |
|
|
|
if do_return_2d_tensor: |
|
|
|
self.shape_return = (batch_size * from_seq_length, num_attention_heads * size_per_head) |
|
|
|
self.shape_return = (-1, num_attention_heads * size_per_head) |
|
|
|
else: |
|
|
|
self.shape_return = (batch_size, from_seq_length, num_attention_heads * size_per_head) |
|
|
|
self.shape_return = (-1, from_seq_length, num_attention_heads * size_per_head) |
|
|
|
self.cast_compute_type = SaturateCast(dst_type=compute_type) |
|
|
|
if self.use_relative_positions: |
|
|
|
self._generate_relative_positions_embeddings = \ |
|
|
|
@@ -510,7 +494,7 @@ class BertAttention(nn.Cell): |
|
|
|
# query_layer_r is [F, B * N, H] |
|
|
|
query_layer_r = self.reshape(query_layer_t, |
|
|
|
(self.from_seq_length, |
|
|
|
self.batch_num, |
|
|
|
-1, |
|
|
|
self.size_per_head)) |
|
|
|
# key_position_scores is [F, B * N, F|T] |
|
|
|
key_position_scores = self.matmul_trans_b(query_layer_r, |
|
|
|
@@ -518,7 +502,7 @@ class BertAttention(nn.Cell): |
|
|
|
# key_position_scores_r is [F, B, N, F|T] |
|
|
|
key_position_scores_r = self.reshape(key_position_scores, |
|
|
|
(self.from_seq_length, |
|
|
|
self.batch_size, |
|
|
|
-1, |
|
|
|
self.num_attention_heads, |
|
|
|
self.from_seq_length)) |
|
|
|
# key_position_scores_r_t is [B, N, F, F|T] |
|
|
|
@@ -548,7 +532,7 @@ class BertAttention(nn.Cell): |
|
|
|
attention_probs_r = self.reshape( |
|
|
|
attention_probs_t, |
|
|
|
(self.from_seq_length, |
|
|
|
self.batch_num, |
|
|
|
-1, |
|
|
|
self.to_seq_length)) |
|
|
|
# value_position_scores is [F, B * N, H] |
|
|
|
value_position_scores = self.matmul(attention_probs_r, |
|
|
|
@@ -556,7 +540,7 @@ class BertAttention(nn.Cell): |
|
|
|
# value_position_scores_r is [F, B, N, H] |
|
|
|
value_position_scores_r = self.reshape(value_position_scores, |
|
|
|
(self.from_seq_length, |
|
|
|
self.batch_size, |
|
|
|
-1, |
|
|
|
self.num_attention_heads, |
|
|
|
self.size_per_head)) |
|
|
|
# value_position_scores_r_t is [B, N, F, H] |
|
|
|
@@ -572,7 +556,6 @@ class BertSelfAttention(nn.Cell): |
|
|
|
Apply self-attention. |
|
|
|
|
|
|
|
Args: |
|
|
|
batch_size (int): Batch size of input dataset. |
|
|
|
seq_length (int): Length of input sequence. |
|
|
|
hidden_size (int): Size of the bert encoder layers. |
|
|
|
num_attention_heads (int): Number of attention heads. Default: 12. |
|
|
|
@@ -585,7 +568,6 @@ class BertSelfAttention(nn.Cell): |
|
|
|
compute_type (:class:`mindspore.dtype`): Compute type in BertSelfAttention. Default: mstype.float32. |
|
|
|
""" |
|
|
|
def __init__(self, |
|
|
|
batch_size, |
|
|
|
seq_length, |
|
|
|
hidden_size, |
|
|
|
num_attention_heads=12, |
|
|
|
@@ -601,7 +583,6 @@ class BertSelfAttention(nn.Cell): |
|
|
|
"of attention heads (%d)" % (hidden_size, num_attention_heads)) |
|
|
|
self.size_per_head = int(hidden_size / num_attention_heads) |
|
|
|
self.attention = BertAttention( |
|
|
|
batch_size=batch_size, |
|
|
|
from_tensor_width=hidden_size, |
|
|
|
to_tensor_width=hidden_size, |
|
|
|
from_seq_length=seq_length, |
|
|
|
@@ -636,7 +617,6 @@ class BertEncoderCell(nn.Cell): |
|
|
|
Encoder cells used in BertTransformer. |
|
|
|
|
|
|
|
Args: |
|
|
|
batch_size (int): Batch size of input dataset. |
|
|
|
hidden_size (int): Size of the bert encoder layers. Default: 768. |
|
|
|
seq_length (int): Length of input sequence. Default: 512. |
|
|
|
num_attention_heads (int): Number of attention heads. Default: 12. |
|
|
|
@@ -651,7 +631,6 @@ class BertEncoderCell(nn.Cell): |
|
|
|
compute_type (:class:`mindspore.dtype`): Compute type in attention. Default: mstype.float32. |
|
|
|
""" |
|
|
|
def __init__(self, |
|
|
|
batch_size, |
|
|
|
hidden_size=768, |
|
|
|
seq_length=512, |
|
|
|
num_attention_heads=12, |
|
|
|
@@ -665,7 +644,6 @@ class BertEncoderCell(nn.Cell): |
|
|
|
compute_type=mstype.float32): |
|
|
|
super(BertEncoderCell, self).__init__() |
|
|
|
self.attention = BertSelfAttention( |
|
|
|
batch_size=batch_size, |
|
|
|
hidden_size=hidden_size, |
|
|
|
seq_length=seq_length, |
|
|
|
num_attention_heads=num_attention_heads, |
|
|
|
@@ -700,7 +678,6 @@ class BertTransformer(nn.Cell): |
|
|
|
Multi-layer bert transformer. |
|
|
|
|
|
|
|
Args: |
|
|
|
batch_size (int): Batch size of input dataset. |
|
|
|
hidden_size (int): Size of the encoder layers. |
|
|
|
seq_length (int): Length of input sequence. |
|
|
|
num_hidden_layers (int): Number of hidden layers in encoder cells. |
|
|
|
@@ -717,7 +694,6 @@ class BertTransformer(nn.Cell): |
|
|
|
return_all_encoders (bool): Specifies whether to return all encoders. Default: False. |
|
|
|
""" |
|
|
|
def __init__(self, |
|
|
|
batch_size, |
|
|
|
hidden_size, |
|
|
|
seq_length, |
|
|
|
num_hidden_layers, |
|
|
|
@@ -735,8 +711,7 @@ class BertTransformer(nn.Cell): |
|
|
|
self.return_all_encoders = return_all_encoders |
|
|
|
layers = [] |
|
|
|
for _ in range(num_hidden_layers): |
|
|
|
layer = BertEncoderCell(batch_size=batch_size, |
|
|
|
hidden_size=hidden_size, |
|
|
|
layer = BertEncoderCell(hidden_size=hidden_size, |
|
|
|
seq_length=seq_length, |
|
|
|
num_attention_heads=num_attention_heads, |
|
|
|
intermediate_size=intermediate_size, |
|
|
|
@@ -751,7 +726,7 @@ class BertTransformer(nn.Cell): |
|
|
|
self.layers = nn.CellList(layers) |
|
|
|
self.reshape = P.Reshape() |
|
|
|
self.shape = (-1, hidden_size) |
|
|
|
self.out_shape = (batch_size, seq_length, hidden_size) |
|
|
|
self.out_shape = (-1, seq_length, hidden_size) |
|
|
|
def construct(self, input_tensor, attention_mask): |
|
|
|
"""bert transformer""" |
|
|
|
prev_output = self.reshape(input_tensor, self.shape) |
|
|
|
@@ -782,22 +757,13 @@ class CreateAttentionMaskFromInputMask(nn.Cell): |
|
|
|
""" |
|
|
|
def __init__(self, config): |
|
|
|
super(CreateAttentionMaskFromInputMask, self).__init__() |
|
|
|
self.input_mask_from_dataset = config.input_mask_from_dataset |
|
|
|
self.input_mask = None |
|
|
|
if not self.input_mask_from_dataset: |
|
|
|
self.input_mask = initializer( |
|
|
|
"ones", [config.batch_size, config.seq_length], mstype.int32).to_tensor() |
|
|
|
self.cast = P.Cast() |
|
|
|
self.reshape = P.Reshape() |
|
|
|
self.shape = (config.batch_size, 1, config.seq_length) |
|
|
|
self.broadcast_ones = initializer( |
|
|
|
"ones", [config.batch_size, config.seq_length, 1], mstype.float32).to_tensor() |
|
|
|
self.batch_matmul = P.BatchMatMul() |
|
|
|
self.shape = (-1, 1, config.seq_length) |
|
|
|
|
|
|
|
def construct(self, input_mask): |
|
|
|
if not self.input_mask_from_dataset: |
|
|
|
input_mask = self.input_mask |
|
|
|
input_mask = self.cast(self.reshape(input_mask, self.shape), mstype.float32) |
|
|
|
attention_mask = self.batch_matmul(self.broadcast_ones, input_mask) |
|
|
|
attention_mask = self.cast(self.reshape(input_mask, self.shape), mstype.float32) |
|
|
|
return attention_mask |
|
|
|
|
|
|
|
class BertModel(nn.Cell): |
|
|
|
@@ -818,20 +784,14 @@ class BertModel(nn.Cell): |
|
|
|
if not is_training: |
|
|
|
config.hidden_dropout_prob = 0.0 |
|
|
|
config.attention_probs_dropout_prob = 0.0 |
|
|
|
self.input_mask_from_dataset = config.input_mask_from_dataset |
|
|
|
self.token_type_ids_from_dataset = config.token_type_ids_from_dataset |
|
|
|
self.batch_size = config.batch_size |
|
|
|
self.seq_length = config.seq_length |
|
|
|
self.hidden_size = config.hidden_size |
|
|
|
self.num_hidden_layers = config.num_hidden_layers |
|
|
|
self.embedding_size = config.hidden_size |
|
|
|
self.token_type_ids = None |
|
|
|
self.last_idx = self.num_hidden_layers - 1 |
|
|
|
output_embedding_shape = [self.batch_size, self.seq_length, |
|
|
|
output_embedding_shape = [-1, self.seq_length, |
|
|
|
self.embedding_size] |
|
|
|
if not self.token_type_ids_from_dataset: |
|
|
|
self.token_type_ids = initializer( |
|
|
|
"zeros", [self.batch_size, self.seq_length], mstype.int32).to_tensor() |
|
|
|
self.bert_embedding_lookup = EmbeddingLookup( |
|
|
|
vocab_size=config.vocab_size, |
|
|
|
embedding_size=self.embedding_size, |
|
|
|
@@ -849,7 +809,6 @@ class BertModel(nn.Cell): |
|
|
|
max_position_embeddings=config.max_position_embeddings, |
|
|
|
dropout_prob=config.hidden_dropout_prob) |
|
|
|
self.bert_encoder = BertTransformer( |
|
|
|
batch_size=self.batch_size, |
|
|
|
hidden_size=self.hidden_size, |
|
|
|
seq_length=self.seq_length, |
|
|
|
num_attention_heads=config.num_attention_heads, |
|
|
|
@@ -876,8 +835,6 @@ class BertModel(nn.Cell): |
|
|
|
def construct(self, input_ids, token_type_ids, input_mask): |
|
|
|
"""bert model""" |
|
|
|
# embedding |
|
|
|
if not self.token_type_ids_from_dataset: |
|
|
|
token_type_ids = self.token_type_ids |
|
|
|
word_embeddings, embedding_tables = self.bert_embedding_lookup(input_ids) |
|
|
|
embedding_output = self.bert_embedding_postprocessor(token_type_ids, word_embeddings) |
|
|
|
# attention mask [batch_size, seq_length, seq_length] |
|
|
|
@@ -889,7 +846,7 @@ class BertModel(nn.Cell): |
|
|
|
# pooler |
|
|
|
sequence_slice = self.slice(sequence_output, |
|
|
|
(0, 0, 0), |
|
|
|
(self.batch_size, 1, self.hidden_size), |
|
|
|
(-1, 1, self.hidden_size), |
|
|
|
(1, 1, 1)) |
|
|
|
first_token = self.squeeze_1(sequence_slice) |
|
|
|
pooled_output = self.dense(first_token) |
|
|
|
@@ -921,20 +878,14 @@ class TinyBertModel(nn.Cell): |
|
|
|
if not is_training: |
|
|
|
config.hidden_dropout_prob = 0.0 |
|
|
|
config.attention_probs_dropout_prob = 0.0 |
|
|
|
self.input_mask_from_dataset = config.input_mask_from_dataset |
|
|
|
self.token_type_ids_from_dataset = config.token_type_ids_from_dataset |
|
|
|
self.batch_size = config.batch_size |
|
|
|
self.seq_length = config.seq_length |
|
|
|
self.hidden_size = config.hidden_size |
|
|
|
self.num_hidden_layers = config.num_hidden_layers |
|
|
|
self.embedding_size = config.hidden_size |
|
|
|
self.token_type_ids = None |
|
|
|
self.last_idx = self.num_hidden_layers - 1 |
|
|
|
output_embedding_shape = [self.batch_size, self.seq_length, |
|
|
|
output_embedding_shape = [-1, self.seq_length, |
|
|
|
self.embedding_size] |
|
|
|
if not self.token_type_ids_from_dataset: |
|
|
|
self.token_type_ids = initializer( |
|
|
|
"zeros", [self.batch_size, self.seq_length], mstype.int32).to_tensor() |
|
|
|
self.tinybert_embedding_lookup = EmbeddingLookup( |
|
|
|
vocab_size=config.vocab_size, |
|
|
|
embedding_size=self.embedding_size, |
|
|
|
@@ -952,7 +903,6 @@ class TinyBertModel(nn.Cell): |
|
|
|
max_position_embeddings=config.max_position_embeddings, |
|
|
|
dropout_prob=config.hidden_dropout_prob) |
|
|
|
self.tinybert_encoder = BertTransformer( |
|
|
|
batch_size=self.batch_size, |
|
|
|
hidden_size=self.hidden_size, |
|
|
|
seq_length=self.seq_length, |
|
|
|
num_attention_heads=config.num_attention_heads, |
|
|
|
@@ -979,8 +929,6 @@ class TinyBertModel(nn.Cell): |
|
|
|
def construct(self, input_ids, token_type_ids, input_mask): |
|
|
|
"""tiny bert model""" |
|
|
|
# embedding |
|
|
|
if not self.token_type_ids_from_dataset: |
|
|
|
token_type_ids = self.token_type_ids |
|
|
|
word_embeddings, embedding_tables = self.tinybert_embedding_lookup(input_ids) |
|
|
|
embedding_output = self.tinybert_embedding_postprocessor(token_type_ids, |
|
|
|
word_embeddings) |
|
|
|
@@ -993,7 +941,7 @@ class TinyBertModel(nn.Cell): |
|
|
|
# pooler |
|
|
|
sequence_slice = self.slice(sequence_output, |
|
|
|
(0, 0, 0), |
|
|
|
(self.batch_size, 1, self.hidden_size), |
|
|
|
(-1, 1, self.hidden_size), |
|
|
|
(1, 1, 1)) |
|
|
|
first_token = self.squeeze_1(sequence_slice) |
|
|
|
pooled_output = self.dense(first_token) |
|
|
|
|