|
|
|
@@ -146,9 +146,9 @@ class TensorAssignWithSlice(Cell): |
|
|
|
return z |
|
|
|
|
|
|
|
|
|
|
|
class TensorIndexByOneTensor(Cell): |
|
|
|
class TensorGetItemByOneTensor(Cell): |
|
|
|
def __init__(self): |
|
|
|
super(TensorIndexByOneTensor, self).__init__() |
|
|
|
super(TensorGetItemByOneTensor, self).__init__() |
|
|
|
self.const = Tensor(np.ones((5, 4, 7, 8)), mstype.int32) |
|
|
|
|
|
|
|
def construct(self, x, index): |
|
|
|
@@ -156,9 +156,9 @@ class TensorIndexByOneTensor(Cell): |
|
|
|
return ret |
|
|
|
|
|
|
|
|
|
|
|
class TensorIndexByTwoTensors(Cell): |
|
|
|
class TensorGetItemByTwoTensors(Cell): |
|
|
|
def __init__(self): |
|
|
|
super(TensorIndexByTwoTensors, self).__init__() |
|
|
|
super(TensorGetItemByTwoTensors, self).__init__() |
|
|
|
self.const = Tensor(np.ones((3, 4, 5, 8)), mstype.int32) |
|
|
|
|
|
|
|
def construct(self, x, index_0, index_1): |
|
|
|
@@ -166,9 +166,9 @@ class TensorIndexByTwoTensors(Cell): |
|
|
|
return ret |
|
|
|
|
|
|
|
|
|
|
|
class TensorIndexByThreeTensors(Cell): |
|
|
|
class TensorGetItemByThreeTensors(Cell): |
|
|
|
def __init__(self): |
|
|
|
super(TensorIndexByThreeTensors, self).__init__() |
|
|
|
super(TensorGetItemByThreeTensors, self).__init__() |
|
|
|
self.const = Tensor(np.ones((5, 3, 4, 5)), mstype.int32) |
|
|
|
|
|
|
|
def construct(self, x, index_0, index_1, index_2): |
|
|
|
@@ -176,6 +176,15 @@ class TensorIndexByThreeTensors(Cell): |
|
|
|
return ret |
|
|
|
|
|
|
|
|
|
|
|
class TensorGetItemByMixedTensors(Cell): |
|
|
|
def __init__(self): |
|
|
|
super(TensorGetItemByMixedTensors, self).__init__() |
|
|
|
|
|
|
|
def construct(self, x, index_0, index_1): |
|
|
|
ret = x[index_0, index_1, 0:6] |
|
|
|
return ret |
|
|
|
|
|
|
|
|
|
|
|
class TensorSetItemByOneTensorWithNumber(Cell): |
|
|
|
def __init__(self, value): |
|
|
|
super(TensorSetItemByOneTensorWithNumber, self).__init__() |
|
|
|
@@ -300,6 +309,19 @@ class TensorSetItemByTensorsWithTupleOfTensorNumberError(Cell): |
|
|
|
return ret |
|
|
|
|
|
|
|
|
|
|
|
class TensorSetItemByMixedTensors(Cell): |
|
|
|
def __init__(self): |
|
|
|
super(TensorSetItemByMixedTensors, self).__init__() |
|
|
|
self.const = Tensor(np.ones((6, 7, 8)), mstype.float32) |
|
|
|
self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x") |
|
|
|
self.value = 99.0 |
|
|
|
|
|
|
|
def construct(self, index_0, index_1): |
|
|
|
self.param[index_0, index_1, 0:6] = self.value |
|
|
|
ret = self.param + self.const |
|
|
|
return ret |
|
|
|
|
|
|
|
|
|
|
|
def test_tensor_assign(): |
|
|
|
context.set_context(mode=context.GRAPH_MODE, save_graphs=True) |
|
|
|
net = TensorAssignWithSlice() |
|
|
|
@@ -596,19 +618,19 @@ test_cases = [ |
|
|
|
'block': NetWorkSliceEllipsis(), |
|
|
|
'desc_inputs': [Tensor(np.ones([6, 7, 8, 9], np.int32))], |
|
|
|
}), |
|
|
|
('TensorIndexByOneTensor', { |
|
|
|
'block': TensorIndexByOneTensor(), |
|
|
|
('TensorGetItemByOneTensor', { |
|
|
|
'block': TensorGetItemByOneTensor(), |
|
|
|
'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32), |
|
|
|
Tensor(np.random.randint(6, size=(5, 4)), mstype.int32)], |
|
|
|
}), |
|
|
|
('TensorIndexByTwoTensors', { |
|
|
|
'block': TensorIndexByTwoTensors(), |
|
|
|
('TensorGetItemByTwoTensors', { |
|
|
|
'block': TensorGetItemByTwoTensors(), |
|
|
|
'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32), |
|
|
|
Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), |
|
|
|
Tensor(np.random.randint(7, size=(4, 5)), mstype.int32)], |
|
|
|
}), |
|
|
|
('TensorIndexByThreeTensors', { |
|
|
|
'block': TensorIndexByThreeTensors(), |
|
|
|
('TensorGetItemByThreeTensors', { |
|
|
|
'block': TensorGetItemByThreeTensors(), |
|
|
|
'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32), |
|
|
|
Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), |
|
|
|
Tensor(np.random.randint(7, size=(4, 5)), mstype.int32), |
|
|
|
@@ -665,37 +687,43 @@ test_cases = [ |
|
|
|
] |
|
|
|
|
|
|
|
raise_error_set = [ |
|
|
|
('TensorIndexByOneTensorDtypeError', { |
|
|
|
'block': (TensorIndexByOneTensor(), {'exception': TypeError}), |
|
|
|
('TensorGetItemByOneTensorDtypeError', { |
|
|
|
'block': (TensorGetItemByOneTensor(), {'exception': TypeError}), |
|
|
|
'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32), |
|
|
|
Tensor(np.random.randint(6, size=(5, 4)), mstype.int8)], |
|
|
|
}), |
|
|
|
('TensorIndexByTwoTensorsShapeError', { |
|
|
|
'block': (TensorIndexByTwoTensors(), {'exception': ValueError}), |
|
|
|
('TensorGetItemByTwoTensorsShapeError', { |
|
|
|
'block': (TensorGetItemByTwoTensors(), {'exception': ValueError}), |
|
|
|
'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32), |
|
|
|
Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), |
|
|
|
Tensor(np.random.randint(7, size=(2, 3, 5)), mstype.int32)], |
|
|
|
}), |
|
|
|
('TensorIndexByTwoTensorsDtypeError', { |
|
|
|
'block': (TensorIndexByTwoTensors(), {'exception': TypeError}), |
|
|
|
('TensorGetItemByTwoTensorsDtypeError', { |
|
|
|
'block': (TensorGetItemByTwoTensors(), {'exception': TypeError}), |
|
|
|
'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32), |
|
|
|
Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), |
|
|
|
Tensor(np.random.randint(7, size=(4, 5)), mstype.float32)], |
|
|
|
}), |
|
|
|
('TensorIndexByThreeTensorsShapeError', { |
|
|
|
'block': (TensorIndexByThreeTensors(), {'exception': ValueError}), |
|
|
|
('TensorGetItemByThreeTensorsShapeError', { |
|
|
|
'block': (TensorGetItemByThreeTensors(), {'exception': ValueError}), |
|
|
|
'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32), |
|
|
|
Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), |
|
|
|
Tensor(np.random.randint(7, size=(3, 4, 5)), mstype.int32), |
|
|
|
Tensor(np.random.randint(8, size=(5, 2, 4, 5)), mstype.int32)], |
|
|
|
}), |
|
|
|
('TensorIndexByThreeTensorsDtypeError', { |
|
|
|
'block': (TensorIndexByThreeTensors(), {'exception': TypeError}), |
|
|
|
('TensorGetItemByThreeTensorsDtypeError', { |
|
|
|
'block': (TensorGetItemByThreeTensors(), {'exception': TypeError}), |
|
|
|
'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32), |
|
|
|
Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), |
|
|
|
Tensor(np.random.randint(7, size=(3, 4, 5)), mstype.int64), |
|
|
|
Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)], |
|
|
|
}), |
|
|
|
('TensorGetItemByMixedTensors', { |
|
|
|
'block': (TensorGetItemByMixedTensors(), {'exception': IndexError}), |
|
|
|
'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32), |
|
|
|
Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), |
|
|
|
Tensor(np.random.randint(7, size=(3, 4, 5)), mstype.int64)], |
|
|
|
}), |
|
|
|
('TensorSetItemByOneTensorWithNumberTypeError', { |
|
|
|
'block': (TensorSetItemByOneTensorWithNumber(value=0), {'exception': TypeError}), |
|
|
|
'desc_inputs': [Tensor(np.random.randint(4, size=(5, 4)), mstype.int32)], |
|
|
|
@@ -781,6 +809,11 @@ raise_error_set = [ |
|
|
|
Tensor(np.zeros((4, 5)), mstype.float32), |
|
|
|
Tensor(np.ones((4, 5)), mstype.int32), |
|
|
|
Tensor(np.ones((4, 5)) * 2, mstype.int32)], |
|
|
|
}), |
|
|
|
('TensorSetItemByMixedTensors', { |
|
|
|
'block': (TensorSetItemByMixedTensors(), {'exception': IndexError}), |
|
|
|
'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), |
|
|
|
Tensor(np.random.randint(7, size=(4, 5)), mstype.int32)], |
|
|
|
}) |
|
|
|
] |
|
|
|
|
|
|
|
|