diff --git a/mindspore/lite/src/runtime/kernel/arm/base/convolution_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/convolution_base.cc index 6e84bc1610..116976e6c3 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/convolution_base.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/convolution_base.cc @@ -19,6 +19,7 @@ #include "schema/model_generated.h" #include "src/kernel_registry.h" #include "include/errorcode.h" +#include "src/ops/conv2d.h" using mindspore::lite::KernelRegistrar; using mindspore::lite::RET_ERROR; @@ -80,6 +81,11 @@ void ConvolutionBaseCPUKernel::FreeQuantParam() { } int ConvolutionBaseCPUKernel::Init() { + auto conv2d_lite_primitive = (lite::Conv2D *)primitive_; + conv_param_->pad_u_ = conv2d_lite_primitive->PadUp(); + conv_param_->pad_d_ = conv2d_lite_primitive->PadDown(); + conv_param_->pad_l_ = conv2d_lite_primitive->PadLeft(); + conv_param_->pad_r_ = conv2d_lite_primitive->PadRight(); auto input = this->in_tensors_.front(); auto output = this->out_tensors_.front(); conv_param_->input_batch_ = input->Batch(); diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_arithmetic_operation_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_arithmetic_operation_parser.cc index 12bdde3af9..abb8172fa6 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_arithmetic_operation_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_arithmetic_operation_parser.cc @@ -129,6 +129,24 @@ STATUS OnnxPowParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Node return RET_NULL_PTR; } + const auto &onnx_pow_power = onnx_node.input(1); + auto nodeIter = + std::find_if(onnx_graph.node().begin(), onnx_graph.node().end(), + [onnx_pow_power](const onnx::NodeProto &proto) { return proto.output(0) == onnx_pow_power; }); + if (nodeIter == onnx_graph.node().end()) { + MS_LOG(ERROR) << "can not find node: " << onnx_pow_power; + return RET_ERROR; + } + const float *pW = nullptr; + for (const auto &attrPower : nodeIter->attribute()) { + if (attrPower.name() == "value") { + const auto &t = attrPower.t(); + pW = reinterpret_cast(t.raw_data().data()); + } + } + attr->power = *pW; + attr->scale = 1.0f; + attr->shift = 0.0f; op->primitive->value.type = schema::PrimitiveType_Power; op->primitive->value.value = attr.release(); return RET_OK; @@ -675,7 +693,7 @@ OnnxNodeRegistrar g_onnxInt8AddParser("Int8Add", new OnnxAddParser()); OnnxNodeRegistrar g_onnxSubParser("Sub", new OnnxSubParser()); OnnxNodeRegistrar g_onnxMulParser("Mul", new OnnxMulParser()); OnnxNodeRegistrar g_onnxDivParser("Div", new OnnxDivParser()); -OnnxNodeRegistrar g_onnxPowParser("Power", new OnnxPowParser()); +OnnxNodeRegistrar g_onnxPowParser("Pow", new OnnxPowParser()); OnnxNodeRegistrar g_onnxEqualParser("Equal", new OnnxEqualParser()); OnnxNodeRegistrar g_onnxLessParser("Less", new OnnxLessParser()); OnnxNodeRegistrar g_onnxGreaterParser("Greater", new OnnxGreaterParser()); diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_onehot_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_onehot_parser.cc new file mode 100644 index 0000000000..b001199d98 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_onehot_parser.cc @@ -0,0 +1,55 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/parser/onnx/onnx_onehot_parser.h" +#include + +namespace mindspore { +namespace lite { +STATUS OnnxOneHotParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, + schema::CNodeT *op) { + MS_LOG(DEBUG) << "onnx OneHotParser"; + if (op == nullptr) { + MS_LOG(ERROR) << "op is null"; + return RET_NULL_PTR; + } + op->primitive = std::make_unique(); + if (op->primitive == nullptr) { + MS_LOG(ERROR) << "op->primitive is null"; + return RET_NULL_PTR; + } + + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + + for (const auto &onnx_node_attr : onnx_node.attribute()) { + const auto &attribute_name = onnx_node_attr.name(); + if (attribute_name == "axis") { + attr->axis = static_cast(onnx_node_attr.i()); + } + } + + op->primitive->value.type = schema::PrimitiveType_OneHot; + op->primitive->value.value = attr.release(); + return RET_OK; +} + +OnnxNodeRegistrar g_onnxOneHotParser("OneHot", new OnnxOneHotParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_onehot_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_onehot_parser.h new file mode 100644 index 0000000000..49f183d6a7 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_onehot_parser.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ONEHOT_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ONEHOT_PARSER_H + +#include "tools/converter/parser/onnx/onnx_node_parser.h" +#include "tools/converter/parser/onnx/onnx_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class OnnxOneHotParser : public OnnxNodeParser { + public: + OnnxOneHotParser() : OnnxNodeParser("OneHot") {} + + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ONEHOT_PARSER_H