diff --git a/mindspore/_checkparam.py b/mindspore/_checkparam.py index 7c408bce01..462615653a 100644 --- a/mindspore/_checkparam.py +++ b/mindspore/_checkparam.py @@ -448,12 +448,14 @@ class Validator: @staticmethod def check_tensors_dtypes_same_and_valid(args, valid_dtypes, prim_name): """Checks whether the element types of input tensors are the same and valid.""" + valid_dtypes = valid_dtypes if isinstance(valid_dtypes, Iterable) else [valid_dtypes] tensor_types = [mstype.tensor_type(t) for t in valid_dtypes] Validator.check_types_same_and_valid(args, tensor_types, prim_name) @staticmethod def check_tensor_dtype_valid(arg_name, arg_type, valid_dtypes, prim_name): """Checks whether the element types of input tensors are valid.""" + valid_dtypes = valid_dtypes if isinstance(valid_dtypes, Iterable) else [valid_dtypes] tensor_types = [mstype.tensor_type(t) for t in valid_dtypes] Validator.check_subclass(arg_name, arg_type, tensor_types, prim_name) diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index fb34e62dbf..fa34601793 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -5716,7 +5716,7 @@ class DynamicRNN(PrimitiveWithInfer): return y_shape, y_shape, y_shape, y_shape, y_shape, y_shape, y_shape, y_shape def infer_dtype(self, x_dtype, w_dtype, b_dtype, seq_dtype, h_dtype, c_dtype): - tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=mstype.float16, prim_name=self.name), + tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=[mstype.float16], prim_name=self.name), ("x", "w", "h", "c"), (x_dtype, w_dtype, h_dtype, c_dtype))) validator.check_tensor_dtype_valid("b", b_dtype, (mstype.float16, mstype.float32), self.name)