Browse Source

fix bug of int mul and add overflow

tags/v1.6.0
mengyuanli 4 years ago
parent
commit
ba0af7cbad
5 changed files with 36 additions and 33 deletions
  1. +4
    -8
      mindspore/lite/src/runtime/kernel/arm/fp16/cast_fp16.cc
  2. +4
    -7
      mindspore/lite/src/runtime/kernel/arm/fp16_grad/strided_slice_fp16_grad.cc
  3. +4
    -7
      mindspore/lite/src/runtime/kernel/arm/fp32_grad/strided_slice_grad.cc
  4. +3
    -8
      mindspore/lite/src/scheduler.cc
  5. +21
    -3
      mindspore/lite/src/tensor.cc

+ 4
- 8
mindspore/lite/src/runtime/kernel/arm/fp16/cast_fp16.cc View File

@@ -125,15 +125,11 @@ int CastFp16CPUKernel::DoCast(int thread_id) {
return RET_ERROR;
}
} else if (input_data_type == kNumberTypeInt32) {
switch (output_data_type) {
case kNumberTypeFloat32:
Int32ToFloat32(static_cast<int32_t *>(input_data) + offset, static_cast<float *>(output_data) + offset,
data_num);
break;
default:
MS_LOG(ERROR) << "Unsupported output data type " << output_data_type;
return RET_ERROR;
if (output_data_type != kNumberTypeFloat32) {
MS_LOG(ERROR) << "Unsupported output data type " << output_data_type;
return RET_ERROR;
}
Int32ToFloat32(static_cast<int32_t *>(input_data) + offset, static_cast<float *>(output_data) + offset, data_num);
} else {
MS_LOG(ERROR) << "Unsupported input data type " << input_data_type;
return RET_ERROR;


+ 4
- 7
mindspore/lite/src/runtime/kernel/arm/fp16_grad/strided_slice_fp16_grad.cc View File

@@ -41,14 +41,11 @@ int StridedSliceGradCPUKernelFp16::Prepare() {
auto input = in_tensors_.at(0);
CHECK_NULL_RETURN(input);
CHECK_NULL_RETURN(out_tensors_.at(0));
switch (input->data_type()) {
case kNumberTypeFloat16:
param_->data_type = kDataTypeFloat16;
break;
default:
MS_LOG(ERROR) << "Not supported data type: " << input->data_type();
return RET_ERROR;
if (input->data_type() != kNumberTypeFloat16) {
MS_LOG(ERROR) << "Not supported data type: " << input->data_type();
return RET_ERROR;
}
param_->data_type = kDataTypeFloat16;
FillEmptyDims();
FillOutputDim();
return ReSize();


+ 4
- 7
mindspore/lite/src/runtime/kernel/arm/fp32_grad/strided_slice_grad.cc View File

@@ -42,14 +42,11 @@ int StridedSliceGradCPUKernel::Prepare() {
CHECK_NULL_RETURN(out_tensors_.at(0));
auto input = in_tensors_.at(0);
CHECK_NULL_RETURN(input);
switch (input->data_type()) {
case kNumberTypeFloat32:
param_->data_type = kDataTypeFloat;
break;
default:
MS_LOG(ERROR) << "Not supported data type: " << input->data_type();
return RET_ERROR;
if (input->data_type() != kNumberTypeFloat32) {
MS_LOG(ERROR) << "Not supported data type: " << input->data_type();
return RET_ERROR;
}
param_->data_type = kDataTypeFloat;
return ReSize();
}



+ 3
- 8
mindspore/lite/src/scheduler.cc View File

@@ -655,15 +655,10 @@ int Scheduler::CopyPartialShapeToSubGraph(const lite::Model::Node *partial_node)
for (size_t i = 0; i < partial_node->input_indices_.size(); ++i) {
auto &subgraph_input = src_tensors_->at(subgraph->input_indices_[i]);
auto &partial_input = src_tensors_->at(partial_node->input_indices_[i]);
switch (partial_input->data_type()) {
case kObjectTypeTensorType: {
return RET_INFER_INVALID;
}
default: {
CopyCommonTensor(subgraph_input, partial_input);
break;
}
if (partial_input->data_type() == kObjectTypeTensorType) {
return RET_INFER_INVALID;
}
CopyCommonTensor(subgraph_input, partial_input);
}

return RET_OK;


+ 21
- 3
mindspore/lite/src/tensor.cc View File

@@ -226,11 +226,29 @@ int32_t Tensor::ElementsC4Num() const {
if (this->category_ == CONST_SCALAR) {
return 1;
}
int32_t result = 1;
int64_t result = 1;
if (this->shape_.size() == 4) {
result = Batch() * Height() * Width() * ((Channel() + 3) / 4 * 4);
int64_t tmp_channel = Channel() + 3;
if (tmp_channel > static_cast<int64_t>(INT32_MAX) || tmp_channel < 0) {
MS_LOG(ERROR) << "extend INT32_MAX: " << tmp_channel << " return INT32_MAX";
return INT32_MAX;
}
result = Batch() * Height() * Width() * (tmp_channel / 4 * 4);
if (result > static_cast<int64_t>(INT32_MAX) || result < 0) {
MS_LOG(ERROR) << "extend INT32_MAX: " << result << " return INT32_MAX";
return INT32_MAX;
}
} else if (this->shape_.size() == 2) {
result = this->shape_[0] * ((this->shape_[1] + 3) / 4 * 4);
int64_t tmp_shape = this->shape_[1] + 3;
if (tmp_shape > static_cast<int64_t>(INT32_MAX) || tmp_shape < 0) {
MS_LOG(ERROR) << "extend INT32_MAX: " << tmp_shape << " return INT32_MAX";
return INT32_MAX;
}
result = this->shape_[0] * (tmp_shape / 4 * 4);
if (result > static_cast<int64_t>(INT32_MAX) || result < 0) {
MS_LOG(ERROR) << "extend INT32_MAX: " << result << " return INT32_MAX";
return INT32_MAX;
}
}
return result;
}


Loading…
Cancel
Save