| @@ -31,6 +31,7 @@ | |||
| namespace mindspore { | |||
| const char kShapeSeperator[] = ","; | |||
| const char kShapeScalar[] = "[0]"; | |||
| static std::map<std::string, TypeId> print_type_map = { | |||
| {"int8_t", TypeId::kNumberTypeInt8}, {"uint8_t", TypeId::kNumberTypeUInt8}, | |||
| {"int16_t", TypeId::kNumberTypeInt16}, {"uint16_t", TypeId::kNumberTypeUInt16}, | |||
| @@ -81,6 +82,73 @@ bool PrintTensorToString(const char *str_data_ptr, mindspore::tensor::Tensor *co | |||
| return true; | |||
| } | |||
| template <typename T> | |||
| void PrintScalarToString(const char *str_data_ptr, const string &tensor_type) { | |||
| const T *data_ptr = reinterpret_cast<const T *>(str_data_ptr); | |||
| std::ostringstream buf_scalar; | |||
| buf_scalar << "Tensor shape :1 " << tensor_type; | |||
| buf_scalar << "\nval:"; | |||
| buf_scalar << *data_ptr; | |||
| std::cout << buf_scalar.str() << std::endl; | |||
| } | |||
| void PrintScalarToBoolString(const char *str_data_ptr, const string &tensor_type) { | |||
| const bool *data_ptr = reinterpret_cast<const bool *>(str_data_ptr); | |||
| std::ostringstream buf_scalar; | |||
| buf_scalar << "Tensor shape :1 " << tensor_type; | |||
| buf_scalar << "\nval:"; | |||
| if (*data_ptr == true) { | |||
| buf_scalar << "True"; | |||
| } else { | |||
| buf_scalar << "False"; | |||
| } | |||
| std::cout << buf_scalar.str() << std::endl; | |||
| } | |||
| void convertDataItem2Scalar(const char *str_data_ptr, const string &tensor_type) { | |||
| auto type_iter = print_type_map.find(tensor_type); | |||
| auto type_id = type_iter->second; | |||
| if (type_id == TypeId::kNumberTypeBool) { | |||
| PrintScalarToBoolString(str_data_ptr, tensor_type); | |||
| } else if (type_id == TypeId::kNumberTypeInt8) { | |||
| PrintScalarToString<int8_t>(str_data_ptr, tensor_type); | |||
| } else if (type_id == TypeId::kNumberTypeUInt8) { | |||
| PrintScalarToString<uint8_t>(str_data_ptr, tensor_type); | |||
| } else if (type_id == TypeId::kNumberTypeInt16) { | |||
| PrintScalarToString<int16_t>(str_data_ptr, tensor_type); | |||
| } else if (type_id == TypeId::kNumberTypeUInt16) { | |||
| PrintScalarToString<uint16_t>(str_data_ptr, tensor_type); | |||
| } else if (type_id == TypeId::kNumberTypeInt32) { | |||
| PrintScalarToString<int32_t>(str_data_ptr, tensor_type); | |||
| } else if (type_id == TypeId::kNumberTypeUInt32) { | |||
| PrintScalarToString<uint32_t>(str_data_ptr, tensor_type); | |||
| } else if (type_id == TypeId::kNumberTypeInt64) { | |||
| PrintScalarToString<int64_t>(str_data_ptr, tensor_type); | |||
| } else if (type_id == TypeId::kNumberTypeUInt64) { | |||
| PrintScalarToString<uint64_t>(str_data_ptr, tensor_type); | |||
| } else if (type_id == TypeId::kNumberTypeFloat16) { | |||
| PrintScalarToString<float16>(str_data_ptr, tensor_type); | |||
| } else if (type_id == TypeId::kNumberTypeFloat32) { | |||
| PrintScalarToString<float>(str_data_ptr, tensor_type); | |||
| } else if (type_id == TypeId::kNumberTypeFloat64) { | |||
| PrintScalarToString<double>(str_data_ptr, tensor_type); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Cannot print scalar because of unsupport data type: " << tensor_type << "."; | |||
| } | |||
| } // namespace mindspore | |||
| bool judgeLengthValid(const size_t str_len, const string &tensor_type) { | |||
| auto type_iter = type_size_map.find(tensor_type); | |||
| if (type_iter == type_size_map.end()) { | |||
| MS_LOG(EXCEPTION) << "type of scalar to print is not support."; | |||
| } | |||
| if (str_len != type_iter->second) { | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| #ifndef NO_DLIB | |||
| bool ConvertDataItem2Tensor(const std::vector<tdt::DataItem> &items) { | |||
| // Acquire Python GIL | |||
| @@ -92,14 +160,22 @@ bool ConvertDataItem2Tensor(const std::vector<tdt::DataItem> &items) { | |||
| ret_end_sequence = true; | |||
| break; | |||
| } | |||
| std::shared_ptr<std::string> str_data_ptr = std::static_pointer_cast<std::string>(item.dataPtr_); | |||
| MS_EXCEPTION_IF_NULL(str_data_ptr); | |||
| if (item.tensorShape_ == kShapeScalar) { | |||
| if (!judgeLengthValid(str_data_ptr->size(), item.tensorType_)) { | |||
| MS_LOG(EXCEPTION) << "Print op receive data length is invalid."; | |||
| } | |||
| convertDataItem2Scalar(str_data_ptr->data(), item.tensorType_); | |||
| continue; | |||
| } | |||
| std::vector<int> tensor_shape; | |||
| size_t totaldims = 1; | |||
| if (!ParseTensorShape(item.tensorShape_, &tensor_shape, &totaldims)) { | |||
| MS_LOG(ERROR) << "Tensor print can not parse tensor shape, receive info" << item.tensorShape_; | |||
| continue; | |||
| } | |||
| std::shared_ptr<std::string> str_data_ptr = std::static_pointer_cast<std::string>(item.dataPtr_); | |||
| MS_EXCEPTION_IF_NULL(str_data_ptr); | |||
| if (item.tensorType_ == "string") { | |||
| std::string data(reinterpret_cast<const char *>(str_data_ptr->c_str()), item.dataLen_); | |||