|
|
@@ -6519,9 +6519,9 @@ class DynamicGRUV2(PrimitiveWithInfer): |
|
|
Tensor of shape :math:`(\text{hidden_size}, 3 \times \text{hidden_size})`. |
|
|
Tensor of shape :math:`(\text{hidden_size}, 3 \times \text{hidden_size})`. |
|
|
The data type must be float16. |
|
|
The data type must be float16. |
|
|
- **bias_input** (Tensor) - Input-hidden bias. Tensor of shape :math:`(3 \times \text{hidden_size})`, or None. |
|
|
- **bias_input** (Tensor) - Input-hidden bias. Tensor of shape :math:`(3 \times \text{hidden_size})`, or None. |
|
|
The data type must be float16 or float32. |
|
|
|
|
|
|
|
|
Has the same data type with input `init_h`. |
|
|
- **bias_hidden** (Tensor) - Hidden-hidden bias. Tensor of shape :math:`(3 \times \text{hidden_size})`, or None. |
|
|
- **bias_hidden** (Tensor) - Hidden-hidden bias. Tensor of shape :math:`(3 \times \text{hidden_size})`, or None. |
|
|
The data type must be float16 or float32. |
|
|
|
|
|
|
|
|
Has the same data type with input `init_h`. |
|
|
- **seq_length** (Tensor) - The length of each batch. Tensor of shape :math:`(\text{batch_size})`. |
|
|
- **seq_length** (Tensor) - The length of each batch. Tensor of shape :math:`(\text{batch_size})`. |
|
|
Only `None` is currently supported. |
|
|
Only `None` is currently supported. |
|
|
- **init_h** (Tensor) - Hidden state of initial time. |
|
|
- **init_h** (Tensor) - Hidden state of initial time. |
|
|
@@ -6563,15 +6563,6 @@ class DynamicGRUV2(PrimitiveWithInfer): |
|
|
>>> print(output[0].shape) |
|
|
>>> print(output[0].shape) |
|
|
(2, 8, 16) |
|
|
(2, 8, 16) |
|
|
""" |
|
|
""" |
|
|
__mindspore_signature__ = ( |
|
|
|
|
|
sig.make_sig('x', dtype=sig.sig_dtype.T1), |
|
|
|
|
|
sig.make_sig('weight_input', dtype=sig.sig_dtype.T2), |
|
|
|
|
|
sig.make_sig('weight_hidden', dtype=sig.sig_dtype.T3), |
|
|
|
|
|
sig.make_sig('bias_input', dtype=sig.sig_dtype.T), |
|
|
|
|
|
sig.make_sig('bias_hidden', dtype=sig.sig_dtype.T), |
|
|
|
|
|
sig.make_sig('seq_length', dtype=sig.sig_dtype.T4), |
|
|
|
|
|
sig.make_sig('init_h', dtype=sig.sig_dtype.T), |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
@prim_attr_register |
|
|
@prim_attr_register |
|
|
def __init__(self, |
|
|
def __init__(self, |
|
|
@@ -6639,15 +6630,16 @@ class DynamicGRUV2(PrimitiveWithInfer): |
|
|
validator.check_tensor_dtype_valid("x dtype", x_dtype, [mstype.float16], self.name) |
|
|
validator.check_tensor_dtype_valid("x dtype", x_dtype, [mstype.float16], self.name) |
|
|
validator.check_tensor_dtype_valid("weight input dtype", winput_dtype, [mstype.float16], self.name) |
|
|
validator.check_tensor_dtype_valid("weight input dtype", winput_dtype, [mstype.float16], self.name) |
|
|
validator.check_tensor_dtype_valid("weight hidden dtype", whidden_dtype, [mstype.float16], self.name) |
|
|
validator.check_tensor_dtype_valid("weight hidden dtype", whidden_dtype, [mstype.float16], self.name) |
|
|
validator.check_tensor_dtype_valid("init_h dtype", h_dtype, (mstype.float16, mstype.float32), self.name) |
|
|
|
|
|
|
|
|
valid_dtypes = [mstype.float16, mstype.float32] |
|
|
|
|
|
validator.check_tensor_dtype_valid("init_h dtype", h_dtype, valid_dtypes, self.name) |
|
|
b_dtype = h_dtype |
|
|
b_dtype = h_dtype |
|
|
if binput_dtype is not None: |
|
|
if binput_dtype is not None: |
|
|
validator.check_tensor_dtype_valid("bias input dtype", binput_dtype, |
|
|
|
|
|
(mstype.float16, mstype.float32), self.name) |
|
|
|
|
|
|
|
|
args = {'init_h': h_dtype, 'bias_input': binput_dtype} |
|
|
|
|
|
validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name) |
|
|
b_dtype = binput_dtype |
|
|
b_dtype = binput_dtype |
|
|
elif bhidden_dtype is not None: |
|
|
elif bhidden_dtype is not None: |
|
|
validator.check_tensor_dtype_valid("bias hidden dtype", bhidden_dtype, |
|
|
|
|
|
(mstype.float16, mstype.float32), self.name) |
|
|
|
|
|
|
|
|
args = {'init_h': h_dtype, 'bias_hidden': bhidden_dtype} |
|
|
|
|
|
validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name) |
|
|
b_dtype = bhidden_dtype |
|
|
b_dtype = bhidden_dtype |
|
|
|
|
|
|
|
|
return b_dtype, b_dtype, b_dtype, b_dtype, b_dtype, b_dtype |
|
|
return b_dtype, b_dtype, b_dtype, b_dtype, b_dtype, b_dtype |
|
|
|