| @@ -0,0 +1,36 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| #ifndef LITE_MINDSPORE_LITE_C_OPS_CONSTANT_H_ | |||
| #define LITE_MINDSPORE_LITE_C_OPS_CONSTANT_H_ | |||
| #include "src/ops/primitive_c.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class Constant : public PrimitiveC { | |||
| public: | |||
| Constant() = default; | |||
| ~Constant() = default; | |||
| MS_DECLARE_PARENT(Constant, PrimitiveC); | |||
| explicit Constant(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // LITE_MINDSPORE_LITE_C_OPS_CONSTANT_H_ | |||
| #endif | |||
| @@ -149,6 +149,7 @@ | |||
| #include "src/ops/oneslike.h" | |||
| #include "src/ops/unsorted_segment_sum.h" | |||
| #include "src/ops/reciprocal.h" | |||
| #include "src/ops/constant.h" | |||
| #ifdef SUPPORT_TRAIN | |||
| #include "src/ops/neg_grad.h" | |||
| @@ -186,7 +187,7 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| std::vector<int> CastToInt(const ValuePtr value) { | |||
| std::vector<int> CastToInt(const ValuePtr &value) { | |||
| if (value == nullptr) { | |||
| MS_LOG(WARNING) << "valueptr is nullptr."; | |||
| return {}; | |||
| @@ -903,6 +904,8 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) { | |||
| return new (std::nothrow) Dequant(primitive); | |||
| case schema::PrimitiveType_Reciprocal: | |||
| return new (std::nothrow) Reciprocal(primitive); | |||
| case schema::PrimitiveType_Constant: | |||
| return new (std::nothrow) Constant(primitive); | |||
| #ifdef SUPPORT_TRAIN | |||
| case schema::PrimitiveType_ActivationGrad: | |||
| @@ -57,7 +57,7 @@ static std::map<std::string, schema::ActivationType> kActivationTypeMap{ | |||
| {"LeakyRelu", schema::ActivationType_LEAKY_RELU}, | |||
| {"Tanh", schema::ActivationType_TANH}, | |||
| {"Logistic", schema::ActivationType_SIGMOID}}; | |||
| std::vector<int> CastToInt(const ValuePtr value); | |||
| std::vector<int> CastToInt(const ValuePtr &value); | |||
| class PrimitiveC : public mindspore::Primitive { | |||
| public: | |||
| // Argument primitive is deliverd into PrimitiveC and will be deleted in ~PrimitiveC(). | |||
| @@ -204,6 +204,7 @@ if(ENABLE_CONVERTER) | |||
| ${LITE_DIR}/tools/optimizer/graph/infershape_pass.cc | |||
| ${LITE_DIR}/tools/optimizer/graph/slice_prepose_pass.cc | |||
| ${LITE_DIR}/tools/optimizer/graph/mindir_adjust_pass.cc | |||
| ${LITE_DIR}/tools/optimizer/graph/onnx_inputs_adjust_pass.cc | |||
| ) | |||
| endif() | |||
| ### train | |||
| @@ -58,6 +58,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | |||
| ../optimizer/graph/infershape_pass.cc | |||
| ../optimizer/graph/slice_prepose_pass.cc | |||
| ../optimizer/graph/mindir_adjust_pass.cc | |||
| ../optimizer/graph/onnx_inputs_adjust_pass.cc | |||
| ) | |||
| add_subdirectory(../anf_importer anf_importer) | |||
| @@ -36,6 +36,7 @@ | |||
| #include "tools/optimizer/graph/clip_convert_activation_pass.h" | |||
| #include "tools/optimizer/graph/group_depthwise_op_convert_pass.h" | |||
| #include "tools/optimizer/graph/tflite_inputs_order_exchange_pass.h" | |||
| #include "tools/optimizer/graph/onnx_inputs_adjust_pass.h" | |||
| #include "tools/optimizer/graph/update_conv2d_param_pass.h" | |||
| #include "tools/optimizer/graph/unused_cast_node_remove_pass.h" | |||
| #include "tools/optimizer/graph/unused_transpose_node_remove_pass.h" | |||
| @@ -74,6 +75,16 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver | |||
| } | |||
| } | |||
| // onnx pre adjustment | |||
| if (config->fmk == converter::FmkType_ONNX) { | |||
| auto onnx_adjust_pass = std::make_shared<opt::OnnxInputAdjustOpPass>(); | |||
| if (!onnx_adjust_pass->Run(old_graph)) { | |||
| MS_LOG(ERROR) << "onnx adjust failed."; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); | |||
| return nullptr; | |||
| } | |||
| } | |||
| // for now - trainning is not supporting fuse operations | |||
| if (!config->trainModel) { | |||
| // remove quantdtype when awaretraining | |||
| @@ -90,6 +90,7 @@ STATUS CaffeModelParser::ConvertLayers() { | |||
| auto primitive_c = node_parser->ParseLitePrimitive(layer, weight); | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "parse node " << layer.name() << " failed."; | |||
| status = RET_ERROR; | |||
| continue; | |||
| } | |||
| @@ -98,8 +99,7 @@ STATUS CaffeModelParser::ConvertLayers() { | |||
| status = ConvertBottom(layer, &input_nodes); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "Convert layer bottom for " << layer.name() << " failed."; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||
| return status; | |||
| continue; | |||
| } | |||
| // build weights | |||
| @@ -107,8 +107,7 @@ STATUS CaffeModelParser::ConvertLayers() { | |||
| status = ConvertBlobs(weight, &const_parameters); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "Convert blobs for " << layer.name() << " failed."; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||
| return status; | |||
| continue; | |||
| } | |||
| // build cnode | |||
| @@ -122,15 +121,13 @@ STATUS CaffeModelParser::ConvertLayers() { | |||
| status = ConvertTop(layer, new_cnode); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "Convert outputs for " << layer.name() << " failed."; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||
| return status; | |||
| continue; | |||
| } | |||
| status = ConvertLayerQuantParams(layer, weight, primitive_c); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "Convert quant params for " << layer.name() << " failed."; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||
| return status; | |||
| continue; | |||
| } | |||
| } | |||
| return status; | |||
| @@ -19,27 +19,22 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS OnnxAdderParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||
| schema::CNodeT *op) { | |||
| lite::PrimitiveC *OnnxAdderParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||
| const onnx::NodeProto &onnx_node) { | |||
| MS_LOG(DEBUG) << "onnx AdderParser"; | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto attr = std::make_unique<schema::AdderT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| return nullptr; | |||
| } | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (primitive == nullptr) { | |||
| MS_LOG(ERROR) << "new primitive failed"; | |||
| return nullptr; | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_Adder; | |||
| op->primitive->value.value = attr.release(); | |||
| return RET_OK; | |||
| primitive->value.type = schema::PrimitiveType_Adder; | |||
| primitive->value.value = attr.release(); | |||
| return PrimitiveC::Create(primitive.release()); | |||
| } | |||
| OnnxNodeRegistrar g_onnxAdderParser("adder_f", new OnnxAdderParser()); | |||
| @@ -26,8 +26,7 @@ class OnnxAdderParser : public OnnxNodeParser { | |||
| public: | |||
| OnnxAdderParser() : OnnxNodeParser("Adder") {} | |||
| ~OnnxAdderParser() override = default; | |||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -19,23 +19,14 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS OnnxArgMaxParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||
| schema::CNodeT *op) { | |||
| lite::PrimitiveC *OnnxArgMaxParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||
| const onnx::NodeProto &onnx_node) { | |||
| MS_LOG(DEBUG) << "onnx ArgMaxParser"; | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::unique_ptr<schema::ArgMaxT> attr = std::make_unique<schema::ArgMaxT>(); | |||
| auto attr = std::make_unique<schema::ArgMaxT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| return nullptr; | |||
| } | |||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | |||
| @@ -46,10 +37,14 @@ STATUS OnnxArgMaxParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N | |||
| attr->keepDims = static_cast<bool>(onnx_node_attr.i()); | |||
| } | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_ArgMax; | |||
| op->primitive->value.value = attr.release(); | |||
| return RET_OK; | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (primitive == nullptr) { | |||
| MS_LOG(ERROR) << "new primitive failed"; | |||
| return nullptr; | |||
| } | |||
| primitive->value.type = schema::PrimitiveType_ArgMax; | |||
| primitive->value.value = attr.release(); | |||
| return PrimitiveC::Create(primitive.release()); | |||
| } | |||
| OnnxNodeRegistrar g_onnxArgMaxParser("ArgMax", new OnnxArgMaxParser()); | |||
| @@ -27,7 +27,7 @@ class OnnxArgMaxParser : public OnnxNodeParser { | |||
| OnnxArgMaxParser() : OnnxNodeParser("ArgMax") {} | |||
| ~OnnxArgMaxParser() override = default; | |||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -26,203 +26,203 @@ class OnnxAddParser : public OnnxNodeParser { | |||
| public: | |||
| OnnxAddParser() : OnnxNodeParser("Add") {} | |||
| ~OnnxAddParser() override = default; | |||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||
| }; | |||
| class OnnxSubParser : public OnnxNodeParser { | |||
| public: | |||
| OnnxSubParser() : OnnxNodeParser("Sub") {} | |||
| ~OnnxSubParser() override = default; | |||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||
| }; | |||
| class OnnxMulParser : public OnnxNodeParser { | |||
| public: | |||
| OnnxMulParser() : OnnxNodeParser("Mul") {} | |||
| ~OnnxMulParser() override = default; | |||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||
| }; | |||
| class OnnxDivParser : public OnnxNodeParser { | |||
| public: | |||
| OnnxDivParser() : OnnxNodeParser("Div") {} | |||
| ~OnnxDivParser() override = default; | |||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||
| }; | |||
| class OnnxPowParser : public OnnxNodeParser { | |||
| public: | |||
| OnnxPowParser() : OnnxNodeParser("Power") {} | |||
| ~OnnxPowParser() override = default; | |||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||
| }; | |||
| class OnnxEqualParser : public OnnxNodeParser { | |||
| public: | |||
| OnnxEqualParser() : OnnxNodeParser("Equal") {} | |||
| ~OnnxEqualParser() override = default; | |||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||
| }; | |||
| class OnnxLessParser : public OnnxNodeParser { | |||
| public: | |||
| OnnxLessParser() : OnnxNodeParser("Less") {} | |||
| ~OnnxLessParser() override = default; | |||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||
| }; | |||
| class OnnxGreaterParser : public OnnxNodeParser { | |||
| public: | |||
| OnnxGreaterParser() : OnnxNodeParser("Greater") {} | |||
| ~OnnxGreaterParser() override = default; | |||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||
| }; | |||
| class OnnxMinParser : public OnnxNodeParser { | |||
| public: | |||
| OnnxMinParser() : OnnxNodeParser("Min") {} | |||
| ~OnnxMinParser() override = default; | |||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||
| }; | |||
| class OnnxEltwiseParser : public OnnxNodeParser { | |||
| public: | |||
| OnnxEltwiseParser() : OnnxNodeParser("Eltwise") {} | |||
| ~OnnxEltwiseParser() override = default; | |||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||
| }; | |||
| class OnnxFloorParser : public OnnxNodeParser { | |||
| public: | |||
| OnnxFloorParser() : OnnxNodeParser("Floor") {} | |||
| ~OnnxFloorParser() override = default; | |||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||
| }; | |||
| class OnnxAbsParser : public OnnxNodeParser { | |||
| public: | |||
| OnnxAbsParser() : OnnxNodeParser("Abs") {} | |||
| ~OnnxAbsParser() override = default; | |||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||
| }; | |||
| class OnnxNegParser : public OnnxNodeParser { | |||
| public: | |||
| OnnxNegParser() : OnnxNodeParser("Neg") {} | |||
| ~OnnxNegParser() override = default; | |||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||
| }; | |||
| class OnnxExpParser : public OnnxNodeParser { | |||
| public: | |||
| OnnxExpParser() : OnnxNodeParser("Exp") {} | |||
| ~OnnxExpParser() override = default; | |||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||
| }; | |||
| class OnnxCosParser : public OnnxNodeParser { | |||
| public: | |||
| OnnxCosParser() : OnnxNodeParser("Cos") {} | |||
| ~OnnxCosParser() override = default; | |||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||
| }; | |||
| class OnnxSinParser : public OnnxNodeParser { | |||
| public: | |||
| OnnxSinParser() : OnnxNodeParser("Sin") {} | |||
| ~OnnxSinParser() override = default; | |||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||
| }; | |||
| class OnnxSqrtParser : public OnnxNodeParser { | |||
| public: | |||
| OnnxSqrtParser() : OnnxNodeParser("Sqrt") {} | |||
| ~OnnxSqrtParser() override = default; | |||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||
| }; | |||
| class OnnxCeilParser : public OnnxNodeParser { | |||
| public: | |||
| OnnxCeilParser() : OnnxNodeParser("Ceil") {} | |||
| ~OnnxCeilParser() override = default; | |||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||
| }; | |||
| class OnnxLogParser : public OnnxNodeParser { | |||
| public: | |||
| OnnxLogParser() : OnnxNodeParser("Log") {} | |||
| ~OnnxLogParser() override = default; | |||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||
| }; | |||
| class OnnxTanParser : public OnnxNodeParser { | |||
| public: | |||
| OnnxTanParser() : OnnxNodeParser("Tan") {} | |||
| ~OnnxTanParser() override = default; | |||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||
| }; | |||
| class OnnxAtanParser : public OnnxNodeParser { | |||
| public: | |||
| OnnxAtanParser() : OnnxNodeParser("Atan") {} | |||
| ~OnnxAtanParser() override = default; | |||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||
| }; | |||
| class OnnxAsinParser : public OnnxNodeParser { | |||
| public: | |||
| OnnxAsinParser() : OnnxNodeParser("Asin") {} | |||
| ~OnnxAsinParser() override = default; | |||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||
| }; | |||
| class OnnxTanhParser : public OnnxNodeParser { | |||
| public: | |||
| OnnxTanhParser() : OnnxNodeParser("Tanh") {} | |||
| ~OnnxTanhParser() override = default; | |||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||
| }; | |||
| class OnnxSignParser : public OnnxNodeParser { | |||
| public: | |||
| OnnxSignParser() : OnnxNodeParser("Sign") {} | |||
| ~OnnxSignParser() override = default; | |||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||
| }; | |||
| class OnnxAndParser : public OnnxNodeParser { | |||
| public: | |||
| OnnxAndParser() : OnnxNodeParser("And") {} | |||
| ~OnnxAndParser() override = default; | |||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||
| }; | |||
| class OnnxOrParser : public OnnxNodeParser { | |||
| public: | |||
| OnnxOrParser() : OnnxNodeParser("Or") {} | |||
| ~OnnxOrParser() override = default; | |||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||
| }; | |||
| class OnnxNotParser : public OnnxNodeParser { | |||
| public: | |||
| OnnxNotParser() : OnnxNodeParser("Not") {} | |||
| ~OnnxNotParser() override = default; | |||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||
| }; | |||
| class OnnxRoundParser : public OnnxNodeParser { | |||
| public: | |||
| OnnxRoundParser() : OnnxNodeParser("Round") {} | |||
| ~OnnxRoundParser() override = default; | |||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||
| }; | |||
| class OnnxReciprocalParser : public OnnxNodeParser { | |||
| public: | |||
| OnnxReciprocalParser() : OnnxNodeParser("Reciprocal") {} | |||
| ~OnnxReciprocalParser() override = default; | |||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -19,23 +19,13 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS OnnxBatchNormParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||
| schema::CNodeT *op) { | |||
| lite::PrimitiveC *OnnxBatchNormParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||
| const onnx::NodeProto &onnx_node) { | |||
| MS_LOG(DEBUG) << "onnx BatchNormParser"; | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::unique_ptr<schema::FusedBatchNormT> attr = std::make_unique<schema::FusedBatchNormT>(); | |||
| auto attr = std::make_unique<schema::FusedBatchNormT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| return nullptr; | |||
| } | |||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | |||
| @@ -47,10 +37,14 @@ STATUS OnnxBatchNormParser::Parse(const onnx::GraphProto &onnx_graph, const onnx | |||
| attr->spatial = static_cast<int32_t>(onnx_node_attr.i()); | |||
| } | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_FusedBatchNorm; | |||
| op->primitive->value.value = attr.release(); | |||
| return RET_OK; | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (primitive == nullptr) { | |||
| MS_LOG(ERROR) << "new primitive failed"; | |||
| return nullptr; | |||
| } | |||
| primitive->value.type = schema::PrimitiveType_FusedBatchNorm; | |||
| primitive->value.value = attr.release(); | |||
| return PrimitiveC::Create(primitive.release()); | |||
| } | |||
| OnnxNodeRegistrar g_onnxBatchNormParser("BatchNormalization", new OnnxBatchNormParser()); | |||
| @@ -27,7 +27,7 @@ class OnnxBatchNormParser : public OnnxNodeParser { | |||
| OnnxBatchNormParser() : OnnxNodeParser("BatchNormalization") {} | |||
| ~OnnxBatchNormParser() override = default; | |||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -19,30 +19,25 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS OnnxBiasAddParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||
| schema::CNodeT *op) { | |||
| lite::PrimitiveC *OnnxBiasAddParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||
| const onnx::NodeProto &onnx_node) { | |||
| MS_LOG(DEBUG) << "onnx BiasAddParser"; | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::unique_ptr<schema::BiasAddT> attr = std::make_unique<schema::BiasAddT>(); | |||
| auto attr = std::make_unique<schema::BiasAddT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| return nullptr; | |||
| } | |||
| attr->axis = {1}; | |||
| op->primitive->value.type = schema::PrimitiveType_BiasAdd; | |||
| op->primitive->value.value = attr.release(); | |||
| return RET_OK; | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (primitive == nullptr) { | |||
| MS_LOG(ERROR) << "new primitive failed"; | |||
| return nullptr; | |||
| } | |||
| primitive->value.type = schema::PrimitiveType_BiasAdd; | |||
| primitive->value.value = attr.release(); | |||
| return PrimitiveC::Create(primitive.release()); | |||
| } | |||
| OnnxNodeRegistrar g_onnxBiasAddParser("BiasAdd", new OnnxBiasAddParser()); | |||
| @@ -27,7 +27,7 @@ class OnnxBiasAddParser : public OnnxNodeParser { | |||
| OnnxBiasAddParser() : OnnxNodeParser("BiasAdd") {} | |||
| ~OnnxBiasAddParser() override = default; | |||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -20,22 +20,13 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS OnnxCastParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { | |||
| lite::PrimitiveC *OnnxCastParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||
| const onnx::NodeProto &onnx_node) { | |||
| MS_LOG(DEBUG) << "onnx CastParser"; | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::unique_ptr<schema::CastT> attr = std::make_unique<schema::CastT>(); | |||
| auto attr = std::make_unique<schema::CastT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| return nullptr; | |||
| } | |||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | |||
| @@ -48,10 +39,14 @@ STATUS OnnxCastParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod | |||
| attr->dstT = static_cast<int>(dst_type); | |||
| } | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_Cast; | |||
| op->primitive->value.value = attr.release(); | |||
| return RET_OK; | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (primitive == nullptr) { | |||
| MS_LOG(ERROR) << "new primitive failed"; | |||
| return nullptr; | |||
| } | |||
| primitive->value.type = schema::PrimitiveType_Cast; | |||
| primitive->value.value = attr.release(); | |||
| return PrimitiveC::Create(primitive.release()); | |||
| } | |||
| OnnxNodeRegistrar g_onnxCastParser("Cast", new OnnxCastParser()); | |||
| @@ -27,7 +27,7 @@ class OnnxCastParser : public OnnxNodeParser { | |||
| OnnxCastParser() : OnnxNodeParser("Cast") {} | |||
| ~OnnxCastParser() override = default; | |||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -19,39 +19,32 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS OnnxClipParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { | |||
| lite::PrimitiveC *OnnxClipParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||
| const onnx::NodeProto &onnx_node) { | |||
| MS_LOG(DEBUG) << "onnx ClipParser"; | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| auto attr = std::make_unique<schema::ClipT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return nullptr; | |||
| } | |||
| float min = -1, max = -1; | |||
| attr->max = -1; | |||
| attr->min = -1; | |||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | |||
| const auto &attribute_name = onnx_node_attr.name(); | |||
| if (attribute_name == "max") { | |||
| max = onnx_node_attr.f(); | |||
| attr->max = onnx_node_attr.f(); | |||
| } else if (attribute_name == "min") { | |||
| min = onnx_node_attr.f(); | |||
| attr->min = onnx_node_attr.f(); | |||
| } | |||
| } | |||
| std::unique_ptr<schema::ClipT> attr = std::make_unique<schema::ClipT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (primitive == nullptr) { | |||
| MS_LOG(ERROR) << "new primitive failed"; | |||
| return nullptr; | |||
| } | |||
| attr->max = max; | |||
| attr->min = min; | |||
| op->primitive->value.type = schema::PrimitiveType_Clip; | |||
| op->primitive->value.value = attr.release(); | |||
| return RET_OK; | |||
| primitive->value.type = schema::PrimitiveType_Clip; | |||
| primitive->value.value = attr.release(); | |||
| return PrimitiveC::Create(primitive.release()); | |||
| } | |||
| OnnxNodeRegistrar g_onnxClipParser("Clip", new OnnxClipParser()); | |||
| @@ -27,7 +27,7 @@ class OnnxClipParser : public OnnxNodeParser { | |||
| OnnxClipParser() : OnnxNodeParser("Clip") {} | |||
| ~OnnxClipParser() override = default; | |||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -19,23 +19,13 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS OnnxConcatParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||
| schema::CNodeT *op) { | |||
| lite::PrimitiveC *OnnxConcatParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||
| const onnx::NodeProto &onnx_node) { | |||
| MS_LOG(DEBUG) << "onnx ConcatParser"; | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::unique_ptr<schema::ConcatT> attr = std::make_unique<schema::ConcatT>(); | |||
| auto attr = std::make_unique<schema::ConcatT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| return nullptr; | |||
| } | |||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | |||
| @@ -44,10 +34,14 @@ STATUS OnnxConcatParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N | |||
| attr->axis = static_cast<int32_t>(onnx_node_attr.i()); | |||
| } | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_Concat; | |||
| op->primitive->value.value = attr.release(); | |||
| return RET_OK; | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (primitive == nullptr) { | |||
| MS_LOG(ERROR) << "new primitive failed"; | |||
| return nullptr; | |||
| } | |||
| primitive->value.type = schema::PrimitiveType_Concat; | |||
| primitive->value.value = attr.release(); | |||
| return PrimitiveC::Create(primitive.release()); | |||
| } | |||
| OnnxNodeRegistrar g_onnxConcatParser("Concat", new OnnxConcatParser()); | |||
| @@ -27,7 +27,7 @@ class OnnxConcatParser : public OnnxNodeParser { | |||
| OnnxConcatParser() : OnnxNodeParser("Concat") {} | |||
| ~OnnxConcatParser() override = default; | |||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -20,23 +20,13 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS OnnxConstantOfShapeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||
| schema::CNodeT *op) { | |||
| lite::PrimitiveC *OnnxConstantOfShapeParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||
| const onnx::NodeProto &onnx_node) { | |||
| MS_LOG(DEBUG) << "onnx ConstantOfShapeParser"; | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::unique_ptr<schema::ConstantOfShapeT> attr = std::make_unique<schema::ConstantOfShapeT>(); | |||
| auto attr = std::make_unique<schema::ConstantOfShapeT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| return nullptr; | |||
| } | |||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | |||
| @@ -55,19 +45,24 @@ STATUS OnnxConstantOfShapeParser::Parse(const onnx::GraphProto &onnx_graph, cons | |||
| const auto &tensor = onnx_node_attr.t(); | |||
| auto ret = GetTensorDataFromOnnx(tensor, &attr->value, &attr->dataType); | |||
| if (ret != RET_OK) { | |||
| return ret; | |||
| MS_LOG(ERROR) << "get data from tensor failed"; | |||
| return nullptr; | |||
| } | |||
| } break; | |||
| default: | |||
| MS_LOG(ERROR) << "The data type is not supported."; | |||
| return RET_ERROR; | |||
| return nullptr; | |||
| } | |||
| } | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_ConstantOfShape; | |||
| op->primitive->value.value = attr.release(); | |||
| return RET_OK; | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (primitive == nullptr) { | |||
| MS_LOG(ERROR) << "new primitive failed"; | |||
| return nullptr; | |||
| } | |||
| primitive->value.type = schema::PrimitiveType_ConstantOfShape; | |||
| primitive->value.value = attr.release(); | |||
| return PrimitiveC::Create(primitive.release()); | |||
| } | |||
| OnnxNodeRegistrar g_onnxConstantOfShapeParser("ConstantOfShape", new OnnxConstantOfShapeParser()); | |||
| @@ -27,7 +27,7 @@ class OnnxConstantOfShapeParser : public OnnxNodeParser { | |||
| OnnxConstantOfShapeParser() : OnnxNodeParser("ConstantOfShape") {} | |||
| ~OnnxConstantOfShapeParser() override = default; | |||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -16,33 +16,75 @@ | |||
| #include "tools/converter/parser/onnx/onnx_constant_parser.h" | |||
| #include <memory> | |||
| #include <vector> | |||
| #include <algorithm> | |||
| #include "tools/converter/parser/onnx/onnx_model_parser.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS OnnxConstantParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||
| schema::CNodeT *op) { | |||
| MS_LOG(DEBUG) << "onnx ConstantParser"; | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| STATUS OnnxConstantParser::AddDataInfoAttr(const onnx::TensorProto &onnx_const_tensor, lite::PrimitiveC *primitive_c) { | |||
| ParamValueLitePtr param_value = std::make_shared<ParamValueLite>(); | |||
| if (param_value == nullptr) { | |||
| MS_LOG(ERROR) << "new a paramValueLite failed."; | |||
| return RET_ERROR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| auto data_type = | |||
| OnnxModelParser::GetDataTypeFromOnnx(static_cast<onnx::TensorProto_DataType>(onnx_const_tensor.data_type())); | |||
| if (data_type == kTypeUnknown) { | |||
| MS_LOG(ERROR) << "not support onnx data type " | |||
| << static_cast<onnx::TensorProto_DataType>(onnx_const_tensor.data_type()); | |||
| return RET_ERROR; | |||
| } | |||
| std::unique_ptr<schema::ConstantT> attr = std::make_unique<schema::ConstantT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| std::vector<int64_t> shape_vector(onnx_const_tensor.dims().begin(), onnx_const_tensor.dims().end()); | |||
| std::vector<int> shape; | |||
| std::transform(shape_vector.begin(), shape_vector.end(), std::back_inserter(shape), | |||
| [](const int64_t &val) { return static_cast<int32_t>(val); }); | |||
| param_value->set_tensor_type(data_type); | |||
| param_value->set_tensor_shape(shape); | |||
| param_value->set_format(schema::Format_NCHW); | |||
| if (OnnxModelParser::CopyOnnxTensorData(onnx_const_tensor, param_value) != RET_OK) { | |||
| MS_LOG(ERROR) << "get value failed."; | |||
| return RET_ERROR; | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_Constant; | |||
| op->primitive->value.value = attr.release(); | |||
| primitive_c->set_attr("const_data", param_value); | |||
| return RET_OK; | |||
| } | |||
| lite::PrimitiveC *OnnxConstantParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||
| const onnx::NodeProto &onnx_node) { | |||
| MS_LOG(DEBUG) << "onnx ConstantParser"; | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (primitive == nullptr) { | |||
| MS_LOG(ERROR) << "new primitive failed"; | |||
| return nullptr; | |||
| } | |||
| primitive->value.type = schema::PrimitiveType_Constant; | |||
| auto primitive_c = PrimitiveC::Create(primitive.release()); | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "create primitiveC failed."; | |||
| return nullptr; | |||
| } | |||
| for (const auto &attr : onnx_node.attribute()) { | |||
| if (attr.name() == "sparse_value") { | |||
| MS_LOG(WARNING) << "sparse_value"; | |||
| continue; | |||
| } | |||
| if (attr.name() == "value") { | |||
| const auto &const_tensor = attr.t(); | |||
| if (AddDataInfoAttr(const_tensor, primitive_c) != RET_OK) { | |||
| MS_LOG(ERROR) << "add basic attr failed."; | |||
| delete primitive_c; | |||
| return nullptr; | |||
| } | |||
| } else { | |||
| MS_LOG(ERROR) << "processing Constant op attr " << attr.name() << " not implemented"; | |||
| delete primitive_c; | |||
| return nullptr; | |||
| } | |||
| } | |||
| return primitive_c; | |||
| } | |||
| OnnxNodeRegistrar g_onnxConstantParser("Constant", new OnnxConstantParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -27,7 +27,8 @@ class OnnxConstantParser : public OnnxNodeParser { | |||
| OnnxConstantParser() : OnnxNodeParser("Constant") {} | |||
| ~OnnxConstantParser() override = default; | |||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||
| STATUS AddDataInfoAttr(const onnx::TensorProto &onnx_const_tensor, lite::PrimitiveC *primitive_c); | |||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -21,9 +21,14 @@ | |||
| namespace mindspore::lite { | |||
| constexpr int32_t kSingleGroup = 1; | |||
| bool OnnxConvParser::ParseGroupConvolution(const std::unique_ptr<schema::Conv2DT> &attr, schema::CNodeT *op) { | |||
| bool OnnxConvParser::ParseGroupConvolution(const std::unique_ptr<schema::Conv2DT> &attr, | |||
| schema::PrimitiveT *primitive) { | |||
| MS_LOG(DEBUG) << "onnx DepthwiseConvParser"; | |||
| std::unique_ptr<schema::DepthwiseConv2DT> depthwiseConv2DParam = std::make_unique<schema::DepthwiseConv2DT>(); | |||
| if (attr == nullptr || primitive == nullptr) { | |||
| MS_LOG(ERROR) << "input parameter is nullptr"; | |||
| return false; | |||
| } | |||
| auto depthwiseConv2DParam = std::make_unique<schema::DepthwiseConv2DT>(); | |||
| if (depthwiseConv2DParam == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return false; | |||
| @@ -45,27 +50,18 @@ bool OnnxConvParser::ParseGroupConvolution(const std::unique_ptr<schema::Conv2DT | |||
| depthwiseConv2DParam->hasBias = attr->hasBias; | |||
| depthwiseConv2DParam->activationType = attr->activationType; | |||
| op->primitive->value.type = schema::PrimitiveType_DepthwiseConv2D; | |||
| op->primitive->value.value = depthwiseConv2DParam.release(); | |||
| primitive->value.type = schema::PrimitiveType_DepthwiseConv2D; | |||
| primitive->value.value = depthwiseConv2DParam.release(); | |||
| return true; | |||
| } | |||
| STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { | |||
| lite::PrimitiveC *OnnxConvParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||
| const onnx::NodeProto &onnx_node) { | |||
| MS_LOG(DEBUG) << "onnx ConvParser"; | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::unique_ptr<schema::Conv2DT> attr = std::make_unique<schema::Conv2DT>(); | |||
| auto attr = std::make_unique<schema::Conv2DT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| return nullptr; | |||
| } | |||
| attr->strideH = 1; | |||
| @@ -83,21 +79,21 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod | |||
| } else if (onnx_node_attr.name() == "dilations") { | |||
| if (onnx_node_attr.ints().size() != 2) { | |||
| MS_LOG(ERROR) << "dilations size " << onnx_node_attr.ints().size() << " is not 2"; | |||
| return RET_ERROR; | |||
| return nullptr; | |||
| } | |||
| attr->dilateH = static_cast<int32_t>(onnx_node_attr.ints(0)); | |||
| attr->dilateW = static_cast<int32_t>(onnx_node_attr.ints(1)); | |||
| } else if (onnx_node_attr.name() == "kernels") { | |||
| if (onnx_node_attr.ints().size() != 2) { | |||
| MS_LOG(ERROR) << "kernel_shape size " << onnx_node_attr.ints().size() << " is not 2"; | |||
| return RET_ERROR; | |||
| return nullptr; | |||
| } | |||
| attr->kernelH = static_cast<int32_t>(onnx_node_attr.ints(0)); | |||
| attr->kernelW = static_cast<int32_t>(onnx_node_attr.ints(1)); | |||
| } else if (onnx_node_attr.name() == "kernel_shape") { | |||
| if (onnx_node_attr.ints().size() != 2) { | |||
| MS_LOG(ERROR) << "kernel_shape size " << onnx_node_attr.ints().size() << " is not 2"; | |||
| return RET_ERROR; | |||
| return nullptr; | |||
| } | |||
| attr->kernelH = static_cast<int32_t>(onnx_node_attr.ints(0)); | |||
| attr->kernelW = static_cast<int32_t>(onnx_node_attr.ints(1)); | |||
| @@ -106,7 +102,7 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod | |||
| } else if (onnx_node_attr.name() == "pads") { | |||
| if (onnx_node_attr.ints().size() != 4) { | |||
| MS_LOG(ERROR) << "pads size " << onnx_node_attr.ints().size() << " is not 4"; | |||
| return RET_ERROR; | |||
| return nullptr; | |||
| } | |||
| attr->padUp = static_cast<int32_t>(onnx_node_attr.ints(0)); | |||
| attr->padLeft = static_cast<int32_t>(onnx_node_attr.ints(1)); | |||
| @@ -115,7 +111,7 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod | |||
| } else if (onnx_node_attr.name() == "strides") { | |||
| if (onnx_node_attr.ints().size() != 2) { | |||
| MS_LOG(ERROR) << "strides size " << onnx_node_attr.ints().size() << " is not 2"; | |||
| return RET_ERROR; | |||
| return nullptr; | |||
| } | |||
| attr->strideH = static_cast<int32_t>(onnx_node_attr.ints(0)); | |||
| attr->strideW = static_cast<int32_t>(onnx_node_attr.ints(1)); | |||
| @@ -124,7 +120,7 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod | |||
| attr->format = schema::Format::Format_NHWC; | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported format: " << onnx_node_attr.s(); | |||
| return RET_ERROR; | |||
| return nullptr; | |||
| } | |||
| } | |||
| } | |||
| @@ -152,7 +148,7 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod | |||
| [onnx_conv_weight](const onnx::NodeProto &proto) { return proto.output(0) == onnx_conv_weight; }); | |||
| if (node_iter == onnx_graph.node().end()) { | |||
| MS_LOG(ERROR) << "can not find node: " << onnx_conv_weight; | |||
| return RET_ERROR; | |||
| return nullptr; | |||
| } | |||
| std::vector<int> dims; | |||
| auto iter = std::find_if((*node_iter).attribute().begin(), (*node_iter).attribute().end(), | |||
| @@ -160,7 +156,7 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod | |||
| if (iter != (*node_iter).attribute().end()) { | |||
| if (iter->ints().begin() == nullptr || iter->ints().end() == nullptr) { | |||
| MS_LOG(ERROR) << "dims insert failed"; | |||
| return RET_ERROR; | |||
| return nullptr; | |||
| } | |||
| dims.insert(dims.begin(), iter->ints().begin(), iter->ints().end()); | |||
| } | |||
| @@ -174,16 +170,21 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod | |||
| attr->activationType = schema::ActivationType_NO_ACTIVATION; | |||
| } | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (primitive == nullptr) { | |||
| MS_LOG(ERROR) << "new primitive failed"; | |||
| return nullptr; | |||
| } | |||
| if (attr->group > kSingleGroup && attr->group == attr->channelIn) { | |||
| if (!ParseGroupConvolution(attr, op)) { | |||
| if (!ParseGroupConvolution(attr, primitive.get())) { | |||
| MS_LOG(ERROR) << "Convert Convolution to Depthwise failed"; | |||
| return RET_ERROR; | |||
| return nullptr; | |||
| } | |||
| } else { | |||
| op->primitive->value.type = schema::PrimitiveType_Conv2D; | |||
| op->primitive->value.value = attr.release(); | |||
| primitive->value.type = schema::PrimitiveType_Conv2D; | |||
| primitive->value.value = attr.release(); | |||
| } | |||
| return RET_OK; | |||
| return PrimitiveC::Create(primitive.release()); | |||
| } | |||
| OnnxNodeRegistrar g_onnxConvParser("Conv", new OnnxConvParser()); | |||
| @@ -28,10 +28,10 @@ class OnnxConvParser : public OnnxNodeParser { | |||
| OnnxConvParser() : OnnxNodeParser("Conv") {} | |||
| ~OnnxConvParser() override = default; | |||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||
| private: | |||
| static bool ParseGroupConvolution(const std::unique_ptr<schema::Conv2DT> &attr, schema::CNodeT *op); | |||
| static bool ParseGroupConvolution(const std::unique_ptr<schema::Conv2DT> &attr, schema::PrimitiveT *primitive); | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -21,11 +21,13 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| bool OnnxDeConvParser::ParseGroupDeConvolution(const std::unique_ptr<schema::DeConv2DT> &attr, schema::CNodeT *op) { | |||
| if (attr == nullptr || attr->group != attr->channelOut) { | |||
| bool OnnxDeConvParser::ParseGroupDeConvolution(const std::unique_ptr<schema::DeConv2DT> &attr, | |||
| schema::PrimitiveT *primitive) { | |||
| if (attr == nullptr || attr->group != attr->channelOut || primitive == nullptr) { | |||
| MS_LOG(ERROR) << "input parameter is nullptr"; | |||
| return false; | |||
| } | |||
| std::unique_ptr<schema::DeDepthwiseConv2DT> deDepthwiseConv2DParam = std::make_unique<schema::DeDepthwiseConv2DT>(); | |||
| auto deDepthwiseConv2DParam = std::make_unique<schema::DeDepthwiseConv2DT>(); | |||
| if (deDepthwiseConv2DParam == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return false; | |||
| @@ -47,28 +49,18 @@ bool OnnxDeConvParser::ParseGroupDeConvolution(const std::unique_ptr<schema::DeC | |||
| deDepthwiseConv2DParam->hasBias = attr->hasBias; | |||
| deDepthwiseConv2DParam->activationType = attr->activationType; | |||
| op->primitive->value.type = schema::PrimitiveType_DeDepthwiseConv2D; | |||
| op->primitive->value.value = deDepthwiseConv2DParam.release(); | |||
| primitive->value.type = schema::PrimitiveType_DeDepthwiseConv2D; | |||
| primitive->value.value = deDepthwiseConv2DParam.release(); | |||
| return true; | |||
| } | |||
| STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||
| schema::CNodeT *op) { | |||
| lite::PrimitiveC *OnnxDeConvParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||
| const onnx::NodeProto &onnx_node) { | |||
| MS_LOG(DEBUG) << "onnx DeConvParser"; | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::unique_ptr<schema::DeConv2DT> attr = std::make_unique<schema::DeConv2DT>(); | |||
| auto attr = std::make_unique<schema::DeConv2DT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| return nullptr; | |||
| } | |||
| attr->padMode = schema::PadMode_NOTSET; | |||
| @@ -83,21 +75,21 @@ STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N | |||
| } else if (onnx_node_attr.name() == "dilations") { | |||
| if (onnx_node_attr.ints().size() != 2) { | |||
| MS_LOG(ERROR) << "dilations size " << onnx_node_attr.ints().size() << " is not 2"; | |||
| return RET_ERROR; | |||
| return nullptr; | |||
| } | |||
| attr->dilateH = static_cast<int32_t>(onnx_node_attr.ints(0)); | |||
| attr->dilateW = static_cast<int32_t>(onnx_node_attr.ints(1)); | |||
| } else if (onnx_node_attr.name() == "kernels") { | |||
| if (onnx_node_attr.ints().size() != 2) { | |||
| MS_LOG(ERROR) << "kernel_shape size " << onnx_node_attr.ints().size() << " is not 2"; | |||
| return RET_ERROR; | |||
| return nullptr; | |||
| } | |||
| attr->kernelH = static_cast<int32_t>(onnx_node_attr.ints(0)); | |||
| attr->kernelW = static_cast<int32_t>(onnx_node_attr.ints(1)); | |||
| } else if (onnx_node_attr.name() == "kernel_shape") { | |||
| if (onnx_node_attr.ints().size() != 2) { | |||
| MS_LOG(ERROR) << "kernel_shape size " << onnx_node_attr.ints().size() << " is not 2"; | |||
| return RET_ERROR; | |||
| return nullptr; | |||
| } | |||
| attr->kernelH = static_cast<int32_t>(onnx_node_attr.ints(0)); | |||
| attr->kernelW = static_cast<int32_t>(onnx_node_attr.ints(1)); | |||
| @@ -106,7 +98,7 @@ STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N | |||
| } else if (onnx_node_attr.name() == "pads") { | |||
| if (onnx_node_attr.ints().size() != 4) { | |||
| MS_LOG(ERROR) << "pads size " << onnx_node_attr.ints().size() << " is not 4"; | |||
| return RET_ERROR; | |||
| return nullptr; | |||
| } | |||
| attr->padUp = static_cast<int32_t>(onnx_node_attr.ints(0)); | |||
| attr->padLeft = static_cast<int32_t>(onnx_node_attr.ints(1)); | |||
| @@ -115,7 +107,7 @@ STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N | |||
| } else if (onnx_node_attr.name() == "strides") { | |||
| if (onnx_node_attr.ints().size() != 2) { | |||
| MS_LOG(ERROR) << "strides size " << onnx_node_attr.ints().size() << " is not 2"; | |||
| return RET_ERROR; | |||
| return nullptr; | |||
| } | |||
| attr->strideH = static_cast<int32_t>(onnx_node_attr.ints(0)); | |||
| attr->strideW = static_cast<int32_t>(onnx_node_attr.ints(1)); | |||
| @@ -124,11 +116,11 @@ STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N | |||
| attr->format = schema::Format::Format_NHWC; | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported format: " << onnx_node_attr.s().c_str(); | |||
| return RET_ERROR; | |||
| return nullptr; | |||
| } | |||
| } else if (onnx_node_attr.name() == "output_padding") { | |||
| MS_LOG(ERROR) << "output_padding param hasn't been supported"; | |||
| return RET_NOT_SUPPORT; | |||
| return nullptr; | |||
| } | |||
| } | |||
| @@ -138,7 +130,7 @@ STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N | |||
| [onnx_conv_weight](const onnx::TensorProto &proto) { return proto.name() == onnx_conv_weight; }); | |||
| if (node_iter == onnx_graph.initializer().end()) { | |||
| MS_LOG(ERROR) << "not find node: " << onnx_conv_weight.c_str(); | |||
| return RET_ERROR; | |||
| return nullptr; | |||
| } | |||
| std::vector<int> weight_shape; | |||
| auto size = (*node_iter).dims_size(); | |||
| @@ -148,7 +140,7 @@ STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N | |||
| } | |||
| if (weight_shape.size() != 4) { | |||
| MS_LOG(ERROR) << "weight_shape.size() should be 4, but is " << weight_shape.size(); | |||
| return RET_ERROR; | |||
| return nullptr; | |||
| } | |||
| attr->channelIn = weight_shape[0]; | |||
| attr->channelOut = weight_shape[1] * attr->group; | |||
| @@ -156,17 +148,22 @@ STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N | |||
| attr->format = schema::Format::Format_NCHW; | |||
| attr->hasBias = onnx_node.input().size() == 3; | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (primitive == nullptr) { | |||
| MS_LOG(ERROR) << "new primitive failed"; | |||
| return nullptr; | |||
| } | |||
| if (attr->group != 1) { | |||
| if (!ParseGroupDeConvolution(attr, op)) { | |||
| if (!ParseGroupDeConvolution(attr, primitive.get())) { | |||
| MS_LOG(ERROR) << "Convert DeConvolution to DeDepthwise failed, generalized group deconv hasn't support"; | |||
| return RET_NOT_SUPPORT; | |||
| return nullptr; | |||
| } | |||
| } else { | |||
| op->primitive->value.type = schema::PrimitiveType_DeConv2D; | |||
| op->primitive->value.value = attr.release(); | |||
| primitive->value.type = schema::PrimitiveType_DeConv2D; | |||
| primitive->value.value = attr.release(); | |||
| } | |||
| return RET_OK; | |||
| return PrimitiveC::Create(primitive.release()); | |||
| } | |||
| OnnxNodeRegistrar g_onnxDeConvParser("ConvTranspose", new OnnxDeConvParser()); | |||
| @@ -28,10 +28,10 @@ class OnnxDeConvParser : public OnnxNodeParser { | |||
| OnnxDeConvParser() : OnnxNodeParser("DeConv") {} | |||
| ~OnnxDeConvParser() override = default; | |||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||
| private: | |||
| static bool ParseGroupDeConvolution(const std::unique_ptr<schema::DeConv2DT> &attr, schema::CNodeT *op); | |||
| bool ParseGroupDeConvolution(const std::unique_ptr<schema::DeConv2DT> &attr, schema::PrimitiveT *primitive); | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -19,23 +19,13 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS OnnxDepthToSpaceParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||
| schema::CNodeT *op) { | |||
| lite::PrimitiveC *OnnxDepthToSpaceParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||
| const onnx::NodeProto &onnx_node) { | |||
| MS_LOG(DEBUG) << "onnx DepthToSpaceParser"; | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::unique_ptr<schema::DepthToSpaceT> attr = std::make_unique<schema::DepthToSpaceT>(); | |||
| auto attr = std::make_unique<schema::DepthToSpaceT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| return nullptr; | |||
| } | |||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | |||
| @@ -44,10 +34,14 @@ STATUS OnnxDepthToSpaceParser::Parse(const onnx::GraphProto &onnx_graph, const o | |||
| attr->blockSize = static_cast<int32_t>(onnx_node_attr.i()); | |||
| } | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_DepthToSpace; | |||
| op->primitive->value.value = attr.release(); | |||
| return RET_OK; | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (primitive == nullptr) { | |||
| MS_LOG(ERROR) << "new primitive failed"; | |||
| return nullptr; | |||
| } | |||
| primitive->value.type = schema::PrimitiveType_DepthToSpace; | |||
| primitive->value.value = attr.release(); | |||
| return PrimitiveC::Create(primitive.release()); | |||
| } | |||
| OnnxNodeRegistrar g_onnxDepthToSpaceParser("DepthToSpace", new OnnxDepthToSpaceParser()); | |||
| @@ -27,7 +27,7 @@ class OnnxDepthToSpaceParser : public OnnxNodeParser { | |||
| OnnxDepthToSpaceParser() : OnnxNodeParser("DepthToSpace") {} | |||
| ~OnnxDepthToSpaceParser() override = default; | |||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -19,23 +19,13 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS OnnxDropoutParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||
| schema::CNodeT *op) { | |||
| lite::PrimitiveC *OnnxDropoutParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||
| const onnx::NodeProto &onnx_node) { | |||
| MS_LOG(DEBUG) << "onnx DropoutParser"; | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::unique_ptr<schema::DropoutT> attr = std::make_unique<schema::DropoutT>(); | |||
| auto attr = std::make_unique<schema::DropoutT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| return nullptr; | |||
| } | |||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | |||
| @@ -44,10 +34,14 @@ STATUS OnnxDropoutParser::Parse(const onnx::GraphProto &onnx_graph, const onnx:: | |||
| attr->ratio = static_cast<float>(onnx_node_attr.f()); | |||
| } | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_Dropout; | |||
| op->primitive->value.value = attr.release(); | |||
| return RET_OK; | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (primitive == nullptr) { | |||
| MS_LOG(ERROR) << "new primitive failed"; | |||
| return nullptr; | |||
| } | |||
| primitive->value.type = schema::PrimitiveType_Dropout; | |||
| primitive->value.value = attr.release(); | |||
| return PrimitiveC::Create(primitive.release()); | |||
| } | |||
| OnnxNodeRegistrar g_onnxDropoutParser("Dropout", new OnnxDropoutParser()); | |||
| @@ -27,7 +27,7 @@ class OnnxDropoutParser : public OnnxNodeParser { | |||
| OnnxDropoutParser() : OnnxNodeParser("Dropout") {} | |||
| ~OnnxDropoutParser() override = default; | |||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -19,22 +19,13 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS OnnxEluParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { | |||
| lite::PrimitiveC *OnnxEluParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||
| const onnx::NodeProto &onnx_node) { | |||
| MS_LOG(DEBUG) << "onnx EluParser"; | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::unique_ptr<schema::EluT> attr = std::make_unique<schema::EluT>(); | |||
| auto attr = std::make_unique<schema::EluT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| return nullptr; | |||
| } | |||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | |||
| @@ -43,10 +34,14 @@ STATUS OnnxEluParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Node | |||
| attr->alpha = onnx_node_attr.f(); | |||
| } | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_Elu; | |||
| op->primitive->value.value = attr.release(); | |||
| return RET_OK; | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (primitive == nullptr) { | |||
| MS_LOG(ERROR) << "new primitive failed"; | |||
| return nullptr; | |||
| } | |||
| primitive->value.type = schema::PrimitiveType_Elu; | |||
| primitive->value.value = attr.release(); | |||
| return PrimitiveC::Create(primitive.release()); | |||
| } | |||
| OnnxNodeRegistrar g_onnxEluParser("Elu", new OnnxEluParser()); | |||
| @@ -27,7 +27,7 @@ class OnnxEluParser : public OnnxNodeParser { | |||
| OnnxEluParser() : OnnxNodeParser("Elu") {} | |||
| ~OnnxEluParser() override = default; | |||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -20,23 +20,13 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS OnnxExpandParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||
| schema::CNodeT *op) { | |||
| lite::PrimitiveC *OnnxExpandParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||
| const onnx::NodeProto &onnx_node) { | |||
| MS_LOG(DEBUG) << "onnx ExpandParser"; | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::unique_ptr<schema::BroadcastToT> attr = std::make_unique<schema::BroadcastToT>(); | |||
| auto attr = std::make_unique<schema::BroadcastToT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| return nullptr; | |||
| } | |||
| std::vector<int> dst_shape; | |||
| @@ -46,7 +36,7 @@ STATUS OnnxExpandParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N | |||
| [onnx_expand_power](const onnx::NodeProto &proto) { return proto.output(0) == onnx_expand_power; }); | |||
| if (node_iter == onnx_graph.node().end()) { | |||
| MS_LOG(ERROR) << "can not find node: " << onnx_expand_power; | |||
| return RET_ERROR; | |||
| return nullptr; | |||
| } | |||
| for (const auto &attrPower : node_iter->attribute()) { | |||
| if (attrPower.name() == "value") { | |||
| @@ -58,9 +48,14 @@ STATUS OnnxExpandParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N | |||
| } | |||
| } | |||
| attr->dst_shape = dst_shape; | |||
| op->primitive->value.type = schema::PrimitiveType_BroadcastTo; | |||
| op->primitive->value.value = attr.release(); | |||
| return RET_OK; | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (primitive == nullptr) { | |||
| MS_LOG(ERROR) << "new primitive failed"; | |||
| return nullptr; | |||
| } | |||
| primitive->value.type = schema::PrimitiveType_BroadcastTo; | |||
| primitive->value.value = attr.release(); | |||
| return PrimitiveC::Create(primitive.release()); | |||
| } | |||
| OnnxNodeRegistrar g_onnxExpandSpaceParser("Expand", new OnnxExpandParser()); | |||
| @@ -27,7 +27,7 @@ class OnnxExpandParser : public OnnxNodeParser { | |||
| OnnxExpandParser() : OnnxNodeParser("Expand") {} | |||
| ~OnnxExpandParser() override = default; | |||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -19,23 +19,13 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS OnnxFlattenParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||
| schema::CNodeT *op) { | |||
| lite::PrimitiveC *OnnxFlattenParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||
| const onnx::NodeProto &onnx_node) { | |||
| MS_LOG(DEBUG) << "onnx FlattenParser"; | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::unique_ptr<schema::ReshapeT> attr = std::make_unique<schema::ReshapeT>(); | |||
| auto attr = std::make_unique<schema::ReshapeT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| return nullptr; | |||
| } | |||
| int axis = 1; | |||
| @@ -49,10 +39,14 @@ STATUS OnnxFlattenParser::Parse(const onnx::GraphProto &onnx_graph, const onnx:: | |||
| attr->shape.emplace_back(0); | |||
| } | |||
| attr->shape.emplace_back(-1); | |||
| op->primitive->value.type = schema::PrimitiveType_Reshape; | |||
| op->primitive->value.value = attr.release(); | |||
| return RET_OK; | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (primitive == nullptr) { | |||
| MS_LOG(ERROR) << "new primitive failed"; | |||
| return nullptr; | |||
| } | |||
| primitive->value.type = schema::PrimitiveType_Reshape; | |||
| primitive->value.value = attr.release(); | |||
| return PrimitiveC::Create(primitive.release()); | |||
| } | |||
| OnnxNodeRegistrar g_onnxFlattenParser("Flatten", new OnnxFlattenParser()); | |||
| @@ -27,7 +27,7 @@ class OnnxFlattenParser : public OnnxNodeParser { | |||
| OnnxFlattenParser() : OnnxNodeParser("Fatten") {} | |||
| ~OnnxFlattenParser() override = default; | |||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -19,23 +19,13 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS OnnxGatherParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||
| schema::CNodeT *op) { | |||
| lite::PrimitiveC *OnnxGatherParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||
| const onnx::NodeProto &onnx_node) { | |||
| MS_LOG(DEBUG) << "onnx GatherParser"; | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::unique_ptr<schema::GatherT> attr = std::make_unique<schema::GatherT>(); | |||
| auto attr = std::make_unique<schema::GatherT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| return nullptr; | |||
| } | |||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | |||
| @@ -45,9 +35,14 @@ STATUS OnnxGatherParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N | |||
| } | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_Gather; | |||
| op->primitive->value.value = attr.release(); | |||
| return RET_OK; | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (primitive == nullptr) { | |||
| MS_LOG(ERROR) << "new primitive failed"; | |||
| return nullptr; | |||
| } | |||
| primitive->value.type = schema::PrimitiveType_Gather; | |||
| primitive->value.value = attr.release(); | |||
| return PrimitiveC::Create(primitive.release()); | |||
| } | |||
| OnnxNodeRegistrar g_onnxGatherParser("Gather", new OnnxGatherParser()); | |||
| @@ -27,7 +27,7 @@ class OnnxGatherParser : public OnnxNodeParser { | |||
| OnnxGatherParser() : OnnxNodeParser("Gather") {} | |||
| ~OnnxGatherParser() override = default; | |||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,55 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/converter/parser/onnx/onnx_gemm_parser.h" | |||
| #include <vector> | |||
| #include <memory> | |||
| namespace mindspore { | |||
| namespace lite { | |||
| lite::PrimitiveC *OnnxGemmParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||
| const onnx::NodeProto &onnx_node) { | |||
| MS_LOG(DEBUG) << "onnx IdentityParser"; | |||
| auto node_parser = OnnxNodeParserRegistry::GetInstance()->GetNodeParser("MatMul"); | |||
| if (node_parser == nullptr) { | |||
| MS_LOG(ERROR) << "parse node " << onnx_node.op_type() << " failed."; | |||
| return nullptr; | |||
| } | |||
| auto *matmul_primitive = node_parser->ParseLitePrimitive(onnx_graph, onnx_node); | |||
| node_parser = OnnxNodeParserRegistry::GetInstance()->GetNodeParser("BiasAdd"); | |||
| if (node_parser == nullptr) { | |||
| MS_LOG(ERROR) << "parse node " << onnx_node.op_type() << " failed."; | |||
| return nullptr; | |||
| } | |||
| auto *bias_add_primitive = node_parser->ParseLitePrimitive(onnx_graph, onnx_node); | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (primitive == nullptr) { | |||
| MS_LOG(ERROR) << "new primitive failed"; | |||
| return nullptr; | |||
| } | |||
| primitive->value.type = schema::PrimitiveType_MakeTuple; | |||
| auto primitve_c = PrimitiveC::Create(primitive.release()); | |||
| primitve_c->set_attr("MatMul", std::shared_ptr<lite::PrimitiveC>(matmul_primitive)); | |||
| primitve_c->set_attr("BiasAdd", std::shared_ptr<lite::PrimitiveC>(bias_add_primitive)); | |||
| return primitve_c; | |||
| } | |||
| OnnxNodeRegistrar g_onnxGemmParser("Gemm", new OnnxGemmParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -14,26 +14,21 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_TENSOR_PARSER_H | |||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_TENSOR_PARSER_H | |||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_GEMM_PARSER_H | |||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_GEMM_PARSER_H | |||
| #include "tools/common/tensor_util.h" | |||
| #include "tools/converter/parser/onnx/onnx_node_parser.h" | |||
| #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class OnnxTensorParser { | |||
| class OnnxGemmParser : public OnnxNodeParser { | |||
| public: | |||
| ~OnnxTensorParser() = default; | |||
| static OnnxTensorParser *GetInstance() { | |||
| static OnnxTensorParser onnxTensorParser; | |||
| return &onnxTensorParser; | |||
| } | |||
| TensorCache *GetTensorCache() { return &tensor_cache_; } | |||
| OnnxGemmParser() : OnnxNodeParser("Gemm") {} | |||
| ~OnnxGemmParser() override = default; | |||
| private: | |||
| OnnxTensorParser() = default; | |||
| TensorCache tensor_cache_; | |||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_TESNOR_PARSER_H | |||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_GEMM_PARSER_H | |||
| @@ -0,0 +1,127 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/converter/parser/onnx/onnx_given_tensor_fill_parser.h" | |||
| #include <functional> | |||
| #include <memory> | |||
| #include <vector> | |||
| #include <algorithm> | |||
| #include "src/param_value_lite.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS OnnxGivenTensorFillParser::ParseInt8GivenIntTensorFill(const onnx::NodeProto &onnx_node, | |||
| lite::PrimitiveC *primitive_c, | |||
| const std::vector<int> &shape) { | |||
| ParamValueLitePtr param_value = std::make_shared<ParamValueLite>(); | |||
| if (param_value == nullptr) { | |||
| MS_LOG(ERROR) << "new a paramValueLite failed."; | |||
| return RET_ERROR; | |||
| } | |||
| int data_count = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>()); | |||
| auto iter = std::find_if(onnx_node.attribute().begin(), onnx_node.attribute().end(), | |||
| [](const onnx::AttributeProto &attr) { return attr.name() == "values"; }); | |||
| if (iter == onnx_node.attribute().end()) { | |||
| return RET_OK; | |||
| } | |||
| size_t data_size = data_count * sizeof(int64_t) / sizeof(uint8_t); | |||
| char *param_data = new (std::nothrow) char[data_size]; | |||
| if (param_data == nullptr) { | |||
| MS_LOG(ERROR) << "new char[] failed"; | |||
| return RET_MEMORY_FAILED; | |||
| } | |||
| if (memcpy_s(param_data, data_size, iter->ints().data(), data_size) != EOK) { | |||
| MS_LOG(ERROR) << "memcpy data failed."; | |||
| delete[] param_data; | |||
| return RET_ERROR; | |||
| } | |||
| param_value->set_tensor_shape(shape); | |||
| param_value->set_format(schema::Format_NUM_OF_FORMAT); | |||
| param_value->set_tensor_type(kNumberTypeInt64); | |||
| param_value->SetTensorData(param_data, data_size); | |||
| primitive_c->set_attr("const_data", param_value); | |||
| return RET_OK; | |||
| } | |||
| STATUS OnnxGivenTensorFillParser::ParseInt8GivenTensorFill(const onnx::NodeProto &onnx_node, | |||
| lite::PrimitiveC *primitive_c, | |||
| const std::vector<int> &shape) { | |||
| ParamValueLitePtr param_value = std::make_shared<ParamValueLite>(); | |||
| if (param_value == nullptr) { | |||
| MS_LOG(ERROR) << "new a paramValueLite failed."; | |||
| return RET_ERROR; | |||
| } | |||
| int data_count = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int64_t>()); | |||
| auto iter = std::find_if(onnx_node.attribute().begin(), onnx_node.attribute().end(), | |||
| [](const onnx::AttributeProto &attr) { return attr.name() == "values"; }); | |||
| if (iter == onnx_node.attribute().end()) { | |||
| return RET_OK; | |||
| } | |||
| char *param_data = new (std::nothrow) char[data_count]; | |||
| if (param_data == nullptr) { | |||
| MS_LOG(ERROR) << "new char[] failed"; | |||
| return RET_MEMORY_FAILED; | |||
| } | |||
| if (memcpy_s(param_data, data_count, iter->s().data(), data_count) != EOK) { | |||
| MS_LOG(ERROR) << "memcpy data failed."; | |||
| delete[] param_data; | |||
| return RET_ERROR; | |||
| } | |||
| param_value->set_tensor_shape(shape); | |||
| param_value->set_format(schema::Format_NUM_OF_FORMAT); | |||
| param_value->set_tensor_type(kNumberTypeUInt8); | |||
| param_value->SetTensorData(param_data, data_count); | |||
| primitive_c->set_attr("const_data", param_value); | |||
| return RET_OK; | |||
| } | |||
| lite::PrimitiveC *OnnxGivenTensorFillParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||
| const onnx::NodeProto &onnx_node) { | |||
| MS_LOG(DEBUG) << "onnx GivenTensorFillParser"; | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (primitive == nullptr) { | |||
| MS_LOG(ERROR) << "new primitive failed"; | |||
| return nullptr; | |||
| } | |||
| primitive->value.type = schema::PrimitiveType_Constant; | |||
| auto primitive_c = PrimitiveC::Create(primitive.release()); | |||
| std::vector<int64_t> shape_vector; | |||
| auto iter = std::find_if(onnx_node.attribute().begin(), onnx_node.attribute().end(), | |||
| [](const onnx::AttributeProto &attr) { return attr.name() == "shape"; }); | |||
| if (iter != onnx_node.attribute().end()) { | |||
| shape_vector.insert(shape_vector.begin(), iter->ints().begin(), iter->ints().end()); | |||
| } | |||
| std::vector<int> shape; | |||
| std::transform(shape_vector.begin(), shape_vector.end(), std::back_inserter(shape), | |||
| [](const int64_t &val) { return static_cast<int32_t>(val); }); | |||
| if (onnx_node.op_type() == "Int8GivenIntTensorFill") { | |||
| if (ParseInt8GivenIntTensorFill(onnx_node, primitive_c, shape) != RET_OK) { | |||
| MS_LOG(ERROR) << "given tensor fill parse failed."; | |||
| return nullptr; | |||
| } | |||
| } else if (onnx_node.op_type() == "Int8GivenTensorFill") { | |||
| if (ParseInt8GivenTensorFill(onnx_node, primitive_c, shape) != RET_OK) { | |||
| MS_LOG(ERROR) << "given tensor fill parse failed."; | |||
| return nullptr; | |||
| } | |||
| } | |||
| return primitive_c; | |||
| } | |||
| OnnxNodeRegistrar g_onnxInt8GivenIntTensorFillParser("Int8GivenIntTensorFill", new OnnxGivenTensorFillParser()); | |||
| OnnxNodeRegistrar g_onnxInt8GivenTensorFillParser("Int8GivenTensorFill", new OnnxGivenTensorFillParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,39 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_GIVEN_TENSOR_FILL_PARSER_H | |||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_GIVEN_TENSOR_FILL_PARSER_H | |||
| #include <vector> | |||
| #include "tools/converter/parser/onnx/onnx_node_parser.h" | |||
| #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class OnnxGivenTensorFillParser : public OnnxNodeParser { | |||
| public: | |||
| OnnxGivenTensorFillParser() : OnnxNodeParser("GivenTensorFill") {} | |||
| ~OnnxGivenTensorFillParser() override = default; | |||
| STATUS ParseInt8GivenIntTensorFill(const onnx::NodeProto &onnx_node, lite::PrimitiveC *primitive_c, | |||
| const std::vector<int> &shape); | |||
| STATUS ParseInt8GivenTensorFill(const onnx::NodeProto &onnx_node, lite::PrimitiveC *primitive_c, | |||
| const std::vector<int> &shape); | |||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_GIVEN_TENSOR_FILL_PARSER_H | |||
| @@ -20,28 +20,23 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS OnnxIdentityParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||
| schema::CNodeT *op) { | |||
| lite::PrimitiveC *OnnxIdentityParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||
| const onnx::NodeProto &onnx_node) { | |||
| MS_LOG(DEBUG) << "onnx IdentityParser"; | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::unique_ptr<schema::IdentityT> attr = std::make_unique<schema::IdentityT>(); | |||
| auto attr = std::make_unique<schema::IdentityT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| return nullptr; | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_Identity; | |||
| op->primitive->value.value = attr.release(); | |||
| return RET_OK; | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (primitive == nullptr) { | |||
| MS_LOG(ERROR) << "new primitive failed"; | |||
| return nullptr; | |||
| } | |||
| primitive->value.type = schema::PrimitiveType_Identity; | |||
| primitive->value.value = attr.release(); | |||
| return PrimitiveC::Create(primitive.release()); | |||
| } | |||
| OnnxNodeRegistrar g_onnxIdentityParser("Identity", new OnnxIdentityParser()); | |||
| @@ -27,7 +27,7 @@ class OnnxIdentityParser : public OnnxNodeParser { | |||
| OnnxIdentityParser() : OnnxNodeParser("Identity") {} | |||
| ~OnnxIdentityParser() override = default; | |||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -19,23 +19,13 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS OnnxInstanceNormParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||
| schema::CNodeT *op) { | |||
| lite::PrimitiveC *OnnxInstanceNormParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||
| const onnx::NodeProto &onnx_node) { | |||
| MS_LOG(DEBUG) << "onnx InstanceNormParser"; | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::unique_ptr<schema::InstanceNormT> attr = std::make_unique<schema::InstanceNormT>(); | |||
| auto attr = std::make_unique<schema::InstanceNormT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| return nullptr; | |||
| } | |||
| if (!onnx_node.attribute().empty()) { | |||
| @@ -44,10 +34,14 @@ STATUS OnnxInstanceNormParser::Parse(const onnx::GraphProto &onnx_graph, const o | |||
| attr->epsilon = onnx_node_attr.f(); | |||
| } | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_InstanceNorm; | |||
| op->primitive->value.value = attr.release(); | |||
| return RET_OK; | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (primitive == nullptr) { | |||
| MS_LOG(ERROR) << "new primitive failed"; | |||
| return nullptr; | |||
| } | |||
| primitive->value.type = schema::PrimitiveType_InstanceNorm; | |||
| primitive->value.value = attr.release(); | |||
| return PrimitiveC::Create(primitive.release()); | |||
| } | |||
| OnnxNodeRegistrar g_onnxInstanceNormParser("InstanceNormalization", new OnnxInstanceNormParser()); | |||
| @@ -27,7 +27,7 @@ class OnnxInstanceNormParser : public OnnxNodeParser { | |||
| OnnxInstanceNormParser() : OnnxNodeParser("InstanceNorm") {} | |||
| ~OnnxInstanceNormParser() override = default; | |||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -18,23 +18,13 @@ | |||
| #include <memory> | |||
| namespace mindspore::lite { | |||
| STATUS OnnxLpNormParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||
| schema::CNodeT *op) { | |||
| lite::PrimitiveC *OnnxLpNormParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||
| const onnx::NodeProto &onnx_node) { | |||
| MS_LOG(DEBUG) << "onnx LpNormParser"; | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::unique_ptr<schema::LpNormalizationT> attr = std::make_unique<schema::LpNormalizationT>(); | |||
| auto attr = std::make_unique<schema::LpNormalizationT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| return nullptr; | |||
| } | |||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | |||
| @@ -45,10 +35,14 @@ STATUS OnnxLpNormParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N | |||
| attr->p = onnx_node_attr.i(); | |||
| } | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_LpNormalization; | |||
| op->primitive->value.value = attr.release(); | |||
| return RET_OK; | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (primitive == nullptr) { | |||
| MS_LOG(ERROR) << "new primitive failed"; | |||
| return nullptr; | |||
| } | |||
| primitive->value.type = schema::PrimitiveType_LpNormalization; | |||
| primitive->value.value = attr.release(); | |||
| return PrimitiveC::Create(primitive.release()); | |||
| } | |||
| OnnxNodeRegistrar g_onnxLpNormParser("LpNormalization", new OnnxLpNormParser()); | |||
| @@ -27,7 +27,7 @@ class OnnxLpNormParser : public OnnxNodeParser { | |||
| OnnxLpNormParser() : OnnxNodeParser("LpNorm") {} | |||
| ~OnnxLpNormParser() override = default; | |||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -18,22 +18,13 @@ | |||
| #include <memory> | |||
| namespace mindspore::lite { | |||
| STATUS OnnxLrnParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { | |||
| lite::PrimitiveC *OnnxLrnParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||
| const onnx::NodeProto &onnx_node) { | |||
| MS_LOG(DEBUG) << "onnx LrnParser"; | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::unique_ptr<schema::LocalResponseNormalizationT> attr = std::make_unique<schema::LocalResponseNormalizationT>(); | |||
| auto attr = std::make_unique<schema::LocalResponseNormalizationT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| return nullptr; | |||
| } | |||
| int32_t size = 0; | |||
| @@ -53,13 +44,18 @@ STATUS OnnxLrnParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Node | |||
| if (size == 0) { | |||
| MS_LOG(ERROR) << "Divide-by-zero error."; | |||
| return RET_ERROR; | |||
| return nullptr; | |||
| } | |||
| attr->alpha /= size; | |||
| op->primitive->value.type = schema::PrimitiveType_LocalResponseNormalization; | |||
| op->primitive->value.value = attr.release(); | |||
| return RET_OK; | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (primitive == nullptr) { | |||
| MS_LOG(ERROR) << "new primitive failed"; | |||
| return nullptr; | |||
| } | |||
| primitive->value.type = schema::PrimitiveType_LocalResponseNormalization; | |||
| primitive->value.value = attr.release(); | |||
| return PrimitiveC::Create(primitive.release()); | |||
| } | |||
| OnnxNodeRegistrar g_onnxLrnxParser("Lrn", new OnnxLrnParser()); | |||
| @@ -27,7 +27,7 @@ class OnnxLrnParser : public OnnxNodeParser { | |||
| OnnxLrnParser() : OnnxNodeParser("Lrn") {} | |||
| ~OnnxLrnParser() override = default; | |||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -19,22 +19,13 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS OnnxLstmParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { | |||
| lite::PrimitiveC *OnnxLstmParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||
| const onnx::NodeProto &onnx_node) { | |||
| MS_LOG(DEBUG) << "onnx LstmParser"; | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::unique_ptr<schema::LstmT> attr = std::make_unique<schema::LstmT>(); | |||
| auto attr = std::make_unique<schema::LstmT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| return nullptr; | |||
| } | |||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | |||
| @@ -44,9 +35,14 @@ STATUS OnnxLstmParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod | |||
| } | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_Lstm; | |||
| op->primitive->value.value = attr.release(); | |||
| return RET_OK; | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (primitive == nullptr) { | |||
| MS_LOG(ERROR) << "new primitive failed"; | |||
| return nullptr; | |||
| } | |||
| primitive->value.type = schema::PrimitiveType_Lstm; | |||
| primitive->value.value = attr.release(); | |||
| return PrimitiveC::Create(primitive.release()); | |||
| } | |||
| OnnxNodeRegistrar g_onnxLstmParser("LSTM", new OnnxLstmParser()); | |||
| @@ -27,7 +27,7 @@ class OnnxLstmParser : public OnnxNodeParser { | |||
| OnnxLstmParser() : OnnxNodeParser("LSTM") {} | |||
| ~OnnxLstmParser() override = default; | |||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -19,23 +19,13 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS OnnxMatmulParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||
| schema::CNodeT *op) { | |||
| lite::PrimitiveC *OnnxMatmulParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||
| const onnx::NodeProto &onnx_node) { | |||
| MS_LOG(DEBUG) << "onnx MatMulParser"; | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::unique_ptr<schema::MatMulT> attr = std::make_unique<schema::MatMulT>(); | |||
| auto attr = std::make_unique<schema::MatMulT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| return nullptr; | |||
| } | |||
| float alpha = 1.0f; | |||
| @@ -54,12 +44,17 @@ STATUS OnnxMatmulParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N | |||
| } | |||
| if (alpha != 1 || beta != 1) { | |||
| MS_LOG(ERROR) << "not support alpha * A * B + beta * C"; | |||
| return RET_ERROR; | |||
| return nullptr; | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_MatMul; | |||
| op->primitive->value.value = attr.release(); | |||
| return RET_OK; | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (primitive == nullptr) { | |||
| MS_LOG(ERROR) << "new primitive failed"; | |||
| return nullptr; | |||
| } | |||
| primitive->value.type = schema::PrimitiveType_MatMul; | |||
| primitive->value.value = attr.release(); | |||
| return PrimitiveC::Create(primitive.release()); | |||
| } | |||
| OnnxNodeRegistrar g_onnxMatmulParser("MatMul", new OnnxMatmulParser()); | |||
| @@ -27,7 +27,7 @@ class OnnxMatmulParser : public OnnxNodeParser { | |||
| OnnxMatmulParser() : OnnxNodeParser("MatMul") {} | |||
| ~OnnxMatmulParser() override = default; | |||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -26,75 +26,57 @@ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include <set> | |||
| #include <map> | |||
| #include <unordered_map> | |||
| #include "securec/include/securec.h" | |||
| #include "tools/converter/model_parser.h" | |||
| #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" | |||
| #include "tools/converter/parser/onnx/onnx_tensor_parser.h" | |||
| #include "proto/onnx.pb.h" | |||
| #include "src/param_value_lite.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class OnnxModelParser : public ModelParser { | |||
| public: | |||
| OnnxModelParser(); | |||
| OnnxModelParser() = default; | |||
| virtual ~OnnxModelParser(); | |||
| ~OnnxModelParser() override = default; | |||
| // schema::MetaGraphT *ParseGraph(const onnx::GraphProto &graph, const QuantType &quantType = QuantType_QUANT_NONE); | |||
| int ParseGraph(schema::MetaGraphT *dst_graph, schema::SubGraphT *dst_sub_graph, const onnx::GraphProto &onnx_graph, | |||
| const QuantType &quantType); | |||
| MetaGraphT *ParseToFb(const std::string &model_file, const std::string &weight_file, | |||
| const QuantType &quant_type) override { | |||
| return nullptr; | |||
| } | |||
| FuncGraphPtr Parse(const std::string &model_file, const std::string &weight_file, | |||
| const QuantType &quant_type) override; | |||
| static TypeId GetDataTypeFromOnnx(onnx::TensorProto_DataType onnx_type); | |||
| static STATUS CopyOnnxTensorData(const onnx::TensorProto &onnx_const_value, | |||
| const ParamValueLitePtr ¶m_value_lite); | |||
| private: | |||
| schema::MetaGraphT *ParseToFb(const std::string &model_file, const std::string &weight_file, | |||
| const QuantType &quant_type = QuantType_QUANT_NONE) override; | |||
| std::vector<int32_t> GetDimsFromOnnxValue(const onnx::ValueInfoProto &onnx_value); | |||
| STATUS SetGraphConstTensor(const onnx::GraphProto &onnx_graph); | |||
| STATUS SetGraphInputTensor(const onnx::GraphProto &onnx_graph, schema::SubGraphT *graph); | |||
| STATUS SetGraphOutputTensor(const onnx::GraphProto &onnx_graph, schema::SubGraphT *graph); | |||
| STATUS AddValueInfo(const onnx::ValueInfoProto &proto, const std::string &name, const Category &type, int *index); | |||
| STATUS AddTensorProto(const onnx::TensorProto &proto, const std::string &name, const Category &type, int *index); | |||
| STATUS ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||
| schema::CNodeT *dst_op, const QuantType &quantType, schema::MetaGraphT *dst_graph); | |||
| void ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||
| schema::SubGraphT *sub_graph, schema::MetaGraphT *graph, const QuantType &quant_type); | |||
| STATUS ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node); | |||
| STATUS ParseOnnxNodeAttr(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||
| const string &onnx_op_type, schema::CNodeT *dst_op); | |||
| void SetOpQuantParams(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *dst_op, | |||
| schema::TensorT *dst_tensor); | |||
| STATUS SetOpInputIndex(const std::vector<string> &node_inputs, schema::CNodeT *dst_op, | |||
| const onnx::NodeProto &onnx_node); | |||
| STATUS SetOpOutputIndex(const std::vector<string> &node_outputs, schema::CNodeT *dst_op); | |||
| STATUS CopyOnnxTensorData(const onnx::TensorProto &onnx_init_value, schema::TensorT *tensor); | |||
| STATUS SetAllTensors(schema::MetaGraphT *graphDef); | |||
| void FindGraphInputAndConst(const onnx::GraphProto &onnx_graph); | |||
| STATUS ParseSubgraph(schema::CNodeT *dst_op, const onnx::NodeProto &onnx_node, const QuantType &quantType, | |||
| schema::MetaGraphT *dst_graph); | |||
| private: | |||
| std::vector<std::string> graphInputNames; | |||
| std::vector<std::string> graphConstNames; | |||
| int subGraphNum = 0; | |||
| STATUS InitOriginModel(const std::string &model_file); | |||
| STATUS ConvertNodes(); | |||
| STATUS ConvertConstTensors(); | |||
| STATUS ConvertGraphInputs(); | |||
| STATUS ConvertGraphOutputs(); | |||
| STATUS BuildReturnNode(const std::vector<AnfNodePtr> &return_inputs); | |||
| STATUS BuildParameterNode(const ParameterPtr ¶meter_node, const onnx::TensorProto &tensor); | |||
| STATUS BuildParameterNodeForQuantParam(void *data, const std::string &name, TypeId type); | |||
| STATUS BuildCNode(const onnx::NodeProto &onnx_node, lite::PrimitiveC *primitive_c); | |||
| STATUS BuildOpOutputs(const onnx::NodeProto &onnx_node, const CNodePtr &cnode); | |||
| STATUS ConvertSpecialOnnxNode(const onnx::NodeProto &onnx_node, lite::PrimitiveC *primitive_c); | |||
| STATUS ConvertOnnxGemmNode(const onnx::NodeProto &onnx_node, lite::PrimitiveC *primitive_c); | |||
| STATUS BuildCNodeForGemm(const onnx::NodeProto &onnx_node, lite::PrimitiveC *primitive_c, const std::string &name); | |||
| STATUS ConvertOpQuantParams(const onnx::NodeProto &onnx_node, lite::PrimitiveC *primitive_c); | |||
| STATUS ParseQuantParam(const onnx::NodeProto &onnx_node); | |||
| STATUS SetTensorQuantParam(const std::string &tensor_name, std::vector<QuantParamT> *quant_params); | |||
| STATUS SetTensorQuantParamFromNode(const std::string &tensor_name, std::vector<QuantParamT> *quant_params); | |||
| STATUS CopyTensorQuantParam(const std::string &tensor_name, QuantParamT *quant_param, bool scale_or_not); | |||
| bool IsSpecialOnnxNode(const onnx::NodeProto &onnx_node); | |||
| onnx::ModelProto onnx_model_; | |||
| onnx::GraphProto onnx_graph_; | |||
| std::unordered_map<std::string, AnfNodePtr> nodes_; | |||
| FuncGraphPtr func_graph_ptr_ = nullptr; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -15,7 +15,9 @@ | |||
| */ | |||
| #include "tools/converter/parser/onnx/onnx_node_parser.h" | |||
| #include <algorithm> | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "tools/converter/parser/onnx/onnx_model_parser.h" | |||
| namespace mindspore { | |||
| @@ -20,6 +20,7 @@ | |||
| #include <string> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "src/ops/primitive_c.h" | |||
| #include "google/protobuf/message.h" | |||
| #include "proto/onnx.pb.h" | |||
| #include "include/errorcode.h" | |||
| @@ -34,7 +35,8 @@ class OnnxNodeParser { | |||
| virtual ~OnnxNodeParser() = default; | |||
| virtual STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) = 0; | |||
| virtual lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||
| const onnx::NodeProto &onnx_node) = 0; | |||
| static STATUS GetTensorDataFromOnnx(const onnx::TensorProto &onnx_tensor, std::vector<float> *value, int *type); | |||
| @@ -19,23 +19,13 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS OnnxNonMaxSuppressionParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||
| schema::CNodeT *op) { | |||
| lite::PrimitiveC *OnnxNonMaxSuppressionParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||
| const onnx::NodeProto &onnx_node) { | |||
| MS_LOG(DEBUG) << "onnx EluParser"; | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::unique_ptr<schema::NonMaxSuppressionT> attr = std::make_unique<schema::NonMaxSuppressionT>(); | |||
| auto attr = std::make_unique<schema::NonMaxSuppressionT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| return nullptr; | |||
| } | |||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | |||
| @@ -47,9 +37,14 @@ STATUS OnnxNonMaxSuppressionParser::Parse(const onnx::GraphProto &onnx_graph, co | |||
| } | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_NonMaxSuppression; | |||
| op->primitive->value.value = attr.release(); | |||
| return RET_OK; | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (primitive == nullptr) { | |||
| MS_LOG(ERROR) << "new primitive failed"; | |||
| return nullptr; | |||
| } | |||
| primitive->value.type = schema::PrimitiveType_NonMaxSuppression; | |||
| primitive->value.value = attr.release(); | |||
| return PrimitiveC::Create(primitive.release()); | |||
| } | |||
| OnnxNodeRegistrar g_onnxNonMaxSuppressionParser("NonMaxSuppression", new OnnxNonMaxSuppressionParser()); | |||
| @@ -27,7 +27,7 @@ class OnnxNonMaxSuppressionParser : public OnnxNodeParser { | |||
| OnnxNonMaxSuppressionParser() : OnnxNodeParser("NonMaxSuppression") {} | |||
| ~OnnxNonMaxSuppressionParser() override = default; | |||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -19,23 +19,13 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS OnnxOneHotParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||
| schema::CNodeT *op) { | |||
| lite::PrimitiveC *OnnxOneHotParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||
| const onnx::NodeProto &onnx_node) { | |||
| MS_LOG(DEBUG) << "onnx OneHotParser"; | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::unique_ptr<schema::OneHotT> attr = std::make_unique<schema::OneHotT>(); | |||
| auto attr = std::make_unique<schema::OneHotT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| return nullptr; | |||
| } | |||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | |||
| @@ -45,9 +35,14 @@ STATUS OnnxOneHotParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N | |||
| } | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_OneHot; | |||
| op->primitive->value.value = attr.release(); | |||
| return RET_OK; | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (primitive == nullptr) { | |||
| MS_LOG(ERROR) << "new primitive failed"; | |||
| return nullptr; | |||
| } | |||
| primitive->value.type = schema::PrimitiveType_OneHot; | |||
| primitive->value.value = attr.release(); | |||
| return PrimitiveC::Create(primitive.release()); | |||
| } | |||
| OnnxNodeRegistrar g_onnxOneHotParser("OneHot", new OnnxOneHotParser()); | |||
| @@ -27,7 +27,7 @@ class OnnxOneHotParser : public OnnxNodeParser { | |||
| OnnxOneHotParser() : OnnxNodeParser("OneHot") {} | |||
| ~OnnxOneHotParser() override = default; | |||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -19,22 +19,13 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS OnnxPadParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { | |||
| lite::PrimitiveC *OnnxPadParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||
| const onnx::NodeProto &onnx_node) { | |||
| MS_LOG(DEBUG) << "onnx PadParser"; | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::unique_ptr<schema::PadT> attr = std::make_unique<schema::PadT>(); | |||
| auto attr = std::make_unique<schema::PadT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| return nullptr; | |||
| } | |||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | |||
| @@ -58,9 +49,14 @@ STATUS OnnxPadParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Node | |||
| } | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_Pad; | |||
| op->primitive->value.value = attr.release(); | |||
| return RET_OK; | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (primitive == nullptr) { | |||
| MS_LOG(ERROR) << "new primitive failed"; | |||
| return nullptr; | |||
| } | |||
| primitive->value.type = schema::PrimitiveType_Pad; | |||
| primitive->value.value = attr.release(); | |||
| return PrimitiveC::Create(primitive.release()); | |||
| } | |||
| OnnxNodeRegistrar g_onnxPadParser("Pad", new OnnxPadParser()); | |||
| @@ -27,7 +27,7 @@ class OnnxPadParser : public OnnxNodeParser { | |||
| OnnxPadParser() : OnnxNodeParser("Pad") {} | |||
| ~OnnxPadParser() override = default; | |||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -19,22 +19,13 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS OnnxPoolParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { | |||
| lite::PrimitiveC *OnnxPoolParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||
| const onnx::NodeProto &onnx_node) { | |||
| MS_LOG(DEBUG) << "onnx PoolParser"; | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::unique_ptr<schema::PoolingT> attr = std::make_unique<schema::PoolingT>(); | |||
| auto attr = std::make_unique<schema::PoolingT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| return nullptr; | |||
| } | |||
| attr->format = schema::Format::Format_NCHW; | |||
| @@ -56,7 +47,7 @@ STATUS OnnxPoolParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod | |||
| attr->global = false; | |||
| } else { | |||
| MS_LOG(ERROR) << "Pooling param`s PoolingMode is not MAX either AVE. MindSpore support MAX and AVE only."; | |||
| return RET_ERROR; | |||
| return nullptr; | |||
| } | |||
| attr->roundMode = schema::RoundMode_FLOOR; | |||
| @@ -101,13 +92,18 @@ STATUS OnnxPoolParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod | |||
| } | |||
| if (attribute_name == "dilations") { | |||
| MS_LOG(ERROR) << "pooling op not support dilations now"; | |||
| return RET_ERROR; | |||
| return nullptr; | |||
| } | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_Pooling; | |||
| op->primitive->value.value = attr.release(); | |||
| return RET_OK; | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (primitive == nullptr) { | |||
| MS_LOG(ERROR) << "new primitive failed"; | |||
| return nullptr; | |||
| } | |||
| primitive->value.type = schema::PrimitiveType_Pooling; | |||
| primitive->value.value = attr.release(); | |||
| return PrimitiveC::Create(primitive.release()); | |||
| } | |||
| OnnxNodeRegistrar g_onnxMaxPoolParser("MaxPool", new OnnxPoolParser()); | |||
| @@ -27,7 +27,7 @@ class OnnxPoolParser : public OnnxNodeParser { | |||
| OnnxPoolParser() : OnnxNodeParser("Pool") {} | |||
| ~OnnxPoolParser() override = default; | |||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -19,23 +19,13 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS OnnxQuantizeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||
| schema::CNodeT *op) { | |||
| lite::PrimitiveC *OnnxQuantizeParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||
| const onnx::NodeProto &onnx_node) { | |||
| MS_LOG(DEBUG) << "onnx QuantizeDequantizeParser"; | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::unique_ptr<schema::QuantDTypeCastT> attr = std::make_unique<schema::QuantDTypeCastT>(); | |||
| auto attr = std::make_unique<schema::QuantDTypeCastT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed."; | |||
| return RET_NULL_PTR; | |||
| return nullptr; | |||
| } | |||
| if (onnx_node.op_type() == "Int8Quantize") { | |||
| attr->srcT = kNumberTypeFloat32; | |||
| @@ -45,11 +35,16 @@ STATUS OnnxQuantizeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx: | |||
| attr->dstT = kNumberTypeFloat32; | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported nodeType: " << onnx_node.op_type().c_str(); | |||
| return RET_ERROR; | |||
| return nullptr; | |||
| } | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (primitive == nullptr) { | |||
| MS_LOG(ERROR) << "new primitive failed"; | |||
| return nullptr; | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_QuantDTypeCast; | |||
| op->primitive->value.value = attr.release(); | |||
| return RET_OK; | |||
| primitive->value.type = schema::PrimitiveType_QuantDTypeCast; | |||
| primitive->value.value = attr.release(); | |||
| return PrimitiveC::Create(primitive.release()); | |||
| } | |||
| OnnxNodeRegistrar g_onnxInt8QuantizeParser("Int8Quantize", new OnnxQuantizeParser()); | |||
| @@ -27,7 +27,7 @@ class OnnxQuantizeParser : public OnnxNodeParser { | |||
| OnnxQuantizeParser() : OnnxNodeParser("Quantize") {} | |||
| ~OnnxQuantizeParser() override = default; | |||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -19,28 +19,23 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS OnnxRangeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||
| schema::CNodeT *op) { | |||
| lite::PrimitiveC *OnnxRangeParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||
| const onnx::NodeProto &onnx_node) { | |||
| MS_LOG(DEBUG) << "onnx RangeParser"; | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::unique_ptr<schema::RangeT> attr = std::make_unique<schema::RangeT>(); | |||
| auto attr = std::make_unique<schema::RangeT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| return nullptr; | |||
| } | |||
| attr->dType = 0; | |||
| op->primitive->value.type = schema::PrimitiveType_Range; | |||
| op->primitive->value.value = attr.release(); | |||
| return RET_OK; | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (primitive == nullptr) { | |||
| MS_LOG(ERROR) << "new primitive failed"; | |||
| return nullptr; | |||
| } | |||
| primitive->value.type = schema::PrimitiveType_Range; | |||
| primitive->value.value = attr.release(); | |||
| return PrimitiveC::Create(primitive.release()); | |||
| } | |||
| OnnxNodeRegistrar g_onnxRangeParser("Range", new OnnxRangeParser()); | |||
| @@ -27,7 +27,7 @@ class OnnxRangeParser : public OnnxNodeParser { | |||
| OnnxRangeParser() : OnnxNodeParser("Range") {} | |||
| ~OnnxRangeParser() override = default; | |||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -19,23 +19,13 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS OnnxReduceParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||
| schema::CNodeT *op) { | |||
| lite::PrimitiveC *OnnxReduceParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||
| const onnx::NodeProto &onnx_node) { | |||
| MS_LOG(DEBUG) << "onnx ReduceParser"; | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::unique_ptr<schema::ReduceT> attr = std::make_unique<schema::ReduceT>(); | |||
| auto attr = std::make_unique<schema::ReduceT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| return nullptr; | |||
| } | |||
| attr->keepDims = 1; | |||
| @@ -65,12 +55,17 @@ STATUS OnnxReduceParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N | |||
| attr->mode = schema::ReduceMode_ReduceSumSquare; | |||
| } else { | |||
| MS_LOG(ERROR) << "unsupported type"; | |||
| return RET_ERROR; | |||
| return nullptr; | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_Reduce; | |||
| op->primitive->value.value = attr.release(); | |||
| return RET_OK; | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (primitive == nullptr) { | |||
| MS_LOG(ERROR) << "new primitive failed"; | |||
| return nullptr; | |||
| } | |||
| primitive->value.type = schema::PrimitiveType_Reduce; | |||
| primitive->value.value = attr.release(); | |||
| return PrimitiveC::Create(primitive.release()); | |||
| } | |||
| OnnxNodeRegistrar g_onnxReduceMeanParser("ReduceMean", new OnnxReduceParser()); | |||
| @@ -27,7 +27,7 @@ class OnnxReduceParser : public OnnxNodeParser { | |||
| OnnxReduceParser() : OnnxNodeParser("Reduce") {} | |||
| ~OnnxReduceParser() override = default; | |||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -21,22 +21,13 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS OnnxReluParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { | |||
| lite::PrimitiveC *OnnxReluParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||
| const onnx::NodeProto &onnx_node) { | |||
| MS_LOG(DEBUG) << "onnx ReluParser"; | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::unique_ptr<schema::ActivationT> attr = std::make_unique<schema::ActivationT>(); | |||
| auto attr = std::make_unique<schema::ActivationT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| return nullptr; | |||
| } | |||
| const auto &relu_type = onnx_node.op_type(); | |||
| @@ -54,29 +45,24 @@ STATUS OnnxReluParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod | |||
| } | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_Activation; | |||
| op->primitive->value.value = attr.release(); | |||
| return RET_OK; | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (primitive == nullptr) { | |||
| MS_LOG(ERROR) << "new primitive failed"; | |||
| return nullptr; | |||
| } | |||
| primitive->value.type = schema::PrimitiveType_Activation; | |||
| primitive->value.value = attr.release(); | |||
| return PrimitiveC::Create(primitive.release()); | |||
| } | |||
| STATUS OnnxPReluParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||
| schema::CNodeT *op) { | |||
| lite::PrimitiveC *OnnxPReluParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||
| const onnx::NodeProto &onnx_node) { | |||
| MS_LOG(DEBUG) << "onnx PReluParser"; | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| if (onnx_node.input_size() != 2) { | |||
| MS_LOG(ERROR) << "input num should be 2"; | |||
| return RET_ERROR; | |||
| return nullptr; | |||
| } | |||
| std::unique_ptr<schema::PReLUT> attr = std::make_unique<schema::PReLUT>(); | |||
| auto attr = std::make_unique<schema::PReLUT>(); | |||
| std::vector<onnx::TensorProto> params; | |||
| const auto &input_name = onnx_node.input(1); | |||
| for (const auto &it : onnx_graph.initializer()) { | |||
| @@ -90,7 +76,7 @@ STATUS OnnxPReluParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::No | |||
| const onnx::TensorProto *slope = ¶ms[0]; | |||
| if (slope == nullptr) { | |||
| MS_LOG(ERROR) << "input error: params[0] is null"; | |||
| return RET_ERROR; | |||
| return nullptr; | |||
| } | |||
| const auto slope_raw_data = reinterpret_cast<const float *>(slope->raw_data().data()); | |||
| const int64_t slope_size = slope->raw_data().size() / sizeof(float); | |||
| @@ -102,16 +88,21 @@ STATUS OnnxPReluParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::No | |||
| attr->channelShared = false; | |||
| if (memcpy_s(attr->slope.data(), slope_size * sizeof(float), slope_raw_data, slope_size * sizeof(float)) != EOK) { | |||
| MS_LOG(ERROR) << "memcpy_s failed"; | |||
| return RET_ERROR; | |||
| return nullptr; | |||
| } | |||
| } | |||
| } else { | |||
| MS_LOG(WARNING) << "The slope pf prelu is null, which may cause errors."; | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_PReLU; | |||
| op->primitive->value.value = attr.release(); | |||
| return RET_OK; | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (primitive == nullptr) { | |||
| MS_LOG(ERROR) << "new primitive failed"; | |||
| return nullptr; | |||
| } | |||
| primitive->value.type = schema::PrimitiveType_PReLU; | |||
| primitive->value.value = attr.release(); | |||
| return PrimitiveC::Create(primitive.release()); | |||
| } | |||
| OnnxNodeRegistrar g_onnxReluParser("Relu", new OnnxReluParser()); | |||
| @@ -27,7 +27,7 @@ class OnnxReluParser : public OnnxNodeParser { | |||
| OnnxReluParser() : OnnxNodeParser("Relu") {} | |||
| ~OnnxReluParser() override = default; | |||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||
| }; | |||
| class OnnxPReluParser : public OnnxNodeParser { | |||
| @@ -35,7 +35,7 @@ class OnnxPReluParser : public OnnxNodeParser { | |||
| OnnxPReluParser() : OnnxNodeParser("Prelu") {} | |||
| ~OnnxPReluParser() override = default; | |||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -20,23 +20,13 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS OnnxReshapeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||
| schema::CNodeT *op) { | |||
| lite::PrimitiveC *OnnxReshapeParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||
| const onnx::NodeProto &onnx_node) { | |||
| MS_LOG(DEBUG) << "onnx ReshapeParser"; | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::unique_ptr<schema::ReshapeT> attr = std::make_unique<schema::ReshapeT>(); | |||
| auto attr = std::make_unique<schema::ReshapeT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| return nullptr; | |||
| } | |||
| attr->format = schema::Format_NCHW; | |||
| @@ -51,28 +41,17 @@ STATUS OnnxReshapeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx:: | |||
| } | |||
| } | |||
| } | |||
| } else { | |||
| onnx::TensorProto input_shape; | |||
| const auto &shape_name = onnx_node.input(1); | |||
| for (const auto &it : onnx_graph.initializer()) { | |||
| if (it.name() == shape_name) { | |||
| input_shape = it; | |||
| break; | |||
| } | |||
| } | |||
| if (input_shape.int64_data_size() == 0) { | |||
| MS_LOG(INFO) << "shape maybe from another op other than const initializer"; | |||
| } else { | |||
| for (int i = 0; i < input_shape.int64_data_size(); ++i) { | |||
| shape.push_back(input_shape.int64_data(i)); | |||
| } | |||
| } | |||
| } | |||
| attr->shape = shape; | |||
| op->primitive->value.type = schema::PrimitiveType_Reshape; | |||
| op->primitive->value.value = attr.release(); | |||
| return RET_OK; | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (primitive == nullptr) { | |||
| MS_LOG(ERROR) << "new primitive failed"; | |||
| return nullptr; | |||
| } | |||
| primitive->value.type = schema::PrimitiveType_Reshape; | |||
| primitive->value.value = attr.release(); | |||
| return PrimitiveC::Create(primitive.release()); | |||
| } | |||
| OnnxNodeRegistrar g_onnxReshapeParser("Reshape", new OnnxReshapeParser()); | |||
| @@ -27,7 +27,7 @@ class OnnxReshapeParser : public OnnxNodeParser { | |||
| OnnxReshapeParser() : OnnxNodeParser("Reshape") {} | |||
| ~OnnxReshapeParser() override = default; | |||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -22,23 +22,13 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS OnnxResizeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||
| schema::CNodeT *op) { | |||
| lite::PrimitiveC *OnnxResizeParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||
| const onnx::NodeProto &onnx_node) { | |||
| MS_LOG(DEBUG) << "onnx ResizeParser"; | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::unique_ptr<schema::ResizeT> attr = std::make_unique<schema::ResizeT>(); | |||
| auto attr = std::make_unique<schema::ResizeT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| return nullptr; | |||
| } | |||
| attr->format = schema::Format_NCHW; | |||
| @@ -85,9 +75,14 @@ STATUS OnnxResizeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N | |||
| } | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_Resize; | |||
| op->primitive->value.value = attr.release(); | |||
| return RET_OK; | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (primitive == nullptr) { | |||
| MS_LOG(ERROR) << "new primitive failed"; | |||
| return nullptr; | |||
| } | |||
| primitive->value.type = schema::PrimitiveType_Resize; | |||
| primitive->value.value = attr.release(); | |||
| return PrimitiveC::Create(primitive.release()); | |||
| } | |||
| OnnxNodeRegistrar g_onnxResizeParser("Resize", new OnnxResizeParser()); | |||
| @@ -27,7 +27,7 @@ class OnnxResizeParser : public OnnxNodeParser { | |||
| OnnxResizeParser() : OnnxNodeParser("Resize") {} | |||
| ~OnnxResizeParser() override = default; | |||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -19,28 +19,23 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS OnnxShapeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||
| schema::CNodeT *op) { | |||
| lite::PrimitiveC *OnnxShapeParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||
| const onnx::NodeProto &onnx_node) { | |||
| MS_LOG(DEBUG) << "onnx ShapeParser"; | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::unique_ptr<schema::ShapeT> attr = std::make_unique<schema::ShapeT>(); | |||
| auto attr = std::make_unique<schema::ShapeT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| return nullptr; | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_Shape; | |||
| op->primitive->value.value = attr.release(); | |||
| return RET_OK; | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (primitive == nullptr) { | |||
| MS_LOG(ERROR) << "new primitive failed"; | |||
| return nullptr; | |||
| } | |||
| primitive->value.type = schema::PrimitiveType_Shape; | |||
| primitive->value.value = attr.release(); | |||
| return PrimitiveC::Create(primitive.release()); | |||
| } | |||
| OnnxNodeRegistrar g_onnxShapeParser("Shape", new OnnxShapeParser()); | |||
| @@ -27,7 +27,7 @@ class OnnxShapeParser : public OnnxNodeParser { | |||
| OnnxShapeParser() : OnnxNodeParser("Shape") {} | |||
| ~OnnxShapeParser() override = default; | |||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -19,30 +19,25 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS OnnxSigmoidParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||
| schema::CNodeT *op) { | |||
| lite::PrimitiveC *OnnxSigmoidParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||
| const onnx::NodeProto &onnx_node) { | |||
| MS_LOG(DEBUG) << "onnx SigmoidParser"; | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::unique_ptr<schema::ActivationT> attr = std::make_unique<schema::ActivationT>(); | |||
| auto attr = std::make_unique<schema::ActivationT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| return nullptr; | |||
| } | |||
| attr->type = schema::ActivationType_SIGMOID; | |||
| op->primitive->value.type = schema::PrimitiveType_Activation; | |||
| op->primitive->value.value = attr.release(); | |||
| return RET_OK; | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (primitive == nullptr) { | |||
| MS_LOG(ERROR) << "new primitive failed"; | |||
| return nullptr; | |||
| } | |||
| primitive->value.type = schema::PrimitiveType_Activation; | |||
| primitive->value.value = attr.release(); | |||
| return PrimitiveC::Create(primitive.release()); | |||
| } | |||
| OnnxNodeRegistrar g_onnxSigmoodParser("Sigmoid", new OnnxSigmoidParser()); | |||
| @@ -27,7 +27,7 @@ class OnnxSigmoidParser : public OnnxNodeParser { | |||
| OnnxSigmoidParser() : OnnxNodeParser("Sigmoid") {} | |||
| ~OnnxSigmoidParser() override = default; | |||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -23,77 +23,13 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS OnnxSliceParser::InsertTensor(const std::vector<int> &onnx_val, const std::string &name, | |||
| onnx::NodeProto *onnx_node) { | |||
| std::unique_ptr<schema::TensorT> tensor = std::make_unique<schema::TensorT>(); | |||
| if (tensor == nullptr) { | |||
| MS_LOG(ERROR) << "new tensor failed"; | |||
| return RET_ERROR; | |||
| } | |||
| tensor->dataType = mindspore::kNumberTypeInt32; | |||
| tensor->dims.push_back(onnx_val.size()); | |||
| tensor->format = schema::Format::Format_NCHW; | |||
| tensor->nodeType = schema::NodeType::NodeType_ValueNode; | |||
| int data_size = sizeof(int32_t) * onnx_val.size(); | |||
| tensor->data.resize(data_size); | |||
| if (data_size != 0 && | |||
| memcpy_s(static_cast<void *>(tensor->data.data()), data_size, onnx_val.data(), data_size) != EOK) { | |||
| MS_LOG(ERROR) << "memcpy_s failed"; | |||
| return RET_ERROR; | |||
| } | |||
| int tensor_num = OnnxTensorParser::GetInstance()->GetTensorCache()->GetCachedTensor().size(); | |||
| std::string tensor_name = name + std::to_string(tensor_num); | |||
| OnnxTensorParser::GetInstance()->GetTensorCache()->AddTensor(tensor_name, tensor.release(), GRAPH_INPUT); | |||
| onnx_node->add_input(tensor_name); | |||
| return RET_OK; | |||
| } | |||
| STATUS OnnxSliceParser::GetInputTensor(std::vector<int> *onnx_val, const std::string &name) { | |||
| if (onnx_val == nullptr) { | |||
| MS_LOG(ERROR) << "input vector is nullptr."; | |||
| return RET_ERROR; | |||
| } | |||
| if (OnnxTensorParser::GetInstance() == nullptr || OnnxTensorParser::GetInstance()->GetTensorCache() == nullptr) { | |||
| MS_LOG(ERROR) << "cannot get tensorcache."; | |||
| return RET_ERROR; | |||
| } | |||
| int index = OnnxTensorParser::GetInstance()->GetTensorCache()->FindTensor(name); | |||
| if (index == -1) { | |||
| MS_LOG(ERROR) << "can not find node: " << name; | |||
| return RET_ERROR; | |||
| } | |||
| auto input_tensor = OnnxTensorParser::GetInstance()->GetTensorCache()->GetCachedTensor()[index]; | |||
| if (input_tensor->data.empty()) { | |||
| MS_LOG(DEBUG) << "data is empty."; | |||
| return RET_NO_CHANGE; | |||
| } | |||
| int data_num = std::accumulate(input_tensor->dims.begin(), input_tensor->dims.end(), 1, std::multiplies<int>()); | |||
| onnx_val->resize(data_num); | |||
| if (memcpy_s(onnx_val->data(), data_num * sizeof(int32_t), input_tensor->data.data(), data_num * sizeof(int32_t)) != | |||
| EOK) { | |||
| MS_LOG(ERROR) << "memcpy_s failed"; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS OnnxSliceParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||
| schema::CNodeT *op) { | |||
| lite::PrimitiveC *OnnxSliceParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||
| const onnx::NodeProto &onnx_node) { | |||
| MS_LOG(DEBUG) << "onnx SliceParser"; | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::unique_ptr<schema::StridedSliceT> attr = std::make_unique<schema::StridedSliceT>(); | |||
| auto attr = std::make_unique<schema::StridedSliceT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| return nullptr; | |||
| } | |||
| std::vector<int> starts; | |||
| @@ -128,36 +64,17 @@ STATUS OnnxSliceParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::No | |||
| } | |||
| } | |||
| } | |||
| int status = RET_OK; | |||
| switch (onnx_node.input_size()) { | |||
| case 5: { | |||
| if (steps.empty()) { | |||
| status = GetInputTensor(&steps, onnx_node.input(4)); | |||
| } | |||
| } | |||
| case 4: { | |||
| if (status != RET_ERROR && axes.empty()) { | |||
| status = GetInputTensor(&axes, onnx_node.input(3)); | |||
| } | |||
| } | |||
| case 3: { | |||
| if (status != RET_ERROR && ends.empty()) { | |||
| status = GetInputTensor(&ends, onnx_node.input(2)); | |||
| } | |||
| } | |||
| case 2: { | |||
| if (status != RET_ERROR && starts.empty()) { | |||
| status = GetInputTensor(&starts, onnx_node.input(1)); | |||
| } | |||
| } | |||
| default: { | |||
| if (status == RET_ERROR) { | |||
| MS_LOG(ERROR) << "onnx slice inputs are invalid."; | |||
| return RET_INPUT_TENSOR_ERROR; | |||
| } | |||
| } | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (primitive == nullptr) { | |||
| MS_LOG(ERROR) << "new primitive failed"; | |||
| return nullptr; | |||
| } | |||
| primitive->value.type = schema::PrimitiveType_StridedSlice; | |||
| primitive->value.value = attr.release(); | |||
| auto primitive_c = PrimitiveC::Create(primitive.release()); | |||
| if (starts.empty()) { | |||
| return primitive_c; | |||
| } | |||
| if (axes.empty()) { | |||
| for (size_t i = 0; i < starts.size(); ++i) { | |||
| axes.push_back(i); | |||
| @@ -166,42 +83,11 @@ STATUS OnnxSliceParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::No | |||
| if (steps.empty()) { | |||
| steps.assign(starts.size(), 1); | |||
| } | |||
| onnx::NodeProto *slice_node = nullptr; | |||
| for (auto &node : onnx_graph.node()) { | |||
| if (&node == &onnx_node) { | |||
| slice_node = const_cast<onnx::NodeProto *>(&node); | |||
| } | |||
| } | |||
| int insert_num = 5 - onnx_node.input_size(); | |||
| switch (insert_num) { | |||
| case 4: { | |||
| std::string name = "slice/starts/"; | |||
| status = InsertTensor(starts, name, slice_node); | |||
| } | |||
| case 3: | |||
| if (status == RET_OK) { | |||
| std::string name = "slice/ends/"; | |||
| status = InsertTensor(ends, name, slice_node); | |||
| } | |||
| case 2: | |||
| if (status == RET_OK) { | |||
| std::string name = "slice/axes/"; | |||
| status = InsertTensor(axes, name, slice_node); | |||
| } | |||
| case 1: | |||
| if (status == RET_OK) { | |||
| std::string name = "slice/steps/"; | |||
| status = InsertTensor(steps, name, slice_node); | |||
| } | |||
| default: | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "onnx slice insert tensor failed"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_StridedSlice; | |||
| op->primitive->value.value = attr.release(); | |||
| return RET_OK; | |||
| primitive_c->set_attr("starts", MakeValue<std::vector<int>>(starts)); | |||
| primitive_c->set_attr("ends", MakeValue<std::vector<int>>(ends)); | |||
| primitive_c->set_attr("axes", MakeValue<std::vector<int>>(axes)); | |||
| primitive_c->set_attr("steps", MakeValue<std::vector<int>>(steps)); | |||
| return primitive_c; | |||
| } | |||
| OnnxNodeRegistrar g_onnxSliceParser("Slice", new OnnxSliceParser()); | |||
| @@ -21,7 +21,6 @@ | |||
| #include <string> | |||
| #include "tools/converter/parser/onnx/onnx_node_parser.h" | |||
| #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" | |||
| #include "tools/converter/parser/onnx/onnx_tensor_parser.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| @@ -30,9 +29,7 @@ class OnnxSliceParser : public OnnxNodeParser { | |||
| OnnxSliceParser() : OnnxNodeParser("Slice") {} | |||
| ~OnnxSliceParser() override = default; | |||
| STATUS InsertTensor(const std::vector<int> &onnx_val, const std::string &name, onnx::NodeProto *onnx_node); | |||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||
| STATUS GetInputTensor(std::vector<int> *onnx_val, const std::string &name); | |||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -19,23 +19,13 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS OnnxSoftMaxParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||
| schema::CNodeT *op) { | |||
| lite::PrimitiveC *OnnxSoftMaxParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||
| const onnx::NodeProto &onnx_node) { | |||
| MS_LOG(DEBUG) << "onnx SoftMaxParser"; | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::unique_ptr<schema::SoftMaxT> attr = std::make_unique<schema::SoftMaxT>(); | |||
| auto attr = std::make_unique<schema::SoftMaxT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| return nullptr; | |||
| } | |||
| bool axis_is_def = true; | |||
| @@ -53,9 +43,14 @@ STATUS OnnxSoftMaxParser::Parse(const onnx::GraphProto &onnx_graph, const onnx:: | |||
| attr->axis = 1; | |||
| } | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_SoftMax; | |||
| op->primitive->value.value = attr.release(); | |||
| return RET_OK; | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (primitive == nullptr) { | |||
| MS_LOG(ERROR) << "new primitive failed"; | |||
| return nullptr; | |||
| } | |||
| primitive->value.type = schema::PrimitiveType_SoftMax; | |||
| primitive->value.value = attr.release(); | |||
| return PrimitiveC::Create(primitive.release()); | |||
| } | |||
| OnnxNodeRegistrar g_onnxSoftMaxParser("Softmax", new OnnxSoftMaxParser()); | |||
| @@ -27,7 +27,7 @@ class OnnxSoftMaxParser : public OnnxNodeParser { | |||
| OnnxSoftMaxParser() : OnnxNodeParser("Softmax") {} | |||
| ~OnnxSoftMaxParser() override = default; | |||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -19,23 +19,13 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS OnnxSpaceToDepthParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||
| schema::CNodeT *op) { | |||
| lite::PrimitiveC *OnnxSpaceToDepthParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||
| const onnx::NodeProto &onnx_node) { | |||
| MS_LOG(DEBUG) << "onnx SpaceToDepthParser"; | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::unique_ptr<schema::SpaceToDepthT> attr = std::make_unique<schema::SpaceToDepthT>(); | |||
| auto attr = std::make_unique<schema::SpaceToDepthT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| return nullptr; | |||
| } | |||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | |||
| @@ -45,9 +35,14 @@ STATUS OnnxSpaceToDepthParser::Parse(const onnx::GraphProto &onnx_graph, const o | |||
| } | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_SpaceToDepth; | |||
| op->primitive->value.value = attr.release(); | |||
| return RET_OK; | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (primitive == nullptr) { | |||
| MS_LOG(ERROR) << "new primitive failed"; | |||
| return nullptr; | |||
| } | |||
| primitive->value.type = schema::PrimitiveType_SpaceToDepth; | |||
| primitive->value.value = attr.release(); | |||
| return PrimitiveC::Create(primitive.release()); | |||
| } | |||
| OnnxNodeRegistrar g_onnxSpaceToDepthParser("SpaceToDepth", new OnnxSpaceToDepthParser()); | |||
| @@ -27,7 +27,7 @@ class OnnxSpaceToDepthParser : public OnnxNodeParser { | |||
| OnnxSpaceToDepthParser() : OnnxNodeParser("SpaceToDepth") {} | |||
| ~OnnxSpaceToDepthParser() override = default; | |||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -19,23 +19,13 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS OnnxSplitParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||
| schema::CNodeT *op) { | |||
| lite::PrimitiveC *OnnxSplitParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||
| const onnx::NodeProto &onnx_node) { | |||
| MS_LOG(DEBUG) << "onnx SplitParser"; | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::unique_ptr<schema::SplitT> attr = std::make_unique<schema::SplitT>(); | |||
| auto attr = std::make_unique<schema::SplitT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| return nullptr; | |||
| } | |||
| attr->splitDim = 0; | |||
| @@ -51,9 +41,14 @@ STATUS OnnxSplitParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::No | |||
| } | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_Split; | |||
| op->primitive->value.value = attr.release(); | |||
| return RET_OK; | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (primitive == nullptr) { | |||
| MS_LOG(ERROR) << "new primitive failed"; | |||
| return nullptr; | |||
| } | |||
| primitive->value.type = schema::PrimitiveType_Split; | |||
| primitive->value.value = attr.release(); | |||
| return PrimitiveC::Create(primitive.release()); | |||
| } | |||
| OnnxNodeRegistrar g_onnxSplitParser("Split", new OnnxSplitParser()); | |||
| @@ -27,7 +27,7 @@ class OnnxSplitParser : public OnnxNodeParser { | |||
| OnnxSplitParser() : OnnxNodeParser("Split") {} | |||
| ~OnnxSplitParser() override = default; | |||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -19,23 +19,13 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS OnnxSqueezeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | |||
| schema::CNodeT *op) { | |||
| lite::PrimitiveC *OnnxSqueezeParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||
| const onnx::NodeProto &onnx_node) { | |||
| MS_LOG(DEBUG) << "onnx SqueezeParser"; | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::unique_ptr<schema::SqueezeT> attr = std::make_unique<schema::SqueezeT>(); | |||
| auto attr = std::make_unique<schema::SqueezeT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| return nullptr; | |||
| } | |||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | |||
| @@ -47,9 +37,14 @@ STATUS OnnxSqueezeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx:: | |||
| } | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_Squeeze; | |||
| op->primitive->value.value = attr.release(); | |||
| return RET_OK; | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (primitive == nullptr) { | |||
| MS_LOG(ERROR) << "new primitive failed"; | |||
| return nullptr; | |||
| } | |||
| primitive->value.type = schema::PrimitiveType_Squeeze; | |||
| primitive->value.value = attr.release(); | |||
| return PrimitiveC::Create(primitive.release()); | |||
| } | |||
| OnnxNodeRegistrar g_onnxSqueezeParser("Squeeze", new OnnxSqueezeParser()); | |||
| @@ -27,7 +27,7 @@ class OnnxSqueezeParser : public OnnxNodeParser { | |||
| OnnxSqueezeParser() : OnnxNodeParser("Squeeze") {} | |||
| ~OnnxSqueezeParser() override = default; | |||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -20,26 +20,22 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS OnnxTileParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { | |||
| lite::PrimitiveC *OnnxTileParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||
| const onnx::NodeProto &onnx_node) { | |||
| MS_LOG(DEBUG) << "onnx TileParser"; | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::unique_ptr<schema::TileT> attr = std::make_unique<schema::TileT>(); | |||
| auto attr = std::make_unique<schema::TileT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| return nullptr; | |||
| } | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (primitive == nullptr) { | |||
| MS_LOG(ERROR) << "new primitive failed"; | |||
| return nullptr; | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_Tile; | |||
| op->primitive->value.value = attr.release(); | |||
| return RET_OK; | |||
| primitive->value.type = schema::PrimitiveType_Tile; | |||
| primitive->value.value = attr.release(); | |||
| return PrimitiveC::Create(primitive.release()); | |||
| } | |||
| OnnxNodeRegistrar g_onnxTileParser("Tile", new OnnxTileParser()); | |||
| @@ -27,7 +27,7 @@ class OnnxTileParser : public OnnxNodeParser { | |||
| OnnxTileParser() : OnnxNodeParser("Tile") {} | |||
| ~OnnxTileParser() override = default; | |||
| STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; | |||
| lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -19,22 +19,13 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS OnnxTopkParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { | |||
| lite::PrimitiveC *OnnxTopkParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||
| const onnx::NodeProto &onnx_node) { | |||
| MS_LOG(DEBUG) << "onnx TopKParser"; | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::unique_ptr<schema::TopKT> attr = std::make_unique<schema::TopKT>(); | |||
| auto attr = std::make_unique<schema::TopKT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| return nullptr; | |||
| } | |||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | |||
| @@ -44,9 +35,14 @@ STATUS OnnxTopkParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod | |||
| } | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_TopK; | |||
| op->primitive->value.value = attr.release(); | |||
| return RET_OK; | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (primitive == nullptr) { | |||
| MS_LOG(ERROR) << "new primitive failed"; | |||
| return nullptr; | |||
| } | |||
| primitive->value.type = schema::PrimitiveType_TopK; | |||
| primitive->value.value = attr.release(); | |||
| return PrimitiveC::Create(primitive.release()); | |||
| } | |||
| OnnxNodeRegistrar g_onnxTopkParser("TopK", new OnnxTopkParser()); | |||