| @@ -37,8 +37,9 @@ | |||||
| #include "dataset/kernels/image/resize_bilinear_op.h" | #include "dataset/kernels/image/resize_bilinear_op.h" | ||||
| #include "dataset/kernels/image/resize_op.h" | #include "dataset/kernels/image/resize_op.h" | ||||
| #include "dataset/kernels/image/uniform_aug_op.h" | #include "dataset/kernels/image/uniform_aug_op.h" | ||||
| #include "dataset/kernels/data/type_cast_op.h" | |||||
| #include "dataset/kernels/data/fill_op.h" | #include "dataset/kernels/data/fill_op.h" | ||||
| #include "dataset/kernels/data/slice_op.h" | |||||
| #include "dataset/kernels/data/type_cast_op.h" | |||||
| #include "dataset/engine/datasetops/source/cifar_op.h" | #include "dataset/engine/datasetops/source/cifar_op.h" | ||||
| #include "dataset/engine/datasetops/source/image_folder_op.h" | #include "dataset/engine/datasetops/source/image_folder_op.h" | ||||
| #include "dataset/engine/datasetops/source/io_block.h" | #include "dataset/engine/datasetops/source/io_block.h" | ||||
| @@ -369,6 +370,37 @@ void bindTensorOps2(py::module *m) { | |||||
| *m, "FillOp", "Tensor operation to return tensor filled with same value as input fill value.") | *m, "FillOp", "Tensor operation to return tensor filled with same value as input fill value.") | ||||
| .def(py::init<std::shared_ptr<Tensor>>()); | .def(py::init<std::shared_ptr<Tensor>>()); | ||||
| (void)py::class_<SliceOp, TensorOp, std::shared_ptr<SliceOp>>(*m, "SliceOp", "") | |||||
| .def(py::init<bool>()) | |||||
| .def(py::init([](const py::list &py_list) { | |||||
| std::vector<dsize_t> c_list; | |||||
| for (auto l : py_list) { | |||||
| if (!l.is_none()) { | |||||
| c_list.push_back(py::reinterpret_borrow<py::int_>(l)); | |||||
| } | |||||
| } | |||||
| return std::make_shared<SliceOp>(c_list); | |||||
| })) | |||||
| .def(py::init([](const py::tuple &py_slice) { | |||||
| if (py_slice.size() != 3) { | |||||
| THROW_IF_ERROR(Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Wrong slice object")); | |||||
| } | |||||
| Slice c_slice; | |||||
| if (!py_slice[0].is_none() && !py_slice[1].is_none() && !py_slice[2].is_none()) { | |||||
| c_slice = Slice(py::reinterpret_borrow<py::int_>(py_slice[0]), py::reinterpret_borrow<py::int_>(py_slice[1]), | |||||
| py::reinterpret_borrow<py::int_>(py_slice[2])); | |||||
| } else if (py_slice[0].is_none() && py_slice[2].is_none()) { | |||||
| c_slice = Slice(py::reinterpret_borrow<py::int_>(py_slice[1])); | |||||
| } else if (!py_slice[0].is_none() && !py_slice[1].is_none()) { | |||||
| c_slice = Slice(py::reinterpret_borrow<py::int_>(py_slice[0]), py::reinterpret_borrow<py::int_>(py_slice[1])); | |||||
| } | |||||
| if (!c_slice.valid()) { | |||||
| THROW_IF_ERROR(Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Wrong slice object")); | |||||
| } | |||||
| return std::make_shared<SliceOp>(c_slice); | |||||
| })); | |||||
| (void)py::class_<RandomRotationOp, TensorOp, std::shared_ptr<RandomRotationOp>>( | (void)py::class_<RandomRotationOp, TensorOp, std::shared_ptr<RandomRotationOp>>( | ||||
| *m, "RandomRotationOp", | *m, "RandomRotationOp", | ||||
| "Tensor operation to apply RandomRotation." | "Tensor operation to apply RandomRotation." | ||||
| @@ -916,6 +916,61 @@ Status Tensor::CopyLastDimAt(const std::shared_ptr<Tensor> &src, const std::vect | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(memcpy_s(dst_addr, len, src_addr, len) == 0, "memcpy error"); | CHECK_FAIL_RETURN_UNEXPECTED(memcpy_s(dst_addr, len, src_addr, len) == 0, "memcpy error"); | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status Tensor::Slice(std::shared_ptr<Tensor> *out, const std::vector<dsize_t> &indices) { | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(shape_.Rank() == 1, "Currently Slice work with rank 1 tensors only."); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(!indices.empty(), "Indices are empty, generated tensor would be empty."); | |||||
| if (type_.IsNumeric()) { | |||||
| return SliceNumeric(out, indices); | |||||
| } else { | |||||
| return SliceString(out, indices); | |||||
| } | |||||
| } | |||||
| Status Tensor::SliceNumeric(std::shared_ptr<Tensor> *out, const std::vector<dsize_t> &indices) { | |||||
| RETURN_IF_NOT_OK( | |||||
| CreateTensor(out, TensorImpl::kFlexible, TensorShape({static_cast<dsize_t>(indices.size())}), type_)); | |||||
| (*out)->GetMutableBuffer(); | |||||
| dsize_t out_index = 0; | |||||
| dsize_t dim_length = shape_[0]; | |||||
| dsize_t type_size = type_.SizeInBytes(); | |||||
| dsize_t src_start = handleNeg(indices[0], dim_length); | |||||
| uchar *dst_addr = (*out)->data_; | |||||
| dsize_t count = 1; | |||||
| for (dsize_t i = 0; i < indices.size(); i++) { | |||||
| dsize_t cur_index = handleNeg(indices[i], dim_length); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED( | |||||
| cur_index >= 0 && cur_index < dim_length, | |||||
| "Index " + std::to_string(indices[i]) + " is out of bounds [0," + std::to_string(dim_length) + ")"); | |||||
| if (i < indices.size() - 1) { | |||||
| dsize_t next_index = handleNeg(indices[i + 1], dim_length); | |||||
| if (next_index == cur_index + 1) { | |||||
| count++; | |||||
| continue; | |||||
| } | |||||
| } | |||||
| memcpy_s(dst_addr + out_index * type_size, (*out)->SizeInBytes(), data_ + src_start * type_size, count * type_size); | |||||
| out_index += count; | |||||
| if (i < indices.size() - 1) { | |||||
| src_start = handleNeg(indices[i + 1], dim_length); // next index | |||||
| } | |||||
| count = 1; | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| Status Tensor::SliceString(std::shared_ptr<Tensor> *out, const std::vector<dsize_t> &indices) { | |||||
| dsize_t dim_length = shape_[0]; | |||||
| std::vector<std::string> strings; | |||||
| for (dsize_t index : indices) { | |||||
| dsize_t cur_index = handleNeg(index, dim_length); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED( | |||||
| cur_index >= 0 && cur_index < dim_length, | |||||
| "Index " + std::to_string(index) + " is out of bounds [0," + std::to_string(dim_length) + ")"); | |||||
| std::string_view sv; | |||||
| GetItemAt(&sv, {cur_index}); | |||||
| strings.emplace_back(sv); | |||||
| } | |||||
| return CreateTensor(out, strings); | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -347,6 +347,22 @@ class Tensor { | |||||
| return ss.str(); | return ss.str(); | ||||
| } | } | ||||
| // Handle negative indices. | |||||
| static inline dsize_t handleNeg(dsize_t index, dsize_t length) { return (index < 0) ? (index + length) : index; } | |||||
| // Slice tensor bases on the given indicies. Copy the sliced data into out tensor. Only rank1 tensors are supported. | |||||
| // Based on the type of tensor, SliceNumeric or SliceString will be called | |||||
| // @param out Tensor | |||||
| // @param indices vector of indices | |||||
| // @return Status error code | |||||
| Status Slice(std::shared_ptr<Tensor> *out, const std::vector<dsize_t> &indices); | |||||
| // Slice numeric tensors. | |||||
| Status SliceNumeric(std::shared_ptr<Tensor> *out, const std::vector<dsize_t> &indices); | |||||
| // Slice string tensors | |||||
| Status SliceString(std::shared_ptr<Tensor> *out, const std::vector<dsize_t> &indices); | |||||
| // Constructs numpy array from input tensor | // Constructs numpy array from input tensor | ||||
| // @param data this data is the location of python data | // @param data this data is the location of python data | ||||
| // @return Status code | // @return Status code | ||||
| @@ -5,4 +5,5 @@ add_library(kernels-data OBJECT | |||||
| one_hot_op.cc | one_hot_op.cc | ||||
| type_cast_op.cc | type_cast_op.cc | ||||
| to_float16_op.cc | to_float16_op.cc | ||||
| fill_op.cc) | |||||
| fill_op.cc | |||||
| slice_op.cc) | |||||
| @@ -0,0 +1,48 @@ | |||||
| /** | |||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "dataset/kernels/data/slice_op.h" | |||||
| #include "dataset/core/tensor.h" | |||||
| #include "dataset/kernels/data/data_utils.h" | |||||
| #include "dataset/kernels/tensor_op.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| Status SliceOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) { | |||||
| IO_CHECK(input, output); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(input->shape().Rank() == 1, "SliceOp supports 1D Tensors only for now."); | |||||
| // if `all` flag is true, output is just the input. | |||||
| if (all_) { | |||||
| *output = input; | |||||
| return Status::OK(); | |||||
| } | |||||
| // if slice object was provided, indices should be empty. Generate indices from the slice object. | |||||
| if (slice_.valid() && indices_.empty()) { | |||||
| dsize_t len = input->shape()[0]; | |||||
| indices_ = slice_.Indices(len); | |||||
| return input->Slice(output, indices_); | |||||
| } | |||||
| // if indices are not empty, slices should be invalid, use indices_ to slice | |||||
| if (!indices_.empty() && !slice_.valid()) { | |||||
| return input->Slice(output, indices_); | |||||
| } | |||||
| RETURN_STATUS_UNEXPECTED("The indexing parameters are invalid"); | |||||
| } | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,83 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef DATASET_KERNELS_DATA_SLICE_OP_H_ | |||||
| #define DATASET_KERNELS_DATA_SLICE_OP_H_ | |||||
| #include <algorithm> | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <utility> | |||||
| #include <vector> | |||||
| #include "dataset/core/tensor.h" | |||||
| #include "dataset/kernels/tensor_op.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| class Slice { | |||||
| public: | |||||
| Slice() : start_(0), stop_(0), step_(0) {} | |||||
| Slice(dsize_t start, dsize_t stop, dsize_t step) : start_(start), stop_(stop), step_(step) {} | |||||
| Slice(dsize_t start, dsize_t stop) : start_(start), stop_(stop), step_(1) {} | |||||
| explicit Slice(dsize_t stop) : start_(0), stop_(stop), step_(1) {} | |||||
| std::vector<dsize_t> Indices(dsize_t length) { | |||||
| std::vector<dsize_t> indices; | |||||
| dsize_t index = std::min(Tensor::handleNeg(start_, length), length); | |||||
| dsize_t end_index = std::min(Tensor::handleNeg(stop_, length), length); | |||||
| if (step_ > 0) { | |||||
| for (; index < end_index; index += step_) { | |||||
| indices.push_back(index); | |||||
| } | |||||
| } else { | |||||
| for (; index > end_index; index += step_) { | |||||
| indices.push_back(index); | |||||
| } | |||||
| } | |||||
| return indices; | |||||
| } | |||||
| bool valid() { return !(start_ == 0 && stop_ == 0 && step_ == 0); } | |||||
| dsize_t start_; | |||||
| dsize_t stop_; | |||||
| dsize_t step_; | |||||
| }; | |||||
| class SliceOp : public TensorOp { | |||||
| public: | |||||
| explicit SliceOp(std::vector<dsize_t> indices) : indices_(std::move(indices)) {} | |||||
| explicit SliceOp(Slice slice) : slice_(slice) {} | |||||
| explicit SliceOp(bool all) : all_(all) {} | |||||
| ~SliceOp() override = default; | |||||
| void Print(std::ostream &out) const override { out << "SliceOp"; } | |||||
| Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override; | |||||
| private: | |||||
| // only on of the following will be valid | |||||
| // given indices to slice the Tensor. Empty vector if invalid. | |||||
| std::vector<dsize_t> indices_; | |||||
| // Slice object. All start, stop and step are 0 if invalid. | |||||
| Slice slice_; | |||||
| // Flag to read all indcies in the dim. | |||||
| bool all_ = false; | |||||
| }; | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| #endif // DATASET_KERNELS_DATA_ONE_HOT_OP_H_ | |||||
| @@ -17,7 +17,8 @@ This module c_transforms provides common operations, including OneHotOp and Type | |||||
| """ | """ | ||||
| import numpy as np | import numpy as np | ||||
| import mindspore._c_dataengine as cde | import mindspore._c_dataengine as cde | ||||
| from .validators import check_num_classes, check_de_type, check_fill_value | |||||
| from .validators import check_num_classes, check_de_type, check_fill_value, check_slice_op | |||||
| from ..core.datatypes import mstype_to_detype | from ..core.datatypes import mstype_to_detype | ||||
| @@ -64,3 +65,46 @@ class TypeCast(cde.TypeCastOp): | |||||
| data_type = mstype_to_detype(data_type) | data_type = mstype_to_detype(data_type) | ||||
| self.data_type = str(data_type) | self.data_type = str(data_type) | ||||
| super().__init__(data_type) | super().__init__(data_type) | ||||
| class Slice(cde.SliceOp): | |||||
| """ | |||||
| Slice operation to extract a tensor out using the given n slices. | |||||
| The functionality of Slice is similar to NumPy indexing feature. | |||||
| (Currently only rank 1 Tensors are supported) | |||||
| Args: | |||||
| *slices: Maximum n number of objects to slice a tensor of rank n. | |||||
| One object in slices can be one of: | |||||
| 1. int: slice this index only. Negative index is supported. | |||||
| 2. slice object: slice the generated indices from the slice object. Similar to `start:stop:step`. | |||||
| 3. None: slice the whole dimension. Similar to `:` in python indexing. | |||||
| 4. Ellipses ...: slice all dimensions between the two slices. | |||||
| Examples: | |||||
| >>> # Data before | |||||
| >>> # | col | | |||||
| >>> # +---------+ | |||||
| >>> # | [1,2,3] | | |||||
| >>> # +---------| | |||||
| >>> data = data.map(operations=Slice(slice(1,3))) # slice indices 1 and 2 only | |||||
| >>> # Data after | |||||
| >>> # | col | | |||||
| >>> # +------------+ | |||||
| >>> # | [1,2] | | |||||
| >>> # +------------| | |||||
| """ | |||||
| @check_slice_op | |||||
| def __init__(self, *slices): | |||||
| dim0 = slices[0] | |||||
| if isinstance(dim0, int): | |||||
| dim0 = [dim0] | |||||
| elif dim0 is None: | |||||
| dim0 = True | |||||
| elif isinstance(dim0, slice): | |||||
| dim0 = (dim0.start, dim0.stop, dim0.step) | |||||
| elif dim0 is Ellipsis: | |||||
| dim0 = True | |||||
| super().__init__(dim0) | |||||
| @@ -15,6 +15,7 @@ | |||||
| """Validators for TensorOps. | """Validators for TensorOps. | ||||
| """ | """ | ||||
| from functools import wraps | from functools import wraps | ||||
| from mindspore._c_expression import typing | from mindspore._c_expression import typing | ||||
| # POS_INT_MIN is used to limit values from starting from 0 | # POS_INT_MIN is used to limit values from starting from 0 | ||||
| @@ -195,3 +196,20 @@ def check_de_type(method): | |||||
| return method(self, **kwargs) | return method(self, **kwargs) | ||||
| return new_method | return new_method | ||||
| def check_slice_op(method): | |||||
| """Wrapper method to check the parameters of slice.""" | |||||
| @wraps(method) | |||||
| def new_method(self, *args): | |||||
| for i, arg in enumerate(args): | |||||
| if arg is not None and arg is not Ellipsis and not isinstance(arg, (int, slice, list)): | |||||
| raise TypeError("Indexing of dim " + str(i) + "is not of valid type") | |||||
| if isinstance(arg, list): | |||||
| for a in arg: | |||||
| if not isinstance(a, int): | |||||
| raise TypeError("Index " + a + " is not an int") | |||||
| return method(self, *args) | |||||
| return new_method | |||||
| @@ -28,17 +28,13 @@ using namespace mindspore::dataset; | |||||
| namespace py = pybind11; | namespace py = pybind11; | ||||
| class MindDataTestTensorDE : public UT::Common { | class MindDataTestTensorDE : public UT::Common { | ||||
| public: | public: | ||||
| MindDataTestTensorDE() {} | |||||
| MindDataTestTensorDE() {} | |||||
| void SetUp() { | |||||
| GlobalInit(); | |||||
| } | |||||
| void SetUp() { GlobalInit(); } | |||||
| }; | }; | ||||
| TEST_F(MindDataTestTensorDE, Basics) { | TEST_F(MindDataTestTensorDE, Basics) { | ||||
| std::shared_ptr<Tensor> t = std::make_shared<Tensor>(TensorShape({2, 3}), DataType(DataType::DE_UINT64)); | std::shared_ptr<Tensor> t = std::make_shared<Tensor>(TensorShape({2, 3}), DataType(DataType::DE_UINT64)); | ||||
| ASSERT_TRUE((t->AllocateBuffer(t->SizeInBytes())).IsOk()); | ASSERT_TRUE((t->AllocateBuffer(t->SizeInBytes())).IsOk()); | ||||
| @@ -167,8 +163,7 @@ TEST_F(MindDataTestTensorDE, InsertTensor) { | |||||
| // Test the bug of Tensor::ToString will exec failed for Tensor which store bool values | // Test the bug of Tensor::ToString will exec failed for Tensor which store bool values | ||||
| TEST_F(MindDataTestTensorDE, BoolTensor) { | TEST_F(MindDataTestTensorDE, BoolTensor) { | ||||
| std::shared_ptr<Tensor> t = std::make_shared<Tensor>(TensorShape({2}), | |||||
| DataType(DataType::DE_BOOL)); | |||||
| std::shared_ptr<Tensor> t = std::make_shared<Tensor>(TensorShape({2}), DataType(DataType::DE_BOOL)); | |||||
| t->SetItemAt<bool>({0}, true); | t->SetItemAt<bool>({0}, true); | ||||
| t->SetItemAt<bool>({1}, true); | t->SetItemAt<bool>({1}, true); | ||||
| std::string out = t->ToString(); | std::string out = t->ToString(); | ||||
| @@ -255,14 +250,19 @@ void checkCvMat(TensorShape shape, DataType type) { | |||||
| } else { | } else { | ||||
| ASSERT_EQ(m.size[0], shape[0]); | ASSERT_EQ(m.size[0], shape[0]); | ||||
| } | } | ||||
| if (shape.Rank() == 3) { ASSERT_EQ(m.channels(), shape[2]); } | |||||
| if (shape.Rank() == 3) { | |||||
| ASSERT_EQ(m.channels(), shape[2]); | |||||
| } | |||||
| ASSERT_EQ(m.dims, 2); | ASSERT_EQ(m.dims, 2); | ||||
| ASSERT_EQ(m.size.dims(), 2); | ASSERT_EQ(m.size.dims(), 2); | ||||
| if (shape.Rank() > 0) { ASSERT_EQ(m.rows, shape[0]); } | |||||
| if (shape.Rank() > 1) { ASSERT_EQ(m.cols, shape[1]); } | |||||
| if (shape.Rank() > 0) { | |||||
| ASSERT_EQ(m.rows, shape[0]); | |||||
| } | |||||
| if (shape.Rank() > 1) { | |||||
| ASSERT_EQ(m.cols, shape[1]); | |||||
| } | |||||
| } else { | } else { | ||||
| for (dsize_t i = 0; i < shape.Rank(); i++) | |||||
| ASSERT_EQ(m.size[static_cast<int>(i)], shape[i]); | |||||
| for (dsize_t i = 0; i < shape.Rank(); i++) ASSERT_EQ(m.size[static_cast<int>(i)], shape[i]); | |||||
| ASSERT_EQ(m.dims, shape.Rank()); | ASSERT_EQ(m.dims, shape.Rank()); | ||||
| ASSERT_EQ(m.size.dims(), shape.Rank()); | ASSERT_EQ(m.size.dims(), shape.Rank()); | ||||
| ASSERT_EQ(m.rows, -1); | ASSERT_EQ(m.rows, -1); | ||||
| @@ -394,3 +394,16 @@ TEST_F(MindDataTestTensorDE, TensorIterator) { | |||||
| } | } | ||||
| ASSERT_TRUE(ctr == 6); | ASSERT_TRUE(ctr == 6); | ||||
| } | } | ||||
| TEST_F(MindDataTestTensorDE, TensorSlice) { | |||||
| std::shared_ptr<Tensor> t; | |||||
| Tensor::CreateTensor(&t, std::vector<dsize_t>{0, 1, 2, 3, 4}); | |||||
| std::shared_ptr<Tensor> t2; | |||||
| auto x = std::vector<dsize_t>{0, 3, 4}; | |||||
| std::shared_ptr<Tensor> expected; | |||||
| Tensor::CreateTensor(&expected, x); | |||||
| t->Slice(&t2, x); | |||||
| ASSERT_EQ(*t2, *expected); | |||||
| t->Slice(&t2, std::vector<dsize_t>{0, 1, 2, 3, 4}); | |||||
| ASSERT_EQ(*t2, *t); | |||||
| } | |||||
| @@ -0,0 +1,211 @@ | |||||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================== | |||||
| """ | |||||
| Testing TypeCast op in DE | |||||
| """ | |||||
| import numpy as np | |||||
| import pytest | |||||
| import mindspore.dataset as ds | |||||
| import mindspore.dataset.transforms.c_transforms as ops | |||||
| def slice_compare(array, indexing): | |||||
| data = ds.NumpySlicesDataset([array]) | |||||
| array = np.array(array) | |||||
| data = data.map(operations=ops.Slice(indexing)) | |||||
| for d in data: | |||||
| if indexing is None: | |||||
| array = array[:] | |||||
| else: | |||||
| array = array[indexing] | |||||
| np.testing.assert_array_equal(array, d[0]) | |||||
| def test_slice_all(): | |||||
| slice_compare([1, 2, 3, 4, 5], None) | |||||
| slice_compare([1, 2, 3, 4, 5], ...) | |||||
| def test_slice_single_index(): | |||||
| slice_compare([1, 2, 3, 4, 5], 0) | |||||
| slice_compare([1, 2, 3, 4, 5], 4) | |||||
| slice_compare([1, 2, 3, 4, 5], 2) | |||||
| slice_compare([1, 2, 3, 4, 5], -1) | |||||
| slice_compare([1, 2, 3, 4, 5], -5) | |||||
| slice_compare([1, 2, 3, 4, 5], -3) | |||||
| def test_slice_list_index(): | |||||
| slice_compare([1, 2, 3, 4, 5], [0, 1, 4]) | |||||
| slice_compare([1, 2, 3, 4, 5], [4, 1, 0]) | |||||
| slice_compare([1, 2, 3, 4, 5], [-1, 1, 0]) | |||||
| slice_compare([1, 2, 3, 4, 5], [-1, -4, -2]) | |||||
| slice_compare([1, 2, 3, 4, 5], [3, 3, 3]) | |||||
| slice_compare([1, 2, 3, 4, 5], [1, 1, 1, 1, 1]) | |||||
| def test_slice_slice_obj_2s(): | |||||
| slice_compare([1, 2, 3, 4, 5], slice(0, 2)) | |||||
| slice_compare([1, 2, 3, 4, 5], slice(2, 4)) | |||||
| slice_compare([1, 2, 3, 4, 5], slice(4, 10)) | |||||
| def test_slice_slice_obj_1s(): | |||||
| slice_compare([1, 2, 3, 4, 5], slice(1)) | |||||
| slice_compare([1, 2, 3, 4, 5], slice(4)) | |||||
| slice_compare([1, 2, 3, 4, 5], slice(10)) | |||||
| def test_slice_slice_obj_3s(): | |||||
| slice_compare([1, 2, 3, 4, 5], slice(0, 2, 1)) | |||||
| slice_compare([1, 2, 3, 4, 5], slice(0, 4, 1)) | |||||
| slice_compare([1, 2, 3, 4, 5], slice(0, 10, 1)) | |||||
| slice_compare([1, 2, 3, 4, 5], slice(0, 5, 2)) | |||||
| slice_compare([1, 2, 3, 4, 5], slice(0, 2, 2)) | |||||
| slice_compare([1, 2, 3, 4, 5], slice(0, 1, 2)) | |||||
| slice_compare([1, 2, 3, 4, 5], slice(4, 5, 1)) | |||||
| slice_compare([1, 2, 3, 4, 5], slice(2, 5, 3)) | |||||
| def test_slice_slice_obj_3s_double(): | |||||
| slice_compare([1., 2., 3., 4., 5.], slice(0, 2, 1)) | |||||
| slice_compare([1., 2., 3., 4., 5.], slice(0, 4, 1)) | |||||
| slice_compare([1., 2., 3., 4., 5.], slice(0, 10, 1)) | |||||
| slice_compare([1., 2., 3., 4., 5.], slice(0, 5, 2)) | |||||
| slice_compare([1., 2., 3., 4., 5.], slice(0, 2, 2)) | |||||
| slice_compare([1., 2., 3., 4., 5.], slice(0, 1, 2)) | |||||
| slice_compare([1., 2., 3., 4., 5.], slice(4, 5, 1)) | |||||
| slice_compare([1., 2., 3., 4., 5.], slice(2, 5, 3)) | |||||
| def test_slice_slice_obj_neg(): | |||||
| slice_compare([1, 2, 3, 4, 5], slice(-1, -5, -1)) | |||||
| slice_compare([1, 2, 3, 4, 5], slice(-1)) | |||||
| slice_compare([1, 2, 3, 4, 5], slice(-2)) | |||||
| slice_compare([1, 2, 3, 4, 5], slice(-1, -5, -2)) | |||||
| slice_compare([1, 2, 3, 4, 5], slice(-5, -1, 2)) | |||||
| slice_compare([1, 2, 3, 4, 5], slice(-5, -1)) | |||||
| def test_slice_exceptions(): | |||||
| with pytest.raises(RuntimeError) as info: | |||||
| slice_compare([1, 2, 3, 4, 5], 5) | |||||
| assert "Index 5 is out of bounds [0,5)" in str(info.value) | |||||
| with pytest.raises(RuntimeError) as info: | |||||
| slice_compare([1, 2, 3, 4, 5], slice(0)) | |||||
| assert "Indices are empty, generated tensor would be empty." in str(info.value) | |||||
| with pytest.raises(RuntimeError) as info: | |||||
| slice_compare([1, 2, 3, 4, 5], slice(5, 10, 1)) | |||||
| assert "Indices are empty, generated tensor would be empty." in str(info.value) | |||||
| with pytest.raises(RuntimeError) as info: | |||||
| slice_compare([1, 2, 3, 4, 5], slice(-1, -5, 1)) | |||||
| assert "Indices are empty, generated tensor would be empty." in str(info.value) | |||||
| def test_slice_all_str(): | |||||
| slice_compare([b"1", b"2", b"3", b"4", b"5"], None) | |||||
| slice_compare([b"1", b"2", b"3", b"4", b"5"], ...) | |||||
| def test_slice_single_index_str(): | |||||
| slice_compare([b"1", b"2", b"3", b"4", b"5"], 0) | |||||
| slice_compare([b"1", b"2", b"3", b"4", b"5"], 4) | |||||
| slice_compare([b"1", b"2", b"3", b"4", b"5"], 2) | |||||
| slice_compare([b"1", b"2", b"3", b"4", b"5"], -1) | |||||
| slice_compare([b"1", b"2", b"3", b"4", b"5"], -5) | |||||
| slice_compare([b"1", b"2", b"3", b"4", b"5"], -3) | |||||
| def test_slice_list_index_str(): | |||||
| slice_compare([b"1", b"2", b"3", b"4", b"5"], [0, 1, 4]) | |||||
| slice_compare([b"1", b"2", b"3", b"4", b"5"], [4, 1, 0]) | |||||
| slice_compare([b"1", b"2", b"3", b"4", b"5"], [-1, 1, 0]) | |||||
| slice_compare([b"1", b"2", b"3", b"4", b"5"], [-1, -4, -2]) | |||||
| slice_compare([b"1", b"2", b"3", b"4", b"5"], [3, 3, 3]) | |||||
| slice_compare([b"1", b"2", b"3", b"4", b"5"], [1, 1, 1, 1, 1]) | |||||
| def test_slice_slice_obj_2s_str(): | |||||
| slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(0, 2)) | |||||
| slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(2, 4)) | |||||
| slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(4, 10)) | |||||
| def test_slice_slice_obj_1s_str(): | |||||
| slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(1)) | |||||
| slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(4)) | |||||
| slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(10)) | |||||
| def test_slice_slice_obj_3s_str(): | |||||
| slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(0, 2, 1)) | |||||
| slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(0, 4, 1)) | |||||
| slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(0, 10, 1)) | |||||
| slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(0, 5, 2)) | |||||
| slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(0, 2, 2)) | |||||
| slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(0, 1, 2)) | |||||
| slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(4, 5, 1)) | |||||
| slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(2, 5, 3)) | |||||
| def test_slice_slice_obj_neg_str(): | |||||
| slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(-1, -5, -1)) | |||||
| slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(-1)) | |||||
| slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(-2)) | |||||
| slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(-1, -5, -2)) | |||||
| slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(-5, -1, 2)) | |||||
| slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(-5, -1)) | |||||
| def test_slice_exceptions_str(): | |||||
| with pytest.raises(RuntimeError) as info: | |||||
| slice_compare([b"1", b"2", b"3", b"4", b"5"], 5) | |||||
| assert "Index 5 is out of bounds [0,5)" in str(info.value) | |||||
| with pytest.raises(RuntimeError) as info: | |||||
| slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(0)) | |||||
| assert "Indices are empty, generated tensor would be empty." in str(info.value) | |||||
| with pytest.raises(RuntimeError) as info: | |||||
| slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(5, 10, 1)) | |||||
| assert "Indices are empty, generated tensor would be empty." in str(info.value) | |||||
| with pytest.raises(RuntimeError) as info: | |||||
| slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(-1, -5, 1)) | |||||
| assert "Indices are empty, generated tensor would be empty." in str(info.value) | |||||
| if __name__ == "__main__": | |||||
| test_slice_all() | |||||
| test_slice_single_index() | |||||
| test_slice_list_index() | |||||
| test_slice_slice_obj_3s() | |||||
| test_slice_slice_obj_2s() | |||||
| test_slice_slice_obj_1s() | |||||
| test_slice_slice_obj_neg() | |||||
| test_slice_exceptions() | |||||
| test_slice_slice_obj_3s_double() | |||||
| test_slice_all_str() | |||||
| test_slice_single_index_str() | |||||
| test_slice_list_index_str() | |||||
| test_slice_slice_obj_3s_str() | |||||
| test_slice_slice_obj_2s_str() | |||||
| test_slice_slice_obj_1s_str() | |||||
| test_slice_slice_obj_neg_str() | |||||
| test_slice_exceptions_str() | |||||