Merge pull request !7559 from yankai10/1021mergetags/v1.1.0
| @@ -19,6 +19,7 @@ | |||||
| #include "schema/model_generated.h" | #include "schema/model_generated.h" | ||||
| #include "src/kernel_registry.h" | #include "src/kernel_registry.h" | ||||
| #include "include/errorcode.h" | #include "include/errorcode.h" | ||||
| #include "src/ops/conv2d.h" | |||||
| using mindspore::lite::KernelRegistrar; | using mindspore::lite::KernelRegistrar; | ||||
| using mindspore::lite::RET_ERROR; | using mindspore::lite::RET_ERROR; | ||||
| @@ -80,6 +81,11 @@ void ConvolutionBaseCPUKernel::FreeQuantParam() { | |||||
| } | } | ||||
| int ConvolutionBaseCPUKernel::Init() { | int ConvolutionBaseCPUKernel::Init() { | ||||
| auto conv2d_lite_primitive = (lite::Conv2D *)primitive_; | |||||
| conv_param_->pad_u_ = conv2d_lite_primitive->PadUp(); | |||||
| conv_param_->pad_d_ = conv2d_lite_primitive->PadDown(); | |||||
| conv_param_->pad_l_ = conv2d_lite_primitive->PadLeft(); | |||||
| conv_param_->pad_r_ = conv2d_lite_primitive->PadRight(); | |||||
| auto input = this->in_tensors_.front(); | auto input = this->in_tensors_.front(); | ||||
| auto output = this->out_tensors_.front(); | auto output = this->out_tensors_.front(); | ||||
| conv_param_->input_batch_ = input->Batch(); | conv_param_->input_batch_ = input->Batch(); | ||||
| @@ -129,6 +129,24 @@ STATUS OnnxPowParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Node | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| const auto &onnx_pow_power = onnx_node.input(1); | |||||
| auto nodeIter = | |||||
| std::find_if(onnx_graph.node().begin(), onnx_graph.node().end(), | |||||
| [onnx_pow_power](const onnx::NodeProto &proto) { return proto.output(0) == onnx_pow_power; }); | |||||
| if (nodeIter == onnx_graph.node().end()) { | |||||
| MS_LOG(ERROR) << "can not find node: " << onnx_pow_power; | |||||
| return RET_ERROR; | |||||
| } | |||||
| const float *pW = nullptr; | |||||
| for (const auto &attrPower : nodeIter->attribute()) { | |||||
| if (attrPower.name() == "value") { | |||||
| const auto &t = attrPower.t(); | |||||
| pW = reinterpret_cast<const float *>(t.raw_data().data()); | |||||
| } | |||||
| } | |||||
| attr->power = *pW; | |||||
| attr->scale = 1.0f; | |||||
| attr->shift = 0.0f; | |||||
| op->primitive->value.type = schema::PrimitiveType_Power; | op->primitive->value.type = schema::PrimitiveType_Power; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| return RET_OK; | return RET_OK; | ||||
| @@ -675,7 +693,7 @@ OnnxNodeRegistrar g_onnxInt8AddParser("Int8Add", new OnnxAddParser()); | |||||
| OnnxNodeRegistrar g_onnxSubParser("Sub", new OnnxSubParser()); | OnnxNodeRegistrar g_onnxSubParser("Sub", new OnnxSubParser()); | ||||
| OnnxNodeRegistrar g_onnxMulParser("Mul", new OnnxMulParser()); | OnnxNodeRegistrar g_onnxMulParser("Mul", new OnnxMulParser()); | ||||
| OnnxNodeRegistrar g_onnxDivParser("Div", new OnnxDivParser()); | OnnxNodeRegistrar g_onnxDivParser("Div", new OnnxDivParser()); | ||||
| OnnxNodeRegistrar g_onnxPowParser("Power", new OnnxPowParser()); | |||||
| OnnxNodeRegistrar g_onnxPowParser("Pow", new OnnxPowParser()); | |||||
| OnnxNodeRegistrar g_onnxEqualParser("Equal", new OnnxEqualParser()); | OnnxNodeRegistrar g_onnxEqualParser("Equal", new OnnxEqualParser()); | ||||
| OnnxNodeRegistrar g_onnxLessParser("Less", new OnnxLessParser()); | OnnxNodeRegistrar g_onnxLessParser("Less", new OnnxLessParser()); | ||||
| OnnxNodeRegistrar g_onnxGreaterParser("Greater", new OnnxGreaterParser()); | OnnxNodeRegistrar g_onnxGreaterParser("Greater", new OnnxGreaterParser()); | ||||
| @@ -0,0 +1,55 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "tools/converter/parser/onnx/onnx_onehot_parser.h" | |||||
| #include <memory> | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| STATUS OnnxOneHotParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||||
| schema::CNodeT *op) { | |||||
| MS_LOG(DEBUG) << "onnx OneHotParser"; | |||||
| if (op == nullptr) { | |||||
| MS_LOG(ERROR) << "op is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (op->primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "op->primitive is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| std::unique_ptr<schema::OneHotT> attr = std::make_unique<schema::OneHotT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | |||||
| const auto &attribute_name = onnx_node_attr.name(); | |||||
| if (attribute_name == "axis") { | |||||
| attr->axis = static_cast<int32_t>(onnx_node_attr.i()); | |||||
| } | |||||
| } | |||||
| op->primitive->value.type = schema::PrimitiveType_OneHot; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | |||||
| } | |||||
| OnnxNodeRegistrar g_onnxOneHotParser("OneHot", new OnnxOneHotParser()); | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,33 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ONEHOT_PARSER_H | |||||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ONEHOT_PARSER_H | |||||
| #include "tools/converter/parser/onnx/onnx_node_parser.h" | |||||
| #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class OnnxOneHotParser : public OnnxNodeParser { | |||||
| public: | |||||
| OnnxOneHotParser() : OnnxNodeParser("OneHot") {} | |||||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ONEHOT_PARSER_H | |||||