Merge pull request !1133 from zhangbuxue/support_tensor_get_value_by_tensor_indextags/v0.3.0-alpha
| @@ -1172,6 +1172,12 @@ int GenerateStridedSliceParametersFromNumber(const AbstractScalarPtr &scalar, co | |||
| return 1; | |||
| } | |||
| FuncGraphPtr ExpandADim(const FuncGraphPtr &ret_graph, const AnfNodePtr &tensor_node) { | |||
| auto PrimExpandDims = GetPythonOps("expand_dims", "mindspore.ops.functional"); | |||
| ret_graph->set_output(NewCNode({NewValueNode(PrimExpandDims), tensor_node, NewValueNode(0)}, ret_graph)); | |||
| return ret_graph; | |||
| } | |||
| FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { | |||
| // slice a tensor | |||
| // args: tensor, slice or slice tuple | |||
| @@ -1229,12 +1235,6 @@ FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec | |||
| return ret_graph; | |||
| } | |||
| FuncGraphPtr TensorSlice::ExpandADim(const FuncGraphPtr &ret_graph, const AnfNodePtr &tensor_node) const { | |||
| auto PrimExpandDims = GetPythonOps("expand_dims", "mindspore.ops.functional"); | |||
| ret_graph->set_output(NewCNode({NewValueNode(PrimExpandDims), tensor_node, NewValueNode(0)}, ret_graph)); | |||
| return ret_graph; | |||
| } | |||
| FuncGraphPtr TupleGetItemTensor::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { | |||
| // select indexed item | |||
| // args: tuple of items, index | |||
| @@ -206,8 +206,6 @@ class TensorSlice : public MetaFuncGraph { | |||
| MS_DECLARE_PARENT(TensorSlice, MetaFuncGraph) | |||
| FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; | |||
| friend bool operator==(const TensorSlice &lhs, const TensorSlice &rhs) { return lhs.name_ == rhs.name_; } | |||
| FuncGraphPtr ExpandADim(const FuncGraphPtr &ret_graph, const AnfNodePtr &tensor_node) const; | |||
| }; | |||
| using TensorSlicePtr = std::shared_ptr<TensorSlice>; | |||
| @@ -101,6 +101,7 @@ const char kNameReLU6[] = "ReLU6"; | |||
| const char kNameReLU6Grad[] = "ReLU6Grad"; | |||
| const char kNameElu[] = "Elu"; | |||
| const char kNameEluGrad[] = "EluGrad"; | |||
| const char kNameScatterUpdate[] = "ScatterUpdate"; | |||
| const char kNameScatterNdUpdate[] = "ScatterNdUpdate"; | |||
| const char kNameScatterMax[] = "ScatterMax"; | |||
| const char kNameNMSWithMask[] = "NMSWithMask"; | |||
| @@ -256,6 +257,7 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma | |||
| {string(kNameResizeBilinear), ADPT_DESC(ResizeBilinearV2D)}, | |||
| {string(kNameZerosLike), ADPT_DESC(ZerosLike)}, | |||
| {string(kNameOnesLike), ADPT_DESC(OnesLike)}, | |||
| {string(kNameScatterUpdate), ADPT_DESC(ScatterUpdate)}, | |||
| {string(kNameScatterNdUpdate), ADPT_DESC(ScatterNdUpdate)}, | |||
| {string(kNameScatterMax), ADPT_DESC(ScatterMax)}, | |||
| {string(kNameNMSWithMask), ADPT_DESC(NMSWithMask)}, | |||
| @@ -515,6 +515,11 @@ INPUT_MAP(Unpack) = {{1, INPUT_DESC(x)}}; | |||
| ATTR_MAP(Unpack) = {{"axis", ATTR_DESC(axis, AnyTraits<int>())}, {"num", ATTR_DESC(num, AnyTraits<int>())}}; | |||
| DYN_OUTPUT_MAP(Unpack) = {{0, DYN_OUTPUT_DESC(y)}}; | |||
| // ScatterUpdate | |||
| INPUT_MAP(ScatterUpdate) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(indices)}, {3, INPUT_DESC(updates)}}; | |||
| ATTR_MAP(ScatterUpdate) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}}; | |||
| OUTPUT_MAP(ScatterUpdate) = {{0, OUTPUT_DESC(var)}}; | |||
| // ScatterNdUpdate | |||
| INPUT_MAP(ScatterNdUpdate) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(indices)}, {3, INPUT_DESC(updates)}}; | |||
| ATTR_MAP(ScatterNdUpdate) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}}; | |||
| @@ -132,6 +132,8 @@ DECLARE_OP_ADAPTER(ZerosLike) | |||
| DECLARE_OP_USE_OUTPUT(ZerosLike) | |||
| DECLARE_OP_ADAPTER(OnesLike) | |||
| DECLARE_OP_USE_OUTPUT(OnesLike) | |||
| DECLARE_OP_ADAPTER(ScatterUpdate) | |||
| DECLARE_OP_USE_OUTPUT(ScatterUpdate) | |||
| DECLARE_OP_ADAPTER(ScatterNdUpdate) | |||
| DECLARE_OP_USE_OUTPUT(ScatterNdUpdate) | |||
| DECLARE_OP_ADAPTER(ScatterMax) | |||
| @@ -179,14 +179,15 @@ from .bounding_box_encode import _bounding_box_encode_tbe | |||
| from .check_valid import _check_valid_tbe | |||
| from .iou import _iou_tbe | |||
| from .arg_max import _arg_max_tbe | |||
| from .nms_with_mask import nms_with_mask_op_info | |||
| from .random_choice_with_mask import random_choice_with_mask_op_info | |||
| from .sgd import sgd_op_info | |||
| from .lars_update import lars_update_op_info | |||
| from .nms_with_mask import _nms_with_mask_tbe | |||
| from .random_choice_with_mask import _random_choice_with_mask_tbe | |||
| from .sgd import _sgd_tbe | |||
| from .lars_update import _lars_update_tbe | |||
| from .bn_training_update_v2 import _bn_training_update_v2_tbe | |||
| from .square_sum_all import square_sum_all_op_info | |||
| from .square_sum_all import _square_sum_all_tbe | |||
| from .pack import _pack_tbe | |||
| from .unpack import _unpack_tbe | |||
| from .scatter_update import _scatter_update_tbe | |||
| from .prelu import _prelu_tbe | |||
| from .prelu_grad import _prelu_grad_tbe | |||
| from .binary_cross_entropy import _binary_cross_entropy_tbe | |||
| @@ -0,0 +1,42 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ScatterUpdate op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| scatter_update_op_info = TBERegOp("ScatterUpdate") \ | |||
| .fusion_type("ELEMWISE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("scatter_update.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("scatter_update") \ | |||
| .partial_flag(True) \ | |||
| .attr("use_locking", "optional", "bool", "all") \ | |||
| .input(0, "var", False, "required", "all") \ | |||
| .input(1, "indices", False, "required", "all") \ | |||
| .input(1, "updates", False, "required", "all") \ | |||
| .output(0, "var", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default, DataType.I8_Default) \ | |||
| .dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default, DataType.U8_Default,) \ | |||
| .dtype_format(DataType.BOOL_Default, DataType.I32_Default, DataType.BOOL_Default, DataType.BOOL_Default) \ | |||
| .get_op_info() | |||
| @op_info_register(scatter_update_op_info) | |||
| def _scatter_update_tbe(): | |||
| """ScatterUpdate TBE register""" | |||
| return | |||
| @@ -14,6 +14,6 @@ | |||
| # ============================================================================ | |||
| """ops utils.""" | |||
| from .utils import _get_broadcast_shape, _get_concat_offset | |||
| from .utils import get_broadcast_shape, get_concat_offset | |||
| __all__ = ['_get_broadcast_shape', '_get_concat_offset'] | |||
| __all__ = ['get_broadcast_shape', 'get_concat_offset'] | |||
| @@ -19,7 +19,8 @@ from ..._checkparam import Validator as validator | |||
| from ..._checkparam import Rel | |||
| from ...common import dtype as mstype | |||
| def _get_broadcast_shape(x_shape, y_shape, prim_name): | |||
| def get_broadcast_shape(x_shape, y_shape, prim_name): | |||
| """ | |||
| Doing broadcast between tensor x and tensor y. | |||
| @@ -37,7 +38,7 @@ def _get_broadcast_shape(x_shape, y_shape, prim_name): | |||
| Examples: | |||
| >>> x_shape = [1, 2, 3] | |||
| >>> y_shape = [1, 2] | |||
| >>> broadcast_shape = _get_broadcast_shape(x_shape, y_shape) | |||
| >>> broadcast_shape = get_broadcast_shape(x_shape, y_shape) | |||
| """ | |||
| if x_shape == y_shape: | |||
| return x_shape | |||
| @@ -54,15 +55,14 @@ def _get_broadcast_shape(x_shape, y_shape, prim_name): | |||
| elif x_shape[i] == y_shape[i]: | |||
| broadcast_shape_back.append(x_shape[i]) | |||
| else: | |||
| raise ValueError("For '{}' the x_shape {} and y_shape {} can not broadcast.".format( | |||
| prim_name, x_shape, y_shape)) | |||
| raise ValueError(f"For '{prim_name}', the x_shape {x_shape} and y_shape {y_shape} can not broadcast.") | |||
| broadcast_shape_front = y_shape[0: y_len - length] if length == x_len else x_shape[0: x_len - length] | |||
| broadcast_shape = broadcast_shape_front + broadcast_shape_back | |||
| broadcast_shape = list(broadcast_shape_front) + broadcast_shape_back | |||
| return broadcast_shape | |||
| def _get_concat_offset(x_shp, x_type, axis, prim_name): | |||
| def get_concat_offset(x_shp, x_type, axis, prim_name): | |||
| """for concat and concatoffset check args and compute offset""" | |||
| validator.check_value_type("shape", x_shp, [tuple], prim_name) | |||
| validator.check_integer("input_x rank", len(x_shp), 0, Rel.GT, prim_name) | |||
| @@ -73,7 +73,7 @@ def _get_concat_offset(x_shp, x_type, axis, prim_name): | |||
| if axis < 0: | |||
| axis = axis + rank_base | |||
| all_shp = x_shp[0][axis] | |||
| offset = [0,] | |||
| offset = [0] | |||
| for i in range(1, len(x_shp)): | |||
| v = x_shp[i] | |||
| validator.check('len of x_shp[%d]' % i, len(v), 'len of x_shp[0]', len(x_shp[0]), Rel.EQ, prim_name) | |||
| @@ -1,226 +0,0 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """constexpr util""" | |||
| from functools import reduce | |||
| import numpy as np | |||
| from ...primitive import constexpr | |||
| from ....common.tensor import Tensor | |||
| from ....common import dtype as mstype | |||
| from ...._extends.utils import Slice, Ellipsis_ | |||
| @constexpr | |||
| def check_equal(param1, param2, msg="{},{}"): | |||
| """Checks whether the two parameters are equal or not.""" | |||
| if param1 != param2: | |||
| raise ValueError(msg.format(param1, param2)) | |||
| return param1 | |||
| @constexpr | |||
| def check_ellipsis_shape_size(data_shape, value_shape, data_size, value_size): | |||
| """Checks the shape and size of the sensor and value.""" | |||
| if data_shape == value_shape or data_size == value_size or value_size == 1: | |||
| return True | |||
| raise ValueError("The value(shape={}), can not assign to tensor(shape={}).".format(value_shape, data_shape)) | |||
| @constexpr | |||
| def check_tensor_setitem_index(index, element_type=None): | |||
| """Checks tuple index type of tensor assignment.""" | |||
| if index is None: | |||
| raise IndexError("Tensor's index cannot be None.") | |||
| # eg. Tensor[Slice] = u | |||
| if isinstance(index, Slice): | |||
| return True | |||
| # eg. Tensor[tuple] = u | |||
| if isinstance(index, tuple): | |||
| if not index: | |||
| raise IndexError("Tensor's index cannot be empty.") | |||
| # eg. Tensor[tuple(Slice...)] = u | |||
| if isinstance(index[0], (Slice, Ellipsis_, int)): | |||
| return True | |||
| raise IndexError("Index of type '{}' is not supported yet.".format(type(index[0]))) | |||
| # eg. Tensor[Tensor[dtype=bool]] = u | |||
| if index == mstype.tensor: | |||
| if element_type is None or element_type != mstype.bool_: | |||
| raise TypeError( | |||
| "The index of tensor should be a bool type tensor. " | |||
| "{} type is not supported yet.".format(element_type)) | |||
| return True | |||
| raise IndexError("Index of type '{}' is not supported yet.".format(type(index))) | |||
| @constexpr | |||
| def is_same_type(inst, type_): | |||
| """ | |||
| Checks whether an object is an instance of a target type. | |||
| Inputs: | |||
| inst (mindspore.dtype): Inspected type. | |||
| type_ (mindspore.dtype): Target type. | |||
| Outputs: | |||
| bool, the check result. | |||
| """ | |||
| return inst == type_ | |||
| def slice_expand(input_slices, shape): | |||
| """ | |||
| Converts slice to indices. | |||
| Inputs: | |||
| slices (Union[Slice, tuple[Slice]]): Slice tuple or slice. | |||
| shape (tuple): The shape of a sensor is an integer element tuple. | |||
| Outputs: | |||
| tuple[list], This is expressed as (begins, ends, strides). | |||
| """ | |||
| begin = [] | |||
| end = [] | |||
| strides = [] | |||
| index = 0 | |||
| slices = None | |||
| # Slice or tuple(Slice...) | |||
| if isinstance(input_slices, Slice): | |||
| slices = (input_slices,) | |||
| elif isinstance(input_slices, (tuple, list)) and input_slices and isinstance(input_slices[0], (Slice, Ellipsis_)): | |||
| is_have_ellipsis = False | |||
| for _, element in enumerate(input_slices): | |||
| if isinstance(element, Ellipsis_): | |||
| is_have_ellipsis = True | |||
| break | |||
| if is_have_ellipsis: | |||
| slices = ellipsis2slice(input_slices, shape) | |||
| else: | |||
| slices = input_slices | |||
| else: | |||
| raise IndexError("Tensor's index type is not supported yet.") | |||
| for s in slices: | |||
| start = 0 if (s.start is None) else s.start | |||
| stop = shape[index] if (s.end is None) else s.end | |||
| step = 1 if (s.step is None) else s.step | |||
| begin.append(start) | |||
| end.append(stop) | |||
| strides.append(step) | |||
| index += 1 | |||
| while index < len(shape): | |||
| begin.append(0) | |||
| end.append(shape[index]) | |||
| strides.append(1) | |||
| index += 1 | |||
| return begin, end, strides | |||
| def ellipsis2slice(input_, shape): | |||
| """Converts ellipsis to slice.""" | |||
| input_slice = input_ | |||
| result = [] | |||
| if isinstance(input_, Ellipsis_): | |||
| input_slice = (input_,) | |||
| ell_count = 0 | |||
| for _, element in enumerate(input_slice): | |||
| if not isinstance(element, Ellipsis_): | |||
| result.append(element) | |||
| continue | |||
| ell_count += 1 | |||
| if ell_count > 1: | |||
| raise IndexError("There cannot be more than one ellisis (...) in the index of the tensor, " | |||
| "but it is currently {}".format(input_slice)) | |||
| for _ in range(len(shape) - len(input_slice) + 1): | |||
| result.append(Slice(None, None, None)) | |||
| return tuple(result) | |||
| @constexpr | |||
| def slice2indices(input_slices, shape): | |||
| """ | |||
| Converts slice to indices. | |||
| Inputs: | |||
| slices (Union[Slice, tuple[Slice]]): Slice tuple or slice. | |||
| shape (tuple): The shape of a tensor is an integer element tuple. | |||
| Outputs: | |||
| Tensor, the shape is (n, 1). | |||
| """ | |||
| begin, end, strides = slice_expand(input_slices, shape) | |||
| np_r = [] | |||
| for i, element in enumerate(shape): | |||
| s = begin[i] if (begin[i] >= 0) else (element + begin[i]) | |||
| e = end[i] if (end[i] >= 0) else (element + end[i]) | |||
| np_r.append(np.r_[s:e:strides[i]]) | |||
| # Reference: np.ravel_multi_index((np.ix_(np.r_[1:3:1], np.r_[0:4:1], np.r_[4:0:-1])), a.shape) | |||
| np_ix = np.ix_(*np_r) | |||
| ravel = np.ravel_multi_index(np_ix, shape) | |||
| ravel = Tensor(ravel.reshape(-1, 1), dtype=mstype.int32) | |||
| return ravel | |||
| @constexpr | |||
| def check_indices(indices_size, index): | |||
| """Checks indices whether is empty.""" | |||
| if indices_size < 1: | |||
| raise IndexError("The tensor's index is unreasonable. index:{}".format(index)) | |||
| return indices_size | |||
| @constexpr | |||
| def check_indices_value_size(indices_size, value_size): | |||
| """Checks if the sizes are already matched.""" | |||
| if value_size < 1: | |||
| raise ValueError("The value assigned to tensor cannot be empty.") | |||
| if value_size > 1: | |||
| if value_size != indices_size: | |||
| raise ValueError( | |||
| "The value given to tensor does not match the index size," | |||
| " value size:{}, indics size:{}".format(value_size, indices_size)) | |||
| return value_size | |||
| @constexpr | |||
| def integer_to_indices(index, shape): | |||
| """Converts int or tuple[int] to indices.""" | |||
| size = reduce(lambda x, y: x * y, shape) | |||
| range_ = np.arange(size).reshape(shape) | |||
| value = range_[index] | |||
| value = value.reshape(-1, 1) | |||
| return Tensor(value, dtype=mstype.int32) | |||
| @constexpr | |||
| def tuple_element_is_slice(indexs): | |||
| """Judges tuple element type.""" | |||
| if not indexs: | |||
| raise IndexError("Tensor's index cannot be empty.") | |||
| if isinstance(indexs, tuple): | |||
| for _, ele in enumerate(indexs): | |||
| if not isinstance(ele, Slice): | |||
| return False | |||
| return True | |||
| return False | |||
| @constexpr | |||
| def tuple_element_is_int(indexs): | |||
| """Judges tuple element type.""" | |||
| if not indexs: | |||
| raise IndexError("Tensor's index cannot be empty.") | |||
| if isinstance(indexs, tuple): | |||
| for _, ele in enumerate(indexs): | |||
| if not isinstance(ele, int): | |||
| return False | |||
| return True | |||
| return False | |||
| @@ -0,0 +1,487 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """constexpr util""" | |||
| from functools import reduce | |||
| import numpy as np | |||
| from ...primitive import constexpr | |||
| from ....common.tensor import Tensor | |||
| from ....common import dtype as mstype | |||
| from ...._extends.utils import Slice, Ellipsis_ | |||
| from ....ops import _utils as op_utils | |||
| from ...composite import base | |||
| from .... import log as logger | |||
| from ... import functional as F | |||
| from ... import operations as P | |||
| hyper_map = base.HyperMap() | |||
| pack = P.Pack(axis=-1) | |||
| ALL_TENSOR = 0 | |||
| NO_TENSOR = 1 | |||
| CONTAIN_TENSOR = 2 | |||
| ALL_SCALAR = 3 | |||
| INT_ = 0 | |||
| BOOL_ = 1 | |||
| UNSUPPORTED_DTYPE = 2 | |||
| TENSOR_SETITEM = "tensor setitem" | |||
| TENSOR_GETITEM = "tensor getitem" | |||
| SET_ITEM_BY_ONE_TENSOR = 0 | |||
| SET_ITEM_BY_TUPLE_OF_TENSOR = 1 | |||
| @constexpr | |||
| def check_equal(param1, param2, msg="{},{}"): | |||
| """Checks whether the two parameters are equal or not.""" | |||
| if param1 != param2: | |||
| raise ValueError(msg.format(param1, param2)) | |||
| return param1 | |||
| @constexpr | |||
| def check_ellipsis_shape_size(data_shape, value_shape, data_size, value_size): | |||
| """Checks the shape and size of the sensor and value.""" | |||
| if data_shape == value_shape or data_size == value_size or value_size == 1: | |||
| return True | |||
| raise ValueError("The value(shape={}), can not assign to tensor(shape={}).".format(value_shape, data_shape)) | |||
| @constexpr | |||
| def check_tensor_setitem_index(index, element_type=None): | |||
| """Checks tuple index type of tensor assignment.""" | |||
| if index is None: | |||
| raise IndexError("Tensor's index cannot be None.") | |||
| # eg. Tensor[Slice] = u | |||
| if isinstance(index, Slice): | |||
| return True | |||
| # eg. Tensor[tuple] = u | |||
| if isinstance(index, tuple): | |||
| if not index: | |||
| raise IndexError("Tensor's index cannot be empty.") | |||
| # eg. Tensor[tuple(Slice...)] = u | |||
| if isinstance(index[0], (Slice, Ellipsis_, int)): | |||
| return True | |||
| raise IndexError("Index of type '{}' is not supported yet.".format(type(index[0]))) | |||
| # eg. Tensor[Tensor[dtype=bool]] = u | |||
| if isinstance(index, mstype.tensor_type): | |||
| if element_type is None or element_type != mstype.bool_: | |||
| raise TypeError( | |||
| "The index of tensor should be a bool type tensor. " | |||
| "{} type is not supported yet.".format(element_type)) | |||
| return True | |||
| raise IndexError("Index of type '{}' is not supported yet.".format(type(index))) | |||
| @constexpr | |||
| def is_same_type(inst, type_): | |||
| """ | |||
| Checks whether an object is an instance of a target type. | |||
| Inputs: | |||
| inst (mindspore.dtype): Inspected type. | |||
| type_ (mindspore.dtype): Target type. | |||
| Outputs: | |||
| bool, the check result. | |||
| """ | |||
| return inst == type_ | |||
| def slice_expand(input_slices, shape): | |||
| """ | |||
| Converts slice to indices. | |||
| Inputs: | |||
| slices (Union[Slice, tuple[Slice]]): Slice tuple or slice. | |||
| shape (tuple): The shape of a sensor is an integer element tuple. | |||
| Outputs: | |||
| tuple[list], This is expressed as (begins, ends, strides). | |||
| """ | |||
| begin = [] | |||
| end = [] | |||
| strides = [] | |||
| index = 0 | |||
| slices = None | |||
| # Slice or tuple(Slice...) | |||
| if isinstance(input_slices, Slice): | |||
| slices = (input_slices,) | |||
| elif isinstance(input_slices, (tuple, list)) and input_slices and isinstance(input_slices[0], (Slice, Ellipsis_)): | |||
| is_have_ellipsis = False | |||
| for _, element in enumerate(input_slices): | |||
| if isinstance(element, Ellipsis_): | |||
| is_have_ellipsis = True | |||
| break | |||
| if is_have_ellipsis: | |||
| slices = ellipsis2slice(input_slices, shape) | |||
| else: | |||
| slices = input_slices | |||
| else: | |||
| raise IndexError("Tensor's index type is not supported yet.") | |||
| for s in slices: | |||
| start = 0 if (s.start is None) else s.start | |||
| stop = shape[index] if (s.end is None) else s.end | |||
| step = 1 if (s.step is None) else s.step | |||
| begin.append(start) | |||
| end.append(stop) | |||
| strides.append(step) | |||
| index += 1 | |||
| while index < len(shape): | |||
| begin.append(0) | |||
| end.append(shape[index]) | |||
| strides.append(1) | |||
| index += 1 | |||
| return begin, end, strides | |||
| def ellipsis2slice(input_, shape): | |||
| """Converts ellipsis to slice.""" | |||
| input_slice = input_ | |||
| result = [] | |||
| if isinstance(input_, Ellipsis_): | |||
| input_slice = (input_,) | |||
| ell_count = 0 | |||
| for _, element in enumerate(input_slice): | |||
| if not isinstance(element, Ellipsis_): | |||
| result.append(element) | |||
| continue | |||
| ell_count += 1 | |||
| if ell_count > 1: | |||
| raise IndexError("There cannot be more than one ellisis (...) in the index of the tensor, " | |||
| "but it is currently {}".format(input_slice)) | |||
| for _ in range(len(shape) - len(input_slice) + 1): | |||
| result.append(Slice(None, None, None)) | |||
| return tuple(result) | |||
| @constexpr | |||
| def slice2indices(input_slices, shape): | |||
| """ | |||
| Converts slice to indices. | |||
| Inputs: | |||
| slices (Union[Slice, tuple[Slice]]): Slice tuple or slice. | |||
| shape (tuple): The shape of a tensor is an integer element tuple. | |||
| Outputs: | |||
| Tensor, the shape is (n, 1). | |||
| """ | |||
| begin, end, strides = slice_expand(input_slices, shape) | |||
| np_r = [] | |||
| for i, element in enumerate(shape): | |||
| s = begin[i] if (begin[i] >= 0) else (element + begin[i]) | |||
| e = end[i] if (end[i] >= 0) else (element + end[i]) | |||
| np_r.append(np.r_[s:e:strides[i]]) | |||
| # Reference: np.ravel_multi_index((np.ix_(np.r_[1:3:1], np.r_[0:4:1], np.r_[4:0:-1])), a.shape) | |||
| np_ix = np.ix_(*np_r) | |||
| ravel = np.ravel_multi_index(np_ix, shape) | |||
| ravel = Tensor(ravel.reshape(-1, 1), dtype=mstype.int32) | |||
| return ravel | |||
| @constexpr | |||
| def check_indices(indices_size, index): | |||
| """Checks indices whether is empty.""" | |||
| if indices_size < 1: | |||
| raise IndexError("The tensor's index is unreasonable. index:{}".format(index)) | |||
| return indices_size | |||
| @constexpr | |||
| def check_indices_value_size(indices_size, value_size): | |||
| """Checks if the sizes are already matched.""" | |||
| if value_size < 1: | |||
| raise ValueError("The value assigned to tensor cannot be empty.") | |||
| if value_size > 1: | |||
| if value_size != indices_size: | |||
| raise ValueError( | |||
| "The value given to tensor does not match the index size," | |||
| " value size:{}, indics size:{}".format(value_size, indices_size)) | |||
| return value_size | |||
| @constexpr | |||
| def integer_to_indices(index, shape): | |||
| """Converts int or tuple[int] to indices.""" | |||
| size = reduce(lambda x, y: x * y, shape) | |||
| range_ = np.arange(size).reshape(shape) | |||
| value = range_[index] | |||
| value = value.reshape(-1, 1) | |||
| return Tensor(value, dtype=mstype.int32) | |||
| @constexpr | |||
| def tuple_element_is_slice(indexs): | |||
| """Judges tuple element type.""" | |||
| if not indexs: | |||
| raise IndexError("Tensor's index cannot be empty.") | |||
| if isinstance(indexs, tuple): | |||
| for _, ele in enumerate(indexs): | |||
| if not isinstance(ele, Slice): | |||
| return False | |||
| return True | |||
| return False | |||
| @constexpr | |||
| def tuple_element_is_int(indexs): | |||
| """Judges tuple element type.""" | |||
| if not indexs: | |||
| raise IndexError("Tensor's index cannot be empty.") | |||
| if isinstance(indexs, tuple): | |||
| for _, ele in enumerate(indexs): | |||
| if not isinstance(ele, int): | |||
| return False | |||
| return True | |||
| return False | |||
| @constexpr | |||
| def tuple_elements_type(types): | |||
| """Judges the type of all elements of the tuple.""" | |||
| tensors_number = 0 | |||
| for ele in types: | |||
| if isinstance(ele, mstype.tensor_type): | |||
| tensors_number += 1 | |||
| if tensors_number == len(types): | |||
| return ALL_TENSOR | |||
| if tensors_number == 0: | |||
| return NO_TENSOR | |||
| return CONTAIN_TENSOR | |||
| @constexpr | |||
| def check_value_elements(data_dtype, types): | |||
| """Judges the type of all elements of the tuple.""" | |||
| tensors_number = 0 | |||
| scalars_number = 0 | |||
| for i, ele in enumerate(types): | |||
| if isinstance(ele, mstype.tensor_type): | |||
| ele_dtype = ele.element_type() | |||
| if data_dtype == ele_dtype: | |||
| tensors_number += 1 | |||
| else: | |||
| raise TypeError(f"For '{TENSOR_SETITEM}', the data type of {i}th tensor '{ele_dtype}' " | |||
| f"in value tuple is not consistent with origin tensor data type '{data_dtype}'.") | |||
| elif mstype.issubclass_(ele, data_dtype): | |||
| scalars_number += 1 | |||
| else: | |||
| raise TypeError(f"For '{TENSOR_SETITEM}', the {i}th element type '{ele}' in " | |||
| f"value tuple is not consistent with origin tensor data type '{data_dtype}'.") | |||
| if tensors_number == len(types): | |||
| return ALL_TENSOR | |||
| if scalars_number == len(types): | |||
| return ALL_SCALAR | |||
| raise TypeError(f"For '{TENSOR_SETITEM}', the value does not support scalar and tensor mixing, but got {types}.") | |||
| @constexpr | |||
| def get_index_tensor_dtype(dtype): | |||
| """Check a tuple of tensor data type.""" | |||
| if dtype == mstype.int32: | |||
| return INT_ | |||
| if dtype == mstype.bool_: | |||
| return BOOL_ | |||
| raise TypeError(f"For '{TENSOR_SETITEM}', the index tensor data type '{dtype}' is not supported.") | |||
| @constexpr | |||
| def check_index_tensors_dtype(dtypes, op_name): | |||
| """Check a tuple of tensor data type.""" | |||
| if op_name == TENSOR_GETITEM: | |||
| valid_dtypes = (mstype.int32, mstype.int64) | |||
| elif op_name == TENSOR_SETITEM: | |||
| valid_dtypes = (mstype.int32,) | |||
| else: | |||
| raise ValueError("Unsupported operation.") | |||
| for ele in dtypes: | |||
| if ele in valid_dtypes and ele == dtypes[0]: | |||
| continue | |||
| raise TypeError(f"For '{op_name}', the index tensors data type must be same, " | |||
| f"and should be one of the following: {valid_dtypes}, but got {dtypes}.") | |||
| return True | |||
| @constexpr | |||
| def check_tensor_dtype_valid(dtype, valid_dtypes): | |||
| """Check a tensor data type.""" | |||
| if dtype in valid_dtypes: | |||
| return True | |||
| raise TypeError(f"The index tensor data type must be one of " | |||
| f"the following: {valid_dtypes}, but got {dtype}.") | |||
| @constexpr | |||
| def check_tensors_dtype_same(x_dtype, y_dtype, op_name): | |||
| """Check tensors data type same.""" | |||
| if x_dtype == y_dtype: | |||
| return True | |||
| raise TypeError(f"For '{op_name}', the value data type '{y_dtype}' " | |||
| f"is not consistent with origin tensor data type {x_dtype}.") | |||
| @constexpr | |||
| def broadcast_shapes(shapes, op_name): | |||
| """Broadcasts a tuple of tensor.""" | |||
| broadcast_shape = shapes[0] | |||
| for i, shape in enumerate(shapes): | |||
| logger.debug(f"Broadcasts the {i}th tensor, the shape is {shape}.") | |||
| broadcast_shape = op_utils.get_broadcast_shape(broadcast_shape, shape, op_name) | |||
| return tuple(broadcast_shape) | |||
| @constexpr | |||
| def check_two_shapes_need_broadcast(shape_x, shape_y): | |||
| """Check two shapes need broadcast.""" | |||
| error = ValueError(f"For 'tensor setitem with tensor', the value tensor shape " | |||
| f"{shape_y} could not broadcast the required updates shape {shape_x}.") | |||
| if len(shape_y) > len(shape_x): | |||
| raise error | |||
| for i in range(-len(shape_y), 0): | |||
| if shape_y[i] > shape_x[i]: | |||
| raise error | |||
| if shape_y[i] < shape_x[i] and shape_y[i] != 1: | |||
| raise error | |||
| if shape_y == shape_x: | |||
| return False | |||
| return True | |||
| @constexpr | |||
| def compute_multiples(origin_shape, broadcast_shape): | |||
| """Compute multiples between broadcast_shape with origin_shape.""" | |||
| len_gap = len(broadcast_shape) - len(origin_shape) | |||
| return broadcast_shape[0:len_gap] + tuple(map(lambda x, y: x // y, broadcast_shape[len_gap:], origin_shape)) | |||
| def tile(broadcast_shape, x): | |||
| multiples = compute_multiples(F.shape(x), broadcast_shape) | |||
| return F.tile(x, multiples) | |||
| @constexpr | |||
| def check_shapes_same(value_shapes, op_name): | |||
| """Check if the shapes in the tuple are consistent.""" | |||
| for i, shape in enumerate(value_shapes): | |||
| if shape != value_shapes[0]: | |||
| raise ValueError(f"For '{op_name}', the {i}th tensor shape in value tuple " | |||
| f"is not same as the first tensor shape.") | |||
| return True | |||
| @constexpr | |||
| def convert_scalar_to_tensor(data_shape, data_dtype, indices_shape, value, op_type): | |||
| """Convert a scalar to a tensor.""" | |||
| if op_type == SET_ITEM_BY_ONE_TENSOR: | |||
| updates_shape = indices_shape + data_shape[1:] | |||
| else: | |||
| updates_shape = indices_shape[:-1] + data_shape[indices_shape[-1]:] | |||
| if isinstance(value, mstype.dtype_to_pytype(data_dtype)): | |||
| return Tensor(np.full(updates_shape, value), dtype=data_dtype) | |||
| raise TypeError(f"For '{TENSOR_SETITEM}', the value type '{value.__class__.__name__}'" | |||
| f" is not consistent with tensor data type {data_dtype}.") | |||
| @constexpr | |||
| def convert_tuple_of_scalar_to_tensor(data_shape, data_dtype, index_shape, value, op_type): | |||
| """Convert a tuple of scalar to a tensor.""" | |||
| updates_shape = generate_updates_shape(data_shape, index_shape, op_type) | |||
| if len(value) != updates_shape[-1]: | |||
| raise ValueError(f"For '{TENSOR_SETITEM}', the number of elements : {len(value)} in the updates tuple " | |||
| f"does not meet the requirements: {updates_shape[-1]}.") | |||
| array = np.array(value, dtype=mstype.dtype_to_nptype(data_dtype)) | |||
| reps = compute_multiples(updates_shape[-1:], updates_shape) | |||
| return Tensor(np.tile(array, reps)) | |||
| @constexpr | |||
| def generate_updates_shape(data_shape, index_shape, op_type): | |||
| """Generate updates shape for 'tensor setitem'.""" | |||
| if op_type == SET_ITEM_BY_ONE_TENSOR: | |||
| updates_shape = index_shape + data_shape[1:] | |||
| else: | |||
| updates_shape = index_shape[:-1] + data_shape[index_shape[-1]:] | |||
| return updates_shape | |||
| @constexpr | |||
| def check_number_of_index_tensor(data_shape, tuple_len, op_name): | |||
| """Check if the number of index tensor exceeds the dimension of the operated tensor.""" | |||
| if tuple_len <= len(data_shape): | |||
| return True | |||
| raise IndexError(f"For '{op_name}', the number {tuple_len} of index tensor " | |||
| f"is greater than the dimension {len(data_shape)} of the operated tensor.") | |||
| def generate_indeices_from_tuple_of_tensor(data, tuple_index, op_name): | |||
| """Generate an indices tensor from a tuple of tensor.""" | |||
| indices = None | |||
| check_index_tensor_number = check_number_of_index_tensor(F.shape(data), len(tuple_index), op_name) | |||
| if check_index_tensor_number: | |||
| dtype_tuple = hyper_map(F.dtype, tuple_index) | |||
| check_dtypes = check_index_tensors_dtype(dtype_tuple, op_name) | |||
| if check_dtypes: | |||
| shape_tuple = hyper_map(F.shape, tuple_index) | |||
| broadcast_shape = broadcast_shapes(shape_tuple, op_name) | |||
| broadcast_tensors = hyper_map(F.partial(tile, broadcast_shape), tuple_index) | |||
| indices = pack(broadcast_tensors) | |||
| return indices | |||
| def generate_updates_from_scalar(data, indices, value, op_type): | |||
| """Generate an updates tensor from a scalar.""" | |||
| data_shape = F.shape(data) | |||
| indices_shape = F.shape(indices) | |||
| data_dtype = F.dtype(data) | |||
| return convert_scalar_to_tensor(data_shape, data_dtype, indices_shape, value, op_type) | |||
| def generate_updates_from_tuple(data, index, value, op_type): | |||
| """Generate an updates tensor from a tuple.""" | |||
| value_types = hyper_map(F.typeof, value) | |||
| data_dtype = F.dtype(data) | |||
| value_elements_type = check_value_elements(data_dtype, value_types) | |||
| if value_elements_type == ALL_TENSOR: | |||
| value_shapes = hyper_map(F.shape, value) | |||
| shapes_same = check_shapes_same(value_shapes, TENSOR_SETITEM) | |||
| if shapes_same: | |||
| value = F.pack(value) | |||
| return generate_updates_from_tensor(data, index, value, op_type) | |||
| data_shape = F.shape(data) | |||
| index_shape = F.shape(index) | |||
| return convert_tuple_of_scalar_to_tensor(data_shape, data_dtype, index_shape, value, op_type) | |||
| def generate_updates_from_tensor(data, index, value, op_type): | |||
| """Generate an updates tensor from a tensor.""" | |||
| data_shape = F.shape(data) | |||
| index_shape = F.shape(index) | |||
| value_shape = F.shape(value) | |||
| data_dtype = F.dtype(data) | |||
| value_dtype = F.dtype(value) | |||
| updates_shape = value_shape | |||
| check_dtype_same = check_tensors_dtype_same(data_dtype, value_dtype, TENSOR_SETITEM) | |||
| if check_dtype_same: | |||
| updates_shape = generate_updates_shape(data_shape, index_shape, op_type) | |||
| need_broadcast = check_two_shapes_need_broadcast(updates_shape, value_shape) | |||
| if need_broadcast: | |||
| return tile(updates_shape, value) | |||
| return value | |||
| @@ -15,9 +15,10 @@ | |||
| """Implementation for getitem.""" | |||
| from ...composite import base | |||
| from . import _utils as multi_utils | |||
| from ..import base | |||
| from ... import functional as F | |||
| from ....common import dtype as mstype | |||
| getitem = base.MultitypeFuncGraph('getitem') | |||
| """ | |||
| @@ -214,19 +215,45 @@ def _tensor_getitem_by_slice(data, slice_index): | |||
| return _tensor_slice(data, slice_index) | |||
| @getitem.register("Tensor", "Tensor") | |||
| def _tensor_getitem_by_tensor(data, tensor_index): | |||
| """ | |||
| Getting item of tensor by slice. | |||
| Inputs: | |||
| data (Tensor): A tensor. | |||
| tensor_index (Tensor): An index expressed by tensor. | |||
| Outputs: | |||
| Tensor, element type is same as the element type of data. | |||
| """ | |||
| check_dtypes = multi_utils.check_tensor_dtype_valid(F.dtype(tensor_index), (mstype.int32, mstype.int64)) | |||
| result = None | |||
| if check_dtypes: | |||
| result = F.gather(data, tensor_index, 0) | |||
| return result | |||
| @getitem.register("Tensor", "Tuple") | |||
| def _tensor_getitem_by_slice_tuple(data, slice_tuple_index): | |||
| def _tensor_getitem_by_tuple(data, tuple_index): | |||
| """ | |||
| Getting item of tensor by slice tuple. | |||
| Inputs: | |||
| data (Tensor): A tensor. | |||
| slice_tuple_index (tuple): Index in tuple. | |||
| tuple_index (tuple): Index in tuple. | |||
| Outputs: | |||
| Tensor, element type is same as the element type of data. | |||
| """ | |||
| return _tensor_slice(data, slice_tuple_index) | |||
| index_types = multi_utils.hyper_map(F.typeof, tuple_index) | |||
| index_elements_type = multi_utils.tuple_elements_type(index_types) | |||
| result = None | |||
| if index_elements_type == multi_utils.NO_TENSOR: | |||
| result = _tensor_slice(data, tuple_index) | |||
| if index_elements_type == multi_utils.ALL_TENSOR: | |||
| result = _tensor_getitem_by_tuple_of_tensor(data, tuple_index) | |||
| return result | |||
| @getitem.register("Tensor", "Ellipsis") | |||
| @@ -242,3 +269,10 @@ def _tensor_getitem_by_ellipsis(data, ellipsis_index): | |||
| Tensor, same as data. | |||
| """ | |||
| return _tensor_slice(data, ellipsis_index) | |||
| def _tensor_getitem_by_tuple_of_tensor(data, tuple_index): | |||
| """Tensor getitem by a tuple of tensor.""" | |||
| indices = multi_utils.generate_indeices_from_tuple_of_tensor(data, tuple_index, multi_utils.TENSOR_GETITEM) | |||
| result = F.gather_nd(data, indices) | |||
| return result | |||
| @@ -18,10 +18,11 @@ | |||
| from ...composite import base | |||
| from ....common import dtype as mstype | |||
| from ... import functional as F | |||
| from . import _multitype_ops_util as mult_util | |||
| from . import _utils as multi_utils | |||
| setitem = base.MultitypeFuncGraph('setitem') | |||
| @setitem.register("List", "Number", "String") | |||
| def _list_setitem_with_string(data, number_index, value): | |||
| """ | |||
| @@ -118,7 +119,7 @@ def _dict_setitem_with_number(data, key, value): | |||
| @setitem.register("Tensor", "Tensor", "Tensor") | |||
| def _tensor_setitem_by_tensor_v1(data, index, value_tensor): | |||
| def _tensor_setitem_by_tensor_with_tensor(data, index, value_tensor): | |||
| """ | |||
| Tensor assignment. | |||
| @@ -137,27 +138,15 @@ def _tensor_setitem_by_tensor_v1(data, index, value_tensor): | |||
| Outputs: | |||
| Tensor, element type and shape is same as data. | |||
| """ | |||
| result = None | |||
| index_dtype = F.dtype(index) | |||
| index_shape = F.shape(index) | |||
| check_result = mult_util.check_tensor_setitem_index(mstype.tensor, index_dtype) | |||
| if check_result: | |||
| data_shape = F.shape(data) | |||
| data_shape = mult_util.check_equal(data_shape, index_shape, | |||
| "The tensor(shape={}) and tensor index(shape={}) should be the same shape.") | |||
| size = F.size(value_tensor) | |||
| size = mult_util.check_equal(1, size, | |||
| "When assign value is a tensor, its size should be {}, but current size is {}.") | |||
| dtype = F.dtype(data) | |||
| u_cast = F.cast(value_tensor, dtype) | |||
| one_data = F.ones_like(data) | |||
| u = F.tensor_mul(one_data, u_cast) | |||
| result = F.select(index, u, data) | |||
| return result | |||
| tensor_dtype = multi_utils.get_index_tensor_dtype(index_dtype) | |||
| if tensor_dtype == multi_utils.INT_: | |||
| return _tensor_setitem_by_int_tensor_with_tensor(data, index, value_tensor) | |||
| return _tensor_setitem_by_bool_tensor_with_tensor(data, index, value_tensor) | |||
| @setitem.register("Tensor", "Tensor", "Number") | |||
| def _tensor_setitem_by_tensor_v2(data, index, value): | |||
| def _tensor_setitem_by_tensor_with_number(data, index, value): | |||
| """ | |||
| Tensor assignment. | |||
| @@ -171,143 +160,167 @@ def _tensor_setitem_by_tensor_v2(data, index, value): | |||
| Inputs: | |||
| data (Tensor): Assigned tensor. | |||
| index (Tensor): Tensor of bool type. | |||
| value_tensor (Number): Assignment value. | |||
| value (Number): Assignment value. | |||
| Outputs: | |||
| Tensor, element type and shape is same as data. | |||
| """ | |||
| result = None | |||
| index_dtype = F.dtype(index) | |||
| index_shape = F.shape(index) | |||
| check_result = mult_util.check_tensor_setitem_index(mstype.tensor, index_dtype) | |||
| if check_result: | |||
| shape = F.shape(data) | |||
| shape = mult_util.check_equal( | |||
| shape, index_shape, "The tensor(shape={}) and tensor index(shape={}) should be the same shape.") | |||
| dtype = F.dtype(data) | |||
| u = F.fill(dtype, shape, value) | |||
| result = F.select(index, u, data) | |||
| return result | |||
| tensor_dtype = multi_utils.get_index_tensor_dtype(index_dtype) | |||
| if tensor_dtype == multi_utils.BOOL_: | |||
| return _tensor_setitem_by_bool_tensor_with_scalar(data, index, value) | |||
| return _tensor_setitem_by_int_tensor_with_scalar(data, index, value) | |||
| @setitem.register("Tensor", "Slice", "Tensor") | |||
| def _tensor_setitem_with_slice_v3(data, input_slice, value): | |||
| @setitem.register("Tensor", "Tuple", "Number") | |||
| def _tensor_setitem_by_tuple_with_number(data, tuple_index, value): | |||
| """ | |||
| Tensor assignment. | |||
| Note: | |||
| Syntax support: A[Slice] = U | |||
| Restraint condition: A is a Tensor | |||
| Slice like "1:3" | |||
| U is a Tensor(size=1) or Tensor(size>1) | |||
| Syntax support: A[B, C, D] = u. | |||
| Restraint condition: 1) A is a Tensor, and B, C, D are index. | |||
| 2) u is a scalar. | |||
| Inputs: | |||
| data (Tensor): Assigned tensor. | |||
| input_slice (Slice): Slice expression. | |||
| index (Tuple): An index tuple. | |||
| value (Number): Assignment value. | |||
| Outputs: | |||
| Tensor, element type and shape is same as data. | |||
| """ | |||
| return _tensor_assgin_tensor(data, input_slice, value) | |||
| index_types = multi_utils.hyper_map(F.typeof, tuple_index) | |||
| index_elements_type = multi_utils.tuple_elements_type(index_types) | |||
| result = None | |||
| if index_elements_type == multi_utils.NO_TENSOR: | |||
| result = _tensor_assgin_number(data, tuple_index, value) | |||
| if index_elements_type == multi_utils.ALL_TENSOR: | |||
| indices = multi_utils.generate_indeices_from_tuple_of_tensor(data, tuple_index, multi_utils.TENSOR_SETITEM) | |||
| updates = multi_utils.generate_updates_from_scalar(data, indices, value, | |||
| multi_utils.SET_ITEM_BY_TUPLE_OF_TENSOR) | |||
| result = F.scatter_nd_update(data, indices, updates) | |||
| return result | |||
| @setitem.register("Tensor", "Tuple", "Tensor") | |||
| def _tensor_setitem_with_slice_v4(data, input_slice, value): | |||
| def _tensor_setitem_by_tuple_with_tensor(data, tuple_index, value): | |||
| """ | |||
| Tensor assignment. | |||
| Note: | |||
| Syntax support: A[tuple(Slice)] = U, and A[tuple(Number)] = U | |||
| Restraint condition: A is a Tensor | |||
| Slice like "1:3, ::, :4:-1" | |||
| U is a Tensor(size=1) or Tensor(size>1) | |||
| Syntax support: A[B, C, D] = U. | |||
| Restraint condition: 1) A is a Tensor, and B, C, D are index Tensors. | |||
| 2) U is a Tensor. | |||
| Inputs: | |||
| data (Tensor): Assigned tensor. | |||
| input_slice (Union[tuple[Slice], tuple[Number]]): Slice expression. | |||
| value (Number): Assignment value. | |||
| index (Tuple): An index tuple. | |||
| value (Tensor): Assignment tensor, should has the same data type as 'data'. | |||
| Outputs: | |||
| Tensor, element type and shape is same as data. | |||
| """ | |||
| return _tensor_assgin_tensor(data, input_slice, value) | |||
| index_types = multi_utils.hyper_map(F.typeof, tuple_index) | |||
| index_elements_type = multi_utils.tuple_elements_type(index_types) | |||
| result = None | |||
| if index_elements_type == multi_utils.NO_TENSOR: | |||
| result = _tensor_assgin_tensor(data, tuple_index, value) | |||
| if index_elements_type == multi_utils.ALL_TENSOR: | |||
| indices = multi_utils.generate_indeices_from_tuple_of_tensor(data, tuple_index, multi_utils.TENSOR_SETITEM) | |||
| updates = multi_utils.generate_updates_from_tensor(data, indices, value, | |||
| multi_utils.SET_ITEM_BY_TUPLE_OF_TENSOR) | |||
| result = F.scatter_nd_update(data, indices, updates) | |||
| return result | |||
| def _tensor_assgin_tensor(data, input_slice, value): | |||
| """Assigns a tensor value to the tensor by slice.""" | |||
| @setitem.register("Tensor", "Tuple", "Tuple") | |||
| def _tensor_setitem_by_tuple_with_tuple(data, tuple_index, value): | |||
| """ | |||
| Tensor assignment. | |||
| Note: | |||
| Syntax support: A[B, C, D] = U. | |||
| Restraint condition: 1) A is a Tensor, and B, C, D are index Tensors. | |||
| 2) A B and C could be broadcast. | |||
| 3) U is a Tensor. | |||
| Inputs: | |||
| data (Tensor): Assigned tensor. | |||
| index (Tuple): A tuple of tensor, these tensor could be broadcast. | |||
| value (Tensor): Assignment tensor, should has the same data type as 'data'. | |||
| Outputs: | |||
| Tensor, element type and shape is same as data. | |||
| """ | |||
| index_types = multi_utils.hyper_map(F.typeof, tuple_index) | |||
| index_elements_type = multi_utils.tuple_elements_type(index_types) | |||
| result = None | |||
| check_result = mult_util.check_tensor_setitem_index(input_slice) | |||
| if check_result: | |||
| data_shape = F.shape(data) | |||
| indices = mult_util.slice2indices(input_slice, data_shape) | |||
| is_tuple_int = mult_util.tuple_element_is_int(input_slice) | |||
| if is_tuple_int: | |||
| indices = mult_util.integer_to_indices(input_slice, data_shape) | |||
| result = _tensor_indices_tensor(data, data_shape, input_slice, indices, value) | |||
| if index_elements_type == multi_utils.ALL_TENSOR: | |||
| indices = multi_utils.generate_indeices_from_tuple_of_tensor(data, tuple_index, multi_utils.TENSOR_SETITEM) | |||
| updates = multi_utils.generate_updates_from_tuple(data, indices, value, | |||
| multi_utils.SET_ITEM_BY_TUPLE_OF_TENSOR) | |||
| result = F.scatter_nd_update(data, indices, updates) | |||
| return result | |||
| def _tensor_indices_tensor(data, data_shape, index, indices, value): | |||
| """Assigns a tensor value to the tensor.""" | |||
| data_size = F.size(data) | |||
| data_dtype = F.dtype(data) | |||
| indices_size = F.size(indices) | |||
| indices_size = mult_util.check_indices(indices_size, index) | |||
| update = F.fill(mstype.int32, (indices_size,), 1) | |||
| condition_1d = F.scatter_nd(indices, update, (data_size,)) | |||
| condition = F.reshape(condition_1d, data_shape) | |||
| condition = F.cast(condition, mstype.bool_) | |||
| value_fill = None | |||
| value_size = F.size(value) | |||
| @setitem.register("Tensor", "Tensor", "Tuple") | |||
| def _tensor_setitem_by_tensor_v2(data, index, value): | |||
| """ | |||
| Tensor assignment. | |||
| value_size = mult_util.check_indices_value_size(indices_size, value_size) | |||
| if value_size == 1: | |||
| value_fill = F.fill(data_dtype, (indices_size,), 1) | |||
| value = F.cast(value, data_dtype) | |||
| value_fill = F.tensor_mul(value_fill, value) | |||
| elif value_size > 1: | |||
| value_fill = F.reshape(value, (indices_size,)) | |||
| value_1d = F.scatter_nd(indices, value_fill, (data_size,)) | |||
| u = F.reshape(value_1d, data_shape) | |||
| return F.select(condition, u, data) | |||
| Inputs: | |||
| data (Tensor): Assigned tensor. | |||
| index (Tensor): Tensor of bool type. | |||
| value (Tuple): Assignment value. | |||
| @setitem.register("Tensor", "Slice", "Number") | |||
| def _tensor_setitem_with_slice_v1(data, input_slice, value): | |||
| Outputs: | |||
| Tensor, element type and shape is same as data. | |||
| """ | |||
| index_dtype = F.dtype(index) | |||
| check_dtype = multi_utils.check_tensor_dtype_valid(index_dtype, (mstype.int32, mstype.int64)) | |||
| result = None | |||
| if check_dtype: | |||
| result = _tensor_setitem_by_tensor_with_tuple(data, index, value) | |||
| return result | |||
| @setitem.register("Tensor", "Slice", "Tensor") | |||
| def _tensor_setitem_with_slice_v3(data, input_slice, value): | |||
| """ | |||
| Tensor assignment. | |||
| Note: | |||
| Syntax support: A[Slice] = u | |||
| Restraint condition: A is a Tensor. | |||
| Syntax support: A[Slice] = U | |||
| Restraint condition: A is a Tensor | |||
| Slice like "1:3" | |||
| u is a scalar | |||
| U is a Tensor(size=1) or Tensor(size>1) | |||
| Inputs: | |||
| data (Tensor): Assigned tensor. | |||
| input_slice (Slice): slice expression. | |||
| input_slice (Slice): Slice expression. | |||
| value (Number): Assignment value. | |||
| Outputs: | |||
| Tensor, element type and shape is same as data. | |||
| """ | |||
| return _tensor_assgin_number(data, input_slice, value) | |||
| return _tensor_assgin_tensor(data, input_slice, value) | |||
| @setitem.register("Tensor", "Tuple", "Number") | |||
| def _tensor_setitem_with_slice_v2(data, input_slice, value): | |||
| @setitem.register("Tensor", "Slice", "Number") | |||
| def _tensor_setitem_with_slice_v1(data, input_slice, value): | |||
| """ | |||
| Tensor assignment. | |||
| Note: | |||
| Syntax support: A[tuple(Slice)] = u, and A[tuple(Number)] = u | |||
| Syntax support: A[Slice] = u | |||
| Restraint condition: A is a Tensor. | |||
| Slice like "1:3, ::, :4:-1" | |||
| Slice like "1:3" | |||
| u is a scalar | |||
| Inputs: | |||
| data (Tensor): Assigned tensor. | |||
| input_slice (Union[tuple[Slice], tuple[Number]]): slice expression. | |||
| input_slice (Slice): slice expression. | |||
| value (Number): Assignment value. | |||
| Outputs: | |||
| @@ -318,39 +331,23 @@ def _tensor_setitem_with_slice_v2(data, input_slice, value): | |||
| def _tensor_assgin_number(data, input_slice, value): | |||
| """Givens a scalar assign to tensor by slice""" | |||
| check_result = mult_util.check_tensor_setitem_index(input_slice) | |||
| check_result = multi_utils.check_tensor_setitem_index(input_slice) | |||
| result = None | |||
| if check_result: | |||
| data_shape = F.shape(data) | |||
| indices = mult_util.slice2indices(input_slice, data_shape) | |||
| is_tuple_int = mult_util.tuple_element_is_int(input_slice) | |||
| indices = multi_utils.slice2indices(input_slice, data_shape) | |||
| is_tuple_int = multi_utils.tuple_element_is_int(input_slice) | |||
| if is_tuple_int: | |||
| indices = mult_util.integer_to_indices(input_slice, data_shape) | |||
| indices = multi_utils.integer_to_indices(input_slice, data_shape) | |||
| result = _tensor_indices_number(data, data_shape, input_slice, indices, value) | |||
| return result | |||
| def _tensor_indices_number(data, data_shape, index, indices, value): | |||
| """Assigns a scalar value to the tensor.""" | |||
| data_size = F.size(data) | |||
| data_dtype = F.dtype(data) | |||
| indices_size = F.size(indices) | |||
| indices_size = mult_util.check_indices(indices_size, index) | |||
| update = F.fill(mstype.int32, (indices_size,), 1) | |||
| condition_1d = F.scatter_nd(indices, update, (data_size,)) | |||
| condition = F.reshape(condition_1d, data_shape) | |||
| condition = F.cast(condition, mstype.bool_) | |||
| value_fill = F.fill(data_dtype, (indices_size,), value) | |||
| value_1d = F.scatter_nd(indices, value_fill, (data_size,)) | |||
| u = F.reshape(value_1d, data_shape) | |||
| return F.select(condition, u, data) | |||
| @setitem.register("Tensor", "Number", "Number") | |||
| def _tensor_setitem_with_int_v1(data, index, value): | |||
| """Syntax: A[1] = 3""" | |||
| data_shape = F.shape(data) | |||
| indices = mult_util.integer_to_indices(index, data_shape) | |||
| indices = multi_utils.integer_to_indices(index, data_shape) | |||
| return _tensor_indices_number(data, data_shape, index, indices, value) | |||
| @@ -358,7 +355,7 @@ def _tensor_setitem_with_int_v1(data, index, value): | |||
| def _tensor_setitem_with_int_v2(data, index, value): | |||
| """Syntax: A[1] = Tensor""" | |||
| data_shape = F.shape(data) | |||
| indices = mult_util.integer_to_indices(index, data_shape) | |||
| indices = multi_utils.integer_to_indices(index, data_shape) | |||
| return _tensor_indices_tensor(data, data_shape, index, indices, value) | |||
| @@ -379,7 +376,7 @@ def _tensor_setitem_with_ellipsis_v2(data, index, value): | |||
| data_size = F.size(data) | |||
| value_shape = F.shape(value) | |||
| value_size = F.size(value) | |||
| check_result = mult_util.check_ellipsis_shape_size(data_shape, value_shape, data_size, value_size) | |||
| check_result = multi_utils.check_ellipsis_shape_size(data_shape, value_shape, data_size, value_size) | |||
| if check_result: | |||
| if data_size == value_size: | |||
| result = F.reshape(value, data_shape) | |||
| @@ -389,3 +386,108 @@ def _tensor_setitem_with_ellipsis_v2(data, index, value): | |||
| param2 = F.cast(value, data_dtype) | |||
| result = F.tensor_mul(param1, param2) | |||
| return result | |||
| def _tensor_assgin_tensor(data, input_slice, value): | |||
| """Assigns a tensor value to the tensor by slice.""" | |||
| result = None | |||
| check_result = multi_utils.check_tensor_setitem_index(input_slice) | |||
| if check_result: | |||
| data_shape = F.shape(data) | |||
| indices = multi_utils.slice2indices(input_slice, data_shape) | |||
| is_tuple_int = multi_utils.tuple_element_is_int(input_slice) | |||
| if is_tuple_int: | |||
| indices = multi_utils.integer_to_indices(input_slice, data_shape) | |||
| result = _tensor_indices_tensor(data, data_shape, input_slice, indices, value) | |||
| return result | |||
| def _tensor_indices_tensor(data, data_shape, index, indices, value): | |||
| """Assigns a tensor value to the tensor.""" | |||
| data_size = F.size(data) | |||
| data_dtype = F.dtype(data) | |||
| indices_size = F.size(indices) | |||
| indices_size = multi_utils.check_indices(indices_size, index) | |||
| update = F.fill(mstype.int32, (indices_size,), 1) | |||
| condition_1d = F.scatter_nd(indices, update, (data_size,)) | |||
| condition = F.reshape(condition_1d, data_shape) | |||
| condition = F.cast(condition, mstype.bool_) | |||
| value_fill = None | |||
| value_size = F.size(value) | |||
| value_size = multi_utils.check_indices_value_size(indices_size, value_size) | |||
| if value_size == 1: | |||
| value_fill = F.fill(data_dtype, (indices_size,), 1) | |||
| value = F.cast(value, data_dtype) | |||
| value_fill = F.tensor_mul(value_fill, value) | |||
| elif value_size > 1: | |||
| value_fill = F.reshape(value, (indices_size,)) | |||
| value_1d = F.scatter_nd(indices, value_fill, (data_size,)) | |||
| u = F.reshape(value_1d, data_shape) | |||
| return F.select(condition, u, data) | |||
| def _tensor_indices_number(data, data_shape, index, indices, value): | |||
| """Assigns a scalar value to the tensor.""" | |||
| data_size = F.size(data) | |||
| data_dtype = F.dtype(data) | |||
| indices_size = F.size(indices) | |||
| indices_size = multi_utils.check_indices(indices_size, index) | |||
| update = F.fill(mstype.int32, (indices_size,), 1) | |||
| condition_1d = F.scatter_nd(indices, update, (data_size,)) | |||
| condition = F.reshape(condition_1d, data_shape) | |||
| condition = F.cast(condition, mstype.bool_) | |||
| value_fill = F.fill(data_dtype, (indices_size,), value) | |||
| value_1d = F.scatter_nd(indices, value_fill, (data_size,)) | |||
| u = F.reshape(value_1d, data_shape) | |||
| return F.select(condition, u, data) | |||
| def _tensor_setitem_by_tensor_with_tuple(data, index, value): | |||
| """Set a tensor item by a tensor with a tuple.""" | |||
| updates = multi_utils.generate_updates_from_tuple(data, index, value, | |||
| multi_utils.SET_ITEM_BY_ONE_TENSOR) | |||
| result = F.scatter_update(data, index, updates) | |||
| return result | |||
| def _tensor_setitem_by_int_tensor_with_scalar(data, index, value): | |||
| """Set a tensor item by a int tensor with a scalar.""" | |||
| updates = multi_utils.generate_updates_from_scalar(data, index, value, | |||
| multi_utils.SET_ITEM_BY_ONE_TENSOR) | |||
| return F.scatter_update(data, index, updates) | |||
| def _tensor_setitem_by_bool_tensor_with_scalar(data, index, value): | |||
| """Set a tensor item by a bool tensor with a scalar.""" | |||
| index_shape = F.shape(index) | |||
| shape = F.shape(data) | |||
| shape = multi_utils.check_equal( | |||
| shape, index_shape, "The tensor(shape={}) and tensor index(shape={}) should be the same shape.") | |||
| dtype = F.dtype(data) | |||
| u = F.fill(dtype, shape, value) | |||
| return F.select(index, u, data) | |||
| def _tensor_setitem_by_int_tensor_with_tensor(data, index, value): | |||
| """Set a tensor item by a int tensor with a tensor.""" | |||
| updates = multi_utils.generate_updates_from_tensor(data, index, value, | |||
| multi_utils.SET_ITEM_BY_ONE_TENSOR) | |||
| return F.scatter_update(data, index, updates) | |||
| def _tensor_setitem_by_bool_tensor_with_tensor(data, index, value): | |||
| """Set a tensor item by a bool tensor with a tensor.""" | |||
| index_shape = F.shape(index) | |||
| data_shape = F.shape(data) | |||
| data_shape = multi_utils.check_equal(data_shape, index_shape, | |||
| "The tensor(shape={}) and tensor index(shape={}) should be the same shape.") | |||
| size = F.size(value) | |||
| size = multi_utils.check_equal(1, size, | |||
| "When assign value is a tensor, its size should be {}, but current size is {}.") | |||
| dtype = F.dtype(data) | |||
| u_cast = F.cast(value, dtype) | |||
| one_data = F.ones_like(data) | |||
| u = F.tensor_mul(one_data, u_cast) | |||
| result = F.select(index, u, data) | |||
| return result | |||
| @@ -31,6 +31,7 @@ dtype = P.DType() | |||
| issubclass_ = P.IsSubClass() | |||
| isinstance_ = P.IsInstance() | |||
| fill = P.Fill() | |||
| tile = P.Tile() | |||
| select = P.Select() | |||
| size = P.Size() | |||
| ones_like = P.OnesLike() | |||
| @@ -70,6 +71,12 @@ scalar_cast = P.ScalarCast() | |||
| print_ = P.Print() | |||
| expand_dims = P.ExpandDims() | |||
| scatter_nd = P.ScatterNd() | |||
| gather = P.GatherV2() | |||
| gather_nd = P.GatherNd() | |||
| scatter_update = P.ScatterUpdate() | |||
| scatter_nd_update = P.ScatterNdUpdate() | |||
| pack = P.Pack() | |||
| tuple_setitem = Primitive('tuple_setitem') | |||
| tuple_getitem = Primitive('tuple_getitem') | |||
| @@ -24,7 +24,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack, | |||
| Fill, GatherNd, GatherV2, InvertPermutation, | |||
| IsInstance, IsSubClass, ArgMaxWithValue, OnesLike, ZerosLike, | |||
| Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue, | |||
| SameTypeShape, ScatterMax, | |||
| SameTypeShape, ScatterMax, ScatterUpdate, | |||
| ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select, | |||
| Shape, Size, Slice, Split, | |||
| Squeeze, StridedSlice, Tile, | |||
| @@ -193,6 +193,7 @@ __all__ = [ | |||
| 'Pad', | |||
| 'MirrorPad', | |||
| 'GatherNd', | |||
| 'ScatterUpdate', | |||
| 'ScatterNdUpdate', | |||
| 'Floor', | |||
| 'NMSWithMask', | |||
| @@ -19,7 +19,7 @@ from ..._c_expression import signature_rw as sig_rw | |||
| from ..._c_expression import signature_kind as sig_kind | |||
| from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register | |||
| from ..._checkparam import Validator as validator, Rel | |||
| from .._utils import _get_concat_offset | |||
| from .._utils import get_concat_offset | |||
| from ...common import dtype as mstype | |||
| @@ -136,7 +136,7 @@ class ConcatOffset(PrimitiveWithInfer): | |||
| axis = self.axis | |||
| x_shp = input_x['shape'] | |||
| x_type = input_x['dtype'] | |||
| offset, _, axis = _get_concat_offset(x_shp, x_type, axis, self.name) | |||
| offset, _, axis = get_concat_offset(x_shp, x_type, axis, self.name) | |||
| self.add_prim_attr('T', x_type[0].element_type()) | |||
| offset_values = [] | |||
| for i in range(len(x_shp)): | |||
| @@ -24,16 +24,15 @@ import itertools | |||
| import numbers | |||
| import numpy as np | |||
| from ..._c_expression import signature_rw as sig_rw | |||
| from ..._c_expression import signature_kind as sig_kind | |||
| from ..._checkparam import Validator as validator | |||
| from ..._checkparam import Rel | |||
| from ...common import dtype as mstype | |||
| from ...common.tensor import Tensor | |||
| from ..operations.math_ops import _infer_shape_reduce | |||
| from .._utils import _get_concat_offset | |||
| from .._utils import get_concat_offset | |||
| from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register | |||
| def _check_infer_attr_reduce(axis, keep_dims, prim_name): | |||
| validator.check_value_type('keep_dims', keep_dims, [bool], prim_name) | |||
| validator.check_value_type('axis', axis, [int, tuple], prim_name) | |||
| @@ -931,7 +930,7 @@ class InvertPermutation(PrimitiveWithInfer): | |||
| z = [x_value[i] for i in range(len(x_value))] | |||
| z.sort() | |||
| y = [None]*len(x_value) | |||
| y = [None] * len(x_value) | |||
| for i, value in enumerate(x_value): | |||
| validator.check_value_type("input[%d]" % i, value, [int], self.name) | |||
| validator.check(f'value', z[i], f'index', i, Rel.EQ, self.name) | |||
| @@ -1111,6 +1110,7 @@ class ArgMinWithValue(PrimitiveWithInfer): | |||
| >>> input_x = Tensor(np.random.rand(5)) | |||
| >>> index, output = P.ArgMinWithValue()(input_x) | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self, axis=0, keep_dims=False): | |||
| """init ArgMinWithValue""" | |||
| @@ -1352,7 +1352,7 @@ class Concat(PrimitiveWithInfer): | |||
| axis = self.axis | |||
| x_shp = input_x['shape'] | |||
| x_type = input_x['dtype'] | |||
| _, all_shp, _ = _get_concat_offset(x_shp, x_type, axis, self.name) | |||
| _, all_shp, _ = get_concat_offset(x_shp, x_type, axis, self.name) | |||
| self.add_prim_attr('T', x_type[0].element_type()) | |||
| self.add_prim_attr('inputNums', len(x_shp)) | |||
| ret_shp = x_shp[0].copy() | |||
| @@ -1376,15 +1376,13 @@ def _get_pack_shape(x_shape, x_type, axis, prim_name): | |||
| if axis < 0: | |||
| axis = axis + rank_base + 1 | |||
| for i in range(1, N): | |||
| v = x_shape[i] | |||
| validator.check('len of x_shape[%d]' % i, len(v), 'len of rank_base', rank_base, Rel.EQ, prim_name) | |||
| validator.check('x_type[%d]' % i, x_type[i], 'base', x_type[0], Rel.EQ, prim_name, TypeError) | |||
| for j in range(rank_base): | |||
| if v[j] != x_shape[0][j]: | |||
| raise ValueError(f"For \'{prim_name}\' element {i} shape in input can not pack with first element") | |||
| if x_shape[i] != x_shape[0]: | |||
| raise ValueError(f"For \'{prim_name}\' element {i} shape in input can not pack with first element") | |||
| out_shape.insert(axis, N) | |||
| return out_shape | |||
| class Pack(PrimitiveWithInfer): | |||
| r""" | |||
| Packs a list of tensors in specified axis. | |||
| @@ -1831,7 +1829,7 @@ class DiagPart(PrimitiveWithInfer): | |||
| return x_type | |||
| def infer_shape(self, x_shape): | |||
| if len(x_shape)%2 != 0 or \ | |||
| if len(x_shape) % 2 != 0 or \ | |||
| not x_shape: | |||
| raise ValueError(f"For \'{self.name}\' input rank must be non-zero and even, but got rank {len(x_shape)}, " | |||
| f"with shapes {x_shape}") | |||
| @@ -2004,6 +2002,49 @@ class GatherNd(PrimitiveWithInfer): | |||
| return x_dtype | |||
| class ScatterUpdate(PrimitiveWithInfer): | |||
| """ | |||
| Update tensor value by using input indices and value. | |||
| Using given values to update tensor value, along with the input indices. | |||
| Args: | |||
| use_locking (bool): Whether protect the assignment by a lock. Default: True. | |||
| Inputs: | |||
| - **input_x** (Parameter) - The target tensor, with data type of Parameter. | |||
| - **indices** (Tensor) - The index of input tensor. | |||
| - **update** (Tensor) - The tensor to update the input tensor, has the same type as input, | |||
| and update.shape = indices.shape + input_x.shape[1:]. | |||
| Outputs: | |||
| Tensor, has the same shape and type as `input_x`. | |||
| Examples: | |||
| >>> input_x = mindspore.Parameter(Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32)) | |||
| >>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32) | |||
| >>> update = Tensor(np.array([1.0, 2.2]), mindspore.float32) | |||
| >>> op = P.ScatterNdUpdate() | |||
| >>> output = op(input_x, indices, update) | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self, use_locking=True): | |||
| """Init ScatterNdUpdate""" | |||
| self.init_prim_io_names(inputs=['x', 'indices', 'value'], outputs=['y']) | |||
| def infer_shape(self, x_shape, indices_shape, value_shape): | |||
| if indices_shape + x_shape[1:] != value_shape: | |||
| raise ValueError('Input value are not match with input indices.') | |||
| return x_shape | |||
| def infer_dtype(self, x_dtype, indices_dtype, value_dtype): | |||
| validator.check_tensor_type_same({'indices': indices_dtype}, mstype.int_type, self.name) | |||
| args = {"x": x_dtype, "value": value_dtype} | |||
| validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name) | |||
| return x_dtype | |||
| class ScatterNdUpdate(PrimitiveWithInfer): | |||
| """ | |||
| Update tensor value by using input indices and value. | |||
| @@ -2028,11 +2069,6 @@ class ScatterNdUpdate(PrimitiveWithInfer): | |||
| >>> op = P.ScatterNdUpdate() | |||
| >>> output = op(input_x, indices, update) | |||
| """ | |||
| __mindspore_signature__ = ( | |||
| ('input_x', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD), | |||
| ('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD), | |||
| ('value', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD) | |||
| ) | |||
| @prim_attr_register | |||
| def __init__(self, use_locking=True): | |||
| @@ -2142,10 +2178,10 @@ class SpaceToDepth(PrimitiveWithInfer): | |||
| validator.check('x dimension', len(x_shape), '', 4, Rel.EQ) | |||
| out_shape = copy.deepcopy(x_shape) | |||
| for i in range(2): | |||
| if out_shape[i+2] % self.block_size != 0: | |||
| raise ValueError(f'For \'{self.name}\' input shape[{i+2}] {out_shape[i+2]} should be ' | |||
| if out_shape[i + 2] % self.block_size != 0: | |||
| raise ValueError(f'For \'{self.name}\' input shape[{i + 2}] {out_shape[i + 2]} should be ' | |||
| f'fully divided by block_size {self.block_size}') | |||
| out_shape[i+2] //= self.block_size | |||
| out_shape[i + 2] //= self.block_size | |||
| out_shape[1] *= self.block_size * self.block_size | |||
| return out_shape | |||
| @@ -2199,9 +2235,10 @@ class DepthToSpace(PrimitiveWithInfer): | |||
| validator.check('x dimension', len(x_shape), '', 4, Rel.EQ) | |||
| out_shape = copy.deepcopy(x_shape) | |||
| for i in range(2): | |||
| out_shape[i+2] *= self.block_size | |||
| out_shape[i + 2] *= self.block_size | |||
| validator.check_integer('x_shape[1] % (block_size*block_size)', x_shape[1] % (self.block_size*self.block_size), | |||
| validator.check_integer('x_shape[1] % (block_size*block_size)', | |||
| x_shape[1] % (self.block_size * self.block_size), | |||
| 0, Rel.EQ, self.name) | |||
| out_shape[1] //= self.block_size * self.block_size | |||
| return out_shape | |||
| @@ -2251,6 +2288,7 @@ class SpaceToBatch(PrimitiveWithInfer): | |||
| [[[[1.]]], [[[2.]]], [[[3.]]], [[[4.]]]] | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self, block_size, paddings): | |||
| """Init SpaceToBatch""" | |||
| @@ -2271,12 +2309,12 @@ class SpaceToBatch(PrimitiveWithInfer): | |||
| validator.check_integer('rank of input_x', len(x_shape), 4, Rel.EQ, self.name) | |||
| out_shape = copy.deepcopy(x_shape) | |||
| for i in range(2): | |||
| padded = out_shape[i+2] + self.paddings[i][0] + \ | |||
| padded = out_shape[i + 2] + self.paddings[i][0] + \ | |||
| self.paddings[i][1] | |||
| if padded % self.block_size != 0: | |||
| raise ValueError(f'For \'{self.name}\' padded[{i}] {padded} should be divisible by ' | |||
| f'block_size {self.block_size}') | |||
| out_shape[i+2] = padded // self.block_size | |||
| out_shape[i + 2] = padded // self.block_size | |||
| out_shape[0] *= self.block_size * self.block_size | |||
| return out_shape | |||
| @@ -2319,6 +2357,7 @@ class BatchToSpace(PrimitiveWithInfer): | |||
| [[[[1., 2.], [3., 4.]]]] | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self, block_size, crops): | |||
| """Init BatchToSpace""" | |||
| @@ -2339,10 +2378,10 @@ class BatchToSpace(PrimitiveWithInfer): | |||
| validator.check('rank of input_x', len(x_shape), '', 4) | |||
| out_shape = copy.deepcopy(x_shape) | |||
| for i in range(2): | |||
| x_block_prod = out_shape[i+2] * self.block_size | |||
| x_block_prod = out_shape[i + 2] * self.block_size | |||
| crops_sum = self.crops[i][0] + self.crops[i][1] | |||
| validator.check("x block shape prod", x_block_prod, 'crops sum', crops_sum, Rel.GT, self.name) | |||
| out_shape[i+2] = x_block_prod - crops_sum | |||
| out_shape[i + 2] = x_block_prod - crops_sum | |||
| block_size_prod = self.block_size * self.block_size | |||
| if out_shape[0] % block_size_prod != 0: | |||
| raise ValueError(f'For \'{self.name}\' input_x dimension 0 {out_shape[0]} should be divisible by ' | |||
| @@ -24,7 +24,7 @@ from ..._checkparam import Validator as validator | |||
| from ..._checkparam import Rel | |||
| from ...common import dtype as mstype | |||
| from ...common.tensor import Tensor | |||
| from .._utils import _get_broadcast_shape | |||
| from .._utils import get_broadcast_shape | |||
| from ..primitive import PrimitiveWithInfer, prim_attr_register, _run_op | |||
| @@ -75,7 +75,7 @@ class _BinaryOp(PrimitiveWithInfer): | |||
| self.init_prim_io_names(inputs=['x', 'y'], outputs=['output']) | |||
| def infer_shape(self, x_shape, y_shape): | |||
| return _get_broadcast_shape(x_shape, y_shape, self.name) | |||
| return get_broadcast_shape(x_shape, y_shape, self.name) | |||
| class _MathBinaryOp(_BinaryOp): | |||
| @@ -27,9 +27,15 @@ class IdentityEC(IExectorComponent): | |||
| def __call__(self): | |||
| result_id = self.function[keyword.id] + '-' + self.inputs[keyword.id] | |||
| group = self.function[keyword.group] + '-' + self.inputs[keyword.group] | |||
| return { | |||
| ret = { | |||
| keyword.id: result_id, | |||
| keyword.group: group, | |||
| keyword.desc_inputs: self.inputs[keyword.desc_inputs], | |||
| keyword.result: self.function[keyword.block](*self.inputs[keyword.desc_inputs]) | |||
| } | |||
| print("buxue------------------------------------------------") | |||
| print("inputs") | |||
| print(ret[keyword.desc_inputs]) | |||
| print("outputs") | |||
| print(ret[keyword.result]) | |||
| return ret | |||
| @@ -1307,7 +1307,7 @@ raise_set = [ | |||
| ('ScatterNdUpdate', { | |||
| 'block': (P.ScatterNdUpdate(), {'exception': TypeError}), | |||
| 'desc_inputs': (Tensor(np.ones((2, 3), np.float32)), | |||
| Tensor(np.ones((2, 2), np.int32)), | |||
| Tensor(np.ones((2, 2), np.float32)), | |||
| Tensor(np.ones((2,), np.float32))), | |||
| 'desc_bprop': [[2, 3]]}), | |||
| ('Pack', { | |||
| @@ -16,13 +16,14 @@ | |||
| import numpy as np | |||
| import pytest | |||
| from mindspore import Tensor | |||
| from mindspore import Tensor, Parameter | |||
| from mindspore import context | |||
| from mindspore import dtype as mstype | |||
| from mindspore.nn import Cell | |||
| from ....mindspore_test_framework.mindspore_test import mindspore_test | |||
| from ....mindspore_test_framework.pipeline.forward.compile_forward \ | |||
| import pipeline_for_compile_forward_ge_graph_for_case_by_case_config | |||
| import pipeline_for_compile_forward_ge_graph_for_case_by_case_config, \ | |||
| pipeline_for_compile_forward_ge_graph_for_case_by_case_config_exception | |||
| class NetWorkSlicePositive(Cell): | |||
| @@ -145,6 +146,160 @@ class TensorAssignWithSlice(Cell): | |||
| return z | |||
| class TensorIndexByOneTensor(Cell): | |||
| def __init__(self): | |||
| super(TensorIndexByOneTensor, self).__init__() | |||
| self.const = Tensor(np.ones((5, 4, 7, 8)), mstype.int32) | |||
| def construct(self, x, index): | |||
| ret = x[index] + self.const | |||
| return ret | |||
| class TensorIndexByTwoTensors(Cell): | |||
| def __init__(self): | |||
| super(TensorIndexByTwoTensors, self).__init__() | |||
| self.const = Tensor(np.ones((3, 4, 5, 8)), mstype.int32) | |||
| def construct(self, x, index_0, index_1): | |||
| ret = x[index_0, index_1] + self.const | |||
| return ret | |||
| class TensorIndexByThreeTensors(Cell): | |||
| def __init__(self): | |||
| super(TensorIndexByThreeTensors, self).__init__() | |||
| self.const = Tensor(np.ones((5, 3, 4, 5)), mstype.int32) | |||
| def construct(self, x, index_0, index_1, index_2): | |||
| ret = x[index_0, index_1, index_2] + self.const | |||
| return ret | |||
| class TensorSetItemByOneTensorWithNumber(Cell): | |||
| def __init__(self, value): | |||
| super(TensorSetItemByOneTensorWithNumber, self).__init__() | |||
| self.const = Tensor(np.ones((6, 7, 8)), mstype.float32) | |||
| self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x") | |||
| self.value = value | |||
| def construct(self, index): | |||
| self.param[index] = self.value | |||
| ret = self.param + self.const | |||
| return ret | |||
| class TensorSetItemByOneTensorWithTensor(Cell): | |||
| def __init__(self): | |||
| super(TensorSetItemByOneTensorWithTensor, self).__init__() | |||
| self.const = Tensor(np.ones((6, 7, 8)), mstype.float32) | |||
| self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x") | |||
| def construct(self, index, value): | |||
| self.param[index] = value | |||
| ret = self.param + self.const | |||
| return ret | |||
| class TensorSetItemByOneTensorWithTupleOfNumber(Cell): | |||
| def __init__(self, value): | |||
| super(TensorSetItemByOneTensorWithTupleOfNumber, self).__init__() | |||
| self.const = Tensor(np.ones((6, 7, 8)), mstype.float32) | |||
| self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x") | |||
| self.value = value | |||
| def construct(self, index): | |||
| self.param[index] = self.value | |||
| ret = self.param + self.const | |||
| return ret | |||
| class TensorSetItemByOneTensorWithTupleOfTensor(Cell): | |||
| def __init__(self): | |||
| super(TensorSetItemByOneTensorWithTupleOfTensor, self).__init__() | |||
| self.const = Tensor(np.ones((6, 3, 8)), mstype.float32) | |||
| self.param = Parameter(Tensor(np.arange(6*3*8).reshape((6, 3, 8)), mstype.float32), name="x") | |||
| def construct(self, index, value_0, value_1, value_2): | |||
| self.param[index] = (value_0, value_1, value_2) | |||
| ret = self.param + self.const | |||
| return ret | |||
| class TensorSetItemByTensorsWithNumber(Cell): | |||
| def __init__(self, value): | |||
| super(TensorSetItemByTensorsWithNumber, self).__init__() | |||
| self.const = Tensor(np.ones((6, 7, 8)), mstype.float32) | |||
| self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x") | |||
| self.value = value | |||
| def construct(self, index_0, index_1, index_2): | |||
| self.param[index_0, index_1, index_2] = self.value | |||
| ret = self.param + self.const | |||
| return ret | |||
| class TensorSetItemByTensorsWithTensor(Cell): | |||
| def __init__(self): | |||
| super(TensorSetItemByTensorsWithTensor, self).__init__() | |||
| self.const = Tensor(np.ones((6, 7, 8)), mstype.float32) | |||
| self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x") | |||
| def construct(self, index_0, index_1, index_2, value): | |||
| self.param[index_0, index_1, index_2] = value | |||
| ret = self.param + self.const | |||
| return ret | |||
| class TensorSetItemByTensorsWithTensorNumberError(Cell): | |||
| def __init__(self): | |||
| super(TensorSetItemByTensorsWithTensorNumberError, self).__init__() | |||
| self.const = Tensor(np.ones((6, 7, 8)), mstype.float32) | |||
| self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x") | |||
| def construct(self, index_0, index_1, index_2, index_3, value): | |||
| self.param[index_0, index_1, index_2, index_3] = value | |||
| ret = self.param + self.const | |||
| return ret | |||
| class TensorSetItemByTensorsWithTupleOfNumber(Cell): | |||
| def __init__(self, value): | |||
| super(TensorSetItemByTensorsWithTupleOfNumber, self).__init__() | |||
| self.const = Tensor(np.ones((6, 7, 8)), mstype.float32) | |||
| self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x") | |||
| self.value = value | |||
| def construct(self, index_0, index_1, index_2): | |||
| self.param[index_0, index_1, index_2] = self.value | |||
| ret = self.param + self.const | |||
| return ret | |||
| class TensorSetItemByTensorsWithTupleOfTensor(Cell): | |||
| def __init__(self): | |||
| super(TensorSetItemByTensorsWithTupleOfTensor, self).__init__() | |||
| self.const = Tensor(np.ones((6, 7, 8)), mstype.float32) | |||
| self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x") | |||
| def construct(self, index_0, index_1, index_2, value_0, value_1, value_2): | |||
| self.param[index_0, index_1, index_2] = (value_0, value_1, value_2) | |||
| ret = self.param + self.const | |||
| return ret | |||
| class TensorSetItemByTensorsWithTupleOfTensorNumberError(Cell): | |||
| def __init__(self): | |||
| super(TensorSetItemByTensorsWithTupleOfTensorNumberError, self).__init__() | |||
| self.const = Tensor(np.ones((6, 7, 8)), mstype.float32) | |||
| self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x") | |||
| def construct(self, index_0, index_1, index_2, value_0, value_1): | |||
| self.param[index_0, index_1, index_2] = (value_0, value_1) | |||
| ret = self.param + self.const | |||
| return ret | |||
| def test_tensor_assign(): | |||
| context.set_context(mode=context.GRAPH_MODE, save_graphs=True) | |||
| net = TensorAssignWithSlice() | |||
| @@ -441,15 +596,206 @@ test_cases = [ | |||
| 'block': NetWorkSliceEllipsis(), | |||
| 'desc_inputs': [Tensor(np.ones([6, 7, 8, 9], np.int32))], | |||
| }), | |||
| ('TensorIndexByOneTensor', { | |||
| 'block': TensorIndexByOneTensor(), | |||
| 'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32), | |||
| Tensor(np.random.randint(6, size=(5, 4)), mstype.int32)], | |||
| }), | |||
| ('TensorIndexByTwoTensors', { | |||
| 'block': TensorIndexByTwoTensors(), | |||
| 'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32), | |||
| Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(7, size=(4, 5)), mstype.int32)], | |||
| }), | |||
| ('TensorIndexByThreeTensors', { | |||
| 'block': TensorIndexByThreeTensors(), | |||
| 'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32), | |||
| Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(7, size=(4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)], | |||
| }), | |||
| ('TensorSetItemByOneTensorWithNumber', { | |||
| 'block': TensorSetItemByOneTensorWithNumber(value=0.0), | |||
| 'desc_inputs': [Tensor(np.random.randint(4, size=(5, 4)), mstype.int32)], | |||
| }), | |||
| ('TensorSetItemByOneTensorWithTensor', { | |||
| 'block': TensorSetItemByOneTensorWithTensor(), | |||
| 'desc_inputs': [Tensor(np.random.randint(3, size=(5, 4)), mstype.int32), | |||
| Tensor(np.zeros((4, 7, 8)), mstype.float32)], | |||
| }), | |||
| ('TensorSetItemByOneTensorWithTupleOfNumber', { | |||
| 'block': TensorSetItemByOneTensorWithTupleOfNumber(value=(0.0, 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7)), | |||
| 'desc_inputs': [Tensor(np.random.randint(5, size=(5, 4)), mstype.int32)], | |||
| }), | |||
| ('TensorSetItemByOneTensorWithTupleOfTensor', { | |||
| 'block': TensorSetItemByOneTensorWithTupleOfTensor(), | |||
| 'desc_inputs': [Tensor(np.random.randint(6, size=(5, 4)), mstype.int32), | |||
| Tensor(np.zeros((8,), np.float32)), | |||
| Tensor(np.ones((8,), np.float32)), | |||
| Tensor(np.ones((8,), np.float32) * 2)], | |||
| }), | |||
| ('TensorSetItemByTensorsWithNumber', { | |||
| 'block': TensorSetItemByTensorsWithNumber(value=0.0), | |||
| 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(7, size=(4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)], | |||
| }), | |||
| ('TensorSetItemByTensorsWithTensor', { | |||
| 'block': TensorSetItemByTensorsWithTensor(), | |||
| 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(7, size=(4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32), | |||
| Tensor(np.zeros((4, 5)), mstype.float32)], | |||
| }), | |||
| ('TensorSetItemByTensorsWithTupleOfNumber', { | |||
| 'block': TensorSetItemByTensorsWithTupleOfNumber(value=(0.0, 1.1, 2.2, 3.3, 4.4)), | |||
| 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(7, size=(4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)], | |||
| }), | |||
| ('TensorSetItemByTensorsWithTupleOfTensor', { | |||
| 'block': TensorSetItemByTensorsWithTupleOfTensor(), | |||
| 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(7, size=(4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32), | |||
| Tensor(np.zeros((4, 5)), mstype.float32), | |||
| Tensor(np.ones((4, 5)), mstype.float32), | |||
| Tensor(np.ones((4, 5)) * 2, mstype.float32)], | |||
| }) | |||
| ] | |||
| raise_error_set = [ | |||
| ('TensorIndexByOneTensorDtypeError', { | |||
| 'block': (TensorIndexByOneTensor(), {'exception': TypeError}), | |||
| 'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32), | |||
| Tensor(np.random.randint(6, size=(5, 4)), mstype.int8)], | |||
| }), | |||
| ('TensorIndexByTwoTensorsShapeError', { | |||
| 'block': (TensorIndexByTwoTensors(), {'exception': ValueError}), | |||
| 'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32), | |||
| Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(7, size=(2, 3, 5)), mstype.int32)], | |||
| }), | |||
| ('TensorIndexByTwoTensorsDtypeError', { | |||
| 'block': (TensorIndexByTwoTensors(), {'exception': TypeError}), | |||
| 'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32), | |||
| Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(7, size=(4, 5)), mstype.float32)], | |||
| }), | |||
| ('TensorIndexByThreeTensorsShapeError', { | |||
| 'block': (TensorIndexByThreeTensors(), {'exception': ValueError}), | |||
| 'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32), | |||
| Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(7, size=(3, 4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(8, size=(5, 2, 4, 5)), mstype.int32)], | |||
| }), | |||
| ('TensorIndexByThreeTensorsDtypeError', { | |||
| 'block': (TensorIndexByThreeTensors(), {'exception': TypeError}), | |||
| 'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32), | |||
| Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(7, size=(3, 4, 5)), mstype.int64), | |||
| Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)], | |||
| }), | |||
| ('TensorSetItemByOneTensorWithNumberTypeError', { | |||
| 'block': (TensorSetItemByOneTensorWithNumber(value=0), {'exception': TypeError}), | |||
| 'desc_inputs': [Tensor(np.random.randint(4, size=(5, 4)), mstype.int32)], | |||
| }), | |||
| ('TensorSetItemByOneTensorWithTensorShapeError', { | |||
| 'block': (TensorSetItemByOneTensorWithTensor(), {'exception': ValueError}), | |||
| 'desc_inputs': [Tensor(np.random.randint(3, size=(5, 4)), mstype.int32), | |||
| Tensor(np.zeros((6, 7, 8)), mstype.float32)], | |||
| }), | |||
| ('TensorSetItemByOneTensorWithTensorDtypeError', { | |||
| 'block': (TensorSetItemByOneTensorWithTensor(), {'exception': TypeError}), | |||
| 'desc_inputs': [Tensor(np.random.randint(3, size=(5, 4)), mstype.int32), | |||
| Tensor(np.zeros((6, 7, 8)), mstype.int32)], | |||
| }), | |||
| ('TensorSetItemByOneTensorWithTupleOfNumberTypeError', { | |||
| 'block': (TensorSetItemByOneTensorWithTupleOfNumber(value=(0, 1, 2, 3, 4, 5, 6, 7)), {'exception': TypeError}), | |||
| 'desc_inputs': [Tensor(np.random.randint(5, size=(5, 4)), mstype.int32)], | |||
| }), | |||
| ('TensorSetItemByOneTensorWithTupleOfNumberNumberError', { | |||
| 'block': (TensorSetItemByOneTensorWithTupleOfNumber(value=(0.0, 1.1, 2.2)), {'exception': ValueError}), | |||
| 'desc_inputs': [Tensor(np.random.randint(5, size=(5, 4)), mstype.int32)], | |||
| }), | |||
| ('TensorSetItemByOneTensorWithTupleOfTensorDtyeError', { | |||
| 'block': (TensorSetItemByOneTensorWithTupleOfTensor(), {'exception': TypeError}), | |||
| 'desc_inputs': [Tensor(np.random.randint(6, size=(5, 4)), mstype.int32), | |||
| Tensor(np.zeros((8,), np.int32)), | |||
| Tensor(np.ones((8,), np.int32)), | |||
| Tensor(np.ones((8,), np.float32) * 2)], | |||
| }), | |||
| ('TensorSetItemByTensorsWithNumberTypeError', { | |||
| 'block': (TensorSetItemByTensorsWithNumber(value=0), {'exception': TypeError}), | |||
| 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(7, size=(4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)], | |||
| }), | |||
| ('TensorSetItemByTensorsWithTensorShapeError', { | |||
| 'block': (TensorSetItemByTensorsWithTensor(), {'exception': ValueError}), | |||
| 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(7, size=(4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32), | |||
| Tensor(np.zeros((2, 5)), mstype.float32)], | |||
| }), | |||
| ('TensorSetItemByTensorsWithTensorTypeError', { | |||
| 'block': (TensorSetItemByTensorsWithTensor(), {'exception': TypeError}), | |||
| 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(7, size=(4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32), | |||
| Tensor(np.zeros((4, 5)), mstype.int32)], | |||
| }), | |||
| ('TensorSetItemByTensorsWithTensorNumberError', { | |||
| 'block': (TensorSetItemByTensorsWithTensorNumberError(), {'exception': IndexError}), | |||
| 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(7, size=(4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(8, size=(1, 3, 4, 5)), mstype.int32), | |||
| Tensor(np.zeros((2, 5)), mstype.float32)], | |||
| }), | |||
| ('TensorSetItemByTensorsWithTupleOfNumberTypeError', { | |||
| 'block': (TensorSetItemByTensorsWithTupleOfNumber(value=(0, 1, 2, 3, 4)), {'exception': TypeError}), | |||
| 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(7, size=(4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)], | |||
| }), | |||
| ('TensorSetItemByTensorsWithTupleOfNumberNumberError', { | |||
| 'block': (TensorSetItemByTensorsWithTupleOfNumber(value=(0.0, 1.0, 2.0, 3.0)), {'exception': ValueError}), | |||
| 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(7, size=(4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)], | |||
| }), | |||
| ('TensorSetItemByTensorsWithTupleOfTensorNumberError', { | |||
| 'block': (TensorSetItemByTensorsWithTupleOfTensorNumberError(), {'exception': ValueError}), | |||
| 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(7, size=(4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32), | |||
| Tensor(np.zeros((4, 5)), mstype.float32), | |||
| Tensor(np.ones((4, 5)), mstype.float32)], | |||
| }), | |||
| ('TensorSetItemByTensorsWithTupleOfTensorTypeError', { | |||
| 'block': (TensorSetItemByTensorsWithTupleOfTensor(), {'exception': TypeError}), | |||
| 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(7, size=(4, 5)), mstype.int32), | |||
| Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32), | |||
| Tensor(np.zeros((4, 5)), mstype.float32), | |||
| Tensor(np.ones((4, 5)), mstype.int32), | |||
| Tensor(np.ones((4, 5)) * 2, mstype.int32)], | |||
| }) | |||
| ] | |||
| @mindspore_test(pipeline_for_compile_forward_ge_graph_for_case_by_case_config) | |||
| def test_compile(): | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| def test_exec(): | |||
| context.set_context(mode=context.GRAPH_MODE, save_graphs=True) | |||
| return test_cases | |||
| @mindspore_test(pipeline_for_compile_forward_ge_graph_for_case_by_case_config_exception) | |||
| def test_check_exception(): | |||
| return raise_error_set | |||
| def test_tensor_slice_reduce_out_of_bounds_neg(): | |||
| class NetWork(Cell): | |||
| def __init__(self): | |||
| @@ -26,7 +26,7 @@ from mindspore.ops import functional as F | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops._grad.grad_base import bprop_getters | |||
| from mindspore.ops._grad.grad_math_ops import binop_grad_common | |||
| from mindspore.ops._utils import _get_broadcast_shape | |||
| from mindspore.ops._utils import get_broadcast_shape | |||
| from mindspore.ops.primitive import PrimitiveWithInfer, prim_attr_register | |||
| from mindspore.train.loss_scale_manager import DynamicLossScaleManager | |||
| @@ -54,7 +54,7 @@ class MockSub(PrimitiveWithInfer): | |||
| self.init_prim_io_names(inputs=['x', 'y'], outputs=['output']) | |||
| def infer_shape(self, x_shape, y_shape): | |||
| return _get_broadcast_shape(x_shape, y_shape) | |||
| return get_broadcast_shape(x_shape, y_shape) | |||
| def infer_dtype(self, x_dtype, y_dtype): | |||
| return x_dtype | |||