From: @cathwong Reviewed-by: @robingrosman,@pandoublefeng Signed-off-by: @pandoublefengpull/13789/MERGE
| @@ -20,15 +20,11 @@ | |||||
| #include "minddata/dataset/api/python/pybind_register.h" | #include "minddata/dataset/api/python/pybind_register.h" | ||||
| #include "minddata/dataset/core/tensor_helpers.h" | #include "minddata/dataset/core/tensor_helpers.h" | ||||
| #include "minddata/dataset/kernels/data/concatenate_op.h" | #include "minddata/dataset/kernels/data/concatenate_op.h" | ||||
| #include "minddata/dataset/kernels/data/duplicate_op.h" | |||||
| #include "minddata/dataset/kernels/data/fill_op.h" | #include "minddata/dataset/kernels/data/fill_op.h" | ||||
| #include "minddata/dataset/kernels/data/mask_op.h" | #include "minddata/dataset/kernels/data/mask_op.h" | ||||
| #include "minddata/dataset/kernels/data/one_hot_op.h" | |||||
| #include "minddata/dataset/kernels/data/pad_end_op.h" | #include "minddata/dataset/kernels/data/pad_end_op.h" | ||||
| #include "minddata/dataset/kernels/data/slice_op.h" | #include "minddata/dataset/kernels/data/slice_op.h" | ||||
| #include "minddata/dataset/kernels/data/to_float16_op.h" | #include "minddata/dataset/kernels/data/to_float16_op.h" | ||||
| #include "minddata/dataset/kernels/data/type_cast_op.h" | |||||
| #include "minddata/dataset/kernels/data/unique_op.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| @@ -38,15 +34,6 @@ PYBIND_REGISTER(ConcatenateOp, 1, ([](const py::module *m) { | |||||
| .def(py::init<int8_t, std::shared_ptr<Tensor>, std::shared_ptr<Tensor>>()); | .def(py::init<int8_t, std::shared_ptr<Tensor>, std::shared_ptr<Tensor>>()); | ||||
| })); | })); | ||||
| PYBIND_REGISTER( | |||||
| DuplicateOp, 1, ([](const py::module *m) { | |||||
| (void)py::class_<DuplicateOp, TensorOp, std::shared_ptr<DuplicateOp>>(*m, "DuplicateOp").def(py::init<>()); | |||||
| })); | |||||
| PYBIND_REGISTER(UniqueOp, 1, ([](const py::module *m) { | |||||
| (void)py::class_<UniqueOp, TensorOp, std::shared_ptr<UniqueOp>>(*m, "UniqueOp").def(py::init<>()); | |||||
| })); | |||||
| PYBIND_REGISTER( | PYBIND_REGISTER( | ||||
| FillOp, 1, ([](const py::module *m) { | FillOp, 1, ([](const py::module *m) { | ||||
| (void)py::class_<FillOp, TensorOp, std::shared_ptr<FillOp>>(*m, "FillOp").def(py::init<std::shared_ptr<Tensor>>()); | (void)py::class_<FillOp, TensorOp, std::shared_ptr<FillOp>>(*m, "FillOp").def(py::init<std::shared_ptr<Tensor>>()); | ||||
| @@ -57,11 +44,6 @@ PYBIND_REGISTER(MaskOp, 1, ([](const py::module *m) { | |||||
| .def(py::init<RelationalOp, std::shared_ptr<Tensor>, DataType>()); | .def(py::init<RelationalOp, std::shared_ptr<Tensor>, DataType>()); | ||||
| })); | })); | ||||
| PYBIND_REGISTER( | |||||
| OneHotOp, 1, ([](const py::module *m) { | |||||
| (void)py::class_<OneHotOp, TensorOp, std::shared_ptr<OneHotOp>>(*m, "OneHotOp").def(py::init<int32_t>()); | |||||
| })); | |||||
| PYBIND_REGISTER(PadEndOp, 1, ([](const py::module *m) { | PYBIND_REGISTER(PadEndOp, 1, ([](const py::module *m) { | ||||
| (void)py::class_<PadEndOp, TensorOp, std::shared_ptr<PadEndOp>>(*m, "PadEndOp") | (void)py::class_<PadEndOp, TensorOp, std::shared_ptr<PadEndOp>>(*m, "PadEndOp") | ||||
| .def(py::init<TensorShape, std::shared_ptr<Tensor>>()); | .def(py::init<TensorShape, std::shared_ptr<Tensor>>()); | ||||
| @@ -111,12 +93,6 @@ PYBIND_REGISTER(ToFloat16Op, 1, ([](const py::module *m) { | |||||
| .def(py::init<>()); | .def(py::init<>()); | ||||
| })); | })); | ||||
| PYBIND_REGISTER(TypeCastOp, 1, ([](const py::module *m) { | |||||
| (void)py::class_<TypeCastOp, TensorOp, std::shared_ptr<TypeCastOp>>(*m, "TypeCastOp") | |||||
| .def(py::init<DataType>()) | |||||
| .def(py::init<std::string>()); | |||||
| })); | |||||
| PYBIND_REGISTER(RelationalOp, 0, ([](const py::module *m) { | PYBIND_REGISTER(RelationalOp, 0, ([](const py::module *m) { | ||||
| (void)py::enum_<RelationalOp>(*m, "RelationalOp", py::arithmetic()) | (void)py::enum_<RelationalOp>(*m, "RelationalOp", py::arithmetic()) | ||||
| .value("EQ", RelationalOp::kEqual) | .value("EQ", RelationalOp::kEqual) | ||||
| @@ -64,6 +64,28 @@ PYBIND_REGISTER( | |||||
| })); | })); | ||||
| })); | })); | ||||
| PYBIND_REGISTER( | |||||
| DuplicateOperation, 1, ([](const py::module *m) { | |||||
| (void)py::class_<transforms::DuplicateOperation, TensorOperation, std::shared_ptr<transforms::DuplicateOperation>>( | |||||
| *m, "DuplicateOperation") | |||||
| .def(py::init([]() { | |||||
| auto duplicate = std::make_shared<transforms::DuplicateOperation>(); | |||||
| THROW_IF_ERROR(duplicate->ValidateParams()); | |||||
| return duplicate; | |||||
| })); | |||||
| })); | |||||
| PYBIND_REGISTER( | |||||
| OneHotOperation, 1, ([](const py::module *m) { | |||||
| (void)py::class_<transforms::OneHotOperation, TensorOperation, std::shared_ptr<transforms::OneHotOperation>>( | |||||
| *m, "OneHotOperation") | |||||
| .def(py::init([](int32_t num_classes) { | |||||
| auto one_hot = std::make_shared<transforms::OneHotOperation>(num_classes); | |||||
| THROW_IF_ERROR(one_hot->ValidateParams()); | |||||
| return one_hot; | |||||
| })); | |||||
| })); | |||||
| PYBIND_REGISTER(RandomChoiceOperation, 1, ([](const py::module *m) { | PYBIND_REGISTER(RandomChoiceOperation, 1, ([](const py::module *m) { | ||||
| (void)py::class_<transforms::RandomChoiceOperation, TensorOperation, | (void)py::class_<transforms::RandomChoiceOperation, TensorOperation, | ||||
| std::shared_ptr<transforms::RandomChoiceOperation>>(*m, "RandomChoiceOperation") | std::shared_ptr<transforms::RandomChoiceOperation>>(*m, "RandomChoiceOperation") | ||||
| @@ -87,5 +109,28 @@ PYBIND_REGISTER(RandomApplyOperation, 1, ([](const py::module *m) { | |||||
| return random_apply; | return random_apply; | ||||
| })); | })); | ||||
| })); | })); | ||||
| PYBIND_REGISTER( | |||||
| TypeCastOperation, 1, ([](const py::module *m) { | |||||
| (void)py::class_<transforms::TypeCastOperation, TensorOperation, std::shared_ptr<transforms::TypeCastOperation>>( | |||||
| *m, "TypeCastOperation") | |||||
| .def(py::init([](std::string data_type) { | |||||
| auto type_cast = std::make_shared<transforms::TypeCastOperation>(data_type); | |||||
| THROW_IF_ERROR(type_cast->ValidateParams()); | |||||
| return type_cast; | |||||
| })); | |||||
| })); | |||||
| PYBIND_REGISTER( | |||||
| UniqueOperation, 1, ([](const py::module *m) { | |||||
| (void)py::class_<transforms::UniqueOperation, TensorOperation, std::shared_ptr<transforms::UniqueOperation>>( | |||||
| *m, "UniqueOperation") | |||||
| .def(py::init([]() { | |||||
| auto unique = std::make_shared<transforms::UniqueOperation>(); | |||||
| THROW_IF_ERROR(unique->ValidateParams()); | |||||
| return unique; | |||||
| })); | |||||
| })); | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -64,7 +64,7 @@ std::shared_ptr<TensorOperation> Duplicate::Parse() { return std::make_shared<Du | |||||
| // Constructor to OneHot | // Constructor to OneHot | ||||
| struct OneHot::Data { | struct OneHot::Data { | ||||
| explicit Data(int32_t num_classes) : num_classes_(num_classes) {} | explicit Data(int32_t num_classes) : num_classes_(num_classes) {} | ||||
| float num_classes_; | |||||
| int32_t num_classes_; | |||||
| }; | }; | ||||
| OneHot::OneHot(int32_t num_classes) : data_(std::make_shared<Data>(num_classes)) {} | OneHot::OneHot(int32_t num_classes) : data_(std::make_shared<Data>(num_classes)) {} | ||||
| @@ -351,7 +351,7 @@ class SentencePieceTokenizer final : public TensorTransform { | |||||
| /// \param[in] vocab a SentencePieceVocab object. | /// \param[in] vocab a SentencePieceVocab object. | ||||
| /// \param[in] out_type The type of output. | /// \param[in] out_type The type of output. | ||||
| SentencePieceTokenizer(const std::shared_ptr<SentencePieceVocab> &vocab, | SentencePieceTokenizer(const std::shared_ptr<SentencePieceVocab> &vocab, | ||||
| mindspore::dataset::SPieceTokenizerOutType out_typee); | |||||
| mindspore::dataset::SPieceTokenizerOutType out_type); | |||||
| /// \brief Constructor. | /// \brief Constructor. | ||||
| /// \param[in] vocab_path vocab model file path. | /// \param[in] vocab_path vocab model file path. | ||||
| @@ -398,14 +398,14 @@ class SlidingWindow final : public TensorTransform { | |||||
| }; | }; | ||||
| /// \brief Tensor operation to convert every element of a string tensor to a number. | /// \brief Tensor operation to convert every element of a string tensor to a number. | ||||
| /// Strings are casted according to the rules specified in the following links: | |||||
| /// Strings are cast according to the rules specified in the following links: | |||||
| /// https://en.cppreference.com/w/cpp/string/basic_string/stof, | /// https://en.cppreference.com/w/cpp/string/basic_string/stof, | ||||
| /// https://en.cppreference.com/w/cpp/string/basic_string/stoul, | /// https://en.cppreference.com/w/cpp/string/basic_string/stoul, | ||||
| /// except that any strings which represent negative numbers cannot be cast to an unsigned integer type. | /// except that any strings which represent negative numbers cannot be cast to an unsigned integer type. | ||||
| class ToNumber final : public TensorTransform { | class ToNumber final : public TensorTransform { | ||||
| public: | public: | ||||
| /// \brief Constructor. | /// \brief Constructor. | ||||
| /// \param[in] data_type of the tensor to be casted to. Must be a numeric type. | |||||
| /// \param[in] data_type of the tensor to be cast to. Must be a numeric type. | |||||
| explicit ToNumber(const std::string &data_type) : ToNumber(StringToChar(data_type)) {} | explicit ToNumber(const std::string &data_type) : ToNumber(StringToChar(data_type)) {} | ||||
| explicit ToNumber(const std::vector<char> &data_type); | explicit ToNumber(const std::vector<char> &data_type); | ||||
| @@ -38,11 +38,5 @@ Status OneHotOp::OutputShape(const std::vector<TensorShape> &inputs, std::vector | |||||
| return Status(StatusCode::kMDUnexpectedError, "OneHot: invalid input shape."); | return Status(StatusCode::kMDUnexpectedError, "OneHot: invalid input shape."); | ||||
| } | } | ||||
| Status OneHotOp::to_json(nlohmann::json *out_json) { | |||||
| nlohmann::json args; | |||||
| args["num_classes"] = num_classes_; | |||||
| *out_json = args; | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -37,8 +37,6 @@ class OneHotOp : public TensorOp { | |||||
| std::string Name() const override { return kOneHotOp; } | std::string Name() const override { return kOneHotOp; } | ||||
| Status to_json(nlohmann::json *out_json) override; | |||||
| private: | private: | ||||
| int num_classes_; | int num_classes_; | ||||
| }; | }; | ||||
| @@ -34,11 +34,5 @@ Status TypeCastOp::OutputType(const std::vector<DataType> &inputs, std::vector<D | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status TypeCastOp::to_json(nlohmann::json *out_json) { | |||||
| nlohmann::json args; | |||||
| args["data_type"] = type_.ToString(); | |||||
| *out_json = args; | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -43,8 +43,6 @@ class TypeCastOp : public TensorOp { | |||||
| std::string Name() const override { return kTypeCastOp; } | std::string Name() const override { return kTypeCastOp; } | ||||
| Status to_json(nlohmann::json *out_json) override; | |||||
| private: | private: | ||||
| DataType type_; | DataType type_; | ||||
| }; | }; | ||||
| @@ -78,6 +78,13 @@ Status OneHotOperation::ValidateParams() { | |||||
| std::shared_ptr<TensorOp> OneHotOperation::Build() { return std::make_shared<OneHotOp>(num_classes_); } | std::shared_ptr<TensorOp> OneHotOperation::Build() { return std::make_shared<OneHotOp>(num_classes_); } | ||||
| Status OneHotOperation::to_json(nlohmann::json *out_json) { | |||||
| nlohmann::json args; | |||||
| args["num_classes"] = num_classes_; | |||||
| *out_json = args; | |||||
| return Status::OK(); | |||||
| } | |||||
| // PreBuiltOperation | // PreBuiltOperation | ||||
| PreBuiltOperation::PreBuiltOperation(std::shared_ptr<TensorOp> tensor_op) : op_(tensor_op) { | PreBuiltOperation::PreBuiltOperation(std::shared_ptr<TensorOp> tensor_op) : op_(tensor_op) { | ||||
| #ifdef ENABLE_PYTHON | #ifdef ENABLE_PYTHON | ||||
| @@ -149,6 +156,13 @@ Status TypeCastOperation::ValidateParams() { | |||||
| std::shared_ptr<TensorOp> TypeCastOperation::Build() { return std::make_shared<TypeCastOp>(data_type_); } | std::shared_ptr<TensorOp> TypeCastOperation::Build() { return std::make_shared<TypeCastOp>(data_type_); } | ||||
| Status TypeCastOperation::to_json(nlohmann::json *out_json) { | |||||
| nlohmann::json args; | |||||
| args["data_type"] = data_type_; | |||||
| *out_json = args; | |||||
| return Status::OK(); | |||||
| } | |||||
| #ifndef ENABLE_ANDROID | #ifndef ENABLE_ANDROID | ||||
| // UniqueOperation | // UniqueOperation | ||||
| Status UniqueOperation::ValidateParams() { return Status::OK(); } | Status UniqueOperation::ValidateParams() { return Status::OK(); } | ||||
| @@ -81,8 +81,10 @@ class OneHotOperation : public TensorOperation { | |||||
| std::string Name() const override { return kOneHotOperation; } | std::string Name() const override { return kOneHotOperation; } | ||||
| Status to_json(nlohmann::json *out_json) override; | |||||
| private: | private: | ||||
| float num_classes_; | |||||
| int32_t num_classes_; | |||||
| }; | }; | ||||
| class PreBuiltOperation : public TensorOperation { | class PreBuiltOperation : public TensorOperation { | ||||
| @@ -147,6 +149,8 @@ class TypeCastOperation : public TensorOperation { | |||||
| std::string Name() const override { return kTypeCastOperation; } | std::string Name() const override { return kTypeCastOperation; } | ||||
| Status to_json(nlohmann::json *out_json) override; | |||||
| private: | private: | ||||
| std::string data_type_; | std::string data_type_; | ||||
| }; | }; | ||||
| @@ -362,8 +362,7 @@ def construct_tensor_ops(operations): | |||||
| if hasattr(op_module_vis, op_name): | if hasattr(op_module_vis, op_name): | ||||
| op_class = getattr(op_module_vis, op_name, None) | op_class = getattr(op_module_vis, op_name, None) | ||||
| elif hasattr(op_module_trans, op_name[:-2]): | |||||
| op_name = op_name[:-2] # to remove op from the back of the name | |||||
| elif hasattr(op_module_trans, op_name): | |||||
| op_class = getattr(op_module_trans, op_name, None) | op_class = getattr(op_module_trans, op_name, None) | ||||
| else: | else: | ||||
| raise RuntimeError(op_name + " is not yet supported by deserialize().") | raise RuntimeError(op_name + " is not yet supported by deserialize().") | ||||
| @@ -387,18 +387,18 @@ class ToNumber(TextTensorOperation): | |||||
| """ | """ | ||||
| Tensor operation to convert every element of a string tensor to a number. | Tensor operation to convert every element of a string tensor to a number. | ||||
| Strings are casted according to the rules specified in the following links: | |||||
| Strings are cast according to the rules specified in the following links: | |||||
| https://en.cppreference.com/w/cpp/string/basic_string/stof, | https://en.cppreference.com/w/cpp/string/basic_string/stof, | ||||
| https://en.cppreference.com/w/cpp/string/basic_string/stoul, | https://en.cppreference.com/w/cpp/string/basic_string/stoul, | ||||
| except that any strings which represent negative numbers cannot be cast to an | except that any strings which represent negative numbers cannot be cast to an | ||||
| unsigned integer type. | unsigned integer type. | ||||
| Args: | Args: | ||||
| data_type (mindspore.dtype): mindspore.dtype to be casted to. Must be | |||||
| data_type (mindspore.dtype): mindspore.dtype to be cast to. Must be | |||||
| a numeric type. | a numeric type. | ||||
| Raises: | Raises: | ||||
| RuntimeError: If strings are invalid to cast, or are out of range after being casted. | |||||
| RuntimeError: If strings are invalid to cast, or are out of range after being cast. | |||||
| Examples: | Examples: | ||||
| >>> import mindspore.common.dtype as mstype | >>> import mindspore.common.dtype as mstype | ||||
| @@ -21,7 +21,7 @@ import numpy as np | |||||
| import mindspore.common.dtype as mstype | import mindspore.common.dtype as mstype | ||||
| import mindspore._c_dataengine as cde | import mindspore._c_dataengine as cde | ||||
| from .validators import check_num_classes, check_de_type, check_fill_value, check_slice_option, check_slice_op, \ | |||||
| from .validators import check_num_classes, check_ms_type, check_fill_value, check_slice_option, check_slice_op, \ | |||||
| check_mask_op, check_pad_end, check_concat_type, check_random_transform_ops | check_mask_op, check_pad_end, check_concat_type, check_random_transform_ops | ||||
| from ..core.datatypes import mstype_to_detype | from ..core.datatypes import mstype_to_detype | ||||
| @@ -52,7 +52,7 @@ class TensorOperation: | |||||
| raise NotImplementedError("TensorOperation has to implement parse() method.") | raise NotImplementedError("TensorOperation has to implement parse() method.") | ||||
| class OneHot(cde.OneHotOp): | |||||
| class OneHot(TensorOperation): | |||||
| """ | """ | ||||
| Tensor operation to apply one hot encoding. | Tensor operation to apply one hot encoding. | ||||
| @@ -72,7 +72,9 @@ class OneHot(cde.OneHotOp): | |||||
| @check_num_classes | @check_num_classes | ||||
| def __init__(self, num_classes): | def __init__(self, num_classes): | ||||
| self.num_classes = num_classes | self.num_classes = num_classes | ||||
| super().__init__(num_classes) | |||||
| def parse(self): | |||||
| return cde.OneHotOperation(self.num_classes) | |||||
| class Fill(cde.FillOp): | class Fill(cde.FillOp): | ||||
| @@ -102,7 +104,7 @@ class Fill(cde.FillOp): | |||||
| super().__init__(cde.Tensor(np.array(fill_value))) | super().__init__(cde.Tensor(np.array(fill_value))) | ||||
| class TypeCast(cde.TypeCastOp): | |||||
| class TypeCast(TensorOperation): | |||||
| """ | """ | ||||
| Tensor operation to cast to a given MindSpore data type. | Tensor operation to cast to a given MindSpore data type. | ||||
| @@ -123,11 +125,13 @@ class TypeCast(cde.TypeCastOp): | |||||
| >>> dataset = dataset.map(operations=type_cast_op) | >>> dataset = dataset.map(operations=type_cast_op) | ||||
| """ | """ | ||||
| @check_de_type | |||||
| @check_ms_type | |||||
| def __init__(self, data_type): | def __init__(self, data_type): | ||||
| data_type = mstype_to_detype(data_type) | data_type = mstype_to_detype(data_type) | ||||
| self.data_type = str(data_type) | self.data_type = str(data_type) | ||||
| super().__init__(data_type) | |||||
| def parse(self): | |||||
| return cde.TypeCastOperation(self.data_type) | |||||
| class _SliceOption(cde.SliceOption): | class _SliceOption(cde.SliceOption): | ||||
| @@ -314,7 +318,7 @@ class Concatenate(cde.ConcatenateOp): | |||||
| super().__init__(axis, prepend, append) | super().__init__(axis, prepend, append) | ||||
| class Duplicate(cde.DuplicateOp): | |||||
| class Duplicate(TensorOperation): | |||||
| """ | """ | ||||
| Duplicate the input tensor to output, only support transform one column each time. | Duplicate the input tensor to output, only support transform one column each time. | ||||
| @@ -337,8 +341,11 @@ class Duplicate(cde.DuplicateOp): | |||||
| >>> # +---------+---------+ | >>> # +---------+---------+ | ||||
| """ | """ | ||||
| def parse(self): | |||||
| return cde.DuplicateOperation() | |||||
| class Unique(cde.UniqueOp): | |||||
| class Unique(TensorOperation): | |||||
| """ | """ | ||||
| Perform the unique operation on the input tensor, only support transform one column each time. | Perform the unique operation on the input tensor, only support transform one column each time. | ||||
| @@ -373,9 +380,11 @@ class Unique(cde.UniqueOp): | |||||
| >>> # +---------+-----------------+---------+ | >>> # +---------+-----------------+---------+ | ||||
| """ | """ | ||||
| def parse(self): | |||||
| return cde.UniqueOperation() | |||||
| class Compose(): | |||||
| class Compose(TensorOperation): | |||||
| """ | """ | ||||
| Compose a list of transforms into a single transform. | Compose a list of transforms into a single transform. | ||||
| @@ -401,7 +410,7 @@ class Compose(): | |||||
| return cde.ComposeOperation(operations) | return cde.ComposeOperation(operations) | ||||
| class RandomApply(): | |||||
| class RandomApply(TensorOperation): | |||||
| """ | """ | ||||
| Randomly perform a series of transforms with a given probability. | Randomly perform a series of transforms with a given probability. | ||||
| @@ -429,7 +438,7 @@ class RandomApply(): | |||||
| return cde.RandomApplyOperation(self.prob, operations) | return cde.RandomApplyOperation(self.prob, operations) | ||||
| class RandomChoice(): | |||||
| class RandomChoice(TensorOperation): | |||||
| """ | """ | ||||
| Randomly select one transform from a list of transforms to perform operation. | Randomly select one transform from a list of transforms to perform operation. | ||||
| @@ -87,7 +87,7 @@ def check_num_classes(method): | |||||
| return new_method | return new_method | ||||
| def check_de_type(method): | |||||
| def check_ms_type(method): | |||||
| """Wrapper method to check the parameters of data type.""" | """Wrapper method to check the parameters of data type.""" | ||||
| @wraps(method) | @wraps(method) | ||||