| @@ -19,6 +19,7 @@ | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| constexpr int MAX_DIMS = 8; | |||
| void SliceCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| CheckParam(kernel_node); | |||
| input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| @@ -205,7 +206,7 @@ void SliceCPUKernel::CheckParam(const CNodePtr &kernel_node) const { | |||
| MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but SliceCPUKernel needs 1 output."; | |||
| } | |||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| if (input_shape.size() > 4) { | |||
| if (input_shape.size() > MAX_DIMS) { | |||
| MS_LOG(EXCEPTION) << "Input dims is " << input_shape.size() << ", but SliceCPUKernel olny support 4d or lower."; | |||
| } | |||
| if (input_shape.size() == 0) { | |||
| @@ -26,7 +26,7 @@ | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| constexpr int MAX_DIMS = 7; | |||
| constexpr int MAX_DIMS = 8; | |||
| template <typename T> | |||
| class StridedSliceGpuKernel : public GpuKernel { | |||
| public: | |||
| @@ -51,7 +51,8 @@ class StridedSliceGpuKernel : public GpuKernel { | |||
| bool Init(const CNodePtr &kernel_node) override { | |||
| input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| if (input_shape_.size() > MAX_DIMS) { | |||
| MS_LOG(ERROR) << "StridedSlice support support dims less than " << input_shape_.size(); | |||
| MS_LOG(ERROR) << "StridedSlice support dims no more than " << MAX_DIMS << ", but the input shape is " | |||
| << input_shape_.size(); | |||
| return false; | |||
| } | |||
| @@ -255,28 +255,38 @@ def tensor_index_by_tensor(data, tensor_index): | |||
| "the index tensor data type only support mstype.int32.") | |||
| def _tensor_index_by_tuple_slice(data, t): | |||
| def _tensor_index_by_tuple_slice(data, tuple_index): | |||
| """Tensor getitem by a tuple of slice""" | |||
| shape = F.shape(data) | |||
| if len(t) > len(shape): | |||
| if len(tuple_index) > len(shape): | |||
| const_utils.raise_index_error("When tensor is indexed by a tuple, " | |||
| "the length of the tuple cannot be greater than the dimension of the tensor.") | |||
| begin_strides, end_strides, step_strides, shrink_axis_mask = \ | |||
| const_utils.get_stride_info_from_tuple(shape, t) | |||
| const_utils.get_stride_info_from_tuple(shape, tuple_index) | |||
| return P.StridedSlice(0, 0, 0, 0, shrink_axis_mask)(data, begin_strides, end_strides, step_strides) | |||
| def tensor_expand_dims(data, tuple_index): | |||
| """Expand tensor dims by tuple contains None and replace the None by slice in tuple_index """ | |||
| none_positions, tuple_index_without_none = const_utils.split_tuple_index_for_none(tuple_index) | |||
| for position in none_positions: | |||
| data = F.expand_dims(data, position) | |||
| return data, tuple_index_without_none | |||
| def tensor_index_by_tuple(data, tuple_index): | |||
| """Tensor getitem by tuple of various types""" | |||
| """Tensor getitem by tuple of various types with None""" | |||
| # data, tuple_index_without_none = tensor_expand_dims(data, tuple_index) | |||
| tuple_index_without_none = tuple_index | |||
| if len(tuple_index) == 1: | |||
| return data[tuple_index[0]] | |||
| indexes_types = hyper_map(F.typeof, tuple_index) | |||
| index_elements_type = const_utils.tuple_index_tensor_cnt(indexes_types, const_utils.TENSOR_GETITEM) | |||
| if index_elements_type == const_utils.NO_TENSOR: | |||
| return _tensor_index_by_tuple_slice(data, tuple_index) | |||
| if index_elements_type == const_utils.ALL_TENSOR: | |||
| return _tensor_getitem_by_tuple_of_tensor(data, tuple_index) | |||
| return _tensor_getitem_by_tuple_of_mixed_tensors(data, tuple_index) | |||
| return data[tuple_index_without_none[0]] | |||
| indexes_types = hyper_map(F.typeof, tuple_index_without_none) | |||
| tensor_cnt = const_utils.tuple_index_tensor_cnt(indexes_types, const_utils.TENSOR_GETITEM) | |||
| if tensor_cnt == const_utils.NO_TENSOR: | |||
| return _tensor_index_by_tuple_slice(data, tuple_index_without_none) | |||
| if tensor_cnt == const_utils.ALL_TENSOR: | |||
| return _tensor_getitem_by_tuple_of_tensor(data, tuple_index_without_none) | |||
| return _tensor_getitem_by_tuple_of_mixed_tensors(data, tuple_index_without_none) | |||
| def _tensor_setitem(self, index, value): | |||
| @@ -66,6 +66,19 @@ def check_equal(param1, param2, msg="{},{}"): | |||
| return param1 | |||
| @constexpr | |||
| def split_tuple_index_for_none(tuple_index): | |||
| """return the none_positions and the tuple_index_without_none whose None index is replaced by slice.""" | |||
| none_positions, tuple_index_without_none = (), () | |||
| for idx, item in enumerate(tuple_index): | |||
| if item is None: | |||
| none_positions += (idx,) | |||
| tuple_index_without_none += (slice(None, None, None),) | |||
| else: | |||
| tuple_index_without_none += (item,) | |||
| return none_positions, tuple_index_without_none | |||
| @constexpr | |||
| def check_ellipsis_shape_size(data_shape, value_shape, data_size, value_size): | |||
| """Checks the shape and size of the sensor and value.""" | |||
| @@ -75,35 +88,6 @@ def check_ellipsis_shape_size(data_shape, value_shape, data_size, value_size): | |||
| value_shape, data_shape)) | |||
| @constexpr | |||
| def restrict_int_index(data_shape, tuple_indexes): | |||
| """ | |||
| Check the int index of tuple_indexes if value of index is out of the corresponding data shape | |||
| and turn the negtive int index to positive int index. | |||
| Inputs: | |||
| data_shape: the shape of data. | |||
| tuple_indexes(tuple[mstype.int32]): the tuple of index which will be used in setitem or getitem. | |||
| Outputs: | |||
| tuple_indexes_new(tuple[mstype.int32]): same purpose with tuple_indexes but only contain positive. | |||
| """ | |||
| if tuple_indexes is None: | |||
| return tuple_indexes | |||
| tuple_indexes_new = () | |||
| for i, index in enumerate(tuple_indexes): | |||
| if isinstance(index, mstype.Int): | |||
| if index < -data_shape[i] or index >= data_shape[i]: | |||
| raise_index_error("The index is out of the data's special dimension range.") | |||
| elif index < 0: | |||
| tuple_indexes_new += (tuple_indexes[i]+data_shape[i],) | |||
| else: | |||
| tuple_indexes_new += (tuple_indexes[i],) | |||
| else: | |||
| tuple_indexes_new += (tuple_indexes[i],) | |||
| return tuple_indexes_new | |||
| @constexpr | |||
| def check_tensor_setitem_index(index, element_type=None): | |||
| """Checks tuple index type of tensor assignment.""" | |||
| @@ -213,7 +213,7 @@ def _tensor_getitem_by_tuple(data, tuple_index): | |||
| Inputs: | |||
| data (Tensor): A tensor. | |||
| tuple_index (tuple): Index in tuple. | |||
| tuple_index (tuple): Index in tuple which include ellipsis, slice, int, Tensor, None, list, tuple. | |||
| Outputs: | |||
| Tensor, element type is the same as the element type of data. | |||