Merge pull request !2198 from h.farahat/mask_optags/v0.5.0-beta
| @@ -38,6 +38,7 @@ | |||
| #include "dataset/kernels/image/resize_op.h" | |||
| #include "dataset/kernels/image/uniform_aug_op.h" | |||
| #include "dataset/kernels/data/fill_op.h" | |||
| #include "dataset/kernels/data/mask_op.h" | |||
| #include "dataset/kernels/data/slice_op.h" | |||
| #include "dataset/kernels/data/type_cast_op.h" | |||
| #include "dataset/engine/datasetops/source/cifar_op.h" | |||
| @@ -383,7 +384,7 @@ void bindTensorOps2(py::module *m) { | |||
| *m, "FillOp", "Tensor operation to return tensor filled with same value as input fill value.") | |||
| .def(py::init<std::shared_ptr<Tensor>>()); | |||
| (void)py::class_<SliceOp, TensorOp, std::shared_ptr<SliceOp>>(*m, "SliceOp", "") | |||
| (void)py::class_<SliceOp, TensorOp, std::shared_ptr<SliceOp>>(*m, "SliceOp", "Tensor Slice operation.") | |||
| .def(py::init<bool>()) | |||
| .def(py::init([](const py::list &py_list) { | |||
| std::vector<dsize_t> c_list; | |||
| @@ -414,6 +415,19 @@ void bindTensorOps2(py::module *m) { | |||
| return std::make_shared<SliceOp>(c_slice); | |||
| })); | |||
| (void)py::enum_<RelationalOp>(*m, "RelationalOp", py::arithmetic()) | |||
| .value("EQ", RelationalOp::kEqual) | |||
| .value("NE", RelationalOp::kNotEqual) | |||
| .value("LT", RelationalOp::kLess) | |||
| .value("LE", RelationalOp::kLessEqual) | |||
| .value("GT", RelationalOp::kGreater) | |||
| .value("GE", RelationalOp::kGreaterEqual) | |||
| .export_values(); | |||
| (void)py::class_<MaskOp, TensorOp, std::shared_ptr<MaskOp>>(*m, "MaskOp", | |||
| "Tensor operation mask using relational comparator") | |||
| .def(py::init<RelationalOp, std::shared_ptr<Tensor>, DataType>()); | |||
| (void)py::class_<RandomRotationOp, TensorOp, std::shared_ptr<RandomRotationOp>>( | |||
| *m, "RandomRotationOp", | |||
| "Tensor operation to apply RandomRotation." | |||
| @@ -699,7 +699,7 @@ Status Tensor::GetItemAt(T *o, const std::vector<dsize_t> &index) const { | |||
| Status Tensor::GetItemAt(std::string_view *o, const std::vector<dsize_t> &index) const { | |||
| RETURN_UNEXPECTED_IF_NULL(data_); | |||
| RETURN_UNEXPECTED_IF_NULL(o); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(type_ == DataType::DE_STRING, "Type is not DE_STRING"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(type_ == DataType::DE_STRING, "Tensor type is not a string"); | |||
| uchar *start = nullptr; | |||
| offset_t length = 0; | |||
| @@ -932,17 +932,17 @@ Status Tensor::SliceNumeric(std::shared_ptr<Tensor> *out, const std::vector<dsiz | |||
| dsize_t out_index = 0; | |||
| dsize_t dim_length = shape_[0]; | |||
| dsize_t type_size = type_.SizeInBytes(); | |||
| dsize_t src_start = handleNeg(indices[0], dim_length); | |||
| dsize_t src_start = HandleNeg(indices[0], dim_length); | |||
| uchar *dst_addr = (*out)->data_; | |||
| dsize_t count = 1; | |||
| for (dsize_t i = 0; i < indices.size(); i++) { | |||
| dsize_t cur_index = handleNeg(indices[i], dim_length); | |||
| dsize_t cur_index = HandleNeg(indices[i], dim_length); | |||
| CHECK_FAIL_RETURN_UNEXPECTED( | |||
| cur_index >= 0 && cur_index < dim_length, | |||
| "Index " + std::to_string(indices[i]) + " is out of bounds [0," + std::to_string(dim_length) + ")"); | |||
| if (i < indices.size() - 1) { | |||
| dsize_t next_index = handleNeg(indices[i + 1], dim_length); | |||
| dsize_t next_index = HandleNeg(indices[i + 1], dim_length); | |||
| if (next_index == cur_index + 1) { | |||
| count++; | |||
| continue; | |||
| @@ -951,7 +951,7 @@ Status Tensor::SliceNumeric(std::shared_ptr<Tensor> *out, const std::vector<dsiz | |||
| memcpy_s(dst_addr + out_index * type_size, (*out)->SizeInBytes(), data_ + src_start * type_size, count * type_size); | |||
| out_index += count; | |||
| if (i < indices.size() - 1) { | |||
| src_start = handleNeg(indices[i + 1], dim_length); // next index | |||
| src_start = HandleNeg(indices[i + 1], dim_length); // next index | |||
| } | |||
| count = 1; | |||
| } | |||
| @@ -961,7 +961,7 @@ Status Tensor::SliceString(std::shared_ptr<Tensor> *out, const std::vector<dsize | |||
| dsize_t dim_length = shape_[0]; | |||
| std::vector<std::string> strings; | |||
| for (dsize_t index : indices) { | |||
| dsize_t cur_index = handleNeg(index, dim_length); | |||
| dsize_t cur_index = HandleNeg(index, dim_length); | |||
| CHECK_FAIL_RETURN_UNEXPECTED( | |||
| cur_index >= 0 && cur_index < dim_length, | |||
| "Index " + std::to_string(index) + " is out of bounds [0," + std::to_string(dim_length) + ")"); | |||
| @@ -348,7 +348,7 @@ class Tensor { | |||
| } | |||
| // Handle negative indices. | |||
| static inline dsize_t handleNeg(dsize_t index, dsize_t length) { return (index < 0) ? (index + length) : index; } | |||
| static inline dsize_t HandleNeg(dsize_t index, dsize_t length) { return (index < 0) ? (index + length) : index; } | |||
| // Slice tensor bases on the given indicies. Copy the sliced data into out tensor. Only rank1 tensors are supported. | |||
| // Based on the type of tensor, SliceNumeric or SliceString will be called | |||
| @@ -1,9 +1,10 @@ | |||
| file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") | |||
| set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) | |||
| add_library(kernels-data OBJECT | |||
| data_utils.cc | |||
| one_hot_op.cc | |||
| type_cast_op.cc | |||
| to_float16_op.cc | |||
| fill_op.cc | |||
| slice_op.cc) | |||
| data_utils.cc | |||
| one_hot_op.cc | |||
| type_cast_op.cc | |||
| to_float16_op.cc | |||
| fill_op.cc | |||
| slice_op.cc | |||
| mask_op.cc) | |||
| @@ -120,7 +120,7 @@ Status Fill(const std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> *output | |||
| std::unique_ptr<TypeCastOp> op(new TypeCastOp(to)); | |||
| std::shared_ptr<Tensor> fill_output; | |||
| op->Compute(fill_value, &fill_output); | |||
| RETURN_IF_NOT_OK(op->Compute(fill_value, &fill_output)); | |||
| RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, TensorImpl::kFlexible, input->shape(), input->type())); | |||
| @@ -344,6 +344,8 @@ Status PadEnd(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor> *dst, | |||
| return PadEndString(src, dst, pad_shape, ""); | |||
| } | |||
| } | |||
| CHECK_FAIL_RETURN_UNEXPECTED(src->type().IsNumeric() == pad_val->type().IsNumeric(), | |||
| "Source and pad_value tensors are not of the same type."); | |||
| if (pad_val->type().IsNumeric()) { | |||
| float val = 0; | |||
| RETURN_IF_NOT_OK(pad_val->GetItemAt<float>(&val, {})); | |||
| @@ -454,5 +456,102 @@ Status PadEndStringHelper(const std::shared_ptr<Tensor> &src, std::vector<std::s | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| template <typename T> | |||
| Status MaskHelper(const std::shared_ptr<Tensor> &input, const std::shared_ptr<Tensor> &output, | |||
| const std::shared_ptr<Tensor> &value_tensor, RelationalOp op) { | |||
| T value; | |||
| RETURN_IF_NOT_OK(value_tensor->GetItemAt(&value, {})); | |||
| auto in_itr = input->begin<T>(); | |||
| auto out_itr = output->begin<bool>(); | |||
| for (; in_itr != input->end<T>(); in_itr++, out_itr++) { | |||
| switch (op) { | |||
| case RelationalOp::kEqual: | |||
| *out_itr = (*in_itr == value); | |||
| break; | |||
| case RelationalOp::kNotEqual: | |||
| *out_itr = (*in_itr != value); | |||
| break; | |||
| case RelationalOp::kGreater: | |||
| *out_itr = (*in_itr > value); | |||
| break; | |||
| case RelationalOp::kGreaterEqual: | |||
| *out_itr = (*in_itr >= value); | |||
| break; | |||
| case RelationalOp::kLess: | |||
| *out_itr = (*in_itr < value); | |||
| break; | |||
| case RelationalOp::kLessEqual: | |||
| *out_itr = (*in_itr <= value); | |||
| break; | |||
| default: | |||
| RETURN_STATUS_UNEXPECTED("Unknown relational operator."); | |||
| } | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status Mask(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, const std::shared_ptr<Tensor> &value, | |||
| RelationalOp op) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(input->type().IsNumeric() == value->type().IsNumeric(), | |||
| "Cannot convert constant value to the type of the input tensor."); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(value->shape() == TensorShape::CreateScalar(), "Value is not a scalar"); | |||
| RETURN_IF_NOT_OK(Tensor::CreateTensor(output, TensorImpl::kFlexible, input->shape(), DataType(DataType::DE_BOOL))); | |||
| std::unique_ptr<TypeCastOp> value_cast_op(new TypeCastOp(input->type())); | |||
| std::shared_ptr<Tensor> casted_value; | |||
| if (input->type().IsNumeric()) { | |||
| RETURN_IF_NOT_OK(value_cast_op->Compute(value, &casted_value)); | |||
| } else { | |||
| casted_value = value; | |||
| } | |||
| switch (input->type().value()) { | |||
| case DataType::DE_BOOL: | |||
| RETURN_IF_NOT_OK(MaskHelper<bool>(input, *output, casted_value, op)); | |||
| break; | |||
| case DataType::DE_INT8: | |||
| RETURN_IF_NOT_OK(MaskHelper<int8_t>(input, *output, casted_value, op)); | |||
| break; | |||
| case DataType::DE_UINT8: | |||
| RETURN_IF_NOT_OK(MaskHelper<uint8_t>(input, *output, casted_value, op)); | |||
| break; | |||
| case DataType::DE_UINT16: | |||
| RETURN_IF_NOT_OK(MaskHelper<uint16_t>(input, *output, casted_value, op)); | |||
| break; | |||
| case DataType::DE_INT16: | |||
| RETURN_IF_NOT_OK(MaskHelper<int16_t>(input, *output, casted_value, op)); | |||
| break; | |||
| case DataType::DE_UINT32: | |||
| RETURN_IF_NOT_OK(MaskHelper<uint32_t>(input, *output, casted_value, op)); | |||
| break; | |||
| case DataType::DE_INT32: | |||
| RETURN_IF_NOT_OK(MaskHelper<int32_t>(input, *output, casted_value, op)); | |||
| break; | |||
| case DataType::DE_UINT64: | |||
| RETURN_IF_NOT_OK(MaskHelper<uint64_t>(input, *output, casted_value, op)); | |||
| break; | |||
| case DataType::DE_INT64: | |||
| RETURN_IF_NOT_OK(MaskHelper<int64_t>(input, *output, casted_value, op)); | |||
| break; | |||
| case DataType::DE_FLOAT16: | |||
| RETURN_IF_NOT_OK(MaskHelper<float16>(input, *output, casted_value, op)); | |||
| break; | |||
| case DataType::DE_FLOAT32: | |||
| RETURN_IF_NOT_OK(MaskHelper<float>(input, *output, casted_value, op)); | |||
| break; | |||
| case DataType::DE_FLOAT64: | |||
| RETURN_IF_NOT_OK(MaskHelper<double>(input, *output, casted_value, op)); | |||
| break; | |||
| case DataType::DE_STRING: | |||
| RETURN_IF_NOT_OK(MaskHelper<std::string_view>(input, *output, casted_value, op)); | |||
| break; | |||
| case DataType::DE_UNKNOWN: | |||
| RETURN_STATUS_UNEXPECTED("Unsupported input type."); | |||
| break; | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -119,6 +119,35 @@ Status PadEndString(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor> | |||
| Status PadEndStringHelper(const std::shared_ptr<Tensor> &src, std::vector<std::string> *dst, | |||
| const TensorShape &dst_shape, std::vector<dsize_t> cur_ind, size_t cur_dim, | |||
| const std::string &pad_value); | |||
| enum class RelationalOp { | |||
| kEqual = 0, // == | |||
| kNotEqual, // != | |||
| kLess, // < | |||
| kLessEqual, // <= | |||
| kGreater, // > | |||
| kGreaterEqual, // >= | |||
| }; | |||
| /// Helper method that masks the input tensor | |||
| /// @tparam T type of the tensor | |||
| /// @param input[in] input tensor | |||
| /// @param output[out] output tensor | |||
| /// @param value_tensor[in] scalar tensor value to compared with | |||
| /// @param op[in] RelationalOp enum | |||
| /// @return Status ok/error | |||
| template <typename T> | |||
| Status MaskHelper(const std::shared_ptr<Tensor> &input, const std::shared_ptr<Tensor> &output, | |||
| const std::shared_ptr<Tensor> &value_tensor, RelationalOp op); | |||
| /// Mask the input tensor | |||
| /// @param input[in] input tensor | |||
| /// @param output[out] output tensor | |||
| /// @param value[in] scalar tensor value to compared with | |||
| /// @param op[in] RelationalOp enum | |||
| /// @return Status ok/error | |||
| Status Mask(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, const std::shared_ptr<Tensor> &value, | |||
| RelationalOp op); | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,49 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "dataset/kernels/data/mask_op.h" | |||
| #include "dataset/core/tensor.h" | |||
| #include "dataset/kernels/tensor_op.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| Status MaskOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) { | |||
| IO_CHECK(input, output); | |||
| std::shared_ptr<Tensor> temp_output; | |||
| CHECK_FAIL_RETURN_UNEXPECTED(type_.IsNumeric(), "Cannot generate a string mask. Type should be numeric."); | |||
| RETURN_IF_NOT_OK(Mask(input, &temp_output, value_, op_)); | |||
| // cast the output to the the required type. Skip casting if type_ is bool. | |||
| if (type_ != DataType::DE_BOOL) { | |||
| RETURN_IF_NOT_OK(cast_->Compute(temp_output, output)); | |||
| } else { | |||
| *output = temp_output; | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status MaskOp::OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) { | |||
| RETURN_IF_NOT_OK(TensorOp::OutputType(inputs, outputs)); | |||
| outputs[0] = type_; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,54 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef DATASET_KERNELS_DATA_MASK_OP_H_ | |||
| #define DATASET_KERNELS_DATA_MASK_OP_H_ | |||
| #include <algorithm> | |||
| #include <memory> | |||
| #include <string> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "dataset/core/tensor.h" | |||
| #include "dataset/kernels/tensor_op.h" | |||
| #include "dataset/kernels/data/type_cast_op.h" | |||
| #include "dataset/kernels/data/data_utils.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| class MaskOp : public TensorOp { | |||
| public: | |||
| MaskOp(RelationalOp op, std::shared_ptr<Tensor> value, DataType type = DataType(DataType::DE_BOOL)) | |||
| : op_(op), value_(std::move(value)), type_(type), cast_(new TypeCastOp(type)) {} | |||
| ~MaskOp() override = default; | |||
| void Print(std::ostream &out) const override { out << "MaskOp"; } | |||
| Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override; | |||
| Status OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) override; | |||
| private: | |||
| RelationalOp op_; | |||
| std::shared_ptr<Tensor> value_; | |||
| DataType type_; | |||
| std::unique_ptr<TypeCastOp> cast_; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // DATASET_KERNELS_DATA_MASK_OP_H_ | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -16,7 +16,6 @@ | |||
| #include "dataset/kernels/data/slice_op.h" | |||
| #include "dataset/core/tensor.h" | |||
| #include "dataset/kernels/data/data_utils.h" | |||
| #include "dataset/kernels/tensor_op.h" | |||
| namespace mindspore { | |||
| @@ -36,8 +36,8 @@ class Slice { | |||
| std::vector<dsize_t> Indices(dsize_t length) { | |||
| std::vector<dsize_t> indices; | |||
| dsize_t index = std::min(Tensor::handleNeg(start_, length), length); | |||
| dsize_t end_index = std::min(Tensor::handleNeg(stop_, length), length); | |||
| dsize_t index = std::min(Tensor::HandleNeg(start_, length), length); | |||
| dsize_t end_index = std::min(Tensor::HandleNeg(stop_, length), length); | |||
| if (step_ > 0) { | |||
| for (; index < end_index; index += step_) { | |||
| indices.push_back(index); | |||
| @@ -80,4 +80,4 @@ class SliceOp : public TensorOp { | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // DATASET_KERNELS_DATA_ONE_HOT_OP_H_ | |||
| #endif // DATASET_KERNELS_DATA_SLICE_OP_H_ | |||
| @@ -15,10 +15,14 @@ | |||
| """ | |||
| This module c_transforms provides common operations, including OneHotOp and TypeCast. | |||
| """ | |||
| import numpy as np | |||
| from enum import IntEnum | |||
| import mindspore.common.dtype as mstype | |||
| import mindspore._c_dataengine as cde | |||
| from .validators import check_num_classes, check_de_type, check_fill_value, check_slice_op | |||
| import numpy as np | |||
| from .validators import check_num_classes, check_de_type, check_fill_value, check_slice_op, check_mask_op | |||
| from ..core.datatypes import mstype_to_detype | |||
| @@ -48,7 +52,6 @@ class Fill(cde.FillOp): | |||
| @check_fill_value | |||
| def __init__(self, fill_value): | |||
| print(fill_value) | |||
| super().__init__(cde.Tensor(np.array(fill_value))) | |||
| @@ -108,3 +111,50 @@ class Slice(cde.SliceOp): | |||
| elif dim0 is Ellipsis: | |||
| dim0 = True | |||
| super().__init__(dim0) | |||
| class Relational(IntEnum): | |||
| EQ = 0 | |||
| NE = 1 | |||
| GT = 2 | |||
| GE = 3 | |||
| LT = 4 | |||
| LE = 5 | |||
| DE_C_RELATIONAL = {Relational.EQ: cde.RelationalOp.EQ, | |||
| Relational.NE: cde.RelationalOp.NE, | |||
| Relational.GT: cde.RelationalOp.GT, | |||
| Relational.GE: cde.RelationalOp.GE, | |||
| Relational.LT: cde.RelationalOp.LT, | |||
| Relational.LE: cde.RelationalOp.LE} | |||
| class Mask(cde.MaskOp): | |||
| """ | |||
| Mask content of the input tensor with the given predicate. | |||
| Any element of the tensor that matches the predicate will be evaluated to True, otherwise False. | |||
| Args: | |||
| operator (Relational): One of the relational operator EQ, NE LT, GT, LE or GE | |||
| constant (python types (str, int, float, or bool): constant to be compared to. | |||
| Constant will be casted to the type of the input tensor | |||
| dtype (optional, mindspore.dtype): type of the generated mask. Default to bool | |||
| Examples: | |||
| >>> # Data before | |||
| >>> # | col1 | | |||
| >>> # +---------+ | |||
| >>> # | [1,2,3] | | |||
| >>> # +---------+ | |||
| >>> data = data.map(operations=Mask(Relational.EQ, 2)) | |||
| >>> # Data after | |||
| >>> # | col1 | | |||
| >>> # +--------------------+ | |||
| >>> # | [False,True,False] | | |||
| >>> # +--------------------+ | |||
| """ | |||
| @check_mask_op | |||
| def __init__(self, operator, constant, dtype=mstype.bool_): | |||
| dtype = mstype_to_detype(dtype) | |||
| constant = cde.Tensor(np.array(constant)) | |||
| super().__init__(DE_C_RELATIONAL[operator], constant, dtype) | |||
| @@ -213,3 +213,40 @@ def check_slice_op(method): | |||
| return method(self, *args) | |||
| return new_method | |||
| def check_mask_op(method): | |||
| """Wrapper method to check the parameters of slice.""" | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| operator, constant, dtype = (list(args) + 3 * [None])[:3] | |||
| if "operator" in kwargs: | |||
| operator = kwargs.get("operator") | |||
| if "constant" in kwargs: | |||
| constant = kwargs.get("constant") | |||
| if "dtype" in kwargs: | |||
| dtype = kwargs.get("dtype") | |||
| if operator is None: | |||
| raise ValueError("operator is not provided.") | |||
| if constant is None: | |||
| raise ValueError("constant is not provided.") | |||
| from .c_transforms import Relational | |||
| if not isinstance(operator, Relational): | |||
| raise TypeError("operator is not a Relational operator enum.") | |||
| if not isinstance(constant, (str, float, bool, int)): | |||
| raise TypeError("constant must be either a primitive python str, float, bool, or int") | |||
| if not isinstance(dtype, typing.Type): | |||
| raise TypeError("dtype is not a MindSpore data type.") | |||
| kwargs["operator"] = operator | |||
| kwargs["constant"] = constant | |||
| kwargs["dtype"] = dtype | |||
| return method(self, **kwargs) | |||
| return new_method | |||
| @@ -0,0 +1,63 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <memory> | |||
| #include <string> | |||
| #include "dataset/core/client.h" | |||
| #include "common/common.h" | |||
| #include "gtest/gtest.h" | |||
| #include "securec.h" | |||
| #include "dataset/core/tensor.h" | |||
| #include "dataset/core/cv_tensor.h" | |||
| #include "dataset/core/data_type.h" | |||
| #include "dataset/util/de_error.h" | |||
| #include "dataset/kernels/data/mask_op.h" | |||
| #include "dataset/kernels/data/data_utils.h" | |||
| using namespace mindspore::dataset; | |||
| namespace py = pybind11; | |||
| class MindDataTestMaskOp : public UT::Common { | |||
| public: | |||
| MindDataTestMaskOp() {} | |||
| void SetUp() { GlobalInit(); } | |||
| }; | |||
| TEST_F(MindDataTestMaskOp, Basics) { | |||
| std::shared_ptr<Tensor> t; | |||
| Tensor::CreateTensor(&t, std::vector<uint32_t>({1, 2, 3, 4, 5, 6})); | |||
| std::shared_ptr<Tensor> v; | |||
| Tensor::CreateTensor(&v, std::vector<uint32_t>({3}), TensorShape::CreateScalar()); | |||
| std::shared_ptr<MaskOp> op = std::make_shared<MaskOp>(RelationalOp::kEqual, v, DataType(DataType::DE_UINT16)); | |||
| std::shared_ptr<Tensor> out; | |||
| ASSERT_TRUE(op->Compute(t, &out).IsOk()); | |||
| op = std::make_shared<MaskOp>(RelationalOp::kNotEqual, v, DataType(DataType::DE_UINT16)); | |||
| ASSERT_TRUE(op->Compute(t, &out).IsOk()); | |||
| op = std::make_shared<MaskOp>(RelationalOp::kLessEqual, v, DataType(DataType::DE_UINT16)); | |||
| ASSERT_TRUE(op->Compute(t, &out).IsOk()); | |||
| op = std::make_shared<MaskOp>(RelationalOp::kLess, v, DataType(DataType::DE_UINT16)); | |||
| ASSERT_TRUE(op->Compute(t, &out).IsOk()); | |||
| op = std::make_shared<MaskOp>(RelationalOp::kGreaterEqual, v, DataType(DataType::DE_UINT16)); | |||
| ASSERT_TRUE(op->Compute(t, &out).IsOk()); | |||
| op = std::make_shared<MaskOp>(RelationalOp::kGreater, v, DataType(DataType::DE_UINT16)); | |||
| ASSERT_TRUE(op->Compute(t, &out).IsOk()); | |||
| } | |||
| @@ -0,0 +1,132 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """ | |||
| Testing Mask op in DE | |||
| """ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.common.dtype as mstype | |||
| import mindspore.dataset as ds | |||
| import mindspore.dataset.transforms.c_transforms as ops | |||
| mstype_to_np_type = { | |||
| mstype.bool_: np.bool, | |||
| mstype.int8: np.int8, | |||
| mstype.uint8: np.uint8, | |||
| mstype.int16: np.int16, | |||
| mstype.uint16: np.uint16, | |||
| mstype.int32: np.int32, | |||
| mstype.uint32: np.uint32, | |||
| mstype.int64: np.int64, | |||
| mstype.uint64: np.uint64, | |||
| mstype.float16: np.float16, | |||
| mstype.float32: np.float32, | |||
| mstype.float64: np.float64, | |||
| mstype.string: np.str | |||
| } | |||
| def mask_compare(array, op, constant, dtype=mstype.bool_): | |||
| data = ds.NumpySlicesDataset([array]) | |||
| array = np.array(array) | |||
| data = data.map(operations=ops.Mask(op, constant, dtype)) | |||
| for d in data: | |||
| if op == ops.Relational.EQ: | |||
| array = array == np.array(constant, dtype=array.dtype) | |||
| elif op == ops.Relational.NE: | |||
| array = array != np.array(constant, dtype=array.dtype) | |||
| elif op == ops.Relational.GT: | |||
| array = array > np.array(constant, dtype=array.dtype) | |||
| elif op == ops.Relational.GE: | |||
| array = array >= np.array(constant, dtype=array.dtype) | |||
| elif op == ops.Relational.LT: | |||
| array = array < np.array(constant, dtype=array.dtype) | |||
| elif op == ops.Relational.LE: | |||
| array = array <= np.array(constant, dtype=array.dtype) | |||
| array = array.astype(dtype=mstype_to_np_type[dtype]) | |||
| np.testing.assert_array_equal(array, d[0]) | |||
| def test_int_comparison(): | |||
| for k in mstype_to_np_type: | |||
| if k == mstype.string: | |||
| continue | |||
| mask_compare([1, 2, 3, 4, 5], ops.Relational.EQ, 3, k) | |||
| mask_compare([1, 2, 3, 4, 5], ops.Relational.NE, 3, k) | |||
| mask_compare([1, 2, 3, 4, 5], ops.Relational.LT, 3, k) | |||
| mask_compare([1, 2, 3, 4, 5], ops.Relational.LE, 3, k) | |||
| mask_compare([1, 2, 3, 4, 5], ops.Relational.GT, 3, k) | |||
| mask_compare([1, 2, 3, 4, 5], ops.Relational.GE, 3, k) | |||
| def test_float_comparison(): | |||
| for k in mstype_to_np_type: | |||
| if k == mstype.string: | |||
| continue | |||
| mask_compare([1.5, 2.5, 3., 4.5, 5.5], ops.Relational.EQ, 3, k) | |||
| mask_compare([1.5, 2.5, 3., 4.5, 5.5], ops.Relational.NE, 3, k) | |||
| mask_compare([1.5, 2.5, 3., 4.5, 5.5], ops.Relational.LT, 3, k) | |||
| mask_compare([1.5, 2.5, 3., 4.5, 5.5], ops.Relational.LE, 3, k) | |||
| mask_compare([1.5, 2.5, 3., 4.5, 5.5], ops.Relational.GT, 3, k) | |||
| mask_compare([1.5, 2.5, 3., 4.5, 5.5], ops.Relational.GE, 3, k) | |||
| def test_float_comparison2(): | |||
| for k in mstype_to_np_type: | |||
| if k == mstype.string: | |||
| continue | |||
| mask_compare([1, 2, 3, 4, 5], ops.Relational.EQ, 3.5, k) | |||
| mask_compare([1, 2, 3, 4, 5], ops.Relational.NE, 3.5, k) | |||
| mask_compare([1, 2, 3, 4, 5], ops.Relational.LT, 3.5, k) | |||
| mask_compare([1, 2, 3, 4, 5], ops.Relational.LE, 3.5, k) | |||
| mask_compare([1, 2, 3, 4, 5], ops.Relational.GT, 3.5, k) | |||
| mask_compare([1, 2, 3, 4, 5], ops.Relational.GE, 3.5, k) | |||
| def test_string_comparison(): | |||
| for k in mstype_to_np_type: | |||
| if k == mstype.string: | |||
| continue | |||
| mask_compare(["1.5", "2.5", "3.", "4.5", "5.5"], ops.Relational.EQ, "3.", k) | |||
| mask_compare(["1.5", "2.5", "3.", "4.5", "5.5"], ops.Relational.NE, "3.", k) | |||
| mask_compare(["1.5", "2.5", "3.", "4.5", "5.5"], ops.Relational.LT, "3.", k) | |||
| mask_compare(["1.5", "2.5", "3.", "4.5", "5.5"], ops.Relational.LE, "3.", k) | |||
| mask_compare(["1.5", "2.5", "3.", "4.5", "5.5"], ops.Relational.GT, "3.", k) | |||
| mask_compare(["1.5", "2.5", "3.", "4.5", "5.5"], ops.Relational.GE, "3.", k) | |||
| def test_mask_exceptions_str(): | |||
| with pytest.raises(RuntimeError) as info: | |||
| mask_compare([1, 2, 3, 4, 5], ops.Relational.EQ, "3.5") | |||
| assert "Cannot convert constant value to the type of the input tensor." in str(info.value) | |||
| with pytest.raises(RuntimeError) as info: | |||
| mask_compare(["1", "2", "3", "4", "5"], ops.Relational.EQ, 3.5) | |||
| assert "Cannot convert constant value to the type of the input tensor." in str(info.value) | |||
| with pytest.raises(RuntimeError) as info: | |||
| mask_compare(["1", "2", "3", "4", "5"], ops.Relational.EQ, "3.5", mstype.string) | |||
| assert "Cannot generate a string mask. Type should be numeric." in str(info.value) | |||
| if __name__ == "__main__": | |||
| test_int_comparison() | |||
| test_float_comparison() | |||
| test_float_comparison2() | |||
| test_string_comparison() | |||
| test_mask_exceptions_str() | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -13,7 +13,7 @@ | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """ | |||
| Testing TypeCast op in DE | |||
| Testing Slice op in DE | |||
| """ | |||
| import numpy as np | |||
| import pytest | |||
| @@ -109,6 +109,10 @@ def test_slice_exceptions(): | |||
| slice_compare([1, 2, 3, 4, 5], slice(0)) | |||
| assert "Indices are empty, generated tensor would be empty." in str(info.value) | |||
| with pytest.raises(RuntimeError) as info: | |||
| slice_compare([1, 2, 3, 4, 5], slice(3, 1, 1)) | |||
| assert "Indices are empty, generated tensor would be empty." in str(info.value) | |||
| with pytest.raises(RuntimeError) as info: | |||
| slice_compare([1, 2, 3, 4, 5], slice(5, 10, 1)) | |||
| assert "Indices are empty, generated tensor would be empty." in str(info.value) | |||
| @@ -182,6 +186,10 @@ def test_slice_exceptions_str(): | |||
| slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(0)) | |||
| assert "Indices are empty, generated tensor would be empty." in str(info.value) | |||
| with pytest.raises(RuntimeError) as info: | |||
| slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(3, 1, 1)) | |||
| assert "Indices are empty, generated tensor would be empty." in str(info.value) | |||
| with pytest.raises(RuntimeError) as info: | |||
| slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(5, 10, 1)) | |||
| assert "Indices are empty, generated tensor would be empty." in str(info.value) | |||