Browse Source

check validation of input shape before running graphes.

tags/v1.1.0
wang_shaocong 5 years ago
parent
commit
313762b8dc
6 changed files with 24 additions and 5 deletions
  1. +1
    -1
      mindspore/lite/nnacl/split_parameter.h
  2. +7
    -0
      mindspore/lite/src/executor.cc
  3. +7
    -1
      mindspore/lite/src/ops/populate/split_populate.cc
  4. +4
    -1
      mindspore/lite/tools/converter/legacy_optimizer/graph/trans_format_insert_pass.cc
  5. +3
    -0
      mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc
  6. +2
    -2
      mindspore/lite/tools/converter/parser/onnx/onnx_quantize_parser.cc

+ 1
- 1
mindspore/lite/nnacl/split_parameter.h View File

@@ -24,7 +24,7 @@ typedef struct SplitParameter {
OpParameter op_parameter_;
SplitQuantArg quant_arg_;
int num_split_;
int split_sizes_[32];
int *split_sizes_;
int strides_[32];
int split_dim_;
int n_dims_;


+ 7
- 0
mindspore/lite/src/executor.cc View File

@@ -29,6 +29,13 @@ int Executor::CheckInputs(std::vector<Tensor *> &in_tensors) {
MS_LOG(ERROR) << "Graph input tensor data is nullptr";
return RET_ERROR;
}
auto shape = inTensor->shape();
bool valid = all_of(shape.begin(), shape.end(), [](int i) { return i > 0; });
if (!valid) {
MS_LOG(ERROR) << "The shape of input tensor contains zero or negative dimension,"
<< "check the model and assign the input shape with method Resize().";
return RET_ERROR;
}
}
return RET_OK;
}


+ 7
- 1
mindspore/lite/src/ops/populate/split_populate.cc View File

@@ -32,13 +32,19 @@ OpParameter *PopulateSplitParameter(const mindspore::lite::PrimitiveC *primitive
auto param = reinterpret_cast<mindspore::lite::Split *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
split_param->op_parameter_.type_ = primitive->Type();
split_param->num_split_ = param->GetNumberSplit();
int *split_sizes = reinterpret_cast<int *>(malloc(split_param->num_split_ * sizeof(int)));
if (split_sizes == nullptr) {
MS_LOG(ERROR) << "malloc split size of SplitParameter failed.";
return nullptr;
}
memset(split_sizes, 0, split_param->num_split_ * sizeof(int));
split_param->split_sizes_ = split_sizes;
auto split_sizes_vector_ = param->GetSizeSplits();
int i = 0;
for (auto iter = split_sizes_vector_.begin(); iter != split_sizes_vector_.end(); iter++) {
split_param->split_sizes_[i++] = *iter;
}
split_param->split_dim_ = param->GetSplitDim();
split_param->num_split_ = param->GetNumberSplit();
return reinterpret_cast<OpParameter *>(split_param);
}
Registry SplitParameterRegistry(schema::PrimitiveType_Split, PopulateSplitParameter);


+ 4
- 1
mindspore/lite/tools/converter/legacy_optimizer/graph/trans_format_insert_pass.cc View File

@@ -294,7 +294,10 @@ STATUS TransOpInsertPass::Run(schema::MetaGraphT *graph) {
return ret;
}
ret = ChangeOpAxis(graph, node);
if (ret != RET_OK) {
if (ret == RET_NOT_SUPPORT) {
MS_LOG(INFO) << "not support to ChangeOpAxis";
return RET_OK;
} else if (ret != RET_OK) {
MS_LOG(INFO) << "no need to ChangeOpAxis";
return ret;
}


+ 3
- 0
mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc View File

@@ -251,6 +251,7 @@ STATUS OnnxModelParser::ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node)
quant_param->zeroPoint = static_cast<int32_t>(onnx_node_attr.i());
}
}
quant_param->inited = true;
tensor->quantParams.emplace_back(std::move(quant_param));
} else {
MS_LOG(ERROR) << "unsupported data type " << tensor->dataType;
@@ -369,6 +370,7 @@ void OnnxModelParser::SetOpQuantParams(const onnx::GraphProto &onnx_graph, const
MS_LOG(ERROR) << "new QuantParamT failed, node: " << dst_op->name;
return;
}
quant_param->inited = true;
int argNum = 0;
for (const auto &onnx_node_attr : node.attribute()) {
if (onnx_node_attr.name() == "Y_scale") {
@@ -384,6 +386,7 @@ void OnnxModelParser::SetOpQuantParams(const onnx::GraphProto &onnx_graph, const
quant_param->zeroPoint = 0;
quant_param->min = FLT_MAX;
quant_param->max = FLT_MAX;
quant_param->inited = false;
}
dst_tensor->quantParams.emplace_back(std::move(quant_param));
if (argNum == 2) {


+ 2
- 2
mindspore/lite/tools/converter/parser/onnx/onnx_quantize_parser.cc View File

@@ -39,9 +39,9 @@ STATUS OnnxQuantizeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx:
}
if (onnx_node.op_type() == "Int8Quantize") {
attr->srcT = kNumberTypeFloat32;
attr->dstT = kNumberTypeInt8;
attr->dstT = kNumberTypeUInt8;
} else if (onnx_node.op_type() == "Int8Dequantize") {
attr->srcT = kNumberTypeInt8;
attr->srcT = kNumberTypeUInt8;
attr->dstT = kNumberTypeFloat32;
} else {
MS_LOG(ERROR) << "Unsupported nodeType: " << onnx_node.op_type().c_str();


Loading…
Cancel
Save