Browse Source

!24811 [MSLITE] change ElementsNum data type to int64_t

Merge pull request !24811 from zhanyuan/super_large
tags/v1.6.0
i-robot Gitee 4 years ago
parent
commit
2c3ce1d73c
8 changed files with 46 additions and 43 deletions
  1. +1
    -1
      mindspore/lite/include/ms_tensor.h
  2. +3
    -3
      mindspore/lite/java/native/runtime/ms_tensor.cpp
  3. +1
    -1
      mindspore/lite/micro/coder/generator/component/const_blocks/benchmark.cc
  4. +3
    -3
      mindspore/lite/micro/coder/generator/component/const_blocks/mtensor.cc
  5. +1
    -1
      mindspore/lite/src/cxx_api/tensor/tensor_impl.h
  6. +34
    -31
      mindspore/lite/src/tensor.cc
  7. +2
    -2
      mindspore/lite/src/tensor.h
  8. +1
    -1
      mindspore/lite/tools/benchmark/benchmark.cc

+ 1
- 1
mindspore/lite/include/ms_tensor.h View File

@@ -86,7 +86,7 @@ class MS_API MSTensor {
/// \brief Get number of element in MSTensor.
///
/// \return Number of element in MSTensor.
virtual int ElementsNum() const = 0;
virtual int64_t ElementsNum() const = 0;

/// \brief Get byte size of data in MSTensor.
///


+ 3
- 3
mindspore/lite/java/native/runtime/ms_tensor.cpp View File

@@ -97,7 +97,7 @@ extern "C" JNIEXPORT jlongArray JNICALL Java_com_mindspore_lite_MSTensor_getLong
}
auto local_element_num = ms_tensor_ptr->ElementsNum();
if (local_element_num <= 0) {
MS_LOGE("ElementsNum of tensor is negative: %d", local_element_num);
MS_LOGE("ElementsNum of tensor is negative: %d", static_cast<int>(local_element_num));
return env->NewLongArray(0);
}
auto ret = env->NewLongArray(local_element_num);
@@ -127,7 +127,7 @@ extern "C" JNIEXPORT jintArray JNICALL Java_com_mindspore_lite_MSTensor_getIntDa
}
auto local_element_num = ms_tensor_ptr->ElementsNum();
if (local_element_num <= 0) {
MS_LOGE("ElementsNum of tensor is negative: %d", local_element_num);
MS_LOGE("ElementsNum of tensor is negative: %d", static_cast<int>(local_element_num));
return env->NewIntArray(0);
}
auto ret = env->NewIntArray(local_element_num);
@@ -157,7 +157,7 @@ extern "C" JNIEXPORT jfloatArray JNICALL Java_com_mindspore_lite_MSTensor_getFlo
}
auto local_element_num = ms_tensor_ptr->ElementsNum();
if (local_element_num <= 0) {
MS_LOGE("ElementsNum of tensor is negative: %d", local_element_num);
MS_LOGE("ElementsNum of tensor is negative: %d", static_cast<int>(local_element_num));
return env->NewFloatArray(0);
}
auto ret = env->NewFloatArray(local_element_num);


+ 1
- 1
mindspore/lite/micro/coder/generator/component/const_blocks/benchmark.cc View File

