Browse Source

support non-Iterable valid_dtypes when check tensor type valid

tags/v1.1.0
buxue 5 years ago
parent
commit
32610ab63f
2 changed files with 3 additions and 1 deletions
  1. +2
    -0
      mindspore/_checkparam.py
  2. +1
    -1
      mindspore/ops/operations/nn_ops.py

+ 2
- 0
mindspore/_checkparam.py View File

@@ -448,12 +448,14 @@ class Validator:
@staticmethod
def check_tensors_dtypes_same_and_valid(args, valid_dtypes, prim_name):
"""Checks whether the element types of input tensors are the same and valid."""
valid_dtypes = valid_dtypes if isinstance(valid_dtypes, Iterable) else [valid_dtypes]
tensor_types = [mstype.tensor_type(t) for t in valid_dtypes]
Validator.check_types_same_and_valid(args, tensor_types, prim_name)

@staticmethod
def check_tensor_dtype_valid(arg_name, arg_type, valid_dtypes, prim_name):
"""Checks whether the element types of input tensors are valid."""
valid_dtypes = valid_dtypes if isinstance(valid_dtypes, Iterable) else [valid_dtypes]
tensor_types = [mstype.tensor_type(t) for t in valid_dtypes]
Validator.check_subclass(arg_name, arg_type, tensor_types, prim_name)



+ 1
- 1
mindspore/ops/operations/nn_ops.py View File

@@ -5716,7 +5716,7 @@ class DynamicRNN(PrimitiveWithInfer):
return y_shape, y_shape, y_shape, y_shape, y_shape, y_shape, y_shape, y_shape

def infer_dtype(self, x_dtype, w_dtype, b_dtype, seq_dtype, h_dtype, c_dtype):
tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=mstype.float16, prim_name=self.name),
tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=[mstype.float16], prim_name=self.name),
("x", "w", "h", "c"),
(x_dtype, w_dtype, h_dtype, c_dtype)))
validator.check_tensor_dtype_valid("b", b_dtype, (mstype.float16, mstype.float32), self.name)


Loading…
Cancel
Save