| @@ -345,22 +345,22 @@ def get_index_tensor_dtype(dtype): | |||||
| @constexpr | @constexpr | ||||
| def check_index_tensors_dtype(dtypes, op_name): | |||||
| def check_index_tensors_dtype(indexes_types, op_name): | |||||
| """Check a tuple of tensor data type.""" | """Check a tuple of tensor data type.""" | ||||
| for ele in dtypes: | |||||
| if not ele in mstype.int_type: | |||||
| raise IndexError(f"For '{op_name}', the all index tensor " | |||||
| f"data types should be mstype.int32, but got {dtypes}.") | |||||
| for index_type in indexes_types: | |||||
| if not index_type in (mstype.int32, mstype.int64): | |||||
| raise IndexError(f"For '{op_name}', the all index tensor data types should be " | |||||
| f"mstype.int32, but got {index_type}.") | |||||
| return True | return True | ||||
| @constexpr | @constexpr | ||||
| def check_index_tensor_dtype(dtype, op_name): | |||||
| def check_index_tensor_dtype(index_type, op_name): | |||||
| """Check a tensor data type.""" | """Check a tensor data type.""" | ||||
| if dtype in mstype.int_type: | |||||
| if index_type in (mstype.int32, mstype.int64): | |||||
| return True | return True | ||||
| raise IndexError( | |||||
| f"For '{op_name}', the index tensor data type should be mstype.int32, but got {dtype}.") | |||||
| raise IndexError(f"For '{op_name}', the index tensor data type should be mstype.int32, " | |||||
| f"but got {index_type}.") | |||||
| @constexpr | @constexpr | ||||