Browse Source

!7838 Enumerate function enable tensor as input

Merge pull request !7838 from LiangZhibo/master
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
b3855530e3
2 changed files with 129 additions and 30 deletions
  1. +18
    -7
      mindspore/_extends/parse/standard_method.py
  2. +111
    -23
      tests/ut/python/pipeline/parse/test_enumerate.py

+ 18
- 7
mindspore/_extends/parse/standard_method.py View File

@@ -30,7 +30,6 @@ trans = P.Transpose()
shape_ = P.Shape() shape_ = P.Shape()
dtype_ = P.DType() dtype_ = P.DType()



def all_(x, axis=(), keep_dims=False): def all_(x, axis=(), keep_dims=False):
""" """
Check all array elements along a given axis evaluate to True. Check all array elements along a given axis evaluate to True.
@@ -144,12 +143,16 @@ def bool_(x):




def enumerate_(x, start=0): def enumerate_(x, start=0):
"""Enumerate list or tuple."""
"""Enumerate list or tuple or tensor."""
x_type = F.typeof(x) x_type = F.typeof(x)
ret = () ret = ()
op_name = "enumerate" op_name = "enumerate"
if check_is_tuple_or_list(x_type, op_name, "first input") and check_is_const_int(start, op_name, "start"):
ret = zip(range(start, start + len(x)), x)
if check_is_tuple_or_list_or_tensor(x_type, op_name, "first input") and check_is_const_int(start, op_name, "start"):
if check_is_tensor(x_type):
for i in range(x.shape[0]):
ret += ((start + i, x[i]),)
else:
ret = zip(range(start, start + len(x)), x)
return ret return ret




@@ -177,11 +180,19 @@ def check_type_same(x_type, base_type):




@constexpr @constexpr
def check_is_tuple_or_list(x, op_name, arg_name):
def check_is_tensor(x):
"""check whether x is list or tuple.""" """check whether x is list or tuple."""
if isinstance(x, (mstype.list_type, mstype.tuple_type)):
if isinstance(x, mstype.tensor_type):
return True
return False


@constexpr
def check_is_tuple_or_list_or_tensor(x, op_name, arg_name):
"""check whether x is list or tuple or tensor."""
if isinstance(x, (mstype.list_type, mstype.tuple_type, mstype.tensor_type)):
return True return True
raise TypeError(f"For '{op_name}', the '{arg_name}' should be tuple or list, but got {x}.")
raise TypeError(f"For '{op_name}', the '{arg_name}' should be tuple or list or tensor, but got {x}.")




@constexpr @constexpr


+ 111
- 23
tests/ut/python/pipeline/parse/test_enumerate.py View File

@@ -59,23 +59,36 @@ def test_enumerate_tuple_const():
assert net() == (6, 110) assert net() == (6, 110)




def test_enumerate_tensor_const():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.value = Tensor(np.arange(2 * 3).reshape(2, 3))

def construct(self):
return enumerate(self.value)

net = Net()
net()


def test_enumerate_list_parameter(): def test_enumerate_list_parameter():
class Net(nn.Cell): class Net(nn.Cell):
def __init__(self): def __init__(self):
super(Net, self).__init__() super(Net, self).__init__()


def construct(self, x, y, z):
def construct(self, x, y):
index_sum = 0 index_sum = 0
value = [x, y, z]
value = [x, y]
ret = () ret = ()
for i, j in enumerate(value): for i, j in enumerate(value):
index_sum += i index_sum += i
ret += (j,) ret += (j,)
return index_sum, ret return index_sum, ret


x = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)))
x = Tensor(np.arange(4))
net = Net() net = Net()
net(x, x, x)
net(x, x)




def test_enumerate_tuple_parameter(): def test_enumerate_tuple_parameter():
@@ -83,18 +96,36 @@ def test_enumerate_tuple_parameter():
def __init__(self): def __init__(self):
super(Net, self).__init__() super(Net, self).__init__()


def construct(self, x, y, z):
def construct(self, x, y):
index_sum = 0 index_sum = 0
value = (x, y, z)
value = (x, y)
ret = () ret = ()
for i, j in enumerate(value): for i, j in enumerate(value):
index_sum += i index_sum += i
ret += (j,) ret += (j,)
return index_sum, ret return index_sum, ret


x = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)))
x = Tensor(np.arange(4))
net = Net()
net(x, x)


def test_enumerate_tensor_parameter():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()

def construct(self, x):
index_sum = 0
ret = ()
for i, j in enumerate(x):
index_sum += i
ret += (j,)
return index_sum, ret

x = Tensor(np.arange(2 * 3).reshape(2, 3))
net = Net() net = Net()
net(x, x, x)
net(x)




def test_enumerate_tuple_const_1(): def test_enumerate_tuple_const_1():
@@ -115,23 +146,59 @@ def test_enumerate_tuple_const_1():
assert net() == (6, 110) assert net() == (6, 110)




def test_enumerate_tensor_const_1():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.value = Tensor(np.arange(2*3).reshape(2, 3))

def construct(self):
index_sum = 0
ret = ()
for i in enumerate(self.value):
index_sum += i[0]
ret += (i[1],)
return index_sum, ret

net = Net()
net()


def test_enumerate_tuple_parameter_1(): def test_enumerate_tuple_parameter_1():
class Net(nn.Cell): class Net(nn.Cell):
def __init__(self): def __init__(self):
super(Net, self).__init__() super(Net, self).__init__()


def construct(self, x, y, z):
def construct(self, x, y):
index_sum = 0 index_sum = 0
value = (x, y, z)
value = (x, y)
ret = () ret = ()
for i in enumerate(value): for i in enumerate(value):
index_sum += i[0] index_sum += i[0]
ret += (i[1],) ret += (i[1],)
return index_sum, ret return index_sum, ret


x = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)))
x = Tensor(np.arange(4))
net = Net()
net(x, x)


def test_enumerate_tensor_parameter_1():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()

def construct(self, x):
index_sum = 0
ret = ()
for i in enumerate(x):
index_sum += i[0]
ret += (i[1],)
return index_sum, ret

x = Tensor(np.arange(2 * 3).reshape(2, 3))
net = Net() net = Net()
net(x, x, x)
net(x)




def test_enumerate_tuple_const_2(): def test_enumerate_tuple_const_2():
@@ -152,38 +219,59 @@ def test_enumerate_tuple_const_2():
assert net() == (10, 110) assert net() == (10, 110)




def test_enumerate_tensor_const_2():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.value = Tensor(np.arange(2 * 3).reshape(2, 3))

def construct(self):
index_sum = 0
ret = ()
for i in enumerate(self.value, 1):
index_sum += i[0]
ret += (i[1],)
return index_sum, ret

net = Net()
net()


def test_enumerate_tuple_parameter_2(): def test_enumerate_tuple_parameter_2():
class Net(nn.Cell): class Net(nn.Cell):
def __init__(self): def __init__(self):
super(Net, self).__init__() super(Net, self).__init__()


def construct(self, x, y, z):
def construct(self, x, y):
index_sum = 0 index_sum = 0
value = (x, y, z)
value = (x, y)
ret = () ret = ()
for i in enumerate(value, 2):
for i in enumerate(value, 1):
index_sum += i[0] index_sum += i[0]
ret += (i[1],) ret += (i[1],)
return index_sum, ret return index_sum, ret


x = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)))
x = Tensor(np.arange(4))
net = Net() net = Net()
net(x, x, x)
net(x, x)




def test_enumerate_first_input_type_error():
def test_enumerate_tensor_parameter_2():
class Net(nn.Cell): class Net(nn.Cell):
def __init__(self): def __init__(self):
super(Net, self).__init__() super(Net, self).__init__()


def construct(self, x): def construct(self, x):
return enumerate(x)
index_sum = 0
ret = ()
for i, j in enumerate(x, 1):
index_sum += i
ret += (j,)
return index_sum, ret


x = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)))
x = Tensor(np.arange(2 * 3).reshape(2, 3))
net = Net() net = Net()
with pytest.raises(TypeError) as ex:
net(x)
assert "For 'enumerate', the 'first input'" in str(ex.value)
net(x)




def test_enumerate_start_type_error(): def test_enumerate_start_type_error():


Loading…
Cancel
Save