Browse Source

!7559 add onnx parser of pow

Merge pull request !7559 from yankai10/1021merge
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
de93d9bff1
4 changed files with 113 additions and 1 deletions
  1. +6
    -0
      mindspore/lite/src/runtime/kernel/arm/base/convolution_base.cc
  2. +19
    -1
      mindspore/lite/tools/converter/parser/onnx/onnx_arithmetic_operation_parser.cc
  3. +55
    -0
      mindspore/lite/tools/converter/parser/onnx/onnx_onehot_parser.cc
  4. +33
    -0
      mindspore/lite/tools/converter/parser/onnx/onnx_onehot_parser.h

+ 6
- 0
mindspore/lite/src/runtime/kernel/arm/base/convolution_base.cc View File

@@ -19,6 +19,7 @@
#include "schema/model_generated.h" #include "schema/model_generated.h"
#include "src/kernel_registry.h" #include "src/kernel_registry.h"
#include "include/errorcode.h" #include "include/errorcode.h"
#include "src/ops/conv2d.h"


using mindspore::lite::KernelRegistrar; using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR; using mindspore::lite::RET_ERROR;
@@ -80,6 +81,11 @@ void ConvolutionBaseCPUKernel::FreeQuantParam() {
} }


int ConvolutionBaseCPUKernel::Init() { int ConvolutionBaseCPUKernel::Init() {
auto conv2d_lite_primitive = (lite::Conv2D *)primitive_;
conv_param_->pad_u_ = conv2d_lite_primitive->PadUp();
conv_param_->pad_d_ = conv2d_lite_primitive->PadDown();
conv_param_->pad_l_ = conv2d_lite_primitive->PadLeft();
conv_param_->pad_r_ = conv2d_lite_primitive->PadRight();
auto input = this->in_tensors_.front(); auto input = this->in_tensors_.front();
auto output = this->out_tensors_.front(); auto output = this->out_tensors_.front();
conv_param_->input_batch_ = input->Batch(); conv_param_->input_batch_ = input->Batch();


+ 19
- 1
mindspore/lite/tools/converter/parser/onnx/onnx_arithmetic_operation_parser.cc View File

@@ -129,6 +129,24 @@ STATUS OnnxPowParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Node
return RET_NULL_PTR; return RET_NULL_PTR;
} }


const auto &onnx_pow_power = onnx_node.input(1);
auto nodeIter =
std::find_if(onnx_graph.node().begin(), onnx_graph.node().end(),
[onnx_pow_power](const onnx::NodeProto &proto) { return proto.output(0) == onnx_pow_power; });
if (nodeIter == onnx_graph.node().end()) {
MS_LOG(ERROR) << "can not find node: " << onnx_pow_power;
return RET_ERROR;
}
const float *pW = nullptr;
for (const auto &attrPower : nodeIter->attribute()) {
if (attrPower.name() == "value") {
const auto &t = attrPower.t();
pW = reinterpret_cast<const float *>(t.raw_data().data());
}
}
attr->power = *pW;
attr->scale = 1.0f;
attr->shift = 0.0f;
op->primitive->value.type = schema::PrimitiveType_Power; op->primitive->value.type = schema::PrimitiveType_Power;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();
return RET_OK; return RET_OK;
@@ -675,7 +693,7 @@ OnnxNodeRegistrar g_onnxInt8AddParser("Int8Add", new OnnxAddParser());
OnnxNodeRegistrar g_onnxSubParser("Sub", new OnnxSubParser()); OnnxNodeRegistrar g_onnxSubParser("Sub", new OnnxSubParser());
OnnxNodeRegistrar g_onnxMulParser("Mul", new OnnxMulParser()); OnnxNodeRegistrar g_onnxMulParser("Mul", new OnnxMulParser());
OnnxNodeRegistrar g_onnxDivParser("Div", new OnnxDivParser()); OnnxNodeRegistrar g_onnxDivParser("Div", new OnnxDivParser());
OnnxNodeRegistrar g_onnxPowParser("Power", new OnnxPowParser());
OnnxNodeRegistrar g_onnxPowParser("Pow", new OnnxPowParser());
OnnxNodeRegistrar g_onnxEqualParser("Equal", new OnnxEqualParser()); OnnxNodeRegistrar g_onnxEqualParser("Equal", new OnnxEqualParser());
OnnxNodeRegistrar g_onnxLessParser("Less", new OnnxLessParser()); OnnxNodeRegistrar g_onnxLessParser("Less", new OnnxLessParser());
OnnxNodeRegistrar g_onnxGreaterParser("Greater", new OnnxGreaterParser()); OnnxNodeRegistrar g_onnxGreaterParser("Greater", new OnnxGreaterParser());


+ 55
- 0
mindspore/lite/tools/converter/parser/onnx/onnx_onehot_parser.cc View File

@@ -0,0 +1,55 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "tools/converter/parser/onnx/onnx_onehot_parser.h"
#include <memory>

namespace mindspore {
namespace lite {
STATUS OnnxOneHotParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx OneHotParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}

std::unique_ptr<schema::OneHotT> attr = std::make_unique<schema::OneHotT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}

for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name();
if (attribute_name == "axis") {
attr->axis = static_cast<int32_t>(onnx_node_attr.i());
}
}

op->primitive->value.type = schema::PrimitiveType_OneHot;
op->primitive->value.value = attr.release();
return RET_OK;
}

OnnxNodeRegistrar g_onnxOneHotParser("OneHot", new OnnxOneHotParser());
} // namespace lite
} // namespace mindspore

+ 33
- 0
mindspore/lite/tools/converter/parser/onnx/onnx_onehot_parser.h View File

@@ -0,0 +1,33 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ONEHOT_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ONEHOT_PARSER_H

#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"

namespace mindspore {
namespace lite {
class OnnxOneHotParser : public OnnxNodeParser {
public:
OnnxOneHotParser() : OnnxNodeParser("OneHot") {}

STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ONEHOT_PARSER_H

Loading…
Cancel
Save