@@ -85,7 +85,7 @@ void PrintData(void *data, size_t data_number) {
void TensorToString(tensor::MSTensor *tensor) {
printf("name: %s, ", tensor->tensor_name().c_str());
printf("DataType: %d, ", tensor->data_type());
printf("Elements: %d, ", tensor->ElementsNum());
printf("Elements: %d, ", static_cast<int>(tensor->ElementsNum()));
printf("Shape: [");
for (auto &dim : tensor->shape()) {
printf("%d ", dim);


+ 3
- 3
mindspore/lite/micro/coder/generator/component/const_blocks/mtensor.cc View File

@@ -69,7 +69,7 @@ class MTensor : public mindspore::tensor::MSTensor {
mindspore::Format format() const override { return mindspore::NHWC; }
Vector<int> shape() const override { return shape_; }
void set_shape(const Vector<int> &shape) override { shape_ = shape; }
int ElementsNum() const override;
int64_t ElementsNum() const override;
size_t Size() const override;
String tensor_name() const override { return tensor_name_; }
void set_tensor_name(const String &name) override { tensor_name_ = name; }
@@ -156,8 +156,8 @@ MTensor::~MTensor() {
}
}

int MTensor::ElementsNum() const {
int elements = 1;
int64_t MTensor::ElementsNum() const {
int64_t elements = 1;
for (int i : shape_) {
elements *= i;
}


+ 1
- 1
mindspore/lite/src/cxx_api/tensor/tensor_impl.h View File

@@ -107,7 +107,7 @@ class MSTensor::Impl {
MS_LOG(ERROR) << "Invalid tensor.";
return -1;
}
return static_cast<int64_t>(lite_tensor_->ElementsNum());
return lite_tensor_->ElementsNum();
}

virtual const std::vector<int64_t> &Shape() const {


+ 34
- 31
mindspore/lite/src/tensor.cc View File

@@ -24,6 +24,23 @@

namespace mindspore {
namespace lite {
#if ENABLE_HIGH_PERFORMANCE
#define CHECK_INT64_MUL_OVERFLOW(x, y)
#else
#define CHECK_INT64_MUL_OVERFLOW(x, y) \
do { \
if (INT64_MUL_OVERFLOW(x, y)) { \
MS_LOG(ERROR) << "INT64 MUL OVERFLOW"; \
return INT64_MAX; \
} \
} while (0)

#define INT64_MUL_OVERFLOW(x, y) \
(((x) == 0) ? false \
: ((x) > 0 ? (((y) >= 0) ? (INT64_MAX / (x)) < (y) : (INT64_MAX / (x)) < (-1 * (y))) \
: (((y) >= 0) ? (INT64_MAX / (x)) > (-1 * (y)) : (INT64_MAX / (x)) > (y))))
#endif

namespace {
constexpr int kMaxMallocSize = 1024 * 1024 * 300;
} // namespace
@@ -203,52 +220,38 @@ size_t Tensor::Size() const {
return element_size * element_num;
}

int Tensor::ElementsNum() const {
int64_t Tensor::ElementsNum() const {
if (this->category_ == CONST_SCALAR) {
return 1;
}
int64_t num = 1;
for (size_t i = 0; i < shape_.size(); ++i) {
if (shape_[i] < 0) {
MS_LOG(ERROR) << "shapes contains negative value: " << shape_[i] << " return 0";
return 0;
}
CHECK_INT64_MUL_OVERFLOW(num, shape_[i]);
num *= shape_[i];
if (num > static_cast<int64_t>(INT32_MAX) || num < 0) {
MS_LOG(ERROR) << "extend INT32_MAX: " << num << " return INT32_MAX";
return INT32_MAX;
}
}
return (int32_t)num;
return num;
}

int32_t Tensor::ElementsC4Num() const {
int64_t Tensor::ElementsC4Num() const {
if (this->category_ == CONST_SCALAR) {
return 1;
}
int64_t result = 1;
constexpr int kC4Align = 4;
if (this->shape_.size() == 4) {
int64_t tmp_channel = Channel() + 3;
if (tmp_channel > static_cast<int64_t>(INT32_MAX) || tmp_channel < 0) {
MS_LOG(ERROR) << "extend INT32_MAX: " << tmp_channel << " return INT32_MAX";
return INT32_MAX;
}
result = Batch() * Height() * Width() * (tmp_channel / 4 * 4);
if (result > static_cast<int64_t>(INT32_MAX) || result < 0) {
MS_LOG(ERROR) << "extend INT32_MAX: " << result << " return INT32_MAX";
return INT32_MAX;
}
CHECK_INT64_MUL_OVERFLOW(result, Batch());
result *= Batch();
CHECK_INT64_MUL_OVERFLOW(result, Height());
result *= Height();
CHECK_INT64_MUL_OVERFLOW(result, Width());
result *= Width();
CHECK_INT64_MUL_OVERFLOW(result, (Channel() + 3LL) / kC4Align * kC4Align);
result *= (Channel() + 3LL) / kC4Align * kC4Align;
} else if (this->shape_.size() == 2) {
int64_t tmp_shape = this->shape_[1] + 3;
if (tmp_shape > static_cast<int64_t>(INT32_MAX) || tmp_shape < 0) {
MS_LOG(ERROR) << "extend INT32_MAX: " << tmp_shape << " return INT32_MAX";
return INT32_MAX;
}
result = this->shape_[0] * (tmp_shape / 4 * 4);
if (result > static_cast<int64_t>(INT32_MAX) || result < 0) {
MS_LOG(ERROR) << "extend INT32_MAX: " << result << " return INT32_MAX";
return INT32_MAX;
}
CHECK_INT64_MUL_OVERFLOW(result, this->shape_[0]);
result *= this->shape_[0];
CHECK_INT64_MUL_OVERFLOW(result, (this->shape_[1] + 3LL) / kC4Align * kC4Align);
result *= (this->shape_[1] + 3LL) / kC4Align * kC4Align;
}
return result;
}


+ 2
- 2
mindspore/lite/src/tensor.h View File

@@ -88,7 +88,7 @@ class Tensor : public mindspore::tensor::MSTensor {

int DimensionSize(size_t index) const;

int ElementsNum() const override;
int64_t ElementsNum() const override;

int32_t Batch() const;

@@ -98,7 +98,7 @@ class Tensor : public mindspore::tensor::MSTensor {

int32_t Width() const;

int32_t ElementsC4Num() const;
int64_t ElementsC4Num() const;

size_t Size() const override;



+ 1
- 1
mindspore/lite/tools/benchmark/benchmark.cc View File

@@ -398,7 +398,7 @@ int Benchmark::PrintInputData() {
}
continue;
}
size_t print_num = std::min(input->ElementsNum(), 20);
size_t print_num = std::min(static_cast<int>(input->ElementsNum()), 20);
const void *in_data = input->MutableData();
if (in_data == nullptr) {
MS_LOG(ERROR) << "in_data is nullptr.";


Loading…
Cancel
Save