|
|
|
@@ -5716,7 +5716,7 @@ class DynamicRNN(PrimitiveWithInfer): |
|
|
|
return y_shape, y_shape, y_shape, y_shape, y_shape, y_shape, y_shape, y_shape |
|
|
|
|
|
|
|
def infer_dtype(self, x_dtype, w_dtype, b_dtype, seq_dtype, h_dtype, c_dtype): |
|
|
|
tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=mstype.float16, prim_name=self.name), |
|
|
|
tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=[mstype.float16], prim_name=self.name), |
|
|
|
("x", "w", "h", "c"), |
|
|
|
(x_dtype, w_dtype, h_dtype, c_dtype))) |
|
|
|
validator.check_tensor_dtype_valid("b", b_dtype, (mstype.float16, mstype.float32), self.name) |
|
|
|
|