| @@ -38,6 +38,7 @@ | |||||
| #include "dataset/kernels/image/resize_op.h" | #include "dataset/kernels/image/resize_op.h" | ||||
| #include "dataset/kernels/image/uniform_aug_op.h" | #include "dataset/kernels/image/uniform_aug_op.h" | ||||
| #include "dataset/kernels/data/type_cast_op.h" | #include "dataset/kernels/data/type_cast_op.h" | ||||
| #include "dataset/kernels/data/fill_op.h" | |||||
| #include "dataset/engine/datasetops/source/cifar_op.h" | #include "dataset/engine/datasetops/source/cifar_op.h" | ||||
| #include "dataset/engine/datasetops/source/image_folder_op.h" | #include "dataset/engine/datasetops/source/image_folder_op.h" | ||||
| #include "dataset/engine/datasetops/source/io_block.h" | #include "dataset/engine/datasetops/source/io_block.h" | ||||
| @@ -350,6 +351,10 @@ void bindTensorOps2(py::module *m) { | |||||
| *m, "OneHotOp", "Tensor operation to apply one hot encoding. Takes number of classes.") | *m, "OneHotOp", "Tensor operation to apply one hot encoding. Takes number of classes.") | ||||
| .def(py::init<int32_t>()); | .def(py::init<int32_t>()); | ||||
| (void)py::class_<FillOp, TensorOp, std::shared_ptr<FillOp>>( | |||||
| *m, "FillOp", "Tensor operation to return tensor filled with same value as input fill value.") | |||||
| .def(py::init<std::shared_ptr<Tensor>>()); | |||||
| (void)py::class_<RandomRotationOp, TensorOp, std::shared_ptr<RandomRotationOp>>( | (void)py::class_<RandomRotationOp, TensorOp, std::shared_ptr<RandomRotationOp>>( | ||||
| *m, "RandomRotationOp", | *m, "RandomRotationOp", | ||||
| "Tensor operation to apply RandomRotation." | "Tensor operation to apply RandomRotation." | ||||
| @@ -5,4 +5,4 @@ add_library(kernels-data OBJECT | |||||
| one_hot_op.cc | one_hot_op.cc | ||||
| type_cast_op.cc | type_cast_op.cc | ||||
| to_float16_op.cc | to_float16_op.cc | ||||
| ) | |||||
| fill_op.cc) | |||||
| @@ -23,6 +23,7 @@ | |||||
| #include "dataset/core/tensor_shape.h" | #include "dataset/core/tensor_shape.h" | ||||
| #include "dataset/core/data_type.h" | #include "dataset/core/data_type.h" | ||||
| #include "dataset/core/pybind_support.h" | #include "dataset/core/pybind_support.h" | ||||
| #include "dataset/kernels/data/type_cast_op.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| @@ -78,6 +79,7 @@ Status OneHotEncodingSigned(const std::shared_ptr<Tensor> &input, std::shared_pt | |||||
| Status OneHotEncoding(std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> *output, dsize_t num_classes) { | Status OneHotEncoding(std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> *output, dsize_t num_classes) { | ||||
| input->Squeeze(); | input->Squeeze(); | ||||
| if (input->Rank() > 1) { // We expect the input to be int he first dimension | if (input->Rank() > 1) { // We expect the input to be int he first dimension | ||||
| RETURN_STATUS_UNEXPECTED("One hot only supports scalars or 1D shape Tensors."); | RETURN_STATUS_UNEXPECTED("One hot only supports scalars or 1D shape Tensors."); | ||||
| } | } | ||||
| @@ -106,11 +108,121 @@ Status OneHotEncoding(std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> *ou | |||||
| } | } | ||||
| } | } | ||||
| Status Fill(const std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> *output, std::shared_ptr<Tensor> fill_value) { | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(!((fill_value->type() == DataType::DE_STRING) && (input->type() != DataType::DE_STRING)), | |||||
| "Types do not match"); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(fill_value->shape() == TensorShape({}), "fill_value is not a scalar"); | |||||
| std::shared_ptr<Tensor> out; | |||||
| const DataType &to = input->type(); | |||||
| std::unique_ptr<TypeCastOp> op(new TypeCastOp(to)); | |||||
| std::shared_ptr<Tensor> fill_output; | |||||
| op->Compute(fill_value, &fill_output); | |||||
| RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, TensorImpl::kFlexible, input->shape(), input->type())); | |||||
| switch (input->type().value()) { | |||||
| case DataType::DE_BOOL: { | |||||
| bool value = 0; | |||||
| RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {})); | |||||
| out->Fill<bool>(value); | |||||
| break; | |||||
| } | |||||
| case DataType::DE_INT8: { | |||||
| int8_t value = 0; | |||||
| RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {})); | |||||
| out->Fill<int8_t>(value); | |||||
| break; | |||||
| } | |||||
| case DataType::DE_UINT8: { | |||||
| uint8_t value = 0; | |||||
| RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {})); | |||||
| out->Fill<uint8_t>(value); | |||||
| break; | |||||
| } | |||||
| case DataType::DE_UINT16: { | |||||
| uint16_t value = 0; | |||||
| RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {})); | |||||
| out->Fill<uint16_t>(value); | |||||
| break; | |||||
| } | |||||
| case DataType::DE_INT16: { | |||||
| int16_t value = 0; | |||||
| RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {})); | |||||
| out->Fill<int16_t>(value); | |||||
| break; | |||||
| } | |||||
| case DataType::DE_UINT32: { | |||||
| uint32_t value = 0; | |||||
| RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {})); | |||||
| out->Fill<uint32_t>(value); | |||||
| break; | |||||
| } | |||||
| case DataType::DE_INT32: { | |||||
| int32_t value = 0; | |||||
| RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {})); | |||||
| out->Fill<int32_t>(value); | |||||
| break; | |||||
| } | |||||
| case DataType::DE_UINT64: { | |||||
| uint64_t value = 0; | |||||
| RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {})); | |||||
| out->Fill<uint64_t>(value); | |||||
| break; | |||||
| } | |||||
| case DataType::DE_INT64: { | |||||
| int64_t value = 0; | |||||
| RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {})); | |||||
| out->Fill<int64_t>(value); | |||||
| break; | |||||
| } | |||||
| case DataType::DE_FLOAT16: { | |||||
| int64_t value = 0; | |||||
| RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {})); | |||||
| out->Fill<float>(value); | |||||
| break; | |||||
| } | |||||
| case DataType::DE_FLOAT32: { | |||||
| float value = 0; | |||||
| RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {})); | |||||
| out->Fill<float>(value); | |||||
| break; | |||||
| } | |||||
| case DataType::DE_FLOAT64: { | |||||
| double value = 0; | |||||
| RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {})); | |||||
| out->Fill<double>(value); | |||||
| break; | |||||
| } | |||||
| case DataType::DE_STRING: { | |||||
| std::vector<std::string> strings; | |||||
| std::string_view fill_string_view; | |||||
| RETURN_IF_NOT_OK(fill_value->GetItemAt(&fill_string_view, {})); | |||||
| std::string fill_string = std::string(fill_string_view); | |||||
| for (int i = 0; i < input->shape().NumOfElements(); i++) { | |||||
| strings.emplace_back(fill_string); | |||||
| } | |||||
| RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, strings, input->shape())); | |||||
| break; | |||||
| } | |||||
| case DataType::DE_UNKNOWN: { | |||||
| RETURN_STATUS_UNEXPECTED("FillOp does not support input of this type."); | |||||
| break; | |||||
| } | |||||
| } | |||||
| *output = out; | |||||
| return Status::OK(); | |||||
| } | |||||
| template <typename FROM, typename TO> | template <typename FROM, typename TO> | ||||
| void Cast(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) { | void Cast(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) { | ||||
| auto in_itr = input->begin<FROM>(); | auto in_itr = input->begin<FROM>(); | ||||
| auto out_itr = (*output)->begin<TO>(); | auto out_itr = (*output)->begin<TO>(); | ||||
| auto out_end = (*output)->end<TO>(); | auto out_end = (*output)->end<TO>(); | ||||
| for (; out_itr != out_end; static_cast<void>(in_itr++), static_cast<void>(out_itr++)) | for (; out_itr != out_end; static_cast<void>(in_itr++), static_cast<void>(out_itr++)) | ||||
| *out_itr = static_cast<TO>(*in_itr); | *out_itr = static_cast<TO>(*in_itr); | ||||
| } | } | ||||
| @@ -43,6 +43,13 @@ Status OneHotEncodingUnsigned(const std::shared_ptr<Tensor> &input, std::shared_ | |||||
| Status OneHotEncodingSigned(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, dsize_t num_classes, | Status OneHotEncodingSigned(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, dsize_t num_classes, | ||||
| int64_t index); | int64_t index); | ||||
| // Returns a tensor of shape input filled with the passed fill_value | |||||
| // @param input Tensor | |||||
| // @param output Tensor. The shape and type of the output tensor is same as input | |||||
| // @param fill_value Tensor. A scalar tensor used to fill the output tensor | |||||
| Status Fill(const std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> *output, std::shared_ptr<Tensor> fill_value); | |||||
| // Returns a type changed input tensor. | // Returns a type changed input tensor. | ||||
| // Example: if input tensor is float64, the output will the specified dataType. See DataTypes.cpp | // Example: if input tensor is float64, the output will the specified dataType. See DataTypes.cpp | ||||
| // @param input Tensor | // @param input Tensor | ||||
| @@ -0,0 +1,31 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "dataset/kernels/data/fill_op.h" | |||||
| #include "dataset/core/tensor.h" | |||||
| #include "dataset/kernels/data/data_utils.h" | |||||
| #include "dataset/kernels/tensor_op.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| Status FillOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) { | |||||
| IO_CHECK(input, output); | |||||
| Status s = Fill(input, output, fill_value_); | |||||
| return s; | |||||
| } | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,47 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef DATASET_KERNELS_DATA_FILL_OP_H_ | |||||
| #define DATASET_KERNELS_DATA_FILL_OP_H_ | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include "dataset/core/tensor.h" | |||||
| #include "dataset/kernels/tensor_op.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| class FillOp : public TensorOp { | |||||
| public: | |||||
| explicit FillOp(std::shared_ptr<Tensor> value) : fill_value_(value) {} | |||||
| ~FillOp() override = default; | |||||
| void Print(std::ostream &out) const override { out << "FillOp"; } | |||||
| Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override; | |||||
| private: | |||||
| std::shared_ptr<Tensor> fill_value_; | |||||
| }; | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_FILL_OP_H | |||||
| @@ -15,9 +15,9 @@ | |||||
| """ | """ | ||||
| This module c_transforms provides common operations, including OneHotOp and TypeCast. | This module c_transforms provides common operations, including OneHotOp and TypeCast. | ||||
| """ | """ | ||||
| import numpy as np | |||||
| import mindspore._c_dataengine as cde | import mindspore._c_dataengine as cde | ||||
| from .validators import check_num_classes, check_de_type | |||||
| from .validators import check_num_classes, check_de_type, check_fill_value | |||||
| from ..core.datatypes import mstype_to_detype | from ..core.datatypes import mstype_to_detype | ||||
| @@ -35,6 +35,22 @@ class OneHot(cde.OneHotOp): | |||||
| super().__init__(num_classes) | super().__init__(num_classes) | ||||
| class Fill(cde.FillOp): | |||||
| """ | |||||
| Tensor operation to create a tensor filled with passed scalar value. | |||||
| The output tensor will have the same shape and type as the input tensor. | |||||
| Args: | |||||
| fill_value (python types (str, int, float, or bool)) : scalar value | |||||
| to fill created tensor with. | |||||
| """ | |||||
| @check_fill_value | |||||
| def __init__(self, fill_value): | |||||
| print(fill_value) | |||||
| super().__init__(cde.Tensor(np.array(fill_value))) | |||||
| class TypeCast(cde.TypeCastOp): | class TypeCast(cde.TypeCastOp): | ||||
| """ | """ | ||||
| Tensor operation to cast to a given MindSpore data type. | Tensor operation to cast to a given MindSpore data type. | ||||
| @@ -17,7 +17,6 @@ | |||||
| from functools import wraps | from functools import wraps | ||||
| from mindspore._c_expression import typing | from mindspore._c_expression import typing | ||||
| # POS_INT_MIN is used to limit values from starting from 0 | # POS_INT_MIN is used to limit values from starting from 0 | ||||
| POS_INT_MIN = 1 | POS_INT_MIN = 1 | ||||
| UINT8_MAX = 255 | UINT8_MAX = 255 | ||||
| @@ -159,6 +158,25 @@ def check_num_classes(method): | |||||
| return new_method | return new_method | ||||
| def check_fill_value(method): | |||||
| """Wrapper method to check the parameters of fill value.""" | |||||
| @wraps(method) | |||||
| def new_method(self, *args, **kwargs): | |||||
| fill_value = (list(args) + [None])[0] | |||||
| if "fill_value" in kwargs: | |||||
| fill_value = kwargs.get("fill_value") | |||||
| if fill_value is None: | |||||
| raise ValueError("fill_value is not provided.") | |||||
| if not isinstance(fill_value, (str, float, bool, int)): | |||||
| raise TypeError("fill_value must be either a primitive python str, float, bool, or int") | |||||
| kwargs["fill_value"] = fill_value | |||||
| return method(self, **kwargs) | |||||
| return new_method | |||||
| def check_de_type(method): | def check_de_type(method): | ||||
| """Wrapper method to check the parameters of data type.""" | """Wrapper method to check the parameters of data type.""" | ||||
| @@ -72,6 +72,7 @@ SET(DE_UT_SRCS | |||||
| tokenizer_op_test.cc | tokenizer_op_test.cc | ||||
| gnn_graph_test.cc | gnn_graph_test.cc | ||||
| coco_op_test.cc | coco_op_test.cc | ||||
| fill_op_test.cc | |||||
| ) | ) | ||||
| add_executable(de_ut_tests ${DE_UT_SRCS}) | add_executable(de_ut_tests ${DE_UT_SRCS}) | ||||
| @@ -0,0 +1,183 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "common/common.h" | |||||
| #include "dataset/kernels/data/fill_op.h" | |||||
| #include "utils/log_adapter.h" | |||||
| using namespace mindspore::dataset; | |||||
| using mindspore::LogStream; | |||||
| using mindspore::ExceptionType::NoExceptionType; | |||||
| using mindspore::MsLogLevel::INFO; | |||||
| class MindDataTestFillOp : public UT::Common { | |||||
| protected: | |||||
| MindDataTestFillOp() {} | |||||
| }; | |||||
| TEST_F(MindDataTestFillOp, TestOp) { | |||||
| MS_LOG(INFO) << "Doing MindDataTestFillOp-TestOp."; | |||||
| uint64_t labels[3] = {1, 1, 2}; | |||||
| TensorShape shape({3}); | |||||
| std::shared_ptr<Tensor> input = | |||||
| std::make_shared<Tensor>(shape, DataType(DataType::DE_UINT64), reinterpret_cast<unsigned char *>(labels)); | |||||
| TensorShape fill_shape({}); | |||||
| std::shared_ptr<Tensor> fill_tensor = std::make_shared<Tensor>(fill_shape, DataType(DataType::DE_UINT64)); | |||||
| fill_tensor->SetItemAt<uint64_t>({}, 4); | |||||
| std::shared_ptr<Tensor> output; | |||||
| std::unique_ptr<FillOp> op(new FillOp(fill_tensor)); | |||||
| Status s = op->Compute(input, &output); | |||||
| uint64_t out[3] = {4, 4, 4}; | |||||
| std::shared_ptr<Tensor> expected = | |||||
| std::make_shared<Tensor>(TensorShape{3}, DataType(DataType::DE_UINT64), reinterpret_cast<unsigned char *>(out)); | |||||
| EXPECT_TRUE(s.IsOk()); | |||||
| ASSERT_TRUE(output->shape() == expected->shape()); | |||||
| ASSERT_TRUE(output->type() == expected->type()); | |||||
| MS_LOG(DEBUG) << *output << std::endl; | |||||
| MS_LOG(DEBUG) << *expected << std::endl; | |||||
| ASSERT_TRUE(*output == *expected); | |||||
| MS_LOG(INFO) << "MindDataTestFillOp-TestOp end."; | |||||
| } | |||||
| TEST_F(MindDataTestFillOp, TestCasting) { | |||||
| MS_LOG(INFO) << "Doing MindDataTestFillOp-TestCasting."; | |||||
| uint64_t labels[3] = {0, 1, 2}; | |||||
| TensorShape shape({3}); | |||||
| std::shared_ptr<Tensor> input = | |||||
| std::make_shared<Tensor>(shape, DataType(DataType::DE_UINT64), reinterpret_cast<unsigned char *>(labels)); | |||||
| TensorShape fill_shape({}); | |||||
| std::shared_ptr<Tensor> fill_tensor = std::make_shared<Tensor>(fill_shape, DataType(DataType::DE_FLOAT32)); | |||||
| fill_tensor->SetItemAt<float>({}, 2.0); | |||||
| std::shared_ptr<Tensor> output; | |||||
| std::unique_ptr<FillOp> op(new FillOp(fill_tensor)); | |||||
| Status s = op->Compute(input, &output); | |||||
| uint64_t out[3] = {2, 2, 2}; | |||||
| std::shared_ptr<Tensor> expected = | |||||
| std::make_shared<Tensor>(TensorShape{3}, DataType(DataType::DE_UINT64), reinterpret_cast<unsigned char *>(out)); | |||||
| ASSERT_TRUE(output->shape() == expected->shape()); | |||||
| ASSERT_TRUE(output->type() == expected->type()); | |||||
| EXPECT_TRUE(s.IsOk()); | |||||
| MS_LOG(DEBUG) << *output << std::endl; | |||||
| MS_LOG(DEBUG) << *expected << std::endl; | |||||
| ASSERT_TRUE(*output == *expected); | |||||
| MS_LOG(INFO) << "MindDataTestFillOp-TestCasting end."; | |||||
| } | |||||
| TEST_F(MindDataTestFillOp, ScalarFill) { | |||||
| MS_LOG(INFO) << "Doing MindDataTestFillOp-ScalarFill."; | |||||
| uint64_t labels[3] = {0, 1, 2}; | |||||
| TensorShape shape({3}); | |||||
| std::shared_ptr<Tensor> input = | |||||
| std::make_shared<Tensor>(shape, DataType(DataType::DE_UINT64), reinterpret_cast<unsigned char *>(labels)); | |||||
| TensorShape fill_shape({2}); | |||||
| uint64_t fill_labels[3] = {0, 1}; | |||||
| std::shared_ptr<Tensor> fill_tensor = | |||||
| std::make_shared<Tensor>(fill_shape, DataType(DataType::DE_UINT64), reinterpret_cast<unsigned char *>(fill_labels)); | |||||
| std::shared_ptr<Tensor> output; | |||||
| std::unique_ptr<FillOp> op(new FillOp(fill_tensor)); | |||||
| Status s = op->Compute(input, &output); | |||||
| EXPECT_TRUE(s.IsError()); | |||||
| ASSERT_TRUE(s.get_code() == StatusCode::kUnexpectedError); | |||||
| MS_LOG(INFO) << "MindDataTestFillOp-ScalarFill end."; | |||||
| } | |||||
| TEST_F(MindDataTestFillOp, StringFill) { | |||||
| MS_LOG(INFO) << "Doing MindDataTestFillOp-StringFill."; | |||||
| std::vector<std::string> strings = {"xyzzy", "plugh", "abracadabra"}; | |||||
| TensorShape shape({3}); | |||||
| std::shared_ptr<Tensor> input = std::make_shared<Tensor>(strings, shape); | |||||
| TensorShape fill_shape({}); | |||||
| std::string fill_string = "hello"; | |||||
| std::shared_ptr<Tensor> fill_tensor = std::make_shared<Tensor>(fill_string); | |||||
| std::shared_ptr<Tensor> output; | |||||
| std::unique_ptr<FillOp> op(new FillOp(fill_tensor)); | |||||
| Status s = op->Compute(input, &output); | |||||
| std::vector<std::string> expected_strings = {"hello", "hello", "hello"}; | |||||
| TensorShape expected_shape({3}); | |||||
| std::shared_ptr<Tensor> expected = std::make_shared<Tensor>(expected_strings, expected_shape); | |||||
| EXPECT_TRUE(s.IsOk()); | |||||
| ASSERT_TRUE(output->shape() == expected->shape()); | |||||
| ASSERT_TRUE(output->type() == expected->type()); | |||||
| MS_LOG(DEBUG) << *output << std::endl; | |||||
| MS_LOG(DEBUG) << *expected << std::endl; | |||||
| ASSERT_TRUE(*output == *expected); | |||||
| MS_LOG(INFO) << "MindDataTestFillOp-StringFill end."; | |||||
| } | |||||
| TEST_F(MindDataTestFillOp, NumericToString) { | |||||
| MS_LOG(INFO) << "Doing MindDataTestFillOp-NumericToString."; | |||||
| std::vector<std::string> strings = {"xyzzy", "plugh", "abracadabra"}; | |||||
| TensorShape shape({3}); | |||||
| std::shared_ptr<Tensor> input = std::make_shared<Tensor>(strings, shape); | |||||
| TensorShape fill_shape({}); | |||||
| std::shared_ptr<Tensor> fill_tensor = std::make_shared<Tensor>(fill_shape, DataType(DataType::DE_FLOAT32)); | |||||
| fill_tensor->SetItemAt<float>({}, 2.0); | |||||
| std::shared_ptr<Tensor> output; | |||||
| std::unique_ptr<FillOp> op(new FillOp(fill_tensor)); | |||||
| Status s = op->Compute(input, &output); | |||||
| EXPECT_TRUE(s.IsError()); | |||||
| ASSERT_TRUE(s.get_code() == StatusCode::kUnexpectedError); | |||||
| MS_LOG(INFO) << "MindDataTestFillOp-NumericToString end."; | |||||
| } | |||||
| TEST_F(MindDataTestFillOp, StringToNumeric) { | |||||
| MS_LOG(INFO) << "Doing MindDataTestFillOp-StringToNumeric."; | |||||
| uint64_t labels[3] = {0, 1, 2}; | |||||
| TensorShape shape({3}); | |||||
| std::shared_ptr<Tensor> input = | |||||
| std::make_shared<Tensor>(shape, DataType(DataType::DE_UINT64), reinterpret_cast<unsigned char *>(labels)); | |||||
| TensorShape fill_shape({}); | |||||
| std::string fill_string = "hello"; | |||||
| std::shared_ptr<Tensor> fill_tensor = std::make_shared<Tensor>(fill_string); | |||||
| std::shared_ptr<Tensor> output; | |||||
| std::unique_ptr<FillOp> op(new FillOp(fill_tensor)); | |||||
| Status s = op->Compute(input, &output); | |||||
| EXPECT_TRUE(s.IsError()); | |||||
| ASSERT_TRUE(s.get_code() == StatusCode::kUnexpectedError); | |||||
| MS_LOG(INFO) << "MindDataTestFillOp-StringToNumeric end."; | |||||
| } | |||||
| @@ -13,9 +13,6 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| // | |||||
| // Created by jesse on 10/3/19. | |||||
| // | |||||
| #include "common/common.h" | #include "common/common.h" | ||||
| #include "gtest/gtest.h" | #include "gtest/gtest.h" | ||||
| @@ -25,32 +22,32 @@ | |||||
| #include "utils/log_adapter.h" | #include "utils/log_adapter.h" | ||||
| using namespace mindspore::dataset; | using namespace mindspore::dataset; | ||||
| using mindspore::MsLogLevel::INFO; | |||||
| using mindspore::ExceptionType::NoExceptionType; | |||||
| using mindspore::LogStream; | using mindspore::LogStream; | ||||
| using mindspore::ExceptionType::NoExceptionType; | |||||
| using mindspore::MsLogLevel::INFO; | |||||
| class MindDataTestQueue : public UT::Common { | class MindDataTestQueue : public UT::Common { | ||||
| public: | public: | ||||
| MindDataTestQueue() {} | |||||
| MindDataTestQueue() {} | |||||
| void SetUp() {} | |||||
| void SetUp() {} | |||||
| }; | }; | ||||
| int gRefCountDestructorCalled; | int gRefCountDestructorCalled; | ||||
| class RefCount { | class RefCount { | ||||
| public: | public: | ||||
| RefCount() : v_(nullptr) {} | |||||
| explicit RefCount(int x) : v_(std::make_shared<int>(x)) {} | |||||
| explicit RefCount(const RefCount &o) : v_(o.v_) {} | |||||
| ~RefCount() { | |||||
| MS_LOG(DEBUG) << "Destructor of RefCount called" << std::endl; | |||||
| gRefCountDestructorCalled++; | |||||
| } | |||||
| RefCount& operator=(const RefCount &o) { | |||||
| v_ = o.v_; | |||||
| return *this; | |||||
| } | |||||
| RefCount() : v_(nullptr) {} | |||||
| explicit RefCount(int x) : v_(std::make_shared<int>(x)) {} | |||||
| explicit RefCount(const RefCount &o) : v_(o.v_) {} | |||||
| ~RefCount() { | |||||
| MS_LOG(DEBUG) << "Destructor of RefCount called" << std::endl; | |||||
| gRefCountDestructorCalled++; | |||||
| } | |||||
| RefCount &operator=(const RefCount &o) { | |||||
| v_ = o.v_; | |||||
| return *this; | |||||
| } | |||||
| std::shared_ptr<int> v_; | std::shared_ptr<int> v_; | ||||
| }; | }; | ||||
| @@ -70,22 +67,22 @@ TEST_F(MindDataTestQueue, Test1) { | |||||
| // Use count should remain 2. a and b. No copy in the queue. | // Use count should remain 2. a and b. No copy in the queue. | ||||
| ASSERT_EQ(a.use_count(), 2); | ASSERT_EQ(a.use_count(), 2); | ||||
| a.reset(new int(5)); | a.reset(new int(5)); | ||||
| ASSERT_EQ(a.use_count(),1); | |||||
| ASSERT_EQ(a.use_count(), 1); | |||||
| // Push again but expect a is nullptr after push | // Push again but expect a is nullptr after push | ||||
| rc = que.Add(std::move(a)); | rc = que.Add(std::move(a)); | ||||
| ASSERT_TRUE(rc.IsOk()); | ASSERT_TRUE(rc.IsOk()); | ||||
| ASSERT_EQ(a.use_count(),0); | |||||
| ASSERT_EQ(a.use_count(), 0); | |||||
| rc = que.PopFront(&b); | rc = que.PopFront(&b); | ||||
| ASSERT_TRUE(rc.IsOk()); | ASSERT_TRUE(rc.IsOk()); | ||||
| ASSERT_EQ(*b, 5); | ASSERT_EQ(*b, 5); | ||||
| ASSERT_EQ(b.use_count(),1); | |||||
| ASSERT_EQ(b.use_count(), 1); | |||||
| // Test construct in place | // Test construct in place | ||||
| rc = que.EmplaceBack(std::make_shared<int>(100)); | rc = que.EmplaceBack(std::make_shared<int>(100)); | ||||
| ASSERT_TRUE(rc.IsOk()); | ASSERT_TRUE(rc.IsOk()); | ||||
| rc = que.PopFront(&b); | rc = que.PopFront(&b); | ||||
| ASSERT_TRUE(rc.IsOk()); | ASSERT_TRUE(rc.IsOk()); | ||||
| ASSERT_EQ(*b, 100); | ASSERT_EQ(*b, 100); | ||||
| ASSERT_EQ(b.use_count(),1); | |||||
| ASSERT_EQ(b.use_count(), 1); | |||||
| // Test the destructor of the Queue by add an element in the queue without popping it and let the queue go | // Test the destructor of the Queue by add an element in the queue without popping it and let the queue go | ||||
| // out of scope. | // out of scope. | ||||
| rc = que.EmplaceBack(std::make_shared<int>(2000)); | rc = que.EmplaceBack(std::make_shared<int>(2000)); | ||||
| @@ -127,7 +124,7 @@ TEST_F(MindDataTestQueue, Test3) { | |||||
| ASSERT_EQ(*b, 40); | ASSERT_EQ(*b, 40); | ||||
| } | } | ||||
| void test4(){ | |||||
| void test4() { | |||||
| gRefCountDestructorCalled = 0; | gRefCountDestructorCalled = 0; | ||||
| // Pass a structure along the queue. | // Pass a structure along the queue. | ||||
| Queue<RefCount> que(3); | Queue<RefCount> que(3); | ||||
| @@ -144,9 +141,7 @@ void test4(){ | |||||
| ASSERT_TRUE(rc.IsOk()); | ASSERT_TRUE(rc.IsOk()); | ||||
| } | } | ||||
| TEST_F(MindDataTestQueue, Test4) { | |||||
| test4(); | |||||
| } | |||||
| TEST_F(MindDataTestQueue, Test4) { test4(); } | |||||
| TEST_F(MindDataTestQueue, Test5) { | TEST_F(MindDataTestQueue, Test5) { | ||||
| test4(); | test4(); | ||||
| @@ -0,0 +1,95 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================== | |||||
| """ | |||||
| Testing fill op | |||||
| """ | |||||
| import numpy as np | |||||
| import pytest | |||||
| import mindspore.dataset as ds | |||||
| import mindspore.dataset.transforms.c_transforms as data_trans | |||||
| def test_fillop_basic(): | |||||
| def gen(): | |||||
| yield (np.array([4, 5, 6, 7], dtype=np.uint8),) | |||||
| data = ds.GeneratorDataset(gen, column_names=["col"]) | |||||
| fill_op = data_trans.Fill(3) | |||||
| data = data.map(input_columns=["col"], operations=fill_op) | |||||
| expected = np.array([3, 3, 3, 3], dtype=np.uint8) | |||||
| for data_row in data: | |||||
| np.testing.assert_array_equal(data_row[0], expected) | |||||
| def test_fillop_down_type_cast(): | |||||
| def gen(): | |||||
| yield (np.array([4, 5, 6, 7], dtype=np.uint8),) | |||||
| data = ds.GeneratorDataset(gen, column_names=["col"]) | |||||
| fill_op = data_trans.Fill(-3) | |||||
| data = data.map(input_columns=["col"], operations=fill_op) | |||||
| expected = np.array([253, 253, 253, 253], dtype=np.uint8) | |||||
| for data_row in data: | |||||
| np.testing.assert_array_equal(data_row[0], expected) | |||||
| def test_fillop_up_type_cast(): | |||||
| def gen(): | |||||
| yield (np.array([4, 5, 6, 7], dtype=np.float),) | |||||
| data = ds.GeneratorDataset(gen, column_names=["col"]) | |||||
| fill_op = data_trans.Fill(3) | |||||
| data = data.map(input_columns=["col"], operations=fill_op) | |||||
| expected = np.array([3., 3., 3., 3.], dtype=np.float) | |||||
| for data_row in data: | |||||
| np.testing.assert_array_equal(data_row[0], expected) | |||||
| def test_fillop_string(): | |||||
| def gen(): | |||||
| yield (np.array(["45555", "45555"], dtype='S'),) | |||||
| data = ds.GeneratorDataset(gen, column_names=["col"]) | |||||
| fill_op = data_trans.Fill("error") | |||||
| data = data.map(input_columns=["col"], operations=fill_op) | |||||
| expected = np.array(['error', 'error'], dtype='S') | |||||
| for data_row in data: | |||||
| np.testing.assert_array_equal(data_row[0], expected) | |||||
| def test_fillop_error_handling(): | |||||
| def gen(): | |||||
| yield (np.array([4, 4, 4, 4]),) | |||||
| data = ds.GeneratorDataset(gen, column_names=["col"]) | |||||
| fill_op = data_trans.Fill("words") | |||||
| data = data.map(input_columns=["col"], operations=fill_op) | |||||
| with pytest.raises(RuntimeError) as error_info: | |||||
| for data_row in data: | |||||
| print(data_row) | |||||
| assert "Types do not match" in repr(error_info.value) | |||||
| if __name__ == "__main__": | |||||
| test_fillop_basic() | |||||
| test_fillop_up_type_cast() | |||||
| test_fillop_down_type_cast() | |||||
| test_fillop_string() | |||||
| test_fillop_error_handling() | |||||