From a837e3cb7d20c27f45e0b55e426a86e9eda86e82 Mon Sep 17 00:00:00 2001 From: yankai Date: Sun, 26 Apr 2020 14:50:45 +0800 Subject: [PATCH] support print scalar --- mindspore/ccsrc/utils/tensorprint_utils.cc | 80 +++++++++++++++++++++- 1 file changed, 78 insertions(+), 2 deletions(-) diff --git a/mindspore/ccsrc/utils/tensorprint_utils.cc b/mindspore/ccsrc/utils/tensorprint_utils.cc index 1036b424ba..ec1953c56d 100644 --- a/mindspore/ccsrc/utils/tensorprint_utils.cc +++ b/mindspore/ccsrc/utils/tensorprint_utils.cc @@ -31,6 +31,7 @@ namespace mindspore { const char kShapeSeperator[] = ","; +const char kShapeScalar[] = "[0]"; static std::map print_type_map = { {"int8_t", TypeId::kNumberTypeInt8}, {"uint8_t", TypeId::kNumberTypeUInt8}, {"int16_t", TypeId::kNumberTypeInt16}, {"uint16_t", TypeId::kNumberTypeUInt16}, @@ -81,6 +82,73 @@ bool PrintTensorToString(const char *str_data_ptr, mindspore::tensor::Tensor *co return true; } +template +void PrintScalarToString(const char *str_data_ptr, const string &tensor_type) { + const T *data_ptr = reinterpret_cast(str_data_ptr); + std::ostringstream buf_scalar; + buf_scalar << "Tensor shape :1 " << tensor_type; + buf_scalar << "\nval:"; + buf_scalar << *data_ptr; + std::cout << buf_scalar.str() << std::endl; +} + +void PrintScalarToBoolString(const char *str_data_ptr, const string &tensor_type) { + const bool *data_ptr = reinterpret_cast(str_data_ptr); + std::ostringstream buf_scalar; + buf_scalar << "Tensor shape :1 " << tensor_type; + buf_scalar << "\nval:"; + if (*data_ptr == true) { + buf_scalar << "True"; + } else { + buf_scalar << "False"; + } + std::cout << buf_scalar.str() << std::endl; +} + +void convertDataItem2Scalar(const char *str_data_ptr, const string &tensor_type) { + auto type_iter = print_type_map.find(tensor_type); + auto type_id = type_iter->second; + if (type_id == TypeId::kNumberTypeBool) { + PrintScalarToBoolString(str_data_ptr, tensor_type); + } else if (type_id == TypeId::kNumberTypeInt8) { + PrintScalarToString(str_data_ptr, tensor_type); + } else if (type_id == TypeId::kNumberTypeUInt8) { + PrintScalarToString(str_data_ptr, tensor_type); + } else if (type_id == TypeId::kNumberTypeInt16) { + PrintScalarToString(str_data_ptr, tensor_type); + } else if (type_id == TypeId::kNumberTypeUInt16) { + PrintScalarToString(str_data_ptr, tensor_type); + } else if (type_id == TypeId::kNumberTypeInt32) { + PrintScalarToString(str_data_ptr, tensor_type); + } else if (type_id == TypeId::kNumberTypeUInt32) { + PrintScalarToString(str_data_ptr, tensor_type); + } else if (type_id == TypeId::kNumberTypeInt64) { + PrintScalarToString(str_data_ptr, tensor_type); + } else if (type_id == TypeId::kNumberTypeUInt64) { + PrintScalarToString(str_data_ptr, tensor_type); + } else if (type_id == TypeId::kNumberTypeFloat16) { + PrintScalarToString(str_data_ptr, tensor_type); + } else if (type_id == TypeId::kNumberTypeFloat32) { + PrintScalarToString(str_data_ptr, tensor_type); + } else if (type_id == TypeId::kNumberTypeFloat64) { + PrintScalarToString(str_data_ptr, tensor_type); + } else { + MS_LOG(EXCEPTION) << "Cannot print scalar because of unsupport data type: " << tensor_type << "."; + } +} // namespace mindspore + +bool judgeLengthValid(const size_t str_len, const string &tensor_type) { + auto type_iter = type_size_map.find(tensor_type); + if (type_iter == type_size_map.end()) { + MS_LOG(EXCEPTION) << "type of scalar to print is not support."; + } + + if (str_len != type_iter->second) { + return false; + } + return true; +} + #ifndef NO_DLIB bool ConvertDataItem2Tensor(const std::vector &items) { // Acquire Python GIL @@ -92,14 +160,22 @@ bool ConvertDataItem2Tensor(const std::vector &items) { ret_end_sequence = true; break; } + std::shared_ptr str_data_ptr = std::static_pointer_cast(item.dataPtr_); + MS_EXCEPTION_IF_NULL(str_data_ptr); + if (item.tensorShape_ == kShapeScalar) { + if (!judgeLengthValid(str_data_ptr->size(), item.tensorType_)) { + MS_LOG(EXCEPTION) << "Print op receive data length is invalid."; + } + convertDataItem2Scalar(str_data_ptr->data(), item.tensorType_); + continue; + } + std::vector tensor_shape; size_t totaldims = 1; if (!ParseTensorShape(item.tensorShape_, &tensor_shape, &totaldims)) { MS_LOG(ERROR) << "Tensor print can not parse tensor shape, receive info" << item.tensorShape_; continue; } - std::shared_ptr str_data_ptr = std::static_pointer_cast(item.dataPtr_); - MS_EXCEPTION_IF_NULL(str_data_ptr); if (item.tensorType_ == "string") { std::string data(reinterpret_cast(str_data_ptr->c_str()), item.dataLen_);