Browse Source

!9990 Fix scalar tensor print

From: @huangbingjian
Reviewed-by: @zhunaipan,@ginfung,@zh_qh
Signed-off-by: @zh_qh
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
f8aada52be
2 changed files with 3 additions and 3 deletions
  1. +2
    -2
      mindspore/ccsrc/utils/tensorprint_utils.cc
  2. +1
    -1
      tests/st/ops/ascend/test_tensor_print/test_tensor_print.py

+ 2
- 2
mindspore/ccsrc/utils/tensorprint_utils.cc View File

@@ -103,7 +103,7 @@ template <typename T>
void PrintScalarToString(const char *str_data_ptr, const string &tensor_type, std::ostringstream *const buf) {
MS_EXCEPTION_IF_NULL(str_data_ptr);
MS_EXCEPTION_IF_NULL(buf);
*buf << "Tensor(shape=[1], dtype=" << GetParseType(tensor_type) << ", value=";
*buf << "Tensor(shape=[], dtype=" << GetParseType(tensor_type) << ", value=";
const T *data_ptr = reinterpret_cast<const T *>(str_data_ptr);
if constexpr (std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value) {
const int int_data = static_cast<int>(*data_ptr);
@@ -117,7 +117,7 @@ void PrintScalarToBoolString(const char *str_data_ptr, const string &tensor_type
MS_EXCEPTION_IF_NULL(str_data_ptr);
MS_EXCEPTION_IF_NULL(buf);
const bool *data_ptr = reinterpret_cast<const bool *>(str_data_ptr);
*buf << "Tensor(shape=[1], dtype=" << GetParseType(tensor_type) << ", value=";
*buf << "Tensor(shape=[], dtype=" << GetParseType(tensor_type) << ", value=";
if (*data_ptr) {
*buf << "True)\n";
} else {


+ 1
- 1
tests/st/ops/ascend/test_tensor_print/test_tensor_print.py View File

@@ -25,7 +25,7 @@ expect_array = {'Bool': '\n[[ True False]\n [False True]]', 'UInt': '\n[[1 2 3]
'[ *.********e*** **.********e*** *.********e***]]'}

def get_expect_value(res):
if res[0] == '[1]':
if res[0] == '[]':
if res[1] == 'Bool':
return expect_scalar['Bool']
if res[1] in ['Uint8', 'Uint16', 'Uint32', 'Uint64']:


Loading…
Cancel
Save