From 4824727a6383db4057b5cb713ac3311cadd7bc86 Mon Sep 17 00:00:00 2001 From: yankai Date: Wed, 28 Oct 2020 14:33:14 +0800 Subject: [PATCH] fix broadcast parser of onnx --- mindspore/lite/src/ops/broadcast_to.cc | 17 ++++++++++--- .../parser/onnx/onnx_expand_parser.cc | 24 +++++++++++++++++-- 2 files changed, 36 insertions(+), 5 deletions(-) diff --git a/mindspore/lite/src/ops/broadcast_to.cc b/mindspore/lite/src/ops/broadcast_to.cc index 2ce4a5d260..49edd976f4 100644 --- a/mindspore/lite/src/ops/broadcast_to.cc +++ b/mindspore/lite/src/ops/broadcast_to.cc @@ -62,21 +62,32 @@ Registry BroadcastToRegistry(schema::PrimitiveType_BroadcastTo, BroadcastToCreat namespace { constexpr int kBroadcastToInputNum = 1; +constexpr int kBroadcastToOnnxInputNum = 2; constexpr int kBroadcastToOutputNum = 1; } // namespace int BroadcastTo::InferShape(std::vector inputs, std::vector outputs) { - if (inputs.size() != kBroadcastToInputNum || outputs.size() != kBroadcastToOutputNum) { - MS_LOG(ERROR) << "input size:" << inputs.size() << ", output size:" << outputs.size(); + if (inputs.size() != kBroadcastToInputNum && inputs.size() != kBroadcastToOnnxInputNum) { + MS_LOG(ERROR) << "input size:" << inputs.size(); return RET_PARAM_INVALID; } + if (outputs.size() != kBroadcastToOutputNum) { + MS_LOG(ERROR) << "output size:" << outputs.size(); + return RET_PARAM_INVALID; + } + auto input = inputs.at(0); outputs[0]->SetFormat(input->GetFormat()); outputs[0]->set_data_type(input->data_type()); if (!GetInferFlag()) { return RET_OK; } - std::vector dst_shape(GetDstShape().begin(), GetDstShape().end()); + std::vector dst_shape(GetDstShape()); + for (size_t i = 0; i < dst_shape.size(); ++i) { + if (dst_shape[i] == -1) { + dst_shape[i] = inputs[0]->shape()[i]; + } + } auto input_shape = input->shape(); std::vector shape(dst_shape.size()); int input_shape_index = input_shape.size() - 1; diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_expand_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_expand_parser.cc index 7994faf770..767120d571 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_expand_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_expand_parser.cc @@ -32,13 +32,33 @@ STATUS OnnxExpandParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N return RET_NULL_PTR; } - std::unique_ptr attr = std::make_unique(); + std::unique_ptr attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; return RET_NULL_PTR; } - op->primitive->value.type = schema::PrimitiveType_Broadcast; + std::vector dst_shape; + const auto &onnx_expand_power = onnx_node.input(1); + auto nodeIter = + std::find_if(onnx_graph.node().begin(), onnx_graph.node().end(), + [onnx_expand_power](const onnx::NodeProto &proto) { return proto.output(0) == onnx_expand_power; }); + if (nodeIter == onnx_graph.node().end()) { + MS_LOG(ERROR) << "can not find node: " << onnx_expand_power; + return RET_ERROR; + } + const int64_t *dataPtr = nullptr; + for (const auto &attrPower : nodeIter->attribute()) { + if (attrPower.name() == "value") { + const auto &t = attrPower.t(); + dataPtr = reinterpret_cast(t.raw_data().data()); + for (int i = 0; i < t.dims(0); ++i) { + dst_shape.emplace_back(dataPtr[i]); + } + } + } + attr->dst_shape = dst_shape; + op->primitive->value.type = schema::PrimitiveType_BroadcastTo; op->primitive->value.value = attr.release(); return RET_OK; }