Merge pull request !8196 from zhangbuxue/support_non-Iterable_valid_dtypes_when_check_tensor_type_validtags/v1.1.0
| @@ -448,12 +448,14 @@ class Validator: | |||||
| @staticmethod | @staticmethod | ||||
| def check_tensors_dtypes_same_and_valid(args, valid_dtypes, prim_name): | def check_tensors_dtypes_same_and_valid(args, valid_dtypes, prim_name): | ||||
| """Checks whether the element types of input tensors are the same and valid.""" | """Checks whether the element types of input tensors are the same and valid.""" | ||||
| valid_dtypes = valid_dtypes if isinstance(valid_dtypes, Iterable) else [valid_dtypes] | |||||
| tensor_types = [mstype.tensor_type(t) for t in valid_dtypes] | tensor_types = [mstype.tensor_type(t) for t in valid_dtypes] | ||||
| Validator.check_types_same_and_valid(args, tensor_types, prim_name) | Validator.check_types_same_and_valid(args, tensor_types, prim_name) | ||||
| @staticmethod | @staticmethod | ||||
| def check_tensor_dtype_valid(arg_name, arg_type, valid_dtypes, prim_name): | def check_tensor_dtype_valid(arg_name, arg_type, valid_dtypes, prim_name): | ||||
| """Checks whether the element types of input tensors are valid.""" | """Checks whether the element types of input tensors are valid.""" | ||||
| valid_dtypes = valid_dtypes if isinstance(valid_dtypes, Iterable) else [valid_dtypes] | |||||
| tensor_types = [mstype.tensor_type(t) for t in valid_dtypes] | tensor_types = [mstype.tensor_type(t) for t in valid_dtypes] | ||||
| Validator.check_subclass(arg_name, arg_type, tensor_types, prim_name) | Validator.check_subclass(arg_name, arg_type, tensor_types, prim_name) | ||||
| @@ -5716,7 +5716,7 @@ class DynamicRNN(PrimitiveWithInfer): | |||||
| return y_shape, y_shape, y_shape, y_shape, y_shape, y_shape, y_shape, y_shape | return y_shape, y_shape, y_shape, y_shape, y_shape, y_shape, y_shape, y_shape | ||||
| def infer_dtype(self, x_dtype, w_dtype, b_dtype, seq_dtype, h_dtype, c_dtype): | def infer_dtype(self, x_dtype, w_dtype, b_dtype, seq_dtype, h_dtype, c_dtype): | ||||
| tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=mstype.float16, prim_name=self.name), | |||||
| tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=[mstype.float16], prim_name=self.name), | |||||
| ("x", "w", "h", "c"), | ("x", "w", "h", "c"), | ||||
| (x_dtype, w_dtype, h_dtype, c_dtype))) | (x_dtype, w_dtype, h_dtype, c_dtype))) | ||||
| validator.check_tensor_dtype_valid("b", b_dtype, (mstype.float16, mstype.float32), self.name) | validator.check_tensor_dtype_valid("b", b_dtype, (mstype.float16, mstype.float32), self.name) | ||||