|
|
|
@@ -131,33 +131,33 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod |
|
|
|
|
|
|
|
const auto &onnx_conv_weight = onnx_node.input(1); |
|
|
|
if (onnx_node.op_type() == "Conv") { |
|
|
|
auto nodeIter = |
|
|
|
auto node_iter = |
|
|
|
std::find_if(onnx_graph.initializer().begin(), onnx_graph.initializer().end(), |
|
|
|
[onnx_conv_weight](const onnx::TensorProto &proto) { return proto.name() == onnx_conv_weight; }); |
|
|
|
if (nodeIter == onnx_graph.initializer().end()) { |
|
|
|
if (node_iter == onnx_graph.initializer().end()) { |
|
|
|
MS_LOG(WARNING) << "not find node: " << onnx_conv_weight; |
|
|
|
} else { |
|
|
|
std::vector<int> weight_shape; |
|
|
|
auto size = (*nodeIter).dims_size(); |
|
|
|
auto size = (*node_iter).dims_size(); |
|
|
|
weight_shape.reserve(size); |
|
|
|
for (int i = 0; i < size; ++i) { |
|
|
|
weight_shape.emplace_back((*nodeIter).dims(i)); |
|
|
|
weight_shape.emplace_back((*node_iter).dims(i)); |
|
|
|
} |
|
|
|
attr->channelOut = weight_shape[0]; |
|
|
|
attr->channelIn = weight_shape[1] * attr->group; |
|
|
|
} |
|
|
|
} else { |
|
|
|
auto nodeIter = |
|
|
|
auto node_iter = |
|
|
|
std::find_if(onnx_graph.node().begin(), onnx_graph.node().end(), |
|
|
|
[onnx_conv_weight](const onnx::NodeProto &proto) { return proto.output(0) == onnx_conv_weight; }); |
|
|
|
if (nodeIter == onnx_graph.node().end()) { |
|
|
|
if (node_iter == onnx_graph.node().end()) { |
|
|
|
MS_LOG(ERROR) << "can not find node: " << onnx_conv_weight; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
std::vector<int> dims; |
|
|
|
auto iter = std::find_if((*nodeIter).attribute().begin(), (*nodeIter).attribute().end(), |
|
|
|
auto iter = std::find_if((*node_iter).attribute().begin(), (*node_iter).attribute().end(), |
|
|
|
[](const onnx::AttributeProto &attr) { return attr.name() == "shape"; }); |
|
|
|
if (iter != (*nodeIter).attribute().end()) { |
|
|
|
if (iter != (*node_iter).attribute().end()) { |
|
|
|
if (iter->ints().begin() == nullptr || iter->ints().end() == nullptr) { |
|
|
|
MS_LOG(ERROR) << "dims insert failed"; |
|
|
|
return RET_ERROR; |
|
|
|
|