|
|
|
@@ -226,11 +226,29 @@ int32_t Tensor::ElementsC4Num() const { |
|
|
|
if (this->category_ == CONST_SCALAR) { |
|
|
|
return 1; |
|
|
|
} |
|
|
|
int32_t result = 1; |
|
|
|
int64_t result = 1; |
|
|
|
if (this->shape_.size() == 4) { |
|
|
|
result = Batch() * Height() * Width() * ((Channel() + 3) / 4 * 4); |
|
|
|
int64_t tmp_channel = Channel() + 3; |
|
|
|
if (tmp_channel > static_cast<int64_t>(INT32_MAX) || tmp_channel < 0) { |
|
|
|
MS_LOG(ERROR) << "extend INT32_MAX: " << tmp_channel << " return INT32_MAX"; |
|
|
|
return INT32_MAX; |
|
|
|
} |
|
|
|
result = Batch() * Height() * Width() * (tmp_channel / 4 * 4); |
|
|
|
if (result > static_cast<int64_t>(INT32_MAX) || result < 0) { |
|
|
|
MS_LOG(ERROR) << "extend INT32_MAX: " << result << " return INT32_MAX"; |
|
|
|
return INT32_MAX; |
|
|
|
} |
|
|
|
} else if (this->shape_.size() == 2) { |
|
|
|
result = this->shape_[0] * ((this->shape_[1] + 3) / 4 * 4); |
|
|
|
int64_t tmp_shape = this->shape_[1] + 3; |
|
|
|
if (tmp_shape > static_cast<int64_t>(INT32_MAX) || tmp_shape < 0) { |
|
|
|
MS_LOG(ERROR) << "extend INT32_MAX: " << tmp_shape << " return INT32_MAX"; |
|
|
|
return INT32_MAX; |
|
|
|
} |
|
|
|
result = this->shape_[0] * (tmp_shape / 4 * 4); |
|
|
|
if (result > static_cast<int64_t>(INT32_MAX) || result < 0) { |
|
|
|
MS_LOG(ERROR) << "extend INT32_MAX: " << result << " return INT32_MAX"; |
|
|
|
return INT32_MAX; |
|
|
|
} |
|
|
|
} |
|
|
|
return result; |
|
|
|
} |
|
|
|
|