|
|
|
@@ -40,16 +40,14 @@ ops::PrimitiveC *OnnxExpandParser::Parse(const onnx::GraphProto &onnx_graph, con |
|
|
|
auto node_iter = |
|
|
|
std::find_if(onnx_graph.node().begin(), onnx_graph.node().end(), |
|
|
|
[onnx_expand_power](const onnx::NodeProto &proto) { return proto.output(0) == onnx_expand_power; }); |
|
|
|
if (node_iter == onnx_graph.node().end()) { |
|
|
|
MS_LOG(ERROR) << "can not find node: " << onnx_expand_power; |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
for (const auto &attr_power : node_iter->attribute()) { |
|
|
|
if (attr_power.name() == "value") { |
|
|
|
const auto &t = attr_power.t(); |
|
|
|
auto *data_ptr = reinterpret_cast<const int64_t *>(t.raw_data().data()); |
|
|
|
for (int i = 0; i < t.dims(0); ++i) { |
|
|
|
dst_shape.emplace_back(data_ptr[i]); |
|
|
|
if (node_iter != onnx_graph.node().end()) { |
|
|
|
for (const auto &attr_power : node_iter->attribute()) { |
|
|
|
if (attr_power.name() == "value") { |
|
|
|
const auto &t = attr_power.t(); |
|
|
|
auto *data_ptr = reinterpret_cast<const int64_t *>(t.raw_data().data()); |
|
|
|
for (int i = 0; i < t.dims(0); ++i) { |
|
|
|
dst_shape.emplace_back(data_ptr[i]); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|