|
|
|
@@ -15,6 +15,7 @@ |
|
|
|
*/ |
|
|
|
|
|
|
|
#include <algorithm> |
|
|
|
#include <typeinfo> |
|
|
|
|
|
|
|
#include "minddata/dataset/kernels/ir/data/transforms_ir.h" |
|
|
|
|
|
|
|
@@ -213,19 +214,22 @@ std::shared_ptr<TensorOp> SliceOperation::Build() { return std::make_shared<Slic |
|
|
|
#endif |
|
|
|
|
|
|
|
// TypeCastOperation |
|
|
|
TypeCastOperation::TypeCastOperation(std::string data_type) : data_type_(data_type) {} |
|
|
|
// DataType data_type - required for C++ API |
|
|
|
TypeCastOperation::TypeCastOperation(DataType data_type) : data_type_(data_type) {} |
|
|
|
|
|
|
|
// std::string data_type - required for Pybind |
|
|
|
TypeCastOperation::TypeCastOperation(std::string data_type) { |
|
|
|
// Convert from string to DEType |
|
|
|
DataType temp_data_type(data_type); |
|
|
|
data_type_ = temp_data_type; |
|
|
|
} |
|
|
|
|
|
|
|
Status TypeCastOperation::ValidateParams() { |
|
|
|
std::vector<std::string> predefine_type = {"bool", "int8", "uint8", "int16", "uint16", "int32", "uint32", |
|
|
|
"int64", "uint64", "float16", "float32", "float64", "string"}; |
|
|
|
auto itr = std::find(predefine_type.begin(), predefine_type.end(), data_type_); |
|
|
|
if (itr == predefine_type.end()) { |
|
|
|
std::string err_msg = "TypeCast: Invalid data type: " + data_type_; |
|
|
|
MS_LOG(ERROR) << "TypeCast: Only supports data type bool, int8, uint8, int16, uint16, int32, uint32, " |
|
|
|
<< "int64, uint64, float16, float32, float64, string, but got: " << data_type_; |
|
|
|
if (data_type_ == DataType::DE_UNKNOWN) { |
|
|
|
std::string err_msg = "TypeCast: Invalid data type"; |
|
|
|
MS_LOG(ERROR) << err_msg; |
|
|
|
RETURN_STATUS_SYNTAX_ERROR(err_msg); |
|
|
|
} |
|
|
|
|
|
|
|
return Status::OK(); |
|
|
|
} |
|
|
|
|
|
|
|
@@ -233,7 +237,7 @@ std::shared_ptr<TensorOp> TypeCastOperation::Build() { return std::make_shared<T |
|
|
|
|
|
|
|
Status TypeCastOperation::to_json(nlohmann::json *out_json) { |
|
|
|
nlohmann::json args; |
|
|
|
args["data_type"] = data_type_; |
|
|
|
args["data_type"] = data_type_.ToString(); |
|
|
|
*out_json = args; |
|
|
|
return Status::OK(); |
|
|
|
} |
|
|
|
|