Browse Source

!10385 Modified data type requirement of bias of DynamicGRUV2.

From: @liu_xiao_93
Reviewed-by: @liangchenghui,@wuxuejian
Signed-off-by: @liangchenghui
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
173265f11f
1 changed files with 8 additions and 16 deletions
  1. +8
    -16
      mindspore/ops/operations/nn_ops.py

+ 8
- 16
mindspore/ops/operations/nn_ops.py View File

@@ -6519,9 +6519,9 @@ class DynamicGRUV2(PrimitiveWithInfer):
Tensor of shape :math:`(\text{hidden_size}, 3 \times \text{hidden_size})`. Tensor of shape :math:`(\text{hidden_size}, 3 \times \text{hidden_size})`.
The data type must be float16. The data type must be float16.
- **bias_input** (Tensor) - Input-hidden bias. Tensor of shape :math:`(3 \times \text{hidden_size})`, or None. - **bias_input** (Tensor) - Input-hidden bias. Tensor of shape :math:`(3 \times \text{hidden_size})`, or None.
The data type must be float16 or float32.
Has the same data type with input `init_h`.
- **bias_hidden** (Tensor) - Hidden-hidden bias. Tensor of shape :math:`(3 \times \text{hidden_size})`, or None. - **bias_hidden** (Tensor) - Hidden-hidden bias. Tensor of shape :math:`(3 \times \text{hidden_size})`, or None.
The data type must be float16 or float32.
Has the same data type with input `init_h`.
- **seq_length** (Tensor) - The length of each batch. Tensor of shape :math:`(\text{batch_size})`. - **seq_length** (Tensor) - The length of each batch. Tensor of shape :math:`(\text{batch_size})`.
Only `None` is currently supported. Only `None` is currently supported.
- **init_h** (Tensor) - Hidden state of initial time. - **init_h** (Tensor) - Hidden state of initial time.
@@ -6563,15 +6563,6 @@ class DynamicGRUV2(PrimitiveWithInfer):
>>> print(output[0].shape) >>> print(output[0].shape)
(2, 8, 16) (2, 8, 16)
""" """
__mindspore_signature__ = (
sig.make_sig('x', dtype=sig.sig_dtype.T1),
sig.make_sig('weight_input', dtype=sig.sig_dtype.T2),
sig.make_sig('weight_hidden', dtype=sig.sig_dtype.T3),
sig.make_sig('bias_input', dtype=sig.sig_dtype.T),
sig.make_sig('bias_hidden', dtype=sig.sig_dtype.T),
sig.make_sig('seq_length', dtype=sig.sig_dtype.T4),
sig.make_sig('init_h', dtype=sig.sig_dtype.T),
)


@prim_attr_register @prim_attr_register
def __init__(self, def __init__(self,
@@ -6639,15 +6630,16 @@ class DynamicGRUV2(PrimitiveWithInfer):
validator.check_tensor_dtype_valid("x dtype", x_dtype, [mstype.float16], self.name) validator.check_tensor_dtype_valid("x dtype", x_dtype, [mstype.float16], self.name)
validator.check_tensor_dtype_valid("weight input dtype", winput_dtype, [mstype.float16], self.name) validator.check_tensor_dtype_valid("weight input dtype", winput_dtype, [mstype.float16], self.name)
validator.check_tensor_dtype_valid("weight hidden dtype", whidden_dtype, [mstype.float16], self.name) validator.check_tensor_dtype_valid("weight hidden dtype", whidden_dtype, [mstype.float16], self.name)
validator.check_tensor_dtype_valid("init_h dtype", h_dtype, (mstype.float16, mstype.float32), self.name)
valid_dtypes = [mstype.float16, mstype.float32]
validator.check_tensor_dtype_valid("init_h dtype", h_dtype, valid_dtypes, self.name)
b_dtype = h_dtype b_dtype = h_dtype
if binput_dtype is not None: if binput_dtype is not None:
validator.check_tensor_dtype_valid("bias input dtype", binput_dtype,
(mstype.float16, mstype.float32), self.name)
args = {'init_h': h_dtype, 'bias_input': binput_dtype}
validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name)
b_dtype = binput_dtype b_dtype = binput_dtype
elif bhidden_dtype is not None: elif bhidden_dtype is not None:
validator.check_tensor_dtype_valid("bias hidden dtype", bhidden_dtype,
(mstype.float16, mstype.float32), self.name)
args = {'init_h': h_dtype, 'bias_hidden': bhidden_dtype}
validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name)
b_dtype = bhidden_dtype b_dtype = bhidden_dtype


return b_dtype, b_dtype, b_dtype, b_dtype, b_dtype, b_dtype return b_dtype, b_dtype, b_dtype, b_dtype, b_dtype, b_dtype


Loading…
Cancel
Save