Browse Source

!4435 Refactory Tensor bool type and ToString()

Merge pull request !4435 from hewei/refactory_tensor_bool_tostring
tags/v0.7.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
9d55ac62c8
1 changed files with 47 additions and 29 deletions
  1. +47
    -29
      mindspore/core/ir/tensor.cc

+ 47
- 29
mindspore/core/ir/tensor.cc View File

@@ -54,6 +54,18 @@ static size_t SizeOf(const std::vector<int> &shape) {
return std::accumulate(shape.begin(), shape.end(), size_t(1), std::multiplies<size_t>()); return std::accumulate(shape.begin(), shape.end(), size_t(1), std::multiplies<size_t>());
} }


static std::string ShapeToString(const std::vector<int> &shape) {
std::string str = "[";
const size_t count = shape.size();
for (size_t i = 0; i < count; ++i) {
if (i > 0) {
str.append(", ");
}
str.append(std::to_string(shape[i]));
}
return str.append("]");
}

template <typename T, typename U> template <typename T, typename U>
std::unique_ptr<T[]> NewData(const U *input, size_t size) { std::unique_ptr<T[]> NewData(const U *input, size_t size) {
if (input == nullptr || size == 0) { if (input == nullptr || size == 0) {
@@ -84,7 +96,10 @@ template <typename T>
std::unique_ptr<T[]> CopyData(const std::vector<int> &shape, void *const data, TypeId data_type) { std::unique_ptr<T[]> CopyData(const std::vector<int> &shape, void *const data, TypeId data_type) {
const size_t size = SizeOf(shape); const size_t size = SizeOf(shape);
switch (data_type) { switch (data_type) {
case kNumberTypeBool:
case kNumberTypeBool: {
auto buf = static_cast<bool *>(data);
return NewData<T>(buf, size);
}
case kNumberTypeUInt8: { case kNumberTypeUInt8: {
auto buf = static_cast<uint8_t *>(data); auto buf = static_cast<uint8_t *>(data);
return NewData<T>(buf, size); return NewData<T>(buf, size);
@@ -200,7 +215,7 @@ class TensorDataImpl : public TensorData {


std::string ToString(const TypeId type, const std::vector<int> &shape) const override { std::string ToString(const TypeId type, const std::vector<int> &shape) const override {
constexpr auto valid = constexpr auto valid =
std::is_same<T, Bool>::value || std::is_same<T, uint8_t>::value || std::is_same<T, int8_t>::value ||
std::is_same<T, bool>::value || std::is_same<T, uint8_t>::value || std::is_same<T, int8_t>::value ||
std::is_same<T, int16_t>::value || std::is_same<T, int32_t>::value || std::is_same<T, int64_t>::value || std::is_same<T, int16_t>::value || std::is_same<T, int32_t>::value || std::is_same<T, int64_t>::value ||
std::is_same<T, uint16_t>::value || std::is_same<T, uint32_t>::value || std::is_same<T, uint64_t>::value || std::is_same<T, uint16_t>::value || std::is_same<T, uint32_t>::value || std::is_same<T, uint64_t>::value ||
std::is_same<T, float16>::value || std::is_same<T, float>::value || std::is_same<T, double>::value; std::is_same<T, float16>::value || std::is_same<T, float>::value || std::is_same<T, double>::value;
@@ -214,27 +229,28 @@ class TensorDataImpl : public TensorData {


std::ostringstream ss; std::ostringstream ss;
if (data_size_ == 1 && ndim_ == 0) { // Scalar if (data_size_ == 1 && ndim_ == 0) { // Scalar
OutputDataString(ss, type, 0, 0, 1);
OutputDataString(ss, 0, 0, 1);
return ss.str(); return ss.str();
} }
ssize_t cursor = 0; ssize_t cursor = 0;
SummaryStringRecursive(ss, type, shape, &cursor, 0);
SummaryStringRecursive(ss, shape, &cursor, 0);
return ss.str(); return ss.str();
} }


private: private:
void OutputDataString(std::ostringstream &ss, const TypeId type, ssize_t cursor, ssize_t start, ssize_t end) const {
bool isScalar = ndim_ == 0 && end - start == 1;
int linefeedThreshold;
void OutputDataString(std::ostringstream &ss, ssize_t cursor, ssize_t start, ssize_t end) const {
const bool isScalar = ndim_ == 0 && end - start == 1;
constexpr auto isFloat = constexpr auto isFloat =
std::is_same<T, float16>::value || std::is_same<T, float>::value || std::is_same<T, double>::value; std::is_same<T, float16>::value || std::is_same<T, float>::value || std::is_same<T, double>::value;
constexpr auto isBool = std::is_same<T, bool>::value;
constexpr int linefeedThreshold = isFloat ? kThreshold1DFloat : (isBool ? kThreshold1DBool : kThreshold1DInt);
for (ssize_t i = start; i < end && (cursor + i) < static_cast<ssize_t>(data_size_); i++) { for (ssize_t i = start; i < end && (cursor + i) < static_cast<ssize_t>(data_size_); i++) {
const auto value = data_[cursor + i]; const auto value = data_[cursor + i];
if constexpr (isFloat) { if constexpr (isFloat) {
if (isScalar) { if (isScalar) {
ss << value; ss << value;
} else { } else {
if (std::is_same<T, float16>::value) {
if constexpr (std::is_same<T, float16>::value) {
ss << std::setw(11) << std::setprecision(4) << std::setiosflags(std::ios::scientific | std::ios::right) ss << std::setw(11) << std::setprecision(4) << std::setiosflags(std::ios::scientific | std::ios::right)
<< value; << value;
} else { } else {
@@ -242,14 +258,12 @@ class TensorDataImpl : public TensorData {
<< value; << value;
} }
} }
linefeedThreshold = kThreshold1DFloat;
} else if (type == kNumberTypeBool) {
} else if (std::is_same<T, bool>::value) {
if (isScalar) { if (isScalar) {
ss << (value == 0 ? "False" : "True");
ss << (value ? "True" : "False");
} else { } else {
ss << std::setw(5) << std::setiosflags(std::ios::right) << (value == 0 ? "False" : "True");
ss << std::setw(5) << std::setiosflags(std::ios::right) << (value ? "True" : "False");
} }
linefeedThreshold = kThreshold1DBool;
} else { } else {
constexpr auto isSigned = std::is_same<T, int8_t>::value || std::is_same<T, int16_t>::value || constexpr auto isSigned = std::is_same<T, int8_t>::value || std::is_same<T, int16_t>::value ||
std::is_same<T, int32_t>::value || std::is_same<T, int64_t>::value; std::is_same<T, int32_t>::value || std::is_same<T, int64_t>::value;
@@ -276,7 +290,6 @@ class TensorDataImpl : public TensorData {
} else { } else {
ss << value; ss << value;
} }
linefeedThreshold = kThreshold1DInt;
} }
if (!isScalar && i != end - 1) { if (!isScalar && i != end - 1) {
ss << ' '; ss << ' ';
@@ -288,7 +301,7 @@ class TensorDataImpl : public TensorData {
} }
} }


void SummaryStringRecursive(std::ostringstream &ss, const TypeId type, const std::vector<int> &shape, ssize_t *cursor,
void SummaryStringRecursive(std::ostringstream &ss, const std::vector<int> &shape, ssize_t *cursor,
ssize_t depth) const { ssize_t depth) const {
if (depth >= static_cast<ssize_t>(ndim_)) { if (depth >= static_cast<ssize_t>(ndim_)) {
return; return;
@@ -297,11 +310,11 @@ class TensorDataImpl : public TensorData {
if (depth == static_cast<ssize_t>(ndim_) - 1) { // Bottom dimension if (depth == static_cast<ssize_t>(ndim_) - 1) { // Bottom dimension
ssize_t num = shape[depth]; ssize_t num = shape[depth];
if (num > kThreshold && ndim_ > 1) { if (num > kThreshold && ndim_ > 1) {
OutputDataString(ss, type, *cursor, 0, kThreshold / 2);
OutputDataString(ss, *cursor, 0, kThreshold / 2);
ss << ' ' << kEllipsis << ' '; ss << ' ' << kEllipsis << ' ';
OutputDataString(ss, type, *cursor, num - kThreshold / 2, num);
OutputDataString(ss, *cursor, num - kThreshold / 2, num);
} else { } else {
OutputDataString(ss, type, *cursor, 0, num);
OutputDataString(ss, *cursor, 0, num);
} }
*cursor += num; *cursor += num;
} else { // Middle dimension } else { // Middle dimension
@@ -312,7 +325,7 @@ class TensorDataImpl : public TensorData {
ss << '\n'; ss << '\n';
ss << std::setw(depth + 1) << ' '; // Add the indent. ss << std::setw(depth + 1) << ' '; // Add the indent.
} }
SummaryStringRecursive(ss, type, shape, cursor, depth + 1);
SummaryStringRecursive(ss, shape, cursor, depth + 1);
} }
// Handle the ignored part. // Handle the ignored part.
if (num > kThreshold) { if (num > kThreshold) {
@@ -334,7 +347,7 @@ class TensorDataImpl : public TensorData {
for (ssize_t i = num - kThreshold / 2; i < num; i++) { for (ssize_t i = num - kThreshold / 2; i < num; i++) {
ss << '\n'; ss << '\n';
ss << std::setw(depth + 1) << ' '; // Add the indent. ss << std::setw(depth + 1) << ' '; // Add the indent.
SummaryStringRecursive(ss, type, shape, cursor, depth + 1);
SummaryStringRecursive(ss, shape, cursor, depth + 1);
} }
} }
} }
@@ -350,6 +363,7 @@ template <typename... Args>
TensorDataPtr MakeTensorData(TypeId data_type, const std::vector<int> &shape, const Args... args) { TensorDataPtr MakeTensorData(TypeId data_type, const std::vector<int> &shape, const Args... args) {
switch (data_type) { switch (data_type) {
case kNumberTypeBool: case kNumberTypeBool:
return std::make_shared<TensorDataImpl<bool>>(shape, args...);
case kNumberTypeUInt8: case kNumberTypeUInt8:
return std::make_shared<TensorDataImpl<uint8_t>>(shape, args...); return std::make_shared<TensorDataImpl<uint8_t>>(shape, args...);
case kNumberTypeInt8: case kNumberTypeInt8:
@@ -466,31 +480,35 @@ std::string Tensor::GetShapeAndDataTypeInfo() const {
} }


std::string Tensor::ToString() const { std::string Tensor::ToString() const {
const int small_tensor_size = 30;
constexpr int small_tensor_size = 30;
std::ostringstream buf; std::ostringstream buf;
auto dtype = Dtype();
MS_EXCEPTION_IF_NULL(dtype);
data_sync(); data_sync();
buf << "Tensor shape:[" << shape() << "]" << this->Dtype()->ToString();
// only print small tensor
buf << "Tensor(shape=" << ShapeToString(shape_) << ", dtype=" << dtype->ToString() << ",\n";
if (DataSize() < small_tensor_size) { if (DataSize() < small_tensor_size) {
buf << ", value:" << data().ToString(data_type_, shape());
// Only print data for small tensor.
buf << data().ToString(data_type_, shape_) << ')';
} else {
buf << "[...])";
} }
return buf.str(); return buf.str();
} }


std::string Tensor::ToStringRepr() const { std::string Tensor::ToStringRepr() const {
std::ostringstream buf; std::ostringstream buf;
auto type_ptr = this->Dtype();
MS_EXCEPTION_IF_NULL(type_ptr);
auto dtype = Dtype();
MS_EXCEPTION_IF_NULL(dtype);
data_sync(); data_sync();
buf << "Tensor shape:[" << shape() << "]" << type_ptr->ToString();
buf << "\nvalue:" << data().ToString(data_type_, shape());
buf << "Tensor(shape=" << ShapeToString(shape_) << ", dtype=" << dtype->ToString() << ",\n"
<< data().ToString(data_type_, shape_) << ')';
return buf.str(); return buf.str();
} }


void Tensor::data_sync() const { void Tensor::data_sync() const {
if (device_sync_ != nullptr) { if (device_sync_ != nullptr) {
if (!device_sync_->SyncDeviceToHost(shape(), static_cast<size_t>(data().nbytes()), data_type(), data_c())) { if (!device_sync_->SyncDeviceToHost(shape(), static_cast<size_t>(data().nbytes()), data_type(), data_c())) {
MS_LOG(EXCEPTION) << "SyncDeviceToHost when asnumpy.";
MS_LOG(EXCEPTION) << "SyncDeviceToHost failed.";
} }
} }
} }


Loading…
Cancel
Save