Browse Source

!68 add support print scalar

Merge pull request !68 from yankai10/fix_bug
tags/v0.3.0-alpha
mindspore-ci-bot Gitee 5 years ago
parent
commit
e4f300d8a1
1 changed files with 78 additions and 2 deletions
  1. +78
    -2
      mindspore/ccsrc/utils/tensorprint_utils.cc

+ 78
- 2
mindspore/ccsrc/utils/tensorprint_utils.cc View File

@@ -31,6 +31,7 @@


namespace mindspore { namespace mindspore {
const char kShapeSeperator[] = ","; const char kShapeSeperator[] = ",";
const char kShapeScalar[] = "[0]";
static std::map<std::string, TypeId> print_type_map = { static std::map<std::string, TypeId> print_type_map = {
{"int8_t", TypeId::kNumberTypeInt8}, {"uint8_t", TypeId::kNumberTypeUInt8}, {"int8_t", TypeId::kNumberTypeInt8}, {"uint8_t", TypeId::kNumberTypeUInt8},
{"int16_t", TypeId::kNumberTypeInt16}, {"uint16_t", TypeId::kNumberTypeUInt16}, {"int16_t", TypeId::kNumberTypeInt16}, {"uint16_t", TypeId::kNumberTypeUInt16},
@@ -81,6 +82,73 @@ bool PrintTensorToString(const char *str_data_ptr, mindspore::tensor::Tensor *co
return true; return true;
} }


template <typename T>
void PrintScalarToString(const char *str_data_ptr, const string &tensor_type) {
const T *data_ptr = reinterpret_cast<const T *>(str_data_ptr);
std::ostringstream buf_scalar;
buf_scalar << "Tensor shape :1 " << tensor_type;
buf_scalar << "\nval:";
buf_scalar << *data_ptr;
std::cout << buf_scalar.str() << std::endl;
}

void PrintScalarToBoolString(const char *str_data_ptr, const string &tensor_type) {
const bool *data_ptr = reinterpret_cast<const bool *>(str_data_ptr);
std::ostringstream buf_scalar;
buf_scalar << "Tensor shape :1 " << tensor_type;
buf_scalar << "\nval:";
if (*data_ptr == true) {
buf_scalar << "True";
} else {
buf_scalar << "False";
}
std::cout << buf_scalar.str() << std::endl;
}

void convertDataItem2Scalar(const char *str_data_ptr, const string &tensor_type) {
auto type_iter = print_type_map.find(tensor_type);
auto type_id = type_iter->second;
if (type_id == TypeId::kNumberTypeBool) {
PrintScalarToBoolString(str_data_ptr, tensor_type);
} else if (type_id == TypeId::kNumberTypeInt8) {
PrintScalarToString<int8_t>(str_data_ptr, tensor_type);
} else if (type_id == TypeId::kNumberTypeUInt8) {
PrintScalarToString<uint8_t>(str_data_ptr, tensor_type);
} else if (type_id == TypeId::kNumberTypeInt16) {
PrintScalarToString<int16_t>(str_data_ptr, tensor_type);
} else if (type_id == TypeId::kNumberTypeUInt16) {
PrintScalarToString<uint16_t>(str_data_ptr, tensor_type);
} else if (type_id == TypeId::kNumberTypeInt32) {
PrintScalarToString<int32_t>(str_data_ptr, tensor_type);
} else if (type_id == TypeId::kNumberTypeUInt32) {
PrintScalarToString<uint32_t>(str_data_ptr, tensor_type);
} else if (type_id == TypeId::kNumberTypeInt64) {
PrintScalarToString<int64_t>(str_data_ptr, tensor_type);
} else if (type_id == TypeId::kNumberTypeUInt64) {
PrintScalarToString<uint64_t>(str_data_ptr, tensor_type);
} else if (type_id == TypeId::kNumberTypeFloat16) {
PrintScalarToString<float16>(str_data_ptr, tensor_type);
} else if (type_id == TypeId::kNumberTypeFloat32) {
PrintScalarToString<float>(str_data_ptr, tensor_type);
} else if (type_id == TypeId::kNumberTypeFloat64) {
PrintScalarToString<double>(str_data_ptr, tensor_type);
} else {
MS_LOG(EXCEPTION) << "Cannot print scalar because of unsupport data type: " << tensor_type << ".";
}
} // namespace mindspore

bool judgeLengthValid(const size_t str_len, const string &tensor_type) {
auto type_iter = type_size_map.find(tensor_type);
if (type_iter == type_size_map.end()) {
MS_LOG(EXCEPTION) << "type of scalar to print is not support.";
}

if (str_len != type_iter->second) {
return false;
}
return true;
}

#ifndef NO_DLIB #ifndef NO_DLIB
bool ConvertDataItem2Tensor(const std::vector<tdt::DataItem> &items) { bool ConvertDataItem2Tensor(const std::vector<tdt::DataItem> &items) {
// Acquire Python GIL // Acquire Python GIL
@@ -92,14 +160,22 @@ bool ConvertDataItem2Tensor(const std::vector<tdt::DataItem> &items) {
ret_end_sequence = true; ret_end_sequence = true;
break; break;
} }
std::shared_ptr<std::string> str_data_ptr = std::static_pointer_cast<std::string>(item.dataPtr_);
MS_EXCEPTION_IF_NULL(str_data_ptr);
if (item.tensorShape_ == kShapeScalar) {
if (!judgeLengthValid(str_data_ptr->size(), item.tensorType_)) {
MS_LOG(EXCEPTION) << "Print op receive data length is invalid.";
}
convertDataItem2Scalar(str_data_ptr->data(), item.tensorType_);
continue;
}

std::vector<int> tensor_shape; std::vector<int> tensor_shape;
size_t totaldims = 1; size_t totaldims = 1;
if (!ParseTensorShape(item.tensorShape_, &tensor_shape, &totaldims)) { if (!ParseTensorShape(item.tensorShape_, &tensor_shape, &totaldims)) {
MS_LOG(ERROR) << "Tensor print can not parse tensor shape, receive info" << item.tensorShape_; MS_LOG(ERROR) << "Tensor print can not parse tensor shape, receive info" << item.tensorShape_;
continue; continue;
} }
std::shared_ptr<std::string> str_data_ptr = std::static_pointer_cast<std::string>(item.dataPtr_);
MS_EXCEPTION_IF_NULL(str_data_ptr);


if (item.tensorType_ == "string") { if (item.tensorType_ == "string") {
std::string data(reinterpret_cast<const char *>(str_data_ptr->c_str()), item.dataLen_); std::string data(reinterpret_cast<const char *>(str_data_ptr->c_str()), item.dataLen_);


Loading…
Cancel
Save