From ff3d7819168c6aef09ad56e134071086699b67a9 Mon Sep 17 00:00:00 2001 From: yvette Date: Sat, 28 Nov 2020 15:50:06 +0800 Subject: [PATCH] fix onnx parser bug --- .../parser/onnx/onnx_reduce_parser.cc | 5 +++ .../converter/parser/onnx/onnx_relu_parser.cc | 34 +++++++++++-------- 2 files changed, 24 insertions(+), 15 deletions(-) diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_reduce_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_reduce_parser.cc index fa2293ae06..74eec74153 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_reduce_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_reduce_parser.cc @@ -38,6 +38,7 @@ STATUS OnnxReduceParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N return RET_NULL_PTR; } + attr->keepDims = 1; for (const auto &onnx_node_attr : onnx_node.attribute()) { const auto &attribute_name = onnx_node_attr.name(); if (attribute_name == "axes") { @@ -58,6 +59,10 @@ STATUS OnnxReduceParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N attr->mode = schema::ReduceMode_ReduceMin; } else if (type == "ReduceSum") { attr->mode = schema::ReduceMode_ReduceSum; + } else if (type == "ReduceProd") { + attr->mode = schema::ReduceMode_ReduceProd; + } else if (type == "ReduceSumSquare") { + attr->mode = schema::ReduceMode_ReduceSumSquare; } else { MS_LOG(ERROR) << "unsupported type"; return RET_ERROR; diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_relu_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_relu_parser.cc index 2079a900c9..c4b260b43e 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_relu_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_relu_parser.cc @@ -86,23 +86,27 @@ STATUS OnnxPReluParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::No } } - const onnx::TensorProto *slope = ¶ms[0]; - if (slope == nullptr) { - MS_LOG(ERROR) << "input error: params[0] is null"; - return RET_ERROR; - } - const auto slope_raw_data = reinterpret_cast(slope->raw_data().data()); - const int64_t slope_size = slope->raw_data().size() / sizeof(float); - if (slope_size == 1) { - attr->slope.push_back(*slope_raw_data); - attr->channelShared = true; - } else { - attr->slope.resize(slope_size); - attr->channelShared = false; - if (memcpy_s(attr->slope.data(), slope_size * sizeof(float), slope_raw_data, slope_size * sizeof(float)) != 0) { - MS_LOG(ERROR) << "memcpy_s failed"; + if (!params.empty()) { + const onnx::TensorProto *slope = ¶ms[0]; + if (slope == nullptr) { + MS_LOG(ERROR) << "input error: params[0] is null"; return RET_ERROR; } + const auto slope_raw_data = reinterpret_cast(slope->raw_data().data()); + const int64_t slope_size = slope->raw_data().size() / sizeof(float); + if (slope_size == 1) { + attr->slope.push_back(*slope_raw_data); + attr->channelShared = true; + } else { + attr->slope.resize(slope_size); + attr->channelShared = false; + if (memcpy_s(attr->slope.data(), slope_size * sizeof(float), slope_raw_data, slope_size * sizeof(float)) != 0) { + MS_LOG(ERROR) << "memcpy_s failed"; + return RET_ERROR; + } + } + } else { + MS_LOG(WARNING) << "The slope pf prelu is null, which may cause errors."; } op->primitive->value.type = schema::PrimitiveType_PReLU;