|
|
|
@@ -24,6 +24,23 @@ |
|
|
|
|
|
|
|
namespace mindspore { |
|
|
|
namespace lite { |
|
|
|
#if ENABLE_HIGH_PERFORMANCE |
|
|
|
#define CHECK_INT64_MUL_OVERFLOW(x, y) |
|
|
|
#else |
|
|
|
#define CHECK_INT64_MUL_OVERFLOW(x, y) \ |
|
|
|
do { \ |
|
|
|
if (INT64_MUL_OVERFLOW(x, y)) { \ |
|
|
|
MS_LOG(ERROR) << "INT64 MUL OVERFLOW"; \ |
|
|
|
return INT64_MAX; \ |
|
|
|
} \ |
|
|
|
} while (0) |
|
|
|
|
|
|
|
#define INT64_MUL_OVERFLOW(x, y) \ |
|
|
|
(((x) == 0) ? false \ |
|
|
|
: ((x) > 0 ? (((y) >= 0) ? (INT64_MAX / (x)) < (y) : (INT64_MAX / (x)) < (-1 * (y))) \ |
|
|
|
: (((y) >= 0) ? (INT64_MAX / (x)) > (-1 * (y)) : (INT64_MAX / (x)) > (y)))) |
|
|
|
#endif |
|
|
|
|
|
|
|
namespace { |
|
|
|
constexpr int kMaxMallocSize = 1024 * 1024 * 300; |
|
|
|
} // namespace |
|
|
|
@@ -203,52 +220,38 @@ size_t Tensor::Size() const { |
|
|
|
return element_size * element_num; |
|
|
|
} |
|
|
|
|
|
|
|
int Tensor::ElementsNum() const { |
|
|
|
int64_t Tensor::ElementsNum() const { |
|
|
|
if (this->category_ == CONST_SCALAR) { |
|
|
|
return 1; |
|
|
|
} |
|
|
|
int64_t num = 1; |
|
|
|
for (size_t i = 0; i < shape_.size(); ++i) { |
|
|
|
if (shape_[i] < 0) { |
|
|
|
MS_LOG(ERROR) << "shapes contains negative value: " << shape_[i] << " return 0"; |
|
|
|
return 0; |
|
|
|
} |
|
|
|
CHECK_INT64_MUL_OVERFLOW(num, shape_[i]); |
|
|
|
num *= shape_[i]; |
|
|
|
if (num > static_cast<int64_t>(INT32_MAX) || num < 0) { |
|
|
|
MS_LOG(ERROR) << "extend INT32_MAX: " << num << " return INT32_MAX"; |
|
|
|
return INT32_MAX; |
|
|
|
} |
|
|
|
} |
|
|
|
return (int32_t)num; |
|
|
|
return num; |
|
|
|
} |
|
|
|
|
|
|
|
int32_t Tensor::ElementsC4Num() const { |
|
|
|
int64_t Tensor::ElementsC4Num() const { |
|
|
|
if (this->category_ == CONST_SCALAR) { |
|
|
|
return 1; |
|
|
|
} |
|
|
|
int64_t result = 1; |
|
|
|
constexpr int kC4Align = 4; |
|
|
|
if (this->shape_.size() == 4) { |
|
|
|
int64_t tmp_channel = Channel() + 3; |
|
|
|
if (tmp_channel > static_cast<int64_t>(INT32_MAX) || tmp_channel < 0) { |
|
|
|
MS_LOG(ERROR) << "extend INT32_MAX: " << tmp_channel << " return INT32_MAX"; |
|
|
|
return INT32_MAX; |
|
|
|
} |
|
|
|
result = Batch() * Height() * Width() * (tmp_channel / 4 * 4); |
|
|
|
if (result > static_cast<int64_t>(INT32_MAX) || result < 0) { |
|
|
|
MS_LOG(ERROR) << "extend INT32_MAX: " << result << " return INT32_MAX"; |
|
|
|
return INT32_MAX; |
|
|
|
} |
|
|
|
CHECK_INT64_MUL_OVERFLOW(result, Batch()); |
|
|
|
result *= Batch(); |
|
|
|
CHECK_INT64_MUL_OVERFLOW(result, Height()); |
|
|
|
result *= Height(); |
|
|
|
CHECK_INT64_MUL_OVERFLOW(result, Width()); |
|
|
|
result *= Width(); |
|
|
|
CHECK_INT64_MUL_OVERFLOW(result, (Channel() + 3LL) / kC4Align * kC4Align); |
|
|
|
result *= (Channel() + 3LL) / kC4Align * kC4Align; |
|
|
|
} else if (this->shape_.size() == 2) { |
|
|
|
int64_t tmp_shape = this->shape_[1] + 3; |
|
|
|
if (tmp_shape > static_cast<int64_t>(INT32_MAX) || tmp_shape < 0) { |
|
|
|
MS_LOG(ERROR) << "extend INT32_MAX: " << tmp_shape << " return INT32_MAX"; |
|
|
|
return INT32_MAX; |
|
|
|
} |
|
|
|
result = this->shape_[0] * (tmp_shape / 4 * 4); |
|
|
|
if (result > static_cast<int64_t>(INT32_MAX) || result < 0) { |
|
|
|
MS_LOG(ERROR) << "extend INT32_MAX: " << result << " return INT32_MAX"; |
|
|
|
return INT32_MAX; |
|
|
|
} |
|
|
|
CHECK_INT64_MUL_OVERFLOW(result, this->shape_[0]); |
|
|
|
result *= this->shape_[0]; |
|
|
|
CHECK_INT64_MUL_OVERFLOW(result, (this->shape_[1] + 3LL) / kC4Align * kC4Align); |
|
|
|
result *= (this->shape_[1] + 3LL) / kC4Align * kC4Align; |
|
|
|
} |
|
|
|
return result; |
|
|
|
} |
|
|
|
|