| @@ -28,6 +28,7 @@ | |||
| #include "minddata/dataset/kernels/data/slice_op.h" | |||
| #include "minddata/dataset/kernels/data/to_float16_op.h" | |||
| #include "minddata/dataset/kernels/data/type_cast_op.h" | |||
| #include "minddata/dataset/kernels/data/unique_op.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| @@ -42,6 +43,10 @@ PYBIND_REGISTER( | |||
| (void)py::class_<DuplicateOp, TensorOp, std::shared_ptr<DuplicateOp>>(*m, "DuplicateOp").def(py::init<>()); | |||
| })); | |||
| PYBIND_REGISTER(UniqueOp, 1, ([](const py::module *m) { | |||
| (void)py::class_<UniqueOp, TensorOp, std::shared_ptr<UniqueOp>>(*m, "UniqueOp").def(py::init<>()); | |||
| })); | |||
| PYBIND_REGISTER( | |||
| FillOp, 1, ([](const py::module *m) { | |||
| (void)py::class_<FillOp, TensorOp, std::shared_ptr<FillOp>>(*m, "FillOp").def(py::init<std::shared_ptr<Tensor>>()); | |||
| @@ -11,4 +11,5 @@ add_library(kernels-data OBJECT | |||
| mask_op.cc | |||
| concatenate_op.cc | |||
| duplicate_op.cc | |||
| unique_op.cc | |||
| ) | |||
| @@ -706,6 +706,78 @@ Status TensorVectorToBatchTensor(const std::vector<std::shared_ptr<Tensor>> &inp | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| template <typename T> | |||
| struct UniqueOpHashMap { | |||
| using map_type = std::unordered_map<T, int32_t>; | |||
| }; | |||
| template <typename T> | |||
| Status UniqueHelper(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, | |||
| std::shared_ptr<Tensor> *output_idx, std::shared_ptr<Tensor> *output_cnt) { | |||
| const dsize_t N = input->Size(); | |||
| RETURN_IF_NOT_OK(Tensor::CreateEmpty(input->shape(), DataType(DataType::DE_INT32), output_idx)); | |||
| typename UniqueOpHashMap<T>::map_type uniq; | |||
| uniq.reserve(2 * N); | |||
| auto in_iter = input->begin<T>(); | |||
| auto out_idx_iter = (*output_idx)->begin<int32_t>(); | |||
| int32_t i = 0; | |||
| for (; in_iter != input->end<T>(); ++in_iter, ++out_idx_iter) { | |||
| auto it = uniq.emplace(*in_iter, i); | |||
| *out_idx_iter = it.first->second; | |||
| if (it.second) { | |||
| ++i; | |||
| } | |||
| } | |||
| auto uniq_size = uniq.size(); | |||
| RETURN_IF_NOT_OK(Tensor::CreateEmpty(TensorShape({static_cast<int32_t>(uniq_size)}), input->type(), output)); | |||
| auto out_iter = (*output)->begin<T>(); | |||
| for (const auto &it : uniq) { | |||
| *(out_iter + static_cast<ptrdiff_t>(it.second)) = it.first; | |||
| } | |||
| RETURN_IF_NOT_OK( | |||
| Tensor::CreateEmpty(TensorShape({static_cast<int32_t>(uniq_size)}), DataType(DataType::DE_INT32), output_cnt)); | |||
| RETURN_IF_NOT_OK((*output_cnt)->Zero()); | |||
| auto out_cnt_iter = (*output_cnt)->begin<int32_t>(); | |||
| out_idx_iter = (*output_idx)->begin<int32_t>(); | |||
| for (int32_t j = 0; j < N; ++j) { | |||
| auto idx = *(out_idx_iter + static_cast<ptrdiff_t>(j)); | |||
| ++*(out_cnt_iter + static_cast<ptrdiff_t>(idx)); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status Unique(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, | |||
| std::shared_ptr<Tensor> *output_idx, std::shared_ptr<Tensor> *output_cnt) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(input->shape().Rank() == 1, "Only 1D tensors supported."); | |||
| if (input->type() == DataType::DE_INT64) { | |||
| RETURN_IF_NOT_OK(UniqueHelper<int64_t>(input, output, output_idx, output_cnt)); | |||
| } else if (input->type() == DataType::DE_INT32) { | |||
| RETURN_IF_NOT_OK(UniqueHelper<int32_t>(input, output, output_idx, output_cnt)); | |||
| } else if (input->type() == DataType::DE_INT16) { | |||
| RETURN_IF_NOT_OK(UniqueHelper<int16_t>(input, output, output_idx, output_cnt)); | |||
| } else if (input->type() == DataType::DE_INT8) { | |||
| RETURN_IF_NOT_OK(UniqueHelper<int8_t>(input, output, output_idx, output_cnt)); | |||
| } else if (input->type() == DataType::DE_UINT64) { | |||
| RETURN_IF_NOT_OK(UniqueHelper<uint64_t>(input, output, output_idx, output_cnt)); | |||
| } else if (input->type() == DataType::DE_UINT32) { | |||
| RETURN_IF_NOT_OK(UniqueHelper<uint32_t>(input, output, output_idx, output_cnt)); | |||
| } else if (input->type() == DataType::DE_UINT16) { | |||
| RETURN_IF_NOT_OK(UniqueHelper<uint16_t>(input, output, output_idx, output_cnt)); | |||
| } else if (input->type() == DataType::DE_UINT8) { | |||
| RETURN_IF_NOT_OK(UniqueHelper<uint8_t>(input, output, output_idx, output_cnt)); | |||
| } else if (input->type() == DataType::DE_FLOAT16) { | |||
| RETURN_IF_NOT_OK(UniqueHelper<float16>(input, output, output_idx, output_cnt)); | |||
| } else if (input->type() == DataType::DE_FLOAT32) { | |||
| RETURN_IF_NOT_OK(UniqueHelper<float>(input, output, output_idx, output_cnt)); | |||
| } else if (input->type() == DataType::DE_FLOAT64) { | |||
| RETURN_IF_NOT_OK(UniqueHelper<double>(input, output, output_idx, output_cnt)); | |||
| } else { | |||
| RETURN_STATUS_UNEXPECTED("Unique op only supports numeric input."); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -19,6 +19,7 @@ | |||
| #include <memory> | |||
| #include <string> | |||
| #include <vector> | |||
| #include <unordered_map> | |||
| #include "minddata/dataset/core/constants.h" | |||
| #include "minddata/dataset/core/cv_tensor.h" | |||
| #include "minddata/dataset/core/data_type.h" | |||
| @@ -176,6 +177,27 @@ Status BatchTensorToTensorVector(const std::shared_ptr<Tensor> &input, std::vect | |||
| /// \return Status ok/error | |||
| Status TensorVectorToBatchTensor(const std::vector<std::shared_ptr<Tensor>> &input, std::shared_ptr<Tensor> *output); | |||
| /// Helper method that uniques the input tensor | |||
| /// @tparam T type of the tensor | |||
| /// \param input[in] input 1d tensor | |||
| /// \param output[out] output tensor | |||
| /// \param output[out] output tensor of item index | |||
| /// \param output[out] output tensor of item count | |||
| /// \return Status ok/error | |||
| template <typename T> | |||
| Status UniqueHelper(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, | |||
| std::shared_ptr<Tensor> *output_idx, std::shared_ptr<Tensor> *output_cnt); | |||
| /// Unique the input tensor | |||
| /// @tparam T type of the tensor | |||
| /// \param input[in] input 1d tensor | |||
| /// \param output[out] output tensor | |||
| /// \param output[out] output tensor of item index | |||
| /// \param output[out] output tensor of item count | |||
| /// \return Status ok/error | |||
| Status Unique(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, | |||
| std::shared_ptr<Tensor> *output_idx, std::shared_ptr<Tensor> *output_cnt); | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,53 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "minddata/dataset/kernels/data/unique_op.h" | |||
| #include "minddata/dataset/core/tensor.h" | |||
| #include "minddata/dataset/kernels/tensor_op.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| Status UniqueOp::Compute(const TensorRow &input, TensorRow *output) { | |||
| IO_CHECK_VECTOR(input, output); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(input.size() == 1, "Input should be one tensor"); | |||
| auto in_tensor = input[0]; | |||
| auto in_tensor_shape = in_tensor->shape(); | |||
| auto in_tensor_type = in_tensor->type(); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(in_tensor_type.IsNumeric(), "Tensor type must be numeric."); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(in_tensor_shape.Rank() >= 2, "Tensor must be at least 2-D in order to do unique op."); | |||
| CHECK_FAIL_RETURN_UNEXPECTED( | |||
| in_tensor->Size() <= std::numeric_limits<int32_t>::max(), | |||
| "UniqueOp does not support input tensor large than " + std::to_string(std::numeric_limits<int32_t>::max())); | |||
| RETURN_IF_NOT_OK(in_tensor->Reshape(TensorShape({in_tensor->Size()}))); | |||
| std::shared_ptr<Tensor> out; | |||
| std::shared_ptr<Tensor> out_idx; | |||
| std::shared_ptr<Tensor> out_cnt; | |||
| RETURN_IF_NOT_OK(Unique(in_tensor, &out, &out_idx, &out_cnt)); | |||
| output->push_back(out); | |||
| output->push_back(out_idx); | |||
| output->push_back(out_cnt); | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,45 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_DATA_UNIQUE_OP_H_ | |||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_DATA_UNIQUE_OP_H_ | |||
| #include <limits> | |||
| #include <vector> | |||
| #include <memory> | |||
| #include <string> | |||
| #include "minddata/dataset/core/tensor.h" | |||
| #include "minddata/dataset/kernels/tensor_op.h" | |||
| #include "minddata/dataset/kernels/data/data_utils.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| class UniqueOp : public TensorOp { | |||
| public: | |||
| UniqueOp() = default; | |||
| ~UniqueOp() override = default; | |||
| Status Compute(const TensorRow &input, TensorRow *output) override; | |||
| uint32_t NumOutput() override { return 0; } | |||
| std::string Name() const override { return kUniqueOp; } | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_UNIQUE_OP_H_ | |||
| @@ -125,6 +125,7 @@ constexpr char kPadEndOp[] = "PadEndOp"; | |||
| constexpr char kSliceOp[] = "SliceOp"; | |||
| constexpr char kToFloat16Op[] = "ToFloat16Op"; | |||
| constexpr char kTypeCastOp[] = "TypeCastOp"; | |||
| constexpr char kUniqueOp[] = "UniqueOp"; | |||
| // other | |||
| constexpr char kCFuncOp[] = "CFuncOp"; | |||
| @@ -296,6 +296,37 @@ class Duplicate(cde.DuplicateOp): | |||
| """ | |||
| class Unique(cde.UniqueOp): | |||
| """ | |||
| Return an output tensor containing all the unique elements of the input tensor in | |||
| the same order that they occur in the input tensor. | |||
| Also return an index tensor that contains the index of each element of the | |||
| input tensor in the Unique output tensor. | |||
| Finally, return a count tensor that constains the count of each element of | |||
| the output tensor in the input tensor. | |||
| Note: | |||
| Call batch op before calling this function. | |||
| Examples: | |||
| >>> import mindspore.dataset.transforms.c_transforms as c_transforms | |||
| >>> | |||
| >>> # Data before | |||
| >>> # | x | | |||
| >>> # +--------------------+ | |||
| >>> # | [[0,1,2], [1,2,3]] | | |||
| >>> # +--------------------+ | |||
| >>> data1 = data1.map(operations=c_transforms.Unique(), input_columns=["x"], | |||
| >>> output_columns=["x", "y", "z"], column_order=["x", "y", "z"]) | |||
| >>> # Data after | |||
| >>> # | x | y |z | | |||
| >>> # +---------+-----------------+---------+ | |||
| >>> # | [0,1,2,3] | [0,1,2,1,2,3] | [1,2,2,1] | |||
| >>> # +---------+-----------------+---------+ | |||
| """ | |||
| class Compose(cde.ComposeOp): | |||
| """ | |||
| Compose a list of transforms into a single transform. | |||
| @@ -0,0 +1,45 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """ | |||
| Testing unique op in DE | |||
| """ | |||
| import numpy as np | |||
| import mindspore.dataset as ds | |||
| import mindspore.dataset.transforms.c_transforms as ops | |||
| def compare(array, res, idx, cnt): | |||
| data = ds.NumpySlicesDataset([array], column_names="x") | |||
| data = data.batch(2) | |||
| data = data.map(operations=ops.Unique(), input_columns=["x"], output_columns=["x", "y", "z"], | |||
| column_order=["x", "y", "z"]) | |||
| for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): | |||
| np.testing.assert_array_equal(res, d["x"]) | |||
| np.testing.assert_array_equal(idx, d["y"]) | |||
| np.testing.assert_array_equal(cnt, d["z"]) | |||
| def test_duplicate_basics(): | |||
| compare([0, 1, 2, 1, 2, 3], np.array([0, 1, 2, 3]), | |||
| np.array([0, 1, 2, 1, 2, 3]), np.array([1, 2, 2, 1])) | |||
| compare([0.0, 1.0, 2.0, 1.0, 2.0, 3.0], np.array([0.0, 1.0, 2.0, 3.0]), | |||
| np.array([0, 1, 2, 1, 2, 3]), np.array([1, 2, 2, 1])) | |||
| compare([1, 1, 1, 1, 1, 1], np.array([1]), | |||
| np.array([0, 0, 0, 0, 0, 0]), np.array([6])) | |||
| if __name__ == "__main__": | |||
| test_duplicate_basics() | |||