From 313762b8dcc3d70091c66317e95ed551cdc9fdb3 Mon Sep 17 00:00:00 2001 From: wang_shaocong Date: Mon, 16 Nov 2020 02:08:04 +0800 Subject: [PATCH] check validation of input shape before running graphes. --- mindspore/lite/nnacl/split_parameter.h | 2 +- mindspore/lite/src/executor.cc | 7 +++++++ mindspore/lite/src/ops/populate/split_populate.cc | 8 +++++++- .../legacy_optimizer/graph/trans_format_insert_pass.cc | 5 ++++- .../lite/tools/converter/parser/onnx/onnx_model_parser.cc | 3 +++ .../tools/converter/parser/onnx/onnx_quantize_parser.cc | 4 ++-- 6 files changed, 24 insertions(+), 5 deletions(-) diff --git a/mindspore/lite/nnacl/split_parameter.h b/mindspore/lite/nnacl/split_parameter.h index fec09595d2..469f498f3a 100644 --- a/mindspore/lite/nnacl/split_parameter.h +++ b/mindspore/lite/nnacl/split_parameter.h @@ -24,7 +24,7 @@ typedef struct SplitParameter { OpParameter op_parameter_; SplitQuantArg quant_arg_; int num_split_; - int split_sizes_[32]; + int *split_sizes_; int strides_[32]; int split_dim_; int n_dims_; diff --git a/mindspore/lite/src/executor.cc b/mindspore/lite/src/executor.cc index 96d5d4b383..e01af935f9 100644 --- a/mindspore/lite/src/executor.cc +++ b/mindspore/lite/src/executor.cc @@ -29,6 +29,13 @@ int Executor::CheckInputs(std::vector &in_tensors) { MS_LOG(ERROR) << "Graph input tensor data is nullptr"; return RET_ERROR; } + auto shape = inTensor->shape(); + bool valid = all_of(shape.begin(), shape.end(), [](int i) { return i > 0; }); + if (!valid) { + MS_LOG(ERROR) << "The shape of input tensor contains zero or negative dimension," + << "check the model and assign the input shape with method Resize()."; + return RET_ERROR; + } } return RET_OK; } diff --git a/mindspore/lite/src/ops/populate/split_populate.cc b/mindspore/lite/src/ops/populate/split_populate.cc index eeaeed8349..cad5d71f5c 100644 --- a/mindspore/lite/src/ops/populate/split_populate.cc +++ b/mindspore/lite/src/ops/populate/split_populate.cc @@ -32,13 +32,19 @@ OpParameter *PopulateSplitParameter(const mindspore::lite::PrimitiveC *primitive auto param = reinterpret_cast(const_cast(primitive)); split_param->op_parameter_.type_ = primitive->Type(); split_param->num_split_ = param->GetNumberSplit(); + int *split_sizes = reinterpret_cast(malloc(split_param->num_split_ * sizeof(int))); + if (split_sizes == nullptr) { + MS_LOG(ERROR) << "malloc split size of SplitParameter failed."; + return nullptr; + } + memset(split_sizes, 0, split_param->num_split_ * sizeof(int)); + split_param->split_sizes_ = split_sizes; auto split_sizes_vector_ = param->GetSizeSplits(); int i = 0; for (auto iter = split_sizes_vector_.begin(); iter != split_sizes_vector_.end(); iter++) { split_param->split_sizes_[i++] = *iter; } split_param->split_dim_ = param->GetSplitDim(); - split_param->num_split_ = param->GetNumberSplit(); return reinterpret_cast(split_param); } Registry SplitParameterRegistry(schema::PrimitiveType_Split, PopulateSplitParameter); diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/trans_format_insert_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/trans_format_insert_pass.cc index 85256efbb4..263be6b2d2 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/trans_format_insert_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/trans_format_insert_pass.cc @@ -294,7 +294,10 @@ STATUS TransOpInsertPass::Run(schema::MetaGraphT *graph) { return ret; } ret = ChangeOpAxis(graph, node); - if (ret != RET_OK) { + if (ret == RET_NOT_SUPPORT) { + MS_LOG(INFO) << "not support to ChangeOpAxis"; + return RET_OK; + } else if (ret != RET_OK) { MS_LOG(INFO) << "no need to ChangeOpAxis"; return ret; } diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc index 0d3125e32b..41cebcc8e5 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc @@ -251,6 +251,7 @@ STATUS OnnxModelParser::ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node) quant_param->zeroPoint = static_cast(onnx_node_attr.i()); } } + quant_param->inited = true; tensor->quantParams.emplace_back(std::move(quant_param)); } else { MS_LOG(ERROR) << "unsupported data type " << tensor->dataType; @@ -369,6 +370,7 @@ void OnnxModelParser::SetOpQuantParams(const onnx::GraphProto &onnx_graph, const MS_LOG(ERROR) << "new QuantParamT failed, node: " << dst_op->name; return; } + quant_param->inited = true; int argNum = 0; for (const auto &onnx_node_attr : node.attribute()) { if (onnx_node_attr.name() == "Y_scale") { @@ -384,6 +386,7 @@ void OnnxModelParser::SetOpQuantParams(const onnx::GraphProto &onnx_graph, const quant_param->zeroPoint = 0; quant_param->min = FLT_MAX; quant_param->max = FLT_MAX; + quant_param->inited = false; } dst_tensor->quantParams.emplace_back(std::move(quant_param)); if (argNum == 2) { diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_quantize_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_quantize_parser.cc index bc7bb87931..bfb45dd89a 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_quantize_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_quantize_parser.cc @@ -39,9 +39,9 @@ STATUS OnnxQuantizeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx: } if (onnx_node.op_type() == "Int8Quantize") { attr->srcT = kNumberTypeFloat32; - attr->dstT = kNumberTypeInt8; + attr->dstT = kNumberTypeUInt8; } else if (onnx_node.op_type() == "Int8Dequantize") { - attr->srcT = kNumberTypeInt8; + attr->srcT = kNumberTypeUInt8; attr->dstT = kNumberTypeFloat32; } else { MS_LOG(ERROR) << "Unsupported nodeType: " << onnx_node.op_type().c_str();