|
|
|
@@ -29,8 +29,9 @@ from mindspore.ops import functional as F |
|
|
|
from mindspore.nn.cell import Cell |
|
|
|
from mindspore._checkparam import Validator |
|
|
|
from mindspore import log as logger |
|
|
|
from .layers import _LayerNorm, _Linear, _check_input_shape,\ |
|
|
|
_check_shape_equal, _check_past_none_input_none, _check_input_dtype, _check_input_shape_value |
|
|
|
from .layers import _LayerNorm, _Linear, _check_input_shape, \ |
|
|
|
_check_shape_equal, _check_past_none_input_none, _check_input_dtype, _check_input_shape_value, \ |
|
|
|
_args_type_validator_check, _valid_type_checks, _valid_value_checks |
|
|
|
from .op_parallel_config import default_dpmp_config, _PipeLineConfig, OpParallelConfig, _Config, _check_config |
|
|
|
|
|
|
|
__all__ = [ |
|
|
|
@@ -292,6 +293,14 @@ class FeedForward(Cell): |
|
|
|
(2, 20, 15) |
|
|
|
""" |
|
|
|
|
|
|
|
@_args_type_validator_check(hidden_size=Validator.check_positive_int, |
|
|
|
ffn_hidden_size=Validator.check_positive_int, |
|
|
|
dropout_rate=Validator.check_non_negative_float, |
|
|
|
hidden_act=_valid_type_checks([str], "FeedForward"), |
|
|
|
param_init_type=_valid_value_checks([mstype.float32, mstype.float16], |
|
|
|
"FeedForward"), |
|
|
|
parallel_config=_valid_type_checks([OpParallelConfig], |
|
|
|
"FeedForward")) |
|
|
|
def __init__(self, hidden_size, |
|
|
|
ffn_hidden_size, |
|
|
|
dropout_rate, |
|
|
|
@@ -300,13 +309,6 @@ class FeedForward(Cell): |
|
|
|
parallel_config=default_dpmp_config): |
|
|
|
super(FeedForward, self).__init__() |
|
|
|
_check_config(parallel_config) |
|
|
|
Validator.check_positive_int(hidden_size, "hidden_size") |
|
|
|
Validator.check_positive_int(ffn_hidden_size, "ffn_hidden_size") |
|
|
|
if not isinstance(hidden_act, str): |
|
|
|
raise ValueError(f"The hidden_act should be a str type, but found {type(hidden_act)}") |
|
|
|
if not isinstance(parallel_config, OpParallelConfig): |
|
|
|
raise ValueError( |
|
|
|
f"The parallel_config should be a OpParallelConfig type, but found {type(parallel_config)}") |
|
|
|
dp = parallel_config.data_parallel |
|
|
|
mp = parallel_config.model_parallel |
|
|
|
if ffn_hidden_size % mp != 0: |
|
|
|
@@ -388,12 +390,10 @@ class AttentionMask(Cell): |
|
|
|
[0, 0, 0, 0]]]) |
|
|
|
""" |
|
|
|
|
|
|
|
@_args_type_validator_check(seq_length=Validator.check_positive_int, |
|
|
|
parallel_config=_valid_type_checks([OpParallelConfig], "AttentionMask")) |
|
|
|
def __init__(self, seq_length, parallel_config=default_dpmp_config): |
|
|
|
super(AttentionMask, self).__init__() |
|
|
|
Validator.check_positive_int(seq_length, "seq_length") |
|
|
|
if not isinstance(parallel_config, OpParallelConfig): |
|
|
|
raise ValueError( |
|
|
|
f"The parallel_config should be a OpParallelConfig type, but found {type(parallel_config)}") |
|
|
|
self.seq_length = seq_length |
|
|
|
self.not_equal = P.NotEqual().shard(((parallel_config.data_parallel, 1), ())) |
|
|
|
self.reshape = P.Reshape() |
|
|
|
@@ -470,15 +470,13 @@ class VocabEmbedding(Cell): |
|
|
|
(30, 30) |
|
|
|
""" |
|
|
|
|
|
|
|
@_args_type_validator_check(vocab_size=Validator.check_positive_int, |
|
|
|
embedding_size=Validator.check_positive_int, |
|
|
|
parallel_config=_valid_type_checks([EmbeddingOpParallelConfig], "VocabEmbedding")) |
|
|
|
def __init__(self, vocab_size, embedding_size, parallel_config=default_embedding_parallel_config, |
|
|
|
param_init='normal'): |
|
|
|
super(VocabEmbedding, self).__init__() |
|
|
|
_check_config(parallel_config) |
|
|
|
Validator.check_positive_int(vocab_size, "vocab_size") |
|
|
|
Validator.check_positive_int(embedding_size, "embedding_size") |
|
|
|
if not isinstance(parallel_config, EmbeddingOpParallelConfig): |
|
|
|
raise ValueError(f"The parallel_config should be a VocabEmbedding type, but found {type(parallel_config)}") |
|
|
|
|
|
|
|
self.vocab_size = vocab_size |
|
|
|
self.embedding_size = embedding_size |
|
|
|
self.embedding_table = Parameter(initializer(param_init, [self.vocab_size, self.embedding_size]), |
|
|
|
@@ -564,6 +562,20 @@ class MultiHeadAttention(Cell): |
|
|
|
(2, 3, 20, 5) |
|
|
|
""" |
|
|
|
|
|
|
|
@_args_type_validator_check(batch_size=Validator.check_positive_int, |
|
|
|
hidden_size=Validator.check_positive_int, |
|
|
|
num_heads=Validator.check_positive_int, |
|
|
|
src_seq_length=Validator.check_positive_int, |
|
|
|
tgt_seq_length=Validator.check_positive_int, |
|
|
|
attention_dropout_rate=Validator.check_non_negative_float, |
|
|
|
hidden_dropout_rate=Validator.check_non_negative_float, |
|
|
|
softmax_comptue_type=_valid_value_checks([mstype.float32, mstype.float16], |
|
|
|
"MultiHeadAttention"), |
|
|
|
param_init_type=_valid_value_checks([mstype.float32, mstype.float16], |
|
|
|
"MultiHeadAttention"), |
|
|
|
parallel_config=_valid_type_checks([OpParallelConfig], |
|
|
|
"MultiHeadAttention"), |
|
|
|
use_past=Validator.check_bool) |
|
|
|
def __init__(self, batch_size, |
|
|
|
src_seq_length, |
|
|
|
tgt_seq_length, |
|
|
|
@@ -582,11 +594,6 @@ class MultiHeadAttention(Cell): |
|
|
|
self.tgt_seq_length = tgt_seq_length |
|
|
|
self.hidden_size = hidden_size |
|
|
|
self.batch_size = batch_size |
|
|
|
Validator.check_positive_int(num_heads, "num_heads") |
|
|
|
Validator.check_positive_int(batch_size, "batch_size") |
|
|
|
Validator.check_positive_int(src_seq_length, "src_seq_length") |
|
|
|
Validator.check_positive_int(tgt_seq_length, "tgt_seq_length") |
|
|
|
Validator.check_positive_int(num_heads, "num_heads") |
|
|
|
if hidden_dropout_rate < 0 or hidden_dropout_rate >= 1: |
|
|
|
raise ValueError("hidden_dropout_rate probability should be a number in range [0, 1.0), " |
|
|
|
"but got {}".format(hidden_dropout_rate)) |
|
|
|
@@ -787,11 +794,13 @@ class MultiHeadAttention(Cell): |
|
|
|
_check_input_dtype(F.dtype(query_tensor), "query_tensor", [mstype.float32, mstype.float16], self.cls_name) |
|
|
|
_check_input_dtype(F.dtype(key_tensor), "key_tensor", [mstype.float32, mstype.float16], self.cls_name) |
|
|
|
_check_input_dtype(F.dtype(value_tensor), "value_tensor", [mstype.float32, mstype.float16], self.cls_name) |
|
|
|
_check_input_dtype(F.dtype(attention_mask), "attention_mask", [mstype.float32, mstype.float16], self.cls_name) |
|
|
|
|
|
|
|
_check_past_none_input_none(self.use_past, "key_past", self.cls_name, key_past) |
|
|
|
_check_past_none_input_none(self.use_past, "value_past", self.cls_name, value_past) |
|
|
|
_check_past_none_input_none(self.use_past, "batch_valid_length", self.cls_name, batch_valid_length) |
|
|
|
return True |
|
|
|
|
|
|
|
def _merge_heads(self, x): |
|
|
|
""" |
|
|
|
convert a 4d input to a 3d output |
|
|
|
@@ -932,6 +941,24 @@ class TransformerEncoderLayer(Cell): |
|
|
|
(2, 2, 16, 4) |
|
|
|
""" |
|
|
|
|
|
|
|
@_args_type_validator_check(batch_size=Validator.check_positive_int, |
|
|
|
hidden_size=Validator.check_positive_int, |
|
|
|
num_heads=Validator.check_positive_int, |
|
|
|
ffn_hidden_size=Validator.check_positive_int, |
|
|
|
seq_length=Validator.check_positive_int, |
|
|
|
attention_dropout_rate=Validator.check_non_negative_float, |
|
|
|
hidden_dropout_rate=Validator.check_non_negative_float, |
|
|
|
hidden_act=_valid_type_checks([str], "TransformerEncoderLayer"), |
|
|
|
post_layernorm_residual=Validator.check_bool, |
|
|
|
layernorm_compute_type=_valid_value_checks([mstype.float32, mstype.float16], |
|
|
|
"TransformerEncoderLayer"), |
|
|
|
softmax_comptue_type=_valid_value_checks([mstype.float32, mstype.float16], |
|
|
|
"TransformerEncoderLayer"), |
|
|
|
param_init_type=_valid_value_checks([mstype.float32, mstype.float16], |
|
|
|
"TransformerEncoderLayer"), |
|
|
|
parallel_config=_valid_type_checks([OpParallelConfig], |
|
|
|
"TransformerEncoderLayer"), |
|
|
|
use_past=Validator.check_bool) |
|
|
|
def __init__(self, |
|
|
|
batch_size, |
|
|
|
hidden_size, |
|
|
|
@@ -953,9 +980,6 @@ class TransformerEncoderLayer(Cell): |
|
|
|
raise ValueError( |
|
|
|
f"num heads must be divisibled by the model parallel way {parallel_config.model_parallel}," |
|
|
|
f"but found {num_heads}") |
|
|
|
Validator.check_bool(post_layernorm_residual, "post_layernorm_residual") |
|
|
|
if not isinstance(hidden_act, str): |
|
|
|
raise ValueError(f"The hidden_act should be a str type, but found {type(hidden_act)}") |
|
|
|
self.use_past = use_past |
|
|
|
self.seq_length = seq_length |
|
|
|
self.hidden_size = hidden_size |
|
|
|
@@ -986,6 +1010,8 @@ class TransformerEncoderLayer(Cell): |
|
|
|
self.post_layernorm_residual = post_layernorm_residual |
|
|
|
self.add = P.TensorAdd().shard(((parallel_config.data_parallel, 1, 1), (parallel_config.data_parallel, 1, 1))) |
|
|
|
self.dtype = mstype.float16 |
|
|
|
self.key_past = None |
|
|
|
self.value_past = None |
|
|
|
|
|
|
|
if self.use_past: |
|
|
|
# operator used for state reuse |
|
|
|
@@ -1058,8 +1084,10 @@ class TransformerEncoderLayer(Cell): |
|
|
|
r"""Check inputs""" |
|
|
|
_check_shape_equal(F.shape(x), "x", self.cls_name, |
|
|
|
[self.batch_size, self.seq_length, self.hidden_size]) |
|
|
|
_check_input_dtype(F.dtype(x), "x", [mstype.float32, mstype.float16], self.cls_name) |
|
|
|
_check_shape_equal(F.shape(input_mask), "input_mask", self.cls_name, |
|
|
|
[self.batch_size, self.seq_length, self.seq_length]) |
|
|
|
_check_input_dtype(F.dtype(input_mask), "input_mask", [mstype.float32, mstype.float16], self.cls_name) |
|
|
|
_check_past_none_input_none(self.use_past, "init_reset", self.cls_name, init_reset, True) |
|
|
|
if init_reset is not True: |
|
|
|
_check_input_dtype(F.dtype(init_reset), "init_reset", [mstype.bool_], self.cls_name) |
|
|
|
@@ -1143,6 +1171,25 @@ class TransformerDecoderLayer(Cell): |
|
|
|
(2, 2, 20, 32) |
|
|
|
""" |
|
|
|
|
|
|
|
@_args_type_validator_check(batch_size=Validator.check_positive_int, |
|
|
|
hidden_size=Validator.check_positive_int, |
|
|
|
num_heads=Validator.check_positive_int, |
|
|
|
ffn_hidden_size=Validator.check_positive_int, |
|
|
|
src_seq_length=Validator.check_positive_int, |
|
|
|
tgt_seq_length=Validator.check_positive_int, |
|
|
|
attention_dropout_rate=Validator.check_non_negative_float, |
|
|
|
hidden_dropout_rate=Validator.check_non_negative_float, |
|
|
|
hidden_act=_valid_type_checks([str], "TransformerDecoderLayer"), |
|
|
|
post_layernorm_residual=Validator.check_bool, |
|
|
|
layernorm_compute_type=_valid_value_checks([mstype.float32, mstype.float16], |
|
|
|
"TransformerDecoderLayer"), |
|
|
|
softmax_comptue_type=_valid_value_checks([mstype.float32, mstype.float16], |
|
|
|
"TransformerDecoderLayer"), |
|
|
|
param_init_type=_valid_value_checks([mstype.float32, mstype.float16], |
|
|
|
"TransformerDecoderLayer"), |
|
|
|
parallel_config=_valid_type_checks([OpParallelConfig], |
|
|
|
"TransformerDecoderLayer"), |
|
|
|
use_past=Validator.check_bool) |
|
|
|
def __init__(self, hidden_size, |
|
|
|
ffn_hidden_size, |
|
|
|
num_heads, |
|
|
|
@@ -1163,13 +1210,7 @@ class TransformerDecoderLayer(Cell): |
|
|
|
self.batch_size = batch_size |
|
|
|
self.use_past = use_past |
|
|
|
self.softmax_comptue_type = softmax_comptue_type |
|
|
|
if num_heads % parallel_config.model_parallel != 0: |
|
|
|
raise ValueError( |
|
|
|
f"num heads must be divisibled by the model parallel way {parallel_config.model_parallel}," |
|
|
|
f"but found {num_heads}") |
|
|
|
Validator.check_bool(post_layernorm_residual, "post_layernorm_residual") |
|
|
|
if not isinstance(hidden_act, str): |
|
|
|
raise ValueError(f"The hidden_act should be a str type, but found {type(hidden_act)}") |
|
|
|
|
|
|
|
self.src_seq_length = src_seq_length |
|
|
|
self.tgt_seq_length = tgt_seq_length |
|
|
|
self.use_past = use_past |
|
|
|
@@ -1217,6 +1258,8 @@ class TransformerDecoderLayer(Cell): |
|
|
|
self.post_layernorm_residual = post_layernorm_residual |
|
|
|
self.add = P.TensorAdd().shard(((parallel_config.data_parallel, 1, 1), (parallel_config.data_parallel, 1, 1))) |
|
|
|
self.dtype = mstype.float16 |
|
|
|
self.key_past = None |
|
|
|
self.value_past = None |
|
|
|
if self.use_past: |
|
|
|
# operator used for state reuse |
|
|
|
self.reducesum = P.ReduceSum().shard(((1, 1, 1, 1),)) |
|
|
|
@@ -1308,12 +1351,18 @@ class TransformerDecoderLayer(Cell): |
|
|
|
[self.batch_size, self.tgt_seq_length, self.hidden_size]) |
|
|
|
_check_shape_equal(F.shape(attention_mask), "attention_mask", self.cls_name, |
|
|
|
[self.batch_size, self.tgt_seq_length, self.tgt_seq_length]) |
|
|
|
_check_input_dtype(F.dtype(hidden_states), "hidden_size", [mstype.float32, mstype.float16], self.cls_name) |
|
|
|
_check_input_dtype(F.dtype(attention_mask), "attention_mask", [mstype.float32, mstype.float16], self.cls_name) |
|
|
|
if encoder_output is not None: |
|
|
|
_check_shape_equal(F.shape(encoder_output), "encoder_output", self.cls_name, |
|
|
|
[self.batch_size, self.src_seq_length, self.hidden_size]) |
|
|
|
_check_input_dtype(F.dtype(encoder_output), "encoder_output", |
|
|
|
[mstype.float32, mstype.float16], self.cls_name) |
|
|
|
if memory_mask is not None: |
|
|
|
_check_shape_equal(F.shape(memory_mask), "memory_mask", self.cls_name, |
|
|
|
[self.batch_size, self.tgt_seq_length, self.src_seq_length]) |
|
|
|
_check_input_dtype(F.dtype(memory_mask), "memory_mask", |
|
|
|
[mstype.float32, mstype.float16], self.cls_name) |
|
|
|
|
|
|
|
_check_past_none_input_none(self.use_past, "init_reset", self.cls_name, init_reset, True) |
|
|
|
if init_reset is not True: |
|
|
|
@@ -1437,6 +1486,26 @@ class TransformerEncoder(Cell): |
|
|
|
(2, 2, 16, 4) |
|
|
|
""" |
|
|
|
|
|
|
|
@_args_type_validator_check(batch_size=Validator.check_positive_int, |
|
|
|
hidden_size=Validator.check_positive_int, |
|
|
|
num_heads=Validator.check_positive_int, |
|
|
|
ffn_hidden_size=Validator.check_positive_int, |
|
|
|
seq_length=Validator.check_positive_int, |
|
|
|
num_layers=Validator.check_positive_int, |
|
|
|
offset=Validator.check_non_negative_int, |
|
|
|
attention_dropout_rate=Validator.check_non_negative_float, |
|
|
|
hidden_dropout_rate=Validator.check_non_negative_float, |
|
|
|
hidden_act=_valid_type_checks([str], "TransformerEncoder"), |
|
|
|
post_layernorm_residual=Validator.check_bool, |
|
|
|
layernorm_compute_type=_valid_value_checks([mstype.float32, mstype.float16], |
|
|
|
"TransformerEncoder"), |
|
|
|
softmax_comptue_type=_valid_value_checks([mstype.float32, mstype.float16], |
|
|
|
"TransformerEncoder"), |
|
|
|
param_init_type=_valid_value_checks([mstype.float32, mstype.float16], |
|
|
|
"TransformerEncoder"), |
|
|
|
parallel_config=_valid_type_checks([TransformerOpParallelConfig], |
|
|
|
"TransformerEncoder"), |
|
|
|
use_past=Validator.check_bool) |
|
|
|
def __init__(self, |
|
|
|
batch_size, |
|
|
|
num_layers, |
|
|
|
@@ -1457,15 +1526,6 @@ class TransformerEncoder(Cell): |
|
|
|
parallel_config=default_transformer_config): |
|
|
|
super(TransformerEncoder, self).__init__() |
|
|
|
_check_config(parallel_config) |
|
|
|
Validator.check_positive_int(num_layers, "num_layers") |
|
|
|
Validator.check_non_negative_int(offset, "offset") |
|
|
|
if num_heads % parallel_config.model_parallel != 0: |
|
|
|
raise ValueError( |
|
|
|
f"num heads must be divisibled by the model parallel way {parallel_config.model_parallel}," |
|
|
|
f"but found {num_heads}") |
|
|
|
Validator.check_bool(post_layernorm_residual, "post_layernorm_residual") |
|
|
|
if not isinstance(hidden_act, str): |
|
|
|
raise ValueError(f"The hidden_act should be a str type, but found {type(hidden_act)}") |
|
|
|
|
|
|
|
self.num_layers = num_layers |
|
|
|
self.blocks = nn.CellList() |
|
|
|
@@ -1587,6 +1647,27 @@ class TransformerDecoder(Cell): |
|
|
|
|
|
|
|
""" |
|
|
|
|
|
|
|
@_args_type_validator_check(batch_size=Validator.check_positive_int, |
|
|
|
hidden_size=Validator.check_positive_int, |
|
|
|
num_heads=Validator.check_positive_int, |
|
|
|
ffn_hidden_size=Validator.check_positive_int, |
|
|
|
src_seq_length=Validator.check_positive_int, |
|
|
|
num_layers=Validator.check_positive_int, |
|
|
|
tgt_seq_length=Validator.check_positive_int, |
|
|
|
offset=Validator.check_non_negative_int, |
|
|
|
attention_dropout_rate=Validator.check_non_negative_float, |
|
|
|
hidden_dropout_rate=Validator.check_non_negative_float, |
|
|
|
hidden_act=_valid_type_checks([str], "TransformerDecoder"), |
|
|
|
post_layernorm_residual=Validator.check_bool, |
|
|
|
layernorm_compute_type=_valid_value_checks([mstype.float32, mstype.float16], |
|
|
|
"TransformerDecoder"), |
|
|
|
softmax_comptue_type=_valid_value_checks([mstype.float32, mstype.float16], |
|
|
|
"TransformerDecoder"), |
|
|
|
param_init_type=_valid_value_checks([mstype.float32, mstype.float16], |
|
|
|
"TransformerDecoder"), |
|
|
|
parallel_config=_valid_type_checks([TransformerOpParallelConfig], |
|
|
|
"TransformerDecoder"), |
|
|
|
use_past=Validator.check_bool) |
|
|
|
def __init__(self, |
|
|
|
num_layers, |
|
|
|
batch_size, |
|
|
|
@@ -1608,15 +1689,6 @@ class TransformerDecoder(Cell): |
|
|
|
parallel_config=default_transformer_config): |
|
|
|
super(TransformerDecoder, self).__init__() |
|
|
|
_check_config(parallel_config) |
|
|
|
Validator.check_positive_int(num_layers, "num_layers") |
|
|
|
Validator.check_non_negative_int(offset, "offset") |
|
|
|
if num_heads % parallel_config.model_parallel != 0: |
|
|
|
raise ValueError( |
|
|
|
f"num heads must be divisibled by the model parallel way {parallel_config.model_parallel}," |
|
|
|
f"but found {num_heads}") |
|
|
|
Validator.check_bool(post_layernorm_residual, "post_layernorm_residual") |
|
|
|
if not isinstance(hidden_act, str): |
|
|
|
raise ValueError(f"The hidden_act should be a str type, but found {type(hidden_act)}") |
|
|
|
|
|
|
|
self.num_layers = num_layers |
|
|
|
self.blocks = nn.CellList() |
|
|
|
@@ -1762,6 +1834,25 @@ class Transformer(Cell): |
|
|
|
|
|
|
|
""" |
|
|
|
|
|
|
|
@_args_type_validator_check(batch_size=Validator.check_positive_int, |
|
|
|
hidden_size=Validator.check_positive_int, |
|
|
|
num_heads=Validator.check_positive_int, |
|
|
|
ffn_hidden_size=Validator.check_positive_int, |
|
|
|
src_seq_length=Validator.check_positive_int, |
|
|
|
encoder_layers=Validator.check_positive_int, |
|
|
|
decoder_layers=Validator.check_non_negative_int, |
|
|
|
tgt_seq_length=Validator.check_positive_int, |
|
|
|
attention_dropout_rate=Validator.check_non_negative_float, |
|
|
|
hidden_dropout_rate=Validator.check_non_negative_float, |
|
|
|
hidden_act=_valid_type_checks([str], "Transformer"), |
|
|
|
post_layernorm_residual=Validator.check_bool, |
|
|
|
layernorm_compute_type=_valid_type_checks([mstype.float32, mstype.float16], |
|
|
|
"Transformer"), |
|
|
|
softmax_comptue_type=_valid_type_checks([mstype.float32, mstype.float16], |
|
|
|
"Transformer"), |
|
|
|
param_init_type=_valid_type_checks([mstype.float32, mstype.float16], "Transformer"), |
|
|
|
parallel_config=_valid_type_checks([TransformerOpParallelConfig], "Transformer"), |
|
|
|
use_past=Validator.check_bool) |
|
|
|
def __init__(self, |
|
|
|
hidden_size, |
|
|
|
batch_size, |
|
|
|
|