| @@ -19,6 +19,8 @@ | |||
| #include <algorithm> | |||
| #include "mindspore/ccsrc/minddata/dataset/core/type_id.h" | |||
| #include "mindspore/core/ir/dtype/type_id.h" | |||
| #include "minddata/dataset/core/type_id.h" | |||
| #include "minddata/dataset/kernels/ir/data/transforms_ir.h" | |||
| namespace mindspore { | |||
| @@ -211,11 +213,12 @@ std::shared_ptr<TensorOperation> Slice::Parse() { return std::make_shared<SliceO | |||
| // Constructor to TypeCast | |||
| struct TypeCast::Data { | |||
| explicit Data(const std::vector<char> &data_type) : data_type_(CharToString(data_type)) {} | |||
| std::string data_type_; | |||
| dataset::DataType data_type_; | |||
| }; | |||
| TypeCast::TypeCast(const std::vector<char> &data_type) : data_(std::make_shared<Data>(data_type)) {} | |||
| TypeCast::TypeCast(mindspore::DataType data_type) : data_(std::make_shared<Data>()) { | |||
| data_->data_type_ = dataset::MSTypeToDEType(static_cast<TypeId>(data_type)); | |||
| } | |||
| std::shared_ptr<TensorOperation> TypeCast::Parse() { return std::make_shared<TypeCastOperation>(data_->data_type_); } | |||
| @@ -24,6 +24,7 @@ | |||
| #include "include/api/dual_abi_helper.h" | |||
| #include "include/api/status.h" | |||
| #include "include/api/types.h" | |||
| #include "minddata/dataset/include/constants.h" | |||
| namespace mindspore { | |||
| @@ -349,10 +350,8 @@ class Slice final : public TensorTransform { | |||
| class TypeCast final : public TensorTransform { | |||
| public: | |||
| /// \brief Constructor. | |||
| /// \param[in] data_type mindspore.dtype to be cast to. | |||
| explicit TypeCast(std::string data_type) : TypeCast(StringToChar(data_type)) {} | |||
| explicit TypeCast(const std::vector<char> &data_type); | |||
| /// \param[in] data_type mindspore::DataType to be cast to. | |||
| explicit TypeCast(mindspore::DataType data_type); | |||
| /// \brief Destructor | |||
| ~TypeCast() = default; | |||
| @@ -15,6 +15,7 @@ | |||
| */ | |||
| #include <algorithm> | |||
| #include <typeinfo> | |||
| #include "minddata/dataset/kernels/ir/data/transforms_ir.h" | |||
| @@ -213,19 +214,22 @@ std::shared_ptr<TensorOp> SliceOperation::Build() { return std::make_shared<Slic | |||
| #endif | |||
| // TypeCastOperation | |||
| TypeCastOperation::TypeCastOperation(std::string data_type) : data_type_(data_type) {} | |||
| // DataType data_type - required for C++ API | |||
| TypeCastOperation::TypeCastOperation(DataType data_type) : data_type_(data_type) {} | |||
| // std::string data_type - required for Pybind | |||
| TypeCastOperation::TypeCastOperation(std::string data_type) { | |||
| // Convert from string to DEType | |||
| DataType temp_data_type(data_type); | |||
| data_type_ = temp_data_type; | |||
| } | |||
| Status TypeCastOperation::ValidateParams() { | |||
| std::vector<std::string> predefine_type = {"bool", "int8", "uint8", "int16", "uint16", "int32", "uint32", | |||
| "int64", "uint64", "float16", "float32", "float64", "string"}; | |||
| auto itr = std::find(predefine_type.begin(), predefine_type.end(), data_type_); | |||
| if (itr == predefine_type.end()) { | |||
| std::string err_msg = "TypeCast: Invalid data type: " + data_type_; | |||
| MS_LOG(ERROR) << "TypeCast: Only supports data type bool, int8, uint8, int16, uint16, int32, uint32, " | |||
| << "int64, uint64, float16, float32, float64, string, but got: " << data_type_; | |||
| if (data_type_ == DataType::DE_UNKNOWN) { | |||
| std::string err_msg = "TypeCast: Invalid data type"; | |||
| MS_LOG(ERROR) << err_msg; | |||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| @@ -233,7 +237,7 @@ std::shared_ptr<TensorOp> TypeCastOperation::Build() { return std::make_shared<T | |||
| Status TypeCastOperation::to_json(nlohmann::json *out_json) { | |||
| nlohmann::json args; | |||
| args["data_type"] = data_type_; | |||
| args["data_type"] = data_type_.ToString(); | |||
| *out_json = args; | |||
| return Status::OK(); | |||
| } | |||
| @@ -22,6 +22,7 @@ | |||
| #include <string> | |||
| #include <vector> | |||
| #include "minddata/dataset/core/data_type.h" | |||
| #include "minddata/dataset/kernels/ir/tensor_operation.h" | |||
| namespace mindspore { | |||
| @@ -214,7 +215,8 @@ class SliceOperation : public TensorOperation { | |||
| class TypeCastOperation : public TensorOperation { | |||
| public: | |||
| explicit TypeCastOperation(std::string data_type); | |||
| explicit TypeCastOperation(DataType data_type); // Used for C++ API | |||
| explicit TypeCastOperation(std::string data_type); // Used for Pybind | |||
| ~TypeCastOperation() = default; | |||
| @@ -227,7 +229,7 @@ class TypeCastOperation : public TensorOperation { | |||
| Status to_json(nlohmann::json *out_json) override; | |||
| private: | |||
| std::string data_type_; | |||
| DataType data_type_; | |||
| }; | |||
| #ifndef ENABLE_ANDROID | |||
| @@ -75,7 +75,8 @@ TEST_F(MindDataTestPipeline, TestSaveCifar10AndLoad) { | |||
| // Create objects for the tensor ops | |||
| // uint32 will be casted to int64 implicitly in mindrecord file, so we have to cast it back to uint32 | |||
| std::shared_ptr<TensorTransform> type_cast = std::make_shared<transforms::TypeCast>("uint32"); | |||
| std::shared_ptr<TensorTransform> type_cast = | |||
| std::make_shared<transforms::TypeCast>(mindspore::DataType::kNumberTypeUInt32); | |||
| EXPECT_NE(type_cast, nullptr); | |||
| // Create a Map operation on ds | |||
| @@ -825,7 +825,8 @@ TEST_F(MindDataTestPipeline, TestTypeCastSuccess) { | |||
| iter->Stop(); | |||
| // Create objects for the tensor ops | |||
| std::shared_ptr<TensorTransform> type_cast = std::make_shared<transforms::TypeCast>("uint16"); | |||
| std::shared_ptr<TensorTransform> type_cast = | |||
| std::make_shared<transforms::TypeCast>(mindspore::DataType::kNumberTypeUInt16); | |||
| // Create a Map operation on ds | |||
| std::shared_ptr<Dataset> ds2 = ds->Map({type_cast}, {"image"}); | |||
| @@ -848,7 +849,7 @@ TEST_F(MindDataTestPipeline, TestTypeCastSuccess) { | |||
| } | |||
| TEST_F(MindDataTestPipeline, TestTypeCastFail) { | |||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTypeCastFail with invalid params."; | |||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTypeCastFail with invalid param."; | |||
| // Create a Cifar10 Dataset | |||
| std::string folder_path = datasets_root_path_ + "/testCifar10Data/"; | |||
| @@ -856,7 +857,7 @@ TEST_F(MindDataTestPipeline, TestTypeCastFail) { | |||
| EXPECT_NE(ds, nullptr); | |||
| // incorrect data type | |||
| std::shared_ptr<TensorTransform> type_cast = std::make_shared<transforms::TypeCast>("char"); | |||
| std::shared_ptr<TensorTransform> type_cast = std::make_shared<transforms::TypeCast>(mindspore::DataType::kTypeUnknown); | |||
| // Create a Map operation on ds | |||
| ds = ds->Map({type_cast}, {"image", "label"}); | |||
| @@ -865,4 +866,4 @@ TEST_F(MindDataTestPipeline, TestTypeCastFail) { | |||
| std::shared_ptr<Iterator> iter = ds->CreateIterator(); | |||
| // Expect failure: invalid TypeCast input | |||
| EXPECT_EQ(iter, nullptr); | |||
| } | |||
| } | |||
| @@ -49,7 +49,7 @@ TEST_F(MindDataTestPipeline, TestRescaleSucess1) { | |||
| // Note: No need to check for output after calling API class constructor | |||
| // Convert to the same type | |||
| std::shared_ptr<TensorTransform> type_cast(new transforms::TypeCast("uint8")); | |||
| std::shared_ptr<TensorTransform> type_cast(new transforms::TypeCast(mindspore::DataType::kNumberTypeUInt8)); | |||
| // Note: No need to check for output after calling API class constructor | |||
| ds = ds->Map({rescale, type_cast}, {"image"}); | |||
| @@ -332,7 +332,7 @@ TEST_F(MindDataTestCallback, TestCAPICallback) { | |||
| ASSERT_OK(schema->add_column("label", mindspore::DataType::kNumberTypeUInt32, {})); | |||
| std::shared_ptr<Dataset> ds = RandomData(44, schema); | |||
| ASSERT_NE(ds, nullptr); | |||
| ds = ds->Map({std::make_shared<transforms::TypeCast>("uint64")}, {"label"}, {}, {}, nullptr, {cb1}); | |||
| ds = ds->Map({std::make_shared<transforms::TypeCast>(mindspore::DataType::kNumberTypeUInt64)}, {"label"}, {}, {}, nullptr, {cb1}); | |||
| ASSERT_NE(ds, nullptr); | |||
| ds = ds->Repeat(2); | |||
| ASSERT_NE(ds, nullptr); | |||