Merge pull request !1133 from zhangbuxue/support_tensor_get_value_by_tensor_indextags/v0.3.0-alpha
| @@ -1172,6 +1172,12 @@ int GenerateStridedSliceParametersFromNumber(const AbstractScalarPtr &scalar, co | |||||
| return 1; | return 1; | ||||
| } | } | ||||
| FuncGraphPtr ExpandADim(const FuncGraphPtr &ret_graph, const AnfNodePtr &tensor_node) { | |||||
| auto PrimExpandDims = GetPythonOps("expand_dims", "mindspore.ops.functional"); | |||||
| ret_graph->set_output(NewCNode({NewValueNode(PrimExpandDims), tensor_node, NewValueNode(0)}, ret_graph)); | |||||
| return ret_graph; | |||||
| } | |||||
| FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { | FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { | ||||
| // slice a tensor | // slice a tensor | ||||
| // args: tensor, slice or slice tuple | // args: tensor, slice or slice tuple | ||||
| @@ -1229,12 +1235,6 @@ FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec | |||||
| return ret_graph; | return ret_graph; | ||||
| } | } | ||||
| FuncGraphPtr TensorSlice::ExpandADim(const FuncGraphPtr &ret_graph, const AnfNodePtr &tensor_node) const { | |||||
| auto PrimExpandDims = GetPythonOps("expand_dims", "mindspore.ops.functional"); | |||||
| ret_graph->set_output(NewCNode({NewValueNode(PrimExpandDims), tensor_node, NewValueNode(0)}, ret_graph)); | |||||
| return ret_graph; | |||||
| } | |||||
| FuncGraphPtr TupleGetItemTensor::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { | FuncGraphPtr TupleGetItemTensor::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { | ||||
| // select indexed item | // select indexed item | ||||
| // args: tuple of items, index | // args: tuple of items, index | ||||
| @@ -206,8 +206,6 @@ class TensorSlice : public MetaFuncGraph { | |||||
| MS_DECLARE_PARENT(TensorSlice, MetaFuncGraph) | MS_DECLARE_PARENT(TensorSlice, MetaFuncGraph) | ||||
| FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; | FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; | ||||
| friend bool operator==(const TensorSlice &lhs, const TensorSlice &rhs) { return lhs.name_ == rhs.name_; } | friend bool operator==(const TensorSlice &lhs, const TensorSlice &rhs) { return lhs.name_ == rhs.name_; } | ||||
| FuncGraphPtr ExpandADim(const FuncGraphPtr &ret_graph, const AnfNodePtr &tensor_node) const; | |||||
| }; | }; | ||||
| using TensorSlicePtr = std::shared_ptr<TensorSlice>; | using TensorSlicePtr = std::shared_ptr<TensorSlice>; | ||||
| @@ -101,6 +101,7 @@ const char kNameReLU6[] = "ReLU6"; | |||||
| const char kNameReLU6Grad[] = "ReLU6Grad"; | const char kNameReLU6Grad[] = "ReLU6Grad"; | ||||
| const char kNameElu[] = "Elu"; | const char kNameElu[] = "Elu"; | ||||
| const char kNameEluGrad[] = "EluGrad"; | const char kNameEluGrad[] = "EluGrad"; | ||||
| const char kNameScatterUpdate[] = "ScatterUpdate"; | |||||
| const char kNameScatterNdUpdate[] = "ScatterNdUpdate"; | const char kNameScatterNdUpdate[] = "ScatterNdUpdate"; | ||||
| const char kNameScatterMax[] = "ScatterMax"; | const char kNameScatterMax[] = "ScatterMax"; | ||||
| const char kNameNMSWithMask[] = "NMSWithMask"; | const char kNameNMSWithMask[] = "NMSWithMask"; | ||||
| @@ -256,6 +257,7 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma | |||||
| {string(kNameResizeBilinear), ADPT_DESC(ResizeBilinearV2D)}, | {string(kNameResizeBilinear), ADPT_DESC(ResizeBilinearV2D)}, | ||||
| {string(kNameZerosLike), ADPT_DESC(ZerosLike)}, | {string(kNameZerosLike), ADPT_DESC(ZerosLike)}, | ||||
| {string(kNameOnesLike), ADPT_DESC(OnesLike)}, | {string(kNameOnesLike), ADPT_DESC(OnesLike)}, | ||||
| {string(kNameScatterUpdate), ADPT_DESC(ScatterUpdate)}, | |||||
| {string(kNameScatterNdUpdate), ADPT_DESC(ScatterNdUpdate)}, | {string(kNameScatterNdUpdate), ADPT_DESC(ScatterNdUpdate)}, | ||||
| {string(kNameScatterMax), ADPT_DESC(ScatterMax)}, | {string(kNameScatterMax), ADPT_DESC(ScatterMax)}, | ||||
| {string(kNameNMSWithMask), ADPT_DESC(NMSWithMask)}, | {string(kNameNMSWithMask), ADPT_DESC(NMSWithMask)}, | ||||
| @@ -515,6 +515,11 @@ INPUT_MAP(Unpack) = {{1, INPUT_DESC(x)}}; | |||||
| ATTR_MAP(Unpack) = {{"axis", ATTR_DESC(axis, AnyTraits<int>())}, {"num", ATTR_DESC(num, AnyTraits<int>())}}; | ATTR_MAP(Unpack) = {{"axis", ATTR_DESC(axis, AnyTraits<int>())}, {"num", ATTR_DESC(num, AnyTraits<int>())}}; | ||||
| DYN_OUTPUT_MAP(Unpack) = {{0, DYN_OUTPUT_DESC(y)}}; | DYN_OUTPUT_MAP(Unpack) = {{0, DYN_OUTPUT_DESC(y)}}; | ||||
| // ScatterUpdate | |||||
| INPUT_MAP(ScatterUpdate) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(indices)}, {3, INPUT_DESC(updates)}}; | |||||
| ATTR_MAP(ScatterUpdate) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}}; | |||||
| OUTPUT_MAP(ScatterUpdate) = {{0, OUTPUT_DESC(var)}}; | |||||
| // ScatterNdUpdate | // ScatterNdUpdate | ||||
| INPUT_MAP(ScatterNdUpdate) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(indices)}, {3, INPUT_DESC(updates)}}; | INPUT_MAP(ScatterNdUpdate) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(indices)}, {3, INPUT_DESC(updates)}}; | ||||
| ATTR_MAP(ScatterNdUpdate) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}}; | ATTR_MAP(ScatterNdUpdate) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}}; | ||||
| @@ -132,6 +132,8 @@ DECLARE_OP_ADAPTER(ZerosLike) | |||||
| DECLARE_OP_USE_OUTPUT(ZerosLike) | DECLARE_OP_USE_OUTPUT(ZerosLike) | ||||
| DECLARE_OP_ADAPTER(OnesLike) | DECLARE_OP_ADAPTER(OnesLike) | ||||
| DECLARE_OP_USE_OUTPUT(OnesLike) | DECLARE_OP_USE_OUTPUT(OnesLike) | ||||
| DECLARE_OP_ADAPTER(ScatterUpdate) | |||||
| DECLARE_OP_USE_OUTPUT(ScatterUpdate) | |||||
| DECLARE_OP_ADAPTER(ScatterNdUpdate) | DECLARE_OP_ADAPTER(ScatterNdUpdate) | ||||
| DECLARE_OP_USE_OUTPUT(ScatterNdUpdate) | DECLARE_OP_USE_OUTPUT(ScatterNdUpdate) | ||||
| DECLARE_OP_ADAPTER(ScatterMax) | DECLARE_OP_ADAPTER(ScatterMax) | ||||
| @@ -179,14 +179,15 @@ from .bounding_box_encode import _bounding_box_encode_tbe | |||||
| from .check_valid import _check_valid_tbe | from .check_valid import _check_valid_tbe | ||||
| from .iou import _iou_tbe | from .iou import _iou_tbe | ||||
| from .arg_max import _arg_max_tbe | from .arg_max import _arg_max_tbe | ||||
| from .nms_with_mask import nms_with_mask_op_info | |||||
| from .random_choice_with_mask import random_choice_with_mask_op_info | |||||
| from .sgd import sgd_op_info | |||||
| from .lars_update import lars_update_op_info | |||||
| from .nms_with_mask import _nms_with_mask_tbe | |||||
| from .random_choice_with_mask import _random_choice_with_mask_tbe | |||||
| from .sgd import _sgd_tbe | |||||
| from .lars_update import _lars_update_tbe | |||||
| from .bn_training_update_v2 import _bn_training_update_v2_tbe | from .bn_training_update_v2 import _bn_training_update_v2_tbe | ||||
| from .square_sum_all import square_sum_all_op_info | |||||
| from .square_sum_all import _square_sum_all_tbe | |||||
| from .pack import _pack_tbe | from .pack import _pack_tbe | ||||
| from .unpack import _unpack_tbe | from .unpack import _unpack_tbe | ||||
| from .scatter_update import _scatter_update_tbe | |||||
| from .prelu import _prelu_tbe | from .prelu import _prelu_tbe | ||||
| from .prelu_grad import _prelu_grad_tbe | from .prelu_grad import _prelu_grad_tbe | ||||
| from .binary_cross_entropy import _binary_cross_entropy_tbe | from .binary_cross_entropy import _binary_cross_entropy_tbe | ||||
| @@ -0,0 +1,42 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ScatterUpdate op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||||
| scatter_update_op_info = TBERegOp("ScatterUpdate") \ | |||||
| .fusion_type("ELEMWISE") \ | |||||
| .async_flag(False) \ | |||||
| .binfile_name("scatter_update.so") \ | |||||
| .compute_cost(10) \ | |||||
| .kernel_name("scatter_update") \ | |||||
| .partial_flag(True) \ | |||||
| .attr("use_locking", "optional", "bool", "all") \ | |||||
| .input(0, "var", False, "required", "all") \ | |||||
| .input(1, "indices", False, "required", "all") \ | |||||
| .input(1, "updates", False, "required", "all") \ | |||||
| .output(0, "var", False, "required", "all") \ | |||||
| .dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||||
| .dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default, DataType.I8_Default) \ | |||||
| .dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default, DataType.U8_Default,) \ | |||||
| .dtype_format(DataType.BOOL_Default, DataType.I32_Default, DataType.BOOL_Default, DataType.BOOL_Default) \ | |||||
| .get_op_info() | |||||
| @op_info_register(scatter_update_op_info) | |||||
| def _scatter_update_tbe(): | |||||
| """ScatterUpdate TBE register""" | |||||
| return | |||||
| @@ -14,6 +14,6 @@ | |||||
| # ============================================================================ | # ============================================================================ | ||||
| """ops utils.""" | """ops utils.""" | ||||
| from .utils import _get_broadcast_shape, _get_concat_offset | |||||
| from .utils import get_broadcast_shape, get_concat_offset | |||||
| __all__ = ['_get_broadcast_shape', '_get_concat_offset'] | |||||
| __all__ = ['get_broadcast_shape', 'get_concat_offset'] | |||||
| @@ -19,7 +19,8 @@ from ..._checkparam import Validator as validator | |||||
| from ..._checkparam import Rel | from ..._checkparam import Rel | ||||
| from ...common import dtype as mstype | from ...common import dtype as mstype | ||||
| def _get_broadcast_shape(x_shape, y_shape, prim_name): | |||||
| def get_broadcast_shape(x_shape, y_shape, prim_name): | |||||
| """ | """ | ||||
| Doing broadcast between tensor x and tensor y. | Doing broadcast between tensor x and tensor y. | ||||
| @@ -37,7 +38,7 @@ def _get_broadcast_shape(x_shape, y_shape, prim_name): | |||||
| Examples: | Examples: | ||||
| >>> x_shape = [1, 2, 3] | >>> x_shape = [1, 2, 3] | ||||
| >>> y_shape = [1, 2] | >>> y_shape = [1, 2] | ||||
| >>> broadcast_shape = _get_broadcast_shape(x_shape, y_shape) | |||||
| >>> broadcast_shape = get_broadcast_shape(x_shape, y_shape) | |||||
| """ | """ | ||||
| if x_shape == y_shape: | if x_shape == y_shape: | ||||
| return x_shape | return x_shape | ||||
| @@ -54,15 +55,14 @@ def _get_broadcast_shape(x_shape, y_shape, prim_name): | |||||
| elif x_shape[i] == y_shape[i]: | elif x_shape[i] == y_shape[i]: | ||||
| broadcast_shape_back.append(x_shape[i]) | broadcast_shape_back.append(x_shape[i]) | ||||
| else: | else: | ||||
| raise ValueError("For '{}' the x_shape {} and y_shape {} can not broadcast.".format( | |||||
| prim_name, x_shape, y_shape)) | |||||
| raise ValueError(f"For '{prim_name}', the x_shape {x_shape} and y_shape {y_shape} can not broadcast.") | |||||
| broadcast_shape_front = y_shape[0: y_len - length] if length == x_len else x_shape[0: x_len - length] | broadcast_shape_front = y_shape[0: y_len - length] if length == x_len else x_shape[0: x_len - length] | ||||
| broadcast_shape = broadcast_shape_front + broadcast_shape_back | |||||
| broadcast_shape = list(broadcast_shape_front) + broadcast_shape_back | |||||
| return broadcast_shape | return broadcast_shape | ||||
| def _get_concat_offset(x_shp, x_type, axis, prim_name): | |||||
| def get_concat_offset(x_shp, x_type, axis, prim_name): | |||||
| """for concat and concatoffset check args and compute offset""" | """for concat and concatoffset check args and compute offset""" | ||||
| validator.check_value_type("shape", x_shp, [tuple], prim_name) | validator.check_value_type("shape", x_shp, [tuple], prim_name) | ||||
| validator.check_integer("input_x rank", len(x_shp), 0, Rel.GT, prim_name) | validator.check_integer("input_x rank", len(x_shp), 0, Rel.GT, prim_name) | ||||
| @@ -73,7 +73,7 @@ def _get_concat_offset(x_shp, x_type, axis, prim_name): | |||||
| if axis < 0: | if axis < 0: | ||||
| axis = axis + rank_base | axis = axis + rank_base | ||||
| all_shp = x_shp[0][axis] | all_shp = x_shp[0][axis] | ||||
| offset = [0,] | |||||
| offset = [0] | |||||
| for i in range(1, len(x_shp)): | for i in range(1, len(x_shp)): | ||||
| v = x_shp[i] | v = x_shp[i] | ||||
| validator.check('len of x_shp[%d]' % i, len(v), 'len of x_shp[0]', len(x_shp[0]), Rel.EQ, prim_name) | validator.check('len of x_shp[%d]' % i, len(v), 'len of x_shp[0]', len(x_shp[0]), Rel.EQ, prim_name) | ||||
| @@ -1,226 +0,0 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """constexpr util""" | |||||
| from functools import reduce | |||||
| import numpy as np | |||||
| from ...primitive import constexpr | |||||
| from ....common.tensor import Tensor | |||||
| from ....common import dtype as mstype | |||||
| from ...._extends.utils import Slice, Ellipsis_ | |||||
| @constexpr | |||||
| def check_equal(param1, param2, msg="{},{}"): | |||||
| """Checks whether the two parameters are equal or not.""" | |||||
| if param1 != param2: | |||||
| raise ValueError(msg.format(param1, param2)) | |||||
| return param1 | |||||
| @constexpr | |||||
| def check_ellipsis_shape_size(data_shape, value_shape, data_size, value_size): | |||||
| """Checks the shape and size of the sensor and value.""" | |||||
| if data_shape == value_shape or data_size == value_size or value_size == 1: | |||||
| return True | |||||
| raise ValueError("The value(shape={}), can not assign to tensor(shape={}).".format(value_shape, data_shape)) | |||||
| @constexpr | |||||
| def check_tensor_setitem_index(index, element_type=None): | |||||
| """Checks tuple index type of tensor assignment.""" | |||||
| if index is None: | |||||
| raise IndexError("Tensor's index cannot be None.") | |||||
| # eg. Tensor[Slice] = u | |||||
| if isinstance(index, Slice): | |||||
| return True | |||||
| # eg. Tensor[tuple] = u | |||||
| if isinstance(index, tuple): | |||||
| if not index: | |||||
| raise IndexError("Tensor's index cannot be empty.") | |||||
| # eg. Tensor[tuple(Slice...)] = u | |||||
| if isinstance(index[0], (Slice, Ellipsis_, int)): | |||||
| return True | |||||
| raise IndexError("Index of type '{}' is not supported yet.".format(type(index[0]))) | |||||
| # eg. Tensor[Tensor[dtype=bool]] = u | |||||
| if index == mstype.tensor: | |||||
| if element_type is None or element_type != mstype.bool_: | |||||
| raise TypeError( | |||||
| "The index of tensor should be a bool type tensor. " | |||||
| "{} type is not supported yet.".format(element_type)) | |||||
| return True | |||||
| raise IndexError("Index of type '{}' is not supported yet.".format(type(index))) | |||||
| @constexpr | |||||
| def is_same_type(inst, type_): | |||||
| """ | |||||
| Checks whether an object is an instance of a target type. | |||||
| Inputs: | |||||
| inst (mindspore.dtype): Inspected type. | |||||
| type_ (mindspore.dtype): Target type. | |||||
| Outputs: | |||||
| bool, the check result. | |||||
| """ | |||||
| return inst == type_ | |||||
| def slice_expand(input_slices, shape): | |||||
| """ | |||||
| Converts slice to indices. | |||||
| Inputs: | |||||
| slices (Union[Slice, tuple[Slice]]): Slice tuple or slice. | |||||
| shape (tuple): The shape of a sensor is an integer element tuple. | |||||
| Outputs: | |||||
| tuple[list], This is expressed as (begins, ends, strides). | |||||
| """ | |||||
| begin = [] | |||||
| end = [] | |||||
| strides = [] | |||||
| index = 0 | |||||
| slices = None | |||||
| # Slice or tuple(Slice...) | |||||
| if isinstance(input_slices, Slice): | |||||
| slices = (input_slices,) | |||||
| elif isinstance(input_slices, (tuple, list)) and input_slices and isinstance(input_slices[0], (Slice, Ellipsis_)): | |||||
| is_have_ellipsis = False | |||||
| for _, element in enumerate(input_slices): | |||||
| if isinstance(element, Ellipsis_): | |||||
| is_have_ellipsis = True | |||||
| break | |||||
| if is_have_ellipsis: | |||||
| slices = ellipsis2slice(input_slices, shape) | |||||
| else: | |||||
| slices = input_slices | |||||
| else: | |||||
| raise IndexError("Tensor's index type is not supported yet.") | |||||
| for s in slices: | |||||
| start = 0 if (s.start is None) else s.start | |||||
| stop = shape[index] if (s.end is None) else s.end | |||||
| step = 1 if (s.step is None) else s.step | |||||
| begin.append(start) | |||||
| end.append(stop) | |||||
| strides.append(step) | |||||
| index += 1 | |||||
| while index < len(shape): | |||||
| begin.append(0) | |||||
| end.append(shape[index]) | |||||
| strides.append(1) | |||||
| index += 1 | |||||
| return begin, end, strides | |||||
| def ellipsis2slice(input_, shape): | |||||
| """Converts ellipsis to slice.""" | |||||
| input_slice = input_ | |||||
| result = [] | |||||
| if isinstance(input_, Ellipsis_): | |||||
| input_slice = (input_,) | |||||
| ell_count = 0 | |||||
| for _, element in enumerate(input_slice): | |||||
| if not isinstance(element, Ellipsis_): | |||||
| result.append(element) | |||||
| continue | |||||
| ell_count += 1 | |||||
| if ell_count > 1: | |||||
| raise IndexError("There cannot be more than one ellisis (...) in the index of the tensor, " | |||||
| "but it is currently {}".format(input_slice)) | |||||
| for _ in range(len(shape) - len(input_slice) + 1): | |||||
| result.append(Slice(None, None, None)) | |||||
| return tuple(result) | |||||
| @constexpr | |||||
| def slice2indices(input_slices, shape): | |||||
| """ | |||||
| Converts slice to indices. | |||||
| Inputs: | |||||
| slices (Union[Slice, tuple[Slice]]): Slice tuple or slice. | |||||
| shape (tuple): The shape of a tensor is an integer element tuple. | |||||
| Outputs: | |||||
| Tensor, the shape is (n, 1). | |||||
| """ | |||||
| begin, end, strides = slice_expand(input_slices, shape) | |||||
| np_r = [] | |||||
| for i, element in enumerate(shape): | |||||
| s = begin[i] if (begin[i] >= 0) else (element + begin[i]) | |||||
| e = end[i] if (end[i] >= 0) else (element + end[i]) | |||||
| np_r.append(np.r_[s:e:strides[i]]) | |||||
| # Reference: np.ravel_multi_index((np.ix_(np.r_[1:3:1], np.r_[0:4:1], np.r_[4:0:-1])), a.shape) | |||||
| np_ix = np.ix_(*np_r) | |||||
| ravel = np.ravel_multi_index(np_ix, shape) | |||||
| ravel = Tensor(ravel.reshape(-1, 1), dtype=mstype.int32) | |||||
| return ravel | |||||
| @constexpr | |||||
| def check_indices(indices_size, index): | |||||
| """Checks indices whether is empty.""" | |||||
| if indices_size < 1: | |||||
| raise IndexError("The tensor's index is unreasonable. index:{}".format(index)) | |||||
| return indices_size | |||||
| @constexpr | |||||
| def check_indices_value_size(indices_size, value_size): | |||||
| """Checks if the sizes are already matched.""" | |||||
| if value_size < 1: | |||||
| raise ValueError("The value assigned to tensor cannot be empty.") | |||||
| if value_size > 1: | |||||
| if value_size != indices_size: | |||||
| raise ValueError( | |||||
| "The value given to tensor does not match the index size," | |||||
| " value size:{}, indics size:{}".format(value_size, indices_size)) | |||||
| return value_size | |||||
| @constexpr | |||||
| def integer_to_indices(index, shape): | |||||
| """Converts int or tuple[int] to indices.""" | |||||
| size = reduce(lambda x, y: x * y, shape) | |||||
| range_ = np.arange(size).reshape(shape) | |||||
| value = range_[index] | |||||
| value = value.reshape(-1, 1) | |||||
| return Tensor(value, dtype=mstype.int32) | |||||
| @constexpr | |||||
| def tuple_element_is_slice(indexs): | |||||
| """Judges tuple element type.""" | |||||
| if not indexs: | |||||
| raise IndexError("Tensor's index cannot be empty.") | |||||
| if isinstance(indexs, tuple): | |||||
| for _, ele in enumerate(indexs): | |||||
| if not isinstance(ele, Slice): | |||||
| return False | |||||
| return True | |||||
| return False | |||||
| @constexpr | |||||
| def tuple_element_is_int(indexs): | |||||
| """Judges tuple element type.""" | |||||
| if not indexs: | |||||
| raise IndexError("Tensor's index cannot be empty.") | |||||
| if isinstance(indexs, tuple): | |||||
| for _, ele in enumerate(indexs): | |||||
| if not isinstance(ele, int): | |||||
| return False | |||||
| return True | |||||
| return False | |||||
| @@ -0,0 +1,487 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """constexpr util""" | |||||
| from functools import reduce | |||||
| import numpy as np | |||||
| from ...primitive import constexpr | |||||
| from ....common.tensor import Tensor | |||||
| from ....common import dtype as mstype | |||||
| from ...._extends.utils import Slice, Ellipsis_ | |||||
| from ....ops import _utils as op_utils | |||||
| from ...composite import base | |||||
| from .... import log as logger | |||||
| from ... import functional as F | |||||
| from ... import operations as P | |||||
| hyper_map = base.HyperMap() | |||||
| pack = P.Pack(axis=-1) | |||||
| ALL_TENSOR = 0 | |||||
| NO_TENSOR = 1 | |||||
| CONTAIN_TENSOR = 2 | |||||
| ALL_SCALAR = 3 | |||||
| INT_ = 0 | |||||
| BOOL_ = 1 | |||||
| UNSUPPORTED_DTYPE = 2 | |||||
| TENSOR_SETITEM = "tensor setitem" | |||||
| TENSOR_GETITEM = "tensor getitem" | |||||
| SET_ITEM_BY_ONE_TENSOR = 0 | |||||
| SET_ITEM_BY_TUPLE_OF_TENSOR = 1 | |||||
| @constexpr | |||||
| def check_equal(param1, param2, msg="{},{}"): | |||||
| """Checks whether the two parameters are equal or not.""" | |||||
| if param1 != param2: | |||||
| raise ValueError(msg.format(param1, param2)) | |||||
| return param1 | |||||
| @constexpr | |||||
| def check_ellipsis_shape_size(data_shape, value_shape, data_size, value_size): | |||||
| """Checks the shape and size of the sensor and value.""" | |||||
| if data_shape == value_shape or data_size == value_size or value_size == 1: | |||||
| return True | |||||
| raise ValueError("The value(shape={}), can not assign to tensor(shape={}).".format(value_shape, data_shape)) | |||||
| @constexpr | |||||
| def check_tensor_setitem_index(index, element_type=None): | |||||
| """Checks tuple index type of tensor assignment.""" | |||||
| if index is None: | |||||
| raise IndexError("Tensor's index cannot be None.") | |||||
| # eg. Tensor[Slice] = u | |||||
| if isinstance(index, Slice): | |||||
| return True | |||||
| # eg. Tensor[tuple] = u | |||||
| if isinstance(index, tuple): | |||||
| if not index: | |||||
| raise IndexError("Tensor's index cannot be empty.") | |||||
| # eg. Tensor[tuple(Slice...)] = u | |||||
| if isinstance(index[0], (Slice, Ellipsis_, int)): | |||||
| return True | |||||
| raise IndexError("Index of type '{}' is not supported yet.".format(type(index[0]))) | |||||
| # eg. Tensor[Tensor[dtype=bool]] = u | |||||
| if isinstance(index, mstype.tensor_type): | |||||
| if element_type is None or element_type != mstype.bool_: | |||||
| raise TypeError( | |||||
| "The index of tensor should be a bool type tensor. " | |||||
| "{} type is not supported yet.".format(element_type)) | |||||
| return True | |||||
| raise IndexError("Index of type '{}' is not supported yet.".format(type(index))) | |||||
| @constexpr | |||||
| def is_same_type(inst, type_): | |||||
| """ | |||||
| Checks whether an object is an instance of a target type. | |||||
| Inputs: | |||||
| inst (mindspore.dtype): Inspected type. | |||||
| type_ (mindspore.dtype): Target type. | |||||
| Outputs: | |||||
| bool, the check result. | |||||
| """ | |||||
| return inst == type_ | |||||
| def slice_expand(input_slices, shape): | |||||
| """ | |||||
| Converts slice to indices. | |||||
| Inputs: | |||||
| slices (Union[Slice, tuple[Slice]]): Slice tuple or slice. | |||||
| shape (tuple): The shape of a sensor is an integer element tuple. | |||||
| Outputs: | |||||
| tuple[list], This is expressed as (begins, ends, strides). | |||||
| """ | |||||
| begin = [] | |||||
| end = [] | |||||
| strides = [] | |||||
| index = 0 | |||||
| slices = None | |||||
| # Slice or tuple(Slice...) | |||||
| if isinstance(input_slices, Slice): | |||||
| slices = (input_slices,) | |||||
| elif isinstance(input_slices, (tuple, list)) and input_slices and isinstance(input_slices[0], (Slice, Ellipsis_)): | |||||
| is_have_ellipsis = False | |||||
| for _, element in enumerate(input_slices): | |||||
| if isinstance(element, Ellipsis_): | |||||
| is_have_ellipsis = True | |||||
| break | |||||
| if is_have_ellipsis: | |||||
| slices = ellipsis2slice(input_slices, shape) | |||||
| else: | |||||
| slices = input_slices | |||||
| else: | |||||
| raise IndexError("Tensor's index type is not supported yet.") | |||||
| for s in slices: | |||||
| start = 0 if (s.start is None) else s.start | |||||
| stop = shape[index] if (s.end is None) else s.end | |||||
| step = 1 if (s.step is None) else s.step | |||||
| begin.append(start) | |||||
| end.append(stop) | |||||
| strides.append(step) | |||||
| index += 1 | |||||
| while index < len(shape): | |||||
| begin.append(0) | |||||
| end.append(shape[index]) | |||||
| strides.append(1) | |||||
| index += 1 | |||||
| return begin, end, strides | |||||
| def ellipsis2slice(input_, shape): | |||||
| """Converts ellipsis to slice.""" | |||||
| input_slice = input_ | |||||
| result = [] | |||||
| if isinstance(input_, Ellipsis_): | |||||
| input_slice = (input_,) | |||||
| ell_count = 0 | |||||
| for _, element in enumerate(input_slice): | |||||
| if not isinstance(element, Ellipsis_): | |||||
| result.append(element) | |||||
| continue | |||||
| ell_count += 1 | |||||
| if ell_count > 1: | |||||
| raise IndexError("There cannot be more than one ellisis (...) in the index of the tensor, " | |||||
| "but it is currently {}".format(input_slice)) | |||||
| for _ in range(len(shape) - len(input_slice) + 1): | |||||
| result.append(Slice(None, None, None)) | |||||
| return tuple(result) | |||||
| @constexpr | |||||
| def slice2indices(input_slices, shape): | |||||
| """ | |||||
| Converts slice to indices. | |||||
| Inputs: | |||||
| slices (Union[Slice, tuple[Slice]]): Slice tuple or slice. | |||||
| shape (tuple): The shape of a tensor is an integer element tuple. | |||||
| Outputs: | |||||
| Tensor, the shape is (n, 1). | |||||
| """ | |||||
| begin, end, strides = slice_expand(input_slices, shape) | |||||
| np_r = [] | |||||
| for i, element in enumerate(shape): | |||||
| s = begin[i] if (begin[i] >= 0) else (element + begin[i]) | |||||
| e = end[i] if (end[i] >= 0) else (element + end[i]) | |||||
| np_r.append(np.r_[s:e:strides[i]]) | |||||
| # Reference: np.ravel_multi_index((np.ix_(np.r_[1:3:1], np.r_[0:4:1], np.r_[4:0:-1])), a.shape) | |||||
| np_ix = np.ix_(*np_r) | |||||
| ravel = np.ravel_multi_index(np_ix, shape) | |||||
| ravel = Tensor(ravel.reshape(-1, 1), dtype=mstype.int32) | |||||
| return ravel | |||||
| @constexpr | |||||
| def check_indices(indices_size, index): | |||||
| """Checks indices whether is empty.""" | |||||
| if indices_size < 1: | |||||
| raise IndexError("The tensor's index is unreasonable. index:{}".format(index)) | |||||
| return indices_size | |||||
| @constexpr | |||||
| def check_indices_value_size(indices_size, value_size): | |||||
| """Checks if the sizes are already matched.""" | |||||
| if value_size < 1: | |||||
| raise ValueError("The value assigned to tensor cannot be empty.") | |||||
| if value_size > 1: | |||||
| if value_size != indices_size: | |||||
| raise ValueError( | |||||
| "The value given to tensor does not match the index size," | |||||
| " value size:{}, indics size:{}".format(value_size, indices_size)) | |||||
| return value_size | |||||
| @constexpr | |||||
| def integer_to_indices(index, shape): | |||||
| """Converts int or tuple[int] to indices.""" | |||||
| size = reduce(lambda x, y: x * y, shape) | |||||
| range_ = np.arange(size).reshape(shape) | |||||
| value = range_[index] | |||||
| value = value.reshape(-1, 1) | |||||
| return Tensor(value, dtype=mstype.int32) | |||||
| @constexpr | |||||
| def tuple_element_is_slice(indexs): | |||||
| """Judges tuple element type.""" | |||||
| if not indexs: | |||||
| raise IndexError("Tensor's index cannot be empty.") | |||||
| if isinstance(indexs, tuple): | |||||
| for _, ele in enumerate(indexs): | |||||
| if not isinstance(ele, Slice): | |||||
| return False | |||||
| return True | |||||
| return False | |||||
| @constexpr | |||||
| def tuple_element_is_int(indexs): | |||||
| """Judges tuple element type.""" | |||||
| if not indexs: | |||||
| raise IndexError("Tensor's index cannot be empty.") | |||||
| if isinstance(indexs, tuple): | |||||
| for _, ele in enumerate(indexs): | |||||
| if not isinstance(ele, int): | |||||
| return False | |||||
| return True | |||||
| return False | |||||
| @constexpr | |||||
| def tuple_elements_type(types): | |||||
| """Judges the type of all elements of the tuple.""" | |||||
| tensors_number = 0 | |||||
| for ele in types: | |||||
| if isinstance(ele, mstype.tensor_type): | |||||
| tensors_number += 1 | |||||
| if tensors_number == len(types): | |||||
| return ALL_TENSOR | |||||
| if tensors_number == 0: | |||||
| return NO_TENSOR | |||||
| return CONTAIN_TENSOR | |||||
| @constexpr | |||||
| def check_value_elements(data_dtype, types): | |||||
| """Judges the type of all elements of the tuple.""" | |||||
| tensors_number = 0 | |||||
| scalars_number = 0 | |||||
| for i, ele in enumerate(types): | |||||
| if isinstance(ele, mstype.tensor_type): | |||||
| ele_dtype = ele.element_type() | |||||
| if data_dtype == ele_dtype: | |||||
| tensors_number += 1 | |||||
| else: | |||||
| raise TypeError(f"For '{TENSOR_SETITEM}', the data type of {i}th tensor '{ele_dtype}' " | |||||
| f"in value tuple is not consistent with origin tensor data type '{data_dtype}'.") | |||||
| elif mstype.issubclass_(ele, data_dtype): | |||||
| scalars_number += 1 | |||||
| else: | |||||
| raise TypeError(f"For '{TENSOR_SETITEM}', the {i}th element type '{ele}' in " | |||||
| f"value tuple is not consistent with origin tensor data type '{data_dtype}'.") | |||||
| if tensors_number == len(types): | |||||
| return ALL_TENSOR | |||||
| if scalars_number == len(types): | |||||
| return ALL_SCALAR | |||||
| raise TypeError(f"For '{TENSOR_SETITEM}', the value does not support scalar and tensor mixing, but got {types}.") | |||||
| @constexpr | |||||
| def get_index_tensor_dtype(dtype): | |||||
| """Check a tuple of tensor data type.""" | |||||
| if dtype == mstype.int32: | |||||
| return INT_ | |||||
| if dtype == mstype.bool_: | |||||
| return BOOL_ | |||||
| raise TypeError(f"For '{TENSOR_SETITEM}', the index tensor data type '{dtype}' is not supported.") | |||||
| @constexpr | |||||
| def check_index_tensors_dtype(dtypes, op_name): | |||||
| """Check a tuple of tensor data type.""" | |||||
| if op_name == TENSOR_GETITEM: | |||||
| valid_dtypes = (mstype.int32, mstype.int64) | |||||
| elif op_name == TENSOR_SETITEM: | |||||
| valid_dtypes = (mstype.int32,) | |||||
| else: | |||||
| raise ValueError("Unsupported operation.") | |||||
| for ele in dtypes: | |||||
| if ele in valid_dtypes and ele == dtypes[0]: | |||||
| continue | |||||
| raise TypeError(f"For '{op_name}', the index tensors data type must be same, " | |||||
| f"and should be one of the following: {valid_dtypes}, but got {dtypes}.") | |||||
| return True | |||||
| @constexpr | |||||
| def check_tensor_dtype_valid(dtype, valid_dtypes): | |||||
| """Check a tensor data type.""" | |||||
| if dtype in valid_dtypes: | |||||
| return True | |||||
| raise TypeError(f"The index tensor data type must be one of " | |||||
| f"the following: {valid_dtypes}, but got {dtype}.") | |||||
| @constexpr | |||||
| def check_tensors_dtype_same(x_dtype, y_dtype, op_name): | |||||
| """Check tensors data type same.""" | |||||
| if x_dtype == y_dtype: | |||||
| return True | |||||
| raise TypeError(f"For '{op_name}', the value data type '{y_dtype}' " | |||||
| f"is not consistent with origin tensor data type {x_dtype}.") | |||||
| @constexpr | |||||
| def broadcast_shapes(shapes, op_name): | |||||
| """Broadcasts a tuple of tensor.""" | |||||
| broadcast_shape = shapes[0] | |||||
| for i, shape in enumerate(shapes): | |||||
| logger.debug(f"Broadcasts the {i}th tensor, the shape is {shape}.") | |||||
| broadcast_shape = op_utils.get_broadcast_shape(broadcast_shape, shape, op_name) | |||||
| return tuple(broadcast_shape) | |||||
| @constexpr | |||||
| def check_two_shapes_need_broadcast(shape_x, shape_y): | |||||
| """Check two shapes need broadcast.""" | |||||
| error = ValueError(f"For 'tensor setitem with tensor', the value tensor shape " | |||||
| f"{shape_y} could not broadcast the required updates shape {shape_x}.") | |||||
| if len(shape_y) > len(shape_x): | |||||
| raise error | |||||
| for i in range(-len(shape_y), 0): | |||||
| if shape_y[i] > shape_x[i]: | |||||
| raise error | |||||
| if shape_y[i] < shape_x[i] and shape_y[i] != 1: | |||||
| raise error | |||||
| if shape_y == shape_x: | |||||
| return False | |||||
| return True | |||||
| @constexpr | |||||
| def compute_multiples(origin_shape, broadcast_shape): | |||||
| """Compute multiples between broadcast_shape with origin_shape.""" | |||||
| len_gap = len(broadcast_shape) - len(origin_shape) | |||||
| return broadcast_shape[0:len_gap] + tuple(map(lambda x, y: x // y, broadcast_shape[len_gap:], origin_shape)) | |||||
| def tile(broadcast_shape, x): | |||||
| multiples = compute_multiples(F.shape(x), broadcast_shape) | |||||
| return F.tile(x, multiples) | |||||
| @constexpr | |||||
| def check_shapes_same(value_shapes, op_name): | |||||
| """Check if the shapes in the tuple are consistent.""" | |||||
| for i, shape in enumerate(value_shapes): | |||||
| if shape != value_shapes[0]: | |||||
| raise ValueError(f"For '{op_name}', the {i}th tensor shape in value tuple " | |||||
| f"is not same as the first tensor shape.") | |||||
| return True | |||||
| @constexpr | |||||
| def convert_scalar_to_tensor(data_shape, data_dtype, indices_shape, value, op_type): | |||||
| """Convert a scalar to a tensor.""" | |||||
| if op_type == SET_ITEM_BY_ONE_TENSOR: | |||||
| updates_shape = indices_shape + data_shape[1:] | |||||
| else: | |||||
| updates_shape = indices_shape[:-1] + data_shape[indices_shape[-1]:] | |||||
| if isinstance(value, mstype.dtype_to_pytype(data_dtype)): | |||||
| return Tensor(np.full(updates_shape, value), dtype=data_dtype) | |||||
| raise TypeError(f"For '{TENSOR_SETITEM}', the value type '{value.__class__.__name__}'" | |||||
| f" is not consistent with tensor data type {data_dtype}.") | |||||
| @constexpr | |||||
| def convert_tuple_of_scalar_to_tensor(data_shape, data_dtype, index_shape, value, op_type): | |||||
| """Convert a tuple of scalar to a tensor.""" | |||||
| updates_shape = generate_updates_shape(data_shape, index_shape, op_type) | |||||
| if len(value) != updates_shape[-1]: | |||||
| raise ValueError(f"For '{TENSOR_SETITEM}', the number of elements : {len(value)} in the updates tuple " | |||||
| f"does not meet the requirements: {updates_shape[-1]}.") | |||||
| array = np.array(value, dtype=mstype.dtype_to_nptype(data_dtype)) | |||||
| reps = compute_multiples(updates_shape[-1:], updates_shape) | |||||
| return Tensor(np.tile(array, reps)) | |||||
| @constexpr | |||||
| def generate_updates_shape(data_shape, index_shape, op_type): | |||||
| """Generate updates shape for 'tensor setitem'.""" | |||||
| if op_type == SET_ITEM_BY_ONE_TENSOR: | |||||
| updates_shape = index_shape + data_shape[1:] | |||||
| else: | |||||
| updates_shape = index_shape[:-1] + data_shape[index_shape[-1]:] | |||||
| return updates_shape | |||||
| @constexpr | |||||
| def check_number_of_index_tensor(data_shape, tuple_len, op_name): | |||||
| """Check if the number of index tensor exceeds the dimension of the operated tensor.""" | |||||
| if tuple_len <= len(data_shape): | |||||
| return True | |||||
| raise IndexError(f"For '{op_name}', the number {tuple_len} of index tensor " | |||||
| f"is greater than the dimension {len(data_shape)} of the operated tensor.") | |||||
| def generate_indeices_from_tuple_of_tensor(data, tuple_index, op_name): | |||||
| """Generate an indices tensor from a tuple of tensor.""" | |||||
| indices = None | |||||
| check_index_tensor_number = check_number_of_index_tensor(F.shape(data), len(tuple_index), op_name) | |||||
| if check_index_tensor_number: | |||||
| dtype_tuple = hyper_map(F.dtype, tuple_index) | |||||
| check_dtypes = check_index_tensors_dtype(dtype_tuple, op_name) | |||||
| if check_dtypes: | |||||
| shape_tuple = hyper_map(F.shape, tuple_index) | |||||
| broadcast_shape = broadcast_shapes(shape_tuple, op_name) | |||||
| broadcast_tensors = hyper_map(F.partial(tile, broadcast_shape), tuple_index) | |||||
| indices = pack(broadcast_tensors) | |||||
| return indices | |||||
| def generate_updates_from_scalar(data, indices, value, op_type): | |||||
| """Generate an updates tensor from a scalar.""" | |||||
| data_shape = F.shape(data) | |||||
| indices_shape = F.shape(indices) | |||||
| data_dtype = F.dtype(data) | |||||
| return convert_scalar_to_tensor(data_shape, data_dtype, indices_shape, value, op_type) | |||||
| def generate_updates_from_tuple(data, index, value, op_type): | |||||
| """Generate an updates tensor from a tuple.""" | |||||
| value_types = hyper_map(F.typeof, value) | |||||
| data_dtype = F.dtype(data) | |||||
| value_elements_type = check_value_elements(data_dtype, value_types) | |||||
| if value_elements_type == ALL_TENSOR: | |||||
| value_shapes = hyper_map(F.shape, value) | |||||
| shapes_same = check_shapes_same(value_shapes, TENSOR_SETITEM) | |||||
| if shapes_same: | |||||
| value = F.pack(value) | |||||
| return generate_updates_from_tensor(data, index, value, op_type) | |||||
| data_shape = F.shape(data) | |||||
| index_shape = F.shape(index) | |||||
| return convert_tuple_of_scalar_to_tensor(data_shape, data_dtype, index_shape, value, op_type) | |||||
| def generate_updates_from_tensor(data, index, value, op_type): | |||||
| """Generate an updates tensor from a tensor.""" | |||||
| data_shape = F.shape(data) | |||||
| index_shape = F.shape(index) | |||||
| value_shape = F.shape(value) | |||||
| data_dtype = F.dtype(data) | |||||
| value_dtype = F.dtype(value) | |||||
| updates_shape = value_shape | |||||
| check_dtype_same = check_tensors_dtype_same(data_dtype, value_dtype, TENSOR_SETITEM) | |||||
| if check_dtype_same: | |||||
| updates_shape = generate_updates_shape(data_shape, index_shape, op_type) | |||||
| need_broadcast = check_two_shapes_need_broadcast(updates_shape, value_shape) | |||||
| if need_broadcast: | |||||
| return tile(updates_shape, value) | |||||
| return value | |||||
| @@ -15,9 +15,10 @@ | |||||
| """Implementation for getitem.""" | """Implementation for getitem.""" | ||||
| from ...composite import base | |||||
| from . import _utils as multi_utils | |||||
| from ..import base | |||||
| from ... import functional as F | from ... import functional as F | ||||
| from ....common import dtype as mstype | |||||
| getitem = base.MultitypeFuncGraph('getitem') | getitem = base.MultitypeFuncGraph('getitem') | ||||
| """ | """ | ||||
| @@ -214,19 +215,45 @@ def _tensor_getitem_by_slice(data, slice_index): | |||||
| return _tensor_slice(data, slice_index) | return _tensor_slice(data, slice_index) | ||||
| @getitem.register("Tensor", "Tensor") | |||||
| def _tensor_getitem_by_tensor(data, tensor_index): | |||||
| """ | |||||
| Getting item of tensor by slice. | |||||
| Inputs: | |||||
| data (Tensor): A tensor. | |||||
| tensor_index (Tensor): An index expressed by tensor. | |||||
| Outputs: | |||||
| Tensor, element type is same as the element type of data. | |||||
| """ | |||||
| check_dtypes = multi_utils.check_tensor_dtype_valid(F.dtype(tensor_index), (mstype.int32, mstype.int64)) | |||||
| result = None | |||||
| if check_dtypes: | |||||
| result = F.gather(data, tensor_index, 0) | |||||
| return result | |||||
| @getitem.register("Tensor", "Tuple") | @getitem.register("Tensor", "Tuple") | ||||
| def _tensor_getitem_by_slice_tuple(data, slice_tuple_index): | |||||
| def _tensor_getitem_by_tuple(data, tuple_index): | |||||
| """ | """ | ||||
| Getting item of tensor by slice tuple. | Getting item of tensor by slice tuple. | ||||
| Inputs: | Inputs: | ||||
| data (Tensor): A tensor. | data (Tensor): A tensor. | ||||
| slice_tuple_index (tuple): Index in tuple. | |||||
| tuple_index (tuple): Index in tuple. | |||||
| Outputs: | Outputs: | ||||
| Tensor, element type is same as the element type of data. | Tensor, element type is same as the element type of data. | ||||
| """ | """ | ||||
| return _tensor_slice(data, slice_tuple_index) | |||||
| index_types = multi_utils.hyper_map(F.typeof, tuple_index) | |||||
| index_elements_type = multi_utils.tuple_elements_type(index_types) | |||||
| result = None | |||||
| if index_elements_type == multi_utils.NO_TENSOR: | |||||
| result = _tensor_slice(data, tuple_index) | |||||
| if index_elements_type == multi_utils.ALL_TENSOR: | |||||
| result = _tensor_getitem_by_tuple_of_tensor(data, tuple_index) | |||||
| return result | |||||
| @getitem.register("Tensor", "Ellipsis") | @getitem.register("Tensor", "Ellipsis") | ||||
| @@ -242,3 +269,10 @@ def _tensor_getitem_by_ellipsis(data, ellipsis_index): | |||||
| Tensor, same as data. | Tensor, same as data. | ||||
| """ | """ | ||||
| return _tensor_slice(data, ellipsis_index) | return _tensor_slice(data, ellipsis_index) | ||||
| def _tensor_getitem_by_tuple_of_tensor(data, tuple_index): | |||||
| """Tensor getitem by a tuple of tensor.""" | |||||
| indices = multi_utils.generate_indeices_from_tuple_of_tensor(data, tuple_index, multi_utils.TENSOR_GETITEM) | |||||
| result = F.gather_nd(data, indices) | |||||
| return result | |||||
| @@ -18,10 +18,11 @@ | |||||
| from ...composite import base | from ...composite import base | ||||
| from ....common import dtype as mstype | from ....common import dtype as mstype | ||||
| from ... import functional as F | from ... import functional as F | ||||
| from . import _multitype_ops_util as mult_util | |||||
| from . import _utils as multi_utils | |||||
| setitem = base.MultitypeFuncGraph('setitem') | setitem = base.MultitypeFuncGraph('setitem') | ||||
| @setitem.register("List", "Number", "String") | @setitem.register("List", "Number", "String") | ||||
| def _list_setitem_with_string(data, number_index, value): | def _list_setitem_with_string(data, number_index, value): | ||||
| """ | """ | ||||
| @@ -118,7 +119,7 @@ def _dict_setitem_with_number(data, key, value): | |||||
| @setitem.register("Tensor", "Tensor", "Tensor") | @setitem.register("Tensor", "Tensor", "Tensor") | ||||
| def _tensor_setitem_by_tensor_v1(data, index, value_tensor): | |||||
| def _tensor_setitem_by_tensor_with_tensor(data, index, value_tensor): | |||||
| """ | """ | ||||
| Tensor assignment. | Tensor assignment. | ||||
| @@ -137,27 +138,15 @@ def _tensor_setitem_by_tensor_v1(data, index, value_tensor): | |||||
| Outputs: | Outputs: | ||||
| Tensor, element type and shape is same as data. | Tensor, element type and shape is same as data. | ||||
| """ | """ | ||||
| result = None | |||||
| index_dtype = F.dtype(index) | index_dtype = F.dtype(index) | ||||
| index_shape = F.shape(index) | |||||
| check_result = mult_util.check_tensor_setitem_index(mstype.tensor, index_dtype) | |||||
| if check_result: | |||||
| data_shape = F.shape(data) | |||||
| data_shape = mult_util.check_equal(data_shape, index_shape, | |||||
| "The tensor(shape={}) and tensor index(shape={}) should be the same shape.") | |||||
| size = F.size(value_tensor) | |||||
| size = mult_util.check_equal(1, size, | |||||
| "When assign value is a tensor, its size should be {}, but current size is {}.") | |||||
| dtype = F.dtype(data) | |||||
| u_cast = F.cast(value_tensor, dtype) | |||||
| one_data = F.ones_like(data) | |||||
| u = F.tensor_mul(one_data, u_cast) | |||||
| result = F.select(index, u, data) | |||||
| return result | |||||
| tensor_dtype = multi_utils.get_index_tensor_dtype(index_dtype) | |||||
| if tensor_dtype == multi_utils.INT_: | |||||
| return _tensor_setitem_by_int_tensor_with_tensor(data, index, value_tensor) | |||||
| return _tensor_setitem_by_bool_tensor_with_tensor(data, index, value_tensor) | |||||
| @setitem.register("Tensor", "Tensor", "Number") | @setitem.register("Tensor", "Tensor", "Number") | ||||
| def _tensor_setitem_by_tensor_v2(data, index, value): | |||||
| def _tensor_setitem_by_tensor_with_number(data, index, value): | |||||
| """ | """ | ||||
| Tensor assignment. | Tensor assignment. | ||||
| @@ -171,143 +160,167 @@ def _tensor_setitem_by_tensor_v2(data, index, value): | |||||
| Inputs: | Inputs: | ||||
| data (Tensor): Assigned tensor. | data (Tensor): Assigned tensor. | ||||
| index (Tensor): Tensor of bool type. | index (Tensor): Tensor of bool type. | ||||
| value_tensor (Number): Assignment value. | |||||
| value (Number): Assignment value. | |||||
| Outputs: | Outputs: | ||||
| Tensor, element type and shape is same as data. | Tensor, element type and shape is same as data. | ||||
| """ | """ | ||||
| result = None | |||||
| index_dtype = F.dtype(index) | index_dtype = F.dtype(index) | ||||
| index_shape = F.shape(index) | |||||
| check_result = mult_util.check_tensor_setitem_index(mstype.tensor, index_dtype) | |||||
| if check_result: | |||||
| shape = F.shape(data) | |||||
| shape = mult_util.check_equal( | |||||
| shape, index_shape, "The tensor(shape={}) and tensor index(shape={}) should be the same shape.") | |||||
| dtype = F.dtype(data) | |||||
| u = F.fill(dtype, shape, value) | |||||
| result = F.select(index, u, data) | |||||
| return result | |||||
| tensor_dtype = multi_utils.get_index_tensor_dtype(index_dtype) | |||||
| if tensor_dtype == multi_utils.BOOL_: | |||||
| return _tensor_setitem_by_bool_tensor_with_scalar(data, index, value) | |||||
| return _tensor_setitem_by_int_tensor_with_scalar(data, index, value) | |||||
| @setitem.register("Tensor", "Slice", "Tensor") | |||||
| def _tensor_setitem_with_slice_v3(data, input_slice, value): | |||||
| @setitem.register("Tensor", "Tuple", "Number") | |||||
| def _tensor_setitem_by_tuple_with_number(data, tuple_index, value): | |||||
| """ | """ | ||||
| Tensor assignment. | Tensor assignment. | ||||
| Note: | Note: | ||||
| Syntax support: A[Slice] = U | |||||
| Restraint condition: A is a Tensor | |||||
| Slice like "1:3" | |||||
| U is a Tensor(size=1) or Tensor(size>1) | |||||
| Syntax support: A[B, C, D] = u. | |||||
| Restraint condition: 1) A is a Tensor, and B, C, D are index. | |||||
| 2) u is a scalar. | |||||
| Inputs: | Inputs: | ||||
| data (Tensor): Assigned tensor. | data (Tensor): Assigned tensor. | ||||
| input_slice (Slice): Slice expression. | |||||
| index (Tuple): An index tuple. | |||||
| value (Number): Assignment value. | value (Number): Assignment value. | ||||
| Outputs: | Outputs: | ||||
| Tensor, element type and shape is same as data. | Tensor, element type and shape is same as data. | ||||
| """ | """ | ||||
| return _tensor_assgin_tensor(data, input_slice, value) | |||||
| index_types = multi_utils.hyper_map(F.typeof, tuple_index) | |||||
| index_elements_type = multi_utils.tuple_elements_type(index_types) | |||||
| result = None | |||||
| if index_elements_type == multi_utils.NO_TENSOR: | |||||
| result = _tensor_assgin_number(data, tuple_index, value) | |||||
| if index_elements_type == multi_utils.ALL_TENSOR: | |||||
| indices = multi_utils.generate_indeices_from_tuple_of_tensor(data, tuple_index, multi_utils.TENSOR_SETITEM) | |||||
| updates = multi_utils.generate_updates_from_scalar(data, indices, value, | |||||
| multi_utils.SET_ITEM_BY_TUPLE_OF_TENSOR) | |||||
| result = F.scatter_nd_update(data, indices, updates) | |||||
| return result | |||||
| @setitem.register("Tensor", "Tuple", "Tensor") | @setitem.register("Tensor", "Tuple", "Tensor") | ||||
| def _tensor_setitem_with_slice_v4(data, input_slice, value): | |||||
| def _tensor_setitem_by_tuple_with_tensor(data, tuple_index, value): | |||||
| """ | """ | ||||
| Tensor assignment. | Tensor assignment. | ||||
| Note: | Note: | ||||
| Syntax support: A[tuple(Slice)] = U, and A[tuple(Number)] = U | |||||
| Restraint condition: A is a Tensor | |||||
| Slice like "1:3, ::, :4:-1" | |||||
| U is a Tensor(size=1) or Tensor(size>1) | |||||
| Syntax support: A[B, C, D] = U. | |||||
| Restraint condition: 1) A is a Tensor, and B, C, D are index Tensors. | |||||
| 2) U is a Tensor. | |||||
| Inputs: | Inputs: | ||||
| data (Tensor): Assigned tensor. | data (Tensor): Assigned tensor. | ||||
| input_slice (Union[tuple[Slice], tuple[Number]]): Slice expression. | |||||
| value (Number): Assignment value. | |||||
| index (Tuple): An index tuple. | |||||
| value (Tensor): Assignment tensor, should has the same data type as 'data'. | |||||
| Outputs: | Outputs: | ||||
| Tensor, element type and shape is same as data. | Tensor, element type and shape is same as data. | ||||
| """ | """ | ||||
| return _tensor_assgin_tensor(data, input_slice, value) | |||||
| index_types = multi_utils.hyper_map(F.typeof, tuple_index) | |||||
| index_elements_type = multi_utils.tuple_elements_type(index_types) | |||||
| result = None | |||||
| if index_elements_type == multi_utils.NO_TENSOR: | |||||
| result = _tensor_assgin_tensor(data, tuple_index, value) | |||||
| if index_elements_type == multi_utils.ALL_TENSOR: | |||||
| indices = multi_utils.generate_indeices_from_tuple_of_tensor(data, tuple_index, multi_utils.TENSOR_SETITEM) | |||||
| updates = multi_utils.generate_updates_from_tensor(data, indices, value, | |||||
| multi_utils.SET_ITEM_BY_TUPLE_OF_TENSOR) | |||||
| result = F.scatter_nd_update(data, indices, updates) | |||||
| return result | |||||
| def _tensor_assgin_tensor(data, input_slice, value): | |||||
| """Assigns a tensor value to the tensor by slice.""" | |||||
| @setitem.register("Tensor", "Tuple", "Tuple") | |||||
| def _tensor_setitem_by_tuple_with_tuple(data, tuple_index, value): | |||||
| """ | |||||
| Tensor assignment. | |||||
| Note: | |||||
| Syntax support: A[B, C, D] = U. | |||||
| Restraint condition: 1) A is a Tensor, and B, C, D are index Tensors. | |||||
| 2) A B and C could be broadcast. | |||||
| 3) U is a Tensor. | |||||
| Inputs: | |||||
| data (Tensor): Assigned tensor. | |||||
| index (Tuple): A tuple of tensor, these tensor could be broadcast. | |||||
| value (Tensor): Assignment tensor, should has the same data type as 'data'. | |||||
| Outputs: | |||||
| Tensor, element type and shape is same as data. | |||||
| """ | |||||
| index_types = multi_utils.hyper_map(F.typeof, tuple_index) | |||||
| index_elements_type = multi_utils.tuple_elements_type(index_types) | |||||
| result = None | result = None | ||||
| check_result = mult_util.check_tensor_setitem_index(input_slice) | |||||
| if check_result: | |||||
| data_shape = F.shape(data) | |||||
| indices = mult_util.slice2indices(input_slice, data_shape) | |||||
| is_tuple_int = mult_util.tuple_element_is_int(input_slice) | |||||
| if is_tuple_int: | |||||
| indices = mult_util.integer_to_indices(input_slice, data_shape) | |||||
| result = _tensor_indices_tensor(data, data_shape, input_slice, indices, value) | |||||
| if index_elements_type == multi_utils.ALL_TENSOR: | |||||
| indices = multi_utils.generate_indeices_from_tuple_of_tensor(data, tuple_index, multi_utils.TENSOR_SETITEM) | |||||
| updates = multi_utils.generate_updates_from_tuple(data, indices, value, | |||||
| multi_utils.SET_ITEM_BY_TUPLE_OF_TENSOR) | |||||
| result = F.scatter_nd_update(data, indices, updates) | |||||
| return result | return result | ||||
| def _tensor_indices_tensor(data, data_shape, index, indices, value): | |||||
| """Assigns a tensor value to the tensor.""" | |||||
| data_size = F.size(data) | |||||
| data_dtype = F.dtype(data) | |||||
| indices_size = F.size(indices) | |||||
| indices_size = mult_util.check_indices(indices_size, index) | |||||
| update = F.fill(mstype.int32, (indices_size,), 1) | |||||
| condition_1d = F.scatter_nd(indices, update, (data_size,)) | |||||
| condition = F.reshape(condition_1d, data_shape) | |||||
| condition = F.cast(condition, mstype.bool_) | |||||
| value_fill = None | |||||
| value_size = F.size(value) | |||||
| @setitem.register("Tensor", "Tensor", "Tuple") | |||||
| def _tensor_setitem_by_tensor_v2(data, index, value): | |||||
| """ | |||||
| Tensor assignment. | |||||
| value_size = mult_util.check_indices_value_size(indices_size, value_size) | |||||
| if value_size == 1: | |||||
| value_fill = F.fill(data_dtype, (indices_size,), 1) | |||||
| value = F.cast(value, data_dtype) | |||||
| value_fill = F.tensor_mul(value_fill, value) | |||||
| elif value_size > 1: | |||||
| value_fill = F.reshape(value, (indices_size,)) | |||||
| value_1d = F.scatter_nd(indices, value_fill, (data_size,)) | |||||
| u = F.reshape(value_1d, data_shape) | |||||
| return F.select(condition, u, data) | |||||
| Inputs: | |||||
| data (Tensor): Assigned tensor. | |||||
| index (Tensor): Tensor of bool type. | |||||
| value (Tuple): Assignment value. | |||||
| @setitem.register("Tensor", "Slice", "Number") | |||||
| def _tensor_setitem_with_slice_v1(data, input_slice, value): | |||||
| Outputs: | |||||
| Tensor, element type and shape is same as data. | |||||
| """ | |||||
| index_dtype = F.dtype(index) | |||||
| check_dtype = multi_utils.check_tensor_dtype_valid(index_dtype, (mstype.int32, mstype.int64)) | |||||
| result = None | |||||
| if check_dtype: | |||||
| result = _tensor_setitem_by_tensor_with_tuple(data, index, value) | |||||
| return result | |||||
| @setitem.register("Tensor", "Slice", "Tensor") | |||||
| def _tensor_setitem_with_slice_v3(data, input_slice, value): | |||||
| """ | """ | ||||
| Tensor assignment. | Tensor assignment. | ||||
| Note: | Note: | ||||
| Syntax support: A[Slice] = u | |||||
| Restraint condition: A is a Tensor. | |||||
| Syntax support: A[Slice] = U | |||||
| Restraint condition: A is a Tensor | |||||
| Slice like "1:3" | Slice like "1:3" | ||||
| u is a scalar | |||||
| U is a Tensor(size=1) or Tensor(size>1) | |||||
| Inputs: | Inputs: | ||||
| data (Tensor): Assigned tensor. | data (Tensor): Assigned tensor. | ||||
| input_slice (Slice): slice expression. | |||||
| input_slice (Slice): Slice expression. | |||||
| value (Number): Assignment value. | value (Number): Assignment value. | ||||
| Outputs: | Outputs: | ||||
| Tensor, element type and shape is same as data. | Tensor, element type and shape is same as data. | ||||
| """ | """ | ||||
| return _tensor_assgin_number(data, input_slice, value) | |||||
| return _tensor_assgin_tensor(data, input_slice, value) | |||||
| @setitem.register("Tensor", "Tuple", "Number") | |||||
| def _tensor_setitem_with_slice_v2(data, input_slice, value): | |||||
| @setitem.register("Tensor", "Slice", "Number") | |||||
| def _tensor_setitem_with_slice_v1(data, input_slice, value): | |||||
| """ | """ | ||||
| Tensor assignment. | Tensor assignment. | ||||
| Note: | Note: | ||||
| Syntax support: A[tuple(Slice)] = u, and A[tuple(Number)] = u | |||||
| Syntax support: A[Slice] = u | |||||
| Restraint condition: A is a Tensor. | Restraint condition: A is a Tensor. | ||||
| Slice like "1:3, ::, :4:-1" | |||||
| Slice like "1:3" | |||||
| u is a scalar | u is a scalar | ||||
| Inputs: | Inputs: | ||||
| data (Tensor): Assigned tensor. | data (Tensor): Assigned tensor. | ||||
| input_slice (Union[tuple[Slice], tuple[Number]]): slice expression. | |||||
| input_slice (Slice): slice expression. | |||||
| value (Number): Assignment value. | value (Number): Assignment value. | ||||
| Outputs: | Outputs: | ||||
| @@ -318,39 +331,23 @@ def _tensor_setitem_with_slice_v2(data, input_slice, value): | |||||
| def _tensor_assgin_number(data, input_slice, value): | def _tensor_assgin_number(data, input_slice, value): | ||||
| """Givens a scalar assign to tensor by slice""" | """Givens a scalar assign to tensor by slice""" | ||||
| check_result = mult_util.check_tensor_setitem_index(input_slice) | |||||
| check_result = multi_utils.check_tensor_setitem_index(input_slice) | |||||
| result = None | result = None | ||||
| if check_result: | if check_result: | ||||
| data_shape = F.shape(data) | data_shape = F.shape(data) | ||||
| indices = mult_util.slice2indices(input_slice, data_shape) | |||||
| is_tuple_int = mult_util.tuple_element_is_int(input_slice) | |||||
| indices = multi_utils.slice2indices(input_slice, data_shape) | |||||
| is_tuple_int = multi_utils.tuple_element_is_int(input_slice) | |||||
| if is_tuple_int: | if is_tuple_int: | ||||
| indices = mult_util.integer_to_indices(input_slice, data_shape) | |||||
| indices = multi_utils.integer_to_indices(input_slice, data_shape) | |||||
| result = _tensor_indices_number(data, data_shape, input_slice, indices, value) | result = _tensor_indices_number(data, data_shape, input_slice, indices, value) | ||||
| return result | return result | ||||
| def _tensor_indices_number(data, data_shape, index, indices, value): | |||||
| """Assigns a scalar value to the tensor.""" | |||||
| data_size = F.size(data) | |||||
| data_dtype = F.dtype(data) | |||||
| indices_size = F.size(indices) | |||||
| indices_size = mult_util.check_indices(indices_size, index) | |||||
| update = F.fill(mstype.int32, (indices_size,), 1) | |||||
| condition_1d = F.scatter_nd(indices, update, (data_size,)) | |||||
| condition = F.reshape(condition_1d, data_shape) | |||||
| condition = F.cast(condition, mstype.bool_) | |||||
| value_fill = F.fill(data_dtype, (indices_size,), value) | |||||
| value_1d = F.scatter_nd(indices, value_fill, (data_size,)) | |||||
| u = F.reshape(value_1d, data_shape) | |||||
| return F.select(condition, u, data) | |||||
| @setitem.register("Tensor", "Number", "Number") | @setitem.register("Tensor", "Number", "Number") | ||||
| def _tensor_setitem_with_int_v1(data, index, value): | def _tensor_setitem_with_int_v1(data, index, value): | ||||
| """Syntax: A[1] = 3""" | """Syntax: A[1] = 3""" | ||||
| data_shape = F.shape(data) | data_shape = F.shape(data) | ||||
| indices = mult_util.integer_to_indices(index, data_shape) | |||||
| indices = multi_utils.integer_to_indices(index, data_shape) | |||||
| return _tensor_indices_number(data, data_shape, index, indices, value) | return _tensor_indices_number(data, data_shape, index, indices, value) | ||||
| @@ -358,7 +355,7 @@ def _tensor_setitem_with_int_v1(data, index, value): | |||||
| def _tensor_setitem_with_int_v2(data, index, value): | def _tensor_setitem_with_int_v2(data, index, value): | ||||
| """Syntax: A[1] = Tensor""" | """Syntax: A[1] = Tensor""" | ||||
| data_shape = F.shape(data) | data_shape = F.shape(data) | ||||
| indices = mult_util.integer_to_indices(index, data_shape) | |||||
| indices = multi_utils.integer_to_indices(index, data_shape) | |||||
| return _tensor_indices_tensor(data, data_shape, index, indices, value) | return _tensor_indices_tensor(data, data_shape, index, indices, value) | ||||
| @@ -379,7 +376,7 @@ def _tensor_setitem_with_ellipsis_v2(data, index, value): | |||||
| data_size = F.size(data) | data_size = F.size(data) | ||||
| value_shape = F.shape(value) | value_shape = F.shape(value) | ||||
| value_size = F.size(value) | value_size = F.size(value) | ||||
| check_result = mult_util.check_ellipsis_shape_size(data_shape, value_shape, data_size, value_size) | |||||
| check_result = multi_utils.check_ellipsis_shape_size(data_shape, value_shape, data_size, value_size) | |||||
| if check_result: | if check_result: | ||||
| if data_size == value_size: | if data_size == value_size: | ||||
| result = F.reshape(value, data_shape) | result = F.reshape(value, data_shape) | ||||
| @@ -389,3 +386,108 @@ def _tensor_setitem_with_ellipsis_v2(data, index, value): | |||||
| param2 = F.cast(value, data_dtype) | param2 = F.cast(value, data_dtype) | ||||
| result = F.tensor_mul(param1, param2) | result = F.tensor_mul(param1, param2) | ||||
| return result | return result | ||||
| def _tensor_assgin_tensor(data, input_slice, value): | |||||
| """Assigns a tensor value to the tensor by slice.""" | |||||
| result = None | |||||
| check_result = multi_utils.check_tensor_setitem_index(input_slice) | |||||
| if check_result: | |||||
| data_shape = F.shape(data) | |||||
| indices = multi_utils.slice2indices(input_slice, data_shape) | |||||
| is_tuple_int = multi_utils.tuple_element_is_int(input_slice) | |||||
| if is_tuple_int: | |||||
| indices = multi_utils.integer_to_indices(input_slice, data_shape) | |||||
| result = _tensor_indices_tensor(data, data_shape, input_slice, indices, value) | |||||
| return result | |||||
| def _tensor_indices_tensor(data, data_shape, index, indices, value): | |||||
| """Assigns a tensor value to the tensor.""" | |||||
| data_size = F.size(data) | |||||
| data_dtype = F.dtype(data) | |||||
| indices_size = F.size(indices) | |||||
| indices_size = multi_utils.check_indices(indices_size, index) | |||||
| update = F.fill(mstype.int32, (indices_size,), 1) | |||||
| condition_1d = F.scatter_nd(indices, update, (data_size,)) | |||||
| condition = F.reshape(condition_1d, data_shape) | |||||
| condition = F.cast(condition, mstype.bool_) | |||||
| value_fill = None | |||||
| value_size = F.size(value) | |||||
| value_size = multi_utils.check_indices_value_size(indices_size, value_size) | |||||
| if value_size == 1: | |||||
| value_fill = F.fill(data_dtype, (indices_size,), 1) | |||||
| value = F.cast(value, data_dtype) | |||||
| value_fill = F.tensor_mul(value_fill, value) | |||||
| elif value_size > 1: | |||||
| value_fill = F.reshape(value, (indices_size,)) | |||||
| value_1d = F.scatter_nd(indices, value_fill, (data_size,)) | |||||
| u = F.reshape(value_1d, data_shape) | |||||
| return F.select(condition, u, data) | |||||
| def _tensor_indices_number(data, data_shape, index, indices, value): | |||||
| """Assigns a scalar value to the tensor.""" | |||||
| data_size = F.size(data) | |||||
| data_dtype = F.dtype(data) | |||||
| indices_size = F.size(indices) | |||||
| indices_size = multi_utils.check_indices(indices_size, index) | |||||
| update = F.fill(mstype.int32, (indices_size,), 1) | |||||
| condition_1d = F.scatter_nd(indices, update, (data_size,)) | |||||
| condition = F.reshape(condition_1d, data_shape) | |||||
| condition = F.cast(condition, mstype.bool_) | |||||
| value_fill = F.fill(data_dtype, (indices_size,), value) | |||||
| value_1d = F.scatter_nd(indices, value_fill, (data_size,)) | |||||
| u = F.reshape(value_1d, data_shape) | |||||
| return F.select(condition, u, data) | |||||
| def _tensor_setitem_by_tensor_with_tuple(data, index, value): | |||||
| """Set a tensor item by a tensor with a tuple.""" | |||||
| updates = multi_utils.generate_updates_from_tuple(data, index, value, | |||||
| multi_utils.SET_ITEM_BY_ONE_TENSOR) | |||||
| result = F.scatter_update(data, index, updates) | |||||
| return result | |||||
| def _tensor_setitem_by_int_tensor_with_scalar(data, index, value): | |||||
| """Set a tensor item by a int tensor with a scalar.""" | |||||
| updates = multi_utils.generate_updates_from_scalar(data, index, value, | |||||
| multi_utils.SET_ITEM_BY_ONE_TENSOR) | |||||
| return F.scatter_update(data, index, updates) | |||||
| def _tensor_setitem_by_bool_tensor_with_scalar(data, index, value): | |||||
| """Set a tensor item by a bool tensor with a scalar.""" | |||||
| index_shape = F.shape(index) | |||||
| shape = F.shape(data) | |||||
| shape = multi_utils.check_equal( | |||||
| shape, index_shape, "The tensor(shape={}) and tensor index(shape={}) should be the same shape.") | |||||
| dtype = F.dtype(data) | |||||
| u = F.fill(dtype, shape, value) | |||||
| return F.select(index, u, data) | |||||
| def _tensor_setitem_by_int_tensor_with_tensor(data, index, value): | |||||
| """Set a tensor item by a int tensor with a tensor.""" | |||||
| updates = multi_utils.generate_updates_from_tensor(data, index, value, | |||||
| multi_utils.SET_ITEM_BY_ONE_TENSOR) | |||||
| return F.scatter_update(data, index, updates) | |||||
| def _tensor_setitem_by_bool_tensor_with_tensor(data, index, value): | |||||
| """Set a tensor item by a bool tensor with a tensor.""" | |||||
| index_shape = F.shape(index) | |||||
| data_shape = F.shape(data) | |||||
| data_shape = multi_utils.check_equal(data_shape, index_shape, | |||||
| "The tensor(shape={}) and tensor index(shape={}) should be the same shape.") | |||||
| size = F.size(value) | |||||
| size = multi_utils.check_equal(1, size, | |||||
| "When assign value is a tensor, its size should be {}, but current size is {}.") | |||||
| dtype = F.dtype(data) | |||||
| u_cast = F.cast(value, dtype) | |||||
| one_data = F.ones_like(data) | |||||
| u = F.tensor_mul(one_data, u_cast) | |||||
| result = F.select(index, u, data) | |||||
| return result | |||||
| @@ -31,6 +31,7 @@ dtype = P.DType() | |||||
| issubclass_ = P.IsSubClass() | issubclass_ = P.IsSubClass() | ||||
| isinstance_ = P.IsInstance() | isinstance_ = P.IsInstance() | ||||
| fill = P.Fill() | fill = P.Fill() | ||||
| tile = P.Tile() | |||||
| select = P.Select() | select = P.Select() | ||||
| size = P.Size() | size = P.Size() | ||||
| ones_like = P.OnesLike() | ones_like = P.OnesLike() | ||||
| @@ -70,6 +71,12 @@ scalar_cast = P.ScalarCast() | |||||
| print_ = P.Print() | print_ = P.Print() | ||||
| expand_dims = P.ExpandDims() | expand_dims = P.ExpandDims() | ||||
| scatter_nd = P.ScatterNd() | scatter_nd = P.ScatterNd() | ||||
| gather = P.GatherV2() | |||||
| gather_nd = P.GatherNd() | |||||
| scatter_update = P.ScatterUpdate() | |||||
| scatter_nd_update = P.ScatterNdUpdate() | |||||
| pack = P.Pack() | |||||
| tuple_setitem = Primitive('tuple_setitem') | tuple_setitem = Primitive('tuple_setitem') | ||||
| tuple_getitem = Primitive('tuple_getitem') | tuple_getitem = Primitive('tuple_getitem') | ||||
| @@ -24,7 +24,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack, | |||||
| Fill, GatherNd, GatherV2, InvertPermutation, | Fill, GatherNd, GatherV2, InvertPermutation, | ||||
| IsInstance, IsSubClass, ArgMaxWithValue, OnesLike, ZerosLike, | IsInstance, IsSubClass, ArgMaxWithValue, OnesLike, ZerosLike, | ||||
| Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue, | Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue, | ||||
| SameTypeShape, ScatterMax, | |||||
| SameTypeShape, ScatterMax, ScatterUpdate, | |||||
| ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select, | ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select, | ||||
| Shape, Size, Slice, Split, | Shape, Size, Slice, Split, | ||||
| Squeeze, StridedSlice, Tile, | Squeeze, StridedSlice, Tile, | ||||
| @@ -193,6 +193,7 @@ __all__ = [ | |||||
| 'Pad', | 'Pad', | ||||
| 'MirrorPad', | 'MirrorPad', | ||||
| 'GatherNd', | 'GatherNd', | ||||
| 'ScatterUpdate', | |||||
| 'ScatterNdUpdate', | 'ScatterNdUpdate', | ||||
| 'Floor', | 'Floor', | ||||
| 'NMSWithMask', | 'NMSWithMask', | ||||
| @@ -19,7 +19,7 @@ from ..._c_expression import signature_rw as sig_rw | |||||
| from ..._c_expression import signature_kind as sig_kind | from ..._c_expression import signature_kind as sig_kind | ||||
| from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register | from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register | ||||
| from ..._checkparam import Validator as validator, Rel | from ..._checkparam import Validator as validator, Rel | ||||
| from .._utils import _get_concat_offset | |||||
| from .._utils import get_concat_offset | |||||
| from ...common import dtype as mstype | from ...common import dtype as mstype | ||||
| @@ -136,7 +136,7 @@ class ConcatOffset(PrimitiveWithInfer): | |||||
| axis = self.axis | axis = self.axis | ||||
| x_shp = input_x['shape'] | x_shp = input_x['shape'] | ||||
| x_type = input_x['dtype'] | x_type = input_x['dtype'] | ||||
| offset, _, axis = _get_concat_offset(x_shp, x_type, axis, self.name) | |||||
| offset, _, axis = get_concat_offset(x_shp, x_type, axis, self.name) | |||||
| self.add_prim_attr('T', x_type[0].element_type()) | self.add_prim_attr('T', x_type[0].element_type()) | ||||
| offset_values = [] | offset_values = [] | ||||
| for i in range(len(x_shp)): | for i in range(len(x_shp)): | ||||
| @@ -24,16 +24,15 @@ import itertools | |||||
| import numbers | import numbers | ||||
| import numpy as np | import numpy as np | ||||
| from ..._c_expression import signature_rw as sig_rw | |||||
| from ..._c_expression import signature_kind as sig_kind | |||||
| from ..._checkparam import Validator as validator | from ..._checkparam import Validator as validator | ||||
| from ..._checkparam import Rel | from ..._checkparam import Rel | ||||
| from ...common import dtype as mstype | from ...common import dtype as mstype | ||||
| from ...common.tensor import Tensor | from ...common.tensor import Tensor | ||||
| from ..operations.math_ops import _infer_shape_reduce | from ..operations.math_ops import _infer_shape_reduce | ||||
| from .._utils import _get_concat_offset | |||||
| from .._utils import get_concat_offset | |||||
| from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register | from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register | ||||
| def _check_infer_attr_reduce(axis, keep_dims, prim_name): | def _check_infer_attr_reduce(axis, keep_dims, prim_name): | ||||
| validator.check_value_type('keep_dims', keep_dims, [bool], prim_name) | validator.check_value_type('keep_dims', keep_dims, [bool], prim_name) | ||||
| validator.check_value_type('axis', axis, [int, tuple], prim_name) | validator.check_value_type('axis', axis, [int, tuple], prim_name) | ||||
| @@ -931,7 +930,7 @@ class InvertPermutation(PrimitiveWithInfer): | |||||
| z = [x_value[i] for i in range(len(x_value))] | z = [x_value[i] for i in range(len(x_value))] | ||||
| z.sort() | z.sort() | ||||
| y = [None]*len(x_value) | |||||
| y = [None] * len(x_value) | |||||
| for i, value in enumerate(x_value): | for i, value in enumerate(x_value): | ||||
| validator.check_value_type("input[%d]" % i, value, [int], self.name) | validator.check_value_type("input[%d]" % i, value, [int], self.name) | ||||
| validator.check(f'value', z[i], f'index', i, Rel.EQ, self.name) | validator.check(f'value', z[i], f'index', i, Rel.EQ, self.name) | ||||
| @@ -1111,6 +1110,7 @@ class ArgMinWithValue(PrimitiveWithInfer): | |||||
| >>> input_x = Tensor(np.random.rand(5)) | >>> input_x = Tensor(np.random.rand(5)) | ||||
| >>> index, output = P.ArgMinWithValue()(input_x) | >>> index, output = P.ArgMinWithValue()(input_x) | ||||
| """ | """ | ||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self, axis=0, keep_dims=False): | def __init__(self, axis=0, keep_dims=False): | ||||
| """init ArgMinWithValue""" | """init ArgMinWithValue""" | ||||
| @@ -1352,7 +1352,7 @@ class Concat(PrimitiveWithInfer): | |||||
| axis = self.axis | axis = self.axis | ||||
| x_shp = input_x['shape'] | x_shp = input_x['shape'] | ||||
| x_type = input_x['dtype'] | x_type = input_x['dtype'] | ||||
| _, all_shp, _ = _get_concat_offset(x_shp, x_type, axis, self.name) | |||||
| _, all_shp, _ = get_concat_offset(x_shp, x_type, axis, self.name) | |||||
| self.add_prim_attr('T', x_type[0].element_type()) | self.add_prim_attr('T', x_type[0].element_type()) | ||||
| self.add_prim_attr('inputNums', len(x_shp)) | self.add_prim_attr('inputNums', len(x_shp)) | ||||
| ret_shp = x_shp[0].copy() | ret_shp = x_shp[0].copy() | ||||
| @@ -1376,15 +1376,13 @@ def _get_pack_shape(x_shape, x_type, axis, prim_name): | |||||
| if axis < 0: | if axis < 0: | ||||
| axis = axis + rank_base + 1 | axis = axis + rank_base + 1 | ||||
| for i in range(1, N): | for i in range(1, N): | ||||
| v = x_shape[i] | |||||
| validator.check('len of x_shape[%d]' % i, len(v), 'len of rank_base', rank_base, Rel.EQ, prim_name) | |||||
| validator.check('x_type[%d]' % i, x_type[i], 'base', x_type[0], Rel.EQ, prim_name, TypeError) | validator.check('x_type[%d]' % i, x_type[i], 'base', x_type[0], Rel.EQ, prim_name, TypeError) | ||||
| for j in range(rank_base): | |||||
| if v[j] != x_shape[0][j]: | |||||
| raise ValueError(f"For \'{prim_name}\' element {i} shape in input can not pack with first element") | |||||
| if x_shape[i] != x_shape[0]: | |||||
| raise ValueError(f"For \'{prim_name}\' element {i} shape in input can not pack with first element") | |||||
| out_shape.insert(axis, N) | out_shape.insert(axis, N) | ||||
| return out_shape | return out_shape | ||||
| class Pack(PrimitiveWithInfer): | class Pack(PrimitiveWithInfer): | ||||
| r""" | r""" | ||||
| Packs a list of tensors in specified axis. | Packs a list of tensors in specified axis. | ||||
| @@ -1831,7 +1829,7 @@ class DiagPart(PrimitiveWithInfer): | |||||
| return x_type | return x_type | ||||
| def infer_shape(self, x_shape): | def infer_shape(self, x_shape): | ||||
| if len(x_shape)%2 != 0 or \ | |||||
| if len(x_shape) % 2 != 0 or \ | |||||
| not x_shape: | not x_shape: | ||||
| raise ValueError(f"For \'{self.name}\' input rank must be non-zero and even, but got rank {len(x_shape)}, " | raise ValueError(f"For \'{self.name}\' input rank must be non-zero and even, but got rank {len(x_shape)}, " | ||||
| f"with shapes {x_shape}") | f"with shapes {x_shape}") | ||||
| @@ -2004,6 +2002,49 @@ class GatherNd(PrimitiveWithInfer): | |||||
| return x_dtype | return x_dtype | ||||
| class ScatterUpdate(PrimitiveWithInfer): | |||||
| """ | |||||
| Update tensor value by using input indices and value. | |||||
| Using given values to update tensor value, along with the input indices. | |||||
| Args: | |||||
| use_locking (bool): Whether protect the assignment by a lock. Default: True. | |||||
| Inputs: | |||||
| - **input_x** (Parameter) - The target tensor, with data type of Parameter. | |||||
| - **indices** (Tensor) - The index of input tensor. | |||||
| - **update** (Tensor) - The tensor to update the input tensor, has the same type as input, | |||||
| and update.shape = indices.shape + input_x.shape[1:]. | |||||
| Outputs: | |||||
| Tensor, has the same shape and type as `input_x`. | |||||
| Examples: | |||||
| >>> input_x = mindspore.Parameter(Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32)) | |||||
| >>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32) | |||||
| >>> update = Tensor(np.array([1.0, 2.2]), mindspore.float32) | |||||
| >>> op = P.ScatterNdUpdate() | |||||
| >>> output = op(input_x, indices, update) | |||||
| """ | |||||
| @prim_attr_register | |||||
| def __init__(self, use_locking=True): | |||||
| """Init ScatterNdUpdate""" | |||||
| self.init_prim_io_names(inputs=['x', 'indices', 'value'], outputs=['y']) | |||||
| def infer_shape(self, x_shape, indices_shape, value_shape): | |||||
| if indices_shape + x_shape[1:] != value_shape: | |||||
| raise ValueError('Input value are not match with input indices.') | |||||
| return x_shape | |||||
| def infer_dtype(self, x_dtype, indices_dtype, value_dtype): | |||||
| validator.check_tensor_type_same({'indices': indices_dtype}, mstype.int_type, self.name) | |||||
| args = {"x": x_dtype, "value": value_dtype} | |||||
| validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name) | |||||
| return x_dtype | |||||
| class ScatterNdUpdate(PrimitiveWithInfer): | class ScatterNdUpdate(PrimitiveWithInfer): | ||||
| """ | """ | ||||
| Update tensor value by using input indices and value. | Update tensor value by using input indices and value. | ||||
| @@ -2028,11 +2069,6 @@ class ScatterNdUpdate(PrimitiveWithInfer): | |||||
| >>> op = P.ScatterNdUpdate() | >>> op = P.ScatterNdUpdate() | ||||
| >>> output = op(input_x, indices, update) | >>> output = op(input_x, indices, update) | ||||
| """ | """ | ||||
| __mindspore_signature__ = ( | |||||
| ('input_x', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD), | |||||
| ('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD), | |||||
| ('value', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD) | |||||
| ) | |||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self, use_locking=True): | def __init__(self, use_locking=True): | ||||
| @@ -2142,10 +2178,10 @@ class SpaceToDepth(PrimitiveWithInfer): | |||||
| validator.check('x dimension', len(x_shape), '', 4, Rel.EQ) | validator.check('x dimension', len(x_shape), '', 4, Rel.EQ) | ||||
| out_shape = copy.deepcopy(x_shape) | out_shape = copy.deepcopy(x_shape) | ||||
| for i in range(2): | for i in range(2): | ||||
| if out_shape[i+2] % self.block_size != 0: | |||||
| raise ValueError(f'For \'{self.name}\' input shape[{i+2}] {out_shape[i+2]} should be ' | |||||
| if out_shape[i + 2] % self.block_size != 0: | |||||
| raise ValueError(f'For \'{self.name}\' input shape[{i + 2}] {out_shape[i + 2]} should be ' | |||||
| f'fully divided by block_size {self.block_size}') | f'fully divided by block_size {self.block_size}') | ||||
| out_shape[i+2] //= self.block_size | |||||
| out_shape[i + 2] //= self.block_size | |||||
| out_shape[1] *= self.block_size * self.block_size | out_shape[1] *= self.block_size * self.block_size | ||||
| return out_shape | return out_shape | ||||
| @@ -2199,9 +2235,10 @@ class DepthToSpace(PrimitiveWithInfer): | |||||
| validator.check('x dimension', len(x_shape), '', 4, Rel.EQ) | validator.check('x dimension', len(x_shape), '', 4, Rel.EQ) | ||||
| out_shape = copy.deepcopy(x_shape) | out_shape = copy.deepcopy(x_shape) | ||||
| for i in range(2): | for i in range(2): | ||||
| out_shape[i+2] *= self.block_size | |||||
| out_shape[i + 2] *= self.block_size | |||||
| validator.check_integer('x_shape[1] % (block_size*block_size)', x_shape[1] % (self.block_size*self.block_size), | |||||
| validator.check_integer('x_shape[1] % (block_size*block_size)', | |||||
| x_shape[1] % (self.block_size * self.block_size), | |||||
| 0, Rel.EQ, self.name) | 0, Rel.EQ, self.name) | ||||
| out_shape[1] //= self.block_size * self.block_size | out_shape[1] //= self.block_size * self.block_size | ||||
| return out_shape | return out_shape | ||||
| @@ -2251,6 +2288,7 @@ class SpaceToBatch(PrimitiveWithInfer): | |||||
| [[[[1.]]], [[[2.]]], [[[3.]]], [[[4.]]]] | [[[[1.]]], [[[2.]]], [[[3.]]], [[[4.]]]] | ||||
| """ | """ | ||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self, block_size, paddings): | def __init__(self, block_size, paddings): | ||||
| """Init SpaceToBatch""" | """Init SpaceToBatch""" | ||||
| @@ -2271,12 +2309,12 @@ class SpaceToBatch(PrimitiveWithInfer): | |||||
| validator.check_integer('rank of input_x', len(x_shape), 4, Rel.EQ, self.name) | validator.check_integer('rank of input_x', len(x_shape), 4, Rel.EQ, self.name) | ||||
| out_shape = copy.deepcopy(x_shape) | out_shape = copy.deepcopy(x_shape) | ||||
| for i in range(2): | for i in range(2): | ||||
| padded = out_shape[i+2] + self.paddings[i][0] + \ | |||||
| padded = out_shape[i + 2] + self.paddings[i][0] + \ | |||||
| self.paddings[i][1] | self.paddings[i][1] | ||||
| if padded % self.block_size != 0: | if padded % self.block_size != 0: | ||||
| raise ValueError(f'For \'{self.name}\' padded[{i}] {padded} should be divisible by ' | raise ValueError(f'For \'{self.name}\' padded[{i}] {padded} should be divisible by ' | ||||
| f'block_size {self.block_size}') | f'block_size {self.block_size}') | ||||
| out_shape[i+2] = padded // self.block_size | |||||
| out_shape[i + 2] = padded // self.block_size | |||||
| out_shape[0] *= self.block_size * self.block_size | out_shape[0] *= self.block_size * self.block_size | ||||
| return out_shape | return out_shape | ||||
| @@ -2319,6 +2357,7 @@ class BatchToSpace(PrimitiveWithInfer): | |||||
| [[[[1., 2.], [3., 4.]]]] | [[[[1., 2.], [3., 4.]]]] | ||||
| """ | """ | ||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self, block_size, crops): | def __init__(self, block_size, crops): | ||||
| """Init BatchToSpace""" | """Init BatchToSpace""" | ||||
| @@ -2339,10 +2378,10 @@ class BatchToSpace(PrimitiveWithInfer): | |||||
| validator.check('rank of input_x', len(x_shape), '', 4) | validator.check('rank of input_x', len(x_shape), '', 4) | ||||
| out_shape = copy.deepcopy(x_shape) | out_shape = copy.deepcopy(x_shape) | ||||
| for i in range(2): | for i in range(2): | ||||
| x_block_prod = out_shape[i+2] * self.block_size | |||||
| x_block_prod = out_shape[i + 2] * self.block_size | |||||
| crops_sum = self.crops[i][0] + self.crops[i][1] | crops_sum = self.crops[i][0] + self.crops[i][1] | ||||
| validator.check("x block shape prod", x_block_prod, 'crops sum', crops_sum, Rel.GT, self.name) | validator.check("x block shape prod", x_block_prod, 'crops sum', crops_sum, Rel.GT, self.name) | ||||
| out_shape[i+2] = x_block_prod - crops_sum | |||||
| out_shape[i + 2] = x_block_prod - crops_sum | |||||
| block_size_prod = self.block_size * self.block_size | block_size_prod = self.block_size * self.block_size | ||||
| if out_shape[0] % block_size_prod != 0: | if out_shape[0] % block_size_prod != 0: | ||||
| raise ValueError(f'For \'{self.name}\' input_x dimension 0 {out_shape[0]} should be divisible by ' | raise ValueError(f'For \'{self.name}\' input_x dimension 0 {out_shape[0]} should be divisible by ' | ||||
| @@ -24,7 +24,7 @@ from ..._checkparam import Validator as validator | |||||
| from ..._checkparam import Rel | from ..._checkparam import Rel | ||||
| from ...common import dtype as mstype | from ...common import dtype as mstype | ||||
| from ...common.tensor import Tensor | from ...common.tensor import Tensor | ||||
| from .._utils import _get_broadcast_shape | |||||
| from .._utils import get_broadcast_shape | |||||
| from ..primitive import PrimitiveWithInfer, prim_attr_register, _run_op | from ..primitive import PrimitiveWithInfer, prim_attr_register, _run_op | ||||
| @@ -75,7 +75,7 @@ class _BinaryOp(PrimitiveWithInfer): | |||||
| self.init_prim_io_names(inputs=['x', 'y'], outputs=['output']) | self.init_prim_io_names(inputs=['x', 'y'], outputs=['output']) | ||||
| def infer_shape(self, x_shape, y_shape): | def infer_shape(self, x_shape, y_shape): | ||||
| return _get_broadcast_shape(x_shape, y_shape, self.name) | |||||
| return get_broadcast_shape(x_shape, y_shape, self.name) | |||||
| class _MathBinaryOp(_BinaryOp): | class _MathBinaryOp(_BinaryOp): | ||||
| @@ -27,9 +27,15 @@ class IdentityEC(IExectorComponent): | |||||
| def __call__(self): | def __call__(self): | ||||
| result_id = self.function[keyword.id] + '-' + self.inputs[keyword.id] | result_id = self.function[keyword.id] + '-' + self.inputs[keyword.id] | ||||
| group = self.function[keyword.group] + '-' + self.inputs[keyword.group] | group = self.function[keyword.group] + '-' + self.inputs[keyword.group] | ||||
| return { | |||||
| ret = { | |||||
| keyword.id: result_id, | keyword.id: result_id, | ||||
| keyword.group: group, | keyword.group: group, | ||||
| keyword.desc_inputs: self.inputs[keyword.desc_inputs], | keyword.desc_inputs: self.inputs[keyword.desc_inputs], | ||||
| keyword.result: self.function[keyword.block](*self.inputs[keyword.desc_inputs]) | keyword.result: self.function[keyword.block](*self.inputs[keyword.desc_inputs]) | ||||
| } | } | ||||
| print("buxue------------------------------------------------") | |||||
| print("inputs") | |||||
| print(ret[keyword.desc_inputs]) | |||||
| print("outputs") | |||||
| print(ret[keyword.result]) | |||||
| return ret | |||||
| @@ -1307,7 +1307,7 @@ raise_set = [ | |||||
| ('ScatterNdUpdate', { | ('ScatterNdUpdate', { | ||||
| 'block': (P.ScatterNdUpdate(), {'exception': TypeError}), | 'block': (P.ScatterNdUpdate(), {'exception': TypeError}), | ||||
| 'desc_inputs': (Tensor(np.ones((2, 3), np.float32)), | 'desc_inputs': (Tensor(np.ones((2, 3), np.float32)), | ||||
| Tensor(np.ones((2, 2), np.int32)), | |||||
| Tensor(np.ones((2, 2), np.float32)), | |||||
| Tensor(np.ones((2,), np.float32))), | Tensor(np.ones((2,), np.float32))), | ||||
| 'desc_bprop': [[2, 3]]}), | 'desc_bprop': [[2, 3]]}), | ||||
| ('Pack', { | ('Pack', { | ||||
| @@ -16,13 +16,14 @@ | |||||
| import numpy as np | import numpy as np | ||||
| import pytest | import pytest | ||||
| from mindspore import Tensor | |||||
| from mindspore import Tensor, Parameter | |||||
| from mindspore import context | from mindspore import context | ||||
| from mindspore import dtype as mstype | from mindspore import dtype as mstype | ||||
| from mindspore.nn import Cell | from mindspore.nn import Cell | ||||
| from ....mindspore_test_framework.mindspore_test import mindspore_test | from ....mindspore_test_framework.mindspore_test import mindspore_test | ||||
| from ....mindspore_test_framework.pipeline.forward.compile_forward \ | from ....mindspore_test_framework.pipeline.forward.compile_forward \ | ||||
| import pipeline_for_compile_forward_ge_graph_for_case_by_case_config | |||||
| import pipeline_for_compile_forward_ge_graph_for_case_by_case_config, \ | |||||
| pipeline_for_compile_forward_ge_graph_for_case_by_case_config_exception | |||||
| class NetWorkSlicePositive(Cell): | class NetWorkSlicePositive(Cell): | ||||
| @@ -145,6 +146,160 @@ class TensorAssignWithSlice(Cell): | |||||
| return z | return z | ||||
| class TensorIndexByOneTensor(Cell): | |||||
| def __init__(self): | |||||
| super(TensorIndexByOneTensor, self).__init__() | |||||
| self.const = Tensor(np.ones((5, 4, 7, 8)), mstype.int32) | |||||
| def construct(self, x, index): | |||||
| ret = x[index] + self.const | |||||
| return ret | |||||
| class TensorIndexByTwoTensors(Cell): | |||||
| def __init__(self): | |||||
| super(TensorIndexByTwoTensors, self).__init__() | |||||
| self.const = Tensor(np.ones((3, 4, 5, 8)), mstype.int32) | |||||
| def construct(self, x, index_0, index_1): | |||||
| ret = x[index_0, index_1] + self.const | |||||
| return ret | |||||
| class TensorIndexByThreeTensors(Cell): | |||||
| def __init__(self): | |||||
| super(TensorIndexByThreeTensors, self).__init__() | |||||
| self.const = Tensor(np.ones((5, 3, 4, 5)), mstype.int32) | |||||
| def construct(self, x, index_0, index_1, index_2): | |||||
| ret = x[index_0, index_1, index_2] + self.const | |||||
| return ret | |||||
| class TensorSetItemByOneTensorWithNumber(Cell): | |||||
| def __init__(self, value): | |||||
| super(TensorSetItemByOneTensorWithNumber, self).__init__() | |||||
| self.const = Tensor(np.ones((6, 7, 8)), mstype.float32) | |||||
| self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x") | |||||
| self.value = value | |||||
| def construct(self, index): | |||||
| self.param[index] = self.value | |||||
| ret = self.param + self.const | |||||
| return ret | |||||
| class TensorSetItemByOneTensorWithTensor(Cell): | |||||
| def __init__(self): | |||||
| super(TensorSetItemByOneTensorWithTensor, self).__init__() | |||||
| self.const = Tensor(np.ones((6, 7, 8)), mstype.float32) | |||||
| self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x") | |||||
| def construct(self, index, value): | |||||
| self.param[index] = value | |||||
| ret = self.param + self.const | |||||
| return ret | |||||
| class TensorSetItemByOneTensorWithTupleOfNumber(Cell): | |||||
| def __init__(self, value): | |||||
| super(TensorSetItemByOneTensorWithTupleOfNumber, self).__init__() | |||||
| self.const = Tensor(np.ones((6, 7, 8)), mstype.float32) | |||||
| self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x") | |||||
| self.value = value | |||||
| def construct(self, index): | |||||
| self.param[index] = self.value | |||||
| ret = self.param + self.const | |||||
| return ret | |||||
| class TensorSetItemByOneTensorWithTupleOfTensor(Cell): | |||||
| def __init__(self): | |||||
| super(TensorSetItemByOneTensorWithTupleOfTensor, self).__init__() | |||||
| self.const = Tensor(np.ones((6, 3, 8)), mstype.float32) | |||||
| self.param = Parameter(Tensor(np.arange(6*3*8).reshape((6, 3, 8)), mstype.float32), name="x") | |||||
| def construct(self, index, value_0, value_1, value_2): | |||||
| self.param[index] = (value_0, value_1, value_2) | |||||
| ret = self.param + self.const | |||||
| return ret | |||||
| class TensorSetItemByTensorsWithNumber(Cell): | |||||
| def __init__(self, value): | |||||
| super(TensorSetItemByTensorsWithNumber, self).__init__() | |||||
| self.const = Tensor(np.ones((6, 7, 8)), mstype.float32) | |||||
| self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x") | |||||
| self.value = value | |||||
| def construct(self, index_0, index_1, index_2): | |||||
| self.param[index_0, index_1, index_2] = self.value | |||||
| ret = self.param + self.const | |||||
| return ret | |||||
| class TensorSetItemByTensorsWithTensor(Cell): | |||||
| def __init__(self): | |||||
| super(TensorSetItemByTensorsWithTensor, self).__init__() | |||||
| self.const = Tensor(np.ones((6, 7, 8)), mstype.float32) | |||||
| self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x") | |||||
| def construct(self, index_0, index_1, index_2, value): | |||||
| self.param[index_0, index_1, index_2] = value | |||||
| ret = self.param + self.const | |||||
| return ret | |||||
| class TensorSetItemByTensorsWithTensorNumberError(Cell): | |||||
| def __init__(self): | |||||
| super(TensorSetItemByTensorsWithTensorNumberError, self).__init__() | |||||
| self.const = Tensor(np.ones((6, 7, 8)), mstype.float32) | |||||
| self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x") | |||||
| def construct(self, index_0, index_1, index_2, index_3, value): | |||||
| self.param[index_0, index_1, index_2, index_3] = value | |||||
| ret = self.param + self.const | |||||
| return ret | |||||
| class TensorSetItemByTensorsWithTupleOfNumber(Cell): | |||||
| def __init__(self, value): | |||||
| super(TensorSetItemByTensorsWithTupleOfNumber, self).__init__() | |||||
| self.const = Tensor(np.ones((6, 7, 8)), mstype.float32) | |||||
| self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x") | |||||
| self.value = value | |||||
| def construct(self, index_0, index_1, index_2): | |||||
| self.param[index_0, index_1, index_2] = self.value | |||||
| ret = self.param + self.const | |||||
| return ret | |||||
| class TensorSetItemByTensorsWithTupleOfTensor(Cell): | |||||
| def __init__(self): | |||||
| super(TensorSetItemByTensorsWithTupleOfTensor, self).__init__() | |||||
| self.const = Tensor(np.ones((6, 7, 8)), mstype.float32) | |||||
| self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x") | |||||
| def construct(self, index_0, index_1, index_2, value_0, value_1, value_2): | |||||
| self.param[index_0, index_1, index_2] = (value_0, value_1, value_2) | |||||
| ret = self.param + self.const | |||||
| return ret | |||||
| class TensorSetItemByTensorsWithTupleOfTensorNumberError(Cell): | |||||
| def __init__(self): | |||||
| super(TensorSetItemByTensorsWithTupleOfTensorNumberError, self).__init__() | |||||
| self.const = Tensor(np.ones((6, 7, 8)), mstype.float32) | |||||
| self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x") | |||||
| def construct(self, index_0, index_1, index_2, value_0, value_1): | |||||
| self.param[index_0, index_1, index_2] = (value_0, value_1) | |||||
| ret = self.param + self.const | |||||
| return ret | |||||
| def test_tensor_assign(): | def test_tensor_assign(): | ||||
| context.set_context(mode=context.GRAPH_MODE, save_graphs=True) | context.set_context(mode=context.GRAPH_MODE, save_graphs=True) | ||||
| net = TensorAssignWithSlice() | net = TensorAssignWithSlice() | ||||
| @@ -441,15 +596,206 @@ test_cases = [ | |||||
| 'block': NetWorkSliceEllipsis(), | 'block': NetWorkSliceEllipsis(), | ||||
| 'desc_inputs': [Tensor(np.ones([6, 7, 8, 9], np.int32))], | 'desc_inputs': [Tensor(np.ones([6, 7, 8, 9], np.int32))], | ||||
| }), | }), | ||||
| ('TensorIndexByOneTensor', { | |||||
| 'block': TensorIndexByOneTensor(), | |||||
| 'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32), | |||||
| Tensor(np.random.randint(6, size=(5, 4)), mstype.int32)], | |||||
| }), | |||||
| ('TensorIndexByTwoTensors', { | |||||
| 'block': TensorIndexByTwoTensors(), | |||||
| 'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32), | |||||
| Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), | |||||
| Tensor(np.random.randint(7, size=(4, 5)), mstype.int32)], | |||||
| }), | |||||
| ('TensorIndexByThreeTensors', { | |||||
| 'block': TensorIndexByThreeTensors(), | |||||
| 'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32), | |||||
| Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), | |||||
| Tensor(np.random.randint(7, size=(4, 5)), mstype.int32), | |||||
| Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)], | |||||
| }), | |||||
| ('TensorSetItemByOneTensorWithNumber', { | |||||
| 'block': TensorSetItemByOneTensorWithNumber(value=0.0), | |||||
| 'desc_inputs': [Tensor(np.random.randint(4, size=(5, 4)), mstype.int32)], | |||||
| }), | |||||
| ('TensorSetItemByOneTensorWithTensor', { | |||||
| 'block': TensorSetItemByOneTensorWithTensor(), | |||||
| 'desc_inputs': [Tensor(np.random.randint(3, size=(5, 4)), mstype.int32), | |||||
| Tensor(np.zeros((4, 7, 8)), mstype.float32)], | |||||
| }), | |||||
| ('TensorSetItemByOneTensorWithTupleOfNumber', { | |||||
| 'block': TensorSetItemByOneTensorWithTupleOfNumber(value=(0.0, 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7)), | |||||
| 'desc_inputs': [Tensor(np.random.randint(5, size=(5, 4)), mstype.int32)], | |||||
| }), | |||||
| ('TensorSetItemByOneTensorWithTupleOfTensor', { | |||||
| 'block': TensorSetItemByOneTensorWithTupleOfTensor(), | |||||
| 'desc_inputs': [Tensor(np.random.randint(6, size=(5, 4)), mstype.int32), | |||||
| Tensor(np.zeros((8,), np.float32)), | |||||
| Tensor(np.ones((8,), np.float32)), | |||||
| Tensor(np.ones((8,), np.float32) * 2)], | |||||
| }), | |||||
| ('TensorSetItemByTensorsWithNumber', { | |||||
| 'block': TensorSetItemByTensorsWithNumber(value=0.0), | |||||
| 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), | |||||
| Tensor(np.random.randint(7, size=(4, 5)), mstype.int32), | |||||
| Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)], | |||||
| }), | |||||
| ('TensorSetItemByTensorsWithTensor', { | |||||
| 'block': TensorSetItemByTensorsWithTensor(), | |||||
| 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), | |||||
| Tensor(np.random.randint(7, size=(4, 5)), mstype.int32), | |||||
| Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32), | |||||
| Tensor(np.zeros((4, 5)), mstype.float32)], | |||||
| }), | |||||
| ('TensorSetItemByTensorsWithTupleOfNumber', { | |||||
| 'block': TensorSetItemByTensorsWithTupleOfNumber(value=(0.0, 1.1, 2.2, 3.3, 4.4)), | |||||
| 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), | |||||
| Tensor(np.random.randint(7, size=(4, 5)), mstype.int32), | |||||
| Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)], | |||||
| }), | |||||
| ('TensorSetItemByTensorsWithTupleOfTensor', { | |||||
| 'block': TensorSetItemByTensorsWithTupleOfTensor(), | |||||
| 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), | |||||
| Tensor(np.random.randint(7, size=(4, 5)), mstype.int32), | |||||
| Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32), | |||||
| Tensor(np.zeros((4, 5)), mstype.float32), | |||||
| Tensor(np.ones((4, 5)), mstype.float32), | |||||
| Tensor(np.ones((4, 5)) * 2, mstype.float32)], | |||||
| }) | |||||
| ] | |||||
| raise_error_set = [ | |||||
| ('TensorIndexByOneTensorDtypeError', { | |||||
| 'block': (TensorIndexByOneTensor(), {'exception': TypeError}), | |||||
| 'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32), | |||||
| Tensor(np.random.randint(6, size=(5, 4)), mstype.int8)], | |||||
| }), | |||||
| ('TensorIndexByTwoTensorsShapeError', { | |||||
| 'block': (TensorIndexByTwoTensors(), {'exception': ValueError}), | |||||
| 'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32), | |||||
| Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), | |||||
| Tensor(np.random.randint(7, size=(2, 3, 5)), mstype.int32)], | |||||
| }), | |||||
| ('TensorIndexByTwoTensorsDtypeError', { | |||||
| 'block': (TensorIndexByTwoTensors(), {'exception': TypeError}), | |||||
| 'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32), | |||||
| Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), | |||||
| Tensor(np.random.randint(7, size=(4, 5)), mstype.float32)], | |||||
| }), | |||||
| ('TensorIndexByThreeTensorsShapeError', { | |||||
| 'block': (TensorIndexByThreeTensors(), {'exception': ValueError}), | |||||
| 'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32), | |||||
| Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), | |||||
| Tensor(np.random.randint(7, size=(3, 4, 5)), mstype.int32), | |||||
| Tensor(np.random.randint(8, size=(5, 2, 4, 5)), mstype.int32)], | |||||
| }), | |||||
| ('TensorIndexByThreeTensorsDtypeError', { | |||||
| 'block': (TensorIndexByThreeTensors(), {'exception': TypeError}), | |||||
| 'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32), | |||||
| Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), | |||||
| Tensor(np.random.randint(7, size=(3, 4, 5)), mstype.int64), | |||||
| Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)], | |||||
| }), | |||||
| ('TensorSetItemByOneTensorWithNumberTypeError', { | |||||
| 'block': (TensorSetItemByOneTensorWithNumber(value=0), {'exception': TypeError}), | |||||
| 'desc_inputs': [Tensor(np.random.randint(4, size=(5, 4)), mstype.int32)], | |||||
| }), | |||||
| ('TensorSetItemByOneTensorWithTensorShapeError', { | |||||
| 'block': (TensorSetItemByOneTensorWithTensor(), {'exception': ValueError}), | |||||
| 'desc_inputs': [Tensor(np.random.randint(3, size=(5, 4)), mstype.int32), | |||||
| Tensor(np.zeros((6, 7, 8)), mstype.float32)], | |||||
| }), | |||||
| ('TensorSetItemByOneTensorWithTensorDtypeError', { | |||||
| 'block': (TensorSetItemByOneTensorWithTensor(), {'exception': TypeError}), | |||||
| 'desc_inputs': [Tensor(np.random.randint(3, size=(5, 4)), mstype.int32), | |||||
| Tensor(np.zeros((6, 7, 8)), mstype.int32)], | |||||
| }), | |||||
| ('TensorSetItemByOneTensorWithTupleOfNumberTypeError', { | |||||
| 'block': (TensorSetItemByOneTensorWithTupleOfNumber(value=(0, 1, 2, 3, 4, 5, 6, 7)), {'exception': TypeError}), | |||||
| 'desc_inputs': [Tensor(np.random.randint(5, size=(5, 4)), mstype.int32)], | |||||
| }), | |||||
| ('TensorSetItemByOneTensorWithTupleOfNumberNumberError', { | |||||
| 'block': (TensorSetItemByOneTensorWithTupleOfNumber(value=(0.0, 1.1, 2.2)), {'exception': ValueError}), | |||||
| 'desc_inputs': [Tensor(np.random.randint(5, size=(5, 4)), mstype.int32)], | |||||
| }), | |||||
| ('TensorSetItemByOneTensorWithTupleOfTensorDtyeError', { | |||||
| 'block': (TensorSetItemByOneTensorWithTupleOfTensor(), {'exception': TypeError}), | |||||
| 'desc_inputs': [Tensor(np.random.randint(6, size=(5, 4)), mstype.int32), | |||||
| Tensor(np.zeros((8,), np.int32)), | |||||
| Tensor(np.ones((8,), np.int32)), | |||||
| Tensor(np.ones((8,), np.float32) * 2)], | |||||
| }), | |||||
| ('TensorSetItemByTensorsWithNumberTypeError', { | |||||
| 'block': (TensorSetItemByTensorsWithNumber(value=0), {'exception': TypeError}), | |||||
| 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), | |||||
| Tensor(np.random.randint(7, size=(4, 5)), mstype.int32), | |||||
| Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)], | |||||
| }), | |||||
| ('TensorSetItemByTensorsWithTensorShapeError', { | |||||
| 'block': (TensorSetItemByTensorsWithTensor(), {'exception': ValueError}), | |||||
| 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), | |||||
| Tensor(np.random.randint(7, size=(4, 5)), mstype.int32), | |||||
| Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32), | |||||
| Tensor(np.zeros((2, 5)), mstype.float32)], | |||||
| }), | |||||
| ('TensorSetItemByTensorsWithTensorTypeError', { | |||||
| 'block': (TensorSetItemByTensorsWithTensor(), {'exception': TypeError}), | |||||
| 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), | |||||
| Tensor(np.random.randint(7, size=(4, 5)), mstype.int32), | |||||
| Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32), | |||||
| Tensor(np.zeros((4, 5)), mstype.int32)], | |||||
| }), | |||||
| ('TensorSetItemByTensorsWithTensorNumberError', { | |||||
| 'block': (TensorSetItemByTensorsWithTensorNumberError(), {'exception': IndexError}), | |||||
| 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), | |||||
| Tensor(np.random.randint(7, size=(4, 5)), mstype.int32), | |||||
| Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32), | |||||
| Tensor(np.random.randint(8, size=(1, 3, 4, 5)), mstype.int32), | |||||
| Tensor(np.zeros((2, 5)), mstype.float32)], | |||||
| }), | |||||
| ('TensorSetItemByTensorsWithTupleOfNumberTypeError', { | |||||
| 'block': (TensorSetItemByTensorsWithTupleOfNumber(value=(0, 1, 2, 3, 4)), {'exception': TypeError}), | |||||
| 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), | |||||
| Tensor(np.random.randint(7, size=(4, 5)), mstype.int32), | |||||
| Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)], | |||||
| }), | |||||
| ('TensorSetItemByTensorsWithTupleOfNumberNumberError', { | |||||
| 'block': (TensorSetItemByTensorsWithTupleOfNumber(value=(0.0, 1.0, 2.0, 3.0)), {'exception': ValueError}), | |||||
| 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), | |||||
| Tensor(np.random.randint(7, size=(4, 5)), mstype.int32), | |||||
| Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)], | |||||
| }), | |||||
| ('TensorSetItemByTensorsWithTupleOfTensorNumberError', { | |||||
| 'block': (TensorSetItemByTensorsWithTupleOfTensorNumberError(), {'exception': ValueError}), | |||||
| 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), | |||||
| Tensor(np.random.randint(7, size=(4, 5)), mstype.int32), | |||||
| Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32), | |||||
| Tensor(np.zeros((4, 5)), mstype.float32), | |||||
| Tensor(np.ones((4, 5)), mstype.float32)], | |||||
| }), | |||||
| ('TensorSetItemByTensorsWithTupleOfTensorTypeError', { | |||||
| 'block': (TensorSetItemByTensorsWithTupleOfTensor(), {'exception': TypeError}), | |||||
| 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), | |||||
| Tensor(np.random.randint(7, size=(4, 5)), mstype.int32), | |||||
| Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32), | |||||
| Tensor(np.zeros((4, 5)), mstype.float32), | |||||
| Tensor(np.ones((4, 5)), mstype.int32), | |||||
| Tensor(np.ones((4, 5)) * 2, mstype.int32)], | |||||
| }) | |||||
| ] | ] | ||||
| @mindspore_test(pipeline_for_compile_forward_ge_graph_for_case_by_case_config) | @mindspore_test(pipeline_for_compile_forward_ge_graph_for_case_by_case_config) | ||||
| def test_compile(): | |||||
| context.set_context(mode=context.GRAPH_MODE) | |||||
| def test_exec(): | |||||
| context.set_context(mode=context.GRAPH_MODE, save_graphs=True) | |||||
| return test_cases | return test_cases | ||||
| @mindspore_test(pipeline_for_compile_forward_ge_graph_for_case_by_case_config_exception) | |||||
| def test_check_exception(): | |||||
| return raise_error_set | |||||
| def test_tensor_slice_reduce_out_of_bounds_neg(): | def test_tensor_slice_reduce_out_of_bounds_neg(): | ||||
| class NetWork(Cell): | class NetWork(Cell): | ||||
| def __init__(self): | def __init__(self): | ||||
| @@ -26,7 +26,7 @@ from mindspore.ops import functional as F | |||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from mindspore.ops._grad.grad_base import bprop_getters | from mindspore.ops._grad.grad_base import bprop_getters | ||||
| from mindspore.ops._grad.grad_math_ops import binop_grad_common | from mindspore.ops._grad.grad_math_ops import binop_grad_common | ||||
| from mindspore.ops._utils import _get_broadcast_shape | |||||
| from mindspore.ops._utils import get_broadcast_shape | |||||
| from mindspore.ops.primitive import PrimitiveWithInfer, prim_attr_register | from mindspore.ops.primitive import PrimitiveWithInfer, prim_attr_register | ||||
| from mindspore.train.loss_scale_manager import DynamicLossScaleManager | from mindspore.train.loss_scale_manager import DynamicLossScaleManager | ||||
| @@ -54,7 +54,7 @@ class MockSub(PrimitiveWithInfer): | |||||
| self.init_prim_io_names(inputs=['x', 'y'], outputs=['output']) | self.init_prim_io_names(inputs=['x', 'y'], outputs=['output']) | ||||
| def infer_shape(self, x_shape, y_shape): | def infer_shape(self, x_shape, y_shape): | ||||
| return _get_broadcast_shape(x_shape, y_shape) | |||||
| return get_broadcast_shape(x_shape, y_shape) | |||||
| def infer_dtype(self, x_dtype, y_dtype): | def infer_dtype(self, x_dtype, y_dtype): | ||||
| return x_dtype | return x_dtype | ||||