Merge pull request !504 from candanzg/tensor_assign_with_slicetags/v0.2.0-alpha
| @@ -18,7 +18,7 @@ Interfaces for parser module in c++. | |||||
| from .parser import (Parser, create_obj_instance, generate_scope, | from .parser import (Parser, create_obj_instance, generate_scope, | ||||
| get_bprop_method_of_class, get_class_instance_type, | get_bprop_method_of_class, get_class_instance_type, | ||||
| get_class_member_namespace_symbol, | |||||
| get_class_member_namespace_symbol, create_slice_obj, | |||||
| get_dataclass_attributes, get_dataclass_methods, | get_dataclass_attributes, get_dataclass_methods, | ||||
| get_module_namespace, get_obj_type, get_object_key, | get_module_namespace, get_obj_type, get_object_key, | ||||
| get_parse_method_of_class, get_scope_name, | get_parse_method_of_class, get_scope_name, | ||||
| @@ -29,4 +29,4 @@ __all__ = ['parse_cb', 'get_parse_method_of_class', 'get_bprop_method_of_class', | |||||
| 'get_object_key', 'get_class_instance_type', 'is_class_member', 'get_obj_type', | 'get_object_key', 'get_class_instance_type', 'is_class_member', 'get_obj_type', | ||||
| 'create_obj_instance', 'get_module_namespace', 'get_class_member_namespace_symbol', | 'create_obj_instance', 'get_module_namespace', 'get_class_member_namespace_symbol', | ||||
| 'Parser', 'get_dataclass_attributes', 'get_dataclass_methods', 'dump_obj', 'load_obj', | 'Parser', 'get_dataclass_attributes', 'get_dataclass_methods', 'dump_obj', 'load_obj', | ||||
| 'get_dataclass_methods', 'get_scope_name'] | |||||
| 'get_dataclass_methods', 'get_scope_name', 'create_slice_obj'] | |||||
| @@ -29,6 +29,7 @@ from mindspore.common.dtype import pytype_to_dtype | |||||
| from mindspore.common.api import _MindSporeFunction | from mindspore.common.api import _MindSporeFunction | ||||
| from .namespace import CellNamespace, ClosureNamespace, ClassMemberNamespace | from .namespace import CellNamespace, ClosureNamespace, ClassMemberNamespace | ||||
| from .resources import parse_object_map, convert_object_map, trope_ns, SYMBOL_UNDEFINE, NO_IMPLEMENT | from .resources import parse_object_map, convert_object_map, trope_ns, SYMBOL_UNDEFINE, NO_IMPLEMENT | ||||
| from ..utils import Slice | |||||
| # define return value | # define return value | ||||
| RET_SUCCESS = 0 | RET_SUCCESS = 0 | ||||
| @@ -69,6 +70,10 @@ parse_expr_statement_white_list = ( | |||||
| "append", | "append", | ||||
| ) | ) | ||||
| def create_slice_obj(start, end, step): | |||||
| """Create Slice object""" | |||||
| return Slice(start, end, step) | |||||
| def parse_cb(func, parse_method=None): | def parse_cb(func, parse_method=None): | ||||
| """Implements the function of parse.""" | """Implements the function of parse.""" | ||||
| @@ -19,6 +19,7 @@ import logging | |||||
| import os | import os | ||||
| import inspect | import inspect | ||||
| from functools import wraps | from functools import wraps | ||||
| from dataclasses import dataclass | |||||
| def cal_sha256(file_path): | def cal_sha256(file_path): | ||||
| @@ -99,3 +100,13 @@ def cell_attr_register(fn=None, attrs=None): | |||||
| if fn is not None: | if fn is not None: | ||||
| return wrap_cell(fn) | return wrap_cell(fn) | ||||
| return wrap_cell | return wrap_cell | ||||
| @dataclass | |||||
| class Slice: | |||||
| """ | |||||
| Slice class | |||||
| """ | |||||
| start: int | |||||
| end: int | |||||
| step: int | |||||
| @@ -123,6 +123,9 @@ class ValueSlice : public Value { | |||||
| abstract::AbstractBasePtr ToAbstract() override; | abstract::AbstractBasePtr ToAbstract() override; | ||||
| std::string DumpText() const override { return ToString(); } | std::string DumpText() const override { return ToString(); } | ||||
| ValuePtr start() const { return start_; } | |||||
| ValuePtr stop() const { return stop_; } | |||||
| ValuePtr step() const { return step_; } | |||||
| private: | private: | ||||
| ValuePtr start_; | ValuePtr start_; | ||||
| @@ -79,6 +79,8 @@ const char PYTHON_PARSE_EXPAND_EXPR_STATEMENT[] = "expand_expr_statement"; | |||||
| const char PYTHON_PARSE_GENERATE_SCOPE[] = "generate_scope"; | const char PYTHON_PARSE_GENERATE_SCOPE[] = "generate_scope"; | ||||
| const char PYTHON_PARSE_GET_SCOPE_NAME[] = "get_scope_name"; | const char PYTHON_PARSE_GET_SCOPE_NAME[] = "get_scope_name"; | ||||
| const char PYTHON_PARSE_CLASS_SLICE[] = "create_slice_obj"; | |||||
| // define the common name | // define the common name | ||||
| const char NAMED_PRIMITIVE_ITER[] = "iter"; | const char NAMED_PRIMITIVE_ITER[] = "iter"; | ||||
| const char NAMED_PRIMITIVE_NEXT[] = "next"; | const char NAMED_PRIMITIVE_NEXT[] = "next"; | ||||
| @@ -289,6 +289,13 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) { | |||||
| dic["shape"] = shape; | dic["shape"] = shape; | ||||
| dic["dtype"] = abs_base->BuildType(); | dic["dtype"] = abs_base->BuildType(); | ||||
| dic["value"] = BuildValue(abs_base->BuildValue()); | dic["value"] = BuildValue(abs_base->BuildValue()); | ||||
| } else if (abs_base->isa<AbstractSlice>()) { | |||||
| auto arg_slice = dyn_cast<AbstractSlice>(abs_base); | |||||
| std::vector<int> shape; | |||||
| dic["shape"] = shape; | |||||
| dic["dtype"] = arg_slice->BuildType(); | |||||
| dic["value"] = BuildValue(arg_slice->BuildValue()); | |||||
| } else if (abs_base->isa<AbstractTuple>()) { | } else if (abs_base->isa<AbstractTuple>()) { | ||||
| auto arg_tuple = dyn_cast<AbstractTuple>(abs_base); | auto arg_tuple = dyn_cast<AbstractTuple>(abs_base); | ||||
| size_t len = arg_tuple->size(); | size_t len = arg_tuple->size(); | ||||
| @@ -28,6 +28,7 @@ | |||||
| #include "ir/meta_tensor.h" | #include "ir/meta_tensor.h" | ||||
| #include "pipeline/parse/parse.h" | #include "pipeline/parse/parse.h" | ||||
| #include "pipeline/parse/parse_base.h" | |||||
| #include "ir/value.h" | #include "ir/value.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| @@ -97,6 +98,13 @@ py::object ValuePtrToPyData(const ValuePtr &value) { | |||||
| i++; | i++; | ||||
| } | } | ||||
| ret = rets; | ret = rets; | ||||
| } else if (value->isa<ValueSlice>()) { | |||||
| auto slice = value->cast<ValueSlicePtr>(); | |||||
| auto start = ValuePtrToPyData(slice->start()); | |||||
| auto end = ValuePtrToPyData(slice->stop()); | |||||
| auto step = ValuePtrToPyData(slice->step()); | |||||
| ret = parse::python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE, parse::PYTHON_PARSE_CLASS_SLICE, start, end, | |||||
| step); | |||||
| } else if (value->isa<Type>()) { | } else if (value->isa<Type>()) { | ||||
| py::tuple v(1); | py::tuple v(1); | ||||
| v[0] = value->cast<TypePtr>(); | v[0] = value->cast<TypePtr>(); | ||||
| @@ -15,7 +15,43 @@ | |||||
| """constexpr util""" | """constexpr util""" | ||||
| import numpy as np | |||||
| from ...primitive import constexpr | from ...primitive import constexpr | ||||
| from ....common.tensor import Tensor | |||||
| from ....common import dtype as mstype | |||||
| from ...._extends.utils import Slice | |||||
| @constexpr | |||||
| def check_equal(param1, param2, msg="{},{}"): | |||||
| if param1 != param2: | |||||
| raise ValueError(msg.format(param1, param2)) | |||||
| return param1 | |||||
| @constexpr | |||||
| def check_tensor_setitem_index(index, element_type=None): | |||||
| """Check tuple index type of tensor assignment.""" | |||||
| if index is None: | |||||
| raise ValueError("Tensor's index cannot be None.") | |||||
| # eg. Tensor[Slice] = u | |||||
| if isinstance(index, Slice): | |||||
| return True | |||||
| # eg. Tensor[Tuple] = u | |||||
| if isinstance(index, tuple): | |||||
| if not index: | |||||
| raise ValueError("Tensor's index cannot be empty.") | |||||
| # eg. Tensor[Tuple(Slice...)] = u | |||||
| if not isinstance(index[0], Slice): | |||||
| raise ValueError("Index of type '{}' is not supported yet.".format(type(index[0]))) | |||||
| return True | |||||
| # eg. Tensor[Tensor[dtype=bool]] = u | |||||
| if index == mstype.tensor: | |||||
| if element_type is None or element_type != mstype.bool_: | |||||
| raise ValueError( | |||||
| "The index of tensor should be a bool type tensor. \ | |||||
| {} type is not supported yet.".format(element_type)) | |||||
| return True | |||||
| raise ValueError("Index of type '{}' is not supported yet.".format(type(index))) | |||||
| @constexpr | @constexpr | ||||
| @@ -43,3 +79,84 @@ def error_msg(msg="", format_values=""): | |||||
| """ | """ | ||||
| raise ValueError(msg.format(*format_values)) | raise ValueError(msg.format(*format_values)) | ||||
| def slice_expand(input_slices, shape): | |||||
| """ | |||||
| Convert slice to indices. | |||||
| Inputs: | |||||
| slices (List or Tuple(List, ...)): Slice tuple or slice. | |||||
| shape (Tuple): The shape of a sensor is an integer element tuple. | |||||
| Outputs: | |||||
| (List, List, List), This is expressed as (begins, ends, strides). | |||||
| """ | |||||
| begin = [] | |||||
| end = [] | |||||
| strides = [] | |||||
| index = 0 | |||||
| slices = None | |||||
| # Slice or Tuple(Slice...) | |||||
| if isinstance(input_slices, Slice): | |||||
| slices = (input_slices,) | |||||
| elif isinstance(input_slices, (tuple, list)) and input_slices and isinstance(input_slices[0], Slice): | |||||
| slices = input_slices | |||||
| else: | |||||
| raise ValueError("Tensor's index type is not supported yet.") | |||||
| for s in slices: | |||||
| start = 0 if (s.start is None) else s.start | |||||
| stop = shape[index] if (s.end is None) else s.end | |||||
| step = 1 if (s.step is None) else s.step | |||||
| begin.append(start) | |||||
| end.append(stop) | |||||
| strides.append(step) | |||||
| index += 1 | |||||
| while index < len(shape): | |||||
| begin.append(0) | |||||
| end.append(shape[index]) | |||||
| strides.append(1) | |||||
| index += 1 | |||||
| return begin, end, strides | |||||
| @constexpr | |||||
| def slice2indices(input_slices, shape): | |||||
| """ | |||||
| Convert slice to indices. | |||||
| Inputs: | |||||
| slices (List or Tuple(List, ...)): Slice tuple or slice. | |||||
| shape (Tuple): The shape of a sensor is an integer element tuple. | |||||
| Outputs: | |||||
| Tensor, the shape is (n, 1). | |||||
| """ | |||||
| begin, end, strides = slice_expand(input_slices, shape) | |||||
| np_r = [] | |||||
| for i, element in enumerate(shape): | |||||
| s = begin[i] if (begin[i] >= 0) else (element + begin[i]) | |||||
| e = end[i] if (end[i] >= 0) else (element + end[i]) | |||||
| np_r.append(np.r_[s:e:strides[i]]) | |||||
| # Reference: np.ravel_multi_index((np.ix_(np.r_[1:3:1], np.r_[0:4:1], np.r_[4:0:-1])), a.shape) | |||||
| np_ix = np.ix_(*np_r) | |||||
| ravel = np.ravel_multi_index(np_ix, shape) | |||||
| ravel = Tensor(ravel.reshape(-1, 1), dtype=mstype.int32) | |||||
| return ravel | |||||
| @constexpr | |||||
| def check_indices(indices_size, index): | |||||
| if indices_size < 1: | |||||
| raise ValueError("The tensor's index is unreasonable. index:{}".format(index)) | |||||
| return indices_size | |||||
| @constexpr | |||||
| def check_indices_value_size(indices_size, value_size): | |||||
| if value_size < 1: | |||||
| raise ValueError("The value assigned to tensor cannot be empty.") | |||||
| if value_size > 1: | |||||
| if value_size != indices_size: | |||||
| raise ValueError( | |||||
| "The value given to tensor does not match the index size. \ | |||||
| value size:{}, indics size:{}".format(value_size, indices_size)) | |||||
| return value_size | |||||
| @@ -138,25 +138,23 @@ def _tensor_setitem_by_tensor_v1(data, index, value_tensor): | |||||
| Outputs: | Outputs: | ||||
| Tensor, element type and shape is same as data. | Tensor, element type and shape is same as data. | ||||
| """ | """ | ||||
| result = None | |||||
| index_dtype = F.dtype(index) | index_dtype = F.dtype(index) | ||||
| index_shape = F.shape(index) | index_shape = F.shape(index) | ||||
| is_bool = mult_util.is_same_type(index_dtype, mstype.bool_) | |||||
| if not is_bool: | |||||
| return mult_util.error_msg( | |||||
| "The tensor index should be a bool type tensor. {} type tensor is not supported yet.", (index_dtype,)) | |||||
| data_shape = F.shape(data) | |||||
| if index_shape != data_shape: | |||||
| return mult_util.error_msg( | |||||
| "The tensor(shape={}) and tensor index(shape={}) should be the same shape.", (data_shape, index_shape)) | |||||
| size = F.size(value_tensor) | |||||
| if size != 1: | |||||
| return mult_util.error_msg( | |||||
| "When assign value is a tensor, its size should be 1, but current size is {}.", (size,)) | |||||
| dtype = F.dtype(data) | |||||
| u_cast = F.cast(value_tensor, dtype) | |||||
| one_data = F.ones_like(data) | |||||
| u = F.tensor_mul(one_data, u_cast) | |||||
| return F.select(index, u, data) | |||||
| check_result = mult_util.check_tensor_setitem_index(mstype.tensor, index_dtype) | |||||
| if check_result: | |||||
| data_shape = F.shape(data) | |||||
| data_shape = mult_util.check_equal(data_shape, index_shape, | |||||
| "The tensor(shape={}) and tensor index(shape={}) should be the same shape.") | |||||
| size = F.size(value_tensor) | |||||
| size = mult_util.check_equal(1, size, | |||||
| "When assign value is a tensor, its size should be {}, but current size is {}.") | |||||
| dtype = F.dtype(data) | |||||
| u_cast = F.cast(value_tensor, dtype) | |||||
| one_data = F.ones_like(data) | |||||
| u = F.tensor_mul(one_data, u_cast) | |||||
| result = F.select(index, u, data) | |||||
| return result | |||||
| @setitem.register("Tensor", "Tensor", "Number") | @setitem.register("Tensor", "Tensor", "Number") | ||||
| @@ -179,16 +177,162 @@ def _tensor_setitem_by_tensor_v2(data, index, value): | |||||
| Outputs: | Outputs: | ||||
| Tensor, element type and shape is same as data. | Tensor, element type and shape is same as data. | ||||
| """ | """ | ||||
| result = None | |||||
| index_dtype = F.dtype(index) | index_dtype = F.dtype(index) | ||||
| index_shape = F.shape(index) | index_shape = F.shape(index) | ||||
| is_bool = mult_util.is_same_type(index_dtype, mstype.bool_) | |||||
| if not is_bool: | |||||
| return mult_util.error_msg( | |||||
| "The tensor index should be a bool type tensor. {} type tensor is not supported yet.", (index_dtype,)) | |||||
| shape = F.shape(data) | |||||
| if index_shape != shape: | |||||
| return mult_util.error_msg( | |||||
| "The tensor(shape={}) and tensor index(shape={}) should be the same shape.", (shape, index_shape)) | |||||
| dtype = F.dtype(data) | |||||
| u = F.fill(dtype, shape, value) | |||||
| return F.select(index, u, data) | |||||
| check_result = mult_util.check_tensor_setitem_index(mstype.tensor, index_dtype) | |||||
| if check_result: | |||||
| shape = F.shape(data) | |||||
| shape = mult_util.check_equal( | |||||
| shape, index_shape, "The tensor(shape={}) and tensor index(shape={}) should be the same shape.") | |||||
| dtype = F.dtype(data) | |||||
| u = F.fill(dtype, shape, value) | |||||
| result = F.select(index, u, data) | |||||
| return result | |||||
| @setitem.register("Tensor", "Slice", "Tensor") | |||||
| def _tensor_setitem_with_slice_v3(data, input_slice, value): | |||||
| """ | |||||
| Tensor assignment. | |||||
| Note: | |||||
| Syntax support: A[Slice] = U | |||||
| Restraint condition: A is a Tensor | |||||
| Slice like "1:3" | |||||
| U is a Tensor(size=1) or Tensor(size>1) | |||||
| Inputs: | |||||
| data (Tensor): Assigned tensor. | |||||
| input_slice (Slice): Slice expression. | |||||
| value (Number): Assignment value. | |||||
| Outputs: | |||||
| Tensor, element type and shape is same as data. | |||||
| """ | |||||
| return _tensor_assgin_tensor(data, input_slice, value) | |||||
| @setitem.register("Tensor", "Tuple", "Tensor") | |||||
| def _tensor_setitem_with_slice_v4(data, input_slice, value): | |||||
| """ | |||||
| Tensor assignment. | |||||
| Note: | |||||
| Syntax support: A[Slice] = U | |||||
| Restraint condition: A is a Tensor | |||||
| Slice like "1:3, ::, :4:-1" | |||||
| U is a Tensor(size=1) or Tensor(size>1) | |||||
| Inputs: | |||||
| data (Tensor): Assigned tensor. | |||||
| input_slice (Tuple(Slice)): Slice expression. | |||||
| value (Number): Assignment value. | |||||
| Outputs: | |||||
| Tensor, element type and shape is same as data. | |||||
| """ | |||||
| return _tensor_assgin_tensor(data, input_slice, value) | |||||
| def _tensor_assgin_tensor(data, input_slice, value): | |||||
| """Given a tensor value assign to tensor by slice""" | |||||
| # 1. condition | |||||
| result = None | |||||
| check_result = mult_util.check_tensor_setitem_index(input_slice) | |||||
| if check_result: | |||||
| data_shape = F.shape(data) | |||||
| data_size = F.size(data) | |||||
| data_dtype = F.dtype(data) | |||||
| indices = mult_util.slice2indices(input_slice, data_shape) | |||||
| indices_size = F.size(indices) | |||||
| indices_size = mult_util.check_indices(indices_size, input_slice) | |||||
| update = F.fill(data_dtype, (indices_size,), 1) | |||||
| condition_1d = F.scatter_nd(indices, update, (data_size,)) | |||||
| condition_1d = F.cast(condition_1d, mstype.bool_) | |||||
| condition = F.reshape(condition_1d, data_shape) | |||||
| # 2. u | |||||
| value_fill = None | |||||
| value_size = F.size(value) | |||||
| value_size = mult_util.check_indices_value_size(indices_size, value_size) | |||||
| if value_size == 1: | |||||
| value_fill = F.fill(data_dtype, (indices_size,), 1) | |||||
| value = F.cast(value, data_dtype) | |||||
| value_fill = F.tensor_mul(value_fill, value) | |||||
| elif value_size > 1: | |||||
| value_fill = F.reshape(value, (indices_size,)) | |||||
| value_1d = F.scatter_nd(indices, value_fill, (data_size,)) | |||||
| u = F.reshape(value_1d, data_shape) | |||||
| # A[slice]= u -> A[B]=U -> select(B, U, A) | |||||
| result = F.select(condition, u, data) | |||||
| return result | |||||
| @setitem.register("Tensor", "Slice", "Number") | |||||
| def _tensor_setitem_with_slice_v1(data, input_slice, value): | |||||
| """ | |||||
| Tensor assignment. | |||||
| Note: | |||||
| Syntax support: A[Slice] = u | |||||
| Restraint condition: A is a Tensor. | |||||
| Slice like "1:3" | |||||
| u is a scalar | |||||
| Inputs: | |||||
| data (Tensor): Assigned tensor. | |||||
| input_slice (Slice): slice expression. | |||||
| value (Number): Assignment value. | |||||
| Outputs: | |||||
| Tensor, element type and shape is same as data. | |||||
| """ | |||||
| return _tensor_assgin_number(data, input_slice, value) | |||||
| @setitem.register("Tensor", "Tuple", "Number") | |||||
| def _tensor_setitem_with_slice_v2(data, input_slice, value): | |||||
| """ | |||||
| Tensor assignment. | |||||
| Note: | |||||
| Syntax support: A[Slice] = u | |||||
| Restraint condition: A is a Tensor. | |||||
| Slice like "1:3, ::, :4:-1" | |||||
| u is a scalar | |||||
| Inputs: | |||||
| data (Tensor): Assigned tensor. | |||||
| input_slice (Tuple(Slice)): slice expression. | |||||
| value (Number): Assignment value. | |||||
| Outputs: | |||||
| Tensor, element type and shape is same as data. | |||||
| """ | |||||
| return _tensor_assgin_number(data, input_slice, value) | |||||
| def _tensor_assgin_number(data, input_slice, value): | |||||
| """Given a scalar assign to tensor by slice""" | |||||
| # 1. condition | |||||
| check_result = mult_util.check_tensor_setitem_index(input_slice) | |||||
| result = None | |||||
| if check_result: | |||||
| data_shape = F.shape(data) | |||||
| data_size = F.size(data) | |||||
| data_dtype = F.dtype(data) | |||||
| indices = mult_util.slice2indices(input_slice, data_shape) | |||||
| indices_size = F.size(indices) | |||||
| indices_size = mult_util.check_indices(indices_size, input_slice) | |||||
| update = F.fill(data_dtype, (indices_size,), 1) | |||||
| condition_1d = F.scatter_nd(indices, update, (data_size,)) | |||||
| condition_1d = F.cast(condition_1d, mstype.bool_) | |||||
| condition = F.reshape(condition_1d, data_shape) | |||||
| # 2. u | |||||
| value_fill = F.fill(data_dtype, (indices_size,), value) | |||||
| value_1d = F.scatter_nd(indices, value_fill, (data_size,)) | |||||
| u = F.reshape(value_1d, data_shape) | |||||
| # A[slice]= u -> A[B]=U -> select(B, U, A) | |||||
| result = F.select(condition, u, data) | |||||
| return result | |||||
| @@ -68,6 +68,7 @@ tuple_to_array = P.TupleToArray() | |||||
| scalar_cast = P.ScalarCast() | scalar_cast = P.ScalarCast() | ||||
| print_ = P.Print() | print_ = P.Print() | ||||
| expand_dims = P.ExpandDims() | expand_dims = P.ExpandDims() | ||||
| scatter_nd = P.ScatterNd() | |||||
| tuple_setitem = Primitive('tuple_setitem') | tuple_setitem = Primitive('tuple_setitem') | ||||
| tuple_getitem = Primitive('tuple_getitem') | tuple_getitem = Primitive('tuple_getitem') | ||||
| @@ -94,10 +94,101 @@ class NetWorkReduceToScalar(Cell): | |||||
| return ret | return ret | ||||
| class TensorAssignWithSliceError1(Cell): | |||||
| def __init__(self): | |||||
| super(TensorAssignWithSliceError1, self).__init__() | |||||
| def construct(self, a, b): | |||||
| a[1:3:-1,::] = b | |||||
| return a | |||||
| class TensorAssignWithSliceError2(Cell): | |||||
| def __init__(self): | |||||
| super(TensorAssignWithSliceError2, self).__init__() | |||||
| def construct(self, a, b): | |||||
| a[1:3:-1] = b | |||||
| return a | |||||
| class TensorAssignWithSlice2(Cell): | |||||
| def __init__(self): | |||||
| super(TensorAssignWithSlice2, self).__init__() | |||||
| def construct(self, a, b): | |||||
| a[1:5] = b | |||||
| a[3:4] = 5 | |||||
| a[-1:1:-1] = b | |||||
| a[-1:3:-1] = 5 | |||||
| a[::] = b | |||||
| a[::] = 9 | |||||
| return a | |||||
| class TensorAssignWithSlice(Cell): | |||||
| def __init__(self): | |||||
| super(TensorAssignWithSlice, self).__init__() | |||||
| self.c = 2 | |||||
| def construct(self, a, b): | |||||
| a[1:3,::] = b | |||||
| a[2:3:,3:] = b | |||||
| a[::] = b | |||||
| a[::] = self.c | |||||
| a[::,::] = b | |||||
| a[::,::] = self.c | |||||
| a[2:3:,0:, 4:1:-1] = b | |||||
| a[2:3:,0:, 4:1:-1] = self.c | |||||
| z = a | |||||
| return z | |||||
| def test_tensor_assign_with_slice(): | |||||
| context.set_context(mode=context.GRAPH_MODE, save_graphs=True) | |||||
| net = TensorAssignWithSlice() | |||||
| net2= TensorAssignWithSlice2() | |||||
| net_e1 = TensorAssignWithSliceError1() | |||||
| net_e2 = TensorAssignWithSliceError2() | |||||
| a = np.arange(60).reshape(3,4,5) | |||||
| b = Tensor([1]) | |||||
| Ta = Tensor(a) | |||||
| Tb= Tensor([1,3]) | |||||
| Tc= Tensor([]) | |||||
| t = Tensor([1, 2, 3, 4, 5, 6, 7, 8]) | |||||
| net(Ta, b) | |||||
| net2(t, b) | |||||
| # Error for A[Slice] = Number | |||||
| # 1. A[Slice] = Number, Slice error | |||||
| with pytest.raises(ValueError): | |||||
| net_e2(t, 2) | |||||
| # Error for A[Slice] = U, U is a Tensor | |||||
| # 1. A[Slice] = U, u.size is error | |||||
| with pytest.raises(ValueError): | |||||
| net2(t, Tb) | |||||
| # 2. A[Slice] = U, U is empty | |||||
| with pytest.raises(ValueError): | |||||
| net2(t, Tc) | |||||
| # 3. A[Slice] = U, U.size error | |||||
| with pytest.raises(ValueError): | |||||
| net2(t, Tb) | |||||
| # Error for A[Tuple(Slice...)] = Tensor | |||||
| # 1. A[Tuple(Slice...)] = U, U is empty | |||||
| with pytest.raises(ValueError): | |||||
| net(Ta, Tc) | |||||
| # 2. A[Tuple(Slice...)] = U, U.size error | |||||
| with pytest.raises(ValueError): | |||||
| net(Ta, Tb) | |||||
| # 3. A[Tuple(Slice...)] = U, Slice error | |||||
| with pytest.raises(ValueError): | |||||
| net_e1(Ta, b) | |||||
| # Error for A[Tuple(Slice...)] = Number | |||||
| # 1. A[Tuple(Slice...)] = Number, Slice error | |||||
| with pytest.raises(ValueError): | |||||
| net_e1(Ta, 2) | |||||
| class TensorAssignWithBoolTensorIndex(Cell): | class TensorAssignWithBoolTensorIndex(Cell): | ||||
| def __init__(self): | def __init__(self): | ||||
| super(TensorAssignWithBoolTensorIndex, self).__init__() | super(TensorAssignWithBoolTensorIndex, self).__init__() | ||||
| self.t = Tensor(np.arange(6).reshape([2, 3]), dtype=mstype.float64) | |||||
| self.t = Tensor(np.arange(60).reshape([3,4,5]), dtype = mstype.float64) | |||||
| def construct(self, a, b, c, u_tensor, _scalar): | def construct(self, a, b, c, u_tensor, _scalar): | ||||
| a[c] = u_scalar | a[c] = u_scalar | ||||
| @@ -119,6 +210,7 @@ class TensorAssignWithBoolTensorIndex2(Cell): | |||||
| def __init__(self): | def __init__(self): | ||||
| super(TensorAssignWithBoolTensorIndex2, self).__init__() | super(TensorAssignWithBoolTensorIndex2, self).__init__() | ||||
| self.t = Tensor(np.arange(6).reshape([2, 3]), dtype=mstype.float64) | self.t = Tensor(np.arange(6).reshape([2, 3]), dtype=mstype.float64) | ||||
| self.t = Tensor(np.arange(60).reshape([3,4,5]), dtype = mstype.float64) | |||||
| def construct(self, a, u_tensor, _scalar): | def construct(self, a, u_tensor, _scalar): | ||||
| a[a > 8] = u_tensor | a[a > 8] = u_tensor | ||||
| @@ -139,7 +231,7 @@ class TensorAssignWithBoolTensorIndex2Error(Cell): | |||||
| return a | return a | ||||
| a = np.random.uniform(1, 10, [2, 3]) | |||||
| a = np.random.uniform(1,10,[3,4,5]) | |||||
| b = a > 5 | b = a > 5 | ||||
| c = a < 3 | c = a < 3 | ||||
| Ta = Tensor(a) | Ta = Tensor(a) | ||||
| @@ -148,13 +240,13 @@ Tc = Tensor(c) | |||||
| Td = Tensor([True, True]) | Td = Tensor([True, True]) | ||||
| u_tensor = Tensor([1]) | u_tensor = Tensor([1]) | ||||
| u_tensor_error = Tensor([1, 2]) | u_tensor_error = Tensor([1, 2]) | ||||
| t_1d = Tensor([1, 2, 3, 4, 5, 6, 7, 8]) | |||||
| u_scalar = 5 | u_scalar = 5 | ||||
| def test_tensor_assign_bool_index(): | def test_tensor_assign_bool_index(): | ||||
| net1 = TensorAssignWithBoolTensorIndex() | net1 = TensorAssignWithBoolTensorIndex() | ||||
| net2 = TensorAssignWithBoolTensorIndex2() | net2 = TensorAssignWithBoolTensorIndex2() | ||||
| net1(Ta, Tb, Tc, u_tensor, u_scalar) | |||||
| net1(Ta, Tb, Tc, u_tensor, u_scalar) | net1(Ta, Tb, Tc, u_tensor, u_scalar) | ||||
| with pytest.raises(ValueError): | with pytest.raises(ValueError): | ||||
| net1(Ta, Td, Tc, u_tensor, u_scalar) | net1(Ta, Td, Tc, u_tensor, u_scalar) | ||||
| @@ -180,8 +272,15 @@ def test_tensor_assign_bool_index(): | |||||
| with pytest.raises(AttributeError): | with pytest.raises(AttributeError): | ||||
| net4(Ta, u_scalar) | net4(Ta, u_scalar) | ||||
| test_cases = [ | test_cases = [ | ||||
| ('TensorAssignWithSlice', { | |||||
| 'block': TensorAssignWithSlice(), | |||||
| 'desc_inputs': [Ta, u_tensor], | |||||
| }), | |||||
| ('TensorAssignWithSlice2', { | |||||
| 'block': TensorAssignWithSlice2(), | |||||
| 'desc_inputs': [t_1d, u_tensor], | |||||
| }), | |||||
| ('TensorAssignWithBoolTensorIndex', { | ('TensorAssignWithBoolTensorIndex', { | ||||
| 'block': TensorAssignWithBoolTensorIndex(), | 'block': TensorAssignWithBoolTensorIndex(), | ||||
| 'desc_inputs': [Ta, Tb, Tc, u_tensor, u_scalar], | 'desc_inputs': [Ta, Tb, Tc, u_tensor, u_scalar], | ||||