Browse Source

!9165 fix onnx parser bug

From: @lyvette
Reviewed-by: @hangangqiang,@zhanghaibo5
Signed-off-by: @hangangqiang
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
83ff5f5ff3
2 changed files with 24 additions and 15 deletions
  1. +5
    -0
      mindspore/lite/tools/converter/parser/onnx/onnx_reduce_parser.cc
  2. +19
    -15
      mindspore/lite/tools/converter/parser/onnx/onnx_relu_parser.cc

+ 5
- 0
mindspore/lite/tools/converter/parser/onnx/onnx_reduce_parser.cc View File

@@ -38,6 +38,7 @@ STATUS OnnxReduceParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N
return RET_NULL_PTR;
}

attr->keepDims = 1;
for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name();
if (attribute_name == "axes") {
@@ -58,6 +59,10 @@ STATUS OnnxReduceParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N
attr->mode = schema::ReduceMode_ReduceMin;
} else if (type == "ReduceSum") {
attr->mode = schema::ReduceMode_ReduceSum;
} else if (type == "ReduceProd") {
attr->mode = schema::ReduceMode_ReduceProd;
} else if (type == "ReduceSumSquare") {
attr->mode = schema::ReduceMode_ReduceSumSquare;
} else {
MS_LOG(ERROR) << "unsupported type";
return RET_ERROR;


+ 19
- 15
mindspore/lite/tools/converter/parser/onnx/onnx_relu_parser.cc View File

@@ -86,23 +86,27 @@ STATUS OnnxPReluParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::No
}
}

const onnx::TensorProto *slope = &params[0];
if (slope == nullptr) {
MS_LOG(ERROR) << "input error: params[0] is null";
return RET_ERROR;
}
const auto slope_raw_data = reinterpret_cast<const float *>(slope->raw_data().data());
const int64_t slope_size = slope->raw_data().size() / sizeof(float);
if (slope_size == 1) {
attr->slope.push_back(*slope_raw_data);
attr->channelShared = true;
} else {
attr->slope.resize(slope_size);
attr->channelShared = false;
if (memcpy_s(attr->slope.data(), slope_size * sizeof(float), slope_raw_data, slope_size * sizeof(float)) != 0) {
MS_LOG(ERROR) << "memcpy_s failed";
if (!params.empty()) {
const onnx::TensorProto *slope = &params[0];
if (slope == nullptr) {
MS_LOG(ERROR) << "input error: params[0] is null";
return RET_ERROR;
}
const auto slope_raw_data = reinterpret_cast<const float *>(slope->raw_data().data());
const int64_t slope_size = slope->raw_data().size() / sizeof(float);
if (slope_size == 1) {
attr->slope.push_back(*slope_raw_data);
attr->channelShared = true;
} else {
attr->slope.resize(slope_size);
attr->channelShared = false;
if (memcpy_s(attr->slope.data(), slope_size * sizeof(float), slope_raw_data, slope_size * sizeof(float)) != 0) {
MS_LOG(ERROR) << "memcpy_s failed";
return RET_ERROR;
}
}
} else {
MS_LOG(WARNING) << "The slope pf prelu is null, which may cause errors.";
}

op->primitive->value.type = schema::PrimitiveType_PReLU;


Loading…
Cancel
Save