| @@ -38,6 +38,7 @@ | |||
| #include "dataset/kernels/image/resize_op.h" | |||
| #include "dataset/kernels/image/uniform_aug_op.h" | |||
| #include "dataset/kernels/data/type_cast_op.h" | |||
| #include "dataset/kernels/data/fill_op.h" | |||
| #include "dataset/engine/datasetops/source/cifar_op.h" | |||
| #include "dataset/engine/datasetops/source/image_folder_op.h" | |||
| #include "dataset/engine/datasetops/source/io_block.h" | |||
| @@ -350,6 +351,10 @@ void bindTensorOps2(py::module *m) { | |||
| *m, "OneHotOp", "Tensor operation to apply one hot encoding. Takes number of classes.") | |||
| .def(py::init<int32_t>()); | |||
| (void)py::class_<FillOp, TensorOp, std::shared_ptr<FillOp>>( | |||
| *m, "FillOp", "Tensor operation to return tensor filled with same value as input fill value.") | |||
| .def(py::init<std::shared_ptr<Tensor>>()); | |||
| (void)py::class_<RandomRotationOp, TensorOp, std::shared_ptr<RandomRotationOp>>( | |||
| *m, "RandomRotationOp", | |||
| "Tensor operation to apply RandomRotation." | |||
| @@ -5,4 +5,4 @@ add_library(kernels-data OBJECT | |||
| one_hot_op.cc | |||
| type_cast_op.cc | |||
| to_float16_op.cc | |||
| ) | |||
| fill_op.cc) | |||
| @@ -23,6 +23,7 @@ | |||
| #include "dataset/core/tensor_shape.h" | |||
| #include "dataset/core/data_type.h" | |||
| #include "dataset/core/pybind_support.h" | |||
| #include "dataset/kernels/data/type_cast_op.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| @@ -78,6 +79,7 @@ Status OneHotEncodingSigned(const std::shared_ptr<Tensor> &input, std::shared_pt | |||
| Status OneHotEncoding(std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> *output, dsize_t num_classes) { | |||
| input->Squeeze(); | |||
| if (input->Rank() > 1) { // We expect the input to be int he first dimension | |||
| RETURN_STATUS_UNEXPECTED("One hot only supports scalars or 1D shape Tensors."); | |||
| } | |||
| @@ -106,11 +108,121 @@ Status OneHotEncoding(std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> *ou | |||
| } | |||
| } | |||
| Status Fill(const std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> *output, std::shared_ptr<Tensor> fill_value) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(!((fill_value->type() == DataType::DE_STRING) && (input->type() != DataType::DE_STRING)), | |||
| "Types do not match"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(fill_value->shape() == TensorShape({}), "fill_value is not a scalar"); | |||
| std::shared_ptr<Tensor> out; | |||
| const DataType &to = input->type(); | |||
| std::unique_ptr<TypeCastOp> op(new TypeCastOp(to)); | |||
| std::shared_ptr<Tensor> fill_output; | |||
| op->Compute(fill_value, &fill_output); | |||
| RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, TensorImpl::kFlexible, input->shape(), input->type())); | |||
| switch (input->type().value()) { | |||
| case DataType::DE_BOOL: { | |||
| bool value = 0; | |||
| RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {})); | |||
| out->Fill<bool>(value); | |||
| break; | |||
| } | |||
| case DataType::DE_INT8: { | |||
| int8_t value = 0; | |||
| RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {})); | |||
| out->Fill<int8_t>(value); | |||
| break; | |||
| } | |||
| case DataType::DE_UINT8: { | |||
| uint8_t value = 0; | |||
| RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {})); | |||
| out->Fill<uint8_t>(value); | |||
| break; | |||
| } | |||
| case DataType::DE_UINT16: { | |||
| uint16_t value = 0; | |||
| RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {})); | |||
| out->Fill<uint16_t>(value); | |||
| break; | |||
| } | |||
| case DataType::DE_INT16: { | |||
| int16_t value = 0; | |||
| RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {})); | |||
| out->Fill<int16_t>(value); | |||
| break; | |||
| } | |||
| case DataType::DE_UINT32: { | |||
| uint32_t value = 0; | |||
| RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {})); | |||
| out->Fill<uint32_t>(value); | |||
| break; | |||
| } | |||
| case DataType::DE_INT32: { | |||
| int32_t value = 0; | |||
| RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {})); | |||
| out->Fill<int32_t>(value); | |||
| break; | |||
| } | |||
| case DataType::DE_UINT64: { | |||
| uint64_t value = 0; | |||
| RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {})); | |||
| out->Fill<uint64_t>(value); | |||
| break; | |||
| } | |||
| case DataType::DE_INT64: { | |||
| int64_t value = 0; | |||
| RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {})); | |||
| out->Fill<int64_t>(value); | |||
| break; | |||
| } | |||
| case DataType::DE_FLOAT16: { | |||
| int64_t value = 0; | |||
| RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {})); | |||
| out->Fill<float>(value); | |||
| break; | |||
| } | |||
| case DataType::DE_FLOAT32: { | |||
| float value = 0; | |||
| RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {})); | |||
| out->Fill<float>(value); | |||
| break; | |||
| } | |||
| case DataType::DE_FLOAT64: { | |||
| double value = 0; | |||
| RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {})); | |||
| out->Fill<double>(value); | |||
| break; | |||
| } | |||
| case DataType::DE_STRING: { | |||
| std::vector<std::string> strings; | |||
| std::string_view fill_string_view; | |||
| RETURN_IF_NOT_OK(fill_value->GetItemAt(&fill_string_view, {})); | |||
| std::string fill_string = std::string(fill_string_view); | |||
| for (int i = 0; i < input->shape().NumOfElements(); i++) { | |||
| strings.emplace_back(fill_string); | |||
| } | |||
| RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, strings, input->shape())); | |||
| break; | |||
| } | |||
| case DataType::DE_UNKNOWN: { | |||
| RETURN_STATUS_UNEXPECTED("FillOp does not support input of this type."); | |||
| break; | |||
| } | |||
| } | |||
| *output = out; | |||
| return Status::OK(); | |||
| } | |||
| template <typename FROM, typename TO> | |||
| void Cast(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) { | |||
| auto in_itr = input->begin<FROM>(); | |||
| auto out_itr = (*output)->begin<TO>(); | |||
| auto out_end = (*output)->end<TO>(); | |||
| for (; out_itr != out_end; static_cast<void>(in_itr++), static_cast<void>(out_itr++)) | |||
| *out_itr = static_cast<TO>(*in_itr); | |||
| } | |||
| @@ -43,6 +43,13 @@ Status OneHotEncodingUnsigned(const std::shared_ptr<Tensor> &input, std::shared_ | |||
| Status OneHotEncodingSigned(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, dsize_t num_classes, | |||
| int64_t index); | |||
| // Returns a tensor of shape input filled with the passed fill_value | |||
| // @param input Tensor | |||
| // @param output Tensor. The shape and type of the output tensor is same as input | |||
| // @param fill_value Tensor. A scalar tensor used to fill the output tensor | |||
| Status Fill(const std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> *output, std::shared_ptr<Tensor> fill_value); | |||
| // Returns a type changed input tensor. | |||
| // Example: if input tensor is float64, the output will the specified dataType. See DataTypes.cpp | |||
| // @param input Tensor | |||
| @@ -0,0 +1,31 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "dataset/kernels/data/fill_op.h" | |||
| #include "dataset/core/tensor.h" | |||
| #include "dataset/kernels/data/data_utils.h" | |||
| #include "dataset/kernels/tensor_op.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| Status FillOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) { | |||
| IO_CHECK(input, output); | |||
| Status s = Fill(input, output, fill_value_); | |||
| return s; | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,47 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef DATASET_KERNELS_DATA_FILL_OP_H_ | |||
| #define DATASET_KERNELS_DATA_FILL_OP_H_ | |||
| #include <string> | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "dataset/core/tensor.h" | |||
| #include "dataset/kernels/tensor_op.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| class FillOp : public TensorOp { | |||
| public: | |||
| explicit FillOp(std::shared_ptr<Tensor> value) : fill_value_(value) {} | |||
| ~FillOp() override = default; | |||
| void Print(std::ostream &out) const override { out << "FillOp"; } | |||
| Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override; | |||
| private: | |||
| std::shared_ptr<Tensor> fill_value_; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_FILL_OP_H | |||
| @@ -15,9 +15,9 @@ | |||
| """ | |||
| This module c_transforms provides common operations, including OneHotOp and TypeCast. | |||
| """ | |||
| import numpy as np | |||
| import mindspore._c_dataengine as cde | |||
| from .validators import check_num_classes, check_de_type | |||
| from .validators import check_num_classes, check_de_type, check_fill_value | |||
| from ..core.datatypes import mstype_to_detype | |||
| @@ -35,6 +35,22 @@ class OneHot(cde.OneHotOp): | |||
| super().__init__(num_classes) | |||
| class Fill(cde.FillOp): | |||
| """ | |||
| Tensor operation to create a tensor filled with passed scalar value. | |||
| The output tensor will have the same shape and type as the input tensor. | |||
| Args: | |||
| fill_value (python types (str, int, float, or bool)) : scalar value | |||
| to fill created tensor with. | |||
| """ | |||
| @check_fill_value | |||
| def __init__(self, fill_value): | |||
| print(fill_value) | |||
| super().__init__(cde.Tensor(np.array(fill_value))) | |||
| class TypeCast(cde.TypeCastOp): | |||
| """ | |||
| Tensor operation to cast to a given MindSpore data type. | |||
| @@ -17,7 +17,6 @@ | |||
| from functools import wraps | |||
| from mindspore._c_expression import typing | |||
| # POS_INT_MIN is used to limit values from starting from 0 | |||
| POS_INT_MIN = 1 | |||
| UINT8_MAX = 255 | |||
| @@ -159,6 +158,25 @@ def check_num_classes(method): | |||
| return new_method | |||
| def check_fill_value(method): | |||
| """Wrapper method to check the parameters of fill value.""" | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| fill_value = (list(args) + [None])[0] | |||
| if "fill_value" in kwargs: | |||
| fill_value = kwargs.get("fill_value") | |||
| if fill_value is None: | |||
| raise ValueError("fill_value is not provided.") | |||
| if not isinstance(fill_value, (str, float, bool, int)): | |||
| raise TypeError("fill_value must be either a primitive python str, float, bool, or int") | |||
| kwargs["fill_value"] = fill_value | |||
| return method(self, **kwargs) | |||
| return new_method | |||
| def check_de_type(method): | |||
| """Wrapper method to check the parameters of data type.""" | |||
| @@ -72,6 +72,7 @@ SET(DE_UT_SRCS | |||
| tokenizer_op_test.cc | |||
| gnn_graph_test.cc | |||
| coco_op_test.cc | |||
| fill_op_test.cc | |||
| ) | |||
| add_executable(de_ut_tests ${DE_UT_SRCS}) | |||
| @@ -0,0 +1,183 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "common/common.h" | |||
| #include "dataset/kernels/data/fill_op.h" | |||
| #include "utils/log_adapter.h" | |||
| using namespace mindspore::dataset; | |||
| using mindspore::LogStream; | |||
| using mindspore::ExceptionType::NoExceptionType; | |||
| using mindspore::MsLogLevel::INFO; | |||
| class MindDataTestFillOp : public UT::Common { | |||
| protected: | |||
| MindDataTestFillOp() {} | |||
| }; | |||
| TEST_F(MindDataTestFillOp, TestOp) { | |||
| MS_LOG(INFO) << "Doing MindDataTestFillOp-TestOp."; | |||
| uint64_t labels[3] = {1, 1, 2}; | |||
| TensorShape shape({3}); | |||
| std::shared_ptr<Tensor> input = | |||
| std::make_shared<Tensor>(shape, DataType(DataType::DE_UINT64), reinterpret_cast<unsigned char *>(labels)); | |||
| TensorShape fill_shape({}); | |||
| std::shared_ptr<Tensor> fill_tensor = std::make_shared<Tensor>(fill_shape, DataType(DataType::DE_UINT64)); | |||
| fill_tensor->SetItemAt<uint64_t>({}, 4); | |||
| std::shared_ptr<Tensor> output; | |||
| std::unique_ptr<FillOp> op(new FillOp(fill_tensor)); | |||
| Status s = op->Compute(input, &output); | |||
| uint64_t out[3] = {4, 4, 4}; | |||
| std::shared_ptr<Tensor> expected = | |||
| std::make_shared<Tensor>(TensorShape{3}, DataType(DataType::DE_UINT64), reinterpret_cast<unsigned char *>(out)); | |||
| EXPECT_TRUE(s.IsOk()); | |||
| ASSERT_TRUE(output->shape() == expected->shape()); | |||
| ASSERT_TRUE(output->type() == expected->type()); | |||
| MS_LOG(DEBUG) << *output << std::endl; | |||
| MS_LOG(DEBUG) << *expected << std::endl; | |||
| ASSERT_TRUE(*output == *expected); | |||
| MS_LOG(INFO) << "MindDataTestFillOp-TestOp end."; | |||
| } | |||
| TEST_F(MindDataTestFillOp, TestCasting) { | |||
| MS_LOG(INFO) << "Doing MindDataTestFillOp-TestCasting."; | |||
| uint64_t labels[3] = {0, 1, 2}; | |||
| TensorShape shape({3}); | |||
| std::shared_ptr<Tensor> input = | |||
| std::make_shared<Tensor>(shape, DataType(DataType::DE_UINT64), reinterpret_cast<unsigned char *>(labels)); | |||
| TensorShape fill_shape({}); | |||
| std::shared_ptr<Tensor> fill_tensor = std::make_shared<Tensor>(fill_shape, DataType(DataType::DE_FLOAT32)); | |||
| fill_tensor->SetItemAt<float>({}, 2.0); | |||
| std::shared_ptr<Tensor> output; | |||
| std::unique_ptr<FillOp> op(new FillOp(fill_tensor)); | |||
| Status s = op->Compute(input, &output); | |||
| uint64_t out[3] = {2, 2, 2}; | |||
| std::shared_ptr<Tensor> expected = | |||
| std::make_shared<Tensor>(TensorShape{3}, DataType(DataType::DE_UINT64), reinterpret_cast<unsigned char *>(out)); | |||
| ASSERT_TRUE(output->shape() == expected->shape()); | |||
| ASSERT_TRUE(output->type() == expected->type()); | |||
| EXPECT_TRUE(s.IsOk()); | |||
| MS_LOG(DEBUG) << *output << std::endl; | |||
| MS_LOG(DEBUG) << *expected << std::endl; | |||
| ASSERT_TRUE(*output == *expected); | |||
| MS_LOG(INFO) << "MindDataTestFillOp-TestCasting end."; | |||
| } | |||
| TEST_F(MindDataTestFillOp, ScalarFill) { | |||
| MS_LOG(INFO) << "Doing MindDataTestFillOp-ScalarFill."; | |||
| uint64_t labels[3] = {0, 1, 2}; | |||
| TensorShape shape({3}); | |||
| std::shared_ptr<Tensor> input = | |||
| std::make_shared<Tensor>(shape, DataType(DataType::DE_UINT64), reinterpret_cast<unsigned char *>(labels)); | |||
| TensorShape fill_shape({2}); | |||
| uint64_t fill_labels[3] = {0, 1}; | |||
| std::shared_ptr<Tensor> fill_tensor = | |||
| std::make_shared<Tensor>(fill_shape, DataType(DataType::DE_UINT64), reinterpret_cast<unsigned char *>(fill_labels)); | |||
| std::shared_ptr<Tensor> output; | |||
| std::unique_ptr<FillOp> op(new FillOp(fill_tensor)); | |||
| Status s = op->Compute(input, &output); | |||
| EXPECT_TRUE(s.IsError()); | |||
| ASSERT_TRUE(s.get_code() == StatusCode::kUnexpectedError); | |||
| MS_LOG(INFO) << "MindDataTestFillOp-ScalarFill end."; | |||
| } | |||
| TEST_F(MindDataTestFillOp, StringFill) { | |||
| MS_LOG(INFO) << "Doing MindDataTestFillOp-StringFill."; | |||
| std::vector<std::string> strings = {"xyzzy", "plugh", "abracadabra"}; | |||
| TensorShape shape({3}); | |||
| std::shared_ptr<Tensor> input = std::make_shared<Tensor>(strings, shape); | |||
| TensorShape fill_shape({}); | |||
| std::string fill_string = "hello"; | |||
| std::shared_ptr<Tensor> fill_tensor = std::make_shared<Tensor>(fill_string); | |||
| std::shared_ptr<Tensor> output; | |||
| std::unique_ptr<FillOp> op(new FillOp(fill_tensor)); | |||
| Status s = op->Compute(input, &output); | |||
| std::vector<std::string> expected_strings = {"hello", "hello", "hello"}; | |||
| TensorShape expected_shape({3}); | |||
| std::shared_ptr<Tensor> expected = std::make_shared<Tensor>(expected_strings, expected_shape); | |||
| EXPECT_TRUE(s.IsOk()); | |||
| ASSERT_TRUE(output->shape() == expected->shape()); | |||
| ASSERT_TRUE(output->type() == expected->type()); | |||
| MS_LOG(DEBUG) << *output << std::endl; | |||
| MS_LOG(DEBUG) << *expected << std::endl; | |||
| ASSERT_TRUE(*output == *expected); | |||
| MS_LOG(INFO) << "MindDataTestFillOp-StringFill end."; | |||
| } | |||
| TEST_F(MindDataTestFillOp, NumericToString) { | |||
| MS_LOG(INFO) << "Doing MindDataTestFillOp-NumericToString."; | |||
| std::vector<std::string> strings = {"xyzzy", "plugh", "abracadabra"}; | |||
| TensorShape shape({3}); | |||
| std::shared_ptr<Tensor> input = std::make_shared<Tensor>(strings, shape); | |||
| TensorShape fill_shape({}); | |||
| std::shared_ptr<Tensor> fill_tensor = std::make_shared<Tensor>(fill_shape, DataType(DataType::DE_FLOAT32)); | |||
| fill_tensor->SetItemAt<float>({}, 2.0); | |||
| std::shared_ptr<Tensor> output; | |||
| std::unique_ptr<FillOp> op(new FillOp(fill_tensor)); | |||
| Status s = op->Compute(input, &output); | |||
| EXPECT_TRUE(s.IsError()); | |||
| ASSERT_TRUE(s.get_code() == StatusCode::kUnexpectedError); | |||
| MS_LOG(INFO) << "MindDataTestFillOp-NumericToString end."; | |||
| } | |||
| TEST_F(MindDataTestFillOp, StringToNumeric) { | |||
| MS_LOG(INFO) << "Doing MindDataTestFillOp-StringToNumeric."; | |||
| uint64_t labels[3] = {0, 1, 2}; | |||
| TensorShape shape({3}); | |||
| std::shared_ptr<Tensor> input = | |||
| std::make_shared<Tensor>(shape, DataType(DataType::DE_UINT64), reinterpret_cast<unsigned char *>(labels)); | |||
| TensorShape fill_shape({}); | |||
| std::string fill_string = "hello"; | |||
| std::shared_ptr<Tensor> fill_tensor = std::make_shared<Tensor>(fill_string); | |||
| std::shared_ptr<Tensor> output; | |||
| std::unique_ptr<FillOp> op(new FillOp(fill_tensor)); | |||
| Status s = op->Compute(input, &output); | |||
| EXPECT_TRUE(s.IsError()); | |||
| ASSERT_TRUE(s.get_code() == StatusCode::kUnexpectedError); | |||
| MS_LOG(INFO) << "MindDataTestFillOp-StringToNumeric end."; | |||
| } | |||
| @@ -13,9 +13,6 @@ | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| // | |||
| // Created by jesse on 10/3/19. | |||
| // | |||
| #include "common/common.h" | |||
| #include "gtest/gtest.h" | |||
| @@ -25,32 +22,32 @@ | |||
| #include "utils/log_adapter.h" | |||
| using namespace mindspore::dataset; | |||
| using mindspore::MsLogLevel::INFO; | |||
| using mindspore::ExceptionType::NoExceptionType; | |||
| using mindspore::LogStream; | |||
| using mindspore::ExceptionType::NoExceptionType; | |||
| using mindspore::MsLogLevel::INFO; | |||
| class MindDataTestQueue : public UT::Common { | |||
| public: | |||
| MindDataTestQueue() {} | |||
| MindDataTestQueue() {} | |||
| void SetUp() {} | |||
| void SetUp() {} | |||
| }; | |||
| int gRefCountDestructorCalled; | |||
| class RefCount { | |||
| public: | |||
| RefCount() : v_(nullptr) {} | |||
| explicit RefCount(int x) : v_(std::make_shared<int>(x)) {} | |||
| explicit RefCount(const RefCount &o) : v_(o.v_) {} | |||
| ~RefCount() { | |||
| MS_LOG(DEBUG) << "Destructor of RefCount called" << std::endl; | |||
| gRefCountDestructorCalled++; | |||
| } | |||
| RefCount& operator=(const RefCount &o) { | |||
| v_ = o.v_; | |||
| return *this; | |||
| } | |||
| RefCount() : v_(nullptr) {} | |||
| explicit RefCount(int x) : v_(std::make_shared<int>(x)) {} | |||
| explicit RefCount(const RefCount &o) : v_(o.v_) {} | |||
| ~RefCount() { | |||
| MS_LOG(DEBUG) << "Destructor of RefCount called" << std::endl; | |||
| gRefCountDestructorCalled++; | |||
| } | |||
| RefCount &operator=(const RefCount &o) { | |||
| v_ = o.v_; | |||
| return *this; | |||
| } | |||
| std::shared_ptr<int> v_; | |||
| }; | |||
| @@ -70,22 +67,22 @@ TEST_F(MindDataTestQueue, Test1) { | |||
| // Use count should remain 2. a and b. No copy in the queue. | |||
| ASSERT_EQ(a.use_count(), 2); | |||
| a.reset(new int(5)); | |||
| ASSERT_EQ(a.use_count(),1); | |||
| ASSERT_EQ(a.use_count(), 1); | |||
| // Push again but expect a is nullptr after push | |||
| rc = que.Add(std::move(a)); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| ASSERT_EQ(a.use_count(),0); | |||
| ASSERT_EQ(a.use_count(), 0); | |||
| rc = que.PopFront(&b); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| ASSERT_EQ(*b, 5); | |||
| ASSERT_EQ(b.use_count(),1); | |||
| ASSERT_EQ(b.use_count(), 1); | |||
| // Test construct in place | |||
| rc = que.EmplaceBack(std::make_shared<int>(100)); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| rc = que.PopFront(&b); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| ASSERT_EQ(*b, 100); | |||
| ASSERT_EQ(b.use_count(),1); | |||
| ASSERT_EQ(b.use_count(), 1); | |||
| // Test the destructor of the Queue by add an element in the queue without popping it and let the queue go | |||
| // out of scope. | |||
| rc = que.EmplaceBack(std::make_shared<int>(2000)); | |||
| @@ -127,7 +124,7 @@ TEST_F(MindDataTestQueue, Test3) { | |||
| ASSERT_EQ(*b, 40); | |||
| } | |||
| void test4(){ | |||
| void test4() { | |||
| gRefCountDestructorCalled = 0; | |||
| // Pass a structure along the queue. | |||
| Queue<RefCount> que(3); | |||
| @@ -144,9 +141,7 @@ void test4(){ | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| } | |||
| TEST_F(MindDataTestQueue, Test4) { | |||
| test4(); | |||
| } | |||
| TEST_F(MindDataTestQueue, Test4) { test4(); } | |||
| TEST_F(MindDataTestQueue, Test5) { | |||
| test4(); | |||
| @@ -0,0 +1,95 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """ | |||
| Testing fill op | |||
| """ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.dataset as ds | |||
| import mindspore.dataset.transforms.c_transforms as data_trans | |||
| def test_fillop_basic(): | |||
| def gen(): | |||
| yield (np.array([4, 5, 6, 7], dtype=np.uint8),) | |||
| data = ds.GeneratorDataset(gen, column_names=["col"]) | |||
| fill_op = data_trans.Fill(3) | |||
| data = data.map(input_columns=["col"], operations=fill_op) | |||
| expected = np.array([3, 3, 3, 3], dtype=np.uint8) | |||
| for data_row in data: | |||
| np.testing.assert_array_equal(data_row[0], expected) | |||
| def test_fillop_down_type_cast(): | |||
| def gen(): | |||
| yield (np.array([4, 5, 6, 7], dtype=np.uint8),) | |||
| data = ds.GeneratorDataset(gen, column_names=["col"]) | |||
| fill_op = data_trans.Fill(-3) | |||
| data = data.map(input_columns=["col"], operations=fill_op) | |||
| expected = np.array([253, 253, 253, 253], dtype=np.uint8) | |||
| for data_row in data: | |||
| np.testing.assert_array_equal(data_row[0], expected) | |||
| def test_fillop_up_type_cast(): | |||
| def gen(): | |||
| yield (np.array([4, 5, 6, 7], dtype=np.float),) | |||
| data = ds.GeneratorDataset(gen, column_names=["col"]) | |||
| fill_op = data_trans.Fill(3) | |||
| data = data.map(input_columns=["col"], operations=fill_op) | |||
| expected = np.array([3., 3., 3., 3.], dtype=np.float) | |||
| for data_row in data: | |||
| np.testing.assert_array_equal(data_row[0], expected) | |||
| def test_fillop_string(): | |||
| def gen(): | |||
| yield (np.array(["45555", "45555"], dtype='S'),) | |||
| data = ds.GeneratorDataset(gen, column_names=["col"]) | |||
| fill_op = data_trans.Fill("error") | |||
| data = data.map(input_columns=["col"], operations=fill_op) | |||
| expected = np.array(['error', 'error'], dtype='S') | |||
| for data_row in data: | |||
| np.testing.assert_array_equal(data_row[0], expected) | |||
| def test_fillop_error_handling(): | |||
| def gen(): | |||
| yield (np.array([4, 4, 4, 4]),) | |||
| data = ds.GeneratorDataset(gen, column_names=["col"]) | |||
| fill_op = data_trans.Fill("words") | |||
| data = data.map(input_columns=["col"], operations=fill_op) | |||
| with pytest.raises(RuntimeError) as error_info: | |||
| for data_row in data: | |||
| print(data_row) | |||
| assert "Types do not match" in repr(error_info.value) | |||
| if __name__ == "__main__": | |||
| test_fillop_basic() | |||
| test_fillop_up_type_cast() | |||
| test_fillop_down_type_cast() | |||
| test_fillop_string() | |||
| test_fillop_error_handling() | |||