|
|
|
@@ -15,7 +15,9 @@ |
|
|
|
*/ |
|
|
|
|
|
|
|
#include "tools/converter/parser/onnx/onnx_arithmetic_operation_parser.h" |
|
|
|
#include "tools/converter/parser/onnx/onnx_tensor_parser.h" |
|
|
|
#include <memory> |
|
|
|
#include <numeric> |
|
|
|
|
|
|
|
namespace mindspore { |
|
|
|
namespace lite { |
|
|
|
@@ -130,21 +132,21 @@ STATUS OnnxPowParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Node |
|
|
|
} |
|
|
|
|
|
|
|
const auto &onnx_pow_power = onnx_node.input(1); |
|
|
|
auto nodeIter = |
|
|
|
std::find_if(onnx_graph.node().begin(), onnx_graph.node().end(), |
|
|
|
[onnx_pow_power](const onnx::NodeProto &proto) { return proto.output(0) == onnx_pow_power; }); |
|
|
|
if (nodeIter == onnx_graph.node().end()) { |
|
|
|
int index = OnnxTensorParser::GetInstance()->GetTensorCache()->FindTensor(onnx_pow_power); |
|
|
|
if (index == -1) { |
|
|
|
MS_LOG(ERROR) << "can not find node: " << onnx_pow_power; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
const float *pW = nullptr; |
|
|
|
for (const auto &attrPower : nodeIter->attribute()) { |
|
|
|
if (attrPower.name() == "value") { |
|
|
|
const auto &t = attrPower.t(); |
|
|
|
pW = reinterpret_cast<const float *>(t.raw_data().data()); |
|
|
|
} |
|
|
|
auto pow_attr = OnnxTensorParser::GetInstance()->GetTensorCache()->GetCachedTensor()[index]; |
|
|
|
if (std::accumulate(pow_attr->dims.begin(), pow_attr->dims.end(), 1, std::multiplies<int>()) != 1) { |
|
|
|
MS_LOG(ERROR) << "the exponent element num is bigger than 1, which don't support now."; |
|
|
|
return RET_NOT_SUPPORT; |
|
|
|
} |
|
|
|
attr->power = *pW; |
|
|
|
if (pow_attr->data.data() == nullptr) { |
|
|
|
MS_LOG(ERROR) << "power's attr pow can't be obtained."; |
|
|
|
return RET_INVALID_OP_ATTR; |
|
|
|
} |
|
|
|
attr->power = *reinterpret_cast<float *>(pow_attr->data.data()); |
|
|
|
attr->scale = 1.0f; |
|
|
|
attr->shift = 0.0f; |
|
|
|
op->primitive->value.type = schema::PrimitiveType_Power; |
|
|
|
|