Merge pull request !504 from candanzg/tensor_assign_with_slicetags/v0.2.0-alpha
| @@ -18,7 +18,7 @@ Interfaces for parser module in c++. | |||
| from .parser import (Parser, create_obj_instance, generate_scope, | |||
| get_bprop_method_of_class, get_class_instance_type, | |||
| get_class_member_namespace_symbol, | |||
| get_class_member_namespace_symbol, create_slice_obj, | |||
| get_dataclass_attributes, get_dataclass_methods, | |||
| get_module_namespace, get_obj_type, get_object_key, | |||
| get_parse_method_of_class, get_scope_name, | |||
| @@ -29,4 +29,4 @@ __all__ = ['parse_cb', 'get_parse_method_of_class', 'get_bprop_method_of_class', | |||
| 'get_object_key', 'get_class_instance_type', 'is_class_member', 'get_obj_type', | |||
| 'create_obj_instance', 'get_module_namespace', 'get_class_member_namespace_symbol', | |||
| 'Parser', 'get_dataclass_attributes', 'get_dataclass_methods', 'dump_obj', 'load_obj', | |||
| 'get_dataclass_methods', 'get_scope_name'] | |||
| 'get_dataclass_methods', 'get_scope_name', 'create_slice_obj'] | |||
| @@ -29,6 +29,7 @@ from mindspore.common.dtype import pytype_to_dtype | |||
| from mindspore.common.api import _MindSporeFunction | |||
| from .namespace import CellNamespace, ClosureNamespace, ClassMemberNamespace | |||
| from .resources import parse_object_map, convert_object_map, trope_ns, SYMBOL_UNDEFINE, NO_IMPLEMENT | |||
| from ..utils import Slice | |||
| # define return value | |||
| RET_SUCCESS = 0 | |||
| @@ -69,6 +70,10 @@ parse_expr_statement_white_list = ( | |||
| "append", | |||
| ) | |||
| def create_slice_obj(start, end, step): | |||
| """Create Slice object""" | |||
| return Slice(start, end, step) | |||
| def parse_cb(func, parse_method=None): | |||
| """Implements the function of parse.""" | |||
| @@ -19,6 +19,7 @@ import logging | |||
| import os | |||
| import inspect | |||
| from functools import wraps | |||
| from dataclasses import dataclass | |||
| def cal_sha256(file_path): | |||
| @@ -99,3 +100,13 @@ def cell_attr_register(fn=None, attrs=None): | |||
| if fn is not None: | |||
| return wrap_cell(fn) | |||
| return wrap_cell | |||
| @dataclass | |||
| class Slice: | |||
| """ | |||
| Slice class | |||
| """ | |||
| start: int | |||
| end: int | |||
| step: int | |||
| @@ -123,6 +123,9 @@ class ValueSlice : public Value { | |||
| abstract::AbstractBasePtr ToAbstract() override; | |||
| std::string DumpText() const override { return ToString(); } | |||
| ValuePtr start() const { return start_; } | |||
| ValuePtr stop() const { return stop_; } | |||
| ValuePtr step() const { return step_; } | |||
| private: | |||
| ValuePtr start_; | |||
| @@ -79,6 +79,8 @@ const char PYTHON_PARSE_EXPAND_EXPR_STATEMENT[] = "expand_expr_statement"; | |||
| const char PYTHON_PARSE_GENERATE_SCOPE[] = "generate_scope"; | |||
| const char PYTHON_PARSE_GET_SCOPE_NAME[] = "get_scope_name"; | |||
| const char PYTHON_PARSE_CLASS_SLICE[] = "create_slice_obj"; | |||
| // define the common name | |||
| const char NAMED_PRIMITIVE_ITER[] = "iter"; | |||
| const char NAMED_PRIMITIVE_NEXT[] = "next"; | |||
| @@ -289,6 +289,13 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) { | |||
| dic["shape"] = shape; | |||
| dic["dtype"] = abs_base->BuildType(); | |||
| dic["value"] = BuildValue(abs_base->BuildValue()); | |||
| } else if (abs_base->isa<AbstractSlice>()) { | |||
| auto arg_slice = dyn_cast<AbstractSlice>(abs_base); | |||
| std::vector<int> shape; | |||
| dic["shape"] = shape; | |||
| dic["dtype"] = arg_slice->BuildType(); | |||
| dic["value"] = BuildValue(arg_slice->BuildValue()); | |||
| } else if (abs_base->isa<AbstractTuple>()) { | |||
| auto arg_tuple = dyn_cast<AbstractTuple>(abs_base); | |||
| size_t len = arg_tuple->size(); | |||
| @@ -28,6 +28,7 @@ | |||
| #include "ir/meta_tensor.h" | |||
| #include "pipeline/parse/parse.h" | |||
| #include "pipeline/parse/parse_base.h" | |||
| #include "ir/value.h" | |||
| namespace mindspore { | |||
| @@ -97,6 +98,13 @@ py::object ValuePtrToPyData(const ValuePtr &value) { | |||
| i++; | |||
| } | |||
| ret = rets; | |||
| } else if (value->isa<ValueSlice>()) { | |||
| auto slice = value->cast<ValueSlicePtr>(); | |||
| auto start = ValuePtrToPyData(slice->start()); | |||
| auto end = ValuePtrToPyData(slice->stop()); | |||
| auto step = ValuePtrToPyData(slice->step()); | |||
| ret = parse::python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE, parse::PYTHON_PARSE_CLASS_SLICE, start, end, | |||
| step); | |||
| } else if (value->isa<Type>()) { | |||
| py::tuple v(1); | |||
| v[0] = value->cast<TypePtr>(); | |||
| @@ -15,7 +15,43 @@ | |||
| """constexpr util""" | |||
| import numpy as np | |||
| from ...primitive import constexpr | |||
| from ....common.tensor import Tensor | |||
| from ....common import dtype as mstype | |||
| from ...._extends.utils import Slice | |||
| @constexpr | |||
| def check_equal(param1, param2, msg="{},{}"): | |||
| if param1 != param2: | |||
| raise ValueError(msg.format(param1, param2)) | |||
| return param1 | |||
| @constexpr | |||
| def check_tensor_setitem_index(index, element_type=None): | |||
| """Check tuple index type of tensor assignment.""" | |||
| if index is None: | |||
| raise ValueError("Tensor's index cannot be None.") | |||
| # eg. Tensor[Slice] = u | |||
| if isinstance(index, Slice): | |||
| return True | |||
| # eg. Tensor[Tuple] = u | |||
| if isinstance(index, tuple): | |||
| if not index: | |||
| raise ValueError("Tensor's index cannot be empty.") | |||
| # eg. Tensor[Tuple(Slice...)] = u | |||
| if not isinstance(index[0], Slice): | |||
| raise ValueError("Index of type '{}' is not supported yet.".format(type(index[0]))) | |||
| return True | |||
| # eg. Tensor[Tensor[dtype=bool]] = u | |||
| if index == mstype.tensor: | |||
| if element_type is None or element_type != mstype.bool_: | |||
| raise ValueError( | |||
| "The index of tensor should be a bool type tensor. \ | |||
| {} type is not supported yet.".format(element_type)) | |||
| return True | |||
| raise ValueError("Index of type '{}' is not supported yet.".format(type(index))) | |||
| @constexpr | |||
| @@ -43,3 +79,84 @@ def error_msg(msg="", format_values=""): | |||
| """ | |||
| raise ValueError(msg.format(*format_values)) | |||
| def slice_expand(input_slices, shape): | |||
| """ | |||
| Convert slice to indices. | |||
| Inputs: | |||
| slices (List or Tuple(List, ...)): Slice tuple or slice. | |||
| shape (Tuple): The shape of a sensor is an integer element tuple. | |||
| Outputs: | |||
| (List, List, List), This is expressed as (begins, ends, strides). | |||
| """ | |||
| begin = [] | |||
| end = [] | |||
| strides = [] | |||
| index = 0 | |||
| slices = None | |||
| # Slice or Tuple(Slice...) | |||
| if isinstance(input_slices, Slice): | |||
| slices = (input_slices,) | |||
| elif isinstance(input_slices, (tuple, list)) and input_slices and isinstance(input_slices[0], Slice): | |||
| slices = input_slices | |||
| else: | |||
| raise ValueError("Tensor's index type is not supported yet.") | |||
| for s in slices: | |||
| start = 0 if (s.start is None) else s.start | |||
| stop = shape[index] if (s.end is None) else s.end | |||
| step = 1 if (s.step is None) else s.step | |||
| begin.append(start) | |||
| end.append(stop) | |||
| strides.append(step) | |||
| index += 1 | |||
| while index < len(shape): | |||
| begin.append(0) | |||
| end.append(shape[index]) | |||
| strides.append(1) | |||
| index += 1 | |||
| return begin, end, strides | |||
| @constexpr | |||
| def slice2indices(input_slices, shape): | |||
| """ | |||
| Convert slice to indices. | |||
| Inputs: | |||
| slices (List or Tuple(List, ...)): Slice tuple or slice. | |||
| shape (Tuple): The shape of a sensor is an integer element tuple. | |||
| Outputs: | |||
| Tensor, the shape is (n, 1). | |||
| """ | |||
| begin, end, strides = slice_expand(input_slices, shape) | |||
| np_r = [] | |||
| for i, element in enumerate(shape): | |||
| s = begin[i] if (begin[i] >= 0) else (element + begin[i]) | |||
| e = end[i] if (end[i] >= 0) else (element + end[i]) | |||
| np_r.append(np.r_[s:e:strides[i]]) | |||
| # Reference: np.ravel_multi_index((np.ix_(np.r_[1:3:1], np.r_[0:4:1], np.r_[4:0:-1])), a.shape) | |||
| np_ix = np.ix_(*np_r) | |||
| ravel = np.ravel_multi_index(np_ix, shape) | |||
| ravel = Tensor(ravel.reshape(-1, 1), dtype=mstype.int32) | |||
| return ravel | |||
| @constexpr | |||
| def check_indices(indices_size, index): | |||
| if indices_size < 1: | |||
| raise ValueError("The tensor's index is unreasonable. index:{}".format(index)) | |||
| return indices_size | |||
| @constexpr | |||
| def check_indices_value_size(indices_size, value_size): | |||
| if value_size < 1: | |||
| raise ValueError("The value assigned to tensor cannot be empty.") | |||
| if value_size > 1: | |||
| if value_size != indices_size: | |||
| raise ValueError( | |||
| "The value given to tensor does not match the index size. \ | |||
| value size:{}, indics size:{}".format(value_size, indices_size)) | |||
| return value_size | |||
| @@ -138,25 +138,23 @@ def _tensor_setitem_by_tensor_v1(data, index, value_tensor): | |||
| Outputs: | |||
| Tensor, element type and shape is same as data. | |||
| """ | |||
| result = None | |||
| index_dtype = F.dtype(index) | |||
| index_shape = F.shape(index) | |||
| is_bool = mult_util.is_same_type(index_dtype, mstype.bool_) | |||
| if not is_bool: | |||
| return mult_util.error_msg( | |||
| "The tensor index should be a bool type tensor. {} type tensor is not supported yet.", (index_dtype,)) | |||
| data_shape = F.shape(data) | |||
| if index_shape != data_shape: | |||
| return mult_util.error_msg( | |||
| "The tensor(shape={}) and tensor index(shape={}) should be the same shape.", (data_shape, index_shape)) | |||
| size = F.size(value_tensor) | |||
| if size != 1: | |||
| return mult_util.error_msg( | |||
| "When assign value is a tensor, its size should be 1, but current size is {}.", (size,)) | |||
| dtype = F.dtype(data) | |||
| u_cast = F.cast(value_tensor, dtype) | |||
| one_data = F.ones_like(data) | |||
| u = F.tensor_mul(one_data, u_cast) | |||
| return F.select(index, u, data) | |||
| check_result = mult_util.check_tensor_setitem_index(mstype.tensor, index_dtype) | |||
| if check_result: | |||
| data_shape = F.shape(data) | |||
| data_shape = mult_util.check_equal(data_shape, index_shape, | |||
| "The tensor(shape={}) and tensor index(shape={}) should be the same shape.") | |||
| size = F.size(value_tensor) | |||
| size = mult_util.check_equal(1, size, | |||
| "When assign value is a tensor, its size should be {}, but current size is {}.") | |||
| dtype = F.dtype(data) | |||
| u_cast = F.cast(value_tensor, dtype) | |||
| one_data = F.ones_like(data) | |||
| u = F.tensor_mul(one_data, u_cast) | |||
| result = F.select(index, u, data) | |||
| return result | |||
| @setitem.register("Tensor", "Tensor", "Number") | |||
| @@ -179,16 +177,162 @@ def _tensor_setitem_by_tensor_v2(data, index, value): | |||
| Outputs: | |||
| Tensor, element type and shape is same as data. | |||
| """ | |||
| result = None | |||
| index_dtype = F.dtype(index) | |||
| index_shape = F.shape(index) | |||
| is_bool = mult_util.is_same_type(index_dtype, mstype.bool_) | |||
| if not is_bool: | |||
| return mult_util.error_msg( | |||
| "The tensor index should be a bool type tensor. {} type tensor is not supported yet.", (index_dtype,)) | |||
| shape = F.shape(data) | |||
| if index_shape != shape: | |||
| return mult_util.error_msg( | |||
| "The tensor(shape={}) and tensor index(shape={}) should be the same shape.", (shape, index_shape)) | |||
| dtype = F.dtype(data) | |||
| u = F.fill(dtype, shape, value) | |||
| return F.select(index, u, data) | |||
| check_result = mult_util.check_tensor_setitem_index(mstype.tensor, index_dtype) | |||
| if check_result: | |||
| shape = F.shape(data) | |||
| shape = mult_util.check_equal( | |||
| shape, index_shape, "The tensor(shape={}) and tensor index(shape={}) should be the same shape.") | |||
| dtype = F.dtype(data) | |||
| u = F.fill(dtype, shape, value) | |||
| result = F.select(index, u, data) | |||
| return result | |||
| @setitem.register("Tensor", "Slice", "Tensor") | |||
| def _tensor_setitem_with_slice_v3(data, input_slice, value): | |||
| """ | |||
| Tensor assignment. | |||
| Note: | |||
| Syntax support: A[Slice] = U | |||
| Restraint condition: A is a Tensor | |||
| Slice like "1:3" | |||
| U is a Tensor(size=1) or Tensor(size>1) | |||
| Inputs: | |||
| data (Tensor): Assigned tensor. | |||
| input_slice (Slice): Slice expression. | |||
| value (Number): Assignment value. | |||
| Outputs: | |||
| Tensor, element type and shape is same as data. | |||
| """ | |||
| return _tensor_assgin_tensor(data, input_slice, value) | |||
| @setitem.register("Tensor", "Tuple", "Tensor") | |||
| def _tensor_setitem_with_slice_v4(data, input_slice, value): | |||
| """ | |||
| Tensor assignment. | |||
| Note: | |||
| Syntax support: A[Slice] = U | |||
| Restraint condition: A is a Tensor | |||
| Slice like "1:3, ::, :4:-1" | |||
| U is a Tensor(size=1) or Tensor(size>1) | |||
| Inputs: | |||
| data (Tensor): Assigned tensor. | |||
| input_slice (Tuple(Slice)): Slice expression. | |||
| value (Number): Assignment value. | |||
| Outputs: | |||
| Tensor, element type and shape is same as data. | |||
| """ | |||
| return _tensor_assgin_tensor(data, input_slice, value) | |||
| def _tensor_assgin_tensor(data, input_slice, value): | |||
| """Given a tensor value assign to tensor by slice""" | |||
| # 1. condition | |||
| result = None | |||
| check_result = mult_util.check_tensor_setitem_index(input_slice) | |||
| if check_result: | |||
| data_shape = F.shape(data) | |||
| data_size = F.size(data) | |||
| data_dtype = F.dtype(data) | |||
| indices = mult_util.slice2indices(input_slice, data_shape) | |||
| indices_size = F.size(indices) | |||
| indices_size = mult_util.check_indices(indices_size, input_slice) | |||
| update = F.fill(data_dtype, (indices_size,), 1) | |||
| condition_1d = F.scatter_nd(indices, update, (data_size,)) | |||
| condition_1d = F.cast(condition_1d, mstype.bool_) | |||
| condition = F.reshape(condition_1d, data_shape) | |||
| # 2. u | |||
| value_fill = None | |||
| value_size = F.size(value) | |||
| value_size = mult_util.check_indices_value_size(indices_size, value_size) | |||
| if value_size == 1: | |||
| value_fill = F.fill(data_dtype, (indices_size,), 1) | |||
| value = F.cast(value, data_dtype) | |||
| value_fill = F.tensor_mul(value_fill, value) | |||
| elif value_size > 1: | |||
| value_fill = F.reshape(value, (indices_size,)) | |||
| value_1d = F.scatter_nd(indices, value_fill, (data_size,)) | |||
| u = F.reshape(value_1d, data_shape) | |||
| # A[slice]= u -> A[B]=U -> select(B, U, A) | |||
| result = F.select(condition, u, data) | |||
| return result | |||
| @setitem.register("Tensor", "Slice", "Number") | |||
| def _tensor_setitem_with_slice_v1(data, input_slice, value): | |||
| """ | |||
| Tensor assignment. | |||
| Note: | |||
| Syntax support: A[Slice] = u | |||
| Restraint condition: A is a Tensor. | |||
| Slice like "1:3" | |||
| u is a scalar | |||
| Inputs: | |||
| data (Tensor): Assigned tensor. | |||
| input_slice (Slice): slice expression. | |||
| value (Number): Assignment value. | |||
| Outputs: | |||
| Tensor, element type and shape is same as data. | |||
| """ | |||
| return _tensor_assgin_number(data, input_slice, value) | |||
| @setitem.register("Tensor", "Tuple", "Number") | |||
| def _tensor_setitem_with_slice_v2(data, input_slice, value): | |||
| """ | |||
| Tensor assignment. | |||
| Note: | |||
| Syntax support: A[Slice] = u | |||
| Restraint condition: A is a Tensor. | |||
| Slice like "1:3, ::, :4:-1" | |||
| u is a scalar | |||
| Inputs: | |||
| data (Tensor): Assigned tensor. | |||
| input_slice (Tuple(Slice)): slice expression. | |||
| value (Number): Assignment value. | |||
| Outputs: | |||
| Tensor, element type and shape is same as data. | |||
| """ | |||
| return _tensor_assgin_number(data, input_slice, value) | |||
| def _tensor_assgin_number(data, input_slice, value): | |||
| """Given a scalar assign to tensor by slice""" | |||
| # 1. condition | |||
| check_result = mult_util.check_tensor_setitem_index(input_slice) | |||
| result = None | |||
| if check_result: | |||
| data_shape = F.shape(data) | |||
| data_size = F.size(data) | |||
| data_dtype = F.dtype(data) | |||
| indices = mult_util.slice2indices(input_slice, data_shape) | |||
| indices_size = F.size(indices) | |||
| indices_size = mult_util.check_indices(indices_size, input_slice) | |||
| update = F.fill(data_dtype, (indices_size,), 1) | |||
| condition_1d = F.scatter_nd(indices, update, (data_size,)) | |||
| condition_1d = F.cast(condition_1d, mstype.bool_) | |||
| condition = F.reshape(condition_1d, data_shape) | |||
| # 2. u | |||
| value_fill = F.fill(data_dtype, (indices_size,), value) | |||
| value_1d = F.scatter_nd(indices, value_fill, (data_size,)) | |||
| u = F.reshape(value_1d, data_shape) | |||
| # A[slice]= u -> A[B]=U -> select(B, U, A) | |||
| result = F.select(condition, u, data) | |||
| return result | |||
| @@ -68,6 +68,7 @@ tuple_to_array = P.TupleToArray() | |||
| scalar_cast = P.ScalarCast() | |||
| print_ = P.Print() | |||
| expand_dims = P.ExpandDims() | |||
| scatter_nd = P.ScatterNd() | |||
| tuple_setitem = Primitive('tuple_setitem') | |||
| tuple_getitem = Primitive('tuple_getitem') | |||
| @@ -94,10 +94,101 @@ class NetWorkReduceToScalar(Cell): | |||
| return ret | |||
| class TensorAssignWithSliceError1(Cell): | |||
| def __init__(self): | |||
| super(TensorAssignWithSliceError1, self).__init__() | |||
| def construct(self, a, b): | |||
| a[1:3:-1,::] = b | |||
| return a | |||
| class TensorAssignWithSliceError2(Cell): | |||
| def __init__(self): | |||
| super(TensorAssignWithSliceError2, self).__init__() | |||
| def construct(self, a, b): | |||
| a[1:3:-1] = b | |||
| return a | |||
| class TensorAssignWithSlice2(Cell): | |||
| def __init__(self): | |||
| super(TensorAssignWithSlice2, self).__init__() | |||
| def construct(self, a, b): | |||
| a[1:5] = b | |||
| a[3:4] = 5 | |||
| a[-1:1:-1] = b | |||
| a[-1:3:-1] = 5 | |||
| a[::] = b | |||
| a[::] = 9 | |||
| return a | |||
| class TensorAssignWithSlice(Cell): | |||
| def __init__(self): | |||
| super(TensorAssignWithSlice, self).__init__() | |||
| self.c = 2 | |||
| def construct(self, a, b): | |||
| a[1:3,::] = b | |||
| a[2:3:,3:] = b | |||
| a[::] = b | |||
| a[::] = self.c | |||
| a[::,::] = b | |||
| a[::,::] = self.c | |||
| a[2:3:,0:, 4:1:-1] = b | |||
| a[2:3:,0:, 4:1:-1] = self.c | |||
| z = a | |||
| return z | |||
| def test_tensor_assign_with_slice(): | |||
| context.set_context(mode=context.GRAPH_MODE, save_graphs=True) | |||
| net = TensorAssignWithSlice() | |||
| net2= TensorAssignWithSlice2() | |||
| net_e1 = TensorAssignWithSliceError1() | |||
| net_e2 = TensorAssignWithSliceError2() | |||
| a = np.arange(60).reshape(3,4,5) | |||
| b = Tensor([1]) | |||
| Ta = Tensor(a) | |||
| Tb= Tensor([1,3]) | |||
| Tc= Tensor([]) | |||
| t = Tensor([1, 2, 3, 4, 5, 6, 7, 8]) | |||
| net(Ta, b) | |||
| net2(t, b) | |||
| # Error for A[Slice] = Number | |||
| # 1. A[Slice] = Number, Slice error | |||
| with pytest.raises(ValueError): | |||
| net_e2(t, 2) | |||
| # Error for A[Slice] = U, U is a Tensor | |||
| # 1. A[Slice] = U, u.size is error | |||
| with pytest.raises(ValueError): | |||
| net2(t, Tb) | |||
| # 2. A[Slice] = U, U is empty | |||
| with pytest.raises(ValueError): | |||
| net2(t, Tc) | |||
| # 3. A[Slice] = U, U.size error | |||
| with pytest.raises(ValueError): | |||
| net2(t, Tb) | |||
| # Error for A[Tuple(Slice...)] = Tensor | |||
| # 1. A[Tuple(Slice...)] = U, U is empty | |||
| with pytest.raises(ValueError): | |||
| net(Ta, Tc) | |||
| # 2. A[Tuple(Slice...)] = U, U.size error | |||
| with pytest.raises(ValueError): | |||
| net(Ta, Tb) | |||
| # 3. A[Tuple(Slice...)] = U, Slice error | |||
| with pytest.raises(ValueError): | |||
| net_e1(Ta, b) | |||
| # Error for A[Tuple(Slice...)] = Number | |||
| # 1. A[Tuple(Slice...)] = Number, Slice error | |||
| with pytest.raises(ValueError): | |||
| net_e1(Ta, 2) | |||
| class TensorAssignWithBoolTensorIndex(Cell): | |||
| def __init__(self): | |||
| super(TensorAssignWithBoolTensorIndex, self).__init__() | |||
| self.t = Tensor(np.arange(6).reshape([2, 3]), dtype=mstype.float64) | |||
| self.t = Tensor(np.arange(60).reshape([3,4,5]), dtype = mstype.float64) | |||
| def construct(self, a, b, c, u_tensor, _scalar): | |||
| a[c] = u_scalar | |||
| @@ -119,6 +210,7 @@ class TensorAssignWithBoolTensorIndex2(Cell): | |||
| def __init__(self): | |||
| super(TensorAssignWithBoolTensorIndex2, self).__init__() | |||
| self.t = Tensor(np.arange(6).reshape([2, 3]), dtype=mstype.float64) | |||
| self.t = Tensor(np.arange(60).reshape([3,4,5]), dtype = mstype.float64) | |||
| def construct(self, a, u_tensor, _scalar): | |||
| a[a > 8] = u_tensor | |||
| @@ -139,7 +231,7 @@ class TensorAssignWithBoolTensorIndex2Error(Cell): | |||
| return a | |||
| a = np.random.uniform(1, 10, [2, 3]) | |||
| a = np.random.uniform(1,10,[3,4,5]) | |||
| b = a > 5 | |||
| c = a < 3 | |||
| Ta = Tensor(a) | |||
| @@ -148,13 +240,13 @@ Tc = Tensor(c) | |||
| Td = Tensor([True, True]) | |||
| u_tensor = Tensor([1]) | |||
| u_tensor_error = Tensor([1, 2]) | |||
| t_1d = Tensor([1, 2, 3, 4, 5, 6, 7, 8]) | |||
| u_scalar = 5 | |||
| def test_tensor_assign_bool_index(): | |||
| net1 = TensorAssignWithBoolTensorIndex() | |||
| net2 = TensorAssignWithBoolTensorIndex2() | |||
| net1(Ta, Tb, Tc, u_tensor, u_scalar) | |||
| net1(Ta, Tb, Tc, u_tensor, u_scalar) | |||
| with pytest.raises(ValueError): | |||
| net1(Ta, Td, Tc, u_tensor, u_scalar) | |||
| @@ -180,8 +272,15 @@ def test_tensor_assign_bool_index(): | |||
| with pytest.raises(AttributeError): | |||
| net4(Ta, u_scalar) | |||
| test_cases = [ | |||
| ('TensorAssignWithSlice', { | |||
| 'block': TensorAssignWithSlice(), | |||
| 'desc_inputs': [Ta, u_tensor], | |||
| }), | |||
| ('TensorAssignWithSlice2', { | |||
| 'block': TensorAssignWithSlice2(), | |||
| 'desc_inputs': [t_1d, u_tensor], | |||
| }), | |||
| ('TensorAssignWithBoolTensorIndex', { | |||
| 'block': TensorAssignWithBoolTensorIndex(), | |||
| 'desc_inputs': [Ta, Tb, Tc, u_tensor, u_scalar], | |||