| @@ -0,0 +1,36 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifdef PRIMITIVE_WRITEABLE | |||||
| #ifndef LITE_MINDSPORE_LITE_C_OPS_CONSTANT_H_ | |||||
| #define LITE_MINDSPORE_LITE_C_OPS_CONSTANT_H_ | |||||
| #include "src/ops/primitive_c.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class Constant : public PrimitiveC { | |||||
| public: | |||||
| Constant() = default; | |||||
| ~Constant() = default; | |||||
| MS_DECLARE_PARENT(Constant, PrimitiveC); | |||||
| explicit Constant(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // LITE_MINDSPORE_LITE_C_OPS_CONSTANT_H_ | |||||
| #endif | |||||
| @@ -149,6 +149,7 @@ | |||||
| #include "src/ops/oneslike.h" | #include "src/ops/oneslike.h" | ||||
| #include "src/ops/unsorted_segment_sum.h" | #include "src/ops/unsorted_segment_sum.h" | ||||
| #include "src/ops/reciprocal.h" | #include "src/ops/reciprocal.h" | ||||
| #include "src/ops/constant.h" | |||||
| #ifdef SUPPORT_TRAIN | #ifdef SUPPORT_TRAIN | ||||
| #include "src/ops/neg_grad.h" | #include "src/ops/neg_grad.h" | ||||
| @@ -186,7 +187,7 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| std::vector<int> CastToInt(const ValuePtr value) { | |||||
| std::vector<int> CastToInt(const ValuePtr &value) { | |||||
| if (value == nullptr) { | if (value == nullptr) { | ||||
| MS_LOG(WARNING) << "valueptr is nullptr."; | MS_LOG(WARNING) << "valueptr is nullptr."; | ||||
| return {}; | return {}; | ||||
| @@ -903,6 +904,8 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) { | |||||
| return new (std::nothrow) Dequant(primitive); | return new (std::nothrow) Dequant(primitive); | ||||
| case schema::PrimitiveType_Reciprocal: | case schema::PrimitiveType_Reciprocal: | ||||
| return new (std::nothrow) Reciprocal(primitive); | return new (std::nothrow) Reciprocal(primitive); | ||||
| case schema::PrimitiveType_Constant: | |||||
| return new (std::nothrow) Constant(primitive); | |||||
| #ifdef SUPPORT_TRAIN | #ifdef SUPPORT_TRAIN | ||||
| case schema::PrimitiveType_ActivationGrad: | case schema::PrimitiveType_ActivationGrad: | ||||
| @@ -57,7 +57,7 @@ static std::map<std::string, schema::ActivationType> kActivationTypeMap{ | |||||
| {"LeakyRelu", schema::ActivationType_LEAKY_RELU}, | {"LeakyRelu", schema::ActivationType_LEAKY_RELU}, | ||||
| {"Tanh", schema::ActivationType_TANH}, | {"Tanh", schema::ActivationType_TANH}, | ||||
| {"Logistic", schema::ActivationType_SIGMOID}}; | {"Logistic", schema::ActivationType_SIGMOID}}; | ||||
| std::vector<int> CastToInt(const ValuePtr value); | |||||
| std::vector<int> CastToInt(const ValuePtr &value); | |||||
| class PrimitiveC : public mindspore::Primitive { | class PrimitiveC : public mindspore::Primitive { | ||||
| public: | public: | ||||
| // Argument primitive is deliverd into PrimitiveC and will be deleted in ~PrimitiveC(). | // Argument primitive is deliverd into PrimitiveC and will be deleted in ~PrimitiveC(). | ||||
| @@ -204,6 +204,7 @@ if(ENABLE_CONVERTER) | |||||
| ${LITE_DIR}/tools/optimizer/graph/infershape_pass.cc | ${LITE_DIR}/tools/optimizer/graph/infershape_pass.cc | ||||
| ${LITE_DIR}/tools/optimizer/graph/slice_prepose_pass.cc | ${LITE_DIR}/tools/optimizer/graph/slice_prepose_pass.cc | ||||
| ${LITE_DIR}/tools/optimizer/graph/mindir_adjust_pass.cc | ${LITE_DIR}/tools/optimizer/graph/mindir_adjust_pass.cc | ||||
| ${LITE_DIR}/tools/optimizer/graph/onnx_inputs_adjust_pass.cc | |||||
| ) | ) | ||||
| endif() | endif() | ||||
| ### train | ### train | ||||
| @@ -58,6 +58,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | |||||
| ../optimizer/graph/infershape_pass.cc | ../optimizer/graph/infershape_pass.cc | ||||
| ../optimizer/graph/slice_prepose_pass.cc | ../optimizer/graph/slice_prepose_pass.cc | ||||
| ../optimizer/graph/mindir_adjust_pass.cc | ../optimizer/graph/mindir_adjust_pass.cc | ||||
| ../optimizer/graph/onnx_inputs_adjust_pass.cc | |||||
| ) | ) | ||||
| add_subdirectory(../anf_importer anf_importer) | add_subdirectory(../anf_importer anf_importer) | ||||
| @@ -36,6 +36,7 @@ | |||||
| #include "tools/optimizer/graph/clip_convert_activation_pass.h" | #include "tools/optimizer/graph/clip_convert_activation_pass.h" | ||||
| #include "tools/optimizer/graph/group_depthwise_op_convert_pass.h" | #include "tools/optimizer/graph/group_depthwise_op_convert_pass.h" | ||||
| #include "tools/optimizer/graph/tflite_inputs_order_exchange_pass.h" | #include "tools/optimizer/graph/tflite_inputs_order_exchange_pass.h" | ||||
| #include "tools/optimizer/graph/onnx_inputs_adjust_pass.h" | |||||
| #include "tools/optimizer/graph/update_conv2d_param_pass.h" | #include "tools/optimizer/graph/update_conv2d_param_pass.h" | ||||
| #include "tools/optimizer/graph/unused_cast_node_remove_pass.h" | #include "tools/optimizer/graph/unused_cast_node_remove_pass.h" | ||||
| #include "tools/optimizer/graph/unused_transpose_node_remove_pass.h" | #include "tools/optimizer/graph/unused_transpose_node_remove_pass.h" | ||||
| @@ -74,6 +75,16 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver | |||||
| } | } | ||||
| } | } | ||||
| // onnx pre adjustment | |||||
| if (config->fmk == converter::FmkType_ONNX) { | |||||
| auto onnx_adjust_pass = std::make_shared<opt::OnnxInputAdjustOpPass>(); | |||||
| if (!onnx_adjust_pass->Run(old_graph)) { | |||||
| MS_LOG(ERROR) << "onnx adjust failed."; | |||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); | |||||
| return nullptr; | |||||
| } | |||||
| } | |||||
| // for now - trainning is not supporting fuse operations | // for now - trainning is not supporting fuse operations | ||||
| if (!config->trainModel) { | if (!config->trainModel) { | ||||
| // remove quantdtype when awaretraining | // remove quantdtype when awaretraining | ||||
| @@ -90,6 +90,7 @@ STATUS CaffeModelParser::ConvertLayers() { | |||||
| auto primitive_c = node_parser->ParseLitePrimitive(layer, weight); | auto primitive_c = node_parser->ParseLitePrimitive(layer, weight); | ||||
| if (primitive_c == nullptr) { | if (primitive_c == nullptr) { | ||||
| MS_LOG(ERROR) << "parse node " << layer.name() << " failed."; | MS_LOG(ERROR) << "parse node " << layer.name() << " failed."; | ||||
| status = RET_ERROR; | |||||
| continue; | continue; | ||||
| } | } | ||||
| @@ -98,8 +99,7 @@ STATUS CaffeModelParser::ConvertLayers() { | |||||
| status = ConvertBottom(layer, &input_nodes); | status = ConvertBottom(layer, &input_nodes); | ||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| MS_LOG(ERROR) << "Convert layer bottom for " << layer.name() << " failed."; | MS_LOG(ERROR) << "Convert layer bottom for " << layer.name() << " failed."; | ||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||||
| return status; | |||||
| continue; | |||||
| } | } | ||||
| // build weights | // build weights | ||||
| @@ -107,8 +107,7 @@ STATUS CaffeModelParser::ConvertLayers() { | |||||
| status = ConvertBlobs(weight, &const_parameters); | status = ConvertBlobs(weight, &const_parameters); | ||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| MS_LOG(ERROR) << "Convert blobs for " << layer.name() << " failed."; | MS_LOG(ERROR) << "Convert blobs for " << layer.name() << " failed."; | ||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||||
| return status; | |||||
| continue; | |||||
| } | } | ||||
| // build cnode | // build cnode | ||||
| @@ -122,15 +121,13 @@ STATUS CaffeModelParser::ConvertLayers() { | |||||
| status = ConvertTop(layer, new_cnode); | status = ConvertTop(layer, new_cnode); | ||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| MS_LOG(ERROR) << "Convert outputs for " << layer.name() << " failed."; | MS_LOG(ERROR) << "Convert outputs for " << layer.name() << " failed."; | ||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||||
| return status; | |||||
| continue; | |||||
| } | } | ||||
| status = ConvertLayerQuantParams(layer, weight, primitive_c); | status = ConvertLayerQuantParams(layer, weight, primitive_c); | ||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| MS_LOG(ERROR) << "Convert quant params for " << layer.name() << " failed."; | MS_LOG(ERROR) << "Convert quant params for " << layer.name() << " failed."; | ||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||||
| return status; | |||||
| continue; | |||||
| } | } | ||||
| } | } | ||||
| return status; | return status; | ||||
| @@ -19,27 +19,22 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS OnnxAdderParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||||
| schema::CNodeT *op) { | |||||
| lite::PrimitiveC *OnnxAdderParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||||
| const onnx::NodeProto &onnx_node) { | |||||
| MS_LOG(DEBUG) << "onnx AdderParser"; | MS_LOG(DEBUG) << "onnx AdderParser"; | ||||
| if (op == nullptr) { | |||||
| MS_LOG(ERROR) << "op is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (op->primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "op->primitive is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| auto attr = std::make_unique<schema::AdderT>(); | auto attr = std::make_unique<schema::AdderT>(); | ||||
| if (attr == nullptr) { | if (attr == nullptr) { | ||||
| MS_LOG(ERROR) << "new op failed"; | MS_LOG(ERROR) << "new op failed"; | ||||
| return RET_NULL_PTR; | |||||
| return nullptr; | |||||
| } | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitive failed"; | |||||
| return nullptr; | |||||
| } | } | ||||
| op->primitive->value.type = schema::PrimitiveType_Adder; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | |||||
| primitive->value.type = schema::PrimitiveType_Adder; | |||||
| primitive->value.value = attr.release(); | |||||
| return PrimitiveC::Create(primitive.release()); | |||||
| } | } | ||||
| OnnxNodeRegistrar g_onnxAdderParser("adder_f", new OnnxAdderParser()); | OnnxNodeRegistrar g_onnxAdderParser("adder_f", new OnnxAdderParser()); | ||||
| @@ -26,8 +26,7 @@ class OnnxAdderParser : public OnnxNodeParser { | |||||
| public: | public: | ||||
| OnnxAdderParser() : OnnxNodeParser("Adder") {} | OnnxAdderParser() : OnnxNodeParser("Adder") {} | ||||
| ~OnnxAdderParser() override = default; | ~OnnxAdderParser() override = default; | ||||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -19,23 +19,14 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS OnnxArgMaxParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||||
| schema::CNodeT *op) { | |||||
| lite::PrimitiveC *OnnxArgMaxParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||||
| const onnx::NodeProto &onnx_node) { | |||||
| MS_LOG(DEBUG) << "onnx ArgMaxParser"; | MS_LOG(DEBUG) << "onnx ArgMaxParser"; | ||||
| if (op == nullptr) { | |||||
| MS_LOG(ERROR) << "op is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (op->primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "op->primitive is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| std::unique_ptr<schema::ArgMaxT> attr = std::make_unique<schema::ArgMaxT>(); | |||||
| auto attr = std::make_unique<schema::ArgMaxT>(); | |||||
| if (attr == nullptr) { | if (attr == nullptr) { | ||||
| MS_LOG(ERROR) << "new op failed"; | MS_LOG(ERROR) << "new op failed"; | ||||
| return RET_NULL_PTR; | |||||
| return nullptr; | |||||
| } | } | ||||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | for (const auto &onnx_node_attr : onnx_node.attribute()) { | ||||
| @@ -46,10 +37,14 @@ STATUS OnnxArgMaxParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N | |||||
| attr->keepDims = static_cast<bool>(onnx_node_attr.i()); | attr->keepDims = static_cast<bool>(onnx_node_attr.i()); | ||||
| } | } | ||||
| } | } | ||||
| op->primitive->value.type = schema::PrimitiveType_ArgMax; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitive failed"; | |||||
| return nullptr; | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_ArgMax; | |||||
| primitive->value.value = attr.release(); | |||||
| return PrimitiveC::Create(primitive.release()); | |||||
| } | } | ||||
| OnnxNodeRegistrar g_onnxArgMaxParser("ArgMax", new OnnxArgMaxParser()); | OnnxNodeRegistrar g_onnxArgMaxParser("ArgMax", new OnnxArgMaxParser()); | ||||
| @@ -27,7 +27,7 @@ class OnnxArgMaxParser : public OnnxNodeParser { | |||||
| OnnxArgMaxParser() : OnnxNodeParser("ArgMax") {} | OnnxArgMaxParser() : OnnxNodeParser("ArgMax") {} | ||||
| ~OnnxArgMaxParser() override = default; | ~OnnxArgMaxParser() override = default; | ||||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -26,203 +26,203 @@ class OnnxAddParser : public OnnxNodeParser { | |||||
| public: | public: | ||||
| OnnxAddParser() : OnnxNodeParser("Add") {} | OnnxAddParser() : OnnxNodeParser("Add") {} | ||||
| ~OnnxAddParser() override = default; | ~OnnxAddParser() override = default; | ||||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||||
| }; | }; | ||||
| class OnnxSubParser : public OnnxNodeParser { | class OnnxSubParser : public OnnxNodeParser { | ||||
| public: | public: | ||||
| OnnxSubParser() : OnnxNodeParser("Sub") {} | OnnxSubParser() : OnnxNodeParser("Sub") {} | ||||
| ~OnnxSubParser() override = default; | ~OnnxSubParser() override = default; | ||||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||||
| }; | }; | ||||
| class OnnxMulParser : public OnnxNodeParser { | class OnnxMulParser : public OnnxNodeParser { | ||||
| public: | public: | ||||
| OnnxMulParser() : OnnxNodeParser("Mul") {} | OnnxMulParser() : OnnxNodeParser("Mul") {} | ||||
| ~OnnxMulParser() override = default; | ~OnnxMulParser() override = default; | ||||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||||
| }; | }; | ||||
| class OnnxDivParser : public OnnxNodeParser { | class OnnxDivParser : public OnnxNodeParser { | ||||
| public: | public: | ||||
| OnnxDivParser() : OnnxNodeParser("Div") {} | OnnxDivParser() : OnnxNodeParser("Div") {} | ||||
| ~OnnxDivParser() override = default; | ~OnnxDivParser() override = default; | ||||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||||
| }; | }; | ||||
| class OnnxPowParser : public OnnxNodeParser { | class OnnxPowParser : public OnnxNodeParser { | ||||
| public: | public: | ||||
| OnnxPowParser() : OnnxNodeParser("Power") {} | OnnxPowParser() : OnnxNodeParser("Power") {} | ||||
| ~OnnxPowParser() override = default; | ~OnnxPowParser() override = default; | ||||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||||
| }; | }; | ||||
| class OnnxEqualParser : public OnnxNodeParser { | class OnnxEqualParser : public OnnxNodeParser { | ||||
| public: | public: | ||||
| OnnxEqualParser() : OnnxNodeParser("Equal") {} | OnnxEqualParser() : OnnxNodeParser("Equal") {} | ||||
| ~OnnxEqualParser() override = default; | ~OnnxEqualParser() override = default; | ||||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||||
| }; | }; | ||||
| class OnnxLessParser : public OnnxNodeParser { | class OnnxLessParser : public OnnxNodeParser { | ||||
| public: | public: | ||||
| OnnxLessParser() : OnnxNodeParser("Less") {} | OnnxLessParser() : OnnxNodeParser("Less") {} | ||||
| ~OnnxLessParser() override = default; | ~OnnxLessParser() override = default; | ||||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||||
| }; | }; | ||||
| class OnnxGreaterParser : public OnnxNodeParser { | class OnnxGreaterParser : public OnnxNodeParser { | ||||
| public: | public: | ||||
| OnnxGreaterParser() : OnnxNodeParser("Greater") {} | OnnxGreaterParser() : OnnxNodeParser("Greater") {} | ||||
| ~OnnxGreaterParser() override = default; | ~OnnxGreaterParser() override = default; | ||||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||||
| }; | }; | ||||
| class OnnxMinParser : public OnnxNodeParser { | class OnnxMinParser : public OnnxNodeParser { | ||||
| public: | public: | ||||
| OnnxMinParser() : OnnxNodeParser("Min") {} | OnnxMinParser() : OnnxNodeParser("Min") {} | ||||
| ~OnnxMinParser() override = default; | ~OnnxMinParser() override = default; | ||||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||||
| }; | }; | ||||
| class OnnxEltwiseParser : public OnnxNodeParser { | class OnnxEltwiseParser : public OnnxNodeParser { | ||||
| public: | public: | ||||
| OnnxEltwiseParser() : OnnxNodeParser("Eltwise") {} | OnnxEltwiseParser() : OnnxNodeParser("Eltwise") {} | ||||
| ~OnnxEltwiseParser() override = default; | ~OnnxEltwiseParser() override = default; | ||||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||||
| }; | }; | ||||
| class OnnxFloorParser : public OnnxNodeParser { | class OnnxFloorParser : public OnnxNodeParser { | ||||
| public: | public: | ||||
| OnnxFloorParser() : OnnxNodeParser("Floor") {} | OnnxFloorParser() : OnnxNodeParser("Floor") {} | ||||
| ~OnnxFloorParser() override = default; | ~OnnxFloorParser() override = default; | ||||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||||
| }; | }; | ||||
| class OnnxAbsParser : public OnnxNodeParser { | class OnnxAbsParser : public OnnxNodeParser { | ||||
| public: | public: | ||||
| OnnxAbsParser() : OnnxNodeParser("Abs") {} | OnnxAbsParser() : OnnxNodeParser("Abs") {} | ||||
| ~OnnxAbsParser() override = default; | ~OnnxAbsParser() override = default; | ||||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||||
| }; | }; | ||||
| class OnnxNegParser : public OnnxNodeParser { | class OnnxNegParser : public OnnxNodeParser { | ||||
| public: | public: | ||||
| OnnxNegParser() : OnnxNodeParser("Neg") {} | OnnxNegParser() : OnnxNodeParser("Neg") {} | ||||
| ~OnnxNegParser() override = default; | ~OnnxNegParser() override = default; | ||||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||||
| }; | }; | ||||
| class OnnxExpParser : public OnnxNodeParser { | class OnnxExpParser : public OnnxNodeParser { | ||||
| public: | public: | ||||
| OnnxExpParser() : OnnxNodeParser("Exp") {} | OnnxExpParser() : OnnxNodeParser("Exp") {} | ||||
| ~OnnxExpParser() override = default; | ~OnnxExpParser() override = default; | ||||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||||
| }; | }; | ||||
| class OnnxCosParser : public OnnxNodeParser { | class OnnxCosParser : public OnnxNodeParser { | ||||
| public: | public: | ||||
| OnnxCosParser() : OnnxNodeParser("Cos") {} | OnnxCosParser() : OnnxNodeParser("Cos") {} | ||||
| ~OnnxCosParser() override = default; | ~OnnxCosParser() override = default; | ||||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||||
| }; | }; | ||||
| class OnnxSinParser : public OnnxNodeParser { | class OnnxSinParser : public OnnxNodeParser { | ||||
| public: | public: | ||||
| OnnxSinParser() : OnnxNodeParser("Sin") {} | OnnxSinParser() : OnnxNodeParser("Sin") {} | ||||
| ~OnnxSinParser() override = default; | ~OnnxSinParser() override = default; | ||||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||||
| }; | }; | ||||
| class OnnxSqrtParser : public OnnxNodeParser { | class OnnxSqrtParser : public OnnxNodeParser { | ||||
| public: | public: | ||||
| OnnxSqrtParser() : OnnxNodeParser("Sqrt") {} | OnnxSqrtParser() : OnnxNodeParser("Sqrt") {} | ||||
| ~OnnxSqrtParser() override = default; | ~OnnxSqrtParser() override = default; | ||||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||||
| }; | }; | ||||
| class OnnxCeilParser : public OnnxNodeParser { | class OnnxCeilParser : public OnnxNodeParser { | ||||
| public: | public: | ||||
| OnnxCeilParser() : OnnxNodeParser("Ceil") {} | OnnxCeilParser() : OnnxNodeParser("Ceil") {} | ||||
| ~OnnxCeilParser() override = default; | ~OnnxCeilParser() override = default; | ||||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||||
| }; | }; | ||||
| class OnnxLogParser : public OnnxNodeParser { | class OnnxLogParser : public OnnxNodeParser { | ||||
| public: | public: | ||||
| OnnxLogParser() : OnnxNodeParser("Log") {} | OnnxLogParser() : OnnxNodeParser("Log") {} | ||||
| ~OnnxLogParser() override = default; | ~OnnxLogParser() override = default; | ||||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||||
| }; | }; | ||||
| class OnnxTanParser : public OnnxNodeParser { | class OnnxTanParser : public OnnxNodeParser { | ||||
| public: | public: | ||||
| OnnxTanParser() : OnnxNodeParser("Tan") {} | OnnxTanParser() : OnnxNodeParser("Tan") {} | ||||
| ~OnnxTanParser() override = default; | ~OnnxTanParser() override = default; | ||||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||||
| }; | }; | ||||
| class OnnxAtanParser : public OnnxNodeParser { | class OnnxAtanParser : public OnnxNodeParser { | ||||
| public: | public: | ||||
| OnnxAtanParser() : OnnxNodeParser("Atan") {} | OnnxAtanParser() : OnnxNodeParser("Atan") {} | ||||
| ~OnnxAtanParser() override = default; | ~OnnxAtanParser() override = default; | ||||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||||
| }; | }; | ||||
| class OnnxAsinParser : public OnnxNodeParser { | class OnnxAsinParser : public OnnxNodeParser { | ||||
| public: | public: | ||||
| OnnxAsinParser() : OnnxNodeParser("Asin") {} | OnnxAsinParser() : OnnxNodeParser("Asin") {} | ||||
| ~OnnxAsinParser() override = default; | ~OnnxAsinParser() override = default; | ||||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||||
| }; | }; | ||||
| class OnnxTanhParser : public OnnxNodeParser { | class OnnxTanhParser : public OnnxNodeParser { | ||||
| public: | public: | ||||
| OnnxTanhParser() : OnnxNodeParser("Tanh") {} | OnnxTanhParser() : OnnxNodeParser("Tanh") {} | ||||
| ~OnnxTanhParser() override = default; | ~OnnxTanhParser() override = default; | ||||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||||
| }; | }; | ||||
| class OnnxSignParser : public OnnxNodeParser { | class OnnxSignParser : public OnnxNodeParser { | ||||
| public: | public: | ||||
| OnnxSignParser() : OnnxNodeParser("Sign") {} | OnnxSignParser() : OnnxNodeParser("Sign") {} | ||||
| ~OnnxSignParser() override = default; | ~OnnxSignParser() override = default; | ||||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||||
| }; | }; | ||||
| class OnnxAndParser : public OnnxNodeParser { | class OnnxAndParser : public OnnxNodeParser { | ||||
| public: | public: | ||||
| OnnxAndParser() : OnnxNodeParser("And") {} | OnnxAndParser() : OnnxNodeParser("And") {} | ||||
| ~OnnxAndParser() override = default; | ~OnnxAndParser() override = default; | ||||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||||
| }; | }; | ||||
| class OnnxOrParser : public OnnxNodeParser { | class OnnxOrParser : public OnnxNodeParser { | ||||
| public: | public: | ||||
| OnnxOrParser() : OnnxNodeParser("Or") {} | OnnxOrParser() : OnnxNodeParser("Or") {} | ||||
| ~OnnxOrParser() override = default; | ~OnnxOrParser() override = default; | ||||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||||
| }; | }; | ||||
| class OnnxNotParser : public OnnxNodeParser { | class OnnxNotParser : public OnnxNodeParser { | ||||
| public: | public: | ||||
| OnnxNotParser() : OnnxNodeParser("Not") {} | OnnxNotParser() : OnnxNodeParser("Not") {} | ||||
| ~OnnxNotParser() override = default; | ~OnnxNotParser() override = default; | ||||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||||
| }; | }; | ||||
| class OnnxRoundParser : public OnnxNodeParser { | class OnnxRoundParser : public OnnxNodeParser { | ||||
| public: | public: | ||||
| OnnxRoundParser() : OnnxNodeParser("Round") {} | OnnxRoundParser() : OnnxNodeParser("Round") {} | ||||
| ~OnnxRoundParser() override = default; | ~OnnxRoundParser() override = default; | ||||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||||
| }; | }; | ||||
| class OnnxReciprocalParser : public OnnxNodeParser { | class OnnxReciprocalParser : public OnnxNodeParser { | ||||
| public: | public: | ||||
| OnnxReciprocalParser() : OnnxNodeParser("Reciprocal") {} | OnnxReciprocalParser() : OnnxNodeParser("Reciprocal") {} | ||||
| ~OnnxReciprocalParser() override = default; | ~OnnxReciprocalParser() override = default; | ||||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -19,23 +19,13 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS OnnxBatchNormParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||||
| schema::CNodeT *op) { | |||||
| lite::PrimitiveC *OnnxBatchNormParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||||
| const onnx::NodeProto &onnx_node) { | |||||
| MS_LOG(DEBUG) << "onnx BatchNormParser"; | MS_LOG(DEBUG) << "onnx BatchNormParser"; | ||||
| if (op == nullptr) { | |||||
| MS_LOG(ERROR) << "op is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (op->primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "op->primitive is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| std::unique_ptr<schema::FusedBatchNormT> attr = std::make_unique<schema::FusedBatchNormT>(); | |||||
| auto attr = std::make_unique<schema::FusedBatchNormT>(); | |||||
| if (attr == nullptr) { | if (attr == nullptr) { | ||||
| MS_LOG(ERROR) << "new op failed"; | MS_LOG(ERROR) << "new op failed"; | ||||
| return RET_NULL_PTR; | |||||
| return nullptr; | |||||
| } | } | ||||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | for (const auto &onnx_node_attr : onnx_node.attribute()) { | ||||
| @@ -47,10 +37,14 @@ STATUS OnnxBatchNormParser::Parse(const onnx::GraphProto &onnx_graph, const onnx | |||||
| attr->spatial = static_cast<int32_t>(onnx_node_attr.i()); | attr->spatial = static_cast<int32_t>(onnx_node_attr.i()); | ||||
| } | } | ||||
| } | } | ||||
| op->primitive->value.type = schema::PrimitiveType_FusedBatchNorm; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitive failed"; | |||||
| return nullptr; | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_FusedBatchNorm; | |||||
| primitive->value.value = attr.release(); | |||||
| return PrimitiveC::Create(primitive.release()); | |||||
| } | } | ||||
| OnnxNodeRegistrar g_onnxBatchNormParser("BatchNormalization", new OnnxBatchNormParser()); | OnnxNodeRegistrar g_onnxBatchNormParser("BatchNormalization", new OnnxBatchNormParser()); | ||||
| @@ -27,7 +27,7 @@ class OnnxBatchNormParser : public OnnxNodeParser { | |||||
| OnnxBatchNormParser() : OnnxNodeParser("BatchNormalization") {} | OnnxBatchNormParser() : OnnxNodeParser("BatchNormalization") {} | ||||
| ~OnnxBatchNormParser() override = default; | ~OnnxBatchNormParser() override = default; | ||||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -19,30 +19,25 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS OnnxBiasAddParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||||
| schema::CNodeT *op) { | |||||
| lite::PrimitiveC *OnnxBiasAddParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||||
| const onnx::NodeProto &onnx_node) { | |||||
| MS_LOG(DEBUG) << "onnx BiasAddParser"; | MS_LOG(DEBUG) << "onnx BiasAddParser"; | ||||
| if (op == nullptr) { | |||||
| MS_LOG(ERROR) << "op is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (op->primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "op->primitive is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| std::unique_ptr<schema::BiasAddT> attr = std::make_unique<schema::BiasAddT>(); | |||||
| auto attr = std::make_unique<schema::BiasAddT>(); | |||||
| if (attr == nullptr) { | if (attr == nullptr) { | ||||
| MS_LOG(ERROR) << "new op failed"; | MS_LOG(ERROR) << "new op failed"; | ||||
| return RET_NULL_PTR; | |||||
| return nullptr; | |||||
| } | } | ||||
| attr->axis = {1}; | attr->axis = {1}; | ||||
| op->primitive->value.type = schema::PrimitiveType_BiasAdd; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitive failed"; | |||||
| return nullptr; | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_BiasAdd; | |||||
| primitive->value.value = attr.release(); | |||||
| return PrimitiveC::Create(primitive.release()); | |||||
| } | } | ||||
| OnnxNodeRegistrar g_onnxBiasAddParser("BiasAdd", new OnnxBiasAddParser()); | OnnxNodeRegistrar g_onnxBiasAddParser("BiasAdd", new OnnxBiasAddParser()); | ||||
| @@ -27,7 +27,7 @@ class OnnxBiasAddParser : public OnnxNodeParser { | |||||
| OnnxBiasAddParser() : OnnxNodeParser("BiasAdd") {} | OnnxBiasAddParser() : OnnxNodeParser("BiasAdd") {} | ||||
| ~OnnxBiasAddParser() override = default; | ~OnnxBiasAddParser() override = default; | ||||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -20,22 +20,13 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS OnnxCastParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { | |||||
| lite::PrimitiveC *OnnxCastParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||||
| const onnx::NodeProto &onnx_node) { | |||||
| MS_LOG(DEBUG) << "onnx CastParser"; | MS_LOG(DEBUG) << "onnx CastParser"; | ||||
| if (op == nullptr) { | |||||
| MS_LOG(ERROR) << "op is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (op->primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "op->primitive is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| std::unique_ptr<schema::CastT> attr = std::make_unique<schema::CastT>(); | |||||
| auto attr = std::make_unique<schema::CastT>(); | |||||
| if (attr == nullptr) { | if (attr == nullptr) { | ||||
| MS_LOG(ERROR) << "new op failed"; | MS_LOG(ERROR) << "new op failed"; | ||||
| return RET_NULL_PTR; | |||||
| return nullptr; | |||||
| } | } | ||||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | for (const auto &onnx_node_attr : onnx_node.attribute()) { | ||||
| @@ -48,10 +39,14 @@ STATUS OnnxCastParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod | |||||
| attr->dstT = static_cast<int>(dst_type); | attr->dstT = static_cast<int>(dst_type); | ||||
| } | } | ||||
| } | } | ||||
| op->primitive->value.type = schema::PrimitiveType_Cast; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitive failed"; | |||||
| return nullptr; | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_Cast; | |||||
| primitive->value.value = attr.release(); | |||||
| return PrimitiveC::Create(primitive.release()); | |||||
| } | } | ||||
| OnnxNodeRegistrar g_onnxCastParser("Cast", new OnnxCastParser()); | OnnxNodeRegistrar g_onnxCastParser("Cast", new OnnxCastParser()); | ||||
| @@ -27,7 +27,7 @@ class OnnxCastParser : public OnnxNodeParser { | |||||
| OnnxCastParser() : OnnxNodeParser("Cast") {} | OnnxCastParser() : OnnxNodeParser("Cast") {} | ||||
| ~OnnxCastParser() override = default; | ~OnnxCastParser() override = default; | ||||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -19,39 +19,32 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS OnnxClipParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { | |||||
| lite::PrimitiveC *OnnxClipParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||||
| const onnx::NodeProto &onnx_node) { | |||||
| MS_LOG(DEBUG) << "onnx ClipParser"; | MS_LOG(DEBUG) << "onnx ClipParser"; | ||||
| if (op == nullptr) { | |||||
| MS_LOG(ERROR) << "op is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (op->primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "op->primitive is null"; | |||||
| return RET_NULL_PTR; | |||||
| auto attr = std::make_unique<schema::ClipT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return nullptr; | |||||
| } | } | ||||
| float min = -1, max = -1; | |||||
| attr->max = -1; | |||||
| attr->min = -1; | |||||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | for (const auto &onnx_node_attr : onnx_node.attribute()) { | ||||
| const auto &attribute_name = onnx_node_attr.name(); | const auto &attribute_name = onnx_node_attr.name(); | ||||
| if (attribute_name == "max") { | if (attribute_name == "max") { | ||||
| max = onnx_node_attr.f(); | |||||
| attr->max = onnx_node_attr.f(); | |||||
| } else if (attribute_name == "min") { | } else if (attribute_name == "min") { | ||||
| min = onnx_node_attr.f(); | |||||
| attr->min = onnx_node_attr.f(); | |||||
| } | } | ||||
| } | } | ||||
| std::unique_ptr<schema::ClipT> attr = std::make_unique<schema::ClipT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return RET_NULL_PTR; | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitive failed"; | |||||
| return nullptr; | |||||
| } | } | ||||
| attr->max = max; | |||||
| attr->min = min; | |||||
| op->primitive->value.type = schema::PrimitiveType_Clip; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | |||||
| primitive->value.type = schema::PrimitiveType_Clip; | |||||
| primitive->value.value = attr.release(); | |||||
| return PrimitiveC::Create(primitive.release()); | |||||
| } | } | ||||
| OnnxNodeRegistrar g_onnxClipParser("Clip", new OnnxClipParser()); | OnnxNodeRegistrar g_onnxClipParser("Clip", new OnnxClipParser()); | ||||
| @@ -27,7 +27,7 @@ class OnnxClipParser : public OnnxNodeParser { | |||||
| OnnxClipParser() : OnnxNodeParser("Clip") {} | OnnxClipParser() : OnnxNodeParser("Clip") {} | ||||
| ~OnnxClipParser() override = default; | ~OnnxClipParser() override = default; | ||||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -19,23 +19,13 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS OnnxConcatParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||||
| schema::CNodeT *op) { | |||||
| lite::PrimitiveC *OnnxConcatParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||||
| const onnx::NodeProto &onnx_node) { | |||||
| MS_LOG(DEBUG) << "onnx ConcatParser"; | MS_LOG(DEBUG) << "onnx ConcatParser"; | ||||
| if (op == nullptr) { | |||||
| MS_LOG(ERROR) << "op is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (op->primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "op->primitive is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| std::unique_ptr<schema::ConcatT> attr = std::make_unique<schema::ConcatT>(); | |||||
| auto attr = std::make_unique<schema::ConcatT>(); | |||||
| if (attr == nullptr) { | if (attr == nullptr) { | ||||
| MS_LOG(ERROR) << "new op failed"; | MS_LOG(ERROR) << "new op failed"; | ||||
| return RET_NULL_PTR; | |||||
| return nullptr; | |||||
| } | } | ||||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | for (const auto &onnx_node_attr : onnx_node.attribute()) { | ||||
| @@ -44,10 +34,14 @@ STATUS OnnxConcatParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N | |||||
| attr->axis = static_cast<int32_t>(onnx_node_attr.i()); | attr->axis = static_cast<int32_t>(onnx_node_attr.i()); | ||||
| } | } | ||||
| } | } | ||||
| op->primitive->value.type = schema::PrimitiveType_Concat; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitive failed"; | |||||
| return nullptr; | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_Concat; | |||||
| primitive->value.value = attr.release(); | |||||
| return PrimitiveC::Create(primitive.release()); | |||||
| } | } | ||||
| OnnxNodeRegistrar g_onnxConcatParser("Concat", new OnnxConcatParser()); | OnnxNodeRegistrar g_onnxConcatParser("Concat", new OnnxConcatParser()); | ||||
| @@ -27,7 +27,7 @@ class OnnxConcatParser : public OnnxNodeParser { | |||||
| OnnxConcatParser() : OnnxNodeParser("Concat") {} | OnnxConcatParser() : OnnxNodeParser("Concat") {} | ||||
| ~OnnxConcatParser() override = default; | ~OnnxConcatParser() override = default; | ||||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -20,23 +20,13 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS OnnxConstantOfShapeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||||
| schema::CNodeT *op) { | |||||
| lite::PrimitiveC *OnnxConstantOfShapeParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||||
| const onnx::NodeProto &onnx_node) { | |||||
| MS_LOG(DEBUG) << "onnx ConstantOfShapeParser"; | MS_LOG(DEBUG) << "onnx ConstantOfShapeParser"; | ||||
| if (op == nullptr) { | |||||
| MS_LOG(ERROR) << "op is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (op->primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "op->primitive is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| std::unique_ptr<schema::ConstantOfShapeT> attr = std::make_unique<schema::ConstantOfShapeT>(); | |||||
| auto attr = std::make_unique<schema::ConstantOfShapeT>(); | |||||
| if (attr == nullptr) { | if (attr == nullptr) { | ||||
| MS_LOG(ERROR) << "new op failed"; | MS_LOG(ERROR) << "new op failed"; | ||||
| return RET_NULL_PTR; | |||||
| return nullptr; | |||||
| } | } | ||||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | for (const auto &onnx_node_attr : onnx_node.attribute()) { | ||||
| @@ -55,19 +45,24 @@ STATUS OnnxConstantOfShapeParser::Parse(const onnx::GraphProto &onnx_graph, cons | |||||
| const auto &tensor = onnx_node_attr.t(); | const auto &tensor = onnx_node_attr.t(); | ||||
| auto ret = GetTensorDataFromOnnx(tensor, &attr->value, &attr->dataType); | auto ret = GetTensorDataFromOnnx(tensor, &attr->value, &attr->dataType); | ||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| return ret; | |||||
| MS_LOG(ERROR) << "get data from tensor failed"; | |||||
| return nullptr; | |||||
| } | } | ||||
| } break; | } break; | ||||
| default: | default: | ||||
| MS_LOG(ERROR) << "The data type is not supported."; | MS_LOG(ERROR) << "The data type is not supported."; | ||||
| return RET_ERROR; | |||||
| return nullptr; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| op->primitive->value.type = schema::PrimitiveType_ConstantOfShape; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitive failed"; | |||||
| return nullptr; | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_ConstantOfShape; | |||||
| primitive->value.value = attr.release(); | |||||
| return PrimitiveC::Create(primitive.release()); | |||||
| } | } | ||||
| OnnxNodeRegistrar g_onnxConstantOfShapeParser("ConstantOfShape", new OnnxConstantOfShapeParser()); | OnnxNodeRegistrar g_onnxConstantOfShapeParser("ConstantOfShape", new OnnxConstantOfShapeParser()); | ||||
| @@ -27,7 +27,7 @@ class OnnxConstantOfShapeParser : public OnnxNodeParser { | |||||
| OnnxConstantOfShapeParser() : OnnxNodeParser("ConstantOfShape") {} | OnnxConstantOfShapeParser() : OnnxNodeParser("ConstantOfShape") {} | ||||
| ~OnnxConstantOfShapeParser() override = default; | ~OnnxConstantOfShapeParser() override = default; | ||||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -16,33 +16,75 @@ | |||||
| #include "tools/converter/parser/onnx/onnx_constant_parser.h" | #include "tools/converter/parser/onnx/onnx_constant_parser.h" | ||||
| #include <memory> | #include <memory> | ||||
| #include <vector> | |||||
| #include <algorithm> | |||||
| #include "tools/converter/parser/onnx/onnx_model_parser.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS OnnxConstantParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||||
| schema::CNodeT *op) { | |||||
| MS_LOG(DEBUG) << "onnx ConstantParser"; | |||||
| if (op == nullptr) { | |||||
| MS_LOG(ERROR) << "op is null"; | |||||
| return RET_NULL_PTR; | |||||
| STATUS OnnxConstantParser::AddDataInfoAttr(const onnx::TensorProto &onnx_const_tensor, lite::PrimitiveC *primitive_c) { | |||||
| ParamValueLitePtr param_value = std::make_shared<ParamValueLite>(); | |||||
| if (param_value == nullptr) { | |||||
| MS_LOG(ERROR) << "new a paramValueLite failed."; | |||||
| return RET_ERROR; | |||||
| } | } | ||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (op->primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "op->primitive is null"; | |||||
| return RET_NULL_PTR; | |||||
| auto data_type = | |||||
| OnnxModelParser::GetDataTypeFromOnnx(static_cast<onnx::TensorProto_DataType>(onnx_const_tensor.data_type())); | |||||
| if (data_type == kTypeUnknown) { | |||||
| MS_LOG(ERROR) << "not support onnx data type " | |||||
| << static_cast<onnx::TensorProto_DataType>(onnx_const_tensor.data_type()); | |||||
| return RET_ERROR; | |||||
| } | } | ||||
| std::unique_ptr<schema::ConstantT> attr = std::make_unique<schema::ConstantT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return RET_NULL_PTR; | |||||
| std::vector<int64_t> shape_vector(onnx_const_tensor.dims().begin(), onnx_const_tensor.dims().end()); | |||||
| std::vector<int> shape; | |||||
| std::transform(shape_vector.begin(), shape_vector.end(), std::back_inserter(shape), | |||||
| [](const int64_t &val) { return static_cast<int32_t>(val); }); | |||||
| param_value->set_tensor_type(data_type); | |||||
| param_value->set_tensor_shape(shape); | |||||
| param_value->set_format(schema::Format_NCHW); | |||||
| if (OnnxModelParser::CopyOnnxTensorData(onnx_const_tensor, param_value) != RET_OK) { | |||||
| MS_LOG(ERROR) << "get value failed."; | |||||
| return RET_ERROR; | |||||
| } | } | ||||
| op->primitive->value.type = schema::PrimitiveType_Constant; | |||||
| op->primitive->value.value = attr.release(); | |||||
| primitive_c->set_attr("const_data", param_value); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| lite::PrimitiveC *OnnxConstantParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||||
| const onnx::NodeProto &onnx_node) { | |||||
| MS_LOG(DEBUG) << "onnx ConstantParser"; | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitive failed"; | |||||
| return nullptr; | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_Constant; | |||||
| auto primitive_c = PrimitiveC::Create(primitive.release()); | |||||
| if (primitive_c == nullptr) { | |||||
| MS_LOG(ERROR) << "create primitiveC failed."; | |||||
| return nullptr; | |||||
| } | |||||
| for (const auto &attr : onnx_node.attribute()) { | |||||
| if (attr.name() == "sparse_value") { | |||||
| MS_LOG(WARNING) << "sparse_value"; | |||||
| continue; | |||||
| } | |||||
| if (attr.name() == "value") { | |||||
| const auto &const_tensor = attr.t(); | |||||
| if (AddDataInfoAttr(const_tensor, primitive_c) != RET_OK) { | |||||
| MS_LOG(ERROR) << "add basic attr failed."; | |||||
| delete primitive_c; | |||||
| return nullptr; | |||||
| } | |||||
| } else { | |||||
| MS_LOG(ERROR) << "processing Constant op attr " << attr.name() << " not implemented"; | |||||
| delete primitive_c; | |||||
| return nullptr; | |||||
| } | |||||
| } | |||||
| return primitive_c; | |||||
| } | |||||
| OnnxNodeRegistrar g_onnxConstantParser("Constant", new OnnxConstantParser()); | OnnxNodeRegistrar g_onnxConstantParser("Constant", new OnnxConstantParser()); | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -27,7 +27,8 @@ class OnnxConstantParser : public OnnxNodeParser { | |||||
| OnnxConstantParser() : OnnxNodeParser("Constant") {} | OnnxConstantParser() : OnnxNodeParser("Constant") {} | ||||
| ~OnnxConstantParser() override = default; | ~OnnxConstantParser() override = default; | ||||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||||
| STATUS AddDataInfoAttr(const onnx::TensorProto &onnx_const_tensor, lite::PrimitiveC *primitive_c); | |||||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -21,9 +21,14 @@ | |||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| constexpr int32_t kSingleGroup = 1; | constexpr int32_t kSingleGroup = 1; | ||||
| bool OnnxConvParser::ParseGroupConvolution(const std::unique_ptr<schema::Conv2DT> &attr, schema::CNodeT *op) { | |||||
| bool OnnxConvParser::ParseGroupConvolution(const std::unique_ptr<schema::Conv2DT> &attr, | |||||
| schema::PrimitiveT *primitive) { | |||||
| MS_LOG(DEBUG) << "onnx DepthwiseConvParser"; | MS_LOG(DEBUG) << "onnx DepthwiseConvParser"; | ||||
| std::unique_ptr<schema::DepthwiseConv2DT> depthwiseConv2DParam = std::make_unique<schema::DepthwiseConv2DT>(); | |||||
| if (attr == nullptr || primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "input parameter is nullptr"; | |||||
| return false; | |||||
| } | |||||
| auto depthwiseConv2DParam = std::make_unique<schema::DepthwiseConv2DT>(); | |||||
| if (depthwiseConv2DParam == nullptr) { | if (depthwiseConv2DParam == nullptr) { | ||||
| MS_LOG(ERROR) << "new op failed"; | MS_LOG(ERROR) << "new op failed"; | ||||
| return false; | return false; | ||||
| @@ -45,27 +50,18 @@ bool OnnxConvParser::ParseGroupConvolution(const std::unique_ptr<schema::Conv2DT | |||||
| depthwiseConv2DParam->hasBias = attr->hasBias; | depthwiseConv2DParam->hasBias = attr->hasBias; | ||||
| depthwiseConv2DParam->activationType = attr->activationType; | depthwiseConv2DParam->activationType = attr->activationType; | ||||
| op->primitive->value.type = schema::PrimitiveType_DepthwiseConv2D; | |||||
| op->primitive->value.value = depthwiseConv2DParam.release(); | |||||
| primitive->value.type = schema::PrimitiveType_DepthwiseConv2D; | |||||
| primitive->value.value = depthwiseConv2DParam.release(); | |||||
| return true; | return true; | ||||
| } | } | ||||
| STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { | |||||
| lite::PrimitiveC *OnnxConvParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||||
| const onnx::NodeProto &onnx_node) { | |||||
| MS_LOG(DEBUG) << "onnx ConvParser"; | MS_LOG(DEBUG) << "onnx ConvParser"; | ||||
| if (op == nullptr) { | |||||
| MS_LOG(ERROR) << "op is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (op->primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "op->primitive is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| std::unique_ptr<schema::Conv2DT> attr = std::make_unique<schema::Conv2DT>(); | |||||
| auto attr = std::make_unique<schema::Conv2DT>(); | |||||
| if (attr == nullptr) { | if (attr == nullptr) { | ||||
| MS_LOG(ERROR) << "new op failed"; | MS_LOG(ERROR) << "new op failed"; | ||||
| return RET_NULL_PTR; | |||||
| return nullptr; | |||||
| } | } | ||||
| attr->strideH = 1; | attr->strideH = 1; | ||||
| @@ -83,21 +79,21 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod | |||||
| } else if (onnx_node_attr.name() == "dilations") { | } else if (onnx_node_attr.name() == "dilations") { | ||||
| if (onnx_node_attr.ints().size() != 2) { | if (onnx_node_attr.ints().size() != 2) { | ||||
| MS_LOG(ERROR) << "dilations size " << onnx_node_attr.ints().size() << " is not 2"; | MS_LOG(ERROR) << "dilations size " << onnx_node_attr.ints().size() << " is not 2"; | ||||
| return RET_ERROR; | |||||
| return nullptr; | |||||
| } | } | ||||
| attr->dilateH = static_cast<int32_t>(onnx_node_attr.ints(0)); | attr->dilateH = static_cast<int32_t>(onnx_node_attr.ints(0)); | ||||
| attr->dilateW = static_cast<int32_t>(onnx_node_attr.ints(1)); | attr->dilateW = static_cast<int32_t>(onnx_node_attr.ints(1)); | ||||
| } else if (onnx_node_attr.name() == "kernels") { | } else if (onnx_node_attr.name() == "kernels") { | ||||
| if (onnx_node_attr.ints().size() != 2) { | if (onnx_node_attr.ints().size() != 2) { | ||||
| MS_LOG(ERROR) << "kernel_shape size " << onnx_node_attr.ints().size() << " is not 2"; | MS_LOG(ERROR) << "kernel_shape size " << onnx_node_attr.ints().size() << " is not 2"; | ||||
| return RET_ERROR; | |||||
| return nullptr; | |||||
| } | } | ||||
| attr->kernelH = static_cast<int32_t>(onnx_node_attr.ints(0)); | attr->kernelH = static_cast<int32_t>(onnx_node_attr.ints(0)); | ||||
| attr->kernelW = static_cast<int32_t>(onnx_node_attr.ints(1)); | attr->kernelW = static_cast<int32_t>(onnx_node_attr.ints(1)); | ||||
| } else if (onnx_node_attr.name() == "kernel_shape") { | } else if (onnx_node_attr.name() == "kernel_shape") { | ||||
| if (onnx_node_attr.ints().size() != 2) { | if (onnx_node_attr.ints().size() != 2) { | ||||
| MS_LOG(ERROR) << "kernel_shape size " << onnx_node_attr.ints().size() << " is not 2"; | MS_LOG(ERROR) << "kernel_shape size " << onnx_node_attr.ints().size() << " is not 2"; | ||||
| return RET_ERROR; | |||||
| return nullptr; | |||||
| } | } | ||||
| attr->kernelH = static_cast<int32_t>(onnx_node_attr.ints(0)); | attr->kernelH = static_cast<int32_t>(onnx_node_attr.ints(0)); | ||||
| attr->kernelW = static_cast<int32_t>(onnx_node_attr.ints(1)); | attr->kernelW = static_cast<int32_t>(onnx_node_attr.ints(1)); | ||||
| @@ -106,7 +102,7 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod | |||||
| } else if (onnx_node_attr.name() == "pads") { | } else if (onnx_node_attr.name() == "pads") { | ||||
| if (onnx_node_attr.ints().size() != 4) { | if (onnx_node_attr.ints().size() != 4) { | ||||
| MS_LOG(ERROR) << "pads size " << onnx_node_attr.ints().size() << " is not 4"; | MS_LOG(ERROR) << "pads size " << onnx_node_attr.ints().size() << " is not 4"; | ||||
| return RET_ERROR; | |||||
| return nullptr; | |||||
| } | } | ||||
| attr->padUp = static_cast<int32_t>(onnx_node_attr.ints(0)); | attr->padUp = static_cast<int32_t>(onnx_node_attr.ints(0)); | ||||
| attr->padLeft = static_cast<int32_t>(onnx_node_attr.ints(1)); | attr->padLeft = static_cast<int32_t>(onnx_node_attr.ints(1)); | ||||
| @@ -115,7 +111,7 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod | |||||
| } else if (onnx_node_attr.name() == "strides") { | } else if (onnx_node_attr.name() == "strides") { | ||||
| if (onnx_node_attr.ints().size() != 2) { | if (onnx_node_attr.ints().size() != 2) { | ||||
| MS_LOG(ERROR) << "strides size " << onnx_node_attr.ints().size() << " is not 2"; | MS_LOG(ERROR) << "strides size " << onnx_node_attr.ints().size() << " is not 2"; | ||||
| return RET_ERROR; | |||||
| return nullptr; | |||||
| } | } | ||||
| attr->strideH = static_cast<int32_t>(onnx_node_attr.ints(0)); | attr->strideH = static_cast<int32_t>(onnx_node_attr.ints(0)); | ||||
| attr->strideW = static_cast<int32_t>(onnx_node_attr.ints(1)); | attr->strideW = static_cast<int32_t>(onnx_node_attr.ints(1)); | ||||
| @@ -124,7 +120,7 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod | |||||
| attr->format = schema::Format::Format_NHWC; | attr->format = schema::Format::Format_NHWC; | ||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "Unsupported format: " << onnx_node_attr.s(); | MS_LOG(ERROR) << "Unsupported format: " << onnx_node_attr.s(); | ||||
| return RET_ERROR; | |||||
| return nullptr; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -152,7 +148,7 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod | |||||
| [onnx_conv_weight](const onnx::NodeProto &proto) { return proto.output(0) == onnx_conv_weight; }); | [onnx_conv_weight](const onnx::NodeProto &proto) { return proto.output(0) == onnx_conv_weight; }); | ||||
| if (node_iter == onnx_graph.node().end()) { | if (node_iter == onnx_graph.node().end()) { | ||||
| MS_LOG(ERROR) << "can not find node: " << onnx_conv_weight; | MS_LOG(ERROR) << "can not find node: " << onnx_conv_weight; | ||||
| return RET_ERROR; | |||||
| return nullptr; | |||||
| } | } | ||||
| std::vector<int> dims; | std::vector<int> dims; | ||||
| auto iter = std::find_if((*node_iter).attribute().begin(), (*node_iter).attribute().end(), | auto iter = std::find_if((*node_iter).attribute().begin(), (*node_iter).attribute().end(), | ||||
| @@ -160,7 +156,7 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod | |||||
| if (iter != (*node_iter).attribute().end()) { | if (iter != (*node_iter).attribute().end()) { | ||||
| if (iter->ints().begin() == nullptr || iter->ints().end() == nullptr) { | if (iter->ints().begin() == nullptr || iter->ints().end() == nullptr) { | ||||
| MS_LOG(ERROR) << "dims insert failed"; | MS_LOG(ERROR) << "dims insert failed"; | ||||
| return RET_ERROR; | |||||
| return nullptr; | |||||
| } | } | ||||
| dims.insert(dims.begin(), iter->ints().begin(), iter->ints().end()); | dims.insert(dims.begin(), iter->ints().begin(), iter->ints().end()); | ||||
| } | } | ||||
| @@ -174,16 +170,21 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod | |||||
| attr->activationType = schema::ActivationType_NO_ACTIVATION; | attr->activationType = schema::ActivationType_NO_ACTIVATION; | ||||
| } | } | ||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitive failed"; | |||||
| return nullptr; | |||||
| } | |||||
| if (attr->group > kSingleGroup && attr->group == attr->channelIn) { | if (attr->group > kSingleGroup && attr->group == attr->channelIn) { | ||||
| if (!ParseGroupConvolution(attr, op)) { | |||||
| if (!ParseGroupConvolution(attr, primitive.get())) { | |||||
| MS_LOG(ERROR) << "Convert Convolution to Depthwise failed"; | MS_LOG(ERROR) << "Convert Convolution to Depthwise failed"; | ||||
| return RET_ERROR; | |||||
| return nullptr; | |||||
| } | } | ||||
| } else { | } else { | ||||
| op->primitive->value.type = schema::PrimitiveType_Conv2D; | |||||
| op->primitive->value.value = attr.release(); | |||||
| primitive->value.type = schema::PrimitiveType_Conv2D; | |||||
| primitive->value.value = attr.release(); | |||||
| } | } | ||||
| return RET_OK; | |||||
| return PrimitiveC::Create(primitive.release()); | |||||
| } | } | ||||
| OnnxNodeRegistrar g_onnxConvParser("Conv", new OnnxConvParser()); | OnnxNodeRegistrar g_onnxConvParser("Conv", new OnnxConvParser()); | ||||
| @@ -28,10 +28,10 @@ class OnnxConvParser : public OnnxNodeParser { | |||||
| OnnxConvParser() : OnnxNodeParser("Conv") {} | OnnxConvParser() : OnnxNodeParser("Conv") {} | ||||
| ~OnnxConvParser() override = default; | ~OnnxConvParser() override = default; | ||||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||||
| private: | private: | ||||
| static bool ParseGroupConvolution(const std::unique_ptr<schema::Conv2DT> &attr, schema::CNodeT *op); | |||||
| static bool ParseGroupConvolution(const std::unique_ptr<schema::Conv2DT> &attr, schema::PrimitiveT *primitive); | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -21,11 +21,13 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| bool OnnxDeConvParser::ParseGroupDeConvolution(const std::unique_ptr<schema::DeConv2DT> &attr, schema::CNodeT *op) { | |||||
| if (attr == nullptr || attr->group != attr->channelOut) { | |||||
| bool OnnxDeConvParser::ParseGroupDeConvolution(const std::unique_ptr<schema::DeConv2DT> &attr, | |||||
| schema::PrimitiveT *primitive) { | |||||
| if (attr == nullptr || attr->group != attr->channelOut || primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "input parameter is nullptr"; | |||||
| return false; | return false; | ||||
| } | } | ||||
| std::unique_ptr<schema::DeDepthwiseConv2DT> deDepthwiseConv2DParam = std::make_unique<schema::DeDepthwiseConv2DT>(); | |||||
| auto deDepthwiseConv2DParam = std::make_unique<schema::DeDepthwiseConv2DT>(); | |||||
| if (deDepthwiseConv2DParam == nullptr) { | if (deDepthwiseConv2DParam == nullptr) { | ||||
| MS_LOG(ERROR) << "new op failed"; | MS_LOG(ERROR) << "new op failed"; | ||||
| return false; | return false; | ||||
| @@ -47,28 +49,18 @@ bool OnnxDeConvParser::ParseGroupDeConvolution(const std::unique_ptr<schema::DeC | |||||
| deDepthwiseConv2DParam->hasBias = attr->hasBias; | deDepthwiseConv2DParam->hasBias = attr->hasBias; | ||||
| deDepthwiseConv2DParam->activationType = attr->activationType; | deDepthwiseConv2DParam->activationType = attr->activationType; | ||||
| op->primitive->value.type = schema::PrimitiveType_DeDepthwiseConv2D; | |||||
| op->primitive->value.value = deDepthwiseConv2DParam.release(); | |||||
| primitive->value.type = schema::PrimitiveType_DeDepthwiseConv2D; | |||||
| primitive->value.value = deDepthwiseConv2DParam.release(); | |||||
| return true; | return true; | ||||
| } | } | ||||
| STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||||
| schema::CNodeT *op) { | |||||
| lite::PrimitiveC *OnnxDeConvParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||||
| const onnx::NodeProto &onnx_node) { | |||||
| MS_LOG(DEBUG) << "onnx DeConvParser"; | MS_LOG(DEBUG) << "onnx DeConvParser"; | ||||
| if (op == nullptr) { | |||||
| MS_LOG(ERROR) << "op is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (op->primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "op->primitive is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| std::unique_ptr<schema::DeConv2DT> attr = std::make_unique<schema::DeConv2DT>(); | |||||
| auto attr = std::make_unique<schema::DeConv2DT>(); | |||||
| if (attr == nullptr) { | if (attr == nullptr) { | ||||
| MS_LOG(ERROR) << "new op failed"; | MS_LOG(ERROR) << "new op failed"; | ||||
| return RET_NULL_PTR; | |||||
| return nullptr; | |||||
| } | } | ||||
| attr->padMode = schema::PadMode_NOTSET; | attr->padMode = schema::PadMode_NOTSET; | ||||
| @@ -83,21 +75,21 @@ STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N | |||||
| } else if (onnx_node_attr.name() == "dilations") { | } else if (onnx_node_attr.name() == "dilations") { | ||||
| if (onnx_node_attr.ints().size() != 2) { | if (onnx_node_attr.ints().size() != 2) { | ||||
| MS_LOG(ERROR) << "dilations size " << onnx_node_attr.ints().size() << " is not 2"; | MS_LOG(ERROR) << "dilations size " << onnx_node_attr.ints().size() << " is not 2"; | ||||
| return RET_ERROR; | |||||
| return nullptr; | |||||
| } | } | ||||
| attr->dilateH = static_cast<int32_t>(onnx_node_attr.ints(0)); | attr->dilateH = static_cast<int32_t>(onnx_node_attr.ints(0)); | ||||
| attr->dilateW = static_cast<int32_t>(onnx_node_attr.ints(1)); | attr->dilateW = static_cast<int32_t>(onnx_node_attr.ints(1)); | ||||
| } else if (onnx_node_attr.name() == "kernels") { | } else if (onnx_node_attr.name() == "kernels") { | ||||
| if (onnx_node_attr.ints().size() != 2) { | if (onnx_node_attr.ints().size() != 2) { | ||||
| MS_LOG(ERROR) << "kernel_shape size " << onnx_node_attr.ints().size() << " is not 2"; | MS_LOG(ERROR) << "kernel_shape size " << onnx_node_attr.ints().size() << " is not 2"; | ||||
| return RET_ERROR; | |||||
| return nullptr; | |||||
| } | } | ||||
| attr->kernelH = static_cast<int32_t>(onnx_node_attr.ints(0)); | attr->kernelH = static_cast<int32_t>(onnx_node_attr.ints(0)); | ||||
| attr->kernelW = static_cast<int32_t>(onnx_node_attr.ints(1)); | attr->kernelW = static_cast<int32_t>(onnx_node_attr.ints(1)); | ||||
| } else if (onnx_node_attr.name() == "kernel_shape") { | } else if (onnx_node_attr.name() == "kernel_shape") { | ||||
| if (onnx_node_attr.ints().size() != 2) { | if (onnx_node_attr.ints().size() != 2) { | ||||
| MS_LOG(ERROR) << "kernel_shape size " << onnx_node_attr.ints().size() << " is not 2"; | MS_LOG(ERROR) << "kernel_shape size " << onnx_node_attr.ints().size() << " is not 2"; | ||||
| return RET_ERROR; | |||||
| return nullptr; | |||||
| } | } | ||||
| attr->kernelH = static_cast<int32_t>(onnx_node_attr.ints(0)); | attr->kernelH = static_cast<int32_t>(onnx_node_attr.ints(0)); | ||||
| attr->kernelW = static_cast<int32_t>(onnx_node_attr.ints(1)); | attr->kernelW = static_cast<int32_t>(onnx_node_attr.ints(1)); | ||||
| @@ -106,7 +98,7 @@ STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N | |||||
| } else if (onnx_node_attr.name() == "pads") { | } else if (onnx_node_attr.name() == "pads") { | ||||
| if (onnx_node_attr.ints().size() != 4) { | if (onnx_node_attr.ints().size() != 4) { | ||||
| MS_LOG(ERROR) << "pads size " << onnx_node_attr.ints().size() << " is not 4"; | MS_LOG(ERROR) << "pads size " << onnx_node_attr.ints().size() << " is not 4"; | ||||
| return RET_ERROR; | |||||
| return nullptr; | |||||
| } | } | ||||
| attr->padUp = static_cast<int32_t>(onnx_node_attr.ints(0)); | attr->padUp = static_cast<int32_t>(onnx_node_attr.ints(0)); | ||||
| attr->padLeft = static_cast<int32_t>(onnx_node_attr.ints(1)); | attr->padLeft = static_cast<int32_t>(onnx_node_attr.ints(1)); | ||||
| @@ -115,7 +107,7 @@ STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N | |||||
| } else if (onnx_node_attr.name() == "strides") { | } else if (onnx_node_attr.name() == "strides") { | ||||
| if (onnx_node_attr.ints().size() != 2) { | if (onnx_node_attr.ints().size() != 2) { | ||||
| MS_LOG(ERROR) << "strides size " << onnx_node_attr.ints().size() << " is not 2"; | MS_LOG(ERROR) << "strides size " << onnx_node_attr.ints().size() << " is not 2"; | ||||
| return RET_ERROR; | |||||
| return nullptr; | |||||
| } | } | ||||
| attr->strideH = static_cast<int32_t>(onnx_node_attr.ints(0)); | attr->strideH = static_cast<int32_t>(onnx_node_attr.ints(0)); | ||||
| attr->strideW = static_cast<int32_t>(onnx_node_attr.ints(1)); | attr->strideW = static_cast<int32_t>(onnx_node_attr.ints(1)); | ||||
| @@ -124,11 +116,11 @@ STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N | |||||
| attr->format = schema::Format::Format_NHWC; | attr->format = schema::Format::Format_NHWC; | ||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "Unsupported format: " << onnx_node_attr.s().c_str(); | MS_LOG(ERROR) << "Unsupported format: " << onnx_node_attr.s().c_str(); | ||||
| return RET_ERROR; | |||||
| return nullptr; | |||||
| } | } | ||||
| } else if (onnx_node_attr.name() == "output_padding") { | } else if (onnx_node_attr.name() == "output_padding") { | ||||
| MS_LOG(ERROR) << "output_padding param hasn't been supported"; | MS_LOG(ERROR) << "output_padding param hasn't been supported"; | ||||
| return RET_NOT_SUPPORT; | |||||
| return nullptr; | |||||
| } | } | ||||
| } | } | ||||
| @@ -138,7 +130,7 @@ STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N | |||||
| [onnx_conv_weight](const onnx::TensorProto &proto) { return proto.name() == onnx_conv_weight; }); | [onnx_conv_weight](const onnx::TensorProto &proto) { return proto.name() == onnx_conv_weight; }); | ||||
| if (node_iter == onnx_graph.initializer().end()) { | if (node_iter == onnx_graph.initializer().end()) { | ||||
| MS_LOG(ERROR) << "not find node: " << onnx_conv_weight.c_str(); | MS_LOG(ERROR) << "not find node: " << onnx_conv_weight.c_str(); | ||||
| return RET_ERROR; | |||||
| return nullptr; | |||||
| } | } | ||||
| std::vector<int> weight_shape; | std::vector<int> weight_shape; | ||||
| auto size = (*node_iter).dims_size(); | auto size = (*node_iter).dims_size(); | ||||
| @@ -148,7 +140,7 @@ STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N | |||||
| } | } | ||||
| if (weight_shape.size() != 4) { | if (weight_shape.size() != 4) { | ||||
| MS_LOG(ERROR) << "weight_shape.size() should be 4, but is " << weight_shape.size(); | MS_LOG(ERROR) << "weight_shape.size() should be 4, but is " << weight_shape.size(); | ||||
| return RET_ERROR; | |||||
| return nullptr; | |||||
| } | } | ||||
| attr->channelIn = weight_shape[0]; | attr->channelIn = weight_shape[0]; | ||||
| attr->channelOut = weight_shape[1] * attr->group; | attr->channelOut = weight_shape[1] * attr->group; | ||||
| @@ -156,17 +148,22 @@ STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N | |||||
| attr->format = schema::Format::Format_NCHW; | attr->format = schema::Format::Format_NCHW; | ||||
| attr->hasBias = onnx_node.input().size() == 3; | attr->hasBias = onnx_node.input().size() == 3; | ||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitive failed"; | |||||
| return nullptr; | |||||
| } | |||||
| if (attr->group != 1) { | if (attr->group != 1) { | ||||
| if (!ParseGroupDeConvolution(attr, op)) { | |||||
| if (!ParseGroupDeConvolution(attr, primitive.get())) { | |||||
| MS_LOG(ERROR) << "Convert DeConvolution to DeDepthwise failed, generalized group deconv hasn't support"; | MS_LOG(ERROR) << "Convert DeConvolution to DeDepthwise failed, generalized group deconv hasn't support"; | ||||
| return RET_NOT_SUPPORT; | |||||
| return nullptr; | |||||
| } | } | ||||
| } else { | } else { | ||||
| op->primitive->value.type = schema::PrimitiveType_DeConv2D; | |||||
| op->primitive->value.value = attr.release(); | |||||
| primitive->value.type = schema::PrimitiveType_DeConv2D; | |||||
| primitive->value.value = attr.release(); | |||||
| } | } | ||||
| return RET_OK; | |||||
| return PrimitiveC::Create(primitive.release()); | |||||
| } | } | ||||
| OnnxNodeRegistrar g_onnxDeConvParser("ConvTranspose", new OnnxDeConvParser()); | OnnxNodeRegistrar g_onnxDeConvParser("ConvTranspose", new OnnxDeConvParser()); | ||||
| @@ -28,10 +28,10 @@ class OnnxDeConvParser : public OnnxNodeParser { | |||||
| OnnxDeConvParser() : OnnxNodeParser("DeConv") {} | OnnxDeConvParser() : OnnxNodeParser("DeConv") {} | ||||
| ~OnnxDeConvParser() override = default; | ~OnnxDeConvParser() override = default; | ||||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||||
| private: | private: | ||||
| static bool ParseGroupDeConvolution(const std::unique_ptr<schema::DeConv2DT> &attr, schema::CNodeT *op); | |||||
| bool ParseGroupDeConvolution(const std::unique_ptr<schema::DeConv2DT> &attr, schema::PrimitiveT *primitive); | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -19,23 +19,13 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS OnnxDepthToSpaceParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||||
| schema::CNodeT *op) { | |||||
| lite::PrimitiveC *OnnxDepthToSpaceParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||||
| const onnx::NodeProto &onnx_node) { | |||||
| MS_LOG(DEBUG) << "onnx DepthToSpaceParser"; | MS_LOG(DEBUG) << "onnx DepthToSpaceParser"; | ||||
| if (op == nullptr) { | |||||
| MS_LOG(ERROR) << "op is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (op->primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "op->primitive is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| std::unique_ptr<schema::DepthToSpaceT> attr = std::make_unique<schema::DepthToSpaceT>(); | |||||
| auto attr = std::make_unique<schema::DepthToSpaceT>(); | |||||
| if (attr == nullptr) { | if (attr == nullptr) { | ||||
| MS_LOG(ERROR) << "new op failed"; | MS_LOG(ERROR) << "new op failed"; | ||||
| return RET_NULL_PTR; | |||||
| return nullptr; | |||||
| } | } | ||||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | for (const auto &onnx_node_attr : onnx_node.attribute()) { | ||||
| @@ -44,10 +34,14 @@ STATUS OnnxDepthToSpaceParser::Parse(const onnx::GraphProto &onnx_graph, const o | |||||
| attr->blockSize = static_cast<int32_t>(onnx_node_attr.i()); | attr->blockSize = static_cast<int32_t>(onnx_node_attr.i()); | ||||
| } | } | ||||
| } | } | ||||
| op->primitive->value.type = schema::PrimitiveType_DepthToSpace; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitive failed"; | |||||
| return nullptr; | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_DepthToSpace; | |||||
| primitive->value.value = attr.release(); | |||||
| return PrimitiveC::Create(primitive.release()); | |||||
| } | } | ||||
| OnnxNodeRegistrar g_onnxDepthToSpaceParser("DepthToSpace", new OnnxDepthToSpaceParser()); | OnnxNodeRegistrar g_onnxDepthToSpaceParser("DepthToSpace", new OnnxDepthToSpaceParser()); | ||||
| @@ -27,7 +27,7 @@ class OnnxDepthToSpaceParser : public OnnxNodeParser { | |||||
| OnnxDepthToSpaceParser() : OnnxNodeParser("DepthToSpace") {} | OnnxDepthToSpaceParser() : OnnxNodeParser("DepthToSpace") {} | ||||
| ~OnnxDepthToSpaceParser() override = default; | ~OnnxDepthToSpaceParser() override = default; | ||||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -19,23 +19,13 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS OnnxDropoutParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||||
| schema::CNodeT *op) { | |||||
| lite::PrimitiveC *OnnxDropoutParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||||
| const onnx::NodeProto &onnx_node) { | |||||
| MS_LOG(DEBUG) << "onnx DropoutParser"; | MS_LOG(DEBUG) << "onnx DropoutParser"; | ||||
| if (op == nullptr) { | |||||
| MS_LOG(ERROR) << "op is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (op->primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "op->primitive is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| std::unique_ptr<schema::DropoutT> attr = std::make_unique<schema::DropoutT>(); | |||||
| auto attr = std::make_unique<schema::DropoutT>(); | |||||
| if (attr == nullptr) { | if (attr == nullptr) { | ||||
| MS_LOG(ERROR) << "new op failed"; | MS_LOG(ERROR) << "new op failed"; | ||||
| return RET_NULL_PTR; | |||||
| return nullptr; | |||||
| } | } | ||||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | for (const auto &onnx_node_attr : onnx_node.attribute()) { | ||||
| @@ -44,10 +34,14 @@ STATUS OnnxDropoutParser::Parse(const onnx::GraphProto &onnx_graph, const onnx:: | |||||
| attr->ratio = static_cast<float>(onnx_node_attr.f()); | attr->ratio = static_cast<float>(onnx_node_attr.f()); | ||||
| } | } | ||||
| } | } | ||||
| op->primitive->value.type = schema::PrimitiveType_Dropout; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitive failed"; | |||||
| return nullptr; | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_Dropout; | |||||
| primitive->value.value = attr.release(); | |||||
| return PrimitiveC::Create(primitive.release()); | |||||
| } | } | ||||
| OnnxNodeRegistrar g_onnxDropoutParser("Dropout", new OnnxDropoutParser()); | OnnxNodeRegistrar g_onnxDropoutParser("Dropout", new OnnxDropoutParser()); | ||||
| @@ -27,7 +27,7 @@ class OnnxDropoutParser : public OnnxNodeParser { | |||||
| OnnxDropoutParser() : OnnxNodeParser("Dropout") {} | OnnxDropoutParser() : OnnxNodeParser("Dropout") {} | ||||
| ~OnnxDropoutParser() override = default; | ~OnnxDropoutParser() override = default; | ||||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -19,22 +19,13 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS OnnxEluParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { | |||||
| lite::PrimitiveC *OnnxEluParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||||
| const onnx::NodeProto &onnx_node) { | |||||
| MS_LOG(DEBUG) << "onnx EluParser"; | MS_LOG(DEBUG) << "onnx EluParser"; | ||||
| if (op == nullptr) { | |||||
| MS_LOG(ERROR) << "op is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (op->primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "op->primitive is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| std::unique_ptr<schema::EluT> attr = std::make_unique<schema::EluT>(); | |||||
| auto attr = std::make_unique<schema::EluT>(); | |||||
| if (attr == nullptr) { | if (attr == nullptr) { | ||||
| MS_LOG(ERROR) << "new op failed"; | MS_LOG(ERROR) << "new op failed"; | ||||
| return RET_NULL_PTR; | |||||
| return nullptr; | |||||
| } | } | ||||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | for (const auto &onnx_node_attr : onnx_node.attribute()) { | ||||
| @@ -43,10 +34,14 @@ STATUS OnnxEluParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Node | |||||
| attr->alpha = onnx_node_attr.f(); | attr->alpha = onnx_node_attr.f(); | ||||
| } | } | ||||
| } | } | ||||
| op->primitive->value.type = schema::PrimitiveType_Elu; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitive failed"; | |||||
| return nullptr; | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_Elu; | |||||
| primitive->value.value = attr.release(); | |||||
| return PrimitiveC::Create(primitive.release()); | |||||
| } | } | ||||
| OnnxNodeRegistrar g_onnxEluParser("Elu", new OnnxEluParser()); | OnnxNodeRegistrar g_onnxEluParser("Elu", new OnnxEluParser()); | ||||
| @@ -27,7 +27,7 @@ class OnnxEluParser : public OnnxNodeParser { | |||||
| OnnxEluParser() : OnnxNodeParser("Elu") {} | OnnxEluParser() : OnnxNodeParser("Elu") {} | ||||
| ~OnnxEluParser() override = default; | ~OnnxEluParser() override = default; | ||||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -20,23 +20,13 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS OnnxExpandParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||||
| schema::CNodeT *op) { | |||||
| lite::PrimitiveC *OnnxExpandParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||||
| const onnx::NodeProto &onnx_node) { | |||||
| MS_LOG(DEBUG) << "onnx ExpandParser"; | MS_LOG(DEBUG) << "onnx ExpandParser"; | ||||
| if (op == nullptr) { | |||||
| MS_LOG(ERROR) << "op is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (op->primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "op->primitive is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| std::unique_ptr<schema::BroadcastToT> attr = std::make_unique<schema::BroadcastToT>(); | |||||
| auto attr = std::make_unique<schema::BroadcastToT>(); | |||||
| if (attr == nullptr) { | if (attr == nullptr) { | ||||
| MS_LOG(ERROR) << "new op failed"; | MS_LOG(ERROR) << "new op failed"; | ||||
| return RET_NULL_PTR; | |||||
| return nullptr; | |||||
| } | } | ||||
| std::vector<int> dst_shape; | std::vector<int> dst_shape; | ||||
| @@ -46,7 +36,7 @@ STATUS OnnxExpandParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N | |||||
| [onnx_expand_power](const onnx::NodeProto &proto) { return proto.output(0) == onnx_expand_power; }); | [onnx_expand_power](const onnx::NodeProto &proto) { return proto.output(0) == onnx_expand_power; }); | ||||
| if (node_iter == onnx_graph.node().end()) { | if (node_iter == onnx_graph.node().end()) { | ||||
| MS_LOG(ERROR) << "can not find node: " << onnx_expand_power; | MS_LOG(ERROR) << "can not find node: " << onnx_expand_power; | ||||
| return RET_ERROR; | |||||
| return nullptr; | |||||
| } | } | ||||
| for (const auto &attrPower : node_iter->attribute()) { | for (const auto &attrPower : node_iter->attribute()) { | ||||
| if (attrPower.name() == "value") { | if (attrPower.name() == "value") { | ||||
| @@ -58,9 +48,14 @@ STATUS OnnxExpandParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N | |||||
| } | } | ||||
| } | } | ||||
| attr->dst_shape = dst_shape; | attr->dst_shape = dst_shape; | ||||
| op->primitive->value.type = schema::PrimitiveType_BroadcastTo; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitive failed"; | |||||
| return nullptr; | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_BroadcastTo; | |||||
| primitive->value.value = attr.release(); | |||||
| return PrimitiveC::Create(primitive.release()); | |||||
| } | } | ||||
| OnnxNodeRegistrar g_onnxExpandSpaceParser("Expand", new OnnxExpandParser()); | OnnxNodeRegistrar g_onnxExpandSpaceParser("Expand", new OnnxExpandParser()); | ||||
| @@ -27,7 +27,7 @@ class OnnxExpandParser : public OnnxNodeParser { | |||||
| OnnxExpandParser() : OnnxNodeParser("Expand") {} | OnnxExpandParser() : OnnxNodeParser("Expand") {} | ||||
| ~OnnxExpandParser() override = default; | ~OnnxExpandParser() override = default; | ||||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -19,23 +19,13 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS OnnxFlattenParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||||
| schema::CNodeT *op) { | |||||
| lite::PrimitiveC *OnnxFlattenParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||||
| const onnx::NodeProto &onnx_node) { | |||||
| MS_LOG(DEBUG) << "onnx FlattenParser"; | MS_LOG(DEBUG) << "onnx FlattenParser"; | ||||
| if (op == nullptr) { | |||||
| MS_LOG(ERROR) << "op is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (op->primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "op->primitive is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| std::unique_ptr<schema::ReshapeT> attr = std::make_unique<schema::ReshapeT>(); | |||||
| auto attr = std::make_unique<schema::ReshapeT>(); | |||||
| if (attr == nullptr) { | if (attr == nullptr) { | ||||
| MS_LOG(ERROR) << "new op failed"; | MS_LOG(ERROR) << "new op failed"; | ||||
| return RET_NULL_PTR; | |||||
| return nullptr; | |||||
| } | } | ||||
| int axis = 1; | int axis = 1; | ||||
| @@ -49,10 +39,14 @@ STATUS OnnxFlattenParser::Parse(const onnx::GraphProto &onnx_graph, const onnx:: | |||||
| attr->shape.emplace_back(0); | attr->shape.emplace_back(0); | ||||
| } | } | ||||
| attr->shape.emplace_back(-1); | attr->shape.emplace_back(-1); | ||||
| op->primitive->value.type = schema::PrimitiveType_Reshape; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitive failed"; | |||||
| return nullptr; | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_Reshape; | |||||
| primitive->value.value = attr.release(); | |||||
| return PrimitiveC::Create(primitive.release()); | |||||
| } | } | ||||
| OnnxNodeRegistrar g_onnxFlattenParser("Flatten", new OnnxFlattenParser()); | OnnxNodeRegistrar g_onnxFlattenParser("Flatten", new OnnxFlattenParser()); | ||||
| @@ -27,7 +27,7 @@ class OnnxFlattenParser : public OnnxNodeParser { | |||||
| OnnxFlattenParser() : OnnxNodeParser("Fatten") {} | OnnxFlattenParser() : OnnxNodeParser("Fatten") {} | ||||
| ~OnnxFlattenParser() override = default; | ~OnnxFlattenParser() override = default; | ||||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -19,23 +19,13 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS OnnxGatherParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||||
| schema::CNodeT *op) { | |||||
| lite::PrimitiveC *OnnxGatherParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||||
| const onnx::NodeProto &onnx_node) { | |||||
| MS_LOG(DEBUG) << "onnx GatherParser"; | MS_LOG(DEBUG) << "onnx GatherParser"; | ||||
| if (op == nullptr) { | |||||
| MS_LOG(ERROR) << "op is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (op->primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "op->primitive is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| std::unique_ptr<schema::GatherT> attr = std::make_unique<schema::GatherT>(); | |||||
| auto attr = std::make_unique<schema::GatherT>(); | |||||
| if (attr == nullptr) { | if (attr == nullptr) { | ||||
| MS_LOG(ERROR) << "new op failed"; | MS_LOG(ERROR) << "new op failed"; | ||||
| return RET_NULL_PTR; | |||||
| return nullptr; | |||||
| } | } | ||||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | for (const auto &onnx_node_attr : onnx_node.attribute()) { | ||||
| @@ -45,9 +35,14 @@ STATUS OnnxGatherParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N | |||||
| } | } | ||||
| } | } | ||||
| op->primitive->value.type = schema::PrimitiveType_Gather; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitive failed"; | |||||
| return nullptr; | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_Gather; | |||||
| primitive->value.value = attr.release(); | |||||
| return PrimitiveC::Create(primitive.release()); | |||||
| } | } | ||||
| OnnxNodeRegistrar g_onnxGatherParser("Gather", new OnnxGatherParser()); | OnnxNodeRegistrar g_onnxGatherParser("Gather", new OnnxGatherParser()); | ||||
| @@ -27,7 +27,7 @@ class OnnxGatherParser : public OnnxNodeParser { | |||||
| OnnxGatherParser() : OnnxNodeParser("Gather") {} | OnnxGatherParser() : OnnxNodeParser("Gather") {} | ||||
| ~OnnxGatherParser() override = default; | ~OnnxGatherParser() override = default; | ||||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -0,0 +1,55 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "tools/converter/parser/onnx/onnx_gemm_parser.h" | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| lite::PrimitiveC *OnnxGemmParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||||
| const onnx::NodeProto &onnx_node) { | |||||
| MS_LOG(DEBUG) << "onnx IdentityParser"; | |||||
| auto node_parser = OnnxNodeParserRegistry::GetInstance()->GetNodeParser("MatMul"); | |||||
| if (node_parser == nullptr) { | |||||
| MS_LOG(ERROR) << "parse node " << onnx_node.op_type() << " failed."; | |||||
| return nullptr; | |||||
| } | |||||
| auto *matmul_primitive = node_parser->ParseLitePrimitive(onnx_graph, onnx_node); | |||||
| node_parser = OnnxNodeParserRegistry::GetInstance()->GetNodeParser("BiasAdd"); | |||||
| if (node_parser == nullptr) { | |||||
| MS_LOG(ERROR) << "parse node " << onnx_node.op_type() << " failed."; | |||||
| return nullptr; | |||||
| } | |||||
| auto *bias_add_primitive = node_parser->ParseLitePrimitive(onnx_graph, onnx_node); | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitive failed"; | |||||
| return nullptr; | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_MakeTuple; | |||||
| auto primitve_c = PrimitiveC::Create(primitive.release()); | |||||
| primitve_c->set_attr("MatMul", std::shared_ptr<lite::PrimitiveC>(matmul_primitive)); | |||||
| primitve_c->set_attr("BiasAdd", std::shared_ptr<lite::PrimitiveC>(bias_add_primitive)); | |||||
| return primitve_c; | |||||
| } | |||||
| OnnxNodeRegistrar g_onnxGemmParser("Gemm", new OnnxGemmParser()); | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -14,26 +14,21 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_TENSOR_PARSER_H | |||||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_TENSOR_PARSER_H | |||||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_GEMM_PARSER_H | |||||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_GEMM_PARSER_H | |||||
| #include "tools/common/tensor_util.h" | |||||
| #include "tools/converter/parser/onnx/onnx_node_parser.h" | |||||
| #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| class OnnxTensorParser { | |||||
| class OnnxGemmParser : public OnnxNodeParser { | |||||
| public: | public: | ||||
| ~OnnxTensorParser() = default; | |||||
| static OnnxTensorParser *GetInstance() { | |||||
| static OnnxTensorParser onnxTensorParser; | |||||
| return &onnxTensorParser; | |||||
| } | |||||
| TensorCache *GetTensorCache() { return &tensor_cache_; } | |||||
| OnnxGemmParser() : OnnxNodeParser("Gemm") {} | |||||
| ~OnnxGemmParser() override = default; | |||||
| private: | |||||
| OnnxTensorParser() = default; | |||||
| TensorCache tensor_cache_; | |||||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_TESNOR_PARSER_H | |||||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_GEMM_PARSER_H | |||||
| @@ -0,0 +1,127 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "tools/converter/parser/onnx/onnx_given_tensor_fill_parser.h" | |||||
| #include <functional> | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include <algorithm> | |||||
| #include "src/param_value_lite.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| STATUS OnnxGivenTensorFillParser::ParseInt8GivenIntTensorFill(const onnx::NodeProto &onnx_node, | |||||
| lite::PrimitiveC *primitive_c, | |||||
| const std::vector<int> &shape) { | |||||
| ParamValueLitePtr param_value = std::make_shared<ParamValueLite>(); | |||||
| if (param_value == nullptr) { | |||||
| MS_LOG(ERROR) << "new a paramValueLite failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| int data_count = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>()); | |||||
| auto iter = std::find_if(onnx_node.attribute().begin(), onnx_node.attribute().end(), | |||||
| [](const onnx::AttributeProto &attr) { return attr.name() == "values"; }); | |||||
| if (iter == onnx_node.attribute().end()) { | |||||
| return RET_OK; | |||||
| } | |||||
| size_t data_size = data_count * sizeof(int64_t) / sizeof(uint8_t); | |||||
| char *param_data = new (std::nothrow) char[data_size]; | |||||
| if (param_data == nullptr) { | |||||
| MS_LOG(ERROR) << "new char[] failed"; | |||||
| return RET_MEMORY_FAILED; | |||||
| } | |||||
| if (memcpy_s(param_data, data_size, iter->ints().data(), data_size) != EOK) { | |||||
| MS_LOG(ERROR) << "memcpy data failed."; | |||||
| delete[] param_data; | |||||
| return RET_ERROR; | |||||
| } | |||||
| param_value->set_tensor_shape(shape); | |||||
| param_value->set_format(schema::Format_NUM_OF_FORMAT); | |||||
| param_value->set_tensor_type(kNumberTypeInt64); | |||||
| param_value->SetTensorData(param_data, data_size); | |||||
| primitive_c->set_attr("const_data", param_value); | |||||
| return RET_OK; | |||||
| } | |||||
| STATUS OnnxGivenTensorFillParser::ParseInt8GivenTensorFill(const onnx::NodeProto &onnx_node, | |||||
| lite::PrimitiveC *primitive_c, | |||||
| const std::vector<int> &shape) { | |||||
| ParamValueLitePtr param_value = std::make_shared<ParamValueLite>(); | |||||
| if (param_value == nullptr) { | |||||
| MS_LOG(ERROR) << "new a paramValueLite failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| int data_count = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int64_t>()); | |||||
| auto iter = std::find_if(onnx_node.attribute().begin(), onnx_node.attribute().end(), | |||||
| [](const onnx::AttributeProto &attr) { return attr.name() == "values"; }); | |||||
| if (iter == onnx_node.attribute().end()) { | |||||
| return RET_OK; | |||||
| } | |||||
| char *param_data = new (std::nothrow) char[data_count]; | |||||
| if (param_data == nullptr) { | |||||
| MS_LOG(ERROR) << "new char[] failed"; | |||||
| return RET_MEMORY_FAILED; | |||||
| } | |||||
| if (memcpy_s(param_data, data_count, iter->s().data(), data_count) != EOK) { | |||||
| MS_LOG(ERROR) << "memcpy data failed."; | |||||
| delete[] param_data; | |||||
| return RET_ERROR; | |||||
| } | |||||
| param_value->set_tensor_shape(shape); | |||||
| param_value->set_format(schema::Format_NUM_OF_FORMAT); | |||||
| param_value->set_tensor_type(kNumberTypeUInt8); | |||||
| param_value->SetTensorData(param_data, data_count); | |||||
| primitive_c->set_attr("const_data", param_value); | |||||
| return RET_OK; | |||||
| } | |||||
| lite::PrimitiveC *OnnxGivenTensorFillParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||||
| const onnx::NodeProto &onnx_node) { | |||||
| MS_LOG(DEBUG) << "onnx GivenTensorFillParser"; | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitive failed"; | |||||
| return nullptr; | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_Constant; | |||||
| auto primitive_c = PrimitiveC::Create(primitive.release()); | |||||
| std::vector<int64_t> shape_vector; | |||||
| auto iter = std::find_if(onnx_node.attribute().begin(), onnx_node.attribute().end(), | |||||
| [](const onnx::AttributeProto &attr) { return attr.name() == "shape"; }); | |||||
| if (iter != onnx_node.attribute().end()) { | |||||
| shape_vector.insert(shape_vector.begin(), iter->ints().begin(), iter->ints().end()); | |||||
| } | |||||
| std::vector<int> shape; | |||||
| std::transform(shape_vector.begin(), shape_vector.end(), std::back_inserter(shape), | |||||
| [](const int64_t &val) { return static_cast<int32_t>(val); }); | |||||
| if (onnx_node.op_type() == "Int8GivenIntTensorFill") { | |||||
| if (ParseInt8GivenIntTensorFill(onnx_node, primitive_c, shape) != RET_OK) { | |||||
| MS_LOG(ERROR) << "given tensor fill parse failed."; | |||||
| return nullptr; | |||||
| } | |||||
| } else if (onnx_node.op_type() == "Int8GivenTensorFill") { | |||||
| if (ParseInt8GivenTensorFill(onnx_node, primitive_c, shape) != RET_OK) { | |||||
| MS_LOG(ERROR) << "given tensor fill parse failed."; | |||||
| return nullptr; | |||||
| } | |||||
| } | |||||
| return primitive_c; | |||||
| } | |||||
| OnnxNodeRegistrar g_onnxInt8GivenIntTensorFillParser("Int8GivenIntTensorFill", new OnnxGivenTensorFillParser()); | |||||
| OnnxNodeRegistrar g_onnxInt8GivenTensorFillParser("Int8GivenTensorFill", new OnnxGivenTensorFillParser()); | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,39 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_GIVEN_TENSOR_FILL_PARSER_H | |||||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_GIVEN_TENSOR_FILL_PARSER_H | |||||
| #include <vector> | |||||
| #include "tools/converter/parser/onnx/onnx_node_parser.h" | |||||
| #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class OnnxGivenTensorFillParser : public OnnxNodeParser { | |||||
| public: | |||||
| OnnxGivenTensorFillParser() : OnnxNodeParser("GivenTensorFill") {} | |||||
| ~OnnxGivenTensorFillParser() override = default; | |||||
| STATUS ParseInt8GivenIntTensorFill(const onnx::NodeProto &onnx_node, lite::PrimitiveC *primitive_c, | |||||
| const std::vector<int> &shape); | |||||
| STATUS ParseInt8GivenTensorFill(const onnx::NodeProto &onnx_node, lite::PrimitiveC *primitive_c, | |||||
| const std::vector<int> &shape); | |||||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_GIVEN_TENSOR_FILL_PARSER_H | |||||
| @@ -20,28 +20,23 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS OnnxIdentityParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||||
| schema::CNodeT *op) { | |||||
| lite::PrimitiveC *OnnxIdentityParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||||
| const onnx::NodeProto &onnx_node) { | |||||
| MS_LOG(DEBUG) << "onnx IdentityParser"; | MS_LOG(DEBUG) << "onnx IdentityParser"; | ||||
| if (op == nullptr) { | |||||
| MS_LOG(ERROR) << "op is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (op->primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "op->primitive is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| std::unique_ptr<schema::IdentityT> attr = std::make_unique<schema::IdentityT>(); | |||||
| auto attr = std::make_unique<schema::IdentityT>(); | |||||
| if (attr == nullptr) { | if (attr == nullptr) { | ||||
| MS_LOG(ERROR) << "new op failed"; | MS_LOG(ERROR) << "new op failed"; | ||||
| return RET_NULL_PTR; | |||||
| return nullptr; | |||||
| } | } | ||||
| op->primitive->value.type = schema::PrimitiveType_Identity; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitive failed"; | |||||
| return nullptr; | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_Identity; | |||||
| primitive->value.value = attr.release(); | |||||
| return PrimitiveC::Create(primitive.release()); | |||||
| } | } | ||||
| OnnxNodeRegistrar g_onnxIdentityParser("Identity", new OnnxIdentityParser()); | OnnxNodeRegistrar g_onnxIdentityParser("Identity", new OnnxIdentityParser()); | ||||
| @@ -27,7 +27,7 @@ class OnnxIdentityParser : public OnnxNodeParser { | |||||
| OnnxIdentityParser() : OnnxNodeParser("Identity") {} | OnnxIdentityParser() : OnnxNodeParser("Identity") {} | ||||
| ~OnnxIdentityParser() override = default; | ~OnnxIdentityParser() override = default; | ||||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -19,23 +19,13 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS OnnxInstanceNormParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||||
| schema::CNodeT *op) { | |||||
| lite::PrimitiveC *OnnxInstanceNormParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||||
| const onnx::NodeProto &onnx_node) { | |||||
| MS_LOG(DEBUG) << "onnx InstanceNormParser"; | MS_LOG(DEBUG) << "onnx InstanceNormParser"; | ||||
| if (op == nullptr) { | |||||
| MS_LOG(ERROR) << "op is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (op->primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "op->primitive is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| std::unique_ptr<schema::InstanceNormT> attr = std::make_unique<schema::InstanceNormT>(); | |||||
| auto attr = std::make_unique<schema::InstanceNormT>(); | |||||
| if (attr == nullptr) { | if (attr == nullptr) { | ||||
| MS_LOG(ERROR) << "new op failed"; | MS_LOG(ERROR) << "new op failed"; | ||||
| return RET_NULL_PTR; | |||||
| return nullptr; | |||||
| } | } | ||||
| if (!onnx_node.attribute().empty()) { | if (!onnx_node.attribute().empty()) { | ||||
| @@ -44,10 +34,14 @@ STATUS OnnxInstanceNormParser::Parse(const onnx::GraphProto &onnx_graph, const o | |||||
| attr->epsilon = onnx_node_attr.f(); | attr->epsilon = onnx_node_attr.f(); | ||||
| } | } | ||||
| } | } | ||||
| op->primitive->value.type = schema::PrimitiveType_InstanceNorm; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitive failed"; | |||||
| return nullptr; | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_InstanceNorm; | |||||
| primitive->value.value = attr.release(); | |||||
| return PrimitiveC::Create(primitive.release()); | |||||
| } | } | ||||
| OnnxNodeRegistrar g_onnxInstanceNormParser("InstanceNormalization", new OnnxInstanceNormParser()); | OnnxNodeRegistrar g_onnxInstanceNormParser("InstanceNormalization", new OnnxInstanceNormParser()); | ||||
| @@ -27,7 +27,7 @@ class OnnxInstanceNormParser : public OnnxNodeParser { | |||||
| OnnxInstanceNormParser() : OnnxNodeParser("InstanceNorm") {} | OnnxInstanceNormParser() : OnnxNodeParser("InstanceNorm") {} | ||||
| ~OnnxInstanceNormParser() override = default; | ~OnnxInstanceNormParser() override = default; | ||||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -18,23 +18,13 @@ | |||||
| #include <memory> | #include <memory> | ||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| STATUS OnnxLpNormParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||||
| schema::CNodeT *op) { | |||||
| lite::PrimitiveC *OnnxLpNormParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||||
| const onnx::NodeProto &onnx_node) { | |||||
| MS_LOG(DEBUG) << "onnx LpNormParser"; | MS_LOG(DEBUG) << "onnx LpNormParser"; | ||||
| if (op == nullptr) { | |||||
| MS_LOG(ERROR) << "op is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (op->primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "op->primitive is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| std::unique_ptr<schema::LpNormalizationT> attr = std::make_unique<schema::LpNormalizationT>(); | |||||
| auto attr = std::make_unique<schema::LpNormalizationT>(); | |||||
| if (attr == nullptr) { | if (attr == nullptr) { | ||||
| MS_LOG(ERROR) << "new op failed"; | MS_LOG(ERROR) << "new op failed"; | ||||
| return RET_NULL_PTR; | |||||
| return nullptr; | |||||
| } | } | ||||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | for (const auto &onnx_node_attr : onnx_node.attribute()) { | ||||
| @@ -45,10 +35,14 @@ STATUS OnnxLpNormParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N | |||||
| attr->p = onnx_node_attr.i(); | attr->p = onnx_node_attr.i(); | ||||
| } | } | ||||
| } | } | ||||
| op->primitive->value.type = schema::PrimitiveType_LpNormalization; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitive failed"; | |||||
| return nullptr; | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_LpNormalization; | |||||
| primitive->value.value = attr.release(); | |||||
| return PrimitiveC::Create(primitive.release()); | |||||
| } | } | ||||
| OnnxNodeRegistrar g_onnxLpNormParser("LpNormalization", new OnnxLpNormParser()); | OnnxNodeRegistrar g_onnxLpNormParser("LpNormalization", new OnnxLpNormParser()); | ||||
| @@ -27,7 +27,7 @@ class OnnxLpNormParser : public OnnxNodeParser { | |||||
| OnnxLpNormParser() : OnnxNodeParser("LpNorm") {} | OnnxLpNormParser() : OnnxNodeParser("LpNorm") {} | ||||
| ~OnnxLpNormParser() override = default; | ~OnnxLpNormParser() override = default; | ||||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -18,22 +18,13 @@ | |||||
| #include <memory> | #include <memory> | ||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| STATUS OnnxLrnParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { | |||||
| lite::PrimitiveC *OnnxLrnParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||||
| const onnx::NodeProto &onnx_node) { | |||||
| MS_LOG(DEBUG) << "onnx LrnParser"; | MS_LOG(DEBUG) << "onnx LrnParser"; | ||||
| if (op == nullptr) { | |||||
| MS_LOG(ERROR) << "op is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (op->primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "op->primitive is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| std::unique_ptr<schema::LocalResponseNormalizationT> attr = std::make_unique<schema::LocalResponseNormalizationT>(); | |||||
| auto attr = std::make_unique<schema::LocalResponseNormalizationT>(); | |||||
| if (attr == nullptr) { | if (attr == nullptr) { | ||||
| MS_LOG(ERROR) << "new op failed"; | MS_LOG(ERROR) << "new op failed"; | ||||
| return RET_NULL_PTR; | |||||
| return nullptr; | |||||
| } | } | ||||
| int32_t size = 0; | int32_t size = 0; | ||||
| @@ -53,13 +44,18 @@ STATUS OnnxLrnParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Node | |||||
| if (size == 0) { | if (size == 0) { | ||||
| MS_LOG(ERROR) << "Divide-by-zero error."; | MS_LOG(ERROR) << "Divide-by-zero error."; | ||||
| return RET_ERROR; | |||||
| return nullptr; | |||||
| } | } | ||||
| attr->alpha /= size; | attr->alpha /= size; | ||||
| op->primitive->value.type = schema::PrimitiveType_LocalResponseNormalization; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitive failed"; | |||||
| return nullptr; | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_LocalResponseNormalization; | |||||
| primitive->value.value = attr.release(); | |||||
| return PrimitiveC::Create(primitive.release()); | |||||
| } | } | ||||
| OnnxNodeRegistrar g_onnxLrnxParser("Lrn", new OnnxLrnParser()); | OnnxNodeRegistrar g_onnxLrnxParser("Lrn", new OnnxLrnParser()); | ||||
| @@ -27,7 +27,7 @@ class OnnxLrnParser : public OnnxNodeParser { | |||||
| OnnxLrnParser() : OnnxNodeParser("Lrn") {} | OnnxLrnParser() : OnnxNodeParser("Lrn") {} | ||||
| ~OnnxLrnParser() override = default; | ~OnnxLrnParser() override = default; | ||||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -19,22 +19,13 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS OnnxLstmParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { | |||||
| lite::PrimitiveC *OnnxLstmParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||||
| const onnx::NodeProto &onnx_node) { | |||||
| MS_LOG(DEBUG) << "onnx LstmParser"; | MS_LOG(DEBUG) << "onnx LstmParser"; | ||||
| if (op == nullptr) { | |||||
| MS_LOG(ERROR) << "op is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (op->primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "op->primitive is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| std::unique_ptr<schema::LstmT> attr = std::make_unique<schema::LstmT>(); | |||||
| auto attr = std::make_unique<schema::LstmT>(); | |||||
| if (attr == nullptr) { | if (attr == nullptr) { | ||||
| MS_LOG(ERROR) << "new op failed"; | MS_LOG(ERROR) << "new op failed"; | ||||
| return RET_NULL_PTR; | |||||
| return nullptr; | |||||
| } | } | ||||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | for (const auto &onnx_node_attr : onnx_node.attribute()) { | ||||
| @@ -44,9 +35,14 @@ STATUS OnnxLstmParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod | |||||
| } | } | ||||
| } | } | ||||
| op->primitive->value.type = schema::PrimitiveType_Lstm; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitive failed"; | |||||
| return nullptr; | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_Lstm; | |||||
| primitive->value.value = attr.release(); | |||||
| return PrimitiveC::Create(primitive.release()); | |||||
| } | } | ||||
| OnnxNodeRegistrar g_onnxLstmParser("LSTM", new OnnxLstmParser()); | OnnxNodeRegistrar g_onnxLstmParser("LSTM", new OnnxLstmParser()); | ||||
| @@ -27,7 +27,7 @@ class OnnxLstmParser : public OnnxNodeParser { | |||||
| OnnxLstmParser() : OnnxNodeParser("LSTM") {} | OnnxLstmParser() : OnnxNodeParser("LSTM") {} | ||||
| ~OnnxLstmParser() override = default; | ~OnnxLstmParser() override = default; | ||||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -19,23 +19,13 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS OnnxMatmulParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||||
| schema::CNodeT *op) { | |||||
| lite::PrimitiveC *OnnxMatmulParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||||
| const onnx::NodeProto &onnx_node) { | |||||
| MS_LOG(DEBUG) << "onnx MatMulParser"; | MS_LOG(DEBUG) << "onnx MatMulParser"; | ||||
| if (op == nullptr) { | |||||
| MS_LOG(ERROR) << "op is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (op->primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "op->primitive is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| std::unique_ptr<schema::MatMulT> attr = std::make_unique<schema::MatMulT>(); | |||||
| auto attr = std::make_unique<schema::MatMulT>(); | |||||
| if (attr == nullptr) { | if (attr == nullptr) { | ||||
| MS_LOG(ERROR) << "new op failed"; | MS_LOG(ERROR) << "new op failed"; | ||||
| return RET_NULL_PTR; | |||||
| return nullptr; | |||||
| } | } | ||||
| float alpha = 1.0f; | float alpha = 1.0f; | ||||
| @@ -54,12 +44,17 @@ STATUS OnnxMatmulParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N | |||||
| } | } | ||||
| if (alpha != 1 || beta != 1) { | if (alpha != 1 || beta != 1) { | ||||
| MS_LOG(ERROR) << "not support alpha * A * B + beta * C"; | MS_LOG(ERROR) << "not support alpha * A * B + beta * C"; | ||||
| return RET_ERROR; | |||||
| return nullptr; | |||||
| } | } | ||||
| op->primitive->value.type = schema::PrimitiveType_MatMul; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitive failed"; | |||||
| return nullptr; | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_MatMul; | |||||
| primitive->value.value = attr.release(); | |||||
| return PrimitiveC::Create(primitive.release()); | |||||
| } | } | ||||
| OnnxNodeRegistrar g_onnxMatmulParser("MatMul", new OnnxMatmulParser()); | OnnxNodeRegistrar g_onnxMatmulParser("MatMul", new OnnxMatmulParser()); | ||||
| @@ -27,7 +27,7 @@ class OnnxMatmulParser : public OnnxNodeParser { | |||||
| OnnxMatmulParser() : OnnxNodeParser("MatMul") {} | OnnxMatmulParser() : OnnxNodeParser("MatMul") {} | ||||
| ~OnnxMatmulParser() override = default; | ~OnnxMatmulParser() override = default; | ||||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -26,75 +26,57 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include <memory> | #include <memory> | ||||
| #include <set> | #include <set> | ||||
| #include <map> | |||||
| #include <unordered_map> | |||||
| #include "securec/include/securec.h" | #include "securec/include/securec.h" | ||||
| #include "tools/converter/model_parser.h" | #include "tools/converter/model_parser.h" | ||||
| #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" | #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" | ||||
| #include "tools/converter/parser/onnx/onnx_tensor_parser.h" | |||||
| #include "proto/onnx.pb.h" | #include "proto/onnx.pb.h" | ||||
| #include "src/param_value_lite.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| class OnnxModelParser : public ModelParser { | class OnnxModelParser : public ModelParser { | ||||
| public: | public: | ||||
| OnnxModelParser(); | |||||
| OnnxModelParser() = default; | |||||
| virtual ~OnnxModelParser(); | |||||
| ~OnnxModelParser() override = default; | |||||
| // schema::MetaGraphT *ParseGraph(const onnx::GraphProto &graph, const QuantType &quantType = QuantType_QUANT_NONE); | |||||
| int ParseGraph(schema::MetaGraphT *dst_graph, schema::SubGraphT *dst_sub_graph, const onnx::GraphProto &onnx_graph, | |||||
| const QuantType &quantType); | |||||
| MetaGraphT *ParseToFb(const std::string &model_file, const std::string &weight_file, | |||||
| const QuantType &quant_type) override { | |||||
| return nullptr; | |||||
| } | |||||
| FuncGraphPtr Parse(const std::string &model_file, const std::string &weight_file, | |||||
| const QuantType &quant_type) override; | |||||
| static TypeId GetDataTypeFromOnnx(onnx::TensorProto_DataType onnx_type); | static TypeId GetDataTypeFromOnnx(onnx::TensorProto_DataType onnx_type); | ||||
| static STATUS CopyOnnxTensorData(const onnx::TensorProto &onnx_const_value, | |||||
| const ParamValueLitePtr ¶m_value_lite); | |||||
| private: | private: | ||||
| schema::MetaGraphT *ParseToFb(const std::string &model_file, const std::string &weight_file, | |||||
| const QuantType &quant_type = QuantType_QUANT_NONE) override; | |||||
| std::vector<int32_t> GetDimsFromOnnxValue(const onnx::ValueInfoProto &onnx_value); | |||||
| STATUS SetGraphConstTensor(const onnx::GraphProto &onnx_graph); | |||||
| STATUS SetGraphInputTensor(const onnx::GraphProto &onnx_graph, schema::SubGraphT *graph); | |||||
| STATUS SetGraphOutputTensor(const onnx::GraphProto &onnx_graph, schema::SubGraphT *graph); | |||||
| STATUS AddValueInfo(const onnx::ValueInfoProto &proto, const std::string &name, const Category &type, int *index); | |||||
| STATUS AddTensorProto(const onnx::TensorProto &proto, const std::string &name, const Category &type, int *index); | |||||
| STATUS ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||||
| schema::CNodeT *dst_op, const QuantType &quantType, schema::MetaGraphT *dst_graph); | |||||
| void ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||||
| schema::SubGraphT *sub_graph, schema::MetaGraphT *graph, const QuantType &quant_type); | |||||
| STATUS ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node); | |||||
| STATUS ParseOnnxNodeAttr(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||||
| const string &onnx_op_type, schema::CNodeT *dst_op); | |||||
| void SetOpQuantParams(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *dst_op, | |||||
| schema::TensorT *dst_tensor); | |||||
| STATUS SetOpInputIndex(const std::vector<string> &node_inputs, schema::CNodeT *dst_op, | |||||
| const onnx::NodeProto &onnx_node); | |||||
| STATUS SetOpOutputIndex(const std::vector<string> &node_outputs, schema::CNodeT *dst_op); | |||||
| STATUS CopyOnnxTensorData(const onnx::TensorProto &onnx_init_value, schema::TensorT *tensor); | |||||
| STATUS SetAllTensors(schema::MetaGraphT *graphDef); | |||||
| void FindGraphInputAndConst(const onnx::GraphProto &onnx_graph); | |||||
| STATUS ParseSubgraph(schema::CNodeT *dst_op, const onnx::NodeProto &onnx_node, const QuantType &quantType, | |||||
| schema::MetaGraphT *dst_graph); | |||||
| private: | |||||
| std::vector<std::string> graphInputNames; | |||||
| std::vector<std::string> graphConstNames; | |||||
| int subGraphNum = 0; | |||||
| STATUS InitOriginModel(const std::string &model_file); | |||||
| STATUS ConvertNodes(); | |||||
| STATUS ConvertConstTensors(); | |||||
| STATUS ConvertGraphInputs(); | |||||
| STATUS ConvertGraphOutputs(); | |||||
| STATUS BuildReturnNode(const std::vector<AnfNodePtr> &return_inputs); | |||||
| STATUS BuildParameterNode(const ParameterPtr ¶meter_node, const onnx::TensorProto &tensor); | |||||
| STATUS BuildParameterNodeForQuantParam(void *data, const std::string &name, TypeId type); | |||||
| STATUS BuildCNode(const onnx::NodeProto &onnx_node, lite::PrimitiveC *primitive_c); | |||||
| STATUS BuildOpOutputs(const onnx::NodeProto &onnx_node, const CNodePtr &cnode); | |||||
| STATUS ConvertSpecialOnnxNode(const onnx::NodeProto &onnx_node, lite::PrimitiveC *primitive_c); | |||||
| STATUS ConvertOnnxGemmNode(const onnx::NodeProto &onnx_node, lite::PrimitiveC *primitive_c); | |||||
| STATUS BuildCNodeForGemm(const onnx::NodeProto &onnx_node, lite::PrimitiveC *primitive_c, const std::string &name); | |||||
| STATUS ConvertOpQuantParams(const onnx::NodeProto &onnx_node, lite::PrimitiveC *primitive_c); | |||||
| STATUS ParseQuantParam(const onnx::NodeProto &onnx_node); | |||||
| STATUS SetTensorQuantParam(const std::string &tensor_name, std::vector<QuantParamT> *quant_params); | |||||
| STATUS SetTensorQuantParamFromNode(const std::string &tensor_name, std::vector<QuantParamT> *quant_params); | |||||
| STATUS CopyTensorQuantParam(const std::string &tensor_name, QuantParamT *quant_param, bool scale_or_not); | |||||
| bool IsSpecialOnnxNode(const onnx::NodeProto &onnx_node); | |||||
| onnx::ModelProto onnx_model_; | |||||
| onnx::GraphProto onnx_graph_; | |||||
| std::unordered_map<std::string, AnfNodePtr> nodes_; | |||||
| FuncGraphPtr func_graph_ptr_ = nullptr; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -15,7 +15,9 @@ | |||||
| */ | */ | ||||
| #include "tools/converter/parser/onnx/onnx_node_parser.h" | #include "tools/converter/parser/onnx/onnx_node_parser.h" | ||||
| #include <algorithm> | |||||
| #include <vector> | #include <vector> | ||||
| #include <memory> | |||||
| #include "tools/converter/parser/onnx/onnx_model_parser.h" | #include "tools/converter/parser/onnx/onnx_model_parser.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| @@ -20,6 +20,7 @@ | |||||
| #include <string> | #include <string> | ||||
| #include <utility> | #include <utility> | ||||
| #include <vector> | #include <vector> | ||||
| #include "src/ops/primitive_c.h" | |||||
| #include "google/protobuf/message.h" | #include "google/protobuf/message.h" | ||||
| #include "proto/onnx.pb.h" | #include "proto/onnx.pb.h" | ||||
| #include "include/errorcode.h" | #include "include/errorcode.h" | ||||
| @@ -34,7 +35,8 @@ class OnnxNodeParser { | |||||
| virtual ~OnnxNodeParser() = default; | virtual ~OnnxNodeParser() = default; | ||||
| virtual STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) = 0; | |||||
| virtual lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||||
| const onnx::NodeProto &onnx_node) = 0; | |||||
| static STATUS GetTensorDataFromOnnx(const onnx::TensorProto &onnx_tensor, std::vector<float> *value, int *type); | static STATUS GetTensorDataFromOnnx(const onnx::TensorProto &onnx_tensor, std::vector<float> *value, int *type); | ||||
| @@ -19,23 +19,13 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS OnnxNonMaxSuppressionParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||||
| schema::CNodeT *op) { | |||||
| lite::PrimitiveC *OnnxNonMaxSuppressionParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||||
| const onnx::NodeProto &onnx_node) { | |||||
| MS_LOG(DEBUG) << "onnx EluParser"; | MS_LOG(DEBUG) << "onnx EluParser"; | ||||
| if (op == nullptr) { | |||||
| MS_LOG(ERROR) << "op is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (op->primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "op->primitive is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| std::unique_ptr<schema::NonMaxSuppressionT> attr = std::make_unique<schema::NonMaxSuppressionT>(); | |||||
| auto attr = std::make_unique<schema::NonMaxSuppressionT>(); | |||||
| if (attr == nullptr) { | if (attr == nullptr) { | ||||
| MS_LOG(ERROR) << "new op failed"; | MS_LOG(ERROR) << "new op failed"; | ||||
| return RET_NULL_PTR; | |||||
| return nullptr; | |||||
| } | } | ||||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | for (const auto &onnx_node_attr : onnx_node.attribute()) { | ||||
| @@ -47,9 +37,14 @@ STATUS OnnxNonMaxSuppressionParser::Parse(const onnx::GraphProto &onnx_graph, co | |||||
| } | } | ||||
| } | } | ||||
| op->primitive->value.type = schema::PrimitiveType_NonMaxSuppression; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitive failed"; | |||||
| return nullptr; | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_NonMaxSuppression; | |||||
| primitive->value.value = attr.release(); | |||||
| return PrimitiveC::Create(primitive.release()); | |||||
| } | } | ||||
| OnnxNodeRegistrar g_onnxNonMaxSuppressionParser("NonMaxSuppression", new OnnxNonMaxSuppressionParser()); | OnnxNodeRegistrar g_onnxNonMaxSuppressionParser("NonMaxSuppression", new OnnxNonMaxSuppressionParser()); | ||||
| @@ -27,7 +27,7 @@ class OnnxNonMaxSuppressionParser : public OnnxNodeParser { | |||||
| OnnxNonMaxSuppressionParser() : OnnxNodeParser("NonMaxSuppression") {} | OnnxNonMaxSuppressionParser() : OnnxNodeParser("NonMaxSuppression") {} | ||||
| ~OnnxNonMaxSuppressionParser() override = default; | ~OnnxNonMaxSuppressionParser() override = default; | ||||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -19,23 +19,13 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS OnnxOneHotParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||||
| schema::CNodeT *op) { | |||||
| lite::PrimitiveC *OnnxOneHotParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||||
| const onnx::NodeProto &onnx_node) { | |||||
| MS_LOG(DEBUG) << "onnx OneHotParser"; | MS_LOG(DEBUG) << "onnx OneHotParser"; | ||||
| if (op == nullptr) { | |||||
| MS_LOG(ERROR) << "op is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (op->primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "op->primitive is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| std::unique_ptr<schema::OneHotT> attr = std::make_unique<schema::OneHotT>(); | |||||
| auto attr = std::make_unique<schema::OneHotT>(); | |||||
| if (attr == nullptr) { | if (attr == nullptr) { | ||||
| MS_LOG(ERROR) << "new op failed"; | MS_LOG(ERROR) << "new op failed"; | ||||
| return RET_NULL_PTR; | |||||
| return nullptr; | |||||
| } | } | ||||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | for (const auto &onnx_node_attr : onnx_node.attribute()) { | ||||
| @@ -45,9 +35,14 @@ STATUS OnnxOneHotParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N | |||||
| } | } | ||||
| } | } | ||||
| op->primitive->value.type = schema::PrimitiveType_OneHot; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitive failed"; | |||||
| return nullptr; | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_OneHot; | |||||
| primitive->value.value = attr.release(); | |||||
| return PrimitiveC::Create(primitive.release()); | |||||
| } | } | ||||
| OnnxNodeRegistrar g_onnxOneHotParser("OneHot", new OnnxOneHotParser()); | OnnxNodeRegistrar g_onnxOneHotParser("OneHot", new OnnxOneHotParser()); | ||||
| @@ -27,7 +27,7 @@ class OnnxOneHotParser : public OnnxNodeParser { | |||||
| OnnxOneHotParser() : OnnxNodeParser("OneHot") {} | OnnxOneHotParser() : OnnxNodeParser("OneHot") {} | ||||
| ~OnnxOneHotParser() override = default; | ~OnnxOneHotParser() override = default; | ||||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -19,22 +19,13 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS OnnxPadParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { | |||||
| lite::PrimitiveC *OnnxPadParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||||
| const onnx::NodeProto &onnx_node) { | |||||
| MS_LOG(DEBUG) << "onnx PadParser"; | MS_LOG(DEBUG) << "onnx PadParser"; | ||||
| if (op == nullptr) { | |||||
| MS_LOG(ERROR) << "op is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (op->primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "op->primitive is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| std::unique_ptr<schema::PadT> attr = std::make_unique<schema::PadT>(); | |||||
| auto attr = std::make_unique<schema::PadT>(); | |||||
| if (attr == nullptr) { | if (attr == nullptr) { | ||||
| MS_LOG(ERROR) << "new op failed"; | MS_LOG(ERROR) << "new op failed"; | ||||
| return RET_NULL_PTR; | |||||
| return nullptr; | |||||
| } | } | ||||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | for (const auto &onnx_node_attr : onnx_node.attribute()) { | ||||
| @@ -58,9 +49,14 @@ STATUS OnnxPadParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Node | |||||
| } | } | ||||
| } | } | ||||
| op->primitive->value.type = schema::PrimitiveType_Pad; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitive failed"; | |||||
| return nullptr; | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_Pad; | |||||
| primitive->value.value = attr.release(); | |||||
| return PrimitiveC::Create(primitive.release()); | |||||
| } | } | ||||
| OnnxNodeRegistrar g_onnxPadParser("Pad", new OnnxPadParser()); | OnnxNodeRegistrar g_onnxPadParser("Pad", new OnnxPadParser()); | ||||
| @@ -27,7 +27,7 @@ class OnnxPadParser : public OnnxNodeParser { | |||||
| OnnxPadParser() : OnnxNodeParser("Pad") {} | OnnxPadParser() : OnnxNodeParser("Pad") {} | ||||
| ~OnnxPadParser() override = default; | ~OnnxPadParser() override = default; | ||||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -19,22 +19,13 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS OnnxPoolParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { | |||||
| lite::PrimitiveC *OnnxPoolParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||||
| const onnx::NodeProto &onnx_node) { | |||||
| MS_LOG(DEBUG) << "onnx PoolParser"; | MS_LOG(DEBUG) << "onnx PoolParser"; | ||||
| if (op == nullptr) { | |||||
| MS_LOG(ERROR) << "op is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (op->primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "op->primitive is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| std::unique_ptr<schema::PoolingT> attr = std::make_unique<schema::PoolingT>(); | |||||
| auto attr = std::make_unique<schema::PoolingT>(); | |||||
| if (attr == nullptr) { | if (attr == nullptr) { | ||||
| MS_LOG(ERROR) << "new op failed"; | MS_LOG(ERROR) << "new op failed"; | ||||
| return RET_NULL_PTR; | |||||
| return nullptr; | |||||
| } | } | ||||
| attr->format = schema::Format::Format_NCHW; | attr->format = schema::Format::Format_NCHW; | ||||
| @@ -56,7 +47,7 @@ STATUS OnnxPoolParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod | |||||
| attr->global = false; | attr->global = false; | ||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "Pooling param`s PoolingMode is not MAX either AVE. MindSpore support MAX and AVE only."; | MS_LOG(ERROR) << "Pooling param`s PoolingMode is not MAX either AVE. MindSpore support MAX and AVE only."; | ||||
| return RET_ERROR; | |||||
| return nullptr; | |||||
| } | } | ||||
| attr->roundMode = schema::RoundMode_FLOOR; | attr->roundMode = schema::RoundMode_FLOOR; | ||||
| @@ -101,13 +92,18 @@ STATUS OnnxPoolParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod | |||||
| } | } | ||||
| if (attribute_name == "dilations") { | if (attribute_name == "dilations") { | ||||
| MS_LOG(ERROR) << "pooling op not support dilations now"; | MS_LOG(ERROR) << "pooling op not support dilations now"; | ||||
| return RET_ERROR; | |||||
| return nullptr; | |||||
| } | } | ||||
| } | } | ||||
| op->primitive->value.type = schema::PrimitiveType_Pooling; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitive failed"; | |||||
| return nullptr; | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_Pooling; | |||||
| primitive->value.value = attr.release(); | |||||
| return PrimitiveC::Create(primitive.release()); | |||||
| } | } | ||||
| OnnxNodeRegistrar g_onnxMaxPoolParser("MaxPool", new OnnxPoolParser()); | OnnxNodeRegistrar g_onnxMaxPoolParser("MaxPool", new OnnxPoolParser()); | ||||
| @@ -27,7 +27,7 @@ class OnnxPoolParser : public OnnxNodeParser { | |||||
| OnnxPoolParser() : OnnxNodeParser("Pool") {} | OnnxPoolParser() : OnnxNodeParser("Pool") {} | ||||
| ~OnnxPoolParser() override = default; | ~OnnxPoolParser() override = default; | ||||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -19,23 +19,13 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS OnnxQuantizeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||||
| schema::CNodeT *op) { | |||||
| lite::PrimitiveC *OnnxQuantizeParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||||
| const onnx::NodeProto &onnx_node) { | |||||
| MS_LOG(DEBUG) << "onnx QuantizeDequantizeParser"; | MS_LOG(DEBUG) << "onnx QuantizeDequantizeParser"; | ||||
| if (op == nullptr) { | |||||
| MS_LOG(ERROR) << "op is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (op->primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "op->primitive is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| std::unique_ptr<schema::QuantDTypeCastT> attr = std::make_unique<schema::QuantDTypeCastT>(); | |||||
| auto attr = std::make_unique<schema::QuantDTypeCastT>(); | |||||
| if (attr == nullptr) { | if (attr == nullptr) { | ||||
| MS_LOG(ERROR) << "new op failed."; | MS_LOG(ERROR) << "new op failed."; | ||||
| return RET_NULL_PTR; | |||||
| return nullptr; | |||||
| } | } | ||||
| if (onnx_node.op_type() == "Int8Quantize") { | if (onnx_node.op_type() == "Int8Quantize") { | ||||
| attr->srcT = kNumberTypeFloat32; | attr->srcT = kNumberTypeFloat32; | ||||
| @@ -45,11 +35,16 @@ STATUS OnnxQuantizeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx: | |||||
| attr->dstT = kNumberTypeFloat32; | attr->dstT = kNumberTypeFloat32; | ||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "Unsupported nodeType: " << onnx_node.op_type().c_str(); | MS_LOG(ERROR) << "Unsupported nodeType: " << onnx_node.op_type().c_str(); | ||||
| return RET_ERROR; | |||||
| return nullptr; | |||||
| } | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitive failed"; | |||||
| return nullptr; | |||||
| } | } | ||||
| op->primitive->value.type = schema::PrimitiveType_QuantDTypeCast; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | |||||
| primitive->value.type = schema::PrimitiveType_QuantDTypeCast; | |||||
| primitive->value.value = attr.release(); | |||||
| return PrimitiveC::Create(primitive.release()); | |||||
| } | } | ||||
| OnnxNodeRegistrar g_onnxInt8QuantizeParser("Int8Quantize", new OnnxQuantizeParser()); | OnnxNodeRegistrar g_onnxInt8QuantizeParser("Int8Quantize", new OnnxQuantizeParser()); | ||||
| @@ -27,7 +27,7 @@ class OnnxQuantizeParser : public OnnxNodeParser { | |||||
| OnnxQuantizeParser() : OnnxNodeParser("Quantize") {} | OnnxQuantizeParser() : OnnxNodeParser("Quantize") {} | ||||
| ~OnnxQuantizeParser() override = default; | ~OnnxQuantizeParser() override = default; | ||||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -19,28 +19,23 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS OnnxRangeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||||
| schema::CNodeT *op) { | |||||
| lite::PrimitiveC *OnnxRangeParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||||
| const onnx::NodeProto &onnx_node) { | |||||
| MS_LOG(DEBUG) << "onnx RangeParser"; | MS_LOG(DEBUG) << "onnx RangeParser"; | ||||
| if (op == nullptr) { | |||||
| MS_LOG(ERROR) << "op is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (op->primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "op->primitive is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| std::unique_ptr<schema::RangeT> attr = std::make_unique<schema::RangeT>(); | |||||
| auto attr = std::make_unique<schema::RangeT>(); | |||||
| if (attr == nullptr) { | if (attr == nullptr) { | ||||
| MS_LOG(ERROR) << "new op failed"; | MS_LOG(ERROR) << "new op failed"; | ||||
| return RET_NULL_PTR; | |||||
| return nullptr; | |||||
| } | } | ||||
| attr->dType = 0; | attr->dType = 0; | ||||
| op->primitive->value.type = schema::PrimitiveType_Range; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitive failed"; | |||||
| return nullptr; | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_Range; | |||||
| primitive->value.value = attr.release(); | |||||
| return PrimitiveC::Create(primitive.release()); | |||||
| } | } | ||||
| OnnxNodeRegistrar g_onnxRangeParser("Range", new OnnxRangeParser()); | OnnxNodeRegistrar g_onnxRangeParser("Range", new OnnxRangeParser()); | ||||
| @@ -27,7 +27,7 @@ class OnnxRangeParser : public OnnxNodeParser { | |||||
| OnnxRangeParser() : OnnxNodeParser("Range") {} | OnnxRangeParser() : OnnxNodeParser("Range") {} | ||||
| ~OnnxRangeParser() override = default; | ~OnnxRangeParser() override = default; | ||||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -19,23 +19,13 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS OnnxReduceParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||||
| schema::CNodeT *op) { | |||||
| lite::PrimitiveC *OnnxReduceParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||||
| const onnx::NodeProto &onnx_node) { | |||||
| MS_LOG(DEBUG) << "onnx ReduceParser"; | MS_LOG(DEBUG) << "onnx ReduceParser"; | ||||
| if (op == nullptr) { | |||||
| MS_LOG(ERROR) << "op is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (op->primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "op->primitive is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| std::unique_ptr<schema::ReduceT> attr = std::make_unique<schema::ReduceT>(); | |||||
| auto attr = std::make_unique<schema::ReduceT>(); | |||||
| if (attr == nullptr) { | if (attr == nullptr) { | ||||
| MS_LOG(ERROR) << "new op failed"; | MS_LOG(ERROR) << "new op failed"; | ||||
| return RET_NULL_PTR; | |||||
| return nullptr; | |||||
| } | } | ||||
| attr->keepDims = 1; | attr->keepDims = 1; | ||||
| @@ -65,12 +55,17 @@ STATUS OnnxReduceParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N | |||||
| attr->mode = schema::ReduceMode_ReduceSumSquare; | attr->mode = schema::ReduceMode_ReduceSumSquare; | ||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "unsupported type"; | MS_LOG(ERROR) << "unsupported type"; | ||||
| return RET_ERROR; | |||||
| return nullptr; | |||||
| } | } | ||||
| op->primitive->value.type = schema::PrimitiveType_Reduce; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitive failed"; | |||||
| return nullptr; | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_Reduce; | |||||
| primitive->value.value = attr.release(); | |||||
| return PrimitiveC::Create(primitive.release()); | |||||
| } | } | ||||
| OnnxNodeRegistrar g_onnxReduceMeanParser("ReduceMean", new OnnxReduceParser()); | OnnxNodeRegistrar g_onnxReduceMeanParser("ReduceMean", new OnnxReduceParser()); | ||||
| @@ -27,7 +27,7 @@ class OnnxReduceParser : public OnnxNodeParser { | |||||
| OnnxReduceParser() : OnnxNodeParser("Reduce") {} | OnnxReduceParser() : OnnxNodeParser("Reduce") {} | ||||
| ~OnnxReduceParser() override = default; | ~OnnxReduceParser() override = default; | ||||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -21,22 +21,13 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS OnnxReluParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { | |||||
| lite::PrimitiveC *OnnxReluParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||||
| const onnx::NodeProto &onnx_node) { | |||||
| MS_LOG(DEBUG) << "onnx ReluParser"; | MS_LOG(DEBUG) << "onnx ReluParser"; | ||||
| if (op == nullptr) { | |||||
| MS_LOG(ERROR) << "op is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (op->primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "op->primitive is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| std::unique_ptr<schema::ActivationT> attr = std::make_unique<schema::ActivationT>(); | |||||
| auto attr = std::make_unique<schema::ActivationT>(); | |||||
| if (attr == nullptr) { | if (attr == nullptr) { | ||||
| MS_LOG(ERROR) << "new op failed"; | MS_LOG(ERROR) << "new op failed"; | ||||
| return RET_NULL_PTR; | |||||
| return nullptr; | |||||
| } | } | ||||
| const auto &relu_type = onnx_node.op_type(); | const auto &relu_type = onnx_node.op_type(); | ||||
| @@ -54,29 +45,24 @@ STATUS OnnxReluParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod | |||||
| } | } | ||||
| } | } | ||||
| op->primitive->value.type = schema::PrimitiveType_Activation; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitive failed"; | |||||
| return nullptr; | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_Activation; | |||||
| primitive->value.value = attr.release(); | |||||
| return PrimitiveC::Create(primitive.release()); | |||||
| } | } | ||||
| STATUS OnnxPReluParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||||
| schema::CNodeT *op) { | |||||
| lite::PrimitiveC *OnnxPReluParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||||
| const onnx::NodeProto &onnx_node) { | |||||
| MS_LOG(DEBUG) << "onnx PReluParser"; | MS_LOG(DEBUG) << "onnx PReluParser"; | ||||
| if (op == nullptr) { | |||||
| MS_LOG(ERROR) << "op is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (op->primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "op->primitive is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| if (onnx_node.input_size() != 2) { | if (onnx_node.input_size() != 2) { | ||||
| MS_LOG(ERROR) << "input num should be 2"; | MS_LOG(ERROR) << "input num should be 2"; | ||||
| return RET_ERROR; | |||||
| return nullptr; | |||||
| } | } | ||||
| std::unique_ptr<schema::PReLUT> attr = std::make_unique<schema::PReLUT>(); | |||||
| auto attr = std::make_unique<schema::PReLUT>(); | |||||
| std::vector<onnx::TensorProto> params; | std::vector<onnx::TensorProto> params; | ||||
| const auto &input_name = onnx_node.input(1); | const auto &input_name = onnx_node.input(1); | ||||
| for (const auto &it : onnx_graph.initializer()) { | for (const auto &it : onnx_graph.initializer()) { | ||||
| @@ -90,7 +76,7 @@ STATUS OnnxPReluParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::No | |||||
| const onnx::TensorProto *slope = ¶ms[0]; | const onnx::TensorProto *slope = ¶ms[0]; | ||||
| if (slope == nullptr) { | if (slope == nullptr) { | ||||
| MS_LOG(ERROR) << "input error: params[0] is null"; | MS_LOG(ERROR) << "input error: params[0] is null"; | ||||
| return RET_ERROR; | |||||
| return nullptr; | |||||
| } | } | ||||
| const auto slope_raw_data = reinterpret_cast<const float *>(slope->raw_data().data()); | const auto slope_raw_data = reinterpret_cast<const float *>(slope->raw_data().data()); | ||||
| const int64_t slope_size = slope->raw_data().size() / sizeof(float); | const int64_t slope_size = slope->raw_data().size() / sizeof(float); | ||||
| @@ -102,16 +88,21 @@ STATUS OnnxPReluParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::No | |||||
| attr->channelShared = false; | attr->channelShared = false; | ||||
| if (memcpy_s(attr->slope.data(), slope_size * sizeof(float), slope_raw_data, slope_size * sizeof(float)) != EOK) { | if (memcpy_s(attr->slope.data(), slope_size * sizeof(float), slope_raw_data, slope_size * sizeof(float)) != EOK) { | ||||
| MS_LOG(ERROR) << "memcpy_s failed"; | MS_LOG(ERROR) << "memcpy_s failed"; | ||||
| return RET_ERROR; | |||||
| return nullptr; | |||||
| } | } | ||||
| } | } | ||||
| } else { | } else { | ||||
| MS_LOG(WARNING) << "The slope pf prelu is null, which may cause errors."; | MS_LOG(WARNING) << "The slope pf prelu is null, which may cause errors."; | ||||
| } | } | ||||
| op->primitive->value.type = schema::PrimitiveType_PReLU; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitive failed"; | |||||
| return nullptr; | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_PReLU; | |||||
| primitive->value.value = attr.release(); | |||||
| return PrimitiveC::Create(primitive.release()); | |||||
| } | } | ||||
| OnnxNodeRegistrar g_onnxReluParser("Relu", new OnnxReluParser()); | OnnxNodeRegistrar g_onnxReluParser("Relu", new OnnxReluParser()); | ||||
| @@ -27,7 +27,7 @@ class OnnxReluParser : public OnnxNodeParser { | |||||
| OnnxReluParser() : OnnxNodeParser("Relu") {} | OnnxReluParser() : OnnxNodeParser("Relu") {} | ||||
| ~OnnxReluParser() override = default; | ~OnnxReluParser() override = default; | ||||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||||
| }; | }; | ||||
| class OnnxPReluParser : public OnnxNodeParser { | class OnnxPReluParser : public OnnxNodeParser { | ||||
| @@ -35,7 +35,7 @@ class OnnxPReluParser : public OnnxNodeParser { | |||||
| OnnxPReluParser() : OnnxNodeParser("Prelu") {} | OnnxPReluParser() : OnnxNodeParser("Prelu") {} | ||||
| ~OnnxPReluParser() override = default; | ~OnnxPReluParser() override = default; | ||||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -20,23 +20,13 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS OnnxReshapeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||||
| schema::CNodeT *op) { | |||||
| lite::PrimitiveC *OnnxReshapeParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||||
| const onnx::NodeProto &onnx_node) { | |||||
| MS_LOG(DEBUG) << "onnx ReshapeParser"; | MS_LOG(DEBUG) << "onnx ReshapeParser"; | ||||
| if (op == nullptr) { | |||||
| MS_LOG(ERROR) << "op is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (op->primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "op->primitive is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| std::unique_ptr<schema::ReshapeT> attr = std::make_unique<schema::ReshapeT>(); | |||||
| auto attr = std::make_unique<schema::ReshapeT>(); | |||||
| if (attr == nullptr) { | if (attr == nullptr) { | ||||
| MS_LOG(ERROR) << "new op failed"; | MS_LOG(ERROR) << "new op failed"; | ||||
| return RET_NULL_PTR; | |||||
| return nullptr; | |||||
| } | } | ||||
| attr->format = schema::Format_NCHW; | attr->format = schema::Format_NCHW; | ||||
| @@ -51,28 +41,17 @@ STATUS OnnxReshapeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx:: | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| } else { | |||||
| onnx::TensorProto input_shape; | |||||
| const auto &shape_name = onnx_node.input(1); | |||||
| for (const auto &it : onnx_graph.initializer()) { | |||||
| if (it.name() == shape_name) { | |||||
| input_shape = it; | |||||
| break; | |||||
| } | |||||
| } | |||||
| if (input_shape.int64_data_size() == 0) { | |||||
| MS_LOG(INFO) << "shape maybe from another op other than const initializer"; | |||||
| } else { | |||||
| for (int i = 0; i < input_shape.int64_data_size(); ++i) { | |||||
| shape.push_back(input_shape.int64_data(i)); | |||||
| } | |||||
| } | |||||
| } | } | ||||
| attr->shape = shape; | attr->shape = shape; | ||||
| op->primitive->value.type = schema::PrimitiveType_Reshape; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitive failed"; | |||||
| return nullptr; | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_Reshape; | |||||
| primitive->value.value = attr.release(); | |||||
| return PrimitiveC::Create(primitive.release()); | |||||
| } | } | ||||
| OnnxNodeRegistrar g_onnxReshapeParser("Reshape", new OnnxReshapeParser()); | OnnxNodeRegistrar g_onnxReshapeParser("Reshape", new OnnxReshapeParser()); | ||||
| @@ -27,7 +27,7 @@ class OnnxReshapeParser : public OnnxNodeParser { | |||||
| OnnxReshapeParser() : OnnxNodeParser("Reshape") {} | OnnxReshapeParser() : OnnxNodeParser("Reshape") {} | ||||
| ~OnnxReshapeParser() override = default; | ~OnnxReshapeParser() override = default; | ||||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -22,23 +22,13 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS OnnxResizeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||||
| schema::CNodeT *op) { | |||||
| lite::PrimitiveC *OnnxResizeParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||||
| const onnx::NodeProto &onnx_node) { | |||||
| MS_LOG(DEBUG) << "onnx ResizeParser"; | MS_LOG(DEBUG) << "onnx ResizeParser"; | ||||
| if (op == nullptr) { | |||||
| MS_LOG(ERROR) << "op is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (op->primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "op->primitive is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| std::unique_ptr<schema::ResizeT> attr = std::make_unique<schema::ResizeT>(); | |||||
| auto attr = std::make_unique<schema::ResizeT>(); | |||||
| if (attr == nullptr) { | if (attr == nullptr) { | ||||
| MS_LOG(ERROR) << "new op failed"; | MS_LOG(ERROR) << "new op failed"; | ||||
| return RET_NULL_PTR; | |||||
| return nullptr; | |||||
| } | } | ||||
| attr->format = schema::Format_NCHW; | attr->format = schema::Format_NCHW; | ||||
| @@ -85,9 +75,14 @@ STATUS OnnxResizeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N | |||||
| } | } | ||||
| } | } | ||||
| op->primitive->value.type = schema::PrimitiveType_Resize; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitive failed"; | |||||
| return nullptr; | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_Resize; | |||||
| primitive->value.value = attr.release(); | |||||
| return PrimitiveC::Create(primitive.release()); | |||||
| } | } | ||||
| OnnxNodeRegistrar g_onnxResizeParser("Resize", new OnnxResizeParser()); | OnnxNodeRegistrar g_onnxResizeParser("Resize", new OnnxResizeParser()); | ||||
| @@ -27,7 +27,7 @@ class OnnxResizeParser : public OnnxNodeParser { | |||||
| OnnxResizeParser() : OnnxNodeParser("Resize") {} | OnnxResizeParser() : OnnxNodeParser("Resize") {} | ||||
| ~OnnxResizeParser() override = default; | ~OnnxResizeParser() override = default; | ||||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -19,28 +19,23 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS OnnxShapeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||||
| schema::CNodeT *op) { | |||||
| lite::PrimitiveC *OnnxShapeParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||||
| const onnx::NodeProto &onnx_node) { | |||||
| MS_LOG(DEBUG) << "onnx ShapeParser"; | MS_LOG(DEBUG) << "onnx ShapeParser"; | ||||
| if (op == nullptr) { | |||||
| MS_LOG(ERROR) << "op is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (op->primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "op->primitive is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| std::unique_ptr<schema::ShapeT> attr = std::make_unique<schema::ShapeT>(); | |||||
| auto attr = std::make_unique<schema::ShapeT>(); | |||||
| if (attr == nullptr) { | if (attr == nullptr) { | ||||
| MS_LOG(ERROR) << "new op failed"; | MS_LOG(ERROR) << "new op failed"; | ||||
| return RET_NULL_PTR; | |||||
| return nullptr; | |||||
| } | } | ||||
| op->primitive->value.type = schema::PrimitiveType_Shape; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitive failed"; | |||||
| return nullptr; | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_Shape; | |||||
| primitive->value.value = attr.release(); | |||||
| return PrimitiveC::Create(primitive.release()); | |||||
| } | } | ||||
| OnnxNodeRegistrar g_onnxShapeParser("Shape", new OnnxShapeParser()); | OnnxNodeRegistrar g_onnxShapeParser("Shape", new OnnxShapeParser()); | ||||
| @@ -27,7 +27,7 @@ class OnnxShapeParser : public OnnxNodeParser { | |||||
| OnnxShapeParser() : OnnxNodeParser("Shape") {} | OnnxShapeParser() : OnnxNodeParser("Shape") {} | ||||
| ~OnnxShapeParser() override = default; | ~OnnxShapeParser() override = default; | ||||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -19,30 +19,25 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS OnnxSigmoidParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||||
| schema::CNodeT *op) { | |||||
| lite::PrimitiveC *OnnxSigmoidParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||||
| const onnx::NodeProto &onnx_node) { | |||||
| MS_LOG(DEBUG) << "onnx SigmoidParser"; | MS_LOG(DEBUG) << "onnx SigmoidParser"; | ||||
| if (op == nullptr) { | |||||
| MS_LOG(ERROR) << "op is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (op->primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "op->primitive is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| std::unique_ptr<schema::ActivationT> attr = std::make_unique<schema::ActivationT>(); | |||||
| auto attr = std::make_unique<schema::ActivationT>(); | |||||
| if (attr == nullptr) { | if (attr == nullptr) { | ||||
| MS_LOG(ERROR) << "new op failed"; | MS_LOG(ERROR) << "new op failed"; | ||||
| return RET_NULL_PTR; | |||||
| return nullptr; | |||||
| } | } | ||||
| attr->type = schema::ActivationType_SIGMOID; | attr->type = schema::ActivationType_SIGMOID; | ||||
| op->primitive->value.type = schema::PrimitiveType_Activation; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitive failed"; | |||||
| return nullptr; | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_Activation; | |||||
| primitive->value.value = attr.release(); | |||||
| return PrimitiveC::Create(primitive.release()); | |||||
| } | } | ||||
| OnnxNodeRegistrar g_onnxSigmoodParser("Sigmoid", new OnnxSigmoidParser()); | OnnxNodeRegistrar g_onnxSigmoodParser("Sigmoid", new OnnxSigmoidParser()); | ||||
| @@ -27,7 +27,7 @@ class OnnxSigmoidParser : public OnnxNodeParser { | |||||
| OnnxSigmoidParser() : OnnxNodeParser("Sigmoid") {} | OnnxSigmoidParser() : OnnxNodeParser("Sigmoid") {} | ||||
| ~OnnxSigmoidParser() override = default; | ~OnnxSigmoidParser() override = default; | ||||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -23,77 +23,13 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS OnnxSliceParser::InsertTensor(const std::vector<int> &onnx_val, const std::string &name, | |||||
| onnx::NodeProto *onnx_node) { | |||||
| std::unique_ptr<schema::TensorT> tensor = std::make_unique<schema::TensorT>(); | |||||
| if (tensor == nullptr) { | |||||
| MS_LOG(ERROR) << "new tensor failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| tensor->dataType = mindspore::kNumberTypeInt32; | |||||
| tensor->dims.push_back(onnx_val.size()); | |||||
| tensor->format = schema::Format::Format_NCHW; | |||||
| tensor->nodeType = schema::NodeType::NodeType_ValueNode; | |||||
| int data_size = sizeof(int32_t) * onnx_val.size(); | |||||
| tensor->data.resize(data_size); | |||||
| if (data_size != 0 && | |||||
| memcpy_s(static_cast<void *>(tensor->data.data()), data_size, onnx_val.data(), data_size) != EOK) { | |||||
| MS_LOG(ERROR) << "memcpy_s failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| int tensor_num = OnnxTensorParser::GetInstance()->GetTensorCache()->GetCachedTensor().size(); | |||||
| std::string tensor_name = name + std::to_string(tensor_num); | |||||
| OnnxTensorParser::GetInstance()->GetTensorCache()->AddTensor(tensor_name, tensor.release(), GRAPH_INPUT); | |||||
| onnx_node->add_input(tensor_name); | |||||
| return RET_OK; | |||||
| } | |||||
| STATUS OnnxSliceParser::GetInputTensor(std::vector<int> *onnx_val, const std::string &name) { | |||||
| if (onnx_val == nullptr) { | |||||
| MS_LOG(ERROR) << "input vector is nullptr."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (OnnxTensorParser::GetInstance() == nullptr || OnnxTensorParser::GetInstance()->GetTensorCache() == nullptr) { | |||||
| MS_LOG(ERROR) << "cannot get tensorcache."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| int index = OnnxTensorParser::GetInstance()->GetTensorCache()->FindTensor(name); | |||||
| if (index == -1) { | |||||
| MS_LOG(ERROR) << "can not find node: " << name; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto input_tensor = OnnxTensorParser::GetInstance()->GetTensorCache()->GetCachedTensor()[index]; | |||||
| if (input_tensor->data.empty()) { | |||||
| MS_LOG(DEBUG) << "data is empty."; | |||||
| return RET_NO_CHANGE; | |||||
| } | |||||
| int data_num = std::accumulate(input_tensor->dims.begin(), input_tensor->dims.end(), 1, std::multiplies<int>()); | |||||
| onnx_val->resize(data_num); | |||||
| if (memcpy_s(onnx_val->data(), data_num * sizeof(int32_t), input_tensor->data.data(), data_num * sizeof(int32_t)) != | |||||
| EOK) { | |||||
| MS_LOG(ERROR) << "memcpy_s failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| STATUS OnnxSliceParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||||
| schema::CNodeT *op) { | |||||
| lite::PrimitiveC *OnnxSliceParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||||
| const onnx::NodeProto &onnx_node) { | |||||
| MS_LOG(DEBUG) << "onnx SliceParser"; | MS_LOG(DEBUG) << "onnx SliceParser"; | ||||
| if (op == nullptr) { | |||||
| MS_LOG(ERROR) << "op is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (op->primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "op->primitive is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| std::unique_ptr<schema::StridedSliceT> attr = std::make_unique<schema::StridedSliceT>(); | |||||
| auto attr = std::make_unique<schema::StridedSliceT>(); | |||||
| if (attr == nullptr) { | if (attr == nullptr) { | ||||
| MS_LOG(ERROR) << "new op failed"; | MS_LOG(ERROR) << "new op failed"; | ||||
| return RET_NULL_PTR; | |||||
| return nullptr; | |||||
| } | } | ||||
| std::vector<int> starts; | std::vector<int> starts; | ||||
| @@ -128,36 +64,17 @@ STATUS OnnxSliceParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::No | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| int status = RET_OK; | |||||
| switch (onnx_node.input_size()) { | |||||
| case 5: { | |||||
| if (steps.empty()) { | |||||
| status = GetInputTensor(&steps, onnx_node.input(4)); | |||||
| } | |||||
| } | |||||
| case 4: { | |||||
| if (status != RET_ERROR && axes.empty()) { | |||||
| status = GetInputTensor(&axes, onnx_node.input(3)); | |||||
| } | |||||
| } | |||||
| case 3: { | |||||
| if (status != RET_ERROR && ends.empty()) { | |||||
| status = GetInputTensor(&ends, onnx_node.input(2)); | |||||
| } | |||||
| } | |||||
| case 2: { | |||||
| if (status != RET_ERROR && starts.empty()) { | |||||
| status = GetInputTensor(&starts, onnx_node.input(1)); | |||||
| } | |||||
| } | |||||
| default: { | |||||
| if (status == RET_ERROR) { | |||||
| MS_LOG(ERROR) << "onnx slice inputs are invalid."; | |||||
| return RET_INPUT_TENSOR_ERROR; | |||||
| } | |||||
| } | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitive failed"; | |||||
| return nullptr; | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_StridedSlice; | |||||
| primitive->value.value = attr.release(); | |||||
| auto primitive_c = PrimitiveC::Create(primitive.release()); | |||||
| if (starts.empty()) { | |||||
| return primitive_c; | |||||
| } | } | ||||
| if (axes.empty()) { | if (axes.empty()) { | ||||
| for (size_t i = 0; i < starts.size(); ++i) { | for (size_t i = 0; i < starts.size(); ++i) { | ||||
| axes.push_back(i); | axes.push_back(i); | ||||
| @@ -166,42 +83,11 @@ STATUS OnnxSliceParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::No | |||||
| if (steps.empty()) { | if (steps.empty()) { | ||||
| steps.assign(starts.size(), 1); | steps.assign(starts.size(), 1); | ||||
| } | } | ||||
| onnx::NodeProto *slice_node = nullptr; | |||||
| for (auto &node : onnx_graph.node()) { | |||||
| if (&node == &onnx_node) { | |||||
| slice_node = const_cast<onnx::NodeProto *>(&node); | |||||
| } | |||||
| } | |||||
| int insert_num = 5 - onnx_node.input_size(); | |||||
| switch (insert_num) { | |||||
| case 4: { | |||||
| std::string name = "slice/starts/"; | |||||
| status = InsertTensor(starts, name, slice_node); | |||||
| } | |||||
| case 3: | |||||
| if (status == RET_OK) { | |||||
| std::string name = "slice/ends/"; | |||||
| status = InsertTensor(ends, name, slice_node); | |||||
| } | |||||
| case 2: | |||||
| if (status == RET_OK) { | |||||
| std::string name = "slice/axes/"; | |||||
| status = InsertTensor(axes, name, slice_node); | |||||
| } | |||||
| case 1: | |||||
| if (status == RET_OK) { | |||||
| std::string name = "slice/steps/"; | |||||
| status = InsertTensor(steps, name, slice_node); | |||||
| } | |||||
| default: | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "onnx slice insert tensor failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| op->primitive->value.type = schema::PrimitiveType_StridedSlice; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | |||||
| primitive_c->set_attr("starts", MakeValue<std::vector<int>>(starts)); | |||||
| primitive_c->set_attr("ends", MakeValue<std::vector<int>>(ends)); | |||||
| primitive_c->set_attr("axes", MakeValue<std::vector<int>>(axes)); | |||||
| primitive_c->set_attr("steps", MakeValue<std::vector<int>>(steps)); | |||||
| return primitive_c; | |||||
| } | } | ||||
| OnnxNodeRegistrar g_onnxSliceParser("Slice", new OnnxSliceParser()); | OnnxNodeRegistrar g_onnxSliceParser("Slice", new OnnxSliceParser()); | ||||
| @@ -21,7 +21,6 @@ | |||||
| #include <string> | #include <string> | ||||
| #include "tools/converter/parser/onnx/onnx_node_parser.h" | #include "tools/converter/parser/onnx/onnx_node_parser.h" | ||||
| #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" | #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" | ||||
| #include "tools/converter/parser/onnx/onnx_tensor_parser.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| @@ -30,9 +29,7 @@ class OnnxSliceParser : public OnnxNodeParser { | |||||
| OnnxSliceParser() : OnnxNodeParser("Slice") {} | OnnxSliceParser() : OnnxNodeParser("Slice") {} | ||||
| ~OnnxSliceParser() override = default; | ~OnnxSliceParser() override = default; | ||||
| STATUS InsertTensor(const std::vector<int> &onnx_val, const std::string &name, onnx::NodeProto *onnx_node); | |||||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||||
| STATUS GetInputTensor(std::vector<int> *onnx_val, const std::string &name); | |||||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -19,23 +19,13 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS OnnxSoftMaxParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||||
| schema::CNodeT *op) { | |||||
| lite::PrimitiveC *OnnxSoftMaxParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||||
| const onnx::NodeProto &onnx_node) { | |||||
| MS_LOG(DEBUG) << "onnx SoftMaxParser"; | MS_LOG(DEBUG) << "onnx SoftMaxParser"; | ||||
| if (op == nullptr) { | |||||
| MS_LOG(ERROR) << "op is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (op->primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "op->primitive is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| std::unique_ptr<schema::SoftMaxT> attr = std::make_unique<schema::SoftMaxT>(); | |||||
| auto attr = std::make_unique<schema::SoftMaxT>(); | |||||
| if (attr == nullptr) { | if (attr == nullptr) { | ||||
| MS_LOG(ERROR) << "new op failed"; | MS_LOG(ERROR) << "new op failed"; | ||||
| return RET_NULL_PTR; | |||||
| return nullptr; | |||||
| } | } | ||||
| bool axis_is_def = true; | bool axis_is_def = true; | ||||
| @@ -53,9 +43,14 @@ STATUS OnnxSoftMaxParser::Parse(const onnx::GraphProto &onnx_graph, const onnx:: | |||||
| attr->axis = 1; | attr->axis = 1; | ||||
| } | } | ||||
| } | } | ||||
| op->primitive->value.type = schema::PrimitiveType_SoftMax; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitive failed"; | |||||
| return nullptr; | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_SoftMax; | |||||
| primitive->value.value = attr.release(); | |||||
| return PrimitiveC::Create(primitive.release()); | |||||
| } | } | ||||
| OnnxNodeRegistrar g_onnxSoftMaxParser("Softmax", new OnnxSoftMaxParser()); | OnnxNodeRegistrar g_onnxSoftMaxParser("Softmax", new OnnxSoftMaxParser()); | ||||
| @@ -27,7 +27,7 @@ class OnnxSoftMaxParser : public OnnxNodeParser { | |||||
| OnnxSoftMaxParser() : OnnxNodeParser("Softmax") {} | OnnxSoftMaxParser() : OnnxNodeParser("Softmax") {} | ||||
| ~OnnxSoftMaxParser() override = default; | ~OnnxSoftMaxParser() override = default; | ||||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -19,23 +19,13 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS OnnxSpaceToDepthParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||||
| schema::CNodeT *op) { | |||||
| lite::PrimitiveC *OnnxSpaceToDepthParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||||
| const onnx::NodeProto &onnx_node) { | |||||
| MS_LOG(DEBUG) << "onnx SpaceToDepthParser"; | MS_LOG(DEBUG) << "onnx SpaceToDepthParser"; | ||||
| if (op == nullptr) { | |||||
| MS_LOG(ERROR) << "op is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (op->primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "op->primitive is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| std::unique_ptr<schema::SpaceToDepthT> attr = std::make_unique<schema::SpaceToDepthT>(); | |||||
| auto attr = std::make_unique<schema::SpaceToDepthT>(); | |||||
| if (attr == nullptr) { | if (attr == nullptr) { | ||||
| MS_LOG(ERROR) << "new op failed"; | MS_LOG(ERROR) << "new op failed"; | ||||
| return RET_NULL_PTR; | |||||
| return nullptr; | |||||
| } | } | ||||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | for (const auto &onnx_node_attr : onnx_node.attribute()) { | ||||
| @@ -45,9 +35,14 @@ STATUS OnnxSpaceToDepthParser::Parse(const onnx::GraphProto &onnx_graph, const o | |||||
| } | } | ||||
| } | } | ||||
| op->primitive->value.type = schema::PrimitiveType_SpaceToDepth; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitive failed"; | |||||
| return nullptr; | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_SpaceToDepth; | |||||
| primitive->value.value = attr.release(); | |||||
| return PrimitiveC::Create(primitive.release()); | |||||
| } | } | ||||
| OnnxNodeRegistrar g_onnxSpaceToDepthParser("SpaceToDepth", new OnnxSpaceToDepthParser()); | OnnxNodeRegistrar g_onnxSpaceToDepthParser("SpaceToDepth", new OnnxSpaceToDepthParser()); | ||||
| @@ -27,7 +27,7 @@ class OnnxSpaceToDepthParser : public OnnxNodeParser { | |||||
| OnnxSpaceToDepthParser() : OnnxNodeParser("SpaceToDepth") {} | OnnxSpaceToDepthParser() : OnnxNodeParser("SpaceToDepth") {} | ||||
| ~OnnxSpaceToDepthParser() override = default; | ~OnnxSpaceToDepthParser() override = default; | ||||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -19,23 +19,13 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS OnnxSplitParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||||
| schema::CNodeT *op) { | |||||
| lite::PrimitiveC *OnnxSplitParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||||
| const onnx::NodeProto &onnx_node) { | |||||
| MS_LOG(DEBUG) << "onnx SplitParser"; | MS_LOG(DEBUG) << "onnx SplitParser"; | ||||
| if (op == nullptr) { | |||||
| MS_LOG(ERROR) << "op is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (op->primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "op->primitive is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| std::unique_ptr<schema::SplitT> attr = std::make_unique<schema::SplitT>(); | |||||
| auto attr = std::make_unique<schema::SplitT>(); | |||||
| if (attr == nullptr) { | if (attr == nullptr) { | ||||
| MS_LOG(ERROR) << "new op failed"; | MS_LOG(ERROR) << "new op failed"; | ||||
| return RET_NULL_PTR; | |||||
| return nullptr; | |||||
| } | } | ||||
| attr->splitDim = 0; | attr->splitDim = 0; | ||||
| @@ -51,9 +41,14 @@ STATUS OnnxSplitParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::No | |||||
| } | } | ||||
| } | } | ||||
| op->primitive->value.type = schema::PrimitiveType_Split; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitive failed"; | |||||
| return nullptr; | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_Split; | |||||
| primitive->value.value = attr.release(); | |||||
| return PrimitiveC::Create(primitive.release()); | |||||
| } | } | ||||
| OnnxNodeRegistrar g_onnxSplitParser("Split", new OnnxSplitParser()); | OnnxNodeRegistrar g_onnxSplitParser("Split", new OnnxSplitParser()); | ||||
| @@ -27,7 +27,7 @@ class OnnxSplitParser : public OnnxNodeParser { | |||||
| OnnxSplitParser() : OnnxNodeParser("Split") {} | OnnxSplitParser() : OnnxNodeParser("Split") {} | ||||
| ~OnnxSplitParser() override = default; | ~OnnxSplitParser() override = default; | ||||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -19,23 +19,13 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS OnnxSqueezeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||||
| schema::CNodeT *op) { | |||||
| lite::PrimitiveC *OnnxSqueezeParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||||
| const onnx::NodeProto &onnx_node) { | |||||
| MS_LOG(DEBUG) << "onnx SqueezeParser"; | MS_LOG(DEBUG) << "onnx SqueezeParser"; | ||||
| if (op == nullptr) { | |||||
| MS_LOG(ERROR) << "op is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (op->primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "op->primitive is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| std::unique_ptr<schema::SqueezeT> attr = std::make_unique<schema::SqueezeT>(); | |||||
| auto attr = std::make_unique<schema::SqueezeT>(); | |||||
| if (attr == nullptr) { | if (attr == nullptr) { | ||||
| MS_LOG(ERROR) << "new op failed"; | MS_LOG(ERROR) << "new op failed"; | ||||
| return RET_NULL_PTR; | |||||
| return nullptr; | |||||
| } | } | ||||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | for (const auto &onnx_node_attr : onnx_node.attribute()) { | ||||
| @@ -47,9 +37,14 @@ STATUS OnnxSqueezeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx:: | |||||
| } | } | ||||
| } | } | ||||
| op->primitive->value.type = schema::PrimitiveType_Squeeze; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitive failed"; | |||||
| return nullptr; | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_Squeeze; | |||||
| primitive->value.value = attr.release(); | |||||
| return PrimitiveC::Create(primitive.release()); | |||||
| } | } | ||||
| OnnxNodeRegistrar g_onnxSqueezeParser("Squeeze", new OnnxSqueezeParser()); | OnnxNodeRegistrar g_onnxSqueezeParser("Squeeze", new OnnxSqueezeParser()); | ||||
| @@ -27,7 +27,7 @@ class OnnxSqueezeParser : public OnnxNodeParser { | |||||
| OnnxSqueezeParser() : OnnxNodeParser("Squeeze") {} | OnnxSqueezeParser() : OnnxNodeParser("Squeeze") {} | ||||
| ~OnnxSqueezeParser() override = default; | ~OnnxSqueezeParser() override = default; | ||||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -20,26 +20,22 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS OnnxTileParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { | |||||
| lite::PrimitiveC *OnnxTileParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||||
| const onnx::NodeProto &onnx_node) { | |||||
| MS_LOG(DEBUG) << "onnx TileParser"; | MS_LOG(DEBUG) << "onnx TileParser"; | ||||
| if (op == nullptr) { | |||||
| MS_LOG(ERROR) << "op is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (op->primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "op->primitive is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| std::unique_ptr<schema::TileT> attr = std::make_unique<schema::TileT>(); | |||||
| auto attr = std::make_unique<schema::TileT>(); | |||||
| if (attr == nullptr) { | if (attr == nullptr) { | ||||
| MS_LOG(ERROR) << "new op failed"; | MS_LOG(ERROR) << "new op failed"; | ||||
| return RET_NULL_PTR; | |||||
| return nullptr; | |||||
| } | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitive failed"; | |||||
| return nullptr; | |||||
| } | } | ||||
| op->primitive->value.type = schema::PrimitiveType_Tile; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | |||||
| primitive->value.type = schema::PrimitiveType_Tile; | |||||
| primitive->value.value = attr.release(); | |||||
| return PrimitiveC::Create(primitive.release()); | |||||
| } | } | ||||
| OnnxNodeRegistrar g_onnxTileParser("Tile", new OnnxTileParser()); | OnnxNodeRegistrar g_onnxTileParser("Tile", new OnnxTileParser()); | ||||
| @@ -27,7 +27,7 @@ class OnnxTileParser : public OnnxNodeParser { | |||||
| OnnxTileParser() : OnnxNodeParser("Tile") {} | OnnxTileParser() : OnnxNodeParser("Tile") {} | ||||
| ~OnnxTileParser() override = default; | ~OnnxTileParser() override = default; | ||||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -19,22 +19,13 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS OnnxTopkParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { | |||||
| lite::PrimitiveC *OnnxTopkParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||||
| const onnx::NodeProto &onnx_node) { | |||||
| MS_LOG(DEBUG) << "onnx TopKParser"; | MS_LOG(DEBUG) << "onnx TopKParser"; | ||||
| if (op == nullptr) { | |||||
| MS_LOG(ERROR) << "op is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (op->primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "op->primitive is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| std::unique_ptr<schema::TopKT> attr = std::make_unique<schema::TopKT>(); | |||||
| auto attr = std::make_unique<schema::TopKT>(); | |||||
| if (attr == nullptr) { | if (attr == nullptr) { | ||||
| MS_LOG(ERROR) << "new op failed"; | MS_LOG(ERROR) << "new op failed"; | ||||
| return RET_NULL_PTR; | |||||
| return nullptr; | |||||
| } | } | ||||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | for (const auto &onnx_node_attr : onnx_node.attribute()) { | ||||
| @@ -44,9 +35,14 @@ STATUS OnnxTopkParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod | |||||
| } | } | ||||
| } | } | ||||
| op->primitive->value.type = schema::PrimitiveType_TopK; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitive failed"; | |||||
| return nullptr; | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_TopK; | |||||
| primitive->value.value = attr.release(); | |||||
| return PrimitiveC::Create(primitive.release()); | |||||
| } | } | ||||
| OnnxNodeRegistrar g_onnxTopkParser("TopK", new OnnxTopkParser()); | OnnxNodeRegistrar g_onnxTopkParser("TopK", new OnnxTopkParser()); | ||||