Browse Source

!721 Tensor assign by ellipsis

Merge pull request !721 from candanzg/tensor_assgin_ellipsis
tags/v0.3.0-alpha
mindspore-ci-bot Gitee 5 years ago
parent
commit
33e427e1c5
9 changed files with 254 additions and 86 deletions
  1. +2
    -2
      mindspore/_extends/parse/__init__.py
  2. +6
    -1
      mindspore/_extends/parse/parser.py
  3. +7
    -0
      mindspore/_extends/utils.py
  4. +1
    -0
      mindspore/ccsrc/pipeline/parse/parse_base.h
  5. +6
    -0
      mindspore/ccsrc/pipeline/static_analysis/prim.cc
  6. +2
    -0
      mindspore/ccsrc/utils/convert_utils.cc
  7. +62
    -19
      mindspore/ops/composite/multitype_ops/_multitype_ops_util.py
  8. +33
    -4
      mindspore/ops/composite/multitype_ops/setitem_impl.py
  9. +135
    -60
      tests/ut/python/ops/test_tensor_slice.py

+ 2
- 2
mindspore/_extends/parse/__init__.py View File

@@ -22,11 +22,11 @@ from .parser import (Parser, create_obj_instance, generate_scope,
get_dataclass_attributes, get_dataclass_methods, get_dataclass_attributes, get_dataclass_methods,
get_module_namespace, get_obj_type, get_object_key, get_module_namespace, get_obj_type, get_object_key,
get_parse_method_of_class, get_scope_name, get_parse_method_of_class, get_scope_name,
is_class_member, parse_cb, resolve_symbol)
is_class_member, parse_cb, resolve_symbol, create_ellipsis_obj)
from .serialize import * from .serialize import *


__all__ = ['parse_cb', 'get_parse_method_of_class', 'get_bprop_method_of_class', 'resolve_symbol', __all__ = ['parse_cb', 'get_parse_method_of_class', 'get_bprop_method_of_class', 'resolve_symbol',
'get_object_key', 'get_class_instance_type', 'is_class_member', 'get_obj_type', 'get_object_key', 'get_class_instance_type', 'is_class_member', 'get_obj_type',
'create_obj_instance', 'get_module_namespace', 'get_class_member_namespace_symbol', 'create_obj_instance', 'get_module_namespace', 'get_class_member_namespace_symbol',
'Parser', 'get_dataclass_attributes', 'get_dataclass_methods', 'dump_obj', 'load_obj', 'Parser', 'get_dataclass_attributes', 'get_dataclass_methods', 'dump_obj', 'load_obj',
'get_dataclass_methods', 'get_scope_name', 'create_slice_obj']
'get_dataclass_methods', 'get_scope_name', 'create_slice_obj', 'create_ellipsis_obj']

+ 6
- 1
mindspore/_extends/parse/parser.py View File

@@ -29,7 +29,7 @@ from mindspore.common.dtype import pytype_to_dtype
from mindspore.common.api import _MindSporeFunction from mindspore.common.api import _MindSporeFunction
from .namespace import CellNamespace, ClosureNamespace, ClassMemberNamespace from .namespace import CellNamespace, ClosureNamespace, ClassMemberNamespace
from .resources import parse_object_map, convert_object_map, trope_ns, SYMBOL_UNDEFINE, NO_IMPLEMENT from .resources import parse_object_map, convert_object_map, trope_ns, SYMBOL_UNDEFINE, NO_IMPLEMENT
from ..utils import Slice
from ..utils import Slice, Ellipsis_


# define return value # define return value
RET_SUCCESS = 0 RET_SUCCESS = 0
@@ -70,6 +70,11 @@ parse_expr_statement_white_list = (
"append", "append",
) )


def create_ellipsis_obj():
"""Create Slice object"""
return Ellipsis_()


def create_slice_obj(start, end, step): def create_slice_obj(start, end, step):
"""Create Slice object""" """Create Slice object"""
return Slice(start, end, step) return Slice(start, end, step)


+ 7
- 0
mindspore/_extends/utils.py View File

@@ -110,3 +110,10 @@ class Slice:
start: int start: int
end: int end: int
step: int step: int


@dataclass
class Ellipsis_:
"""
Ellipsis class
"""

+ 1
- 0
mindspore/ccsrc/pipeline/parse/parse_base.h View File

@@ -80,6 +80,7 @@ const char PYTHON_PARSE_GENERATE_SCOPE[] = "generate_scope";
const char PYTHON_PARSE_GET_SCOPE_NAME[] = "get_scope_name"; const char PYTHON_PARSE_GET_SCOPE_NAME[] = "get_scope_name";


