|
|
|
@@ -359,22 +359,22 @@ def get_index_tensor_dtype(dtype): |
|
|
|
|
|
|
|
|
|
|
|
@constexpr |
|
|
|
def check_index_tensors_dtype(dtypes, op_name): |
|
|
|
def check_index_tensors_dtype(indexes_types, op_name): |
|
|
|
"""Check a tuple of tensor data type.""" |
|
|
|
for ele in dtypes: |
|
|
|
if not ele in mstype.int_type: |
|
|
|
raise IndexError(f"For '{op_name}', the all index tensor " |
|
|
|
f"data types should be mstype.int32, but got {dtypes}.") |
|
|
|
for index_type in indexes_types: |
|
|
|
if not index_type in (mstype.int32, mstype.int64): |
|
|
|
raise IndexError(f"For '{op_name}', the all index tensor data types should be " |
|
|
|
f"mstype.int32, but got {index_type}.") |
|
|
|
return True |
|
|
|
|
|
|
|
|
|
|
|
@constexpr |
|
|
|
def check_index_tensor_dtype(dtype, op_name): |
|
|
|
def check_index_tensor_dtype(index_type, op_name): |
|
|
|
"""Check a tensor data type.""" |
|
|
|
if dtype in mstype.int_type: |
|
|
|
if index_type in (mstype.int32, mstype.int64): |
|
|
|
return True |
|
|
|
raise IndexError( |
|
|
|
f"For '{op_name}', the index tensor data type should be mstype.int32, but got {dtype}.") |
|
|
|
raise IndexError(f"For '{op_name}', the index tensor data type should be mstype.int32, " |
|
|
|
f"but got {index_type}.") |
|
|
|
|
|
|
|
|
|
|
|
@constexpr |
|
|
|
|