|
|
|
@@ -213,7 +213,7 @@ class TensorDataImpl : public TensorData { |
|
|
|
std::equal(data_.get(), data_.get() + data_size_, ptr->data_.get()); |
|
|
|
} |
|
|
|
|
|
|
|
std::string ToString(const TypeId type, const ShapeVector &shape) const override { |
|
|
|
std::string ToString(const TypeId type, const ShapeVector &shape, bool use_comma) const override { |
|
|
|
constexpr auto valid = |
|
|
|
std::is_same<T, bool>::value || std::is_same<T, uint8_t>::value || std::is_same<T, int8_t>::value || |
|
|
|
std::is_same<T, int16_t>::value || std::is_same<T, int32_t>::value || std::is_same<T, int64_t>::value || |
|
|
|
@@ -229,16 +229,16 @@ class TensorDataImpl : public TensorData { |
|
|
|
|
|
|
|
std::ostringstream ss; |
|
|
|
if (data_size_ == 1 && ndim_ == 0) { // Scalar |
|
|
|
OutputDataString(ss, 0, 0, 1); |
|
|
|
OutputDataString(ss, 0, 0, 1, false); |
|
|
|
return ss.str(); |
|
|
|
} |
|
|
|
ssize_t cursor = 0; |
|
|
|
SummaryStringRecursive(ss, shape, &cursor, 0); |
|
|
|
SummaryStringRecursive(ss, shape, &cursor, 0, use_comma); |
|
|
|
return ss.str(); |
|
|
|
} |
|
|
|
|
|
|
|
private: |
|
|
|
void OutputDataString(std::ostringstream &ss, ssize_t cursor, ssize_t start, ssize_t end) const { |
|
|
|
void OutputDataString(std::ostringstream &ss, ssize_t cursor, ssize_t start, ssize_t end, bool use_comma) const { |
|
|
|
const bool isScalar = ndim_ == 0 && end - start == 1; |
|
|
|
constexpr auto isFloat = |
|
|
|
std::is_same<T, float16>::value || std::is_same<T, float>::value || std::is_same<T, double>::value; |
|
|
|
@@ -265,33 +265,43 @@ class TensorDataImpl : public TensorData { |
|
|
|
ss << std::setw(5) << std::setiosflags(std::ios::right) << (value ? "True" : "False"); |
|
|
|
} |
|
|
|
} else { |
|
|
|
constexpr auto isSigned = std::is_same<T, int8_t>::value || std::is_same<T, int16_t>::value || |
|
|
|
std::is_same<T, int32_t>::value || std::is_same<T, int64_t>::value; |
|
|
|
constexpr auto isSigned = std::is_same<T, int64_t>::value; |
|
|
|
if constexpr (isSigned) { |
|
|
|
if (!isScalar && static_cast<int64_t>(value) >= 0) { |
|
|
|
ss << ' '; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// Set width and indent for different int type. |
|
|
|
// Set width and indent for different int type with signed position. |
|
|
|
// |
|
|
|
// int8/uint8 width: 3 |
|
|
|
// int16/uint16 width: 5 |
|
|
|
// int32/uint32 width: 10 |
|
|
|
// int64/uint64 width: NOT SET |
|
|
|
if constexpr (std::is_same<T, int8_t>::value) { |
|
|
|
ss << std::setw(3) << std::setiosflags(std::ios::right) << static_cast<int16_t>(value); |
|
|
|
} else if constexpr (std::is_same<T, uint8_t>::value) { |
|
|
|
// uint8 width: 3, [0, 255] |
|
|
|
// int8 width: 4, [-128, 127] |
|
|
|
// uint16 width: 5, [0, 65535] |
|
|
|
// int16 width: 6, [-32768, 32767] |
|
|
|
// uint32 width: 10, [0, 4294967295] |
|
|
|
// int32 width: 11, [-2147483648, 2147483647] |
|
|
|
// uint64 width: NOT SET (20, [0, 18446744073709551615]) |
|
|
|
// int64 width: NOT SET (20, [-9223372036854775808, 9223372036854775807]) |
|
|
|
if constexpr (std::is_same<T, uint8_t>::value) { |
|
|
|
ss << std::setw(3) << std::setiosflags(std::ios::right) << static_cast<uint16_t>(value); |
|
|
|
} else if constexpr (std::is_same<T, int16_t>::value || std::is_same<T, uint16_t>::value) { |
|
|
|
} else if constexpr (std::is_same<T, int8_t>::value) { |
|
|
|
ss << std::setw(4) << std::setiosflags(std::ios::right) << static_cast<int16_t>(value); |
|
|
|
} else if constexpr (std::is_same<T, uint16_t>::value) { |
|
|
|
ss << std::setw(5) << std::setiosflags(std::ios::right) << value; |
|
|
|
} else if constexpr (std::is_same<T, int32_t>::value || std::is_same<T, uint32_t>::value) { |
|
|
|
} else if constexpr (std::is_same<T, int16_t>::value) { |
|
|
|
ss << std::setw(6) << std::setiosflags(std::ios::right) << value; |
|
|
|
} else if constexpr (std::is_same<T, uint32_t>::value) { |
|
|
|
ss << std::setw(10) << std::setiosflags(std::ios::right) << value; |
|
|
|
} else if constexpr (std::is_same<T, int32_t>::value) { |
|
|
|
ss << std::setw(11) << std::setiosflags(std::ios::right) << value; |
|
|
|
} else { |
|
|
|
ss << value; |
|
|
|
} |
|
|
|
} |
|
|
|
if (!isScalar && i != end - 1) { |
|
|
|
if (use_comma) { |
|
|
|
ss << ','; |
|
|
|
} |
|
|
|
ss << ' '; |
|
|
|
} |
|
|
|
if (!isScalar && ndim_ == 1 && (i + 1) % linefeedThreshold == 0) { |
|
|
|
@@ -301,7 +311,8 @@ class TensorDataImpl : public TensorData { |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void SummaryStringRecursive(std::ostringstream &ss, const ShapeVector &shape, ssize_t *cursor, ssize_t depth) const { |
|
|
|
void SummaryStringRecursive(std::ostringstream &ss, const ShapeVector &shape, ssize_t *cursor, ssize_t depth, |
|
|
|
bool use_comma) const { |
|
|
|
if (depth >= static_cast<ssize_t>(ndim_)) { |
|
|
|
return; |
|
|
|
} |
|
|
|
@@ -309,11 +320,11 @@ class TensorDataImpl : public TensorData { |
|
|
|
if (depth == static_cast<ssize_t>(ndim_) - 1) { // Bottom dimension |
|
|
|
ssize_t num = shape[depth]; |
|
|
|
if (num > kThreshold && ndim_ > 1) { |
|
|
|
OutputDataString(ss, *cursor, 0, kThreshold / 2); |
|
|
|
OutputDataString(ss, *cursor, 0, kThreshold / 2, use_comma); |
|
|
|
ss << ' ' << kEllipsis << ' '; |
|
|
|
OutputDataString(ss, *cursor, num - kThreshold / 2, num); |
|
|
|
OutputDataString(ss, *cursor, num - kThreshold / 2, num, use_comma); |
|
|
|
} else { |
|
|
|
OutputDataString(ss, *cursor, 0, num); |
|
|
|
OutputDataString(ss, *cursor, 0, num, use_comma); |
|
|
|
} |
|
|
|
*cursor += num; |
|
|
|
} else { // Middle dimension |
|
|
|
@@ -321,13 +332,19 @@ class TensorDataImpl : public TensorData { |
|
|
|
// Handle the first half. |
|
|
|
for (ssize_t i = 0; i < std::min(static_cast<ssize_t>(kThreshold / 2), num); i++) { |
|
|
|
if (i > 0) { |
|
|
|
if (use_comma) { |
|
|
|
ss << ','; |
|
|
|
} |
|
|
|
ss << '\n'; |
|
|
|
ss << std::setw(depth + 1) << ' '; // Add the indent. |
|
|
|
} |
|
|
|
SummaryStringRecursive(ss, shape, cursor, depth + 1); |
|
|
|
SummaryStringRecursive(ss, shape, cursor, depth + 1, use_comma); |
|
|
|
} |
|
|
|
// Handle the ignored part. |
|
|
|
if (num > kThreshold) { |
|
|
|
if (use_comma) { |
|
|
|
ss << ','; |
|
|
|
} |
|
|
|
ss << '\n'; |
|
|
|
ss << std::setw(depth + 1) << ' '; // Add the indent. |
|
|
|
ss << kEllipsis; |
|
|
|
@@ -343,10 +360,14 @@ class TensorDataImpl : public TensorData { |
|
|
|
} |
|
|
|
// Handle the second half. |
|
|
|
if (num > kThreshold / 2) { |
|
|
|
for (ssize_t i = num - kThreshold / 2; i < num; i++) { |
|
|
|
auto continue_pos = num - kThreshold / 2; |
|
|
|
for (ssize_t i = continue_pos; i < num; i++) { |
|
|
|
if (use_comma && i != continue_pos) { |
|
|
|
ss << ','; |
|
|
|
} |
|
|
|
ss << '\n'; |
|
|
|
ss << std::setw(depth + 1) << ' '; // Add the indent. |
|
|
|
SummaryStringRecursive(ss, shape, cursor, depth + 1); |
|
|
|
SummaryStringRecursive(ss, shape, cursor, depth + 1, use_comma); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -487,29 +508,35 @@ std::string Tensor::GetShapeAndDataTypeInfo() const { |
|
|
|
return buf.str(); |
|
|
|
} |
|
|
|
|
|
|
|
std::string Tensor::ToString() const { |
|
|
|
constexpr int small_tensor_size = 30; |
|
|
|
std::string Tensor::ToStringInternal(int limit_size) const { |
|
|
|
std::ostringstream buf; |
|
|
|
auto dtype = Dtype(); |
|
|
|
MS_EXCEPTION_IF_NULL(dtype); |
|
|
|
data_sync(); |
|
|
|
buf << "Tensor(shape=" << ShapeToString(shape_) << ", dtype=" << dtype->ToString() << ','; |
|
|
|
if (DataSize() < small_tensor_size) { |
|
|
|
if (limit_size <= 0 || DataSize() < limit_size) { |
|
|
|
// Only print data for small tensor. |
|
|
|
buf << ((data().ndim() > 1) ? '\n' : ' ') << data().ToString(data_type_, shape_) << ')'; |
|
|
|
buf << ((data().ndim() > 1) ? '\n' : ' ') << data().ToString(data_type_, shape_, false) << ')'; |
|
|
|
} else { |
|
|
|
buf << " [...])"; |
|
|
|
} |
|
|
|
return buf.str(); |
|
|
|
} |
|
|
|
|
|
|
|
std::string Tensor::ToString() const { |
|
|
|
constexpr int small_tensor_size = 30; |
|
|
|
return ToStringInternal(small_tensor_size); |
|
|
|
} |
|
|
|
|
|
|
|
std::string Tensor::ToStringNoLimit() const { return ToStringInternal(0); } |
|
|
|
|
|
|
|
std::string Tensor::ToStringRepr() const { |
|
|
|
std::ostringstream buf; |
|
|
|
auto dtype = Dtype(); |
|
|
|
MS_EXCEPTION_IF_NULL(dtype); |
|
|
|
data_sync(); |
|
|
|
buf << "Tensor(shape=" << ShapeToString(shape_) << ", dtype=" << dtype->ToString() << ',' |
|
|
|
<< ((data().ndim() > 1) ? '\n' : ' ') << data().ToString(data_type_, shape_) << ')'; |
|
|
|
<< ((data().ndim() > 1) ? '\n' : ' ') << data().ToString(data_type_, shape_, true) << ')'; |
|
|
|
return buf.str(); |
|
|
|
} |
|
|
|
|
|
|
|
|