Browse Source

support swtichlayer for pynative

tags/v1.2.0-rc1
chujinjin 4 years ago
parent
commit
96229b7358
1 changed files with 14 additions and 0 deletions
  1. +14
    -0
      mindspore/common/tensor.py

+ 14
- 0
mindspore/common/tensor.py View File

@@ -161,6 +161,20 @@ class Tensor(Tensor_):
return bool(data[0]) return bool(data[0])
raise ValueError("The truth value of an array with several elements is ambiguous.") raise ValueError("The truth value of an array with several elements is ambiguous.")


def __index__(self):
data = self.asnumpy()
if not (data.dtype == "int8"
or data.dtype == "int16"
or data.dtype == "int32"
or data.dtype == "int64"
or data.dtype == "bool"):
raise ValueError("Only integer tensors of a single element can be converted to an index.")
if data.shape == ():
return int(data)
if data.shape == (1,):
return int(data[0])
raise ValueError("Only integer tensors of a single element can be converted to an index.")

def __pos__(self): def __pos__(self):
return self return self




Loading…
Cancel
Save