|
|
@@ -31,6 +31,7 @@ |
|
|
|
|
|
|
|
|
namespace mindspore { |
|
|
namespace mindspore { |
|
|
const char kShapeSeperator[] = ","; |
|
|
const char kShapeSeperator[] = ","; |
|
|
|
|
|
const char kShapeScalar[] = "[0]"; |
|
|
static std::map<std::string, TypeId> print_type_map = { |
|
|
static std::map<std::string, TypeId> print_type_map = { |
|
|
{"int8_t", TypeId::kNumberTypeInt8}, {"uint8_t", TypeId::kNumberTypeUInt8}, |
|
|
{"int8_t", TypeId::kNumberTypeInt8}, {"uint8_t", TypeId::kNumberTypeUInt8}, |
|
|
{"int16_t", TypeId::kNumberTypeInt16}, {"uint16_t", TypeId::kNumberTypeUInt16}, |
|
|
{"int16_t", TypeId::kNumberTypeInt16}, {"uint16_t", TypeId::kNumberTypeUInt16}, |
|
|
@@ -81,6 +82,73 @@ bool PrintTensorToString(const char *str_data_ptr, mindspore::tensor::Tensor *co |
|
|
return true; |
|
|
return true; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
template <typename T> |
|
|
|
|
|
void PrintScalarToString(const char *str_data_ptr, const string &tensor_type) { |
|
|
|
|
|
const T *data_ptr = reinterpret_cast<const T *>(str_data_ptr); |
|
|
|
|
|
std::ostringstream buf_scalar; |
|
|
|
|
|
buf_scalar << "Tensor shape :1 " << tensor_type; |
|
|
|
|
|
buf_scalar << "\nval:"; |
|
|
|
|
|
buf_scalar << *data_ptr; |
|
|
|
|
|
std::cout << buf_scalar.str() << std::endl; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
void PrintScalarToBoolString(const char *str_data_ptr, const string &tensor_type) { |
|
|
|
|
|
const bool *data_ptr = reinterpret_cast<const bool *>(str_data_ptr); |
|
|
|
|
|
std::ostringstream buf_scalar; |
|
|
|
|
|
buf_scalar << "Tensor shape :1 " << tensor_type; |
|
|
|
|
|
buf_scalar << "\nval:"; |
|
|
|
|
|
if (*data_ptr == true) { |
|
|
|
|
|
buf_scalar << "True"; |
|
|
|
|
|
} else { |
|
|
|
|
|
buf_scalar << "False"; |
|
|
|
|
|
} |
|
|
|
|
|
std::cout << buf_scalar.str() << std::endl; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
void convertDataItem2Scalar(const char *str_data_ptr, const string &tensor_type) { |
|
|
|
|
|
auto type_iter = print_type_map.find(tensor_type); |
|
|
|
|
|
auto type_id = type_iter->second; |
|
|
|
|
|
if (type_id == TypeId::kNumberTypeBool) { |
|
|
|
|
|
PrintScalarToBoolString(str_data_ptr, tensor_type); |
|
|
|
|
|
} else if (type_id == TypeId::kNumberTypeInt8) { |
|
|
|
|
|
PrintScalarToString<int8_t>(str_data_ptr, tensor_type); |
|
|
|
|
|
} else if (type_id == TypeId::kNumberTypeUInt8) { |
|
|
|
|
|
PrintScalarToString<uint8_t>(str_data_ptr, tensor_type); |
|
|
|
|
|
} else if (type_id == TypeId::kNumberTypeInt16) { |
|
|
|
|
|
PrintScalarToString<int16_t>(str_data_ptr, tensor_type); |
|
|
|
|
|
} else if (type_id == TypeId::kNumberTypeUInt16) { |
|
|
|
|
|
PrintScalarToString<uint16_t>(str_data_ptr, tensor_type); |
|
|
|
|
|
} else if (type_id == TypeId::kNumberTypeInt32) { |
|
|
|
|
|
PrintScalarToString<int32_t>(str_data_ptr, tensor_type); |
|
|
|
|
|
} else if (type_id == TypeId::kNumberTypeUInt32) { |
|
|
|
|
|
PrintScalarToString<uint32_t>(str_data_ptr, tensor_type); |
|
|
|
|
|
} else if (type_id == TypeId::kNumberTypeInt64) { |
|
|
|
|
|
PrintScalarToString<int64_t>(str_data_ptr, tensor_type); |
|
|
|
|
|
} else if (type_id == TypeId::kNumberTypeUInt64) { |
|
|
|
|
|
PrintScalarToString<uint64_t>(str_data_ptr, tensor_type); |
|
|
|
|
|
} else if (type_id == TypeId::kNumberTypeFloat16) { |
|
|
|
|
|
PrintScalarToString<float16>(str_data_ptr, tensor_type); |
|
|
|
|
|
} else if (type_id == TypeId::kNumberTypeFloat32) { |
|
|
|
|
|
PrintScalarToString<float>(str_data_ptr, tensor_type); |
|
|
|
|
|
} else if (type_id == TypeId::kNumberTypeFloat64) { |
|
|
|
|
|
PrintScalarToString<double>(str_data_ptr, tensor_type); |
|
|
|
|
|
} else { |
|
|
|
|
|
MS_LOG(EXCEPTION) << "Cannot print scalar because of unsupport data type: " << tensor_type << "."; |
|
|
|
|
|
} |
|
|
|
|
|
} // namespace mindspore |
|
|
|
|
|
|
|
|
|
|
|
bool judgeLengthValid(const size_t str_len, const string &tensor_type) { |
|
|
|
|
|
auto type_iter = type_size_map.find(tensor_type); |
|
|
|
|
|
if (type_iter == type_size_map.end()) { |
|
|
|
|
|
MS_LOG(EXCEPTION) << "type of scalar to print is not support."; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
if (str_len != type_iter->second) { |
|
|
|
|
|
return false; |
|
|
|
|
|
} |
|
|
|
|
|
return true; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
#ifndef NO_DLIB |
|
|
#ifndef NO_DLIB |
|
|
bool ConvertDataItem2Tensor(const std::vector<tdt::DataItem> &items) { |
|
|
bool ConvertDataItem2Tensor(const std::vector<tdt::DataItem> &items) { |
|
|
// Acquire Python GIL |
|
|
// Acquire Python GIL |
|
|
@@ -92,14 +160,22 @@ bool ConvertDataItem2Tensor(const std::vector<tdt::DataItem> &items) { |
|
|
ret_end_sequence = true; |
|
|
ret_end_sequence = true; |
|
|
break; |
|
|
break; |
|
|
} |
|
|
} |
|
|
|
|
|
std::shared_ptr<std::string> str_data_ptr = std::static_pointer_cast<std::string>(item.dataPtr_); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(str_data_ptr); |
|
|
|
|
|
if (item.tensorShape_ == kShapeScalar) { |
|
|
|
|
|
if (!judgeLengthValid(str_data_ptr->size(), item.tensorType_)) { |
|
|
|
|
|
MS_LOG(EXCEPTION) << "Print op receive data length is invalid."; |
|
|
|
|
|
} |
|
|
|
|
|
convertDataItem2Scalar(str_data_ptr->data(), item.tensorType_); |
|
|
|
|
|
continue; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
std::vector<int> tensor_shape; |
|
|
std::vector<int> tensor_shape; |
|
|
size_t totaldims = 1; |
|
|
size_t totaldims = 1; |
|
|
if (!ParseTensorShape(item.tensorShape_, &tensor_shape, &totaldims)) { |
|
|
if (!ParseTensorShape(item.tensorShape_, &tensor_shape, &totaldims)) { |
|
|
MS_LOG(ERROR) << "Tensor print can not parse tensor shape, receive info" << item.tensorShape_; |
|
|
MS_LOG(ERROR) << "Tensor print can not parse tensor shape, receive info" << item.tensorShape_; |
|
|
continue; |
|
|
continue; |
|
|
} |
|
|
} |
|
|
std::shared_ptr<std::string> str_data_ptr = std::static_pointer_cast<std::string>(item.dataPtr_); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(str_data_ptr); |
|
|
|
|
|
|
|
|
|
|
|
if (item.tensorType_ == "string") { |
|
|
if (item.tensorType_ == "string") { |
|
|
std::string data(reinterpret_cast<const char *>(str_data_ptr->c_str()), item.dataLen_); |
|
|
std::string data(reinterpret_cast<const char *>(str_data_ptr->c_str()), item.dataLen_); |
|
|
|