diff --git a/mindspore/common/tensor.py b/mindspore/common/tensor.py index 5ac9e05626..225b5c6d2d 100644 --- a/mindspore/common/tensor.py +++ b/mindspore/common/tensor.py @@ -161,6 +161,20 @@ class Tensor(Tensor_): return bool(data[0]) raise ValueError("The truth value of an array with several elements is ambiguous.") + def __index__(self): + data = self.asnumpy() + if not (data.dtype == "int8" + or data.dtype == "int16" + or data.dtype == "int32" + or data.dtype == "int64" + or data.dtype == "bool"): + raise ValueError("Only integer tensors of a single element can be converted to an index.") + if data.shape == (): + return int(data) + if data.shape == (1,): + return int(data[0]) + raise ValueError("Only integer tensors of a single element can be converted to an index.") + def __pos__(self): return self