From 03f54af905b0ac42a018b4ec5cf92ccf61774706 Mon Sep 17 00:00:00 2001 From: gongdaguo Date: Mon, 26 Oct 2020 17:22:19 +0800 Subject: [PATCH] fix softmax parser, fix mul bug, fix lrn bug --- mindspore/lite/schema/ops.fbs | 2 +- .../src/runtime/kernel/arm/fp32/arithmetic.cc | 27 ++++++++++++++----- mindspore/lite/test/models_onnx.cfg | 5 ++++ mindspore/lite/tools/common/node_util.cc | 1 + .../parser/onnx/onnx_model_parser.cc | 1 + .../converter/parser/onnx/onnx_node_parser.cc | 2 ++ .../converter/parser/onnx/onnx_node_parser.h | 9 +++++++ .../converter/parser/onnx/onnx_pool_parser.cc | 2 +- .../parser/onnx/onnx_softmax_parser.cc | 10 ++++++- 9 files changed, 50 insertions(+), 9 deletions(-) diff --git a/mindspore/lite/schema/ops.fbs b/mindspore/lite/schema/ops.fbs index 4ad406898e..44ee28a214 100644 --- a/mindspore/lite/schema/ops.fbs +++ b/mindspore/lite/schema/ops.fbs @@ -172,7 +172,7 @@ table Concat { } table SoftMax { - axis: int; + axis: int = -1; } table Activation { diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.cc index 410fdcfa08..b47cc9becb 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.cc @@ -73,12 +73,27 @@ int ArithmeticCPUKernel::ReSize() { arithmeticParameter_->in_elements_num0_ = in_tensors_[0]->ElementsNum(); arithmeticParameter_->in_elements_num1_ = in_tensors_[1]->ElementsNum(); arithmeticParameter_->out_elements_num_ = out_tensors_[0]->ElementsNum(); - memcpy(arithmeticParameter_->in_shape0_, static_cast(in_tensors_[0]->shape().data()), - in_tensors_[0]->shape().size() * sizeof(int)); - memcpy(arithmeticParameter_->in_shape1_, static_cast(in_tensors_[1]->shape().data()), - in_tensors_[1]->shape().size() * sizeof(int)); - memcpy(arithmeticParameter_->out_shape_, static_cast(out_tensors_[0]->shape().data()), - out_tensors_[0]->shape().size() * sizeof(int)); + for (size_t i = 0; i < in_tensors_[0]->shape().size(); i++) { + if (arithmeticParameter_->in_shape0_[i] == -1) { + memcpy(arithmeticParameter_->in_shape0_, static_cast(in_tensors_[0]->shape().data()), + in_tensors_[0]->shape().size() * sizeof(int)); + break; + } + } + for (size_t i = 0; i < in_tensors_[1]->shape().size(); i++) { + if (arithmeticParameter_->in_shape1_[i] == -1) { + memcpy(arithmeticParameter_->in_shape1_, static_cast(in_tensors_[1]->shape().data()), + in_tensors_[1]->shape().size() * sizeof(int)); + break; + } + } + for (size_t i = 0; i < out_tensors_[0]->shape().size(); i++) { + if (arithmeticParameter_->out_shape_[i] == -1) { + memcpy(arithmeticParameter_->out_shape_, static_cast(out_tensors_[0]->shape().data()), + out_tensors_[0]->shape().size() * sizeof(int)); + break; + } + } if (arithmeticParameter_->in_elements_num0_ == 1 || arithmeticParameter_->in_elements_num1_ == 1) { switch (arithmeticParameter_->op_parameter_.type_) { diff --git a/mindspore/lite/test/models_onnx.cfg b/mindspore/lite/test/models_onnx.cfg index 8dc7ecc8a3..c1e0891784 100644 --- a/mindspore/lite/test/models_onnx.cfg +++ b/mindspore/lite/test/models_onnx.cfg @@ -2,6 +2,11 @@ mtk_detect-mbv2-shortcut-400-400-simplified.onnx mtk_emotions-d2012-75.8%.onnx mtk_face_features_v3.onnx emotion-ferplus-8.onnx +#rcnn-ilsvrc13-9.onnx +efficientnet-lite4-11.onnx +mobilenetv2-7.onnx +shufflenet-v2-10.onnx +squeezenet1.1-7.onnx ml_face_3d.onnx gts_version-RFB-320_simplified.onnx mnist-8.onnx diff --git a/mindspore/lite/tools/common/node_util.cc b/mindspore/lite/tools/common/node_util.cc index f6775e8b09..8d1c3cbc41 100644 --- a/mindspore/lite/tools/common/node_util.cc +++ b/mindspore/lite/tools/common/node_util.cc @@ -42,6 +42,7 @@ static const std::vector nhwcOpList = { schema::PrimitiveType_DepthwiseConv2D, schema::PrimitiveType_DeDepthwiseConv2D, schema::PrimitiveType_Pooling, + schema::PrimitiveType_LocalResponseNormalization, schema::PrimitiveType_Resize, schema::PrimitiveType_BatchNorm, schema::PrimitiveType_FusedBatchNorm, diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc index 66c62d809a..a4a48662b3 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc @@ -608,6 +608,7 @@ schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &modelFile, con ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); return nullptr; } + OnnxNodeParser::set_opset_version(onnx_model.opset_import().Get(0).version()); const onnx::GraphProto &onnx_graph = onnx_model.graph(); MS_LOG(INFO) << "model producer name: " << onnx_model.producer_name() << ", graph name: " << onnx_graph.name(); diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.cc index 84fca4c794..4ffc5fe9fd 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.cc @@ -20,6 +20,8 @@ namespace mindspore { namespace lite { +int OnnxNodeParser::opset_version_ = 0; + schema::PadMode OnnxNodeParser::GetOnnxPadMode(const onnx::AttributeProto &onnx_node_attr) { if (onnx_node_attr.s() == "NOTSET") { return schema::PadMode_NOTSET; diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h index 43c7eb1032..f9b6a2b8f8 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h @@ -37,12 +37,21 @@ class OnnxNodeParser { STATUS GetTensorDataFromOnnx(const onnx::TensorProto &onnx_tensor, std::vector *value, int *type); + static STATUS set_opset_version(int version) { + opset_version_ = version; + return RET_OK; + } + static int opset_version() { return opset_version_; } + protected: schema::PadMode GetOnnxPadMode(const onnx::AttributeProto &onnx_node_attr); void Split(const std::string &src_str, std::vector *dst_str, const std::string &chr); const std::string &name; + + private: + static int opset_version_; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_pool_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_pool_parser.cc index d97fb94fb4..ea041867a6 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_pool_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_pool_parser.cc @@ -94,7 +94,7 @@ STATUS OnnxPoolParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod } } if (attribute_name == "ceil_mode") { - if (onnx_node_attr.f() == 0) { + if (onnx_node_attr.i() == 0) { attr->roundMode = schema::RoundMode_FLOOR; } else { attr->roundMode = schema::RoundMode_CEIL; diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_softmax_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_softmax_parser.cc index 8e35db742d..5e136e1583 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_softmax_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_softmax_parser.cc @@ -38,13 +38,21 @@ STATUS OnnxSoftMaxParser::Parse(const onnx::GraphProto &onnx_graph, const onnx:: return RET_NULL_PTR; } + bool axis_is_def = true; for (const auto &onnx_node_attr : onnx_node.attribute()) { const auto &attribute_name = onnx_node_attr.name(); if (attribute_name == "axis") { attr->axis = static_cast(onnx_node_attr.i()); + axis_is_def = false; + } + } + if (axis_is_def) { + if (OnnxNodeParser::opset_version() >= 13) { + attr->axis = -1; + } else { + attr->axis = 1; } } - op->primitive->value.type = schema::PrimitiveType_SoftMax; op->primitive->value.value = attr.release(); return RET_OK;