const char PYTHON_PARSE_CLASS_SLICE[] = "create_slice_obj"; const char PYTHON_PARSE_CLASS_SLICE[] = "create_slice_obj";
const char PYTHON_PARSE_CLASS_ELLIPSIS[] = "create_ellipsis_obj";


// define the common name // define the common name
const char NAMED_PRIMITIVE_ITER[] = "iter"; const char NAMED_PRIMITIVE_ITER[] = "iter";


+ 6
- 0
mindspore/ccsrc/pipeline/static_analysis/prim.cc View File

@@ -298,6 +298,12 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) {
} else if (abs_base->isa<AbstractRef>()) { } else if (abs_base->isa<AbstractRef>()) {
auto value = abs_base->cast<AbstractRefPtr>()->ref(); auto value = abs_base->cast<AbstractRefPtr>()->ref();
dic = ConvertAbstractToPython(value); dic = ConvertAbstractToPython(value);
} else if (abs_base->isa<AbstractEllipsis>()) {
auto arg_slice = dyn_cast<AbstractEllipsis>(abs_base);
std::vector<int> shape;
dic["shape"] = shape;
dic["dtype"] = arg_slice->BuildType();
dic["value"] = BuildValue(arg_slice->BuildValue());
} else if (abs_base->isa<AbstractTuple>()) { } else if (abs_base->isa<AbstractTuple>()) {
auto arg_tuple = dyn_cast<AbstractTuple>(abs_base); auto arg_tuple = dyn_cast<AbstractTuple>(abs_base);
size_t len = arg_tuple->size(); size_t len = arg_tuple->size();


+ 2
- 0
mindspore/ccsrc/utils/convert_utils.cc View File

@@ -98,6 +98,8 @@ py::object ValuePtrToPyData(const ValuePtr &value) {
i++; i++;
} }
ret = rets; ret = rets;
} else if (value->isa<EllipsisObj>()) {
ret = parse::python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE, parse::PYTHON_PARSE_CLASS_ELLIPSIS);
} else if (value->isa<ValueSlice>()) { } else if (value->isa<ValueSlice>()) {
auto slice = value->cast<ValueSlicePtr>(); auto slice = value->cast<ValueSlicePtr>();
auto start = ValuePtrToPyData(slice->start()); auto start = ValuePtrToPyData(slice->start());


+ 62
- 19
mindspore/ops/composite/multitype_ops/_multitype_ops_util.py View File

@@ -20,7 +20,7 @@ import numpy as np
from ...primitive import constexpr from ...primitive import constexpr
from ....common.tensor import Tensor from ....common.tensor import Tensor
from ....common import dtype as mstype from ....common import dtype as mstype
from ...._extends.utils import Slice
from ...._extends.utils import Slice, Ellipsis_


@constexpr @constexpr
def check_equal(param1, param2, msg="{},{}"): def check_equal(param1, param2, msg="{},{}"):
@@ -29,31 +29,40 @@ def check_equal(param1, param2, msg="{},{}"):
raise ValueError(msg.format(param1, param2)) raise ValueError(msg.format(param1, param2))
return param1 return param1



@constexpr
def check_ellipsis_shape_size(data_shape, value_shape, data_size, value_size):
"""Checks the shape and size of the sensor and value."""
if data_shape == value_shape or data_size == value_size or value_size == 1:
return True
raise ValueError("The value(shape={}), can not assign to tensor(shape={}).".format(value_shape, data_shape))


@constexpr @constexpr
def check_tensor_setitem_index(index, element_type=None): def check_tensor_setitem_index(index, element_type=None):
"""Checks tuple index type of tensor assignment.""" """Checks tuple index type of tensor assignment."""
if index is None: if index is None:
raise ValueError("Tensor's index cannot be None.")
raise IndexError("Tensor's index cannot be None.")
# eg. Tensor[Slice] = u # eg. Tensor[Slice] = u
if isinstance(index, Slice): if isinstance(index, Slice):
return True return True
# eg. Tensor[tuple] = u # eg. Tensor[tuple] = u
if isinstance(index, tuple): if isinstance(index, tuple):
if not index: if not index:
raise ValueError("Tensor's index cannot be empty.")
raise IndexError("Tensor's index cannot be empty.")
# eg. Tensor[tuple(Slice...)] = u # eg. Tensor[tuple(Slice...)] = u
if isinstance(index[0], (Slice, int)):
if isinstance(index[0], (Slice, Ellipsis_, int)):
return True return True
raise ValueError("Index of type '{}' is not supported yet.".format(type(index[0])))
raise IndexError("Index of type '{}' is not supported yet.".format(type(index[0])))
# eg. Tensor[Tensor[dtype=bool]] = u # eg. Tensor[Tensor[dtype=bool]] = u
if index == mstype.tensor: if index == mstype.tensor:
if element_type is None or element_type != mstype.bool_: if element_type is None or element_type != mstype.bool_:
raise ValueError(
"The index of tensor should be a bool type tensor. \
{} type is not supported yet.".format(element_type))
raise TypeError(
"The index of tensor should be a bool type tensor. "
"{} type is not supported yet.".format(element_type))
return True return True


raise ValueError("Index of type '{}' is not supported yet.".format(type(index)))
raise IndexError("Index of type '{}' is not supported yet.".format(type(index)))




@constexpr @constexpr
@@ -90,10 +99,18 @@ def slice_expand(input_slices, shape):
# Slice or tuple(Slice...) # Slice or tuple(Slice...)
if isinstance(input_slices, Slice): if isinstance(input_slices, Slice):
slices = (input_slices,) slices = (input_slices,)
elif isinstance(input_slices, (tuple, list)) and input_slices and isinstance(input_slices[0], Slice):
slices = input_slices
elif isinstance(input_slices, (tuple, list)) and input_slices and isinstance(input_slices[0], (Slice, Ellipsis_)):
is_have_ellipsis = False
for _, element in enumerate(input_slices):
if isinstance(element, Ellipsis_):
is_have_ellipsis = True
break
if is_have_ellipsis:
slices = ellipsis2slice(input_slices, shape)
else:
slices = input_slices
else: else:
raise ValueError("Tensor's index type is not supported yet.")
raise IndexError("Tensor's index type is not supported yet.")


for s in slices: for s in slices:
start = 0 if (s.start is None) else s.start start = 0 if (s.start is None) else s.start
@@ -111,6 +128,26 @@ def slice_expand(input_slices, shape):
return begin, end, strides return begin, end, strides




def ellipsis2slice(input_, shape):
"""Converts ellipsis to slice."""
input_slice = input_
result = []
if isinstance(input_, Ellipsis_):
input_slice = (input_,)
ell_count = 0
for _, element in enumerate(input_slice):
if not isinstance(element, Ellipsis_):
result.append(element)
continue
ell_count += 1
if ell_count > 1:
raise IndexError("There cannot be more than one ellisis (...) in the index of the tensor, "
"but it is currently {}".format(input_slice))
for _ in range(len(shape) - len(input_slice) + 1):
result.append(Slice(None, None, None))
return tuple(result)


@constexpr @constexpr
def slice2indices(input_slices, shape): def slice2indices(input_slices, shape):
""" """
@@ -139,7 +176,7 @@ def slice2indices(input_slices, shape):
def check_indices(indices_size, index): def check_indices(indices_size, index):
"""Checks indices whether is empty.""" """Checks indices whether is empty."""
if indices_size < 1: if indices_size < 1:
raise ValueError("The tensor's index is unreasonable. index:{}".format(index))
raise IndexError("The tensor's index is unreasonable. index:{}".format(index))
return indices_size return indices_size




@@ -151,8 +188,8 @@ def check_indices_value_size(indices_size, value_size):
if value_size > 1: if value_size > 1:
if value_size != indices_size: if value_size != indices_size:
raise ValueError( raise ValueError(
"The value given to tensor does not match the index size. \
value size:{}, indics size:{}".format(value_size, indices_size))
"The value given to tensor does not match the index size,"
" value size:{}, indics size:{}".format(value_size, indices_size))
return value_size return value_size


@constexpr @constexpr
@@ -168,8 +205,11 @@ def integer_to_indices(index, shape):
def tuple_element_is_slice(indexs): def tuple_element_is_slice(indexs):
"""Judges tuple element type.""" """Judges tuple element type."""
if not indexs: if not indexs:
raise ValueError("Tensor's index cannot be empty.")
if isinstance(indexs, tuple) and isinstance(indexs[0], Slice):
raise IndexError("Tensor's index cannot be empty.")
if isinstance(indexs, tuple):
for _, ele in enumerate(indexs):
if not isinstance(ele, Slice):
return False
return True return True
return False return False


@@ -177,7 +217,10 @@ def tuple_element_is_slice(indexs):
def tuple_element_is_int(indexs): def tuple_element_is_int(indexs):
"""Judges tuple element type.""" """Judges tuple element type."""
if not indexs: if not indexs:
raise ValueError("Tensor's index cannot be empty.")
if isinstance(indexs, tuple) and isinstance(indexs[0], int):
raise IndexError("Tensor's index cannot be empty.")
if isinstance(indexs, tuple):
for _, ele in enumerate(indexs):
if not isinstance(ele, int):
return False
return True return True
return False return False

+ 33
- 4
mindspore/ops/composite/multitype_ops/setitem_impl.py View File

@@ -254,10 +254,10 @@ def _tensor_indices_tensor(data, data_shape, index, indices, value):
data_dtype = F.dtype(data) data_dtype = F.dtype(data)
indices_size = F.size(indices) indices_size = F.size(indices)
indices_size = mult_util.check_indices(indices_size, index) indices_size = mult_util.check_indices(indices_size, index)
update = F.fill(data_dtype, (indices_size,), 1)
update = F.fill(mstype.int32, (indices_size,), 1)
condition_1d = F.scatter_nd(indices, update, (data_size,)) condition_1d = F.scatter_nd(indices, update, (data_size,))
condition_1d = F.cast(condition_1d, mstype.bool_)
condition = F.reshape(condition_1d, data_shape) condition = F.reshape(condition_1d, data_shape)
condition = F.cast(condition, mstype.bool_)
value_fill = None value_fill = None
value_size = F.size(value) value_size = F.size(value)


@@ -336,10 +336,10 @@ def _tensor_indices_number(data, data_shape, index, indices, value):
data_dtype = F.dtype(data) data_dtype = F.dtype(data)
indices_size = F.size(indices) indices_size = F.size(indices)
indices_size = mult_util.check_indices(indices_size, index) indices_size = mult_util.check_indices(indices_size, index)
update = F.fill(data_dtype, (indices_size,), 1)
update = F.fill(mstype.int32, (indices_size,), 1)
condition_1d = F.scatter_nd(indices, update, (data_size,)) condition_1d = F.scatter_nd(indices, update, (data_size,))
condition_1d = F.cast(condition_1d, mstype.bool_)
condition = F.reshape(condition_1d, data_shape) condition = F.reshape(condition_1d, data_shape)
condition = F.cast(condition, mstype.bool_)
value_fill = F.fill(data_dtype, (indices_size,), value) value_fill = F.fill(data_dtype, (indices_size,), value)
value_1d = F.scatter_nd(indices, value_fill, (data_size,)) value_1d = F.scatter_nd(indices, value_fill, (data_size,))
u = F.reshape(value_1d, data_shape) u = F.reshape(value_1d, data_shape)
@@ -360,3 +360,32 @@ def _tensor_setitem_with_int_v2(data, index, value):
data_shape = F.shape(data) data_shape = F.shape(data)
indices = mult_util.integer_to_indices(index, data_shape) indices = mult_util.integer_to_indices(index, data_shape)
return _tensor_indices_tensor(data, data_shape, index, indices, value) return _tensor_indices_tensor(data, data_shape, index, indices, value)


@setitem.register("Tensor", "Ellipsis", "Number")
def _tensor_setitem_with_ellipsis_v1(data, index, value):
"""Syntax: A[...] = number."""
data_shape = F.shape(data)
data_dtype = F.dtype(data)
return F.fill(data_dtype, data_shape, value)


@setitem.register("Tensor", "Ellipsis", "Tensor")
def _tensor_setitem_with_ellipsis_v2(data, index, value):
"""Syntax: A[...] = Tensor."""
result = None
data_shape = F.shape(data)
data_dtype = F.dtype(data)
data_size = F.size(data)
value_shape = F.shape(value)
value_size = F.size(value)
check_result = mult_util.check_ellipsis_shape_size(data_shape, value_shape, data_size, value_size)
if check_result:
if data_size == value_size:
result = F.reshape(value, data_shape)
result = F.cast(result, data_dtype)
elif value_size == 1:
param1 = F.fill(data_dtype, data_shape, 1)
param2 = F.cast(value, data_dtype)
result = F.tensor_mul(param1, param2)
return result

+ 135
- 60
tests/ut/python/ops/test_tensor_slice.py View File

@@ -103,6 +103,7 @@ class TensorAssignWithSliceError1(Cell):
a[1:3:-1,::] = b a[1:3:-1,::] = b
return a return a



class TensorAssignWithSliceError2(Cell): class TensorAssignWithSliceError2(Cell):
def __init__(self): def __init__(self):
super(TensorAssignWithSliceError2, self).__init__() super(TensorAssignWithSliceError2, self).__init__()
@@ -110,24 +111,29 @@ class TensorAssignWithSliceError2(Cell):
def construct(self, a, b): def construct(self, a, b):
a[1:3:-1] = b a[1:3:-1] = b
return a return a


class TensorAssignWithSlice2(Cell): class TensorAssignWithSlice2(Cell):
def __init__(self): def __init__(self):
super(TensorAssignWithSlice2, self).__init__() super(TensorAssignWithSlice2, self).__init__()


def construct(self, a, b):
def construct(self, a, b, ck):
a[1:5] = b a[1:5] = b
a[3:4] = 5 a[3:4] = 5
a[-1:1:-1] = b a[-1:1:-1] = b
a[-1:3:-1] = 5 a[-1:3:-1] = 5
a[::] = b a[::] = b
a[::] = 9 a[::] = 9
return a
z = a + ck
return z


class TensorAssignWithSlice(Cell): class TensorAssignWithSlice(Cell):
def __init__(self): def __init__(self):
super(TensorAssignWithSlice, self).__init__() super(TensorAssignWithSlice, self).__init__()
self.c = 2 self.c = 2


def construct(self, a, b):
def construct(self, a, b, ck):
a[1:3,::] = b a[1:3,::] = b
a[2:3:,3:] = b a[2:3:,3:] = b
a[::] = b a[::] = b
@@ -136,9 +142,10 @@ class TensorAssignWithSlice(Cell):
a[::,::] = self.c a[::,::] = self.c
a[2:3:,0:, 4:1:-1] = b a[2:3:,0:, 4:1:-1] = b
a[2:3:,0:, 4:1:-1] = self.c a[2:3:,0:, 4:1:-1] = self.c
z = a
z = a + ck
return z return z



def test_tensor_assign(): def test_tensor_assign():
context.set_context(mode=context.GRAPH_MODE, save_graphs=True) context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
net = TensorAssignWithSlice() net = TensorAssignWithSlice()
@@ -146,95 +153,145 @@ def test_tensor_assign():
net_e1 = TensorAssignWithSliceError1() net_e1 = TensorAssignWithSliceError1()
net_e2 = TensorAssignWithSliceError2() net_e2 = TensorAssignWithSliceError2()
a = np.arange(60).reshape(3,4,5) a = np.arange(60).reshape(3,4,5)
b = Tensor([1])
Ta = Tensor(a)
Ta4d = Tensor(a.reshape(1,3,4,5))
Tb= Tensor([1,3])
Tc= Tensor([])
t = Tensor([1, 2, 3, 4, 5, 6, 7, 8])
net(Ta, b)
net2(t, b)
ck = np.arange(60).reshape(3,4,5)
b = Tensor([1], dtype=mstype.float32)
Ta = Tensor(a, dtype=mstype.float32)
Tck = Tensor(ck, dtype=mstype.float32)
Ta4d = Tensor(a.reshape(1,3,4,5), dtype=mstype.float32)
Ta4d_ck = Tensor(ck.reshape(1,3,4,5), dtype=mstype.float32)
Tb= Tensor([1,3], dtype=mstype.float32)
Tc= Tensor([], dtype=mstype.float32)
t = Tensor([1, 2, 3, 4, 5, 6, 7, 8], dtype=mstype.float32)
tck = Tensor([1, 2, 3, 4, 5, 6, 7, 8], dtype=mstype.float32)
net(Ta, b, Tck)
net2(t, b, tck)
# Error for A[Slice] = Number # Error for A[Slice] = Number
# 1. A[Slice] = Number, Slice error # 1. A[Slice] = Number, Slice error
with pytest.raises(ValueError):
with pytest.raises(IndexError):
net_e2(t, 2) net_e2(t, 2)


# Error for A[Slice] = U, U is a Tensor # Error for A[Slice] = U, U is a Tensor
# 1. A[Slice] = U, u.size is error # 1. A[Slice] = U, u.size is error
with pytest.raises(ValueError): with pytest.raises(ValueError):
net2(t, Tb)
net2(t, Tb, tck)
# 2. A[Slice] = U, U is empty # 2. A[Slice] = U, U is empty
with pytest.raises(ValueError): with pytest.raises(ValueError):
net2(t, Tc)
net2(t, Tc, tck)
# 3. A[Slice] = U, U.size error # 3. A[Slice] = U, U.size error
with pytest.raises(ValueError): with pytest.raises(ValueError):
net2(t, Tb)
net2(t, Tb, tck)


# Error for A[Tuple(Slice...)] = Tensor # Error for A[Tuple(Slice...)] = Tensor
# 1. A[Tuple(Slice...)] = U, U is empty # 1. A[Tuple(Slice...)] = U, U is empty
with pytest.raises(ValueError): with pytest.raises(ValueError):
net(Ta, Tc)
net(Ta, Tc, Tck)
# 2. A[Tuple(Slice...)] = U, U.size error # 2. A[Tuple(Slice...)] = U, U.size error
with pytest.raises(ValueError): with pytest.raises(ValueError):
net(Ta, Tb)
net(Ta, Tb, Tck)
# 3. A[Tuple(Slice...)] = U, Slice error # 3. A[Tuple(Slice...)] = U, Slice error
with pytest.raises(ValueError):
with pytest.raises(IndexError):
net_e1(Ta, b) net_e1(Ta, b)


# Error for A[Tuple(Slice...)] = Number # Error for A[Tuple(Slice...)] = Number
# 1. A[Tuple(Slice...)] = Number, Slice error # 1. A[Tuple(Slice...)] = Number, Slice error
with pytest.raises(ValueError):
with pytest.raises(IndexError):
net_e1(Ta, 2) net_e1(Ta, 2)


net = TensorAssignWithInteger() net = TensorAssignWithInteger()
# Error for A[Number] = scalar/Tensor # Error for A[Number] = scalar/Tensor
# 1. A[Number] = U, U is a Tensor, u.size not match # 1. A[Number] = U, U is a Tensor, u.size not match
with pytest.raises(ValueError): with pytest.raises(ValueError):
net(Ta, Tb)
net(Ta, Tb, Tck)
with pytest.raises(ValueError): with pytest.raises(ValueError):
net(Ta, Tc)
net(Ta, Tc, Tck)
# 2. A[Number] = U, the number index error # 2. A[Number] = U, the number index error
with pytest.raises(IndexError): with pytest.raises(IndexError):
net(Ta4d, b)
net(Ta4d, b, Ta4d_ck)


# Error for A[(n,m)] = scalar/Tensor # Error for A[(n,m)] = scalar/Tensor
# 1. A[(n,m)] = U, U is a tensor. u.size not match # 1. A[(n,m)] = U, U is a tensor. u.size not match
net = TensorAssignWithTupleInteger() net = TensorAssignWithTupleInteger()
with pytest.raises(ValueError): with pytest.raises(ValueError):
net(Ta, Tc)
net(Ta, Tc, Tck)
with pytest.raises(ValueError): with pytest.raises(ValueError):
net(Ta, Tb)
net(Ta, Tb, Tck)
# 2. A[(n,m)] = U, the number index error # 2. A[(n,m)] = U, the number index error
with pytest.raises(IndexError): with pytest.raises(IndexError):
net(Ta4d, b)
net(Ta4d, b, Ta4d_ck)

#Error for A[...] = U or A[1:, ...] = u
#1. A[...] = scalar/tensor
net = TensorAssignWithEllipsis()
net(Ta, Ta4d)
with pytest.raises(ValueError):
net(Ta, Tc)
with pytest.raises(ValueError):
net(Ta, Tb)
#2. A[::, 1:, ...] = scalar/tensor
net = TensorAssignWithTupleEllipsis()
net(Ta, b)
with pytest.raises(ValueError):
net(Ta, Tc)
with pytest.raises(ValueError):
net(Ta, Tb)


class TensorAssignWithTupleEllipsis2(Cell):
def __init__(self):
super(TensorAssignWithTupleEllipsis2, self).__init__()
def construct(self, a, b):
a[1:, ..., ::] = b
return a


class TensorAssignWithTupleEllipsis(Cell):
def __init__(self):
super(TensorAssignWithTupleEllipsis, self).__init__()
def construct(self, a, b):
a[:2, ...] = 1
a[1:, ...] = b
return a


class TensorAssignWithEllipsis(Cell):
def __init__(self):
super(TensorAssignWithEllipsis, self).__init__()
def construct(self, a, b):
a[...] = 1
a[...] = b
return a



class TensorAssignWithInteger(Cell): class TensorAssignWithInteger(Cell):
def __init__(self): def __init__(self):
super(TensorAssignWithInteger, self).__init__() super(TensorAssignWithInteger, self).__init__()


def construct(self, a, b):
def construct(self, a, b, ck):
a[1] = 1 a[1] = 1
a[0] = b a[0] = b
return a
z = a + ck
return z


class TensorAssignWithTupleInteger(Cell): class TensorAssignWithTupleInteger(Cell):
def __init__(self): def __init__(self):
super(TensorAssignWithTupleInteger, self).__init__() super(TensorAssignWithTupleInteger, self).__init__()


def construct(self, a, b):
def construct(self, a, b, ck):
a[(1)] = 1 a[(1)] = 1
a[(1)] = b a[(1)] = b
a[(1,1)] = b a[(1,1)] = b
a[(1,1)] = 1 a[(1,1)] = 1
return a
z = a + ck
return z


class TensorAssignWithBoolTensorIndex(Cell): class TensorAssignWithBoolTensorIndex(Cell):
def __init__(self): def __init__(self):
super(TensorAssignWithBoolTensorIndex, self).__init__() super(TensorAssignWithBoolTensorIndex, self).__init__()
self.t = Tensor(np.arange(60).reshape([3,4,5]), dtype = mstype.float64)
self.t = Tensor(np.arange(60).reshape([3,4,5]), dtype = mstype.float32)
self.u_scalar = 5


def construct(self, a, b, c, u_tensor, _scalar):
a[c] = u_scalar
def construct(self, a, b, c, u_tensor):
a[c] = self.u_scalar
a[b] = u_tensor a[b] = u_tensor
z = a + self.t z = a + self.t
return z return z
@@ -252,15 +309,16 @@ class TensorAssignWithBoolTensorIndexError(Cell):
class TensorAssignWithBoolTensorIndex2(Cell): class TensorAssignWithBoolTensorIndex2(Cell):
def __init__(self): def __init__(self):
super(TensorAssignWithBoolTensorIndex2, self).__init__() super(TensorAssignWithBoolTensorIndex2, self).__init__()
self.t = Tensor(np.arange(6).reshape([2, 3]), dtype=mstype.float64)
self.t = Tensor(np.arange(60).reshape([3,4,5]), dtype = mstype.float64)
self.t = Tensor(np.arange(6).reshape([2, 3]), dtype=mstype.float32)
self.t = Tensor(np.arange(60).reshape([3,4,5]), dtype = mstype.float32)
self.u_scalar = 5


def construct(self, a, u_tensor, _scalar):
def construct(self, a, u_tensor):
a[a > 8] = u_tensor a[a > 8] = u_tensor
a[a >= 6] = u_scalar
a[a < 3] = u_scalar
a[a >= 6] = self.u_scalar
a[a < 3] = self.u_scalar
a[a <= 5] = u_tensor a[a <= 5] = u_tensor
a[a == 5] = u_scalar
a[a == 5] = self.u_scalar
z = a + self.t z = a + self.t
return z return z


@@ -274,36 +332,41 @@ class TensorAssignWithBoolTensorIndex2Error(Cell):
return a return a




a = np.random.uniform(1,10,[3,4,5])
a = np.arange(60).reshape(3, 4, 5)
ck = np.arange(60).reshape(3, 4, 5)
a4 = np.arange(60).reshape(3, 2, 2, 5)
b = a > 5 b = a > 5
c = a < 3 c = a < 3
Ta = Tensor(a)
Ta = Tensor(a, dtype=mstype.float32)
Tck = Tensor(ck, dtype=mstype.float32)
Ta4 = Tensor(a4, dtype=mstype.float32)
Tb = Tensor(b) Tb = Tensor(b)
Tc = Tensor(c) Tc = Tensor(c)
Td = Tensor([True, True]) Td = Tensor([True, True])
u_tensor = Tensor([1])
u_tensor_error = Tensor([1, 2])
t_1d = Tensor([1, 2, 3, 4, 5, 6, 7, 8])
u_tensor = Tensor([1], dtype=mstype.float32)
u_tensor_error = Tensor([1, 2], dtype=mstype.float32)
t_1d = Tensor([1, 2, 3, 4, 5, 6, 7, 8], dtype=mstype.float32)
tck_1d = Tensor([1, 2, 3, 4, 5, 6, 7, 8], dtype=mstype.float32)
u_scalar = 5 u_scalar = 5


def test_tensor_assign_bool_index(): def test_tensor_assign_bool_index():
net1 = TensorAssignWithBoolTensorIndex() net1 = TensorAssignWithBoolTensorIndex()
net2 = TensorAssignWithBoolTensorIndex2() net2 = TensorAssignWithBoolTensorIndex2()
net1(Ta, Tb, Tc, u_tensor, u_scalar)
net1(Ta, Tb, Tc, u_tensor, u_scalar)
with pytest.raises(ValueError):
net1(Ta, Td, Tc, u_tensor, u_scalar)
with pytest.raises(ValueError):
net1(Ta, u_tensor, Tc, u_tensor, u_scalar)
net1(Ta, Tb, Tc, u_tensor)
net1(Ta, Tb, Tc, u_tensor)
with pytest.raises(ValueError): with pytest.raises(ValueError):
net1(Ta, Tb, Td, u_tensor, u_scalar)
net1(Ta, Td, Tc, u_tensor)
with pytest.raises(TypeError):
net1(Ta, u_tensor, Tc, u_tensor)
with pytest.raises(ValueError): with pytest.raises(ValueError):
net1(Ta, Tb, Ta, u_tensor, u_scalar)
net1(Ta, Tb, Td, u_tensor)
with pytest.raises(TypeError):
net1(Ta, Tb, Ta, u_tensor)
with pytest.raises(ValueError): with pytest.raises(ValueError):
net1(Ta, Tb, Tc, u_tensor_error, u_scalar)
net1(Ta, Tb, Tc, u_tensor_error)
# net1(Ta, u_tensor, Tc, u_tensor_error, u_scalar) # net1(Ta, u_tensor, Tc, u_tensor_error, u_scalar)
with pytest.raises(ValueError): with pytest.raises(ValueError):
net2(Ta, u_tensor_error, u_scalar)
net2(Ta, u_tensor_error)
net3 = TensorAssignWithBoolTensorIndexError() net3 = TensorAssignWithBoolTensorIndexError()
with pytest.raises(AttributeError): with pytest.raises(AttributeError):
net3(Ta, Tb, Tc, u_tensor) net3(Ta, Tb, Tc, u_tensor)
@@ -316,29 +379,41 @@ def test_tensor_assign_bool_index():
net4(Ta, u_scalar) net4(Ta, u_scalar)


test_cases = [ test_cases = [
('TensorAssignWithTupleEllipsis2', {
'block': TensorAssignWithTupleEllipsis2(),
'desc_inputs': [Ta4, u_tensor],
}),
('TensorAssignWithTupleEllipsis', {
'block': TensorAssignWithTupleEllipsis(),
'desc_inputs': [Ta, u_tensor],
}),
('TensorAssignWithEllipsis', {
'block': TensorAssignWithEllipsis(),
'desc_inputs': [Ta, u_tensor],
}),
('TensorAssignWithTupleInteger', { ('TensorAssignWithTupleInteger', {
'block': TensorAssignWithTupleInteger(), 'block': TensorAssignWithTupleInteger(),
'desc_inputs': [Ta, u_tensor],
'desc_inputs': [Ta, u_tensor, Tck],
}), }),
('TensorAssignWithInteger', { ('TensorAssignWithInteger', {
'block': TensorAssignWithInteger(), 'block': TensorAssignWithInteger(),
'desc_inputs': [Ta, u_tensor],
'desc_inputs': [Ta, u_tensor, Tck],
}), }),
('TensorAssignWithSlice', { ('TensorAssignWithSlice', {
'block': TensorAssignWithSlice(), 'block': TensorAssignWithSlice(),
'desc_inputs': [Ta, u_tensor],
'desc_inputs': [Ta, u_tensor, Tck],
}), }),
('TensorAssignWithSlice2', { ('TensorAssignWithSlice2', {
'block': TensorAssignWithSlice2(), 'block': TensorAssignWithSlice2(),
'desc_inputs': [t_1d, u_tensor],
'desc_inputs': [t_1d, u_tensor, tck_1d],
}), }),
('TensorAssignWithBoolTensorIndex', { ('TensorAssignWithBoolTensorIndex', {
'block': TensorAssignWithBoolTensorIndex(), 'block': TensorAssignWithBoolTensorIndex(),
'desc_inputs': [Ta, Tb, Tc, u_tensor, u_scalar],
'desc_inputs': [Ta, Tb, Tc, u_tensor],
}), }),
('TensorAssignWithBoolTensorIndex2', { ('TensorAssignWithBoolTensorIndex2', {
'block': TensorAssignWithBoolTensorIndex2(), 'block': TensorAssignWithBoolTensorIndex2(),
'desc_inputs': [Ta, u_tensor, u_scalar],
'desc_inputs': [Ta, u_tensor],
}), }),
('SlicePositive', { ('SlicePositive', {
'block': NetWorkSlicePositive(), 'block': NetWorkSlicePositive(),


Loading…
Cancel
Save